Transfer Learning: Standing on Giant Shoulders

12–18 minutes

Pretrained ResNet on a Small Dataset → Medical Imaging Analysis

Why Transfer Learning?

When your dataset is small (hundreds to a few thousand images), training a deep network from scratch is risky: you’ll overfit and waste compute. Transfer learning lets you reuse rich, generic visual features learned from massive datasets (e.g., ImageNet) and adapt them to your task with modest compute and data. In practice:

  • Start with a pretrained ResNet (residual network) that already “knows” edges, textures, parts, and shapes.
  • Replace the final classification layer and fine‑tune a small portion (or all) of the network on your dataset.
  • Achieve strong accuracy in hours instead of days, with far fewer images.

In this module, you’ll do two hands‑on builds:

  1. Toy Problem: Fine‑tune a pretrained ResNet on a tiny 2‑class image dataset (e.g., “cats vs dogs (mini)” or any small ImageFolder dataset).
  2. Real‑World Application: Adapt a pretrained ResNet to chest X‑ray pneumonia classification (medical imaging), including class imbalance handling, AUROC evaluation, and Grad‑CAM explanations.

Theory Deep Dive

1. Definitions & Setup

Let \mathcal{D}_S and \mathcal{D}_T be source and target domains with corresponding tasks \mathcal{T}_S and \mathcal{T}_T. Classic inductive transfer learning assumes labeled data for \mathcal{T}_T and transfers knowledge from \mathcal{T}_S (e.g., ImageNet classification) to improve performance on \mathcal{T}_T.

  • Feature extractor transfer: Use the pretrained backbone to compute features f_\theta(x), freeze most layers, and train a small head g_\phi on top.
  • Fine‑tuning: Start from pretrained \theta, then update a subset (or all) of \theta with a lower learning rate to adapt to \mathcal{D}_T.

The overall objective is often cross‑entropy:

$latex \displaystyle \mathcal{L}{CE} = -\frac{1}{N} \sum{i=1}^{N} \sum_{c=1}^{C} y^{(i)}_c \log, \hat{p}(y=c\mid x^{(i)}),$

with optional weight decay (L2) regularization:

\displaystyle \mathcal{L} = \mathcal{L}_{CE} + \lambda \lVert \Theta \rVert_2^2.

When classes are imbalanced (common in medical imaging), use class weights w_c:

$latex \displaystyle \mathcal{L}{wCE} = -\frac{1}{N} \sum{i=1}^{N} \sum_{c=1}^{C} w_c ; y^{(i)}_c ; \log, \hat{p}(y=c\mid x^{(i)}).$

2. Why ResNet for transfer?

Residual connections help train very deep CNNs by learning residual mappings:

\displaystyle \mathbf{y} = \mathcal{F}(\mathbf{x},{W_i}) + \mathbf{x}.

Shortcut paths preserve gradient flow and stabilize optimization, producing robust mid‑level features (edges → textures → object parts) that transfer well across tasks and domains. In practice, ResNet‑18/34 are great for small datasets (fewer parameters), while ResNet‑50/101 help when you have more data/compute.

3. What should I freeze/unfreeze?

  • Stage 1 (feature extractor): Freeze all convolutional blocks. Train only a new linear head for a few epochs to quickly learn class boundaries.
  • Stage 2 (discriminative fine‑tuning): Unfreeze the last block (e.g., layer4 in ResNet) and the head. Use a smaller learning rate for backbone layers, slightly larger for the head. Optionally unfreeze more layers if the domain gap is large (e.g., natural images → medical X‑rays).

Tip: When unfreezing BatchNorm layers, prefer keeping them in eval mode (use running stats) when minibatches are small/shifted; otherwise, mismatched batch statistics can degrade performance.

4. Data augmentation that actually helps

For small datasets, augmentation is regularization:

  • Geometric: random crops/resizes, flips, slight rotations (\pm 10^\circ), mild affine transforms.
  • Photometric: brightness/contrast jitter, light color jitter (avoid harsh shifts for medical images), Cutout/RandomErasing for robustness.
  • Medical caveat: keep augmentations physically plausible (e.g., don’t flip left/right in tasks with laterality importance; avoid extreme color distortions in grayscale X‑rays).

5. Optimization & schedules

Use AdamW or SGD+momentum.

Warmup + cosine decay or OneCycle schedules stabilize fine‑tuning.

Employ early stopping using validation loss/AUROC to avoid overfitting.

6. Evaluation metrics beyond accuracy

For imbalanced problems, accuracy can mislead. Prefer:

  • AUROC (area under ROC) and AUPRC.
  • Sensitivity/Recall and Specificity, often tuned via threshold \tau to optimize Youden’s J:
    J(\tau) = \text{TPR}(\tau) - \text{FPR}(\tau).

7. When transfer may not help

Large domain shift with very different low‑level statistics (e.g., MRIs vs ImageNet photos).

Very large labeled target dataset: training from scratch (or self‑supervised pretraining on in‑domain data) may match or beat ImageNet transfer.

Label space mismatch where features don’t align with target semantics.

Toy Problem – Pretrained ResNet on a Tiny 2‑Class Dataset

Data Snapshot

We’ll fine‑tune ResNet‑18 on a small ImageFolder dataset (e.g., data/animals/{train,val,test}/{cat,dog}), but you can swap in any small two‑class set.

Your actual counts will differ; this is a template of what to expect after running the counting code below.

SplitCatDogTotal
Train100100200
Val202040
Test5050100

Discussion: Balanced classes simplify optimization. If your dataset is imbalanced, we’ll show how to compute class weights to mitigate bias.

Step 1: Imports, config, and seeding

Environment: Python 3.10+, PyTorch 2.x, TorchVision 0.15+.

Load PyTorch and utilities. We’ll use sklearn for metrics and switch to GPU automatically if available.

import os, math, random, copy, time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

Sets seeds for reproducibility and chooses GPU if available.

Step 2: Paths and basic transforms

Augment the train set; keep validation deterministic. Normalize to ImageNet stats.

data_dir = Path('data/animals')  # change to your dataset root
img_size = 224

mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

val_tfms = transforms.Compose([
    transforms.Resize(int(img_size*1.15)),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

Step 3: Datasets and loaders

Wrap datasets with loaders for efficient batching.

train_ds = datasets.ImageFolder(data_dir/'train', transform=train_tfms)
val_ds   = datasets.ImageFolder(data_dir/'val',   transform=val_tfms)
test_ds  = datasets.ImageFolder(data_dir/'test',  transform=val_tfms)

class_names = train_ds.classes
num_classes = len(class_names)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

Step 4: Count samples per class

Shows actual per‑class counts; use these to judge imbalance.

from collections import Counter

cnt_train = Counter([y for _, y in train_ds.samples])
cnt_val   = Counter([y for _, y in val_ds.samples])
cnt_test  = Counter([y for _, y in test_ds.samples])
print('Train counts:', {class_names[k]: v for k, v in cnt_train.items()})
print('Val counts  :', {class_names[k]: v for k, v in cnt_val.items()})
print('Test counts :', {class_names[k]: v for k, v in cnt_test.items()})

Step 5: Build a pretrained ResNet‑18 (feature extractor stage)

Load pretrained weights, freeze the convolutional backbone, and attach a new classification head.

weights = models.ResNet18_Weights.IMAGENET1K_V1
backbone = models.resnet18(weights=weights)

# Freeze all conv layers first
for p in backbone.parameters():
    p.requires_grad = False

in_features = backbone.fc.in_features
backbone.fc = nn.Linear(in_features, num_classes)
backbone = backbone.to(device)

Step 6: Loss, optimizer, and evaluation helpers

Optimize the head with a modest LR and cosine schedule.

def accuracy(logits, y):
    preds = logits.argmax(dim=1)
    return (preds == y).float().mean().item()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(backbone.fc.parameters(), lr=3e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

Step 7: Training loop (stage 1: head only)

Train only the head for a few epochs; keep the best model by validation loss.

def run_epoch(model, loader, train=True):
    model.train() if train else model.eval()
    epoch_loss, epoch_acc = 0.0, 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.set_grad_enabled(train):
            logits = model(x)
            loss = criterion(logits, y)
        if train:
            optimizer.zero_grad(); loss.backward(); optimizer.step()
        epoch_loss += loss.item() * y.size(0)
        epoch_acc  += accuracy(logits, y) * y.size(0)
    return epoch_loss/len(loader.dataset), epoch_acc/len(loader.dataset)

best_w = copy.deepcopy(backbone.state_dict())
best_val = float('inf')
for epoch in range(5):
    tr_loss, tr_acc = run_epoch(backbone, train_loader, True)
    val_loss, val_acc = run_epoch(backbone, val_loader, False)
    scheduler.step()
    if val_loss < best_val:
        best_val = val_loss; best_w = copy.deepcopy(backbone.state_dict())
    print(f"[E{epoch+1}] train loss {tr_loss:.3f} acc {tr_acc:.3f} | val loss {val_loss:.3f} acc {val_acc:.3f}")

backbone.load_state_dict(best_w)

Step 8: Fine‑tune last block (stage 2)

Discriminative fine‑tuning: last block + head with smaller LR for backbone.

# Unfreeze layer4 and keep BN layers in eval for stability on small batches
for name, module in backbone.named_modules():
    if name.startswith('layer4'):
        for p in module.parameters():
            p.requires_grad = True

# Two parameter groups: smaller LR for backbone, larger for head
params = [
    {"params": [p for n, p in backbone.named_parameters() if p.requires_grad and not n.startswith('fc.')], "lr": 5e-5},
    {"params": backbone.fc.parameters(), "lr": 5e-4}
]
optimizer = torch.optim.AdamW(params, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

best_w = copy.deepcopy(backbone.state_dict()); best_val = float('inf')
for epoch in range(5):
    tr_loss, tr_acc = run_epoch(backbone, train_loader, True)
    val_loss, val_acc = run_epoch(backbone, val_loader, False)
    scheduler.step()
    if val_loss < best_val:
        best_val = val_loss; best_w = copy.deepcopy(backbone.state_dict())
    print(f"[FT{epoch+1}] train loss {tr_loss:.3f} acc {tr_acc:.3f} | val loss {val_loss:.3f} acc {val_acc:.3f}")

backbone.load_state_dict(best_w)

Step 9: Test evaluation

Standard confusion matrix and precision/recall/F1 view.

from sklearn.metrics import classification_report, confusion_matrix

backbone.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        logits = backbone(x).cpu()
        all_preds.append(logits.argmax(1))
        all_labels.append(y)

preds = torch.cat(all_preds).numpy()
labels = torch.cat(all_labels).numpy()
print(confusion_matrix(labels, preds))
print(classification_report(labels, preds, target_names=class_names))

Step 10: Save the model

Persist your fine‑tuned weights for later use.

torch.save(backbone.state_dict(), 'resnet18_toy_small.pth')

Quick Reference: Full Code

# Toy TL with ResNet18: feature extractor → fine-tune last block
from pathlib import Path; import copy, torch, torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

img_size=224; mean=[0.485,0.456,0.406]; std=[0.229,0.224,0.225]
train_tfms=transforms.Compose([transforms.RandomResizedCrop(img_size,(0.8,1.0)),transforms.RandomHorizontalFlip(),transforms.ColorJitter(0.1,0.1,0.1),transforms.ToTensor(),transforms.Normalize(mean,std)])
val_tfms=transforms.Compose([transforms.Resize(int(img_size*1.15)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize(mean,std)])

data_dir=Path('data/animals')
train_ds=datasets.ImageFolder(data_dir/'train',train_tfms)
val_ds=datasets.ImageFolder(data_dir/'val',val_tfms)
test_ds=datasets.ImageFolder(data_dir/'test',val_tfms)
class_names=train_ds.classes; num_classes=len(class_names)
train_loader=DataLoader(train_ds,32,True,num_workers=4,pin_memory=True)
val_loader=DataLoader(val_ds,64,False,num_workers=4,pin_memory=True)
test_loader=DataLoader(test_ds,64,False,num_workers=4,pin_memory=True)

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights=models.ResNet18_Weights.IMAGENET1K_V1
m=models.resnet18(weights=weights); [setattr(p,'requires_grad',False) for p in m.parameters()]
m.fc=nn.Linear(m.fc.in_features,num_classes); m=m.to(device)
crit=nn.CrossEntropyLoss(); opt=torch.optim.AdamW(m.fc.parameters(),3e-3,weight_decay=1e-4)
sch=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=10)

acc=lambda z,y:(z.argmax(1)==y).float().mean().item()

def run_epoch(model,ldr,train=True):
    model.train() if train else model.eval(); L=A=0
    for x,y in ldr:
        x,y=x.to(device),y.to(device)
        with torch.set_grad_enabled(train):
            out=model(x); loss=crit(out,y)
        if train: opt.zero_grad(); loss.backward(); opt.step()
        L+=loss.item()*y.size(0); A+=acc(out,y)*y.size(0)
    return L/len(ldr.dataset),A/len(ldr.dataset)

best=m.state_dict(); bestv=1e9
for _ in range(5):
    trl,tra=run_epoch(m,train_loader,True); vll,vla=run_epoch(m,val_loader,False); sch.step()
    if vll<bestv: bestv=vll; best=copy.deepcopy(m.state_dict())

m.load_state_dict(best)
for n,mod in m.named_modules():
    if n.startswith('layer4'):
        for p in mod.parameters(): p.requires_grad=True
opt=torch.optim.AdamW([
    {"params":[p for n,p in m.named_parameters() if p.requires_grad and not n.startswith('fc.')],"lr":5e-5},
    {"params":m.fc.parameters(),"lr":5e-4}],weight_decay=1e-4)
sch=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=10)
for _ in range(5):
    trl,tra=run_epoch(m,train_loader,True); vll,vla=run_epoch(m,val_loader,False); sch.step()

# test
m.eval(); import torch, numpy as np
P=[]; Y=[]
with torch.no_grad():
    for x,y in test_loader:
        P.append(m(x.to(device)).argmax(1).cpu()); Y.append(y)
from sklearn.metrics import classification_report, confusion_matrix
p=torch.cat(P).numpy(); y=torch.cat(Y).numpy()
print(confusion_matrix(y,p)); print(classification_report(y,p,target_names=class_names))

Real‑World Application — Medical Imaging (Chest X‑ray Pneumonia)

We’ll adapt a pretrained ResNet‑50 to classify NORMAL vs PNEUMONIA chest X‑rays. This section introduces class‑imbalance weighting, AUROC, threshold tuning, Grad‑CAM visualization, and model export.

Important: Medical data demands careful handling — ensure patient‑level splits (no leakage), de‑identification (no PHI in filenames/overlays), and appropriate institutional approvals. This tutorial is educational and not a clinical device.

Data Assumptions & Snapshots

Directory layout (common for public CXRs like the Pneumonia dataset):

chest_xray/
  train/{NORMAL,PNEUMONIA}/...
  val/{NORMAL,PNEUMONIA}/...
  test/{NORMAL,PNEUMONIA}/...

Example snapshot after counting:

SplitNORMALPNEUMONIATotal
Train134138755216
Val8816
Test234390624

The train set is heavily imbalanced toward PNEUMONIA. We’ll compute class weights and focus on AUROC and sensitivity/specificity.

Step 1: Imports & config

Standard imports plus ROC utilities.

import os, copy, math, time
from pathlib import Path
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms, models
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report

# Reuse device and seeds if continuing from the toy section
SEED=42; torch.manual_seed(SEED)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Step 2: Transforms for grayscale X‑rays → 3‑channel

Convert grayscale to 3‑channel; use mild, plausible augmentations.

img_size=224
mean=[0.485,0.456,0.406]; std=[0.229,0.224,0.225]

def to_3ch(img):
    # img is PIL Image (grayscale). Convert to 3-ch by duplicating.
    return img.convert('RGB')

train_tfms = transforms.Compose([
    transforms.Lambda(to_3ch),
    transforms.RandomResizedCrop(img_size, scale=(0.85,1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=7),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])
val_tfms = transforms.Compose([
    transforms.Lambda(to_3ch),
    transforms.Resize(int(img_size*1.15)),
    transforms.CenterCrop(img_size),
    transforms.ToTensor(),
    transforms.Normalize(mean,std)
])

Step 3: Datasets, counts, and class weights

Compute class weights to address imbalance in the loss.

root=Path('chest_xray')
train_ds=datasets.ImageFolder(root/'train',train_tfms)
val_ds  =datasets.ImageFolder(root/'val',  val_tfms)
test_ds =datasets.ImageFolder(root/'test', val_tfms)

class_names=train_ds.classes  # ['NORMAL','PNEUMONIA']
num_classes=len(class_names)

# Count per class in train set
counts=np.bincount([y for _,y in train_ds.samples], minlength=num_classes)
print('Train counts:', dict(zip(class_names, counts.tolist())))

# Class weights inversely proportional to frequency
class_weights = torch.tensor((counts.sum() / (num_classes*counts)), dtype=torch.float32)
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

Step 4: (Optional) Balanced sampling**

Combats imbalance during mini‑batch formation.

# WeightedRandomSampler so that each batch is more balanced
sample_weights = [class_weights[y].item() for _, y in train_ds.samples]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

train_loader=DataLoader(train_ds,batch_size=32,sampler=sampler,num_workers=4,pin_memory=True)
val_loader  =DataLoader(val_ds,  batch_size=64,shuffle=False,num_workers=4,pin_memory=True)
test_loader =DataLoader(test_ds, batch_size=64,shuffle=False,num_workers=4,pin_memory=True)

Step 5: Build pretrained ResNet‑50

Start with feature‑extractor stage to stabilize the head.

weights=models.ResNet50_Weights.IMAGENET1K_V2
model=models.resnet50(weights=weights)

# Freeze backbone initially
for p in model.parameters():
    p.requires_grad=False

model.fc=nn.Linear(model.fc.in_features, num_classes)
model=model.to(device)

Step 6: Optimizer, scheduler, helper for AUROC

Tracks loss/accuracy and **AUROC** each epoch.

optimizer=torch.optim.AdamW(model.fc.parameters(), lr=2e-3, weight_decay=1e-4)
scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=2e-3,epochs=5,steps_per_epoch=len(train_loader))

def epoch_loop(model, loader, train=True):
    model.train() if train else model.eval()
    total_loss, total_correct, total = 0.0, 0, 0
    all_probs, all_labels = [], []
    for x,y in loader:
        x,y=x.to(device),y.to(device)
        with torch.set_grad_enabled(train):
            logits=model(x)
            loss=criterion(logits,y)
        if train:
            optimizer.zero_grad(); loss.backward(); optimizer.step(); scheduler.step()
        total_loss += loss.item()*y.size(0)
        total_correct += (logits.argmax(1)==y).sum().item()
        total += y.size(0)
        all_probs.append(F.softmax(logits,1)[:,1].detach().cpu())
        all_labels.append(y.detach().cpu())
    probs=torch.cat(all_probs).numpy(); labels=torch.cat(all_labels).numpy()
    try:
        auc=roc_auc_score(labels, probs)
    except Exception:
        auc=float('nan')
    return total_loss/total, total_correct/total, auc

Step 7: Train head, then fine‑tune deeper

Stabilize with head training, then fine‑tune last block with smaller LR.

best=copy.deepcopy(model.state_dict()); best_val=float('inf')
for epoch in range(5):
    trL,trA,trAUC=epoch_loop(model, train_loader, True)
    vaL,vaA,vaAUC=epoch_loop(model, val_loader, False)
    if vaL<best_val: best_val=vaL; best=copy.deepcopy(model.state_dict())
    print(f"[Head E{epoch+1}] train L {trL:.3f} A {trA:.3f} AUC {trAUC:.3f} | val L {vaL:.3f} A {vaA:.3f} AUC {vaAUC:.3f}")

model.load_state_dict(best)

# Unfreeze layer4 + fc; discriminative LR
for name, module in model.named_modules():
    if name.startswith('layer4'):
        for p in module.parameters(): p.requires_grad=True

params=[
    {"params":[p for n,p in model.named_parameters() if p.requires_grad and not n.startswith('fc.')],"lr":1e-4},
    {"params":model.fc.parameters(),"lr":5e-4}
]
optimizer=torch.optim.SGD(params, momentum=0.9, weight_decay=1e-4)
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=10)

best=copy.deepcopy(model.state_dict()); best_val=float('inf')
for epoch in range(10):
    trL,trA,trAUC=epoch_loop(model, train_loader, True)
    vaL,vaA,vaAUC=epoch_loop(model, val_loader, False)
    scheduler.step()
    if vaL<best_val: best_val=vaL; best=copy.deepcopy(model.state_dict())
    print(f"[FT E{epoch+1}] train L {trL:.3f} A {trA:.3f} AUC {trAUC:.3f} | val L {vaL:.3f} A {vaA:.3f} AUC {vaAUC:.3f}")

model.load_state_dict(best)

Step 8: Threshold tuning and report

Picks a probability threshold that balances TPR/FPR on validation, then evaluates on test.

# Compute AUROC, then pick threshold via Youden's J on the validation set
model.eval(); import numpy as np
val_probs, val_labels = [], []
with torch.no_grad():
    for x,y in val_loader:
        p=F.softmax(model(x.to(device)),1)[:,1].cpu().numpy()
        val_probs.append(p); val_labels.append(y.numpy())
val_probs=np.concatenate(val_probs); val_labels=np.concatenate(val_labels)

fpr,tpr,thr=roc_curve(val_labels, val_probs)
youden = tpr - fpr
best_idx = np.argmax(youden)
threshold = thr[best_idx]
print('Chosen threshold:', float(threshold))

# Test evaluation using chosen threshold
te_probs, te_labels = [], []
with torch.no_grad():
    for x,y in test_loader:
        te_probs.append(F.softmax(model(x.to(device)),1)[:,1].cpu().numpy())
        te_labels.append(y.numpy())

test_probs=np.concatenate(te_probs); test_labels=np.concatenate(te_labels)
auc=roc_auc_score(test_labels, test_probs)
print('Test AUROC:', auc)

preds=(test_probs>=threshold).astype(int)
print('Confusion matrix:\n', confusion_matrix(test_labels, preds))
print(classification_report(test_labels, preds, target_names=class_names))

Step 9: Grad‑CAM (layer4) for explainability

Generates a simple Grad‑CAM heatmap showing salient regions.

# Minimal Grad-CAM to visualize regions influencing the prediction
import torch.nn.functional as F
import cv2, numpy as np
from PIL import Image

# Register hooks on layer4
acts = []; grads = []

def save_acts(m, i, o):
    acts.append(o.detach())

def save_grads(m, gi, go):
    grads.append(go[0].detach())

handle_f = model.layer4.register_forward_hook(save_acts)
handle_b = model.layer4.register_full_backward_hook(save_grads)

# Prepare one image from test set
(img_path, label) = test_ds.samples[0]
img = Image.open(img_path).convert('RGB')
input_tensor = val_tfms(img).unsqueeze(0).to(device)

model.eval(); acts.clear(); grads.clear()
logits = model(input_tensor)
prob = F.softmax(logits,1)[0,1]
prob.backward()  # target: class index 1 (e.g., PNEUMONIA)

A = acts[-1][0]          # [C,H,W]
G = grads[-1][0]         # [C,H,W]
weights = G.mean(dim=(1,2), keepdim=True)  # GAP over grads
cam = F.relu((weights*A).sum(0)).cpu().numpy()
cam = (cam - cam.min()) / (cam.max()-cam.min() + 1e-6)

# Overlay heatmap
img_np = np.array(img.resize((img_size, img_size)))
heatmap = cv2.applyColorMap((cam*255).astype(np.uint8), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
overlay = (0.5*heatmap + 0.5*img_np).astype(np.uint8)
Image.fromarray(overlay).save('gradcam_overlay.png')

handle_f.remove(); handle_b.remove()
print('Saved Grad-CAM to gradcam_overlay.png')

Step 10: Save/export model

Persist and export for inference pipelines.

torch.save(model.state_dict(), 'resnet50_cxr_pneumonia.pth')

# (Optional) Export to ONNX for deployment
x = torch.randn(1,3,img_size,img_size).to(device)
torch.onnx.export(model, x, 'resnet50_cxr.onnx', input_names=['input'], output_names=['logits'], opset_version=13)

Quick Reference: Manufacturing Defect Detection (Full Code)

# CXR TL with ResNet50: class weights, AUROC, Grad-CAM
from pathlib import Path; import copy, torch, numpy as np
import torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import datasets, transforms, models
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix, classification_report

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_size=224; mean=[0.485,0.456,0.406]; std=[0.229,0.224,0.225]

root=Path('chest_xray')
L=lambda im: im.convert('RGB')
tr=transforms.Compose([transforms.Lambda(L),transforms.RandomResizedCrop(img_size,(0.85,1.0)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(7),transforms.ToTensor(),transforms.Normalize(mean,std)])
va=transforms.Compose([transforms.Lambda(L),transforms.Resize(int(img_size*1.15)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize(mean,std)])
train_ds=datasets.ImageFolder(root/'train',tr); val_ds=datasets.ImageFolder(root/'val',va); test_ds=datasets.ImageFolder(root/'test',va)
cls=train_ds.classes; K=len(cls)

cnt=np.bincount([y for _,y in train_ds.samples], minlength=K)
W=torch.tensor((cnt.sum()/(K*cnt)),dtype=torch.float32)
crit=nn.CrossEntropyLoss(weight=W.to(device))

sw=[W[y].item() for _,y in train_ds.samples]
train_loader=DataLoader(train_ds,32,sampler=WeightedRandomSampler(sw,len(sw),True),num_workers=4,pin_memory=True)
val_loader=DataLoader(val_ds,64,False,num_workers=4,pin_memory=True)
test_loader=DataLoader(test_ds,64,False,num_workers=4,pin_memory=True)

weights=models.ResNet50_Weights.IMAGENET1K_V2
m=models.resnet50(weights=weights); [setattr(p,'requires_grad',False) for p in m.parameters()]
m.fc=nn.Linear(m.fc.in_features,K); m=m.to(device)
opt=torch.optim.AdamW(m.fc.parameters(),2e-3,weight_decay=1e-4)
sch=torch.optim.lr_scheduler.OneCycleLR(opt,2e-3,epochs=5,steps_per_epoch=len(train_loader))

def loop(model,ldr,train):
    model.train() if train else model.eval(); L=0; C=0; N=0; P=[]; Y=[]
    for x,y in ldr:
        x,y=x.to(device),y.to(device)
        with torch.set_grad_enabled(train):
            z=model(x); loss=crit(z,y)
        if train: opt.zero_grad(); loss.backward(); opt.step(); sch.step()
        L+=loss.item()*y.size(0); C+=(z.argmax(1)==y).sum().item(); N+=y.size(0)
        P.append(F.softmax(z,1)[:,1].detach().cpu()); Y.append(y.detach().cpu())
    from sklearn.metrics import roc_auc_score
    try: auc=roc_auc_score(torch.cat(Y).numpy(), torch.cat(P).numpy())
    except: auc=float('nan')
    return L/N, C/N, auc

best=m.state_dict(); bestv=1e9
for _ in range(5):
    trL,trA,trU=loop(m,train_loader,True); vaL,vaA,vaU=loop(m,val_loader,False)
    if vaL<bestv: bestv=vaL; best=copy.deepcopy(m.state_dict())

m.load_state_dict(best)
for n,mod in m.named_modules():
    if n.startswith('layer4'):
        for p in mod.parameters(): p.requires_grad=True
opt=torch.optim.SGD([
    {"params":[p for n,p in m.named_parameters() if p.requires_grad and not n.startswith('fc.')],"lr":1e-4},
    {"params":m.fc.parameters(),"lr":5e-4}], momentum=0.9, weight_decay=1e-4)
sch=torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)
for _ in range(10):
    trL,trA,trU=loop(m,train_loader,True); vaL,vaA,vaU=loop(m,val_loader,False); sch.step()

# threshold via Youden's J on val
import numpy as np
m.eval(); P=[]; Y=[]
with torch.no_grad():
    for x,y in val_loader:
        P.append(F.softmax(m(x.to(device)),1)[:,1].cpu().numpy()); Y.append(y.numpy())
P=np.concatenate(P); Y=np.concatenate(Y)
from sklearn.metrics import roc_curve, roc_auc_score, confusion_matrix, classification_report
fpr,tpr,thr=roc_curve(Y,P); th=thr[(tpr-fpr).argmax()]

# test
TP=[]; TY=[]
with torch.no_grad():
    for x,y in test_loader:
        TP.append(F.softmax(m(x.to(device)),1)[:,1].cpu().numpy()); TY.append(y.numpy())
TP=np.concatenate(TP); TY=np.concatenate(TY)
print('Test AUROC:', roc_auc_score(TY,TP))
pr=(TP>=th).astype(int)
print(confusion_matrix(TY,pr)); print(classification_report(TY,pr,target_names=cls))

Strengths & Limitations

Strengths

  • Data‑efficient: Strong results with small labeled datasets by reusing rich pretrained features.
  • Compute‑efficient: Converges faster than training from scratch.
  • Versatile: Works across many tasks (classification, detection, segmentation) with simple head swaps.

Limitations

  • Domain shift risk: ImageNet features may misalign with medical modalities (X‑ray/CT/MRI), requiring careful fine‑tuning or in‑domain pretraining.
  • Hidden biases: Pretrained models inherit biases from source data; can propagate spurious correlations (e.g., markers on X‑rays).
  • BatchNorm pitfalls: Small batches and differing statistics can harm performance if BN layers are freely updated.

Final Notes

You learned how to: (1) load a pretrained ResNet, (2) train a new classification head on small data, (3) fine‑tune deeper layers with discriminative learning rates, (4) handle class imbalance with weighted loss or balanced sampling, (5) evaluate with AUROC and tune thresholds via Youden’s J, and (6) generate Grad‑CAM explanations.

These practices let you ship useful models quickly and responsibly—even with limited data.

Next Steps for You:

Try other backbones & self‑supervision: Compare ResNet‑18/34/50 vs EfficientNet; experiment with self‑supervised backbones (e.g., SimCLR/MoCo) pre‑trained on in‑domain data.

Beyond classification: Extend to segmentation (e.g., U‑Net with ResNet encoder) or object detection for localizing pneumonia regions; add calibration (temperature scaling) for reliable probabilities.

References

[1] S. J. Pan and Q. Yang, “A Survey on Transfer Learning,” IEEE Trans. Knowl. Data Eng., vol. 22, no. 10, pp. 1345–1359, 2010.
[2] K. He, X. Zhang, S. Ren, and J. Sun, “Deep Residual Learning for Image Recognition,” in Proc. CVPR, 2016, pp. 770–778.
[3] J. Yosinski, J. Clune, Y. Bengio, and H. Lipson, “How transferable are features in deep neural networks?,” in Proc. NIPS, 2014, pp. 3320–3328.
[4] A. Kolesnikov et al., “Big Transfer (BiT): General Visual Representation Learning,” in Proc. ECCV, 2020.
[5] S. Kornblith, J. Shlens, and Q. V. Le, “Do better ImageNet models transfer better?,” in Proc. CVPR, 2019, pp. 2661–2671.
[6] K. He, R. Girshick, and P. Dollár, “Rethinking ImageNet Pre-training,” in Proc. ICCV, 2019, pp. 4918–4927.
[7] S. Azizi et al., “Big Self-Supervised Models Advance Medical Image Classification,” in Proc. ICCV, 2021.
[8] R. R. Selvaraju et al., “Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization,” in Proc. ICCV, 2017, pp. 618–626.
[9] S. Raghu, C. Zhang, J. Kleinberg, and S. Bengio, “Transfusion: Understanding Transfer Learning for Medical Imaging,” in Proc. NeurIPS, 2019.
[10] S. Ioffe and C. Szegedy, “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift,” in Proc. ICML, 2015.
[11] O. Russakovsky et al., “ImageNet Large Scale Visual Recognition Challenge,” IJCV, vol. 115, no. 3, pp. 211–252, 2015.

Leave a comment