jax.numpy.put_along_axis#

jax.numpy.put_along_axis(arr, indices, values, axis, inplace=True, *, mode=None)[source]#

Put values into the destination array by matching 1d index and data slices.

JAX implementation of numpy.put_along_axis().

The semantics of numpy.put_along_axis() are to modify arrays in-place, which is not possible for JAX’s immutable arrays. The JAX version returns a modified copy of the input, and adds the inplace parameter which must be set to False` by the user as a reminder of this API difference.

Parameters:
Returns:

A copy of a with specified entries updated.

Return type:

Array

See also

Examples

>>> from jax import numpy as jnp
>>> a = jnp.array([[10, 30, 20], [60, 40, 50]])
>>> i = jnp.argmax(a, axis=1, keepdims=True)
>>> print(i)
[[1]
 [0]]
>>> b = jnp.put_along_axis(a, i, 99, axis=1, inplace=False)
>>> print(b)
[[10 99 20]
 [99 40 50]]