Run this notebook online:Binder or Colab: Colab

10.3. Multi-Head Attention

In practice, given the same set of queries, keys, and values we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence. Thus, it may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.

To this end, instead of performing a single attention pooling, queries, keys, and values can be transformed with h independently learned linear projections. Then these h projected queries, keys, and values are fed into attention pooling in parallel. In the end, h attention pooling outputs are concatenated and transformed with another learned linear projection to produce the final output. This design is called multi-head attention, where each of the h attention pooling outputs is a head [Vaswani et al., 2017]. Using fully-connected layers to perform learnable linear transformations, fig_multi-head-attention describes multi-head attention.

Multi-head attention, where multiple heads are concatenated then linearly transformed. .. _fig_multi-head-attention:

10.3.1. Model

Before providing the implementation of multi-head attention, let us formalize this model mathematically. Given a query qRdq, a key kRdk, and a value vRdv, each attention head hi (i=1,,h) is computed as

(10.3.1)hi=f(W(q)iq,W(k)ik,W(v)iv)Rpv,

where learnable parameters W(q)iRpq×dq, W(k)iRpk×dk and W(v)iRpv×dv, and f is attention pooling, such as additive attention and scaled dot-product attention in Section 10.2. The multi-head attention output is another linear transformation via learnable parameters WoRpo×hpv of the concatenation of h heads:

(10.3.2)Wo[h1hh]Rpo.

Based on this design, each head may attend to different parts of the input. More sophisticated functions than the simple weighted average can be expressed.

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java

%load ../utils/attention/Chap10Utils.java
%load ../utils/attention/DotProductAttention.java
%load ../utils/attention/MultiHeadAttention.java
%load ../utils/attention/PositionalEncoding.java
NDManager manager = NDManager.newBaseManager();

To allow for parallel computation of multiple heads, the below MultiHeadAttention class uses two transposition functions as defined below. Specifically, the transposeOutput function reverses the operation of the transposeQkv function.

public static NDArray transposeQkv(NDArray X, int numHeads) {
    // Shape of input `X`:
    // (`batchSize`, no. of queries or key-value pairs, `numHiddens`).
    // Shape of output `X`:
    // (`batchSize`, no. of queries or key-value pairs, `numHeads`,
    // `numHiddens` / `numHeads`)
    X = X.reshape(X.getShape().get(0), X.getShape().get(1), numHeads, -1);

    // Shape of output `X`:
    // (`batchSize`, `numHeads`, no. of queries or key-value pairs,
    // `numHiddens` / `numHeads`)
    X = X.transpose(0, 2, 1, 3);

    // Shape of `output`:
    // (`batchSize` * `numHeads`, no. of queries or key-value pairs,
    // `numHiddens` / `numHeads`)
    return X.reshape(-1, X.getShape().get(2), X.getShape().get(3));
}

public static NDArray transposeOutput(NDArray X, int numHeads) {
    X = X.reshape(-1, numHeads, X.getShape().get(1), X.getShape().get(2));
    X = X.transpose(0, 2, 1, 3);
    return X.reshape(X.getShape().get(0), X.getShape().get(1), -1);
}

10.3.2. Implementation

In our implementation, we choose the scaled dot-product attention for each head of the multi-head attention. To avoid significant growth of computational cost and parameterization cost, we set pq=pk=pv=po/h. Note that h heads can be computed in parallel if we set the number of outputs of linear transformations for the query, key, and value to pqh=pkh=pvh=po. In the following implementation, po is specified via the argument numHiddens.

public static class MultiHeadAttention extends AbstractBlock {

    private int numHeads;
    public DotProductAttention attention;
    private Linear W_k;
    private Linear W_q;
    private Linear W_v;
    private Linear W_o;

    public MultiHeadAttention(int numHiddens, int numHeads, float dropout, boolean useBias) {
        this.numHeads = numHeads;

        attention = new DotProductAttention(dropout);

        W_q = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
        addChildBlock("W_q", W_q);

        W_k = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
        addChildBlock("W_k", W_k);

        W_v = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
        addChildBlock("W_v", W_v);

        W_o = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
        addChildBlock("W_o", W_o);

        Dropout dropout1 = Dropout.builder().optRate(dropout).build();
        addChildBlock("dropout", dropout1);
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore ps,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        // Shape of `queries`, `keys`, or `values`:
        // (`batchSize`, no. of queries or key-value pairs, `numHiddens`)
        // Shape of `validLens`:
        // (`batchSize`,) or (`batchSize`, no. of queries)
        // After transposing, shape of output `queries`, `keys`, or `values`:
        // (`batchSize` * `numHeads`, no. of queries or key-value pairs,
        // `numHiddens` / `numHeads`)
        NDArray queries = inputs.get(0);
        NDArray keys = inputs.get(1);
        NDArray values = inputs.get(2);
        NDArray validLens = inputs.get(3);
        // On axis 0, copy the first item (scalar or vector) for
        // `numHeads` times, then copy the next item, and so on
        validLens = validLens.repeat(0, numHeads);

        queries =
                transposeQkv(
                        W_q.forward(ps, new NDList(queries), training, params).get(0),
                        numHeads);
        keys =
                transposeQkv(
                        W_k.forward(ps, new NDList(keys), training, params).get(0), numHeads);
        values =
                transposeQkv(
                        W_v.forward(ps, new NDList(values), training, params).get(0), numHeads);

        // Shape of `output`: (`batchSize` * `numHeads`, no. of queries,
        // `numHiddens` / `numHeads`)
        NDArray output =
                attention
                        .forward(
                                ps,
                                new NDList(queries, keys, values, validLens),
                                training,
                                params)
                        .get(0);

        // Shape of `outputConcat`:
        // (`batchSize`, no. of queries, `numHiddens`)
        NDArray outputConcat = transposeOutput(output, numHeads);
        return new NDList(W_o.forward(ps, new NDList(outputConcat), training, params).get(0));
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initializeChildBlocks(
            NDManager manager, DataType dataType, Shape... inputShapes) {
        try (NDManager sub = manager.newSubManager()) {
            NDArray queries = sub.zeros(inputShapes[0], dataType);
            NDArray keys = sub.zeros(inputShapes[1], dataType);
            NDArray values = sub.zeros(inputShapes[2], dataType);
            NDArray validLens = sub.zeros(inputShapes[3], dataType);
            validLens = validLens.repeat(0, numHeads);

            ParameterStore ps = new ParameterStore(sub, false);

            W_q.initialize(manager, dataType, queries.getShape());
            W_k.initialize(manager, dataType, keys.getShape());
            W_v.initialize(manager, dataType, values.getShape());

            queries =
                    transposeQkv(W_q.forward(ps, new NDList(queries), false).get(0), numHeads);
            keys = transposeQkv(W_k.forward(ps, new NDList(keys), false).get(0), numHeads);
            values = transposeQkv(W_v.forward(ps, new NDList(values), false).get(0), numHeads);

            NDList list = new NDList(queries, keys, values, validLens);
            attention.initialize(sub, dataType, list.getShapes());
            NDArray output = attention.forward(ps, list, false).head();
            NDArray outputConcat = Chap10Utils.transposeOutput(output, numHeads);

            W_o.initialize(manager, dataType, outputConcat.getShape());
        }
    }
}

Let us test our implemented MultiHeadAttention class using a toy example where keys and values are the same. As a result, the shape of the multi-head attention output is (batchSize, numQueries, numHiddens).

int numHiddens = 100;
int numHeads = 5;
MultiHeadAttention attention = new MultiHeadAttention(numHiddens, numHeads, 0.5f, false);
int batchSize = 2;
int numQueries = 4;
int numKvpairs = 6;
NDArray validLens = manager.create(new float[] {3, 2});
NDArray X = manager.ones(new Shape(batchSize, numQueries, numHiddens));
NDArray Y = manager.ones(new Shape(batchSize, numKvpairs, numHiddens));

ParameterStore ps = new ParameterStore(manager, false);
NDList input = new NDList(X, Y, Y, validLens);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList result = attention.forward(ps, input, false);
result.get(0).getShape();
(2, 4, 100)

10.3.3. Summary

  • Multi-head attention combines knowledge of the same attention pooling via different representation subspaces of queries, keys, and values.

  • To compute multiple heads of multi-head attention in parallel, proper tensor manipulation is needed.

10.3.4. Exercises

  1. Visualize attention weights of multiple heads in this experiment.

  2. Suppose that we have a trained model based on multi-head attention and we want to prune least important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head?