jax.numpy.bincount#
- jax.numpy.bincount(x, weights=None, minlength=0, *, length=None)[source]#
Count the number of occurrences of each value in an integer array.
JAX implementation of
numpy.bincount()
.For an array of positive integers
x
, this function returns an arraycounts
of sizex.max() + 1
, such thatcounts[i]
contains the number of occurrences of the valuei
inx
.The JAX version has a few differences from the NumPy version:
In NumPy, passing an array
x
with negative entries will result in an error. In JAX, negative values are clipped to zero.JAX adds an optional
length
parameter which can be used to statically specify the length of the output array so that this function can be used with transformations likejax.jit()
. In this case, items larger than length + 1 will be dropped.
- Parameters:
x (ArrayLike) – N-dimensional array of positive integers
weights (ArrayLike | None | None) – optional array of weights associated with
x
. If not specified, the weight for each entry will be1
.minlength (int) – the minimum length of the output counts array.
length (int | None | None) – the length of the output counts array. Must be specified statically for
bincount
to be used withjax.jit()
and other JAX transformations.
- Returns:
An array of counts or summed weights reflecting the number of occurrences of values in
x
.- Return type:
Examples
Basic bincount:
>>> x = jnp.array([1, 1, 2, 3, 3, 3]) >>> jnp.bincount(x) Array([0, 2, 1, 3], dtype=int32)
Weighted bincount:
>>> weights = jnp.array([1, 2, 3, 4, 5, 6]) >>> jnp.bincount(x, weights) Array([ 0, 3, 3, 15], dtype=int32)
Specifying a static
length
makes this jit-compatible:>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length']) >>> jit_bincount(x, length=5) Array([0, 2, 1, 3, 0], dtype=int32)
Any negative numbers are clipped to the first bin, and numbers beyond the specified
length
are dropped:>>> x = jnp.array([-1, -1, 1, 3, 10]) >>> jnp.bincount(x, length=5) Array([2, 1, 0, 1, 0], dtype=int32)