landmarkdiff 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- landmarkdiff/__init__.py +40 -0
- landmarkdiff/__main__.py +207 -0
- landmarkdiff/api_client.py +316 -0
- landmarkdiff/arcface_torch.py +583 -0
- landmarkdiff/audit.py +338 -0
- landmarkdiff/augmentation.py +293 -0
- landmarkdiff/benchmark.py +213 -0
- landmarkdiff/checkpoint_manager.py +361 -0
- landmarkdiff/cli.py +252 -0
- landmarkdiff/clinical.py +223 -0
- landmarkdiff/conditioning.py +278 -0
- landmarkdiff/config.py +358 -0
- landmarkdiff/curriculum.py +191 -0
- landmarkdiff/data.py +405 -0
- landmarkdiff/data_version.py +301 -0
- landmarkdiff/displacement_model.py +745 -0
- landmarkdiff/ensemble.py +330 -0
- landmarkdiff/evaluation.py +415 -0
- landmarkdiff/experiment_tracker.py +231 -0
- landmarkdiff/face_verifier.py +947 -0
- landmarkdiff/fid.py +244 -0
- landmarkdiff/hyperparam.py +347 -0
- landmarkdiff/inference.py +754 -0
- landmarkdiff/landmarks.py +432 -0
- landmarkdiff/log.py +90 -0
- landmarkdiff/losses.py +348 -0
- landmarkdiff/manipulation.py +651 -0
- landmarkdiff/masking.py +316 -0
- landmarkdiff/metrics_agg.py +313 -0
- landmarkdiff/metrics_viz.py +464 -0
- landmarkdiff/model_registry.py +362 -0
- landmarkdiff/morphometry.py +342 -0
- landmarkdiff/postprocess.py +600 -0
- landmarkdiff/py.typed +0 -0
- landmarkdiff/safety.py +395 -0
- landmarkdiff/synthetic/__init__.py +23 -0
- landmarkdiff/synthetic/augmentation.py +188 -0
- landmarkdiff/synthetic/pair_generator.py +208 -0
- landmarkdiff/synthetic/tps_warp.py +273 -0
- landmarkdiff/validation.py +324 -0
- landmarkdiff-0.2.3.dist-info/METADATA +1173 -0
- landmarkdiff-0.2.3.dist-info/RECORD +46 -0
- landmarkdiff-0.2.3.dist-info/WHEEL +5 -0
- landmarkdiff-0.2.3.dist-info/entry_points.txt +2 -0
- landmarkdiff-0.2.3.dist-info/licenses/LICENSE +21 -0
- landmarkdiff-0.2.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,464 @@
|
|
|
1
|
+
"""Publication-quality metrics visualization for LandmarkDiff.
|
|
2
|
+
|
|
3
|
+
Generates figures suitable for MICCAI/medical imaging papers:
|
|
4
|
+
- Bar charts comparing procedures and methods
|
|
5
|
+
- Radar plots for multi-metric comparison
|
|
6
|
+
- Box plots for per-sample distributions
|
|
7
|
+
- Heatmaps for Fitzpatrick equity analysis
|
|
8
|
+
- Table formatters for LaTeX
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from landmarkdiff.metrics_viz import MetricsVisualizer
|
|
12
|
+
|
|
13
|
+
viz = MetricsVisualizer(output_dir="paper/figures")
|
|
14
|
+
|
|
15
|
+
# Bar chart comparing procedures
|
|
16
|
+
viz.procedure_comparison(metrics_by_procedure)
|
|
17
|
+
|
|
18
|
+
# Radar plot for ablation study
|
|
19
|
+
viz.radar_plot(experiments)
|
|
20
|
+
|
|
21
|
+
# Equity heatmap
|
|
22
|
+
viz.fitzpatrick_heatmap(metrics_by_type)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
from __future__ import annotations
|
|
26
|
+
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MetricsVisualizer:
|
|
34
|
+
"""Generate publication-quality figures from evaluation metrics.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
output_dir: Directory to save generated figures.
|
|
38
|
+
dpi: Resolution for saved figures.
|
|
39
|
+
style: Matplotlib style preset.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
# Color palette (colorblind-safe, MICCAI-friendly)
|
|
43
|
+
COLORS = {
|
|
44
|
+
"rhinoplasty": "#4C72B0",
|
|
45
|
+
"blepharoplasty": "#55A868",
|
|
46
|
+
"rhytidectomy": "#C44E52",
|
|
47
|
+
"orthognathic": "#8172B2",
|
|
48
|
+
"baseline": "#CCB974",
|
|
49
|
+
"ours": "#4C72B0",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
METRIC_LABELS = {
|
|
53
|
+
"ssim": "SSIM",
|
|
54
|
+
"lpips": "LPIPS",
|
|
55
|
+
"fid": "FID",
|
|
56
|
+
"nme": "NME",
|
|
57
|
+
"identity_sim": "ID Sim.",
|
|
58
|
+
"psnr": "PSNR (dB)",
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
METRIC_HIGHER_BETTER = {
|
|
62
|
+
"ssim": True,
|
|
63
|
+
"lpips": False,
|
|
64
|
+
"fid": False,
|
|
65
|
+
"nme": False,
|
|
66
|
+
"identity_sim": True,
|
|
67
|
+
"psnr": True,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
output_dir: str | Path = "figures",
|
|
73
|
+
dpi: int = 300,
|
|
74
|
+
style: str = "seaborn-v0_8-whitegrid",
|
|
75
|
+
) -> None:
|
|
76
|
+
self.output_dir = Path(output_dir)
|
|
77
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
78
|
+
self.dpi = dpi
|
|
79
|
+
self.style = style
|
|
80
|
+
|
|
81
|
+
def _get_plt(self) -> Any:
|
|
82
|
+
"""Import matplotlib with configuration."""
|
|
83
|
+
import matplotlib
|
|
84
|
+
|
|
85
|
+
matplotlib.use("Agg")
|
|
86
|
+
import matplotlib.pyplot as plt
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
plt.style.use(self.style)
|
|
90
|
+
except OSError:
|
|
91
|
+
plt.style.use("seaborn-v0_8")
|
|
92
|
+
# Publication font sizes
|
|
93
|
+
plt.rcParams.update(
|
|
94
|
+
{
|
|
95
|
+
"font.size": 10,
|
|
96
|
+
"axes.titlesize": 12,
|
|
97
|
+
"axes.labelsize": 11,
|
|
98
|
+
"xtick.labelsize": 9,
|
|
99
|
+
"ytick.labelsize": 9,
|
|
100
|
+
"legend.fontsize": 9,
|
|
101
|
+
"figure.titlesize": 13,
|
|
102
|
+
}
|
|
103
|
+
)
|
|
104
|
+
return plt
|
|
105
|
+
|
|
106
|
+
# ------------------------------------------------------------------
|
|
107
|
+
# Procedure comparison bar chart
|
|
108
|
+
# ------------------------------------------------------------------
|
|
109
|
+
|
|
110
|
+
def procedure_comparison(
|
|
111
|
+
self,
|
|
112
|
+
metrics_by_procedure: dict[str, dict[str, float]],
|
|
113
|
+
metrics: list[str] | None = None,
|
|
114
|
+
title: str = "Per-Procedure Performance",
|
|
115
|
+
filename: str = "procedure_comparison.pdf",
|
|
116
|
+
) -> Path:
|
|
117
|
+
"""Generate grouped bar chart comparing procedures.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
metrics_by_procedure: {procedure: {metric: value}}.
|
|
121
|
+
metrics: Which metrics to show. None = auto-detect.
|
|
122
|
+
title: Figure title.
|
|
123
|
+
filename: Output filename.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Path to saved figure.
|
|
127
|
+
"""
|
|
128
|
+
plt = self._get_plt()
|
|
129
|
+
|
|
130
|
+
if metrics is None:
|
|
131
|
+
all_metrics: set[str] = set()
|
|
132
|
+
for m in metrics_by_procedure.values():
|
|
133
|
+
all_metrics.update(m.keys())
|
|
134
|
+
metrics = sorted(all_metrics & set(self.METRIC_LABELS.keys()))
|
|
135
|
+
|
|
136
|
+
procedures = list(metrics_by_procedure.keys())
|
|
137
|
+
n_procs = len(procedures)
|
|
138
|
+
n_metrics = len(metrics)
|
|
139
|
+
|
|
140
|
+
fig, axes = plt.subplots(1, n_metrics, figsize=(3 * n_metrics, 4))
|
|
141
|
+
if n_metrics == 1:
|
|
142
|
+
axes = [axes]
|
|
143
|
+
|
|
144
|
+
for ax, metric in zip(axes, metrics):
|
|
145
|
+
values = [metrics_by_procedure[p].get(metric, 0) for p in procedures]
|
|
146
|
+
colors = [self.COLORS.get(p, "#999999") for p in procedures]
|
|
147
|
+
|
|
148
|
+
bars = ax.bar(range(n_procs), values, color=colors, width=0.6, edgecolor="white")
|
|
149
|
+
ax.set_xticks(range(n_procs))
|
|
150
|
+
ax.set_xticklabels(
|
|
151
|
+
[p[:5].title() for p in procedures],
|
|
152
|
+
rotation=30,
|
|
153
|
+
ha="right",
|
|
154
|
+
)
|
|
155
|
+
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
|
156
|
+
ax.set_title(self.METRIC_LABELS.get(metric, metric))
|
|
157
|
+
|
|
158
|
+
# Add value labels on bars
|
|
159
|
+
for bar, val in zip(bars, values):
|
|
160
|
+
ax.text(
|
|
161
|
+
bar.get_x() + bar.get_width() / 2,
|
|
162
|
+
bar.get_height(),
|
|
163
|
+
f"{val:.3f}",
|
|
164
|
+
ha="center",
|
|
165
|
+
va="bottom",
|
|
166
|
+
fontsize=8,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
fig.suptitle(title, fontweight="bold")
|
|
170
|
+
fig.tight_layout()
|
|
171
|
+
|
|
172
|
+
out_path = self.output_dir / filename
|
|
173
|
+
fig.savefig(out_path, dpi=self.dpi, bbox_inches="tight")
|
|
174
|
+
plt.close(fig)
|
|
175
|
+
return out_path
|
|
176
|
+
|
|
177
|
+
# ------------------------------------------------------------------
|
|
178
|
+
# Radar plot for multi-metric comparison
|
|
179
|
+
# ------------------------------------------------------------------
|
|
180
|
+
|
|
181
|
+
def radar_plot(
|
|
182
|
+
self,
|
|
183
|
+
experiments: dict[str, dict[str, float]],
|
|
184
|
+
metrics: list[str] | None = None,
|
|
185
|
+
title: str = "Multi-Metric Comparison",
|
|
186
|
+
filename: str = "radar_plot.pdf",
|
|
187
|
+
) -> Path:
|
|
188
|
+
"""Generate radar/spider plot for comparing experiments.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
experiments: {experiment_name: {metric: value}}.
|
|
192
|
+
metrics: Which metrics to show.
|
|
193
|
+
title: Figure title.
|
|
194
|
+
filename: Output filename.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Path to saved figure.
|
|
198
|
+
"""
|
|
199
|
+
plt = self._get_plt()
|
|
200
|
+
|
|
201
|
+
if metrics is None:
|
|
202
|
+
metrics = sorted(
|
|
203
|
+
set.intersection(*(set(v.keys()) for v in experiments.values()))
|
|
204
|
+
& set(self.METRIC_LABELS.keys())
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
n_metrics = len(metrics)
|
|
208
|
+
angles = np.linspace(0, 2 * np.pi, n_metrics, endpoint=False).tolist()
|
|
209
|
+
angles += angles[:1] # Close the polygon
|
|
210
|
+
|
|
211
|
+
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw={"polar": True})
|
|
212
|
+
|
|
213
|
+
colors = list(self.COLORS.values())
|
|
214
|
+
for i, (name, values_dict) in enumerate(experiments.items()):
|
|
215
|
+
raw_values = []
|
|
216
|
+
for m in metrics:
|
|
217
|
+
val = values_dict.get(m, 0)
|
|
218
|
+
# Normalize: for "lower is better" metrics, invert
|
|
219
|
+
if not self.METRIC_HIGHER_BETTER.get(m, True):
|
|
220
|
+
val = 1 - min(val, 1) # Invert so higher = better on plot
|
|
221
|
+
raw_values.append(val)
|
|
222
|
+
|
|
223
|
+
# Normalize to [0, 1] range
|
|
224
|
+
vals = np.array(raw_values)
|
|
225
|
+
vals = vals / max(vals.max(), 1e-10)
|
|
226
|
+
vals = vals.tolist() + vals[:1].tolist()
|
|
227
|
+
|
|
228
|
+
color = colors[i % len(colors)]
|
|
229
|
+
ax.plot(angles, vals, "o-", linewidth=2, label=name, color=color)
|
|
230
|
+
ax.fill(angles, vals, alpha=0.15, color=color)
|
|
231
|
+
|
|
232
|
+
ax.set_xticks(angles[:-1])
|
|
233
|
+
ax.set_xticklabels([self.METRIC_LABELS.get(m, m) for m in metrics])
|
|
234
|
+
ax.set_ylim(0, 1.1)
|
|
235
|
+
ax.legend(loc="upper right", bbox_to_anchor=(1.3, 1.0))
|
|
236
|
+
ax.set_title(title, fontweight="bold", pad=20)
|
|
237
|
+
|
|
238
|
+
out_path = self.output_dir / filename
|
|
239
|
+
fig.savefig(out_path, dpi=self.dpi, bbox_inches="tight")
|
|
240
|
+
plt.close(fig)
|
|
241
|
+
return out_path
|
|
242
|
+
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
# Fitzpatrick equity heatmap
|
|
245
|
+
# ------------------------------------------------------------------
|
|
246
|
+
|
|
247
|
+
def fitzpatrick_heatmap(
|
|
248
|
+
self,
|
|
249
|
+
metrics_by_type: dict[str, dict[str, float]],
|
|
250
|
+
metric: str = "ssim",
|
|
251
|
+
title: str | None = None,
|
|
252
|
+
filename: str = "fitzpatrick_equity.pdf",
|
|
253
|
+
) -> Path:
|
|
254
|
+
"""Generate heatmap showing metric values across Fitzpatrick types and procedures.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
metrics_by_type: {fitzpatrick_type: {procedure: value}}.
|
|
258
|
+
metric: Which metric to visualize.
|
|
259
|
+
title: Figure title.
|
|
260
|
+
filename: Output filename.
|
|
261
|
+
|
|
262
|
+
Returns:
|
|
263
|
+
Path to saved figure.
|
|
264
|
+
"""
|
|
265
|
+
plt = self._get_plt()
|
|
266
|
+
|
|
267
|
+
fitz_types = sorted(metrics_by_type.keys())
|
|
268
|
+
procedures = sorted(set.union(*(set(v.keys()) for v in metrics_by_type.values())))
|
|
269
|
+
|
|
270
|
+
# Build matrix
|
|
271
|
+
matrix = np.zeros((len(fitz_types), len(procedures)))
|
|
272
|
+
for i, ft in enumerate(fitz_types):
|
|
273
|
+
for j, proc in enumerate(procedures):
|
|
274
|
+
matrix[i, j] = metrics_by_type[ft].get(proc, 0)
|
|
275
|
+
|
|
276
|
+
fig, ax = plt.subplots(
|
|
277
|
+
figsize=(max(6, len(procedures) * 1.5), max(4, len(fitz_types) * 0.8))
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
cmap = "RdYlGn" if self.METRIC_HIGHER_BETTER.get(metric, True) else "RdYlGn_r"
|
|
281
|
+
im = ax.imshow(matrix, cmap=cmap, aspect="auto")
|
|
282
|
+
|
|
283
|
+
ax.set_xticks(range(len(procedures)))
|
|
284
|
+
ax.set_xticklabels([p.title() for p in procedures], rotation=30, ha="right")
|
|
285
|
+
ax.set_yticks(range(len(fitz_types)))
|
|
286
|
+
ax.set_yticklabels(fitz_types)
|
|
287
|
+
ax.set_ylabel("Fitzpatrick Type")
|
|
288
|
+
|
|
289
|
+
# Annotate cells
|
|
290
|
+
for i in range(len(fitz_types)):
|
|
291
|
+
for j in range(len(procedures)):
|
|
292
|
+
ax.text(
|
|
293
|
+
j,
|
|
294
|
+
i,
|
|
295
|
+
f"{matrix[i, j]:.3f}",
|
|
296
|
+
ha="center",
|
|
297
|
+
va="center",
|
|
298
|
+
fontsize=9,
|
|
299
|
+
color="white" if matrix[i, j] < np.median(matrix) else "black",
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
fig.colorbar(im, ax=ax, label=self.METRIC_LABELS.get(metric, metric))
|
|
303
|
+
|
|
304
|
+
if title is None:
|
|
305
|
+
title = f"{self.METRIC_LABELS.get(metric, metric)} by Fitzpatrick Type"
|
|
306
|
+
ax.set_title(title, fontweight="bold")
|
|
307
|
+
fig.tight_layout()
|
|
308
|
+
|
|
309
|
+
out_path = self.output_dir / filename
|
|
310
|
+
fig.savefig(out_path, dpi=self.dpi, bbox_inches="tight")
|
|
311
|
+
plt.close(fig)
|
|
312
|
+
return out_path
|
|
313
|
+
|
|
314
|
+
# ------------------------------------------------------------------
|
|
315
|
+
# Box plots for per-sample distribution
|
|
316
|
+
# ------------------------------------------------------------------
|
|
317
|
+
|
|
318
|
+
def distribution_boxplot(
|
|
319
|
+
self,
|
|
320
|
+
samples_by_group: dict[str, list[float]],
|
|
321
|
+
metric: str = "ssim",
|
|
322
|
+
title: str | None = None,
|
|
323
|
+
filename: str = "distribution.pdf",
|
|
324
|
+
) -> Path:
|
|
325
|
+
"""Generate box plot showing per-sample metric distributions.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
samples_by_group: {group_name: [sample_values]}.
|
|
329
|
+
metric: Metric being plotted.
|
|
330
|
+
title: Figure title.
|
|
331
|
+
filename: Output filename.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Path to saved figure.
|
|
335
|
+
"""
|
|
336
|
+
plt = self._get_plt()
|
|
337
|
+
|
|
338
|
+
groups = list(samples_by_group.keys())
|
|
339
|
+
data = [samples_by_group[g] for g in groups]
|
|
340
|
+
|
|
341
|
+
fig, ax = plt.subplots(figsize=(max(6, len(groups) * 1.2), 5))
|
|
342
|
+
|
|
343
|
+
bp = ax.boxplot(
|
|
344
|
+
data,
|
|
345
|
+
patch_artist=True,
|
|
346
|
+
widths=0.6,
|
|
347
|
+
medianprops={"color": "black", "linewidth": 1.5},
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
colors = [self.COLORS.get(g, "#4C72B0") for g in groups]
|
|
351
|
+
for patch, color in zip(bp["boxes"], colors):
|
|
352
|
+
patch.set_facecolor(color)
|
|
353
|
+
patch.set_alpha(0.7)
|
|
354
|
+
|
|
355
|
+
ax.set_xticklabels(
|
|
356
|
+
[g.title() for g in groups],
|
|
357
|
+
rotation=30,
|
|
358
|
+
ha="right",
|
|
359
|
+
)
|
|
360
|
+
ax.set_ylabel(self.METRIC_LABELS.get(metric, metric))
|
|
361
|
+
|
|
362
|
+
if title is None:
|
|
363
|
+
title = f"{self.METRIC_LABELS.get(metric, metric)} Distribution"
|
|
364
|
+
ax.set_title(title, fontweight="bold")
|
|
365
|
+
|
|
366
|
+
# Add sample count annotations
|
|
367
|
+
for i, (_g, vals) in enumerate(zip(groups, data)):
|
|
368
|
+
ax.text(
|
|
369
|
+
i + 1,
|
|
370
|
+
ax.get_ylim()[0],
|
|
371
|
+
f"n={len(vals)}",
|
|
372
|
+
ha="center",
|
|
373
|
+
va="bottom",
|
|
374
|
+
fontsize=8,
|
|
375
|
+
color="gray",
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
fig.tight_layout()
|
|
379
|
+
out_path = self.output_dir / filename
|
|
380
|
+
fig.savefig(out_path, dpi=self.dpi, bbox_inches="tight")
|
|
381
|
+
plt.close(fig)
|
|
382
|
+
return out_path
|
|
383
|
+
|
|
384
|
+
# ------------------------------------------------------------------
|
|
385
|
+
# LaTeX table formatter
|
|
386
|
+
# ------------------------------------------------------------------
|
|
387
|
+
|
|
388
|
+
@staticmethod
|
|
389
|
+
def to_latex_table(
|
|
390
|
+
rows: list[dict[str, Any]],
|
|
391
|
+
metrics: list[str],
|
|
392
|
+
caption: str = "Quantitative results",
|
|
393
|
+
label: str = "tab:results",
|
|
394
|
+
highlight_best: bool = True,
|
|
395
|
+
) -> str:
|
|
396
|
+
"""Format metrics as a LaTeX table.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
rows: List of dicts with 'name' and metric values.
|
|
400
|
+
metrics: List of metric names to include.
|
|
401
|
+
caption: Table caption.
|
|
402
|
+
label: LaTeX label.
|
|
403
|
+
highlight_best: Bold the best value per column.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
LaTeX table string.
|
|
407
|
+
"""
|
|
408
|
+
metric_labels = MetricsVisualizer.METRIC_LABELS
|
|
409
|
+
higher_better = MetricsVisualizer.METRIC_HIGHER_BETTER
|
|
410
|
+
|
|
411
|
+
# Find best values
|
|
412
|
+
best: dict[str, float] = {}
|
|
413
|
+
if highlight_best:
|
|
414
|
+
for m in metrics:
|
|
415
|
+
vals = [r.get(m) for r in rows if r.get(m) is not None]
|
|
416
|
+
if vals:
|
|
417
|
+
if higher_better.get(m, True):
|
|
418
|
+
best[m] = max(vals)
|
|
419
|
+
else:
|
|
420
|
+
best[m] = min(vals)
|
|
421
|
+
|
|
422
|
+
cols = "l" + "c" * len(metrics)
|
|
423
|
+
lines = [
|
|
424
|
+
"\\begin{table}[t]",
|
|
425
|
+
"\\centering",
|
|
426
|
+
f"\\caption{{{caption}}}",
|
|
427
|
+
f"\\label{{{label}}}",
|
|
428
|
+
f"\\begin{{tabular}}{{{cols}}}",
|
|
429
|
+
"\\toprule",
|
|
430
|
+
]
|
|
431
|
+
|
|
432
|
+
# Header
|
|
433
|
+
header = ["Method"]
|
|
434
|
+
for m in metrics:
|
|
435
|
+
name = metric_labels.get(m, m)
|
|
436
|
+
arrow = "$\\uparrow$" if higher_better.get(m, True) else "$\\downarrow$"
|
|
437
|
+
header.append(f"{name} {arrow}")
|
|
438
|
+
lines.append(" & ".join(header) + " \\\\")
|
|
439
|
+
lines.append("\\midrule")
|
|
440
|
+
|
|
441
|
+
# Data rows
|
|
442
|
+
for row in rows:
|
|
443
|
+
parts = [row.get("name", "").replace("_", "\\_")]
|
|
444
|
+
for m in metrics:
|
|
445
|
+
val = row.get(m)
|
|
446
|
+
if val is None:
|
|
447
|
+
parts.append("--")
|
|
448
|
+
else:
|
|
449
|
+
fmt = ".4f" if abs(val) < 10 else ".1f"
|
|
450
|
+
val_str = f"{val:{fmt}}"
|
|
451
|
+
if highlight_best and val == best.get(m):
|
|
452
|
+
val_str = f"\\textbf{{{val_str}}}"
|
|
453
|
+
parts.append(val_str)
|
|
454
|
+
lines.append(" & ".join(parts) + " \\\\")
|
|
455
|
+
|
|
456
|
+
lines.extend(
|
|
457
|
+
[
|
|
458
|
+
"\\bottomrule",
|
|
459
|
+
"\\end{tabular}",
|
|
460
|
+
"\\end{table}",
|
|
461
|
+
]
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
return "\n".join(lines)
|