jax.scipy.linalg.solve_triangular#
- jax.scipy.linalg.solve_triangular(a, b, trans=0, lower=False, unit_diagonal=False, overwrite_b=False, debug=None, check_finite=True)[source]#
Solve a triangular linear system of equations
JAX implementation of
scipy.linalg.solve_triangular()
.This solves a (batched) linear system of equations
a @ x = b
forx
given a triangular matrixa
and a vector or matrixb
.- Parameters:
a (ArrayLike) – array of shape
(..., N, N)
. Only part of the array will be accessed, depending on thelower
andunit_diagonal
arguments.b (ArrayLike) – array of shape
(..., N)
or(..., N, M)
lower (bool) – If True, only use the lower triangle of the input, If False (default), only use the upper triangle.
unit_diagonal (bool) – If True, ignore diagonal elements of
a
and assume they are1
(default: False).specify what properties of
a
can be assumed. Options are:0
or'N'
: solve \(Ax=b\)1
or'T'
: solve \(A^Tx=b\)2
or'C'
: solve \(A^Hx=b\)
overwrite_b (bool) – unused by JAX
debug (Any | None) – unused by JAX
check_finite (bool) – unused by JAX
- Returns:
An array of the same shape as
b
containing the solution to the linear system.- Return type:
See also
jax.scipy.linalg.solve()
: Solve a general linear system.Examples
A simple 3x3 triangular linear system:
>>> A = jnp.array([[1., 2., 3.], ... [0., 3., 2.], ... [0., 0., 5.]]) >>> b = jnp.array([10., 8., 5.]) >>> x = jax.scipy.linalg.solve_triangular(A, b) >>> x Array([3., 2., 1.], dtype=float32)
Confirming that the result solves the system:
>>> jnp.allclose(A @ x, b) Array(True, dtype=bool)
Computing the transposed problem:
>>> x = jax.scipy.linalg.solve_triangular(A, b, trans='T') >>> x Array([10. , -4. , -3.4], dtype=float32)
Confirming that the result solves the system:
>>> jnp.allclose(A.T @ x, b) Array(True, dtype=bool)