Run this notebook online: or Colab:
4.3. 多层感知机的简洁实现¶
正如你所期待的,我们可以使用DJL
更简洁地实现多层感知机。
%load ../utils/djl-imports
%load ../utils/plot-utils
import ai.djl.metric.*;
import ai.djl.basicdataset.cv.classification.*;
import org.apache.commons.lang3.ArrayUtils;
4.3.1. 模型¶
与softmax回归的简洁实现(:numref:sec_softmax_concise
)相比,唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。第二层是输出层。
SequentialBlock net = new SequentialBlock();
net.add(Blocks.batchFlattenBlock(784));
net.add(Linear.builder().setUnits(256).build());
net.add(Activation::relu);
net.add(Linear.builder().setUnits(10).build());
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);
训练过程的实现与我们实现softmax回归时完全相同,这种模块化设计使我们能够将与和模型架构有关的内容独立出来。
int batchSize = 256;
int numEpochs = Integer.getInteger("MAX_EPOCH", 10);
double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;
trainLoss = new double[numEpochs];
trainAccuracy = new double[numEpochs];
testAccuracy = new double[numEpochs];
epochCount = new double[numEpochs];
FashionMnist trainIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
FashionMnist testIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TEST)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
trainIter.prepare();
testIter.prepare();
for(int i = 0; i < epochCount.length; i++) {
epochCount[i] = (i + 1);
}
Map<String, double[]> evaluatorMetrics = new HashMap<>();
Tracker lrt = Tracker.fixed(0.5f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
Loss loss = Loss.softmaxCrossEntropyLoss();
DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(sgd) // Optimizer (loss function)
.optDevices(Engine.getInstance().getDevices(1)) // single GPU
.addEvaluator(new Accuracy()) // Model Accuracy
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging
try (Model model = Model.newInstance("mlp")) {
model.setBlock(net);
try (Trainer trainer = model.newTrainer(config)) {
trainer.initialize(new Shape(1, 784));
trainer.setMetrics(new Metrics());
EasyTrain.fit(trainer, numEpochs, trainIter, testIter);
// collect results from evaluators
Metrics metrics = trainer.getMetrics();
trainer.getEvaluators().stream()
.forEach(evaluator -> {
evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
.mapToDouble(x -> x.getValue().doubleValue()).toArray());
evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
.mapToDouble(x -> x.getValue().doubleValue()).toArray());
});
}
}
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.060 ms.
Training: 100% |████████████████████████████████████████| Accuracy: 0.70, SoftmaxCrossEntropyLoss: 0.81
Validating: 100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.70, SoftmaxCrossEntropyLoss: 0.81
INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.73
Training: 100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
Validating: 100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49
INFO Validate: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.47
Training: 100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42
Validating: 100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45
Training: 100% |████████████████████████████████████████| Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.39
Validating: 100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.39
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42
Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
Validating: 100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.43
Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35
Validating: 100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.43
Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.34
Validating: 100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.34
INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.44
Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32
Validating: 100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37
Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.31
Validating: 100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.31
INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.37
Training: 100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Validating: 100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.42
INFO forward P50: 0.327 ms, P90: 0.382 ms
INFO training-metrics P50: 0.018 ms, P90: 0.022 ms
INFO backward P50: 0.623 ms, P90: 0.663 ms
INFO step P50: 0.911 ms, P90: 0.996 ms
INFO epoch P50: 1.290 s, P90: 1.641 s
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");
String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];
Arrays.fill(lossLabel, 0, trainLoss.length, "test acc");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
trainLoss.length + testAccuracy.length + trainAccuracy.length, "train loss");
Table data = Table.create("Data").addColumns(
DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
DoubleColumn.create("loss", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))),
StringColumn.create("lossLabel", lossLabel)
);
render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html");
4.3.2. 练习¶
尝试添加不同数量的隐藏层(也可以修改学习率)。怎么样设置效果最好?
尝试不同的激活函数。哪个效果最好?
尝试不同的方案来初始化权重。什么方法效果最好?