If you've been online lately, you'll know that generative pre-trained transformer (GPT1) models have taken the internet by storm. If you haven't given it a shot yet, ask ChatGPT for advice on a topic that you're an expert in — if you're reading this article, that might be related to programming or machine learning. The thing that astounds me most about ChatGPT-4 is that it seems to have learned something deeper about the world than the average of the data it was trained from. I often find the answers from GPT to be not only more accurate than StackOverflow or Reddit, but also kinder and more targeted to my level of understanding2.
In this post, I'll dive into the details of why transformers work so well and why the most powerful machine learning models in the world have been built with them over the past few years. This article is written for programmers and engineers who have gone through Neural Networks: Zero to Hero or have an equivalent understanding. I wrote it as I was diving into these topics, so I'm not an expert either. If I've missed important details or got something wrong, please let me know in the comments.
The beginning
Until 2017, state of the art machine learning models for text generation used recurrent neural networks (RNNs). RNNs worked fairly well for text generation and translation tasks, but they were running into limitations.
RNNs fundamentally work by encoding a sequence, step by step and encoding the entire context into a single fixed length vector. This meant that as sequences got longer and models became deeper, the models became harder and harder to train. They suffered from vanishing and exploding gradients, and were finicky to tune. They couldn't track dependencies over long sequences, because you could only fit a certain amount of information about earlier tokens into the fixed length context vector. Additionally, they couldn't be trained with nearly as much parallelism as transformers, because training a model couldn't begin predicting token n + 1 until it had finished generating token n. Each sequence had to be processed token by painstaking token.
Transformers
In 2017, transformer models were introduced in Attention is All You Need. Instead of generating a single context vector, they relied on the attention mechanism. Attention avoids encoding all of the information from the sequence into a single vector by letting the model look at every token in the sequence simultaneously, before trying to predict the next token. The cost of attention is that for every new token you generated, you need to look at the output of every earlier token, so computing a sequence of N tokens would require N² operations, instead of N.
The original transformers were designed for translation, so they had two different types of attention layers — decoders and encoders. Encoder layers were used to attend to the tokens of the input sequence, while encoder-decoder layers attended to the output of the encoder sequence and to the already-generated part of the output sequence.
This is where the key-query distinction comes from in attention. In attention formulations in decoder-only transformers, one of the tensors is named K for keys and Q for queries, but both tensors are structurally identical — they may as well be called A and B. This is a historical artifact from older transformer models, especially those that are focused on converting one sequence into another, like translation models. Those earlier models had encoder-decoder layers, where keys and values come from the encoder, but queries come from the decoder. We still call the tensors K and Q as a historical homage.
In 2018 Liu and company published Generating Wikipedia by Summarizing Long Sequences with the revelation that transformers in their decoder-only configuration could be a powerful tool for generating text. This showed that transformers could be used without encoder layers. This means that the model can generate text based purely on a sequence prefix, or even from an empty string.
Outrageously large neural networks
One of the most important factors in the capacity of a model to learn is the number of trainable parameters it can use. One of the most amazing discoveries of the latest era of deep learning is that even with a fixed amount of training data, increasing the size of a model will increase its ability to learn. This is counterintuitive! Naively, you would expect models to just memorize the training data if they were large enough and the training data was limited. Instead, we see that larger models actually memorize less, and are quicker to generalize from memorizing the training data to learning more powerful abstractions that underly the data.3
Transformer models were some of the first models that made sense to train and run at this scale. Earlier model architectures were hard to train with billions of parameters because they needed too much time to train, and because they weren't economical to run inference at scale.
With auto-regressive decoder-only transformers4, models can be trained over every single element of a sequence in parallel, through the masking mechanism. An attention mask ensures that when predicting token N, only tokens before N are inspected, so the transformer can't cheat and see elements further ahead in the sequence.
This means that the model trainer will compare its predicted token to the correct token on every element of the sequence at the same time, enabling a transformer to be trained on batch size * sequence length number of samples in a single batch. In RNNs, a single batch would only predict batch size tokens.
Regularization
With so many parameters, it could be easy for for these models to memorize the training input, so we employ regularization techniques to reduce overfitting. Traditionally, regularization meant penalizing weights in proportion to their magnitude (or squared magnitude). This is referred to as L1 regularization for linear penalties for weights and L2 regularization for the squared magnitude. This meant that models would only learn to have extremely high or low weights if they contributed significantly to reducing the loss. Most GPT algorithms don't use L1 or L2 regularization, and instead use dropout and layer normalization.
Dropout is a layer that is inserted into a neural network that blocks out a portion of the connections in a mini-batch run. In transformer training, we typically drop 20% of the connections. This incentives the network to not rely too much on any individual connection between layers. For an entire run of the mini batch, the transformer will have a fixed set of connections completely erased. These connections won't be updated when we back-propagate our gradients either. On the next mini-batch, these connections will be restored, and another random subset of connections will be picked to be blocked out.
Normalization is another regularization technique that encourages models to represent the data with minimal overfitting, and that also helps improve the speed of model training. Like L1 and L2 regularization, normalization pushes layers to smaller weight values, but instead of punishing individual weights for being too high, it instead pushes the output of each layer to have a mean of zero and a variance of one.
Residual streams
One critical factor in allowing transformers to stack so deeply is the residual stream. The residual stream allows information from the previous layer to flow around the attention layer if it's more effective for an attention layer wants to ignore it. In combination with regularization techniques, the attention layers are actually encouraged to leave information alone in the residual stream if it is not directly useful in reducing the loss for the mini-batch.
Here's an intuition-primer for why a residual stream might be useful in a multi-layer transformer: An attention layer later in the model might require attending to both the original token embedding and the output from an earlier attention layer in order to predict the pronouns of a person in sentence. If transformers didn't have a residual stream, this information would be forced to travel directly through the attention layer, which would make it harder for the attention layers to learn higher level concepts, and would instead be forced to carry forward basic token embeddings.
The residual streams end up being an information-flow bottleneck in the model training because the residual stream has far lower dimensionality than the number of trainable parameters in the network. Anthropic explains this better than I can (from the excellent Mathematical Framework for Transformer Circuits):
There are generally far more "computational dimensions" (such as neurons and attention head result dimensions) than the residual stream has dimensions to move information. Just a single MLP layer typically has four times more neurons than the residual stream has dimensions. So, for example, at layer 25 of a 50 layer transformer, the residual stream has 100 times more neurons as it has dimensions before it, trying to communicate with 100 times as many neurons as it has dimensions after it, somehow communicating in superposition!
Pre-training and general intelligence
Lastly and most importantly, GPTs are pre-trained. A well resourced organization like OpenAI or Google can spend hundreds of millions of dollars training GPT models that can be used in a mind bogglingly large number of contexts. With earlier models, it was common to take a base model like VGG, but then to do fine-tuning on task specific data (for example, to tell you if a picture was a hot dog) and then stick it into a SaaS product where it was a little cog in a big machine. Often, users wouldn't know that there was a machine learning model involved in the product. The only indication that it existed at all would be that Google would understand your query even if it didn't have any matching words.
On the other hand, GPTs can perform tasks that they learn in context. This means that they can perform novel tasks without any training at all, and can complete a huge variety of tasks with simple natural language-based prompting. Of course, you can fine-tune a GPT, but many users are able to get significant utility out of careful prompting and creative thinking. Like VGG or BERT, GPTs can be used in programming contexts, but they can also chat, write programs or play Minecraft. These many use cases, and the fact that individuals and businesses are willing to pay for them, means that it makes sense for AI companies to invest enormous sums into training these models and then earn that investment back, token by token.
Hopefully this gives you an idea of why transformers have become ubiquitous over the past two years, and gives you some insight into some of the decisions that went into formulating transformers. In my next article, I'll focus on some of the ways that transformer models have been improved since GPT-3, including new discoveries like instruction-tuning, RLHF and flash-attention.
What concepts in large language models and GPTs have you found confusing?
As far as I know, all LLMs these days are GPTs, and I think GPT is a better term to explain how they work.
GPT has never roasted me because it thought my question was the same as a related but different question from 8 years ago.
Belkin et. al. have a hypothesis on why this happens, but as far as I can tell the mechanism for this isn't well understood.
Yann LeCun says that these are really, “auto-regressive encoder-decoder” models, but I’m not sure why — it seems like the encoder and cross-attention are missing!