jax.tree
module#
Utilities for working with tree-like container data structures.
The jax.tree
namespace contains aliases of utilities from jax.tree_util
.
List of Functions#
|
Call all() over the leaves of a tree. |
|
Flattens a pytree. |
|
Flattens a pytree like |
|
Gets the leaves of a pytree. |
|
Gets the leaves of a pytree like |
|
Maps a multi-input function over pytree args to produce a new pytree. |
|
Maps a multi-input function over pytree key path and args to produce a new pytree. |
|
Call reduce() over the leaves of a tree. |
|
Gets the treedef for a pytree. |
|
Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). |
|
Reconstructs a pytree from the treedef and the leaves. |