Run this notebook online: or Colab:
8.5. Implementation of Recurrent Neural Networks from Scratch¶
In this section we will implement an RNN from scratch for a character-level language model, according to our descriptions in Section 8.4. Such a model will be trained on H. G. Wells’ The Time Machine. As before, we start by reading the dataset first, which is introduced in Section 8.3.
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java
%load ../utils/StopWatch.java
%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
%load ../utils/timemachine/SeqDataLoader.java
@FunctionalInterface
public interface TriFunction<T, U, V, W> {
public W apply(T t, U u, V v);
}
@FunctionalInterface
public interface QuadFunction<T, U, V, W, R> {
public R apply(T t, U u, V v, W w);
}
@FunctionalInterface
public interface SimpleFunction<T> {
public T apply();
}
@FunctionalInterface
public interface voidFunction<T> {
public void apply(T t);
}
@FunctionalInterface
public interface voidTwoFunction<T, U> {
public void apply(T t, U u);
}
NDManager manager = NDManager.newBaseManager();
int batchSize = 32;
int numSteps = 35;
Pair<List<NDList>, Vocab> timeMachine = SeqDataLoader.loadDataTimeMachine(batchSize, numSteps, false, 10000, manager);
List<NDList> trainIter = timeMachine.getKey();
Vocab vocab = timeMachine.getValue();
8.5.1. One-Hot Encoding¶
Recall that each token is represented as a numerical index in
trainIter
. Feeding these indices directly to a neural network might
make it hard to learn. We often represent each token as a more
expressive feature vector. The easiest representation is called one-hot
encoding, which is introduced in
subsec_classification-problem
.
In a nutshell, we map each index to a different unit vector: assume that
the number of different tokens in the vocabulary is \(N\)
(vocab.length()
) and the token indices range from 0 to \(N-1\).
If the index of a token is the integer \(i\), then we create a
vector of all 0s with a length of \(N\) and set the element at
position \(i\) to 1. This vector is the one-hot vector of the
original token. The one-hot vectors with indices 0 and 2 are shown
below.
manager.create(new int[] {0, 2}).oneHot(vocab.length())
ND: (2, 29) gpu(0) float32
[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ... 9 more],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., ... 9 more],
]
The shape of the minibatch that we sample each time is (batch size,
number of time steps). The oneHot
function transforms such a
minibatch into a three-dimensional NDArray with the last dimension
equals to the vocabulary size (vocab.length()
). We often transpose
the input so that we will obtain an output of shape (number of time
steps, batch size, vocabulary size). This will allow us to more
conveniently loop through the outermost dimension for updating hidden
states of a minibatch, time step by time step.
NDArray X = manager.arange(10).reshape(new Shape(2,5));
X.transpose().oneHot(28).getShape()
(5, 2, 28)
8.5.2. Initializing the Model Parameters¶
Next, we initialize the model parameters for the RNN model. The number
of hidden units numHiddens
is a tunable hyperparameter. When
training language models, the inputs and outputs are from the same
vocabulary. Hence, they have the same dimension, which is equal to the
vocabulary size.
public static NDList getParams(int vocabSize, int numHiddens, Device device) {
int numOutputs = vocabSize;
int numInputs = vocabSize;
// Hidden layer parameters
NDArray W_xh = normal(new Shape(numInputs, numHiddens), device);
NDArray W_hh = normal(new Shape(numHiddens, numHiddens), device);
NDArray b_h = manager.zeros(new Shape(numHiddens), DataType.FLOAT32, device);
// Output layer parameters
NDArray W_hq = normal(new Shape(numHiddens, numOutputs), device);
NDArray b_q = manager.zeros(new Shape(numOutputs), DataType.FLOAT32, device);
// Attach gradients
NDList params = new NDList(W_xh, W_hh, b_h, W_hq, b_q);
for (NDArray param : params) {
param.setRequiresGradient(true);
}
return params;
}
public static NDArray normal(Shape shape, Device device) {
return manager.randomNormal(0f, 0.01f, shape, DataType.FLOAT32, device);
}
8.5.3. RNN Model¶
To define an RNN model, we first need an initRNNState
function to
return the hidden state at initialization. It returns a NDArray filled
with 0 and with a shape of (batch size, number of hidden units).
public static NDList initRNNState(int batchSize, int numHiddens, Device device) {
return new NDList(manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device));
}
The following rnn
function defines how to compute the hidden state
and output at a time step. Note that the RNN model loops through the
outermost dimension of inputs
so that it updates hidden states H
of a minibatch, time step by time step. Besides, the activation function
here uses the \(\tanh\) function. As described in
Section 4.1, the mean value of the \(\tanh\) function is 0,
when the elements are uniformly distributed over the real numbers.
public static Pair<NDArray, NDList> rnn(NDArray inputs, NDList state, NDList params) {
// Shape of `inputs`: (`numSteps`, `batchSize`, `vocabSize`)
NDArray W_xh = params.get(0);
NDArray W_hh = params.get(1);
NDArray b_h = params.get(2);
NDArray W_hq = params.get(3);
NDArray b_q = params.get(4);
NDArray H = state.get(0);
NDList outputs = new NDList();
// Shape of `X`: (`batchSize`, `vocabSize`)
NDArray X, Y;
for (int i = 0; i < inputs.size(0); i++) {
X = inputs.get(i);
H = (X.dot(W_xh).add(H.dot(W_hh)).add(b_h)).tanh();
Y = H.dot(W_hq).add(b_q);
outputs.add(Y);
}
return new Pair<>(outputs.size() > 1 ? NDArrays.concat(outputs) : outputs.get(0), new NDList(H));
}
With all the needed functions being defined, next we create a class to wrap these functions and store parameters for an RNN model implemented from scratch.
/** An RNN Model implemented from scratch. */
public class RNNModelScratch {
public int vocabSize;
public int numHiddens;
public NDList params;
public TriFunction<Integer, Integer, Device, NDList> initState;
public TriFunction<NDArray, NDList, NDList, Pair> forwardFn;
public RNNModelScratch(
int vocabSize,
int numHiddens,
Device device,
TriFunction<Integer, Integer, Device, NDList> getParams,
TriFunction<Integer, Integer, Device, NDList> initRNNState,
TriFunction<NDArray, NDList, NDList, Pair> forwardFn) {
this.vocabSize = vocabSize;
this.numHiddens = numHiddens;
this.params = getParams.apply(vocabSize, numHiddens, device);
this.initState = initRNNState;
this.forwardFn = forwardFn;
}
public Pair forward(NDArray X, NDList state) {
X = X.transpose().oneHot(this.vocabSize);
return this.forwardFn.apply(X, state, this.params);
}
public NDList beginState(int batchSize, Device device) {
return this.initState.apply(batchSize, this.numHiddens, device);
}
}
Let us check whether the outputs have the correct shapes, e.g., to ensure that the dimensionality of the hidden state remains unchanged.
int numHiddens = 512;
TriFunction<Integer, Integer, Device, NDList> getParamsFn = (a, b, c) -> getParams(a, b, c);
TriFunction<Integer, Integer, Device, NDList> initRNNStateFn =
(a, b, c) -> initRNNState(a, b, c);
TriFunction<NDArray, NDList, NDList, Pair> rnnFn = (a, b, c) -> rnn(a, b, c);
NDArray X = manager.arange(10).reshape(new Shape(2, 5));
Device device = manager.getDevice();
RNNModelScratch net =
new RNNModelScratch(
vocab.length(), numHiddens, device, getParamsFn, initRNNStateFn, rnnFn);
NDList state = net.beginState((int) X.getShape().getShape()[0], device);
Pair<NDArray, NDList> pairResult = net.forward(X.toDevice(device, false), state);
NDArray Y = pairResult.getKey();
NDList newState = pairResult.getValue();
System.out.println(Y.getShape());
System.out.println(newState.get(0).getShape());
(10, 29)
(2, 512)
We can see that the output shape is (number of time steps \(\times\) batch size, vocabulary size), while the hidden state shape remains the same, i.e., (batch size, number of hidden units).
8.5.4. Prediction¶
Let us first define the prediction function to generate new characters
following the user-provided prefix
, which is a string containing
several characters. When looping through these beginning characters in
prefix
, we keep passing the hidden state to the next time step
without generating any output. This is called the warm-up period,
during which the model updates itself (e.g., update the hidden state)
but does not make predictions. After the warm-up period, the hidden
state is generally better than its initialized value at the beginning.
So we generate the predicted characters and emit them.
/** Generate new characters following the `prefix`. */
public static String predictCh8(
String prefix, int numPreds, RNNModelScratch net, Vocab vocab, Device device) {
NDList state = net.beginState(1, device);
List<Integer> outputs = new ArrayList<>();
outputs.add(vocab.getIdx("" + prefix.charAt(0)));
SimpleFunction<NDArray> getInput =
() ->
manager.create(outputs.get(outputs.size() - 1))
.toDevice(device, false)
.reshape(new Shape(1, 1));
for (char c : prefix.substring(1).toCharArray()) { // Warm-up period
state = (NDList) net.forward(getInput.apply(), state).getValue();
outputs.add(vocab.getIdx("" + c));
}
NDArray y;
for (int i = 0; i < numPreds; i++) {
Pair<NDArray, NDList> pair = net.forward(getInput.apply(), state);
y = pair.getKey();
state = pair.getValue();
outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));
}
StringBuilder output = new StringBuilder();
for (int i : outputs) {
output.append(vocab.idxToToken.get(i));
}
return output.toString();
}
Now we can test the predict_ch8
function. We specify the prefix as
time traveller
and have it generate 10 additional characters. Given
that we have not trained the network, it will generate nonsensical
predictions.
predictCh8("time traveller ", 10, net, vocab, manager.getDevice());
time traveller ks<unk>s<unk>s<unk>s<unk>s
8.5.5. Gradient Clipping¶
For a sequence of length \(T\), we compute the gradients over these \(T\) time steps in an iteration, which results in a chain of matrix-products with length \(\mathcal{O}(T)\) during backpropagation. As mentioned in Section 4.8, it might result in numerical instability, e.g., the gradients may either explode or vanish, when \(T\) is large. Therefore, RNN models often need extra help to stabilize the training.
Generally speaking, when solving an optimization problem, we take update steps for the model parameter, say in the vector form \(\mathbf{x}\), in the direction of the negative gradient \(\mathbf{g}\) on a minibatch. For example, with \(\eta > 0\) as the learning rate, in one iteration we update \(\mathbf{x}\) as \(\mathbf{x} - \eta \mathbf{g}\). Let us further assume that the objective function \(f\) is well behaved, say, Lipschitz continuous with constant \(L\). That is to say, for any \(\mathbf{x}\) and \(\mathbf{y}\) we have
In this case we can safely assume that if we update the parameter vector by \(\eta \mathbf{g}\), then
which means that we will not observe a change by more than \(L \eta \|\mathbf{g}\|\). This is both a curse and a blessing. On the curse side, it limits the speed of making progress; whereas on the blessing side, it limits the extent to which things can go wrong if we move in the wrong direction.
Sometimes the gradients can be quite large and the optimization algorithm may fail to converge. We could address this by reducing the learning rate \(\eta\). But what if we only rarely get large gradients? In this case such an approach may appear entirely unwarranted. One popular alternative is to clip the gradient \(\mathbf{g}\) by projecting them back to a ball of a given radius, say \(\theta\) via
By doing so we know that the gradient norm never exceeds \(\theta\) and that the updated gradient is entirely aligned with the original direction of \(\mathbf{g}\). It also has the desirable side-effect of limiting the influence any given minibatch (and within it any given sample) can exert on the parameter vector. This bestows a certain degree of robustness to the model. Gradient clipping provides a quick fix to the gradient exploding. While it does not entirely solve the problem, it is one of the many techniques to alleviate it.
Below we define a function to clip the gradients of a model that is implemented from scratch or a model constructed by the high-level APIs. Also note that we compute the gradient norm over all the model parameters.
/** Clip the gradient. */
public static void gradClipping(RNNModelScratch net, int theta, NDManager manager) {
double result = 0;
for (NDArray p : net.params) {
NDArray gradient = p.getGradient();
gradient.attach(manager);
result += gradient.pow(2).sum().getFloat();
}
double norm = Math.sqrt(result);
if (norm > theta) {
for (NDArray param : net.params) {
NDArray gradient = param.getGradient();
gradient.muli(theta / norm);
}
}
}
8.5.6. Training¶
Before training the model, let us define a function to train the model in one epoch. It differs from how we train the model of Section 3.6 in three places:
Different sampling methods for sequential data (random sampling and sequential partitioning) will result in differences in the initialization of hidden states.
We clip the gradients before updating the model parameters. This ensures that the model does not diverge even when gradients blow up at some point during the training process.
We use perplexity to evaluate the model. As discussed in Section 8.4.4, this ensures that sequences of different length are comparable.
Specifically, when sequential partitioning is used, we initialize the hidden state only at the beginning of each epoch. Since the \(i^\mathrm{th}\) subsequence example in the next minibatch is adjacent to the current \(i^\mathrm{th}\) subsequence example, the hidden state at the end of the current minibatch will be used to initialize the hidden state at the beginning of the next minibatch. In this way, historical information of the sequence stored in the hidden state might flow over adjacent subsequences within an epoch. However, the computation of the hidden state at any point depends on all the previous minibatches in the same epoch, which complicates the gradient computation. To reduce computational cost, we detach the gradient before processing any minibatch so that the gradient computation of the hidden state is always limited to the time steps in one minibatch.
When using the random sampling, we need to re-initialize the hidden
state for each iteration since each example is sampled with a random
position. Same as the trainEpochCh3
function in
Section 3.6, updater
is a general function to
update the model parameters. It can be either the function implemented
from scratch or the built-in optimization function in a deep learning
framework.
/** Train a model within one epoch. */
public static Pair<Double, Double> trainEpochCh8(
RNNModelScratch net,
List<NDList> trainIter,
Loss loss,
voidTwoFunction<Integer, NDManager> updater,
Device device,
boolean useRandomIter) {
StopWatch watch = new StopWatch();
watch.start();
Accumulator metric = new Accumulator(2); // Sum of training loss, no. of tokens
try (NDManager childManager = manager.newSubManager()) {
NDList state = null;
for (NDList pair : trainIter) {
NDArray X = pair.get(0).toDevice(device, true);
X.attach(childManager);
NDArray Y = pair.get(1).toDevice(device, true);
Y.attach(childManager);
if (state == null || useRandomIter) {
// Initialize `state` when either it is the first iteration or
// using random sampling
state = net.beginState((int) X.getShape().getShape()[0], device);
} else {
for (NDArray s : state) {
s.stopGradient();
}
}
state.attach(childManager);
NDArray y = Y.transpose().reshape(new Shape(-1));
X = X.toDevice(device, false);
y = y.toDevice(device, false);
try (GradientCollector gc = manager.getEngine().newGradientCollector()) {
Pair<NDArray, NDList> pairResult = net.forward(X, state);
NDArray yHat = pairResult.getKey();
state = pairResult.getValue();
NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();
gc.backward(l);
metric.add(new float[] {l.getFloat() * y.size(), y.size()});
}
gradClipping(net, 1, childManager);
updater.apply(1, childManager); // Since the `mean` function has been invoked
}
}
return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());
}
The training function supports an RNN model implemented either from scratch or using high-level APIs.
/** Train a model. */
public static void trainCh8(
RNNModelScratch net,
List<NDList> trainIter,
Vocab vocab,
int lr,
int numEpochs,
Device device,
boolean useRandomIter) {
SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();
Animator animator = new Animator();
// Initialize
voidTwoFunction<Integer, NDManager> updater =
(batchSize, subManager) -> Training.sgd(net.params, lr, batchSize, subManager);
Function<String, String> predict = (prefix) -> predictCh8(prefix, 50, net, vocab, device);
// Train and predict
double ppl = 0.0;
double speed = 0.0;
for (int epoch = 0; epoch < numEpochs; epoch++) {
Pair<Double, Double> pair =
trainEpochCh8(net, trainIter, loss, updater, device, useRandomIter);
ppl = pair.getKey();
speed = pair.getValue();
if ((epoch + 1) % 10 == 0) {
animator.add(epoch + 1, (float) ppl, "");
animator.show();
}
}
System.out.format(
"perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString());
System.out.println(predict.apply("time traveller"));
System.out.println(predict.apply("traveller"));
}
Now we can train the RNN model. Since we only use 10000 tokens in the dataset, the model needs more epochs to converge better.
int numEpochs = Integer.getInteger("MAX_EPOCH", 500);
int lr = 1;
trainCh8(net, trainIter, vocab, lr, numEpochs, manager.getDevice(), false);
perplexity: 1.0, 42125.3 tokens/sec on gpu(0)
time traveller came back andfilby s anecdote collapsedthe thing
traveller broce tea ls thoug be than s abe asions if at un
Finally, let us check the results of using the random sampling method.
trainCh8(net, trainIter, vocab, lr, numEpochs, manager.getDevice(), true);
perplexity: 1.1, 42501.5 tokens/sec on gpu(0)
time traveller fol so eat you dan homi frealt ato hesperte m sm
traveller bet her wioke treo damsyon intw b tare arougn e o
While implementing the above RNN model from scratch is instructive, it is not convenient. In the next section we will see how to improve the RNN model, such as how to make it easier to implement and make it run faster.
8.5.7. Summary¶
We can train an RNN-based character-level language model to generate text following the user-provided text prefix.
A simple RNN language model consists of input encoding, RNN modeling, and output generation.
RNN models need state initialization for training, though random sampling and sequential partitioning use different ways.
When using sequential partitioning, we need to detach the gradient to reduce computational cost.
A warm-up period allows a model to update itself (e.g., obtain a better hidden state than its initialized value) before making any prediction.
Gradient clipping prevents gradient explosion, but it cannot fix vanishing gradients.
8.5.8. Exercises¶
Show that one-hot encoding is equivalent to picking a different embedding for each object.
Adjust the hyperparameters (e.g., number of epochs, number of hidden units, number of time steps in a minibatch, and learning rate) to improve the perplexity.
How low can you go?
Replace one-hot encoding with learnable embeddings. Does this lead to better performance?
How well will it work on other books by H. G. Wells, e.g., *The War of the Worlds*?
Modify the prediction function such as to use sampling rather than picking the most likely next character.
What happens?
Bias the model towards more likely outputs, e.g., by sampling from \(q(x_t \mid x_{t-1}, \ldots, x_1) \propto P(x_t \mid x_{t-1}, \ldots, x_1)^\alpha\) for \(\alpha > 1\).
Run the code in this section without clipping the gradient. What happens?
Change sequential partitioning so that it does not separate hidden states from the computational graph. Does the running time change? How about the perplexity?
Replace the activation function used in this section with ReLU and repeat the experiments in this section. Do we still need gradient clipping? Why?