landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/safety.py ADDED
@@ -0,0 +1,395 @@
1
+ """Clinical safety validation for responsible deployment.
2
+
3
+ Implements safety checks for surgical outcome predictions:
4
+ 1. Identity preservation: verify output preserves patient identity
5
+ 2. Anatomical plausibility: check landmark displacements are realistic
6
+ 3. Out-of-distribution detection: flag unusual inputs
7
+ 4. Watermarking: mark AI-generated images
8
+ 5. Consent metadata: embed provenance information
9
+
10
+ Usage:
11
+ from landmarkdiff.safety import SafetyValidator
12
+
13
+ validator = SafetyValidator()
14
+ result = validator.validate(
15
+ input_image=image,
16
+ output_image=generated,
17
+ landmarks_original=face.landmarks,
18
+ landmarks_manipulated=manip.landmarks,
19
+ procedure="rhinoplasty",
20
+ )
21
+
22
+ if not result.passed:
23
+ print(f"Safety check failed: {result.failures}")
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ from dataclasses import dataclass, field
29
+
30
+ import cv2
31
+ import numpy as np
32
+
33
+
34
+ @dataclass
35
+ class SafetyResult:
36
+ """Result of safety validation checks."""
37
+
38
+ passed: bool = True
39
+ failures: list[str] = field(default_factory=list)
40
+ warnings: list[str] = field(default_factory=list)
41
+ checks: dict[str, bool] = field(default_factory=dict)
42
+ details: dict[str, object] = field(default_factory=dict)
43
+
44
+ def __repr__(self) -> str:
45
+ return (
46
+ f"SafetyResult(passed={self.passed}, "
47
+ f"failures={self.failures}, "
48
+ f"warnings={self.warnings}, "
49
+ f"checks={self.checks}, "
50
+ f"details={self.details})"
51
+ )
52
+
53
+ def __eq__(self, other: object) -> bool:
54
+ if not isinstance(other, SafetyResult):
55
+ return NotImplemented
56
+ return (
57
+ self.passed == other.passed
58
+ and self.failures == other.failures
59
+ and self.warnings == other.warnings
60
+ and self.checks == other.checks
61
+ and self.details == other.details
62
+ )
63
+
64
+ def add_failure(self, name: str, message: str) -> None:
65
+ self.passed = False
66
+ self.failures.append(message)
67
+ self.checks[name] = False
68
+
69
+ def add_warning(self, name: str, message: str) -> None:
70
+ self.warnings.append(message)
71
+
72
+ def add_pass(self, name: str) -> None:
73
+ self.checks[name] = True
74
+
75
+ def summary(self) -> str:
76
+ lines = [f"Safety: {'PASS' if self.passed else 'FAIL'}"]
77
+ for name, ok in self.checks.items():
78
+ lines.append(f" [{'OK' if ok else 'FAIL'}] {name}")
79
+ for w in self.warnings:
80
+ lines.append(f" [WARN] {w}")
81
+ return "\n".join(lines)
82
+
83
+
84
+ class SafetyValidator:
85
+ """Clinical safety validation for surgical predictions."""
86
+
87
+ def __init__(
88
+ self,
89
+ identity_threshold: float = 0.6,
90
+ max_displacement_fraction: float = 0.05,
91
+ min_face_confidence: float = 0.5,
92
+ max_yaw_degrees: float = 45.0,
93
+ watermark_enabled: bool = True,
94
+ watermark_text: str = "AI-GENERATED PREDICTION",
95
+ ):
96
+ self.identity_threshold = identity_threshold
97
+ self.max_displacement_fraction = max_displacement_fraction
98
+ self.min_face_confidence = min_face_confidence
99
+ self.max_yaw_degrees = max_yaw_degrees
100
+ self.watermark_enabled = watermark_enabled
101
+ self.watermark_text = watermark_text
102
+
103
+ def validate(
104
+ self,
105
+ input_image: np.ndarray,
106
+ output_image: np.ndarray,
107
+ landmarks_original: np.ndarray | None = None,
108
+ landmarks_manipulated: np.ndarray | None = None,
109
+ procedure: str | None = None,
110
+ face_confidence: float = 1.0,
111
+ ) -> SafetyResult:
112
+ """Run all safety checks on a prediction.
113
+
114
+ Args:
115
+ input_image: Original patient image (BGR, uint8).
116
+ output_image: Generated prediction (BGR, uint8).
117
+ landmarks_original: Original landmarks (N, 2-3), normalized [0, 1].
118
+ landmarks_manipulated: Manipulated landmarks (N, 2-3), normalized [0, 1].
119
+ procedure: Surgical procedure name.
120
+ face_confidence: MediaPipe face detection confidence.
121
+
122
+ Returns:
123
+ SafetyResult with all check results.
124
+ """
125
+ result = SafetyResult()
126
+
127
+ # 1. Face detection confidence
128
+ self._check_face_confidence(result, face_confidence)
129
+
130
+ # 2. Identity preservation
131
+ self._check_identity(result, input_image, output_image)
132
+
133
+ # 3. Anatomical plausibility
134
+ if landmarks_original is not None and landmarks_manipulated is not None:
135
+ self._check_anatomical_plausibility(
136
+ result, landmarks_original, landmarks_manipulated, procedure
137
+ )
138
+
139
+ # 4. Output quality
140
+ self._check_output_quality(result, output_image)
141
+
142
+ # 5. OOD detection (basic)
143
+ self._check_ood(result, input_image)
144
+
145
+ return result
146
+
147
+ def _check_face_confidence(self, result: SafetyResult, confidence: float) -> None:
148
+ """Check face detection confidence."""
149
+ if confidence < self.min_face_confidence:
150
+ result.add_failure(
151
+ "face_confidence",
152
+ f"Face detection confidence {confidence:.2f} below threshold "
153
+ f"{self.min_face_confidence}",
154
+ )
155
+ else:
156
+ result.add_pass("face_confidence")
157
+ result.details["face_confidence"] = confidence
158
+
159
+ def _check_identity(
160
+ self,
161
+ result: SafetyResult,
162
+ input_image: np.ndarray,
163
+ output_image: np.ndarray,
164
+ ) -> None:
165
+ """Check identity preservation using ArcFace similarity."""
166
+ try:
167
+ from landmarkdiff.evaluation import compute_identity_similarity
168
+
169
+ sim = compute_identity_similarity(output_image, input_image)
170
+ result.details["identity_similarity"] = float(sim)
171
+
172
+ if sim < self.identity_threshold:
173
+ result.add_failure(
174
+ "identity",
175
+ f"Identity similarity {sim:.3f} below threshold {self.identity_threshold}",
176
+ )
177
+ else:
178
+ result.add_pass("identity")
179
+ except Exception as e:
180
+ result.add_warning("identity", f"Identity check failed: {e}")
181
+
182
+ def _check_anatomical_plausibility(
183
+ self,
184
+ result: SafetyResult,
185
+ landmarks_orig: np.ndarray,
186
+ landmarks_manip: np.ndarray,
187
+ procedure: str | None,
188
+ ) -> None:
189
+ """Check that landmark displacements are anatomically plausible."""
190
+ if len(landmarks_orig) != len(landmarks_manip):
191
+ result.add_failure(
192
+ "anatomical",
193
+ f"Landmark count mismatch: {len(landmarks_orig)} vs {len(landmarks_manip)}",
194
+ )
195
+ return
196
+
197
+ # Compute displacement magnitudes
198
+ n = min(len(landmarks_orig), len(landmarks_manip))
199
+ orig = landmarks_orig[:n, :2] # (N, 2), normalized [0, 1]
200
+ manip = landmarks_manip[:n, :2]
201
+ displacements = np.linalg.norm(manip - orig, axis=1)
202
+
203
+ max_disp = float(displacements.max())
204
+ mean_disp = float(displacements.mean())
205
+ result.details["max_displacement"] = max_disp
206
+ result.details["mean_displacement"] = mean_disp
207
+
208
+ # Check maximum displacement
209
+ if max_disp > self.max_displacement_fraction:
210
+ result.add_failure(
211
+ "anatomical_magnitude",
212
+ f"Maximum displacement {max_disp:.4f} exceeds threshold "
213
+ f"{self.max_displacement_fraction}",
214
+ )
215
+ else:
216
+ result.add_pass("anatomical_magnitude")
217
+
218
+ # Check procedure-specific regions
219
+ if procedure:
220
+ self._check_procedure_regions(result, orig, manip, displacements, procedure)
221
+
222
+ def _check_procedure_regions(
223
+ self,
224
+ result: SafetyResult,
225
+ orig: np.ndarray,
226
+ manip: np.ndarray,
227
+ displacements: np.ndarray,
228
+ procedure: str,
229
+ ) -> None:
230
+ """Verify displacement is concentrated in expected anatomical regions."""
231
+ from landmarkdiff.landmarks import LANDMARK_REGIONS
232
+
233
+ # Expected regions by procedure
234
+ expected_regions = {
235
+ "rhinoplasty": ["nose"],
236
+ "blepharoplasty": ["eye_left", "eye_right"],
237
+ "rhytidectomy": ["jawline"],
238
+ "orthognathic": ["jawline", "lips"],
239
+ }
240
+
241
+ expected = expected_regions.get(procedure, [])
242
+ if not expected:
243
+ result.add_pass("procedure_region")
244
+ return
245
+
246
+ # Get expected region indices
247
+ expected_indices = set()
248
+ for region in expected:
249
+ if region in LANDMARK_REGIONS:
250
+ expected_indices.update(LANDMARK_REGIONS[region])
251
+
252
+ if not expected_indices:
253
+ result.add_pass("procedure_region")
254
+ return
255
+
256
+ # Check: is most displacement in expected regions?
257
+ n = min(len(displacements), len(orig))
258
+ expected_mask = np.array([i in expected_indices for i in range(n)])
259
+
260
+ if expected_mask.sum() > 0 and (~expected_mask).sum() > 0:
261
+ expected_disp = displacements[expected_mask].mean()
262
+ unexpected_disp = displacements[~expected_mask].mean()
263
+ result.details["expected_region_disp"] = float(expected_disp)
264
+ result.details["unexpected_region_disp"] = float(unexpected_disp)
265
+
266
+ # Expected regions should have more displacement
267
+ if unexpected_disp > expected_disp * 2 and unexpected_disp > 0.005:
268
+ result.add_warning(
269
+ "procedure_region",
270
+ f"{procedure}: unexpected regions displaced more than expected "
271
+ f"({unexpected_disp:.4f} vs {expected_disp:.4f})",
272
+ )
273
+ else:
274
+ result.add_pass("procedure_region")
275
+ else:
276
+ result.add_pass("procedure_region")
277
+
278
+ def _check_output_quality(self, result: SafetyResult, output: np.ndarray) -> None:
279
+ """Check output image quality (not blank, not corrupted)."""
280
+ if output is None or output.size == 0:
281
+ result.add_failure("output_quality", "Output image is empty")
282
+ return
283
+
284
+ # Check for blank/black images
285
+ mean_val = output.mean()
286
+ if mean_val < 5:
287
+ result.add_failure("output_quality", f"Output is nearly black (mean={mean_val:.1f})")
288
+ return
289
+ if mean_val > 250:
290
+ result.add_failure("output_quality", f"Output is nearly white (mean={mean_val:.1f})")
291
+ return
292
+
293
+ # Check for artifacts (extreme variance)
294
+ std_val = output.std()
295
+ if std_val < 10:
296
+ result.add_warning(
297
+ "output_quality",
298
+ f"Output has very low variance (std={std_val:.1f}), may be uniform",
299
+ )
300
+
301
+ result.add_pass("output_quality")
302
+ result.details["output_mean"] = float(mean_val)
303
+ result.details["output_std"] = float(std_val)
304
+
305
+ def _check_ood(self, result: SafetyResult, image: np.ndarray) -> None:
306
+ """Basic out-of-distribution detection.
307
+
308
+ Checks image properties against expected ranges for face photos.
309
+ """
310
+ h, w = image.shape[:2]
311
+
312
+ # Resolution check
313
+ if min(h, w) < 128:
314
+ result.add_warning("ood", f"Image resolution too low: {w}x{h}")
315
+
316
+ # Aspect ratio (faces should be roughly square after preprocessing)
317
+ aspect = max(h, w) / max(min(h, w), 1)
318
+ if aspect > 3.0:
319
+ result.add_warning("ood", f"Unusual aspect ratio: {aspect:.1f}")
320
+
321
+ # Color distribution (face photos should have some skin tones)
322
+ if len(image.shape) == 3 and image.shape[2] == 3:
323
+ mean_b, mean_g, mean_r = image.mean(axis=(0, 1))
324
+ # Face images typically have red channel > blue channel
325
+ if mean_b > mean_r * 1.5:
326
+ result.add_warning("ood", "Image appears very blue (not typical face photo)")
327
+
328
+ result.add_pass("ood_basic")
329
+
330
+ def apply_watermark(
331
+ self,
332
+ image: np.ndarray,
333
+ text: str | None = None,
334
+ opacity: float = 0.3,
335
+ ) -> np.ndarray:
336
+ """Apply a text watermark to the output image.
337
+
338
+ Places semi-transparent text at the bottom of the image to indicate
339
+ it is AI-generated.
340
+ """
341
+ if not self.watermark_enabled:
342
+ return image
343
+
344
+ text = text or self.watermark_text
345
+ result = image.copy()
346
+ h, w = result.shape[:2]
347
+
348
+ # Create text overlay
349
+ font = cv2.FONT_HERSHEY_SIMPLEX
350
+ font_scale = max(0.3, w / 1500)
351
+ thickness = max(1, int(w / 500))
352
+
353
+ text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
354
+ x = (w - text_size[0]) // 2
355
+ y = h - 10
356
+
357
+ # Semi-transparent background bar
358
+ bar_y1 = y - text_size[1] - 10
359
+ bar_y2 = h
360
+ overlay = result.copy()
361
+ cv2.rectangle(overlay, (0, bar_y1), (w, bar_y2), (0, 0, 0), -1)
362
+ cv2.addWeighted(overlay, opacity, result, 1 - opacity, 0, result)
363
+
364
+ # White text
365
+ cv2.putText(result, text, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
366
+
367
+ return result
368
+
369
+ def embed_metadata(
370
+ self,
371
+ image_path: str,
372
+ procedure: str,
373
+ intensity: float,
374
+ model_version: str = "0.3.0",
375
+ ) -> None:
376
+ """Embed provenance metadata in the output image.
377
+
378
+ Writes EXIF/PNG metadata with generation parameters for traceability.
379
+ """
380
+ import json
381
+ from pathlib import Path
382
+
383
+ meta = {
384
+ "generator": "LandmarkDiff",
385
+ "version": model_version,
386
+ "procedure": procedure,
387
+ "intensity": intensity,
388
+ "disclaimer": "AI-generated surgical prediction for visualization only. "
389
+ "Not a guarantee of surgical outcome.",
390
+ }
391
+
392
+ # Save as sidecar JSON (PNG doesn't have easy EXIF support)
393
+ meta_path = Path(image_path).with_suffix(".meta.json")
394
+ with open(meta_path, "w") as f:
395
+ json.dump(meta, f, indent=2)
@@ -0,0 +1,23 @@
1
+ """Synthetic data generation for ControlNet fine-tuning.
2
+
3
+ Modules:
4
+ - pair_generator: Generate training pairs from face images
5
+ - augmentation: Clinical degradation augmentations
6
+ - tps_warp: TPS warping with rigid region preservation
7
+ """
8
+
9
+ from landmarkdiff.synthetic.augmentation import apply_clinical_augmentation
10
+ from landmarkdiff.synthetic.pair_generator import (
11
+ TrainingPair,
12
+ generate_pair,
13
+ generate_pairs_from_directory,
14
+ )
15
+ from landmarkdiff.synthetic.tps_warp import warp_image_tps
16
+
17
+ __all__ = [
18
+ "TrainingPair",
19
+ "apply_clinical_augmentation",
20
+ "generate_pair",
21
+ "generate_pairs_from_directory",
22
+ "warp_image_tps",
23
+ ]
@@ -0,0 +1,188 @@
1
+ """Clinical degradation augmentations.
2
+
3
+ Degrades clean FFHQ/CelebA-HQ to match real clinical photo distribution.
4
+ Applied from day 1 - domain gap prevention, not afterthought.
5
+ 3-5 random augmentations per sample.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from collections.abc import Callable
11
+ from dataclasses import dataclass
12
+
13
+ import cv2
14
+ import numpy as np
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class AugmentationConfig:
19
+ """Configuration for a single augmentation."""
20
+
21
+ name: str
22
+ fn: Callable[[np.ndarray, np.random.Generator], np.ndarray]
23
+ probability: float
24
+
25
+
26
+ def point_source_lighting(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
27
+ """Simulate point-source clinical lighting from a random direction."""
28
+ h, w = image.shape[:2]
29
+
30
+ # Random light source position
31
+ lx = rng.uniform(0, w)
32
+ ly = rng.uniform(0, h)
33
+ intensity = rng.uniform(0.3, 0.7)
34
+
35
+ # Distance-based falloff
36
+ y_grid, x_grid = np.mgrid[0:h, 0:w].astype(np.float32)
37
+ dist = np.sqrt((x_grid - lx) ** 2 + (y_grid - ly) ** 2)
38
+ max_dist = np.sqrt(w**2 + h**2)
39
+ light_map = 1.0 - (dist / max_dist) * intensity
40
+
41
+ light_map = np.clip(light_map, 0.3, 1.0)
42
+ light_3ch = np.stack([light_map] * 3, axis=-1)
43
+
44
+ return np.clip(image.astype(np.float32) * light_3ch, 0, 255).astype(np.uint8)
45
+
46
+
47
+ def color_temperature_jitter(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
48
+ """Jitter color temperature +/- 2000K equivalent."""
49
+ shift = rng.uniform(-0.15, 0.15)
50
+
51
+ result = image.astype(np.float32)
52
+ if shift > 0:
53
+ # Warmer: boost red, reduce blue
54
+ result[:, :, 2] *= 1 + shift # red (BGR)
55
+ result[:, :, 0] *= 1 - shift * 0.5 # blue
56
+ else:
57
+ # Cooler: boost blue, reduce red
58
+ result[:, :, 0] *= 1 + abs(shift)
59
+ result[:, :, 2] *= 1 - abs(shift) * 0.5
60
+
61
+ return np.clip(result, 0, 255).astype(np.uint8)
62
+
63
+
64
+ def green_fluorescent_cast(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
65
+ """Add green fluorescent lighting cast (common in clinical settings)."""
66
+ intensity = rng.uniform(0.05, 0.15)
67
+ result = image.astype(np.float32)
68
+ result[:, :, 1] *= 1 + intensity # green channel boost
69
+ result[:, :, 0] *= 1 - intensity * 0.3 # slight blue reduction
70
+ result[:, :, 2] *= 1 - intensity * 0.3 # slight red reduction
71
+ return np.clip(result, 0, 255).astype(np.uint8)
72
+
73
+
74
+ def jpeg_compression(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
75
+ """Simulate JPEG compression artifacts (quality 40-85)."""
76
+ quality = int(rng.uniform(40, 85))
77
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
78
+ _, encoded = cv2.imencode(".jpg", image, encode_param)
79
+ return cv2.imdecode(encoded, cv2.IMREAD_COLOR)
80
+
81
+
82
+ def gaussian_sensor_noise(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
83
+ """Add Gaussian sensor noise (sigma 5-25)."""
84
+ sigma = rng.uniform(5, 25)
85
+ noise = rng.normal(0, sigma, image.shape).astype(np.float32)
86
+ return np.clip(image.astype(np.float32) + noise, 0, 255).astype(np.uint8)
87
+
88
+
89
+ def barrel_distortion(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
90
+ """Apply barrel/pincushion distortion simulating phone camera lens."""
91
+ h, w = image.shape[:2]
92
+ k1 = rng.uniform(-0.2, 0.2)
93
+
94
+ fx = fy = max(w, h)
95
+ cx, cy = w / 2, h / 2
96
+
97
+ camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
98
+ dist_coeffs = np.array([k1, 0, 0, 0, 0], dtype=np.float64)
99
+
100
+ map1, map2 = cv2.initUndistortRectifyMap(
101
+ camera_matrix, dist_coeffs, None, camera_matrix, (w, h), cv2.CV_32FC1
102
+ )
103
+ return cv2.remap(image, map1, map2, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
104
+
105
+
106
+ def motion_blur(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
107
+ """Slight motion blur (common in handheld clinical photos)."""
108
+ size = int(rng.uniform(3, 7))
109
+ angle = rng.uniform(0, 180)
110
+
111
+ kernel = np.zeros((size, size))
112
+ kernel[size // 2, :] = 1.0 / size
113
+
114
+ M = cv2.getRotationMatrix2D((size / 2, size / 2), angle, 1)
115
+ kernel = cv2.warpAffine(kernel, M, (size, size))
116
+ ksum = kernel.sum()
117
+ if ksum > 0:
118
+ kernel = kernel / ksum
119
+ else:
120
+ # rotation can zero out the kernel - fall back to identity
121
+ kernel = np.zeros_like(kernel)
122
+ kernel[size // 2, size // 2] = 1.0
123
+
124
+ return cv2.filter2D(image, -1, kernel)
125
+
126
+
127
+ def vignette(image: np.ndarray, rng: np.random.Generator) -> np.ndarray:
128
+ """Add lens vignetting (darkened corners)."""
129
+ h, w = image.shape[:2]
130
+ strength = rng.uniform(0.3, 0.7)
131
+
132
+ y, x = np.mgrid[0:h, 0:w].astype(np.float32)
133
+ cx, cy = w / 2, h / 2
134
+ dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
135
+ max_dist = np.sqrt(cx**2 + cy**2)
136
+
137
+ mask = 1 - strength * (dist / max_dist) ** 2
138
+ mask = np.clip(mask, 0.3, 1.0)
139
+ mask_3ch = np.stack([mask] * 3, axis=-1)
140
+
141
+ return np.clip(image.astype(np.float32) * mask_3ch, 0, 255).astype(np.uint8)
142
+
143
+
144
+ # Augmentation pool with probabilities from the spec
145
+ AUGMENTATION_POOL: list[AugmentationConfig] = [
146
+ AugmentationConfig("point_source_lighting", point_source_lighting, 0.40),
147
+ AugmentationConfig("color_temperature", color_temperature_jitter, 0.60),
148
+ AugmentationConfig("green_fluorescent", green_fluorescent_cast, 0.25),
149
+ AugmentationConfig("jpeg_compression", jpeg_compression, 0.30),
150
+ AugmentationConfig("sensor_noise", gaussian_sensor_noise, 0.40),
151
+ AugmentationConfig("barrel_distortion", barrel_distortion, 0.30),
152
+ AugmentationConfig("motion_blur", motion_blur, 0.20),
153
+ AugmentationConfig("vignette", vignette, 0.25),
154
+ ]
155
+
156
+
157
+ def apply_clinical_augmentation(
158
+ image: np.ndarray,
159
+ min_augmentations: int = 3,
160
+ max_augmentations: int = 5,
161
+ rng: np.random.Generator | None = None,
162
+ ) -> np.ndarray:
163
+ """Apply random clinical degradation augmentations to an image."""
164
+ rng = rng or np.random.default_rng()
165
+
166
+ # Select augmentations by probability
167
+ selected = []
168
+ for aug in AUGMENTATION_POOL:
169
+ if rng.random() < aug.probability:
170
+ selected.append(aug)
171
+
172
+ # Ensure min/max bounds
173
+ if len(selected) < min_augmentations:
174
+ remaining = [a for a in AUGMENTATION_POOL if a not in selected]
175
+ rng.shuffle(remaining)
176
+ selected.extend(remaining[: min_augmentations - len(selected)])
177
+
178
+ if len(selected) > max_augmentations:
179
+ rng.shuffle(selected)
180
+ selected = selected[:max_augmentations]
181
+
182
+ # Apply in random order
183
+ rng.shuffle(selected)
184
+ result = image.copy()
185
+ for aug in selected:
186
+ result = aug.fn(result, rng)
187
+
188
+ return result