Run this notebook online: or Colab:

# 11.5. 小批量随机梯度下降¶

## 11.5.1. 向量化和缓存¶

1. 我们可以计算$$\mathbf{A}_{ij} = \mathbf{B}_{i,:} \mathbf{C}_{:,j}^\top$$，也就是说，我们可以通过点积进行逐元素计算。

2. 我们可以计算$$\mathbf{A}_{:,j} = \mathbf{B} \mathbf{C}_{:,j}^\top$$，也就是说，我们可以一次计算一列。同样，我们可以一次计算$$\mathbf{A}$$一行$$\mathbf{A}_{i,:}$$

3. 我们可以简单地计算$$\mathbf{A} = \mathbf{B} \mathbf{C}$$

4. 我们可以将$$\mathbf{B}$$$$\mathbf{C}$$分成较小的区块矩阵，然后一次计算$$\mathbf{A}$$的一个区块。

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/StopWatch.java
%load ../utils/Training.java
%load ../utils/Accumulator.java

import ai.djl.basicdataset.tabular.*;
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;

NDManager manager = NDManager.newBaseManager();
StopWatch stopWatch = new StopWatch();
NDArray A = manager.zeros(new Shape(256, 256));
NDArray B = manager.randomNormal(new Shape(256, 256));
NDArray C = manager.randomNormal(new Shape(256, 256));


// 逐元素计算A=BC
stopWatch.start();
for (int i = 0; i < 256; i++) {
for (int j = 0; j < 256; j++) {
A.set(new NDIndex(i, j),
B.get(new NDIndex(String.format("%d, :", i)))
.dot(C.get(new NDIndex(String.format(":, %d", j)))));
}
}
stopWatch.stop();

41.33436005


// 逐列计算A=BC
stopWatch.start();
for (int j = 0; j < 256; j++) {
A.set(new NDIndex(String.format(":, %d", j)), B.dot(C.get(new NDIndex(String.format(":, %d", j)))));
}
stopWatch.stop();

0.174655732


// 一次性计算A=BC
stopWatch.start();
A = B.dot(C);
stopWatch.stop();

// Multiply and add count as separate operations (fused in practice)
float[] gigaflops = new float[stopWatch.getTimes().size()];
for (int i = 0; i < stopWatch.getTimes().size(); i++) {
gigaflops[i] = (float)(2 / stopWatch.getTimes().get(i));
}
String.format("Performance in Gigaflops: element %.3f, column %.3f, full %.3f", gigaflops[0], gigaflops[1], gigaflops[2]);

Performance in Gigaflops: element 0.048, column 11.451, full 48.620


## 11.5.2. 小批量¶

(11.5.1)$\mathbf{g}_t = \partial_{\mathbf{w}} f(\mathbf{x}_{t}, \mathbf{w})$

(11.5.2)$\mathbf{g}_t = \partial_{\mathbf{w}} \frac{1}{|\mathcal{B}_t|} \sum_{i \in \mathcal{B}_t} f(\mathbf{x}_{i}, \mathbf{w})$

stopWatch.start();
for (int j = 0; j < 256; j+=64) {
A.set(new NDIndex(String.format(":, %d:%d", j, j + 64)),
B.dot(C.get(new NDIndex(String.format(":, %d:%d", j, j + 64)))));
}
stopWatch.stop();

String.format("Performance in Gigaflops: block %.3f\n", 2 / stopWatch.getTimes().get(3));

Performance in Gigaflops: block 41.325


## 11.5.3. 读取数据集¶

NDManager manager = NDManager.newBaseManager();

public AirfoilRandomAccess getDataCh11(int batchSize, int n) throws IOException, TranslateException {
// Load data
AirfoilRandomAccess airfoil = AirfoilRandomAccess.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.optNormalize(true)
.optLimit(n)
.build();
return airfoil;
}


## 11.5.4. 从零开始实现¶

Section 3.2一节中已经实现过小批量随机梯度下降算法。 我们在这里将它的输入参数变得更加通用，主要是为了方便本章后面介绍的其他优化算法也可以使用同样的输入。 具体来说，我们添加了一个状态输入states并将超参数放在字典hyperparams中。 此外，我们将在训练函数里对各个小批量样本的损失求平均，因此优化算法中的梯度不需要除以批量大小。

public class Optimization {
public static void sgd(NDList params, NDList states, Map<String, Float> hyperparams) {
for (int i = 0; i < params.size(); i++) {
NDArray param = params.get(i);
// Update param
// param = param - param.gradient * lr
param.subi(param.getGradient().mul(hyperparams.get("lr")));
}
}
}


public static float evaluateLoss(Iterable<Batch> dataIterator, NDArray w, NDArray b) {
Accumulator metric = new Accumulator(2);  // sumLoss, numExamples

for (Batch batch : dataIterator) {
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();
NDArray yHat = Training.linreg(X, w, b);
float lossSum = Training.squaredLoss(yHat, y).sum().getFloat();

metric.add(new float[]{lossSum, (float) y.size()});
batch.close();
}
return metric.get(0) / metric.get(1);
}

public static class LossTime {
public float[] loss;
public float[] time;

public LossTime(float[] loss, float[] time) {
this.loss = loss;
this.time = time;
}
}

public void plotLossEpoch(float[] loss, float[] epoch) {
Table data = Table.create("data")
.addColumns(
DoubleColumn.create("epoch", Functions.floatToDoubleArray(epoch)),
DoubleColumn.create("loss", Functions.floatToDoubleArray(loss))
);
display(LinePlot.create("loss vs. epoch", data, "epoch", "loss"));
}

public float[] arrayListToFloat (ArrayList<Double> arrayList) {
float[] ret = new float[arrayList.size()];

for (int i = 0; i < arrayList.size(); i++) {
ret[i] = arrayList.get(i).floatValue();
}
return ret;
}

@FunctionalInterface
public static interface TrainerConsumer {
void train(NDList params, NDList states, Map<String, Float> hyperparams);

}

public static LossTime trainCh11(TrainerConsumer trainer, NDList states, Map<String, Float> hyperparams,
AirfoilRandomAccess dataset,
int featureDim, int numEpochs) throws IOException, TranslateException {
NDManager manager = NDManager.newBaseManager();
NDArray w = manager.randomNormal(0, 0.01f, new Shape(featureDim, 1), DataType.FLOAT32);
NDArray b = manager.zeros(new Shape(1));

w.setRequiresGradient(true);
b.setRequiresGradient(true);

NDList params = new NDList(w, b);
int n = 0;
StopWatch stopWatch = new StopWatch();
stopWatch.start();

float lastLoss = -1;
ArrayList<Double> loss = new ArrayList<>();
ArrayList<Double> epoch = new ArrayList<>();

for (int i = 0; i < numEpochs; i++) {
for (Batch batch : dataset.getData(manager)) {
int len = (int) dataset.size() / batch.getSize();  // number of batches
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();

NDArray l;
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
NDArray yHat = Training.linreg(X, params.get(0), params.get(1));
l = Training.squaredLoss(yHat, y).mean();
gc.backward(l);
}

trainer.train(params, states, hyperparams);
n += X.getShape().get(0);

if (n % 200 == 0) {
stopWatch.stop();
lastLoss = evaluateLoss(dataset.getData(manager), params.get(0), params.get(1));
loss.add((double) lastLoss);
double lastEpoch = 1.0 * n / X.getShape().get(0) / len;
epoch.add(lastEpoch);
stopWatch.start();
}

batch.close();
}
}
float[] lossArray = arrayListToFloat(loss);
float[] epochArray = arrayListToFloat(epoch);
plotLossEpoch(lossArray, epochArray);
System.out.printf("loss: %.3f, %.3f sec/epoch\n", lastLoss, stopWatch.avg());
float[] timeArray = arrayListToFloat(stopWatch.cumsum());
return new LossTime(lossArray, timeArray);
}


public static LossTime trainSgd(float lr, int batchSize, int numEpochs) throws IOException, TranslateException {
AirfoilRandomAccess dataset = getDataCh11(batchSize, 1500);
int featureDim = dataset.getColumnNames().size();

Map<String, Float> hyperparams = new HashMap<>();
hyperparams.put("lr", lr);

return trainCh11(Optimization::sgd, new NDList(), hyperparams, dataset, featureDim, numEpochs);
}

LossTime gdRes = trainSgd(1f, 1500, 10);

loss: 0.251, 0.692 sec/epoch


LossTime sgdRes = trainSgd(0.005f, 1, 2);

loss: 0.248, 0.270 sec/epoch


LossTime mini1Res = trainSgd(0.4f, 100, 2);

loss: 0.247, 0.044 sec/epoch


LossTime mini2Res = trainSgd(0.05f, 10, 2);

loss: 0.243, 0.064 sec/epoch


public String[] getTypeArray(LossTime lossTime, String name) {
String[] type = new String[lossTime.time.length];
for (int i = 0; i < type.length; i++) {
type[i] = name;
}
return type;
}

// Converts a float array to a log scale
float[] convertLogScale(float[] array) {
float[] newArray = new float[array.length];
for (int i = 0; i < array.length; i++) {
newArray[i] = (float) Math.log10(array[i]);
}
return newArray;
}

float[] time = ArrayUtils.addAll(ArrayUtils.addAll(gdRes.time, sgdRes.time),
ArrayUtils.addAll(mini1Res.time, mini2Res.time));
float[] loss = ArrayUtils.addAll(ArrayUtils.addAll(gdRes.loss, sgdRes.loss),
ArrayUtils.addAll(mini1Res.loss, mini2Res.loss));
String[] type = ArrayUtils.addAll(ArrayUtils.addAll(getTypeArray(gdRes, "gd"),
getTypeArray(sgdRes, "sgd")),
ArrayUtils.addAll(getTypeArray(mini1Res, "batch size = 100"),
getTypeArray(mini1Res, "batch size = 10")));
Table data = Table.create("data")
.addColumns(
DoubleColumn.create("log time (sec)", Functions.floatToDoubleArray(convertLogScale(time))),
DoubleColumn.create("loss", Functions.floatToDoubleArray(loss)),
StringColumn.create("type", type)
);
LinePlot.create("loss vs. time", data, "log time (sec)", "loss", "type");


## 11.5.5. 简洁实现¶

public void trainConciseCh11(Optimizer sgd, AirfoilRandomAccess dataset,
int numEpochs) throws IOException, TranslateException {
// Initialization
NDManager manager = NDManager.newBaseManager();

SequentialBlock net = new SequentialBlock();
Linear linear = Linear.builder().setUnits(1).build();
net.add(linear);
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);

Model model = Model.newInstance("concise implementation");
model.setBlock(net);

Loss loss = Loss.l2Loss();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(sgd)
.addEvaluator(new Accuracy()) // Model Accuracy
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);

int n = 0;
StopWatch stopWatch = new StopWatch();
stopWatch.start();

trainer.initialize(new Shape(10, 5));

Metrics metrics = new Metrics();
trainer.setMetrics(metrics);

float lastLoss = -1;

ArrayList<Double> lossArray = new ArrayList<>();
ArrayList<Double> epochArray = new ArrayList<>();

for (Batch batch : trainer.iterateDataset(dataset)) {
int len = (int) dataset.size() / batch.getSize();  // number of batches

NDArray X = batch.getData().head();
EasyTrain.trainBatch(trainer, batch);
trainer.step();

n += X.getShape().get(0);

if (n % 200 == 0) {
stopWatch.stop();
stopWatch.stop();
lastLoss = evaluateLoss(dataset.getData(manager), linear.getParameters().get(0).getValue().getArray()
.reshape(new Shape(dataset.getColumnNames().size(), 1)),
linear.getParameters().get(1).getValue().getArray());

lossArray.add((double) lastLoss);
double lastEpoch = 1.0 * n / X.getShape().get(0) / len;
epochArray.add(lastEpoch);
stopWatch.start();
}
batch.close();
}
plotLossEpoch(arrayListToFloat(lossArray), arrayListToFloat(epochArray));

System.out.printf("loss: %.3f, %.3f sec/epoch\n", lastLoss, stopWatch.avg());
}


AirfoilRandomAccess airfoilDataset = getDataCh11(10, 1500);

Tracker lrt = Tracker.fixed(0.05f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

trainConciseCh11(sgd, airfoilDataset, 2);

INFO Training on: 4 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.078 ms.

Training:    100% |████████████████████████████████████████| Accuracy: 1.00, L2Loss: 0.29

loss: 0.244, 1.730 sec/epoch


## 11.5.6. 小结¶

• 由于减少了深度学习框架的额外开销，使用更好的内存方位以及CPU和GPU上的缓存，向量化使代码更加高效。

• 随机梯度下降的“统计效率”与大批量一次处理数据的“计算效率”之间存在权衡。小批量随机梯度下降提供了两全其美的答案：计算和统计效率。

• 在小批量随机梯度下降中，我们处理通过训练数据的随机排列获得的批量数据（即每个观测值只处理一次，但按随机顺序）。

• 在训练期间降低学习率有助于训练。

• 一般来说，小批量随机梯度下降比随机梯度下降和梯度下降的速度快，收敛风险较小。

## 11.5.7. 练习¶

1. 修改批量大小和学习率，并观察目标函数值的下降率以及每个迭代轮数消耗的时间。

2. 将小批量随机梯度下降与实际从训练集中取样替换的变体进行比较。会看出什么？

3. 一个邪恶的精灵在没通知你的情况下复制了你的数据集（即每个观测发生两次，你的数据集增加到原始大小的两倍，但没有人告诉你）。随机梯度下降、小批量随机梯度下降和梯度下降的表现将如何变化？