Transformer-XL – Combining Transformers and RNNs Into a State-of-the-art Language Model7 min read

Language modeling has become an important NLP technique thanks to the ability to apply it to various NLP tasks, such as machine translation and topic classification. Today, there are two leading architectures for language modeling – Recurrent Neural Networks (RNNs) and Transformers. While the former handles the input tokens – words or characters – one by one to learn the relationship between them, the latter receives a segment of tokens and learns the dependencies between at once them using an attention mechanism.

Though both architectures have reached impressive achievements, their main limitation is capturing long-term dependencies, e.g. use of important words from the beginning of the document to predict words in a subsequent part. A new paper by Google and Carnegie Mellon University, “Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context”, combines these two approaches. The new model uses the Transformer’s attention modules on each segment of input data and a recurrence mechanism to learn dependencies between consecutive segments.

Transformer-XL achieves state-of-the-art (SOTA) results on multiple language-modeling datasets such as enwik8 (word-level) and text8 (character-level), while being significantly faster (300x-1800x) during inference than the previous SOTA Transformer architecture.

Background – Transformers

A popular approach for language modeling is Recurrent Neural Networks (RNNs) as they capture dependencies between words well, especially when using modules such as LSTM. However, RNNs tend to be slow and their ability to learn long-term dependencies is still limited due to vanishing gradients.

Transformers, invented in 2017, introduced a new approach – attention modules. Instead of processing tokens one by one, attention modules receive a segment of tokens and learn the dependencies between all of them at once using three learned weight matrices – Query, Key and Value – that form an Attention Head. The Transformer network consists of multiple layers, each with several Attention Heads (and additional layers), used to learn different relationships between tokens.

As in many NLP models, the input tokens are first embedded into vectors. Due to the concurrent processing in the attention module, the model also needs to add information about the order of the tokens, a step named Positional Encoding, that helps the network learn their position. In general, this step is done with a sinusoidal function that generates a vector according to the token’s position, without any learned parameters.

An example of a single Attention Head on a single token (E1). Its output is calculated using its Query vector, and the Key and Value vectors of all tokens (In the chart we show only one additional token E2) – The Query and the Key define the weight of each token, and the output is the weighted sum of all Value vectors.

Note: An in-depth review of Transformers can be found in Jay Alammar’s great blog post.

While the original Transformers were used for machine translation (with an encoder-decoder mechanism), Al-Rfou et al. presented an architecture for language modeling. Its goal is to predict a character in a segment based on its previous characters, for example, it predicts character xn using x1xn-1, while the next characters to the right are masked (See image below). This 64-layer Transformer model is limited to relatively short inputs of 512 characters, therefore it splits the input to into segments and learns from each one separately. To process longer inputs in evaluation, it predicts one character at a time by shifting the input by one in each step.  

Training and Evaluation of the vanilla Transformer language model. Source: Transformer-XL

This model outperforms RNN models on popular benchmarks (enwik8 and text8), however, it still suffers from two shortcomings:

  1. Limited context-dependency – The maximum dependency distance between characters is limited to the length of the input. For example, the model can’t “use” a word that appeared several sentences ago.   
  2. Context fragmentation – For texts that are longer than 512 characters, every segment of that size is trained separately from scratch. Therefore, there is no context (dependencies) at all for the first tokens of each segment and between segments. This leads to inefficient training and might affect the model performance.

Sign up to our monthly newsletter
Stay updated with the latest research in Deep Learning

How Transformer-XL works

Transformer-XL heavily relies on the vanilla Transformer (Al-Rfou et al.) but introduces two innovative techniques – Recurrence Mechanism and Relative Positional Encoding – to overcome vanilla’s shortcomings. An additional advantage over the vanilla Transformer is that it can be used for both word-level and character-level language modeling.

Recurrence Mechanism

The goal of the recurrence mechanism is to enable long-term dependencies by using information from previous segments. Similarly to the vanilla version, Transformer-XL processes the first segment of tokens but keeps the outputs of the hidden layers. When the following segment is processed, each hidden layer receives two inputs:

  1. The output of the previous hidden layer of that segment, as in the vanilla version (the grey arrows in the chart below).
  2. The output of the previous hidden layer from the previous segment (the green arrows) that allows the model to create long-term dependencies.

Technically, the two inputs are concatenated and then used to calculate the Key and the Value matrices of the (current Head of the current layer of the) current segment. This addition provides the network with more information in regards to the weights (importance) of each token, but it doesn’t change the Value matrix.

Training and Evaluation of the Transformer-XL language model. Source: Transformer-XL

The concept can be expanded to incorporate longer dependencies by using information from several previous segments in the same way (under the limitations of the GPU memory), even only during evaluation.   

Another advantage of the recurrence mechanism is its speed in evaluation – In each step, it can advance by an entire segment (and not by one token as in the vanilla version) and use the previous segments’ data to predict the current segment tokens.

Relative Positional Encoding

The recurrence mechanism also introduces a new challenge – The original positional encoding handles each segment separately and, as a result, tokens from different segments have the same positional encoding. For example, the first token of the first and the second segments will have the same encoding, although their position and importance are different (the one from the first segment is probably lower). This confusion might affect the network incorrectly.

Instead, the paper presents a new positional encoding that is part of each attention module, as opposed to encoding position only before the first layer, and is based on the relative distance between tokens and not their absolute position. Technically, it expands the simple multiplication of the Attention Head’s Score (QiKj) to include four parts:  

  1. Content weight – the original score without the addition of the original positional encoding of course.
  2. Positional bias with respect to the current content (Qi). It uses a similar sinusoidal function that receives the distance between tokens (e.g. i-j), instead of the absolute position of the current token.
  3. A learned global content bias – The model adds a learned vector that adjusts the importance of the other token content (Kj).
  4. A learned global bias – Another learned vector that adjusts the importance based only on the distance between the tokens (e.g. the last previous words are probably more important than a word from a previous paragraph).

Results

The authors compared the model’s performance on word-level and character-level datasets and compared them to other prominent models (RNNs and Transformers). Transformer-XL achieved state-of-the-art (SOTA) results on several different datasets benchmarks:

  1. On WikiText-103, a large word-level dataset, the 18-layer Transformer-XL (257M parameters) reached perplexity of 18.3 compared to Baevski & Auli, the former SOTA that reached 20.5.
  2. On enwik8, a character-level dataset, the 12-layer Transformer-XL reached 1.06 bits per character (bpc), a similar result to the previous SOTA by Al-Rfou et al. that used six times more parameters. The 24-layer Transformer-XL achieved a new SOTA with 0.99 bpc.
  3. Interestingly, the model also achieves SOTA results on a dataset with only short-term dependencies – One Billion Word with only individual sentences – and on a small dataset – Penn Treebank with only 1M tokens. This suggests that the model might also be effective in these scenarios.

The benefits of the recurrence mechanism and the relative positional encoding can be seen in the following chart. It compares the perplexity score without the recurrence or the new encoding for different context lengths (number of previous tokens used in the attention head). The full Transformer-XL significantly outperforms the others and is able to exploit longer-term dependencies. In addition, it’s also capable of capturing longer dependencies than RNN (80% longer).  

Transformer-XL ablation study. Source: Transformer-XL

Lastly, as mentioned before, the model is also significantly faster during inference than the vanilla Transformer, especially for longer contexts. For example, for context-length of 800 characters, it’s 363 times faster and 1,874 times faster for 3,800 characters.  

Implementation details

The model is open-source and is implemented in both TensorFlow and PyTorch (including pre-trained models). Training duration for each dataset wasn’t specified.

Conclusion

Transformer-XL presents state-of-the-art results for language modeling on several different datasets (big/small, characters/words, etc). Its combination of two prominent concepts of deep learning – recurrence and attention – allows the model to learn long-term dependencies and might be effective to other fields of deep learning that require that capability, such as audio analysis (e.g speech data with 16k samples per second).    

This model hasn’t been tested yet on NLP tasks like sentiment analysis or question answering, and it’s still an open question what the benefit from this strong language model will be compared to other Transformer-based models such as BERT.


Sign up to our monthly newsletter
Stay updated with the latest research in Deep Learning

Leave a Reply
:)