jax.tree_util.tree_unflatten#

jax.tree_util.tree_unflatten(treedef, leaves)[source]#

Alias of jax.tree.unflatten().

Parameters:
  • treedef (PyTreeDef)

  • leaves (Iterable[Leaf])

Return type:

Any