Run this notebook online:Binder or Colab: Colab

4.3. 多层感知机的简洁实现

正如你所期待的,我们可以使用DJL更简洁地实现多层感知机。

%load ../utils/djl-imports
%load ../utils/plot-utils
import ai.djl.metric.*;
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;

4.3.1. 模型

与softmax回归的简洁实现(:numref:sec_softmax_concise)相比,唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。第二层是输出层。

SequentialBlock net = new SequentialBlock();
net.add(Blocks.batchFlattenBlock(784));
net.add(Linear.builder().setUnits(256).build());
net.add(Activation::relu);
net.add(Linear.builder().setUnits(10).build());
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);

训练过程的实现与我们实现softmax回归时完全相同,这种模块化设计使我们能够将与和模型架构有关的内容独立出来。

int batchSize = 256;
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);
double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;

trainLoss = new double[numEpochs];
trainAccuracy = new double[numEpochs];
testAccuracy = new double[numEpochs];
epochCount = new double[numEpochs];

FashionMnist trainIter = FashionMnist.builder()
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

FashionMnist testIter = FashionMnist.builder()
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

trainIter.prepare();
testIter.prepare();

for(int i = 0; i < epochCount.length; i++) {
    epochCount[i] = (i + 1);
}

Map<String, double[]> evaluatorMetrics = new HashMap<>();
Tracker lrt = Tracker.fixed(0.5f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

Loss loss = Loss.softmaxCrossEntropyLoss();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
                .optOptimizer(sgd) // Optimizer (loss function)
                .optDevices(Engine.getInstance().getDevices(1)) // single GPU
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

    try (Model model = Model.newInstance("mlp")) {
        model.setBlock(net);

        try (Trainer trainer = model.newTrainer(config)) {

            trainer.initialize(new Shape(1, 784));
            trainer.setMetrics(new Metrics());

            EasyTrain.fit(trainer, numEpochs, trainIter, testIter);
            // collect results from evaluators
            Metrics metrics = trainer.getMetrics();

            trainer.getEvaluators().stream()
                    .forEach(evaluator -> {
                        evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                        evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
            });
    }
}
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.060 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.70, SoftmaxCrossEntropyLoss: 0.81
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.70, SoftmaxCrossEntropyLoss: 0.81
INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.73
Training:    100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
INFO Validate: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.47
Training:    100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45
Training:    100% |████████████████████████████████████████| Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.39
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.39
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42
Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.43
Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.43
Training:    100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.34
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.34
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.44
Training:    100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37
Training:    100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.31
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.31
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.37
Training:    100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.42
INFO forward P50: 0.327 ms, P90: 0.382 ms
INFO training-metrics P50: 0.018 ms, P90: 0.022 ms
INFO backward P50: 0.623 ms, P90: 0.663 ms
INFO step P50: 0.911 ms, P90: 0.996 ms
INFO epoch P50: 1.290 s, P90: 1.641 s
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

Arrays.fill(lossLabel, 0, trainLoss.length, "test acc");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                trainLoss.length + testAccuracy.length + trainAccuracy.length, "train loss");

Table data = Table.create("Data").addColumns(
            DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
            DoubleColumn.create("loss", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))),
            StringColumn.create("lossLabel", lossLabel)
);

render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html");

4.3.2. 练习

  1. 尝试添加不同数量的隐藏层(也可以修改学习率)。怎么样设置效果最好?

  2. 尝试不同的激活函数。哪个效果最好?

  3. 尝试不同的方案来初始化权重。什么方法效果最好?