The Modern LLM Training Pipeline
Training a production LLM isn't a single step — it's a pipeline of stages, each building on the previous. Understanding this pipeline is crucial for understanding how models like GPT-4 and Claude are created.
The Three-Stage Pipeline
Pre-training: The Foundation
Pre-training is where the model learns the vast majority of its knowledge. It's trained to predict the next token on enormous datasets scraped from the internet.
The Objective: Next Token Prediction
This is the cross-entropy loss — we want the model to assign high probability to the actual next token. Minimizing this loss means getting better at prediction.
Training at Scale
GPT-3 was trained on:
- 300 billion tokens (Common Crawl, WebText, Books, Wikipedia)
- 175 billion parameters
- 3.14 × 10²³ FLOPs of compute
- ~$4.6 million in compute costs (at the time)
Optimization: How Weights Update
Training means adjusting billions of weights to minimize loss. We use optimization algorithms that compute which direction to adjust each weight.
Gradient Descent
Where:
- θ (theta): The model's parameters (weights)
- α (alpha): Learning rate — how big of a step to take
- ∇J(θ): Gradient — which direction reduces loss
Modern Optimizers
Optimizer Comparison
SGD
Basic gradient descent. Simple but slow.
Momentum
Accelerates in consistent directions.
Adam
Adaptive learning rates per parameter. Most popular.
AdamW
Adam with proper weight decay. Used in LLMs.
m_t = β₁ · m_{t-1} + (1-β₁) · g_t (momentum)
v_t = β₂ · v_{t-1} + (1-β₂) · g_t² (second moment)
θ_t = θ_{t-1} - α · m_t / (√v_t + ε)
Backpropagation: Computing Gradients
How do we compute ∇J(θ) for billions of parameters? Backpropagation — the chain rule applied to computational graphs.
The Chain Rule
If y = f(g(x)), then:
Neural networks are compositions of functions, so we apply the chain rule repeatedly, working backwards from the output to compute gradients for every parameter.
Computational Graph
A neural network is a computational graph. Forward pass computes values; backward pass computes gradients:
Backward: ∇L ← [∂L/∂ŷ] ← ∇ŷ ← [∂ŷ/∂h] ← ∇h ← [∂h/∂z₁] ← ∇z₁ ← [∂z₁/∂W₁] ← ∇W₁
RLHF: Learning from Human Preferences
Pre-trained models predict text but aren't helpful assistants. RLHF (Reinforcement Learning from Human Feedback) aligns them with human preferences.
The RLHF Pipeline
RLHF Steps
Humans rank multiple model outputs for the same prompt
Learn to predict human preferences from rankings
Use RL to maximize reward while staying close to SFT model
Why RLHF Works
- Direct optimization: Instead of imitating humans, optimize for what humans prefer
- Multi-objective: Balance helpfulness, harmlessness, honesty
- Generalization: Reward model generalizes to unseen prompts
L = E[min(r(θ) · A, clip(r(θ), 1-ε, 1+ε) · A)]
where r(θ) = π_θ(a|s) / π_old(a|s) and A is advantage
Training Challenges at Scale
Training billion-parameter models isn't just "bigger batch size." It introduces unique challenges:
1. Memory
GPT-3's 175B parameters need 700GB just to store (4 bytes each). During training, we need:
- Parameters (4 bytes each)
- Gradients (4 bytes each)
- Optimizer states (8+ bytes each for Adam)
- Activations (for backprop)
Total: ~2-4TB of GPU memory! We use techniques like:
- Mixed precision: Use FP16 instead of FP32
- Gradient checkpointing: Recompute activations instead of storing
- Model parallelism: Split model across GPUs
- ZeRO: Partition optimizer states across GPUs
2. Stability
Deep networks are hard to train. Solutions include:
- Layer normalization: Normalize activations
- Residual connections: Skip connections help gradients flow
- Learning rate warmup: Start small, gradually increase
- Gradient clipping: Prevent exploding gradients
3. Convergence
Training takes weeks on thousands of GPUs. Any crash is expensive. Solutions:
- Checkpointing: Save model state frequently
- Fault tolerance: Handle GPU failures gracefully
- Monitoring: Watch loss curves, gradient norms, learning rates
Data: The Fuel for Training
The quality of training data is arguably the single most important factor in LLM performance. Models are limited by what they see during training — garbage in, garbage out.
What Data Do LLMs Train On?
Typical Pre-training Data Composition
| Data Source | Proportion | Quality | Key Properties |
|---|---|---|---|
| Web crawl (CommonCrawl) | 60-80% | Variable | Huge scale, lots of noise, many languages |
| Curated web (Wikipedia, etc.) | 5-15% | High | Well-structured, factual, multi-language |
| Books | 5-15% | Very High | Long-form reasoning, diverse knowledge |
| Code (GitHub) | 5-20% | High | Programming knowledge, logical reasoning |
| Scientific papers | 1-5% | Very High | Specialized knowledge, reasoning |
| Synthetic data | 0-30% | Varies | Generated by other models, targeted skills |
Data Curation Pipeline
Raw web data is messy. A typical pipeline includes:
- Language filtering: Remove non-target languages using fast classifiers
- Quality filtering: Remove low-quality content (spam, machine-translated text, boilerplate)
- Deduplication: Remove exact and near-duplicate documents (crucial for preventing memorization)
- PII removal: Remove personally identifiable information (email addresses, phone numbers)
- Toxic content filtering: Remove harmful, abusive, or explicit content
- Quality scoring: Rank documents by quality metrics (perplexity under a reference model, etc.)
Scaling Laws: Predicting Performance
One of the most important discoveries in deep learning is that LLM performance follows predictable scaling laws — smooth curves that relate model size, data size, and compute budget to final loss.
The Power Law
The key insight: loss decreases as a power law with respect to model size (N), data size (D), and compute budget (C):
L(D) ≈ (D_c/D)^α_D (loss vs. data)
L(C) ≈ (C_c/C)^α_C (loss vs. compute)
Where α_N ≈ 0.076, α_D ≈ 0.095, α_C ≈ 0.050 (from the Chinchilla paper, Kaplan et al.). These exponents mean that to halve the loss, you need roughly 10× more parameters or 10× more data.
Practical Implications
Chinchilla-Optimal Model Sizes
| Parameters | Optimal Training Tokens | Compute Budget (FLOPs) | Approx. GPU-Hours |
|---|---|---|---|
| 400M | 8B | 1.3 × 10¹⁸ | ~90 A100-hours |
| 1B | 20B | 8.0 × 10¹⁸ | ~600 A100-hours |
| 7B | 140B | 3.9 × 10²⁰ | ~29,000 A100-hours |
| 70B | 1.4T | 3.9 × 10²² | ~2.9M A100-hours |
| 175B | 3.5T | 2.4 × 10²³ | ~18M A100-hours |
Distributed Training: Training on Many GPUs
Modern LLMs require thousands of GPUs working together. A training run that would take 100 years on one GPU can be done in weeks on thousands of GPUs. But distributing training introduces significant engineering challenges.
Data Parallelism
The simplest approach: each GPU holds a complete copy of the model, but processes different data batches. Gradients are averaged across all GPUs before each update.
Limitation: Each GPU must hold the entire model. DeepSpeed ZeRO addresses this by partitioning optimizer states, gradients, and parameters across GPUs.
Model Parallelism
Parallelism Strategies
Split each layer's matrices across GPUs. E.g., split an 8192×8192 weight matrix into 4 GPUs each holding 8192×2048.
Different layers on different GPUs. GPU 1 handles layers 1-12, GPU 2 handles 13-24, etc. Requires careful scheduling to avoid pipeline bubbles.
Partition optimizer states, gradients, and parameters across GPUs. No single GPU holds the full model. ZeRO-3 can train 1T+ parameter models.
Communication Overhead
Distributed training requires GPUs to communicate — sharing gradients, synchronizing parameters, passing activations. The communication cost depends on the interconnect:
| Interconnect | Bandwidth | Latency | Use Case |
|---|---|---|---|
| PCIe 4.0 | ~32 GB/s | ~1μs | Within single server |
| NVLink | ~300 GB/s | ~0.5μs | GPU-to-GPU (same node) |
| InfiniBand | ~200 Gb/s | ~0.6μs | Between nodes |
| Ethernet | ~100 Gb/s | ~5μs | Slower inter-node |
Fine-Tuning: Adapting the Base Model
Pre-training produces a "base model" that predicts the next token. But we want models that follow instructions, answer questions, and behave safely. Fine-tuning adapts the pre-trained model to specific tasks.
Supervised Fine-Tuning (SFT)
The simplest form: continue training the model on high-quality instruction-response pairs.
SFT Training Example
<|user|> Explain quantum computing in simple terms.<|assistant|>
Target:
Quantum computing uses quantum mechanical phenomena like superposition and entanglement to process information. Unlike classical bits that are 0 or 1, quantum bits (qubits) can be in multiple states simultaneously...<|end|>
SFT data is expensive — each example requires human writing. Typical datasets: 10K-100K high-quality conversations.
Parameter-Efficient Fine-Tuning (PEFT)
Full fine-tuning updates all parameters, which is expensive and can cause catastrophic forgetting. PEFT methods update only a small number of parameters:
LoRA: Low-Rank Adaptation
Instead of updating the full weight matrix W ∈ ℝ^(d×d), LoRA learns two small matrices:
where B ∈ ℝ^(d×r), A ∈ ℝ^(r×d), r ≪ d
For d = 4096 and r = 8, we reduce learnable parameters from 16.7M to just 65K — a 250× reduction!
| Method | Trainable Params | Memory | Quality |
|---|---|---|---|
| Full Fine-Tuning | 100% | Very High | Best |
| LoRA (r=8) | ~0.1% | Low | Near full |
| LoRA (r=16) | ~0.5% | Low | Comparable |
| QLoRA (4-bit) | ~0.1% | Very Low | Good |
| Prompt Tuning | ~0.01% | Minimal | Decent |
Direct Preference Optimization (DPO)
While RLHF was the original alignment method, a newer approach called Direct Preference Optimization has become increasingly popular. DPO eliminates the need for a separate reward model and RL training loop.
The DPO Insight
RLHF requires training a reward model, then optimizing the policy with PPO. DPO shows that you can directly optimize the policy using the preference data, skipping the reward model entirely.
Where y_w is the preferred response, y_l is the dispreferred response, π_θ is the policy being trained, π_ref is the reference (SFT) model, and β controls deviation from the reference.
RLHF vs DPO Comparison
| Aspect | RLHF | DPO |
|---|---|---|
| Requires reward model | Yes | No |
| Requires RL optimization | Yes (PPO) | No |
| Training stability | Can be unstable | Stable (classification loss) |
| Implementation complexity | High (4 models needed) | Low (2 models needed) |
| GPU memory | Very high | Moderate |
| Performance | Strong | Comparable or better |
DPO has become the dominant approach for alignment training. Models like LLaMA 3, Mistral, and many others use DPO or its variants (IPO, KTO) instead of RLHF.
Evaluation: How Do We Know If It's Good?
Measuring LLM quality is surprisingly hard. No single metric captures everything.
Language Modeling Metrics
Core Evaluation Metrics
| Metric | What It Measures | How It Works |
|---|---|---|
| Perplexity | Next-token prediction quality | exp(avg negative log-likelihood). Lower is better. GPT-4 ≈ 10-15 on web text. |
| HellaSwag | Common-sense reasoning | Choose the correct sentence completion from 4 options. |
| MMLU | Multi-task knowledge | 57 subjects, from math to law to medicine. 4-way multiple choice. |
| HumanEval | Code generation | Generate Python functions from docstrings. Pass unit tests? |
| GSM8K | Math reasoning | 8,500 grade-school math word problems. Chain-of-thought. |
| MT-Bench | Chat quality | Multi-turn conversations judged by GPT-4 or humans. |
The Evaluation Problem
- Contamination: Test data may appear in training data, inflating scores
- Gaming benchmarks: Models can overfit to specific benchmarks without true improvement
- Instruction following vs. knowledge: A model may know the answer but not follow format
- Safety and alignment: Benchmarks don't capture whether a model is safe and helpful
- Generalization: Performance on benchmarks doesn't guarantee real-world performance
Inference Optimization: Making Models Fast
A trained model is useless if it's too slow or expensive to run. Optimizing inference (running the model) is crucial for deployment.
Quantization
Reducing the precision of model weights from 32-bit or 16-bit floats to lower precision:
Quantization Levels
| Precision | Bits per Param | Memory (70B model) | Quality Loss |
|---|---|---|---|
| FP32 | 32 | 280 GB | None (baseline) |
| FP16/BF16 | 16 | 140 GB | Negligible |
| INT8 | 8 | 70 GB | Very small |
| INT4 (GPTQ) | 4 | 35 GB | Small, acceptable |
| INT3 | 3 | 26 GB | Noticeable |
Other Inference Techniques
- Flash Attention: Memory-efficient attention that avoids materializing the full attention matrix. 2-4× faster attention.
- Speculative Decoding: Use a small draft model to guess tokens, then verify with the large model. Can double throughput.
- Batching: Process multiple requests together on the same GPU. Essential for throughput.
- Continuous Batching: Dynamically add and remove requests from a batch as they complete. Maximizes GPU utilization.
- VLLM / PagedAttention: Manage KV cache memory like virtual memory with paging. Reduces memory waste from 60-80% to <4%.