jax.extend.linear_util.WrappedFun#
- class jax.extend.linear_util.WrappedFun(f, f_transformed, transforms, stores, params, in_type, debug_info)[source]#
Represents a function f to which transforms are to be applied.
- Parameters:
f (Callable) – the function to be transformed.
transforms – a list of (gen, gen_static_args) tuples representing transformations to apply to f. Here gen is a generator function and gen_static_args is a tuple of static arguments for the generator. See description at the start of this module for the expected behavior of the generator.
stores (tuple[Store | EqualStore | None, ...]) – a list of out_store for the auxiliary output of the transforms.
params – extra parameters to pass as keyword arguments to f, along with the transformed keyword arguments.
debug_info (DebugInfo) – debugging info about the function being wrapped.
f_transformed (Callable)
- __init__(f, f_transformed, transforms, stores, params, in_type, debug_info)[source]#
- Parameters:
f (Callable)
f_transformed (Callable)
stores (tuple[Store | EqualStore | None, ...])
debug_info (DebugInfo)
Methods
__init__
(f, f_transformed, transforms, ...)call_wrapped
(*args, **kwargs)Calls the transformed function
populate_stores
(stores)Copy the values from the stores into self.stores.
wrap
(gen, gen_static_args, out_store)Add another transform and its store.
Attributes
f
f_transformed
transforms
stores
params
in_type
debug_info