Run this notebook online: or Colab:
12.3. Concise Implementation for Multiple GPUs¶
Implementing parallelism from scratch for every new model is no fun.
Moreover, there is significant benefit in optimizing synchronization
tools for high performance. In the following we will show how to do this
using DJL. The math and the algorithms are the same as in
sec_multi_gpu
. As before we begin by importing the required
modules (quite unsurprisingly you will need at least two GPUs to run
this notebook).
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Training.java
import ai.djl.basicdataset.cv.classification.*;
import ai.djl.metric.*;
import org.apache.commons.lang3.ArrayUtils;
12.3.1. A Toy Network¶
Let us use a slightly more meaningful network than LeNet from the previous section that’s still sufficiently easy and quick to train. We pick a ResNet-18 variant [He et al., 2016a]. Since the input images are tiny we modify it slightly. In particular, the difference to Section 7.6 is that we use a smaller convolution kernel, stride, and padding at the beginning. Moreover, we remove the max-pooling layer.
class Residual extends AbstractBlock {
private static final byte VERSION = 2;
public ParallelBlock block;
public Residual(int numChannels, boolean use1x1Conv, Shape strideShape) {
super(VERSION);
SequentialBlock b1;
SequentialBlock conv1x1;
b1 = new SequentialBlock();
b1.add(Conv2d.builder()
.setFilters(numChannels)
.setKernelShape(new Shape(3, 3))
.optPadding(new Shape(1, 1))
.optStride(strideShape)
.build())
.add(BatchNorm.builder().build())
.add(Activation::relu)
.add(Conv2d.builder()
.setFilters(numChannels)
.setKernelShape(new Shape(3, 3))
.optPadding(new Shape(1, 1))
.build())
.add(BatchNorm.builder().build());
if (use1x1Conv) {
conv1x1 = new SequentialBlock();
conv1x1.add(Conv2d.builder()
.setFilters(numChannels)
.setKernelShape(new Shape(1, 1))
.optStride(strideShape)
.build());
} else {
conv1x1 = new SequentialBlock();
conv1x1.add(Blocks.identityBlock());
}
block = addChildBlock("residualBlock", new ParallelBlock(
list -> {
NDList unit = list.get(0);
NDList parallel = list.get(1);
return new NDList(
unit.singletonOrThrow()
.add(parallel.singletonOrThrow())
.getNDArrayInternal()
.relu());
},
Arrays.asList(b1, conv1x1)));
}
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
return block.forward(parameterStore, inputs, training);
}
@Override
public Shape[] getOutputShapes(Shape[] inputs) {
Shape[] current = inputs;
for (Block block : block.getChildren().values()) {
current = block.getOutputShapes(current);
}
return current;
}
@Override
protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
block.initialize(manager, dataType, inputShapes);
}
}
public SequentialBlock resnetBlock(int numChannels, int numResiduals, boolean isFirstBlock) {
SequentialBlock blk = new SequentialBlock();
for (int i = 0; i < numResiduals; i++) {
if (i == 0 && !isFirstBlock) {
blk.add(new Residual(numChannels, true, new Shape(2, 2)));
} else {
blk.add(new Residual(numChannels, false, new Shape(1, 1)));
}
}
return blk;
}
int numClass = 10;
// This model uses a smaller convolution kernel, stride, and padding and
// removes the maximum pooling layer
SequentialBlock net = new SequentialBlock();
net
.add(
Conv2d.builder()
.setFilters(64)
.setKernelShape(new Shape(3, 3))
.optPadding(new Shape(1, 1))
.build())
.add(BatchNorm.builder().build())
.add(Activation::relu)
.add(resnetBlock(64, 2, true))
.add(resnetBlock(128, 2, false))
.add(resnetBlock(256, 2, false))
.add(resnetBlock(512, 2, false))
.add(Pool.globalAvgPool2dBlock())
.add(Linear.builder().setUnits(numClass).build());
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
SequentialBlock {
Residual {
residualBlock {
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
Conv2d
BatchNorm
}
SequentialBlock {
identity
}
}
}
Residual {
residualBlock {
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
Conv2d
BatchNorm
}
SequentialBlock {
identity
}
}
}
}
SequentialBlock {
Residual {
residualBlock {
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
Conv2d
BatchNorm
}
SequentialBlock {
Conv2d
}
}
}
Residual {
residualBlock {
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
Conv2d
BatchNorm
}
SequentialBlock {
identity
}
}
}
}
SequentialBlock {
Residual {
residualBlock {
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
Conv2d
BatchNorm
}
SequentialBlock {
Conv2d
}
}
}
Residual {
residualBlock {
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
Conv2d
BatchNorm
}
SequentialBlock {
identity
}
}
}
}
SequentialBlock {
Residual {
residualBlock {
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
Conv2d
BatchNorm
}
SequentialBlock {
Conv2d
}
}
}
Residual {
residualBlock {
SequentialBlock {
Conv2d
BatchNorm
LambdaBlock
Conv2d
BatchNorm
}
SequentialBlock {
identity
}
}
}
}
globalAvgPool2d
Linear
}
12.3.2. Parameter Initialization and Logistics¶
The setInitializer
method allows us to set initial defaults for
parameters on a device of our choice. For a refresher see
Section 4.8. What is particularly convenient is
that it also lets us initialize the network on multiple devices
simultaneously. Let us try how this works in practice.
Model model = Model.newInstance("training-multiple-gpus-1");
model.setBlock(net);
Loss loss = Loss.softmaxCrossEntropyLoss();
Tracker lrt = Tracker.fixed(0.1f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(sgd) // Optimizer (loss function)
.optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer
.optDevices(Engine.getInstance().getDevices(1)) // setting the number of GPUs needed
.addEvaluator(new Accuracy()) // Model Accuracy
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging
Trainer trainer = model.newTrainer(config);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.070 ms.
Using the split
function in the previous section we can divide a
minibatch of data and copy portions to your list of devices, and this
list can be retrieved from Device
variable. We can then split the
data with the help of Batchifier
class. The network object
automatically uses the appropriate GPU to compute the value of the
forward propagation. As before we generate 4 observations and split them
over the GPUs.
NDManager manager = NDManager.newBaseManager();
NDArray X = manager.randomUniform(0f, 1.0f, new Shape(4, 1, 28, 28));
trainer.initialize(X.getShape());
NDList[] res = Batchifier.STACK.split(new NDList(X), 4, true);
ParameterStore parameterStore = new ParameterStore(manager, true);
System.out.println(net.forward(parameterStore, new NDList(res[0]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[1]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[2]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[3]), false).singletonOrThrow());
ND: (1, 10) gpu(0) float32
[[-2.53076792e-07, 2.19176854e-06, -2.05096558e-06, -2.80443487e-07, -1.65612937e-06, 5.92275399e-07, -4.38029275e-07, 1.43108821e-07, 1.86682854e-07, 8.35030505e-07],
]
ND: (1, 10) gpu(0) float32
[[-3.17955994e-07, 1.94063477e-06, -1.82914255e-06, 1.36083145e-09, -1.45861077e-06, 4.11562326e-07, -8.99586439e-07, 1.97685665e-07, 2.77768578e-07, 6.80656115e-07],
]
ND: (1, 10) gpu(0) float32
[[-1.82850158e-07, 2.26233874e-06, -2.24626365e-06, 8.68596715e-08, -1.29084265e-06, 9.33801005e-07, -1.04999901e-06, 1.76022922e-07, 3.97307645e-08, 9.49504113e-07],
]
ND: (1, 10) gpu(0) float32
[[-1.78178539e-07, 1.59132321e-06, -2.00916884e-06, -2.30666600e-07, -1.31331467e-06, 5.71873784e-07, -4.02916669e-07, 1.11762461e-07, 3.40592749e-07, 8.89963815e-07],
]
Once data passes through the network, the corresponding parameters are initialized on the device the data passed through. This means that initialization happens on a per-device basis.
net.getChildren().values().get(0).getParameters().get("weight").getArray().get(new NDIndex("0:1"));
ND: (1, 1, 3, 3) gpu(0) float32
[[[[ 0.0053, -0.0018, -0.0141],
[-0.0094, -0.0146, 0.0094],
[ 0.002 , 0.0189, 0.0014],
],
],
]
12.3.3. Training¶
As before, the training code needs to perform a number of basic functions for efficient parallelism:
Network parameters need to be initialized across all devices.
While iterating over the dataset minibatches are to be divided across all devices.
We compute the loss and its gradient in parallel across devices.
Losses are aggregated (by the trainer method) and parameters are updated accordingly.
In the end we compute the accuracy (again in parallel) to report the final value of the network. The training routine is quite similar to implementations in previous chapters, except that we need to split and aggregate data.
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);
double[] testAccuracy;
double[] epochCount;
epochCount = new double[numEpochs];
for (int i = 0; i < epochCount.length; i++) {
epochCount[i] = (i + 1);
}
Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = 0;
public void train(int numEpochs, Trainer trainer, int batchSize) throws IOException, TranslateException {
FashionMnist trainIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
FashionMnist testIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TEST)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
trainIter.prepare();
testIter.prepare();
Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTime = 0;
trainer.setMetrics(new Metrics());
EasyTrain.fit(trainer, numEpochs, trainIter, testIter);
Metrics metrics = trainer.getMetrics();
trainer.getEvaluators().stream()
.forEach(evaluator -> {
evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
.mapToDouble(x -> x.getValue().doubleValue()).toArray());
});
avgTrainTime = metrics.mean("epoch");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");
System.out.printf("test acc %.2f\n" , testAccuracy[numEpochs-1]);
System.out.println(avgTrainTime / Math.pow(10, 9) + " sec/epoch \n");
}
12.3.4. Experiments¶
Let us see how this works in practice. As a warmup we train the network on a single GPU.
Table data = null;
// We will check if we have at least 1 GPU available. If yes, we run the training on 1 GPU.
if (Engine.getInstance().getGpuCount() >= 1) {
train(numEpochs, trainer, 256);
data = Table.create("Data");
data = data.addColumns(
DoubleColumn.create("X", epochCount),
DoubleColumn.create("testAccuracy", testAccuracy)
);
}
Training: 100% |████████████████████████████████████████| Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.62
Validating: 100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.62
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.43
Training: 100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26
Validating: 100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26
INFO Validate: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28
Training: 100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.20
Validating: 100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.20
INFO Validate: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.27
Training: 100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
Validating: 100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Training: 100% |████████████████████████████████████████| Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
Validating: 100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.35
Training: 100% |████████████████████████████████████████| Accuracy: 0.96, SoftmaxCrossEntropyLoss: 0.11
Validating: 100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.96, SoftmaxCrossEntropyLoss: 0.11
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.37
Training: 100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
Validating: 100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.28
Training: 100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.06
Validating: 100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.06
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.51
Training: 100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.05
Validating: 100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.05
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.46
Training: 100% |████████████████████████████████████████| Accuracy: 0.99, SoftmaxCrossEntropyLoss: 0.04
Validating: 100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.99, SoftmaxCrossEntropyLoss: 0.04
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.37
test acc 0.91
19.9016687377 sec/epoch
// uncomment to view graph if you have 1 GPU, since the render function doesn't work inside the if condition scope.
// render(LinePlot.create("", data, "x", "testAccuracy"), "text/html");
Table data = Table.create("Data");
// We will check if we have more than 1 GPU available. If yes, we run the training on 2 GPU.
if (Engine.getInstance().getGpuCount() >= 1) {
X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 28, 28));
Model model = Model.newInstance("training-multiple-gpus-2");
model.setBlock(net);
loss = Loss.softmaxCrossEntropyLoss();
Tracker lrt = Tracker.fixed(0.2f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(sgd) // Optimizer (loss function)
.optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer
.optDevices(Engine.getInstance().getDevices(2)) // setting the number of GPUs needed
.addEvaluator(new Accuracy()) // Model Accuracy
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging
Trainer trainer = model.newTrainer(config);
trainer.initialize(X.getShape());
Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = 0;
train(numEpochs, trainer, 512);
data = data.addColumns(
DoubleColumn.create("X", epochCount),
DoubleColumn.create("testAccuracy", testAccuracy)
);
}
INFO Training on: 2 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.020 ms.
Training: 100% |████████████████████████████████████████| Accuracy: 0.57, SoftmaxCrossEntropyLoss: 1.28
Validating: 100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.57, SoftmaxCrossEntropyLoss: 1.26
INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.72
Training: 100% |████████████████████████████████████████| Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.45
Validating: 100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.45
INFO Validate: Accuracy: 0.71, SoftmaxCrossEntropyLoss: 0.81
Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33
Validating: 100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33
INFO Validate: Accuracy: 0.80, SoftmaxCrossEntropyLoss: 0.61
Training: 100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28
Validating: 100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28
INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.74
Training: 100% |████████████████████████████████████████| Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
Validating: 100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.43
Training: 100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
Validating: 100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.33
Training: 100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.19
Validating: 100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.19
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.49
Training: 100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
Validating: 100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.44
Training: 100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.15
Validating: 100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.15
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.36
Training: 100% |████████████████████████████████████████| Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
Validating: 100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.37
test acc 0.87
14.4473384254 sec/epoch
// uncomment to view graph if you have 2 GPU, since the render function doesn't work inside the if condition scope.
// render(LinePlot.create("", data, "x", "testAccuracy"), "text/html");
12.3.5. Summary¶
Data is automatically evaluated on the devices where the data can be found.
Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error.
The optimization algorithms automatically aggregate over multiple GPUs.
12.3.6. Exercises¶
This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this on a p2.16xlarge instance with 16 GPUs?
Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?
12.3.7. Summary¶
Data is automatically evaluated on the devices where the data can be found.
Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error.
The optimization algorithms automatically aggregate over multiple GPUs.
12.3.8. Exercises¶
This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this on a p2.16xlarge instance with 16 GPUs?
Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?