jax.numpy.linalg.det#

jax.numpy.linalg.det(a)[source]#

Compute the determinant of an array.

JAX implementation of numpy.linalg.det().

Parameters:

a (ArrayLike) – array of shape (..., M, M) for which to compute the determinant.

Returns:

An array of determinants of shape a.shape[:-2].

Return type:

Array

See also

jax.scipy.linalg.det(): Scipy-style API for determinant.

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> jnp.linalg.det(a)
Array(-2., dtype=float32)