jax.numpy.ediff1d#
- jax.numpy.ediff1d(ary, to_end=None, to_begin=None)[source]#
Compute the differences of the elements of the flattened array.
JAX implementation of
numpy.ediff1d()
.- Parameters:
ary (ArrayLike) – input array or scalar.
to_end (ArrayLike | None) – scalar or array, optional, default=None. Specifies the numbers to append to the resulting array.
to_begin (ArrayLike | None) – scalar or array, optional, default=None. Specifies the numbers to prepend to the resulting array.
- Returns:
An array containing the differences between the elements of the input array.
- Return type:
Note
Unlike NumPy’s implementation of ediff1d,
jax.numpy.ediff1d()
will not issue an error if castingto_end
orto_begin
to the type ofary
loses precision.See also
jax.numpy.diff()
: Computes the n-th order difference between elements of the array along a given axis.jax.numpy.cumsum()
: Computes the cumulative sum of the elements of the array along a given axis.jax.numpy.gradient()
: Computes the gradient of an N-dimensional array.
Examples
>>> a = jnp.array([2, 3, 5, 9, 1, 4]) >>> jnp.ediff1d(a) Array([ 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10) Array([-10, 1, 2, 4, -8, 3], dtype=int32) >>> jnp.ediff1d(a, to_end=jnp.array([20, 30])) Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32) >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30])) Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32)
For array with
ndim > 1
, the differences are computed after flattening the input array.>>> a1 = jnp.array([[2, -1, 4, 7], ... [3, 5, -6, 9]]) >>> jnp.ediff1d(a1) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9]) >>> jnp.ediff1d(a2) Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32)