jax.lax.dynamic_update_slice_in_dim#

jax.lax.dynamic_update_slice_in_dim(operand, update, start_index, axis, *, allow_negative_indices=True)[source]#

Convenience wrapper around dynamic_update_slice() to update a slice in a single axis.

Parameters:
  • operand (Array | np.ndarray) – an array to slice.

  • update (ArrayLike) – an array containing the new values to write onto operand.

  • start_index (ArrayLike) – a single scalar index

  • axis (int) – the axis of the update.

  • allow_negative_indices (bool) – boolean specifying whether negative indices are allowed. If true, negative indices are taken relative to the end of the array. If false, negative indices are out of bounds and the result is implementation defined.

Returns:

The updated array

Return type:

Array

Examples

>>> x = jnp.zeros(6)
>>> y = jnp.ones(3)
>>> dynamic_update_slice_in_dim(x, y, 2, axis=0)
Array([0., 0., 1., 1., 1., 0.], dtype=float32)

If the update slice is too large to fit in the array, the start index will be adjusted to make it fit:

>>> dynamic_update_slice_in_dim(x, y, 3, axis=0)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)
>>> dynamic_update_slice_in_dim(x, y, 5, axis=0)
Array([0., 0., 0., 1., 1., 1.], dtype=float32)

Here is an example of a two-dimensional slice update:

>>> x = jnp.zeros((4, 4))
>>> y = jnp.ones((2, 4))
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [0., 0., 0., 0.]], dtype=float32)

Note that the shape of the additional axes in update need not match the associated dimensions of the operand:

>>> y = jnp.ones((2, 3))
>>> dynamic_update_slice_in_dim(x, y, 1, axis=0)
Array([[0., 0., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 0.],
       [0., 0., 0., 0.]], dtype=float32)