jax.tree_util.tree_unflatten#
- jax.tree_util.tree_unflatten(treedef, leaves)[source]#
Alias of
jax.tree.unflatten()
.- Parameters:
treedef (PyTreeDef)
leaves (Iterable[Leaf])
- Return type:
Any
Alias of jax.tree.unflatten()
.
treedef (PyTreeDef)
leaves (Iterable[Leaf])
Any