from d2l import torch as d2l
import torch
from torch import nn13.6 Concise Implementation for Multiple GPUs
Implementing parallelism from scratch for every new model is no fun. Moreover, there is significant benefit in optimizing synchronization tools for high performance. In the following we will show how to do this using high-level APIs of deep learning frameworks. The mathematics and the algorithms are the same as in Section 13.5. Quite unsurprisingly you will need at least two GPUs to run code of this section.
from d2l import jax as d2l
import functools
import jax
from jax import numpy as jnp
from flax import linen as nn
import optax
from flax.training import train_state
import flax
import numpy as npfrom d2l import mxnet as d2l
from mxnet import autograd, gluon, init, np, npx
from mxnet.gluon import nn
npx.set_np()13.6.1 A Toy Network
Let’s use a slightly more meaningful network than LeNet from Section 13.5 that is still sufficiently easy and quick to train. We pick a ResNet-18 variant (He et al. 2016). Since the input images are tiny we modify it slightly. In particular, the difference from Section 8.6 is that we use a smaller convolution kernel, stride, and padding at the beginning. Moreover, we remove the max-pooling layer.
def resnet18(num_classes, in_channels=1):
"""A slightly modified ResNet-18 model."""
def resnet_block(in_channels, out_channels, num_residuals,
first_block=False):
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(d2l.Residual(out_channels, use_1x1conv=True,
strides=2))
else:
blk.append(d2l.Residual(out_channels))
return nn.Sequential(*blk)
# This model uses a smaller convolution kernel, stride, and padding and
# removes the max-pooling layer
net = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU())
net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True))
net.add_module("resnet_block2", resnet_block(64, 128, 2))
net.add_module("resnet_block3", resnet_block(128, 256, 2))
net.add_module("resnet_block4", resnet_block(256, 512, 2))
net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1,1)))
net.add_module("fc", nn.Sequential(nn.Flatten(),
nn.Linear(512, num_classes)))
return net
class ResNet18(nn.Module):
"""A slightly modified ResNet-18 model."""
num_classes: int = 10
training: bool = True
def setup(self):
self.net = nn.Sequential([
nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding='same'),
nn.BatchNorm(not self.training),
nn.relu,
# ResNet blocks
d2l.Residual(64, training=self.training),
d2l.Residual(64, training=self.training),
d2l.Residual(128, use_1x1conv=True, strides=(2, 2),
training=self.training),
d2l.Residual(128, training=self.training),
d2l.Residual(256, use_1x1conv=True, strides=(2, 2),
training=self.training),
d2l.Residual(256, training=self.training),
d2l.Residual(512, use_1x1conv=True, strides=(2, 2),
training=self.training),
d2l.Residual(512, training=self.training),
# Global average pooling and classifier
lambda x: x.mean(axis=(1, 2)),
nn.Dense(self.num_classes),
])
def __call__(self, x):
return self.net(x)
def resnet18(num_classes):
"""A slightly modified ResNet-18 model."""
def resnet_block(num_channels, num_residuals, first_block=False):
blk = nn.Sequential()
for i in range(num_residuals):
if i == 0 and not first_block:
blk.add(d2l.Residual(
num_channels, use_1x1conv=True, strides=2))
else:
blk.add(d2l.Residual(num_channels))
return blk
net = nn.Sequential()
# This model uses a smaller convolution kernel, stride, and padding and
# removes the max-pooling layer
net.add(nn.Conv2D(64, kernel_size=3, strides=1, padding=1),
nn.BatchNorm(), nn.Activation('relu'))
net.add(resnet_block(64, 2, first_block=True),
resnet_block(128, 2),
resnet_block(256, 2),
resnet_block(512, 2))
net.add(nn.GlobalAvgPool2D(), nn.Dense(num_classes))
return net13.6.2 Network Initialization
We will initialize the network inside the training loop. For a refresher on initialization methods see Section 5.4.
In JAX, we initialize the model parameters and create a TrainState that bundles the parameters with the optimizer. For multi-GPU training, we replicate the state across all devices using flax.jax_utils.replicate.
The initialize function allows us to initialize parameters on a device of our choice. For a refresher on initialization methods see Section 5.4. What is particularly convenient is that it also allows us to initialize the network on multiple devices simultaneously. Let’s try how this works in practice.
net = resnet18(10)
# Get a list of GPUs
devices = d2l.try_all_gpus()
# We will initialize the network inside the training loopnet = ResNet18(num_classes=10)
# Count available devices (GPUs/TPUs)
num_devices = jax.local_device_count()
print(f'Using {num_devices} devices: {jax.devices()}')
# We will initialize the network inside the training loopUsing 4 devices: [CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]
net = resnet18(10)
# Get a list of GPUs
devices = d2l.try_all_gpus()
# Initialize all the parameters of the network
net.initialize(init=init.Normal(sigma=0.01), ctx=devices)Using the split_and_load function introduced in Section 13.5 we can divide a minibatch of data and copy portions to the list of devices provided by the devices variable. The network instance automatically uses the appropriate GPU to compute the value of the forward propagation. Here we generate 4 observations and split them over the GPUs.
x = np.random.uniform(size=(4, 1, 28, 28))
x_shards = gluon.utils.split_and_load(x, devices)
net(x_shards[0]), net(x_shards[1])[04:49:10] /work/mxnet/src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:96: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
(array([[ 2.2610195e-06, 2.2045988e-06, -5.4046795e-06, 1.2869961e-06,
5.1373149e-06, -3.8298003e-06, 1.4338968e-07, 5.4683442e-06,
-2.8279201e-06, -3.9651122e-06]], ctx=gpu(0)),
array([[ 2.0698672e-06, 2.0084667e-06, -5.6382496e-06, 1.0498482e-06,
5.5506434e-06, -4.1065477e-06, 6.0830178e-07, 5.4521761e-06,
-3.7365016e-06, -4.1891649e-06]], ctx=gpu(1)))
Once data passes through the network, the corresponding parameters are initialized on the device the data passed through. This means that initialization happens on a per-device basis. Since we picked GPU 0 and GPU 1 for initialization, the network is initialized only there, and not on the CPU. In fact, the parameters do not even exist on the CPU. We can verify this by printing out the parameters and observing any errors that might arise.
weight = net[0].params.get('weight')
try:
weight.data()
except RuntimeError:
print('not initialized on cpu')
weight.data(devices[0])[0], weight.data(devices[1])[0]not initialized on cpu
(array([[[ 0.01382882, -0.01183044, 0.01417865],
[-0.00319718, 0.00439528, 0.02562625],
[-0.00835081, 0.01387452, -0.01035946]]], ctx=gpu(0)),
array([[[ 0.01382882, -0.01183044, 0.01417865],
[-0.00319718, 0.00439528, 0.02562625],
[-0.00835081, 0.01387452, -0.01035946]]], ctx=gpu(1)))
Next, let’s replace the code to evaluate the accuracy by one that works in parallel across multiple devices. This serves as a replacement of the evaluate_accuracy_gpu function from Section 7.6. The main difference is that we split a minibatch before invoking the network. All else is essentially identical.
def evaluate_accuracy_gpus(net, data_iter, split_f=d2l.split_batch):
"""Compute the accuracy for a model on a dataset using multiple GPUs."""
# Query the list of devices
devices = list(net.collect_params().values())[0].list_ctx()
# No. of correct predictions, no. of predictions
metric = d2l.Accumulator(2)
for features, labels in data_iter:
X_shards, y_shards = split_f(features, labels, devices)
# Run in parallel
pred_shards = [net(X_shard) for X_shard in X_shards]
metric.add(sum(float(d2l.accuracy(pred_shard, y_shard)) for
pred_shard, y_shard in zip(
pred_shards, y_shards)), labels.size)
return metric[0] / metric[1]13.6.3 Training
As before, the training code needs to perform several basic functions for efficient parallelism:
- Network parameters need to be initialized across all devices.
- While iterating over the dataset minibatches are to be divided across all devices.
- We compute the loss and its gradient in parallel across devices.
- Gradients are aggregated and parameters are updated accordingly.
In the end we compute the accuracy (again in parallel) to report the final performance of the network. The training routine is quite similar to implementations in previous chapters, except that we need to split and aggregate data.
def train(net, num_gpus, batch_size, lr):
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
devices = [d2l.try_gpu(i) for i in range(num_gpus)]
def init_weights(module):
if type(module) in [nn.Linear, nn.Conv2d]:
nn.init.normal_(module.weight, std=0.01)
net.apply(init_weights)
# Set the model on multiple GPUs
net = nn.DataParallel(net, device_ids=devices)
trainer = torch.optim.SGD(net.parameters(), lr)
loss = nn.CrossEntropyLoss()
timer, num_epochs = d2l.Timer(), 10
animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
for epoch in range(num_epochs):
net.train()
timer.start()
for X, y in train_iter:
trainer.zero_grad()
X, y = X.to(devices[0]), y.to(devices[0])
l = loss(net(X), y)
l.backward()
trainer.step()
timer.stop()
animator.add(epoch + 1, (d2l.evaluate_accuracy_gpu(net, test_iter),))
print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
f'on {str(devices)}')def train(num_devices, batch_size, lr):
data = d2l.FashionMNIST(batch_size=batch_size)
train_iter = data.get_dataloader(train=True)
test_iter = data.get_dataloader(train=False)
net = ResNet18(num_classes=10, training=True)
# Initialize parameters
dummy_input = jnp.ones((1, 28, 28, 1))
key = jax.random.PRNGKey(0)
variables = net.init(key, dummy_input)
params = variables['params']
batch_stats = variables.get('batch_stats', {})
# Create optimizer and training state
tx = optax.sgd(lr)
class TrainState(train_state.TrainState):
batch_stats: dict
state = TrainState.create(apply_fn=net.apply, params=params,
tx=tx, batch_stats=batch_stats)
# Replicate state across devices
num_devices = jax.local_device_count()
state = jax.tree.map(
lambda x: jnp.stack([x] * num_devices), state)
@functools.partial(jax.pmap, axis_name='batch')
def train_step(state, images, labels):
"""A single training step on one device."""
def loss_fn(params):
logits, updates = state.apply_fn(
{'params': params, 'batch_stats': state.batch_stats},
images, mutable=['batch_stats'])
loss = optax.softmax_cross_entropy_with_integer_labels(
logits, labels).mean()
return loss, updates
(loss, updates), grads = jax.value_and_grad(
loss_fn, has_aux=True)(state.params)
# Average gradients across devices
grads = jax.lax.pmean(grads, axis_name='batch')
state = state.apply_gradients(grads=grads)
state = state.replace(
batch_stats=updates['batch_stats'])
return state, loss
@functools.partial(jax.pmap, axis_name='batch')
def eval_step(state, images, labels):
"""Evaluate accuracy on one device."""
logits, _ = state.apply_fn(
{'params': state.params,
'batch_stats': state.batch_stats},
images, mutable=['batch_stats'])
return (logits.argmax(axis=-1) == labels).sum(), labels.shape[0]
def reshape_batch(X, y, num_devices):
"""Reshape a batch for pmap: (batch, ...) -> (num_devices, per_device, ...)."""
per_device = X.shape[0] // num_devices
X = X[:per_device * num_devices].reshape(
num_devices, per_device, *X.shape[1:])
y = y[:per_device * num_devices].reshape(num_devices, per_device)
return X, y
timer, num_epochs = d2l.Timer(), 10
animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
for epoch in range(num_epochs):
timer.start()
for X, y in train_iter:
X, y = np.array(X), np.array(y)
X, y = reshape_batch(X, y, num_devices)
state, loss = train_step(state, X, y)
jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
timer.stop()
# Evaluate accuracy
correct, total = 0, 0
for X, y in test_iter:
X, y = np.array(X), np.array(y)
X, y = reshape_batch(X, y, num_devices)
c, t = eval_step(state, X, y)
correct += int(c.sum())
total += int(t.sum())
test_acc = correct / total
animator.add(epoch + 1, (test_acc,))
print(f'test acc: {test_acc:.2f}, {timer.avg():.1f} sec/epoch '
f'on {num_devices} devices')def train(num_gpus, batch_size, lr):
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
ctx = [d2l.try_gpu(i) for i in range(num_gpus)]
net.initialize(init=init.Normal(sigma=0.01), ctx=ctx, force_reinit=True)
trainer = gluon.Trainer(net.collect_params(), 'sgd',
{'learning_rate': lr})
loss = gluon.loss.SoftmaxCrossEntropyLoss()
timer, num_epochs = d2l.Timer(), 10
animator = d2l.Animator('epoch', 'test acc', xlim=[1, num_epochs])
for epoch in range(num_epochs):
timer.start()
for features, labels in train_iter:
X_shards, y_shards = d2l.split_batch(features, labels, ctx)
with autograd.record():
ls = [loss(net(X_shard), y_shard) for X_shard, y_shard
in zip(X_shards, y_shards)]
for l in ls:
l.backward()
trainer.step(batch_size)
npx.waitall()
timer.stop()
animator.add(epoch + 1, (evaluate_accuracy_gpus(net, test_iter),))
print(f'test acc: {animator.Y[0][-1]:.2f}, {timer.avg():.1f} sec/epoch '
f'on {str(ctx)}')Let’s see how this works in practice. As a warm-up we train the network on a single GPU.
train(net, num_gpus=1, batch_size=256, lr=0.1)test acc: 0.84, 5.9 sec/epoch on [device(type='cuda', index=0)]
train(num_devices=1, batch_size=256, lr=0.1)test acc: 0.92, 10.1 sec/epoch on 4 devices
train(num_gpus=1, batch_size=256, lr=0.1)test acc: 0.93, 6.4 sec/epoch on [gpu(0)]
Next we use 2 GPUs for training. Compared with LeNet evaluated in Section 13.5, the model for ResNet-18 is considerably more complex. This is where parallelization shows its advantage. The time for computation is meaningfully larger than the time for synchronizing parameters. This improves scalability since the overhead for parallelization is less relevant.
train(net, num_gpus=2, batch_size=512, lr=0.2)test acc: 0.87, 8.7 sec/epoch on [device(type='cuda', index=0), device(type='cuda', index=1)]
train(num_devices=2, batch_size=512, lr=0.2)test acc: 0.91, 6.4 sec/epoch on 4 devices
train(num_gpus=2, batch_size=512, lr=0.2)test acc: 0.92, 3.7 sec/epoch on [gpu(0), gpu(1)]
13.6.4 Summary
- JAX provides
jax.pmapfor data-parallel training across multiple devices with automatic gradient aggregation viajax.lax.pmean. - Flax’s
jax_utils.replicateandjax_utils.unreplicatehandle distributing and collecting state across devices.
- Gluon provides primitives for model initialization across multiple devices by providing a context list.
- Data is automatically evaluated on the devices where the data can be found.
- Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error.
- The optimization algorithms automatically aggregate over multiple GPUs.
13.6.5 Exercises
- This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this with 16 GPUs (e.g., on an AWS p2.16xlarge instance)?
- Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?
- This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this with 16 GPUs (e.g., on an AWS p2.16xlarge instance) or with TPUs?
- Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?
- What happens if we replace
jax.pmapwithjax.vmap? How does the behavior differ?
- This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this with 16 GPUs (e.g., on an AWS p2.16xlarge instance)?
- Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?
- What happens if we drop
npx.waitall()? How would you modify training such that you have an overlap of up to two steps for parallelism?