Run this notebook online: or Colab:
8.6. 递归神经网络的简洁实现¶
而:numref:sec_rnn_scratch
对于了解rnn是如何实现的很有指导意义,
这既不方便也不快捷。 本节将展示如何更有效地实现相同的语言模型
使用高级API提供的函数 一个深入学习的框架。
我们像以前一样从读取时间机器数据集开始。
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/PlotUtils.java
%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Functions.java
%load ../utils/StopWatch.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
import ai.djl.training.dataset.Record;
NDManager manager = NDManager.newBaseManager();
8.6.1. 在DJL中创建数据集¶
在 DJL
中,处理数据集的理想而简洁的方法是使用内置的数据集,这些数据集可以轻松地围绕现有的
Ndarray
,或者创建从 RandomAccessDataset
类扩展而来的自己的数据集。对于这一部分,我们将实现我们自己的。有关在 DJL
中创建自己的数据集的更多信息,请参阅:https://djl.ai/docs/development/how_to_use_dataset.html
我们对 TimeMachineDataset
简洁地实现将替换先前创建的
SeqDataLoader
类。使用 DJL
格式的数据集,将允许我们使用已经内置的函数,这样我们就不必从头开始实现大多数功能。我们必须实现一个生成器、一个包含将数据保存到TimeMachineDataset对象的过程的prepare函数,以及一个get函数。
public static class TimeMachineDataset extends RandomAccessDataset {
private Vocab vocab;
private NDArray data;
private NDArray labels;
private int numSteps;
private int maxTokens;
private int batchSize;
private NDManager manager;
private boolean prepared;
public TimeMachineDataset(Builder builder) {
super(builder);
this.numSteps = builder.numSteps;
this.maxTokens = builder.maxTokens;
this.batchSize = builder.getSampler().getBatchSize();
this.manager = builder.manager;
this.data = this.manager.create(new Shape(0,35), DataType.INT32);
this.labels = this.manager.create(new Shape(0,35), DataType.INT32);
this.prepared = false;
}
@Override
public Record get(NDManager manager, long index) throws IOException {
NDArray X = data.get(new NDIndex("{}", index));
NDArray Y = labels.get(new NDIndex("{}", index));
return new Record(new NDList(X), new NDList(Y));
}
@Override
protected long availableSize() {
return data.getShape().get(0);
}
@Override
public void prepare(Progress progress) throws IOException, TranslateException {
if (prepared) {
return;
}
Pair<List<Integer>, Vocab> corpusVocabPair = null;
try {
corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens);
} catch (Exception e) {
e.printStackTrace(); // 在tokenize()函数期间,异常可能来自未知的token类型。
}
List<Integer> corpus = corpusVocabPair.getKey();
this.vocab = corpusVocabPair.getValue();
// 从一个随机偏移量(包括'numSteps-1')开始到分区a
// 序列
int offset = new Random().nextInt(numSteps);
int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize;
NDArray Xs =
manager.create(
corpus.subList(offset, offset + numTokens).stream()
.mapToInt(Integer::intValue)
.toArray());
NDArray Ys =
manager.create(
corpus.subList(offset + 1, offset + 1 + numTokens).stream()
.mapToInt(Integer::intValue)
.toArray());
Xs = Xs.reshape(new Shape(batchSize, -1));
Ys = Ys.reshape(new Shape(batchSize, -1));
int numBatches = (int) Xs.getShape().get(1) / numSteps;
NDList xNDList = new NDList();
NDList yNDList = new NDList();
for (int i = 0; i < numSteps * numBatches; i += numSteps) {
NDArray X = Xs.get(new NDIndex(":, {}:{}", i, i + numSteps));
NDArray Y = Ys.get(new NDIndex(":, {}:{}", i, i + numSteps));
xNDList.add(X);
yNDList.add(Y);
}
this.data = NDArrays.concat(xNDList);
xNDList.close();
this.labels = NDArrays.concat(yNDList);
yNDList.close();
this.prepared = true;
}
public Vocab getVocab() {
return this.vocab;
}
public static final class Builder extends BaseBuilder<Builder> {
int numSteps;
int maxTokens;
NDManager manager;
@Override
protected Builder self() { return this; }
public Builder setSteps(int steps) {
this.numSteps = steps;
return this;
}
public Builder setMaxTokens(int maxTokens) {
this.maxTokens = maxTokens;
return this;
}
public Builder setManager(NDManager manager) {
this.manager = manager;
return this;
}
public TimeMachineDataset build() throws IOException, TranslateException {
TimeMachineDataset dataset = new TimeMachineDataset(this);
return dataset;
}
}
}
因此,我们将更新上一节中函数 predictCh8
, trainCh8
,
trainEpochCh8
, 和 gradClipping
的代码,以包含数据集逻辑,并允许函数从DJL接受 AbstractBlock
而不是只接受 RNNModelScratch
.
/** 在`prefix`后面生成新字符。 */
public static String predictCh8(
String prefix,
int numPreds,
Object net,
Vocab vocab,
Device device,
NDManager manager) {
List<Integer> outputs = new ArrayList<>();
outputs.add(vocab.getIdx("" + prefix.charAt(0)));
Functions.SimpleFunction<NDArray> getInput =
() ->
manager.create(outputs.get(outputs.size() - 1))
.toDevice(device, false)
.reshape(new Shape(1, 1));
if (net instanceof RNNModelScratch) {
RNNModelScratch castedNet = (RNNModelScratch) net;
NDList state = castedNet.beginState(1, device);
for (char c : prefix.substring(1).toCharArray()) { // 预热期
state = (NDList) castedNet.forward(getInput.apply(), state).getValue();
outputs.add(vocab.getIdx("" + c));
}
NDArray y;
for (int i = 0; i < numPreds; i++) {
Pair<NDArray, NDList> pair = castedNet.forward(getInput.apply(), state);
y = pair.getKey();
state = pair.getValue();
outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));
}
} else {
AbstractBlock castedNet = (AbstractBlock) net;
NDList state = null;
for (char c : prefix.substring(1).toCharArray()) { // 预热期
if (state == null) {
// Begin state
state =
castedNet
.forward(
new ParameterStore(manager, false),
new NDList(getInput.apply()),
false)
.subNDList(1);
} else {
state =
castedNet
.forward(
new ParameterStore(manager, false),
new NDList(getInput.apply()).addAll(state),
false)
.subNDList(1);
}
outputs.add(vocab.getIdx("" + c));
}
NDArray y;
for (int i = 0; i < numPreds; i++) {
NDList pair =
castedNet.forward(
new ParameterStore(manager, false),
new NDList(getInput.apply()).addAll(state),
false);
y = pair.get(0);
state = pair.subNDList(1);
outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));
}
}
StringBuilder output = new StringBuilder();
for (int i : outputs) {
output.append(vocab.idxToToken.get(i));
}
return output.toString();
}
/** 训练一个模型 */
public static void trainCh8(
Object net,
RandomAccessDataset dataset,
Vocab vocab,
int lr,
int numEpochs,
Device device,
boolean useRandomIter,
NDManager manager)
throws IOException, TranslateException {
SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();
Animator animator = new Animator();
Functions.voidTwoFunction<Integer, NDManager> updater;
if (net instanceof RNNModelScratch) {
RNNModelScratch castedNet = (RNNModelScratch) net;
updater =
(batchSize, subManager) ->
Training.sgd(castedNet.params, lr, batchSize, subManager);
} else {
// 已初始化网络
AbstractBlock castedNet = (AbstractBlock) net;
Model model = Model.newInstance("model");
model.setBlock(castedNet);
Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
DefaultTrainingConfig config =
new DefaultTrainingConfig(loss)
.optOptimizer(sgd) // 优化器(损失函数)
.optInitializer(
new NormalInitializer(0.01f),
Parameter.Type.WEIGHT) // 设置初始值设定项
.optDevices(Engine.getInstance().getDevices(1)) // 设置所需的GPU数量
.addEvaluator(new Accuracy()) // 模型精度
.addTrainingListeners(TrainingListener.Defaults.logging()); // 日志
Trainer trainer = model.newTrainer(config);
updater = (batchSize, subManager) -> trainer.step();
}
Function<String, String> predict =
(prefix) -> predictCh8(prefix, 50, net, vocab, device, manager);
// 训练和预测
double ppl = 0.0;
double speed = 0.0;
for (int epoch = 0; epoch < numEpochs; epoch++) {
Pair<Double, Double> pair =
trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager);
ppl = pair.getKey();
speed = pair.getValue();
if ((epoch + 1) % 10 == 0) {
animator.add(epoch + 1, (float) ppl, "ppl");
animator.show();
}
}
System.out.format(
"perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString());
System.out.println(predict.apply("time traveller"));
System.out.println(predict.apply("traveller"));
}
/** 在一个epoch内训练一个模型 */
public static Pair<Double, Double> trainEpochCh8(
Object net,
RandomAccessDataset dataset,
Loss loss,
Functions.voidTwoFunction<Integer, NDManager> updater,
Device device,
boolean useRandomIter,
NDManager manager)
throws IOException, TranslateException {
StopWatch watch = new StopWatch();
watch.start();
Accumulator metric = new Accumulator(2); // 训练损失总和,tokens
try (NDManager childManager = manager.newSubManager()) {
NDList state = null;
for (Batch batch : dataset.getData(childManager)) {
NDArray X = batch.getData().head().toDevice(device, true);
NDArray Y = batch.getLabels().head().toDevice(device, true);
if (state == null || useRandomIter) {
// 在第一次迭代或
// 使用随机抽样
if (net instanceof RNNModelScratch) {
state =
((RNNModelScratch) net)
.beginState((int) X.getShape().getShape()[0], device);
}
} else {
for (NDArray s : state) {
s.stopGradient();
}
}
if (state != null) {
state.attach(childManager);
}
NDArray y = Y.transpose().reshape(new Shape(-1));
X = X.toDevice(device, false);
y = y.toDevice(device, false);
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
NDArray yHat;
if (net instanceof RNNModelScratch) {
Pair<NDArray, NDList> pairResult = ((RNNModelScratch) net).forward(X, state);
yHat = pairResult.getKey();
state = pairResult.getValue();
} else {
NDList pairResult;
if (state == null) {
// 开始状态
pairResult =
((AbstractBlock) net)
.forward(
new ParameterStore(manager, false),
new NDList(X),
true);
} else {
pairResult =
((AbstractBlock) net)
.forward(
new ParameterStore(manager, false),
new NDList(X).addAll(state),
true);
}
yHat = pairResult.get(0);
state = pairResult.subNDList(1);
}
NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();
gc.backward(l);
metric.add(new float[] {l.getFloat() * y.size(), y.size()});
}
gradClipping(net, 1, childManager);
updater.apply(1, childManager); // 因为已经调用了“mean”函数
}
}
return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());
}
/** 修剪梯度 */
public static void gradClipping(Object net, int theta, NDManager manager) {
double result = 0;
NDList params;
if (net instanceof RNNModelScratch) {
params = ((RNNModelScratch) net).params;
} else {
params = new NDList();
for (Pair<String, Parameter> pair : ((AbstractBlock) net).getParameters()) {
params.add(pair.getValue().getArray());
}
}
for (NDArray p : params) {
NDArray gradient = p.getGradient().stopGradient();
gradient.attach(manager);
result += gradient.pow(2).sum().getFloat();
}
double norm = Math.sqrt(result);
if (norm > theta) {
for (NDArray param : params) {
NDArray gradient = param.getGradient();
gradient.muli(theta / norm);
}
}
}
现在,我们将利用刚刚创建的数据集并分配所需的参数。
int batchSize = 32;
int numSteps = 35;
TimeMachineDataset dataset = new TimeMachineDataset.Builder()
.setManager(manager).setMaxTokens(10000).setSampling(batchSize, false)
.setSteps(numSteps).build();
dataset.prepare();
Vocab vocab = dataset.getVocab();
8.6.2. 定义模型¶
高级API提供递归神经网络的实现。
我们构造了一个包含单个隐层和256个隐单元的递归神经网络层rnn_layer
。
事实上,我们甚至还没有讨论多层的含义——这将发生在
Section 9.3.
现在,只需说多个层相当于一个RNN层的输出被用作下一个 RNN 层的输入即可。
int numHiddens = 256;
RNN rnnLayer = RNN.builder().setNumLayers(1)
.setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();
初始化隐藏状态很简单。 我们调用成员函数 beginState
(在DJL中,我们不必在以后第一次运行 forward
时运行 beginState
来指定结果状态,因为在我们第一次运行forward
,
时,DJL会运行此逻辑,但出于演示目的,我们将在此处创建此逻辑)。
这将返回一个列表 (state
) 包含 初始隐藏状态
对于minibatch中的每个示例, 是谁的形状
(隐藏层的数量、批次大小、隐藏单元的数量)。 对于某些型号 待稍后介绍
(例如,长-短期记忆), 这样一份名单也不例外 包含其他信息。
public static NDList beginState(int batchSize, int numLayers, int numHiddens) {
return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));
}
NDList state = beginState(batchSize, 1, numHiddens);
System.out.println(state.size());
System.out.println(state.get(0).getShape());
1
(1, 32, 256)
具有隐藏状态和输入, 我们可以用 更新后的隐藏状态。 应该强调的是 the
“output” (Y
) of rnnLayer
不 涉及输出层的计算: 指 每个
时间步骤的隐藏状态, 它们可以作为输入 到后续的输出层。
此外, (stateNew
) 返回的更新的隐藏状态 rnnLayer
指隐藏状态
在小批量的最后时间步。 它可用于初始化
一个epoch下一个小批量的隐藏状态 在顺序分区中。 对于多个隐藏层,
将存储每个层的隐藏状态 在这个变量(stateNew
)中。 对于某些型号
待稍后介绍 (例如,长-短期记忆), 这个变量也是 包含其他信息。
NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));
NDList input = new NDList(X, state.get(0));
rnnLayer.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList forwardOutput = rnnLayer.forward(new ParameterStore(manager, false), input, false);
NDArray Y = forwardOutput.get(0);
NDArray stateNew = forwardOutput.get(1);
System.out.println(Y.getShape());
System.out.println(stateNew.getShape());
(35, 32, 256)
(1, 32, 256)
类似于 Section 8.5, 我们定义了一个 RNNModel
类
对于完整的RNN模型。 注意 rnnLayer
只包含隐藏的重复层,我们需要创建一个单独的输出层。
public class RNNModel extends AbstractBlock {
private RNN rnnLayer;
private Linear dense;
private int vocabSize;
public RNNModel(RNN rnnLayer, int vocabSize) {
this.rnnLayer = rnnLayer;
this.addChildBlock("rnn", rnnLayer);
this.vocabSize = vocabSize;
this.dense = Linear.builder().setUnits(vocabSize).build();
this.addChildBlock("linear", dense);
}
@Override
protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
NDArray X = inputs.get(0).transpose().oneHot(vocabSize);
inputs.set(0, X);
NDList result = rnnLayer.forward(parameterStore, inputs, training);
NDArray Y = result.get(0);
NDArray state = result.get(1);
int shapeLength = Y.getShape().dimension();
NDList output = dense.forward(parameterStore, new NDList(Y
.reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training);
return new NDList(output.get(0), state);
}
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Shape shape = rnnLayer.getOutputShapes(new Shape[]{inputShapes[0]})[0];
dense.initialize(manager, dataType, new Shape(vocabSize, shape.get(shape.dimension() - 1)));
}
/* 我们不会实现它,因为我们不会使用它,但它是抽象块的一部分 */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
return new Shape[0];
}
}
8.6.3. 训练与预测¶
在训练模型之前,让我们使用具有随机权重的模型进行预测。
Device device = manager.getDevice();
RNNModel net = new RNNModel(rnnLayer, vocab.length());
net.initialize(manager, DataType.FLOAT32, X.getShape());
predictCh8("time traveller", 10, net, vocab, device, manager);
time travellermgmmmmmm
很明显,这种模式根本不起作用。
接下来,我们使用在:numref:sec_rnn_scratch
中定义的相同超参数调用
trainCh8
,并使用高级API训练我们的模型。
int numEpochs = Integer.getInteger("MAX_EPOCH", 500);
int lr = 1;
trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.8.0 in 0.083 ms.
perplexity: 1.2, 79130.5 tokens/sec on gpu(0)
time traveller for and pentive traveller for and praceesent time
travellerit would attery count my think so murmured the psy
与上一节相比,由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。
8.6.4. 总结¶
深度学习框架的高级API提供了RNN层的实现。
高级API的RNN层返回输出和更新的隐藏状态,其中输出不涉及输出层计算。
与从头开始使用其实现相比,使用高级API可以更快地进行RNN训练。
8.6.5. 练习¶
尝试使用高级API,你能使循环神经网络模型过拟合吗?
如果在RNN模型中增加隐藏层的数量,会发生什么情况?你能使模型工作吗?
使用RNN实现 Section 8.1 的自回归模型。