jax.numpy.stack#

jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[source]#

Join arrays along a new axis.

JAX implementation of numpy.stack().

Parameters:
  • arrays (np.ndarray | Array | Sequence[ArrayLike]) – a sequence of arrays to stack; each must have the same shape. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.

  • axis (int) – specify the axis along which to stack.

  • out (None | None) – unused by JAX

  • dtype (DTypeLike | None | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Returns:

the stacked result.

Return type:

Array

See also

Examples

>>> x = jnp.array([1, 2, 3])
>>> y = jnp.array([4, 5, 6])
>>> jnp.stack([x, y])
Array([[1, 2, 3],
       [4, 5, 6]], dtype=int32)
>>> jnp.stack([x, y], axis=1)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

unstack() performs the inverse operation:

>>> arr = jnp.stack([x, y], axis=1)
>>> x, y = jnp.unstack(arr, axis=1)
>>> x
Array([1, 2, 3], dtype=int32)
>>> y
Array([4, 5, 6], dtype=int32)