jax.numpy.bincount#

jax.numpy.bincount(x, weights=None, minlength=0, *, length=None)[source]#

Count the number of occurrences of each value in an integer array.

JAX implementation of numpy.bincount().

For an array of positive integers x, this function returns an array counts of size x.max() + 1, such that counts[i] contains the number of occurrences of the value i in x.

The JAX version has a few differences from the NumPy version:

  • In NumPy, passing an array x with negative entries will result in an error. In JAX, negative values are clipped to zero.

  • JAX adds an optional length parameter which can be used to statically specify the length of the output array so that this function can be used with transformations like jax.jit(). In this case, items larger than length + 1 will be dropped.

Parameters:
  • x (ArrayLike) – N-dimensional array of positive integers

  • weights (ArrayLike | None | None) – optional array of weights associated with x. If not specified, the weight for each entry will be 1.

  • minlength (int) – the minimum length of the output counts array.

  • length (int | None | None) – the length of the output counts array. Must be specified statically for bincount to be used with jax.jit() and other JAX transformations.

Returns:

An array of counts or summed weights reflecting the number of occurrences of values in x.

Return type:

Array

Examples

Basic bincount:

>>> x = jnp.array([1, 1, 2, 3, 3, 3])
>>> jnp.bincount(x)
Array([0, 2, 1, 3], dtype=int32)

Weighted bincount:

>>> weights = jnp.array([1, 2, 3, 4, 5, 6])
>>> jnp.bincount(x, weights)
Array([ 0,  3,  3, 15], dtype=int32)

Specifying a static length makes this jit-compatible:

>>> jit_bincount = jax.jit(jnp.bincount, static_argnames=['length'])
>>> jit_bincount(x, length=5)
Array([0, 2, 1, 3, 0], dtype=int32)

Any negative numbers are clipped to the first bin, and numbers beyond the specified length are dropped:

>>> x = jnp.array([-1, -1, 1, 3, 10])
>>> jnp.bincount(x, length=5)
Array([2, 1, 0, 1, 0], dtype=int32)