🚧 Lesson 2 of 10 in Level 03
Level 03 • Lesson 2

Self-Attention Deep Dive

Understanding how tokens attend to each other. The complete mechanism with examples.

The Intuition

Self-attention allows each token to "look at" all other tokens in the sequence and decide which ones are relevant.

Key Idea: For each word, compute a weighted average of all words in the sentence. The weights depend on how relevant each word is to the current word.

Example Sentence

"The cat sat on the mat because it was tired."

When processing the word "it", what should it attend to?

Self-attention learns these relevance scores automatically!

The Three Vectors: Q, K, V

For each token, we create three vectors:

Query, Key, Value

  • Query (Q): "What am I looking for?"
  • Key (K): "What do I contain?"
  • Value (V): "What information do I provide?"
# For each token embedding x: Q = x @ W_Q # What this token is looking for K = x @ W_K # What this token offers V = x @ W_V # The actual information W_Q, W_K, W_V are learned weight matrices

Analogy: Database Lookup

Database Analogy:
• Query = Your search query
• Key = Database index/key
• Value = The actual data/content

Attention = Match queries to keys, retrieve corresponding values

Computing Attention Scores

Step 1: Dot Product (Similarity)

For each pair of tokens, compute how well the Query matches the Key:

# Attention scores (before softmax) scores = Q @ K^T # For sequence length n, this creates an n×n matrix # scores[i,j] = how much token i should attend to token j

Step 2: Scale

Divide by √d_k to prevent dot products from getting too large:

scores = scores / sqrt(d_k) # d_k = dimension of keys (typically 64 or 128) # Scaling helps gradients flow better

Step 3: Softmax

Convert scores to probabilities (sum to 1):

attention_weights = softmax(scores, axis=-1) # Each row sums to 1 # attention_weights[i,j] = probability that token i attends to token j

Step 4: Weighted Sum

Multiply attention weights by Values:

output = attention_weights @ V # For each token: weighted average of all value vectors # based on attention weights

Complete Example

Sentence: "The cat sat"

3 tokens: ["The", "cat", "sat"]

Step 1: Compute Q, K, V
# Embeddings (simplified, d_model=4) The: [1, 0, 0, 0] cat: [0, 1, 0, 0] sat: [0, 0, 1, 0] # After projection (d_k=2) Q_The = [0.5, 0.2] Q_cat = [0.1, 0.8] Q_sat = [0.3, 0.4] K_The = [0.6, 0.1] K_cat = [0.2, 0.9] K_sat = [0.4, 0.3] V_The = [1, 0] V_cat = [0, 1] V_sat = [1, 1]
Step 2: Attention Scores (Q @ K^T)
The cat sat The [0.32 0.23 0.27] cat [0.38 0.74 0.34] sat [0.34 0.42 0.24]
Step 3: After Softmax
The cat sat The [0.35 0.31 0.34] cat [0.28 0.50 0.22] sat [0.33 0.37 0.30]

Notice: "cat" pays most attention to itself (0.50), but also to "sat" (0.22)

Matrix Form

In practice, we compute attention for the entire sequence at once:

# Input: X (batch_size, seq_len, d_model) # Step 1: Project to Q, K, V Q = X @ W_Q # (batch, seq, d_k) K = X @ W_K # (batch, seq, d_k) V = X @ W_V # (batch, seq, d_v) # Step 2-4: Scaled dot-product attention scores = Q @ K.transpose(-2, -1) / sqrt(d_k) # (batch, seq, seq) attn_weights = softmax(scores, dim=-1) # (batch, seq, seq) output = attn_weights @ V # (batch, seq, d_v) # Final output has same sequence length as input!
Key Property: Self-attention is permutation-equivariant (if you reorder input, output reorders the same way) and can handle variable-length sequences.

Why This Works

Properties of Self-Attention

1. Long-range dependencies: Any token can directly attend to any other token, regardless of distance. No vanishing gradients through time!

2. Parallel computation: Unlike RNNs, all attention scores can be computed simultaneously. Much faster on GPUs.

3. Interpretable: We can visualize attention weights to see what the model is "looking at."

4. Content-based: Attention depends on the actual content of tokens, not just their position.

Exercises

Exercise 1: Attention Shape

For a sequence of length 100 with d_k = 64, what are the shapes of:
- Q, K, V matrices?
- The attention score matrix?
- The output?

Exercise 2: Attention Weights

If attention_weights[5, 10] = 0.3, what does this mean in plain English?

Exercise 3: Complexity

What is the computational complexity of self-attention in terms of sequence length n? Why might this be a problem for very long sequences?