jax.tree_util.flatten_one_level_with_keys

jax.tree_util.flatten_one_level_with_keys#

jax.tree_util.flatten_one_level_with_keys(tree)[source]#

Flatten the given pytree node by one level, with keys.

Parameters:

tree (Any)

Return type:

tuple[Iterable[KeyLeafPair], Hashable]