jax.numpy.tile#
- jax.numpy.tile(A, reps)[source]#
Construct an array by repeating
A
along specified dimensions.JAX implementation of
numpy.tile()
.If
A
is an array of shape(d1, d2, ..., dn)
andreps
is a sequence of integers, the resulting array will have a shape of(reps[0] * d1, reps[1] * d2, ..., reps[n] * dn)
, withA
tiled along each dimension.- Parameters:
A (ArrayLike) – input array to be repeated. Can be of any shape or dimension.
reps (DimSize | Sequence[DimSize]) – specifies the number of repetitions along each axis.
- Returns:
a new array where the input array has been repeated according to
reps
.- Return type:
See also
jax.numpy.repeat()
: Construct an array from repeated elements.jax.numpy.broadcast_to()
: Broadcast an array to a specified shape.
Examples
>>> arr = jnp.array([1, 2]) >>> jnp.tile(arr, 2) Array([1, 2, 1, 2], dtype=int32) >>> arr = jnp.array([[1, 2], ... [3, 4,]]) >>> jnp.tile(arr, (2, 1)) Array([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=int32)