Run this notebook online:Binder or Colab: Colab

8.3. Gradient Descent

In this section we are going to introduce the basic concepts underlying gradient descent. This is brief by necessity. See e.g., [Boyd & Vandenberghe, 2004] for an in-depth introduction to convex optimization. Although the latter is rarely used directly in deep learning, an understanding of gradient descent is key to understanding stochastic gradient descent algorithms. For instance, the optimization problem might diverge due to an overly large learning rate. This phenomenon can already be seen in gradient descent. Likewise, preconditioning is a common technique in gradient descent and carries over to more advanced algorithms. Let us start with a simple special case.

8.3.1. Gradient Descent in One Dimension

Gradient descent in one dimension is an excellent example to explain why the gradient descent algorithm may reduce the value of the objective function. Consider some continuously differentiable real-valued function \(f: \mathbb{R} \rightarrow \mathbb{R}\). Using a Taylor expansion (sec_single_variable_calculus) we obtain that

(8.3.1)\[f(x + \epsilon) = f(x) + \epsilon f'(x) + \mathcal{O}(\epsilon^2).\]

That is, in first approximation \(f(x+\epsilon)\) is given by the function value \(f(x)\) and the first derivative \(f'(x)\) at \(x\). It is not unreasonable to assume that for small \(\epsilon\) moving in the direction of the negative gradient will decrease \(f\). To keep things simple we pick a fixed step size \(\eta > 0\) and choose \(\epsilon = -\eta f'(x)\). Plugging this into the Taylor expansion above we get

(8.3.2)\[f(x - \eta f'(x)) = f(x) - \eta f'^2(x) + \mathcal{O}(\eta^2 f'^2(x)).\]

If the derivative \(f'(x) \neq 0\) does not vanish we make progress since \(\eta f'^2(x)>0\). Moreover, we can always choose \(\eta\) small enough for the higher order terms to become irrelevant. Hence we arrive at

(8.3.3)\[f(x - \eta f'(x)) \lessapprox f(x).\]

This means that, if we use

(8.3.4)\[x \leftarrow x - \eta f'(x)\]

to iterate \(x\), the value of function \(f(x)\) might decline. Therefore, in gradient descent we first choose an initial value \(x\) and a constant \(\eta > 0\) and then use them to continuously iterate \(x\) until the stop condition is reached, for example, when the magnitude of the gradient \(|f'(x)|\) is small enough or the number of iterations has reached a certain value.

For simplicity we choose the objective function \(f(x)=x^2\) to illustrate how to implement gradient descent. Although we know that \(x=0\) is the solution to minimize \(f(x)\), we still use this simple function to observe how \(x\) changes. As always, we begin by importing all required modules.

%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.7.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26

%maven ai.djl.mxnet:mxnet-engine:0.7.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-b
%load ../utils/plot-utils
%load ../utils/Functions.java
import ai.djl.ndarray.*;
import java.lang.Math;
import tech.tablesaw.plotly.traces.ScatterTrace;
Function<Float, Float> f = x -> x * x; // Objective Function
Function<Float, Float> gradf = x -> 2 * x; // Its Derivative

NDManager manager = NDManager.newBaseManager();

Next, we use \(x=10\) as the initial value and assume \(\eta=0.2\). Using gradient descent to iterate \(x\) for 10 times we can see that, eventually, the value of \(x\) approaches the optimal solution.

public float[] gd(float eta) {
    float x = 10f;
    float[] results = new float[11];
    results[0] = x;

    for (int i = 0; i < 10; i++) {
        x -= eta * gradf.apply(x);
        results[i + 1] = x;
    }
    System.out.printf("epoch 10, x: %f\n", x);
    return results;
}

float[] res = gd(0.2f);
epoch 10, x: 0.060466

The progress of optimizing over \(x\) can be plotted as follows.

/* Saved in GradDescUtils.java */
public void plotGD(float[] x, float[] y, float[] segment, Function<Float, Float> func,
                                 int width, int height) {
    // Function Line
    ScatterTrace trace = ScatterTrace.builder(Functions.floatToDoubleArray(x),
                                              Functions.floatToDoubleArray(y))
        .mode(ScatterTrace.Mode.LINE)
        .build();

    // GD Line
    ScatterTrace trace2 = ScatterTrace.builder(Functions.floatToDoubleArray(segment),
                                               Functions.floatToDoubleArray(Functions.callFunc(segment, func)))
        .mode(ScatterTrace.Mode.LINE)
        .build();

    // GD Points
    ScatterTrace trace3 = ScatterTrace.builder(Functions.floatToDoubleArray(segment),
                                               Functions.floatToDoubleArray(Functions.callFunc(segment, func)))
        .build();

    Layout layout = Layout.builder()
        .height(height)
        .width(width)
        .showLegend(false)
        .build();

    display(new Figure(layout, trace, trace2, trace3));
}
/* Saved in GradDescUtils.java */
public void showTrace(float[] res) {
    float n = 0;
    for (int i = 0; i < res.length; i++) {
        if (Math.abs(res[i]) > n) {
            n = Math.abs(res[i]);
        }
    }
    NDArray fLineND = manager.arange(-n, n, 0.01f);
    float[] fLine = fLineND.toFloatArray();
    plotGD(fLine, Functions.callFunc(fLine, f), res, f, 500, 400);
}

showTrace(res);

8.3.1.1. Learning Rate

The learning rate \(\eta\) can be set by the algorithm designer. If we use a learning rate that is too small, it will cause \(x\) to update very slowly, requiring more iterations to get a better solution. To show what happens in such a case, consider the progress in the same optimization problem for \(\eta = 0.05\). As we can see, even after 10 steps we are still very far from the optimal solution.

showTrace(gd(0.05f));
epoch 10, x: 3.486785

Conversely, if we use an excessively high learning rate, \(\left|\eta f'(x)\right|\) might be too large for the first-order Taylor expansion formula. That is, the term \(\mathcal{O}(\eta^2 f'^2(x))\) in (8.3.1) might become significant. In this case, we cannot guarantee that the iteration of \(x\) will be able to lower the value of \(f(x)\). For example, when we set the learning rate to \(\eta=1.1\), \(x\) overshoots the optimal solution \(x=0\) and gradually diverges.

showTrace(gd(1.1f));
epoch 10, x: 61.917389

8.3.1.2. Local Minima

To illustrate what happens for nonconvex functions consider the case of \(f(x) = x \cdot \cos c x\). This function has infinitely many local minima. Depending on our choice of learning rate and depending on how well conditioned the problem is, we may end up with one of many solutions. The example below illustrates how an (unrealistically) high learning rate will lead to a poor local minimum.

float c = (float)(0.15f * Math.PI);

Function<Float, Float> f = x -> x * (float)Math.cos(c * x);

Function<Float, Float> gradf = x -> (float)(Math.cos(c * x) - c * x * Math.sin(c * x));

showTrace(gd(2));
epoch 10, x: -1.528166

8.3.2. Multivariate Gradient Descent

Now that we have a better intuition of the univariate case, let us consider the situation where \(\mathbf{x} \in \mathbb{R}^d\). That is, the objective function \(f: \mathbb{R}^d \to \mathbb{R}\) maps vectors into scalars. Correspondingly its gradient is multivariate, too. It is a vector consisting of \(d\) partial derivatives:

(8.3.5)\[\nabla f(\mathbf{x}) = \bigg[\frac{\partial f(\mathbf{x})}{\partial x_1}, \frac{\partial f(\mathbf{x})}{\partial x_2}, \ldots, \frac{\partial f(\mathbf{x})}{\partial x_d}\bigg]^\top.\]

Each partial derivative element \(\partial f(\mathbf{x})/\partial x_i\) in the gradient indicates the rate of change of \(f\) at \(\mathbf{x}\) with respect to the input \(x_i\). As before in the univariate case we can use the corresponding Taylor approximation for multivariate functions to get some idea of what we should do. In particular, we have that

(8.3.6)\[f(\mathbf{x} + \mathbf{\epsilon}) = f(\mathbf{x}) + \mathbf{\epsilon}^\top \nabla f(\mathbf{x}) + \mathcal{O}(\|\mathbf{\epsilon}\|^2).\]

In other words, up to second order terms in \(\mathbf{\epsilon}\) the direction of steepest descent is given by the negative gradient \(-\nabla f(\mathbf{x})\). Choosing a suitable learning rate \(\eta > 0\) yields the prototypical gradient descent algorithm:

\(\mathbf{x} \leftarrow \mathbf{x} - \eta \nabla f(\mathbf{x}).\)

To see how the algorithm behaves in practice let us construct an objective function \(f(\mathbf{x})=x_1^2+2x_2^2\) with a two-dimensional vector \(\mathbf{x} = [x_1, x_2]^\top\) as input and a scalar as output. The gradient is given by \(\nabla f(\mathbf{x}) = [2x_1, 4x_2]^\top\). We will observe the trajectory of \(\mathbf{x}\) by gradient descent from the initial position \([-5, -2]\). We need two more helper functions. The first uses an update function and applies it \(20\) times to the initial value. The second helper visualizes the trajectory of \(\mathbf{x}\).

We also create a Weights class to make it easier to store the weight parameters and return them in functions.

/* Saved in GradDescUtils.java */
public class Weights {
    public float x1, x2;
    public Weights(float x1, float x2) {
        this.x1 = x1;
        this.x2 = x2;
    }
}

/* Saved in GradDescUtils.java */
/* Optimize a 2D objective function with a customized trainer. */
public ArrayList<Weights> train2d(Function<Float[], Float[]> trainer, int steps) {
    // s1 and s2 are internal state variables and will
    // be used later in the chapter
    float x1 = -5f, x2 = -2f, s1 = 0f, s2 = 0f;
    ArrayList<Weights> results = new ArrayList<>();
    results.add(new Weights(x1, x2));
    for (int i = 1; i < steps + 1; i++) {
        Float[] step = trainer.apply(new Float[]{x1, x2, s1, s2});
        x1 = step[0];
        x2 = step[1];
        s1 = step[2];
        s2 = step[3];
        results.add(new Weights(x1, x2));
        System.out.printf("epoch %d, x1 %f, x2 %f\n", i, x1, x2);
    }
    return results;
}

import java.util.function.BiFunction;

/* Saved in GradDescUtils.java */
/* Show the trace of 2D variables during optimization. */
public void showTrace2d(BiFunction<Float, Float, Float> f, ArrayList<Weights> results) {
    // TODO: add when tablesaw adds support for contour and meshgrids
}

Next, we observe the trajectory of the optimization variable \(\mathbf{x}\) for learning rate \(\eta = 0.1\). We can see that after 20 steps the value of \(\mathbf{x}\) approaches its minimum at \([0, 0]\). Progress is fairly well-behaved albeit rather slow.

float eta = 0.1f;

BiFunction<Float, Float, Float> f = (x1, x2) -> x1 * x1 + 2 * x2 * x2; // Objective

BiFunction<Float, Float, Float[]> gradf = (x1, x2) -> new Float[]{2 * x1, 4 * x2}; // Gradient

Function<Float[], Float[]> gd = (state) -> {
    Float x1 = state[0];
    Float x2 = state[1];

    Float[] g = gradf.apply(x1, x2); // Compute Gradient
    Float g1 = g[0];
    Float g2 = g[1];

    return new Float[]{x1 - eta * g1, x2 - eta * g2, 0f, 0f}; // Update Variables
};

showTrace2d(f, train2d(gd, 20));
epoch 1, x1 -4.000000, x2 -1.200000
epoch 2, x1 -3.200000, x2 -0.720000
epoch 3, x1 -2.560000, x2 -0.432000
epoch 4, x1 -2.048000, x2 -0.259200
epoch 5, x1 -1.638400, x2 -0.155520
epoch 6, x1 -1.310720, x2 -0.093312
epoch 7, x1 -1.048576, x2 -0.055987
epoch 8, x1 -0.838861, x2 -0.033592
epoch 9, x1 -0.671089, x2 -0.020155
epoch 10, x1 -0.536871, x2 -0.012093
epoch 11, x1 -0.429497, x2 -0.007256
epoch 12, x1 -0.343597, x2 -0.004354
epoch 13, x1 -0.274878, x2 -0.002612
epoch 14, x1 -0.219902, x2 -0.001567
epoch 15, x1 -0.175922, x2 -0.000940
epoch 16, x1 -0.140737, x2 -0.000564
epoch 17, x1 -0.112590, x2 -0.000339
epoch 18, x1 -0.090072, x2 -0.000203
epoch 19, x1 -0.072058, x2 -0.000122
epoch 20, x1 -0.057646, x2 -0.000073
https://d2l-java-resources.s3.amazonaws.com/img/contour_gd.svg

Fig. 8.3.1 Contour Gradient Descent.

8.3.3. Adaptive Methods

As we could see in Section 8.3.1.1, getting the learning rate \(\eta\) “just right” is tricky. If we pick it too small, we make no progress. If we pick it too large, the solution oscillates and in the worst case it might even diverge. What if we could determine \(\eta\) automatically or get rid of having to select a step size at all? Second order methods that look not only at the value and gradient of the objective but also at its curvature can help in this case. While these methods cannot be applied to deep learning directly due to the computational cost, they provide useful intuition into how to design advanced optimization algorithms that mimic many of the desirable properties of the algorithms outlined below.

8.3.3.1. Newton’s Method

Reviewing the Taylor expansion of \(f\) there is no need to stop after the first term. In fact, we can write it as

(8.3.7)\[f(\mathbf{x} + \mathbf{\epsilon}) = f(\mathbf{x}) + \mathbf{\epsilon}^\top \nabla f(\mathbf{x}) + \frac{1}{2} \mathbf{\epsilon}^\top \nabla \nabla^\top f(\mathbf{x}) \mathbf{\epsilon} + \mathcal{O}(\|\mathbf{\epsilon}\|^3).\]

To avoid cumbersome notation we define \(H_f := \nabla \nabla^\top f(\mathbf{x})\) to be the Hessian of \(f\). This is a \(d \times d\) matrix. For small \(d\) and simple problems \(H_f\) is easy to compute. For deep networks, on the other hand, \(H_f\) may be prohibitively large, due to the cost of storing \(\mathcal{O}(d^2)\) entries. Furthermore it may be too expensive to compute via backprop as we would need to apply backprop to the backpropagation call graph. For now let us ignore such considerations and look at what algorithm we’d get.

After all, the minimum of \(f\) satisfies \(\nabla f(\mathbf{x}) = 0\). Taking derivatives of (8.3.7) with regard to \(\mathbf{\epsilon}\) and ignoring higher order terms we arrive at

(8.3.8)\[\nabla f(\mathbf{x}) + H_f \mathbf{\epsilon} = 0 \text{ and hence } \mathbf{\epsilon} = -H_f^{-1} \nabla f(\mathbf{x}).\]

That is, we need to invert the Hessian \(H_f\) as part of the optimization problem.

For \(f(x) = \frac{1}{2} x^2\) we have \(\nabla f(x) = x\) and \(H_f = 1\). Hence for any \(x\) we obtain \(\epsilon = -x\). In other words, a single step is sufficient to converge perfectly without the need for any adjustment! Alas, we got a bit lucky here since the Taylor expansion was exact. Let us see what happens in other problems.

float c = 0.5f;

Function<Float, Float> f = x -> (float)Math.cosh(c * x); // Objective

Function<Float, Float> gradf = x -> c * (float)Math.sinh(c * x); // Derivative

Function<Float, Float> hessf = x -> c * c * (float)Math.cosh(c * x); // Hessian

// Hide learning rate for now
public float[] newton(float eta) {
    float x = 10f;
    float[] results = new float[11];
    results[0] = x;

    for (int i = 0; i < 10; i++) {
        x -= eta * gradf.apply(x) / hessf.apply(x);
        results[i + 1] = x;
    }
    System.out.printf("epoch 10, x: %f\n", x);
    return results;
}

showTrace(newton(1));
epoch 10, x: 0.000000

Now let us see what happens when we have a nonconvex function, such as \(f(x) = x \cos(c x)\). After all, note that in Newton’s method we end up dividing by the Hessian. This means that if the second derivative is negative we would walk into the direction of increasing \(f\). That is a fatal flaw of the algorithm. Let us see what happens in practice.

c = 0.15f * (float)Math.PI;

Function<Float, Float> f = x -> x * (float)Math.cos(c * x);

Function<Float, Float> gradf = x -> (float)(Math.cos(c * x) - c * x * Math.sin(c * x));

Function<Float, Float> hessf = x -> (float)(-2 * c * Math.sin(c * x) -
                                    x * c * c * Math.cos(c * x));

showTrace(newton(1));
epoch 10, x: 26.834131