jax.numpy.square#

jax.numpy.square(x, /)[source]#

Calculate element-wise square of the input array.

JAX implementation of numpy.square.

Parameters:

x (ArrayLike) – input array or scalar.

Returns:

An array containing the square of the elements of x.

Return type:

Array

Note

jnp.square is equivalent to computing jnp.power(x, 2).

See also

  • jax.numpy.sqrt(): Calculates the element-wise non-negative square root of the input array.

  • jax.numpy.power(): Calculates the element-wise base x1 exponential of x2.

  • jax.lax.integer_pow(): Computes element-wise power \(x^y\), where \(y\) is a fixed integer.

  • jax.numpy.float_power(): Computes the first array raised to the power of second array, element-wise, by promoting to the inexact dtype.

Examples

>>> x = jnp.array([3, -2, 5.3, 1])
>>> jnp.square(x)
Array([ 9.      ,  4.      , 28.090002,  1.      ], dtype=float32)
>>> jnp.power(x, 2)
Array([ 9.      ,  4.      , 28.090002,  1.      ], dtype=float32)

For integer inputs:

>>> x1 = jnp.array([2, 4, 5, 6])
>>> jnp.square(x1)
Array([ 4, 16, 25, 36], dtype=int32)

For complex-valued inputs:

>>> x2 = jnp.array([1-3j, -1j, 2])
>>> jnp.square(x2)
Array([-8.-6.j, -1.+0.j,  4.+0.j], dtype=complex64)