jax.extend.xla.register_hlo_module_transformation#
- jax.extend.xla.register_hlo_module_transformation(callback, *, name, stage=PipelineStage.PRE_SCHEDULER, platforms=None)[source]#
Register a custom compiler pass that transforms HLO modules.
The registered pass will be called during XLA compilation at the specified pipeline stage. The callback receives a serialized
HloModuleProtoas bytes and should return either:Modified serialized
HloModuleProtobytes if the module was changed.Noneif no changes were made.
Multiple registration calls at the same stage (with different callbacks) will be added to a queue, and be invoked in the order they were registered.
- Parameters:
callback (Callable[[bytes], bytes | None]) – A function
(bytes) -> bytes | Nonethat receives a serialized HloModuleProto and optionally returns a modified one.name (str) – A name for the compiler pass.
stage (PipelineStage) – The pipeline stage at which the pass runs. Must be a
PipelineStageenum.platforms (Sequence[str] | str | None) – The list of platforms to register the pass for (e.g.
"cpu","tpu"). IfNone, the pass is registered for all known backends by default. Can be a single platform string or a sequence of strings.
- Return type:
None