jax.lax.unstack#
- jax.lax.unstack(x, axis=0)[source]#
Unstacks an array along an axis.
- Parameters:
x (ArrayLike) – the array to unstack.
axis (int) – the axis along which to unstack the array.
- Returns:
A tuple of arrays, split along axis.
- Return type:
Examples
>>> import jax.numpy as jnp >>> from jax import lax >>> x = jnp.array([[1, 2], [3, 4]]) >>> lax.unstack(x, axis=0) (Array([1, 2], dtype=int32), Array([3, 4], dtype=int32)) >>> lax.unstack(x, axis=1) (Array([1, 3], dtype=int32), Array([2, 4], dtype=int32))