jax.lax.gather#
- jax.lax.gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)[source]#
Gather operator.
Wraps XLA’s Gather operator.
gather()
is a low-level operator with complicated semantics, and most JAX users will never need to call it directly. Instead, you should prefer using Numpy-style indexing, and/orjax.numpy.ndarray.at()
, perhaps in combination withjax.vmap()
.- Parameters:
operand (ArrayLike) – an array from which slices should be taken
start_indices (ArrayLike) – the indices at which slices should be taken
dimension_numbers (GatherDimensionNumbers) – a lax.GatherDimensionNumbers object that describes how dimensions of operand, start_indices and the output relate.
slice_sizes (Shape) – the size of each slice. Must be a sequence of non-negative integers with length equal to ndim(operand).
indices_are_sorted (bool) – whether indices is known to be sorted. If true, may improve performance on some backends.
unique_indices (bool) – whether the elements gathered from
operand
are guaranteed not to overlap with each other. IfTrue
, this may improve performance on some backends. JAX does not check this promise: if the elements overlap the behavior is undefined.mode (str | GatherScatterMode | None | None) – how to handle indices that are out of bounds: when set to
'clip'
, indices are clamped so that the slice is within bounds, and when set to'fill'
or'drop'
gather returns a slice full offill_value
for the affected slice. The behavior for out-of-bounds indices when set to'promise_in_bounds'
is implementation-defined.fill_value – the fill value to return for out-of-bounds slices when mode is
'fill'
. Ignored otherwise. Defaults toNaN
for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, andTrue
for booleans.
- Returns:
An array containing the gather output.
- Return type:
Examples
As mentioned above, you should basically never use
gather()
directly, and instead use NumPy-style indexing expressions to gather values from arrays.For example, here is how you can extract values at particular indices using straightforward indexing semantics, which will lower to XLA’s Gather operator:
>>> import jax.numpy as jnp >>> x = jnp.array([10, 11, 12]) >>> indices = jnp.array([0, 1, 1, 2, 2, 2])
>>> x[indices] Array([10, 11, 11, 12, 12, 12], dtype=int32)
For control over settings like
indices_are_sorted
,unique_indices
,mode
, andfill_value
, you can use thejax.numpy.ndarray.at
syntax:>>> x.at[indices].get(indices_are_sorted=True, mode="promise_in_bounds") Array([10, 11, 11, 12, 12, 12], dtype=int32)
By comparison, here is the equivalent function call using
gather()
directly, which is not something typical users should ever need to do:>>> from jax import lax >>> lax.gather(x, indices[:, None], slice_sizes=(1,), ... dimension_numbers=lax.GatherDimensionNumbers( ... offset_dims=(), ... collapsed_slice_dims=(0,), ... start_index_map=(0,)), ... indices_are_sorted=True, ... mode=lax.GatherScatterMode.PROMISE_IN_BOUNDS) Array([10, 11, 11, 12, 12, 12], dtype=int32)