Run this notebook online: or Colab:
10.2. 注意力评分函数¶
在 sec_nadaraya-waston
中,我们使用高斯核来对查询和键之间的关系建模。可以将
eq_nadaraya-waston-gaussian
中的高斯核的指数部分视为
注意力评分函数(attention scoring function),简称
评分函数(scoring function),然后把这个函数的输出结果输入到 softmax
函数中进行运算。通过上述步骤,我们将得到与键对应的值的概率分布(即注意力权重)。最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。
从宏观来看,可以使用上述算法来实现 Section 10.1.2
中的注意力机制框架。:numref:fig_attention_output
说明了如何将注意力汇聚的输出计算成为值的加权和,其中 𝑎
表示注意力评分函数。由于注意力权重是概率分布,因此加权和其本质上是加权平均值。
.. _fig_attention_output:
Mathematically, suppose that we have a query \(\mathbf{q} \in \mathbb{R}^q\) and \(m\) key-value pairs \((\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)\), where any \(\mathbf{k}_i \in \mathbb{R}^k\) and any \(\mathbf{v}_i \in \mathbb{R}^v\). The attention pooling \(f\) is instantiated as a weighted sum of the values:
用数学语言描述,假设有一个查询 \(\mathbf{q} \in \mathbb{R}^q\) 和 \(m\) 个键值对 \((\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)\),其中 \(\mathbf{k}_i \in \mathbb{R}^k\),\(\mathbf{v}_i \in \mathbb{R}^v\)。注意力汇聚函数 𝑓 就被表示成值的加权和:
其中查询 \(\mathbf{q}\) 和键 \(\mathbf{k}_i\) 的注意力权重(标量)是通过注意力评分函数 \(a\) 将两个向量映射成标量,再经过 softmax 运算得到的:
正如我们所看到的,选择不同的注意力评分函数 \(a\) 会导致不同的注意力汇聚操作。在本节中,我们将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java
NDManager manager = NDManager.newBaseManager();
10.2.1. 遮蔽softmax操作¶
正如上面提到的,softmax
运算用于输出一个概率分布作为注意力权重。在某些情况下,并非所有的值都应该被纳入到注意力汇聚中。例如,为了在
Section 9.5
中高效处理小批量数据集,某些文本序列被填充了没有意义的特殊词元。为了仅将有意义的词元作为值去获取注意力汇聚,可以指定一个有效序列长度(即词元的个数),以便在计算
softmax 时过滤掉超出指定范围的位置。通过这种方式,我们可以在下面的
maskedSoftmax
函数中实现这样的 遮蔽 softmax 操作(masked softmax
operation),其中任何超出有效长度的位置都被遮蔽并置为0。
public static NDArray maskedSoftmax(NDArray X, NDArray validLens) {
/* Perform softmax operation by masking elements on the last axis. */
// `X`: 3D NDArray, `validLens`: 1D or 2D NDArray
if (validLens == null) {
return X.softmax(-1);
}
Shape shape = X.getShape();
if (validLens.getShape().dimension() == 1) {
validLens = validLens.repeat(shape.get(1));
} else {
validLens = validLens.reshape(-1);
}
// On the last axis, replace masked elements with a very large negative
// value, whose exponentiation outputs 0
X = X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))
.sequenceMask(validLens, (float) -1E6);
return X.softmax(-1).reshape(shape);
}
为了演示此函数是如何工作的,考虑由两个 \(2 \times 4\) 矩阵表示的样本,这两个样本的有效长度分别为2和3。经过遮蔽 softmax 操作,超出有效长度的值都被遮蔽为0。
maskedSoftmax(
manager.randomUniform(0, 1, new Shape(2, 2, 4)),
manager.create(new float[] {2, 3}));
ND: (2, 2, 4) gpu(0) float32
[[[0.4549, 0.5451, 0. , 0. ],
[0.6175, 0.3825, 0. , 0. ],
],
[[0.294 , 0.3069, 0.3992, 0. ],
[0.3747, 0.2626, 0.3627, 0. ],
],
]
同样,我们也可以使用二维张量为矩阵样本中的每一行指定有效长度。
maskedSoftmax(
manager.randomUniform(0, 1, new Shape(2, 2, 4)),
manager.create(new float[][] {{1, 3}, {2, 4}}));
ND: (2, 2, 4) gpu(0) float32
[[[1. , 0. , 0. , 0. ],
[0.2777, 0.4156, 0.3067, 0. ],
],
[[0.3441, 0.6559, 0. , 0. ],
[0.2544, 0.2482, 0.2036, 0.2939],
],
]
10.2.2. 加性注意力¶
一般来说,当查询和键是不同长度的矢量时,可以使用加性注意力作为评分函数。给定查询 \(\mathbf{q} \in \mathbb{R}^q\) 和键 \(\mathbf{k} \in \mathbb{R}^k\),加性注意力(additive attention) 的评分函数为
其中可学习的参数是 \(\mathbf W_q\in\mathbb R^{h\times q}\), \(\mathbf W_k\in\mathbb R^{h\times k}\), and \(\mathbf w_v\in\mathbb R^{h}\) 。如 (10.2.4) 所示,将查询和键连接起来后输入到一个多层感知机(MLP)中,感知机包含一个隐藏层,其隐藏单元数是一个超参数 \(h\)。通过使用 \(\tanh\) 作为激活函数,并且禁用偏置项,我们将在下面实现加性注意力。
/* Additive attention. */
public static class AdditiveAttention extends AbstractBlock {
private Linear W_k;
private Linear W_q;
private Linear W_v;
private Dropout dropout;
public NDArray attentionWeights;
public AdditiveAttention(int numHiddens, float dropout) {
W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();
addChildBlock("W_k", W_k);
W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();
addChildBlock("W_q", W_q);
W_v = Linear.builder().setUnits(1).optBias(false).build();
addChildBlock("W_v", W_v);
this.dropout = Dropout.builder().optRate(dropout).build();
addChildBlock("dropout", this.dropout);
}
@Override
protected NDList forwardInternal(
ParameterStore ps,
NDList inputs,
boolean training,
PairList<String, Object> params) {
// Shape of the output `queries` and `attentionWeights`:
// (no. of queries, no. of key-value pairs)
NDArray queries = inputs.get(0);
NDArray keys = inputs.get(1);
NDArray values = inputs.get(2);
NDArray validLens = inputs.get(3);
queries = W_q.forward(ps, new NDList(queries), training, params).head();
keys = W_k.forward(ps, new NDList(keys), training, params).head();
// After dimension expansion, shape of `queries`: (`batchSize`, no. of
// queries, 1, `numHiddens`) and shape of `keys`: (`batchSize`, 1,
// no. of key-value pairs, `numHiddens`). Sum them up with
// broadcasting
NDArray features = queries.expandDims(2).add(keys.expandDims(1));
features = features.tanh();
// There is only one output of `this.W_v`, so we remove the last
// one-dimensional entry from the shape. Shape of `scores`:
// (`batchSize`, no. of queries, no. of key-value pairs)
NDArray result = W_v.forward(ps, new NDList(features), training, params).head();
NDArray scores = result.squeeze(-1);
attentionWeights = maskedSoftmax(scores, validLens);
// Shape of `values`: (`batchSize`, no. of key-value pairs, value dimension)
NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);
return new NDList(list.head().batchDot(values));
}
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public void initializeChildBlocks(
NDManager manager, DataType dataType, Shape... inputShapes) {
W_q.initialize(manager, dataType, inputShapes[0]);
W_k.initialize(manager, dataType, inputShapes[1]);
long[] q = W_q.getOutputShapes(new Shape[] {inputShapes[0]})[0].getShape();
long[] k = W_k.getOutputShapes(new Shape[] {inputShapes[1]})[0].getShape();
long w = Math.max(q[q.length - 2], k[k.length - 2]);
long h = Math.max(q[q.length - 1], k[k.length - 1]);
long[] shape = new long[] {2, 1, w, h};
W_v.initialize(manager, dataType, new Shape(shape));
long[] dropoutShape = new long[shape.length - 1];
System.arraycopy(shape, 0, dropoutShape, 0, dropoutShape.length);
dropout.initialize(manager, dataType, new Shape(dropoutShape));
}
}
让我们用一个小例子来演示上面的 AdditiveAttention
类,其中查询、键和值的形状为(批量大小、步数或词元序列长度、特征大小),实际输出为
(\(2\), \(1\), \(20\)), (\(2\), \(10\), \(2\))
和 (\(2\), \(10\),
\(4\))。注意力汇聚输出的形状为(批量大小、查询的步数、值的维度)。
NDArray queries = manager.randomNormal(0, 1, new Shape(2, 1, 20), DataType.FLOAT32);
NDArray keys = manager.ones(new Shape(2, 10, 2));
// The two value matrices in the `values` minibatch are identical
NDArray values = manager.arange(40f).reshape(1, 10, 4).repeat(0, 2);
NDArray validLens = manager.create(new float[] {2, 6});
AdditiveAttention attention = new AdditiveAttention(8, 0.1f);
NDList input = new NDList(queries, keys, values, validLens);
ParameterStore ps = new ParameterStore(manager, false);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());
attention.forward(ps, input, false).head();
ND: (2, 1, 4) gpu(0) float32
[[[ 2., 3., 4., 5.],
],
[[10., 11., 12., 13.],
],
]
尽管加性注意力包含了可学习的参数,但由于本例子中每个键都是相同的,所以注意力权重是均匀的,由指定的有效长度决定。
PlotUtils.showHeatmaps(
attention.attentionWeights.reshape(1, 1, 2, 10),
"Keys",
"Queries",
new String[] {""},
500,
700);
10.2.3. 缩放点积注意力¶
使用点积可以得到计算效率更高的评分函数。但是点积操作要求查询和键具有相同的长度 \(d\)。假设查询和键的所有元素都是独立的随机变量,并且都满足均值为0和方差为。那么两个向量的点积的均值为0,方差为 \(d\)。为确保无论向量长度如何,点积的方差在不考虑向量长度的情况下仍然是1,则可以使用 缩放点积注意力(scaled dot-product attention) 评分函数:
将点积除以 \(\sqrt{d}\)。在实践中,我们通常从小批量的角度来考虑提高效率,例如基于 \(n\) 个查询和 \(m\) 个键-值对计算注意力,其中查询和键的长度为 \(d\),值的长度为 \(v\)。查询 \(\mathbf Q\in\mathbb R^{n\times d}\),键 \(\mathbf K\in\mathbb R^{m\times d}\) 和值 \(\mathbf V\in\mathbb R^{m\times v}\) 的缩放点积注意力是
在下面的缩放点积注意力的实现中,我们使用了 dropout 进行模型正则化。
/* Scaled dot product attention. */
public static class DotProductAttention extends AbstractBlock {
private Dropout dropout;
public NDArray attentionWeights;
public DotProductAttention(float dropout) {
this.dropout = Dropout.builder().optRate(dropout).build();
this.addChildBlock("dropout", this.dropout);
this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT);
}
@Override
protected NDList forwardInternal(
ParameterStore ps,
NDList inputs,
boolean training,
PairList<String, Object> params) {
// Shape of `queries`: (`batchSize`, no. of queries, `d`)
// Shape of `keys`: (`batchSize`, no. of key-value pairs, `d`)
// Shape of `values`: (`batchSize`, no. of key-value pairs, value
// dimension)
// Shape of `valid_lens`: (`batchSize`,) or (`batchSize`, no. of queries)
NDArray queries = inputs.get(0);
NDArray keys = inputs.get(1);
NDArray values = inputs.get(2);
NDArray validLens = inputs.get(3);
Long d = queries.getShape().get(queries.getShape().dimension() - 1);
// Swap the last two dimensions of `keys` and perform batchDot
NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2));
attentionWeights = maskedSoftmax(scores, validLens);
NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);
return new NDList(list.head().batchDot(values));
}
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public void initializeChildBlocks(
NDManager manager, DataType dataType, Shape... inputShapes) {
try (NDManager sub = manager.newSubManager()) {
NDArray queries = sub.zeros(inputShapes[0], dataType);
NDArray keys = sub.zeros(inputShapes[1], dataType);
NDArray scores = queries.batchDot(keys.swapAxes(1, 2));
dropout.initialize(manager, dataType, scores.getShape());
}
}
}
为了演示上述的 DotProductAttention
类,我们使用了与先前加性注意力例子中相同的键、值和有效长度。对于点积操作,令查询的特征维度与键的特征维度大小相同。
queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32);
DotProductAttention productAttention = new DotProductAttention(0.5f);
input = new NDList(queries, keys, values, validLens);
productAttention.initialize(manager, DataType.FLOAT32, input.getShapes());
productAttention.forward(ps, input, false).head();
ND: (2, 1, 4) gpu(0) float32
[[[ 2., 3., 4., 5.],
],
[[10., 11., 12., 13.],
],
]
与加性注意力演示相同,由于键
包含的是相同的元素,而这些元素无法通过任何查询进行区分,因此获得了均匀的注意力权重。
PlotUtils.showHeatmaps(
productAttention.attentionWeights.reshape(1, 1, 2, 10),
"Keys",
"Queries",
new String[] {""},
500,
700);
10.2.4. 小结¶
可以将注意力汇聚的输出计算作为值的加权平均,选择不同的注意力评分函数会带来不同的注意力汇聚操作。
当查询和键是不同长度的矢量时,可以使用可加性注意力评分函数。当它们的长度相同时,使用缩放的“点-积”注意力评分函数的计算效率更高。
10.2.5. 练习¶
修改小例子中的键,并且可视化注意力权重。可加性注意力和缩放的“点-积”注意力是否仍然产生相同的结果?为什么?
只使用矩阵乘法,您能否为具有不同矢量长度的查询和键设计新的评分函数?
当查询和键具有相同的矢量长度时,矢量求和作为评分函数是否比“点-积”更好?为什么?