Run this notebook online:Binder or Colab: Colab

12.3. Concise Implementation for Multiple GPUs

Implementing parallelism from scratch for every new model is no fun. Moreover, there is significant benefit in optimizing synchronization tools for high performance. In the following we will show how to do this using DJL. The math and the algorithms are the same as in sec_multi_gpu. As before we begin by importing the required modules (quite unsurprisingly you will need at least two GPUs to run this notebook).

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Training.java
import ai.djl.basicdataset.cv.classification.*;
import ai.djl.metric.*;
import org.apache.commons.lang3.ArrayUtils;

12.3.1. A Toy Network

Let us use a slightly more meaningful network than LeNet from the previous section that’s still sufficiently easy and quick to train. We pick a ResNet-18 variant [He et al., 2016a]. Since the input images are tiny we modify it slightly. In particular, the difference to Section 7.6 is that we use a smaller convolution kernel, stride, and padding at the beginning. Moreover, we remove the max-pooling layer.

class Residual extends AbstractBlock {

    private static final byte VERSION = 2;

    public ParallelBlock block;

    public Residual(int numChannels, boolean use1x1Conv, Shape strideShape) {
        super(VERSION);

        SequentialBlock b1;
        SequentialBlock conv1x1;

        b1 = new SequentialBlock();

        b1.add(Conv2d.builder()
                .setFilters(numChannels)
                .setKernelShape(new Shape(3, 3))
                .optPadding(new Shape(1, 1))
                .optStride(strideShape)
                .build())
                .add(BatchNorm.builder().build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setFilters(numChannels)
                        .setKernelShape(new Shape(3, 3))
                        .optPadding(new Shape(1, 1))
                        .build())
                .add(BatchNorm.builder().build());

        if (use1x1Conv) {
            conv1x1 = new SequentialBlock();
            conv1x1.add(Conv2d.builder()
                    .setFilters(numChannels)
                    .setKernelShape(new Shape(1, 1))
                    .optStride(strideShape)
                    .build());
        } else {
            conv1x1 = new SequentialBlock();
            conv1x1.add(Blocks.identityBlock());
        }

        block = addChildBlock("residualBlock", new ParallelBlock(
                list -> {
                    NDList unit = list.get(0);
                    NDList parallel = list.get(1);
                    return new NDList(
                            unit.singletonOrThrow()
                                    .add(parallel.singletonOrThrow())
                                    .getNDArrayInternal()
                                    .relu());
                },
                Arrays.asList(b1, conv1x1)));
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore parameterStore,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        return block.forward(parameterStore, inputs, training);
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputs) {
        Shape[] current = inputs;
        for (Block block : block.getChildren().values()) {
            current = block.getOutputShapes(current);
        }
        return current;
    }

    @Override
    protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
        block.initialize(manager, dataType, inputShapes);
    }
}
public SequentialBlock resnetBlock(int numChannels, int numResiduals, boolean isFirstBlock) {

        SequentialBlock blk = new SequentialBlock();
        for (int i = 0; i < numResiduals; i++) {

            if (i == 0 && !isFirstBlock) {
                blk.add(new Residual(numChannels, true, new Shape(2, 2)));
            } else {
                blk.add(new Residual(numChannels, false, new Shape(1, 1)));
            }
        }
        return blk;
}

int numClass = 10;
// This model uses a smaller convolution kernel, stride, and padding and
// removes the maximum pooling layer
SequentialBlock net = new SequentialBlock();
net
    .add(
            Conv2d.builder()
                    .setFilters(64)
                    .setKernelShape(new Shape(3, 3))
                    .optPadding(new Shape(1, 1))
                    .build())
    .add(BatchNorm.builder().build())
    .add(Activation::relu)
    .add(resnetBlock(64, 2, true))
    .add(resnetBlock(128, 2, false))
    .add(resnetBlock(256, 2, false))
    .add(resnetBlock(512, 2, false))
    .add(Pool.globalAvgPool2dBlock())
    .add(Linear.builder().setUnits(numClass).build());
SequentialBlock {
    Conv2d
    BatchNorm
    LambdaBlock
    SequentialBlock {
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
    }
    SequentialBlock {
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    Conv2d
                            }
                    }
            }
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
    }
    SequentialBlock {
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    Conv2d
                            }
                    }
            }
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
    }
    SequentialBlock {
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    Conv2d
                            }
                    }
            }
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
    }
    globalAvgPool2d
    Linear
}

12.3.2. Parameter Initialization and Logistics

The setInitializer method allows us to set initial defaults for parameters on a device of our choice. For a refresher see Section 4.8. What is particularly convenient is that it also lets us initialize the network on multiple devices simultaneously. Let us try how this works in practice.

Model model = Model.newInstance("training-multiple-gpus-1");
model.setBlock(net);

Loss loss = Loss.softmaxCrossEntropyLoss();

Tracker lrt = Tracker.fixed(0.1f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
        .optOptimizer(sgd) // Optimizer (loss function)
        .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer
        .optDevices(Engine.getInstance().getDevices(1)) // setting the number of GPUs needed
        .addEvaluator(new Accuracy()) // Model Accuracy
        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.070 ms.

Using the split function in the previous section we can divide a minibatch of data and copy portions to your list of devices, and this list can be retrieved from Device variable. We can then split the data with the help of Batchifier class. The network object automatically uses the appropriate GPU to compute the value of the forward propagation. As before we generate 4 observations and split them over the GPUs.

NDManager manager = NDManager.newBaseManager();
NDArray X = manager.randomUniform(0f, 1.0f, new Shape(4, 1, 28, 28));
trainer.initialize(X.getShape());

NDList[] res = Batchifier.STACK.split(new NDList(X), 4, true);

ParameterStore parameterStore = new ParameterStore(manager, true);

System.out.println(net.forward(parameterStore, new NDList(res[0]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[1]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[2]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[3]), false).singletonOrThrow());
ND: (1, 10) gpu(0) float32
[[-2.53076792e-07,  2.19176854e-06, -2.05096558e-06, -2.80443487e-07, -1.65612937e-06,  5.92275399e-07, -4.38029275e-07,  1.43108821e-07,  1.86682854e-07,  8.35030505e-07],
]

ND: (1, 10) gpu(0) float32
[[-3.17955994e-07,  1.94063477e-06, -1.82914255e-06,  1.36083145e-09, -1.45861077e-06,  4.11562326e-07, -8.99586439e-07,  1.97685665e-07,  2.77768578e-07,  6.80656115e-07],
]

ND: (1, 10) gpu(0) float32
[[-1.82850158e-07,  2.26233874e-06, -2.24626365e-06,  8.68596715e-08, -1.29084265e-06,  9.33801005e-07, -1.04999901e-06,  1.76022922e-07,  3.97307645e-08,  9.49504113e-07],
]

ND: (1, 10) gpu(0) float32
[[-1.78178539e-07,  1.59132321e-06, -2.00916884e-06, -2.30666600e-07, -1.31331467e-06,  5.71873784e-07, -4.02916669e-07,  1.11762461e-07,  3.40592749e-07,  8.89963815e-07],
]

Once data passes through the network, the corresponding parameters are initialized on the device the data passed through. This means that initialization happens on a per-device basis.

net.getChildren().values().get(0).getParameters().get("weight").getArray().get(new NDIndex("0:1"));
ND: (1, 1, 3, 3) gpu(0) float32
[[[[ 0.0053, -0.0018, -0.0141],
   [-0.0094, -0.0146,  0.0094],
   [ 0.002 ,  0.0189,  0.0014],
  ],
 ],
]

12.3.3. Training

As before, the training code needs to perform a number of basic functions for efficient parallelism:

  • Network parameters need to be initialized across all devices.

  • While iterating over the dataset minibatches are to be divided across all devices.

  • We compute the loss and its gradient in parallel across devices.

  • Losses are aggregated (by the trainer method) and parameters are updated accordingly.

In the end we compute the accuracy (again in parallel) to report the final value of the network. The training routine is quite similar to implementations in previous chapters, except that we need to split and aggregate data.

int numEpochs = Integer.getInteger("MAX_EPOCH", 10);

double[] testAccuracy;
double[] epochCount;

epochCount = new double[numEpochs];

for (int i = 0; i < epochCount.length; i++) {
    epochCount[i] = (i + 1);
}

Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = 0;
public void train(int numEpochs, Trainer trainer, int batchSize) throws IOException, TranslateException {

    FashionMnist trainIter = FashionMnist.builder()
            .optUsage(Dataset.Usage.TRAIN)
            .setSampling(batchSize, true)
            .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
            .build();
    FashionMnist testIter = FashionMnist.builder()
            .optUsage(Dataset.Usage.TEST)
            .setSampling(batchSize, true)
            .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
            .build();

    trainIter.prepare();
    testIter.prepare();

    Map<String, double[]> evaluatorMetrics = new HashMap<>();
    double avgTrainTime = 0;

    trainer.setMetrics(new Metrics());

    EasyTrain.fit(trainer, numEpochs, trainIter, testIter);

    Metrics metrics = trainer.getMetrics();

    trainer.getEvaluators().stream()
            .forEach(evaluator -> {
                evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());
            });

    avgTrainTime = metrics.mean("epoch");
    testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");
    System.out.printf("test acc %.2f\n" , testAccuracy[numEpochs-1]);
    System.out.println(avgTrainTime / Math.pow(10, 9) + " sec/epoch \n");
}

12.3.4. Experiments

Let us see how this works in practice. As a warmup we train the network on a single GPU.

Table data = null;
// We will check if we have at least 1 GPU available. If yes, we run the training on 1 GPU.
if (Engine.getInstance().getGpuCount() >= 1) {
    train(numEpochs, trainer, 256);

    data = Table.create("Data");
    data = data.addColumns(
            DoubleColumn.create("X", epochCount),
            DoubleColumn.create("testAccuracy", testAccuracy)
    );
}
Training:    100% |████████████████████████████████████████| Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.62
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.62
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.43
Training:    100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26
INFO Validate: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.20
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.20
INFO Validate: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.27
Training:    100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.35
Training:    100% |████████████████████████████████████████| Accuracy: 0.96, SoftmaxCrossEntropyLoss: 0.11
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.96, SoftmaxCrossEntropyLoss: 0.11
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.37
Training:    100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.28
Training:    100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.06
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.06
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.51
Training:    100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.05
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.05
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.46
Training:    100% |████████████████████████████████████████| Accuracy: 0.99, SoftmaxCrossEntropyLoss: 0.04
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.99, SoftmaxCrossEntropyLoss: 0.04
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.37
test acc 0.91
19.9016687377 sec/epoch
// uncomment to view graph if you have 1 GPU, since the render function doesn't work inside the if condition scope.
// render(LinePlot.create("", data, "x", "testAccuracy"), "text/html");
https://d2l-java-resources.s3.amazonaws.com/img/training-with-1-gpu.png

Fig. 12.3.1 Contour Gradient Descent.

Table data = Table.create("Data");

// We will check if we have more than 1 GPU available. If yes, we run the training on 2 GPU.
if (Engine.getInstance().getGpuCount() >= 1) {

    X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 28, 28));

    Model model = Model.newInstance("training-multiple-gpus-2");
    model.setBlock(net);

    loss = Loss.softmaxCrossEntropyLoss();

    Tracker lrt = Tracker.fixed(0.2f);
    Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

    DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
                .optOptimizer(sgd) // Optimizer (loss function)
                .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer
                .optDevices(Engine.getInstance().getDevices(2)) // setting the number of GPUs needed
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

    Trainer trainer = model.newTrainer(config);

    trainer.initialize(X.getShape());

    Map<String, double[]> evaluatorMetrics = new HashMap<>();
    double avgTrainTimePerEpoch = 0;

    train(numEpochs, trainer, 512);

    data = data.addColumns(
        DoubleColumn.create("X", epochCount),
        DoubleColumn.create("testAccuracy", testAccuracy)
    );
}
INFO Training on: 2 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.020 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.57, SoftmaxCrossEntropyLoss: 1.28
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.57, SoftmaxCrossEntropyLoss: 1.26
INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.72
Training:    100% |████████████████████████████████████████| Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.45
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.45
INFO Validate: Accuracy: 0.71, SoftmaxCrossEntropyLoss: 0.81
Training:    100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33
INFO Validate: Accuracy: 0.80, SoftmaxCrossEntropyLoss: 0.61
Training:    100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28
INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.74
Training:    100% |████████████████████████████████████████| Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.43
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.33
Training:    100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.19
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.19
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.49
Training:    100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.44
Training:    100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.15
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.15
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.36
Training:    100% |████████████████████████████████████████| Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.37
test acc 0.87
14.4473384254 sec/epoch
// uncomment to view graph if you have 2 GPU, since the render function doesn't work inside the if condition scope.
// render(LinePlot.create("", data, "x", "testAccuracy"), "text/html");
https://d2l-java-resources.s3.amazonaws.com/img/training-with-2-gpu.png

Fig. 12.3.2 Contour Gradient Descent.

12.3.5. Summary

  • Data is automatically evaluated on the devices where the data can be found.

  • Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error.

  • The optimization algorithms automatically aggregate over multiple GPUs.

12.3.6. Exercises

  1. This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this on a p2.16xlarge instance with 16 GPUs?

  2. Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?

12.3.7. Summary

  • Data is automatically evaluated on the devices where the data can be found.

  • Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error.

  • The optimization algorithms automatically aggregate over multiple GPUs.

12.3.8. Exercises

  1. This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this on a p2.16xlarge instance with 16 GPUs?

  2. Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?