jax.numpy.percentile#
- jax.numpy.percentile(a, q, axis=None, out=None, overwrite_input=False, method='linear', keepdims=False, *, interpolation=Deprecated)[source]#
Compute the percentile of the data along the specified axis.
JAX implementation of
numpy.percentile()
.- Parameters:
a (ArrayLike) β N-dimensional array input.
q (ArrayLike) β scalar or 1-dimensional array specifying the desired quantiles.
q
should contain integer or floating point values between0
and100
.axis (int | tuple[int, ...] | None) β optional axis or tuple of axes along which to compute the quantile
out (None) β not implemented by JAX; will error if not None
overwrite_input (bool) β not implemented by JAX; will error if not False
method (str) β specify the interpolation method to use. Options are one of
["linear", "lower", "higher", "midpoint", "nearest"]
. default islinear
.keepdims (bool) β if True, then the returned array will have the same number of dimensions as the input. Default is False.
interpolation (str | DeprecatedArg) β deprecated alias of the
method
argument. Will result in aDeprecationWarning
if used.
- Returns:
An array containing the specified percentiles along the specified axes.
- Return type:
See also
jax.numpy.quantile()
: compute the quantile (0.0-1.0)jax.numpy.nanpercentile()
: compute the percentile while ignoring NaNs
Examples
Computing the median and quartiles of a 1D array:
>>> x = jnp.array([0, 1, 2, 3, 4, 5, 6]) >>> q = jnp.array([25, 50, 75]) >>> jnp.percentile(x, q) Array([1.5, 3. , 4.5], dtype=float32)
Computing the same percentiles with nearest rather than linear interpolation:
>>> jnp.percentile(x, q, method='nearest') Array([1., 3., 4.], dtype=float32)