landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/audit.py ADDED
@@ -0,0 +1,338 @@
1
+ """Clinical audit report generator for regulatory compliance.
2
+
3
+ Generates structured HTML reports summarizing safety validation results,
4
+ model performance, and Fitzpatrick equity analysis for clinical review.
5
+
6
+ Reports include:
7
+ - Safety validation pass/fail summary per patient
8
+ - Aggregate statistics by procedure and Fitzpatrick type
9
+ - Flagged cases for manual review
10
+ - Model version and configuration provenance
11
+
12
+ Usage:
13
+ from landmarkdiff.audit import AuditReporter, AuditCase
14
+
15
+ reporter = AuditReporter(model_version="0.3.2")
16
+ reporter.add_case(AuditCase(
17
+ case_id="P001",
18
+ procedure="rhinoplasty",
19
+ safety_passed=True,
20
+ identity_sim=0.87,
21
+ fitzpatrick_type="III",
22
+ ))
23
+ reporter.generate_report("audit_report.html")
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import json
29
+ from dataclasses import dataclass, field
30
+ from datetime import datetime, timezone
31
+ from pathlib import Path
32
+ from typing import Any
33
+
34
+
35
+ @dataclass
36
+ class AuditCase:
37
+ """A single patient case for audit reporting."""
38
+
39
+ case_id: str
40
+ procedure: str
41
+ safety_passed: bool
42
+ identity_sim: float = 0.0
43
+ intensity: float = 65.0
44
+ fitzpatrick_type: str = ""
45
+ warnings: list[str] = field(default_factory=list)
46
+ failures: list[str] = field(default_factory=list)
47
+ metrics: dict[str, float] = field(default_factory=dict)
48
+ timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
49
+
50
+
51
+ @dataclass
52
+ class AuditSummary:
53
+ """Aggregate statistics for an audit report."""
54
+
55
+ total_cases: int = 0
56
+ passed_cases: int = 0
57
+ failed_cases: int = 0
58
+ flagged_cases: int = 0
59
+ pass_rate: float = 0.0
60
+ mean_identity_sim: float = 0.0
61
+ by_procedure: dict[str, dict[str, Any]] = field(default_factory=dict)
62
+ by_fitzpatrick: dict[str, dict[str, Any]] = field(default_factory=dict)
63
+
64
+
65
+ class AuditReporter:
66
+ """Generate clinical audit reports from safety validation results.
67
+
68
+ Args:
69
+ model_version: Model version string for provenance.
70
+ report_title: Title for generated reports.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ model_version: str = "0.3.2",
76
+ report_title: str = "LandmarkDiff Clinical Audit Report",
77
+ ) -> None:
78
+ self.model_version = model_version
79
+ self.report_title = report_title
80
+ self.cases: list[AuditCase] = []
81
+
82
+ def add_case(self, case: AuditCase) -> None:
83
+ """Add a case to the audit report."""
84
+ self.cases.append(case)
85
+
86
+ def add_cases(self, cases: list[AuditCase]) -> None:
87
+ """Add multiple cases."""
88
+ self.cases.extend(cases)
89
+
90
+ def clear(self) -> None:
91
+ """Clear all cases."""
92
+ self.cases.clear()
93
+
94
+ def compute_summary(self) -> AuditSummary:
95
+ """Compute aggregate statistics from all cases."""
96
+ if not self.cases:
97
+ return AuditSummary()
98
+
99
+ total = len(self.cases)
100
+ passed = sum(1 for c in self.cases if c.safety_passed)
101
+ failed = total - passed
102
+ flagged = sum(1 for c in self.cases if not c.safety_passed or c.warnings)
103
+
104
+ id_sims = [c.identity_sim for c in self.cases if c.identity_sim > 0]
105
+ mean_id = sum(id_sims) / len(id_sims) if id_sims else 0.0
106
+
107
+ # By procedure
108
+ by_proc: dict[str, dict[str, Any]] = {}
109
+ for case in self.cases:
110
+ proc = case.procedure
111
+ if proc not in by_proc:
112
+ by_proc[proc] = {"total": 0, "passed": 0, "id_sims": []}
113
+ by_proc[proc]["total"] += 1
114
+ if case.safety_passed:
115
+ by_proc[proc]["passed"] += 1
116
+ if case.identity_sim > 0:
117
+ by_proc[proc]["id_sims"].append(case.identity_sim)
118
+
119
+ for _proc, stats in by_proc.items():
120
+ stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
121
+ stats["mean_identity_sim"] = (
122
+ sum(stats["id_sims"]) / len(stats["id_sims"]) if stats["id_sims"] else 0.0
123
+ )
124
+ del stats["id_sims"]
125
+
126
+ # By Fitzpatrick type
127
+ by_fitz: dict[str, dict[str, Any]] = {}
128
+ for case in self.cases:
129
+ ft = case.fitzpatrick_type or "Unknown"
130
+ if ft not in by_fitz:
131
+ by_fitz[ft] = {"total": 0, "passed": 0, "id_sims": []}
132
+ by_fitz[ft]["total"] += 1
133
+ if case.safety_passed:
134
+ by_fitz[ft]["passed"] += 1
135
+ if case.identity_sim > 0:
136
+ by_fitz[ft]["id_sims"].append(case.identity_sim)
137
+
138
+ for _ft, stats in by_fitz.items():
139
+ stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
140
+ stats["mean_identity_sim"] = (
141
+ sum(stats["id_sims"]) / len(stats["id_sims"]) if stats["id_sims"] else 0.0
142
+ )
143
+ del stats["id_sims"]
144
+
145
+ return AuditSummary(
146
+ total_cases=total,
147
+ passed_cases=passed,
148
+ failed_cases=failed,
149
+ flagged_cases=flagged,
150
+ pass_rate=passed / total,
151
+ mean_identity_sim=mean_id,
152
+ by_procedure=by_proc,
153
+ by_fitzpatrick=by_fitz,
154
+ )
155
+
156
+ def flagged_cases(self) -> list[AuditCase]:
157
+ """Return cases that need manual review (failed or have warnings)."""
158
+ return [c for c in self.cases if not c.safety_passed or c.warnings]
159
+
160
+ def to_json(self) -> str:
161
+ """Export audit data as JSON."""
162
+ summary = self.compute_summary()
163
+ data = {
164
+ "report_title": self.report_title,
165
+ "model_version": self.model_version,
166
+ "generated_at": datetime.now(timezone.utc).isoformat(),
167
+ "summary": {
168
+ "total_cases": summary.total_cases,
169
+ "passed_cases": summary.passed_cases,
170
+ "failed_cases": summary.failed_cases,
171
+ "flagged_cases": summary.flagged_cases,
172
+ "pass_rate": round(summary.pass_rate, 4),
173
+ "mean_identity_sim": round(summary.mean_identity_sim, 4),
174
+ },
175
+ "by_procedure": {
176
+ k: {kk: round(vv, 4) if isinstance(vv, float) else vv for kk, vv in v.items()}
177
+ for k, v in summary.by_procedure.items()
178
+ },
179
+ "by_fitzpatrick": {
180
+ k: {kk: round(vv, 4) if isinstance(vv, float) else vv for kk, vv in v.items()}
181
+ for k, v in summary.by_fitzpatrick.items()
182
+ },
183
+ "cases": [
184
+ {
185
+ "case_id": c.case_id,
186
+ "procedure": c.procedure,
187
+ "safety_passed": c.safety_passed,
188
+ "identity_sim": round(c.identity_sim, 4),
189
+ "intensity": c.intensity,
190
+ "fitzpatrick_type": c.fitzpatrick_type,
191
+ "warnings": c.warnings,
192
+ "failures": c.failures,
193
+ "metrics": {k: round(v, 4) for k, v in c.metrics.items()},
194
+ "timestamp": c.timestamp,
195
+ }
196
+ for c in self.cases
197
+ ],
198
+ }
199
+ return json.dumps(data, indent=2)
200
+
201
+ def generate_report(self, output_path: str | Path) -> Path:
202
+ """Generate an HTML audit report.
203
+
204
+ Args:
205
+ output_path: Path to save the HTML report.
206
+
207
+ Returns:
208
+ Path to the generated report.
209
+ """
210
+ output_path = Path(output_path)
211
+ output_path.parent.mkdir(parents=True, exist_ok=True)
212
+
213
+ summary = self.compute_summary()
214
+ html = self._render_html(summary)
215
+
216
+ output_path.write_text(html)
217
+ return output_path
218
+
219
+ def _render_html(self, summary: AuditSummary) -> str:
220
+ """Render the audit report as HTML."""
221
+ now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
222
+ status = "PASS" if summary.failed_cases == 0 else "REQUIRES REVIEW"
223
+ status_color = "#28a745" if summary.failed_cases == 0 else "#dc3545"
224
+
225
+ # Build procedure rows
226
+ proc_rows = ""
227
+ for proc, stats in sorted(summary.by_procedure.items()):
228
+ rate = stats["pass_rate"]
229
+ rate_color = "#28a745" if rate >= 0.95 else "#ffc107" if rate >= 0.8 else "#dc3545"
230
+ proc_rows += (
231
+ f"<tr>"
232
+ f"<td>{proc.title()}</td>"
233
+ f"<td>{stats['total']}</td>"
234
+ f"<td>{stats['passed']}</td>"
235
+ f'<td style="color:{rate_color};font-weight:bold">{rate:.1%}</td>'
236
+ f"<td>{stats['mean_identity_sim']:.4f}</td>"
237
+ f"</tr>\n"
238
+ )
239
+
240
+ # Build Fitzpatrick rows
241
+ fitz_rows = ""
242
+ for ft, stats in sorted(summary.by_fitzpatrick.items()):
243
+ rate = stats["pass_rate"]
244
+ rate_color = "#28a745" if rate >= 0.95 else "#ffc107" if rate >= 0.8 else "#dc3545"
245
+ fitz_rows += (
246
+ f"<tr>"
247
+ f"<td>{ft}</td>"
248
+ f"<td>{stats['total']}</td>"
249
+ f"<td>{stats['passed']}</td>"
250
+ f'<td style="color:{rate_color};font-weight:bold">{rate:.1%}</td>'
251
+ f"<td>{stats['mean_identity_sim']:.4f}</td>"
252
+ f"</tr>\n"
253
+ )
254
+
255
+ # Build flagged cases
256
+ flagged = self.flagged_cases()
257
+ flagged_rows = ""
258
+ for c in flagged:
259
+ issues = "; ".join(c.failures + [f"WARN: {w}" for w in c.warnings])
260
+ bg = "#fff3cd" if c.safety_passed else "#f8d7da"
261
+ flagged_rows += (
262
+ f'<tr style="background:{bg}">'
263
+ f"<td>{c.case_id}</td>"
264
+ f"<td>{c.procedure.title()}</td>"
265
+ f"<td>{c.fitzpatrick_type}</td>"
266
+ f"<td>{c.identity_sim:.4f}</td>"
267
+ f"<td>{'WARN' if c.safety_passed else 'FAIL'}</td>"
268
+ f"<td>{issues}</td>"
269
+ f"</tr>\n"
270
+ )
271
+
272
+ return f"""<!DOCTYPE html>
273
+ <html lang="en">
274
+ <head>
275
+ <meta charset="utf-8">
276
+ <title>{self.report_title}</title>
277
+ <style>
278
+ body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
279
+ max-width: 1100px; margin: 0 auto; padding: 20px; color: #333; }}
280
+ h1 {{ border-bottom: 3px solid #333; padding-bottom: 10px; }}
281
+ h2 {{ color: #555; margin-top: 30px; border-bottom: 1px solid #ddd; padding-bottom: 5px; }}
282
+ table {{ border-collapse: collapse; width: 100%; margin: 15px 0; }}
283
+ th, td {{ border: 1px solid #ddd; padding: 8px 12px; text-align: left; }}
284
+ th {{ background: #f8f9fa; font-weight: 600; }}
285
+ tr:hover {{ background: #f5f5f5; }}
286
+ .status {{ display: inline-block; padding: 4px 12px; border-radius: 4px;
287
+ color: white; font-weight: bold; font-size: 18px; }}
288
+ .summary-grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
289
+ gap: 15px; margin: 20px 0; }}
290
+ .summary-card {{ background: #f8f9fa; border-radius: 8px; padding: 15px; text-align: center; }}
291
+ .summary-card .value {{ font-size: 28px; font-weight: bold; color: #333; }}
292
+ .summary-card .label {{ font-size: 12px; color: #888; text-transform: uppercase; }}
293
+ .disclaimer {{ background: #fff3cd; border: 1px solid #ffc107; border-radius: 4px;
294
+ padding: 12px; margin: 20px 0; font-size: 13px; }}
295
+ footer {{ margin-top: 40px; padding-top: 15px; border-top: 1px solid #ddd;
296
+ font-size: 12px; color: #999; }}
297
+ </style>
298
+ </head>
299
+ <body>
300
+ <h1>{self.report_title}</h1>
301
+ <p>Generated: {now} &nbsp;|&nbsp; Model version: <code>{self.model_version}</code>
302
+ &nbsp;|&nbsp; Overall status: <span class="status" style="background:{status_color}">{status}</span></p>
303
+
304
+ <div class="disclaimer">
305
+ <strong>Disclaimer:</strong> This report is for research and development purposes only.
306
+ LandmarkDiff predictions are AI-generated visualizations and do not constitute medical advice
307
+ or guarantee surgical outcomes. All predictions should be reviewed by qualified clinical professionals.
308
+ </div>
309
+
310
+ <h2>Summary</h2>
311
+ <div class="summary-grid">
312
+ <div class="summary-card"><div class="value">{summary.total_cases}</div><div class="label">Total Cases</div></div>
313
+ <div class="summary-card"><div class="value" style="color:#28a745">{summary.passed_cases}</div><div class="label">Passed</div></div>
314
+ <div class="summary-card"><div class="value" style="color:#dc3545">{summary.failed_cases}</div><div class="label">Failed</div></div>
315
+ <div class="summary-card"><div class="value" style="color:#ffc107">{summary.flagged_cases}</div><div class="label">Flagged</div></div>
316
+ <div class="summary-card"><div class="value">{summary.pass_rate:.1%}</div><div class="label">Pass Rate</div></div>
317
+ <div class="summary-card"><div class="value">{summary.mean_identity_sim:.4f}</div><div class="label">Mean ID Sim</div></div>
318
+ </div>
319
+
320
+ <h2>Performance by Procedure</h2>
321
+ <table>
322
+ <tr><th>Procedure</th><th>Total</th><th>Passed</th><th>Pass Rate</th><th>Mean ID Sim</th></tr>
323
+ {proc_rows}</table>
324
+
325
+ <h2>Equity Analysis by Fitzpatrick Type</h2>
326
+ <table>
327
+ <tr><th>Fitzpatrick Type</th><th>Total</th><th>Passed</th><th>Pass Rate</th><th>Mean ID Sim</th></tr>
328
+ {fitz_rows}</table>
329
+
330
+ {"<h2>Flagged Cases (Require Review)</h2>" if flagged_rows else ""}
331
+ {"<table><tr><th>Case ID</th><th>Procedure</th><th>Fitzpatrick</th><th>ID Sim</th><th>Status</th><th>Issues</th></tr>" + flagged_rows + "</table>" if flagged_rows else "<p>No flagged cases.</p>"}
332
+
333
+ <footer>
334
+ LandmarkDiff v{self.model_version} &mdash; Clinical Audit Report &mdash;
335
+ For research use only. Not FDA approved.
336
+ </footer>
337
+ </body>
338
+ </html>"""
@@ -0,0 +1,293 @@
1
+ """Training data augmentation pipeline for LandmarkDiff.
2
+
3
+ Provides domain-specific augmentations that maintain landmark consistency:
4
+ - Geometric: flip, rotation, affine (landmarks co-transformed)
5
+ - Photometric: color jitter, brightness, contrast (applied to images only)
6
+ - Skin-tone augmentation: ITA-space perturbation for Fitzpatrick balance
7
+ - Conditioning augmentation: noise injection, dropout for robustness
8
+
9
+ All augmentations preserve the correspondence between:
10
+ input_image ↔ conditioning_image ↔ target_image ↔ mask
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from dataclasses import dataclass
16
+
17
+ import cv2
18
+ import numpy as np
19
+
20
+
21
+ @dataclass
22
+ class AugmentationConfig:
23
+ """Augmentation parameters."""
24
+
25
+ # Geometric
26
+ random_flip: bool = True
27
+ random_rotation_deg: float = 5.0
28
+ random_scale: tuple[float, float] = (0.95, 1.05)
29
+ random_translate: float = 0.02 # fraction of image size
30
+
31
+ # Photometric (images only, not conditioning)
32
+ brightness_range: tuple[float, float] = (0.9, 1.1)
33
+ contrast_range: tuple[float, float] = (0.9, 1.1)
34
+ saturation_range: tuple[float, float] = (0.9, 1.1)
35
+ hue_shift_range: float = 5.0 # degrees
36
+
37
+ # Conditioning augmentation
38
+ conditioning_dropout_prob: float = 0.1
39
+ conditioning_noise_std: float = 0.02
40
+
41
+ # Skin-tone augmentation
42
+ ita_perturbation_std: float = 3.0 # ITA angle noise
43
+
44
+ seed: int | None = None
45
+
46
+
47
+ def augment_training_sample(
48
+ input_image: np.ndarray,
49
+ target_image: np.ndarray,
50
+ conditioning: np.ndarray,
51
+ mask: np.ndarray,
52
+ landmarks_src: np.ndarray | None = None,
53
+ landmarks_dst: np.ndarray | None = None,
54
+ config: AugmentationConfig | None = None,
55
+ rng: np.random.Generator | None = None,
56
+ ) -> dict[str, np.ndarray]:
57
+ """Apply consistent augmentations to a training sample.
58
+
59
+ All spatial transforms are applied to images AND landmarks together
60
+ so correspondence is preserved.
61
+
62
+ Args:
63
+ input_image: (H, W, 3) original face image (uint8 BGR).
64
+ target_image: (H, W, 3) target face image (uint8 BGR).
65
+ conditioning: (H, W, 3) conditioning image (uint8).
66
+ mask: (H, W) or (H, W, 1) float32 mask.
67
+ landmarks_src: (N, 2) normalized [0,1] source landmark coords.
68
+ landmarks_dst: (N, 2) normalized [0,1] target landmark coords.
69
+ config: Augmentation parameters.
70
+ rng: Random generator for reproducibility.
71
+
72
+ Returns:
73
+ Dict with augmented versions of all inputs.
74
+ """
75
+ if config is None:
76
+ config = AugmentationConfig()
77
+ if rng is None:
78
+ rng = np.random.default_rng(config.seed)
79
+
80
+ h, w = input_image.shape[:2]
81
+ out_input = input_image.copy()
82
+ out_target = target_image.copy()
83
+ out_cond = conditioning.copy()
84
+ out_mask = mask.copy()
85
+ out_lm_src = landmarks_src.copy() if landmarks_src is not None else None
86
+ out_lm_dst = landmarks_dst.copy() if landmarks_dst is not None else None
87
+
88
+ # --- Geometric augmentations (applied to all) ---
89
+
90
+ # Random horizontal flip
91
+ if config.random_flip and rng.random() < 0.5:
92
+ out_input = np.ascontiguousarray(out_input[:, ::-1])
93
+ out_target = np.ascontiguousarray(out_target[:, ::-1])
94
+ out_cond = np.ascontiguousarray(out_cond[:, ::-1])
95
+ out_mask = np.ascontiguousarray(
96
+ out_mask[:, ::-1] if out_mask.ndim == 2 else out_mask[:, ::-1, :]
97
+ )
98
+ if out_lm_src is not None:
99
+ out_lm_src[:, 0] = 1.0 - out_lm_src[:, 0]
100
+ if out_lm_dst is not None:
101
+ out_lm_dst[:, 0] = 1.0 - out_lm_dst[:, 0]
102
+
103
+ # Random rotation + scale + translate
104
+ if config.random_rotation_deg > 0 or config.random_scale != (1.0, 1.0):
105
+ angle = rng.uniform(-config.random_rotation_deg, config.random_rotation_deg)
106
+ scale = rng.uniform(config.random_scale[0], config.random_scale[1])
107
+ tx = rng.uniform(-config.random_translate, config.random_translate) * w
108
+ ty = rng.uniform(-config.random_translate, config.random_translate) * h
109
+
110
+ center = (w / 2, h / 2)
111
+ M = cv2.getRotationMatrix2D(center, angle, scale)
112
+ M[0, 2] += tx
113
+ M[1, 2] += ty
114
+
115
+ out_input = cv2.warpAffine(out_input, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
116
+ out_target = cv2.warpAffine(out_target, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
117
+ out_cond = cv2.warpAffine(
118
+ out_cond, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=0
119
+ )
120
+ mask_2d = out_mask if out_mask.ndim == 2 else out_mask[:, :, 0]
121
+ mask_2d = cv2.warpAffine(mask_2d, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=0)
122
+ out_mask = mask_2d if out_mask.ndim == 2 else mask_2d[:, :, np.newaxis]
123
+
124
+ # Transform landmarks
125
+ if out_lm_src is not None:
126
+ out_lm_src = _transform_landmarks(out_lm_src, M, w, h)
127
+ if out_lm_dst is not None:
128
+ out_lm_dst = _transform_landmarks(out_lm_dst, M, w, h)
129
+
130
+ # --- Photometric augmentations (images only, not conditioning/mask) ---
131
+
132
+ # Brightness
133
+ b_factor = rng.uniform(config.brightness_range[0], config.brightness_range[1])
134
+ out_input = np.clip(out_input.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
135
+ out_target = np.clip(out_target.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
136
+
137
+ # Contrast
138
+ c_factor = rng.uniform(config.contrast_range[0], config.contrast_range[1])
139
+ mean_in = out_input.mean()
140
+ mean_tgt = out_target.mean()
141
+ out_input = np.clip(
142
+ (out_input.astype(np.float32) - mean_in) * c_factor + mean_in, 0, 255
143
+ ).astype(np.uint8)
144
+ out_target = np.clip(
145
+ (out_target.astype(np.float32) - mean_tgt) * c_factor + mean_tgt, 0, 255
146
+ ).astype(np.uint8)
147
+
148
+ # Saturation (in HSV space)
149
+ s_factor = rng.uniform(config.saturation_range[0], config.saturation_range[1])
150
+ if abs(s_factor - 1.0) > 1e-4:
151
+ out_input = _adjust_saturation(out_input, s_factor)
152
+ out_target = _adjust_saturation(out_target, s_factor)
153
+
154
+ # Hue shift
155
+ if config.hue_shift_range > 0:
156
+ hue_delta = rng.uniform(-config.hue_shift_range, config.hue_shift_range)
157
+ if abs(hue_delta) > 0.1:
158
+ out_input = _shift_hue(out_input, hue_delta)
159
+ out_target = _shift_hue(out_target, hue_delta)
160
+
161
+ # --- Conditioning augmentation ---
162
+
163
+ # Conditioning dropout (replace with zeros to learn unconditional)
164
+ if config.conditioning_dropout_prob > 0 and rng.random() < config.conditioning_dropout_prob:
165
+ out_cond = np.zeros_like(out_cond)
166
+
167
+ # Conditioning noise
168
+ if config.conditioning_noise_std > 0:
169
+ noise = rng.normal(0, config.conditioning_noise_std * 255, out_cond.shape)
170
+ out_cond = np.clip(out_cond.astype(np.float32) + noise, 0, 255).astype(np.uint8)
171
+
172
+ result = {
173
+ "input_image": out_input,
174
+ "target_image": out_target,
175
+ "conditioning": out_cond,
176
+ "mask": out_mask,
177
+ }
178
+ if out_lm_src is not None:
179
+ result["landmarks_src"] = out_lm_src
180
+ if out_lm_dst is not None:
181
+ result["landmarks_dst"] = out_lm_dst
182
+
183
+ return result
184
+
185
+
186
+ def _transform_landmarks(landmarks: np.ndarray, M: np.ndarray, w: int, h: int) -> np.ndarray:
187
+ """Transform normalized landmarks with an affine matrix."""
188
+ # Convert to pixel coords
189
+ px = landmarks.copy()
190
+ px[:, 0] *= w
191
+ px[:, 1] *= h
192
+
193
+ # Apply affine transform
194
+ ones = np.ones((px.shape[0], 1))
195
+ px_h = np.hstack([px, ones]) # (N, 3)
196
+ transformed = (M @ px_h.T).T # (N, 2)
197
+
198
+ # Back to normalized
199
+ transformed[:, 0] /= w
200
+ transformed[:, 1] /= h
201
+ return np.clip(transformed, 0.0, 1.0)
202
+
203
+
204
+ def _adjust_saturation(img: np.ndarray, factor: float) -> np.ndarray:
205
+ """Adjust saturation of a BGR image."""
206
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
207
+ hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
208
+ return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
209
+
210
+
211
+ def _shift_hue(img: np.ndarray, delta_deg: float) -> np.ndarray:
212
+ """Shift hue of a BGR image by delta degrees."""
213
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
214
+ # OpenCV hue range is [0, 180]
215
+ hsv[:, :, 0] = (hsv[:, :, 0] + delta_deg / 2) % 180
216
+ return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
217
+
218
+
219
+ def augment_skin_tone(
220
+ image: np.ndarray,
221
+ ita_delta: float = 0.0,
222
+ ) -> np.ndarray:
223
+ """Augment skin tone by shifting in L*a*b* space.
224
+
225
+ This helps balance Fitzpatrick representation in training by
226
+ simulating different skin tones from existing samples.
227
+
228
+ Args:
229
+ image: (H, W, 3) BGR uint8 image.
230
+ ita_delta: ITA angle shift (positive = lighter, negative = darker).
231
+
232
+ Returns:
233
+ Augmented image with shifted skin tone.
234
+ """
235
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
236
+
237
+ # Shift L channel (lightness) based on ITA delta
238
+ # ITA = arctan((L-50)/b), so shifting ITA shifts L
239
+ l_shift = ita_delta * 0.5 # approximate mapping
240
+ lab[:, :, 0] = np.clip(lab[:, :, 0] + l_shift, 0, 255)
241
+
242
+ # Slightly shift b channel too for more natural tone changes
243
+ b_shift = -ita_delta * 0.15
244
+ lab[:, :, 2] = np.clip(lab[:, :, 2] + b_shift, 0, 255)
245
+
246
+ return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
247
+
248
+
249
+ class FitzpatrickBalancer:
250
+ """Oversample underrepresented Fitzpatrick types during training.
251
+
252
+ Maintains per-type counts and generates sampling weights to ensure
253
+ equitable training across all skin types.
254
+ """
255
+
256
+ def __init__(self, target_distribution: dict[str, float] | None = None):
257
+ """Initialize balancer.
258
+
259
+ Args:
260
+ target_distribution: Target fraction per type. Defaults to uniform.
261
+ """
262
+ self.target = target_distribution or {
263
+ "I": 1 / 6,
264
+ "II": 1 / 6,
265
+ "III": 1 / 6,
266
+ "IV": 1 / 6,
267
+ "V": 1 / 6,
268
+ "VI": 1 / 6,
269
+ }
270
+ self._counts: dict[str, int] = {}
271
+
272
+ def register_sample(self, fitz_type: str) -> None:
273
+ """Register a sample's Fitzpatrick type."""
274
+ self._counts[fitz_type] = self._counts.get(fitz_type, 0) + 1
275
+
276
+ def get_sampling_weights(self, fitz_types: list[str]) -> np.ndarray:
277
+ """Compute sampling weights for a list of samples.
278
+
279
+ Returns weights inversely proportional to type frequency,
280
+ so underrepresented types get upsampled.
281
+ """
282
+ total = sum(self._counts.values()) or 1
283
+ weights = []
284
+ for ft in fitz_types:
285
+ count = self._counts.get(ft, 1)
286
+ freq = count / total
287
+ target_freq = self.target.get(ft, 1 / 6)
288
+ # Weight = target / actual (capped for stability)
289
+ w = min(target_freq / max(freq, 1e-6), 5.0)
290
+ weights.append(w)
291
+
292
+ w = np.array(weights, dtype=np.float64)
293
+ return w / w.sum() # normalize to probability distribution