landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Curriculum learning support for progressive training difficulty.
|
|
2
|
+
|
|
3
|
+
Implements a schedule that controls which training samples are used
|
|
4
|
+
at different stages of training, starting with easy examples (small
|
|
5
|
+
displacements) and gradually introducing harder ones.
|
|
6
|
+
|
|
7
|
+
Usage in training loop::
|
|
8
|
+
|
|
9
|
+
curriculum = TrainingCurriculum(
|
|
10
|
+
total_steps=100000,
|
|
11
|
+
warmup_fraction=0.1, # first 10% easy only
|
|
12
|
+
full_difficulty_at=0.5, # full dataset by 50%
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# In training loop:
|
|
16
|
+
difficulty = curriculum.get_difficulty(global_step)
|
|
17
|
+
# Use difficulty to filter/weight samples
|
|
18
|
+
|
|
19
|
+
Or as a dataset wrapper::
|
|
20
|
+
|
|
21
|
+
dataset = CurriculumDataset(
|
|
22
|
+
base_dataset=SyntheticPairDataset(data_dir),
|
|
23
|
+
metadata_path=Path(data_dir) / "metadata.json",
|
|
24
|
+
total_steps=100000,
|
|
25
|
+
)
|
|
26
|
+
# Call dataset.set_step(global_step) each iteration
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import json
|
|
32
|
+
import math
|
|
33
|
+
from pathlib import Path
|
|
34
|
+
|
|
35
|
+
import numpy as np
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class TrainingCurriculum:
|
|
39
|
+
"""Schedule that maps training step to difficulty level [0, 1].
|
|
40
|
+
|
|
41
|
+
Difficulty 0 = easiest (smallest displacements, lowest intensity).
|
|
42
|
+
Difficulty 1 = full dataset (all difficulties).
|
|
43
|
+
|
|
44
|
+
The schedule uses a cosine ramp:
|
|
45
|
+
- During warmup: difficulty = 0 (easy only)
|
|
46
|
+
- warmup → full_difficulty: cosine ramp from 0 → 1
|
|
47
|
+
- After full_difficulty: difficulty = 1 (full dataset)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
total_steps: int,
|
|
53
|
+
warmup_fraction: float = 0.1,
|
|
54
|
+
full_difficulty_at: float = 0.5,
|
|
55
|
+
):
|
|
56
|
+
self.total_steps = total_steps
|
|
57
|
+
self.warmup_steps = int(total_steps * warmup_fraction)
|
|
58
|
+
self.full_steps = int(total_steps * full_difficulty_at)
|
|
59
|
+
|
|
60
|
+
def get_difficulty(self, step: int) -> float:
|
|
61
|
+
"""Get difficulty level [0, 1] for the given training step."""
|
|
62
|
+
if step < self.warmup_steps:
|
|
63
|
+
return 0.0
|
|
64
|
+
if step >= self.full_steps:
|
|
65
|
+
return 1.0
|
|
66
|
+
progress = (step - self.warmup_steps) / max(1, self.full_steps - self.warmup_steps)
|
|
67
|
+
return 0.5 * (1 - math.cos(math.pi * progress))
|
|
68
|
+
|
|
69
|
+
def should_include(
|
|
70
|
+
self,
|
|
71
|
+
step: int,
|
|
72
|
+
sample_difficulty: float,
|
|
73
|
+
rng: np.random.Generator | None = None,
|
|
74
|
+
) -> bool:
|
|
75
|
+
"""Whether to include a sample of the given difficulty at this step.
|
|
76
|
+
|
|
77
|
+
Uses probabilistic inclusion so harder samples gradually appear.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
step: Current training step.
|
|
81
|
+
sample_difficulty: Difficulty of the sample [0, 1].
|
|
82
|
+
rng: Random number generator for stochastic inclusion.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
True if sample should be used.
|
|
86
|
+
"""
|
|
87
|
+
curr_difficulty = self.get_difficulty(step)
|
|
88
|
+
if sample_difficulty <= curr_difficulty:
|
|
89
|
+
return True
|
|
90
|
+
# Stochastic inclusion for samples slightly above threshold
|
|
91
|
+
if rng is None:
|
|
92
|
+
rng = np.random.default_rng()
|
|
93
|
+
overshoot = sample_difficulty - curr_difficulty
|
|
94
|
+
include_prob = max(0, 1.0 - overshoot * 5) # drops off quickly
|
|
95
|
+
return rng.random() < include_prob
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ProcedureCurriculum:
|
|
99
|
+
"""Procedure-aware curriculum that adjusts per-procedure weights.
|
|
100
|
+
|
|
101
|
+
Some procedures are inherently harder (e.g., orthognathic with large
|
|
102
|
+
deformations). This curriculum increases their weight over training.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
# Difficulty ranking (0=easiest, 1=hardest)
|
|
106
|
+
DEFAULT_PROCEDURE_DIFFICULTY = {
|
|
107
|
+
"blepharoplasty": 0.3, # small, localized changes
|
|
108
|
+
"rhinoplasty": 0.5, # moderate, central face
|
|
109
|
+
"rhytidectomy": 0.7, # large, affects face shape
|
|
110
|
+
"orthognathic": 0.9, # largest deformations
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
total_steps: int,
|
|
116
|
+
procedure_difficulty: dict[str, float] | None = None,
|
|
117
|
+
warmup_fraction: float = 0.1,
|
|
118
|
+
):
|
|
119
|
+
self.curriculum = TrainingCurriculum(total_steps, warmup_fraction)
|
|
120
|
+
self.proc_difficulty = procedure_difficulty or self.DEFAULT_PROCEDURE_DIFFICULTY
|
|
121
|
+
|
|
122
|
+
def get_weight(self, step: int, procedure: str) -> float:
|
|
123
|
+
"""Get sampling weight for a procedure at the given step.
|
|
124
|
+
|
|
125
|
+
Returns a value in [0.1, 1.0] — never fully excludes any procedure.
|
|
126
|
+
"""
|
|
127
|
+
difficulty = self.get_difficulty(step)
|
|
128
|
+
proc_diff = self.proc_difficulty.get(procedure, 0.5)
|
|
129
|
+
|
|
130
|
+
if proc_diff <= difficulty:
|
|
131
|
+
return 1.0
|
|
132
|
+
# Reduce weight for too-hard procedures
|
|
133
|
+
return max(0.1, 1.0 - (proc_diff - difficulty) * 2)
|
|
134
|
+
|
|
135
|
+
def get_difficulty(self, step: int) -> float:
|
|
136
|
+
return self.curriculum.get_difficulty(step)
|
|
137
|
+
|
|
138
|
+
def get_procedure_weights(self, step: int) -> dict[str, float]:
|
|
139
|
+
"""Get all procedure weights at the given step."""
|
|
140
|
+
return {proc: self.get_weight(step, proc) for proc in self.proc_difficulty}
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def compute_sample_difficulty(
|
|
144
|
+
metadata_path: str | Path,
|
|
145
|
+
displacement_model_path: str | Path | None = None,
|
|
146
|
+
) -> dict[str, float]:
|
|
147
|
+
"""Compute difficulty scores for each sample in the dataset.
|
|
148
|
+
|
|
149
|
+
Difficulty is based on:
|
|
150
|
+
1. Displacement intensity (from metadata)
|
|
151
|
+
2. Procedure difficulty
|
|
152
|
+
3. Source type (real > synthetic)
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
Dict mapping sample prefix to difficulty score [0, 1].
|
|
156
|
+
"""
|
|
157
|
+
with open(metadata_path) as f:
|
|
158
|
+
meta = json.load(f)
|
|
159
|
+
|
|
160
|
+
pairs = meta.get("pairs", {})
|
|
161
|
+
difficulties = {}
|
|
162
|
+
|
|
163
|
+
proc_base = {
|
|
164
|
+
"blepharoplasty": 0.2,
|
|
165
|
+
"rhinoplasty": 0.4,
|
|
166
|
+
"rhytidectomy": 0.6,
|
|
167
|
+
"orthognathic": 0.8,
|
|
168
|
+
"unknown": 0.5,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
source_bonus = {
|
|
172
|
+
"synthetic": 0.0,
|
|
173
|
+
"synthetic_v3": 0.1, # realistic displacements slightly harder
|
|
174
|
+
"real": 0.2, # real data hardest
|
|
175
|
+
"augmented": 0.0,
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
for prefix, info in pairs.items():
|
|
179
|
+
proc = info.get("procedure", "unknown")
|
|
180
|
+
source = info.get("source", "synthetic")
|
|
181
|
+
intensity = info.get("intensity", 1.0)
|
|
182
|
+
|
|
183
|
+
# Combine factors
|
|
184
|
+
base = proc_base.get(proc, 0.5)
|
|
185
|
+
src = source_bonus.get(source, 0.0)
|
|
186
|
+
# Intensity scaling (higher intensity = harder)
|
|
187
|
+
int_factor = min(1.0, intensity / 1.5) * 0.2
|
|
188
|
+
|
|
189
|
+
difficulties[prefix] = min(1.0, base + src + int_factor)
|
|
190
|
+
|
|
191
|
+
return difficulties
|
landmarkdiff/data.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""Reusable data loading utilities for LandmarkDiff training and evaluation.
|
|
2
|
+
|
|
3
|
+
Provides PyTorch Dataset implementations for loading synthetic training pairs,
|
|
4
|
+
manifest-based datasets, and evaluation datasets. Extracted from the training
|
|
5
|
+
script for reuse across training, evaluation, and testing pipelines.
|
|
6
|
+
|
|
7
|
+
Usage::
|
|
8
|
+
|
|
9
|
+
from landmarkdiff.data import SurgicalPairDataset, create_dataloader
|
|
10
|
+
|
|
11
|
+
dataset = SurgicalPairDataset("data/training_combined", resolution=512)
|
|
12
|
+
loader = create_dataloader(dataset, batch_size=4, num_workers=4)
|
|
13
|
+
|
|
14
|
+
for batch in loader:
|
|
15
|
+
input_img = batch["input"] # (B, 3, H, W) RGB [0,1]
|
|
16
|
+
target_img = batch["target"] # (B, 3, H, W) RGB [0,1]
|
|
17
|
+
conditioning = batch["conditioning"] # (B, 3, H, W) RGB [0,1]
|
|
18
|
+
mask = batch["mask"] # (B, 1, H, W) [0,1]
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import csv
|
|
24
|
+
import json
|
|
25
|
+
import logging
|
|
26
|
+
from collections.abc import Callable
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
|
|
29
|
+
import cv2
|
|
30
|
+
import numpy as np
|
|
31
|
+
import torch
|
|
32
|
+
from torch.utils.data import DataLoader, Dataset, Sampler, WeightedRandomSampler
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ---------------------------------------------------------------------------
|
|
38
|
+
# Core dataset
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class SurgicalPairDataset(Dataset):
|
|
43
|
+
"""Dataset for loading surgical before/after training pairs.
|
|
44
|
+
|
|
45
|
+
Each sample has four components:
|
|
46
|
+
- input: original face image (before surgery)
|
|
47
|
+
- target: modified face image (after surgery)
|
|
48
|
+
- conditioning: 3-channel landmark mesh visualization
|
|
49
|
+
- mask: surgical region mask (soft float)
|
|
50
|
+
|
|
51
|
+
Supports loading from a flat directory of ``{prefix}_input.png`` files
|
|
52
|
+
or from a manifest CSV.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
data_dir: Directory containing training pair images.
|
|
56
|
+
resolution: Target image resolution (square).
|
|
57
|
+
manifest_path: Optional CSV with columns [prefix, procedure, ...].
|
|
58
|
+
If None, auto-discovers pairs from ``*_input.png`` files.
|
|
59
|
+
transform: Optional callable for custom augmentation. Receives and
|
|
60
|
+
returns a dict with numpy arrays.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
data_dir: str | Path,
|
|
66
|
+
resolution: int = 512,
|
|
67
|
+
manifest_path: str | Path | None = None,
|
|
68
|
+
transform: Callable[[dict], dict] | None = None,
|
|
69
|
+
):
|
|
70
|
+
self.data_dir = Path(data_dir)
|
|
71
|
+
self.resolution = resolution
|
|
72
|
+
self.transform = transform
|
|
73
|
+
|
|
74
|
+
# Discover pairs
|
|
75
|
+
if manifest_path is not None:
|
|
76
|
+
self.pairs, self.metadata = self._load_manifest(Path(manifest_path))
|
|
77
|
+
else:
|
|
78
|
+
self.pairs = sorted(self.data_dir.glob("*_input.png"))
|
|
79
|
+
self.metadata = self._load_metadata()
|
|
80
|
+
|
|
81
|
+
if not self.pairs:
|
|
82
|
+
raise FileNotFoundError(f"No training pairs found in {data_dir}")
|
|
83
|
+
|
|
84
|
+
logger.info("Loaded %d training pairs from %s", len(self.pairs), data_dir)
|
|
85
|
+
|
|
86
|
+
def _load_manifest(self, path: Path) -> tuple[list[Path], dict[str, dict]]:
|
|
87
|
+
"""Load pairs from a manifest CSV."""
|
|
88
|
+
pairs = []
|
|
89
|
+
metadata = {}
|
|
90
|
+
with open(path) as f:
|
|
91
|
+
reader = csv.DictReader(f)
|
|
92
|
+
for row in reader:
|
|
93
|
+
prefix = row.get("prefix", row.get("name", ""))
|
|
94
|
+
input_path = self.data_dir / f"{prefix}_input.png"
|
|
95
|
+
if input_path.exists():
|
|
96
|
+
pairs.append(input_path)
|
|
97
|
+
metadata[prefix] = dict(row)
|
|
98
|
+
return pairs, metadata
|
|
99
|
+
|
|
100
|
+
def _load_metadata(self) -> dict[str, dict]:
|
|
101
|
+
"""Load metadata from metadata.json if present."""
|
|
102
|
+
meta_path = self.data_dir / "metadata.json"
|
|
103
|
+
if not meta_path.exists():
|
|
104
|
+
return {}
|
|
105
|
+
try:
|
|
106
|
+
with open(meta_path) as f:
|
|
107
|
+
data = json.load(f)
|
|
108
|
+
result: dict = data.get("pairs", {})
|
|
109
|
+
return result
|
|
110
|
+
except (json.JSONDecodeError, OSError):
|
|
111
|
+
logger.debug("Failed to load metadata from %s", meta_path)
|
|
112
|
+
return {}
|
|
113
|
+
|
|
114
|
+
def get_procedure(self, idx: int) -> str:
|
|
115
|
+
"""Get the surgical procedure type for a sample."""
|
|
116
|
+
prefix = self._prefix(idx)
|
|
117
|
+
info = self.metadata.get(prefix, {})
|
|
118
|
+
proc: str = info.get("procedure", "unknown")
|
|
119
|
+
return proc
|
|
120
|
+
|
|
121
|
+
def get_procedures(self) -> list[str]:
|
|
122
|
+
"""Get procedure types for all samples."""
|
|
123
|
+
return [self.get_procedure(i) for i in range(len(self))]
|
|
124
|
+
|
|
125
|
+
def _prefix(self, idx: int) -> str:
|
|
126
|
+
return self.pairs[idx].stem.replace("_input", "")
|
|
127
|
+
|
|
128
|
+
def __len__(self) -> int:
|
|
129
|
+
return len(self.pairs)
|
|
130
|
+
|
|
131
|
+
def __getitem__(self, idx: int) -> dict:
|
|
132
|
+
prefix = self._prefix(idx)
|
|
133
|
+
|
|
134
|
+
# Load images as BGR uint8
|
|
135
|
+
input_bgr = self._load_image(f"{prefix}_input.png")
|
|
136
|
+
target_bgr = self._load_image(f"{prefix}_target.png")
|
|
137
|
+
cond_bgr = self._load_image(f"{prefix}_conditioning.png")
|
|
138
|
+
mask_arr = self._load_mask(f"{prefix}_mask.png")
|
|
139
|
+
|
|
140
|
+
sample = {
|
|
141
|
+
"input_image": input_bgr,
|
|
142
|
+
"target_image": target_bgr,
|
|
143
|
+
"conditioning": cond_bgr,
|
|
144
|
+
"mask": mask_arr,
|
|
145
|
+
"procedure": self.get_procedure(idx),
|
|
146
|
+
"idx": idx,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
# Apply custom transform
|
|
150
|
+
if self.transform is not None:
|
|
151
|
+
sample = self.transform(sample)
|
|
152
|
+
|
|
153
|
+
# Convert to tensors
|
|
154
|
+
return {
|
|
155
|
+
"input": bgr_to_tensor(sample["input_image"]),
|
|
156
|
+
"target": bgr_to_tensor(sample["target_image"]),
|
|
157
|
+
"conditioning": bgr_to_tensor(sample["conditioning"]),
|
|
158
|
+
"mask": mask_to_tensor(sample["mask"]),
|
|
159
|
+
"procedure": sample["procedure"],
|
|
160
|
+
"idx": sample["idx"],
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
def _load_image(self, filename: str) -> np.ndarray:
|
|
164
|
+
"""Load an image as BGR uint8, resized to resolution."""
|
|
165
|
+
path = self.data_dir / filename
|
|
166
|
+
img = cv2.imread(str(path))
|
|
167
|
+
if img is None:
|
|
168
|
+
logger.warning("Failed to load %s, using blank", path)
|
|
169
|
+
return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
|
|
170
|
+
if img.shape[:2] != (self.resolution, self.resolution):
|
|
171
|
+
img = cv2.resize(img, (self.resolution, self.resolution))
|
|
172
|
+
return img
|
|
173
|
+
|
|
174
|
+
def _load_mask(self, filename: str) -> np.ndarray:
|
|
175
|
+
"""Load a mask as float32 [0,1], resized to resolution."""
|
|
176
|
+
path = self.data_dir / filename
|
|
177
|
+
if not path.exists():
|
|
178
|
+
return np.ones((self.resolution, self.resolution), dtype=np.float32)
|
|
179
|
+
mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
|
|
180
|
+
if mask is None:
|
|
181
|
+
return np.ones((self.resolution, self.resolution), dtype=np.float32)
|
|
182
|
+
mask = cv2.resize(mask, (self.resolution, self.resolution))
|
|
183
|
+
return mask.astype(np.float32) / 255.0
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
# ---------------------------------------------------------------------------
|
|
187
|
+
# Evaluation dataset (input + ground truth)
|
|
188
|
+
# ---------------------------------------------------------------------------
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class EvalPairDataset(Dataset):
|
|
192
|
+
"""Dataset for evaluation: loads input/target pairs with procedure labels.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
data_dir: Directory with evaluation pairs.
|
|
196
|
+
resolution: Target resolution.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def __init__(self, data_dir: str | Path, resolution: int = 512):
|
|
200
|
+
self.data_dir = Path(data_dir)
|
|
201
|
+
self.resolution = resolution
|
|
202
|
+
self.pairs = sorted(self.data_dir.glob("*_input.png"))
|
|
203
|
+
|
|
204
|
+
# Load metadata
|
|
205
|
+
meta_path = self.data_dir / "metadata.json"
|
|
206
|
+
self._meta = {}
|
|
207
|
+
if meta_path.exists():
|
|
208
|
+
try:
|
|
209
|
+
with open(meta_path) as f:
|
|
210
|
+
self._meta = json.load(f).get("pairs", {})
|
|
211
|
+
except (json.JSONDecodeError, OSError):
|
|
212
|
+
logger.debug("Failed to load metadata from %s", meta_path)
|
|
213
|
+
|
|
214
|
+
def __len__(self) -> int:
|
|
215
|
+
return len(self.pairs)
|
|
216
|
+
|
|
217
|
+
def __getitem__(self, idx: int) -> dict:
|
|
218
|
+
prefix = self.pairs[idx].stem.replace("_input", "")
|
|
219
|
+
|
|
220
|
+
input_img = self._load(f"{prefix}_input.png")
|
|
221
|
+
target_img = self._load(f"{prefix}_target.png")
|
|
222
|
+
|
|
223
|
+
info = self._meta.get(prefix, {})
|
|
224
|
+
procedure = info.get("procedure", "unknown")
|
|
225
|
+
|
|
226
|
+
return {
|
|
227
|
+
"input": bgr_to_tensor(input_img),
|
|
228
|
+
"target": bgr_to_tensor(target_img),
|
|
229
|
+
"procedure": procedure,
|
|
230
|
+
"prefix": prefix,
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
def _load(self, filename: str) -> np.ndarray:
|
|
234
|
+
path = self.data_dir / filename
|
|
235
|
+
img = cv2.imread(str(path))
|
|
236
|
+
if img is None:
|
|
237
|
+
return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
|
|
238
|
+
if img.shape[:2] != (self.resolution, self.resolution):
|
|
239
|
+
img = cv2.resize(img, (self.resolution, self.resolution))
|
|
240
|
+
return img
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# ---------------------------------------------------------------------------
|
|
244
|
+
# Conversion utilities
|
|
245
|
+
# ---------------------------------------------------------------------------
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
|
|
249
|
+
"""Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
|
|
250
|
+
rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
|
|
251
|
+
return torch.from_numpy(np.ascontiguousarray(rgb)).permute(2, 0, 1)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def tensor_to_bgr(tensor: torch.Tensor) -> np.ndarray:
|
|
255
|
+
"""Convert RGB [0,1] tensor (C, H, W) to BGR uint8 image."""
|
|
256
|
+
rgb = tensor.detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy()
|
|
257
|
+
bgr = (rgb[:, :, ::-1] * 255).astype(np.uint8)
|
|
258
|
+
return np.ascontiguousarray(bgr)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def mask_to_tensor(mask: np.ndarray) -> torch.Tensor:
|
|
262
|
+
"""Convert float32 mask (H, W) to tensor (1, H, W)."""
|
|
263
|
+
if mask.ndim == 3:
|
|
264
|
+
mask = mask[:, :, 0]
|
|
265
|
+
return torch.from_numpy(mask).unsqueeze(0)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# ---------------------------------------------------------------------------
|
|
269
|
+
# Samplers
|
|
270
|
+
# ---------------------------------------------------------------------------
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def create_procedure_sampler(
|
|
274
|
+
dataset: SurgicalPairDataset,
|
|
275
|
+
balance_procedures: bool = True,
|
|
276
|
+
) -> Sampler | None:
|
|
277
|
+
"""Create a weighted sampler that balances procedure types.
|
|
278
|
+
|
|
279
|
+
Returns None if balancing is disabled or all procedures are the same.
|
|
280
|
+
"""
|
|
281
|
+
if not balance_procedures:
|
|
282
|
+
return None
|
|
283
|
+
|
|
284
|
+
procedures = dataset.get_procedures()
|
|
285
|
+
unique_procs = list(set(procedures))
|
|
286
|
+
|
|
287
|
+
if len(unique_procs) <= 1:
|
|
288
|
+
return None
|
|
289
|
+
|
|
290
|
+
# Count per procedure
|
|
291
|
+
counts = {p: procedures.count(p) for p in unique_procs}
|
|
292
|
+
total = len(procedures)
|
|
293
|
+
|
|
294
|
+
# Weight inversely proportional to count
|
|
295
|
+
weights = []
|
|
296
|
+
for proc in procedures:
|
|
297
|
+
w = total / (len(unique_procs) * counts[proc])
|
|
298
|
+
weights.append(w)
|
|
299
|
+
|
|
300
|
+
return WeightedRandomSampler(
|
|
301
|
+
weights=weights,
|
|
302
|
+
num_samples=len(dataset),
|
|
303
|
+
replacement=True,
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# ---------------------------------------------------------------------------
|
|
308
|
+
# DataLoader factory
|
|
309
|
+
# ---------------------------------------------------------------------------
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def create_dataloader(
|
|
313
|
+
dataset: Dataset,
|
|
314
|
+
batch_size: int = 4,
|
|
315
|
+
num_workers: int = 4,
|
|
316
|
+
shuffle: bool = True,
|
|
317
|
+
sampler: Sampler | None = None,
|
|
318
|
+
pin_memory: bool = True,
|
|
319
|
+
drop_last: bool = True,
|
|
320
|
+
persistent_workers: bool = False,
|
|
321
|
+
) -> DataLoader:
|
|
322
|
+
"""Create a DataLoader with sensible defaults for training.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
dataset: PyTorch Dataset.
|
|
326
|
+
batch_size: Batch size.
|
|
327
|
+
num_workers: Number of data loading workers.
|
|
328
|
+
shuffle: Shuffle data (ignored if sampler is provided).
|
|
329
|
+
sampler: Custom sampler (e.g., from create_procedure_sampler).
|
|
330
|
+
pin_memory: Pin memory for faster GPU transfer.
|
|
331
|
+
drop_last: Drop last incomplete batch.
|
|
332
|
+
persistent_workers: Keep workers alive between epochs.
|
|
333
|
+
|
|
334
|
+
Returns:
|
|
335
|
+
Configured DataLoader.
|
|
336
|
+
"""
|
|
337
|
+
if sampler is not None:
|
|
338
|
+
shuffle = False # Sampler and shuffle are mutually exclusive
|
|
339
|
+
|
|
340
|
+
return DataLoader(
|
|
341
|
+
dataset,
|
|
342
|
+
batch_size=batch_size,
|
|
343
|
+
shuffle=shuffle,
|
|
344
|
+
sampler=sampler,
|
|
345
|
+
num_workers=num_workers,
|
|
346
|
+
pin_memory=pin_memory and torch.cuda.is_available(),
|
|
347
|
+
drop_last=drop_last,
|
|
348
|
+
persistent_workers=persistent_workers and num_workers > 0,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
# ---------------------------------------------------------------------------
|
|
353
|
+
# Multi-directory dataset
|
|
354
|
+
# ---------------------------------------------------------------------------
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class CombinedDataset(Dataset):
|
|
358
|
+
"""Combine multiple SurgicalPairDatasets into one.
|
|
359
|
+
|
|
360
|
+
Useful for combining synthetic v1, v2, v3 data and real pairs.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
datasets: List of SurgicalPairDataset instances.
|
|
364
|
+
"""
|
|
365
|
+
|
|
366
|
+
def __init__(self, datasets: list[SurgicalPairDataset]):
|
|
367
|
+
self.datasets = datasets
|
|
368
|
+
self._cumulative_sizes = []
|
|
369
|
+
total = 0
|
|
370
|
+
for ds in datasets:
|
|
371
|
+
total += len(ds)
|
|
372
|
+
self._cumulative_sizes.append(total)
|
|
373
|
+
|
|
374
|
+
def __len__(self) -> int:
|
|
375
|
+
return self._cumulative_sizes[-1] if self._cumulative_sizes else 0
|
|
376
|
+
|
|
377
|
+
def __getitem__(self, idx: int) -> dict:
|
|
378
|
+
if idx < 0 or idx >= len(self):
|
|
379
|
+
raise IndexError(f"CombinedDataset index {idx} out of range [0, {len(self)})")
|
|
380
|
+
dataset_idx = 0
|
|
381
|
+
for i, size in enumerate(self._cumulative_sizes):
|
|
382
|
+
if idx < size:
|
|
383
|
+
dataset_idx = i
|
|
384
|
+
break
|
|
385
|
+
if dataset_idx > 0:
|
|
386
|
+
idx -= self._cumulative_sizes[dataset_idx - 1]
|
|
387
|
+
return self.datasets[dataset_idx][idx]
|
|
388
|
+
|
|
389
|
+
def get_procedure(self, idx: int) -> str:
|
|
390
|
+
if idx < 0 or idx >= len(self):
|
|
391
|
+
raise IndexError(f"CombinedDataset index {idx} out of range [0, {len(self)})")
|
|
392
|
+
dataset_idx = 0
|
|
393
|
+
for i, size in enumerate(self._cumulative_sizes):
|
|
394
|
+
if idx < size:
|
|
395
|
+
dataset_idx = i
|
|
396
|
+
break
|
|
397
|
+
if dataset_idx > 0:
|
|
398
|
+
idx -= self._cumulative_sizes[dataset_idx - 1]
|
|
399
|
+
return self.datasets[dataset_idx].get_procedure(idx)
|
|
400
|
+
|
|
401
|
+
def get_procedures(self) -> list[str]:
|
|
402
|
+
procs = []
|
|
403
|
+
for ds in self.datasets:
|
|
404
|
+
procs.extend(ds.get_procedures())
|
|
405
|
+
return procs
|