Run this notebook online:\ |Binder| or Colab: |Colab|
.. |Binder| image:: https://mybinder.org/badge_logo.svg
:target: https://mybinder.org/v2/gh/aws-samples/d2l-java/master?filepath=chapter_linear-networks/linear-regression-scratch.ipynb
.. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/aws-samples/d2l-java/blob/colab/chapter_linear-networks/linear-regression-scratch.ipynb
.. _sec_linear_scratch:
Linear Regression Implementation from Scratch
=============================================
Now that you understand the key ideas behind linear regression, we can
begin to work through a hands-on implementation in code. In this
section, we will implement the entire method from scratch, including the
data pipeline, the model, the loss function, and the gradient descent
optimizer. While modern deep learning frameworks can automate nearly all
of this work, implementing things from scratch is the only to make sure
that you really know what you are doing. Moreover, when it comes time to
customize models, defining our own layers, loss functions, etc.,
understanding how things work under the hood will prove handy. In this
section, we will rely only on ``NDArray`` and ``GradientCollector``.
Afterwards, we will introduce a more compact implementation, taking
advantage of DJL's bells and whistles. To start off, we import the few
required packages.
.. code:: java
%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/
%maven ai.djl:api:0.7.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26
%maven ai.djl.mxnet:mxnet-engine:0.7.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-b
.. code:: java
%%loadFromPOM
tech.tablesaw
tablesaw-jsplot
0.38.1
.. code:: java
%load ../utils/plot-utils
.. code:: java
import ai.djl.Device;
import ai.djl.ndarray.*;
import ai.djl.ndarray.types.*;
import ai.djl.ndarray.index.*;
import ai.djl.ndarray.types.DataType;
import tech.tablesaw.api.*;
import tech.tablesaw.plotly.api.*;
import tech.tablesaw.plotly.components.*;
Generating the Dataset
----------------------
To keep things simple, we will construct an artificial dataset according
to a linear model with additive noise. Out task will be to recover this
model's parameters using the finite set of examples contained in our
dataset. We will keep the data low-dimensional so we can visualize it
easily. In the following code snippet, we generated a dataset containing
:math:`1000` examples, each consisting of :math:`2` features sampled
from a standard normal distribution. Thus our synthetic dataset will be
an object :math:`\mathbf{X}\in \mathbb{R}^{1000 \times 2}`.
The true parameters generating our data will be
:math:`\mathbf{w} = [2, -3.4]^\top` and :math:`b = 4.2` and our
synthetic labels will be assigned according to the following linear
model with noise term :math:`\epsilon`:
.. math:: \mathbf{y}= \mathbf{X} \mathbf{w} + b + \mathbf\epsilon.
You could think of :math:`\epsilon` as capturing potential measurement
errors on the features and labels. We will assume that the standard
assumptions hold and thus that :math:`\epsilon` obeys a normal
distribution with mean of :math:`0`. To make our problem easy, we will
set its standard deviation to :math:`0.01`. The following code generates
our synthetic dataset:
.. code:: java
class DataPoints {
private NDArray X, y;
public DataPoints(NDArray X, NDArray y) {
this.X = X;
this.y = y;
}
public NDArray getX() {
return X;
}
public NDArray getY() {
return y;
}
}
// Generate y = X w + b + noise
public DataPoints syntheticData(NDManager manager, NDArray w, float b, int numExamples) {
NDArray X = manager.randomNormal(new Shape(numExamples, w.size()));
NDArray y = X.dot(w).add(b);
// Add noise
y = y.add(manager.randomNormal(0, 0.01f, y.getShape(), DataType.FLOAT32, Device.defaultDevice()));
return new DataPoints(X, y);
}
NDManager manager = NDManager.newBaseManager();
NDArray trueW = manager.create(new float[]{2, -3.4f});
float trueB = 4.2f;
DataPoints dp = syntheticData(manager, trueW, trueB, 1000);
NDArray features = dp.getX();
NDArray labels = dp.getY();
Note that each row in ``features`` consists of a 2-dimensional data
point and that each row in ``labels`` consists of a 1-dimensional target
value (a scalar).
.. code:: java
System.out.printf("features: [%f, %f]\n", features.get(0).getFloat(0), features.get(0).getFloat(1));
System.out.println("label: " + labels.getFloat(0));
.. parsed-literal::
:class: output
features: [0.292537, -0.718359]
label: 7.2342157
By generating a scatter plot using the second feature ``features[:, 1]``
and ``labels``, we can clearly observe the linear correlation between
the two.
.. code:: java
float[] X = features.get(new NDIndex(":, 1")).toFloatArray();
float[] y = labels.toFloatArray();
Table data = Table.create("Data")
.addColumns(
FloatColumn.create("X", X),
FloatColumn.create("y", y)
);
ScatterPlot.create("Synthetic Data", data, "X", "y");
.. raw:: html
.. figure:: https://d2l-java-resources.s3.amazonaws.com/img/chapter_linear-networks_linear-regression-scratch_output1.png
Scatterplot
Reading the Dataset
-------------------
Recall that training models consists of making multiple passes over the
dataset, grabbing one minibatch of examples at a time, and using them to
update our model. We can use ``ArrayDataset`` to randomly sample the
data and access it in minibatches.
In the following code, we instantiate an ``ArrayDataset``. We then set
parameters for ``features``, ``labels``, ``batchSize``, and
``sampling``.
With ``dataset.getData``, we can get minibatches of size ``batchSize``,
each consisting of its features and labels.
.. code:: java
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.dataset.Batch;
int batchSize = 10;
ArrayDataset dataset = new ArrayDataset.Builder()
.setData(features) // Set the Features
.optLabels(labels) // Set the Labels
.setSampling(batchSize, false) // set the batch size and random sampling to false
.build();
In general, note that we want to use reasonably sized minibatches to
take advantage of the GPU hardware, which excels at parallelizing
operations. Because each example can be fed through our models in
parallel and the gradient of the loss function for each example can also
be taken in parallel, GPUs allow us to process hundreds of examples in
scarcely more time than it might take to process just a single example.
To build some intuition, let us read and print the first small batch of
data examples. The shape of the features in each minibatch tells us both
the minibatch size and the number of input features. Likewise, our
minibatch of labels will have a shape given by ``batchSize``.
.. code:: java
for (Batch batch : dataset.getData(manager)) {
// Call head() to get the first NDArray
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();
System.out.println(X);
System.out.println(y);
// Don't forget to close the batch!
batch.close();
break;
}
.. parsed-literal::
:class: output
ND: (10, 2) gpu(0) float32
[[ 0.2925, -0.7184],
[ 0.1 , -0.3932],
[ 2.547 , -0.0034],
[ 0.0083, -0.251 ],
[ 0.129 , 0.3728],
[ 1.0822, -0.665 ],
[ 0.5434, -0.7168],
[-1.4913, 1.4805],
[ 0.1374, -1.2208],
[ 0.3072, 1.1135],
]
ND: (10) gpu(0) float32
[ 7.2342, 5.7411, 9.3138, 5.0536, 3.1772, 8.6284, 7.7434, -3.808 , 8.6185, 1.0259]
As we run the iterator, we obtain distinct minibatches successively
until all the data has been exhausted (try this). While the iterator
implemented above is good for didactic purposes, it is inefficient in
ways that might get us in trouble on real problems. For example, it
requires that we load all data in memory and that we perform lots of
random memory access. The built-in iterators implemented in DJL are
considerably more efficient and they can deal both with data stored in
file and data fed via a data stream.
Initializing Model Parameters
-----------------------------
Before we can begin optimizing our model's parameters by gradient
descent, we need to have some parameters in the first place. In the
following code, we initialize weights by sampling random numbers from a
normal distribution with mean 0 and a standard deviation of
:math:`0.01`, setting the bias :math:`b` to :math:`0`.
.. code:: java
NDArray w = manager.randomNormal(0, 0.01f, new Shape(2, 1), DataType.FLOAT32, Device.defaultDevice());
NDArray b = manager.zeros(new Shape(1));
NDList params = new NDList(w, b);
Now that we have initialized our parameters, our next task is to update
them until they fit our data sufficiently well. Each update requires
taking the gradient (a multi-dimensional derivative) of our loss
function with respect to the parameters. Given this gradient, we can
update each parameter in the direction that reduces the loss.
Since nobody wants to compute gradients explicitly (this is tedious and
error prone), we use automatic differentiation to compute the gradient.
See :numref:`sec_gradcollector` for more details. Recall from the
autograd chapter that in order for ``GradientCollector`` to know that it
should store a gradient for our parameters, we need to invoke the
``attachGradient()`` function, allocating memory to store the gradients
that we plan to take.
Defining the Model
------------------
Next, we must define our model, relating its inputs and parameters to
its outputs. Recall that to calculate the output of the linear model, we
simply take the matrix-vector dot product of the examples
:math:`\mathbf{X}` and the models weights :math:`w`, and add the offset
:math:`b` to each example. Note that below ``X.dot(w)`` is a vector and
``b`` is a scalar. Recall that when we add a vector and a scalar, the
scalar is added to each component of the vector.
.. code:: java
// Saved in Training.java for later use
public NDArray linreg(NDArray X, NDArray w, NDArray b) {
return X.dot(w).add(b);
}
Defining the Loss Function
--------------------------
Since updating our model requires taking the gradient of our loss
function, we ought to define the loss function first. Here we will use
the squared loss function as described in the previous section. In the
implementation, we need to transform the true value ``y`` into the
predicted value's shape ``yHat``. The result returned by the following
function will also be the same as the ``yHat`` shape.
.. code:: java
// Saved in Training.java for later use
public NDArray squaredLoss(NDArray yHat, NDArray y) {
return (yHat.sub(y.reshape(yHat.getShape()))).mul
((yHat.sub(y.reshape(yHat.getShape())))).div(2);
}
Defining the Optimization Algorithm
-----------------------------------
As we discussed in the previous section, linear regression has a
closed-form solution. However, this is not a book about linear
regression, it is a book about deep learning. Since none of the other
models that this book introduces can be solved analytically, we will
take this opportunity to introduce your first working example of
stochastic gradient descent (SGD).
At each step, using one batch randomly drawn from our dataset, we will
estimate the gradient of the loss with respect to our parameters. Next,
we will update our parameters (a small amount) in the direction that
reduces the loss. Recall from :numref:`sec_gradcollector` that after
we call ``backward()``, each parameter (``param``) will have its
gradient stored in ``param.getGradient()``. The following code applies
the SGD update, given a set of parameters, a learning rate, and a batch
size. The size of the update step is determined by the learning rate
``lr``. Because our loss is calculated as a sum over the batch of
examples, we normalize our step size by the batch size (``batchSize``),
so that the magnitude of a typical step size does not depend heavily on
our choice of the batch size.
.. code:: java
// Saved in Training.java for later use
public static void sgd(NDList params, float lr, int batchSize) {
for (int i = 0; i < params.size(); i++) {
NDArray param = params.get(i);
// Update param
// param = param - param.gradient * lr / batchSize
param.subi(param.getGradient().mul(lr).div(batchSize));
}
}
Training
--------
Now that we have all of the parts in place, we are ready to implement
the main training loop. It is crucial that you understand this code
because you will see nearly identical training loops over and over again
throughout your career in deep learning.
In each iteration, we will grab minibatches of models, first passing
them through our model to obtain a set of predictions. After calculating
the loss, we call the ``backward()`` function to initiate the backwards
pass through the network, storing the gradients with respect to each
parameter in its corresponding ``gradient`` attribute. Technically since
``NDArray`` is an interface for each engine's implementation, there is
no standard ``gradient`` attribute, but we can safely assume that we can
access them however they are stored with ``getGradient()``. Finally, we
will call the optimization algorithm ``sgd`` to update the model
parameters. Since we previously set the batch size ``batchSize`` to
:math:`10`, the loss shape ``l`` for each minibatch is (:math:`10`,
:math:`1`).
In summary, we will execute the following loop:
- Initialize parameters :math:`(\mathbf{w}, b)`
- Repeat until done
- Compute gradient
:math:`\mathbf{g} \leftarrow \partial_{(\mathbf{w},b)} \frac{1}{\mathcal{B}} \sum_{i \in \mathcal{B}} l(\mathbf{x}^i, y^i, \mathbf{w}, b)`
- Update parameters
:math:`(\mathbf{w}, b) \leftarrow (\mathbf{w}, b) - \eta \mathbf{g}`
In the code below, ``l`` is a vector of the losses for each example in
the minibatch.
In each epoch (a pass through the data), we will iterate through the
entire dataset (using the ``dataset.getData()`` function) once passing
through every examples in the training dataset (assuming the number of
examples is divisible by the batch size). The number of epochs
``numEpochs`` and the learning rate ``lr`` are both hyper-parameters,
which we set here to :math:`3` and :math:`0.03`, respectively.
Unfortunately, setting hyper-parameters is tricky and requires some
adjustment by trial and error. We elide these details for now but revise
them later in :numref:`chap_optimization`.
Note: We can replace ``linreg`` and ``squaredLoss`` with any net or loss
function respectively and still keep the same training structure shown
here.
.. code:: java
import ai.djl.training.GradientCollector;
import ai.djl.engine.Engine;
float lr = 0.03f; // Learning Rate
int numEpochs = 3; // Number of Iterations
// Attach Gradients
for (NDArray param : params) {
param.attachGradient();
}
for (int epoch = 0; epoch < numEpochs; epoch++) {
// Assuming the number of examples can be divided by the batch size, all
// the examples in the training dataset are used once in one epoch
// iteration. The features and tags of minibatch examples are given by X
// and y respectively.
for (Batch batch : dataset.getData(manager)) {
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
// Minibatch loss in X and y
NDArray l = squaredLoss(linreg(X, params.get(0), params.get(1)), y);
gc.backward(l); // Compute gradient on l with respect to w and b
}
sgd(params, lr, batchSize); // Update parameters using their gradient
batch.close();
}
NDArray trainL = squaredLoss(linreg(features, params.get(0), params.get(1)), labels);
System.out.printf("epoch %d, loss %f\n", epoch + 1, trainL.mean().getFloat());
}
.. parsed-literal::
:class: output
epoch 1, loss 0.042579
epoch 2, loss 0.000161
epoch 3, loss 0.000052
In this case, because we synthesized the data ourselves, we know
precisely what the true parameters are. Thus, we can evaluate our
success in training by comparing the true parameters with those that we
learned through our training loop. Indeed they turn out to be very close
to each other.
.. code:: java
float[] w = trueW.sub(params.get(0).reshape(trueW.getShape())).toFloatArray();
System.out.printf("Error in estimating w: [%f %f]\n", w[0], w[1]);
System.out.printf("Error in estimating b: %f\n", trueB - params.get(1).getFloat());
.. parsed-literal::
:class: output
Error in estimating w: [-0.000233 -0.000601]
Error in estimating b: 0.000912
.. parsed-literal::
:class: output
java.io.PrintStream@4ba9664f
Note that we should not take it for granted that we are able to recover
the parameters accurately. This only happens for a special category
problems: strongly convex optimization problems with "enough" data to
ensure that the noisy samples allow us to recover the underlying
dependency. In most cases this is *not* the case. In fact, the
parameters of a deep network are rarely the same (or even close) between
two different runs, unless all conditions are identical, including the
order in which the data is traversed. However, in machine learning, we
are typically less concerned with recovering true underlying parameters,
and more concerned with parameters that lead to accurate prediction.
Fortunately, even on difficult optimization problems, stochastic
gradient descent can often find remarkably good solutions, owing partly
to the fact that, for deep networks, there exist many configurations of
the parameters that lead to accurate prediction.
Summary
-------
We saw how a deep network can be implemented and optimized from scratch,
using just ``NDArray`` and ``GradientCollector``, without any need for
defining layers, fancy optimizers, etc. This only scratches the surface
of what is possible. In the following sections, we will describe
additional models based on the concepts that we have just introduced and
learn how to implement them more concisely.
Exercises
---------
1. What would happen if we were to initialize the weights
:math:`\mathbf{w} = 0`. Would the algorithm still work?
2. Assume that you are `Georg Simon
Ohm `__ trying to come up
with a model between voltage and current. Can you use
``GradientCollector`` to learn the parameters of your model.
3. Can you use `Planck's
Law `__ to determine
the temperature of an object using spectral energy density?
4. What are the problems you might encounter if you wanted to extend
``GradientCollector`` to second derivatives? How would you fix them?
5. Why is the ``reshape()`` function needed in the ``squaredLoss()``
function?
6. Experiment using different learning rates to find out how fast the
loss function value drops.
7. If the number of examples cannot be divided by the batch size, what
happens to the ``dataset.getData()`` function's behavior?