jax.lax.sub#
- jax.lax.sub(x, y)[source]#
Elementwise subtraction: \(x - y\).
This function lowers directly to the stablehlo.subtract operation.
- Parameters:
x (ArrayLike) – Input arrays. Must have matching numerical dtypes. If neither is a scalar,
x
andy
must have the same number of dimensions and be broadcast compatible.y (ArrayLike) – Input arrays. Must have matching numerical dtypes. If neither is a scalar,
x
andy
must have the same number of dimensions and be broadcast compatible.
- Returns:
An array of the same dtype as
x
andy
containing the difference of each pair of broadcasted entries.- Return type:
See also
jax.numpy.subtract()
: NumPy-style subtraction supporting inputs with mixed dtypes and ranks.