jax.lax.pad#
- jax.lax.pad(operand, padding_value, padding_config)[source]#
Applies low, high, and/or interior padding to an array.
Wraps XLA’s Pad operator.
- Parameters:
operand (ArrayLike) – an array to be padded.
padding_value (ArrayLike) – the value to be inserted as padding. Must have the same dtype as
operand
.padding_config (Sequence[tuple[int, int, int]]) – a sequence of
(low, high, interior)
tuples of integers, giving the amount of low, high, and interior (dilation) padding to insert in each dimension.
- Returns:
The
operand
array with padding valuepadding_value
inserted in each dimension according to thepadding_config
.- Return type:
Examples
>>> from jax import lax >>> import jax.numpy as jnp
Pad a 1-dimensional array with zeros, We’ll specify two zeros in front and three at the end:
>>> x = jnp.array([1, 2, 3, 4]) >>> lax.pad(x, 0, [(2, 3, 0)]) Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)
Pad a 1-dimensional array with interior zeros; i.e. insert a single zero between each value:
>>> lax.pad(x, 0, [(0, 0, 1)]) Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)
Pad a 2-dimensional array with the value
-1
at front and end, with a pad size of 2 in each dimension:>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)]) Array([[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, 1, 2, 3, -1, -1], [-1, -1, 4, 5, 6, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)