Run this notebook online:Binder or Colab: Colab

10.4. 自注意力和位置编码

在深度学习中,我们经常使用卷积神经网络(CNN)或循环神经网络(RNN)对序列进行编码。现在想象一下,有了注意力机制之后,我们将词元序列输入注意力池化中,以便同一组词元同时充当查询、键和值。具体来说,每个查询都会关注所有的键-值对并生成一个注意力输出。由于查询、键和值来自同一组输入,因此被称为 自注意力(self-attention [Lin.Feng.Santos.ea.2017][Vaswani et al., 2017]),也被称为 内部注意力(intra-attention [Cheng.Dong.Lapata.2016][Parikh et al., 2016][Paulus.Xiong.Socher.2017])。在本节中,我们将讨论使用自注意力进行序列编码,包括使用序列的顺序作为补充信息。

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java

%load ../utils/attention/Chap10Utils.java
%load ../utils/attention/DotProductAttention.java
%load ../utils/attention/MultiHeadAttention.java
%load ../utils/attention/PositionalEncoding.java
NDManager manager = NDManager.newBaseManager();

10.4.1. 自注意力

给定一个由词元组成的输入序列 \(\mathbf{x}_1, \ldots, \mathbf{x}_n\),其中任意 \(\mathbf{x}_i \in \mathbb{R}^d\) (\(1 \leq i \leq n\))。该序列的自注意力输出为一个长度相同的序列 \(\mathbf{y}_1, \ldots, \mathbf{y}_n\),其中:

(10.4.1)\[\mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d\]

根据 () 中定义的注意力池化函数 \(f\) 。下面的代码片段是基于多头注意力对一个 NDArray 完成自注意力的计算,NDArray 的形状为 (批量大小, 时间步的数目或词元序列的长度 \(d\)) 。输出与输入的 NDArray 形状相同。

int numHiddens = 100;
int numHeads = 5;
MultiHeadAttention attention = new MultiHeadAttention(numHiddens, numHeads, 0.5f, false);
int batchSize = 2;
int numQueries = 4;
NDArray validLens = manager.create(new float[] {3, 2});
NDArray X = manager.ones(new Shape(batchSize, numQueries, numHiddens));
ParameterStore ps = new ParameterStore(manager, false);
NDList input = new NDList(X, X, X, validLens);
attention.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList result = attention.forward(ps, input, false);
result.get(0).getShape()
(2, 4, 100)

10.4.2. 比较卷积神经网络、循环神经网络和自注意力

让我们比较下面几个架构,目标都是将由 \(n\) 个词元组成的序列映射到另一个长度相等的序列,其中的每个输入词元或输出词元都由 \(d\) 维向量表示。具体来说,我们将比较的是卷积神经网络、循环神经网络和自注意力这几个架构的计算复杂性、顺序操作和最大路径长度。请注意,顺序操作会妨碍并行计算,而任意的序列位置组合之间的路径越短,则能更轻松地学习序列中的远距离依赖关系 [Hochreiter.Bengio.Frasconi.ea.2001]

比较卷积神经网络(填充词元被忽略)、循环神经网络和自注意力三种架构。

考虑一个卷积核大小为 \(k\) 的卷积层。我们将在后面的章节中提供关于使用卷积神经网络处理序列的更多详细信息。目前,我们只需要知道,由于序列长度是 \(n\) ,输入和输出的通道数量都是 \(d\),所以卷积层的计算复杂度为 \(\mathcal{O}(knd^2)\) 。如 Section 10.4.2 所示,卷积神经网络是分层的,因此为有 \(\mathcal{O}(1)\) 个顺序操作,最大路径长度为 \(\mathcal{O}(n/k)\)。例如,\(\mathbf{x}_1\)\(\mathbf{x}_5\) 处于 Section 10.4.2 中卷积核大小为3的双层卷积神经网络的感受野内。

当更新循环神经网络的隐藏状态时,\(d \times d\) 权重矩阵和 \(d\) 维隐藏状态的乘法计算复杂度为 \(\mathcal{O}(d^2)\) 。由于序列长度为 \(n\),因此循环神经网络层的计算复杂度为 \(\mathcal{O}(nd^2)\)。根据 Section 10.4.2,有 \(\mathcal{O}(n)\) 个顺序操作无法并行化,最大路径长度也是 \(\mathcal{O}(n)\)

在自注意力中,查询、键和值都是 \(n \times d\) 矩阵。考虑 (10.2.6) 中缩放的”点-积“注意力,其中 \(n \times d\) 矩阵乘以 \(d \times n\) 矩阵,然后输出的 \(n \times n\) 矩阵乘以 \(n \times d\) 矩阵。因此,自注意力具有 \(\mathcal{O}(n^2d)\) 计算复杂性。正如我们在 Section 10.4.2 中看到的那样,每个词元都通过自注意力直接连接到任何其他词元。因此,有 \(\mathcal{O}(1)\) 个顺序操作可以并行计算,最大路径长度也是 \(\mathcal{O}(1)\)

总而言之,卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

10.4.3. 位置编码

在处理词元序列时,循环神经网络是逐个的重复地处理词元的,而自注意力则因为并行计算而放弃了顺序操作。为了使用序列的顺序信息,我们通过在输入表示中添加 位置编码(positional encoding) 来注入绝对的或相对的位置信息。位置编码可以通过学习得到也可以直接固定得到。接下来,我们描述的是基于正弦函数和余弦函数的固定位置编码 [Vaswani et al., 2017]

假设输入表示 \(\mathbf{X} \in \mathbb{R}^{n \times d}\) 包含一个序列中 \(n\) 个词元的 \(d\) 维嵌入表示。位置编码使用相同形状的位置嵌入矩阵 \(\mathbf{X} + \mathbf{P}\) 输出 \(\mathbf{P} \in \mathbb{R}^{n \times d}\),矩阵第 \(i^\mathrm{th}\) 行、第 \((2j)^\mathrm{th}\) 列和 \((2j + 1)^\mathrm{th}\) 列上的元素为:

(10.4.2)\[\begin{split}\begin{aligned} p_{i, 2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right),\\p_{i, 2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right).\end{aligned}\end{split}\]

乍一看,这种基于三角函数的设计看起来很奇怪。在解释这个设计之前,让我们先在下面的 PositionalEncoding 类中实现它。

public class PositionalEncoding extends AbstractBlock {

    private Dropout dropout;
    public NDArray P;

    public PositionalEncoding(int numHiddens, float dropout, int maxLen, NDManager manager) {
        this.dropout = Dropout.builder().optRate(dropout).build();
        addChildBlock("dropout", this.dropout);

        // Create a long enough `P`
        P = manager.zeros(new Shape(1, maxLen, numHiddens));
        NDArray X =
                manager.arange(maxLen)
                        .reshape(-1, 1)
                        .div(
                                manager.create(10000)
                                        .pow(manager.arange(0, numHiddens, 2).div(numHiddens)));
        P.set(new NDIndex(":, :, {}::{}", 0, 2), X.sin());
        P.set(new NDIndex(":, :, {}::{}", 1, 2), X.cos());
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore parameterStore,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        NDArray X = inputs.get(0);
        X = X.add(P.get(":, :{}, :", X.getShape().get(1)));
        return new NDList(
                dropout.forward(parameterStore, new NDList(X), training, params).get(0));
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
        try (NDManager sub = manager.newSubManager()) {
            NDArray X = sub.zeros(inputShapes[0], dataType);
            X = X.add(P.get(":, :{}, :", X.getShape().get(1)));
            dropout.initialize(manager, dataType, X.getShape());
        }
    }
}

在位置嵌入矩阵 \(\mathbf{P}\) 中,行代表词元在序列中的位置,列代表位置编码的不同维度。在下面的例子中,我们可以看到位置嵌入矩阵的第 6 列和第 7列的频率高于第 8 列和第 9 列。第 6 列和第 7 列之间的偏移量(第 8 列和 第 9 列相同)是由于正弦函数和余弦函数的交替。

int encodingDim = 32;
int numSteps = 60;
PositionalEncoding posEncoding = new PositionalEncoding(encodingDim, 0, 1000, manager);
input = new NDList(manager.zeros(new Shape(1, numSteps, encodingDim)));
X = posEncoding.forward(ps, input, false).get(0);
NDArray P = posEncoding.P.get(new NDIndex(":, :{}, :", X.getShape().get(1)));

double[][] plotX = new double[4][];
double[][] plotY = new double[4][];
for (int i = 0; i < 4; i++) {
    if (i == 0) {
        plotX[i] = manager.arange(numSteps).toType(DataType.FLOAT64, false).toDoubleArray();
    } else {
        plotX[i] = plotX[i - 1];
    }
    plotY[i] =
            Functions.floatToDoubleArray(
                    P.get(new NDIndex("0, :, {},", i + 6)).toFloatArray());
}


PlotUtils.plot(
        plotX,
        plotY,
        new String[] {"Col6", "Col7", "Col8", "Col9"},
        "Row (position)",
        "");

10.4.3.1. 绝对位置信息

为了明白沿着编码维度单调降低的频率与绝对位置信息的关系,让我们打印出 \(0, 1, \ldots, 7\) 的二进制表示形式。正如我们所看到的,每个数字、每两个数字和每四个数字上的比特值在第一个最低位、第二个最低位和第三个最低位上分别交替。

for (int i = 0; i < 8; i++) {
    System.out.println(i + " in binary is " + Integer.toBinaryString(i));
}
0 in binary is 0
1 in binary is 1
2 in binary is 10
3 in binary is 11
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111

在二进制表示中,较高比特位的交替频率低于较低比特位,与下面的热图所示相似,只是位置编码通过使用三角函数在编码维度上降低频率。由于输出是浮点数,因此此类连续表示比二进制表示法更节省空间。

P = P.get(new NDIndex("0, :, :")).expandDims(0).expandDims(0);
PlotUtils.showHeatmaps(
        P, "Column (encoding dimension)", "Row (position)", new String[] {""}, 500, 700);

10.4.3.2. 相对位置信息

除了捕获绝对位置信息之外,上述的位置编码还允许模型学习得到输入序列中相对位置信息。这是因为对于任何确定的位置偏移 \(\delta\),位置 \(i + \delta\) 处的位置编码可以线性投影位置 \(i\) 处的位置编码来表示。

这种投影的数学解释是,令 \(\omega_j = 1/10000^{2j/d}\),对于任何确定的位置偏移 \(\delta\),:eqref:eq_positional-encoding-def 中的任何一对 \((p_{i, 2j}, p_{i, 2j+1})\) 都可以线性投影到 \((p_{i+\delta, 2j}, p_{i+\delta, 2j+1})\)

(10.4.3)\[\begin{split}\begin{aligned} &\begin{bmatrix} \cos(\delta \omega_j) & \sin(\delta \omega_j) \\ -\sin(\delta \omega_j) & \cos(\delta \omega_j) \\ \end{bmatrix} \begin{bmatrix} p_{i, 2j} \\ p_{i, 2j+1} \\ \end{bmatrix}\\ =&\begin{bmatrix} \cos(\delta \omega_j) \sin(i \omega_j) + \sin(\delta \omega_j) \cos(i \omega_j) \\ -\sin(\delta \omega_j) \sin(i \omega_j) + \cos(\delta \omega_j) \cos(i \omega_j) \\ \end{bmatrix}\\ =&\begin{bmatrix} \sin\left((i+\delta) \omega_j\right) \\ \cos\left((i+\delta) \omega_j\right) \\ \end{bmatrix}\\ =& \begin{bmatrix} p_{i+\delta, 2j} \\ p_{i+\delta, 2j+1} \\ \end{bmatrix}, \end{aligned}\end{split}\]

\(2\times 2\) 投影矩阵不依赖于任何位置的索引 \(i\)

10.4.4. 小结

  • 在自注意力中,查询、键和值都来自同一组输入。

  • 卷积神经网络和自注意力都拥有并行计算的优势,而且自注意力的最大路径长度最短。但是因为其计算复杂度是关于序列长度的二次方,所以在很长的序列中计算会非常慢。

  • 为了使用序列的顺序信息,我们可以通过在输入表示中添加位置编码来注入绝对的或相对的位置信息。

10.4.5. 练习

  1. 假设我们设计一个深度架构,通过堆叠基于位置编码的自注意力层来表示序列。可能会存在什么问题?

  2. 你能设计一种可学习的位置编码方法吗?