jax.numpy.put_along_axis#
- jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[source]#
Put values into the destination array by matching 1d index and data slices.
JAX implementation of
numpy.put_along_axis()
.The semantics of
numpy.put_along_axis()
are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds theinplace
parameter which must be set to False` by the user as a reminder of this API difference.- Parameters:
arr (Array | ndarray | bool | number | bool | int | float | complex) – array into which values will be put.
indices (Array | ndarray | bool | number | bool | int | float | complex) – array of indices at which to put values.
values (Array | ndarray | bool | number | bool | int | float | complex) – array of values to put into the array.
axis (int | None) – the axis along which to put values. If not specified, the array will be flattened before indexing is applied.
inplace (bool) – must be set to False to indicate that the input is not modified in-place, but rather a modified copy is returned.
mode (str | None) – Out-of-bounds indexing mode. For more discussion of
mode
options, seejax.numpy.ndarray.at
.
- Returns:
A copy of
a
with specified entries updated.- Return type:
See also
jax.numpy.put()
: put elements into an array at given indices.jax.numpy.place()
: place elements into an array via boolean mask.jax.numpy.ndarray.at()
: array updates using NumPy-style indexing.jax.numpy.take()
: extract values from an array at given indices.jax.numpy.take_along_axis()
: extract values from an array along an axis.
Examples
>>> from jax import numpy as jnp >>> a = jnp.array([[10, 30, 20], [60, 40, 50]]) >>> i = jnp.argmax(a, axis=1, keepdims=True) >>> print(i) [[1] [0]] >>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False) >>> print(b) [[10 99 20] [99 40 50]]