jax.scipy.interpolate.RegularGridInterpolator

jax.scipy.interpolate.RegularGridInterpolator#

class jax.scipy.interpolate.RegularGridInterpolator(points, values, method='linear', bounds_error=False, fill_value=nan)[source]#

Interpolate points on a regular rectangular grid.

JAX implementation of scipy.interpolate.RegularGridInterpolator().

Parameters:
  • points – length-N sequence of arrays specifying the grid coordinates.

  • values – N-dimensional array specifying the grid values.

  • method – interpolation method, either "linear" or "nearest".

  • bounds_error – not implemented by JAX

  • fill_value – value returned for points outside the grid, defaults to NaN.

Returns:

callable interpolation object.

Return type:

interpolator

Examples

>>> points = (jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
>>> values = jnp.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
>>> interpolate = RegularGridInterpolator(points, values, method='linear')
>>> query_points = jnp.array([[1.5, 4.5], [2.2, 5.8]])
>>> interpolate(query_points)
Array([30., 64.], dtype=float32)

Note

Unlike scipy.interpolate.RegularGridInterpolator, JAX requires each axis in points to be strictly increasing. SciPy accepts any monotonic axis (increasing or decreasing), but a decreasing axis is not supported here and will silently produce incorrect results. Reorder the grid (and the corresponding axis of values) so each axis is strictly increasing before constructing the interpolator.

__init__(points, values, method='linear', bounds_error=False, fill_value=nan)[source]#

Methods

__init__(points, values[, method, ...])