Run this notebook online:Binder or Colab: Colab

11.5. Minibatch Stochastic Gradient Descent

So far we encountered two extremes in the approach to gradient based learning: Section 11.3 uses the full dataset to compute gradients and to update parameters, one pass at a time. Conversely Section 11.4 processes one observation at a time to make progress. Each of them has its own drawbacks. Gradient Descent is not particularly data efficient whenever data is very similar. Stochastic Gradient Descent is not particularly computationally efficient since CPUs and GPUs cannot exploit the full power of vectorization. This suggests that there might be a happy medium, and in fact, that’s what we have been using so far in the examples we discussed.

11.5.1. Vectorization and Caches

At the heart of the decision to use minibatches is computational efficiency. This is most easily understood when considering parallelization to multiple GPUs and multiple servers. In this case we need to send at least one image to each GPU. With 8 GPUs per server and 16 servers we already arrive at a minibatch size of 128.

Things are a bit more subtle when it comes to single GPUs or even CPUs. These devices have multiple types of memory, often multiple type of compute units and different bandwidth constraints between them. For instance, a CPU has a small number of registers and then L1, L2 and in some cases even L3 cache (which is shared between the different processor cores). These caches are of increasing size and latency (and at the same time they are of decreasing bandwidth). Suffice it to say, the processor is capable of performing many more operations than what the main memory interface is able to provide.

  • A 2GHz CPU with 16 cores and AVX-512 vectorization can process up to \(2 \cdot 10^9 \cdot 16 \cdot 32 = 10^{12}\) bytes per second. The capability of GPUs easily exceeds this number by a factor of 100. On the other hand, a midrange server processor might not have much more than 100 GB/s bandwidth, i.e., less than one tenth of what would be required to keep the processor fed. To make matters worse, not all memory access is created equal: first, memory interfaces are typically 64 bit wide or wider (e.g., on GPUs up to 384 bit), hence reading a single byte incurs the cost of a much wider access.

  • There is significant overhead for the first access whereas sequential access is relatively cheap (this is often called a burst read). There are many more things to keep in mind, such as caching when we have multiple sockets, chiplets and other structures. A detailed discussion of this is beyond the scope of this section. See e.g., this Wikipedia article for a more in-depth discussion.

The way to alleviate these constraints is to use a hierarchy of CPU caches which are actually fast enough to supply the processor with data. This is the driving force behind batching in deep learning. To keep matters simple, consider matrix-matrix multiplication, say \(\mathbf{A} = \mathbf{B}\mathbf{C}\). We have a number of options for calculating \(\mathbf{A}\). For instance we could try the following:

  1. We could compute \(\mathbf{A}_{ij} = \mathbf{B}_{i,:} \mathbf{C}_{:,j}^\top\), i.e., we could compute it element-wise by means of dot products.

  2. We could compute \(\mathbf{A}_{:,j} = \mathbf{B} \mathbf{C}_{:,j}^\top\), i.e., we could compute it one column at a time. Likewise we could compute \(\mathbf{A}\) one row \(\mathbf{A}_{i,:}\) at a time.

  3. We could simply compute \(\mathbf{A} = \mathbf{B} \mathbf{C}\).

  4. We could break \(\mathbf{B}\) and \(\mathbf{C}\) into smaller block matrices and compute \(\mathbf{A}\) one block at a time.

If we follow the first option, we will need to copy one row and one column vector into the CPU each time we want to compute an element \(\mathbf{A}_{ij}\). Even worse, due to the fact that matrix elements are aligned sequentially we are thus required to access many disjoint locations for one of the two vectors as we read them from memory. The second option is much more favorable. In it, we are able to keep the column vector \(\mathbf{C}_{:,j}\) in the CPU cache while we keep on traversing through \(B\). This halves the memory bandwidth requirement with correspondingly faster access. Of course, option 3 is most desirable. Unfortunately, most matrices might not entirely fit into cache (this is what we are discussing after all). However, option 4 offers a practically useful alternative: we can move blocks of the matrix into cache and multiply them locally. Optimized libraries take care of this for us. Let us have a look at how efficient these operations are in practice.

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/StopWatch.java
%load ../utils/Training.java
%load ../utils/Accumulator.java
import ai.djl.basicdataset.tabular.*;
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;
NDManager manager = NDManager.newBaseManager();
StopWatch stopWatch = new StopWatch();
NDArray A = manager.zeros(new Shape(256, 256));
NDArray B = manager.randomNormal(new Shape(256, 256));
NDArray C = manager.randomNormal(new Shape(256, 256));

Element-wise assignment simply iterates over all rows and columns of \(\mathbf{B}\) and \(\mathbf{C}\) respectively to assign the value to \(\mathbf{A}\).

// Compute A = B C one element at a time
stopWatch.start();
for (int i = 0; i < 256; i++) {
    for (int j = 0; j < 256; j++) {
        A.set(new NDIndex(i, j),
              B.get(new NDIndex(String.format("%d, :", i)))
              .dot(C.get(new NDIndex(String.format(":, %d", j)))));
    }
}
stopWatch.stop();
41.596357294

A faster strategy is to perform column-wise assignment.

// Compute A = B C one column at a time
stopWatch.start();
for (int j = 0; j < 256; j++) {
    A.set(new NDIndex(String.format(":, %d", j)), B.dot(C.get(new NDIndex(String.format(":, %d", j)))));
}
stopWatch.stop();
0.172290598

Last, the most effective manner is to perform the entire operation in one block. Let us see what the respective speed of the operations is.

// Compute A = B C in one go
stopWatch.start();
A = B.dot(C);
stopWatch.stop();

// Multiply and add count as separate operations (fused in practice)
float[] gigaflops = new float[stopWatch.getTimes().size()];
for (int i = 0; i < stopWatch.getTimes().size(); i++) {
    gigaflops[i] = (float)(2 / stopWatch.getTimes().get(i));
}
String.format("Performance in Gigaflops: element %.3f, column %.3f, full %.3f", gigaflops[0], gigaflops[1], gigaflops[2]);
Performance in Gigaflops: element 0.048, column 11.608, full 65.164

11.5.2. Minibatches

In the past we took it for granted that we would read minibatches of data rather than single observations to update parameters. We now give a brief justification for it. Processing single observations requires us to perform many single matrix-vector (or even vector-vector) multiplications, which is quite expensive and which incurs a significant overhead on behalf of the underlying deep learning framework. This applies both to evaluating a network when applied to data (often referred to as inference) and when computing gradients to update parameters. That is, this applies whenever we perform \(\mathbf{w} \leftarrow \mathbf{w} - \eta_t \mathbf{g}_t\) where

(11.5.1)\[\mathbf{g}_t = \partial_{\mathbf{w}} f(\mathbf{x}_{t}, \mathbf{w})\]

We can increase the computational efficiency of this operation by applying it to a minibatch of observations at a time. That is, we replace the gradient \(\mathbf{g}_t\) over a single observation by one over a small batch

(11.5.2)\[\mathbf{g}_t = \partial_{\mathbf{w}} \frac{1}{|\mathcal{B}_t|} \sum_{i \in \mathcal{B}_t} f(\mathbf{x}_{i}, \mathbf{w})\]

Let us see what this does to the statistical properties of \(\mathbf{g}_t\): since both \(\mathbf{x}_t\) and also all elements of the minibatch \(\mathcal{B}_t\) are drawn uniformly at random from the training set, the expectation of the gradient remains unchanged. The variance, on the other hand, is reduced significantly. Since the minibatch gradient is composed of \(b := |\mathcal{B}_t|\) independent gradients which are being averaged, its standard deviation is reduced by a factor of \(b^{-\frac{1}{2}}\). This, by itself, is a good thing, since it means that the updates are more reliably aligned with the full gradient.

Naively this would indicate that choosing a large minibatch \(\mathcal{B}_t\) would be universally desirable. Alas, after some point, the additional reduction in standard deviation is minimal when compared to the linear increase in computational cost. In practice we pick a minibatch that is large enough to offer good computational efficiency while still fitting into the memory of a GPU. To illustrate the savings let us have a look at some code. In it we perform the same matrix-matrix multiplication, but this time broken up into “minibatches” of 64 columns at a time.

stopWatch.start();
for (int j = 0; j < 256; j+=64) {
    A.set(new NDIndex(String.format(":, %d:%d", j, j + 64)),
        B.dot(C.get(new NDIndex(String.format(":, %d:%d", j, j + 64)))));
}
stopWatch.stop();

String.format("Performance in Gigaflops: block %.3f\n", 2 / stopWatch.getTimes().get(3));
Performance in Gigaflops: block 50.857

As we can see, the computation on the minibatch is essentially as efficient as on the full matrix. A word of caution is in order. In Section 7.5 we used a type of regularization that was heavily dependent on the amount of variance in a minibatch. As we increase the latter, the variance decreases and with it the benefit of the noise-injection due to batch normalization. See e.g., [Ioffe, 2017] for details on how to rescale and compute the appropriate terms.

11.5.3. Reading the Dataset

Let us have a look at how minibatches are efficiently generated from data. In the following we use a dataset developed by NASA to test the wing noise from different aircraft to compare these optimization algorithms. For convenience we only use the first \(1,500\) examples. The data is whitened for preprocessing, i.e., we remove the mean and rescale the variance to \(1\) per coordinate.

NDManager manager = NDManager.newBaseManager();

public AirfoilRandomAccess getDataCh11(int batchSize, int n) throws IOException, TranslateException {
    // Load data
    AirfoilRandomAccess airfoil = AirfoilRandomAccess.builder()
            .optUsage(Dataset.Usage.TRAIN)
            .setSampling(batchSize, true)
            .optNormalize(true)
            .optLimit(n)
            .build();
    return airfoil;
}

11.5.4. Implementation from Scratch

Recall the minibatch SGD implementation from Section 3.2. In the following we provide a slightly more general implementation. For convenience it has the same call signature as the other optimization algorithms introduced later in this chapter. Specifically, we add the status input states and place the hyperparameter in dictionary hyperparams. In addition, we will average the loss of each minibatch example in the training function, so the gradient in the optimization algorithm does not need to be divided by the batch size.

public class Optimization {
    public static void sgd(NDList params, NDList states, Map<String, Float> hyperparams) {
        for (int i = 0; i < params.size(); i++) {
            NDArray param = params.get(i);
            // Update param
            // param = param - param.gradient * lr
            param.subi(param.getGradient().mul(hyperparams.get("lr")));
        }
    }
}

Next, we implement a generic training function to facilitate the use of the other optimization algorithms introduced later in this chapter. It initializes a linear regression model and can be used to train the model with minibatch SGD and other algorithms introduced subsequently.

public static float evaluateLoss(Iterable<Batch> dataIterator, NDArray w, NDArray b) {
    Accumulator metric = new Accumulator(2);  // sumLoss, numExamples

    for (Batch batch : dataIterator) {
        NDArray X = batch.getData().head();
        NDArray y = batch.getLabels().head();
        NDArray yHat = Training.linreg(X, w, b);
        float lossSum = Training.squaredLoss(yHat, y).sum().getFloat();

        metric.add(new float[]{lossSum, (float) y.size()});
        batch.close();
    }
    return metric.get(0) / metric.get(1);
}
public static class LossTime {
    public float[] loss;
    public float[] time;

    public LossTime(float[] loss, float[] time) {
        this.loss = loss;
        this.time = time;
    }
}
public void plotLossEpoch(float[] loss, float[] epoch) {
    Table data = Table.create("data")
        .addColumns(
            DoubleColumn.create("epoch", Functions.floatToDoubleArray(epoch)),
            DoubleColumn.create("loss", Functions.floatToDoubleArray(loss))
    );
    display(LinePlot.create("loss vs. epoch", data, "epoch", "loss"));
}
public float[] arrayListToFloat (ArrayList<Double> arrayList) {
    float[] ret = new float[arrayList.size()];

    for (int i = 0; i < arrayList.size(); i++) {
        ret[i] = arrayList.get(i).floatValue();
    }
    return ret;
}

@FunctionalInterface
public static interface TrainerConsumer {
    void train(NDList params, NDList states, Map<String, Float> hyperparams);

}

public static LossTime trainCh11(TrainerConsumer trainer, NDList states, Map<String, Float> hyperparams,
                                AirfoilRandomAccess dataset,
                                int featureDim, int numEpochs) throws IOException, TranslateException {
    NDManager manager = NDManager.newBaseManager();
    NDArray w = manager.randomNormal(0, 0.01f, new Shape(featureDim, 1), DataType.FLOAT32);
    NDArray b = manager.zeros(new Shape(1));

    w.setRequiresGradient(true);
    b.setRequiresGradient(true);

    NDList params = new NDList(w, b);
    int n = 0;
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();

    float lastLoss = -1;
    ArrayList<Double> loss = new ArrayList<>();
    ArrayList<Double> epoch = new ArrayList<>();

    for (int i = 0; i < numEpochs; i++) {
        for (Batch batch : dataset.getData(manager)) {
            int len = (int) dataset.size() / batch.getSize();  // number of batches
            NDArray X = batch.getData().head();
            NDArray y = batch.getLabels().head();

            NDArray l;
            try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
                NDArray yHat = Training.linreg(X, params.get(0), params.get(1));
                l = Training.squaredLoss(yHat, y).mean();
                gc.backward(l);
            }

            trainer.train(params, states, hyperparams);
            n += X.getShape().get(0);

            if (n % 200 == 0) {
                stopWatch.stop();
                lastLoss = evaluateLoss(dataset.getData(manager), params.get(0), params.get(1));
                loss.add((double) lastLoss);
                double lastEpoch = 1.0 * n / X.getShape().get(0) / len;
                epoch.add(lastEpoch);
                stopWatch.start();
            }

            batch.close();
        }
    }
    float[] lossArray = arrayListToFloat(loss);
    float[] epochArray = arrayListToFloat(epoch);
    plotLossEpoch(lossArray, epochArray);
    System.out.printf("loss: %.3f, %.3f sec/epoch\n", lastLoss, stopWatch.avg());
    float[] timeArray = arrayListToFloat(stopWatch.cumsum());
    return new LossTime(lossArray, timeArray);
}

Let us see how optimization proceeds for batch gradient descent. This can be achieved by setting the minibatch size to 1500 (i.e., to the total number of examples). As a result the model parameters are updated only once per epoch. There is little progress. In fact, after 6 steps progress stalls.

public static LossTime trainSgd(float lr, int batchSize, int numEpochs) throws IOException, TranslateException {
    AirfoilRandomAccess dataset = getDataCh11(batchSize, 1500);
    int featureDim = dataset.getColumnNames().size();

    Map<String, Float> hyperparams = new HashMap<>();
    hyperparams.put("lr", lr);

    return trainCh11(Optimization::sgd, new NDList(), hyperparams, dataset, featureDim, numEpochs);
}

LossTime gdRes = trainSgd(1f, 1500, 10);
loss: 0.251, 0.675 sec/epoch

When the batch size equals 1, we use SGD for optimization. For simplicity of implementation we picked a constant (albeit small) learning rate. In SGD, the model parameters are updated whenever an example is processed. In our case this amounts to 1500 updates per epoch. As we can see, the decline in the value of the objective function slows down after one epoch. Although both the procedures processed 1500 examples within one epoch, SGD consumes more time than gradient descent in our experiment. This is because SGD updated the parameters more frequently and since it is less efficient to process single observations one at a time.

LossTime sgdRes = trainSgd(0.005f, 1, 2);
loss: 0.244, 0.267 sec/epoch

Last, when the batch size equals 100, we use minibatch SGD for optimization. The time required per epoch is longer than the time needed for SGD and the time for batch gradient descent.

LossTime mini1Res = trainSgd(0.4f, 100, 2);
loss: 0.249, 0.044 sec/epoch

Reducing the batch size to 10, the time for each epoch increases because the workload for each batch is less efficient to execute.

LossTime mini2Res = trainSgd(0.05f, 10, 2);
loss: 0.245, 0.062 sec/epoch

Finally, we compare the time versus loss for the preview four experiments. As can be seen, despite SGD converges faster than GD in terms of number of examples processed, it uses more time to reach the same loss than GD because that computing gradient example by example is not efficient. Minibatch SGD is able to trade-off the convergence speed and computation efficiency. A minibatch size 10 is more efficient than SGD; a minibatch size 100 even outperforms GD in terms of runtime.

public String[] getTypeArray(LossTime lossTime, String name) {
    String[] type = new String[lossTime.time.length];
    for (int i = 0; i < type.length; i++) {
        type[i] = name;
    }
    return type;
}
// Converts a float array to a log scale
float[] convertLogScale(float[] array) {
    float[] newArray = new float[array.length];
    for (int i = 0; i < array.length; i++) {
        newArray[i] = (float) Math.log10(array[i]);
    }
    return newArray;
}
float[] time = ArrayUtils.addAll(ArrayUtils.addAll(gdRes.time, sgdRes.time),
                                 ArrayUtils.addAll(mini1Res.time, mini2Res.time));
float[] loss = ArrayUtils.addAll(ArrayUtils.addAll(gdRes.loss, sgdRes.loss),
                                 ArrayUtils.addAll(mini1Res.loss, mini2Res.loss));
String[] type = ArrayUtils.addAll(ArrayUtils.addAll(getTypeArray(gdRes, "gd"),
                                                    getTypeArray(sgdRes, "sgd")),
                                  ArrayUtils.addAll(getTypeArray(mini1Res, "batch size = 100"),
                                  getTypeArray(mini1Res, "batch size = 10")));
Table data = Table.create("data")
    .addColumns(
        DoubleColumn.create("log time (sec)", Functions.floatToDoubleArray(convertLogScale(time))),
        DoubleColumn.create("loss", Functions.floatToDoubleArray(loss)),
        StringColumn.create("type", type)
    );
LinePlot.create("loss vs. time", data, "log time (sec)", "loss", "type");

11.5.5. Concise Implementation

In DJL, we can use the Optimizer package to access different optimization algorithms. This is used to implement a generic training function. We will use this throughout the current chapter.

public void trainConciseCh11(Optimizer sgd, AirfoilRandomAccess dataset,
                            int numEpochs) throws IOException, TranslateException {
    // Initialization
    NDManager manager = NDManager.newBaseManager();

    SequentialBlock net = new SequentialBlock();
    Linear linear = Linear.builder().setUnits(1).build();
    net.add(linear);
    net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);

    Model model = Model.newInstance("concise implementation");
    model.setBlock(net);

    Loss loss = Loss.l2Loss();

    DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
        .optOptimizer(sgd)
        .addEvaluator(new Accuracy()) // Model Accuracy
        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

    Trainer trainer = model.newTrainer(config);

    int n = 0;
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();

    trainer.initialize(new Shape(10, 5));

    Metrics metrics = new Metrics();
    trainer.setMetrics(metrics);

    float lastLoss = -1;

    ArrayList<Double> lossArray = new ArrayList<>();
    ArrayList<Double> epochArray = new ArrayList<>();

    for (Batch batch : trainer.iterateDataset(dataset)) {
        int len = (int) dataset.size() / batch.getSize();  // number of batches

        NDArray X = batch.getData().head();
        EasyTrain.trainBatch(trainer, batch);
        trainer.step();

        n += X.getShape().get(0);

        if (n % 200 == 0) {
            stopWatch.stop();
            stopWatch.stop();
            lastLoss = evaluateLoss(dataset.getData(manager), linear.getParameters().get(0).getValue().getArray()
                            .reshape(new Shape(dataset.getColumnNames().size(), 1)),
                    linear.getParameters().get(1).getValue().getArray());

            lossArray.add((double) lastLoss);
            double lastEpoch = 1.0 * n / X.getShape().get(0) / len;
            epochArray.add(lastEpoch);
            stopWatch.start();
        }
        batch.close();
    }
    plotLossEpoch(arrayListToFloat(lossArray), arrayListToFloat(epochArray));

    System.out.printf("loss: %.3f, %.3f sec/epoch\n", lastLoss, stopWatch.avg());
}

Using DJL to repeat the last experiment shows identical behavior.

AirfoilRandomAccess airfoilDataset = getDataCh11(10, 1500);

Tracker lrt = Tracker.fixed(0.05f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

trainConciseCh11(sgd, airfoilDataset, 2);
INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.059 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 1.00, L2Loss: 0.28
loss: 0.305, 1.743 sec/epoch

11.5.6. Summary

  • Vectorization makes code more efficient due to reduced overhead arising from the deep learning framework and due to better memory locality and caching on CPUs and GPUs.

  • There is a trade-off between statistical efficiency arising from SGD and computational efficiency arising from processing large batches of data at a time.

  • Minibatch stochastic gradient descent offers the best of both worlds: computational and statistical efficiency.

  • In minibatch SGD we process batches of data obtained by a random permutation of the training data (i.e., each observation is processed only once per epoch, albeit in random order).

  • It is advisable to decay the learning rates during training.

  • In general, minibatch SGD is faster than SGD and gradient descent for convergence to a smaller risk, when measured in terms of clock time.

11.5.7. Exercises

  1. Modify the batch size and learning rate and observe the rate of decline for the value of the objective function and the time consumed in each epoch.

  2. Read the DJL documentation and explore the different learning rate trackers in ai.djl.training.optimizer.tracker to see how they affect training. Try using a FactorTracker to reduce the learning rate to 1/10 of its previous value after each epoch.

  3. Compare minibatch SGD with a variant that actually samples with replacement from the training set. What happens?

  4. An evil genie replicates your dataset without telling you (i.e., each observation occurs twice and your dataset grows to twice its original size, but nobody told you). How does the behavior of SGD, minibatch SGD and that of gradient descent change?