jax.lax.reduce_max#
- jax.lax.reduce_max(operand, axes)[source]#
Compute the maximum of elements over one or more array axes.
- Parameters:
operand (ArrayLike) – array over which to compute maximum.
axes (Sequence[int]) – sequence of zero or more unique integers specifying the axes over which to reduce. Each entry must satisfy
0 <= axis < operand.ndim
.
- Returns:
An array of the same dtype as
operand
, with shape corresponding to the dimensions ofoperand.shape
withaxes
removed.- Return type:
See also
jax.numpy.max()
: more flexible NumPy-style max-reduction API, built aroundjax.lax.reduce_max()
.Other low-level
jax.lax
reduction operators:jax.lax.reduce_sum()
,jax.lax.reduce_prod()
,jax.lax.reduce_min()
,jax.lax.reduce_and()
,jax.lax.reduce_or()
,jax.lax.reduce_xor()
.