landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,947 @@
|
|
|
1
|
+
"""Neural face verification, distortion detection, and restoration pipeline.
|
|
2
|
+
|
|
3
|
+
End-to-end system that:
|
|
4
|
+
1. Detects face distortions (blur, beauty filters, compression, warping, etc.)
|
|
5
|
+
2. Classifies distortion type and severity using no-reference quality metrics
|
|
6
|
+
3. Restores faces using cascaded neural networks (CodeFormer → GFPGAN → Real-ESRGAN)
|
|
7
|
+
4. Verifies output identity matches input via ArcFace embeddings
|
|
8
|
+
5. Scores output realism using learned perceptual metrics
|
|
9
|
+
|
|
10
|
+
Designed for:
|
|
11
|
+
- Cleaning scraped training data (reject/fix bad images before pair generation)
|
|
12
|
+
- Post-diffusion quality gate (ensure generated faces pass realism threshold)
|
|
13
|
+
- Filter removal (undo Snapchat/Instagram beauty filters for clinical use)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from pathlib import Path
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import cv2
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# ---------------------------------------------------------------------------
|
|
29
|
+
# Data structures
|
|
30
|
+
# ---------------------------------------------------------------------------
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class DistortionReport:
|
|
35
|
+
"""Analysis of detected distortions in a face image."""
|
|
36
|
+
|
|
37
|
+
# Overall quality score (0-100, higher = better)
|
|
38
|
+
quality_score: float = 0.0
|
|
39
|
+
|
|
40
|
+
# Individual distortion scores (0-1, higher = more distorted)
|
|
41
|
+
blur_score: float = 0.0 # Laplacian variance-based
|
|
42
|
+
noise_score: float = 0.0 # High-freq energy ratio
|
|
43
|
+
compression_score: float = 0.0 # JPEG block artifact detection
|
|
44
|
+
oversmooth_score: float = 0.0 # Beauty filter / airbrushed detection
|
|
45
|
+
color_cast_score: float = 0.0 # Unnatural color shift
|
|
46
|
+
geometric_distort: float = 0.0 # Face proportion anomalies
|
|
47
|
+
lighting_score: float = 0.0 # Over/under exposure
|
|
48
|
+
|
|
49
|
+
# Classification
|
|
50
|
+
primary_distortion: str = "none"
|
|
51
|
+
severity: str = "none" # none, mild, moderate, severe
|
|
52
|
+
is_usable: bool = True # Whether image is worth restoring vs rejecting
|
|
53
|
+
|
|
54
|
+
# Details
|
|
55
|
+
details: dict = field(default_factory=dict)
|
|
56
|
+
|
|
57
|
+
def summary(self) -> str:
|
|
58
|
+
lines = [
|
|
59
|
+
f"Quality Score: {self.quality_score:.1f}/100",
|
|
60
|
+
f"Primary Issue: {self.primary_distortion} ({self.severity})",
|
|
61
|
+
f"Usable: {self.is_usable}",
|
|
62
|
+
"",
|
|
63
|
+
"Distortion Breakdown:",
|
|
64
|
+
f" Blur: {self.blur_score:.3f}",
|
|
65
|
+
f" Noise: {self.noise_score:.3f}",
|
|
66
|
+
f" Compression: {self.compression_score:.3f}",
|
|
67
|
+
f" Oversmooth: {self.oversmooth_score:.3f}",
|
|
68
|
+
f" Color Cast: {self.color_cast_score:.3f}",
|
|
69
|
+
f" Geometric: {self.geometric_distort:.3f}",
|
|
70
|
+
f" Lighting: {self.lighting_score:.3f}",
|
|
71
|
+
]
|
|
72
|
+
return "\n".join(lines)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class RestorationResult:
|
|
77
|
+
"""Result of neural face restoration pipeline."""
|
|
78
|
+
|
|
79
|
+
restored: np.ndarray # Restored BGR image
|
|
80
|
+
original: np.ndarray # Original BGR image
|
|
81
|
+
distortion_report: DistortionReport # Pre-restoration analysis
|
|
82
|
+
post_quality_score: float = 0.0 # Quality after restoration
|
|
83
|
+
identity_similarity: float = 0.0 # ArcFace cosine sim (original vs restored)
|
|
84
|
+
identity_preserved: bool = True # Whether identity check passed
|
|
85
|
+
restoration_stages: list[str] = field(default_factory=list) # Which nets ran
|
|
86
|
+
improvement: float = 0.0 # quality_after - quality_before
|
|
87
|
+
|
|
88
|
+
def summary(self) -> str:
|
|
89
|
+
lines = [
|
|
90
|
+
f"Pre-restoration: {self.distortion_report.quality_score:.1f}/100",
|
|
91
|
+
f"Post-restoration: {self.post_quality_score:.1f}/100",
|
|
92
|
+
f"Improvement: +{self.improvement:.1f}",
|
|
93
|
+
f"Identity Sim: {self.identity_similarity:.3f}",
|
|
94
|
+
f"Identity OK: {self.identity_preserved}",
|
|
95
|
+
f"Stages Used: {' → '.join(self.restoration_stages) or 'none'}",
|
|
96
|
+
]
|
|
97
|
+
return "\n".join(lines)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclass
|
|
101
|
+
class BatchVerificationReport:
|
|
102
|
+
"""Summary of batch face verification/restoration."""
|
|
103
|
+
|
|
104
|
+
total: int = 0
|
|
105
|
+
passed: int = 0 # Good quality, no fix needed
|
|
106
|
+
restored: int = 0 # Fixed and now usable
|
|
107
|
+
rejected: int = 0 # Too distorted to salvage
|
|
108
|
+
identity_failures: int = 0 # Restoration changed identity
|
|
109
|
+
avg_quality_before: float = 0.0
|
|
110
|
+
avg_quality_after: float = 0.0
|
|
111
|
+
avg_identity_sim: float = 0.0
|
|
112
|
+
distortion_counts: dict[str, int] = field(default_factory=dict)
|
|
113
|
+
|
|
114
|
+
def summary(self) -> str:
|
|
115
|
+
lines = [
|
|
116
|
+
f"Total Images: {self.total}",
|
|
117
|
+
f" Passed (good): {self.passed}",
|
|
118
|
+
f" Restored: {self.restored}",
|
|
119
|
+
f" Rejected: {self.rejected}",
|
|
120
|
+
f" Identity Fail: {self.identity_failures}",
|
|
121
|
+
f"Avg Quality Before: {self.avg_quality_before:.1f}",
|
|
122
|
+
f"Avg Quality After: {self.avg_quality_after:.1f}",
|
|
123
|
+
f"Avg Identity Sim: {self.avg_identity_sim:.3f}",
|
|
124
|
+
"",
|
|
125
|
+
"Distortion Breakdown:",
|
|
126
|
+
]
|
|
127
|
+
for dist_type, count in sorted(
|
|
128
|
+
self.distortion_counts.items(),
|
|
129
|
+
key=lambda x: -x[1],
|
|
130
|
+
):
|
|
131
|
+
lines.append(f" {dist_type}: {count}")
|
|
132
|
+
return "\n".join(lines)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
# ---------------------------------------------------------------------------
|
|
136
|
+
# Distortion Detection (classical + neural)
|
|
137
|
+
# ---------------------------------------------------------------------------
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def detect_blur(image: np.ndarray) -> float:
|
|
141
|
+
"""Detect blur using Laplacian variance.
|
|
142
|
+
|
|
143
|
+
Low variance = blurry. We normalize to 0-1 where 1 = very blurry.
|
|
144
|
+
Uses both Laplacian variance and gradient magnitude for robustness.
|
|
145
|
+
"""
|
|
146
|
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
|
147
|
+
|
|
148
|
+
# Laplacian variance (primary metric)
|
|
149
|
+
lap_var = cv2.Laplacian(gray, cv2.CV_64F).var()
|
|
150
|
+
|
|
151
|
+
# Gradient magnitude (secondary)
|
|
152
|
+
gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
|
|
153
|
+
gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
|
|
154
|
+
grad_mag = np.sqrt(gx**2 + gy**2).mean()
|
|
155
|
+
|
|
156
|
+
# Normalize: typical sharp face has lap_var > 500, grad_mag > 30
|
|
157
|
+
blur_lap = 1.0 - min(lap_var / 800.0, 1.0)
|
|
158
|
+
blur_grad = 1.0 - min(grad_mag / 50.0, 1.0)
|
|
159
|
+
|
|
160
|
+
return float(np.clip(0.6 * blur_lap + 0.4 * blur_grad, 0, 1))
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def detect_noise(image: np.ndarray) -> float:
|
|
164
|
+
"""Detect image noise level.
|
|
165
|
+
|
|
166
|
+
Estimates noise by measuring high-frequency energy in smooth regions.
|
|
167
|
+
Uses the median absolute deviation of the Laplacian (robust estimator).
|
|
168
|
+
"""
|
|
169
|
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
|
170
|
+
|
|
171
|
+
# Robust noise estimation via MAD of Laplacian
|
|
172
|
+
lap = cv2.Laplacian(gray.astype(np.float64), cv2.CV_64F)
|
|
173
|
+
sigma_est = np.median(np.abs(lap)) * 1.4826 # MAD → std conversion
|
|
174
|
+
|
|
175
|
+
# Normalize: sigma > 20 is very noisy
|
|
176
|
+
return float(np.clip(sigma_est / 25.0, 0, 1))
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def detect_compression_artifacts(image: np.ndarray) -> float:
|
|
180
|
+
"""Detect JPEG compression block artifacts.
|
|
181
|
+
|
|
182
|
+
Measures energy at 8x8 block boundaries (JPEG DCT block size).
|
|
183
|
+
High boundary energy relative to interior = compression artifacts.
|
|
184
|
+
"""
|
|
185
|
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
|
186
|
+
h, w = gray.shape
|
|
187
|
+
|
|
188
|
+
if h < 16 or w < 16:
|
|
189
|
+
return 0.0
|
|
190
|
+
|
|
191
|
+
gray_f = gray.astype(np.float64)
|
|
192
|
+
|
|
193
|
+
# Compute horizontal and vertical differences
|
|
194
|
+
h_diff = np.abs(np.diff(gray_f, axis=1))
|
|
195
|
+
v_diff = np.abs(np.diff(gray_f, axis=0))
|
|
196
|
+
|
|
197
|
+
# Energy at 8-pixel boundaries vs non-boundaries
|
|
198
|
+
h_boundary = h_diff[:, 7::8].mean() if h_diff[:, 7::8].size > 0 else 0
|
|
199
|
+
h_interior = h_diff.mean()
|
|
200
|
+
v_boundary = v_diff[7::8, :].mean() if v_diff[7::8, :].size > 0 else 0
|
|
201
|
+
v_interior = v_diff.mean()
|
|
202
|
+
|
|
203
|
+
if h_interior < 1e-6 or v_interior < 1e-6:
|
|
204
|
+
return 0.0
|
|
205
|
+
|
|
206
|
+
# Ratio of boundary to interior energy (>1 means block artifacts)
|
|
207
|
+
h_ratio = h_boundary / (h_interior + 1e-6)
|
|
208
|
+
v_ratio = v_boundary / (v_interior + 1e-6)
|
|
209
|
+
artifact_ratio = (h_ratio + v_ratio) / 2.0
|
|
210
|
+
|
|
211
|
+
# Normalize: ratio > 1.5 indicates visible artifacts
|
|
212
|
+
return float(np.clip((artifact_ratio - 1.0) / 0.8, 0, 1))
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def detect_oversmoothing(image: np.ndarray) -> float:
|
|
216
|
+
"""Detect beauty filter / airbrushed skin (oversmoothing).
|
|
217
|
+
|
|
218
|
+
Beauty filters remove skin texture while preserving edges. We detect
|
|
219
|
+
this by measuring the ratio of edge energy to texture energy.
|
|
220
|
+
High edge / low texture = beauty filtered.
|
|
221
|
+
"""
|
|
222
|
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) if image.ndim == 3 else image
|
|
223
|
+
h, w = gray.shape
|
|
224
|
+
|
|
225
|
+
# Focus on face center region (avoid background)
|
|
226
|
+
if h < 8 or w < 8:
|
|
227
|
+
return 0.0 # Too small to analyze
|
|
228
|
+
roi = gray[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
|
|
229
|
+
|
|
230
|
+
# Texture energy: variance of high-pass filtered image
|
|
231
|
+
blurred = cv2.GaussianBlur(roi.astype(np.float64), (0, 0), 2.0)
|
|
232
|
+
high_pass = roi.astype(np.float64) - blurred
|
|
233
|
+
texture_energy = np.var(high_pass)
|
|
234
|
+
|
|
235
|
+
# Edge energy: Canny edge density
|
|
236
|
+
edges = cv2.Canny(roi, 50, 150)
|
|
237
|
+
edge_density = np.mean(edges > 0)
|
|
238
|
+
|
|
239
|
+
# Oversmooth: low texture but edges still present
|
|
240
|
+
# Natural skin: texture_energy > 20, beauty filter: < 8
|
|
241
|
+
smooth_score = 1.0 - min(texture_energy / 30.0, 1.0)
|
|
242
|
+
|
|
243
|
+
# If there are still strong edges but no texture, it's a filter
|
|
244
|
+
if edge_density > 0.02:
|
|
245
|
+
smooth_score *= 1.3 # Amplify if edges present but no texture
|
|
246
|
+
|
|
247
|
+
return float(np.clip(smooth_score, 0, 1))
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def detect_color_cast(image: np.ndarray) -> float:
|
|
251
|
+
"""Detect unnatural color cast (Instagram-style filters).
|
|
252
|
+
|
|
253
|
+
Measures deviation of average A/B channels in LAB space from
|
|
254
|
+
neutral. Natural skin has consistent LAB distributions; filtered
|
|
255
|
+
images shift these channels.
|
|
256
|
+
"""
|
|
257
|
+
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
|
|
258
|
+
h, w = image.shape[:2]
|
|
259
|
+
|
|
260
|
+
# Sample face center region
|
|
261
|
+
roi = lab[h // 4 : 3 * h // 4, w // 4 : 3 * w // 4]
|
|
262
|
+
|
|
263
|
+
# A channel: green-red axis (neutral ~128)
|
|
264
|
+
# B channel: blue-yellow axis (neutral ~128)
|
|
265
|
+
a_mean = roi[:, :, 1].mean()
|
|
266
|
+
b_mean = roi[:, :, 2].mean()
|
|
267
|
+
|
|
268
|
+
# Deviation from neutral
|
|
269
|
+
a_dev = abs(a_mean - 128) / 128.0
|
|
270
|
+
b_dev = abs(b_mean - 128) / 128.0
|
|
271
|
+
|
|
272
|
+
# Also check if color distribution is unnaturally narrow (saturated filter)
|
|
273
|
+
a_std = roi[:, :, 1].std()
|
|
274
|
+
b_std = roi[:, :, 2].std()
|
|
275
|
+
narrow_color = max(0, 1.0 - (a_std + b_std) / 30.0)
|
|
276
|
+
|
|
277
|
+
score = 0.5 * (a_dev + b_dev) + 0.3 * narrow_color
|
|
278
|
+
return float(np.clip(score, 0, 1))
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def detect_geometric_distortion(image: np.ndarray) -> float:
|
|
282
|
+
"""Detect geometric face distortion (warping filters, lens distortion).
|
|
283
|
+
|
|
284
|
+
Uses MediaPipe landmarks to check face proportions against anatomical
|
|
285
|
+
norms. Distorted faces have abnormal inter-ocular / face-width ratios.
|
|
286
|
+
"""
|
|
287
|
+
try:
|
|
288
|
+
from landmarkdiff.landmarks import extract_landmarks
|
|
289
|
+
except ImportError:
|
|
290
|
+
return 0.0
|
|
291
|
+
|
|
292
|
+
face = extract_landmarks(image)
|
|
293
|
+
if face is None:
|
|
294
|
+
return 0.5 # Can't detect face = possibly distorted
|
|
295
|
+
|
|
296
|
+
coords = face.pixel_coords
|
|
297
|
+
h, w = image.shape[:2]
|
|
298
|
+
|
|
299
|
+
if len(coords) < 478:
|
|
300
|
+
return 0.5 # Incomplete landmark set
|
|
301
|
+
|
|
302
|
+
# Key ratios that should be anatomically consistent
|
|
303
|
+
left_eye = coords[33]
|
|
304
|
+
right_eye = coords[263]
|
|
305
|
+
nose_tip = coords[1]
|
|
306
|
+
chin = coords[152]
|
|
307
|
+
forehead = coords[10]
|
|
308
|
+
|
|
309
|
+
iod = np.linalg.norm(left_eye - right_eye)
|
|
310
|
+
face_height = np.linalg.norm(forehead - chin)
|
|
311
|
+
nose_to_chin = np.linalg.norm(nose_tip - chin)
|
|
312
|
+
|
|
313
|
+
if iod < 1.0 or face_height < 1.0:
|
|
314
|
+
return 0.5
|
|
315
|
+
|
|
316
|
+
# Anatomical norms (approximate):
|
|
317
|
+
# face_height / iod ≈ 2.5-3.5
|
|
318
|
+
# nose_to_chin / face_height ≈ 0.3-0.45
|
|
319
|
+
height_ratio = face_height / iod
|
|
320
|
+
lower_ratio = nose_to_chin / face_height
|
|
321
|
+
|
|
322
|
+
# Score deviations from normal ranges
|
|
323
|
+
height_dev = max(0, abs(height_ratio - 3.0) - 0.5) / 1.5
|
|
324
|
+
lower_dev = max(0, abs(lower_ratio - 0.38) - 0.08) / 0.15
|
|
325
|
+
|
|
326
|
+
# Eye symmetry check (vertical alignment)
|
|
327
|
+
eye_tilt = abs(left_eye[1] - right_eye[1]) / (iod + 1e-6)
|
|
328
|
+
tilt_dev = max(0, eye_tilt - 0.05) / 0.15
|
|
329
|
+
|
|
330
|
+
score = 0.4 * height_dev + 0.3 * lower_dev + 0.3 * tilt_dev
|
|
331
|
+
return float(np.clip(score, 0, 1))
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def detect_lighting_issues(image: np.ndarray) -> float:
|
|
335
|
+
"""Detect over/under exposure and harsh lighting.
|
|
336
|
+
|
|
337
|
+
Checks luminance histogram for clipping and uneven distribution.
|
|
338
|
+
"""
|
|
339
|
+
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
|
340
|
+
l_channel = lab[:, :, 0]
|
|
341
|
+
|
|
342
|
+
# Check for clipping
|
|
343
|
+
overexposed = np.mean(l_channel > 245) * 5 # Fraction near white
|
|
344
|
+
underexposed = np.mean(l_channel < 10) * 5 # Fraction near black
|
|
345
|
+
|
|
346
|
+
# Check for bimodal distribution (harsh shadows)
|
|
347
|
+
hist = cv2.calcHist([l_channel], [0], None, [256], [0, 256]).flatten()
|
|
348
|
+
hist_sum = hist.sum()
|
|
349
|
+
if hist_sum < 1e-10:
|
|
350
|
+
return 0.0
|
|
351
|
+
hist = hist / hist_sum
|
|
352
|
+
# Measure how spread out the histogram is
|
|
353
|
+
entropy = -np.sum(hist[hist > 0] * np.log2(hist[hist > 0] + 1e-10))
|
|
354
|
+
# Low entropy = concentrated = potentially problematic
|
|
355
|
+
entropy_score = max(0, 1.0 - entropy / 7.0)
|
|
356
|
+
|
|
357
|
+
score = 0.4 * overexposed + 0.4 * underexposed + 0.2 * entropy_score
|
|
358
|
+
return float(np.clip(score, 0, 1))
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def analyze_distortions(image: np.ndarray) -> DistortionReport:
|
|
362
|
+
"""Run full distortion analysis on a face image.
|
|
363
|
+
|
|
364
|
+
Combines all detection methods into a comprehensive report with
|
|
365
|
+
quality score, primary distortion classification, and severity.
|
|
366
|
+
"""
|
|
367
|
+
blur = detect_blur(image)
|
|
368
|
+
noise = detect_noise(image)
|
|
369
|
+
compression = detect_compression_artifacts(image)
|
|
370
|
+
oversmooth = detect_oversmoothing(image)
|
|
371
|
+
color_cast = detect_color_cast(image)
|
|
372
|
+
geometric = detect_geometric_distortion(image)
|
|
373
|
+
lighting = detect_lighting_issues(image)
|
|
374
|
+
|
|
375
|
+
# Overall quality: weighted combination (inverted — 100 = perfect)
|
|
376
|
+
weighted = (
|
|
377
|
+
0.25 * blur
|
|
378
|
+
+ 0.15 * noise
|
|
379
|
+
+ 0.10 * compression
|
|
380
|
+
+ 0.20 * oversmooth
|
|
381
|
+
+ 0.10 * color_cast
|
|
382
|
+
+ 0.10 * geometric
|
|
383
|
+
+ 0.10 * lighting
|
|
384
|
+
)
|
|
385
|
+
quality = (1.0 - weighted) * 100.0
|
|
386
|
+
|
|
387
|
+
# Classify primary distortion
|
|
388
|
+
scores = {
|
|
389
|
+
"blur": blur,
|
|
390
|
+
"noise": noise,
|
|
391
|
+
"compression": compression,
|
|
392
|
+
"oversmooth": oversmooth,
|
|
393
|
+
"color_cast": color_cast,
|
|
394
|
+
"geometric": geometric,
|
|
395
|
+
"lighting": lighting,
|
|
396
|
+
}
|
|
397
|
+
primary = max(scores, key=scores.get)
|
|
398
|
+
primary_val = scores[primary]
|
|
399
|
+
|
|
400
|
+
if primary_val < 0.15:
|
|
401
|
+
severity = "none"
|
|
402
|
+
primary = "none"
|
|
403
|
+
elif primary_val < 0.35:
|
|
404
|
+
severity = "mild"
|
|
405
|
+
elif primary_val < 0.60:
|
|
406
|
+
severity = "moderate"
|
|
407
|
+
else:
|
|
408
|
+
severity = "severe"
|
|
409
|
+
|
|
410
|
+
# Image is usable if quality > 30 and no severe geometric distortion
|
|
411
|
+
is_usable = quality > 25 and geometric < 0.7
|
|
412
|
+
|
|
413
|
+
return DistortionReport(
|
|
414
|
+
quality_score=quality,
|
|
415
|
+
blur_score=blur,
|
|
416
|
+
noise_score=noise,
|
|
417
|
+
compression_score=compression,
|
|
418
|
+
oversmooth_score=oversmooth,
|
|
419
|
+
color_cast_score=color_cast,
|
|
420
|
+
geometric_distort=geometric,
|
|
421
|
+
lighting_score=lighting,
|
|
422
|
+
primary_distortion=primary,
|
|
423
|
+
severity=severity,
|
|
424
|
+
is_usable=is_usable,
|
|
425
|
+
details=scores,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
# ---------------------------------------------------------------------------
|
|
430
|
+
# Neural Face Quality Scoring (no-reference)
|
|
431
|
+
# ---------------------------------------------------------------------------
|
|
432
|
+
|
|
433
|
+
_FACE_QUALITY_NET = None
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def _get_face_quality_scorer() -> Any:
|
|
437
|
+
"""Get or create singleton face quality assessment model.
|
|
438
|
+
|
|
439
|
+
Uses FaceXLib's quality scorer or falls back to BRISQUE-style features.
|
|
440
|
+
"""
|
|
441
|
+
global _FACE_QUALITY_NET
|
|
442
|
+
if _FACE_QUALITY_NET is not None:
|
|
443
|
+
return _FACE_QUALITY_NET
|
|
444
|
+
|
|
445
|
+
try:
|
|
446
|
+
from facexlib.assessment import init_assessment_model
|
|
447
|
+
|
|
448
|
+
_FACE_QUALITY_NET = init_assessment_model("hypernet")
|
|
449
|
+
return _FACE_QUALITY_NET
|
|
450
|
+
except Exception:
|
|
451
|
+
pass
|
|
452
|
+
|
|
453
|
+
return None
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
def neural_quality_score(image: np.ndarray) -> float:
|
|
457
|
+
"""Score face quality using neural network (0-100, higher = better).
|
|
458
|
+
|
|
459
|
+
Tries FaceXLib quality assessment first, then falls back to
|
|
460
|
+
BRISQUE-style scoring using OpenCV's QualityBRISQUE if available,
|
|
461
|
+
or classical metrics as last resort.
|
|
462
|
+
"""
|
|
463
|
+
# Try neural scorer
|
|
464
|
+
scorer = _get_face_quality_scorer()
|
|
465
|
+
if scorer is not None:
|
|
466
|
+
try:
|
|
467
|
+
import torch
|
|
468
|
+
from facexlib.utils import img2tensor
|
|
469
|
+
|
|
470
|
+
img_t = img2tensor(image / 255.0, bgr2rgb=True, float32=True)
|
|
471
|
+
img_t = img_t.unsqueeze(0)
|
|
472
|
+
if torch.cuda.is_available():
|
|
473
|
+
img_t = img_t.cuda()
|
|
474
|
+
scorer = scorer.cuda()
|
|
475
|
+
with torch.no_grad():
|
|
476
|
+
score = scorer(img_t).item()
|
|
477
|
+
return float(np.clip(score * 100, 0, 100))
|
|
478
|
+
except Exception:
|
|
479
|
+
pass
|
|
480
|
+
|
|
481
|
+
# Fallback: composite classical score
|
|
482
|
+
report = analyze_distortions(image)
|
|
483
|
+
return report.quality_score
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
# ---------------------------------------------------------------------------
|
|
487
|
+
# Neural Face Restoration (cascaded)
|
|
488
|
+
# ---------------------------------------------------------------------------
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def restore_face(
|
|
492
|
+
image: np.ndarray,
|
|
493
|
+
distortion: DistortionReport | None = None,
|
|
494
|
+
mode: str = "auto",
|
|
495
|
+
codeformer_fidelity: float = 0.7,
|
|
496
|
+
) -> tuple[np.ndarray, list[str]]:
|
|
497
|
+
"""Cascaded neural face restoration.
|
|
498
|
+
|
|
499
|
+
Selects and applies restoration networks based on detected distortions:
|
|
500
|
+
- Blur/oversmooth → CodeFormer (recovers texture from codebook)
|
|
501
|
+
- Noise/compression → GFPGAN (trained on degraded faces)
|
|
502
|
+
- Background → Real-ESRGAN (neural 4x upscale + downsample)
|
|
503
|
+
- Color cast → Classical LAB correction (no neural net needed)
|
|
504
|
+
- Geometric → Not fixable by restoration (flag and skip)
|
|
505
|
+
|
|
506
|
+
Args:
|
|
507
|
+
image: BGR face image to restore.
|
|
508
|
+
distortion: Pre-computed distortion report (computed if None).
|
|
509
|
+
mode: 'auto' (choose based on distortion), 'codeformer', 'gfpgan', 'all'.
|
|
510
|
+
codeformer_fidelity: CodeFormer quality-fidelity tradeoff.
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
Tuple of (restored BGR image, list of stages applied).
|
|
514
|
+
"""
|
|
515
|
+
if distortion is None:
|
|
516
|
+
distortion = analyze_distortions(image)
|
|
517
|
+
|
|
518
|
+
result = image.copy()
|
|
519
|
+
stages = []
|
|
520
|
+
|
|
521
|
+
# Step 0: Fix color cast first (classical — fast, doesn't affect identity)
|
|
522
|
+
if distortion.color_cast_score > 0.25:
|
|
523
|
+
result = _fix_color_cast(result)
|
|
524
|
+
stages.append("color_correction")
|
|
525
|
+
|
|
526
|
+
# Step 1: Fix lighting issues (classical)
|
|
527
|
+
if distortion.lighting_score > 0.35:
|
|
528
|
+
result = _fix_lighting(result)
|
|
529
|
+
stages.append("lighting_fix")
|
|
530
|
+
|
|
531
|
+
# Step 2: Neural face restoration
|
|
532
|
+
if mode == "auto":
|
|
533
|
+
# Choose based on what's wrong
|
|
534
|
+
needs_face_restore = (
|
|
535
|
+
distortion.blur_score > 0.2
|
|
536
|
+
or distortion.oversmooth_score > 0.25
|
|
537
|
+
or distortion.noise_score > 0.25
|
|
538
|
+
or distortion.compression_score > 0.2
|
|
539
|
+
)
|
|
540
|
+
if needs_face_restore:
|
|
541
|
+
mode = "codeformer" # CodeFormer handles most degradations well
|
|
542
|
+
|
|
543
|
+
if mode in ("codeformer", "all"):
|
|
544
|
+
restored = _try_codeformer(result, fidelity=codeformer_fidelity)
|
|
545
|
+
if restored is not None:
|
|
546
|
+
result = restored
|
|
547
|
+
stages.append("codeformer")
|
|
548
|
+
else:
|
|
549
|
+
# Fallback to GFPGAN
|
|
550
|
+
restored = _try_gfpgan(result)
|
|
551
|
+
if restored is not None:
|
|
552
|
+
result = restored
|
|
553
|
+
stages.append("gfpgan")
|
|
554
|
+
|
|
555
|
+
elif mode == "gfpgan":
|
|
556
|
+
restored = _try_gfpgan(result)
|
|
557
|
+
if restored is not None:
|
|
558
|
+
result = restored
|
|
559
|
+
stages.append("gfpgan")
|
|
560
|
+
|
|
561
|
+
# Step 3: Background enhancement with Real-ESRGAN (if image is low-res)
|
|
562
|
+
h, w = result.shape[:2]
|
|
563
|
+
if h < 400 or w < 400:
|
|
564
|
+
enhanced = _try_realesrgan(result)
|
|
565
|
+
if enhanced is not None:
|
|
566
|
+
result = enhanced
|
|
567
|
+
stages.append("realesrgan")
|
|
568
|
+
|
|
569
|
+
# Step 4: Mild sharpening if still soft after restoration
|
|
570
|
+
post_blur = detect_blur(result)
|
|
571
|
+
if post_blur > 0.3:
|
|
572
|
+
from landmarkdiff.postprocess import frequency_aware_sharpen
|
|
573
|
+
|
|
574
|
+
result = frequency_aware_sharpen(result, strength=0.3)
|
|
575
|
+
stages.append("sharpen")
|
|
576
|
+
|
|
577
|
+
return result, stages
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _try_codeformer(image: np.ndarray, fidelity: float = 0.7) -> np.ndarray | None:
|
|
581
|
+
"""Try CodeFormer restoration. Returns None if unavailable."""
|
|
582
|
+
try:
|
|
583
|
+
from landmarkdiff.postprocess import restore_face_codeformer
|
|
584
|
+
|
|
585
|
+
restored = restore_face_codeformer(image, fidelity=fidelity)
|
|
586
|
+
if restored is not image:
|
|
587
|
+
return restored
|
|
588
|
+
except Exception:
|
|
589
|
+
pass
|
|
590
|
+
return None
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _try_gfpgan(image: np.ndarray) -> np.ndarray | None:
|
|
594
|
+
"""Try GFPGAN restoration. Returns None if unavailable."""
|
|
595
|
+
try:
|
|
596
|
+
from landmarkdiff.postprocess import restore_face_gfpgan
|
|
597
|
+
|
|
598
|
+
restored = restore_face_gfpgan(image)
|
|
599
|
+
if restored is not image:
|
|
600
|
+
return restored
|
|
601
|
+
except Exception:
|
|
602
|
+
pass
|
|
603
|
+
return None
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
_FV_REALESRGAN = None
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
def _try_realesrgan(image: np.ndarray) -> np.ndarray | None:
|
|
610
|
+
"""Try Real-ESRGAN 2x upscale + downsample. Returns None if unavailable."""
|
|
611
|
+
try:
|
|
612
|
+
import torch
|
|
613
|
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
614
|
+
from realesrgan import RealESRGANer
|
|
615
|
+
|
|
616
|
+
global _FV_REALESRGAN
|
|
617
|
+
if _FV_REALESRGAN is None:
|
|
618
|
+
model = RRDBNet(
|
|
619
|
+
num_in_ch=3,
|
|
620
|
+
num_out_ch=3,
|
|
621
|
+
num_feat=64,
|
|
622
|
+
num_block=23,
|
|
623
|
+
num_grow_ch=32,
|
|
624
|
+
scale=4,
|
|
625
|
+
)
|
|
626
|
+
_FV_REALESRGAN = RealESRGANer(
|
|
627
|
+
scale=4,
|
|
628
|
+
model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
|
629
|
+
model=model,
|
|
630
|
+
tile=400,
|
|
631
|
+
tile_pad=10,
|
|
632
|
+
pre_pad=0,
|
|
633
|
+
half=torch.cuda.is_available(),
|
|
634
|
+
)
|
|
635
|
+
enhanced, _ = _FV_REALESRGAN.enhance(image, outscale=2)
|
|
636
|
+
|
|
637
|
+
# Downsample to 512x512 for pipeline consistency
|
|
638
|
+
enhanced = cv2.resize(enhanced, (512, 512), interpolation=cv2.INTER_LANCZOS4)
|
|
639
|
+
return enhanced
|
|
640
|
+
except Exception:
|
|
641
|
+
pass
|
|
642
|
+
return None
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def _fix_color_cast(image: np.ndarray) -> np.ndarray:
|
|
646
|
+
"""Remove color cast by normalizing A/B channels in LAB space."""
|
|
647
|
+
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
|
|
648
|
+
|
|
649
|
+
# Center A and B channels around 128 (neutral)
|
|
650
|
+
for ch in [1, 2]:
|
|
651
|
+
channel = lab[:, :, ch]
|
|
652
|
+
mean_val = channel.mean()
|
|
653
|
+
# Shift toward neutral, but only partially to preserve natural skin tone
|
|
654
|
+
shift = (128.0 - mean_val) * 0.6
|
|
655
|
+
lab[:, :, ch] = np.clip(channel + shift, 0, 255)
|
|
656
|
+
|
|
657
|
+
return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
def _fix_lighting(image: np.ndarray) -> np.ndarray:
|
|
661
|
+
"""Fix over/under exposure using adaptive CLAHE in LAB space."""
|
|
662
|
+
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
|
663
|
+
|
|
664
|
+
# CLAHE on luminance channel only
|
|
665
|
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
|
666
|
+
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
|
|
667
|
+
|
|
668
|
+
return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
# ---------------------------------------------------------------------------
|
|
672
|
+
# ArcFace Identity Verification
|
|
673
|
+
# ---------------------------------------------------------------------------
|
|
674
|
+
|
|
675
|
+
_ARCFACE_APP = None
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
def _get_arcface() -> Any:
|
|
679
|
+
"""Get or create singleton ArcFace model."""
|
|
680
|
+
global _ARCFACE_APP
|
|
681
|
+
if _ARCFACE_APP is not None:
|
|
682
|
+
return _ARCFACE_APP
|
|
683
|
+
|
|
684
|
+
try:
|
|
685
|
+
import torch
|
|
686
|
+
from insightface.app import FaceAnalysis
|
|
687
|
+
|
|
688
|
+
app = FaceAnalysis(
|
|
689
|
+
name="buffalo_l",
|
|
690
|
+
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
|
|
691
|
+
)
|
|
692
|
+
ctx_id = 0 if torch.cuda.is_available() else -1
|
|
693
|
+
app.prepare(ctx_id=ctx_id, det_size=(320, 320))
|
|
694
|
+
_ARCFACE_APP = app
|
|
695
|
+
return app
|
|
696
|
+
except Exception:
|
|
697
|
+
return None
|
|
698
|
+
|
|
699
|
+
|
|
700
|
+
def get_face_embedding(image: np.ndarray) -> np.ndarray | None:
|
|
701
|
+
"""Extract ArcFace 512-d embedding from a face image.
|
|
702
|
+
|
|
703
|
+
Returns None if no face detected or InsightFace unavailable.
|
|
704
|
+
"""
|
|
705
|
+
app = _get_arcface()
|
|
706
|
+
if app is None:
|
|
707
|
+
return None
|
|
708
|
+
|
|
709
|
+
try:
|
|
710
|
+
faces = app.get(image)
|
|
711
|
+
if faces:
|
|
712
|
+
emb = faces[0].embedding
|
|
713
|
+
if np.linalg.norm(emb) < 1e-6:
|
|
714
|
+
logger.warning("ArcFace returned near-zero embedding (occluded face?)")
|
|
715
|
+
return None
|
|
716
|
+
return emb
|
|
717
|
+
except Exception:
|
|
718
|
+
pass
|
|
719
|
+
return None
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def verify_identity(
|
|
723
|
+
original: np.ndarray,
|
|
724
|
+
restored: np.ndarray,
|
|
725
|
+
threshold: float = 0.6,
|
|
726
|
+
) -> tuple[float, bool]:
|
|
727
|
+
"""Compare identity between original and restored using ArcFace.
|
|
728
|
+
|
|
729
|
+
Returns (cosine_similarity, passed).
|
|
730
|
+
Similarity > threshold means same person (threshold=0.6 is conservative).
|
|
731
|
+
"""
|
|
732
|
+
emb_orig = get_face_embedding(original)
|
|
733
|
+
emb_rest = get_face_embedding(restored)
|
|
734
|
+
|
|
735
|
+
if emb_orig is None or emb_rest is None:
|
|
736
|
+
return -1.0, True # Can't verify — assume OK
|
|
737
|
+
|
|
738
|
+
sim = float(
|
|
739
|
+
np.dot(emb_orig, emb_rest) / (np.linalg.norm(emb_orig) * np.linalg.norm(emb_rest) + 1e-8)
|
|
740
|
+
)
|
|
741
|
+
sim = float(np.clip(sim, -1, 1))
|
|
742
|
+
return sim, sim >= threshold
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
# ---------------------------------------------------------------------------
|
|
746
|
+
# Full Verification + Restoration Pipeline
|
|
747
|
+
# ---------------------------------------------------------------------------
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def verify_and_restore(
|
|
751
|
+
image: np.ndarray,
|
|
752
|
+
quality_threshold: float = 60.0,
|
|
753
|
+
identity_threshold: float = 0.6,
|
|
754
|
+
restore_mode: str = "auto",
|
|
755
|
+
codeformer_fidelity: float = 0.7,
|
|
756
|
+
) -> RestorationResult:
|
|
757
|
+
"""Full pipeline: analyze → restore → verify identity.
|
|
758
|
+
|
|
759
|
+
This is the main entry point for the face verifier. It:
|
|
760
|
+
1. Analyzes the input for distortions
|
|
761
|
+
2. If quality is below threshold, applies neural restoration
|
|
762
|
+
3. Verifies the restored face preserves identity
|
|
763
|
+
4. Returns comprehensive result with metrics
|
|
764
|
+
|
|
765
|
+
Args:
|
|
766
|
+
image: BGR face image.
|
|
767
|
+
quality_threshold: Min quality to skip restoration (0-100).
|
|
768
|
+
identity_threshold: Min ArcFace similarity to pass (0-1).
|
|
769
|
+
restore_mode: 'auto', 'codeformer', 'gfpgan', 'all'.
|
|
770
|
+
codeformer_fidelity: CodeFormer quality-fidelity balance.
|
|
771
|
+
|
|
772
|
+
Returns:
|
|
773
|
+
RestorationResult with restored image and full metrics.
|
|
774
|
+
"""
|
|
775
|
+
# Step 1: Analyze distortions
|
|
776
|
+
report = analyze_distortions(image)
|
|
777
|
+
|
|
778
|
+
# Step 2: Decide if restoration needed
|
|
779
|
+
if report.quality_score >= quality_threshold and report.severity in ("none", "mild"):
|
|
780
|
+
# Image is good enough — no restoration needed
|
|
781
|
+
return RestorationResult(
|
|
782
|
+
restored=image.copy(),
|
|
783
|
+
original=image.copy(),
|
|
784
|
+
distortion_report=report,
|
|
785
|
+
post_quality_score=report.quality_score,
|
|
786
|
+
identity_similarity=1.0,
|
|
787
|
+
identity_preserved=True,
|
|
788
|
+
restoration_stages=[],
|
|
789
|
+
improvement=0.0,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
if not report.is_usable:
|
|
793
|
+
# Too distorted to salvage
|
|
794
|
+
return RestorationResult(
|
|
795
|
+
restored=image.copy(),
|
|
796
|
+
original=image.copy(),
|
|
797
|
+
distortion_report=report,
|
|
798
|
+
post_quality_score=report.quality_score,
|
|
799
|
+
identity_similarity=0.0,
|
|
800
|
+
identity_preserved=False,
|
|
801
|
+
restoration_stages=["rejected"],
|
|
802
|
+
improvement=0.0,
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
# Step 3: Neural restoration
|
|
806
|
+
restored, stages = restore_face(
|
|
807
|
+
image,
|
|
808
|
+
distortion=report,
|
|
809
|
+
mode=restore_mode,
|
|
810
|
+
codeformer_fidelity=codeformer_fidelity,
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Step 4: Post-restoration quality check
|
|
814
|
+
post_quality = neural_quality_score(restored)
|
|
815
|
+
|
|
816
|
+
# Step 5: Identity verification
|
|
817
|
+
sim, id_ok = verify_identity(image, restored, threshold=identity_threshold)
|
|
818
|
+
|
|
819
|
+
return RestorationResult(
|
|
820
|
+
restored=restored,
|
|
821
|
+
original=image.copy(),
|
|
822
|
+
distortion_report=report,
|
|
823
|
+
post_quality_score=post_quality,
|
|
824
|
+
identity_similarity=sim,
|
|
825
|
+
identity_preserved=id_ok,
|
|
826
|
+
restoration_stages=stages,
|
|
827
|
+
improvement=post_quality - report.quality_score,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
# ---------------------------------------------------------------------------
|
|
832
|
+
# Batch Processing
|
|
833
|
+
# ---------------------------------------------------------------------------
|
|
834
|
+
|
|
835
|
+
|
|
836
|
+
def verify_batch(
|
|
837
|
+
image_dir: str,
|
|
838
|
+
output_dir: str | None = None,
|
|
839
|
+
quality_threshold: float = 60.0,
|
|
840
|
+
identity_threshold: float = 0.6,
|
|
841
|
+
restore_mode: str = "auto",
|
|
842
|
+
save_rejected: bool = False,
|
|
843
|
+
extensions: tuple[str, ...] = (".jpg", ".jpeg", ".png", ".webp", ".bmp"),
|
|
844
|
+
) -> BatchVerificationReport:
|
|
845
|
+
"""Process a directory of face images: analyze, restore, verify, sort.
|
|
846
|
+
|
|
847
|
+
Outputs:
|
|
848
|
+
- {output_dir}/passed/ — good images (no fix needed)
|
|
849
|
+
- {output_dir}/restored/ — fixed images
|
|
850
|
+
- {output_dir}/rejected/ — too distorted to use (if save_rejected=True)
|
|
851
|
+
- {output_dir}/report.txt — batch verification report
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
image_dir: Directory of face images to process.
|
|
855
|
+
output_dir: Where to save results (default: {image_dir}_verified/).
|
|
856
|
+
quality_threshold: Min quality to pass without restoration.
|
|
857
|
+
identity_threshold: Min identity similarity after restoration.
|
|
858
|
+
restore_mode: 'auto', 'codeformer', 'gfpgan', 'all'.
|
|
859
|
+
save_rejected: Whether to copy rejected images to rejected/ subdir.
|
|
860
|
+
extensions: File extensions to process.
|
|
861
|
+
|
|
862
|
+
Returns:
|
|
863
|
+
BatchVerificationReport with summary statistics.
|
|
864
|
+
"""
|
|
865
|
+
image_path = Path(image_dir)
|
|
866
|
+
if output_dir is None:
|
|
867
|
+
out_path = image_path.parent / f"{image_path.name}_verified"
|
|
868
|
+
else:
|
|
869
|
+
out_path = Path(output_dir)
|
|
870
|
+
|
|
871
|
+
# Create output dirs
|
|
872
|
+
passed_dir = out_path / "passed"
|
|
873
|
+
restored_dir = out_path / "restored"
|
|
874
|
+
rejected_dir = out_path / "rejected"
|
|
875
|
+
passed_dir.mkdir(parents=True, exist_ok=True)
|
|
876
|
+
restored_dir.mkdir(parents=True, exist_ok=True)
|
|
877
|
+
if save_rejected:
|
|
878
|
+
rejected_dir.mkdir(parents=True, exist_ok=True)
|
|
879
|
+
|
|
880
|
+
# Find all images
|
|
881
|
+
image_files = sorted(
|
|
882
|
+
[f for f in image_path.iterdir() if f.suffix.lower() in extensions and f.is_file()]
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
report = BatchVerificationReport(total=len(image_files))
|
|
886
|
+
quality_before = []
|
|
887
|
+
quality_after = []
|
|
888
|
+
identity_sims = []
|
|
889
|
+
|
|
890
|
+
for i, img_file in enumerate(image_files):
|
|
891
|
+
if (i + 1) % 50 == 0 or i == 0:
|
|
892
|
+
logger.info("Processing %d/%d: %s", i + 1, len(image_files), img_file.name)
|
|
893
|
+
|
|
894
|
+
image = cv2.imread(str(img_file))
|
|
895
|
+
if image is None:
|
|
896
|
+
report.rejected += 1
|
|
897
|
+
continue
|
|
898
|
+
|
|
899
|
+
# Resize to 512x512 for consistency
|
|
900
|
+
image = cv2.resize(image, (512, 512))
|
|
901
|
+
|
|
902
|
+
# Run verification + restoration
|
|
903
|
+
result = verify_and_restore(
|
|
904
|
+
image,
|
|
905
|
+
quality_threshold=quality_threshold,
|
|
906
|
+
identity_threshold=identity_threshold,
|
|
907
|
+
restore_mode=restore_mode,
|
|
908
|
+
)
|
|
909
|
+
|
|
910
|
+
quality_before.append(result.distortion_report.quality_score)
|
|
911
|
+
quality_after.append(result.post_quality_score)
|
|
912
|
+
|
|
913
|
+
# Track distortion types
|
|
914
|
+
dist_type = result.distortion_report.primary_distortion
|
|
915
|
+
report.distortion_counts[dist_type] = report.distortion_counts.get(dist_type, 0) + 1
|
|
916
|
+
|
|
917
|
+
if not result.distortion_report.is_usable or "rejected" in result.restoration_stages:
|
|
918
|
+
report.rejected += 1
|
|
919
|
+
if save_rejected:
|
|
920
|
+
cv2.imwrite(str(rejected_dir / img_file.name), image)
|
|
921
|
+
elif not result.restoration_stages:
|
|
922
|
+
# Passed without restoration
|
|
923
|
+
report.passed += 1
|
|
924
|
+
cv2.imwrite(str(passed_dir / img_file.name), image)
|
|
925
|
+
else:
|
|
926
|
+
# Restored
|
|
927
|
+
if result.identity_preserved:
|
|
928
|
+
report.restored += 1
|
|
929
|
+
cv2.imwrite(str(restored_dir / img_file.name), result.restored)
|
|
930
|
+
identity_sims.append(result.identity_similarity)
|
|
931
|
+
else:
|
|
932
|
+
report.identity_failures += 1
|
|
933
|
+
if save_rejected:
|
|
934
|
+
cv2.imwrite(str(rejected_dir / img_file.name), image)
|
|
935
|
+
|
|
936
|
+
# Compute averages
|
|
937
|
+
report.avg_quality_before = float(np.mean(quality_before)) if quality_before else 0.0
|
|
938
|
+
report.avg_quality_after = float(np.mean(quality_after)) if quality_after else 0.0
|
|
939
|
+
report.avg_identity_sim = float(np.mean(identity_sims)) if identity_sims else 0.0
|
|
940
|
+
|
|
941
|
+
# Save report
|
|
942
|
+
report_text = report.summary()
|
|
943
|
+
(out_path / "report.txt").write_text(report_text)
|
|
944
|
+
logger.info("\n%s", report_text)
|
|
945
|
+
logger.info("Results saved to %s/", out_path)
|
|
946
|
+
|
|
947
|
+
return report
|