jax.lax.dynamic_update_slice#
- jax.lax.dynamic_update_slice(operand, update, start_indices, *, allow_negative_indices=True)[source]#
Wraps XLA’s DynamicUpdateSlice operator.
- Parameters:
operand (Array | np.ndarray) – an array to slice.
update (ArrayLike) – an array containing the new values to write onto operand.
start_indices (Array | Sequence[ArrayLike]) – a list of scalar indices, one per dimension.
allow_negative_indices (bool | Sequence[bool]) – a bool or sequence of bools, one per dimension; if a bool is passed, it applies to all dimensions. For each dimension, if true, negative indices are permitted and are are interpreted relative to the end of the array. If false, negative indices are treated as if they were out of bounds and the result is implementation defined, typically clamped to the first index.
- Returns:
An array containing the slice.
- Return type:
Examples
Here is an example of updating a one-dimensional slice update:
>>> x = jnp.zeros(6) >>> y = jnp.ones(3) >>> dynamic_update_slice(x, y, (2,)) Array([0., 0., 1., 1., 1., 0.], dtype=float32)
If the update slice is too large to fit in the array, the start index will be adjusted to make it fit
>>> dynamic_update_slice(x, y, (3,)) Array([0., 0., 0., 1., 1., 1.], dtype=float32) >>> dynamic_update_slice(x, y, (5,)) Array([0., 0., 0., 1., 1., 1.], dtype=float32)
Here is an example of a two-dimensional slice update:
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones((2, 2)) >>> dynamic_update_slice(x, y, (1, 2)) Array([[0., 0., 0., 0.], [0., 0., 1., 1.], [0., 0., 1., 1.], [0., 0., 0., 0.]], dtype=float32)
See also
lax.dynamic_update_index_in_dim
lax.dynamic_update_slice_in_dim