jax.numpy.minimum#
- jax.numpy.minimum(x, y, /)[source]#
Return element-wise minimum of the input arrays.
JAX implementation of
numpy.minimum
.- Parameters:
x (ArrayLike) – input array or scalar.
y (ArrayLike) – input array or scalar. Both
x
andy
should either have same shape or be broadcast compatible.
- Returns:
An array containing the element-wise minimum of
x
andy
.- Return type:
Note
- For each pair of elements,
jnp.minimum
returns: smaller of the two if both elements are finite numbers.
nan
if one element isnan
.
See also
jax.numpy.maximum()
: Returns element-wise maximum of the input arrays.jax.numpy.fmin()
: Returns element-wise minimum of the input arrays, ignoring NaNs.jax.numpy.amin()
: Returns the minimum of array elements along a given axis.jax.numpy.nanmin()
: Returns the minimum of the array elements along a given axis, ignoring NaNs.
Examples
Inputs with
x.shape == y.shape
:>>> x = jnp.array([2, 3, 5, 1]) >>> y = jnp.array([-3, 6, -4, 7]) >>> jnp.minimum(x, y) Array([-3, 3, -4, 1], dtype=int32)
Inputs having broadcast compatibility:
>>> x1 = jnp.array([[1, 5, 2], ... [-3, 4, 7]]) >>> y1 = jnp.array([-2, 3, 6]) >>> jnp.minimum(x1, y1) Array([[-2, 3, 2], [-3, 3, 6]], dtype=int32)
Inputs with
nan
:>>> nan = jnp.nan >>> x2 = jnp.array([[2.5, nan, -2], ... [nan, 5, 6], ... [-4, 3, 7]]) >>> y2 = jnp.array([1, nan, 5]) >>> jnp.minimum(x2, y2) Array([[ 1., nan, -2.], [nan, nan, 5.], [-4., nan, 5.]], dtype=float32)