jax.numpy.concatenate#
- jax.numpy.concatenate(arrays, axis=0, dtype=None)[source]#
Join arrays along an existing axis.
JAX implementation of
numpy.concatenate()
.- Parameters:
arrays (np.ndarray | Array | Sequence[ArrayLike]) – a sequence of arrays to concatenate; each must have the same shape except along the specified axis. If a single array is given it will be treated equivalently to arrays = unstack(arrays), but the implementation will avoid explicit unstacking.
axis (int | None) – specify the axis along which to concatenate.
dtype (DTypeLike | None | None) – optional dtype of the resulting array. If not specified, the dtype will be determined via type promotion rules described in Type promotion semantics.
- Returns:
the concatenated result.
- Return type:
See also
jax.lax.concatenate()
: XLA concatenation API.jax.numpy.concat()
: Array API version of this function.jax.numpy.stack()
: concatenate arrays along a new axis.
Examples
One-dimensional concatenation:
>>> x = jnp.arange(3) >>> y = jnp.zeros(3, dtype=int) >>> jnp.concatenate([x, y]) Array([0, 1, 2, 0, 0, 0], dtype=int32)
Two-dimensional concatenation:
>>> x = jnp.ones((2, 3)) >>> y = jnp.zeros((2, 1)) >>> jnp.concatenate([x, y], axis=1) Array([[1., 1., 1., 0.], [1., 1., 1., 0.]], dtype=float32)