Run this notebook online: or Colab:
11.5. 小批量随机梯度下降¶
到目前为止,我们在基于梯度的学习方法中遇到了两个极端情况: Section 11.3中使用完整数据集来计算梯度并更新参数, Section 11.4中一次处理一个训练样本来取得进展。 二者各有利弊:每当数据非常相似时,梯度下降并不是非常“数据高效”。 而由于CPU和GPU无法充分利用向量化,随机梯度下降并不特别“计算高效”。 这暗示了两者之间可能有折中方案,这便涉及到小批量随机梯度下降(minibatch gradient descent)。
11.5.1. 向量化和缓存¶
使用小批量的决策的核心是计算效率。 当考虑与多个GPU和多台服务器并行处理时,这一点最容易被理解。在这种情况下,我们需要向每个GPU发送至少一张图像。 有了每台服务器8个GPU和16台服务器,我们就能得到大小为128的小批量。
当涉及到单个GPU甚至CPU时,事情会更微妙一些: 这些设备有多种类型的内存、通常情况下多种类型的计算单元以及在它们之间不同的带宽限制。 例如,一个CPU有少量寄存器(register),L1和L2缓存,以及L3缓存(在不同的处理器内核之间共享)。 随着缓存的大小的增加,它们的延迟也在增加,同时带宽在减少。 可以说,处理器能够执行的操作远比主内存接口所能提供的多得多。
首先,具有16个内核和AVX-512向量化的2GHz CPU每秒可处理高达\(2 \cdot 10^9 \cdot 16 \cdot 32 = 10^{12}\)个字节。 同时,GPU的性能很容易超过该数字100倍。 而另一方面,中端服务器处理器的带宽可能不超过100Gb/s,即不到处理器满负荷所需的十分之一。 更糟糕的是,并非所有的内存入口都是相等的:内存接口通常为64位或更宽(例如,在最多384位的GPU上)。 因此读取单个字节会导致由于更宽的存取而产生的代价。
其次,第一次存取的额外开销很大,而按序存取(sequential access)或突发读取(burst read)相对开销较小。 有关更深入的讨论,请参阅此维基百科文章。
减轻这些限制的方法是使用足够快的CPU缓存层次结构来为处理器提供数据。 这是深度学习中批量处理背后的推动力。 举一个简单的例子:矩阵-矩阵乘法。 比如\(\mathbf{A} = \mathbf{B}\mathbf{C}\),我们有很多方法来计算\(\mathbf{A}\)。例如,我们可以尝试以下方法:
我们可以计算\(\mathbf{A}_{ij} = \mathbf{B}_{i,:} \mathbf{C}_{:,j}^\top\),也就是说,我们可以通过点积进行逐元素计算。
我们可以计算\(\mathbf{A}_{:,j} = \mathbf{B} \mathbf{C}_{:,j}^\top\),也就是说,我们可以一次计算一列。同样,我们可以一次计算\(\mathbf{A}\)一行\(\mathbf{A}_{i,:}\)。
我们可以简单地计算\(\mathbf{A} = \mathbf{B} \mathbf{C}\)。
我们可以将\(\mathbf{B}\)和\(\mathbf{C}\)分成较小的区块矩阵,然后一次计算\(\mathbf{A}\)的一个区块。
如果我们使用第一个选择,每次我们计算一个元素\(\mathbf{A}_{ij}\)时,都需要将一行和一列向量复制到CPU中。 更糟糕的是,由于矩阵元素是按顺序对齐的,因此当从内存中读取它们时,我们需要访问两个向量中许多不相交的位置。 第二种选择相对更有利:我们能够在遍历\(\mathbf{B}\)的同时,将列向量\(\mathbf{C}_{:,j}\)保留在CPU缓存中。 它将内存带宽需求减半,相应地提高了访问速度。 第三种选择表面上是最可取的,然而大多数矩阵可能不能完全放入缓存中。 第四种选择提供了一个实践上很有用的方案:我们可以将矩阵的区块移到缓存中然后在本地将它们相乘。 让我们来看看这些操作在实践中的效率如何。
除了计算效率之外,Python和深度学习框架本身带来的额外开销也是相当大的。 回想一下,每次我们执行代码时,Python解释器都会向深度学习框架发送一个命令,要求将其插入到计算图中并在调度过程中处理它。 这样的额外开销可能是非常不利的。 总而言之,我们最好用向量化(和矩阵)。
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/StopWatch.java
%load ../utils/Training.java
%load ../utils/Accumulator.java
import ai.djl.basicdataset.tabular.*;
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;
NDManager manager = NDManager.newBaseManager();
StopWatch stopWatch = new StopWatch();
NDArray A = manager.zeros(new Shape(256, 256));
NDArray B = manager.randomNormal(new Shape(256, 256));
NDArray C = manager.randomNormal(new Shape(256, 256));
按元素分配只需遍历分别为\(\mathbf{B}\)和\(\mathbf{C}\)的所有行和列,即可将该值分配给\(\mathbf{A}\)。
// 逐元素计算A=BC
stopWatch.start();
for (int i = 0; i < 256; i++) {
for (int j = 0; j < 256; j++) {
A.set(new NDIndex(i, j),
B.get(new NDIndex(String.format("%d, :", i)))
.dot(C.get(new NDIndex(String.format(":, %d", j)))));
}
}
stopWatch.stop();
41.33436005
更快的策略是执行按列分配。
// 逐列计算A=BC
stopWatch.start();
for (int j = 0; j < 256; j++) {
A.set(new NDIndex(String.format(":, %d", j)), B.dot(C.get(new NDIndex(String.format(":, %d", j)))));
}
stopWatch.stop();
0.174655732
最有效的方法是在一个区块中执行整个操作。让我们看看它们各自的操作速度是多少。
// 一次性计算A=BC
stopWatch.start();
A = B.dot(C);
stopWatch.stop();
// Multiply and add count as separate operations (fused in practice)
float[] gigaflops = new float[stopWatch.getTimes().size()];
for (int i = 0; i < stopWatch.getTimes().size(); i++) {
gigaflops[i] = (float)(2 / stopWatch.getTimes().get(i));
}
String.format("Performance in Gigaflops: element %.3f, column %.3f, full %.3f", gigaflops[0], gigaflops[1], gigaflops[2]);
Performance in Gigaflops: element 0.048, column 11.451, full 48.620
11.5.2. 小批量¶
之前我们会理所当然地读取数据的小批量,而不是观测单个数据来更新参数,现在简要解释一下原因。 处理单个观测值需要我们执行许多单一矩阵-矢量(甚至矢量-矢量)乘法,这耗费相当大,而且对应深度学习框架也要巨大的开销。 这既适用于计算梯度以更新参数时,也适用于用神经网络预测。 也就是说,每当我们执行\(\mathbf{w} \leftarrow \mathbf{w} - \eta_t \mathbf{g}_t\)时,消耗巨大。其中
。
我们可以通过将其应用于一个小批量观测值来提高此操作的计算效率。 也就是说,我们将梯度\(\mathbf{g}_t\)替换为一个小批量而不是单个观测值
让我们看看这对\(\mathbf{g}_t\)的统计属性有什么影响:由于\(\mathbf{x}_t\)和小批量\(\mathcal{B}_t\)的所有元素都是从训练集中随机抽出的,因此梯度的预期保持不变。 另一方面,方差显著降低。 由于小批量梯度由正在被平均计算的\(b := |\mathcal{B}_t|\)个独立梯度组成,其标准差降低了\(b^{-\frac{1}{2}}\)。 这本身就是一件好事,因为这意味着更新与完整的梯度更接近了。
天真点说,这表明选择大型的小批量\(\mathcal{B}_t\)将是普遍可用的。 然而,经过一段时间后,与计算代价的线性增长相比,标准差的额外减少是微乎其微的。 在实践中我们选择一个足够大的小批量,它可以提供良好的计算效率同时仍适合GPU的内存。 下面,我们来看看这些高效的代码。 在里面我们执行相同的矩阵-矩阵乘法,但是这次我们将其一次性分为64列的“小批量”。
stopWatch.start();
for (int j = 0; j < 256; j+=64) {
A.set(new NDIndex(String.format(":, %d:%d", j, j + 64)),
B.dot(C.get(new NDIndex(String.format(":, %d:%d", j, j + 64)))));
}
stopWatch.stop();
String.format("Performance in Gigaflops: block %.3f\n", 2 / stopWatch.getTimes().get(3));
Performance in Gigaflops: block 41.325
显而易见,小批量上的计算基本上与完整矩阵一样有效。 需要注意的是,在 Section 7.5中,我们使用了一种在很大程度上取决于小批量中的方差的正则化。 随着后者增加,方差会减少,随之而来的是批量规范化带来的噪声注入的好处。 关于实例,请参阅 [Ioffe, 2017],了解有关如何重新缩放并计算适当项目。
11.5.3. 读取数据集¶
让我们来看看如何从数据中有效地生成小批量。 下面我们使用NASA开发的测试机翼的数据集不同飞行器产生的噪声来比较这些优化算法。 为方便起见,我们只使用前\(1,500\)样本。 数据已作预处理:我们移除了均值并将方差重新缩放到每个坐标为\(1\)。
NDManager manager = NDManager.newBaseManager();
public AirfoilRandomAccess getDataCh11(int batchSize, int n) throws IOException, TranslateException {
// Load data
AirfoilRandomAccess airfoil = AirfoilRandomAccess.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.optNormalize(true)
.optLimit(n)
.build();
return airfoil;
}
11.5.4. 从零开始实现¶
Section 3.2一节中已经实现过小批量随机梯度下降算法。
我们在这里将它的输入参数变得更加通用,主要是为了方便本章后面介绍的其他优化算法也可以使用同样的输入。
具体来说,我们添加了一个状态输入states
并将超参数放在字典hyperparams
中。
此外,我们将在训练函数里对各个小批量样本的损失求平均,因此优化算法中的梯度不需要除以批量大小。
public class Optimization {
public static void sgd(NDList params, NDList states, Map<String, Float> hyperparams) {
for (int i = 0; i < params.size(); i++) {
NDArray param = params.get(i);
// Update param
// param = param - param.gradient * lr
param.subi(param.getGradient().mul(hyperparams.get("lr")));
}
}
}
下面实现一个通用的训练函数,以方便本章后面介绍的其他优化算法使用。 它初始化了一个线性回归模型,然后可以使用小批量随机梯度下降以及后续小节介绍的其他算法来训练模型。
public static float evaluateLoss(Iterable<Batch> dataIterator, NDArray w, NDArray b) {
Accumulator metric = new Accumulator(2); // sumLoss, numExamples
for (Batch batch : dataIterator) {
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();
NDArray yHat = Training.linreg(X, w, b);
float lossSum = Training.squaredLoss(yHat, y).sum().getFloat();
metric.add(new float[]{lossSum, (float) y.size()});
batch.close();
}
return metric.get(0) / metric.get(1);
}
public static class LossTime {
public float[] loss;
public float[] time;
public LossTime(float[] loss, float[] time) {
this.loss = loss;
this.time = time;
}
}
public void plotLossEpoch(float[] loss, float[] epoch) {
Table data = Table.create("data")
.addColumns(
DoubleColumn.create("epoch", Functions.floatToDoubleArray(epoch)),
DoubleColumn.create("loss", Functions.floatToDoubleArray(loss))
);
display(LinePlot.create("loss vs. epoch", data, "epoch", "loss"));
}
public float[] arrayListToFloat (ArrayList<Double> arrayList) {
float[] ret = new float[arrayList.size()];
for (int i = 0; i < arrayList.size(); i++) {
ret[i] = arrayList.get(i).floatValue();
}
return ret;
}
@FunctionalInterface
public static interface TrainerConsumer {
void train(NDList params, NDList states, Map<String, Float> hyperparams);
}
public static LossTime trainCh11(TrainerConsumer trainer, NDList states, Map<String, Float> hyperparams,
AirfoilRandomAccess dataset,
int featureDim, int numEpochs) throws IOException, TranslateException {
NDManager manager = NDManager.newBaseManager();
NDArray w = manager.randomNormal(0, 0.01f, new Shape(featureDim, 1), DataType.FLOAT32);
NDArray b = manager.zeros(new Shape(1));
w.setRequiresGradient(true);
b.setRequiresGradient(true);
NDList params = new NDList(w, b);
int n = 0;
StopWatch stopWatch = new StopWatch();
stopWatch.start();
float lastLoss = -1;
ArrayList<Double> loss = new ArrayList<>();
ArrayList<Double> epoch = new ArrayList<>();
for (int i = 0; i < numEpochs; i++) {
for (Batch batch : dataset.getData(manager)) {
int len = (int) dataset.size() / batch.getSize(); // number of batches
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();
NDArray l;
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
NDArray yHat = Training.linreg(X, params.get(0), params.get(1));
l = Training.squaredLoss(yHat, y).mean();
gc.backward(l);
}
trainer.train(params, states, hyperparams);
n += X.getShape().get(0);
if (n % 200 == 0) {
stopWatch.stop();
lastLoss = evaluateLoss(dataset.getData(manager), params.get(0), params.get(1));
loss.add((double) lastLoss);
double lastEpoch = 1.0 * n / X.getShape().get(0) / len;
epoch.add(lastEpoch);
stopWatch.start();
}
batch.close();
}
}
float[] lossArray = arrayListToFloat(loss);
float[] epochArray = arrayListToFloat(epoch);
plotLossEpoch(lossArray, epochArray);
System.out.printf("loss: %.3f, %.3f sec/epoch\n", lastLoss, stopWatch.avg());
float[] timeArray = arrayListToFloat(stopWatch.cumsum());
return new LossTime(lossArray, timeArray);
}
让我们来看看批量梯度下降的优化是如何进行的。 这可以通过将小批量设置为1500(即样本总数)来实现。 因此,模型参数每个迭代轮数只迭代一次。
public static LossTime trainSgd(float lr, int batchSize, int numEpochs) throws IOException, TranslateException {
AirfoilRandomAccess dataset = getDataCh11(batchSize, 1500);
int featureDim = dataset.getColumnNames().size();
Map<String, Float> hyperparams = new HashMap<>();
hyperparams.put("lr", lr);
return trainCh11(Optimization::sgd, new NDList(), hyperparams, dataset, featureDim, numEpochs);
}
LossTime gdRes = trainSgd(1f, 1500, 10);
loss: 0.251, 0.692 sec/epoch
当批量大小为1时,优化使用的是随机梯度下降。 为了简化实现,我们选择了很小的学习率。 在随机梯度下降的实验中,每当一个样本被处理,模型参数都会更新。 在这个例子中,这相当于每个迭代轮数有1500次更新。 可以看到,目标函数值的下降在1个迭代轮数后就变得较为平缓。 尽管两个例子在一个迭代轮数内都处理了1500个样本,但实验中随机梯度下降的一个迭代轮数耗时更多。 这是因为随机梯度下降更频繁地更新了参数,而且一次处理单个观测值效率较低。
LossTime sgdRes = trainSgd(0.005f, 1, 2);
loss: 0.248, 0.270 sec/epoch
最后,当批量大小等于100时,我们使用小批量随机梯度下降进行优化。 每个迭代轮数所需的时间比随机梯度下降和批量梯度下降所需的时间短。
LossTime mini1Res = trainSgd(0.4f, 100, 2);
loss: 0.247, 0.044 sec/epoch
将批量大小减少到10,每个迭代轮数的时间都会增加,因为每批的工作负载有了更低的执行效率。
LossTime mini2Res = trainSgd(0.05f, 10, 2);
loss: 0.243, 0.064 sec/epoch
现在我们可以比较前四个实验的时间与损失。 可以看出,尽管在处理的样本数方面,随机梯度下降的收敛速度快于梯度下降,但与梯度下降相比,它需要更多的时间来达到同样的损失,因为逐个样本来计算梯度并不那么有效。 小批量随机梯度下降能够平衡收敛速度和计算效率。 大小为10的小批量比随机梯度下降更有效; 大小为100的小批量在运行时间上甚至优于梯度下降。
public String[] getTypeArray(LossTime lossTime, String name) {
String[] type = new String[lossTime.time.length];
for (int i = 0; i < type.length; i++) {
type[i] = name;
}
return type;
}
// Converts a float array to a log scale
float[] convertLogScale(float[] array) {
float[] newArray = new float[array.length];
for (int i = 0; i < array.length; i++) {
newArray[i] = (float) Math.log10(array[i]);
}
return newArray;
}
float[] time = ArrayUtils.addAll(ArrayUtils.addAll(gdRes.time, sgdRes.time),
ArrayUtils.addAll(mini1Res.time, mini2Res.time));
float[] loss = ArrayUtils.addAll(ArrayUtils.addAll(gdRes.loss, sgdRes.loss),
ArrayUtils.addAll(mini1Res.loss, mini2Res.loss));
String[] type = ArrayUtils.addAll(ArrayUtils.addAll(getTypeArray(gdRes, "gd"),
getTypeArray(sgdRes, "sgd")),
ArrayUtils.addAll(getTypeArray(mini1Res, "batch size = 100"),
getTypeArray(mini1Res, "batch size = 10")));
Table data = Table.create("data")
.addColumns(
DoubleColumn.create("log time (sec)", Functions.floatToDoubleArray(convertLogScale(time))),
DoubleColumn.create("loss", Functions.floatToDoubleArray(loss)),
StringColumn.create("type", type)
);
LinePlot.create("loss vs. time", data, "log time (sec)", "loss", "type");
11.5.5. 简洁实现¶
下面用DJL 自带算法实现一个通用的训练函数,我们将在本章中其它小结使用它。
public void trainConciseCh11(Optimizer sgd, AirfoilRandomAccess dataset,
int numEpochs) throws IOException, TranslateException {
// Initialization
NDManager manager = NDManager.newBaseManager();
SequentialBlock net = new SequentialBlock();
Linear linear = Linear.builder().setUnits(1).build();
net.add(linear);
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);
Model model = Model.newInstance("concise implementation");
model.setBlock(net);
Loss loss = Loss.l2Loss();
DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(sgd)
.addEvaluator(new Accuracy()) // Model Accuracy
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging
Trainer trainer = model.newTrainer(config);
int n = 0;
StopWatch stopWatch = new StopWatch();
stopWatch.start();
trainer.initialize(new Shape(10, 5));
Metrics metrics = new Metrics();
trainer.setMetrics(metrics);
float lastLoss = -1;
ArrayList<Double> lossArray = new ArrayList<>();
ArrayList<Double> epochArray = new ArrayList<>();
for (Batch batch : trainer.iterateDataset(dataset)) {
int len = (int) dataset.size() / batch.getSize(); // number of batches
NDArray X = batch.getData().head();
EasyTrain.trainBatch(trainer, batch);
trainer.step();
n += X.getShape().get(0);
if (n % 200 == 0) {
stopWatch.stop();
stopWatch.stop();
lastLoss = evaluateLoss(dataset.getData(manager), linear.getParameters().get(0).getValue().getArray()
.reshape(new Shape(dataset.getColumnNames().size(), 1)),
linear.getParameters().get(1).getValue().getArray());
lossArray.add((double) lastLoss);
double lastEpoch = 1.0 * n / X.getShape().get(0) / len;
epochArray.add(lastEpoch);
stopWatch.start();
}
batch.close();
}
plotLossEpoch(arrayListToFloat(lossArray), arrayListToFloat(epochArray));
System.out.printf("loss: %.3f, %.3f sec/epoch\n", lastLoss, stopWatch.avg());
}
下面使用这个训练函数,复现之前的实验。
AirfoilRandomAccess airfoilDataset = getDataCh11(10, 1500);
Tracker lrt = Tracker.fixed(0.05f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
trainConciseCh11(sgd, airfoilDataset, 2);
INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.078 ms.
Training: 100% |████████████████████████████████████████| Accuracy: 1.00, L2Loss: 0.29
loss: 0.244, 1.730 sec/epoch
11.5.6. 小结¶
由于减少了深度学习框架的额外开销,使用更好的内存方位以及CPU和GPU上的缓存,向量化使代码更加高效。
随机梯度下降的“统计效率”与大批量一次处理数据的“计算效率”之间存在权衡。小批量随机梯度下降提供了两全其美的答案:计算和统计效率。
在小批量随机梯度下降中,我们处理通过训练数据的随机排列获得的批量数据(即每个观测值只处理一次,但按随机顺序)。
在训练期间降低学习率有助于训练。
一般来说,小批量随机梯度下降比随机梯度下降和梯度下降的速度快,收敛风险较小。
11.5.7. 练习¶
修改批量大小和学习率,并观察目标函数值的下降率以及每个迭代轮数消耗的时间。
将小批量随机梯度下降与实际从训练集中取样替换的变体进行比较。会看出什么?
一个邪恶的精灵在没通知你的情况下复制了你的数据集(即每个观测发生两次,你的数据集增加到原始大小的两倍,但没有人告诉你)。随机梯度下降、小批量随机梯度下降和梯度下降的表现将如何变化?