Run this notebook online:\ |Binder| or Colab: |Colab|
.. |Binder| image:: https://mybinder.org/badge_logo.svg
:target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_recurrent-modern/machine-translation-and-dataset.ipynb
.. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg
:target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_recurrent-modern/machine-translation-and-dataset.ipynb
.. _sec_machine_translation:
机器翻译与数据集
================
语言模型是自然语言处理的关键,
而\ *机器翻译*\ 是语言模型最成功的基准测试。
因为机器翻译正是将输入序列转换成输出序列的 *序列转换模型*\ (sequence
transduction)的核心问题。
序列转换模型在各类现代人工智能应用中发挥着至关重要的作用,
因此我们将其做为本章剩余部分和 :numref:`chap_attention`\ 的重点。
为此,本节将介绍机器翻译问题及其后文需要使用的数据集。
*机器翻译*\ (machine translation)指的是
将序列从一种语言自动翻译成另一种语言。
事实上,这个研究领域可以追溯到数字计算机发明后不久的20世纪40年代,
特别是在第二次世界大战中使用计算机破解语言编码。
几十年来,在使用神经网络进行端到端学习的兴起之前,
统计学方法在这一领域一直占据主导地位
:cite:`Brown.Cocke.Della-Pietra.ea.1988,Brown.Cocke.Della-Pietra.ea.1990`\ 。
因为\ *统计机器翻译*\ (statisticalmachine translation)涉及了
翻译模型和语言模型等组成部分的统计分析,
因此基于神经网络的方法通常被称为 *神经机器翻译*\ (neuralmachine
translation), 用于将两种翻译模型区分开来。
本书的关注点是神经网络机器翻译方法,强调的是端到端的学习。 与
:numref:`sec_language_model`\ 中的语料库
是单一语言的语言模型问题存在不同,
机器翻译的数据集是由源语言和目标语言的文本序列对组成的。
因此,我们需要一种完全不同的方法来预处理机器翻译数据集,
而不是复用语言模型的预处理程序。
下面,我们看一下如何将预处理后的数据加载到小批量中用于训练。
.. code:: java
%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModel.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
%load ../utils/timemachine/TimeMachineDataset.java
.. code:: java
import java.nio.charset.*;
import java.util.zip.*;
import java.util.stream.*;
.. code:: java
NDManager manager = NDManager.newBaseManager();
下载和预处理数据集
------------------
首先,下载一个由\ `Tatoeba项目的双语句子对 `__
组成的“英-法”数据集,数据集中的每一行都是制表符分隔的文本序列对,
序列对由英文文本序列和翻译后的法语文本序列组成。
请注意,每个文本序列可以是一个句子, 也可以是包含多个句子的一个段落。
在这个将英语翻译成法语的机器翻译问题中, 英语是\ *源语言*\ (source
language), 法语是\ *目标语言*\ (target language)。
.. code:: java
public static String readDataNMT() throws IOException {
DownloadUtils.download(
"http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip", "fra-eng.zip");
ZipFile zipFile = new ZipFile(new File("fra-eng.zip"));
Enumeration extends ZipEntry> entries = zipFile.entries();
while (entries.hasMoreElements()) {
ZipEntry entry = entries.nextElement();
if (entry.getName().contains("fra.txt")) {
InputStream stream = zipFile.getInputStream(entry);
return new String(stream.readAllBytes(), StandardCharsets.UTF_8);
}
}
return null;
}
String rawText = readDataNMT();
System.out.println(rawText.substring(0, 75));
.. parsed-literal::
:class: output
Go. Va !
Hi. Salut !
Run! Cours !
Run! Courez !
Who? Qui ?
Wow! Ça alors !
下载数据集后,原始文本数据需要经过几个预处理步骤。
例如,我们用空格代替\ *不间断空格*\ (non-breaking space),
使用小写字母替换大写字母,并在单词和标点符号之间插入空格。
.. code:: java
public static String preprocessNMT(String text) {
// 使用空格替换不间断空格
// 使用小写字母替换大写字母
text = text.replace('\u202f', ' ').replaceAll("\\xa0", " ").toLowerCase();
// 在单词和标点符号之间插入空格
StringBuilder out = new StringBuilder();
Character currChar;
for (int i = 0; i < text.length(); i++) {
currChar = text.charAt(i);
if (i > 0 && noSpace(currChar, text.charAt(i - 1))) {
out.append(' ');
}
out.append(currChar);
}
return out.toString();
}
public static boolean noSpace(Character currChar, Character prevChar) {
/* Preprocess the English-French dataset. */
return new HashSet<>(Arrays.asList(',', '.', '!', '?')).contains(currChar)
&& prevChar != ' ';
}
String text = preprocessNMT(rawText);
System.out.println(text.substring(0, 80));
.. parsed-literal::
:class: output
go . va !
hi . salut !
run ! cours !
run ! courez !
who ? qui ?
wow ! ça alors !
词元化
------
与 :numref:`sec_language_model`\ 中的字符级词元化不同,
在机器翻译中,我们更喜欢单词级词元化
(最先进的模型可能使用更高级的词元化技术)。
下面的\ ``tokenize_nmt``\ 函数对前\ ``num_examples``\ 个文本序列对进行词元,
其中每个词元要么是一个词,要么是一个标点符号。
此函数返回两个词元列表:\ ``source``\ 和\ ``target``\ :
``source[i]``\ 是源语言(这里是英语)第\ :math:`i`\ 个文本序列的词元列表,
``target[i]``\ 是目标语言(这里是法语)第\ :math:`i`\ 个文本序列的词元列表。
.. code:: java
public static Pair, ArrayList> tokenizeNMT(
String text, Integer numExamples) {
ArrayList source = new ArrayList<>();
ArrayList target = new ArrayList<>();
int i = 0;
for (String line : text.split("\n")) {
if (numExamples != null && i > numExamples) {
break;
}
String[] parts = line.split("\t");
if (parts.length == 2) {
source.add(parts[0].split(" "));
target.add(parts[1].split(" "));
}
i += 1;
}
return new Pair<>(source, target);
}
Pair, ArrayList> pair = tokenizeNMT(text.toString(), null);
ArrayList source = pair.getKey();
ArrayList target = pair.getValue();
for (String[] subArr : source.subList(0, 6)) {
System.out.println(Arrays.toString(subArr));
}
for (String[] subArr : target.subList(0, 6)) {
System.out.println(Arrays.toString(subArr));
}
.. parsed-literal::
:class: output
[go, .]
[hi, .]
[run, !]
[run, !]
[who, ?]
[wow, !]
[va, !]
[salut, !]
[cours, !]
[courez, !]
[qui, ?]
[ça, alors, !]
让我们绘制每个文本序列所包含的词元数量的直方图。
在这个简单的“英-法”数据集中,大多数文本序列的词元数量少于\ :math:`20`\ 个。
.. code:: java
double[] y1 = new double[source.size()];
for (int i = 0; i < source.size(); i++) y1[i] = source.get(i).length;
double[] y2 = new double[target.size()];
for (int i = 0; i < target.size(); i++) y2[i] = target.get(i).length;
HistogramTrace trace1 =
HistogramTrace.builder(y1).opacity(.75).name("source").nBinsX(20).build();
HistogramTrace trace2 =
HistogramTrace.builder(y2).opacity(.75).name("target").nBinsX(20).build();
Layout layout = Layout.builder().barMode(Layout.BarMode.GROUP).build();
new Figure(layout, trace1, trace2);
.. raw:: html