jax.lax.max#
- jax.lax.max(x, y)[source]#
Elementwise maximum: \(\mathrm{max}(x, y)\).
This function lowers directly to the stablehlo.maximum operation for non-complex inputs. For complex numbers, this uses a lexicographic comparison on the (real, imaginary) pairs.
- Parameters:
x (ArrayLike) – Input arrays. Must have matching dtypes. If neither is a scalar,
x
andy
must have the same rank and be broadcast compatible.y (ArrayLike) – Input arrays. Must have matching dtypes. If neither is a scalar,
x
andy
must have the same rank and be broadcast compatible.
- Returns:
An array of the same dtype as
x
andy
containing the elementwise maximum.- Return type:
See also
jax.numpy.maximum()
: more flexibly NumPy-style maximum.jax.lax.reduce_max()
: maximum along an axis of an array.jax.lax.min()
: elementwise minimum.