Run this notebook online:Binder or Colab: Colab

7.3. Network in Network (NiN)

LeNet, AlexNet, and VGG all share a common design pattern: extract features exploiting spatial structure via a sequence of convolutions and pooling layers and then post-process the representations via fully-connected layers. The improvements upon LeNet by AlexNet and VGG mainly lie in how these later networks widen and deepen these two modules. Alternatively, one could imagine using fully-connected layers earlier in the process. However, a careless use of dense layers might give up the spatial structure of the representation entirely, Network in Network (NiN) blocks offer an alternative. They were proposed in [Lin et al., 2013] based on a very simple insight—to use an MLP on the channels for each pixel separately.

7.3.1. NiN Blocks

Recall that the inputs and outputs of convolutional layers consist of four-dimensional arrays with axes corresponding to the batch, channel, height, and width. Also recall that the inputs and outputs of fully-connected layers are typically two-dimensional arrays corresponding to the batch, and features. The idea behind NiN is to apply a fully-connected layer at each pixel location (for each height and width). If we tie the weights across each spatial location, we could think of this as a \(1\times 1\) convolutional layer (as described in Section 6.4) or as a fully-connected layer acting independently on each pixel location. Another way to view this is to think of each element in the spatial dimension (height and width) as equivalent to an example and the channel as equivalent to a feature. Fig. 7.3.1 illustrates the main structural differences between NiN and AlexNet, VGG, and other networks.

https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/nin.svg

Fig. 7.3.1 The figure on the left shows the network structure of AlexNet and VGG, and the figure on the right shows the network structure of NiN.

The NiN block consists of one convolutional layer followed by two \(1\times 1\) convolutional layers that act as per-pixel fully-connected layers with ReLU activations. The convolution width of the first layer is typically set by the user. The subsequent widths are fixed to \(1 \times 1\).

%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.7.0-SNAPSHOT
%maven ai.djl:model-zoo:0.7.0-SNAPSHOT
%maven ai.djl:basicdataset:0.7.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26

%maven ai.djl.mxnet:mxnet-engine:0.7.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-a
%%loadFromPOM
<dependency>
    <groupId>tech.tablesaw</groupId>
    <artifactId>tablesaw-jsplot</artifactId>
    <version>0.30.4</version>
</dependency>
%load ../utils/plot-utils.ipynb
%load ../utils/Training.java
%load ../utils/Accumulator.java
import java.nio.file.*;
import ai.djl.*;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.basicdataset.FashionMnist;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.translate.Pipeline;
import tech.tablesaw.api.*;
import tech.tablesaw.plotly.api.*;
import tech.tablesaw.plotly.components.*;
import tech.tablesaw.plotly.Plot;
import tech.tablesaw.plotly.components.Figure;
import org.apache.commons.lang3.ArrayUtils;
// setting the seed for demonstration purpose. You can remove it when you run the notebook
Engine.getInstance().setRandomSeed(5555);

public SequentialBlock niNBlock(int numChannels, Shape kernelShape,
                     Shape strideShape, Shape paddingShape){

    SequentialBlock tempBlock = new SequentialBlock();

    tempBlock.add(Conv2d.builder()
              .setKernelShape(kernelShape)
              .optStride(strideShape)
              .optPadding(paddingShape)
              .setFilters(numChannels)
              .build())
        .add(Activation::relu)
        .add(Conv2d.builder()
              .setKernelShape(new Shape(1, 1))
              .setFilters(numChannels)
              .build())
        .add(Activation::relu)
        .add(Conv2d.builder()
              .setKernelShape(new Shape(1, 1))
              .setFilters(numChannels)
              .build())
        .add(Activation::relu);

    return tempBlock;
}

7.3.2. NiN Model

The original NiN network was proposed shortly after AlexNet and clearly draws some inspiration. NiN uses convolutional layers with window shapes of \(11\times 11\), \(5\times 5\), and \(3\times 3\), and the corresponding numbers of output channels are the same as in AlexNet. Each NiN block is followed by a maximum pooling layer with a stride of 2 and a window shape of \(3\times 3\).

Once significant difference between NiN and AlexNet is that NiN avoids dense connections altogether. Instead, NiN uses an NiN block with a number of output channels equal to the number of label classes, followed by a global average pooling layer, yielding a vector of logits. One advantage of NiN’s design is that it significantly reduces the number of required model parameters. However, in practice, this design sometimes requires increased model training time.

SequentialBlock block = new SequentialBlock();

block.add(niNBlock(96, new Shape(11, 11), new Shape(4, 4), new Shape(0, 0)))
     .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
     .add(niNBlock(256, new Shape(5, 5), new Shape(1, 1), new Shape(2, 2)))
     .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
     .add(niNBlock(384, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
     .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
     .add(Dropout.builder().optRate(0.5f).build())
     // There are 10 label classes
     .add(niNBlock(10, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
     // The global average pooling layer automatically sets the window shape
     // to the height and width of the input
     .add(Pool.globalAvgPool2dBlock())
     // Transform the four-dimensional output into two-dimensional output
     // with a shape of (batch size, 10)
     .add(Blocks.batchFlattenBlock());
Sequential(
    Sequential(
            Conv2d(Uninitialized)
            Lambda()
            Conv2d(Uninitialized)
            Lambda()
            Conv2d(Uninitialized)
            Lambda()
    )
    Lambda()
    Sequential(
            Conv2d(Uninitialized)
            Lambda()
            Conv2d(Uninitialized)
            Lambda()
            Conv2d(Uninitialized)
            Lambda()
    )
    Lambda()
    Sequential(
            Conv2d(Uninitialized)
            Lambda()
            Conv2d(Uninitialized)
            Lambda()
            Conv2d(Uninitialized)
            Lambda()
    )
    Lambda()
    Dropout()
    Sequential(
            Conv2d(Uninitialized)
            Lambda()
            Conv2d(Uninitialized)
            Lambda()
            Conv2d(Uninitialized)
            Lambda()
    )
    Lambda()
    Lambda()
)

We create a data example to see the output shape of each block.

float lr = 0.1f;
Model model = Model.newInstance("cnn");
model.setBlock(block);

Loss loss = Loss.softmaxCrossEntropyLoss();

Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);

NDManager manager = NDManager.newBaseManager();
NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 224, 224));
trainer.initialize(X.getShape());

Shape currentShape = X.getShape();

for (int i = 0; i < block.getChildren().size(); i++) {

    Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(manager, new Shape[]{currentShape});
    currentShape = newShape[0];
    System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape);
}
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: 4 GPUs.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.125 ms.
01SequentialBlock layer output : (1, 96, 54, 54)
02LambdaBlock layer output : (1, 96, 26, 26)
03SequentialBlock layer output : (1, 256, 26, 26)
04LambdaBlock layer output : (1, 256, 12, 12)
05SequentialBlock layer output : (1, 384, 12, 12)
06LambdaBlock layer output : (1, 384, 5, 5)
07Dropout layer output : (1, 384, 5, 5)
08SequentialBlock layer output : (1, 10, 5, 5)
09LambdaBlock layer output : (1, 10)
10LambdaBlock layer output : (1, 10)

7.3.3. Data Acquisition and Training

As before we use Fashion-MNIST to train the model. NiN’s training is similar to that for AlexNet and VGG, but it often uses a larger learning rate.

int batchSize = 128;
int numEpochs = 10;
double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;

epochCount = new double[numEpochs];

for (int i = 0; i < epochCount.length; i++) {
    epochCount[i] = i+1;
}

FashionMnist trainIter = FashionMnist.builder()
                        .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
                        .optUsage(Dataset.Usage.TRAIN)
                        .setSampling(batchSize, true)
                        .build();

FashionMnist testIter = FashionMnist.builder()
                        .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
                        .optUsage(Dataset.Usage.TEST)
                        .setSampling(batchSize, true)
                        .build();

trainIter.prepare();
testIter.prepare();
Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = 0;
Training.trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics, avgTrainTimePerEpoch);
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 1 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 2 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 3 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 4 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 5 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 6 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 7 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 8 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 9 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 10 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
System.out.println();
loss 2.303, train acc 0.100, test acc 0.100
Infinity examples/sec
https://d2l-java-resources.s3.amazonaws.com/img/nin-plot.png

Fig. 7.3.2 Contour Gradient Descent.

String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, “train loss”); Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, “train acc”); Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length, trainLoss.length + testAccuracy.length + trainAccuracy.length, “test acc”);

Table data = Table.create(“Data”).addColumns( DoubleColumn.create(“epoch”, ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))), DoubleColumn.create(“metrics”, ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))), StringColumn.create(“lossLabel”, lossLabel) );

render(LinePlot.create(“”, data, “epoch”, “metrics”, “lossLabel”),”text/html”);

String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");

Table data = Table.create("Data").addColumns(
            DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
            DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
            StringColumn.create("lossLabel", lossLabel)
);

render(LinePlot.create("", data, "epoch", "metrics", "lossLabel"),"text/html");

7.3.4. Summary

  • NiN uses blocks consisting of a convolutional layer and multiple \(1\times 1\) convolutional layer. This can be used within the convolutional stack to allow for more per-pixel nonlinearity.

  • NiN removes the fully connected layers and replaces them with global average pooling (i.e., summing over all locations) after reducing the number of channels to the desired number of outputs (e.g., 10 for Fashion-MNIST).

  • Removing the dense layers reduces overfitting. NiN has dramatically fewer parameters.

  • The NiN design influenced many subsequent convolutional neural networks designs.

7.3.5. Exercises

  1. Tune the hyper-parameters to improve the classification accuracy.

  2. Why are there two \(1\times 1\) convolutional layers in the NiN block? Remove one of them, and then observe and analyze the experimental phenomena.

  3. Calculate the resource usage for NiN

    • What is the number of parameters?

    • What is the amount of computation?

    • What is the amount of memory needed during training?

    • What is the amount of memory needed during inference?

  4. What are possible problems with reducing the \(384 \times 5 \times 5\) representation to a \(10 \times 5 \times 5\) representation in one step?

7.3.6. Discussions

chapter_convolutional-modern/../img/qr_nin.svg