landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/losses.py ADDED
@@ -0,0 +1,348 @@
1
+ """4-term loss function module for ControlNet fine-tuning.
2
+
3
+ L_total = L_diffusion + w_landmark * L_landmark
4
+ + w_identity * L_identity + w_perceptual * L_perceptual
5
+
6
+ Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
7
+ rubbery TPS warps — it would penalize realism.
8
+
9
+ Phase B (FEM/clinical data): All 4 terms enabled.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class LossWeights:
22
+ """Loss term weights."""
23
+
24
+ diffusion: float = 1.0
25
+ landmark: float = 0.1
26
+ identity: float = 0.1
27
+ perceptual: float = 0.05
28
+
29
+
30
+ class DiffusionLoss:
31
+ """Standard epsilon-prediction MSE loss (primary training signal)."""
32
+
33
+ def __call__(
34
+ self,
35
+ noise_pred: torch.Tensor,
36
+ noise_target: torch.Tensor,
37
+ ) -> torch.Tensor:
38
+ return F.mse_loss(noise_pred, noise_target)
39
+
40
+
41
+ class LandmarkLoss:
42
+ """L2 landmark distance normalized by inter-ocular distance.
43
+
44
+ Computed INSIDE surgical mask only. Requires MediaPipe re-extraction
45
+ from generated image (done at eval, not every training step for speed).
46
+ """
47
+
48
+ def __call__(
49
+ self,
50
+ pred_landmarks: torch.Tensor, # (B, N, 2)
51
+ target_landmarks: torch.Tensor, # (B, N, 2)
52
+ mask: torch.Tensor | None = None, # (B, N) binary
53
+ iod: torch.Tensor | None = None, # (B,) inter-ocular distance
54
+ ) -> torch.Tensor:
55
+ diff = pred_landmarks - target_landmarks # (B, N, 2)
56
+ dist = torch.norm(diff, dim=-1) # (B, N)
57
+
58
+ if mask is not None:
59
+ dist = dist * mask
60
+ count = mask.sum(dim=-1).clamp(min=1)
61
+ mean_dist = dist.sum(dim=-1) / count
62
+ else:
63
+ mean_dist = dist.mean(dim=-1)
64
+
65
+ if iod is not None:
66
+ mean_dist = mean_dist / iod.clamp(min=1.0)
67
+
68
+ return mean_dist.mean()
69
+
70
+
71
+ class IdentityLoss:
72
+ """ArcFace cosine similarity loss with procedure-dependent masking.
73
+
74
+ Uses InsightFace ArcFace model (buffalo_l) for 512-dim identity embeddings.
75
+ Falls back to pixel-level cosine similarity if InsightFace is unavailable.
76
+
77
+ - Full face for blepharoplasty
78
+ - Upper-face crop for rhinoplasty
79
+ - Disabled for orthognathic
80
+
81
+ Input images MUST be normalized to [-1, 1] and cropped to 112x112
82
+ before passing to ArcFace (AdaFace outputs garbage for 1024x1024).
83
+ """
84
+
85
+ def __init__(self, device: torch.device | None = None):
86
+ self._model = None
87
+ self._device = device
88
+ self._has_arcface = None # None = not checked yet
89
+
90
+ def _ensure_loaded(self, device: torch.device) -> None:
91
+ """Lazy-load ArcFace model on first use."""
92
+ if self._has_arcface is not None:
93
+ return
94
+ try:
95
+ from insightface.app import FaceAnalysis
96
+
97
+ self._app = FaceAnalysis(
98
+ name="buffalo_l",
99
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
100
+ )
101
+ ctx_id = (
102
+ device.index
103
+ if device.type == "cuda" and device.index is not None
104
+ else (0 if device.type == "cuda" else -1)
105
+ )
106
+ self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
107
+ self._has_arcface = True
108
+ except Exception:
109
+ self._has_arcface = False
110
+
111
+ @torch.no_grad()
112
+ def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
113
+ """Extract ArcFace embedding from a batch of images.
114
+
115
+ Args:
116
+ image_tensor: (B, 3, 112, 112) in [-1, 1]
117
+
118
+ Returns:
119
+ (B, 512) identity embeddings, or (B, D) pixel-level if fallback.
120
+ """
121
+ if self._has_arcface:
122
+ import numpy as np
123
+
124
+ embeddings = []
125
+ valid_mask = []
126
+ for i in range(image_tensor.shape[0]):
127
+ # Convert to uint8 BGR for InsightFace
128
+ img = ((image_tensor[i].permute(1, 2, 0) + 1) / 2 * 255).clamp(0, 255)
129
+ img_np = img.cpu().numpy().astype(np.uint8)
130
+ img_bgr = img_np[:, :, ::-1].copy()
131
+
132
+ faces = self._app.get(img_bgr)
133
+ if faces and hasattr(faces[0], "embedding") and faces[0].embedding is not None:
134
+ embeddings.append(torch.from_numpy(faces[0].embedding))
135
+ valid_mask.append(True)
136
+ else:
137
+ embeddings.append(torch.zeros(512))
138
+ valid_mask.append(False)
139
+
140
+ return torch.stack(embeddings).to(image_tensor.device), valid_mask
141
+ else:
142
+ # Fallback: pixel-level features
143
+ return image_tensor.flatten(1), [True] * image_tensor.shape[0]
144
+
145
+ def __call__(
146
+ self,
147
+ pred_image: torch.Tensor, # (B, 3, H, W) in [0, 1]
148
+ target_image: torch.Tensor,
149
+ procedure: str = "rhinoplasty",
150
+ ) -> torch.Tensor:
151
+ if procedure == "orthognathic":
152
+ return torch.tensor(0.0, device=pred_image.device)
153
+
154
+ self._ensure_loaded(pred_image.device)
155
+
156
+ # Crop based on procedure
157
+ pred_crop = self._procedure_crop(pred_image, procedure)
158
+ target_crop = self._procedure_crop(target_image, procedure)
159
+
160
+ # Resize to 112x112 for ArcFace
161
+ pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
162
+ target_112 = F.interpolate(
163
+ target_crop,
164
+ size=(112, 112),
165
+ mode="bilinear",
166
+ align_corners=False,
167
+ )
168
+
169
+ # Normalize to [-1, 1]
170
+ pred_norm = pred_112 * 2 - 1
171
+ target_norm = target_112 * 2 - 1
172
+
173
+ # Extract embeddings (ArcFace or fallback)
174
+ pred_emb, pred_valid = self._extract_embedding(pred_norm)
175
+ target_emb, target_valid = self._extract_embedding(target_norm)
176
+
177
+ # Only compute loss for samples where both faces were detected
178
+ valid = [p and t for p, t in zip(pred_valid, target_valid)]
179
+ if not any(valid):
180
+ return torch.tensor(0.0, device=pred_image.device)
181
+
182
+ valid_indices = [i for i, v in enumerate(valid) if v]
183
+ valid_idx_t = torch.tensor(valid_indices, device=pred_image.device, dtype=torch.long)
184
+
185
+ # Select ONLY valid embeddings before normalization to avoid 0/0 NaN
186
+ pred_valid_emb = pred_emb[valid_idx_t].float()
187
+ target_valid_emb = target_emb[valid_idx_t].float()
188
+
189
+ # L2 normalize (safe — zero vectors excluded above)
190
+ pred_valid_emb = F.normalize(pred_valid_emb, dim=1)
191
+ target_valid_emb = F.normalize(target_valid_emb, dim=1)
192
+
193
+ cosine_sim = (pred_valid_emb * target_valid_emb).sum(dim=1)
194
+ return (1 - cosine_sim).mean()
195
+
196
+ def _procedure_crop(
197
+ self,
198
+ image: torch.Tensor,
199
+ procedure: str,
200
+ ) -> torch.Tensor:
201
+ """Crop image based on procedure for identity comparison."""
202
+ _, _, h, w = image.shape
203
+
204
+ if procedure == "rhinoplasty":
205
+ # Upper face crop (forehead to nose tip)
206
+ return image[:, :, : h * 2 // 3, :]
207
+ elif procedure == "blepharoplasty":
208
+ # Full face
209
+ return image
210
+ elif procedure == "rhytidectomy":
211
+ # Upper face (above jawline)
212
+ return image[:, :, : h * 3 // 4, :]
213
+ else:
214
+ return image
215
+
216
+
217
+ class PerceptualLoss:
218
+ """LPIPS perceptual loss on regions OUTSIDE surgical mask only.
219
+
220
+ LPIPS expects [-1, 1] input. VAE outputs [0, 1].
221
+ Must apply (x * 2) - 1 before every call.
222
+ """
223
+
224
+ def __init__(self):
225
+ self._lpips = None
226
+
227
+ def _ensure_loaded(self, device: torch.device) -> None:
228
+ if self._lpips is None:
229
+ try:
230
+ import lpips
231
+
232
+ self._lpips = lpips.LPIPS(net="alex").to(device)
233
+ self._lpips.eval()
234
+ for p in self._lpips.parameters():
235
+ p.requires_grad_(False)
236
+ except ImportError:
237
+ self._lpips = "unavailable"
238
+
239
+ def __call__(
240
+ self,
241
+ pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
242
+ target: torch.Tensor,
243
+ mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
244
+ ) -> torch.Tensor:
245
+ self._ensure_loaded(pred.device)
246
+
247
+ # Normalize to [-1, 1] for LPIPS
248
+ pred_norm = pred * 2 - 1
249
+ target_norm = target * 2 - 1
250
+
251
+ # When mask is all-ones (no mask file available), compute on full image.
252
+ # Otherwise invert mask to get loss OUTSIDE the surgical region only.
253
+ has_mask = mask.sum() < mask.numel() * 0.99
254
+ if has_mask:
255
+ outside_mask = 1 - mask
256
+ erode_kernel = 5
257
+ if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
258
+ outside_mask = -F.max_pool2d(
259
+ -outside_mask,
260
+ kernel_size=erode_kernel,
261
+ stride=1,
262
+ padding=erode_kernel // 2,
263
+ )
264
+ pred_norm = pred_norm * outside_mask
265
+ target_norm = target_norm * outside_mask
266
+
267
+ if self._lpips == "unavailable":
268
+ # Fallback: simple L1 loss
269
+ return F.l1_loss(pred_norm, target_norm)
270
+
271
+ return self._lpips(pred_norm, target_norm).mean()
272
+
273
+
274
+ class CombinedLoss:
275
+ """Combined 4-term loss with configurable weights.
276
+
277
+ Use phase='A' for Phase A training (diffusion only).
278
+ Use phase='B' for Phase B training (all terms).
279
+
280
+ For Phase B, set ``use_differentiable_arcface=True`` to use the
281
+ PyTorch-native ArcFace backbone (``arcface_torch.py``) that provides
282
+ actual gradient signal. The default ONNX-based IdentityLoss produces
283
+ zero gradients (DA2-03).
284
+ """
285
+
286
+ def __init__(
287
+ self,
288
+ weights: LossWeights | None = None,
289
+ phase: str = "A",
290
+ use_differentiable_arcface: bool = False,
291
+ arcface_weights_path: str | None = None,
292
+ ):
293
+ self.weights = weights or LossWeights()
294
+ self.phase = phase
295
+ self.diffusion_loss = DiffusionLoss()
296
+ self.landmark_loss = LandmarkLoss()
297
+ self.perceptual_loss = PerceptualLoss()
298
+
299
+ # Identity loss: differentiable PyTorch ArcFace for Phase B,
300
+ # or ONNX-based fallback
301
+ if use_differentiable_arcface:
302
+ from landmarkdiff.arcface_torch import ArcFaceLoss
303
+
304
+ self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
305
+ else:
306
+ self.identity_loss = IdentityLoss()
307
+
308
+ def __call__(
309
+ self,
310
+ noise_pred: torch.Tensor,
311
+ noise_target: torch.Tensor,
312
+ **kwargs,
313
+ ) -> dict[str, torch.Tensor]:
314
+ losses = {}
315
+
316
+ # Always compute diffusion loss
317
+ losses["diffusion"] = self.weights.diffusion * self.diffusion_loss(noise_pred, noise_target)
318
+ losses["total"] = losses["diffusion"]
319
+
320
+ if self.phase == "B":
321
+ # Phase B: add auxiliary losses
322
+ if "pred_landmarks" in kwargs and "target_landmarks" in kwargs:
323
+ losses["landmark"] = self.weights.landmark * self.landmark_loss(
324
+ kwargs["pred_landmarks"],
325
+ kwargs["target_landmarks"],
326
+ kwargs.get("landmark_mask"),
327
+ kwargs.get("iod"),
328
+ )
329
+ losses["total"] = losses["total"] + losses["landmark"]
330
+
331
+ if "pred_image" in kwargs and "target_image" in kwargs:
332
+ procedure = kwargs.get("procedure", "rhinoplasty")
333
+ losses["identity"] = self.weights.identity * self.identity_loss(
334
+ kwargs["pred_image"],
335
+ kwargs["target_image"],
336
+ procedure,
337
+ )
338
+ losses["total"] = losses["total"] + losses["identity"]
339
+
340
+ if "pred_image" in kwargs and "target_image" in kwargs and "mask" in kwargs:
341
+ losses["perceptual"] = self.weights.perceptual * self.perceptual_loss(
342
+ kwargs["pred_image"],
343
+ kwargs["target_image"],
344
+ kwargs["mask"],
345
+ )
346
+ losses["total"] = losses["total"] + losses["perceptual"]
347
+
348
+ return losses