CS336 Notes: Lecture 7 - Parallelism 1
You cannot train a frontier language model on one GPU. The model doesn't fit. The data doesn't fit. The compute takes too long.
This lecture from Stanford CS336 covers how to split work across many GPUs. The core constraint: three resources trade off against each other. Compute (FLOPs), memory, and communication bandwidth. Batch size is a hidden fourth resource that some parallelism strategies consume.
Why Distributed Training Exists
GPUs keep getting faster, but model needs grow faster still.
Big models strain one GPU in two ways:
- Parameters: billions of weights don't fit in one device's memory.
- Activations: long sequences and deep networks create large intermediate tensors.
The new unit of compute is the data center, not one GPU. We want near-linear scaling: double the GPUs, roughly double the capacity and throughput.
Hardware Hierarchy
A typical GPU cluster has layers:
- Inside a node: multiple GPUs (often 8) connected by very fast links (NVLink/NVSwitch).
- Across nodes: network fabric (often InfiniBand), slower and higher-latency than NVLink.
- Past a few hundred GPUs: traffic crosses even slower switch tiers.
This hierarchy decides which parallel strategies work where. Chatty algorithms work inside a node where bandwidth is high. Across nodes, you must be more careful.
Five Collectives You Need to Know
These primitives show up everywhere in distributed training:
Allreduce: Each rank starts with data. Reduce (e.g., sum) across ranks. Every rank receives the same reduced result. Bandwidth cost: about 2x the data size.
Broadcast: One rank has the input. Copy to all ranks. Cost: about 1x the data size.
Reduce: Many ranks send data. Only one rank receives the reduced result.
Allgather: Each rank starts with one chunk. Every rank receives the concatenation of all chunks.
Reducescatter: Each rank starts with a chunk. Reduce across ranks. Split the reduced result so each rank keeps only its shard.
Key identity: allreduce = reducescatter + allgather. In bandwidth-limited regimes, they have the same optimal cost. This identity is the backbone of ZeRO and FSDP.
Three Types of Parallelism
Data parallelism: Copy the full model on each GPU. Split the batch across GPUs. Synchronize gradients.
Model parallelism: Split parameters across GPUs. Exchange activations instead of weights.
Activation parallelism: Split activations so no GPU must hold all intermediate tensors. Critical for long sequences and deep networks.
Data Parallelism: Simple but Memory-Heavy
Naive data parallel SGD works like this:
- Global batch size B, M GPUs.
- Each GPU processes B/M examples.
- Each GPU runs forward/backward, computes gradients.
- Allreduce gradients so every GPU has the full gradient.
- Each GPU applies the update locally.
Why it's appealing: Compute scales well. Communication per step is about 2x the number of parameters. With enough batch, communication hides under compute.
The memory problem: Every GPU stores a full copy of parameters, gradients, and optimizer state.
Optimizer state dominates. With Adam-style training, per parameter you store: bf16 parameter (2 bytes), bf16 gradient (2 bytes), fp32 master weight (4 bytes), first moment (4 bytes), second moment (4 bytes). Total: about 16 bytes per parameter, not 2.
A 7.5B model on 64 accelerators can require about 120 GB of parameter-related memory. In naive data parallel, most of that is redundant copies.
Core question: Do we need the full optimizer state on every GPU? If not, we can shard it.
ZeRO Stage 1: Shard Optimizer State
ZeRO means "Zero Redundancy Optimizer."
Stage 1 idea: Parameters and gradients stay replicated. Optimizer state is sharded across GPUs.
How it runs:
- Each GPU computes gradients on its mini-batch.
- Use reducescatter on gradients so the summed gradient shard lands on the "owner" GPU.
- Each GPU updates only its shard (it has the needed optimizer state).
- Use allgather to rebuild full updated parameters on all GPUs.
Communication: reducescatter + allgather has the same bandwidth cost as one allreduce. Stage 1 adds almost no bandwidth cost versus naive data parallel.
Benefit: Optimizer state memory shrinks by the number of GPUs.
ZeRO Stage 2: Also Shard Gradients
Stage 2 keeps Stage 1 and also shards gradients.
During backward, as each layer's gradients are produced, immediately reducescatter them to the owners. Free non-owned gradient buffers right away.
After backward: each GPU holds only its gradient shard and optimizer-state shard. Parameters are still replicated. Updates happen on shards, then allgather restores full parameters.
Communication: still about 2x parameters per step, with more frequent smaller messages.
Benefit: Optimizer state and gradients both shrink by the number of GPUs.
ZeRO Stage 3 / FSDP: Shard Everything
Stage 3 shards parameters, gradients, and optimizer state. No GPU holds the full model at once.
Per-layer pattern:
Forward:
- Allgather that layer's parameter shards so each GPU can run the layer.
- Run forward on local data.
- Free the full parameters for that layer.
Backward:
- Allgather parameters again if needed.
- Compute gradients.
- Reducescatter gradients so each owner keeps only its shard.
- Owners update their shards.
- Free full parameters and temporary gradients.
Communication per step: about 3x parameters (vs 2x for simpler data parallel).
Why it can still work well: Communication can overlap compute. While the GPU computes layer k, the network brings in parameters for layer k+1.
Result: Near-minimum memory per GPU. Extra communication that's often worth it to fit larger models.
Limits of Data Parallelism
Even with ZeRO/FSDP, data parallelism hits walls:
- Batch size is finite. You can't have more replicas than the global batch size.
- Past a critical batch size, bigger batch gives diminishing returns.
- Activations don't automatically shrink. Stage 1 and 2 do nothing for activation memory. Stage 3 helps parameters and optimizer state, but activations still dominate for long sequences.
When batch can't grow and activations are large, you need model parallelism.
Pipeline Parallelism: Split by Depth
Split the network into stages. GPU 0 gets early layers, GPU 1 gets middle layers, GPU 2 gets later layers.
Forward: Activations flow stage to stage. Backward: Gradients flow back from last stage to first.
Naive problem: bubbles. With one microbatch, most GPUs sit idle. Utilization is roughly 1 / number of stages.
Microbatching fixes this. Split the batch into M microbatches. As GPU 0 sends microbatch 0 forward, it starts microbatch 1. The pipeline fills, keeping more GPUs busy.
Bubble fraction scales like (S - 1) / M where S is stages. To keep bubbles small, M must be much larger than S.
Tradeoffs:
- Pros: Strong memory scaling. Each GPU stores only its layers. Communication is point-to-point activations, which works over slower links.
- Cons: Consumes batch size through microbatches. Scheduling is tricky. Efficient zero-bubble schedules are hard to build and maintain.
Tensor Parallelism: Split by Width
Most transformer compute is in matmuls. Tensor parallel splits those matmuls across GPUs.
Shard large weight matrices. Each GPU computes a partial result. Use collectives (allreduce/allgather/reducescatter) to assemble the right activations.
Placement matters. Tensor parallel creates communication points every layer. It needs very high bandwidth. It usually stays inside a fast node (e.g., 8 GPUs on NVLink/NVSwitch).
Empirical pattern: TP 1 to 8 inside a node costs modest throughput (10-12%). TP 16 or 32 across nodes can cost a lot (40-65%) due to slower links.
Pros: Doesn't consume batch size. Matches matmul-heavy workloads well. Works well on fast intra-node links.
Cons: Bandwidth hungry. Sensitive to hardware topology.
Activation Memory and Sequence Parallelism
Activation memory grows during forward and shrinks during backward. Peak memory often occurs mid-backward.
Even with tensor or pipeline parallel, some activation terms remain large. A rough per-layer view: one term scales like sequence × batch × hidden (MLP and pointwise ops), another scales like sequence² × batch (attention).
Tensor parallel reduces many terms, but not all. Pointwise ops like LayerNorm and Dropout can stay effectively unsharded in naive TP.
Sequence parallelism fixes this. Shard along the sequence dimension for pointwise work. Each GPU handles about sequence/TP tokens. When a full sequence view is needed, use allgather/reducescatter.
Combined with recomputation: Flash-style attention avoids storing full attention matrices. With tensor + sequence parallel + recomputation, activation memory per device drops significantly.
Rule of Thumb for Mixing Parallelism
1. Make the model fit.
Use tensor parallel inside each node (often up to GPUs per node, like TP = 8). If it still doesn't fit, add FSDP/Stage 3 across nodes and/or pipeline parallel.
2. Scale throughput with data parallelism.
Once memory fits, use more data-parallel replicas to increase total FLOPs. It works over slower links and is easiest to reason about.
3. Treat batch size as a hard budget.
Pipeline and data parallel both consume batch size (microbatches and replicas). If batch is small but you want fewer synchronizations, use gradient accumulation.
Real-World Examples
Megatron-style training: Mix data, tensor, and pipeline parallelism. As models grow from 1B toward 1T parameters, tensor parallel caps around 8, pipeline parallel increases with depth, data parallel adjusts to balance batch and compute. Achieves roughly 40-50% of theoretical GPU FLOPs.
DeepSeek: Combines ZeRO Stage 1, tensor parallel, sequence parallel, and pipeline parallel. Newer versions add expert parallelism for MoE.
Llama 3: 8-way tensor parallel inside nodes. Pipeline and data parallel across racks. Context parallel for long-context phases.
Smaller runs (e.g., OLMo 7B): Often use simpler setups like FSDP without heavy model parallelism, because the model fits more easily.
In real clusters, GPU failures happen. Fault-tolerant infrastructure matters.
The Core Lesson
Frontier-scale training is a distributed systems problem.
There's no single best parallel strategy. You mix data, model, and activation parallelism to match memory limits, network topology, and batch-size budget.
The allreduce vs reducescatter + allgather equivalence is key for designing sharded optimizers.
ZeRO and FSDP remove redundant memory copies without exploding bandwidth.
Pipeline and tensor parallel split the model by layers and by matmul width.
Sequence parallelism and recomputation attack the remaining activation bottleneck.
Think in data centers, not GPUs.
Keep reading
You might also like
CS336 Notes: Lecture 8 - Parallelism 2
Hands-on distributed training: implementing collectives with PyTorch and NCCL, data/tensor/pipeline parallelism in practice, and understanding the compute-memory-communication tradeoff.