jax.lax
module#
jax.lax
is a library of primitives operations that underpins libraries
such as jax.numpy
. Transformation rules, such as JVP and batching rules,
are typically defined as transformations on jax.lax
primitives.
Many of the primitives are thin wrappers around equivalent XLA operations, described by the XLA operation semantics documentation. In a few cases JAX diverges from XLA, usually to ensure that the set of operations is closed under the operation of JVP and transpose rules.
Where possible, prefer to use libraries such as jax.numpy
instead of
using jax.lax
directly. The jax.numpy
API follows NumPy, and is
therefore more stable and less likely to change than the jax.lax
API.
Operators#
|
Elementwise absolute value: \(|x|\). |
|
Elementwise arc cosine: \(\mathrm{acos}(x)\). |
|
Elementwise inverse hyperbolic cosine: \(\mathrm{acosh}(x)\). |
|
Elementwise addition: \(x + y\). |
|
Merges one or more XLA token values. |
|
Returns max |
|
Returns min |
|
Computes the index of the maximum element along |
|
Computes the index of the minimum element along |
|
Elementwise arc sine: \(\mathrm{asin}(x)\). |
|
Elementwise inverse hyperbolic sine: \(\mathrm{asinh}(x)\). |
|
Elementwise arc tangent: \(\mathrm{atan}(x)\). |
|
Elementwise two-term arc tangent: \(\mathrm{atan}({x \over y})\). |
|
Elementwise inverse hyperbolic tangent: \(\mathrm{atanh}(x)\). |
|
Batch matrix multiplication. |
|
Exponentially scaled modified Bessel function of order 0: \(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)\) |
|
Exponentially scaled modified Bessel function of order 1: \(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)\) |
|
Elementwise regularized incomplete beta integral. |
|
Elementwise bitcast. |
|
Elementwise AND: \(x \wedge y\). |
|
Elementwise NOT: \(\neg x\). |
|
Elementwise OR: \(x \vee y\). |
|
Elementwise exclusive OR: \(x \oplus y\). |
Elementwise popcount, count the number of set bits in each element. |
|
|
Broadcasts an array, adding new leading dimensions |
|
Wraps XLA's BroadcastInDim operator. |
|
Returns the shape that results from NumPy broadcasting of shapes. |
|
Adds leading dimensions of |
|
Convenience wrapper around |
|
Elementwise cube root: \(\sqrt[3]{x}\). |
|
Elementwise ceiling: \(\left\lceil x \right\rceil\). |
|
Elementwise clamp. |
|
Elementwise count-leading-zeros. |
|
Collapses dimensions of an array into a single dimension. |
|
Elementwise make complex number: \(x + jy\). |
|
Composite with semantics defined by the decomposition function. |
|
Concatenates a sequence of arrays along dimension. |
|
Elementwise complex conjugate function: \(\overline{x}\). |
|
Convenience wrapper around conv_general_dilated. |
|
Elementwise cast. |
|
Converts convolution dimension_numbers to a ConvDimensionNumbers. |
|
General n-dimensional convolution operator, with optional dilation. |
|
General n-dimensional unshared convolution operator with optional dilation. |
|
Extract patches subject to the receptive field of conv_general_dilated. |
|
Convenience wrapper for calculating the N-d convolution "transpose". |
|
Convenience wrapper around conv_general_dilated. |
|
Elementwise cosine: \(\mathrm{cos}(x)\). |
|
Elementwise hyperbolic cosine: \(\mathrm{cosh}(x)\). |
|
Computes a cumulative logsumexp along axis. |
|
Computes a cumulative maximum along axis. |
|
Computes a cumulative minimum along axis. |
|
Computes a cumulative product along axis. |
|
Computes a cumulative sum along axis. |
|
Elementwise digamma: \(\psi(x)\). |
|
Elementwise division: \(x \over y\). |
|
Vector/vector, matrix/vector, and matrix/matrix multiplication. |
|
General dot product/contraction operator. |
|
Convenience wrapper around dynamic_slice to perform int indexing. |
|
Wraps XLA's DynamicSlice operator. |
|
Convenience wrapper around |
|
Convenience wrapper around |
|
Wraps XLA's DynamicUpdateSlice operator. |
|
Convenience wrapper around |
|
Elementwise equals: \(x = y\). |
|
Elementwise error function: \(\mathrm{erf}(x)\). |
|
Elementwise complementary error function: \(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x)\). |
|
Elementwise inverse error function: \(\mathrm{erf}^{-1}(x)\). |
|
Elementwise exponential: \(e^x\). |
|
Elementwise base-2 exponential: \(2^x\). |
|
Insert any number of size 1 dimensions into an array. |
|
Elementwise \(e^{x} - 1\). |
|
|
|
Elementwise floor: \(\left\lfloor x \right\rfloor\). |
|
Returns an array of shape filled with fill_value. |
|
Create a full array like np.full based on the example array x. |
|
Gather operator. |
|
Elementwise greater-than-or-equals: \(x \geq y\). |
|
Elementwise greater-than: \(x > y\). |
|
Elementwise regularized incomplete gamma function. |
|
Elementwise complementary regularized incomplete gamma function. |
|
Elementwise extract imaginary part: \(\mathrm{Im}(x)\). |
|
Convenience wrapper around |
|
|
|
Elementwise power: \(x^y\), where \(y\) is a static integer. |
|
Wraps XLA's Iota operator. |
|
Elementwise \(\mathrm{isfinite}\). |
|
Elementwise less-than-or-equals: \(x \leq y\). |
|
Elementwise log gamma: \(\mathrm{log}(\Gamma(x))\). |
|
Elementwise natural logarithm: \(\mathrm{log}(x)\). |
|
Elementwise \(\mathrm{log}(1 + x)\). |
|
Elementwise logistic (sigmoid) function: \(\frac{1}{1 + e^{-x}}\). |
|
Elementwise less-than: \(x < y\). |
|
Elementwise maximum: \(\mathrm{max}(x, y)\) |
|
Elementwise minimum: \(\mathrm{min}(x, y)\) |
|
Elementwise multiplication: \(x \times y\). |
|
Elementwise not-equals: \(x \neq y\). |
|
Elementwise negation: \(-x\). |
|
Returns the next representable value after |
|
Prevents the compiler from moving operations across the barrier. |
|
Applies low, high, and/or interior padding to an array. |
|
Stages out platform-specific code. |
|
Elementwise polygamma: \(\psi^{(m)}(x)\). |
Elementwise popcount, count the number of set bits in each element. |
|
|
Elementwise power: \(x^y\). |
|
Elementwise derivative of samples from Gamma(a, 1). |
|
Elementwise extract real part: \(\mathrm{Re}(x)\). |
|
Elementwise reciprocal: \(1 \over x\). |
|
Wraps XLA's Reduce operator. |
|
Compute the bitwise AND of elements over one or more array axes. |
|
Compute the maximum of elements over one or more array axes. |
|
Compute the minimum of elements over one or more array axes. |
|
Compute the bitwise OR of elements over one or more array axes. |
|
Wraps XLA's ReducePrecision operator. |
|
Compute the product of elements over one or more array axes. |
|
Compute the sum of elements over one or more array axes. |
|
Wraps XLA's ReduceWindowWithGeneralPadding operator. |
|
Compute the bitwise XOR of elements over one or more array axes. |
|
Elementwise remainder: \(x \bmod y\). |
|
Wraps XLA's Reshape operator. |
|
Wraps XLA's Rev operator. |
|
Stateless PRNG bit generator. |
|
Stateful PRNG generator. |
|
Elementwise round. |
|
Elementwise reciprocal square root: \(1 \over \sqrt{x}\). |
|
Scatter-update operator. |
|
Scatter-add operator. |
|
Scatter-apply operator. |
|
Scatter-max operator. |
|
Scatter-min operator. |
|
Scatter-multiply operator. |
|
Elementwise left shift: \(x \ll y\). |
|
Elementwise arithmetic right shift: \(x \gg y\). |
|
Elementwise logical right shift: \(x \gg y\). |
|
Elementwise sign. |
|
Elementwise sine: \(\mathrm{sin}(x)\). |
|
Elementwise hyperbolic sine: \(\mathrm{sinh}(x)\). |
|
Wraps XLA's Slice operator. |
|
Convenience wrapper around |
|
Wraps XLA's Sort operator. |
|
Sorts |
|
Splits an array along |
|
Elementwise square root: \(\sqrt{x}\). |
|
Elementwise square: \(x^2\). |
|
Squeeze any number of size 1 dimensions from an array. |
|
Elementwise subtraction: \(x - y\). |
|
Elementwise tangent: \(\mathrm{tan}(x)\). |
|
Elementwise hyperbolic tangent: \(\mathrm{tanh}(x)\). |
|
Returns top |
|
Wraps XLA's Transpose operator. |
|
Elementwise Hurwitz zeta function: \(\zeta(x, q)\) |
Control flow operators#
|
Performs a scan with an associative binary operation, in parallel. |
|
Conditionally apply |
|
Loop from |
|
Map a function over leading array axes. |
|
Scan a function over leading array axes while carrying along state. |
|
Selects between two branches based on a boolean predicate. |
|
Selects array values from multiple cases. |
|
Apply exactly one of the |
|
Call |
Custom gradient operators#
Stops gradient computation. |
|
|
Perform a matrix-free linear solve with implicitly defined gradients. |
|
Differentiably solve for the roots of a function. |
Parallel operators#
|
Gather values of x across all replicas. |
|
Materialize the mapped axis and map a different axis. |
|
Compute an all-reduce sum on |
|
Like |
|
Compute an all-reduce max on |
|
Compute an all-reduce min on |
|
Compute an all-reduce mean on |
|
Perform a collective permutation according to the permutation |
|
Convenience wrapper of jax.lax.ppermute with alternate permutation encoding |
|
Swap the pmapped axis |
|
Return the index along the mapped axis |
Linear algebra operators (jax.lax.linalg)#
|
Cholesky decomposition. |
|
Cholesky rank-1 update. |
|
Eigendecomposition of a general matrix. |
|
Eigendecomposition of a Hermitian matrix. |
|
Reduces a square matrix to upper Hessenberg form. |
|
Product of elementary Householder reflectors. |
|
LU decomposition with partial pivoting. |
|
Converts the pivots (row swaps) returned by LU to a permutation. |
|
QR-based dynamically weighted Halley iteration for polar decomposition. |
|
QR decomposition. |
|
Schur decomposition. |
|
Singular value decomposition. |
|
Enum for SVD algorithm. |
|
Symmetric product. |
|
Triangular solve. |
|
Reduces a symmetric/Hermitian matrix to tridiagonal form. |
|
Computes the solution of a tridiagonal linear system. |
Argument classes#
- class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)[source]#
Describes batch, spatial, and feature dimensions of a convolution.
- Parameters:
lhs_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (batch dimension, feature dimension, spatial dimensions…).
rhs_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (out feature dimension, in feature dimension, spatial dimensions…).
out_spec (Sequence[int]) – a tuple of nonnegative integer dimension numbers containing (batch dimension, feature dimension, spatial dimensions…).
- jax.lax.ConvGeneralDilatedDimensionNumbers#
- class jax.lax.DotAlgorithm(lhs_precision_type, rhs_precision_type, accumulation_type, lhs_component_count=1, rhs_component_count=1, num_primitive_operations=1, allow_imprecise_accumulation=False)[source]#
Specify the algorithm used for computing dot products.
When used to specify the
precision
input todot()
,dot_general()
, and other dot product functions, this data structure is used for controlling the properties of the algorithm used for computing the dot product. This API controls the precision used for the computation, and allows users to access hardware-specific accelerations.Support for these algorithms is platform dependent, and using an unsupported algorithm will raise a Python exception when the computation is compiled. The algorithms that are known to be supported on at least some platforms are listed in the
DotAlgorithmPreset
enum, and these are a good starting point for experimenting with this API.A “dot algorithm” is specified by the following parameters:
lhs_precision_type
andrhs_precision_type
, the data types that the LHS and RHS of the operation are rounded to.accumulation_type
the data type used for accumulation.lhs_component_count
,rhs_component_count
, andnum_primitive_operations
apply to algorithms that decompose the LHS and/or RHS into multiple components and execute multiple operations on those values, usually to emulate a higher precision. For algorithms with no decomposition, these values should be set to1
.allow_imprecise_accumulation
to specify if accumulation in lower precision is permitted for some steps (e.g.CUBLASLT_MATMUL_DESC_FAST_ACCUM
).
The StableHLO spec for the dot operation doesn’t require that the precision types be the same as the storage types for the inputs or outputs, but some plaforms may require that these types match. Furthermore, the return type of
dot_general()
is always defined by theaccumulation_type
parameter of the input algorithm, if specified.Examples
Accumulate two 16-bit floats using a 32-bit float accumulator:
>>> algorithm = DotAlgorithm( ... lhs_precision_type=np.float16, ... rhs_precision_type=np.float16, ... accumulation_type=np.float32, ... ) >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
Or, equivalently, using a preset:
>>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
Presets can also be specified by name:
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
The
preferred_element_type
parameter can be used to return the output without downcasting the accumulation type:>>> dot(lhs, rhs, precision="F16_F16_F32", preferred_element_type=np.float32) array([ 1., 4., 9., 16.], dtype=float32)
- class jax.lax.DotAlgorithmPreset(value)[source]#
An enum of known algorithms for computing dot products.
This
Enum
provides a named set ofDotAlgorithm
objects that are known to be supported on at least platform. See theDotAlgorithm
documentation for more details about the behavior of these algorithms.An algorithm can be selected from this list when calling
dot()
,dot_general()
, or most other JAX dot product functions, by passing either a member of thisEnum
or it’s name as a string using theprecision
argument.For example, users can specify the preset using this
Enum
directly:>>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) >>> algorithm = DotAlgorithmPreset.F16_F16_F32 >>> dot(lhs, rhs, precision=algorithm) array([ 1., 4., 9., 16.], dtype=float16)
or, equivalently, they can be specified by name:
>>> dot(lhs, rhs, precision="F16_F16_F32") array([ 1., 4., 9., 16.], dtype=float16)
The names of the presets are typically
LHS_RHS_ACCUM
whereLHS
andRHS
are the element types of thelhs
andrhs
inputs respectively, andACCUM
is the element type of the accumulator. Some presets have an extra suffix, and the meaning of each of these is documented below. The supported presets are:- DEFAULT = 1#
An algorithm will be selected based on input and output types.
- ANY_F8_ANY_F8_F32 = 2#
Accepts any float8 input types and accumulates into float32.
- ANY_F8_ANY_F8_F32_FAST_ACCUM = 3#
Like
ANY_F8_ANY_F8_F32
, but using faster accumulation with the cost of lower accuracy.
- ANY_F8_ANY_F8_ANY = 4#
Like
ANY_F8_ANY_F8_F32
, but the accumulation type is controlled bypreferred_element_type
.
- ANY_F8_ANY_F8_ANY_FAST_ACCUM = 5#
Like
ANY_F8_ANY_F8_F32_FAST_ACCUM
, but the accumulation type is controlled bypreferred_element_type
.
- F16_F16_F16 = 6#
- F16_F16_F32 = 7#
- BF16_BF16_BF16 = 8#
- BF16_BF16_F32 = 9#
- BF16_BF16_F32_X3 = 10#
The
_X3
suffix indicates that the algorithm uses 3 operations to emulate higher precision.
- BF16_BF16_F32_X6 = 11#
Like
BF16_BF16_F32_X3
, but using 6 operations instead of 3.
- BF16_BF16_F32_X9 = 12#
Like
BF16_BF16_F32_X3
, but using 9 operations instead of 3.
- TF32_TF32_F32 = 13#
- TF32_TF32_F32_X3 = 14#
The
_X3
suffix indicates that the algorithm uses 3 operations to emulate higher precision.
- F32_F32_F32 = 15#
- F64_F64_F64 = 16#
- class jax.lax.FftType(value)[source]#
Describes which FFT operation to perform.
- FFT = 0#
Forward complex-to-complex FFT.
- IFFT = 1#
Inverse complex-to-complex FFT.
- IRFFT = 3#
Inverse real-to-complex FFT.
- RFFT = 2#
Forward real-to-complex FFT.
- class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map, operand_batching_dims=(), start_indices_batching_dims=())[source]#
Describes the dimension number arguments to an XLA’s Gather operator. See the XLA documentation for more details of what the dimension numbers mean.
- Parameters:
offset_dims (tuple[int, ...]) – the set of dimensions in the gather output that offset into an array sliced from operand. Must be a tuple of integers in ascending order, each representing a dimension number of the output.
collapsed_slice_dims (tuple[int, ...]) – the set of dimensions i in operand that have slice_sizes[i] == 1 and that should not have a corresponding dimension in the output of the gather. Must be a tuple of integers in ascending order.
start_index_map (tuple[int, ...]) – for each dimension in start_indices, gives the corresponding dimension in the operand that is to be sliced. Must be a tuple of integers with size equal to start_indices.shape[-1].
operand_batching_dims (tuple[int, ...]) – the set of batching dimensions i in operand that have slice_sizes[i] == 1 and that should have a corresponding dimension in both the start_indices (at the same index in start_indices_batching_dims) and output of the gather. Must be a tuple of integers in ascending order.
start_indices_batching_dims (tuple[int, ...]) – the set of batching dimensions i in start_indices that should have a corresponding dimension in both the operand (at the same index in operand_batching_dims) and output of the gather. Must be a tuple of integers (order is fixed based on correspondence with operand_batching_dims).
Unlike XLA’s GatherDimensionNumbers structure, index_vector_dim is implicit; there is always an index vector dimension and it must always be the last dimension. To gather scalar indices, add a trailing dimension of size 1.
- class jax.lax.GatherScatterMode(value)[source]#
Describes how to handle out-of-bounds indices in a gather or scatter.
Possible values are:
- CLIP:
Indices will be clamped to the nearest in-range value, i.e., such that the entire window to be gathered is in-range.
- FILL_OR_DROP:
If any part of a gathered window is out of bounds, the entire window that is returned, even those elements that were otherwise in-bounds, will be filled with a constant. If any part of a scattered window is out of bounds, the entire window will be discarded.
- PROMISE_IN_BOUNDS:
The user promises that indices are in bounds. No additional checking will be performed. In practice, with the current XLA implementation this means that out-of-bounds gathers will be clamped but out-of-bounds scatters will be discarded. Gradients will not be correct if indices are out-of-bounds.
- class jax.lax.Precision(value)[source]#
Precision enum for lax matrix multiply related functions.
The device-dependent precision argument to JAX functions generally controls the tradeoff between speed and accuracy for array computations on accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends. This only has an effect on float32 computations, and does not affect the input/output datatypes. Members are:
- DEFAULT:
Fastest mode, but least accurate. On TPU: performs float32 computations in bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100 GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases:
'default'
,'fastest'
.- HIGH:
Slower but more accurate. On TPU: performs float32 computations in 3 bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise float32. Aliases:
'high'
..- HIGHEST:
Slowest but most accurate. On TPU: performs float32 computations in 6 bfloat16. Aliases:
'highest'
. On GPU: uses float32.
- jax.lax.PrecisionLike#
alias of
None
|str
|Precision
|tuple
[str
,str
] |tuple
[Precision
,Precision
] |DotAlgorithm
|DotAlgorithmPreset
- class jax.lax.RandomAlgorithm(value)[source]#
Describes which PRNG algorithm to use for rng_bit_generator.
- RNG_DEFAULT = 0#
The platform’s default algorithm.
- RNG_THREE_FRY = 1#
The Threefry-2x32 PRNG algorithm.
- RNG_PHILOX = 2#
The Philox-4x32 PRNG algorithm.
- class jax.lax.RoundingMethod(value)[source]#
Rounding strategies for handling halfway values (e.g., 0.5) in
jax.lax.round()
.- AWAY_FROM_ZERO = 0#
Rounds halfway values away from zero (e.g., 0.5 -> 1, -0.5 -> -1).
- TO_NEAREST_EVEN = 1#
Rounds halfway values to the nearest even integer. This is also known as “banker’s rounding” (e.g., 0.5 -> 0, 1.5 -> 2).
- class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, operand_batching_dims=(), scatter_indices_batching_dims=())[source]#
Describes the dimension number arguments to an XLA’s Scatter operator. See the XLA documentation for more details of what the dimension numbers mean.
- Parameters:
update_window_dims (Sequence[int]) – the set of dimensions in the updates that are window dimensions. Must be a tuple of integers in ascending order, each representing a dimension number.
inserted_window_dims (Sequence[int]) – the set of size 1 window dimensions that must be inserted into the shape of updates. Must be a tuple of integers in ascending order, each representing a dimension number of the output. These are the mirror image of collapsed_slice_dims in the case of gather.
scatter_dims_to_operand_dims (Sequence[int]) – for each dimension in scatter_indices, gives the corresponding dimension in operand. Must be a sequence of integers with size equal to scatter_indices.shape[-1].
operand_batching_dims (Sequence[int]) – the set of batching dimensions i in operand that should have a corresponding dimension in both the scatter_indices (at the same index in scatter_indices_batching_dims) and updates. Must be a tuple of integers in ascending order. These are the mirror image of operand_batching_dims in the case of gather.
scatter_indices_batching_dims (Sequence[int]) – the set of batching dimensions i in scatter_indices that should have a corresponding dimension in both the operand (at the same index in operand_batching_dims) and output of the gather. Must be a tuple of integers (order is fixed based on correspondence with input_batching_dims). These are the mirror image of start_indices_batching_dims in the case of gather.
Unlike XLA’s ScatterDimensionNumbers structure, index_vector_dim is implicit; there is always an index vector dimension and it must always be the last dimension. To scatter scalar indices, add a trailing dimension of size 1.