[The original DeepMind article]

The International Math Olympiad (IMO) is an annual competition where rising stars in mathematics represent their country in a test of their abilities. One of the contestants in this year’s IMO was a peculiar stateless entrant, backed by the folks at Google DeepMind. The entry consisted of the joint efforts of two AI systems: AlphaProof and AlphaGeometry 2. The system scored a total of 28 out of 42 points, or, put another way, it got perfect scores on the four (out of six) problems it managed to solve, and none for the others. It placed 58th out of 609 contestants, which this year equated to a silver medal.

First, some groundwork: Regular human contestants have just two 4.5-hour blocks to submit solutions, but Google DeepMind’s systems took longer than that. AlphaGeometry 2 solved one problem within minutes, but AlphaProof took *three days* to solve two algebra problems and one number theory problem. The systems weren’t able to solve the other two problems, which covered combinatorics. So, this AI participant wasn’t *really* a contestant, but I don’t think that takes away from Google DeepMind’s achievement.

As its name suggests, AlphaProof proves mathematical statements, and it does this using a combination of two things. The first is a software tool called *Lean*, which verifies whether a proof of a mathematical statement is correct. Proving something in Lean is kind of like finding a path from your house to the grocery store. The proof itself is a series of *moves* or *proof steps*, like “turn left” and “stay on Biscayne.” Lean checks each step to make sure it's valid, which is important because, depending on the state of the problem, not all moves are valid. The second part of AlphaProof is AlphaZero, which is the same path-finding AI that Google DeepMind previously used to master chess, shogi, and Go, and it works in much the same way as AlphaProof: by evaluating which paths show promise.

But doing mathematical proofs is quite different from playing chess or Go. The researchers needed to train the system on the types of problems that AlphaProof would encounter in the IMO. So they fine-tuned a Gemini LLM to translate natural language problem statements into formal ones. Then, using a dataset of ~100k formal problem statements, the researchers trained AlphaProof to prove or disprove them by searching over possible proof steps (i.e., *moves*) using Lean. When Lean says that the series of steps satisfies the original problem statement, it’s done! Each proof that AlphaProof finds reinforces it, enhancing its ability to solve problems. The figure below shows this process.

AlphaGeometry 2 takes a slightly different approach. The Google DeepMind researchers describe it as a “neuro-symbolic” hybrid system since it uses a language model (based on Gemini) and a symbolic engine. This approach starts by representing the geometric problem as a graph (or network) of geometric symbols (e.g., a point, line, angle, circle, segment, etc.). Then, a language model suggests some next steps (e.g., construct D: midpoint BC). Next, using some primitive geometric relationships like the behavior of parallel lines or midpoints, the symbolic engine searches over possible next steps (kind of like AlphaZero) to see if it can reach the goal of the proof. For instance, in the example below, the language model suggests a construction that helps prove that the triangle is isosceles where, in the right frame, the first blue statement represents the construction, and the subsequent ones represent the subsequent symbolic deductions.

TWIST! What I described just above was actually AlphaGeometry (the original, which was announced in January 2024), not its successor. AlphaGeometry 2 includes three main enhancements that help it solve much more challenging problems:

It’s trained on an order of magnitude more synthetic data.

Its symbolic engine is two orders of magnitude faster, allowing it to expand its search for solutions.

It uses a novel

*knowledge sharing*mechanism, which lets the deduction engine share information from disparate steps/constructions and their subsequent deductions, and — if helpful — combine these disparate steps.

When I was studying math in school, I often found myself using WolframAlpha to check my answers or my work. I see Google DeepMind’s system as an extension of this idea, where mathematicians can use tools like this to help them solve and verify solutions to problems. To some extent, this is already happening!

This year, Australian mathematician Terrence Tao presented a talk on “machine assisted proofs,” where he said that he suspects it’s ~20x harder to write formal proofs (ones computers can verify) than informal ones (ones that mathematicians write, publish, and peer-review). However, he also said that AI integration could change this, potentially tipping the balance in favor of formal proofs, which would have a dramatic impact on the field of mathematics. If we extrapolate the rate of progress made by AlphaGeometry over this year alone, it seems that the balance is definitely shifting, and potentially quite rapidly.

When I hear people say “AI is going to take all our jobs,” what I think they mean is that LLMs like ChatGPT will automate more and more tasks to the point where many tasks don’t require a human anymore. There’s a bit of hand-waving involved in that inference, but I think it’s pretty fair. But LLMs are just text generators, and most jobs involve more than just pressing keys on a keyboard. So, how can we use LLMs to, say, automate a job that involves manipulating spreadsheets? How would it read the spreadsheet, let alone use it to answer questions about the data? Today’s summary is about a new method that lets an LLM do exactly that.

Let’s pretend that *we* are AI engineers and our job is to make an LLM manipulate a spreadsheet. How might we do this? One way might be to describe each cell using text. Let’s use this approach to describe a very important spreadsheet of mine:

We can describe this spreadsheet like this:

*The following text describes a spreadsheet for tracking foods and their ratings. Here’s the data:*

*Cell A1. Text: “Food”. Formula: None. Formatting: Bold, and centered.**Cell A2. Text: “Ice cream”. Formula: None. Formatting: None.**Cell A3. Text: “Milk & cookies”. Formula: None. Formatting: None.*…

We could then ask the LLM to do some work for us, like “Please tell me how to calculate the average rating,” and it might say:

`Cell B6. Text: None. Formula: "=Average(C2:C5)". Formatting: "Numeric, two decimal places"`

.

There are, however, two main issues with this naive approach and others like it: It’s unnecessarily verbose, and its index-first structure isn’t ideal for an LLM. The next bit of this summary explores some clever techniques — developed by Microsoft researchers — to fix these problems.

When you make a spreadsheet, do you use whitespace/empty cells to delineate particular tables or separate different kinds of info? So do I! But it turns out that this whitespace is really unhelpful for an LLM, since it adds a lot of useless, distracting information to a text-encoded spreadsheet. So the researchers came up with a technique called *structural anchors,* a heuristic-based algorithm that essentially draws boxes around useful information in a spreadsheet. The method then extracts the cells inside these anchors (and a little bit outside the anchors, just in case the structure isn’t perfect), and remaps the addresses so that they make sense without the whitespace.

Continuing on the theme of “things about spreadsheets that humans like but LLMs don’t” are the 2d matrix format, our repetition of some values (e.g., “Dessert” in my spreadsheet), and that we sometimes scatter useful bits of info at seemingly random places. The researchers found that LLMs, being the language-lovers that they are, much prefer a dictionary-like format. So, the researchers created an inverse index–based translation method that flips a spreadsheet on its head. It uses the *values* of the cells as the primary keys, not the cell indexes. The dictionary values are lists of cell indexes, so the encoding can easily represent repeated values. My spreadsheet above might look like this:

```
{
“Food”: A1,
“Category”: B1,
“Rating”: C1,
“Ice cream”: A2,
“Dessert”: B2, B3,
“9”: C2,
...
}
```

The researchers realized one more way they could better represent the spreadsheet for an LLM. To understand this trick, keep in mind that spreadsheet-aware LLMs don’t actually have to do any computing on their own, because they can produce commands that are executed by the spreadsheet software. In other words, they simply ask the spreadsheet to perform the calculations, just as a human would. Because of this, the LLM doesn’t need to know the specific numeric values of the cells — it just needs to know what format they are (for example, an integer). So the encoding can represent numeric cells — like integers, floating point numbers, percentages, dates, etc. — using some text that describes their format. In my example above, the encoding represents the Rating values in the dictionary format like this: `"IntNum: C2:C5"`

.

Overall, these three tricks reduce the number of tokens needed to represent the spreadsheet by 25x compared to one of the naive encoding methods the researchers considered. They tested their approach on a Spreadsheet QA task, and found that regardless of which base LLM they used (e.g., Llama 3, Mistral 3, GPT-4, etc.), LLMs that use these techniques equalled or outperformed the existing spreadsheet analyzing technique, called TableSense-CNN. The GPT-4 model had the best F1 score on this benchmark, on average scoring 9% higher (76%) than TableSense-CNN (67%).

The researchers also conducted ablation experiments, individually excluding one of their three techniques (anchoring, inverse-index/dictionary encoding, and data type aggregation). While the first two techniques tended to improve the F1 scores with GPT-4, surprisingly, the last technique actually made the F1 scores slightly *worse;* the best score achieved on their benchmark was 79%, using GPT-4 without aggregation. The researchers hypothesize that this might be because the data types are a bit too abstract for the LLM. Nonetheless, they suggest that the aggregation could be necessary for some models that have a limited context length, since the aggregation reduces the number of tokens the model needs to represent the spreadsheet significantly.

I think it’s remarkable that an LLM can work with spreadsheets so well considering that spreadsheets are fundamentally designed for humans, not computers. It seems to me that a spreadsheet is really an unsuitable tool for an LLM to use for solving problems. Still, the techniques presented in this paper could be really helpful to humans when we use spreadsheet software. For example, we could use it to ask questions like “How can I forecast next month’s expenses?” or “When will we break even?” Given that this paper comes from Microsoft researchers, I’m sure that such features are coming to Excel soon!

In a 2007 essay called *The Origin of Circuits*, Alan Bellows tells a fascinating story about an experiment that Dr. Adrian Thompson conducted in the 1990s. Thompson arranged a 10 x 10 array of logic gates (in a configuration now known as an FPGA) and tried to see if he could evolve a program encoded by these gates to reliably distinguish between signals of two different audio frequencies. With sophisticated logic configurations (like complicated signal processors), this was a trivial task at the time, but even though the array was so small, Thompson found that there were indeed logic configurations that could reliably detect these signals.

Aside from this main result, there are two other remarkable things about Thompson’s experiment. First, the logic configuration was updated iteratively in an evolutionary manner that’s quite similar to evolutionary machine learning algorithms. The other is that, upon investigating the most successful configuration, Thompson noticed a section of the array that was logically disconnected from the array’s output yet, without it, the array couldn’t reliably classify the signals. This means that the disconnected logic section was influencing the classification through some mechanism other than digital logic, and that the evolutionary algorithm seemed to account for the effects of this mechanism as it updated the program. (It turns out some of the logically-disconnected gates were influencing the voltage of other nearby gates via magnetic flux.) I highly recommend giving Bellows’s essay a read if you haven’t before.

With three decades of hindsight, we can see that Thompson’s array of logic gates was an example of a *physical neural network*, or PNN. PNNs are neural-like networks that aren’t built from silicon chips (though they could be) but instead from components that harness other physical phenomena, like light or sound. In a sense, PNNs offer an alternative paradigm of machine learning, one which isn’t necessarily constrained by the limitations of digital logic. That is to say, PNNs can let us harness various physical phenomena to solve problems with machine learning. Today’s summary explores training PNNs, i.e., the different ways that PNNs’ parameters can be updated to solve particular problems.

Before exploring how PNNs are trained, let me quickly describe the concept of *back propagation* (BP), the workhorse of traditional, digital neural network training. When an NN makes a prediction (sometimes called a *forward pass*) and we know what that prediction should be, we can calculate the error in the network's prediction. We can then propagate the error backwards through the network (sometimes called a *backward pass*), updating the network’s parameters so that it's more likely to give a more accurate answer the next time it runs on similar inputs.

One way of training a PNN, called *in silico training*, mirrors BP quite closely. In silico training involves digitally simulating and optimizing physical parameters (θ) using a digital twin, which is an emulation of the physical hardware within a computer environment. Similar to BP in traditional neural networks, in silico training uses these digital models to compute gradients and update weights, which you can then apply to the physical system through some other process. This approach benefits from the rapid, cost-effective iteration and testing of PNN architectures, but it might not work well if the digital twin isn’t perfect. This means that the entire process is simulated, and the physical model is only given the learned weights at the end.

Another training approach, called *Physics-aware BP training* (PAT), is a hybrid of in situ (meaning not simulated) and in silico methods. In PAT, the physical system handles the forward pass, while the backward pass is performed by differentiating a digital model that is an approximation of the physical system. This means the info for the forward pass will still be precise while maintaining the versatility of performing the backward pass on a computer. You still need an accurate digital twin to effectively model the backward pass, and the larger and more complicated the PNN is, the harder it is to make an accurate model of it.

Both of the methods described above are, in a sense, cheating, since they aim to train PNNs using conventional digital NN techniques. But there’s good reason for this, since BP has been shown to be much more effective than other techniques for training digital NNs. But, as we’ve seen with in silico training and PAT, it can be hard to accurately back propagate error signals through a PNN. Are there any other ways? Here are two:

**Feedback alignment (FA)**is an alternative to BP whereby some of the terms in the weight-update algorithm of BP are replaced by random terms. This essentially transforms the update rule into a random walk in weight space. This means that, unlike BP, we don’t need to know exactly what the weights were in the forward pass to know how to update them.

**Physical local learning**uses a concept called*local learning*to train the weights in each block or layer independently (i.e., without any BP). There are lots of different ways that local learning could be achieved, but they typically all try to define some objective function using the layer’s activations, one that indicates whether the activations are doing something useful, like compression or providing useful information for the next layer/block. Geoffery Hinton’s forward-forward technique is one example of local learning, and one study has already used it to train an optical NN with a contrastive-based approach.

These methods aren’t quite as effective as BP, though, so other studies try to reproduce BP without the need for a physical twin; they essentially encode the BP algorithm directly into the physical system. For this to work, the system needs to utilize some physical process that’s a linear reciprocal function, i.e., a system that behaves like y = 1 / x (I’ve omitted the coefficients for simplicity). Two examples are waves propagating through a linear medium in a photonic system, or a peculiar electrical device called a memristor crossbar array. There are also other techniques for in situ training, such as a one called *continual learning* that updates the model’s parameters as it’s used.

At the end of their review, the authors arrived at three qualities that would be great for a PNN to have, although nothing meets all three (yet):

They don’t depend on the model used.

They give a speed or efficiency advantage over regular NNs.

They are more resilient to noise.

But a PNN doesn’t need to have all three of these qualities to be useful. It just means a bit more effort might need to go into developing them, since we can’t yet say for certain things like, “Oh, this particular learning algorithm works best for this kind of PNN or this kind of application.” Given the current pace of AI developments, the possibility of realizing all three at once could be on the horizon, which could open the doors to an entirely new domain of AI.

Last week we discussed a new method for resolving the 3d structure of a scene from two perspectives or photos of it. I mentioned how our brains can do this too, using the images from each of our eyes to perceive things in 3d. But what if you close one eye; do you lose your ability to see in 3d? The answer is a resounding “No!” Our brains can still perceive a lot of 3d info from a single, unmoving image. In fact, we do this all the time when we look at photos and use things like occlusions and shadows to infer depth and scale. This is why we find optical illusions like the Penrose Triangle perplexing.

If our brains can perceive the depth of objects using information in an image, then computers probably can, too. This process, known as monocular depth estimation (MDE), has been an active area of machine learning research for some time now. Before discussing that research, let’s learn a bit more about MDE. The figure below shows relative depth estimates on some images predicted by a model called Depth Anything V2. The redness/blueness of the results indicate parts of the image that are closest/farthest from the camera, respectively.

Depth Anything V2 is an impressive model. But the story of its success isn’t just “more data” or “bigger model.” To appreciate it, we first need to review how we got here:

MiDaS is an impressive MDE method that broke onto the scene in 2020. It predicts the relative depth of pixels in an image, and was trained using supervised learning on a dataset of over a million images with depth labels. When it was released, it was the state of the art for MDE.

The MiDaS team made incremental improvements, and its third iteration was the state of the art until this year. Despite its success, it struggles to predict depth on images that are different from its training data (that is, zero-shot prediction).

In January, the researchers behind Depth Anything V2 released Depth Anything (V1). While this model introduced several innovations to help zero-shot generalization (like augmentation and a special loss term), its real innovation — and the one that ultimately helped improve its zero-shot generalization — was its use of

*unlabeled*training images. (I’ll explain how this is possible shortly!)Finally, we now have Depth Anything V2 a short six months after V1. In terms of data, V2 took a drastic approach, ditching labeled data entirely for synthetic data! As we’ll see, it improves upon V1 in a number of ways, including fine-grained details, accuracy, and its ability to not be fooled by confusing surfaces like windows and mirrors.

From this MDE timeline, it’s clear that training data has played a pivotal role in MDE iteration. If you think about this for a moment, it kind of makes sense. The depth “labels” — which are generated from a number of depth-sensing sources like RGBD cameras (yep, the “D” is for “depth”) or LIDAR — can have a lot of issues, like not being as high-resolution as their corresponding image, or they might be noisy or just not very accurate. So MDE models trained on this data will be fundamentally limited by it.

In Depth Anything V1, the researchers came up with a way to use unlabeled images (i.e., regular images) as training data. To do this, they first trained the best model they could on the labeled dataset, then used this model to predict the depth on unlabeled images, and used these predictions as *pseudo labels*. This is called a student-teacher approach, where the big model trained on the labeled data is the teacher, and its knowledge (i.e., predictions) are used to teach the student model, which is nice to have because it’s much smaller (and so more convenient to work with) than the teacher model. The reason I’m taking time to specifically mention this aspect of V1 (and not any of the other clever tricks they invented) is because it’s crucial for V2. The figure below shows a rough schematic of the student-teacher process, where the solid lines indicate the flow of labeled data/images, and dashed lines represent the flow for unlabeled ones. (The “semantic preservation” is one of the tricks I mentioned, and it prevents the encoder from varying too much from when it was trained on labeled data).

So, Depth Anything V1 used unlabeled data to address the zero-shot generalization problem of MiDaS. V2 goes further to improve the accuracy and robustness of the predictions by extensively using synthetic data, which you can think of as renderings of 3d scenes — like from a video game — where the depth can be measured using the information in the 3d model. While it sounds insane, the choice to completely ditch labeled data actually makes a lot of sense. If the labeled data isn’t accurate enough, just replace it with highly accurate synthetic data, right? Well, there are two very good reasons *not* to do this: Synthetic data is typically quite different from and far less diverse than real imagery. So there’s a large domain shift between synthetic and real data. But combining synthetic data with both the student-teacher approach and unlabeled real images alleviated these two problems and yielded a model that’s both accurate and has good zero-shot generalization.

I think the best way to demonstrate why using only synthetic data is so helpful is to take another look at the pictures of the bridge and the room above. In the bridge image, there's a lot of fine-grained detail that a depth sensor might not be able to capture. And, in the room image, notice how the depth indicates the window, not the objects that you can see through the window. This is the kind of depth detail that’s really difficult to accurately capture in the real world.

“But Adrian!” I hear you screaming, “Surely there must be *some* value in all that labeled data.” Well, the Depth Anything V2 authors thought that too. So, when training the student with the pseudo labels on unlabeled images, they tried including a little bit of labeled data — the highest-quality labeled data they had. But they found that mixing in just 5% labeled data (keeping 95% synthetic) seemed to harm performance, particularly for fine-grained details. You might need to squint, but the figure below shows that the model that used synthetic data only (middle column) is definitely superior.

I really like the authors’ approach in the Depth Anything V2 paper, and I think their model’s results are quite impressive. It’s a very “outside-the-box” idea, and I’m curious to see if researchers in other domains can use similar ideas to make other AI breakthroughs. I highly recommend taking a look at the paper’s webpage to see more examples of what they accomplished.

I’ll leave you with one more Depth Anything V2 example below, one that yet again shows how impressive it is while also being a bit of a contradiction to the authors' claims of robustness. Specifically, they claim that the model is quite robust to domain shift (which it certainly is), but they use these depth-prediction results on drawings and paintings (shown below), among others, as an exemplar of this idea. But what *is* the correct result in this case? Should it be a realistic depth (as Depth Anything V2 predicts) or just the depth of a planar surface, since the paper or canvas is presumably flat? I think it should be flat, since this would be more in-line with predicting windows instead of what’s shown through the window. What do you think?

Your brain is amazing. Whenever you open your eyes and look around, you experience your surroundings in 3d despite only having two 2d views of it, one from each eye. This 3d environment that you perceive is so good that, without much effort at all, you can accurately judge things like how hard and in what direction you need to throw a ball so that it gets to a specific person, or how far it is between your car and a red traffic light in the distance.

To give you an idea of why it’s amazing that your brain can do this, let’s quickly break down the process of *multi-view stereo* (MVS)* *reconstruction, which is computer-speak for “How do I turn the two flat images from these two cameras into a 3d model of what they saw?” First, a computer needs to identify and match the same parts of the images. Then, using information about where each image was taken and other details like the parameters of the camera’s lens, along with some complicated mathematics, it can reconstruct where in 3d-space each of those parts of the image must have been. Believe it or not, that’s actually an *oversimplification. *The figure below, taken from the paper of a popular MVS technique called COLMAP, shows a breakdown of an actual MVS pipeline.

The approach taken by COLMAP and other MVS techniques — that is, using mathematics and algorithms — seems perfectly sensible to me, and it works quite well. But is this what our brains are doing when they see the world? Maybe. Or maybe our brains operate more like a new machine learning-based approach called DUSt3R.

Like COLMAP, DUSt3R turns images into a 3d point cloud, but in a completely different way. Here’s how it works: First, the model extracts small patches from two images, and then separately encodes them using the same Vision Transformer (ViT) encoder. Then, two ViT decoders share information about these patches via cross-attention to generate one feature vector for each patch in each image. Finally, a “head” (a fully-connected layer) predicts the 3d positions {x, y, z} of each pixel in each image, as well as a confidence value that indicates how confident the network is about each pixel’s prediction. Importantly, the 3d positions predicted by the head for the second image are in *the same* coordinate space as the one from the first image.

The authors used supervised learning to train their model, which means they needed the corresponding 3d locations of pixels in image pairs of the same scene. Their training dataset consisted of about 8 million examples of this kind of data, and contained both indoor and outdoor images, as well as images of objects. Then, to optimize the model’s parameters, they used a regression loss, which is just the average error, or distance, of where DUSt3R thinks the pixels are versus where they actually were. These errors were each scaled by the confidence value that DUSt3R predicts, which is helpful because sometimes it’s really hard to know the exact location of particular pixels, like ones in the sky or in reflections.

Compared to about a dozen other MVS methods (some neural network-based, others more traditional), DUSt3R performed the best in terms of absolute relative error. But the more impressive result is DUSt3R’s zero-shot prediction accuracy (its accuracy on datasets it wasn’t trained on), where it was almost as good or sometimes even *better* than non-zero-shot neural approaches. Note that the traditional approaches should also be considered zero-shot, since they weren’t designed for any specific dataset — but DUSt3R still seems to outperform these approaches more often than not. And remember: DUSt3R doesn’t need information about the cameras’ poses or their intrinsic parameters either!

The DUSt3R researchers recently followed up their method with an extension they call MASt3R. MASt3R improves on DUSt3R’s approach by emphasizing pixel matching: matching each pixel in one input image with the pixel of the same point (in the 3d scene) in the other input image. Here’s an example of pixel matching between two input images:

The figure below shows MASt3R’s architecture, which adds an additional head onto the ViT’s decoder. From the image patches, this head generates a vector of features for each pixel in each image; this new vector-per-pixel data provides the additional info MASt3R uses to match pixels across the input images. And MASt3R’s loss function is the same as DUSt3R’s confidence-weighted regression loss, but with an additional loss term that penalizes the model for every pixel that it incorrectly matches.

Without going too deep into the details of MASt3R, this extension (pixelwise matching) adds a lot of additional complexity (efficient pixel-matching is non-obvious), but the authors introduce clever algorithms for solving this and other related problems. All this effort is worth it, though, since MASt3R is both more accurate and more robust to viewpoint and illumination changes than DUSt3R. Also, aside from the main purpose of predicting point clouds, these models’ results can be used for things like camera calibration, inferring the cameras’ pose, depth estimation, and dense 3d reconstruction.

Being neural network-based, DUSt3R and MASt3R share afflictions similar to neural networks in other domains. Despite their impressive zero-shot performance, these approaches might need to be retrained to work effectively in contexts that differ substantially from their training data, such as in underwater or aerial imagery, or imagery with very wide or very long lenses. Traditional MVS approaches would be more robust to these sorts of changes, provided their models can be adjusted for these settings.

In either case, there’s no one-model-fits-all approach to MVS, much like how our brains are particularly well adapted to MVS from our two eyes, but would need to be “retrained” if our vision was suddenly inverted or if the shape and composition of our eyes suddenly changed in some way.This paper hasn’t brought us any closer to understanding how our brains do MVS, but it has shown that there’s more than one process that can achieve it. Maybe our brains work like DUSt3R, or maybe they have their own method that’s still a mystery.

LLMs like GPT-4 and Gemini act like they know everything. They do know a lot of stuff, but certainly not everything. LLMs also speak confidently, rarely saying things like “I don’t know the answer to your question” or “I’m not sure.” This leads to a widely known problem: LLMs sometimes hallucinate or confabulate information to fill in gaps in their knowledge. This is frustrating and it severely limits their ability to be used in real-world situations, since it’s really hard to fact-check everything an LLM says to make sure it isn’t just making stuff up. But a new paper published in *Nature* proposes a statistical way to detect hallucinations.

The main piece of information that makes confabulation-checking possible is that, when an LLM generates some text, it also says how confident it is about each token that it generates given the tokens that came before. Using these confidence probabilities, there’s actually a simple approach to detecting hallucinations. For example, we could use the following process to detect whether an LLM’s answer to the question “Where is the Eiffel Tower?” is a hallucination:

Ask the LLM to generate many different answers to this question.

Aggregate the probabilities from the tokens within each answer individually.

Combine the probabilities over each answer into a single value, called

*predictive entropy*, which is the conditional entropy of possible answers to the question.

When the predictive entropy is low, it means that the distribution of answers is heavily concentrated to a small number of answers. When it’s high, it means many answers are equally likely.

But there’s a big problem with this approach, as you can see in the figure below. There are many different ways an LLM can answer a question, and sometimes the aggregated probabilities of correct answers are lower than those of incorrect answers. For example, the example shown below indicates that “France’s capital Paris” has lower probability than the hallucination “Rome.” This can result in a misleadingly high naive predictive entropy, even when the space of possible answers is in fact skewed to a small number of answers that, in this example, are *not* hallucinations.

You can probably see the issue with naive entropy in this example: Even though the answer “France’s capital Paris” is correct, it’s just not the way one would typically answer the question. The authors’ solution to this problem is to calculate the entropy across the semantic categories of the LLM’s responses rather than the responses themselves. Semantic equivalence is a relation that holds when two sentences mean the same thing — this idea can be extended to group any number of outputs from the model into these categories. There are two things we need in order to do this: First, a way to find these semantic categories and know which sequence belongs to which category, and second, a new way to combine the sequence-probability values into a final semantic entropy.

To solve the first problem (grouping together similar answers), the authors use special LLMs that determine whether two sentences are semantically equivalent (i.e., mean the same thing). These LLMs can be specialized for this task, such as DeBERTa-Large-MNLI, or general-purpose LLMs like GPT-3.5 that can predict when one bit of text implies another, given suitable prompts.

Once this model has determined which answers belong to the same semantic group, their sequence probabilities need to be combined. Mathematically, this involves summing all the sequences of a given group together. Then, the probabilities of all the semantic groups are combined into a value that the authors call *semantic entropy*, which is similar to how individual sequence probabilities are combined in naive entropy. The authors also have another way to calculate semantic entropy when they can’t access the sequence probabilities: They assume each sequence has equal probability, and thus each semantic cluster’s probability is proportional to how many sequences were generated in that cluster. They call this value *discrete semantic entropy*. This approach makes sense if you have many sample points because you expect to see repeats of answers in proportion to how likely they are. (But, if all the probabilities are very small, then you only ever see each answer once, and the regular semantic entropy value is clearly more accurate.)

To use the semantic entropy approach, an LLM generates many answers to a question, and the authors choose an answer from the highest-probability semantic cluster as the final answer. The figure below compares their two methods: semantic entropy (blue) against naive entropy (green) and other baselines (red and yellow) on QA and math tasks. The *AUROC* metric captures how well the model can distinguish between hallucinations and non-hallucinations, where a value of 0.5 indicates the result from random chance. The *AURAC* metric, which the authors invented, tracks (but isn’t identical to) the accuracy of answers given by a model which can say “I don’t know” when it notices itself hallucinating. In other words, the AUROC score measures how good a system is at *noticing* hallucinations, and the AURAC score measures how good the system is *after filtering out *hallucinations. In both cases, higher is better, and the two scores may be quite different because the AURAC score depends heavily on the accuracy of the underlying model, such as LLaMA 2 or Falcon.

Overall, I think the semantic entropy approach is fantastic. It uses information and models that are already available (token probabilities and semantic equivalence models) and clever statistical tricks to detect confabulations, without having to train specialized models like the baselines do. In fact, this points to something a bit deeper: The red-colored baseline in the figure above is from a paper titled “Language Models (Mostly) Know What They Know.” The title indicates that LLMs actually have some awareness of when they’re making stuff up. The result from this paper adds more weight to that idea. To me, this suggests that hallucinations and confabulations are a consequence of the way we’re training LLMs. Maybe we just need to find a way to train LLMs to know when to say, “Sorry, I don’t know.”

Black holes are extreme objects that demonstrate some of the complex and bizarre ways that the universe works. For example, imagine you were observing an orbiting planet transit across the far side of a black hole. At the moment that you, the black hole, and the planet were in perfect alignment, the planet would look to you like a squished ring around the black hole. This is because of a strange phenomenon called *gravitational lensing* that happens since the black hole’s gravitational field distorts spacetime; i.e., light no longer travels in a straight line! This distortion can be quite counterintuitive because the mental model of spacetime that we use on a daily basis (that light travels in straight lines) is a simplification of the more complicated reality that’s better described by Einstein’s general relativity.

As someone who thinks about machine learning frequently, I find that my brain also constructs similarly helpful-yet-simplified models of machine learning-related concepts, like transformers. For example, when I see the phrase “transformers are next-token predictors,” my brain often reads this as “transformers predict the next word in a sentence, kind of like how a person might anticipate a sentence’s continuation.” Unsurprisingly, mental models like these, while helpful, are imprecise. But what *is* perhaps a surprise is that, mathematically, this particular mental model is imprecise in a very similar way to how our mental model of the spacetime doesn’t reflect the reality described by general relativity. 🤯

Researchers from Google DeepMind have investigated how information propagates in decoder-only transformers. They found that transformers definitely *don’t* continue sentences the way that a person might because of a subtle detail relating to how transformers work that my simplified mental model (the finish-a-sentence model) doesn’t capture.

Here’s a simple illustration of the researchers’ experiments. Imagine asking a transformer what it “thinks” after showing it a sequence of digits. The sequence contains only ones, except for the last digit, which is zero. Here’s what they found: As the number of ones increases, the transformer’s last-token representation (i.e., what it “thinks”) converges. In other words, the representations of the strings 10 and 110 are very different, but the representations of strings with 1,000 ones and 1,001 ones are basically the same. The figure below shows this phenomenon, where the color and proximity of the curved lines illustrate how these representations converge as the sequence length increases.

Why does this happen? At a high-level it’s because, as the sequence length increases, there are exponentially more paths for the information of earlier tokens in the sequence to propagate to the last tokens’s representation than there are for the last token’s information to propagate there. This phenomenon is called *squashing* and, when hardly any of the information from the last token makes it to the final layer, it’s said to have been “over-squashed.” The figure below shows how this happens: The red token has fewer paths to travel to the final token than the blue token does, which demonstrates that earlier tokens tend to squash later ones.

This phenomenon can be formalized mathematically by analyzing how sensitive the output of the last token’s representation is to changes in each of its input tokens. When we do this, there are three possible cases:

The output might be equally sensitive to changes in early and late tokens,

The output might be more sensitive to early tokens and less sensitive to later tokens, or

The output might be less sensitive to early tokens and more sensitive to later tokens.

Based on the examples we’ve looked at, can you tell which of the above three cases the transformer exhibits? If you said the second case, you’re correct! That is, the output is *more* sensitive to changes in earlier tokens and *less* sensitive to changes in later tokens. (If you’re following along with the paper, this is stated in Theorem 5.1 and proved in Appendix Theorem B.5.)

The reason the transformer exhibits this behavior has to do with the number of paths between the input tokens and the output representation. This idea of studying paths between two points is related to a mathematical concept called *curvature*. You can think of the general concept as categorizing the local shape of a space as either spherical (nearby parallel lines converge), Euclidean (nearby parallel lines remain parallel), or hyperbolic (nearby parallel lines diverge). Intuitively, this concept gives you a sense of whether the space around a point is expanding or contracting. One specific kind of curvature called Gaussian curvature describes how curved a 2d manifold (a curved surface) is at a particular point — its value, shown in parentheses below, indicates the shape at the points shown by the black dots.

(Fun fact: Ricci curvature — another measure of the curvature of manifolds — also appears in Einstein’s field equations, where it corresponds to the curvature of spacetime due to energy and momentum! I’m not going to pretend like I understand these equations, but if you’d like to learn more you can read about the connection between neural networks and spacetime in this awesome blog post.)

When we talk about graphs — like the the paths in the transformer — we’re talking about another specific kind of curvature called the *balanced Forman curvature* (as defined in this paper). This concept is another way you could mathematically capture the behavior of information flow that biases the last token toward receiving more data from earlier tokens and less from later tokens. In fact, the over-squashing phenomenon was first observed in another kind of neural network architecture called graph neural networks where over-squashing can occur between edges in a graph that have a small number of paths between them, like in the tree depicted below.

Theory aside, what are the consequences of this analysis? A practical consequence is that a sequence can get so long to the point that it’s literally *impossible* for later tokens’ information to propagate to the last token’s output, since the information has been over-squashed to the point that the precision of floating-point numbers can’t capture it anymore. It also means that — and this is my favorite takeaway from the paper — transformers are really bad at counting. As a transformer counts more numbers, it will eventually forget what the most recently counted numbers were, and they’ll essentially all blur together to the point that it doesn’t know what number should come next! 😂

While this is fascinating, the above discussion is a bit of a simplification of reality since it ignores attention weights. These weights *also* influence the curvature of information propagation and whether over-squashing will occur. In practice, transformers might learn to pay more attention to the tokens that appear later in the sequence.

The researchers also point to a simple way to reduce over-squashing: insert additional tokens, like a comma every three digits in the ones-and-zero example I described. This strategy helps to keep the representations more distant. More permanent solutions might involve modifying the transformer’s architecture, like how this paper adds more edges to the GNN where the Ricci curvature is the most negative. In practice, you probably don’t need to understand differential geometry (I certainly don’t!) in order to use transformers for something practical, but this paper is a helpful reminder that sometimes things aren’t as simple as our brains might pretend they are.

If you’ve ever used Google’s Gemini model, you might have noticed that sometimes it shows you the model’s “drafts,” which make it clear that Gemini isn’t a simple language-in, language-out LLM. It has a more complicated inference procedure that involves drafting before generating its final output. Processes like this drafting trick are a common way to improve an LLM’s accuracy and utility, but sometimes these approaches don’t improve the accuracy enough, or they make the model substantially less efficient. But a new approach called buffer of thoughts (BoT) is an accurate and efficient way to augment an LLM’s output that mimics the way people retrieve and refine their thoughts.

You might be familiar with some of the simpler augmentation processes, such as few-shot prompting or chain-of-thought (CoT) prompting. These approaches use a single query to prompt the LLM; sometimes this makes the LLM more accurate, but it’s not a silver bullet. More sophisticated processes use multiple queries to build a graph or tree of thoughts, iteratively refining and pruning the graph before reaching a final output. We covered the tree of thoughts (ToT) paper last year; give it a read if you haven’t already (link). The figure below shows these single- and multi-query processes, as well as the new BoT approach.

The BoT process has four main components:

The

*problem distiller*extracts the essential parameters and variables from the problem, and outlines the objectives of the input task and any corresponding constraints.The

*meta buffer*is a buffer of problem-solving templates. There are six problem-specific templates for tasks — such as text comprehension, mathematical reasoning, or programming — and three general-purpose templates.From the thought-retrieval process, the

*instantiated reasoning*process selects the template that best matches the input query. This process uses both the template and the distilled information to generate a solution to the query.Finally, the

*thought distillation and update*process uses any new information gained from problems that have been solved using BoT to update the thought templates. This helps the BoT process apply successful problem solving techniques to new problems.

To better understand how BoT actually works, here’s an example from the paper that we’ve summarized for solving the following mathematical problem:

On average, a shop sells 20 shirts every day for a profit of 40 yuan each. The shop wants to expand sales, increase profits, and reduce inventory. An investigation found that for every 1 yuan that their shirts decreased in price, on average they’d sell 2 more shirts per day. How much should the price of each shirt be reduced to average a 1,200 yuan daily profit?

According to the results in the paper, the BoT process outperformed seven other problem solving techniques on 10 different tasks, such as multi-step arithmetic, sonnet writing, and checkmate-in-one (chess). BoT is also much more efficient than competing methods that use multiple queries, and is less affected by the inherent randomness of LLM outputs, achieving a higher problem-solving success rate than other methods.

The authors say that BoT often struggles on tasks that require human creativity since templates aren’t as helpful for these problems. They also found that it is quite sensitive to the quality of the underlying LLM. I also have some thoughts about the BoT process, so it’s ✨Speculation Time!✨ I suspect that BoT’s performance would be highly sensitive to the types of tasks it’s used to solve and the library of thought templates in its thought buffer. This isn’t a dig at the BoT method, since I think it’s reasonable to expect that the templates need to be curated and expanded to make the BoT a more general problem solver. I also suspect that human curation of these templates would be much more effective than automatic curation via the proposed thought-distillation approach, but that obviously isn’t very scalable.

Finally, I think we might be reaching the limit of how much we can improve an LLM’s problem-solving ability by augmenting its generation process with techniques such as CoT, ToT, and BoT. The BoT method is essentially a way to retrieve instructions for solving particular problems. But this isn’t really a general solution to using LLMs to solve problems, and ultimately, it isn’t a replacement for good LLMs and neural network innovation.

There’s a longstanding relationship between chess and AI. Chess played a key role in demonstrating a machine’s ability to effectively learn from data when DeepBlue, IBM’s chess engine, defeated champion Garry Kasparov in 1997. Since then, there have been several state-of-the-art chess engines, such as AlphaZero and Stockfish, which are powered by search.

What exactly does “search” imply in the context of chess engines? If you play chess, you already know that experienced players can think several moves ahead: “If I play this, my opponent will most likely play this, which will then let me play that.” Well, that’s exactly what search allows these chess engines to do — analyze future moves based on a potential move.

Novice- (like myself) and amateur-level players, on the other hand, play chess “in the moment,” one move at a time. Interestingly, several chess champions are purported to have said that they only think one move ahead — hmm, curious. With the vast amount of data at our disposal these days and the availability of large models, imagine if we could train a chess agent that only thinks one move ahead. Would it be as efficient as the current state-of-the-art chess engines? Does training on a large dataset using a model with tons of parameters make a search component unnecessary? This is precisely what this week’s paper investigates.

The researchers trained their chess engine differently than is typical. They trained three models, each of which predicts one type of score:

Action value: the value of making a move

State value: the value of an opponent’s position after a move

Behavioral cloning: an estimated victory probability after making a move

They trained these models using standard supervised learning, which is notable because most chess engines are trained using reinforcement learning algorithms.

For brevity's sake, I'll focus on the action-value model as the researchers found it to be ~6% better than the state-value model and ~13% better than the behavioral-cloning model. To train the action-value model, the authors curated a large dataset of chess games in an unconventional manner. Rather than hand-crafting features, they generated data in a somewhat automated fashion using a chess engine. They downloaded ten million games from Lichess (an online chess platform), and then extracted chess states (information about which pieces are where, whose turn it is to play, etc.) from these games. Finally, they used Stockfish 16 to estimate action-values for every legal action at a given state. Using this strategy, they were able to amass 15.32B action-value estimates to use as the training set. They curated the test set in a similar manner but with fewer games and a different time frame.

They also did something interesting regarding the data representation they fed into the model. They represented board states in FEN (Forsyth-Edwards Notation) format as fixed length ASCII-code strings of length 77 padded with a period if required. They then tokenized these FEN notations by designating each character as a unique token. Also, they stored actions in UCI (Universal Chess Interface) notation (e.g.: e2e4 represents the popular opening where a pawn is moved from square e2 to e4). The researchers tokenized actions by determining all possible legal moves, sorting them alphabetically, and then deriving the index of the action in question and returning it as a token. Afterwards, they concatenated the tokenized board states and actions to derive textual model inputs during training, a deviation from typical practice in other papers, where teams of researchers use some form of numerical representation as input.

The authors of this paper derived model targets by grouping action values into K discrete bins (K=128), which ensures that the task is structured as a classification problem. During training, they provided board states as inputs and action values as targets to a decoder-only transformer model with a context size of 79 and an output size of K. After training, they obtained a chess model which, when shown a board state and all legal actions, computes the confidence score of each legal move one at a time, then picks the action that provides the best win rate by taking the argmax of the confidence scores. Their training setup basically distilled the knowledge of Stockfish 16 (termed an Oracle) into a transformer model.

The authors trained three action-value models of different parameter sizes (9M, 136M, and 270M). Overall, the 270M-parameter model performed best. While playing against bots on Lichess, it had an Elo rating of 2299, outperforming GPT-3.5-turbo-instruct (1755); while playing against humans, it had an Elo rating of 2895! Basically, it played at grandmaster level.

Via ablation experiments, the authors also noted that using training with more data led to better model performance. Their findings imply that, by using a large model with the attention mechanism and a vast amount of data, you could model complex tasks in a supervised learning setup, a task that previously required some combination of reinforcement learning and heuristics. This also shows that transformer models can do more than next-token prediction and can be used on other tasks like classification. While it’s quite impressive to see this model play chess at grandmaster level without analyzing several moves ahead, keep in mind that it was trained for and performed at that level when playing *blitz* chess (a fast game that allows little reaction time). In blitz chess, there really isn’t time to think several moves ahead, so a chess engine devoid of a search strategy may be particularly well suited for this scenario.

The authors don’t claim that their model is perfect. They note that the predictors are often indecisive in the face of overwhelming victory (where more than one move significantly strengthens one’s hand for a checkmate) since many actions may end up with a maximum bin value. This causes the chess agent to play randomly rather than commit to a move that will lead to a checkmate in the fewest number of steps. This critique of the paper highlighted this issue, among others, as a potential red flag and argued that the agent can’t be said to play at a grandmaster level if its endgame is weak.

Personally, I feel the findings in this paper are quite impressive, though perhaps the researchers could have gotten better results by structuring the modeling task as a regression problem rather than a classification problem. But the fact that they attained grandmaster-level chess without using search is a big milestone, one that could potentially lead to chess models that are computationally simpler than the current state-of-the-art ones.

In philosophy of mind, *qualia* are instances of subjective, conscious experience. Some examples are our perception of the *pain* of a headache, the *taste* of wine, and the *redness* of a sunset. Due to the subjective nature of these experiences, we actually don’t know whether two people experience these qualia the same way — you may experience the redness of a sunset in a different way than I do, for example. Do neural networks experience qualia too? In other words, would two neural networks learn to experience the redness of an evening sky in similar ways? A group of researchers have been wondering this, too, and they have a new hypothesis: the Platonic Representation Hypothesis (PRH).

Minyoung Huh and coauthors think that neural networks do in fact experience the world similarly. The figure below states the hypothesis, and includes a diagram that conveys the idea behind it. In the figure, images (X) and text (Y) are projections of a common underlying reality (Z). The PRH conjectures that representation learning algorithms will converge on a shared representation of Z, and making models larger, as well as making data and tasks more diverse, seems to drive this convergence.

On the face of it, the PRH seems like a plausible claim, but what evidence do the researchers have for this hypothetical behavior? Here are some facts that the researchers cite:

**Different models, with different architectures and objectives, can have aligned representations**. For example, one study found that the layers from a model trained on the ImageNet dataset could be “stitched” together with the layers from a separate model trained on the Places-365 dataset (recognizing types of places from images), and the resulting stitched model still had good performance. This suggests the layers are data-independent and compatible with each other. Other studies show that this kind of compatibility also applies to other neural network components, like individual neurons; that is, you can find a neuron per model such that this neuron is activated on seeing the same feature in the input image, no matter which model you use.**Alignment increases with scale and performance**. The researchers cited a few papers that add weight to this claim, but they also conducted an experiment of their own: They evaluated how well 78 different vision models — trained with varying architectures, training objectives, and datasets — transfer to the Visual Task Adaptation Benchmark (VTAB, which is designed to test if a visual model can perform well on tasks it wasn’t specifically trained for). They found that models that transfer well to VTAB had very similar representations, while no such similarity was found among the models that couldn’t adapt to VTAB.**Representations are converging across modalities**. This claim is also supported by several studies, but the researchers conducted their own experiments to determine whether models are indeed learning an increasingly modality-agnostic representation of the world. Using a dataset of paired images and captions, they found that the better a language model is at language modeling, the more its representations aligned with DINOv2, a vision model. These results are shown in the figure below.**Models are increasingly aligning to brains**. They cited some studies to support this claim, though it’s by far the weakest of their five claims.**Alignment predicts performance on downstream tasks**. The researchers found that there’s a correlation between how well language models align to DINOv2 and how well they perform on downstream tasks, like commonsense reasoning and mathematical problem solving.

Next, the researchers investigated why such representational alignment might be occurring. They present three further ideas:

**The Multitask Scaling Hypothesis**. Each training datapoint and objective (task) places an additional constraint on the model. As data and tasks scale, the volume of representations that satisfy these constraints must proportionately grow smaller.**The Capacity Hypothesis**. Larger models and better learning objectives should be better at making models arrive at optimal solutions to problems. This idea sounds interesting, but it doesn’t really explain*why*the representations of optimal models should be similar.**The Simplicity Bias Hypothesis**. All neural networks, even unnecessarily large ones and ones that lack regularization, tend to arrive at the simplest representations. (We actually touched on this topic last week; have a read if you missed it).

The paper goes on to discuss what kinds of representations are being converged to, and some implications of that convergence. But, before we finish up, I think it’s important to mention some counterexamples and limitations of their research:

Representations of modality-specific concepts, such as visually experiencing the beauty of a total solar eclipse, can’t be learned solely from other modalities.

Not all representations are presently converging.

The demographic bias of people creating AI models might also accidentally bias them toward similar representations.

The level of measured alignment might actually be quite small. For example, the maximum measured alignment in the DINOv2 figure above is 0.16 on a scale of 0 to 1. The researchers aren’t sure whether this is indicative of peak alignment or not.

After reading this paper, I’m not convinced that neural networks’ representations are converging. But I do think that the idea is interesting and plausible, and this paper introduces lots of different avenues for further exploration. For example, I’d love to see more investigation into methods of measuring alignment and a more comprehensive analysis of what kinds of representations align, how much alignment varies across different kinds of representations, and how alignment changes as models change.

A well-known rule in statistical machine learning is that a statistical model shouldn’t have more parameters than the number of samples that were used to train it. That’s because the model will have enough parameters to fit each of the samples exactly, and so it will be less likely to generalize to unseen data. But this rule is seemingly contradicted by modern deep neural networks like Llama 3 and Stable Diffusion — models that have hundreds of billions or even *trillions* of parameters. Why can models like these generalize well to unseen data even when their training data size is smaller than their parameter counts? This week’s Learn & Burn will cover this strange phenomenon, known as *double descent*, rather than our typical focus on a single research paper.

Double descent is a phenomenon where a model can continue to generalize well to unseen data even when it has many more parameters than training data samples. The figure below demonstrates this using a polynomial fitting example. The main thing being compared is the parameter count, or degree, of the polynomial, and whether it is less than, approximately equal to, or greater than the number of training samples.

Let’s look at these cases one by one:

On the left, the degree-1 polynomial doesn’t fit the data well because it doesn’t have enough parameters to fit the nonlinear training data.

In the middle, it

*looks*like the fitting problem is solved, since the degree-10 polynomial fits the 11 data points precisely. But a precise fit isn’t ideal since any unseen data drawn from the same distribution as the training data probably won’t fit exactly on this polynomial’s curve. This problem is called*overfitting*.Finally, on the right, the degree-30 polynomial seems to fit the data quite well, despite having substantially more parameters than the training data that it’s being fit to.

We need to include some regularization in the final curve’s optimization; otherwise it would look “bumpier” in order to fit exactly through every point. As we’ll see below, researchers suspect that models that learn in a regularized fashion (intuitively, they prefer learning simpler patterns) correlate with the double descent phenomenon.

(Side note: About a year ago, Yann LeCunn, the Chief AI Scientist at Meta, gave a brief explanation of the double descent phenomenon in this fireside chat. He said that the double descent phenomenon can be observed with polynomial fitting, too. A curious viewer then asked Claude, an LLM like ChatGPT, to write some code that demonstrates double descent with polynomial fitting. The figure you saw above is the one that Claude’s code generated!)

Here’s another figure (from Wikipedia) that I really like that helps explain double descent. Like the sequence of plots above, from left to right, the x-axis shows the behavior of a two-layer neural network when the number of parameters is less than, equal to, and greater than the number of data points. But this time the y-axis shows the training and test errors. We can see where the double descent phenomenon gets its name: As the number of parameters increases, the test error descends, before increasing at the interpolation threshold, and then continuing to decrease again.

Why do neural networks behave this way? We still don’t know the precise answer, but researchers have established that the data’s signal-to-noise ratio (SNR) and the amount of regularization used during training are central to the phenomenon. The figure below shows how these characteristics influence double descent. The show results on datasets with high (left) and low (right) SNR, and the colors show results of models trained with different levels of regularization (low regularization = blue, high regularization = yellow). We can see that, without regularization, the double descent phenomenon occurs regardless of the SNR. But when regularization *is* used, its optimal value — indicated by a test error that doesn’t increase at the interpolation threshold N/n=1 — is slightly different in the high- and low- SNR cases.

*Image source: Fig 3 of https://arxiv.org/pdf/1908.05355*

Double descent and related strange phenomena that arise when we train neural networks — like grokking — still aren’t entirely understood. Researchers are still trying to develop solid theoretical explanations for why they happen. So far, we understand small pieces of the double-descent puzzle, such as:

Poor generalization is most likely at the interpolation threshold

Models with optimal test error typically lie beyond the interpolation threshold

The behavior of double descent depends on the SNR in the data and regularization (like in the figure above)

If we’re lucky, the next time I talk about double descent on Learn & Burn will be when someone properly cracks the double-descent problem wide open!

In 2021, Google DeepMind announced AlphaFold 2, their latest deep learning–based, protein-folding algorithm. Protein folding is the process of predicting the 3d coordinates of the heavy atoms in a given protein using some basic information about that protein, such as its primary amino acid sequence. Protein folding methods that preceded AlphaFold 2 — including the original AlphaFold — were ok, but AlphaFold 2 blew them out of the water, achieving error rates 3x smaller than the next best method. But now AlphaFold 3, which Google DeepMind has developed in collaboration with Isomorphic Labs, goes beyond proteins to analyzing a broad spectrum of biomolecules.

Given an input list of molecules, AlphaFold 3 can determine their joint 3d structure to reveal how they all fit together. In addition to proteins, AlphaFold 3 can model other large molecules like DNA and RNA, as well as smaller ones known as ligands. For example, the figure below shows the structure of a protein (blue) bound to a double helix of DNA (pink), and how this structure compares to the ground-truth, experimentally measured structure (gray).

Architecturally, AlphaFold 3 is very similar to AlphaFold 2. Each method has two main components: one for generating representations of the molecules, and another for predicting their structure. AlphaFold 3’s representation method, called Pairformer, is a simpler version of the one in AlphaFold 2. Both of these methods work like a transformer: The attention mechanism operates on the chemical structure of the biomolecule, and a gating mechanism generates representations for pairs of atoms in the molecule, similar to the causal attention mask in a transformer. The figure below shows an example of the pair representation, and how the elements of the representation correspond to atoms in a graph representation of a molecule.

The next step AlphaFold 3 performs is determining the actual 3d coordinates of each atom in the joint biomolecular structure. Unlike AlphaFold 2, which used a complicated structure-prediction module that needed carefully tuned parameters to ensure that its predictions were plausible, AlphaFold 3 uses a diffusion model to directly predict the 3d coordinates. This works kind of like a text-conditioned, image-generating diffusion model, except that AlphaFold 3 uses the pairwise representations to condition the denoising of the atoms’ coordinates.

While a diffusion model offers many benefits — like not needing to enforce global rotational and translational invariances during generation — it also has drawbacks. For example, the researchers found that the model would hallucinate plausible chemical structures where they shouldn’t exist. To counteract this, they used predicted structures from the AlphaFold-Multimer v2.3 — an extension of AlphaFold 2 for protein complex structure prediction — to enrich AlphaFold 3’s training data. This effectively taught AlphaFold 3 to mimic the non-hallucination behavior of its predecessor.

This new AlphaFold model is a tremendous leap forward in terms of its predictive accuracy, but it’s also a one-stop shop for many biomolecular modeling tasks. For example, AlphaFold 3 achieves much higher accuracy on protein-nucleic acid interactions than nucleic acid–specific predictors. It’s a similar story for protein-ligand interactions and antibody-antigen prediction. The method also demonstrates that deep learning methods are highly effective at modeling a variety of biomolecular interactions and will help us better understand how the most complex processes in our bodies work, like drug interactions, hormone production, and the health-preserving process of DNA repair.

What is a neural network? An archetype is a multilayer perceptron (MLP). MLPs are probably the most widely used architectural component of the NN variants we’re familiar with, like transformers and CNNs. But have you ever considered why this is the case? What if there was something else — maybe even something better — that we could use instead of MLPs? Today’s summary explores a new kind of NN building block called Kolmogorov-Arnold Networks, or KANs. As we’ll see, KANs are a fascinating new way to think about and construct NNs, and they may also offer some insight into why MLPs are so ubiquitous in today’s NN models.

According to Liu et al., KANs were inspired by the Kolmogorov-Arnold Representation Theorem, a mathematical theorem that says it’s possible to write a complicated, multivariate function using a bunch of simpler, univariate ones. If this sounds familiar, that’s because it’s very similar to the universal approximation theorem that underpins MLPs! MLPs work by combining input signals with learnable weights and a fixed activation function, and KANs work by combining inputs with learnable activation functions. The network representations in the figure below show this distinction. We can see that the MLP has different weights (indicated by the red and blue edges) and a fixed activation function (in this case, SiLU), whereas the KAN has much more complex edges that are summed together.

In a KAN, the activation functions are special piecewise functions called B-splines. If you’ve ever used the Bezier curve tool in Adobe Photoshop or a similar tool, then you might already have an intuitive idea about how B-splines can model an arbitrary function. The figure below shows a curve that is a weighted sum of several B-splines. The shape of the splines (and thus the overall curve) is controlled by the position of the anchor points. During training, the positions of these points are optimized so that the resulting splines act as useful activation functions.

Using splines as the basis of the “neurons” in a KAN yields an NN with useful properties. For starters, KANs are more interpretable than MLPs, since visual representations of their complex activations are much more understandable than reading the raw weights of an MLP layer. Additionally, more parameters can be added to a KAN after partial training to improve its accuracy (after more training to improve the added parameters). Unlike MLPs — where added parameters can affect the whole NN — each additional KAN parameter only affects two splines (if the splines are quadratic), leaving the other splines unaffected. To use another Photoshop analogy, this is like drawing a rough outline of an object with a small number of Bezier curves, and then adding in more points later to refine it.

Another benefit is that sometimes a simpler KAN can be better than a more complex one. For example, the researchers used two different KANs to fit a two-parameter function. To describe these networks, they used the notation [*n_0, n_1, …, n_L*], where *n_0* is the number of inputs or parameters in the data, *n_i* is the number of nodes or activation functions in layer *i*, and *L* is the total number of layers. For a certain two-parameter function, the authors tried a [2, 5, 1] KAN and a [2, 1, 1] KAN. The figure below shows the expression for the function *f(x,y)* that the KAN is fitting, and the error levels (RMSE) over training time. You can see that the [2, 1, 1] KAN achieves a lower test/training loss overall at the marked interpolation threshold.

The reason for the stepped nature of the loss in these training results is that, during training, the researchers progressively added more splines to each of the KAN’s activation functions, improving their accuracy. This is indicated by the “grid” information printed on the plots. It makes intuitive sense that the [2, 1, 1] KAN should perform better since it only takes two activations to fit the data in this case: one that sums a spline each for sin(πx) and y^2, and another for exp. The figure below shows what the activation functions would look like in this [2, 1, 1] KAN.

Unfortunately, this paper only demonstrates the ability of KANs on toy datasets like the one above. They didn’t run tests on a simple real-world dataset, like the MNIST dataset. One reason for this might be because training a KAN is about 10x slower than training an MLP with the same number of parameters — though this may just be due to an inefficient training implementation that could improve as more researchers iterate on the training algorithm.

Also, while KANs seem to offer several advantages compared to MLPs, they are limited in several ways. For example, a KAN’s activation is only defined on a fixed set of input values, since splines don’t extend infinitely across the input’s domain. And, while KANs may be interpretable in toy examples, I’m not convinced that being able to observe the activations in a more complex KAN — say, one with more than a dozen layers — would really be that helpful. And of course there’s the argument that MLPs are already universal approximators, so an MLP can also represent a KAN. So I won’t hold my breath waiting for KANs to revolutionize NNs. But they’re cool nonetheless, and they may have some niche mathematical applications, such as for symbolic regression or knot theory. If that sounds interesting to you, then I highly recommend you give this paper a read!

Just a few weeks ago, we covered some papers from Hao Liu — a PhD student at UC Berkeley — that describe how to modify an LLM’s architecture so that it can process million-token sequences. In the article, I speculated that these techniques might be how Gemini 1.5 Pro — a new LLM from Google — achieved its long-context processing abilities. Well, just a few weeks ago, Google released a research paper of their own describing how an LLM can process *infinite*-length contexts using an interesting idea called compressive memory. The authors call their technique *Infini-attention*.

The figure below shows how the “attention” part of Infini-attention works. It begins by chunking the input sequence into segments, and then computing attention on a segment-by-segment basis. The green blocks represent a “memory” that’s updated with information from each segment that is then accessible by all subsequent segments.

As you can see in the next figure, Infini-attention is actually a combination of two types of attention: vanilla attention (purple) and compressive memory plus linear attention (green). Due to the segmentation, the purple vanilla-attention blocks compute the interactions between the queries, keys, and vectors of within-segment tokens only (Qs and {KV}s). The compressive memory, on the other hand, remembers the keys and values from previous segments {KV}s-1. The memory is built up iteratively, starting with the first segment. So, by querying this memory with the current queries Qs, the output projection from each segment can still be aware of the keys and values of all tokens in all preceding segments.

These two kinds of attention are similar but differ slightly. Here’s how the linear attention (green) and vanilla attention (purple) are computed:

There are two things to note: First, each attention uses a different non-linearity: ELU+1 (nicknamed σ) for linear attention, and softmax for vanilla attention. Second, the vanilla attention requires quadratic space to compute the softmax, whereas the linear attention can be computed linearly, i.e., without quadratic space. I’ll discuss the implications of these differences in a moment.

The compressive memory is a special kind of memory called an *associative memory*. At its most basic level, this memory is just a matrix of values that changes every time it’s updated with new information — which, in the case of Infini-attention, are the keys and values from segments. The memory’s contents for the ith segment can be written as follows:

The subscripts on the K’s and V’s indicate which segment they’re from. The cool thing about this memory is that we can retrieve any given value *v* as long as we have the corresponding key *k*, where *v* and *k* are row vectors that have been previously stored in the memory. If we pretend that all of the K’s and V’s in the above equation have only a single row, and if we ignore the non-linearities for a moment, then we can retrieve V_2 from the memory using K_2 like this:

The way I see it, structuring the memory this way has three main benefits: First, the size of the memory stays the same no matter how many segment’s worth of keys and values are stored in it — it’s a d_k × d_v matrix. Second, the form of M is ready to be queried as-is: ignoring non-linearities, QM is equivalent to the definition of A_{linear}.

The third cool thing about this memory is that storing keys/values can be made slightly more efficient by modifying the storage process. The regular way to store new keys and values is to simply add them to the previous memory state M_{i-1}, like this:

But with Infini-attention, we can *retrieve* the existing value from the memory and only update the memory with the difference between the value that was stored previously, like this:

The authors call this a “linear + delta” memory update, and it helps stop the memory from getting cluttered with too many values.

The concept of combining compressive memory with linear attention is definitely neat. But it left me thinking: If the compressive memory with linear attention can cover the entire context length, then what’s the point of continuing to use the regular attention?

The reason they did so is that these two types of “attention” aren’t really equivalent. The linear attention idea isn’t new — it’s commonly used in compute-constrained contexts or when large contexts are necessary — but it’s not as effective as vanilla attention. That’s why vanilla attention hasn’t gone away; for example, the best open-source LLMs like Llama 3 use vanilla attention. So, I guess Infini-attention is trying to have the best of both worlds: It uses vanilla attention for short snippets of a long context, and linear attention to cover the rest of the context. Ultimately, though, I suspect Infini-attention probably isn’t going to work as well as full-context vanilla attention.

In the rest of the paper, the Google researchers present some experimental results that show that Infini-attention works really well at the “passkey task” — which is essentially a needle-in-a-haystack challenge — after the model has been fine-tuned a little bit on that task. They also show that the Infini-attention method works better than other long-context techniques, and that the “linear + delta” approach works slightly better than the linear-only approach. I was hoping that this paper was going to reveal a silver bullet for long-context attention, but in reality it seems like another band-aid approach to me. But it’s great to see that people are coming up with new and — in the case of the associative-memory approach — very clever techniques.

In general, the more parameters an LLM has, the better it performs. The best-performing open-weight LLMs have hundreds of billions of parameters, but, oddly, not all of these parameters are useful. For a long time, AI researchers have known that some of a model’s parameters are much more important than others. Researchers even have a technique called “pruning,” which they use to remove some of these unhelpful parameters and reduce the size of the model without affecting its predictive performance very much. But there often wasn’t much rhyme or reason as to which parameters were useless — that is, until a recent research paper from Gromov et al., which found that *entire layers* of parameters in an LLM’s network can be pruned!

I know what you’re thinking: How can an entire layer be removed from an LLM — surely that’ll have a huge impact on its accuracy, right? Actually, it’s not quite that straightforward. The researchers found that certain layers can be pruned with minimal impact on the LLM’s predictive performance. In fact, layers can continue to be pruned, up to a point, before the LLM’s performance falls off a cliff.

Also, the predictive performance that’s lost by pruning layers can be restored with a tiny amount of fine-tuning of the pruned LLM. The figure below shows the predictive performance (y-axes) of the Llama-2-70B model against the fraction of the model’s layers that have been dropped. The top-two plots show the model’s accuracy on two question-answering benchmarks, while the bottom plot shows the validation loss. The dark-blue trace shows the pruned model’s performance, while the light-blue trace shows the performance with “healing,” which is the post-pruning fine-tuning.

Before continuing, it’s worth taking a moment to consider why a pruned LLM behaves this way. I often think of the parameters in an LLM working together in perfect harmony. Disrupting this harmony — say, by deleting an entire layer! — would be very damaging for the network since any errors introduced would cascade down subsequent layers. This intuition is a decent model, so long as a layer’s output is significantly different from its input.

But, what if a layer’s output *wasn’t* significantly different? In that case, removing the layer shouldn’t affect the network very much, since it didn’t do much to begin with. In fact, with typical transformer architecture models, we *do* expect layers to have output similar to their input because each layer *adds* a delta to the input — that is, the output of a transformer layer is always an adjustment (something added to) its input. And, since each layer is adjusting its input, we might also expect earlier layers to have more impact since their changes have a compound effect on the later layers.

So, suppose that layers can be removed from an LLM without hurting its performance. This should indicate that the original LLM contains some layers that aren’t very useful (i.e., their output is typically quite similar to their input). That’s what these researchers found! The figure below shows how much a given layer’s output changes from its input. They measured it using a metric called Shifted Rescaled Angular Distance, which is close to 1 when the change is large and close to 0 when the change is small. (In case you’re curious, this distance is a scaled factor of the angle between the vectors before and after the layers being measured.) The y-axes in these plots indicates the number of consecutive layers that were pruned — so the bottom row represents the full architecture, while higher rows represent heavily pruned versions of the model. I’ll describe in a moment which layers were removed.

Across all model sizes (and other non-Llama LLMs that aren’t pictured), there seems to be a trend that the deeper layers in a network tend to contribute less than shallower layers. This means that many layers can be pruned without harming the LLM much. So, based on these results, the researchers devised the following strategy to prune layers from an LLM:

Choose how many layers you want to prune,

*n*.Compute the similarity (angular distance) between the inputs of all pairs of layers that are exactly

*n*layers apart.Select the pair with the highest similarity (lowest angular distance) and prune them.

Optionally, heal the network with some fine-tuning.

This strategy is quite simple, but it requires a lot of work before pruning to determine the angular distances between input layers; it also requires that the user load and run the entire unpruned model. This might be prohibitive for some users, since they might be pruning because they don’t have the resources to run an unpruned LLM. So, the researchers devised an even simpler pruning strategy: Decide how many layers you want to prune — call this number *n* — and remove the last *n* layers *before* the final layer (which the researchers noticed is always a useful layer), and then heal with fine-tuning. In the figure above you can see that removing the final layer is never a good idea; it has much lower similarity to the layers that precede it (it’s blue while the preceding layers are yellow).

The figure below compares the quality of these approaches. Each graph plots both the simple pruning strategy (in red) and the more complex similarity-based strategy (in blue). While the similarity-based approach tends to preserve more accuracy than the simple approach, the difference mostly vanishes when healing is used, as shown in the right column. So, if you plan to heal, then either approach is suitable.

By applying this pruning approach with the latest model-quantization techniques, Llama-2-70B — which spans 140 GB of memory and consumes 30 billion floating-point operations (FLOPS) per token — can run with significantly fewer resources: 17.5 GB of memory and 15 billion FLOPS. This makes it possible to run the model on consumer computers, not just big, beefy datacenter computers. But it also leaves me wondering why the models we train have layers that don’t contribute much to the result in the first place. Pruning is great, but wouldn’t it be better if we could train equivalent LLMs that didn’t have unnecessary layers in the first place?

Have you ever right-clicked a webpage and selected “View Page Source?” If so, then you’ve glimpsed the world of frontend web development — the source code that tells your browser how the things on the webpage should look. Unless you’re a frontend developer, you’d probably have a hard time seeing a page design and then mapping it to its source code. But could an AI do that? Today’s paper explores whether multimodal AIs like GPT 4 or Gemini Vision Pro can generate a webpage’s source code from an image of a page design. If it *is* possible, then this could become a component of webpage-building AI tools.

Design2Code is a framework based on the above premise — it auto-generates the code for a webpage based on an image of what that page should look like. The framework includes a dataset of webpage screenshots and their corresponding source code. The dataset contains a diverse range of webpages, including blogs, company/organization webpages, product pages, and news pages. Unlike other datasets of this kind that are typically generated synthetically, Design2Code’s dataset is sourced from real-world webpages. Here are a few examples from it:

The Design2Code dataset only contains 484 examples because it’s not meant for training models, but for evaluating how well the model can generate webpages. The Design2Code framework can score a webpage-generating AI along 5 axes. To help generate these scores, the framework divides both the input image and the output webpage into *blocks*, which are rectangular sections within the image, and an associated subset of the resulting webpage. The blocks come in pairs (one from the image, one from the generated webpage), and it’s good if the blocks in a pair are similar to each other. Within that context, the authors measured these forms of similarity between image and generated page:

Color: The perceptual difference between colors in the reference image and generated webpages.

Position: How closely the coordinates of blocks on the reference image and generated pages match.

Block-match: Overall, how closely the set of blocks in each page match with the set of blocks in the image. (This criteria can help punish dropped or hallucinated blocks.)

Text: How similar the text is between matched blocks from each webpage.

CLIP: How similar each webpage is to its reference image overall, which is achieved using image embeddings (a semantic vectorization) for screenshots of both the page and of the image.

The first four of these axes are all done on a block-by-block basis. To do this, the authors needed a way to determine which blocks in each webpage correspond to each other, even when the two blocks aren’t exactly the same. For example, if there is a “About us” block in the reference page, the algorithm might match it with the “About” block in the generated page. To do this, the authors use a fancy but standard algorithm (called the Jonker-Volgenant algorithm).

The figure below shows a radar chart comparing the performance of four different webpage-generating models: GPT-4V (Vision); Gemini Pro Vision; WebSight, which is a model trained on a synthetic dataset of webpage-code data; and Design2Code, which we’ll discuss in more detail shortly.

The authors then used the Design2Code benchmarking data and these metrics to evaluate how well these models generate webpages. GPT-4V consistently scores the best or close to the best across all the metrics. Even when the researchers tried various prompting techniques (e.g., direct, text-augmented, and self-revision prompting), GPT-4V came out on top. This was true even when people evaluated the generated webpages — they preferred GPT-4V over Gemini and other open-source models.

The quantitative aspects of Design2Code are really important for making incremental improvements in automatic webpage generation. Yet, on their own they left me feeling like these models aren’t quite up to the task yet. For example, even GPT-4V has a block-match score of only 78%! (Intuitively, this means something like this: The webpages made by GPT-4V either had only 78% of the blocks they should have, or that the image contained only 78% of the blocks that were in the created page; the 22% discrepancy is bad either way.) However, this is where the most intriguing aspect of the paper comes in: The researchers had people compare an original webpage to a webpage that was generated by GPT-4V. They then asked these people: Can the AI-generated webpage replace the original webpage? And is the reference webpage or AI generation better?

Amazingly, 49% of respondents considered the AI webpage to be interchangeable with the original, while 64% said they *preferred* the AI webpage to the original! I find this fascinating because, even though Design2Code provides really useful metrics for scoring webpage reproduction, none of them capture whether the generated webpage is functionally *good enough*. This is an aspect that would be great to see in more machine learning papers, especially ones that explore practical applications of AI. It’s a reminder that an AI model doesn’t need to score 100% on the relevant quantitative metrics for it to be useful. Sometimes, the bar can be much lower than that, and other times the metrics might not capture progress on fundamental questions like, Could this AI do a human’s job?

One final thought: As someone who shivers in fear when I hear the words “HTML” or “CSS,” I was really hoping that Design2Code would mean that I’d never have to write a line of webpage code again in my life. Unfortunately, it seems like the reality is that webpage generation remains a challenging task for AIs. But one take away from the paper is that webpages become much harder to generate (according to the 5 axes outlined above) as the total number of tags increases. Unsurprisingly, this means that simpler webpages are easier to reproduce than more complex ones. So, if you have a simple design in mind and you just can’t be bothered to code it up, there’s a decent chance that an AI might be able to do it for you. However, if your design is complicated, then it might be best to leave it to the pros!

By now, you've probably heard of diffusion models. They’re those neural networks that transform an array of useless values, like random noise, into an array of meaningful data. Diffusion models are most famous for generating images and videos that — you guessed it — are arrays of useful data. You know what else are just arrays of useful data? Neural network parameters! So, can diffusion models be used to turn noise into useful network parameters? That’s the question that Wang et al. try to answer in their recent paper: Neural Network Diffusion.

How is it even possible to diffuse the parameters of a neural network (NN)? Architecturally, the setup is pretty much identical to how diffusion models generate images or video. The authors (Wang et al.) call their method *p-diff* (for “parameter diffusion”).The setup, shown in the figure below, involves a parameter autoencoder (upper left) and a latent diffusion model a.k.a. LDM (upper right). Once the autoencoder is trained (more on this later), its decoder can be used to generate network parameters from the diffused latent representations (lower half).

Instead of training on web-scale datasets of images and video, the researchers compiled their own datasets of neural network parameters. Each dataset consists of very minor variations to a subset of a single NN’s parameters. To acquire these parameter variations, the researchers trained a model from scratch, and then — in the last epoch of training — froze the non-subset weights and continued to train the subset of weights that the diffusion model will be able to generate. Checkpoints of the subset weights slowly changed during this final period of training, and these are what the researchers used to train the LDM and autoencoder.

Training then proceeded as normal: First, the autoencoder was trained to encode latent representations of the input parameters, and the decoder had to decode them to recreate the input parameters by minimizing the mean-squared reconstruction error. Then, the LDM was trained to remove noise from noisy latent representations of the input parameters from the dataset. This all means that this trained autoencoder and LDM can only diffuse parameters for a single model, not several different models.

You might be wondering, “Are the different variations of parameter subsets really diverse enough to train p-diff to generate distinct representations? Essentially, is p-diff just memorizing the input data?” One way to check would be to compare how well the model with diffused parameters performs on tests compared to regularly trained models. Unsurprisingly, models with diffused parameters perform better when there are more parameter variants in its training set. But, given enough training data, models with diffused parameters perform about on par with their regularly trained counterparts, not reliably better.

That doesn’t really answer our question though. Maybe p-diff has just learned to copy the parameters of the model that performs best in its training data. The authors wondered this too, so they devised a similarity metric to compare two different models. The method compares the agreement between the two models’ predictions: It’s 100% when they precisely agree on the classification for examples in a test dataset, and 0% when they disagree on every example (and in between for partial agreement).

The figure below compares original models with p-diff models. In the bar chart, the mean similarity among original models (yellow) is much higher than the mean similarity among p-diff models (orange); yet the similarity among pairs of original-to-p-diff models is a little bit higher than p-diff models alone (pink). This indicates that p-diff models are more distinct from each other than from original models, but not by much.

The upper scatter plot compares the accuracy of individual models from the bar chart (y-axis) with their similarity to a single baseline model (x-axis). Original models (blue) all achieve similar accuracy and are very similar to each other, while p-diff models perform differently, sometimes much better or worse. Finally, the lower scatter plot shows a t-SNE plot of the models’ parameters. Since the p-diff points and original points are in their own clusters, the takeaway is that there’s something distinct about each set of parameters (since t-SNE can distinguish between them), but it’s not entirely clear *what*, as t-SNE is a bit of a black-box.

In my opinion, I don’t see methods like p-diff being used for practical applications any time soon — NNs are already a black-box, and applying a subsequent NN that changes the original’s parameters in a non-obvious way seems a bit dubious. But, as a research exercise, p-diff is absolutely fascinating! I’ve always imagined a NN’s parameters and its inputs pairing as being very tightly yet delicately coupled to each other and, aside from further training, changes to some or all of the parameters would throw the whole NN out of whack. But with p-diff, that isn’t the case at all. I’d love to see future research delve deeper into what the differences are between p-diff models and vanilla ones — that would help me understand what p-diff learns from the parameter subsets in its training data.

[Links to papers at the bottom]

According to Google’s own technical report, their new Gemini model, Gemini 1.5 Pro, is an impressive beast. It boasts multimodality, super-efficient inference, and class-leading language modeling abilities. But its most impressive feature (to me) is its absolutely bonkers context length of up to 10M tokens! That’s ten times the entirety of the 1,440-page book (or 587,287 words) "War and Peace." Unfortunately, the report doesn’t disclose exactly how the Gemini team achieved this impressive feat. They are no doubt trying to conceal several cleverly-engineered technical solutions, but — and this is entirely speculation on my part — perhaps some of these solutions are a series of techniques developed by Hao Liu, a PhD student from UC Berkeley.

Last August, Liu published a new algorithm (paper [1]) for computing the attention and feedforward steps in Transformers using *blocks*. A single block is just a subset of either query vectors or key-value vector pairs. Instead of computing the interactions between all query, key, and value vectors at once and using the result in the subsequent feedforward projection, the data is chunked into two groups of blocks: one group of blocks for the queries and another for the keys and values. Then, the output of the attention and feedforward steps are computed on a block-by-block basis, as shown here:

Liu calls this method a Blockwise Parallel Transformer (BPT), since the blockwise computations can all be computed in parallel. But the main benefit of BPT isn’t faster computation via parallelism — it’s that the memory consumption of each blockwise computation is significantly smaller than a non-BPT equivalent. Normally, if an LLM accepts an input sequence of length *s*, then the result of multiplying all query vectors in matrix Q by all key vectors in matrix K will give us a large s×s matrix. Traditionally, that entire matrix is stored in memory at once because each row must be normalized via softmax so that we can find the weights used by the attention mechanism. The next image shows one aspect of BPT: how it uses a smaller query-key product to reduce memory needs.

With that one idea, blockwise computation already reduces the memory usage down from O(*s*²). But BPT does more: It also computes, in blocks, the output of the feedforward network (FFN) after the attention module. Normally the hidden layer of the FFN, which is larger than its input, uses up 4x as much memory as its output. Holding onto that data would be the next memory bottleneck (after the *s*×*s* Q-K matrix product). BPT reduces memory usage here, too, by only requiring us to hold in memory a single (query-based) block at a time. In other words, the FFN is only ever computed on a small subsequence at a time, requiring less memory. This makes the memory bottleneck smaller still, which is how BPT is 4x more efficient even when compared against *Flash Attention*, another method of reducing attention’s memory needs.

Despite how efficiently BPT uses memory, its memory requirement still scales linearly with the input sequence length. So even though it’s quite efficient, it would still have an impressive but limited context length. BPT alone couldn’t possibly account for Gemini 1.5 Pro’s gargantuan context length.

A short three months later, Liu followed up the BPT paper with another BPT-based paper. By extending the BPT approach with something called *Ring Attention* (paper [2]), the method’s memory requirements scale linearly with the size of the query, key, and value blocks, rather than the length of the input sequence. The idea behind Ring Attention is that the BPT method’s parallel block computations can be distributed among several devices (think GPUs or TPUs), where each device computes one element of the outer loop from the figure above.

The main problem with distributing BPT across devices this way is that, even though the devices each need only their respective query block, they still need access to *all* the key-value blocks; holding them on a single device would hurt memory efficiency. To overcome this challenge, the devices are arranged into a ring-like topology (each device can share data with either a single *previous* device or a single *following* device, connected in a ring).

The BPT’s fused attention-and-feedforward computation then proceeds as follows: First, each device holds the data for one of the query blocks in the outer loop, while key-value blocks are rotated between devices in a ring. This is really efficient since, while the device is performing its blockwise attention, it can simultaneously send the key-value block it’s currently holding to the following device and receive a new key-value block from the previous device.

Ring Attention is depicted in the figure below. The solid boxes represent computations (solid outlines) and data (dashed outlines) that are stored on one of the hosts, while the faded boxes represent data for other devices. As you can see, the device depicted is responsible for all the attention and feedforward computations for a single query block, while the keys and values traverse the devices in the ring.

The key reason why the Ring Attention method computes the same result as in a vanilla Transformer is actually really subtle. For instance, it’s not obvious that blockwise self-attention and feedforward computations are equivalent to their vanilla counterparts — but they are! That’s because the final result of the blockwise computations is invariant to the order in which they were computed. It’s similar to how the result of the computation 4+5 is the same as 5+4: It doesn’t matter which “block” comes first.

Ring Attention takes BPT’s memory efficiency to a whole new level, achieving context lengths 2–3 orders of magnitude longer than the competition. Intuitively, you could say that Ring Attention can work with a context window proportional to the number of devices in the ring. The figure below shows the maximum context size for various Transformer architectures on a TPUv4-1024 chip, which has 32GB of memory. The green bar represents the maximum context length for a vanilla Transformer, the yellow for a Transformer with Flash Attention, lavender for BPT, and the red for BPT with Ring Attention.

As if that wasn’t enough, Liu and coauthor Wilson Yan wrote a new paper that extends BPT with Ring Attention even further (paper [3]). In their paper, they present some ways to train models that understand text *and* video sequences. They also present solutions to various technical challenges that arise when doing such training, such as what to do when text and video sequences have different lengths or how to apply different weights to language and vision inputs. This article is getting quite long, but I encourage you to give their paper a read if you want to learn more.

Gemini 1.5 Pro is an undeniably impressive feat of engineering, requiring vast amounts of resources. But it’s important to remember the human element behind advancements like Gemini. I don’t know of an official link between Hao Liu and Gemini — maybe they're not really linked — but we can still appreciate the significance of Liu’s contributions in achieving a new era in LLM context lengths. Individual researchers like Liu can make significant contributions, even seemingly single-handedly developing groundbreaking algorithms that propel entire fields forward. Even the most opaque and large-scale technological achievements are ultimately driven by personal ingenuity.

—

Hi, it’s still me, Adrian. I hope you enjoyed today’s summary as much as I enjoyed writing it. Actually, I lie — mostly I enjoyed reading Hao Liu’s papers. As someone who just submitted their PhD thesis, let me say that what Liu has achieved over the span of their PhD so far is nothing short of incredible. If you share this view, you should definitely give these papers a thorough read:

[1] Blockwise Parallel Transformer for Large Context Models (This is the original BPT paper).

[2] Ring Attention with Blockwise Transformers for Near-Infinite Context (This is the one that extends BPT with Ring Attention)

[3] World Model on Million-Length Video and Language with Blockwise Ring Attention (This is the one that does the crazy multimodal text-video training).

If you’re a regular ChatGPT user, you’ve probably noticed that, capable as ChatGPT may be, it struggles with some tasks. For example, if you asked it to compose a rap song for you, it might spit out some impressive rhymes in a typical verse-chorus song structure, but not an actual *song*. Wouldn’t it be cool if LLMs like ChatGPT could generate songs in a more complete sense, with information about tempo, chords, melodies, structure, and motifs? A new model called ChatMusician can do just that, as the open-source research community Multimodal Art Projection (M-A-P) announced in this week’s paper.

When I think of audio-based neural networks (NN), I immediately think of bespoke NN architectures that are tailored to address the specific challenges of audio signals. But ChatMusician doesn’t use a bespoke architecture; it’s just a fine-tuned version of Llama 2. Despite its standard design, ChatMusician can learn about music by leveraging a music-specific format called ABC notation. This format is a text-based, shorthand way to write a piece of music. The figure below (from Wikipedia) shows ABC notation and the corresponding staff notation for the same song.

To teach ChatMusician about music, the M-A-P researchers curated their own dataset to continue training Llama 2. They included the following categories of text in their training corpus:

General music corpera, which are text documents containing music terminology.

Instruction and chat data to help Llama 2 learn how to chat and answer questions.

Music knowledge and music summaries, which are summaries of the metadata from 2 million songs from YouTube, and music knowledge QA pairs that are generated from these summaries using LLMs.

Math and code data, which the researchers think will aid ChatMusician with symbolic reasoning of music scores.

In addition, the researchers also curated a music theory dataset called MusicTheoryBench. To do that, they hired a college-level music teacher to create 372 multiple-choice questions (each with 4 choices) about music knowledge and reasoning. (If a question included music notation, the authors converted it into ABC notation.) The figure below shows examples of these knowledge (top) and reasoning (bottom) questions.

The researchers used these questions to evaluate several ChatMusician variants’ musical understanding against GPT3.5, GPT4, and Llama 2. I think there are two main takeaways from their results, which you can see below: First, GPT4 knows a lot about music, but ChatMusician’s music-specific training makes it more effective at musical reasoning than Llama 2. Second, music *reasoning* is hard: All the models score about as good as random guessing, though the ChatMusician models perform marginally better than that.

Beyond quantitative results, the researchers also demonstrate how ChatMusician can generate music. The figure below shows the ABC notation (top) and the corresponding staff notation (bottom) of a song created by ChatMusician. We can see some of the key features of ABC notation, such as the “|:” and “:|” repetition symbols (blue), as well as repetition and motifs in the colored sections. (Red blocks are one motif, yellow are another, and green blocks represent variation on the preceding motif. I think these colors represent analysis by a human of the model-created song.)

The researchers conducted several other experiments, such as using ChatMusician to analyze the compression ratio of ABC notation with other music data formats (e.g., MIDI and WAV), exploring the use of few-shot learning with GPT4 on music knowledge and reasoning, and a study where participants were asked whether they prefer the music generated by GPT4, ChatMusician, or an actual song. In the last experiment, people preferred ChatMusician 76% of the time versus GPT4’s 44% when compared to the actual song.

Despite the impressive results of their research, I think this represents just the tip of the iceberg for musical LLMs. The approach they used to craft MusicTheoryBench could be extended to create more training data to make an even better musical LLM. These models could open up a new kind of laboratory for musical experimentation. Perhaps one day musicians, chatting with future music-generation models, will have powerful new tools at their disposal.

Years ago when I graduated from university, my first job was to design parts of logic chips that implement data-compression algorithms. This meant that I needed to understand the compression algorithms themselves *and* needed to modify them so that I could minimize the number of logic gates they used. (If you’re familiar with FPGAs, I was trying to minimize lookup table usage.)

One of these compression algorithms was based on a small neural network (NN). Using what was cutting-edge research at the time, I figured out how to replace the NN’s floating-point arithmetic with fixed-point, integer-based arithmetic. This was a boon for minimizing logic usage, since integer-based arithmetic is far simpler to implement than floating-point arithmetic. While I was quite impressed with my achievements at the time, these kinds of NN optimizations have come a *long* way since then. Today’s paper presents one such example: BitNet b1.58. (I’ll explain where this magic number 1.58 comes from below.)

BitNet b1.58 is a peculiar NN proposed by GeneralAI, a Microsoft-backed research lab based in Beijing. It builds upon an LLM called BitNet, which is previous work from the same lab that aims to make NN-computations more efficient by reducing the number of bits used in each operation. The bit-reduction techniques fall on a spectrum:

On one end, vanilla NNs typically use 32-bit floating-point (or fp32) values, which is the standard data format for doing math with real numbers.

With some loss in precision, fp16 can be used instead of fp32, which yields some computation and memory benefits.

Fixed-point based methods typically use 8-bit integers (int8). This approach can often be substantially more efficient than fp16 (especially for hardware optimized for integer arithmetic), but with substantial added complexity and loss in precision.

Finally, BitNet takes bit-reduction to the extreme. It uses 1-bit weights — that is, the weights are either +1 or –1.

Before I describe BitNet b1.58, it’s worth understanding BitNet a bit more. A NN with 1-bit weights involves completely re-thinking how a NN functions, down to the individual math operations it uses. For starters, BitNet only uses 1-bit operations for the *fully-connected* layers in its transformer architecture, though these layers account for the vast majority of the LLM’s computational requirements. BitNet’s replacement unit for a fully connected component is called a *BitLinear* module. The BitLinear’s activations and intermediate results are stored in higher, 8-bit precision. The image below shows the layer, with *β* and *γ* being additional values that BitLinear uses to dequantize the accumulated result (which I’ll discuss next) into the 8-bit range.

NNs are mostly comprised of matrix-multiplication operations, which can be implemented using several multiply-accumulate steps: Multiply a weight and an input value, accumulate the result, and repeat. The genius behind using weights that are either +1 or –1 is that the multiplication step can be simplified to an addition (or subtraction) instead: Given the value of the weight, add (or subtract) the input from the accumulator.

This is where BitNet b1.58 comes in. The “b1.58” part of its name is derived from how many values the weights can take: Instead of just +1 and –1, BitNet b1.58 adds in a third value, 0. So, instead of it using 1-bit weights like BitNet, it uses log2(3) = 1.58-bit weights. (The math works like this: If we could have ~1.58 bits, then those bits could store ~2^1.58 different values — that is, ~3 values.) Also, BitNet’s accumulate-only property is retained: Instead of adding (or subtracting) only, now there’s a third option, “do nothing with this input.” One downside of the BitNet approach is that the LLM must be trained from scratch: An existing model can’t just be converted to a 1-bit or 1.58-bit NN. However, fp32 models can be converted to fp16 or even int8 without full retraining.

By this point, you might be asking yourself, “If BitNet was so good, how can adding an additional weight value be beneficial for computational efficiency?” To answer this question, I would really love to show you some experimental results comparing BitNet to BitNet b1.58, but unfortunately the paper doesn’t include such a comparison. But I think we can intuit the following:

The size of the weights in memory will at least double since two bits are required to store a ternary without compression (i.e., you can’t actually have 1.58 bits).

The additional algorithmic complexity induced by the 0 weight should have negligible effect on runtime compared to the original BitNet.

The language-modeling accuracy should improve compared to BitNet, since the weights can be more precise.

The results shown in the paper compare BitNet b1.58 to Llama (the original, not Llama 2). Across seven different tasks and three different model sizes (700M, 1.3B, and 3B), BitNet b1.58 performs within ~1% of Llama. Also, the figure below shows the latency and memory required for Llama and BitNet b1.58, and highlights how much more efficient the new approach is on an equal-parameter basis. Interestingly, the authors also report efficiency improvements for 13B- and 70B-sized models, but they don’t discuss the comparative accuracy of these model variants.

The final, and perhaps most important part of optimized NNs is the hardware that they run on. The creators of BitNet b1.58 made their comparisons using GPUs, so I think we can assume that BitNet models run more efficiently on GPUs than on CPUs. I’d say this provides a fair comparison, since specialized hardware for running 1-bit NNs doesn’t exist yet. But the authors note that hardware specifically designed for 1-bit NNs would yield significantly better performance.

Ultimately, the way we run NNs is up to chip designers. Anecdotally, I’d say there has been a convergence on fp16 (and related 16-bit floating point formats) for NN computations, especially in GPUs and Google’s TPUs. But, regardless of whether these hypothetical efficiency gains are realized with specialized hardware, the BitNet b1.58 approach could be a very effective way to run LLMs on consumer devices that don’t have a lot of computing power, like phones and smartwatches.