Who this is for
This lesson is for MLOps Engineers and ML practitioners who need to train models faster by using multiple GPUs or machines, and to integrate training runs into reliable batch pipelines.
Prerequisites
- Comfortable with Python and a deep learning framework (PyTorch or TensorFlow).
- Basic understanding of SGD, batches, epochs, loss, and evaluation metrics.
- You’ve run single-GPU training and can save/load checkpoints.
Why this matters
Real-world models and datasets are big. As an MLOps Engineer, you’ll be asked to:
- Cut training time from days to hours to hit release deadlines.
- Train models on multiple GPUs/nodes without sacrificing accuracy.
- Make jobs resilient to preemptions and integrate them into nightly or weekly batch pipelines.
- Scale hyperparameter searches efficiently across compute budgets.
Typical on-the-job tasks
- Move a team’s training job from 1 GPU to 8 GPUs with the same results.
- Choose between data parallel and model parallel for a large model.
- Set up synchronous training with all-reduce and tune batch size, learning rate, and gradient accumulation.
- Ensure reproducibility with correct seeding and dataset sharding.
- Add checkpointing that survives node restarts or spot preemptions.
Concept explained simply
Distributed training means splitting the work of training across multiple workers (usually GPUs) to finish faster or fit bigger models.
- Data parallelism: Each worker has a full copy of the model and sees a different shard of the batch. Workers compute gradients locally, then synchronize (average) them. Think: many identical chefs making parts of the same order, then agreeing on the final seasoning.
- Model parallelism: The model itself is split across workers. Useful when the model doesn’t fit in one device’s memory. Think: one chef handles dough, another handles sauce—parts of the same pizza live in different places.
- Synchronization patterns:
- All-reduce (peer-to-peer): Workers share and average gradients directly. Fast on GPUs with high-speed interconnects.
- Parameter server: Workers send gradients to servers that update weights. Flexible but can bottleneck at servers.
- Synchronous vs asynchronous:
- Synchronous: Everyone waits at a barrier to average gradients—more stable, deterministic.
- Asynchronous: No global barrier—faster in some cases but can be noisier and harder to reproduce.
Mental model
Picture a round-table meeting after each training step: each participant (GPU) speaks (their gradients), everyone averages what they heard (all-reduce), and they all update their notes (weights) the same way. Bigger tables (more GPUs) mean more voices—great for speed—but coordinating takes time. Your job is to set the table so coordination cost doesn’t eat the speed-up.
Core building blocks
- World size: total number of workers (e.g., 8 GPUs).
- Rank: unique ID of each worker (0 to world_size-1).
- Backend: communication library (e.g., NCCL for GPU, Gloo for CPU).
- Collectives: ops like all-reduce, broadcast, gather.
- Batch size: per-GPU batch Ă— world size = global batch. Often scale learning rate roughly linearly with global batch (then tune).
- Gradient accumulation: virtually increases batch size without more GPUs by accumulating gradients over N steps before an optimizer step.
- Sharding the dataset: each worker sees a unique subset per epoch (use distributed samplers).
- Seeding and determinism: set seeds and control randomness for reproducibility.
- Checkpointing: save model, optimizer, scheduler, and progress (epoch/step, RNG state).
Worked examples
Example 1 — Scale single-GPU training to 8 GPUs with DDP (PyTorch)
# Key ideas: init process group, set device per rank, use DistributedSampler, wrap model
import torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def train(rank, world_size):
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank)
torch.cuda.set_device(rank)
model = MyModel().cuda(rank)
model = DDP(model, device_ids=[rank], output_device=rank)
dataset = MyDataset()
sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4, pin_memory=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4 * world_size) # linear scaling starting point
scaler = torch.cuda.amp.GradScaler()
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # reshuffle per epoch across workers
for batch in loader:
inputs, targets = batch[0].cuda(rank, non_blocking=True), batch[1].cuda(rank, non_blocking=True)
with torch.cuda.amp.autocast():
loss = criterion(model(inputs), targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
dist.destroy_process_group()
Notes: NCCL is best for multi-GPU. Set sampler per epoch for good shuffling. Start with linear LR scaling and then tune.
Example 2 — All-reduce vs Parameter Server
- All-reduce shines when: GPU-to-GPU links are fast (NVLink/InfiniBand), model size is moderate, and you want simple, fully synchronous updates.
- Parameter server helps when: you need elasticity (workers come/go), have heterogeneous hardware, or want partial asynchrony. Beware server bottlenecks.
- In most GPU clusters training vision/NLP models, all-reduce (e.g., PyTorch DDP) is the practical default.
Example 3 — Batch size and learning rate
You have per-GPU batch 32 on 1 GPU, LR=3e-4. Moving to 8 GPUs with the same per-GPU batch gives global batch 256 (32Ă—8). A common rule: scale LR linearly to 8Ă—3e-4=2.4e-3, then run a short warmup and tune. If loss becomes unstable, back off LR or add gradient accumulation to adjust the effective batch.
Example 4 — Checkpointing for fault tolerance
- Save every N steps: model weights, optimizer, LR scheduler, epoch/step, and RNG states.
- On restart: load checkpoint and resume sampler state to avoid repeating data.
- For multi-worker: only rank 0 writes checkpoints to avoid conflicts; others skip writing.
# Pseudocode for safe checkpointing
if rank == 0 and (global_step % save_every == 0):
state = {
"model": model.module.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"epoch": epoch,
"step": global_step,
"rng": {
"python": random.getstate(),
"numpy": np.random.get_state(),
"torch": torch.random.get_rng_state(),
"cuda": torch.cuda.get_rng_state_all(),
}
}
torch.save(state, ckpt_path)
Step-by-step: Plan a distributed run
- Decide the pattern: If the model fits on one GPU, start with data parallel (DDP). If not, consider model/tensor sharding.
- Pick synchronization: Prefer synchronous all-reduce for stability; try mixed precision to recover throughput.
- Set batch and LR: Choose per-GPU batch for good GPU utilization; compute global batch; start with linear LR scaling and warmup.
- Shard data: Use a distributed sampler so each worker sees unique examples each epoch.
- Seed and log: Set seeds, log world size/rank, and record all hyperparameters.
- Add checkpoints: Save regularly from rank 0; verify resume works before long runs.
- Dry run: Run a short, small-epoch job to confirm convergence and throughput.
Common mistakes and self-check
- Forgetting DistributedSampler: leads to overlapping data and skewed metrics. Self-check: count unique sample IDs per worker per epoch.
- Wrong LR after scaling: instability or divergence. Self-check: plot loss vs steps; compare to 1-GPU baseline.
- No seed or inconsistent seeding: irreproducible results. Self-check: re-run short jobs; results should match closely.
- Every rank saving checkpoints: corrupted files. Self-check: ensure only rank 0 writes.
- Communication bottlenecks: small batches and too many syncs. Self-check: profile step time, check all-reduce percentage; try larger per-GPU batch/mixed precision.
Exercises
Everyone can do the exercises and Quick Test. Only logged-in users will have progress saved.
- Exercise 1: Choose strategy and scale hyperparameters.
- Exercise 2: Fix a broken PyTorch DDP snippet.
Exercise 1 — Instructions
Scenario: You have 2 nodes, each with 4 GPUs (total 8 GPUs). The model fits in one GPU. On 1 GPU you trained with per-GPU batch=32 and LR=3e-4. Move to 8 GPUs using synchronous training.
- Pick data or model parallel (and why).
- Choose all-reduce or parameter server (and why).
- Compute new global batch, and propose a new LR using linear scaling.
- If you must keep per-GPU batch at 16 due to memory, how can you preserve the effective batch?
Exercise 2 — Instructions
Fix the code so it works correctly for 4 GPUs on a single node.
def train(rank):
dist.init_process_group("nccl")
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
dataset = MyDataset()
loader = DataLoader(dataset, batch_size=64, shuffle=True)
for x, y in loader:
out = model(x)
loss = loss_fn(out, y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Exercise checklist
- Your plan states the chosen parallelism and sync method.
- You computed global batch and proposed an LR with a short rationale.
- You used gradient accumulation correctly if needed.
- Your fixed code sets device per rank and uses a DistributedSampler.
Practical projects
- Speed-up project: Take an existing 1-GPU training script. Add DDP, mixed precision, and checkpoints. Show a 4–8× speed-up with stable accuracy.
- Resilience project: Simulate a mid-epoch failure and prove your job resumes correctly without reprocessing the same data.
- Throughput tuning: Profile step time. Experiment with per-GPU batch, gradient accumulation, and number of data loader workers to maximize samples/sec.
Next steps
- Go deeper into tensor/model parallelism for very large models.
- Automate distributed runs in your batch pipeline scheduler with retries and metrics export.
- Explore distributed hyperparameter tuning (population-based or Bayesian) with resource-aware scheduling.
Mini challenge
Your 8-GPU training shows good throughput but validation accuracy is worse than the 1-GPU baseline. Propose three changes you’ll try (in order) to recover accuracy, and explain what each change targets.