CS336 Notes: Lecture 10 - Inference
Inference is not a footnote to training. It's the main engineering problem.
Training is expensive but rare. Inference runs nonstop. At the token volumes companies serve, small efficiency gains matter more than training costs.
This lecture from Stanford CS336 covers why inference is hard and how to make it faster. The core constraint: inference is memory-bound, not compute-bound. The KV cache dominates everything.
Where Inference Shows Up
Inference means: given a fixed model and a prompt, generate response tokens.
You see it in chat, code completion, batch processing, evaluation, test-time reasoning, and RL sampling. Inference underlies evaluation, fine-tuning, and reasoning chains, not just demos.
Three Metrics That Matter
TTFT (Time-To-First-Token): Time from prompt arrival to first output token. Mostly prefill time. What users notice first.
Per-token latency: Time between output tokens after the first. What users feel as "smoothness."
Throughput: Total tokens per second across requests. What you maximize for batch jobs or high volume.
High throughput doesn't mean low latency. You can generate many tokens per second overall while some requests stall. Batching and scheduling determine the tradeoff.
Why Inference is Harder Than Training
In training: the full sequence is known. You parallelize across sequence length. You run large matmuls that saturate the GPU.
In autoregressive inference: you generate one token at a time. Each token depends on all previous tokens. You can't parallelize future tokens within a single sequence.
Results: each step has smaller, less efficient computations. GPUs are underused unless you increase batch size or run many sequences. Memory traffic becomes the bottleneck, especially in attention.
Arithmetic Intensity: The Key Metric
Arithmetic intensity = FLOPs / bytes moved to and from memory.
- High intensity: compute-limited (ALUs bottleneck).
- Low intensity: memory-limited (bandwidth bottleneck).
Matrix multiply example: X is B × D, W is D × F. FLOPs ≈ 2 × B × D × F. When D and F are much larger than B (typical), intensity ≈ B.
- Large B: compute-limited.
- B = 1 (matrix-vector): intensity ≈ 1 FLOP/byte, memory-limited.
On an H100, the threshold is roughly 295 FLOPs/byte. Below that: memory-limited. Above: compute-limited.
Autoregressive generation is many small matvecs. Intensity is low. Memory bandwidth dominates.
Prefill vs Generation
Prefill (encode prompt): Run on all prompt tokens. Parallel over sequence length. High arithmetic intensity. Often compute-limited and efficient.
Generation (decode response): Add one token at a time. Each step conditions on full history. Small compute per step, heavy memory reads. Usually memory-limited and the main bottleneck.
The KV Cache Problem
Without caching, each new token recomputes attention over all previous tokens. O(T²) work per sequence. Too slow.
KV cache fixes this: Store keys and values for every past token, layer, and head. For each new token, compute its Q/K/V, then attend using cached past K/V. Complexity drops to O(T) per token.
The cost: heavy memory traffic. During generation, attention repeatedly reads large cached tensors from HBM to do little math. That makes it memory-limited.
Why Batching Helps MLP But Not Attention
MLP during generation: Arithmetic intensity ≈ B (batch size). Larger B moves toward compute-limited. Batching helps.
Attention during generation: Intensity stays near 1 and doesn't improve with B. Each sequence has its own KV cache. Weights are shared, but KV movement scales with B. B cancels out.
During generation, MLP improves with B. Attention stays memory-limited regardless.
Throughput vs Latency Tradeoff
Example with LLaMA 2 13B on H100:
- B = 1: per-token latency around 8 ms, throughput 100+ tokens/sec.
- B = 16: KV cache grows 16×, per-step latency increases, throughput improves a lot.
- As B grows: latency rises, throughput rises with diminishing returns.
- Eventually KV cache exceeds GPU memory (B = 256 can exceed 80 GB).
Small B: low latency, low throughput. Large B: high throughput, higher latency, more memory pressure.
Simple scaling: Run M copies on M GPUs. Latency per request stays about the same. Total throughput scales with M. Easy because inference doesn't need cross-GPU sync per token.
Architectural Tricks to Shrink KV Cache
Because generation is dominated by KV cache movement, many designs aim to reduce it.
Grouped-Query Attention (GQA): Standard MHA has N query heads and N KV heads. GQA reduces KV heads to K < N. Multiple query heads share one KV head. KV cache shrinks by about N/K. Models like LLaMA 3 use GQA. Too few KV heads can hurt accuracy.
Multi-Head Latent Attention (MLA): Used in DeepSeek V2. Project K/V into a smaller latent space before caching. Example: compress 16,000-d KV space to 512. KV cache shrinks dramatically.
Cross-Layer Attention (CLA): Share KV projections across multiple layers. KV cache is shared, reducing size and traffic.
Local attention with occasional global layers: Local attention keeps a sliding window (size K). KV cache grows with K, not total length. Pure local loses long-range dependencies. Fix: many local layers plus periodic full-attention layers (e.g., every 6th).
Beyond Transformers
Full attention scales KV state with sequence length. Other model families keep state size constant.
State-space models (SSMs): Maintain fixed-size state updated per token. Linear time, constant memory per step. Early work struggled with associative recall. Later models (Mamba) perform well on language at modest scale.
Linear attention: Approximates softmax attention with kernel features. Reorders computation to maintain a running state. Each step uses constant memory and linear time. Some large systems use mostly linear and local attention with a few full-attention layers for quality.
Diffusion for language: Generate the whole sequence in parallel, then refine over steps. Each step updates all tokens at once. Can saturate GPUs on long sequences. Recent text diffusion shows high tokens-per-second rates. Open questions: can they match transformers across tasks? How many refinement steps for quality?
Quantization and Pruning
Quantization: Lower precision (FP32 → BF16 → FP8 → INT8 → INT4). Fewer bytes per weight reduces memory traffic. Helps models and KV caches fit in memory.
Challenges: too much quantization hurts accuracy. Some weights have outliers needing higher precision.
Examples: LLM.int8() keeps outliers higher precision. AWQ uses calibration to find important weights and pushes others to INT4, reaching ~3× speedup.
Pruning + distillation: Remove less important layers, heads, or units. The pruned model starts worse. Distillation trains it to match the large model's outputs. A 15B model pruned to 8B can keep benchmark quality close.
Both trade fidelity for speed. They work best when quality loss is small.
Speculative Decoding
Key observation: checking tokens in parallel is cheaper than generating them one by one.
Setup: Target model Q (large, accurate, expensive). Draft model P (smaller, cheaper, similar behavior).
Algorithm:
- From current context, P proposes K tokens autoregressively.
- Q evaluates those K tokens in one parallel pass.
- Accept proposed tokens using Q(token)/P(token), clipped at 1.
- On rejection, sample a corrected token from Q.
Property: Output is an exact sample from Q. You get the same distribution as running Q alone.
Performance: If P is much faster and close to Q, many tokens are accepted. Speedups around 2× or more are common.
Extensions: multi-branch drafts (Medusa), drafts guided by target signals (EAGLE).
Serving Real Traffic
Production serving is messy. Requests arrive unpredictably. Prompts and generations vary in length. Many requests share prefixes.
Dynamic batching: Each step generates one token for every active sequence. The scheduler adds new requests and removes finished ones. Batches refresh continuously.
Selective batching: Attention needs each sequence's own KV history, so mixed lengths are awkward. MLP work is easier to batch: flatten token vectors across sequences and run shared matmuls.
KV cache fragmentation: You don't know sequence lengths in advance. Allocating one contiguous buffer per request wastes memory and creates holes when requests finish.
PagedAttention (vLLM): Treat KV cache like OS virtual memory. Split KV memory into fixed-size pages. Each sequence's KV is a list of pages, not one contiguous block. Allocate pages wherever there's free space. Reclaim when sequences finish.
Shared prefixes and copy-on-write: Many sequences share the same prefix. Store shared KV pages once with a reference count. When sequences diverge, copy-on-write creates new pages only for the divergent part.
The Core Lesson
Training is mostly compute-limited and benefits from sequence parallelism.
Autoregressive inference is dominated by memory bandwidth. Generation is sequential. KV cache traffic is large.
Ways to make inference faster:
- Shrink or share the KV cache (GQA, MLA, CLA, local attention).
- Use alternatives that avoid full attention or long KV state (SSMs, linear attention, diffusion).
- Cut memory with quantization and pruning.
- Use speculative decoding to shift work to a small model while preserving large-model outputs.
- Build serving systems that handle variable traffic with dynamic batching and paged KV.
The guiding rule: reduce memory movement per useful FLOP. Design as if memory, not compute, is the constraint.
Keep reading
You might also like
CS336 Notes: Lecture 5 - GPUs
GPU fundamentals for LLM training: memory hierarchy, arithmetic intensity, kernel optimization, FlashAttention, and bandwidth limits.