landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
"""Evaluation metrics suite.
|
|
2
|
+
|
|
3
|
+
All metrics stratified by Fitzpatrick skin type (I-VI) using ITA-based thresholding.
|
|
4
|
+
Primary metrics: FID, LPIPS, NME, ArcFace identity similarity.
|
|
5
|
+
Secondary: SSIM (relaxed target >0.80).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import cv2
|
|
17
|
+
except ImportError:
|
|
18
|
+
cv2 = None # type: ignore[assignment]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class EvalMetrics:
|
|
23
|
+
"""Computed evaluation metrics for a batch of generated images."""
|
|
24
|
+
|
|
25
|
+
fid: float = 0.0
|
|
26
|
+
lpips: float = 0.0
|
|
27
|
+
nme: float = 0.0 # Normalized Mean landmark Error
|
|
28
|
+
identity_sim: float = 0.0 # ArcFace cosine similarity
|
|
29
|
+
ssim: float = 0.0
|
|
30
|
+
|
|
31
|
+
# Per-Fitzpatrick breakdown (all metrics stratified)
|
|
32
|
+
fid_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
|
|
33
|
+
nme_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
|
|
34
|
+
lpips_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
|
|
35
|
+
ssim_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
|
|
36
|
+
identity_sim_by_fitzpatrick: dict[str, float] = field(default_factory=dict)
|
|
37
|
+
count_by_fitzpatrick: dict[str, int] = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
# Per-procedure breakdown
|
|
40
|
+
nme_by_procedure: dict[str, float] = field(default_factory=dict)
|
|
41
|
+
lpips_by_procedure: dict[str, float] = field(default_factory=dict)
|
|
42
|
+
ssim_by_procedure: dict[str, float] = field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
def summary(self) -> str:
|
|
45
|
+
lines = [
|
|
46
|
+
f"FID: {self.fid:.2f}",
|
|
47
|
+
f"LPIPS: {self.lpips:.4f}",
|
|
48
|
+
f"NME: {self.nme:.4f}",
|
|
49
|
+
f"Identity Sim: {self.identity_sim:.4f}",
|
|
50
|
+
f"SSIM: {self.ssim:.4f}",
|
|
51
|
+
]
|
|
52
|
+
if self.count_by_fitzpatrick:
|
|
53
|
+
lines.append("\nBy Fitzpatrick Type:")
|
|
54
|
+
for ftype in sorted(self.count_by_fitzpatrick):
|
|
55
|
+
n = self.count_by_fitzpatrick[ftype]
|
|
56
|
+
parts = [f" Type {ftype} (n={n}):"]
|
|
57
|
+
if ftype in self.lpips_by_fitzpatrick:
|
|
58
|
+
parts.append(f"LPIPS={self.lpips_by_fitzpatrick[ftype]:.4f}")
|
|
59
|
+
if ftype in self.ssim_by_fitzpatrick:
|
|
60
|
+
parts.append(f"SSIM={self.ssim_by_fitzpatrick[ftype]:.4f}")
|
|
61
|
+
if ftype in self.nme_by_fitzpatrick:
|
|
62
|
+
parts.append(f"NME={self.nme_by_fitzpatrick[ftype]:.4f}")
|
|
63
|
+
if ftype in self.identity_sim_by_fitzpatrick:
|
|
64
|
+
parts.append(f"ID={self.identity_sim_by_fitzpatrick[ftype]:.4f}")
|
|
65
|
+
lines.append(" ".join(parts))
|
|
66
|
+
if self.fid_by_fitzpatrick:
|
|
67
|
+
lines.append("\nFID by Fitzpatrick:")
|
|
68
|
+
for k, v in sorted(self.fid_by_fitzpatrick.items()):
|
|
69
|
+
lines.append(f" Type {k}: {v:.2f}")
|
|
70
|
+
return "\n".join(lines)
|
|
71
|
+
|
|
72
|
+
def to_dict(self) -> dict:
|
|
73
|
+
"""Convert to flat dictionary for JSON/CSV export."""
|
|
74
|
+
d = {
|
|
75
|
+
"fid": self.fid,
|
|
76
|
+
"lpips": self.lpips,
|
|
77
|
+
"nme": self.nme,
|
|
78
|
+
"identity_sim": self.identity_sim,
|
|
79
|
+
"ssim": self.ssim,
|
|
80
|
+
}
|
|
81
|
+
for ftype in sorted(self.count_by_fitzpatrick):
|
|
82
|
+
prefix = f"fitz_{ftype}"
|
|
83
|
+
d[f"{prefix}_count"] = self.count_by_fitzpatrick.get(ftype, 0)
|
|
84
|
+
d[f"{prefix}_lpips"] = self.lpips_by_fitzpatrick.get(ftype, 0.0)
|
|
85
|
+
d[f"{prefix}_ssim"] = self.ssim_by_fitzpatrick.get(ftype, 0.0)
|
|
86
|
+
d[f"{prefix}_nme"] = self.nme_by_fitzpatrick.get(ftype, 0.0)
|
|
87
|
+
d[f"{prefix}_identity"] = self.identity_sim_by_fitzpatrick.get(ftype, 0.0)
|
|
88
|
+
for proc in sorted(self.nme_by_procedure):
|
|
89
|
+
d[f"proc_{proc}_nme"] = self.nme_by_procedure.get(proc, 0.0)
|
|
90
|
+
d[f"proc_{proc}_lpips"] = self.lpips_by_procedure.get(proc, 0.0)
|
|
91
|
+
d[f"proc_{proc}_ssim"] = self.ssim_by_procedure.get(proc, 0.0)
|
|
92
|
+
return d
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def classify_fitzpatrick_ita(image: np.ndarray) -> str:
|
|
96
|
+
"""Classify Fitzpatrick skin type using Individual Typology Angle (ITA).
|
|
97
|
+
|
|
98
|
+
ITA = arctan((L - 50) / b) * (180 / pi)
|
|
99
|
+
where L, b are from CIE L*a*b* color space.
|
|
100
|
+
|
|
101
|
+
Thresholds from Chardon et al. (1991):
|
|
102
|
+
- ITA > 55: Type I (very light)
|
|
103
|
+
- 41 < ITA <= 55: Type II (light)
|
|
104
|
+
- 28 < ITA <= 41: Type III (intermediate)
|
|
105
|
+
- 10 < ITA <= 28: Type IV (tan)
|
|
106
|
+
- -30 < ITA <= 10: Type V (brown)
|
|
107
|
+
- ITA <= -30: Type VI (dark)
|
|
108
|
+
"""
|
|
109
|
+
if cv2 is None:
|
|
110
|
+
raise ImportError("opencv-python is required for Fitzpatrick classification")
|
|
111
|
+
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
|
|
112
|
+
|
|
113
|
+
# Sample from face center region (avoid background)
|
|
114
|
+
h, w = image.shape[:2]
|
|
115
|
+
center = lab[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
|
|
116
|
+
|
|
117
|
+
L_mean = center[:, :, 0].mean() * 100 / 255 # scale to 0-100
|
|
118
|
+
b_mean = center[:, :, 2].mean() - 128 # center around 0
|
|
119
|
+
|
|
120
|
+
if abs(b_mean) < 1e-6:
|
|
121
|
+
b_mean = 1e-6
|
|
122
|
+
|
|
123
|
+
ita = np.arctan2(L_mean - 50, b_mean) * (180 / np.pi)
|
|
124
|
+
|
|
125
|
+
if ita > 55:
|
|
126
|
+
return "I"
|
|
127
|
+
elif ita > 41:
|
|
128
|
+
return "II"
|
|
129
|
+
elif ita > 28:
|
|
130
|
+
return "III"
|
|
131
|
+
elif ita > 10:
|
|
132
|
+
return "IV"
|
|
133
|
+
elif ita > -30:
|
|
134
|
+
return "V"
|
|
135
|
+
else:
|
|
136
|
+
return "VI"
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def compute_nme(
|
|
140
|
+
pred_landmarks: np.ndarray,
|
|
141
|
+
target_landmarks: np.ndarray,
|
|
142
|
+
left_eye_idx: int = 33,
|
|
143
|
+
right_eye_idx: int = 263,
|
|
144
|
+
) -> float:
|
|
145
|
+
"""Compute Normalized Mean Error for landmarks.
|
|
146
|
+
|
|
147
|
+
Normalized by inter-ocular distance.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
pred_landmarks: (N, 2) predicted landmark positions.
|
|
151
|
+
target_landmarks: (N, 2) ground truth positions.
|
|
152
|
+
left_eye_idx: MediaPipe index for left eye center.
|
|
153
|
+
right_eye_idx: MediaPipe index for right eye center.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
NME value (lower is better).
|
|
157
|
+
"""
|
|
158
|
+
iod = np.linalg.norm(target_landmarks[left_eye_idx] - target_landmarks[right_eye_idx])
|
|
159
|
+
if iod < 1.0:
|
|
160
|
+
iod = 1.0
|
|
161
|
+
|
|
162
|
+
distances = np.linalg.norm(pred_landmarks - target_landmarks, axis=1)
|
|
163
|
+
return float(np.mean(distances) / iod)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def compute_ssim(
|
|
167
|
+
pred: np.ndarray,
|
|
168
|
+
target: np.ndarray,
|
|
169
|
+
) -> float:
|
|
170
|
+
"""Compute Structural Similarity Index (SSIM).
|
|
171
|
+
|
|
172
|
+
Uses scikit-image's windowed SSIM (Wang et al. 2004) for proper
|
|
173
|
+
per-window computation with 11x11 Gaussian kernel.
|
|
174
|
+
"""
|
|
175
|
+
try:
|
|
176
|
+
from skimage.metrics import structural_similarity
|
|
177
|
+
|
|
178
|
+
# Convert to grayscale if color, or compute per-channel
|
|
179
|
+
if pred.ndim == 3 and pred.shape[2] == 3:
|
|
180
|
+
return float(structural_similarity(pred, target, channel_axis=2, data_range=255))
|
|
181
|
+
else:
|
|
182
|
+
return float(structural_similarity(pred, target, data_range=255))
|
|
183
|
+
except ImportError:
|
|
184
|
+
# Fallback: simple global SSIM (not publication-quality)
|
|
185
|
+
pred_f = pred.astype(np.float64)
|
|
186
|
+
target_f = target.astype(np.float64)
|
|
187
|
+
|
|
188
|
+
mu_p = np.mean(pred_f)
|
|
189
|
+
mu_t = np.mean(target_f)
|
|
190
|
+
sigma_p = np.std(pred_f)
|
|
191
|
+
sigma_t = np.std(target_f)
|
|
192
|
+
sigma_pt = np.mean((pred_f - mu_p) * (target_f - mu_t))
|
|
193
|
+
|
|
194
|
+
C1 = (0.01 * 255) ** 2
|
|
195
|
+
C2 = (0.03 * 255) ** 2
|
|
196
|
+
|
|
197
|
+
ssim_val = ((2 * mu_p * mu_t + C1) * (2 * sigma_pt + C2)) / (
|
|
198
|
+
(mu_p**2 + mu_t**2 + C1) * (sigma_p**2 + sigma_t**2 + C2)
|
|
199
|
+
)
|
|
200
|
+
return float(ssim_val)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
_LPIPS_FN = None
|
|
204
|
+
_ARCFACE_APP = None
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def _get_lpips_fn() -> Any:
|
|
208
|
+
"""Get or create singleton LPIPS model."""
|
|
209
|
+
global _LPIPS_FN
|
|
210
|
+
if _LPIPS_FN is None:
|
|
211
|
+
import lpips
|
|
212
|
+
|
|
213
|
+
_LPIPS_FN = lpips.LPIPS(net="alex", verbose=False)
|
|
214
|
+
_LPIPS_FN.eval()
|
|
215
|
+
return _LPIPS_FN
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def compute_lpips(
|
|
219
|
+
pred: np.ndarray,
|
|
220
|
+
target: np.ndarray,
|
|
221
|
+
) -> float:
|
|
222
|
+
"""Compute LPIPS perceptual distance between two images.
|
|
223
|
+
|
|
224
|
+
Returns LPIPS score (lower = more similar).
|
|
225
|
+
"""
|
|
226
|
+
try:
|
|
227
|
+
import lpips # noqa: F401
|
|
228
|
+
import torch
|
|
229
|
+
except ImportError:
|
|
230
|
+
return float("nan")
|
|
231
|
+
|
|
232
|
+
_lpips_fn = _get_lpips_fn()
|
|
233
|
+
|
|
234
|
+
def _to_tensor(img: np.ndarray) -> torch.Tensor:
|
|
235
|
+
t = torch.from_numpy(img.astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0)
|
|
236
|
+
return t * 2 - 1 # LPIPS expects [-1, 1]
|
|
237
|
+
|
|
238
|
+
with torch.no_grad():
|
|
239
|
+
score = _lpips_fn(_to_tensor(pred), _to_tensor(target))
|
|
240
|
+
return float(score.item())
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def compute_fid(
|
|
244
|
+
real_dir: str,
|
|
245
|
+
generated_dir: str,
|
|
246
|
+
) -> float:
|
|
247
|
+
"""Compute FID between directories of real and generated images.
|
|
248
|
+
|
|
249
|
+
Uses torch-fidelity for GPU-accelerated computation.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
real_dir: Path to directory of real images.
|
|
253
|
+
generated_dir: Path to directory of generated images.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
FID score (lower = more similar distributions).
|
|
257
|
+
"""
|
|
258
|
+
try:
|
|
259
|
+
from torch_fidelity import calculate_metrics
|
|
260
|
+
except ImportError:
|
|
261
|
+
raise ImportError(
|
|
262
|
+
"torch-fidelity is required for FID. Install with: pip install torch-fidelity"
|
|
263
|
+
) from None
|
|
264
|
+
|
|
265
|
+
import torch
|
|
266
|
+
|
|
267
|
+
metrics = calculate_metrics(
|
|
268
|
+
input1=generated_dir,
|
|
269
|
+
input2=real_dir,
|
|
270
|
+
cuda=torch.cuda.is_available(),
|
|
271
|
+
fid=True,
|
|
272
|
+
verbose=False,
|
|
273
|
+
)
|
|
274
|
+
return float(metrics["frechet_inception_distance"])
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def compute_identity_similarity(
|
|
278
|
+
pred: np.ndarray,
|
|
279
|
+
target: np.ndarray,
|
|
280
|
+
) -> float:
|
|
281
|
+
"""Compute ArcFace identity cosine similarity between two face images.
|
|
282
|
+
|
|
283
|
+
Returns cosine similarity [0, 1] where 1 = identical identity.
|
|
284
|
+
Falls back to SSIM-based proxy if InsightFace unavailable.
|
|
285
|
+
"""
|
|
286
|
+
try:
|
|
287
|
+
from insightface.app import FaceAnalysis
|
|
288
|
+
|
|
289
|
+
global _ARCFACE_APP
|
|
290
|
+
if _ARCFACE_APP is None:
|
|
291
|
+
_ARCFACE_APP = FaceAnalysis(
|
|
292
|
+
name="buffalo_l",
|
|
293
|
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
294
|
+
)
|
|
295
|
+
_ARCFACE_APP.prepare(ctx_id=-1, det_size=(320, 320))
|
|
296
|
+
app = _ARCFACE_APP
|
|
297
|
+
|
|
298
|
+
pred_bgr = pred if pred.shape[2] == 3 else cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
|
|
299
|
+
target_bgr = target if target.shape[2] == 3 else cv2.cvtColor(target, cv2.COLOR_RGB2BGR)
|
|
300
|
+
|
|
301
|
+
pred_faces = app.get(pred_bgr)
|
|
302
|
+
target_faces = app.get(target_bgr)
|
|
303
|
+
|
|
304
|
+
if pred_faces and target_faces:
|
|
305
|
+
pred_emb = pred_faces[0].embedding
|
|
306
|
+
target_emb = target_faces[0].embedding
|
|
307
|
+
sim = np.dot(pred_emb, target_emb) / (
|
|
308
|
+
np.linalg.norm(pred_emb) * np.linalg.norm(target_emb) + 1e-8
|
|
309
|
+
)
|
|
310
|
+
return float(np.clip(sim, 0, 1))
|
|
311
|
+
except Exception:
|
|
312
|
+
pass
|
|
313
|
+
|
|
314
|
+
# Fallback: SSIM-based proxy
|
|
315
|
+
return compute_ssim(pred, target)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def evaluate_batch(
|
|
319
|
+
predictions: list[np.ndarray],
|
|
320
|
+
targets: list[np.ndarray],
|
|
321
|
+
pred_landmarks: list[np.ndarray] | None = None,
|
|
322
|
+
target_landmarks: list[np.ndarray] | None = None,
|
|
323
|
+
procedures: list[str] | None = None,
|
|
324
|
+
compute_identity: bool = False,
|
|
325
|
+
) -> EvalMetrics:
|
|
326
|
+
"""Evaluate a batch of predicted vs target images.
|
|
327
|
+
|
|
328
|
+
Computes all metrics and stratifies by Fitzpatrick skin type and procedure.
|
|
329
|
+
|
|
330
|
+
Args:
|
|
331
|
+
predictions: List of predicted BGR images.
|
|
332
|
+
targets: List of target BGR images.
|
|
333
|
+
pred_landmarks: Optional list of (N, 2) predicted landmark arrays.
|
|
334
|
+
target_landmarks: Optional list of (N, 2) target landmark arrays.
|
|
335
|
+
procedures: Optional list of procedure names for per-procedure breakdown.
|
|
336
|
+
compute_identity: Whether to compute ArcFace identity similarity (slow).
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
EvalMetrics with all computed values.
|
|
340
|
+
"""
|
|
341
|
+
n = len(predictions)
|
|
342
|
+
ssim_scores = []
|
|
343
|
+
lpips_scores = []
|
|
344
|
+
nme_scores = []
|
|
345
|
+
identity_scores = []
|
|
346
|
+
fitz_groups: dict[str, list[int]] = {}
|
|
347
|
+
proc_groups: dict[str, list[int]] = {}
|
|
348
|
+
|
|
349
|
+
for i in range(n):
|
|
350
|
+
ssim_scores.append(compute_ssim(predictions[i], targets[i]))
|
|
351
|
+
lpips_scores.append(compute_lpips(predictions[i], targets[i]))
|
|
352
|
+
|
|
353
|
+
if pred_landmarks is not None and target_landmarks is not None:
|
|
354
|
+
nme_scores.append(compute_nme(pred_landmarks[i], target_landmarks[i]))
|
|
355
|
+
|
|
356
|
+
if compute_identity:
|
|
357
|
+
identity_scores.append(compute_identity_similarity(predictions[i], targets[i]))
|
|
358
|
+
|
|
359
|
+
# Fitzpatrick classification
|
|
360
|
+
if cv2 is not None:
|
|
361
|
+
try:
|
|
362
|
+
fitz = classify_fitzpatrick_ita(targets[i])
|
|
363
|
+
fitz_groups.setdefault(fitz, []).append(i)
|
|
364
|
+
except Exception:
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
# Procedure grouping
|
|
368
|
+
if procedures is not None and i < len(procedures):
|
|
369
|
+
proc_groups.setdefault(procedures[i], []).append(i)
|
|
370
|
+
|
|
371
|
+
metrics = EvalMetrics(
|
|
372
|
+
ssim=float(np.nanmean(ssim_scores)) if ssim_scores else 0.0,
|
|
373
|
+
lpips=float(np.nanmean(lpips_scores)) if lpips_scores else 0.0,
|
|
374
|
+
nme=float(np.nanmean(nme_scores)) if nme_scores else 0.0,
|
|
375
|
+
identity_sim=float(np.nanmean(identity_scores)) if identity_scores else 0.0,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
# Full Fitzpatrick stratification for ALL metrics
|
|
379
|
+
for ftype, indices in fitz_groups.items():
|
|
380
|
+
metrics.count_by_fitzpatrick[ftype] = len(indices)
|
|
381
|
+
|
|
382
|
+
group_lpips = [lpips_scores[i] for i in indices]
|
|
383
|
+
if group_lpips:
|
|
384
|
+
metrics.lpips_by_fitzpatrick[ftype] = float(np.nanmean(group_lpips))
|
|
385
|
+
|
|
386
|
+
group_ssim = [ssim_scores[i] for i in indices]
|
|
387
|
+
if group_ssim:
|
|
388
|
+
metrics.ssim_by_fitzpatrick[ftype] = float(np.nanmean(group_ssim))
|
|
389
|
+
|
|
390
|
+
if nme_scores:
|
|
391
|
+
group_nme = [nme_scores[i] for i in indices if i < len(nme_scores)]
|
|
392
|
+
if group_nme:
|
|
393
|
+
metrics.nme_by_fitzpatrick[ftype] = float(np.nanmean(group_nme))
|
|
394
|
+
|
|
395
|
+
if identity_scores:
|
|
396
|
+
group_id = [identity_scores[i] for i in indices if i < len(identity_scores)]
|
|
397
|
+
if group_id:
|
|
398
|
+
metrics.identity_sim_by_fitzpatrick[ftype] = float(np.nanmean(group_id))
|
|
399
|
+
|
|
400
|
+
# Per-procedure breakdown
|
|
401
|
+
for proc, indices in proc_groups.items():
|
|
402
|
+
group_lpips = [lpips_scores[i] for i in indices]
|
|
403
|
+
if group_lpips:
|
|
404
|
+
metrics.lpips_by_procedure[proc] = float(np.nanmean(group_lpips))
|
|
405
|
+
|
|
406
|
+
group_ssim = [ssim_scores[i] for i in indices]
|
|
407
|
+
if group_ssim:
|
|
408
|
+
metrics.ssim_by_procedure[proc] = float(np.nanmean(group_ssim))
|
|
409
|
+
|
|
410
|
+
if nme_scores:
|
|
411
|
+
group_nme = [nme_scores[i] for i in indices if i < len(nme_scores)]
|
|
412
|
+
if group_nme:
|
|
413
|
+
metrics.nme_by_procedure[proc] = float(np.nanmean(group_nme))
|
|
414
|
+
|
|
415
|
+
return metrics
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
"""Local experiment tracker for training reproducibility.
|
|
2
|
+
|
|
3
|
+
Tracks all training runs with their configs, metrics, and results.
|
|
4
|
+
Each experiment gets a unique ID and timestamp.
|
|
5
|
+
|
|
6
|
+
Usage::
|
|
7
|
+
|
|
8
|
+
tracker = ExperimentTracker("experiments/")
|
|
9
|
+
|
|
10
|
+
# Start a new experiment
|
|
11
|
+
exp_id = tracker.start(
|
|
12
|
+
name="phaseA_v2",
|
|
13
|
+
config={
|
|
14
|
+
"phase": "A", "lr": 1e-5, "batch": 4,
|
|
15
|
+
"steps": 100000, "data": "training_combined",
|
|
16
|
+
},
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
# Log metrics during training
|
|
20
|
+
tracker.log_metric(exp_id, step=1000, loss=0.045, ssim=0.82)
|
|
21
|
+
|
|
22
|
+
# Record final results
|
|
23
|
+
tracker.finish(exp_id, results={"fid": 42.3, "ssim": 0.87})
|
|
24
|
+
|
|
25
|
+
# List all experiments
|
|
26
|
+
tracker.list_experiments()
|
|
27
|
+
|
|
28
|
+
# Compare experiments
|
|
29
|
+
tracker.compare(["exp_001", "exp_002"])
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from __future__ import annotations
|
|
33
|
+
|
|
34
|
+
import json
|
|
35
|
+
import logging
|
|
36
|
+
import os
|
|
37
|
+
import socket
|
|
38
|
+
import time
|
|
39
|
+
from datetime import datetime
|
|
40
|
+
from pathlib import Path
|
|
41
|
+
|
|
42
|
+
logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ExperimentTracker:
|
|
46
|
+
"""Simple file-based experiment tracker."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, experiments_dir: str = "experiments"):
|
|
49
|
+
self.dir = Path(experiments_dir)
|
|
50
|
+
self.dir.mkdir(parents=True, exist_ok=True)
|
|
51
|
+
self._index_path = self.dir / "index.json"
|
|
52
|
+
self._index = self._load_index()
|
|
53
|
+
|
|
54
|
+
def _load_index(self) -> dict:
|
|
55
|
+
if self._index_path.exists():
|
|
56
|
+
with open(self._index_path) as f:
|
|
57
|
+
return json.load(f)
|
|
58
|
+
return {"experiments": {}, "counter": 0}
|
|
59
|
+
|
|
60
|
+
def _save_index(self) -> None:
|
|
61
|
+
with open(self._index_path, "w") as f:
|
|
62
|
+
json.dump(self._index, f, indent=2)
|
|
63
|
+
|
|
64
|
+
def start(
|
|
65
|
+
self,
|
|
66
|
+
name: str,
|
|
67
|
+
config: dict,
|
|
68
|
+
tags: list[str] | None = None,
|
|
69
|
+
) -> str:
|
|
70
|
+
"""Start a new experiment. Returns experiment ID."""
|
|
71
|
+
self._index["counter"] += 1
|
|
72
|
+
exp_id = f"exp_{self._index['counter']:03d}"
|
|
73
|
+
|
|
74
|
+
exp = {
|
|
75
|
+
"id": exp_id,
|
|
76
|
+
"name": name,
|
|
77
|
+
"config": config,
|
|
78
|
+
"tags": tags or [],
|
|
79
|
+
"status": "running",
|
|
80
|
+
"started_at": datetime.now().isoformat(),
|
|
81
|
+
"finished_at": None,
|
|
82
|
+
"hostname": socket.gethostname(),
|
|
83
|
+
"slurm_job_id": os.environ.get("SLURM_JOB_ID"),
|
|
84
|
+
"gpu": os.environ.get("CUDA_VISIBLE_DEVICES"),
|
|
85
|
+
"results": {},
|
|
86
|
+
"metrics_file": f"{exp_id}_metrics.jsonl",
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
self._index["experiments"][exp_id] = exp
|
|
90
|
+
self._save_index()
|
|
91
|
+
|
|
92
|
+
# Create metrics log file
|
|
93
|
+
metrics_path = self.dir / str(exp["metrics_file"])
|
|
94
|
+
metrics_path.touch()
|
|
95
|
+
|
|
96
|
+
logger.info("Experiment started: %s (%s)", exp_id, name)
|
|
97
|
+
return exp_id
|
|
98
|
+
|
|
99
|
+
def log_metric(self, exp_id: str, step: int | None = None, **metrics) -> None:
|
|
100
|
+
"""Log metrics for a training step."""
|
|
101
|
+
exp = self._index["experiments"].get(exp_id)
|
|
102
|
+
if not exp:
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
entry = {
|
|
106
|
+
"timestamp": time.time(),
|
|
107
|
+
"step": step,
|
|
108
|
+
**metrics,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
metrics_path = self.dir / str(exp["metrics_file"])
|
|
112
|
+
with open(metrics_path, "a") as f:
|
|
113
|
+
f.write(json.dumps(entry) + "\n")
|
|
114
|
+
|
|
115
|
+
def finish(
|
|
116
|
+
self,
|
|
117
|
+
exp_id: str,
|
|
118
|
+
results: dict | None = None,
|
|
119
|
+
status: str = "completed",
|
|
120
|
+
) -> None:
|
|
121
|
+
"""Mark experiment as finished."""
|
|
122
|
+
exp = self._index["experiments"].get(exp_id)
|
|
123
|
+
if not exp:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
exp["status"] = status
|
|
127
|
+
exp["finished_at"] = datetime.now().isoformat()
|
|
128
|
+
if results:
|
|
129
|
+
exp["results"] = results
|
|
130
|
+
|
|
131
|
+
self._save_index()
|
|
132
|
+
logger.info("Experiment %s %s", exp_id, status)
|
|
133
|
+
|
|
134
|
+
def get_metrics(self, exp_id: str) -> list[dict]:
|
|
135
|
+
"""Load all logged metrics for an experiment."""
|
|
136
|
+
exp = self._index["experiments"].get(exp_id)
|
|
137
|
+
if not exp:
|
|
138
|
+
return []
|
|
139
|
+
|
|
140
|
+
metrics_path = self.dir / str(exp["metrics_file"])
|
|
141
|
+
if not metrics_path.exists():
|
|
142
|
+
return []
|
|
143
|
+
|
|
144
|
+
entries = []
|
|
145
|
+
with open(metrics_path) as f:
|
|
146
|
+
for line in f:
|
|
147
|
+
line = line.strip()
|
|
148
|
+
if line:
|
|
149
|
+
entries.append(json.loads(line))
|
|
150
|
+
return entries
|
|
151
|
+
|
|
152
|
+
def list_experiments(self) -> list[dict]:
|
|
153
|
+
"""List all experiments with summary info."""
|
|
154
|
+
experiments = []
|
|
155
|
+
for exp_id, exp in sorted(self._index["experiments"].items()):
|
|
156
|
+
summary = {
|
|
157
|
+
"id": exp_id,
|
|
158
|
+
"name": exp["name"],
|
|
159
|
+
"status": exp["status"],
|
|
160
|
+
"started": exp["started_at"][:19],
|
|
161
|
+
"tags": exp.get("tags", []),
|
|
162
|
+
}
|
|
163
|
+
if exp["results"]:
|
|
164
|
+
for key in ["fid", "ssim", "lpips", "nme"]:
|
|
165
|
+
if key in exp["results"]:
|
|
166
|
+
summary[key] = exp["results"][key]
|
|
167
|
+
experiments.append(summary)
|
|
168
|
+
return experiments
|
|
169
|
+
|
|
170
|
+
def compare(self, exp_ids: list[str]) -> dict:
|
|
171
|
+
"""Compare multiple experiments by their results."""
|
|
172
|
+
comparison = {}
|
|
173
|
+
for exp_id in exp_ids:
|
|
174
|
+
exp = self._index["experiments"].get(exp_id)
|
|
175
|
+
if exp:
|
|
176
|
+
comparison[exp_id] = {
|
|
177
|
+
"name": exp["name"],
|
|
178
|
+
"config": exp["config"],
|
|
179
|
+
"results": exp["results"],
|
|
180
|
+
}
|
|
181
|
+
return comparison
|
|
182
|
+
|
|
183
|
+
def print_summary(self) -> None:
|
|
184
|
+
"""Print a summary table of all experiments."""
|
|
185
|
+
experiments = self.list_experiments()
|
|
186
|
+
if not experiments:
|
|
187
|
+
logger.info("No experiments found.")
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
# Header
|
|
191
|
+
logger.info(
|
|
192
|
+
"%s %s %s %s %s %s",
|
|
193
|
+
"ID".ljust(10),
|
|
194
|
+
"Name".ljust(20),
|
|
195
|
+
"Status".ljust(12),
|
|
196
|
+
"FID".rjust(6),
|
|
197
|
+
"SSIM".rjust(6),
|
|
198
|
+
"LPIPS".rjust(6),
|
|
199
|
+
)
|
|
200
|
+
logger.info("-" * 70)
|
|
201
|
+
|
|
202
|
+
for exp in experiments:
|
|
203
|
+
fid = f"{exp.get('fid', '')}" if "fid" in exp else "--"
|
|
204
|
+
ssim = f"{exp.get('ssim', ''):.4f}" if "ssim" in exp else "--"
|
|
205
|
+
lpips = f"{exp.get('lpips', ''):.4f}" if "lpips" in exp else "--"
|
|
206
|
+
logger.info(
|
|
207
|
+
"%s %s %s %s %s %s",
|
|
208
|
+
exp["id"].ljust(10),
|
|
209
|
+
exp["name"].ljust(20),
|
|
210
|
+
exp["status"].ljust(12),
|
|
211
|
+
fid.rjust(6),
|
|
212
|
+
ssim.rjust(6),
|
|
213
|
+
lpips.rjust(6),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def get_best(self, metric: str = "fid", lower_is_better: bool = True) -> str | None:
|
|
217
|
+
"""Get the experiment ID with the best value for a given metric."""
|
|
218
|
+
best_id = None
|
|
219
|
+
best_val = float("inf") if lower_is_better else float("-inf")
|
|
220
|
+
|
|
221
|
+
for exp_id, exp in self._index["experiments"].items():
|
|
222
|
+
if exp["status"] != "completed":
|
|
223
|
+
continue
|
|
224
|
+
val = exp["results"].get(metric)
|
|
225
|
+
if val is None:
|
|
226
|
+
continue
|
|
227
|
+
if (lower_is_better and val < best_val) or (not lower_is_better and val > best_val):
|
|
228
|
+
best_val = val
|
|
229
|
+
best_id = exp_id
|
|
230
|
+
|
|
231
|
+
return best_id
|