jax.default_matmul_precision#
- jax.default_matmul_precision = <jax._src.config.State object>#
Context manager for jax_default_matmul_precision config option.
Control the default matmul and conv precision for 32bit inputs.
Some platforms, like TPU, offer configurable precision levels for matrix multiplication and convolution computations, trading off accuracy for speed. The precision can be controlled for each operation; for example, see the
jax.lax.conv_general_dilated()
andjax.lax.dot()
docstrings. But it can be useful to control the default behavior obtained when an operation is not given a specific precision.This option can be used to control the default precision level for computations involved in matrix multiplication and convolution on 32bit inputs. The levels roughly describe the precision at which scalar products are computed. The âbfloat16â option is the fastest and least precise; âfloat32â is similar to full float32 precision; âtensorfloat32â is intermediate.
This parameter can also be used to specify an accumulation âalgorithmâ for functions that perform matrix multiplications, like
jax.lax.dot()
. To specify an algorithm, set this option to the name of aDotAlgorithmPreset
.- Parameters:
new_val (Any)