jax.lax.dynamic_update_index_in_dim#
- jax.lax.dynamic_update_index_in_dim(operand, update, index, axis, *, allow_negative_indices=True)[source]#
Convenience wrapper around
dynamic_update_slice()
to update a slice of size 1 in a singleaxis
.- Parameters:
operand (Array | np.ndarray) – an array to slice.
update (ArrayLike) – an array containing the new values to write onto operand.
index (ArrayLike) – a single scalar index
axis (int) – the axis of the update.
allow_negative_indices (bool) – boolean specifying whether negative indices are allowed. If true, negative indices are taken relative to the end of the array. If false, negative indices are out of bounds and the result is implementation defined.
- Returns:
The updated array
- Return type:
Examples
>>> x = jnp.zeros(6) >>> y = 1.0 >>> dynamic_update_index_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 0., 0., 0.], dtype=float32)
>>> y = jnp.array([1.0]) >>> dynamic_update_index_in_dim(x, y, 2, axis=0) Array([0., 0., 1., 0., 0., 0.], dtype=float32)
If the specified index is out of bounds, the index will be clipped to the valid range:
>>> dynamic_update_index_in_dim(x, y, 10, axis=0) Array([0., 0., 0., 0., 0., 1.], dtype=float32)
Here is an example of a two-dimensional dynamic index update:
>>> x = jnp.zeros((4, 4)) >>> y = jnp.ones(4) >>> dynamic_update_index_in_dim(x, y, 1, axis=0) Array([[0., 0., 0., 0.], [1., 1., 1., 1.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32)
Note that the shape of the additional axes in
update
need not match the associated dimensions of theoperand
:>>> y = jnp.ones((1, 3)) >>> dynamic_update_index_in_dim(x, y, 1, 0) Array([[0., 0., 0., 0.], [1., 1., 1., 0.], [0., 0., 0., 0.], [0., 0., 0., 0.]], dtype=float32)