jax.numpy.dot#
- jax.numpy.dot(a, b, *, precision=None, preferred_element_type=None)[source]#
Compute the dot product of two arrays.
JAX implementation of
numpy.dot()
.This differs from
jax.numpy.matmul()
in two respects:if either
a
orb
is a scalar, the result ofdot
is equivalent tojax.numpy.multiply()
, while the result ofmatmul
is an error.if
a
andb
have more than 2 dimensions, the batch indices are stacked rather than broadcast.
- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – first input array, of shape
(..., N)
.b (Array | ndarray | bool | number | bool | int | float | complex) – second input array. Must have shape
(N,)
or(..., N, M)
. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions ofa
.precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – either
None
(default), which means the default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of two such values indicating precision ofa
andb
.preferred_element_type (str | type[Any] | dtype | SupportsDType | None) – either
None
(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Returns:
array containing the dot product of the inputs, with batch dimensions of
a
andb
stacked rather than broadcast.- Return type:
See also
jax.numpy.matmul()
: broadcasted batched matmul.jax.lax.dot_general()
: general batched matrix multiplication.
Examples
For scalar inputs,
dot
computes the element-wise product:>>> x = jnp.array([1, 2, 3]) >>> jnp.dot(x, 2) Array([2, 4, 6], dtype=int32)
For vector or matrix inputs,
dot
computes the vector or matrix product:>>> M = jnp.array([[2, 3, 4], ... [5, 6, 7], ... [8, 9, 0]]) >>> jnp.dot(M, x) Array([20, 38, 26], dtype=int32) >>> jnp.dot(M, M) Array([[ 51, 60, 29], [ 96, 114, 62], [ 61, 78, 95]], dtype=int32)
For higher-dimensional matrix products, batch dimensions are stacked, whereas in
matmul()
they are broadcast. For example:>>> a = jnp.zeros((3, 2, 4)) >>> b = jnp.zeros((3, 4, 1)) >>> jnp.dot(a, b).shape (3, 2, 3, 1) >>> jnp.matmul(a, b).shape (3, 2, 1)