Configuration Options#
JAX provides various configuration options to customize its behavior. These options control everything from numerical precision to debugging features.
How to Use Configuration Options#
JAX configuration options can be set in several ways:
Environment variables (set before running your program):
export JAX_ENABLE_X64=True python my_program.py
Runtime configuration (in your Python code):
import jax jax.config.update("jax_enable_x64", True)
Command-line flags (using Abseil):
# In your code: import jax jax.config.parse_flags_with_absl()
# When running: python my_program.py --jax_enable_x64=True
Common Configuration Options#
Here are some of the most frequently used configuration options:
jax_enable_x64
– Enable 64-bit floating-point precisionjax_disable_jit
– Disable JIT compilation for debuggingjax_debug_nans
– Check for and raise errors on NaNsjax_platforms
– Control which backends (CPU/GPU/TPU) JAX will initializejax_numpy_rank_promotion
– Control automatic rank promotion behaviorjax_default_matmul_precision
– Set default precision for matrix multiplication operations
All Configuration Options#
Below is a complete list of all available JAX configuration options:
Check Vma#
- Type:
bool
- Default Value:
False
- Configuration String:
'check_vma'
- Environment Variable:
CHECK_VMA
internal implementation detail of shard_map, DO NOT USE
Eager Constant Folding#
- Type:
bool
- Default Value:
False
- Configuration String:
'eager_constant_folding'
- Environment Variable:
EAGER_CONSTANT_FOLDING
Attempt constant folding during staging.
Jax2Tf Associative Scan Reductions#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax2tf_associative_scan_reductions'
- Environment Variable:
JAX2TF_ASSOCIATIVE_SCAN_REDUCTIONS
JAX has two separate lowering rules for the cumulative reduction primitives (cumsum, cumprod, cummax, cummin). On CPUs and GPUs it uses a lax.associative_scan, while for TPUs it uses the HLO ReduceWindow. The latter has a slow implementation on CPUs and GPUs. By default, jax2tf uses the TPU lowering. Set this flag to True to use the associative scan lowering usage, and only if it makes a difference for your application. See the jax2tf README.md for more details.
Jax2Tf Default Native Serialization#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax2tf_default_native_serialization'
- Environment Variable:
JAX2TF_DEFAULT_NATIVE_SERIALIZATION
Sets the default value of the native_serialization parameter to jax2tf.convert. Prefer using the parameter instead of the flag, the flag may be removed in the future. Starting with JAX 0.4.31 non-native serialization is deprecated.
Array Garbage Collection Guard#
- Type:
Enum values:
'allow'
,'log'
,'fatal'
- Default Value:
None
- Configuration String:
'jax_array_garbage_collection_guard'
- Environment Variable:
JAX_ARRAY_GARBAGE_COLLECTION_GUARD
Select garbage collection guard level for jax.Array
objects.
This option can be used to control what happens when a jax.Array
object is garbage collected. It is desirable for jax.Array
objects to be freed by Python reference counting rather than garbage collection in order to avoid device memory being held by the arrays until garbage collection occurs.
Valid values are:
allow
: do not log garbage collection ofjax.Array
objects.log
: log an error when ajax.Array
is garbage collected.fatal
: fatal error if ajax.Array
is garbage collected.
Default is allow
. Note that not all cycles may be detected.
Backend Target#
- Type:
str
- Default Value:
''
- Configuration String:
'jax_backend_target'
- Environment Variable:
JAX_BACKEND_TARGET
Bcoo Cusparse Lowering#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_bcoo_cusparse_lowering'
- Environment Variable:
JAX_BCOO_CUSPARSE_LOWERING
Enables lowering BCOO ops to cuSparse.
Check Proxy Envs#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_check_proxy_envs'
- Environment Variable:
JAX_CHECK_PROXY_ENVS
Checks proxy vars in user envs and emit warnings.
Check Tracer Leaks#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_check_tracer_leaks'
- Environment Variable:
JAX_CHECK_TRACER_LEAKS
Turn on checking for leaked tracers as soon as a trace completes. Enabling leak checking may have performance impacts: some caching is disabled, and other overheads may be added. Additionally, be aware that some Python debuggers can cause false positives, so it is recommended to disable any debuggers while leak checking is enabled.
Compilation Cache Dir#
- Type:
str
- Default Value:
None
- Configuration String:
'jax_compilation_cache_dir'
- Environment Variable:
JAX_COMPILATION_CACHE_DIR
Path for the cache. Precedence: 1. A call to compilation_cache.set_cache_dir(). 2. The value of this flag set in the command line or by default.
Compilation Cache Expect Pgle#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_compilation_cache_expect_pgle'
- Environment Variable:
JAX_COMPILATION_CACHE_EXPECT_PGLE
If set to True, compilation cache entries that were compiled with profile data (i.e. PGLE was enabled and the requisite number of executions were profiled) will be preferentially loaded, even if PGLE is not currently enabled. A warning will be printed when no preferred cache entry is found.
Compilation Cache Include Metadata In Key#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_compilation_cache_include_metadata_in_key'
- Environment Variable:
JAX_COMPILATION_CACHE_INCLUDE_METADATA_IN_KEY
Include metadata, such as file names and line numbers, in the compilation cache key. If false, the cache will still get hits even if functions or files are moved, etc. However, it means that executables loaded from the cache may have stale metadata, which may show up in, e.g., profiles.
Compilation Cache Max Size#
- Type:
int
- Default Value:
-1
- Configuration String:
'jax_compilation_cache_max_size'
- Environment Variable:
JAX_COMPILATION_CACHE_MAX_SIZE
The maximum size (in bytes) allowed for the persistent compilation cache. When set, the least recently accessed cache entry(s) will be deleted once the total cache directory size exceeds the specified limit. Caching will be disabled if this value is set to 0. A special value of -1 indicates no limit, allowing the cache size to grow indefinitely.
Compiler Detailed Logging Min Ops#
- Type:
int
- Default Value:
10
- Configuration String:
'jax_compiler_detailed_logging_min_ops'
- Environment Variable:
JAX_COMPILER_DETAILED_LOGGING_MIN_OPS
How big should a module be in MLIR operations before JAX enables detailed compiler logging? The intent of this flag is to suppress detailed logging for small/uninteresting computations.
Compiler Enable Remat Pass#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_compiler_enable_remat_pass'
- Environment Variable:
JAX_COMPILER_ENABLE_REMAT_PASS
Config to enable / disable the rematerialization HLO pass. Useful to allow XLA to automatically trade off memory and compute when encountering OOM errors. However, you are likely to get better results manually with jax.checkpoint
Cpu Collectives Implementation#
- Type:
Enum values:
'gloo'
,'mpi'
,'megascale'
- Default Value:
'gloo'
- Configuration String:
'jax_cpu_collectives_implementation'
- Environment Variable:
JAX_CPU_COLLECTIVES_IMPLEMENTATION
Cross-process collective implementation used on CPU. Must be one of (“gloo”, “mpi”)
Cpu Enable Async Dispatch#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_cpu_enable_async_dispatch'
- Environment Variable:
JAX_CPU_ENABLE_ASYNC_DISPATCH
Only applies to non-parallel computations. If False, run computationsinline without async dispatch.
Cpu Enable Gloo Collectives#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_cpu_enable_gloo_collectives'
- Environment Variable:
JAX_CPU_ENABLE_GLOO_COLLECTIVES
Deprecated, please use jax_cpu_collectives_implementation instead.
Cuda Visible Devices#
- Type:
str
- Default Value:
'all'
- Configuration String:
'jax_cuda_visible_devices'
- Environment Variable:
JAX_CUDA_VISIBLE_DEVICES
Custom Vjp Disable Shape Check#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_custom_vjp_disable_shape_check'
- Environment Variable:
JAX_CUSTOM_VJP_DISABLE_SHAPE_CHECK
Disable the check from #19009 to enable some custom_vjp hacks. This will be enabled by default in future versions of JAX, at which point all uses of the flag will be considered deprecated (following the API compatibility policy).
Debug Infs#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_debug_infs'
- Environment Variable:
JAX_DEBUG_INFS
Add inf checks to every operation. When an inf is detected on the output of a jit-compiled computation, call into the un-compiled version in an attempt to more precisely identify the operation which produced the inf.
Debug Key Reuse#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_debug_key_reuse'
- Environment Variable:
JAX_DEBUG_KEY_REUSE
Turn on experimental key reuse checking. With this configuration enabled, typed PRNG keys (i.e. keys created with jax.random.key()) will have their usage tracked, and incorrect reuse of a previously-used key will lead to an error. Currently enabling this leads to a small Python overhead on every call to a JIT-compiled function with keys as inputs or outputs.
Debug Log Modules#
- Type:
str
- Default Value:
''
- Configuration String:
'jax_debug_log_modules'
- Environment Variable:
JAX_DEBUG_LOG_MODULES
Comma-separated list of module names (e.g. “jax” or “jax._src.xla_bridge,jax._src.dispatch”) to enable debug logging for.
Debug Nans#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_debug_nans'
- Environment Variable:
JAX_DEBUG_NANS
Add nan checks to every operation. When a nan is detected on the output of a jit-compiled computation, call into the un-compiled version in an attempt to more precisely identify the operation which produced the nan.
Default Device#
- Type:
str
- Default Value:
None
- Configuration String:
'jax_default_device'
- Environment Variable:
JAX_DEFAULT_DEVICE
Configure the default device for JAX operations. Set to a Device object (e.g. jax.devices("cpu")[0]
) to use that Device as the default device for JAX operations and jit’d function calls (there is no effect on multi-device computations, e.g. pmapped function calls). Set to None to use the system default device. See Controlling data and computation placement on devices for more information on device placement.
Default Matmul Precision#
- Type:
Enum values:
'default'
,'high'
,'highest'
,'bfloat16'
,'tensorfloat32'
,'float32'
,'ANY_F8_ANY_F8_F32'
,'ANY_F8_ANY_F8_F32_FAST_ACCUM'
,'ANY_F8_ANY_F8_ANY'
,'ANY_F8_ANY_F8_ANY_FAST_ACCUM'
,'F16_F16_F16'
,'F16_F16_F32'
,'BF16_BF16_BF16'
,'BF16_BF16_F32'
,'BF16_BF16_F32_X3'
,'BF16_BF16_F32_X6'
,'BF16_BF16_F32_X9'
,'TF32_TF32_F32'
,'TF32_TF32_F32_X3'
,'F32_F32_F32'
,'F64_F64_F64'
- Default Value:
None
- Configuration String:
'jax_default_matmul_precision'
- Environment Variable:
JAX_DEFAULT_MATMUL_PRECISION
Control the default matmul and conv precision for 32bit inputs.
Some platforms, like TPU, offer configurable precision levels for matrix multiplication and convolution computations, trading off accuracy for speed. The precision can be controlled for each operation; for example, see the jax.lax.conv_general_dilated()
and jax.lax.dot()
docstrings. But it can be useful to control the default behavior obtained when an operation is not given a specific precision.
This option can be used to control the default precision level for computations involved in matrix multiplication and convolution on 32bit inputs. The levels roughly describe the precision at which scalar products are computed. The ‘bfloat16’ option is the fastest and least precise; ‘float32’ is similar to full float32 precision; ‘tensorfloat32’ is intermediate.
This parameter can also be used to specify an accumulation “algorithm” for functions that perform matrix multiplications, like jax.lax.dot()
. To specify an algorithm, set this option to the name of a DotAlgorithmPreset
.
Default Prng Impl#
- Type:
Enum values:
'threefry2x32'
,'rbg'
,'unsafe_rbg'
- Default Value:
'threefry2x32'
- Configuration String:
'jax_default_prng_impl'
- Environment Variable:
JAX_DEFAULT_PRNG_IMPL
Select the default PRNG implementation, used when one is not explicitly provided at seeding time.
Disable Jit#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_disable_jit'
- Environment Variable:
JAX_DISABLE_JIT
Disable JIT compilation and just call original Python.
Disable Most Optimizations#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_disable_most_optimizations'
- Environment Variable:
JAX_DISABLE_MOST_OPTIMIZATIONS
Disable Vmap Shmap Error#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_disable_vmap_shmap_error'
- Environment Variable:
JAX_DISABLE_VMAP_SHMAP_ERROR
Temporary workaround to disable an error check in vmap-of-shmap.
Disallow Mesh Context Manager#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_disallow_mesh_context_manager'
- Environment Variable:
JAX_DISALLOW_MESH_CONTEXT_MANAGER
If set to True, trying to use a mesh as a context manager will result in a RuntimeError.
Distributed Debug#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_distributed_debug'
- Environment Variable:
JAX_DISTRIBUTED_DEBUG
Enable logging useful for debugging multi-process distributed computations. Logging is performed with logging at WARNING level.
Dump Ir To#
- Type:
str
- Default Value:
''
- Configuration String:
'jax_dump_ir_to'
- Environment Variable:
JAX_DUMP_IR_TO
Path to which the IR that is emitted by JAX should be dumped as text files. If omitted, JAX will not dump IR. Supports the special value ‘sponge’ to pick the path from the environment variable TEST_UNDECLARED_OUTPUTS_DIR.
Dynamic Shapes#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_dynamic_shapes'
- Environment Variable:
JAX_DYNAMIC_SHAPES
Enables experimental features for staging out computations with dynamic shapes.
Enable Checks#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_enable_checks'
- Environment Variable:
JAX_ENABLE_CHECKS
Turn on invariant checking for JAX internals. Makes things slower.
Enable Compilation Cache#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_enable_compilation_cache'
- Environment Variable:
JAX_ENABLE_COMPILATION_CACHE
If set to False, the compilation cache will be disabled regardless of whether set_cache_dir() was called. If set to True, the path could be set to a default value or via a call to set_cache_dir().
Enable Custom Prng#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_enable_custom_prng'
- Environment Variable:
JAX_ENABLE_CUSTOM_PRNG
Enables an internal upgrade that allows one to define custom pseudo-random number generator implementations. This will be enabled by default in future versions of JAX, at which point all uses of the flag will be considered deprecated (following the API compatibility policy).
Enable Custom Vjp By Custom Transpose#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_enable_custom_vjp_by_custom_transpose'
- Environment Variable:
JAX_ENABLE_CUSTOM_VJP_BY_CUSTOM_TRANSPOSE
Enables an internal upgrade that implements jax.custom_vjp by reduction to jax.custom_jvp and jax.custom_transpose. This will be enabled by default in future versions of JAX, at which point all uses of the flag will be considered deprecated (following the API compatibility policy).
Enable Empty Arrays#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_enable_empty_arrays'
- Environment Variable:
JAX_ENABLE_EMPTY_ARRAYS
Enable the creation of an Array from an empty list of single-device arrays. This is to support MPMD/pipeline parallelism in McJAX (WIP).
Enable Pgle#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_enable_pgle'
- Environment Variable:
JAX_ENABLE_PGLE
If set to True and the property jax_pgle_profiling_runs is set to greater than 0, the modules will be recompiled after running specified number times with collected data provided to the profile guided latency estimator.
Enable X64#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_enable_x64'
- Environment Variable:
JAX_ENABLE_X64
Enable 64-bit types to be used
Error Checking Behavior Divide#
- Type:
Enum values:
'ignore'
,'raise'
- Default Value:
'ignore'
- Configuration String:
'jax_error_checking_behavior_divide'
- Environment Variable:
JAX_ERROR_CHECKING_BEHAVIOR_DIVIDE
Specify the behavior when a divide by zero is encountered. Options are “ignore” or “raise”.
Error Checking Behavior Nan#
- Type:
Enum values:
'ignore'
,'raise'
- Default Value:
'ignore'
- Configuration String:
'jax_error_checking_behavior_nan'
- Environment Variable:
JAX_ERROR_CHECKING_BEHAVIOR_NAN
Specify the behavior when a NaN is encountered. Options are “ignore” or “raise”.
Error Checking Behavior Oob#
- Type:
Enum values:
'ignore'
,'raise'
- Default Value:
'ignore'
- Configuration String:
'jax_error_checking_behavior_oob'
- Environment Variable:
JAX_ERROR_CHECKING_BEHAVIOR_OOB
Specify the behavior when an out of bounds access is encountered. Options are “ignore” or “raise”.
Exec Time Optimization Effort#
- Type:
float
- Default Value:
0.0
- Configuration String:
'jax_exec_time_optimization_effort'
- Environment Variable:
JAX_EXEC_TIME_OPTIMIZATION_EFFORT
Effort for minimizing execution time (higher means more effort), valid range [-1.0, 1.0].
Experimental Unsafe Xla Runtime Errors#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_experimental_unsafe_xla_runtime_errors'
- Environment Variable:
JAX_EXPERIMENTAL_UNSAFE_XLA_RUNTIME_ERRORS
Enable XLA runtime errors for jax.experimental.checkify.checks on CPU and GPU. These errors are async, might get lost and are not very readable. But, they crash the computation and enable you to write jittable checks without needing to checkify. Does not work under pmap/pjit.
Explain Cache Misses#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_explain_cache_misses'
- Environment Variable:
JAX_EXPLAIN_CACHE_MISSES
Each time there is a miss on one of the main caches (e.g. the tracing cache), log an explanation. Logging is performed with logging. When this option is set, the log level is WARNING; otherwise the level is DEBUG.
Export Calling Convention Version#
- Type:
int
- Default Value:
9
- Configuration String:
'jax_export_calling_convention_version'
- Environment Variable:
JAX_EXPORT_CALLING_CONVENTION_VERSION
The calling convention version number to use for exporting. This must be within the range of versions supported by the tf.XlaCallModule used in your deployment environment. See https://docs.jax.dev/en/latest/export/shape_poly.html#calling-convention-versions.
Export Ignore Forward Compatibility#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_export_ignore_forward_compatibility'
- Environment Variable:
JAX_EXPORT_IGNORE_FORWARD_COMPATIBILITY
Whether to ignore the forward compatibility lowering rules. See https://docs.jax.dev/en/latest/export/export.html#compatibility-guarantees-for-custom-calls.
High Dynamic Range Gumbel#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_high_dynamic_range_gumbel'
- Environment Variable:
JAX_HIGH_DYNAMIC_RANGE_GUMBEL
If True, gumble noise draws two samples to cover low probability events with more precision.
Hlo Source File Canonicalization Regex#
- Type:
str
- Default Value:
None
- Configuration String:
'jax_hlo_source_file_canonicalization_regex'
- Environment Variable:
JAX_HLO_SOURCE_FILE_CANONICALIZATION_REGEX
Used to canonicalize the source_path metadata of HLO instructions by removing the given regex. If set, re.sub() is called on each source_file with the given regex, and all matches are removed. This can be used to avoid spurious cache misses when using the persistent compilation cache, which includes HLO metadata in the cache key.
Include Debug Info In Dumps#
- Type:
str
- Default Value:
'True'
- Configuration String:
'jax_include_debug_info_in_dumps'
- Environment Variable:
JAX_INCLUDE_DEBUG_INFO_IN_DUMPS
Determine whether or not to keep debug symbols and location information when dumping IR code. By default, debug information will be preserved in the IR dump. To avoid exposing source code and potentially sensitive information, set to false
Include Full Tracebacks In Locations#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_include_full_tracebacks_in_locations'
- Environment Variable:
JAX_INCLUDE_FULL_TRACEBACKS_IN_LOCATIONS
Include Python tracebacks in MLIR locations in IR emitted by JAX.
Legacy Prng Key#
- Type:
Enum values:
'allow'
,'warn'
,'error'
- Default Value:
'allow'
- Configuration String:
'jax_legacy_prng_key'
- Environment Variable:
JAX_LEGACY_PRNG_KEY
Specify the behavior when raw PRNG keys are passed to jax.random APIs.
Log Checkpoint Residuals#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_log_checkpoint_residuals'
- Environment Variable:
JAX_LOG_CHECKPOINT_RESIDUALS
Log a message every time jax.checkpoint (aka jax.remat) is partially evaluated (e.g. for autodiff), printing what residuals are saved.
Log Compiles#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_log_compiles'
- Environment Variable:
JAX_LOG_COMPILES
Log a message each time jit or pmap compiles an XLA computation. Logging is performed with logging. When this option is set, the log level is WARNING; otherwise the level is DEBUG.
Logging Level#
- Type:
Enum values:
'NOTSET'
,'DEBUG'
,'INFO'
,'WARNING'
,'ERROR'
,'CRITICAL'
- Default Value:
'NOTSET'
- Configuration String:
'jax_logging_level'
- Environment Variable:
JAX_LOGGING_LEVEL
Set the corresponding logging level on all jax loggers. Only string values from [“NOTSET”, “DEBUG”, “INFO”, “WARNING”, “ERROR”, “CRITICAL”] are accepted. If None, the logging level will not be set. Includes C++ logging.
Memory Fitting Effort#
- Type:
float
- Default Value:
0.0
- Configuration String:
'jax_memory_fitting_effort'
- Environment Variable:
JAX_MEMORY_FITTING_EFFORT
Effort for minimizing memory usage (higher means more effort), valid range [-1.0, 1.0].
Memory Fitting Level#
- Type:
Enum values:
'UNKNOWN'
,'O0'
,'O1'
,'O2'
,'O3'
- Default Value:
'O2'
- Configuration String:
'jax_memory_fitting_level'
- Environment Variable:
JAX_MEMORY_FITTING_LEVEL
The degree to which the compiler should attempt to make the program fit in memory
Mock Gpu Topology#
- Type:
str
- Default Value:
''
- Configuration String:
'jax_mock_gpu_topology'
- Environment Variable:
JAX_MOCK_GPU_TOPOLOGY
Mock multi-host GPU topology in GPU client. The value should be of the form “<number-of-slices> x <number-of-hosts-per-slice> x <number-of-devices-per-host>”. Empty string turns off mocking.
Mosaic Allow Hlo#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_mosaic_allow_hlo'
- Environment Variable:
JAX_MOSAIC_ALLOW_HLO
Allow hlo dialects in Mosaic
Mutable Array Checks#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_mutable_array_checks'
- Environment Variable:
JAX_MUTABLE_ARRAY_CHECKS
Enable error checks for mutable arrays that rule out aliasing. This will be enabled by default in future versions of JAX, at which point all uses of the flag will be considered deprecated (following the API compatibility policy).
No Tracing#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_no_tracing'
- Environment Variable:
JAX_NO_TRACING
Disallow tracing for JIT compilation.
Num Cpu Devices#
- Type:
int
- Default Value:
-1
- Configuration String:
'jax_num_cpu_devices'
- Environment Variable:
JAX_NUM_CPU_DEVICES
Number of CPU devices to use. If not provided, the value of the XLA flag –xla_force_host_platform_device_count is used. Must be set before JAX is initialized.
Numpy Dtype Promotion#
- Type:
Enum values:
'standard'
,'strict'
- Default Value:
'standard'
- Configuration String:
'jax_numpy_dtype_promotion'
- Environment Variable:
JAX_NUMPY_DTYPE_PROMOTION
Specify the rules used for implicit type promotion in operations between arrays. Options are “standard” or “strict”; in strict-mode, binary operations between arrays of differing strongly-specified dtypes will result in an error.
Numpy Rank Promotion#
- Type:
Enum values:
'allow'
,'warn'
,'raise'
- Default Value:
'allow'
- Configuration String:
'jax_numpy_rank_promotion'
- Environment Variable:
JAX_NUMPY_RANK_PROMOTION
Control NumPy-style automatic rank promotion broadcasting (“allow”, “warn”, or “raise”).
Optimization Level#
- Type:
Enum values:
'UNKNOWN'
,'O0'
,'O1'
,'O2'
,'O3'
- Default Value:
'UNKNOWN'
- Configuration String:
'jax_optimization_level'
- Environment Variable:
JAX_OPTIMIZATION_LEVEL
The degree to which the compiler should optimize for execution time
Pallas Dump Promela To#
- Type:
str
- Default Value:
''
- Configuration String:
'jax_pallas_dump_promela_to'
- Environment Variable:
JAX_PALLAS_DUMP_PROMELA_TO
If set, dumps a Promela model of the kernel to the specified directory. The model can verify that the kernel is free of data races, deadlocks, etc.
Pallas Enable Runtime Assert#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_pallas_enable_runtime_assert'
- Environment Variable:
JAX_PALLAS_ENABLE_RUNTIME_ASSERT
If set, enables runtime assertions in the kernel via checkify.check. Otherwise, runtime asserts will be ignored unless functionalized using checkify.checkify.
Pallas Use Mosaic Gpu#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_pallas_use_mosaic_gpu'
- Environment Variable:
JAX_PALLAS_USE_MOSAIC_GPU
If True, lower Pallas kernels to the experimental Mosaic GPU dialect, instead of Triton IR.
Pallas Verbose Errors#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_pallas_verbose_errors'
- Environment Variable:
JAX_PALLAS_VERBOSE_ERRORS
If True, print verbose error messages for Pallas kernels.
Persistent Cache Enable Xla Caches#
- Type:
str
- Default Value:
'xla_gpu_per_fusion_autotune_cache_dir'
- Configuration String:
'jax_persistent_cache_enable_xla_caches'
- Environment Variable:
JAX_PERSISTENT_CACHE_ENABLE_XLA_CACHES
When the persistent cache is enabled, additional XLA caching will also be enabled automatically. This option can be used to configurewhich XLA caching methods will be enabled.
Persistent Cache Min Compile Time Secs#
- Type:
float
- Default Value:
1.0
- Configuration String:
'jax_persistent_cache_min_compile_time_secs'
- Environment Variable:
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS
The minimum compile time of a computation to be written to the persistent compilation cache. This threshold can be raised to decrease the number of entries written to the cache.
Persistent Cache Min Entry Size Bytes#
- Type:
int
- Default Value:
0
- Configuration String:
'jax_persistent_cache_min_entry_size_bytes'
- Environment Variable:
JAX_PERSISTENT_CACHE_MIN_ENTRY_SIZE_BYTES
The minimum size (in bytes) of an entry that will be cached in the persistent compilation cache: * -1: disable the size restriction and prevent overrides. * Leave at default (0) to allow for overrides. The override will typically ensure that the minimum size is optimal for the filesystem being used for the cache. * > 0: the actual minimum size desired; no overrides.
Pgle Aggregation Percentile#
- Type:
int
- Default Value:
90
- Configuration String:
'jax_pgle_aggregation_percentile'
- Environment Variable:
JAX_PGLE_AGGREGATION_PERCENTILE
Percentile used to aggregate performance data between devices when PGLE is used.
Pgle Profiling Runs#
- Type:
int
- Default Value:
3
- Configuration String:
'jax_pgle_profiling_runs'
- Environment Variable:
JAX_PGLE_PROFILING_RUNS
Amount of times module should be profiled before recompilation when PGLE is used.
Pjrt Client Create Options#
- Type:
str
- Default Value:
None
- Configuration String:
'jax_pjrt_client_create_options'
- Environment Variable:
JAX_PJRT_CLIENT_CREATE_OPTIONS
A set of key-value pairs in the format of “k1:v1;k2:v2” strings provided to a device platform pjrt client as extra arguments.
Platform Name#
- Type:
str
- Default Value:
''
- Configuration String:
'jax_platform_name'
- Environment Variable:
JAX_PLATFORM_NAME
Platforms#
- Type:
str
- Default Value:
None
- Configuration String:
'jax_platforms'
- Environment Variable:
JAX_PLATFORMS
Comma-separated list of platform names specifying which platforms jax should initialize. If any of the platforms in this list are not successfully initialized, an exception will be raised and the program will be aborted. The first platform in the list will be the default platform. For example, config.jax_platforms=cpu,tpu means that CPU and TPU backends will be initialized, and the CPU backend will be used unless otherwise specified. If TPU initialization fails, it will raise an exception. By default, jax will try to initialize all available platforms and will default to GPU or TPU if available, and fallback to CPU otherwise.
Pmap No Rank Reduction#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_pmap_no_rank_reduction'
- Environment Variable:
JAX_PMAP_NO_RANK_REDUCTION
If True, pmap shards have a the same rank as their enclosing array.
Pmap Shmap Merge#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_pmap_shmap_merge'
- Environment Variable:
JAX_PMAP_SHMAP_MERGE
If True, pmap and shard_map API will be merged. This will be enabled by default in future versions of JAX, at which point all uses of the flag will be considered deprecated (following the API compatibility policy).
Pprint Use Color#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_pprint_use_color'
- Environment Variable:
JAX_PPRINT_USE_COLOR
Enable jaxpr pretty-printing with colorful syntax highlighting.
Raise Persistent Cache Errors#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_raise_persistent_cache_errors'
- Environment Variable:
JAX_RAISE_PERSISTENT_CACHE_ERRORS
If true, exceptions raised when reading or writing to the persistent compilation cache will be allowed through, halting program execution if not manually caught. If false, exceptions are caught and raised as warnings, allowing program execution to continue. Defaults to false so cache bugs or intermittent issues are non-fatal.
Random Seed Offset#
- Type:
int
- Default Value:
0
- Configuration String:
'jax_random_seed_offset'
- Environment Variable:
JAX_RANDOM_SEED_OFFSET
Offset to all random seeds (e.g. argument to jax.random.key()).
Remove Custom Partitioning Ptr From Cache Key#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_remove_custom_partitioning_ptr_from_cache_key'
- Environment Variable:
JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY
If set to True, remove the custom partitioning pointer present in the precompiled stableHLO before hashing during cache key computation. This is a potentially unsafe flag to set and only users who are sure of what they are trying to achieve should set it.
Rocm Visible Devices#
- Type:
str
- Default Value:
'all'
- Configuration String:
'jax_rocm_visible_devices'
- Environment Variable:
JAX_ROCM_VISIBLE_DEVICES
Softmax Custom Jvp#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_softmax_custom_jvp'
- Environment Variable:
JAX_SOFTMAX_CUSTOM_JVP
Use a new custom_jvp rule for jax.nn.softmax. The new rule should improve memory usage and stability. Set True to use new behavior. See jax-ml/jax#15677 This will be enabled by default in future versions of JAX, at which point all uses of the flag will be considered deprecated (following the API compatibility policy).
Threefry Gpu Kernel Lowering#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_threefry_gpu_kernel_lowering'
- Environment Variable:
JAX_THREEFRY_GPU_KERNEL_LOWERING
On GPU, lower threefry PRNG operations to a kernel implementation. This makes compile times faster at a potential runtime memory cost.
Threefry Partitionable#
- Type:
bool
- Default Value:
True
- Configuration String:
'jax_threefry_partitionable'
- Environment Variable:
JAX_THREEFRY_PARTITIONABLE
Enables internal threefry PRNG implementation changes that render it automatically partitionable in some cases. Without this flag, using the standard jax.random pseudo-random number generation may result in extraneous communication and/or redundant distributed computation. With this flag, the communication overheads disappear in some cases. This will be enabled by default in future versions of JAX, at which point all uses of the flag will be considered deprecated (following the API compatibility policy).
Traceback Filtering#
- Type:
Enum values:
'off'
,'tracebackhide'
,'remove_frames'
,'quiet_remove_frames'
,'auto'
- Default Value:
'auto'
- Configuration String:
'jax_traceback_filtering'
- Environment Variable:
JAX_TRACEBACK_FILTERING
Controls how JAX filters internal frames out of tracebacks. Valid values are:
- off
: disables traceback filtering.
- auto
: use tracebackhide
if running under a sufficiently new IPython, or remove_frames
otherwise.
- tracebackhide
: adds __tracebackhide__
annotations to hidden stack frames, which some traceback printers support.
- remove_frames
: removes hidden frames from tracebacks, and adds the unfiltered traceback as a __cause__
of the exception.
- quiet_remove_frames
: removes hidden frames from tracebacks, and adds a brief message (to the __cause__
of the exception) describing that this has happened.
Traceback In Locations Limit#
- Type:
int
- Default Value:
10
- Configuration String:
'jax_traceback_in_locations_limit'
- Environment Variable:
JAX_TRACEBACK_IN_LOCATIONS_LIMIT
Limit the number of frames at the Python traceback frames included in MLIR locations. If set to the negative value, traceback will not be limited.
Tracer Error Num Traceback Frames#
- Type:
int
- Default Value:
5
- Configuration String:
'jax_tracer_error_num_traceback_frames'
- Environment Variable:
JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES
Set the number of stack frames in JAX tracer error messages.
Transfer Guard#
- Type:
Enum values:
'allow'
,'log'
,'disallow'
,'log_explicit'
,'disallow_explicit'
- Default Value:
None
- Configuration String:
'jax_transfer_guard'
- Environment Variable:
JAX_TRANSFER_GUARD
Select the transfer guard level for all transfers. This option is set-only; the transfer guard level for a specific direction should be read using the per-transfer direction option. Default is “allow”.
Transfer Guard Device To Device#
- Type:
Enum values:
'allow'
,'log'
,'disallow'
,'log_explicit'
,'disallow_explicit'
- Default Value:
None
- Configuration String:
'jax_transfer_guard_device_to_device'
- Environment Variable:
JAX_TRANSFER_GUARD_DEVICE_TO_DEVICE
Select the transfer guard level for device-to-device transfers. Default is “allow”.
Transfer Guard Device To Host#
- Type:
Enum values:
'allow'
,'log'
,'disallow'
,'log_explicit'
,'disallow_explicit'
- Default Value:
None
- Configuration String:
'jax_transfer_guard_device_to_host'
- Environment Variable:
JAX_TRANSFER_GUARD_DEVICE_TO_HOST
Select the transfer guard level for device-to-host transfers. Default is “allow”.
Transfer Guard Host To Device#
- Type:
Enum values:
'allow'
,'log'
,'disallow'
,'log_explicit'
,'disallow_explicit'
- Default Value:
None
- Configuration String:
'jax_transfer_guard_host_to_device'
- Environment Variable:
JAX_TRANSFER_GUARD_HOST_TO_DEVICE
Select the transfer guard level for host-to-device transfers. Default is “allow”.
Use Direct Linearize#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_use_direct_linearize'
- Environment Variable:
JAX_USE_DIRECT_LINEARIZE
Use direct linearization instead JVP followed by partial eval
Use Magma#
- Type:
Enum values:
'off'
,'on'
,'auto'
- Default Value:
'auto'
- Configuration String:
'jax_use_magma'
- Environment Variable:
JAX_USE_MAGMA
Enable experimental support for MAGMA-backed lax.linalg.eig on GPU. See the documentation for lax.linalg.eig for more details about how to use this feature.
Use Shardy Partitioner#
- Type:
bool
- Default Value:
False
- Configuration String:
'jax_use_shardy_partitioner'
- Environment Variable:
JAX_USE_SHARDY_PARTITIONER
Whether to lower to Shardy. Shardy is a new open sourced propagation framework for MLIR. Currently Shardy is experimental in JAX. See www.github.com/openxla/shardy This will be enabled by default in future versions of JAX, at which point all uses of the flag will be considered deprecated (following the API compatibility policy).
Xla Backend#
- Type:
str
- Default Value:
''
- Configuration String:
'jax_xla_backend'
- Environment Variable:
JAX_XLA_BACKEND
Xla Profile Version#
- Type:
int
- Default Value:
0
- Configuration String:
'jax_xla_profile_version'
- Environment Variable:
JAX_XLA_PROFILE_VERSION
Optional profile version for XLA compilation. This is meaningful only when XLA is configured to support the remote compilation profile feature.
Mock Num Gpu Processes#
- Type:
int
- Default Value:
0
- Configuration String:
'mock_num_gpu_processes'
- Environment Variable:
MOCK_NUM_GPU_PROCESSES
Mock number of JAX processes in GPU client. Value zero turns off mocking.
Mosaic Use Python Pipeline#
- Type:
bool
- Default Value:
False
- Configuration String:
'mosaic_use_python_pipeline'
- Environment Variable:
MOSAIC_USE_PYTHON_PIPELINE
Run the initial Mosaic MLIR passes from Python, when as_tpu_kernel is called (for Pallas, this happens at JAX lowering time), instead of later within XLA.