Run this notebook online:Binder or Colab: Colab

8.8. RMSProp

One of the key issues in Section 8.7 is that the learning rate decreases at a predefined schedule of effectively \(\mathcal{O}(t^{-\frac{1}{2}})\). While this is generally appropriate for convex problems, it might not be ideal for nonconvex ones, such as those encountered in deep learning. Yet, the coordinate-wise adaptivity of Adagrad is highly desirable as a preconditioner.

[Tieleman & Hinton, 2012] proposed the RMSProp algorithm as a simple fix to decouple rate scheduling from coordinate-adaptive learning rates. The issue is that Adagrad accumulates the squares of the gradient \(\mathbf{g}_t\) into a state vector \(\mathbf{s}_t = \mathbf{s}_{t-1} + \mathbf{g}_t^2\). As a result \(\mathbf{s}_t\) keeps on growing without bound due to the lack of normalization, essentially linarly as the algorithm converges.

One way of fixing this problem would be to use \(\mathbf{s}_t / t\). For reasonable distributions of \(\mathbf{g}_t\) this will converge. Unfortunately it might take a very long time until the limit behavior starts to matter since the procedure remembers the full trajectory of values. An alternative is to use a leaky average in the same way we used in the momentum method, i.e., \(\mathbf{s}_t \leftarrow \gamma \mathbf{s}_{t-1} + (1-\gamma) \mathbf{g}_t^2\) for some parameter \(\gamma > 0\). Keeping all other parts unchanged yields RMSProp.

8.8.1. The Algorithm

Let us write out the equations in detail.

(8.8.1)\[\begin{split}\begin{aligned} \mathbf{s}_t & \leftarrow \gamma \mathbf{s}_{t-1} + (1 - \gamma) \mathbf{g}_t^2, \\ \mathbf{x}_t & \leftarrow \mathbf{x}_{t-1} - \frac{\eta}{\sqrt{\mathbf{s}_t + \epsilon}} \odot \mathbf{g}_t. \end{aligned}\end{split}\]

The constant \(\epsilon > 0\) is typically set to \(10^{-6}\) to ensure that we do not suffer from division by zero or overly large step sizes. Given this expansion we are now free to control the learning rate \(\eta\) independently of the scaling that is applied on a per-coordinate basis. In terms of leaky averages we can apply the same reasoning as previously applied in the case of the momentum method. Expanding the definition of \(\mathbf{s}_t\) yields

(8.8.2)\[\begin{split}\begin{aligned} \mathbf{s}_t & = (1 - \gamma) \mathbf{g}_t^2 + \gamma \mathbf{s}_{t-1} \\ & = (1 - \gamma) \left(\mathbf{g}_t^2 + \gamma \mathbf{g}_{t-1}^2 + \gamma^2 \mathbf{g}_{t-2} + \ldots, \right). \end{aligned}\end{split}\]

As before in Section 8.6 we use \(1 + \gamma + \gamma^2 + \ldots, = \frac{1}{1-\gamma}\). Hence the sum of weights is normalized to \(1\) with a half-life time of an observation of \(\gamma^{-1}\). Let us visualize the weights for the past 40 timesteps for various choices of \(\gamma\).

%mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/

%maven ai.djl:api:0.7.0-SNAPSHOT
%maven ai.djl:basicdataset:0.7.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26

%maven ai.djl.mxnet:mxnet-engine:0.7.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-a
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/GradDescUtils.java
%load ../utils/Accumulator.java
%load ../utils/StopWatch.java
%load ../utils/Training.java
%load ../utils/TrainingChapter11.java
import ai.djl.training.tracker.Tracker;
import java.io.IOException;
import ai.djl.translate.TranslateException;
NDManager manager = NDManager.newBaseManager();

float[] gammas = new float[]{0.95f, 0.9f, 0.8f, 0.7f};

NDArray timesND = manager.arange(40f);
float[] times = timesND.toFloatArray();
display(GradDescUtils.plotGammas(times, gammas, 600, 400));
13429901-9119-4e77-a3e7-afe22605bee7

8.8.2. Implementation from Scratch

As before we use the quadratic function \(f(\mathbf{x})=0.1x_1^2+2x_2^2\) to observe the trajectory of RMSProp. Recall that in Section 8.7, when we used Adagrad with a learning rate of 0.4, the variables moved only very slowly in the later stages of the algorithm since the learning rate decreased too quickly. Since \(\eta\) is controlled separately this does not happen with RMSProp.

float eta = 0.4f;
float gamma = 0.9f;

Function<Float[], Float[]> rmsProp2d = (state) -> {
    Float x1 = state[0], x2 = state[1], s1 = state[2], s2 = state[3];
    float g1 = 0.2f * x1;
    float g2 = 4 * x2;
    float eps = (float) 1e-6;
    s1 = gamma * s1 + (1 - gamma) * g1 * g1;
    s2 = gamma * s2 + (1 - gamma) * g2 * g2;
    x1 -= eta / (float) Math.sqrt(s1 + eps) * g1;
    x2 -= eta / (float) Math.sqrt(s2 + eps) * g2;
    return new Float[]{x1, x2, s1, s2};
};

BiFunction<Float, Float, Float> f2d = (x1, x2) -> {
    return 0.1f * x1 * x1 + 2 * x2 * x2;
};

GradDescUtils.showTrace2d(f2d, GradDescUtils.train2d(rmsProp2d, 20));
Tablesaw not supporting for contour and meshgrids, will update soon
https://d2l-java-resources.s3.amazonaws.com/img/chapter_optim-rmsprop-gd2d.svg

Fig. 8.8.1 RmsProp Gradient Descent 2D.

Next, we implement RMSProp to be used in a deep network. This is equally straightforward.

NDList initRmsPropStates(int featureDimension) {
    NDManager manager = NDManager.newBaseManager();
    NDArray sW = manager.zeros(new Shape(featureDimension, 1));
    NDArray sB = manager.zeros(new Shape(1));
    return new NDList(sW, sB);
}

public class Optimization {
    public static void rmsProp(NDList params, NDList states, Map<String, Float> hyperparams) {
        float gamma = hyperparams.get("gamma");
        float eps = (float) 1e-6;
        for (int i = 0; i < params.size(); i++) {
            NDArray param = params.get(i);
            NDArray state = states.get(i);
            // Update parameter and state
            // state = gamma * state + (1 - gamma) * param.gradient^(1/2)
            state.muli(gamma).addi(param.getGradient().square().mul(1 - gamma));
            // param -= lr * param.gradient / sqrt(s + eps)
            param.subi(param.getGradient().mul(hyperparams.get("lr")).div(state.add(eps).sqrt()));
        }
    }
}

We set the initial learning rate to 0.01 and the weighting term \(\gamma\) to 0.9. That is, \(\mathbf{s}\) aggregates on average over the past \(1/(1-\gamma) = 10\) observations of the square gradient.

AirfoilRandomAccess airfoil = TrainingChapter11.getDataCh11(10, 1500);

public TrainingChapter11.LossTime trainRmsProp(float lr, float gamma, int numEpochs)
                    throws IOException, TranslateException {
    int featureDimension = airfoil.getFeatureArraySize();
    Map<String, Float> hyperparams = new HashMap<>();
    hyperparams.put("lr", lr);
    hyperparams.put("gamma", gamma);
    return TrainingChapter11.trainCh11(Optimization::rmsProp,
                                       initRmsPropStates(featureDimension),
                                       hyperparams, airfoil,
                                       featureDimension, numEpochs);
}

trainRmsProp(0.01f, 0.9f, 2);
loss: 0.254, 0.094 sec/epoch
REPL.$JShell$122B$TrainingChapter11$LossTime@e49dd48

8.8.3. Concise Implementation

Since RMSProp is a rather popular algorithm it is also available in Optimizer. We create an instance of RmsProp and set its learning rate and optional gamma1 parameter.

Tracker lrt = Tracker.fixed(0.01f);
Optimizer rmsProp = Optimizer.rmsprop().optLearningRateTracker(lrt).optRho(0.9f).build();

TrainingChapter11.trainConciseCh11(rmsProp, airfoil, 2);
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Training on: 4 GPUs.
[IJava-executor-0] INFO ai.djl.training.listener.LoggingTrainingListener - Load MXNet Engine Version 1.7.0 in 0.084 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.67, L2Loss: 0.31
loss: 0.248, 3.270 sec/epoch

8.8.4. Summary

  • RMSProp is very similar to Adagrad insofar as both use the square of the gradient to scale coefficients.

  • RMSProp shares with momentum the leaky averaging. However, RMSProp uses the technique to adjust the coefficient-wise preconditioner.

  • The learning rate needs to be scheduled by the experimenter in practice.

  • The coefficient \(\gamma\) determines how long the history is when adjusting the per-coordinate scale.

8.8.5. Exercises

  1. What happens experimentally if we set \(\gamma = 1\)? Why?

  2. Rotate the optimization problem to minimize \(f(\mathbf{x}) = 0.1 (x_1 + x_2)^2 + 2 (x_1 - x_2)^2\). What happens to the convergence?

  3. Try out what happens to RMSProp on a real machine learning problem, such as training on Fashion-MNIST. Experiment with different choices for adjusting the learning rate.

  4. Would you want to adjust \(\gamma\) as optimization progresses? How sensitive is RMSProp to this?