DR Detection β EfficientNet-B4
Automated Diabetic Retinopathy (DR) severity grading from retinal fundus images, fine-tuned on APTOS 2019 dataset. This is the primary model of a comparative study (vs. ResNet-50 baseline).
Source code: https://github.com/anish030803/Computer-Vison-
Model Details
- Architecture: EfficientNet-B4 (via
timmβtf_efficientnet_b4_ns, NoisyStudent pretrained) - Custom Head: Global Average Pool β BatchNorm β Dropout(0.4) β Linear(1792β256, ReLU) β Dropout(0.3) β Linear(256β5)
- Input: 380Γ380 RGB fundus images
- Output: 5-class probabilities
- Parameters: 18M
- Preprocessing: Ben Graham's method (resize β circular crop β Gaussian blur subtraction β normalize)
Severity Grades
| Grade | Label | Description |
|---|---|---|
| 0 | No DR | No visible retinopathy |
| 1 | Mild NPDR | Microaneurysms only |
| 2 | Moderate NPDR | More than just microaneurysms |
| 3 | Severe NPDR | Extensive intraretinal hemorrhages |
| 4 | Proliferative DR | Neovascularization or vitreous hemorrhage |
Performance
5-Fold Stratified Cross-Validation (10% held-out test set):
| Metric | Mean Β± Std | Held-Out Test |
|---|---|---|
| Quadratic Weighted Kappa (QWK) | 0.8187 Β± 0.0253 | 0.8076 |
| Accuracy | 0.7355 Β± 0.0270 | 0.7361 |
| Macro F1 | 0.5849 Β± 0.0326 | β |
| Severe DR Recall | 0.5847 Β± 0.0522 | 0.5714 |
| Proliferative DR Recall | 0.4318 Β± 0.0844 | 0.4091 |
vs. ResNet-50 baseline: EfficientNet-B4 improves QWK by ~0.06 (0.7281 β 0.7884 single run, 0.8187 in 5-fold CV).
Training
- Hardware: NVIDIA A100 80GB (Northeastern Explorer HPC)
- Mixed precision: BF16
- Two-phase strategy:
- Phase 1 (warmup): Frozen backbone, train classification head only, 20 epochs, LR=1e-3 with linear warmup β cosine annealing
- Phase 2 (fine-tuning): Top 30-50% of backbone unfrozen, 25 epochs, LR=1e-5 with cosine annealing + warm restarts
- Loss: Class-weighted cross-entropy with label smoothing (0.1) β class weights auto-computed from inverse frequency
- Augmentation: Horizontal/vertical flips, rotation Β±36Β°, zoom 90-110%, brightness/contrast Β±10%, MixUp (Ξ±=0.2)
- Regularization: Dropout, weight decay (1e-4), gradient clipping (max norm 1.0), early stopping (patience=5-7 on val_qwk)
- Optimizer: AdamW (Ξ²=[0.9, 0.999])
Dataset
APTOS 2019 Blindness Detection (Kaggle):
- Original: 3,662 training images
- After cleaning: 2,681 images
- Class distribution: 49% No DR, 10% Mild, 28% Moderate, 5% Severe, 8% Proliferative
- Imbalance ratio: 9.5x
Cleaning pipeline (5 passes): file integrity, duplicates (pHash hamming < 3), quality (sharpness/brightness/contrast 2nd percentile cutoffs), resolution (min 256x256), label verification.
Usage
import torch
import cv2
import numpy as np
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
# Download checkpoint
ckpt_path = hf_hub_download(repo_id="anishanish383/dr-detection-efficientnet-b4", filename="best.pt")
# Load model (requires source code from https://github.com/anish030803/Computer-Vison-)
from src.models.efficientnet import build_efficientnet
from src.utils.config import load_config
from src.utils.checkpoint import load_checkpoint
config = load_config("configs/train_efficientnet.yaml")
model = build_efficientnet(config)
load_checkpoint(ckpt_path, model)
model.eval()
# Preprocess image
from src.data.preprocessing import ben_graham_preprocess
img = cv2.imread("fundus.png")
preprocessed = ben_graham_preprocess(img, target_size=380)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
normalized = (preprocessed - mean) / std
tensor = torch.from_numpy(normalized.transpose(2, 0, 1)).float().unsqueeze(0)
# Predict
with torch.no_grad():
logits = model(tensor)
probs = F.softmax(logits, dim=-1)[0]
class_names = ["No DR", "Mild NPDR", "Moderate NPDR", "Severe NPDR", "Proliferative DR"]
print(f"Prediction: {class_names[probs.argmax().item()]} ({probs.max().item():.2%})")
Limitations
β οΈ This model is for research and educational use only β not approved for clinical diagnosis.
- Severe DR Recall (0.58) and Proliferative DR Recall (0.43) are below clinical thresholds (β₯0.80). The model misses too many severe/sight-threatening cases.
- Trained on a single dataset (APTOS 2019, India). Performance may degrade on different populations, camera types, or imaging conditions.
- Class imbalance (9.5x) makes rare-class detection difficult.
- Not validated on external datasets (Messidor-2, EyePACS, etc.).
Suitable for: screening triage, research, educational demos. NOT suitable for: autonomous clinical diagnosis, replacing ophthalmologist review.
Citation
If you use this model, please cite the original works:
@article{tan2019efficientnet,
title={EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks},
author={Tan, Mingxing and Le, Quoc V},
journal={ICML},
year={2019}
}
@misc{aptos2019,
title={APTOS 2019 Blindness Detection},
author={Asia Pacific Tele-Ophthalmology Society},
year={2019},
publisher={Kaggle}
}
@misc{graham2015,
title={Kaggle Diabetic Retinopathy Detection Competition Report},
author={Graham, Ben},
year={2015},
publisher={University of Warwick}
}
License
MIT
- Downloads last month
- 5