landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,754 @@
1
+ """Inference pipeline for surgical outcome prediction.
2
+
3
+ Four modes:
4
+ 1. ControlNet: CrucibleAI/ControlNetMediaPipeFace + SD1.5 (requires HF auth + GPU)
5
+ 2. ControlNet + IP-Adapter: ControlNet with identity preservation via face embeddings
6
+ 3. Img2Img: SD1.5 img2img with mask compositing (runs on MPS, no auth needed)
7
+ 4. TPS-only: Pure geometric warp -- no diffusion model, instant results
8
+
9
+ Supports MPS (Apple Silicon), CUDA, and CPU backends.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import logging
15
+ import os
16
+ import sys
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING
19
+
20
+ import cv2
21
+ import numpy as np
22
+ import torch
23
+ from PIL import Image
24
+
25
+ from landmarkdiff.landmarks import FaceLandmarks, extract_landmarks, render_landmark_image
26
+ from landmarkdiff.manipulation import apply_procedure_preset
27
+ from landmarkdiff.masking import generate_surgical_mask, mask_to_3channel
28
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps
29
+
30
+ if TYPE_CHECKING:
31
+ from landmarkdiff.clinical import ClinicalFlags
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ def get_device() -> torch.device:
37
+ if torch.backends.mps.is_available():
38
+ return torch.device("mps")
39
+ if torch.cuda.is_available():
40
+ return torch.device("cuda")
41
+ return torch.device("cpu")
42
+
43
+
44
+ def numpy_to_pil(arr: np.ndarray) -> Image.Image:
45
+ if len(arr.shape) == 2:
46
+ return Image.fromarray(arr, mode="L")
47
+ return Image.fromarray(arr[:, :, ::-1].copy())
48
+
49
+
50
+ def pil_to_numpy(img: Image.Image) -> np.ndarray:
51
+ arr = np.array(img)
52
+ if len(arr.shape) == 3 and arr.shape[2] == 3:
53
+ return arr[:, :, ::-1].copy()
54
+ return arr
55
+
56
+
57
+ PROCEDURE_PROMPTS: dict[str, str] = {
58
+ "rhinoplasty": (
59
+ "clinical photograph, patient face, natural refined nose, smooth nasal bridge, "
60
+ "realistic skin pores and texture, sharp focus, studio lighting, "
61
+ "DSLR quality, natural skin color"
62
+ ),
63
+ "blepharoplasty": (
64
+ "clinical photograph, patient face, natural eyelids, smooth periorbital area, "
65
+ "realistic skin pores and texture, sharp focus, studio lighting, "
66
+ "DSLR quality, natural skin color"
67
+ ),
68
+ "rhytidectomy": (
69
+ "clinical photograph, patient face, defined jawline, smooth facial contour, "
70
+ "realistic skin pores and texture, sharp focus, studio lighting, "
71
+ "DSLR quality, natural skin color"
72
+ ),
73
+ "orthognathic": (
74
+ "clinical photograph, patient face, balanced jaw and chin proportions, "
75
+ "realistic skin pores and texture, sharp focus, studio lighting, "
76
+ "DSLR quality, natural skin color"
77
+ ),
78
+ "brow_lift": (
79
+ "clinical photograph, patient face, elevated brow position, smooth forehead, "
80
+ "realistic skin pores and texture, sharp focus, studio lighting, "
81
+ "DSLR quality, natural skin color"
82
+ ),
83
+ "mentoplasty": (
84
+ "clinical photograph, patient face, refined chin contour, balanced lower face, "
85
+ "realistic skin pores and texture, sharp focus, studio lighting, "
86
+ "DSLR quality, natural skin color"
87
+ ),
88
+ }
89
+
90
+ NEGATIVE_PROMPT = (
91
+ "painting, drawing, illustration, cartoon, anime, render, 3d, cgi, "
92
+ "blurry, distorted, deformed, disfigured, bad anatomy, bad proportions, "
93
+ "extra limbs, mutated, poorly drawn face, ugly, low quality, low resolution, "
94
+ "watermark, text, signature, duplicate, artifact, noise, overexposed, "
95
+ "plastic skin, waxy, smooth skin, airbrushed, oversaturated"
96
+ )
97
+
98
+ # Skin tone matching: minimum mask alpha to include in LAB stats transfer
99
+ _SKIN_TONE_MASK_THRESHOLD = 0.3
100
+ # Epsilon to avoid division by zero in std normalization
101
+ _STD_EPSILON = 1e-6
102
+ # Default SD1.5 resolution (all pipelines resize to this)
103
+ _SD15_RESOLUTION = 512
104
+ # Intensity mapping: UI scale (0-100) to displacement model scale (0-2)
105
+ _INTENSITY_UI_TO_MODEL = 50.0
106
+ # Face view classification thresholds (degrees)
107
+ _YAW_FRONTAL_MAX = 15
108
+ _YAW_THREE_QUARTER_MAX = 45
109
+ _YAW_WARNING_THRESHOLD = 30
110
+ # Max pitch scale factor (maps pitch ratio to degrees)
111
+ _PITCH_SCALE = 45
112
+
113
+
114
+ def mask_composite(
115
+ warped: np.ndarray,
116
+ original: np.ndarray,
117
+ mask: np.ndarray,
118
+ use_laplacian: bool = True,
119
+ ) -> np.ndarray:
120
+ """Composite warped image into original using ONLY the mask region.
121
+
122
+ Uses Laplacian pyramid blending by default for seamless transitions.
123
+ Falls back to simple alpha blend if Laplacian unavailable.
124
+ Matches skin tone in LAB space to prevent any color shift.
125
+ """
126
+ mask_f = mask.astype(np.float32)
127
+ if mask_f.max() > 1.0:
128
+ mask_f = mask_f / 255.0
129
+
130
+ # Match color of warped region to original skin tone in LAB space
131
+ corrected = _match_skin_tone(warped, original, mask_f)
132
+
133
+ if use_laplacian:
134
+ try:
135
+ from landmarkdiff.postprocess import laplacian_pyramid_blend
136
+
137
+ return laplacian_pyramid_blend(corrected, original, mask_f)
138
+ except Exception:
139
+ logger.debug("Laplacian blend failed, using alpha blend", exc_info=True)
140
+
141
+ # Fallback: simple alpha blend
142
+ mask_3ch = mask_to_3channel(mask_f)
143
+ result = (
144
+ corrected.astype(np.float32) * mask_3ch + original.astype(np.float32) * (1.0 - mask_3ch)
145
+ ).astype(np.uint8)
146
+
147
+ return result
148
+
149
+
150
+ def _match_skin_tone(source: np.ndarray, target: np.ndarray, mask: np.ndarray) -> np.ndarray:
151
+ """Match source skin tone to target within mask, preserving structure.
152
+
153
+ Works in LAB space: transfers L (luminance) and AB (color) statistics
154
+ from the original to the warped image so skin tone is preserved exactly.
155
+ """
156
+ mask_bool = mask > _SKIN_TONE_MASK_THRESHOLD
157
+ if not np.any(mask_bool):
158
+ return source
159
+
160
+ src_lab = cv2.cvtColor(source, cv2.COLOR_BGR2LAB).astype(np.float32)
161
+ tgt_lab = cv2.cvtColor(target, cv2.COLOR_BGR2LAB).astype(np.float32)
162
+
163
+ # Match each LAB channel's statistics in the mask region
164
+ for ch in range(3):
165
+ src_vals = src_lab[:, :, ch][mask_bool]
166
+ tgt_vals = tgt_lab[:, :, ch][mask_bool]
167
+
168
+ src_mean, src_std = np.mean(src_vals), np.std(src_vals) + _STD_EPSILON
169
+ tgt_mean, tgt_std = np.mean(tgt_vals), np.std(tgt_vals) + _STD_EPSILON
170
+
171
+ # Normalize source to match target's distribution
172
+ src_lab[:, :, ch] = np.where(
173
+ mask_bool,
174
+ (src_lab[:, :, ch] - src_mean) * (tgt_std / src_std) + tgt_mean,
175
+ src_lab[:, :, ch],
176
+ )
177
+
178
+ src_lab = np.clip(src_lab, 0, 255)
179
+ return cv2.cvtColor(src_lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
180
+
181
+
182
+ class LandmarkDiffPipeline:
183
+ """End-to-end pipeline: image -> landmarks -> manipulate -> generate.
184
+
185
+ Modes:
186
+ - 'controlnet': CrucibleAI/ControlNetMediaPipeFace + SD1.5 (30 steps)
187
+ - 'controlnet_fast': ControlNet + LCM-LoRA (4 steps, CPU-viable)
188
+ - 'controlnet_ip': ControlNet + IP-Adapter for identity preservation
189
+ - 'img2img': SD1.5 img2img with mask compositing
190
+ - 'tps': Pure geometric TPS warp (no diffusion, instant)
191
+ """
192
+
193
+ # Default IP-Adapter model for SD1.5 face identity
194
+ IP_ADAPTER_REPO = "h94/IP-Adapter"
195
+ IP_ADAPTER_SUBFOLDER = "models"
196
+ IP_ADAPTER_WEIGHT_NAME = "ip-adapter-plus-face_sd15.bin"
197
+ IP_ADAPTER_SCALE_DEFAULT = 0.6
198
+
199
+ # LCM-LoRA for fast inference (2-4 steps instead of 30)
200
+ LCM_LORA_REPO = "latent-consistency/lcm-lora-sdv1-5"
201
+
202
+ def __init__(
203
+ self,
204
+ mode: str = "img2img",
205
+ controlnet_id: str = "CrucibleAI/ControlNetMediaPipeFace",
206
+ controlnet_checkpoint: str | None = None,
207
+ base_model_id: str | None = None,
208
+ device: torch.device | None = None,
209
+ dtype: torch.dtype | None = None,
210
+ ip_adapter_scale: float = 0.6,
211
+ clinical_flags: ClinicalFlags | None = None,
212
+ displacement_model_path: str | None = None,
213
+ ):
214
+ self.mode = mode
215
+ self.device = device or get_device()
216
+ self.ip_adapter_scale = ip_adapter_scale
217
+ self.clinical_flags = clinical_flags
218
+ self.controlnet_checkpoint = controlnet_checkpoint
219
+
220
+ # Load displacement model for data-driven manipulation
221
+ self._displacement_model = None
222
+ if displacement_model_path:
223
+ try:
224
+ from landmarkdiff.displacement_model import DisplacementModel
225
+
226
+ self._displacement_model = DisplacementModel.load(displacement_model_path)
227
+ logger.info("Displacement model loaded: %s", self._displacement_model.procedures)
228
+ except Exception as e:
229
+ logger.warning("Failed to load displacement model: %s", e)
230
+
231
+ if self.device.type == "mps":
232
+ self.dtype = torch.float32
233
+ elif dtype:
234
+ self.dtype = dtype
235
+ else:
236
+ self.dtype = torch.float16 if self.device.type == "cuda" else torch.float32
237
+
238
+ if base_model_id:
239
+ self.base_model_id = base_model_id
240
+ else:
241
+ self.base_model_id = "runwayml/stable-diffusion-v1-5"
242
+
243
+ self.controlnet_id = controlnet_id
244
+ self._pipe = None
245
+ self._ip_adapter_loaded = False
246
+ self._lcm_loaded = False
247
+
248
+ def load(self) -> None:
249
+ if self.mode == "tps":
250
+ logger.info("TPS mode -- no model to load")
251
+ return
252
+ if self.mode in ("controlnet", "controlnet_ip", "controlnet_fast"):
253
+ self._load_controlnet()
254
+ if self.mode == "controlnet_fast":
255
+ self._load_lcm_lora()
256
+ elif self.mode == "controlnet_ip":
257
+ self._load_ip_adapter()
258
+ else:
259
+ self._load_img2img()
260
+
261
+ def _load_controlnet(self) -> None:
262
+ from diffusers import (
263
+ ControlNetModel,
264
+ DPMSolverMultistepScheduler,
265
+ StableDiffusionControlNetPipeline,
266
+ )
267
+
268
+ _local_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
269
+ _kw: dict = {"local_files_only": True} if _local_only else {}
270
+
271
+ if self.controlnet_checkpoint:
272
+ # Load fine-tuned ControlNet from local checkpoint
273
+ ckpt_path = Path(self.controlnet_checkpoint)
274
+ # Support both direct path and training checkpoint structure
275
+ if (ckpt_path / "controlnet_ema").exists():
276
+ ckpt_path = ckpt_path / "controlnet_ema"
277
+ logger.info("Loading fine-tuned ControlNet from %s", ckpt_path)
278
+ controlnet = ControlNetModel.from_pretrained(
279
+ str(ckpt_path),
280
+ torch_dtype=self.dtype,
281
+ )
282
+ else:
283
+ logger.info("Loading ControlNet from %s", self.controlnet_id)
284
+ controlnet = ControlNetModel.from_pretrained(
285
+ self.controlnet_id,
286
+ subfolder="diffusion_sd15",
287
+ torch_dtype=self.dtype,
288
+ **_kw,
289
+ )
290
+ logger.info("Loading base model from %s", self.base_model_id)
291
+ self._pipe = StableDiffusionControlNetPipeline.from_pretrained(
292
+ self.base_model_id,
293
+ controlnet=controlnet,
294
+ torch_dtype=self.dtype,
295
+ safety_checker=None,
296
+ requires_safety_checker=False,
297
+ **_kw,
298
+ )
299
+ # DPM++ 2M Karras -- produces more photorealistic output than UniPC
300
+ self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(
301
+ self._pipe.scheduler.config,
302
+ algorithm_type="dpmsolver++",
303
+ use_karras_sigmas=True,
304
+ )
305
+ # FP32 VAE decode -- prevents color banding artifacts on skin tones
306
+ if hasattr(self._pipe, "vae") and self._pipe.vae is not None:
307
+ self._pipe.vae.config.force_upcast = True
308
+ self._apply_device_optimizations()
309
+
310
+ def _load_lcm_lora(self) -> None:
311
+ """Load LCM-LoRA for fast 4-step inference.
312
+
313
+ LCM-LoRA (Latent Consistency Model) distills the denoising process
314
+ into 2-4 steps, making CPU inference viable (~3-8s vs ~60s+).
315
+ Replaces the scheduler with LCMScheduler for consistency sampling.
316
+ """
317
+ if self._pipe is None:
318
+ raise RuntimeError("Base pipeline must be loaded before LCM-LoRA")
319
+ try:
320
+ from diffusers import LCMScheduler
321
+
322
+ logger.info("Loading LCM-LoRA from %s", self.LCM_LORA_REPO)
323
+ _local_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
324
+ _kw: dict = {"local_files_only": True} if _local_only else {}
325
+ self._pipe.load_lora_weights(self.LCM_LORA_REPO, **_kw)
326
+ self._pipe.scheduler = LCMScheduler.from_config(self._pipe.scheduler.config)
327
+ self._lcm_loaded = True
328
+ logger.info("LCM-LoRA loaded -- 4-step inference enabled")
329
+ except Exception as e:
330
+ logger.warning("LCM-LoRA load failed: %s", e)
331
+ logger.warning("Falling back to standard scheduler (30 steps)")
332
+ self._lcm_loaded = False
333
+
334
+ def _load_ip_adapter(self) -> None:
335
+ """Load IP-Adapter for identity-preserving generation.
336
+
337
+ Uses h94/IP-Adapter-FaceID with CLIP image encoder to condition
338
+ generation on the input face identity.
339
+ """
340
+ if self._pipe is None:
341
+ raise RuntimeError("Base pipeline must be loaded before IP-Adapter")
342
+ try:
343
+ logger.info("Loading IP-Adapter (%s)", self.IP_ADAPTER_WEIGHT_NAME)
344
+ self._pipe.load_ip_adapter(
345
+ self.IP_ADAPTER_REPO,
346
+ subfolder=self.IP_ADAPTER_SUBFOLDER,
347
+ weight_name=self.IP_ADAPTER_WEIGHT_NAME,
348
+ )
349
+ self._pipe.set_ip_adapter_scale(self.ip_adapter_scale)
350
+ self._ip_adapter_loaded = True
351
+ logger.info("IP-Adapter loaded (scale=%s)", self.ip_adapter_scale)
352
+ except Exception as e:
353
+ logger.warning("IP-Adapter load failed: %s", e)
354
+ logger.warning("Falling back to ControlNet-only mode")
355
+ self._ip_adapter_loaded = False
356
+
357
+ def _load_img2img(self) -> None:
358
+ from diffusers import (
359
+ DPMSolverMultistepScheduler,
360
+ StableDiffusionImg2ImgPipeline,
361
+ )
362
+
363
+ _local_only = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
364
+ _kw: dict = {"local_files_only": True} if _local_only else {}
365
+
366
+ logger.info("Loading SD1.5 img2img from %s", self.base_model_id)
367
+ self._pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
368
+ self.base_model_id,
369
+ torch_dtype=self.dtype,
370
+ safety_checker=None,
371
+ requires_safety_checker=False,
372
+ **_kw,
373
+ )
374
+ self._pipe.scheduler = DPMSolverMultistepScheduler.from_config(self._pipe.scheduler.config)
375
+ self._apply_device_optimizations()
376
+
377
+ def _apply_device_optimizations(self) -> None:
378
+ if self.device.type == "mps":
379
+ self._pipe = self._pipe.to(self.device)
380
+ self._pipe.enable_attention_slicing()
381
+ elif self.device.type == "cuda":
382
+ try:
383
+ self._pipe.enable_model_cpu_offload()
384
+ except Exception:
385
+ self._pipe = self._pipe.to(self.device)
386
+ else:
387
+ self._pipe.enable_sequential_cpu_offload()
388
+ logger.info("Pipeline loaded on %s (%s)", self.device, self.dtype)
389
+
390
+ @property
391
+ def is_loaded(self) -> bool:
392
+ return self._pipe is not None or self.mode == "tps"
393
+
394
+ def generate(
395
+ self,
396
+ image: np.ndarray,
397
+ procedure: str = "rhinoplasty",
398
+ intensity: float = 50.0,
399
+ num_inference_steps: int = 30,
400
+ guidance_scale: float = 9.0,
401
+ controlnet_conditioning_scale: float = 0.9,
402
+ strength: float = 0.5,
403
+ seed: int | None = None,
404
+ clinical_flags: ClinicalFlags | None = None,
405
+ postprocess: bool = True,
406
+ use_gfpgan: bool = False,
407
+ ) -> dict:
408
+ if not self.is_loaded:
409
+ raise RuntimeError("Pipeline not loaded. Call .load() first.")
410
+
411
+ flags = clinical_flags or self.clinical_flags
412
+ res = _SD15_RESOLUTION
413
+ image_512 = cv2.resize(image, (res, res))
414
+
415
+ face = extract_landmarks(image_512)
416
+ if face is None:
417
+ raise ValueError("No face detected in image.")
418
+
419
+ # Estimate face view angle for multi-view awareness
420
+ view_info = estimate_face_view(face)
421
+
422
+ # Use displacement model for data-driven manipulation if available
423
+ manipulation_mode = "preset"
424
+ if self._displacement_model and procedure in self._displacement_model.procedures:
425
+ try:
426
+ rng = np.random.default_rng(seed) if seed is not None else np.random.default_rng()
427
+ # Map UI intensity (0-100) to displacement model intensity (0-2)
428
+ dm_intensity = intensity / _INTENSITY_UI_TO_MODEL # 50 -> 1.0x mean displacement
429
+ displacement = self._displacement_model.get_displacement_field(
430
+ procedure,
431
+ intensity=dm_intensity,
432
+ noise_scale=0.3,
433
+ rng=rng,
434
+ )
435
+ # Apply displacement to landmarks
436
+ new_lm = face.landmarks.copy()
437
+ n = min(len(new_lm), len(displacement))
438
+ new_lm[:n, 0] += displacement[:n, 0]
439
+ new_lm[:n, 1] += displacement[:n, 1]
440
+ new_lm[:, 0] = np.clip(new_lm[:, 0], 0.01, 0.99)
441
+ new_lm[:, 1] = np.clip(new_lm[:, 1], 0.01, 0.99)
442
+ manipulated = FaceLandmarks(
443
+ landmarks=new_lm,
444
+ image_width=res,
445
+ image_height=res,
446
+ confidence=face.confidence,
447
+ )
448
+ manipulation_mode = "displacement_model"
449
+ except Exception as exc:
450
+ logger.warning("Displacement model failed, falling back to preset: %s", exc)
451
+ manipulated = apply_procedure_preset(
452
+ face,
453
+ procedure,
454
+ intensity,
455
+ image_size=res,
456
+ clinical_flags=flags,
457
+ )
458
+ else:
459
+ manipulated = apply_procedure_preset(
460
+ face,
461
+ procedure,
462
+ intensity,
463
+ image_size=res,
464
+ clinical_flags=flags,
465
+ )
466
+ landmark_img = render_landmark_image(manipulated, res, res)
467
+ mask = generate_surgical_mask(
468
+ face,
469
+ procedure,
470
+ res,
471
+ res,
472
+ clinical_flags=flags,
473
+ )
474
+
475
+ generator = None
476
+ if seed is not None:
477
+ generator = torch.Generator(device="cpu").manual_seed(seed)
478
+
479
+ prompt = PROCEDURE_PROMPTS.get(procedure, "a photo of a person's face")
480
+
481
+ # Step 1: TPS geometric warp (always computed -- the geometric baseline)
482
+ tps_warped = warp_image_tps(image_512, face.pixel_coords, manipulated.pixel_coords)
483
+
484
+ if self.mode == "tps":
485
+ raw_output = tps_warped
486
+ elif self.mode in ("controlnet", "controlnet_ip", "controlnet_fast"):
487
+ # LCM mode: override to 4 steps, low guidance (LCM works best with cfg=1-2)
488
+ if self._lcm_loaded:
489
+ num_inference_steps = min(num_inference_steps, 4)
490
+ guidance_scale = min(guidance_scale, 1.5)
491
+ ip_image = numpy_to_pil(image_512) if self._ip_adapter_loaded else None
492
+ try:
493
+ raw_output = self._generate_controlnet(
494
+ image_512,
495
+ landmark_img,
496
+ prompt,
497
+ num_inference_steps,
498
+ guidance_scale,
499
+ controlnet_conditioning_scale,
500
+ generator,
501
+ ip_adapter_image=ip_image,
502
+ )
503
+ except torch.cuda.OutOfMemoryError as exc:
504
+ torch.cuda.empty_cache()
505
+ raise RuntimeError(
506
+ "GPU out of memory during inference. Try reducing "
507
+ "num_inference_steps or switching to mode='tps' for CPU-only."
508
+ ) from exc
509
+ else:
510
+ try:
511
+ raw_output = self._generate_img2img(
512
+ tps_warped,
513
+ mask,
514
+ prompt,
515
+ num_inference_steps,
516
+ guidance_scale,
517
+ strength,
518
+ generator,
519
+ )
520
+ except torch.cuda.OutOfMemoryError as exc:
521
+ torch.cuda.empty_cache()
522
+ raise RuntimeError(
523
+ "GPU out of memory during inference. Try reducing "
524
+ "num_inference_steps or switching to mode='tps' for CPU-only."
525
+ ) from exc
526
+
527
+ # Step 2: Post-processing for photorealism (neural + classical pipeline)
528
+ identity_check = None
529
+ restore_used = "none"
530
+ if postprocess and self.mode != "tps":
531
+ from landmarkdiff.postprocess import full_postprocess
532
+
533
+ pp_result = full_postprocess(
534
+ generated=raw_output,
535
+ original=image_512,
536
+ mask=mask,
537
+ restore_mode="codeformer" if use_gfpgan else "none",
538
+ use_realesrgan=use_gfpgan,
539
+ use_laplacian_blend=True,
540
+ sharpen_strength=0.25,
541
+ verify_identity=True,
542
+ )
543
+ composited = pp_result["image"]
544
+ identity_check = pp_result["identity_check"]
545
+ restore_used = pp_result["restore_used"]
546
+ else:
547
+ composited = mask_composite(raw_output, image_512, mask)
548
+
549
+ return {
550
+ "output": composited,
551
+ "output_raw": raw_output,
552
+ "output_tps": tps_warped,
553
+ "input": image_512,
554
+ "landmarks_original": face,
555
+ "landmarks_manipulated": manipulated,
556
+ "conditioning": landmark_img,
557
+ "mask": mask,
558
+ "procedure": procedure,
559
+ "intensity": intensity,
560
+ "device": str(self.device),
561
+ "mode": self.mode,
562
+ "view_info": view_info,
563
+ "ip_adapter_active": self._ip_adapter_loaded,
564
+ "lcm_active": self._lcm_loaded,
565
+ "identity_check": identity_check,
566
+ "restore_used": restore_used,
567
+ "manipulation_mode": manipulation_mode,
568
+ }
569
+
570
+ def _generate_controlnet(
571
+ self,
572
+ image: np.ndarray,
573
+ conditioning: np.ndarray,
574
+ prompt: str,
575
+ steps: int,
576
+ cfg: float,
577
+ cn_scale: float,
578
+ generator: torch.Generator | None,
579
+ ip_adapter_image: Image.Image | None = None,
580
+ ) -> np.ndarray:
581
+ kwargs = dict(
582
+ prompt=prompt,
583
+ negative_prompt=NEGATIVE_PROMPT,
584
+ image=numpy_to_pil(conditioning), # control conditioning only
585
+ num_inference_steps=steps,
586
+ guidance_scale=cfg,
587
+ controlnet_conditioning_scale=cn_scale,
588
+ generator=generator,
589
+ )
590
+ if ip_adapter_image is not None and self._ip_adapter_loaded:
591
+ kwargs["ip_adapter_image"] = ip_adapter_image
592
+ result = self._pipe(**kwargs)
593
+ return pil_to_numpy(result.images[0])
594
+
595
+ def _generate_img2img(
596
+ self,
597
+ image: np.ndarray,
598
+ mask: np.ndarray,
599
+ prompt: str,
600
+ steps: int,
601
+ cfg: float,
602
+ strength: float,
603
+ generator: torch.Generator | None,
604
+ ) -> np.ndarray:
605
+ result = self._pipe(
606
+ prompt=prompt,
607
+ negative_prompt=NEGATIVE_PROMPT,
608
+ image=numpy_to_pil(image),
609
+ num_inference_steps=steps,
610
+ guidance_scale=cfg,
611
+ strength=strength,
612
+ generator=generator,
613
+ )
614
+ return pil_to_numpy(result.images[0])
615
+
616
+
617
+ def estimate_face_view(face: FaceLandmarks) -> dict:
618
+ """Estimate face orientation from landmarks for multi-view awareness.
619
+
620
+ Uses the nose tip (idx 1), left ear (idx 234), and right ear (idx 454) to
621
+ estimate yaw angle. Pitch from forehead (idx 10) and chin (idx 152).
622
+
623
+ Returns dict with yaw, pitch (degrees), and view classification.
624
+ """
625
+ coords = face.pixel_coords
626
+ # MediaPipe landmark indices for key anatomical points
627
+ nose_tip = coords[1] # nose tip
628
+ left_ear = coords[234] # left tragion (ear)
629
+ right_ear = coords[454] # right tragion (ear)
630
+ forehead = coords[10] # forehead center
631
+ chin = coords[152] # chin center
632
+
633
+ # Yaw: ratio of nose-to-ear distances (symmetric = 0 degrees)
634
+ left_dist = np.linalg.norm(nose_tip - left_ear)
635
+ right_dist = np.linalg.norm(nose_tip - right_ear)
636
+ total = left_dist + right_dist
637
+ if total < 1.0:
638
+ yaw = 0.0
639
+ else:
640
+ ratio = (right_dist - left_dist) / total
641
+ yaw = float(np.arcsin(np.clip(ratio, -1, 1)) * 180 / np.pi)
642
+
643
+ # Pitch: nose-to-chin vs forehead-to-nose vertical ratio
644
+ upper = np.linalg.norm(forehead - nose_tip)
645
+ lower = np.linalg.norm(nose_tip - chin)
646
+ if upper + lower < 1.0:
647
+ pitch = 0.0
648
+ else:
649
+ pitch_ratio = (lower - upper) / (upper + lower)
650
+ pitch = float(pitch_ratio * _PITCH_SCALE)
651
+
652
+ # Classify view
653
+ abs_yaw = abs(yaw)
654
+ if abs_yaw < _YAW_FRONTAL_MAX:
655
+ view = "frontal"
656
+ elif abs_yaw < _YAW_THREE_QUARTER_MAX:
657
+ view = "three_quarter"
658
+ else:
659
+ view = "profile"
660
+
661
+ return {
662
+ "yaw": round(yaw, 1),
663
+ "pitch": round(pitch, 1),
664
+ "view": view,
665
+ "is_frontal": abs_yaw < _YAW_FRONTAL_MAX,
666
+ "warning": "Side-view detected: results may be less accurate"
667
+ if abs_yaw > _YAW_WARNING_THRESHOLD
668
+ else None,
669
+ }
670
+
671
+
672
+ def run_inference(
673
+ image_path: str,
674
+ procedure: str = "rhinoplasty",
675
+ intensity: float = 50.0,
676
+ output_dir: str = "scripts/inference_output",
677
+ seed: int = 42,
678
+ mode: str = "img2img",
679
+ ip_adapter_scale: float = 0.6,
680
+ controlnet_checkpoint: str | None = None,
681
+ displacement_model_path: str | None = None,
682
+ ) -> None:
683
+ out = Path(output_dir)
684
+ out.mkdir(parents=True, exist_ok=True)
685
+
686
+ image = cv2.imread(image_path)
687
+ if image is None:
688
+ logger.error("Could not load %s", image_path)
689
+ sys.exit(1)
690
+
691
+ pipe = LandmarkDiffPipeline(
692
+ mode=mode,
693
+ ip_adapter_scale=ip_adapter_scale,
694
+ controlnet_checkpoint=controlnet_checkpoint,
695
+ displacement_model_path=displacement_model_path,
696
+ )
697
+ pipe.load()
698
+
699
+ logger.info("Generating %s prediction (intensity=%s, mode=%s)", procedure, intensity, mode)
700
+ result = pipe.generate(image, procedure=procedure, intensity=intensity, seed=seed)
701
+
702
+ cv2.imwrite(str(out / "input.png"), result["input"])
703
+ cv2.imwrite(str(out / "output.png"), result["output"])
704
+ cv2.imwrite(str(out / "output_raw.png"), result["output_raw"])
705
+ cv2.imwrite(str(out / "output_tps.png"), result["output_tps"])
706
+ cv2.imwrite(str(out / "conditioning.png"), result["conditioning"])
707
+ cv2.imwrite(str(out / "mask.png"), (result["mask"] * 255).astype(np.uint8))
708
+
709
+ comparison = np.hstack([result["input"], result["output_tps"], result["output"]])
710
+ cv2.imwrite(str(out / "comparison.png"), comparison)
711
+
712
+ view = result.get("view_info", {})
713
+ if view.get("warning"):
714
+ logger.warning("%s", view["warning"])
715
+ logger.info("Face view: %s (yaw=%s)", view.get("view", "unknown"), view.get("yaw", 0))
716
+ logger.info("Results saved to %s/", out)
717
+
718
+
719
+ if __name__ == "__main__":
720
+ import argparse
721
+
722
+ parser = argparse.ArgumentParser(description="LandmarkDiff inference")
723
+ parser.add_argument("image", help="Path to face image")
724
+ parser.add_argument("--procedure", default="rhinoplasty")
725
+ parser.add_argument("--intensity", type=float, default=50.0)
726
+ parser.add_argument("--output", default="scripts/inference_output")
727
+ parser.add_argument("--seed", type=int, default=42)
728
+ parser.add_argument(
729
+ "--mode",
730
+ default="img2img",
731
+ choices=["img2img", "controlnet", "controlnet_ip", "controlnet_fast", "tps"],
732
+ )
733
+ parser.add_argument("--ip-adapter-scale", type=float, default=0.6)
734
+ parser.add_argument(
735
+ "--checkpoint", default=None, help="Path to fine-tuned ControlNet checkpoint"
736
+ )
737
+ parser.add_argument(
738
+ "--displacement-model",
739
+ default=None,
740
+ help="Path to displacement_model.npz for data-driven manipulation",
741
+ )
742
+ args = parser.parse_args()
743
+
744
+ run_inference(
745
+ args.image,
746
+ args.procedure,
747
+ args.intensity,
748
+ args.output,
749
+ args.seed,
750
+ args.mode,
751
+ args.ip_adapter_scale,
752
+ args.checkpoint,
753
+ args.displacement_model,
754
+ )