jax.numpy.matmul#
- jax.numpy.matmul(a, b, *, precision=None, preferred_element_type=None)[source]#
Perform a matrix multiplication.
JAX implementation of
numpy.matmul()
.- Parameters:
a (Array | ndarray | bool | number | bool | int | float | complex) – first input array, of shape
(N,)
or(..., K, N)
.b (Array | ndarray | bool | number | bool | int | float | complex) – second input array. Must have shape
(N,)
or(..., N, M)
. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions ofa
.precision (None | str | Precision | tuple[str, str] | tuple[Precision, Precision] | DotAlgorithm | DotAlgorithmPreset) – either
None
(default), which means the default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of two such values indicating precision ofa
andb
.preferred_element_type (str | type[Any] | dtype | SupportsDType | None) – either
None
(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Returns:
array containing the matrix product of the inputs. Shape is
a.shape[:-1]
ifb.ndim == 1
, otherwise the shape is(..., K, M)
, where leading dimensions ofa
andb
are broadcast together.- Return type:
See also
jax.numpy.linalg.vecdot()
: batched vector product.jax.numpy.linalg.tensordot()
: batched tensor product.jax.lax.dot_general()
: general N-dimensional batched dot product.
Examples
Vector dot products:
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.matmul(a, b) Array(32, dtype=int32)
Matrix dot product:
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> b = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> jnp.matmul(a, b) Array([[22, 28], [49, 64]], dtype=int32)
For convenience, in all cases you can do the same computation using the
@
operator:>>> a @ b Array([[22, 28], [49, 64]], dtype=int32)