landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,745 @@
1
+ """Data-driven surgical displacement extraction and modeling.
2
+
3
+ Extracts real landmark displacements from before/after surgery image pairs,
4
+ classifies procedures based on regional displacement patterns, and fits
5
+ per-procedure statistical models that can replace the hand-tuned RBF
6
+ displacement vectors in ``manipulation.py``.
7
+
8
+ Typical usage::
9
+
10
+ from landmarkdiff.displacement_model import (
11
+ extract_displacements,
12
+ extract_from_directory,
13
+ DisplacementModel,
14
+ )
15
+
16
+ # Single pair
17
+ result = extract_displacements(before_img, after_img)
18
+
19
+ # Batch from directory
20
+ all_displacements = extract_from_directory("data/surgery_pairs/")
21
+
22
+ # Fit model
23
+ model = DisplacementModel()
24
+ model.fit(all_displacements)
25
+ model.save("displacement_model.npz")
26
+
27
+ # Generate displacement field
28
+ field = model.get_displacement_field("rhinoplasty", intensity=0.7)
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import json
34
+ import logging
35
+ from pathlib import Path
36
+
37
+ import cv2
38
+ import numpy as np
39
+
40
+ from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks
41
+ from landmarkdiff.manipulation import PROCEDURE_LANDMARKS
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ # Number of MediaPipe Face Mesh landmarks (468 face + 10 iris)
46
+ NUM_LANDMARKS = 478
47
+
48
+ # All supported procedures
49
+ PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
50
+
51
+
52
+ # ---------------------------------------------------------------------------
53
+ # Helpers
54
+ # ---------------------------------------------------------------------------
55
+
56
+
57
+ def _normalized_coords_2d(face: FaceLandmarks) -> np.ndarray:
58
+ """Extract (478, 2) normalized [0, 1] coordinates from a FaceLandmarks object.
59
+
60
+ ``FaceLandmarks.landmarks`` is (478, 3) with (x, y, z) in normalized space.
61
+ We take only the x, y columns.
62
+ """
63
+ return face.landmarks[:, :2].copy()
64
+
65
+
66
+ def _compute_alignment_quality(
67
+ landmarks_before: np.ndarray,
68
+ landmarks_after: np.ndarray,
69
+ ) -> float:
70
+ """Estimate alignment quality between two landmark sets.
71
+
72
+ Uses a Procrustes-style analysis on landmarks that should *not* move during
73
+ surgery (forehead, temples, ears) to measure how well the faces are aligned.
74
+ A score of 1.0 means perfect alignment; lower values indicate pose/scale
75
+ mismatches that contaminate the displacement signal.
76
+
77
+ Args:
78
+ landmarks_before: (478, 2) normalized coordinates.
79
+ landmarks_after: (478, 2) normalized coordinates.
80
+
81
+ Returns:
82
+ Quality score in [0, 1].
83
+ """
84
+ # Stable landmarks: forehead, temple region, outer face oval
85
+ # These should exhibit near-zero displacement after surgery.
86
+ stable_indices = [
87
+ 10,
88
+ 109,
89
+ 67,
90
+ 103,
91
+ 54,
92
+ 21,
93
+ 162,
94
+ 127, # left forehead/temple
95
+ 338,
96
+ 297,
97
+ 332,
98
+ 284,
99
+ 251,
100
+ 389,
101
+ 356,
102
+ 454, # right forehead/temple
103
+ 234,
104
+ 93, # outer cheek anchors
105
+ ]
106
+ stable_indices = [i for i in stable_indices if i < NUM_LANDMARKS]
107
+
108
+ before_stable = landmarks_before[stable_indices]
109
+ after_stable = landmarks_after[stable_indices]
110
+
111
+ # RMS displacement on stable points
112
+ diffs = after_stable - before_stable
113
+ rms = np.sqrt(np.mean(np.sum(diffs**2, axis=1)))
114
+
115
+ # Map RMS to quality: 0 displacement -> 1.0, rms >= 0.05 (5% of image) -> 0.0
116
+ quality = float(np.clip(1.0 - rms / 0.05, 0.0, 1.0))
117
+ return quality
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # Procedure classification
122
+ # ---------------------------------------------------------------------------
123
+
124
+
125
+ def classify_procedure(displacements: np.ndarray) -> str:
126
+ """Classify which surgical procedure was performed from displacement vectors.
127
+
128
+ Computes the mean displacement magnitude within each procedure's landmark
129
+ region (as defined by ``PROCEDURE_LANDMARKS``) and returns the procedure
130
+ with the highest regional activity.
131
+
132
+ Args:
133
+ displacements: (478, 2) displacement vectors (after - before) in
134
+ normalized coordinate space.
135
+
136
+ Returns:
137
+ Procedure name string, one of ``PROCEDURES``, or ``"unknown"`` if
138
+ no region shows significant displacement.
139
+ """
140
+ magnitudes = np.linalg.norm(displacements, axis=1)
141
+
142
+ best_procedure = "unknown"
143
+ best_score = 0.0
144
+
145
+ for procedure, indices in PROCEDURE_LANDMARKS.items():
146
+ valid_indices = [i for i in indices if i < len(magnitudes)]
147
+ if not valid_indices:
148
+ continue
149
+
150
+ region_mag = magnitudes[valid_indices]
151
+ # Use mean magnitude in the region as the score
152
+ score = float(np.mean(region_mag))
153
+
154
+ if score > best_score:
155
+ best_score = score
156
+ best_procedure = procedure
157
+
158
+ # If the best score is negligible, classify as unknown
159
+ # Threshold: mean displacement < 0.002 (~1 pixel at 512x512)
160
+ if best_score < 0.002:
161
+ logger.debug(
162
+ "No significant displacement detected (best=%.5f). Classified as 'unknown'.",
163
+ best_score,
164
+ )
165
+ return "unknown"
166
+
167
+ return best_procedure
168
+
169
+
170
+ # ---------------------------------------------------------------------------
171
+ # Single-pair extraction
172
+ # ---------------------------------------------------------------------------
173
+
174
+
175
+ def extract_displacements(
176
+ before_img: np.ndarray,
177
+ after_img: np.ndarray,
178
+ min_detection_confidence: float = 0.5,
179
+ ) -> dict | None:
180
+ """Extract landmark displacements from a before/after surgery image pair.
181
+
182
+ Runs MediaPipe Face Mesh on both images, computes per-landmark
183
+ displacement vectors, classifies the procedure, and evaluates
184
+ alignment quality.
185
+
186
+ Args:
187
+ before_img: Pre-surgery BGR image as numpy array.
188
+ after_img: Post-surgery BGR image as numpy array.
189
+ min_detection_confidence: Minimum face detection confidence for
190
+ MediaPipe (default 0.5).
191
+
192
+ Returns:
193
+ Dictionary with keys:
194
+ - ``landmarks_before``: (478, 2) normalized coordinates
195
+ - ``landmarks_after``: (478, 2) normalized coordinates
196
+ - ``displacements``: (478, 2) displacement vectors
197
+ - ``magnitude``: (478,) per-landmark displacement magnitudes
198
+ - ``procedure``: classified procedure name or ``"unknown"``
199
+ - ``quality_score``: float in [0, 1] indicating alignment quality
200
+
201
+ Returns ``None`` if face detection fails on either image.
202
+ """
203
+ # Extract landmarks from both images
204
+ face_before = extract_landmarks(before_img, min_detection_confidence=min_detection_confidence)
205
+ if face_before is None:
206
+ logger.warning("Face detection failed on before image.")
207
+ return None
208
+
209
+ face_after = extract_landmarks(after_img, min_detection_confidence=min_detection_confidence)
210
+ if face_after is None:
211
+ logger.warning("Face detection failed on after image.")
212
+ return None
213
+
214
+ # Get normalized 2D coordinates
215
+ coords_before = _normalized_coords_2d(face_before)
216
+ coords_after = _normalized_coords_2d(face_after)
217
+
218
+ # Compute displacements
219
+ displacements = coords_after - coords_before
220
+ magnitudes = np.linalg.norm(displacements, axis=1)
221
+
222
+ # Classify procedure
223
+ procedure = classify_procedure(displacements)
224
+
225
+ # Evaluate alignment quality
226
+ quality = _compute_alignment_quality(coords_before, coords_after)
227
+
228
+ return {
229
+ "landmarks_before": coords_before,
230
+ "landmarks_after": coords_after,
231
+ "displacements": displacements,
232
+ "magnitude": magnitudes,
233
+ "procedure": procedure,
234
+ "quality_score": quality,
235
+ }
236
+
237
+
238
+ # ---------------------------------------------------------------------------
239
+ # Batch extraction from directory
240
+ # ---------------------------------------------------------------------------
241
+
242
+
243
+ def extract_from_directory(
244
+ pairs_dir: str | Path,
245
+ min_detection_confidence: float = 0.5,
246
+ min_quality: float = 0.0,
247
+ ) -> list[dict]:
248
+ """Batch-extract displacements from a directory of before/after image pairs.
249
+
250
+ Supports two naming conventions:
251
+ - ``<name>_before.{png,jpg,...}`` / ``<name>_after.{png,jpg,...}``
252
+ - ``<name>_input.{png,jpg,...}`` / ``<name>_target.{png,jpg,...}``
253
+
254
+ Args:
255
+ pairs_dir: Path to directory containing image pairs.
256
+ min_detection_confidence: Passed to ``extract_displacements``.
257
+ min_quality: Minimum alignment quality score to include a pair
258
+ in the results (default 0.0 = include all).
259
+
260
+ Returns:
261
+ List of displacement dictionaries (same format as
262
+ ``extract_displacements``), each augmented with:
263
+ - ``pair_name``: stem of the pair (e.g. ``"patient_001"``)
264
+ - ``before_path``: path to the before image
265
+ - ``after_path``: path to the after image
266
+ """
267
+ pairs_dir = Path(pairs_dir)
268
+ if not pairs_dir.is_dir():
269
+ raise FileNotFoundError(f"Directory not found: {pairs_dir}")
270
+
271
+ # Collect all image files
272
+ image_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".webp"}
273
+ all_files = {
274
+ f.stem.lower(): f
275
+ for f in pairs_dir.iterdir()
276
+ if f.is_file() and f.suffix.lower() in image_extensions
277
+ }
278
+
279
+ # Find pairs using both naming conventions
280
+ pairs: list[tuple[str, Path, Path]] = []
281
+ seen_stems: set[str] = set()
282
+
283
+ for stem_lower, filepath in all_files.items():
284
+ # Convention 1: *_before / *_after
285
+ for before_suffix, after_suffix in [("_before", "_after"), ("_input", "_target")]:
286
+ if stem_lower.endswith(before_suffix):
287
+ base = stem_lower[: -len(before_suffix)]
288
+ after_stem = base + after_suffix
289
+ if after_stem in all_files and base not in seen_stems:
290
+ # Use original-case paths
291
+ before_path = filepath
292
+ after_path = all_files[after_stem]
293
+ pairs.append((base, before_path, after_path))
294
+ seen_stems.add(base)
295
+
296
+ if not pairs:
297
+ logger.warning("No image pairs found in %s", pairs_dir)
298
+ return []
299
+
300
+ logger.info("Found %d image pairs in %s", len(pairs), pairs_dir)
301
+
302
+ results: list[dict] = []
303
+ for pair_name, before_path, after_path in sorted(pairs):
304
+ logger.info("Processing pair: %s", pair_name)
305
+
306
+ # Load images
307
+ before_img = cv2.imread(str(before_path))
308
+ if before_img is None:
309
+ logger.warning("Failed to load before image: %s", before_path)
310
+ continue
311
+
312
+ after_img = cv2.imread(str(after_path))
313
+ if after_img is None:
314
+ logger.warning("Failed to load after image: %s", after_path)
315
+ continue
316
+
317
+ # Extract displacements
318
+ result = extract_displacements(
319
+ before_img, after_img, min_detection_confidence=min_detection_confidence
320
+ )
321
+ if result is None:
322
+ logger.warning("Skipping pair %s: face detection failed.", pair_name)
323
+ continue
324
+
325
+ # Filter by quality
326
+ if result["quality_score"] < min_quality:
327
+ logger.info(
328
+ "Skipping pair %s: quality %.3f < threshold %.3f",
329
+ pair_name,
330
+ result["quality_score"],
331
+ min_quality,
332
+ )
333
+ continue
334
+
335
+ # Augment with metadata
336
+ result["pair_name"] = pair_name
337
+ result["before_path"] = str(before_path)
338
+ result["after_path"] = str(after_path)
339
+ results.append(result)
340
+
341
+ logger.info(
342
+ "Successfully extracted %d / %d pairs (%.0f%%)",
343
+ len(results),
344
+ len(pairs),
345
+ 100.0 * len(results) / max(len(pairs), 1),
346
+ )
347
+ return results
348
+
349
+
350
+ # ---------------------------------------------------------------------------
351
+ # Displacement model
352
+ # ---------------------------------------------------------------------------
353
+
354
+
355
+ class DisplacementModel:
356
+ """Statistical model of per-procedure surgical displacements.
357
+
358
+ Aggregates displacement vectors from multiple before/after pairs and
359
+ computes per-procedure, per-landmark statistics (mean, std, min, max).
360
+ Can then generate displacement fields for use in the conditioning
361
+ pipeline, replacing hand-tuned RBF vectors.
362
+
363
+ Attributes:
364
+ procedures: List of procedure names the model has data for.
365
+ stats: Nested dict ``{procedure: {stat_name: array}}``.
366
+ n_samples: Dict ``{procedure: int}`` sample counts.
367
+ """
368
+
369
+ def __init__(self) -> None:
370
+ self.stats: dict[str, dict[str, np.ndarray]] = {}
371
+ self.n_samples: dict[str, int] = {}
372
+ self._fitted = False
373
+
374
+ @property
375
+ def procedures(self) -> list[str]:
376
+ """Return list of procedures the model has been fitted on."""
377
+ return list(self.stats.keys())
378
+
379
+ @property
380
+ def fitted(self) -> bool:
381
+ """Whether the model has been fitted."""
382
+ return self._fitted
383
+
384
+ def fit(self, displacement_list: list[dict]) -> None:
385
+ """Fit the model from a list of extracted displacement dictionaries.
386
+
387
+ Groups displacements by classified procedure and computes per-landmark
388
+ statistics for each group.
389
+
390
+ Args:
391
+ displacement_list: List of dicts as returned by
392
+ ``extract_displacements`` or ``extract_from_directory``.
393
+ Each must contain ``"displacements"`` (478, 2) and
394
+ ``"procedure"`` (str) keys.
395
+
396
+ Raises:
397
+ ValueError: If ``displacement_list`` is empty or contains no
398
+ valid displacement data.
399
+ """
400
+ if not displacement_list:
401
+ raise ValueError("displacement_list is empty.")
402
+
403
+ # Group by procedure
404
+ procedure_groups: dict[str, list[np.ndarray]] = {}
405
+ for entry in displacement_list:
406
+ proc = entry.get("procedure", "unknown")
407
+ disp = entry.get("displacements")
408
+ if disp is None:
409
+ logger.warning("Skipping entry without 'displacements' key.")
410
+ continue
411
+ if disp.shape != (NUM_LANDMARKS, 2):
412
+ logger.warning(
413
+ "Skipping entry with unexpected shape %s (expected (%d, 2)).",
414
+ disp.shape,
415
+ NUM_LANDMARKS,
416
+ )
417
+ continue
418
+
419
+ if proc not in procedure_groups:
420
+ procedure_groups[proc] = []
421
+ procedure_groups[proc].append(disp)
422
+
423
+ if not procedure_groups:
424
+ raise ValueError("No valid displacement data found in displacement_list.")
425
+
426
+ # Compute per-procedure statistics
427
+ self.stats = {}
428
+ self.n_samples = {}
429
+
430
+ for proc, disp_arrays in procedure_groups.items():
431
+ stacked = np.stack(disp_arrays, axis=0) # (N, 478, 2)
432
+ n = stacked.shape[0]
433
+
434
+ self.stats[proc] = {
435
+ "mean": np.mean(stacked, axis=0), # (478, 2)
436
+ "std": np.std(stacked, axis=0), # (478, 2)
437
+ "min": np.min(stacked, axis=0), # (478, 2)
438
+ "max": np.max(stacked, axis=0), # (478, 2)
439
+ "median": np.median(stacked, axis=0), # (478, 2)
440
+ "mean_magnitude": np.mean( # (478,)
441
+ np.linalg.norm(stacked, axis=2), axis=0
442
+ ),
443
+ }
444
+ self.n_samples[proc] = n
445
+ logger.info(
446
+ "Fitted procedure '%s': %d samples, mean magnitude=%.5f",
447
+ proc,
448
+ n,
449
+ float(np.mean(self.stats[proc]["mean_magnitude"])),
450
+ )
451
+
452
+ self._fitted = True
453
+
454
+ def get_displacement_field(
455
+ self,
456
+ procedure: str,
457
+ intensity: float = 1.0,
458
+ noise_scale: float = 0.0,
459
+ rng: np.random.Generator | None = None,
460
+ ) -> np.ndarray:
461
+ """Generate a displacement field for a given procedure and intensity.
462
+
463
+ Returns the mean displacement scaled by ``intensity``, optionally
464
+ with Gaussian noise added (scaled by per-landmark std).
465
+
466
+ Args:
467
+ procedure: Procedure name (must exist in the fitted model).
468
+ intensity: Scaling factor for the mean displacement. 1.0 = average
469
+ observed displacement; 0.5 = half intensity; etc.
470
+ noise_scale: If > 0, adds Gaussian noise with this many standard
471
+ deviations of variation. 0.0 = deterministic mean field.
472
+ rng: NumPy random generator for reproducible noise. If ``None``
473
+ and ``noise_scale > 0``, uses ``np.random.default_rng()``.
474
+
475
+ Returns:
476
+ (478, 2) displacement field in normalized coordinate space.
477
+
478
+ Raises:
479
+ RuntimeError: If the model has not been fitted.
480
+ KeyError: If the procedure is not in the model.
481
+ """
482
+ if not self._fitted:
483
+ raise RuntimeError("Model has not been fitted. Call fit() first.")
484
+
485
+ if procedure not in self.stats:
486
+ available = ", ".join(self.procedures)
487
+ raise KeyError(f"Procedure '{procedure}' not in model. Available: {available}")
488
+
489
+ proc_stats = self.stats[procedure]
490
+ field = proc_stats["mean"].copy() * intensity
491
+
492
+ if noise_scale > 0:
493
+ if rng is None:
494
+ rng = np.random.default_rng()
495
+ noise = rng.normal(
496
+ loc=0.0,
497
+ scale=proc_stats["std"] * noise_scale,
498
+ )
499
+ field += noise
500
+
501
+ return field.astype(np.float32)
502
+
503
+ def get_summary(self, procedure: str | None = None) -> dict:
504
+ """Get a human-readable summary of the model statistics.
505
+
506
+ Args:
507
+ procedure: If provided, return summary for one procedure.
508
+ If ``None``, return summaries for all procedures.
509
+
510
+ Returns:
511
+ Dictionary with summary statistics.
512
+ """
513
+ if not self._fitted:
514
+ return {"fitted": False}
515
+
516
+ procs = [procedure] if procedure else self.procedures
517
+ summary = {"fitted": True, "procedures": {}}
518
+
519
+ for proc in procs:
520
+ if proc not in self.stats:
521
+ continue
522
+ s = self.stats[proc]
523
+ summary["procedures"][proc] = {
524
+ "n_samples": self.n_samples[proc],
525
+ "global_mean_magnitude": float(np.mean(s["mean_magnitude"])),
526
+ "global_max_magnitude": float(np.max(s["mean_magnitude"])),
527
+ "top_landmarks": _top_k_landmarks(s["mean_magnitude"], k=10),
528
+ }
529
+
530
+ return summary
531
+
532
+ def save(self, path: str | Path) -> None:
533
+ """Save the fitted model to disk as a ``.npz`` file.
534
+
535
+ The file contains:
536
+ - Per-procedure stat arrays keyed as ``{procedure}__{stat_name}``
537
+ - A JSON metadata string with sample counts and procedure list
538
+
539
+ Args:
540
+ path: Output file path. Extension ``.npz`` is added if missing.
541
+
542
+ Raises:
543
+ RuntimeError: If the model has not been fitted.
544
+ """
545
+ if not self._fitted:
546
+ raise RuntimeError("Model has not been fitted. Call fit() first.")
547
+
548
+ path = Path(path)
549
+ if path.suffix != ".npz":
550
+ path = path.with_suffix(".npz")
551
+
552
+ arrays: dict[str, np.ndarray] = {}
553
+ for proc, proc_stats in self.stats.items():
554
+ for stat_name, arr in proc_stats.items():
555
+ key = f"{proc}__{stat_name}"
556
+ arrays[key] = arr
557
+
558
+ # Store metadata as a JSON string encoded to bytes
559
+ metadata = {
560
+ "procedures": self.procedures,
561
+ "n_samples": self.n_samples,
562
+ "num_landmarks": NUM_LANDMARKS,
563
+ }
564
+ arrays["__metadata__"] = np.frombuffer(json.dumps(metadata).encode("utf-8"), dtype=np.uint8)
565
+
566
+ np.savez_compressed(str(path), **arrays)
567
+ logger.info("Saved displacement model to %s", path)
568
+
569
+ @classmethod
570
+ def load(cls, path: str | Path) -> DisplacementModel:
571
+ """Load a fitted model from a ``.npz`` file.
572
+
573
+ Supports two formats:
574
+ 1. ``save()`` format: keys like ``{proc}__{stat}`` with ``__metadata__``
575
+ 2. ``extract_displacements.py`` format: keys like ``{proc}_{stat}``
576
+ with a ``procedures`` array
577
+
578
+ Args:
579
+ path: Path to the ``.npz`` file.
580
+
581
+ Returns:
582
+ A fitted ``DisplacementModel`` instance.
583
+
584
+ Raises:
585
+ FileNotFoundError: If the file does not exist.
586
+ """
587
+ path = Path(path)
588
+ if not path.exists():
589
+ raise FileNotFoundError(f"Model file not found: {path}")
590
+
591
+ data = np.load(str(path), allow_pickle=False)
592
+ model = cls()
593
+
594
+ # Format 1: save() format with __metadata__
595
+ if "__metadata__" in data.files:
596
+ meta_bytes = data["__metadata__"].tobytes()
597
+ metadata = json.loads(meta_bytes.decode("utf-8"))
598
+ model.n_samples = {k: int(v) for k, v in metadata["n_samples"].items()}
599
+
600
+ for proc in metadata["procedures"]:
601
+ model.stats[proc] = {}
602
+ for key in data.files:
603
+ if key.startswith(f"{proc}__"):
604
+ stat_name = key[len(f"{proc}__") :]
605
+ model.stats[proc][stat_name] = data[key]
606
+
607
+ # Format 2: extract_displacements.py format with procedures array
608
+ elif "procedures" in data.files:
609
+ procedures = [str(p) for p in data["procedures"]]
610
+ # Map from extraction script key names to DisplacementModel stat names
611
+ stat_map = {
612
+ "mean": "mean",
613
+ "std": "std",
614
+ "median": "median",
615
+ "min": "min",
616
+ "max": "max",
617
+ "mag_mean": "mean_magnitude",
618
+ "mag_std": "std_magnitude",
619
+ "count": "_count",
620
+ }
621
+ for proc in procedures:
622
+ model.stats[proc] = {}
623
+ for ext_key, model_key in stat_map.items():
624
+ npz_key = f"{proc}_{ext_key}"
625
+ if npz_key in data.files:
626
+ arr = data[npz_key]
627
+ if model_key == "_count":
628
+ model.n_samples[proc] = int(arr)
629
+ else:
630
+ model.stats[proc][model_key] = arr
631
+
632
+ # Ensure count is set
633
+ if proc not in model.n_samples:
634
+ model.n_samples[proc] = 0
635
+
636
+ else:
637
+ raise ValueError(f"Unrecognized displacement model format. Keys: {data.files[:10]}")
638
+
639
+ # Validate loaded model is not empty
640
+ if not model.stats:
641
+ raise ValueError(
642
+ f"Displacement model at {path} contains no procedure data. "
643
+ f"File may be corrupted or empty. Keys found: {data.files[:10]}"
644
+ )
645
+ for proc, stats in model.stats.items():
646
+ if not stats:
647
+ raise ValueError(
648
+ f"Displacement model at {path} has no statistics for "
649
+ f"procedure '{proc}'. File may be corrupted."
650
+ )
651
+
652
+ model._fitted = True
653
+ logger.info(
654
+ "Loaded displacement model from %s (%d procedures, %s samples)",
655
+ path,
656
+ len(model.procedures),
657
+ model.n_samples,
658
+ )
659
+ return model
660
+
661
+
662
+ # ---------------------------------------------------------------------------
663
+ # Utilities
664
+ # ---------------------------------------------------------------------------
665
+
666
+
667
+ def _top_k_landmarks(
668
+ magnitudes: np.ndarray,
669
+ k: int = 10,
670
+ ) -> list[dict]:
671
+ """Return the top-k landmarks by mean displacement magnitude.
672
+
673
+ Args:
674
+ magnitudes: (478,) array of per-landmark magnitudes.
675
+ k: Number of top landmarks to return.
676
+
677
+ Returns:
678
+ List of dicts with ``index`` and ``magnitude`` keys, sorted
679
+ descending by magnitude.
680
+ """
681
+ top_indices = np.argsort(magnitudes)[::-1][:k]
682
+ return [{"index": int(idx), "magnitude": float(magnitudes[idx])} for idx in top_indices]
683
+
684
+
685
+ def visualize_displacements(
686
+ before_img: np.ndarray,
687
+ result: dict,
688
+ scale: float = 10.0,
689
+ arrow_color: tuple[int, int, int] = (0, 255, 0),
690
+ thickness: int = 1,
691
+ ) -> np.ndarray:
692
+ """Draw displacement arrows on the before image for visual inspection.
693
+
694
+ Args:
695
+ before_img: BGR image (will be copied).
696
+ result: Displacement dict from ``extract_displacements``.
697
+ scale: Arrow length multiplier (displacements are small in
698
+ normalized space, so scale up for visibility).
699
+ arrow_color: BGR color for arrows.
700
+ thickness: Arrow line thickness.
701
+
702
+ Returns:
703
+ Annotated BGR image.
704
+ """
705
+ canvas = before_img.copy()
706
+ h, w = canvas.shape[:2]
707
+
708
+ coords_before = result["landmarks_before"]
709
+ displacements = result["displacements"]
710
+
711
+ for i in range(NUM_LANDMARKS):
712
+ bx = int(coords_before[i, 0] * w)
713
+ by = int(coords_before[i, 1] * h)
714
+ dx = int(displacements[i, 0] * w * scale)
715
+ dy = int(displacements[i, 1] * h * scale)
716
+
717
+ # Only draw if displacement is above noise floor
718
+ mag = np.sqrt(dx**2 + dy**2)
719
+ if mag < 1.0:
720
+ continue
721
+
722
+ cv2.arrowedLine(
723
+ canvas,
724
+ (bx, by),
725
+ (bx + dx, by + dy),
726
+ arrow_color,
727
+ thickness,
728
+ tipLength=0.3,
729
+ )
730
+
731
+ # Add procedure label and quality score
732
+ proc = result.get("procedure", "unknown")
733
+ quality = result.get("quality_score", 0.0)
734
+ label = f"{proc} (quality={quality:.2f})"
735
+ cv2.putText(
736
+ canvas,
737
+ label,
738
+ (10, 30),
739
+ cv2.FONT_HERSHEY_SIMPLEX,
740
+ 0.8,
741
+ (255, 255, 255),
742
+ 2,
743
+ )
744
+
745
+ return canvas