Omnistaging#
mattjj@ Sept 25 2020
This is more of an upgrade guide than a design doc.
Contents#
tl;dr#
What’s going on?#
A change to JAX’s tracing infrastructure called “omnistaging” (jax-ml/jax#3370) was switched on in jax==0.2.0. This change improves memory performance, trace execution time, and simplifies jax internals, but may cause some existing code to break. Breakage is usually a result of buggy code, so long-term it’s best to fix the bugs, but omnistaging can also be disabled as a temporary workaround. And we’re happy to help you with fixes!
How do I know if omnistaging broke my code?#
The easiest way to tell if omnistaging is responsible is to disable omnistaging and see if the issues go away. See the What issues can arise when omnistaging is switched on? section below.
How can I disable omnistaging for now?#
Note: this applies to JAX versions 0.2.0 through 0.2.11; omnistaging cannot be disabled in JAX versions 0.2.12 and higher
It is temporarily possible to disable omnistaging by
setting the shell environment variable
JAX_OMNISTAGING
to something falsey;setting the boolean flag
jax_omnistaging
to something falsey if your code parses flags with absl;using this statement near the top of your main file:
jax.config.disable_omnistaging()
How do I fix bugs exposed by omnistaging?#
By far the most common issue with omnistaging is using jax.numpy
to compute
shape values or other trace-time constants. See the code block below for a quick
example, and for full details along with other issues see the section What
issues can arise when omnistaging is switched
on?.
Instead of this:
@jit
def f(x):
input_size = jnp.prod(x.shape)
if input_size > 100:
...
do this:
import numpy as np
@jit
def f(x):
input_size = np.prod(x.shape)
if input_size > 100:
...
Instead of thinking of jax.numpy
as a drop-in replacement for numpy
, it’s
now better to think of using jax.numpy
operations only when you want to perform a
computation on an accelerator (like your GPU).
What is “omnistaging” and why is it useful?#
Omnistaging is the name for a JAX core upgrade aimed at staging out more
computation from op-by-op Python to XLA, and avoiding any “trace-time constant
folding” in jit
, pmap
, and control flow primitives. As a result, omnistaging
improves JAX’s memory performance (sometimes dramatically) both by reducing
fragmentation during tracing and by producing fewer large compile-time constants
for XLA. It can also improve tracing performance by eliminating op-by-op
execution at tracing time. Further, omnistaging simplifies JAX core internals,
fixing many outstanding bugs and setting the stage for important upcoming
features.
The name “omnistaging” means staging out everything possible.
Toy example#
JAX transformations like jit
and pmap
stage out computations to XLA. That
is, we apply them to functions comprising multiple primitive operations so that
rather being executed one at a time from Python the operations are all part of
one end-to-end optimized XLA computation.
But exactly which operations get staged out? Until omnistaging, JAX staged out computation based on data dependence only. Here’s an example function, followed by the XLA HLO program it stages out before the omnistaging change:
from jax import jit
import jax.numpy as jnp
@jit
def f(x):
y = jnp.add(1, 1)
return x * y
f(3)
ENTRY jit_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(2)
multiply.4 = s32[] multiply(parameter.1, constant.3)
ROOT tuple.5 = (s32[]) tuple(multiply.4)
}
Notice that the add
operation is not staged out. Instead, we only see a
multiply.
Here’s the HLO generated from this function after the omnistaging change:
ENTRY jit_f.8 {
constant.2 = pred[] constant(false)
parameter.1 = s32[] parameter(0)
constant.3 = s32[] constant(1)
constant.4 = s32[] constant(1)
add.5 = s32[] add(constant.3, constant.4)
multiply.6 = s32[] multiply(parameter.1, add.5)
ROOT tuple.7 = (s32[]) tuple(multiply.6)
}
Slightly less toy example#
Here’s a less toy example which can arise in practice when we want to create boolean masks:
import jax.numpy as jnp
from jax import lax
@jit
def select_tril(x):
mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
return lax.select(mask, x, jnp.zeros_like(x)) # lax.select is like jnp.where
x = np.arange(12).reshape((3, 4))
select_tril(x)
Before omnistaging:
ENTRY jit_select_tril.8 {
constant.3 = pred[] constant(false)
constant.1 = pred[3,4]{1,0} constant({...})
parameter.2 = s32[3,4]{1,0} parameter(0)
constant.4 = s32[] constant(0)
broadcast.5 = s32[3,4]{1,0} broadcast(constant.4), dimensions={}
select.6 = s32[3,4]{1,0} select(constant.1, parameter.2, broadcast.5)
ROOT tuple.7 = (s32[3,4]{1,0}) tuple(select.6)
}
The select
operation is staged out, but the operations for constructing the
constant mask
are not. Rather than being staged out, the operations that
construct mask
are executed op-by-op at Python tracing time, and XLA only sees
a compile time constant constant.1
representing the value of mask
. That’s
unfortunate, because if we had staged out the operations for constructing
mask
, XLA could have fused them into the select
and avoided materializing
the result at all. As a result we end up wasting memory with a potentially-large
constant, wasting time dispatching multiple un-fused op-by-op XLA computations,
and potentially even fragmenting memory.
(The broadcast
that corresponds to the construction of the zeros array for
jnp.zeros_like(x)
is staged out because JAX is lazy about very simple
expressions from jax-ml/jax#1668. After
omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)
The reason the creation of mask
is not staged out is that, before omnistaging,
jit
operates based on data dependence. That is, jit
stages out only those
operations in a function that have a data dependence on an argument. Control
flow primitives and pmap
behave similarly. In the case of select_tril
, the
operations to construct the constant mask
do not have a data dependence on the
argument x, so they are not staged out; only the lax.select
call has a data
dependence.
With omnistaging all jax.numpy
calls in the dynamic context of a
jit
-transformed function are staged out to XLA. That is, after omnistaging the
computation XLA sees for select_tril
is
ENTRY jit_select_tril.16 {
constant.4 = pred[] constant(false)
iota.1 = s32[3]{0} iota(), iota_dimension=0
broadcast.5 = s32[3,1]{1,0} broadcast(iota.1), dimensions={0}
reshape.7 = s32[3]{0} reshape(broadcast.5)
broadcast.8 = s32[3,4]{1,0} broadcast(reshape.7), dimensions={0}
iota.2 = s32[4]{0} iota(), iota_dimension=0
broadcast.6 = s32[1,4]{1,0} broadcast(iota.2), dimensions={1}
reshape.9 = s32[4]{0} reshape(broadcast.6)
broadcast.10 = s32[3,4]{1,0} broadcast(reshape.9), dimensions={1}
compare.11 = pred[3,4]{1,0} compare(broadcast.8, broadcast.10), direction=GT
parameter.3 = s32[3,4]{1,0} parameter(0)
constant.12 = s32[] constant(0)
broadcast.13 = s32[3,4]{1,0} broadcast(constant.12), dimensions={}
select.14 = s32[3,4]{1,0} select(compare.11, parameter.3, broadcast.13)
ROOT tuple.15 = (s32[3,4]{1,0}) tuple(select.14)
}
What issues can arise when omnistaging is switched on?#
As a consequence of staging out all jax.numpy
operations from Python to XLA
when in the dynamic context of a jit
or pmap
, some code that worked
previously can start raising loud errors. As explained below, these behaviors
were already buggy before omnistaging, but omnistaging makes them into hard
errors.
Using jax.numpy
for shape computations#
Example#
from jax import jit
import jax.numpy as jnp
@jit
def ex1(x):
size = jnp.prod(jnp.array(x.shape))
return x.reshape((size,))
ex1(jnp.ones((3, 4)))
Error message#
[... full traceback ...]
File "/home/mattjj/packages/jax/jax/core.py", line 862, in raise_concretization_error
raise ConcretizationTypeError(msg)
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function ex1 at ex1.py:4, this value became a tracer due to JAX operations on these lines:
operation c:int32[] = reduce_prod[ axes=(0,) ] b:int32[2]
from line ex1.py:6 (ex1)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
Explanation#
With omnistaging, we can’t use jax.numpy
for shape computations as in the use
of jnp.prod
above because in the dynamic context of a jit function those
operations will be staged out of Python as values to be computed at execution
time, yet we need them to be compile-time (and hence trace-time) constants.
Before omnistaging, this code wouldn’t have raised an error, but it was a common
performance bug: the jnp.prod
computation would have been executed on the
device at tracing time, meaning extra compilation, transfers, synchronization,
allocations, and potentially memory fragmentation.
Solution#
The solution is simply to use the original numpy
for shape calculations like
these. Not only do we avoid the error, but also we keep the computations on the
host (and with lower overheads).
This issue was common enough in code that we tried to make the error
message especially good. In addition to the stack trace showing where an
abstract tracer value caused a problem (the jnp.reshape
line in the full stack
trace, on omni.py:10), we also explain why this value became a tracer in the
first place by pointing to the upstream primitive operation that caused it to
become an abstract tracer (the reduce_prod
from jnp.prod
on omni.py:9) and to
which jit
-decorated function the tracer belongs (ex1
on omni.py:6).
Side-effects#
Example#
from jax import jit
from jax import random
key = random.PRNGKey(0)
def init():
global key
key, subkey = random.split(key)
return random.normal(subkey, ())
print(init()) # -1.2515389
print(init()) # -0.58665067
init = jit(init)
print(init()) # 0.48648298
print(init()) # 0.48648298 !!
That last call has repeated randomness but no hard error, because we aren’t
re-executing the Python. But if we look at key
, we see an escaped tracer when
omnistaging is on:
print(key) # Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>
Before omnistaging, the random.split
call would not be staged out and so we
wouldn’t get an escaped tracer. The code would still be buggy in that the jitted
function wouldn’t be reproducing the semantics of the original function (because
of the repeated use of the same PRNG key), ultimately due to the side effect.
With omnistaging on, if we touch key
again, we’ll get an escaped tracer error:
random.normal(key, ())
Error message#
[... full stack trace …]
File "/home/mattjj/packages/jax/jax/interpreters/partial_eval.py", line 836, in _assert_live
raise core.escaped_tracer_error(msg)
jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state. Detail: tracer created on line example.py:8 (init).
Explanation#
The second largest category of omnistaging issues we found had to do with side-effecting code. This code already voided the JAX warranty by transforming effectful functions, but due to pre-omnistaging “trace-time constant folding” behavior, some side effecting functions could nevertheless behave correctly. Omnistaging catches more of these errors.
Solution#
The solution is to identify JAX-transformed functions that rely on side effects, and to rewrite them not to be effectful.
Small numerical differences based on XLA optimizations#
Because with omnistaging more computations are being staged out to XLA, rather than some being executed at trace time, that can have the effect of reordering floating point operations. As a result, we’ve seen numerical behaviors change in a way that causes tests with overly tight tolerances to fail when omnistaging is switched on.
Dependence on JAX internal APIs that changed#
Omnistaging involved some big revisions to JAX’s core code, including removing or changing internal functions. Any code that relies on such internal JAX APIs can break when omnistaging is switched on, either with build errors (from pytype) or runtime errors.
Triggering XLA compile time bugs#
Because omnistaging involves staging out more code to XLA, we’ve seen it trigger pre-existing XLA compile-time bugs on some backends. The best thing to do with these is to report them so we can work with the XLA teams on fixes.