landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,191 @@
1
+ """Curriculum learning support for progressive training difficulty.
2
+
3
+ Implements a schedule that controls which training samples are used
4
+ at different stages of training, starting with easy examples (small
5
+ displacements) and gradually introducing harder ones.
6
+
7
+ Usage in training loop::
8
+
9
+ curriculum = TrainingCurriculum(
10
+ total_steps=100000,
11
+ warmup_fraction=0.1, # first 10% easy only
12
+ full_difficulty_at=0.5, # full dataset by 50%
13
+ )
14
+
15
+ # In training loop:
16
+ difficulty = curriculum.get_difficulty(global_step)
17
+ # Use difficulty to filter/weight samples
18
+
19
+ Or as a dataset wrapper::
20
+
21
+ dataset = CurriculumDataset(
22
+ base_dataset=SyntheticPairDataset(data_dir),
23
+ metadata_path=Path(data_dir) / "metadata.json",
24
+ total_steps=100000,
25
+ )
26
+ # Call dataset.set_step(global_step) each iteration
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import math
33
+ from pathlib import Path
34
+
35
+ import numpy as np
36
+
37
+
38
+ class TrainingCurriculum:
39
+ """Schedule that maps training step to difficulty level [0, 1].
40
+
41
+ Difficulty 0 = easiest (smallest displacements, lowest intensity).
42
+ Difficulty 1 = full dataset (all difficulties).
43
+
44
+ The schedule uses a cosine ramp:
45
+ - During warmup: difficulty = 0 (easy only)
46
+ - warmup → full_difficulty: cosine ramp from 0 → 1
47
+ - After full_difficulty: difficulty = 1 (full dataset)
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ total_steps: int,
53
+ warmup_fraction: float = 0.1,
54
+ full_difficulty_at: float = 0.5,
55
+ ):
56
+ self.total_steps = total_steps
57
+ self.warmup_steps = int(total_steps * warmup_fraction)
58
+ self.full_steps = int(total_steps * full_difficulty_at)
59
+
60
+ def get_difficulty(self, step: int) -> float:
61
+ """Get difficulty level [0, 1] for the given training step."""
62
+ if step < self.warmup_steps:
63
+ return 0.0
64
+ if step >= self.full_steps:
65
+ return 1.0
66
+ progress = (step - self.warmup_steps) / max(1, self.full_steps - self.warmup_steps)
67
+ return 0.5 * (1 - math.cos(math.pi * progress))
68
+
69
+ def should_include(
70
+ self,
71
+ step: int,
72
+ sample_difficulty: float,
73
+ rng: np.random.Generator | None = None,
74
+ ) -> bool:
75
+ """Whether to include a sample of the given difficulty at this step.
76
+
77
+ Uses probabilistic inclusion so harder samples gradually appear.
78
+
79
+ Args:
80
+ step: Current training step.
81
+ sample_difficulty: Difficulty of the sample [0, 1].
82
+ rng: Random number generator for stochastic inclusion.
83
+
84
+ Returns:
85
+ True if sample should be used.
86
+ """
87
+ curr_difficulty = self.get_difficulty(step)
88
+ if sample_difficulty <= curr_difficulty:
89
+ return True
90
+ # Stochastic inclusion for samples slightly above threshold
91
+ if rng is None:
92
+ rng = np.random.default_rng()
93
+ overshoot = sample_difficulty - curr_difficulty
94
+ include_prob = max(0, 1.0 - overshoot * 5) # drops off quickly
95
+ return rng.random() < include_prob
96
+
97
+
98
+ class ProcedureCurriculum:
99
+ """Procedure-aware curriculum that adjusts per-procedure weights.
100
+
101
+ Some procedures are inherently harder (e.g., orthognathic with large
102
+ deformations). This curriculum increases their weight over training.
103
+ """
104
+
105
+ # Difficulty ranking (0=easiest, 1=hardest)
106
+ DEFAULT_PROCEDURE_DIFFICULTY = {
107
+ "blepharoplasty": 0.3, # small, localized changes
108
+ "rhinoplasty": 0.5, # moderate, central face
109
+ "rhytidectomy": 0.7, # large, affects face shape
110
+ "orthognathic": 0.9, # largest deformations
111
+ }
112
+
113
+ def __init__(
114
+ self,
115
+ total_steps: int,
116
+ procedure_difficulty: dict[str, float] | None = None,
117
+ warmup_fraction: float = 0.1,
118
+ ):
119
+ self.curriculum = TrainingCurriculum(total_steps, warmup_fraction)
120
+ self.proc_difficulty = procedure_difficulty or self.DEFAULT_PROCEDURE_DIFFICULTY
121
+
122
+ def get_weight(self, step: int, procedure: str) -> float:
123
+ """Get sampling weight for a procedure at the given step.
124
+
125
+ Returns a value in [0.1, 1.0] — never fully excludes any procedure.
126
+ """
127
+ difficulty = self.get_difficulty(step)
128
+ proc_diff = self.proc_difficulty.get(procedure, 0.5)
129
+
130
+ if proc_diff <= difficulty:
131
+ return 1.0
132
+ # Reduce weight for too-hard procedures
133
+ return max(0.1, 1.0 - (proc_diff - difficulty) * 2)
134
+
135
+ def get_difficulty(self, step: int) -> float:
136
+ return self.curriculum.get_difficulty(step)
137
+
138
+ def get_procedure_weights(self, step: int) -> dict[str, float]:
139
+ """Get all procedure weights at the given step."""
140
+ return {proc: self.get_weight(step, proc) for proc in self.proc_difficulty}
141
+
142
+
143
+ def compute_sample_difficulty(
144
+ metadata_path: str | Path,
145
+ displacement_model_path: str | Path | None = None,
146
+ ) -> dict[str, float]:
147
+ """Compute difficulty scores for each sample in the dataset.
148
+
149
+ Difficulty is based on:
150
+ 1. Displacement intensity (from metadata)
151
+ 2. Procedure difficulty
152
+ 3. Source type (real > synthetic)
153
+
154
+ Returns:
155
+ Dict mapping sample prefix to difficulty score [0, 1].
156
+ """
157
+ with open(metadata_path) as f:
158
+ meta = json.load(f)
159
+
160
+ pairs = meta.get("pairs", {})
161
+ difficulties = {}
162
+
163
+ proc_base = {
164
+ "blepharoplasty": 0.2,
165
+ "rhinoplasty": 0.4,
166
+ "rhytidectomy": 0.6,
167
+ "orthognathic": 0.8,
168
+ "unknown": 0.5,
169
+ }
170
+
171
+ source_bonus = {
172
+ "synthetic": 0.0,
173
+ "synthetic_v3": 0.1, # realistic displacements slightly harder
174
+ "real": 0.2, # real data hardest
175
+ "augmented": 0.0,
176
+ }
177
+
178
+ for prefix, info in pairs.items():
179
+ proc = info.get("procedure", "unknown")
180
+ source = info.get("source", "synthetic")
181
+ intensity = info.get("intensity", 1.0)
182
+
183
+ # Combine factors
184
+ base = proc_base.get(proc, 0.5)
185
+ src = source_bonus.get(source, 0.0)
186
+ # Intensity scaling (higher intensity = harder)
187
+ int_factor = min(1.0, intensity / 1.5) * 0.2
188
+
189
+ difficulties[prefix] = min(1.0, base + src + int_factor)
190
+
191
+ return difficulties
landmarkdiff/data.py ADDED
@@ -0,0 +1,405 @@
1
+ """Reusable data loading utilities for LandmarkDiff training and evaluation.
2
+
3
+ Provides PyTorch Dataset implementations for loading synthetic training pairs,
4
+ manifest-based datasets, and evaluation datasets. Extracted from the training
5
+ script for reuse across training, evaluation, and testing pipelines.
6
+
7
+ Usage::
8
+
9
+ from landmarkdiff.data import SurgicalPairDataset, create_dataloader
10
+
11
+ dataset = SurgicalPairDataset("data/training_combined", resolution=512)
12
+ loader = create_dataloader(dataset, batch_size=4, num_workers=4)
13
+
14
+ for batch in loader:
15
+ input_img = batch["input"] # (B, 3, H, W) RGB [0,1]
16
+ target_img = batch["target"] # (B, 3, H, W) RGB [0,1]
17
+ conditioning = batch["conditioning"] # (B, 3, H, W) RGB [0,1]
18
+ mask = batch["mask"] # (B, 1, H, W) [0,1]
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import csv
24
+ import json
25
+ import logging
26
+ from collections.abc import Callable
27
+ from pathlib import Path
28
+
29
+ import cv2
30
+ import numpy as np
31
+ import torch
32
+ from torch.utils.data import DataLoader, Dataset, Sampler, WeightedRandomSampler
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Core dataset
39
+ # ---------------------------------------------------------------------------
40
+
41
+
42
+ class SurgicalPairDataset(Dataset):
43
+ """Dataset for loading surgical before/after training pairs.
44
+
45
+ Each sample has four components:
46
+ - input: original face image (before surgery)
47
+ - target: modified face image (after surgery)
48
+ - conditioning: 3-channel landmark mesh visualization
49
+ - mask: surgical region mask (soft float)
50
+
51
+ Supports loading from a flat directory of ``{prefix}_input.png`` files
52
+ or from a manifest CSV.
53
+
54
+ Args:
55
+ data_dir: Directory containing training pair images.
56
+ resolution: Target image resolution (square).
57
+ manifest_path: Optional CSV with columns [prefix, procedure, ...].
58
+ If None, auto-discovers pairs from ``*_input.png`` files.
59
+ transform: Optional callable for custom augmentation. Receives and
60
+ returns a dict with numpy arrays.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ data_dir: str | Path,
66
+ resolution: int = 512,
67
+ manifest_path: str | Path | None = None,
68
+ transform: Callable[[dict], dict] | None = None,
69
+ ):
70
+ self.data_dir = Path(data_dir)
71
+ self.resolution = resolution
72
+ self.transform = transform
73
+
74
+ # Discover pairs
75
+ if manifest_path is not None:
76
+ self.pairs, self.metadata = self._load_manifest(Path(manifest_path))
77
+ else:
78
+ self.pairs = sorted(self.data_dir.glob("*_input.png"))
79
+ self.metadata = self._load_metadata()
80
+
81
+ if not self.pairs:
82
+ raise FileNotFoundError(f"No training pairs found in {data_dir}")
83
+
84
+ logger.info("Loaded %d training pairs from %s", len(self.pairs), data_dir)
85
+
86
+ def _load_manifest(self, path: Path) -> tuple[list[Path], dict[str, dict]]:
87
+ """Load pairs from a manifest CSV."""
88
+ pairs = []
89
+ metadata = {}
90
+ with open(path) as f:
91
+ reader = csv.DictReader(f)
92
+ for row in reader:
93
+ prefix = row.get("prefix", row.get("name", ""))
94
+ input_path = self.data_dir / f"{prefix}_input.png"
95
+ if input_path.exists():
96
+ pairs.append(input_path)
97
+ metadata[prefix] = dict(row)
98
+ return pairs, metadata
99
+
100
+ def _load_metadata(self) -> dict[str, dict]:
101
+ """Load metadata from metadata.json if present."""
102
+ meta_path = self.data_dir / "metadata.json"
103
+ if not meta_path.exists():
104
+ return {}
105
+ try:
106
+ with open(meta_path) as f:
107
+ data = json.load(f)
108
+ result: dict = data.get("pairs", {})
109
+ return result
110
+ except (json.JSONDecodeError, OSError):
111
+ logger.debug("Failed to load metadata from %s", meta_path)
112
+ return {}
113
+
114
+ def get_procedure(self, idx: int) -> str:
115
+ """Get the surgical procedure type for a sample."""
116
+ prefix = self._prefix(idx)
117
+ info = self.metadata.get(prefix, {})
118
+ proc: str = info.get("procedure", "unknown")
119
+ return proc
120
+
121
+ def get_procedures(self) -> list[str]:
122
+ """Get procedure types for all samples."""
123
+ return [self.get_procedure(i) for i in range(len(self))]
124
+
125
+ def _prefix(self, idx: int) -> str:
126
+ return self.pairs[idx].stem.replace("_input", "")
127
+
128
+ def __len__(self) -> int:
129
+ return len(self.pairs)
130
+
131
+ def __getitem__(self, idx: int) -> dict:
132
+ prefix = self._prefix(idx)
133
+
134
+ # Load images as BGR uint8
135
+ input_bgr = self._load_image(f"{prefix}_input.png")
136
+ target_bgr = self._load_image(f"{prefix}_target.png")
137
+ cond_bgr = self._load_image(f"{prefix}_conditioning.png")
138
+ mask_arr = self._load_mask(f"{prefix}_mask.png")
139
+
140
+ sample = {
141
+ "input_image": input_bgr,
142
+ "target_image": target_bgr,
143
+ "conditioning": cond_bgr,
144
+ "mask": mask_arr,
145
+ "procedure": self.get_procedure(idx),
146
+ "idx": idx,
147
+ }
148
+
149
+ # Apply custom transform
150
+ if self.transform is not None:
151
+ sample = self.transform(sample)
152
+
153
+ # Convert to tensors
154
+ return {
155
+ "input": bgr_to_tensor(sample["input_image"]),
156
+ "target": bgr_to_tensor(sample["target_image"]),
157
+ "conditioning": bgr_to_tensor(sample["conditioning"]),
158
+ "mask": mask_to_tensor(sample["mask"]),
159
+ "procedure": sample["procedure"],
160
+ "idx": sample["idx"],
161
+ }
162
+
163
+ def _load_image(self, filename: str) -> np.ndarray:
164
+ """Load an image as BGR uint8, resized to resolution."""
165
+ path = self.data_dir / filename
166
+ img = cv2.imread(str(path))
167
+ if img is None:
168
+ logger.warning("Failed to load %s, using blank", path)
169
+ return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
170
+ if img.shape[:2] != (self.resolution, self.resolution):
171
+ img = cv2.resize(img, (self.resolution, self.resolution))
172
+ return img
173
+
174
+ def _load_mask(self, filename: str) -> np.ndarray:
175
+ """Load a mask as float32 [0,1], resized to resolution."""
176
+ path = self.data_dir / filename
177
+ if not path.exists():
178
+ return np.ones((self.resolution, self.resolution), dtype=np.float32)
179
+ mask = cv2.imread(str(path), cv2.IMREAD_GRAYSCALE)
180
+ if mask is None:
181
+ return np.ones((self.resolution, self.resolution), dtype=np.float32)
182
+ mask = cv2.resize(mask, (self.resolution, self.resolution))
183
+ return mask.astype(np.float32) / 255.0
184
+
185
+
186
+ # ---------------------------------------------------------------------------
187
+ # Evaluation dataset (input + ground truth)
188
+ # ---------------------------------------------------------------------------
189
+
190
+
191
+ class EvalPairDataset(Dataset):
192
+ """Dataset for evaluation: loads input/target pairs with procedure labels.
193
+
194
+ Args:
195
+ data_dir: Directory with evaluation pairs.
196
+ resolution: Target resolution.
197
+ """
198
+
199
+ def __init__(self, data_dir: str | Path, resolution: int = 512):
200
+ self.data_dir = Path(data_dir)
201
+ self.resolution = resolution
202
+ self.pairs = sorted(self.data_dir.glob("*_input.png"))
203
+
204
+ # Load metadata
205
+ meta_path = self.data_dir / "metadata.json"
206
+ self._meta = {}
207
+ if meta_path.exists():
208
+ try:
209
+ with open(meta_path) as f:
210
+ self._meta = json.load(f).get("pairs", {})
211
+ except (json.JSONDecodeError, OSError):
212
+ logger.debug("Failed to load metadata from %s", meta_path)
213
+
214
+ def __len__(self) -> int:
215
+ return len(self.pairs)
216
+
217
+ def __getitem__(self, idx: int) -> dict:
218
+ prefix = self.pairs[idx].stem.replace("_input", "")
219
+
220
+ input_img = self._load(f"{prefix}_input.png")
221
+ target_img = self._load(f"{prefix}_target.png")
222
+
223
+ info = self._meta.get(prefix, {})
224
+ procedure = info.get("procedure", "unknown")
225
+
226
+ return {
227
+ "input": bgr_to_tensor(input_img),
228
+ "target": bgr_to_tensor(target_img),
229
+ "procedure": procedure,
230
+ "prefix": prefix,
231
+ }
232
+
233
+ def _load(self, filename: str) -> np.ndarray:
234
+ path = self.data_dir / filename
235
+ img = cv2.imread(str(path))
236
+ if img is None:
237
+ return np.zeros((self.resolution, self.resolution, 3), dtype=np.uint8)
238
+ if img.shape[:2] != (self.resolution, self.resolution):
239
+ img = cv2.resize(img, (self.resolution, self.resolution))
240
+ return img
241
+
242
+
243
+ # ---------------------------------------------------------------------------
244
+ # Conversion utilities
245
+ # ---------------------------------------------------------------------------
246
+
247
+
248
+ def bgr_to_tensor(bgr: np.ndarray) -> torch.Tensor:
249
+ """Convert BGR uint8 image to RGB [0,1] tensor (C, H, W)."""
250
+ rgb = bgr[:, :, ::-1].astype(np.float32) / 255.0
251
+ return torch.from_numpy(np.ascontiguousarray(rgb)).permute(2, 0, 1)
252
+
253
+
254
+ def tensor_to_bgr(tensor: torch.Tensor) -> np.ndarray:
255
+ """Convert RGB [0,1] tensor (C, H, W) to BGR uint8 image."""
256
+ rgb = tensor.detach().cpu().clamp(0, 1).permute(1, 2, 0).numpy()
257
+ bgr = (rgb[:, :, ::-1] * 255).astype(np.uint8)
258
+ return np.ascontiguousarray(bgr)
259
+
260
+
261
+ def mask_to_tensor(mask: np.ndarray) -> torch.Tensor:
262
+ """Convert float32 mask (H, W) to tensor (1, H, W)."""
263
+ if mask.ndim == 3:
264
+ mask = mask[:, :, 0]
265
+ return torch.from_numpy(mask).unsqueeze(0)
266
+
267
+
268
+ # ---------------------------------------------------------------------------
269
+ # Samplers
270
+ # ---------------------------------------------------------------------------
271
+
272
+
273
+ def create_procedure_sampler(
274
+ dataset: SurgicalPairDataset,
275
+ balance_procedures: bool = True,
276
+ ) -> Sampler | None:
277
+ """Create a weighted sampler that balances procedure types.
278
+
279
+ Returns None if balancing is disabled or all procedures are the same.
280
+ """
281
+ if not balance_procedures:
282
+ return None
283
+
284
+ procedures = dataset.get_procedures()
285
+ unique_procs = list(set(procedures))
286
+
287
+ if len(unique_procs) <= 1:
288
+ return None
289
+
290
+ # Count per procedure
291
+ counts = {p: procedures.count(p) for p in unique_procs}
292
+ total = len(procedures)
293
+
294
+ # Weight inversely proportional to count
295
+ weights = []
296
+ for proc in procedures:
297
+ w = total / (len(unique_procs) * counts[proc])
298
+ weights.append(w)
299
+
300
+ return WeightedRandomSampler(
301
+ weights=weights,
302
+ num_samples=len(dataset),
303
+ replacement=True,
304
+ )
305
+
306
+
307
+ # ---------------------------------------------------------------------------
308
+ # DataLoader factory
309
+ # ---------------------------------------------------------------------------
310
+
311
+
312
+ def create_dataloader(
313
+ dataset: Dataset,
314
+ batch_size: int = 4,
315
+ num_workers: int = 4,
316
+ shuffle: bool = True,
317
+ sampler: Sampler | None = None,
318
+ pin_memory: bool = True,
319
+ drop_last: bool = True,
320
+ persistent_workers: bool = False,
321
+ ) -> DataLoader:
322
+ """Create a DataLoader with sensible defaults for training.
323
+
324
+ Args:
325
+ dataset: PyTorch Dataset.
326
+ batch_size: Batch size.
327
+ num_workers: Number of data loading workers.
328
+ shuffle: Shuffle data (ignored if sampler is provided).
329
+ sampler: Custom sampler (e.g., from create_procedure_sampler).
330
+ pin_memory: Pin memory for faster GPU transfer.
331
+ drop_last: Drop last incomplete batch.
332
+ persistent_workers: Keep workers alive between epochs.
333
+
334
+ Returns:
335
+ Configured DataLoader.
336
+ """
337
+ if sampler is not None:
338
+ shuffle = False # Sampler and shuffle are mutually exclusive
339
+
340
+ return DataLoader(
341
+ dataset,
342
+ batch_size=batch_size,
343
+ shuffle=shuffle,
344
+ sampler=sampler,
345
+ num_workers=num_workers,
346
+ pin_memory=pin_memory and torch.cuda.is_available(),
347
+ drop_last=drop_last,
348
+ persistent_workers=persistent_workers and num_workers > 0,
349
+ )
350
+
351
+
352
+ # ---------------------------------------------------------------------------
353
+ # Multi-directory dataset
354
+ # ---------------------------------------------------------------------------
355
+
356
+
357
+ class CombinedDataset(Dataset):
358
+ """Combine multiple SurgicalPairDatasets into one.
359
+
360
+ Useful for combining synthetic v1, v2, v3 data and real pairs.
361
+
362
+ Args:
363
+ datasets: List of SurgicalPairDataset instances.
364
+ """
365
+
366
+ def __init__(self, datasets: list[SurgicalPairDataset]):
367
+ self.datasets = datasets
368
+ self._cumulative_sizes = []
369
+ total = 0
370
+ for ds in datasets:
371
+ total += len(ds)
372
+ self._cumulative_sizes.append(total)
373
+
374
+ def __len__(self) -> int:
375
+ return self._cumulative_sizes[-1] if self._cumulative_sizes else 0
376
+
377
+ def __getitem__(self, idx: int) -> dict:
378
+ if idx < 0 or idx >= len(self):
379
+ raise IndexError(f"CombinedDataset index {idx} out of range [0, {len(self)})")
380
+ dataset_idx = 0
381
+ for i, size in enumerate(self._cumulative_sizes):
382
+ if idx < size:
383
+ dataset_idx = i
384
+ break
385
+ if dataset_idx > 0:
386
+ idx -= self._cumulative_sizes[dataset_idx - 1]
387
+ return self.datasets[dataset_idx][idx]
388
+
389
+ def get_procedure(self, idx: int) -> str:
390
+ if idx < 0 or idx >= len(self):
391
+ raise IndexError(f"CombinedDataset index {idx} out of range [0, {len(self)})")
392
+ dataset_idx = 0
393
+ for i, size in enumerate(self._cumulative_sizes):
394
+ if idx < size:
395
+ dataset_idx = i
396
+ break
397
+ if dataset_idx > 0:
398
+ idx -= self._cumulative_sizes[dataset_idx - 1]
399
+ return self.datasets[dataset_idx].get_procedure(idx)
400
+
401
+ def get_procedures(self) -> list[str]:
402
+ procs = []
403
+ for ds in self.datasets:
404
+ procs.extend(ds.get_procedures())
405
+ return procs