It's no secret that multi-head self-attention is expensive -- the \(O(n²)\) complexity with respect to sequence length means allowing vanilla transformers to attend to long sequences quickly becomes intractable. Over the past two years the NLP community has developed a veritable zoo of methods to combat this problematic complexity, but in this post we'll focus on size promising approaches. You can click on any of the links below to jump directly to that section of the post.

  1. Sparse Transformers
  2. Adaptive Span Transformers
  3. Transformer-XL
  4. Compressive Transformers
  5. Reformer
  6. Routing Transformer

Time and Memory Complexity of Dense Multi-Head Attention

Multi-head attention scales poorly with sequence length for two reasons.  The first is that the number of FLOPs required to compute the attention matrix scales with the square of the sequence length, resulting in a computational complexity of \(O(hdn²)\) for a self-attention operation on a single sequence, where \(h\) is the number of attention heads, and \(d\) is the dimension of keys and queries, and \(n\) is the length of our sequence.  

Equally problematically, the memory complexity of the dot-product self-attention operation also scales with the square of the sequence length.  The memory complexity to compute the attention matrix is \(O(hdn + hn²)\) – the first term being the memory required to store keys and queries, and the second term referring to the scalar attention values produced by each head.

Let's substitute in some concrete numbers from BERT-Base to get a sense for what terms dominate.  BERT-Base uses a sequence length of 512, a hidden size of 768, and 12 heads, which means that each head has dimension 64 (768 / 12).  In this setting, 393216 floats (~1.5MB) (12 heads * 64 head size * 512 sequence length) are required to store the keys and values, while the memory required to store the scalar attention values for all head works out to 3,145,728 floats (12 * 512 * 512) or ~12MB of device memory – nearly 10 times as much memory as storing the keys at a mere 512 token context size.

Since activations must be cached during training to allow for gradient computation (barring the use of activation re-computation strategies like gradient checkpointing), storing just these attention matrices for all 12 layers of BERT base requires about ~150MB of memory per example.  At sequence length 1024 this becomes ~600MB, and at sequence length 2048 we're already up to ~2.4GB of memory per example for the attention matrices alone. This means smaller batch sizes and poorer parallelism at training time, further hindering our ability to train models that leverage long context lengths.

Sparse Transformers

"Generating Long Sequences with Sparse Transformers" by Rewon Child, Scott Gray, Alec Radford, and Ilya Sutskever addresses the problematic \(O(n²)\) term in the time and memory complexity of self-attention via a factorization approach.

Factorized Attention

In a typical self-attention operation, each term in the input sequence attends to all other terms in the input sequence, resulting in an attention pattern shown below:

Typical self-attention connectivity pattern in an auto-regressive setting. Dark blue squares represent the "queries", while light blue squares represent the "keys".

The benefit of typical self-attention is that this high connectivity allows ease of information flow between tokens – only a single layer is necessary to aggregate information from any two tokens.  But if we relax this constraint, and ensure only that information can flow between any two tokens after two layers, we can dramatically reduce our complexity with respect to sequence length. The Sparse Transformer achieves this goal by writing custom kernels that leverage fixed attention patterns.

The fixed variant of the Sparse Transformer. Dark blue squares represent queries, medium-light blue squares represent the key indices attended to by odd layers, and the lightest blue squares represent the key indices attended to by even layers. 

Half of the heads attend only to terms in a short, local context, while the other half attend to predesignated indices spread evenly throughout the sequence.  

By routing information through these aggregation indices the network is still able to pass information from distant tokens and make use of long-term context, while reducing time and memory complexities to \(O(n\sqrt n)\). Importantly, it only requires two layers for any token to incorporate information from any other token.

Empirical Results

Importantly, the factorized attention structure doesn't seem to negatively impact language modeling performance, leading to bits per character that were (surprisingly) marginally better than dense attention on enwiki8 and allowing efficient attention over context sizes up to 12,228 tokens.

It's conceivable that the Sparse Transformer attention structure works in part because these attention patterns aren't all that dissimilar from real learned dense attention patterns.  In "What Does BERT Look At? An Analysis of BERT’s Attention" by Kevin Clark, Urvashi Khandelwal, Omer Levy, and Christopher D. Manning the authors probe the patterns learned by dense attention in an effort to gain intuition for what functions attention performs in transformer models. They find heads that attend to the token immediately previous (similar to the local attention pattern in sparse attention) as well as heads that attend to specific aggregation tokens like [SEP] and periods. So perhaps the inductive biases encoded in the Sparse Transformers attention patterns are useful rather than detrimental.

Examples of BERT's learned attention patterns.

To employ the fixed attention kernels in your own projects, check out OpenAI's blocksparse library and the accompanying examples the authors have released as open source.

Adaptive Span Transformers

Sainbayar Sukhbaatar, Edouard Grave, Piotr Bojanowski, and Armand Joulin take a different approach to the complexity problem in their work "Adaptive Attention Span in Transformers".  They make a similar observations to the authors of "What Does Bert Look At?" and note that while dense attention allows each head to attend over the full context, many attention heads specialize to only consider local context while others consider the entire available sequence.  They propose leveraging this observation by using a variant of self-attention that allows the model to select it's context size.

Adaptive Masking

The Adaptive Span Transformer accomplishes this by masking the sequence such that the contribution of tokens outside of a learned, per-head context fall off quickly to zero.  The mask (\(M\)) is multiplied with the logits of the softmax operation to zero out certain tokens' contributions to the current hidden state, \(x\), and a hyperparameter \(R\) controls the minimum span size.

$$ M = min(max(\frac{1}{R}(R + z - x), 0), 1)$$

The soft-masking function employed by "Adaptive Attention Span in Transformers". Graphic from the Facebook AI blog post accompanying the paper.

An \(\ell_1\) penalty is applied to the learned z values in order to encourage the model to only use additional context where beneficial.  

Attention Introspection and Empirical Results

With these constraints, most heads elect to focus on <100 characters of context, with only a few select heads (primarily in the later layers of the network) opting to pay the \(\ell_1\) penalty in order to attend to a context of >1000 characters.

Along with clever caching, this penalty on long term context allows the adaptive-span transformer to use attention spans of up to 8k characters for select heads while still keeping the overall computational cost of the model cheap. In addition, performance on benchmarks remains high, reaching 0.98 bits per character on enwiki8 and 1.07 bits per character on the text8 dataset.

However, the variable span sizes aren't ideal in terms of ease of parallelism where we typically want dense, uniformly sized matrices for best performance. Although this method allows a dramatic reduction in the number of FLOPS necessary to compute the forward pass at prediction time, the authors only provide vague performance estimates, stating that the adaptive span implementation allowed for processing context lengths up to 8192 tokens at similar rates to a fixed context size model with 2048 tokens of context.  

Facebook AI Research has also open sourced their work – code and pretrained models are available at github.com/facebookresearch/adaptive-spans.

Transformer-XL

Rather than attempting to make the dense attention operation cheaper,
Zihang Dai, Zhilin Yang, Yiming Yang, Jaime Carbonell, Quoc V. Le, and Ruslan Salakhutdinov opted to take inspiration from RNNs and introduce a recurrence mechanism in addition to the self-attention mechanism in transformers.  Their work "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", introduces two novel concepts – a component that feeds the hidden states of previous "segments" as inputs to current segments layers, and a relative position encoding scheme to facilitate this strategy.

Segment Recurrence

With a standard transformer that has fixed context size, handling long inputs requires splitting the input up into chunks (segments) and processing each individually. However, this approach has the limitation that information from prior segments cannot flow to the current token.  This independence is somewhat beneficial in that it allows for the segments to be batch processed efficiently, but if your goal is long-term coherence this is a major limiting factor.

Token attention structure in an auto-regressive transformer with dense attention from the Transformer XL paper.

Transformer-XL overcomes this limitation by enforcing that the segments be processed in series.  After the first segment, tokens in subsequent tokens will always have an immediate context size of 512 tokens, as the previous segments activation are passed as context to subsequent segment's attention operations.  This means that information from \(N\) context size * \(L\) layers away can be propagated to a given token.  Assuming a context size of 640 and a model with 16 layers, the Transformer-XL can theoretically incorporate signal from up to 10,240 tokens away.

Token attention patterns from the Transformer-XL from the original paper.

In order to avoid having to store activations from all previous segments, the author's stop gradients from flowing through to previous segments.

Incorporating Relative Position

Transformer-XL also introduces a novel position encoding scheme they deem "relative position encodings".  Rather than simply treating the networks inputs as a sum of content and absolute position embeddings, each layers attention operation is broken up into a portion that attends based on content and a portion that attends based on relative position – for the 512th token in a chunk to attend to the 511th, the embedding corresponding to relative position -1 is used.

To make the use of relative position encodings tractable, they break up the operation that produces attention weights from the keys and queries. For a typical dense-attention operation, the pre-softmax attention weights can be decomposed as follows:

In the equation above, \(E_{x_i}\) is the content-based embedding of token at location \(i\), and \(U_j\) is the position embedding for token \(j\).

\(a)\) relates the query's content with the key's content

\(b)\) relates the query's content with the key's position

\(c)\) relates the query's position with the key's content

\(d)\) relates the query's position with the key's position

When using relative position embeddings, the author's modify the equation as follows:

In \(b)\) and \(d)\), \(U_j\) has been replaced with it's relative position counterpart, \(R_{i-j}\).

For the terms that included the query's position, we substitute the matrix \(U_i\) for two new learned parameters, \(u\) and \(v\).  These vectors can now be interpreted as two biases that don't depend on the specifics of the query – \(c\) encourages attends to some terms more than others, and \(d\) encourages attending to some relative positions more than others.  This substitution is motivated by the relative position of the query with respect to itself remaining constant.

Attention Introspection and Empirical Results

For the Transformer-XL model to make use of such long term context, at least one head from each layer would have to make use the full context of its attention span.  A plot of average attention weights show that there are heads from every layer that attend broadly to prior positions.

A plot of average attention weights from the Transformer-XL paper

In addition the Transformer-XL paper measures the impact of effective context length on perplexity and finds that increasing context length leads to better perplexity scores up to a context length of ~900 tokens – further evidence that the recurrence mechanism is useful in practice and not merely in theory.

See Kimi Young's github for source code or check out the HuggingFace implementation to start using Transformer-XL for your own side project.

Compressive Transformers

The next model on our list, the Compressive Transformer, builds off of Transformer-XL architecture and extends their methodology with a compressive loss to incorporate even longer sequence lengths. In the work "Compressive Transformers for Long-Range Sequence Modelling", Jack W. Rae, Anna Potapenko, Siddhant M. Jayakumar, and Timothy P. Lillicrap from DeepMind detail a model architecture capable of attending to sequences as long as full length books.

Compressive Transformer Attention

Diagram of the Compressed Memory in "Compressive Transformers For Long Range Sequence Modeling"

Following Transformer-XL's suit, the sequence can attend to a set of stored activations from previous segments. In addition, in the same multi-head attention operation, tokens in current segment can attend to a second set of states stored in "compressed memory".  

At each time step, the oldest compressed memories are discarded and the compressed memory is shifted back a single index.  Then, the oldest \(n\) states from the normal memory segment undergo compression and are shifted into the newly open slot in the compressed memory.

A gif from the DeepMind blog illustrates this process nicely:

Gradual compression of past memory into compressed memory – graphic courtesy of the DeepMind blog

The DeepMind team tried a variety of compressive operations (including baselines like mean pooling, max pooling, and learned convolutions), but settled on training a secondary network to reconstruct the content-based attention matrix of the memory being compressed.  

In other words, they learn a function, \(f_c\), that compresses the \(n\) oldest memory states to a single compresses memory state, by minimize the difference between attention over the compressed memory (\(C_{-1} = f_c(M_old)\)) and attention over the states in normal memory being compressed:

$${\sigma((XW^Q)(M_{0..n} W^K))(M_{0..n} W^V)} - {\sigma((XW^Q)(C_{-1}W^K))(C_{-1}W^V)}$$

Rather than training this compressive operation jointly with the main language model, they opt to update the compressive network in a separate optimization loop, as making the attention states easily compressible is counter-productive to reducing the language modeling loss.

Empirical Results

For their experiments, they use a compressed memory size of 512, a memory size of 512, a window size of 512, and a compression rate of 2 – meaning the 2 oldest memory states are compressed to a single state during the compression step. Using these settings they achieve a new state of the art test perplexity of 17.1 on WikiText-103.

As the gains from exploiting longer sequence lengths are typically long tail, they look specifically at perplexity bucketed by token rarity and note that gains are especially notable on the rarest tokens:

Perplexity bucketed by word frequency. Graphic courtesy of the DeepMind blog

Although their source code is not yet public, DeepMind has open sourced PG-19, the dataset they developed while working on the Compressive Transformer. PG-19 is a Project Gutenberg derivative intended to further research into long-term attention.

Reformer

Next up we have a work by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya entitled "Reformer: The Efficient Transformer". The Reformer takes a different tack at increasing sequence length – rather than introducing recurrence mechanisms or a compressive memory, they opt to narrow the scope of each token's attention by using locality sensitive hashing techniques.  

Locality sensitive hashing is a family of methods that map high dimensional vectors to a set of discrete values (buckets / clusters). It's most commonly used for as a method for approximate nearest neighbor search.

The Reformer authors use a single projection for both the keys and queries of the attention operation, and use a random rotation-based locality sensitive hashing method to group the shared keys / queries into buckets of at most a few hundred tokens. An illustration of the hashing method is below:

Angular LSH illustration from the reformer paper. The diagram illustrates a setup with 3 rounds of hashing with 4 buckets. The vectors in the lower illustration have mapped to the same bucket is all three hashes because they were close in inputs, while the upper portion illustrates vectors that map to distinct buckets in the first and last hashes.

They compute attention matrices within each bucket and then take a weighted sum of the corresponding values.  Because they attend only to elements within a given bucket, this can reduce the overall complexity of the attention operation from \(O(n^2)\) to \(O(n \log{n})\) if bucket size is selected appropriately. Because the bucketing process is stochastic and based on random rotations, they compute several hashes to ensure that tokens which have similar shared key-query embeddings end up in the same bucket with high probability.

They additional employ techniques introduced in "The Reversible Residual Network: Backpropagation Without Storing Activations" to keep train-time memory consumption under control. Reversible residual layers use clever architectural structure to allow the easy re-construction of layer inputs from layer outputs, and trade extra computation for a memory complexity that is constant in network depth.

With the locality sensitive hashing trick to reduce computational cost, and the reversible residuals to reduce memory consumption, the Reformer architecture is able to process sequences of up to 64,000 tokens long on a single accelerator.

Although the reported score of 1.05 bits per character on enwiki-8 lags behind some of the other models we've looked at in the course of this blog post, the Reformer is a refreshingly unique take on a mechanism to incorporate long term context and I'm looking forward to seeing how the approach scales up.

If you're interested in exploring Reformer architecture in more detail, take a look at my recent blog post "A Deep Dive into the Reformer" on the subject.  An open source implementation of the Reformer is available as an example in the google/jax Github repository.  A PyTorch version maintained by Phil Wang is also available.

Routing Transformer

I originally intended to cut my list here, but I've included one final paper at the suggestion of Aran Komatsauzaki. A second paper submitted to ICLR 2020, "Efficient Content-Based Sparse Attention with Routing Transformers" by Aurko Roy, Mohammad Taghi Saffar, David Grangier, and Ashish Vaswani shares some similarities with the aforementioned Reformer. They frame the problem as one of routing, and aim to learn to select sparse clusters of tokens, \(S_i\), as a function of the content, \(x\).

The author's illustrate their approach in the diagram below.  Rather than solely attending to local elements  or every nth element to increase sparsity, they learn clusters (denoted by color in figure \(c)\) within which to attend. Importantly these clusters are a function of the content of each key and query, not just their absolute or relative positions.

Comparison of Routing Attention with Local and Strided Attention, from the Routing Transformer paper.

Routing Attention

After ensuring each key and query vector has unit magnitude, they project key and query values using a shared matrix of random orthogonal weights of shape \((D_k, D_k)\), where \(D_k\) is the hidden dimension of the keys and queries.

$$ R = \begin{bmatrix} Q, K \end{bmatrix} \begin{bmatrix} W_R \\ W_R \end{bmatrix}$$

The vectors in R are then grouped into k-clusters according to a set of k-means centroids into clusters, \(C\).  The k-means centroids are learned separate from the gradient descent process, using one application of the k-means update rule per batch.

Within a given cluster, \(C_i\), they compute a new set of contextual embeddings using a typical weighted sum of values, where each attention value, \(A_i\) is computed using typical dot-product self attention.

$$ X_i^{\prime} = \sum_{j \in C_k}{A_{ij}V_j} $$

Because attention patterns in dense attention are often dominated by a few key elements, and because the cluster assignment process should group keys and queries with high attention weights into the same cluster, the authors argue that this preserves the key information that would have informed \(X_i^{\prime}\) had an expensive dense operation been applied.

Finally, they choose a number of clusters close to \(\sqrt{n}\), so that the overall complexity of their sparse content-based attention mechanism becomes \(O(n\sqrt{n})\).  To make the whole process easily parallelizable and deal with matrices of uniform size, the authors use the top-k terms closest to a each centroid in place of the true k-means clusters assignments.

In addition to the content-based routing attention, the Routing Transformer also performs local attention over a context window of size 256.

Empirical Results

The Routing Transformer's gains in computational efficiency also lead to perplexity gains on Wikitext-103, a word-level language modeling benchmark, where they edge out the Transformer-XL model described previously by a significant margin.

Wikitext-103 test set perplexity, from the Routing Transformer paper.

On enwiki-8, the Routing Transformer also performs quite well, although their results lag marginally behind the Adaptive Span Transformer.

Test set bits per character on the enwiki-8 character level language modeling benchmark.

I originally couldn't find an implementation of the Routing Transformer, but Aurko Roy was kind enough to point me to a zip of their source that was released as part of the ICLR review process.

Other Approaches To Long Term Context in Transformers

If you're interested in other approaches to incorporating long-term context in transformers, you might also enjoy reading:

Did I miss a paper you think should have been included?  Send me your suggestions on twitter!