jax.tree_util.tree_structure#
- jax.tree_util.tree_structure(tree, is_leaf=None)[source]#
Alias of
jax.tree.structure()
.- Parameters:
tree (Any)
is_leaf (None | Callable[[Any], bool] | None)
- Return type:
PyTreeDef
Alias of jax.tree.structure()
.
tree (Any)
is_leaf (None | Callable[[Any], bool] | None)
PyTreeDef