Run this notebook online:Binder or Colab: Colab

3.5. The Image Classification Dataset

In sec_naive_bayes, we trained a naive Bayes classifier, using the MNIST dataset introduced in 1998 [LeCun et al., 1998]. While MNIST had a good run as a benchmark dataset, even simple models by today’s standards achieve classification accuracy over 95% making it unsuitable for distinguishing between stronger models and weaker ones. Today, MNIST serves as more of sanity checks than as a benchmark. To up the ante just a bit, we will focus our discussion in the coming sections on the qualitatively similar, but comparatively complex Fashion-MNIST dataset [Xiao et al., 2017], which was released in 2017.

%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.7.0-SNAPSHOT
%maven ai.djl:basicdataset:0.7.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26

%maven ai.djl.mxnet:mxnet-engine:0.7.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-b
import ai.djl.ndarray.*;
import ai.djl.training.dataset.*;
import ai.djl.basicdataset.FashionMnist;
%load ../utils/StopWatch.java

3.5.1. Getting the Dataset

Just as with MNIST, DJL makes it easy to download and load the Fashion-MNIST dataset into memory via the FashionMnist class contained in ai.djl.basicdataset. We briefly work through the mechanics of loading and exploring the dataset below. Please refer to sec_naive_bayes for more details on loading data.

Let us first define the getDataset() function that obtains and reads the Fashion-MNIST dataset. It returns the dataset for the training set or the validation set depending on the passed in usage (Dataset.Usage.TRAIN for training and Dataset.Usage.TEST for validation). You can then call getData(manager) on the dataset to get the corresponding iterator. It also takes in the batchSize and randomShuffle which dictates the size of each batch and whether to randomly shuffle the data respectively.

import ai.djl.translate.TranslateException;

// Saved in the FashionMnistUtils class for later use
public ArrayDataset getDataset(Dataset.Usage usage,
                                      int batchSize,
                                      boolean randomShuffle) throws IOException, TranslateException {
    FashionMnist fashionMnist = FashionMnist.builder().optUsage(usage)
                                                      .setSampling(batchSize, randomShuffle)
                                                      .build();
    fashionMnist.prepare();
    return fashionMnist;
}
int batchSize = 256;
boolean randomShuffle = true;

ArrayDataset mnistTrain = getDataset(Dataset.Usage.TRAIN, batchSize, randomShuffle);
ArrayDataset mnistTest = getDataset(Dataset.Usage.TEST, batchSize, randomShuffle);

NDManager manager = NDManager.newBaseManager();

Fashion-MNIST consists of images from 10 categories, each represented by 60k images in the training set and by 10k in the test set. Consequently the training set and the test set contain 60k and 10k images, respectively.

System.out.println(mnistTrain.size());
System.out.println(mnistTest.size());
60000
10000

The images in Fashion-MNIST are associated with the following categories: t-shirt, trousers, pullover, dress, coat, sandal, shirt, sneaker, bag and ankle boot. The following function converts between numeric label indices and their names in text.

// Saved in the FashionMnist class for later use
public String[] getFashionMnistLabels(int[] labelIndices) {
    String[] textLabels = {"t-shirt", "trouser", "pullover", "dress", "coat",
                   "sandal", "shirt", "sneaker", "bag", "ankle boot"};
    String[] convertedLabels = new String[labelIndices.length];
    for (int i = 0; i < labelIndices.length; i++) {
        convertedLabels[i] = textLabels[labelIndices[i]];
    }
    return convertedLabels;
}

public String getFashionMnistLabel(int labelIndice) {
    String[] textLabels = {"t-shirt", "trouser", "pullover", "dress", "coat",
                   "sandal", "shirt", "sneaker", "bag", "ankle boot"};
    return textLabels[labelIndice];
}

We can now create a function to visualize these examples. Don’t worry too much about the specifics of visualization. This is simply just to help intuitively understand the data. We essentially read in a number of datapoints and convert their RGB value from 0-255 to between 0-1. We then set the color as grayscale and then display it along with their labels in an external window.

import java.awt.image.BufferedImage;
import java.awt.Graphics2D;
import java.awt.Graphics;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Component;
import javax.swing.JFrame;
import javax.swing.JPanel;
import javax.swing.JLabel;
import javax.swing.BoxLayout;
public class ImagePanel extends JPanel {
    int SCALE;
    BufferedImage img;

    public ImagePanel() {
        this.SCALE = 1;
    }
    public ImagePanel(int scale, BufferedImage img) {
        this.SCALE = scale;
        this.img = img;
    }
    @Override
    protected void paintComponent(Graphics g) {
        Graphics2D g2d = (Graphics2D)g;
        g2d.scale(SCALE, SCALE);
        g2d.drawImage(this.img, 0, 0, this);
    }
}

public class Container extends JPanel {
    public Container(String label) {
        setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
        JLabel l = new JLabel(label, JLabel.CENTER);
        l.setAlignmentX(Component.CENTER_ALIGNMENT);
        add(l);
    }
    public Container(String trueLabel, String predLabel) {
        setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
        JLabel l = new JLabel(trueLabel, JLabel.CENTER);
        l.setAlignmentX(Component.CENTER_ALIGNMENT);
        add(l);
        JLabel l2 = new JLabel(predLabel, JLabel.CENTER);
        l2.setAlignmentX(Component.CENTER_ALIGNMENT);
        add(l2);
    }
}
import ai.djl.translate.TranslateException;

// Saved in the FashionMnistUtils class for later use
public void showImages(ArrayDataset dataset,
                       int number, int WIDTH, int HEIGHT, int SCALE,
                       NDManager manager)
    throws IOException, TranslateException {
    // Plot a list of images
    JFrame frame = new JFrame("Fashion Mnist");
    for (int record = 0; record < number; record++) {
        NDArray X = dataset.get(manager, record).getData().get(0).squeeze(-1);
        int y = (int)dataset.get(manager, record).getLabels().get(0).getFloat();
        BufferedImage img = new BufferedImage(WIDTH, HEIGHT, BufferedImage.TYPE_BYTE_GRAY);
        Graphics2D g = (Graphics2D) img.getGraphics();
        for(int i = 0; i < WIDTH; i++) {
            for(int j = 0; j < HEIGHT; j++) {
                float c = X.getFloat(j, i) / 255;  // scale down to between 0 and 1
                g.setColor(new Color(c, c, c)); // set as a gray color
                g.fillRect(i, j, 1, 1);
            }
        }
        JPanel panel = new ImagePanel(SCALE, img);
        panel.setPreferredSize(new Dimension(WIDTH * SCALE, HEIGHT * SCALE));
        JPanel container = new Container(getFashionMnistLabel(y));
        container.add(panel);
        frame.getContentPane().add(container);
    }
    frame.getContentPane().setLayout(new FlowLayout());
    frame.pack();
    frame.setVisible(true);
}

Here are the images and their corresponding labels (in text) for the first few examples in the training dataset.

final int SCALE = 4;
final int WIDTH = 28;
final int HEIGHT = 28;

/* Uncomment the following line and run to display images.
   It will open in another window. */
// showImages(mnistTrain, 18, WIDTH, HEIGHT, SCALE, manager);
https://d2l-java-resources.s3.amazonaws.com/img/fashion_mnist_labels.png

Fig. 3.5.1 Fashion Mnist labels.

3.5.2. Reading a Minibatch

To make our life easier when reading from the training and test sets, we use the getData(manager). Recall that at each iteration, getData(manager) reads a minibatch of data with size batchSize each time. We then get the X and y by calling getData() and getLabels() on each batch respectively.

Note: During training, reading data can be a significant performance bottleneck, especially when our model is simple or when our computer is fast.

Let us look at the time it takes to read the training data.

StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (Batch batch : mnistTrain.getData(manager)) {
    NDArray X = batch.getData().head();
    NDArray y = batch.getLabels().head();
    continue;
}
System.out.printf("%.2f sec\n", stopWatch.stop());
22.41 sec
java.io.PrintStream@34f279bc

We are now ready to work with the Fashion-MNIST dataset in the sections that follow.

3.5.3. Summary

  • Fashion-MNIST is an apparel classification dataset consisting of images representing 10 categories.

  • We will use this dataset in subsequent sections and chapters to evaluate various classification algorithms.

  • We store the shape of each image with height \(h\) width \(w\) pixels as \(h \times w\) or (h, w).

  • Data iterators are a key component for efficient performance. Rely on well-implemented iterators that exploit multi-threading to avoid slowing down your training loop.

3.5.4. Exercises

  1. Does reducing the batchSize (for instance, to 1) affect read performance?

  2. Use the DJL documentation to see which other datasets are available in ai.djl.basicdataset.