Paper: MEGABYTE: Predicting Million-byte Sequences with Multiscale Transformers
Summary by Adrian Wilkins-Caruana
By now, you’ve probably heard of transformers — neural networks that use a specific architecture that was introduced by researchers at Google in 2017. Transformers are really good at modeling short sequences — i.e., on the scale of several paragraphs, like a multi-turn ChatGPT conversation. But they start to become memory-limited with longer sequences, like entire documents. This is the problem that researchers from Meta are trying to address with their new, transformer-inspired architecture MegaByte. As its name suggests, MegaByte can predict sequences of one million bytes. To give you a sense of what that means, a million-byte sequence is approximately 40k words, which is about the length of a master’s thesis — and much longer than a typical ChatGPT conversation or GPT-3 response.
MegaByte and transformers have key architectural differences (which I’ll discuss shortly), with one of the biggest being what each model considers an “element” in its input. Transformers, like GPT-3, convert their input sequences of plain text into sequences of “tokens.” Each token is a short sequence of characters, like “the,” “ized,” or “today.” GPT-3’s vocabulary consists of about 50k tokens, and transformers typically process and predict these tokens directly. On the other hand, MegaByte predicts the individual “bytes” that make up the characters in the text. For example, in ASCII, the text “mega” is encoded as the following sequence of integer-encoded bytes: “109 101 103 97.” Thus, MegaByte’s vocabulary consists of only 256 8-bit bytes.
Tokenization is a language-specific heuristic, and it adds complexity to LLMs. So why do transformers use tokens instead of bytes?
Consider the sentence below, which I generated using OpenAI’s interactive tokenizer. Because the attention mechanism scales quadratically, a byte-level transformer would need a lot more memory to model the byte-level sequence than the token-level sequence. So tokenization makes it feasible to model multi-paragraph texts. But transformers still scale quadratically, so they’re still limited by long sequences of tokens.
MegaByte can model text as bytes instead of tokens because its memory complexity doesn’t scale quadratically with the input sequence. It does this by splitting the input sequence into “patches.” Patching is a technique borrowed from vision transformers (ViT), where the pixels (which are essentially bytes) in an image (that may contain millions of bytes) are chunked into patches before being modeled by the transformer. But unlike ViTs, which typically only classify images, MegaByte needs to generate new text at its output — and this is where its unique architecture comes into play. Here’s a diagram of it:
Instead of modeling tokens or even the bytes, MegaByte’s input is actually a sequence of embedded patches; a “global” transformer then models this sequence. Then a smaller model, called a local model, predicts the sub-patch bytes. Each model learns to predict the next “thing” in the sequence. For example, in the image above the global model begins with a start-of-sequence token “___” at h-global-in, and predicts “_meg” at h-global-out (where “_meg” is a 4-character sequence and “_” indicates the beginning of the output.) Likewise, the local model then predicts “mega” at h-local-out given “_meg” at h-local-in.
MegaByte’s memory complexity depends on the patch length and the total sequence length. This means we can tune the patch length for optimum memory complexity. At its worst, MegaByte’s memory complexity is still sub-quadratic (if you’re familiar with big-O notation, the best case is O(n^1.33)). When it comes to computational complexity, MegaByte is also much better than quadratic-scaling transformers, and even better than a linearly-scaling transformer variant up to 1M bytes, as you can see below.
In addition to MegaByte’s sub-quadratic scaling, another key benefit of its architecture is that the local model can decode all of the patches into text in parallel, unlike transformers, which generate the text tokens sequentially. To evaluate MegaByte’s performance, the authors were less concerned with absolute performance (i.e., MegaByte vs. GPT-3) and more focused on MegaByte’s learning and training efficiency compared to other models (i.e., controlling for number of training iterations and computational training budget).
Compared to other competing architectures, MegaByte is pretty good at modeling text, images, and audio (better on some benchmarks, worse on others). But the point of this paper isn’t to blow the other models out of the water. Instead, this paper shows that it’s viable to push beyond the quadratic-scaling boundaries of transformers, enabling new applications such as allowing models to analyze entire books and albums of images or songs.