CS336 Notes: Lecture 2 - PyTorch and Resource Accounting
Before you train a model, you need to know if it will fit. This lecture teaches you to estimate compute and memory before you run anything.
The core formula: one training step costs about 6 × (tokens) × (parameters) floating point operations. Everything else follows from that.
Six Things You Need to Know
Tensors are PyTorch's basic data type. A tensor's memory is (number of elements) × (bytes per element). Different tensors can share the same underlying storage through views.
Floating point types trade precision for speed. float32 has good precision and range. float16 is faster but can overflow or underflow. bfloat16 keeps float32's range with lower precision. fp8 is even faster but harder to use safely.
The common pattern: store parameters and optimizer state in float32, run most forward and backward compute in bfloat16.
GPUs do the heavy compute. You must move tensors to the GPU. GPU speed is measured in FLOP/s. MFU (model FLOP utilization) measures how much of peak compute you're actually using.
Memory is spent four ways: parameters, activations, gradients, and optimizer state. Adam is expensive because it keeps two extra buffers per parameter.
Matmul dominates cost. A matmul of (B × D) by (D × K) costs about 2 × B × D × K FLOPs. Forward is about 2 × N × P, backward about 4 × N × P, so a full step is roughly 6 × N × P.
Napkin Math: Can You Train It?
Two questions you should answer before starting:
Question 1: How long to train a 70B parameter transformer on 15T tokens using 1,024 H100s?
Compute sketch:
- Total FLOPs ≈ 6 × (parameters) × (tokens).
- Choose a peak FLOP/s per H100 for the dtype you're using.
- Apply an MFU guess (50% is reasonable for well-tuned setups).
- Multiply by number of GPUs and seconds per day to get FLOPs per day.
- Training days ≈ total FLOPs ÷ FLOPs per day.
Question 2: What's the largest dense model trainable on 8 H100s with AdamW, no memory optimizations?
Memory sketch:
- Each H100 has 80 GB HBM.
- AdamW is roughly 16 bytes per parameter for (parameter + gradient + optimizer state), ignoring activations.
- Max parameters ≈ total GPU memory ÷ bytes per parameter.
This gives a rough upper bound (around 40B parameters on 8 H100s) before activations, sequence length, and other overheads eat into it.
Tensors and Memory
Tensors hold parameters, gradients, optimizer state, data (token IDs, embeddings), and activations.
Memory per tensor = elements × bytes per element.
Example: A 4 × 8 float32 tensor has 32 elements. float32 is 4 bytes. Memory = 32 × 4 = 128 bytes.
Big weight matrices can take gigabytes.
Floating Point Types
float32
32 bits: 1 sign, 8 exponent, 23 fraction. Good precision and range. Common for parameter storage and optimizer state.
float16
16 bits with smaller exponent and fraction than float32. Half the memory, often faster. Limited range: small values can underflow to zero, large values can overflow. Example: 1e-8 rounds to zero. Can be unstable for large models.
bfloat16
16 bits with more exponent bits and fewer fraction bits than float16. Range similar to float32, lower precision. Much less prone to underflow and overflow than float16. Good for forward and backward passes.
fp8
8-bit float with variants that trade range vs precision. Supported on H100 tensor cores. Can boost speed and cut memory but is harder to train safely. Often used in advanced mixed-precision setups.
Mixed Precision Training
The pattern: use different dtypes in different parts of training.
- bfloat16 or fp8 for most matmuls in forward and backward.
- float32 for master parameters and optimizer accumulators.
- Sometimes float32 for sensitive ops (often attention-related).
Benefits: less memory for activations and intermediates, higher throughput on tensor cores.
Risks: instability if precision is too low in the wrong places. Often needs loss scaling and careful dtype choices.
In practice, training usually needs float32 somewhere to stay stable. Inference can use more aggressive quantization later (down to int4 in some deployments).
CPUs, GPUs, and Tensor Locations
PyTorch creates tensors on CPU by default. Training large models on CPU is too slow.
You must either move tensors to GPU with x.to(device) or create them on GPU.
GPU HBM is fast. Moving data between CPU RAM and GPU HBM is expensive.
Rules of thumb:
- Always know where your tensors live (CPU vs GPU).
- Avoid unnecessary transfers.
- Use tools that report tensor devices when debugging.
Tensors as Storage and Views
A tensor is a view into storage, not raw memory itself.
- Storage is a 1D array of values.
- Tensor metadata holds shape and strides.
- Strides tell you how to walk the storage for each dimension.
Many tensors can share the same storage (slices, transposes, views). Indexing, transpose, and view often don't copy data. Mutating one view can change others that share storage.
Contiguous vs non-contiguous: Contiguous tensors are laid out in row-major order. Transpose and some slicing produce non-contiguous views. Some ops require contiguous memory, and calling contiguous() will copy.
Memory rule: Views are cheap. Elementwise ops and some reshapes allocate new memory.
Matrix Multiplication Costs
Matmul is the main workhorse.
Basic case: (B × D) × (D × K) → (B × K).
Each output element does D multiplies and D adds. There are B × K outputs.
Total FLOPs = 2 × B × D × K.
Language models often use (batch, sequence, hidden) shapes. PyTorch supports batched matmuls, applying the same weights across tokens and batches.
Einops for Clarity
Indexing by position (like -2, -1) is easy to mess up. Einops lets you name dimensions:
- Einsum keeps dimensions that appear on both sides and sums over dimensions that disappear on the right.
- Rearrange helps pack and unpack composite dimensions like (heads × head_dim).
This makes tensor code easier to read and reason about.
FLOPs, FLOP/s, and Peak Numbers
Definitions:
- FLOP: one floating point add or multiply.
- FLOPs: total operations.
- FLOP/s: operations per second.
GPU specs list peak FLOP/s by dtype. H100 has much higher peak for fp16/bf16/fp8 tensor cores than for float32. Peak numbers assume ideal conditions. Dense models often hit about half of peak.
Example scale: 8 H100s running for a week can do about 10²¹ FLOPs, comparable to major training budgets.
MFU: Model FLOP Utilization
MFU measures how efficiently you use the GPU:
MFU = (model FLOPs / wall-clock time) ÷ (GPU peak FLOP/s)
How to estimate:
- Count model FLOPs by summing matmuls using 2 × product of dimensions.
- Measure step time.
- Actual FLOP/s = FLOPs ÷ time.
- Divide by peak FLOP/s for the dtype.
Interpretation:
- MFU around 0.5 or higher is strong for real systems.
- MFU like 0.05 usually means the GPU is waiting (small batches, overhead, poor input pipeline).
- MFU rises when large matmuls dominate.
Forward and Backward Cost
Two-layer model example:
- X: (B × D)
- W1: (D × D)
- W2: (D × K)
Forward:
- H1 = X × W1
- H2 = H1 × W2
- Loss = scalar function of H2
Forward FLOPs:
- X × W1: 2 × B × D × D
- H1 × W2: 2 × B × D × K
Total forward ≈ 2 × B × (D² + D × K), roughly 2 × (data points) × (parameters).
Backward:
- Grad W2: H1ᵀ × grad_H2, cost ≈ 2 × B × D × K
- Grad H1: grad_H2 × W2ᵀ, cost ≈ 2 × B × D × K
- Similar for W1 with two D × D matmuls
Total backward ≈ 4 × B × (D² + D × K), about twice forward.
Summary:
- Forward ≈ 2 × N × P
- Backward ≈ 4 × N × P
- One step ≈ 6 × N × P
where N is tokens (data points) and P is parameters.
Memory Budget
Four buckets:
- Parameters: learnable weights.
- Activations: intermediate values saved for backward.
- Gradients: same shapes as parameters.
- Optimizer state: extra buffers, size depends on optimizer.
Deep linear model with hidden size D and L layers:
- Parameters ≈ L × D² (plus a small head).
- Activations ≈ batch_size × sequence_length × D × L.
- Gradients ≈ parameters.
- Optimizer state ≈ parameters × (1× for AdaGrad-style, 2× for Adam-style).
Total bytes ≈ (parameters + activations + gradients + optimizer state) × bytes per value.
Optimizers and Their State
SGD
Update is proportional to the gradient. No extra state.
Momentum
Keeps a running average of gradients. One extra tensor per parameter.
AdaGrad
Keeps a running sum of squared gradients per parameter. Scales updates down where accumulated squared gradients are large. One extra tensor per parameter.
RMSProp
Like AdaGrad but uses an exponential moving average. One extra tensor per parameter.
Adam
Combines momentum and RMSProp ideas. Keeps two moving averages per parameter: first moment and second moment. Two extra tensors per parameter.
Memory: AdaGrad adds one extra tensor per parameter. Adam adds two.
Parameter Initialization
Naive: W ~ N(0, 1). Output variance grows with fan-in, so large D can blow up activations.
Safer: Scale by 1/√fan_in. For a D × D matrix, fan_in = D, so W = N(0, 1) / √D. Keeps outputs stable as D changes.
This connects to Xavier-style initialization. Sometimes values are clipped (truncating to [-3, 3]) to avoid rare extremes.
Randomness and Data Loading
Randomness comes from initialization, shuffling, dropout, and other regularizers.
Practice:
- Seed PyTorch, NumPy, and anything else you use.
- Use separate generators if you want fine control.
Language model data:
- Token IDs are integers.
- Store long token streams in binary or NumPy arrays.
- For huge datasets, don't load everything into RAM.
- Use memory-mapped arrays (mmap) and sample slices on demand.
- Wrap this in a PyTorch Dataset and DataLoader.
Training Loop and Checkpointing
Typical loop:
- Get a batch.
- Move it to GPU.
- Forward pass to get loss.
- Backward pass (
loss.backward()). - Optimizer step (
optimizer.step()). - Zero grads (
optimizer.zero_grad()).
Checkpointing:
- Save model state dict.
- Save optimizer state dict.
- Save step and epoch counters.
- Load on restart and resume.
What's Next
Resource accounting is how you know a training run is feasible before you start it. The formulas are simple: 6 × N × P for compute, (params + activations + grads + optimizer) × bytes for memory.
Next lecture: architecture choices. Pre-norm vs post-norm, RMSNorm, SwiGLU, RoPE, and the design decisions that show up in modern LLMs.
Keep reading
You might also like
CS336 Notes: Lecture 5 - GPUs
GPU fundamentals for LLM training: memory hierarchy, arithmetic intensity, kernel optimization, FlashAttention, and bandwidth limits.