Understanding why autoregressive models need masks, with interactive visualizations and PyTorch implementations.
Why Causal Masking Matters
When training a language model to predict the next token, we have a fundamental constraint: the model must not see future tokens. This is analogous to how humans generate textβwe write one word at a time, without knowing what we'll write next.
The Core Principle: At position i, the model can only attend to positions j where j β€ i. This is called causal or autoregressive attention.
Visualizing the Causal Mask
For a sequence of 5 tokens, the causal mask looks like this:
Causal Mask Matrix
tβ
tβ
tβ
tβ
tβ
tβ
1
0
0
0
0
tβ
1
1
0
0
0
tβ
1
1
1
0
0
tβ
1
1
1
1
0
tβ
1
1
1
1
1
Green = AllowedRed = Blocked
Mathematical Formulation
The causal mask is applied to the attention scores before the softmax:
Mij = { 1 if j β€ i, 0 otherwise }
Attention(Q, K, V) = softmax( (QKT β M) / βdk ) V
Where β denotes element-wise multiplication. Positions where Mij = 0 are set to -β before softmax, making their attention weight 0.
Method 3: Efficient Implementation with torch.nn.functional.scaled_dot_product_attention
import torch.nn.functional as F
def efficient_causal_attention(Q, K, V, dropout_p=0.0):
"""
Efficient causal attention using PyTorch's optimized implementation.
Available in PyTorch 2.0+ with Flash Attention support.
"""
# is_causal=True automatically creates and applies the causal mask
output = F.scaled_dot_product_attention(
Q, K, V,
attn_mask=None, # No custom mask
dropout_p=dropout_p,
is_causal=True # Enable causal masking
)
return output
Performance Tip: PyTorch 2.0's scaled_dot_product_attention with is_causal=True automatically uses Flash Attention when available, providing significant speedups and memory savings.
Attention Pattern Visualization
Let's visualize what the attention pattern looks like for the sentence "The cat sat on the mat":
Example Attention Weights (Causal)
The
cat
sat
on
the
The
1.0
0
0
0
0
cat
0.4
0.6
0
0
0
sat
0.2
0.3
0.5
0
0
on
0.1
0.15
0.25
0.5
0
the
0.08
0.12
0.2
0.3
0.3
Each row shows which previous tokens the current token attends to. Notice the triangular pattern.
Training vs. Inference
During Training
We use teacher forcing with causal masking:
# Training: Process entire sequence at once with causal mask
def train_step(model, input_ids, target_ids):
"""
input_ids: [The, cat, sat, on, the, mat]
target_ids: [cat, sat, on, the, mat, ]
"""
# Forward pass with causal mask
logits = model(input_ids) # Each position can only see previous positions
# Compute loss for all positions simultaneously
loss = cross_entropy(logits, target_ids)
return loss
This is efficient because we compute all positions in parallel while maintaining the autoregressive property.
During Inference
We generate one token at a time:
# Inference: Generate tokens one at a time
def generate(model, prompt, max_length=50):
tokens = tokenize(prompt) # [The, cat]
for _ in range(max_length):
# Forward pass - causal mask is automatic since there's no future
logits = model(tokens)
# Get next token prediction
next_token = argmax(logits[-1]) # Only use last position
# Append and continue
tokens.append(next_token)
if next_token == EOS:
break
return detokenize(tokens)
Key Insight: During inference, the causal constraint is naturally satisfied because future tokens don't exist yet. The KV cache optimization stores previous key/value pairs to avoid recomputation.
Advanced: Other Mask Types
Padding Mask
When batching sequences of different lengths, we pad shorter sequences. The padding mask prevents attention to padding tokens:
def create_padding_mask(seq, pad_token_id=0):
"""
Create a mask for padding tokens.
Args:
seq: (batch, seq_len) tensor of token ids
Returns:
mask: (batch, 1, 1, seq_len) boolean mask
"""
# True for real tokens, False for padding
mask = (seq != pad_token_id).unsqueeze(1).unsqueeze(2)
return mask # Shape: (batch, 1, 1, seq_len)
# Combined with causal mask
def create_combined_mask(seq, pad_token_id=0):
seq_len = seq.size(1)
causal = torch.tril(torch.ones(seq_len, seq_len))
padding = (seq != pad_token_id).unsqueeze(1).expand(-1, seq_len, -1)
return causal * padding # Both conditions must be satisfied
Sliding Window (Local) Attention
Some models use a fixed-size window to reduce computation:
def create_sliding_window_mask(seq_len, window_size=256):
"""Create a mask that only attends to the last window_size tokens."""
mask = torch.zeros(seq_len, seq_len)
for i in range(seq_len):
start = max(0, i - window_size + 1)
mask[i, start:i+1] = 1
return mask
Interactive Exercises
Exercise 1: Mask Shape
For a sequence of length 10 with batch size 4 and 8 attention heads, what is the shape of the causal mask?
(10, 10) - The mask is the same for all batches and heads
(4, 8, 10, 10)
(4, 10, 10)
Exercise 2: Attention Pattern
In a 4-token sequence, which positions can the token at position 2 (0-indexed) attend to?
Positions 0, 1, and 2
Positions 0 and 1 only
Positions 0, 1, 2, and 3
Exercise 3: Implementation
Complete the following code to apply the causal mask: