Run this notebook online:Binder or Colab: Colab

4.3. Concise Implementation of Multilayer Perceptron

As you might expect, by relying on the DJL library, we can implement MLPs even more concisely. Let’s setup the relevant libraries first.

%load ../utils/djl-imports
%load ../utils/plot-utils
import ai.djl.metric.*;
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;

4.3.1. The Model

As compared to our concise implementation of softmax regression implementation (Section 3.7), the only difference is that we add two Linear (fully-connected) layers (previously, we added one). The first is our hidden layer, which contains 256 hidden units and applies the ReLU activation function. The second is our output layer.

SequentialBlock net = new SequentialBlock();
net.add(Blocks.batchFlattenBlock(784));
net.add(Linear.builder().setUnits(256).build());
net.add(Activation::relu);
net.add(Linear.builder().setUnits(10).build());
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);

Note that DJL, as usual, automatically infers the missing input dimensions to each layer.

The training loop is exactly the same as when we implemented softmax regression. This modularity enables us to separate matters concerning the model architecture from orthogonal considerations.

int batchSize = 256;
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);
double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;

trainLoss = new double[numEpochs];
trainAccuracy = new double[numEpochs];
testAccuracy = new double[numEpochs];
epochCount = new double[numEpochs];

FashionMnist trainIter = FashionMnist.builder()
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();


FashionMnist testIter = FashionMnist.builder()
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

trainIter.prepare();
testIter.prepare();

for(int i = 0; i < epochCount.length; i++) {
    epochCount[i] = (i + 1);
}

Map<String, double[]> evaluatorMetrics = new HashMap<>();
Tracker lrt = Tracker.fixed(0.5f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

Loss loss = Loss.softmaxCrossEntropyLoss();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
                .optOptimizer(sgd) // Optimizer (loss function)
                .optDevices(Engine.getInstance().getDevices(1)) // single GPU
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

    try (Model model = Model.newInstance("mlp")) {
        model.setBlock(net);

        try (Trainer trainer = model.newTrainer(config)) {

            trainer.initialize(new Shape(1, 784));
            trainer.setMetrics(new Metrics());

            EasyTrain.fit(trainer, numEpochs, trainIter, testIter);
            // collect results from evaluators
            Metrics metrics = trainer.getMetrics();

            trainer.getEvaluators().stream()
                    .forEach(evaluator -> {
                        evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                        evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
            });
    }
}
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.061 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.71, SoftmaxCrossEntropyLoss: 0.78
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.71, SoftmaxCrossEntropyLoss: 0.78
INFO Validate: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.57
Training:    100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.48
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.48
INFO Validate: Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.53
Training:    100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42
INFO Validate: Accuracy: 0.80, SoftmaxCrossEntropyLoss: 0.58
Training:    100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.40
Training:    100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37
INFO Validate: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.49
Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.44
Training:    100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
Training:    100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32
INFO Validate: Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.47
Training:    100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.31
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.31
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.43
Training:    100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39
INFO forward P50: 0.314 ms, P90: 0.375 ms
INFO training-metrics P50: 0.018 ms, P90: 0.022 ms
INFO backward P50: 0.616 ms, P90: 0.652 ms
INFO step P50: 0.901 ms, P90: 0.947 ms
INFO epoch P50: 1.265 s, P90: 1.672 s
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "test acc");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                trainLoss.length + testAccuracy.length + trainAccuracy.length, "train loss");

Table data = Table.create("Data").addColumns(
            DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
            DoubleColumn.create("loss", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))),
            StringColumn.create("lossLabel", lossLabel)
);

render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html");

4.3.2. Exercises

  1. Try adding different numbers of hidden layers. What setting (keeping other parameters and hyperparameters constant) works best?

  2. Try out different activation functions. Which ones work best?

  3. Try different schemes for initializing the weights. What method works best?