jax.lax.pad#

jax.lax.pad(operand, padding_value, padding_config)[source]#

Applies low, high, and/or interior padding to an array.

Wraps XLA’s Pad operator.

Parameters:
  • operand (ArrayLike) – an array to be padded.

  • padding_value (ArrayLike) – the value to be inserted as padding. Must have the same dtype as operand.

  • padding_config (Sequence[tuple[int, int, int]]) – a sequence of (low, high, interior) tuples of integers, giving the amount of low, high, and interior (dilation) padding to insert in each dimension.

Returns:

The operand array with padding value padding_value inserted in each dimension according to the padding_config.

Return type:

Array

Examples

>>> from jax import lax
>>> import jax.numpy as jnp

Pad a 1-dimensional array with zeros, We’ll specify two zeros in front and three at the end:

>>> x = jnp.array([1, 2, 3, 4])
>>> lax.pad(x, 0, [(2, 3, 0)])
Array([0, 0, 1, 2, 3, 4, 0, 0, 0], dtype=int32)

Pad a 1-dimensional array with interior zeros; i.e. insert a single zero between each value:

>>> lax.pad(x, 0, [(0, 0, 1)])
Array([1, 0, 2, 0, 3, 0, 4], dtype=int32)

Pad a 2-dimensional array with the value -1 at front and end, with a pad size of 2 in each dimension:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> lax.pad(x, -1, [(2, 2, 0), (2, 2, 0)])
Array([[-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1,  1,  2,  3, -1, -1],
       [-1, -1,  4,  5,  6, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1],
       [-1, -1, -1, -1, -1, -1, -1]], dtype=int32)