Run this notebook online: or Colab:
8.6. 递归神经网络的简洁实现¶
而:numref:sec_rnn_scratch
对于了解rnn是如何实现的很有指导意义,
这既不方便也不快捷。 本节将展示如何更有效地实现相同的语言模型
使用高级API提供的函数 一个深入学习的框架。
我们像以前一样从读取时间机器数据集开始。
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/PlotUtils.java
%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Functions.java
%load ../utils/StopWatch.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
import ai.djl.training.dataset.Record;
NDManager manager = NDManager.newBaseManager();
8.6.1. 在DJL中创建数据集¶
在 DJL
中,处理数据集的理想而简洁的方法是使用内置的数据集,这些数据集可以轻松地围绕现有的
Ndarray
,或者创建从 RandomAccessDataset
类扩展而来的自己的数据集。对于这一部分,我们将实现我们自己的。有关在 DJL
中创建自己的数据集的更多信息,请参阅:https://djl.ai/docs/development/how_to_use_dataset.html
我们对 TimeMachineDataset
简洁地实现将替换先前创建的
SeqDataLoader
类。使用 DJL
格式的数据集,将允许我们使用已经内置的函数,这样我们就不必从头开始实现大多数功能。我们必须实现一个生成器、一个包含将数据保存到TimeMachineDataset对象的过程的prepare函数,以及一个get函数。
public static class TimeMachineDataset extends RandomAccessDataset {
private Vocab vocab;
private NDArray data;
private NDArray labels;
private int numSteps;
private int maxTokens;
private int batchSize;
private NDManager manager;
private boolean prepared;
public TimeMachineDataset(Builder builder) {
super(builder);
this.numSteps = builder.numSteps;
this.maxTokens = builder.maxTokens;
this.batchSize = builder.getSampler().getBatchSize();
this.manager = builder.manager;
this.data = this.manager.create(new Shape(0,35), DataType.INT32);
this.labels = this.manager.create(new Shape(0,35), DataType.INT32);
this.prepared = false;
}
@Override
public Record get(NDManager manager, long index) throws IOException {
NDArray X = data.get(new NDIndex("{}", index));
NDArray Y = labels.get(new NDIndex("{}", index));
return new Record(new NDList(X), new NDList(Y));
}
@Override
protected long availableSize() {
return data.getShape().get(0);
}
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
if (prepared) {
return;
}
Pair<List<Integer>, Vocab> corpusVocabPair = null;
try {
corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens);
} catch (Exception e) {
e.printStackTrace(); // 在tokenize()函数期间,异常可能来自未知的token类型。
}
List<Integer> corpus = corpusVocabPair.getKey();
this.vocab = corpusVocabPair.getValue();
// 从一个随机偏移量(包括'numSteps-1')开始到分区a
// 序列
int offset = new Random().nextInt(numSteps);
int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize;
NDArray Xs =
manager.create(
corpus.subList(offset, offset + numTokens).stream()
.mapToInt(Integer::intValue)
.toArray());
NDArray Ys =
manager.create(
corpus.subList(offset + 1, offset + 1 + numTokens).stream()
.mapToInt(Integer::intValue)
.toArray());
Xs = Xs.reshape(new Shape(batchSize, -1));
Ys = Ys.reshape(new Shape(batchSize, -1));
int numBatches = (int) Xs.getShape().get(1) / numSteps;
NDList xNDList = new NDList();
NDList yNDList = new NDList();
for (int i = 0; i < numSteps * numBatches; i += numSteps) {
NDArray X = Xs.get(new NDIndex(":, {}:{}", i, i + numSteps));
NDArray Y = Ys.get(new NDIndex(":, {}:{}", i, i + numSteps));
xNDList.add(X);
yNDList.add(Y);
}
this.data = NDArrays.concat(xNDList);
xNDList.close();
this.labels = NDArrays.concat(yNDList);
yNDList.close();
this.prepared = true;
}
public Vocab getVocab() {
return this.vocab;
}
public static final class Builder extends BaseBuilder<Builder> {
int numSteps;
int maxTokens;
NDManager manager;
@Override
protected Builder self() { return this; }
public Builder setSteps(int steps) {
this.numSteps = steps;
return this;
}
public Builder setMaxTokens(int maxTokens) {
this.maxTokens = maxTokens;
return this;
}
public Builder setManager(NDManager manager) {
this.manager = manager;
return this;
}
public TimeMachineDataset build() throws IOException, TranslateException {
TimeMachineDataset dataset = new TimeMachineDataset(this);
return dataset;
}
}
}
因此,我们将更新上一节中函数 predictCh8
, trainCh8
,
trainEpochCh8
, 和 gradClipping
的代码,以包含数据集逻辑,并允许函数从DJL接受 AbstractBlock
而不是只接受 RNNModelScratch
.
/** 在`prefix`后面生成新字符。 */
public static String predictCh8(
String prefix,
int numPreds,
Object net,
Vocab vocab,
Device device,
NDManager manager) {
List<Integer> outputs = new ArrayList<>();
outputs.add(vocab.getIdx("" + prefix.charAt(0)));
Functions.SimpleFunction<NDArray> getInput =
() ->
manager.create(outputs.get(outputs.size() - 1))
.toDevice(device, false)
.reshape(new Shape(1, 1));
if (net instanceof RNNModelScratch) {
RNNModelScratch castedNet = (RNNModelScratch) net;
NDList state = castedNet.beginState(1, device);
for (char c : prefix.substring(1).toCharArray()) { // 预热期
state = (NDList) castedNet.forward(getInput.apply(), state).getValue();
outputs.add(vocab.getIdx("" + c));
}
NDArray y;
for (int i = 0; i < numPreds; i++) {
Pair<NDArray, NDList> pair = castedNet.forward(getInput.apply(), state);
y = pair.getKey();
state = pair.getValue();
outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));
}
} else {
AbstractBlock castedNet = (AbstractBlock) net;
NDList state = null;
for (char c : prefix.substring(1).toCharArray()) { // 预热期
if (state == null) {
// Begin state
state =
castedNet
.forward(
new ParameterStore(manager, false),
new NDList(getInput.apply()),
false)
.subNDList(1);
} else {
state =
castedNet
.forward(
new ParameterStore(manager, false),
new NDList(getInput.apply()).addAll(state),
false)
.subNDList(1);
}
outputs.add(vocab.getIdx("" + c));
}
NDArray y;
for (int i = 0; i < numPreds; i++) {
NDList pair =
castedNet.forward(
new ParameterStore(manager, false),
new NDList(getInput.apply()).addAll(state),
false);
y = pair.get(0);
state = pair.subNDList(1);
outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));
}
}
StringBuilder output = new StringBuilder();
for (int i : outputs) {
output.append(vocab.idxToToken.get(i));
}
return output.toString();
}
/** 训练一个模型 */
public static void trainCh8(
Object net,
RandomAccessDataset dataset,
Vocab vocab,
int lr,
int numEpochs,
Device device,
boolean useRandomIter,
NDManager manager)
throws IOException, TranslateException {
SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();
Animator animator = new Animator();
Functions.voidTwoFunction<Integer, NDManager> updater;
if (net instanceof RNNModelScratch) {
RNNModelScratch castedNet = (RNNModelScratch) net;
updater =
(batchSize, subManager) ->
Training.sgd(castedNet.params, lr, batchSize, subManager);
} else {
// 已初始化网络
AbstractBlock castedNet = (AbstractBlock) net;
Model model = Model.newInstance("model");
model.setBlock(castedNet);
Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
DefaultTrainingConfig config =
new DefaultTrainingConfig(loss)
.optOptimizer(sgd) // 优化器(损失函数)
.optInitializer(
new NormalInitializer(0.01f),
Parameter.Type.WEIGHT) // 设置初始值设定项
.optDevices(Engine.getInstance().getDevices(1)) // 设置所需的GPU数量
.addEvaluator(new Accuracy()) // 模型精度
.addTrainingListeners(TrainingListener.Defaults.logging()); // 日志
Trainer trainer = model.newTrainer(config);
updater = (batchSize, subManager) -> trainer.step();
}
Function<String, String> predict =
(prefix) -> predictCh8(prefix, 50, net, vocab, device, manager);
// 训练和预测
double ppl = 0.0;
double speed = 0.0;
for (int epoch = 0; epoch < numEpochs; epoch++) {
Pair<Double, Double> pair =
trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager);
ppl = pair.getKey();
speed = pair.getValue();
if ((epoch + 1) % 10 == 0) {
animator.add(epoch + 1, (float) ppl, "ppl");
animator.show();
}
}
System.out.format(
"perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString());
System.out.println(predict.apply("time traveller"));
System.out.println(predict.apply("traveller"));
}
/** 在一个epoch内训练一个模型 */
public static Pair<Double, Double> trainEpochCh8(
Object net,
RandomAccessDataset dataset,
Loss loss,
Functions.voidTwoFunction<Integer, NDManager> updater,
Device device,
boolean useRandomIter,
NDManager manager)
throws IOException, TranslateException {
StopWatch watch = new StopWatch();
watch.start();
Accumulator metric = new Accumulator(2); // 训练损失总和,tokens
try (NDManager childManager = manager.newSubManager()) {
NDList state = null;
for (Batch batch : dataset.getData(childManager)) {
NDArray X = batch.getData().head().toDevice(device, true);
NDArray Y = batch.getLabels().head().toDevice(device, true);
if (state == null || useRandomIter) {
// 在第一次迭代或
// 使用随机抽样
if (net instanceof RNNModelScratch) {
state =
((RNNModelScratch) net)
.beginState((int) X.getShape().getShape()[0], device);
}
} else {
for (NDArray s : state) {
s.stopGradient();
}
}
if (state != null) {
state.attach(childManager);
}
NDArray y = Y.transpose().reshape(new Shape(-1));
X = X.toDevice(device, false);
y = y.toDevice(device, false);
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
NDArray yHat;
if (net instanceof RNNModelScratch) {
Pair<NDArray, NDList> pairResult = ((RNNModelScratch) net).forward(X, state);
yHat = pairResult.getKey();
state = pairResult.getValue();
} else {
NDList pairResult;
if (state == null) {
// 开始状态
pairResult =
((AbstractBlock) net)
.forward(
new ParameterStore(manager, false),
new NDList(X),
true);
} else {
pairResult =
((AbstractBlock) net)
.forward(
new ParameterStore(manager, false),
new NDList(X).addAll(state),
true);
}
yHat = pairResult.get(0);
state = pairResult.subNDList(1);
}
NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();
gc.backward(l);
metric.add(new float[] {l.getFloat() * y.size(), y.size()});
}
gradClipping(net, 1, childManager);
updater.apply(1, childManager); // 因为已经调用了“mean”函数
}
}
return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());
}
/** 修剪梯度 */
public static void gradClipping(Object net, int theta, NDManager manager) {
double result = 0;
NDList params;
if (net instanceof RNNModelScratch) {
params = ((RNNModelScratch) net).params;
} else {
params = new NDList();
for (Pair<String, Parameter> pair : ((AbstractBlock) net).getParameters()) {
params.add(pair.getValue().getArray());
}
}
for (NDArray p : params) {
NDArray gradient = p.getGradient().stopGradient();
gradient.attach(manager);
result += gradient.pow(2).sum().getFloat();
}
double norm = Math.sqrt(result);
if (norm > theta) {
for (NDArray param : params) {
NDArray gradient = param.getGradient();
gradient.muli(theta / norm);
}
}
}
现在,我们将利用刚刚创建的数据集并分配所需的参数。
int batchSize = 32;
int numSteps = 35;
TimeMachineDataset dataset = new TimeMachineDataset.Builder()
.setManager(manager).setMaxTokens(10000).setSampling(batchSize, false)
.setSteps(numSteps).build();
dataset.prepare();
Vocab vocab = dataset.getVocab();
8.6.2. 定义模型¶
高级API提供递归神经网络的实现。
我们构造了一个包含单个隐层和256个隐单元的递归神经网络层rnn_layer
。
事实上,我们甚至还没有讨论多层的含义——这将发生在
Section 9.3.
现在,只需说多个层相当于一个RNN层的输出被用作下一个 RNN 层的输入即可。
int numHiddens = 256;
RNN rnnLayer = RNN.builder().setNumLayers(1)
.setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();
初始化隐藏状态很简单。 我们调用成员函数 beginState
(在DJL中,我们不必在以后第一次运行 forward
时运行 beginState
来指定结果状态,因为在我们第一次运行forward
,
时,DJL会运行此逻辑,但出于演示目的,我们将在此处创建此逻辑)。
这将返回一个列表 (state
) 包含 初始隐藏状态
对于minibatch中的每个示例, 是谁的形状
(隐藏层的数量、批次大小、隐藏单元的数量)。 对于某些型号 待稍后介绍
(例如,长-短期记忆), 这样一份名单也不例外 包含其他信息。
public static NDList beginState(int batchSize, int numLayers, int numHiddens) {
return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));
}
NDList state = beginState(batchSize, 1, numHiddens);
System.out.println(state.size());
System.out.println(state.get(0).getShape());
1
(1, 32, 256)
具有隐藏状态和输入, 我们可以用 更新后的隐藏状态。 应该强调的是 the
“output” (Y
) of rnnLayer
不 涉及输出层的计算: 指 每个
时间步骤的隐藏状态, 它们可以作为输入 到后续的输出层。
此外, (stateNew
) 返回的更新的隐藏状态 rnnLayer
指隐藏状态
在小批量的最后时间步。 它可用于初始化
一个epoch下一个小批量的隐藏状态 在顺序分区中。 对于多个隐藏层,
将存储每个层的隐藏状态 在这个变量(stateNew
)中。 对于某些型号
待稍后介绍 (例如,长-短期记忆), 这个变量也是 包含其他信息。
NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));
NDList input = new NDList(X, state.get(0));
rnnLayer.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList forwardOutput = rnnLayer.forward(new ParameterStore(manager, false), input, false);
NDArray Y = forwardOutput.get(0);
NDArray stateNew = forwardOutput.get(1);
System.out.println(Y.getShape());
System.out.println(stateNew.getShape());
(35, 32, 256)
(1, 32, 256)
类似于 Section 8.5, 我们定义了一个 RNNModel
类
对于完整的RNN模型。 注意 rnnLayer
只包含隐藏的重复层,我们需要创建一个单独的输出层。
public class RNNModel extends AbstractBlock {
private RNN rnnLayer;
private Linear dense;
private int vocabSize;
public RNNModel(RNN rnnLayer, int vocabSize) {
this.rnnLayer = rnnLayer;
this.addChildBlock("rnn", rnnLayer);
this.vocabSize = vocabSize;
this.dense = Linear.builder().setUnits(vocabSize).build();
this.addChildBlock("linear", dense);
}
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
NDArray X = inputs.get(0).transpose().oneHot(vocabSize);
inputs.set(0, X);
NDList result = rnnLayer.forward(parameterStore, inputs, training);
NDArray Y = result.get(0);
NDArray state = result.get(1);
int shapeLength = Y.getShape().dimension();
NDList output = dense.forward(parameterStore, new NDList(Y
.reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training);
return new NDList(output.get(0), state);
}
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Shape shape = rnnLayer.getOutputShapes(new Shape[]{inputShapes[0]})[0];
dense.initialize(manager, dataType, new Shape(vocabSize, shape.get(shape.dimension() - 1)));
}
/* 我们不会实现它,因为我们不会使用它,但它是抽象块的一部分 */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[0];
}
}
8.6.3. 训练与预测¶
在训练模型之前,让我们使用具有随机权重的模型进行预测。
Device device = manager.getDevice();
RNNModel net = new RNNModel(rnnLayer, vocab.length());
net.initialize(manager, DataType.FLOAT32, X.getShape());
predictCh8("time traveller", 10, net, vocab, device, manager);
time travellermgmmmmmm
很明显,这种模式根本不起作用。
接下来,我们使用在:numref:sec_rnn_scratch
中定义的相同超参数调用
trainCh8
,并使用高级API训练我们的模型。
int numEpochs = Integer.getInteger("MAX_EPOCH", 500);
int lr = 1;
trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.071 ms.
perplexity: 1.2, 82972.4 tokens/sec on gpu(0)
time traveller but now you begin this thatwell intithe fire with
travellerit s all have a realexistencethere is however a te
与上一节相比,由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。
8.6.4. 总结¶
深度学习框架的高级API提供了RNN层的实现。
高级API的RNN层返回输出和更新的隐藏状态,其中输出不涉及输出层计算。
与从头开始使用其实现相比,使用高级API可以更快地进行RNN训练。
8.6.5. 练习¶
尝试使用高级API,你能使循环神经网络模型过拟合吗?
如果在RNN模型中增加隐藏层的数量,会发生什么情况?你能使模型工作吗?
使用RNN实现 Section 8.1 的自回归模型。