We Found An Neuron in GPT-2

Written in collaboration with Joseph Miller. See the discussion of this post over on LessWrong.

We started out with the question: How does GPT-2 know when to use the word an over a? The choice depends on whether the word that comes after starts with a vowel or not, but GPT-2 is only capable of predicting one word at a time.

We still don’t have a full answer, but we did find a single MLP neuron in GPT-2 Large that is crucial for predicting the token “ an”. And we also found that the weights of this neuron correspond with the embedding of the “ an” token, which led us to find other neurons that predict a specific token.

Discovering the neuron

Choosing the prompt

It was surprisingly hard to think of a prompt where GPT-2 would output the token “ an” (the leading space is part of the token) as the top prediction. In fact, we gave up with GPT-2_small and switched to GPT-2_large. As we’ll see later, even GPT-2_large systematically under-predicts “ an” in favor of “ a”. This may be because smaller language models lean on the higher frequency of a to make a best guess. The prompt we finally found that gave a high (64%) probability for “ an” was:

“I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked”

The first sentence was necessary to push the model towards an indefinite article — without it the model would make other predictions such as “[picked] up”.

Before we proceed, here’s a quick overview on the transformer architecture. Each attention block and MLP takes inputs and adds outputs to the residual stream.

Logit Lens

Using a technique known as logit lens, we took the logits from the residual stream between each layer and plotted the difference between logit(‘ an’) and logit(‘ a’). We found a big spike after Layer 31’s MLP.

Activation Patching by the Layer

Activation patching is a technique introduced by Meng et. al. (2022) to analyze the significance of a single layer in a transformer. First, we saved the activation of each layer when running the original prompt through the model — the “clean activation”.

We then ran a corrupted prompt through the model: “I climbed up the pear tree and picked a pear. I climbed up the lemon tree and picked”. By replacing the word ‘apple’ with ‘lemon’, we induce the model to predict the token ‘ a’ instead of ‘ an’.

With the model predicting " a" over " an", we can replace a layer’s corrupted activation with its clean activation to see how much the model shifts towards the " an" token, which indicates that layer’s significance to predicting " an". We repeat this process over all the layers of the model.

We’re mostly going to ignore attention for the rest of this post, but these results indicate that Layer 26 is where " picked" starts thinking a lot about " apple", which is obviously required to predict " an".

The two MLP layers that stand out are Layer 0 and Layer 31. We already know that Layer 0’s MLP is generally important for GPT-2 to function1 (although we’re not sure why attention in Layer 0 is important). The effect of Layer 31 is more interesting. Our results suggests that Layer 31’s MLP plays a significant role in predicting the ‘ an’ token. (See this comment if you’re confused how this result fits with the logit lens above.)

Finding 1: We can discover predictive neurons by activation patching individual neurons

Activation patching has been used to investigate transformers by the layer, but can we push this technique further and apply it to individual neurons? Since each MLP in a transformer only has one hidden layer, each neuron’s activation does not affect any other neuron in the MLP. So we should be able to patch individual neurons, because they are independent from each other in the same sense that the attention heads in a single layer are independent from each other.

We run neuron-wise activation patching for Layer 31’s MLP in a similar fashion to the layer-wise patching above. We reintroduce the clean activation of each neuron in the MLP when running the corrupted prompt through the model, and look at how much restoring each neuron contributes to the logit difference between " a" and " an".

We see that patching Neuron 892 recovers 50% of the clean prompt’s logit difference, while patching whole layer actually does worse at 49%.

Finding 2: The activation of the “an-neuron” correlates with the “ an” token being predicted.

Neuroscope Layer 31 Neuron 892 Maximum Activating Examples

Neuroscope's An Neuroscope is an online tool that shows the top activating examples in a large dataset for each neuron in GPT-2. When we look at Layer 31 Neuron 892, we see that the neuron maximally activates on tokens where the subsequent token is " an".

But Neuroscope only shows us the top 20 most activating examples. Would there be a trend for a wider range of activations?

Testing the neuron on a larger dataset

To check for a trend, we ran the pile-10k dataset through the model. This is a diverse set of about 10 million tokens taken from The Pile, split into prompts of 1,024 tokens. We plotted the proportion of " an" predictions across the range of neuron activations:

We see that the proportion of " an" predictions increases as the neuron’s activation increases, to the point where " an" is always the top prediction. The trend is somewhat noisy, which suggests that there might be other mechanisms in the model that also contribute towards the ‘ an’ prediction. Or maybe when the " an" logit increases, other logits increase at the time.

Note that the model only predicted “ an” 1,500 times, even though it actually occurred 12,000 times in the dataset. No wonder it was so hard to find a good prompt!

The neuron’s output weights have a high dot-product with the “ an” token

How does the neuron influence the model’s output? Well, the neuron’s output weights have a high dot product with the embedding for the token “ an”. We call this the congruence of the neuron with the token. Compared to other random tokens like " any" and " had", the neuron’s congruence with “ an” is very high:

Congruence Illustration

In fact, when we calculate the neuron’s congruence with all of the tokens, there are a few clear outliers: