jax.lax.dynamic_update_index_in_dim#

jax.lax.dynamic_update_index_in_dim(operand, update, index, axis, *, allow_negative_indices=True)[source]#

Convenience wrapper around dynamic_update_slice() to update a slice of size 1 in a single axis.

Parameters:
  • operand (Array | np.ndarray) – an array to slice.

  • update (ArrayLike) – an array containing the new values to write onto operand.

  • index (ArrayLike) – a single scalar index

  • axis (int) – the axis of the update.

  • allow_negative_indices (bool) – boolean specifying whether negative indices are allowed. If true, negative indices are taken relative to the end of the array. If false, negative indices are out of bounds and the result is implementation defined.

Returns:

The updated array

Return type:

Array

Examples

>>> x = jnp.zeros(6)
>>> y = 1.0
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0])
>>> dynamic_update_index_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 0., 0., 0.], dtype=float32)

If the specified index is out of bounds, the index will be clipped to the valid range:

>>> dynamic_update_index_in_dim(x, y, 10, axis=0)
Array([0., 0., 0., 0., 0., 1.], dtype=float32)

Here is an example of a two-dimensional dynamic index update:

>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones(4)
>>> dynamic_update_index_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
      [1., 1., 1., 1.],
      [0., 0., 0., 0.],
      [0., 0., 0., 0.]], dtype=float32)

Note that the shape of the additional axes in update need not match the associated dimensions of the operand:

>>> y = jnp.ones((1, 3))
>>> dynamic_update_index_in_dim(x, y, 1, 0)
Array([[0., 0., 0., 0.],
       [1., 1., 1., 0.],
       [0., 0., 0., 0.],
       [0., 0., 0., 0.]], dtype=float32)