jax.lax.unstack

Contents

jax.lax.unstack#

jax.lax.unstack(x, axis=0)[source]#

Unstacks an array along an axis.

Parameters:
  • x (ArrayLike) – the array to unstack.

  • axis (int) – the axis along which to unstack the array.

Returns:

A tuple of arrays, split along axis.

Return type:

tuple[Array, …]

Examples

>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.array([[1, 2], [3, 4]])
>>> lax.unstack(x, axis=0)
(Array([1, 2], dtype=int32), Array([3, 4], dtype=int32))
>>> lax.unstack(x, axis=1)
(Array([1, 3], dtype=int32), Array([2, 4], dtype=int32))