Key concepts#

This section briefly introduces some key concepts of the JAX package.

JAX arrays (jax.Array)#

The default array implementation in JAX is jax.Array. In many ways it is similar to the numpy.ndarray type that you may be familiar with from the NumPy package, but it has some important differences.

Array creation#

We typically don’t call the jax.Array constructor directly, but rather create arrays via JAX API functions. For example, jax.numpy provides familiar NumPy-style array construction functionality such as jax.numpy.zeros(), jax.numpy.linspace(), jax.numpy.arange(), etc.

import jax
import jax.numpy as jnp

x = jnp.arange(5)
isinstance(x, jax.Array)
True

If you use Python type annotations in your code, jax.Array is the appropriate annotation for jax array objects (see jax.typing for more discussion).

Array devices and sharding#

JAX Array objects have a devices method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:

x.devices()
{CpuDevice(id=0)}

In general, an array may be sharded across multiple devices, in a manner that can be inspected via the sharding attribute:

x.sharding
SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

Here the array is on a single device, but in general a JAX array can be sharded across multiple devices, or even multiple hosts. To read more about sharded arrays and parallel computation, refer to Introduction to parallel programming

Transformations#

Along with functions to operate on arrays, JAX includes a number of transformations which operate on JAX functions. These include

as well as several others. Transformations accept a function as an argument, and return a new transformed function. For example, here’s how you might JIT-compile a simple SELU function:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jax.jit(selu)
print(selu_jit(1.0))
1.05

Often you’ll see transformations applied using Python’s decorator syntax for convenience:

@jax.jit
def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

Transformations like jit(), vmap(), grad(), and others are key to using JAX effectively, and we’ll cover them in detail in later sections.

Tracing#

The magic behind transformations is the notion of a Tracer. Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order to extract the sequence of operations that the function encodes.

You can see this by printing any array value within transformed JAX code; for example:

@jax.jit
def f(x):
  print(x)
  return x + 1

x = jnp.arange(5)
result = f(x)
Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace>

The value printed is not the array x, but a Tracer instance that represents essential attributes of x, such as its shape and dtype. By executing the function with traced values, JAX can determine the sequence of operations encoded by the function before those operations are actually executed: transformations like jit(), vmap(), and grad() can then map this sequence of input operations to a transformed sequence of operations.

Jaxprs#

JAX has its own intermediate representation for sequences of operations, known as a jaxpr. A jaxpr (short for JAX exPRession) is a simple representation of a functional program, comprising a sequence of primitive operations.

For example, consider the selu function we defined above:

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

We can use the jax.make_jaxpr() utility to convert this function into a jaxpr given a particular input:

x = jnp.arange(5.0)
jax.make_jaxpr(selu)(x)
{ lambda ; a:f32[5]. let
    b:bool[5] = gt a 0.0
    c:f32[5] = exp a
    d:f32[5] = mul 1.6699999570846558 c
    e:f32[5] = sub d 1.6699999570846558
    f:f32[5] = pjit[
      name=_where
      jaxpr={ lambda ; b:bool[5] a:f32[5] e:f32[5]. let
          f:f32[5] = select_n b e a
        in (f,) }
    ] b a e
    g:f32[5] = mul 1.0499999523162842 f
  in (g,) }

Comparing this to the Python function definition, we see that it encodes the precise sequence of operations that the function represents. We’ll go into more depth about jaxprs later in JAX internals: The jaxpr language.

Pytrees#

JAX functions and transformations fundamentally operate on arrays, but in practice it is convenient to write code that works with collection of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the pytree abstraction to treat such collections in a uniform manner.

Here are some examples of objects that can be treated as pytrees:

# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef([*, *, (*, *)])
[1, 2, Array([0, 1, 2], dtype=int32), Array([1., 1.], dtype=float32)]
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}

print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef({'W': *, 'b': *, 'n': *})
[Array([[1., 1.],
       [1., 1.]], dtype=float32), Array([0., 0.], dtype=float32), 5]
# Named tuple of parameters
from typing import NamedTuple

class Params(NamedTuple):
  a: int
  b: float

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PyTreeDef(CustomNode(namedtuple[Params], [*, *]))
[1, 5.0]

JAX has a number of general-purpose utilities for working with PyTrees; for example the functions jax.tree.map() can be used to map a function to every leaf in a tree, and jax.tree.reduce() can be used to apply a reduction across the leaves in a tree.

You can learn more in the Working with pytrees tutorial.

Pseudorandom numbers#

Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception. NumPy supports a method of pseudo random number generation that is based on a global state, which can be set using numpy.random.seed(). Global random state interacts poorly with JAX’s compute model and makes it difficult to enforce reproducibility across different threads, processes, and devices. JAX instead tracks state explicitly via a random key:

from jax import random

key = random.key(43)
print(key)
Array((), dtype=key<fry>) overlaying:
[ 0 43]

The key is effectively a stand-in for NumPy’s hidden state object, but we pass it explicitly to jax.random() functions. Importantly, random functions consume the key, but do not modify it: feeding the same key object to a random function will always result in the same sample being generated.

print(random.normal(key))
print(random.normal(key))
0.07520543
0.07520543

The rule of thumb is: never reuse keys (unless you want identical outputs).

In order to generate different and independent samples, you must split() the key explicitly before passing it to a random function:

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration.
draw 0: -1.9133632183074951
draw 1: -1.4749839305877686
draw 2: -0.36703771352767944

Note that this code is thread safe, since the local random state eliminates possible race conditions involving global state. jax.random.split() is a deterministic function that converts one key into several independent (in the pseudorandomness sense) keys.

For more on pseudo random numbers in JAX, see the Pseudorandom numbers tutorial.