Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_multilayer-perceptrons/mlp-djl.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_multilayer-perceptrons/mlp-djl.ipynb .. _sec_mlp_concise: 多层感知机的简洁实现 ==================== 正如你所期待的,我们可以使用\ ``DJL``\ 更简洁地实现多层感知机。 .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils .. code:: java import ai.djl.metric.*; import ai.djl.basicdataset.cv.classification.*; import org.apache.commons.lang3.ArrayUtils; 模型 ---- 与softmax回归的简洁实现(:numref:\ ``sec_softmax_concise``\ )相比,唯一的区别是我们添加了2个全连接层(之前我们只添加了1个全连接层)。第一层是隐藏层,它包含256个隐藏单元,并使用了ReLU激活函数。第二层是输出层。 .. code:: java SequentialBlock net = new SequentialBlock(); net.add(Blocks.batchFlattenBlock(784)); net.add(Linear.builder().setUnits(256).build()); net.add(Activation::relu); net.add(Linear.builder().setUnits(10).build()); net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT); 训练过程的实现与我们实现softmax回归时完全相同,这种模块化设计使我们能够将与和模型架构有关的内容独立出来。 .. code:: java int batchSize = 256; int numEpochs = Integer.getInteger("MAX_EPOCH", 10); double[] trainLoss; double[] testAccuracy; double[] epochCount; double[] trainAccuracy; trainLoss = new double[numEpochs]; trainAccuracy = new double[numEpochs]; testAccuracy = new double[numEpochs]; epochCount = new double[numEpochs]; FashionMnist trainIter = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(batchSize, true) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); FashionMnist testIter = FashionMnist.builder() .optUsage(Dataset.Usage.TEST) .setSampling(batchSize, true) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); trainIter.prepare(); testIter.prepare(); for(int i = 0; i < epochCount.length; i++) { epochCount[i] = (i + 1); } Map evaluatorMetrics = new HashMap<>(); .. code:: java Tracker lrt = Tracker.fixed(0.5f); Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build(); Loss loss = Loss.softmaxCrossEntropyLoss(); DefaultTrainingConfig config = new DefaultTrainingConfig(loss) .optOptimizer(sgd) // Optimizer (loss function) .optDevices(Engine.getInstance().getDevices(1)) // single GPU .addEvaluator(new Accuracy()) // Model Accuracy .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging try (Model model = Model.newInstance("mlp")) { model.setBlock(net); try (Trainer trainer = model.newTrainer(config)) { trainer.initialize(new Shape(1, 784)); trainer.setMetrics(new Metrics()); EasyTrain.fit(trainer, numEpochs, trainIter, testIter); // collect results from evaluators Metrics metrics = trainer.getMetrics(); trainer.getEvaluators().stream() .forEach(evaluator -> { evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream() .mapToDouble(x -> x.getValue().doubleValue()).toArray()); }); } } .. parsed-literal:: :class: output INFO Training on: 1 GPUs. INFO Load MXNet Engine Version 1.9.0 in 0.060 ms. .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.70, SoftmaxCrossEntropyLoss: 0.81 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 1 finished. INFO Train: Accuracy: 0.70, SoftmaxCrossEntropyLoss: 0.81 INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.73 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 2 finished. INFO Train: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.49 INFO Validate: Accuracy: 0.83, SoftmaxCrossEntropyLoss: 0.47 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 3 finished. INFO Train: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42 INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.45 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.39 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 4 finished. INFO Train: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.39 INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.42 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 5 finished. INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.36 INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.43 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 6 finished. INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.35 INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.43 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.34 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 7 finished. INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.34 INFO Validate: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.44 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 8 finished. INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.32 INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.37 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.31 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 9 finished. INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.31 INFO Validate: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.37 .. parsed-literal:: :class: output Training: 100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30 Validating: 100% |████████████████████████████████████████| .. parsed-literal:: :class: output INFO Epoch 10 finished. INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30 INFO Validate: Accuracy: 0.86, SoftmaxCrossEntropyLoss: 0.42 INFO forward P50: 0.327 ms, P90: 0.382 ms INFO training-metrics P50: 0.018 ms, P90: 0.022 ms INFO backward P50: 0.623 ms, P90: 0.663 ms INFO step P50: 0.911 ms, P90: 0.996 ms INFO epoch P50: 1.290 s, P90: 1.641 s .. code:: java trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss"); trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy"); testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy"); String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length]; Arrays.fill(lossLabel, 0, trainLoss.length, "test acc"); Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc"); Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length, trainLoss.length + testAccuracy.length + trainAccuracy.length, "train loss"); Table data = Table.create("Data").addColumns( DoubleColumn.create("epochCount", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))), DoubleColumn.create("loss", ArrayUtils.addAll(testAccuracy , ArrayUtils.addAll(trainAccuracy, trainLoss))), StringColumn.create("lossLabel", lossLabel) ); render(LinePlot.create("", data, "epochCount", "loss", "lossLabel"),"text/html"); .. raw:: html
练习 ---- 1. 尝试添加不同数量的隐藏层(也可以修改学习率)。怎么样设置效果最好? 2. 尝试不同的激活函数。哪个效果最好? 3. 尝试不同的方案来初始化权重。什么方法效果最好?