wavedl 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wavedl/__init__.py +43 -0
- wavedl/hpo.py +366 -0
- wavedl/models/__init__.py +86 -0
- wavedl/models/_template.py +157 -0
- wavedl/models/base.py +173 -0
- wavedl/models/cnn.py +249 -0
- wavedl/models/convnext.py +425 -0
- wavedl/models/densenet.py +406 -0
- wavedl/models/efficientnet.py +236 -0
- wavedl/models/registry.py +104 -0
- wavedl/models/resnet.py +555 -0
- wavedl/models/unet.py +304 -0
- wavedl/models/vit.py +372 -0
- wavedl/test.py +1069 -0
- wavedl/train.py +1079 -0
- wavedl/utils/__init__.py +151 -0
- wavedl/utils/config.py +269 -0
- wavedl/utils/cross_validation.py +509 -0
- wavedl/utils/data.py +1220 -0
- wavedl/utils/distributed.py +138 -0
- wavedl/utils/losses.py +216 -0
- wavedl/utils/metrics.py +1236 -0
- wavedl/utils/optimizers.py +216 -0
- wavedl/utils/schedulers.py +251 -0
- wavedl-1.2.0.dist-info/LICENSE +21 -0
- wavedl-1.2.0.dist-info/METADATA +991 -0
- wavedl-1.2.0.dist-info/RECORD +30 -0
- wavedl-1.2.0.dist-info/WHEEL +5 -0
- wavedl-1.2.0.dist-info/entry_points.txt +4 -0
- wavedl-1.2.0.dist-info/top_level.txt +1 -0
wavedl/utils/metrics.py
ADDED
|
@@ -0,0 +1,1236 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Scientific Metrics and Visualization Utilities
|
|
3
|
+
===============================================
|
|
4
|
+
|
|
5
|
+
Provides metric tracking, statistical calculations, and publication-quality
|
|
6
|
+
visualization tools for deep learning experiments.
|
|
7
|
+
|
|
8
|
+
Author: Ductho Le (ductho.le@outlook.com)
|
|
9
|
+
Version: 1.1.0
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
import numpy as np
|
|
16
|
+
from scipy.stats import pearsonr
|
|
17
|
+
from sklearn.metrics import r2_score
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ==============================================================================
|
|
21
|
+
# PUBLICATION-QUALITY PLOT CONFIGURATION
|
|
22
|
+
# ==============================================================================
|
|
23
|
+
# Width: 19 cm = 7.48 inches (for two-column journals)
|
|
24
|
+
FIGURE_WIDTH_CM = 19
|
|
25
|
+
FIGURE_WIDTH_INCH = FIGURE_WIDTH_CM / 2.54
|
|
26
|
+
|
|
27
|
+
# Font sizes (consistent 8pt for publication)
|
|
28
|
+
FONT_SIZE_TEXT = 8
|
|
29
|
+
FONT_SIZE_TICKS = 8
|
|
30
|
+
FONT_SIZE_TITLE = 9
|
|
31
|
+
|
|
32
|
+
# DPI for publication (300 for print, 150 for screen)
|
|
33
|
+
FIGURE_DPI = 300
|
|
34
|
+
|
|
35
|
+
# Color palette (accessible, print-friendly)
|
|
36
|
+
COLORS = {
|
|
37
|
+
"primary": "#2E86AB", # Steel blue
|
|
38
|
+
"secondary": "#A23B72", # Raspberry
|
|
39
|
+
"accent": "#F18F01", # Orange
|
|
40
|
+
"neutral": "#6B717E", # Slate gray
|
|
41
|
+
"error": "#C73E1D", # Red
|
|
42
|
+
"success": "#3A7D44", # Green
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def configure_matplotlib_style():
|
|
47
|
+
"""Configure matplotlib for publication-quality LaTeX-style plots."""
|
|
48
|
+
plt.rcParams.update(
|
|
49
|
+
{
|
|
50
|
+
# LaTeX-style fonts
|
|
51
|
+
"font.family": "serif",
|
|
52
|
+
"font.serif": ["Times New Roman", "DejaVu Serif", "serif"],
|
|
53
|
+
"mathtext.fontset": "cm", # Computer Modern for math
|
|
54
|
+
# Font sizes
|
|
55
|
+
"font.size": FONT_SIZE_TEXT,
|
|
56
|
+
"axes.titlesize": FONT_SIZE_TITLE,
|
|
57
|
+
"axes.labelsize": FONT_SIZE_TEXT,
|
|
58
|
+
"xtick.labelsize": FONT_SIZE_TICKS,
|
|
59
|
+
"ytick.labelsize": FONT_SIZE_TICKS,
|
|
60
|
+
"legend.fontsize": FONT_SIZE_TICKS,
|
|
61
|
+
# Line widths
|
|
62
|
+
"axes.linewidth": 0.8,
|
|
63
|
+
"grid.linewidth": 0.5,
|
|
64
|
+
"lines.linewidth": 1.5,
|
|
65
|
+
# Grid style
|
|
66
|
+
"grid.alpha": 0.4,
|
|
67
|
+
"grid.linestyle": ":",
|
|
68
|
+
# Figure settings
|
|
69
|
+
"figure.dpi": FIGURE_DPI,
|
|
70
|
+
"savefig.dpi": FIGURE_DPI,
|
|
71
|
+
"savefig.bbox": "tight",
|
|
72
|
+
"savefig.pad_inches": 0.05,
|
|
73
|
+
# Remove top/right spines for cleaner look
|
|
74
|
+
"axes.spines.top": False,
|
|
75
|
+
"axes.spines.right": False,
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# Apply style on import
|
|
81
|
+
configure_matplotlib_style()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# ==============================================================================
|
|
85
|
+
# METRIC TRACKING
|
|
86
|
+
# ==============================================================================
|
|
87
|
+
class MetricTracker:
|
|
88
|
+
"""
|
|
89
|
+
Tracks running averages of metrics with thread-safe accumulation.
|
|
90
|
+
|
|
91
|
+
Useful for tracking loss, accuracy, or any scalar metric across batches.
|
|
92
|
+
Handles division-by-zero safely by returning 0.0 when count is zero.
|
|
93
|
+
|
|
94
|
+
Attributes:
|
|
95
|
+
val: Most recent value
|
|
96
|
+
avg: Running average
|
|
97
|
+
sum: Cumulative sum
|
|
98
|
+
count: Number of samples
|
|
99
|
+
|
|
100
|
+
Example:
|
|
101
|
+
tracker = MetricTracker()
|
|
102
|
+
for batch in dataloader:
|
|
103
|
+
loss = compute_loss(batch)
|
|
104
|
+
tracker.update(loss.item(), n=batch_size)
|
|
105
|
+
print(f"Average loss: {tracker.avg}")
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
def __init__(self):
|
|
109
|
+
self.reset()
|
|
110
|
+
|
|
111
|
+
def reset(self):
|
|
112
|
+
"""Reset all statistics to initial state."""
|
|
113
|
+
self.val: float = 0.0
|
|
114
|
+
self.avg: float = 0.0
|
|
115
|
+
self.sum: float = 0.0
|
|
116
|
+
self.count: float = 0.0
|
|
117
|
+
|
|
118
|
+
def update(self, val: float, n: int = 1):
|
|
119
|
+
"""
|
|
120
|
+
Update tracker with new value(s).
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
val: New value (or mean of values if n > 1)
|
|
124
|
+
n: Number of samples this value represents
|
|
125
|
+
"""
|
|
126
|
+
self.val = val
|
|
127
|
+
self.sum += val * n
|
|
128
|
+
self.count += n
|
|
129
|
+
self.avg = self.sum / self.count if self.count > 0 else 0.0
|
|
130
|
+
|
|
131
|
+
def __repr__(self) -> str:
|
|
132
|
+
return (
|
|
133
|
+
f"MetricTracker(val={self.val:.4f}, avg={self.avg:.4f}, count={self.count})"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# ==============================================================================
|
|
138
|
+
# STATISTICAL METRICS
|
|
139
|
+
# ==============================================================================
|
|
140
|
+
def get_lr(optimizer) -> float:
|
|
141
|
+
"""
|
|
142
|
+
Extract current learning rate from optimizer.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
optimizer: PyTorch optimizer instance
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
Current learning rate (from first param group)
|
|
149
|
+
"""
|
|
150
|
+
for param_group in optimizer.param_groups:
|
|
151
|
+
return param_group["lr"]
|
|
152
|
+
return 0.0
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def calc_pearson(y_true: np.ndarray, y_pred: np.ndarray) -> float:
|
|
156
|
+
"""
|
|
157
|
+
Calculate average Pearson Correlation Coefficient across all targets.
|
|
158
|
+
|
|
159
|
+
Handles edge cases where variance is near zero to avoid NaN values.
|
|
160
|
+
This metric is important for physics-based signal regression papers.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
164
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Mean Pearson correlation across all targets
|
|
168
|
+
"""
|
|
169
|
+
if y_true.ndim == 1:
|
|
170
|
+
y_true = y_true.reshape(-1, 1)
|
|
171
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
172
|
+
|
|
173
|
+
correlations = []
|
|
174
|
+
for i in range(y_true.shape[1]):
|
|
175
|
+
# Check for near-constant arrays to avoid NaN
|
|
176
|
+
std_true = np.std(y_true[:, i])
|
|
177
|
+
std_pred = np.std(y_pred[:, i])
|
|
178
|
+
|
|
179
|
+
if std_true < 1e-9 or std_pred < 1e-9:
|
|
180
|
+
correlations.append(0.0)
|
|
181
|
+
else:
|
|
182
|
+
corr, _ = pearsonr(y_true[:, i], y_pred[:, i])
|
|
183
|
+
# Handle NaN from pearsonr (shouldn't happen with std check, but safety)
|
|
184
|
+
correlations.append(corr if not np.isnan(corr) else 0.0)
|
|
185
|
+
|
|
186
|
+
return float(np.mean(correlations))
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def calc_per_target_r2(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
|
|
190
|
+
"""
|
|
191
|
+
Calculate R² score for each target independently.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
195
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
Array of R² scores, one per target
|
|
199
|
+
"""
|
|
200
|
+
if y_true.ndim == 1:
|
|
201
|
+
return np.array([r2_score(y_true, y_pred)])
|
|
202
|
+
|
|
203
|
+
r2_scores = []
|
|
204
|
+
for i in range(y_true.shape[1]):
|
|
205
|
+
try:
|
|
206
|
+
r2 = r2_score(y_true[:, i], y_pred[:, i])
|
|
207
|
+
r2_scores.append(r2)
|
|
208
|
+
except ValueError:
|
|
209
|
+
r2_scores.append(0.0)
|
|
210
|
+
|
|
211
|
+
return np.array(r2_scores)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
# ==============================================================================
|
|
215
|
+
# VISUALIZATION - SCATTER PLOTS
|
|
216
|
+
# ==============================================================================
|
|
217
|
+
def plot_scientific_scatter(
|
|
218
|
+
y_true: np.ndarray,
|
|
219
|
+
y_pred: np.ndarray,
|
|
220
|
+
param_names: list[str] | None = None,
|
|
221
|
+
max_samples: int = 2000,
|
|
222
|
+
) -> plt.Figure:
|
|
223
|
+
"""
|
|
224
|
+
Generate publication-quality scatter plots comparing predictions to ground truth.
|
|
225
|
+
|
|
226
|
+
Creates a grid of scatter plots, one for each output target, with R² annotations
|
|
227
|
+
and ideal diagonal reference lines.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
231
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
232
|
+
param_names: Optional list of parameter names for titles
|
|
233
|
+
max_samples: Maximum samples to plot (downsamples if exceeded)
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Matplotlib Figure object (can be saved or logged to WandB)
|
|
237
|
+
"""
|
|
238
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
239
|
+
|
|
240
|
+
# Handle 1D case
|
|
241
|
+
if y_true.ndim == 1:
|
|
242
|
+
y_true = y_true.reshape(-1, 1)
|
|
243
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
244
|
+
|
|
245
|
+
# Downsample for visualization performance
|
|
246
|
+
if len(y_true) > max_samples:
|
|
247
|
+
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
248
|
+
y_true = y_true[indices]
|
|
249
|
+
y_pred = y_pred[indices]
|
|
250
|
+
|
|
251
|
+
# Calculate grid dimensions
|
|
252
|
+
cols = min(num_params, 4)
|
|
253
|
+
rows = (num_params + cols - 1) // cols
|
|
254
|
+
|
|
255
|
+
# Calculate figure size (19 cm width)
|
|
256
|
+
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
257
|
+
fig, axes = plt.subplots(
|
|
258
|
+
rows,
|
|
259
|
+
cols,
|
|
260
|
+
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows),
|
|
261
|
+
)
|
|
262
|
+
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
263
|
+
|
|
264
|
+
for i in range(num_params):
|
|
265
|
+
ax = axes[i]
|
|
266
|
+
|
|
267
|
+
# Calculate R² for this target
|
|
268
|
+
if len(y_true) >= 2:
|
|
269
|
+
r2 = r2_score(y_true[:, i], y_pred[:, i])
|
|
270
|
+
else:
|
|
271
|
+
r2 = float("nan")
|
|
272
|
+
|
|
273
|
+
# Scatter plot
|
|
274
|
+
ax.scatter(
|
|
275
|
+
y_true[:, i],
|
|
276
|
+
y_pred[:, i],
|
|
277
|
+
alpha=0.5,
|
|
278
|
+
s=15,
|
|
279
|
+
c=COLORS["primary"],
|
|
280
|
+
edgecolors="none",
|
|
281
|
+
rasterized=True,
|
|
282
|
+
label="Data",
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Ideal diagonal line
|
|
286
|
+
min_val = min(y_true[:, i].min(), y_pred[:, i].min())
|
|
287
|
+
max_val = max(y_true[:, i].max(), y_pred[:, i].max())
|
|
288
|
+
margin = (max_val - min_val) * 0.05
|
|
289
|
+
ax.plot(
|
|
290
|
+
[min_val - margin, max_val + margin],
|
|
291
|
+
[min_val - margin, max_val + margin],
|
|
292
|
+
"--",
|
|
293
|
+
color=COLORS["error"],
|
|
294
|
+
lw=1.2,
|
|
295
|
+
alpha=0.8,
|
|
296
|
+
label="Ideal",
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
# Labels and formatting
|
|
300
|
+
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
301
|
+
ax.set_title(f"{title}\n$R^2 = {r2:.4f}$")
|
|
302
|
+
ax.set_xlabel("Ground Truth")
|
|
303
|
+
ax.set_ylabel("Prediction")
|
|
304
|
+
ax.grid(True)
|
|
305
|
+
ax.set_aspect("equal", adjustable="box")
|
|
306
|
+
ax.set_xlim(min_val - margin, max_val + margin)
|
|
307
|
+
ax.set_ylim(min_val - margin, max_val + margin)
|
|
308
|
+
ax.legend(fontsize=6, loc="best")
|
|
309
|
+
|
|
310
|
+
# Hide unused subplots
|
|
311
|
+
for i in range(num_params, len(axes)):
|
|
312
|
+
axes[i].axis("off")
|
|
313
|
+
|
|
314
|
+
plt.tight_layout()
|
|
315
|
+
return fig
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# ==============================================================================
|
|
319
|
+
# VISUALIZATION - ERROR HISTOGRAM
|
|
320
|
+
# ==============================================================================
|
|
321
|
+
def plot_error_histogram(
|
|
322
|
+
y_true: np.ndarray,
|
|
323
|
+
y_pred: np.ndarray,
|
|
324
|
+
param_names: list[str] | None = None,
|
|
325
|
+
bins: int = 50,
|
|
326
|
+
) -> plt.Figure:
|
|
327
|
+
"""
|
|
328
|
+
Generate publication-quality error distribution histograms.
|
|
329
|
+
|
|
330
|
+
Shows the distribution of prediction errors (y_pred - y_true) for each target.
|
|
331
|
+
Includes mean, std, and MAE annotations.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
335
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
336
|
+
param_names: Optional list of parameter names for titles
|
|
337
|
+
bins: Number of histogram bins
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Matplotlib Figure object
|
|
341
|
+
"""
|
|
342
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
343
|
+
|
|
344
|
+
# Handle 1D case
|
|
345
|
+
if y_true.ndim == 1:
|
|
346
|
+
y_true = y_true.reshape(-1, 1)
|
|
347
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
348
|
+
|
|
349
|
+
# Calculate errors
|
|
350
|
+
errors = y_pred - y_true
|
|
351
|
+
|
|
352
|
+
# Calculate grid dimensions
|
|
353
|
+
cols = min(num_params, 4)
|
|
354
|
+
rows = (num_params + cols - 1) // cols
|
|
355
|
+
|
|
356
|
+
# Calculate figure size
|
|
357
|
+
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
358
|
+
fig, axes = plt.subplots(
|
|
359
|
+
rows,
|
|
360
|
+
cols,
|
|
361
|
+
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
362
|
+
)
|
|
363
|
+
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
364
|
+
|
|
365
|
+
for i in range(num_params):
|
|
366
|
+
ax = axes[i]
|
|
367
|
+
err = errors[:, i]
|
|
368
|
+
|
|
369
|
+
# Statistics
|
|
370
|
+
mean_err = np.mean(err)
|
|
371
|
+
std_err = np.std(err)
|
|
372
|
+
mae = np.mean(np.abs(err))
|
|
373
|
+
|
|
374
|
+
# Histogram
|
|
375
|
+
ax.hist(
|
|
376
|
+
err,
|
|
377
|
+
bins=bins,
|
|
378
|
+
color=COLORS["primary"],
|
|
379
|
+
alpha=0.7,
|
|
380
|
+
edgecolor="white",
|
|
381
|
+
linewidth=0.5,
|
|
382
|
+
label="Errors",
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Vertical line at zero
|
|
386
|
+
ax.axvline(
|
|
387
|
+
x=0, color=COLORS["error"], linestyle="--", lw=1.2, alpha=0.8, label="Zero"
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
# Mean line
|
|
391
|
+
ax.axvline(
|
|
392
|
+
x=mean_err,
|
|
393
|
+
color=COLORS["accent"],
|
|
394
|
+
linestyle="-",
|
|
395
|
+
lw=1.2,
|
|
396
|
+
label=f"Mean = {mean_err:.4f}",
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Labels and formatting
|
|
400
|
+
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
401
|
+
ax.set_title(f"{title}\nMAE = {mae:.4f}, $\\sigma$ = {std_err:.4f}")
|
|
402
|
+
ax.set_xlabel("Prediction Error")
|
|
403
|
+
ax.set_ylabel("Count")
|
|
404
|
+
ax.grid(True, axis="y")
|
|
405
|
+
ax.legend(fontsize=6, loc="best")
|
|
406
|
+
|
|
407
|
+
# Hide unused subplots
|
|
408
|
+
for i in range(num_params, len(axes)):
|
|
409
|
+
axes[i].axis("off")
|
|
410
|
+
|
|
411
|
+
plt.tight_layout()
|
|
412
|
+
return fig
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
# ==============================================================================
|
|
416
|
+
# VISUALIZATION - RESIDUAL PLOT
|
|
417
|
+
# ==============================================================================
|
|
418
|
+
def plot_residuals(
|
|
419
|
+
y_true: np.ndarray,
|
|
420
|
+
y_pred: np.ndarray,
|
|
421
|
+
param_names: list[str] | None = None,
|
|
422
|
+
max_samples: int = 2000,
|
|
423
|
+
) -> plt.Figure:
|
|
424
|
+
"""
|
|
425
|
+
Generate publication-quality residual plots.
|
|
426
|
+
|
|
427
|
+
Shows residuals (y_pred - y_true) vs predicted values. Useful for detecting
|
|
428
|
+
systematic bias or heteroscedasticity in predictions.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
432
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
433
|
+
param_names: Optional list of parameter names for titles
|
|
434
|
+
max_samples: Maximum samples to plot
|
|
435
|
+
|
|
436
|
+
Returns:
|
|
437
|
+
Matplotlib Figure object
|
|
438
|
+
"""
|
|
439
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
440
|
+
|
|
441
|
+
# Handle 1D case
|
|
442
|
+
if y_true.ndim == 1:
|
|
443
|
+
y_true = y_true.reshape(-1, 1)
|
|
444
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
445
|
+
|
|
446
|
+
# Calculate residuals
|
|
447
|
+
residuals = y_pred - y_true
|
|
448
|
+
|
|
449
|
+
# Downsample for visualization
|
|
450
|
+
if len(y_true) > max_samples:
|
|
451
|
+
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
452
|
+
y_pred = y_pred[indices]
|
|
453
|
+
residuals = residuals[indices]
|
|
454
|
+
|
|
455
|
+
# Calculate grid dimensions
|
|
456
|
+
cols = min(num_params, 4)
|
|
457
|
+
rows = (num_params + cols - 1) // cols
|
|
458
|
+
|
|
459
|
+
# Calculate figure size
|
|
460
|
+
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
461
|
+
fig, axes = plt.subplots(
|
|
462
|
+
rows,
|
|
463
|
+
cols,
|
|
464
|
+
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
465
|
+
)
|
|
466
|
+
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
467
|
+
|
|
468
|
+
for i in range(num_params):
|
|
469
|
+
ax = axes[i]
|
|
470
|
+
|
|
471
|
+
# Scatter plot of residuals
|
|
472
|
+
ax.scatter(
|
|
473
|
+
y_pred[:, i],
|
|
474
|
+
residuals[:, i],
|
|
475
|
+
alpha=0.5,
|
|
476
|
+
s=15,
|
|
477
|
+
c=COLORS["primary"],
|
|
478
|
+
edgecolors="none",
|
|
479
|
+
rasterized=True,
|
|
480
|
+
label="Data",
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Zero line
|
|
484
|
+
ax.axhline(
|
|
485
|
+
y=0, color=COLORS["error"], linestyle="--", lw=1.2, alpha=0.8, label="Zero"
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
# Calculate and show mean residual
|
|
489
|
+
mean_res = np.mean(residuals[:, i])
|
|
490
|
+
ax.axhline(
|
|
491
|
+
y=mean_res,
|
|
492
|
+
color=COLORS["accent"],
|
|
493
|
+
linestyle="-",
|
|
494
|
+
lw=1.0,
|
|
495
|
+
alpha=0.7,
|
|
496
|
+
label=f"Mean = {mean_res:.4f}",
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
# Labels
|
|
500
|
+
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
501
|
+
ax.set_title(f"{title}")
|
|
502
|
+
ax.set_xlabel("Predicted Value")
|
|
503
|
+
ax.set_ylabel("Residual (Pred - True)")
|
|
504
|
+
ax.grid(True)
|
|
505
|
+
ax.legend(fontsize=6, loc="best")
|
|
506
|
+
|
|
507
|
+
# Hide unused subplots
|
|
508
|
+
for i in range(num_params, len(axes)):
|
|
509
|
+
axes[i].axis("off")
|
|
510
|
+
|
|
511
|
+
plt.tight_layout()
|
|
512
|
+
return fig
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
# ==============================================================================
|
|
516
|
+
# VISUALIZATION - TRAINING CURVES
|
|
517
|
+
# ==============================================================================
|
|
518
|
+
def create_training_curves(
|
|
519
|
+
history: list[dict[str, Any]],
|
|
520
|
+
metrics: list[str] = ["train_loss", "val_loss"],
|
|
521
|
+
show_lr: bool = True,
|
|
522
|
+
) -> plt.Figure:
|
|
523
|
+
"""
|
|
524
|
+
Create training curve visualization from history with optional learning rate.
|
|
525
|
+
|
|
526
|
+
Plots training and validation loss over epochs. If learning rate data is
|
|
527
|
+
available in history, it's plotted on a secondary y-axis.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
history: List of epoch statistics dictionaries. Each dict should contain
|
|
531
|
+
'epoch', 'train_loss', 'val_loss', and optionally 'lr'.
|
|
532
|
+
metrics: Metric names to plot on primary y-axis
|
|
533
|
+
show_lr: If True and 'lr' is in history, show learning rate on secondary axis
|
|
534
|
+
|
|
535
|
+
Returns:
|
|
536
|
+
Matplotlib Figure object
|
|
537
|
+
"""
|
|
538
|
+
epochs = [h.get("epoch", i + 1) for i, h in enumerate(history)]
|
|
539
|
+
|
|
540
|
+
fig, ax1 = plt.subplots(figsize=(FIGURE_WIDTH_INCH * 0.7, FIGURE_WIDTH_INCH * 0.4))
|
|
541
|
+
|
|
542
|
+
colors = [
|
|
543
|
+
COLORS["primary"],
|
|
544
|
+
COLORS["secondary"],
|
|
545
|
+
COLORS["accent"],
|
|
546
|
+
COLORS["neutral"],
|
|
547
|
+
]
|
|
548
|
+
|
|
549
|
+
# Plot metrics on primary axis
|
|
550
|
+
lines = []
|
|
551
|
+
for idx, metric in enumerate(metrics):
|
|
552
|
+
values = [h.get(metric, np.nan) for h in history]
|
|
553
|
+
color = colors[idx % len(colors)]
|
|
554
|
+
(line,) = ax1.plot(
|
|
555
|
+
epochs,
|
|
556
|
+
values,
|
|
557
|
+
label=metric.replace("_", " ").title(),
|
|
558
|
+
linewidth=1.5,
|
|
559
|
+
color=color,
|
|
560
|
+
)
|
|
561
|
+
lines.append(line)
|
|
562
|
+
|
|
563
|
+
ax1.set_xlabel("Epoch")
|
|
564
|
+
ax1.set_ylabel("Loss")
|
|
565
|
+
ax1.set_yscale("log") # Log scale for loss
|
|
566
|
+
ax1.grid(True, alpha=0.3)
|
|
567
|
+
|
|
568
|
+
# Check if learning rate data exists
|
|
569
|
+
has_lr = show_lr and any("lr" in h for h in history)
|
|
570
|
+
|
|
571
|
+
if has_lr:
|
|
572
|
+
# Create secondary y-axis for learning rate
|
|
573
|
+
ax2 = ax1.twinx()
|
|
574
|
+
lr_values = [h.get("lr", np.nan) for h in history]
|
|
575
|
+
(line_lr,) = ax2.plot(
|
|
576
|
+
epochs,
|
|
577
|
+
lr_values,
|
|
578
|
+
"--",
|
|
579
|
+
color=COLORS["neutral"],
|
|
580
|
+
linewidth=1.0,
|
|
581
|
+
alpha=0.7,
|
|
582
|
+
label="Learning Rate",
|
|
583
|
+
)
|
|
584
|
+
ax2.set_ylabel("Learning Rate", color=COLORS["neutral"])
|
|
585
|
+
ax2.tick_params(axis="y", labelcolor=COLORS["neutral"])
|
|
586
|
+
ax2.set_yscale("log") # Log scale for LR
|
|
587
|
+
lines.append(line_lr)
|
|
588
|
+
|
|
589
|
+
# Combined legend
|
|
590
|
+
labels = [l.get_label() for l in lines]
|
|
591
|
+
ax1.legend(lines, labels, loc="best", fontsize=6)
|
|
592
|
+
|
|
593
|
+
ax1.set_title("Training Curves")
|
|
594
|
+
|
|
595
|
+
plt.tight_layout()
|
|
596
|
+
return fig
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
# ==============================================================================
|
|
600
|
+
# VISUALIZATION - BLAND-ALTMAN PLOT
|
|
601
|
+
# ==============================================================================
|
|
602
|
+
def plot_bland_altman(
|
|
603
|
+
y_true: np.ndarray,
|
|
604
|
+
y_pred: np.ndarray,
|
|
605
|
+
param_names: list[str] | None = None,
|
|
606
|
+
max_samples: int = 2000,
|
|
607
|
+
) -> plt.Figure:
|
|
608
|
+
"""
|
|
609
|
+
Generate Bland-Altman plots for method comparison.
|
|
610
|
+
|
|
611
|
+
Plots the difference between predictions and ground truth against their mean.
|
|
612
|
+
Includes mean difference line and ±1.96*SD limits of agreement.
|
|
613
|
+
Standard for validating agreement in medical/scientific papers.
|
|
614
|
+
|
|
615
|
+
Args:
|
|
616
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
617
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
618
|
+
param_names: Optional list of parameter names for titles
|
|
619
|
+
max_samples: Maximum samples to plot
|
|
620
|
+
|
|
621
|
+
Returns:
|
|
622
|
+
Matplotlib Figure object
|
|
623
|
+
"""
|
|
624
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
625
|
+
|
|
626
|
+
# Handle 1D case
|
|
627
|
+
if y_true.ndim == 1:
|
|
628
|
+
y_true = y_true.reshape(-1, 1)
|
|
629
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
630
|
+
|
|
631
|
+
# Calculate mean and difference
|
|
632
|
+
mean_vals = (y_true + y_pred) / 2
|
|
633
|
+
diff_vals = y_pred - y_true
|
|
634
|
+
|
|
635
|
+
# Downsample for visualization
|
|
636
|
+
if len(y_true) > max_samples:
|
|
637
|
+
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
638
|
+
mean_vals = mean_vals[indices]
|
|
639
|
+
diff_vals = diff_vals[indices]
|
|
640
|
+
|
|
641
|
+
# Calculate grid dimensions
|
|
642
|
+
cols = min(num_params, 4)
|
|
643
|
+
rows = (num_params + cols - 1) // cols
|
|
644
|
+
|
|
645
|
+
# Calculate figure size
|
|
646
|
+
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
647
|
+
fig, axes = plt.subplots(
|
|
648
|
+
rows,
|
|
649
|
+
cols,
|
|
650
|
+
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
651
|
+
)
|
|
652
|
+
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
653
|
+
|
|
654
|
+
for i in range(num_params):
|
|
655
|
+
ax = axes[i]
|
|
656
|
+
|
|
657
|
+
mean_diff = np.mean(diff_vals[:, i])
|
|
658
|
+
std_diff = np.std(diff_vals[:, i])
|
|
659
|
+
|
|
660
|
+
# Limits of agreement (95% CI = mean ± 1.96*SD)
|
|
661
|
+
upper_loa = mean_diff + 1.96 * std_diff
|
|
662
|
+
lower_loa = mean_diff - 1.96 * std_diff
|
|
663
|
+
|
|
664
|
+
# Scatter plot
|
|
665
|
+
ax.scatter(
|
|
666
|
+
mean_vals[:, i],
|
|
667
|
+
diff_vals[:, i],
|
|
668
|
+
alpha=0.5,
|
|
669
|
+
s=15,
|
|
670
|
+
c=COLORS["primary"],
|
|
671
|
+
edgecolors="none",
|
|
672
|
+
rasterized=True,
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
# Mean difference line
|
|
676
|
+
ax.axhline(
|
|
677
|
+
y=mean_diff,
|
|
678
|
+
color=COLORS["accent"],
|
|
679
|
+
linestyle="-",
|
|
680
|
+
lw=1.2,
|
|
681
|
+
label=f"Mean = {mean_diff:.3f}",
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
# Limits of agreement
|
|
685
|
+
ax.axhline(
|
|
686
|
+
y=upper_loa,
|
|
687
|
+
color=COLORS["error"],
|
|
688
|
+
linestyle="--",
|
|
689
|
+
lw=1.0,
|
|
690
|
+
label=f"+1.96 SD = {upper_loa:.3f}",
|
|
691
|
+
)
|
|
692
|
+
ax.axhline(
|
|
693
|
+
y=lower_loa,
|
|
694
|
+
color=COLORS["error"],
|
|
695
|
+
linestyle="--",
|
|
696
|
+
lw=1.0,
|
|
697
|
+
label=f"-1.96 SD = {lower_loa:.3f}",
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
# Zero line
|
|
701
|
+
ax.axhline(y=0, color="gray", linestyle=":", lw=0.8, alpha=0.5)
|
|
702
|
+
|
|
703
|
+
# Labels
|
|
704
|
+
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
705
|
+
ax.set_title(f"{title}")
|
|
706
|
+
ax.set_xlabel("Mean of True and Predicted")
|
|
707
|
+
ax.set_ylabel("Difference (Pred - True)")
|
|
708
|
+
ax.grid(True)
|
|
709
|
+
ax.legend(fontsize=6, loc="best")
|
|
710
|
+
|
|
711
|
+
# Hide unused subplots
|
|
712
|
+
for i in range(num_params, len(axes)):
|
|
713
|
+
axes[i].axis("off")
|
|
714
|
+
|
|
715
|
+
plt.tight_layout()
|
|
716
|
+
return fig
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
# ==============================================================================
|
|
720
|
+
# VISUALIZATION - QQ PLOT
|
|
721
|
+
# ==============================================================================
|
|
722
|
+
def plot_qq(
|
|
723
|
+
y_true: np.ndarray,
|
|
724
|
+
y_pred: np.ndarray,
|
|
725
|
+
param_names: list[str] | None = None,
|
|
726
|
+
) -> plt.Figure:
|
|
727
|
+
"""
|
|
728
|
+
Generate Q-Q plots to check if prediction errors are normally distributed.
|
|
729
|
+
|
|
730
|
+
Compares the quantiles of the error distribution to a theoretical normal
|
|
731
|
+
distribution. Points on the diagonal indicate normally distributed errors.
|
|
732
|
+
|
|
733
|
+
Args:
|
|
734
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
735
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
736
|
+
param_names: Optional list of parameter names for titles
|
|
737
|
+
|
|
738
|
+
Returns:
|
|
739
|
+
Matplotlib Figure object
|
|
740
|
+
"""
|
|
741
|
+
from scipy import stats
|
|
742
|
+
|
|
743
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
744
|
+
|
|
745
|
+
# Handle 1D case
|
|
746
|
+
if y_true.ndim == 1:
|
|
747
|
+
y_true = y_true.reshape(-1, 1)
|
|
748
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
749
|
+
|
|
750
|
+
# Calculate errors
|
|
751
|
+
errors = y_pred - y_true
|
|
752
|
+
|
|
753
|
+
# Calculate grid dimensions
|
|
754
|
+
cols = min(num_params, 4)
|
|
755
|
+
rows = (num_params + cols - 1) // cols
|
|
756
|
+
|
|
757
|
+
# Calculate figure size
|
|
758
|
+
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
759
|
+
fig, axes = plt.subplots(
|
|
760
|
+
rows,
|
|
761
|
+
cols,
|
|
762
|
+
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows),
|
|
763
|
+
)
|
|
764
|
+
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
765
|
+
|
|
766
|
+
for i in range(num_params):
|
|
767
|
+
ax = axes[i]
|
|
768
|
+
|
|
769
|
+
# Standardize errors for QQ plot
|
|
770
|
+
err = errors[:, i]
|
|
771
|
+
standardized = (err - np.mean(err)) / np.std(err)
|
|
772
|
+
|
|
773
|
+
# Calculate theoretical quantiles and sample quantiles
|
|
774
|
+
(osm, osr), (slope, intercept, r) = stats.probplot(standardized, dist="norm")
|
|
775
|
+
|
|
776
|
+
# Scatter plot
|
|
777
|
+
ax.scatter(
|
|
778
|
+
osm,
|
|
779
|
+
osr,
|
|
780
|
+
alpha=0.5,
|
|
781
|
+
s=15,
|
|
782
|
+
c=COLORS["primary"],
|
|
783
|
+
edgecolors="none",
|
|
784
|
+
rasterized=True,
|
|
785
|
+
label="Data",
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
# Reference line
|
|
789
|
+
line_x = np.array([osm.min(), osm.max()])
|
|
790
|
+
line_y = slope * line_x + intercept
|
|
791
|
+
ax.plot(line_x, line_y, "--", color=COLORS["error"], lw=1.2, label="Normal")
|
|
792
|
+
|
|
793
|
+
# Labels
|
|
794
|
+
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
795
|
+
ax.set_title(f"{title}\n$R^2 = {r**2:.4f}$")
|
|
796
|
+
ax.set_xlabel("Theoretical Quantiles")
|
|
797
|
+
ax.set_ylabel("Sample Quantiles")
|
|
798
|
+
ax.grid(True)
|
|
799
|
+
ax.legend(fontsize=6, loc="best")
|
|
800
|
+
|
|
801
|
+
# Hide unused subplots
|
|
802
|
+
for i in range(num_params, len(axes)):
|
|
803
|
+
axes[i].axis("off")
|
|
804
|
+
|
|
805
|
+
plt.tight_layout()
|
|
806
|
+
return fig
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
# ==============================================================================
|
|
810
|
+
# VISUALIZATION - CORRELATION HEATMAP
|
|
811
|
+
# ==============================================================================
|
|
812
|
+
def plot_correlation_heatmap(
|
|
813
|
+
y_true: np.ndarray,
|
|
814
|
+
y_pred: np.ndarray,
|
|
815
|
+
param_names: list[str] | None = None,
|
|
816
|
+
) -> plt.Figure:
|
|
817
|
+
"""
|
|
818
|
+
Generate correlation heatmap between predicted parameters.
|
|
819
|
+
|
|
820
|
+
Shows the Pearson correlation between different output parameters,
|
|
821
|
+
useful for understanding multi-output prediction relationships.
|
|
822
|
+
|
|
823
|
+
Args:
|
|
824
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
825
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
826
|
+
param_names: Optional list of parameter names for labels
|
|
827
|
+
|
|
828
|
+
Returns:
|
|
829
|
+
Matplotlib Figure object
|
|
830
|
+
"""
|
|
831
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
832
|
+
|
|
833
|
+
if num_params < 2:
|
|
834
|
+
# Need at least 2 params for correlation
|
|
835
|
+
fig, ax = plt.subplots(figsize=(4, 3))
|
|
836
|
+
ax.text(
|
|
837
|
+
0.5,
|
|
838
|
+
0.5,
|
|
839
|
+
"Correlation heatmap requires\nat least 2 parameters",
|
|
840
|
+
ha="center",
|
|
841
|
+
va="center",
|
|
842
|
+
fontsize=10,
|
|
843
|
+
)
|
|
844
|
+
ax.axis("off")
|
|
845
|
+
return fig
|
|
846
|
+
|
|
847
|
+
# Handle 1D case
|
|
848
|
+
if y_true.ndim == 1:
|
|
849
|
+
y_true = y_true.reshape(-1, 1)
|
|
850
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
851
|
+
|
|
852
|
+
if param_names is None or len(param_names) != num_params:
|
|
853
|
+
param_names = [f"P{i}" for i in range(num_params)]
|
|
854
|
+
|
|
855
|
+
# Calculate prediction error correlations
|
|
856
|
+
errors = y_pred - y_true
|
|
857
|
+
corr_matrix = np.corrcoef(errors.T)
|
|
858
|
+
|
|
859
|
+
# Create figure
|
|
860
|
+
fig_size = min(FIGURE_WIDTH_INCH * 0.6, 2 + num_params * 0.6)
|
|
861
|
+
fig, ax = plt.subplots(figsize=(fig_size, fig_size))
|
|
862
|
+
|
|
863
|
+
# Heatmap
|
|
864
|
+
im = ax.imshow(corr_matrix, cmap="RdBu_r", vmin=-1, vmax=1, aspect="equal")
|
|
865
|
+
|
|
866
|
+
# Colorbar
|
|
867
|
+
cbar = plt.colorbar(im, ax=ax, shrink=0.8)
|
|
868
|
+
cbar.set_label("Correlation", fontsize=FONT_SIZE_TICKS)
|
|
869
|
+
|
|
870
|
+
# Labels
|
|
871
|
+
ax.set_xticks(range(num_params))
|
|
872
|
+
ax.set_yticks(range(num_params))
|
|
873
|
+
ax.set_xticklabels(param_names, rotation=45, ha="right")
|
|
874
|
+
ax.set_yticklabels(param_names)
|
|
875
|
+
|
|
876
|
+
# Annotate with values
|
|
877
|
+
for i in range(num_params):
|
|
878
|
+
for j in range(num_params):
|
|
879
|
+
color = "white" if abs(corr_matrix[i, j]) > 0.5 else "black"
|
|
880
|
+
ax.text(
|
|
881
|
+
j,
|
|
882
|
+
i,
|
|
883
|
+
f"{corr_matrix[i, j]:.2f}",
|
|
884
|
+
ha="center",
|
|
885
|
+
va="center",
|
|
886
|
+
color=color,
|
|
887
|
+
fontsize=FONT_SIZE_TICKS,
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
ax.set_title("Error Correlation Matrix")
|
|
891
|
+
|
|
892
|
+
plt.tight_layout()
|
|
893
|
+
return fig
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
# ==============================================================================
|
|
897
|
+
# VISUALIZATION - RELATIVE ERROR PLOT
|
|
898
|
+
# ==============================================================================
|
|
899
|
+
def plot_relative_error(
|
|
900
|
+
y_true: np.ndarray,
|
|
901
|
+
y_pred: np.ndarray,
|
|
902
|
+
param_names: list[str] | None = None,
|
|
903
|
+
max_samples: int = 2000,
|
|
904
|
+
) -> plt.Figure:
|
|
905
|
+
"""
|
|
906
|
+
Generate relative error plots (percentage error vs true value).
|
|
907
|
+
|
|
908
|
+
Useful for detecting scale-dependent bias where errors increase
|
|
909
|
+
proportionally with the magnitude of the true value.
|
|
910
|
+
|
|
911
|
+
Args:
|
|
912
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
913
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
914
|
+
param_names: Optional list of parameter names for titles
|
|
915
|
+
max_samples: Maximum samples to plot
|
|
916
|
+
|
|
917
|
+
Returns:
|
|
918
|
+
Matplotlib Figure object
|
|
919
|
+
"""
|
|
920
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
921
|
+
|
|
922
|
+
# Handle 1D case
|
|
923
|
+
if y_true.ndim == 1:
|
|
924
|
+
y_true = y_true.reshape(-1, 1)
|
|
925
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
926
|
+
|
|
927
|
+
# Downsample for visualization
|
|
928
|
+
if len(y_true) > max_samples:
|
|
929
|
+
indices = np.random.choice(len(y_true), max_samples, replace=False)
|
|
930
|
+
y_true = y_true[indices]
|
|
931
|
+
y_pred = y_pred[indices]
|
|
932
|
+
|
|
933
|
+
# Calculate relative error (avoid division by zero)
|
|
934
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
935
|
+
rel_error = np.abs((y_pred - y_true) / y_true) * 100
|
|
936
|
+
rel_error = np.nan_to_num(rel_error, nan=0.0, posinf=0.0, neginf=0.0)
|
|
937
|
+
|
|
938
|
+
# Calculate grid dimensions
|
|
939
|
+
cols = min(num_params, 4)
|
|
940
|
+
rows = (num_params + cols - 1) // cols
|
|
941
|
+
|
|
942
|
+
# Calculate figure size
|
|
943
|
+
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
944
|
+
fig, axes = plt.subplots(
|
|
945
|
+
rows,
|
|
946
|
+
cols,
|
|
947
|
+
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
948
|
+
)
|
|
949
|
+
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
950
|
+
|
|
951
|
+
for i in range(num_params):
|
|
952
|
+
ax = axes[i]
|
|
953
|
+
|
|
954
|
+
# Scatter plot
|
|
955
|
+
ax.scatter(
|
|
956
|
+
y_true[:, i],
|
|
957
|
+
rel_error[:, i],
|
|
958
|
+
alpha=0.5,
|
|
959
|
+
s=15,
|
|
960
|
+
c=COLORS["primary"],
|
|
961
|
+
edgecolors="none",
|
|
962
|
+
rasterized=True,
|
|
963
|
+
label="Data",
|
|
964
|
+
)
|
|
965
|
+
|
|
966
|
+
# Mean relative error line
|
|
967
|
+
mean_rel = np.mean(rel_error[:, i])
|
|
968
|
+
ax.axhline(
|
|
969
|
+
y=mean_rel,
|
|
970
|
+
color=COLORS["accent"],
|
|
971
|
+
linestyle="-",
|
|
972
|
+
lw=1.2,
|
|
973
|
+
label=f"Mean = {mean_rel:.2f}%",
|
|
974
|
+
)
|
|
975
|
+
|
|
976
|
+
# Labels
|
|
977
|
+
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
978
|
+
ax.set_title(f"{title}")
|
|
979
|
+
ax.set_xlabel("True Value")
|
|
980
|
+
ax.set_ylabel("Relative Error (%)")
|
|
981
|
+
ax.grid(True)
|
|
982
|
+
ax.legend(fontsize=6, loc="best")
|
|
983
|
+
|
|
984
|
+
# Hide unused subplots
|
|
985
|
+
for i in range(num_params, len(axes)):
|
|
986
|
+
axes[i].axis("off")
|
|
987
|
+
|
|
988
|
+
plt.tight_layout()
|
|
989
|
+
return fig
|
|
990
|
+
|
|
991
|
+
|
|
992
|
+
# ==============================================================================
|
|
993
|
+
# VISUALIZATION - CUMULATIVE ERROR DISTRIBUTION (CDF)
|
|
994
|
+
# ==============================================================================
|
|
995
|
+
def plot_error_cdf(
|
|
996
|
+
y_true: np.ndarray,
|
|
997
|
+
y_pred: np.ndarray,
|
|
998
|
+
param_names: list[str] | None = None,
|
|
999
|
+
use_relative: bool = True,
|
|
1000
|
+
) -> plt.Figure:
|
|
1001
|
+
"""
|
|
1002
|
+
Generate cumulative distribution function (CDF) of prediction errors.
|
|
1003
|
+
|
|
1004
|
+
Shows what percentage of predictions fall within a given error bound.
|
|
1005
|
+
Very useful for reporting: "95% of predictions have error < X%"
|
|
1006
|
+
|
|
1007
|
+
Args:
|
|
1008
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
1009
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
1010
|
+
param_names: Optional list of parameter names for legend
|
|
1011
|
+
use_relative: If True, plot relative error (%), else absolute error
|
|
1012
|
+
|
|
1013
|
+
Returns:
|
|
1014
|
+
Matplotlib Figure object
|
|
1015
|
+
"""
|
|
1016
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
1017
|
+
|
|
1018
|
+
# Handle 1D case
|
|
1019
|
+
if y_true.ndim == 1:
|
|
1020
|
+
y_true = y_true.reshape(-1, 1)
|
|
1021
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
1022
|
+
|
|
1023
|
+
if param_names is None or len(param_names) != num_params:
|
|
1024
|
+
param_names = [f"P{i}" for i in range(num_params)]
|
|
1025
|
+
|
|
1026
|
+
# Calculate errors
|
|
1027
|
+
if use_relative:
|
|
1028
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
1029
|
+
errors = np.abs((y_pred - y_true) / y_true) * 100
|
|
1030
|
+
errors = np.nan_to_num(errors, nan=0.0, posinf=0.0, neginf=0.0)
|
|
1031
|
+
xlabel = "Relative Error (%)"
|
|
1032
|
+
else:
|
|
1033
|
+
errors = np.abs(y_pred - y_true)
|
|
1034
|
+
xlabel = "Absolute Error"
|
|
1035
|
+
|
|
1036
|
+
# Create figure
|
|
1037
|
+
fig, ax = plt.subplots(figsize=(FIGURE_WIDTH_INCH * 0.6, FIGURE_WIDTH_INCH * 0.4))
|
|
1038
|
+
|
|
1039
|
+
colors_list = [
|
|
1040
|
+
COLORS["primary"],
|
|
1041
|
+
COLORS["secondary"],
|
|
1042
|
+
COLORS["accent"],
|
|
1043
|
+
COLORS["success"],
|
|
1044
|
+
COLORS["neutral"],
|
|
1045
|
+
]
|
|
1046
|
+
|
|
1047
|
+
for i in range(num_params):
|
|
1048
|
+
err = np.sort(errors[:, i])
|
|
1049
|
+
cdf = np.arange(1, len(err) + 1) / len(err)
|
|
1050
|
+
|
|
1051
|
+
color = colors_list[i % len(colors_list)]
|
|
1052
|
+
ax.plot(err, cdf * 100, label=param_names[i], color=color, lw=1.5)
|
|
1053
|
+
|
|
1054
|
+
# Find 95th percentile (use np.percentile for accuracy)
|
|
1055
|
+
p95_val = np.percentile(errors[:, i], 95)
|
|
1056
|
+
ax.axvline(x=p95_val, color=color, linestyle=":", alpha=0.5)
|
|
1057
|
+
|
|
1058
|
+
# Reference lines
|
|
1059
|
+
ax.axhline(y=95, color="gray", linestyle="--", lw=0.8, alpha=0.7, label="95%")
|
|
1060
|
+
|
|
1061
|
+
ax.set_xlabel(xlabel)
|
|
1062
|
+
ax.set_ylabel("Cumulative Percentage (%)")
|
|
1063
|
+
ax.set_title("Cumulative Error Distribution")
|
|
1064
|
+
ax.legend(fontsize=6, loc="best")
|
|
1065
|
+
ax.grid(True)
|
|
1066
|
+
ax.set_ylim(0, 105)
|
|
1067
|
+
|
|
1068
|
+
plt.tight_layout()
|
|
1069
|
+
return fig
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
# ==============================================================================
|
|
1073
|
+
# VISUALIZATION - PREDICTION VS SAMPLE INDEX
|
|
1074
|
+
# ==============================================================================
|
|
1075
|
+
def plot_prediction_vs_index(
|
|
1076
|
+
y_true: np.ndarray,
|
|
1077
|
+
y_pred: np.ndarray,
|
|
1078
|
+
param_names: list[str] | None = None,
|
|
1079
|
+
max_samples: int = 500,
|
|
1080
|
+
) -> plt.Figure:
|
|
1081
|
+
"""
|
|
1082
|
+
Generate prediction vs sample index plots.
|
|
1083
|
+
|
|
1084
|
+
Shows true and predicted values for each sample in sequence.
|
|
1085
|
+
Useful for time-series style visualization and spotting outliers.
|
|
1086
|
+
|
|
1087
|
+
Args:
|
|
1088
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
1089
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
1090
|
+
param_names: Optional list of parameter names for titles
|
|
1091
|
+
max_samples: Maximum samples to show
|
|
1092
|
+
|
|
1093
|
+
Returns:
|
|
1094
|
+
Matplotlib Figure object
|
|
1095
|
+
"""
|
|
1096
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
1097
|
+
|
|
1098
|
+
# Handle 1D case
|
|
1099
|
+
if y_true.ndim == 1:
|
|
1100
|
+
y_true = y_true.reshape(-1, 1)
|
|
1101
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
1102
|
+
|
|
1103
|
+
# Limit samples
|
|
1104
|
+
n_samples = min(len(y_true), max_samples)
|
|
1105
|
+
y_true = y_true[:n_samples]
|
|
1106
|
+
y_pred = y_pred[:n_samples]
|
|
1107
|
+
indices = np.arange(n_samples)
|
|
1108
|
+
|
|
1109
|
+
# Calculate grid dimensions
|
|
1110
|
+
cols = min(num_params, 4)
|
|
1111
|
+
rows = (num_params + cols - 1) // cols
|
|
1112
|
+
|
|
1113
|
+
# Calculate figure size
|
|
1114
|
+
subplot_size = FIGURE_WIDTH_INCH / cols
|
|
1115
|
+
fig, axes = plt.subplots(
|
|
1116
|
+
rows,
|
|
1117
|
+
cols,
|
|
1118
|
+
figsize=(FIGURE_WIDTH_INCH, subplot_size * rows * 0.8),
|
|
1119
|
+
)
|
|
1120
|
+
axes = np.array(axes).flatten() if num_params > 1 else [axes]
|
|
1121
|
+
|
|
1122
|
+
for i in range(num_params):
|
|
1123
|
+
ax = axes[i]
|
|
1124
|
+
|
|
1125
|
+
# Plot true and predicted
|
|
1126
|
+
ax.plot(
|
|
1127
|
+
indices,
|
|
1128
|
+
y_true[:, i],
|
|
1129
|
+
"o",
|
|
1130
|
+
markersize=3,
|
|
1131
|
+
alpha=0.6,
|
|
1132
|
+
color=COLORS["primary"],
|
|
1133
|
+
label="True",
|
|
1134
|
+
)
|
|
1135
|
+
ax.plot(
|
|
1136
|
+
indices,
|
|
1137
|
+
y_pred[:, i],
|
|
1138
|
+
"x",
|
|
1139
|
+
markersize=3,
|
|
1140
|
+
alpha=0.6,
|
|
1141
|
+
color=COLORS["error"],
|
|
1142
|
+
label="Predicted",
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
# Labels
|
|
1146
|
+
title = param_names[i] if param_names and i < len(param_names) else f"Param {i}"
|
|
1147
|
+
ax.set_title(f"{title}")
|
|
1148
|
+
ax.set_xlabel("Sample Index")
|
|
1149
|
+
ax.set_ylabel("Value")
|
|
1150
|
+
ax.grid(True)
|
|
1151
|
+
ax.legend(fontsize=6, loc="best")
|
|
1152
|
+
|
|
1153
|
+
# Hide unused subplots
|
|
1154
|
+
for i in range(num_params, len(axes)):
|
|
1155
|
+
axes[i].axis("off")
|
|
1156
|
+
|
|
1157
|
+
plt.tight_layout()
|
|
1158
|
+
return fig
|
|
1159
|
+
|
|
1160
|
+
|
|
1161
|
+
# ==============================================================================
|
|
1162
|
+
# VISUALIZATION - ERROR BOX PLOT
|
|
1163
|
+
# ==============================================================================
|
|
1164
|
+
def plot_error_boxplot(
|
|
1165
|
+
y_true: np.ndarray,
|
|
1166
|
+
y_pred: np.ndarray,
|
|
1167
|
+
param_names: list[str] | None = None,
|
|
1168
|
+
use_relative: bool = False,
|
|
1169
|
+
) -> plt.Figure:
|
|
1170
|
+
"""
|
|
1171
|
+
Generate box plots comparing error distributions across parameters.
|
|
1172
|
+
|
|
1173
|
+
Provides a compact view of error statistics (median, quartiles, outliers)
|
|
1174
|
+
for all parameters side-by-side.
|
|
1175
|
+
|
|
1176
|
+
Args:
|
|
1177
|
+
y_true: Ground truth values of shape (N, num_targets)
|
|
1178
|
+
y_pred: Predicted values of shape (N, num_targets)
|
|
1179
|
+
param_names: Optional list of parameter names for x-axis
|
|
1180
|
+
use_relative: If True, plot relative error (%), else absolute error
|
|
1181
|
+
|
|
1182
|
+
Returns:
|
|
1183
|
+
Matplotlib Figure object
|
|
1184
|
+
"""
|
|
1185
|
+
num_params = y_true.shape[1] if y_true.ndim > 1 else 1
|
|
1186
|
+
|
|
1187
|
+
# Handle 1D case
|
|
1188
|
+
if y_true.ndim == 1:
|
|
1189
|
+
y_true = y_true.reshape(-1, 1)
|
|
1190
|
+
y_pred = y_pred.reshape(-1, 1)
|
|
1191
|
+
|
|
1192
|
+
if param_names is None or len(param_names) != num_params:
|
|
1193
|
+
param_names = [f"P{i}" for i in range(num_params)]
|
|
1194
|
+
|
|
1195
|
+
# Calculate errors
|
|
1196
|
+
if use_relative:
|
|
1197
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
1198
|
+
errors = np.abs((y_pred - y_true) / y_true) * 100
|
|
1199
|
+
errors = np.nan_to_num(errors, nan=0.0, posinf=0.0, neginf=0.0)
|
|
1200
|
+
ylabel = "Relative Error (%)"
|
|
1201
|
+
else:
|
|
1202
|
+
errors = y_pred - y_true # Signed error for box plot
|
|
1203
|
+
ylabel = "Prediction Error"
|
|
1204
|
+
|
|
1205
|
+
# Create figure
|
|
1206
|
+
fig_width = min(FIGURE_WIDTH_INCH * 0.5, 2 + num_params * 0.8)
|
|
1207
|
+
fig, ax = plt.subplots(figsize=(fig_width, FIGURE_WIDTH_INCH * 0.4))
|
|
1208
|
+
|
|
1209
|
+
# Box plot
|
|
1210
|
+
bp = ax.boxplot(
|
|
1211
|
+
[errors[:, i] for i in range(num_params)],
|
|
1212
|
+
labels=param_names,
|
|
1213
|
+
patch_artist=True,
|
|
1214
|
+
showfliers=True,
|
|
1215
|
+
flierprops={"marker": "o", "markersize": 3, "alpha": 0.5},
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
# Color the boxes
|
|
1219
|
+
for patch in bp["boxes"]:
|
|
1220
|
+
patch.set_facecolor(COLORS["primary"])
|
|
1221
|
+
patch.set_alpha(0.7)
|
|
1222
|
+
|
|
1223
|
+
# Zero line for signed errors
|
|
1224
|
+
if not use_relative:
|
|
1225
|
+
ax.axhline(y=0, color=COLORS["error"], linestyle="--", lw=1.0, alpha=0.7)
|
|
1226
|
+
|
|
1227
|
+
ax.set_ylabel(ylabel)
|
|
1228
|
+
ax.set_title("Error Distribution by Parameter")
|
|
1229
|
+
ax.grid(True, axis="y")
|
|
1230
|
+
|
|
1231
|
+
# Rotate labels if needed
|
|
1232
|
+
if num_params > 4:
|
|
1233
|
+
ax.tick_params(axis="x", rotation=45)
|
|
1234
|
+
|
|
1235
|
+
plt.tight_layout()
|
|
1236
|
+
return fig
|