Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_recurrent-modern/machine-translation-and-dataset.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_recurrent-modern/machine-translation-and-dataset.ipynb .. _sec_machine_translation: 机器翻译与数据集 ================ 语言模型是自然语言处理的关键, 而\ *机器翻译*\ 是语言模型最成功的基准测试。 因为机器翻译正是将输入序列转换成输出序列的 *序列转换模型*\ (sequence transduction)的核心问题。 序列转换模型在各类现代人工智能应用中发挥着至关重要的作用, 因此我们将其做为本章剩余部分和 :numref:`chap_attention`\ 的重点。 为此,本节将介绍机器翻译问题及其后文需要使用的数据集。 *机器翻译*\ (machine translation)指的是 将序列从一种语言自动翻译成另一种语言。 事实上,这个研究领域可以追溯到数字计算机发明后不久的20世纪40年代, 特别是在第二次世界大战中使用计算机破解语言编码。 几十年来,在使用神经网络进行端到端学习的兴起之前, 统计学方法在这一领域一直占据主导地位 :cite:`Brown.Cocke.Della-Pietra.ea.1988,Brown.Cocke.Della-Pietra.ea.1990`\ 。 因为\ *统计机器翻译*\ (statisticalmachine translation)涉及了 翻译模型和语言模型等组成部分的统计分析, 因此基于神经网络的方法通常被称为 *神经机器翻译*\ (neuralmachine translation), 用于将两种翻译模型区分开来。 本书的关注点是神经网络机器翻译方法,强调的是端到端的学习。 与 :numref:`sec_language_model`\ 中的语料库 是单一语言的语言模型问题存在不同, 机器翻译的数据集是由源语言和目标语言的文本序列对组成的。 因此,我们需要一种完全不同的方法来预处理机器翻译数据集, 而不是复用语言模型的预处理程序。 下面,我们看一下如何将预处理后的数据加载到小批量中用于训练。 .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/Functions.java %load ../utils/timemachine/Vocab.java %load ../utils/timemachine/RNNModel.java %load ../utils/timemachine/RNNModelScratch.java %load ../utils/timemachine/TimeMachine.java %load ../utils/timemachine/TimeMachineDataset.java .. code:: java import java.nio.charset.*; import java.util.zip.*; import java.util.stream.*; .. code:: java NDManager manager = NDManager.newBaseManager(); 下载和预处理数据集 ------------------ 首先,下载一个由\ `Tatoeba项目的双语句子对 `__ 组成的“英-法”数据集,数据集中的每一行都是制表符分隔的文本序列对, 序列对由英文文本序列和翻译后的法语文本序列组成。 请注意,每个文本序列可以是一个句子, 也可以是包含多个句子的一个段落。 在这个将英语翻译成法语的机器翻译问题中, 英语是\ *源语言*\ (source language), 法语是\ *目标语言*\ (target language)。 .. code:: java public static String readDataNMT() throws IOException { DownloadUtils.download( "http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip", "fra-eng.zip"); ZipFile zipFile = new ZipFile(new File("fra-eng.zip")); Enumeration entries = zipFile.entries(); while (entries.hasMoreElements()) { ZipEntry entry = entries.nextElement(); if (entry.getName().contains("fra.txt")) { InputStream stream = zipFile.getInputStream(entry); return new String(stream.readAllBytes(), StandardCharsets.UTF_8); } } return null; } String rawText = readDataNMT(); System.out.println(rawText.substring(0, 75)); .. parsed-literal:: :class: output Go. Va ! Hi. Salut ! Run! Cours ! Run! Courez ! Who? Qui ? Wow! Ça alors ! 下载数据集后,原始文本数据需要经过几个预处理步骤。 例如,我们用空格代替\ *不间断空格*\ (non-breaking space), 使用小写字母替换大写字母,并在单词和标点符号之间插入空格。 .. code:: java public static String preprocessNMT(String text) { // 使用空格替换不间断空格 // 使用小写字母替换大写字母 text = text.replace('\u202f', ' ').replaceAll("\\xa0", " ").toLowerCase(); // 在单词和标点符号之间插入空格 StringBuilder out = new StringBuilder(); Character currChar; for (int i = 0; i < text.length(); i++) { currChar = text.charAt(i); if (i > 0 && noSpace(currChar, text.charAt(i - 1))) { out.append(' '); } out.append(currChar); } return out.toString(); } public static boolean noSpace(Character currChar, Character prevChar) { /* Preprocess the English-French dataset. */ return new HashSet<>(Arrays.asList(',', '.', '!', '?')).contains(currChar) && prevChar != ' '; } String text = preprocessNMT(rawText); System.out.println(text.substring(0, 80)); .. parsed-literal:: :class: output go . va ! hi . salut ! run ! cours ! run ! courez ! who ? qui ? wow ! ça alors ! 词元化 ------ 与 :numref:`sec_language_model`\ 中的字符级词元化不同, 在机器翻译中,我们更喜欢单词级词元化 (最先进的模型可能使用更高级的词元化技术)。 下面的\ ``tokenize_nmt``\ 函数对前\ ``num_examples``\ 个文本序列对进行词元, 其中每个词元要么是一个词,要么是一个标点符号。 此函数返回两个词元列表:\ ``source``\ 和\ ``target``\ : ``source[i]``\ 是源语言(这里是英语)第\ :math:`i`\ 个文本序列的词元列表, ``target[i]``\ 是目标语言(这里是法语)第\ :math:`i`\ 个文本序列的词元列表。 .. code:: java public static Pair, ArrayList> tokenizeNMT( String text, Integer numExamples) { ArrayList source = new ArrayList<>(); ArrayList target = new ArrayList<>(); int i = 0; for (String line : text.split("\n")) { if (numExamples != null && i > numExamples) { break; } String[] parts = line.split("\t"); if (parts.length == 2) { source.add(parts[0].split(" ")); target.add(parts[1].split(" ")); } i += 1; } return new Pair<>(source, target); } Pair, ArrayList> pair = tokenizeNMT(text.toString(), null); ArrayList source = pair.getKey(); ArrayList target = pair.getValue(); for (String[] subArr : source.subList(0, 6)) { System.out.println(Arrays.toString(subArr)); } for (String[] subArr : target.subList(0, 6)) { System.out.println(Arrays.toString(subArr)); } .. parsed-literal:: :class: output [go, .] [hi, .] [run, !] [run, !] [who, ?] [wow, !] [va, !] [salut, !] [cours, !] [courez, !] [qui, ?] [ça, alors, !] 让我们绘制每个文本序列所包含的词元数量的直方图。 在这个简单的“英-法”数据集中,大多数文本序列的词元数量少于\ :math:`20`\ 个。 .. code:: java double[] y1 = new double[source.size()]; for (int i = 0; i < source.size(); i++) y1[i] = source.get(i).length; double[] y2 = new double[target.size()]; for (int i = 0; i < target.size(); i++) y2[i] = target.get(i).length; HistogramTrace trace1 = HistogramTrace.builder(y1).opacity(.75).name("source").nBinsX(20).build(); HistogramTrace trace2 = HistogramTrace.builder(y2).opacity(.75).name("target").nBinsX(20).build(); Layout layout = Layout.builder().barMode(Layout.BarMode.GROUP).build(); new Figure(layout, trace1, trace2); .. raw:: html