jax.lax.custom_root#
- jax.lax.custom_root(f, initial_guess, solve, tangent_solve, has_aux=False)[source]#
Differentiably solve for the roots of a function.
This is a low-level routine, mostly intended for internal use in JAX. Gradients of custom_root() are defined with respect to closed-over variables from the provided function
f
via the implicit function theorem: https://en.wikipedia.org/wiki/Implicit_function_theorem- Parameters:
f (Callable) – function for which to find a root. Should accept a single argument, return a tree of arrays with the same structure as its input.
initial_guess (Any) – initial guess for a zero of f.
solve (Callable[[Callable, Any], Any]) –
function to solve for the roots of f. Should take two positional arguments, f and initial_guess, and return a solution with the same structure as initial_guess such that func(solution) = 0. In other words, the following is assumed to be true (but not checked):
solution = solve(f, initial_guess) error = f(solution) assert all(error == 0)
tangent_solve (Callable[[Callable, Any], Any]) –
function to solve the tangent system. Should take two positional arguments, a linear function
g
(the functionf
linearized at its root) and a tree of array(s)y
with the same structure as initial_guess, and return a solutionx
such thatg(x)=y
:For scalar
y
, uselambda g, y: y / g(1.0)
.For vector
y
, you could use a linear solve with the Jacobian, if dimensionality ofy
is not too large:lambda g, y: np.linalg.solve(jacobian(g)(y), y)
.
has_aux – bool indicating whether the
solve
function returns auxiliary data like solver diagnostics as a second argument.
- Returns:
The result of calling solve(f, initial_guess) with gradients defined via implicit differentiation assuming
f(solve(f, initial_guess)) == 0
.