Run this notebook online:\ |Binder| or Colab: |Colab|
.. |Binder| image:: https://mybinder.org/badge_logo.svg
:target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_multilayer-perceptrons/weight-decay.ipynb
.. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_multilayer-perceptrons/weight-decay.ipynb
.. _sec_weight_decay:
权重衰减
========
我们已经描述了过拟合的问题,现在我们可以介绍一些正则化模型的技术。我们总是可以通过去收集更多的训练数据来缓解过拟合。但这可能成本很高而且耗时.或者完全超出我们的控制,在短期内不可能做到。假设已经拥有尽可能多的高质量数据,现在我们将重点放在正则化技术上。
回想一下,在多项式回归的例子(
:numref:`sec_model_selection`\ )中,我们可以通过调整拟合多项式的阶数来限制模型的容量。实际上,限制特征的数量是缓解过拟合的一种常用技术。然而,简单地丢弃特征对于这项工作来说可能过于生硬。我们继续思考多项式回归的例子,考虑高维输入可能发生的情况。多项式对多变量数据的自然扩展称为\ *单项式*\ (monomials),也可以说是变量幂的乘积。单项式的阶数是幂的和。例如,\ :math:`x_1^2 x_2`\ 和\ :math:`x_3 x_5^2`\ 都是3次单项式。
注意,随着阶数\ :math:`d`\ 的增长,带有阶数\ :math:`d`\ 的项数迅速增加。给定\ :math:`k`\ 个变量,阶数\ :math:`d`\ (即\ :math:`k`\ 多选\ :math:`d`\ )的个数为\ :math:`{k - 1 + d} \choose {k - 1}`\ 。即使是阶数上的微小变化,比如从\ :math:`2`\ 到\ :math:`3`\ ,也会显著增加我们模型的复杂性。因此,我们经常需要一个更细粒度的工具来调整函数的复杂性。
范数与权重衰减
--------------
在之前的章节,我们已经描述了\ :math:`L_2`\ 范数和\ :math:`L_1`\ 范数,它们是\ :math:`L_p`\ 范数的特殊情况。
([STRIKEOUT:权重衰减是最广泛使用的正则化的技术之一])
在训练参数化机器学习模型时,\ *权重衰减*\ (通常称为\ :math:`L_2`\ 正则化)是最广泛使用的正则化的技术之一。这项技术是基于一个基本直觉,即在所有函数\ :math:`f`\ 中,函数\ :math:`f = 0`\ (所有输入都得到值\ :math:`0`\ )在某种意义上是最简单的,我们可以通过函数与零的距离来衡量函数的复杂度。但是我们应该如何精确地测量一个函数和零之间的距离呢?没有一个正确的答案。事实上,整个数学分支,包括函数分析和巴拿赫空间理论,都在致力于回答这个问题。
一种简单的方法是通过线性函数\ :math:`f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x}`\ 中的权重向量的某个范数来度量其复杂性,例如\ :math:`\| \mathbf{w} \|^2`\ 。要保证权重向量比较小,最常用方法是将其范数作为惩罚项加到最小化损失的问题中。将原来的训练目标\ *最小化训练标签上的预测损失*\ ,调整为\ *最小化预测损失和惩罚项之和*\ 。
现在,如果我们的权重向量增长的太大,我们的学习算法可能会更集中于最小化权重范数\ :math:`\| \mathbf{w} \|^2`\ 。这正是我们想要的。让我们回顾一下
:numref:`sec_linear_regression`
中的线性回归例子。我们的损失由下式给出:
.. math:: L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2.
回想一下,\ :math:`\mathbf{x}^{(i)}`\ 是样本\ :math:`i`\ 的特征,\ :math:`y^{(i)}`\ 是样本\ :math:`i`\ 的标签。\ :math:`(\mathbf{w}, b)`\ 是权重和偏置参数。为了惩罚权重向量的大小,我们必须以某种方式在损失函数中添加\ :math:`\| \mathbf{w} \|^2`\ ,但是模型应该如何平衡这个新的额外惩罚的损失?实际上,我们通过\ *正则化常数*\ :math:`\lambda`\ 来描述这种权衡,这是一个非负超参数,我们使用验证数据拟合:
.. math:: L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2,
对于\ :math:`\lambda = 0`\ ,我们恢复了原来的损失函数。对于\ :math:`\lambda > 0`\ ,我们限制\ :math:`\| \mathbf{w} \|`\ 的大小。我们仍然除以\ :math:`2`\ :当我们取一个二次函数的导数时,\ :math:`2`\ 和\ :math:`1/2`\ 会抵消,以确保更新表达式看起来既漂亮又简单。聪明的读者可能会想知道为什么我们使用平方范数而不是标准范数(即欧几里得距离)。我们这样做是为了便于计算。通过平方\ :math:`L_2`\ 范数,我们去掉平方根,留下权重向量每个分量的平方和。这使得惩罚的导数很容易计算:导数的和等于和的导数。
此外,你可能会问为什么我们首先使用\ :math:`L_2`\ 范数,而不是\ :math:`L_1`\ 范数。事实上,这些选择在整个统计领域中都是有效的和受欢迎的。\ :math:`L_2`\ 正则化线性模型构成经典的\ *岭回归*\ (ridge
regression)算法,\ :math:`L_1`\ 正则化线性回归是统计学中类似的基本模型,通常被称为\ *套索回归*\ (lasso
regression)。
使用\ :math:`L_2`\ 范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为鲁棒。相比之下,\ :math:`L_1`\ 惩罚会导致模型将其他权重清除为零而将权重集中在一小部分特征上。这称为\ *特征选择*\ (feature
selection),这可能是其他场景下需要的。
使用与 :eq:`eq_linreg_batch_update`
中的相同符号,\ :math:`L_2`\ 正则化回归的小批量随机梯度下降更新如下式:
.. math::
\begin{aligned}
\mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right).
\end{aligned}
根据之前章节所讲的,我们根据估计值与观测值之间的差异来更新\ :math:`\mathbf{w}`\ 。然而,我们同时也在试图将\ :math:`\mathbf{w}`\ 的大小缩小到零。这就是为什么这种方法有时被称为\ *权重衰减*\ 。我们仅考虑惩罚项,优化算法在训练的每一步\ *衰减*\ 权重。与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。较小的\ :math:`\lambda`\ 值对应较少约束的\ :math:`\mathbf{w}`\ ,而较大的\ :math:`\lambda`\ 值对\ :math:`\mathbf{w}`\ 的约束更大。
是否对相应的偏置\ :math:`b^2`\ 进行惩罚在不同的实现中会有所不同。在神经网络的不同层中也会有所不同。通常,我们不正则化网络输出层的偏置项。
高维线性回归
------------
我们通过一个简单的例子来说明演示权重衰减。
.. code:: java
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/DataPoints.java
%load ../utils/Training.java
%load ../utils/Accumulator.java
.. code:: java
import org.apache.commons.lang3.ArrayUtils;
.. code:: java
int nTrain = 20;
int nTest = 100;
int numInputs = 200;
int batchSize = 5;
float trueB = 0.05f;
NDManager manager = NDManager.newBaseManager();
NDArray trueW = manager.ones(new Shape(numInputs, 1));
trueW = trueW.mul(0.01);
public ArrayDataset loadArray(NDArray features, NDArray labels, int batchSize, boolean shuffle) {
return new ArrayDataset.Builder()
.setData(features) // set the features
.optLabels(labels) // set the labels
.setSampling(batchSize, shuffle) // set the batch size and random sampling
.build();
}
DataPoints trainData = DataPoints.syntheticData(manager, trueW, trueB, nTrain);
ArrayDataset trainIter = loadArray(trainData.getX(), trainData.getY(), batchSize, true);
DataPoints testData = DataPoints.syntheticData(manager, trueW, trueB, nTest);
ArrayDataset testIter = loadArray(testData.getX(), testData.getY(), batchSize, false);
首先,我们像以前一样生成一些数据,生成公式如下:
.. math::
y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where }
\epsilon \sim \mathcal{N}(0, 0.01^2).
我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。为了使过拟合的效果更加明显,我们可以将问题的维数增加到\ :math:`d = 200`\ ,并使用一个只包含20个样本的小训练集。
从零开始实现
------------
在下面,我们将从头开始实现权重衰减,只需将\ :math:`L_2`\ 的平方惩罚添加到原始目标函数中。
初始化模型参数
~~~~~~~~~~~~~~
首先,我们将定义一个函数来随机初始化我们的模型参数。
.. code:: java
public class InitParams{
private NDArray w;
private NDArray b;
private NDList l;
public NDArray getW(){
return this.w;
}
public NDArray getB(){
return this.b;
}
public InitParams(){
NDManager manager = NDManager.newBaseManager();
w = manager.randomNormal(0, 1.0f, new Shape(numInputs, 1), DataType.FLOAT32);
b = manager.zeros(new Shape(1));
w.setRequiresGradient(true);
b.setRequiresGradient(true);
}
}
定义\ :math:`L_2`\ 范数惩罚
~~~~~~~~~~~~~~~~~~~~~~~~~~~
实现这一惩罚最方便的方法是对所有项求平方后并将它们求和。
.. code:: java
public NDArray l2Penalty(NDArray w){
return ((w.pow(2)).sum()).div(2);
}
.. code:: java
Loss l2loss = Loss.l2Loss();
定义训练代码实现
~~~~~~~~~~~~~~~~
下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。从
:numref:`chap_linear`
以来,线性网络和平方损失没有变化,所以我们通过\ ``Training.linreg()``\ 和\ ``Training.squaredLoss()``\ 导入它们。唯一的变化是损失现在包括了惩罚项。
.. code:: java
double[] trainLoss;
double[] testLoss;
double[] epochCount;
public void train(float lambd) throws IOException, TranslateException {
InitParams initParams = new InitParams();
NDList params = new NDList(initParams.getW(), initParams.getB());
int numEpochs = Integer.getInteger("MAX_EPOCH", 100);
float lr = 0.003f;
trainLoss = new double[(numEpochs/5)];
testLoss = new double[(numEpochs/5)];
epochCount = new double[(numEpochs/5)];
for(int epoch = 1; epoch <= numEpochs; epoch++){
for(Batch batch : trainIter.getData(manager)){
NDArray X = batch.getData().head();
NDArray y = batch.getLabels().head();
NDArray w = params.get(0);
NDArray b = params.get(1);
try (GradientCollector gc = Engine.getInstance().newGradientCollector()) {
// The L2 norm penalty term has been added, and broadcasting
// makes `l2Penalty(w)` a vector whose length is `batch_size`
NDArray l = Training.squaredLoss(Training.linreg(X, w, b), y).add(l2Penalty(w).mul(lambd));
gc.backward(l); // Compute gradient on l with respect to w and b
}
batch.close();
Training.sgd(params, lr, batchSize); // Update parameters using their gradient
}
if(epoch % 5 == 0){
NDArray testL = Training.squaredLoss(Training.linreg(testData.getX(), params.get(0), params.get(1)), testData.getY());
NDArray trainL = Training.squaredLoss(Training.linreg(trainData.getX(), params.get(0), params.get(1)), trainData.getY());
epochCount[epoch/5 - 1] = epoch;
trainLoss[epoch/5 -1] = trainL.mean().log10().getFloat();
testLoss[epoch/5 -1] = testL.mean().log10().getFloat();
}
}
System.out.println("l1 norm of w: " + params.get(0).abs().sum());
}
忽略正则化直接训练
~~~~~~~~~~~~~~~~~~
我们现在用\ ``lambd = 0``\ 禁用权重衰减后运行这个代码。注意,这里训练误差有了减少,但测试误差没有减少。这意味着出现了严重的过拟合。这是过拟合的一个典型例子。
.. code:: java
train(0f);
String[] lossLabel = new String[trainLoss.length + testLoss.length];
Arrays.fill(lossLabel, 0, testLoss.length, "test");
Arrays.fill(lossLabel, testLoss.length, trainLoss.length + testLoss.length, "train");
Table data = Table.create("Data").addColumns(
DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, epochCount)),
DoubleColumn.create("loss", ArrayUtils.addAll(testLoss, trainLoss)),
StringColumn.create("lossLabel", lossLabel)
);
render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html");
.. parsed-literal::
:class: output
l1 norm of w: ND: () gpu(0) float32
161.0816
.. raw:: html