jax.tree_util.register_dataclass#
- jax.tree_util.register_dataclass(nodetype, data_fields=None, meta_fields=None, drop_fields=())[source]#
Extends the set of types that are considered internal nodes in pytrees.
This differs from
register_pytree_with_keys_class
in that the C++ registries use the optimized C++ dataclass builtin instead of the argument functions.See Extending pytrees for more information about registering pytrees.
- Parameters:
nodetype (Typ) – a Python type to treat as an internal pytree node. This is assumed to have the semantics of a
dataclass
: namely, class attributes represent the whole of the object state, and can be passed as keywords to the class constructor to create a copy of the object. All defined attributes should be listed amongmeta_fields
ordata_fields
.meta_fields (Sequence[str] | None | None) – metadata field names: these are attributes which will be treated as {term}`static` when this pytree is passed to
jax.jit()
.meta_fields
is optional only ifnodetype
is a dataclass, in which case individual fields can be marked static viadataclasses.field()
(see examples below). Metadata fields must be static, hashable, immutable objects, as these objects are used to generate JIT cache keys. In particular, metadata fields cannot containjax.Array
ornumpy.ndarray
objects.data_fields (Sequence[str] | None | None) – data field names: these are attributes which will be treated as non-static when this pytree is passed to
jax.jit()
.data_fields
is optional only ifnodetype
is a dataclass, in which case fields are assumed data fields unless marked viadataclasses.field()
(see examples below). Data fields must be JAX-compatible objects such as arrays (jax.Array
ornumpy.ndarray
), scalars, or pytrees whose leaves are arrays or scalars. Note thatNone
is a valid data field, as JAX recognizes this as an empty pytree.drop_fields (Sequence[str])
- Returns:
The input class
nodetype
is returned unchanged after being added to JAX’s pytree registry, so thatregister_dataclass()
can be used as a decorator.- Return type:
Typ
Examples
In JAX v0.4.35 or older, you must specify
data_fields
andmeta_fields
in order to use this decorator:>>> import jax >>> from dataclasses import dataclass >>> from functools import partial ... >>> @partial(jax.tree_util.register_dataclass, ... data_fields=['x', 'y'], ... meta_fields=['op']) ... @dataclass ... class MyStruct: ... x: jax.Array ... y: jax.Array ... op: str ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Starting in JAX v0.4.36, the
data_fields
andmeta_fields
arguments are optional fordataclass()
inputs, with fields defaulting todata_fields
unless marked as static using static metadata indataclasses.field()
.>>> import jax >>> from dataclasses import dataclass, field ... >>> @jax.tree_util.register_dataclass ... @dataclass ... class MyStruct: ... x: jax.Array # defaults to non-static data field ... y: jax.Array # defaults to non-static data field ... op: str = field(metadata=dict(static=True)) # marked as static meta field. ... >>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add') >>> m MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Once this class is registered, it can be used with functions in
jax.tree
andjax.tree_util
:>>> leaves, treedef = jax.tree.flatten(m) >>> leaves [Array([1., 1., 1.], dtype=float32), Array([0, 1, 2], dtype=int32)] >>> treedef PyTreeDef(CustomNode(MyStruct[('add',)], [*, *])) >>> jax.tree.unflatten(treedef, leaves) MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
In particular, this registration allows
m
to be passed seamlessly through code wrapped injax.jit()
and other JAX transformations, withdata_fields
being treated as dynamic arguments, andmeta_fields
being treated as static arguments:>>> @jax.jit ... def compiled_func(m): ... if m.op == 'add': ... return m.x + m.y ... else: ... raise ValueError(f"{m.op=}") ... >>> compiled_func(m) Array([1., 2., 3.], dtype=float32)