Run this notebook online: or Colab:

# 10.2. 注意力评分函数¶

sec_nadaraya-waston 中，我们使用高斯核来对查询和键之间的关系建模。可以将 eq_nadaraya-waston-gaussian 中的高斯核的指数部分视为 注意力评分函数（attention scoring function），简称 评分函数（scoring function），然后把这个函数的输出结果输入到 softmax 函数中进行运算。通过上述步骤，我们将得到与键对应的值的概率分布（即注意力权重）。最后，注意力汇聚的输出就是基于这些注意力权重的值的加权和。

.. _fig_attention_output:

Mathematically, suppose that we have a query $$\mathbf{q} \in \mathbb{R}^q$$ and $$m$$ key-value pairs $$(\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)$$, where any $$\mathbf{k}_i \in \mathbb{R}^k$$ and any $$\mathbf{v}_i \in \mathbb{R}^v$$. The attention pooling $$f$$ is instantiated as a weighted sum of the values:

(10.2.2)$f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,$

(10.2.2)$f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i \in \mathbb{R}^v,$

(10.2.3)$\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))} \in \mathbb{R}.$

%load ../utils/djl-imports

NDManager manager = NDManager.newBaseManager();


## 10.2.1. 遮蔽softmax操作¶

public static NDArray maskedSoftmax(NDArray X, NDArray validLens) {
/* Perform softmax operation by masking elements on the last axis. */
// X: 3D NDArray, validLens: 1D or 2D NDArray
if (validLens == null) {
return X.softmax(-1);
}

Shape shape = X.getShape();
if (validLens.getShape().dimension() == 1) {
validLens = validLens.repeat(shape.get(1));
} else {
validLens = validLens.reshape(-1);
}
// On the last axis, replace masked elements with a very large negative
// value, whose exponentiation outputs 0
X = X.reshape(new Shape(-1, shape.get(shape.dimension() - 1)))
return X.softmax(-1).reshape(shape);
}


maskedSoftmax(
manager.randomUniform(0, 1, new Shape(2, 2, 4)),
manager.create(new float[] {2, 3}));

ND: (2, 2, 4) gpu(0) float32
[[[0.4549, 0.5451, 0.    , 0.    ],
[0.6175, 0.3825, 0.    , 0.    ],
],
[[0.294 , 0.3069, 0.3992, 0.    ],
[0.3747, 0.2626, 0.3627, 0.    ],
],
]


maskedSoftmax(
manager.randomUniform(0, 1, new Shape(2, 2, 4)),
manager.create(new float[][] {{1, 3}, {2, 4}}));

ND: (2, 2, 4) gpu(0) float32
[[[1.    , 0.    , 0.    , 0.    ],
[0.2777, 0.4156, 0.3067, 0.    ],
],
[[0.3441, 0.6559, 0.    , 0.    ],
[0.2544, 0.2482, 0.2036, 0.2939],
],
]


## 10.2.2. 加性注意力¶

(10.2.4)$a(\mathbf q, \mathbf k) = \mathbf w_v^\top \text{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},$

/* Additive attention. */
public static class AdditiveAttention extends AbstractBlock {

private Linear W_k;
private Linear W_q;
private Linear W_v;
private Dropout dropout;
public NDArray attentionWeights;

public AdditiveAttention(int numHiddens, float dropout) {
W_k = Linear.builder().setUnits(numHiddens).optBias(false).build();

W_q = Linear.builder().setUnits(numHiddens).optBias(false).build();

W_v = Linear.builder().setUnits(1).optBias(false).build();

this.dropout = Dropout.builder().optRate(dropout).build();
}

@Override
protected NDList forwardInternal(
ParameterStore ps,
NDList inputs,
boolean training,
PairList<String, Object> params) {
// Shape of the output queries and attentionWeights:
// (no. of queries, no. of key-value pairs)
NDArray queries = inputs.get(0);
NDArray keys = inputs.get(1);
NDArray values = inputs.get(2);
NDArray validLens = inputs.get(3);

queries = W_q.forward(ps, new NDList(queries), training, params).head();
keys = W_k.forward(ps, new NDList(keys), training, params).head();
// After dimension expansion, shape of queries: (batchSize, no. of
// queries, 1, numHiddens) and shape of keys: (batchSize, 1,
// no. of key-value pairs, numHiddens). Sum them up with
features = features.tanh();
// There is only one output of this.W_v, so we remove the last
// one-dimensional entry from the shape. Shape of scores:
// (batchSize, no. of queries, no. of key-value pairs)
NDArray result = W_v.forward(ps, new NDList(features), training, params).head();
NDArray scores = result.squeeze(-1);
// Shape of values: (batchSize, no. of key-value pairs, value dimension)
NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);
}

@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public void initializeChildBlocks(
NDManager manager, DataType dataType, Shape... inputShapes) {
W_q.initialize(manager, dataType, inputShapes[0]);
W_k.initialize(manager, dataType, inputShapes[1]);
long[] q = W_q.getOutputShapes(new Shape[] {inputShapes[0]})[0].getShape();
long[] k = W_k.getOutputShapes(new Shape[] {inputShapes[1]})[0].getShape();
long w = Math.max(q[q.length - 2], k[k.length - 2]);
long h = Math.max(q[q.length - 1], k[k.length - 1]);
long[] shape = new long[] {2, 1, w, h};
W_v.initialize(manager, dataType, new Shape(shape));
long[] dropoutShape = new long[shape.length - 1];
System.arraycopy(shape, 0, dropoutShape, 0, dropoutShape.length);
dropout.initialize(manager, dataType, new Shape(dropoutShape));
}
}


NDArray queries = manager.randomNormal(0, 1, new Shape(2, 1, 20), DataType.FLOAT32);
NDArray keys = manager.ones(new Shape(2, 10, 2));
// The two value matrices in the values minibatch are identical
NDArray values = manager.arange(40f).reshape(1, 10, 4).repeat(0, 2);
NDArray validLens = manager.create(new float[] {2, 6});

NDList input = new NDList(queries, keys, values, validLens);
ParameterStore ps = new ParameterStore(manager, false);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());

ND: (2, 1, 4) gpu(0) float32
[[[ 2.,  3.,  4.,  5.],
],
[[10., 11., 12., 13.],
],
]


PlotUtils.showHeatmaps(
attention.attentionWeights.reshape(1, 1, 2, 10),
"Keys",
"Queries",
new String[] {""},
500,
700);


## 10.2.3. 缩放点积注意力¶

(10.2.5)$a(\mathbf q, \mathbf k) = \mathbf{q}^\top \mathbf{k} /\sqrt{d}$

(10.2.6)$\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.$

/* Scaled dot product attention. */
public static class DotProductAttention extends AbstractBlock {

private Dropout dropout;
public NDArray attentionWeights;

public DotProductAttention(float dropout) {
this.dropout = Dropout.builder().optRate(dropout).build();
this.dropout.setInitializer(new UniformInitializer(0.07f), Parameter.Type.WEIGHT);
}

@Override
protected NDList forwardInternal(
ParameterStore ps,
NDList inputs,
boolean training,
PairList<String, Object> params) {
// Shape of queries: (batchSize, no. of queries, d)
// Shape of keys: (batchSize, no. of key-value pairs, d)
// Shape of values: (batchSize, no. of key-value pairs, value
// dimension)
// Shape of valid_lens: (batchSize,) or (batchSize, no. of queries)
NDArray queries = inputs.get(0);
NDArray keys = inputs.get(1);
NDArray values = inputs.get(2);
NDArray validLens = inputs.get(3);

Long d = queries.getShape().get(queries.getShape().dimension() - 1);
// Swap the last two dimensions of keys and perform batchDot
NDArray scores = queries.batchDot(keys.swapAxes(1, 2)).div(Math.sqrt(2));
NDList list = dropout.forward(ps, new NDList(attentionWeights), training, params);
}

@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public void initializeChildBlocks(
NDManager manager, DataType dataType, Shape... inputShapes) {
try (NDManager sub = manager.newSubManager()) {
NDArray queries = sub.zeros(inputShapes[0], dataType);
NDArray keys = sub.zeros(inputShapes[1], dataType);
NDArray scores = queries.batchDot(keys.swapAxes(1, 2));
dropout.initialize(manager, dataType, scores.getShape());
}
}
}


queries = manager.randomNormal(0, 1, new Shape(2, 1, 2), DataType.FLOAT32);
DotProductAttention productAttention = new DotProductAttention(0.5f);
input = new NDList(queries, keys, values, validLens);
productAttention.initialize(manager, DataType.FLOAT32, input.getShapes());

ND: (2, 1, 4) gpu(0) float32
[[[ 2.,  3.,  4.,  5.],
],
[[10., 11., 12., 13.],
],
]


PlotUtils.showHeatmaps(
productAttention.attentionWeights.reshape(1, 1, 2, 10),
"Keys",
"Queries",
new String[] {""},
500,
700);


## 10.2.4. 小结¶

• 可以将注意力汇聚的输出计算作为值的加权平均，选择不同的注意力评分函数会带来不同的注意力汇聚操作。

• 当查询和键是不同长度的矢量时，可以使用可加性注意力评分函数。当它们的长度相同时，使用缩放的“点－积”注意力评分函数的计算效率更高。

## 10.2.5. 练习¶

1. 修改小例子中的键，并且可视化注意力权重。可加性注意力和缩放的“点－积”注意力是否仍然产生相同的结果？为什么？

2. 只使用矩阵乘法，您能否为具有不同矢量长度的查询和键设计新的评分函数？

3. 当查询和键具有相同的矢量长度时，矢量求和作为评分函数是否比“点－积”更好？为什么？