jax.lax.linalg.qr#

jax.lax.linalg.qr(x: ArrayLike, *, pivoting: Literal[False], full_matrices: bool = True, use_magma: bool | None = None) tuple[Array, Array][source]#
jax.lax.linalg.qr(x: ArrayLike, *, pivoting: Literal[True], full_matrices: bool = True, use_magma: bool | None = None) tuple[Array, Array, Array]
jax.lax.linalg.qr(x: ArrayLike, *, pivoting: bool = False, full_matrices: bool = True, use_magma: bool | None = None) tuple[Array, Array] | tuple[Array, Array, Array]

QR decomposition.

Computes the QR decomposition

\[A = Q \, R\]

of matrices \(A\), such that \(Q\) is a unitary (orthogonal) matrix, and \(R\) is an upper-triangular matrix.

Parameters:
  • x – A batch of matrices with shape [..., m, n].

  • pivoting – Allows the QR decomposition to be rank-revealing. If True, compute the column pivoted decomposition A[:, P] = Q @ R, where P is chosen such that the diagonal of R is non-increasing. Currently supported on CPU and GPU backends only.

  • full_matrices – Determines if full or reduced matrices are returned; see below.

  • use_magma – Locally override the jax_use_magma flag. If True, the pivoted qr factorization is computed using MAGMA. If False, the computation is done using LAPACK on the host CPU. If None (default), the behavior is controlled by the jax_use_magma flag. This argument is only used on GPU.

Returns:

A pair of arrays (q, r), if pivoting=False, otherwise (q, r, p).

Array q is a unitary (orthogonal) matrix, with shape [..., m, m] if full_matrices=True, or [..., m, min(m, n)] if full_matrices=False.

Array r is an upper-triangular matrix with shape [..., m, n] if full_matrices=True, or [..., min(m, n), n] if full_matrices=False.

Array p is an index vector with shape […, n]

Notes

  • MAGMA support is experimental - see jax.lax.linalg.eig() for further assumptions and limitations.

  • If jax_use_magma is set to "auto", the MAGMA implementation will be used if the library can be found, and the input matrix is sufficiently large (has at least 2048 columns).