landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/losses.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
"""4-term loss function module for ControlNet fine-tuning.
|
|
2
|
+
|
|
3
|
+
L_total = L_diffusion + w_landmark * L_landmark
|
|
4
|
+
+ w_identity * L_identity + w_perceptual * L_perceptual
|
|
5
|
+
|
|
6
|
+
Phase A (synthetic TPS data): L_diffusion ONLY. No perceptual loss against
|
|
7
|
+
rubbery TPS warps — it would penalize realism.
|
|
8
|
+
|
|
9
|
+
Phase B (FEM/clinical data): All 4 terms enabled.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
import torch.nn.functional as F
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True)
|
|
21
|
+
class LossWeights:
|
|
22
|
+
"""Loss term weights."""
|
|
23
|
+
|
|
24
|
+
diffusion: float = 1.0
|
|
25
|
+
landmark: float = 0.1
|
|
26
|
+
identity: float = 0.1
|
|
27
|
+
perceptual: float = 0.05
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DiffusionLoss:
|
|
31
|
+
"""Standard epsilon-prediction MSE loss (primary training signal)."""
|
|
32
|
+
|
|
33
|
+
def __call__(
|
|
34
|
+
self,
|
|
35
|
+
noise_pred: torch.Tensor,
|
|
36
|
+
noise_target: torch.Tensor,
|
|
37
|
+
) -> torch.Tensor:
|
|
38
|
+
return F.mse_loss(noise_pred, noise_target)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LandmarkLoss:
|
|
42
|
+
"""L2 landmark distance normalized by inter-ocular distance.
|
|
43
|
+
|
|
44
|
+
Computed INSIDE surgical mask only. Requires MediaPipe re-extraction
|
|
45
|
+
from generated image (done at eval, not every training step for speed).
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __call__(
|
|
49
|
+
self,
|
|
50
|
+
pred_landmarks: torch.Tensor, # (B, N, 2)
|
|
51
|
+
target_landmarks: torch.Tensor, # (B, N, 2)
|
|
52
|
+
mask: torch.Tensor | None = None, # (B, N) binary
|
|
53
|
+
iod: torch.Tensor | None = None, # (B,) inter-ocular distance
|
|
54
|
+
) -> torch.Tensor:
|
|
55
|
+
diff = pred_landmarks - target_landmarks # (B, N, 2)
|
|
56
|
+
dist = torch.norm(diff, dim=-1) # (B, N)
|
|
57
|
+
|
|
58
|
+
if mask is not None:
|
|
59
|
+
dist = dist * mask
|
|
60
|
+
count = mask.sum(dim=-1).clamp(min=1)
|
|
61
|
+
mean_dist = dist.sum(dim=-1) / count
|
|
62
|
+
else:
|
|
63
|
+
mean_dist = dist.mean(dim=-1)
|
|
64
|
+
|
|
65
|
+
if iod is not None:
|
|
66
|
+
mean_dist = mean_dist / iod.clamp(min=1.0)
|
|
67
|
+
|
|
68
|
+
return mean_dist.mean()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class IdentityLoss:
|
|
72
|
+
"""ArcFace cosine similarity loss with procedure-dependent masking.
|
|
73
|
+
|
|
74
|
+
Uses InsightFace ArcFace model (buffalo_l) for 512-dim identity embeddings.
|
|
75
|
+
Falls back to pixel-level cosine similarity if InsightFace is unavailable.
|
|
76
|
+
|
|
77
|
+
- Full face for blepharoplasty
|
|
78
|
+
- Upper-face crop for rhinoplasty
|
|
79
|
+
- Disabled for orthognathic
|
|
80
|
+
|
|
81
|
+
Input images MUST be normalized to [-1, 1] and cropped to 112x112
|
|
82
|
+
before passing to ArcFace (AdaFace outputs garbage for 1024x1024).
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, device: torch.device | None = None):
|
|
86
|
+
self._model = None
|
|
87
|
+
self._device = device
|
|
88
|
+
self._has_arcface = None # None = not checked yet
|
|
89
|
+
|
|
90
|
+
def _ensure_loaded(self, device: torch.device) -> None:
|
|
91
|
+
"""Lazy-load ArcFace model on first use."""
|
|
92
|
+
if self._has_arcface is not None:
|
|
93
|
+
return
|
|
94
|
+
try:
|
|
95
|
+
from insightface.app import FaceAnalysis
|
|
96
|
+
|
|
97
|
+
self._app = FaceAnalysis(
|
|
98
|
+
name="buffalo_l",
|
|
99
|
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
100
|
+
)
|
|
101
|
+
ctx_id = (
|
|
102
|
+
device.index
|
|
103
|
+
if device.type == "cuda" and device.index is not None
|
|
104
|
+
else (0 if device.type == "cuda" else -1)
|
|
105
|
+
)
|
|
106
|
+
self._app.prepare(ctx_id=ctx_id, det_size=(320, 320))
|
|
107
|
+
self._has_arcface = True
|
|
108
|
+
except Exception:
|
|
109
|
+
self._has_arcface = False
|
|
110
|
+
|
|
111
|
+
@torch.no_grad()
|
|
112
|
+
def _extract_embedding(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
|
113
|
+
"""Extract ArcFace embedding from a batch of images.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
image_tensor: (B, 3, 112, 112) in [-1, 1]
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
(B, 512) identity embeddings, or (B, D) pixel-level if fallback.
|
|
120
|
+
"""
|
|
121
|
+
if self._has_arcface:
|
|
122
|
+
import numpy as np
|
|
123
|
+
|
|
124
|
+
embeddings = []
|
|
125
|
+
valid_mask = []
|
|
126
|
+
for i in range(image_tensor.shape[0]):
|
|
127
|
+
# Convert to uint8 BGR for InsightFace
|
|
128
|
+
img = ((image_tensor[i].permute(1, 2, 0) + 1) / 2 * 255).clamp(0, 255)
|
|
129
|
+
img_np = img.cpu().numpy().astype(np.uint8)
|
|
130
|
+
img_bgr = img_np[:, :, ::-1].copy()
|
|
131
|
+
|
|
132
|
+
faces = self._app.get(img_bgr)
|
|
133
|
+
if faces and hasattr(faces[0], "embedding") and faces[0].embedding is not None:
|
|
134
|
+
embeddings.append(torch.from_numpy(faces[0].embedding))
|
|
135
|
+
valid_mask.append(True)
|
|
136
|
+
else:
|
|
137
|
+
embeddings.append(torch.zeros(512))
|
|
138
|
+
valid_mask.append(False)
|
|
139
|
+
|
|
140
|
+
return torch.stack(embeddings).to(image_tensor.device), valid_mask
|
|
141
|
+
else:
|
|
142
|
+
# Fallback: pixel-level features
|
|
143
|
+
return image_tensor.flatten(1), [True] * image_tensor.shape[0]
|
|
144
|
+
|
|
145
|
+
def __call__(
|
|
146
|
+
self,
|
|
147
|
+
pred_image: torch.Tensor, # (B, 3, H, W) in [0, 1]
|
|
148
|
+
target_image: torch.Tensor,
|
|
149
|
+
procedure: str = "rhinoplasty",
|
|
150
|
+
) -> torch.Tensor:
|
|
151
|
+
if procedure == "orthognathic":
|
|
152
|
+
return torch.tensor(0.0, device=pred_image.device)
|
|
153
|
+
|
|
154
|
+
self._ensure_loaded(pred_image.device)
|
|
155
|
+
|
|
156
|
+
# Crop based on procedure
|
|
157
|
+
pred_crop = self._procedure_crop(pred_image, procedure)
|
|
158
|
+
target_crop = self._procedure_crop(target_image, procedure)
|
|
159
|
+
|
|
160
|
+
# Resize to 112x112 for ArcFace
|
|
161
|
+
pred_112 = F.interpolate(pred_crop, size=(112, 112), mode="bilinear", align_corners=False)
|
|
162
|
+
target_112 = F.interpolate(
|
|
163
|
+
target_crop,
|
|
164
|
+
size=(112, 112),
|
|
165
|
+
mode="bilinear",
|
|
166
|
+
align_corners=False,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Normalize to [-1, 1]
|
|
170
|
+
pred_norm = pred_112 * 2 - 1
|
|
171
|
+
target_norm = target_112 * 2 - 1
|
|
172
|
+
|
|
173
|
+
# Extract embeddings (ArcFace or fallback)
|
|
174
|
+
pred_emb, pred_valid = self._extract_embedding(pred_norm)
|
|
175
|
+
target_emb, target_valid = self._extract_embedding(target_norm)
|
|
176
|
+
|
|
177
|
+
# Only compute loss for samples where both faces were detected
|
|
178
|
+
valid = [p and t for p, t in zip(pred_valid, target_valid)]
|
|
179
|
+
if not any(valid):
|
|
180
|
+
return torch.tensor(0.0, device=pred_image.device)
|
|
181
|
+
|
|
182
|
+
valid_indices = [i for i, v in enumerate(valid) if v]
|
|
183
|
+
valid_idx_t = torch.tensor(valid_indices, device=pred_image.device, dtype=torch.long)
|
|
184
|
+
|
|
185
|
+
# Select ONLY valid embeddings before normalization to avoid 0/0 NaN
|
|
186
|
+
pred_valid_emb = pred_emb[valid_idx_t].float()
|
|
187
|
+
target_valid_emb = target_emb[valid_idx_t].float()
|
|
188
|
+
|
|
189
|
+
# L2 normalize (safe — zero vectors excluded above)
|
|
190
|
+
pred_valid_emb = F.normalize(pred_valid_emb, dim=1)
|
|
191
|
+
target_valid_emb = F.normalize(target_valid_emb, dim=1)
|
|
192
|
+
|
|
193
|
+
cosine_sim = (pred_valid_emb * target_valid_emb).sum(dim=1)
|
|
194
|
+
return (1 - cosine_sim).mean()
|
|
195
|
+
|
|
196
|
+
def _procedure_crop(
|
|
197
|
+
self,
|
|
198
|
+
image: torch.Tensor,
|
|
199
|
+
procedure: str,
|
|
200
|
+
) -> torch.Tensor:
|
|
201
|
+
"""Crop image based on procedure for identity comparison."""
|
|
202
|
+
_, _, h, w = image.shape
|
|
203
|
+
|
|
204
|
+
if procedure == "rhinoplasty":
|
|
205
|
+
# Upper face crop (forehead to nose tip)
|
|
206
|
+
return image[:, :, : h * 2 // 3, :]
|
|
207
|
+
elif procedure == "blepharoplasty":
|
|
208
|
+
# Full face
|
|
209
|
+
return image
|
|
210
|
+
elif procedure == "rhytidectomy":
|
|
211
|
+
# Upper face (above jawline)
|
|
212
|
+
return image[:, :, : h * 3 // 4, :]
|
|
213
|
+
else:
|
|
214
|
+
return image
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class PerceptualLoss:
|
|
218
|
+
"""LPIPS perceptual loss on regions OUTSIDE surgical mask only.
|
|
219
|
+
|
|
220
|
+
LPIPS expects [-1, 1] input. VAE outputs [0, 1].
|
|
221
|
+
Must apply (x * 2) - 1 before every call.
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
def __init__(self):
|
|
225
|
+
self._lpips = None
|
|
226
|
+
|
|
227
|
+
def _ensure_loaded(self, device: torch.device) -> None:
|
|
228
|
+
if self._lpips is None:
|
|
229
|
+
try:
|
|
230
|
+
import lpips
|
|
231
|
+
|
|
232
|
+
self._lpips = lpips.LPIPS(net="alex").to(device)
|
|
233
|
+
self._lpips.eval()
|
|
234
|
+
for p in self._lpips.parameters():
|
|
235
|
+
p.requires_grad_(False)
|
|
236
|
+
except ImportError:
|
|
237
|
+
self._lpips = "unavailable"
|
|
238
|
+
|
|
239
|
+
def __call__(
|
|
240
|
+
self,
|
|
241
|
+
pred: torch.Tensor, # (B, 3, H, W) in [0, 1]
|
|
242
|
+
target: torch.Tensor,
|
|
243
|
+
mask: torch.Tensor, # (B, 1, H, W) surgical mask [0, 1]
|
|
244
|
+
) -> torch.Tensor:
|
|
245
|
+
self._ensure_loaded(pred.device)
|
|
246
|
+
|
|
247
|
+
# Normalize to [-1, 1] for LPIPS
|
|
248
|
+
pred_norm = pred * 2 - 1
|
|
249
|
+
target_norm = target * 2 - 1
|
|
250
|
+
|
|
251
|
+
# When mask is all-ones (no mask file available), compute on full image.
|
|
252
|
+
# Otherwise invert mask to get loss OUTSIDE the surgical region only.
|
|
253
|
+
has_mask = mask.sum() < mask.numel() * 0.99
|
|
254
|
+
if has_mask:
|
|
255
|
+
outside_mask = 1 - mask
|
|
256
|
+
erode_kernel = 5
|
|
257
|
+
if outside_mask.shape[-1] >= erode_kernel and outside_mask.shape[-2] >= erode_kernel:
|
|
258
|
+
outside_mask = -F.max_pool2d(
|
|
259
|
+
-outside_mask,
|
|
260
|
+
kernel_size=erode_kernel,
|
|
261
|
+
stride=1,
|
|
262
|
+
padding=erode_kernel // 2,
|
|
263
|
+
)
|
|
264
|
+
pred_norm = pred_norm * outside_mask
|
|
265
|
+
target_norm = target_norm * outside_mask
|
|
266
|
+
|
|
267
|
+
if self._lpips == "unavailable":
|
|
268
|
+
# Fallback: simple L1 loss
|
|
269
|
+
return F.l1_loss(pred_norm, target_norm)
|
|
270
|
+
|
|
271
|
+
return self._lpips(pred_norm, target_norm).mean()
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class CombinedLoss:
|
|
275
|
+
"""Combined 4-term loss with configurable weights.
|
|
276
|
+
|
|
277
|
+
Use phase='A' for Phase A training (diffusion only).
|
|
278
|
+
Use phase='B' for Phase B training (all terms).
|
|
279
|
+
|
|
280
|
+
For Phase B, set ``use_differentiable_arcface=True`` to use the
|
|
281
|
+
PyTorch-native ArcFace backbone (``arcface_torch.py``) that provides
|
|
282
|
+
actual gradient signal. The default ONNX-based IdentityLoss produces
|
|
283
|
+
zero gradients (DA2-03).
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
def __init__(
|
|
287
|
+
self,
|
|
288
|
+
weights: LossWeights | None = None,
|
|
289
|
+
phase: str = "A",
|
|
290
|
+
use_differentiable_arcface: bool = False,
|
|
291
|
+
arcface_weights_path: str | None = None,
|
|
292
|
+
):
|
|
293
|
+
self.weights = weights or LossWeights()
|
|
294
|
+
self.phase = phase
|
|
295
|
+
self.diffusion_loss = DiffusionLoss()
|
|
296
|
+
self.landmark_loss = LandmarkLoss()
|
|
297
|
+
self.perceptual_loss = PerceptualLoss()
|
|
298
|
+
|
|
299
|
+
# Identity loss: differentiable PyTorch ArcFace for Phase B,
|
|
300
|
+
# or ONNX-based fallback
|
|
301
|
+
if use_differentiable_arcface:
|
|
302
|
+
from landmarkdiff.arcface_torch import ArcFaceLoss
|
|
303
|
+
|
|
304
|
+
self.identity_loss = ArcFaceLoss(weights_path=arcface_weights_path)
|
|
305
|
+
else:
|
|
306
|
+
self.identity_loss = IdentityLoss()
|
|
307
|
+
|
|
308
|
+
def __call__(
|
|
309
|
+
self,
|
|
310
|
+
noise_pred: torch.Tensor,
|
|
311
|
+
noise_target: torch.Tensor,
|
|
312
|
+
**kwargs,
|
|
313
|
+
) -> dict[str, torch.Tensor]:
|
|
314
|
+
losses = {}
|
|
315
|
+
|
|
316
|
+
# Always compute diffusion loss
|
|
317
|
+
losses["diffusion"] = self.weights.diffusion * self.diffusion_loss(noise_pred, noise_target)
|
|
318
|
+
losses["total"] = losses["diffusion"]
|
|
319
|
+
|
|
320
|
+
if self.phase == "B":
|
|
321
|
+
# Phase B: add auxiliary losses
|
|
322
|
+
if "pred_landmarks" in kwargs and "target_landmarks" in kwargs:
|
|
323
|
+
losses["landmark"] = self.weights.landmark * self.landmark_loss(
|
|
324
|
+
kwargs["pred_landmarks"],
|
|
325
|
+
kwargs["target_landmarks"],
|
|
326
|
+
kwargs.get("landmark_mask"),
|
|
327
|
+
kwargs.get("iod"),
|
|
328
|
+
)
|
|
329
|
+
losses["total"] = losses["total"] + losses["landmark"]
|
|
330
|
+
|
|
331
|
+
if "pred_image" in kwargs and "target_image" in kwargs:
|
|
332
|
+
procedure = kwargs.get("procedure", "rhinoplasty")
|
|
333
|
+
losses["identity"] = self.weights.identity * self.identity_loss(
|
|
334
|
+
kwargs["pred_image"],
|
|
335
|
+
kwargs["target_image"],
|
|
336
|
+
procedure,
|
|
337
|
+
)
|
|
338
|
+
losses["total"] = losses["total"] + losses["identity"]
|
|
339
|
+
|
|
340
|
+
if "pred_image" in kwargs and "target_image" in kwargs and "mask" in kwargs:
|
|
341
|
+
losses["perceptual"] = self.weights.perceptual * self.perceptual_loss(
|
|
342
|
+
kwargs["pred_image"],
|
|
343
|
+
kwargs["target_image"],
|
|
344
|
+
kwargs["mask"],
|
|
345
|
+
)
|
|
346
|
+
losses["total"] = losses["total"] + losses["perceptual"]
|
|
347
|
+
|
|
348
|
+
return losses
|