jax.numpy.apply_along_axis#
- jax.numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)[source]#
Apply a function to 1D array slices along an axis.
JAX implementation of
numpy.apply_along_axis()
. While NumPy implements this iteratively, JAX implements this viajax.vmap()
, and sofunc1d
must be compatible withvmap
.- Parameters:
func1d (Callable) – a callable function with signature
func1d(arr, /, *args, **kwargs)
where*args
and**kwargs
are the additional positional and keyword arguments passed toapply_along_axis()
.axis (int) – integer axis along which to apply the function.
arr (ArrayLike) – the array over which to apply the function.
args – additional positional and keyword arguments are passed through to
func1d
.kwargs – additional positional and keyword arguments are passed through to
func1d
.
- Returns:
The result of
func1d
applied along the specified axis.- Return type:
See also
jax.vmap()
: a more direct way to create a vectorized version of a function.jax.numpy.apply_over_axes()
: repeatedly apply a function over multiple axes.jax.numpy.vectorize()
: create a vectorized version of a function.
Examples
A simple example in two dimensions, where the function is applied either row-wise or column-wise:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> def func1d(x): ... return jnp.sum(x ** 2) >>> jnp.apply_along_axis(func1d, 0, x) Array([17, 29, 45], dtype=int32) >>> jnp.apply_along_axis(func1d, 1, x) Array([14, 77], dtype=int32)
For 2D inputs, this can be equivalently expressed using
jax.vmap()
, though note that vmap specifies the mapped axis rather than the applied axis:>>> jax.vmap(func1d, in_axes=1)(x) # same as applying along axis 0 Array([17, 29, 45], dtype=int32) >>> jax.vmap(func1d, in_axes=0)(x) # same as applying along axis 1 Array([14, 77], dtype=int32)
For 3D inputs,
apply_along_axis()
is equivalent to mapping over two dimensions:>>> x_3d = jnp.arange(24).reshape(2, 3, 4) >>> jnp.apply_along_axis(func1d, 2, x_3d) Array([[ 14, 126, 366], [ 734, 1230, 1854]], dtype=int32) >>> jax.vmap(jax.vmap(func1d))(x_3d) Array([[ 14, 126, 366], [ 734, 1230, 1854]], dtype=int32)
The applied function may also take arbitrary positional or keyword arguments, which should be passed directly as additional arguments to
apply_along_axis()
:>>> def func1d(x, exponent): ... return jnp.sum(x ** exponent) >>> jnp.apply_along_axis(func1d, 0, x, exponent=3) Array([ 65, 133, 243], dtype=int32)