GPU performance tips#

This document focuses on performance tips for neural network workloads

Matmul precision#

On recent GPU generations, such as the Nvidia A100 generation or later, it can be a good idea to perform most computations in bfloat16 precision. For example, if using Flax, instantiate Dense layers using flax.linen.Dense(..., dtype=jax.numpy.bfloat16). Here are some code examples:

XLA performance flags#

Note

JAX-Toolbox also has a page on NVIDIA XLA performance FLAGS.

The existence and exact behavior of XLA flags may be jaxlib-version dependent.

As of jaxlib==0.4.18 (released Oct 6 2023), setting these XLA flags can improve performance. Some are related to communication between GPUs, and so are only relevant when running computations on multiple devices, while others are related to code generation on each device.

Some of these may be set by default in future releases.

These flags can be set via the XLA_FLAGS shell environment variable. For example, we can add this to the top of a Python file:

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

For more examples, see also XLA Flags recommended for Pax training on Nvidia GPUs.

Code generation flags#

  • –xla_gpu_triton_gemm_any Use the Triton-based GEMM (matmul) emitter for any GEMM that it supports. The default value is False.

Communication tips#

Auto and manual PGLE#

The Profile Guided Latency Estimator (PGLE) workflow measures the actual running time of compute and collectives, the the profile information is fed back into XLA compiler for a better scheduling decision.

The Profile Guided Latency Estimator can be used manually or automatically. In the auto mode JAX will collect profile information and recompile a module in a single run. While in manual mode you need to run a task twice, the first time to collect and save profiles and the second to compile and run with provided data.

Auto PGLE#

The auto PGLE can be turned on by setting the following environment variables:

Mandatory:

export JAX_ENABLE_PGLE=true

# For JAX version <= 0.5.0 make sure to include:
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"

Optional:

export JAX_PGLE_PROFILING_RUNS=3
export JAX_PGLE_AGGREGATION_PERCENTILE=85

# Right now the auto PGLE profile collection doesn't work with command buffer.
# If the command buffer is enabled, Auto PGLE will disable it during profile
# colletion and enable it back after the recompilation. If you need to have a
# consistent command buffer logic with and with PGLE profile you can disable it
# manually:
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_command_buffer=''"

Or in the JAX this can be set as the following:

import jax
from jax._src import config

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  # Run with the profiler collecting performance information.
  train_step()
  # Automatically re-compile with PGLE profile results
  train_step()
  ...

You can control amount of reruns used to collect profile data by changing JAX_PGLE_PROFILING_RUNS. Increasing this parameter would lead to better profile information, but it will also increase the amount of non-optimized training steps.

Decreasing the JAX_PGLE_AGGREGATION_PERCENTILE parameter might help in case when performance between steps is too noisy to filter out a non-relevant measures.

Attention: Auto PGLE doesn’t work for pre-compiled modules. Since JAX need to recompile the module during execution the auto PGLE will not work neither for AoT nor for the following case:

import jax
from jax._src import config

train_step_compiled = train_step().lower().compile()

with config.enable_pgle(True), config.pgle_profiling_runs(1):
  train_step_compiled()
  # No effect since module was pre-compiled.
  train_step_compiled()

Manual PGLE#

If you still want to use a manual Profile Guided Latency Estimator the workflow in XLA/GPU is:

    1. Run your workload once, with async collectives and latency hiding scheduler enabled.

You could do so by setting:

export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
    1. Collect and post process a profile by using JAX profiler, saving the extracted instruction latencies into a binary protobuf file.

import os
from etils import epath
import jax
from jax.experimental import profiler as exp_profiler

# Define your profile directory
profile_dir = 'gs://my_bucket/profile'
jax.profiler.start_trace(profile_dir)

# run your workflow
# for i in range(10):
#   train_step()

# Stop trace
jax.profiler.stop_trace()
profile_dir = epath.Path(profile_dir)
directories = profile_dir.glob('plugins/profile/*/')
directories = [d for d in directories if d.is_dir()]
rundir = directories[-1]
logging.info('rundir: %s', rundir)

# Post process the profile
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))

# Save the profile proto to a file.
dump_dir = rundir / 'profile.pb'
dump_dir.parent.mkdir(parents=True, exist_ok=True)
dump_dir.write_bytes(fdo_profile)

After this step, you will get a profile.pb file under the rundir printed in the code.

    1. Run the workload again feeding that file into the compilation.

You need to pass the profile.pb file to the --xla_gpu_pgle_profile_file_or_directory_path flag.

 export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"

To enable logging in the XLA and check if the profile is good, set the logging level to include INFO:

export TF_CPP_MIN_LOG_LEVEL=0

Run the real workflow, if you found these loggings in the running log, it means the profiler is used in the latency hiding scheduler:

2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator

Flags#

  • –xla_gpu_enable_latency_hiding_scheduler This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.

  • –xla_gpu_memory_limit_slop_factor This flag serves as a multiplier applied to the total available memory, creating a threshold that guides the Latency Hiding Scheduler (LHS) in balancing memory reduction and latency hiding optimizations. The default value is 95.

    This factor effectively establishes a memory limit for compiler passes, determining when the scheduler should prioritize:

    1. Memory reduction: When memory usage approaches or exceeds the calculated threshold.

    2. Latency hiding: When memory usage is below the threshold, allowing for more aggressive optimizations that may temporarily increase memory usage but improve overall performance.

    By adjusting this factor, users can fine-tune the trade-off between memory efficiency and performance optimizations.

  • –xla_gpu_enable_pipelined_collectives When using pipeline parallelism, this flag enables overlapping the (i+1)-th layer weight AllGather with the i-th layer computation. It also enables overlapping (i+1)-th layer weight Reduce/ReduceScatter with i-th layer’s computation. The default value is False. There are some bugs when this flag is turned on.

  • –xla_gpu_collective_permute_decomposer_threshold This flag is useful when performing GSPMD pipelining. Setting a nonzero threshold decomposes CollectivePermutes into CollectivePermuteReceiveDone and CollectivePermuteSendDone pairs, so that computation can be performed between each corresponding ReceiveDone/SendDone pair and hence achieve more overlap. By default the threshold is 0 and there is no decomposition. Setting it to threshold > 0 such as --xla_gpu_collective_permute_decomposer_threshold=1024 can enable this feature.

  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes These flags tune when to combine multiple small AllGather/ReduceScatter/AllReduce into one big AllGather/ReduceScatter/AllReduce to reduce time spent on cross-device communication. For example, for the AllGather/ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather/ReduceScatter. By default, the combine_threshold_bytes is set to 256.

NCCL flags#

These Nvidia NCCL flag values may be useful for single-host multi-device computations on Nvidia GPUs:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

These NCCL flags could improve single-host communication speed. These flags don’t seem useful for multi-host communication yet.

Multi-Process#

We recommend using one process per GPU and not one per node. In some cases, this can speed up jitted computation. The jax.distributed.initialize() API will automatically understand that configuration when run under SLURM. However, this only a rule of thumb and it may be useful to test both one process per GPU and one process per node on your use case.