Run this notebook online: or Colab:

# 10.3. 多头注意力¶

## 10.3.1. 模型¶

(10.3.1)$\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},$

(10.3.2)$\begin{split}\mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.\end{split}$

%load ../utils/djl-imports


NDManager manager = NDManager.newBaseManager();


public static NDArray transposeQkv(NDArray X, int numHeads) {
// Shape of input X:
// (batchSize, no. of queries or key-value pairs, numHiddens).
// Shape of output X:
// (batchSize, no. of queries or key-value pairs, numHeads,
// numHiddens / numHeads)
X = X.reshape(X.getShape().get(0), X.getShape().get(1), numHeads, -1);

// Shape of output X:
// (batchSize, numHeads, no. of queries or key-value pairs,
// numHiddens / numHeads)
X = X.transpose(0, 2, 1, 3);

// Shape of output:
// (batchSize * numHeads, no. of queries or key-value pairs,
// numHiddens / numHeads)
return X.reshape(-1, X.getShape().get(2), X.getShape().get(3));
}

public static NDArray transposeOutput(NDArray X, int numHeads) {
X = X.reshape(-1, numHeads, X.getShape().get(1), X.getShape().get(2));
X = X.transpose(0, 2, 1, 3);
return X.reshape(X.getShape().get(0), X.getShape().get(1), -1);
}


## 10.3.2. 实现¶

public static class MultiHeadAttention extends AbstractBlock {

public DotProductAttention attention;
private Linear W_k;
private Linear W_q;
private Linear W_v;
private Linear W_o;

attention = new DotProductAttention(dropout);

W_q = Linear.builder().setUnits(numHiddens).optBias(useBias).build();

W_k = Linear.builder().setUnits(numHiddens).optBias(useBias).build();

W_v = Linear.builder().setUnits(numHiddens).optBias(useBias).build();

W_o = Linear.builder().setUnits(numHiddens).optBias(useBias).build();

Dropout dropout1 = Dropout.builder().optRate(dropout).build();
}

@Override
protected NDList forwardInternal(
ParameterStore ps,
NDList inputs,
boolean training,
PairList<String, Object> params) {
// Shape of queries, keys, or values:
// (batchSize, no. of queries or key-value pairs, numHiddens)
// Shape of validLens:
// (batchSize,) or (batchSize, no. of queries)
// After transposing, shape of output queries, keys, or values:
// (batchSize * numHeads, no. of queries or key-value pairs,
// numHiddens / numHeads)
NDArray queries = inputs.get(0);
NDArray keys = inputs.get(1);
NDArray values = inputs.get(2);
NDArray validLens = inputs.get(3);
// On axis 0, copy the first item (scalar or vector) for
// numHeads times, then copy the next item, and so on

queries =
transposeQkv(
W_q.forward(ps, new NDList(queries), training, params).get(0),
keys =
transposeQkv(
W_k.forward(ps, new NDList(keys), training, params).get(0), numHeads);
values =
transposeQkv(
W_v.forward(ps, new NDList(values), training, params).get(0), numHeads);

// Shape of output: (batchSize * numHeads, no. of queries,
// numHiddens / numHeads)
NDArray output =
attention
.forward(
ps,
new NDList(queries, keys, values, validLens),
training,
params)
.get(0);

// Shape of outputConcat:
// (batchSize, no. of queries, numHiddens)
return new NDList(W_o.forward(ps, new NDList(outputConcat), training, params).get(0));
}

@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public void initializeChildBlocks(
NDManager manager, DataType dataType, Shape... inputShapes) {
try (NDManager sub = manager.newSubManager()) {
NDArray queries = sub.zeros(inputShapes[0], dataType);
NDArray keys = sub.zeros(inputShapes[1], dataType);
NDArray values = sub.zeros(inputShapes[2], dataType);
NDArray validLens = sub.zeros(inputShapes[3], dataType);

ParameterStore ps = new ParameterStore(sub, false);

W_q.initialize(manager, dataType, queries.getShape());
W_k.initialize(manager, dataType, keys.getShape());
W_v.initialize(manager, dataType, values.getShape());

queries =
keys = transposeQkv(W_k.forward(ps, new NDList(keys), false).get(0), numHeads);
values = transposeQkv(W_v.forward(ps, new NDList(values), false).get(0), numHeads);

NDList list = new NDList(queries, keys, values, validLens);
attention.initialize(sub, dataType, list.getShapes());
NDArray output = attention.forward(ps, list, false).head();

W_o.initialize(manager, dataType, outputConcat.getShape());
}
}
}


int numHiddens = 100;

int batchSize = 2;
int numQueries = 4;
int numKvpairs = 6;
NDArray validLens = manager.create(new float[] {3, 2});
NDArray X = manager.ones(new Shape(batchSize, numQueries, numHiddens));
NDArray Y = manager.ones(new Shape(batchSize, numKvpairs, numHiddens));

ParameterStore ps = new ParameterStore(manager, false);
NDList input = new NDList(X, Y, Y, validLens);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList result = attention.forward(ps, input, false);
result.get(0).getShape();

(2, 4, 100)


## 10.3.3. 小结¶

• 多头注意力融合了来自于相同的注意力汇聚产生的不同的知识，这些知识的不同来源于相同的查询、键和值的不同的子空间表示。

• 基于适当的张量操作，可以实现多头注意力的并行计算。

## 10.3.4. 练习¶

1. 分别可视化这个实验中的多个头的注意力权重。

2. 假设我们已经拥有一个完成训练的基于多头注意力的模型，现在希望修剪最不重要的注意力头以提高预测速度。应该如何设计实验来衡量注意力头的重要性？