jax.lax.round#

jax.lax.round(x, rounding_method=RoundingMethod.AWAY_FROM_ZERO)[source]#

Elementwise round.

Rounds values to the nearest integer. This function lowers directly to the stablehlo.round operation.

Parameters:
  • x (ArrayLike) – an array or scalar value to round. Must have floating-point type.

  • rounding_method (RoundingMethod) – the method to use when rounding halfway values (e.g., 0.5). See jax.lax.RoundingMethod for possible values.

Returns:

An array of the same shape and dtype as x, containing the elementwise rounding of x.

Return type:

Array

See also

Examples

>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.array([-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5])
>>> jax.lax.round(x)  # defaults method is AWAY_FROM_ZERO
Array([-2., -1., -1.,  0.,  1.,  1.,  2.], dtype=float32)
>>> jax.lax.round(x, rounding_method=jax.lax.RoundingMethod.TO_NEAREST_EVEN)
Array([-2., -1., -0.,  0.,  0.,  1.,  2.], dtype=float32)