Why this matters
Real products rarely need a single vision output. A driver-assistance camera needs detection, lane segmentation, and depth. An ecommerce system wants product classification, box detection, and quality checks. Multi-task learning (MTL) lets you train one model to solve several tasks at once, often with better generalization and lower latency than running many separate models.
- Ship faster: one shared encoder, multiple heads.
- Run cheaper: less compute and memory at inference.
- Learn more robust features: tasks help regularize each other.
Note: Everyone can take the Quick Test on this page. Only logged-in users will have their progress saved.
Who this is for
- Computer Vision Engineers moving from single-task models to production-ready multi-task systems.
- ML practitioners who want efficient deployment and better generalization.
- Students building multi-output models for projects and competitions.
Prerequisites
- Comfort with CNNs or ViT backbones and common vision heads (classification, detection, segmentation).
- Basic PyTorch or TensorFlow training loops and loss functions (CE, BCE, L1/L2, focal loss).
- Understanding of evaluation metrics (accuracy/F1, mAP, IoU).
Concept explained simply
Idea: Share one feature extractor (encoder/backbone). Add a small head per task. Train with a weighted sum of task losses, masking losses for tasks without labels.
Mental model: Think of a tree. The trunk (shared encoder) learns general visual features. Each branch (task head) specializes to a task (classification, detection, segmentation, depth). You control how much each branch influences the trunk via loss weights and sampling.
Deep dive: Hard vs. soft parameter sharing
- Hard sharing: Single shared encoder + task-specific heads. Most common and efficient.
- Soft sharing: Separate encoders that exchange information (e.g., cross-stitch units, adapters, FiLM). Useful when tasks conflict strongly.
Core building blocks
1) Shared encoder
- Choose a backbone (ResNet, EfficientNet, MobileNet, Swin/ViT) sized for your latency target.
- Optionally add FPN or multi-scale features if detection/segmentation are included.
2) Task heads
- Classification: Global pooling + MLP.
- Detection: Conv head on feature maps (anchor-based or anchor-free).
- Segmentation: Upsampling/decoder (e.g., simple bilinear + convs) to full/half-res masks.
- Others: Keypoints, depth, normals, OCR. All are just different heads.
3) Loss design
- Total loss: L = Σ w_t * L_t, only over tasks with labels for that sample (mask others).
- Weighting options: fixed weights; dynamic schemes (uncertainty weighting, GradNorm, DWA, PCGrad/GradVac for gradient conflicts).
4) Batching and sampling
- Mixed batches: samples may have labels for some tasks and not others. Always mask missing labels.
- Sampling strategies: proportionate to task dataset sizes, or temperature-smoothed to avoid starving small tasks.
5) Training schedule
- Warmup with balanced or uniform weights; gradually adjust with a dynamic method.
- Consider task-specific learning rates or head warmup if one task lags badly.
6) Evaluation
- Report per-task metrics. Also track latency and memory.
- Prefer Pareto comparisons over a single aggregated score.
Tip: Handling conflicting tasks
- Try task-specific BatchNorm or separate normalization stats.
- Use gradient balancing (e.g., PCGrad) when one task improves while another regresses.
- Consider soft sharing or lightweight adapters for hard conflicts.
Worked examples
Example 1: Driver assistance
- Tasks: Vehicle detection (mAP), lane segmentation (mIoU), depth (AbsRel).
- Backbone: Mobile-friendly CNN with FPN (P3–P5).
- Heads: Detection head on P3–P5; Seg head on fused P3; Depth head on P3.
- Loss: 0.4*L_det + 0.4*L_seg + 0.2*L_depth. Mask if depth not labeled.
- Why: Detection and lanes share edges/structure; depth helps geometry.
Example 2: Ecommerce product photos
- Tasks: Category classification, box detection, binary background mask.
- Backbone: ResNet-50 (shared).
- Heads: Classification: GAP + FC to 1,000 classes; Detection: anchor-free head; Segmentation: light decoder + sigmoid mask.
- Training: Uncertainty weighting to auto-balance; oversample rare classes.
Example 3: Document understanding
- Tasks: Text region detection, layout segmentation, page classification.
- Backbone: Swin-Tiny with FPN.
- Heads: Detection head for boxes; Seg head for regions (title/body/table); Classification head for page type.
- Trick: Task-specific BatchNorm because scanned vs. photo documents have different statistics.
Simple multi-task training loop (pseudo-code)
# For each batch with mixed annotations
features = encoder(images)
loss = 0
if has_cls: loss += w_cls * CE(cls_head(pool(features)), y_cls)
if has_det: loss += w_det * det_loss(det_head(fpn(features)), y_det)
if has_seg: loss += w_seg * dice_bce(seg_head(features), y_seg)
# optionally update weights with uncertainty or GradNorm
loss.backward(); optimizer.step()
Common mistakes and how to self-check
- Forgetting loss masking: Training explodes or drifts. Self-check: Verify each task loss only uses samples with labels.
- Unbalanced gradients: One task dominates. Self-check: Log per-task gradient norms on the shared encoder.
- Architectural mismatch: Heavy decoder for tiny images. Self-check: Profile latency and memory per head.
- Mixed augmentations misaligned: Applying CutMix when detection is present. Self-check: Ensure bbox/mask transforms are consistent.
- One metric stagnates: Could be under-weighted loss. Self-check: Temporarily up-weight lagging task and see if it moves.
Practical projects
- Street-scene tri-task: Train one model for detection, semantic segmentation, and depth on a small subset of urban scenes. Compare to three separate models for latency and accuracy.
- Retail shelf analysis: One model for product detection, shelf segmentation, and out-of-stock classification. Evaluate per-task metrics and total inference time.
- Document trinity: Text block detection, region segmentation, and document type classification on a curated set of scanned pages.
Learning path
- Before this: Single-task detection/segmentation heads, loss functions, metrics.
- This lesson: How to combine tasks with shared encoders, weighting, masking, and evaluation.
- Next: Advanced task balancing (uncertainty, GradNorm, PCGrad), soft-sharing (cross-stitch, adapters), and multi-dataset training strategies.
Hands-on exercises
Do these in order. Then take the Quick Test at the bottom of the page.
Exercise 1 — Balance multi-task losses
You have a batch with three tasks: classification (L_cls=0.8), detection (L_det=2.4), segmentation (L_seg=1.2). Use fixed weights w_cls=0.5, w_det=0.2, w_seg=0.3. Compute the total loss. If detection labels are missing for half the batch, should you change the formula?
- Write the scalar total loss.
- Explain how masking works for missing labels.
Hint
Weighted sum over tasks that have labels in the current batch. Missing labels mean you drop that term for those samples.
Exercise 2 — Design a minimal multi-task model
Design an MTL architecture for 512×512 images with tasks: 1) 100-class classification, 2) 1-class object detection (box + objectness), 3) binary segmentation. Specify:
- Backbone and where features are taken from.
- Head outputs (shapes and activations) at inference.
- Losses for each task.
Hint
Use a shared CNN with global pooled features for classification, a conv head on a stride-8/16 map for detection, and an upsampling decoder for segmentation.
Exercise 3 — Debug conflicting tasks
After 10 epochs: detection mAP rises steadily, segmentation mIoU is flat, classification accuracy decreases after epoch 4. Propose three concrete changes to fix this.
- Two training changes (weights/sampling/augs).
- One architectural change.
Hint
Adjust loss weights, use dynamic balancing, task-specific BN or lighter head for the dominant task, check augmentation consistency.
Exercise checklist
- I can compute weighted multi-task loss with masking for missing labels.
- I can specify shared backbone + per-task heads and their outputs.
- I can diagnose negative transfer and propose fixes.
Mini challenge
You must ship a mobile model that does person detection, instance segmentation, and keypoints at 30 FPS. Choose a backbone and heads, define a loss weighting plan for the first 5 epochs and after, and list two latency-saving tricks you will use. Keep your answer under 10 lines.
Next steps
- Implement a small MTL prototype with two tasks and log per-task gradient norms.
- Try uncertainty weighting vs. fixed weights and compare per-task metrics.
- When ready, take the Quick Test below. Note: The test is available to everyone; only logged-in users will see saved progress.