CS336 Notes: Lecture 9 - Scaling Laws 1
Scaling laws turn model building from guesswork into engineering.
This lecture from Stanford CS336 covers how loss depends on data, parameters, and compute. The core finding: on log-log axes, loss often looks like a straight line. That pattern is predictive enough to plan billion-dollar training runs.
What Scaling Laws Are
Scaling laws predict how performance changes as you scale data, parameters, or total compute.
The key empirical pattern is simple. Plot test loss against dataset size (or parameters, or compute) on log-log axes. You often see a straight line. That suggests a power law:
loss ≈ constant + c · n^(−α)
Here n is dataset size and α > 0.
Why we care: Training a frontier model is expensive. Instead of guessing, we train many smaller models, fit a curve, and use it to choose architectures, hyperparameters, data mixtures, and the model size versus training tokens before spending the full budget.
Classical Roots
Scaling is an old idea. Classical learning theory studies how error depends on sample size. With VC dimension or Rademacher complexity, you get upper bounds like excess risk ~ 1/√m. In non-parametric density estimation with smoothness β, L2 error can scale like n^(−β/(2β+1)).
These are worst-case bounds. They're often loose.
Modern scaling work flips the approach. Instead of conservative bounds, we measure real loss curves for real networks, fit simple formulas, and use them as engineering tools.
Early empirical work already looked like today. A 1993 Bell Labs line of work argued that full training on huge datasets is costly, so train smaller systems, fit a curve of "irreducible error + decaying term," and extrapolate. Banko and Brill showed smooth gains with more data in early NLP. Hestness et al. (2017) studied translation, speech, and vision and found power-law data scaling with three regions: near-random at tiny scale, a wide power-law region, and flattening toward irreducible error.
That picture still frames how people think about LLM scaling.
When Scaling Fails
Scaling is most reliable for training loss and held-out versions of the same distribution.
But scaling can fail, or reverse, on out-of-distribution or adversarial tasks. The "Inverse Scaling Prize" collected tasks where larger models do worse, often because the task punishes behaviors that scaling strengthens (like copying).
The rule: scaling laws describe training-like regimes well. Far outside that regime, behavior gets messy.
Data Scaling Laws
A data scaling law maps dataset size n to reducible error:
excess error ≈ A · n^(−α)
On log-log axes, this is a straight line with slope −α.
Across many domains and models, loss is monotone and log-log linear over wide ranges.
Why the Slope is Small
A simple parametric example: estimate the mean of a Gaussian with variance σ² from n samples. Squared error scales like 1/n. Taking logs gives a straight line with slope −1.
Deep nets usually have much smaller exponents. Reported slopes:
- About −0.13 for machine translation
- About −0.3 for speech
- About −0.095 for language modeling
These are far slower than −1 or −0.5.
The intuition: neural nets behave like non-parametric estimators in a high-dimensional space. If inputs live in a D-dimensional region, even a simple "bin the space and average locally" estimator gives rates like n^(−1/D). When D is large, the slope is tiny.
Read the slope as difficulty. Smaller slope means higher effective dimension.
Uses of Data Scaling Laws
Comparing data mixtures: Changing the data mix often shifts the curve up or down without changing the slope much. The offset tells you how good the mixture is at every scale. Compare mixtures using small runs, choose the better offset, expect the advantage to persist.
Multi-epoch training: If fresh data is limited, repeat data. Repetition helps with diminishing returns. After a few epochs, gains shrink sharply. Replace raw token count with "effective" sample size that grows more slowly as repetition increases.
Repeat high-quality data or add lower-quality data: At trillion-token scale, you choose between repeating clean sources and adding noisier new sources. Each source mixture has its own scaling curve. Use those curves to plan repetition and growth.
Model Scaling
Once you understand data scaling, ask how performance changes with parameters and compute.
Comparing architectures: Kaplan-style comparisons train different architectures across compute levels. The transformer line sits well below LSTM, with a roughly constant efficiency gap. For the same loss, LSTMs need far more compute.
Google's architecture sweeps compared many variants to a transformer baseline. A few variants (GLU, mixture-of-experts) beat the baseline. Many others add overhead without improving scaling.
Optimizer choice: Scaling curves quantify optimizer differences. Across scales, Adam often sits below SGD, like a near-constant factor improvement.
Depth versus width: Very shallow transformers do poorly. Once depth clears a small threshold, there's a broad band of depth/width ratios with similar performance. Curves share slopes and differ by small offsets. Tune aspect ratio on small models and trust it at larger ones.
Not all parameters are equal: Embedding parameters can bend parameter-loss plots. Counting only non-embedding parameters often restores straight lines. For mixture-of-experts, use "equivalent dense" count since not all parameters are active per token.
Batch Size and Learning Rate
Critical batch size: Increasing batch size reduces gradient noise. At first, that improves efficiency. Past a critical batch size, extra samples mostly reduce noise that's already small. Progress per FLOP drops.
Two regimes: below critical, doubling batch size is close to doubling useful steps. Above critical, returns diminish quickly.
As you push to lower loss, critical batch size grows. Large runs commonly increase batch size as training progresses.
Learning rate versus width: With standard parameterizations, wider models need smaller learning rates to stay stable. Optimal learning rate shrinks as width grows.
Scale-aware parameterizations (muP): Adjust initialization and layer scaling so the same learning rate works across widths. Reduces retuning, though transfer isn't perfect.
Joint Data-Model Scaling
Under a fixed compute budget, data and model size trade off.
Tiny model with huge data: model can't use it. Huge model with tiny data: capacity wasted, undertrained.
A common joint form:
loss(n, N) ≈ C + A · n^(−α) + B · N^(−β)
n is tokens, N is parameters, C is irreducible error.
These surfaces fit real measurements well and predict larger runs from smaller ones.
Chinchilla: Compute-Optimal Training
Chinchilla revisited the compute-optimal tradeoff for dense LMs using three methods.
Method 1: Lower envelope. Train many runs, plot loss versus compute, take the best at each compute point. Read off N and n along that envelope. Fit how optimal N and n scale with FLOPs. Both scale roughly like FLOPs^(1/2), implying optimal n is proportional to optimal N. The constant: about 20 tokens per parameter.
Method 2: IsoFLOP sweeps. For each compute budget, sweep model size and adjust tokens to keep FLOPs fixed. Loss as a function of N is U-shaped. Fit the minimum to get optimal N and n. Matches Method 1.
Method 3: Surface fitting. Fit the joint surface and solve for compute-optimal point. The original paper's Method 3 had a fitting bug. Corrected, it lines up with Methods 1 and 2.
The Chinchilla rule holds: for compute-optimal dense LM training, use roughly 20 tokens per parameter.
Training-Optimal vs Product-Optimal
Chinchilla optimizes training loss for a fixed training compute budget. That's training-optimal.
Products care about inference cost, which scales with parameters (and context length). If inference dominates, you prefer smaller models trained on more tokens per parameter, even if training costs more.
This explains the shift:
- GPT-3: undertrained by Chinchilla (about 2 tokens per parameter).
- Chinchilla: about 20 tokens per parameter.
- Newer systems: higher ratios to reduce inference cost per quality.
Limitations
Pretraining loss vs downstream tasks: Scaling laws are cleanest for cross-entropy and held-out perplexity. Downstream benchmarks can behave less smoothly. A system that scales well in perplexity may lag on capabilities.
Out-of-distribution behavior: On adversarial or OOD tasks, scaling can fail or reverse.
The Core Lesson
Data scaling explains how loss drops with more data and why the slope is small in high effective dimension.
Model scaling lets us compare architectures, optimizers, and depth/width choices at small scale and trust those comparisons at large scale.
Joint scaling, especially Chinchilla, tells us how to split a fixed compute budget between parameters and tokens.
Batch size and learning rate follow their own predictable patterns. Scale-aware parameterizations reduce retuning.
Together, these tools let us test cheaply, fit simple curves, and spend large budgets with fewer surprises.
Keep reading
You might also like
CS336 Notes: Lecture 11 - Scaling Laws 2
Practical scaling: muP for hyperparameter transfer, WSD learning rate schedules, case studies from Cerebras-GPT, MiniCPM, and DeepSeek on compute-optimal training.