jax.numpy.power#
- jax.numpy.power(x1, x2, /)[source]#
Calculate element-wise base
x1
exponential ofx2
.JAX implementation of
numpy.power
.- Parameters:
x1 (ArrayLike) – scalar or array. Specifies the bases.
x2 (ArrayLike) – scalar or array. Specifies the exponent.
x1
andx2
should either have same shape or be broadcast compatible.
- Returns:
An array containing the base
x1
exponentials ofx2
with same dtype as input.- Return type:
Note
When
x2
is a concrete integer scalar,jnp.power
lowers tojax.lax.integer_pow()
.When
x2
is a traced scalar or an array,jnp.power
lowers tojax.lax.pow()
.jnp.power
raises aTypeError
for integer type raised to negative integer power.jnp.power
returnsnan
for negative value raised to the power of non-integer values.
See also
jax.lax.pow()
: Computes element-wise power, \(x^y\).jax.lax.integer_pow()
: Computes element-wise power \(x^y\), where \(y\) is a fixed integer.jax.numpy.float_power()
: Computes the first array raised to the power of second array, element-wise, by promoting to the inexact dtype.jax.numpy.pow()
: Computes the first array raised to the power of second array, element-wise.
Examples
Inputs with scalar integers:
>>> jnp.power(4, 3) Array(64, dtype=int32, weak_type=True)
Inputs with same shape:
>>> x1 = jnp.array([2, 4, 5]) >>> x2 = jnp.array([3, 0.5, 2]) >>> jnp.power(x1, x2) Array([ 8., 2., 25.], dtype=float32)
Inputs with broadcast compatibility:
>>> x3 = jnp.array([-2, 3, 1]) >>> x4 = jnp.array([[4, 1, 6], ... [1.3, 3, 5]]) >>> jnp.power(x3, x4) Array([[16., 3., 1.], [nan, 27., 1.]], dtype=float32)