jax.numpy.median#
- jax.numpy.median(a, axis=None, out=None, overwrite_input=False, keepdims=False)[source]#
Return the median of array elements along a given axis.
JAX implementation of
numpy.median()
.- Parameters:
a (ArrayLike) – input array.
axis (int | tuple[int, ...] | None) – optional, int or sequence of ints, default=None. Axis along which the median to be computed. If None, median is computed for the flattened array.
keepdims (bool) – bool, default=False. If true, reduced axes are left in the result with size 1.
out (None) – Unused by JAX.
overwrite_input (bool) – Unused by JAX.
- Returns:
An array of the median along the given axis.
- Return type:
See also
jax.numpy.mean()
: Compute the mean of array elements over a given axis.jax.numpy.max()
: Compute the maximum of array elements over given axis.jax.numpy.min()
: Compute the minimum of array elements over given axis.
Examples
By default, the median is computed for the flattened array.
>>> x = jnp.array([[2, 4, 7, 1], ... [3, 5, 9, 2], ... [6, 1, 8, 3]]) >>> jnp.median(x) Array(3.5, dtype=float32)
If
axis=1
, the median is computed along axis 1.>>> jnp.median(x, axis=1) Array([3. , 4. , 4.5], dtype=float32)
If
keepdims=True
,ndim
of the output is equal to that of the input.>>> jnp.median(x, axis=1, keepdims=True) Array([[3. ], [4. ], [4.5]], dtype=float32)