Run this notebook online:Binder or Colab: Colab

8.6. 递归神经网络的简洁实现

而:numref:sec_rnn_scratch 对于了解rnn是如何实现的很有指导意义, 这既不方便也不快捷。 本节将展示如何更有效地实现相同的语言模型 使用高级API提供的函数 一个深入学习的框架。 我们像以前一样从读取时间机器数据集开始。

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/PlotUtils.java

%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Functions.java
%load ../utils/StopWatch.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
import ai.djl.training.dataset.Record;
NDManager manager = NDManager.newBaseManager();

8.6.1. 在DJL中创建数据集

在 DJL 中,处理数据集的理想而简洁的方法是使用内置的数据集,这些数据集可以轻松地围绕现有的 Ndarray,或者创建从 RandomAccessDataset 类扩展而来的自己的数据集。对于这一部分,我们将实现我们自己的。有关在 DJL 中创建自己的数据集的更多信息,请参阅:https://djl.ai/docs/development/how_to_use_dataset.html

我们对 TimeMachineDataset 简洁地实现将替换先前创建的 SeqDataLoader 类。使用 DJL 格式的数据集,将允许我们使用已经内置的函数,这样我们就不必从头开始实现大多数功能。我们必须实现一个生成器、一个包含将数据保存到TimeMachineDataset对象的过程的prepare函数,以及一个get函数。

public static class TimeMachineDataset extends RandomAccessDataset {

    private Vocab vocab;
    private NDArray data;
    private NDArray labels;
    private int numSteps;
    private int maxTokens;
    private int batchSize;
    private NDManager manager;
    private boolean prepared;

    public TimeMachineDataset(Builder builder) {
        super(builder);
        this.numSteps = builder.numSteps;
        this.maxTokens = builder.maxTokens;
        this.batchSize = builder.getSampler().getBatchSize();
        this.manager = builder.manager;
        this.data = this.manager.create(new Shape(0,35), DataType.INT32);
        this.labels = this.manager.create(new Shape(0,35), DataType.INT32);
        this.prepared = false;
    }

    @Override
    public Record get(NDManager manager, long index) throws IOException {
        NDArray X = data.get(new NDIndex("{}", index));
        NDArray Y = labels.get(new NDIndex("{}", index));
        return new Record(new NDList(X), new NDList(Y));
    }

    @Override
    protected long availableSize() {
        return data.getShape().get(0);
    }

    @Override
    public void prepare(Progress progress) throws IOException, TranslateException {
        if (prepared) {
            return;
        }

        Pair<List<Integer>, Vocab> corpusVocabPair = null;
        try {
            corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens);
        } catch (Exception e) {
            e.printStackTrace(); // 在tokenize()函数期间,异常可能来自未知的token类型。
        }
        List<Integer> corpus = corpusVocabPair.getKey();
        this.vocab = corpusVocabPair.getValue();

        // 从一个随机偏移量(包括'numSteps-1')开始到分区a
        // 序列
        int offset = new Random().nextInt(numSteps);
        int numTokens = ((int) ((corpus.size() - offset - 1) / batchSize)) * batchSize;
        NDArray Xs =
                manager.create(
                        corpus.subList(offset, offset + numTokens).stream()
                                .mapToInt(Integer::intValue)
                                .toArray());
        NDArray Ys =
                manager.create(
                        corpus.subList(offset + 1, offset + 1 + numTokens).stream()
                                .mapToInt(Integer::intValue)
                                .toArray());
        Xs = Xs.reshape(new Shape(batchSize, -1));
        Ys = Ys.reshape(new Shape(batchSize, -1));
        int numBatches = (int) Xs.getShape().get(1) / numSteps;

        NDList xNDList = new NDList();
        NDList yNDList = new NDList();
        for (int i = 0; i < numSteps * numBatches; i += numSteps) {
            NDArray X = Xs.get(new NDIndex(":, {}:{}", i, i + numSteps));
            NDArray Y = Ys.get(new NDIndex(":, {}:{}", i, i + numSteps));
            xNDList.add(X);
            yNDList.add(Y);
        }
        this.data = NDArrays.concat(xNDList);
        xNDList.close();
        this.labels = NDArrays.concat(yNDList);
        yNDList.close();
        this.prepared = true;
    }

    public Vocab getVocab() {
        return this.vocab;
    }

    public static final class Builder extends BaseBuilder<Builder> {
        int numSteps;
        int maxTokens;
        NDManager manager;

        @Override
        protected Builder self() { return this; }

        public Builder setSteps(int steps) {
            this.numSteps = steps;
            return this;
        }

        public Builder setMaxTokens(int maxTokens) {
            this.maxTokens = maxTokens;
            return this;
        }

        public Builder setManager(NDManager manager) {
            this.manager = manager;
            return this;
        }

        public TimeMachineDataset build() throws IOException, TranslateException {
            TimeMachineDataset dataset = new TimeMachineDataset(this);
            return dataset;
        }
    }
}

因此,我们将更新上一节中函数 predictCh8, trainCh8, trainEpochCh8, 和 gradClipping 的代码,以包含数据集逻辑,并允许函数从DJL接受 AbstractBlock 而不是只接受 RNNModelScratch.

/** 在`prefix`后面生成新字符。 */
public static String predictCh8(
        String prefix,
        int numPreds,
        Object net,
        Vocab vocab,
        Device device,
        NDManager manager) {

    List<Integer> outputs = new ArrayList<>();
    outputs.add(vocab.getIdx("" + prefix.charAt(0)));
    Functions.SimpleFunction<NDArray> getInput =
            () ->
                    manager.create(outputs.get(outputs.size() - 1))
                            .toDevice(device, false)
                            .reshape(new Shape(1, 1));

    if (net instanceof RNNModelScratch) {
        RNNModelScratch castedNet = (RNNModelScratch) net;
        NDList state = castedNet.beginState(1, device);

        for (char c : prefix.substring(1).toCharArray()) { // 预热期
            state = (NDList) castedNet.forward(getInput.apply(), state).getValue();
            outputs.add(vocab.getIdx("" + c));
        }

        NDArray y;
        for (int i = 0; i < numPreds; i++) {
            Pair<NDArray, NDList> pair = castedNet.forward(getInput.apply(), state);
            y = pair.getKey();
            state = pair.getValue();

            outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));
        }
    } else {
        AbstractBlock castedNet = (AbstractBlock) net;
        NDList state = null;
        for (char c : prefix.substring(1).toCharArray()) { // 预热期
            if (state == null) {
                // Begin state
                state =
                        castedNet
                                .forward(
                                        new ParameterStore(manager, false),
                                        new NDList(getInput.apply()),
                                        false)
                                .subNDList(1);
            } else {
                state =
                        castedNet
                                .forward(
                                        new ParameterStore(manager, false),
                                        new NDList(getInput.apply()).addAll(state),
                                        false)
                                .subNDList(1);
            }
            outputs.add(vocab.getIdx("" + c));
        }

        NDArray y;
        for (int i = 0; i < numPreds; i++) {
            NDList pair =
                    castedNet.forward(
                            new ParameterStore(manager, false),
                            new NDList(getInput.apply()).addAll(state),
                            false);
            y = pair.get(0);
            state = pair.subNDList(1);

            outputs.add((int) y.argMax(1).reshape(new Shape(1)).getLong(0L));
        }
    }

    StringBuilder output = new StringBuilder();
    for (int i : outputs) {
        output.append(vocab.idxToToken.get(i));
    }
    return output.toString();
}
/** 训练一个模型 */
public static void trainCh8(
        Object net,
        RandomAccessDataset dataset,
        Vocab vocab,
        int lr,
        int numEpochs,
        Device device,
        boolean useRandomIter,
        NDManager manager)
        throws IOException, TranslateException {
    SoftmaxCrossEntropyLoss loss = new SoftmaxCrossEntropyLoss();
    Animator animator = new Animator();

    Functions.voidTwoFunction<Integer, NDManager> updater;
    if (net instanceof RNNModelScratch) {
        RNNModelScratch castedNet = (RNNModelScratch) net;
        updater =
                (batchSize, subManager) ->
                        Training.sgd(castedNet.params, lr, batchSize, subManager);
    } else {
        // 已初始化网络
        AbstractBlock castedNet = (AbstractBlock) net;
        Model model = Model.newInstance("model");
        model.setBlock(castedNet);

        Tracker lrt = Tracker.fixed(lr);
        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

        DefaultTrainingConfig config =
                new DefaultTrainingConfig(loss)
                        .optOptimizer(sgd) // 优化器(损失函数)
                        .optInitializer(
                                new NormalInitializer(0.01f),
                                Parameter.Type.WEIGHT) // 设置初始值设定项
                        .optDevices(Engine.getInstance().getDevices(1)) // 设置所需的GPU数量
                        .addEvaluator(new Accuracy()) // 模型精度
                        .addTrainingListeners(TrainingListener.Defaults.logging()); // 日志

        Trainer trainer = model.newTrainer(config);
        updater = (batchSize, subManager) -> trainer.step();
    }

    Function<String, String> predict =
            (prefix) -> predictCh8(prefix, 50, net, vocab, device, manager);
    // 训练和预测
    double ppl = 0.0;
    double speed = 0.0;
    for (int epoch = 0; epoch < numEpochs; epoch++) {
        Pair<Double, Double> pair =
                trainEpochCh8(net, dataset, loss, updater, device, useRandomIter, manager);
        ppl = pair.getKey();
        speed = pair.getValue();
        if ((epoch + 1) % 10 == 0) {
           animator.add(epoch + 1, (float) ppl, "ppl");
           animator.show();
        }
    }
    System.out.format(
            "perplexity: %.1f, %.1f tokens/sec on %s%n", ppl, speed, device.toString());
    System.out.println(predict.apply("time traveller"));
    System.out.println(predict.apply("traveller"));
}
/** 在一个epoch内训练一个模型 */
public static Pair<Double, Double> trainEpochCh8(
        Object net,
        RandomAccessDataset dataset,
        Loss loss,
        Functions.voidTwoFunction<Integer, NDManager> updater,
        Device device,
        boolean useRandomIter,
        NDManager manager)
        throws IOException, TranslateException {
    StopWatch watch = new StopWatch();
    watch.start();
    Accumulator metric = new Accumulator(2); // 训练损失总和,tokens

    try (NDManager childManager = manager.newSubManager()) {
        NDList state = null;
        for (Batch batch : dataset.getData(childManager)) {
            NDArray X = batch.getData().head().toDevice(device, true);
            NDArray Y = batch.getLabels().head().toDevice(device, true);
            if (state == null || useRandomIter) {
                // 在第一次迭代或
                // 使用随机抽样
                if (net instanceof RNNModelScratch) {
                    state =
                            ((RNNModelScratch) net)
                                    .beginState((int) X.getShape().getShape()[0], device);
                }
            } else {
                for (NDArray s : state) {
                    s.stopGradient();
                }
            }
            if (state != null) {
                state.attach(childManager);
            }

            NDArray y = Y.transpose().reshape(new Shape(-1));
            X = X.toDevice(device, false);
            y = y.toDevice(device, false);
            try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
                NDArray yHat;
                if (net instanceof RNNModelScratch) {
                    Pair<NDArray, NDList> pairResult = ((RNNModelScratch) net).forward(X, state);
                    yHat = pairResult.getKey();
                    state = pairResult.getValue();
                } else {
                    NDList pairResult;
                    if (state == null) {
                        // 开始状态
                        pairResult =
                                ((AbstractBlock) net)
                                        .forward(
                                                new ParameterStore(manager, false),
                                                new NDList(X),
                                                true);
                    } else {
                        pairResult =
                                ((AbstractBlock) net)
                                        .forward(
                                                new ParameterStore(manager, false),
                                                new NDList(X).addAll(state),
                                                true);
                    }
                    yHat = pairResult.get(0);
                    state = pairResult.subNDList(1);
                }

                NDArray l = loss.evaluate(new NDList(y), new NDList(yHat)).mean();
                gc.backward(l);
                metric.add(new float[] {l.getFloat() * y.size(), y.size()});
            }
            gradClipping(net, 1, childManager);
            updater.apply(1, childManager); // 因为已经调用了“mean”函数
        }
    }
    return new Pair<>(Math.exp(metric.get(0) / metric.get(1)), metric.get(1) / watch.stop());
}
/** 修剪梯度 */
public static void gradClipping(Object net, int theta, NDManager manager) {
    double result = 0;
    NDList params;
    if (net instanceof RNNModelScratch) {
        params = ((RNNModelScratch) net).params;
    } else {
        params = new NDList();
        for (Pair<String, Parameter> pair : ((AbstractBlock) net).getParameters()) {
            params.add(pair.getValue().getArray());
        }
    }
    for (NDArray p : params) {
        NDArray gradient = p.getGradient().stopGradient();
        gradient.attach(manager);
        result += gradient.pow(2).sum().getFloat();
    }
    double norm = Math.sqrt(result);
    if (norm > theta) {
        for (NDArray param : params) {
            NDArray gradient = param.getGradient();
            gradient.muli(theta / norm);
        }
    }
}

现在,我们将利用刚刚创建的数据集并分配所需的参数。

int batchSize = 32;
int numSteps = 35;

TimeMachineDataset dataset = new TimeMachineDataset.Builder()
        .setManager(manager).setMaxTokens(10000).setSampling(batchSize, false)
        .setSteps(numSteps).build();
dataset.prepare();
Vocab vocab = dataset.getVocab();

8.6.2. 定义模型

高级API提供递归神经网络的实现。 我们构造了一个包含单个隐层和256个隐单元的递归神经网络层rnn_layer。 事实上,我们甚至还没有讨论多层的含义——这将发生在 Section 9.3. 现在,只需说多个层相当于一个RNN层的输出被用作下一个 RNN 层的输入即可。

int numHiddens = 256;
RNN rnnLayer = RNN.builder().setNumLayers(1)
        .setStateSize(numHiddens).optReturnState(true).optBatchFirst(false).build();

初始化隐藏状态很简单。 我们调用成员函数 beginState (在DJL中,我们不必在以后第一次运行 forward 时运行 beginState 来指定结果状态,因为在我们第一次运行forward, 时,DJL会运行此逻辑,但出于演示目的,我们将在此处创建此逻辑)。 这将返回一个列表 (state) 包含 初始隐藏状态 对于minibatch中的每个示例, 是谁的形状 (隐藏层的数量、批次大小、隐藏单元的数量)。 对于某些型号 待稍后介绍 (例如,长-短期记忆), 这样一份名单也不例外 包含其他信息。

public static NDList beginState(int batchSize, int numLayers, int numHiddens) {
    return new NDList(manager.zeros(new Shape(numLayers, batchSize, numHiddens)));
}

NDList state = beginState(batchSize, 1, numHiddens);
System.out.println(state.size());
System.out.println(state.get(0).getShape());
1
(1, 32, 256)

具有隐藏状态和输入, 我们可以用 更新后的隐藏状态。 应该强调的是 the “output” (Y) of rnnLayer 涉及输出层的计算: 指 每个 时间步骤的隐藏状态, 它们可以作为输入 到后续的输出层。

此外, (stateNew) 返回的更新的隐藏状态 rnnLayer 指隐藏状态 在小批量的最后时间步。 它可用于初始化 一个epoch下一个小批量的隐藏状态 在顺序分区中。 对于多个隐藏层, 将存储每个层的隐藏状态 在这个变量(stateNew)中。 对于某些型号 待稍后介绍 (例如,长-短期记忆), 这个变量也是 包含其他信息。

NDArray X = manager.randomUniform (0, 1,new Shape(numSteps, batchSize, vocab.length()));

NDList input = new NDList(X, state.get(0));
rnnLayer.initialize(manager, DataType.FLOAT32, input.getShapes());
NDList forwardOutput = rnnLayer.forward(new ParameterStore(manager, false), input, false);
NDArray Y = forwardOutput.get(0);
NDArray stateNew = forwardOutput.get(1);

System.out.println(Y.getShape());
System.out.println(stateNew.getShape());
(35, 32, 256)
(1, 32, 256)

类似于 Section 8.5, 我们定义了一个 RNNModel 类 对于完整的RNN模型。 注意 rnnLayer 只包含隐藏的重复层,我们需要创建一个单独的输出层。

public class RNNModel extends AbstractBlock {

    private RNN rnnLayer;
    private Linear dense;
    private int vocabSize;

    public RNNModel(RNN rnnLayer, int vocabSize) {
        this.rnnLayer = rnnLayer;
        this.addChildBlock("rnn", rnnLayer);
        this.vocabSize = vocabSize;
        this.dense = Linear.builder().setUnits(vocabSize).build();
        this.addChildBlock("linear", dense);
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray X = inputs.get(0).transpose().oneHot(vocabSize);
        inputs.set(0, X);
        NDList result = rnnLayer.forward(parameterStore, inputs, training);
        NDArray Y = result.get(0);
        NDArray state = result.get(1);

        int shapeLength = Y.getShape().dimension();
        NDList output = dense.forward(parameterStore, new NDList(Y
                .reshape(new Shape(-1, Y.getShape().get(shapeLength-1)))), training);
        return new NDList(output.get(0), state);
    }

    @Override
    public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
        Shape shape = rnnLayer.getOutputShapes(new Shape[]{inputShapes[0]})[0];
        dense.initialize(manager, dataType, new Shape(vocabSize, shape.get(shape.dimension() - 1)));
    }

    /* 我们不会实现它,因为我们不会使用它,但它是抽象块的一部分  */
    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        return new Shape[0];
    }
}

8.6.3. 训练与预测

在训练模型之前,让我们使用具有随机权重的模型进行预测。

Device device = manager.getDevice();
RNNModel net = new RNNModel(rnnLayer, vocab.length());
net.initialize(manager, DataType.FLOAT32, X.getShape());
predictCh8("time traveller", 10, net, vocab, device, manager);
time travellermgmmmmmm

很明显,这种模式根本不起作用。 接下来,我们使用在:numref:sec_rnn_scratch 中定义的相同超参数调用 trainCh8,并使用高级API训练我们的模型。

int numEpochs = Integer.getInteger("MAX_EPOCH", 500);

int lr = 1;
trainCh8((Object) net, dataset, vocab, lr, numEpochs, device, false, manager);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.071 ms.
perplexity: 1.2, 82972.4 tokens/sec on gpu(0)
time traveller but now you begin this thatwell intithe fire with
travellerit s all have a realexistencethere is however a te

与上一节相比,由于深度学习框架的高级API对代码进行了更多的优化, 该模型在较短的时间内达到了较低的困惑度。

8.6.4. 总结

  • 深度学习框架的高级API提供了RNN层的实现。

  • 高级API的RNN层返回输出和更新的隐藏状态,其中输出不涉及输出层计算。

  • 与从头开始使用其实现相比,使用高级API可以更快地进行RNN训练。

8.6.5. 练习

  1. 尝试使用高级API,你能使循环神经网络模型过拟合吗?

  2. 如果在RNN模型中增加隐藏层的数量,会发生什么情况?你能使模型工作吗?

  3. 使用RNN实现 Section 8.1 的自回归模型。