jax.lax.stack

Contents

jax.lax.stack#

jax.lax.stack(operands, axis=0)[source]#

Joins a sequence of arrays along a new axis.

Parameters:
  • operands (Sequence[ArrayLike]) – a sequence of arrays to stack. All arrays must have the same shape.

  • axis (int) – the axis along which to stack the arrays.

Returns:

An array containing the stacked operands.

Return type:

Array

Examples

>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.array([1, 2])
>>> y = jnp.array([3, 4])
>>> lax.stack([x, y], axis=0)
Array([[1, 2],
       [3, 4]], dtype=int32)
>>> lax.stack([x, y], axis=1)
Array([[1, 3],
       [2, 4]], dtype=int32)