jax.numpy.sinc#
- jax.numpy.sinc(x, /)[source]#
Calculate the normalized sinc function.
JAX implementation of
numpy.sinc()
.The normalized sinc function is given by
\[\mathrm{sinc}(x) = \frac{\sin({\pi x})}{\pi x}\]where
sinc(0)
returns the limit value of1
. The sinc function is smooth and infinitely differentiable.- Parameters:
x (ArrayLike) – input array; will be promoted to an inexact type.
- Returns:
An array of the same shape as
x
containing the result.- Return type:
Examples
>>> x = jnp.array([-1, -0.5, 0, 0.5, 1]) >>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sinc(x) Array([-0. , 0.637, 1. , 0.637, -0. ], dtype=float32)
Compare this to the naive approach to computing the function, which is undefined at zero:
>>> with jnp.printoptions(precision=3, suppress=True): ... jnp.sin(jnp.pi * x) / (jnp.pi * x) Array([-0. , 0.637, nan, 0.637, -0. ], dtype=float32)
JAX defines a custom gradient rule for sinc to allow accurate evaluation of the gradient at zero even for higher-order derivatives:
>>> f = jnp.sinc >>> for i in range(1, 6): ... f = jax.grad(f) ... print(f"(d/dx)^{i} f(0.0) = {f(0.0):.2f}") ... (d/dx)^1 f(0.0) = 0.00 (d/dx)^2 f(0.0) = -3.29 (d/dx)^3 f(0.0) = 0.00 (d/dx)^4 f(0.0) = 19.48 (d/dx)^5 f(0.0) = 0.00