landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,745 @@
|
|
|
1
|
+
"""Data-driven surgical displacement extraction and modeling.
|
|
2
|
+
|
|
3
|
+
Extracts real landmark displacements from before/after surgery image pairs,
|
|
4
|
+
classifies procedures based on regional displacement patterns, and fits
|
|
5
|
+
per-procedure statistical models that can replace the hand-tuned RBF
|
|
6
|
+
displacement vectors in ``manipulation.py``.
|
|
7
|
+
|
|
8
|
+
Typical usage::
|
|
9
|
+
|
|
10
|
+
from landmarkdiff.displacement_model import (
|
|
11
|
+
extract_displacements,
|
|
12
|
+
extract_from_directory,
|
|
13
|
+
DisplacementModel,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
# Single pair
|
|
17
|
+
result = extract_displacements(before_img, after_img)
|
|
18
|
+
|
|
19
|
+
# Batch from directory
|
|
20
|
+
all_displacements = extract_from_directory("data/surgery_pairs/")
|
|
21
|
+
|
|
22
|
+
# Fit model
|
|
23
|
+
model = DisplacementModel()
|
|
24
|
+
model.fit(all_displacements)
|
|
25
|
+
model.save("displacement_model.npz")
|
|
26
|
+
|
|
27
|
+
# Generate displacement field
|
|
28
|
+
field = model.get_displacement_field("rhinoplasty", intensity=0.7)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
import json
|
|
34
|
+
import logging
|
|
35
|
+
from pathlib import Path
|
|
36
|
+
|
|
37
|
+
import cv2
|
|
38
|
+
import numpy as np
|
|
39
|
+
|
|
40
|
+
from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks
|
|
41
|
+
from landmarkdiff.manipulation import PROCEDURE_LANDMARKS
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
# Number of MediaPipe Face Mesh landmarks (468 face + 10 iris)
|
|
46
|
+
NUM_LANDMARKS = 478
|
|
47
|
+
|
|
48
|
+
# All supported procedures
|
|
49
|
+
PROCEDURES = list(PROCEDURE_LANDMARKS.keys())
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# ---------------------------------------------------------------------------
|
|
53
|
+
# Helpers
|
|
54
|
+
# ---------------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _normalized_coords_2d(face: FaceLandmarks) -> np.ndarray:
|
|
58
|
+
"""Extract (478, 2) normalized [0, 1] coordinates from a FaceLandmarks object.
|
|
59
|
+
|
|
60
|
+
``FaceLandmarks.landmarks`` is (478, 3) with (x, y, z) in normalized space.
|
|
61
|
+
We take only the x, y columns.
|
|
62
|
+
"""
|
|
63
|
+
return face.landmarks[:, :2].copy()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _compute_alignment_quality(
|
|
67
|
+
landmarks_before: np.ndarray,
|
|
68
|
+
landmarks_after: np.ndarray,
|
|
69
|
+
) -> float:
|
|
70
|
+
"""Estimate alignment quality between two landmark sets.
|
|
71
|
+
|
|
72
|
+
Uses a Procrustes-style analysis on landmarks that should *not* move during
|
|
73
|
+
surgery (forehead, temples, ears) to measure how well the faces are aligned.
|
|
74
|
+
A score of 1.0 means perfect alignment; lower values indicate pose/scale
|
|
75
|
+
mismatches that contaminate the displacement signal.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
landmarks_before: (478, 2) normalized coordinates.
|
|
79
|
+
landmarks_after: (478, 2) normalized coordinates.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Quality score in [0, 1].
|
|
83
|
+
"""
|
|
84
|
+
# Stable landmarks: forehead, temple region, outer face oval
|
|
85
|
+
# These should exhibit near-zero displacement after surgery.
|
|
86
|
+
stable_indices = [
|
|
87
|
+
10,
|
|
88
|
+
109,
|
|
89
|
+
67,
|
|
90
|
+
103,
|
|
91
|
+
54,
|
|
92
|
+
21,
|
|
93
|
+
162,
|
|
94
|
+
127, # left forehead/temple
|
|
95
|
+
338,
|
|
96
|
+
297,
|
|
97
|
+
332,
|
|
98
|
+
284,
|
|
99
|
+
251,
|
|
100
|
+
389,
|
|
101
|
+
356,
|
|
102
|
+
454, # right forehead/temple
|
|
103
|
+
234,
|
|
104
|
+
93, # outer cheek anchors
|
|
105
|
+
]
|
|
106
|
+
stable_indices = [i for i in stable_indices if i < NUM_LANDMARKS]
|
|
107
|
+
|
|
108
|
+
before_stable = landmarks_before[stable_indices]
|
|
109
|
+
after_stable = landmarks_after[stable_indices]
|
|
110
|
+
|
|
111
|
+
# RMS displacement on stable points
|
|
112
|
+
diffs = after_stable - before_stable
|
|
113
|
+
rms = np.sqrt(np.mean(np.sum(diffs**2, axis=1)))
|
|
114
|
+
|
|
115
|
+
# Map RMS to quality: 0 displacement -> 1.0, rms >= 0.05 (5% of image) -> 0.0
|
|
116
|
+
quality = float(np.clip(1.0 - rms / 0.05, 0.0, 1.0))
|
|
117
|
+
return quality
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# ---------------------------------------------------------------------------
|
|
121
|
+
# Procedure classification
|
|
122
|
+
# ---------------------------------------------------------------------------
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def classify_procedure(displacements: np.ndarray) -> str:
|
|
126
|
+
"""Classify which surgical procedure was performed from displacement vectors.
|
|
127
|
+
|
|
128
|
+
Computes the mean displacement magnitude within each procedure's landmark
|
|
129
|
+
region (as defined by ``PROCEDURE_LANDMARKS``) and returns the procedure
|
|
130
|
+
with the highest regional activity.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
displacements: (478, 2) displacement vectors (after - before) in
|
|
134
|
+
normalized coordinate space.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Procedure name string, one of ``PROCEDURES``, or ``"unknown"`` if
|
|
138
|
+
no region shows significant displacement.
|
|
139
|
+
"""
|
|
140
|
+
magnitudes = np.linalg.norm(displacements, axis=1)
|
|
141
|
+
|
|
142
|
+
best_procedure = "unknown"
|
|
143
|
+
best_score = 0.0
|
|
144
|
+
|
|
145
|
+
for procedure, indices in PROCEDURE_LANDMARKS.items():
|
|
146
|
+
valid_indices = [i for i in indices if i < len(magnitudes)]
|
|
147
|
+
if not valid_indices:
|
|
148
|
+
continue
|
|
149
|
+
|
|
150
|
+
region_mag = magnitudes[valid_indices]
|
|
151
|
+
# Use mean magnitude in the region as the score
|
|
152
|
+
score = float(np.mean(region_mag))
|
|
153
|
+
|
|
154
|
+
if score > best_score:
|
|
155
|
+
best_score = score
|
|
156
|
+
best_procedure = procedure
|
|
157
|
+
|
|
158
|
+
# If the best score is negligible, classify as unknown
|
|
159
|
+
# Threshold: mean displacement < 0.002 (~1 pixel at 512x512)
|
|
160
|
+
if best_score < 0.002:
|
|
161
|
+
logger.debug(
|
|
162
|
+
"No significant displacement detected (best=%.5f). Classified as 'unknown'.",
|
|
163
|
+
best_score,
|
|
164
|
+
)
|
|
165
|
+
return "unknown"
|
|
166
|
+
|
|
167
|
+
return best_procedure
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# ---------------------------------------------------------------------------
|
|
171
|
+
# Single-pair extraction
|
|
172
|
+
# ---------------------------------------------------------------------------
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def extract_displacements(
|
|
176
|
+
before_img: np.ndarray,
|
|
177
|
+
after_img: np.ndarray,
|
|
178
|
+
min_detection_confidence: float = 0.5,
|
|
179
|
+
) -> dict | None:
|
|
180
|
+
"""Extract landmark displacements from a before/after surgery image pair.
|
|
181
|
+
|
|
182
|
+
Runs MediaPipe Face Mesh on both images, computes per-landmark
|
|
183
|
+
displacement vectors, classifies the procedure, and evaluates
|
|
184
|
+
alignment quality.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
before_img: Pre-surgery BGR image as numpy array.
|
|
188
|
+
after_img: Post-surgery BGR image as numpy array.
|
|
189
|
+
min_detection_confidence: Minimum face detection confidence for
|
|
190
|
+
MediaPipe (default 0.5).
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
Dictionary with keys:
|
|
194
|
+
- ``landmarks_before``: (478, 2) normalized coordinates
|
|
195
|
+
- ``landmarks_after``: (478, 2) normalized coordinates
|
|
196
|
+
- ``displacements``: (478, 2) displacement vectors
|
|
197
|
+
- ``magnitude``: (478,) per-landmark displacement magnitudes
|
|
198
|
+
- ``procedure``: classified procedure name or ``"unknown"``
|
|
199
|
+
- ``quality_score``: float in [0, 1] indicating alignment quality
|
|
200
|
+
|
|
201
|
+
Returns ``None`` if face detection fails on either image.
|
|
202
|
+
"""
|
|
203
|
+
# Extract landmarks from both images
|
|
204
|
+
face_before = extract_landmarks(before_img, min_detection_confidence=min_detection_confidence)
|
|
205
|
+
if face_before is None:
|
|
206
|
+
logger.warning("Face detection failed on before image.")
|
|
207
|
+
return None
|
|
208
|
+
|
|
209
|
+
face_after = extract_landmarks(after_img, min_detection_confidence=min_detection_confidence)
|
|
210
|
+
if face_after is None:
|
|
211
|
+
logger.warning("Face detection failed on after image.")
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
# Get normalized 2D coordinates
|
|
215
|
+
coords_before = _normalized_coords_2d(face_before)
|
|
216
|
+
coords_after = _normalized_coords_2d(face_after)
|
|
217
|
+
|
|
218
|
+
# Compute displacements
|
|
219
|
+
displacements = coords_after - coords_before
|
|
220
|
+
magnitudes = np.linalg.norm(displacements, axis=1)
|
|
221
|
+
|
|
222
|
+
# Classify procedure
|
|
223
|
+
procedure = classify_procedure(displacements)
|
|
224
|
+
|
|
225
|
+
# Evaluate alignment quality
|
|
226
|
+
quality = _compute_alignment_quality(coords_before, coords_after)
|
|
227
|
+
|
|
228
|
+
return {
|
|
229
|
+
"landmarks_before": coords_before,
|
|
230
|
+
"landmarks_after": coords_after,
|
|
231
|
+
"displacements": displacements,
|
|
232
|
+
"magnitude": magnitudes,
|
|
233
|
+
"procedure": procedure,
|
|
234
|
+
"quality_score": quality,
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
# ---------------------------------------------------------------------------
|
|
239
|
+
# Batch extraction from directory
|
|
240
|
+
# ---------------------------------------------------------------------------
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def extract_from_directory(
|
|
244
|
+
pairs_dir: str | Path,
|
|
245
|
+
min_detection_confidence: float = 0.5,
|
|
246
|
+
min_quality: float = 0.0,
|
|
247
|
+
) -> list[dict]:
|
|
248
|
+
"""Batch-extract displacements from a directory of before/after image pairs.
|
|
249
|
+
|
|
250
|
+
Supports two naming conventions:
|
|
251
|
+
- ``<name>_before.{png,jpg,...}`` / ``<name>_after.{png,jpg,...}``
|
|
252
|
+
- ``<name>_input.{png,jpg,...}`` / ``<name>_target.{png,jpg,...}``
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
pairs_dir: Path to directory containing image pairs.
|
|
256
|
+
min_detection_confidence: Passed to ``extract_displacements``.
|
|
257
|
+
min_quality: Minimum alignment quality score to include a pair
|
|
258
|
+
in the results (default 0.0 = include all).
|
|
259
|
+
|
|
260
|
+
Returns:
|
|
261
|
+
List of displacement dictionaries (same format as
|
|
262
|
+
``extract_displacements``), each augmented with:
|
|
263
|
+
- ``pair_name``: stem of the pair (e.g. ``"patient_001"``)
|
|
264
|
+
- ``before_path``: path to the before image
|
|
265
|
+
- ``after_path``: path to the after image
|
|
266
|
+
"""
|
|
267
|
+
pairs_dir = Path(pairs_dir)
|
|
268
|
+
if not pairs_dir.is_dir():
|
|
269
|
+
raise FileNotFoundError(f"Directory not found: {pairs_dir}")
|
|
270
|
+
|
|
271
|
+
# Collect all image files
|
|
272
|
+
image_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif", ".webp"}
|
|
273
|
+
all_files = {
|
|
274
|
+
f.stem.lower(): f
|
|
275
|
+
for f in pairs_dir.iterdir()
|
|
276
|
+
if f.is_file() and f.suffix.lower() in image_extensions
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
# Find pairs using both naming conventions
|
|
280
|
+
pairs: list[tuple[str, Path, Path]] = []
|
|
281
|
+
seen_stems: set[str] = set()
|
|
282
|
+
|
|
283
|
+
for stem_lower, filepath in all_files.items():
|
|
284
|
+
# Convention 1: *_before / *_after
|
|
285
|
+
for before_suffix, after_suffix in [("_before", "_after"), ("_input", "_target")]:
|
|
286
|
+
if stem_lower.endswith(before_suffix):
|
|
287
|
+
base = stem_lower[: -len(before_suffix)]
|
|
288
|
+
after_stem = base + after_suffix
|
|
289
|
+
if after_stem in all_files and base not in seen_stems:
|
|
290
|
+
# Use original-case paths
|
|
291
|
+
before_path = filepath
|
|
292
|
+
after_path = all_files[after_stem]
|
|
293
|
+
pairs.append((base, before_path, after_path))
|
|
294
|
+
seen_stems.add(base)
|
|
295
|
+
|
|
296
|
+
if not pairs:
|
|
297
|
+
logger.warning("No image pairs found in %s", pairs_dir)
|
|
298
|
+
return []
|
|
299
|
+
|
|
300
|
+
logger.info("Found %d image pairs in %s", len(pairs), pairs_dir)
|
|
301
|
+
|
|
302
|
+
results: list[dict] = []
|
|
303
|
+
for pair_name, before_path, after_path in sorted(pairs):
|
|
304
|
+
logger.info("Processing pair: %s", pair_name)
|
|
305
|
+
|
|
306
|
+
# Load images
|
|
307
|
+
before_img = cv2.imread(str(before_path))
|
|
308
|
+
if before_img is None:
|
|
309
|
+
logger.warning("Failed to load before image: %s", before_path)
|
|
310
|
+
continue
|
|
311
|
+
|
|
312
|
+
after_img = cv2.imread(str(after_path))
|
|
313
|
+
if after_img is None:
|
|
314
|
+
logger.warning("Failed to load after image: %s", after_path)
|
|
315
|
+
continue
|
|
316
|
+
|
|
317
|
+
# Extract displacements
|
|
318
|
+
result = extract_displacements(
|
|
319
|
+
before_img, after_img, min_detection_confidence=min_detection_confidence
|
|
320
|
+
)
|
|
321
|
+
if result is None:
|
|
322
|
+
logger.warning("Skipping pair %s: face detection failed.", pair_name)
|
|
323
|
+
continue
|
|
324
|
+
|
|
325
|
+
# Filter by quality
|
|
326
|
+
if result["quality_score"] < min_quality:
|
|
327
|
+
logger.info(
|
|
328
|
+
"Skipping pair %s: quality %.3f < threshold %.3f",
|
|
329
|
+
pair_name,
|
|
330
|
+
result["quality_score"],
|
|
331
|
+
min_quality,
|
|
332
|
+
)
|
|
333
|
+
continue
|
|
334
|
+
|
|
335
|
+
# Augment with metadata
|
|
336
|
+
result["pair_name"] = pair_name
|
|
337
|
+
result["before_path"] = str(before_path)
|
|
338
|
+
result["after_path"] = str(after_path)
|
|
339
|
+
results.append(result)
|
|
340
|
+
|
|
341
|
+
logger.info(
|
|
342
|
+
"Successfully extracted %d / %d pairs (%.0f%%)",
|
|
343
|
+
len(results),
|
|
344
|
+
len(pairs),
|
|
345
|
+
100.0 * len(results) / max(len(pairs), 1),
|
|
346
|
+
)
|
|
347
|
+
return results
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
# ---------------------------------------------------------------------------
|
|
351
|
+
# Displacement model
|
|
352
|
+
# ---------------------------------------------------------------------------
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class DisplacementModel:
|
|
356
|
+
"""Statistical model of per-procedure surgical displacements.
|
|
357
|
+
|
|
358
|
+
Aggregates displacement vectors from multiple before/after pairs and
|
|
359
|
+
computes per-procedure, per-landmark statistics (mean, std, min, max).
|
|
360
|
+
Can then generate displacement fields for use in the conditioning
|
|
361
|
+
pipeline, replacing hand-tuned RBF vectors.
|
|
362
|
+
|
|
363
|
+
Attributes:
|
|
364
|
+
procedures: List of procedure names the model has data for.
|
|
365
|
+
stats: Nested dict ``{procedure: {stat_name: array}}``.
|
|
366
|
+
n_samples: Dict ``{procedure: int}`` sample counts.
|
|
367
|
+
"""
|
|
368
|
+
|
|
369
|
+
def __init__(self) -> None:
|
|
370
|
+
self.stats: dict[str, dict[str, np.ndarray]] = {}
|
|
371
|
+
self.n_samples: dict[str, int] = {}
|
|
372
|
+
self._fitted = False
|
|
373
|
+
|
|
374
|
+
@property
|
|
375
|
+
def procedures(self) -> list[str]:
|
|
376
|
+
"""Return list of procedures the model has been fitted on."""
|
|
377
|
+
return list(self.stats.keys())
|
|
378
|
+
|
|
379
|
+
@property
|
|
380
|
+
def fitted(self) -> bool:
|
|
381
|
+
"""Whether the model has been fitted."""
|
|
382
|
+
return self._fitted
|
|
383
|
+
|
|
384
|
+
def fit(self, displacement_list: list[dict]) -> None:
|
|
385
|
+
"""Fit the model from a list of extracted displacement dictionaries.
|
|
386
|
+
|
|
387
|
+
Groups displacements by classified procedure and computes per-landmark
|
|
388
|
+
statistics for each group.
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
displacement_list: List of dicts as returned by
|
|
392
|
+
``extract_displacements`` or ``extract_from_directory``.
|
|
393
|
+
Each must contain ``"displacements"`` (478, 2) and
|
|
394
|
+
``"procedure"`` (str) keys.
|
|
395
|
+
|
|
396
|
+
Raises:
|
|
397
|
+
ValueError: If ``displacement_list`` is empty or contains no
|
|
398
|
+
valid displacement data.
|
|
399
|
+
"""
|
|
400
|
+
if not displacement_list:
|
|
401
|
+
raise ValueError("displacement_list is empty.")
|
|
402
|
+
|
|
403
|
+
# Group by procedure
|
|
404
|
+
procedure_groups: dict[str, list[np.ndarray]] = {}
|
|
405
|
+
for entry in displacement_list:
|
|
406
|
+
proc = entry.get("procedure", "unknown")
|
|
407
|
+
disp = entry.get("displacements")
|
|
408
|
+
if disp is None:
|
|
409
|
+
logger.warning("Skipping entry without 'displacements' key.")
|
|
410
|
+
continue
|
|
411
|
+
if disp.shape != (NUM_LANDMARKS, 2):
|
|
412
|
+
logger.warning(
|
|
413
|
+
"Skipping entry with unexpected shape %s (expected (%d, 2)).",
|
|
414
|
+
disp.shape,
|
|
415
|
+
NUM_LANDMARKS,
|
|
416
|
+
)
|
|
417
|
+
continue
|
|
418
|
+
|
|
419
|
+
if proc not in procedure_groups:
|
|
420
|
+
procedure_groups[proc] = []
|
|
421
|
+
procedure_groups[proc].append(disp)
|
|
422
|
+
|
|
423
|
+
if not procedure_groups:
|
|
424
|
+
raise ValueError("No valid displacement data found in displacement_list.")
|
|
425
|
+
|
|
426
|
+
# Compute per-procedure statistics
|
|
427
|
+
self.stats = {}
|
|
428
|
+
self.n_samples = {}
|
|
429
|
+
|
|
430
|
+
for proc, disp_arrays in procedure_groups.items():
|
|
431
|
+
stacked = np.stack(disp_arrays, axis=0) # (N, 478, 2)
|
|
432
|
+
n = stacked.shape[0]
|
|
433
|
+
|
|
434
|
+
self.stats[proc] = {
|
|
435
|
+
"mean": np.mean(stacked, axis=0), # (478, 2)
|
|
436
|
+
"std": np.std(stacked, axis=0), # (478, 2)
|
|
437
|
+
"min": np.min(stacked, axis=0), # (478, 2)
|
|
438
|
+
"max": np.max(stacked, axis=0), # (478, 2)
|
|
439
|
+
"median": np.median(stacked, axis=0), # (478, 2)
|
|
440
|
+
"mean_magnitude": np.mean( # (478,)
|
|
441
|
+
np.linalg.norm(stacked, axis=2), axis=0
|
|
442
|
+
),
|
|
443
|
+
}
|
|
444
|
+
self.n_samples[proc] = n
|
|
445
|
+
logger.info(
|
|
446
|
+
"Fitted procedure '%s': %d samples, mean magnitude=%.5f",
|
|
447
|
+
proc,
|
|
448
|
+
n,
|
|
449
|
+
float(np.mean(self.stats[proc]["mean_magnitude"])),
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
self._fitted = True
|
|
453
|
+
|
|
454
|
+
def get_displacement_field(
|
|
455
|
+
self,
|
|
456
|
+
procedure: str,
|
|
457
|
+
intensity: float = 1.0,
|
|
458
|
+
noise_scale: float = 0.0,
|
|
459
|
+
rng: np.random.Generator | None = None,
|
|
460
|
+
) -> np.ndarray:
|
|
461
|
+
"""Generate a displacement field for a given procedure and intensity.
|
|
462
|
+
|
|
463
|
+
Returns the mean displacement scaled by ``intensity``, optionally
|
|
464
|
+
with Gaussian noise added (scaled by per-landmark std).
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
procedure: Procedure name (must exist in the fitted model).
|
|
468
|
+
intensity: Scaling factor for the mean displacement. 1.0 = average
|
|
469
|
+
observed displacement; 0.5 = half intensity; etc.
|
|
470
|
+
noise_scale: If > 0, adds Gaussian noise with this many standard
|
|
471
|
+
deviations of variation. 0.0 = deterministic mean field.
|
|
472
|
+
rng: NumPy random generator for reproducible noise. If ``None``
|
|
473
|
+
and ``noise_scale > 0``, uses ``np.random.default_rng()``.
|
|
474
|
+
|
|
475
|
+
Returns:
|
|
476
|
+
(478, 2) displacement field in normalized coordinate space.
|
|
477
|
+
|
|
478
|
+
Raises:
|
|
479
|
+
RuntimeError: If the model has not been fitted.
|
|
480
|
+
KeyError: If the procedure is not in the model.
|
|
481
|
+
"""
|
|
482
|
+
if not self._fitted:
|
|
483
|
+
raise RuntimeError("Model has not been fitted. Call fit() first.")
|
|
484
|
+
|
|
485
|
+
if procedure not in self.stats:
|
|
486
|
+
available = ", ".join(self.procedures)
|
|
487
|
+
raise KeyError(f"Procedure '{procedure}' not in model. Available: {available}")
|
|
488
|
+
|
|
489
|
+
proc_stats = self.stats[procedure]
|
|
490
|
+
field = proc_stats["mean"].copy() * intensity
|
|
491
|
+
|
|
492
|
+
if noise_scale > 0:
|
|
493
|
+
if rng is None:
|
|
494
|
+
rng = np.random.default_rng()
|
|
495
|
+
noise = rng.normal(
|
|
496
|
+
loc=0.0,
|
|
497
|
+
scale=proc_stats["std"] * noise_scale,
|
|
498
|
+
)
|
|
499
|
+
field += noise
|
|
500
|
+
|
|
501
|
+
return field.astype(np.float32)
|
|
502
|
+
|
|
503
|
+
def get_summary(self, procedure: str | None = None) -> dict:
|
|
504
|
+
"""Get a human-readable summary of the model statistics.
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
procedure: If provided, return summary for one procedure.
|
|
508
|
+
If ``None``, return summaries for all procedures.
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
Dictionary with summary statistics.
|
|
512
|
+
"""
|
|
513
|
+
if not self._fitted:
|
|
514
|
+
return {"fitted": False}
|
|
515
|
+
|
|
516
|
+
procs = [procedure] if procedure else self.procedures
|
|
517
|
+
summary = {"fitted": True, "procedures": {}}
|
|
518
|
+
|
|
519
|
+
for proc in procs:
|
|
520
|
+
if proc not in self.stats:
|
|
521
|
+
continue
|
|
522
|
+
s = self.stats[proc]
|
|
523
|
+
summary["procedures"][proc] = {
|
|
524
|
+
"n_samples": self.n_samples[proc],
|
|
525
|
+
"global_mean_magnitude": float(np.mean(s["mean_magnitude"])),
|
|
526
|
+
"global_max_magnitude": float(np.max(s["mean_magnitude"])),
|
|
527
|
+
"top_landmarks": _top_k_landmarks(s["mean_magnitude"], k=10),
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
return summary
|
|
531
|
+
|
|
532
|
+
def save(self, path: str | Path) -> None:
|
|
533
|
+
"""Save the fitted model to disk as a ``.npz`` file.
|
|
534
|
+
|
|
535
|
+
The file contains:
|
|
536
|
+
- Per-procedure stat arrays keyed as ``{procedure}__{stat_name}``
|
|
537
|
+
- A JSON metadata string with sample counts and procedure list
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
path: Output file path. Extension ``.npz`` is added if missing.
|
|
541
|
+
|
|
542
|
+
Raises:
|
|
543
|
+
RuntimeError: If the model has not been fitted.
|
|
544
|
+
"""
|
|
545
|
+
if not self._fitted:
|
|
546
|
+
raise RuntimeError("Model has not been fitted. Call fit() first.")
|
|
547
|
+
|
|
548
|
+
path = Path(path)
|
|
549
|
+
if path.suffix != ".npz":
|
|
550
|
+
path = path.with_suffix(".npz")
|
|
551
|
+
|
|
552
|
+
arrays: dict[str, np.ndarray] = {}
|
|
553
|
+
for proc, proc_stats in self.stats.items():
|
|
554
|
+
for stat_name, arr in proc_stats.items():
|
|
555
|
+
key = f"{proc}__{stat_name}"
|
|
556
|
+
arrays[key] = arr
|
|
557
|
+
|
|
558
|
+
# Store metadata as a JSON string encoded to bytes
|
|
559
|
+
metadata = {
|
|
560
|
+
"procedures": self.procedures,
|
|
561
|
+
"n_samples": self.n_samples,
|
|
562
|
+
"num_landmarks": NUM_LANDMARKS,
|
|
563
|
+
}
|
|
564
|
+
arrays["__metadata__"] = np.frombuffer(json.dumps(metadata).encode("utf-8"), dtype=np.uint8)
|
|
565
|
+
|
|
566
|
+
np.savez_compressed(str(path), **arrays)
|
|
567
|
+
logger.info("Saved displacement model to %s", path)
|
|
568
|
+
|
|
569
|
+
@classmethod
|
|
570
|
+
def load(cls, path: str | Path) -> DisplacementModel:
|
|
571
|
+
"""Load a fitted model from a ``.npz`` file.
|
|
572
|
+
|
|
573
|
+
Supports two formats:
|
|
574
|
+
1. ``save()`` format: keys like ``{proc}__{stat}`` with ``__metadata__``
|
|
575
|
+
2. ``extract_displacements.py`` format: keys like ``{proc}_{stat}``
|
|
576
|
+
with a ``procedures`` array
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
path: Path to the ``.npz`` file.
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
A fitted ``DisplacementModel`` instance.
|
|
583
|
+
|
|
584
|
+
Raises:
|
|
585
|
+
FileNotFoundError: If the file does not exist.
|
|
586
|
+
"""
|
|
587
|
+
path = Path(path)
|
|
588
|
+
if not path.exists():
|
|
589
|
+
raise FileNotFoundError(f"Model file not found: {path}")
|
|
590
|
+
|
|
591
|
+
data = np.load(str(path), allow_pickle=False)
|
|
592
|
+
model = cls()
|
|
593
|
+
|
|
594
|
+
# Format 1: save() format with __metadata__
|
|
595
|
+
if "__metadata__" in data.files:
|
|
596
|
+
meta_bytes = data["__metadata__"].tobytes()
|
|
597
|
+
metadata = json.loads(meta_bytes.decode("utf-8"))
|
|
598
|
+
model.n_samples = {k: int(v) for k, v in metadata["n_samples"].items()}
|
|
599
|
+
|
|
600
|
+
for proc in metadata["procedures"]:
|
|
601
|
+
model.stats[proc] = {}
|
|
602
|
+
for key in data.files:
|
|
603
|
+
if key.startswith(f"{proc}__"):
|
|
604
|
+
stat_name = key[len(f"{proc}__") :]
|
|
605
|
+
model.stats[proc][stat_name] = data[key]
|
|
606
|
+
|
|
607
|
+
# Format 2: extract_displacements.py format with procedures array
|
|
608
|
+
elif "procedures" in data.files:
|
|
609
|
+
procedures = [str(p) for p in data["procedures"]]
|
|
610
|
+
# Map from extraction script key names to DisplacementModel stat names
|
|
611
|
+
stat_map = {
|
|
612
|
+
"mean": "mean",
|
|
613
|
+
"std": "std",
|
|
614
|
+
"median": "median",
|
|
615
|
+
"min": "min",
|
|
616
|
+
"max": "max",
|
|
617
|
+
"mag_mean": "mean_magnitude",
|
|
618
|
+
"mag_std": "std_magnitude",
|
|
619
|
+
"count": "_count",
|
|
620
|
+
}
|
|
621
|
+
for proc in procedures:
|
|
622
|
+
model.stats[proc] = {}
|
|
623
|
+
for ext_key, model_key in stat_map.items():
|
|
624
|
+
npz_key = f"{proc}_{ext_key}"
|
|
625
|
+
if npz_key in data.files:
|
|
626
|
+
arr = data[npz_key]
|
|
627
|
+
if model_key == "_count":
|
|
628
|
+
model.n_samples[proc] = int(arr)
|
|
629
|
+
else:
|
|
630
|
+
model.stats[proc][model_key] = arr
|
|
631
|
+
|
|
632
|
+
# Ensure count is set
|
|
633
|
+
if proc not in model.n_samples:
|
|
634
|
+
model.n_samples[proc] = 0
|
|
635
|
+
|
|
636
|
+
else:
|
|
637
|
+
raise ValueError(f"Unrecognized displacement model format. Keys: {data.files[:10]}")
|
|
638
|
+
|
|
639
|
+
# Validate loaded model is not empty
|
|
640
|
+
if not model.stats:
|
|
641
|
+
raise ValueError(
|
|
642
|
+
f"Displacement model at {path} contains no procedure data. "
|
|
643
|
+
f"File may be corrupted or empty. Keys found: {data.files[:10]}"
|
|
644
|
+
)
|
|
645
|
+
for proc, stats in model.stats.items():
|
|
646
|
+
if not stats:
|
|
647
|
+
raise ValueError(
|
|
648
|
+
f"Displacement model at {path} has no statistics for "
|
|
649
|
+
f"procedure '{proc}'. File may be corrupted."
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
model._fitted = True
|
|
653
|
+
logger.info(
|
|
654
|
+
"Loaded displacement model from %s (%d procedures, %s samples)",
|
|
655
|
+
path,
|
|
656
|
+
len(model.procedures),
|
|
657
|
+
model.n_samples,
|
|
658
|
+
)
|
|
659
|
+
return model
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
# ---------------------------------------------------------------------------
|
|
663
|
+
# Utilities
|
|
664
|
+
# ---------------------------------------------------------------------------
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
def _top_k_landmarks(
|
|
668
|
+
magnitudes: np.ndarray,
|
|
669
|
+
k: int = 10,
|
|
670
|
+
) -> list[dict]:
|
|
671
|
+
"""Return the top-k landmarks by mean displacement magnitude.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
magnitudes: (478,) array of per-landmark magnitudes.
|
|
675
|
+
k: Number of top landmarks to return.
|
|
676
|
+
|
|
677
|
+
Returns:
|
|
678
|
+
List of dicts with ``index`` and ``magnitude`` keys, sorted
|
|
679
|
+
descending by magnitude.
|
|
680
|
+
"""
|
|
681
|
+
top_indices = np.argsort(magnitudes)[::-1][:k]
|
|
682
|
+
return [{"index": int(idx), "magnitude": float(magnitudes[idx])} for idx in top_indices]
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def visualize_displacements(
|
|
686
|
+
before_img: np.ndarray,
|
|
687
|
+
result: dict,
|
|
688
|
+
scale: float = 10.0,
|
|
689
|
+
arrow_color: tuple[int, int, int] = (0, 255, 0),
|
|
690
|
+
thickness: int = 1,
|
|
691
|
+
) -> np.ndarray:
|
|
692
|
+
"""Draw displacement arrows on the before image for visual inspection.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
before_img: BGR image (will be copied).
|
|
696
|
+
result: Displacement dict from ``extract_displacements``.
|
|
697
|
+
scale: Arrow length multiplier (displacements are small in
|
|
698
|
+
normalized space, so scale up for visibility).
|
|
699
|
+
arrow_color: BGR color for arrows.
|
|
700
|
+
thickness: Arrow line thickness.
|
|
701
|
+
|
|
702
|
+
Returns:
|
|
703
|
+
Annotated BGR image.
|
|
704
|
+
"""
|
|
705
|
+
canvas = before_img.copy()
|
|
706
|
+
h, w = canvas.shape[:2]
|
|
707
|
+
|
|
708
|
+
coords_before = result["landmarks_before"]
|
|
709
|
+
displacements = result["displacements"]
|
|
710
|
+
|
|
711
|
+
for i in range(NUM_LANDMARKS):
|
|
712
|
+
bx = int(coords_before[i, 0] * w)
|
|
713
|
+
by = int(coords_before[i, 1] * h)
|
|
714
|
+
dx = int(displacements[i, 0] * w * scale)
|
|
715
|
+
dy = int(displacements[i, 1] * h * scale)
|
|
716
|
+
|
|
717
|
+
# Only draw if displacement is above noise floor
|
|
718
|
+
mag = np.sqrt(dx**2 + dy**2)
|
|
719
|
+
if mag < 1.0:
|
|
720
|
+
continue
|
|
721
|
+
|
|
722
|
+
cv2.arrowedLine(
|
|
723
|
+
canvas,
|
|
724
|
+
(bx, by),
|
|
725
|
+
(bx + dx, by + dy),
|
|
726
|
+
arrow_color,
|
|
727
|
+
thickness,
|
|
728
|
+
tipLength=0.3,
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
# Add procedure label and quality score
|
|
732
|
+
proc = result.get("procedure", "unknown")
|
|
733
|
+
quality = result.get("quality_score", 0.0)
|
|
734
|
+
label = f"{proc} (quality={quality:.2f})"
|
|
735
|
+
cv2.putText(
|
|
736
|
+
canvas,
|
|
737
|
+
label,
|
|
738
|
+
(10, 30),
|
|
739
|
+
cv2.FONT_HERSHEY_SIMPLEX,
|
|
740
|
+
0.8,
|
|
741
|
+
(255, 255, 255),
|
|
742
|
+
2,
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
return canvas
|