Run this notebook online: or Colab:
注意力汇聚:Nadaraya-Watson 核回归¶
在知道了 Section 10.1.2 框架下的注意力机制的主要成分。回顾一下,查询(自主提示)和键(非自主提示)之间的交互形成了 注意力汇聚(attention pooling)。注意力汇聚有选择地聚合了值(感官输入)以生成最终的输出。在本节中,我们将介绍注意力汇聚的更多细节,以便从宏观上了解注意力机制在实践中的运作方式。具体来说,1964 年提出的 Nadaraya-Watson 核回归模型是一个简单但完整的例子,可以用于演示具有注意力机制的机器学习。
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/Animator.java
%load ../utils/PlotUtils.java
NDManager manager = NDManager.newBaseManager();
生成数据集¶
简单起见,考虑下面这个回归问题:给定的成对的 “输入-输出” 数据集 \(\{(x_1, y_1), \ldots, (x_n, y_n)\}\),如何学习 \(f\) 来预测任意新输入 \(x\) 的输出 \(\hat{y} = f(x)\) ?
根据下面的非线性函数生成一个人工数据集,其中加入的噪声项为 \(\epsilon\):
其中 \(\epsilon\) 服从均值为 0 和标准差为 0.5 的正态分布。我们生成了 50 个训练样本和 50 个测试样本。为了更好地可视化之后的注意力模式,输入的训练样本将进行排序。
int nTrain = 50; // No. of training examples
NDArray xTrain = manager.randomUniform(0, 1, new Shape(nTrain)).mul(5).sort(); // Training inputs
Function<NDArray, NDArray> f = x -> x.sin().mul(2).add(x.pow(0.8));
NDArray yTrain =
f.apply(xTrain)
.add(
manager.randomNormal(
0f,
0.5f,
new Shape(nTrain),
DataType.FLOAT32)); // Training outputs
NDArray xTest = manager.arange(0f, 5f, 0.1f); // Testing examples
NDArray yTruth = f.apply(xTest); // Ground-truth outputs for the testing examples
int nTest = (int) xTest.getShape().get(0); // No. of testing examples
System.out.println(nTest);
50
下面的函数将绘制所有的训练样本(样本由圆圈表示)、不带噪声项的真实数据生成函数
f
(标记为“Truth”)和学习得到的预测函数(标记为“Pred”)。
public static Figure plot(
NDArray yHat,
String trace1Name,
String trace2Name,
String xLabel,
String yLabel,
int width,
int height) {
ScatterTrace trace =
ScatterTrace.builder(
Functions.floatToDoubleArray(xTest.toFloatArray()),
Functions.floatToDoubleArray(yTruth.toFloatArray()))
.mode(ScatterTrace.Mode.LINE)
.name(trace1Name)
.build();
ScatterTrace trace2 =
ScatterTrace.builder(
Functions.floatToDoubleArray(xTest.toFloatArray()),
Functions.floatToDoubleArray(yHat.toFloatArray()))
.mode(ScatterTrace.Mode.LINE)
.name(trace2Name)
.build();
ScatterTrace trace3 =
ScatterTrace.builder(
Functions.floatToDoubleArray(xTrain.toFloatArray()),
Functions.floatToDoubleArray(yTrain.toFloatArray()))
.mode(ScatterTrace.Mode.MARKERS)
.marker(Marker.builder().symbol(Symbol.CIRCLE).size(15).opacity(.5).build())
.build();
Layout layout =
Layout.builder()
.height(height)
.width(width)
.showLegend(true)
.xAxis(Axis.builder().title(xLabel).domain(0, 5).build())
.yAxis(Axis.builder().title(yLabel).domain(-1, 5).build())
.build();
return new Figure(layout, trace, trace2, trace3);
}
平均汇聚¶
先使用可能是这个世界上“最愚蠢”的估计器来解决回归问题:基于平均汇聚来计算所有训练样本输出值的平均值:
如下图所示,这个估计器确实不够聪明。
NDArray yHat = yTrain.mean().tile(nTest);
plot(yHat, "Truth", "Pred", "x", "y", 700, 500);
非参数注意力汇聚¶
显然,平均汇聚忽略了输入 \(x_i\) 。于是 [Nadaraya.1964] 和 Waston [Watson.1964] 提出了一个更好的想法,根据输入的位置对输出 \(y_i\) 进行加权:
其中 \(K\) 是 核(kernel)。公式 () 中所描述的估计器被称为 Nadaraya-Watson 核回归(Nadaraya-Watson kernel regression)。在这里我们不会深入讨论核函数的细节。回想一下 Section 10.1.2 中的注意力机制框架,我们可以从注意力机制的角度重写 () 成为一个更加通用的 注意力汇聚 公式:
其中 \(x\) 是查询,\((x_i, y_i)\) 是键值对。比较 () 和 (),注意力汇聚是 \(y_i\) 的加权平均。将查询 (x) 和键 \(x_i\) 之间的关系建模为 注意力权重(attetnion weight) \(\alpha(x, x_i)\) ,如 () 所示,这个权重将被分配给每一个对应值 \(y_i\)。对于任何查询,模型在所有键值对上的注意力权重都是一个有效的概率分布:它们是非负数的,并且总和为1。
为了更好地理解注意力汇聚,仅考虑一个 高斯核(Gaussian kernel),其定义为:
在 ()中,如果一个键 \(x_i\) 越是接近给定的查询 \(x\), 那么分配给这个键对应值 \(y_i\) 的注意力权重就会越大, 也就是 获得了更多的注意力。
值得注意的是,Nadaraya-Watson
核回归是一个非参数模型。因此,:eqref:eq_nadaraya-watson-gaussian
是 非参数的注意力汇聚
的例子。接下来,我们将基于这个非参数的注意力汇聚模型来绘制预测结果。结果是预测线是平滑的,并且比平均汇聚产生的线更接近真实。
// Shape of `xRepeat`: (`nTest`, `nTrain`), where each row contains the
// same testing inputs (i.e., same queries)
NDArray xRepeat = xTest.repeat(nTrain).reshape(new Shape(-1, nTrain));
// Note that `xTrain` contains the keys. Shape of `attention_weights`:
// (`nTest`, `nTrain`), where each row contains attention weights to be
// assigned among the values (`yTrain`) given each query
NDArray attentionWeights = xRepeat.sub(xTrain).pow(2).div(2).mul(-1).softmax(-1);
// Each element of `yHat` is weighted average of values, where weights are
// attention weights
yHat = attentionWeights.dot(yTrain);
plot(yHat, "Truth", "Pred", "x", "y", 700, 500);
现在,让我们来观察注意力的权重。这里测试数据的输入相当于查询,而训练数据的输入相当于键。因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近,注意力汇聚的注意力权重就越高。
PlotUtils.showHeatmaps(
attentionWeights.expandDims(0).expandDims(0),
"Sorted training inputs",
"Sorted testing inputs",
new String[] {""},
500,
700);
带参数注意力汇聚¶
非参数的 Nadaraya-Watson 核回归具有 一致性(consistency) 的优点:如果有足够的数据,此模型会收敛到最优结果。尽管如此,我们还是可以轻松地将可学习的参数集成到注意力汇聚中。
例如,与 () 略有不同,在下面的查询 \(x_i\) 和键 \(x_i\) 之间的距离乘以可学习参数 \(w\):
在本节的余下部分,我们将通过训练这个模型 () 来学习注意力汇聚的参数。
批量矩阵乘法¶
为了更有效地计算小批量数据的注意力,我们可以利用深度学习开发框架中提供的批量矩阵乘法。
假设第一个小批量数据包含 \(n\) 个矩阵 \(\mathbf{X}_1, \ldots, \mathbf{X}_n\),形状为 \(a\times b\),第二个小批量包含 \(n\) 个矩阵 \(\mathbf{Y}_1, \ldots, \mathbf{Y}_n\),形状为 \(b\times c\)。它们的批量矩阵乘法得到 \(n\) 个矩阵 \(\mathbf{X}_1\mathbf{Y}_1, \ldots, \mathbf{X}_n\mathbf{Y}_n\) ,形状为 \(a\times c\)。因此,假定两个张量的形状分别是 (\(n\), \(a\), \(b\)) 和 (\(n\), \(b\), \(c\)),它们的批量矩阵乘法输出的形状为 (\(n\), \(a\), \(c\))。
NDArray X = manager.ones(new Shape(2, 1, 4));
NDArray Y = manager.ones(new Shape(2, 4, 6));
X.batchDot(Y).getShape()
(2, 1, 6)
在注意力机制的背景中,我们可以使用小批量矩阵乘法来计算小批量数据中的加权平均值。
NDArray weights = manager.ones(new Shape(2, 10)).mul(0.1);
NDArray values = manager.arange(20f).reshape(new Shape(2, 10));
weights.expandDims(1).batchDot(values.expandDims(-1))
ND: (2, 1, 1) gpu(0) float32
[[[ 4.5],
],
[[14.5],
],
]
定义模型¶
基于 () 中的带参数的注意力汇聚,使用小批量矩阵乘法,定义 Nadaraya-Watson 核回归的带参数版本为:
public class NWKernelRegression extends AbstractBlock {
private Parameter w;
public NDArray attentionWeights;
public NWKernelRegression() {
w = Parameter.builder().optShape(new Shape(1)).setName("w").optInitializer(new UniformInitializer()).build();
addParameter(w);
}
@Override
protected NDList forwardInternal(
ParameterStore parameterStore,
NDList inputs,
boolean training,
PairList<String, Object> params) {
// Shape of the output `queries` and `attentionWeights`:
// (no. of queries, no. of key-value pairs)
NDArray queries = inputs.get(0);
NDArray keys = inputs.get(1);
NDArray values = inputs.get(2);
queries =
queries.repeat(keys.getShape().get(1))
.reshape(new Shape(-1, keys.getShape().get(1)));
this.attentionWeights =
queries.sub(keys).mul(this.w.getArray()).pow(2).div(2).mul(-1).softmax(-1);
// Shape of `values`: (no. of queries, no. of key-value pairs)
return new NDList(
attentionWeights.expandDims(1).batchDot(values.expandDims(-1)).reshape(-1));
}
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
throw new UnsupportedOperationException("Not implemented");
}
}
训练¶
接下来,将训练数据集转换为键和值用于训练注意力模型。在带参数的注意力汇聚模型中,任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算,从而得到其对应的预测输出。
// Shape of `xTile`: (`nTrain`, `nTrain`), where each column contains the
// same training inputs
NDArray xTile = xTrain.tile(new long[] {nTrain, 1});
// Shape of `Y_tile`: (`nTrain`, `nTrain`), where each column contains the
// same training outputs
NDArray yTile = yTrain.tile(new long[] {nTrain, 1});
// Shape of `keys`: ('nTrain', 'nTrain' - 1)
NDArray keys =
xTile.get(new NDIndex().addBooleanIndex(manager.eye(nTrain).mul(-1).add(1))).reshape(new Shape(nTrain, -1));
// Shape of `values`: ('nTrain', 'nTrain' - 1)
values = yTile.get(new NDIndex().addBooleanIndex(manager.eye(nTrain).mul(-1).add(1))).reshape(new Shape(nTrain, -1));
训练带参数的注意力汇聚模型时使用平方损失函数和随机梯度下降。
NWKernelRegression net = new NWKernelRegression();
Loss loss = Loss.l2Loss();
Tracker lrt =
Tracker.fixed(0.5f * nTrain); // Since we are using sgd, to be able to put the right
// scale, we need to multiply by batchSize
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
DefaultTrainingConfig config =
new DefaultTrainingConfig(loss)
.optOptimizer(sgd) // Optimizer (loss function)
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging
Model model = Model.newInstance("");
model.setBlock(net);
Trainer trainer = model.newTrainer(config);
Animator animator = new Animator();
ParameterStore ps = new ParameterStore(manager, false);
// Create trainer and animator
for (int epoch = 0; epoch < 5; epoch++) {
try (GradientCollector gc = trainer.newGradientCollector()) {
NDArray result = net.forward(ps, new NDList(xTrain, keys, values), true).get(0);
NDArray l = trainer.getLoss().evaluate(new NDList(yTrain), new NDList(result));
gc.backward(l);
animator.add(epoch + 1, (float) l.getFloat(), "Loss");
animator.show();
}
trainer.step();
}
INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.061 ms.
训练完带参数的注意力汇聚模型后,我们发现,在尝试拟合带噪声的训练数据时,预测结果绘制的线不如之前非参数模型的线平滑。
// Shape of `keys`: (`nTest`, `nTrain`), where each column contains the same
// training inputs (i.e., same keys)
keys = xTrain.tile(new long[] {nTest, 1});
// Shape of `value`: (`nTest`, `nTrain`)
values = yTrain.tile(new long[] {nTest, 1});
yHat = net.forward(ps, new NDList(xTest, keys, values), true).get(0);
plot(yHat, "Truth", "Pred", "x", "y", 700, 500);
与非参数的注意力汇聚模型相比,带参数的模型加入可学习的参数后,在输出结果的绘制图上,曲线在注意力权重较大的区域变得更不平滑。
PlotUtils.showHeatmaps(
net.attentionWeights.expandDims(0).expandDims(0),
"Sorted training inputs",
"Sorted testing inputs",
new String[] {""},
500,
700);
小结¶
Nadaraya-Watson 核回归是具有注意力机制的机器学习的一个例子。
Nadaraya-Watson 核回归的注意力汇聚是对训练数据中输出的加权平均。从注意力的角度来看,分配给每个值的注意力权重取决于将值所对应的键和查询作为输入的函数。
注意力汇聚可以分为非参数型和带参数型。
练习¶
增加训练数据的样本数量。能否得到更好的非参数的 Nadaraya-Watson 核回归模型?
在带参数的注意力汇聚的实验中学习得到的参数 \(w\) 的价值是什么?为什么在可视化注意力权重时,它会使加权区域更加尖锐?
如何将超参数添加到非参数的 Nadaraya-Watson 核回归中以实现更好地预测结果?
为本节的核回归设计一个新的带参数的注意力汇聚模型。训练这个新模型并可视化其注意力权重。