landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,415 @@
1
+ """Evaluation metrics suite.
2
+
3
+ All metrics stratified by Fitzpatrick skin type (I-VI) using ITA-based thresholding.
4
+ Primary metrics: FID, LPIPS, NME, ArcFace identity similarity.
5
+ Secondary: SSIM (relaxed target >0.80).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+ import numpy as np
14
+
15
+ try:
16
+ import cv2
17
+ except ImportError:
18
+ cv2 = None # type: ignore[assignment]
19
+
20
+
21
+ @dataclass
22
+ class EvalMetrics:
23
+ """Computed evaluation metrics for a batch of generated images."""
24
+
25
+ fid: float = 0.0
26
+ lpips: float = 0.0
27
+ nme: float = 0.0 # Normalized Mean landmark Error
28
+ identity_sim: float = 0.0 # ArcFace cosine similarity
29
+ ssim: float = 0.0
30
+
31
+ # Per-Fitzpatrick breakdown (all metrics stratified)
32
+ fid_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
33
+ nme_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
34
+ lpips_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
35
+ ssim_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
36
+ identity_sim_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
37
+ count_by_fitzpatrick: dict[str, int] = field(default_factory=dict)
38
+
39
+ # Per-procedure breakdown
40
+ nme_by_procedure: dict[str, float] = field(default_factory=dict)
41
+ lpips_by_procedure: dict[str, float] = field(default_factory=dict)
42
+ ssim_by_procedure: dict[str, float] = field(default_factory=dict)
43
+
44
+ def summary(self) -> str:
45
+ lines = [
46
+ f"FID: {self.fid:.2f}",
47
+ f"LPIPS: {self.lpips:.4f}",
48
+ f"NME: {self.nme:.4f}",
49
+ f"Identity Sim: {self.identity_sim:.4f}",
50
+ f"SSIM: {self.ssim:.4f}",
51
+ ]
52
+ if self.count_by_fitzpatrick:
53
+ lines.append("\nBy Fitzpatrick Type:")
54
+ for ftype in sorted(self.count_by_fitzpatrick):
55
+ n = self.count_by_fitzpatrick[ftype]
56
+ parts = [f" Type {ftype} (n={n}):"]
57
+ if ftype in self.lpips_by_fitzpatrick:
58
+ parts.append(f"LPIPS={self.lpips_by_fitzpatrick[ftype]:.4f}")
59
+ if ftype in self.ssim_by_fitzpatrick:
60
+ parts.append(f"SSIM={self.ssim_by_fitzpatrick[ftype]:.4f}")
61
+ if ftype in self.nme_by_fitzpatrick:
62
+ parts.append(f"NME={self.nme_by_fitzpatrick[ftype]:.4f}")
63
+ if ftype in self.identity_sim_by_fitzpatrick:
64
+ parts.append(f"ID={self.identity_sim_by_fitzpatrick[ftype]:.4f}")
65
+ lines.append(" ".join(parts))
66
+ if self.fid_by_fitzpatrick:
67
+ lines.append("\nFID by Fitzpatrick:")
68
+ for k, v in sorted(self.fid_by_fitzpatrick.items()):
69
+ lines.append(f" Type {k}: {v:.2f}")
70
+ return "\n".join(lines)
71
+
72
+ def to_dict(self) -> dict:
73
+ """Convert to flat dictionary for JSON/CSV export."""
74
+ d = {
75
+ "fid": self.fid,
76
+ "lpips": self.lpips,
77
+ "nme": self.nme,
78
+ "identity_sim": self.identity_sim,
79
+ "ssim": self.ssim,
80
+ }
81
+ for ftype in sorted(self.count_by_fitzpatrick):
82
+ prefix = f"fitz_{ftype}"
83
+ d[f"{prefix}_count"] = self.count_by_fitzpatrick.get(ftype, 0)
84
+ d[f"{prefix}_lpips"] = self.lpips_by_fitzpatrick.get(ftype, 0.0)
85
+ d[f"{prefix}_ssim"] = self.ssim_by_fitzpatrick.get(ftype, 0.0)
86
+ d[f"{prefix}_nme"] = self.nme_by_fitzpatrick.get(ftype, 0.0)
87
+ d[f"{prefix}_identity"] = self.identity_sim_by_fitzpatrick.get(ftype, 0.0)
88
+ for proc in sorted(self.nme_by_procedure):
89
+ d[f"proc_{proc}_nme"] = self.nme_by_procedure.get(proc, 0.0)
90
+ d[f"proc_{proc}_lpips"] = self.lpips_by_procedure.get(proc, 0.0)
91
+ d[f"proc_{proc}_ssim"] = self.ssim_by_procedure.get(proc, 0.0)
92
+ return d
93
+
94
+
95
+ def classify_fitzpatrick_ita(image: np.ndarray) -> str:
96
+ """Classify Fitzpatrick skin type using Individual Typology Angle (ITA).
97
+
98
+ ITA = arctan((L - 50) / b) * (180 / pi)
99
+ where L, b are from CIE L*a*b* color space.
100
+
101
+ Thresholds from Chardon et al. (1991):
102
+ - ITA > 55: Type I (very light)
103
+ - 41 < ITA <= 55: Type II (light)
104
+ - 28 < ITA <= 41: Type III (intermediate)
105
+ - 10 < ITA <= 28: Type IV (tan)
106
+ - -30 < ITA <= 10: Type V (brown)
107
+ - ITA <= -30: Type VI (dark)
108
+ """
109
+ if cv2 is None:
110
+ raise ImportError("opencv-python is required for Fitzpatrick classification")
111
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
112
+
113
+ # Sample from face center region (avoid background)
114
+ h, w = image.shape[:2]
115
+ center = lab[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
116
+
117
+ L_mean = center[:, :, 0].mean() * 100 / 255 # scale to 0-100
118
+ b_mean = center[:, :, 2].mean() - 128 # center around 0
119
+
120
+ if abs(b_mean) < 1e-6:
121
+ b_mean = 1e-6
122
+
123
+ ita = np.arctan2(L_mean - 50, b_mean) * (180 / np.pi)
124
+
125
+ if ita > 55:
126
+ return "I"
127
+ elif ita > 41:
128
+ return "II"
129
+ elif ita > 28:
130
+ return "III"
131
+ elif ita > 10:
132
+ return "IV"
133
+ elif ita > -30:
134
+ return "V"
135
+ else:
136
+ return "VI"
137
+
138
+
139
+ def compute_nme(
140
+ pred_landmarks: np.ndarray,
141
+ target_landmarks: np.ndarray,
142
+ left_eye_idx: int = 33,
143
+ right_eye_idx: int = 263,
144
+ ) -> float:
145
+ """Compute Normalized Mean Error for landmarks.
146
+
147
+ Normalized by inter-ocular distance.
148
+
149
+ Args:
150
+ pred_landmarks: (N, 2) predicted landmark positions.
151
+ target_landmarks: (N, 2) ground truth positions.
152
+ left_eye_idx: MediaPipe index for left eye center.
153
+ right_eye_idx: MediaPipe index for right eye center.
154
+
155
+ Returns:
156
+ NME value (lower is better).
157
+ """
158
+ iod = np.linalg.norm(target_landmarks[left_eye_idx] - target_landmarks[right_eye_idx])
159
+ if iod < 1.0:
160
+ iod = 1.0
161
+
162
+ distances = np.linalg.norm(pred_landmarks - target_landmarks, axis=1)
163
+ return float(np.mean(distances) / iod)
164
+
165
+
166
+ def compute_ssim(
167
+ pred: np.ndarray,
168
+ target: np.ndarray,
169
+ ) -> float:
170
+ """Compute Structural Similarity Index (SSIM).
171
+
172
+ Uses scikit-image's windowed SSIM (Wang et al. 2004) for proper
173
+ per-window computation with 11x11 Gaussian kernel.
174
+ """
175
+ try:
176
+ from skimage.metrics import structural_similarity
177
+
178
+ # Convert to grayscale if color, or compute per-channel
179
+ if pred.ndim == 3 and pred.shape[2] == 3:
180
+ return float(structural_similarity(pred, target, channel_axis=2, data_range=255))
181
+ else:
182
+ return float(structural_similarity(pred, target, data_range=255))
183
+ except ImportError:
184
+ # Fallback: simple global SSIM (not publication-quality)
185
+ pred_f = pred.astype(np.float64)
186
+ target_f = target.astype(np.float64)
187
+
188
+ mu_p = np.mean(pred_f)
189
+ mu_t = np.mean(target_f)
190
+ sigma_p = np.std(pred_f)
191
+ sigma_t = np.std(target_f)
192
+ sigma_pt = np.mean((pred_f - mu_p) * (target_f - mu_t))
193
+
194
+ C1 = (0.01 * 255) ** 2
195
+ C2 = (0.03 * 255) ** 2
196
+
197
+ ssim_val = ((2 * mu_p * mu_t + C1) * (2 * sigma_pt + C2)) / (
198
+ (mu_p**2 + mu_t**2 + C1) * (sigma_p**2 + sigma_t**2 + C2)
199
+ )
200
+ return float(ssim_val)
201
+
202
+
203
+ _LPIPS_FN = None
204
+ _ARCFACE_APP = None
205
+
206
+
207
+ def _get_lpips_fn() -> Any:
208
+ """Get or create singleton LPIPS model."""
209
+ global _LPIPS_FN
210
+ if _LPIPS_FN is None:
211
+ import lpips
212
+
213
+ _LPIPS_FN = lpips.LPIPS(net="alex", verbose=False)
214
+ _LPIPS_FN.eval()
215
+ return _LPIPS_FN
216
+
217
+
218
+ def compute_lpips(
219
+ pred: np.ndarray,
220
+ target: np.ndarray,
221
+ ) -> float:
222
+ """Compute LPIPS perceptual distance between two images.
223
+
224
+ Returns LPIPS score (lower = more similar).
225
+ """
226
+ try:
227
+ import lpips # noqa: F401
228
+ import torch
229
+ except ImportError:
230
+ return float("nan")
231
+
232
+ _lpips_fn = _get_lpips_fn()
233
+
234
+ def _to_tensor(img: np.ndarray) -> torch.Tensor:
235
+ t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0)
236
+ return t * 2 - 1 # LPIPS expects [-1, 1]
237
+
238
+ with torch.no_grad():
239
+ score = _lpips_fn(_to_tensor(pred), _to_tensor(target))
240
+ return float(score.item())
241
+
242
+
243
+ def compute_fid(
244
+ real_dir: str,
245
+ generated_dir: str,
246
+ ) -> float:
247
+ """Compute FID between directories of real and generated images.
248
+
249
+ Uses torch-fidelity for GPU-accelerated computation.
250
+
251
+ Args:
252
+ real_dir: Path to directory of real images.
253
+ generated_dir: Path to directory of generated images.
254
+
255
+ Returns:
256
+ FID score (lower = more similar distributions).
257
+ """
258
+ try:
259
+ from torch_fidelity import calculate_metrics
260
+ except ImportError:
261
+ raise ImportError(
262
+ "torch-fidelity is required for FID. Install with: pip install torch-fidelity"
263
+ ) from None
264
+
265
+ import torch
266
+
267
+ metrics = calculate_metrics(
268
+ input1=generated_dir,
269
+ input2=real_dir,
270
+ cuda=torch.cuda.is_available(),
271
+ fid=True,
272
+ verbose=False,
273
+ )
274
+ return float(metrics["frechet_inception_distance"])
275
+
276
+
277
+ def compute_identity_similarity(
278
+ pred: np.ndarray,
279
+ target: np.ndarray,
280
+ ) -> float:
281
+ """Compute ArcFace identity cosine similarity between two face images.
282
+
283
+ Returns cosine similarity [0, 1] where 1 = identical identity.
284
+ Falls back to SSIM-based proxy if InsightFace unavailable.
285
+ """
286
+ try:
287
+ from insightface.app import FaceAnalysis
288
+
289
+ global _ARCFACE_APP
290
+ if _ARCFACE_APP is None:
291
+ _ARCFACE_APP = FaceAnalysis(
292
+ name="buffalo_l",
293
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
294
+ )
295
+ _ARCFACE_APP.prepare(ctx_id=-1, det_size=(320, 320))
296
+ app = _ARCFACE_APP
297
+
298
+ pred_bgr = pred if pred.shape[2] == 3 else cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
299
+ target_bgr = target if target.shape[2] == 3 else cv2.cvtColor(target, cv2.COLOR_RGB2BGR)
300
+
301
+ pred_faces = app.get(pred_bgr)
302
+ target_faces = app.get(target_bgr)
303
+
304
+ if pred_faces and target_faces:
305
+ pred_emb = pred_faces[0].embedding
306
+ target_emb = target_faces[0].embedding
307
+ sim = np.dot(pred_emb, target_emb) / (
308
+ np.linalg.norm(pred_emb) * np.linalg.norm(target_emb) + 1e-8
309
+ )
310
+ return float(np.clip(sim, 0, 1))
311
+ except Exception:
312
+ pass
313
+
314
+ # Fallback: SSIM-based proxy
315
+ return compute_ssim(pred, target)
316
+
317
+
318
+ def evaluate_batch(
319
+ predictions: list[np.ndarray],
320
+ targets: list[np.ndarray],
321
+ pred_landmarks: list[np.ndarray] | None = None,
322
+ target_landmarks: list[np.ndarray] | None = None,
323
+ procedures: list[str] | None = None,
324
+ compute_identity: bool = False,
325
+ ) -> EvalMetrics:
326
+ """Evaluate a batch of predicted vs target images.
327
+
328
+ Computes all metrics and stratifies by Fitzpatrick skin type and procedure.
329
+
330
+ Args:
331
+ predictions: List of predicted BGR images.
332
+ targets: List of target BGR images.
333
+ pred_landmarks: Optional list of (N, 2) predicted landmark arrays.
334
+ target_landmarks: Optional list of (N, 2) target landmark arrays.
335
+ procedures: Optional list of procedure names for per-procedure breakdown.
336
+ compute_identity: Whether to compute ArcFace identity similarity (slow).
337
+
338
+ Returns:
339
+ EvalMetrics with all computed values.
340
+ """
341
+ n = len(predictions)
342
+ ssim_scores = []
343
+ lpips_scores = []
344
+ nme_scores = []
345
+ identity_scores = []
346
+ fitz_groups: dict[str, list[int]] = {}
347
+ proc_groups: dict[str, list[int]] = {}
348
+
349
+ for i in range(n):
350
+ ssim_scores.append(compute_ssim(predictions[i], targets[i]))
351
+ lpips_scores.append(compute_lpips(predictions[i], targets[i]))
352
+
353
+ if pred_landmarks is not None and target_landmarks is not None:
354
+ nme_scores.append(compute_nme(pred_landmarks[i], target_landmarks[i]))
355
+
356
+ if compute_identity:
357
+ identity_scores.append(compute_identity_similarity(predictions[i], targets[i]))
358
+
359
+ # Fitzpatrick classification
360
+ if cv2 is not None:
361
+ try:
362
+ fitz = classify_fitzpatrick_ita(targets[i])
363
+ fitz_groups.setdefault(fitz, []).append(i)
364
+ except Exception:
365
+ pass
366
+
367
+ # Procedure grouping
368
+ if procedures is not None and i < len(procedures):
369
+ proc_groups.setdefault(procedures[i], []).append(i)
370
+
371
+ metrics = EvalMetrics(
372
+ ssim=float(np.nanmean(ssim_scores)) if ssim_scores else 0.0,
373
+ lpips=float(np.nanmean(lpips_scores)) if lpips_scores else 0.0,
374
+ nme=float(np.nanmean(nme_scores)) if nme_scores else 0.0,
375
+ identity_sim=float(np.nanmean(identity_scores)) if identity_scores else 0.0,
376
+ )
377
+
378
+ # Full Fitzpatrick stratification for ALL metrics
379
+ for ftype, indices in fitz_groups.items():
380
+ metrics.count_by_fitzpatrick[ftype] = len(indices)
381
+
382
+ group_lpips = [lpips_scores[i] for i in indices]
383
+ if group_lpips:
384
+ metrics.lpips_by_fitzpatrick[ftype] = float(np.nanmean(group_lpips))
385
+
386
+ group_ssim = [ssim_scores[i] for i in indices]
387
+ if group_ssim:
388
+ metrics.ssim_by_fitzpatrick[ftype] = float(np.nanmean(group_ssim))
389
+
390
+ if nme_scores:
391
+ group_nme = [nme_scores[i] for i in indices if i < len(nme_scores)]
392
+ if group_nme:
393
+ metrics.nme_by_fitzpatrick[ftype] = float(np.nanmean(group_nme))
394
+
395
+ if identity_scores:
396
+ group_id = [identity_scores[i] for i in indices if i < len(identity_scores)]
397
+ if group_id:
398
+ metrics.identity_sim_by_fitzpatrick[ftype] = float(np.nanmean(group_id))
399
+
400
+ # Per-procedure breakdown
401
+ for proc, indices in proc_groups.items():
402
+ group_lpips = [lpips_scores[i] for i in indices]
403
+ if group_lpips:
404
+ metrics.lpips_by_procedure[proc] = float(np.nanmean(group_lpips))
405
+
406
+ group_ssim = [ssim_scores[i] for i in indices]
407
+ if group_ssim:
408
+ metrics.ssim_by_procedure[proc] = float(np.nanmean(group_ssim))
409
+
410
+ if nme_scores:
411
+ group_nme = [nme_scores[i] for i in indices if i < len(nme_scores)]
412
+ if group_nme:
413
+ metrics.nme_by_procedure[proc] = float(np.nanmean(group_nme))
414
+
415
+ return metrics
@@ -0,0 +1,231 @@
1
+ """Local experiment tracker for training reproducibility.
2
+
3
+ Tracks all training runs with their configs, metrics, and results.
4
+ Each experiment gets a unique ID and timestamp.
5
+
6
+ Usage::
7
+
8
+ tracker = ExperimentTracker("experiments/")
9
+
10
+ # Start a new experiment
11
+ exp_id = tracker.start(
12
+ name="phaseA_v2",
13
+ config={
14
+ "phase": "A", "lr": 1e-5, "batch": 4,
15
+ "steps": 100000, "data": "training_combined",
16
+ },
17
+ )
18
+
19
+ # Log metrics during training
20
+ tracker.log_metric(exp_id, step=1000, loss=0.045, ssim=0.82)
21
+
22
+ # Record final results
23
+ tracker.finish(exp_id, results={"fid": 42.3, "ssim": 0.87})
24
+
25
+ # List all experiments
26
+ tracker.list_experiments()
27
+
28
+ # Compare experiments
29
+ tracker.compare(["exp_001", "exp_002"])
30
+ """
31
+
32
+ from __future__ import annotations
33
+
34
+ import json
35
+ import logging
36
+ import os
37
+ import socket
38
+ import time
39
+ from datetime import datetime
40
+ from pathlib import Path
41
+
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class ExperimentTracker:
46
+ """Simple file-based experiment tracker."""
47
+
48
+ def __init__(self, experiments_dir: str = "experiments"):
49
+ self.dir = Path(experiments_dir)
50
+ self.dir.mkdir(parents=True, exist_ok=True)
51
+ self._index_path = self.dir / "index.json"
52
+ self._index = self._load_index()
53
+
54
+ def _load_index(self) -> dict:
55
+ if self._index_path.exists():
56
+ with open(self._index_path) as f:
57
+ return json.load(f)
58
+ return {"experiments": {}, "counter": 0}
59
+
60
+ def _save_index(self) -> None:
61
+ with open(self._index_path, "w") as f:
62
+ json.dump(self._index, f, indent=2)
63
+
64
+ def start(
65
+ self,
66
+ name: str,
67
+ config: dict,
68
+ tags: list[str] | None = None,
69
+ ) -> str:
70
+ """Start a new experiment. Returns experiment ID."""
71
+ self._index["counter"] += 1
72
+ exp_id = f"exp_{self._index['counter']:03d}"
73
+
74
+ exp = {
75
+ "id": exp_id,
76
+ "name": name,
77
+ "config": config,
78
+ "tags": tags or [],
79
+ "status": "running",
80
+ "started_at": datetime.now().isoformat(),
81
+ "finished_at": None,
82
+ "hostname": socket.gethostname(),
83
+ "slurm_job_id": os.environ.get("SLURM_JOB_ID"),
84
+ "gpu": os.environ.get("CUDA_VISIBLE_DEVICES"),
85
+ "results": {},
86
+ "metrics_file": f"{exp_id}_metrics.jsonl",
87
+ }
88
+
89
+ self._index["experiments"][exp_id] = exp
90
+ self._save_index()
91
+
92
+ # Create metrics log file
93
+ metrics_path = self.dir / str(exp["metrics_file"])
94
+ metrics_path.touch()
95
+
96
+ logger.info("Experiment started: %s (%s)", exp_id, name)
97
+ return exp_id
98
+
99
+ def log_metric(self, exp_id: str, step: int | None = None, **metrics) -> None:
100
+ """Log metrics for a training step."""
101
+ exp = self._index["experiments"].get(exp_id)
102
+ if not exp:
103
+ return
104
+
105
+ entry = {
106
+ "timestamp": time.time(),
107
+ "step": step,
108
+ **metrics,
109
+ }
110
+
111
+ metrics_path = self.dir / str(exp["metrics_file"])
112
+ with open(metrics_path, "a") as f:
113
+ f.write(json.dumps(entry) + "\n")
114
+
115
+ def finish(
116
+ self,
117
+ exp_id: str,
118
+ results: dict | None = None,
119
+ status: str = "completed",
120
+ ) -> None:
121
+ """Mark experiment as finished."""
122
+ exp = self._index["experiments"].get(exp_id)
123
+ if not exp:
124
+ return
125
+
126
+ exp["status"] = status
127
+ exp["finished_at"] = datetime.now().isoformat()
128
+ if results:
129
+ exp["results"] = results
130
+
131
+ self._save_index()
132
+ logger.info("Experiment %s %s", exp_id, status)
133
+
134
+ def get_metrics(self, exp_id: str) -> list[dict]:
135
+ """Load all logged metrics for an experiment."""
136
+ exp = self._index["experiments"].get(exp_id)
137
+ if not exp:
138
+ return []
139
+
140
+ metrics_path = self.dir / str(exp["metrics_file"])
141
+ if not metrics_path.exists():
142
+ return []
143
+
144
+ entries = []
145
+ with open(metrics_path) as f:
146
+ for line in f:
147
+ line = line.strip()
148
+ if line:
149
+ entries.append(json.loads(line))
150
+ return entries
151
+
152
+ def list_experiments(self) -> list[dict]:
153
+ """List all experiments with summary info."""
154
+ experiments = []
155
+ for exp_id, exp in sorted(self._index["experiments"].items()):
156
+ summary = {
157
+ "id": exp_id,
158
+ "name": exp["name"],
159
+ "status": exp["status"],
160
+ "started": exp["started_at"][:19],
161
+ "tags": exp.get("tags", []),
162
+ }
163
+ if exp["results"]:
164
+ for key in ["fid", "ssim", "lpips", "nme"]:
165
+ if key in exp["results"]:
166
+ summary[key] = exp["results"][key]
167
+ experiments.append(summary)
168
+ return experiments
169
+
170
+ def compare(self, exp_ids: list[str]) -> dict:
171
+ """Compare multiple experiments by their results."""
172
+ comparison = {}
173
+ for exp_id in exp_ids:
174
+ exp = self._index["experiments"].get(exp_id)
175
+ if exp:
176
+ comparison[exp_id] = {
177
+ "name": exp["name"],
178
+ "config": exp["config"],
179
+ "results": exp["results"],
180
+ }
181
+ return comparison
182
+
183
+ def print_summary(self) -> None:
184
+ """Print a summary table of all experiments."""
185
+ experiments = self.list_experiments()
186
+ if not experiments:
187
+ logger.info("No experiments found.")
188
+ return
189
+
190
+ # Header
191
+ logger.info(
192
+ "%s %s %s %s %s %s",
193
+ "ID".ljust(10),
194
+ "Name".ljust(20),
195
+ "Status".ljust(12),
196
+ "FID".rjust(6),
197
+ "SSIM".rjust(6),
198
+ "LPIPS".rjust(6),
199
+ )
200
+ logger.info("-" * 70)
201
+
202
+ for exp in experiments:
203
+ fid = f"{exp.get('fid', '')}" if "fid" in exp else "--"
204
+ ssim = f"{exp.get('ssim', ''):.4f}" if "ssim" in exp else "--"
205
+ lpips = f"{exp.get('lpips', ''):.4f}" if "lpips" in exp else "--"
206
+ logger.info(
207
+ "%s %s %s %s %s %s",
208
+ exp["id"].ljust(10),
209
+ exp["name"].ljust(20),
210
+ exp["status"].ljust(12),
211
+ fid.rjust(6),
212
+ ssim.rjust(6),
213
+ lpips.rjust(6),
214
+ )
215
+
216
+ def get_best(self, metric: str = "fid", lower_is_better: bool = True) -> str | None:
217
+ """Get the experiment ID with the best value for a given metric."""
218
+ best_id = None
219
+ best_val = float("inf") if lower_is_better else float("-inf")
220
+
221
+ for exp_id, exp in self._index["experiments"].items():
222
+ if exp["status"] != "completed":
223
+ continue
224
+ val = exp["results"].get(metric)
225
+ if val is None:
226
+ continue
227
+ if (lower_is_better and val < best_val) or (not lower_is_better and val > best_val):
228
+ best_val = val
229
+ best_id = exp_id
230
+
231
+ return best_id