The Problem: Internal Covariate Shift
As weights change during training, the distribution of inputs to each layer also changes. This makes training unstable and slow.
Internal Covariate Shift: The change in the distribution of network activations
due to the change in network parameters during training.
Each layer must continuously adapt to new input distributions, making learning difficult.
The Solution: Batch Normalization
Normalize the inputs to each layer to have mean 0 and variance 1:
# Batch Normalization (training)
μ = mean(x over batch)
σ² = variance(x over batch)
x̂ = (x - μ) / √(σ² + ε) # Normalize
y = γ·x̂ + β # Scale and shift
γ (gamma) and β (beta) are learned parameters
At Test Time
Use running averages of μ and σ collected during training:
# Batch Normalization (inference)
x̂ = (x - μ_running) / √(σ²_running + ε)
y = γ·x̂ + β
Benefits of BatchNorm
Why It Works
- Stabilizes training: Reduces internal covariate shift
- Allows higher learning rates: Gradients don't explode/vanish as easily
- Reduces sensitivity to initialization: Less dependent on good starting weights
- Acts as regularization: Noise from batch statistics helps prevent overfitting
Key Insight: BatchNorm makes the optimization landscape smoother,
allowing larger step sizes and faster convergence.
Layer Normalization
For sequence models (transformers, RNNs), we use Layer Normalization instead:
# Layer Normalization
μ = mean(x over features) # Mean across feature dimension
σ² = variance(x over features)
x̂ = (x - μ) / √(σ² + ε)
y = γ·x̂ + β
BatchNorm vs LayerNorm
| Aspect | BatchNorm | LayerNorm |
|---|---|---|
| Normalizes across | Batch dimension | Feature dimension |
| Works with batch size 1? | No | Yes |
| Used in | CNNs | Transformers, RNNs |
Transformers use LayerNorm because:
- Sequences have variable lengths
- Training and inference behave the same
- No dependency on batch size
Pre-Norm vs Post-Norm
Where to put normalization in transformer blocks?
# Post-Norm (original Transformer)
x = x + Attention(LayerNorm(x))
x = x + FFN(LayerNorm(x))
# Pre-Norm (modern, more stable)
x = LayerNorm(x + Attention(x))
x = LayerNorm(x + FFN(x))
Pre-Norm is now standard: It makes training more stable, especially for very deep networks.
GPT, LLaMA, and most modern models use pre-norm.