Run this notebook online:Binder or Colab: Colab

9.1. 门控循环单元(GRU)

Section 8.7中, 我们讨论了如何在循环神经网络中计算梯度, 以及矩阵连续乘积可以导致梯度消失或梯度爆炸的问题。 下面我们简单思考一下这种梯度异常在实践中的意义:

  • 我们可能会遇到这样的情况:早期观测值对预测所有未来观测值具有非常重要的意义。 考虑一个极端情况,其中第一个观测值包含一个校验和, 目标是在序列的末尾辨别校验和是否正确。 在这种情况下,第一个词元的影响至关重要。 我们希望有某些机制能够在一个记忆元里存储重要的早期信息。 如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度, 因为它会影响所有后续的观测值。

  • 我们可能会遇到这样的情况:一些词元没有相关的观测值。 例如,在对网页内容进行情感分析时, 可能有一些辅助HTML代码与网页传达的情绪无关。 我们希望有一些机制来跳过隐状态表示中的此类词元。

  • 我们可能会遇到这样的情况:序列的各个部分之间存在逻辑中断。 例如,书的章节之间可能会有过渡存在, 或者证券的熊市和牛市之间可能会有过渡存在。 在这种情况下,最好有一种方法来重置我们的内部状态表示。

在学术界已经提出了许多方法来解决这类问题。 其中最早的方法是”长短期记忆”(long-short-term memory,LSTM) [Hochreiter & Schmidhuber, 1997], 我们将在 Section 9.2中讨论。 门控循环单元(gated recurrent unit,GRU) [Cho et al., 2014] 是一个稍微简化的变体,通常能够提供同等的效果, 并且计算 [Chung et al., 2014]的速度明显更快。 由于门控循环单元更简单,我们从它开始解读。

9.1.1. 门控隐状态

门控循环单元与普通的循环神经网络之间的关键区别在于: 后者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。 这些机制是可学习的,并且能够解决了上面列出的问题。 例如,如果第一个词元非常重要, 模型将学会在第一次观测之后不更新隐状态。 同样,模型也可以学会跳过不相关的临时观测。 最后,模型还将学会在需要的时候重置隐状态。 下面我们将详细讨论各类门控。

9.1.1.1. 重置门和更新门

我们首先介绍重置门(reset gate)和更新门(update gate)。 我们把它们设计成\((0, 1)\)区间中的向量, 这样我们就可以进行凸组合。 重置门允许我们控制“可能还想记住”的过去状态的数量; 更新门将允许我们控制新状态中有多少个是旧状态的副本。

我们从构造这些门控开始。 fig_gru_1 描述了门控循环单元中的重置门和更新门的输入, 输入是由当前时间步的输入和前一时间步的隐状态给出。 两个门的输出是由使用sigmoid激活函数的两个全连接层给出。

在门控循环单元模型中计算重置门和更新门 .. _fig_gru_1:

我们来看一下门控循环单元的数学表达。 对于给定的时间步\(t\),假设输入是一个小批量 \(\mathbf{X}_t \in \mathbb{R}^{n \times d}\) (样本个数:\(n\),输入个数:\(d\)), 上一个时间步的隐状态是 \(\mathbf{H}_{t-1} \in \mathbb{R}^{n \times h}\) (隐藏单元个数:\(h\))。 那么,重置门\(\mathbf{R}_t \in \mathbb{R}^{n \times h}\)和 更新门\(\mathbf{Z}_t \in \mathbb{R}^{n \times h}\)的计算如下所示:

(9.1.1)\[\begin{split}\begin{aligned} \mathbf{R}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xr} + \mathbf{H}_{t-1} \mathbf{W}_{hr} + \mathbf{b}_r),\\ \mathbf{Z}_t = \sigma(\mathbf{X}_t \mathbf{W}_{xz} + \mathbf{H}_{t-1} \mathbf{W}_{hz} + \mathbf{b}_z), \end{aligned}\end{split}\]

其中\(\mathbf{W}_{xr}, \mathbf{W}_{xz} \in \mathbb{R}^{d \times h}\)\(\mathbf{W}_{hr}, \mathbf{W}_{hz} \in \mathbb{R}^{h \times h}\)是权重参数, \(\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}\)是偏置参数。 请注意,在求和过程中会触发广播机制 (请参阅 subsec_broadcasting)。 我们使用sigmoid函数(如 Section 4.1中介绍的) 将输入值转换到区间\((0, 1)\)

9.1.1.2. 候选隐状态

接下来,让我们将重置门\(\mathbf{R}_t\)(8.4.5) 中的常规隐状态更新机制集成, 得到在时间步\(t\)候选隐状态(candidate hidden state) \(\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}\)

(9.1.2)\[\tilde{\mathbf{H}}_t = \tanh(\mathbf{X}_t \mathbf{W}_{xh} + \left(\mathbf{R}_t \odot \mathbf{H}_{t-1}\right) \mathbf{W}_{hh} + \mathbf{b}_h),\]

其中\(\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}\)\(\mathbf{W}_{hh} \in \mathbb{R}^{h \times h}\)是权重参数, \(\mathbf{b}_h \in \mathbb{R}^{1 \times h}\)是偏置项, 符号\(\odot\)是Hadamard积(按元素乘积)运算符。 在这里,我们使用tanh非线性激活函数来确保候选隐状态中的值保持在区间\((-1, 1)\)中。

(8.4.5)相比, (9.1.2)中的\(\mathbf{R}_t\)\(\mathbf{H}_{t-1}\) 的元素相乘可以减少以往状态的影响。 每当重置门\(\mathbf{R}_t\)中的项接近\(1\)时, 我们恢复一个如 (8.4.5)中的普通的循环神经网络。 对于重置门\(\mathbf{R}_t\)中所有接近\(0\)的项, 候选隐状态是以\(\mathbf{X}_t\)作为输入的多层感知机的结果。 因此,任何预先存在的隐状态都会被重置为默认值。

fig_gru_2说明了应用重置门之后的计算流程。

在门控循环单元模型中计算候选隐状态 .. _fig_gru_2:

9.1.1.3. 隐状态

上述的计算结果只是候选隐状态,我们仍然需要结合更新门\(\mathbf{Z}_t\)的效果。 这一步确定新的隐状态\(\mathbf{H}_t \in \mathbb{R}^{n \times h}\) 在多大程度上来自旧的状态\(\mathbf{H}_{t-1}\)和 新的候选状态\(\tilde{\mathbf{H}}_t\)。 更新门\(\mathbf{Z}_t\)仅需要在 \(\mathbf{H}_{t-1}\)\(\tilde{\mathbf{H}}_t\) 之间进行按元素的凸组合就可以实现这个目标。 这就得出了门控循环单元的最终更新公式:

(9.1.3)\[\mathbf{H}_t = \mathbf{Z}_t \odot \mathbf{H}_{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.\]

每当更新门\(\mathbf{Z}_t\)接近\(1\)时,模型就倾向只保留旧状态。 此时,来自\(\mathbf{X}_t\)的信息基本上被忽略, 从而有效地跳过了依赖链条中的时间步\(t\)。 相反,当\(\mathbf{Z}_t\)接近\(0\)时, 新的隐状态\(\mathbf{H}_t\)就会接近候选隐状态\(\tilde{\mathbf{H}}_t\)。 这些设计可以帮助我们处理循环神经网络中的梯度消失问题, 并更好地捕获时间步距离很长的序列的依赖关系。 例如,如果整个子序列的所有时间步的更新门都接近于\(1\), 则无论序列的长度如何,在序列起始时间步的旧隐状态都将很容易保留并传递到序列结束。

fig_gru_3说明了更新门起作用后的计算流。

计算门控循环单元模型中的隐状态 .. _fig_gru_3:

总之,门控循环单元具有以下两个显著特征:

  • 重置门有助于捕获序列中的短期依赖关系。

  • 更新门有助于捕获序列中的长期依赖关系。

9.1.2. 从零开始实现

为了更好地理解门控循环单元模型,我们从零开始实现它。 首先,我们读取 Section 8.5中使用的时间机器数据集:

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java

%load ../utils/StopWatch.java
%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModel.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
%load ../utils/timemachine/TimeMachineDataset.java
NDManager manager = NDManager.newBaseManager();
int batchSize = 32;
int numSteps = 35;

TimeMachineDataset dataset =
        new TimeMachineDataset.Builder()
                .setManager(manager)
                .setMaxTokens(10000)
                .setSampling(batchSize, false)
                .setSteps(numSteps)
                .build();
dataset.prepare();
Vocab vocab = dataset.getVocab();

9.1.2.1. 初始化模型参数

下一步是初始化模型参数。 我们从标准差为\(0.01\)的高斯分布中提取权重, 并将偏置项设为\(0\),超参数numHiddens定义隐藏单元的数量, 实例化与更新门、重置门、候选隐状态和输出层相关的所有权重和偏置。

public static NDArray normal(Shape shape, Device device) {
    return manager.randomNormal(0, 0.01f, shape, DataType.FLOAT32, device);
}

public static NDList three(int numInputs, int numHiddens, Device device) {
    return new NDList(
            normal(new Shape(numInputs, numHiddens), device),
            normal(new Shape(numHiddens, numHiddens), device),
            manager.zeros(new Shape(numHiddens), DataType.FLOAT32, device));
}

public static NDList getParams(int vocabSize, int numHiddens, Device device) {
    int numInputs = vocabSize;
    int numOutputs = vocabSize;

    // Update gate parameters
    NDList temp = three(numInputs, numHiddens, device);
    NDArray W_xz = temp.get(0);
    NDArray W_hz = temp.get(1);
    NDArray b_z = temp.get(2);

    // Reset gate parameters
    temp = three(numInputs, numHiddens, device);
    NDArray W_xr = temp.get(0);
    NDArray W_hr = temp.get(1);
    NDArray b_r = temp.get(2);

    // Candidate hidden state parameters
    temp = three(numInputs, numHiddens, device);
    NDArray W_xh = temp.get(0);
    NDArray W_hh = temp.get(1);
    NDArray b_h = temp.get(2);

    // Output layer parameters
    NDArray W_hq = normal(new Shape(numHiddens, numOutputs), device);
    NDArray b_q = manager.zeros(new Shape(numOutputs), DataType.FLOAT32, device);

    // Attach gradients
    NDList params = new NDList(W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q);
    for (NDArray param : params) {
        param.setRequiresGradient(true);
    }
    return params;
}

9.1.2.2. 定义模型

现在我们将定义隐状态的初始化函数 initGruState()。 与 Section 8.5中定义的initRnnState()函数一样, 此函数返回一个形状为(批量大小,隐藏单元个数)的 NDArray,NDArray 的值全部为零。

public static NDList initGruState(int batchSize, int numHiddens, Device device) {
    return new NDList(manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device));
}

现在我们准备定义门控循环单元模型, 模型的架构与基本的循环神经网络单元是相同的, 只是权重更新公式更为复杂。

public static Pair<NDArray, NDList> gru(NDArray inputs, NDList state, NDList params) {
    NDArray W_xz = params.get(0);
    NDArray W_hz = params.get(1);
    NDArray b_z = params.get(2);

    NDArray W_xr = params.get(3);
    NDArray W_hr = params.get(4);
    NDArray b_r = params.get(5);

    NDArray W_xh = params.get(6);
    NDArray W_hh = params.get(7);
    NDArray b_h = params.get(8);

    NDArray W_hq = params.get(9);
    NDArray b_q  = params.get(10);

    NDArray H = state.get(0);
    NDList outputs = new NDList();
    NDArray X, Y, Z, R, H_tilda;
    for (int i = 0; i < inputs.size(0); i++) {
        X = inputs.get(i);
        Z = Activation.sigmoid(X.dot(W_xz).add(H.dot(W_hz).add(b_z)));
        R = Activation.sigmoid(X.dot(W_xr).add(H.dot(W_hr).add(b_r)));
        H_tilda = Activation.tanh(X.dot(W_xh).add(R.mul(H).dot(W_hh).add(b_h)));
        H = Z.mul(H).add(Z.mul(-1).add(1).mul(H_tilda));
        Y = H.dot(W_hq).add(b_q);
        outputs.add(Y);
    }
    return new Pair(outputs.size() > 1 ? NDArrays.concat(outputs) : outputs.get(0), new NDList(H));
}

9.1.2.3. 训练与预测

训练和预测的工作方式与 Section 8.5完全相同。 训练结束后,我们分别打印输出训练集的困惑度, 以及前缀“time traveler”和“traveler”的预测序列上的困惑度。

int vocabSize = vocab.length();
int numHiddens = 256;
Device device = manager.getDevice();
int numEpochs = Integer.getInteger("MAX_EPOCH", 500);

int lr = 1;

Functions.TriFunction<Integer, Integer, Device, NDList> getParamsFn = (a, b, c) -> getParams(a, b, c);
Functions.TriFunction<Integer, Integer, Device, NDList> initGruStateFn =
        (a, b, c) -> initGruState(a, b, c);
Functions.TriFunction<NDArray, NDList, NDList, Pair<NDArray, NDList>> gruFn = (a, b, c) -> gru(a, b, c);

RNNModelScratch model =
        new RNNModelScratch(vocabSize, numHiddens, device,
                getParamsFn, initGruStateFn, gruFn);
TimeMachine.trainCh8(model, dataset, vocab, lr, numEpochs, device, false, manager);
perplexity: 1.0, 14074.5 tokens/sec on gpu(0)
time traveller the walk touscens of ppatinal of shayline so gonk
travellerim therely eremt you dire sithtat rouly on the fir

9.1.3. 简洁实现

高级API包含了前文介绍的所有配置细节, 所以我们可以直接实例化门控循环单元模型。 这段代码的运行速度要快得多, 因为它使用的是编译好的运算符而不是单个的 NDArray 运算来处理之前阐述的许多细节。

GRU gruLayer = GRU.builder().setNumLayers(1)
        .setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();
RNNModel modelConcise = new RNNModel(gruLayer,vocab.length());
TimeMachine.trainCh8(modelConcise, dataset, vocab, lr, numEpochs, device, false, manager);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.8.0 in 0.066 ms.
perplexity: 1.0, 83413.0 tokens/sec on gpu(0)
time travellerit s against reason said filby an a gume back diff
travelleryou cand and hend his arathohit ies and have be ex

9.1.4. 小结

  • 门控循环神经网络可以更好地捕获时间步距离很长的序列上的依赖关系。

  • 重置门有助于捕获序列中的短期依赖关系。

  • 更新门有助于捕获序列中的长期依赖关系。

  • 重置门打开时,门控循环单元包含基本循环神经网络;更新门打开时,门控循环单元可以跳过子序列。

9.1.5. 练习

  1. 假设我们只想使用时间步\(t'\)的输入来预测时间步\(t > t'\)的输出。对于每个时间步,重置门和更新门的最佳值是什么?

  2. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。

  3. 比较rnn.RNNrnn.GRU的不同实现对运行时间、困惑度和输出字符串的影响。

  4. 如果仅仅实现门控循环单元的一部分,例如,只有一个重置门或一个更新门会怎样?