jax.tree_util.flatten_one_level

jax.tree_util.flatten_one_level#

jax.tree_util.flatten_one_level(tree)[source]#

Flatten the given pytree node by one level.

Parameters:

tree (Any) – A valid pytree node, either built-in or registered via register_pytree_node() or related functions.

Returns:

A pair of the pytrees flattened children and its hashable metadata.

Raises:
  • ValueError – If the given pytree is not a built-in or registered container

  • via register_pytree_node or register_pytree_with_keys. –

Return type:

tuple[Iterable[Any], Hashable]

Examples

>>> import jax
>>> from jax._src.tree_util import flatten_one_level
>>> flattened, meta = flatten_one_level({'a': [1, 2], 'b': {'c': 3}})
>>> flattened
([1, 2], {'c': 3})
>>> meta
('a', 'b')