jax.lax.map#
- jax.lax.map(f, xs, *, batch_size=None)[source]#
Map a function over leading array axes.
Like Python’s builtin map, except inputs and outputs are in the form of stacked arrays. Consider using the
vmap()
transform instead, unless you need to apply a function element by element for reduced memory usage or heterogeneous computation with other control flow primitives.When
xs
is an array type, the semantics ofmap()
are given by this Python implementation:def map(f, xs): return np.stack([f(x) for x in xs])
Like
scan()
,map()
is implemented in terms of JAX primitives so many of the same advantages over a Python loop apply:xs
may be an arbitrary nested pytree type, and the mapped computation is compiled only once.If
batch_size
is provided, the computation is executed in batches of that size and parallelized usingvmap()
. This can be used as either a more performant version ofmap
or as a memory-efficient version ofvmap
. If the axis is not divisible by the batch size, the remainder is processed in a separatevmap
and concatenated to the result.>>> x = jnp.ones((10, 3, 4)) >>> def f(x): ... print('inner shape:', x.shape) ... return x + 1 >>> y = lax.map(f, x, batch_size=3) inner shape: (3, 4) inner shape: (3, 4) >>> y.shape (10, 3, 4)
In the example above, “inner shape” is printed twice, once while tracing the batched computation and once while tracing the remainder computation.
- Parameters:
f – a Python function to apply element-wise over the first axis or axes of
xs
.xs – values over which to map along the leading axis.
batch_size (int | None) – (optional) integer specifying the size of the batch for each step to execute in parallel.
- Returns:
Mapped values.