Run this notebook online: or Colab:

4.4. 模型选择、欠拟合和过拟合¶

4.4.1. 训练误差和泛化误差¶

4.4.1.2. 模型复杂性¶

1. 可调整参数的数量。当可调整参数的数量（有时称为自由度）很大时，模型往往更容易过拟合。

2. 参数采用的值。当权重的取值范围较大时，模型可能更容易过拟合。

3. 训练样本的数量。即使你的模型很简单，也很容易过拟合只包含一两个样本的数据集。而过拟合一个有数百万个样本的数据集则需要一个极其灵活的模型。

4.4.3. 欠拟合还是过拟合？¶

4.4.3.1. 模型复杂性¶

(4.4.1)$\hat{y}= \sum_{i=0}^d x^i w_i$

.. _fig_capacity_vs_error:

4.4.4. 多项式回归¶

%load ../utils/djl-imports

import ai.djl.metric.*;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.special.Gamma;


4.4.4.1. 生成数据集¶

(4.4.2)$y = 5 + 1.2x - 3.4\frac{x^2}{2!} + 5.6 \frac{x^3}{3!} + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.1^2).$

// utility functions for shuffle data
public void swap(NDArray arr, int i, int j) {
float tmp = arr.getFloat(i);
arr.set(new NDIndex(i), arr.getFloat(j));
arr.set(new NDIndex(j), tmp);
}

public NDArray shuffle(NDArray arr) {
int size = (int) arr.size();

Random rnd = RandomUtils.RANDOM;

for (int i = Math.toIntExact(size) - 1; i > 0; --i) {
swap(arr, i, rnd.nextInt(i));
}
return arr;
}

int maxDegree = 20; // Maximum degree of the polynomial
// Training and test dataset sizes
int nTrain = 100;
int nTest = 100;

NDManager manager = NDManager.newBaseManager();
NDArray trueW = manager.zeros(new Shape(maxDegree)); // Allocate lots of empty space
NDArray tempArr = manager.create(new float[]{5f, 1.2f, -3.4f, 5.6f});

for (int i = 0; i < tempArr.size(); i++) {
trueW.set(new NDIndex(i), tempArr.getFloat(i));
}

NDArray features = manager.randomNormal(new Shape(nTrain + nTest, 1));
features = shuffle(features);

NDArray polyFeatures = features.pow(manager.arange(maxDegree).reshape(1, -1));

for(int i = 0; i <  maxDegree; i ++){
polyFeatures.set(new NDIndex(":, " + i), polyFeatures.get(":, " + i).div(Gamma.gamma(i+1)));
}
// NDArray factorialArr = factorial(manager.arange(maxDegree).add(1.0f).toType(DataType.FLOAT32, false)).reshape(1, -1);

// polyFeatures = polyFeatures.div(factorialArr);
// Shape of labels: (n_train + n_test,)
NDArray labels = polyFeatures.dot(trueW);
labels = labels.add(manager.randomNormal(0, 0.1f, labels.getShape(), DataType.FLOAT32));


System.out.println("features: " + features.get(":2"));
System.out.println("polyFeatures: " + polyFeatures.get(":2"));
System.out.println("labels: " + labels.get(":2"));

features: ND: (2, 1) gpu(0) float32
[[-0.8061],
[ 1.2403],
]

polyFeatures: ND: (2, 20) gpu(0) float32
[[ 1.00000000e+00, -8.06136370e-01,  3.24927926e-01, -8.73120725e-02,  1.75963584e-02, -2.83701299e-03,  3.81169870e-04, -4.38964162e-05,  4.42331202e-06, -3.96199198e-07,  3.19390594e-08, -2.34065789e-09,  1.57240790e-10, -9.75057795e-12,  5.61449758e-13, -3.01736683e-14,  1.52025579e-15, -7.20902065e-17,  3.22858540e-18, -1.36983152e-19],
[ 1.00000000e+00,  1.24028385e+00,  7.69151986e-01,  3.17988932e-01,  9.85991359e-02,  2.44581830e-02,  5.05584804e-03,  8.95812409e-04,  1.38882708e-04,  1.91393319e-05,  2.37382051e-06,  2.67655565e-07,  2.76640701e-08,  2.63933075e-09,  2.33822822e-10,  1.93337759e-11,  1.49871066e-12,  1.09342745e-13,  7.53422423e-15,  4.91819834e-16],
]

labels: ND: (2) gpu(0) float32
[2.5422, 5.4919]


4.4.4.2. 对模型进行训练和测试¶

int logInterval = 20;
int numEpochs = Integer.getInteger("MAX_EPOCH", 400);

public ArrayDataset loadArray(NDArray features, NDArray labels, int batchSize, boolean shuffle) {
return new ArrayDataset.Builder()
.setData(features) // set the features
.optLabels(labels) // set the labels
.setSampling(batchSize, shuffle) // set the batch size and random sampling
.build();
}

double[] trainLoss;
double[] testLoss;
double[] epochCount;



NDArray weight = null;

public void train(NDArray trainFeatures, NDArray testFeatures, NDArray trainLabels, NDArray testLabels, int nDegree)
throws IOException, TranslateException {

Loss l2Loss = Loss.l2Loss();
NDManager manager = NDManager.newBaseManager();
Tracker lrt = Tracker.fixed(0.01f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
DefaultTrainingConfig config = new DefaultTrainingConfig(l2Loss)
.optDevices(manager.getEngine().getDevices(1)) // single GPU
.optOptimizer(sgd) // Optimizer (loss function)

Model model = Model.newInstance("mlp");
SequentialBlock net = new SequentialBlock();
// Switch off the bias since we already catered for it in the polynomial
// features
Linear linearBlock = Linear.builder().optBias(false).setUnits(1).build();

model.setBlock(net);
Trainer trainer = model.newTrainer(config);

int batchSize = Math.min(10, (int) trainLabels.getShape().get(0));

ArrayDataset trainIter = loadArray(trainFeatures, trainLabels, batchSize, true);
ArrayDataset testIter = loadArray(testFeatures, testLabels, batchSize, true);

trainer.initialize(new Shape(1, nDegree));
System.out.println("Start Training...");
for (int epoch = 1; epoch <= numEpochs; epoch++) {

// Iterate over dataset
for (Batch batch : trainer.iterateDataset(trainIter)) {
// Update loss and evaulator
EasyTrain.trainBatch(trainer, batch);

// Update parameters
trainer.step();

batch.close();
}
// reset training and validation evaluators at end of epoch

for (Batch batch : trainer.iterateDataset(testIter)) {
// Update loss and evaulator
EasyTrain.validateBatch(trainer, batch);

batch.close();
}

trainer.notifyListeners(listener -> listener.onEpoch(trainer));
if (epoch % logInterval == 0) {
}
}
System.out.println("Training complete...");
model.close();
}


4.4.4.3. 三阶多项式函数拟合(正态)¶

// 从多项式特征中选择前4个维度，即 1, x, x^2/2!, x^3/3!
int nDegree = 4;
train(polyFeatures.get("0:" + nTrain + ", 0:" + nDegree),
polyFeatures.get(nTrain + ": , 0:" + nDegree),
labels.get(":" + nTrain),
labels.get(nTrain + ":"), nDegree);

Start Training...
Training complete...

String[] lossLabel = new String[trainLoss.length + testLoss.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainLoss.length, trainLoss.length + testLoss.length, "test loss");

StringColumn.create("lossLabel", lossLabel)
);
Figure figure = LinePlot.create("Normal", data, "epochCount", "loss", "lossLabel");
// set Y axis to log scale
Axis yAxis = Axis.builder()
.type(Axis.Type.LOG)
.build();
Layout layout = Layout.builder("Normal")
.yAxis(yAxis)
.build();
figure.setLayout(layout);
render(figure,"text/html");


4.4.4.4. 线性函数拟合(欠拟合)¶

// 从多项式特征中选择前2个维度，即 1, x
int nDegree = 2;
train(polyFeatures.get("0:" + nTrain + ", 0:" + nDegree),
polyFeatures.get(nTrain + ": , 0:" + nDegree),
labels.get(":" + nTrain),
labels.get(nTrain + ":"), nDegree);

Start Training...
Training complete...

String[] lossLabel = new String[trainLoss.length + testLoss.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainLoss.length, trainLoss.length + testLoss.length, "test loss");

StringColumn.create("lossLabel", lossLabel)
);
Figure figure = LinePlot.create("Underfitting", data, "epochCount", "loss", "lossLabel");
// set Y axis to log scale
Axis yAxis = Axis.builder()
.type(Axis.Type.LOG)
.build();
Layout layout = Layout.builder("Underfitting")
.yAxis(yAxis)
.build();
figure.setLayout(layout);
render(figure,"text/html");


4.4.4.5. 高阶多项式函数拟合(过拟合)¶

// 从多项式特征中选取所有维度
numEpochs = 1500;

train(polyFeatures.get("0:" + nTrain + ", 0:" + maxDegree),
polyFeatures.get(nTrain + ": , 0:" + maxDegree),
labels.get(":" + nTrain),
labels.get(nTrain + ":"), maxDegree);

Start Training...
Training complete...

String[] lossLabel = new String[trainLoss.length + testLoss.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainLoss.length, trainLoss.length + testLoss.length, "test loss");

StringColumn.create("lossLabel", lossLabel)
);

Figure figure = LinePlot.create("Overfitting", data, "epochCount", "loss", "lossLabel");
// set Y axis to log scale
Axis yAxis = Axis.builder()
.type(Axis.Type.LOG)
.build();
Layout layout = Layout.builder("Overfitting")
.yAxis(yAxis)
.build();
figure.setLayout(layout);
render(figure,"text/html");


4.4.5. 小结¶

• 由于不能基于训练误差来估计泛化误差，因此简单地最小化训练误差并不一定意味着泛化误差的减小。机器学习模型需要注意防止过拟合，来使得泛化误差最小。

• 验证集可以用于模型选择，但不能过于随意地使用它。

• 欠拟合是指模型无法继续减少训练误差。过拟合是指训练误差远小于验证误差。

• 我们应该选择一个复杂度适当的模型，避免使用数量不足的训练样本。

4.4.6. 练习¶

1. 你能准确地解出这个多项式回归问题吗？提示：使用线性代数。

2. 考虑多项式的模型选择：

1. 绘制训练损失与模型复杂度（多项式的阶数）的关系图。你观察到了什么？需要多少阶的多项式才能将训练损失减少到0?

2. 在这种情况下绘制测试的损失图。

3. 生成同样的图，作为数据量的函数。

3. 如果你不对多项式特征$$x^i$$进行标准化($$1/i!$$)，会发生什么事情？你能用其他方法解决这个问题吗？

4. 你能期待看到泛化误差为零吗？