landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,947 @@
1
+ """Neural face verification, distortion detection, and restoration pipeline.
2
+
3
+ End-to-end system that:
4
+ 1. Detects face distortions (blur, beauty filters, compression, warping, etc.)
5
+ 2. Classifies distortion type and severity using no-reference quality metrics
6
+ 3. Restores faces using cascaded neural networks (CodeFormer → GFPGAN → Real-ESRGAN)
7
+ 4. Verifies output identity matches input via ArcFace embeddings
8
+ 5. Scores output realism using learned perceptual metrics
9
+
10
+ Designed for:
11
+ - Cleaning scraped training data (reject/fix bad images before pair generation)
12
+ - Post-diffusion quality gate (ensure generated faces pass realism threshold)
13
+ - Filter removal (undo Snapchat/Instagram beauty filters for clinical use)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from dataclasses import dataclass, field
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ import cv2
24
+ import numpy as np
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Data structures
30
+ # ---------------------------------------------------------------------------
31
+
32
+
33
+ @dataclass
34
+ class DistortionReport:
35
+ """Analysis of detected distortions in a face image."""
36
+
37
+ # Overall quality score (0-100, higher = better)
38
+ quality_score: float = 0.0
39
+
40
+ # Individual distortion scores (0-1, higher = more distorted)
41
+ blur_score: float = 0.0 # Laplacian variance-based
42
+ noise_score: float = 0.0 # High-freq energy ratio
43
+ compression_score: float = 0.0 # JPEG block artifact detection
44
+ oversmooth_score: float = 0.0 # Beauty filter / airbrushed detection
45
+ color_cast_score: float = 0.0 # Unnatural color shift
46
+ geometric_distort: float = 0.0 # Face proportion anomalies
47
+ lighting_score: float = 0.0 # Over/under exposure
48
+
49
+ # Classification
50
+ primary_distortion: str = "none"
51
+ severity: str = "none" # none, mild, moderate, severe
52
+ is_usable: bool = True # Whether image is worth restoring vs rejecting
53
+
54
+ # Details
55
+ details: dict = field(default_factory=dict)
56
+
57
+ def summary(self) -> str:
58
+ lines = [
59
+ f"Quality Score: {self.quality_score:.1f}/100",
60
+ f"Primary Issue: {self.primary_distortion} ({self.severity})",
61
+ f"Usable: {self.is_usable}",
62
+ "",
63
+ "Distortion Breakdown:",
64
+ f" Blur: {self.blur_score:.3f}",
65
+ f" Noise: {self.noise_score:.3f}",
66
+ f" Compression: {self.compression_score:.3f}",
67
+ f" Oversmooth: {self.oversmooth_score:.3f}",
68
+ f" Color Cast: {self.color_cast_score:.3f}",
69
+ f" Geometric: {self.geometric_distort:.3f}",
70
+ f" Lighting: {self.lighting_score:.3f}",
71
+ ]
72
+ return "\n".join(lines)
73
+
74
+
75
+ @dataclass
76
+ class RestorationResult:
77
+ """Result of neural face restoration pipeline."""
78
+
79
+ restored: np.ndarray # Restored BGR image
80
+ original: np.ndarray # Original BGR image
81
+ distortion_report: DistortionReport # Pre-restoration analysis
82
+ post_quality_score: float = 0.0 # Quality after restoration
83
+ identity_similarity: float = 0.0 # ArcFace cosine sim (original vs restored)
84
+ identity_preserved: bool = True # Whether identity check passed
85
+ restoration_stages: list[str] = field(default_factory=list) # Which nets ran
86
+ improvement: float = 0.0 # quality_after - quality_before
87
+
88
+ def summary(self) -> str:
89
+ lines = [
90
+ f"Pre-restoration: {self.distortion_report.quality_score:.1f}/100",
91
+ f"Post-restoration: {self.post_quality_score:.1f}/100",
92
+ f"Improvement: +{self.improvement:.1f}",
93
+ f"Identity Sim: {self.identity_similarity:.3f}",
94
+ f"Identity OK: {self.identity_preserved}",
95
+ f"Stages Used: {' → '.join(self.restoration_stages) or 'none'}",
96
+ ]
97
+ return "\n".join(lines)
98
+
99
+
100
+ @dataclass
101
+ class BatchVerificationReport:
102
+ """Summary of batch face verification/restoration."""
103
+
104
+ total: int = 0
105
+ passed: int = 0 # Good quality, no fix needed
106
+ restored: int = 0 # Fixed and now usable
107
+ rejected: int = 0 # Too distorted to salvage
108
+ identity_failures: int = 0 # Restoration changed identity
109
+ avg_quality_before: float = 0.0
110
+ avg_quality_after: float = 0.0
111
+ avg_identity_sim: float = 0.0
112
+ distortion_counts: dict[str, int] = field(default_factory=dict)
113
+
114
+ def summary(self) -> str:
115
+ lines = [
116
+ f"Total Images: {self.total}",
117
+ f" Passed (good): {self.passed}",
118
+ f" Restored: {self.restored}",
119
+ f" Rejected: {self.rejected}",
120
+ f" Identity Fail: {self.identity_failures}",
121
+ f"Avg Quality Before: {self.avg_quality_before:.1f}",
122
+ f"Avg Quality After: {self.avg_quality_after:.1f}",
123
+ f"Avg Identity Sim: {self.avg_identity_sim:.3f}",
124
+ "",
125
+ "Distortion Breakdown:",
126
+ ]
127
+ for dist_type, count in sorted(
128
+ self.distortion_counts.items(),
129
+ key=lambda x: -x[1],
130
+ ):
131
+ lines.append(f" {dist_type}: {count}")
132
+ return "\n".join(lines)
133
+
134
+
135
+ # ---------------------------------------------------------------------------
136
+ # Distortion Detection (classical + neural)
137
+ # ---------------------------------------------------------------------------
138
+
139
+
140
+ def detect_blur(image: np.ndarray) -> float:
141
+ """Detect blur using Laplacian variance.
142
+
143
+ Low variance = blurry. We normalize to 0-1 where 1 = very blurry.
144
+ Uses both Laplacian variance and gradient magnitude for robustness.
145
+ """
146
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
147
+
148
+ # Laplacian variance (primary metric)
149
+ lap_var = cv2.Laplacian(gray, cv2.CV_64F).var()
150
+
151
+ # Gradient magnitude (secondary)
152
+ gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
153
+ gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
154
+ grad_mag = np.sqrt(gx**2 + gy**2).mean()
155
+
156
+ # Normalize: typical sharp face has lap_var > 500, grad_mag > 30
157
+ blur_lap = 1.0 - min(lap_var / 800.0, 1.0)
158
+ blur_grad = 1.0 - min(grad_mag / 50.0, 1.0)
159
+
160
+ return float(np.clip(0.6 * blur_lap + 0.4 * blur_grad, 0, 1))
161
+
162
+
163
+ def detect_noise(image: np.ndarray) -> float:
164
+ """Detect image noise level.
165
+
166
+ Estimates noise by measuring high-frequency energy in smooth regions.
167
+ Uses the median absolute deviation of the Laplacian (robust estimator).
168
+ """
169
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
170
+
171
+ # Robust noise estimation via MAD of Laplacian
172
+ lap = cv2.Laplacian(gray.astype(np.float64), cv2.CV_64F)
173
+ sigma_est = np.median(np.abs(lap)) * 1.4826 # MAD → std conversion
174
+
175
+ # Normalize: sigma > 20 is very noisy
176
+ return float(np.clip(sigma_est / 25.0, 0, 1))
177
+
178
+
179
+ def detect_compression_artifacts(image: np.ndarray) -> float:
180
+ """Detect JPEG compression block artifacts.
181
+
182
+ Measures energy at 8x8 block boundaries (JPEG DCT block size).
183
+ High boundary energy relative to interior = compression artifacts.
184
+ """
185
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
186
+ h, w = gray.shape
187
+
188
+ if h < 16 or w < 16:
189
+ return 0.0
190
+
191
+ gray_f = gray.astype(np.float64)
192
+
193
+ # Compute horizontal and vertical differences
194
+ h_diff = np.abs(np.diff(gray_f, axis=1))
195
+ v_diff = np.abs(np.diff(gray_f, axis=0))
196
+
197
+ # Energy at 8-pixel boundaries vs non-boundaries
198
+ h_boundary = h_diff[:, 7::8].mean() if h_diff[:, 7::8].size > 0 else 0
199
+ h_interior = h_diff.mean()
200
+ v_boundary = v_diff[7::8, :].mean() if v_diff[7::8, :].size > 0 else 0
201
+ v_interior = v_diff.mean()
202
+
203
+ if h_interior < 1e-6 or v_interior < 1e-6:
204
+ return 0.0
205
+
206
+ # Ratio of boundary to interior energy (>1 means block artifacts)
207
+ h_ratio = h_boundary / (h_interior + 1e-6)
208
+ v_ratio = v_boundary / (v_interior + 1e-6)
209
+ artifact_ratio = (h_ratio + v_ratio) / 2.0
210
+
211
+ # Normalize: ratio > 1.5 indicates visible artifacts
212
+ return float(np.clip((artifact_ratio - 1.0) / 0.8, 0, 1))
213
+
214
+
215
+ def detect_oversmoothing(image: np.ndarray) -> float:
216
+ """Detect beauty filter / airbrushed skin (oversmoothing).
217
+
218
+ Beauty filters remove skin texture while preserving edges. We detect
219
+ this by measuring the ratio of edge energy to texture energy.
220
+ High edge / low texture = beauty filtered.
221
+ """
222
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
223
+ h, w = gray.shape
224
+
225
+ # Focus on face center region (avoid background)
226
+ if h < 8 or w < 8:
227
+ return 0.0 # Too small to analyze
228
+ roi = gray[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
229
+
230
+ # Texture energy: variance of high-pass filtered image
231
+ blurred = cv2.GaussianBlur(roi.astype(np.float64), (0, 0), 2.0)
232
+ high_pass = roi.astype(np.float64) - blurred
233
+ texture_energy = np.var(high_pass)
234
+
235
+ # Edge energy: Canny edge density
236
+ edges = cv2.Canny(roi, 50, 150)
237
+ edge_density = np.mean(edges > 0)
238
+
239
+ # Oversmooth: low texture but edges still present
240
+ # Natural skin: texture_energy > 20, beauty filter: < 8
241
+ smooth_score = 1.0 - min(texture_energy / 30.0, 1.0)
242
+
243
+ # If there are still strong edges but no texture, it's a filter
244
+ if edge_density > 0.02:
245
+ smooth_score *= 1.3 # Amplify if edges present but no texture
246
+
247
+ return float(np.clip(smooth_score, 0, 1))
248
+
249
+
250
+ def detect_color_cast(image: np.ndarray) -> float:
251
+ """Detect unnatural color cast (Instagram-style filters).
252
+
253
+ Measures deviation of average A/B channels in LAB space from
254
+ neutral. Natural skin has consistent LAB distributions; filtered
255
+ images shift these channels.
256
+ """
257
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
258
+ h, w = image.shape[:2]
259
+
260
+ # Sample face center region
261
+ roi = lab[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
262
+
263
+ # A channel: green-red axis (neutral ~128)
264
+ # B channel: blue-yellow axis (neutral ~128)
265
+ a_mean = roi[:, :, 1].mean()
266
+ b_mean = roi[:, :, 2].mean()
267
+
268
+ # Deviation from neutral
269
+ a_dev = abs(a_mean - 128) / 128.0
270
+ b_dev = abs(b_mean - 128) / 128.0
271
+
272
+ # Also check if color distribution is unnaturally narrow (saturated filter)
273
+ a_std = roi[:, :, 1].std()
274
+ b_std = roi[:, :, 2].std()
275
+ narrow_color = max(0, 1.0 - (a_std + b_std) / 30.0)
276
+
277
+ score = 0.5 * (a_dev + b_dev) + 0.3 * narrow_color
278
+ return float(np.clip(score, 0, 1))
279
+
280
+
281
+ def detect_geometric_distortion(image: np.ndarray) -> float:
282
+ """Detect geometric face distortion (warping filters, lens distortion).
283
+
284
+ Uses MediaPipe landmarks to check face proportions against anatomical
285
+ norms. Distorted faces have abnormal inter-ocular / face-width ratios.
286
+ """
287
+ try:
288
+ from landmarkdiff.landmarks import extract_landmarks
289
+ except ImportError:
290
+ return 0.0
291
+
292
+ face = extract_landmarks(image)
293
+ if face is None:
294
+ return 0.5 # Can't detect face = possibly distorted
295
+
296
+ coords = face.pixel_coords
297
+ h, w = image.shape[:2]
298
+
299
+ if len(coords) < 478:
300
+ return 0.5 # Incomplete landmark set
301
+
302
+ # Key ratios that should be anatomically consistent
303
+ left_eye = coords[33]
304
+ right_eye = coords[263]
305
+ nose_tip = coords[1]
306
+ chin = coords[152]
307
+ forehead = coords[10]
308
+
309
+ iod = np.linalg.norm(left_eye - right_eye)
310
+ face_height = np.linalg.norm(forehead - chin)
311
+ nose_to_chin = np.linalg.norm(nose_tip - chin)
312
+
313
+ if iod < 1.0 or face_height < 1.0:
314
+ return 0.5
315
+
316
+ # Anatomical norms (approximate):
317
+ # face_height / iod ≈ 2.5-3.5
318
+ # nose_to_chin / face_height ≈ 0.3-0.45
319
+ height_ratio = face_height / iod
320
+ lower_ratio = nose_to_chin / face_height
321
+
322
+ # Score deviations from normal ranges
323
+ height_dev = max(0, abs(height_ratio - 3.0) - 0.5) / 1.5
324
+ lower_dev = max(0, abs(lower_ratio - 0.38) - 0.08) / 0.15
325
+
326
+ # Eye symmetry check (vertical alignment)
327
+ eye_tilt = abs(left_eye[1] - right_eye[1]) / (iod + 1e-6)
328
+ tilt_dev = max(0, eye_tilt - 0.05) / 0.15
329
+
330
+ score = 0.4 * height_dev + 0.3 * lower_dev + 0.3 * tilt_dev
331
+ return float(np.clip(score, 0, 1))
332
+
333
+
334
+ def detect_lighting_issues(image: np.ndarray) -> float:
335
+ """Detect over/under exposure and harsh lighting.
336
+
337
+ Checks luminance histogram for clipping and uneven distribution.
338
+ """
339
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
340
+ l_channel = lab[:, :, 0]
341
+
342
+ # Check for clipping
343
+ overexposed = np.mean(l_channel > 245) * 5 # Fraction near white
344
+ underexposed = np.mean(l_channel < 10) * 5 # Fraction near black
345
+
346
+ # Check for bimodal distribution (harsh shadows)
347
+ hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
348
+ hist_sum = hist.sum()
349
+ if hist_sum < 1e-10:
350
+ return 0.0
351
+ hist = hist / hist_sum
352
+ # Measure how spread out the histogram is
353
+ entropy = -np.sum(hist[hist > 0] * np.log2(hist[hist > 0] + 1e-10))
354
+ # Low entropy = concentrated = potentially problematic
355
+ entropy_score = max(0, 1.0 - entropy / 7.0)
356
+
357
+ score = 0.4 * overexposed + 0.4 * underexposed + 0.2 * entropy_score
358
+ return float(np.clip(score, 0, 1))
359
+
360
+
361
+ def analyze_distortions(image: np.ndarray) -> DistortionReport:
362
+ """Run full distortion analysis on a face image.
363
+
364
+ Combines all detection methods into a comprehensive report with
365
+ quality score, primary distortion classification, and severity.
366
+ """
367
+ blur = detect_blur(image)
368
+ noise = detect_noise(image)
369
+ compression = detect_compression_artifacts(image)
370
+ oversmooth = detect_oversmoothing(image)
371
+ color_cast = detect_color_cast(image)
372
+ geometric = detect_geometric_distortion(image)
373
+ lighting = detect_lighting_issues(image)
374
+
375
+ # Overall quality: weighted combination (inverted — 100 = perfect)
376
+ weighted = (
377
+ 0.25 * blur
378
+ + 0.15 * noise
379
+ + 0.10 * compression
380
+ + 0.20 * oversmooth
381
+ + 0.10 * color_cast
382
+ + 0.10 * geometric
383
+ + 0.10 * lighting
384
+ )
385
+ quality = (1.0 - weighted) * 100.0
386
+
387
+ # Classify primary distortion
388
+ scores = {
389
+ "blur": blur,
390
+ "noise": noise,
391
+ "compression": compression,
392
+ "oversmooth": oversmooth,
393
+ "color_cast": color_cast,
394
+ "geometric": geometric,
395
+ "lighting": lighting,
396
+ }
397
+ primary = max(scores, key=scores.get)
398
+ primary_val = scores[primary]
399
+
400
+ if primary_val < 0.15:
401
+ severity = "none"
402
+ primary = "none"
403
+ elif primary_val < 0.35:
404
+ severity = "mild"
405
+ elif primary_val < 0.60:
406
+ severity = "moderate"
407
+ else:
408
+ severity = "severe"
409
+
410
+ # Image is usable if quality > 30 and no severe geometric distortion
411
+ is_usable = quality > 25 and geometric < 0.7
412
+
413
+ return DistortionReport(
414
+ quality_score=quality,
415
+ blur_score=blur,
416
+ noise_score=noise,
417
+ compression_score=compression,
418
+ oversmooth_score=oversmooth,
419
+ color_cast_score=color_cast,
420
+ geometric_distort=geometric,
421
+ lighting_score=lighting,
422
+ primary_distortion=primary,
423
+ severity=severity,
424
+ is_usable=is_usable,
425
+ details=scores,
426
+ )
427
+
428
+
429
+ # ---------------------------------------------------------------------------
430
+ # Neural Face Quality Scoring (no-reference)
431
+ # ---------------------------------------------------------------------------
432
+
433
+ _FACE_QUALITY_NET = None
434
+
435
+
436
+ def _get_face_quality_scorer() -> Any:
437
+ """Get or create singleton face quality assessment model.
438
+
439
+ Uses FaceXLib's quality scorer or falls back to BRISQUE-style features.
440
+ """
441
+ global _FACE_QUALITY_NET
442
+ if _FACE_QUALITY_NET is not None:
443
+ return _FACE_QUALITY_NET
444
+
445
+ try:
446
+ from facexlib.assessment import init_assessment_model
447
+
448
+ _FACE_QUALITY_NET = init_assessment_model("hypernet")
449
+ return _FACE_QUALITY_NET
450
+ except Exception:
451
+ pass
452
+
453
+ return None
454
+
455
+
456
+ def neural_quality_score(image: np.ndarray) -> float:
457
+ """Score face quality using neural network (0-100, higher = better).
458
+
459
+ Tries FaceXLib quality assessment first, then falls back to
460
+ BRISQUE-style scoring using OpenCV's QualityBRISQUE if available,
461
+ or classical metrics as last resort.
462
+ """
463
+ # Try neural scorer
464
+ scorer = _get_face_quality_scorer()
465
+ if scorer is not None:
466
+ try:
467
+ import torch
468
+ from facexlib.utils import img2tensor
469
+
470
+ img_t = img2tensor(image / 255.0, bgr2rgb=True, float32=True)
471
+ img_t = img_t.unsqueeze(0)
472
+ if torch.cuda.is_available():
473
+ img_t = img_t.cuda()
474
+ scorer = scorer.cuda()
475
+ with torch.no_grad():
476
+ score = scorer(img_t).item()
477
+ return float(np.clip(score * 100, 0, 100))
478
+ except Exception:
479
+ pass
480
+
481
+ # Fallback: composite classical score
482
+ report = analyze_distortions(image)
483
+ return report.quality_score
484
+
485
+
486
+ # ---------------------------------------------------------------------------
487
+ # Neural Face Restoration (cascaded)
488
+ # ---------------------------------------------------------------------------
489
+
490
+
491
+ def restore_face(
492
+ image: np.ndarray,
493
+ distortion: DistortionReport | None = None,
494
+ mode: str = "auto",
495
+ codeformer_fidelity: float = 0.7,
496
+ ) -> tuple[np.ndarray, list[str]]:
497
+ """Cascaded neural face restoration.
498
+
499
+ Selects and applies restoration networks based on detected distortions:
500
+ - Blur/oversmooth → CodeFormer (recovers texture from codebook)
501
+ - Noise/compression → GFPGAN (trained on degraded faces)
502
+ - Background → Real-ESRGAN (neural 4x upscale + downsample)
503
+ - Color cast → Classical LAB correction (no neural net needed)
504
+ - Geometric → Not fixable by restoration (flag and skip)
505
+
506
+ Args:
507
+ image: BGR face image to restore.
508
+ distortion: Pre-computed distortion report (computed if None).
509
+ mode: 'auto' (choose based on distortion), 'codeformer', 'gfpgan', 'all'.
510
+ codeformer_fidelity: CodeFormer quality-fidelity tradeoff.
511
+
512
+ Returns:
513
+ Tuple of (restored BGR image, list of stages applied).
514
+ """
515
+ if distortion is None:
516
+ distortion = analyze_distortions(image)
517
+
518
+ result = image.copy()
519
+ stages = []
520
+
521
+ # Step 0: Fix color cast first (classical — fast, doesn't affect identity)
522
+ if distortion.color_cast_score > 0.25:
523
+ result = _fix_color_cast(result)
524
+ stages.append("color_correction")
525
+
526
+ # Step 1: Fix lighting issues (classical)
527
+ if distortion.lighting_score > 0.35:
528
+ result = _fix_lighting(result)
529
+ stages.append("lighting_fix")
530
+
531
+ # Step 2: Neural face restoration
532
+ if mode == "auto":
533
+ # Choose based on what's wrong
534
+ needs_face_restore = (
535
+ distortion.blur_score > 0.2
536
+ or distortion.oversmooth_score > 0.25
537
+ or distortion.noise_score > 0.25
538
+ or distortion.compression_score > 0.2
539
+ )
540
+ if needs_face_restore:
541
+ mode = "codeformer" # CodeFormer handles most degradations well
542
+
543
+ if mode in ("codeformer", "all"):
544
+ restored = _try_codeformer(result, fidelity=codeformer_fidelity)
545
+ if restored is not None:
546
+ result = restored
547
+ stages.append("codeformer")
548
+ else:
549
+ # Fallback to GFPGAN
550
+ restored = _try_gfpgan(result)
551
+ if restored is not None:
552
+ result = restored
553
+ stages.append("gfpgan")
554
+
555
+ elif mode == "gfpgan":
556
+ restored = _try_gfpgan(result)
557
+ if restored is not None:
558
+ result = restored
559
+ stages.append("gfpgan")
560
+
561
+ # Step 3: Background enhancement with Real-ESRGAN (if image is low-res)
562
+ h, w = result.shape[:2]
563
+ if h < 400 or w < 400:
564
+ enhanced = _try_realesrgan(result)
565
+ if enhanced is not None:
566
+ result = enhanced
567
+ stages.append("realesrgan")
568
+
569
+ # Step 4: Mild sharpening if still soft after restoration
570
+ post_blur = detect_blur(result)
571
+ if post_blur > 0.3:
572
+ from landmarkdiff.postprocess import frequency_aware_sharpen
573
+
574
+ result = frequency_aware_sharpen(result, strength=0.3)
575
+ stages.append("sharpen")
576
+
577
+ return result, stages
578
+
579
+
580
+ def _try_codeformer(image: np.ndarray, fidelity: float = 0.7) -> np.ndarray | None:
581
+ """Try CodeFormer restoration. Returns None if unavailable."""
582
+ try:
583
+ from landmarkdiff.postprocess import restore_face_codeformer
584
+
585
+ restored = restore_face_codeformer(image, fidelity=fidelity)
586
+ if restored is not image:
587
+ return restored
588
+ except Exception:
589
+ pass
590
+ return None
591
+
592
+
593
+ def _try_gfpgan(image: np.ndarray) -> np.ndarray | None:
594
+ """Try GFPGAN restoration. Returns None if unavailable."""
595
+ try:
596
+ from landmarkdiff.postprocess import restore_face_gfpgan
597
+
598
+ restored = restore_face_gfpgan(image)
599
+ if restored is not image:
600
+ return restored
601
+ except Exception:
602
+ pass
603
+ return None
604
+
605
+
606
+ _FV_REALESRGAN = None
607
+
608
+
609
+ def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
610
+ """Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
611
+ try:
612
+ import torch
613
+ from basicsr.archs.rrdbnet_arch import RRDBNet
614
+ from realesrgan import RealESRGANer
615
+
616
+ global _FV_REALESRGAN
617
+ if _FV_REALESRGAN is None:
618
+ model = RRDBNet(
619
+ num_in_ch=3,
620
+ num_out_ch=3,
621
+ num_feat=64,
622
+ num_block=23,
623
+ num_grow_ch=32,
624
+ scale=4,
625
+ )
626
+ _FV_REALESRGAN = RealESRGANer(
627
+ scale=4,
628
+ model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
629
+ model=model,
630
+ tile=400,
631
+ tile_pad=10,
632
+ pre_pad=0,
633
+ half=torch.cuda.is_available(),
634
+ )
635
+ enhanced, _ = _FV_REALESRGAN.enhance(image, outscale=2)
636
+
637
+ # Downsample to 512x512 for pipeline consistency
638
+ enhanced = cv2.resize(enhanced, (512, 512), interpolation=cv2.INTER_LANCZOS4)
639
+ return enhanced
640
+ except Exception:
641
+ pass
642
+ return None
643
+
644
+
645
+ def _fix_color_cast(image: np.ndarray) -> np.ndarray:
646
+ """Remove color cast by normalizing A/B channels in LAB space."""
647
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
648
+
649
+ # Center A and B channels around 128 (neutral)
650
+ for ch in [1, 2]:
651
+ channel = lab[:, :, ch]
652
+ mean_val = channel.mean()
653
+ # Shift toward neutral, but only partially to preserve natural skin tone
654
+ shift = (128.0 - mean_val) * 0.6
655
+ lab[:, :, ch] = np.clip(channel + shift, 0, 255)
656
+
657
+ return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
658
+
659
+
660
+ def _fix_lighting(image: np.ndarray) -> np.ndarray:
661
+ """Fix over/under exposure using adaptive CLAHE in LAB space."""
662
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
663
+
664
+ # CLAHE on luminance channel only
665
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
666
+ lab[:, :, 0] = clahe.apply(lab[:, :, 0])
667
+
668
+ return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
669
+
670
+
671
+ # ---------------------------------------------------------------------------
672
+ # ArcFace Identity Verification
673
+ # ---------------------------------------------------------------------------
674
+
675
+ _ARCFACE_APP = None
676
+
677
+
678
+ def _get_arcface() -> Any:
679
+ """Get or create singleton ArcFace model."""
680
+ global _ARCFACE_APP
681
+ if _ARCFACE_APP is not None:
682
+ return _ARCFACE_APP
683
+
684
+ try:
685
+ import torch
686
+ from insightface.app import FaceAnalysis
687
+
688
+ app = FaceAnalysis(
689
+ name="buffalo_l",
690
+ providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
691
+ )
692
+ ctx_id = 0 if torch.cuda.is_available() else -1
693
+ app.prepare(ctx_id=ctx_id, det_size=(320, 320))
694
+ _ARCFACE_APP = app
695
+ return app
696
+ except Exception:
697
+ return None
698
+
699
+
700
+ def get_face_embedding(image: np.ndarray) -> np.ndarray | None:
701
+ """Extract ArcFace 512-d embedding from a face image.
702
+
703
+ Returns None if no face detected or InsightFace unavailable.
704
+ """
705
+ app = _get_arcface()
706
+ if app is None:
707
+ return None
708
+
709
+ try:
710
+ faces = app.get(image)
711
+ if faces:
712
+ emb = faces[0].embedding
713
+ if np.linalg.norm(emb) < 1e-6:
714
+ logger.warning("ArcFace returned near-zero embedding (occluded face?)")
715
+ return None
716
+ return emb
717
+ except Exception:
718
+ pass
719
+ return None
720
+
721
+
722
+ def verify_identity(
723
+ original: np.ndarray,
724
+ restored: np.ndarray,
725
+ threshold: float = 0.6,
726
+ ) -> tuple[float, bool]:
727
+ """Compare identity between original and restored using ArcFace.
728
+
729
+ Returns (cosine_similarity, passed).
730
+ Similarity > threshold means same person (threshold=0.6 is conservative).
731
+ """
732
+ emb_orig = get_face_embedding(original)
733
+ emb_rest = get_face_embedding(restored)
734
+
735
+ if emb_orig is None or emb_rest is None:
736
+ return -1.0, True # Can't verify — assume OK
737
+
738
+ sim = float(
739
+ np.dot(emb_orig, emb_rest) / (np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8)
740
+ )
741
+ sim = float(np.clip(sim, -1, 1))
742
+ return sim, sim >= threshold
743
+
744
+
745
+ # ---------------------------------------------------------------------------
746
+ # Full Verification + Restoration Pipeline
747
+ # ---------------------------------------------------------------------------
748
+
749
+
750
+ def verify_and_restore(
751
+ image: np.ndarray,
752
+ quality_threshold: float = 60.0,
753
+ identity_threshold: float = 0.6,
754
+ restore_mode: str = "auto",
755
+ codeformer_fidelity: float = 0.7,
756
+ ) -> RestorationResult:
757
+ """Full pipeline: analyze → restore → verify identity.
758
+
759
+ This is the main entry point for the face verifier. It:
760
+ 1. Analyzes the input for distortions
761
+ 2. If quality is below threshold, applies neural restoration
762
+ 3. Verifies the restored face preserves identity
763
+ 4. Returns comprehensive result with metrics
764
+
765
+ Args:
766
+ image: BGR face image.
767
+ quality_threshold: Min quality to skip restoration (0-100).
768
+ identity_threshold: Min ArcFace similarity to pass (0-1).
769
+ restore_mode: 'auto', 'codeformer', 'gfpgan', 'all'.
770
+ codeformer_fidelity: CodeFormer quality-fidelity balance.
771
+
772
+ Returns:
773
+ RestorationResult with restored image and full metrics.
774
+ """
775
+ # Step 1: Analyze distortions
776
+ report = analyze_distortions(image)
777
+
778
+ # Step 2: Decide if restoration needed
779
+ if report.quality_score >= quality_threshold and report.severity in ("none", "mild"):
780
+ # Image is good enough — no restoration needed
781
+ return RestorationResult(
782
+ restored=image.copy(),
783
+ original=image.copy(),
784
+ distortion_report=report,
785
+ post_quality_score=report.quality_score,
786
+ identity_similarity=1.0,
787
+ identity_preserved=True,
788
+ restoration_stages=[],
789
+ improvement=0.0,
790
+ )
791
+
792
+ if not report.is_usable:
793
+ # Too distorted to salvage
794
+ return RestorationResult(
795
+ restored=image.copy(),
796
+ original=image.copy(),
797
+ distortion_report=report,
798
+ post_quality_score=report.quality_score,
799
+ identity_similarity=0.0,
800
+ identity_preserved=False,
801
+ restoration_stages=["rejected"],
802
+ improvement=0.0,
803
+ )
804
+
805
+ # Step 3: Neural restoration
806
+ restored, stages = restore_face(
807
+ image,
808
+ distortion=report,
809
+ mode=restore_mode,
810
+ codeformer_fidelity=codeformer_fidelity,
811
+ )
812
+
813
+ # Step 4: Post-restoration quality check
814
+ post_quality = neural_quality_score(restored)
815
+
816
+ # Step 5: Identity verification
817
+ sim, id_ok = verify_identity(image, restored, threshold=identity_threshold)
818
+
819
+ return RestorationResult(
820
+ restored=restored,
821
+ original=image.copy(),
822
+ distortion_report=report,
823
+ post_quality_score=post_quality,
824
+ identity_similarity=sim,
825
+ identity_preserved=id_ok,
826
+ restoration_stages=stages,
827
+ improvement=post_quality - report.quality_score,
828
+ )
829
+
830
+
831
+ # ---------------------------------------------------------------------------
832
+ # Batch Processing
833
+ # ---------------------------------------------------------------------------
834
+
835
+
836
+ def verify_batch(
837
+ image_dir: str,
838
+ output_dir: str | None = None,
839
+ quality_threshold: float = 60.0,
840
+ identity_threshold: float = 0.6,
841
+ restore_mode: str = "auto",
842
+ save_rejected: bool = False,
843
+ extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp", ".bmp"),
844
+ ) -> BatchVerificationReport:
845
+ """Process a directory of face images: analyze, restore, verify, sort.
846
+
847
+ Outputs:
848
+ - {output_dir}/passed/ — good images (no fix needed)
849
+ - {output_dir}/restored/ — fixed images
850
+ - {output_dir}/rejected/ — too distorted to use (if save_rejected=True)
851
+ - {output_dir}/report.txt — batch verification report
852
+
853
+ Args:
854
+ image_dir: Directory of face images to process.
855
+ output_dir: Where to save results (default: {image_dir}_verified/).
856
+ quality_threshold: Min quality to pass without restoration.
857
+ identity_threshold: Min identity similarity after restoration.
858
+ restore_mode: 'auto', 'codeformer', 'gfpgan', 'all'.
859
+ save_rejected: Whether to copy rejected images to rejected/ subdir.
860
+ extensions: File extensions to process.
861
+
862
+ Returns:
863
+ BatchVerificationReport with summary statistics.
864
+ """
865
+ image_path = Path(image_dir)
866
+ if output_dir is None:
867
+ out_path = image_path.parent / f"{image_path.name}_verified"
868
+ else:
869
+ out_path = Path(output_dir)
870
+
871
+ # Create output dirs
872
+ passed_dir = out_path / "passed"
873
+ restored_dir = out_path / "restored"
874
+ rejected_dir = out_path / "rejected"
875
+ passed_dir.mkdir(parents=True, exist_ok=True)
876
+ restored_dir.mkdir(parents=True, exist_ok=True)
877
+ if save_rejected:
878
+ rejected_dir.mkdir(parents=True, exist_ok=True)
879
+
880
+ # Find all images
881
+ image_files = sorted(
882
+ [f for f in image_path.iterdir() if f.suffix.lower() in extensions and f.is_file()]
883
+ )
884
+
885
+ report = BatchVerificationReport(total=len(image_files))
886
+ quality_before = []
887
+ quality_after = []
888
+ identity_sims = []
889
+
890
+ for i, img_file in enumerate(image_files):
891
+ if (i + 1) % 50 == 0 or i == 0:
892
+ logger.info("Processing %d/%d: %s", i + 1, len(image_files), img_file.name)
893
+
894
+ image = cv2.imread(str(img_file))
895
+ if image is None:
896
+ report.rejected += 1
897
+ continue
898
+
899
+ # Resize to 512x512 for consistency
900
+ image = cv2.resize(image, (512, 512))
901
+
902
+ # Run verification + restoration
903
+ result = verify_and_restore(
904
+ image,
905
+ quality_threshold=quality_threshold,
906
+ identity_threshold=identity_threshold,
907
+ restore_mode=restore_mode,
908
+ )
909
+
910
+ quality_before.append(result.distortion_report.quality_score)
911
+ quality_after.append(result.post_quality_score)
912
+
913
+ # Track distortion types
914
+ dist_type = result.distortion_report.primary_distortion
915
+ report.distortion_counts[dist_type] = report.distortion_counts.get(dist_type, 0) + 1
916
+
917
+ if not result.distortion_report.is_usable or "rejected" in result.restoration_stages:
918
+ report.rejected += 1
919
+ if save_rejected:
920
+ cv2.imwrite(str(rejected_dir / img_file.name), image)
921
+ elif not result.restoration_stages:
922
+ # Passed without restoration
923
+ report.passed += 1
924
+ cv2.imwrite(str(passed_dir / img_file.name), image)
925
+ else:
926
+ # Restored
927
+ if result.identity_preserved:
928
+ report.restored += 1
929
+ cv2.imwrite(str(restored_dir / img_file.name), result.restored)
930
+ identity_sims.append(result.identity_similarity)
931
+ else:
932
+ report.identity_failures += 1
933
+ if save_rejected:
934
+ cv2.imwrite(str(rejected_dir / img_file.name), image)
935
+
936
+ # Compute averages
937
+ report.avg_quality_before = float(np.mean(quality_before)) if quality_before else 0.0
938
+ report.avg_quality_after = float(np.mean(quality_after)) if quality_after else 0.0
939
+ report.avg_identity_sim = float(np.mean(identity_sims)) if identity_sims else 0.0
940
+
941
+ # Save report
942
+ report_text = report.summary()
943
+ (out_path / "report.txt").write_text(report_text)
944
+ logger.info("\n%s", report_text)
945
+ logger.info("Results saved to %s/", out_path)
946
+
947
+ return report