jax.tree.leaves_with_path#

jax.tree.leaves_with_path(tree, is_leaf=None)[source]#

Gets the leaves of a pytree like tree_leaves and returns each leaf’s key path.

Parameters:
  • tree (Any) – a pytree. If it contains a custom type, it is recommended to be registered with register_pytree_with_keys.

  • is_leaf (Callable[[Any], bool] | None | None)

Returns:

A list of key-leaf pairs, each of which contains a leaf and its key path.

Return type:

list[tuple[tree_util.KeyPath, Any]]

Examples

>>> import jax
>>> jax.tree.leaves_with_path([1, {'x': 3}])
[((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]