jax.numpy.dstack#

jax.numpy.dstack(tup, dtype=None)[source]#

Stack arrays depth-wise.

JAX implementation of numpy.dstack().

For arrays of three or more dimensions, this is equivalent to jax.numpy.concatenate() with axis=2.

Parameters:
  • tup (np.ndarray | Array | Sequence[ArrayLike]) – a sequence of arrays to stack; each must have the same shape along all but the third axis. Input arrays will be promoted to at least rank 3. If a single array is given it will be treated equivalently to tup = unstack(tup), but the implementation will avoid explicit unstacking.

  • dtype (DTypeLike | None | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Returns:

the stacked result.

Return type:

Array

See also

Examples

Scalar values:

>>> jnp.dstack([1, 2, 3])
Array([[[1, 2, 3]]], dtype=int32, weak_type=True)

1D arrays:

>>> x = jnp.arange(3)
>>> y = jnp.ones(3)
>>> jnp.dstack([x, y])
Array([[[0., 1.],
        [1., 1.],
        [2., 1.]]], dtype=float32)

2D arrays:

>>> x = x.reshape(1, 3)
>>> y = y.reshape(1, 3)
>>> jnp.dstack([x, y])
Array([[[0., 1.],
        [1., 1.],
        [2., 1.]]], dtype=float32)