jax.extend.xla.register_hlo_module_transformation

jax.extend.xla.register_hlo_module_transformation#

jax.extend.xla.register_hlo_module_transformation(callback, *, name, stage=PipelineStage.PRE_SCHEDULER, platforms=None)[source]#

Register a custom compiler pass that transforms HLO modules.

The registered pass will be called during XLA compilation at the specified pipeline stage. The callback receives a serialized HloModuleProto as bytes and should return either:

  • Modified serialized HloModuleProto bytes if the module was changed.

  • None if no changes were made.

Multiple registration calls at the same stage (with different callbacks) will be added to a queue, and be invoked in the order they were registered.

Parameters:
  • callback (Callable[[bytes], bytes | None]) – A function (bytes) -> bytes | None that receives a serialized HloModuleProto and optionally returns a modified one.

  • name (str) – A name for the compiler pass.

  • stage (PipelineStage) – The pipeline stage at which the pass runs. Must be a PipelineStage enum.

  • platforms (Sequence[str] | str | None) – The list of platforms to register the pass for (e.g. "cpu", "tpu"). If None, the pass is registered for all known backends by default. Can be a single platform string or a sequence of strings.

Return type:

None