jax.scipy.fft.idctn#

jax.scipy.fft.idctn(x, type=2, s=None, axes=None, norm=None)[source]#

Computes the multidimensional inverse discrete cosine transform of the input

JAX implementation of scipy.fft.idctn().

Parameters:
  • x (Array) – array

  • type (int) – integer, default = 2. Currently only type 2 is supported.

  • s (Sequence[int] | None | None) – integer or sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of x along the specified axes.

  • axes (Sequence[int] | None | None) – integer or sequence of integers. Specifies the axes along which the transform will be computed.

  • norm (str | None | None) – string. The normalization mode: one of [None, "backward", "ortho"]. The default is None, which is equivalent to "backward".

Returns:

array containing the inverse discrete cosine transform of x

Return type:

Array

See also

Examples

jax.scipy.fft.idctn computes the transform along both the axes by default when axes argument is None.

>>> x = jax.random.normal(jax.random.key(0), (3, 3))
>>> with jnp.printoptions(precision=2, suppress=True):
...    print(jax.scipy.fft.idctn(x))
[[ 0.12  0.11 -0.15]
 [ 0.07  0.17 -0.03]
 [ 0.19 -0.07 -0.02]]

When s=[2], dimension of the transform along axis 0 will be 2 and dimension along axis 1 will be the same as that of input.

>>> with jnp.printoptions(precision=2, suppress=True):
...  print(jax.scipy.fft.idctn(x, s=[2]))
[[ 0.15  0.21 -0.18]
 [ 0.24 -0.01 -0.02]]

When s=[2] and axes=[1], dimension of the transform along axis 1 will be 2 and dimension along axis 0 will be same as that of input. Also when axes=[1], transform will be computed only along axis 1.

>>> with jnp.printoptions(precision=2, suppress=True):
...  print(jax.scipy.fft.idctn(x, s=[2], axes=[1]))
[[ 1.12 -0.31]
 [ 0.04 -0.08]
 [ 0.05 -0.3 ]]

When s=[2, 4], shape of the transform will be (2, 4)

>>> with jnp.printoptions(precision=2, suppress=True):
...  print(jax.scipy.fft.idctn(x, s=[2, 4]))
[[ 0.1   0.18  0.07 -0.16]
 [ 0.2   0.06 -0.03 -0.01]]

jax.scipy.fft.idctn can be used to reconstruct x from the result of jax.scipy.fft.dctn

>>> x_dctn = jax.scipy.fft.dctn(x)
>>> jnp.allclose(x, jax.scipy.fft.idctn(x_dctn))
Array(True, dtype=bool)