jax.lax.platform_dependent#
- jax.lax.platform_dependent(*args, default=None, **per_platform)[source]#
Stages out platform-specific code.
In JAX the actual platform on which a computation is run is determined very late, e.g., based on where the data is located. When using AOT lowering or serialization, the computation may be compiled and executed on a different machine, or even on a platform that is not available at lowering time. This means that it is not safe to write platform-dependent code using Python conditionals, e.g., based on the current default JAX platform. Instead, one can use
platform_dependent
:Usage:
def cpu_code(*args): ... def tpu_code(*args): ... def other_platforms_code(*args): ... res = platform_dependent(*args, cpu=cpu_code, tpu=tpu_code, default=other_platforms_code)
When the staged out code is executed on a CPU, this is equivalent to
cpu_code(*args)
, on a TPU is equivalent totpu_code(*args)
and on any other platform toother_platforms_code(*args)
. Unlike a Python conditional, all alternatives are traced and staged out to Jaxpr. This is similar to, and is implemented in terms of,switch()
, from which it inherits the behavior under transformations.Unlike a
switch()
the choice of what gets executed is made earlier: in most cases during lowering when the lowering platform is known; in the rare case of multi-platform lowering and serialization, the StableHLO code will contain a conditional on the actual platform. This conditional is resolved just in time prior to compilation when the compilation platform is known. This means that the compiler actually never sees a conditional.- Parameters:
*args (Any) – JAX arrays passed to each of the branches. May be PyTrees.
**per_platform (Callable[..., _T]) – branches to use for different platforms. The branches are JAX callables invoked with
*args
. The keywords are platform names, e.g., ‘cpu’, ‘tpu’, ‘cuda’, ‘rocm’.default (Callable[..., _T] | None | None) – optional default branch to use for a platform not mentioned in
per_platform
. If there is nodefault
there will be an error when the code is lowered for a platform not mentioned inper_platform
.
- Returns:
The value
per_platform[execution_platform](*args)
.