jax.lax.gt#
- jax.lax.gt(x, y)[source]#
Elementwise greater-than: \(x > y\).
This function lowers directly to the stablehlo.compare operation with
comparison_direction=GT
andcompare_type
set according to the input dtype.- Parameters:
x (ArrayLike) – Input arrays. Must have matching non-complex dtypes. If neither is a scalar,
x
andy
must have the same number of dimensions and be broadcast compatible.y (ArrayLike) – Input arrays. Must have matching non-complex dtypes. If neither is a scalar,
x
andy
must have the same number of dimensions and be broadcast compatible.
- Returns:
A boolean array of shape
lax.broadcast_shapes(x.shape, y.shape)
containing the elementwise greater-than comparison.- Return type:
See also
jax.numpy.greater()
: NumPy wrapper for this API, also accessible via thex > y
operator on JAX arrays.jax.lax.eq()
: elementwise equaljax.lax.ne()
: elementwise not-equaljax.lax.ge()
: elementwise greater-than-or-equaljax.lax.le()
: elementwise less-than-or-equaljax.lax.lt()
: elementwise less-than