jax.numpy.average#

jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[source]#

Compute the weighed average.

JAX Implementation of numpy.average().

Parameters:
  • a (ArrayLike) – array to be averaged

  • axis (Axis | None) – an optional integer or sequence of integers specifying the axis along which the mean to be computed. If not specified, mean is computed along all the axes.

  • weights (ArrayLike | None | None) – an optional array of weights for a weighted average. Must be broadcast-compatible with a.

  • returned (bool) – If False (default) then return only the average. If True then return both the average and the normalization factor (i.e. the sum of weights).

  • keepdims (bool) – If True, reduced axes are left in the result with size 1. If False (default) then reduced axes are squeezed out.

Returns:

An array average or tuple of arrays (average, normalization) if returned is True.

Return type:

Array | tuple[Array, Array]

See also

Examples

Simple average:

>>> x = jnp.array([1, 2, 3, 2, 4])
>>> jnp.average(x)
Array(2.4, dtype=float32)

Weighted average:

>>> weights = jnp.array([2, 1, 3, 2, 2])
>>> jnp.average(x, weights=weights)
Array(2.5, dtype=float32)

Use returned=True to optionally return the normalization, i.e. the sum of weights:

>>> jnp.average(x, returned=True)
(Array(2.4, dtype=float32), Array(5., dtype=float32))
>>> jnp.average(x, weights=weights, returned=True)
(Array(2.5, dtype=float32), Array(10., dtype=float32))

Weighted average along a specified axis:

>>> x = jnp.array([[8, 2, 7],
...                [3, 6, 4]])
>>> weights = jnp.array([1, 2, 3])
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)