jax.extend.linear_util.WrappedFun#

class jax.extend.linear_util.WrappedFun(f, f_transformed, transforms, stores, params, in_type, debug_info)[source]#

Represents a function f to which transforms are to be applied.

Parameters:
  • f (Callable) – the function to be transformed.

  • transforms – a list of (gen, gen_static_args) tuples representing transformations to apply to f. Here gen is a generator function and gen_static_args is a tuple of static arguments for the generator. See description at the start of this module for the expected behavior of the generator.

  • stores (tuple[Store | EqualStore | None, ...]) – a list of out_store for the auxiliary output of the transforms.

  • params – extra parameters to pass as keyword arguments to f, along with the transformed keyword arguments.

  • debug_info (DebugInfo) – debugging info about the function being wrapped.

  • f_transformed (Callable)

__init__(f, f_transformed, transforms, stores, params, in_type, debug_info)[source]#
Parameters:
  • f (Callable)

  • f_transformed (Callable)

  • stores (tuple[Store | EqualStore | None, ...])

  • debug_info (DebugInfo)

Methods

__init__(f, f_transformed, transforms, ...)

call_wrapped(*args, **kwargs)

Calls the transformed function

populate_stores(stores)

Copy the values from the stores into self.stores.

wrap(gen, gen_static_args, out_store)

Add another transform and its store.

Attributes

f

f_transformed

transforms

stores

params

in_type

debug_info