jax.lax.dynamic_slice#
- jax.lax.dynamic_slice(operand, start_indices, slice_sizes, *, allow_negative_indices=True)[source]#
Wraps XLA’s DynamicSlice operator.
- Parameters:
operand (Array | np.ndarray) – an array to slice.
start_indices (Array | np.ndarray | Sequence[ArrayLike]) – a list of scalar indices, one per dimension. These values may be dynamic.
slice_sizes (Shape) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).
allow_negative_indices (bool | Sequence[bool]) – a bool or sequence of bools, one per dimension; if a bool is passed, it applies to all dimensions. For each dimension, if true, negative indices are permitted and are are interpreted relative to the end of the array. If false, negative indices are treated as if they were out of bounds and the result is implementation defined, typically clamped to the first index.
- Returns:
An array containing the slice.
- Return type:
Examples
Here is a simple two-dimensional dynamic slice:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4)) Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)