Run this notebook online:\ |Binder| or Colab: |Colab|
.. |Binder| image:: https://mybinder.org/badge_logo.svg
:target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_recurrent-modern/gru.ipynb
.. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_recurrent-modern/gru.ipynb
.. _sec_gru:
Gated Recurrent Units (GRU)
===========================
In :numref:`sec_bptt`, we discussed how gradients are calculated in
RNNs. In particular we found that long products of matrices can lead to
vanishing or exploding gradients. Let us briefly think about what such
gradient anomalies mean in practice:
- We might encounter a situation where an early observation is highly
significant for predicting all future observations. Consider the
somewhat contrived case where the first observation contains a
checksum and the goal is to discern whether the checksum is correct
at the end of the sequence. In this case, the influence of the first
token is vital. We would like to have some mechanisms for storing
vital early information in a *memory cell*. Without such a mechanism,
we will have to assign a very large gradient to this observation,
since it affects all the subsequent observations.
- We might encounter situations where some tokens carry no pertinent
observation. For instance, when parsing a web page there might be
auxiliary HTML code that is irrelevant for the purpose of assessing
the sentiment conveyed on the page. We would like to have some
mechanism for *skipping* such tokens in the latent state
representation.
- We might encounter situations where there is a logical break between
parts of a sequence. For instance, there might be a transition
between chapters in a book, or a transition between a bear and a bull
market for securities. In this case it would be nice to have a means
of *resetting* our internal state representation.
A number of methods have been proposed to address this. One of the
earliest is long short-term memory :cite:`Hochreiter.Schmidhuber.1997`
which we will discuss in :numref:`sec_lstm`. The gated recurrent unit
(GRU) :cite:`Cho.Van-Merrienboer.Bahdanau.ea.2014` is a slightly more
streamlined variant that often offers comparable performance and is
significantly faster to compute :cite:`Chung.Gulcehre.Cho.ea.2014`.
Due to its simplicity, let us start with the GRU.
Gated Hidden State
------------------
The key distinction between vanilla RNNs and GRUs is that the latter
support gating of the hidden state. This means that we have dedicated
mechanisms for when a hidden state should be *updated* and also when it
should be *reset*. These mechanisms are learned and they address the
concerns listed above. For instance, if the first token is of great
importance we will learn not to update the hidden state after the first
observation. Likewise, we will learn to skip irrelevant temporary
observations. Last, we will learn to reset the latent state whenever
needed. We discuss this in detail below.
.. _fig_gru_1:
Reset Gate and Update Gate
~~~~~~~~~~~~~~~~~~~~~~~~~~
The first thing we need to introduce are the *reset gate* and the
*update gate*. We engineer them to be vectors with entries in
:math:`(0, 1)` such that we can perform convex combinations. For
instance, a reset gate would allow us to control how much of the
previous state we might still want to remember. Likewise, an update gate
would allow us to control how much of the new state is just a copy of
the old state.
We begin by engineering these gates. :numref:`fig_gru_1` illustrates
the inputs for both the reset and update gates in a GRU, given the input
of the current time step and the hidden state of the previous time step.
The outputs of two gates are given by two fully-connected layers with a
sigmoid activation function.
|Computing the reset gate and the update gate in a GRU model.|
Mathematically, for a given time step :math:`t`, suppose that the input
is a minibatch :math:`\mathbf{X}_t \in \mathbb{R}^{n \times d}` (number
of examples: :math:`n`, number of inputs: :math:`d`) and the hidden
state of the previous time step is
:math:`\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}` (number of hidden
units: :math:`h`). Then, the reset gate
:math:`\mathbf{R}_t \in \mathbb{R}^{n \times h}` and update gate
:math:`\mathbf{Z}_t \in \mathbb{R}^{n \times h}` are computed as
follows:
.. math::
\begin{aligned}
\mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\
\mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z),
\end{aligned}
where
:math:`\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}` and
:math:`\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}` are
weight parameters and
:math:`\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}` are
biases. Note that broadcasting (see :numref:`subsec_broadcasting`) is
triggered during the summation. We use sigmoid functions (as introduced
in :numref:`sec_mlp`) to transform input values to the interval
:math:`(0, 1)`.
.. _fig_gru_2:
Candidate Hidden State
~~~~~~~~~~~~~~~~~~~~~~
Next, let us integrate the reset gate :math:`\mathbf{R}_t` with the
regular latent state updating mechanism in :eq:`rnn_h_with_state`.
It leads to the following *candidate hidden state*
:math:`\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}` at time step
:math:`t`:
.. math:: \tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),
:label: gru_tilde_H
where :math:`\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}` and
:math:`\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}` are weight
parameters, :math:`\mathbf{b}_h \in \mathbb{R}^{1 \times h}` is the
bias, and the symbol :math:`\odot` is the Hadamard (elementwise) product
operator. Here we use a nonlinearity in the form of tanh to ensure that
the values in the candidate hidden state remain in the interval
:math:`(-1, 1)`.
The result is a *candidate* since we still need to incorporate the
action of the update gate. Comparing with :eq:`rnn_h_with_state`,
now the influence of the previous states can be reduced with the
elementwise multiplication of :math:`\mathbf{R}_t` and
:math:`\mathbf{H}_{t-1}` in :eq:`gru_tilde_H`. Whenever the entries
in the reset gate :math:`\mathbf{R}_t` are close to 1, we recover a
vanilla RNN such as in :eq:`rnn_h_with_state`. For all entries of
the reset gate :math:`\mathbf{R}_t` that are close to 0, the candidate
hidden state is the result of an MLP with :math:`\mathbf{X}_t` as the
input. Any pre-existing hidden state is thus *reset* to defaults.
:numref:`fig_gru_2` illustrates the computational flow after applying
the reset gate.
|Computing the candidate hidden state in a GRU model.|
Hidden State
~~~~~~~~~~~~
Finally, we need to incorporate the effect of the update gate
:math:`\mathbf{Z}_t`. This determines the extent to which the new hidden
state :math:`\mathbf{H}_t \in \mathbb{R}^{n \times h}` is just the old
state :math:`\mathbf{H}_{t-1}` and by how much the new candidate state
:math:`\tilde{\mathbf{H}}_t` is used. The update gate
:math:`\mathbf{Z}_t` can be used for this purpose, simply by taking
elementwise convex combinations between both :math:`\mathbf{H}_{t-1}`
and :math:`\tilde{\mathbf{H}}_t`. This leads to the final update
equation for the GRU:
.. math:: \mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.
Whenever the update gate :math:`\mathbf{Z}_t` is close to 1, we simply
retain the old state. In this case the information from
:math:`\mathbf{X}_t` is essentially ignored, effectively skipping time
step :math:`t` in the dependency chain. In contrast, whenever
:math:`\mathbf{Z}_t` is close to 0, the new latent state
:math:`\mathbf{H}_t` approaches the candidate latent state
:math:`\tilde{\mathbf{H}}_t`. These designs can help us cope with the
vanishing gradient problem in RNNs and better capture dependencies for
sequences with large time step distances. For instance, if the update
gate has been close to 1 for all the time steps of an entire
subsequence, the old hidden state at the time step of its beginning will
be easily retained and passed to its end, regardless of the length of
the subsequence.
:numref:`fig_gru_3` illustrates the computational flow after the
update gate is in action.
|Computing the hidden state in a GRU model.| .. _fig_gru_3:
In summary, GRUs have the following two distinguishing features:
- Reset gates help capture short-term dependencies in sequences.
- Update gates help capture long-term dependencies in sequences.
Implementation from Scratch
---------------------------
To gain a better understanding of the GRU model, let us implement it
from scratch. We begin by reading the time machine dataset that we used
in :numref:`sec_rnn_scratch`. The code for reading the dataset is
given below.
.. |Computing the reset gate and the update gate in a GRU model.| image:: https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/gru-1.svg
.. |Computing the candidate hidden state in a GRU model.| image:: https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/gru-2.svg
.. |Computing the hidden state in a GRU model.| image:: https://raw.githubusercontent.com/d2l-ai/d2l-en/master/img/gru-3.svg
.. code:: java
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java
%load ../utils/StopWatch.java
%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModel.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
%load ../utils/timemachine/TimeMachineDataset.java
.. code:: java
NDManager manager = NDManager.newBaseManager();
.. code:: java
int batchSize = 32;
int numSteps = 35;
TimeMachineDataset dataset =
new TimeMachineDataset.Builder()
.setManager(manager)
.setMaxTokens(10000)
.setSampling(batchSize, false)
.setSteps(numSteps)
.build();
dataset.prepare();
Vocab vocab = dataset.getVocab();
Initializing Model Parameters
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The next step is to initialize the model parameters. We draw the weights
from a Gaussian distribution with standard deviation to be 0.01 and set
the bias to 0. The hyperparameter ``num_hiddens`` defines the number of
hidden units. We instantiate all weights and biases relating to the
update gate, the reset gate, the candidate hidden state, and the output
layer.
.. code:: java
public static NDArray normal(Shape shape, Device device) {
return manager.randomNormal(0, 0.01f, shape, DataType.FLOAT32, device);
}
public static NDList three(int numInputs, int numHiddens, Device device) {
return new NDList(
normal(new Shape(numInputs, numHiddens), device),
normal(new Shape(numHiddens, numHiddens), device),
manager.zeros(new Shape(numHiddens), DataType.FLOAT32, device));
}
public static NDList getParams(int vocabSize, int numHiddens, Device device) {
int numInputs = vocabSize;
int numOutputs = vocabSize;
// Update gate parameters
NDList temp = three(numInputs, numHiddens, device);
NDArray W_xz = temp.get(0);
NDArray W_hz = temp.get(1);
NDArray b_z = temp.get(2);
// Reset gate parameters
temp = three(numInputs, numHiddens, device);
NDArray W_xr = temp.get(0);
NDArray W_hr = temp.get(1);
NDArray b_r = temp.get(2);
// Candidate hidden state parameters
temp = three(numInputs, numHiddens, device);
NDArray W_xh = temp.get(0);
NDArray W_hh = temp.get(1);
NDArray b_h = temp.get(2);
// Output layer parameters
NDArray W_hq = normal(new Shape(numHiddens, numOutputs), device);
NDArray b_q = manager.zeros(new Shape(numOutputs), DataType.FLOAT32, device);
// Attach gradients
NDList params = new NDList(W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q);
for (NDArray param : params) {
param.setRequiresGradient(true);
}
return params;
}
Defining the Model
~~~~~~~~~~~~~~~~~~
Now we will define the hidden state initialization function
``init_gru_state``. Just like the ``init_rnn_state`` function defined in
:numref:`sec_rnn_scratch`, this function returns a tensor with a shape
(batch size, number of hidden units) whose values are all zeros.
.. code:: java
public static NDList initGruState(int batchSize, int numHiddens, Device device) {
return new NDList(manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device));
}
Now we are ready to define the GRU model. Its structure is the same as
that of the basic RNN cell, except that the update equations are more
complex.
.. code:: java
public static Pair gru(NDArray inputs, NDList state, NDList params) {
NDArray W_xz = params.get(0);
NDArray W_hz = params.get(1);
NDArray b_z = params.get(2);
NDArray W_xr = params.get(3);
NDArray W_hr = params.get(4);
NDArray b_r = params.get(5);
NDArray W_xh = params.get(6);
NDArray W_hh = params.get(7);
NDArray b_h = params.get(8);
NDArray W_hq = params.get(9);
NDArray b_q = params.get(10);
NDArray H = state.get(0);
NDList outputs = new NDList();
NDArray X, Y, Z, R, H_tilda;
for (int i = 0; i < inputs.size(0); i++) {
X = inputs.get(i);
Z = Activation.sigmoid(X.dot(W_xz).add(H.dot(W_hz).add(b_z)));
R = Activation.sigmoid(X.dot(W_xr).add(H.dot(W_hr).add(b_r)));
H_tilda = Activation.tanh(X.dot(W_xh).add(R.mul(H).dot(W_hh).add(b_h)));
H = Z.mul(H).add(Z.mul(-1).add(1).mul(H_tilda));
Y = H.dot(W_hq).add(b_q);
outputs.add(Y);
}
return new Pair(outputs.size() > 1 ? NDArrays.concat(outputs) : outputs.get(0), new NDList(H));
}
Training and Prediction
~~~~~~~~~~~~~~~~~~~~~~~
Training and prediction work in exactly the same manner as in
:numref:`sec_rnn_scratch`. After training, we print out the perplexity
on the training set and the predicted sequence following the provided
prefixes "time traveller" and "traveller", respectively.
.. code:: java
int vocabSize = vocab.length();
int numHiddens = 256;
Device device = manager.getDevice();
int numEpochs = Integer.getInteger("MAX_EPOCH", 500);
int lr = 1;
Functions.TriFunction getParamsFn = (a, b, c) -> getParams(a, b, c);
Functions.TriFunction initGruStateFn =
(a, b, c) -> initGruState(a, b, c);
Functions.TriFunction> gruFn = (a, b, c) -> gru(a, b, c);
RNNModelScratch model =
new RNNModelScratch(vocabSize, numHiddens, device,
getParamsFn, initGruStateFn, gruFn);
TimeMachine.trainCh8(model, dataset, vocab, lr, numEpochs, device, false, manager);
.. raw:: html
.. parsed-literal::
:class: output
perplexity: 1.0, 14501.5 tokens/sec on gpu(0)
time travelleryou can show black is white by argument said filby
travellerabl the bublict os to intaboo in a balloon and why
Concise Implementation
----------------------
In high-level APIs, we can directly instantiate a GPU model. This
encapsulates all the configuration detail that we made explicit above.
The code is significantly faster as it uses compiled operators rather
than Python for many details that we spelled out before.
.. code:: java
GRU gruLayer = GRU.builder().setNumLayers(1)
.setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();
RNNModel modelConcise = new RNNModel(gruLayer,vocab.length());
TimeMachine.trainCh8(modelConcise, dataset, vocab, lr, numEpochs, device, false, manager);
.. parsed-literal::
:class: output
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.8.0 in 0.059 ms.
.. raw:: html
.. parsed-literal::
:class: output
perplexity: 1.0, 81300.7 tokens/sec on gpu(0)
time traveller with a slight accession ofcheerfulness really thi
travellerrore as so itnee any a fourthene wither other will
Summary
-------
- Gated RNNs can better capture dependencies for sequences with large
time step distances.
- Reset gates help capture short-term dependencies in sequences.
- Update gates help capture long-term dependencies in sequences.
- GRUs contain basic RNNs as their extreme case whenever the reset gate
is switched on. They can also skip subsequences by turning on the
update gate.
Exercises
---------
1. Assume that we only want to use the input at time step :math:`t'` to
predict the output at time step :math:`t > t'`. What are the best
values for the reset and update gates for each time step?
2. Adjust the hyperparameters and analyze the their influence on running
time, perplexity, and the output sequence.
3. Compare runtime, perplexity, and the output strings for ``rnn.RNN``
and ``rnn.GRU`` implementations with each other.
4. What happens if you implement only parts of a GRU, e.g., with only a
reset gate or only an update gate?