jax.numpy.linalg.matrix_rank#
- jax.numpy.linalg.matrix_rank(M, rtol=None, *, tol=Deprecated)[source]#
Compute the rank of a matrix.
JAX implementation of
numpy.linalg.matrix_rank()
.The rank is calculated via the Singular Value Decomposition (SVD), and determined by the number of singular values greater than the specified tolerance.
- Parameters:
M (ArrayLike) – array of shape
(..., N, K)
whose rank is to be computed.rtol (ArrayLike | None) – optional array of shape
(...)
specifying the tolerance. Singular values smaller than rtol * largest_singular_value are considered to be zero. Ifrtol
is None (the default), a reasonable default is chosen based the floating point precision of the input.tol (ArrayLike | DeprecatedArg | None) – deprecated alias of the
rtol
argument. Will result in aDeprecationWarning
if used.
- Returns:
array of shape
a.shape[-2]
giving the matrix rank.- Return type:
Notes
The rank calculation may be inaccurate for matrices with very small singular values or those that are numerically ill-conditioned. Consider adjusting the
rtol
parameter or using a more specialized rank computation method in such cases.Examples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.linalg.matrix_rank(a) Array(2, dtype=int32)
>>> b = jnp.array([[1, 0], # Rank-deficient matrix ... [0, 0]]) >>> jnp.linalg.matrix_rank(b) Array(1, dtype=int32)