jax.tree module#

Utilities for working with tree-like container data structures.

The jax.tree namespace contains aliases of utilities from jax.tree_util.

List of Functions#

all(tree, *[, is_leaf])

Call all() over the leaves of a tree.

flatten(tree[, is_leaf])

Flattens a pytree.

flatten_with_path(tree[, is_leaf])

Flattens a pytree like tree_flatten, but also returns each leaf's key path.

leaves(tree[, is_leaf])

Gets the leaves of a pytree.

leaves_with_path(tree[, is_leaf])

Gets the leaves of a pytree like tree_leaves and returns each leaf's key path.

map(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree args to produce a new pytree.

map_with_path(f, tree, *rest[, is_leaf])

Maps a multi-input function over pytree key path and args to produce a new pytree.

reduce()

Call reduce() over the leaves of a tree.

structure(tree[, is_leaf])

Gets the treedef for a pytree.

transpose(outer_treedef, inner_treedef, ...)

Transform a tree having tree structure (outer, inner) into one having structure (inner, outer).

unflatten(treedef, leaves)

Reconstructs a pytree from the treedef and the leaves.