Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_linear-networks/image-classification-dataset.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_linear-networks/image-classification-dataset.ipynb .. _sec_fashion_mnist: 图像分类数据集 ============== 目前广泛使用的图像分类数据集之一是 MNIST 数据集 :cite:`LeCun.Bottou.Bengio.ea.1998`\ 。虽然它是很不错的基准数据集,但按今天的标准,即使是简单的模型也能达到95%以上的分类准确率,因此不适合区分强模型和弱模型。如今,MNIST更像是一个健全检查,而不是一个基准。 为了提高难度,我们将在接下来的章节中讨论在2017年发布的性质相似但相对复杂的Fashion-MNIST数据集 :cite:`Xiao.Rasul.Vollgraf.2017`\ 。 .. code:: java %load ../utils/djl-imports %load ../utils/StopWatch.java %load ../utils/ImageUtils.java .. code:: java import ai.djl.basicdataset.cv.classification.*; import ai.djl.training.dataset.Record; import java.awt.image.BufferedImage; import java.awt.Graphics2D; import java.awt.Color; 读取数据集 ---------- 就像\ ``MNIST``\ ,我们可以使用 DJL ``ai.djl.basicdataset`` 中的 ``FashionMnist`` 类来下载并读取到内存中。 .. code:: java int batchSize = 256; boolean randomShuffle = true; FashionMnist mnistTrain = FashionMnist.builder() .optUsage(Dataset.Usage.TRAIN) .setSampling(batchSize, randomShuffle) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); FashionMnist mnistTest = FashionMnist.builder() .optUsage(Dataset.Usage.TEST) .setSampling(batchSize, randomShuffle) .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE)) .build(); mnistTrain.prepare(); mnistTest.prepare(); NDManager manager = NDManager.newBaseManager(); Fashion-MNIST 由 10 个类别的图像组成,每个类别由训练数据集中的 6000 张图像和测试数据集中的 1000 张图像组成。\ *测试数据集*\ (test dataset)不会用于训练,只用于评估模型性能。训练集和测试集分别包含 60000 和 10000 张图像。 .. code:: java System.out.println(mnistTrain.size()); System.out.println(mnistTest.size()); .. parsed-literal:: :class: output 60000 10000 Fashion-MNIST中包含的10个类别分别为t-shirt(T恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和ankle boot(短靴)。以下函数用于在数字标签索引及其文本名称之间进行转换。 .. code:: java // Saved in the FashionMnist class for later use public String[] getFashionMnistLabels(int[] labelIndices) { String[] textLabels = {"t-shirt", "trouser", "pullover", "dress", "coat", "sandal", "shirt", "sneaker", "bag", "ankle boot"}; String[] convertedLabels = new String[labelIndices.length]; for (int i = 0; i < labelIndices.length; i++) { convertedLabels[i] = textLabels[labelIndices[i]]; } return convertedLabels; } public String getFashionMnistLabel(int labelIndice) { String[] textLabels = {"t-shirt", "trouser", "pullover", "dress", "coat", "sandal", "shirt", "sneaker", "bag", "ankle boot"}; return textLabels[labelIndice]; } 我们现在可以创建一个函数来可视化这些样本。 下面的代码只是为了帮助直观地理解数据, 你不需要太关注可视化的细节。我们读取了许多数据点并将它们的 ``RGB`` 值从 0-255 转换为 0-1 之间。然后我们将颜色以灰度的形式,将其与标签一起显示出来。 .. code:: java // Saved in the FashionMnistUtils class for later use public static BufferedImage showImages( ArrayDataset dataset, int number, int width, int height, int scale, NDManager manager) { BufferedImage[] images = new BufferedImage[number]; String[] labels = new String[number]; for (int i = 0; i < number; i++) { Record record = dataset.get(manager, i); NDArray array = record.getData().get(0).squeeze(-1); int y = (int) record.getLabels().get(0).getFloat(); images[i] = toImage(array, width, height); labels[i] = getFashionMnistLabel(y); } int w = images[0].getWidth() * scale; int h = images[0].getHeight() * scale; return ImageUtils.showImages(images, labels, w, h); } private static BufferedImage toImage(NDArray array, int width, int height) { System.setProperty("apple.awt.UIElement", "true"); BufferedImage img = new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY); Graphics2D g = (Graphics2D) img.getGraphics(); for (int i = 0; i < width; i++) { for (int j = 0; j < height; j++) { float c = array.getFloat(j, i) / 255; // scale down to between 0 and 1 g.setColor(new Color(c, c, c)); // set as a gray color g.fillRect(i, j, 1, 1); } } g.dispose(); return img; } 以下是训练数据集中前几个样本的图像及其相应的标签(文本形式)。 .. code:: java final int SCALE = 4; final int WIDTH = 28; final int HEIGHT = 28; showImages(mnistTrain, 6, WIDTH, HEIGHT, SCALE, manager) .. figure:: output_image-classification-dataset_89a587_13_0.png 读取小批量 ---------- 为了使我们在读取训练集和测试集时更容易,我们使用 ``getData(manager)``\ 。回顾一下,在每次迭代中,\ ``getData(manager)`` 每次都会读取一小批量数据,大小为 ``batchSize``\ 。我们可以用 ``getData()`` 和 ``getLabels()`` 来得到\ ``x``\ 和\ ``y``\ 。 让我们看一下读取训练数据所需的时间。 .. code:: java StopWatch stopWatch = new StopWatch(); stopWatch.start(); for (Batch batch : mnistTrain.getData(manager)) { NDArray x = batch.getData().head(); NDArray y = batch.getLabels().head(); } System.out.println(String.format("%.2f sec", stopWatch.stop())); .. parsed-literal:: :class: output 0.20 sec 我们现在已经准备好在下面的章节中使用Fashion-MNIST数据集。 小结 ---- - Fashion-MNIST是一个服装分类数据集,由10个类别的图像组成。我们将在后续章节中使用此数据集来评估各种分类算法。 - 我们将高度\ :math:`h`\ 像素,宽度\ :math:`w`\ 像素图像的形状记为\ :math:`h \times w`\ 或(\ :math:`h`, :math:`w`)。 - 数据迭代器是获得更高性能的关键组件。依靠实现良好的数据迭代器,利用高性能计算来避免减慢训练过程。 练习 ---- 1. 减少 ``batchSize``\ (如减少到 1)是否会影响读取性能? 2. 数据迭代器的性能非常重要。你认为当前的实现足够快吗?探索各种选择来改进它。 3. 查阅框架的在线API文档。还有哪些其他数据集可用?