You can prune entire layers of LLMs (and they still work)
[Paper: The Unreasonable Ineffectiveness of the Deeper Layers]
Summary by Adrian Wilkins-Caruana
In general, the more parameters an LLM has, the better it performs. The best-performing open-weight LLMs have hundreds of billions of parameters, but, oddly, not all of these parameters are useful. For a long time, AI researchers have known that some of a model’s parameters are much more important than others. Researchers even have a technique called “pruning,” which they use to remove some of these unhelpful parameters and reduce the size of the model without affecting its predictive performance very much. But there often wasn’t much rhyme or reason as to which parameters were useless — that is, until a recent research paper from Gromov et al., which found that entire layers of parameters in an LLM’s network can be pruned!
I know what you’re thinking: How can an entire layer be removed from an LLM — surely that’ll have a huge impact on its accuracy, right? Actually, it’s not quite that straightforward. The researchers found that certain layers can be pruned with minimal impact on the LLM’s predictive performance. In fact, layers can continue to be pruned, up to a point, before the LLM’s performance falls off a cliff.
Also, the predictive performance that’s lost by pruning layers can be restored with a tiny amount of fine-tuning of the pruned LLM. The figure below shows the predictive performance (y-axes) of the Llama-2-70B model against the fraction of the model’s layers that have been dropped. The top-two plots show the model’s accuracy on two question-answering benchmarks, while the bottom plot shows the validation loss. The dark-blue trace shows the pruned model’s performance, while the light-blue trace shows the performance with “healing,” which is the post-pruning fine-tuning.
Before continuing, it’s worth taking a moment to consider why a pruned LLM behaves this way. I often think of the parameters in an LLM working together in perfect harmony. Disrupting this harmony — say, by deleting an entire layer! — would be very damaging for the network since any errors introduced would cascade down subsequent layers. This intuition is a decent model, so long as a layer’s output is significantly different from its input.
But, what if a layer’s output wasn’t significantly different? In that case, removing the layer shouldn’t affect the network very much, since it didn’t do much to begin with. In fact, with typical transformer architecture models, we do expect layers to have output similar to their input because each layer adds a delta to the input — that is, the output of a transformer layer is always an adjustment (something added to) its input. And, since each layer is adjusting its input, we might also expect earlier layers to have more impact since their changes have a compound effect on the later layers.
So, suppose that layers can be removed from an LLM without hurting its performance. This should indicate that the original LLM contains some layers that aren’t very useful (i.e., their output is typically quite similar to their input). That’s what these researchers found! The figure below shows how much a given layer’s output changes from its input. They measured it using a metric called Shifted Rescaled Angular Distance, which is close to 1 when the change is large and close to 0 when the change is small. (In case you’re curious, this distance is a scaled factor of the angle between the vectors before and after the layers being measured.) The y-axes in these plots indicates the number of consecutive layers that were pruned — so the bottom row represents the full architecture, while higher rows represent heavily pruned versions of the model. I’ll describe in a moment which layers were removed.
Across all model sizes (and other non-Llama LLMs that aren’t pictured), there seems to be a trend that the deeper layers in a network tend to contribute less than shallower layers. This means that many layers can be pruned without harming the LLM much. So, based on these results, the researchers devised the following strategy to prune layers from an LLM:
Choose how many layers you want to prune, n.
Compute the similarity (angular distance) between the inputs of all pairs of layers that are exactly n layers apart.
Select the pair with the highest similarity (lowest angular distance) and prune them.
Optionally, heal the network with some fine-tuning.
This strategy is quite simple, but it requires a lot of work before pruning to determine the angular distances between input layers; it also requires that the user load and run the entire unpruned model. This might be prohibitive for some users, since they might be pruning because they don’t have the resources to run an unpruned LLM. So, the researchers devised an even simpler pruning strategy: Decide how many layers you want to prune — call this number n — and remove the last n layers before the final layer (which the researchers noticed is always a useful layer), and then heal with fine-tuning. In the figure above you can see that removing the final layer is never a good idea; it has much lower similarity to the layers that precede it (it’s blue while the preceding layers are yellow).
The figure below compares the quality of these approaches. Each graph plots both the simple pruning strategy (in red) and the more complex similarity-based strategy (in blue). While the similarity-based approach tends to preserve more accuracy than the simple approach, the difference mostly vanishes when healing is used, as shown in the right column. So, if you plan to heal, then either approach is suitable.
By applying this pruning approach with the latest model-quantization techniques, Llama-2-70B — which spans 140 GB of memory and consumes 30 billion floating-point operations (FLOPS) per token — can run with significantly fewer resources: 17.5 GB of memory and 15 billion FLOPS. This makes it possible to run the model on consumer computers, not just big, beefy datacenter computers. But it also leaves me wondering why the models we train have layers that don’t contribute much to the result in the first place. Pruning is great, but wouldn’t it be better if we could train equivalent LLMs that didn’t have unnecessary layers in the first place?