jax.extend.linear_util.cache#

jax.extend.linear_util.cache(call, *, explain=None)[source]#

Memoization decorator for functions taking a WrappedFun as first argument.

Parameters:
  • call (Callable) – a Python callable that takes a WrappedFun as its first argument. The underlying transforms and params on the WrappedFun are used as part of the memoization cache key.

  • explain (Callable[[WrappedFun, bool, dict, tuple], None] | None | None) – a function that is invoked upon cache misses to log an explanation of the miss. Invoked with (fun, is_cache_first_use, cache, key).

Returns:

A memoized version of call.