Distributed arrays and automatic parallelization#
This tutorial discusses parallelism via jax.Array
, the unified array object model available in JAX v0.4.1 and newer.
from typing import Optional
import numpy as np
import jax
import jax.numpy as jnp
β οΈ WARNING: The notebook requires 8 devices to run.
if len(jax.local_devices()) < 8:
raise Exception("Notebook requires 8 devices to run")
Intro and a quick example#
By reading this tutorial notebook, youβll learn about jax.Array
, a unified
datatype for representing arrays, even with physical storage spanning multiple
devices. Youβll also learn about how using jax.Array
s together with jax.jit
can provide automatic compiler-based parallelization.
Before we think step by step, hereβs a quick example.
First, weβll create a jax.Array
sharded across multiple devices:
from jax.sharding import PartitionSpec as P, NamedSharding
# Create a Sharding object to distribute a value across devices:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)
ββββββββββββ¬βββββββββββ β TPU 0 β TPU 1 β ββββββββββββΌβββββββββββ€ β TPU 2 β TPU 3 β ββββββββββββΌβββββββββββ€ β TPU 6 β TPU 7 β ββββββββββββΌβββββββββββ€ β TPU 4 β TPU 5 β ββββββββββββ΄βββββββββββ
Next, weβll apply a computation to it and visualize how the result values are stored across multiple devices too:
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
ββββββββββββ¬βββββββββββ β TPU 0 β TPU 1 β ββββββββββββΌβββββββββββ€ β TPU 2 β TPU 3 β ββββββββββββΌβββββββββββ€ β TPU 6 β TPU 7 β ββββββββββββΌβββββββββββ€ β TPU 4 β TPU 5 β ββββββββββββ΄βββββββββββ
The evaluation of the jnp.sin
application was automatically parallelized
across the devices on which the input values (and output values) are stored:
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.
25.2 ms Β± 30.9 ms per loop (mean Β± std. dev. of 5 runs, 5 loops each)
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
2.4 ms Β± 61.4 Β΅s per loop (mean Β± std. dev. of 5 runs, 5 loops each)
Now letβs look at each of these pieces in more detail!
Computation follows data sharding and is automatically parallelized#
With sharded input data, the compiler can give us parallel computation. In particular, functions decorated with jax.jit
can operate over sharded arrays without copying data onto a single device. Instead, computation follows sharding: based on the sharding of the input data, the compiler decides shardings for intermediates and output values, and parallelizes their evaluation, even inserting communication operations as necessary.
For example, the simplest computation is an elementwise one:
mesh = jax.make_mesh((4, 2), ('a', 'b'))
x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
jax.debug.visualize_array_sharding(x)
y = jnp.sin(x)
print('output sharding:')
jax.debug.visualize_array_sharding(y)
input sharding:
output sharding:
ββββββββββββ¬βββββββββββ β TPU 0 β TPU 1 β ββββββββββββΌβββββββββββ€ β TPU 2 β TPU 3 β ββββββββββββΌβββββββββββ€ β TPU 6 β TPU 7 β ββββββββββββΌβββββββββββ€ β TPU 4 β TPU 5 β ββββββββββββ΄βββββββββββ
ββββββββββββ¬βββββββββββ β TPU 0 β TPU 1 β ββββββββββββΌβββββββββββ€ β TPU 2 β TPU 3 β ββββββββββββΌβββββββββββ€ β TPU 6 β TPU 7 β ββββββββββββΌβββββββββββ€ β TPU 4 β TPU 5 β ββββββββββββ΄βββββββββββ
Here for the elementwise operation jnp.sin
the compiler chose the output sharding to be the same as the input. Moreover, the compiler automatically parallelized the computation, so that each device computed its output shard from its input shard in parallel.
In other words, even though we wrote the jnp.sin
computation as if a single machine were to execute it, the compiler splits up the computation for us and executes it on multiple devices.
We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:
y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
jax.debug.visualize_array_sharding(z)
w = jnp.dot(y, z)
print('out sharding:')
jax.debug.visualize_array_sharding(w)
lhs sharding:
rhs sharding:
out sharding:
βββββββββββββββββββββββββ β TPU 0,1 β βββββββββββββββββββββββββ€ β TPU 2,3 β βββββββββββββββββββββββββ€ β TPU 6,7 β βββββββββββββββββββββββββ€ β TPU 4,5 β βββββββββββββββββββββββββ
βββββββββββββ¬ββββββββββββ β β β β β β β β β β β β βTPU 0,2,4,6βTPU 1,3,5,7β β β β β β β β β β β β β βββββββββββββ΄ββββββββββββ
ββββββββββββ¬βββββββββββ β TPU 0 β TPU 1 β ββββββββββββΌβββββββββββ€ β TPU 2 β TPU 3 β ββββββββββββΌβββββββββββ€ β TPU 6 β TPU 7 β ββββββββββββΌβββββββββββ€ β TPU 4 β TPU 5 β ββββββββββββ΄βββββββββββ
Here the compiler chose the output sharding so that it could maximally parallelize the computation: without needing communication, each device already has the input shards it needs to compute its output shard.
How can we be sure itβs actually running in parallel? We can do a simple timing experiment:
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
βββββββββββββββββββββββββ
β β
β β
β β
β β
β TPU 0 β
β β
β β
β β
β β
βββββββββββββββββββββββββ
np.allclose(jnp.dot(x_single, x_single),
jnp.dot(y, z))
True
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
49.7 ms Β± 349 Β΅s per loop (mean Β± std. dev. of 5 runs, 5 loops each)
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
7.47 ms Β± 44.8 Β΅s per loop (mean Β± std. dev. of 5 runs, 5 loops each)
Even copying a sharded Array
produces a result with the sharding of the input:
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
ββββββββββββ¬βββββββββββ β TPU 0 β TPU 1 β ββββββββββββΌβββββββββββ€ β TPU 2 β TPU 3 β ββββββββββββΌβββββββββββ€ β TPU 6 β TPU 7 β ββββββββββββΌβββββββββββ€ β TPU 4 β TPU 5 β ββββββββββββ΄βββββββββββ
So computation follows data placement: when we explicitly shard data with jax.device_put
, and apply functions to that data, the compiler attempts to parallelize the computation and decide the output sharding. This policy for sharded data is a generalization of JAXβs policy of following explicit device placement.
When explicit shardings disagree, JAX errors#
But what if two arguments to a computation are explicitly placed on different sets of devices, or with incompatible device orders? In these ambiguous cases, an error is raised:
import textwrap
from termcolor import colored
def print_exception(e):
name = colored(f'{type(e).__name__}', 'red', force_color=True)
print(textwrap.fill(f'{name}: {str(e)}'))
sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))
y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on
platform TPU
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))
y = jax.device_put(x, sharding1)
z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
ValueError: Received incompatible devices for jitted
computation. Got argument x1 of jax.numpy.add with shape int32[24] and
device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of
jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,
4, 5] on platform TPU
We say arrays that have been explicitly placed or sharded with jax.device_put
are committed to their device(s), and so wonβt be automatically moved. See the device placement FAQ for more information.
When arrays are not explicitly placed or sharded with jax.device_put
, they are placed uncommitted on the default device.
Unlike committed arrays, uncommitted arrays can be moved and resharded automatically: that is, uncommitted arrays can be arguments to a computation even if other arguments are explicitly placed on different devices.
For example, the output of jnp.zeros
, jnp.arange
, and jnp.array
are uncommitted:
y = jax.device_put(x, sharding1)
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
no error!
Constraining shardings of intermediates in jit
ted code#
While the compiler will attempt to decide how a functionβs intermediate values and outputs should be sharded, we can also give it hints using jax.lax.with_sharding_constraint
. Using jax.lax.with_sharding_constraint
is much like jax.device_put
, except we use it inside staged-out (i.e. jit
-decorated) functions:
mesh = jax.make_mesh((4, 2), ('x', 'y'))
x = jax.random.normal(jax.random.key(0), (8192, 8192))
x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
ββββββββββββ¬βββββββββββ β TPU 0 β TPU 1 β ββββββββββββΌβββββββββββ€ β TPU 2 β TPU 3 β ββββββββββββΌβββββββββββ€ β TPU 6 β TPU 7 β ββββββββββββΌβββββββββββ€ β TPU 4 β TPU 5 β ββββββββββββ΄βββββββββββ
βββββββββ¬ββββββββ¬ββββββββ¬ββββββββ β β β β β β TPU 0 β TPU 2 β TPU 6 β TPU 4 β β β β β β β β β β β βββββββββΌββββββββΌββββββββΌββββββββ€ β β β β β β TPU 1 β TPU 3 β TPU 7 β TPU 5 β β β β β β β β β β β βββββββββ΄ββββββββ΄ββββββββ΄ββββββββ
@jax.jit
def f(x):
x = x + 1
y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
return y
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
ββββββββββββ¬βββββββββββ β TPU 0 β TPU 1 β ββββββββββββΌβββββββββββ€ β TPU 2 β TPU 3 β ββββββββββββΌβββββββββββ€ β TPU 6 β TPU 7 β ββββββββββββΌβββββββββββ€ β TPU 4 β TPU 5 β ββββββββββββ΄βββββββββββ
βββββββββββββββββββββββββ β β β β β β β β β TPU 0,1,2,3,4,5,6,7 β β β β β β β β β βββββββββββββββββββββββββ
By adding with_sharding_constraint
, weβve constrained the sharding of the output. In addition to respecting the annotation on a particular intermediate, the compiler will use annotations to decide shardings for other values.
Itβs often a good practice to annotate the outputs of computations, for example based on how the values are ultimately consumed.
Examples: neural networks#
β οΈ WARNING: The following is meant to be a simple demonstration of automatic sharding propagation with jax.Array
, but it may not reflect best practices for real examples. For instance, real examples may require more use of with_sharding_constraint
.
We can use jax.device_put
and jax.jit
βs computation-follows-sharding features to parallelize computation in neural networks. Here are some simple examples, based on this basic neural network:
import jax
import jax.numpy as jnp
def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.maximum(outputs, 0)
return outputs
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.mean(jnp.sum((predictions - targets)**2, axis=-1))
loss_jit = jax.jit(loss)
gradfun = jax.jit(jax.grad(loss))
def init_layer(key, n_in, n_out):
k1, k2 = jax.random.split(key)
W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
b = jax.random.normal(k2, (n_out,))
return W, b
def init_model(key, layer_sizes, batch_size):
key, *keys = jax.random.split(key, len(layer_sizes))
params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
key, *keys = jax.random.split(key, 3)
inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
return params, (inputs, targets)
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
8-way batch data parallelism#
mesh = jax.make_mesh((8,), ('batch',))
sharding = NamedSharding(mesh, P('batch'))
replicated_sharding = NamedSharding(mesh, P())
batch = jax.device_put(batch, sharding)
params = jax.device_put(params, replicated_sharding)
loss_jit(params, batch)
Array(23.469475, dtype=float32)
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.760109
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
53.8 ms Β± 1.14 ms per loop (mean Β± std. dev. of 5 runs, 5 loops each)
batch_single = jax.device_put(batch, jax.devices()[0])
params_single = jax.device_put(params, jax.devices()[0])
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
351 ms Β± 81.2 ms per loop (mean Β± std. dev. of 5 runs, 5 loops each)
4-way batch data parallelism and 2-way model tensor parallelism#
mesh = jax.make_mesh((4, 2), ('batch', 'model'))
batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
βββββββββ βTPU 0,1β βββββββββ€ βTPU 2,3β βββββββββ€ βTPU 6,7β βββββββββ€ βTPU 4,5β βββββββββ
βββββββββ βTPU 0,1β βββββββββ€ βTPU 2,3β βββββββββ€ βTPU 6,7β βββββββββ€ βTPU 4,5β βββββββββ
replicated_sharding = NamedSharding(mesh, P())
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
W1 = jax.device_put(W1, replicated_sharding)
b1 = jax.device_put(b1, replicated_sharding)
W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))
W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
b3 = jax.device_put(b3, replicated_sharding)
W4 = jax.device_put(W4, replicated_sharding)
b4 = jax.device_put(b4, replicated_sharding)
params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
jax.debug.visualize_array_sharding(W2)
βββββββββββββ¬ββββββββββββ β β β β β β β β β β β β βTPU 0,2,4,6βTPU 1,3,5,7β β β β β β β β β β β β β βββββββββββββ΄ββββββββββββ
jax.debug.visualize_array_sharding(W3)
βββββββββββββββββββββββββ β β β TPU 0,2,4,6 β β β β β βββββββββββββββββββββββββ€ β β β TPU 1,3,5,7 β β β β β βββββββββββββββββββββββββ
print(loss_jit(params, batch))
10.760109
step_size = 1e-5
for _ in range(30):
grads = gradfun(params, batch)
params = [(W - step_size * dW, b - step_size * db)
for (W, b), (dW, db) in zip(params, grads)]
print(loss_jit(params, batch))
10.752513
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
βββββββββββββ¬ββββββββββββ β β β β β β β β β β β β βTPU 0,2,4,6βTPU 1,3,5,7β β β β β β β β β β β β β βββββββββββββ΄ββββββββββββ
βββββββββββββββββββββββββ β β β TPU 0,2,4,6 β β β β β βββββββββββββββββββββββββ€ β β β TPU 1,3,5,7 β β β β β βββββββββββββββββββββββββ
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
51.4 ms Β± 454 Β΅s per loop (mean Β± std. dev. of 10 runs, 10 loops each)