🚧 Lesson 3 of 10 in Level 03
Level 03 • Lesson 3

Multi-Head Attention

Why one attention isn't enough. Running multiple attentions in parallel for richer representations.

The Problem with Single Attention

A single attention mechanism has limitations:

Example: In "The cat sat on the mat because it was tired", the word "it" needs to:
• Attend to "cat" (coreference - what "it" refers to)
• Attend to "tired" (predicate - describing the state)
• Attend to "sat" (action - what was happening)

One attention head can't capture all these relationships optimally!

The Solution: Multiple Heads

Multi-head attention runs several attention mechanisms (heads) in parallel, each with their own learned projections:

Multi-Head Concept

Head 1
Subject-verb
Head 2
Pronoun reference
Head 3
Adjective-noun
Head 4
Semantic similarity
...
And more!

Each head learns to focus on different types of relationships!

How It Works

Step 1: Split Into Heads

Instead of one set of Q, K, V projections, we have h sets (one per head):

# For each head i: Q_i = X @ W_Q_i # (batch, seq, d_k) K_i = X @ W_K_i # (batch, seq, d_k) V_i = X @ W_V_i # (batch, seq, d_v) # Where d_k = d_model / h # Typical: d_model = 512, h = 8, so d_k = 64

Step 2: Compute Attention Per Head

# For each head independently: head_i_output = softmax(Q_i @ K_i^T / sqrt(d_k)) @ V_i # Each head produces (batch, seq, d_v) where d_v = d_model / h

Step 3: Concatenate and Project

# Concatenate all heads concat_output = [head_1, head_2, ..., head_h] # (batch, seq, d_model) # Final linear projection output = concat_output @ W_O # (batch, seq, d_model) # W_O is (d_model, d_model) - learned output projection

Complete Implementation

class MultiHeadAttention: def __init__(self, d_model=512, num_heads=8): self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads # 64 # Separate projections for each head would be inefficient # Instead, project to d_model and split self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) def forward(self, X): batch_size, seq_len, _ = X.shape # 1. Linear projections Q = self.W_Q(X) # (batch, seq, d_model) K = self.W_K(X) # (batch, seq, d_model) V = self.W_V(X) # (batch, seq, d_model) # 2. Split into heads # Reshape from (batch, seq, d_model) to (batch, num_heads, seq, d_k) Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2) # 3. Scaled dot-product attention for each head scores = Q @ K.transpose(-2, -1) / sqrt(self.d_k) attn_weights = softmax(scores, dim=-1) head_outputs = attn_weights @ V # (batch, num_heads, seq, d_k) # 4. Concatenate heads # Transpose back and reshape head_outputs = head_outputs.transpose(1, 2).contiguous() concat = head_outputs.view(batch_size, seq_len, self.d_model) # 5. Final projection output = self.W_O(concat) return output, attn_weights

Why Multiple Heads Help

Different Heads Learn Different Things

Research has shown that different attention heads specialize:

  • Positional heads: Attend to adjacent tokens (syntax)
  • Coreference heads: Track pronoun references
  • Semantic heads: Group related meanings
  • Rare word heads: Focus on unusual tokens
  • Delimiter heads: Watch for punctuation, separators
Interesting Finding: Some heads are "unused" or redundant. In large models, you can often prune (remove) 20-40% of heads with minimal impact!

Computational Efficiency

Multi-head attention has the same computational cost as single-head with full dimensionality:

# Single head with d_model = 512 Q, K, V each: (seq, 512) Q @ K^T: O(seq² × 512) # 8 heads with d_k = 64 Each head: Q, K, V are (seq, 64) Per head Q @ K^T: O(seq² × 64) 8 heads total: O(seq² × 64 × 8) = O(seq² × 512) # Same total computation! But more expressive.
Key Insight: Splitting into heads doesn't increase computation because each head works with lower-dimensional vectors. But it dramatically increases expressiveness!

Common Configurations

Model d_model Heads d_k per head
BERT-Base 768 12 64
GPT-2 768 12 64
GPT-3 12288 96 128
LLaMA-2 4096 32 128

Exercises

Exercise 1: Head Dimensions

If d_model = 1024 and we want 16 heads, what should d_k be?

Exercise 2: Output Shape

Multi-head attention takes input of shape (batch=2, seq=100, d_model=512) with 8 heads. What is the shape of the output?

Exercise 3: Why Split?

Why not just use single-head attention with the full d_model dimension? What do we gain by splitting into multiple heads?

Knowledge Check Quiz

Test Your Understanding

Question 1: What is the primary advantage of using multiple attention heads instead of a single attention mechanism?

Answer: Multiple heads can capture different types of relationships simultaneously (e.g., syntactic, semantic, coreference) that a single head cannot optimally represent all at once.

Question 2: If a model has d_model = 768 and num_heads = 12, what is d_k (the dimension per head)?

Answer: d_k = 768 / 12 = 64. Each head operates on 64-dimensional vectors.

Question 3: Why doesn't multi-head attention increase computational cost compared to single-head with the same total dimension?

Answer: The computation is split across heads—each head uses lower-dimensional vectors (d_k = d_model / h). The total FLOPs remain the same: O(seq² × d_model) regardless of head count.

Question 4: What operation combines the outputs from all attention heads back into a single representation?

Answer: The outputs are concatenated along the dimension axis, then passed through a final linear projection W_O (output weight matrix) to produce the final d_model-sized representation.

Question 5: Name two types of linguistic relationships that different attention heads might specialize in capturing.

Answer: (Any two of) positional/adjacent token relationships, pronoun coreference, subject-verb agreement, semantic similarity, adjective-noun modification, delimiter/punctuation tracking.