wavedl 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1236 @@
1
+ """
2
+ Scientific Metrics and Visualization Utilities
3
+ ===============================================
4
+
5
+ Provides metric tracking, statistical calculations, and publication-quality
6
+ visualization tools for deep learning experiments.
7
+
8
+ Author: Ductho Le (ductho.le@outlook.com)
9
+ Version: 1.1.0
10
+ """
11
+
12
+ from typing import Any
13
+
14
+ import matplotlib.pyplot as plt
15
+ import numpy as np
16
+ from scipy.stats import pearsonr
17
+ from sklearn.metrics import r2_score
18
+
19
+
20
+ # ==============================================================================
21
+ # PUBLICATION-QUALITY PLOT CONFIGURATION
22
+ # ==============================================================================
23
+ # Width: 19 cm = 7.48 inches (for two-column journals)
24
+ FIGURE_WIDTH_CM = 19
25
+ FIGURE_WIDTH_INCH = FIGURE_WIDTH_CM / 2.54
26
+
27
+ # Font sizes (consistent 8pt for publication)
28
+ FONT_SIZE_TEXT = 8
29
+ FONT_SIZE_TICKS = 8
30
+ FONT_SIZE_TITLE = 9
31
+
32
+ # DPI for publication (300 for print, 150 for screen)
33
+ FIGURE_DPI = 300
34
+
35
+ # Color palette (accessible, print-friendly)
36
+ COLORS = {
37
+ "primary": "#2E86AB", # Steel blue
38
+ "secondary": "#A23B72", # Raspberry
39
+ "accent": "#F18F01", # Orange
40
+ "neutral": "#6B717E", # Slate gray
41
+ "error": "#C73E1D", # Red
42
+ "success": "#3A7D44", # Green
43
+ }
44
+
45
+
46
+ def configure_matplotlib_style():
47
+ """Configure matplotlib for publication-quality LaTeX-style plots."""
48
+ plt.rcParams.update(
49
+ {
50
+ # LaTeX-style fonts
51
+ "font.family": "serif",
52
+ "font.serif": ["Times New Roman", "DejaVu Serif", "serif"],
53
+ "mathtext.fontset": "cm", # Computer Modern for math
54
+ # Font sizes
55
+ "font.size": FONT_SIZE_TEXT,
56
+ "axes.titlesize": FONT_SIZE_TITLE,
57
+ "axes.labelsize": FONT_SIZE_TEXT,
58
+ "xtick.labelsize": FONT_SIZE_TICKS,
59
+ "ytick.labelsize": FONT_SIZE_TICKS,
60
+ "legend.fontsize": FONT_SIZE_TICKS,
61
+ # Line widths
62
+ "axes.linewidth": 0.8,
63
+ "grid.linewidth": 0.5,
64
+ "lines.linewidth": 1.5,
65
+ # Grid style
66
+ "grid.alpha": 0.4,
67
+ "grid.linestyle": ":",
68
+ # Figure settings
69
+ "figure.dpi": FIGURE_DPI,
70
+ "savefig.dpi": FIGURE_DPI,
71
+ "savefig.bbox": "tight",
72
+ "savefig.pad_inches": 0.05,
73
+ # Remove top/right spines for cleaner look
74
+ "axes.spines.top": False,
75
+ "axes.spines.right": False,
76
+ }
77
+ )
78
+
79
+
80
+ # Apply style on import
81
+ configure_matplotlib_style()
82
+
83
+
84
+ # ==============================================================================
85
+ # METRIC TRACKING
86
+ # ==============================================================================
87
+ class MetricTracker:
88
+ """
89
+ Tracks running averages of metrics with thread-safe accumulation.
90
+
91
+ Useful for tracking loss, accuracy, or any scalar metric across batches.
92
+ Handles division-by-zero safely by returning 0.0 when count is zero.
93
+
94
+ Attributes:
95
+ val: Most recent value
96
+ avg: Running average
97
+ sum: Cumulative sum
98
+ count: Number of samples
99
+
100
+ Example:
101
+ tracker = MetricTracker()
102
+ for batch in dataloader:
103
+ loss = compute_loss(batch)
104
+ tracker.update(loss.item(), n=batch_size)
105
+ print(f"Average loss: {tracker.avg}")
106
+ """
107
+
108
+ def __init__(self):
109
+ self.reset()
110
+
111
+ def reset(self):
112
+ """Reset all statistics to initial state."""
113
+ self.val: float = 0.0
114
+ self.avg: float = 0.0
115
+ self.sum: float = 0.0
116
+ self.count: float = 0.0
117
+
118
+ def update(self, val: float, n: int = 1):
119
+ """
120
+ Update tracker with new value(s).
121
+
122
+ Args:
123
+ val: New value (or mean of values if n > 1)
124
+ n: Number of samples this value represents
125
+ """
126
+ self.val = val
127
+ self.sum += val * n
128
+ self.count += n
129
+ self.avg = self.sum / self.count if self.count > 0 else 0.0
130
+
131
+ def __repr__(self) -> str:
132
+ return (
133
+ f"MetricTracker(val={self.val:.4f}, avg={self.avg:.4f}, count={self.count})"
134
+ )
135
+
136
+
137
+ # ==============================================================================
138
+ # STATISTICAL METRICS
139
+ # ==============================================================================
140
+ def get_lr(optimizer) -> float:
141
+ """
142
+ Extract current learning rate from optimizer.
143
+
144
+ Args:
145
+ optimizer: PyTorch optimizer instance
146
+
147
+ Returns:
148
+ Current learning rate (from first param group)
149
+ """
150
+ for param_group in optimizer.param_groups:
151
+ return param_group["lr"]
152
+ return 0.0
153
+
154
+
155
+ def calc_pearson(y_true: np.ndarray, y_pred: np.ndarray) -> float:
156
+ """
157
+ Calculate average Pearson Correlation Coefficient across all targets.
158
+
159
+ Handles edge cases where variance is near zero to avoid NaN values.
160
+ This metric is important for physics-based signal regression papers.
161
+
162
+ Args:
163
+ y_true: Ground truth values of shape (N, num_targets)
164
+ y_pred: Predicted values of shape (N, num_targets)
165
+
166
+ Returns:
167
+ Mean Pearson correlation across all targets
168
+ """
169
+ if y_true.ndim == 1:
170
+ y_true = y_true.reshape(-1, 1)
171
+ y_pred = y_pred.reshape(-1, 1)
172
+
173
+ correlations = []
174
+ for i in range(y_true.shape[1]):
175
+ # Check for near-constant arrays to avoid NaN
176
+ std_true = np.std(y_true[:, i])
177
+ std_pred = np.std(y_pred[:, i])
178
+
179
+ if std_true < 1e-9 or std_pred < 1e-9:
180
+ correlations.append(0.0)
181
+ else:
182
+ corr, _ = pearsonr(y_true[:, i], y_pred[:, i])
183
+ # Handle NaN from pearsonr (shouldn't happen with std check, but safety)
184
+ correlations.append(corr if not np.isnan(corr) else 0.0)
185
+
186
+ return float(np.mean(correlations))
187
+
188
+
189
+ def calc_per_target_r2(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
190
+ """
191
+ Calculate R² score for each target independently.
192
+
193
+ Args:
194
+ y_true: Ground truth values of shape (N, num_targets)
195
+ y_pred: Predicted values of shape (N, num_targets)
196
+
197
+ Returns:
198
+ Array of R² scores, one per target
199
+ """
200
+ if y_true.ndim == 1:
201
+ return np.array([r2_score(y_true, y_pred)])
202
+
203
+ r2_scores = []
204
+ for i in range(y_true.shape[1]):
205
+ try:
206
+ r2 = r2_score(y_true[:, i], y_pred[:, i])
207
+ r2_scores.append(r2)
208
+ except ValueError:
209
+ r2_scores.append(0.0)
210
+
211
+ return np.array(r2_scores)
212
+
213
+
214
+ # ==============================================================================
215
+ # VISUALIZATION - SCATTER PLOTS
216
+ # ==============================================================================
217
+ def plot_scientific_scatter(
218
+ y_true: np.ndarray,
219
+ y_pred: np.ndarray,
220
+ param_names: list[str] | None = None,
221
+ max_samples: int = 2000,
222
+ ) -> plt.Figure:
223
+ """
224
+ Generate publication-quality scatter plots comparing predictions to ground truth.
225
+
226
+ Creates a grid of scatter plots, one for each output target, with R² annotations
227
+ and ideal diagonal reference lines.
228
+
229
+ Args:
230
+ y_true: Ground truth values of shape (N, num_targets)
231
+ y_pred: Predicted values of shape (N, num_targets)
232
+ param_names: Optional list of parameter names for titles
233
+ max_samples: Maximum samples to plot (downsamples if exceeded)
234
+
235
+ Returns:
236
+ Matplotlib Figure object (can be saved or logged to WandB)
237
+ """
238
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
239
+
240
+ # Handle 1D case
241
+ if y_true.ndim == 1:
242
+ y_true = y_true.reshape(-1, 1)
243
+ y_pred = y_pred.reshape(-1, 1)
244
+
245
+ # Downsample for visualization performance
246
+ if len(y_true) > max_samples:
247
+ indices = np.random.choice(len(y_true), max_samples, replace=False)
248
+ y_true = y_true[indices]
249
+ y_pred = y_pred[indices]
250
+
251
+ # Calculate grid dimensions
252
+ cols = min(num_params, 4)
253
+ rows = (num_params + cols - 1) // cols
254
+
255
+ # Calculate figure size (19 cm width)
256
+ subplot_size = FIGURE_WIDTH_INCH / cols
257
+ fig, axes = plt.subplots(
258
+ rows,
259
+ cols,
260
+ figsize=(FIGURE_WIDTH_INCH, subplot_size * rows),
261
+ )
262
+ axes = np.array(axes).flatten() if num_params > 1 else [axes]
263
+
264
+ for i in range(num_params):
265
+ ax = axes[i]
266
+
267
+ # Calculate R² for this target
268
+ if len(y_true) >= 2:
269
+ r2 = r2_score(y_true[:, i], y_pred[:, i])
270
+ else:
271
+ r2 = float("nan")
272
+
273
+ # Scatter plot
274
+ ax.scatter(
275
+ y_true[:, i],
276
+ y_pred[:, i],
277
+ alpha=0.5,
278
+ s=15,
279
+ c=COLORS["primary"],
280
+ edgecolors="none",
281
+ rasterized=True,
282
+ label="Data",
283
+ )
284
+
285
+ # Ideal diagonal line
286
+ min_val = min(y_true[:, i].min(), y_pred[:, i].min())
287
+ max_val = max(y_true[:, i].max(), y_pred[:, i].max())
288
+ margin = (max_val - min_val) * 0.05
289
+ ax.plot(
290
+ [min_val - margin, max_val + margin],
291
+ [min_val - margin, max_val + margin],
292
+ "--",
293
+ color=COLORS["error"],
294
+ lw=1.2,
295
+ alpha=0.8,
296
+ label="Ideal",
297
+ )
298
+
299
+ # Labels and formatting
300
+ title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
301
+ ax.set_title(f"{title}\n$R^2 = {r2:.4f}$")
302
+ ax.set_xlabel("Ground Truth")
303
+ ax.set_ylabel("Prediction")
304
+ ax.grid(True)
305
+ ax.set_aspect("equal", adjustable="box")
306
+ ax.set_xlim(min_val - margin, max_val + margin)
307
+ ax.set_ylim(min_val - margin, max_val + margin)
308
+ ax.legend(fontsize=6, loc="best")
309
+
310
+ # Hide unused subplots
311
+ for i in range(num_params, len(axes)):
312
+ axes[i].axis("off")
313
+
314
+ plt.tight_layout()
315
+ return fig
316
+
317
+
318
+ # ==============================================================================
319
+ # VISUALIZATION - ERROR HISTOGRAM
320
+ # ==============================================================================
321
+ def plot_error_histogram(
322
+ y_true: np.ndarray,
323
+ y_pred: np.ndarray,
324
+ param_names: list[str] | None = None,
325
+ bins: int = 50,
326
+ ) -> plt.Figure:
327
+ """
328
+ Generate publication-quality error distribution histograms.
329
+
330
+ Shows the distribution of prediction errors (y_pred - y_true) for each target.
331
+ Includes mean, std, and MAE annotations.
332
+
333
+ Args:
334
+ y_true: Ground truth values of shape (N, num_targets)
335
+ y_pred: Predicted values of shape (N, num_targets)
336
+ param_names: Optional list of parameter names for titles
337
+ bins: Number of histogram bins
338
+
339
+ Returns:
340
+ Matplotlib Figure object
341
+ """
342
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
343
+
344
+ # Handle 1D case
345
+ if y_true.ndim == 1:
346
+ y_true = y_true.reshape(-1, 1)
347
+ y_pred = y_pred.reshape(-1, 1)
348
+
349
+ # Calculate errors
350
+ errors = y_pred - y_true
351
+
352
+ # Calculate grid dimensions
353
+ cols = min(num_params, 4)
354
+ rows = (num_params + cols - 1) // cols
355
+
356
+ # Calculate figure size
357
+ subplot_size = FIGURE_WIDTH_INCH / cols
358
+ fig, axes = plt.subplots(
359
+ rows,
360
+ cols,
361
+ figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
362
+ )
363
+ axes = np.array(axes).flatten() if num_params > 1 else [axes]
364
+
365
+ for i in range(num_params):
366
+ ax = axes[i]
367
+ err = errors[:, i]
368
+
369
+ # Statistics
370
+ mean_err = np.mean(err)
371
+ std_err = np.std(err)
372
+ mae = np.mean(np.abs(err))
373
+
374
+ # Histogram
375
+ ax.hist(
376
+ err,
377
+ bins=bins,
378
+ color=COLORS["primary"],
379
+ alpha=0.7,
380
+ edgecolor="white",
381
+ linewidth=0.5,
382
+ label="Errors",
383
+ )
384
+
385
+ # Vertical line at zero
386
+ ax.axvline(
387
+ x=0, color=COLORS["error"], linestyle="--", lw=1.2, alpha=0.8, label="Zero"
388
+ )
389
+
390
+ # Mean line
391
+ ax.axvline(
392
+ x=mean_err,
393
+ color=COLORS["accent"],
394
+ linestyle="-",
395
+ lw=1.2,
396
+ label=f"Mean = {mean_err:.4f}",
397
+ )
398
+
399
+ # Labels and formatting
400
+ title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
401
+ ax.set_title(f"{title}\nMAE = {mae:.4f}, $\\sigma$ = {std_err:.4f}")
402
+ ax.set_xlabel("Prediction Error")
403
+ ax.set_ylabel("Count")
404
+ ax.grid(True, axis="y")
405
+ ax.legend(fontsize=6, loc="best")
406
+
407
+ # Hide unused subplots
408
+ for i in range(num_params, len(axes)):
409
+ axes[i].axis("off")
410
+
411
+ plt.tight_layout()
412
+ return fig
413
+
414
+
415
+ # ==============================================================================
416
+ # VISUALIZATION - RESIDUAL PLOT
417
+ # ==============================================================================
418
+ def plot_residuals(
419
+ y_true: np.ndarray,
420
+ y_pred: np.ndarray,
421
+ param_names: list[str] | None = None,
422
+ max_samples: int = 2000,
423
+ ) -> plt.Figure:
424
+ """
425
+ Generate publication-quality residual plots.
426
+
427
+ Shows residuals (y_pred - y_true) vs predicted values. Useful for detecting
428
+ systematic bias or heteroscedasticity in predictions.
429
+
430
+ Args:
431
+ y_true: Ground truth values of shape (N, num_targets)
432
+ y_pred: Predicted values of shape (N, num_targets)
433
+ param_names: Optional list of parameter names for titles
434
+ max_samples: Maximum samples to plot
435
+
436
+ Returns:
437
+ Matplotlib Figure object
438
+ """
439
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
440
+
441
+ # Handle 1D case
442
+ if y_true.ndim == 1:
443
+ y_true = y_true.reshape(-1, 1)
444
+ y_pred = y_pred.reshape(-1, 1)
445
+
446
+ # Calculate residuals
447
+ residuals = y_pred - y_true
448
+
449
+ # Downsample for visualization
450
+ if len(y_true) > max_samples:
451
+ indices = np.random.choice(len(y_true), max_samples, replace=False)
452
+ y_pred = y_pred[indices]
453
+ residuals = residuals[indices]
454
+
455
+ # Calculate grid dimensions
456
+ cols = min(num_params, 4)
457
+ rows = (num_params + cols - 1) // cols
458
+
459
+ # Calculate figure size
460
+ subplot_size = FIGURE_WIDTH_INCH / cols
461
+ fig, axes = plt.subplots(
462
+ rows,
463
+ cols,
464
+ figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
465
+ )
466
+ axes = np.array(axes).flatten() if num_params > 1 else [axes]
467
+
468
+ for i in range(num_params):
469
+ ax = axes[i]
470
+
471
+ # Scatter plot of residuals
472
+ ax.scatter(
473
+ y_pred[:, i],
474
+ residuals[:, i],
475
+ alpha=0.5,
476
+ s=15,
477
+ c=COLORS["primary"],
478
+ edgecolors="none",
479
+ rasterized=True,
480
+ label="Data",
481
+ )
482
+
483
+ # Zero line
484
+ ax.axhline(
485
+ y=0, color=COLORS["error"], linestyle="--", lw=1.2, alpha=0.8, label="Zero"
486
+ )
487
+
488
+ # Calculate and show mean residual
489
+ mean_res = np.mean(residuals[:, i])
490
+ ax.axhline(
491
+ y=mean_res,
492
+ color=COLORS["accent"],
493
+ linestyle="-",
494
+ lw=1.0,
495
+ alpha=0.7,
496
+ label=f"Mean = {mean_res:.4f}",
497
+ )
498
+
499
+ # Labels
500
+ title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
501
+ ax.set_title(f"{title}")
502
+ ax.set_xlabel("Predicted Value")
503
+ ax.set_ylabel("Residual (Pred - True)")
504
+ ax.grid(True)
505
+ ax.legend(fontsize=6, loc="best")
506
+
507
+ # Hide unused subplots
508
+ for i in range(num_params, len(axes)):
509
+ axes[i].axis("off")
510
+
511
+ plt.tight_layout()
512
+ return fig
513
+
514
+
515
+ # ==============================================================================
516
+ # VISUALIZATION - TRAINING CURVES
517
+ # ==============================================================================
518
+ def create_training_curves(
519
+ history: list[dict[str, Any]],
520
+ metrics: list[str] = ["train_loss", "val_loss"],
521
+ show_lr: bool = True,
522
+ ) -> plt.Figure:
523
+ """
524
+ Create training curve visualization from history with optional learning rate.
525
+
526
+ Plots training and validation loss over epochs. If learning rate data is
527
+ available in history, it's plotted on a secondary y-axis.
528
+
529
+ Args:
530
+ history: List of epoch statistics dictionaries. Each dict should contain
531
+ 'epoch', 'train_loss', 'val_loss', and optionally 'lr'.
532
+ metrics: Metric names to plot on primary y-axis
533
+ show_lr: If True and 'lr' is in history, show learning rate on secondary axis
534
+
535
+ Returns:
536
+ Matplotlib Figure object
537
+ """
538
+ epochs = [h.get("epoch", i + 1) for i, h in enumerate(history)]
539
+
540
+ fig, ax1 = plt.subplots(figsize=(FIGURE_WIDTH_INCH * 0.7, FIGURE_WIDTH_INCH * 0.4))
541
+
542
+ colors = [
543
+ COLORS["primary"],
544
+ COLORS["secondary"],
545
+ COLORS["accent"],
546
+ COLORS["neutral"],
547
+ ]
548
+
549
+ # Plot metrics on primary axis
550
+ lines = []
551
+ for idx, metric in enumerate(metrics):
552
+ values = [h.get(metric, np.nan) for h in history]
553
+ color = colors[idx % len(colors)]
554
+ (line,) = ax1.plot(
555
+ epochs,
556
+ values,
557
+ label=metric.replace("_", " ").title(),
558
+ linewidth=1.5,
559
+ color=color,
560
+ )
561
+ lines.append(line)
562
+
563
+ ax1.set_xlabel("Epoch")
564
+ ax1.set_ylabel("Loss")
565
+ ax1.set_yscale("log") # Log scale for loss
566
+ ax1.grid(True, alpha=0.3)
567
+
568
+ # Check if learning rate data exists
569
+ has_lr = show_lr and any("lr" in h for h in history)
570
+
571
+ if has_lr:
572
+ # Create secondary y-axis for learning rate
573
+ ax2 = ax1.twinx()
574
+ lr_values = [h.get("lr", np.nan) for h in history]
575
+ (line_lr,) = ax2.plot(
576
+ epochs,
577
+ lr_values,
578
+ "--",
579
+ color=COLORS["neutral"],
580
+ linewidth=1.0,
581
+ alpha=0.7,
582
+ label="Learning Rate",
583
+ )
584
+ ax2.set_ylabel("Learning Rate", color=COLORS["neutral"])
585
+ ax2.tick_params(axis="y", labelcolor=COLORS["neutral"])
586
+ ax2.set_yscale("log") # Log scale for LR
587
+ lines.append(line_lr)
588
+
589
+ # Combined legend
590
+ labels = [l.get_label() for l in lines]
591
+ ax1.legend(lines, labels, loc="best", fontsize=6)
592
+
593
+ ax1.set_title("Training Curves")
594
+
595
+ plt.tight_layout()
596
+ return fig
597
+
598
+
599
+ # ==============================================================================
600
+ # VISUALIZATION - BLAND-ALTMAN PLOT
601
+ # ==============================================================================
602
+ def plot_bland_altman(
603
+ y_true: np.ndarray,
604
+ y_pred: np.ndarray,
605
+ param_names: list[str] | None = None,
606
+ max_samples: int = 2000,
607
+ ) -> plt.Figure:
608
+ """
609
+ Generate Bland-Altman plots for method comparison.
610
+
611
+ Plots the difference between predictions and ground truth against their mean.
612
+ Includes mean difference line and ±1.96*SD limits of agreement.
613
+ Standard for validating agreement in medical/scientific papers.
614
+
615
+ Args:
616
+ y_true: Ground truth values of shape (N, num_targets)
617
+ y_pred: Predicted values of shape (N, num_targets)
618
+ param_names: Optional list of parameter names for titles
619
+ max_samples: Maximum samples to plot
620
+
621
+ Returns:
622
+ Matplotlib Figure object
623
+ """
624
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
625
+
626
+ # Handle 1D case
627
+ if y_true.ndim == 1:
628
+ y_true = y_true.reshape(-1, 1)
629
+ y_pred = y_pred.reshape(-1, 1)
630
+
631
+ # Calculate mean and difference
632
+ mean_vals = (y_true + y_pred) / 2
633
+ diff_vals = y_pred - y_true
634
+
635
+ # Downsample for visualization
636
+ if len(y_true) > max_samples:
637
+ indices = np.random.choice(len(y_true), max_samples, replace=False)
638
+ mean_vals = mean_vals[indices]
639
+ diff_vals = diff_vals[indices]
640
+
641
+ # Calculate grid dimensions
642
+ cols = min(num_params, 4)
643
+ rows = (num_params + cols - 1) // cols
644
+
645
+ # Calculate figure size
646
+ subplot_size = FIGURE_WIDTH_INCH / cols
647
+ fig, axes = plt.subplots(
648
+ rows,
649
+ cols,
650
+ figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
651
+ )
652
+ axes = np.array(axes).flatten() if num_params > 1 else [axes]
653
+
654
+ for i in range(num_params):
655
+ ax = axes[i]
656
+
657
+ mean_diff = np.mean(diff_vals[:, i])
658
+ std_diff = np.std(diff_vals[:, i])
659
+
660
+ # Limits of agreement (95% CI = mean ± 1.96*SD)
661
+ upper_loa = mean_diff + 1.96 * std_diff
662
+ lower_loa = mean_diff - 1.96 * std_diff
663
+
664
+ # Scatter plot
665
+ ax.scatter(
666
+ mean_vals[:, i],
667
+ diff_vals[:, i],
668
+ alpha=0.5,
669
+ s=15,
670
+ c=COLORS["primary"],
671
+ edgecolors="none",
672
+ rasterized=True,
673
+ )
674
+
675
+ # Mean difference line
676
+ ax.axhline(
677
+ y=mean_diff,
678
+ color=COLORS["accent"],
679
+ linestyle="-",
680
+ lw=1.2,
681
+ label=f"Mean = {mean_diff:.3f}",
682
+ )
683
+
684
+ # Limits of agreement
685
+ ax.axhline(
686
+ y=upper_loa,
687
+ color=COLORS["error"],
688
+ linestyle="--",
689
+ lw=1.0,
690
+ label=f"+1.96 SD = {upper_loa:.3f}",
691
+ )
692
+ ax.axhline(
693
+ y=lower_loa,
694
+ color=COLORS["error"],
695
+ linestyle="--",
696
+ lw=1.0,
697
+ label=f"-1.96 SD = {lower_loa:.3f}",
698
+ )
699
+
700
+ # Zero line
701
+ ax.axhline(y=0, color="gray", linestyle=":", lw=0.8, alpha=0.5)
702
+
703
+ # Labels
704
+ title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
705
+ ax.set_title(f"{title}")
706
+ ax.set_xlabel("Mean of True and Predicted")
707
+ ax.set_ylabel("Difference (Pred - True)")
708
+ ax.grid(True)
709
+ ax.legend(fontsize=6, loc="best")
710
+
711
+ # Hide unused subplots
712
+ for i in range(num_params, len(axes)):
713
+ axes[i].axis("off")
714
+
715
+ plt.tight_layout()
716
+ return fig
717
+
718
+
719
+ # ==============================================================================
720
+ # VISUALIZATION - QQ PLOT
721
+ # ==============================================================================
722
+ def plot_qq(
723
+ y_true: np.ndarray,
724
+ y_pred: np.ndarray,
725
+ param_names: list[str] | None = None,
726
+ ) -> plt.Figure:
727
+ """
728
+ Generate Q-Q plots to check if prediction errors are normally distributed.
729
+
730
+ Compares the quantiles of the error distribution to a theoretical normal
731
+ distribution. Points on the diagonal indicate normally distributed errors.
732
+
733
+ Args:
734
+ y_true: Ground truth values of shape (N, num_targets)
735
+ y_pred: Predicted values of shape (N, num_targets)
736
+ param_names: Optional list of parameter names for titles
737
+
738
+ Returns:
739
+ Matplotlib Figure object
740
+ """
741
+ from scipy import stats
742
+
743
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
744
+
745
+ # Handle 1D case
746
+ if y_true.ndim == 1:
747
+ y_true = y_true.reshape(-1, 1)
748
+ y_pred = y_pred.reshape(-1, 1)
749
+
750
+ # Calculate errors
751
+ errors = y_pred - y_true
752
+
753
+ # Calculate grid dimensions
754
+ cols = min(num_params, 4)
755
+ rows = (num_params + cols - 1) // cols
756
+
757
+ # Calculate figure size
758
+ subplot_size = FIGURE_WIDTH_INCH / cols
759
+ fig, axes = plt.subplots(
760
+ rows,
761
+ cols,
762
+ figsize=(FIGURE_WIDTH_INCH, subplot_size * rows),
763
+ )
764
+ axes = np.array(axes).flatten() if num_params > 1 else [axes]
765
+
766
+ for i in range(num_params):
767
+ ax = axes[i]
768
+
769
+ # Standardize errors for QQ plot
770
+ err = errors[:, i]
771
+ standardized = (err - np.mean(err)) / np.std(err)
772
+
773
+ # Calculate theoretical quantiles and sample quantiles
774
+ (osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
775
+
776
+ # Scatter plot
777
+ ax.scatter(
778
+ osm,
779
+ osr,
780
+ alpha=0.5,
781
+ s=15,
782
+ c=COLORS["primary"],
783
+ edgecolors="none",
784
+ rasterized=True,
785
+ label="Data",
786
+ )
787
+
788
+ # Reference line
789
+ line_x = np.array([osm.min(), osm.max()])
790
+ line_y = slope * line_x + intercept
791
+ ax.plot(line_x, line_y, "--", color=COLORS["error"], lw=1.2, label="Normal")
792
+
793
+ # Labels
794
+ title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
795
+ ax.set_title(f"{title}\n$R^2 = {r**2:.4f}$")
796
+ ax.set_xlabel("Theoretical Quantiles")
797
+ ax.set_ylabel("Sample Quantiles")
798
+ ax.grid(True)
799
+ ax.legend(fontsize=6, loc="best")
800
+
801
+ # Hide unused subplots
802
+ for i in range(num_params, len(axes)):
803
+ axes[i].axis("off")
804
+
805
+ plt.tight_layout()
806
+ return fig
807
+
808
+
809
+ # ==============================================================================
810
+ # VISUALIZATION - CORRELATION HEATMAP
811
+ # ==============================================================================
812
+ def plot_correlation_heatmap(
813
+ y_true: np.ndarray,
814
+ y_pred: np.ndarray,
815
+ param_names: list[str] | None = None,
816
+ ) -> plt.Figure:
817
+ """
818
+ Generate correlation heatmap between predicted parameters.
819
+
820
+ Shows the Pearson correlation between different output parameters,
821
+ useful for understanding multi-output prediction relationships.
822
+
823
+ Args:
824
+ y_true: Ground truth values of shape (N, num_targets)
825
+ y_pred: Predicted values of shape (N, num_targets)
826
+ param_names: Optional list of parameter names for labels
827
+
828
+ Returns:
829
+ Matplotlib Figure object
830
+ """
831
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
832
+
833
+ if num_params < 2:
834
+ # Need at least 2 params for correlation
835
+ fig, ax = plt.subplots(figsize=(4, 3))
836
+ ax.text(
837
+ 0.5,
838
+ 0.5,
839
+ "Correlation heatmap requires\nat least 2 parameters",
840
+ ha="center",
841
+ va="center",
842
+ fontsize=10,
843
+ )
844
+ ax.axis("off")
845
+ return fig
846
+
847
+ # Handle 1D case
848
+ if y_true.ndim == 1:
849
+ y_true = y_true.reshape(-1, 1)
850
+ y_pred = y_pred.reshape(-1, 1)
851
+
852
+ if param_names is None or len(param_names) != num_params:
853
+ param_names = [f"P{i}" for i in range(num_params)]
854
+
855
+ # Calculate prediction error correlations
856
+ errors = y_pred - y_true
857
+ corr_matrix = np.corrcoef(errors.T)
858
+
859
+ # Create figure
860
+ fig_size = min(FIGURE_WIDTH_INCH * 0.6, 2 + num_params * 0.6)
861
+ fig, ax = plt.subplots(figsize=(fig_size, fig_size))
862
+
863
+ # Heatmap
864
+ im = ax.imshow(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, aspect="equal")
865
+
866
+ # Colorbar
867
+ cbar = plt.colorbar(im, ax=ax, shrink=0.8)
868
+ cbar.set_label("Correlation", fontsize=FONT_SIZE_TICKS)
869
+
870
+ # Labels
871
+ ax.set_xticks(range(num_params))
872
+ ax.set_yticks(range(num_params))
873
+ ax.set_xticklabels(param_names, rotation=45, ha="right")
874
+ ax.set_yticklabels(param_names)
875
+
876
+ # Annotate with values
877
+ for i in range(num_params):
878
+ for j in range(num_params):
879
+ color = "white" if abs(corr_matrix[i, j]) > 0.5 else "black"
880
+ ax.text(
881
+ j,
882
+ i,
883
+ f"{corr_matrix[i, j]:.2f}",
884
+ ha="center",
885
+ va="center",
886
+ color=color,
887
+ fontsize=FONT_SIZE_TICKS,
888
+ )
889
+
890
+ ax.set_title("Error Correlation Matrix")
891
+
892
+ plt.tight_layout()
893
+ return fig
894
+
895
+
896
+ # ==============================================================================
897
+ # VISUALIZATION - RELATIVE ERROR PLOT
898
+ # ==============================================================================
899
+ def plot_relative_error(
900
+ y_true: np.ndarray,
901
+ y_pred: np.ndarray,
902
+ param_names: list[str] | None = None,
903
+ max_samples: int = 2000,
904
+ ) -> plt.Figure:
905
+ """
906
+ Generate relative error plots (percentage error vs true value).
907
+
908
+ Useful for detecting scale-dependent bias where errors increase
909
+ proportionally with the magnitude of the true value.
910
+
911
+ Args:
912
+ y_true: Ground truth values of shape (N, num_targets)
913
+ y_pred: Predicted values of shape (N, num_targets)
914
+ param_names: Optional list of parameter names for titles
915
+ max_samples: Maximum samples to plot
916
+
917
+ Returns:
918
+ Matplotlib Figure object
919
+ """
920
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
921
+
922
+ # Handle 1D case
923
+ if y_true.ndim == 1:
924
+ y_true = y_true.reshape(-1, 1)
925
+ y_pred = y_pred.reshape(-1, 1)
926
+
927
+ # Downsample for visualization
928
+ if len(y_true) > max_samples:
929
+ indices = np.random.choice(len(y_true), max_samples, replace=False)
930
+ y_true = y_true[indices]
931
+ y_pred = y_pred[indices]
932
+
933
+ # Calculate relative error (avoid division by zero)
934
+ with np.errstate(divide="ignore", invalid="ignore"):
935
+ rel_error = np.abs((y_pred - y_true) / y_true) * 100
936
+ rel_error = np.nan_to_num(rel_error, nan=0.0, posinf=0.0, neginf=0.0)
937
+
938
+ # Calculate grid dimensions
939
+ cols = min(num_params, 4)
940
+ rows = (num_params + cols - 1) // cols
941
+
942
+ # Calculate figure size
943
+ subplot_size = FIGURE_WIDTH_INCH / cols
944
+ fig, axes = plt.subplots(
945
+ rows,
946
+ cols,
947
+ figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
948
+ )
949
+ axes = np.array(axes).flatten() if num_params > 1 else [axes]
950
+
951
+ for i in range(num_params):
952
+ ax = axes[i]
953
+
954
+ # Scatter plot
955
+ ax.scatter(
956
+ y_true[:, i],
957
+ rel_error[:, i],
958
+ alpha=0.5,
959
+ s=15,
960
+ c=COLORS["primary"],
961
+ edgecolors="none",
962
+ rasterized=True,
963
+ label="Data",
964
+ )
965
+
966
+ # Mean relative error line
967
+ mean_rel = np.mean(rel_error[:, i])
968
+ ax.axhline(
969
+ y=mean_rel,
970
+ color=COLORS["accent"],
971
+ linestyle="-",
972
+ lw=1.2,
973
+ label=f"Mean = {mean_rel:.2f}%",
974
+ )
975
+
976
+ # Labels
977
+ title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
978
+ ax.set_title(f"{title}")
979
+ ax.set_xlabel("True Value")
980
+ ax.set_ylabel("Relative Error (%)")
981
+ ax.grid(True)
982
+ ax.legend(fontsize=6, loc="best")
983
+
984
+ # Hide unused subplots
985
+ for i in range(num_params, len(axes)):
986
+ axes[i].axis("off")
987
+
988
+ plt.tight_layout()
989
+ return fig
990
+
991
+
992
+ # ==============================================================================
993
+ # VISUALIZATION - CUMULATIVE ERROR DISTRIBUTION (CDF)
994
+ # ==============================================================================
995
+ def plot_error_cdf(
996
+ y_true: np.ndarray,
997
+ y_pred: np.ndarray,
998
+ param_names: list[str] | None = None,
999
+ use_relative: bool = True,
1000
+ ) -> plt.Figure:
1001
+ """
1002
+ Generate cumulative distribution function (CDF) of prediction errors.
1003
+
1004
+ Shows what percentage of predictions fall within a given error bound.
1005
+ Very useful for reporting: "95% of predictions have error < X%"
1006
+
1007
+ Args:
1008
+ y_true: Ground truth values of shape (N, num_targets)
1009
+ y_pred: Predicted values of shape (N, num_targets)
1010
+ param_names: Optional list of parameter names for legend
1011
+ use_relative: If True, plot relative error (%), else absolute error
1012
+
1013
+ Returns:
1014
+ Matplotlib Figure object
1015
+ """
1016
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
1017
+
1018
+ # Handle 1D case
1019
+ if y_true.ndim == 1:
1020
+ y_true = y_true.reshape(-1, 1)
1021
+ y_pred = y_pred.reshape(-1, 1)
1022
+
1023
+ if param_names is None or len(param_names) != num_params:
1024
+ param_names = [f"P{i}" for i in range(num_params)]
1025
+
1026
+ # Calculate errors
1027
+ if use_relative:
1028
+ with np.errstate(divide="ignore", invalid="ignore"):
1029
+ errors = np.abs((y_pred - y_true) / y_true) * 100
1030
+ errors = np.nan_to_num(errors, nan=0.0, posinf=0.0, neginf=0.0)
1031
+ xlabel = "Relative Error (%)"
1032
+ else:
1033
+ errors = np.abs(y_pred - y_true)
1034
+ xlabel = "Absolute Error"
1035
+
1036
+ # Create figure
1037
+ fig, ax = plt.subplots(figsize=(FIGURE_WIDTH_INCH * 0.6, FIGURE_WIDTH_INCH * 0.4))
1038
+
1039
+ colors_list = [
1040
+ COLORS["primary"],
1041
+ COLORS["secondary"],
1042
+ COLORS["accent"],
1043
+ COLORS["success"],
1044
+ COLORS["neutral"],
1045
+ ]
1046
+
1047
+ for i in range(num_params):
1048
+ err = np.sort(errors[:, i])
1049
+ cdf = np.arange(1, len(err) + 1) / len(err)
1050
+
1051
+ color = colors_list[i % len(colors_list)]
1052
+ ax.plot(err, cdf * 100, label=param_names[i], color=color, lw=1.5)
1053
+
1054
+ # Find 95th percentile (use np.percentile for accuracy)
1055
+ p95_val = np.percentile(errors[:, i], 95)
1056
+ ax.axvline(x=p95_val, color=color, linestyle=":", alpha=0.5)
1057
+
1058
+ # Reference lines
1059
+ ax.axhline(y=95, color="gray", linestyle="--", lw=0.8, alpha=0.7, label="95%")
1060
+
1061
+ ax.set_xlabel(xlabel)
1062
+ ax.set_ylabel("Cumulative Percentage (%)")
1063
+ ax.set_title("Cumulative Error Distribution")
1064
+ ax.legend(fontsize=6, loc="best")
1065
+ ax.grid(True)
1066
+ ax.set_ylim(0, 105)
1067
+
1068
+ plt.tight_layout()
1069
+ return fig
1070
+
1071
+
1072
+ # ==============================================================================
1073
+ # VISUALIZATION - PREDICTION VS SAMPLE INDEX
1074
+ # ==============================================================================
1075
+ def plot_prediction_vs_index(
1076
+ y_true: np.ndarray,
1077
+ y_pred: np.ndarray,
1078
+ param_names: list[str] | None = None,
1079
+ max_samples: int = 500,
1080
+ ) -> plt.Figure:
1081
+ """
1082
+ Generate prediction vs sample index plots.
1083
+
1084
+ Shows true and predicted values for each sample in sequence.
1085
+ Useful for time-series style visualization and spotting outliers.
1086
+
1087
+ Args:
1088
+ y_true: Ground truth values of shape (N, num_targets)
1089
+ y_pred: Predicted values of shape (N, num_targets)
1090
+ param_names: Optional list of parameter names for titles
1091
+ max_samples: Maximum samples to show
1092
+
1093
+ Returns:
1094
+ Matplotlib Figure object
1095
+ """
1096
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
1097
+
1098
+ # Handle 1D case
1099
+ if y_true.ndim == 1:
1100
+ y_true = y_true.reshape(-1, 1)
1101
+ y_pred = y_pred.reshape(-1, 1)
1102
+
1103
+ # Limit samples
1104
+ n_samples = min(len(y_true), max_samples)
1105
+ y_true = y_true[:n_samples]
1106
+ y_pred = y_pred[:n_samples]
1107
+ indices = np.arange(n_samples)
1108
+
1109
+ # Calculate grid dimensions
1110
+ cols = min(num_params, 4)
1111
+ rows = (num_params + cols - 1) // cols
1112
+
1113
+ # Calculate figure size
1114
+ subplot_size = FIGURE_WIDTH_INCH / cols
1115
+ fig, axes = plt.subplots(
1116
+ rows,
1117
+ cols,
1118
+ figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
1119
+ )
1120
+ axes = np.array(axes).flatten() if num_params > 1 else [axes]
1121
+
1122
+ for i in range(num_params):
1123
+ ax = axes[i]
1124
+
1125
+ # Plot true and predicted
1126
+ ax.plot(
1127
+ indices,
1128
+ y_true[:, i],
1129
+ "o",
1130
+ markersize=3,
1131
+ alpha=0.6,
1132
+ color=COLORS["primary"],
1133
+ label="True",
1134
+ )
1135
+ ax.plot(
1136
+ indices,
1137
+ y_pred[:, i],
1138
+ "x",
1139
+ markersize=3,
1140
+ alpha=0.6,
1141
+ color=COLORS["error"],
1142
+ label="Predicted",
1143
+ )
1144
+
1145
+ # Labels
1146
+ title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
1147
+ ax.set_title(f"{title}")
1148
+ ax.set_xlabel("Sample Index")
1149
+ ax.set_ylabel("Value")
1150
+ ax.grid(True)
1151
+ ax.legend(fontsize=6, loc="best")
1152
+
1153
+ # Hide unused subplots
1154
+ for i in range(num_params, len(axes)):
1155
+ axes[i].axis("off")
1156
+
1157
+ plt.tight_layout()
1158
+ return fig
1159
+
1160
+
1161
+ # ==============================================================================
1162
+ # VISUALIZATION - ERROR BOX PLOT
1163
+ # ==============================================================================
1164
+ def plot_error_boxplot(
1165
+ y_true: np.ndarray,
1166
+ y_pred: np.ndarray,
1167
+ param_names: list[str] | None = None,
1168
+ use_relative: bool = False,
1169
+ ) -> plt.Figure:
1170
+ """
1171
+ Generate box plots comparing error distributions across parameters.
1172
+
1173
+ Provides a compact view of error statistics (median, quartiles, outliers)
1174
+ for all parameters side-by-side.
1175
+
1176
+ Args:
1177
+ y_true: Ground truth values of shape (N, num_targets)
1178
+ y_pred: Predicted values of shape (N, num_targets)
1179
+ param_names: Optional list of parameter names for x-axis
1180
+ use_relative: If True, plot relative error (%), else absolute error
1181
+
1182
+ Returns:
1183
+ Matplotlib Figure object
1184
+ """
1185
+ num_params = y_true.shape[1] if y_true.ndim > 1 else 1
1186
+
1187
+ # Handle 1D case
1188
+ if y_true.ndim == 1:
1189
+ y_true = y_true.reshape(-1, 1)
1190
+ y_pred = y_pred.reshape(-1, 1)
1191
+
1192
+ if param_names is None or len(param_names) != num_params:
1193
+ param_names = [f"P{i}" for i in range(num_params)]
1194
+
1195
+ # Calculate errors
1196
+ if use_relative:
1197
+ with np.errstate(divide="ignore", invalid="ignore"):
1198
+ errors = np.abs((y_pred - y_true) / y_true) * 100
1199
+ errors = np.nan_to_num(errors, nan=0.0, posinf=0.0, neginf=0.0)
1200
+ ylabel = "Relative Error (%)"
1201
+ else:
1202
+ errors = y_pred - y_true # Signed error for box plot
1203
+ ylabel = "Prediction Error"
1204
+
1205
+ # Create figure
1206
+ fig_width = min(FIGURE_WIDTH_INCH * 0.5, 2 + num_params * 0.8)
1207
+ fig, ax = plt.subplots(figsize=(fig_width, FIGURE_WIDTH_INCH * 0.4))
1208
+
1209
+ # Box plot
1210
+ bp = ax.boxplot(
1211
+ [errors[:, i] for i in range(num_params)],
1212
+ labels=param_names,
1213
+ patch_artist=True,
1214
+ showfliers=True,
1215
+ flierprops={"marker": "o", "markersize": 3, "alpha": 0.5},
1216
+ )
1217
+
1218
+ # Color the boxes
1219
+ for patch in bp["boxes"]:
1220
+ patch.set_facecolor(COLORS["primary"])
1221
+ patch.set_alpha(0.7)
1222
+
1223
+ # Zero line for signed errors
1224
+ if not use_relative:
1225
+ ax.axhline(y=0, color=COLORS["error"], linestyle="--", lw=1.0, alpha=0.7)
1226
+
1227
+ ax.set_ylabel(ylabel)
1228
+ ax.set_title("Error Distribution by Parameter")
1229
+ ax.grid(True, axis="y")
1230
+
1231
+ # Rotate labels if needed
1232
+ if num_params > 4:
1233
+ ax.tick_params(axis="x", rotation=45)
1234
+
1235
+ plt.tight_layout()
1236
+ return fig