landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/safety.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
"""Clinical safety validation for responsible deployment.
|
|
2
|
+
|
|
3
|
+
Implements safety checks for surgical outcome predictions:
|
|
4
|
+
1. Identity preservation: verify output preserves patient identity
|
|
5
|
+
2. Anatomical plausibility: check landmark displacements are realistic
|
|
6
|
+
3. Out-of-distribution detection: flag unusual inputs
|
|
7
|
+
4. Watermarking: mark AI-generated images
|
|
8
|
+
5. Consent metadata: embed provenance information
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from landmarkdiff.safety import SafetyValidator
|
|
12
|
+
|
|
13
|
+
validator = SafetyValidator()
|
|
14
|
+
result = validator.validate(
|
|
15
|
+
input_image=image,
|
|
16
|
+
output_image=generated,
|
|
17
|
+
landmarks_original=face.landmarks,
|
|
18
|
+
landmarks_manipulated=manip.landmarks,
|
|
19
|
+
procedure="rhinoplasty",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
if not result.passed:
|
|
23
|
+
print(f"Safety check failed: {result.failures}")
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
from dataclasses import dataclass, field
|
|
29
|
+
|
|
30
|
+
import cv2
|
|
31
|
+
import numpy as np
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class SafetyResult:
|
|
36
|
+
"""Result of safety validation checks."""
|
|
37
|
+
|
|
38
|
+
passed: bool = True
|
|
39
|
+
failures: list[str] = field(default_factory=list)
|
|
40
|
+
warnings: list[str] = field(default_factory=list)
|
|
41
|
+
checks: dict[str, bool] = field(default_factory=dict)
|
|
42
|
+
details: dict[str, object] = field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
def __repr__(self) -> str:
|
|
45
|
+
return (
|
|
46
|
+
f"SafetyResult(passed={self.passed}, "
|
|
47
|
+
f"failures={self.failures}, "
|
|
48
|
+
f"warnings={self.warnings}, "
|
|
49
|
+
f"checks={self.checks}, "
|
|
50
|
+
f"details={self.details})"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def __eq__(self, other: object) -> bool:
|
|
54
|
+
if not isinstance(other, SafetyResult):
|
|
55
|
+
return NotImplemented
|
|
56
|
+
return (
|
|
57
|
+
self.passed == other.passed
|
|
58
|
+
and self.failures == other.failures
|
|
59
|
+
and self.warnings == other.warnings
|
|
60
|
+
and self.checks == other.checks
|
|
61
|
+
and self.details == other.details
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def add_failure(self, name: str, message: str) -> None:
|
|
65
|
+
self.passed = False
|
|
66
|
+
self.failures.append(message)
|
|
67
|
+
self.checks[name] = False
|
|
68
|
+
|
|
69
|
+
def add_warning(self, name: str, message: str) -> None:
|
|
70
|
+
self.warnings.append(message)
|
|
71
|
+
|
|
72
|
+
def add_pass(self, name: str) -> None:
|
|
73
|
+
self.checks[name] = True
|
|
74
|
+
|
|
75
|
+
def summary(self) -> str:
|
|
76
|
+
lines = [f"Safety: {'PASS' if self.passed else 'FAIL'}"]
|
|
77
|
+
for name, ok in self.checks.items():
|
|
78
|
+
lines.append(f" [{'OK' if ok else 'FAIL'}] {name}")
|
|
79
|
+
for w in self.warnings:
|
|
80
|
+
lines.append(f" [WARN] {w}")
|
|
81
|
+
return "\n".join(lines)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class SafetyValidator:
|
|
85
|
+
"""Clinical safety validation for surgical predictions."""
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
identity_threshold: float = 0.6,
|
|
90
|
+
max_displacement_fraction: float = 0.05,
|
|
91
|
+
min_face_confidence: float = 0.5,
|
|
92
|
+
max_yaw_degrees: float = 45.0,
|
|
93
|
+
watermark_enabled: bool = True,
|
|
94
|
+
watermark_text: str = "AI-GENERATED PREDICTION",
|
|
95
|
+
):
|
|
96
|
+
self.identity_threshold = identity_threshold
|
|
97
|
+
self.max_displacement_fraction = max_displacement_fraction
|
|
98
|
+
self.min_face_confidence = min_face_confidence
|
|
99
|
+
self.max_yaw_degrees = max_yaw_degrees
|
|
100
|
+
self.watermark_enabled = watermark_enabled
|
|
101
|
+
self.watermark_text = watermark_text
|
|
102
|
+
|
|
103
|
+
def validate(
|
|
104
|
+
self,
|
|
105
|
+
input_image: np.ndarray,
|
|
106
|
+
output_image: np.ndarray,
|
|
107
|
+
landmarks_original: np.ndarray | None = None,
|
|
108
|
+
landmarks_manipulated: np.ndarray | None = None,
|
|
109
|
+
procedure: str | None = None,
|
|
110
|
+
face_confidence: float = 1.0,
|
|
111
|
+
) -> SafetyResult:
|
|
112
|
+
"""Run all safety checks on a prediction.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
input_image: Original patient image (BGR, uint8).
|
|
116
|
+
output_image: Generated prediction (BGR, uint8).
|
|
117
|
+
landmarks_original: Original landmarks (N, 2-3), normalized [0, 1].
|
|
118
|
+
landmarks_manipulated: Manipulated landmarks (N, 2-3), normalized [0, 1].
|
|
119
|
+
procedure: Surgical procedure name.
|
|
120
|
+
face_confidence: MediaPipe face detection confidence.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
SafetyResult with all check results.
|
|
124
|
+
"""
|
|
125
|
+
result = SafetyResult()
|
|
126
|
+
|
|
127
|
+
# 1. Face detection confidence
|
|
128
|
+
self._check_face_confidence(result, face_confidence)
|
|
129
|
+
|
|
130
|
+
# 2. Identity preservation
|
|
131
|
+
self._check_identity(result, input_image, output_image)
|
|
132
|
+
|
|
133
|
+
# 3. Anatomical plausibility
|
|
134
|
+
if landmarks_original is not None and landmarks_manipulated is not None:
|
|
135
|
+
self._check_anatomical_plausibility(
|
|
136
|
+
result, landmarks_original, landmarks_manipulated, procedure
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# 4. Output quality
|
|
140
|
+
self._check_output_quality(result, output_image)
|
|
141
|
+
|
|
142
|
+
# 5. OOD detection (basic)
|
|
143
|
+
self._check_ood(result, input_image)
|
|
144
|
+
|
|
145
|
+
return result
|
|
146
|
+
|
|
147
|
+
def _check_face_confidence(self, result: SafetyResult, confidence: float) -> None:
|
|
148
|
+
"""Check face detection confidence."""
|
|
149
|
+
if confidence < self.min_face_confidence:
|
|
150
|
+
result.add_failure(
|
|
151
|
+
"face_confidence",
|
|
152
|
+
f"Face detection confidence {confidence:.2f} below threshold "
|
|
153
|
+
f"{self.min_face_confidence}",
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
result.add_pass("face_confidence")
|
|
157
|
+
result.details["face_confidence"] = confidence
|
|
158
|
+
|
|
159
|
+
def _check_identity(
|
|
160
|
+
self,
|
|
161
|
+
result: SafetyResult,
|
|
162
|
+
input_image: np.ndarray,
|
|
163
|
+
output_image: np.ndarray,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Check identity preservation using ArcFace similarity."""
|
|
166
|
+
try:
|
|
167
|
+
from landmarkdiff.evaluation import compute_identity_similarity
|
|
168
|
+
|
|
169
|
+
sim = compute_identity_similarity(output_image, input_image)
|
|
170
|
+
result.details["identity_similarity"] = float(sim)
|
|
171
|
+
|
|
172
|
+
if sim < self.identity_threshold:
|
|
173
|
+
result.add_failure(
|
|
174
|
+
"identity",
|
|
175
|
+
f"Identity similarity {sim:.3f} below threshold {self.identity_threshold}",
|
|
176
|
+
)
|
|
177
|
+
else:
|
|
178
|
+
result.add_pass("identity")
|
|
179
|
+
except Exception as e:
|
|
180
|
+
result.add_warning("identity", f"Identity check failed: {e}")
|
|
181
|
+
|
|
182
|
+
def _check_anatomical_plausibility(
|
|
183
|
+
self,
|
|
184
|
+
result: SafetyResult,
|
|
185
|
+
landmarks_orig: np.ndarray,
|
|
186
|
+
landmarks_manip: np.ndarray,
|
|
187
|
+
procedure: str | None,
|
|
188
|
+
) -> None:
|
|
189
|
+
"""Check that landmark displacements are anatomically plausible."""
|
|
190
|
+
if len(landmarks_orig) != len(landmarks_manip):
|
|
191
|
+
result.add_failure(
|
|
192
|
+
"anatomical",
|
|
193
|
+
f"Landmark count mismatch: {len(landmarks_orig)} vs {len(landmarks_manip)}",
|
|
194
|
+
)
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
# Compute displacement magnitudes
|
|
198
|
+
n = min(len(landmarks_orig), len(landmarks_manip))
|
|
199
|
+
orig = landmarks_orig[:n, :2] # (N, 2), normalized [0, 1]
|
|
200
|
+
manip = landmarks_manip[:n, :2]
|
|
201
|
+
displacements = np.linalg.norm(manip - orig, axis=1)
|
|
202
|
+
|
|
203
|
+
max_disp = float(displacements.max())
|
|
204
|
+
mean_disp = float(displacements.mean())
|
|
205
|
+
result.details["max_displacement"] = max_disp
|
|
206
|
+
result.details["mean_displacement"] = mean_disp
|
|
207
|
+
|
|
208
|
+
# Check maximum displacement
|
|
209
|
+
if max_disp > self.max_displacement_fraction:
|
|
210
|
+
result.add_failure(
|
|
211
|
+
"anatomical_magnitude",
|
|
212
|
+
f"Maximum displacement {max_disp:.4f} exceeds threshold "
|
|
213
|
+
f"{self.max_displacement_fraction}",
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
result.add_pass("anatomical_magnitude")
|
|
217
|
+
|
|
218
|
+
# Check procedure-specific regions
|
|
219
|
+
if procedure:
|
|
220
|
+
self._check_procedure_regions(result, orig, manip, displacements, procedure)
|
|
221
|
+
|
|
222
|
+
def _check_procedure_regions(
|
|
223
|
+
self,
|
|
224
|
+
result: SafetyResult,
|
|
225
|
+
orig: np.ndarray,
|
|
226
|
+
manip: np.ndarray,
|
|
227
|
+
displacements: np.ndarray,
|
|
228
|
+
procedure: str,
|
|
229
|
+
) -> None:
|
|
230
|
+
"""Verify displacement is concentrated in expected anatomical regions."""
|
|
231
|
+
from landmarkdiff.landmarks import LANDMARK_REGIONS
|
|
232
|
+
|
|
233
|
+
# Expected regions by procedure
|
|
234
|
+
expected_regions = {
|
|
235
|
+
"rhinoplasty": ["nose"],
|
|
236
|
+
"blepharoplasty": ["eye_left", "eye_right"],
|
|
237
|
+
"rhytidectomy": ["jawline"],
|
|
238
|
+
"orthognathic": ["jawline", "lips"],
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
expected = expected_regions.get(procedure, [])
|
|
242
|
+
if not expected:
|
|
243
|
+
result.add_pass("procedure_region")
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
# Get expected region indices
|
|
247
|
+
expected_indices = set()
|
|
248
|
+
for region in expected:
|
|
249
|
+
if region in LANDMARK_REGIONS:
|
|
250
|
+
expected_indices.update(LANDMARK_REGIONS[region])
|
|
251
|
+
|
|
252
|
+
if not expected_indices:
|
|
253
|
+
result.add_pass("procedure_region")
|
|
254
|
+
return
|
|
255
|
+
|
|
256
|
+
# Check: is most displacement in expected regions?
|
|
257
|
+
n = min(len(displacements), len(orig))
|
|
258
|
+
expected_mask = np.array([i in expected_indices for i in range(n)])
|
|
259
|
+
|
|
260
|
+
if expected_mask.sum() > 0 and (~expected_mask).sum() > 0:
|
|
261
|
+
expected_disp = displacements[expected_mask].mean()
|
|
262
|
+
unexpected_disp = displacements[~expected_mask].mean()
|
|
263
|
+
result.details["expected_region_disp"] = float(expected_disp)
|
|
264
|
+
result.details["unexpected_region_disp"] = float(unexpected_disp)
|
|
265
|
+
|
|
266
|
+
# Expected regions should have more displacement
|
|
267
|
+
if unexpected_disp > expected_disp * 2 and unexpected_disp > 0.005:
|
|
268
|
+
result.add_warning(
|
|
269
|
+
"procedure_region",
|
|
270
|
+
f"{procedure}: unexpected regions displaced more than expected "
|
|
271
|
+
f"({unexpected_disp:.4f} vs {expected_disp:.4f})",
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
result.add_pass("procedure_region")
|
|
275
|
+
else:
|
|
276
|
+
result.add_pass("procedure_region")
|
|
277
|
+
|
|
278
|
+
def _check_output_quality(self, result: SafetyResult, output: np.ndarray) -> None:
|
|
279
|
+
"""Check output image quality (not blank, not corrupted)."""
|
|
280
|
+
if output is None or output.size == 0:
|
|
281
|
+
result.add_failure("output_quality", "Output image is empty")
|
|
282
|
+
return
|
|
283
|
+
|
|
284
|
+
# Check for blank/black images
|
|
285
|
+
mean_val = output.mean()
|
|
286
|
+
if mean_val < 5:
|
|
287
|
+
result.add_failure("output_quality", f"Output is nearly black (mean={mean_val:.1f})")
|
|
288
|
+
return
|
|
289
|
+
if mean_val > 250:
|
|
290
|
+
result.add_failure("output_quality", f"Output is nearly white (mean={mean_val:.1f})")
|
|
291
|
+
return
|
|
292
|
+
|
|
293
|
+
# Check for artifacts (extreme variance)
|
|
294
|
+
std_val = output.std()
|
|
295
|
+
if std_val < 10:
|
|
296
|
+
result.add_warning(
|
|
297
|
+
"output_quality",
|
|
298
|
+
f"Output has very low variance (std={std_val:.1f}), may be uniform",
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
result.add_pass("output_quality")
|
|
302
|
+
result.details["output_mean"] = float(mean_val)
|
|
303
|
+
result.details["output_std"] = float(std_val)
|
|
304
|
+
|
|
305
|
+
def _check_ood(self, result: SafetyResult, image: np.ndarray) -> None:
|
|
306
|
+
"""Basic out-of-distribution detection.
|
|
307
|
+
|
|
308
|
+
Checks image properties against expected ranges for face photos.
|
|
309
|
+
"""
|
|
310
|
+
h, w = image.shape[:2]
|
|
311
|
+
|
|
312
|
+
# Resolution check
|
|
313
|
+
if min(h, w) < 128:
|
|
314
|
+
result.add_warning("ood", f"Image resolution too low: {w}x{h}")
|
|
315
|
+
|
|
316
|
+
# Aspect ratio (faces should be roughly square after preprocessing)
|
|
317
|
+
aspect = max(h, w) / max(min(h, w), 1)
|
|
318
|
+
if aspect > 3.0:
|
|
319
|
+
result.add_warning("ood", f"Unusual aspect ratio: {aspect:.1f}")
|
|
320
|
+
|
|
321
|
+
# Color distribution (face photos should have some skin tones)
|
|
322
|
+
if len(image.shape) == 3 and image.shape[2] == 3:
|
|
323
|
+
mean_b, mean_g, mean_r = image.mean(axis=(0, 1))
|
|
324
|
+
# Face images typically have red channel > blue channel
|
|
325
|
+
if mean_b > mean_r * 1.5:
|
|
326
|
+
result.add_warning("ood", "Image appears very blue (not typical face photo)")
|
|
327
|
+
|
|
328
|
+
result.add_pass("ood_basic")
|
|
329
|
+
|
|
330
|
+
def apply_watermark(
|
|
331
|
+
self,
|
|
332
|
+
image: np.ndarray,
|
|
333
|
+
text: str | None = None,
|
|
334
|
+
opacity: float = 0.3,
|
|
335
|
+
) -> np.ndarray:
|
|
336
|
+
"""Apply a text watermark to the output image.
|
|
337
|
+
|
|
338
|
+
Places semi-transparent text at the bottom of the image to indicate
|
|
339
|
+
it is AI-generated.
|
|
340
|
+
"""
|
|
341
|
+
if not self.watermark_enabled:
|
|
342
|
+
return image
|
|
343
|
+
|
|
344
|
+
text = text or self.watermark_text
|
|
345
|
+
result = image.copy()
|
|
346
|
+
h, w = result.shape[:2]
|
|
347
|
+
|
|
348
|
+
# Create text overlay
|
|
349
|
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
|
350
|
+
font_scale = max(0.3, w / 1500)
|
|
351
|
+
thickness = max(1, int(w / 500))
|
|
352
|
+
|
|
353
|
+
text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
|
|
354
|
+
x = (w - text_size[0]) // 2
|
|
355
|
+
y = h - 10
|
|
356
|
+
|
|
357
|
+
# Semi-transparent background bar
|
|
358
|
+
bar_y1 = y - text_size[1] - 10
|
|
359
|
+
bar_y2 = h
|
|
360
|
+
overlay = result.copy()
|
|
361
|
+
cv2.rectangle(overlay, (0, bar_y1), (w, bar_y2), (0, 0, 0), -1)
|
|
362
|
+
cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result)
|
|
363
|
+
|
|
364
|
+
# White text
|
|
365
|
+
cv2.putText(result, text, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
|
|
366
|
+
|
|
367
|
+
return result
|
|
368
|
+
|
|
369
|
+
def embed_metadata(
|
|
370
|
+
self,
|
|
371
|
+
image_path: str,
|
|
372
|
+
procedure: str,
|
|
373
|
+
intensity: float,
|
|
374
|
+
model_version: str = "0.3.0",
|
|
375
|
+
) -> None:
|
|
376
|
+
"""Embed provenance metadata in the output image.
|
|
377
|
+
|
|
378
|
+
Writes EXIF/PNG metadata with generation parameters for traceability.
|
|
379
|
+
"""
|
|
380
|
+
import json
|
|
381
|
+
from pathlib import Path
|
|
382
|
+
|
|
383
|
+
meta = {
|
|
384
|
+
"generator": "LandmarkDiff",
|
|
385
|
+
"version": model_version,
|
|
386
|
+
"procedure": procedure,
|
|
387
|
+
"intensity": intensity,
|
|
388
|
+
"disclaimer": "AI-generated surgical prediction for visualization only. "
|
|
389
|
+
"Not a guarantee of surgical outcome.",
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
# Save as sidecar JSON (PNG doesn't have easy EXIF support)
|
|
393
|
+
meta_path = Path(image_path).with_suffix(".meta.json")
|
|
394
|
+
with open(meta_path, "w") as f:
|
|
395
|
+
json.dump(meta, f, indent=2)
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Synthetic data generation for ControlNet fine-tuning.
|
|
2
|
+
|
|
3
|
+
Modules:
|
|
4
|
+
- pair_generator: Generate training pairs from face images
|
|
5
|
+
- augmentation: Clinical degradation augmentations
|
|
6
|
+
- tps_warp: TPS warping with rigid region preservation
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
|
|
10
|
+
from landmarkdiff.synthetic.pair_generator import (
|
|
11
|
+
TrainingPair,
|
|
12
|
+
generate_pair,
|
|
13
|
+
generate_pairs_from_directory,
|
|
14
|
+
)
|
|
15
|
+
from landmarkdiff.synthetic.tps_warp import warp_image_tps
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"TrainingPair",
|
|
19
|
+
"apply_clinical_augmentation",
|
|
20
|
+
"generate_pair",
|
|
21
|
+
"generate_pairs_from_directory",
|
|
22
|
+
"warp_image_tps",
|
|
23
|
+
]
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""Clinical degradation augmentations.
|
|
2
|
+
|
|
3
|
+
Degrades clean FFHQ/CelebA-HQ to match real clinical photo distribution.
|
|
4
|
+
Applied from day 1 - domain gap prevention, not afterthought.
|
|
5
|
+
3-5 random augmentations per sample.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from collections.abc import Callable
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
|
|
13
|
+
import cv2
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class AugmentationConfig:
|
|
19
|
+
"""Configuration for a single augmentation."""
|
|
20
|
+
|
|
21
|
+
name: str
|
|
22
|
+
fn: Callable[[np.ndarray, np.random.Generator], np.ndarray]
|
|
23
|
+
probability: float
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
27
|
+
"""Simulate point-source clinical lighting from a random direction."""
|
|
28
|
+
h, w = image.shape[:2]
|
|
29
|
+
|
|
30
|
+
# Random light source position
|
|
31
|
+
lx = rng.uniform(0, w)
|
|
32
|
+
ly = rng.uniform(0, h)
|
|
33
|
+
intensity = rng.uniform(0.3, 0.7)
|
|
34
|
+
|
|
35
|
+
# Distance-based falloff
|
|
36
|
+
y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
|
|
37
|
+
dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
|
|
38
|
+
max_dist = np.sqrt(w**2 + h**2)
|
|
39
|
+
light_map = 1.0 - (dist / max_dist) * intensity
|
|
40
|
+
|
|
41
|
+
light_map = np.clip(light_map, 0.3, 1.0)
|
|
42
|
+
light_3ch = np.stack([light_map] * 3, axis=-1)
|
|
43
|
+
|
|
44
|
+
return np.clip(image.astype(np.float32) * light_3ch, 0, 255).astype(np.uint8)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def color_temperature_jitter(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
48
|
+
"""Jitter color temperature +/- 2000K equivalent."""
|
|
49
|
+
shift = rng.uniform(-0.15, 0.15)
|
|
50
|
+
|
|
51
|
+
result = image.astype(np.float32)
|
|
52
|
+
if shift > 0:
|
|
53
|
+
# Warmer: boost red, reduce blue
|
|
54
|
+
result[:, :, 2] *= 1 + shift # red (BGR)
|
|
55
|
+
result[:, :, 0] *= 1 - shift * 0.5 # blue
|
|
56
|
+
else:
|
|
57
|
+
# Cooler: boost blue, reduce red
|
|
58
|
+
result[:, :, 0] *= 1 + abs(shift)
|
|
59
|
+
result[:, :, 2] *= 1 - abs(shift) * 0.5
|
|
60
|
+
|
|
61
|
+
return np.clip(result, 0, 255).astype(np.uint8)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def green_fluorescent_cast(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
65
|
+
"""Add green fluorescent lighting cast (common in clinical settings)."""
|
|
66
|
+
intensity = rng.uniform(0.05, 0.15)
|
|
67
|
+
result = image.astype(np.float32)
|
|
68
|
+
result[:, :, 1] *= 1 + intensity # green channel boost
|
|
69
|
+
result[:, :, 0] *= 1 - intensity * 0.3 # slight blue reduction
|
|
70
|
+
result[:, :, 2] *= 1 - intensity * 0.3 # slight red reduction
|
|
71
|
+
return np.clip(result, 0, 255).astype(np.uint8)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def jpeg_compression(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
75
|
+
"""Simulate JPEG compression artifacts (quality 40-85)."""
|
|
76
|
+
quality = int(rng.uniform(40, 85))
|
|
77
|
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
|
78
|
+
_, encoded = cv2.imencode(".jpg", image, encode_param)
|
|
79
|
+
return cv2.imdecode(encoded, cv2.IMREAD_COLOR)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
83
|
+
"""Add Gaussian sensor noise (sigma 5-25)."""
|
|
84
|
+
sigma = rng.uniform(5, 25)
|
|
85
|
+
noise = rng.normal(0, sigma, image.shape).astype(np.float32)
|
|
86
|
+
return np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
90
|
+
"""Apply barrel/pincushion distortion simulating phone camera lens."""
|
|
91
|
+
h, w = image.shape[:2]
|
|
92
|
+
k1 = rng.uniform(-0.2, 0.2)
|
|
93
|
+
|
|
94
|
+
fx = fy = max(w, h)
|
|
95
|
+
cx, cy = w / 2, h / 2
|
|
96
|
+
|
|
97
|
+
camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
|
|
98
|
+
dist_coeffs = np.array([k1, 0, 0, 0, 0], dtype=np.float64)
|
|
99
|
+
|
|
100
|
+
map1, map2 = cv2.initUndistortRectifyMap(
|
|
101
|
+
camera_matrix, dist_coeffs, None, camera_matrix, (w, h), cv2.CV_32FC1
|
|
102
|
+
)
|
|
103
|
+
return cv2.remap(image, map1, map2, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
107
|
+
"""Slight motion blur (common in handheld clinical photos)."""
|
|
108
|
+
size = int(rng.uniform(3, 7))
|
|
109
|
+
angle = rng.uniform(0, 180)
|
|
110
|
+
|
|
111
|
+
kernel = np.zeros((size, size))
|
|
112
|
+
kernel[size // 2, :] = 1.0 / size
|
|
113
|
+
|
|
114
|
+
M = cv2.getRotationMatrix2D((size / 2, size / 2), angle, 1)
|
|
115
|
+
kernel = cv2.warpAffine(kernel, M, (size, size))
|
|
116
|
+
ksum = kernel.sum()
|
|
117
|
+
if ksum > 0:
|
|
118
|
+
kernel = kernel / ksum
|
|
119
|
+
else:
|
|
120
|
+
# rotation can zero out the kernel - fall back to identity
|
|
121
|
+
kernel = np.zeros_like(kernel)
|
|
122
|
+
kernel[size // 2, size // 2] = 1.0
|
|
123
|
+
|
|
124
|
+
return cv2.filter2D(image, -1, kernel)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
|
|
128
|
+
"""Add lens vignetting (darkened corners)."""
|
|
129
|
+
h, w = image.shape[:2]
|
|
130
|
+
strength = rng.uniform(0.3, 0.7)
|
|
131
|
+
|
|
132
|
+
y, x = np.mgrid[0:h, 0:w].astype(np.float32)
|
|
133
|
+
cx, cy = w / 2, h / 2
|
|
134
|
+
dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
|
|
135
|
+
max_dist = np.sqrt(cx**2 + cy**2)
|
|
136
|
+
|
|
137
|
+
mask = 1 - strength * (dist / max_dist) ** 2
|
|
138
|
+
mask = np.clip(mask, 0.3, 1.0)
|
|
139
|
+
mask_3ch = np.stack([mask] * 3, axis=-1)
|
|
140
|
+
|
|
141
|
+
return np.clip(image.astype(np.float32) * mask_3ch, 0, 255).astype(np.uint8)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
# Augmentation pool with probabilities from the spec
|
|
145
|
+
AUGMENTATION_POOL: list[AugmentationConfig] = [
|
|
146
|
+
AugmentationConfig("point_source_lighting", point_source_lighting, 0.40),
|
|
147
|
+
AugmentationConfig("color_temperature", color_temperature_jitter, 0.60),
|
|
148
|
+
AugmentationConfig("green_fluorescent", green_fluorescent_cast, 0.25),
|
|
149
|
+
AugmentationConfig("jpeg_compression", jpeg_compression, 0.30),
|
|
150
|
+
AugmentationConfig("sensor_noise", gaussian_sensor_noise, 0.40),
|
|
151
|
+
AugmentationConfig("barrel_distortion", barrel_distortion, 0.30),
|
|
152
|
+
AugmentationConfig("motion_blur", motion_blur, 0.20),
|
|
153
|
+
AugmentationConfig("vignette", vignette, 0.25),
|
|
154
|
+
]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def apply_clinical_augmentation(
|
|
158
|
+
image: np.ndarray,
|
|
159
|
+
min_augmentations: int = 3,
|
|
160
|
+
max_augmentations: int = 5,
|
|
161
|
+
rng: np.random.Generator | None = None,
|
|
162
|
+
) -> np.ndarray:
|
|
163
|
+
"""Apply random clinical degradation augmentations to an image."""
|
|
164
|
+
rng = rng or np.random.default_rng()
|
|
165
|
+
|
|
166
|
+
# Select augmentations by probability
|
|
167
|
+
selected = []
|
|
168
|
+
for aug in AUGMENTATION_POOL:
|
|
169
|
+
if rng.random() < aug.probability:
|
|
170
|
+
selected.append(aug)
|
|
171
|
+
|
|
172
|
+
# Ensure min/max bounds
|
|
173
|
+
if len(selected) < min_augmentations:
|
|
174
|
+
remaining = [a for a in AUGMENTATION_POOL if a not in selected]
|
|
175
|
+
rng.shuffle(remaining)
|
|
176
|
+
selected.extend(remaining[: min_augmentations - len(selected)])
|
|
177
|
+
|
|
178
|
+
if len(selected) > max_augmentations:
|
|
179
|
+
rng.shuffle(selected)
|
|
180
|
+
selected = selected[:max_augmentations]
|
|
181
|
+
|
|
182
|
+
# Apply in random order
|
|
183
|
+
rng.shuffle(selected)
|
|
184
|
+
result = image.copy()
|
|
185
|
+
for aug in selected:
|
|
186
|
+
result = aug.fn(result, rng)
|
|
187
|
+
|
|
188
|
+
return result
|