jax.tree_util.build_tree#
- jax.tree_util.build_tree(treedef, xs)[source]#
Build a treedef from a nested iterable structure
DEPRECATED: Use
jax.tree.unflatten()
instead.- Parameters:
treedef (PyTreeDef) – the PyTreeDef structure to build.
xs (Any) – nested iterables matching the arity as the treedef
- Returns:
object with structure defined by treedef
- Return type:
Any
See also
Examples
>>> import jax >>> tree = [(1, 2), {'a': 3, 'b': 4}] >>> treedef = jax.tree.structure(tree) >>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13]) [(10, 11), {'a': 12, 'b': 13}]