jax.numpy.take_along_axis#
- jax.numpy.take_along_axis(arr, indices, axis, mode=None, fill_value=None)[source]#
Take elements from an array.
JAX implementation of
numpy.take_along_axis()
, implemented in terms ofjax.lax.gather()
. JAX’s behavior differs from NumPy in the case of out-of-bound indices; see themode
parameter below.- Parameters:
a – array from which to take values.
indices (Array | ndarray | bool | number | bool | int | float | complex) – array of integer indices. If
axis
isNone
, must be one-dimensional. Ifaxis
is not None, must havea.ndim == indices.ndim
, anda
must be broadcast-compatible withindices
along dimensions other thanaxis
.axis (int | None) – the axis along which to take values. If not specified, the array will be flattened before indexing is applied.
mode (str | GatherScatterMode | None) – Out-of-bounds indexing mode, either
"fill"
or"clip"
. The defaultmode="fill"
returns invalid values (e.g. NaN) for out-of bounds indices. For more discussion ofmode
options, seejax.numpy.ndarray.at
.arr (Array | ndarray | bool | number | bool | int | float | complex)
fill_value (bool | number | bool | int | float | complex | None)
- Returns:
Array of values extracted from
a
.- Return type:
See also
jax.numpy.ndarray.at
: take values via indexing syntax.jax.numpy.take()
: take the same indices along every axis slice.
Examples
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([[0, 2], ... [1, 0]]) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1., 3.], [5., 4.]], dtype=float32) >>> x[jnp.arange(2)[:, None], indices] # equivalent via indexing syntax Array([[1., 3.], [5., 4.]], dtype=float32)
Out-of-bound indices fill with invalid values. For float inputs, this is NaN:
>>> indices = jnp.array([[1, 0, 2]]) >>> jnp.take_along_axis(x, indices, axis=0) Array([[ 4., 2., nan]], dtype=float32) >>> x.at[indices, jnp.arange(3)].get( ... mode='fill', fill_value=jnp.nan) # equivalent via indexing syntax Array([[ 4., 2., nan]], dtype=float32)
take_along_axis
is helpful for extracting values from multi-dimensional argsorts and arg reductions. For, here we computeargsort()
indices along an axis, and usetake_along_axis
to construct the sorted array:>>> x = jnp.array([[5, 3, 4], ... [2, 7, 6]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 2, 0], [0, 2, 1]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[3, 4, 5], [2, 6, 7]], dtype=int32)
Similarly, we can use
argmin()
withkeepdims=True
and usetake_along_axis
to extract the minimum value:>>> idx = jnp.argmin(x, axis=1, keepdims=True) >>> idx Array([[1], [0]], dtype=int32) >>> jnp.take_along_axis(x, idx, axis=1) Array([[3], [2]], dtype=int32)