jax.numpy.ediff1d#

jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[source]#

Compute the differences of the elements of the flattened array.

JAX implementation of numpy.ediff1d().

Parameters:
  • ary (ArrayLike) – input array or scalar.

  • to_end (ArrayLike | None) – scalar or array, optional, default=None. Specifies the numbers to append to the resulting array.

  • to_begin (ArrayLike | None) – scalar or array, optional, default=None. Specifies the numbers to prepend to the resulting array.

Returns:

An array containing the differences between the elements of the input array.

Return type:

Array

Note

Unlike NumPy’s implementation of ediff1d, jax.numpy.ediff1d() will not issue an error if casting to_end or to_begin to the type of ary loses precision.

See also

  • jax.numpy.diff(): Computes the n-th order difference between elements of the array along a given axis.

  • jax.numpy.cumsum(): Computes the cumulative sum of the elements of the array along a given axis.

  • jax.numpy.gradient(): Computes the gradient of an N-dimensional array.

Examples

>>> a = jnp.array([2, 3, 5, 9, 1, 4])
>>> jnp.ediff1d(a)
Array([ 1,  2,  4, -8,  3], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10)
Array([-10,   1,   2,   4,  -8,   3], dtype=int32)
>>> jnp.ediff1d(a, to_end=jnp.array([20, 30]))
Array([ 1,  2,  4, -8,  3, 20, 30], dtype=int32)
>>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30]))
Array([-10,   1,   2,   4,  -8,   3,  20,  30], dtype=int32)

For array with ndim > 1, the differences are computed after flattening the input array.

>>> a1 = jnp.array([[2, -1, 4, 7],
...                 [3, 5, -6, 9]])
>>> jnp.ediff1d(a1)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)
>>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9])
>>> jnp.ediff1d(a2)
Array([ -3,   5,   3,  -4,   2, -11,  15], dtype=int32)