Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_attention-mechanisms/attention-scoring-functions.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_attention-mechanisms/attention-scoring-functions.ipynb .. _sec_attention-scoring-functions: 注意力评分函数 ============== 在 :numref:`sec_nadaraya-waston` 中,我们使用高斯核来对查询和键之间的关系建模。可以将 :eq:`eq_nadaraya-waston-gaussian` 中的高斯核的指数部分视为 *注意力评分函数(attention scoring function)*\ ,简称 *评分函数(scoring function)*\ ,然后把这个函数的输出结果输入到 softmax 函数中进行运算。通过上述步骤,我们将得到与键对应的值的概率分布(即注意力权重)。最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。 从宏观来看,可以使用上述算法来实现 :numref:`fig_qkv` 中的注意力机制框架。:numref:\ ``fig_attention_output`` 说明了如何将注意力汇聚的输出计算成为值的加权和,其中 𝑎 表示注意力评分函数。由于注意力权重是概率分布,因此加权和其本质上是加权平均值。 |计算注意力汇聚的输出为值的加权和。| .. _fig_attention_output: Mathematically, suppose that we have a query :math:`\mathbf{q} \in \mathbb{R}^q` and :math:`m` key-value pairs :math:`(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)`, where any :math:`\mathbf{k}_i \in \mathbb{R}^k` and any :math:`\mathbf{v}_i \in \mathbb{R}^v`. The attention pooling :math:`f` is instantiated as a weighted sum of the values: .. math:: f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v, :label: eq_attn-pooling 用数学语言描述,假设有一个查询 :math:`\mathbf{q} \in \mathbb{R}^q` 和 :math:`m` 个键值对 :math:`(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)`\ ,其中 :math:`\mathbf{k}_i \in \mathbb{R}^k`\ ,\ :math:`\mathbf{v}_i \in \mathbb{R}^v`\ 。注意力汇聚函数 𝑓 就被表示成值的加权和: .. math:: f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v, :label: eq_attn-pooling 其中查询 :math:`\mathbf{q}` 和键 :math:`\mathbf{k}_i` 的注意力权重(标量)是通过注意力评分函数 :math:`a` 将两个向量映射成标量,再经过 softmax 运算得到的: .. math:: \alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}. :label: eq_attn-scoring-alpha 正如我们所看到的,选择不同的注意力评分函数 :math:`a` 会导致不同的注意力汇聚操作。在本节中,我们将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。 .. |计算注意力汇聚的输出为值的加权和。| image:: https://zh-v2.d2l.ai/_images/attention-output.svg .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/Functions.java %load ../utils/PlotUtils.java .. code:: java NDManager manager = NDManager.newBaseManager(); 遮蔽softmax操作 --------------- 正如上面提到的,softmax 运算用于输出一个概率分布作为注意力权重。在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。例如,为了在 :numref:`sec_machine_translation` 中高效处理小批量数据集,某些文本序列被填充了没有意义的特殊词元。为了仅将有意义的词元作为值去获取注意力汇聚,可以指定一个有效序列长度(即词元的个数),以便在计算 softmax 时过滤掉超出指定范围的位置。通过这种方式,我们可以在下面的 ``maskedSoftmax`` 函数中实现这样的 *遮蔽 softmax 操作(masked softmax operation)*\ ,其中任何超出有效长度的位置都被遮蔽并置为0。 .. code:: java public static NDArray maskedSoftmax(NDArray X, NDArray validLens) { /* Perform softmax operation by masking elements on the last axis. */ // `X`: 3D NDArray, `validLens`: 1D or 2D NDArray if (validLens == null) { return X.softmax(-1); } Shape shape = X.getShape(); if (validLens.getShape().dimension() == 1) { validLens = validLens.repeat(shape.get(1)); } else { validLens = validLens.reshape(-1); } // On the last axis, replace masked elements with a very large negative // value, whose exponentiation outputs 0 X = X.reshape(new Shape(-1, shape.get(shape.dimension() - 1))) .sequenceMask(validLens, (float) -1E6); return X.softmax(-1).reshape(shape); } 为了演示此函数是如何工作的,考虑由两个 :math:`2 \times 4` 矩阵表示的样本,这两个样本的有效长度分别为2和3。经过遮蔽 softmax 操作,超出有效长度的值都被遮蔽为0。 .. code:: java maskedSoftmax( manager.randomUniform(0, 1, new Shape(2, 2, 4)), manager.create(new float[] {2, 3})); .. parsed-literal:: :class: output ND: (2, 2, 4) gpu(0) float32 [[[0.4549, 0.5451, 0. , 0. ], [0.6175, 0.3825, 0. , 0. ], ], [[0.294 , 0.3069, 0.3992, 0. ], [0.3747, 0.2626, 0.3627, 0. ], ], ] 同样,我们也可以使用二维张量为矩阵样本中的每一行指定有效长度。 .. code:: java maskedSoftmax( manager.randomUniform(0, 1, new Shape(2, 2, 4)), manager.create(new float[][] {{1, 3}, {2, 4}})); .. parsed-literal:: :class: output ND: (2, 2, 4) gpu(0) float32 [[[1. , 0. , 0. , 0. ], [0.2777, 0.4156, 0.3067, 0. ], ], [[0.3441, 0.6559, 0. , 0. ], [0.2544, 0.2482, 0.2036, 0.2939], ], ] .. _subsec_additive-attention: 加性注意力 ---------- 一般来说,当查询和键是不同长度的矢量时,可以使用加性注意力作为评分函数。给定查询 :math:`\mathbf{q} \in \mathbb{R}^q` 和键 :math:`\mathbf{k} \in \mathbb{R}^k`\ ,\ *加性注意力(additive attention)* 的评分函数为 .. math:: a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R}, :label: eq_additive-attn 其中可学习的参数是 :math:`\mathbf W_q\in\mathbb R^{h\times q}`, :math:`\mathbf W_k\in\mathbb R^{h\times k}`, and :math:`\mathbf w_v\in\mathbb R^{h}` 。如 :eq:`eq_additive-attn` 所示,将查询和键连接起来后输入到一个多层感知机(MLP)中,感知机包含一个隐藏层,其隐藏单元数是一个超参数 :math:`h`\ 。通过使用 :math:`\tanh` 作为激活函数,并且禁用偏置项,我们将在下面实现加性注意力。 .. code:: java /* Additive attention. */ public static class AdditiveAttention extends AbstractBlock { private Linear W_k; private Linear W_q; private Linear W_v; private Dropout dropout; public NDArray attentionWeights; public AdditiveAttention(int numHiddens, float dropout) { W_k = Linear.builder().setUnits(numHiddens).optBias(false).build(); addChildBlock("W_k", W_k); W_q = Linear.builder().setUnits(numHiddens).optBias(false).build(); addChildBlock("W_q", W_q); W_v = Linear.builder().setUnits(1).optBias(false).build(); addChildBlock("W_v", W_v); this.dropout = Dropout.builder().optRate(dropout).build(); addChildBlock("dropout", this.dropout); } @Override protected NDList forwardInternal( ParameterStore ps, NDList inputs, boolean training, PairList params) { // Shape of the output `queries` and `attentionWeights`: // (no. of queries, no. of key-value pairs) NDArray queries = inputs.get(0); NDArray keys = inputs.get(1); NDArray values = inputs.get(2); NDArray validLens = inputs.get(3); queries = W_q.forward(ps, new NDList(queries), training, params).head(); keys = W_k.forward(ps, new NDList(keys), training, params).head(); // After dimension expansion, shape of `queries`: (`batchSize`, no. of // queries, 1, `numHiddens`) and shape of `keys`: (`batchSize`, 1, // no. of key-value pairs, `numHiddens`). Sum them up with // broadcasting NDArray features = queries.expandDims(2).add(keys.expandDims(1)); features = features.tanh(); // There is only one output of `this.W_v`, so we remove the last // one-dimensional entry from the shape. Shape of `scores`: // (`batchSize`, no. of queries, no. of key-value pairs) NDArray result = W_v.forward(ps, new NDList(features), training, params).head(); NDArray scores = result.squeeze(-1); attentionWeights = maskedSoftmax(scores, validLens); // Shape of `values`: (`batchSize`, no. of key-value pairs, value dimension) NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params); return new NDList(list.head().batchDot(values)); } @Override public Shape[] getOutputShapes(Shape[] inputShapes) { throw new UnsupportedOperationException("Not implemented"); } @Override public void initializeChildBlocks( NDManager manager, DataType dataType, Shape... inputShapes) { W_q.initialize(manager, dataType, inputShapes[0]); W_k.initialize(manager, dataType, inputShapes[1]); long[] q = W_q.getOutputShapes(new Shape[] {inputShapes[0]})[0].getShape(); long[] k = W_k.getOutputShapes(new Shape[] {inputShapes[1]})[0].getShape(); long w = Math.max(q[q.length - 2], k[k.length - 2]); long h = Math.max(q[q.length - 1], k[k.length - 1]); long[] shape = new long[] {2, 1, w, h}; W_v.initialize(manager, dataType, new Shape(shape)); long[] dropoutShape = new long[shape.length - 1]; System.arraycopy(shape, 0, dropoutShape, 0, dropoutShape.length); dropout.initialize(manager, dataType, new Shape(dropoutShape)); } } 让我们用一个小例子来演示上面的 ``AdditiveAttention`` 类,其中查询、键和值的形状为(批量大小、步数或词元序列长度、特征大小),实际输出为 (:math:`2`, :math:`1`, :math:`20`), (:math:`2`, :math:`10`, :math:`2`) 和 (:math:`2`, :math:`10`, :math:`4`)。注意力汇聚输出的形状为(批量大小、查询的步数、值的维度)。 .. code:: java NDArray queries = manager.randomNormal(0, 1, new Shape(2, 1, 20), DataType.FLOAT32); NDArray keys = manager.ones(new Shape(2, 10, 2)); // The two value matrices in the `values` minibatch are identical NDArray values = manager.arange(40f).reshape(1, 10, 4).repeat(0, 2); NDArray validLens = manager.create(new float[] {2, 6}); AdditiveAttention attention = new AdditiveAttention(8, 0.1f); NDList input = new NDList(queries, keys, values, validLens); ParameterStore ps = new ParameterStore(manager, false); attention.initialize(manager, DataType.FLOAT32, input.getShapes()); attention.forward(ps, input, false).head(); .. parsed-literal:: :class: output ND: (2, 1, 4) gpu(0) float32 [[[ 2., 3., 4., 5.], ], [[10., 11., 12., 13.], ], ] 尽管加性注意力包含了可学习的参数,但由于本例子中每个键都是相同的,所以注意力权重是均匀的,由指定的有效长度决定。 .. code:: java PlotUtils.showHeatmaps( attention.attentionWeights.reshape(1, 1, 2, 10), "Keys", "Queries", new String[] {""}, 500, 700); .. raw:: html
缩放点积注意力 -------------- 使用点积可以得到计算效率更高的评分函数。但是点积操作要求查询和键具有相同的长度 :math:`d`\ 。假设查询和键的所有元素都是独立的随机变量,并且都满足均值为0和方差为。那么两个向量的点积的均值为0,方差为 :math:`d`\ 。为确保无论向量长度如何,点积的方差在不考虑向量长度的情况下仍然是1,则可以使用 *缩放点积注意力(scaled dot-product attention)* 评分函数: .. math:: a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d} 将点积除以 :math:`\sqrt{d}`\ 。在实践中,我们通常从小批量的角度来考虑提高效率,例如基于 :math:`n` 个查询和 :math:`m` 个键-值对计算注意力,其中查询和键的长度为 :math:`d`\ ,值的长度为 :math:`v`\ 。查询 :math:`\mathbf Q\in\mathbb R^{n\times d}`\ ,键 :math:`\mathbf K\in\mathbb R^{m\times d}` 和值 :math:`\mathbf V\in\mathbb R^{m\times v}` 的缩放点积注意力是 .. math:: \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}. :label: eq_softmax_QK_V 在下面的缩放点积注意力的实现中,我们使用了 dropout 进行模型正则化。 .. code:: java /* Scaled dot product attention. */ public static class DotProductAttention extends AbstractBlock { private Dropout dropout; public NDArray attentionWeights; public DotProductAttention(float dropout) { this.dropout = Dropout.builder().optRate(dropout).build(); this.addChildBlock("dropout", this.dropout); this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT); } @Override protected NDList forwardInternal( ParameterStore ps, NDList inputs, boolean training, PairList params) { // Shape of `queries`: (`batchSize`, no. of queries, `d`) // Shape of `keys`: (`batchSize`, no. of key-value pairs, `d`) // Shape of `values`: (`batchSize`, no. of key-value pairs, value // dimension) // Shape of `valid_lens`: (`batchSize`,) or (`batchSize`, no. of queries) NDArray queries = inputs.get(0); NDArray keys = inputs.get(1); NDArray values = inputs.get(2); NDArray validLens = inputs.get(3); Long d = queries.getShape().get(queries.getShape().dimension() - 1); // Swap the last two dimensions of `keys` and perform batchDot NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2)); attentionWeights = maskedSoftmax(scores, validLens); NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params); return new NDList(list.head().batchDot(values)); } @Override public Shape[] getOutputShapes(Shape[] inputShapes) { throw new UnsupportedOperationException("Not implemented"); } @Override public void initializeChildBlocks( NDManager manager, DataType dataType, Shape... inputShapes) { try (NDManager sub = manager.newSubManager()) { NDArray queries = sub.zeros(inputShapes[0], dataType); NDArray keys = sub.zeros(inputShapes[1], dataType); NDArray scores = queries.batchDot(keys.swapAxes(1, 2)); dropout.initialize(manager, dataType, scores.getShape()); } } } 为了演示上述的 ``DotProductAttention`` 类,我们使用了与先前加性注意力例子中相同的键、值和有效长度。对于点积操作,令查询的特征维度与键的特征维度大小相同。 .. code:: java queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32); DotProductAttention productAttention = new DotProductAttention(0.5f); input = new NDList(queries, keys, values, validLens); productAttention.initialize(manager, DataType.FLOAT32, input.getShapes()); productAttention.forward(ps, input, false).head(); .. parsed-literal:: :class: output ND: (2, 1, 4) gpu(0) float32 [[[ 2., 3., 4., 5.], ], [[10., 11., 12., 13.], ], ] 与加性注意力演示相同,由于\ ``键`` 包含的是相同的元素,而这些元素无法通过任何查询进行区分,因此获得了均匀的注意力权重。 .. code:: java PlotUtils.showHeatmaps( productAttention.attentionWeights.reshape(1, 1, 2, 10), "Keys", "Queries", new String[] {""}, 500, 700); .. raw:: html
小结 ---- - 可以将注意力汇聚的输出计算作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作。 - 当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数。当它们的长度相同时,使用缩放的“点-积”注意力评分函数的计算效率更高。 练习 ---- 1. 修改小例子中的键,并且可视化注意力权重。可加性注意力和缩放的“点-积”注意力是否仍然产生相同的结果?为什么? 2. 只使用矩阵乘法,您能否为具有不同矢量长度的查询和键设计新的评分函数? 3. 当查询和键具有相同的矢量长度时,矢量求和作为评分函数是否比“点-积”更好?为什么?