Run this notebook online: or Colab:
11.4. Stochastic Gradient Descent¶
In this section, we are going to introduce the basic principles of stochastic gradient descent.
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/GradDescUtils.java
11.4.1. Stochastic Gradient Updates¶
In deep learning, the objective function is usually the average of the loss functions for each example in the training dataset. We assume that \(f_i(\mathbf{x})\) is the loss function of the training dataset with \(n\) examples, an index of \(i\), and parameter vector of \(\mathbf{x}\), then we have the objective function
The gradient of the objective function at \(\mathbf{x}\) is computed as
If gradient descent is used, the computing cost for each independent variable iteration is \(\mathcal{O}(n)\), which grows linearly with \(n\). Therefore, when the model training dataset is large, the cost of gradient descent for each iteration will be very high.
Stochastic gradient descent (SGD) reduces computational cost at each iteration. At each iteration of stochastic gradient descent, we uniformly sample an index \(i\in\{1,\ldots, n\}\) for data instances at random, and compute the gradient \(\nabla f_i(\mathbf{x})\) to update \(\mathbf{x}\):
Here, \(\eta\) is the learning rate. We can see that the computing cost for each iteration drops from \(\mathcal{O}(n)\) of the gradient descent to the constant \(\mathcal{O}(1)\). We should mention that the stochastic gradient \(\nabla f_i(\mathbf{x})\) is the unbiased estimate of gradient \(\nabla f(\mathbf{x})\).
This means that, on average, the stochastic gradient is a good estimate of the gradient.
Now, we will compare it to gradient descent by adding random noise with a mean of 0 to the gradient to simulate a SGD.
NDManager manager = NDManager.newBaseManager();
// Sample once from a normal distribution
public float getRandomNormal(float mean, float sd) {
return manager.randomNormal(mean, sd, new Shape(1), DataType.FLOAT32).getFloat();
}
float eta = 0.01f;
Supplier<Float> lr = () -> 1f; // Constant Learning Rate
BiFunction<Float, Float, Float> f = (x1, x2) -> x1 * x1 + 2 * x2 * x2; // Objective
BiFunction<Float, Float, Float[]> gradf = (x1, x2) -> new Float[]{2 * x1, 4 * x2}; // Gradient
Function<Float[], Float[]> sgd = (state) -> {
Float x1 = state[0];
Float x2 = state[1];
Float s1 = state[2];
Float s2 = state[3];
Float[] g = gradf.apply(x1, x2);
Float g1 = g[0];
Float g2 = g[1];
g1 += getRandomNormal(0f, 0.1f);
g2 += getRandomNormal(0f, 0.1f);
Float etaT = eta * lr.get();
return new Float[]{x1 - etaT * g1, x2 - etaT * g2, 0f, 0f};
};
GradDescUtils.showTrace2d(f, GradDescUtils.train2d(sgd, 50));
Tablesaw not supporting for contour and meshgrids, will update soon
As we can see, the trajectory of the variables in the SGD is much more noisy than the one we observed in gradient descent in the previous section. This is due to the stochastic nature of the gradient. That is, even when we arrive near the minimum, we are still subject to the uncertainty injected by the instantaneous gradient via \(\eta \nabla f_i(\mathbf{x})\). Even after 50 steps the quality is still not so good. Even worse, it will not improve after additional steps (we encourage the reader to experiment with a larger number of steps to confirm this on his own). This leaves us with the only alternative—change the learning rate \(\eta\). However, if we pick this too small, we will not make any meaningful progress initially. On the other hand, if we pick it too large, we will not get a good solution, as seen above. The only way to resolve these conflicting goals is to reduce the learning rate dynamically as optimization progresses.
This is also the reason for adding a learning rate function lr()
into the sgd()
step function. In the example above any functionality
for learning rate scheduling lies dormant as we set the associated
lr()
function to be constant, i.e., lr = v -> 1f;
.
11.4.2. Dynamic Learning Rate¶
Replacing \(\eta\) with a time-dependent learning rate \(\eta(t)\) adds to the complexity of controlling convergence of an optimization algorithm. In particular, need to figure out how rapidly \(\eta\) should decay. If it is too quick, we will stop optimizing prematurely. If we decrease it too slowly, we waste too much time on optimization. There are a few basic strategies that are used in adjusting \(\eta\) over time (we will discuss more advanced strategies in a later chapter):
In the first scenario we decrease the learning rate, e.g., whenever progress in optimization has stalled. This is a common strategy for training deep networks. Alternatively we could decrease it much more aggressively by an exponential decay. Unfortunately this leads to premature stopping before the algorithm has converged. A popular choice is polynomial decay with \(\alpha = 0.5\). In the case of convex optimization there are a number of proofs which show that this rate is well behaved. Let us see what this looks like in practice.
int ctr = 1;
Supplier<Float> exponential = () -> {
ctr += 1;
return (float)Math.exp(-0.1 * ctr);
};
lr = exponential; // Set up learning rate
GradDescUtils.showTrace2d(f, GradDescUtils.train2d(sgd, 1000));
Tablesaw not supporting for contour and meshgrids, will update soon
As expected, the variance in the parameters is significantly reduced. However, this comes at the expense of failing to converge to the optimal solution \(\mathbf{x} = (0, 0)\). Even after 1000 steps are we are still very far away from the optimal solution. Indeed, the algorithm fails to converge at all. On the other hand, if we use a polynomial decay where the learning rate decays with the inverse square root of the number of steps convergence is good.
int ctr = 1;
Supplier<Float> polynomial = () -> {
ctr += 1;
return (float)Math.pow(1 + 0.1 * ctr, -0.5);
};
lr = polynomial; // Set up learning rate
GradDescUtils.showTrace2d(f, GradDescUtils.train2d(sgd, 1000));
Tablesaw not supporting for contour and meshgrids, will update soon
There exist many more choices for how to set the learning rate. For instance, we could start with a small rate, then rapidly ramp up and then decrease it again, albeit more slowly. We could even alternate between smaller and larger learning rates. There exists a large variety of such schedules. For now let us focus on learning rate schedules for which a comprehensive theoretical analysis is possible, i.e., on learning rates in a convex setting. For general nonconvex problems it is very difficult to obtain meaningful convergence guarantees, since in general minimizing nonlinear nonconvex problems is NP hard. For a survey see e.g., the excellent lecture notes of Tibshirani 2015.
11.4.3. Convergence Analysis for Convex Objectives¶
The following is optional and primarily serves to convey more intuition about the problem. We limit ourselves to one of the simplest proofs, as described by [Nesterov & Vial, 2000]. Significantly more advanced proof techniques exist, e.g., whenever the objective function is particularly well behaved. [Hazan et al., 2008] show that for strongly convex functions, i.e., for functions that can be bounded from below by \(\mathbf{x}^\top \mathbf{Q} \mathbf{x}\), it is possible to minimize them in a small number of steps while decreasing the learning rate like \(\eta(t) = \eta_0/(\beta t + 1)\). Unfortunately this case never really occurs in deep learning and we are left with a much more slowly decreasing rate in practice.
Consider the case where
In particular, assume that \(\mathbf{x}_t\) is drawn from some distribution \(P(\mathbf{x})\) and that \(l(\mathbf{x}, \mathbf{w})\) is a convex function in \(\mathbf{w}\) for all \(\mathbf{x}\). Last denote by
the expected risk and by \(R^*\) its minimum with regard to \(\mathbf{w}\). Last let \(\mathbf{w}^*\) be the minimizer (we assume that it exists within the domain which \(\mathbf{w}\) is defined). In this case we can track the distance between the current parameter \(\mathbf{w}_t\) and the risk minimizer \(\mathbf{w}^*\) and see whether it improves over time:
The gradient \(\partial_\mathbf{w} l(\mathbf{x}_t, \mathbf{w})\) can be bounded from above by some Lipschitz constant \(L\), hence we have that
We are mostly interested in how the distance between \(\mathbf{w}_t\) and \(\mathbf{w}^*\) changes in expectation. In fact, for any specific sequence of steps the distance might well increase, depending on whichever \(\mathbf{x}_t\) we encounter. Hence we need to bound the inner product. By convexity we have that
Using both inequalities and plugging it into the above we obtain a bound on the distance between parameters at time \(t+1\) as follows:
This means that we make progress as long as the expected difference between current loss and the optimal loss outweighs \(\eta_t L^2\). Since the former is bound to converge to \(0\) it follows that the learning rate \(\eta_t\) also needs to vanish.
Next we take expectations over this expression. This yields
The last step involves summing over the inequalities for \(t \in \{t, \ldots, T\}\). Since the sum telescopes and by dropping the lower term we obtain
Note that we exploited that \(\mathbf{w}_0\) is given and thus the expectation can be dropped. Last define
Then by convexity it follows that
Plugging this into the above inequality yields the bound
Here \(r^2 := \|\mathbf{w}_0 - \mathbf{w}^*\|^2\) is a bound on the distance between the initial choice of parameters and the final outcome. In short, the speed of convergence depends on how rapidly the loss function changes via the Lipschitz constant \(L\) and how far away from optimality the initial value is \(r\). Note that the bound is in terms of \(\bar{\mathbf{w}}\) rather than \(\mathbf{w}_T\). This is the case since \(\bar{\mathbf{w}}\) is a smoothed version of the optimization path. Now let us analyze some choices for \(\eta_t\).
Known Time Horizon. Whenever \(r, L\) and \(T\) are known we can pick \(\eta = r/L \sqrt{T}\). This yields as upper bound \(r L (1 + 1/T)/2\sqrt{T} < rL/\sqrt{T}\). That is, we converge with rate \(\mathcal{O}(1/\sqrt{T})\) to the optimal solution.
Unknown Time Horizon. Whenever we want to have a good solution for any time \(T\) we can pick \(\eta = \mathcal{O}(1/\sqrt{T})\). This costs us an extra logarithmic factor and it leads to an upper bound of the form \(\mathcal{O}(\log T / \sqrt{T})\).
Note that for strongly convex losses \(l(\mathbf{x}, \mathbf{w}') \geq l(\mathbf{x}, \mathbf{w}) + \langle \mathbf{w}'-\mathbf{w}, \partial_\mathbf{w} l(\mathbf{x}, \mathbf{w}) \rangle + \frac{\lambda}{2} \|\mathbf{w}-\mathbf{w}'\|^2\) we can design even more rapidly converging optimization schedules. In fact, an exponential decay in \(\eta\) leads to a bound of the form \(\mathcal{O}(\log T / T)\).
11.4.4. Stochastic Gradients and Finite Samples¶
So far we have played a bit fast and loose when it comes to talking about stochastic gradient descent. We posited that we draw instances \(x_i\), typically with labels \(y_i\) from some distribution \(p(x, y)\) and that we use this to update the weights \(w\) in some manner. In particular, for a finite sample size we simply argued that the discrete distribution \(p(x, y) = \frac{1}{n} \sum_{i=1}^n \delta_{x_i}(x) \delta_{y_i}(y)\) allows us to perform SGD over it.
However, this is not really what we did. In the toy examples in the current section we simply added noise to an otherwise non-stochastic gradient, i.e., we pretended to have pairs \((x_i, y_i)\). It turns out that this is justified here (see the exercises for a detailed discussion). More troubling is that in all previous discussions we clearly did not do this. Instead we iterated over all instances exactly once. To see why this is preferable consider the converse, namely that we are sampling \(n\) observations from the discrete distribution with replacement. The probability of choosing an element \(i\) at random is \(N^{-1}\). Thus to choose it at least once is
A similar reasoning shows that the probability of picking a sample exactly once is given by \({N \choose 1} N^{-1} (1-N^{-1})^{N-1} = \frac{N-1}{N} (1-N^{-1})^{N} \approx e^{-1} \approx 0.37\). This leads to an increased variance and decreased data efficiency relative to sampling without replacement. Hence, in practice we perform the latter (and this is the default choice throughout this book). Last note that repeated passes through the dataset traverse it in a different random order.
11.4.5. Summary¶
For convex problems we can prove that for a wide choice of learning rates Stochastic Gradient Descent will converge to the optimal solution.
For deep learning this is generally not the case. However, the analysis of convex problems gives us useful insight into how to approach optimization, namely to reduce the learning rate progressively, albeit not too quickly.
Problems occur when the learning rate is too small or too large. In practice a suitable learning rate is often found only after multiple experiments.
When there are more examples in the training dataset, it costs more to compute each iteration for gradient descent, so SGD is preferred in these cases.
Optimality guarantees for SGD are in general not available in nonconvex cases since the number of local minima that require checking might well be exponential.
11.4.6. Exercises¶
Experiment with different learning rate schedules for SGD and with different numbers of iterations. In particular, plot the distance from the optimal solution \((0, 0)\) as a function of the number of iterations.
Prove that for the function \(f(x_1, x_2) = x_1^2 + 2 x_2^2\) adding normal noise to the gradient is equivalent to minimizing a loss function \(l(\mathbf{x}, \mathbf{w}) = (x_1 - w_1)^2 + 2 (x_2 - w_2)^2\) where \(x\) is drawn from a normal distribution.
Derive mean and variance of the distribution for \(\mathbf{x}\).
Show that this property holds in general for objective functions \(f(\mathbf{x}) = \frac{1}{2} (\mathbf{x} - \mathbf{\mu})^\top Q (\mathbf{x} - \mathbf{\mu})\) for \(Q \succeq 0\).
Compare convergence of SGD when you sample from \(\{(x_1, y_1), \ldots, (x_m, y_m)\}\) with replacement and when you sample without replacement.
How would you change the SGD solver if some gradient (or rather some coordinate associated with it) was consistently larger than all other gradients?
Assume that \(f(x) = x^2 (1 + \sin x)\). How many local minima does \(f\) have? Can you change \(f\) in such a way that to minimize it one needs to evaluate all local minima?