import torch
import math
class CosineWarmupScheduler:
"""
Cosine annealing with linear warmup.
Args:
optimizer: PyTorch optimizer
warmup_steps: Number of warmup steps
max_steps: Total training steps
base_lr: Peak learning rate after warmup
min_lr: Minimum learning rate (as fraction of base_lr)
"""
def __init__(self, optimizer, warmup_steps, max_steps, base_lr=1e-4, min_lr_ratio=0.1):
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.max_steps = max_steps
self.base_lr = base_lr
self.min_lr = base_lr * min_lr_ratio
self.current_step = 0
def step(self):
self.current_step += 1
if self.current_step < self.warmup_steps:
# Linear warmup
lr = self.base_lr * (self.current_step / self.warmup_steps)
else:
# Cosine decay
progress = (self.current_step - self.warmup_steps) / (self.max_steps - self.warmup_steps)
lr = self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
return lr
def get_lr(self):
return self.optimizer.param_groups[0]['lr']
# Usage example
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = CosineWarmupScheduler(
optimizer,
warmup_steps=2000,
max_steps=100000,
base_lr=1e-4,
min_lr_ratio=0.1
)
# Training loop
for step in range(100000):
# ... training code ...
scheduler.step()
Learning Rate Schedule Visualization
Warmup (0-20%) followed by cosine decay (20%-100%):
β Warmup Phase
β Cosine Decay
β Minimum LR
Other Schedules
Schedule
Formula
Best For
Linear Decay
lr = base Γ (1 - step/max_steps)
Simple baselines
Polynomial
lr = base Γ (1 - step/max_steps)^p
Smoother decay
Inverse Sqrt
lr = base Γ sqrt(warmup/step)
Original Transformer
Constant + Cooldown
Flat then linear drop
Fine-tuning
Optimizer Deep Dive
Adam vs AdamW
AdamW decouples weight decay from gradient updates, leading to better generalization:
# Adam (L2 regularization - coupled)
# weight_decay affects both the gradient AND the adaptive learning rate
# AdamW (decoupled weight decay)
# weight_decay only affects the parameter directly
# This leads to better regularization
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
betas=(0.9, 0.95), # Ξ²1: momentum, Ξ²2: variance
eps=1e-8, # Numerical stability
weight_decay=0.1 # Decoupled L2 regularization
)
Why AdamW? In standard Adam, L2 regularization gets multiplied by the adaptive learning rate (1/βv), which varies per parameter. AdamW applies weight decay directly to parameters, providing more consistent regularization across all parameters.
Hyperparameter Choices
Ξ²1 (First Moment Decay)
Controls momentum. Typical values: 0.9 to 0.99. Higher values smooth updates more.
Ξ²2 (Second Moment Decay)
Controls adaptive learning rate. Typical values: 0.95 to 0.999. Lower values adapt faster to recent gradients.
Ξ΅ (Epsilon)
Numerical stability term. Usually 1e-8. Prevents division by zero.
Modern Variants
Optimizer
Key Feature
When to Use
AdamW
Decoupled weight decay
Standard choice for transformers
Lion
Sign-based, lower memory
Memory-constrained training
Adafactor
Factorized second moments
Very large models
8-bit Adam
Quantized optimizer states
Memory optimization
Gradient Clipping
Essential for training stability. Prevents exploding gradients from destabilizing training:
# Gradient clipping by norm
max_grad_norm = 1.0
# After backward pass, before optimizer step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
# Alternative: clip by value (less common)
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
Best Practice: Always use gradient clipping when training transformers. A max_norm of 1.0 is a good starting point. Monitor gradient norms during trainingβthey should stabilize, not grow indefinitely.
Monitoring Gradient Norms
# Track gradient norms for debugging
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
for step, batch in enumerate(dataloader):
loss = model(batch)
loss.backward()
# Calculate gradient norm before clipping
total_norm = 0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
writer.add_scalar('grad_norm/before_clip', total_norm, step)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()
optimizer.zero_grad()
Mixed Precision Training
Using FP16 or BF16 reduces memory and speeds up training on modern GPUs:
from torch.cuda.amp import autocast, GradScaler
# Initialize gradient scaler for loss scaling
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast(device_type='cuda', dtype=torch.float16):
outputs = model(batch.input_ids)
loss = criterion(outputs, batch.labels)
# Backward pass with scaling
scaler.scale(loss).backward()
# Gradient clipping (unscale first)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
# Optimizer step with scaling
scaler.step(optimizer)
scaler.update()
BF16 vs FP16: BF16 (Brain Float 16) has the same range as FP32 but less precision. It's more stable than FP16 and doesn't require loss scaling. Use BF16 when available (Ampere GPUs and newer).
Distributed Training Basics
Data Parallelism (DDP)
Each GPU processes a different batch, gradients are averaged:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# Initialize process group
dist.init_process_group(backend='nccl')
# Create model and wrap with DDP
model = MyTransformer(...).to(local_rank)
model = DDP(model, device_ids=[local_rank])
# Training loop (same as single-GPU)
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
# Gradients are automatically synchronized!
Gradient Accumulation
Simulate larger batch sizes when memory is limited:
gradient_accumulation_steps = 4
def train_step(model, batch, optimizer, step):
loss = model(batch)
# Scale loss for accumulation
loss = loss / gradient_accumulation_steps
loss.backward()
# Only update every N steps
if (step + 1) % gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()