jax.numpy.fft.fftshift#

jax.numpy.fft.fftshift(x, axes=None)[source]#

Shift zero-frequency fft component to the center of the spectrum.

JAX implementation of numpy.fft.fftshift().

Parameters:
  • x (ArrayLike) – N-dimensional array array of frequencies.

  • axes (None | int | Sequence[int] | None) – optional integer or sequence of integers specifying which axes to shift. If None (default), then shift all axes.

Returns:

A shifted copy of x.

Return type:

Array

See also

Examples

Generate FFT frequencies with fftfreq():

>>> freq = jnp.fft.fftfreq(5)
>>> freq
Array([ 0. ,  0.2,  0.4, -0.4, -0.2], dtype=float32)

Use fftshift to shift the zero-frequency entry to the middle of the array:

>>> shifted_freq = jnp.fft.fftshift(freq)
>>> shifted_freq
Array([-0.4, -0.2,  0. ,  0.2,  0.4], dtype=float32)

Unshift with ifftshift() to recover the original frequencies:

>>> jnp.fft.ifftshift(shifted_freq)
Array([ 0. ,  0.2,  0.4, -0.4, -0.2], dtype=float32)