jax.numpy.pad#
- jax.numpy.pad(array, pad_width, mode='constant', **kwargs)[source]#
Add padding to an array.
JAX implementation of
numpy.pad()
.- Parameters:
array (ArrayLike) – array to pad.
pad_width (PadValueLike[int | Array | np.ndarray]) –
specify the pad width for each dimension of an array. Padding widths may be separately specified for before and after the array. Options are:
int
or(int,)
: pad each array dimension with the same number of values both before and after.(before, after)
: pad each array withbefore
elements before, andafter
elements after((before_1, after_1), (before_2, after_2), ... (before_N, after_N))
: specify distinctbefore
andafter
values for each array dimension.
mode (str | Callable[..., Any]) –
a string or callable. Supported pad modes are:
'constant'
(default): pad with a constant value, which defaults to zero.'empty'
: pad with empty values (i.e. zero)'edge'
: pad with the edge values of the array.'wrap'
: pad by wrapping the array.'linear_ramp'
: pad with a linear ramp to specifiedend_values
.'maximum'
: pad with the maximum value.'mean'
: pad with the mean value.'median'
: pad with the median value.'minimum'
: pad with the minimum value.'reflect'
: pad by reflection.'symmetric'
: pad by symmetric reflection.<callable>
: a callable function. See Notes below.
constant_values – referenced for
mode = 'constant'
. Specify the constant value to pad with.stat_length – referenced for
mode in ['maximum', 'mean', 'median', 'minimum']
. An integer or tuple specifying the number of edge values to use when calculating the statistic.end_values – referenced for
mode = 'linear_ramp'
. Specify the end values to ramp the padding values to.reflect_type – referenced for
mode in ['reflect', 'symmetric']
. Specify whether to use even or odd reflection.
- Returns:
A padded copy of
array
.- Return type:
Notes
When
mode
is callable, it should have the following signature:def pad_func(row: Array, pad_width: tuple[int, int], iaxis: int, kwargs: dict) -> Array: ...
Here
row
is a 1D slice of the padded array along axisiaxis
, with the pad values filled with zeros.pad_width
is a tuple specifying the(before, after)
padding sizes, andkwargs
are any additional keyword arguments passed to thejax.numpy.pad()
function.Note that while in NumPy, the function should modify
row
in-place, in JAX the function should return the modifiedrow
. In JAX, the custom padding function will be mapped across the padded axis using thejax.vmap()
transformation.See also
jax.numpy.resize()
: resize an arrayjax.numpy.tile()
: create a larger array by tiling a smaller array.jax.numpy.repeat()
: create a larger array by repeating values of a smaller array.
Examples
Pad a 1-dimensional array with zeros:
>>> x = jnp.array([10, 20, 30, 40]) >>> jnp.pad(x, 2) Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32) >>> jnp.pad(x, (2, 4)) Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32)
Pad a 1-dimensional array with specified values:
>>> jnp.pad(x, 2, constant_values=99) Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32)
Pad a 1-dimensional array with the mean array value:
>>> jnp.pad(x, 2, mode='mean') Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32)
Pad a 1-dimensional array with reflected values:
>>> jnp.pad(x, 2, mode='reflect') Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32)
Pad a 2-dimensional array with different paddings in each dimension:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.pad(x, ((1, 2), (3, 0))) Array([[0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 2, 3], [0, 0, 0, 4, 5, 6], [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]], dtype=int32)
Pad a 1-dimensional array with a custom padding function:
>>> def custom_pad(row, pad_width, iaxis, kwargs): ... # row represents a 1D slice of the zero-padded array. ... before, after = pad_width ... before_value = kwargs.get('before_value', 0) ... after_value = kwargs.get('after_value', 0) ... row = row.at[:before].set(before_value) ... return row.at[len(row) - after:].set(after_value) >>> x = jnp.array([2, 3, 4]) >>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10) Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32)