Run this notebook online:Binder or Colab: Colab

9.5. Machine Translation and the Dataset

We have used RNNs to design language models, which are key to natural language processing. Another flagship benchmark is machine translation, a central problem domain for sequence transduction models that transform input sequences into output sequences. Playing a crucial role in various modern AI applications, sequence transduction models will form the focus of the remainder of this chapter and Section 10. To this end, this section introduces the machine translation problem and its dataset that will be used later.

Machine translation refers to the automatic translation of a sequence from one language to another. In fact, this field may date back to 1940s soon after digital computers were invented, especially by considering the use of computers for cracking language codes in World War II. For decades, statistical approaches had been dominant in this field [Brown.Cocke.Della-Pietra.ea.1988][Brown.Cocke.Della-Pietra.ea.1990] before the rise of end-to-end learning using neural networks. The latter is often called neural machine translation to distinguish itself from statistical machine translation that involves statistical analysis in components such as the translation model and the language model.

Emphasizing end-to-end learning, this book will focus on neural machine translation methods. Different from our language model problem in Section 8.3 whose corpus is in one single language, machine translation datasets are composed of pairs of text sequences that are in the source language and the target language, respectively. Thus, instead of reusing the preprocessing routine for language modeling, we need a different way to preprocess machine translation datasets. In the following, we show how to load the preprocessed data into minibatches for training.

%mavenRepo snapshots

%maven ai.djl:api:0.11.0-SNAPSHOT
%maven org.slf4j:slf4j-api:1.7.26
%maven org.slf4j:slf4j-simple:1.7.26

%maven ai.djl.mxnet:mxnet-engine:0.11.0-SNAPSHOT
%maven ai.djl.mxnet:mxnet-native-auto:1.7.0-backport
%load ../utils/plot-utils
%load ../utils/
%load ../utils/
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;
import tech.tablesaw.plotly.components.Layout;
import tech.tablesaw.plotly.traces.HistogramTrace;

import java.nio.file.*;
import java.util.*;
NDManager manager = NDManager.newBaseManager(Functions.tryGpu(0));

9.5.1. Downloading and Preprocessing the Dataset

To begin with, we download an English-French dataset that consists of bilingual sentence pairs from the Tatoeba Project. Each line in the dataset is a tab-delimited pair of an English text sequence and the translated French text sequence. Note that each text sequence can be just one sentence or a paragraph of multiple sentences. In this machine translation problem where English is translated into French, English is the source language and French is the target language.

public static StringBuilder readDataNMT() throws IOException {
    File file = new File("./");
    if (!file.exists()) {
        InputStream inputStream =
                new URL("").openStream();
                inputStream, Paths.get("./"), StandardCopyOption.REPLACE_EXISTING);

    ZipFile zipFile = new ZipFile(file);
    Enumeration<? extends ZipEntry> entries = zipFile.entries();
    InputStream stream = null;
    while (entries.hasMoreElements()) {
        ZipEntry entry = entries.nextElement();
        if (entry.getName().contains("fra.txt")) {
            stream = zipFile.getInputStream(entry);

    String[] lines;
    try (BufferedReader in = new BufferedReader(new InputStreamReader(stream))) {
        lines = in.lines().toArray(String[]::new);
    StringBuilder output = new StringBuilder();
    for (int i = 0; i < lines.length; i++) {
        output.append(lines[i] + "\n");
    return output;

StringBuilder rawText = readDataNMT();
System.out.println(rawText.substring(0, 75));
Go. Va !
Hi. Salut !
Run!        Cours���!
Run!        Courez���!
Who?        Qui ?
Wow!        ��a alo

After downloading the dataset, we proceed with several preprocessing steps for the raw text data. For instance, we replace non-breaking space with space, convert uppercase letters to lowercase ones, and insert space between words and punctuation marks.

public static StringBuilder preprocessNMT(String text) {
    // Replace non-breaking space with space, and convert uppercase letters to
    // lowercase ones

    text = text.replace('\u202f', ' ').replaceAll("\\xa0", " ").toLowerCase();

    // Insert space between words and punctuation marks
    StringBuilder out = new StringBuilder();
    Character currChar;
    for (int i = 0; i < text.length(); i++) {
        currChar = text.charAt(i);
        if (i > 0 && noSpace(currChar, text.charAt(i - 1))) {
            out.append(' ');
    return out;

public static boolean noSpace(Character currChar, Character prevChar) {
    /* Preprocess the English-French dataset. */
    return new HashSet<>(Arrays.asList(',', '.', '!', '?')).contains(currChar)
            && prevChar != ' ';

StringBuilder text = preprocessNMT(rawText.toString());
System.out.println(text.substring(0, 80));
go .        va !
hi .        salut !
run !       cours��� !
run !       courez��� !
who ?       qui ?
wow !       ��a

9.5.2. Tokenization

Different from character-level tokenization in Section 8.3, for machine translation we prefer word-level tokenization here (state-of-the-art models may use more advanced tokenization techniques). The following tokenizeNMT function tokenizes the the first numExamples text sequence pairs, where each token is either a word or a punctuation mark. This function returns two lists of token lists: source and target. Specifically, source.get(i) is a list of tokens from the \(i^\mathrm{th}\) text sequence in the source language (English here) and target.get(i) is that in the target language (French here).

public static Pair<ArrayList<String[]>, ArrayList<String[]>> tokenizeNMT(
        String text, Integer numExamples) {
    ArrayList<String[]> source = new ArrayList<>();
    ArrayList<String[]> target = new ArrayList<>();

    int i = 0;
    for (String line : text.split("\n")) {
        if (numExamples != null && i > numExamples) {
        String[] parts = line.split("\t");
        if (parts.length == 2) {
            source.add(parts[0].split(" "));
            target.add(parts[1].split(" "));
        i += 1;
    return new Pair<>(source, target);

Pair<ArrayList<String[]>, ArrayList<String[]>> pair = tokenizeNMT(text.toString(), null);
ArrayList<String[]> source = pair.getKey();
ArrayList<String[]> target = pair.getValue();
for (String[] subArr : source.subList(0, 6)) {
[go, .]
[hi, .]
[run, !]
[run, !]
[who, ?]
[wow, !]
for (String[] subArr : target.subList(0, 6)) System.out.println(Arrays.toString(subArr));
[va, !]
[salut, !]
[cours���, !]
[courez���, !]
[qui, ?]
[��a, alors���, !]

Let us plot the histogram of the number of tokens per text sequence. In this simple English-French dataset, most of the text sequences have fewer than 20 tokens.

double[] y1 = new double[source.size()];
for (int i = 0; i < source.size(); i++) y1[i] = source.get(i).length;
double[] y2 = new double[target.size()];
for (int i = 0; i < target.size(); i++) y2[i] = target.get(i).length;

HistogramTrace trace1 =
HistogramTrace trace2 =

Layout layout = Layout.builder().barMode(Layout.BarMode.GROUP).build();
new Figure(layout, trace1, trace2);