jax.numpy.logaddexp2#
- jax.numpy.logaddexp2 = <jnp.ufunc 'logaddexp2'>#
Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.
JAX implementation of
numpy.logaddexp2
.- Parameters:
x1 – input array or scalar.
x2 – input array or scalar.
x1
andx2
should either have same shape or be broadcast compatible.args (ArrayLike)
out (None)
where (None)
- Returns:
An array containing the result, \(log_2(2^{x1}+2^{x2})\), element-wise.
- Return type:
Any
See also
jax.numpy.logaddexp()
: Computeslog(exp(x1) + exp(x2))
, element-wise.jax.numpy.log2()
: Calculates the base-2 logarithm ofx
element-wise.
Examples
>>> x1 = jnp.array([[3, -1, 4], ... [8, 5, -2]]) >>> x2 = jnp.array([2, 3, -5]) >>> result1 = jnp.logaddexp2(x1, x2) >>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2)) >>> jnp.allclose(result1, result2) Array(True, dtype=bool)