jax.numpy.tile#

jax.numpy.tile(A, reps)[source]#

Construct an array by repeating A along specified dimensions.

JAX implementation of numpy.tile().

If A is an array of shape (d1, d2, ..., dn) and reps is a sequence of integers, the resulting array will have a shape of (reps[0] * d1, reps[1] * d2, ..., reps[n] * dn), with A tiled along each dimension.

Parameters:
  • A (ArrayLike) – input array to be repeated. Can be of any shape or dimension.

  • reps (DimSize | Sequence[DimSize]) – specifies the number of repetitions along each axis.

Returns:

a new array where the input array has been repeated according to reps.

Return type:

Array

See also

Examples

>>> arr = jnp.array([1, 2])
>>> jnp.tile(arr, 2)
Array([1, 2, 1, 2], dtype=int32)
>>> arr = jnp.array([[1, 2],
...                  [3, 4,]])
>>> jnp.tile(arr, (2, 1))
Array([[1, 2],
       [3, 4],
       [1, 2],
       [3, 4]], dtype=int32)