jax.numpy.sqrt#
- jax.numpy.sqrt(x, /)[source]#
Calculates element-wise non-negative square root of the input array.
JAX implementation of
numpy.sqrt
.- Parameters:
x (ArrayLike) – input array or scalar.
- Returns:
An array containing the non-negative square root of the elements of
x
.- Return type:
Note
For real-valued negative inputs,
jnp.sqrt
produces anan
output.For complex-valued negative inputs,
jnp.sqrt
produces acomplex
output.
See also
jax.numpy.square()
: Calculates the element-wise square of the input.jax.numpy.power()
: Calculates the element-wise basex1
exponential ofx2
.
Examples
>>> x = jnp.array([-8-6j, 1j, 4]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sqrt(x) Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) >>> jnp.sqrt(-1) Array(nan, dtype=float32, weak_type=True)