jax.scipy.linalg.convolution_matrix#
- jax.scipy.linalg.convolution_matrix(a, n, mode='full')[source]#
Construct a convolution matrix.
JAX implementation of
scipy.linalg.convolution_matrix().Builds a Toeplitz matrix \(A\) such that
A @ vequalsjnp.convolve(a, v, mode). The returned array hasncolumns; the row count \(k\) depends onmode:'full': \(k = m + n - 1\)'same': \(k = \max(m, n)\)'valid': \(k = \max(m, n) - \min(m, n) + 1\)
where \(m\) is the size of
aalong the last axis.- Parameters:
- Returns:
A convolution matrix of shape
(..., k, n), wherekdepends onmodeas described above.- Return type:
See also
Examples
>>> jax.scipy.linalg.convolution_matrix(jnp.array([-1, 4, -2]), 5, mode='same') Array([[ 4, -1, 0, 0, 0], [-2, 4, -1, 0, 0], [ 0, -2, 4, -1, 0], [ 0, 0, -2, 4, -1], [ 0, 0, 0, -2, 4]], dtype=int32)