landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,316 @@
1
+ """Surgical mask generation with morphological dilation and Gaussian feathering.
2
+
3
+ Procedural masks (not SAM2) -- deterministic, no model dependency.
4
+ Feathered boundaries prevent visible seams in ControlNet inpainting.
5
+ Supports clinical edge cases (vitiligo preservation, keloid softening).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING
11
+
12
+ import cv2
13
+ import numpy as np
14
+
15
+ from landmarkdiff.landmarks import FaceLandmarks
16
+
17
+ if TYPE_CHECKING:
18
+ from landmarkdiff.clinical import ClinicalFlags
19
+
20
+ # Boundary noise parameters for seam prevention
21
+ _BOUNDARY_KERNEL_SIZE = 5 # px, morphological kernel for boundary extraction
22
+ _BOUNDARY_NOISE_MAX = 4 # max random noise offset in pixels
23
+ _BOUNDARY_NOISE_SCALE = 64 # intensity multiplier for noise mask
24
+ _GAUSSIAN_KERNEL_FACTOR = 6 # sigma multiplier for Gaussian kernel size
25
+
26
+ # Procedure-specific mask parameters
27
+ MASK_CONFIG: dict[str, dict] = {
28
+ "rhinoplasty": {
29
+ "landmark_indices": [
30
+ 1,
31
+ 2,
32
+ 4,
33
+ 5,
34
+ 6,
35
+ 19,
36
+ 94,
37
+ 141,
38
+ 168,
39
+ 195,
40
+ 197,
41
+ 236,
42
+ 240,
43
+ 274,
44
+ 275,
45
+ 278,
46
+ 279,
47
+ 294,
48
+ 326,
49
+ 327,
50
+ 360,
51
+ 363,
52
+ 370,
53
+ 456,
54
+ 460,
55
+ ],
56
+ "dilation_px": 30,
57
+ "feather_sigma": 15.0,
58
+ },
59
+ "blepharoplasty": {
60
+ "landmark_indices": [
61
+ 33,
62
+ 7,
63
+ 163,
64
+ 144,
65
+ 145,
66
+ 153,
67
+ 154,
68
+ 155,
69
+ 157,
70
+ 158,
71
+ 159,
72
+ 160,
73
+ 161,
74
+ 246,
75
+ 362,
76
+ 382,
77
+ 381,
78
+ 380,
79
+ 374,
80
+ 373,
81
+ 390,
82
+ 249,
83
+ 263,
84
+ 466,
85
+ 388,
86
+ 387,
87
+ 386,
88
+ 385,
89
+ 384,
90
+ 398,
91
+ ],
92
+ "dilation_px": 15,
93
+ "feather_sigma": 10.0,
94
+ },
95
+ "rhytidectomy": {
96
+ "landmark_indices": [
97
+ 10,
98
+ 21,
99
+ 54,
100
+ 58,
101
+ 67,
102
+ 93,
103
+ 103,
104
+ 109,
105
+ 127,
106
+ 132,
107
+ 136,
108
+ 150,
109
+ 162,
110
+ 172,
111
+ 176,
112
+ 187,
113
+ 207,
114
+ 213,
115
+ 234,
116
+ 284,
117
+ 297,
118
+ 323,
119
+ 332,
120
+ 338,
121
+ 356,
122
+ 361,
123
+ 365,
124
+ 379,
125
+ 389,
126
+ 397,
127
+ 400,
128
+ 427,
129
+ 454,
130
+ ],
131
+ "dilation_px": 40,
132
+ "feather_sigma": 20.0,
133
+ },
134
+ "orthognathic": {
135
+ "landmark_indices": [
136
+ 0,
137
+ 17,
138
+ 18,
139
+ 36,
140
+ 37,
141
+ 39,
142
+ 40,
143
+ 57,
144
+ 61,
145
+ 78,
146
+ 80,
147
+ 81,
148
+ 82,
149
+ 84,
150
+ 87,
151
+ 88,
152
+ 91,
153
+ 95,
154
+ 146,
155
+ 167,
156
+ 169,
157
+ 170,
158
+ 175,
159
+ 181,
160
+ 191,
161
+ 200,
162
+ 201,
163
+ 202,
164
+ 204,
165
+ 208,
166
+ 211,
167
+ 212,
168
+ 214,
169
+ ],
170
+ "dilation_px": 35,
171
+ "feather_sigma": 18.0,
172
+ },
173
+ "brow_lift": {
174
+ "landmark_indices": [
175
+ 70,
176
+ 63,
177
+ 105,
178
+ 66,
179
+ 107, # left brow
180
+ 300,
181
+ 293,
182
+ 334,
183
+ 296,
184
+ 336, # right brow
185
+ 9,
186
+ 8,
187
+ 10, # forehead midline
188
+ 109,
189
+ 67,
190
+ 103, # upper face left
191
+ 338,
192
+ 297,
193
+ 332, # upper face right
194
+ ],
195
+ "dilation_px": 25,
196
+ "feather_sigma": 15.0,
197
+ },
198
+ "mentoplasty": {
199
+ "landmark_indices": [
200
+ 148,
201
+ 149,
202
+ 150,
203
+ 152,
204
+ 171,
205
+ 175,
206
+ 176,
207
+ 377,
208
+ ],
209
+ "dilation_px": 25,
210
+ "feather_sigma": 12.0,
211
+ },
212
+ }
213
+
214
+
215
+ def generate_surgical_mask(
216
+ face: FaceLandmarks,
217
+ procedure: str,
218
+ width: int | None = None,
219
+ height: int | None = None,
220
+ clinical_flags: ClinicalFlags | None = None,
221
+ image: np.ndarray | None = None,
222
+ ) -> np.ndarray:
223
+ """Generate a feathered surgical mask for a procedure.
224
+
225
+ Pipeline:
226
+ 1. Create convex hull from procedure-specific landmarks
227
+ 2. Morphological dilation by N pixels
228
+ 3. Gaussian feathering for smooth alpha gradient
229
+ 4. Add Perlin-style noise at boundary to prevent visible seams
230
+
231
+ Args:
232
+ face: Extracted facial landmarks.
233
+ procedure: Procedure name (e.g. "rhinoplasty").
234
+ width: Mask width (defaults to face.image_width).
235
+ height: Mask height (defaults to face.image_height).
236
+ clinical_flags: Optional clinical edge-case flags (vitiligo, keloid).
237
+ image: Original BGR image, required when clinical_flags.vitiligo is set.
238
+
239
+ Returns:
240
+ Float32 mask array [0.0-1.0] with feathered boundaries.
241
+ """
242
+ if procedure not in MASK_CONFIG:
243
+ raise ValueError(f"Unknown procedure: {procedure}. Choose from {list(MASK_CONFIG)}")
244
+
245
+ config = MASK_CONFIG[procedure]
246
+ w = width or face.image_width
247
+ h = height or face.image_height
248
+
249
+ # Get pixel coordinates of procedure landmarks
250
+ coords = face.landmarks[:, :2].copy()
251
+ coords[:, 0] *= w
252
+ coords[:, 1] *= h
253
+ pts = coords[config["landmark_indices"]].astype(np.int32)
254
+
255
+ # Create binary mask from convex hull
256
+ binary = np.zeros((h, w), dtype=np.uint8)
257
+ hull = cv2.convexHull(pts)
258
+ cv2.fillConvexPoly(binary, hull, 255)
259
+
260
+ # Morphological dilation
261
+ dilation = config["dilation_px"]
262
+ kernel = cv2.getStructuringElement(
263
+ cv2.MORPH_ELLIPSE,
264
+ (2 * dilation + 1, 2 * dilation + 1),
265
+ )
266
+ dilated = cv2.dilate(binary, kernel)
267
+
268
+ # Add slight boundary noise to prevent clean-edge seams
269
+ # (Spec: Perlin noise 2-4px on boundary before feathering)
270
+ boundary = cv2.subtract(
271
+ cv2.dilate(dilated, np.ones((_BOUNDARY_KERNEL_SIZE, _BOUNDARY_KERNEL_SIZE), np.uint8)),
272
+ cv2.erode(dilated, np.ones((_BOUNDARY_KERNEL_SIZE, _BOUNDARY_KERNEL_SIZE), np.uint8)),
273
+ )
274
+ noise = np.random.default_rng().integers(0, _BOUNDARY_NOISE_MAX, size=(h, w), dtype=np.uint8)
275
+ noise_boundary = cv2.bitwise_and(boundary, noise.astype(np.uint8) * _BOUNDARY_NOISE_SCALE)
276
+ dilated = cv2.add(dilated, noise_boundary)
277
+ dilated = np.clip(dilated, 0, 255).astype(np.uint8)
278
+
279
+ # Gaussian feathering
280
+ sigma = config["feather_sigma"]
281
+ ksize = int(_GAUSSIAN_KERNEL_FACTOR * sigma) | 1 # ensure odd
282
+ feathered = cv2.GaussianBlur(
283
+ dilated.astype(np.float32) / 255.0,
284
+ (ksize, ksize),
285
+ sigma,
286
+ )
287
+
288
+ mask = np.clip(feathered, 0.0, 1.0)
289
+
290
+ # Clinical edge case adjustments
291
+ if clinical_flags is not None:
292
+ # Vitiligo: reduce mask over depigmented patches to preserve them
293
+ if clinical_flags.vitiligo and image is not None:
294
+ from landmarkdiff.clinical import adjust_mask_for_vitiligo, detect_vitiligo_patches
295
+
296
+ patches = detect_vitiligo_patches(image, face)
297
+ mask = adjust_mask_for_vitiligo(mask, patches)
298
+
299
+ # Keloid: soften transitions in keloid-prone regions
300
+ if clinical_flags.keloid_prone and clinical_flags.keloid_regions:
301
+ from landmarkdiff.clinical import adjust_mask_for_keloid, get_keloid_exclusion_mask
302
+
303
+ keloid_mask = get_keloid_exclusion_mask(
304
+ face,
305
+ clinical_flags.keloid_regions,
306
+ w,
307
+ h,
308
+ )
309
+ mask = adjust_mask_for_keloid(mask, keloid_mask)
310
+
311
+ return mask
312
+
313
+
314
+ def mask_to_3channel(mask: np.ndarray) -> np.ndarray:
315
+ """Convert single-channel mask to 3-channel for compositing."""
316
+ return np.stack([mask, mask, mask], axis=-1)
@@ -0,0 +1,313 @@
1
+ """Metrics aggregation across checkpoints, experiments, and procedures.
2
+
3
+ Collects evaluation results from multiple sources and computes aggregate
4
+ statistics, confidence intervals, and significance tests for paper reporting.
5
+
6
+ Usage:
7
+ from landmarkdiff.metrics_agg import MetricsAggregator
8
+
9
+ agg = MetricsAggregator()
10
+ agg.add("baseline", "rhinoplasty", {"ssim": 0.82, "lpips": 0.18})
11
+ agg.add("ours", "rhinoplasty", {"ssim": 0.91, "lpips": 0.09})
12
+ print(agg.summary_table())
13
+ print(agg.improvement_over("baseline"))
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import math
20
+ from dataclasses import dataclass, field
21
+ from pathlib import Path
22
+ from typing import Any
23
+
24
+
25
+ @dataclass
26
+ class MetricRecord:
27
+ """A single evaluation record."""
28
+
29
+ experiment: str
30
+ procedure: str
31
+ metrics: dict[str, float]
32
+ checkpoint_step: int | None = None
33
+ metadata: dict[str, Any] = field(default_factory=dict)
34
+
35
+
36
+ class MetricsAggregator:
37
+ """Aggregate and analyze evaluation metrics.
38
+
39
+ Supports multiple experiments, procedures, and per-sample results
40
+ for computing confidence intervals and significance.
41
+ """
42
+
43
+ HIGHER_BETTER = {
44
+ "ssim": True,
45
+ "psnr": True,
46
+ "identity_sim": True,
47
+ "lpips": False,
48
+ "fid": False,
49
+ "nme": False,
50
+ }
51
+
52
+ def __init__(self) -> None:
53
+ self.records: list[MetricRecord] = []
54
+
55
+ def add(
56
+ self,
57
+ experiment: str,
58
+ procedure: str,
59
+ metrics: dict[str, float],
60
+ checkpoint_step: int | None = None,
61
+ **metadata: Any,
62
+ ) -> None:
63
+ """Add a single evaluation record."""
64
+ self.records.append(
65
+ MetricRecord(
66
+ experiment=experiment,
67
+ procedure=procedure,
68
+ metrics=metrics,
69
+ checkpoint_step=checkpoint_step,
70
+ metadata=metadata,
71
+ )
72
+ )
73
+
74
+ def add_batch(
75
+ self,
76
+ experiment: str,
77
+ records: list[dict[str, Any]],
78
+ ) -> None:
79
+ """Add multiple records for an experiment.
80
+
81
+ Each record dict should have 'procedure' and metric keys.
82
+ """
83
+ for rec in records:
84
+ proc = rec.get("procedure", "all")
85
+ metrics = {
86
+ k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))
87
+ }
88
+ self.add(experiment, proc, metrics)
89
+
90
+ @property
91
+ def experiments(self) -> list[str]:
92
+ """Unique experiment names in insertion order."""
93
+ seen: dict[str, None] = {}
94
+ for r in self.records:
95
+ seen.setdefault(r.experiment, None)
96
+ return list(seen.keys())
97
+
98
+ @property
99
+ def procedures(self) -> list[str]:
100
+ """Unique procedure names in insertion order."""
101
+ seen: dict[str, None] = {}
102
+ for r in self.records:
103
+ seen.setdefault(r.procedure, None)
104
+ return list(seen.keys())
105
+
106
+ @property
107
+ def metric_names(self) -> list[str]:
108
+ """All unique metric names."""
109
+ names: set[str] = set()
110
+ for r in self.records:
111
+ names.update(r.metrics.keys())
112
+ return sorted(names)
113
+
114
+ def filter(
115
+ self,
116
+ experiment: str | None = None,
117
+ procedure: str | None = None,
118
+ ) -> list[MetricRecord]:
119
+ """Filter records by experiment and/or procedure."""
120
+ results = self.records
121
+ if experiment is not None:
122
+ results = [r for r in results if r.experiment == experiment]
123
+ if procedure is not None:
124
+ results = [r for r in results if r.procedure == procedure]
125
+ return results
126
+
127
+ def mean(
128
+ self,
129
+ experiment: str,
130
+ metric: str,
131
+ procedure: str | None = None,
132
+ ) -> float:
133
+ """Compute mean of a metric for an experiment."""
134
+ recs = self.filter(experiment=experiment, procedure=procedure)
135
+ vals = [r.metrics[metric] for r in recs if metric in r.metrics]
136
+ if not vals:
137
+ return float("nan")
138
+ return sum(vals) / len(vals)
139
+
140
+ def std(
141
+ self,
142
+ experiment: str,
143
+ metric: str,
144
+ procedure: str | None = None,
145
+ ) -> float:
146
+ """Compute standard deviation of a metric."""
147
+ recs = self.filter(experiment=experiment, procedure=procedure)
148
+ vals = [r.metrics[metric] for r in recs if metric in r.metrics]
149
+ if len(vals) < 2:
150
+ return 0.0
151
+ m = sum(vals) / len(vals)
152
+ var = sum((v - m) ** 2 for v in vals) / (len(vals) - 1)
153
+ return math.sqrt(var)
154
+
155
+ def ci_95(
156
+ self,
157
+ experiment: str,
158
+ metric: str,
159
+ procedure: str | None = None,
160
+ ) -> tuple[float, float]:
161
+ """Compute 95% confidence interval (mean +/- 1.96*SE)."""
162
+ recs = self.filter(experiment=experiment, procedure=procedure)
163
+ vals = [r.metrics[metric] for r in recs if metric in r.metrics]
164
+ if not vals:
165
+ return (float("nan"), float("nan"))
166
+ n = len(vals)
167
+ m = sum(vals) / n
168
+ if n < 2:
169
+ return (m, m)
170
+ var = sum((v - m) ** 2 for v in vals) / (n - 1)
171
+ se = math.sqrt(var / n)
172
+ return (m - 1.96 * se, m + 1.96 * se)
173
+
174
+ def improvement_over(
175
+ self,
176
+ baseline: str,
177
+ metric: str | None = None,
178
+ ) -> dict[str, dict[str, float]]:
179
+ """Compute relative improvement of all experiments over a baseline.
180
+
181
+ Returns:
182
+ {experiment: {metric: relative_improvement_pct}}
183
+ """
184
+ metrics = [metric] if metric else self.metric_names
185
+ result: dict[str, dict[str, float]] = {}
186
+
187
+ for exp in self.experiments:
188
+ if exp == baseline:
189
+ continue
190
+ improvements: dict[str, float] = {}
191
+ for m in metrics:
192
+ base_val = self.mean(baseline, m)
193
+ exp_val = self.mean(exp, m)
194
+ if math.isnan(base_val) or math.isnan(exp_val) or base_val == 0:
195
+ continue
196
+
197
+ higher_better = self.HIGHER_BETTER.get(m, True)
198
+ if higher_better:
199
+ pct = (exp_val - base_val) / abs(base_val) * 100
200
+ else:
201
+ pct = (base_val - exp_val) / abs(base_val) * 100
202
+ improvements[m] = round(pct, 2)
203
+
204
+ result[exp] = improvements
205
+
206
+ return result
207
+
208
+ def best_experiment(
209
+ self,
210
+ metric: str,
211
+ procedure: str | None = None,
212
+ ) -> str | None:
213
+ """Find the experiment with the best mean for a metric."""
214
+ higher_better = self.HIGHER_BETTER.get(metric, True)
215
+ best_exp = None
216
+ best_val = float("-inf") if higher_better else float("inf")
217
+
218
+ for exp in self.experiments:
219
+ val = self.mean(exp, metric, procedure)
220
+ if math.isnan(val):
221
+ continue
222
+ if (higher_better and val > best_val) or (not higher_better and val < best_val):
223
+ best_val = val
224
+ best_exp = exp
225
+
226
+ return best_exp
227
+
228
+ def summary_table(
229
+ self,
230
+ metrics: list[str] | None = None,
231
+ procedure: str | None = None,
232
+ include_std: bool = False,
233
+ ) -> str:
234
+ """Generate a text summary table.
235
+
236
+ Args:
237
+ metrics: Metrics to include. None = all.
238
+ procedure: Filter by procedure. None = aggregate.
239
+ include_std: Show mean +/- std.
240
+
241
+ Returns:
242
+ Formatted text table.
243
+ """
244
+ metrics = metrics or self.metric_names
245
+ exps = self.experiments
246
+
247
+ # Header
248
+ cols = ["Experiment"] + metrics
249
+ header = " | ".join(f"{c:>16s}" for c in cols)
250
+ lines = [header, "-" * len(header)]
251
+
252
+ for exp in exps:
253
+ parts = [f"{exp:>16s}"]
254
+ for m in metrics:
255
+ val = self.mean(exp, m, procedure)
256
+ if math.isnan(val):
257
+ parts.append(f"{'--':>16s}")
258
+ elif include_std:
259
+ s = self.std(exp, m, procedure)
260
+ parts.append(f"{val:>8.4f}±{s:<6.4f}")
261
+ else:
262
+ parts.append(f"{val:>16.4f}")
263
+ lines.append(" | ".join(parts))
264
+
265
+ return "\n".join(lines)
266
+
267
+ def to_json(self, path: str | Path | None = None) -> str:
268
+ """Export all records as JSON.
269
+
270
+ Args:
271
+ path: Optional file path to write to.
272
+
273
+ Returns:
274
+ JSON string.
275
+ """
276
+ data = {
277
+ "experiments": self.experiments,
278
+ "procedures": self.procedures,
279
+ "metrics": self.metric_names,
280
+ "records": [
281
+ {
282
+ "experiment": r.experiment,
283
+ "procedure": r.procedure,
284
+ "metrics": r.metrics,
285
+ "checkpoint_step": r.checkpoint_step,
286
+ "metadata": r.metadata,
287
+ }
288
+ for r in self.records
289
+ ],
290
+ }
291
+ j = json.dumps(data, indent=2)
292
+
293
+ if path is not None:
294
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
295
+ Path(path).write_text(j)
296
+
297
+ return j
298
+
299
+ @staticmethod
300
+ def from_json(path: str | Path) -> MetricsAggregator:
301
+ """Load aggregator from JSON."""
302
+ with open(path) as f:
303
+ data = json.load(f)
304
+
305
+ agg = MetricsAggregator()
306
+ for rec in data.get("records", []):
307
+ agg.add(
308
+ experiment=rec["experiment"],
309
+ procedure=rec["procedure"],
310
+ metrics=rec["metrics"],
311
+ checkpoint_step=rec.get("checkpoint_step"),
312
+ )
313
+ return agg