Defining new JAX types with hijax#
JAX’s built-in currency is the array: functions you transform take arrays in
and produce arrays out, and every intermediate the tracing machinery sees has
an array type like f32[3,4]. When you want to work with aggregate data, the
usual tool is a
pytree: you
bundle arrays into containers, and JAX transparently flattens the bundle
into its array leaves at every boundary.
But sometimes transparency is exactly what you don’t want. Some data is best modeled as a new type, with its own identity:
it should appear in jaxprs as a single value of a single type, not as a spray of array leaves;
it has internal invariants, so users should only produce and consume it through a fixed set of operations, rather than by freely constructing or pattern-matching its components;
its tangent type may differ from its primal structure, so that derivatives with respect to it aren’t just “the same pytree, but for tangents”;
it may have its own notion of batching under
vmap.
Hijax types (or “hi types”) provide this. You subclass HiType to define
the type, register a Python class as carrying values of that type, and write
hijax primitives whose input and output types mention the new type. This
document walks through the whole story with one running example: a
quantized array type.
We’ll assume some familiarity with hijax primitives; see
Custom derivative rules with hijax primitives for an introduction to them. Like everything
hijax, this is experimental: expect imports from jax.experimental.hijax,
and expect the APIs to evolve.
TL;DR#
Subclass
HiTypeand implementlo_ty,lower_val, andraise_valto say how the type and its values lower to ordinary (“lojax”) arrays, then callregister_hitypeto associate your value class with your type.Write
VJPHiPrimitivesubclasses whosein_avals/out_avalmention the new type; these are the only way values of the type get produced and consumed.For autodiff, implement
to_tangent_avalon the type, and VJP/JVP rules on the primitives.For
vmap, implementdec_rankandinc_rankon the type along with aMappingSpecsubclass of your own design, andbatchrules on the primitives. Mapped-over hi type arguments require an explicitaxis_sizeand spec-valuedin_axes/out_axesentries.
Example: quantized arrays#
Say we want to work with arrays quantized to int8. A quantized array is
really a pair of arrays: the int8 values, and a floating point scale
shared by each row (that is, we quantize along the last axis, one scale per
row, as in common per-row/per-channel quantization schemes):
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
# (8 CPU devices, for the shard_map section at the end)
from dataclasses import dataclass
import jax
import jax.numpy as jnp
@dataclass(frozen=True)
class QArray:
qvalue: jax.Array # int8[*leading, n]
scale: jax.Array # f32[*leading]
We could register QArray as a pytree and be done. But consider what we’d
give up:
Invariants. The two components are coupled:
scalemust have the shape ofqvalueminus its last axis, andqvalueis only meaningful together with itsscale. As a pytree, nothing stops code from crossing the streams; under transformations, JAX itself sees only independent leaves.Types in jaxprs. As a pytree, a quantized array appears in traced code as two unrelated array values. We’d rather see one value, of one type, so jaxprs say what they mean.
Tangents. A quantized array’s values live on a discrete grid, so it makes no sense to perturb them along the grid. But a pytree’s tangent type is forced to be the pytree of its leaves’ tangent types — and the tangent type of an integer array like
qvalueis afloat0array, which can only carry a trivial payload. So as a pytree, a quantized array would admit no useful perturbations at all. What we want is to choose a tangent type for the quantized array as a whole, such as the continuousf32arrays that the quantized values approximate.
So instead we’ll make QArray a hijax type.
The type#
A hijax type is a subclass of HiType. The required core is small:
lo_tysays which lojax (array) types make up the type;lower_valandraise_valconvert values to and from that list of arrays;the type must be hashable and comparable for equality (a frozen dataclass gives us both).
This is like the pytree flatten/unflatten interface, but it lives at the level of types: given only the type, JAX can compute the lowered types, without needing a value in hand.
from jax.experimental.hijax import HiType, ShapedArray, register_hitype
@dataclass(frozen=True)
class QArrayTy(HiType):
shape: tuple[int, ...]
# lowering: which array types make up this type, and how values convert
def lo_ty(self):
return [ShapedArray(self.shape, jnp.dtype('int8')),
ShapedArray(self.shape[:-1], jnp.dtype('float32'))]
def lower_val(self, q):
return [q.qvalue, q.scale]
def raise_val(self, qvalue, scale):
return QArray(qvalue, scale)
# autodiff: tangents of quantized arrays are plain float arrays (see below)
def to_tangent_aval(self):
return ShapedArray(self.shape, jnp.dtype('float32'))
# printing, e.g. in jaxprs
def str_short(self, short_dtypes=False, mesh_axis_types=False):
return f'q8[{",".join(map(str, self.shape))}]'
__repr__ = str_short
register_hitype(QArray, lambda q: QArrayTy(q.qvalue.shape))
The register_hitype call associates the value class with the type: its
second argument computes the type of any given value, analogous to how
jax.typeof maps an array to its ShapedArray type. Indeed after
registration, jax.typeof works on QArrays, and JAX transformations
accept them anywhere a value is expected.
The primitives#
With a pytree, users construct and take apart values freely. With a hijax type, values are produced and consumed only by hijax primitives whose declared types mention the new type. That’s where invariants get enforced: if every primitive preserves them, they always hold.
Our two primitives are quantize and dequantize, written with the
VJPHiPrimitive API from Custom derivative rules with hijax primitives. Each declares
its input and output types, gives its implementation in expand, and
(looking ahead to autodiff) carries a straight-through-estimator VJP rule:
from jax.experimental.hijax import VJPHiPrimitive
class Quantize(VJPHiPrimitive):
def __init__(self, x_aval):
if x_aval.dtype != jnp.dtype('float32'): raise TypeError(x_aval.dtype)
self.in_avals = (x_aval,)
self.out_aval = QArrayTy(x_aval.shape)
self.params = {}
super().__init__()
def expand(self, x):
scale = jnp.max(jnp.abs(x), axis=-1) / 127.
qvalue = jnp.round(x / scale[..., None]).astype(jnp.int8)
return QArray(qvalue, scale)
# straight-through estimator: differentiate as if it's the identity
def vjp_fwd(self, nzs_in, x):
return self(x), None
def vjp_bwd_retval(self, _res, g):
return (g,)
class Dequantize(VJPHiPrimitive):
def __init__(self, q_aval):
self.in_avals = (q_aval,)
self.out_aval = ShapedArray(q_aval.shape, jnp.dtype('float32'))
self.params = {}
super().__init__()
def expand(self, qx):
return qx.qvalue.astype('float32') * qx.scale[..., None]
def vjp_fwd(self, nzs_in, qx):
return self(qx), None
def vjp_bwd_retval(self, _res, g):
return (g,)
def quantize(x):
return Quantize(jax.typeof(x))(x)
def dequantize(qx):
return Dequantize(jax.typeof(qx))(qx)
Notice that Quantize’s out_aval and Dequantize’s in_avals are
QArrayTys: the new type appears in primitive type signatures just like
array types do. Also notice expand freely constructs and inspects the
QArray value class; primitive implementations are inside the abstraction
boundary.
Everything works eagerly:
x = jnp.array([[1., 2., 3.],
[4., -5., 6.]])
qx = quantize(x)
print(qx)
print(jax.typeof(qx))
print(dequantize(qx))
QArray(qvalue=Array([[ 42, 85, 127],
[ 85, -106, 127]], dtype=int8), scale=Array([0.02362205, 0.04724409], dtype=float32))
q8[2,3]
[[ 0.992126 2.007874 3. ]
[ 4.015748 -5.007874 6. ]]
Hi types in jaxprs#
When we trace, the quantized array appears as a single value of type
q8[2,3], produced by one equation and consumed by another:
jax.make_jaxpr(lambda x: dequantize(quantize(x)))(x)
{ lambda ; a:f32[2,3]. let
b:q8[2,3] = call_hi_primitive[_prim=Quantize[{}]] a
c:f32[2,3] = call_hi_primitive[_prim=Dequantize[{}]] b
in (c,) }
Compare to the pytree approach, where the same computation would show four
array-typed intermediates with no indication that they pair up. The hi type
only disappears at lowering time, when expand is traced and each
q8[...]-typed value is expanded into the array components given by
lo_ty.
jit works, with quantized arrays as arguments, results, and
intermediates:
print(jax.jit(lambda x: dequantize(quantize(x)))(x)) # QArray internal
qx2 = jax.jit(quantize)(x) # QArray result
print(jax.typeof(qx2))
print(jax.jit(dequantize)(qx2)) # QArray argument
[[ 0.992126 2.007874 3. ]
[ 4.015748 -5.007874 6. ]]
q8[2,3]
[[ 0.992126 2.007874 3. ]
[ 4.015748 -5.007874 6. ]]
Autodiff and tangent types#
Here’s where hi types earn their keep. On the type, we implemented
def to_tangent_aval(self):
return ShapedArray(self.shape, jnp.dtype('float32'))
which says: the tangent type of a quantized array is a plain f32 array.
No pytree can express this: a pytree’s tangent type is always the pytree
of its leaves’ tangent types, and for the int8 leaf qvalue that means
a trivial float0 tangent.
Together with the straight-through VJP rules on the primitives, gradients flow through quantization as if it were the identity:
def f(x):
return jnp.sum(dequantize(quantize(x)))
print(jax.grad(f)(x))
[[1. 1. 1.]
[1. 1. 1.]]
And differentiating with respect to a quantized array input produces a plain float array, as the tangent type dictates:
def g(qx):
return jnp.sum(dequantize(qx) ** 2)
print(jax.grad(g)(qx))
print(jax.typeof(jax.grad(g)(qx)))
[[ 1.984252 4.015748 6. ]
[ 8.031496 -10.015748 12. ]]
float32[2,3]
Notice that making the tangent type an f32 array was a choice, and
there’s a real design space here. We could instead have made the tangent
type of QArrayTy be QArrayTy itself, so that tangents and cotangents
are quantized too — a different tradeoff, sensible for different
applications. (For that choice, since the tangent type is then a hi type,
we’d also implement vspace_zero and vspace_add on it so autodiff can
instantiate and accumulate cotangents.) This flexibility is why hi types
are a user extension point: for each piece of JAX — tracing, lowering,
autodiff, and batching — you set up how your type participates, however
your situation needs.
vmap and mapping specs#
What does it mean to map over a quantized array? For arrays, vmap’s
in_axes and out_axes are axis indices, and JAX can infer the mapped
axis size from the argument’s shape. For a general hi type, JAX doesn’t
guess: you define a “mapping spec” type that says how values of your type
are mapped, users pass instances of it as in_axes/out_axes entries, and
they pass axis_size explicitly when it can’t be inferred from an array
argument.
For our quantized arrays, thanks to the per-row scales, a batch of
QArrays is just a bigger QArray: stacking n quantized arrays of type
q8[2,3] along a new leading axis gives a q8[n,2,3], with qvalue of
shape (n, 2, 3) and scale of shape (n, 2). So the only mapping notion
we need is “the leading axis,” and our spec type doesn’t need to carry any
data at all:
from jax.experimental.hijax import MappingSpec
@dataclass(frozen=True)
class QArraySpec(MappingSpec):
pass # QArrays are only mapped along their leading axis
(Specs can be as rich as your type demands. A tuple-like hi type might use
a spec carrying one axis per component; see the TupSpec example in
tests/hijax_test.py.)
On the type, we implement dec_rank and inc_rank, the hi type analogues
of “remove the mapped axis” and “add the mapped axis.” They take the axis
size and a spec, and return the element type and the batched type,
respectively:
def qarray_dec_rank(self, size, spec):
assert isinstance(spec, QArraySpec) and self.shape[0] == size
return QArrayTy(self.shape[1:])
def qarray_inc_rank(self, size, spec):
assert isinstance(spec, QArraySpec)
return QArrayTy((size, *self.shape))
QArrayTy.dec_rank = qarray_dec_rank
QArrayTy.inc_rank = qarray_inc_rank
(We’re attaching methods to the class as we go, notebook-style; in real
code these would just be more methods in the class QArrayTy definition.)
On the primitives, we implement batch rules. A batch rule receives the
batched arguments along with their mapping specs (None for unbatched
arguments, an integer axis for batched array arguments, and a spec instance
for batched hi type arguments), and returns the batched result along with
its mapping spec. Note that a rule should be prepared for any combination
of batched and unbatched arguments:
def quantize_batch(self, axis_data, args, in_dims):
x, = args
d, = in_dims
if d is None:
return quantize(x), None
x = jnp.moveaxis(x, d, 0)
return quantize(x), QArraySpec()
Quantize.batch = quantize_batch
def dequantize_batch(self, axis_data, args, in_dims):
qx, = args
d, = in_dims
if d is None:
return dequantize(qx), None
assert isinstance(d, QArraySpec)
return dequantize(qx), 0
Dequantize.batch = dequantize_batch
Because per-row quantization applies at any rank, both rules can just apply the unbatched operation to the stacked value — the hallmark of a type whose batches are values of the same type family.
Now we can vmap. Mapping to a quantized array output, the axis size is
inferred from the array argument as usual, and we pass a spec for
out_axes:
xs = jnp.arange(24., dtype='float32').reshape(4, 2, 3)
qxs = jax.vmap(quantize, out_axes=QArraySpec())(xs)
print(jax.typeof(qxs))
print(qxs.qvalue.shape, qxs.scale.shape)
q8[4,2,3]
(4, 2, 3) (4, 2)
Mapping over a quantized array input, we pass a spec for in_axes — and
since there’s no array argument to infer the axis size from, we must pass
axis_size explicitly:
xs_roundtrip = jax.vmap(dequantize, in_axes=QArraySpec(), axis_size=4)(qxs)
print(jax.typeof(xs_roundtrip))
float32[4,2,3]
All the usual compositions work — vmap of jit,
print(jax.typeof(jax.vmap(jax.jit(dequantize), in_axes=QArraySpec(),
axis_size=4)(qxs)))
float32[4,2,3]
vmap of grad, and so on:
def norm_quantized(x):
return jnp.sum(dequantize(quantize(x)) ** 2)
print(jax.vmap(jax.grad(norm_quantized))(xs).shape)
(4, 2, 3)
What we haven’t covered#
A few more corners of the interface: types can implement
leading_axis_spec so that hi type values can be carried through
jax.lax.scan, and on the primitive side there are hooks for customizing
rematerialization and dead code elimination.
As ever with hijax, tests/hijax_test.py is a good source of worked
examples, and Custom derivative rules with hijax primitives covers the primitive-side
API — including JVP rules, symbolic zeros, and custom linearization — in
more depth.