Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_multilayer-perceptrons/mlp-djl.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_multilayer-perceptrons/mlp-djl.ipynb .. _sec_mlp_djl: Concise Implementation of Multilayer Perceptron =============================================== As you might expect, by relying on the DJL library, we can implement MLPs even more concisely. Let's setup the relevant libraries first. .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils .. code:: java import ai.djl.metric.*; import ai.djl.basicdataset.cv.classification.*; import org.apache.commons.lang3.ArrayUtils; The Model --------- As compared to our concise implementation of softmax regression implementation (:numref:`sec_softmax_djl`), the only difference is that we add *two* ``Linear`` (fully-connected) layers (previously, we added *one*). The first is our hidden layer, which contains *256* hidden units and applies the ReLU activation function. The second is our output layer. .. code:: java SequentialBlock net = new SequentialBlock(); net.add(Blocks.batchFlattenBlock(784)); net.add(Linear.builder().setUnits(256).build()); net.add(Activation::relu); net.add(Linear.builder().setUnits(10).build()); net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); Note that DJL, as usual, automatically infers the missing input dimensions to each layer. The training loop is *exactly* the same as when we implemented softmax regression. This modularity enables us to separate matters concerning the model architecture from orthogonal considerations. .. code:: java int batchSize = 256; int numEpochs = Integer.getInteger("MAX_EPOCH", 10); double[] trainLoss; double[] testAccuracy; double[] epochCount; double[] trainAccuracy; trainLoss = new double[numEpochs]; trainAccuracy = new double[numEpochs]; testAccuracy = new double[numEpochs]; epochCount = new double[numEpochs]; FashionMnist trainIter = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(batchSize, true) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); FashionMnist testIter = FashionMnist.builder() .optUsage(Dataset.Usage.TEST) .setSampling(batchSize, true) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); trainIter.prepare(); testIter.prepare(); for(int i = 0; i < epochCount.length; i++) { epochCount[i] = (i + 1); } Map evaluatorMetrics = new HashMap<>(); .. code:: java Tracker lrt = Tracker.fixed(0.5f); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); Loss loss = Loss.softmaxCrossEntropyLoss(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // Optimizer (loss function) .optDevices(Engine.getInstance().getDevices(1)) // single GPU .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging try (Model model = Model.newInstance("mlp")) { model.setBlock(net); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new Shape(1, 784)); trainer.setMetrics(new Metrics()); EasyTrain.fit(trainer, numEpochs, trainIter, testIter); // collect results from evaluators Metrics metrics = trainer.getMetrics(); trainer.getEvaluators().stream() .forEach(evaluator -> { evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); }); } } .. parsed-literal:: :class: output INFO Training on: 1 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.061 ms. .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.71, SoftmaxCrossEntropyLoss: 0.78 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 1 finished. INFO Train: Accuracy: 0.71, SoftmaxCrossEntropyLoss: 0.78 INFO Validate: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.57 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.48 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 2 finished. INFO Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.48 INFO Validate: Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.53 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 3 finished. INFO Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42 INFO Validate: Accuracy: 0.80, SoftmaxCrossEntropyLoss: 0.58 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 4 finished. INFO Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39 INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.40 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 5 finished. INFO Train: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37 INFO Validate: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.49 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 6 finished. INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35 INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.44 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 7 finished. INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.33 INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 8 finished. INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32 INFO Validate: Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.47 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.31 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 9 finished. INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.31 INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.43 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 10 finished. INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30 INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.39 INFO forward P50: 0.314 ms, P90: 0.375 ms INFO training-metrics P50: 0.018 ms, P90: 0.022 ms INFO backward P50: 0.616 ms, P90: 0.652 ms INFO step P50: 0.901 ms, P90: 0.947 ms INFO epoch P50: 1.265 s, P90: 1.672 s .. code:: java trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss"); trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy"); testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy"); String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length]; Arrays.fill(lossLabel, 0, trainLoss.length, "test acc"); Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc"); Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length, trainLoss.length + testAccuracy.length + trainAccuracy.length, "train loss"); Table data = Table.create("Data").addColumns( DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))), DoubleColumn.create("loss", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))), StringColumn.create("lossLabel", lossLabel) ); render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html"); .. raw:: html
Exercises --------- 1. Try adding different numbers of hidden layers. What setting (keeping other parameters and hyperparameters constant) works best? 2. Try out different activation functions. Which ones work best? 3. Try different schemes for initializing the weights. What method works best?