jax.numpy.ldexp#

jax.numpy.ldexp(x1, x2, /)[source]#

Compute x1 * 2 ** x2

JAX implementation of numpy.ldexp().

Note that XLA does not provide an ldexp operation, so this is implemneted in JAX via a standard multiplication and exponentiation.

Parameters:
  • x1 (ArrayLike) – real-valued input array.

  • x2 (ArrayLike) – integer input array. Must be broadcast-compatible with x1.

Returns:

x1 * 2 ** x2 computed element-wise.

Return type:

Array

See also

Examples

>>> x1 = jnp.arange(5.0)
>>> x2 = 10
>>> jnp.ldexp(x1, x2)
Array([   0., 1024., 2048., 3072., 4096.], dtype=float32)

ldexp can be used to reconstruct the input to frexp:

>>> x = jnp.array([2., 3., 5., 11.])
>>> m, e = jnp.frexp(x)
>>> m
Array([0.5   , 0.75  , 0.625 , 0.6875], dtype=float32)
>>> e
Array([2, 2, 3, 4], dtype=int32)
>>> jnp.ldexp(m, e)
Array([ 2.,  3.,  5., 11.], dtype=float32)