Run this notebook online:Binder or Colab: Colab

10.3. 多头注意力

在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,例如捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖)。因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。

为此,与使用单独一个注意力汇聚不同,我们可以用独立学习得到的 \(h\) 组不同的 线性投影(linear projections)来变换查询、键和值。然后,这 \(h\) 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 \(h\) 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为 多头注意力,其中 \(h\) 个注意力汇聚输出中的每一个输出都被称作一个 头(head) [Vaswani et al., 2017]。 图 Section 10.3 展示了使用全连接层来实现可学习的线性变换的多头注意力。

多头注意力,多个头连结然后线性变换。

10.3.1. 模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 \(\mathbf{q} \in \mathbb{R}^{d_q}\)、键 \(\mathbf{k} \in \mathbb{R}^{d_k}\) 和值 \(\mathbf{v} \in \mathbb{R}^{d_v}\),每个注意力头 \(\mathbf{h}_i\) (\(i = 1, \ldots, h\)) 的计算方法为:

(10.3.1)\[\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},\]

其中,可学习的参数包括 \(\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}\)\(\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}\)\(\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}\),以及注意力汇聚的函数 \(f\)\(f\) 可以是 Section 10.2 中的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 \(h\) 个头连结后的结果,因此其可学习参数是 \(\mathbf W_o\in\mathbb R^{p_o\times h p_v}\)

(10.3.2)\[\begin{split}\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.\end{split}\]

基于这种设计,每个头都可能会关注输入的不同部分。可以表示比简单加权平均值更复杂的函数。

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java

%load ../utils/attention/Chap10Utils.java
%load ../utils/attention/DotProductAttention.java
%load ../utils/attention/MultiHeadAttention.java
%load ../utils/attention/PositionalEncoding.java
NDManager manager = NDManager.newBaseManager();

为了允许并行计算多个头,下面的MultiHeadAttention 类使用两个如下定义的转置函数。具体来说,transposeOutput 函数颠倒了 transposeQkv 函数的操作。

public static NDArray transposeQkv(NDArray X, int numHeads) {
    // Shape of input `X`:
    // (`batchSize`, no. of queries or key-value pairs, `numHiddens`).
    // Shape of output `X`:
    // (`batchSize`, no. of queries or key-value pairs, `numHeads`,
    // `numHiddens` / `numHeads`)
    X = X.reshape(X.getShape().get(0), X.getShape().get(1), numHeads, -1);

    // Shape of output `X`:
    // (`batchSize`, `numHeads`, no. of queries or key-value pairs,
    // `numHiddens` / `numHeads`)
    X = X.transpose(0, 2, 1, 3);

    // Shape of `output`:
    // (`batchSize` * `numHeads`, no. of queries or key-value pairs,
    // `numHiddens` / `numHeads`)
    return X.reshape(-1, X.getShape().get(2), X.getShape().get(3));
}

public static NDArray transposeOutput(NDArray X, int numHeads) {
    X = X.reshape(-1, numHeads, X.getShape().get(1), X.getShape().get(2));
    X = X.transpose(0, 2, 1, 3);
    return X.reshape(X.getShape().get(0), X.getShape().get(1), -1);
}

10.3.2. 实现

在实现过程中,我们选择缩放点积注意力作为每一个注意力头。为了避免计算成本和参数数量的大幅增长,我们设定 \(p_q = p_k = p_v = p_o / h\)。值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为 \(p_q h = p_k h = p_v h = p_o\),则可以并行计算 \(h\) 个头。在下面的实现中,\(p_o\) 是通过参数 numHiddens 指定的。

public static class MultiHeadAttention extends AbstractBlock {

    private int numHeads;
    public DotProductAttention attention;
    private Linear W_k;
    private Linear W_q;
    private Linear W_v;
    private Linear W_o;

    public MultiHeadAttention(int numHiddens, int numHeads, float dropout, boolean useBias) {
        this.numHeads = numHeads;

        attention = new DotProductAttention(dropout);

        W_q = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
        addChildBlock("W_q", W_q);

        W_k = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
        addChildBlock("W_k", W_k);

        W_v = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
        addChildBlock("W_v", W_v);

        W_o = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
        addChildBlock("W_o", W_o);

        Dropout dropout1 = Dropout.builder().optRate(dropout).build();
        addChildBlock("dropout", dropout1);
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore ps,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        // Shape of `queries`, `keys`, or `values`:
        // (`batchSize`, no. of queries or key-value pairs, `numHiddens`)
        // Shape of `validLens`:
        // (`batchSize`,) or (`batchSize`, no. of queries)
        // After transposing, shape of output `queries`, `keys`, or `values`:
        // (`batchSize` * `numHeads`, no. of queries or key-value pairs,
        // `numHiddens` / `numHeads`)
        NDArray queries = inputs.get(0);
        NDArray keys = inputs.get(1);
        NDArray values = inputs.get(2);
        NDArray validLens = inputs.get(3);
        // On axis 0, copy the first item (scalar or vector) for
        // `numHeads` times, then copy the next item, and so on
        validLens = validLens.repeat(0, numHeads);

        queries =
                transposeQkv(
                        W_q.forward(ps, new NDList(queries), training, params).get(0),
                        numHeads);
        keys =
                transposeQkv(
                        W_k.forward(ps, new NDList(keys), training, params).get(0), numHeads);
        values =
                transposeQkv(
                        W_v.forward(ps, new NDList(values), training, params).get(0), numHeads);

        // Shape of `output`: (`batchSize` * `numHeads`, no. of queries,
        // `numHiddens` / `numHeads`)
        NDArray output =
                attention
                        .forward(
                                ps,
                                new NDList(queries, keys, values, validLens),
                                training,
                                params)
                        .get(0);

        // Shape of `outputConcat`:
        // (`batchSize`, no. of queries, `numHiddens`)
        NDArray outputConcat = transposeOutput(output, numHeads);
        return new NDList(W_o.forward(ps, new NDList(outputConcat), training, params).get(0));
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initializeChildBlocks(
            NDManager manager, DataType dataType, Shape... inputShapes) {
        try (NDManager sub = manager.newSubManager()) {
            NDArray queries = sub.zeros(inputShapes[0], dataType);
            NDArray keys = sub.zeros(inputShapes[1], dataType);
            NDArray values = sub.zeros(inputShapes[2], dataType);
            NDArray validLens = sub.zeros(inputShapes[3], dataType);
            validLens = validLens.repeat(0, numHeads);

            ParameterStore ps = new ParameterStore(sub, false);

            W_q.initialize(manager, dataType, queries.getShape());
            W_k.initialize(manager, dataType, keys.getShape());
            W_v.initialize(manager, dataType, values.getShape());

            queries =
                    transposeQkv(W_q.forward(ps, new NDList(queries), false).get(0), numHeads);
            keys = transposeQkv(W_k.forward(ps, new NDList(keys), false).get(0), numHeads);
            values = transposeQkv(W_v.forward(ps, new NDList(values), false).get(0), numHeads);

            NDList list = new NDList(queries, keys, values, validLens);
            attention.initialize(sub, dataType, list.getShapes());
            NDArray output = attention.forward(ps, list, false).head();
            NDArray outputConcat = Chap10Utils.transposeOutput(output, numHeads);

            W_o.initialize(manager, dataType, outputConcat.getShape());
        }
    }
}

让我们使用键和值相同的小例子来测试我们编写的 MultiHeadAttention 类。多头注意力输出的形状是 (batchSize, numQueries, numHiddens)。

int numHiddens = 100;
int numHeads = 5;
MultiHeadAttention attention = new MultiHeadAttention(numHiddens, numHeads, 0.5f, false);
int batchSize = 2;
int numQueries = 4;
int numKvpairs = 6;
NDArray validLens = manager.create(new float[] {3, 2});
NDArray X = manager.ones(new Shape(batchSize, numQueries, numHiddens));
NDArray Y = manager.ones(new Shape(batchSize, numKvpairs, numHiddens));

ParameterStore ps = new ParameterStore(manager, false);
NDList input = new NDList(X, Y, Y, validLens);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList result = attention.forward(ps, input, false);
result.get(0).getShape();
(2, 4, 100)

10.3.3. 小结

  • 多头注意力融合了来自于相同的注意力汇聚产生的不同的知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

  • 基于适当的张量操作,可以实现多头注意力的并行计算。

10.3.4. 练习

  1. 分别可视化这个实验中的多个头的注意力权重。

  2. 假设我们已经拥有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。应该如何设计实验来衡量注意力头的重要性?