jax.tree_util.register_static#
- jax.tree_util.register_static(cls)[source]#
Registers cls as a pytree with no leaves.
Instances are treated as static by
jax.jit()
,jax.pmap()
, etc. This can be an alternative to labeling inputs as static usingjit
’sstatic_argnums
andstatic_argnames
kwargs,pmap
’sstatic_broadcasted_argnums
, etc.- Parameters:
cls (type[H]) – type to be registered as static. Must be hashable, as defined in https://docs.python.org/3/glossary.html#term-hashable.
- Returns:
The input class
cls
is returned unchanged after being added to JAX’s pytree registry. This allowsregister_static
to be used as a decorator.- Return type:
type[H]
Examples
>>> import jax >>> @jax.tree_util.register_static ... class StaticStr(str): ... pass
This static string can now be used directly in
jax.jit()
-compiled functions, without marking the variable static usingstatic_argnums
:>>> @jax.jit ... def f(x, y, s): ... return x + y if s == 'add' else x - y ... >>> f(1, 2, StaticStr('add')) Array(3, dtype=int32, weak_type=True)