jax.numpy.dot

Contents

jax.numpy.dot#

jax.numpy.dot(a, b, *, precision=None, preferred_element_type=None, out_sharding=None)[source]#

Compute the dot product of two arrays.

JAX implementation of numpy.dot().

This differs from jax.numpy.matmul() in two respects:

  • if either a or b is a scalar, the result of dot is equivalent to jax.numpy.multiply(), while the result of matmul is an error.

  • if a and b have more than 2 dimensions, the batch indices are stacked rather than broadcast.

Parameters:
Returns:

An array containing the dot product of the inputs. Unlike matmul(), the batch dimensions of a and b are stacked rather than broadcast; that is, the output shape will be (*a_batch,) if b is one-dimensional, or (*a_batch, *b_batch, M) if b has more than one dimension.

Return type:

Array

See also

Examples

For scalar inputs, dot computes the element-wise product:

>>> x = jnp.array([1, 2, 3])
>>> jnp.dot(x, 2)
Array([2, 4, 6], dtype=int32)

For vector or matrix inputs, dot computes the vector or matrix product:

>>> M = jnp.array([[2, 3, 4],
...                [5, 6, 7],
...                [8, 9, 0]])
>>> jnp.dot(M, x)
Array([20, 38, 26], dtype=int32)
>>> jnp.dot(M, M)
Array([[ 51,  60,  29],
       [ 96, 114,  62],
       [ 61,  78,  95]], dtype=int32)

For higher-dimensional matrix products, batch dimensions are stacked, whereas in matmul() they are broadcast. For example:

>>> a = jnp.zeros((3, 2, 4))
>>> b = jnp.zeros((3, 4, 1))
>>> jnp.dot(a, b).shape
(3, 2, 3, 1)
>>> jnp.matmul(a, b).shape
(3, 2, 1)