The Attention Revolution
In 2017, Google researchers published "Attention Is All You Need." This paper introduced the Transformer architecture and changed AI forever. Before transformers, sequence models used recurrence (RNNs, LSTMs) β processing words one at a time. Transformers process the entire sequence at once using attention.
The Problem with Recurrence
RNNs process sequences like this:
hβ = f(xβ, hβ)
hβ = f(xβ, hβ)
...
hββββ = f(xββββ, hβββ) β Must wait for all previous steps!
This is slow (can't parallelize) and forgetful (information from early steps gets diluted). Attention solves both problems.
Attention: The Core Idea
Attention answers the question: "When processing this word, which other words should I pay attention to?"
Attention Example
Consider the sentence: "The animal didn't cross the street because it was too tired."
What does "it" refer to? The model needs to attend to "animal":
Attention weights for "it" β highest attention on "animal" (75%)
In traditional models, "it" would only have access to the immediately previous hidden state. With attention, it can directly "look at" "animal" even though they're 5 words apart.
Query, Key, Value
Attention is implemented using three projections of each input: Query, Key, and Value. Think of it like a database lookup:
- Query (Q): "What am I looking for?"
- Key (K): "What do I contain?"
- Value (V): "What information do I provide?"
QKV Intuition
Query
Key
Value
Query matches with Keys β determines which Values to retrieve
The Attention Formula
Breaking this down:
- QK^T: Compute similarity between every Query and every Key (dot product)
- / βdβ: Scale by square root of key dimension (prevents softmax saturation)
- softmax: Convert to probabilities (sum to 1)
- Β· V: Weighted sum of Values based on attention scores
Multi-Head Attention
Different words might relate to each other in different ways. "It" might relate to "animal" grammatically, but also relate to "tired" semantically. Multi-head attention runs multiple attention operations in parallel, each learning different types of relationships.
8 Attention Heads (GPT-style)
Each head learns different relationship types:
Subject-verb
Coreference
Modifier-head
Positional
Syntactic
Semantic
Rare patterns
Rare patterns
Heads specialize organically during training β some track grammar, others track meaning
where headα΅’ = Attention(Q Β· Wα΅’^Q, K Β· Wα΅’^K, V Β· Wα΅’^V)
The Transformer Block
A transformer is built by stacking identical blocks. Each block contains:
Transformer Block Structure
Key Components
- Multi-Head Attention: Allows tokens to communicate with each other
- Feed-Forward Network: Processes each token independently (applies non-linearity)
- Layer Norm: Normalizes activations for stable training
- Residual Connections: Skip connections that help gradients flow
Positional Encoding
Attention itself is position-agnostic β it doesn't know where words are in the sentence. "Dog bites man" and "Man bites dog" would look the same! We need to inject position information.
Sinusoidal Position Encoding
The original transformer uses sine and cosine functions of different frequencies:
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
This creates a unique "fingerprint" for each position that the model can learn to interpret. Modern models often use learned position embeddings instead β just another embedding layer for position indices.
Scaled Dot-Product Attention: The Math
Let's derive the attention formula step by step. Understanding each component is essential for understanding how transformers work.
Step 1: Computing Similarity Scores
First, we compute how much each Query matches each Key using a dot product:
Each element scores[i][j] tells us how much token i should attend to token j. But raw dot products can be very large (especially in high dimensions), causing softmax to saturate.
Step 2: Scaling by βd_k
We divide by the square root of the key dimension:
Step 3: Softmax β Converting to Probabilities
Softmax converts raw scores to probabilities that sum to 1 for each query position:
Each row of attention_weights tells us the probability distribution over all positions for a given query. Position i distributes its "attention budget" across positions 1 through seq_len.
Step 4: Weighted Aggregation
Finally, we multiply the attention weights by the Value matrix. Each output position is a weighted combination of all values, where the weights come from the attention scores.
Numerical Example
Let's trace attention through a tiny example with 3 tokens and d_k = 4:
Keys:
kβ = [0.7, -0.2, 0.4, 0.3] (for "The")
kβ = [0.1, 0.6, -0.5, 0.2] (for "cat")
kβ = [0.3, -0.4, 0.9, 0.7] (for "sat")
Step 1: Dot products
qβΒ·kβ = 0.5Γ0.7 + (-0.3)Γ(-0.2) + 0.8Γ0.4 + 0.1Γ0.3 = 0.72
qβΒ·kβ = 0.5Γ0.1 + (-0.3)Γ0.6 + 0.8Γ(-0.5) + 0.1Γ0.2 = -0.53
qβΒ·kβ = 0.5Γ0.3 + (-0.3)Γ(-0.4) + 0.8Γ0.9 + 0.1Γ0.7 = 1.02
Step 2: Scale by β4 = 2
scaled = [0.36, -0.265, 0.51]
Step 3: Softmax
exp(0.36) = 1.433, exp(-0.265) = 0.767, exp(0.51) = 1.665
sum = 3.865
attention = [0.371, 0.199, 0.431]
Interpretation: Token 2 ("cat") attends 43.1% to "sat", 37.1% to "The", and 19.9% to itself.
Causal Masking: Preventing Cheating
In autoregressive language models (like GPT), we must prevent tokens from "seeing the future." During training, all positions are computed simultaneously for efficiency, but token i should only attend to tokens 1 through i β not tokens after it.
The Causal Mask
Before applying softmax, we set all "future" attention scores to negative infinity:
masked_scores[i][j] = -β if j > i
This ensures that after softmax, positions after i receive 0 attention:
Causal Mask (5 tokens)
β = can attend to β = masked out (-β)
This lower-triangular mask is called the causal mask or autoregressive mask. Without it, the model could "cheat" by looking at future tokens during training.
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
This creates an upper-triangular mask where positions above the diagonal are True (to be masked).
Layer Normalization & Residual Connections
Two critical components that make deep transformers trainable: layer normalization stabilizes activations, and residual connections allow gradients to flow.
Layer Normalization (LayerNorm)
LayerNorm normalizes the activations of each token independently:
Where ΞΌ and ΟΒ² are the mean and variance computed over the d_model dimensions of each token. Ξ³ and Ξ² are learnable parameters that allow the network to undo the normalization if needed.
Residual (Skip) Connections
Residual connections add the input directly to the output of a sublayer:
This means the sublayer only needs to learn the residual β the difference between the output and input. This has two crucial benefits:
- Gradient flow: Gradients can skip the sublayer entirely, ensuring they reach earlier layers even when the sublayer's gradient is small
- Identity initialization: At the start of training, Sublayer(x) β 0, so output β x. The network starts as an identity function and gradually learns transformations
Full Transformer Block
Combining everything, a complete transformer block (GPT-style) looks like this:
xβ = xβ + FeedForward(LayerNorm(xβ))
This is called Pre-LayerNorm (LayerNorm before the sublayer). The original transformer paper used Post-LayerNorm (after the sublayer), but Pre-LN trains more stably.
The Feed-Forward Network (FFN)
After attention allows tokens to communicate, each token is processed independently by the feed-forward network (also called the MLP):
Where Wβ β β^(d_model Γ d_ff), Wβ β β^(d_ff Γ d_model). The FFN has two linear transformations with a GELU activation in between.
Why d_ff = 4 Γ d_model?
The inner dimension d_ff is typically 4Γ the model dimension. This means:
- GPT-2 Small (d=768): d_ff = 3072
- GPT-3 (d=12288): d_ff = 49152
- LLaMA 2 70B (d=8192): d_ff = 28672 (3.5Γ, with SwiGLU variant)
Most of the transformer's parameters are in the FFN β roughly 2/3 of total parameters. Attention allows tokens to share information; the FFN is where knowledge is stored and processed.
SwiGLU: The Modern FFN
Modern transformers (like LLaMA) use SwiGLU instead of the standard FFN:
Where SiLU(x) = x Β· Ο(x) is the sigmoid linear unit. This gating mechanism (similar to LSTM gates) allows the network to selectively pass information through, which has been shown to improve performance over ReLU and GELU.
Inference: The KV Cache
During text generation, the model produces one token at a time autoregressively. Without optimization, generating token N requires recomputing attention over all N-1 previous tokens. This is extremely wasteful!
The Key Insight
In causal attention, the Key and Value matrices for tokens 1 through N-1 don't change when we add token N. They've already been computed! We can cache them.
KV Cache Comparison
Token 2: Compute Kβ, Vβ, Kβ, Vβ, Qβ
Token 3: Compute Kβ, Vβ, Kβ, Vβ, Kβ, Vβ, Qβ
Token N: Compute all K, V, Q
Total: O(NΒ²) recomputations
Token 2: Load Kβ, Vβ, Compute Kβ, Vβ, Qβ
Token 3: Load Kβ-Vβ, Compute Kβ, Vβ, Qβ
Token N: Load Kβ-V_{N-1}, Compute K_N, V_N, Q_N
Total: O(N) new computations
Memory Cost
The KV cache stores 2 matrices per layer, each of size (seq_len Γ d_head Γ n_heads):
For GPT-3 (96 layers, 96 heads, 128 dim, 2048 seq len):
The Full GPT Architecture
Putting all the pieces together, here's the complete forward pass through a GPT model:
GPT Forward Pass
x = Embedding[token_ids] (V Γ d β T Γ d)
x = x + PositionEmbed[pos] (T Γ d β T Γ d)
x = a + FeedForward(LayerNorm(a))
x = LayerNorm(x)
logits = x Β· W_out (T Γ d β T Γ V)
P(next_token) = softmax(logits[-1])
Parameter Count
Let's count the parameters for a transformer block with d_model = d, n_heads = h, and inner dimension d_ff = 4d:
| Component | Parameters | Example (d=768) |
|---|---|---|
| Q, K, V projections (each) | dΒ² + d | 768Β² + 768 = 590,592 |
| Output projection | dΒ² + d | 590,592 |
| FFN: Wβ | d Γ 4d + 4d | 768 Γ 3072 + 3072 = 2,362,368 |
| FFN: Wβ | 4d Γ d + d | 3072 Γ 768 + 768 = 2,360,064 |
| LayerNorm Γ 2 | 2 Γ 2d | 3,072 |
| Per Block Total | ~6dΒ² | ~7,074,000 |
For GPT-2 Small (12 layers, d=768): ~85M in transformer blocks + ~7M in embeddings + ~7M in output head = ~124M total parameters. GPT-3 uses 96 layers with d=12,288 for ~175B parameters.
Supplementary Materials
Deep dives into key transformer concepts with interactive examples and code: