jax.numpy.mask_indices#

jax.numpy.mask_indices(n, mask_func, k=0, *, size=None)[source]#

Return indices of a mask of an (n, n) array.

Parameters:
  • n (int) – static integer array dimension.

  • mask_func (Callable[[ArrayLike, int], Array]) – a function that takes a shape (n, n) array and an optional offset k, and returns a shape (n, n) mask. Examples of functions with this signature are triu() and tril().

  • k (int) – a scalar value passed to mask_func.

  • size (int | None | None) – optional argument specifying the static size of the output arrays. This is passed to nonzero() when generating the indices from the mask.

Returns:

a tuple of indices where mask_func is nonzero.

Return type:

tuple[Array, Array]

See also

Examples

Calling mask_indices on built-in masking functions:

>>> jnp.mask_indices(3, jnp.triu)
(Array([0, 0, 0, 1, 1, 2], dtype=int32), Array([0, 1, 2, 1, 2, 2], dtype=int32))
>>> jnp.mask_indices(3, jnp.tril)
(Array([0, 1, 1, 2, 2, 2], dtype=int32), Array([0, 0, 1, 0, 1, 2], dtype=int32))

Calling mask_indices on a custom masking function:

>>> def mask_func(x, k=0):
...   i = jnp.arange(x.shape[0])[:, None]
...   j = jnp.arange(x.shape[1])
...   return (i + 1) % (j + 1 + k) == 0
>>> mask_func(jnp.ones((3, 3)))
Array([[ True, False, False],
       [ True,  True, False],
       [ True, False,  True]], dtype=bool)
>>> jnp.mask_indices(3, mask_func)
(Array([0, 1, 1, 2, 2], dtype=int32), Array([0, 0, 1, 0, 2], dtype=int32))