Run this notebook online:Binder or Colab: Colab

10.4. Self-Attention and Positional Encoding

In deep learning, we often use CNNs or RNNs to encode a sequence. Now with attention mechanisms. imagine that we feed a sequence of tokens into attention pooling so that the same set of tokens act as queries, keys, and values. Specifically, each query attends to all the key-value pairs and generates one attention output. Since the queries, keys, and values come from the same place, this performs self-attention [Lin.Feng.Santos.ea.2017][Vaswani et al., 2017], which is also called intra-attention [Cheng.Dong.Lapata.2016][Parikh et al., 2016][Paulus.Xiong.Socher.2017]. In this section, we will discuss sequence encoding using self-attention, including using additional information for the sequence order.

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java

%load ../utils/attention/Chap10Utils.java
%load ../utils/attention/DotProductAttention.java
%load ../utils/attention/MultiHeadAttention.java
%load ../utils/attention/PositionalEncoding.java
NDManager manager = NDManager.newBaseManager();

10.4.1. Self-Attention

Given a sequence of input tokens \(\mathbf{x}_1, \ldots, \mathbf{x}_n\) where any \(\mathbf{x}_i \in \mathbb{R}^d\) (\(1 \leq i \leq n\)), its self-attention outputs a sequence of the same length \(\mathbf{y}_1, \ldots, \mathbf{y}_n\), where

(10.4.1)\[\mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d\]

according to the definition of attention pooling \(f\) in (). Using multi-head attention, the following code snippet computes the self-attention of a tensor with shape (batch size, number of time steps or sequence length in tokens, \(d\)). The output tensor has the same shape.

int numHiddens = 100;
int numHeads = 5;
MultiHeadAttention attention = new MultiHeadAttention(numHiddens, numHeads, 0.5f, false);
int batchSize = 2;
int numQueries = 4;
NDArray validLens = manager.create(new float[] {3, 2});
NDArray X = manager.ones(new Shape(batchSize, numQueries, numHiddens));
ParameterStore ps = new ParameterStore(manager, false);
NDList input = new NDList(X, X, X, validLens);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList result = attention.forward(ps, input, false);
result.get(0).getShape()
(2, 4, 100)

10.4.2. Comparing CNNs, RNNs, and Self-Attention

Let us compare architectures for mapping a sequence of \(n\) tokens to another sequence of equal length, where each input or output token is represented by a \(d\)-dimensional vector. Specifically, we will consider CNNs, RNNs, and self-attention. We will compare their computational complexity, sequential operations, and maximum path lengths. Note that sequential operations prevent parallel computation, while a shorter path between any combination of sequence positions makes it easier to learn long-range dependencies within the sequence [Hochreiter.Bengio.Frasconi.ea.2001].

Comparing CNN (padding tokens are omitted), RNN, and self-attention architectures. .. _fig_cnn-rnn-self-attention:

Consider a convolutional layer whose kernel size is \(k\). We will provide more details about sequence processing using CNNs in later chapters. For now, we only need to know that since the sequence length is \(n\), the numbers of input and output channels are both \(d\), the computational complexity of the convolutional layer is \(\mathcal{O}(knd^2)\). As fig_cnn-rnn-self-attention shows, CNNs are hierarchical so there are \(\mathcal{O}(1)\) sequential operations and the maximum path length is \(\mathcal{O}(n/k)\). For example, \(\mathbf{x}_1\) and \(\mathbf{x}_5\) are within the receptive field of a two-layer CNN with kernel size 3 in fig_cnn-rnn-self-attention.

When updating the hidden state of RNNs, multiplication of the \(d \times d\) weight matrix and the \(d\)-dimensional hidden state has a computational complexity of \(\mathcal{O}(d^2)\). Since the sequence length is \(n\), the computational complexity of the recurrent layer is \(\mathcal{O}(nd^2)\). According to fig_cnn-rnn-self-attention, there are \(\mathcal{O}(n)\) sequential operations that cannot be parallelized and the maximum path length is also \(\mathcal{O}(n)\).

In self-attention, the queries, keys, and values are all \(n \times d\) matrices. Consider the scaled dot-product attention in (10.2.5), where a \(n \times d\) matrix is multiplied by a \(d \times n\) matrix, then the output \(n \times n\) matrix is multiplied by a \(n \times d\) matrix. As a result, the self-attention has a \(\mathcal{O}(n^2d)\) computational complexity. As we can see in fig_cnn-rnn-self-attention, each token is directly connected to any other token via self-attention. Therefore, computation can be parallel with \(\mathcal{O}(1)\) sequential operations and the maximum path length is also \(\mathcal{O}(1)\).

All in all, both CNNs and self-attention enjoy parallel computation and self-attention has the shortest maximum path length. However, the quadratic computational complexity with respect to the sequence length makes self-attention prohibitively slow for very long sequences.

10.4.3. Positional Encoding

Unlike RNNs that recurrently process tokens of a sequence one by one, self-attention ditches sequential operations in favor of parallel computation. To use the sequence order information, we can inject absolute or relative positional information by adding positional encoding to the input representations. Positional encodings can be either learned or fixed. In the following, we describe a fixed positional encoding based on sine and cosine functions [Vaswani et al., 2017].

Suppose that the input representation \(\mathbf{X} \in \mathbb{R}^{n \times d}\) contains the \(d\)-dimensional embeddings for \(n\) tokens of a sequence. The positional encoding outputs \(\mathbf{X} + \mathbf{P}\) using a positional embedding matrix \(\mathbf{P} \in \mathbb{R}^{n \times d}\) of the same shape, whose element on the \(i^\mathrm{th}\) row and the \((2j)^\mathrm{th}\) or the \((2j + 1)^\mathrm{th}\) column is

(10.4.2)\[\begin{split}\begin{aligned} p_{i, 2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right),\\p_{i, 2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right).\end{aligned}\end{split}\]

At first glance, this trigonometric-function design looks weird. Before explanations of this design, let us first implement it in the following PositionalEncoding class.

public class PositionalEncoding extends AbstractBlock {

    private Dropout dropout;
    public NDArray P;

    public PositionalEncoding(int numHiddens, float dropout, int maxLen, NDManager manager) {
        this.dropout = Dropout.builder().optRate(dropout).build();
        addChildBlock("dropout", this.dropout);

        // Create a long enough `P`
        P = manager.zeros(new Shape(1, maxLen, numHiddens));
        NDArray X =
                manager.arange(maxLen)
                        .reshape(-1, 1)
                        .div(
                                manager.create(10000)
                                        .pow(manager.arange(0, numHiddens, 2).div(numHiddens)));
        P.set(new NDIndex(":, :, {}::{}", 0, 2), X.sin());
        P.set(new NDIndex(":, :, {}::{}", 1, 2), X.cos());
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore parameterStore,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        NDArray X = inputs.get(0);
        X = X.add(P.get(":, :{}, :", X.getShape().get(1)));
        return new NDList(
                dropout.forward(parameterStore, new NDList(X), training, params).get(0));
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
        try (NDManager sub = manager.newSubManager()) {
            NDArray X = sub.zeros(inputShapes[0], dataType);
            X = X.add(P.get(":, :{}, :", X.getShape().get(1)));
            dropout.initialize(manager, dataType, X.getShape());
        }
    }
}

In the positional embedding matrix \(\mathbf{P}\), rows correspond to positions within a sequence and columns represent different positional encoding dimensions. In the example below, we can see that the \(6^{\mathrm{th}}\) and the \(7^{\mathrm{th}}\) columns of the positional embedding matrix have a higher frequency than the \(8^{\mathrm{th}}\) and the \(9^{\mathrm{th}}\) columns. The offset between the \(6^{\mathrm{th}}\) and the \(7^{\mathrm{th}}\) (same for the \(8^{\mathrm{th}}\) and the \(9^{\mathrm{th}}\)) columns is due to the alternation of sine and cosine functions.

int encodingDim = 32;
int numSteps = 60;
PositionalEncoding posEncoding = new PositionalEncoding(encodingDim, 0, 1000, manager);
input = new NDList(manager.zeros(new Shape(1, numSteps, encodingDim)));
X = posEncoding.forward(ps, input, false).get(0);
NDArray P = posEncoding.P.get(new NDIndex(":, :{}, :", X.getShape().get(1)));

double[][] plotX = new double[4][];
double[][] plotY = new double[4][];
for (int i = 0; i < 4; i++) {
    if (i == 0) {
        plotX[i] = manager.arange(numSteps).toType(DataType.FLOAT64, false).toDoubleArray();
    } else {
        plotX[i] = plotX[i - 1];
    }
    plotY[i] =
            Functions.floatToDoubleArray(
                    P.get(new NDIndex("0, :, {},", i + 6)).toFloatArray());
}


PlotUtils.plot(
        plotX,
        plotY,
        new String[] {"Col6", "Col7", "Col8", "Col9"},
        "Row (position)",
        "");

10.4.3.1. Absolute Positional Information

To see how the monotonically decreased frequency along the encoding dimension relates to absolute positional information, let us print out the binary representations of \(0, 1, \ldots, 7\). As we can see, the lowest bit, the second-lowest bit, and the third-lowest bit alternate on every number, every two numbers, and every four numbers, respectively.

for (int i = 0; i < 8; i++) {
    System.out.println(i + " in binary is " + Integer.toBinaryString(i));
}
0 in binary is 0
1 in binary is 1
2 in binary is 10
3 in binary is 11
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111

In binary representations, a higher bit has a lower frequency than a lower bit. Similarly, as demonstrated in the heat map below, the positional encoding decreases frequencies along the encoding dimension by using trigonometric functions. Since the outputs are float numbers, such continuous representations are more space-efficient than binary representations.

P = P.get(new NDIndex("0, :, :")).expandDims(0).expandDims(0);
PlotUtils.showHeatmaps(
        P, "Column (encoding dimension)", "Row (position)", new String[] {""}, 500, 700);

10.4.3.2. Relative Positional Information

Besides capturing absolute positional information, the above positional encoding also allows a model to easily learn to attend by relative positions. This is because for any fixed position offset \(\delta\), the positional encoding at position \(i + \delta\) can be represented by a linear projection of that at position \(i\).

This projection can be explained mathematically. Denoting \(\omega_j = 1/10000^{2j/d}\), any pair of \((p_{i, 2j}, p_{i, 2j+1})\) in (10.4.2) can be linearly projected to \((p_{i+\delta, 2j}, p_{i+\delta, 2j+1})\) for any fixed offset \(\delta\):

(10.4.3)\[\begin{split}\begin{aligned} &\begin{bmatrix} \cos(\delta \omega_j) & \sin(\delta \omega_j) \\ -\sin(\delta \omega_j) & \cos(\delta \omega_j) \\ \end{bmatrix} \begin{bmatrix} p_{i, 2j} \\ p_{i, 2j+1} \\ \end{bmatrix}\\ =&\begin{bmatrix} \cos(\delta \omega_j) \sin(i \omega_j) + \sin(\delta \omega_j) \cos(i \omega_j) \\ -\sin(\delta \omega_j) \sin(i \omega_j) + \cos(\delta \omega_j) \cos(i \omega_j) \\ \end{bmatrix}\\ =&\begin{bmatrix} \sin\left((i+\delta) \omega_j\right) \\ \cos\left((i+\delta) \omega_j\right) \\ \end{bmatrix}\\ =& \begin{bmatrix} p_{i+\delta, 2j} \\ p_{i+\delta, 2j+1} \\ \end{bmatrix}, \end{aligned}\end{split}\]

where the \(2\times 2\) projection matrix does not depend on any position index \(i\).

10.4.4. Summary

  • In self-attention, the queries, keys, and values all come from the same place.

  • Both CNNs and self-attention enjoy parallel computation and self-attention has the shortest maximum path length. However, the quadratic computational complexity with respect to the sequence length makes self-attention prohibitively slow for very long sequences.

  • To use the sequence order information, we can inject absolute or relative positional information by adding positional encoding to the input representations.

10.4.5. Exercises

  1. Suppose that we design a deep architecture to represent a sequence by stacking self-attention layers with positional encoding. What could be issues?

  2. Can you design a learnable positional encoding method?