jax.lax.index_in_dim#
- jax.lax.index_in_dim(operand, index, axis=0, keepdims=True)[source]#
Convenience wrapper around
lax.slice()
to perform int indexing.This is effectively equivalent to
operand[..., start_index:limit_index:stride]
with the indexing applied on the specified axis.- Parameters:
- Returns:
The subarray at the specified index.
- Return type:
Examples
Here is a one-dimensional example:
>>> x = jnp.arange(4) >>> lax.index_in_dim(x, 2) Array([2], dtype=int32)
>>> lax.index_in_dim(x, 2, keepdims=False) Array(2, dtype=int32)
Here are some two-dimensional examples:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> lax.index_in_dim(x, 1) Array([[4, 5, 6, 7]], dtype=int32)
>>> lax.index_in_dim(x, 1, axis=1, keepdims=False) Array([1, 5, 9], dtype=int32)