Pallas Changelog#
This is the list of changes specific to jax.experimental.pallas
.
For the overall JAX change log see here.
Released with jax 0.5.0#
New functionality
Added vector support for
jax.experimental.pallas.debug_print()
on TPU.
Released with jax 0.4.37#
New functionality
Added support for
DotAlgorithmPreset
precision arguments fordot
lowering on Triton backend.
Released with jax 0.4.36 (December 6, 2024)#
Released with jax 0.4.35 (October 22, 2024)#
Removals
Removed previously deprecated aliases
jax.experimental.pallas.tpu.CostEstimate
andjax.experimental.tpu.run_scoped()
. Both are now available injax.experimental.pallas
.
New functionality
Added a cost estimate tool
pl.estimate_cost()
for automatically constructing a kernel cost estimate from a JAX reference function.
Released with jax 0.4.34 (October 4, 2024)#
Changes
jax.experimental.pallas.debug_print()
no longer requires all arguments to be scalars. The restrictions on the arguments are backend-specific: Non-scalar arguments are currently only supported on GPU, when using Triton.jax.experimental.pallas.BlockSpec
no longer supports the previously deprecated argument order, whereindex_map
comes beforeblock_shape
.
Deprecations
The
jax.experimental.pallas.gpu
submodule is deprecated to avoid ambiguite withjax.experimental.pallas.mosaic_gpu
. To use the Triton backend importjax.experimental.pallas.triton
.
New functionality
jax.experimental.pallas.pallas_call()
now acceptsscratch_shapes
, a PyTree specifying backend-specific temporary objects needed by the kernel, for example, buffers, synchronization primitives etc.checkify.check()
can now be used to insert runtime asserts when pallas_call is called with thepltpu.enable_runtime_assert(True)
context manager.
Released with jax 0.4.33 (September 16, 2024)#
Released with jax 0.4.32 (September 11, 2024)#
Changes
The kernel function is not allowed to close over constants. Instead, all the needed arrays must be passed as inputs, with proper block specs (#22746).
New functionality
Improved error messages for mistakes in the signature of the index map functions, to include the name and source location of the index map.
Released with jax 0.4.31 (July 29, 2024)#
Changes
jax.experimental.pallas.BlockSpec
now expectsblock_shape
to be passed beforeindex_map
. The old argument order is deprecated and will be removed in a future release.jax.experimental.pallas.GridSpec
does not have anymore thein_specs_tree
, and theout_specs_tree
fields, and thein_specs
andout_specs
tree now store the values as pytrees of BlockSpec. Previously,in_specs
andout_specs
were flattened (#22552).The method
compute_index
ofjax.experimental.pallas.GridSpec
has been removed because it is private. Similarly, theget_grid_mapping
andunzip_dynamic_bounds
have been removed fromBlockSpec
(#22593).Fixed the interpret mode to work with BlockSpec that involve padding (#22275). Padding in interpret mode will be with NaN, to help debug out-of-bounds errors, but this behavior is not present when running in custom kernel mode, and should not be depended on.
Previously it was possible to import many APIs that are meant to be private, as
jax.experimental.pallas.pallas
. This is not possible anymore.
New Functionality
Added documentation for BlockSpec: Grids and BlockSpecs.
Improved error messages for the
jax.experimental.pallas.pallas_call()
API.Added lowering rules for TPU for
lax.shift_right_arithmetic
(#22279) andlax.erf_inv
(#22310).Added initial support for shape polymorphism for the Pallas TPU custom kernels
(#22084).Added TPU support for checkify. (#22480)
Added clearer error messages when the block sizes do not match the TPU requirements. Previously, the errors were coming from the Mosaic backend and did not have useful Python stack traces.
Added support for TPU lowering with 1D blocks, and relaxed the requirements for the block sizes with at least 2 dimensions: the last 2 dimensions must be divisible by 8 and 128 respectively, unless they span the entire corresponding array dimension. Previously, block dimensions that spanned the entire array were allowed only if the block dimensions in the last two dimensions were smaller than 8 and 128 respectively.
Released with JAX 0.4.30 (June 18, 2024)#
New Functionality
Added checkify support for
jax.experimental.pallas.pallas_call()
in interpret mode (#21862).Improved support for PRNG keys for TPU kernels (#21773).