jax.scipy.linalg.rsf2csf#
- jax.scipy.linalg.rsf2csf(T, Z, check_finite=True)[source]#
Convert real Schur form to complex Schur form.
JAX implementation of
scipy.linalg.rsf2csf()
.- Parameters:
T (ArrayLike) – array of shape
(..., N, N)
containing the real Schur form of the input.Z (ArrayLike) – array of shape
(..., N, N)
containing the corresponding Schur transformation matrix.check_finite (bool) – unused by JAX
- Returns:
A tuple of arrays
(T, Z)
of the same shape as the inputs, containing the Complex Schur form and the associated Schur transformation matrix.- Return type:
See also
jax.scipy.linalg.schur()
: Schur decompositionExamples
>>> A = jnp.array([[0., 3., 3.], ... [0., 1., 2.], ... [2., 0., 1.]]) >>> Tr, Zr = jax.scipy.linalg.schur(A) >>> Tc, Zc = jax.scipy.linalg.rsf2csf(Tr, Zr)
Both the real and complex form can be used to reconstruct the input matrix to float32 precision:
>>> jnp.allclose(Zr @ Tr @ Zr.T, A, atol=1E-5) Array(True, dtype=bool) >>> jnp.allclose(Zc @ Tc @ Zc.conj().T, A, atol=1E-5) Array(True, dtype=bool)
The real-valued Schur form is only quasi-upper-triangular, as we can see in this case:
>>> with jax.numpy.printoptions(precision=2, suppress=True): ... print(Tr) [[ 3.76 -2.17 1.38] [ 0. -0.88 -0.35] [ 0. 2.37 -0.88]]
By contrast, the complex form is truly upper-triangular:
>>> with jnp.printoptions(precision=2, suppress=True): ... print(Tc) [[ 3.76+0.j 1.29-0.78j 2.02-0.5j ] [ 0. +0.j -0.88+0.91j -2.02+0.j ] [ 0. +0.j 0. +0.j -0.88-0.91j]]