jax.tree_util.treedef_children#

jax.tree_util.treedef_children(treedef)[source]#

Return a list of treedefs for immediate children

Parameters:

treedef (PyTreeDef) – a single PyTreeDef

Returns:

a list of PyTreeDefs representing the children of treedef.

Return type:

list[PyTreeDef]

Examples

>>> import jax
>>> x = [(1, 2), 3, {'a': 4}]
>>> treedef = jax.tree.structure(x)
>>> jax.tree_util.treedef_children(treedef)
[PyTreeDef((*, *)), PyTreeDef(*), PyTreeDef({'a': *})]
>>> _ == [jax.tree.structure(vals) for vals in x]
True