πŸ“š Supplementary Lesson for Level 03
Level 03 β€’ Supplement

Training Transformers at Scale

Advanced optimization techniques, learning rate schedules, and distributed training strategies.

The Training Stability Problem

Training large transformers is notoriously difficult. Unlike smaller models or RNNs, transformers face unique challenges:

Key Challenges:
  • No recurrent structure: Gradients can flow directly from output to any input position
  • Attention sharpness: Attention weights can become extremely sharp early in training
  • Layer normalization position: Pre-LN vs Post-LN affects gradient flow
  • Deep stacks: 96+ layers amplify any instability

Understanding the Loss Landscape

Transformer loss landscapes are complex and non-convex. The optimization process must navigate:

Learning Rate Schedules

The Warmup Phase

Warmup is critical for transformer training. Starting with a large learning rate causes:

LRwarmup(step) = base_LR Γ— (step / warmup_steps) for step < warmup_steps

Cosine Annealing with Warmup

The most common schedule for LLM training:

import torch import math class CosineWarmupScheduler: """ Cosine annealing with linear warmup. Args: optimizer: PyTorch optimizer warmup_steps: Number of warmup steps max_steps: Total training steps base_lr: Peak learning rate after warmup min_lr: Minimum learning rate (as fraction of base_lr) """ def __init__(self, optimizer, warmup_steps, max_steps, base_lr=1e-4, min_lr_ratio=0.1): self.optimizer = optimizer self.warmup_steps = warmup_steps self.max_steps = max_steps self.base_lr = base_lr self.min_lr = base_lr * min_lr_ratio self.current_step = 0 def step(self): self.current_step += 1 if self.current_step < self.warmup_steps: # Linear warmup lr = self.base_lr * (self.current_step / self.warmup_steps) else: # Cosine decay progress = (self.current_step - self.warmup_steps) / (self.max_steps - self.warmup_steps) lr = self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr return lr def get_lr(self): return self.optimizer.param_groups[0]['lr'] # Usage example optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) scheduler = CosineWarmupScheduler( optimizer, warmup_steps=2000, max_steps=100000, base_lr=1e-4, min_lr_ratio=0.1 ) # Training loop for step in range(100000): # ... training code ... scheduler.step()

Learning Rate Schedule Visualization

Warmup (0-20%) followed by cosine decay (20%-100%):

β–“ Warmup Phase    β–“ Cosine Decay    β–“ Minimum LR

Other Schedules

Schedule Formula Best For
Linear Decay lr = base Γ— (1 - step/max_steps) Simple baselines
Polynomial lr = base Γ— (1 - step/max_steps)^p Smoother decay
Inverse Sqrt lr = base Γ— sqrt(warmup/step) Original Transformer
Constant + Cooldown Flat then linear drop Fine-tuning

Optimizer Deep Dive

Adam vs AdamW

AdamW decouples weight decay from gradient updates, leading to better generalization:

# Adam (L2 regularization - coupled) # weight_decay affects both the gradient AND the adaptive learning rate # AdamW (decoupled weight decay) # weight_decay only affects the parameter directly # This leads to better regularization optimizer = torch.optim.AdamW( model.parameters(), lr=1e-4, betas=(0.9, 0.95), # Ξ²1: momentum, Ξ²2: variance eps=1e-8, # Numerical stability weight_decay=0.1 # Decoupled L2 regularization )
Why AdamW? In standard Adam, L2 regularization gets multiplied by the adaptive learning rate (1/√v), which varies per parameter. AdamW applies weight decay directly to parameters, providing more consistent regularization across all parameters.

Hyperparameter Choices

Ξ²1 (First Moment Decay)

Controls momentum. Typical values: 0.9 to 0.99. Higher values smooth updates more.

Ξ²2 (Second Moment Decay)

Controls adaptive learning rate. Typical values: 0.95 to 0.999. Lower values adapt faster to recent gradients.

Ξ΅ (Epsilon)

Numerical stability term. Usually 1e-8. Prevents division by zero.

Modern Variants

Optimizer Key Feature When to Use
AdamW Decoupled weight decay Standard choice for transformers
Lion Sign-based, lower memory Memory-constrained training
Adafactor Factorized second moments Very large models
8-bit Adam Quantized optimizer states Memory optimization

Gradient Clipping

Essential for training stability. Prevents exploding gradients from destabilizing training:

# Gradient clipping by norm max_grad_norm = 1.0 # After backward pass, before optimizer step torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Alternative: clip by value (less common) torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
Best Practice: Always use gradient clipping when training transformers. A max_norm of 1.0 is a good starting point. Monitor gradient norms during trainingβ€”they should stabilize, not grow indefinitely.

Monitoring Gradient Norms

# Track gradient norms for debugging from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter() for step, batch in enumerate(dataloader): loss = model(batch) loss.backward() # Calculate gradient norm before clipping total_norm = 0 for p in model.parameters(): if p.grad is not None: total_norm += p.grad.data.norm(2).item() ** 2 total_norm = total_norm ** 0.5 writer.add_scalar('grad_norm/before_clip', total_norm, step) torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() optimizer.zero_grad()

Mixed Precision Training

Using FP16 or BF16 reduces memory and speeds up training on modern GPUs:

from torch.cuda.amp import autocast, GradScaler # Initialize gradient scaler for loss scaling scaler = GradScaler() for batch in dataloader: optimizer.zero_grad() # Forward pass in mixed precision with autocast(device_type='cuda', dtype=torch.float16): outputs = model(batch.input_ids) loss = criterion(outputs, batch.labels) # Backward pass with scaling scaler.scale(loss).backward() # Gradient clipping (unscale first) scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Optimizer step with scaling scaler.step(optimizer) scaler.update()
BF16 vs FP16: BF16 (Brain Float 16) has the same range as FP32 but less precision. It's more stable than FP16 and doesn't require loss scaling. Use BF16 when available (Ampere GPUs and newer).

Distributed Training Basics

Data Parallelism (DDP)

Each GPU processes a different batch, gradients are averaged:

import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP # Initialize process group dist.init_process_group(backend='nccl') # Create model and wrap with DDP model = MyTransformer(...).to(local_rank) model = DDP(model, device_ids=[local_rank]) # Training loop (same as single-GPU) for batch in dataloader: loss = model(batch) loss.backward() optimizer.step() # Gradients are automatically synchronized!

Gradient Accumulation

Simulate larger batch sizes when memory is limited:

gradient_accumulation_steps = 4 def train_step(model, batch, optimizer, step): loss = model(batch) # Scale loss for accumulation loss = loss / gradient_accumulation_steps loss.backward() # Only update every N steps if (step + 1) % gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad()
Effective Batch Size = per_device_batch_size Γ— num_devices Γ— gradient_accumulation_steps

Interactive Exercises

Exercise 1: Learning Rate Schedule

For a model with base_lr=1e-4, warmup_steps=1000, and max_steps=10000, what is the learning rate at step 500?

5e-5
1e-4
2.5e-5

Exercise 2: Gradient Clipping

If gradients have a norm of 5.0 and max_grad_norm=1.0, what is the scaling factor applied by clip_grad_norm_?

5.0
0.2
5.0

Exercise 3: Effective Batch Size

You have 4 GPUs, each processing batch_size=8, with gradient_accumulation_steps=2. What is the effective batch size?

32
64
64

Exercise 4: Complete the Training Loop

Fill in the missing pieces:

for batch in dataloader: optimizer._____() # Clear previous gradients with autocast(): # Mixed precision context loss = model(batch) ____.scale(loss).backward() # Scale and backprop # Unscale for clipping scaler._____(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler.____(optimizer) # Optimizer step scaler._____() # Update scale

Answers: zero_grad, scaler, unscale_, step, update

Training Checklist

Before Starting Training

  • Set appropriate learning rate (typically 1e-4 to 6e-4)
  • Configure warmup (typically 1-5% of total steps)
  • Enable gradient clipping (max_norm=1.0)
  • Set up mixed precision (BF16 preferred over FP16)
  • Configure weight decay (typically 0.01-0.1)
  • Set up logging (loss, learning rate, gradient norms)
  • Validate on a small batch first
  • Set up checkpointing

Further Reading