jax.lax.conv_dimension_numbers#
- jax.lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)[source]#
Converts convolution dimension_numbers to a ConvDimensionNumbers.
- Parameters:
lhs_shape – tuple of nonnegative integers, shape of the convolution input.
rhs_shape – tuple of nonnegative integers, shape of the convolution kernel.
dimension_numbers – None or a tuple/list of strings or a ConvDimensionNumbers object.
- Returns:
A ConvDimensionNumbers object that represents dimension_numbers in the canonical form used by lax functions.
- Return type: