jax.numpy.argpartition#

jax.numpy.argpartition(a, kth, axis=-1)[source]#

Returns indices that partially sort an array.

JAX implementation of numpy.argpartition(). The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.

Parameters:
  • a (Array | ndarray | bool | number | bool | int | float | complex) – array to be partitioned.

  • kth (int) – static integer index about which to partition the array.

  • axis (int) – static integer axis along which to partition the array; default is -1.

Returns:

Indices which partition a at the kth value along axis. The entries before kth are indices of values smaller than take(a, kth, axis), and entries after kth are indices of values larger than take(a, kth, axis)

Return type:

Array

Note

The JAX version requires the kth argument to be a static integer rather than a general array. This is implemented via two calls to jax.lax.top_k(). If you’re only accessing the top or bottom k values of the output, it may be more efficient to call jax.lax.top_k() directly.

See also

Examples

>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
>>> kth = 4
>>> idx = jnp.argpartition(x, kth)
>>> idx
Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)

The result is a sequence of indices that partially sort the input. All indices before kth are of values smaller than the pivot value, and all indices after kth are of values larger than the pivot value:

>>> x_partitioned = x[idx]
>>> smallest_values = x_partitioned[:kth]
>>> pivot_value = x_partitioned[kth]
>>> largest_values = x_partitioned[kth + 1:]
>>> print(smallest_values, pivot_value, largest_values)
[1 2 3 3] 4 [6 8 9 7 5]

Notice that among smallest_values and largest_values, the returned order is arbitrary and implementation-dependent.