wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +1 -1
- wavedl/models/__init__.py +52 -4
- wavedl/models/_timm_utils.py +238 -0
- wavedl/models/caformer.py +270 -0
- wavedl/models/convnext.py +108 -33
- wavedl/models/convnext_v2.py +504 -0
- wavedl/models/densenet.py +5 -5
- wavedl/models/efficientnet.py +30 -13
- wavedl/models/efficientnetv2.py +32 -9
- wavedl/models/fastvit.py +285 -0
- wavedl/models/mamba.py +535 -0
- wavedl/models/maxvit.py +251 -0
- wavedl/models/mobilenetv3.py +35 -12
- wavedl/models/regnet.py +39 -16
- wavedl/models/resnet.py +5 -5
- wavedl/models/resnet3d.py +2 -2
- wavedl/models/swin.py +41 -9
- wavedl/models/tcn.py +25 -5
- wavedl/models/unet.py +1 -1
- wavedl/models/vit.py +6 -6
- wavedl/test.py +7 -3
- wavedl/train.py +57 -23
- wavedl/utils/constraints.py +11 -5
- wavedl/utils/data.py +120 -18
- wavedl/utils/metrics.py +287 -326
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/METADATA +104 -67
- wavedl-1.6.0.dist-info/RECORD +44 -0
- wavedl-1.5.6.dist-info/RECORD +0 -38
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/LICENSE +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/WHEEL +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/entry_points.txt +0 -0
- {wavedl-1.5.6.dist-info → wavedl-1.6.0.dist-info}/top_level.txt +0 -0
wavedl/utils/metrics.py
CHANGED
|
@@ -24,10 +24,10 @@ from sklearn.metrics import r2_score
|
|
|
24
24
|
FIGURE_WIDTH_CM = 19
|
|
25
25
|
FIGURE_WIDTH_INCH = FIGURE_WIDTH_CM / 2.54
|
|
26
26
|
|
|
27
|
-
# Font sizes (
|
|
28
|
-
FONT_SIZE_TEXT =
|
|
29
|
-
FONT_SIZE_TICKS =
|
|
30
|
-
FONT_SIZE_TITLE =
|
|
27
|
+
# Font sizes (publication quality)
|
|
28
|
+
FONT_SIZE_TEXT = 10
|
|
29
|
+
FONT_SIZE_TICKS = 9
|
|
30
|
+
FONT_SIZE_TITLE = 11
|
|
31
31
|
|
|
32
32
|
# DPI for publication (300 for print, 150 for screen)
|
|
33
33
|
FIGURE_DPI = 300
|
|
@@ -40,17 +40,43 @@ COLORS = {
|
|
|
40
40
|
"neutral": "#6B717E", # Slate gray
|
|
41
41
|
"error": "#C73E1D", # Red
|
|
42
42
|
"success": "#3A7D44", # Green
|
|
43
|
+
"scatter": "#96C2D5", # Light steel blue (simulates primary at 50% alpha on white)
|
|
43
44
|
}
|
|
44
45
|
|
|
45
46
|
|
|
47
|
+
def _is_latex_available() -> bool:
|
|
48
|
+
"""Check if LaTeX is available for matplotlib rendering."""
|
|
49
|
+
import shutil
|
|
50
|
+
|
|
51
|
+
return shutil.which("latex") is not None
|
|
52
|
+
|
|
53
|
+
|
|
46
54
|
def configure_matplotlib_style():
|
|
47
|
-
"""Configure matplotlib for publication-quality LaTeX-style plots.
|
|
55
|
+
"""Configure matplotlib for publication-quality LaTeX-style plots.
|
|
56
|
+
|
|
57
|
+
Falls back to standard fonts if LaTeX is not installed.
|
|
58
|
+
"""
|
|
59
|
+
use_latex = _is_latex_available()
|
|
60
|
+
|
|
61
|
+
if use_latex:
|
|
62
|
+
latex_settings = {
|
|
63
|
+
"text.usetex": True,
|
|
64
|
+
"font.family": "serif",
|
|
65
|
+
"font.serif": ["Computer Modern Roman"],
|
|
66
|
+
"text.latex.preamble": r"\usepackage{amsmath} \usepackage{amssymb}",
|
|
67
|
+
}
|
|
68
|
+
else:
|
|
69
|
+
# Fallback for systems without LaTeX (e.g., CI runners)
|
|
70
|
+
latex_settings = {
|
|
71
|
+
"text.usetex": False,
|
|
72
|
+
"font.family": "serif",
|
|
73
|
+
"font.serif": ["DejaVu Serif", "Times New Roman", "serif"],
|
|
74
|
+
}
|
|
75
|
+
|
|
48
76
|
plt.rcParams.update(
|
|
49
77
|
{
|
|
50
|
-
# LaTeX
|
|
51
|
-
|
|
52
|
-
"font.serif": ["Times New Roman", "DejaVu Serif", "serif"],
|
|
53
|
-
"mathtext.fontset": "cm", # Computer Modern for math
|
|
78
|
+
# LaTeX/font settings (conditional)
|
|
79
|
+
**latex_settings,
|
|
54
80
|
# Font sizes
|
|
55
81
|
"font.size": FONT_SIZE_TEXT,
|
|
56
82
|
"axes.titlesize": FONT_SIZE_TITLE,
|
|
@@ -62,8 +88,8 @@ def configure_matplotlib_style():
|
|
|
62
88
|
"axes.linewidth": 0.8,
|
|
63
89
|
"grid.linewidth": 0.5,
|
|
64
90
|
"lines.linewidth": 1.5,
|
|
65
|
-
# Grid style
|
|
66
|
-
"grid.
|
|
91
|
+
# Grid style (use light gray instead of alpha for vector compatibility)
|
|
92
|
+
"grid.color": "#CCCCCC",
|
|
67
93
|
"grid.linestyle": ":",
|
|
68
94
|
# Figure settings
|
|
69
95
|
"figure.dpi": FIGURE_DPI,
|
|
@@ -73,6 +99,9 @@ def configure_matplotlib_style():
|
|
|
73
99
|
# Remove top/right spines for cleaner look
|
|
74
100
|
"axes.spines.top": False,
|
|
75
101
|
"axes.spines.right": False,
|
|
102
|
+
# SVG/vector export settings - prevent rasterization
|
|
103
|
+
"svg.fonttype": "none", # Embed fonts as text, not paths
|
|
104
|
+
"image.composite_image": False, # Don't composite images
|
|
76
105
|
}
|
|
77
106
|
)
|
|
78
107
|
|
|
@@ -211,6 +240,83 @@ def calc_per_target_r2(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
|
|
211
240
|
return np.array(r2_scores)
|
|
212
241
|
|
|
213
242
|
|
|
243
|
+
# ==============================================================================
|
|
244
|
+
# VISUALIZATION HELPERS (internal)
|
|
245
|
+
# ==============================================================================
|
|
246
|
+
def _prepare_plot_data(
|
|
247
|
+
y_true: np.ndarray,
|
|
248
|
+
y_pred: np.ndarray,
|
|
249
|
+
param_names: list[str] | None = None,
|
|
250
|
+
max_samples: int | None = None,
|
|
251
|
+
) -> tuple[np.ndarray, np.ndarray, list[str], int]:
|
|
252
|
+
"""Prepare data for plotting: reshape, sample, and generate param names."""
|
|
253
|
+
# Handle 1D case
|
|
254
|
+
if y_true.ndim == 1:
|
|
255
|
+
y_true = y_true.reshape(-1, 1)
|
|
256
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
257
|
+
|
|
258
|
+
num_params = y_true.shape[1]
|
|
259
|
+
|
|
260
|
+
# Subsample if needed
|
|
261
|
+
if max_samples and len(y_true) > max_samples:
|
|
262
|
+
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
263
|
+
y_true = y_true[indices]
|
|
264
|
+
y_pred = y_pred[indices]
|
|
265
|
+
|
|
266
|
+
# Generate default param names if needed
|
|
267
|
+
if param_names is None or len(param_names) != num_params:
|
|
268
|
+
param_names = [f"P{i}" for i in range(num_params)]
|
|
269
|
+
|
|
270
|
+
return y_true, y_pred, param_names, num_params
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def _create_subplot_grid(
|
|
274
|
+
num_params: int,
|
|
275
|
+
height_ratio: float = 1.0,
|
|
276
|
+
max_cols: int = 4,
|
|
277
|
+
) -> tuple[plt.Figure, np.ndarray]:
|
|
278
|
+
"""Create a subplot grid for multi-parameter plots."""
|
|
279
|
+
cols = min(num_params, max_cols)
|
|
280
|
+
rows = (num_params + cols - 1) // cols
|
|
281
|
+
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
282
|
+
|
|
283
|
+
fig, axes = plt.subplots(
|
|
284
|
+
rows,
|
|
285
|
+
cols,
|
|
286
|
+
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * height_ratio),
|
|
287
|
+
)
|
|
288
|
+
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
289
|
+
return fig, axes
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def _add_unified_legend(
|
|
293
|
+
fig: plt.Figure,
|
|
294
|
+
axes: np.ndarray,
|
|
295
|
+
ncol: int = 2,
|
|
296
|
+
y_offset: float = -0.13,
|
|
297
|
+
bottom_margin: float = 0.22,
|
|
298
|
+
) -> None:
|
|
299
|
+
"""Add a unified legend below the figure."""
|
|
300
|
+
handles, labels = axes[0].get_legend_handles_labels()
|
|
301
|
+
fig.legend(
|
|
302
|
+
handles,
|
|
303
|
+
labels,
|
|
304
|
+
loc="lower center",
|
|
305
|
+
ncol=ncol,
|
|
306
|
+
fontsize=FONT_SIZE_TEXT,
|
|
307
|
+
fancybox=False,
|
|
308
|
+
framealpha=1.0,
|
|
309
|
+
bbox_to_anchor=(0.5, y_offset),
|
|
310
|
+
)
|
|
311
|
+
fig.subplots_adjust(bottom=bottom_margin)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _hide_unused_subplots(axes: np.ndarray, num_used: int) -> None:
|
|
315
|
+
"""Hide unused subplots in a grid."""
|
|
316
|
+
for i in range(num_used, len(axes)):
|
|
317
|
+
axes[i].axis("off")
|
|
318
|
+
|
|
319
|
+
|
|
214
320
|
# ==============================================================================
|
|
215
321
|
# VISUALIZATION - SCATTER PLOTS
|
|
216
322
|
# ==============================================================================
|
|
@@ -235,50 +341,25 @@ def plot_scientific_scatter(
|
|
|
235
341
|
Returns:
|
|
236
342
|
Matplotlib Figure object (can be saved or logged to WandB)
|
|
237
343
|
"""
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
# Handle 1D case
|
|
241
|
-
if y_true.ndim == 1:
|
|
242
|
-
y_true = y_true.reshape(-1, 1)
|
|
243
|
-
y_pred = y_pred.reshape(-1, 1)
|
|
244
|
-
|
|
245
|
-
# Downsample for visualization performance
|
|
246
|
-
if len(y_true) > max_samples:
|
|
247
|
-
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
248
|
-
y_true = y_true[indices]
|
|
249
|
-
y_pred = y_pred[indices]
|
|
250
|
-
|
|
251
|
-
# Calculate grid dimensions
|
|
252
|
-
cols = min(num_params, 4)
|
|
253
|
-
rows = (num_params + cols - 1) // cols
|
|
254
|
-
|
|
255
|
-
# Calculate figure size (19 cm width)
|
|
256
|
-
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
257
|
-
fig, axes = plt.subplots(
|
|
258
|
-
rows,
|
|
259
|
-
cols,
|
|
260
|
-
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows),
|
|
344
|
+
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
345
|
+
y_true, y_pred, param_names, max_samples
|
|
261
346
|
)
|
|
262
|
-
axes =
|
|
347
|
+
fig, axes = _create_subplot_grid(num_params, height_ratio=1.0)
|
|
263
348
|
|
|
264
349
|
for i in range(num_params):
|
|
265
350
|
ax = axes[i]
|
|
266
351
|
|
|
267
352
|
# Calculate R² for this target
|
|
268
|
-
if len(y_true) >= 2
|
|
269
|
-
r2 = r2_score(y_true[:, i], y_pred[:, i])
|
|
270
|
-
else:
|
|
271
|
-
r2 = float("nan")
|
|
353
|
+
r2 = r2_score(y_true[:, i], y_pred[:, i]) if len(y_true) >= 2 else float("nan")
|
|
272
354
|
|
|
273
|
-
# Scatter plot
|
|
274
|
-
ax.
|
|
355
|
+
# Scatter plot (using plot for vector SVG compatibility)
|
|
356
|
+
ax.plot(
|
|
275
357
|
y_true[:, i],
|
|
276
358
|
y_pred[:, i],
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
rasterized=True,
|
|
359
|
+
"o",
|
|
360
|
+
markersize=5,
|
|
361
|
+
markerfacecolor=COLORS["scatter"],
|
|
362
|
+
markeredgecolor="none",
|
|
282
363
|
label="Data",
|
|
283
364
|
)
|
|
284
365
|
|
|
@@ -292,25 +373,20 @@ def plot_scientific_scatter(
|
|
|
292
373
|
"--",
|
|
293
374
|
color=COLORS["error"],
|
|
294
375
|
lw=1.2,
|
|
295
|
-
alpha=0.8,
|
|
296
376
|
label="Ideal",
|
|
297
377
|
)
|
|
298
378
|
|
|
299
379
|
# Labels and formatting
|
|
300
|
-
|
|
301
|
-
ax.set_title(f"{title}\n$R^2 = {r2:.4f}$")
|
|
380
|
+
ax.set_title(f"{param_names[i]}\n$R^2 = {r2:.4f}$")
|
|
302
381
|
ax.set_xlabel("Ground Truth")
|
|
303
382
|
ax.set_ylabel("Prediction")
|
|
304
383
|
ax.grid(True)
|
|
305
384
|
ax.set_aspect("equal", adjustable="box")
|
|
306
385
|
ax.set_xlim(min_val - margin, max_val + margin)
|
|
307
386
|
ax.set_ylim(min_val - margin, max_val + margin)
|
|
308
|
-
ax.legend(fontsize=
|
|
309
|
-
|
|
310
|
-
# Hide unused subplots
|
|
311
|
-
for i in range(num_params, len(axes)):
|
|
312
|
-
axes[i].axis("off")
|
|
387
|
+
ax.legend(fontsize=FONT_SIZE_TICKS, loc="best", fancybox=False, framealpha=1.0)
|
|
313
388
|
|
|
389
|
+
_hide_unused_subplots(axes, num_params)
|
|
314
390
|
plt.tight_layout()
|
|
315
391
|
return fig
|
|
316
392
|
|
|
@@ -339,75 +415,48 @@ def plot_error_histogram(
|
|
|
339
415
|
Returns:
|
|
340
416
|
Matplotlib Figure object
|
|
341
417
|
"""
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
# Handle 1D case
|
|
345
|
-
if y_true.ndim == 1:
|
|
346
|
-
y_true = y_true.reshape(-1, 1)
|
|
347
|
-
y_pred = y_pred.reshape(-1, 1)
|
|
348
|
-
|
|
349
|
-
# Calculate errors
|
|
350
|
-
errors = y_pred - y_true
|
|
351
|
-
|
|
352
|
-
# Calculate grid dimensions
|
|
353
|
-
cols = min(num_params, 4)
|
|
354
|
-
rows = (num_params + cols - 1) // cols
|
|
355
|
-
|
|
356
|
-
# Calculate figure size
|
|
357
|
-
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
358
|
-
fig, axes = plt.subplots(
|
|
359
|
-
rows,
|
|
360
|
-
cols,
|
|
361
|
-
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
418
|
+
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
419
|
+
y_true, y_pred, param_names
|
|
362
420
|
)
|
|
363
|
-
|
|
421
|
+
errors = y_pred - y_true
|
|
422
|
+
fig, axes = _create_subplot_grid(num_params)
|
|
364
423
|
|
|
365
424
|
for i in range(num_params):
|
|
366
425
|
ax = axes[i]
|
|
367
426
|
err = errors[:, i]
|
|
427
|
+
mean_err, std_err, mae = np.mean(err), np.std(err), np.mean(np.abs(err))
|
|
368
428
|
|
|
369
|
-
#
|
|
370
|
-
mean_err = np.mean(err)
|
|
371
|
-
std_err = np.std(err)
|
|
372
|
-
mae = np.mean(np.abs(err))
|
|
373
|
-
|
|
374
|
-
# Histogram
|
|
429
|
+
# Histogram - only label first subplot
|
|
375
430
|
ax.hist(
|
|
376
431
|
err,
|
|
377
432
|
bins=bins,
|
|
378
433
|
color=COLORS["primary"],
|
|
379
|
-
alpha=0.7,
|
|
380
434
|
edgecolor="white",
|
|
381
435
|
linewidth=0.5,
|
|
382
|
-
label="Errors",
|
|
436
|
+
label="Errors" if i == 0 else None,
|
|
383
437
|
)
|
|
384
|
-
|
|
385
|
-
# Vertical line at zero
|
|
386
438
|
ax.axvline(
|
|
387
|
-
x=0,
|
|
439
|
+
x=0,
|
|
440
|
+
color=COLORS["error"],
|
|
441
|
+
linestyle="--",
|
|
442
|
+
lw=1.2,
|
|
443
|
+
label="Zero" if i == 0 else None,
|
|
388
444
|
)
|
|
389
|
-
|
|
390
|
-
# Mean line
|
|
391
445
|
ax.axvline(
|
|
392
446
|
x=mean_err,
|
|
393
447
|
color=COLORS["accent"],
|
|
394
448
|
linestyle="-",
|
|
395
449
|
lw=1.2,
|
|
396
|
-
label=
|
|
450
|
+
label="Mean" if i == 0 else None,
|
|
397
451
|
)
|
|
398
452
|
|
|
399
|
-
|
|
400
|
-
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
401
|
-
ax.set_title(f"{title}\nMAE = {mae:.4f}, $\\sigma$ = {std_err:.4f}")
|
|
453
|
+
ax.set_title(f"{param_names[i]}\nMAE = {mae:.4f}, $\\sigma$ = {std_err:.4f}")
|
|
402
454
|
ax.set_xlabel("Prediction Error")
|
|
403
455
|
ax.set_ylabel("Count")
|
|
404
456
|
ax.grid(True, axis="y")
|
|
405
|
-
ax.legend(fontsize=6, loc="best")
|
|
406
|
-
|
|
407
|
-
# Hide unused subplots
|
|
408
|
-
for i in range(num_params, len(axes)):
|
|
409
|
-
axes[i].axis("off")
|
|
410
457
|
|
|
458
|
+
_hide_unused_subplots(axes, num_params)
|
|
459
|
+
_add_unified_legend(fig, axes, ncol=3)
|
|
411
460
|
plt.tight_layout()
|
|
412
461
|
return fig
|
|
413
462
|
|
|
@@ -436,78 +485,46 @@ def plot_residuals(
|
|
|
436
485
|
Returns:
|
|
437
486
|
Matplotlib Figure object
|
|
438
487
|
"""
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
# Handle 1D case
|
|
442
|
-
if y_true.ndim == 1:
|
|
443
|
-
y_true = y_true.reshape(-1, 1)
|
|
444
|
-
y_pred = y_pred.reshape(-1, 1)
|
|
445
|
-
|
|
446
|
-
# Calculate residuals
|
|
447
|
-
residuals = y_pred - y_true
|
|
448
|
-
|
|
449
|
-
# Downsample for visualization
|
|
450
|
-
if len(y_true) > max_samples:
|
|
451
|
-
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
452
|
-
y_pred = y_pred[indices]
|
|
453
|
-
residuals = residuals[indices]
|
|
454
|
-
|
|
455
|
-
# Calculate grid dimensions
|
|
456
|
-
cols = min(num_params, 4)
|
|
457
|
-
rows = (num_params + cols - 1) // cols
|
|
458
|
-
|
|
459
|
-
# Calculate figure size
|
|
460
|
-
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
461
|
-
fig, axes = plt.subplots(
|
|
462
|
-
rows,
|
|
463
|
-
cols,
|
|
464
|
-
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
488
|
+
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
489
|
+
y_true, y_pred, param_names, max_samples
|
|
465
490
|
)
|
|
466
|
-
|
|
491
|
+
residuals = y_pred - y_true
|
|
492
|
+
fig, axes = _create_subplot_grid(num_params)
|
|
467
493
|
|
|
468
494
|
for i in range(num_params):
|
|
469
495
|
ax = axes[i]
|
|
470
|
-
|
|
471
|
-
# Scatter plot of residuals
|
|
472
|
-
ax.scatter(
|
|
496
|
+
ax.plot(
|
|
473
497
|
y_pred[:, i],
|
|
474
498
|
residuals[:, i],
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
label="Data",
|
|
499
|
+
"o",
|
|
500
|
+
markersize=5,
|
|
501
|
+
markerfacecolor=COLORS["scatter"],
|
|
502
|
+
markeredgecolor="none",
|
|
503
|
+
label="Data" if i == 0 else None,
|
|
481
504
|
)
|
|
482
|
-
|
|
483
|
-
# Zero line
|
|
484
505
|
ax.axhline(
|
|
485
|
-
y=0,
|
|
506
|
+
y=0,
|
|
507
|
+
color=COLORS["error"],
|
|
508
|
+
linestyle="--",
|
|
509
|
+
lw=1.2,
|
|
510
|
+
label="Zero" if i == 0 else None,
|
|
486
511
|
)
|
|
487
|
-
|
|
488
|
-
# Calculate and show mean residual
|
|
489
512
|
mean_res = np.mean(residuals[:, i])
|
|
490
513
|
ax.axhline(
|
|
491
514
|
y=mean_res,
|
|
492
515
|
color=COLORS["accent"],
|
|
493
516
|
linestyle="-",
|
|
494
517
|
lw=1.0,
|
|
495
|
-
|
|
496
|
-
label=f"Mean = {mean_res:.4f}",
|
|
518
|
+
label="Mean" if i == 0 else None,
|
|
497
519
|
)
|
|
498
520
|
|
|
499
|
-
|
|
500
|
-
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
501
|
-
ax.set_title(f"{title}")
|
|
521
|
+
ax.set_title(f"{param_names[i]}")
|
|
502
522
|
ax.set_xlabel("Predicted Value")
|
|
503
|
-
ax.set_ylabel("
|
|
523
|
+
ax.set_ylabel("Pred - True")
|
|
504
524
|
ax.grid(True)
|
|
505
|
-
ax.legend(fontsize=6, loc="best")
|
|
506
|
-
|
|
507
|
-
# Hide unused subplots
|
|
508
|
-
for i in range(num_params, len(axes)):
|
|
509
|
-
axes[i].axis("off")
|
|
510
525
|
|
|
526
|
+
_hide_unused_subplots(axes, num_params)
|
|
527
|
+
_add_unified_legend(fig, axes, ncol=3)
|
|
511
528
|
plt.tight_layout()
|
|
512
529
|
return fig
|
|
513
530
|
|
|
@@ -602,7 +619,7 @@ def create_training_curves(
|
|
|
602
619
|
ax1.set_xlabel("Epoch")
|
|
603
620
|
ax1.set_ylabel("Loss")
|
|
604
621
|
ax1.set_yscale("log") # Log scale for loss
|
|
605
|
-
ax1.grid(True
|
|
622
|
+
ax1.grid(True)
|
|
606
623
|
|
|
607
624
|
# Collect all loss values and set clean power of 10 ticks
|
|
608
625
|
all_loss_values = []
|
|
@@ -623,7 +640,6 @@ def create_training_curves(
|
|
|
623
640
|
"--",
|
|
624
641
|
color=COLORS["neutral"],
|
|
625
642
|
linewidth=1.0,
|
|
626
|
-
alpha=0.7,
|
|
627
643
|
label="Learning Rate",
|
|
628
644
|
)
|
|
629
645
|
ax2.set_ylabel("Learning Rate")
|
|
@@ -635,7 +651,14 @@ def create_training_curves(
|
|
|
635
651
|
|
|
636
652
|
# Combined legend
|
|
637
653
|
labels = [l.get_label() for l in lines]
|
|
638
|
-
ax1.legend(
|
|
654
|
+
ax1.legend(
|
|
655
|
+
lines,
|
|
656
|
+
labels,
|
|
657
|
+
loc="best",
|
|
658
|
+
fontsize=FONT_SIZE_TICKS,
|
|
659
|
+
fancybox=False,
|
|
660
|
+
framealpha=1.0,
|
|
661
|
+
)
|
|
639
662
|
|
|
640
663
|
ax1.set_title("Training Curves")
|
|
641
664
|
|
|
@@ -668,97 +691,53 @@ def plot_bland_altman(
|
|
|
668
691
|
Returns:
|
|
669
692
|
Matplotlib Figure object
|
|
670
693
|
"""
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
if y_true.ndim == 1:
|
|
675
|
-
y_true = y_true.reshape(-1, 1)
|
|
676
|
-
y_pred = y_pred.reshape(-1, 1)
|
|
677
|
-
|
|
678
|
-
# Calculate mean and difference
|
|
694
|
+
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
695
|
+
y_true, y_pred, param_names, max_samples
|
|
696
|
+
)
|
|
679
697
|
mean_vals = (y_true + y_pred) / 2
|
|
680
698
|
diff_vals = y_pred - y_true
|
|
681
|
-
|
|
682
|
-
# Downsample for visualization
|
|
683
|
-
if len(y_true) > max_samples:
|
|
684
|
-
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
685
|
-
mean_vals = mean_vals[indices]
|
|
686
|
-
diff_vals = diff_vals[indices]
|
|
687
|
-
|
|
688
|
-
# Calculate grid dimensions
|
|
689
|
-
cols = min(num_params, 4)
|
|
690
|
-
rows = (num_params + cols - 1) // cols
|
|
691
|
-
|
|
692
|
-
# Calculate figure size
|
|
693
|
-
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
694
|
-
fig, axes = plt.subplots(
|
|
695
|
-
rows,
|
|
696
|
-
cols,
|
|
697
|
-
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
698
|
-
)
|
|
699
|
-
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
699
|
+
fig, axes = _create_subplot_grid(num_params)
|
|
700
700
|
|
|
701
701
|
for i in range(num_params):
|
|
702
702
|
ax = axes[i]
|
|
703
|
-
|
|
704
703
|
mean_diff = np.mean(diff_vals[:, i])
|
|
705
704
|
std_diff = np.std(diff_vals[:, i])
|
|
706
|
-
|
|
707
|
-
# Limits of agreement (95% CI = mean ± 1.96*SD)
|
|
708
705
|
upper_loa = mean_diff + 1.96 * std_diff
|
|
709
706
|
lower_loa = mean_diff - 1.96 * std_diff
|
|
710
707
|
|
|
711
|
-
|
|
712
|
-
ax.scatter(
|
|
708
|
+
ax.plot(
|
|
713
709
|
mean_vals[:, i],
|
|
714
710
|
diff_vals[:, i],
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
711
|
+
"o",
|
|
712
|
+
markersize=5,
|
|
713
|
+
markerfacecolor=COLORS["scatter"],
|
|
714
|
+
markeredgecolor="none",
|
|
715
|
+
label="Data" if i == 0 else None,
|
|
720
716
|
)
|
|
721
|
-
|
|
722
|
-
# Mean difference line
|
|
723
717
|
ax.axhline(
|
|
724
718
|
y=mean_diff,
|
|
725
719
|
color=COLORS["accent"],
|
|
726
720
|
linestyle="-",
|
|
727
721
|
lw=1.2,
|
|
728
|
-
label=
|
|
722
|
+
label="Mean" if i == 0 else None,
|
|
729
723
|
)
|
|
730
|
-
|
|
731
|
-
# Limits of agreement
|
|
732
724
|
ax.axhline(
|
|
733
725
|
y=upper_loa,
|
|
734
726
|
color=COLORS["error"],
|
|
735
727
|
linestyle="--",
|
|
736
728
|
lw=1.0,
|
|
737
|
-
label=
|
|
729
|
+
label=r"$\pm$1.96 SD" if i == 0 else None,
|
|
738
730
|
)
|
|
739
|
-
ax.axhline(
|
|
740
|
-
|
|
741
|
-
color=COLORS["error"],
|
|
742
|
-
linestyle="--",
|
|
743
|
-
lw=1.0,
|
|
744
|
-
label=f"-1.96 SD = {lower_loa:.3f}",
|
|
745
|
-
)
|
|
746
|
-
|
|
747
|
-
# Zero line
|
|
748
|
-
ax.axhline(y=0, color="gray", linestyle=":", lw=0.8, alpha=0.5)
|
|
731
|
+
ax.axhline(y=lower_loa, color=COLORS["error"], linestyle="--", lw=1.0)
|
|
732
|
+
ax.axhline(y=0, color="gray", linestyle=":", lw=0.8)
|
|
749
733
|
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
ax.
|
|
753
|
-
ax.set_xlabel("Mean of True and Predicted")
|
|
754
|
-
ax.set_ylabel("Difference (Pred - True)")
|
|
734
|
+
ax.set_title(f"{param_names[i]}")
|
|
735
|
+
ax.set_xlabel("Mean of True and Pred")
|
|
736
|
+
ax.set_ylabel("Pred - True")
|
|
755
737
|
ax.grid(True)
|
|
756
|
-
ax.legend(fontsize=6, loc="best")
|
|
757
|
-
|
|
758
|
-
# Hide unused subplots
|
|
759
|
-
for i in range(num_params, len(axes)):
|
|
760
|
-
axes[i].axis("off")
|
|
761
738
|
|
|
739
|
+
_hide_unused_subplots(axes, num_params)
|
|
740
|
+
_add_unified_legend(fig, axes, ncol=3)
|
|
762
741
|
plt.tight_layout()
|
|
763
742
|
return fig
|
|
764
743
|
|
|
@@ -828,7 +807,7 @@ def plot_qq(
|
|
|
828
807
|
"Zero variance\n(constant errors)",
|
|
829
808
|
ha="center",
|
|
830
809
|
va="center",
|
|
831
|
-
fontsize=
|
|
810
|
+
fontsize=FONT_SIZE_TEXT,
|
|
832
811
|
transform=ax.transAxes,
|
|
833
812
|
)
|
|
834
813
|
ax.set_title(f"{title}\n(zero variance)")
|
|
@@ -841,15 +820,14 @@ def plot_qq(
|
|
|
841
820
|
# Calculate theoretical quantiles and sample quantiles
|
|
842
821
|
(osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
|
|
843
822
|
|
|
844
|
-
# Scatter plot
|
|
845
|
-
ax.
|
|
823
|
+
# Scatter plot (using plot for vector SVG compatibility)
|
|
824
|
+
ax.plot(
|
|
846
825
|
osm,
|
|
847
826
|
osr,
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
rasterized=True,
|
|
827
|
+
"o",
|
|
828
|
+
markersize=5,
|
|
829
|
+
markerfacecolor=COLORS["scatter"],
|
|
830
|
+
markeredgecolor="none",
|
|
853
831
|
label="Data",
|
|
854
832
|
)
|
|
855
833
|
|
|
@@ -864,7 +842,7 @@ def plot_qq(
|
|
|
864
842
|
ax.set_xlabel("Theoretical Quantiles")
|
|
865
843
|
ax.set_ylabel("Sample Quantiles")
|
|
866
844
|
ax.grid(True)
|
|
867
|
-
ax.legend(fontsize=
|
|
845
|
+
ax.legend(fontsize=FONT_SIZE_TICKS, loc="best", fancybox=False, framealpha=1.0)
|
|
868
846
|
|
|
869
847
|
# Hide unused subplots
|
|
870
848
|
for i in range(num_params, len(axes)):
|
|
@@ -896,6 +874,9 @@ def plot_correlation_heatmap(
|
|
|
896
874
|
Returns:
|
|
897
875
|
Matplotlib Figure object
|
|
898
876
|
"""
|
|
877
|
+
from matplotlib.colors import Normalize
|
|
878
|
+
from matplotlib.patches import Rectangle
|
|
879
|
+
|
|
899
880
|
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
900
881
|
|
|
901
882
|
if num_params < 2:
|
|
@@ -907,7 +888,7 @@ def plot_correlation_heatmap(
|
|
|
907
888
|
"Correlation heatmap requires\nat least 2 parameters",
|
|
908
889
|
ha="center",
|
|
909
890
|
va="center",
|
|
910
|
-
fontsize=
|
|
891
|
+
fontsize=FONT_SIZE_TEXT,
|
|
911
892
|
)
|
|
912
893
|
ax.axis("off")
|
|
913
894
|
return fig
|
|
@@ -928,12 +909,55 @@ def plot_correlation_heatmap(
|
|
|
928
909
|
fig_size = min(FIGURE_WIDTH_INCH * 0.6, 2 + num_params * 0.6)
|
|
929
910
|
fig, ax = plt.subplots(figsize=(fig_size, fig_size))
|
|
930
911
|
|
|
931
|
-
# Heatmap
|
|
932
|
-
|
|
912
|
+
# Heatmap using Rectangle patches (vector-compatible, no imshow)
|
|
913
|
+
cmap = plt.cm.RdBu_r
|
|
914
|
+
norm = Normalize(vmin=-1, vmax=1)
|
|
933
915
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
916
|
+
for i in range(num_params):
|
|
917
|
+
for j in range(num_params):
|
|
918
|
+
color = cmap(norm(corr_matrix[i, j]))
|
|
919
|
+
rect = Rectangle(
|
|
920
|
+
(j - 0.5, i - 0.5), # bottom-left corner
|
|
921
|
+
1,
|
|
922
|
+
1, # width, height
|
|
923
|
+
facecolor=color,
|
|
924
|
+
edgecolor="white",
|
|
925
|
+
linewidth=0.5,
|
|
926
|
+
)
|
|
927
|
+
ax.add_patch(rect)
|
|
928
|
+
|
|
929
|
+
# Set axis limits and aspect
|
|
930
|
+
ax.set_xlim(-0.5, num_params - 0.5)
|
|
931
|
+
ax.set_ylim(num_params - 0.5, -0.5) # Invert y-axis for matrix orientation
|
|
932
|
+
ax.set_aspect("equal")
|
|
933
|
+
|
|
934
|
+
# Vector colorbar using rectangles (no raster gradient)
|
|
935
|
+
# Create a separate axes for colorbar
|
|
936
|
+
from mpl_toolkits.axes_grid1 import make_axes_locatable
|
|
937
|
+
|
|
938
|
+
divider = make_axes_locatable(ax)
|
|
939
|
+
cax = divider.append_axes("right", size="5%", pad=0.1)
|
|
940
|
+
|
|
941
|
+
# Draw colorbar as discrete rectangles (20 segments for smooth gradient)
|
|
942
|
+
n_segments = 20
|
|
943
|
+
for k in range(n_segments):
|
|
944
|
+
val = -1 + 2 * k / (n_segments - 1) # -1 to 1
|
|
945
|
+
color = cmap(norm(val))
|
|
946
|
+
rect = Rectangle(
|
|
947
|
+
(0, val - 1 / n_segments),
|
|
948
|
+
1,
|
|
949
|
+
2 / n_segments,
|
|
950
|
+
facecolor=color,
|
|
951
|
+
edgecolor="none",
|
|
952
|
+
)
|
|
953
|
+
cax.add_patch(rect)
|
|
954
|
+
|
|
955
|
+
cax.set_xlim(0, 1)
|
|
956
|
+
cax.set_ylim(-1, 1)
|
|
957
|
+
cax.set_xticks([])
|
|
958
|
+
cax.set_ylabel("Correlation", fontsize=FONT_SIZE_TEXT)
|
|
959
|
+
cax.yaxis.set_label_position("right")
|
|
960
|
+
cax.yaxis.tick_right()
|
|
937
961
|
|
|
938
962
|
# Labels
|
|
939
963
|
ax.set_xticks(range(num_params))
|
|
@@ -944,14 +968,14 @@ def plot_correlation_heatmap(
|
|
|
944
968
|
# Annotate with values
|
|
945
969
|
for i in range(num_params):
|
|
946
970
|
for j in range(num_params):
|
|
947
|
-
|
|
971
|
+
text_color = "white" if abs(corr_matrix[i, j]) > 0.5 else "black"
|
|
948
972
|
ax.text(
|
|
949
973
|
j,
|
|
950
974
|
i,
|
|
951
975
|
f"{corr_matrix[i, j]:.2f}",
|
|
952
976
|
ha="center",
|
|
953
977
|
va="center",
|
|
954
|
-
color=
|
|
978
|
+
color=text_color,
|
|
955
979
|
fontsize=FONT_SIZE_TICKS,
|
|
956
980
|
)
|
|
957
981
|
|
|
@@ -985,74 +1009,40 @@ def plot_relative_error(
|
|
|
985
1009
|
Returns:
|
|
986
1010
|
Matplotlib Figure object
|
|
987
1011
|
"""
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
if y_true.ndim == 1:
|
|
992
|
-
y_true = y_true.reshape(-1, 1)
|
|
993
|
-
y_pred = y_pred.reshape(-1, 1)
|
|
994
|
-
|
|
995
|
-
# Downsample for visualization
|
|
996
|
-
if len(y_true) > max_samples:
|
|
997
|
-
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
998
|
-
y_true = y_true[indices]
|
|
999
|
-
y_pred = y_pred[indices]
|
|
1012
|
+
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
1013
|
+
y_true, y_pred, param_names, max_samples
|
|
1014
|
+
)
|
|
1000
1015
|
|
|
1001
1016
|
# Calculate relative error (avoid division by zero)
|
|
1002
1017
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
1003
1018
|
rel_error = np.abs((y_pred - y_true) / y_true) * 100
|
|
1004
1019
|
rel_error = np.nan_to_num(rel_error, nan=0.0, posinf=0.0, neginf=0.0)
|
|
1005
1020
|
|
|
1006
|
-
|
|
1007
|
-
cols = min(num_params, 4)
|
|
1008
|
-
rows = (num_params + cols - 1) // cols
|
|
1009
|
-
|
|
1010
|
-
# Calculate figure size
|
|
1011
|
-
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
1012
|
-
fig, axes = plt.subplots(
|
|
1013
|
-
rows,
|
|
1014
|
-
cols,
|
|
1015
|
-
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
1016
|
-
)
|
|
1017
|
-
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
1021
|
+
fig, axes = _create_subplot_grid(num_params)
|
|
1018
1022
|
|
|
1019
1023
|
for i in range(num_params):
|
|
1020
1024
|
ax = axes[i]
|
|
1021
|
-
|
|
1022
|
-
# Scatter plot
|
|
1023
|
-
ax.scatter(
|
|
1025
|
+
ax.plot(
|
|
1024
1026
|
y_true[:, i],
|
|
1025
1027
|
rel_error[:, i],
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
rasterized=True,
|
|
1028
|
+
"o",
|
|
1029
|
+
markersize=5,
|
|
1030
|
+
markerfacecolor=COLORS["scatter"],
|
|
1031
|
+
markeredgecolor="none",
|
|
1031
1032
|
label="Data",
|
|
1032
1033
|
)
|
|
1033
|
-
|
|
1034
|
-
# Mean relative error line
|
|
1035
1034
|
mean_rel = np.mean(rel_error[:, i])
|
|
1036
1035
|
ax.axhline(
|
|
1037
|
-
y=mean_rel,
|
|
1038
|
-
color=COLORS["accent"],
|
|
1039
|
-
linestyle="-",
|
|
1040
|
-
lw=1.2,
|
|
1041
|
-
label=f"Mean = {mean_rel:.2f}%",
|
|
1036
|
+
y=mean_rel, color=COLORS["accent"], linestyle="-", lw=1.2, label="Mean"
|
|
1042
1037
|
)
|
|
1043
1038
|
|
|
1044
|
-
|
|
1045
|
-
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
1046
|
-
ax.set_title(f"{title}")
|
|
1039
|
+
ax.set_title(f"{param_names[i]}")
|
|
1047
1040
|
ax.set_xlabel("True Value")
|
|
1048
|
-
ax.set_ylabel("Relative Error (
|
|
1041
|
+
ax.set_ylabel("Relative Error (\\%)")
|
|
1049
1042
|
ax.grid(True)
|
|
1050
|
-
ax.legend(fontsize=
|
|
1051
|
-
|
|
1052
|
-
# Hide unused subplots
|
|
1053
|
-
for i in range(num_params, len(axes)):
|
|
1054
|
-
axes[i].axis("off")
|
|
1043
|
+
ax.legend(fontsize=FONT_SIZE_TICKS, loc="best", fancybox=False, framealpha=1.0)
|
|
1055
1044
|
|
|
1045
|
+
_hide_unused_subplots(axes, num_params)
|
|
1056
1046
|
plt.tight_layout()
|
|
1057
1047
|
return fig
|
|
1058
1048
|
|
|
@@ -1096,7 +1086,7 @@ def plot_error_cdf(
|
|
|
1096
1086
|
with np.errstate(divide="ignore", invalid="ignore"):
|
|
1097
1087
|
errors = np.abs((y_pred - y_true) / y_true) * 100
|
|
1098
1088
|
errors = np.nan_to_num(errors, nan=0.0, posinf=0.0, neginf=0.0)
|
|
1099
|
-
xlabel = "Relative Error
|
|
1089
|
+
xlabel = r"Relative Error\;$(\%)$"
|
|
1100
1090
|
else:
|
|
1101
1091
|
errors = np.abs(y_pred - y_true)
|
|
1102
1092
|
xlabel = "Absolute Error"
|
|
@@ -1121,15 +1111,15 @@ def plot_error_cdf(
|
|
|
1121
1111
|
|
|
1122
1112
|
# Find 95th percentile (use np.percentile for accuracy)
|
|
1123
1113
|
p95_val = np.percentile(errors[:, i], 95)
|
|
1124
|
-
ax.axvline(x=p95_val, color=color, linestyle=":"
|
|
1114
|
+
ax.axvline(x=p95_val, color=color, linestyle=":")
|
|
1125
1115
|
|
|
1126
1116
|
# Reference lines
|
|
1127
|
-
ax.axhline(y=95, color="gray", linestyle="--", lw=0.8,
|
|
1117
|
+
ax.axhline(y=95, color="gray", linestyle="--", lw=0.8, label=r"95\%")
|
|
1128
1118
|
|
|
1129
1119
|
ax.set_xlabel(xlabel)
|
|
1130
|
-
ax.set_ylabel("Cumulative Percentage
|
|
1120
|
+
ax.set_ylabel(r"Cumulative Percentage\;$(\%)$")
|
|
1131
1121
|
ax.set_title("Cumulative Error Distribution")
|
|
1132
|
-
ax.legend(fontsize=
|
|
1122
|
+
ax.legend(fontsize=FONT_SIZE_TICKS, loc="best", fancybox=False, framealpha=1.0)
|
|
1133
1123
|
ax.grid(True)
|
|
1134
1124
|
ax.set_ylim(0, 105)
|
|
1135
1125
|
|
|
@@ -1161,67 +1151,38 @@ def plot_prediction_vs_index(
|
|
|
1161
1151
|
Returns:
|
|
1162
1152
|
Matplotlib Figure object
|
|
1163
1153
|
"""
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
# Handle 1D case
|
|
1167
|
-
if y_true.ndim == 1:
|
|
1168
|
-
y_true = y_true.reshape(-1, 1)
|
|
1169
|
-
y_pred = y_pred.reshape(-1, 1)
|
|
1170
|
-
|
|
1171
|
-
# Limit samples
|
|
1172
|
-
n_samples = min(len(y_true), max_samples)
|
|
1173
|
-
y_true = y_true[:n_samples]
|
|
1174
|
-
y_pred = y_pred[:n_samples]
|
|
1175
|
-
indices = np.arange(n_samples)
|
|
1176
|
-
|
|
1177
|
-
# Calculate grid dimensions
|
|
1178
|
-
cols = min(num_params, 4)
|
|
1179
|
-
rows = (num_params + cols - 1) // cols
|
|
1180
|
-
|
|
1181
|
-
# Calculate figure size
|
|
1182
|
-
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
1183
|
-
fig, axes = plt.subplots(
|
|
1184
|
-
rows,
|
|
1185
|
-
cols,
|
|
1186
|
-
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
1154
|
+
y_true, y_pred, param_names, num_params = _prepare_plot_data(
|
|
1155
|
+
y_true, y_pred, param_names, max_samples
|
|
1187
1156
|
)
|
|
1188
|
-
|
|
1157
|
+
indices = np.arange(len(y_true))
|
|
1158
|
+
fig, axes = _create_subplot_grid(num_params)
|
|
1189
1159
|
|
|
1190
1160
|
for i in range(num_params):
|
|
1191
1161
|
ax = axes[i]
|
|
1192
|
-
|
|
1193
|
-
# Plot true and predicted
|
|
1194
1162
|
ax.plot(
|
|
1195
1163
|
indices,
|
|
1196
1164
|
y_true[:, i],
|
|
1197
1165
|
"o",
|
|
1198
|
-
markersize=
|
|
1199
|
-
alpha=0.6,
|
|
1166
|
+
markersize=5,
|
|
1200
1167
|
color=COLORS["primary"],
|
|
1201
|
-
label="True",
|
|
1168
|
+
label="True" if i == 0 else None,
|
|
1202
1169
|
)
|
|
1203
1170
|
ax.plot(
|
|
1204
1171
|
indices,
|
|
1205
1172
|
y_pred[:, i],
|
|
1206
1173
|
"x",
|
|
1207
|
-
markersize=
|
|
1208
|
-
alpha=0.6,
|
|
1174
|
+
markersize=5,
|
|
1209
1175
|
color=COLORS["error"],
|
|
1210
|
-
label="Predicted",
|
|
1176
|
+
label="Predicted" if i == 0 else None,
|
|
1211
1177
|
)
|
|
1212
1178
|
|
|
1213
|
-
|
|
1214
|
-
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
1215
|
-
ax.set_title(f"{title}")
|
|
1179
|
+
ax.set_title(f"{param_names[i]}")
|
|
1216
1180
|
ax.set_xlabel("Sample Index")
|
|
1217
1181
|
ax.set_ylabel("Value")
|
|
1218
1182
|
ax.grid(True)
|
|
1219
|
-
ax.legend(fontsize=6, loc="best")
|
|
1220
|
-
|
|
1221
|
-
# Hide unused subplots
|
|
1222
|
-
for i in range(num_params, len(axes)):
|
|
1223
|
-
axes[i].axis("off")
|
|
1224
1183
|
|
|
1184
|
+
_hide_unused_subplots(axes, num_params)
|
|
1185
|
+
_add_unified_legend(fig, axes, ncol=2)
|
|
1225
1186
|
plt.tight_layout()
|
|
1226
1187
|
return fig
|
|
1227
1188
|
|
|
@@ -1290,7 +1251,7 @@ def plot_error_boxplot(
|
|
|
1290
1251
|
|
|
1291
1252
|
# Zero line for signed errors
|
|
1292
1253
|
if not use_relative:
|
|
1293
|
-
ax.axhline(y=0, color=COLORS["error"], linestyle="--", lw=1.0
|
|
1254
|
+
ax.axhline(y=0, color=COLORS["error"], linestyle="--", lw=1.0)
|
|
1294
1255
|
|
|
1295
1256
|
ax.set_ylabel(ylabel)
|
|
1296
1257
|
ax.set_title("Error Distribution by Parameter")
|