jax.numpy.linalg.qr#
- jax.numpy.linalg.qr(a, mode='reduced')[source]#
Compute the QR decomposition of an array
JAX implementation of
numpy.linalg.qr()
.The QR decomposition of a matrix A is given by
\[A = QR\]Where Q is a unitary matrix (i.e. \(Q^HQ=I\)) and R is an upper-triangular matrix.
- Parameters:
a (ArrayLike) – array of shape (…, M, N)
mode (str) –
Computational mode. Supported values are:
"reduced"
(default): return Q of shape(..., M, K)
and R of shape(..., K, N)
, whereK = min(M, N)
."complete"
: return Q of shape(..., M, M)
and R of shape(..., M, N)
."raw"
: return lapack-internal representations of shape(..., M, N)
and(..., K)
."r"
: return R only.
- Returns:
A tuple
(Q, R)
(ifmode
is not"r"
) otherwise an arrayR
, where:Q
is an orthogonal matrix of shape(..., M, K)
(ifmode
is"reduced"
) or(..., M, M)
(ifmode
is"complete"
).R
is an upper-triangular matrix of shape(..., M, N)
(ifmode
is"r"
or"complete"
) or(..., K, N)
(ifmode
is"reduced"
)
with
K = min(M, N)
.- Return type:
Array | QRResult
See also
jax.scipy.linalg.qr()
: SciPy-style QR decomposition APIjax.lax.linalg.qr()
: XLA-style QR decomposition API
Examples
Compute the QR decomposition of a matrix:
>>> a = jnp.array([[1., 2., 3., 4.], ... [5., 4., 2., 1.], ... [6., 3., 1., 5.]]) >>> Q, R = jnp.linalg.qr(a) >>> Q Array([[-0.12700021, -0.7581426 , -0.6396022 ], [-0.63500065, -0.43322435, 0.63960224], [-0.7620008 , 0.48737738, -0.42640156]], dtype=float32) >>> R Array([[-7.8740077, -5.080005 , -2.4130025, -4.953006 ], [ 0. , -1.7870499, -2.6534991, -1.028908 ], [ 0. , 0. , -1.0660033, -4.050814 ]], dtype=float32)
Check that
Q
is orthonormal:>>> jnp.allclose(Q.T @ Q, jnp.eye(3), atol=1E-5) Array(True, dtype=bool)
Reconstruct the input:
>>> jnp.allclose(Q @ R, a) Array(True, dtype=bool)