Why this matters
In computer vision, positive examples are often rare: pedestrians vs. background, cancer vs. healthy tissue, product defects vs. normal items. Standard cross-entropy can get dominated by the many easy negatives, leading to poor recall on rare but critical positives. Focal Loss counteracts this by down-weighting easy examples and focusing the model on hard, misclassified ones.
- Object detection: millions of background anchors vs. few object anchors
- Medical imaging: rare pathology among many normal scans
- Long-tail classification: many classes with very few samples
Who this is for
- Computer Vision Engineers and ML practitioners battling class imbalance
- Students implementing detection or imbalanced classification models
- Anyone tuning losses to improve recall on rare classes
Prerequisites
- Basic probability and cross-entropy understanding
- Familiarity with logits vs. probabilities (sigmoid/softmax)
- Hands-on with a deep learning framework (e.g., PyTorch or TensorFlow)
Concept explained simply
Cross-entropy treats all examples equally. When most samples are easy negatives, the model learns to be great at negatives and can ignore rare positives. Focal Loss modifies cross-entropy by multiplying it with a factor that shrinks the loss for well-classified examples and keeps it large for misclassified ones.
Binary focal loss (one common form): FL = - α_t · (1 - p_t)^γ · log(p_t)
- p_t is the model probability of the true class
- α_t is a class weight (can be per-class)
- γ (gamma) ≥ 0 controls how strongly easy examples are down-weighted
Mental model
Imagine wearing a focus lens that dims easy examples. The higher the gamma, the darker the easy examples become, so your model concentrates on the hard ones.
Key components
- Gamma (γ): 0 reduces to weighted cross-entropy; 1–3 is common. Higher γ focuses more on hard examples.
- Alpha (α): balances classes. Use a scalar (binary) or a per-class vector (multi-class). Treat as tunable.
- Logits vs probabilities: Implement focal loss on logits for numerical stability.
- Calibration: Focal loss may affect probability calibration. If you need calibrated probabilities, consider post-calibration.
- Compatibility: Works with both classification and detection heads (anchor/objectness classification).
Worked examples
Example 1: Binary rare positives (intuitive numbers)
Suppose a positive sample with model probability p=0.9:
- Cross-entropy: -log(0.9) ≈ 0.105
- Focal (α=0.25, γ=2): 0.25 × (1-0.9)^2 × 0.105 ≈ 0.000263
Now a harder positive with p=0.6:
- Cross-entropy: -log(0.6) ≈ 0.511
- Focal (α=0.25, γ=2): 0.25 × (0.4)^2 × 0.511 ≈ 0.0204
Focal Loss massively down-weights the easy example and highlights the hard one.
Example 2: Multi-class with a rare class
For classes [A,B,C] with frequencies [70%,25%,5%], set per-class α = [0.3, 0.4, 0.9] (illustrative). Use FL per sample as:
FL = - α[class] × (1 - p_true)^γ × log(p_true)
Rare class C gets a larger α to avoid being ignored.
Example 3: Object detection anchors
In anchor-based detection, there can be 100k anchors per image, and most are background. With cross-entropy, easy negatives dominate the loss. Focal loss (γ≈2) reduces the weight of easy background anchors so the model learns from hard negatives and true positives, boosting recall and mAP.
Implementation snippets
PyTorch: Binary focal loss on logits
import torch, torch.nn.functional as F
def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0, reduction='mean'):
# logits, targets: shape [N]; targets in {0,1}
bce = F.binary_cross_entropy_with_logits(logits, targets.float(), reduction='none')
p = torch.sigmoid(logits)
pt = torch.where(targets == 1, p, 1 - p)
at = torch.where(targets == 1, torch.full_like(pt, alpha), torch.full_like(pt, 1 - alpha))
loss = at * (1 - pt).pow(gamma) * bce
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
return loss
# Example usage
logits = torch.tensor([2.0, -1.0, 0.0, 4.0])
targets = torch.tensor([1, 0, 1, 1])
print(focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0).item())
PyTorch: Multi-class focal on logits
def focal_loss_multiclass(logits, targets, alpha=None, gamma=2.0, reduction='mean'):
# logits: [N, C], targets: [N] with class indices
log_probs = F.log_softmax(logits, dim=-1)
probs = log_probs.exp()
# Gather p_t and log(p_t)
pt = probs[torch.arange(targets.size(0)), targets]
logpt = log_probs[torch.arange(targets.size(0)), targets]
# Alpha handling
if alpha is None:
at = torch.ones_like(pt)
else:
if isinstance(alpha, (list, tuple)):
alpha_t = torch.tensor(alpha, dtype=logits.dtype, device=logits.device)
else:
alpha_t = alpha
at = alpha_t[targets] if torch.is_tensor(alpha_t) else torch.full_like(pt, float(alpha_t))
loss = - at * (1 - pt).pow(gamma) * logpt
if reduction == 'mean':
return loss.mean()
elif reduction == 'sum':
return loss.sum()
return loss
How to tune focal loss
Integration checklist
- Compute on logits for numerical stability
- Use per-class α for multi-class imbalance
- Log key metrics by class (precision, recall, F1)
- Inspect loss breakdown for positives vs. negatives
- Validate probability calibration if needed (e.g., temperature scaling)
Exercises
Do these now. They mirror the graded exercises below.
- ex1: Implement binary focal loss on logits and compute the loss for logits [2.0, -1.0, 0.0, 4.0] and targets [1,0,1,1] with α=0.25, γ=2 (mean reduction).
- ex2: For a positive example, compare cross-entropy vs. focal loss (α=0.25, γ=2) when p=0.9 and when p=0.6. Which loses more weight under focal loss and why?
- Self-check: Your ex1 mean loss should be approximately 0.0152.
- Self-check: In ex2, the easy example should be down-weighted far more by focal loss.
Common mistakes
- Setting γ too high too early: can slow learning or cause vanishing gradients. Start at 2.
- Forgetting per-class α in multi-class: rare classes still get ignored.
- Implementing on probabilities instead of logits: can be numerically unstable.
- Ignoring calibration: focal loss can skew probabilities; calibrate if thresholds matter.
- Relying only on accuracy: use recall, F1, PR-AUC, and class-wise metrics.
How to self-check
- Log average focal loss separately for positives and negatives; positives should contribute meaningfully.
- Plot PR curves per class; improvements should show in the low-recall region.
- Sanity test: set γ=0 and verify it matches your weighted cross-entropy.
Practical projects
- Build an imbalanced binary classifier (e.g., rare defect vs. normal). Compare weighted cross-entropy vs. focal loss across γ in {0,1,2,3}.
- Train a small object detector on a custom dataset with few positives per image. Evaluate mAP with and without focal loss on the classification head.
- Long-tail multi-class classification: assign per-class α from inverse frequency and tune γ to maximize macro-F1.
Learning path
- Review cross-entropy and class weighting.
- Implement binary focal loss on logits; verify γ=0 matches weighted CE.
- Add per-class α for multi-class imbalance.
- Apply focal loss to detection heads; measure gains in recall and mAP.
- Explore complementary techniques: class-balanced sampling, re-weighting, and data augmentation.
Next steps
- Try class-balanced reweighting based on effective number of samples.
- Experiment with hard-negative mining or sampling strategies.
- Evaluate calibration methods if you deploy threshold-based decisions.
Mini challenge
Take a trained classifier that under-detects a rare class. Switch to focal loss (γ=2), set a per-class α to boost the rare class, and re-train for 5 epochs. Can you raise recall by 5+ points without collapsing precision? Document γ, α, metrics before/after.
Quick Test
Everyone can take the test. If you are logged in, your progress and score will be saved.