luvv to helpDiscover the Best Free Online Tools
Topic 6 of 8

Class Imbalance Losses Focal Loss

Learn Class Imbalance Losses Focal Loss for free with explanations, exercises, and a quick test (for Computer Vision Engineer).

Published: January 5, 2026 | Updated: January 5, 2026

Why this matters

In computer vision, positive examples are often rare: pedestrians vs. background, cancer vs. healthy tissue, product defects vs. normal items. Standard cross-entropy can get dominated by the many easy negatives, leading to poor recall on rare but critical positives. Focal Loss counteracts this by down-weighting easy examples and focusing the model on hard, misclassified ones.

  • Object detection: millions of background anchors vs. few object anchors
  • Medical imaging: rare pathology among many normal scans
  • Long-tail classification: many classes with very few samples

Who this is for

  • Computer Vision Engineers and ML practitioners battling class imbalance
  • Students implementing detection or imbalanced classification models
  • Anyone tuning losses to improve recall on rare classes

Prerequisites

  • Basic probability and cross-entropy understanding
  • Familiarity with logits vs. probabilities (sigmoid/softmax)
  • Hands-on with a deep learning framework (e.g., PyTorch or TensorFlow)

Concept explained simply

Cross-entropy treats all examples equally. When most samples are easy negatives, the model learns to be great at negatives and can ignore rare positives. Focal Loss modifies cross-entropy by multiplying it with a factor that shrinks the loss for well-classified examples and keeps it large for misclassified ones.

Binary focal loss (one common form): FL = - α_t · (1 - p_t)^γ · log(p_t)

  • p_t is the model probability of the true class
  • α_t is a class weight (can be per-class)
  • γ (gamma) ≥ 0 controls how strongly easy examples are down-weighted

Mental model

Imagine wearing a focus lens that dims easy examples. The higher the gamma, the darker the easy examples become, so your model concentrates on the hard ones.

Key components

  • Gamma (γ): 0 reduces to weighted cross-entropy; 1–3 is common. Higher γ focuses more on hard examples.
  • Alpha (α): balances classes. Use a scalar (binary) or a per-class vector (multi-class). Treat as tunable.
  • Logits vs probabilities: Implement focal loss on logits for numerical stability.
  • Calibration: Focal loss may affect probability calibration. If you need calibrated probabilities, consider post-calibration.
  • Compatibility: Works with both classification and detection heads (anchor/objectness classification).

Worked examples

Example 1: Binary rare positives (intuitive numbers)

Suppose a positive sample with model probability p=0.9:

  • Cross-entropy: -log(0.9) ≈ 0.105
  • Focal (α=0.25, γ=2): 0.25 × (1-0.9)^2 × 0.105 ≈ 0.000263

Now a harder positive with p=0.6:

  • Cross-entropy: -log(0.6) ≈ 0.511
  • Focal (α=0.25, γ=2): 0.25 × (0.4)^2 × 0.511 ≈ 0.0204

Focal Loss massively down-weights the easy example and highlights the hard one.

Example 2: Multi-class with a rare class

For classes [A,B,C] with frequencies [70%,25%,5%], set per-class α = [0.3, 0.4, 0.9] (illustrative). Use FL per sample as:

FL = - α[class] × (1 - p_true)^γ × log(p_true)

Rare class C gets a larger α to avoid being ignored.

Example 3: Object detection anchors

In anchor-based detection, there can be 100k anchors per image, and most are background. With cross-entropy, easy negatives dominate the loss. Focal loss (γ≈2) reduces the weight of easy background anchors so the model learns from hard negatives and true positives, boosting recall and mAP.

Implementation snippets

PyTorch: Binary focal loss on logits
import torch, torch.nn.functional as F

def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0, reduction='mean'):
    # logits, targets: shape [N]; targets in {0,1}
    bce = F.binary_cross_entropy_with_logits(logits, targets.float(), reduction='none')
    p = torch.sigmoid(logits)
    pt = torch.where(targets == 1, p, 1 - p)
    at = torch.where(targets == 1, torch.full_like(pt, alpha), torch.full_like(pt, 1 - alpha))
    loss = at * (1 - pt).pow(gamma) * bce
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    return loss

# Example usage
logits = torch.tensor([2.0, -1.0, 0.0, 4.0])
targets = torch.tensor([1, 0, 1, 1])
print(focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0).item())
PyTorch: Multi-class focal on logits
def focal_loss_multiclass(logits, targets, alpha=None, gamma=2.0, reduction='mean'):
    # logits: [N, C], targets: [N] with class indices
    log_probs = F.log_softmax(logits, dim=-1)
    probs = log_probs.exp()
    # Gather p_t and log(p_t)
    pt = probs[torch.arange(targets.size(0)), targets]
    logpt = log_probs[torch.arange(targets.size(0)), targets]
    # Alpha handling
    if alpha is None:
        at = torch.ones_like(pt)
    else:
        if isinstance(alpha, (list, tuple)):
            alpha_t = torch.tensor(alpha, dtype=logits.dtype, device=logits.device)
        else:
            alpha_t = alpha
        at = alpha_t[targets] if torch.is_tensor(alpha_t) else torch.full_like(pt, float(alpha_t))
    loss = - at * (1 - pt).pow(gamma) * logpt
    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    return loss

How to tune focal loss

Step 1: Start with γ=2 and α tuned to rebalance the rare class (binary) or a per-class α vector (multi-class).
Step 2: Monitor recall, precision, PR-AUC or mAP. If recall is still low on rare class, try increasing γ to 3.
Step 3: If optimization becomes unstable or gradients vanish, lower γ to 1 or reduce α magnitude; also consider a smaller learning rate.
Step 4: For detection, combine with positive/negative sampling and ignore regions as usual.

Integration checklist

  • Compute on logits for numerical stability
  • Use per-class α for multi-class imbalance
  • Log key metrics by class (precision, recall, F1)
  • Inspect loss breakdown for positives vs. negatives
  • Validate probability calibration if needed (e.g., temperature scaling)

Exercises

Do these now. They mirror the graded exercises below.

  1. ex1: Implement binary focal loss on logits and compute the loss for logits [2.0, -1.0, 0.0, 4.0] and targets [1,0,1,1] with α=0.25, γ=2 (mean reduction).
  2. ex2: For a positive example, compare cross-entropy vs. focal loss (α=0.25, γ=2) when p=0.9 and when p=0.6. Which loses more weight under focal loss and why?
  • Self-check: Your ex1 mean loss should be approximately 0.0152.
  • Self-check: In ex2, the easy example should be down-weighted far more by focal loss.

Common mistakes

  • Setting γ too high too early: can slow learning or cause vanishing gradients. Start at 2.
  • Forgetting per-class α in multi-class: rare classes still get ignored.
  • Implementing on probabilities instead of logits: can be numerically unstable.
  • Ignoring calibration: focal loss can skew probabilities; calibrate if thresholds matter.
  • Relying only on accuracy: use recall, F1, PR-AUC, and class-wise metrics.
How to self-check
  • Log average focal loss separately for positives and negatives; positives should contribute meaningfully.
  • Plot PR curves per class; improvements should show in the low-recall region.
  • Sanity test: set γ=0 and verify it matches your weighted cross-entropy.

Practical projects

  • Build an imbalanced binary classifier (e.g., rare defect vs. normal). Compare weighted cross-entropy vs. focal loss across γ in {0,1,2,3}.
  • Train a small object detector on a custom dataset with few positives per image. Evaluate mAP with and without focal loss on the classification head.
  • Long-tail multi-class classification: assign per-class α from inverse frequency and tune γ to maximize macro-F1.

Learning path

  1. Review cross-entropy and class weighting.
  2. Implement binary focal loss on logits; verify γ=0 matches weighted CE.
  3. Add per-class α for multi-class imbalance.
  4. Apply focal loss to detection heads; measure gains in recall and mAP.
  5. Explore complementary techniques: class-balanced sampling, re-weighting, and data augmentation.

Next steps

  • Try class-balanced reweighting based on effective number of samples.
  • Experiment with hard-negative mining or sampling strategies.
  • Evaluate calibration methods if you deploy threshold-based decisions.

Mini challenge

Take a trained classifier that under-detects a rare class. Switch to focal loss (γ=2), set a per-class α to boost the rare class, and re-train for 5 epochs. Can you raise recall by 5+ points without collapsing precision? Document γ, α, metrics before/after.

Quick Test

Everyone can take the test. If you are logged in, your progress and score will be saved.

Practice Exercises

2 exercises to complete

Instructions

Write a function focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0, reduction='mean') that operates on logits and targets in {0,1}. Use BCE-with-logits for the base loss. Compute the mean focal loss for logits [2.0, -1.0, 0.0, 4.0] and targets [1, 0, 1, 1].

Expected Output
Approximately 0.0152 (mean loss, small numerical tolerance acceptable)

Class Imbalance Losses Focal Loss — Quick Test

Test your knowledge with 8 questions. Pass with 70% or higher.

8 questions70% to pass

Have questions about Class Imbalance Losses Focal Loss?

AI Assistant

Ask questions about this tool