jax.scipy.linalg.polar#
- jax.scipy.linalg.polar(a, side='right', *, method='qdwh', eps=None, max_iterations=None)[source]#
Computes the polar decomposition.
Given the \(m \times n\) matrix \(a\), returns the factors of the polar decomposition \(u\) (also \(m \times n\)) and \(p\) such that \(a = up\) (if side is
"right"
; \(p\) is \(n \times n\)) or \(a = pu\) (if side is"left"
; \(p\) is \(m \times m\)), where \(p\) is positive semidefinite. If \(a\) is nonsingular, \(p\) is positive definite and the decomposition is unique. \(u\) has orthonormal columns unless \(n > m\), in which case it has orthonormal rows.Writing the SVD of \(a\) as \(a = u_\mathit{svd} \cdot s_\mathit{svd} \cdot v^h_\mathit{svd}\), we have \(u = u_\mathit{svd} \cdot v^h_\mathit{svd}\). Thus the unitary factor \(u\) can be constructed as the application of the sign function to the singular values of \(a\); or, if \(a\) is Hermitian, the eigenvalues.
Several methods exist to compute the polar decomposition. Currently two are supported:
method="svd"
:Computes the SVD of \(a\) and then forms \(u = u_\mathit{svd} \cdot v^h_\mathit{svd}\).
method="qdwh"
:Applies the QDWH (QR-based Dynamically Weighted Halley) algorithm.
- Parameters:
a (ArrayLike) – The \(m \times n\) input matrix.
side (str) – Determines whether a right or left polar decomposition is computed. If
side
is"right"
then \(a = up\). Ifside
is"left"
then \(a = pu\). The default is"right"
.method (str) – Determines the algorithm used, as described above.
precision –
Precision
object specifying the matmul precision.eps (float | None) – The final result will satisfy \(\left|x_k - x_{k-1}\right| < \left|x_k\right| (4\epsilon)^{\frac{1}{3}}\), where \(x_k\) are the QDWH iterates. Ignored if
method
is not"qdwh"
.max_iterations (int | None) – Iterations will terminate after this many steps even if the above is unsatisfied. Ignored if
method
is not"qdwh"
.
- Returns:
A
(unitary, posdef)
tuple, whereunitary
is the unitary factor (\(m \times n\)), andposdef
is the positive-semidefinite factor.posdef
is either \(n \times n\) or \(m \times m\) depending on whetherside
is"right"
or"left"
, respectively.- Return type:
Examples
Polar decomposition of a 3x3 matrix:
>>> a = jnp.array([[1., 2., 3.], ... [5., 4., 2.], ... [3., 2., 1.]]) >>> U, P = jax.scipy.linalg.polar(a)
U is a Unitary Matrix:
>>> jnp.round(U.T @ U) Array([[ 1., -0., -0.], [-0., 1., 0.], [-0., 0., 1.]], dtype=float32)
P is positive-semidefinite Matrix:
>>> with jnp.printoptions(precision=2, suppress=True): ... print(P) [[4.79 3.25 1.23] [3.25 3.06 2.01] [1.23 2.01 2.91]]
The original matrix can be reconstructed by multiplying the U and P:
>>> a_reconstructed = U @ P >>> jnp.allclose(a, a_reconstructed) Array(True, dtype=bool)