jax.numpy.take#
- jax.numpy.take(a, indices, axis=None, out=None, mode=None, unique_indices=False, indices_are_sorted=False, fill_value=None)[source]#
Take elements from an array.
JAX implementation of
numpy.take()
, implemented in terms ofjax.lax.gather()
. JAX’s behavior differs from NumPy in the case of out-of-bound indices; see themode
parameter below.- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – array from which to take values.
indices (Array | ndarray | bool | number | bool | int | float | complex) – N-dimensional array of integer indices of values to take from the array.
axis (int | None) – the axis along which to take values. If not specified, the array will be flattened before indexing is applied.
mode (str | None) – Out-of-bounds indexing mode, either
"fill"
or"clip"
. The defaultmode="fill"
returns invalid values (e.g. NaN) for out-of bounds indices; thefill_value
argument gives control over this value. For more discussion ofmode
options, seejax.numpy.ndarray.at
.fill_value (bool | number | bool | int | float | complex | None) – The fill value to return for out-of-bounds slices when mode is ‘fill’. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.
unique_indices (bool) – If True, the implementation will assume that the indices are unique, which can result in more efficient execution on some backends. If set to True and indices are not unique, the output is undefined.
indices_are_sorted (bool) – If True, the implementation will assume that the indices are sorted in ascending order, which can lead to more efficient execution on some backends. If set to True and indices are not sorted, the output is undefined.
out (None)
- Returns:
Array of values extracted from
a
.- Return type:
See also
jax.numpy.ndarray.at
: take values via indexing syntax.jax.numpy.take_along_axis()
: take values along an axis
Examples
>>> x = jnp.array([[1., 2., 3.], ... [4., 5., 6.]]) >>> indices = jnp.array([2, 0])
Passing no axis results in indexing into the flattened array:
>>> jnp.take(x, indices) Array([3., 1.], dtype=float32) >>> x.ravel()[indices] # equivalent indexing syntax Array([3., 1.], dtype=float32)
Passing an axis results ind applying the index to every subarray along the axis:
>>> jnp.take(x, indices, axis=1) Array([[3., 1.], [6., 4.]], dtype=float32) >>> x[:, indices] # equivalent indexing syntax Array([[3., 1.], [6., 4.]], dtype=float32)
Out-of-bound indices fill with invalid values. For float inputs, this is NaN:
>>> jnp.take(x, indices, axis=0) Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='fill', fill_value=jnp.nan) # equivalent indexing syntax Array([[nan, nan, nan], [ 1., 2., 3.]], dtype=float32)
This default out-of-bound behavior can be adjusted using the
mode
parameter, for example, we can instead clip to the last valid value:>>> jnp.take(x, indices, axis=0, mode='clip') Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32) >>> x.at[indices].get(mode='clip') # equivalent indexing syntax Array([[4., 5., 6.], [1., 2., 3.]], dtype=float32)