jax.numpy.argpartition#
- jax.numpy.argpartition(a, kth, axis=-1)[source]#
Returns indices that partially sort an array.
JAX implementation of
numpy.argpartition()
. The JAX version differs from NumPy in the treatment of NaN entries: NaNs which have the negative bit set are sorted to the beginning of the array.- Parameters:
- Returns:
Indices which partition
a
at thekth
value alongaxis
. The entries beforekth
are indices of values smaller thantake(a, kth, axis)
, and entries afterkth
are indices of values larger thantake(a, kth, axis)
- Return type:
Note
The JAX version requires the
kth
argument to be a static integer rather than a general array. This is implemented via two calls tojax.lax.top_k()
. If you’re only accessing the top or bottom k values of the output, it may be more efficient to calljax.lax.top_k()
directly.See also
jax.numpy.partition()
: direct partial sortjax.numpy.argsort()
: full indirect sortjax.lax.top_k()
: directly find the top k entriesjax.lax.approx_max_k()
: compute the approximate top k entriesjax.lax.approx_min_k()
: compute the approximate bottom k entries
Examples
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3]) >>> kth = 4 >>> idx = jnp.argpartition(x, kth) >>> idx Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
The result is a sequence of indices that partially sort the input. All indices before
kth
are of values smaller than the pivot value, and all indices afterkth
are of values larger than the pivot value:>>> x_partitioned = x[idx] >>> smallest_values = x_partitioned[:kth] >>> pivot_value = x_partitioned[kth] >>> largest_values = x_partitioned[kth + 1:] >>> print(smallest_values, pivot_value, largest_values) [1 2 3 3] 4 [6 8 9 7 5]
Notice that among
smallest_values
andlargest_values
, the returned order is arbitrary and implementation-dependent.