Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_computational-performance/multiple-gpus-concise.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_computational-performance/multiple-gpus-concise.ipynb .. _sec_multi_gpu_concise: Concise Implementation for Multiple GPUs ======================================== Implementing parallelism from scratch for every new model is no fun. Moreover, there is significant benefit in optimizing synchronization tools for high performance. In the following we will show how to do this using DJL. The math and the algorithms are the same as in :numref:`sec_multi_gpu`. As before we begin by importing the required modules (quite unsurprisingly you will need at least two GPUs to run this notebook). .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/Training.java .. code:: java import ai.djl.basicdataset.cv.classification.*; import ai.djl.metric.*; import org.apache.commons.lang3.ArrayUtils; A Toy Network ------------- Let us use a slightly more meaningful network than LeNet from the previous section that's still sufficiently easy and quick to train. We pick a ResNet-18 variant :cite:`He.Zhang.Ren.ea.2016`. Since the input images are tiny we modify it slightly. In particular, the difference to :numref:`sec_resnet` is that we use a smaller convolution kernel, stride, and padding at the beginning. Moreover, we remove the max-pooling layer. .. code:: java class Residual extends AbstractBlock { private static final byte VERSION = 2; public ParallelBlock block; public Residual(int numChannels, boolean use1x1Conv, Shape strideShape) { super(VERSION); SequentialBlock b1; SequentialBlock conv1x1; b1 = new SequentialBlock(); b1.add(Conv2d.builder() .setFilters(numChannels) .setKernelShape(new Shape(3, 3)) .optPadding(new Shape(1, 1)) .optStride(strideShape) .build()) .add(BatchNorm.builder().build()) .add(Activation::relu) .add(Conv2d.builder() .setFilters(numChannels) .setKernelShape(new Shape(3, 3)) .optPadding(new Shape(1, 1)) .build()) .add(BatchNorm.builder().build()); if (use1x1Conv) { conv1x1 = new SequentialBlock(); conv1x1.add(Conv2d.builder() .setFilters(numChannels) .setKernelShape(new Shape(1, 1)) .optStride(strideShape) .build()); } else { conv1x1 = new SequentialBlock(); conv1x1.add(Blocks.identityBlock()); } block = addChildBlock("residualBlock", new ParallelBlock( list -> { NDList unit = list.get(0); NDList parallel = list.get(1); return new NDList( unit.singletonOrThrow() .add(parallel.singletonOrThrow()) .getNDArrayInternal() .relu()); }, Arrays.asList(b1, conv1x1))); } @Override protected NDList forwardInternal( ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { return block.forward(parameterStore, inputs, training); } @Override public Shape[] getOutputShapes(Shape[] inputs) { Shape[] current = inputs; for (Block block : block.getChildren().values()) { current = block.getOutputShapes(current); } return current; } @Override protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { block.initialize(manager, dataType, inputShapes); } } .. code:: java public SequentialBlock resnetBlock(int numChannels, int numResiduals, boolean isFirstBlock) { SequentialBlock blk = new SequentialBlock(); for (int i = 0; i < numResiduals; i++) { if (i == 0 && !isFirstBlock) { blk.add(new Residual(numChannels, true, new Shape(2, 2))); } else { blk.add(new Residual(numChannels, false, new Shape(1, 1))); } } return blk; } int numClass = 10; // This model uses a smaller convolution kernel, stride, and padding and // removes the maximum pooling layer SequentialBlock net = new SequentialBlock(); net .add( Conv2d.builder() .setFilters(64) .setKernelShape(new Shape(3, 3)) .optPadding(new Shape(1, 1)) .build()) .add(BatchNorm.builder().build()) .add(Activation::relu) .add(resnetBlock(64, 2, true)) .add(resnetBlock(128, 2, false)) .add(resnetBlock(256, 2, false)) .add(resnetBlock(512, 2, false)) .add(Pool.globalAvgPool2dBlock()) .add(Linear.builder().setUnits(numClass).build()); .. parsed-literal:: :class: output SequentialBlock { Conv2d BatchNorm LambdaBlock SequentialBlock { Residual { residualBlock { SequentialBlock { Conv2d BatchNorm LambdaBlock Conv2d BatchNorm } SequentialBlock { identity } } } Residual { residualBlock { SequentialBlock { Conv2d BatchNorm LambdaBlock Conv2d BatchNorm } SequentialBlock { identity } } } } SequentialBlock { Residual { residualBlock { SequentialBlock { Conv2d BatchNorm LambdaBlock Conv2d BatchNorm } SequentialBlock { Conv2d } } } Residual { residualBlock { SequentialBlock { Conv2d BatchNorm LambdaBlock Conv2d BatchNorm } SequentialBlock { identity } } } } SequentialBlock { Residual { residualBlock { SequentialBlock { Conv2d BatchNorm LambdaBlock Conv2d BatchNorm } SequentialBlock { Conv2d } } } Residual { residualBlock { SequentialBlock { Conv2d BatchNorm LambdaBlock Conv2d BatchNorm } SequentialBlock { identity } } } } SequentialBlock { Residual { residualBlock { SequentialBlock { Conv2d BatchNorm LambdaBlock Conv2d BatchNorm } SequentialBlock { Conv2d } } } Residual { residualBlock { SequentialBlock { Conv2d BatchNorm LambdaBlock Conv2d BatchNorm } SequentialBlock { identity } } } } globalAvgPool2d Linear } Parameter Initialization and Logistics -------------------------------------- The ``setInitializer`` method allows us to set initial defaults for parameters on a device of our choice. For a refresher see :numref:`sec_numerical_stability`. What is particularly convenient is that it also lets us initialize the network on *multiple* devices simultaneously. Let us try how this works in practice. .. code:: java Model model = Model.newInstance("training-multiple-gpus-1"); model.setBlock(net); Loss loss = Loss.softmaxCrossEntropyLoss(); Tracker lrt = Tracker.fixed(0.1f); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // Optimizer (loss function) .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer .optDevices(Engine.getInstance().getDevices(1)) // setting the number of GPUs needed .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging Trainer trainer = model.newTrainer(config); .. parsed-literal:: :class: output INFO Training on: 1 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.070 ms. Using the ``split`` function in the previous section we can divide a minibatch of data and copy portions to your list of devices, and this list can be retrieved from ``Device`` variable. We can then split the data with the help of ``Batchifier`` class. The network object *automatically* uses the appropriate GPU to compute the value of the forward propagation. As before we generate 4 observations and split them over the GPUs. .. code:: java NDManager manager = NDManager.newBaseManager(); NDArray X = manager.randomUniform(0f, 1.0f, new Shape(4, 1, 28, 28)); trainer.initialize(X.getShape()); NDList[] res = Batchifier.STACK.split(new NDList(X), 4, true); ParameterStore parameterStore = new ParameterStore(manager, true); System.out.println(net.forward(parameterStore, new NDList(res[0]), false).singletonOrThrow()); System.out.println(net.forward(parameterStore, new NDList(res[1]), false).singletonOrThrow()); System.out.println(net.forward(parameterStore, new NDList(res[2]), false).singletonOrThrow()); System.out.println(net.forward(parameterStore, new NDList(res[3]), false).singletonOrThrow()); .. parsed-literal:: :class: output ND: (1, 10) gpu(0) float32 [[-2.53076792e-07, 2.19176854e-06, -2.05096558e-06, -2.80443487e-07, -1.65612937e-06, 5.92275399e-07, -4.38029275e-07, 1.43108821e-07, 1.86682854e-07, 8.35030505e-07], ] ND: (1, 10) gpu(0) float32 [[-3.17955994e-07, 1.94063477e-06, -1.82914255e-06, 1.36083145e-09, -1.45861077e-06, 4.11562326e-07, -8.99586439e-07, 1.97685665e-07, 2.77768578e-07, 6.80656115e-07], ] ND: (1, 10) gpu(0) float32 [[-1.82850158e-07, 2.26233874e-06, -2.24626365e-06, 8.68596715e-08, -1.29084265e-06, 9.33801005e-07, -1.04999901e-06, 1.76022922e-07, 3.97307645e-08, 9.49504113e-07], ] ND: (1, 10) gpu(0) float32 [[-1.78178539e-07, 1.59132321e-06, -2.00916884e-06, -2.30666600e-07, -1.31331467e-06, 5.71873784e-07, -4.02916669e-07, 1.11762461e-07, 3.40592749e-07, 8.89963815e-07], ] Once data passes through the network, the corresponding parameters are initialized *on the device the data passed through*. This means that initialization happens on a per-device basis. .. code:: java net.getChildren().values().get(0).getParameters().get("weight").getArray().get(new NDIndex("0:1")); .. parsed-literal:: :class: output ND: (1, 1, 3, 3) gpu(0) float32 [[[[ 0.0053, -0.0018, -0.0141], [-0.0094, -0.0146, 0.0094], [ 0.002 , 0.0189, 0.0014], ], ], ] Training -------- As before, the training code needs to perform a number of basic functions for efficient parallelism: - Network parameters need to be initialized across all devices. - While iterating over the dataset minibatches are to be divided across all devices. - We compute the loss and its gradient in parallel across devices. - Losses are aggregated (by the trainer method) and parameters are updated accordingly. In the end we compute the accuracy (again in parallel) to report the final value of the network. The training routine is quite similar to implementations in previous chapters, except that we need to split and aggregate data. .. code:: java int numEpochs = Integer.getInteger("MAX_EPOCH", 10); double[] testAccuracy; double[] epochCount; epochCount = new double[numEpochs]; for (int i = 0; i < epochCount.length; i++) { epochCount[i] = (i + 1); } Map evaluatorMetrics = new HashMap<>(); double avgTrainTimePerEpoch = 0; .. code:: java public void train(int numEpochs, Trainer trainer, int batchSize) throws IOException, TranslateException { FashionMnist trainIter = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(batchSize, true) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); FashionMnist testIter = FashionMnist.builder() .optUsage(Dataset.Usage.TEST) .setSampling(batchSize, true) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); trainIter.prepare(); testIter.prepare(); Map evaluatorMetrics = new HashMap<>(); double avgTrainTime = 0; trainer.setMetrics(new Metrics()); EasyTrain.fit(trainer, numEpochs, trainIter, testIter); Metrics metrics = trainer.getMetrics(); trainer.getEvaluators().stream() .forEach(evaluator -> { evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); }); avgTrainTime = metrics.mean("epoch"); testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy"); System.out.printf("test acc %.2f\n" , testAccuracy[numEpochs-1]); System.out.println(avgTrainTime / Math.pow(10, 9) + " sec/epoch \n"); } Experiments ----------- Let us see how this works in practice. As a warmup we train the network on a single GPU. .. code:: java Table data = null; // We will check if we have at least 1 GPU available. If yes, we run the training on 1 GPU. if (Engine.getInstance().getGpuCount() >= 1) { train(numEpochs, trainer, 256); data = Table.create("Data"); data = data.addColumns( DoubleColumn.create("X", epochCount), DoubleColumn.create("testAccuracy", testAccuracy) ); } .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.62 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 1 finished. INFO Train: Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.62 INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.43 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 2 finished. INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26 INFO Validate: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.20 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 3 finished. INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.20 INFO Validate: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.27 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 4 finished. INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17 INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 5 finished. INFO Train: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14 INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.35 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.96, SoftmaxCrossEntropyLoss: 0.11 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 6 finished. INFO Train: Accuracy: 0.96, SoftmaxCrossEntropyLoss: 0.11 INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.37 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 7 finished. INFO Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09 INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.28 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.06 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 8 finished. INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.06 INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.51 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.05 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 9 finished. INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.05 INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.46 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.99, SoftmaxCrossEntropyLoss: 0.04 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 10 finished. INFO Train: Accuracy: 0.99, SoftmaxCrossEntropyLoss: 0.04 INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.37 .. parsed-literal:: :class: output test acc 0.91 19.9016687377 sec/epoch .. code:: java // uncomment to view graph if you have 1 GPU, since the render function doesn't work inside the if condition scope. // render(LinePlot.create("", data, "x", "testAccuracy"), "text/html"); .. figure:: https://d2l-java-resources.s3.amazonaws.com/img/training-with-1-gpu.png Contour Gradient Descent. .. code:: java Table data = Table.create("Data"); // We will check if we have more than 1 GPU available. If yes, we run the training on 2 GPU. if (Engine.getInstance().getGpuCount() >= 1) { X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 28, 28)); Model model = Model.newInstance("training-multiple-gpus-2"); model.setBlock(net); loss = Loss.softmaxCrossEntropyLoss(); Tracker lrt = Tracker.fixed(0.2f); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // Optimizer (loss function) .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer .optDevices(Engine.getInstance().getDevices(2)) // setting the number of GPUs needed .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging Trainer trainer = model.newTrainer(config); trainer.initialize(X.getShape()); Map evaluatorMetrics = new HashMap<>(); double avgTrainTimePerEpoch = 0; train(numEpochs, trainer, 512); data = data.addColumns( DoubleColumn.create("X", epochCount), DoubleColumn.create("testAccuracy", testAccuracy) ); } .. parsed-literal:: :class: output INFO Training on: 2 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.020 ms. .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.57, SoftmaxCrossEntropyLoss: 1.28 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 1 finished. INFO Train: Accuracy: 0.57, SoftmaxCrossEntropyLoss: 1.26 INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.72 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.45 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 2 finished. INFO Train: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.45 INFO Validate: Accuracy: 0.71, SoftmaxCrossEntropyLoss: 0.81 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 3 finished. INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33 INFO Validate: Accuracy: 0.80, SoftmaxCrossEntropyLoss: 0.61 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 4 finished. INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28 INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.74 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 5 finished. INFO Train: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24 INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.43 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 6 finished. INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21 INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.33 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.19 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 7 finished. INFO Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.19 INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.49 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 8 finished. INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17 INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.44 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.15 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 9 finished. INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.15 INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.36 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 10 finished. INFO Train: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14 INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.37 .. parsed-literal:: :class: output test acc 0.87 14.4473384254 sec/epoch .. code:: java // uncomment to view graph if you have 2 GPU, since the render function doesn't work inside the if condition scope. // render(LinePlot.create("", data, "x", "testAccuracy"), "text/html"); .. figure:: https://d2l-java-resources.s3.amazonaws.com/img/training-with-2-gpu.png Contour Gradient Descent. Summary ------- - Data is automatically evaluated on the devices where the data can be found. - Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error. - The optimization algorithms automatically aggregate over multiple GPUs. Exercises --------- 1. This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this on a p2.16xlarge instance with 16 GPUs? 2. Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not? Summary ------- - Data is automatically evaluated on the devices where the data can be found. - Take care to initialize the networks on each device before trying to access the parameters on that device. Otherwise you will encounter an error. - The optimization algorithms automatically aggregate over multiple GPUs. Exercises --------- 1. This section uses ResNet-18. Try different epochs, batch sizes, and learning rates. Use more GPUs for computation. What happens if you try this on a p2.16xlarge instance with 16 GPUs? 2. Sometimes, different devices provide different computing power. We could use the GPUs and the CPU at the same time. How should we divide the work? Is it worth the effort? Why? Why not?