jax.numpy.flip#

jax.numpy.flip(m, axis=None)[source]#

Reverse the order of elements of an array along the given axis.

JAX implementation of numpy.flip().

Parameters:
  • m (ArrayLike) – Array.

  • axis (int | Sequence[int] | None | None) – integer or sequence of integers. Specifies along which axis or axes should the array elements be reversed. Default is None, which flips along all axes.

Returns:

An array with the elements in reverse order along axis.

Return type:

Array

See also

Examples

>>> x1 = jnp.array([[1, 2],
...                 [3, 4]])
>>> jnp.flip(x1)
Array([[4, 3],
       [2, 1]], dtype=int32)

If axis is specified with an integer, then jax.numpy.flip reverses the array along that particular axis only.

>>> jnp.flip(x1, axis=1)
Array([[2, 1],
       [4, 3]], dtype=int32)
>>> x2 = jnp.arange(1, 9).reshape(2, 2, 2)
>>> x2
Array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]], dtype=int32)
>>> jnp.flip(x2)
Array([[[8, 7],
        [6, 5]],

       [[4, 3],
        [2, 1]]], dtype=int32)

When axis is specified with a sequence of integers, then jax.numpy.flip reverses the array along the specified axes.

>>> jnp.flip(x2, axis=[1, 2])
Array([[[4, 3],
        [2, 1]],

       [[8, 7],
        [6, 5]]], dtype=int32)