jax.numpy.sort#
- jax.numpy.sort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[source]#
Return a sorted copy of an array.
JAX implementation of
numpy.sort()
.- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – array to sort
axis (int | None) – integer axis along which to sort. Defaults to
-1
, i.e. the last axis. IfNone
, thena
is flattened before being sorted.stable (bool) – boolean specifying whether a stable sort should be used. Default=True.
descending (bool) – boolean specifying whether to sort in descending order. Default=False.
kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.
order (None) – not supported by JAX
- Returns:
Sorted array of shape
a.shape
(ifaxis
is an integer) or of shape(a.size,)
(ifaxis
is None).- Return type:
Examples
Simple 1-dimensional sort
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> jnp.sort(x) Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
>>> x = jnp.array([[2, 1, 3], ... [4, 3, 6]]) >>> jnp.sort(x, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
See also
jax.numpy.argsort()
: return indices of sorted values.jax.numpy.lexsort()
: lexicographical sort of multiple arrays.jax.lax.sort()
: lower-level function wrapping XLA’s Sort operator.