jax.tree_util.treedef_children#
- jax.tree_util.treedef_children(treedef)[source]#
Return a list of treedefs for immediate children
- Parameters:
treedef (PyTreeDef) – a single PyTreeDef
- Returns:
a list of PyTreeDefs representing the children of treedef.
- Return type:
list[PyTreeDef]
Examples
>>> import jax >>> x = [(1, 2), 3, {'a': 4}] >>> treedef = jax.tree.structure(x) >>> jax.tree_util.treedef_children(treedef) [PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})] >>> _ == [jax.tree.structure(vals) for vals in x] True
See also