jax.lax.dynamic_update_slice_in_dim#
- jax.lax.dynamic_update_slice_in_dim(operand, update, start_index, axis, *, allow_negative_indices=True)[source]#
Convenience wrapper around
dynamic_update_slice()
to update a slice in a singleaxis
.- Parameters:
operand (Array | np.ndarray) – an array to slice.
update (ArrayLike) – an array containing the new values to write onto operand.
start_index (ArrayLike) – a single scalar index
axis (int) – the axis of the update.
allow_negative_indices (bool) – boolean specifying whether negative indices are allowed. If true, negative indices are taken relative to the end of the array. If false, negative indices are out of bounds and the result is implementation defined.
- Returns:
The updated array
- Return type:
Examples
>>> x = jnp.zeros(6) >>> y = jnp.ones(3) >>> dynamic_update_slice_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start index will be adjusted to make it fit:
>>> dynamic_update_slice_in_dim(x, y, 3, axis=0) Array([0., 0., 0., 1., 1., 1.], dtype=float32) >>> dynamic_update_slice_in_dim(x, y, 5, axis=0) Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones((2, 4)) >>> dynamic_update_slice_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 1.], [1., 1., 1., 1.], [0., 0., 0., 0.]], dtype=float32)
Note that the shape of the additional axes in
update
need not match the associated dimensions of theoperand
:>>> y = jnp.ones((2, 3)) >>> dynamic_update_slice_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.]], dtype=float32)