jax.tree.map_with_path#
- jax.tree.map_with_path(f, tree, *rest, is_leaf=None)[source]#
Maps a multi-input function over pytree key path and args to produce a new pytree.
This is a more powerful alternative of
tree_map
that can take the key path of each leaf as input argument as well.- Parameters:
f (Callable[..., Any]) – function that takes
2 + len(rest)
arguments, aka. the key path and each corresponding leaves of the pytrees.tree (Any) – a pytree to be mapped over, with each leaf’s key path as the first positional argument and the leaf itself as the second argument to
f
.*rest (Any) – a tuple of pytrees, each of which has the same structure as
tree
or hastree
as a prefix.is_leaf (Callable[[Any], bool] | None | None)
- Returns:
A new pytree with the same structure as
tree
but with the value at each leaf given byf(kp, x, *xs)
wherekp
is the key path of the leaf at the corresponding leaf intree
,x
is the leaf value andxs
is the tuple of values at corresponding nodes inrest
.- Return type:
Any
Examples
>>> import jax >>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3]) [1, 3, 5]