jax.lax.custom_linear_solve#
- jax.lax.custom_linear_solve(matvec, b, solve, transpose_solve=None, symmetric=False, has_aux=False)[source]#
Perform a matrix-free linear solve with implicitly defined gradients.
This function allows for overriding or defining gradients for a linear solve directly via implicit differentiation at the solution, rather than by differentiating through the solve operation. This can sometimes be much faster or more numerically stable, or differentiating through the solve operation may not even be implemented (e.g., if
solve
useslax.while_loop
).Required invariant:
x = solve(matvec, b) # solve the linear equation assert matvec(x) == b # not checked
- Parameters:
matvec (Callable) – linear function to invert. Must be differentiable.
b (Any) – constant right handle side of the equation. May be any nested structure of arrays.
solve (Callable[[Callable, Any], Any]) – higher level function that solves for solution to the linear equation, i.e.,
solve(matvec, x) == x
for allx
of the same form asb
. This function need not be differentiable.transpose_solve (Callable[[Callable, Any], Any] | None) – higher level function for solving the transpose linear equation, i.e.,
transpose_solve(vecmat, x) == x
, wherevecmat
is the transpose of the linear mapmatvec
(computed automatically with autodiff). Required for backwards mode automatic differentiation, unlesssymmetric=True
, in which casesolve
provides the default value.symmetric – bool indicating if it is safe to assume the linear map corresponds to a symmetric matrix, i.e.,
matvec == vecmat
.has_aux – bool indicating whether the
solve
andtranspose_solve
functions return auxiliary data like solver diagnostics as a second argument.
- Returns:
- Result of
solve(matvec, b)
, with gradients defined assuming that the solution
x
satisfies the linear equationmatvec(x) == b
.
- Result of