Run this notebook online: or Colab:

# 11.3. 梯度下降¶

## 11.3.1. 一维梯度下降¶

(11.3.1)$f(x + \epsilon) = f(x) + \epsilon f'(x) + \mathcal{O}(\epsilon^2).$

(11.3.2)$f(x - \eta f'(x)) = f(x) - \eta f'^2(x) + \mathcal{O}(\eta^2 f'^2(x)).$

(11.3.3)$f(x - \eta f'(x)) \lessapprox f(x).$

(11.3.4)$x \leftarrow x - \eta f'(x)$

%load ../utils/djl-imports

Function<Float, Float> f = x -> x * x; // Objective Function
Function<Float, Float> gradf = x -> 2 * x; // Its Derivative

NDManager manager = NDManager.newBaseManager();


public float[] gd(float eta) {
float x = 10f;
float[] results = new float[11];
results[0] = x;

for (int i = 0; i < 10; i++) {
results[i + 1] = x;
}
System.out.printf("epoch 10, x: %f\n", x);
return results;
}

float[] res = gd(0.2f);

epoch 10, x: 0.060466


/* Saved in GradDescUtils.java */
public void plotGD(float[] x, float[] y, float[] segment, Function<Float, Float> func,
int width, int height) {
// Function Line
ScatterTrace trace = ScatterTrace.builder(Functions.floatToDoubleArray(x),
Functions.floatToDoubleArray(y))
.mode(ScatterTrace.Mode.LINE)
.build();

// GD Line
ScatterTrace trace2 = ScatterTrace.builder(Functions.floatToDoubleArray(segment),
Functions.floatToDoubleArray(Functions.callFunc(segment, func)))
.mode(ScatterTrace.Mode.LINE)
.build();

// GD Points
ScatterTrace trace3 = ScatterTrace.builder(Functions.floatToDoubleArray(segment),
Functions.floatToDoubleArray(Functions.callFunc(segment, func)))
.build();

Layout layout = Layout.builder()
.height(height)
.width(width)
.showLegend(false)
.build();

display(new Figure(layout, trace, trace2, trace3));
}

/* Saved in GradDescUtils.java */
public void showTrace(float[] res) {
float n = 0;
for (int i = 0; i < res.length; i++) {
if (Math.abs(res[i]) > n) {
n = Math.abs(res[i]);
}
}
NDArray fLineND = manager.arange(-n, n, 0.01f);
float[] fLine = fLineND.toFloatArray();
plotGD(fLine, Functions.callFunc(fLine, f), res, f, 500, 400);
}

showTrace(res);


### 11.3.1.1. 学习率¶

showTrace(gd(0.05f));

epoch 10, x: 3.486785


showTrace(gd(1.1f));

epoch 10, x: 61.917389


### 11.3.1.2. 局部最小值¶

float c = (float)(0.15f * Math.PI);

Function<Float, Float> f = x -> x * (float)Math.cos(c * x);

Function<Float, Float> gradf = x -> (float)(Math.cos(c * x) - c * x * Math.sin(c * x));

showTrace(gd(2));

epoch 10, x: -1.528166


## 11.3.2. 多元梯度下降¶

(11.3.5)$\nabla f(\mathbf{x}) = \bigg[\frac{\partial f(\mathbf{x})}{\partial x_1}, \frac{\partial f(\mathbf{x})}{\partial x_2}, \ldots, \frac{\partial f(\mathbf{x})}{\partial x_d}\bigg]^\top.$

(11.3.6)$f(\mathbf{x} + \boldsymbol{\epsilon}) = f(\mathbf{x}) + \mathbf{\boldsymbol{\epsilon}}^\top \nabla f(\mathbf{x}) + \mathcal{O}(\|\boldsymbol{\epsilon}\|^2).$

(11.3.7)$\mathbf{x} \leftarrow \mathbf{x} - \eta \nabla f(\mathbf{x}).$

/* Saved in GradDescUtils.java */
public class Weights {
public float x1, x2;
public Weights(float x1, float x2) {
this.x1 = x1;
this.x2 = x2;
}
}

/* Optimize a 2D objective function with a customized trainer. */
public ArrayList<Weights> train2d(Function<Float[], Float[]> trainer, int steps) {
// s1和s2是稍后将使用的内部状态变量
float x1 = -5f, x2 = -2f, s1 = 0f, s2 = 0f;
ArrayList<Weights> results = new ArrayList<>();
for (int i = 1; i < steps + 1; i++) {
Float[] step = trainer.apply(new Float[]{x1, x2, s1, s2});
x1 = step[0];
x2 = step[1];
s1 = step[2];
s2 = step[3];
System.out.printf("epoch %d, x1 %f, x2 %f\n", i, x1, x2);
}
return results;
}

import java.util.function.BiFunction;

/* Show the trace of 2D variables during optimization. */
public void showTrace2d(BiFunction<Float, Float, Float> f, ArrayList<Weights> results) {
}


float eta = 0.1f;

BiFunction<Float, Float, Float> f = (x1, x2) -> x1 * x1 + 2 * x2 * x2; // Objective

BiFunction<Float, Float, Float[]> gradf = (x1, x2) -> new Float[]{2 * x1, 4 * x2}; // Gradient

Function<Float[], Float[]> gd = (state) -> {
Float x1 = state[0];
Float x2 = state[1];

Float g1 = g[0];
Float g2 = g[1];

return new Float[]{x1 - eta * g1, x2 - eta * g2, 0f, 0f}; // Update Variables
};

showTrace2d(f, train2d(gd, 20));

epoch 1, x1 -4.000000, x2 -1.200000
epoch 2, x1 -3.200000, x2 -0.720000
epoch 3, x1 -2.560000, x2 -0.432000
epoch 4, x1 -2.048000, x2 -0.259200
epoch 5, x1 -1.638400, x2 -0.155520
epoch 6, x1 -1.310720, x2 -0.093312
epoch 7, x1 -1.048576, x2 -0.055987
epoch 8, x1 -0.838861, x2 -0.033592
epoch 9, x1 -0.671089, x2 -0.020155
epoch 10, x1 -0.536871, x2 -0.012093
epoch 11, x1 -0.429497, x2 -0.007256
epoch 12, x1 -0.343597, x2 -0.004354
epoch 13, x1 -0.274878, x2 -0.002612
epoch 14, x1 -0.219902, x2 -0.001567
epoch 15, x1 -0.175922, x2 -0.000940
epoch 16, x1 -0.140737, x2 -0.000564
epoch 17, x1 -0.112590, x2 -0.000339
epoch 18, x1 -0.090072, x2 -0.000203
epoch 19, x1 -0.072058, x2 -0.000122
epoch 20, x1 -0.057646, x2 -0.000073


Fig. 11.3.1 image.png

## 11.3.3. 自适应方法¶

### 11.3.3.1. 牛顿法¶

(11.3.8)$f(\mathbf{x} + \boldsymbol{\epsilon}) = f(\mathbf{x}) + \boldsymbol{\epsilon}^\top \nabla f(\mathbf{x}) + \frac{1}{2} \boldsymbol{\epsilon}^\top \nabla^2 f(\mathbf{x}) \boldsymbol{\epsilon} + \mathcal{O}(\|\boldsymbol{\epsilon}\|^3).$

(11.3.9)$\nabla f(\mathbf{x}) + \mathbf{H} \boldsymbol{\epsilon} = 0 \text{ and hence } \boldsymbol{\epsilon} = -\mathbf{H}^{-1} \nabla f(\mathbf{x}).$

float c = 0.5f;

Function<Float, Float> f = x -> (float)Math.cosh(c * x); // Objective

Function<Float, Float> gradf = x -> c * (float)Math.sinh(c * x); // Derivative

Function<Float, Float> hessf = x -> c * c * (float)Math.cosh(c * x); // Hessian

// Hide learning rate for now
public float[] newton(float eta) {
float x = 10f;
float[] results = new float[11];
results[0] = x;

for (int i = 0; i < 10; i++) {
x -= eta * gradf.apply(x) / hessf.apply(x);
results[i + 1] = x;
}
System.out.printf("epoch 10, x: %f\n", x);
return results;
}

showTrace(newton(1));

epoch 10, x: 0.000000


c = 0.15f * (float)Math.PI;

Function<Float, Float> f = x -> x * (float)Math.cos(c * x);

Function<Float, Float> gradf = x -> (float)(Math.cos(c * x) - c * x * Math.sin(c * x));

Function<Float, Float> hessf = x -> (float)(-2 * c * Math.sin(c * x) -
x * c * c * Math.cos(c * x));

showTrace(newton(1));

epoch 10, x: 26.834131