jax.lax.stack#
- jax.lax.stack(operands, axis=0)[source]#
Joins a sequence of arrays along a new axis.
- Parameters:
operands (Sequence[ArrayLike]) – a sequence of arrays to stack. All arrays must have the same shape.
axis (int) – the axis along which to stack the arrays.
- Returns:
An array containing the stacked operands.
- Return type:
Examples
>>> import jax.numpy as jnp >>> from jax import lax >>> x = jnp.array([1, 2]) >>> y = jnp.array([3, 4]) >>> lax.stack([x, y], axis=0) Array([[1, 2], [3, 4]], dtype=int32) >>> lax.stack([x, y], axis=1) Array([[1, 3], [2, 4]], dtype=int32)