Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_attention-mechanisms/multihead-attention.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_attention-mechanisms/multihead-attention.ipynb .. _sec_multihead-attention: .. _fig_multi-head-attention: 多头注意力 ========== 在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,例如捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依赖)。因此,允许注意力机制组合使用查询、键和值的不同 子空间表示(representation subspaces)可能是有益的。 为此,与使用单独一个注意力汇聚不同,我们可以用独立学习得到的 :math:`h` 组不同的 线性投影(linear projections)来变换查询、键和值。然后,这 :math:`h` 组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这 :math:`h` 个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为 多头注意力,其中 :math:`h` 个注意力汇聚输出中的每一个输出都被称作一个 *头(head)* :cite:`Vaswani.Shazeer.Parmar.ea.2017`\ 。 图 :numref:`fig_multi-head-attention` 展示了使用全连接层来实现可学习的线性变换的多头注意力。 |多头注意力,多个头连结然后线性变换。| 模型 ---- 在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询 :math:`\mathbf{q} \in \mathbb{R}^{d_q}`\ 、键 :math:`\mathbf{k} \in \mathbb{R}^{d_k}` 和值 :math:`\mathbf{v} \in \mathbb{R}^{d_v}`\ ,每个注意力头 :math:`\mathbf{h}_i` (:math:`i = 1, \ldots, h`) 的计算方法为: .. math:: \mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v}, 其中,可学习的参数包括 :math:`\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}`\ 、\ :math:`\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}` 和 :math:`\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}`\ ,以及注意力汇聚的函数 :math:`f`\ 。\ :math:`f` 可以是 :numref:`sec_attention-scoring-functions` 中的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着 :math:`h` 个头连结后的结果,因此其可学习参数是 :math:`\mathbf W_o\in\mathbb R^{p_o\times h p_v}`\ : .. math:: \mathbf W_o \begin{bmatrix}\mathbf h_1\\\vdots\\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}. 基于这种设计,每个头都可能会关注输入的不同部分。可以表示比简单加权平均值更复杂的函数。 .. |多头注意力,多个头连结然后线性变换。| image:: https://zh-v2.d2l.ai/_images/multi-head-attention.svg .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/Functions.java %load ../utils/PlotUtils.java %load ../utils/attention/Chap10Utils.java %load ../utils/attention/DotProductAttention.java %load ../utils/attention/MultiHeadAttention.java %load ../utils/attention/PositionalEncoding.java .. code:: java NDManager manager = NDManager.newBaseManager(); 为了允许并行计算多个头,下面的\ ``MultiHeadAttention`` 类使用两个如下定义的转置函数。具体来说,\ ``transposeOutput`` 函数颠倒了 ``transposeQkv`` 函数的操作。 .. code:: java public static NDArray transposeQkv(NDArray X, int numHeads) { // Shape of input `X`: // (`batchSize`, no. of queries or key-value pairs, `numHiddens`). // Shape of output `X`: // (`batchSize`, no. of queries or key-value pairs, `numHeads`, // `numHiddens` / `numHeads`) X = X.reshape(X.getShape().get(0), X.getShape().get(1), numHeads, -1); // Shape of output `X`: // (`batchSize`, `numHeads`, no. of queries or key-value pairs, // `numHiddens` / `numHeads`) X = X.transpose(0, 2, 1, 3); // Shape of `output`: // (`batchSize` * `numHeads`, no. of queries or key-value pairs, // `numHiddens` / `numHeads`) return X.reshape(-1, X.getShape().get(2), X.getShape().get(3)); } public static NDArray transposeOutput(NDArray X, int numHeads) { X = X.reshape(-1, numHeads, X.getShape().get(1), X.getShape().get(2)); X = X.transpose(0, 2, 1, 3); return X.reshape(X.getShape().get(0), X.getShape().get(1), -1); } 实现 ---- 在实现过程中,我们选择缩放点积注意力作为每一个注意力头。为了避免计算成本和参数数量的大幅增长,我们设定 :math:`p_q = p_k = p_v = p_o / h`\ 。值得注意的是,如果我们将查询、键和值的线性变换的输出数量设置为 :math:`p_q h = p_k h = p_v h = p_o`\ ,则可以并行计算 :math:`h` 个头。在下面的实现中,\ :math:`p_o` 是通过参数 ``numHiddens`` 指定的。 .. code:: java public static class MultiHeadAttention extends AbstractBlock { private int numHeads; public DotProductAttention attention; private Linear W_k; private Linear W_q; private Linear W_v; private Linear W_o; public MultiHeadAttention(int numHiddens, int numHeads, float dropout, boolean useBias) { this.numHeads = numHeads; attention = new DotProductAttention(dropout); W_q = Linear.builder().setUnits(numHiddens).optBias(useBias).build(); addChildBlock("W_q", W_q); W_k = Linear.builder().setUnits(numHiddens).optBias(useBias).build(); addChildBlock("W_k", W_k); W_v = Linear.builder().setUnits(numHiddens).optBias(useBias).build(); addChildBlock("W_v", W_v); W_o = Linear.builder().setUnits(numHiddens).optBias(useBias).build(); addChildBlock("W_o", W_o); Dropout dropout1 = Dropout.builder().optRate(dropout).build(); addChildBlock("dropout", dropout1); } @Override protected NDList forwardInternal( ParameterStore ps, NDList inputs, boolean training, PairList params) { // Shape of `queries`, `keys`, or `values`: // (`batchSize`, no. of queries or key-value pairs, `numHiddens`) // Shape of `validLens`: // (`batchSize`,) or (`batchSize`, no. of queries) // After transposing, shape of output `queries`, `keys`, or `values`: // (`batchSize` * `numHeads`, no. of queries or key-value pairs, // `numHiddens` / `numHeads`) NDArray queries = inputs.get(0); NDArray keys = inputs.get(1); NDArray values = inputs.get(2); NDArray validLens = inputs.get(3); // On axis 0, copy the first item (scalar or vector) for // `numHeads` times, then copy the next item, and so on validLens = validLens.repeat(0, numHeads); queries = transposeQkv( W_q.forward(ps, new NDList(queries), training, params).get(0), numHeads); keys = transposeQkv( W_k.forward(ps, new NDList(keys), training, params).get(0), numHeads); values = transposeQkv( W_v.forward(ps, new NDList(values), training, params).get(0), numHeads); // Shape of `output`: (`batchSize` * `numHeads`, no. of queries, // `numHiddens` / `numHeads`) NDArray output = attention .forward( ps, new NDList(queries, keys, values, validLens), training, params) .get(0); // Shape of `outputConcat`: // (`batchSize`, no. of queries, `numHiddens`) NDArray outputConcat = transposeOutput(output, numHeads); return new NDList(W_o.forward(ps, new NDList(outputConcat), training, params).get(0)); } @Override public Shape[] getOutputShapes(Shape[] inputShapes) { throw new UnsupportedOperationException("Not implemented"); } @Override public void initializeChildBlocks( NDManager manager, DataType dataType, Shape... inputShapes) { try (NDManager sub = manager.newSubManager()) { NDArray queries = sub.zeros(inputShapes[0], dataType); NDArray keys = sub.zeros(inputShapes[1], dataType); NDArray values = sub.zeros(inputShapes[2], dataType); NDArray validLens = sub.zeros(inputShapes[3], dataType); validLens = validLens.repeat(0, numHeads); ParameterStore ps = new ParameterStore(sub, false); W_q.initialize(manager, dataType, queries.getShape()); W_k.initialize(manager, dataType, keys.getShape()); W_v.initialize(manager, dataType, values.getShape()); queries = transposeQkv(W_q.forward(ps, new NDList(queries), false).get(0), numHeads); keys = transposeQkv(W_k.forward(ps, new NDList(keys), false).get(0), numHeads); values = transposeQkv(W_v.forward(ps, new NDList(values), false).get(0), numHeads); NDList list = new NDList(queries, keys, values, validLens); attention.initialize(sub, dataType, list.getShapes()); NDArray output = attention.forward(ps, list, false).head(); NDArray outputConcat = Chap10Utils.transposeOutput(output, numHeads); W_o.initialize(manager, dataType, outputConcat.getShape()); } } } 让我们使用键和值相同的小例子来测试我们编写的 ``MultiHeadAttention`` 类。多头注意力输出的形状是 (``batchSize``, ``numQueries``, ``numHiddens``)。 .. code:: java int numHiddens = 100; int numHeads = 5; MultiHeadAttention attention = new MultiHeadAttention(numHiddens, numHeads, 0.5f, false); .. code:: java int batchSize = 2; int numQueries = 4; int numKvpairs = 6; NDArray validLens = manager.create(new float[] {3, 2}); NDArray X = manager.ones(new Shape(batchSize, numQueries, numHiddens)); NDArray Y = manager.ones(new Shape(batchSize, numKvpairs, numHiddens)); ParameterStore ps = new ParameterStore(manager, false); NDList input = new NDList(X, Y, Y, validLens); attention.initialize(manager, DataType.FLOAT32, input.getShapes()); NDList result = attention.forward(ps, input, false); result.get(0).getShape(); .. parsed-literal:: :class: output (2, 4, 100) 小结 ---- - 多头注意力融合了来自于相同的注意力汇聚产生的不同的知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。 - 基于适当的张量操作,可以实现多头注意力的并行计算。 练习 ---- 1. 分别可视化这个实验中的多个头的注意力权重。 2. 假设我们已经拥有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。应该如何设计实验来衡量注意力头的重要性?