jax.numpy.argsort#
- jax.numpy.argsort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[source]#
Return indices that sort an array.
JAX implementation of
numpy.argsort()
.- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – array to sort
axis (int | None) – integer axis along which to sort. Defaults to
-1
, i.e. the last axis. IfNone
, thena
is flattened before being sorted.stable (bool) – boolean specifying whether a stable sort should be used. Default=True.
descending (bool) – boolean specifying whether to sort in descending order. Default=False.
kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.
order (None) – not supported by JAX
- Returns:
Array of indices that sort an array. Returned array will be of shape
a.shape
(ifaxis
is an integer) or of shape(a.size,)
(ifaxis
is None).- Return type:
Examples
Simple 1-dimensional sort
>>> x = jnp.array([1, 3, 5, 4, 2, 1]) >>> indices = jnp.argsort(x) >>> indices Array([0, 5, 4, 1, 3, 2], dtype=int32) >>> x[indices] Array([1, 1, 2, 3, 4, 5], dtype=int32)
Sort along the last axis of an array:
>>> x = jnp.array([[2, 1, 3], ... [6, 4, 3]]) >>> indices = jnp.argsort(x, axis=1) >>> indices Array([[1, 0, 2], [2, 1, 0]], dtype=int32) >>> jnp.take_along_axis(x, indices, axis=1) Array([[1, 2, 3], [3, 4, 6]], dtype=int32)
See also
jax.numpy.sort()
: return sorted values directly.jax.numpy.lexsort()
: lexicographical sort of multiple arrays.jax.lax.sort()
: lower-level function wrapping XLA’s Sort operator.