wavedl 1.5.6__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wavedl/utils/metrics.py CHANGED
@@ -24,10 +24,10 @@ from sklearn.metrics import r2_score
24
24
  FIGURE_WIDTH_CM = 19
25
25
  FIGURE_WIDTH_INCH = FIGURE_WIDTH_CM / 2.54
26
26
 
27
- # Font sizes (consistent 8pt for publication)
28
- FONT_SIZE_TEXT = 8
29
- FONT_SIZE_TICKS = 8
30
- FONT_SIZE_TITLE = 9
27
+ # Font sizes (publication quality)
28
+ FONT_SIZE_TEXT = 10
29
+ FONT_SIZE_TICKS = 9
30
+ FONT_SIZE_TITLE = 11
31
31
 
32
32
  # DPI for publication (300 for print, 150 for screen)
33
33
  FIGURE_DPI = 300
@@ -40,17 +40,43 @@ COLORS = {
40
40
  "neutral": "#6B717E", # Slate gray
41
41
  "error": "#C73E1D", # Red
42
42
  "success": "#3A7D44", # Green
43
+ "scatter": "#96C2D5", # Light steel blue (simulates primary at 50% alpha on white)
43
44
  }
44
45
 
45
46
 
47
+ def _is_latex_available() -> bool:
48
+ """Check if LaTeX is available for matplotlib rendering."""
49
+ import shutil
50
+
51
+ return shutil.which("latex") is not None
52
+
53
+
46
54
  def configure_matplotlib_style():
47
- """Configure matplotlib for publication-quality LaTeX-style plots."""
55
+ """Configure matplotlib for publication-quality LaTeX-style plots.
56
+
57
+ Falls back to standard fonts if LaTeX is not installed.
58
+ """
59
+ use_latex = _is_latex_available()
60
+
61
+ if use_latex:
62
+ latex_settings = {
63
+ "text.usetex": True,
64
+ "font.family": "serif",
65
+ "font.serif": ["Computer Modern Roman"],
66
+ "text.latex.preamble": r"\usepackage{amsmath} \usepackage{amssymb}",
67
+ }
68
+ else:
69
+ # Fallback for systems without LaTeX (e.g., CI runners)
70
+ latex_settings = {
71
+ "text.usetex": False,
72
+ "font.family": "serif",
73
+ "font.serif": ["DejaVu Serif", "Times New Roman", "serif"],
74
+ }
75
+
48
76
  plt.rcParams.update(
49
77
  {
50
- # LaTeX-style fonts
51
- "font.family": "serif",
52
- "font.serif": ["Times New Roman", "DejaVu Serif", "serif"],
53
- "mathtext.fontset": "cm", # Computer Modern for math
78
+ # LaTeX/font settings (conditional)
79
+ **latex_settings,
54
80
  # Font sizes
55
81
  "font.size": FONT_SIZE_TEXT,
56
82
  "axes.titlesize": FONT_SIZE_TITLE,
@@ -62,8 +88,8 @@ def configure_matplotlib_style():
62
88
  "axes.linewidth": 0.8,
63
89
  "grid.linewidth": 0.5,
64
90
  "lines.linewidth": 1.5,
65
- # Grid style
66
- "grid.alpha": 0.4,
91
+ # Grid style (use light gray instead of alpha for vector compatibility)
92
+ "grid.color": "#CCCCCC",
67
93
  "grid.linestyle": ":",
68
94
  # Figure settings
69
95
  "figure.dpi": FIGURE_DPI,
@@ -73,6 +99,9 @@ def configure_matplotlib_style():
73
99
  # Remove top/right spines for cleaner look
74
100
  "axes.spines.top": False,
75
101
  "axes.spines.right": False,
102
+ # SVG/vector export settings - prevent rasterization
103
+ "svg.fonttype": "none", # Embed fonts as text, not paths
104
+ "image.composite_image": False, # Don't composite images
76
105
  }
77
106
  )
78
107
 
@@ -211,6 +240,83 @@ def calc_per_target_r2(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
211
240
  return np.array(r2_scores)
212
241
 
213
242
 
243
+ # ==============================================================================
244
+ # VISUALIZATION HELPERS (internal)
245
+ # ==============================================================================
246
+ def _prepare_plot_data(
247
+ y_true: np.ndarray,
248
+ y_pred: np.ndarray,
249
+ param_names: list[str] | None = None,
250
+ max_samples: int | None = None,
251
+ ) -> tuple[np.ndarray, np.ndarray, list[str], int]:
252
+ """Prepare data for plotting: reshape, sample, and generate param names."""
253
+ # Handle 1D case
254
+ if y_true.ndim == 1:
255
+ y_true = y_true.reshape(-1, 1)
256
+ y_pred = y_pred.reshape(-1, 1)
257
+
258
+ num_params = y_true.shape[1]
259
+
260
+ # Subsample if needed
261
+ if max_samples and len(y_true) > max_samples:
262
+ indices = np.random.choice(len(y_true), max_samples, replace=False)
263
+ y_true = y_true[indices]
264
+ y_pred = y_pred[indices]
265
+
266
+ # Generate default param names if needed
267
+ if param_names is None or len(param_names) != num_params:
268
+ param_names = [f"P{i}" for i in range(num_params)]
269
+
270
+ return y_true, y_pred, param_names, num_params
271
+
272
+
273
+ def _create_subplot_grid(
274
+ num_params: int,
275
+ height_ratio: float = 1.0,
276
+ max_cols: int = 4,
277
+ ) -> tuple[plt.Figure, np.ndarray]:
278
+ """Create a subplot grid for multi-parameter plots."""
279
+ cols = min(num_params, max_cols)
280
+ rows = (num_params + cols - 1) // cols
281
+ subplot_size = FIGURE_WIDTH_INCH / cols
282
+
283
+ fig, axes = plt.subplots(
284
+ rows,
285
+ cols,
286
+ figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * height_ratio),
287
+ )
288
+ axes = np.array(axes).flatten() if num_params > 1 else [axes]
289
+ return fig, axes
290
+
291
+
292
+ def _add_unified_legend(
293
+ fig: plt.Figure,
294
+ axes: np.ndarray,
295
+ ncol: int = 2,
296
+ y_offset: float = -0.13,
297
+ bottom_margin: float = 0.22,
298
+ ) -> None:
299
+ """Add a unified legend below the figure."""
300
+ handles, labels = axes[0].get_legend_handles_labels()
301
+ fig.legend(
302
+ handles,
303
+ labels,
304
+ loc="lower center",
305
+ ncol=ncol,
306
+ fontsize=FONT_SIZE_TEXT,
307
+ fancybox=False,
308
+ framealpha=1.0,
309
+ bbox_to_anchor=(0.5, y_offset),
310
+ )
311
+ fig.subplots_adjust(bottom=bottom_margin)
312
+
313
+
314
+ def _hide_unused_subplots(axes: np.ndarray, num_used: int) -> None:
315
+ """Hide unused subplots in a grid."""
316
+ for i in range(num_used, len(axes)):
317
+ axes[i].axis("off")
318
+
319
+
214
320
  # ==============================================================================
215
321
  # VISUALIZATION - SCATTER PLOTS
216
322
  # ==============================================================================
@@ -235,50 +341,25 @@ def plot_scientific_scatter(
235
341
  Returns:
236
342
  Matplotlib Figure object (can be saved or logged to WandB)
237
343
  """
238
- num_params = y_true.shape[1] if y_true.ndim > 1 else 1
239
-
240
- # Handle 1D case
241
- if y_true.ndim == 1:
242
- y_true = y_true.reshape(-1, 1)
243
- y_pred = y_pred.reshape(-1, 1)
244
-
245
- # Downsample for visualization performance
246
- if len(y_true) > max_samples:
247
- indices = np.random.choice(len(y_true), max_samples, replace=False)
248
- y_true = y_true[indices]
249
- y_pred = y_pred[indices]
250
-
251
- # Calculate grid dimensions
252
- cols = min(num_params, 4)
253
- rows = (num_params + cols - 1) // cols
254
-
255
- # Calculate figure size (19 cm width)
256
- subplot_size = FIGURE_WIDTH_INCH / cols
257
- fig, axes = plt.subplots(
258
- rows,
259
- cols,
260
- figsize=(FIGURE_WIDTH_INCH, subplot_size * rows),
344
+ y_true, y_pred, param_names, num_params = _prepare_plot_data(
345
+ y_true, y_pred, param_names, max_samples
261
346
  )
262
- axes = np.array(axes).flatten() if num_params > 1 else [axes]
347
+ fig, axes = _create_subplot_grid(num_params, height_ratio=1.0)
263
348
 
264
349
  for i in range(num_params):
265
350
  ax = axes[i]
266
351
 
267
352
  # Calculate R² for this target
268
- if len(y_true) >= 2:
269
- r2 = r2_score(y_true[:, i], y_pred[:, i])
270
- else:
271
- r2 = float("nan")
353
+ r2 = r2_score(y_true[:, i], y_pred[:, i]) if len(y_true) >= 2 else float("nan")
272
354
 
273
- # Scatter plot
274
- ax.scatter(
355
+ # Scatter plot (using plot for vector SVG compatibility)
356
+ ax.plot(
275
357
  y_true[:, i],
276
358
  y_pred[:, i],
277
- alpha=0.5,
278
- s=15,
279
- c=COLORS["primary"],
280
- edgecolors="none",
281
- rasterized=True,
359
+ "o",
360
+ markersize=5,
361
+ markerfacecolor=COLORS["scatter"],
362
+ markeredgecolor="none",
282
363
  label="Data",
283
364
  )
284
365
 
@@ -292,25 +373,20 @@ def plot_scientific_scatter(
292
373
  "--",
293
374
  color=COLORS["error"],
294
375
  lw=1.2,
295
- alpha=0.8,
296
376
  label="Ideal",
297
377
  )
298
378
 
299
379
  # Labels and formatting
300
- title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
301
- ax.set_title(f"{title}\n$R^2 = {r2:.4f}$")
380
+ ax.set_title(f"{param_names[i]}\n$R^2 = {r2:.4f}$")
302
381
  ax.set_xlabel("Ground Truth")
303
382
  ax.set_ylabel("Prediction")
304
383
  ax.grid(True)
305
384
  ax.set_aspect("equal", adjustable="box")
306
385
  ax.set_xlim(min_val - margin, max_val + margin)
307
386
  ax.set_ylim(min_val - margin, max_val + margin)
308
- ax.legend(fontsize=6, loc="best")
309
-
310
- # Hide unused subplots
311
- for i in range(num_params, len(axes)):
312
- axes[i].axis("off")
387
+ ax.legend(fontsize=FONT_SIZE_TICKS, loc="best", fancybox=False, framealpha=1.0)
313
388
 
389
+ _hide_unused_subplots(axes, num_params)
314
390
  plt.tight_layout()
315
391
  return fig
316
392
 
@@ -339,75 +415,48 @@ def plot_error_histogram(
339
415
  Returns:
340
416
  Matplotlib Figure object
341
417
  """
342
- num_params = y_true.shape[1] if y_true.ndim > 1 else 1
343
-
344
- # Handle 1D case
345
- if y_true.ndim == 1:
346
- y_true = y_true.reshape(-1, 1)
347
- y_pred = y_pred.reshape(-1, 1)
348
-
349
- # Calculate errors
350
- errors = y_pred - y_true
351
-
352
- # Calculate grid dimensions
353
- cols = min(num_params, 4)
354
- rows = (num_params + cols - 1) // cols
355
-
356
- # Calculate figure size
357
- subplot_size = FIGURE_WIDTH_INCH / cols
358
- fig, axes = plt.subplots(
359
- rows,
360
- cols,
361
- figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
418
+ y_true, y_pred, param_names, num_params = _prepare_plot_data(
419
+ y_true, y_pred, param_names
362
420
  )
363
- axes = np.array(axes).flatten() if num_params > 1 else [axes]
421
+ errors = y_pred - y_true
422
+ fig, axes = _create_subplot_grid(num_params)
364
423
 
365
424
  for i in range(num_params):
366
425
  ax = axes[i]
367
426
  err = errors[:, i]
427
+ mean_err, std_err, mae = np.mean(err), np.std(err), np.mean(np.abs(err))
368
428
 
369
- # Statistics
370
- mean_err = np.mean(err)
371
- std_err = np.std(err)
372
- mae = np.mean(np.abs(err))
373
-
374
- # Histogram
429
+ # Histogram - only label first subplot
375
430
  ax.hist(
376
431
  err,
377
432
  bins=bins,
378
433
  color=COLORS["primary"],
379
- alpha=0.7,
380
434
  edgecolor="white",
381
435
  linewidth=0.5,
382
- label="Errors",
436
+ label="Errors" if i == 0 else None,
383
437
  )
384
-
385
- # Vertical line at zero
386
438
  ax.axvline(
387
- x=0, color=COLORS["error"], linestyle="--", lw=1.2, alpha=0.8, label="Zero"
439
+ x=0,
440
+ color=COLORS["error"],
441
+ linestyle="--",
442
+ lw=1.2,
443
+ label="Zero" if i == 0 else None,
388
444
  )
389
-
390
- # Mean line
391
445
  ax.axvline(
392
446
  x=mean_err,
393
447
  color=COLORS["accent"],
394
448
  linestyle="-",
395
449
  lw=1.2,
396
- label=f"Mean = {mean_err:.4f}",
450
+ label="Mean" if i == 0 else None,
397
451
  )
398
452
 
399
- # Labels and formatting
400
- title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
401
- ax.set_title(f"{title}\nMAE = {mae:.4f}, $\\sigma$ = {std_err:.4f}")
453
+ ax.set_title(f"{param_names[i]}\nMAE = {mae:.4f}, $\\sigma$ = {std_err:.4f}")
402
454
  ax.set_xlabel("Prediction Error")
403
455
  ax.set_ylabel("Count")
404
456
  ax.grid(True, axis="y")
405
- ax.legend(fontsize=6, loc="best")
406
-
407
- # Hide unused subplots
408
- for i in range(num_params, len(axes)):
409
- axes[i].axis("off")
410
457
 
458
+ _hide_unused_subplots(axes, num_params)
459
+ _add_unified_legend(fig, axes, ncol=3)
411
460
  plt.tight_layout()
412
461
  return fig
413
462
 
@@ -436,78 +485,46 @@ def plot_residuals(
436
485
  Returns:
437
486
  Matplotlib Figure object
438
487
  """
439
- num_params = y_true.shape[1] if y_true.ndim > 1 else 1
440
-
441
- # Handle 1D case
442
- if y_true.ndim == 1:
443
- y_true = y_true.reshape(-1, 1)
444
- y_pred = y_pred.reshape(-1, 1)
445
-
446
- # Calculate residuals
447
- residuals = y_pred - y_true
448
-
449
- # Downsample for visualization
450
- if len(y_true) > max_samples:
451
- indices = np.random.choice(len(y_true), max_samples, replace=False)
452
- y_pred = y_pred[indices]
453
- residuals = residuals[indices]
454
-
455
- # Calculate grid dimensions
456
- cols = min(num_params, 4)
457
- rows = (num_params + cols - 1) // cols
458
-
459
- # Calculate figure size
460
- subplot_size = FIGURE_WIDTH_INCH / cols
461
- fig, axes = plt.subplots(
462
- rows,
463
- cols,
464
- figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
488
+ y_true, y_pred, param_names, num_params = _prepare_plot_data(
489
+ y_true, y_pred, param_names, max_samples
465
490
  )
466
- axes = np.array(axes).flatten() if num_params > 1 else [axes]
491
+ residuals = y_pred - y_true
492
+ fig, axes = _create_subplot_grid(num_params)
467
493
 
468
494
  for i in range(num_params):
469
495
  ax = axes[i]
470
-
471
- # Scatter plot of residuals
472
- ax.scatter(
496
+ ax.plot(
473
497
  y_pred[:, i],
474
498
  residuals[:, i],
475
- alpha=0.5,
476
- s=15,
477
- c=COLORS["primary"],
478
- edgecolors="none",
479
- rasterized=True,
480
- label="Data",
499
+ "o",
500
+ markersize=5,
501
+ markerfacecolor=COLORS["scatter"],
502
+ markeredgecolor="none",
503
+ label="Data" if i == 0 else None,
481
504
  )
482
-
483
- # Zero line
484
505
  ax.axhline(
485
- y=0, color=COLORS["error"], linestyle="--", lw=1.2, alpha=0.8, label="Zero"
506
+ y=0,
507
+ color=COLORS["error"],
508
+ linestyle="--",
509
+ lw=1.2,
510
+ label="Zero" if i == 0 else None,
486
511
  )
487
-
488
- # Calculate and show mean residual
489
512
  mean_res = np.mean(residuals[:, i])
490
513
  ax.axhline(
491
514
  y=mean_res,
492
515
  color=COLORS["accent"],
493
516
  linestyle="-",
494
517
  lw=1.0,
495
- alpha=0.7,
496
- label=f"Mean = {mean_res:.4f}",
518
+ label="Mean" if i == 0 else None,
497
519
  )
498
520
 
499
- # Labels
500
- title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
501
- ax.set_title(f"{title}")
521
+ ax.set_title(f"{param_names[i]}")
502
522
  ax.set_xlabel("Predicted Value")
503
- ax.set_ylabel("Residual (Pred - True)")
523
+ ax.set_ylabel("Pred - True")
504
524
  ax.grid(True)
505
- ax.legend(fontsize=6, loc="best")
506
-
507
- # Hide unused subplots
508
- for i in range(num_params, len(axes)):
509
- axes[i].axis("off")
510
525
 
526
+ _hide_unused_subplots(axes, num_params)
527
+ _add_unified_legend(fig, axes, ncol=3)
511
528
  plt.tight_layout()
512
529
  return fig
513
530
 
@@ -602,7 +619,7 @@ def create_training_curves(
602
619
  ax1.set_xlabel("Epoch")
603
620
  ax1.set_ylabel("Loss")
604
621
  ax1.set_yscale("log") # Log scale for loss
605
- ax1.grid(True, alpha=0.3)
622
+ ax1.grid(True)
606
623
 
607
624
  # Collect all loss values and set clean power of 10 ticks
608
625
  all_loss_values = []
@@ -623,7 +640,6 @@ def create_training_curves(
623
640
  "--",
624
641
  color=COLORS["neutral"],
625
642
  linewidth=1.0,
626
- alpha=0.7,
627
643
  label="Learning Rate",
628
644
  )
629
645
  ax2.set_ylabel("Learning Rate")
@@ -635,7 +651,14 @@ def create_training_curves(
635
651
 
636
652
  # Combined legend
637
653
  labels = [l.get_label() for l in lines]
638
- ax1.legend(lines, labels, loc="best", fontsize=6)
654
+ ax1.legend(
655
+ lines,
656
+ labels,
657
+ loc="best",
658
+ fontsize=FONT_SIZE_TICKS,
659
+ fancybox=False,
660
+ framealpha=1.0,
661
+ )
639
662
 
640
663
  ax1.set_title("Training Curves")
641
664
 
@@ -668,97 +691,53 @@ def plot_bland_altman(
668
691
  Returns:
669
692
  Matplotlib Figure object
670
693
  """
671
- num_params = y_true.shape[1] if y_true.ndim > 1 else 1
672
-
673
- # Handle 1D case
674
- if y_true.ndim == 1:
675
- y_true = y_true.reshape(-1, 1)
676
- y_pred = y_pred.reshape(-1, 1)
677
-
678
- # Calculate mean and difference
694
+ y_true, y_pred, param_names, num_params = _prepare_plot_data(
695
+ y_true, y_pred, param_names, max_samples
696
+ )
679
697
  mean_vals = (y_true + y_pred) / 2
680
698
  diff_vals = y_pred - y_true
681
-
682
- # Downsample for visualization
683
- if len(y_true) > max_samples:
684
- indices = np.random.choice(len(y_true), max_samples, replace=False)
685
- mean_vals = mean_vals[indices]
686
- diff_vals = diff_vals[indices]
687
-
688
- # Calculate grid dimensions
689
- cols = min(num_params, 4)
690
- rows = (num_params + cols - 1) // cols
691
-
692
- # Calculate figure size
693
- subplot_size = FIGURE_WIDTH_INCH / cols
694
- fig, axes = plt.subplots(
695
- rows,
696
- cols,
697
- figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
698
- )
699
- axes = np.array(axes).flatten() if num_params > 1 else [axes]
699
+ fig, axes = _create_subplot_grid(num_params)
700
700
 
701
701
  for i in range(num_params):
702
702
  ax = axes[i]
703
-
704
703
  mean_diff = np.mean(diff_vals[:, i])
705
704
  std_diff = np.std(diff_vals[:, i])
706
-
707
- # Limits of agreement (95% CI = mean ± 1.96*SD)
708
705
  upper_loa = mean_diff + 1.96 * std_diff
709
706
  lower_loa = mean_diff - 1.96 * std_diff
710
707
 
711
- # Scatter plot
712
- ax.scatter(
708
+ ax.plot(
713
709
  mean_vals[:, i],
714
710
  diff_vals[:, i],
715
- alpha=0.5,
716
- s=15,
717
- c=COLORS["primary"],
718
- edgecolors="none",
719
- rasterized=True,
711
+ "o",
712
+ markersize=5,
713
+ markerfacecolor=COLORS["scatter"],
714
+ markeredgecolor="none",
715
+ label="Data" if i == 0 else None,
720
716
  )
721
-
722
- # Mean difference line
723
717
  ax.axhline(
724
718
  y=mean_diff,
725
719
  color=COLORS["accent"],
726
720
  linestyle="-",
727
721
  lw=1.2,
728
- label=f"Mean = {mean_diff:.3f}",
722
+ label="Mean" if i == 0 else None,
729
723
  )
730
-
731
- # Limits of agreement
732
724
  ax.axhline(
733
725
  y=upper_loa,
734
726
  color=COLORS["error"],
735
727
  linestyle="--",
736
728
  lw=1.0,
737
- label=f"+1.96 SD = {upper_loa:.3f}",
729
+ label=r"$\pm$1.96 SD" if i == 0 else None,
738
730
  )
739
- ax.axhline(
740
- y=lower_loa,
741
- color=COLORS["error"],
742
- linestyle="--",
743
- lw=1.0,
744
- label=f"-1.96 SD = {lower_loa:.3f}",
745
- )
746
-
747
- # Zero line
748
- ax.axhline(y=0, color="gray", linestyle=":", lw=0.8, alpha=0.5)
731
+ ax.axhline(y=lower_loa, color=COLORS["error"], linestyle="--", lw=1.0)
732
+ ax.axhline(y=0, color="gray", linestyle=":", lw=0.8)
749
733
 
750
- # Labels
751
- title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
752
- ax.set_title(f"{title}")
753
- ax.set_xlabel("Mean of True and Predicted")
754
- ax.set_ylabel("Difference (Pred - True)")
734
+ ax.set_title(f"{param_names[i]}")
735
+ ax.set_xlabel("Mean of True and Pred")
736
+ ax.set_ylabel("Pred - True")
755
737
  ax.grid(True)
756
- ax.legend(fontsize=6, loc="best")
757
-
758
- # Hide unused subplots
759
- for i in range(num_params, len(axes)):
760
- axes[i].axis("off")
761
738
 
739
+ _hide_unused_subplots(axes, num_params)
740
+ _add_unified_legend(fig, axes, ncol=3)
762
741
  plt.tight_layout()
763
742
  return fig
764
743
 
@@ -828,7 +807,7 @@ def plot_qq(
828
807
  "Zero variance\n(constant errors)",
829
808
  ha="center",
830
809
  va="center",
831
- fontsize=10,
810
+ fontsize=FONT_SIZE_TEXT,
832
811
  transform=ax.transAxes,
833
812
  )
834
813
  ax.set_title(f"{title}\n(zero variance)")
@@ -841,15 +820,14 @@ def plot_qq(
841
820
  # Calculate theoretical quantiles and sample quantiles
842
821
  (osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
843
822
 
844
- # Scatter plot
845
- ax.scatter(
823
+ # Scatter plot (using plot for vector SVG compatibility)
824
+ ax.plot(
846
825
  osm,
847
826
  osr,
848
- alpha=0.5,
849
- s=15,
850
- c=COLORS["primary"],
851
- edgecolors="none",
852
- rasterized=True,
827
+ "o",
828
+ markersize=5,
829
+ markerfacecolor=COLORS["scatter"],
830
+ markeredgecolor="none",
853
831
  label="Data",
854
832
  )
855
833
 
@@ -864,7 +842,7 @@ def plot_qq(
864
842
  ax.set_xlabel("Theoretical Quantiles")
865
843
  ax.set_ylabel("Sample Quantiles")
866
844
  ax.grid(True)
867
- ax.legend(fontsize=6, loc="best")
845
+ ax.legend(fontsize=FONT_SIZE_TICKS, loc="best", fancybox=False, framealpha=1.0)
868
846
 
869
847
  # Hide unused subplots
870
848
  for i in range(num_params, len(axes)):
@@ -896,6 +874,9 @@ def plot_correlation_heatmap(
896
874
  Returns:
897
875
  Matplotlib Figure object
898
876
  """
877
+ from matplotlib.colors import Normalize
878
+ from matplotlib.patches import Rectangle
879
+
899
880
  num_params = y_true.shape[1] if y_true.ndim > 1 else 1
900
881
 
901
882
  if num_params < 2:
@@ -907,7 +888,7 @@ def plot_correlation_heatmap(
907
888
  "Correlation heatmap requires\nat least 2 parameters",
908
889
  ha="center",
909
890
  va="center",
910
- fontsize=10,
891
+ fontsize=FONT_SIZE_TEXT,
911
892
  )
912
893
  ax.axis("off")
913
894
  return fig
@@ -928,12 +909,55 @@ def plot_correlation_heatmap(
928
909
  fig_size = min(FIGURE_WIDTH_INCH * 0.6, 2 + num_params * 0.6)
929
910
  fig, ax = plt.subplots(figsize=(fig_size, fig_size))
930
911
 
931
- # Heatmap
932
- im = ax.imshow(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, aspect="equal")
912
+ # Heatmap using Rectangle patches (vector-compatible, no imshow)
913
+ cmap = plt.cm.RdBu_r
914
+ norm = Normalize(vmin=-1, vmax=1)
933
915
 
934
- # Colorbar
935
- cbar = plt.colorbar(im, ax=ax, shrink=0.8)
936
- cbar.set_label("Correlation", fontsize=FONT_SIZE_TICKS)
916
+ for i in range(num_params):
917
+ for j in range(num_params):
918
+ color = cmap(norm(corr_matrix[i, j]))
919
+ rect = Rectangle(
920
+ (j - 0.5, i - 0.5), # bottom-left corner
921
+ 1,
922
+ 1, # width, height
923
+ facecolor=color,
924
+ edgecolor="white",
925
+ linewidth=0.5,
926
+ )
927
+ ax.add_patch(rect)
928
+
929
+ # Set axis limits and aspect
930
+ ax.set_xlim(-0.5, num_params - 0.5)
931
+ ax.set_ylim(num_params - 0.5, -0.5) # Invert y-axis for matrix orientation
932
+ ax.set_aspect("equal")
933
+
934
+ # Vector colorbar using rectangles (no raster gradient)
935
+ # Create a separate axes for colorbar
936
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
937
+
938
+ divider = make_axes_locatable(ax)
939
+ cax = divider.append_axes("right", size="5%", pad=0.1)
940
+
941
+ # Draw colorbar as discrete rectangles (20 segments for smooth gradient)
942
+ n_segments = 20
943
+ for k in range(n_segments):
944
+ val = -1 + 2 * k / (n_segments - 1) # -1 to 1
945
+ color = cmap(norm(val))
946
+ rect = Rectangle(
947
+ (0, val - 1 / n_segments),
948
+ 1,
949
+ 2 / n_segments,
950
+ facecolor=color,
951
+ edgecolor="none",
952
+ )
953
+ cax.add_patch(rect)
954
+
955
+ cax.set_xlim(0, 1)
956
+ cax.set_ylim(-1, 1)
957
+ cax.set_xticks([])
958
+ cax.set_ylabel("Correlation", fontsize=FONT_SIZE_TEXT)
959
+ cax.yaxis.set_label_position("right")
960
+ cax.yaxis.tick_right()
937
961
 
938
962
  # Labels
939
963
  ax.set_xticks(range(num_params))
@@ -944,14 +968,14 @@ def plot_correlation_heatmap(
944
968
  # Annotate with values
945
969
  for i in range(num_params):
946
970
  for j in range(num_params):
947
- color = "white" if abs(corr_matrix[i, j]) > 0.5 else "black"
971
+ text_color = "white" if abs(corr_matrix[i, j]) > 0.5 else "black"
948
972
  ax.text(
949
973
  j,
950
974
  i,
951
975
  f"{corr_matrix[i, j]:.2f}",
952
976
  ha="center",
953
977
  va="center",
954
- color=color,
978
+ color=text_color,
955
979
  fontsize=FONT_SIZE_TICKS,
956
980
  )
957
981
 
@@ -985,74 +1009,40 @@ def plot_relative_error(
985
1009
  Returns:
986
1010
  Matplotlib Figure object
987
1011
  """
988
- num_params = y_true.shape[1] if y_true.ndim > 1 else 1
989
-
990
- # Handle 1D case
991
- if y_true.ndim == 1:
992
- y_true = y_true.reshape(-1, 1)
993
- y_pred = y_pred.reshape(-1, 1)
994
-
995
- # Downsample for visualization
996
- if len(y_true) > max_samples:
997
- indices = np.random.choice(len(y_true), max_samples, replace=False)
998
- y_true = y_true[indices]
999
- y_pred = y_pred[indices]
1012
+ y_true, y_pred, param_names, num_params = _prepare_plot_data(
1013
+ y_true, y_pred, param_names, max_samples
1014
+ )
1000
1015
 
1001
1016
  # Calculate relative error (avoid division by zero)
1002
1017
  with np.errstate(divide="ignore", invalid="ignore"):
1003
1018
  rel_error = np.abs((y_pred - y_true) / y_true) * 100
1004
1019
  rel_error = np.nan_to_num(rel_error, nan=0.0, posinf=0.0, neginf=0.0)
1005
1020
 
1006
- # Calculate grid dimensions
1007
- cols = min(num_params, 4)
1008
- rows = (num_params + cols - 1) // cols
1009
-
1010
- # Calculate figure size
1011
- subplot_size = FIGURE_WIDTH_INCH / cols
1012
- fig, axes = plt.subplots(
1013
- rows,
1014
- cols,
1015
- figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
1016
- )
1017
- axes = np.array(axes).flatten() if num_params > 1 else [axes]
1021
+ fig, axes = _create_subplot_grid(num_params)
1018
1022
 
1019
1023
  for i in range(num_params):
1020
1024
  ax = axes[i]
1021
-
1022
- # Scatter plot
1023
- ax.scatter(
1025
+ ax.plot(
1024
1026
  y_true[:, i],
1025
1027
  rel_error[:, i],
1026
- alpha=0.5,
1027
- s=15,
1028
- c=COLORS["primary"],
1029
- edgecolors="none",
1030
- rasterized=True,
1028
+ "o",
1029
+ markersize=5,
1030
+ markerfacecolor=COLORS["scatter"],
1031
+ markeredgecolor="none",
1031
1032
  label="Data",
1032
1033
  )
1033
-
1034
- # Mean relative error line
1035
1034
  mean_rel = np.mean(rel_error[:, i])
1036
1035
  ax.axhline(
1037
- y=mean_rel,
1038
- color=COLORS["accent"],
1039
- linestyle="-",
1040
- lw=1.2,
1041
- label=f"Mean = {mean_rel:.2f}%",
1036
+ y=mean_rel, color=COLORS["accent"], linestyle="-", lw=1.2, label="Mean"
1042
1037
  )
1043
1038
 
1044
- # Labels
1045
- title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
1046
- ax.set_title(f"{title}")
1039
+ ax.set_title(f"{param_names[i]}")
1047
1040
  ax.set_xlabel("True Value")
1048
- ax.set_ylabel("Relative Error (%)")
1041
+ ax.set_ylabel("Relative Error (\\%)")
1049
1042
  ax.grid(True)
1050
- ax.legend(fontsize=6, loc="best")
1051
-
1052
- # Hide unused subplots
1053
- for i in range(num_params, len(axes)):
1054
- axes[i].axis("off")
1043
+ ax.legend(fontsize=FONT_SIZE_TICKS, loc="best", fancybox=False, framealpha=1.0)
1055
1044
 
1045
+ _hide_unused_subplots(axes, num_params)
1056
1046
  plt.tight_layout()
1057
1047
  return fig
1058
1048
 
@@ -1096,7 +1086,7 @@ def plot_error_cdf(
1096
1086
  with np.errstate(divide="ignore", invalid="ignore"):
1097
1087
  errors = np.abs((y_pred - y_true) / y_true) * 100
1098
1088
  errors = np.nan_to_num(errors, nan=0.0, posinf=0.0, neginf=0.0)
1099
- xlabel = "Relative Error (%)"
1089
+ xlabel = r"Relative Error\;$(\%)$"
1100
1090
  else:
1101
1091
  errors = np.abs(y_pred - y_true)
1102
1092
  xlabel = "Absolute Error"
@@ -1121,15 +1111,15 @@ def plot_error_cdf(
1121
1111
 
1122
1112
  # Find 95th percentile (use np.percentile for accuracy)
1123
1113
  p95_val = np.percentile(errors[:, i], 95)
1124
- ax.axvline(x=p95_val, color=color, linestyle=":", alpha=0.5)
1114
+ ax.axvline(x=p95_val, color=color, linestyle=":")
1125
1115
 
1126
1116
  # Reference lines
1127
- ax.axhline(y=95, color="gray", linestyle="--", lw=0.8, alpha=0.7, label="95%")
1117
+ ax.axhline(y=95, color="gray", linestyle="--", lw=0.8, label=r"95\%")
1128
1118
 
1129
1119
  ax.set_xlabel(xlabel)
1130
- ax.set_ylabel("Cumulative Percentage (%)")
1120
+ ax.set_ylabel(r"Cumulative Percentage\;$(\%)$")
1131
1121
  ax.set_title("Cumulative Error Distribution")
1132
- ax.legend(fontsize=6, loc="best")
1122
+ ax.legend(fontsize=FONT_SIZE_TICKS, loc="best", fancybox=False, framealpha=1.0)
1133
1123
  ax.grid(True)
1134
1124
  ax.set_ylim(0, 105)
1135
1125
 
@@ -1161,67 +1151,38 @@ def plot_prediction_vs_index(
1161
1151
  Returns:
1162
1152
  Matplotlib Figure object
1163
1153
  """
1164
- num_params = y_true.shape[1] if y_true.ndim > 1 else 1
1165
-
1166
- # Handle 1D case
1167
- if y_true.ndim == 1:
1168
- y_true = y_true.reshape(-1, 1)
1169
- y_pred = y_pred.reshape(-1, 1)
1170
-
1171
- # Limit samples
1172
- n_samples = min(len(y_true), max_samples)
1173
- y_true = y_true[:n_samples]
1174
- y_pred = y_pred[:n_samples]
1175
- indices = np.arange(n_samples)
1176
-
1177
- # Calculate grid dimensions
1178
- cols = min(num_params, 4)
1179
- rows = (num_params + cols - 1) // cols
1180
-
1181
- # Calculate figure size
1182
- subplot_size = FIGURE_WIDTH_INCH / cols
1183
- fig, axes = plt.subplots(
1184
- rows,
1185
- cols,
1186
- figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
1154
+ y_true, y_pred, param_names, num_params = _prepare_plot_data(
1155
+ y_true, y_pred, param_names, max_samples
1187
1156
  )
1188
- axes = np.array(axes).flatten() if num_params > 1 else [axes]
1157
+ indices = np.arange(len(y_true))
1158
+ fig, axes = _create_subplot_grid(num_params)
1189
1159
 
1190
1160
  for i in range(num_params):
1191
1161
  ax = axes[i]
1192
-
1193
- # Plot true and predicted
1194
1162
  ax.plot(
1195
1163
  indices,
1196
1164
  y_true[:, i],
1197
1165
  "o",
1198
- markersize=3,
1199
- alpha=0.6,
1166
+ markersize=5,
1200
1167
  color=COLORS["primary"],
1201
- label="True",
1168
+ label="True" if i == 0 else None,
1202
1169
  )
1203
1170
  ax.plot(
1204
1171
  indices,
1205
1172
  y_pred[:, i],
1206
1173
  "x",
1207
- markersize=3,
1208
- alpha=0.6,
1174
+ markersize=5,
1209
1175
  color=COLORS["error"],
1210
- label="Predicted",
1176
+ label="Predicted" if i == 0 else None,
1211
1177
  )
1212
1178
 
1213
- # Labels
1214
- title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
1215
- ax.set_title(f"{title}")
1179
+ ax.set_title(f"{param_names[i]}")
1216
1180
  ax.set_xlabel("Sample Index")
1217
1181
  ax.set_ylabel("Value")
1218
1182
  ax.grid(True)
1219
- ax.legend(fontsize=6, loc="best")
1220
-
1221
- # Hide unused subplots
1222
- for i in range(num_params, len(axes)):
1223
- axes[i].axis("off")
1224
1183
 
1184
+ _hide_unused_subplots(axes, num_params)
1185
+ _add_unified_legend(fig, axes, ncol=2)
1225
1186
  plt.tight_layout()
1226
1187
  return fig
1227
1188
 
@@ -1290,7 +1251,7 @@ def plot_error_boxplot(
1290
1251
 
1291
1252
  # Zero line for signed errors
1292
1253
  if not use_relative:
1293
- ax.axhline(y=0, color=COLORS["error"], linestyle="--", lw=1.0, alpha=0.7)
1254
+ ax.axhline(y=0, color=COLORS["error"], linestyle="--", lw=1.0)
1294
1255
 
1295
1256
  ax.set_ylabel(ylabel)
1296
1257
  ax.set_title("Error Distribution by Parameter")