jax.nn.standardize#
- jax.nn.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None, *, algorithm='fast')[source]#
Standardizes input to zero mean and unit variance.
The standardization is given by:
\[x_{std} = \frac{x - \langle x\rangle}{\sqrt{\langle(x - \langle x\rangle)^2\rangle + \epsilon}}\]where \(\langle x\rangle\) indicates the mean of \(x\), and \(\epsilon\) is a small correction factor introduced to avoid division by zero.
- Parameters:
x (ArrayLike) – input array to be standardized.
axis (Axis) – integer, tuple of integers, or
None(all axes), representing the axes along which to standardize. Defaults to the last axis (-1).mean (ArrayLike | None) – optionally specify the mean used for standardization. If not specified, then
x.mean(axis, where=where)will be used.variance (ArrayLike | None) – optionally specify the variance used for standardization. If not specified, then
x.var(axis, where=where)will be used.epsilon (ArrayLike) – correction factor added to variance to avoid division by zero; defaults to
1E-5.where (ArrayLike | None) – optional boolean mask specifying which elements to use when computing the mean and variance.
algorithm (str) – variance computation algorithm.
"fast"usesmean(x^2) - mean(x)^2which may be faster but can suffer from catastrophic cancellation and produce different results in eager vs JIT contexts."stable"uses the two-pass formulamean((x - mean(x))^2)which is numerically stable. Default is"fast"for backward compatibility.
- Returns:
An array of the same shape as
xcontaining the standardized input.- Return type: