jax.numpy.mask_indices#
- jax.numpy.mask_indices(n, mask_func, k=0, *, size=None)[source]#
Return indices of a mask of an (n, n) array.
- Parameters:
n (int) – static integer array dimension.
mask_func (Callable[[ArrayLike, int], Array]) – a function that takes a shape
(n, n)
array and an optional offsetk
, and returns a shape(n, n)
mask. Examples of functions with this signature aretriu()
andtril()
.k (int) – a scalar value passed to
mask_func
.size (int | None | None) – optional argument specifying the static size of the output arrays. This is passed to
nonzero()
when generating the indices from the mask.
- Returns:
a tuple of indices where
mask_func
is nonzero.- Return type:
See also
jax.numpy.triu_indices()
: computemask_indices
fortriu()
.jax.numpy.tril_indices()
: computemask_indices
fortril()
.
Examples
Calling
mask_indices
on built-in masking functions:>>> jnp.mask_indices(3, jnp.triu) (Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril) (Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))
Calling
mask_indices
on a custom masking function:>>> def mask_func(x, k=0): ... i = jnp.arange(x.shape[0])[:, None] ... j = jnp.arange(x.shape[1]) ... return (i + 1) % (j + 1 + k) == 0 >>> mask_func(jnp.ones((3, 3))) Array([[ True, False, False], [ True, True, False], [ True, False, True]], dtype=bool) >>> jnp.mask_indices(3, mask_func) (Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))