jax.scipy.linalg.det#
- jax.scipy.linalg.det(a, overwrite_a=False, check_finite=True)[source]#
Compute the determinant of a matrix
JAX implementation of
scipy.linalg.det()
.- Parameters:
- Return type:
- Returns
Determinant of shape
a.shape[:-2]
See also
jax.numpy.linalg.det()
: NumPy-style determinant APIExamples
Determinant of a small 2D array:
>>> x = jnp.array([[1., 2.], ... [3., 4.]]) >>> jax.scipy.linalg.det(x) Array(-2., dtype=float32)
Batch-wise determinant of multiple 2D arrays:
>>> x = jnp.array([[[1., 2.], ... [3., 4.]], ... [[8., 5.], ... [7., 9.]]]) >>> jax.scipy.linalg.det(x) Array([-2., 37.], dtype=float32)