jax.numpy.vdot#

jax.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)[source]#

Perform a conjugate multiplication of two 1D vectors.

JAX implementation of numpy.vdot().

Parameters:
Returns:

Scalar array (shape ()) containing the conjugate vector product of the inputs.

Return type:

Array

See also

Examples

>>> x = jnp.array([1j, 2j, 3j])
>>> y = jnp.array([1., 2., 3.])
>>> jnp.vdot(x, y)
Array(0.-14.j, dtype=complex64)

Note the difference between this and dot(), which does not conjugate the first input when complex:

>>> jnp.dot(x, y)
Array(0.+14.j, dtype=complex64)