The Problem with Single Attention
A single attention mechanism has limitations:
- It can only capture one type of relationship at a time
- Different words need different types of attention
- Language has many simultaneous relationships (syntax, semantics, coreference)
• Attend to "cat" (coreference - what "it" refers to)
• Attend to "tired" (predicate - describing the state)
• Attend to "sat" (action - what was happening)
One attention head can't capture all these relationships optimally!
The Solution: Multiple Heads
Multi-head attention runs several attention mechanisms (heads) in parallel, each with their own learned projections:
Multi-Head Concept
Subject-verb
Pronoun reference
Adjective-noun
Semantic similarity
And more!
Each head learns to focus on different types of relationships!
How It Works
Step 1: Split Into Heads
Instead of one set of Q, K, V projections, we have h sets (one per head):
Step 2: Compute Attention Per Head
Step 3: Concatenate and Project
Complete Implementation
Why Multiple Heads Help
Different Heads Learn Different Things
Research has shown that different attention heads specialize:
- Positional heads: Attend to adjacent tokens (syntax)
- Coreference heads: Track pronoun references
- Semantic heads: Group related meanings
- Rare word heads: Focus on unusual tokens
- Delimiter heads: Watch for punctuation, separators
Computational Efficiency
Multi-head attention has the same computational cost as single-head with full dimensionality:
Common Configurations
| Model | d_model | Heads | d_k per head |
|---|---|---|---|
| BERT-Base | 768 | 12 | 64 |
| GPT-2 | 768 | 12 | 64 |
| GPT-3 | 12288 | 96 | 128 |
| LLaMA-2 | 4096 | 32 | 128 |
Exercises
Exercise 1: Head Dimensions
If d_model = 1024 and we want 16 heads, what should d_k be?
Exercise 2: Output Shape
Multi-head attention takes input of shape (batch=2, seq=100, d_model=512) with 8 heads. What is the shape of the output?
Exercise 3: Why Split?
Why not just use single-head attention with the full d_model dimension? What do we gain by splitting into multiple heads?
Knowledge Check Quiz
Test Your Understanding
Question 1: What is the primary advantage of using multiple attention heads instead of a single attention mechanism?
Answer: Multiple heads can capture different types of relationships simultaneously (e.g., syntactic, semantic, coreference) that a single head cannot optimally represent all at once.
Question 2: If a model has d_model = 768 and num_heads = 12, what is d_k (the dimension per head)?
Answer: d_k = 768 / 12 = 64. Each head operates on 64-dimensional vectors.
Question 3: Why doesn't multi-head attention increase computational cost compared to single-head with the same total dimension?
Answer: The computation is split across heads—each head uses lower-dimensional vectors (d_k = d_model / h). The total FLOPs remain the same: O(seq² × d_model) regardless of head count.
Question 4: What operation combines the outputs from all attention heads back into a single representation?
Answer: The outputs are concatenated along the dimension axis, then passed through a final linear projection W_O (output weight matrix) to produce the final d_model-sized representation.
Question 5: Name two types of linguistic relationships that different attention heads might specialize in capturing.
Answer: (Any two of) positional/adjacent token relationships, pronoun coreference, subject-verb agreement, semantic similarity, adjective-noun modification, delimiter/punctuation tracking.