jax.numpy.unique_all#
- jax.numpy.unique_all(x, /, *, size=None, fill_value=None)[source]#
Return unique values from x, along with indices, inverse indices, and counts.
JAX implementation of
numpy.unique_all()
; this is equivalent to callingjax.numpy.unique()
with return_index, return_inverse, return_counts, and equal_nan set to True.Because the size of the output of
unique_all
is data-dependent, the function is not typically compatible withjit()
and other JAX transformations. The JAX version adds the optionalsize
argument which must be specified statically forjnp.unique
to be used in such contexts.- Parameters:
x (ArrayLike) – N-dimensional array from which unique values will be extracted.
size (int | None | None) – if specified, return only the first
size
sorted unique elements. If there are fewer unique elements thansize
indicates, the return value will be padded withfill_value
.fill_value (ArrayLike | None | None) – when
size
is specified and there are fewer than the indicated number of elements, fill the remaining entriesfill_value
. Defaults to the minimum unique value.
- Returns:
values
:an array of shape
(n_unique,)
containing the unique values fromx
.
indices
:An array of shape
(n_unique,)
. Contains the indices of the first occurrence of each unique value inx
. For 1D inputs,x[indices]
is equivalent tovalues
.
inverse_indices
:An array of shape
x.shape
. Contains the indices withinvalues
of each value inx
. For 1D inputs,values[inverse_indices]
is equivalent tox
.
counts
:An array of shape
(n_unique,)
. Contains the number of occurrences of each unique value inx
.
- Return type:
A tuple
(values, indices, inverse_indices, counts)
, with the following properties
See also
jax.numpy.unique()
: general function for computing unique values.jax.numpy.unique_values()
: compute onlyvalues
.jax.numpy.unique_counts()
: compute onlyvalues
andcounts
.jax.numpy.unique_inverse()
: compute onlyvalues
andinverse
.
Examples
Here we compute the unique values in a 1D array:
>>> x = jnp.array([3, 4, 1, 3, 1]) >>> result = jnp.unique_all(x)
The result is a
NamedTuple
with four named attributes. Thevalues
attribute contains the unique values from the array:>>> result.values Array([1, 3, 4], dtype=int32)
The
indices
attribute contains the indices of the uniquevalues
within the input array:>>> result.indices Array([2, 0, 1], dtype=int32) >>> jnp.all(result.values == x[result.indices]) Array(True, dtype=bool)
The
inverse_indices
attribute contains the indices of the input withinvalues
:>>> result.inverse_indices Array([1, 2, 0, 1, 0], dtype=int32) >>> jnp.all(x == result.values[result.inverse_indices]) Array(True, dtype=bool)
The
counts
attribute contains the counts of each unique value in the input:>>> result.counts Array([2, 2, 1], dtype=int32)
For examples of the
size
andfill_value
arguments, seejax.numpy.unique()
.