Run this notebook online:Binder or Colab: Colab

7.1. Deep Convolutional Neural Networks (AlexNet)

Although convolutional neural networks were well known in the computer vision and machine learning communities following the introduction of LeNet, they did not immediately dominate the field. Although LeNet achieved good results on early small datasets, the performance and feasibility of training convolutional networks on larger, more realistic datasets had yet to be established. In fact, for much of the intervening time between the early 1990s and the watershed results of 2012, neural networks were often surpassed by other machine learning methods, such as support vector machines.

For computer vision, this comparison is perhaps not fair. That is although the inputs to convolutional networks consist of raw or lightly-processed (e.g., by centering) pixel values, practitioners would never feed raw pixels into traditional models. Instead, typical computer vision pipelines consisted of manually engineering feature extraction pipelines. Rather than learn the features, the features were crafted. Most of the progress came from having more clever ideas for features, and the learning algorithm was often relegated to an afterthought.

Although some neural network accelerators were available in the 1990s, they were not yet sufficiently powerful to make deep multichannel, multilayer convolutional neural networks with a large number of parameters. Moreover, datasets were still relatively small. Added to these obstacles, key tricks for training neural networks including parameter initialization heuristics, clever variants of stochastic gradient descent, non-squashing activation functions, and effective regularization techniques were still missing.

Thus, rather than training end-to-end (pixel to classification) systems, classical pipelines looked more like this:

  1. Obtain an interesting dataset. In early days, these datasets required expensive sensors (at the time, 1 megapixel images were state of the art).

  2. Preprocess the dataset with hand-crafted features based on some knowledge of optics, geometry, other analytic tools, and occasionally on the serendipitous discoveries of lucky graduate students.

  3. Feed the data through a standard set of feature extractors such as SIFT, the Scale-Invariant Feature Transform, or SURF, the Speeded-Up Robust Features, or any number of other hand-tuned pipelines.

  4. Dump the resulting representations into your favorite classifier, likely a linear model or kernel method, to train a classifier.

If you spoke to machine learning researchers, they believed that machine learning was both important and beautiful. Elegant theories proved the properties of various classifiers. The field of machine learning was thriving, rigorous and eminently useful. However, if you spoke to a computer vision researcher, you’d hear a very different story. The dirty truth of image recognition, they’d tell you, is that features, not learning algorithms, drove progress. Computer vision researchers justifiably believed that a slightly bigger or cleaner dataset or a slightly improved feature-extraction pipeline mattered far more to the final accuracy than any learning algorithm.

7.1.1. Learning Feature Representation

Another way to cast the state of affairs is that the most important part of the pipeline was the representation. And up until 2012 the representation was calculated mechanically. In fact, engineering a new set of feature functions, improving results, and writing up the method was a prominent genre of paper. SIFT, SURF, HOG, Bags of visual words and similar feature extractors ruled the roost.

Another group of researchers, including Yann LeCun, Geoff Hinton, Yoshua Bengio, Andrew Ng, Shun-ichi Amari, and Juergen Schmidhuber, had different plans. They believed that features themselves ought to be learned. Moreover, they believed that to be reasonably complex, the features ought to be hierarchically composed with multiple jointly learned layers, each with learnable parameters. In the case of an image, the lowest layers might come to detect edges, colors, and textures. Indeed, [Krizhevsky et al., 2012] proposed a new variant of a convolutional neural network which achieved excellent performance in the ImageNet challenge.

Interestingly in the lowest layers of the network, the model learned feature extractors that resembled some traditional filters. Fig. 7.1.1 is reproduced from this paper and describes lower-level image descriptors.

https://github.com/d2l-ai/d2l-en/blob/master/img/filters.png?raw=true

Fig. 7.1.1 Image filters learned by the first layer of AlexNet

Higher layers in the network might build upon these representations to represent larger structures, like eyes, noses, blades of grass, etc. Even higher layers might represent whole objects like people, airplanes, dogs, or frisbees. Ultimately, the final hidden state learns a compact representation of the image that summarizes its contents such that data belonging to different categories be separated easily.

While the ultimate breakthrough for many-layered convolutional networks came in 2012, a core group of researchers had dedicated themselves to this idea, attempting to learn hierarchical representations of visual data for many years. The ultimate breakthrough in 2012 can be attributed to two key factors.

7.1.1.1. Missing Ingredient - Data

Deep models with many layers require large amounts of data in order to enter the regime where they significantly outperform traditional methods based on convex optimizations (e.g., linear and kernel methods). However, given the limited storage capacity of computers, the relative expense of sensors, and the comparatively tighter research budgets in the 1990s, most research relied on tiny datasets. Numerous papers addressed the UCI collection of datasets, many of which contained only hundreds or (a few) thousands of images captured in unnatural settings with low resolution.

In 2009, the ImageNet dataset was released, challenging researchers to learn models from 1 million examples, 1,000 each from 1,000 distinct categories of objects. The researchers, led by Fei-Fei Li, who introduced this dataset leveraged Google Image Search to prefilter large candidate sets for each category and employed the Amazon Mechanical Turk crowdsourcing pipeline to confirm for each image whether it belonged to the associated category. This scale was unprecedented. The associated competition, dubbed the ImageNet Challenge pushed computer vision and machine learning research forward, challenging researchers to identify which models performed best at a greater scale than academics had previously considered.

7.1.1.2. Missing Ingredient - Hardware

Deep learning models are voracious consumers of compute cycles. Training can take hundreds of epochs, and each iteration requires passing data through many layers of computationally-expensive linear algebra operations. This is one of the main reasons why in the 90s and early 2000s, simple algorithms based on the more-efficiently optimized convex objectives were preferred.

Graphical processing units (GPUs) proved to be a game changer in make deep learning feasible. These chips had long been developed for accelerating graphics processing to benefit computer games. In particular, they were optimized for high throughput 4x4 matrix-vector products, which are needed for many computer graphics tasks. Fortunately, this math is strikingly similar to that required to calculate convolutional layers. Around that time, NVIDIA and ATI had begun optimizing GPUs for general compute operations, going as far as to market them as General Purpose GPUs (GPGPU).

To provide some intuition, consider the cores of a modern microprocessor (CPU). Each of the cores is fairly powerful running at a high clock frequency and sporting large caches (up to several MB of L3). Each core is well-suited to executing a wide range of instructions, with branch predictors, a deep pipeline, and other bells and whistles that enable it to run a large variety of programs. This apparent strength, however, is also its Achilles heel: general purpose cores are very expensive to build. They require lots of chip area, a sophisticated support structure (memory interfaces, caching logic between cores, high speed interconnects, etc.), and they are comparatively bad at any single task. Modern laptops have up to 4 cores, and even high end servers rarely exceed 64 cores, simply because it is not cost effective.

By comparison, GPUs consist of 100-1000 small processing elements (the details differ somewhat between NVIDIA, ATI, ARM and other chip vendors), often grouped into larger groups (NVIDIA calls them warps). While each core is relatively weak, sometimes even running at sub-1GHz clock frequency, it is the total number of such cores that makes GPUs orders of magnitude faster than CPUs. For instance, NVIDIA’s latest Volta generation offers up to 120 TFlops per chip for specialized instructions (and up to 24 TFlops for more general purpose ones), while floating point performance of CPUs has not exceeded 1 TFlop to date. The reason for why this is possible is actually quite simple: first, power consumption tends to grow quadratically with clock frequency. Hence, for the power budget of a CPU core that runs 4x faster (a typical number), you can use 16 GPU cores at 1/4 the speed, which yields 16 x 1/4 = 4x the performance. Furthermore, GPU cores are much simpler (in fact, for a long time they were not even able to execute general purpose code), which makes them more energy efficient. Last, many operations in deep learning require high memory bandwidth. Again, GPUs shine here with buses that are at least 10x as wide as many CPUs.

Back to 2012. A major breakthrough came when Alex Krizhevsky and Ilya Sutskever implemented a deep convolutional neural network that could run on GPU hardware. They realized that the computational bottlenecks in CNNs (convolutions and matrix multiplications) are all operations that could be parallelized in hardware. Using two NVIDIA GTX 580s with 3GB of memory, they implemented fast convolutions. The code cuda-convnet was good enough that for several years it was the industry standard and powered the first couple years of the deep learning boom.

7.1.2. AlexNet

AlexNet was introduced in 2012, named after Alex Krizhevsky, the first author of the breakthrough ImageNet classification paper [Krizhevsky et al., 2012]. AlexNet, which employed an 8-layer convolutional neural network, won the ImageNet Large Scale Visual Recognition Challenge 2012 by a phenomenally large margin. This network proved, for the first time, that the features obtained by learning can transcend manually-design features, breaking the previous paradigm in computer vision. The architectures of AlexNet and LeNet are very similar, as Fig. 7.1.2 illustrates. Note that we provide a slightly streamlined version of AlexNet removing some of the design quirks that were needed in 2012 to make the model fit on two small GPUs.

https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/alexnet.svg

Fig. 7.1.2 LeNet (left) and AlexNet (right)

The design philosophies of AlexNet and LeNet are very similar, but there are also significant differences. First, AlexNet is much deeper than the comparatively small LeNet5. AlexNet consists of eight layers: five convolutional layers, two fully-connected hidden layers, and one fully-connected output layer. Second, AlexNet used the ReLU instead of the sigmoid as its activation function. Let us delve into the details below.

7.1.2.1. Architecture

In AlexNet’s first layer, the convolution window shape is \(11\times11\). Since most images in ImageNet are more than ten times higher and wider than the MNIST images, objects in ImageNet data tend to occupy more pixels. Consequently, a larger convolution window is needed to capture the object. The convolution window shape in the second layer is reduced to \(5\times5\), followed by \(3\times3\). In addition, after the first, second, and fifth convolutional layers, the network adds maximum pooling layers with a window shape of \(3\times3\) and a stride of 2. Moreover, AlexNet has ten times more convolution channels than LeNet.

After the last convolutional layer are two fully-connected layers with 4096 outputs. These two huge fully-connected layers produce model parameters of nearly 1 GB. Due to the limited memory in early GPUs, the original AlexNet used a dual data stream design, so that each of their two GPUs could be responsible for storing and computing only its half of the model. Fortunately, GPU memory is comparatively abundant now, so we rarely need to break up models across GPUs these days (our version of the AlexNet model deviates from the original paper in this aspect).

7.1.2.2. Activation Functions

Second, AlexNet changed the sigmoid activation function to a simpler ReLU activation function. On the one hand, the computation of the ReLU activation function is simpler. For example, it does not have the exponentiation operation found in the sigmoid activation function. On the other hand, the ReLU activation function makes model training easier when using different parameter initialization methods. This is because, when the output of the sigmoid activation function is very close to 0 or 1, the gradient of these regions is almost 0, so that back propagation cannot continue to update some of the model parameters. In contrast, the gradient of the ReLU activation function in the positive interval is always 1. Therefore, if the model parameters are not properly initialized, the sigmoid function may obtain a gradient of almost 0 in the positive interval, so that the model cannot be effectively trained.

7.1.2.3. Capacity Control and Preprocessing

AlexNet controls the model complexity of the fully-connected layer by dropout (Section 4.6), while LeNet only uses weight decay. To augment the data even further, the training loop of AlexNet added a great deal of image augmentation, such as flipping, clipping, and color changes. This makes the model more robust and the larger sample size effectively reduces overfitting. We will discuss data augmentation in greater detail in sec_image_augmentation.

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Training.java
%load ../utils/Accumulator.java
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;
NDManager manager = NDManager.newBaseManager();

SequentialBlock block = new SequentialBlock();
// Here, we use a larger 11 x 11 window to capture objects. At the same time,
// we use a stride of 4 to greatly reduce the height and width of the output.
//Here, the number of output channels is much larger than that in LeNet
block
        .add(Conv2d.builder()
                .setKernelShape(new Shape(11, 11))
                .optStride(new Shape(4, 4))
                .setFilters(96).build())
        .add(Activation::relu)
        .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
        // Make the convolution window smaller, set padding to 2 for consistent
        // height and width across the input and output, and increase the
        // number of output channels
        .add(Conv2d.builder()
                .setKernelShape(new Shape(5, 5))
                .optPadding(new Shape(2, 2))
                .setFilters(256).build())
        .add(Activation::relu)
        .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
        // Use three successive convolutional layers and a smaller convolution
        // window. Except for the final convolutional layer, the number of
        // output channels is further increased. Pooling layers are not used to
        // reduce the height and width of input after the first two
        // convolutional layers
        .add(Conv2d.builder()
                .setKernelShape(new Shape(3, 3))
                .optPadding(new Shape(1, 1))
                .setFilters(384).build())
        .add(Activation::relu)
        .add(Conv2d.builder()
                .setKernelShape(new Shape(3, 3))
                .optPadding(new Shape(1, 1))
                .setFilters(384).build())
        .add(Activation::relu)
        .add(Conv2d.builder()
                .setKernelShape(new Shape(3, 3))
                .optPadding(new Shape(1, 1))
                .setFilters(256).build())
        .add(Activation::relu)
        .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
        // Here, the number of outputs of the fully connected layer is several
        // times larger than that in LeNet. Use the dropout layer to mitigate
        // overfitting
        .add(Blocks.batchFlattenBlock())
        .add(Linear
                .builder()
                .setUnits(4096)
                .build())
        .add(Activation::relu)
        .add(Dropout
                .builder()
                .optRate(0.5f)
                .build())
        .add(Linear
                .builder()
                .setUnits(4096)
                .build())
        .add(Activation::relu)
        .add(Dropout
                .builder()
                .optRate(0.5f)
                .build())
        // Output layer. Since we are using Fashion-MNIST, the number of
        // classes is 10, instead of 1000 as in the paper
        .add(Linear.builder().setUnits(10).build());
SequentialBlock {
    Conv2d
    LambdaBlock
    maxPool2d
    Conv2d
    LambdaBlock
    maxPool2d
    Conv2d
    LambdaBlock
    Conv2d
    LambdaBlock
    Conv2d
    LambdaBlock
    maxPool2d
    batchFlatten
    Linear
    LambdaBlock
    Dropout
    Linear
    LambdaBlock
    Dropout
    Linear
}

We construct a single-channel data instance with both height and width of 224 to observe the output shape of each layer. It matches our diagram above.

float lr = 0.01f;

Model model = Model.newInstance("cnn");
model.setBlock(block);

Loss loss = Loss.softmaxCrossEntropyLoss();

Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);

NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 224, 224));
trainer.initialize(X.getShape());

Shape currentShape = X.getShape();

for (int i = 0; i < block.getChildren().size(); i++) {
    Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(new Shape[]{currentShape});
    currentShape = newShape[0];
    System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape);
}
INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.070 ms.
01Conv2d layer output : (1, 96, 54, 54)
02LambdaBlock layer output : (1, 96, 54, 54)
03LambdaBlock layer output : (1, 96, 26, 26)
04Conv2d layer output : (1, 256, 26, 26)
05LambdaBlock layer output : (1, 256, 26, 26)
06LambdaBlock layer output : (1, 256, 12, 12)
07Conv2d layer output : (1, 384, 12, 12)
08LambdaBlock layer output : (1, 384, 12, 12)
09Conv2d layer output : (1, 384, 12, 12)
10LambdaBlock layer output : (1, 384, 12, 12)
11Conv2d layer output : (1, 256, 12, 12)
12LambdaBlock layer output : (1, 256, 12, 12)
13LambdaBlock layer output : (1, 256, 5, 5)
14LambdaBlock layer output : (1, 6400)
15Linear layer output : (1, 4096)
16LambdaBlock layer output : (1, 4096)
17Dropout layer output : (1, 4096)
18Linear layer output : (1, 4096)
19LambdaBlock layer output : (1, 4096)
20Dropout layer output : (1, 4096)
21Linear layer output : (1, 10)

7.1.3. Reading the Dataset

Although AlexNet uses ImageNet in the paper, we use Fashion-MNIST here since training an ImageNet model to convergence could take hours or days even on a modern GPU. One of the problems with applying AlexNet directly on Fashion-MNIST is that our images are lower resolution (\(28 \times 28\) pixels) than ImageNet images. To make things work, we upsample them to \(224 \times 224\) (generally not a smart practice, but we do it here to be faithful to the AlexNet architecture). We perform this resizing with the Resize function in Fashion-MNIST.

int batchSize = 128;
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);

double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;

epochCount = new double[numEpochs];

for (int i = 0; i < epochCount.length; i++) {
    epochCount[i] = (i + 1);
}

FashionMnist trainIter = FashionMnist.builder()
        .addTransform(new Resize(224))
        .addTransform(new ToTensor())
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

FashionMnist testIter = FashionMnist.builder()
        .addTransform(new Resize(224))
        .addTransform(new ToTensor())
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

trainIter.prepare();
testIter.prepare();

7.1.4. Training

Now, we can start training AlexNet. Compared to LeNet in the previous section, the main change here is the use of a smaller learning rate and much slower training due to the deeper and wider network, the higher image resolution and the more costly convolutions.

Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = Training.trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics);
Training:    100% |████████████████████████████████████████| Accuracy: 0.62, SoftmaxCrossEntropyLoss: 1.15
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.62, SoftmaxCrossEntropyLoss: 1.15
INFO Validate: Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.63
Training:    100% |████████████████████████████████████████| Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.46
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.46
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.43
Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.34
Training:    100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.32
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.32
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
Training:    100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.29
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.29
INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.27
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.27
INFO Validate: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32
Training:    100% |████████████████████████████████████████| Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.25
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.25
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.26
Training:    100% |████████████████████████████████████████| Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.23
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.23
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
INFO Validate: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.27
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.26
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
System.out.println();
loss 0.209, train acc 0.923, test acc 0.908
2479.3 examples/sec
https://d2l-java-resources.s3.amazonaws.com/img/chapter_convolution-modern-cnn-AlexNet.png

Fig. 7.1.3 Contour Gradient Descent.

String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");

Table data = Table.create("Data").addColumns(
            DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
            DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
            StringColumn.create("lossLabel", lossLabel)
);

render(LinePlot.create("", data, "epoch", "metrics", "lossLabel"),"text/html");

7.1.5. Summary

  • AlexNet has a similar structure to that of LeNet, but uses more convolutional layers and a larger parameter space to fit the large-scale dataset ImageNet.

  • Today AlexNet has been surpassed by much more effective architectures but it is a key step from shallow to deep networks that are used nowadays.

  • Although it seems that there are only a few more lines in AlexNet’s implementation than in LeNet, it took the academic community many years to embrace this conceptual change and take advantage of its excellent experimental results. This was also due to the lack of efficient computational tools.

  • Dropout, ReLU and preprocessing were the other key steps in achieving excellent performance in computer vision tasks.

7.1.6. Exercises

  1. Try increasing the number of epochs. Compared with LeNet, how are the results different? Why?

  2. AlexNet may be too complex for the Fashion-MNIST dataset.

    • Try to simplify the model to make the training faster, while ensuring that the accuracy does not drop significantly.

    • Can you design a better model that works directly on \(28 \times 28\) images.

  3. Modify the batch size, and observe the changes in accuracy and GPU memory.

  4. Rooflines:

    • What is the dominant part for the memory footprint of AlexNet?

    • What is the dominant part for computation in AlexNet?

    • How about memory bandwidth when computing the results?

  5. Apply dropout and ReLU to LeNet5. Does it improve? How about preprocessing?