jax.numpy.take_along_axis#

jax.numpy.take_along_axis(arr, indices, axis, mode=None, fill_value=None)[source]#

Take elements from an array.

JAX implementation of numpy.take_along_axis(), implemented in terms of jax.lax.gather(). JAX’s behavior differs from NumPy in the case of out-of-bound indices; see the mode parameter below.

Parameters:
  • a – array from which to take values.

  • indices (Array | ndarray | bool | number | bool | int | float | complex) – array of integer indices. If axis is None, must be one-dimensional. If axis is not None, must have a.ndim == indices.ndim, and a must be broadcast-compatible with indices along dimensions other than axis.

  • axis (int | None) – the axis along which to take values. If not specified, the array will be flattened before indexing is applied.

  • mode (str | GatherScatterMode | None) – Out-of-bounds indexing mode, either "fill" or "clip". The default mode="fill" returns invalid values (e.g. NaN) for out-of bounds indices. For more discussion of mode options, see jax.numpy.ndarray.at.

  • arr (Array | ndarray | bool | number | bool | int | float | complex)

  • fill_value (bool | number | bool | int | float | complex | None)

Returns:

Array of values extracted from a.

Return type:

Array

See also

Examples

>>> x = jnp.array([[1., 2., 3.],
...                [4., 5., 6.]])
>>> indices = jnp.array([[0, 2],
...                      [1, 0]])
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1., 3.],
       [5., 4.]], dtype=float32)
>>> x[jnp.arange(2)[:, None], indices]  # equivalent via indexing syntax
Array([[1., 3.],
       [5., 4.]], dtype=float32)

Out-of-bound indices fill with invalid values. For float inputs, this is NaN:

>>> indices = jnp.array([[1, 0, 2]])
>>> jnp.take_along_axis(x, indices, axis=0)
Array([[ 4.,  2., nan]], dtype=float32)
>>> x.at[indices, jnp.arange(3)].get(
...     mode='fill', fill_value=jnp.nan)  # equivalent via indexing syntax
Array([[ 4.,  2., nan]], dtype=float32)

take_along_axis is helpful for extracting values from multi-dimensional argsorts and arg reductions. For, here we compute argsort() indices along an axis, and use take_along_axis to construct the sorted array:

>>> x = jnp.array([[5, 3, 4],
...                [2, 7, 6]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 2, 0],
       [0, 2, 1]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[3, 4, 5],
       [2, 6, 7]], dtype=int32)

Similarly, we can use argmin() with keepdims=True and use take_along_axis to extract the minimum value:

>>> idx = jnp.argmin(x, axis=1, keepdims=True)
>>> idx
Array([[1],
       [0]], dtype=int32)
>>> jnp.take_along_axis(x, idx, axis=1)
Array([[3],
       [2]], dtype=int32)