πŸ“š Supplementary Lesson for Level 03
Level 03 β€’ Supplement

Causal Masking Deep Dive

Understanding why autoregressive models need masks, with interactive visualizations and PyTorch implementations.

Why Causal Masking Matters

When training a language model to predict the next token, we have a fundamental constraint: the model must not see future tokens. This is analogous to how humans generate textβ€”we write one word at a time, without knowing what we'll write next.

The Core Principle: At position i, the model can only attend to positions j where j ≀ i. This is called causal or autoregressive attention.

Visualizing the Causal Mask

For a sequence of 5 tokens, the causal mask looks like this:

Causal Mask Matrix

tβ‚€
t₁
tβ‚‚
t₃
tβ‚„
tβ‚€
1
0
0
0
0
t₁
1
1
0
0
0
tβ‚‚
1
1
1
0
0
t₃
1
1
1
1
0
tβ‚„
1
1
1
1
1

Green = Allowed Red = Blocked

Mathematical Formulation

The causal mask is applied to the attention scores before the softmax:

Mij = { 1 if j ≀ i, 0 otherwise }

Attention(Q, K, V) = softmax( (QKT βŠ™ M) / √dk ) V

Where βŠ™ denotes element-wise multiplication. Positions where Mij = 0 are set to -∞ before softmax, making their attention weight 0.

PyTorch Implementation

Method 1: Using torch.tril

import torch import torch.nn as nn import math def create_causal_mask(seq_len): """Create a causal (lower-triangular) mask.""" # torch.tril creates a lower triangular matrix mask = torch.tril(torch.ones(seq_len, seq_len)) return mask # Shape: (seq_len, seq_len) # Example usage seq_len = 5 mask = create_causal_mask(seq_len) print("Causal mask:") print(mask) # Output: # tensor([[1., 0., 0., 0., 0.], # [1., 1., 0., 0., 0.], # [1., 1., 1., 0., 0.], # [1., 1., 1., 1., 0.], # [1., 1., 1., 1., 1.]])

Method 2: Causal Attention in Practice

def causal_self_attention(Q, K, V, dropout_p=0.0): """ Compute causal self-attention. Args: Q, K, V: Tensors of shape (batch, seq_len, d_k) Returns: Attention output of shape (batch, seq_len, d_k) """ batch_size, seq_len, d_k = Q.shape # Compute attention scores scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # scores shape: (batch, seq_len, seq_len) # Create causal mask causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(Q.device) # Expand for batch dimension causal_mask = causal_mask.unsqueeze(0).expand(batch_size, -1, -1) # Apply mask: set future positions to -inf scores = scores.masked_fill(causal_mask == 0, float('-inf')) # Softmax to get attention weights attn_weights = torch.softmax(scores, dim=-1) # Apply dropout attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True) # Apply attention to values output = torch.matmul(attn_weights, V) return output, attn_weights

Method 3: Efficient Implementation with torch.nn.functional.scaled_dot_product_attention

import torch.nn.functional as F def efficient_causal_attention(Q, K, V, dropout_p=0.0): """ Efficient causal attention using PyTorch's optimized implementation. Available in PyTorch 2.0+ with Flash Attention support. """ # is_causal=True automatically creates and applies the causal mask output = F.scaled_dot_product_attention( Q, K, V, attn_mask=None, # No custom mask dropout_p=dropout_p, is_causal=True # Enable causal masking ) return output
Performance Tip: PyTorch 2.0's scaled_dot_product_attention with is_causal=True automatically uses Flash Attention when available, providing significant speedups and memory savings.

Attention Pattern Visualization

Let's visualize what the attention pattern looks like for the sentence "The cat sat on the mat":

Example Attention Weights (Causal)

The
cat
sat
on
the
The
1.0
0
0
0
0
cat
0.4
0.6
0
0
0
sat
0.2
0.3
0.5
0
0
on
0.1
0.15
0.25
0.5
0
the
0.08
0.12
0.2
0.3
0.3

Each row shows which previous tokens the current token attends to. Notice the triangular pattern.

Training vs. Inference

During Training

We use teacher forcing with causal masking:

# Training: Process entire sequence at once with causal mask def train_step(model, input_ids, target_ids): """ input_ids: [The, cat, sat, on, the, mat] target_ids: [cat, sat, on, the, mat, ] """ # Forward pass with causal mask logits = model(input_ids) # Each position can only see previous positions # Compute loss for all positions simultaneously loss = cross_entropy(logits, target_ids) return loss

This is efficient because we compute all positions in parallel while maintaining the autoregressive property.

During Inference

We generate one token at a time:

# Inference: Generate tokens one at a time def generate(model, prompt, max_length=50): tokens = tokenize(prompt) # [The, cat] for _ in range(max_length): # Forward pass - causal mask is automatic since there's no future logits = model(tokens) # Get next token prediction next_token = argmax(logits[-1]) # Only use last position # Append and continue tokens.append(next_token) if next_token == EOS: break return detokenize(tokens)
Key Insight: During inference, the causal constraint is naturally satisfied because future tokens don't exist yet. The KV cache optimization stores previous key/value pairs to avoid recomputation.

Advanced: Other Mask Types

Padding Mask

When batching sequences of different lengths, we pad shorter sequences. The padding mask prevents attention to padding tokens:

def create_padding_mask(seq, pad_token_id=0): """ Create a mask for padding tokens. Args: seq: (batch, seq_len) tensor of token ids Returns: mask: (batch, 1, 1, seq_len) boolean mask """ # True for real tokens, False for padding mask = (seq != pad_token_id).unsqueeze(1).unsqueeze(2) return mask # Shape: (batch, 1, 1, seq_len) # Combined with causal mask def create_combined_mask(seq, pad_token_id=0): seq_len = seq.size(1) causal = torch.tril(torch.ones(seq_len, seq_len)) padding = (seq != pad_token_id).unsqueeze(1).expand(-1, seq_len, -1) return causal * padding # Both conditions must be satisfied

Sliding Window (Local) Attention

Some models use a fixed-size window to reduce computation:

def create_sliding_window_mask(seq_len, window_size=256): """Create a mask that only attends to the last window_size tokens.""" mask = torch.zeros(seq_len, seq_len) for i in range(seq_len): start = max(0, i - window_size + 1) mask[i, start:i+1] = 1 return mask

Interactive Exercises

Exercise 1: Mask Shape

For a sequence of length 10 with batch size 4 and 8 attention heads, what is the shape of the causal mask?

(10, 10) - The mask is the same for all batches and heads
(4, 8, 10, 10)
(4, 10, 10)

Exercise 2: Attention Pattern

In a 4-token sequence, which positions can the token at position 2 (0-indexed) attend to?

Positions 0, 1, and 2
Positions 0 and 1 only
Positions 0, 1, 2, and 3

Exercise 3: Implementation

Complete the following code to apply the causal mask:

def apply_causal_mask(scores): """ scores: (batch, seq_len, seq_len) Apply causal mask to scores before softmax. """ seq_len = scores.size(-1) # Create causal mask mask = torch.____(torch.ones(seq_len, seq_len)) # Apply mask (set future positions to -inf) scores = scores.masked_fill(mask == ___, float('___')) return scores

Answer: tril, 0, -inf

Further Reading