Mask-aware Multi-expert Prior-regularized Dual-head Deferral Router in Glaucoma Screening and Diagnosis.
We address the problem of learning-to-defer (L2D) in clinical glaucoma screening, where a frozen image classifier (Swin-V2) must collaborate with a roster of twelve human experts under three real-world constraints: (i) experts are not uniformly available across cases; (ii) consultations carry asymmetric tier costs, and (iii) deployment caps the deferral budget at ρ ≤ 0.7. We introduce MPD²-Router, a routing policy regularised by a three-level badness-ranked prior (global → support family → local group) and trained with two complementary divergences (GSDP for group-conditional structure and Rank-JS for sample-level peakiness control). Hard deployment constraints are enforced via an augmented Lagrangian on the deferral rate and average tier cost. Hyperparameters are chosen by a constraint-aware Bayesian optimiser with multivariate TPE and a Hyperband pruner, scalarised over a five-objective augmented Chebyshev with adaptive utopia tracking. On REFUGE (in-distribution) and CHAKSU / ORIGA (out-of-distribution), MPD²-Router improves Matthews correlation and AUPRC over the AI-only baseline while keeping deferral within budget and concentrating queries on the experts with the strongest test-time evidence.
- Method
- Pipeline overview
- Repository layout
- Installation
- Reproducing the submission
- Datasets
- Module reference
- Data schema
- Configuration reference
- Reproducibility
- Citation
- License
For each level of the hierarchy we compute per-expert badness
b_j = c_fn · FNR_j + c_fp · FPR_j + tier_cost_j
with Beta-smoothed FNR/FPR (alpha = beta = 1). Levels are then mixed via Dirichlet pseudo-counts:
- Global —
p_global = (1 − u_g) · softmax(−τ · b) + u_g · U - Family —
p_family = λ_f · p_raw_f + (1 − λ_f) · p_global,λ_f = n_f / (n_f + n0_family) - Group —
p_group = λ_g · p_raw_g + λ_f · p_family + λ_glo · p_global,λ_g = n_g / (n_g + n0_group)
After mixing, group-level entries are clipped by _adaptive_clip_max(k_active, anchors) to prevent collapse onto a single expert. The truncated geometric top mass
g_{k,1}(ρ) = (1 − ρ) / (1 − ρ^k)
is exposed as the natural reference for the cap, controlled by two intuitive HPO knobs (clip_ceiling, clip_slack) — see src/m2p/adaptive_hpo.py:99.
Implemented in src/m2p/losses.py.
- GSDP — group-conditional KL or JS divergence between the empirical routing distribution
q̄_g = Σ_i d_i q_i / Σ_i d_iand the precomputed group priorp_g. - Rank-JS — per-sample JS against the truncated geometric rank prior, with a
mode='any_excess'gate that activates only when the sample's sorted prefix sums exceed the geometric reference (i.e. the routing is too peaky).
m2p.augmented_lagrangian.AugLag enforces
E[defer] ≤ ρ (max_deferral_rate)
E[cost] ≤ C (max_avg_cost, optional)
via the standard penalty λ · g + ½ μ · max(0, g)² on each constraint, with λ ← max(0, λ + lr_λ · g) updated once per epoch.
src/m2p/adaptive_hpo.py runs Optuna with multivariate TPE seeded by a known-good notebook config, a Hyperband pruner keyed on the lower bound of the warmup search range, and an augmented Chebyshev scalarisation over (clinical, MCC, AUPRC, tier_soft, es_violation) with an EMA-blended utopia point. A constraint-aware selection score then prefers checkpoints that meet the deferral budget even when raw clinical cost ties.
┌──────────────────────────┐
│ Fundus images (REFUGE, │
│ ORIGA, CHAKSU) │
└──────────────┬───────────┘
│ m2p.feature_extraction (Swin-V2 frozen)
▼
┌──────────────────────────┐
│ Pooled embeddings, │ m2p.ood
│ hidden states, logits │──► MSP / Energy / kNN /
└──────────────┬───────────┘ ViM / Mahalanobis ───┐
│ │
│ vim_risk_z, quality_risk │
▼ ◄─────────────────────────────────────┘
┌──────────────────────────┐
│ data/final_dataset3.csv │
└──────────────┬───────────┘
│ m2p.grouping (Stage 1–3)
▼
┌──────────────────────────┐
│ support_family, │
│ subcluster, group_id │
└──────────────┬───────────┘
│ m2p.adaptive_hpo (TPE + Hyperband)
▼
┌──────────────────────────┐
│ best_params │
└──────────────┬───────────┘
│ m2p.training.train_l2d_multi_expert
▼
┌──────────────────────────┐
│ MPD²-Router │
│ • risk_enc + struct_enc │
│ • OVA expert head │
│ • augmented Lagrangian │
│ • GSDP + Rank-JS reg. │
└──────────────────────────┘
M2P/
├── README.md
├── LICENSE ← MIT
├── CITATION.cff ← citation metadata
├── pyproject.toml ← installable package + console scripts
├── requirements.txt ← pinned transitive deps
├── configs/
│ └── best_hpo.yaml ← submission config (HPO trial 50)
├── data/
│ └── final_dataset3.csv ← 3195-row OOD-scored manifest
├── src/m2p/
│ ├── __init__.py
│ ├── feature_extraction.py ← frozen Swin-V2 wrapper
│ ├── ood.py ← KNN / ViM / Mahalanobis + logit scores
│ ├── grouping.py ← three-stage hierarchical bucketing
│ ├── router_training.py ← FULL training pipeline in one file:
│ │ configs, costs, dataset, model,
│ │ priors, losses, AugLag,
│ │ training loop, evaluation
│ └── adaptive_hpo.py ← consolidated Bayesian HPO
│ (v2 + geometric_clip lineages)
├── scripts/
│ ├── extract_features.py ← Swin-V2 feature extraction
│ ├── run_grouping.py ← three-stage bucketing CLI
│ ├── run_hpo.py ← constraint-aware Bayesian study
│ └── train_router.py ← final training + test eval
└── docs/ ← (extended notes / figures)
git clone https://github.com/Lyra-stellAI/M2P
cd M2P
python -m venv .venv && source .venv/bin/activate
pip install -e ".[yaml]"A CUDA-enabled GPU is recommended for Swin-V2 feature extraction and router training; the OOD detectors and grouping pipeline run comfortably on CPU.
The submitted system can be reproduced from the shipped manifest in four steps; each step writes its outputs to disk so later steps can be resumed independently.
Required only if you need to re-derive the OOD scores; the shipped data/final_dataset3.csv already carries vim_risk_z, maha_risk, etc.
python scripts/extract_features.py \
--csv data/manifests/refuge.csv \
--out_dir data/features/refugepython scripts/run_grouping.py \
--csv data/final_dataset3.csv \
--emb data/features/all_pooled_last.npy \
--out_dir grouping_outputThis writes grouping_output/final_dataset3_grouped.csv (used by the next two steps) and four summary CSVs (support family, family clustering, final groups, fold log).
python scripts/run_hpo.py \
--csv grouping_output/final_dataset3_grouped.csv \
--n_trials 80 \
--storage sqlite:///hpo.db \
--ledger hpo_ledger.jsonlSkip this step if you want to reproduce the submission exactly using configs/best_hpo.yaml.
python scripts/train_router.py \
--csv grouping_output/final_dataset3_grouped.csv \
--config configs/best_hpo.yaml \
--out_dir runs/mpd2-router-001Outputs:
runs/mpd2-router-001/router.pt— best-checkpoint state dict.runs/mpd2-router-001/test_metrics.json— final test-split metrics.- Console: per-action and per-dataset breakdown via
print_eval_reportandprint_eval_report_dataset.
| Dataset | Total | Train | Val | Test | Glaucoma rate | Role |
|---|---|---|---|---|---|---|
| REFUGE | 1200 | 400 | 400 | 400 | 10.0 % | In-distribution reference cohort |
| CHAKSU | 1345 | 686 | 323 | 336 | 14.0 % | Near out-of-distribution |
| ORIGA | 650 | 325 | 162 | 163 | 25.8 % | Far out-of-distribution |
| Total | 3195 | 1411 | 885 | 899 | 14.9 % |
Twelve human experts annotate subsets of the data:
- REFUGE experts (7):
y_refuge_expert_1…y_refuge_expert_7 - CHAKSU experts (5):
y_chaksu_expert_1…y_chaksu_expert_5
The per-row m_experts column encodes availability (12-bit), and m_actions = [1] + m_experts (13-bit) prepends the always-available AI action.
| Symbol | Purpose |
|---|---|
SWINV2_HF_NAME |
HuggingFace model id (pamixsun/swinv2_tiny_for_glaucoma_classification). |
load_swinv2() |
Build the frozen classifier and image processor. |
extract_features(...) |
Returns {logits, pooled_last, hidden, labels, …} for an image manifest. |
| Symbol | Purpose |
|---|---|
score_msp / score_maxlogit / score_entropy / score_energy / score_energy_react |
Logit-based detectors (sign-flipped so higher = more OOD). |
KNN_OOD(k=10) |
Distance to the k-th cosine-normalised train neighbour. |
ViM_OOD(pca_dim=256) |
Virtual-logit Matching with PCA-residual norm. |
Mahalanobis_OOD(layers=(2,3,4)) |
Multi-layer Mahalanobis with Ledoit-Wolf shrinkage. |
ood_metrics(id, ood, tau=99.0) |
AUROC / AUPR_OOD / FPR@τ. |
| Function | Purpose |
|---|---|
build_support_families(...) |
Stage 1 — Hamming-clustered expert masks + micro-family absorption. |
choose_k_and_fit(...) |
Stage 2 — silhouette × log₂ K on train embeddings. |
_merge_starved_subclusters(...) |
Post-hoc merging of train-starved sub-clusters. |
fold_small_groups(...) |
Stage 3 — count + train-fraction folding with audit log. |
group_dataset(...) |
End-to-end driver writing final_dataset3_grouped.csv. |
The full router-training pipeline lives in a single ~1700-line file organised into nine clearly-labelled sections:
| Section | Symbols |
|---|---|
| §1 Configs | TrainConfig, PriorRegConfig, ALConfig |
| §2 Costs / objective | TIER_COST, EXPERT_TIER, action_costs, expert_clinical_cost, ai_expected_clinical_cost, per_action_clinical_cost, l2d_objective |
| §3 Dataset | L2DDataset, prob1_from_logits_np |
| §4 Model | StructuralRisk, MLPBlock, OVAExpertHeadFixed, Router |
| §5 Hierarchical priors | parse_action_mask, prepare_expert_mask, compute_global_prior, compute_family_prior, compute_group_prior, build_prior_tensor, build_all_priors |
| §6 Routing losses | gsdp_loss, rank_majorization_js_loss, combined_routing_loss, truncated_geometric_prior, default_rho_by_k, helpers |
| §7 Augmented Lagrangian | AugLag |
| §8 Evaluation | evaluate, diagnostic_checks, print_eval_report, print_eval_report_dataset |
| §9 Training loop | train_l2d_multi_expert |
Everything is re-exported from the package root, so from m2p import Router, TrainConfig, train_l2d_multi_expert works without importing the file
explicitly.
This module is the single, consolidated HPO driver that supersedes and
merges the earlier adaptive_hpo_v2 and adaptive_hpo_geometric_clip
prototypes; both selection-score strategies are exposed via
selection_score_mode.
| Symbol | Purpose |
|---|---|
SELECTION_MODE_AVG, SELECTION_MODE_WEIGHTED |
Two checkpoint-selection strategies merged from the ancestor variants. |
default_rho_by_k, clip_max_from_geom, build_geometric_clip_anchors |
Geometric anti-collapse cap utilities. |
SearchSpace |
16-dim declarative search space. |
sample_config, params_to_configs |
Sampling + materialisation into (train_cfg, mcfg, al_cfg). |
ObjectiveWeights, UtopiaTracker |
Augmented Chebyshev scalarisation + adaptive utopia. |
TrialExecutor, make_objective, run_hpo, run_hpo_from_notebook |
HPO drivers. |
pareto_front, importance_analysis, retrain_with_best, print_study_summary |
Post-hoc analysis. |
data/final_dataset3.csv — 3195 rows × 36 columns.
| Column | Type | Description |
|---|---|---|
global_id |
str | Stable identifier across datasets. |
dataset |
{refuge,origa,chaksu} |
Source cohort. |
y_true |
{0,1} | Glaucoma label. |
y_*_expert_* |
float / NaN | 12 expert label columns; NaN means the expert did not annotate this image. |
m_experts |
list[bool] | 12-element availability mask. |
m_actions |
list[int] | 13-element action mask, with the AI slot prepended. |
split |
{train,val,test} |
Split assignment (per-dataset). |
is_ood |
{0,1} | 1 for ORIGA / CHAKSU. |
logit_0, logit_1 |
float | Frozen Swin-V2 logits. |
prob_0, prob_1, pred, confidence, uncertainty |
float | Softmax-derived. |
ViM, Mahalanobis, vim_risk, maha_risk, vim_risk_z |
float | OOD scores; *_risk are z-standardised against REFUGE-train. |
quality_score, quality_risk |
float | Image-quality risk. |
vCDR, hCDR, aCDR |
float | Cup-disc-ratio structural features. |
After m2p.grouping.group_dataset(...) the grouped CSV additionally carries exact_mask, support_family, hamming_to_family_rep, subcluster, group_id.
Defaults from m2p.adaptive_hpo.run_hpo. Use these as the starting point unless ablating.
| Group | Knob | Default | Range / notes |
|---|---|---|---|
| Train | lr |
1e-4 |
log-uniform [3e-5, 3e-4] |
gamma_tier |
1.0 |
log-uniform [0.05, 1.5] |
|
warmup_epochs |
15 |
int [8, 20] |
|
epochs |
150 |
fixed | |
patience |
18 |
fixed (min_delta = 1e-4) |
|
c_fn / c_fp |
2.0 / 1.5 |
study-design constant | |
| Prior reg. | tau_bad |
1.0 |
[0.30, 2.00] |
w_gsdp |
0.30 |
[0.03, 1.00] |
|
w_rank_js |
0.30 |
[0.03, 1.00] |
|
global_uniform_mix |
0.35 |
[0.05, 0.40] |
|
family_uniform_mix |
0.30 |
[0.10, 0.55] |
|
group_uniform_mix |
0.30 |
[0.10, 0.55] |
|
family_n0 / group_n0 |
25 / 30 |
Dirichlet pseudo-counts | |
clip_ceiling |
0.35 |
[0.30, 0.38] |
|
clip_slack |
0.03 |
[0.00, 0.06] |
|
| AugLag | mu |
25.0 |
[8, 40] |
lr_lambda |
0.10 |
log-uniform [0.02, 0.30] |
|
max_deferral_rate |
0.70 |
hard deployment constraint | |
| Grouping | support_distance_threshold |
0.40 |
hamming agglomerative cutoff |
min_family_size |
10 |
absorbs micro-families | |
min_train_per_cluster |
20 |
KMeans hard floor | |
min_group_train |
20 |
Stage-3 abs. fold threshold | |
min_train_frac |
0.20 |
Stage-3 fraction fold threshold |
- Determinism. Each HPO trial seeds
torch.manual_seed,numpy.random.seed, andtorch.cuda.manual_seed_allwithseed + trial.number, and DataLoader shuffling uses a per-trial generator. - Seed config.
run_hpoenqueues a known-good seed configuration so the very first trial reproduces the notebook's manual baseline. - Trial ledger. Every trial appends a JSON line with the full
(params, train_cfg, mcfg, al_cfg, metrics, scalar_score, utopia, elapsed_s)tohpo_ledger.jsonl. Together withoptuna.create_study(storage="sqlite:///hpo.db", load_if_exists=True)this lets you resume / re-analyse studies. - Submission config.
configs/best_hpo.yamlmaterialises the exact parameters of HPO trial 50, including the geometric clip anchors{5: 0.3623, 7: 0.3263, 12: 0.3006}. - Early NaN guards. Both the per-batch loss and the final metric vector are NaN-checked; offending trials are pruned rather than silently corrupting the study.
If you use this code or build on the methodology, please cite:
A CITATION.cff is provided for GitHub's "Cite this repository" link.
This project is released under the MIT License.