jax.lax.round#
- jax.lax.round(x, rounding_method=RoundingMethod.AWAY_FROM_ZERO)[source]#
Elementwise round.
Rounds values to the nearest integer. This function lowers directly to the stablehlo.round operation.
- Parameters:
x (ArrayLike) – an array or scalar value to round. Must have floating-point type.
rounding_method (RoundingMethod) – the method to use when rounding halfway values (e.g.,
0.5
). Seejax.lax.RoundingMethod
for possible values.
- Returns:
An array of the same shape and dtype as
x
, containing the elementwise rounding ofx
.- Return type:
See also
jax.lax.floor()
: round to the next integer toward negative infinityjax.lax.ceil()
: round to the next integer toward positive infinity
Examples
>>> import jax.numpy as jnp >>> from jax import lax >>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5]) >>> jax.lax.round(x) # defaults method is AWAY_FROM_ZERO Array([-2., -1., -1., 0., 1., 1., 2.], dtype=float32) >>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN) Array([-2., -1., -0., 0., 0., 1., 2.], dtype=float32)