Run this notebook online:Binder or Colab: Colab

7.2. Networks Using Blocks (VGG)

While AlexNet proved that deep convolutional neural networks can achieve good results, it did not offer a general template to guide subsequent researchers in designing new networks. In the following sections, we will introduce several heuristic concepts commonly used to design deep networks.

Progress in this field mirrors that in chip design where engineers went from placing transistors to logical elements to logic blocks. Similarly, the design of neural network architectures had grown progressively more abstract, with researchers moving from thinking in terms of individual neurons to whole layers, and now to blocks, repeating patterns of layers.

The idea of using blocks first emerged from the Visual Geometry Group (VGG) at Oxford University, in their eponymously-named VGG network. It is easy to implement these repeated structures in code with any modern deep learning framework by using loops and subroutines.

7.2.1. VGG Blocks

The basic building block of classic convolutional networks is a sequence of the following layers: (i) a convolutional layer (with padding to maintain the resolution), (ii) a nonlinearity such as a ReLU, (iii) a pooling layer such as a max pooling layer. One VGG block consists of a sequence of convolutional layers, followed by a max pooling layer for spatial downsampling. In the original VGG paper [Simonyan & Zisserman, 2014], the authors employed convolutions with \(3\times3\) kernels and \(2 \times 2\) max pooling with stride of \(2\) (halving the resolution after each block). In the code below, we define a function called vggBlock to implement one VGG block. The function takes two arguments corresponding to the number of convolutional layers numConvs and the number of output channels numChannels.

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Training.java
%load ../utils/Accumulator.java
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;
public SequentialBlock vggBlock(int numConvs, int numChannels) {

    SequentialBlock tempBlock = new SequentialBlock();
    for (int i = 0; i < numConvs; i++) {
        // DJL has default stride of 1x1, so don't need to set it explicitly.
        tempBlock
                .add(Conv2d.builder()
                        .setFilters(numChannels)
                        .setKernelShape(new Shape(3, 3))
                        .optPadding(new Shape(1, 1))
                        .build()
                )
                .add(Activation::relu);
    }
    tempBlock.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));
    return tempBlock;
}

7.2.2. VGG Network

Like AlexNet and LeNet, the VGG Network can be partitioned into two parts: the first consisting mostly of convolutional and pooling layers and a second consisting of fully-connected layers. The convolutional portion of the net connects several vggBlock modules in succession. In Fig. 7.2.1, the variable convArch consists of a list of tuples (one per block), where each contains two values: the number of convolutional layers and the number of output channels, which are precisely the arguments requires to call the vggBlock function. The fully-connected module is identical to that covered in AlexNet.

https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/vgg.svg

Fig. 7.2.1 Designing a network from building blocks

The original VGG network had 5 convolutional blocks, among which the first two have one convolutional layer each and the latter three contain two convolutional layers each. The first block has 64 output channels and each subsequent block doubles the number of output channels, until that number reaches \(512\). Since this network uses \(8\) convolutional layers and \(3\) fully-connected layers, it is often called VGG-11.

int[][] convArch = {{1, 64}, {1, 128}, {2, 256}, {2, 512}, {2, 512}};

The following code implements VGG-11. This is a simple matter of executing a for loop over convArch.

public SequentialBlock VGG(int[][] convArch) {

    SequentialBlock block = new SequentialBlock();
    // The convolutional layer part
    for (int i = 0; i < convArch.length; i++) {
        block.add(vggBlock(convArch[i][0], convArch[i][1]));
    }

    // The fully connected layer part
    block
        .add(Blocks.batchFlattenBlock())
        .add(Linear
                .builder()
                .setUnits(4096)
                .build())
        .add(Activation::relu)
        .add(Dropout
                .builder()
                .optRate(0.5f)
                .build())
        .add(Linear
                .builder()
                .setUnits(4096)
                .build())
        .add(Activation::relu)
        .add(Dropout
                .builder()
                .optRate(0.5f)
                .build())
        .add(Linear.builder().setUnits(10).build());

    return block;
}

SequentialBlock block = VGG(convArch);

Next, we will construct a single-channel data example with a height and width of 224 to observe the output shape of each layer.

float lr = 0.05f;
Model model = Model.newInstance("vgg-display");
model.setBlock(block);

Loss loss = Loss.softmaxCrossEntropyLoss();

Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
                .optDevices(Engine.getInstance().getDevices(1)) // single GPU
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);

Shape inputShape = new Shape(1, 1, 224, 224);

try(NDManager manager = NDManager.newBaseManager()) {
    NDArray X = manager.randomUniform(0f, 1.0f, inputShape);
    trainer.initialize(inputShape);

    Shape currentShape = X.getShape();

    for (int i = 0; i < block.getChildren().size(); i++) {
        Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(new Shape[]{currentShape});
        currentShape = newShape[0];
        System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape);
    }
}
// save memory on VGG params
model.close();
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.056 ms.
01SequentialBlock layer output : (1, 64, 112, 112)
02SequentialBlock layer output : (1, 128, 56, 56)
03SequentialBlock layer output : (1, 256, 28, 28)
04SequentialBlock layer output : (1, 512, 14, 14)
05SequentialBlock layer output : (1, 512, 7, 7)
06LambdaBlock layer output : (1, 25088)
07Linear layer output : (1, 4096)
08LambdaBlock layer output : (1, 4096)
09Dropout layer output : (1, 4096)
10Linear layer output : (1, 4096)
11LambdaBlock layer output : (1, 4096)
12Dropout layer output : (1, 4096)
13Linear layer output : (1, 10)

As you can see, we halve height and width at each block, finally reaching a height and width of 7 before flattening the representations for processing by the fully-connected layer.

7.2.3. Model Training

Since VGG-11 is more computationally-heavy than AlexNet we construct a network with a smaller number of channels. This is more than sufficient for training on Fashion-MNIST.

int ratio = 4;

for(int i=0; i < convArch.length; i++){
    convArch[i][1] = convArch[i][1] / ratio;
}

inputShape = new Shape(1, 1, 96, 96); // resize the input shape to save memory

Model model = Model.newInstance("vgg-tiny");
SequentialBlock newBlock = VGG(convArch);
model.setBlock(newBlock);
Loss loss = Loss.softmaxCrossEntropyLoss();

Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
                .optDevices(Engine.getInstance().getDevices(1)) // single GPU
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

trainer = model.newTrainer(config);
trainer.initialize(inputShape);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.026 ms.
int batchSize = 128;
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);

double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;

epochCount = new double[numEpochs];

for (int i = 0; i < epochCount.length; i++) {
    epochCount[i] = i+1;
}

FashionMnist trainIter = FashionMnist.builder()
        .addTransform(new Resize(96))
        .addTransform(new ToTensor())
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

FashionMnist testIter = FashionMnist.builder()
        .addTransform(new Resize(96))
        .addTransform(new ToTensor())
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

trainIter.prepare();
testIter.prepare();

Apart from using a slightly larger learning rate, the model training process is similar to that of AlexNet in the last section.

Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = Training.trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics);
Training:    100% |████████████████████████████████████████| Accuracy: 0.59, SoftmaxCrossEntropyLoss: 1.23
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.59, SoftmaxCrossEntropyLoss: 1.22
INFO Validate: Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.49
Training:    100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.43
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.43
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37
Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33
Training:    100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.29
Training:    100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.28
INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.25
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.25
INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.25
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.25
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
Training:    100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.19
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.19
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
System.out.println();
loss 0.195, train acc 0.928, test acc 0.909
5887.1 examples/sec
https://d2l-java-resources.s3.amazonaws.com/img/chapter_convolution-modern-cnn-VGG.png

Fig. 7.2.2 Contour Gradient Descent.

String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");

Table data = Table.create("Data").addColumns(
    DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
    DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
    StringColumn.create("lossLabel", lossLabel)
);

render(LinePlot.create("", data, "epoch", "metrics", "lossLabel"), "text/html");

7.2.4. Summary

  • VGG-11 constructs a network using reusable convolutional blocks. Different VGG models can be defined by the differences in the number of convolutional layers and output channels in each block.

  • The use of blocks leads to very compact representations of the network definition. It allows for efficient design of complex networks.

  • In their work Simonyan and Ziserman experimented with various architectures. In particular, they found that several layers of deep and narrow convolutions (i.e., \(3 \times 3\)) were more effective than fewer layers of wider convolutions.

7.2.5. Exercises

  1. When printing out the dimensions of the layers we only saw 8 results rather than 11. Where did the remaining 3 layer informations go?

  2. Compared with AlexNet, VGG is much slower in terms of computation, and it also needs more GPU memory. Try to analyze the reasons for this.

  3. Try to change the height and width of the images in Fashion-MNIST from 224 to 96. What influence does this have on the experiments?

  4. Refer to Table 1 in [Simonyan & Zisserman, 2014] to construct other common models, such as VGG-16 or VGG-19.