DR Detection β€” EfficientNet-B4

Automated Diabetic Retinopathy (DR) severity grading from retinal fundus images, fine-tuned on APTOS 2019 dataset. This is the primary model of a comparative study (vs. ResNet-50 baseline).

Source code: https://github.com/anish030803/Computer-Vison-

Model Details

  • Architecture: EfficientNet-B4 (via timm β€” tf_efficientnet_b4_ns, NoisyStudent pretrained)
  • Custom Head: Global Average Pool β†’ BatchNorm β†’ Dropout(0.4) β†’ Linear(1792β†’256, ReLU) β†’ Dropout(0.3) β†’ Linear(256β†’5)
  • Input: 380Γ—380 RGB fundus images
  • Output: 5-class probabilities
  • Parameters: 18M
  • Preprocessing: Ben Graham's method (resize β†’ circular crop β†’ Gaussian blur subtraction β†’ normalize)

Severity Grades

Grade Label Description
0 No DR No visible retinopathy
1 Mild NPDR Microaneurysms only
2 Moderate NPDR More than just microaneurysms
3 Severe NPDR Extensive intraretinal hemorrhages
4 Proliferative DR Neovascularization or vitreous hemorrhage

Performance

5-Fold Stratified Cross-Validation (10% held-out test set):

Metric Mean Β± Std Held-Out Test
Quadratic Weighted Kappa (QWK) 0.8187 Β± 0.0253 0.8076
Accuracy 0.7355 Β± 0.0270 0.7361
Macro F1 0.5849 Β± 0.0326 β€”
Severe DR Recall 0.5847 Β± 0.0522 0.5714
Proliferative DR Recall 0.4318 Β± 0.0844 0.4091

vs. ResNet-50 baseline: EfficientNet-B4 improves QWK by ~0.06 (0.7281 β†’ 0.7884 single run, 0.8187 in 5-fold CV).

Training

  • Hardware: NVIDIA A100 80GB (Northeastern Explorer HPC)
  • Mixed precision: BF16
  • Two-phase strategy:
    • Phase 1 (warmup): Frozen backbone, train classification head only, 20 epochs, LR=1e-3 with linear warmup β†’ cosine annealing
    • Phase 2 (fine-tuning): Top 30-50% of backbone unfrozen, 25 epochs, LR=1e-5 with cosine annealing + warm restarts
  • Loss: Class-weighted cross-entropy with label smoothing (0.1) β€” class weights auto-computed from inverse frequency
  • Augmentation: Horizontal/vertical flips, rotation Β±36Β°, zoom 90-110%, brightness/contrast Β±10%, MixUp (Ξ±=0.2)
  • Regularization: Dropout, weight decay (1e-4), gradient clipping (max norm 1.0), early stopping (patience=5-7 on val_qwk)
  • Optimizer: AdamW (Ξ²=[0.9, 0.999])

Dataset

APTOS 2019 Blindness Detection (Kaggle):

  • Original: 3,662 training images
  • After cleaning: 2,681 images
  • Class distribution: 49% No DR, 10% Mild, 28% Moderate, 5% Severe, 8% Proliferative
  • Imbalance ratio: 9.5x

Cleaning pipeline (5 passes): file integrity, duplicates (pHash hamming < 3), quality (sharpness/brightness/contrast 2nd percentile cutoffs), resolution (min 256x256), label verification.

Usage

import torch
import cv2
import numpy as np
import torch.nn.functional as F
from huggingface_hub import hf_hub_download

# Download checkpoint
ckpt_path = hf_hub_download(repo_id="anishanish383/dr-detection-efficientnet-b4", filename="best.pt")

# Load model (requires source code from https://github.com/anish030803/Computer-Vison-)
from src.models.efficientnet import build_efficientnet
from src.utils.config import load_config
from src.utils.checkpoint import load_checkpoint

config = load_config("configs/train_efficientnet.yaml")
model = build_efficientnet(config)
load_checkpoint(ckpt_path, model)
model.eval()

# Preprocess image
from src.data.preprocessing import ben_graham_preprocess
img = cv2.imread("fundus.png")
preprocessed = ben_graham_preprocess(img, target_size=380)
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
normalized = (preprocessed - mean) / std
tensor = torch.from_numpy(normalized.transpose(2, 0, 1)).float().unsqueeze(0)

# Predict
with torch.no_grad():
    logits = model(tensor)
    probs = F.softmax(logits, dim=-1)[0]

class_names = ["No DR", "Mild NPDR", "Moderate NPDR", "Severe NPDR", "Proliferative DR"]
print(f"Prediction: {class_names[probs.argmax().item()]} ({probs.max().item():.2%})")

Limitations

⚠️ This model is for research and educational use only β€” not approved for clinical diagnosis.

  • Severe DR Recall (0.58) and Proliferative DR Recall (0.43) are below clinical thresholds (β‰₯0.80). The model misses too many severe/sight-threatening cases.
  • Trained on a single dataset (APTOS 2019, India). Performance may degrade on different populations, camera types, or imaging conditions.
  • Class imbalance (9.5x) makes rare-class detection difficult.
  • Not validated on external datasets (Messidor-2, EyePACS, etc.).

Suitable for: screening triage, research, educational demos. NOT suitable for: autonomous clinical diagnosis, replacing ophthalmologist review.

Citation

If you use this model, please cite the original works:

@article{tan2019efficientnet,
  title={EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks},
  author={Tan, Mingxing and Le, Quoc V},
  journal={ICML},
  year={2019}
}

@misc{aptos2019,
  title={APTOS 2019 Blindness Detection},
  author={Asia Pacific Tele-Ophthalmology Society},
  year={2019},
  publisher={Kaggle}
}

@misc{graham2015,
  title={Kaggle Diabetic Retinopathy Detection Competition Report},
  author={Graham, Ben},
  year={2015},
  publisher={University of Warwick}
}

License

MIT

Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support