Summary by Adrian Wilkins-Caruana
[Links to papers at the bottom]
According to Google’s own technical report, their new Gemini model, Gemini 1.5 Pro, is an impressive beast. It boasts multimodality, super-efficient inference, and class-leading language modeling abilities. But its most impressive feature (to me) is its absolutely bonkers context length of up to 10M tokens! That’s ten times the entirety of the 1,440-page book (or 587,287 words) "War and Peace." Unfortunately, the report doesn’t disclose exactly how the Gemini team achieved this impressive feat. They are no doubt trying to conceal several cleverly-engineered technical solutions, but — and this is entirely speculation on my part — perhaps some of these solutions are a series of techniques developed by Hao Liu, a PhD student from UC Berkeley.
Last August, Liu published a new algorithm (paper [1]) for computing the attention and feedforward steps in Transformers using blocks. A single block is just a subset of either query vectors or key-value vector pairs. Instead of computing the interactions between all query, key, and value vectors at once and using the result in the subsequent feedforward projection, the data is chunked into two groups of blocks: one group of blocks for the queries and another for the keys and values. Then, the output of the attention and feedforward steps are computed on a block-by-block basis, as shown here:
Liu calls this method a Blockwise Parallel Transformer (BPT), since the blockwise computations can all be computed in parallel. But the main benefit of BPT isn’t faster computation via parallelism — it’s that the memory consumption of each blockwise computation is significantly smaller than a non-BPT equivalent. Normally, if an LLM accepts an input sequence of length s, then the result of multiplying all query vectors in matrix Q by all key vectors in matrix K will give us a large s×s matrix. Traditionally, that entire matrix is stored in memory at once because each row must be normalized via softmax so that we can find the weights used by the attention mechanism. The next image shows one aspect of BPT: how it uses a smaller query-key product to reduce memory needs.
With that one idea, blockwise computation already reduces the memory usage down from O(s²). But BPT does more: It also computes, in blocks, the output of the feedforward network (FFN) after the attention module. Normally the hidden layer of the FFN, which is larger than its input, uses up 4x as much memory as its output. Holding onto that data would be the next memory bottleneck (after the s×s Q-K matrix product). BPT reduces memory usage here, too, by only requiring us to hold in memory a single (query-based) block at a time. In other words, the FFN is only ever computed on a small subsequence at a time, requiring less memory. This makes the memory bottleneck smaller still, which is how BPT is 4x more efficient even when compared against Flash Attention, another method of reducing attention’s memory needs.
Despite how efficiently BPT uses memory, its memory requirement still scales linearly with the input sequence length. So even though it’s quite efficient, it would still have an impressive but limited context length. BPT alone couldn’t possibly account for Gemini 1.5 Pro’s gargantuan context length.
A short three months later, Liu followed up the BPT paper with another BPT-based paper. By extending the BPT approach with something called Ring Attention (paper [2]), the method’s memory requirements scale linearly with the size of the query, key, and value blocks, rather than the length of the input sequence. The idea behind Ring Attention is that the BPT method’s parallel block computations can be distributed among several devices (think GPUs or TPUs), where each device computes one element of the outer loop from the figure above.
The main problem with distributing BPT across devices this way is that, even though the devices each need only their respective query block, they still need access to all the key-value blocks; holding them on a single device would hurt memory efficiency. To overcome this challenge, the devices are arranged into a ring-like topology (each device can share data with either a single previous device or a single following device, connected in a ring).
The BPT’s fused attention-and-feedforward computation then proceeds as follows: First, each device holds the data for one of the query blocks in the outer loop, while key-value blocks are rotated between devices in a ring. This is really efficient since, while the device is performing its blockwise attention, it can simultaneously send the key-value block it’s currently holding to the following device and receive a new key-value block from the previous device.
Ring Attention is depicted in the figure below. The solid boxes represent computations (solid outlines) and data (dashed outlines) that are stored on one of the hosts, while the faded boxes represent data for other devices. As you can see, the device depicted is responsible for all the attention and feedforward computations for a single query block, while the keys and values traverse the devices in the ring.
The key reason why the Ring Attention method computes the same result as in a vanilla Transformer is actually really subtle. For instance, it’s not obvious that blockwise self-attention and feedforward computations are equivalent to their vanilla counterparts — but they are! That’s because the final result of the blockwise computations is invariant to the order in which they were computed. It’s similar to how the result of the computation 4+5 is the same as 5+4: It doesn’t matter which “block” comes first.
Ring Attention takes BPT’s memory efficiency to a whole new level, achieving context lengths 2–3 orders of magnitude longer than the competition. Intuitively, you could say that Ring Attention can work with a context window proportional to the number of devices in the ring. The figure below shows the maximum context size for various Transformer architectures on a TPUv4-1024 chip, which has 32GB of memory. The green bar represents the maximum context length for a vanilla Transformer, the yellow for a Transformer with Flash Attention, lavender for BPT, and the red for BPT with Ring Attention.
As if that wasn’t enough, Liu and coauthor Wilson Yan wrote a new paper that extends BPT with Ring Attention even further (paper [3]). In their paper, they present some ways to train models that understand text and video sequences. They also present solutions to various technical challenges that arise when doing such training, such as what to do when text and video sequences have different lengths or how to apply different weights to language and vision inputs. This article is getting quite long, but I encourage you to give their paper a read if you want to learn more.
Gemini 1.5 Pro is an undeniably impressive feat of engineering, requiring vast amounts of resources. But it’s important to remember the human element behind advancements like Gemini. I don’t know of an official link between Hao Liu and Gemini — maybe they're not really linked — but we can still appreciate the significance of Liu’s contributions in achieving a new era in LLM context lengths. Individual researchers like Liu can make significant contributions, even seemingly single-handedly developing groundbreaking algorithms that propel entire fields forward. Even the most opaque and large-scale technological achievements are ultimately driven by personal ingenuity.
—
Hi, it’s still me, Adrian. I hope you enjoyed today’s summary as much as I enjoyed writing it. Actually, I lie — mostly I enjoyed reading Hao Liu’s papers. As someone who just submitted their PhD thesis, let me say that what Liu has achieved over the span of their PhD so far is nothing short of incredible. If you share this view, you should definitely give these papers a thorough read:
[1] Blockwise Parallel Transformer for Large Context Models (This is the original BPT paper).
[2] Ring Attention with Blockwise Transformers for Near-Infinite Context (This is the one that extends BPT with Ring Attention)
[3] World Model on Million-Length Video and Language with Blockwise Ring Attention (This is the one that does the crazy multimodal text-video training).