jax.numpy.where#
- jax.numpy.where(condition, x=None, y=None, /, *, size=None, fill_value=None)[source]#
Select elements from two arrays based on a condition.
JAX implementation of
numpy.where()
.Note
when only
condition
is provided,jnp.where(condition)
is equivalent tojnp.nonzero(condition)
. For that case, refer to the documentation ofjax.numpy.nonzero()
. The docstring below focuses on the case wherex
andy
are specified.The three-term version of
jnp.where
lowers tojax.lax.select()
.- Parameters:
condition – boolean array. Must be broadcast-compatible with
x
andy
when they are specified.x – arraylike. Should be broadcast-compatible with
condition
andy
, and typecast-compatible withy
.y – arraylike. Should be broadcast-compatible with
condition
andx
, and typecast-compatible withx
.size – integer, only referenced when
x
andy
areNone
. For details, seejax.numpy.nonzero()
.fill_value – only referenced when
x
andy
areNone
. For details, seejax.numpy.nonzero()
.
- Returns:
An array of dtype
jnp.result_type(x, y)
with values drawn fromx
wherecondition
is True, and fromy
where condition isFalse
. Ifx
andy
areNone
, the function behaves differently; seejax.numpy.nonzero()
for a description of the return type.
Notes
Special care is needed when the
x
ory
input tojax.numpy.where()
could have a value of NaN. Specifically, when a gradient is taken withjax.grad()
(reverse-mode differentiation), a NaN in eitherx
ory
will propagate into the gradient, regardless of the value ofcondition
. More information on this behavior and workarounds is available in the JAX FAQ.Examples
When
x
andy
are not provided,where
behaves equivalently tojax.numpy.nonzero()
:>>> x = jnp.arange(10) >>> jnp.where(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),) >>> jnp.nonzero(x > 4) (Array([5, 6, 7, 8, 9], dtype=int32),)
When
x
andy
are provided,where
selects between them based on the specified condition:>>> jnp.where(x > 4, x, 0) Array([0, 0, 0, 0, 0, 5, 6, 7, 8, 9], dtype=int32)