Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_optimization/lr-scheduler.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_optimization/lr-scheduler.ipynb .. _sec_scheduler: 学习率调度器 ============ 到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。 然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑: - 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要(有关详细信息,请参见 :numref:`sec_momentum`\ )。直观地说,这是最不敏感与最敏感方向的变化量的比率。 - 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。 :numref:`sec_minibatch_sgd`\ 比较详细地讨论了这一点,在 :numref:`sec_sgd`\ 中我们则分析了性能保证。简而言之,我们希望速率衰减,但要比\ :math:`\mathcal{O}(t^{-\frac{1}{2}})`\ 慢,这样能成为解决凸问题的不错选择。 - 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式(详情请参阅 :numref:`sec_numerical_stability`\ ),又关系到它们最初的演变方式。这被戏称为\ *预热*\ (warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。 - 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 :cite:`Izmailov.Podoprikhin.Garipov.ea.2018`\ 来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。 鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。 在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过\ *学习率调度器*\ (learning rate scheduler)来有效管理。 一个简单的问题 -------------- 我们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用\ ``relu``\ 而不是\ ``sigmoid``\ ,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。 由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论。如果需要,请参阅 :numref:`chap_cnn`\ 进行复习。 .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/Functions.java %load ../utils/GradDescUtils.java %load ../utils/Accumulator.java %load ../utils/StopWatch.java %load ../utils/Training.java %load ../utils/TrainingChapter11.java .. code:: java import ai.djl.basicdataset.cv.classification.*; import org.apache.commons.lang3.ArrayUtils; .. code:: java SequentialBlock net = new SequentialBlock(); net.add(Conv2d.builder() .setKernelShape(new Shape(5, 5)) .optPadding(new Shape(2, 2)) .setFilters(1) .build()); net.add(Activation.reluBlock()); net.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2))); net.add(Conv2d.builder() .setKernelShape(new Shape(5, 5)) .setFilters(1) .build()); net.add(Blocks.batchFlattenBlock()); net.add(Activation.reluBlock()); net.add(Linear.builder().setUnits(120).build()); net.add(Activation.reluBlock()); net.add(Linear.builder().setUnits(84).build()); net.add(Activation.reluBlock()); net.add(Linear.builder().setUnits(10).build()); .. parsed-literal:: :class: output SequentialBlock { Conv2d ReLU maxPool2d Conv2d batchFlatten ReLU Linear ReLU Linear ReLU Linear } .. code:: java int batchSize = 256; RandomAccessDataset trainDataset = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(batchSize, false) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); RandomAccessDataset testDataset = FashionMnist.builder() .optUsage(Dataset.Usage.TEST) .setSampling(batchSize, false) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); .. code:: java double[] trainLoss; double[] testAccuracy; double[] epochCount; double[] trainAccuracy; public static void train(RandomAccessDataset trainIter, RandomAccessDataset testIter, int numEpochs, Trainer trainer) throws IOException, TranslateException { epochCount = new double[numEpochs]; for (int i = 0; i < epochCount.length; i++) { epochCount[i] = (i + 1); } double avgTrainTimePerEpoch = 0; Map evaluatorMetrics = new HashMap<>(); trainer.setMetrics(new Metrics()); EasyTrain.fit(trainer, numEpochs, trainIter, testIter); Metrics metrics = trainer.getMetrics(); trainer.getEvaluators().stream() .forEach(evaluator -> { evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); }); avgTrainTimePerEpoch = metrics.mean("epoch"); trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss"); trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy"); testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy"); System.out.printf("loss %.3f," , trainLoss[numEpochs-1]); System.out.printf(" train acc %.3f," , trainAccuracy[numEpochs-1]); System.out.printf(" test acc %.3f\n" , testAccuracy[numEpochs-1]); System.out.printf("%.1f examples/sec \n", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9))); } 让我们来看看如果使用默认设置,调用此算法会发生什么。 例如设学习率为\ :math:`0.3`\ 并训练\ :math:`30`\ 次迭代。 留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。 两条曲线之间的间隙表示过拟合。 .. code:: java float lr = 0.3f; int numEpochs = Integer.getInteger("MAX_EPOCH", 10); Model model = Model.newInstance("Modern LeNet"); model.setBlock(net); Loss loss = Loss.softmaxCrossEntropyLoss(); Tracker lrt = Tracker.fixed(lr); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // Optimizer .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging Trainer trainer = model.newTrainer(config); trainer.initialize(new Shape(1, 1, 28, 28)); train(trainDataset, testDataset, numEpochs, trainer); .. parsed-literal:: :class: output INFO Training on: 4 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.057 ms. .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 1 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 2 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 3 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 4 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 5 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 6 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 7 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 8 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 9 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 10 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output loss 2.304, train acc 0.100, test acc 0.100 8559.3 examples/sec .. code:: java public void plotMetrics() { String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length]; Arrays.fill(lossLabel, 0, trainLoss.length, "train loss"); Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc"); Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length, trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc"); Table data = Table.create("Data").addColumns( DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))), DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))), StringColumn.create("lossLabel", lossLabel) ); display(LinePlot.create("", data, "epoch", "metrics", "lossLabel")); } plotMetrics(); .. raw:: html
学习率调度器 ------------ 我们可以在每个迭代轮数(甚至在每个小批量)之后向下调整学习率。 例如,以动态的方式来响应优化的进展情况。 更通常而言,我们应该定义一个调度器。 当调用更新次数时,它将返回学习率的适当值。 让我们定义一个简单的方法,将学习率设置为\ :math:`\eta = \eta_0 (t + 1)^{-\frac{1}{2}}`\ 。 .. code:: java public class SquareRootTracker { float lr; public SquareRootTracker() { this(0.1f); } public SquareRootTracker(float learningRate) { this.lr = learningRate; } public float getNewLearningRate(int numUpdate) { return lr * (float) Math.pow(numUpdate + 1, -0.5); } } 让我们在一系列值上绘制它的行为。 .. code:: java public Figure plotLearningRate(int[] epochs, float[] learningRates) { String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length]; Arrays.fill(lossLabel, 0, trainLoss.length, "train loss"); Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc"); Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length, trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc"); Table data = Table.create("Data").addColumns( IntColumn.create("epoch", epochs), DoubleColumn.create("learning rate", learningRates) ); return LinePlot.create("Learning Rate vs. Epoch", data, "epoch", "learning rate"); } .. code:: java SquareRootTracker tracker = new SquareRootTracker(); int[] epochs = new int[numEpochs]; float[] learningRates = new float[numEpochs]; for (int i = 0; i < numEpochs; i++) { epochs[i] = i; learningRates[i] = tracker.getNewLearningRate(i); } plotLearningRate(epochs, learningRates); .. raw:: html
现在让我们来看看这对在Fashion-MNIST数据集上的训练有何影响。 我们只是提供调度器作为训练算法的额外参数。 这比以前好一些:曲线比以前更加平滑,并且过拟合更小了。 遗憾的是,关于为什么在理论上某些策略会导致较轻的过拟合,有一些观点认为,较小的步长将导致参数更接近零,因此更简单。 但是,这并不能完全解释这种现象,因为我们并没有真正地提前停止,而只是轻柔地降低了学习率。 策略 ---- 虽然我们不可能涵盖所有类型的学习率调度器,但我们会尝试在下面简要概述常用的策略:多项式衰减和分段常数表。 此外,余弦学习率调度在实践中的一些问题上运行效果很好。 在某些问题上,最好在使用较高的学习率之前预热优化器。 多因子调度器 ~~~~~~~~~~~~ 多项式衰减的一种替代方案是乘法衰减,即\ :math:`\eta_{t+1} \leftarrow \eta_t \cdot \alpha`\ 其中\ :math:`\alpha \in (0, 1)`\ 。为了防止学习率衰减超出合理的下限,更新方程经常修改为\ :math:`\eta_{t+1} \leftarrow \mathop{\mathrm{max}}(\eta_{\mathrm{min}}, \eta_t \cdot \alpha)`\ 。 .. code:: java public class DemoFactorTracker { float baseLr; float stopFactorLr; float factor; public DemoFactorTracker(float factor, float stopFactorLr, float baseLr) { this.factor = factor; this.stopFactorLr = stopFactorLr; this.baseLr = baseLr; } public DemoFactorTracker() { this(1f, (float) 1e-7, 0.1f); } public float getNewLearningRate(int numUpdate) { return lr * (float) Math.pow(numUpdate + 1, -0.5); } } .. code:: java DemoFactorTracker tracker = new DemoFactorTracker(0.9f, (float) 1e-2, 2); numEpochs = 50; int[] epochs = new int[numEpochs]; float[] learningRates = new float[numEpochs]; for (int i = 0; i < numEpochs; i++) { epochs[i] = i; learningRates[i] = tracker.getNewLearningRate(i); } plotLearningRate(epochs, learningRates); .. raw:: html
接下来,我们将使用内置的调度器,但在这里仅解释它们的功能。 多因子调度器 ~~~~~~~~~~~~ 训练深度网络的常见策略之一是保持分段稳定的学习率,并且每隔一段时间就一定程度学习率降低。 具体地说,给定一组降低学习率的时间,例如\ :math:`s = \{5, 10, 20\}`\ 每当\ :math:`t \in s`\ 时降低\ :math:`\eta_{t+1} \leftarrow \eta_t \cdot \alpha`\ 。 假设每步中的值减半,我们可以按如下方式实现这一点。 .. code:: java MultiFactorTracker tracker = Tracker.multiFactor() .setSteps(new int[]{5, 30}) .optFactor(0.5f) .setBaseValue(0.5f) .build(); numEpochs = 10; int[] epochs = new int[numEpochs]; float[] learningRates = new float[numEpochs]; for (int i = 0; i < numEpochs; i++) { epochs[i] = i; learningRates[i] = tracker.getNewValue(i); } plotLearningRate(epochs, learningRates); .. raw:: html
这种分段恒定学习率调度背后的直觉是,让优化持续进行,直到权重向量的分布达到一个驻点。 此时,我们才将学习率降低,以获得更高质量的代理来达到一个良好的局部最小值。 下面的例子展示了如何使用这种方法产生更好的解决方案。 .. code:: java int numEpochs = Integer.getInteger("MAX_EPOCH", 10); Model model = Model.newInstance("Modern LeNet"); model.setBlock(net); Loss loss = Loss.softmaxCrossEntropyLoss(); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(tracker).build(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // Optimizer .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging Trainer trainer = model.newTrainer(config); trainer.initialize(new Shape(1, 1, 28, 28)); train(trainDataset, testDataset, numEpochs, trainer); plotMetrics(); .. parsed-literal:: :class: output INFO Training on: 4 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.020 ms. .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 1 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 2 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 3 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 4 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 5 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 6 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 7 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 8 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 9 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 10 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output loss 2.303, train acc 0.100, test acc 0.100 10782.6 examples/sec .. raw:: html
.. code:: java String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length]; Arrays.fill(lossLabel, 0, trainLoss.length, "train loss"); Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc"); Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length, trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc"); Table data = Table.create("Data").addColumns( DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))), DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))), StringColumn.create("lossLabel", lossLabel) ); LinePlot.create("", data, "epoch", "metrics", "lossLabel"); .. raw:: html
余弦调度器 ~~~~~~~~~~ 余弦调度器是 :cite:`Loshchilov.Hutter.2016`\ 提出的一种启发式算法。 它所依据的观点是:我们可能不想在一开始就太大地降低学习率,而且可能希望最终能用非常小的学习率来“改进”解决方案。 这产生了一个类似于余弦的调度,函数形式如下所示,学习率的值在\ :math:`t \in [0, T]`\ 之间。 .. math:: \eta_t = \eta_T + \frac{\eta_0 - \eta_T}{2} \left(1 + \cos(\pi t/T)\right) 这里\ :math:`\eta_0`\ 是初始学习率,\ :math:`\eta_T`\ 是当\ :math:`T`\ 时的目标学习率。 此外,对于\ :math:`t > T`\ ,我们只需将值固定到\ :math:`\eta_T`\ 而不再增加它。 在下面的示例中,我们设置了最大更新步数\ :math:`T = 20`\ 。 .. code:: java public class DemoCosineTracker { float baseLr; float finalLr; int maxUpdate; public DemoCosineTracker() { this(0.5f, 0.01f, 20); } public DemoCosineTracker(float baseLr, float finalLr, int maxUpdate) { this.baseLr = baseLr; this.finalLr = finalLr; this.maxUpdate = maxUpdate; } public float getNewLearningRate(int numUpdate) { if (numUpdate > maxUpdate) { return finalLr; } // Scale the curve to smoothly transition float step = (baseLr - finalLr) / 2 * (1 + (float) Math.cos(Math.PI * numUpdate / maxUpdate)); return finalLr + step; } } .. code:: java DemoCosineTracker tracker = new DemoCosineTracker(0.5f, 0.01f, 20); int[] epochs = new int[numEpochs]; float[] learningRates = new float[numEpochs]; for (int i = 0; i < numEpochs; i++) { epochs[i] = i; learningRates[i] = tracker.getNewLearningRate(i); } plotLearningRate(epochs, learningRates); .. raw:: html
在计算机视觉中,这个调度可以引出改进的结果。 但请注意,如下所示,这种改进并不能保证。 .. code:: java CosineTracker cosineTracker = Tracker.cosine() .setBaseValue(0.5f) .optFinalValue(0.01f) .setMaxUpdates(20) .build(); Loss loss = Loss.softmaxCrossEntropyLoss(); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(cosineTracker).build(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // Optimizer .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging Trainer trainer = model.newTrainer(config); trainer.initialize(new Shape(1, 1, 28, 28)); train(trainDataset, testDataset, numEpochs, trainer); .. parsed-literal:: :class: output INFO Training on: 4 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.022 ms. .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 1 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 2 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 3 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 4 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 5 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 6 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 7 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 8 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 9 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 10 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output loss 2.303, train acc 0.096, test acc 0.100 10674.0 examples/sec 预热 ~~~~ 在某些情况下,初始化参数不足以得到良好的解。 这对于某些高级网络设计来说尤其棘手,可能导致不稳定的优化结果。 对此,一方面,我们可以选择一个足够小的学习率, 从而防止一开始发散,然而这样进展太缓慢。 另一方面,较高的学习率最初就会导致发散。 解决这种困境的一个相当简单的解决方法是使用预热期,在此期间学习率将增加至初始最大值,然后冷却直到优化过程结束。 为了简单起见,通常使用线性递增。 这引出了如下表所示的时间表。 .. code:: java public class CosineWarmupTracker { float baseLr; float finalLr; int maxUpdate; int warmUpSteps; float warmUpBeginValue; float warmUpFinalValue; public CosineWarmupTracker() { this(0.5f, 0.01f, 20, 5); } public CosineWarmupTracker(float baseLr, float finalLr, int maxUpdate, int warmUpSteps) { this.baseLr = baseLr; this.finalLr = finalLr; this.maxUpdate = maxUpdate; this.warmUpSteps = 5; this.warmUpBeginValue = 0f; } public float getNewLearningRate(int numUpdate) { if (numUpdate <= warmUpSteps) { return getWarmUpValue(numUpdate); } if (numUpdate > maxUpdate) { return finalLr; } // Scale the cosine curve to fit smoothly with the warmup steps float step = (baseLr - finalLr) / 2 * (1 + (float) Math.cos(Math.PI * (numUpdate - warmUpSteps) / (maxUpdate - warmUpSteps))); return finalLr + step; } public float getWarmUpValue(int numUpdate) { // Linear warmup return warmUpBeginValue + (baseLr - warmUpBeginValue) * numUpdate / warmUpSteps; } } .. code:: java CosineWarmupTracker tracker = new CosineWarmupTracker(0.5f, 0.01f, 20, 5); int[] epochs = new int[numEpochs]; float[] learningRates = new float[numEpochs]; for (int i = 0; i < numEpochs; i++) { epochs[i] = i; learningRates[i] = tracker.getNewLearningRate(i); } plotLearningRate(epochs, learningRates); .. raw:: html
注意,观察前5个迭代轮数的性能,网络最初收敛得更好。 .. code:: java CosineTracker cosineTracker = Tracker.cosine() .setBaseValue(0.5f) .optFinalValue(0.01f) .setMaxUpdates(15) .build(); WarmUpTracker warmupCosine = Tracker.warmUp() .optWarmUpSteps(5) .setMainTracker(cosineTracker) .build(); Loss loss = Loss.softmaxCrossEntropyLoss(); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(warmupCosine).build(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // Optimizer .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging Trainer trainer = model.newTrainer(config); trainer.initialize(new Shape(1, 1, 28, 28)); train(trainDataset, testDataset, numEpochs, trainer); plotMetrics(); .. parsed-literal:: :class: output INFO Training on: 4 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.020 ms. .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 1 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 2 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 3 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 4 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 5 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 6 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 7 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 8 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 9 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 10 finished. INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30 .. parsed-literal:: :class: output loss 2.303, train acc 0.096, test acc 0.100 10864.1 examples/sec .. raw:: html
预热可以应用于任何调度器,而不仅仅是余弦。 有关学习率调度的更多实验和更详细讨论,请参阅 :cite:`Gotmare.Keskar.Xiong.ea.2018`\ 。 其中,这篇论文的点睛之笔的发现:预热阶段限制了非常深的网络中参数的发散量。 这在直觉上是有道理的:在网络中那些一开始花费最多时间取得进展的部分,随机初始化会产生巨大的发散。 小结 ---- - 在训练期间逐步降低学习率可以提高准确性,并且减少模型的过拟合。 - 在实验中,每当进展趋于稳定时就降低学习率,这是很有效的。从本质上说,这可以确保我们有效地收敛到一个适当的解,也只有这样才能通过降低学习率来减小参数的固有方差。 - 余弦调度器在某些计算机视觉问题中很受欢迎。 - 优化之前的预热期可以防止发散。 - 优化在深度学习中有多种用途。对于同样的训练误差而言,选择不同的优化算法和学习率调度,除了最大限度地减少训练时间,可以导致测试集上不同的泛化和过拟合量。 练习 ---- 1. 试验给定固定学习率的优化行为。这种情况下你可以获得的最佳模型是什么? 2. 如果你改变学习率下降的指数,收敛性会如何改变?在实验中方便起见,使用\ ``PolyScheduler``\ 。 3. 将余弦调度器应用于大型计算机视觉问题,例如训练ImageNet数据集。与其他调度器相比,它如何影响性能? 4. 预热应该持续多长时间? 5. 你能把优化和采样联系起来吗?首先,在随机梯度朗之万动力学上使用 :cite:`Welling.Teh.2011`\ 的结果。