Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_recurrent-modern/machine-translation-and-dataset.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_recurrent-modern/machine-translation-and-dataset.ipynb .. _sec_machine_translation: 机器翻译与数据集 ================ 语言模型是自然语言处理的关键, 而\ *机器翻译*\ 是语言模型最成功的基准测试。 因为机器翻译正是将输入序列转换成输出序列的 *序列转换模型*\ (sequence transduction)的核心问题。 序列转换模型在各类现代人工智能应用中发挥着至关重要的作用, 因此我们将其做为本章剩余部分和 :numref:`chap_attention`\ 的重点。 为此,本节将介绍机器翻译问题及其后文需要使用的数据集。 *机器翻译*\ (machine translation)指的是 将序列从一种语言自动翻译成另一种语言。 事实上,这个研究领域可以追溯到数字计算机发明后不久的20世纪40年代, 特别是在第二次世界大战中使用计算机破解语言编码。 几十年来,在使用神经网络进行端到端学习的兴起之前, 统计学方法在这一领域一直占据主导地位 :cite:`Brown.Cocke.Della-Pietra.ea.1988,Brown.Cocke.Della-Pietra.ea.1990`\ 。 因为\ *统计机器翻译*\ (statisticalmachine translation)涉及了 翻译模型和语言模型等组成部分的统计分析, 因此基于神经网络的方法通常被称为 *神经机器翻译*\ (neuralmachine translation), 用于将两种翻译模型区分开来。 本书的关注点是神经网络机器翻译方法,强调的是端到端的学习。 与 :numref:`sec_language_model`\ 中的语料库 是单一语言的语言模型问题存在不同, 机器翻译的数据集是由源语言和目标语言的文本序列对组成的。 因此,我们需要一种完全不同的方法来预处理机器翻译数据集, 而不是复用语言模型的预处理程序。 下面,我们看一下如何将预处理后的数据加载到小批量中用于训练。 .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/Functions.java %load ../utils/timemachine/Vocab.java %load ../utils/timemachine/RNNModel.java %load ../utils/timemachine/RNNModelScratch.java %load ../utils/timemachine/TimeMachine.java %load ../utils/timemachine/TimeMachineDataset.java .. code:: java import java.nio.charset.*; import java.util.zip.*; import java.util.stream.*; .. code:: java NDManager manager = NDManager.newBaseManager(); 下载和预处理数据集 ------------------ 首先,下载一个由\ `Tatoeba项目的双语句子对 `__ 组成的“英-法”数据集,数据集中的每一行都是制表符分隔的文本序列对, 序列对由英文文本序列和翻译后的法语文本序列组成。 请注意,每个文本序列可以是一个句子, 也可以是包含多个句子的一个段落。 在这个将英语翻译成法语的机器翻译问题中, 英语是\ *源语言*\ (source language), 法语是\ *目标语言*\ (target language)。 .. code:: java public static String readDataNMT() throws IOException { DownloadUtils.download( "http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip", "fra-eng.zip"); ZipFile zipFile = new ZipFile(new File("fra-eng.zip")); Enumeration entries = zipFile.entries(); while (entries.hasMoreElements()) { ZipEntry entry = entries.nextElement(); if (entry.getName().contains("fra.txt")) { InputStream stream = zipFile.getInputStream(entry); return new String(stream.readAllBytes(), StandardCharsets.UTF_8); } } return null; } String rawText = readDataNMT(); System.out.println(rawText.substring(0, 75)); .. parsed-literal:: :class: output Go. Va ! Hi. Salut ! Run! Cours ! Run! Courez ! Who? Qui ? Wow! Ça alors ! 下载数据集后,原始文本数据需要经过几个预处理步骤。 例如,我们用空格代替\ *不间断空格*\ (non-breaking space), 使用小写字母替换大写字母,并在单词和标点符号之间插入空格。 .. code:: java public static String preprocessNMT(String text) { // 使用空格替换不间断空格 // 使用小写字母替换大写字母 text = text.replace('\u202f', ' ').replaceAll("\\xa0", " ").toLowerCase(); // 在单词和标点符号之间插入空格 StringBuilder out = new StringBuilder(); Character currChar; for (int i = 0; i < text.length(); i++) { currChar = text.charAt(i); if (i > 0 && noSpace(currChar, text.charAt(i - 1))) { out.append(' '); } out.append(currChar); } return out.toString(); } public static boolean noSpace(Character currChar, Character prevChar) { /* Preprocess the English-French dataset. */ return new HashSet<>(Arrays.asList(',', '.', '!', '?')).contains(currChar) && prevChar != ' '; } String text = preprocessNMT(rawText); System.out.println(text.substring(0, 80)); .. parsed-literal:: :class: output go . va ! hi . salut ! run ! cours ! run ! courez ! who ? qui ? wow ! ça alors ! 词元化 ------ 与 :numref:`sec_language_model`\ 中的字符级词元化不同, 在机器翻译中,我们更喜欢单词级词元化 (最先进的模型可能使用更高级的词元化技术)。 下面的\ ``tokenize_nmt``\ 函数对前\ ``num_examples``\ 个文本序列对进行词元, 其中每个词元要么是一个词,要么是一个标点符号。 此函数返回两个词元列表:\ ``source``\ 和\ ``target``\ : ``source[i]``\ 是源语言(这里是英语)第\ :math:`i`\ 个文本序列的词元列表, ``target[i]``\ 是目标语言(这里是法语)第\ :math:`i`\ 个文本序列的词元列表。 .. code:: java public static Pair, ArrayList> tokenizeNMT( String text, Integer numExamples) { ArrayList source = new ArrayList<>(); ArrayList target = new ArrayList<>(); int i = 0; for (String line : text.split("\n")) { if (numExamples != null && i > numExamples) { break; } String[] parts = line.split("\t"); if (parts.length == 2) { source.add(parts[0].split(" ")); target.add(parts[1].split(" ")); } i += 1; } return new Pair<>(source, target); } Pair, ArrayList> pair = tokenizeNMT(text.toString(), null); ArrayList source = pair.getKey(); ArrayList target = pair.getValue(); for (String[] subArr : source.subList(0, 6)) { System.out.println(Arrays.toString(subArr)); } for (String[] subArr : target.subList(0, 6)) { System.out.println(Arrays.toString(subArr)); } .. parsed-literal:: :class: output [go, .] [hi, .] [run, !] [run, !] [who, ?] [wow, !] [va, !] [salut, !] [cours, !] [courez, !] [qui, ?] [ça, alors, !] 让我们绘制每个文本序列所包含的词元数量的直方图。 在这个简单的“英-法”数据集中,大多数文本序列的词元数量少于\ :math:`20`\ 个。 .. code:: java double[] y1 = new double[source.size()]; for (int i = 0; i < source.size(); i++) y1[i] = source.get(i).length; double[] y2 = new double[target.size()]; for (int i = 0; i < target.size(); i++) y2[i] = target.get(i).length; HistogramTrace trace1 = HistogramTrace.builder(y1).opacity(.75).name("source").nBinsX(20).build(); HistogramTrace trace2 = HistogramTrace.builder(y2).opacity(.75).name("target").nBinsX(20).build(); Layout layout = Layout.builder().barMode(Layout.BarMode.GROUP).build(); new Figure(layout, trace1, trace2); .. raw:: html
词表 ---- 由于机器翻译数据集由语言对组成, 因此我们可以分别为源语言和目标语言构建两个词表。 使用单词级词元化时,词表大小将明显大于使用字符级词元化时的词表大小。 为了缓解这一问题,这里我们将出现次数少于2次的低频率词元 视为相同的未知(“”)词元。 除此之外,我们还指定了额外的特定词元, 例如在小批量时用于将序列填充到相同长度的填充词元(“”), 以及序列的开始词元(“”)和结束词元(“”)。 这些特殊词元在自然语言处理任务中比较常用。 .. code:: java Vocab srcVocab = new Vocab( source.stream().toArray(String[][]::new), 2, new String[] {"", "", ""}); System.out.println(srcVocab.length()); .. parsed-literal:: :class: output 10012 .. _subsec_mt_data_loading: 加载数据集 ---------- 回想一下,语言模型中的序列样本都有一个固定的长度, 无论这个样本是一个句子的一部分还是跨越了多个句子的一个片断。 这个固定长度是由 :numref:`sec_language_model`\ 中的 ``numSteps``\ (时间步数或词元数量)参数指定的。 在机器翻译中,每个样本都是由源和目标组成的文本序列对, 其中的每个文本序列可能具有不同的长度。 为了提高计算效率,我们仍然可以通过\ *截断*\ (truncation)和 *填充*\ (padding)方式实现一次只处理一个小批量的文本序列。 假设同一个小批量中的每个序列都应该具有相同的长度\ ``numSteps``\ , 那么如果文本序列的词元数目少于\ ``numSteps``\ 时, 我们将继续在其末尾添加特定的“”词元, 直到其长度达到\ ``numSteps``\ ; 反之,我们将截断文本序列时,只取其前\ ``numSteps`` 个词元, 并且丢弃剩余的词元。这样,每个文本序列将具有相同的长度, 以便以相同形状的小批量进行加载。 如前所述,下面的\ ``truncatePad``\ 函数将截断或填充文本序列。 .. code:: java public static int[] truncatePad(Integer[] integerLine, int numSteps, int paddingToken) { // 截断或填充文本序列 int[] line = Arrays.stream(integerLine).mapToInt(i -> i).toArray(); if (line.length > numSteps) { return Arrays.copyOfRange(line, 0, numSteps); } int[] paddingTokenArr = new int[numSteps - line.length]; // Pad Arrays.fill(paddingTokenArr, paddingToken); return IntStream.concat(Arrays.stream(line), Arrays.stream(paddingTokenArr)).toArray(); } int[] result = truncatePad(srcVocab.getIdxs(source.get(0)), 10, srcVocab.getIdx("")); System.out.println(Arrays.toString(result)); .. parsed-literal:: :class: output [47, 4, 1, 1, 1, 1, 1, 1, 1, 1] 现在我们定义一个函数,可以将文本序列 转换成小批量数据集用于训练。 我们将特定的“”词元添加到所有序列的末尾, 用于表示序列的结束。 当模型通过一个词元接一个词元地生成序列进行预测时, 生成的“”词元说明完成了序列输出工作。 此外,我们还记录了每个文本序列的长度, 统计长度时排除了填充词元, 在稍后将要介绍的一些模型会需要这个长度信息。 .. code:: java public static Pair buildArrayNMT( List lines, Vocab vocab, int numSteps) { // 将机器翻译的文本序列转换成小批量 List linesIntArr = new ArrayList<>(); for (String[] strings : lines) { linesIntArr.add(vocab.getIdxs(strings)); } for (int i = 0; i < linesIntArr.size(); i++) { List temp = new ArrayList<>(Arrays.asList(linesIntArr.get(i))); temp.add(vocab.getIdx("")); linesIntArr.set(i, temp.toArray(new Integer[0])); } NDManager manager = NDManager.newBaseManager(); NDArray arr = manager.create(new Shape(linesIntArr.size(), numSteps), DataType.INT32); int row = 0; for (Integer[] line : linesIntArr) { NDArray rowArr = manager.create(truncatePad(line, numSteps, vocab.getIdx(""))); arr.set(new NDIndex("{}:", row), rowArr); row += 1; } NDArray validLen = arr.neq(vocab.getIdx("")).sum(new int[] {1}); return new Pair<>(arr, validLen); } 训练模型 -------- 最后,我们定义\ ``loadDataNMT``\ 函数来返回数据迭代器, 以及源语言和目标语言的两种词表。 .. code:: java public static Pair> loadDataNMT( int batchSize, int numSteps, int numExamples) throws IOException { // 返回翻译数据集的迭代器和词表 String text = preprocessNMT(readDataNMT()); Pair, ArrayList> pair = tokenizeNMT(text, numExamples); ArrayList source = pair.getKey(); ArrayList target = pair.getValue(); Vocab srcVocab = new Vocab( source.toArray(String[][]::new), 2, new String[] {"", "", ""}); Vocab tgtVocab = new Vocab( target.toArray(String[][]::new), 2, new String[] {"", "", ""}); Pair pairArr = buildArrayNMT(source, srcVocab, numSteps); NDArray srcArr = pairArr.getKey(); NDArray srcValidLen = pairArr.getValue(); pairArr = buildArrayNMT(target, tgtVocab, numSteps); NDArray tgtArr = pairArr.getKey(); NDArray tgtValidLen = pairArr.getValue(); ArrayDataset dataset = new ArrayDataset.Builder() .setData(srcArr, srcValidLen) .optLabels(tgtArr, tgtValidLen) .setSampling(batchSize, true) .build(); return new Pair<>(dataset, new Pair<>(srcVocab, tgtVocab)); } 下面我们读出“英语-法语”数据集中的第一个小批量数据。 .. code:: java Pair> output = loadDataNMT(2, 8, 600); ArrayDataset dataset = output.getKey(); srcVocab = output.getValue().getKey(); Vocab tgtVocab = output.getValue().getValue(); Batch batch = dataset.getData(manager).iterator().next(); NDArray X = batch.getData().get(0); NDArray xValidLen = batch.getData().get(1); NDArray Y = batch.getLabels().get(0); NDArray yValidLen = batch.getLabels().get(1); System.out.println(X); System.out.println(xValidLen); System.out.println(Y); System.out.println(yValidLen); .. parsed-literal:: :class: output ND: (2, 8) gpu(0) int32 [[163, 34, 5, 3, 1, 1, 1, 1], [ 7, 64, 4, 3, 1, 1, 1, 1], ] ND: (2) gpu(0) int64 [ 4, 4] ND: (2, 8) gpu(0) int32 [[ 0, 5, 3, 1, 1, 1, 1, 1], [ 6, 7, 135, 4, 3, 1, 1, 1], ] ND: (2) gpu(0) int64 [ 3, 5] 小结 ---- - 机器翻译指的是将文本序列从一种语言自动翻译成另一种语言。 - 使用单词级词元化时的词表大小,将明显大于使用字符级词元化时的词表大小。为了缓解这一问题,我们可以将低频词元视为相同的未知词元。 - 通过截断和填充文本序列,可以保证所有的文本序列都具有相同的长度,以便以小批量的方式加载。 练习 ---- 1. 在\ ``load_data_nmt``\ 函数中尝试不同的\ ``num_examples``\ 参数值。这对源语言和目标语言的词表大小有何影响? 2. 某些语言(例如中文和日语)的文本没有单词边界指示符(例如空格)。对于这种情况,单词级词元化仍然是个好主意吗?为什么?