jax.experimental.pallas.BlockSpec#
- class jax.experimental.pallas.BlockSpec(block_shape=None, index_map=None, indexing_mode=None, pipeline_mode=None, *, memory_space=None)[source]#
Specifies how an array should be sliced for each invocation of a kernel.
The block_shape is a sequence of int | None`s, or `BlockDim types (e.g. pl.Element, pl.Squeezed, pl.Blocked, pl.BoundedSlice). Each of these types specify the size of the block dimension. None is used to specify a dimension that is squeezed out of the kernel. The BlockDim types allow for more fine-grained control over the indexing of the dimension. The index_map needs to return a tuple of the same length as block_shape, which each entry depending on the type of BlockDim.
See BlockSpec, a.k.a. how to chunk up inputs and the individual BlockDim type docstrings for more details.
- Parameters:
block_shape (Sequence[BlockDim | int | None] | None)
index_map (Callable[..., Any] | None)
indexing_mode (Any | None)
pipeline_mode (Buffered | None)
memory_space (Any | None)
- __init__(block_shape=None, index_map=None, indexing_mode=None, pipeline_mode=None, *, memory_space=None)#
- Parameters:
block_shape (Sequence[BlockDim | int | None] | None | None)
index_map (Callable[..., Any] | None | None)
indexing_mode (Any | None | None)
pipeline_mode (Buffered | None | None)
memory_space (Any | None | None)
- Return type:
None
Methods
__init__
([block_shape, index_map, ...])replace
(**changes)Return a new object replacing specified fields with new values.
to_block_mapping
(origin, array_aval, *, ...)Attributes
block_shape
index_map
indexing_mode
memory_space
pipeline_mode