Run this notebook online:Binder or Colab: Colab

8.3. Language Models and the Dataset

In Section 8.2, we see how to map text data into tokens, where these tokens can be viewed as a sequence of discrete observations, such as words or characters. Assume that the tokens in a text sequence of length \(T\) are in turn \(x_1, x_2, \ldots, x_T\). Then, in the text sequence, \(x_t\)(\(1 \leq t \leq T\)) can be considered as the observation or label at time step \(t\). Given such a text sequence, the goal of a language model is to estimate the joint probability of the sequence

(8.3.1)\[P(x_1, x_2, \ldots, x_T).\]

Language models are incredibly useful. For instance, an ideal language model would be able to generate natural text just on its own, simply by drawing one token at a time \(x_t \sim P(x_t \mid x_{t-1}, \ldots, x_1)\). Quite unlike the monkey using a typewriter, all text emerging from such a model would pass as natural language, e.g., English text. Furthermore, it would be sufficient for generating a meaningful dialog, simply by conditioning the text on previous dialog fragments. Clearly we are still very far from designing such a system, since it would need to understand the text rather than just generate grammatically sensible content.

Nonetheless, language models are of great service even in their limited form. For instance, the phrases “to recognize speech” and “to wreck a nice beach” sound very similar. This can cause ambiguity in speech recognition, which is easily resolved through a language model that rejects the second translation as outlandish. Likewise, in a document summarization algorithm it is worthwhile knowing that “dog bites man” is much more frequent than “man bites dog”, or that “I want to eat grandma” is a rather disturbing statement, whereas “I want to eat, grandma” is much more benign.

8.3.1. Learning a Language Model

The obvious question is how we should model a document, or even a sequence of tokens. Suppose that we tokenize text data at the word level. We can take recourse to the analysis we applied to sequence models in Section 8.1. Let us start by applying basic probability rules:

(8.3.2)\[P(x_1, x_2, \ldots, x_T) = \prod_{t=1}^T P(x_t \mid x_1, \ldots, x_{t-1}).\]

For example, the probability of a text sequence containing four words would be given as:

(8.3.3)\[P(\text{deep}, \text{learning}, \text{is}, \text{fun}) = P(\text{deep}) P(\text{learning} \mid \text{deep}) P(\text{is} \mid \text{deep}, \text{learning}) P(\text{fun} \mid \text{deep}, \text{learning}, \text{is}).\]

In order to compute the language model, we need to calculate the probability of words and the conditional probability of a word given the previous few words. Such probabilities are essentially language model parameters.

Here, we assume that the training dataset is a large text corpus, such as all Wikipedia entries, Project Gutenberg, and all text posted on the Web. The probability of words can be calculated from the relative word frequency of a given word in the training dataset. For example, the estimate \(\hat{P}(\text{deep})\) can be calculated as the probability of any sentence starting with the word “deep”. A slightly less accurate approach would be to count all occurrences of the word “deep” and divide it by the total number of words in the corpus. This works fairly well, particularly for frequent words. Moving on, we could attempt to estimate

(8.3.4)\[\hat{P}(\text{learning} \mid \text{deep}) = \frac{n(\text{deep, learning})}{n(\text{deep})},\]

where \(n(x)\) and \(n(x, x')\) are the number of occurrences of singletons and consecutive word pairs, respectively. Unfortunately, estimating the probability of a word pair is somewhat more difficult, since the occurrences of “deep learning” are a lot less frequent. In particular, for some unusual word combinations it may be tricky to find enough occurrences to get accurate estimates. Things take a turn for the worse for three-word combinations and beyond. There will be many plausible three-word combinations that we likely will not see in our dataset. Unless we provide some solution to assign such word combinations nonzero count, we will not be able to use them in a language model. If the dataset is small or if the words are very rare, we might not find even a single one of them.

A common strategy is to perform some form of Laplace smoothing. The solution is to add a small constant to all counts. Denote by \(n\) the total number of words in the training set and \(m\) the number of unique words. This solution helps with singletons, e.g., via

(8.3.5)\[\begin{split}\begin{aligned} \hat{P}(x) & = \frac{n(x) + \epsilon_1/m}{n + \epsilon_1}, \\ \hat{P}(x' \mid x) & = \frac{n(x, x') + \epsilon_2 \hat{P}(x')}{n(x) + \epsilon_2}, \\ \hat{P}(x'' \mid x,x') & = \frac{n(x, x',x'') + \epsilon_3 \hat{P}(x'')}{n(x, x') + \epsilon_3}. \end{aligned}\end{split}\]

Here \(\epsilon_1,\epsilon_2\), and \(\epsilon_3\) are hyperparameters. Take \(\epsilon_1\) as an example: when \(\epsilon_1 = 0\), no smoothing is applied; when \(\epsilon_1\) approaches positive infinity, \(\hat{P}(x)\) approaches the uniform probability \(1/m\). The above is a rather primitive variant of what other techniques can accomplish [Wood et al., 2011].

Unfortunately, models like this get unwieldy rather quickly for the following reasons. First, we need to store all counts. Second, this entirely ignores the meaning of the words. For instance, “cat” and “feline” should occur in related contexts. It is quite difficult to adjust such models to additional contexts, whereas, deep learning based language models are well suited to take this into account. Last, long word sequences are almost certain to be novel, hence a model that simply counts the frequency of previously seen word sequences is bound to perform poorly there.

8.3.2. Markov Models and \(n\)-grams

Before we discuss solutions involving deep learning, we need some more terminology and concepts. Recall our discussion of Markov Models in Section 8.1. Let us apply this to language modeling. A distribution over sequences satisfies the Markov property of first order if \(P(x_{t+1} \mid x_t, \ldots, x_1) = P(x_{t+1} \mid x_t)\). Higher orders correspond to longer dependencies. This leads to a number of approximations that we could apply to model a sequence:

(8.3.6)\[\begin{split}\begin{aligned} P(x_1, x_2, x_3, x_4) &= P(x_1) P(x_2) P(x_3) P(x_4),\\ P(x_1, x_2, x_3, x_4) &= P(x_1) P(x_2 \mid x_1) P(x_3 \mid x_2) P(x_4 \mid x_3),\\ P(x_1, x_2, x_3, x_4) &= P(x_1) P(x_2 \mid x_1) P(x_3 \mid x_1, x_2) P(x_4 \mid x_2, x_3). \end{aligned}\end{split}\]

The probability formulae that involve one, two, and three variables are typically referred to as unigram, bigram, and trigram models, respectively. In the following, we will learn how to design better models.

8.3.3. Natural Language Statistics

Let us see how this works on real data. We construct a vocabulary based on the time machine dataset as introduced in Section 8.2 and print the top 10 most frequent words.

%load ../utils/djl-imports
%load ../utils/plot-utils
%load ../utils/PlotUtils.java

%load ../utils/Accumulator.java
%load ../utils/Animator.java
%load ../utils/Functions.java
%load ../utils/StopWatch.java
%load ../utils/Training.java
%load ../utils/timemachine/Vocab.java
%load ../utils/timemachine/RNNModelScratch.java
%load ../utils/timemachine/TimeMachine.java
NDManager manager = NDManager.newBaseManager();
String[][] tokens = TimeMachine.tokenize(TimeMachine.readTimeMachine(), "word");
// Since each text line is not necessarily a sentence or a paragraph, we
// concatenate all text lines
List<String> corpus = new ArrayList<>();
for (int i = 0; i < tokens.length; i++) {
    for (int j = 0; j < tokens[i].length; j++) {
        if (tokens[i][j] != "") {
            corpus.add(tokens[i][j]);
        }
    }
}

Vocab vocab = new Vocab(new String[][] {corpus.toArray(new String[0])}, -1, new String[0]);
for (int i = 0; i < 10; i++) {
    Map.Entry<String, Integer> token = vocab.tokenFreqs.get(i);
    System.out.println(token.getKey() + ": " + token.getValue());
}
the: 2261
i: 1267
and: 1245
of: 1155
a: 816
to: 695
was: 552
in: 541
that: 443
my: 440

As we can see, the most popular words are actually quite boring to look at. They are often referred to as stop words and thus filtered out. Nonetheless, they still carry meaning and we will still use them. Besides, it is quite clear that the word frequency decays rather rapidly. The \(10^{\mathrm{th}}\) most frequent word is less than \(1/5\) as common as the most popular one. To get a better idea, we plot the figure of the word frequency.

int n = vocab.tokenFreqs.size();
double[] freqs = new double[n];
double[] x = new double[n];
for (int i = 0; i < n; i++) {
    freqs[i] = (double) vocab.tokenFreqs.get(i).getValue();
    x[i] = (double) i;
}

PlotUtils.plotLogScale(new double[][] {x}, new double[][] {freqs}, new String[] {""},
                       "token: x", "frequency: n(x)");

We are on to something quite fundamental here: the word frequency decays rapidly in a well-defined way. After dealing with the first few words as exceptions, all the remaining words roughly follow a straight line on a log-log plot. This means that words satisfy Zipf’s law, which states that the frequency \(n_i\) of the \(i^\mathrm{th}\) most frequent word is:

(8.3.7)\[n_i \propto \frac{1}{i^\alpha},\]

which is equivalent to

(8.3.8)\[\log n_i = -\alpha \log i + c,\]

where \(\alpha\) is the exponent that characterizes the distribution and \(c\) is a constant. This should already give us pause if we want to model words by count statistics and smoothing. After all, we will significantly overestimate the frequency of the tail, also known as the infrequent words. But what about the other word combinations, such as bigrams, trigrams, and beyond? Let us see whether the bigram frequency behaves in the same manner as the unigram frequency.

String[] bigramTokens = new String[corpus.size()-1];
for (int i = 0; i < bigramTokens.length; i++) {
    bigramTokens[i] = corpus.get(i) + " " + corpus.get(i+1);
}
Vocab bigramVocab = new Vocab(new String[][] {bigramTokens}, -1, new String[0]);
for (int i = 0; i < 10; i++) {
    Map.Entry<String, Integer> token = bigramVocab.tokenFreqs.get(i);
    System.out.println(token.getKey() + ": " + token.getValue());
}
of the: 309
in the: 169
i had: 130
i was: 112
and the: 109
the time: 102
it was: 99
to the: 85
as i: 78
of a: 73

One thing is notable here. Out of the ten most frequent word pairs, nine are composed of both stop words and only one is relevant to the actual book—“the time”. Furthermore, let us see whether the trigram frequency behaves in the same manner.

String[] trigramTokens = new String[corpus.size()-2];
for (int i = 0; i < trigramTokens.length; i++) {
    trigramTokens[i] = corpus.get(i) + " " + corpus.get(i+1) + " " + corpus.get(i+2);
}
Vocab trigramVocab = new Vocab(new String[][] {trigramTokens}, -1, new String[0]);
for (int i = 0; i < 10; i++) {
    Map.Entry<String, Integer> token = trigramVocab.tokenFreqs.get(i);
    System.out.println(token.getKey() + ": " + token.getValue());
}
the time traveller: 59
the time machine: 30
  : 26
the medical man: 24
it seemed to: 16
it was a: 15
here and there: 15
seemed to me: 14
i did not: 14
i saw the: 13

Last, let us visualize the token frequency among these three models: unigrams, bigrams, and trigrams.

n = bigramVocab.tokenFreqs.size();
double[] bigramFreqs = new double[n];
double[] bigramX = new double[n];
for (int i = 0; i < n; i++) {
    bigramFreqs[i] = (double) bigramVocab.tokenFreqs.get(i).getValue();
    bigramX[i] = (double) i;
}

n = trigramVocab.tokenFreqs.size();
double[] trigramFreqs = new double[n];
double[] trigramX = new double[n];
for (int i = 0; i < n; i++) {
    trigramFreqs[i] = (double) trigramVocab.tokenFreqs.get(i).getValue();
    trigramX[i] = (double) i;
}

PlotUtils.plotLogScale(new double[][] {x, bigramX, trigramX}, new double[][] {freqs, bigramFreqs, trigramFreqs},
                       new String[] {"unigram", "bigram", "trigram"}, "token: x", "frequency: n(x)");

This figure is quite exciting for a number of reasons. First, beyond unigram words, sequences of words also appear to be following Zipf’s law, albeit with a smaller exponent \(\alpha\) in (8.3.7), depending on the sequence length. Second, the number of distinct \(n\)-grams is not that large. This gives us hope that there is quite a lot of structure in language. Third, many \(n\)-grams occur very rarely, which makes Laplace smoothing rather unsuitable for language modeling. Instead, we will use deep learning based models.

8.3.4. Reading Long Sequence Data

Since sequence data are by their very nature sequential, we need to address the issue of processing it. We did so in a rather ad-hoc manner in Section 8.1. When sequences get too long to be processed by models all at once, we may wish to split such sequences for reading. Now let us describe general strategies. Before introducing the model, let us assume that we will use a neural network to train a language model, where the network processes a minibatch of sequences with predefined length, say \(n\) time steps, at a time. Now the question is how to read minibatches of features and labels at random.

To begin with, since a text sequence can be arbitrarily long, such as the entire The Time Machine book, we can partition such a long sequence into subsequences with the same number of time steps. When training our neural network, a minibatch of such subsequences will be fed into the model. Suppose that the network processes a subsequence of \(n\) time steps at a time. fig_timemachine_5gram shows all the different ways to obtain subsequences from an original text sequence, where \(n=5\) and a token at each time step corresponds to a character. Note that we have quite some freedom since we could pick an arbitrary offset that indicates the initial position.

Different offsets lead to different subsequences when splitting up text. .. _fig_timemachine_5gram:

Hence, which one should we pick from fig_timemachine_5gram? In fact, all of them are equally good. However, if we pick just one offset, there is limited coverage of all the possible subsequences for training our network. Therefore, we can start with a random offset to partition a sequence to get both coverage and randomness. In the following, we describe how to accomplish this for both random sampling and sequential partitioning strategies.

8.3.4.1. Random Sampling

In random sampling, each example is a subsequence arbitrarily captured on the original long sequence. The subsequences from two adjacent random minibatches during iteration are not necessarily adjacent on the original sequence. For language modeling, the target is to predict the next token based on what tokens we have seen so far, hence the labels are the original sequence, shifted by one token.

The following code randomly generates a minibatch from the data each time. Here, the argument batchSize specifies the number of subsequence examples in each minibatch and numSteps is the predefined number of time steps in each subsequence.

/**
 * Generate a minibatch of subsequences using random sampling.
 */
public ArrayList<NDList>
        seqDataIterRandom(List<Integer> corpus, int batchSize, int numSteps, NDManager manager) {
    // Start with a random offset (inclusive of `numSteps - 1`) to partition a
    // sequence
    corpus = corpus.subList(new Random().nextInt(numSteps - 1), corpus.size());
    // Subtract 1 since we need to account for labels
    int numSubseqs = (corpus.size() - 1) / numSteps;
    // The starting indices for subsequences of length `numSteps`
    List<Integer> initialIndices = new ArrayList<>();
    for (int i = 0; i < numSubseqs * numSteps; i += numSteps) {
        initialIndices.add(i);
    }
    // In random sampling, the subsequences from two adjacent random
    // minibatches during iteration are not necessarily adjacent on the
    // original sequence
    Collections.shuffle(initialIndices);

    int numBatches = numSubseqs / batchSize;

    ArrayList<NDList> pairs = new ArrayList<NDList>();
    for (int i = 0; i < batchSize * numBatches; i += batchSize) {
        // Here, `initialIndices` contains randomized starting indices for
        // subsequences
        List<Integer> initialIndicesPerBatch = initialIndices.subList(i, i + batchSize);

        NDArray xNDArray = manager.create(new Shape(initialIndices.size(), numSteps), DataType.INT32);
        NDArray yNDArray = manager.create(new Shape(initialIndices.size(), numSteps), DataType.INT32);
        for (int j = 0; j < initialIndices.size(); j++) {
            ArrayList<Integer> X = data(initialIndices.get(j), corpus, numSteps);
            xNDArray.set(new NDIndex(j), manager.create(X.stream().mapToInt(Integer::intValue).toArray()));
            ArrayList<Integer> Y = data(initialIndices.get(j)+1, corpus, numSteps);
            yNDArray.set(new NDIndex(j), manager.create(Y.stream().mapToInt(Integer::intValue).toArray()));
        }
        NDList pair = new NDList();
        pair.add(xNDArray);
        pair.add(yNDArray);
        pairs.add(pair);
    }
    return pairs;
}

ArrayList<Integer> data(int pos, List<Integer> corpus, int numSteps) {
    // Return a sequence of length `numSteps` starting from `pos`
    return new ArrayList<Integer>(corpus.subList(pos, pos + numSteps));
}

Let us manually generate a sequence from 0 to 34. We assume that the batch size and numbers of time steps are 2 and 5, respectively. This means that we can generate \(\lfloor (35 - 1) / 5 \rfloor= 6\) feature-label subsequence pairs. With a minibatch size of 2, we only get 3 minibatches.

List<Integer> mySeq = new ArrayList<>();
for (int i = 0; i < 35; i++) {
    mySeq.add(i);
}

for (NDList pair : seqDataIterRandom(mySeq, 2, 5, manager)) {
    System.out.println("X:\n" + pair.get(0).toDebugString(50, 50, 50, 50, true));
    System.out.println("Y:\n" + pair.get(1).toDebugString(50, 50, 50, 50, true));
}
X:
ND: (6, 5) gpu(0) int32
[[16, 17, 18, 19, 20],
 [26, 27, 28, 29, 30],
 [21, 22, 23, 24, 25],
 [ 6,  7,  8,  9, 10],
 [ 1,  2,  3,  4,  5],
 [11, 12, 13, 14, 15],
]

Y:
ND: (6, 5) gpu(0) int32
[[17, 18, 19, 20, 21],
 [27, 28, 29, 30, 31],
 [22, 23, 24, 25, 26],
 [ 7,  8,  9, 10, 11],
 [ 2,  3,  4,  5,  6],
 [12, 13, 14, 15, 16],
]

X:
ND: (6, 5) gpu(0) int32
[[16, 17, 18, 19, 20],
 [26, 27, 28, 29, 30],
 [21, 22, 23, 24, 25],
 [ 6,  7,  8,  9, 10],
 [ 1,  2,  3,  4,  5],
 [11, 12, 13, 14, 15],
]

Y:
ND: (6, 5) gpu(0) int32
[[17, 18, 19, 20, 21],
 [27, 28, 29, 30, 31],
 [22, 23, 24, 25, 26],
 [ 7,  8,  9, 10, 11],
 [ 2,  3,  4,  5,  6],
 [12, 13, 14, 15, 16],
]

X:
ND: (6, 5) gpu(0) int32
[[16, 17, 18, 19, 20],
 [26, 27, 28, 29, 30],
 [21, 22, 23, 24, 25],
 [ 6,  7,  8,  9, 10],
 [ 1,  2,  3,  4,  5],
 [11, 12, 13, 14, 15],
]

Y:
ND: (6, 5) gpu(0) int32
[[17, 18, 19, 20, 21],
 [27, 28, 29, 30, 31],
 [22, 23, 24, 25, 26],
 [ 7,  8,  9, 10, 11],
 [ 2,  3,  4,  5,  6],
 [12, 13, 14, 15, 16],
]

8.3.4.2. Sequential Partitioning

In addition to random sampling of the original sequence, we can also ensure that the subsequences from two adjacent minibatches during iteration are adjacent on the original sequence. This strategy preserves the order of split subsequences when iterating over minibatches, hence is called sequential partitioning.

/**
 * Generate a minibatch of subsequences using sequential partitioning.
 */
public ArrayList<NDList> seqDataIterSequential(List<Integer> corpus, int batchSize, int numSteps,
                                               NDManager manager) {
    // Start with a random offset to partition a sequence
    int offset = new Random().nextInt(numSteps);
    int numTokens = ((corpus.size() - offset - 1) / batchSize) * batchSize;

    NDArray Xs = manager.create(
        corpus.subList(offset, offset + numTokens).stream().mapToInt(Integer::intValue).toArray());
    NDArray Ys = manager.create(
        corpus.subList(offset + 1, offset + 1 + numTokens).stream().mapToInt(Integer::intValue).toArray());
    Xs = Xs.reshape(new Shape(batchSize, -1));
    Ys = Ys.reshape(new Shape(batchSize, -1));
    int numBatches = (int) Xs.getShape().get(1) / numSteps;


    ArrayList<NDList> pairs = new ArrayList<NDList>();
    for (int i = 0; i < numSteps * numBatches; i += numSteps) {
        NDArray X = Xs.get(new NDIndex(":, {}:{}", i, i + numSteps));
        NDArray Y = Ys.get(new NDIndex(":, {}:{}", i, i + numSteps));
        NDList pair = new NDList();
        pair.add(X);
        pair.add(Y);
        pairs.add(pair);
    }
    return pairs;
}

Using the same settings, let us print features X and labels Y for each minibatch of subsequences read by sequential partitioning. Note that the subsequences from two adjacent minibatches during iteration are indeed adjacent on the original sequence.

for (NDList pair : seqDataIterSequential(mySeq, 2, 5, manager)) {
    System.out.println("X:\n" + pair.get(0).toDebugString(10, 10, 10, 10, true));
    System.out.println("Y:\n" + pair.get(1).toDebugString(10, 10, 10, 10, true));
}
X:
ND: (2, 5) gpu(0) int32
[[ 4,  5,  6,  7,  8],
 [19, 20, 21, 22, 23],
]

Y:
ND: (2, 5) gpu(0) int32
[[ 5,  6,  7,  8,  9],
 [20, 21, 22, 23, 24],
]

X:
ND: (2, 5) gpu(0) int32
[[ 9, 10, 11, 12, 13],
 [24, 25, 26, 27, 28],
]

Y:
ND: (2, 5) gpu(0) int32
[[10, 11, 12, 13, 14],
 [25, 26, 27, 28, 29],
]

X:
ND: (2, 5) gpu(0) int32
[[14, 15, 16, 17, 18],
 [29, 30, 31, 32, 33],
]

Y:
ND: (2, 5) gpu(0) int32
[[15, 16, 17, 18, 19],
 [30, 31, 32, 33, 34],
]

Now we wrap the above two sampling functions to a class so that we can use it as a data iterator later.

public class SeqDataLoader implements Iterable<NDList> {
    public ArrayList<NDList> dataIter;
    public List<Integer> corpus;
    public Vocab vocab;
    public int batchSize;
    public int numSteps;

    /**
     * An iterator to load sequence data.
     */
    @SuppressWarnings("unchecked")
    public SeqDataLoader(int batchSize, int numSteps, boolean useRandomIter, int maxTokens) throws IOException, Exception {
        Pair<List<Integer>, Vocab> corpusVocabPair = TimeMachine.loadCorpusTimeMachine(maxTokens);
        this.corpus = corpusVocabPair.getKey();
        this.vocab = corpusVocabPair.getValue();

        this.batchSize = batchSize;
        this.numSteps = numSteps;
        if (useRandomIter) {
            dataIter = seqDataIterRandom(corpus, batchSize, numSteps, manager);
        }else {
            dataIter = seqDataIterSequential(corpus, batchSize, numSteps, manager);
        }
    }

    @Override
    public Iterator<NDList> iterator() {
        return dataIter.iterator();
    }
}

Last, we define a function loadDataTimeMachine that returns both the data iterator and the vocabulary.

/**
 * Return the iterator and the vocabulary of the time machine dataset.
 */
public Pair<ArrayList<NDList>, Vocab> loadDataTimeMachine(int batchSize, int numSteps, boolean useRandomIter, int maxTokens) throws IOException, Exception {
    SeqDataLoader seqData = new SeqDataLoader(batchSize, numSteps, useRandomIter, maxTokens);
    return new Pair(seqData.dataIter, seqData.vocab); // ArrayList<NDList>, Vocab
}

8.3.5. Summary

  • Language models are key to natural language processing.

  • \(n\)-grams provide a convenient model for dealing with long sequences by truncating the dependence.

  • Long sequences suffer from the problem that they occur very rarely or never.

  • Zipf’s law governs the word distribution for not only unigrams but also the other \(n\)-grams.

  • There is a lot of structure but not enough frequency to deal with infrequent word combinations efficiently via Laplace smoothing.

  • The main choices for reading long sequences are random sampling and sequential partitioning. The latter can ensure that the subsequences from two adjacent minibatches during iteration are adjacent on the original sequence.

8.3.6. Exercises

  1. Suppose there are \(100,000\) words in the training dataset. How much word frequency and multi-word adjacent frequency does a four-gram need to store?

  2. How would you model a dialogue?

  3. Estimate the exponent of Zipf’s law for unigrams, bigrams, and trigrams.

  4. What other methods can you think of for reading long sequence data?

  5. Consider the random offset that we use for reading long sequences.

    1. Why is it a good idea to have a random offset?

    2. Does it really lead to a perfectly uniform distribution over the sequences on the document?

    3. What would you have to do to make things even more uniform?

  6. If we want a sequence example to be a complete sentence, what kind of problem does this introduce in minibatch sampling? How can we fix the problem?