jax.scipy.linalg.convolution_matrix

jax.scipy.linalg.convolution_matrix#

jax.scipy.linalg.convolution_matrix(a, n, mode='full')[source]#

Construct a convolution matrix.

JAX implementation of scipy.linalg.convolution_matrix().

Builds a Toeplitz matrix \(A\) such that A @ v equals jnp.convolve(a, v, mode). The returned array has n columns; the row count \(k\) depends on mode:

  • 'full': \(k = m + n - 1\)

  • 'same': \(k = \max(m, n)\)

  • 'valid': \(k = \max(m, n) - \min(m, n) + 1\)

where \(m\) is the size of a along the last axis.

Parameters:
  • a (ArrayLike) – array of shape (..., m) to convolve. Must have m >= 1.

  • n (int) – number of columns in the output. Must be a positive integer.

  • mode (str) – one of 'full', 'same', 'valid'. Defaults to 'full'.

Returns:

A convolution matrix of shape (..., k, n), where k depends on mode as described above.

Return type:

Array

Examples

>>> jax.scipy.linalg.convolution_matrix(jnp.array([-1, 4, -2]), 5, mode='same')
Array([[ 4, -1,  0,  0,  0],
       [-2,  4, -1,  0,  0],
       [ 0, -2,  4, -1,  0],
       [ 0,  0, -2,  4, -1],
       [ 0,  0,  0, -2,  4]], dtype=int32)