jax.lax.top_k#
- jax.lax.top_k(operand, k)[source]#
Returns top
k
values and their indices along the last axis ofoperand
.- Parameters:
operand (ArrayLike) – N-dimensional array of non-complex type.
k (int) – integer specifying the number of top entries.
- Returns:
A tuple
(values, indices)
wherevalues
is an array containing the top k values along the last axis.indices
is an array containing the indices corresponding to values.
- Return type:
Examples
Find the largest three values, and their indices, within an array:
>>> x = jnp.array([9., 3., 6., 4., 10.]) >>> values, indices = jax.lax.top_k(x, 3) >>> values Array([10., 9., 6.], dtype=float32) >>> indices Array([4, 0, 2], dtype=int32)