jax.tree_util.build_tree#

jax.tree_util.build_tree(treedef, xs)[source]#

Build a treedef from a nested iterable structure

DEPRECATED: Use jax.tree.unflatten() instead.

Parameters:
  • treedef (PyTreeDef) – the PyTreeDef structure to build.

  • xs (Any) – nested iterables matching the arity as the treedef

Returns:

object with structure defined by treedef

Return type:

Any

Examples

>>> import jax
>>> tree = [(1, 2), {'a': 3, 'b': 4}]
>>> treedef = jax.tree.structure(tree)
>>> jax.tree_util.tree_unflatten(treedef, [10, 11, 12, 13])
[(10, 11), {'a': 12, 'b': 13}]