jax.lib module#

The jax.lib package is a set of internal tools and types for bridging between JAX’s Python frontend and its XLA backend.

jax.lib.xla_bridge#

get_backend([platform])

get_compile_options(num_replicas, num_partitions)

Returns the compile options to use, as derived from flag values.

jax.lib.xla_client#

register_custom_call_target(name, fn[, ...])

Registers a custom call target.