🚧 Lesson 7 of 10 in Level 03
Level 03 • Lesson 7

Masked Attention

Causal masking for autoregressive generation. Preventing lookahead during training.

The Causal Mask

Decoder-only models need to prevent tokens from attending to future positions:

# Causal mask (look-ahead mask) # For sequence length 4: mask = [[1, 0, 0, 0], # Position 0: can see position 0 [1, 1, 0, 0], # Position 1: can see positions 0-1 [1, 1, 1, 0], # Position 2: can see positions 0-2 [1, 1, 1, 1]] # Position 3: can see positions 0-3
Why? During training, we show the full target sequence. Without masking, the model could cheat by looking at future tokens. During generation, future tokens don't exist yet!

Implementation

# Create causal mask def create_causal_mask(seq_len): mask = torch.tril(torch.ones(seq_len, seq_len)) return mask # Lower triangular # Apply to attention scores scores = Q @ K.T / sqrt(d_k) scores = scores.masked_fill(mask == 0, float('-inf')) attn_weights = softmax(scores, dim=-1)

Training vs Inference

During training, we process the whole sequence at once with masking. During inference, we generate one token at a time, so masking happens naturally.

Want to go deeper? Check out the Causal Masking Deep Dive for interactive visualizations, complete PyTorch implementations, and hands-on exercises.

Practical Examples

Example 1: Text Generation with Causal Masking

When generating text token-by-token, each new token can only attend to previously generated tokens:

# Generating "Hello world" step_1: ["Hello"] → attend to ["Hello"] only step_2: ["Hello", " world"] → "world" attends to ["Hello", " world"] step_3: ["Hello", " world", "!"] → "!" attends to all previous tokens # The causal mask enforces this during training # so the model learns to predict without peeking ahead

Example 2: GPT-2 Attention Pattern Visualization

Here's how attention weights look with causal masking for a 5-token sequence:

# Attention weight matrix (5x5) with causal mask applied # Rows = queries (current position), Columns = keys (positions attended to) t0 t1 t2 t3 t4 ┌─────────────────────────────┐ t0 │ 0.40 0.00 0.00 0.00 0.00 │ ← Position 0: only sees itself t1 │ 0.25 0.35 0.00 0.00 0.00 │ ← Position 1: sees t0, t1 t2 │ 0.15 0.20 0.30 0.00 0.00 │ ← Position 2: sees t0, t1, t2 t3 │ 0.10 0.10 0.15 0.25 0.00 │ ← Position 3: sees t0-t3 t4 │ 0.10 0.10 0.10 0.15 0.20 │ ← Position 4: sees all positions └─────────────────────────────┘ # Upper triangle is zero (masked out) # Lower triangle shows learned attention patterns

Example 3: Efficient KV Cache with Causal Attention

During inference, we cache Key and Value tensors to avoid recomputation:

class CausalSelfAttention(nn.Module): def forward(self, x, past_kv=None): B, T, C = x.size() # Compute Q, K, V for current tokens q, k, v = self.qkv(x).split(self.n_embd, dim=2) # Append to cache if provided if past_kv is not None: past_k, past_v = past_kv k = torch.cat([past_k, k], dim=1) v = torch.cat([past_v, v], dim=1) # Causal mask: query position i can only attend to keys j where j <= i attn = (q @ k.transpose(-2, -1)) * (1.0 / sqrt(k.size(-1))) # Apply causal mask (only needed for the new query positions) if past_kv is not None: # Only mask the new query positions causal_mask = torch.tril(torch.ones(T, k.size(1))) attn = attn.masked_fill(causal_mask == 0, float('-inf')) else: # Full causal mask during training causal_mask = torch.tril(torch.ones(T, T)) attn = attn.masked_fill(causal_mask == 0, float('-inf')) attn = F.softmax(attn, dim=-1) return attn @ v, (k, v) # Return output and updated cache
Key Insight: The KV cache makes generation O(1) per step instead of O(n), but causal masking ensures we never attend to "future" positions that haven't been generated yet.