jax.numpy.lexsort#
- jax.numpy.lexsort(keys, axis=-1)[source]#
Sort a sequence of keys in lexicographic order.
JAX implementation of
numpy.lexsort()
.- Parameters:
- Returns:
An array of integers of shape
keys[0].shape
giving the indices of the entries in lexicographically-sorted order.- Return type:
See also
jax.numpy.argsort()
: sort a single entry by index.jax.lax.sort()
: direct XLA sorting API.
Examples
lexsort()
with a single key is equivalent toargsort()
:>>> key1 = jnp.array([4, 2, 3, 2, 5]) >>> jnp.lexsort([key1]) Array([1, 3, 2, 0, 4], dtype=int32) >>> jnp.argsort(key1) Array([1, 3, 2, 0, 4], dtype=int32)
With multiple keys,
lexsort()
uses the last key as the primary key:>>> key2 = jnp.array([2, 1, 1, 2, 2]) >>> jnp.lexsort([key1, key2]) Array([1, 2, 3, 0, 4], dtype=int32)
The meaning of the indices become more clear when printing the sorted keys:
>>> indices = jnp.lexsort([key1, key2]) >>> print(f"{key1[indices]}\n{key2[indices]}") [2 3 2 4 5] [1 1 2 2 2]
Notice that the elements of
key2
appear in order, and within the sequences of duplicated values the corresponding elements of`key1
appear in order.For multi-dimensional inputs,
lexsort()
defaults to sorting along the last axis:>>> key1 = jnp.array([[2, 4, 2, 3], ... [3, 1, 2, 2]]) >>> key2 = jnp.array([[1, 2, 1, 3], ... [2, 1, 2, 1]]) >>> jnp.lexsort([key1, key2]) Array([[0, 2, 1, 3], [1, 3, 2, 0]], dtype=int32)
A different sort axis can be chosen using the
axis
keyword; here we sort along the leading axis:>>> jnp.lexsort([key1, key2], axis=0) Array([[0, 1, 0, 1], [1, 0, 1, 0]], dtype=int32)