Run this notebook online:Binder or Colab: Colab

4.3. Concise Implementation of Multilayer Perceptron

As you might expect, by relying on the DJL library, we can implement MLPs even more concisely. Let’s setup the relevant libraries first.

%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.7.0-SNAPSHOT
%maven ai.djl:model-zoo:0.7.0-SNAPSHOT
%maven ai.djl:basicdataset:0.7.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26

%maven ai.djl.mxnet:mxnet-engine:0.7.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-b
%%loadFromPOM
<dependency>
    <groupId>tech.tablesaw</groupId>
    <artifactId>tablesaw-jsplot</artifactId>
    <version>0.30.4</version>
</dependency>
%load ../utils/plot-utils.ipynb
import ai.djl.Device;
import ai.djl.*;
import ai.djl.metric.*;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.ndarray.index.*;
import ai.djl.nn.*;
import ai.djl.nn.core.*;
import ai.djl.training.*;
import ai.djl.training.initializer.*;
import ai.djl.training.loss.*;
import ai.djl.training.listener.*;
import ai.djl.training.evaluator.*;
import ai.djl.training.optimizer.*;
import ai.djl.training.tracker.*;
import ai.djl.training.dataset.*;
import ai.djl.util.*;
import java.util.Random;
import ai.djl.basicdataset.FashionMnist;
import ai.djl.training.dataset.Dataset;

import tech.tablesaw.api.*;
import tech.tablesaw.plotly.api.*;
import tech.tablesaw.plotly.components.*;
import tech.tablesaw.plotly.Plot;
import tech.tablesaw.plotly.components.Figure;
import org.apache.commons.lang3.ArrayUtils;

4.3.1. The Model

As compared to our gluon implementation of softmax regression implementation (sec_softmax_gluon), the only difference is that we add two Linear (fully-connected) layers (previously, we added one). The first is our hidden layer, which contains 256 hidden units and applies the ReLU activation function. The second is our output layer.

SequentialBlock net = new SequentialBlock();
net.add(Blocks.batchFlattenBlock(784));
net.add(Linear.builder().setUnits(256).build());
net.add(Activation::relu);
net.add(Linear.builder().setUnits(10).build());
net.setInitializer(new NormalInitializer());

Note that DJL, as usual, automatically infers the missing input dimensions to each layer.

The training loop is exactly the same as when we implemented softmax regression. This modularity enables us to separate matters concerning the model architecture from orthogonal considerations.

int batchSize = 256;
int numEpochs = 10;
double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;

trainLoss = new double[numEpochs];
trainAccuracy = new double[numEpochs];
testAccuracy = new double[numEpochs];
epochCount = new double[numEpochs];

FashionMnist trainIter = FashionMnist.builder()
                            .optUsage(Dataset.Usage.TRAIN)
                            .setSampling(batchSize, true)
                            .build();


FashionMnist testIter = FashionMnist.builder()
                            .optUsage(Dataset.Usage.TEST)
                            .setSampling(batchSize, true)
                            .build();

trainIter.prepare();
testIter.prepare();

for(int i = 0; i < epochCount.length; i++) {
    epochCount[i] = (i + 1);
}

Map<String, double[]> evaluatorMetrics = new HashMap<>();
Tracker lrt = Tracker.fixed(0.5f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

Loss loss = Loss.softmaxCrossEntropyLoss();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
                .optOptimizer(sgd) // Optimizer (loss function)
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

    try (Model model = Model.newInstance("mlp")) {
        model.setBlock(net);

        try (Trainer trainer = model.newTrainer(config)) {

            trainer.initialize(new Shape(1, 784));
            trainer.setMetrics(new Metrics());

            EasyTrain.fit(trainer, numEpochs, trainIter, testIter);
            // collect results from evaluators
            Metrics metrics = trainer.getMetrics();

            trainer.getEvaluators().stream()
                    .forEach(evaluator -> {
                        evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                        evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
            });
    }
}
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: 4 GPUs.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.458 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.12, SoftmaxCrossEntropyLoss: 2.64
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 1 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.12, SoftmaxCrossEntropyLoss: 2.64
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.32
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 2 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.32
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 3 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Training:    100% |████████████████████████████████████████| Accuracy: 0.11, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 4 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.11, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 5 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 6 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 7 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.11, SoftmaxCrossEntropyLoss: 2.32
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 8 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.11, SoftmaxCrossEntropyLoss: 2.32
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.31
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 9 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Epoch 10 finished.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - forward P50: 0.347 ms, P90: 0.394 ms
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - training-metrics P50: 0.013 ms, P90: 0.015 ms
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - backward P50: 0.645 ms, P90: 0.781 ms
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - step P50: 1.221 ms, P90: 1.377 ms
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - epoch P50: 29.024 s, P90: 51.097 s
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "test acc");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                trainLoss.length + testAccuracy.length + trainAccuracy.length, "train loss");

Table data = Table.create("Data").addColumns(
            DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
            DoubleColumn.create("loss", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))),
            StringColumn.create("lossLabel", lossLabel)
);

render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html");
https://d2l-java-resources.s3.amazonaws.com/img/chapter_multilayer-perceptrons_mlp-djl_output1.png

Fig. 4.3.1 lineplot

4.3.2. Exercises

  1. Try adding different numbers of hidden layers. What setting (keeping other parameters and hyperparameters constant) works best?

  2. Try out different activation functions. Which ones work best?

  3. Try different schemes for initializing the weights. What method works best?