Run this notebook online:Binder or Colab: Colab

11.11. 学习率调度器

到目前为止,我们主要关注如何更新权重向量的优化算法,而不是它们的更新速率。 然而,调整学习率通常与实际算法同样重要,有如下几方面需要考虑:

  • 首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果。我们之前看到问题的条件数很重要(有关详细信息,请参见 Section 11.6)。直观地说,这是最不敏感与最敏感方向的变化量的比率。

  • 其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。 Section 11.5比较详细地讨论了这一点,在 Section 11.4中我们则分析了性能保证。简而言之,我们希望速率衰减,但要比\(\mathcal{O}(t^{-\frac{1}{2}})\)慢,这样能成为解决凸问题的不错选择。

  • 另一个同样重要的方面是初始化。这既涉及参数最初的设置方式(详情请参阅 Section 4.8),又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的。

  • 最后,还有许多优化变体可以执行周期性学习率调整。这超出了本章的范围,我们建议读者阅读 [Izmailov et al., 2018]来了解个中细节。例如,如何通过对整个路径参数求平均值来获得更好的解。

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。 在本章中,我们将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

11.11.1. 一个简单的问题

我们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。 由于大多数代码都是标准的,我们只介绍基础知识,而不做进一步的详细讨论。如果需要,请参阅 Section 6进行复习。

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/GradDescUtils.java
%load ../utils/Accumulator.java
%load ../utils/StopWatch.java

%load ../utils/Training.java
%load ../utils/TrainingChapter11.java
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;
SequentialBlock net = new SequentialBlock();

net.add(Conv2d.builder()
        .setKernelShape(new Shape(5, 5))
        .optPadding(new Shape(2, 2))
        .setFilters(1)
        .build());
net.add(Activation.reluBlock());
net.add(Pool.maxPool2dBlock(new Shape(2, 2), new Shape(2, 2)));
net.add(Conv2d.builder()
        .setKernelShape(new Shape(5, 5))
        .setFilters(1)
        .build());
net.add(Blocks.batchFlattenBlock());
net.add(Activation.reluBlock());
net.add(Linear.builder().setUnits(120).build());
net.add(Activation.reluBlock());
net.add(Linear.builder().setUnits(84).build());
net.add(Activation.reluBlock());
net.add(Linear.builder().setUnits(10).build());
SequentialBlock {
    Conv2d
    ReLU
    maxPool2d
    Conv2d
    batchFlatten
    ReLU
    Linear
    ReLU
    Linear
    ReLU
    Linear
}
int batchSize = 256;
RandomAccessDataset trainDataset = FashionMnist.builder()
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize, false)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

RandomAccessDataset testDataset = FashionMnist.builder()
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize, false)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();
double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;

public static void train(RandomAccessDataset trainIter, RandomAccessDataset testIter,
                             int numEpochs, Trainer trainer) throws IOException, TranslateException {
    epochCount = new double[numEpochs];

    for (int i = 0; i < epochCount.length; i++) {
        epochCount[i] = (i + 1);
    }

    double avgTrainTimePerEpoch = 0;
    Map<String, double[]> evaluatorMetrics = new HashMap<>();

    trainer.setMetrics(new Metrics());

    EasyTrain.fit(trainer, numEpochs, trainIter, testIter);

    Metrics metrics = trainer.getMetrics();

    trainer.getEvaluators().stream()
            .forEach(evaluator -> {
                evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());
            });

    avgTrainTimePerEpoch = metrics.mean("epoch");

    trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
    trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
    testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

    System.out.printf("loss %.3f," , trainLoss[numEpochs-1]);
    System.out.printf(" train acc %.3f," , trainAccuracy[numEpochs-1]);
    System.out.printf(" test acc %.3f\n" , testAccuracy[numEpochs-1]);
    System.out.printf("%.1f examples/sec \n", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
}

让我们来看看如果使用默认设置,调用此算法会发生什么。 例如设学习率为\(0.3\)并训练\(30\)次迭代。 留意在超过了某点、测试准确度方面的进展停滞时,训练准确度将如何继续提高。 两条曲线之间的间隙表示过拟合。

float lr = 0.3f;
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);

Model model = Model.newInstance("Modern LeNet");
model.setBlock(net);

Loss loss = Loss.softmaxCrossEntropyLoss();
Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
        .optOptimizer(sgd) // Optimizer
        .addEvaluator(new Accuracy()) // Model Accuracy
        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 1, 28, 28));

train(trainDataset, testDataset, numEpochs, trainer);
INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.057 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
loss 2.304, train acc 0.100, test acc 0.100
8559.3 examples/sec
public void plotMetrics() {
    String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

    Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
    Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
    Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                    trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");

    Table data = Table.create("Data").addColumns(
        DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
        DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
        StringColumn.create("lossLabel", lossLabel)
    );

    display(LinePlot.create("", data, "epoch", "metrics", "lossLabel"));
}

plotMetrics();

11.11.2. 学习率调度器

我们可以在每个迭代轮数(甚至在每个小批量)之后向下调整学习率。 例如,以动态的方式来响应优化的进展情况。 更通常而言,我们应该定义一个调度器。 当调用更新次数时,它将返回学习率的适当值。 让我们定义一个简单的方法,将学习率设置为\(\eta = \eta_0 (t + 1)^{-\frac{1}{2}}\)

public class SquareRootTracker {
    float lr;
    public SquareRootTracker() {
        this(0.1f);
    }
    public SquareRootTracker(float learningRate) {
        this.lr = learningRate;
    }
    public float getNewLearningRate(int numUpdate) {
        return lr * (float) Math.pow(numUpdate + 1, -0.5);
    }
}

让我们在一系列值上绘制它的行为。

public Figure plotLearningRate(int[] epochs, float[] learningRates) {

    String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

    Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
    Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
    Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                    trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");

    Table data = Table.create("Data").addColumns(
                IntColumn.create("epoch", epochs),
                DoubleColumn.create("learning rate", learningRates)
    );

    return LinePlot.create("Learning Rate vs. Epoch", data, "epoch", "learning rate");
}
SquareRootTracker tracker = new SquareRootTracker();

int[] epochs = new int[numEpochs];
float[] learningRates = new float[numEpochs];
for (int i = 0; i < numEpochs; i++) {
    epochs[i] = i;
    learningRates[i] = tracker.getNewLearningRate(i);
}

plotLearningRate(epochs, learningRates);

现在让我们来看看这对在Fashion-MNIST数据集上的训练有何影响。 我们只是提供调度器作为训练算法的额外参数。

这比以前好一些:曲线比以前更加平滑,并且过拟合更小了。 遗憾的是,关于为什么在理论上某些策略会导致较轻的过拟合,有一些观点认为,较小的步长将导致参数更接近零,因此更简单。 但是,这并不能完全解释这种现象,因为我们并没有真正地提前停止,而只是轻柔地降低了学习率。

11.11.3. 策略

虽然我们不可能涵盖所有类型的学习率调度器,但我们会尝试在下面简要概述常用的策略:多项式衰减和分段常数表。 此外,余弦学习率调度在实践中的一些问题上运行效果很好。 在某些问题上,最好在使用较高的学习率之前预热优化器。

11.11.3.1. 多因子调度器

多项式衰减的一种替代方案是乘法衰减,即\(\eta_{t+1} \leftarrow \eta_t \cdot \alpha\)其中\(\alpha \in (0, 1)\)。为了防止学习率衰减超出合理的下限,更新方程经常修改为\(\eta_{t+1} \leftarrow \mathop{\mathrm{max}}(\eta_{\mathrm{min}}, \eta_t \cdot \alpha)\)

public class DemoFactorTracker {
    float baseLr;
    float stopFactorLr;
    float factor;
    public DemoFactorTracker(float factor, float stopFactorLr, float baseLr) {
        this.factor = factor;
        this.stopFactorLr = stopFactorLr;
        this.baseLr = baseLr;
    }
    public DemoFactorTracker() {
        this(1f, (float) 1e-7, 0.1f);
    }
    public float getNewLearningRate(int numUpdate) {
        return lr * (float) Math.pow(numUpdate + 1, -0.5);
    }
}
DemoFactorTracker tracker = new DemoFactorTracker(0.9f, (float) 1e-2, 2);

numEpochs = 50;
int[] epochs = new int[numEpochs];
float[] learningRates = new float[numEpochs];
for (int i = 0; i < numEpochs; i++) {
    epochs[i] = i;
    learningRates[i] = tracker.getNewLearningRate(i);
}

plotLearningRate(epochs, learningRates);

接下来,我们将使用内置的调度器,但在这里仅解释它们的功能。

11.11.3.2. 多因子调度器

训练深度网络的常见策略之一是保持分段稳定的学习率,并且每隔一段时间就一定程度学习率降低。 具体地说,给定一组降低学习率的时间,例如\(s = \{5, 10, 20\}\)每当\(t \in s\)时降低\(\eta_{t+1} \leftarrow \eta_t \cdot \alpha\)。 假设每步中的值减半,我们可以按如下方式实现这一点。

MultiFactorTracker tracker = Tracker.multiFactor()
        .setSteps(new int[]{5, 30})
        .optFactor(0.5f)
        .setBaseValue(0.5f)
        .build();

numEpochs = 10;
int[] epochs = new int[numEpochs];
float[] learningRates = new float[numEpochs];
for (int i = 0; i < numEpochs; i++) {
    epochs[i] = i;
    learningRates[i] = tracker.getNewValue(i);
}

plotLearningRate(epochs, learningRates);

这种分段恒定学习率调度背后的直觉是,让优化持续进行,直到权重向量的分布达到一个驻点。 此时,我们才将学习率降低,以获得更高质量的代理来达到一个良好的局部最小值。 下面的例子展示了如何使用这种方法产生更好的解决方案。

int numEpochs = Integer.getInteger("MAX_EPOCH", 10);

Model model = Model.newInstance("Modern LeNet");
model.setBlock(net);

Loss loss = Loss.softmaxCrossEntropyLoss();
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(tracker).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
        .optOptimizer(sgd) // Optimizer
        .addEvaluator(new Accuracy()) // Model Accuracy
        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 1, 28, 28));

train(trainDataset, testDataset, numEpochs, trainer);
plotMetrics();
INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.020 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
loss 2.303, train acc 0.100, test acc 0.100
10782.6 examples/sec
String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");

Table data = Table.create("Data").addColumns(
            DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
            DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
            StringColumn.create("lossLabel", lossLabel)
);

LinePlot.create("", data, "epoch", "metrics", "lossLabel");

11.11.3.3. 余弦调度器

余弦调度器是 [Loshchilov & Hutter, 2016]提出的一种启发式算法。 它所依据的观点是:我们可能不想在一开始就太大地降低学习率,而且可能希望最终能用非常小的学习率来“改进”解决方案。 这产生了一个类似于余弦的调度,函数形式如下所示,学习率的值在\(t \in [0, T]\)之间。

(11.11.1)\[\eta_t = \eta_T + \frac{\eta_0 - \eta_T}{2} \left(1 + \cos(\pi t/T)\right)\]

这里\(\eta_0\)是初始学习率,\(\eta_T\)是当\(T\)时的目标学习率。 此外,对于\(t > T\),我们只需将值固定到\(\eta_T\)而不再增加它。 在下面的示例中,我们设置了最大更新步数\(T = 20\)

public class DemoCosineTracker {
    float baseLr;
    float finalLr;
    int maxUpdate;
    public DemoCosineTracker() {
        this(0.5f, 0.01f, 20);
    }
    public DemoCosineTracker(float baseLr, float finalLr, int maxUpdate) {
        this.baseLr = baseLr;
        this.finalLr = finalLr;
        this.maxUpdate = maxUpdate;
    }
    public float getNewLearningRate(int numUpdate) {
        if (numUpdate > maxUpdate) {
            return finalLr;
        }
        // Scale the curve to smoothly transition
        float step = (baseLr - finalLr) / 2 * (1 + (float) Math.cos(Math.PI * numUpdate / maxUpdate));
        return finalLr + step;
    }
}
DemoCosineTracker tracker = new DemoCosineTracker(0.5f, 0.01f, 20);

int[] epochs = new int[numEpochs];
float[] learningRates = new float[numEpochs];
for (int i = 0; i < numEpochs; i++) {
    epochs[i] = i;
    learningRates[i] = tracker.getNewLearningRate(i);
}

plotLearningRate(epochs, learningRates);

在计算机视觉中,这个调度可以引出改进的结果。 但请注意,如下所示,这种改进并不能保证。

CosineTracker cosineTracker = Tracker.cosine()
                            .setBaseValue(0.5f)
                            .optFinalValue(0.01f)
                            .setMaxUpdates(20)
                            .build();

Loss loss = Loss.softmaxCrossEntropyLoss();
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(cosineTracker).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
        .optOptimizer(sgd) // Optimizer
        .addEvaluator(new Accuracy()) // Model Accuracy
        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 1, 28, 28));

train(trainDataset, testDataset, numEpochs, trainer);
INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.022 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
loss 2.303, train acc 0.096, test acc 0.100
10674.0 examples/sec

11.11.3.4. 预热

在某些情况下,初始化参数不足以得到良好的解。 这对于某些高级网络设计来说尤其棘手,可能导致不稳定的优化结果。 对此,一方面,我们可以选择一个足够小的学习率, 从而防止一开始发散,然而这样进展太缓慢。 另一方面,较高的学习率最初就会导致发散。

解决这种困境的一个相当简单的解决方法是使用预热期,在此期间学习率将增加至初始最大值,然后冷却直到优化过程结束。 为了简单起见,通常使用线性递增。 这引出了如下表所示的时间表。

public class CosineWarmupTracker {
    float baseLr;
    float finalLr;
    int maxUpdate;
    int warmUpSteps;
    float warmUpBeginValue;
    float warmUpFinalValue;

    public CosineWarmupTracker() {
        this(0.5f, 0.01f, 20, 5);
    }

    public CosineWarmupTracker(float baseLr, float finalLr, int maxUpdate, int warmUpSteps) {
        this.baseLr = baseLr;
        this.finalLr = finalLr;
        this.maxUpdate = maxUpdate;
        this.warmUpSteps = 5;
        this.warmUpBeginValue = 0f;
    }

    public float getNewLearningRate(int numUpdate) {
        if (numUpdate <= warmUpSteps) {
            return getWarmUpValue(numUpdate);
        }
        if (numUpdate > maxUpdate) {
            return finalLr;
        }
        // Scale the cosine curve to fit smoothly with the warmup steps
        float step = (baseLr - finalLr) / 2 * (1 +
            (float) Math.cos(Math.PI * (numUpdate - warmUpSteps) / (maxUpdate - warmUpSteps)));
        return finalLr + step;
    }

    public float getWarmUpValue(int numUpdate) {
        // Linear warmup
        return warmUpBeginValue + (baseLr - warmUpBeginValue) * numUpdate / warmUpSteps;
    }
}
CosineWarmupTracker tracker = new CosineWarmupTracker(0.5f, 0.01f, 20, 5);

int[] epochs = new int[numEpochs];
float[] learningRates = new float[numEpochs];
for (int i = 0; i < numEpochs; i++) {
    epochs[i] = i;
    learningRates[i] = tracker.getNewLearningRate(i);
}

plotLearningRate(epochs, learningRates);

注意,观察前5个迭代轮数的性能,网络最初收敛得更好。

CosineTracker cosineTracker = Tracker.cosine()
        .setBaseValue(0.5f)
        .optFinalValue(0.01f)
        .setMaxUpdates(15)
        .build();

WarmUpTracker warmupCosine = Tracker.warmUp()
        .optWarmUpSteps(5)
        .setMainTracker(cosineTracker)
        .build();

Loss loss = Loss.softmaxCrossEntropyLoss();
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(warmupCosine).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
        .optOptimizer(sgd) // Optimizer
        .addEvaluator(new Accuracy()) // Model Accuracy
        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);
trainer.initialize(new Shape(1, 1, 28, 28));

train(trainDataset, testDataset, numEpochs, trainer);
plotMetrics();
INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.020 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
INFO Validate: Accuracy: 0.10, SoftmaxCrossEntropyLoss: 2.30
loss 2.303, train acc 0.096, test acc 0.100
10864.1 examples/sec

预热可以应用于任何调度器,而不仅仅是余弦。 有关学习率调度的更多实验和更详细讨论,请参阅 [Gotmare et al., 2018]。 其中,这篇论文的点睛之笔的发现:预热阶段限制了非常深的网络中参数的发散量。 这在直觉上是有道理的:在网络中那些一开始花费最多时间取得进展的部分,随机初始化会产生巨大的发散。

11.11.4. 小结

  • 在训练期间逐步降低学习率可以提高准确性,并且减少模型的过拟合。

  • 在实验中,每当进展趋于稳定时就降低学习率,这是很有效的。从本质上说,这可以确保我们有效地收敛到一个适当的解,也只有这样才能通过降低学习率来减小参数的固有方差。

  • 余弦调度器在某些计算机视觉问题中很受欢迎。

  • 优化之前的预热期可以防止发散。

  • 优化在深度学习中有多种用途。对于同样的训练误差而言,选择不同的优化算法和学习率调度,除了最大限度地减少训练时间,可以导致测试集上不同的泛化和过拟合量。

11.11.5. 练习

  1. 试验给定固定学习率的优化行为。这种情况下你可以获得的最佳模型是什么?

  2. 如果你改变学习率下降的指数,收敛性会如何改变?在实验中方便起见,使用PolyScheduler

  3. 将余弦调度器应用于大型计算机视觉问题,例如训练ImageNet数据集。与其他调度器相比,它如何影响性能?

  4. 预热应该持续多长时间?

  5. 你能把优化和采样联系起来吗?首先,在随机梯度朗之万动力学上使用 [Welling & Teh, 2011]的结果。