jax.numpy.linalg.tensorinv#
- jax.numpy.linalg.tensorinv(a, ind=2)[source]#
Compute the tensor inverse of an array.
JAX implementation of
numpy.linalg.tensorinv()
.This computes the inverse of the
tensordot()
operation with the sameind
value.- Parameters:
a (ArrayLike) – array to be inverted. Must have
prod(a.shape[:ind]) == prod(a.shape[ind:])
ind (int) – positive integer specifying the number of indices in the tensor product.
- Returns:
array of shape
(*a.shape[ind:], *a.shape[:ind])
containing the tensor inverse ofa
.- Return type:
Examples
>>> key = jax.random.key(1337) >>> x = jax.random.normal(key, shape=(2, 2, 4)) >>> xinv = jnp.linalg.tensorinv(x, 2) >>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2) >>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4) Array(True, dtype=bool)