Run this notebook online:Binder or Colab: Colab

12.4. 多GPU的简洁实现

每个新模型的并行计算都从零开始实现是无趣的。此外,优化同步工具以获得高性能也是有好处的。下面我们将展示如何使用深度学习框架的高级API来实现这一点。数学和算法与 Section 12.3中的相同。不出所料,你至少需要两个GPU来运行本节的代码。

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Training.java
import ai.djl.basicdataset.cv.classification.*;
import ai.djl.metric.*;
import org.apache.commons.lang3.ArrayUtils;

12.4.1. 简单网络

让我们使用一个比 Section 12.3的LeNet更有意义的网络,它依然能够容易地和快速地训练。我们选择的是 [He et al., 2016a]中的ResNet-18。因为输入的图像很小,所以稍微修改了一下。与 Section 7.6的区别在于,我们在开始时使用了更小的卷积核、步长和填充,而且删除了最大汇聚层。

class Residual extends AbstractBlock {

    private static final byte VERSION = 2;

    public ParallelBlock block;

    public Residual(int numChannels, boolean use1x1Conv, Shape strideShape) {
        super(VERSION);

        SequentialBlock b1;
        SequentialBlock conv1x1;

        b1 = new SequentialBlock();

        b1.add(Conv2d.builder()
                .setFilters(numChannels)
                .setKernelShape(new Shape(3, 3))
                .optPadding(new Shape(1, 1))
                .optStride(strideShape)
                .build())
                .add(BatchNorm.builder().build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setFilters(numChannels)
                        .setKernelShape(new Shape(3, 3))
                        .optPadding(new Shape(1, 1))
                        .build())
                .add(BatchNorm.builder().build());

        if (use1x1Conv) {
            conv1x1 = new SequentialBlock();
            conv1x1.add(Conv2d.builder()
                    .setFilters(numChannels)
                    .setKernelShape(new Shape(1, 1))
                    .optStride(strideShape)
                    .build());
        } else {
            conv1x1 = new SequentialBlock();
            conv1x1.add(Blocks.identityBlock());
        }

        block = addChildBlock("residualBlock", new ParallelBlock(
                list -> {
                    NDList unit = list.get(0);
                    NDList parallel = list.get(1);
                    return new NDList(
                            unit.singletonOrThrow()
                                    .add(parallel.singletonOrThrow())
                                    .getNDArrayInternal()
                                    .relu());
                },
                Arrays.asList(b1, conv1x1)));
    }

    @Override
    protected NDList forwardInternal(
            ParameterStore parameterStore,
            NDList inputs,
            boolean training,
            PairList<String, Object> params) {
        return block.forward(parameterStore, inputs, training);
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputs) {
        Shape[] current = inputs;
        for (Block block : block.getChildren().values()) {
            current = block.getOutputShapes(current);
        }
        return current;
    }

    @Override
    protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
        block.initialize(manager, dataType, inputShapes);
    }
}
public SequentialBlock resnetBlock(int numChannels, int numResiduals, boolean isFirstBlock) {

        SequentialBlock blk = new SequentialBlock();
        for (int i = 0; i < numResiduals; i++) {

            if (i == 0 && !isFirstBlock) {
                blk.add(new Residual(numChannels, true, new Shape(2, 2)));
            } else {
                blk.add(new Residual(numChannels, false, new Shape(1, 1)));
            }
        }
        return blk;
}

int numClass = 10;
// This model uses a smaller convolution kernel, stride, and padding and
// removes the maximum pooling layer
SequentialBlock net = new SequentialBlock();
net
    .add(
            Conv2d.builder()
                    .setFilters(64)
                    .setKernelShape(new Shape(3, 3))
                    .optPadding(new Shape(1, 1))
                    .build())
    .add(BatchNorm.builder().build())
    .add(Activation::relu)
    .add(resnetBlock(64, 2, true))
    .add(resnetBlock(128, 2, false))
    .add(resnetBlock(256, 2, false))
    .add(resnetBlock(512, 2, false))
    .add(Pool.globalAvgPool2dBlock())
    .add(Linear.builder().setUnits(numClass).build());
SequentialBlock {
    Conv2d
    BatchNorm
    LambdaBlock
    SequentialBlock {
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
    }
    SequentialBlock {
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    Conv2d
                            }
                    }
            }
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
    }
    SequentialBlock {
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    Conv2d
                            }
                    }
            }
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
    }
    SequentialBlock {
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    Conv2d
                            }
                    }
            }
            Residual {
                    residualBlock {
                            SequentialBlock {
                                    Conv2d
                                    BatchNorm
                                    LambdaBlock
                                    Conv2d
                                    BatchNorm
                            }
                            SequentialBlock {
                                    identity
                            }
                    }
            }
    }
    globalAvgPool2d
    Linear
}

12.4.2. 网络初始化

setInitializer()函数允许我们在所选设备上初始化参数。请参阅 Section 4.8复习初始化方法。这个函数在多个设备上初始化网络时特别方便。让我们在实践中试一试它的运作方式。

Model model = Model.newInstance("training-multiple-gpus-1");
model.setBlock(net);

Loss loss = Loss.softmaxCrossEntropyLoss();

Tracker lrt = Tracker.fixed(0.1f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
        .optOptimizer(sgd) // Optimizer (loss function)
        .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer
        .optDevices(Engine.getInstance().getDevices(1)) // setting the number of GPUs needed
        .addEvaluator(new Accuracy()) // Model Accuracy
        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

Trainer trainer = model.newTrainer(config);
INFO Training on: 1 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.052 ms.

使用 Section 12.3中引入的split()函数可以切分一个小批量数据,并将切分后的分块数据复制到多个设备设备中。网络实例自动使用适当的GPU来计算前向传播的值。我们将在下面生成\(4\)个观测值,并在GPU上将它们拆分。

NDManager manager = NDManager.newBaseManager();
NDArray X = manager.randomUniform(0f, 1.0f, new Shape(4, 1, 28, 28));
trainer.initialize(X.getShape());

NDList[] res = Batchifier.STACK.split(new NDList(X), 4, true);

ParameterStore parameterStore = new ParameterStore(manager, true);

System.out.println(net.forward(parameterStore, new NDList(res[0]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[1]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[2]), false).singletonOrThrow());
System.out.println(net.forward(parameterStore, new NDList(res[3]), false).singletonOrThrow());
ND: (1, 10) gpu(0) float32
[[-2.53076792e-07,  2.19176854e-06, -2.05096558e-06, -2.80443487e-07, -1.65612937e-06,  5.92275399e-07, -4.38029275e-07,  1.43108821e-07,  1.86682854e-07,  8.35030505e-07],
]

ND: (1, 10) gpu(0) float32
[[-3.17955994e-07,  1.94063477e-06, -1.82914255e-06,  1.36083145e-09, -1.45861077e-06,  4.11562326e-07, -8.99586439e-07,  1.97685665e-07,  2.77768578e-07,  6.80656115e-07],
]

ND: (1, 10) gpu(0) float32
[[-1.82850158e-07,  2.26233874e-06, -2.24626365e-06,  8.68596715e-08, -1.29084265e-06,  9.33801005e-07, -1.04999901e-06,  1.76022922e-07,  3.97307645e-08,  9.49504113e-07],
]

ND: (1, 10) gpu(0) float32
[[-1.78178539e-07,  1.59132321e-06, -2.00916884e-06, -2.30666600e-07, -1.31331467e-06,  5.71873784e-07, -4.02916669e-07,  1.11762461e-07,  3.40592749e-07,  8.89963815e-07],
]

一旦数据通过网络,网络对应的参数就会在有数据通过的设备上初始化。这意味着初始化是基于每个设备进行的。由于我们选择的是GPU0和GPU1,所以网络只在这两个GPU上初始化,而不是在CPU上初始化。事实上,CPU上甚至没有这些参数。我们可以通过打印参数和观察可能出现的任何错误来验证这一点。

net.getChildren().values().get(0).getParameters().get("weight").getArray().get(new NDIndex("0:1"));
ND: (1, 1, 3, 3) gpu(0) float32
[[[[ 0.0053, -0.0018, -0.0141],
   [-0.0094, -0.0146,  0.0094],
   [ 0.002 ,  0.0189,  0.0014],
  ],
 ],
]

12.4.3. 训练

如前所述,用于训练的代码需要执行几个基本功能才能实现高效并行:

  • 需要在所有设备上初始化网络参数。

  • 在数据集上迭代时,要将小批量数据分配到所有设备上。

  • 跨设备并行计算损失及其梯度。

  • 聚合梯度,并相应地更新参数。

最后,并行地计算精确度和发布网络的最终性能。除了需要拆分和聚合数据外,训练代码与前几章的实现非常相似。

int numEpochs = Integer.getInteger("MAX_EPOCH", 10);

double[] testAccuracy;
double[] epochCount;

epochCount = new double[numEpochs];

for (int i = 0; i < epochCount.length; i++) {
    epochCount[i] = (i + 1);
}

Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = 0;
public void train(int numEpochs, Trainer trainer, int batchSize) throws IOException, TranslateException {

    FashionMnist trainIter = FashionMnist.builder()
            .optUsage(Dataset.Usage.TRAIN)
            .setSampling(batchSize, true)
            .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
            .build();
    FashionMnist testIter = FashionMnist.builder()
            .optUsage(Dataset.Usage.TEST)
            .setSampling(batchSize, true)
            .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
            .build();

    trainIter.prepare();
    testIter.prepare();

    Map<String, double[]> evaluatorMetrics = new HashMap<>();
    double avgTrainTime = 0;

    trainer.setMetrics(new Metrics());

    EasyTrain.fit(trainer, numEpochs, trainIter, testIter);

    Metrics metrics = trainer.getMetrics();

    trainer.getEvaluators().stream()
            .forEach(evaluator -> {
                evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());
            });

    avgTrainTime = metrics.mean("epoch");
    testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");
    System.out.printf("test acc %.2f\n" , testAccuracy[numEpochs-1]);
    System.out.println(avgTrainTime / Math.pow(10, 9) + " sec/epoch \n");
}

12.4.4. 实践

让我们看看这在实践中是如何运作的。我们先在单个GPU上训练网络进行预热。

Table data = null;
// We will check if we have at least 1 GPU available. If yes, we run the training on 1 GPU.
if (Engine.getInstance().getGpuCount() >= 1) {
    train(numEpochs, trainer, 256);

    data = Table.create("Data");
    data = data.addColumns(
            DoubleColumn.create("X", epochCount),
            DoubleColumn.create("testAccuracy", testAccuracy)
    );
}
Training:    100% |████████████████████████████████████████| Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.61
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 0.61
INFO Validate: Accuracy: 0.82, SoftmaxCrossEntropyLoss: 0.48
Training:    100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26
INFO Validate: Accuracy: 0.84, SoftmaxCrossEntropyLoss: 0.48
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.25
Training:    100% |████████████████████████████████████████| Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.94, SoftmaxCrossEntropyLoss: 0.17
INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.30
Training:    100% |████████████████████████████████████████| Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.25
Training:    100% |████████████████████████████████████████| Accuracy: 0.96, SoftmaxCrossEntropyLoss: 0.11
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.96, SoftmaxCrossEntropyLoss: 0.11
INFO Validate: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.28
Training:    100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09
INFO Validate: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.33
Training:    100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.07
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.07
INFO Validate: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.38
Training:    100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.05
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.05
INFO Validate: Accuracy: 0.78, SoftmaxCrossEntropyLoss: 1.03
Training:    100% |████████████████████████████████████████| Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.06
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.98, SoftmaxCrossEntropyLoss: 0.06
INFO Validate: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.33
test acc 0.92
20.1823143983 sec/epoch
// 以下代码需要你有至少一个GPU设备
// render(LinePlot.create("", data, "x", "testAccuracy"), "text/html");
https://d2l-java-resources.s3.amazonaws.com/img/training-with-1-gpu.png

Fig. 12.4.1 Contour Gradient Descent.

Table data = Table.create("Data");

// We will check if we have more than 1 GPU available. If yes, we run the training on 2 GPU.
if (Engine.getInstance().getGpuCount() > 1) {

    X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 28, 28));

    Model model = Model.newInstance("training-multiple-gpus-2");
    model.setBlock(net);

    loss = Loss.softmaxCrossEntropyLoss();

    Tracker lrt = Tracker.fixed(0.2f);
    Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

    DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
                .optOptimizer(sgd) // Optimizer (loss function)
                .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer
                .optDevices(Engine.getInstance().getDevices(2)) // setting the number of GPUs needed
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

    Trainer trainer = model.newTrainer(config);

    trainer.initialize(X.getShape());

    Map<String, double[]> evaluatorMetrics = new HashMap<>();
    double avgTrainTimePerEpoch = 0;

    train(numEpochs, trainer, 512);

    data = data.addColumns(
        DoubleColumn.create("X", epochCount),
        DoubleColumn.create("testAccuracy", testAccuracy)
    );
}
INFO Training on: 2 GPUs.
INFO Load MXNet Engine Version 1.9.0 in 0.019 ms.
Training:    100% |████████████████████████████████████████| Accuracy: 0.56, SoftmaxCrossEntropyLoss: 1.40
Validating:  100% |████████████████████████████████████████|
INFO Epoch 1 finished.
INFO Train: Accuracy: 0.57, SoftmaxCrossEntropyLoss: 1.38
INFO Validate: Accuracy: 0.52, SoftmaxCrossEntropyLoss: 1.33
Training:    100% |████████████████████████████████████████| Accuracy: 0.80, SoftmaxCrossEntropyLoss: 0.53
Validating:  100% |████████████████████████████████████████|
INFO Epoch 2 finished.
INFO Train: Accuracy: 0.80, SoftmaxCrossEntropyLoss: 0.53
INFO Validate: Accuracy: 0.72, SoftmaxCrossEntropyLoss: 0.83
Training:    100% |████████████████████████████████████████| Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.40
Validating:  100% |████████████████████████████████████████|
INFO Epoch 3 finished.
INFO Train: Accuracy: 0.85, SoftmaxCrossEntropyLoss: 0.40
INFO Validate: Accuracy: 0.72, SoftmaxCrossEntropyLoss: 0.82
Training:    100% |████████████████████████████████████████| Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
Validating:  100% |████████████████████████████████████████|
INFO Epoch 4 finished.
INFO Train: Accuracy: 0.87, SoftmaxCrossEntropyLoss: 0.34
INFO Validate: Accuracy: 0.76, SoftmaxCrossEntropyLoss: 0.66
Training:    100% |████████████████████████████████████████| Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.31
Validating:  100% |████████████████████████████████████████|
INFO Epoch 5 finished.
INFO Train: Accuracy: 0.88, SoftmaxCrossEntropyLoss: 0.31
INFO Validate: Accuracy: 0.73, SoftmaxCrossEntropyLoss: 0.85
Training:    100% |████████████████████████████████████████| Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.28
Validating:  100% |████████████████████████████████████████|
INFO Epoch 6 finished.
INFO Train: Accuracy: 0.89, SoftmaxCrossEntropyLoss: 0.28
INFO Validate: Accuracy: 0.81, SoftmaxCrossEntropyLoss: 0.51
Training:    100% |████████████████████████████████████████| Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26
Validating:  100% |████████████████████████████████████████|
INFO Epoch 7 finished.
INFO Train: Accuracy: 0.90, SoftmaxCrossEntropyLoss: 0.26
INFO Validate: Accuracy: 0.70, SoftmaxCrossEntropyLoss: 0.83
Training:    100% |████████████████████████████████████████| Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
Validating:  100% |████████████████████████████████████████|
INFO Epoch 8 finished.
INFO Train: Accuracy: 0.91, SoftmaxCrossEntropyLoss: 0.24
INFO Validate: Accuracy: 0.75, SoftmaxCrossEntropyLoss: 0.73
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
Validating:  100% |████████████████████████████████████████|
INFO Epoch 9 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.22
INFO Validate: Accuracy: 0.77, SoftmaxCrossEntropyLoss: 0.65
Training:    100% |████████████████████████████████████████| Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
Validating:  100% |████████████████████████████████████████|
INFO Epoch 10 finished.
INFO Train: Accuracy: 0.92, SoftmaxCrossEntropyLoss: 0.21
INFO Validate: Accuracy: 0.68, SoftmaxCrossEntropyLoss: 0.99
test acc 0.68
14.5093982196 sec/epoch
// 以下代码需要你有两个以上GPU设备
// render(LinePlot.create("", data, "x", "testAccuracy"), "text/html");
https://d2l-java-resources.s3.amazonaws.com/img/training-with-2-gpu.png

Fig. 12.4.2 Contour Gradient Descent.

12.4.5. 小结

  • Gluon通过提供一个上下文列表,为跨多个设备的模型初始化提供原语。

  • 神经网络可以在(可找到数据的)单GPU上进行自动评估。

  • 每台设备上的网络需要先初始化,然后再尝试访问该设备上的参数,否则会遇到错误。

  • 优化算法在多个GPU上自动聚合。

12.4.6. 练习

  1. 本节使用ResNet-18,请尝试不同的迭代周期数、批量大小和学习率,以及使用更多的GPU进行计算。如果使用\(8\)个GPU(例如,在AWS p2.16xlarge实例上)尝试此操作,会发生什么?

  2. 有时候不同的设备提供了不同的计算能力,我们可以同时使用GPU和CPU,那应该如何分配工作?为什么?