jax.lax.reduce_sum#

jax.lax.reduce_sum(operand, axes)[source]#

Compute the sum of elements over one or more array axes.

Parameters:
  • operand (ArrayLike) – array over which to sum. Must have numerical dtype.

  • axes (Sequence[int]) – sequence of zero or more unique integers specifying the axes over which to sum. Each entry must satisfy 0 <= axis < operand.ndim.

Returns:

An array of the same dtype as operand, with shape corresponding to the dimensions of operand.shape with axes removed.

Return type:

Array

Notes

Unlike jax.numpy.sum(), jax.lax.reduce_sum() does not upcast narrow-width types for accumulation, so sums of 8-bit or 16-bit types may be subject to rounding errors.

See also