The Causal Mask
Decoder-only models need to prevent tokens from attending to future positions:
# Causal mask (look-ahead mask)
# For sequence length 4:
mask = [[1, 0, 0, 0], # Position 0: can see position 0
[1, 1, 0, 0], # Position 1: can see positions 0-1
[1, 1, 1, 0], # Position 2: can see positions 0-2
[1, 1, 1, 1]] # Position 3: can see positions 0-3
Why? During training, we show the full target sequence. Without masking,
the model could cheat by looking at future tokens. During generation, future tokens don't exist yet!
Implementation
# Create causal mask
def create_causal_mask(seq_len):
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask # Lower triangular
# Apply to attention scores
scores = Q @ K.T / sqrt(d_k)
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = softmax(scores, dim=-1)
Training vs Inference
During training, we process the whole sequence at once with masking. During inference, we generate one token at a time, so masking happens naturally.
Want to go deeper? Check out the
Causal Masking Deep Dive
for interactive visualizations, complete PyTorch implementations, and hands-on exercises.
Practical Examples
Example 1: Text Generation with Causal Masking
When generating text token-by-token, each new token can only attend to previously generated tokens:
# Generating "Hello world"
step_1: ["Hello"] → attend to ["Hello"] only
step_2: ["Hello", " world"] → "world" attends to ["Hello", " world"]
step_3: ["Hello", " world", "!"] → "!" attends to all previous tokens
# The causal mask enforces this during training
# so the model learns to predict without peeking ahead
Example 2: GPT-2 Attention Pattern Visualization
Here's how attention weights look with causal masking for a 5-token sequence:
# Attention weight matrix (5x5) with causal mask applied
# Rows = queries (current position), Columns = keys (positions attended to)
t0 t1 t2 t3 t4
┌─────────────────────────────┐
t0 │ 0.40 0.00 0.00 0.00 0.00 │ ← Position 0: only sees itself
t1 │ 0.25 0.35 0.00 0.00 0.00 │ ← Position 1: sees t0, t1
t2 │ 0.15 0.20 0.30 0.00 0.00 │ ← Position 2: sees t0, t1, t2
t3 │ 0.10 0.10 0.15 0.25 0.00 │ ← Position 3: sees t0-t3
t4 │ 0.10 0.10 0.10 0.15 0.20 │ ← Position 4: sees all positions
└─────────────────────────────┘
# Upper triangle is zero (masked out)
# Lower triangle shows learned attention patterns
Example 3: Efficient KV Cache with Causal Attention
During inference, we cache Key and Value tensors to avoid recomputation:
class CausalSelfAttention(nn.Module):
def forward(self, x, past_kv=None):
B, T, C = x.size()
# Compute Q, K, V for current tokens
q, k, v = self.qkv(x).split(self.n_embd, dim=2)
# Append to cache if provided
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=1)
v = torch.cat([past_v, v], dim=1)
# Causal mask: query position i can only attend to keys j where j <= i
attn = (q @ k.transpose(-2, -1)) * (1.0 / sqrt(k.size(-1)))
# Apply causal mask (only needed for the new query positions)
if past_kv is not None:
# Only mask the new query positions
causal_mask = torch.tril(torch.ones(T, k.size(1)))
attn = attn.masked_fill(causal_mask == 0, float('-inf'))
else:
# Full causal mask during training
causal_mask = torch.tril(torch.ones(T, T))
attn = attn.masked_fill(causal_mask == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
return attn @ v, (k, v) # Return output and updated cache
Key Insight: The KV cache makes generation O(1) per step instead of O(n),
but causal masking ensures we never attend to "future" positions that haven't been generated yet.