jax.tree.unflatten#
- jax.tree.unflatten(treedef, leaves)[source]#
Reconstructs a pytree from the treedef and the leaves.
The inverse of
tree_flatten()
.- Parameters:
treedef (tree_util.PyTreeDef) – the treedef to reconstruct
leaves (Iterable[tree_util.Leaf]) – the iterable of leaves to use for reconstruction. The iterable must match the leaves of the treedef.
- Returns:
The reconstructed pytree, containing the
leaves
placed in the structure described bytreedef
.- Return type:
Any
Examples
>>> import jax >>> vals, treedef = jax.tree.flatten([1, (2, 3), [4, 5]]) >>> newvals = [100, 200, 300, 400, 500] >>> jax.tree.unflatten(treedef, newvals) [100, (200, 300), [400, 500]]