jax.lax.reduce_sum#
- jax.lax.reduce_sum(operand, axes)[source]#
Compute the sum of elements over one or more array axes.
- Parameters:
operand (ArrayLike) – array over which to sum. Must have numerical dtype.
axes (Sequence[int]) – sequence of zero or more unique integers specifying the axes over which to sum. Each entry must satisfy
0 <= axis < operand.ndim
.
- Returns:
An array of the same dtype as
operand
, with shape corresponding to the dimensions ofoperand.shape
withaxes
removed.- Return type:
Notes
Unlike
jax.numpy.sum()
,jax.lax.reduce_sum()
does not upcast narrow-width types for accumulation, so sums of 8-bit or 16-bit types may be subject to rounding errors.See also
jax.numpy.sum()
: more flexible NumPy-style summation API, built aroundjax.lax.reduce_sum()
.Other low-level
jax.lax
reduction operators:jax.lax.reduce_prod()
,jax.lax.reduce_max()
,jax.lax.reduce_min()
,jax.lax.reduce_and()
,jax.lax.reduce_or()
,jax.lax.reduce_xor()
.