jax.numpy.inner#

jax.numpy.inner(a, b, *, precision=None, preferred_element_type=None)[source]#

Compute the inner product of two arrays.

JAX implementation of numpy.inner().

Unlike jax.numpy.matmul() or jax.numpy.dot(), this always performs a contraction along the last dimension of each input.

Parameters:
Returns:

array of shape (*a.shape[:-1], *b.shape[:-1]) containing the batched vector product of the inputs.

Return type:

Array

See also

Examples

For 1D inputs, this implements standard (non-conjugate) vector multiplication:

>>> a = jnp.array([1j, 3j, 4j])
>>> b = jnp.array([4., 2., 5.])
>>> jnp.inner(a, b)
Array(0.+30.j, dtype=complex64)

For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:

>>> a = jnp.ones((2, 3))
>>> b = jnp.ones((5, 3))
>>> jnp.inner(a, b).shape
(2, 5)