jax.numpy.nan_to_num#
- jax.numpy.nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None)[source]#
Replace NaN and infinite entries in an array.
JAX implementation of
numpy.nan_to_num()
.- Parameters:
x (ArrayLike) – array of values to be replaced. If it does not have an inexact dtype it will be returned unmodified.
copy (bool) – unused by JAX
nan (ArrayLike) – value to substitute for NaN entries. Defaults to 0.0.
posinf (ArrayLike | None) – value to substitute for positive infinite entries. Defaults to the maximum representable value.
neginf (ArrayLike | None) – value to substitute for positive infinite entries. Defaults to the minimum representable value.
- Returns:
A copy of
x
with the requested substitutions.- Return type:
See also
jax.numpy.isnan()
: return True where the array contains NaNjax.numpy.isposinf()
: return True where the array contains +infjax.numpy.isneginf()
: return True where the array contains -inf
Examples
>>> x = jnp.array([0, jnp.nan, 1, jnp.inf, 2, -jnp.inf])
Default substitution values:
>>> jnp.nan_to_num(x) Array([ 0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 3.4028235e+38, 2.0000000e+00, -3.4028235e+38], dtype=float32)
Overriding substitutions for
-inf
and+inf
:>>> jnp.nan_to_num(x, posinf=999, neginf=-999) Array([ 0., 0., 1., 999., 2., -999.], dtype=float32)
If you only wish to substitute for NaN values while leaving
inf
values untouched, usingwhere()
withjax.numpy.isnan()
is a better option:>>> jnp.where(jnp.isnan(x), 0, x) Array([ 0., 0., 1., inf, 2., -inf], dtype=float32)