Run this notebook online:\ |Binder| or Colab: |Colab| .. |Binder| image:: https://mybinder.org/badge_logo.svg :target: https://mybinder.org/v2/gh/deepjavalibrary/d2l-java/master?filepath=chapter_natural-language-processing-pretraining/word-embedding-dataset.ipynb .. |Colab| image:: https://colab.research.google.com/assets/colab-badge.svg :target: https://colab.research.google.com/github/deepjavalibrary/d2l-java/blob/colab/chapter_natural-language-processing-pretraining/word-embedding-dataset.ipynb .. _sec_word2vec_data: 用于预训练词嵌入的数据集 ======================== 现在我们已经了解了word2vec模型的技术细节和大致的训练方法,让我们来看看它们的实现。具体地说,我们将以 :numref:`sec_word2vec`\ 的跳元模型和 :numref:`sec_approx_train`\ 的负采样为例。在本节中,我们从用于预训练词嵌入模型的数据集开始:数据的原始格式将被转换为可以在训练期间迭代的小批量。 .. code:: java %load ../utils/djl-imports %load ../utils/plot-utils %load ../utils/Functions.java %load ../utils/PlotUtils.java %load ../utils/StopWatch.java %load ../utils/Accumulator.java %load ../utils/Animator.java %load ../utils/Training.java %load ../utils/timemachine/Vocab.java .. code:: java import java.util.stream.*; import org.apache.commons.math3.distribution.EnumeratedDistribution; .. code:: java NDManager manager = NDManager.newBaseManager(); 正在读取数据集 -------------- 我们在这里使用的数据集是\ `Penn Tree Bank(PTB) `__\ 。该语料库取自“华尔街日报”的文章,分为训练集、验证集和测试集。在原始格式中,文本文件的每一行表示由空格分隔的一句话。在这里,我们将每个单词视为一个词元。 .. code:: java public static String[][] readPTB() throws IOException { String ptbURL = "http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip"; InputStream input = new URL(ptbURL).openStream(); ZipUtils.unzip(input, Paths.get("./")); ArrayList lines = new ArrayList<>(); File file = new File("./ptb/ptb.train.txt"); Scanner myReader = new Scanner(file); while (myReader.hasNextLine()) { lines.add(myReader.nextLine()); } String[][] tokens = new String[lines.size()][]; for (int i = 0; i < lines.size(); i++) { tokens[i] = lines.get(i).trim().split(" "); } return tokens; } .. code:: java String[][] sentences = readPTB(); System.out.println("# sentences: " + sentences.length); .. parsed-literal:: :class: output # sentences: 42068 在读取训练集之后,我们为语料库构建了一个词表,其中出现次数少于10次的任何单词都将由“”词元替换。请注意,原始数据集还包含表示稀有(未知)单词的“”词元。 .. code:: java Vocab vocab = new Vocab(sentences, 10, new String[] {}); System.out.println(vocab.length()); .. parsed-literal:: :class: output 6719 下采样 ------ 文本数据通常有“the”、“a”和“in”等高频词:它们在非常大的语料库中甚至可能出现数十亿次。然而,这些词经常在上下文窗口中与许多不同的词共同出现,提供的有用信息很少。例如,考虑上下文窗口中的词“chip”:直观地说,它与低频单词“intel”的共现比与高频单词“a”的共现在训练中更有用。此外,大量(高频)单词的训练速度很慢。因此,当训练词嵌入模型时,可以对高频单词进行\ *下采样* :cite:`Mikolov.Sutskever.Chen.ea.2013`\ 。具体地说,数据集中的每个词\ :math:`w_i`\ 将有概率地被丢弃 .. math:: P(w_i) = \max\left(1 - \sqrt{\frac{t}{f(w_i)}}, 0\right), 其中\ :math:`f(w_i)`\ 是\ :math:`w_i`\ 的词数与数据集中的总词数的比率,常量\ :math:`t`\ 是超参数(在实验中为\ :math:`10^{-4}`\ )。我们可以看到,只有当相对比率\ :math:`f(w_i) > t`\ 时,(高频)词\ :math:`w_i`\ 才能被丢弃,且该词的相对比率越高,被丢弃的概率就越大。 .. code:: java public static boolean keep(String token, LinkedHashMap counter, int numTokens) { // Return True if to keep this token during subsampling return new Random().nextFloat() < Math.sqrt(1e-4 / counter.get(token) * numTokens); } public static String[][] subSampling(String[][] sentences, Vocab vocab) { for (int i = 0; i < sentences.length; i++) { for (int j = 0; j < sentences[i].length; j++) { sentences[i][j] = vocab.idxToToken.get(vocab.getIdx(sentences[i][j])); } } // Count the frequency for each word LinkedHashMap counter = vocab.countCorpus2D(sentences); int numTokens = 0; for (Integer value : counter.values()) { numTokens += value; } // Now do the subsampling String[][] output = new String[sentences.length][]; for (int i = 0; i < sentences.length; i++) { ArrayList tks = new ArrayList<>(); for (int j = 0; j < sentences[i].length; j++) { String tk = sentences[i][j]; if (keep(sentences[i][j], counter, numTokens)) { tks.add(tk); } } output[i] = tks.toArray(new String[tks.size()]); } return output; } String[][] subsampled = subSampling(sentences, vocab); 下面的代码片段绘制了下采样前后每句话的词元数量的直方图。正如预期的那样,下采样通过删除高频词来显著缩短句子,这将使训练加速。 .. code:: java double[] y1 = new double[sentences.length]; for (int i = 0; i < sentences.length; i++) y1[i] = sentences[i].length; double[] y2 = new double[subsampled.length]; for (int i = 0; i < subsampled.length; i++) y2[i] = subsampled[i].length; HistogramTrace trace1 = HistogramTrace.builder(y1).opacity(.75).name("origin").nBinsX(20).build(); HistogramTrace trace2 = HistogramTrace.builder(y2).opacity(.75).name("subsampled").nBinsX(20).build(); Layout layout = Layout.builder() .barMode(Layout.BarMode.GROUP) .showLegend(true) .xAxis(Axis.builder().title("# tokens per sentence").build()) .yAxis(Axis.builder().title("count").build()) .build(); new Figure(layout, trace1, trace2); .. raw:: html
对于单个词元,高频词“the”的采样率不到1/20。 .. code:: java public static String compareCounts(String token, String[][] sentences, String[][] subsampled) { int beforeCount = 0; for (int i = 0; i < sentences.length; i++) { for (int j = 0; j < sentences[i].length; j++) { if (sentences[i][j].equals(token)) beforeCount += 1; } } int afterCount = 0; for (int i = 0; i < subsampled.length; i++) { for (int j = 0; j < subsampled[i].length; j++) { if (subsampled[i][j].equals(token)) afterCount += 1; } } return "# of \"the\": before=" + beforeCount + ", after=" + afterCount; } System.out.println(compareCounts("the", sentences, subsampled)); .. parsed-literal:: :class: output # of "the": before=50770, after=2111 相比之下,低频词“join”则被完全保留。 .. code:: java System.out.println(compareCounts("join", sentences, subsampled)); .. parsed-literal:: :class: output # of "the": before=45, after=45 在下采样之后,我们将词元映射到它们在语料库中的索引。 .. code:: java Integer[][] corpus = new Integer[subsampled.length][]; for (int i = 0; i < subsampled.length; i++) { corpus[i] = vocab.getIdxs(subsampled[i]); } for (int i = 0; i < 3; i++) { System.out.println(Arrays.toString(corpus[i])); } .. parsed-literal:: :class: output [] [71, 2115, 5] [5277, 3054, 1580] 中心词和上下文词的提取 ---------------------- 下面的\ ``get_centers_and_contexts``\ 函数从\ ``corpus``\ 中提取所有中心词及其上下文词。它随机采样1到\ ``max_window_size``\ 之间的整数作为上下文窗口。对于任一中心词,与其距离不超过采样上下文窗口大小的词为其上下文词。 .. code:: java public static Pair, ArrayList>> getCentersAndContext( Integer[][] corpus, int maxWindowSize) { ArrayList centers = new ArrayList<>(); ArrayList> contexts = new ArrayList<>(); for (Integer[] line : corpus) { // Each sentence needs at least 2 words to form a "central target word // - context word" pair if (line.length < 2) { continue; } centers.addAll(Arrays.asList(line)); for (int i = 0; i < line.length; i++) { // Context window centered at i int windowSize = new Random().nextInt(maxWindowSize - 1) + 1; List indices = IntStream.range( Math.max(0, i - windowSize), Math.min(line.length, i + 1 + windowSize)) .boxed() .collect(Collectors.toList()); // Exclude the central target word from the context words indices.remove(indices.indexOf(i)); ArrayList context = new ArrayList<>(); for (Integer idx : indices) { context.add(line[idx]); } contexts.add(context); } } return new Pair<>(centers, contexts); } 接下来,我们创建一个人工数据集,分别包含7个和3个单词的两个句子。设置最大上下文窗口大小为2,并打印所有中心词及其上下文词。 .. code:: java Integer[][] tinyDataset = new Integer[][] { IntStream.range(0, 7) .boxed() .collect(Collectors.toList()) .toArray(new Integer[] {}), IntStream.range(7, 10) .boxed() .collect(Collectors.toList()) .toArray(new Integer[] {}) }; System.out.println("dataset " + Arrays.deepToString(tinyDataset)); Pair, ArrayList>> centerContextPair = getCentersAndContext(tinyDataset, 2); for (int i = 0; i < centerContextPair.getValue().size(); i++) { System.out.println( "Center " + centerContextPair.getKey().get(i) + " has contexts" + centerContextPair.getValue().get(i)); } .. parsed-literal:: :class: output dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]] Center 0 has contexts[1] Center 1 has contexts[0, 2] Center 2 has contexts[1, 3] Center 3 has contexts[2, 4] Center 4 has contexts[3, 5] Center 5 has contexts[4, 6] Center 6 has contexts[5] Center 7 has contexts[8] Center 8 has contexts[7, 9] Center 9 has contexts[8] 在PTB数据集上进行训练时,我们将最大上下文窗口大小设置为5。下面提取数据集中的所有中心词及其上下文词。 .. code:: java centerContextPair = getCentersAndContext(corpus, 5); ArrayList allCenters = centerContextPair.getKey(); ArrayList> allContexts = centerContextPair.getValue(); System.out.println("中心词-上下文词对”的数量:" + allCenters.size()); .. parsed-literal:: :class: output 中心词-上下文词对”的数量:352849 负采样 ------ 我们使用负采样进行近似训练。为了根据预定义的分布对噪声词进行采样,我们定义以下\ ``RandomGenerator``\ 类,其中(可能未规范化的)采样分布通过变量\ ``samplingWeights``\ 传递。 .. code:: java public class RandomGenerator { /* Draw a random int in [0, n] according to n sampling weights. */ private List population; private List samplingWeights; private List candidates; private List> pmf; private int i; public RandomGenerator(List samplingWeights) { this.population = IntStream.range(0, samplingWeights.size()).boxed().collect(Collectors.toList()); this.samplingWeights = samplingWeights; this.candidates = new ArrayList<>(); this.i = 0; this.pmf = new ArrayList<>(); for (int i = 0; i < samplingWeights.size(); i++) { this.pmf.add(new org.apache.commons.math3.util.Pair(this.population.get(i), this.samplingWeights.get(i).doubleValue())); } } public Integer draw() { if (this.i == this.candidates.size()) { this.candidates = Arrays.asList((Integer[]) new EnumeratedDistribution(this.pmf).sample(10000, new Integer[] {})); this.i = 0; } this.i += 1; return this.candidates.get(this.i - 1); } } 例如,我们可以在索引1、2和3中绘制10个随机变量\ :math:`X`\ ,采样概率为\ :math:`P(X=1)=2/9, P(X=2)=3/9`\ 和\ :math:`P(X=3)=4/9`\ ,如下所示。 .. code:: java RandomGenerator generator = new RandomGenerator(Arrays.asList(new Double[] {2.0, 3.0, 4.0})); Integer[] generatorOutput = new Integer[10]; for (int i = 0; i < 10; i++) { generatorOutput[i] = generator.draw(); } System.out.println(Arrays.toString(generatorOutput)); .. parsed-literal:: :class: output [2, 1, 2, 2, 1, 1, 1, 2, 1, 0] 对于一对中心词和上下文词,我们随机抽取了\ ``K``\ 个(实验中为5个)噪声词。根据word2vec论文中的建议,将噪声词\ :math:`w`\ 的采样概率\ :math:`P(w)`\ 设置为其在字典中的相对频率,其幂为0.75 :cite:`Mikolov.Sutskever.Chen.ea.2013`\ 。 .. code:: java public static ArrayList> getNegatives( ArrayList> allContexts, Integer[][] corpus, int K) { LinkedHashMap counter = Vocab.countCorpus2D(corpus); ArrayList samplingWeights = new ArrayList<>(); for (Map.Entry entry : counter.entrySet()) { samplingWeights.add(Math.pow(entry.getValue(), .75)); } ArrayList> allNegatives = new ArrayList<>(); RandomGenerator generator = new RandomGenerator(samplingWeights); for (ArrayList contexts : allContexts) { ArrayList negatives = new ArrayList<>(); while (negatives.size() < contexts.size() * K) { Integer neg = generator.draw(); // Noise words cannot be context words if (!contexts.contains(neg)) { negatives.add(neg); } } allNegatives.add(negatives); } return allNegatives; } ArrayList> allNegatives = getNegatives(allContexts, corpus, 5); .. _subsec_word2vec-minibatch-loading: 小批量加载训练实例 ------------------ 在提取所有中心词及其上下文词和采样噪声词后,将它们转换成小批量的样本,在训练过程中可以迭代加载。 在小批量中,\ :math:`i^\mathrm{th}`\ 个样本包括中心词及其\ :math:`n_i`\ 个上下文词和\ :math:`m_i`\ 个噪声词。由于上下文窗口大小不同,\ :math:`n_i+m_i`\ 对于不同的\ :math:`i`\ 是不同的。因此,对于每个样本,我们在\ ``contexts_negatives``\ 个变量中将其上下文词和噪声词连结起来,并填充零,直到连结长度达到\ :math:`\max_i n_i+m_i`\ (``max_len``)。为了在计算损失时排除填充,我们定义了掩码变量\ ``masks``\ 。在\ ``masks``\ 中的元素和\ ``contexts_negatives``\ 中的元素之间存在一一对应关系,其中\ ``masks``\ 中的0(否则为1)对应于\ ``contexts_negatives``\ 中的填充。 为了区分正反例,我们在\ ``contexts_negatives``\ 中通过一个\ ``labels``\ 变量将上下文词与噪声词分开。类似于\ ``masks``\ ,在\ ``labels``\ 中的元素和\ ``contexts_negatives``\ 中的元素之间也存在一一对应关系,其中\ ``labels``\ 中的1(否则为0)对应于\ ``contexts_negatives``\ 中的上下文词的正例。 上述思想在下面的\ ``batchify``\ 函数中实现。其输入\ ``data``\ 是长度等于批量大小的列表,其中每个元素是由中心词\ ``center``\ 、其上下文词\ ``context``\ 和其噪声词\ ``negative``\ 组成的样本。此函数返回一个可以在训练期间加载用于计算的小批量,例如包括掩码变量。 .. code:: java public static NDList batchifyData(NDList[] data) { NDList centers = new NDList(); NDList contextsNegatives = new NDList(); NDList masks = new NDList(); NDList labels = new NDList(); long maxLen = 0; for (NDList ndList : data) { // center, context, negative = ndList maxLen = Math.max( maxLen, ndList.get(1).countNonzero().getLong() + ndList.get(2).countNonzero().getLong()); } for (NDList ndList : data) { // center, context, negative = ndList NDArray center = ndList.get(0); NDArray context = ndList.get(1); NDArray negative = ndList.get(2); int count = 0; for (int i = 0; i < context.size(); i++) { // If a 0 is found, we want to stop adding these // values to NDArray if (context.get(i).getInt() == 0) { break; } contextsNegatives.add(context.get(i).reshape(1)); masks.add(manager.create(1).reshape(1)); labels.add(manager.create(1).reshape(1)); count += 1; } for (int i = 0; i < negative.size(); i++) { // If a 0 is found, we want to stop adding these // values to NDArray if (negative.get(i).getInt() == 0) { break; } contextsNegatives.add(negative.get(i).reshape(1)); masks.add(manager.create(1).reshape(1)); labels.add(manager.create(0).reshape(1)); count += 1; } // Fill with zeroes remaining array while (count != maxLen) { contextsNegatives.add(manager.create(0).reshape(1)); masks.add(manager.create(0).reshape(1)); labels.add(manager.create(0).reshape(1)); count += 1; } // Add this NDArrays to output NDArrays centers.add(center.reshape(1)); } return new NDList( NDArrays.concat(centers).reshape(data.length, -1), NDArrays.concat(contextsNegatives).reshape(data.length, -1), NDArrays.concat(masks).reshape(data.length, -1), NDArrays.concat(labels).reshape(data.length, -1)); } 让我们使用一个小批量的两个样本来测试此函数。 .. code:: java NDList x1 = new NDList( manager.create(new int[] {1}), manager.create(new int[] {2, 2}), manager.create(new int[] {3, 3, 3, 3})); NDList x2 = new NDList( manager.create(new int[] {1}), manager.create(new int[] {2, 2, 2}), manager.create(new int[] {3, 3})); NDList batchedData = batchifyData(new NDList[] {x1, x2}); String[] names = new String[] {"centers", "contexts_negatives", "masks", "labels"}; for (int i = 0; i < batchedData.size(); i++) { System.out.println(names[i] + " shape: " + batchedData.get(i)); } .. parsed-literal:: :class: output centers shape: ND: (2, 1) gpu(0) int32 [[ 1], [ 1], ] contexts_negatives shape: ND: (2, 6) gpu(0) int32 [[ 2, 2, 3, 3, 3, 3], [ 2, 2, 2, 3, 3, 0], ] masks shape: ND: (2, 6) gpu(0) int32 [[ 1, 1, 1, 1, 1, 1], [ 1, 1, 1, 1, 1, 0], ] labels shape: ND: (2, 6) gpu(0) int32 [[ 1, 1, 0, 0, 0, 0], [ 1, 1, 1, 0, 0, 0], ] 整合代码 -------- 最后,我们定义了读取PTB数据集并返回数据迭代器和词表的\ ``load_data_ptb``\ 函数。 .. code:: java public static NDList convertNDArray(Object[] data, NDManager manager) { ArrayList centers = (ArrayList) data[0]; ArrayList> contexts = (ArrayList>) data[1]; ArrayList> negatives = (ArrayList>) data[2]; // Create centers NDArray NDArray centersNDArray = manager.create(centers.stream().mapToInt(i -> i).toArray()); // Create contexts NDArray int maxLen = 0; for (ArrayList context : contexts) { maxLen = Math.max(maxLen, context.size()); } // Fill arrays with 0s to all have same lengths and be able to create NDArray for (ArrayList context : contexts) { while (context.size() != maxLen) { context.add(0); } } NDArray contextsNDArray = manager.create( contexts.stream() .map(u -> u.stream().mapToInt(i -> i).toArray()) .toArray(int[][]::new)); // Create negatives NDArray maxLen = 0; for (ArrayList negative : negatives) { maxLen = Math.max(maxLen, negative.size()); } // Fill arrays with 0s to all have same lengths and be able to create NDArray for (ArrayList negative : negatives) { while (negative.size() != maxLen) { negative.add(0); } } NDArray negativesNDArray = manager.create( negatives.stream() .map(u -> u.stream().mapToInt(i -> i).toArray()) .toArray(int[][]::new)); return new NDList(centersNDArray, contextsNDArray, negativesNDArray); } public static Pair loadDataPTB( int batchSize, int maxWindowSize, int numNoiseWords, NDManager manager) throws IOException, TranslateException { String[][] sentences = readPTB(); Vocab vocab = new Vocab(sentences, 10, new String[] {}); String[][] subSampled = subSampling(sentences, vocab); Integer[][] corpus = new Integer[subSampled.length][]; for (int i = 0; i < subSampled.length; i++) { corpus[i] = vocab.getIdxs(subSampled[i]); } Pair, ArrayList>> pair = getCentersAndContext(corpus, maxWindowSize); ArrayList> negatives = getNegatives(pair.getValue(), corpus, numNoiseWords); NDList ndArrays = convertNDArray(new Object[] {pair.getKey(), pair.getValue(), negatives}, manager); ArrayDataset dataset = new ArrayDataset.Builder() .setData(ndArrays.get(0), ndArrays.get(1), ndArrays.get(2)) .optDataBatchifier( new Batchifier() { @Override public NDList batchify(NDList[] ndLists) { return batchifyData(ndLists); } @Override public NDList[] unbatchify(NDList ndList) { return new NDList[0]; } }) .setSampling(batchSize, true) .build(); return new Pair<>(dataset, vocab); } 让我们打印数据迭代器的第一个小批量。 .. code:: java Pair datasetVocab = loadDataPTB(512, 5, 5, manager); ArrayDataset dataset = datasetVocab.getKey(); vocab = datasetVocab.getValue(); Batch batch = dataset.getData(manager).iterator().next(); for (int i = 0; i < batch.getData().size(); i++) { System.out.println(names[i] + " shape: " + batch.getData().get(i).getShape()); } .. parsed-literal:: :class: output centers shape: (512, 1) contexts_negatives shape: (512, 48) masks shape: (512, 48) labels shape: (512, 48) 小结 ---- - 高频词在训练中可能不是那么有用。我们可以对他们进行下采样,以便在训练中加快速度。 - 为了提高计算效率,我们以小批量方式加载样本。我们可以定义其他变量来区分填充标记和非填充标记,以及正例和负例。 练习 ---- 1. 如果不使用下采样,本节中代码的运行时间会发生什么变化? 2. ``RandomGenerator``\ 类缓存\ ``k``\ 个随机采样结果。将\ ``k``\ 设置为其他值,看看它如何影响数据加载速度。 3. 本节代码中的哪些其他超参数可能会影响数据加载速度?