jax.distributed module# initialize([coordinator_address, ...]) Initializes the JAX distributed system. shutdown() Shuts down the distributed system.