jax.numpy.linalg.det#
- jax.numpy.linalg.det(a)[source]#
Compute the determinant of an array.
JAX implementation of
numpy.linalg.det()
.- Parameters:
a (ArrayLike) – array of shape
(..., M, M)
for which to compute the determinant.- Returns:
An array of determinants of shape
a.shape[:-2]
.- Return type:
See also
jax.scipy.linalg.det()
: Scipy-style API for determinant.Examples
>>> a = jnp.array([[1, 2], ... [3, 4]]) >>> jnp.linalg.det(a) Array(-2., dtype=float32)