jax.stages
module#
Interfaces to stages of the compiled execution process.
JAX transformations that compile just in time for execution, such as
jax.jit
and jax.pmap
, also support a common means of explicit
lowering and compilation ahead of time. This module defines types
that represent the stages of this process.
For more, see the AOT walkthrough.
Classes#
- class jax.stages.Wrapped(*args, **kwargs)[source]#
A function ready to be traced, lowered, and compiled.
This protocol reflects the output of functions such as
jax.jit
. Calling it results in JIT (just-in-time) lowering, compilation, and execution. It can also be explicitly lowered prior to compilation, and the result compiled prior to execution.- lower(*args, **kwargs)[source]#
Lower this function explicitly for the given arguments.
This is a shortcut for
self.trace(*args, **kwargs).lower()
.A lowered function is staged out of Python and translated to a compiler’s input language, possibly in a backend-dependent manner. It is ready for compilation but not yet compiled.
- Returns:
A
Lowered
instance representing the lowering.- Return type:
- class jax.stages.Traced(jaxpr, args_info, fun_name, out_tree, lower_callable, args_flat=None, arg_names=None, num_consts=0)[source]#
Traced form of a function specialized to argument types and values.
A traced computation is ready for lowering. This class carries the traced representation with the remaining information needed to later lower, compile, and execute it.
- Parameters:
jaxpr (core.ClosedJaxpr)
args_info (Any)
num_consts (int)
- class jax.stages.Lowered(lowering, args_info, out_tree, no_kwargs=False)[source]#
Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class carries a lowering together with the remaining information needed to later compile and execute it. It also provides a common API for querying properties of lowered computations across JAX’s various lowering paths (
jit()
,pmap()
, etc.).- Parameters:
lowering (XlaLowering)
args_info (Any)
out_tree (tree_util.PyTreeDef)
no_kwargs (bool)
- as_text(dialect=None, *, debug_info=False)[source]#
A human-readable text representation of this lowering.
Intended for visualization and debugging purposes. This need not be a valid nor reliable serialization. Use jax.export if you want reliable and portable serialization.
- compile(compiler_options=None)[source]#
Compile, returning a corresponding
Compiled
instance.- Parameters:
compiler_options (CompilerOptions | None | None)
- Return type:
- compiler_ir(dialect=None)[source]#
An arbitrary object representation of this lowering.
Intended for debugging purposes. This is not a valid nor reliable serialization. The output has no guarantee of consistency across invocations. Use jax.export if you want reliable and portable serialization.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Parameters:
dialect (str | None | None) – Optional string specifying a lowering dialect (e.g. “stablehlo”, or “hlo”).
- Return type:
Any | None
- cost_analysis()[source]#
A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
Any | None
- class jax.stages.Compiled(executable, args_info, out_tree, no_kwargs=False)[source]#
Compiled representation of a function specialized to types/values.
A compiled computation is associated with an executable and the remaining information needed to execute it. It also provides a common API for querying properties of compiled computations across JAX’s various compilation paths and backends.
- Parameters:
args_info (Any)
out_tree (tree_util.PyTreeDef)
- as_text()[source]#
A human-readable text representation of this executable.
Intended for visualization and debugging purposes. This is not a valid nor reliable serialization.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
str | None
- cost_analysis()[source]#
A summary of execution cost estimates.
Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
Any | None
- property in_tree: tree_util.PyTreeDef[source]#
Tree structure of the pair (positional arguments, keyword arguments).
- memory_analysis()[source]#
A summary of estimated memory requirements.
Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
Any | None
- runtime_executable()[source]#
An arbitrary object representation of this executable.
Intended for debugging purposes. This is not valid nor reliable serialization. The output has no guarantee of consistency across invocations.
Returns
None
if unavailable, e.g. based on backend, compiler, or runtime.- Return type:
Any | None