luvv to helpDiscover the Best Free Online Tools
Topic 4 of 8

Distributed Training Basics

Learn Distributed Training Basics for free with explanations, exercises, and a quick test (for MLOps Engineer).

Published: January 4, 2026 | Updated: January 4, 2026

Who this is for

This lesson is for MLOps Engineers and ML practitioners who need to train models faster by using multiple GPUs or machines, and to integrate training runs into reliable batch pipelines.

Prerequisites

  • Comfortable with Python and a deep learning framework (PyTorch or TensorFlow).
  • Basic understanding of SGD, batches, epochs, loss, and evaluation metrics.
  • You’ve run single-GPU training and can save/load checkpoints.

Why this matters

Real-world models and datasets are big. As an MLOps Engineer, you’ll be asked to:

  • Cut training time from days to hours to hit release deadlines.
  • Train models on multiple GPUs/nodes without sacrificing accuracy.
  • Make jobs resilient to preemptions and integrate them into nightly or weekly batch pipelines.
  • Scale hyperparameter searches efficiently across compute budgets.
Typical on-the-job tasks
  • Move a team’s training job from 1 GPU to 8 GPUs with the same results.
  • Choose between data parallel and model parallel for a large model.
  • Set up synchronous training with all-reduce and tune batch size, learning rate, and gradient accumulation.
  • Ensure reproducibility with correct seeding and dataset sharding.
  • Add checkpointing that survives node restarts or spot preemptions.

Concept explained simply

Distributed training means splitting the work of training across multiple workers (usually GPUs) to finish faster or fit bigger models.

  • Data parallelism: Each worker has a full copy of the model and sees a different shard of the batch. Workers compute gradients locally, then synchronize (average) them. Think: many identical chefs making parts of the same order, then agreeing on the final seasoning.
  • Model parallelism: The model itself is split across workers. Useful when the model doesn’t fit in one device’s memory. Think: one chef handles dough, another handles sauce—parts of the same pizza live in different places.
  • Synchronization patterns:
    • All-reduce (peer-to-peer): Workers share and average gradients directly. Fast on GPUs with high-speed interconnects.
    • Parameter server: Workers send gradients to servers that update weights. Flexible but can bottleneck at servers.
  • Synchronous vs asynchronous:
    • Synchronous: Everyone waits at a barrier to average gradients—more stable, deterministic.
    • Asynchronous: No global barrier—faster in some cases but can be noisier and harder to reproduce.

Mental model

Picture a round-table meeting after each training step: each participant (GPU) speaks (their gradients), everyone averages what they heard (all-reduce), and they all update their notes (weights) the same way. Bigger tables (more GPUs) mean more voices—great for speed—but coordinating takes time. Your job is to set the table so coordination cost doesn’t eat the speed-up.

Core building blocks

  • World size: total number of workers (e.g., 8 GPUs).
  • Rank: unique ID of each worker (0 to world_size-1).
  • Backend: communication library (e.g., NCCL for GPU, Gloo for CPU).
  • Collectives: ops like all-reduce, broadcast, gather.
  • Batch size: per-GPU batch Ă— world size = global batch. Often scale learning rate roughly linearly with global batch (then tune).
  • Gradient accumulation: virtually increases batch size without more GPUs by accumulating gradients over N steps before an optimizer step.
  • Sharding the dataset: each worker sees a unique subset per epoch (use distributed samplers).
  • Seeding and determinism: set seeds and control randomness for reproducibility.
  • Checkpointing: save model, optimizer, scheduler, and progress (epoch/step, RNG state).

Worked examples

Example 1 — Scale single-GPU training to 8 GPUs with DDP (PyTorch)
# Key ideas: init process group, set device per rank, use DistributedSampler, wrap model
import torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def train(rank, world_size):
    dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank)
    torch.cuda.set_device(rank)

    model = MyModel().cuda(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)

    dataset = MyDataset()
    sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
    loader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4, pin_memory=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4 * world_size)  # linear scaling starting point
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # reshuffle per epoch across workers
        for batch in loader:
            inputs, targets = batch[0].cuda(rank, non_blocking=True), batch[1].cuda(rank, non_blocking=True)
            with torch.cuda.amp.autocast():
                loss = criterion(model(inputs), targets)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

    dist.destroy_process_group()

Notes: NCCL is best for multi-GPU. Set sampler per epoch for good shuffling. Start with linear LR scaling and then tune.

Example 2 — All-reduce vs Parameter Server
  • All-reduce shines when: GPU-to-GPU links are fast (NVLink/InfiniBand), model size is moderate, and you want simple, fully synchronous updates.
  • Parameter server helps when: you need elasticity (workers come/go), have heterogeneous hardware, or want partial asynchrony. Beware server bottlenecks.
  • In most GPU clusters training vision/NLP models, all-reduce (e.g., PyTorch DDP) is the practical default.
Example 3 — Batch size and learning rate

You have per-GPU batch 32 on 1 GPU, LR=3e-4. Moving to 8 GPUs with the same per-GPU batch gives global batch 256 (32Ă—8). A common rule: scale LR linearly to 8Ă—3e-4=2.4e-3, then run a short warmup and tune. If loss becomes unstable, back off LR or add gradient accumulation to adjust the effective batch.

Example 4 — Checkpointing for fault tolerance
  • Save every N steps: model weights, optimizer, LR scheduler, epoch/step, and RNG states.
  • On restart: load checkpoint and resume sampler state to avoid repeating data.
  • For multi-worker: only rank 0 writes checkpoints to avoid conflicts; others skip writing.
# Pseudocode for safe checkpointing
if rank == 0 and (global_step % save_every == 0):
    state = {
        "model": model.module.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "epoch": epoch,
        "step": global_step,
        "rng": {
            "python": random.getstate(),
            "numpy": np.random.get_state(),
            "torch": torch.random.get_rng_state(),
            "cuda": torch.cuda.get_rng_state_all(),
        }
    }
    torch.save(state, ckpt_path)

Step-by-step: Plan a distributed run

  1. Decide the pattern: If the model fits on one GPU, start with data parallel (DDP). If not, consider model/tensor sharding.
  2. Pick synchronization: Prefer synchronous all-reduce for stability; try mixed precision to recover throughput.
  3. Set batch and LR: Choose per-GPU batch for good GPU utilization; compute global batch; start with linear LR scaling and warmup.
  4. Shard data: Use a distributed sampler so each worker sees unique examples each epoch.
  5. Seed and log: Set seeds, log world size/rank, and record all hyperparameters.
  6. Add checkpoints: Save regularly from rank 0; verify resume works before long runs.
  7. Dry run: Run a short, small-epoch job to confirm convergence and throughput.

Common mistakes and self-check

  • Forgetting DistributedSampler: leads to overlapping data and skewed metrics. Self-check: count unique sample IDs per worker per epoch.
  • Wrong LR after scaling: instability or divergence. Self-check: plot loss vs steps; compare to 1-GPU baseline.
  • No seed or inconsistent seeding: irreproducible results. Self-check: re-run short jobs; results should match closely.
  • Every rank saving checkpoints: corrupted files. Self-check: ensure only rank 0 writes.
  • Communication bottlenecks: small batches and too many syncs. Self-check: profile step time, check all-reduce percentage; try larger per-GPU batch/mixed precision.

Exercises

Everyone can do the exercises and Quick Test. Only logged-in users will have progress saved.

  • Exercise 1: Choose strategy and scale hyperparameters.
  • Exercise 2: Fix a broken PyTorch DDP snippet.
Exercise 1 — Instructions

Scenario: You have 2 nodes, each with 4 GPUs (total 8 GPUs). The model fits in one GPU. On 1 GPU you trained with per-GPU batch=32 and LR=3e-4. Move to 8 GPUs using synchronous training.

  • Pick data or model parallel (and why).
  • Choose all-reduce or parameter server (and why).
  • Compute new global batch, and propose a new LR using linear scaling.
  • If you must keep per-GPU batch at 16 due to memory, how can you preserve the effective batch?
Exercise 2 — Instructions

Fix the code so it works correctly for 4 GPUs on a single node.

def train(rank):
    dist.init_process_group("nccl")
    model = Net()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
    dataset = MyDataset()
    loader = DataLoader(dataset, batch_size=64, shuffle=True)
    for x, y in loader:
        out = model(x)
        loss = loss_fn(out, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Exercise checklist

  • Your plan states the chosen parallelism and sync method.
  • You computed global batch and proposed an LR with a short rationale.
  • You used gradient accumulation correctly if needed.
  • Your fixed code sets device per rank and uses a DistributedSampler.

Practical projects

  • Speed-up project: Take an existing 1-GPU training script. Add DDP, mixed precision, and checkpoints. Show a 4–8Ă— speed-up with stable accuracy.
  • Resilience project: Simulate a mid-epoch failure and prove your job resumes correctly without reprocessing the same data.
  • Throughput tuning: Profile step time. Experiment with per-GPU batch, gradient accumulation, and number of data loader workers to maximize samples/sec.

Next steps

  • Go deeper into tensor/model parallelism for very large models.
  • Automate distributed runs in your batch pipeline scheduler with retries and metrics export.
  • Explore distributed hyperparameter tuning (population-based or Bayesian) with resource-aware scheduling.

Mini challenge

Your 8-GPU training shows good throughput but validation accuracy is worse than the 1-GPU baseline. Propose three changes you’ll try (in order) to recover accuracy, and explain what each change targets.

Practice Exercises

2 exercises to complete

Instructions

Scenario: 2 nodes Ă— 4 GPUs = 8 GPUs. Model fits a single GPU. On 1 GPU you used per-GPU batch=32 and LR=3e-4. You will move to synchronous training.

  • Select data vs model parallelism and justify.
  • Select all-reduce vs parameter server and justify.
  • Compute the global batch and propose a new LR using linear scaling.
  • If per-GPU batch must drop to 16 for memory reasons, describe how to preserve the effective batch size and what to adjust.
Expected Output
A short plan describing the chosen approach, the computed global batch size, proposed learning rate, and how to use gradient accumulation if needed.

Distributed Training Basics — Quick Test

Test your knowledge with 6 questions. Pass with 70% or higher.

6 questions70% to pass

Have questions about Distributed Training Basics?

AI Assistant

Ask questions about this tool