jax.numpy.std#

jax.numpy.std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False, *, where=None, correction=None)[source]#

Compute the standard deviation along a given axis.

JAX implementation of numpy.std().

Parameters:
  • a (ArrayLike) – input array.

  • axis (Axis | None) – optional, int or sequence of ints, default=None. Axis along which the standard deviation is computed. If None, standard deviaiton is computed along all the axes.

  • dtype (DTypeLike | None | None) – The type of the output array. Default=None.

  • ddof (int) – int, default=0. Degrees of freedom. The divisor in the standard deviation computation is N-ddof, N is number of elements along given axis.

  • keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.

  • where (ArrayLike | None | None) – optional, boolean array, default=None. The elements to be used in the standard deviation. Array should be broadcast compatible to the input.

  • correction (int | float | None | None) – int or float, default=None. Alternative name for ddof. Both ddof and correction can’t be provided simultaneously.

  • out (None | None) – Unused by JAX.

Returns:

An array of the standard deviation along the given axis.

Return type:

Array

See also

Examples

By default, jnp.std computes the standard deviation along all axes.

>>> x = jnp.array([[1, 3, 4, 2],
...                [4, 2, 5, 3],
...                [5, 4, 2, 3]])
>>> with jnp.printoptions(precision=2, suppress=True):
...   jnp.std(x)
Array(1.21, dtype=float32)

If axis=0, computes along axis 0.

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0))
[1.7  0.82 1.25 0.47]

To preserve the dimensions of input, you can set keepdims=True.

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0, keepdims=True))
[[1.7  0.82 1.25 0.47]]

If ddof=1:

>>> with jnp.printoptions(precision=2, suppress=True):
...   print(jnp.std(x, axis=0, keepdims=True, ddof=1))
[[2.08 1.   1.53 0.58]]

To include specific elements of the array to compute standard deviation, you can use where.

>>> where = jnp.array([[1, 0, 1, 0],
...                    [0, 1, 0, 1],
...                    [1, 1, 1, 0]], dtype=bool)
>>> jnp.std(x, axis=0, keepdims=True, where=where)
Array([[2., 1., 1., 0.]], dtype=float32)