jax.lax.scan#
- jax.lax.scan(f, init, xs=None, length=None, reverse=False, unroll=1, _split_transpose=False)[source]#
Scan a function over leading array axes while carrying along state.
The Haskell-like type signature in brief is
scan :: (c -> a -> (c, b)) -> c -> [a] -> (c, [b])
where for any array type specifier
t
,[t]
represents the type with an additional leading axis, and ift
is a pytree (container) type with array leaves then[t]
represents the type with the same pytree structure and corresponding leaves each with an additional leading axis.When the type of
xs
(denoted a above) is an array type or None, and the type ofys
(denoted b above) is an array type, the semantics ofscan()
are given roughly by this Python implementation:def scan(f, init, xs, length=None): if xs is None: xs = [None] * length carry = init ys = [] for x in xs: carry, y = f(carry, x) ys.append(y) return carry, np.stack(ys)
Unlike that Python version, both
xs
andys
may be arbitrary pytree values, and so multiple arrays can be scanned over at once and produce multiple output arrays.None
is actually a special case of this, as it represents an empty pytree.Also unlike that Python version,
scan()
is a JAX primitive and is lowered to a single WhileOp. That makes it useful for reducing compilation times for JIT-compiled functions, since native Python loop constructs in anjit()
function are unrolled, leading to large XLA computations.Finally, the loop-carried value
carry
must hold a fixed shape and dtype across all iterations (and not just be consistent up to NumPy rank/shape broadcasting and dtype promotion rules, for example). In other words, the typec
in the type signature above represents an array with a fixed shape and dtype (or a nested tuple/list/dict container data structure with a fixed structure and arrays with fixed shape and dtype at the leaves).Note
scan()
compilesf
, so while it can be combined withjit()
, it’s usually unnecessary.- Parameters:
f (Callable[[Carry, X], tuple[Carry, Y]]) – a Python function to be scanned of type
c -> a -> (c, b)
, meaning thatf
accepts two arguments where the first is a value of the loop carry and the second is a slice ofxs
along its leading axis, and thatf
returns a pair where the first element represents a new value for the loop carry and the second represents a slice of the output.init (Carry) – an initial loop carry value of type
c
, which can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof, representing the initial loop carry value. This value must have the same structure as the first element of the pair returned byf
.xs (X | None) – the value of type
[a]
over which to scan along the leading axis, where[a]
can be an array or any pytree (nested Python tuple/list/dict) thereof with consistent leading axis sizes.length (int | None) – optional integer specifying the number of loop iterations, which must agree with the sizes of leading axes of the arrays in
xs
(but can be used to perform scans where no inputxs
are needed).reverse (bool) – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse, equivalent to reversing the leading axes of the arrays in both
xs
and inys
.unroll (int | bool) – optional positive int or bool specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. unroll=True) or left completely rolled (i.e. unroll=False).
_split_transpose (bool) – experimental optional bool specifying whether to further split the transpose into a scan (computing activation gradients), and a map (computing gradients corresponding to the array arguments). Enabling this may increase memory requirements, and so is an experimental feature that may evolve or even be rolled back.
- Returns:
A pair of type
(c, [b])
where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output off
when scanned over the leading axis of the inputs.- Return type:
tuple[Carry, Y]