jax.scipy.fft.dctn#
- jax.scipy.fft.dctn(x, type=2, s=None, axes=None, norm=None)[source]#
Computes the multidimensional discrete cosine transform of the input
JAX implementation of
scipy.fft.dctn()
.- Parameters:
x (Array) – array
type (int) – integer, default = 2. Currently only type 2 is supported.
s (Sequence[int] | None | None) – integer or sequence of integers. Specifies the shape of the result. If not specified, it will default to the shape of
x
along the specifiedaxes
.axes (Sequence[int] | None | None) – integer or sequence of integers. Specifies the axes along which the transform will be computed.
norm (str | None | None) – string. The normalization mode: one of
[None, "backward", "ortho"]
. The default isNone
, which is equivalent to"backward"
.
- Returns:
array containing the discrete cosine transform of x
- Return type:
See also
jax.scipy.fft.dct()
: one-dimensional DCTjax.scipy.fft.idct()
: one-dimensional inverse DCTjax.scipy.fft.idctn()
: multidimensional inverse DCT
Examples
jax.scipy.fft.dctn
computes the transform along both the axes by default whenaxes
argument isNone
.>>> x = jax.random.normal(jax.random.key(0), (3, 3)) >>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x)) [[ 12.01 6.2 -10.17] [ 8.84 9.65 -3.54] [ 11.25 -1.54 -0.88]]
When
s=[2]
, dimension of the transform alongaxis 0
will be2
and dimension alongaxis 1
will be same as that of input.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2])) [[ 9.36 10.22 -8.53] [11.57 2.85 -2.06]]
When
s=[2]
andaxes=[1]
, dimension of the transform alongaxis 1
will be2
and dimension alongaxis 0
will be same as that of input. Also whenaxes=[1]
, transform will be computed only alongaxis 1
.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2], axes=[1])) [[ 7.3 -0.57] [ 0.19 -0.36] [-0. -1.4 ]]
When
s=[2, 4]
, shape of the transform will be(2, 4)
.>>> with jnp.printoptions(precision=2, suppress=True): ... print(jax.scipy.fft.dctn(x, s=[2, 4])) [[ 9.36 11.23 2.12 -10.97] [ 11.57 5.86 -1.37 -1.58]]