jax.numpy.stack#
- jax.numpy.stack(arrays, axis=0, out=None, dtype=None)[source]#
Join arrays along a new axis.
JAX implementation of
numpy.stack()
.- Parameters:
arrays (np.ndarray | Array | Sequence[ArrayLike]) – a sequence of arrays to stack; each must have the same shape. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.
axis (int) – specify the axis along which to stack.
out (None | None) – unused by JAX
dtype (DTypeLike | None | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Returns:
the stacked result.
- Return type:
See also
jax.numpy.unstack()
: inverse ofstack
.jax.numpy.concatenate()
: concatenation along existing axes.jax.numpy.vstack()
: stack vertically, i.e. along axis 0.jax.numpy.hstack()
: stack horizontally, i.e. along axis 1.jax.numpy.dstack()
: stack depth-wise, i.e. along axis 2.jax.numpy.column_stack()
: stack columns.
Examples
>>> x = jnp.array([1, 2, 3]) >>> y = jnp.array([4, 5, 6]) >>> jnp.stack([x, y]) Array([[1, 2, 3], [4, 5, 6]], dtype=int32) >>> jnp.stack([x, y], axis=1) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
unstack()
performs the inverse operation:>>> arr = jnp.stack([x, y], axis=1) >>> x, y = jnp.unstack(arr, axis=1) >>> x Array([1, 2, 3], dtype=int32) >>> y Array([4, 5, 6], dtype=int32)