jax.export module#

jax.export is a library for exporting and serializing JAX functions for persistent archival.

See the Exporting and serialization documentation.

Classes#

class jax.export.Exported(fun_name, in_tree, in_avals, out_tree, out_avals, in_shardings_hlo, out_shardings_hlo, nr_devices, platforms, ordered_effects, unordered_effects, disabled_safety_checks, mlir_module_serialized, calling_convention_version, module_kept_var_idx, uses_global_constants, _get_vjp)[source]#

A JAX function lowered to StableHLO.

Parameters:
  • fun_name (str)

  • in_tree (tree_util.PyTreeDef)

  • in_avals (tuple[core.ShapedArray, ...])

  • out_tree (tree_util.PyTreeDef)

  • out_avals (tuple[core.ShapedArray, ...])

  • in_shardings_hlo (tuple[HloSharding | None, ...])

  • out_shardings_hlo (tuple[HloSharding | None, ...])

  • nr_devices (int)

  • platforms (tuple[str, ...])

  • ordered_effects (tuple[effects.Effect, ...])

  • unordered_effects (tuple[effects.Effect, ...])

  • disabled_safety_checks (Sequence[DisabledSafetyCheck])

  • mlir_module_serialized (bytes)

  • calling_convention_version (int)

  • module_kept_var_idx (tuple[int, ...])

  • uses_global_constants (bool)

  • _get_vjp (Callable[[Exported], Exported] | None)

fun_name#

the name of the exported function, for error messages.

Type:

str

in_tree#

a PyTreeDef describing the tuple (args, kwargs) of the lowered JAX function. The actual lowering does not depend on the in_tree, but this can be used to invoke the exported function using the same argument structure.

Type:

tree_util.PyTreeDef

in_avals#

the flat tuple of input abstract values. May contain dimension expressions in the shapes.

Type:

tuple[core.ShapedArray, …]

out_tree#

a PyTreeDef describing the result of the lowered JAX function.

Type:

tree_util.PyTreeDef

out_avals#

the flat tuple of output abstract values. May contain dimension expressions in the shapes, with dimension variables among those in in_avals.

Type:

tuple[core.ShapedArray, …]

in_shardings_hlo#

the flattened input shardings, a sequence as long as in_avals. None means unspecified sharding. Note that these do not include the mesh or the actual devices used in the mesh. See in_shardings_jax for a way to turn these into sharding specification that can be used with JAX APIs.

Type:

tuple[HloSharding | None, …]

out_shardings_hlo#

the flattened output shardings, a sequence as long as out_avals. None means unspecified sharding. Note that these do not include the mesh or the actual devices used in the mesh. See out_shardings_jax for a way to turn these into sharding specification that can be used with JAX APIs.

Type:

tuple[HloSharding | None, …]

nr_devices#

the number of devices that the module has been lowered for.

Type:

int

platforms#

a tuple containing the platforms for which the function should be exported. The set of platforms in JAX is open-ended; users can add platforms. JAX built-in platforms are: ‘tpu’, ‘cpu’, ‘cuda’, ‘rocm’. See https://jax.readthedocs.io/en/latest/export/export.html#cross-platform-and-multi-platform-export.

Type:

tuple[str, …]

ordered_effects#

the ordered effects present in the serialized module. This is present from serialization version 9. See https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention for the calling convention in presence of ordered effects.

Type:

tuple[effects.Effect, …]

unordered_effects#

the unordered effects present in the serialized module. This is present from serialization version 9.

Type:

tuple[effects.Effect, …]

mlir_module_serialized#

the serialized lowered VHLO module.

Type:

bytes

calling_convention_version#

a version number for the calling convention of the exported module. See more versioning details at https://jax.readthedocs.io/en/latest/export/export.html#calling-convention-versions.

Type:

int

module_kept_var_idx#

the sorted indices of the arguments among in_avals that must be passed to the module. The other arguments have been dropped because they are not used.

Type:

tuple[int, …]

uses_global_constants#

whether the mlir_module_serialized uses shape polymorphism or multi-platform export. This may be because in_avals contains dimension variables, or due to inner calls of Exported modules that have dimension variables or platform index arguments. Such modules need shape refinement before XLA compilation.

Type:

bool

disabled_safety_checks#

a list of descriptors of safety checks that have been disabled at export time. See docstring for DisabledSafetyCheck.

Type:

Sequence[DisabledSafetyCheck]

_get_vjp#

an optional function that takes the current exported function and returns the exported VJP function. The VJP function takes a flat list of arguments, starting with the primal arguments and followed by a cotangent argument for each primal output. It returns a tuple with the cotangents corresponding to the flattened primal inputs.

Type:

Callable[[Exported], Exported] | None

See a [description of the calling convention for the mlir_module](https://jax.readthedocs.io/en/latest/export/export.html#module-calling-convention).

call(*args, **kwargs)[source]#

Call an exported function from a JAX program.

Parameters:
  • args – the positional arguments to pass to the exported function. This should be a pytree of arrays with the same pytree structure as the arguments for which the function was exported.

  • kwargs – the keyword arguments to pass to the exported function.

Returns: a pytree of result array, with the same structure as the

results of the exported function.

The invocation supports reverse-mode AD, and all the features supported by exporting: shape polymorphism, multi-platform, device polymorphism. See the examples in the [JAX export documentation](https://jax.readthedocs.io/en/latest/export/export.html).

has_vjp()[source]#

Returns if this Exported supports VJP.

Return type:

bool

in_shardings_jax(mesh)[source]#

Creates Shardings corresponding to self.in_shardings_hlo.

The Exported object stores in_shardings_hlo as HloShardings, which are independent of a mesh or set of devices. This method constructs Sharding that can be used in JAX APIs such as jax.jit or jax.device_put.

Example usage:

>>> from jax import export
>>> # Prepare the exported object:
>>> exp_mesh = sharding.Mesh(jax.devices(), ("a",))
>>> exp = export.export(jax.jit(lambda x: jax.numpy.add(x, x),
...                             in_shardings=sharding.NamedSharding(exp_mesh, sharding.PartitionSpec("a")))
...     )(np.arange(jax.device_count()))
>>> exp.in_shardings_hlo
({devices=[8]<=[8]},)
>>> # Create a mesh for running the exported object
>>> run_mesh = sharding.Mesh(jax.devices()[::-1], ("b",))
>>> # Put the args and kwargs on the appropriate devices
>>> run_arg = jax.device_put(np.arange(jax.device_count()),
...     exp.in_shardings_jax(run_mesh)[0])
>>> res = exp.call(run_arg)
>>> res.addressable_shards
[Shard(device=CpuDevice(id=7), index=(slice(0, 1, None),), replica_id=0, data=[0]),
 Shard(device=CpuDevice(id=6), index=(slice(1, 2, None),), replica_id=0, data=[2]),
 Shard(device=CpuDevice(id=5), index=(slice(2, 3, None),), replica_id=0, data=[4]),
 Shard(device=CpuDevice(id=4), index=(slice(3, 4, None),), replica_id=0, data=[6]),
 Shard(device=CpuDevice(id=3), index=(slice(4, 5, None),), replica_id=0, data=[8]),
 Shard(device=CpuDevice(id=2), index=(slice(5, 6, None),), replica_id=0, data=[10]),
 Shard(device=CpuDevice(id=1), index=(slice(6, 7, None),), replica_id=0, data=[12]),
 Shard(device=CpuDevice(id=0), index=(slice(7, 8, None),), replica_id=0, data=[14])]
Parameters:

mesh (sharding.Mesh)

Return type:

Sequence[sharding.Sharding | None]

mlir_module()[source]#

A string representation of the mlir_module_serialized.

Return type:

str

out_shardings_jax(mesh)[source]#

Creates Shardings corresponding to self.out_shardings_hlo.

See documentation for in_shardings_jax.

Parameters:

mesh (sharding.Mesh)

Return type:

Sequence[sharding.Sharding | None]

serialize(vjp_order=0)[source]#

Serializes an Exported.

Parameters:

vjp_order (int) – The maximum vjp order to include. E.g., the value 2 means that we serialize the primal functions and two orders of the vjp function. This should allow 2nd order reverse mode differentiation of the deserialized function. i.e., jax.grad(jax.grad(f)).

Return type:

bytearray

vjp()[source]#

Gets the exported VJP.

Returns None if not available, which can happen if the Exported has been loaded from an external format without a VJP.

Return type:

Exported

class jax.export.DisabledSafetyCheck(_impl)[source]#

A safety check that should be skipped on (de)serialization.

Most of these checks are performed on serialization, but some are deferred to deserialization. The list of disabled checks is attached to the serialization, e.g., as a sequence of string attributes to jax.export.Exported or of tf.XlaCallModuleOp.

When using jax2tf, you can disable more deserialization safety checks by passing TF_XLA_FLAGS=–tf_xla_call_module_disabled_checks=platform.

Parameters:

_impl (str)

classmethod custom_call(target_name)[source]#

Allows the serialization of a call target not known to be stable.

Has effect only on serialization. :param target_name: the name of the custom call target to allow.

Parameters:

target_name (str)

Return type:

DisabledSafetyCheck

is_custom_call()[source]#

Returns the custom call target allowed by this directive.

Return type:

str | None

classmethod platform()[source]#

Allows the compilation platform to differ from the export platform.

Has effect only on deserialization.

Return type:

DisabledSafetyCheck

Functions#

export(fun_jit, *[, platforms, disabled_checks])

Exports a JAX function for persistent serialization.

deserialize(blob)

Deserializes an Exported.

minimum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

maximum_supported_calling_convention_version

int([x]) -> integer int(x, base=10) -> integer

default_export_platform()

Retrieves the default export platform.

register_pytree_node_serialization(nodetype, ...)

Registers a custom PyTree node for serialization and deserialization.

register_namedtuple_serialization(nodetype, ...)

Registers a namedtuple for serialization and deserialization.

Constants#

jax.export.minimum_supported_serialization_version#

The minimum supported serialization version; see Calling convention versions.

jax.export.maximum_supported_serialization_version#

The maximum supported serialization version; see Calling convention versions.