jax.numpy.sqrt#

jax.numpy.sqrt(x, /)[source]#

Calculates element-wise non-negative square root of the input array.

JAX implementation of numpy.sqrt.

Parameters:

x (ArrayLike) – input array or scalar.

Returns:

An array containing the non-negative square root of the elements of x.

Return type:

Array

Note

  • For real-valued negative inputs, jnp.sqrt produces a nan output.

  • For complex-valued negative inputs, jnp.sqrt produces a complex output.

See also

Examples

>>> x = jnp.array([-8-6j, 1j, 4])
>>> with jnp.printoptions(precision=3, suppress=True):
...   jnp.sqrt(x)
Array([1.   -3.j   , 0.707+0.707j, 2.   +0.j   ], dtype=complex64)
>>> jnp.sqrt(-1)
Array(nan, dtype=float32, weak_type=True)