jax.tree_util.flatten_one_level_with_keys# jax.tree_util.flatten_one_level_with_keys(tree)[source]# Flatten the given pytree node by one level, with keys. Parameters: tree (Any) Return type: tuple[Iterable[KeyLeafPair], Hashable]