jax.export
module#
jax.export
is a library for exporting and serializing JAX functions
for persistent archival.
See the Exporting and serialization documentation.
Classes#
- class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[source]#
A JAX function lowered to StableHLO.
- Parameters:
fun_name (str)
in_tree (tree_util.PyTreeDef)
in_avals (tuple[core.ShapedArray, ...])
out_tree (tree_util.PyTreeDef)
out_avals (tuple[core.ShapedArray, ...])
in_shardings_hlo (tuple[HloSharding | None, ...])
out_shardings_hlo (tuple[HloSharding | None, ...])
nr_devices (int)
ordered_effects (tuple[effects.Effect, ...])
unordered_effects (tuple[effects.Effect, ...])
disabled_safety_checks (Sequence[DisabledSafetyCheck])
mlir_module_serialized (bytes)
calling_convention_version (int)
uses_global_constants (bool)
- in_tree#
a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX function. The actual lowering does not depend on the in_tree, but this can be used to invoke the exported function using the same argument structure.
- Type:
tree_util.PyTreeDef
- in_avals#
the flat tuple of input abstract values. May contain dimension expressions in the shapes.
- Type:
tuple[core.ShapedArray, …]
- out_tree#
a PyTreeDef describing the result of the lowered JAX function.
- Type:
tree_util.PyTreeDef
- out_avals#
the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in in_avals.
- Type:
tuple[core.ShapedArray, …]
- in_shardings_hlo#
the flattened input shardings, a sequence as long as in_avals. None means unspecified sharding. Note that these do not include the mesh or the actual devices used in the mesh. See in_shardings_jax for a way to turn these into sharding specification that can be used with JAX APIs.
- Type:
tuple[HloSharding | None, …]
- out_shardings_hlo#
the flattened output shardings, a sequence as long as out_avals. None means unspecified sharding. Note that these do not include the mesh or the actual devices used in the mesh. See out_shardings_jax for a way to turn these into sharding specification that can be used with JAX APIs.
- Type:
tuple[HloSharding | None, …]
- platforms#
a tuple containing the platforms for which the function should be exported. The set of platforms in JAX is open-ended; users can add platforms. JAX built-in platforms are: ‘tpu’, ‘cpu’, ‘cuda’, ‘rocm’. See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export.
- ordered_effects#
the ordered effects present in the serialized module. This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention for the calling convention in presence of ordered effects.
- Type:
tuple[effects.Effect, …]
- unordered_effects#
the unordered effects present in the serialized module. This is present from serialization version 9.
- Type:
tuple[effects.Effect, …]
- calling_convention_version#
a version number for the calling convention of the exported module. See more versioning details at https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions.
- Type:
- module_kept_var_idx#
the sorted indices of the arguments among in_avals that must be passed to the module. The other arguments have been dropped because they are not used.
- uses_global_constants#
whether the mlir_module_serialized uses shape polymorphism or multi-platform export. This may be because in_avals contains dimension variables, or due to inner calls of Exported modules that have dimension variables or platform index arguments. Such modules need shape refinement before XLA compilation.
- Type:
- disabled_safety_checks#
a list of descriptors of safety checks that have been disabled at export time. See docstring for DisabledSafetyCheck.
- Type:
Sequence[DisabledSafetyCheck]
- _get_vjp#
an optional function that takes the current exported function and returns the exported VJP function. The VJP function takes a flat list of arguments, starting with the primal arguments and followed by a cotangent argument for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs.
See a [description of the calling convention for the mlir_module](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention).
- call(*args, **kwargs)[source]#
Call an exported function from a JAX program.
- Parameters:
args – the positional arguments to pass to the exported function. This should be a pytree of arrays with the same pytree structure as the arguments for which the function was exported.
kwargs – the keyword arguments to pass to the exported function.
- Returns: a pytree of result array, with the same structure as the
results of the exported function.
The invocation supports reverse-mode AD, and all the features supported by exporting: shape polymorphism, multi-platform, device polymorphism. See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html).
- in_shardings_jax(mesh)[source]#
Creates Shardings corresponding to self.in_shardings_hlo.
The Exported object stores in_shardings_hlo as HloShardings, which are independent of a mesh or set of devices. This method constructs Sharding that can be used in JAX APIs such as jax.jit or jax.device_put.
Example usage:
>>> from jax import export >>> # Prepare the exported object: >>> exp_mesh = sharding.Mesh(jax.devices(), ("a",)) >>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x), ... in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a"))) ... )(np.arange(jax.device_count())) >>> exp.in_shardings_hlo ({devices=[8]<=[8]},) >>> # Create a mesh for running the exported object >>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",)) >>> # Put the args and kwargs on the appropriate devices >>> run_arg = jax.device_put(np.arange(jax.device_count()), ... exp.in_shardings_jax(run_mesh)[0]) >>> res = exp.call(run_arg) >>> res.addressable_shards [Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]), Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]), Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]), Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]), Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]), Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]), Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]), Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
- Parameters:
mesh (sharding.Mesh)
- Return type:
Sequence[sharding.Sharding | None]
- out_shardings_jax(mesh)[source]#
Creates Shardings corresponding to self.out_shardings_hlo.
See documentation for in_shardings_jax.
- Parameters:
mesh (sharding.Mesh)
- Return type:
Sequence[sharding.Sharding | None]
- serialize(vjp_order=0)[source]#
Serializes an Exported.
- Parameters:
vjp_order (int) – The maximum vjp order to include. E.g., the value 2 means that we serialize the primal functions and two orders of the vjp function. This should allow 2nd order reverse mode differentiation of the deserialized function. i.e., jax.grad(jax.grad(f)).
- Return type:
- class jax.export.DisabledSafetyCheck(_impl)[source]#
A safety check that should be skipped on (de)serialization.
Most of these checks are performed on serialization, but some are deferred to deserialization. The list of disabled checks is attached to the serialization, e.g., as a sequence of string attributes to jax.export.Exported or of tf.XlaCallModuleOp.
When using jax2tf, you can disable more deserialization safety checks by passing TF_XLA_FLAGS=–tf_xla_call_module_disabled_checks=platform.
- Parameters:
_impl (str)
- classmethod custom_call(target_name)[source]#
Allows the serialization of a call target not known to be stable.
Has effect only on serialization. :param target_name: the name of the custom call target to allow.
- Parameters:
target_name (str)
- Return type:
- is_custom_call()[source]#
Returns the custom call target allowed by this directive.
- Return type:
str | None
Functions#
|
Exports a JAX function for persistent serialization. |
|
Deserializes an Exported. |
int([x]) -> integer int(x, base=10) -> integer |
|
int([x]) -> integer int(x, base=10) -> integer |
|
Retrieves the default export platform. |
|
|
Registers a custom PyTree node for serialization and deserialization. |
|
Registers a namedtuple for serialization and deserialization. |
Constants#
- jax.export.minimum_supported_serialization_version#
The minimum supported serialization version; see Calling convention versions.
- jax.export.maximum_supported_serialization_version#
The maximum supported serialization version; see Calling convention versions.