CS336 Notes: Lecture 4 - Mixture of Experts
If you want the best model for a fixed FLOP budget, you probably want an MoE design. This lecture explains why, and what makes it hard.
The core idea: replace each dense MLP with many smaller MLPs called experts, plus a router that picks a few experts per token. You get more parameters without raising per-step FLOPs. More capacity to store knowledge, same compute per forward pass.
Ten Claims About MoE
MoE beats dense at fixed training FLOPs. Across many papers and scales, well-designed MoE models reach lower loss and better benchmarks than dense models at the same compute.
MoE adds parameters without adding proportional compute. If each expert matches the dense MLP size and you route to one expert per token, FLOPs per forward pass match the dense model. Total parameters increase.
The hard parts are systems and optimization, not the forward pass. Routing is discrete. Load can skew. Training can get unstable without guardrails.
Most strong MoE LLMs use token-choice top-k routing. Each token scores experts and picks the top k. This dominates in practice.
Fine-grained experts matter more than shared experts. Splitting the MLP into many small pieces increases capacity. Shared experts can help but aren't consistently necessary.
Routers are simple because they have to be. A linear layer plus softmax plus top-k. The learning signal is weak, so extra router complexity usually doesn't help.
Load balancing is essential. Without it, training collapses to a few experts. The rest never learn.
MoE creates expert parallelism. Experts are sharded across devices. Tokens move to the devices hosting their experts, then move back. This forces all-to-all communication.
Token dropping makes inference non-deterministic. If an expert is overloaded, some systems drop the overflow. Another request in the same batch can change which tokens get dropped.
Stability tricks include float32 router softmax, z-loss on router logits, and large supervised fine-tuning sets.
Why MoE Matters
MoE powers many frontier systems: GPT-like systems (as hinted by leaks), Grok models, DeepSeek, and Llama 4. By 2025, MoE's advantage over dense models at fixed training FLOPs is clear in practice.
What MoE Actually Is
"Mixture of experts" is a misleading name. It doesn't mean a hand-designed "coding expert" or "English expert." MoE is an architectural pattern inside the network, usually inside the MLP blocks.
A standard dense transformer layer has self-attention and a dense MLP.
A sparse MoE layer keeps attention the same and changes only the MLP:
- Many small MLPs, called experts.
- A router that looks at a token's hidden state and picks a small number of experts (top-k).
- Only those experts run for that token.
- Their outputs are combined (often a weighted sum) and added back into the residual stream.
One big MLP becomes a router plus many small MLPs.
FLOPs, Parameters, and Why MoE Helps
If each expert matches the dense MLP size and you route to one expert per token (k = 1):
- FLOPs per forward pass match the dense model.
- Total parameters increase because you have many MLP copies.
If extra parameters help store facts and patterns, MoE is appealing: more parameters for almost the same FLOPs per token.
Classic results (Fedus et al. 2022) show more experts lowers loss at fixed FLOPs, and downstream performance improves. Newer work like AI2's OLMo-MoE reports the same pattern at modern scales.
The practical story: if you can manage the complexity, MoE is cost-effective.
Why MoE Isn't Universal
MoE looks simple on slides and messy in production.
Infrastructure complexity. The biggest gains show up at large scale, where multi-node training and complex parallelism are already required. Before that, the added complexity may not be worth it.
Discrete routing. Gradient descent prefers smooth choices. Top-k expert selection is not differentiable, so naive training can be unstable and can leave many experts unused.
Systems issues. Tokens must move to the devices that host their experts, then come back. Experts must stay balanced to avoid bottlenecks.
Where MoE Lives
In most modern MoE LLMs, MoE is applied to MLP blocks. Attention stays dense.
MoE attention is possible. Some work tried sparsely routed attention. It tends to be harder to train and more unstable. Most large systems avoid it.
Routing Design
Three styles:
Token choice: each token scores experts and picks top-k. This dominates in practice.
Expert choice: each expert scores tokens and picks top-k tokens. Can balance load but is less intuitive from "what fits this token."
Global assignment: solve a global matching problem that balances load. Elegant and usually too expensive at scale.
Modern high-performance MoE models mostly use token-choice top-k routing: Google MoE variants, OLMo-MoE, DeepSeek, Qwen, Mixtral, Llama MoE.
How Top-k Token Routing Works
For each token you have a hidden state x (the residual input to the MLP).
The router has parameters e_i, one vector per expert. These are separate from the expert MLP weights.
A typical router does:
- Score experts: score_i = x · e_i.
- Normalize: softmax over experts to get s_i.
- Select: choose the top-k experts by s_i. Gates g_i equal s_i for selected experts, zero otherwise.
- Compute: run the selected experts' MLPs on x, multiply outputs by g_i, sum them, and add back to the residual stream.
If k = 1, you pick one expert. If k = 2, you get redundancy and exploration but roughly double the MLP compute.
Choice of k
k is a hyperparameter.
Earlier work argued for k = 2 so the router can compare experts and avoid locking in too early. Larger k increases FLOPs and communication. With fine-grained experts, each expert is smaller, so increasing k can still be compute-feasible.
Why Routers Are Simple
Routers are usually just a linear layer plus softmax and top-k.
Reasons:
- Router compute must stay small or it eats the savings.
- The learning signal is weak and indirect, so extra router complexity often doesn't help.
- Simple routers are easier to stabilize and debug.
Routing Experiments That Didn't Win
Hash routing: routing tokens by hash, ignoring semantics, can still beat dense. Even crude partitioning across many MLPs helps.
RL routing: tried early because routing is discrete. Expensive and not better than simpler methods. Faded out.
Stochastic routing: add noise to router logits to encourage exploration. Helped less than balancing heuristics. Mostly dropped.
Fine-Grained and Shared Experts
A baseline approach is to copy the whole MLP into multiple experts. But that grows parameters quickly.
Fine-grained experts split the MLP expansion dimension (often 4× hidden size) into many narrower pieces. Each expert is smaller, so you can have many more experts without blowing up compute.
Shared experts are one or a few experts that run for every token. They can capture patterns useful everywhere.
DeepSeek popularized using many fine-grained experts plus shared experts. Ablations in DeepSeek and OLMo-MoE show that increasing fine-grained experts improves loss and benchmarks, while shared experts help in some settings but aren't consistently necessary.
Common Modern Configurations
Early Google systems (GShard, Switch, ST-MoE) used 8-16 experts per layer with 1-2 active.
DeepSeek and other labs pushed to many more experts per layer: dozens of fine-grained experts, often with at least one shared expert and multiple active experts per token.
The pattern: slice the MLP into many small experts to increase capacity while controlling compute.
Training Challenges
You can't activate all experts per token without blowing up FLOPs. Routing must stay sparse.
Top-k selection is not differentiable. You can't backprop through the choice.
Approaches explored: RL, stochastic exploration, heuristic load balancing losses.
In practice, the standard recipe wins: token-choice top-k routing plus explicit balancing.
Auxiliary Load Balancing Loss
Without balancing, training collapses to a few experts. Most tokens route to them, they become good at everything, and the rest never learn.
The goal is to spread traffic so experts learn and no bottleneck forms.
A common balancing loss (Switch Transformer style) uses:
- f(i): fraction of tokens sent to expert i after top-k.
- p(i): fraction of router probability mass for expert i before top-k.
The auxiliary loss is a dot product of f and p across experts. It discourages experts that already receive many tokens from also having high router probability. Variants apply this per expert and per device.
OLMo-MoE ablations show that without balancing, a couple of experts can take half the tokens while others go unused, hurting validation loss and wasting parameters.
DeepSeek v3's Bias-Based Balancing
DeepSeek v3 proposes balancing without an explicit auxiliary loss.
It keeps a bias b_i per expert. After each batch, it measures token counts per expert:
- If an expert is underused, increase b_i.
- If overused, decrease b_i.
During routing, add b_i to the router score before softmax or sigmoid, but don't include b_i in the final gating weights. This nudges tokens toward underused experts.
DeepSeek v3 still adds a sequence-wise balancing loss to address imbalances within a sequence at inference, not just across batches.
Expert Parallelism
MoE adds expert parallelism.
Each device holds one or more experts. After routing, tokens are sent to the devices that host their chosen experts, experts run, and outputs are sent back for combination.
This requires all-to-all communication and capacity limits. Communication cost is the main trade-off against compute and memory gains.
Modern kernels and libraries can fuse many small expert matmuls into larger sparse operations so small experts still run efficiently.
Token Dropping and Randomness
Experts and devices have capacity limits per batch. If an expert is overloaded, some systems drop the overflow tokens for that expert, effectively skipping that expert for those tokens.
This makes inference non-deterministic even at temperature zero because dropping depends on what else is in the batch. Another request in the same batch can change which tokens get dropped.
Stability, Fine-Tuning, and Overfitting
MoE training and fine-tuning can be unstable.
Softmax instability: router softmax can be numerically touchy, so systems often compute it in float32.
Z-loss: adding a penalty on the log-softmax normalizer keeps router logits from blowing up and reduces loss spikes.
Fine-tuning overfitting: MoE's extra parameters can overfit small fine-tuning sets. Mitigations include alternating dense and MoE layers, fine-tuning only dense layers, and using large supervised fine-tuning sets. DeepSeek uses very large SFT data to reduce overfitting.
Upcycling: Dense to MoE
Upcycling is a cheap path to MoE:
- Start with a trained dense model.
- Copy each MLP into multiple experts (optionally perturb weights).
- Add a router.
- Train further as MoE.
You keep the dense model's knowledge and get a larger active parameter model with limited extra training, if training stays stable. MiniCPM and Qwen report strong results from this approach.
DeepSeek Architecture Evolution
DeepSeek MoE v1
- ~16B total parameters, ~2.8B active.
- Per layer: 64 fine-grained experts + 2 shared experts, several active per token.
- Routing: token-choice top-k with softmax before top-k and weighted sum of expert outputs.
- Balancing: auxiliary load balancing per expert and per device.
DeepSeek v2
- ~236B total, ~21B active.
- Core MoE stays similar.
- Adds device top-m routing: pick a small set of devices per token, then pick experts within those devices, cutting communication.
- Adds communication balancing losses.
DeepSeek v3
- ~671B total, ~37B active.
- Core MoE still similar.
- Routing tweaks: normalize gates so expert outputs sum to one, use sigmoid gating in one part, keep device top-m for efficiency.
- Balancing: bias-based online per-expert balancing plus a sequence-wise auxiliary loss.
The pattern: v1 to v3 is mostly routing, balancing, and systems tuning, not new layer types.
MLA in DeepSeek v3
MLA reduces KV cache memory without increasing FLOPs.
Standard attention caches full keys K and values V per token.
MLA compresses each token's hidden state h_t into a smaller latent c_t:
c_t = W_c h_t
Cache c_t. When needed, up-project:
K_t = W_k_up c_t
V_t = W_v_up c_t
Naively this adds extra matmuls, but W_k_up can be merged with other projections so total matmuls stay about the same.
RoPE complicates merging because it rotates Q and K between projections. MLA is designed so RoPE applies cleanly, often by rotating only the non-compressed parts.
Multi-Token Prediction in DeepSeek v3
Standard training predicts the next token.
Multi-token prediction adds extra supervision by using each position's hidden state to predict further tokens ahead through an extra head or small layer.
DeepSeek v3 uses a one-step-ahead extra prediction head. This gives more learning signal per sequence and can improve efficiency.
What's Next
MoE replaces each dense MLP with many small experts and a router that picks a few experts per token. At fixed training FLOPs, MoE tends to beat dense models on loss and benchmarks.
The winning recipe is simple on paper: token-choice top-k routing, many fine-grained experts, and strong balancing. The difficulty is making it train stably and run efficiently at scale.
DeepSeek's trajectory shows you can scale to hundreds of billions of parameters with mostly incremental changes to routing, balancing, attention memory, and training losses.
Next lecture: GPUs. Memory hierarchy, arithmetic intensity, kernel optimization, and why FlashAttention matters.
Keep reading
You might also like
CS336 Notes: Lecture 3 - Architectures and Hyperparameters
What modern LLMs converge on: pre-norm, RMSNorm, SwiGLU, RoPE, and stability tricks.