jax.numpy.average#
- jax.numpy.average(a, axis=None, weights=None, returned=False, keepdims=False)[source]#
Compute the weighed average.
JAX Implementation of
numpy.average()
.- Parameters:
a (ArrayLike) – array to be averaged
axis (Axis | None) – an optional integer or sequence of integers specifying the axis along which the mean to be computed. If not specified, mean is computed along all the axes.
weights (ArrayLike | None | None) – an optional array of weights for a weighted average. Must be broadcast-compatible with
a
.returned (bool) – If False (default) then return only the average. If True then return both the average and the normalization factor (i.e. the sum of weights).
keepdims (bool) – If True, reduced axes are left in the result with size 1. If False (default) then reduced axes are squeezed out.
- Returns:
An array
average
or tuple of arrays(average, normalization)
ifreturned
is True.- Return type:
See also
jax.numpy.mean()
: unweighted mean.
Examples
Simple average:
>>> x = jnp.array([1, 2, 3, 2, 4]) >>> jnp.average(x) Array(2.4, dtype=float32)
Weighted average:
>>> weights = jnp.array([2, 1, 3, 2, 2]) >>> jnp.average(x, weights=weights) Array(2.5, dtype=float32)
Use
returned=True
to optionally return the normalization, i.e. the sum of weights:>>> jnp.average(x, returned=True) (Array(2.4, dtype=float32), Array(5., dtype=float32)) >>> jnp.average(x, weights=weights, returned=True) (Array(2.5, dtype=float32), Array(10., dtype=float32))
Weighted average along a specified axis:
>>> x = jnp.array([[8, 2, 7], ... [3, 6, 4]]) >>> weights = jnp.array([1, 2, 3]) >>> jnp.average(x, weights=weights, axis=1) Array([5.5, 4.5], dtype=float32)