jax.scipy.linalg.sqrtm#
- jax.scipy.linalg.sqrtm(A, blocksize=1)[source]#
Compute the matrix square root
JAX implementation of
scipy.linalg.sqrtm()
.- Parameters:
A (ArrayLike) – array of shape
(N, N)
blocksize (int) – Not supported in JAX; JAX always uses
blocksize=1
.
- Returns:
An array of shape
(N, N)
containing the matrix square root ofA
- Return type:
See also
Examples
>>> a = jnp.array([[1., 2., 3.], ... [2., 4., 2.], ... [3., 2., 1.]]) >>> sqrt_a = jax.scipy.linalg.sqrtm(a) >>> with jnp.printoptions(precision=2, suppress=True): ... print(sqrt_a) [[0.92+0.71j 0.54+0.j 0.92-0.71j] [0.54+0.j 1.85+0.j 0.54-0.j ] [0.92-0.71j 0.54-0.j 0.92+0.71j]]
By definition, matrix multiplication of the matrix square root with itself should equal the input:
>>> jnp.allclose(a, sqrt_a @ sqrt_a) Array(True, dtype=bool)
Notes
This function implements the complex Schur method described in [1]. It does not use recursive blocking to speed up computations as a Sylvester Equation solver is not yet available in JAX.
References