Run this notebook online:Binder or Colab: Colab

3.6. Implementation of Softmax Regression from Scratch

Just as we implemented linear regression from scratch, we believe that multiclass logistic (softmax) regression is similarly fundamental and you ought to know the gory details of how to implement it yourself. As with linear regression, after doing things by hand we will breeze through an implementation in DJL for comparison. To begin, let us import the familiar packages.

%load ../utils/djl-imports
%load ../utils/plot-utils.ipynb
%load ../utils/
%load ../utils/
%load ../utils/


We will work with the Fashion-MNIST dataset, just introduced in Section 3.5, setting up an iterator with batch size \(256\). We also set it to randomly shuffled elements for each batch for the training set.

int batchSize = 256;
boolean randomShuffle = true;

// get training and validation dataset
FashionMnist trainingSet = FashionMnist.builder()
        .setSampling(batchSize, randomShuffle)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))

FashionMnist validationSet = FashionMnist.builder()
        .setSampling(batchSize, false)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))

3.6.1. Initializing Model Parameters

As in our linear regression example, each example here will be represented by a fixed-length vector. Each example in the raw data is a \(28 \times 28\) image. In this section, we will flatten each image, treating them as \(784\)-long 1D vectors. In the future, we will talk about more sophisticated strategies for exploiting the spatial structure in images, but for now we treat each pixel location as just another feature.

Recall that in softmax regression, we have as many outputs as there are categories. Because our dataset has \(10\) categories, our network will have an output dimension of \(10\). Consequently, our weights will constitute a \(784 \times 10\) matrix and the biases will constitute a \(1 \times 10\) vector. As with linear regression, we will initialize our weights \(W\) with Gaussian noise and our biases to take the initial value \(0\).

int numInputs = 784;
int numOutputs = 10;

NDManager manager = NDManager.newBaseManager();
NDArray W = manager.randomNormal(0, 0.01f, new Shape(numInputs, numOutputs), DataType.FLOAT32);
NDArray b = manager.zeros(new Shape(numOutputs), DataType.FLOAT32);
NDList params = new NDList(W, b);

3.6.2. The Softmax

Before implementing the softmax regression model, let us briefly review how operators such as sum() work along specific dimensions in an NDArray. Given a matrix X we can sum over all elements (default) or only over elements in the same axis, i.e., the column (new int[]{0}) or the same row (new int[]{1}). We wrap the axis in an int array since we can specify multiple axes as well. For example if we call sum() with new int[]{0, 1}, it sums up the elements over both the rows and columns. In this 2D array, this means the total sum of the elements within! Note that if X is an array with shape ($2$, $3$) and we sum over the columns (X.sum(new int[]{0})), the result will be a (1D) vector with shape ($3$,). If we want to keep the number of axes in the original array (resulting in a 2D array with shape ($1$, $3$)), rather than collapsing out the dimension that we summed over we can specify true when invoking sum().

NDArray X = manager.create(new int[][]{{1, 2, 3}, {4, 5, 6}});
System.out.println(X.sum(new int[]{0}, true));
System.out.println(X.sum(new int[]{1}, true));
System.out.println(X.sum(new int[]{0, 1}, true));
ND: (1, 3) gpu(0) int32
[[ 5,  7,  9],

ND: (2, 1) gpu(0) int32
[[ 6],

ND: (1, 1) gpu(0) int32

We are now ready to implement the softmax function. Recall that softmax consists of two steps: First, we exponentiate each term (using exp()). Then, we sum over each row (we have one row per example in the batch) to get the normalization constants for each example. Finally, we divide each row by its normalization constant, ensuring that the result sums to \(1\). Before looking at the code, let us recall how this looks expressed as an equation:

(3.6.1)\[\mathrm{softmax}(\mathbf{X})_{ij} = \frac{\exp(X_{ij})}{\sum_k \exp(X_{ik})}.\]

The denominator, or normalization constant, is also sometimes called the partition function (and its logarithm is called the log-partition function). The origins of that name are in statistical physics where a related equation models the distribution over an ensemble of particles.

public NDArray softmax(NDArray X) {
    NDArray Xexp = X.exp();
    NDArray partition = Xexp.sum(new int[]{1}, true);
    return Xexp.div(partition); // The broadcast mechanism is applied here

As you can see, for any random input, we turn each element into a non-negative number. Moreover, each row sums up to 1, as is required for a probability. Note that while this looks correct mathematically, we were a bit sloppy in our implementation because we failed to take precautions against numerical overflow or underflow due to large (or very small) elements of the matrix, as we did in sec_naive_bayes.

NDArray X = manager.randomNormal(new Shape(2, 5));
NDArray Xprob = softmax(X);
System.out.println(Xprob.sum(new int[]{1}));
ND: (2, 5) gpu(0) float32
[[0.1406, 0.117 , 0.5391, 0.0491, 0.1541],
 [0.204 , 0.0605, 0.0759, 0.5691, 0.0905],

ND: (2) gpu(0) float32
[1.    , 1.    ]

3.6.3. The Model

Now that we have defined the softmax operation, we can implement the softmax regression model. The below code defines the forward pass through the network. Note that we flatten each original image in the batch into a vector with length numInputs with the reshape() function before passing the data through our model.

// We need to wrap `net()` in a class so that we can reference the method
// and pass it as a parameter to a function or save it in a variable
public class Net {
    public static NDArray net(NDArray X) {
        NDArray currentW = params.get(0);
        NDArray currentB = params.get(1);
        return softmax(X.reshape(new Shape(-1, numInputs)).dot(currentW).add(currentB));

3.6.4. The Loss Function

Next, we need to implement the cross-entropy loss function, introduced in Section 3.4. This may be the most common loss function in all of deep learning because, at the moment, classification problems far outnumber regression problems.

Recall that cross-entropy takes the negative log likelihood of the predicted probability assigned to the true label \(-\log P(y \mid x)\). Rather than iterating over the predictions with a Java for loop (which tends to be inefficient), we can use the NDArray get() function in conjunction with NDIndex to let us easily select the appropriate terms from the matrix of softmax entries. This is typically known as the pick() operator in other frameworks such as PyTorch. Below, we illustrate the usage on a toy example, with \(3\) categories and \(2\) examples.

The ":, {}" section of the NDIndex selects all arrays and the manager.create(new int[]{0, 2}) creates an NDArray with the values 0 and 2 to pick the 0th and 2nd elements for each respective NDArray.

Note: when using NDIndex in this way, the passed in NDArray used for picking indices must be of type int or long. You can use the toType() function to change the type of the NDArray which will be shown below.

NDArray yHat = manager.create(new float[][]{{0.1f, 0.3f, 0.6f}, {0.3f, 0.2f, 0.5f}});
yHat.get(new NDIndex(":, {}", manager.create(new int[]{0, 2})));
ND: (2, 2) gpu(0) float32
[[0.1, 0.6],
 [0.3, 0.5],

Now we can implement the cross-entropy loss function efficiently with just one line of code.

// Cross Entropy only cares about the target class's probability
// Get the column index for each row
public class LossFunction {
    public static NDArray crossEntropy(NDArray yHat, NDArray y) {
        // Here, y is not guranteed to be of datatype int or long
        // and in our case we know its a float32.
        // We must first convert it to int or long(here we choose int)
        // before we can use it with NDIndex to "pick" indices.
        // It also takes in a boolean for returning a copy of the existing NDArray
        // but we don't want that so we pass in `false`.
        NDIndex pickIndex = new NDIndex()
                 .addAllDim(Math.floorMod(-1, yHat.getShape().dimension()))
        return yHat.get(pickIndex).log().neg();

3.6.5. Classification Accuracy

Given the predicted probability distribution yHat, we typically choose the class with highest predicted probability whenever we must output a hard prediction. Indeed, many applications require that we make a choice. Gmail must categorize an email into Primary, Social, Updates, or Forums. It might estimate probabilities internally, but at the end of the day it has to choose one among the categories.

When predictions are consistent with the actual category y, they are correct. The classification accuracy is the fraction of all predictions that are correct. Although it can be difficult optimize accuracy directly (it is not differentiable), it is often the performance metric that we care most about, and we will nearly always report it when training classifiers.

To compute accuracy we do the following: First, we execute yHat.argMax(1) where 1 is the axis to gather the predicted classes (given by the indices for the largest entries in each row). The result has the same shape as the variable y. Now we just need to check how frequently the two match. Since the equality function eq() is datatype-sensitive (e.g., a float32 and a float32 are never equal), we also need to convert both to the same type (we pick int32). The result is an NDArray containing entries of 0 (false) and 1 (true). We then sum the number of true entries and convert the result to a float. Finally, we get the mean by dividing by the number of data points.

// Saved in the utils for later use
public float accuracy(NDArray yHat, NDArray y) {
    // Check size of 1st dimension greater than 1
    // to see if we have multiple samples
    if (yHat.getShape().size(1) > 1) {
        // Argmax gets index of maximum args for given axis 1
        // Convert yHat to same dataType as y (int32)
        // Sum up number of true entries
        return yHat.argMax(1).toType(DataType.INT32, false).eq(y.toType(DataType.INT32, false))
            .sum().toType(DataType.FLOAT32, false).getFloat();
    return yHat.toType(DataType.INT32, false).eq(y.toType(DataType.INT32, false))
        .sum().toType(DataType.FLOAT32, false).getFloat();

We will continue to use the variables yHat and y defined in the pick() function, as the predicted probability distribution and label, respectively. We can see that the first example’s prediction category is \(2\) (the largest element of the row is \(0.6\) with an index of \(2\)), which is inconsistent with the actual label, \(0\). The second example’s prediction category is \(2\) (the largest element of the row is \(0.5\) with an index of \(2\)), which is consistent with the actual label, \(2\). Therefore, the classification accuracy rate for these two examples is \(0.5\).

NDArray y = manager.create(new int[]{0,2});
accuracy(yHat, y) / y.size();

Similarly, we can evaluate the accuracy for model net on the dataset (accessed via dataIterator).

import java.util.function.UnaryOperator;
import java.util.function.BinaryOperator;
// Saved in the utils for future use
public float evaluateAccuracy(UnaryOperator<NDArray> net, Iterable<Batch> dataIterator) {
    Accumulator metric = new Accumulator(2);  // numCorrectedExamples, numExamples
    Batch batch = dataIterator.iterator().next();
    NDArray X = batch.getData().head();
    NDArray y = batch.getLabels().head();
    metric.add(new float[]{accuracy(net.apply(X), y), (float)y.size()});

    return metric.get(0) / metric.get(1);

Here Accumulator is a utility class to accumulate sums over multiple numbers.

// Saved in utils for future use
/* Sum a list of numbers over time */
public class Accumulator {
    float[] data;

    public Accumulator(int n) {
        data = new float[n];

    /* Adds a set of numbers to the array */
    public void add(float[] args) {
        for (int i = 0; i < args.length; i++) {
            data[i] += args[i];

    /* Resets the array */
    public void reset() {
        Arrays.fill(data, 0f);

    /* Returns the data point at the given index */
    public float get(int index) {
        return data[index];

Because we initialized the net model with random weights, the accuracy of this model should be close to random guessing, i.e., \(0.1\) for \(10\) classes.

evaluateAccuracy(Net::net, validationSet.getData(manager));

3.6.6. Model Training

The training loop for softmax regression should look strikingly familiar if you read through our implementation of linear regression in Section 3.2. Here we refactor the implementation to make it reusable. First, we define a function to train for one data epoch. Note that updater() is a general function to update the model parameters, which accepts the batch size as an argument. Currently, it is a wrapper of Training.sgd().

public static interface ParamConsumer {
     void accept(NDList params, float lr, int batchSize);

public float[] trainEpochCh3(UnaryOperator<NDArray> net, Iterable<Batch> trainIter, BinaryOperator<NDArray> loss, ParamConsumer updater) {
    Accumulator metric = new Accumulator(3); // trainLossSum, trainAccSum, numExamples

    // Attach Gradients
    for (NDArray param : params) {

    for (Batch batch : trainIter) {
        NDArray X = batch.getData().head();
        NDArray y = batch.getLabels().head();
        X = X.reshape(new Shape(-1, numInputs));

        try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
            // Minibatch loss in X and y
            NDArray yHat = net.apply(X);
            NDArray l = loss.apply(yHat, y);
            gc.backward(l);  // Compute gradient on l with respect to w and b
            metric.add(new float[]{l.sum().toType(DataType.FLOAT32, false).getFloat(),
                                   accuracy(yHat, y),
        updater.accept(params, lr, batch.getSize());  // Update parameters using their gradient

    // Return trainLoss, trainAccuracy
    return new float[]{metric.get(0) / metric.get(2), metric.get(1) / metric.get(2)};

Before showing the implementation of the training function, we define a utility class that draws data in animation. Again, it aims to simplify the code in later chapters.

import tech.tablesaw.api.Row;
import tech.tablesaw.columns.Column;

// Saved in utils
/* Animates a graph with real-time data. */
class Animator {
    private String id; // Id reference of graph(for updating graph)
    private Table data; // Data Points

    public Animator() {
        id = "";

        // Incrementally plot data
        data = Table.create("Data")
            FloatColumn.create("epoch", new float[]{}),
            FloatColumn.create("value", new float[]{}),
            StringColumn.create("metric", new String[]{})

    // Add a single metric to the table
    public void add(float epoch, float value, String metric) {
        Row newRow = data.appendRow();
        newRow.setFloat("epoch", epoch);
        newRow.setFloat("value", value);
        newRow.setString("metric", metric);

    // Add accuracy, train accuracy, and train loss metrics for a given epoch
    // Then plot it on the graph
    public void add(float epoch, float accuracy, float trainAcc, float trainLoss) {
        add(epoch, trainLoss, "train loss");
        add(epoch, trainAcc, "train accuracy");
        add(epoch, accuracy, "test accuracy");

    // Display the graph
    public void show() {
        if (id.equals("")) {
            id = display(LinePlot.create("", data, "epoch", "value", "metric"));

    // Update the graph
    public void update() {
        updateDisplay(id, LinePlot.create("", data, "epoch", "value", "metric"));

    // Returns the column at the given index
    // if it is a float column
    // Otherwise returns null
    public float[] getY(Table t, int index) {
        Column c = t.column(index);
        if (c.type() == ColumnType.FLOAT) {
            float[] newArray = new float[c.size()];
            System.arraycopy(c.asList().toArray(), 0, newArray, 0, c.size());
            return newArray;
        return null;

The training function then runs multiple epochs and visualize the training progress.

public void trainCh3(UnaryOperator<NDArray> net, Dataset trainDataset, Dataset testDataset,
                     BinaryOperator<NDArray> loss, int numEpochs, ParamConsumer updater)
                                                            throws IOException, TranslateException {
    Animator animator = new Animator();
    for (int i = 1; i <= numEpochs; i++) {
        float[] trainMetrics = trainEpochCh3(net, trainDataset.getData(manager), loss, updater);
        float accuracy = evaluateAccuracy(net, testDataset.getData(manager));
        float trainAccuracy = trainMetrics[1];
        float trainLoss = trainMetrics[0];

        animator.add(i, accuracy, trainAccuracy, trainLoss);
        System.out.printf("Epoch %d: Test Accuracy: %f\n", i, accuracy);
        System.out.printf("Train Accuracy: %f\n", trainAccuracy);
        System.out.printf("Train Loss: %f\n", trainLoss);

Again, we use the minibatch stochastic gradient descent to optimize the loss function of the model. Note that the number of epochs (numEpochs), and learning rate (lr) are both adjustable hyper-parameters. By changing their values, we may be able to increase the classification accuracy of the model. In practice we will want to split our data three ways into training, validation, and test data, using the validation data to choose the best values of our hyper-parameters.

int numEpochs = 5;
float lr = 0.1f;

public class Updater {
    public static void updater(NDList params, float lr, int batchSize) {
        Training.sgd(params, lr, batchSize);

trainCh3(Net::net, trainingSet, validationSet, LossFunction::crossEntropy, numEpochs, Updater::updater);
Epoch 1: Test Accuracy: 0.820313
Train Accuracy: 0.750000
Train Loss: 0.783057
Epoch 2: Test Accuracy: 0.816406
Train Accuracy: 0.814083
Train Loss: 0.570253
Epoch 3: Test Accuracy: 0.828125
Train Accuracy: 0.825600
Train Loss: 0.525173
Epoch 4: Test Accuracy: 0.839844
Train Accuracy: 0.831117
Train Loss: 0.501283
Epoch 5: Test Accuracy: 0.859375
Train Accuracy: 0.836683
Train Loss: 0.485033

3.6.7. Prediction

Now that training is complete, our model is ready to classify some images. Given a series of images, we will compare their actual labels (first line of text output) and the model predictions (second line of text output).

// Saved in the FashionMnistUtils class for later use
// Number should be < batchSize for this function to work properly
public BufferedImage predictCh3(UnaryOperator<NDArray> net, ArrayDataset dataset, int number, NDManager manager)
    throws IOException, TranslateException {
    int[] predLabels = new int[number];

    Batch batch = dataset.getData(manager).iterator().next();
    NDArray X = batch.getData().head();
    int[] yHat = net.apply(X).argMax(1).toType(DataType.INT32, false).toIntArray();
    for (int i = 0; i < number; i++) {
        predLabels[i] = yHat[i];

    return FashionMnistUtils.showImages(dataset, predLabels, 28, 28, 4, manager);

predictCh3(Net::net, validationSet, 6, manager)

3.6.8. Summary

With softmax regression, we can train models for multi-category classification. The training loop is very similar to that in linear regression: retrieve and read data, define models and loss functions, then train models using optimization algorithms. As you will soon find out, most common deep learning models have similar training procedures.

3.6.9. Exercises

  1. In this section, we directly implemented the softmax function based on the mathematical definition of the softmax operation. What problems might this cause (hint: try to calculate the size of \(\exp(50)\))?

  2. The function crossEntropy() in this section is implemented according to the definition of the cross-entropy loss function. What could be the problem with this implementation (hint: consider the domain of the logarithm)?

  3. What solutions you can think of to fix the two problems above?

  4. Is it always a good idea to return the most likely label. E.g., would you do this for medical diagnosis?

  5. Assume that we want to use softmax regression to predict the next word based on some features. What are some problems that might arise from a large vocabulary?