
jax.numpy.argsort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[source]#

Return indices that sort an array.

JAX implementation of numpy.argsort().

  • a (Array | ndarray | bool | number | bool | int | float | complex) – array to sort

  • axis (int | None) – integer axis along which to sort. Defaults to -1, i.e. the last axis. If None, then a is flattened before being sorted.

  • stable (bool) – boolean specifying whether a stable sort should be used. Default=True.

  • descending (bool) – boolean specifying whether to sort in descending order. Default=False.

  • kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.

  • order (None) – not supported by JAX


Array of indices that sort an array. Returned array will be of shape a.shape (if axis is an integer) or of shape (a.size,) (if axis is None).

Return type:



Simple 1-dimensional sort

>>> x = jnp.array([1, 3, 5, 4, 2, 1])
>>> indices = jnp.argsort(x)
>>> indices
Array([0, 5, 4, 1, 3, 2], dtype=int32)
>>> x[indices]
Array([1, 1, 2, 3, 4, 5], dtype=int32)

Sort along the last axis of an array:

>>> x = jnp.array([[2, 1, 3],
...                [6, 4, 3]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 0, 2],
       [2, 1, 0]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1, 2, 3],
       [3, 4, 6]], dtype=int32)

See also