Run this notebook online: or Colab:

# 9.2. 长短期记忆网络（LSTM）¶

## 9.2.1. 门控记忆元¶

### 9.2.1.1. 输入门、忘记门和输出门¶

.. _lstm_0:

(9.2.1)\begin{split}\begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ \mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o), \end{aligned}\end{split}

### 9.2.1.2. 候选记忆元¶

(9.2.2)$\tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c),$

.. _lstm_1:

### 9.2.1.3. 记忆元¶

(9.2.3)$\mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t.$

Fig. 9.2.1 在长短期记忆网络模型中计算记忆元

### 9.2.1.4. 隐状态¶

(9.2.4)$\mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t).$

lstm_3提供了数据流的图形化演示。

.. _lstm_3:

## 9.2.2. 从零开始实现¶

%load ../utils/djl-imports


NDManager manager = NDManager.newBaseManager();

int batchSize = 32;
int numSteps = 35;

TimeMachineDataset dataset =
new TimeMachineDataset.Builder()
.setManager(manager)
.setMaxTokens(10000)
.setSampling(batchSize, false)
.setSteps(numSteps)
.build();
dataset.prepare();
Vocab vocab = dataset.getVocab();


### 9.2.2.1. 初始化模型参数¶

public static NDList getLSTMParams(int vocabSize, int numHiddens, Device device) {
int numInputs = vocabSize;
int numOutputs = vocabSize;

// Input gate parameters
NDList temp = three(numInputs, numHiddens, device);
NDArray W_xi = temp.get(0);
NDArray W_hi = temp.get(1);
NDArray b_i = temp.get(2);

// Forget gate parameters
temp = three(numInputs, numHiddens, device);
NDArray W_xf = temp.get(0);
NDArray W_hf = temp.get(1);
NDArray b_f = temp.get(2);

// Output gate parameters
temp = three(numInputs, numHiddens, device);
NDArray W_xo = temp.get(0);
NDArray W_ho = temp.get(1);
NDArray b_o = temp.get(2);

// Candidate memory cell parameters
temp = three(numInputs, numHiddens, device);
NDArray W_xc = temp.get(0);
NDArray W_hc = temp.get(1);
NDArray b_c = temp.get(2);

// Output layer parameters
NDArray W_hq = normal(new Shape(numHiddens, numOutputs), device);
NDArray b_q = manager.zeros(new Shape(numOutputs), DataType.FLOAT32, device);

NDList params =
new NDList(
W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq,
b_q);
for (NDArray param : params) {
}
return params;
}

public static NDArray normal(Shape shape, Device device) {
return manager.randomNormal(0, 0.01f, shape, DataType.FLOAT32, device);
}

public static NDList three(int numInputs, int numHiddens, Device device) {
return new NDList(
normal(new Shape(numInputs, numHiddens), device),
normal(new Shape(numHiddens, numHiddens), device),
manager.zeros(new Shape(numHiddens), DataType.FLOAT32, device));
}


### 9.2.2.2. 定义模型¶

public static NDList initLSTMState(int batchSize, int numHiddens, Device device) {
return new NDList(
manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device),
manager.zeros(new Shape(batchSize, numHiddens), DataType.FLOAT32, device));
}


public static Pair<NDArray, NDList> lstm(NDArray inputs, NDList state, NDList params) {
NDArray W_xi = params.get(0);
NDArray W_hi = params.get(1);
NDArray b_i = params.get(2);

NDArray W_xf = params.get(3);
NDArray W_hf = params.get(4);
NDArray b_f = params.get(5);

NDArray W_xo = params.get(6);
NDArray W_ho = params.get(7);
NDArray b_o = params.get(8);

NDArray W_xc = params.get(9);
NDArray W_hc = params.get(10);
NDArray b_c = params.get(11);

NDArray W_hq = params.get(12);
NDArray b_q = params.get(13);

NDArray H = state.get(0);
NDArray C = state.get(1);
NDList outputs = new NDList();
NDArray X, Y, I, F, O, C_tilda;
for (int i = 0; i < inputs.size(0); i++) {
X = inputs.get(i);
H = O.mul(Activation.tanh(C));
}
return new Pair(
outputs.size() > 1 ? NDArrays.concat(outputs) : outputs.get(0), new NDList(H, C));
}


### 9.2.2.3. 训练和预测¶

int vocabSize = vocab.length();
int numHiddens = 256;
Device device = manager.getDevice();
int numEpochs = Integer.getInteger("MAX_EPOCH", 500);

int lr = 1;

Functions.TriFunction<Integer, Integer, Device, NDList> getParamsFn =
(a, b, c) -> getLSTMParams(a, b, c);
Functions.TriFunction<Integer, Integer, Device, NDList> initLSTMStateFn =
(a, b, c) -> initLSTMState(a, b, c);
Functions.TriFunction<NDArray, NDList, NDList, Pair<NDArray, NDList>> lstmFn = (a, b, c) -> lstm(a, b, c);

RNNModelScratch model =
new RNNModelScratch(
vocabSize, numHiddens, device, getParamsFn, initLSTMStateFn, lstmFn);
TimeMachine.trainCh8(model, dataset, vocab, lr, numEpochs, device, false, manager);

perplexity: 1.1, 12066.0 tokens/sec on gpu(0)
time travellerit s against reason said filby but you willnever c
travellerype thing that by madeateryienclery is it huss ge


## 9.2.3. 简洁实现¶

LSTM lstmLayer =
LSTM.builder()
.setNumLayers(1)
.setStateSize(numHiddens)
.optReturnState(true)
.optBatchFirst(false)
.build();
RNNModel modelConcise = new RNNModel(lstmLayer, vocab.length());
TimeMachine.trainCh8(modelConcise, dataset, vocab, lr, numEpochs, device, false, manager);

INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.062 ms.

perplexity: 1.1, 79980.0 tokens/sec on gpu(0)
time traveller file some abeeimenthon this tore proner arspowsis
traveller fores yound at at sughtare mede a soit and said t


## 9.2.4. 小结¶

• 长短期记忆网络有三种类型的门：输入门、遗忘门和输出门。

• 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层，而记忆元完全属于内部信息。

• 长短期记忆网络可以缓解梯度消失和梯度爆炸。

## 9.2.5. 练习¶

1. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。

2. 你需要如何更改模型以生成适当的单词，而不是字符序列？

3. 在给定隐藏层维度的情况下，比较门控循环单元、长短期记忆网络和常规循环神经网络的计算成本。要特别注意训练和推断成本。

4. 既然候选记忆元通过使用$$\tanh$$函数来确保值范围在$$(-1,1)$$之间，那么为什么隐状态需要再次使用$$\tanh$$函数来确保输出值范围在$$(-1,1)$$之间呢？

5. 实现一个能够基于时间序列进行预测而不是基于字符序列进行预测的长短期记忆网络模型。