jax.numpy.linalg.pinv#
- jax.numpy.linalg.pinv(a, rtol=None, hermitian=False, *, rcond=Deprecated)[source]#
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
JAX implementation of
numpy.linalg.pinv()
.- Parameters:
a (ArrayLike) – array of shape
(..., M, N)
containing matrices to pseudo-invert.rtol (ArrayLike | None | None) – float or array_like of shape
a.shape[:-2]
. Specifies the cutoff for small singular values.of shape(...,)
. Cutoff for small singular values; singular values smallerrtol * largest_singular_value
are treated as zero. The default is determined based on the floating point precision of the dtype.hermitian (bool) – if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False)
rcond (ArrayLike | DeprecatedArg | None) – deprecated alias of the
rtol
argument. Will result in aDeprecationWarning
if used.
- Returns:
An array of shape
(..., N, M)
containing the pseudo-inverse ofa
.- Return type:
See also
jax.numpy.linalg.inv()
: multiplicative inverse of a square matrix.
Notes
jax.numpy.linalg.pinv()
differs fromnumpy.linalg.pinv()
in the default value of rcond`: in NumPy, the default is 1e-15. In JAX, the default is10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps
.Examples
>>> a = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> a_pinv = jnp.linalg.pinv(a) >>> a_pinv Array([[-1.333332 , -0.33333257, 0.6666657 ], [ 1.0833322 , 0.33333272, -0.41666582]], dtype=float32)
The pseudo-inverse operates as a multiplicative inverse so long as the output is not rank-deficient:
>>> jnp.allclose(a_pinv @ a, jnp.eye(2), atol=1E-4) Array(True, dtype=bool)