jax.lax.reduce_max#

jax.lax.reduce_max(operand, axes)[source]#

Compute the maximum of elements over one or more array axes.

Parameters:
  • operand (ArrayLike) – array over which to compute maximum.

  • axes (Sequence[int]) – sequence of zero or more unique integers specifying the axes over which to reduce. Each entry must satisfy 0 <= axis < operand.ndim.

Returns:

An array of the same dtype as operand, with shape corresponding to the dimensions of operand.shape with axes removed.

Return type:

Array

See also