Run this notebook online:Binder or Colab: Colab

14.5. Subword Embedding

English words usually have internal structures and formation methods. For example, we can deduce the relationship between “dog”, “dogs”, and “dogcatcher” by their spelling. All these words have the same root, “dog”, but they use different suffixes to change the meaning of the word. Moreover, this association can be extended to other words. For example, the relationship between “dog” and “dogs” is just like the relationship between “cat” and “cats”. The relationship between “boy” and “boyfriend” is just like the relationship between “girl” and “girlfriend”. This characteristic is not unique to English. In French and Spanish, a lot of verbs can have more than 40 different forms depending on the context. In Finnish, a noun may have more than 15 forms. In fact, morphology, which is an important branch of linguistics, studies the internal structure and formation of words.

14.5.1. fastText

In word2vec, we did not directly use morphology information. In both the skip-gram model and continuous bag-of-words model, we use different vectors to represent words with different forms. For example, “dog” and “dogs” are represented by two different vectors, while the relationship between these two vectors is not directly represented in the model. In view of this, fastText [Bojanowski et al., 2017] proposes the method of subword embedding, thereby attempting to introduce morphological information in the skip-gram model in word2vec.

In fastText, each central word is represented as a collection of subwords. Below we use the word “where” as an example to understand how subwords are formed. First, we add the special characters “<” and “>” at the beginning and end of the word to distinguish the subwords used as prefixes and suffixes. Then, we treat the word as a sequence of characters to extract the \(n\)-grams. For example, when \(n=3\), we can get all subwords with a length of \(3\):

(14.5.1)\[\textrm{"<wh"}, \ \textrm{"whe"}, \ \textrm{"her"}, \ \textrm{"ere"}, \ \textrm{"re>"},\]

and the special subword \(\textrm{"<where>"}\).

In fastText, for a word \(w\), we record the union of all its subwords with length of \(3\) to \(6\) and special subwords as \(\mathcal{G}_w\). Thus, the dictionary is the union of the collection of subwords of all words. Assume the vector of the subword \(g\) in the dictionary is \(\mathbf{z}_g\). Then, the central word vector \(\mathbf{u}_w\) for the word \(w\) in the skip-gram model can be expressed as

(14.5.2)\[\mathbf{u}_w = \sum_{g\in\mathcal{G}_w} \mathbf{z}_g.\]

The rest of the fastText process is consistent with the skip-gram model, so it is not repeated here. As we can see, compared with the skip-gram model, the dictionary in fastText is larger, resulting in more model parameters. Also, the vector of one word requires the summation of all subword vectors, which results in higher computation complexity. However, we can obtain better vectors for more uncommon complex words, even words not existing in the dictionary, by looking at other words with similar structures.

14.5.2. Byte Pair Encoding

In fastText, all the extracted subwords have to be of the specified lengths, such as \(3\) to \(6\), thus the vocabulary size cannot be predefined. To allow for variable-length subwords in a fixed-size vocabulary, we can apply a compression algorithm called byte pair encoding (BPE) to extract subwords [Sennrich et al., 2015].

Byte pair encoding performs a statistical analysis of the training dataset to discover common symbols within a word, such as consecutive characters of arbitrary length. Starting from symbols of length \(1\), byte pair encoding iteratively merges the most frequent pair of consecutive symbols to produce new longer symbols. Note that for efficiency, pairs crossing word boundaries are not considered. In the end, we can use such symbols as subwords to segment words. Byte pair encoding and its variants has been used for input representations in popular natural language processing pretraining models such as GPT-2 [Radford et al., 2019] and RoBERTa [Liu et al., 2019]. In the following, we will illustrate how byte pair encoding works.

First, we initialize the vocabulary of symbols as all the English lowercase characters, a special end-of-word symbol '_', and a special unknown symbol '[UNK]'.

%load ../utils/djl-imports
%load ../utils/Functions.java
import java.util.stream.*;
NDManager manager = NDManager.newBaseManager();
String[] symbols =
        new String[] {
            "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p",
            "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "_", "[UNK]"
        };

Since we do not consider symbol pairs that cross boundaries of words, we only need a dictionary rawTokenFreqs that maps words to their frequencies (number of occurrences) in a dataset. Note that the special symbol '_' is appended to each word so that we can easily recover a word sequence (e.g., “a taller man”) from a sequence of output symbols ( e.g., “a_ tall er_ man”). Since we start the merging process from a vocabulary of only single characters and special symbols, space is inserted between every pair of consecutive characters within each word (keys of the dictionary tokenFreqs). In other words, space is the delimiter between symbols within a word.

HashMap<String, Integer> rawTokenFreqs = new HashMap<>();
rawTokenFreqs.put("fast_", 4);
rawTokenFreqs.put("faster_", 3);
rawTokenFreqs.put("tall_", 5);
rawTokenFreqs.put("taller_", 4);

HashMap<String, Integer> tokenFreqs = new HashMap<>();
for (Map.Entry<String, Integer> e : rawTokenFreqs.entrySet()) {
    String token = e.getKey();
    tokenFreqs.put(String.join(" ", token.split("")), rawTokenFreqs.get(token));
}

tokenFreqs
{f a s t e r _=3, t a l l e r _=4, f a s t _=4, t a l l _=5}

We define the following getMaxFreqPair function that returns the most frequent pair of consecutive symbols within a word, where words come from keys of the input dictionary tokenFreqs.

public static Pair<String, String> getMaxFreqPair(HashMap<String, Integer> tokenFreqs) {
    HashMap<Pair<String, String>, Integer> pairs = new HashMap<>();
    for (Map.Entry<String, Integer> e : tokenFreqs.entrySet()) {
        // Key of 'pairs' is a tuple of two consecutive symbols
        String token = e.getKey();
        Integer freq = e.getValue();
        String[] symbols = token.split(" ");
        for (int i = 0; i < symbols.length - 1; i++) {
            pairs.put(
                    new Pair<>(symbols[i], symbols[i + 1]),
                    pairs.getOrDefault(new Pair<>(symbols[i], symbols[i + 1]), 0) + freq);
        }
    }
    int max = 0; // Key of `pairs` with the max value
    Pair<String, String> maxFreqPair = null;
    for (Map.Entry<Pair<String, String>, Integer> pair : pairs.entrySet()) {
        if (max < pair.getValue()) {
            max = pair.getValue();
            maxFreqPair = pair.getKey();
        }
    }
    return maxFreqPair;
}

As a greedy approach based on frequency of consecutive symbols, byte pair encoding will use the following mergeSymbols function to merge the most frequent pair of consecutive symbols to produce new symbols.

public static Pair<HashMap<String, Integer>, String[]> mergeSymbols(
        Pair<String, String> maxFreqPair, HashMap<String, Integer> tokenFreqs) {
    ArrayList<String> symbols = new ArrayList<>();
    symbols.add(maxFreqPair.getKey() + maxFreqPair.getValue());

    HashMap<String, Integer> newTokenFreqs = new HashMap<>();
    for (Map.Entry<String, Integer> e : tokenFreqs.entrySet()) {
        String token = e.getKey();
        String newToken =
                token.replace(
                        maxFreqPair.getKey() + " " + maxFreqPair.getValue(),
                        maxFreqPair.getKey() + "" + maxFreqPair.getValue());
        newTokenFreqs.put(newToken, tokenFreqs.get(token));
    }
    return new Pair(newTokenFreqs, symbols.toArray(new String[symbols.size()]));
}

Now we iteratively perform the byte pair encoding algorithm over the keys of the dictionary tokenFreqs. In the first iteration, the most frequent pair of consecutive symbols are 't' and 'a', thus byte pair encoding merges them to produce a new symbol 'ta'. In the second iteration, byte pair encoding continues to merge 'ta' and 'l' to result in another new symbol 'tal'.

int numMerges = 10;
for (int i = 0; i < numMerges; i++) {
    Pair<String, String> maxFreqPair = getMaxFreqPair(tokenFreqs);
    Pair<HashMap<String, Integer>, String[]> pair =
            mergeSymbols(maxFreqPair, tokenFreqs);
    tokenFreqs = pair.getKey();
    symbols =
            Stream.concat(Arrays.stream(symbols), Arrays.stream(pair.getValue()))
                    .toArray(String[]::new);
    System.out.println(
            "merge #"
                    + (i + 1)
                    + ": ("
                    + maxFreqPair.getKey()
                    + ", "
                    + maxFreqPair.getValue()
                    + ")");
}
merge #1: (l, l)
merge #2: (a, ll)
merge #3: (t, all)
merge #4: (s, t)
merge #5: (a, st)
merge #6: (f, ast)
merge #7: (e, r)
merge #8: (er, _)
merge #9: (tall, _)
merge #10: (tall, er_)

After 10 iterations of byte pair encoding, we can see that list symbols now contains 10 more symbols that are iteratively merged from other symbols.

Arrays.toString(symbols)
[a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, _, [UNK], ll, all, tall, st, ast, fast, er, er_, tall_, taller_]

For the same dataset specified in the keys of the dictionary raw_token_freqs, each word in the dataset is now segmented by subwords “fast_”, “fast”, “er_”, “tall_”, and “tall” as a result of the byte pair encoding algorithm. For instance, words “faster_” and “taller_” are segmented as “fast er_” and “tall er_”, respectively.

tokenFreqs.keySet()
[fast _, tall_, taller_, fast er_]

Note that the result of byte pair encoding depends on the dataset being used. We can also use the subwords learned from one dataset to segment words of another dataset. As a greedy approach, the following segmentBPE function tries to break words into the longest possible subwords from the input argument symbols.

public static List<String> segmentBPE(String[] tokens, String[] symbols) {
    List<String> outputs = new ArrayList<>();
    for (String token : tokens) {
        int start = 0;
        int end = token.length();
        ArrayList<String> curOutput = new ArrayList<>();
        // Segment token with the longest possible subwords from symbols
        while (start < token.length() && start < end) {
            if (Arrays.asList(symbols).contains(token.substring(start, end))) {
                curOutput.add(token.substring(start, end));
                start = end;
                end = token.length();
            } else {
                end -= 1;
            }
        }
        if (start < tokens.length) {
            curOutput.add("[UNK]");
        }
        String temp = "";
        for (String s : curOutput) {
            temp += s + " ";
        }
        outputs.add(temp.trim());
    }
    return outputs;
}

In the following, we use the subwords in list symbols, which is learned from the aforementioned dataset, to segment tokens that represent another dataset.

String[] tokens = new String[] {"tallest_", "fatter_"};
System.out.println(segmentBPE(tokens, symbols));
[tall e st _, f a t t er_]

14.5.3. Summary

  • FastText proposes a subword embedding method. Based on the skip-gram model in word2vec, it represents the central word vector as the sum of the subword vectors of the word.

  • Subword embedding utilizes the principles of morphology, which usually improves the quality of representations of uncommon words.

  • Byte pair encoding performs a statistical analysis of the training dataset to discover common symbols within a word. As a greedy approach, byte pair encoding iteratively merges the most frequent pair of consecutive symbols.

14.5.4. Exercises

  1. When there are too many subwords (for example, 6 words in English result in about \(3\times 10^8\) combinations), what problems arise? Can you think of any methods to solve them? Hint: Refer to the end of section 3.2 of the fastText paper [Bojanowski et al., 2017].

  2. How can you design a subword embedding model based on the continuous bag-of-words model?

  3. To get a vocabulary of size \(m\), how many merging operations are needed when the initial symbol vocabulary size is \(n\)?

  4. How can we extend the idea of byte pair encoding to extract phrases?