Run this notebook online: or Colab:
10.3. 多头注意力¶
在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,例如捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖)。因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。
为此,与使用单独一个注意力汇聚不同,我们可以用独立学习得到的 \(h\) 组不同的 线性投影(linear projections)来变换查询、键和值。然后,这 \(h\) 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 \(h\) 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为 多头注意力,其中 \(h\) 个注意力汇聚输出中的每一个输出都被称作一个 头(head) [Vaswani et al., 2017]。 图 Section 10.3 展示了使用全连接层来实现可学习的线性变换的多头注意力。
10.3.1. 模型¶
在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 \(\mathbf{q} \in \mathbb{R}^{d_q}\)、键 \(\mathbf{k} \in \mathbb{R}^{d_k}\) 和值 \(\mathbf{v} \in \mathbb{R}^{d_v}\),每个注意力头 \(\mathbf{h}_i\) (\(i = 1, \ldots, h\)) 的计算方法为:
其中,可学习的参数包括 \(\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}\)、\(\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}\) 和 \(\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}\),以及注意力汇聚的函数 \(f\)。\(f\) 可以是 Section 10.2 中的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 \(h\) 个头连结后的结果,因此其可学习参数是 \(\mathbf W_o\in\mathbb R^{p_o\times h p_v}\):
基于这种设计,每个头都可能会关注输入的不同部分。可以表示比简单加权平均值更复杂的函数。
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java
%load ../utils/attention/Chap10Utils.java
%load ../utils/attention/DotProductAttention.java
%load ../utils/attention/MultiHeadAttention.java
%load ../utils/attention/PositionalEncoding.java
NDManager manager = NDManager.newBaseManager();
为了允许并行计算多个头,下面的MultiHeadAttention
类使用两个如下定义的转置函数。具体来说,transposeOutput
函数颠倒了
transposeQkv
函数的操作。
public static NDArray transposeQkv(NDArray X, int numHeads) {
// Shape of input `X`:
// (`batchSize`, no. of queries or key-value pairs, `numHiddens`).
// Shape of output `X`:
// (`batchSize`, no. of queries or key-value pairs, `numHeads`,
// `numHiddens` / `numHeads`)
X = X.reshape(X.getShape().get(0), X.getShape().get(1), numHeads, -1);
// Shape of output `X`:
// (`batchSize`, `numHeads`, no. of queries or key-value pairs,
// `numHiddens` / `numHeads`)
X = X.transpose(0, 2, 1, 3);
// Shape of `output`:
// (`batchSize` * `numHeads`, no. of queries or key-value pairs,
// `numHiddens` / `numHeads`)
return X.reshape(-1, X.getShape().get(2), X.getShape().get(3));
}
public static NDArray transposeOutput(NDArray X, int numHeads) {
X = X.reshape(-1, numHeads, X.getShape().get(1), X.getShape().get(2));
X = X.transpose(0, 2, 1, 3);
return X.reshape(X.getShape().get(0), X.getShape().get(1), -1);
}
10.3.2. 实现¶
在实现过程中,我们选择缩放点积注意力作为每一个注意力头。为了避免计算成本和参数数量的大幅增长,我们设定
\(p_q = p_k = p_v = p_o / h\)。值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为
\(p_q h = p_k h = p_v h = p_o\),则可以并行计算 \(h\)
个头。在下面的实现中,\(p_o\) 是通过参数 numHiddens
指定的。
public static class MultiHeadAttention extends AbstractBlock {
private int numHeads;
public DotProductAttention attention;
private Linear W_k;
private Linear W_q;
private Linear W_v;
private Linear W_o;
public MultiHeadAttention(int numHiddens, int numHeads, float dropout, boolean useBias) {
this.numHeads = numHeads;
attention = new DotProductAttention(dropout);
W_q = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
addChildBlock("W_q", W_q);
W_k = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
addChildBlock("W_k", W_k);
W_v = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
addChildBlock("W_v", W_v);
W_o = Linear.builder().setUnits(numHiddens).optBias(useBias).build();
addChildBlock("W_o", W_o);
Dropout dropout1 = Dropout.builder().optRate(dropout).build();
addChildBlock("dropout", dropout1);
}
@Override
protected NDList forwardInternal(
ParameterStore ps,
NDList inputs,
boolean training,
PairList<String, Object> params) {
// Shape of `queries`, `keys`, or `values`:
// (`batchSize`, no. of queries or key-value pairs, `numHiddens`)
// Shape of `validLens`:
// (`batchSize`,) or (`batchSize`, no. of queries)
// After transposing, shape of output `queries`, `keys`, or `values`:
// (`batchSize` * `numHeads`, no. of queries or key-value pairs,
// `numHiddens` / `numHeads`)
NDArray queries = inputs.get(0);
NDArray keys = inputs.get(1);
NDArray values = inputs.get(2);
NDArray validLens = inputs.get(3);
// On axis 0, copy the first item (scalar or vector) for
// `numHeads` times, then copy the next item, and so on
validLens = validLens.repeat(0, numHeads);
queries =
transposeQkv(
W_q.forward(ps, new NDList(queries), training, params).get(0),
numHeads);
keys =
transposeQkv(
W_k.forward(ps, new NDList(keys), training, params).get(0), numHeads);
values =
transposeQkv(
W_v.forward(ps, new NDList(values), training, params).get(0), numHeads);
// Shape of `output`: (`batchSize` * `numHeads`, no. of queries,
// `numHiddens` / `numHeads`)
NDArray output =
attention
.forward(
ps,
new NDList(queries, keys, values, validLens),
training,
params)
.get(0);
// Shape of `outputConcat`:
// (`batchSize`, no. of queries, `numHiddens`)
NDArray outputConcat = transposeOutput(output, numHeads);
return new NDList(W_o.forward(ps, new NDList(outputConcat), training, params).get(0));
}
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("Not implemented");
}
@Override
public void initializeChildBlocks(
NDManager manager, DataType dataType, Shape... inputShapes) {
try (NDManager sub = manager.newSubManager()) {
NDArray queries = sub.zeros(inputShapes[0], dataType);
NDArray keys = sub.zeros(inputShapes[1], dataType);
NDArray values = sub.zeros(inputShapes[2], dataType);
NDArray validLens = sub.zeros(inputShapes[3], dataType);
validLens = validLens.repeat(0, numHeads);
ParameterStore ps = new ParameterStore(sub, false);
W_q.initialize(manager, dataType, queries.getShape());
W_k.initialize(manager, dataType, keys.getShape());
W_v.initialize(manager, dataType, values.getShape());
queries =
transposeQkv(W_q.forward(ps, new NDList(queries), false).get(0), numHeads);
keys = transposeQkv(W_k.forward(ps, new NDList(keys), false).get(0), numHeads);
values = transposeQkv(W_v.forward(ps, new NDList(values), false).get(0), numHeads);
NDList list = new NDList(queries, keys, values, validLens);
attention.initialize(sub, dataType, list.getShapes());
NDArray output = attention.forward(ps, list, false).head();
NDArray outputConcat = Chap10Utils.transposeOutput(output, numHeads);
W_o.initialize(manager, dataType, outputConcat.getShape());
}
}
}
让我们使用键和值相同的小例子来测试我们编写的 MultiHeadAttention
类。多头注意力输出的形状是 (batchSize
, numQueries
,
numHiddens
)。
int numHiddens = 100;
int numHeads = 5;
MultiHeadAttention attention = new MultiHeadAttention(numHiddens, numHeads, 0.5f, false);
int batchSize = 2;
int numQueries = 4;
int numKvpairs = 6;
NDArray validLens = manager.create(new float[] {3, 2});
NDArray X = manager.ones(new Shape(batchSize, numQueries, numHiddens));
NDArray Y = manager.ones(new Shape(batchSize, numKvpairs, numHiddens));
ParameterStore ps = new ParameterStore(manager, false);
NDList input = new NDList(X, Y, Y, validLens);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList result = attention.forward(ps, input, false);
result.get(0).getShape();
(2, 4, 100)
10.3.3. 小结¶
多头注意力融合了来自于相同的注意力汇聚产生的不同的知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
基于适当的张量操作,可以实现多头注意力的并行计算。
10.3.4. 练习¶
分别可视化这个实验中的多个头的注意力权重。
假设我们已经拥有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。应该如何设计实验来衡量注意力头的重要性?