jax.numpy.histogramdd#

jax.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)[source]#

Compute an N-dimensional histogram.

JAX implementation of numpy.histogramdd().

Parameters:
  • sample (ArrayLike) – input array of shape (N, D) representing N points in D dimensions.

  • bins (ArrayLike | list[ArrayLike]) – Specify the number of bins in each dimension of the histogram. (default: 10). May also be a length-D sequence of integers or arrays of bin edges.

  • range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – Length-D sequence of pairs specifying the range for each dimension. If not specified, the range is inferred from the data.

  • weights (ArrayLike | None | None) – An optional shape (N,) array specifying the weights of the data points. Should be the same shape as sample. If not specified, each data point is weighted equally.

  • density (bool | None | None) – If True, return the normalized histogram in units of counts per unit volume. If False (default) return the (weighted) counts per bin.

Returns:

A tuple of arrays (histogram, bin_edges), where histogram contains the aggregated data, and bin_edges specifies the boundaries of the bins.

Return type:

tuple[Array, list[Array]]

See also

Examples

A histogram over 100 points in three dimensions

>>> key = jax.random.key(42)
>>> a = jax.random.normal(key, (100, 3))
>>> counts, bin_edges = jnp.histogramdd(a, bins=6,
...                                     range=[(-3, 3), (-3, 3), (-3, 3)])
>>> counts.shape
(6, 6, 6)
>>> bin_edges  
[Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32),
 Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32),
 Array([-3., -2., -1.,  0.,  1.,  2.,  3.], dtype=float32)]

Using density=True returns a normalized histogram:

>>> density, bin_edges = jnp.histogramdd(a, density=True)
>>> bin_widths = map(jnp.diff, bin_edges)
>>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij')
>>> normed = jnp.sum(density * dx * dy * dz)
>>> jnp.allclose(normed, 1.0)
Array(True, dtype=bool)