jax.numpy.round#
- jax.numpy.round(a, decimals=0, out=None)[source]#
Round input evenly to the given number of decimals.
JAX implementation of
numpy.round()
.- Parameters:
a (ArrayLike) – input array or scalar.
decimals (int) – int, default=0. Number of decimal points to which the input needs to be rounded. It must be specified statically. Not implemented for
decimals < 0
.out (None) – Unused by JAX.
- Returns:
An array containing the rounded values to the specified
decimals
with same shape and dtype asa
.- Return type:
Note
jnp.round
rounds to the nearest even integer for the values exactly halfway between rounded decimal values.See also
jax.numpy.floor()
: Rounds the input to the nearest integer downwards.jax.numpy.ceil()
: Rounds the input to the nearest integer upwards.jax.numpy.fix()
and :func:numpy.trunc`: Rounds the input to the nearest integer towards zero.
Examples
>>> x = jnp.array([1.532, 3.267, 6.149]) >>> jnp.round(x) Array([2., 3., 6.], dtype=float32) >>> jnp.round(x, decimals=2) Array([1.53, 3.27, 6.15], dtype=float32)
For values exactly halfway between rounded values:
>>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5]) >>> jnp.round(x1) Array([10., 22., 12., 32.], dtype=float32)