jax.numpy.kron#

jax.numpy.kron(a, b)[source]#

Compute the Kronecker product of two input arrays.

JAX implementation of numpy.kron().

The Kronecker product is an operation on two matrices of arbitrary size that produces a block matrix. Each element of the first matrix a is multiplied by the entire second matrix b. If a has shape (m, n) and b has shape (p, q), the resulting matrix will have shape (m * p, n * q).

Parameters:
  • a (ArrayLike) – first input array with any shape.

  • b (ArrayLike) – second input array with any shape.

Returns:

A new array representing the Kronecker product of the inputs a and b. The shape of the output is the element-wise product of the input shapes.

Return type:

Array

See also

Examples

>>> a = jnp.array([[1, 2],
...                [3, 4]])
>>> b = jnp.array([[5, 6],
...                [7, 8]])
>>> jnp.kron(a, b)
Array([[ 5,  6, 10, 12],
       [ 7,  8, 14, 16],
       [15, 18, 20, 24],
       [21, 24, 28, 32]], dtype=int32)