jax.tree_util.flatten_one_level#
- jax.tree_util.flatten_one_level(tree)[source]#
Flatten the given pytree node by one level.
- Parameters:
tree (Any) – A valid pytree node, either built-in or registered via
register_pytree_node()or related functions.- Returns:
A pair of the pytrees flattened children and its hashable metadata.
- Raises:
ValueError – If the given pytree is not a built-in or registered container
via register_pytree_node or register_pytree_with_keys. –
- Return type:
tuple[Iterable[Any], Hashable]
Examples
>>> import jax >>> from jax._src.tree_util import flatten_one_level >>> flattened, meta = flatten_one_level({'a': [1, 2], 'b': {'c': 3}}) >>> flattened ([1, 2], {'c': 3}) >>> meta ('a', 'b')