Run this notebook online:Binder or Colab: Colab

14.3. 用于预训练词嵌入的数据集

现在我们已经了解了word2vec模型的技术细节和大致的训练方法,让我们来看看它们的实现。具体地说,我们将以 Section 14.1的跳元模型和 Section 14.2的负采样为例。在本节中,我们从用于预训练词嵌入模型的数据集开始:数据的原始格式将被转换为可以在训练期间迭代的小批量。

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/Functions.java
%load ../utils/PlotUtils.java

%load ../utils/StopWatch.java
%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
import java.util.stream.*;
import org.apache.commons.math3.distribution.EnumeratedDistribution;
NDManager manager = NDManager.newBaseManager();

14.3.1. 正在读取数据集

我们在这里使用的数据集是Penn Tree Bank(PTB)。该语料库取自“华尔街日报”的文章,分为训练集、验证集和测试集。在原始格式中,文本文件的每一行表示由空格分隔的一句话。在这里,我们将每个单词视为一个词元。

public static String[][] readPTB() throws IOException {
    String ptbURL = "http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip";
    InputStream input = new URL(ptbURL).openStream();
    ZipUtils.unzip(input, Paths.get("./"));

    ArrayList<String> lines = new ArrayList<>();
    File file = new File("./ptb/ptb.train.txt");
    Scanner myReader = new Scanner(file);
    while (myReader.hasNextLine()) {
        lines.add(myReader.nextLine());
    }
    String[][] tokens = new String[lines.size()][];
    for (int i = 0; i < lines.size(); i++) {
        tokens[i] = lines.get(i).trim().split(" ");
    }
    return tokens;
}
String[][] sentences = readPTB();
System.out.println("# sentences: " + sentences.length);
# sentences: 42068

在读取训练集之后,我们为语料库构建了一个词表,其中出现次数少于10次的任何单词都将由“<unk>”词元替换。请注意,原始数据集还包含表示稀有(未知)单词的“<unk>”词元。

Vocab vocab = new Vocab(sentences, 10, new String[] {});
System.out.println(vocab.length());
6719

14.3.2. 下采样

文本数据通常有“the”、“a”和“in”等高频词:它们在非常大的语料库中甚至可能出现数十亿次。然而,这些词经常在上下文窗口中与许多不同的词共同出现,提供的有用信息很少。例如,考虑上下文窗口中的词“chip”:直观地说,它与低频单词“intel”的共现比与高频单词“a”的共现在训练中更有用。此外,大量(高频)单词的训练速度很慢。因此,当训练词嵌入模型时,可以对高频单词进行下采样 [Mikolov et al., 2013b]。具体地说,数据集中的每个词\(w_i\)将有概率地被丢弃

(14.3.1)\[P(w_i) = \max\left(1 - \sqrt{\frac{t}{f(w_i)}}, 0\right),\]

其中\(f(w_i)\)\(w_i\)的词数与数据集中的总词数的比率,常量\(t\)是超参数(在实验中为\(10^{-4}\))。我们可以看到,只有当相对比率\(f(w_i) > t\)时,(高频)词\(w_i\)才能被丢弃,且该词的相对比率越高,被丢弃的概率就越大。

public static boolean keep(String token, LinkedHashMap<?, Integer> counter, int numTokens) {
    // Return True if to keep this token during subsampling
    return new Random().nextFloat() < Math.sqrt(1e-4 / counter.get(token) * numTokens);
}

public static String[][] subSampling(String[][] sentences, Vocab vocab) {
    for (int i = 0; i < sentences.length; i++) {
        for (int j = 0; j < sentences[i].length; j++) {
            sentences[i][j] = vocab.idxToToken.get(vocab.getIdx(sentences[i][j]));
        }
    }
    // Count the frequency for each word
    LinkedHashMap<?, Integer> counter = vocab.countCorpus2D(sentences);
    int numTokens = 0;
    for (Integer value : counter.values()) {
        numTokens += value;
    }

    // Now do the subsampling
    String[][] output = new String[sentences.length][];
    for (int i = 0; i < sentences.length; i++) {
        ArrayList<String> tks = new ArrayList<>();
        for (int j = 0; j < sentences[i].length; j++) {
            String tk = sentences[i][j];
            if (keep(sentences[i][j], counter, numTokens)) {
                tks.add(tk);
            }
        }
        output[i] = tks.toArray(new String[tks.size()]);
    }

    return output;
}

String[][] subsampled = subSampling(sentences, vocab);

下面的代码片段绘制了下采样前后每句话的词元数量的直方图。正如预期的那样,下采样通过删除高频词来显著缩短句子,这将使训练加速。

double[] y1 = new double[sentences.length];
for (int i = 0; i < sentences.length; i++) y1[i] = sentences[i].length;
double[] y2 = new double[subsampled.length];
for (int i = 0; i < subsampled.length; i++) y2[i] = subsampled[i].length;

HistogramTrace trace1 =
        HistogramTrace.builder(y1).opacity(.75).name("origin").nBinsX(20).build();
HistogramTrace trace2 =
        HistogramTrace.builder(y2).opacity(.75).name("subsampled").nBinsX(20).build();

Layout layout =
        Layout.builder()
                .barMode(Layout.BarMode.GROUP)
                .showLegend(true)
                .xAxis(Axis.builder().title("# tokens per sentence").build())
                .yAxis(Axis.builder().title("count").build())
                .build();
new Figure(layout, trace1, trace2);

对于单个词元,高频词“the”的采样率不到1/20。

public static String compareCounts(String token, String[][] sentences, String[][] subsampled) {
    int beforeCount = 0;
    for (int i = 0; i < sentences.length; i++) {
        for (int j = 0; j < sentences[i].length; j++) {
            if (sentences[i][j].equals(token)) beforeCount += 1;
        }
    }

    int afterCount = 0;
    for (int i = 0; i < subsampled.length; i++) {
        for (int j = 0; j < subsampled[i].length; j++) {
            if (subsampled[i][j].equals(token)) afterCount += 1;
        }
    }

    return "# of \"the\": before=" + beforeCount + ", after=" + afterCount;
}

System.out.println(compareCounts("the", sentences, subsampled));
# of "the": before=50770, after=2111

相比之下,低频词“join”则被完全保留。

System.out.println(compareCounts("join", sentences, subsampled));
# of "the": before=45, after=45

在下采样之后,我们将词元映射到它们在语料库中的索引。

Integer[][] corpus = new Integer[subsampled.length][];
for (int i = 0; i < subsampled.length; i++) {
    corpus[i] = vocab.getIdxs(subsampled[i]);
}
for (int i = 0; i < 3; i++) {
    System.out.println(Arrays.toString(corpus[i]));
}
[]
[71, 2115, 5]
[5277, 3054, 1580]

14.3.3. 中心词和上下文词的提取

下面的get_centers_and_contexts函数从corpus中提取所有中心词及其上下文词。它随机采样1到max_window_size之间的整数作为上下文窗口。对于任一中心词,与其距离不超过采样上下文窗口大小的词为其上下文词。

public static Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> getCentersAndContext(
        Integer[][] corpus, int maxWindowSize) {
    ArrayList<Integer> centers = new ArrayList<>();
    ArrayList<ArrayList<Integer>> contexts = new ArrayList<>();

    for (Integer[] line : corpus) {
        // Each sentence needs at least 2 words to form a "central target word
        // - context word" pair
        if (line.length < 2) {
            continue;
        }
        centers.addAll(Arrays.asList(line));
        for (int i = 0; i < line.length; i++) { // Context window centered at i
            int windowSize = new Random().nextInt(maxWindowSize - 1) + 1;
            List<Integer> indices =
                    IntStream.range(
                                    Math.max(0, i - windowSize),
                                    Math.min(line.length, i + 1 + windowSize))
                            .boxed()
                            .collect(Collectors.toList());
            // Exclude the central target word from the context words
            indices.remove(indices.indexOf(i));
            ArrayList<Integer> context = new ArrayList<>();
            for (Integer idx : indices) {
                context.add(line[idx]);
            }
            contexts.add(context);
        }
    }
    return new Pair<>(centers, contexts);
}

接下来,我们创建一个人工数据集,分别包含7个和3个单词的两个句子。设置最大上下文窗口大小为2,并打印所有中心词及其上下文词。

Integer[][] tinyDataset =
        new Integer[][] {
            IntStream.range(0, 7)
                    .boxed()
                    .collect(Collectors.toList())
                    .toArray(new Integer[] {}),
            IntStream.range(7, 10)
                    .boxed()
                    .collect(Collectors.toList())
                    .toArray(new Integer[] {})
        };

System.out.println("dataset " + Arrays.deepToString(tinyDataset));
Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> centerContextPair =
        getCentersAndContext(tinyDataset, 2);
for (int i = 0; i < centerContextPair.getValue().size(); i++) {
    System.out.println(
            "Center "
                    + centerContextPair.getKey().get(i)
                    + " has contexts"
                    + centerContextPair.getValue().get(i));
}
dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
Center 0 has contexts[1]
Center 1 has contexts[0, 2]
Center 2 has contexts[1, 3]
Center 3 has contexts[2, 4]
Center 4 has contexts[3, 5]
Center 5 has contexts[4, 6]
Center 6 has contexts[5]
Center 7 has contexts[8]
Center 8 has contexts[7, 9]
Center 9 has contexts[8]

在PTB数据集上进行训练时,我们将最大上下文窗口大小设置为5。下面提取数据集中的所有中心词及其上下文词。

centerContextPair = getCentersAndContext(corpus, 5);
ArrayList<Integer> allCenters = centerContextPair.getKey();
ArrayList<ArrayList<Integer>> allContexts = centerContextPair.getValue();
System.out.println("中心词-上下文词对”的数量:" + allCenters.size());
中心词-上下文词对”的数量:352849

14.3.4. 负采样

我们使用负采样进行近似训练。为了根据预定义的分布对噪声词进行采样,我们定义以下RandomGenerator类,其中(可能未规范化的)采样分布通过变量samplingWeights传递。

public class RandomGenerator {
    /* Draw a random int in [0, n] according to n sampling weights. */

    private List<Integer> population;
    private List<Double> samplingWeights;
    private List<Integer> candidates;
    private List<org.apache.commons.math3.util.Pair<Integer, Double>> pmf;
    private int i;

    public RandomGenerator(List<Double> samplingWeights) {
        this.population =
                IntStream.range(0, samplingWeights.size()).boxed().collect(Collectors.toList());
        this.samplingWeights = samplingWeights;
        this.candidates = new ArrayList<>();
        this.i = 0;

        this.pmf = new ArrayList<>();
        for (int i = 0; i < samplingWeights.size(); i++) {
            this.pmf.add(new org.apache.commons.math3.util.Pair(this.population.get(i), this.samplingWeights.get(i).doubleValue()));
        }
    }

    public Integer draw() {
        if (this.i == this.candidates.size()) {
            this.candidates =
                    Arrays.asList((Integer[]) new EnumeratedDistribution(this.pmf).sample(10000, new Integer[] {}));
            this.i = 0;
        }
        this.i += 1;
        return this.candidates.get(this.i - 1);
    }
}

例如,我们可以在索引1、2和3中绘制10个随机变量\(X\),采样概率为\(P(X=1)=2/9, P(X=2)=3/9\)\(P(X=3)=4/9\),如下所示。

RandomGenerator generator =
        new RandomGenerator(Arrays.asList(new Double[] {2.0, 3.0, 4.0}));
Integer[] generatorOutput = new Integer[10];
for (int i = 0; i < 10; i++) {
    generatorOutput[i] = generator.draw();
}
System.out.println(Arrays.toString(generatorOutput));
[2, 1, 2, 2, 1, 1, 1, 2, 1, 0]

对于一对中心词和上下文词,我们随机抽取了K个(实验中为5个)噪声词。根据word2vec论文中的建议,将噪声词\(w\)的采样概率\(P(w)\)设置为其在字典中的相对频率,其幂为0.75 [Mikolov et al., 2013b]

public static ArrayList<ArrayList<Integer>> getNegatives(
        ArrayList<ArrayList<Integer>> allContexts, Integer[][] corpus, int K) {
    LinkedHashMap<?, Integer> counter = Vocab.countCorpus2D(corpus);
    ArrayList<Double> samplingWeights = new ArrayList<>();
    for (Map.Entry<?, Integer> entry : counter.entrySet()) {
        samplingWeights.add(Math.pow(entry.getValue(), .75));
    }
    ArrayList<ArrayList<Integer>> allNegatives = new ArrayList<>();
    RandomGenerator generator = new RandomGenerator(samplingWeights);
    for (ArrayList<Integer> contexts : allContexts) {
        ArrayList<Integer> negatives = new ArrayList<>();
        while (negatives.size() < contexts.size() * K) {
            Integer neg = generator.draw();
            // Noise words cannot be context words
            if (!contexts.contains(neg)) {
                negatives.add(neg);
            }
        }
        allNegatives.add(negatives);
    }
    return allNegatives;
}

ArrayList<ArrayList<Integer>> allNegatives = getNegatives(allContexts, corpus, 5);

14.3.5. 小批量加载训练实例

在提取所有中心词及其上下文词和采样噪声词后,将它们转换成小批量的样本,在训练过程中可以迭代加载。

在小批量中,\(i^\mathrm{th}\)个样本包括中心词及其\(n_i\)个上下文词和\(m_i\)个噪声词。由于上下文窗口大小不同,\(n_i+m_i\)对于不同的\(i\)是不同的。因此,对于每个样本,我们在contexts_negatives个变量中将其上下文词和噪声词连结起来,并填充零,直到连结长度达到\(\max_i n_i+m_i\)(max_len)。为了在计算损失时排除填充,我们定义了掩码变量masks。在masks中的元素和contexts_negatives中的元素之间存在一一对应关系,其中masks中的0(否则为1)对应于contexts_negatives中的填充。

为了区分正反例,我们在contexts_negatives中通过一个labels变量将上下文词与噪声词分开。类似于masks,在labels中的元素和contexts_negatives中的元素之间也存在一一对应关系,其中labels中的1(否则为0)对应于contexts_negatives中的上下文词的正例。

上述思想在下面的batchify函数中实现。其输入data是长度等于批量大小的列表,其中每个元素是由中心词center、其上下文词context和其噪声词negative组成的样本。此函数返回一个可以在训练期间加载用于计算的小批量,例如包括掩码变量。

public static NDList batchifyData(NDList[] data) {
    NDList centers = new NDList();
    NDList contextsNegatives = new NDList();
    NDList masks = new NDList();
    NDList labels = new NDList();

    long maxLen = 0;
    for (NDList ndList : data) { // center, context, negative = ndList
        maxLen =
                Math.max(
                        maxLen,
                        ndList.get(1).countNonzero().getLong()
                                + ndList.get(2).countNonzero().getLong());
    }
    for (NDList ndList : data) { // center, context, negative = ndList
        NDArray center = ndList.get(0);
        NDArray context = ndList.get(1);
        NDArray negative = ndList.get(2);

        int count = 0;
        for (int i = 0; i < context.size(); i++) {
            // If a 0 is found, we want to stop adding these
            // values to NDArray
            if (context.get(i).getInt() == 0) {
                break;
            }
            contextsNegatives.add(context.get(i).reshape(1));
            masks.add(manager.create(1).reshape(1));
            labels.add(manager.create(1).reshape(1));
            count += 1;
        }
        for (int i = 0; i < negative.size(); i++) {
            // If a 0 is found, we want to stop adding these
            // values to NDArray
            if (negative.get(i).getInt() == 0) {
                break;
            }
            contextsNegatives.add(negative.get(i).reshape(1));
            masks.add(manager.create(1).reshape(1));
            labels.add(manager.create(0).reshape(1));
            count += 1;
        }
        // Fill with zeroes remaining array
        while (count != maxLen) {
            contextsNegatives.add(manager.create(0).reshape(1));
            masks.add(manager.create(0).reshape(1));
            labels.add(manager.create(0).reshape(1));
            count += 1;
        }

        // Add this NDArrays to output NDArrays
        centers.add(center.reshape(1));
    }
    return new NDList(
            NDArrays.concat(centers).reshape(data.length, -1),
            NDArrays.concat(contextsNegatives).reshape(data.length, -1),
            NDArrays.concat(masks).reshape(data.length, -1),
            NDArrays.concat(labels).reshape(data.length, -1));
}

让我们使用一个小批量的两个样本来测试此函数。

NDList x1 =
        new NDList(
                manager.create(new int[] {1}),
                manager.create(new int[] {2, 2}),
                manager.create(new int[] {3, 3, 3, 3}));
NDList x2 =
        new NDList(
                manager.create(new int[] {1}),
                manager.create(new int[] {2, 2, 2}),
                manager.create(new int[] {3, 3}));

NDList batchedData = batchifyData(new NDList[] {x1, x2});
String[] names = new String[] {"centers", "contexts_negatives", "masks", "labels"};
for (int i = 0; i < batchedData.size(); i++) {
    System.out.println(names[i] + " shape: " + batchedData.get(i));
}
centers shape: ND: (2, 1) gpu(0) int32
[[ 1],
 [ 1],
]

contexts_negatives shape: ND: (2, 6) gpu(0) int32
[[ 2,  2,  3,  3,  3,  3],
 [ 2,  2,  2,  3,  3,  0],
]

masks shape: ND: (2, 6) gpu(0) int32
[[ 1,  1,  1,  1,  1,  1],
 [ 1,  1,  1,  1,  1,  0],
]

labels shape: ND: (2, 6) gpu(0) int32
[[ 1,  1,  0,  0,  0,  0],
 [ 1,  1,  1,  0,  0,  0],
]

14.3.6. 整合代码

最后,我们定义了读取PTB数据集并返回数据迭代器和词表的load_data_ptb函数。

public static NDList convertNDArray(Object[] data, NDManager manager) {
    ArrayList<Integer> centers = (ArrayList<Integer>) data[0];
    ArrayList<ArrayList<Integer>> contexts = (ArrayList<ArrayList<Integer>>) data[1];
    ArrayList<ArrayList<Integer>> negatives = (ArrayList<ArrayList<Integer>>) data[2];

    // Create centers NDArray
    NDArray centersNDArray = manager.create(centers.stream().mapToInt(i -> i).toArray());

    // Create contexts NDArray
    int maxLen = 0;
    for (ArrayList<Integer> context : contexts) {
        maxLen = Math.max(maxLen, context.size());
    }
    // Fill arrays with 0s to all have same lengths and be able to create NDArray
    for (ArrayList<Integer> context : contexts) {
        while (context.size() != maxLen) {
            context.add(0);
        }
    }
    NDArray contextsNDArray =
            manager.create(
                    contexts.stream()
                            .map(u -> u.stream().mapToInt(i -> i).toArray())
                            .toArray(int[][]::new));

    // Create negatives NDArray
    maxLen = 0;
    for (ArrayList<Integer> negative : negatives) {
        maxLen = Math.max(maxLen, negative.size());
    }
    // Fill arrays with 0s to all have same lengths and be able to create NDArray
    for (ArrayList<Integer> negative : negatives) {
        while (negative.size() != maxLen) {
            negative.add(0);
        }
    }
    NDArray negativesNDArray =
            manager.create(
                    negatives.stream()
                            .map(u -> u.stream().mapToInt(i -> i).toArray())
                            .toArray(int[][]::new));

    return new NDList(centersNDArray, contextsNDArray, negativesNDArray);
}

public static Pair<ArrayDataset, Vocab> loadDataPTB(
        int batchSize, int maxWindowSize, int numNoiseWords, NDManager manager)
        throws IOException, TranslateException {
    String[][] sentences = readPTB();
    Vocab vocab = new Vocab(sentences, 10, new String[] {});
    String[][] subSampled = subSampling(sentences, vocab);
    Integer[][] corpus = new Integer[subSampled.length][];
    for (int i = 0; i < subSampled.length; i++) {
        corpus[i] = vocab.getIdxs(subSampled[i]);
    }
    Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> pair =
            getCentersAndContext(corpus, maxWindowSize);
    ArrayList<ArrayList<Integer>> negatives =
            getNegatives(pair.getValue(), corpus, numNoiseWords);

    NDList ndArrays =
            convertNDArray(new Object[] {pair.getKey(), pair.getValue(), negatives}, manager);
    ArrayDataset dataset =
            new ArrayDataset.Builder()
                    .setData(ndArrays.get(0), ndArrays.get(1), ndArrays.get(2))
                    .optDataBatchifier(
                            new Batchifier() {
                                @Override
                                public NDList batchify(NDList[] ndLists) {
                                    return batchifyData(ndLists);
                                }

                                @Override
                                public NDList[] unbatchify(NDList ndList) {
                                    return new NDList[0];
                                }
                            })
                    .setSampling(batchSize, true)
                    .build();

    return new Pair<>(dataset, vocab);
}

让我们打印数据迭代器的第一个小批量。

Pair<ArrayDataset, Vocab> datasetVocab = loadDataPTB(512, 5, 5, manager);
ArrayDataset dataset = datasetVocab.getKey();
vocab = datasetVocab.getValue();

Batch batch = dataset.getData(manager).iterator().next();
for (int i = 0; i < batch.getData().size(); i++) {
    System.out.println(names[i] + " shape: " + batch.getData().get(i).getShape());
}
centers shape: (512, 1)
contexts_negatives shape: (512, 48)
masks shape: (512, 48)
labels shape: (512, 48)

14.3.7. 小结

  • 高频词在训练中可能不是那么有用。我们可以对他们进行下采样,以便在训练中加快速度。

  • 为了提高计算效率,我们以小批量方式加载样本。我们可以定义其他变量来区分填充标记和非填充标记,以及正例和负例。

14.3.8. 练习

  1. 如果不使用下采样,本节中代码的运行时间会发生什么变化?

  2. RandomGenerator类缓存k个随机采样结果。将k设置为其他值,看看它如何影响数据加载速度。

  3. 本节代码中的哪些其他超参数可能会影响数据加载速度?