Run this notebook online:Binder or Colab: Colab

3.5. 图像分类数据集

目前广泛使用的图像分类数据集之一是 MNIST 数据集 [LeCun et al., 1998]。虽然它是很不错的基准数据集,但按今天的标准,即使是简单的模型也能达到95%以上的分类准确率,因此不适合区分强模型和弱模型。如今,MNIST更像是一个健全检查,而不是一个基准。 为了提高难度,我们将在接下来的章节中讨论在2017年发布的性质相似但相对复杂的Fashion-MNIST数据集 [Xiao et al., 2017]

%load ../utils/djl-imports
%load ../utils/StopWatch.java
import ai.djl.basicdataset.cv.classification.*;
import java.awt.image.BufferedImage;
import java.awt.Graphics2D;
import java.awt.Graphics;
import java.awt.Color;
import java.awt.Dimension;
import java.awt.FlowLayout;
import java.awt.Component;
import javax.swing.*;

3.5.1. 读取数据集

就像MNIST,我们可以使用 DJL ai.djl.basicdataset 中的 FashionMnist 类来下载并读取到内存中。

int batchSize = 256;
boolean randomShuffle = true;

FashionMnist mnistTrain = FashionMnist.builder()
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize, randomShuffle)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

FashionMnist mnistTest = FashionMnist.builder()
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize, randomShuffle)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

mnistTrain.prepare();
mnistTest.prepare();

NDManager manager = NDManager.newBaseManager();

Fashion-MNIST 由 10 个类别的图像组成,每个类别由训练数据集中的 6000 张图像和测试数据集中的 1000 张图像组成。测试数据集(test dataset)不会用于训练,只用于评估模型性能。训练集和测试集分别包含 60000 和 10000 张图像。

System.out.println(mnistTrain.size());
System.out.println(mnistTest.size());
60000
10000

Fashion-MNIST中包含的10个类别分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数用于在数字标签索引及其文本名称之间进行转换。

// Saved in the FashionMnist class for later use
public String[] getFashionMnistLabels(int[] labelIndices) {
    String[] textLabels = {"t-shirt", "trouser", "pullover", "dress", "coat",
                   "sandal", "shirt", "sneaker", "bag", "ankle boot"};
    String[] convertedLabels = new String[labelIndices.length];
    for (int i = 0; i < labelIndices.length; i++) {
        convertedLabels[i] = textLabels[labelIndices[i]];
    }
    return convertedLabels;
}

public String getFashionMnistLabel(int labelIndice) {
    String[] textLabels = {"t-shirt", "trouser", "pullover", "dress", "coat",
                   "sandal", "shirt", "sneaker", "bag", "ankle boot"};
    return textLabels[labelIndice];
}

我们现在可以创建一个函数来可视化这些样本。

下面的代码只是为了帮助直观地理解数据, 你不需要太关注可视化的细节。我们读取了许多数据点并将它们的 RGB 值从 0-255 转换为 0-1 之间。然后我们将颜色以灰度的形式,将其与标签一起显示出来。

public class ImagePanel extends JPanel {
    int SCALE;
    BufferedImage img;

    public ImagePanel() {
        this.SCALE = 1;
    }
    public ImagePanel(int scale, BufferedImage img) {
        this.SCALE = scale;
        this.img = img;
    }
    @Override
    protected void paintComponent(Graphics g) {
        Graphics2D g2d = (Graphics2D)g;
        g2d.scale(SCALE, SCALE);
        g2d.drawImage(this.img, 0, 0, this);
    }
}

public class Container extends JPanel {
    public Container(String label) {
        setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
        JLabel l = new JLabel(label, JLabel.CENTER);
        l.setAlignmentX(Component.CENTER_ALIGNMENT);
        add(l);
    }
    public Container(String trueLabel, String predLabel) {
        setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
        JLabel l = new JLabel(trueLabel, JLabel.CENTER);
        l.setAlignmentX(Component.CENTER_ALIGNMENT);
        add(l);
        JLabel l2 = new JLabel(predLabel, JLabel.CENTER);
        l2.setAlignmentX(Component.CENTER_ALIGNMENT);
        add(l2);
    }
}
// Saved in the FashionMnistUtils class for later use
public void showImages(ArrayDataset dataset,
                       int number, int WIDTH, int HEIGHT, int SCALE,
                       NDManager manager)
    throws IOException, TranslateException {
    // Plot a list of images
    JFrame frame = new JFrame("Fashion Mnist");
    for (int record = 0; record < number; record++) {
        NDArray X = dataset.get(manager, record).getData().get(0).squeeze(-1);
        int y = (int)dataset.get(manager, record).getLabels().get(0).getFloat();
        BufferedImage img = new BufferedImage(WIDTH, HEIGHT, BufferedImage.TYPE_BYTE_GRAY);
        Graphics2D g = (Graphics2D) img.getGraphics();
        for(int i = 0; i < WIDTH; i++) {
            for(int j = 0; j < HEIGHT; j++) {
                float c = X.getFloat(j, i) / 255;  // scale down to between 0 and 1
                g.setColor(new Color(c, c, c)); // set as a gray color
                g.fillRect(i, j, 1, 1);
            }
        }
        JPanel panel = new ImagePanel(SCALE, img);
        panel.setPreferredSize(new Dimension(WIDTH * SCALE, HEIGHT * SCALE));
        JPanel container = new Container(getFashionMnistLabel(y));
        container.add(panel);
        frame.getContentPane().add(container);
    }
    frame.getContentPane().setLayout(new FlowLayout());
    frame.pack();
    frame.setVisible(true);
}

以下是训练数据集中前几个样本的图像及其相应的标签(文本形式)。

final int SCALE = 4;
final int WIDTH = 28;
final int HEIGHT = 28;

/* Uncomment the following line and run to display images.
   It will open in another window. */
// showImages(mnistTrain, 18, WIDTH, HEIGHT, SCALE, manager);
https://d2l-java-resources.s3.amazonaws.com/img/fashion_mnist_labels.png

Fig. 3.5.1 Fashion Mnist labels.

3.5.2. 读取小批量

为了使我们在读取训练集和测试集时更容易,我们使用 getData(manager)。回顾一下,在每次迭代中,getData(manager) 每次都会读取一小批量数据,大小为 batchSize。我们可以用 getData()getLabels() 来得到xy

让我们看一下读取训练数据所需的时间。

StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (Batch batch : mnistTrain.getData(manager)) {
    NDArray x = batch.getData().head();
    NDArray y = batch.getLabels().head();
}
System.out.println(String.format("%.2f sec", stopWatch.stop()));
20.32 sec

我们现在已经准备好在下面的章节中使用Fashion-MNIST数据集。

3.5.3. 小结

  • Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成。我们将在后续章节中使用此数据集来评估各种分类算法。

  • 我们将高度\(h\)像素,宽度\(w\)像素图像的形状记为\(h \times w\)或(\(h\), \(w\))。

  • 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。

3.5.4. 练习

  1. 减少 batchSize(如减少到 1)是否会影响读取性能?

  2. 数据迭代器的性能非常重要。你认为当前的实现足够快吗?探索各种选择来改进它。

  3. 查阅框架的在线API文档。还有哪些其他数据集可用?