jax.lax.scatter_apply#
- jax.lax.scatter_apply(operand, scatter_indices, func, dimension_numbers, *, update_shape=(), indices_are_sorted=False, unique_indices=False, mode=None)[source]#
Scatter-apply operator.
Wraps XLA’s Scatter operator, where values from
operand
are replaced withfunc(operand)
, with duplicate indices resulting in multiple applications offunc
.The semantics of scatter are complicated, and its API might change in the future. For most use cases, you should prefer the
jax.numpy.ndarray.at
property on JAX arrays which uses the familiar NumPy indexing syntax.Note that in the current implementation,
scatter_apply
is not compatible with automatic differentiation.- Parameters:
operand (Array) – an array to which the scatter should be applied
scatter_indices (Array) – an array that gives the indices in operand to which each update in updates should be applied.
func (Callable[[Array], Array]) – unary function that will be applied at each index.
dimension_numbers (ScatterDimensionNumbers) – a lax.ScatterDimensionNumbers object that describes how dimensions of operand, start_indices, updates and the output relate.
update_shape (Shape) – the shape of the updates at the given indices.
indices_are_sorted (bool) – whether scatter_indices is known to be sorted. If true, may improve performance on some backends.
unique_indices (bool) – whether the elements to be updated in
operand
are guaranteed to not overlap with each other. If true, may improve performance on some backends. JAX does not check this promise: if the updated elements overlap whenunique_indices
isTrue
the behavior is undefined.mode (str | GatherScatterMode | None | None) – how to handle indices that are out of bounds: when set to ‘clip’, indices are clamped so that the slice is within bounds, and when set to ‘fill’ or ‘drop’ out-of-bounds updates are dropped. The behavior for out-of-bounds indices when set to ‘promise_in_bounds’ is implementation-defined.
- Returns:
An array containing the result of applying func to operand at the given indices.
- Return type: