Run this notebook online:Binder or Colab: Colab

10.2. 注意力评分函数

sec_nadaraya-waston 中,我们使用高斯核来对查询和键之间的关系建模。可以将 eq_nadaraya-waston-gaussian 中的高斯核的指数部分视为 注意力评分函数(attention scoring function),简称 评分函数(scoring function),然后把这个函数的输出结果输入到 softmax 函数中进行运算。通过上述步骤,我们将得到与键对应的值的概率分布(即注意力权重)。最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。

从宏观来看,可以使用上述算法来实现 Section 10.1.2 中的注意力机制框架。:numref:fig_attention_output 说明了如何将注意力汇聚的输出计算成为值的加权和,其中 𝑎 表示注意力评分函数。由于注意力权重是概率分布,因此加权和其本质上是加权平均值。

计算注意力汇聚的输出为值的加权和。 .. _fig_attention_output:

Mathematically, suppose that we have a query \(\mathbf{q} \in \mathbb{R}^q\) and \(m\) key-value pairs \((\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)\), where any \(\mathbf{k}_i \in \mathbb{R}^k\) and any \(\mathbf{v}_i \in \mathbb{R}^v\). The attention pooling \(f\) is instantiated as a weighted sum of the values:

(10.2.2)\[f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,\]

用数学语言描述,假设有一个查询 \(\mathbf{q} \in \mathbb{R}^q\)\(m\) 个键值对 \((\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)\),其中 \(\mathbf{k}_i \in \mathbb{R}^k\)\(\mathbf{v}_i \in \mathbb{R}^v\)。注意力汇聚函数 𝑓 就被表示成值的加权和:

(10.2.2)\[f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,\]

其中查询 \(\mathbf{q}\) 和键 \(\mathbf{k}_i\) 的注意力权重(标量)是通过注意力评分函数 \(a\) 将两个向量映射成标量,再经过 softmax 运算得到的:

(10.2.3)\[\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}.\]

正如我们所看到的,选择不同的注意力评分函数 \(a\) 会导致不同的注意力汇聚操作。在本节中,我们将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java
NDManager manager = NDManager.newBaseManager();

10.2.1. 遮蔽softmax操作

正如上面提到的,softmax 运算用于输出一个概率分布作为注意力权重。在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。例如,为了在 Section 9.5 中高效处理小批量数据集,某些文本序列被填充了没有意义的特殊词元。为了仅将有意义的词元作为值去获取注意力汇聚,可以指定一个有效序列长度(即词元的个数),以便在计算 softmax 时过滤掉超出指定范围的位置。通过这种方式,我们可以在下面的 maskedSoftmax 函数中实现这样的 遮蔽 softmax 操作(masked softmax operation),其中任何超出有效长度的位置都被遮蔽并置为0。

public static NDArray maskedSoftmax(NDArray X, NDArray validLens) {
    /* Perform softmax operation by masking elements on the last axis. */
    // `X`: 3D NDArray, `validLens`: 1D or 2D NDArray
    if (validLens == null) {
        return X.softmax(-1);
    }

    Shape shape = X.getShape();
    if (validLens.getShape().dimension() == 1) {
        validLens = validLens.repeat(shape.get(1));
    } else {
        validLens = validLens.reshape(-1);
    }
    // On the last axis, replace masked elements with a very large negative
    // value, whose exponentiation outputs 0
    X = X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))
            .sequenceMask(validLens, (float) -1E6);
    return X.softmax(-1).reshape(shape);
}

为了演示此函数是如何工作的,考虑由两个 \(2 \times 4\) 矩阵表示的样本,这两个样本的有效长度分别为2和3。经过遮蔽 softmax 操作,超出有效长度的值都被遮蔽为0。

maskedSoftmax(
        manager.randomUniform(0, 1, new Shape(2, 2, 4)),
        manager.create(new float[] {2, 3}));
ND: (2, 2, 4) gpu(0) float32
[[[0.4549, 0.5451, 0.    , 0.    ],
  [0.6175, 0.3825, 0.    , 0.    ],
 ],
 [[0.294 , 0.3069, 0.3992, 0.    ],
  [0.3747, 0.2626, 0.3627, 0.    ],
 ],
]

同样,我们也可以使用二维张量为矩阵样本中的每一行指定有效长度。

maskedSoftmax(
        manager.randomUniform(0, 1, new Shape(2, 2, 4)),
        manager.create(new float[][] {{1, 3}, {2, 4}}));
ND: (2, 2, 4) gpu(0) float32
[[[1.    , 0.    , 0.    , 0.    ],
  [0.2777, 0.4156, 0.3067, 0.    ],
 ],
 [[0.3441, 0.6559, 0.    , 0.    ],
  [0.2544, 0.2482, 0.2036, 0.2939],
 ],
]

10.2.2. 加性注意力

一般来说,当查询和键是不同长度的矢量时,可以使用加性注意力作为评分函数。给定查询 \(\mathbf{q} \in \mathbb{R}^q\) 和键 \(\mathbf{k} \in \mathbb{R}^k\)加性注意力(additive attention) 的评分函数为

(10.2.4)\[a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},\]

其中可学习的参数是 \(\mathbf W_q\in\mathbb R^{h\times q}\), \(\mathbf W_k\in\mathbb R^{h\times k}\), and \(\mathbf w_v\in\mathbb R^{h}\) 。如 (10.2.4) 所示,将查询和键连接起来后输入到一个多层感知机(MLP)中,感知机包含一个隐藏层,其隐藏单元数是一个超参数 \(h\)。通过使用 \(\tanh\) 作为激活函数,并且禁用偏置项,我们将在下面实现加性注意力。

/* Additive attention. */
public static class AdditiveAttention extends AbstractBlock {

    private Linear W_k;
    private Linear W_q;
    private Linear W_v;
    private Dropout dropout;
    public NDArray attentionWeights;

    public AdditiveAttention(int numHiddens, float dropout) {
        W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();
        addChildBlock("W_k", W_k);

        W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();
        addChildBlock("W_q", W_q);

        W_v = Linear.builder().setUnits(1).optBias(false).build();
        addChildBlock("W_v", W_v);

        this.dropout = Dropout.builder().optRate(dropout).build();
        addChildBlock("dropout", this.dropout);
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore ps,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        // Shape of the output `queries` and `attentionWeights`:
        // (no. of queries, no. of key-value pairs)
        NDArray queries = inputs.get(0);
        NDArray keys = inputs.get(1);
        NDArray values = inputs.get(2);
        NDArray validLens = inputs.get(3);

        queries = W_q.forward(ps, new NDList(queries), training, params).head();
        keys = W_k.forward(ps, new NDList(keys), training, params).head();
        // After dimension expansion, shape of `queries`: (`batchSize`, no. of
        // queries, 1, `numHiddens`) and shape of `keys`: (`batchSize`, 1,
        // no. of key-value pairs, `numHiddens`). Sum them up with
        // broadcasting
        NDArray features = queries.expandDims(2).add(keys.expandDims(1));
        features = features.tanh();
        // There is only one output of `this.W_v`, so we remove the last
        // one-dimensional entry from the shape. Shape of `scores`:
        // (`batchSize`, no. of queries, no. of key-value pairs)
        NDArray result = W_v.forward(ps, new NDList(features), training, params).head();
        NDArray scores = result.squeeze(-1);
        attentionWeights = maskedSoftmax(scores, validLens);
        // Shape of `values`: (`batchSize`, no. of key-value pairs, value dimension)
        NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);
        return new NDList(list.head().batchDot(values));
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initializeChildBlocks(
            NDManager manager, DataType dataType, Shape... inputShapes) {
        W_q.initialize(manager, dataType, inputShapes[0]);
        W_k.initialize(manager, dataType, inputShapes[1]);
        long[] q = W_q.getOutputShapes(new Shape[] {inputShapes[0]})[0].getShape();
        long[] k = W_k.getOutputShapes(new Shape[] {inputShapes[1]})[0].getShape();
        long w = Math.max(q[q.length - 2], k[k.length - 2]);
        long h = Math.max(q[q.length - 1], k[k.length - 1]);
        long[] shape = new long[] {2, 1, w, h};
        W_v.initialize(manager, dataType, new Shape(shape));
        long[] dropoutShape = new long[shape.length - 1];
        System.arraycopy(shape, 0, dropoutShape, 0, dropoutShape.length);
        dropout.initialize(manager, dataType, new Shape(dropoutShape));
    }
}

让我们用一个小例子来演示上面的 AdditiveAttention 类,其中查询、键和值的形状为(批量大小、步数或词元序列长度、特征大小),实际输出为 (\(2\), \(1\), \(20\)), (\(2\), \(10\), \(2\)) 和 (\(2\), \(10\), \(4\))。注意力汇聚输出的形状为(批量大小、查询的步数、值的维度)。

NDArray queries = manager.randomNormal(0, 1, new Shape(2, 1, 20), DataType.FLOAT32);
NDArray keys = manager.ones(new Shape(2, 10, 2));
// The two value matrices in the `values` minibatch are identical
NDArray values = manager.arange(40f).reshape(1, 10, 4).repeat(0, 2);
NDArray validLens = manager.create(new float[] {2, 6});

AdditiveAttention attention = new AdditiveAttention(8, 0.1f);
NDList input = new NDList(queries, keys, values, validLens);
ParameterStore ps = new ParameterStore(manager, false);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());
attention.forward(ps, input, false).head();
ND: (2, 1, 4) gpu(0) float32
[[[ 2.,  3.,  4.,  5.],
 ],
 [[10., 11., 12., 13.],
 ],
]

尽管加性注意力包含了可学习的参数,但由于本例子中每个键都是相同的,所以注意力权重是均匀的,由指定的有效长度决定。

PlotUtils.showHeatmaps(
            attention.attentionWeights.reshape(1, 1, 2, 10),
            "Keys",
            "Queries",
            new String[] {""},
            500,
            700);

10.2.3. 缩放点积注意力

使用点积可以得到计算效率更高的评分函数。但是点积操作要求查询和键具有相同的长度 \(d\)。假设查询和键的所有元素都是独立的随机变量,并且都满足均值为0和方差为。那么两个向量的点积的均值为0,方差为 \(d\)。为确保无论向量长度如何,点积的方差在不考虑向量长度的情况下仍然是1,则可以使用 缩放点积注意力(scaled dot-product attention) 评分函数:

(10.2.5)\[a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d}\]

将点积除以 \(\sqrt{d}\)。在实践中,我们通常从小批量的角度来考虑提高效率,例如基于 \(n\) 个查询和 \(m\) 个键-值对计算注意力,其中查询和键的长度为 \(d\),值的长度为 \(v\)。查询 \(\mathbf Q\in\mathbb R^{n\times d}\),键 \(\mathbf K\in\mathbb R^{m\times d}\) 和值 \(\mathbf V\in\mathbb R^{m\times v}\) 的缩放点积注意力是

(10.2.6)\[\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.\]

在下面的缩放点积注意力的实现中,我们使用了 dropout 进行模型正则化。

/* Scaled dot product attention. */
public static class DotProductAttention extends AbstractBlock {

    private Dropout dropout;
    public NDArray attentionWeights;

    public DotProductAttention(float dropout) {
        this.dropout = Dropout.builder().optRate(dropout).build();
        this.addChildBlock("dropout", this.dropout);
        this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT);
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore ps,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        // Shape of `queries`: (`batchSize`, no. of queries, `d`)
        // Shape of `keys`: (`batchSize`, no. of key-value pairs, `d`)
        // Shape of `values`: (`batchSize`, no. of key-value pairs, value
        // dimension)
        // Shape of `valid_lens`: (`batchSize`,) or (`batchSize`, no. of queries)
        NDArray queries = inputs.get(0);
        NDArray keys = inputs.get(1);
        NDArray values = inputs.get(2);
        NDArray validLens = inputs.get(3);

        Long d = queries.getShape().get(queries.getShape().dimension() - 1);
        // Swap the last two dimensions of `keys` and perform batchDot
        NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2));
        attentionWeights = maskedSoftmax(scores, validLens);
        NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);
        return new NDList(list.head().batchDot(values));
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initializeChildBlocks(
            NDManager manager, DataType dataType, Shape... inputShapes) {
        try (NDManager sub = manager.newSubManager()) {
            NDArray queries = sub.zeros(inputShapes[0], dataType);
            NDArray keys = sub.zeros(inputShapes[1], dataType);
            NDArray scores = queries.batchDot(keys.swapAxes(1, 2));
            dropout.initialize(manager, dataType, scores.getShape());
        }
    }
}

为了演示上述的 DotProductAttention 类,我们使用了与先前加性注意力例子中相同的键、值和有效长度。对于点积操作,令查询的特征维度与键的特征维度大小相同。

queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32);
DotProductAttention productAttention = new DotProductAttention(0.5f);
input = new NDList(queries, keys, values, validLens);
productAttention.initialize(manager, DataType.FLOAT32, input.getShapes());
productAttention.forward(ps, input, false).head();
ND: (2, 1, 4) gpu(0) float32
[[[ 2.,  3.,  4.,  5.],
 ],
 [[10., 11., 12., 13.],
 ],
]

与加性注意力演示相同,由于 包含的是相同的元素,而这些元素无法通过任何查询进行区分,因此获得了均匀的注意力权重。

PlotUtils.showHeatmaps(
        productAttention.attentionWeights.reshape(1, 1, 2, 10),
        "Keys",
        "Queries",
        new String[] {""},
        500,
        700);

10.2.4. 小结

  • 可以将注意力汇聚的输出计算作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作。

  • 当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数。当它们的长度相同时,使用缩放的“点-积”注意力评分函数的计算效率更高。

10.2.5. 练习

  1. 修改小例子中的键,并且可视化注意力权重。可加性注意力和缩放的“点-积”注意力是否仍然产生相同的结果?为什么?

  2. 只使用矩阵乘法,您能否为具有不同矢量长度的查询和键设计新的评分函数?

  3. 当查询和键具有相同的矢量长度时,矢量求和作为评分函数是否比“点-积”更好?为什么?