jax.tree_util.all_leaves#
- jax.tree_util.all_leaves(iterable, is_leaf=None)[source]#
Tests whether all elements in the given iterable are all leaves.
This function is useful in advanced cases, for example if a library allows arbitrary map operations on a flat iterable of leaves it may want to check if the result is still a flat iterable of leaves.
- Parameters:
iterable (Iterable[Any]) – Iterable of leaves.
is_leaf (Callable[[Any], bool] | None | None)
- Returns:
A boolean indicating if all elements in the input are leaves.
- Return type:
Examples
>>> import jax >>> tree = {"a": [1, 2, 3]} >>> assert all_leaves(jax.tree_util.tree_leaves(tree)) >>> assert not all_leaves([tree])