Why this matters
Mixed precision lets you train large NLP models faster and fit bigger batches on the same GPU by using lower-precision math where it is safe, while keeping critical parts in full precision.
- Fine-tune large Transformers without out-of-memory errors.
- Increase throughput (tokens/sec) and reduce training time and cost.
- Fit larger batch sizes for more stable gradients.
- Serve inference faster with reduced latency and memory.
Concept explained simply
Neural networks do lots of multiplications and additions. Full precision (FP32) is accurate but heavy. Mixed precision uses faster, smaller numbers (like FP16 or BF16) for most compute, but keeps a high-precision copy of model weights and some operations in FP32 to stay numerically stable.
- FP32: 32-bit float. Accurate, more memory and compute.
- FP16: 16-bit float. Smaller, faster; narrower range than FP32.
- BF16: 16-bit with a wider exponent than FP16. Often more stable than FP16 for training, when supported by hardware.
Two core helpers in modern frameworks:
- Autocast: Runs selected ops in lower precision automatically, others in FP32.
- Loss Scaling: Multiplies the loss by a scale so small gradients do not underflow in FP16; then unscales before the optimizer step.
Mental model
Imagine water flowing through pipes. FP32 is a wide pipe—heavy and costly to pump. FP16/BF16 are narrower pipes—cheaper to push water through. Autocast routes most water through narrow pipes for efficiency, but keeps critical valves wide to avoid leaks (numerical issues). Loss scaling is like increasing water pressure so the flow isn’t lost in tiny channels.
Worked examples
Example 1: Fine-tuning a Transformer with autocast + loss scaling
- Enable autocast in the forward pass.
- Scale the loss before backward.
- Unscale, step the optimizer, update the scaler.
# Pseudocode (PyTorch-like)
model.train()
scaler = GradScaler()
for batch in loader:
optimizer.zero_grad()
with autocast():
outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"])
loss = loss_fn(outputs, batch["labels"]) # loss is in lower precision safe context
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Why this works
Autocast chooses FP16/BF16 where safe, FP32 where needed. The scaler prevents tiny gradients from becoming zero in FP16.
Example 2: Faster inference using autocast (no scaler)
- Switch model to eval mode.
- Wrap only the forward pass with autocast.
model.eval()
with torch.no_grad():
with autocast():
logits = model(input_ids, attention_mask=mask)
probs = softmax(logits)
Outcome
Lower latency and memory vs full FP32, with similar predictions in most cases.
Example 3: Custom op that must stay in FP32
Some numerically sensitive steps (e.g., reductions, certain normalizations, some custom kernels) should run in FP32. You can ensure this by explicitly casting to float32 for that block.
with autocast():
x = model.backbone(hidden_states)
# Ensure FP32 for a sensitive calculation
x_fp32 = x.float()
stable_val = torch.logsumexp(x_fp32, dim=-1)
Tip
Autocast policies already keep many sensitive ops in FP32. Only override if you wrote a custom operation and observed instability.
Choosing FP16 vs BF16
- Use BF16 when your hardware supports it and you notice FP16 instability. BF16’s wider exponent often trains more robustly.
- Use FP16 when BF16 is unavailable; keep loss scaling enabled.
- If both are unstable for your setup, fall back to FP32 for specific layers or the whole model.
Practical settings and heuristics
- Batch size: Mixed precision often lets you 1.5–2x your batch size. Start modestly and increase until you approach GPU memory limits.
- Optimizer: Adam/AdamW commonly used; optimizer states typically stay FP32 for stability.
- Loss scaler: Start with dynamic scaling. If you see frequent overflow warnings, allow the scaler to adjust, or try a smaller learning rate or BF16.
- Monitoring: Watch for NaNs/inf in loss or gradients. If they appear, temporarily run a few steps in FP32 to isolate the cause.
- Inference: No scaler needed—just autocast.
Quick debugging checklist
- Did you put the forward pass inside autocast?
- Are you scaling the loss and calling scaler.step + scaler.update during training?
- Did you avoid using the scaler in eval/inference?
- Are custom ops stable? Consider forcing FP32 locally.
Who this is for
NLP engineers and ML practitioners who train or serve Transformer-based models and want faster, more memory-efficient runs.
Prerequisites
- Comfort with model training loops (e.g., PyTorch-style).
- Basic understanding of floating-point numbers and numerical stability.
- Familiarity with your GPU/accelerator capabilities.
Learning path
- Understand FP32 vs FP16/BF16 and why loss scaling is needed.
- Wrap forward passes with autocast; add GradScaler to the training step.
- Validate stability with a short run; compare throughput and memory.
- Scale batch size or sequence length if memory allows.
- Harden corner cases: custom ops, evaluation pipelines, and export.
Common mistakes and self-check
- Using scaler during inference. Self-check: Ensure no scaler usage when model.eval() and no_grad() are active.
- Forgetting to call scaler.update(). Self-check: Confirm the update is called once per step after scaler.step().
- Putting the optimizer step inside autocast. Self-check: Optimizer math should be outside autocast and receive unscaled grads.
- Assuming perfect accuracy parity every time. Self-check: Compare validation metrics across FP32 vs mixed precision; small differences can be normal.
- Ignoring overflow warnings. Self-check: If frequent, try BF16 (if supported) or reduce LR; confirm the scaler reduces the scale automatically.
Practical projects
- Speed-up baseline: Convert a small BERT fine-tune script from FP32 to mixed precision. Measure tokens/sec and memory usage before/after.
- Batch-size scaling: Increment batch size until near OOM. Record throughput, stability, and validation accuracy.
- Custom layer hardening: Add a numerically sensitive custom op and ensure stability by forcing FP32 locally if needed.
Exercises you can do now
Note: The Quick Test below is available to everyone; only logged-in users will have their progress saved.
Exercise 1 — Convert a training step to mixed precision
Take a standard training loop and integrate autocast + GradScaler correctly.
- Forward pass inside autocast.
- Scale the loss, backward, step, and update the scaler.
- Ensure optimizer.zero_grad() is called appropriately.
Starter template
# Fill the missing pieces (marked ???)
model.train()
scaler = GradScaler()
for batch in loader:
optimizer.zero_grad()
with autocast():
outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"])
loss = loss_fn(outputs, batch["labels"])
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
scaler.step(optimizer)
scaler.update()
Exercise 2 — Plan memory and batch size
Suppose your FP32 run uses ~14 GB for batch size 16. Estimate a reasonable new batch size with mixed precision.
- Rule of thumb: Mixed precision can free ~30–50% memory depending on model and optimizer. Start with 1.5x and adjust.
- Show your estimate and a safety margin you would try first.
Estimation helper
If 14 GB at batch 16, try batch ~24 first (1.5x), then test ~28–32 if memory allows. Always verify by running a few steps.
Completion checklist
- I wrapped the forward pass with autocast.
- I used GradScaler for training, not inference.
- I validated loss decreases without NaNs.
- I measured throughput and memory before/after.
Mini challenge
Train a small text classifier twice: once in FP32 and once in mixed precision. Keep all hyperparameters identical. Report:
- Tokens/sec (or samples/sec)
- Peak GPU memory
- Validation accuracy
What a good answer includes
- At least a 1.2–1.8x throughput improvement on many setups.
- Memory reduction enabling larger batch size or longer sequence length.
- Comparable validation accuracy; explain any deviation and mitigations (e.g., BF16 or smaller LR).
Next steps
- Try BF16 if your hardware supports it; compare stability vs FP16.
- Apply mixed precision to your inference endpoints and measure latency and cost.
- Combine with gradient accumulation and checkpointing to fit even larger context windows.