jax.numpy.vdot#
- jax.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)[source]#
Perform a conjugate multiplication of two 1D vectors.
JAX implementation of
numpy.vdot()
.- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – first input array, if not 1D it will be flattened.
b (Array | ndarray | bool | number | bool | int | float | complex) – second input array, if not 1D it will be flattened. Must have
a.size == b.size
.precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – either
None
(default), which means the default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of two such values indicating precision ofa
andb
.preferred_element_type (str | type[Any] | dtype | SupportsDType | None) – either
None
(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Returns:
Scalar array (shape
()
) containing the conjugate vector product of the inputs.- Return type:
See also
jax.numpy.vecdot()
: batched vector product.jax.numpy.matmul()
: general matrix multiplication.jax.lax.dot_general()
: general N-dimensional batched dot product.
Examples
>>> x = jnp.array([1j, 2j, 3j]) >>> y = jnp.array([1., 2., 3.]) >>> jnp.vdot(x, y) Array(0.-14.j, dtype=complex64)
Note the difference between this and
dot()
, which does not conjugate the first input when complex:>>> jnp.dot(x, y) Array(0.+14.j, dtype=complex64)