jax.lax.conv_dimension_numbers#

jax.lax.conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers)[source]#

Converts convolution dimension_numbers to a ConvDimensionNumbers.

Parameters:
  • lhs_shape – tuple of nonnegative integers, shape of the convolution input.

  • rhs_shape – tuple of nonnegative integers, shape of the convolution kernel.

  • dimension_numbers – None or a tuple/list of strings or a ConvDimensionNumbers object.

Returns:

A ConvDimensionNumbers object that represents dimension_numbers in the canonical form used by lax functions.

Return type:

ConvDimensionNumbers