landmarkdiff 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. landmarkdiff/__init__.py +40 -0
  2. landmarkdiff/__main__.py +207 -0
  3. landmarkdiff/api_client.py +316 -0
  4. landmarkdiff/arcface_torch.py +583 -0
  5. landmarkdiff/audit.py +338 -0
  6. landmarkdiff/augmentation.py +293 -0
  7. landmarkdiff/benchmark.py +213 -0
  8. landmarkdiff/checkpoint_manager.py +361 -0
  9. landmarkdiff/cli.py +252 -0
  10. landmarkdiff/clinical.py +223 -0
  11. landmarkdiff/conditioning.py +278 -0
  12. landmarkdiff/config.py +358 -0
  13. landmarkdiff/curriculum.py +191 -0
  14. landmarkdiff/data.py +405 -0
  15. landmarkdiff/data_version.py +301 -0
  16. landmarkdiff/displacement_model.py +745 -0
  17. landmarkdiff/ensemble.py +330 -0
  18. landmarkdiff/evaluation.py +415 -0
  19. landmarkdiff/experiment_tracker.py +231 -0
  20. landmarkdiff/face_verifier.py +947 -0
  21. landmarkdiff/fid.py +244 -0
  22. landmarkdiff/hyperparam.py +347 -0
  23. landmarkdiff/inference.py +754 -0
  24. landmarkdiff/landmarks.py +432 -0
  25. landmarkdiff/log.py +90 -0
  26. landmarkdiff/losses.py +348 -0
  27. landmarkdiff/manipulation.py +651 -0
  28. landmarkdiff/masking.py +316 -0
  29. landmarkdiff/metrics_agg.py +313 -0
  30. landmarkdiff/metrics_viz.py +464 -0
  31. landmarkdiff/model_registry.py +362 -0
  32. landmarkdiff/morphometry.py +342 -0
  33. landmarkdiff/postprocess.py +600 -0
  34. landmarkdiff/py.typed +0 -0
  35. landmarkdiff/safety.py +395 -0
  36. landmarkdiff/synthetic/__init__.py +23 -0
  37. landmarkdiff/synthetic/augmentation.py +188 -0
  38. landmarkdiff/synthetic/pair_generator.py +208 -0
  39. landmarkdiff/synthetic/tps_warp.py +273 -0
  40. landmarkdiff/validation.py +324 -0
  41. landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
  42. landmarkdiff-0.2.3.dist-info/RECORD +46 -0
  43. landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
  44. landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
  45. landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
  46. landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,464 @@
1
+ """Publication-quality metrics visualization for LandmarkDiff.
2
+
3
+ Generates figures suitable for MICCAI/medical imaging papers:
4
+ - Bar charts comparing procedures and methods
5
+ - Radar plots for multi-metric comparison
6
+ - Box plots for per-sample distributions
7
+ - Heatmaps for Fitzpatrick equity analysis
8
+ - Table formatters for LaTeX
9
+
10
+ Usage:
11
+ from landmarkdiff.metrics_viz import MetricsVisualizer
12
+
13
+ viz = MetricsVisualizer(output_dir="paper/figures")
14
+
15
+ # Bar chart comparing procedures
16
+ viz.procedure_comparison(metrics_by_procedure)
17
+
18
+ # Radar plot for ablation study
19
+ viz.radar_plot(experiments)
20
+
21
+ # Equity heatmap
22
+ viz.fitzpatrick_heatmap(metrics_by_type)
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ from pathlib import Path
28
+ from typing import Any
29
+
30
+ import numpy as np
31
+
32
+
33
+ class MetricsVisualizer:
34
+ """Generate publication-quality figures from evaluation metrics.
35
+
36
+ Args:
37
+ output_dir: Directory to save generated figures.
38
+ dpi: Resolution for saved figures.
39
+ style: Matplotlib style preset.
40
+ """
41
+
42
+ # Color palette (colorblind-safe, MICCAI-friendly)
43
+ COLORS = {
44
+ "rhinoplasty": "#4C72B0",
45
+ "blepharoplasty": "#55A868",
46
+ "rhytidectomy": "#C44E52",
47
+ "orthognathic": "#8172B2",
48
+ "baseline": "#CCB974",
49
+ "ours": "#4C72B0",
50
+ }
51
+
52
+ METRIC_LABELS = {
53
+ "ssim": "SSIM",
54
+ "lpips": "LPIPS",
55
+ "fid": "FID",
56
+ "nme": "NME",
57
+ "identity_sim": "ID Sim.",
58
+ "psnr": "PSNR (dB)",
59
+ }
60
+
61
+ METRIC_HIGHER_BETTER = {
62
+ "ssim": True,
63
+ "lpips": False,
64
+ "fid": False,
65
+ "nme": False,
66
+ "identity_sim": True,
67
+ "psnr": True,
68
+ }
69
+
70
+ def __init__(
71
+ self,
72
+ output_dir: str | Path = "figures",
73
+ dpi: int = 300,
74
+ style: str = "seaborn-v0_8-whitegrid",
75
+ ) -> None:
76
+ self.output_dir = Path(output_dir)
77
+ self.output_dir.mkdir(parents=True, exist_ok=True)
78
+ self.dpi = dpi
79
+ self.style = style
80
+
81
+ def _get_plt(self) -> Any:
82
+ """Import matplotlib with configuration."""
83
+ import matplotlib
84
+
85
+ matplotlib.use("Agg")
86
+ import matplotlib.pyplot as plt
87
+
88
+ try:
89
+ plt.style.use(self.style)
90
+ except OSError:
91
+ plt.style.use("seaborn-v0_8")
92
+ # Publication font sizes
93
+ plt.rcParams.update(
94
+ {
95
+ "font.size": 10,
96
+ "axes.titlesize": 12,
97
+ "axes.labelsize": 11,
98
+ "xtick.labelsize": 9,
99
+ "ytick.labelsize": 9,
100
+ "legend.fontsize": 9,
101
+ "figure.titlesize": 13,
102
+ }
103
+ )
104
+ return plt
105
+
106
+ # ------------------------------------------------------------------
107
+ # Procedure comparison bar chart
108
+ # ------------------------------------------------------------------
109
+
110
+ def procedure_comparison(
111
+ self,
112
+ metrics_by_procedure: dict[str, dict[str, float]],
113
+ metrics: list[str] | None = None,
114
+ title: str = "Per-Procedure Performance",
115
+ filename: str = "procedure_comparison.pdf",
116
+ ) -> Path:
117
+ """Generate grouped bar chart comparing procedures.
118
+
119
+ Args:
120
+ metrics_by_procedure: {procedure: {metric: value}}.
121
+ metrics: Which metrics to show. None = auto-detect.
122
+ title: Figure title.
123
+ filename: Output filename.
124
+
125
+ Returns:
126
+ Path to saved figure.
127
+ """
128
+ plt = self._get_plt()
129
+
130
+ if metrics is None:
131
+ all_metrics: set[str] = set()
132
+ for m in metrics_by_procedure.values():
133
+ all_metrics.update(m.keys())
134
+ metrics = sorted(all_metrics & set(self.METRIC_LABELS.keys()))
135
+
136
+ procedures = list(metrics_by_procedure.keys())
137
+ n_procs = len(procedures)
138
+ n_metrics = len(metrics)
139
+
140
+ fig, axes = plt.subplots(1, n_metrics, figsize=(3 * n_metrics, 4))
141
+ if n_metrics == 1:
142
+ axes = [axes]
143
+
144
+ for ax, metric in zip(axes, metrics):
145
+ values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
146
+ colors = [self.COLORS.get(p, "#999999") for p in procedures]
147
+
148
+ bars = ax.bar(range(n_procs), values, color=colors, width=0.6, edgecolor="white")
149
+ ax.set_xticks(range(n_procs))
150
+ ax.set_xticklabels(
151
+ [p[:5].title() for p in procedures],
152
+ rotation=30,
153
+ ha="right",
154
+ )
155
+ ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
156
+ ax.set_title(self.METRIC_LABELS.get(metric, metric))
157
+
158
+ # Add value labels on bars
159
+ for bar, val in zip(bars, values):
160
+ ax.text(
161
+ bar.get_x() + bar.get_width() / 2,
162
+ bar.get_height(),
163
+ f"{val:.3f}",
164
+ ha="center",
165
+ va="bottom",
166
+ fontsize=8,
167
+ )
168
+
169
+ fig.suptitle(title, fontweight="bold")
170
+ fig.tight_layout()
171
+
172
+ out_path = self.output_dir / filename
173
+ fig.savefig(out_path, dpi=self.dpi, bbox_inches="tight")
174
+ plt.close(fig)
175
+ return out_path
176
+
177
+ # ------------------------------------------------------------------
178
+ # Radar plot for multi-metric comparison
179
+ # ------------------------------------------------------------------
180
+
181
+ def radar_plot(
182
+ self,
183
+ experiments: dict[str, dict[str, float]],
184
+ metrics: list[str] | None = None,
185
+ title: str = "Multi-Metric Comparison",
186
+ filename: str = "radar_plot.pdf",
187
+ ) -> Path:
188
+ """Generate radar/spider plot for comparing experiments.
189
+
190
+ Args:
191
+ experiments: {experiment_name: {metric: value}}.
192
+ metrics: Which metrics to show.
193
+ title: Figure title.
194
+ filename: Output filename.
195
+
196
+ Returns:
197
+ Path to saved figure.
198
+ """
199
+ plt = self._get_plt()
200
+
201
+ if metrics is None:
202
+ metrics = sorted(
203
+ set.intersection(*(set(v.keys()) for v in experiments.values()))
204
+ & set(self.METRIC_LABELS.keys())
205
+ )
206
+
207
+ n_metrics = len(metrics)
208
+ angles = np.linspace(0, 2 * np.pi, n_metrics, endpoint=False).tolist()
209
+ angles += angles[:1] # Close the polygon
210
+
211
+ fig, ax = plt.subplots(figsize=(6, 6), subplot_kw={"polar": True})
212
+
213
+ colors = list(self.COLORS.values())
214
+ for i, (name, values_dict) in enumerate(experiments.items()):
215
+ raw_values = []
216
+ for m in metrics:
217
+ val = values_dict.get(m, 0)
218
+ # Normalize: for "lower is better" metrics, invert
219
+ if not self.METRIC_HIGHER_BETTER.get(m, True):
220
+ val = 1 - min(val, 1) # Invert so higher = better on plot
221
+ raw_values.append(val)
222
+
223
+ # Normalize to [0, 1] range
224
+ vals = np.array(raw_values)
225
+ vals = vals / max(vals.max(), 1e-10)
226
+ vals = vals.tolist() + vals[:1].tolist()
227
+
228
+ color = colors[i % len(colors)]
229
+ ax.plot(angles, vals, "o-", linewidth=2, label=name, color=color)
230
+ ax.fill(angles, vals, alpha=0.15, color=color)
231
+
232
+ ax.set_xticks(angles[:-1])
233
+ ax.set_xticklabels([self.METRIC_LABELS.get(m, m) for m in metrics])
234
+ ax.set_ylim(0, 1.1)
235
+ ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.0))
236
+ ax.set_title(title, fontweight="bold", pad=20)
237
+
238
+ out_path = self.output_dir / filename
239
+ fig.savefig(out_path, dpi=self.dpi, bbox_inches="tight")
240
+ plt.close(fig)
241
+ return out_path
242
+
243
+ # ------------------------------------------------------------------
244
+ # Fitzpatrick equity heatmap
245
+ # ------------------------------------------------------------------
246
+
247
+ def fitzpatrick_heatmap(
248
+ self,
249
+ metrics_by_type: dict[str, dict[str, float]],
250
+ metric: str = "ssim",
251
+ title: str | None = None,
252
+ filename: str = "fitzpatrick_equity.pdf",
253
+ ) -> Path:
254
+ """Generate heatmap showing metric values across Fitzpatrick types and procedures.
255
+
256
+ Args:
257
+ metrics_by_type: {fitzpatrick_type: {procedure: value}}.
258
+ metric: Which metric to visualize.
259
+ title: Figure title.
260
+ filename: Output filename.
261
+
262
+ Returns:
263
+ Path to saved figure.
264
+ """
265
+ plt = self._get_plt()
266
+
267
+ fitz_types = sorted(metrics_by_type.keys())
268
+ procedures = sorted(set.union(*(set(v.keys()) for v in metrics_by_type.values())))
269
+
270
+ # Build matrix
271
+ matrix = np.zeros((len(fitz_types), len(procedures)))
272
+ for i, ft in enumerate(fitz_types):
273
+ for j, proc in enumerate(procedures):
274
+ matrix[i, j] = metrics_by_type[ft].get(proc, 0)
275
+
276
+ fig, ax = plt.subplots(
277
+ figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8))
278
+ )
279
+
280
+ cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
281
+ im = ax.imshow(matrix, cmap=cmap, aspect="auto")
282
+
283
+ ax.set_xticks(range(len(procedures)))
284
+ ax.set_xticklabels([p.title() for p in procedures], rotation=30, ha="right")
285
+ ax.set_yticks(range(len(fitz_types)))
286
+ ax.set_yticklabels(fitz_types)
287
+ ax.set_ylabel("Fitzpatrick Type")
288
+
289
+ # Annotate cells
290
+ for i in range(len(fitz_types)):
291
+ for j in range(len(procedures)):
292
+ ax.text(
293
+ j,
294
+ i,
295
+ f"{matrix[i, j]:.3f}",
296
+ ha="center",
297
+ va="center",
298
+ fontsize=9,
299
+ color="white" if matrix[i, j] < np.median(matrix) else "black",
300
+ )
301
+
302
+ fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
303
+
304
+ if title is None:
305
+ title = f"{self.METRIC_LABELS.get(metric, metric)} by Fitzpatrick Type"
306
+ ax.set_title(title, fontweight="bold")
307
+ fig.tight_layout()
308
+
309
+ out_path = self.output_dir / filename
310
+ fig.savefig(out_path, dpi=self.dpi, bbox_inches="tight")
311
+ plt.close(fig)
312
+ return out_path
313
+
314
+ # ------------------------------------------------------------------
315
+ # Box plots for per-sample distribution
316
+ # ------------------------------------------------------------------
317
+
318
+ def distribution_boxplot(
319
+ self,
320
+ samples_by_group: dict[str, list[float]],
321
+ metric: str = "ssim",
322
+ title: str | None = None,
323
+ filename: str = "distribution.pdf",
324
+ ) -> Path:
325
+ """Generate box plot showing per-sample metric distributions.
326
+
327
+ Args:
328
+ samples_by_group: {group_name: [sample_values]}.
329
+ metric: Metric being plotted.
330
+ title: Figure title.
331
+ filename: Output filename.
332
+
333
+ Returns:
334
+ Path to saved figure.
335
+ """
336
+ plt = self._get_plt()
337
+
338
+ groups = list(samples_by_group.keys())
339
+ data = [samples_by_group[g] for g in groups]
340
+
341
+ fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
342
+
343
+ bp = ax.boxplot(
344
+ data,
345
+ patch_artist=True,
346
+ widths=0.6,
347
+ medianprops={"color": "black", "linewidth": 1.5},
348
+ )
349
+
350
+ colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
351
+ for patch, color in zip(bp["boxes"], colors):
352
+ patch.set_facecolor(color)
353
+ patch.set_alpha(0.7)
354
+
355
+ ax.set_xticklabels(
356
+ [g.title() for g in groups],
357
+ rotation=30,
358
+ ha="right",
359
+ )
360
+ ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
361
+
362
+ if title is None:
363
+ title = f"{self.METRIC_LABELS.get(metric, metric)} Distribution"
364
+ ax.set_title(title, fontweight="bold")
365
+
366
+ # Add sample count annotations
367
+ for i, (_g, vals) in enumerate(zip(groups, data)):
368
+ ax.text(
369
+ i + 1,
370
+ ax.get_ylim()[0],
371
+ f"n={len(vals)}",
372
+ ha="center",
373
+ va="bottom",
374
+ fontsize=8,
375
+ color="gray",
376
+ )
377
+
378
+ fig.tight_layout()
379
+ out_path = self.output_dir / filename
380
+ fig.savefig(out_path, dpi=self.dpi, bbox_inches="tight")
381
+ plt.close(fig)
382
+ return out_path
383
+
384
+ # ------------------------------------------------------------------
385
+ # LaTeX table formatter
386
+ # ------------------------------------------------------------------
387
+
388
+ @staticmethod
389
+ def to_latex_table(
390
+ rows: list[dict[str, Any]],
391
+ metrics: list[str],
392
+ caption: str = "Quantitative results",
393
+ label: str = "tab:results",
394
+ highlight_best: bool = True,
395
+ ) -> str:
396
+ """Format metrics as a LaTeX table.
397
+
398
+ Args:
399
+ rows: List of dicts with 'name' and metric values.
400
+ metrics: List of metric names to include.
401
+ caption: Table caption.
402
+ label: LaTeX label.
403
+ highlight_best: Bold the best value per column.
404
+
405
+ Returns:
406
+ LaTeX table string.
407
+ """
408
+ metric_labels = MetricsVisualizer.METRIC_LABELS
409
+ higher_better = MetricsVisualizer.METRIC_HIGHER_BETTER
410
+
411
+ # Find best values
412
+ best: dict[str, float] = {}
413
+ if highlight_best:
414
+ for m in metrics:
415
+ vals = [r.get(m) for r in rows if r.get(m) is not None]
416
+ if vals:
417
+ if higher_better.get(m, True):
418
+ best[m] = max(vals)
419
+ else:
420
+ best[m] = min(vals)
421
+
422
+ cols = "l" + "c" * len(metrics)
423
+ lines = [
424
+ "\\begin{table}[t]",
425
+ "\\centering",
426
+ f"\\caption{{{caption}}}",
427
+ f"\\label{{{label}}}",
428
+ f"\\begin{{tabular}}{{{cols}}}",
429
+ "\\toprule",
430
+ ]
431
+
432
+ # Header
433
+ header = ["Method"]
434
+ for m in metrics:
435
+ name = metric_labels.get(m, m)
436
+ arrow = "$\\uparrow$" if higher_better.get(m, True) else "$\\downarrow$"
437
+ header.append(f"{name} {arrow}")
438
+ lines.append(" & ".join(header) + " \\\\")
439
+ lines.append("\\midrule")
440
+
441
+ # Data rows
442
+ for row in rows:
443
+ parts = [row.get("name", "").replace("_", "\\_")]
444
+ for m in metrics:
445
+ val = row.get(m)
446
+ if val is None:
447
+ parts.append("--")
448
+ else:
449
+ fmt = ".4f" if abs(val) < 10 else ".1f"
450
+ val_str = f"{val:{fmt}}"
451
+ if highlight_best and val == best.get(m):
452
+ val_str = f"\\textbf{{{val_str}}}"
453
+ parts.append(val_str)
454
+ lines.append(" & ".join(parts) + " \\\\")
455
+
456
+ lines.extend(
457
+ [
458
+ "\\bottomrule",
459
+ "\\end{tabular}",
460
+ "\\end{table}",
461
+ ]
462
+ )
463
+
464
+ return "\n".join(lines)