jax.tree_util.register_pytree_node#

jax.tree_util.register_pytree_node(nodetype, flatten_func, unflatten_func, flatten_with_keys_func=None)[source]#

Extends the set of types that are considered internal nodes in pytrees.

See example usage.

Parameters:
  • nodetype (type[T]) – a Python type to register as a pytree.

  • flatten_func (Callable[[T], tuple[_Children, _AuxData]]) – a function to be used during flattening, taking a value of type nodetype and returning a pair, with (1) an iterable for the children to be flattened recursively, and (2) some hashable auxiliary data to be stored in the treedef and to be passed to the unflatten_func.

  • unflatten_func (Callable[[_AuxData, _Children], T]) – a function taking two arguments: the auxiliary data that was returned by flatten_func and stored in the treedef, and the unflattened children. The function should return an instance of nodetype.

  • flatten_with_keys_func (Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None | None)

Return type:

None

See also

Examples

First we’ll define a custom type:

>>> class MyContainer:
...   def __init__(self, size):
...     self.x = jnp.zeros(size)
...     self.y = jnp.ones(size)
...     self.size = size

If we try using this in a JIT-compiled function, we’ll get an error because JAX does not yet know how to handle this type:

>>> m = MyContainer(size=5)
>>> def f(m):
...   return m.x + m.y + jnp.arange(m.size)
>>> jax.jit(f)(m)  
Traceback (most recent call last):
  ...
TypeError: Cannot interpret value of type <class 'jax.tree_util.MyContainer'> as an abstract array; it does not have a dtype attribute

In order to make our object recognized by JAX, we must register it as a pytree:

>>> def flatten_func(obj):
...   children = (obj.x, obj.y)  # children must contain arrays & pytrees
...   aux_data = (obj.size,)  # aux_data must contain static, hashable data.
...   return (children, aux_data)
...
>>> def unflatten_func(aux_data, children):
...   # Here we avoid `__init__` because it has extra logic we don't require:
...   obj = object.__new__(MyContainer)
...   obj.x, obj.y = children
...   obj.size, = aux_data
...   return obj
...
>>> jax.tree_util.register_pytree_node(MyContainer, flatten_func, unflatten_func)

Now with this defined, we can use instances of this type in JIT-compiled functions.

>>> jax.jit(f)(m)
Array([1., 2., 3., 4., 5.], dtype=float32)