Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_recurrent-neural-networks/rnn-concise.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_recurrent-neural-networks/rnn-concise.ipynb .. _sec_rnn-concise: 递归神经网络的简洁实现 ====================== 而:numref:\ ``sec_rnn_scratch`` 对于了解rnn是如何实现的很有指导意义, 这既不方便也不快捷。 本节将展示如何更有效地实现相同的语言模型 使用高级API提供的函数 一个深入学习的框架。 我们像以前一样从读取时间机器数据集开始。 .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/PlotUtils.java %load ../utils/Accumulator.java %load ../utils/Animator.java %load ../utils/Functions.java %load ../utils/StopWatch.java %load ../utils/Training.java %load ../utils/timemachine/Vocab.java %load ../utils/timemachine/RNNModelScratch.java %load ../utils/timemachine/TimeMachine.java .. code:: java import ai.djl.training.dataset.Record; .. code:: java NDManager manager = NDManager.newBaseManager(); 在DJL中创建数据集 ----------------- 在 DJL 中,处理数据集的理想而简洁的方法是使用内置的数据集,这些数据集可以轻松地围绕现有的 ``Ndarray``\ ,或者创建从 ``RandomAccessDataset`` 类扩展而来的自己的数据集。对于这一部分,我们将实现我们自己的。有关在 DJL 中创建自己的数据集的更多信息,请参阅:https://djl.ai/docs/development/how\_to\_use\_dataset.html 我们对 ``TimeMachineDataset`` 简洁地实现将替换先前创建的 ``SeqDataLoader`` 类。使用 DJL 格式的数据集,将允许我们使用已经内置的函数,这样我们就不必从头开始实现大多数功能。我们必须实现一个生成器、一个包含将数据保存到TimeMachineDataset对象的过程的prepare函数,以及一个get函数。 .. code:: java public static class TimeMachineDataset extends RandomAccessDataset { private Vocab vocab; private NDArray data; private NDArray labels; private int numSteps; private int maxTokens; private int batchSize; private NDManager manager; private boolean prepared; public TimeMachineDataset(Builder builder) { super(builder); this.numSteps = builder.numSteps; this.maxTokens = builder.maxTokens; this.batchSize = builder.getSampler().getBatchSize(); this.manager = builder.manager; this.data = this.manager.create(new Shape(0,35), DataType.INT32); this.labels = this.manager.create(new Shape(0,35), DataType.INT32); this.prepared = false; } @Override public Record get(NDManager manager, long index) throws IOException { NDArray X = data.get(new NDIndex("{}", index)); NDArray Y = labels.get(new NDIndex("{}", index)); return new Record(new NDList(X), new NDList(Y)); } @Override protected long availableSize() { return data.getShape().get(0); } @Override public void prepare(Progress progress) throws IOException, TranslateException { if (prepared) { return; } Pair, Vocab> corpusVocabPair = null; try { corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens); } catch (Exception e) { e.printStackTrace(); // 在tokenize()函数期间,异常可能来自未知的token类型。 } List corpus = corpusVocabPair.getKey(); this.vocab = corpusVocabPair.getValue(); // 从一个随机偏移量(包括'numSteps-1')开始到分区a // 序列 int offset = new Random().nextInt(numSteps); int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize; NDArray Xs = manager.create( corpus.subList(offset, offset + numTokens).stream() .mapToInt(Integer::intValue) .toArray()); NDArray Ys = manager.create( corpus.subList(offset + 1, offset + 1 + numTokens).stream() .mapToInt(Integer::intValue) .toArray()); Xs = Xs.reshape(new Shape(batchSize, -1)); Ys = Ys.reshape(new Shape(batchSize, -1)); int numBatches = (int) Xs.getShape().get(1) / numSteps; NDList xNDList = new NDList(); NDList yNDList = new NDList(); for (int i = 0; i < numSteps * numBatches; i += numSteps) { NDArray X = Xs.get(new NDIndex(":, {}:{}", i, i + numSteps)); NDArray Y = Ys.get(new NDIndex(":, {}:{}", i, i + numSteps)); xNDList.add(X); yNDList.add(Y); } this.data = NDArrays.concat(xNDList); xNDList.close(); this.labels = NDArrays.concat(yNDList); yNDList.close(); this.prepared = true; } public Vocab getVocab() { return this.vocab; } public static final class Builder extends BaseBuilder { int numSteps; int maxTokens; NDManager manager; @Override protected Builder self() { return this; } public Builder setSteps(int steps) { this.numSteps = steps; return this; } public Builder setMaxTokens(int maxTokens) { this.maxTokens = maxTokens; return this; } public Builder setManager(NDManager manager) { this.manager = manager; return this; } public TimeMachineDataset build() throws IOException, TranslateException { TimeMachineDataset dataset = new TimeMachineDataset(this); return dataset; } } } 因此,我们将更新上一节中函数 ``predictCh8``, ``trainCh8``, ``trainEpochCh8``, 和 ``gradClipping`` 的代码,以包含数据集逻辑,并允许函数从DJL接受 ``AbstractBlock`` 而不是只接受 ``RNNModelScratch``. .. code:: java /** 在`prefix`后面生成新字符。 */ public static String predictCh8( String prefix, int numPreds, Object net, Vocab vocab, Device device, NDManager manager) { List outputs = new ArrayList<>(); outputs.add(vocab.getIdx("" + prefix.charAt(0))); Functions.SimpleFunction getInput = () -> manager.create(outputs.get(outputs.size() - 1)) .toDevice(device, false) .reshape(new Shape(1, 1)); if (net instanceof RNNModelScratch) { RNNModelScratch castedNet = (RNNModelScratch) net; NDList state = castedNet.beginState(1, device); for (char c : prefix.substring(1).toCharArray()) { // 预热期 state = (NDList) castedNet.forward(getInput.apply(), state).getValue(); outputs.add(vocab.getIdx("" + c)); } NDArray y; for (int i = 0; i < numPreds; i++) { Pair pair = castedNet.forward(getInput.apply(), state); y = pair.getKey(); state = pair.getValue(); outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L)); } } else { AbstractBlock castedNet = (AbstractBlock) net; NDList state = null; for (char c : prefix.substring(1).toCharArray()) { // 预热期 if (state == null) { // Begin state state = castedNet .forward( new ParameterStore(manager, false), new NDList(getInput.apply()), false) .subNDList(1); } else { state = castedNet .forward( new ParameterStore(manager, false), new NDList(getInput.apply()).addAll(state), false) .subNDList(1); } outputs.add(vocab.getIdx("" + c)); } NDArray y; for (int i = 0; i < numPreds; i++) { NDList pair = castedNet.forward( new ParameterStore(manager, false), new NDList(getInput.apply()).addAll(state), false); y = pair.get(0); state = pair.subNDList(1); outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L)); } } StringBuilder output = new StringBuilder(); for (int i : outputs) { output.append(vocab.idxToToken.get(i)); } return output.toString(); } .. code:: java /** 训练一个模型 */ public static void trainCh8( Object net, RandomAccessDataset dataset, Vocab vocab, int lr, int numEpochs, Device device, boolean useRandomIter, NDManager manager) throws IOException, TranslateException { SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss(); Animator animator = new Animator(); Functions.voidTwoFunction updater; if (net instanceof RNNModelScratch) { RNNModelScratch castedNet = (RNNModelScratch) net; updater = (batchSize, subManager) -> Training.sgd(castedNet.params, lr, batchSize, subManager); } else { // 已初始化网络 AbstractBlock castedNet = (AbstractBlock) net; Model model = Model.newInstance("model"); model.setBlock(castedNet); Tracker lrt = Tracker.fixed(lr); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // 优化器(损失函数) .optInitializer( new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // 设置初始值设定项 .optDevices(Engine.getInstance().getDevices(1)) // 设置所需的GPU数量 .addEvaluator(new Accuracy()) // 模型精度 .addTrainingListeners(TrainingListener.Defaults.logging()); // 日志 Trainer trainer = model.newTrainer(config); updater = (batchSize, subManager) -> trainer.step(); } Function predict = (prefix) -> predictCh8(prefix, 50, net, vocab, device, manager); // 训练和预测 double ppl = 0.0; double speed = 0.0; for (int epoch = 0; epoch < numEpochs; epoch++) { Pair pair = trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager); ppl = pair.getKey(); speed = pair.getValue(); if ((epoch + 1) % 10 == 0) { animator.add(epoch + 1, (float) ppl, "ppl"); animator.show(); } } System.out.format( "perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString()); System.out.println(predict.apply("time traveller")); System.out.println(predict.apply("traveller")); } .. code:: java /** 在一个epoch内训练一个模型 */ public static Pair trainEpochCh8( Object net, RandomAccessDataset dataset, Loss loss, Functions.voidTwoFunction updater, Device device, boolean useRandomIter, NDManager manager) throws IOException, TranslateException { StopWatch watch = new StopWatch(); watch.start(); Accumulator metric = new Accumulator(2); // 训练损失总和,tokens try (NDManager childManager = manager.newSubManager()) { NDList state = null; for (Batch batch : dataset.getData(childManager)) { NDArray X = batch.getData().head().toDevice(device, true); NDArray Y = batch.getLabels().head().toDevice(device, true); if (state == null || useRandomIter) { // 在第一次迭代或 // 使用随机抽样 if (net instanceof RNNModelScratch) { state = ((RNNModelScratch) net) .beginState((int) X.getShape().getShape()[0], device); } } else { for (NDArray s : state) { s.stopGradient(); } } if (state != null) { state.attach(childManager); } NDArray y = Y.transpose().reshape(new Shape(-1)); X = X.toDevice(device, false); y = y.toDevice(device, false); try (GradientCollector gc = Engine.getInstance().newGradientCollector()) { NDArray yHat; if (net instanceof RNNModelScratch) { Pair pairResult = ((RNNModelScratch) net).forward(X, state); yHat = pairResult.getKey(); state = pairResult.getValue(); } else { NDList pairResult; if (state == null) { // 开始状态 pairResult = ((AbstractBlock) net) .forward( new ParameterStore(manager, false), new NDList(X), true); } else { pairResult = ((AbstractBlock) net) .forward( new ParameterStore(manager, false), new NDList(X).addAll(state), true); } yHat = pairResult.get(0); state = pairResult.subNDList(1); } NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean(); gc.backward(l); metric.add(new float[] {l.getFloat() * y.size(), y.size()}); } gradClipping(net, 1, childManager); updater.apply(1, childManager); // 因为已经调用了“mean”函数 } } return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop()); } .. code:: java /** 修剪梯度 */ public static void gradClipping(Object net, int theta, NDManager manager) { double result = 0; NDList params; if (net instanceof RNNModelScratch) { params = ((RNNModelScratch) net).params; } else { params = new NDList(); for (Pair pair : ((AbstractBlock) net).getParameters()) { params.add(pair.getValue().getArray()); } } for (NDArray p : params) { NDArray gradient = p.getGradient().stopGradient(); gradient.attach(manager); result += gradient.pow(2).sum().getFloat(); } double norm = Math.sqrt(result); if (norm > theta) { for (NDArray param : params) { NDArray gradient = param.getGradient(); gradient.muli(theta / norm); } } } 现在,我们将利用刚刚创建的数据集并分配所需的参数。 .. code:: java int batchSize = 32; int numSteps = 35; TimeMachineDataset dataset = new TimeMachineDataset.Builder() .setManager(manager).setMaxTokens(10000).setSampling(batchSize, false) .setSteps(numSteps).build(); dataset.prepare(); Vocab vocab = dataset.getVocab(); 定义模型 -------- 高级API提供递归神经网络的实现。 我们构造了一个包含单个隐层和256个隐单元的递归神经网络层\ ``rnn_layer``\ 。 事实上,我们甚至还没有讨论多层的含义——这将发生在 :numref:`sec_deep_rnn`. 现在,只需说多个层相当于一个RNN层的输出被用作下一个 RNN 层的输入即可。 .. code:: java int numHiddens = 256; RNN rnnLayer = RNN.builder().setNumLayers(1) .setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build(); 初始化隐藏状态很简单。 我们调用成员函数 ``beginState`` (在DJL中,我们不必在以后第一次运行 ``forward`` 时运行 ``beginState`` 来指定结果状态,因为在我们第一次运行\ ``forward``, 时,DJL会运行此逻辑,但出于演示目的,我们将在此处创建此逻辑)。 这将返回一个列表 (``state``) 包含 初始隐藏状态 对于minibatch中的每个示例, 是谁的形状 (隐藏层的数量、批次大小、隐藏单元的数量)。 对于某些型号 待稍后介绍 (例如,长-短期记忆), 这样一份名单也不例外 包含其他信息。 .. code:: java public static NDList beginState(int batchSize, int numLayers, int numHiddens) { return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens))); } NDList state = beginState(batchSize, 1, numHiddens); System.out.println(state.size()); System.out.println(state.get(0).getShape()); .. parsed-literal:: :class: output 1 (1, 32, 256) 具有隐藏状态和输入, 我们可以用 更新后的隐藏状态。 应该强调的是 the "output" (``Y``) of ``rnnLayer`` *不* 涉及输出层的计算: 指 *每个* 时间步骤的隐藏状态, 它们可以作为输入 到后续的输出层。 此外, (``stateNew``) 返回的更新的隐藏状态 ``rnnLayer`` 指隐藏状态 在小批量的\ *最后*\ 时间步。 它可用于初始化 一个epoch下一个小批量的隐藏状态 在顺序分区中。 对于多个隐藏层, 将存储每个层的隐藏状态 在这个变量(\ ``stateNew``)中。 对于某些型号 待稍后介绍 (例如,长-短期记忆), 这个变量也是 包含其他信息。 .. code:: java NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length())); NDList input = new NDList(X, state.get(0)); rnnLayer.initialize(manager, DataType.FLOAT32, input.getShapes()); NDList forwardOutput = rnnLayer.forward(new ParameterStore(manager, false), input, false); NDArray Y = forwardOutput.get(0); NDArray stateNew = forwardOutput.get(1); System.out.println(Y.getShape()); System.out.println(stateNew.getShape()); .. parsed-literal:: :class: output (35, 32, 256) (1, 32, 256) 类似于 :numref:`sec_rnn_scratch`, 我们定义了一个 ``RNNModel`` 类 对于完整的RNN模型。 注意 ``rnnLayer`` 只包含隐藏的重复层,我们需要创建一个单独的输出层。 .. code:: java public class RNNModel extends AbstractBlock { private RNN rnnLayer; private Linear dense; private int vocabSize; public RNNModel(RNN rnnLayer, int vocabSize) { this.rnnLayer = rnnLayer; this.addChildBlock("rnn", rnnLayer); this.vocabSize = vocabSize; this.dense = Linear.builder().setUnits(vocabSize).build(); this.addChildBlock("linear", dense); } @Override protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList params) { NDArray X = inputs.get(0).transpose().oneHot(vocabSize); inputs.set(0, X); NDList result = rnnLayer.forward(parameterStore, inputs, training); NDArray Y = result.get(0); NDArray state = result.get(1); int shapeLength = Y.getShape().dimension(); NDList output = dense.forward(parameterStore, new NDList(Y .reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training); return new NDList(output.get(0), state); } @Override public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) { Shape shape = rnnLayer.getOutputShapes(new Shape[]{inputShapes[0]})[0]; dense.initialize(manager, dataType, new Shape(vocabSize, shape.get(shape.dimension() - 1))); } /* 我们不会实现它,因为我们不会使用它,但它是抽象块的一部分 */ @Override public Shape[] getOutputShapes(Shape[] inputShapes) { return new Shape[0]; } } 训练与预测 ---------- 在训练模型之前,让我们使用具有随机权重的模型进行预测。 .. code:: java Device device = manager.getDevice(); RNNModel net = new RNNModel(rnnLayer, vocab.length()); net.initialize(manager, DataType.FLOAT32, X.getShape()); predictCh8("time traveller", 10, net, vocab, device, manager); .. parsed-literal:: :class: output time travellermgmmmmmm 很明显,这种模式根本不起作用。 接下来,我们使用在:numref:\ ``sec_rnn_scratch`` 中定义的相同超参数调用 ``trainCh8``\ ,并使用高级API训练我们的模型。 .. code:: java int numEpochs = Integer.getInteger("MAX_EPOCH", 500); int lr = 1; trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager); .. parsed-literal:: :class: output INFO Training on: 1 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.071 ms. .. raw:: html
.. parsed-literal:: :class: output perplexity: 1.2, 82972.4 tokens/sec on gpu(0) time traveller but now you begin this thatwell intithe fire with travellerit s all have a realexistencethere is however a te 与上一节相比,由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。 总结 ---- - 深度学习框架的高级API提供了RNN层的实现。 - 高级API的RNN层返回输出和更新的隐藏状态,其中输出不涉及输出层计算。 - 与从头开始使用其实现相比,使用高级API可以更快地进行RNN训练。 练习 ---- 1. 尝试使用高级API,你能使循环神经网络模型过拟合吗? 2. 如果在RNN模型中增加隐藏层的数量,会发生什么情况?你能使模型工作吗? 3. 使用RNN实现 :numref:`sec_sequence` 的自回归模型。