LLMs with fast, infinite-length conversations
Paper: Efficient Streaming Language Models with Attention Sinks
Summary by Adrian Wilkins-Caruana and Tyler Neylon
When I use OpenAI’s GPT models, I occasionally opt for GPT-3.5 instead of the bigger, more capable GPT-4. When I do this, it’s almost always because I want to get my result quickly, and GPT-3.5, being a smaller model, streams its results to my browser much faster than GPT-4 can. Behind the scenes at OpenAI, there’s no doubt a complex suite of software that’s engineered to deliver my GPT results as fast as possible. But, at some point, AI engineers run up against the fundamental limits of their models. In the case of transformer models (like GPT-3.5 and 4), there’s a tradeoff between engineering a model that generates the best possible results and one that generates the results quickly enough for their users’ patience and their servers’ processing capacity. Today’s paper is about a breakthrough in how transformers are engineered, one that allows transformer models to stream their results more quickly and efficiently without sacrificing their language modeling capabilities.
The figure below shows two things that can go wrong with a transformer model running on a text window of size T, where T is large (a.k.a, an input with many tokens). The blue-and-gray matrix is a symbolic visualization of the model’s attention mask, helping us to see which internal token states depend on which others. The red square denotes the output token, which may refer to all of the previous (T-1) tokens within each attention layer. Similarly, the dark blue tokens denote each (internal) token vector, which are modified every time they go through an attention layer, and may also depend on all previous token vectors. There are O(T^2) dependencies in the computation of a single output token — which is a lot when T is large.
That O(T^2) translates to a high memory and computational cost, but it’s not the only problem. If the model were only trained on inputs up to length L < T, then we’d also expect the quality of the output token to be poor, as indicated by a high perplexity (PPL) value. We want both of these properties — efficiency and output quality — to be better for long text windows.
Clever AI engineers have come up with two ways to get around these issues, but each has their own issues. One method called “window attention,” shown on the left of the figure below, computes attention only in a window of L tokens around each input token. Each time a new output token is called for, it internally re-uses the already-computed token vectors, and only needs to perform new calculations corresponding to the output token — so each new output requires O(L) time (thus O(TL) time for the entire process). This improves efficiency but not quality because the system wasn’t trained to re-use internal token vectors. Another approach called “sliding window,” shown on the right, rebuilds the attention’s key-value store for each window, which requires O(L^2) for each output token (thus O(TL^2) for the entire process). This uses the model as it was trained so that the quality is better, but it remains somewhat inefficient.
The contribution of this paper stemmed from a curious observation made by the authors: Xiao et al. noticed that a surprisingly large amount of attention scores are allocated to the initial tokens, irrespective of their relevance to the language modeling task. This phenomenon is visualized in the figure below, which shows the strength of the attention weights (red is stronger, blue is weaker) in various layers of the transformer — the nth row shows how much the nth token vector depended on each of the previous token vectors. The first two layers (left and middle) exhibit the “local modeling phenomenon,” where weights are stronger for nearby tokens (reds near the diagonal). But in the subsequent layers (one of which is depicted on the right), much of the attention scores are concentrated on the first token (the left column). This explains the poor behavior of the windowed attention, which drops the first token: The model can’t make use of the evicted tokens at the beginning of the sequence!
The authors’ solution to this problem is something they call an “attention sink,” which is essentially the same approach as windowed attention except it always retains the first few token vectors of the sequence, as opposed to evicting them. This way, the model can still confidently attend to these initial tokens. The attention sink method is shown below. Notably, it achieves linear memory complexity like windowed attention, but it also achieves low perplexity!
When retaining the first four tokens to be used as attention sinks, the authors show that existing pre-trained models like Llama 2 can model sequences of up to four million tokens, and predict new tokens 22 times faster than the sliding window with re-computation approach! (It’s faster because of the reduced token dependencies — O(TL) is faster than O(TL^2).) The authors also trained their own language model to use only a single token as its attention sink; they include a special, learnable token at the beginning of all training examples that serves as the designated attention sink.
I want to be clear that the attention sink method doesn’t allow a model to, for example, refer to page 5 of a book while it’s modeling page 100. The attention window only allows the model to look at the nearby text. What’s new about this approach is that LLMs can now read or write much longer sequences very quickly while maintaining high output quality. Previously you’d have to give up either speed or quality.
Historically, a limited context length has been one of the bottlenecks of LLMs, but valiant researchers are doing their part to vanquish it! As language models become more pervasive by the day, the attention sink approach can make them a lot faster and more usable. Thanks, valiant researchers!