jax.numpy.concatenate#

jax.numpy.concatenate(arrays, axis=0, dtype=None)[source]#

Join arrays along an existing axis.

JAX implementation of numpy.concatenate().

Parameters:
  • arrays (np.ndarray | Array | Sequence[ArrayLike]) – a sequence of arrays to concatenate; each must have the same shape except along the specified axis. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.

  • axis (int | None) – specify the axis along which to concatenate.

  • dtype (DTypeLike | None | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.

Returns:

the concatenated result.

Return type:

Array

See also

Examples

One-dimensional concatenation:

>>> x = jnp.arange(3)
>>> y = jnp.zeros(3, dtype=int)
>>> jnp.concatenate([x, y])
Array([0, 1, 2, 0, 0, 0], dtype=int32)

Two-dimensional concatenation:

>>> x = jnp.ones((2, 3))
>>> y = jnp.zeros((2, 1))
>>> jnp.concatenate([x, y], axis=1)
Array([[1., 1., 1., 0.],
       [1., 1., 1., 0.]], dtype=float32)