landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
landmarkdiff/audit.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
"""Clinical audit report generator for regulatory compliance.
|
|
2
|
+
|
|
3
|
+
Generates structured HTML reports summarizing safety validation results,
|
|
4
|
+
model performance, and Fitzpatrick equity analysis for clinical review.
|
|
5
|
+
|
|
6
|
+
Reports include:
|
|
7
|
+
- Safety validation pass/fail summary per patient
|
|
8
|
+
- Aggregate statistics by procedure and Fitzpatrick type
|
|
9
|
+
- Flagged cases for manual review
|
|
10
|
+
- Model version and configuration provenance
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
from landmarkdiff.audit import AuditReporter, AuditCase
|
|
14
|
+
|
|
15
|
+
reporter = AuditReporter(model_version="0.3.2")
|
|
16
|
+
reporter.add_case(AuditCase(
|
|
17
|
+
case_id="P001",
|
|
18
|
+
procedure="rhinoplasty",
|
|
19
|
+
safety_passed=True,
|
|
20
|
+
identity_sim=0.87,
|
|
21
|
+
fitzpatrick_type="III",
|
|
22
|
+
))
|
|
23
|
+
reporter.generate_report("audit_report.html")
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import json
|
|
29
|
+
from dataclasses import dataclass, field
|
|
30
|
+
from datetime import datetime, timezone
|
|
31
|
+
from pathlib import Path
|
|
32
|
+
from typing import Any
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class AuditCase:
|
|
37
|
+
"""A single patient case for audit reporting."""
|
|
38
|
+
|
|
39
|
+
case_id: str
|
|
40
|
+
procedure: str
|
|
41
|
+
safety_passed: bool
|
|
42
|
+
identity_sim: float = 0.0
|
|
43
|
+
intensity: float = 65.0
|
|
44
|
+
fitzpatrick_type: str = ""
|
|
45
|
+
warnings: list[str] = field(default_factory=list)
|
|
46
|
+
failures: list[str] = field(default_factory=list)
|
|
47
|
+
metrics: dict[str, float] = field(default_factory=dict)
|
|
48
|
+
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class AuditSummary:
|
|
53
|
+
"""Aggregate statistics for an audit report."""
|
|
54
|
+
|
|
55
|
+
total_cases: int = 0
|
|
56
|
+
passed_cases: int = 0
|
|
57
|
+
failed_cases: int = 0
|
|
58
|
+
flagged_cases: int = 0
|
|
59
|
+
pass_rate: float = 0.0
|
|
60
|
+
mean_identity_sim: float = 0.0
|
|
61
|
+
by_procedure: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
62
|
+
by_fitzpatrick: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class AuditReporter:
|
|
66
|
+
"""Generate clinical audit reports from safety validation results.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
model_version: Model version string for provenance.
|
|
70
|
+
report_title: Title for generated reports.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
model_version: str = "0.3.2",
|
|
76
|
+
report_title: str = "LandmarkDiff Clinical Audit Report",
|
|
77
|
+
) -> None:
|
|
78
|
+
self.model_version = model_version
|
|
79
|
+
self.report_title = report_title
|
|
80
|
+
self.cases: list[AuditCase] = []
|
|
81
|
+
|
|
82
|
+
def add_case(self, case: AuditCase) -> None:
|
|
83
|
+
"""Add a case to the audit report."""
|
|
84
|
+
self.cases.append(case)
|
|
85
|
+
|
|
86
|
+
def add_cases(self, cases: list[AuditCase]) -> None:
|
|
87
|
+
"""Add multiple cases."""
|
|
88
|
+
self.cases.extend(cases)
|
|
89
|
+
|
|
90
|
+
def clear(self) -> None:
|
|
91
|
+
"""Clear all cases."""
|
|
92
|
+
self.cases.clear()
|
|
93
|
+
|
|
94
|
+
def compute_summary(self) -> AuditSummary:
|
|
95
|
+
"""Compute aggregate statistics from all cases."""
|
|
96
|
+
if not self.cases:
|
|
97
|
+
return AuditSummary()
|
|
98
|
+
|
|
99
|
+
total = len(self.cases)
|
|
100
|
+
passed = sum(1 for c in self.cases if c.safety_passed)
|
|
101
|
+
failed = total - passed
|
|
102
|
+
flagged = sum(1 for c in self.cases if not c.safety_passed or c.warnings)
|
|
103
|
+
|
|
104
|
+
id_sims = [c.identity_sim for c in self.cases if c.identity_sim > 0]
|
|
105
|
+
mean_id = sum(id_sims) / len(id_sims) if id_sims else 0.0
|
|
106
|
+
|
|
107
|
+
# By procedure
|
|
108
|
+
by_proc: dict[str, dict[str, Any]] = {}
|
|
109
|
+
for case in self.cases:
|
|
110
|
+
proc = case.procedure
|
|
111
|
+
if proc not in by_proc:
|
|
112
|
+
by_proc[proc] = {"total": 0, "passed": 0, "id_sims": []}
|
|
113
|
+
by_proc[proc]["total"] += 1
|
|
114
|
+
if case.safety_passed:
|
|
115
|
+
by_proc[proc]["passed"] += 1
|
|
116
|
+
if case.identity_sim > 0:
|
|
117
|
+
by_proc[proc]["id_sims"].append(case.identity_sim)
|
|
118
|
+
|
|
119
|
+
for _proc, stats in by_proc.items():
|
|
120
|
+
stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
|
|
121
|
+
stats["mean_identity_sim"] = (
|
|
122
|
+
sum(stats["id_sims"]) / len(stats["id_sims"]) if stats["id_sims"] else 0.0
|
|
123
|
+
)
|
|
124
|
+
del stats["id_sims"]
|
|
125
|
+
|
|
126
|
+
# By Fitzpatrick type
|
|
127
|
+
by_fitz: dict[str, dict[str, Any]] = {}
|
|
128
|
+
for case in self.cases:
|
|
129
|
+
ft = case.fitzpatrick_type or "Unknown"
|
|
130
|
+
if ft not in by_fitz:
|
|
131
|
+
by_fitz[ft] = {"total": 0, "passed": 0, "id_sims": []}
|
|
132
|
+
by_fitz[ft]["total"] += 1
|
|
133
|
+
if case.safety_passed:
|
|
134
|
+
by_fitz[ft]["passed"] += 1
|
|
135
|
+
if case.identity_sim > 0:
|
|
136
|
+
by_fitz[ft]["id_sims"].append(case.identity_sim)
|
|
137
|
+
|
|
138
|
+
for _ft, stats in by_fitz.items():
|
|
139
|
+
stats["pass_rate"] = stats["passed"] / max(stats["total"], 1)
|
|
140
|
+
stats["mean_identity_sim"] = (
|
|
141
|
+
sum(stats["id_sims"]) / len(stats["id_sims"]) if stats["id_sims"] else 0.0
|
|
142
|
+
)
|
|
143
|
+
del stats["id_sims"]
|
|
144
|
+
|
|
145
|
+
return AuditSummary(
|
|
146
|
+
total_cases=total,
|
|
147
|
+
passed_cases=passed,
|
|
148
|
+
failed_cases=failed,
|
|
149
|
+
flagged_cases=flagged,
|
|
150
|
+
pass_rate=passed / total,
|
|
151
|
+
mean_identity_sim=mean_id,
|
|
152
|
+
by_procedure=by_proc,
|
|
153
|
+
by_fitzpatrick=by_fitz,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def flagged_cases(self) -> list[AuditCase]:
|
|
157
|
+
"""Return cases that need manual review (failed or have warnings)."""
|
|
158
|
+
return [c for c in self.cases if not c.safety_passed or c.warnings]
|
|
159
|
+
|
|
160
|
+
def to_json(self) -> str:
|
|
161
|
+
"""Export audit data as JSON."""
|
|
162
|
+
summary = self.compute_summary()
|
|
163
|
+
data = {
|
|
164
|
+
"report_title": self.report_title,
|
|
165
|
+
"model_version": self.model_version,
|
|
166
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
167
|
+
"summary": {
|
|
168
|
+
"total_cases": summary.total_cases,
|
|
169
|
+
"passed_cases": summary.passed_cases,
|
|
170
|
+
"failed_cases": summary.failed_cases,
|
|
171
|
+
"flagged_cases": summary.flagged_cases,
|
|
172
|
+
"pass_rate": round(summary.pass_rate, 4),
|
|
173
|
+
"mean_identity_sim": round(summary.mean_identity_sim, 4),
|
|
174
|
+
},
|
|
175
|
+
"by_procedure": {
|
|
176
|
+
k: {kk: round(vv, 4) if isinstance(vv, float) else vv for kk, vv in v.items()}
|
|
177
|
+
for k, v in summary.by_procedure.items()
|
|
178
|
+
},
|
|
179
|
+
"by_fitzpatrick": {
|
|
180
|
+
k: {kk: round(vv, 4) if isinstance(vv, float) else vv for kk, vv in v.items()}
|
|
181
|
+
for k, v in summary.by_fitzpatrick.items()
|
|
182
|
+
},
|
|
183
|
+
"cases": [
|
|
184
|
+
{
|
|
185
|
+
"case_id": c.case_id,
|
|
186
|
+
"procedure": c.procedure,
|
|
187
|
+
"safety_passed": c.safety_passed,
|
|
188
|
+
"identity_sim": round(c.identity_sim, 4),
|
|
189
|
+
"intensity": c.intensity,
|
|
190
|
+
"fitzpatrick_type": c.fitzpatrick_type,
|
|
191
|
+
"warnings": c.warnings,
|
|
192
|
+
"failures": c.failures,
|
|
193
|
+
"metrics": {k: round(v, 4) for k, v in c.metrics.items()},
|
|
194
|
+
"timestamp": c.timestamp,
|
|
195
|
+
}
|
|
196
|
+
for c in self.cases
|
|
197
|
+
],
|
|
198
|
+
}
|
|
199
|
+
return json.dumps(data, indent=2)
|
|
200
|
+
|
|
201
|
+
def generate_report(self, output_path: str | Path) -> Path:
|
|
202
|
+
"""Generate an HTML audit report.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
output_path: Path to save the HTML report.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
Path to the generated report.
|
|
209
|
+
"""
|
|
210
|
+
output_path = Path(output_path)
|
|
211
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
212
|
+
|
|
213
|
+
summary = self.compute_summary()
|
|
214
|
+
html = self._render_html(summary)
|
|
215
|
+
|
|
216
|
+
output_path.write_text(html)
|
|
217
|
+
return output_path
|
|
218
|
+
|
|
219
|
+
def _render_html(self, summary: AuditSummary) -> str:
|
|
220
|
+
"""Render the audit report as HTML."""
|
|
221
|
+
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
|
|
222
|
+
status = "PASS" if summary.failed_cases == 0 else "REQUIRES REVIEW"
|
|
223
|
+
status_color = "#28a745" if summary.failed_cases == 0 else "#dc3545"
|
|
224
|
+
|
|
225
|
+
# Build procedure rows
|
|
226
|
+
proc_rows = ""
|
|
227
|
+
for proc, stats in sorted(summary.by_procedure.items()):
|
|
228
|
+
rate = stats["pass_rate"]
|
|
229
|
+
rate_color = "#28a745" if rate >= 0.95 else "#ffc107" if rate >= 0.8 else "#dc3545"
|
|
230
|
+
proc_rows += (
|
|
231
|
+
f"<tr>"
|
|
232
|
+
f"<td>{proc.title()}</td>"
|
|
233
|
+
f"<td>{stats['total']}</td>"
|
|
234
|
+
f"<td>{stats['passed']}</td>"
|
|
235
|
+
f'<td style="color:{rate_color};font-weight:bold">{rate:.1%}</td>'
|
|
236
|
+
f"<td>{stats['mean_identity_sim']:.4f}</td>"
|
|
237
|
+
f"</tr>\n"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Build Fitzpatrick rows
|
|
241
|
+
fitz_rows = ""
|
|
242
|
+
for ft, stats in sorted(summary.by_fitzpatrick.items()):
|
|
243
|
+
rate = stats["pass_rate"]
|
|
244
|
+
rate_color = "#28a745" if rate >= 0.95 else "#ffc107" if rate >= 0.8 else "#dc3545"
|
|
245
|
+
fitz_rows += (
|
|
246
|
+
f"<tr>"
|
|
247
|
+
f"<td>{ft}</td>"
|
|
248
|
+
f"<td>{stats['total']}</td>"
|
|
249
|
+
f"<td>{stats['passed']}</td>"
|
|
250
|
+
f'<td style="color:{rate_color};font-weight:bold">{rate:.1%}</td>'
|
|
251
|
+
f"<td>{stats['mean_identity_sim']:.4f}</td>"
|
|
252
|
+
f"</tr>\n"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Build flagged cases
|
|
256
|
+
flagged = self.flagged_cases()
|
|
257
|
+
flagged_rows = ""
|
|
258
|
+
for c in flagged:
|
|
259
|
+
issues = "; ".join(c.failures + [f"WARN: {w}" for w in c.warnings])
|
|
260
|
+
bg = "#fff3cd" if c.safety_passed else "#f8d7da"
|
|
261
|
+
flagged_rows += (
|
|
262
|
+
f'<tr style="background:{bg}">'
|
|
263
|
+
f"<td>{c.case_id}</td>"
|
|
264
|
+
f"<td>{c.procedure.title()}</td>"
|
|
265
|
+
f"<td>{c.fitzpatrick_type}</td>"
|
|
266
|
+
f"<td>{c.identity_sim:.4f}</td>"
|
|
267
|
+
f"<td>{'WARN' if c.safety_passed else 'FAIL'}</td>"
|
|
268
|
+
f"<td>{issues}</td>"
|
|
269
|
+
f"</tr>\n"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return f"""<!DOCTYPE html>
|
|
273
|
+
<html lang="en">
|
|
274
|
+
<head>
|
|
275
|
+
<meta charset="utf-8">
|
|
276
|
+
<title>{self.report_title}</title>
|
|
277
|
+
<style>
|
|
278
|
+
body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
|
279
|
+
max-width: 1100px; margin: 0 auto; padding: 20px; color: #333; }}
|
|
280
|
+
h1 {{ border-bottom: 3px solid #333; padding-bottom: 10px; }}
|
|
281
|
+
h2 {{ color: #555; margin-top: 30px; border-bottom: 1px solid #ddd; padding-bottom: 5px; }}
|
|
282
|
+
table {{ border-collapse: collapse; width: 100%; margin: 15px 0; }}
|
|
283
|
+
th, td {{ border: 1px solid #ddd; padding: 8px 12px; text-align: left; }}
|
|
284
|
+
th {{ background: #f8f9fa; font-weight: 600; }}
|
|
285
|
+
tr:hover {{ background: #f5f5f5; }}
|
|
286
|
+
.status {{ display: inline-block; padding: 4px 12px; border-radius: 4px;
|
|
287
|
+
color: white; font-weight: bold; font-size: 18px; }}
|
|
288
|
+
.summary-grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr));
|
|
289
|
+
gap: 15px; margin: 20px 0; }}
|
|
290
|
+
.summary-card {{ background: #f8f9fa; border-radius: 8px; padding: 15px; text-align: center; }}
|
|
291
|
+
.summary-card .value {{ font-size: 28px; font-weight: bold; color: #333; }}
|
|
292
|
+
.summary-card .label {{ font-size: 12px; color: #888; text-transform: uppercase; }}
|
|
293
|
+
.disclaimer {{ background: #fff3cd; border: 1px solid #ffc107; border-radius: 4px;
|
|
294
|
+
padding: 12px; margin: 20px 0; font-size: 13px; }}
|
|
295
|
+
footer {{ margin-top: 40px; padding-top: 15px; border-top: 1px solid #ddd;
|
|
296
|
+
font-size: 12px; color: #999; }}
|
|
297
|
+
</style>
|
|
298
|
+
</head>
|
|
299
|
+
<body>
|
|
300
|
+
<h1>{self.report_title}</h1>
|
|
301
|
+
<p>Generated: {now} | Model version: <code>{self.model_version}</code>
|
|
302
|
+
| Overall status: <span class="status" style="background:{status_color}">{status}</span></p>
|
|
303
|
+
|
|
304
|
+
<div class="disclaimer">
|
|
305
|
+
<strong>Disclaimer:</strong> This report is for research and development purposes only.
|
|
306
|
+
LandmarkDiff predictions are AI-generated visualizations and do not constitute medical advice
|
|
307
|
+
or guarantee surgical outcomes. All predictions should be reviewed by qualified clinical professionals.
|
|
308
|
+
</div>
|
|
309
|
+
|
|
310
|
+
<h2>Summary</h2>
|
|
311
|
+
<div class="summary-grid">
|
|
312
|
+
<div class="summary-card"><div class="value">{summary.total_cases}</div><div class="label">Total Cases</div></div>
|
|
313
|
+
<div class="summary-card"><div class="value" style="color:#28a745">{summary.passed_cases}</div><div class="label">Passed</div></div>
|
|
314
|
+
<div class="summary-card"><div class="value" style="color:#dc3545">{summary.failed_cases}</div><div class="label">Failed</div></div>
|
|
315
|
+
<div class="summary-card"><div class="value" style="color:#ffc107">{summary.flagged_cases}</div><div class="label">Flagged</div></div>
|
|
316
|
+
<div class="summary-card"><div class="value">{summary.pass_rate:.1%}</div><div class="label">Pass Rate</div></div>
|
|
317
|
+
<div class="summary-card"><div class="value">{summary.mean_identity_sim:.4f}</div><div class="label">Mean ID Sim</div></div>
|
|
318
|
+
</div>
|
|
319
|
+
|
|
320
|
+
<h2>Performance by Procedure</h2>
|
|
321
|
+
<table>
|
|
322
|
+
<tr><th>Procedure</th><th>Total</th><th>Passed</th><th>Pass Rate</th><th>Mean ID Sim</th></tr>
|
|
323
|
+
{proc_rows}</table>
|
|
324
|
+
|
|
325
|
+
<h2>Equity Analysis by Fitzpatrick Type</h2>
|
|
326
|
+
<table>
|
|
327
|
+
<tr><th>Fitzpatrick Type</th><th>Total</th><th>Passed</th><th>Pass Rate</th><th>Mean ID Sim</th></tr>
|
|
328
|
+
{fitz_rows}</table>
|
|
329
|
+
|
|
330
|
+
{"<h2>Flagged Cases (Require Review)</h2>" if flagged_rows else ""}
|
|
331
|
+
{"<table><tr><th>Case ID</th><th>Procedure</th><th>Fitzpatrick</th><th>ID Sim</th><th>Status</th><th>Issues</th></tr>" + flagged_rows + "</table>" if flagged_rows else "<p>No flagged cases.</p>"}
|
|
332
|
+
|
|
333
|
+
<footer>
|
|
334
|
+
LandmarkDiff v{self.model_version} — Clinical Audit Report —
|
|
335
|
+
For research use only. Not FDA approved.
|
|
336
|
+
</footer>
|
|
337
|
+
</body>
|
|
338
|
+
</html>"""
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""Training data augmentation pipeline for LandmarkDiff.
|
|
2
|
+
|
|
3
|
+
Provides domain-specific augmentations that maintain landmark consistency:
|
|
4
|
+
- Geometric: flip, rotation, affine (landmarks co-transformed)
|
|
5
|
+
- Photometric: color jitter, brightness, contrast (applied to images only)
|
|
6
|
+
- Skin-tone augmentation: ITA-space perturbation for Fitzpatrick balance
|
|
7
|
+
- Conditioning augmentation: noise injection, dropout for robustness
|
|
8
|
+
|
|
9
|
+
All augmentations preserve the correspondence between:
|
|
10
|
+
input_image ↔ conditioning_image ↔ target_image ↔ mask
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
|
|
17
|
+
import cv2
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class AugmentationConfig:
|
|
23
|
+
"""Augmentation parameters."""
|
|
24
|
+
|
|
25
|
+
# Geometric
|
|
26
|
+
random_flip: bool = True
|
|
27
|
+
random_rotation_deg: float = 5.0
|
|
28
|
+
random_scale: tuple[float, float] = (0.95, 1.05)
|
|
29
|
+
random_translate: float = 0.02 # fraction of image size
|
|
30
|
+
|
|
31
|
+
# Photometric (images only, not conditioning)
|
|
32
|
+
brightness_range: tuple[float, float] = (0.9, 1.1)
|
|
33
|
+
contrast_range: tuple[float, float] = (0.9, 1.1)
|
|
34
|
+
saturation_range: tuple[float, float] = (0.9, 1.1)
|
|
35
|
+
hue_shift_range: float = 5.0 # degrees
|
|
36
|
+
|
|
37
|
+
# Conditioning augmentation
|
|
38
|
+
conditioning_dropout_prob: float = 0.1
|
|
39
|
+
conditioning_noise_std: float = 0.02
|
|
40
|
+
|
|
41
|
+
# Skin-tone augmentation
|
|
42
|
+
ita_perturbation_std: float = 3.0 # ITA angle noise
|
|
43
|
+
|
|
44
|
+
seed: int | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def augment_training_sample(
|
|
48
|
+
input_image: np.ndarray,
|
|
49
|
+
target_image: np.ndarray,
|
|
50
|
+
conditioning: np.ndarray,
|
|
51
|
+
mask: np.ndarray,
|
|
52
|
+
landmarks_src: np.ndarray | None = None,
|
|
53
|
+
landmarks_dst: np.ndarray | None = None,
|
|
54
|
+
config: AugmentationConfig | None = None,
|
|
55
|
+
rng: np.random.Generator | None = None,
|
|
56
|
+
) -> dict[str, np.ndarray]:
|
|
57
|
+
"""Apply consistent augmentations to a training sample.
|
|
58
|
+
|
|
59
|
+
All spatial transforms are applied to images AND landmarks together
|
|
60
|
+
so correspondence is preserved.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
input_image: (H, W, 3) original face image (uint8 BGR).
|
|
64
|
+
target_image: (H, W, 3) target face image (uint8 BGR).
|
|
65
|
+
conditioning: (H, W, 3) conditioning image (uint8).
|
|
66
|
+
mask: (H, W) or (H, W, 1) float32 mask.
|
|
67
|
+
landmarks_src: (N, 2) normalized [0,1] source landmark coords.
|
|
68
|
+
landmarks_dst: (N, 2) normalized [0,1] target landmark coords.
|
|
69
|
+
config: Augmentation parameters.
|
|
70
|
+
rng: Random generator for reproducibility.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Dict with augmented versions of all inputs.
|
|
74
|
+
"""
|
|
75
|
+
if config is None:
|
|
76
|
+
config = AugmentationConfig()
|
|
77
|
+
if rng is None:
|
|
78
|
+
rng = np.random.default_rng(config.seed)
|
|
79
|
+
|
|
80
|
+
h, w = input_image.shape[:2]
|
|
81
|
+
out_input = input_image.copy()
|
|
82
|
+
out_target = target_image.copy()
|
|
83
|
+
out_cond = conditioning.copy()
|
|
84
|
+
out_mask = mask.copy()
|
|
85
|
+
out_lm_src = landmarks_src.copy() if landmarks_src is not None else None
|
|
86
|
+
out_lm_dst = landmarks_dst.copy() if landmarks_dst is not None else None
|
|
87
|
+
|
|
88
|
+
# --- Geometric augmentations (applied to all) ---
|
|
89
|
+
|
|
90
|
+
# Random horizontal flip
|
|
91
|
+
if config.random_flip and rng.random() < 0.5:
|
|
92
|
+
out_input = np.ascontiguousarray(out_input[:, ::-1])
|
|
93
|
+
out_target = np.ascontiguousarray(out_target[:, ::-1])
|
|
94
|
+
out_cond = np.ascontiguousarray(out_cond[:, ::-1])
|
|
95
|
+
out_mask = np.ascontiguousarray(
|
|
96
|
+
out_mask[:, ::-1] if out_mask.ndim == 2 else out_mask[:, ::-1, :]
|
|
97
|
+
)
|
|
98
|
+
if out_lm_src is not None:
|
|
99
|
+
out_lm_src[:, 0] = 1.0 - out_lm_src[:, 0]
|
|
100
|
+
if out_lm_dst is not None:
|
|
101
|
+
out_lm_dst[:, 0] = 1.0 - out_lm_dst[:, 0]
|
|
102
|
+
|
|
103
|
+
# Random rotation + scale + translate
|
|
104
|
+
if config.random_rotation_deg > 0 or config.random_scale != (1.0, 1.0):
|
|
105
|
+
angle = rng.uniform(-config.random_rotation_deg, config.random_rotation_deg)
|
|
106
|
+
scale = rng.uniform(config.random_scale[0], config.random_scale[1])
|
|
107
|
+
tx = rng.uniform(-config.random_translate, config.random_translate) * w
|
|
108
|
+
ty = rng.uniform(-config.random_translate, config.random_translate) * h
|
|
109
|
+
|
|
110
|
+
center = (w / 2, h / 2)
|
|
111
|
+
M = cv2.getRotationMatrix2D(center, angle, scale)
|
|
112
|
+
M[0, 2] += tx
|
|
113
|
+
M[1, 2] += ty
|
|
114
|
+
|
|
115
|
+
out_input = cv2.warpAffine(out_input, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
|
|
116
|
+
out_target = cv2.warpAffine(out_target, M, (w, h), borderMode=cv2.BORDER_REFLECT_101)
|
|
117
|
+
out_cond = cv2.warpAffine(
|
|
118
|
+
out_cond, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=0
|
|
119
|
+
)
|
|
120
|
+
mask_2d = out_mask if out_mask.ndim == 2 else out_mask[:, :, 0]
|
|
121
|
+
mask_2d = cv2.warpAffine(mask_2d, M, (w, h), borderMode=cv2.BORDER_CONSTANT, borderValue=0)
|
|
122
|
+
out_mask = mask_2d if out_mask.ndim == 2 else mask_2d[:, :, np.newaxis]
|
|
123
|
+
|
|
124
|
+
# Transform landmarks
|
|
125
|
+
if out_lm_src is not None:
|
|
126
|
+
out_lm_src = _transform_landmarks(out_lm_src, M, w, h)
|
|
127
|
+
if out_lm_dst is not None:
|
|
128
|
+
out_lm_dst = _transform_landmarks(out_lm_dst, M, w, h)
|
|
129
|
+
|
|
130
|
+
# --- Photometric augmentations (images only, not conditioning/mask) ---
|
|
131
|
+
|
|
132
|
+
# Brightness
|
|
133
|
+
b_factor = rng.uniform(config.brightness_range[0], config.brightness_range[1])
|
|
134
|
+
out_input = np.clip(out_input.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
|
|
135
|
+
out_target = np.clip(out_target.astype(np.float32) * b_factor, 0, 255).astype(np.uint8)
|
|
136
|
+
|
|
137
|
+
# Contrast
|
|
138
|
+
c_factor = rng.uniform(config.contrast_range[0], config.contrast_range[1])
|
|
139
|
+
mean_in = out_input.mean()
|
|
140
|
+
mean_tgt = out_target.mean()
|
|
141
|
+
out_input = np.clip(
|
|
142
|
+
(out_input.astype(np.float32) - mean_in) * c_factor + mean_in, 0, 255
|
|
143
|
+
).astype(np.uint8)
|
|
144
|
+
out_target = np.clip(
|
|
145
|
+
(out_target.astype(np.float32) - mean_tgt) * c_factor + mean_tgt, 0, 255
|
|
146
|
+
).astype(np.uint8)
|
|
147
|
+
|
|
148
|
+
# Saturation (in HSV space)
|
|
149
|
+
s_factor = rng.uniform(config.saturation_range[0], config.saturation_range[1])
|
|
150
|
+
if abs(s_factor - 1.0) > 1e-4:
|
|
151
|
+
out_input = _adjust_saturation(out_input, s_factor)
|
|
152
|
+
out_target = _adjust_saturation(out_target, s_factor)
|
|
153
|
+
|
|
154
|
+
# Hue shift
|
|
155
|
+
if config.hue_shift_range > 0:
|
|
156
|
+
hue_delta = rng.uniform(-config.hue_shift_range, config.hue_shift_range)
|
|
157
|
+
if abs(hue_delta) > 0.1:
|
|
158
|
+
out_input = _shift_hue(out_input, hue_delta)
|
|
159
|
+
out_target = _shift_hue(out_target, hue_delta)
|
|
160
|
+
|
|
161
|
+
# --- Conditioning augmentation ---
|
|
162
|
+
|
|
163
|
+
# Conditioning dropout (replace with zeros to learn unconditional)
|
|
164
|
+
if config.conditioning_dropout_prob > 0 and rng.random() < config.conditioning_dropout_prob:
|
|
165
|
+
out_cond = np.zeros_like(out_cond)
|
|
166
|
+
|
|
167
|
+
# Conditioning noise
|
|
168
|
+
if config.conditioning_noise_std > 0:
|
|
169
|
+
noise = rng.normal(0, config.conditioning_noise_std * 255, out_cond.shape)
|
|
170
|
+
out_cond = np.clip(out_cond.astype(np.float32) + noise, 0, 255).astype(np.uint8)
|
|
171
|
+
|
|
172
|
+
result = {
|
|
173
|
+
"input_image": out_input,
|
|
174
|
+
"target_image": out_target,
|
|
175
|
+
"conditioning": out_cond,
|
|
176
|
+
"mask": out_mask,
|
|
177
|
+
}
|
|
178
|
+
if out_lm_src is not None:
|
|
179
|
+
result["landmarks_src"] = out_lm_src
|
|
180
|
+
if out_lm_dst is not None:
|
|
181
|
+
result["landmarks_dst"] = out_lm_dst
|
|
182
|
+
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _transform_landmarks(landmarks: np.ndarray, M: np.ndarray, w: int, h: int) -> np.ndarray:
|
|
187
|
+
"""Transform normalized landmarks with an affine matrix."""
|
|
188
|
+
# Convert to pixel coords
|
|
189
|
+
px = landmarks.copy()
|
|
190
|
+
px[:, 0] *= w
|
|
191
|
+
px[:, 1] *= h
|
|
192
|
+
|
|
193
|
+
# Apply affine transform
|
|
194
|
+
ones = np.ones((px.shape[0], 1))
|
|
195
|
+
px_h = np.hstack([px, ones]) # (N, 3)
|
|
196
|
+
transformed = (M @ px_h.T).T # (N, 2)
|
|
197
|
+
|
|
198
|
+
# Back to normalized
|
|
199
|
+
transformed[:, 0] /= w
|
|
200
|
+
transformed[:, 1] /= h
|
|
201
|
+
return np.clip(transformed, 0.0, 1.0)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _adjust_saturation(img: np.ndarray, factor: float) -> np.ndarray:
|
|
205
|
+
"""Adjust saturation of a BGR image."""
|
|
206
|
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
|
|
207
|
+
hsv[:, :, 1] = np.clip(hsv[:, :, 1] * factor, 0, 255)
|
|
208
|
+
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def _shift_hue(img: np.ndarray, delta_deg: float) -> np.ndarray:
|
|
212
|
+
"""Shift hue of a BGR image by delta degrees."""
|
|
213
|
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV).astype(np.float32)
|
|
214
|
+
# OpenCV hue range is [0, 180]
|
|
215
|
+
hsv[:, :, 0] = (hsv[:, :, 0] + delta_deg / 2) % 180
|
|
216
|
+
return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def augment_skin_tone(
|
|
220
|
+
image: np.ndarray,
|
|
221
|
+
ita_delta: float = 0.0,
|
|
222
|
+
) -> np.ndarray:
|
|
223
|
+
"""Augment skin tone by shifting in L*a*b* space.
|
|
224
|
+
|
|
225
|
+
This helps balance Fitzpatrick representation in training by
|
|
226
|
+
simulating different skin tones from existing samples.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
image: (H, W, 3) BGR uint8 image.
|
|
230
|
+
ita_delta: ITA angle shift (positive = lighter, negative = darker).
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Augmented image with shifted skin tone.
|
|
234
|
+
"""
|
|
235
|
+
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB).astype(np.float32)
|
|
236
|
+
|
|
237
|
+
# Shift L channel (lightness) based on ITA delta
|
|
238
|
+
# ITA = arctan((L-50)/b), so shifting ITA shifts L
|
|
239
|
+
l_shift = ita_delta * 0.5 # approximate mapping
|
|
240
|
+
lab[:, :, 0] = np.clip(lab[:, :, 0] + l_shift, 0, 255)
|
|
241
|
+
|
|
242
|
+
# Slightly shift b channel too for more natural tone changes
|
|
243
|
+
b_shift = -ita_delta * 0.15
|
|
244
|
+
lab[:, :, 2] = np.clip(lab[:, :, 2] + b_shift, 0, 255)
|
|
245
|
+
|
|
246
|
+
return cv2.cvtColor(lab.astype(np.uint8), cv2.COLOR_LAB2BGR)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class FitzpatrickBalancer:
|
|
250
|
+
"""Oversample underrepresented Fitzpatrick types during training.
|
|
251
|
+
|
|
252
|
+
Maintains per-type counts and generates sampling weights to ensure
|
|
253
|
+
equitable training across all skin types.
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
def __init__(self, target_distribution: dict[str, float] | None = None):
|
|
257
|
+
"""Initialize balancer.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
target_distribution: Target fraction per type. Defaults to uniform.
|
|
261
|
+
"""
|
|
262
|
+
self.target = target_distribution or {
|
|
263
|
+
"I": 1 / 6,
|
|
264
|
+
"II": 1 / 6,
|
|
265
|
+
"III": 1 / 6,
|
|
266
|
+
"IV": 1 / 6,
|
|
267
|
+
"V": 1 / 6,
|
|
268
|
+
"VI": 1 / 6,
|
|
269
|
+
}
|
|
270
|
+
self._counts: dict[str, int] = {}
|
|
271
|
+
|
|
272
|
+
def register_sample(self, fitz_type: str) -> None:
|
|
273
|
+
"""Register a sample's Fitzpatrick type."""
|
|
274
|
+
self._counts[fitz_type] = self._counts.get(fitz_type, 0) + 1
|
|
275
|
+
|
|
276
|
+
def get_sampling_weights(self, fitz_types: list[str]) -> np.ndarray:
|
|
277
|
+
"""Compute sampling weights for a list of samples.
|
|
278
|
+
|
|
279
|
+
Returns weights inversely proportional to type frequency,
|
|
280
|
+
so underrepresented types get upsampled.
|
|
281
|
+
"""
|
|
282
|
+
total = sum(self._counts.values()) or 1
|
|
283
|
+
weights = []
|
|
284
|
+
for ft in fitz_types:
|
|
285
|
+
count = self._counts.get(ft, 1)
|
|
286
|
+
freq = count / total
|
|
287
|
+
target_freq = self.target.get(ft, 1 / 6)
|
|
288
|
+
# Weight = target / actual (capped for stability)
|
|
289
|
+
w = min(target_freq / max(freq, 1e-6), 5.0)
|
|
290
|
+
weights.append(w)
|
|
291
|
+
|
|
292
|
+
w = np.array(weights, dtype=np.float64)
|
|
293
|
+
return w / w.sum() # normalize to probability distribution
|