Run this notebook online: or Colab:
14.5. Subword Embedding¶
English words usually have internal structures and formation methods. For example, we can deduce the relationship between “dog”, “dogs”, and “dogcatcher” by their spelling. All these words have the same root, “dog”, but they use different suffixes to change the meaning of the word. Moreover, this association can be extended to other words. For example, the relationship between “dog” and “dogs” is just like the relationship between “cat” and “cats”. The relationship between “boy” and “boyfriend” is just like the relationship between “girl” and “girlfriend”. This characteristic is not unique to English. In French and Spanish, a lot of verbs can have more than 40 different forms depending on the context. In Finnish, a noun may have more than 15 forms. In fact, morphology, which is an important branch of linguistics, studies the internal structure and formation of words.
14.5.1. fastText¶
In word2vec, we did not directly use morphology information. In both the skip-gram model and continuous bag-of-words model, we use different vectors to represent words with different forms. For example, “dog” and “dogs” are represented by two different vectors, while the relationship between these two vectors is not directly represented in the model. In view of this, fastText [Bojanowski et al., 2017] proposes the method of subword embedding, thereby attempting to introduce morphological information in the skip-gram model in word2vec.
In fastText, each central word is represented as a collection of subwords. Below we use the word “where” as an example to understand how subwords are formed. First, we add the special characters “<” and “>” at the beginning and end of the word to distinguish the subwords used as prefixes and suffixes. Then, we treat the word as a sequence of characters to extract the \(n\)-grams. For example, when \(n=3\), we can get all subwords with a length of \(3\):
and the special subword \(\textrm{"<where>"}\).
In fastText, for a word \(w\), we record the union of all its subwords with length of \(3\) to \(6\) and special subwords as \(\mathcal{G}_w\). Thus, the dictionary is the union of the collection of subwords of all words. Assume the vector of the subword \(g\) in the dictionary is \(\mathbf{z}_g\). Then, the central word vector \(\mathbf{u}_w\) for the word \(w\) in the skip-gram model can be expressed as
The rest of the fastText process is consistent with the skip-gram model, so it is not repeated here. As we can see, compared with the skip-gram model, the dictionary in fastText is larger, resulting in more model parameters. Also, the vector of one word requires the summation of all subword vectors, which results in higher computation complexity. However, we can obtain better vectors for more uncommon complex words, even words not existing in the dictionary, by looking at other words with similar structures.
14.5.2. Byte Pair Encoding¶
In fastText, all the extracted subwords have to be of the specified lengths, such as \(3\) to \(6\), thus the vocabulary size cannot be predefined. To allow for variable-length subwords in a fixed-size vocabulary, we can apply a compression algorithm called byte pair encoding (BPE) to extract subwords [Sennrich et al., 2015].
Byte pair encoding performs a statistical analysis of the training dataset to discover common symbols within a word, such as consecutive characters of arbitrary length. Starting from symbols of length \(1\), byte pair encoding iteratively merges the most frequent pair of consecutive symbols to produce new longer symbols. Note that for efficiency, pairs crossing word boundaries are not considered. In the end, we can use such symbols as subwords to segment words. Byte pair encoding and its variants has been used for input representations in popular natural language processing pretraining models such as GPT-2 [Radford et al., 2019] and RoBERTa [Liu et al., 2019]. In the following, we will illustrate how byte pair encoding works.
First, we initialize the vocabulary of symbols as all the English
lowercase characters, a special end-of-word symbol '_'
, and a
special unknown symbol '[UNK]'
.
%load ../utils/djl-imports
%load ../utils/Functions.java
import java.util.stream.*;
NDManager manager = NDManager.newBaseManager();
String[] symbols =
new String[] {
"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p",
"q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "_", "[UNK]"
};
Since we do not consider symbol pairs that cross boundaries of words, we
only need a dictionary rawTokenFreqs
that maps words to their
frequencies (number of occurrences) in a dataset. Note that the special
symbol '_'
is appended to each word so that we can easily recover a
word sequence (e.g., “a taller man”) from a sequence of output symbols (
e.g., “a_ tall er_ man”). Since we start the merging process from a
vocabulary of only single characters and special symbols, space is
inserted between every pair of consecutive characters within each word
(keys of the dictionary tokenFreqs
). In other words, space is the
delimiter between symbols within a word.
HashMap<String, Integer> rawTokenFreqs = new HashMap<>();
rawTokenFreqs.put("fast_", 4);
rawTokenFreqs.put("faster_", 3);
rawTokenFreqs.put("tall_", 5);
rawTokenFreqs.put("taller_", 4);
HashMap<String, Integer> tokenFreqs = new HashMap<>();
for (Map.Entry<String, Integer> e : rawTokenFreqs.entrySet()) {
String token = e.getKey();
tokenFreqs.put(String.join(" ", token.split("")), rawTokenFreqs.get(token));
}
tokenFreqs
{f a s t e r _=3, t a l l e r _=4, f a s t _=4, t a l l _=5}
We define the following getMaxFreqPair
function that returns the
most frequent pair of consecutive symbols within a word, where words
come from keys of the input dictionary tokenFreqs
.
public static Pair<String, String> getMaxFreqPair(HashMap<String, Integer> tokenFreqs) {
HashMap<Pair<String, String>, Integer> pairs = new HashMap<>();
for (Map.Entry<String, Integer> e : tokenFreqs.entrySet()) {
// Key of 'pairs' is a tuple of two consecutive symbols
String token = e.getKey();
Integer freq = e.getValue();
String[] symbols = token.split(" ");
for (int i = 0; i < symbols.length - 1; i++) {
pairs.put(
new Pair<>(symbols[i], symbols[i + 1]),
pairs.getOrDefault(new Pair<>(symbols[i], symbols[i + 1]), 0) + freq);
}
}
int max = 0; // Key of `pairs` with the max value
Pair<String, String> maxFreqPair = null;
for (Map.Entry<Pair<String, String>, Integer> pair : pairs.entrySet()) {
if (max < pair.getValue()) {
max = pair.getValue();
maxFreqPair = pair.getKey();
}
}
return maxFreqPair;
}
As a greedy approach based on frequency of consecutive symbols, byte
pair encoding will use the following mergeSymbols
function to merge
the most frequent pair of consecutive symbols to produce new symbols.
public static Pair<HashMap<String, Integer>, String[]> mergeSymbols(
Pair<String, String> maxFreqPair, HashMap<String, Integer> tokenFreqs) {
ArrayList<String> symbols = new ArrayList<>();
symbols.add(maxFreqPair.getKey() + maxFreqPair.getValue());
HashMap<String, Integer> newTokenFreqs = new HashMap<>();
for (Map.Entry<String, Integer> e : tokenFreqs.entrySet()) {
String token = e.getKey();
String newToken =
token.replace(
maxFreqPair.getKey() + " " + maxFreqPair.getValue(),
maxFreqPair.getKey() + "" + maxFreqPair.getValue());
newTokenFreqs.put(newToken, tokenFreqs.get(token));
}
return new Pair(newTokenFreqs, symbols.toArray(new String[symbols.size()]));
}
Now we iteratively perform the byte pair encoding algorithm over the
keys of the dictionary tokenFreqs
. In the first iteration, the most
frequent pair of consecutive symbols are 't'
and 'a'
, thus byte
pair encoding merges them to produce a new symbol 'ta'
. In the
second iteration, byte pair encoding continues to merge 'ta'
and
'l'
to result in another new symbol 'tal'
.
int numMerges = 10;
for (int i = 0; i < numMerges; i++) {
Pair<String, String> maxFreqPair = getMaxFreqPair(tokenFreqs);
Pair<HashMap<String, Integer>, String[]> pair =
mergeSymbols(maxFreqPair, tokenFreqs);
tokenFreqs = pair.getKey();
symbols =
Stream.concat(Arrays.stream(symbols), Arrays.stream(pair.getValue()))
.toArray(String[]::new);
System.out.println(
"merge #"
+ (i + 1)
+ ": ("
+ maxFreqPair.getKey()
+ ", "
+ maxFreqPair.getValue()
+ ")");
}
merge #1: (l, l)
merge #2: (a, ll)
merge #3: (t, all)
merge #4: (s, t)
merge #5: (a, st)
merge #6: (f, ast)
merge #7: (e, r)
merge #8: (er, _)
merge #9: (tall, _)
merge #10: (tall, er_)
After 10 iterations of byte pair encoding, we can see that list
symbols
now contains 10 more symbols that are iteratively merged
from other symbols.
Arrays.toString(symbols)
[a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z, _, [UNK], ll, all, tall, st, ast, fast, er, er_, tall_, taller_]
For the same dataset specified in the keys of the dictionary
raw_token_freqs
, each word in the dataset is now segmented by
subwords “fast_”, “fast”, “er_”, “tall_”, and “tall” as a result of
the byte pair encoding algorithm. For instance, words “faster_” and
“taller_” are segmented as “fast er_” and “tall er_”, respectively.
tokenFreqs.keySet()
[fast _, tall_, taller_, fast er_]
Note that the result of byte pair encoding depends on the dataset being
used. We can also use the subwords learned from one dataset to segment
words of another dataset. As a greedy approach, the following
segmentBPE
function tries to break words into the longest possible
subwords from the input argument symbols
.
public static List<String> segmentBPE(String[] tokens, String[] symbols) {
List<String> outputs = new ArrayList<>();
for (String token : tokens) {
int start = 0;
int end = token.length();
ArrayList<String> curOutput = new ArrayList<>();
// Segment token with the longest possible subwords from symbols
while (start < token.length() && start < end) {
if (Arrays.asList(symbols).contains(token.substring(start, end))) {
curOutput.add(token.substring(start, end));
start = end;
end = token.length();
} else {
end -= 1;
}
}
if (start < tokens.length) {
curOutput.add("[UNK]");
}
String temp = "";
for (String s : curOutput) {
temp += s + " ";
}
outputs.add(temp.trim());
}
return outputs;
}
In the following, we use the subwords in list symbols
, which is
learned from the aforementioned dataset, to segment tokens
that
represent another dataset.
String[] tokens = new String[] {"tallest_", "fatter_"};
System.out.println(segmentBPE(tokens, symbols));
[tall e st _, f a t t er_]
14.5.3. Summary¶
FastText proposes a subword embedding method. Based on the skip-gram model in word2vec, it represents the central word vector as the sum of the subword vectors of the word.
Subword embedding utilizes the principles of morphology, which usually improves the quality of representations of uncommon words.
Byte pair encoding performs a statistical analysis of the training dataset to discover common symbols within a word. As a greedy approach, byte pair encoding iteratively merges the most frequent pair of consecutive symbols.
14.5.4. Exercises¶
When there are too many subwords (for example, 6 words in English result in about \(3\times 10^8\) combinations), what problems arise? Can you think of any methods to solve them? Hint: Refer to the end of section 3.2 of the fastText paper [Bojanowski et al., 2017].
How can you design a subword embedding model based on the continuous bag-of-words model?
To get a vocabulary of size \(m\), how many merging operations are needed when the initial symbol vocabulary size is \(n\)?
How can we extend the idea of byte pair encoding to extract phrases?