Introduction to debugging#

This section introduces you to a set of built-in JAX debugging methods — jax.debug.print(), jax.debug.breakpoint(), and jax.debug.callback() — that you can use with various JAX transformations.

Let’s begin with jax.debug.print().

jax.debug.print for simple inspection#

Here is a rule of thumb:

Recall from Just-in-time compilation that when transforming a function with jax.jit(), the Python code is executed with abstract tracers in place of your arrays. Because of this, the Python print() function will only print this tracer value:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  print("print(x) ->", x)
  y = jnp.sin(x)
  print("print(y) ->", y)
  return y

result = f(2.)
print(x) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>
print(y) -> Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>

Python’s print executes at trace-time, before the runtime values exist. If you want to print the actual runtime values, you can use jax.debug.print():

@jax.jit
def f(x):
  jax.debug.print("jax.debug.print(x) -> {x}", x=x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {y}", y=y)
  return y

result = f(2.)
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.9092974066734314

Similarly, within jax.vmap(), using Python’s print will only print the tracer; to print the values being mapped over, use jax.debug.print():

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  y = jnp.sin(x)
  jax.debug.print("jax.debug.print(y) -> {}", y)
  return y

xs = jnp.arange(3.)

result = jax.vmap(f)(xs)
jax.debug.print(x) -> 0.0
jax.debug.print(x) -> 1.0
jax.debug.print(x) -> 2.0
jax.debug.print(y) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(y) -> 0.9092974066734314

Here’s the result with jax.lax.map(), which is a sequential map rather than a vectorization:

result = jax.lax.map(f, xs)
jax.debug.print(y) -> 0.0
jax.debug.print(x) -> 0.0
jax.debug.print(y) -> 0.8414709568023682
jax.debug.print(x) -> 1.0
jax.debug.print(y) -> 0.9092974066734314
jax.debug.print(x) -> 2.0

Notice the order is different, as jax.vmap() and jax.lax.map() compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.

Below is an example with jax.grad(), where jax.debug.print() only prints the forward pass. In this case, the behavior is similar to Python’s print(), but it’s consistent if you apply jax.jit() during the call.

def f(x):
  jax.debug.print("jax.debug.print(x) -> {}", x)
  return x ** 2

result = jax.grad(f)(1.)
jax.debug.print(x) -> 1.0

Sometimes, when the arguments don’t depend on one another, calls to jax.debug.print() may print them in a different order when staged out with a JAX transformation. If you need the original order, such as x: ... first and then y: ... second, add the ordered=True parameter.

For example:

@jax.jit
def f(x, y):
  jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
  jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
  return x + y

f(1, 2)
jax.debug.print(x) -> 1
jax.debug.print(y) -> 2
Array(3, dtype=int32, weak_type=True)

To learn more about jax.debug.print() and its Sharp Bits, refer to Advanced debugging.

jax.debug.breakpoint for pdb-like debugging#

Summary: Use jax.debug.breakpoint() to pause the execution of your JAX program to inspect values.

To pause your compiled JAX program during certain points during debugging, you can use jax.debug.breakpoint(). The prompt is similar to Python pdb, and it allows you to inspect the values in the call stack. In fact, jax.debug.breakpoint() is an application of jax.debug.callback() that captures information about the call stack.

To print all available commands during a breakpoint debugging session, use the help command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in Advanced debugging.)

Here is an example of what a debugger session might look like:

@jax.jit
def f(x):
  y, z = jnp.sin(x), jnp.cos(x)
  jax.debug.breakpoint()
  return y * z
f(2.) # ==> Pauses during execution

JAX debugger

For value-dependent breakpointing, you can use runtime conditionals like jax.lax.cond():

def breakpoint_if_nonfinite(x):
  is_finite = jnp.isfinite(x).all()
  def true_fn(x):
    pass
  def false_fn(x):
    jax.debug.breakpoint()
  jax.lax.cond(is_finite, true_fn, false_fn, x)

@jax.jit
def f(x, y):
  z = x / y
  breakpoint_if_nonfinite(z)
  return z

f(2., 1.) # ==> No breakpoint
Array(2., dtype=float32, weak_type=True)
f(2., 0.) # ==> Pauses during execution

jax.debug.callback for more control during debugging#

Both jax.debug.print() and jax.debug.breakpoint() are implemented using the more flexible jax.debug.callback(), which gives greater control over the host-side logic executed via a Python callback. It is compatible with jax.jit(), jax.vmap(), jax.grad() and other transformations (refer to the Flavors of callback table in External callbacks for more information).

For example:

import logging

def log_value(x):
  logging.warning(f'Logged value: {x}')

@jax.jit
def f(x):
  jax.debug.callback(log_value, x)
  return x

f(1.0);
WARNING:root:Logged value: 1.0

This callback is compatible with other transformations, including jax.vmap() and jax.grad():

x = jnp.arange(5.0)
jax.vmap(f)(x);
WARNING:root:Logged value: 0.0
WARNING:root:Logged value: 1.0
WARNING:root:Logged value: 2.0
WARNING:root:Logged value: 3.0
WARNING:root:Logged value: 4.0
jax.grad(f)(1.0);
WARNING:root:Logged value: 1.0

This can make jax.debug.callback() useful for general-purpose debugging.

You can learn more about jax.debug.callback() and other kinds of JAX callbacks in External callbacks.

Next steps#

Check out the Advanced debugging to learn more about debugging in JAX.