landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/masking.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
1
|
+
"""Surgical mask generation with morphological dilation and Gaussian feathering.
|
|
2
|
+
|
|
3
|
+
Procedural masks (not SAM2) -- deterministic, no model dependency.
|
|
4
|
+
Feathered boundaries prevent visible seams in ControlNet inpainting.
|
|
5
|
+
Supports clinical edge cases (vitiligo preservation, keloid softening).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
import cv2
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from landmarkdiff.landmarks import FaceLandmarks
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from landmarkdiff.clinical import ClinicalFlags
|
|
19
|
+
|
|
20
|
+
# Boundary noise parameters for seam prevention
|
|
21
|
+
_BOUNDARY_KERNEL_SIZE = 5 # px, morphological kernel for boundary extraction
|
|
22
|
+
_BOUNDARY_NOISE_MAX = 4 # max random noise offset in pixels
|
|
23
|
+
_BOUNDARY_NOISE_SCALE = 64 # intensity multiplier for noise mask
|
|
24
|
+
_GAUSSIAN_KERNEL_FACTOR = 6 # sigma multiplier for Gaussian kernel size
|
|
25
|
+
|
|
26
|
+
# Procedure-specific mask parameters
|
|
27
|
+
MASK_CONFIG: dict[str, dict] = {
|
|
28
|
+
"rhinoplasty": {
|
|
29
|
+
"landmark_indices": [
|
|
30
|
+
1,
|
|
31
|
+
2,
|
|
32
|
+
4,
|
|
33
|
+
5,
|
|
34
|
+
6,
|
|
35
|
+
19,
|
|
36
|
+
94,
|
|
37
|
+
141,
|
|
38
|
+
168,
|
|
39
|
+
195,
|
|
40
|
+
197,
|
|
41
|
+
236,
|
|
42
|
+
240,
|
|
43
|
+
274,
|
|
44
|
+
275,
|
|
45
|
+
278,
|
|
46
|
+
279,
|
|
47
|
+
294,
|
|
48
|
+
326,
|
|
49
|
+
327,
|
|
50
|
+
360,
|
|
51
|
+
363,
|
|
52
|
+
370,
|
|
53
|
+
456,
|
|
54
|
+
460,
|
|
55
|
+
],
|
|
56
|
+
"dilation_px": 30,
|
|
57
|
+
"feather_sigma": 15.0,
|
|
58
|
+
},
|
|
59
|
+
"blepharoplasty": {
|
|
60
|
+
"landmark_indices": [
|
|
61
|
+
33,
|
|
62
|
+
7,
|
|
63
|
+
163,
|
|
64
|
+
144,
|
|
65
|
+
145,
|
|
66
|
+
153,
|
|
67
|
+
154,
|
|
68
|
+
155,
|
|
69
|
+
157,
|
|
70
|
+
158,
|
|
71
|
+
159,
|
|
72
|
+
160,
|
|
73
|
+
161,
|
|
74
|
+
246,
|
|
75
|
+
362,
|
|
76
|
+
382,
|
|
77
|
+
381,
|
|
78
|
+
380,
|
|
79
|
+
374,
|
|
80
|
+
373,
|
|
81
|
+
390,
|
|
82
|
+
249,
|
|
83
|
+
263,
|
|
84
|
+
466,
|
|
85
|
+
388,
|
|
86
|
+
387,
|
|
87
|
+
386,
|
|
88
|
+
385,
|
|
89
|
+
384,
|
|
90
|
+
398,
|
|
91
|
+
],
|
|
92
|
+
"dilation_px": 15,
|
|
93
|
+
"feather_sigma": 10.0,
|
|
94
|
+
},
|
|
95
|
+
"rhytidectomy": {
|
|
96
|
+
"landmark_indices": [
|
|
97
|
+
10,
|
|
98
|
+
21,
|
|
99
|
+
54,
|
|
100
|
+
58,
|
|
101
|
+
67,
|
|
102
|
+
93,
|
|
103
|
+
103,
|
|
104
|
+
109,
|
|
105
|
+
127,
|
|
106
|
+
132,
|
|
107
|
+
136,
|
|
108
|
+
150,
|
|
109
|
+
162,
|
|
110
|
+
172,
|
|
111
|
+
176,
|
|
112
|
+
187,
|
|
113
|
+
207,
|
|
114
|
+
213,
|
|
115
|
+
234,
|
|
116
|
+
284,
|
|
117
|
+
297,
|
|
118
|
+
323,
|
|
119
|
+
332,
|
|
120
|
+
338,
|
|
121
|
+
356,
|
|
122
|
+
361,
|
|
123
|
+
365,
|
|
124
|
+
379,
|
|
125
|
+
389,
|
|
126
|
+
397,
|
|
127
|
+
400,
|
|
128
|
+
427,
|
|
129
|
+
454,
|
|
130
|
+
],
|
|
131
|
+
"dilation_px": 40,
|
|
132
|
+
"feather_sigma": 20.0,
|
|
133
|
+
},
|
|
134
|
+
"orthognathic": {
|
|
135
|
+
"landmark_indices": [
|
|
136
|
+
0,
|
|
137
|
+
17,
|
|
138
|
+
18,
|
|
139
|
+
36,
|
|
140
|
+
37,
|
|
141
|
+
39,
|
|
142
|
+
40,
|
|
143
|
+
57,
|
|
144
|
+
61,
|
|
145
|
+
78,
|
|
146
|
+
80,
|
|
147
|
+
81,
|
|
148
|
+
82,
|
|
149
|
+
84,
|
|
150
|
+
87,
|
|
151
|
+
88,
|
|
152
|
+
91,
|
|
153
|
+
95,
|
|
154
|
+
146,
|
|
155
|
+
167,
|
|
156
|
+
169,
|
|
157
|
+
170,
|
|
158
|
+
175,
|
|
159
|
+
181,
|
|
160
|
+
191,
|
|
161
|
+
200,
|
|
162
|
+
201,
|
|
163
|
+
202,
|
|
164
|
+
204,
|
|
165
|
+
208,
|
|
166
|
+
211,
|
|
167
|
+
212,
|
|
168
|
+
214,
|
|
169
|
+
],
|
|
170
|
+
"dilation_px": 35,
|
|
171
|
+
"feather_sigma": 18.0,
|
|
172
|
+
},
|
|
173
|
+
"brow_lift": {
|
|
174
|
+
"landmark_indices": [
|
|
175
|
+
70,
|
|
176
|
+
63,
|
|
177
|
+
105,
|
|
178
|
+
66,
|
|
179
|
+
107, # left brow
|
|
180
|
+
300,
|
|
181
|
+
293,
|
|
182
|
+
334,
|
|
183
|
+
296,
|
|
184
|
+
336, # right brow
|
|
185
|
+
9,
|
|
186
|
+
8,
|
|
187
|
+
10, # forehead midline
|
|
188
|
+
109,
|
|
189
|
+
67,
|
|
190
|
+
103, # upper face left
|
|
191
|
+
338,
|
|
192
|
+
297,
|
|
193
|
+
332, # upper face right
|
|
194
|
+
],
|
|
195
|
+
"dilation_px": 25,
|
|
196
|
+
"feather_sigma": 15.0,
|
|
197
|
+
},
|
|
198
|
+
"mentoplasty": {
|
|
199
|
+
"landmark_indices": [
|
|
200
|
+
148,
|
|
201
|
+
149,
|
|
202
|
+
150,
|
|
203
|
+
152,
|
|
204
|
+
171,
|
|
205
|
+
175,
|
|
206
|
+
176,
|
|
207
|
+
377,
|
|
208
|
+
],
|
|
209
|
+
"dilation_px": 25,
|
|
210
|
+
"feather_sigma": 12.0,
|
|
211
|
+
},
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def generate_surgical_mask(
|
|
216
|
+
face: FaceLandmarks,
|
|
217
|
+
procedure: str,
|
|
218
|
+
width: int | None = None,
|
|
219
|
+
height: int | None = None,
|
|
220
|
+
clinical_flags: ClinicalFlags | None = None,
|
|
221
|
+
image: np.ndarray | None = None,
|
|
222
|
+
) -> np.ndarray:
|
|
223
|
+
"""Generate a feathered surgical mask for a procedure.
|
|
224
|
+
|
|
225
|
+
Pipeline:
|
|
226
|
+
1. Create convex hull from procedure-specific landmarks
|
|
227
|
+
2. Morphological dilation by N pixels
|
|
228
|
+
3. Gaussian feathering for smooth alpha gradient
|
|
229
|
+
4. Add Perlin-style noise at boundary to prevent visible seams
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
face: Extracted facial landmarks.
|
|
233
|
+
procedure: Procedure name (e.g. "rhinoplasty").
|
|
234
|
+
width: Mask width (defaults to face.image_width).
|
|
235
|
+
height: Mask height (defaults to face.image_height).
|
|
236
|
+
clinical_flags: Optional clinical edge-case flags (vitiligo, keloid).
|
|
237
|
+
image: Original BGR image, required when clinical_flags.vitiligo is set.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Float32 mask array [0.0-1.0] with feathered boundaries.
|
|
241
|
+
"""
|
|
242
|
+
if procedure not in MASK_CONFIG:
|
|
243
|
+
raise ValueError(f"Unknown procedure: {procedure}. Choose from {list(MASK_CONFIG)}")
|
|
244
|
+
|
|
245
|
+
config = MASK_CONFIG[procedure]
|
|
246
|
+
w = width or face.image_width
|
|
247
|
+
h = height or face.image_height
|
|
248
|
+
|
|
249
|
+
# Get pixel coordinates of procedure landmarks
|
|
250
|
+
coords = face.landmarks[:, :2].copy()
|
|
251
|
+
coords[:, 0] *= w
|
|
252
|
+
coords[:, 1] *= h
|
|
253
|
+
pts = coords[config["landmark_indices"]].astype(np.int32)
|
|
254
|
+
|
|
255
|
+
# Create binary mask from convex hull
|
|
256
|
+
binary = np.zeros((h, w), dtype=np.uint8)
|
|
257
|
+
hull = cv2.convexHull(pts)
|
|
258
|
+
cv2.fillConvexPoly(binary, hull, 255)
|
|
259
|
+
|
|
260
|
+
# Morphological dilation
|
|
261
|
+
dilation = config["dilation_px"]
|
|
262
|
+
kernel = cv2.getStructuringElement(
|
|
263
|
+
cv2.MORPH_ELLIPSE,
|
|
264
|
+
(2 * dilation + 1, 2 * dilation + 1),
|
|
265
|
+
)
|
|
266
|
+
dilated = cv2.dilate(binary, kernel)
|
|
267
|
+
|
|
268
|
+
# Add slight boundary noise to prevent clean-edge seams
|
|
269
|
+
# (Spec: Perlin noise 2-4px on boundary before feathering)
|
|
270
|
+
boundary = cv2.subtract(
|
|
271
|
+
cv2.dilate(dilated, np.ones((_BOUNDARY_KERNEL_SIZE, _BOUNDARY_KERNEL_SIZE), np.uint8)),
|
|
272
|
+
cv2.erode(dilated, np.ones((_BOUNDARY_KERNEL_SIZE, _BOUNDARY_KERNEL_SIZE), np.uint8)),
|
|
273
|
+
)
|
|
274
|
+
noise = np.random.default_rng().integers(0, _BOUNDARY_NOISE_MAX, size=(h, w), dtype=np.uint8)
|
|
275
|
+
noise_boundary = cv2.bitwise_and(boundary, noise.astype(np.uint8) * _BOUNDARY_NOISE_SCALE)
|
|
276
|
+
dilated = cv2.add(dilated, noise_boundary)
|
|
277
|
+
dilated = np.clip(dilated, 0, 255).astype(np.uint8)
|
|
278
|
+
|
|
279
|
+
# Gaussian feathering
|
|
280
|
+
sigma = config["feather_sigma"]
|
|
281
|
+
ksize = int(_GAUSSIAN_KERNEL_FACTOR * sigma) | 1 # ensure odd
|
|
282
|
+
feathered = cv2.GaussianBlur(
|
|
283
|
+
dilated.astype(np.float32) / 255.0,
|
|
284
|
+
(ksize, ksize),
|
|
285
|
+
sigma,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
mask = np.clip(feathered, 0.0, 1.0)
|
|
289
|
+
|
|
290
|
+
# Clinical edge case adjustments
|
|
291
|
+
if clinical_flags is not None:
|
|
292
|
+
# Vitiligo: reduce mask over depigmented patches to preserve them
|
|
293
|
+
if clinical_flags.vitiligo and image is not None:
|
|
294
|
+
from landmarkdiff.clinical import adjust_mask_for_vitiligo, detect_vitiligo_patches
|
|
295
|
+
|
|
296
|
+
patches = detect_vitiligo_patches(image, face)
|
|
297
|
+
mask = adjust_mask_for_vitiligo(mask, patches)
|
|
298
|
+
|
|
299
|
+
# Keloid: soften transitions in keloid-prone regions
|
|
300
|
+
if clinical_flags.keloid_prone and clinical_flags.keloid_regions:
|
|
301
|
+
from landmarkdiff.clinical import adjust_mask_for_keloid, get_keloid_exclusion_mask
|
|
302
|
+
|
|
303
|
+
keloid_mask = get_keloid_exclusion_mask(
|
|
304
|
+
face,
|
|
305
|
+
clinical_flags.keloid_regions,
|
|
306
|
+
w,
|
|
307
|
+
h,
|
|
308
|
+
)
|
|
309
|
+
mask = adjust_mask_for_keloid(mask, keloid_mask)
|
|
310
|
+
|
|
311
|
+
return mask
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def mask_to_3channel(mask: np.ndarray) -> np.ndarray:
|
|
315
|
+
"""Convert single-channel mask to 3-channel for compositing."""
|
|
316
|
+
return np.stack([mask, mask, mask], axis=-1)
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""Metrics aggregation across checkpoints, experiments, and procedures.
|
|
2
|
+
|
|
3
|
+
Collects evaluation results from multiple sources and computes aggregate
|
|
4
|
+
statistics, confidence intervals, and significance tests for paper reporting.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from landmarkdiff.metrics_agg import MetricsAggregator
|
|
8
|
+
|
|
9
|
+
agg = MetricsAggregator()
|
|
10
|
+
agg.add("baseline", "rhinoplasty", {"ssim": 0.82, "lpips": 0.18})
|
|
11
|
+
agg.add("ours", "rhinoplasty", {"ssim": 0.91, "lpips": 0.09})
|
|
12
|
+
print(agg.summary_table())
|
|
13
|
+
print(agg.improvement_over("baseline"))
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import math
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class MetricRecord:
|
|
27
|
+
"""A single evaluation record."""
|
|
28
|
+
|
|
29
|
+
experiment: str
|
|
30
|
+
procedure: str
|
|
31
|
+
metrics: dict[str, float]
|
|
32
|
+
checkpoint_step: int | None = None
|
|
33
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class MetricsAggregator:
|
|
37
|
+
"""Aggregate and analyze evaluation metrics.
|
|
38
|
+
|
|
39
|
+
Supports multiple experiments, procedures, and per-sample results
|
|
40
|
+
for computing confidence intervals and significance.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
HIGHER_BETTER = {
|
|
44
|
+
"ssim": True,
|
|
45
|
+
"psnr": True,
|
|
46
|
+
"identity_sim": True,
|
|
47
|
+
"lpips": False,
|
|
48
|
+
"fid": False,
|
|
49
|
+
"nme": False,
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
def __init__(self) -> None:
|
|
53
|
+
self.records: list[MetricRecord] = []
|
|
54
|
+
|
|
55
|
+
def add(
|
|
56
|
+
self,
|
|
57
|
+
experiment: str,
|
|
58
|
+
procedure: str,
|
|
59
|
+
metrics: dict[str, float],
|
|
60
|
+
checkpoint_step: int | None = None,
|
|
61
|
+
**metadata: Any,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Add a single evaluation record."""
|
|
64
|
+
self.records.append(
|
|
65
|
+
MetricRecord(
|
|
66
|
+
experiment=experiment,
|
|
67
|
+
procedure=procedure,
|
|
68
|
+
metrics=metrics,
|
|
69
|
+
checkpoint_step=checkpoint_step,
|
|
70
|
+
metadata=metadata,
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def add_batch(
|
|
75
|
+
self,
|
|
76
|
+
experiment: str,
|
|
77
|
+
records: list[dict[str, Any]],
|
|
78
|
+
) -> None:
|
|
79
|
+
"""Add multiple records for an experiment.
|
|
80
|
+
|
|
81
|
+
Each record dict should have 'procedure' and metric keys.
|
|
82
|
+
"""
|
|
83
|
+
for rec in records:
|
|
84
|
+
proc = rec.get("procedure", "all")
|
|
85
|
+
metrics = {
|
|
86
|
+
k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))
|
|
87
|
+
}
|
|
88
|
+
self.add(experiment, proc, metrics)
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def experiments(self) -> list[str]:
|
|
92
|
+
"""Unique experiment names in insertion order."""
|
|
93
|
+
seen: dict[str, None] = {}
|
|
94
|
+
for r in self.records:
|
|
95
|
+
seen.setdefault(r.experiment, None)
|
|
96
|
+
return list(seen.keys())
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def procedures(self) -> list[str]:
|
|
100
|
+
"""Unique procedure names in insertion order."""
|
|
101
|
+
seen: dict[str, None] = {}
|
|
102
|
+
for r in self.records:
|
|
103
|
+
seen.setdefault(r.procedure, None)
|
|
104
|
+
return list(seen.keys())
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def metric_names(self) -> list[str]:
|
|
108
|
+
"""All unique metric names."""
|
|
109
|
+
names: set[str] = set()
|
|
110
|
+
for r in self.records:
|
|
111
|
+
names.update(r.metrics.keys())
|
|
112
|
+
return sorted(names)
|
|
113
|
+
|
|
114
|
+
def filter(
|
|
115
|
+
self,
|
|
116
|
+
experiment: str | None = None,
|
|
117
|
+
procedure: str | None = None,
|
|
118
|
+
) -> list[MetricRecord]:
|
|
119
|
+
"""Filter records by experiment and/or procedure."""
|
|
120
|
+
results = self.records
|
|
121
|
+
if experiment is not None:
|
|
122
|
+
results = [r for r in results if r.experiment == experiment]
|
|
123
|
+
if procedure is not None:
|
|
124
|
+
results = [r for r in results if r.procedure == procedure]
|
|
125
|
+
return results
|
|
126
|
+
|
|
127
|
+
def mean(
|
|
128
|
+
self,
|
|
129
|
+
experiment: str,
|
|
130
|
+
metric: str,
|
|
131
|
+
procedure: str | None = None,
|
|
132
|
+
) -> float:
|
|
133
|
+
"""Compute mean of a metric for an experiment."""
|
|
134
|
+
recs = self.filter(experiment=experiment, procedure=procedure)
|
|
135
|
+
vals = [r.metrics[metric] for r in recs if metric in r.metrics]
|
|
136
|
+
if not vals:
|
|
137
|
+
return float("nan")
|
|
138
|
+
return sum(vals) / len(vals)
|
|
139
|
+
|
|
140
|
+
def std(
|
|
141
|
+
self,
|
|
142
|
+
experiment: str,
|
|
143
|
+
metric: str,
|
|
144
|
+
procedure: str | None = None,
|
|
145
|
+
) -> float:
|
|
146
|
+
"""Compute standard deviation of a metric."""
|
|
147
|
+
recs = self.filter(experiment=experiment, procedure=procedure)
|
|
148
|
+
vals = [r.metrics[metric] for r in recs if metric in r.metrics]
|
|
149
|
+
if len(vals) < 2:
|
|
150
|
+
return 0.0
|
|
151
|
+
m = sum(vals) / len(vals)
|
|
152
|
+
var = sum((v - m) ** 2 for v in vals) / (len(vals) - 1)
|
|
153
|
+
return math.sqrt(var)
|
|
154
|
+
|
|
155
|
+
def ci_95(
|
|
156
|
+
self,
|
|
157
|
+
experiment: str,
|
|
158
|
+
metric: str,
|
|
159
|
+
procedure: str | None = None,
|
|
160
|
+
) -> tuple[float, float]:
|
|
161
|
+
"""Compute 95% confidence interval (mean +/- 1.96*SE)."""
|
|
162
|
+
recs = self.filter(experiment=experiment, procedure=procedure)
|
|
163
|
+
vals = [r.metrics[metric] for r in recs if metric in r.metrics]
|
|
164
|
+
if not vals:
|
|
165
|
+
return (float("nan"), float("nan"))
|
|
166
|
+
n = len(vals)
|
|
167
|
+
m = sum(vals) / n
|
|
168
|
+
if n < 2:
|
|
169
|
+
return (m, m)
|
|
170
|
+
var = sum((v - m) ** 2 for v in vals) / (n - 1)
|
|
171
|
+
se = math.sqrt(var / n)
|
|
172
|
+
return (m - 1.96 * se, m + 1.96 * se)
|
|
173
|
+
|
|
174
|
+
def improvement_over(
|
|
175
|
+
self,
|
|
176
|
+
baseline: str,
|
|
177
|
+
metric: str | None = None,
|
|
178
|
+
) -> dict[str, dict[str, float]]:
|
|
179
|
+
"""Compute relative improvement of all experiments over a baseline.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
{experiment: {metric: relative_improvement_pct}}
|
|
183
|
+
"""
|
|
184
|
+
metrics = [metric] if metric else self.metric_names
|
|
185
|
+
result: dict[str, dict[str, float]] = {}
|
|
186
|
+
|
|
187
|
+
for exp in self.experiments:
|
|
188
|
+
if exp == baseline:
|
|
189
|
+
continue
|
|
190
|
+
improvements: dict[str, float] = {}
|
|
191
|
+
for m in metrics:
|
|
192
|
+
base_val = self.mean(baseline, m)
|
|
193
|
+
exp_val = self.mean(exp, m)
|
|
194
|
+
if math.isnan(base_val) or math.isnan(exp_val) or base_val == 0:
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
higher_better = self.HIGHER_BETTER.get(m, True)
|
|
198
|
+
if higher_better:
|
|
199
|
+
pct = (exp_val - base_val) / abs(base_val) * 100
|
|
200
|
+
else:
|
|
201
|
+
pct = (base_val - exp_val) / abs(base_val) * 100
|
|
202
|
+
improvements[m] = round(pct, 2)
|
|
203
|
+
|
|
204
|
+
result[exp] = improvements
|
|
205
|
+
|
|
206
|
+
return result
|
|
207
|
+
|
|
208
|
+
def best_experiment(
|
|
209
|
+
self,
|
|
210
|
+
metric: str,
|
|
211
|
+
procedure: str | None = None,
|
|
212
|
+
) -> str | None:
|
|
213
|
+
"""Find the experiment with the best mean for a metric."""
|
|
214
|
+
higher_better = self.HIGHER_BETTER.get(metric, True)
|
|
215
|
+
best_exp = None
|
|
216
|
+
best_val = float("-inf") if higher_better else float("inf")
|
|
217
|
+
|
|
218
|
+
for exp in self.experiments:
|
|
219
|
+
val = self.mean(exp, metric, procedure)
|
|
220
|
+
if math.isnan(val):
|
|
221
|
+
continue
|
|
222
|
+
if (higher_better and val > best_val) or (not higher_better and val < best_val):
|
|
223
|
+
best_val = val
|
|
224
|
+
best_exp = exp
|
|
225
|
+
|
|
226
|
+
return best_exp
|
|
227
|
+
|
|
228
|
+
def summary_table(
|
|
229
|
+
self,
|
|
230
|
+
metrics: list[str] | None = None,
|
|
231
|
+
procedure: str | None = None,
|
|
232
|
+
include_std: bool = False,
|
|
233
|
+
) -> str:
|
|
234
|
+
"""Generate a text summary table.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
metrics: Metrics to include. None = all.
|
|
238
|
+
procedure: Filter by procedure. None = aggregate.
|
|
239
|
+
include_std: Show mean +/- std.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Formatted text table.
|
|
243
|
+
"""
|
|
244
|
+
metrics = metrics or self.metric_names
|
|
245
|
+
exps = self.experiments
|
|
246
|
+
|
|
247
|
+
# Header
|
|
248
|
+
cols = ["Experiment"] + metrics
|
|
249
|
+
header = " | ".join(f"{c:>16s}" for c in cols)
|
|
250
|
+
lines = [header, "-" * len(header)]
|
|
251
|
+
|
|
252
|
+
for exp in exps:
|
|
253
|
+
parts = [f"{exp:>16s}"]
|
|
254
|
+
for m in metrics:
|
|
255
|
+
val = self.mean(exp, m, procedure)
|
|
256
|
+
if math.isnan(val):
|
|
257
|
+
parts.append(f"{'--':>16s}")
|
|
258
|
+
elif include_std:
|
|
259
|
+
s = self.std(exp, m, procedure)
|
|
260
|
+
parts.append(f"{val:>8.4f}±{s:<6.4f}")
|
|
261
|
+
else:
|
|
262
|
+
parts.append(f"{val:>16.4f}")
|
|
263
|
+
lines.append(" | ".join(parts))
|
|
264
|
+
|
|
265
|
+
return "\n".join(lines)
|
|
266
|
+
|
|
267
|
+
def to_json(self, path: str | Path | None = None) -> str:
|
|
268
|
+
"""Export all records as JSON.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
path: Optional file path to write to.
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
JSON string.
|
|
275
|
+
"""
|
|
276
|
+
data = {
|
|
277
|
+
"experiments": self.experiments,
|
|
278
|
+
"procedures": self.procedures,
|
|
279
|
+
"metrics": self.metric_names,
|
|
280
|
+
"records": [
|
|
281
|
+
{
|
|
282
|
+
"experiment": r.experiment,
|
|
283
|
+
"procedure": r.procedure,
|
|
284
|
+
"metrics": r.metrics,
|
|
285
|
+
"checkpoint_step": r.checkpoint_step,
|
|
286
|
+
"metadata": r.metadata,
|
|
287
|
+
}
|
|
288
|
+
for r in self.records
|
|
289
|
+
],
|
|
290
|
+
}
|
|
291
|
+
j = json.dumps(data, indent=2)
|
|
292
|
+
|
|
293
|
+
if path is not None:
|
|
294
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
295
|
+
Path(path).write_text(j)
|
|
296
|
+
|
|
297
|
+
return j
|
|
298
|
+
|
|
299
|
+
@staticmethod
|
|
300
|
+
def from_json(path: str | Path) -> MetricsAggregator:
|
|
301
|
+
"""Load aggregator from JSON."""
|
|
302
|
+
with open(path) as f:
|
|
303
|
+
data = json.load(f)
|
|
304
|
+
|
|
305
|
+
agg = MetricsAggregator()
|
|
306
|
+
for rec in data.get("records", []):
|
|
307
|
+
agg.add(
|
|
308
|
+
experiment=rec["experiment"],
|
|
309
|
+
procedure=rec["procedure"],
|
|
310
|
+
metrics=rec["metrics"],
|
|
311
|
+
checkpoint_step=rec.get("checkpoint_step"),
|
|
312
|
+
)
|
|
313
|
+
return agg
|