jax.numpy.argsort#

jax.numpy.argsort(a, axis=-1, *, kind=None, order=None, stable=True, descending=False)[source]#

Return indices that sort an array.

JAX implementation of numpy.argsort().

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – array to sort

  • axis (int | None) – integer axis along which to sort. Defaults to -1, i.e. the last axis. If None, then a is flattened before being sorted.

  • stable (bool) – boolean specifying whether a stable sort should be used. Default=True.

  • descending (bool) – boolean specifying whether to sort in descending order. Default=False.

  • kind (None) – deprecated; instead specify sort algorithm using stable=True or stable=False.

  • order (None) – not supported by JAX

Returns:

Array of indices that sort an array. Returned array will be of shape a.shape (if axis is an integer) or of shape (a.size,) (if axis is None).

Return type:

Array

Examples

Simple 1-dimensional sort

>>> x = jnp.array([1, 3, 5, 4, 2, 1])
>>> indices = jnp.argsort(x)
>>> indices
Array([0, 5, 4, 1, 3, 2], dtype=int32)
>>> x[indices]
Array([1, 1, 2, 3, 4, 5], dtype=int32)

Sort along the last axis of an array:

>>> x = jnp.array([[2, 1, 3],
...                [6, 4, 3]])
>>> indices = jnp.argsort(x, axis=1)
>>> indices
Array([[1, 0, 2],
       [2, 1, 0]], dtype=int32)
>>> jnp.take_along_axis(x, indices, axis=1)
Array([[1, 2, 3],
       [3, 4, 6]], dtype=int32)

See also