jax.debug.print#
- jax.debug.print(fmt, *args, ordered=False, **kwargs)[source]#
Prints values and works in staged out JAX functions.
This function does not work with f-strings because formatting is delayed. So instead of
jax.debug.print(f"hello {bar}")
, writejax.debug.print("hello {bar}", bar=bar)
.This function is a thin convenience wrapper around
jax.debug.callback()
. The implementation is essentially:def debug_print(fmt: str, *args, **kwargs): jax.debug.callback( lambda *args, **kwargs: print(fmt.format(*args, **kwargs)), *args, **kwargs)
It may be useful to call
jax.debug.callback()
directly instead of this convenience wrapper. For example, to get debug printing in logs, you might usejax.debug.callback()
together withlogging.log
.- Parameters:
fmt (str) – A format string, e.g.
"hello {x}"
, that will be used to format input arguments, likestr.format
. See the Python docs on string formatting and format string syntax.*args – A list of positional arguments to be formatted, as if passed to
fmt.format
.ordered (bool) – A keyword only argument used to indicate whether or not the staged out computation will enforce ordering of this
jax.debug.print
w.r.t. other orderedjax.debug.print
calls.**kwargs – Additional keyword arguments to be formatted, as if passed to
fmt.format
.
- Return type:
None