jax.numpy.argwhere#
- jax.numpy.argwhere(a, *, size=None, fill_value=None)[source]#
Find the indices of nonzero array elements
JAX implementation of
numpy.argwhere()
.jnp.argwhere(x)
is essentially equivalent tojnp.column_stack(jnp.nonzero(x))
with special handling for zero-dimensional (i.e. scalar) inputs.Because the size of the output of
argwhere
is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optionalsize
argument, which specifies the size of the leading dimension of the output - it must be specified statically forjnp.argwhere
to be compiled with non-static operands. Seejax.numpy.nonzero()
for a full discussion ofsize
and its semantics.- Parameters:
a (ArrayLike) – array for which to find nonzero elements
size (int | None | None) – optional integer specifying statically the number of expected nonzero elements. This must be specified in order to use
argwhere
within JAX transformations likejax.jit()
. Seejax.numpy.nonzero()
for more information.fill_value (ArrayLike | None | None) – optional array specifying the fill value when
size
is specified. Seejax.numpy.nonzero()
for more information.
- Returns:
a two-dimensional array of shape
[size, x.ndim]
. Ifsize
is not specified as an argument, it is equal to the number of nonzero elements inx
.- Return type:
See also
Examples
Two-dimensional array:
>>> x = jnp.array([[1, 0, 2], ... [0, 3, 0]]) >>> jnp.argwhere(x) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Equivalent computation using
jax.numpy.column_stack()
andjax.numpy.nonzero()
:>>> jnp.column_stack(jnp.nonzero(x)) Array([[0, 0], [0, 2], [1, 1]], dtype=int32)
Special case for zero-dimensional (i.e. scalar) inputs:
>>> jnp.argwhere(1) Array([], shape=(1, 0), dtype=int32) >>> jnp.argwhere(0) Array([], shape=(0, 0), dtype=int32)