jax.lax.dynamic_index_in_dim#

jax.lax.dynamic_index_in_dim(operand, index, axis=0, keepdims=True, *, allow_negative_indices=True)[source]#

Convenience wrapper around dynamic_slice to perform int indexing.

This is roughly equivalent to the following Python indexing syntax applied along the specified axis: operand[..., index].

Parameters:
  • operand (Array | np.ndarray) – an array to slice.

  • index (int | Array) – the (possibly dynamic) start index

  • axis (int) – the axis along which to apply the slice (defaults to 0)

  • keepdims (bool) – boolean specifying whether the output should have the same rank as the input (default = True)

  • allow_negative_indices (bool) – boolean specifying whether negative indices are allowed. If true, negative indices are taken relative to the end of the array. If false, negative indices are out of bounds and the result is implementation defined.

Returns:

An array containing the slice.

Return type:

Array

Examples

Here is a one-dimensional example:

>>> x = jnp.arange(5)
>>> dynamic_index_in_dim(x, 1)
Array([1], dtype=int32)
>>> dynamic_index_in_dim(x, 1, keepdims=False)
Array(1, dtype=int32)

Here is a two-dimensional example:

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)