jax.numpy.histogramdd#
- jax.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)[source]#
Compute an N-dimensional histogram.
JAX implementation of
numpy.histogramdd()
.- Parameters:
sample (ArrayLike) – input array of shape
(N, D)
representingN
points inD
dimensions.bins (ArrayLike | list[ArrayLike]) – Specify the number of bins in each dimension of the histogram. (default: 10). May also be a length-D sequence of integers or arrays of bin edges.
range (Sequence[None | Array | Sequence[ArrayLike]] | None | None) – Length-D sequence of pairs specifying the range for each dimension. If not specified, the range is inferred from the data.
weights (ArrayLike | None | None) – An optional shape
(N,)
array specifying the weights of the data points. Should be the same shape assample
. If not specified, each data point is weighted equally.density (bool | None | None) – If True, return the normalized histogram in units of counts per unit volume. If False (default) return the (weighted) counts per bin.
- Returns:
A tuple of arrays
(histogram, bin_edges)
, wherehistogram
contains the aggregated data, andbin_edges
specifies the boundaries of the bins.- Return type:
See also
jax.numpy.histogram()
: Compute the histogram of a 1D array.jax.numpy.histogram2d()
: Compute the histogram of a 2D array.jax.numpy.histogram_bin_edges()
: Compute the bin edges for a histogram.
Examples
A histogram over 100 points in three dimensions
>>> key = jax.random.key(42) >>> a = jax.random.normal(key, (100, 3)) >>> counts, bin_edges = jnp.histogramdd(a, bins=6, ... range=[(-3, 3), (-3, 3), (-3, 3)]) >>> counts.shape (6, 6, 6) >>> bin_edges [Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32)]
Using
density=True
returns a normalized histogram:>>> density, bin_edges = jnp.histogramdd(a, density=True) >>> bin_widths = map(jnp.diff, bin_edges) >>> dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij') >>> normed = jnp.sum(density * dx * dy * dz) >>> jnp.allclose(normed, 1.0) Array(True, dtype=bool)