alberta-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +196 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +530 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +422 -0
- alberta_framework/core/types.py +198 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +70 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +995 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +334 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +138 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.1.0.dist-info/METADATA +198 -0
- alberta_framework-0.1.0.dist-info/RECORD +22 -0
- alberta_framework-0.1.0.dist-info/WHEEL +4 -0
- alberta_framework-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,571 @@
|
|
|
1
|
+
"""Publication-quality visualization utilities.
|
|
2
|
+
|
|
3
|
+
Provides functions for creating figures suitable for academic papers,
|
|
4
|
+
including learning curves, bar plots, heatmaps, and multi-panel figures.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from matplotlib.axes import Axes
|
|
14
|
+
from matplotlib.figure import Figure
|
|
15
|
+
|
|
16
|
+
from alberta_framework.utils.experiments import AggregatedResults
|
|
17
|
+
from alberta_framework.utils.statistics import SignificanceResult
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# Default publication style settings
|
|
21
|
+
_DEFAULT_STYLE = {
|
|
22
|
+
"font_size": 10,
|
|
23
|
+
"figure_width": 3.5, # Single column width in inches
|
|
24
|
+
"figure_height": 2.8,
|
|
25
|
+
"line_width": 1.5,
|
|
26
|
+
"marker_size": 4,
|
|
27
|
+
"dpi": 300,
|
|
28
|
+
"use_latex": False,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
_current_style = _DEFAULT_STYLE.copy()
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def set_publication_style(
|
|
35
|
+
font_size: int = 10,
|
|
36
|
+
use_latex: bool = False,
|
|
37
|
+
figure_width: float = 3.5,
|
|
38
|
+
figure_height: float | None = None,
|
|
39
|
+
style: str = "seaborn-v0_8-whitegrid",
|
|
40
|
+
) -> None:
|
|
41
|
+
"""Set matplotlib style for publication-quality figures.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
font_size: Base font size
|
|
45
|
+
use_latex: Whether to use LaTeX for text rendering
|
|
46
|
+
figure_width: Default figure width in inches
|
|
47
|
+
figure_height: Default figure height (auto if None)
|
|
48
|
+
style: Matplotlib style to use
|
|
49
|
+
"""
|
|
50
|
+
try:
|
|
51
|
+
import matplotlib.pyplot as plt
|
|
52
|
+
except ImportError:
|
|
53
|
+
raise ImportError("matplotlib is required. Install with: pip install matplotlib")
|
|
54
|
+
|
|
55
|
+
# Update current style
|
|
56
|
+
_current_style["font_size"] = font_size
|
|
57
|
+
_current_style["figure_width"] = figure_width
|
|
58
|
+
_current_style["use_latex"] = use_latex
|
|
59
|
+
if figure_height is not None:
|
|
60
|
+
_current_style["figure_height"] = figure_height
|
|
61
|
+
else:
|
|
62
|
+
_current_style["figure_height"] = figure_width * 0.8
|
|
63
|
+
|
|
64
|
+
# Try to use the requested style, fall back to default if not available
|
|
65
|
+
try:
|
|
66
|
+
plt.style.use(style)
|
|
67
|
+
except OSError:
|
|
68
|
+
# Style not available, use defaults
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
# Configure matplotlib
|
|
72
|
+
plt.rcParams.update({
|
|
73
|
+
"font.size": font_size,
|
|
74
|
+
"axes.labelsize": font_size,
|
|
75
|
+
"axes.titlesize": font_size + 1,
|
|
76
|
+
"xtick.labelsize": font_size - 1,
|
|
77
|
+
"ytick.labelsize": font_size - 1,
|
|
78
|
+
"legend.fontsize": font_size - 1,
|
|
79
|
+
"figure.figsize": (_current_style["figure_width"], _current_style["figure_height"]),
|
|
80
|
+
"figure.dpi": _current_style["dpi"],
|
|
81
|
+
"savefig.dpi": _current_style["dpi"],
|
|
82
|
+
"lines.linewidth": _current_style["line_width"],
|
|
83
|
+
"lines.markersize": _current_style["marker_size"],
|
|
84
|
+
"axes.linewidth": 0.8,
|
|
85
|
+
"grid.linewidth": 0.5,
|
|
86
|
+
"grid.alpha": 0.3,
|
|
87
|
+
})
|
|
88
|
+
|
|
89
|
+
if use_latex:
|
|
90
|
+
plt.rcParams.update({
|
|
91
|
+
"text.usetex": True,
|
|
92
|
+
"font.family": "serif",
|
|
93
|
+
"font.serif": ["Computer Modern Roman"],
|
|
94
|
+
})
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def plot_learning_curves(
|
|
98
|
+
results: dict[str, "AggregatedResults"],
|
|
99
|
+
metric: str = "squared_error",
|
|
100
|
+
show_ci: bool = True,
|
|
101
|
+
log_scale: bool = True,
|
|
102
|
+
window_size: int = 100,
|
|
103
|
+
ax: "Axes | None" = None,
|
|
104
|
+
colors: dict[str, str] | None = None,
|
|
105
|
+
labels: dict[str, str] | None = None,
|
|
106
|
+
) -> tuple["Figure", "Axes"]:
|
|
107
|
+
"""Plot learning curves with confidence intervals.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
111
|
+
metric: Metric to plot
|
|
112
|
+
show_ci: Whether to show confidence intervals
|
|
113
|
+
log_scale: Whether to use log scale for y-axis
|
|
114
|
+
window_size: Window size for running mean smoothing
|
|
115
|
+
ax: Existing axes to plot on (creates new figure if None)
|
|
116
|
+
colors: Optional custom colors for each method
|
|
117
|
+
labels: Optional custom labels for legend
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Tuple of (figure, axes)
|
|
121
|
+
"""
|
|
122
|
+
try:
|
|
123
|
+
import matplotlib.pyplot as plt
|
|
124
|
+
except ImportError:
|
|
125
|
+
raise ImportError("matplotlib is required. Install with: pip install matplotlib")
|
|
126
|
+
|
|
127
|
+
from alberta_framework.utils.metrics import compute_running_mean
|
|
128
|
+
from alberta_framework.utils.statistics import compute_timeseries_statistics
|
|
129
|
+
|
|
130
|
+
if ax is None:
|
|
131
|
+
fig, ax = plt.subplots()
|
|
132
|
+
else:
|
|
133
|
+
fig = cast("Figure", ax.figure)
|
|
134
|
+
# Default colors
|
|
135
|
+
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
136
|
+
|
|
137
|
+
for i, (name, agg) in enumerate(results.items()):
|
|
138
|
+
color = (colors or {}).get(name, default_colors[i % len(default_colors)])
|
|
139
|
+
label = (labels or {}).get(name, name)
|
|
140
|
+
|
|
141
|
+
# Compute smoothed mean and CI
|
|
142
|
+
metric_array = agg.metric_arrays[metric]
|
|
143
|
+
|
|
144
|
+
# Smooth each seed individually, then compute statistics
|
|
145
|
+
smoothed = np.array([
|
|
146
|
+
compute_running_mean(metric_array[seed_idx], window_size)
|
|
147
|
+
for seed_idx in range(metric_array.shape[0])
|
|
148
|
+
])
|
|
149
|
+
|
|
150
|
+
mean, ci_lower, ci_upper = compute_timeseries_statistics(smoothed)
|
|
151
|
+
|
|
152
|
+
steps = np.arange(len(mean))
|
|
153
|
+
ax.plot(steps, mean, color=color, label=label, linewidth=_current_style["line_width"])
|
|
154
|
+
|
|
155
|
+
if show_ci:
|
|
156
|
+
ax.fill_between(steps, ci_lower, ci_upper, color=color, alpha=0.2)
|
|
157
|
+
|
|
158
|
+
ax.set_xlabel("Time Step")
|
|
159
|
+
ax.set_ylabel(_metric_to_label(metric))
|
|
160
|
+
if log_scale:
|
|
161
|
+
ax.set_yscale("log")
|
|
162
|
+
ax.legend(loc="best", framealpha=0.9)
|
|
163
|
+
ax.grid(True, alpha=0.3)
|
|
164
|
+
|
|
165
|
+
fig.tight_layout()
|
|
166
|
+
return fig, ax
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def plot_final_performance_bars(
|
|
170
|
+
results: dict[str, "AggregatedResults"],
|
|
171
|
+
metric: str = "squared_error",
|
|
172
|
+
show_significance: bool = True,
|
|
173
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
|
|
174
|
+
ax: "Axes | None" = None,
|
|
175
|
+
colors: dict[str, str] | None = None,
|
|
176
|
+
lower_is_better: bool = True,
|
|
177
|
+
) -> tuple["Figure", "Axes"]:
|
|
178
|
+
"""Plot final performance as bar chart with error bars.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
182
|
+
metric: Metric to plot
|
|
183
|
+
show_significance: Whether to show significance markers
|
|
184
|
+
significance_results: Pairwise significance test results
|
|
185
|
+
ax: Existing axes to plot on (creates new figure if None)
|
|
186
|
+
colors: Optional custom colors for each method
|
|
187
|
+
lower_is_better: Whether lower values are better
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Tuple of (figure, axes)
|
|
191
|
+
"""
|
|
192
|
+
try:
|
|
193
|
+
import matplotlib.pyplot as plt
|
|
194
|
+
except ImportError:
|
|
195
|
+
raise ImportError("matplotlib is required. Install with: pip install matplotlib")
|
|
196
|
+
|
|
197
|
+
if ax is None:
|
|
198
|
+
fig, ax = plt.subplots()
|
|
199
|
+
else:
|
|
200
|
+
fig = cast("Figure", ax.figure)
|
|
201
|
+
names = list(results.keys())
|
|
202
|
+
means = [results[name].summary[metric].mean for name in names]
|
|
203
|
+
stds = [results[name].summary[metric].std for name in names]
|
|
204
|
+
|
|
205
|
+
# Default colors
|
|
206
|
+
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
207
|
+
|
|
208
|
+
x = np.arange(len(names))
|
|
209
|
+
bar_colors = [
|
|
210
|
+
(colors or {}).get(name, default_colors[i % len(default_colors)])
|
|
211
|
+
for i, name in enumerate(names)
|
|
212
|
+
]
|
|
213
|
+
|
|
214
|
+
bars = ax.bar(
|
|
215
|
+
x, means, yerr=stds, capsize=3, color=bar_colors, edgecolor="black", linewidth=0.5
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Find best and mark it
|
|
219
|
+
if lower_is_better:
|
|
220
|
+
best_idx = int(np.argmin(means))
|
|
221
|
+
else:
|
|
222
|
+
best_idx = int(np.argmax(means))
|
|
223
|
+
|
|
224
|
+
bars[best_idx].set_edgecolor("gold")
|
|
225
|
+
bars[best_idx].set_linewidth(2)
|
|
226
|
+
|
|
227
|
+
ax.set_xticks(x)
|
|
228
|
+
ax.set_xticklabels(names, rotation=45, ha="right")
|
|
229
|
+
ax.set_ylabel(_metric_to_label(metric))
|
|
230
|
+
|
|
231
|
+
# Add significance markers if provided
|
|
232
|
+
if show_significance and significance_results:
|
|
233
|
+
best_name = names[best_idx]
|
|
234
|
+
y_max = max(m + s for m, s in zip(means, stds, strict=False))
|
|
235
|
+
y_offset = y_max * 0.05
|
|
236
|
+
|
|
237
|
+
for i, name in enumerate(names):
|
|
238
|
+
if name == best_name:
|
|
239
|
+
continue
|
|
240
|
+
|
|
241
|
+
marker = _get_significance_marker_for_plot(name, best_name, significance_results)
|
|
242
|
+
if marker:
|
|
243
|
+
ax.annotate(
|
|
244
|
+
marker,
|
|
245
|
+
(i, means[i] + stds[i] + y_offset),
|
|
246
|
+
ha="center",
|
|
247
|
+
fontsize=_current_style["font_size"],
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
fig.tight_layout()
|
|
251
|
+
return fig, ax
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def plot_hyperparameter_heatmap(
|
|
255
|
+
results: dict[str, "AggregatedResults"],
|
|
256
|
+
param1_name: str,
|
|
257
|
+
param1_values: list[Any],
|
|
258
|
+
param2_name: str,
|
|
259
|
+
param2_values: list[Any],
|
|
260
|
+
metric: str = "squared_error",
|
|
261
|
+
name_pattern: str = "{p1}_{p2}",
|
|
262
|
+
ax: "Axes | None" = None,
|
|
263
|
+
cmap: str = "viridis_r",
|
|
264
|
+
lower_is_better: bool = True,
|
|
265
|
+
) -> tuple["Figure", "Axes"]:
|
|
266
|
+
"""Plot hyperparameter sensitivity heatmap.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
270
|
+
param1_name: Name of first parameter (y-axis)
|
|
271
|
+
param1_values: Values of first parameter
|
|
272
|
+
param2_name: Name of second parameter (x-axis)
|
|
273
|
+
param2_values: Values of second parameter
|
|
274
|
+
metric: Metric to plot
|
|
275
|
+
name_pattern: Pattern to generate config names (use {p1}, {p2})
|
|
276
|
+
ax: Existing axes to plot on
|
|
277
|
+
cmap: Colormap to use
|
|
278
|
+
lower_is_better: Whether lower values are better
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Tuple of (figure, axes)
|
|
282
|
+
"""
|
|
283
|
+
try:
|
|
284
|
+
import matplotlib.pyplot as plt
|
|
285
|
+
except ImportError:
|
|
286
|
+
raise ImportError("matplotlib is required. Install with: pip install matplotlib")
|
|
287
|
+
|
|
288
|
+
if ax is None:
|
|
289
|
+
fig, ax = plt.subplots()
|
|
290
|
+
else:
|
|
291
|
+
fig = cast("Figure", ax.figure)
|
|
292
|
+
# Build heatmap data
|
|
293
|
+
data = np.zeros((len(param1_values), len(param2_values)))
|
|
294
|
+
for i, p1 in enumerate(param1_values):
|
|
295
|
+
for j, p2 in enumerate(param2_values):
|
|
296
|
+
name = name_pattern.format(p1=p1, p2=p2)
|
|
297
|
+
if name in results:
|
|
298
|
+
data[i, j] = results[name].summary[metric].mean
|
|
299
|
+
else:
|
|
300
|
+
data[i, j] = np.nan
|
|
301
|
+
|
|
302
|
+
if lower_is_better:
|
|
303
|
+
cmap_to_use = cmap
|
|
304
|
+
else:
|
|
305
|
+
cmap_to_use = cmap.replace("_r", "") if "_r" in cmap else f"{cmap}_r"
|
|
306
|
+
|
|
307
|
+
im = ax.imshow(data, cmap=cmap_to_use, aspect="auto")
|
|
308
|
+
ax.set_xticks(np.arange(len(param2_values)))
|
|
309
|
+
ax.set_yticks(np.arange(len(param1_values)))
|
|
310
|
+
ax.set_xticklabels([str(v) for v in param2_values])
|
|
311
|
+
ax.set_yticklabels([str(v) for v in param1_values])
|
|
312
|
+
ax.set_xlabel(param2_name)
|
|
313
|
+
ax.set_ylabel(param1_name)
|
|
314
|
+
|
|
315
|
+
# Add colorbar
|
|
316
|
+
cbar = fig.colorbar(im, ax=ax)
|
|
317
|
+
cbar.set_label(_metric_to_label(metric))
|
|
318
|
+
|
|
319
|
+
# Add value annotations
|
|
320
|
+
for i in range(len(param1_values)):
|
|
321
|
+
for j in range(len(param2_values)):
|
|
322
|
+
if not np.isnan(data[i, j]):
|
|
323
|
+
text_color = "white" if data[i, j] > np.nanmean(data) else "black"
|
|
324
|
+
ax.annotate(
|
|
325
|
+
f"{data[i, j]:.3f}",
|
|
326
|
+
(j, i),
|
|
327
|
+
ha="center",
|
|
328
|
+
va="center",
|
|
329
|
+
color=text_color,
|
|
330
|
+
fontsize=_current_style["font_size"] - 2,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
fig.tight_layout()
|
|
334
|
+
return fig, ax
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def plot_step_size_evolution(
|
|
338
|
+
results: dict[str, "AggregatedResults"],
|
|
339
|
+
metric: str = "mean_step_size",
|
|
340
|
+
show_ci: bool = True,
|
|
341
|
+
ax: "Axes | None" = None,
|
|
342
|
+
colors: dict[str, str] | None = None,
|
|
343
|
+
) -> tuple["Figure", "Axes"]:
|
|
344
|
+
"""Plot step-size evolution over time.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
348
|
+
metric: Step-size metric to plot
|
|
349
|
+
show_ci: Whether to show confidence intervals
|
|
350
|
+
ax: Existing axes to plot on
|
|
351
|
+
colors: Optional custom colors
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
Tuple of (figure, axes)
|
|
355
|
+
"""
|
|
356
|
+
try:
|
|
357
|
+
import matplotlib.pyplot as plt
|
|
358
|
+
except ImportError:
|
|
359
|
+
raise ImportError("matplotlib is required. Install with: pip install matplotlib")
|
|
360
|
+
|
|
361
|
+
from alberta_framework.utils.statistics import compute_timeseries_statistics
|
|
362
|
+
|
|
363
|
+
if ax is None:
|
|
364
|
+
fig, ax = plt.subplots()
|
|
365
|
+
else:
|
|
366
|
+
fig = cast("Figure", ax.figure)
|
|
367
|
+
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
368
|
+
|
|
369
|
+
for i, (name, agg) in enumerate(results.items()):
|
|
370
|
+
if metric not in agg.metric_arrays:
|
|
371
|
+
continue
|
|
372
|
+
|
|
373
|
+
color = (colors or {}).get(name, default_colors[i % len(default_colors)])
|
|
374
|
+
metric_array = agg.metric_arrays[metric]
|
|
375
|
+
|
|
376
|
+
mean, ci_lower, ci_upper = compute_timeseries_statistics(metric_array)
|
|
377
|
+
steps = np.arange(len(mean))
|
|
378
|
+
|
|
379
|
+
ax.plot(steps, mean, color=color, label=name, linewidth=_current_style["line_width"])
|
|
380
|
+
if show_ci:
|
|
381
|
+
ax.fill_between(steps, ci_lower, ci_upper, color=color, alpha=0.2)
|
|
382
|
+
|
|
383
|
+
ax.set_xlabel("Time Step")
|
|
384
|
+
ax.set_ylabel("Step Size")
|
|
385
|
+
ax.set_yscale("log")
|
|
386
|
+
ax.legend(loc="best", framealpha=0.9)
|
|
387
|
+
ax.grid(True, alpha=0.3)
|
|
388
|
+
|
|
389
|
+
fig.tight_layout()
|
|
390
|
+
return fig, ax
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def create_comparison_figure(
|
|
394
|
+
results: dict[str, "AggregatedResults"],
|
|
395
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
|
|
396
|
+
metric: str = "squared_error",
|
|
397
|
+
step_size_metric: str = "mean_step_size",
|
|
398
|
+
) -> "Figure":
|
|
399
|
+
"""Create a 2x2 multi-panel comparison figure.
|
|
400
|
+
|
|
401
|
+
Panels:
|
|
402
|
+
- Top-left: Learning curves
|
|
403
|
+
- Top-right: Final performance bars
|
|
404
|
+
- Bottom-left: Step-size evolution
|
|
405
|
+
- Bottom-right: Cumulative error
|
|
406
|
+
|
|
407
|
+
Args:
|
|
408
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
409
|
+
significance_results: Optional pairwise significance test results
|
|
410
|
+
metric: Error metric to use
|
|
411
|
+
step_size_metric: Step-size metric to use
|
|
412
|
+
|
|
413
|
+
Returns:
|
|
414
|
+
Figure with 4 subplots
|
|
415
|
+
"""
|
|
416
|
+
try:
|
|
417
|
+
import matplotlib.pyplot as plt
|
|
418
|
+
except ImportError:
|
|
419
|
+
raise ImportError("matplotlib is required. Install with: pip install matplotlib")
|
|
420
|
+
|
|
421
|
+
fig, axes = plt.subplots(2, 2, figsize=(7, 5.6))
|
|
422
|
+
|
|
423
|
+
# Top-left: Learning curves
|
|
424
|
+
plot_learning_curves(results, metric=metric, ax=axes[0, 0])
|
|
425
|
+
axes[0, 0].set_title("Learning Curves")
|
|
426
|
+
|
|
427
|
+
# Top-right: Final performance bars
|
|
428
|
+
plot_final_performance_bars(
|
|
429
|
+
results,
|
|
430
|
+
metric=metric,
|
|
431
|
+
significance_results=significance_results,
|
|
432
|
+
ax=axes[0, 1],
|
|
433
|
+
)
|
|
434
|
+
axes[0, 1].set_title("Final Performance")
|
|
435
|
+
|
|
436
|
+
# Bottom-left: Step-size evolution (if available)
|
|
437
|
+
has_step_sizes = any(step_size_metric in agg.metric_arrays for agg in results.values())
|
|
438
|
+
if has_step_sizes:
|
|
439
|
+
plot_step_size_evolution(results, metric=step_size_metric, ax=axes[1, 0])
|
|
440
|
+
axes[1, 0].set_title("Step-Size Evolution")
|
|
441
|
+
else:
|
|
442
|
+
axes[1, 0].text(
|
|
443
|
+
0.5,
|
|
444
|
+
0.5,
|
|
445
|
+
"Step-size data\nnot available",
|
|
446
|
+
ha="center",
|
|
447
|
+
va="center",
|
|
448
|
+
transform=axes[1, 0].transAxes,
|
|
449
|
+
)
|
|
450
|
+
axes[1, 0].set_title("Step-Size Evolution")
|
|
451
|
+
|
|
452
|
+
# Bottom-right: Cumulative error
|
|
453
|
+
_plot_cumulative_error(results, metric=metric, ax=axes[1, 1])
|
|
454
|
+
axes[1, 1].set_title("Cumulative Error")
|
|
455
|
+
|
|
456
|
+
fig.tight_layout()
|
|
457
|
+
return fig
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def _plot_cumulative_error(
|
|
461
|
+
results: dict[str, "AggregatedResults"],
|
|
462
|
+
metric: str,
|
|
463
|
+
ax: "Axes",
|
|
464
|
+
) -> None:
|
|
465
|
+
"""Plot cumulative error."""
|
|
466
|
+
try:
|
|
467
|
+
import matplotlib.pyplot as plt
|
|
468
|
+
except ImportError:
|
|
469
|
+
return
|
|
470
|
+
|
|
471
|
+
from alberta_framework.utils.statistics import compute_timeseries_statistics
|
|
472
|
+
|
|
473
|
+
default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
474
|
+
|
|
475
|
+
for i, (name, agg) in enumerate(results.items()):
|
|
476
|
+
color = default_colors[i % len(default_colors)]
|
|
477
|
+
metric_array = agg.metric_arrays[metric]
|
|
478
|
+
|
|
479
|
+
# Compute cumulative sum for each seed
|
|
480
|
+
cumsum_array = np.cumsum(metric_array, axis=1)
|
|
481
|
+
mean, ci_lower, ci_upper = compute_timeseries_statistics(cumsum_array)
|
|
482
|
+
steps = np.arange(len(mean))
|
|
483
|
+
|
|
484
|
+
ax.plot(steps, mean, color=color, label=name, linewidth=_current_style["line_width"])
|
|
485
|
+
ax.fill_between(steps, ci_lower, ci_upper, color=color, alpha=0.2)
|
|
486
|
+
|
|
487
|
+
ax.set_xlabel("Time Step")
|
|
488
|
+
ax.set_ylabel("Cumulative Error")
|
|
489
|
+
ax.legend(loc="best", framealpha=0.9)
|
|
490
|
+
ax.grid(True, alpha=0.3)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def save_figure(
|
|
494
|
+
fig: "Figure",
|
|
495
|
+
filename: str | Path,
|
|
496
|
+
formats: list[str] | None = None,
|
|
497
|
+
dpi: int = 300,
|
|
498
|
+
transparent: bool = False,
|
|
499
|
+
) -> list[Path]:
|
|
500
|
+
"""Save figure to multiple formats.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
fig: Matplotlib figure to save
|
|
504
|
+
filename: Base filename (without extension)
|
|
505
|
+
formats: List of formats to save (default: ["pdf", "png"])
|
|
506
|
+
dpi: Resolution for raster formats
|
|
507
|
+
transparent: Whether to use transparent background
|
|
508
|
+
|
|
509
|
+
Returns:
|
|
510
|
+
List of saved file paths
|
|
511
|
+
"""
|
|
512
|
+
if formats is None:
|
|
513
|
+
formats = ["pdf", "png"]
|
|
514
|
+
|
|
515
|
+
filename = Path(filename)
|
|
516
|
+
filename.parent.mkdir(parents=True, exist_ok=True)
|
|
517
|
+
|
|
518
|
+
saved_paths = []
|
|
519
|
+
for fmt in formats:
|
|
520
|
+
path = filename.with_suffix(f".{fmt}")
|
|
521
|
+
fig.savefig(
|
|
522
|
+
path,
|
|
523
|
+
format=fmt,
|
|
524
|
+
dpi=dpi,
|
|
525
|
+
bbox_inches="tight",
|
|
526
|
+
transparent=transparent,
|
|
527
|
+
)
|
|
528
|
+
saved_paths.append(path)
|
|
529
|
+
|
|
530
|
+
return saved_paths
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def _metric_to_label(metric: str) -> str:
|
|
534
|
+
"""Convert metric name to human-readable label."""
|
|
535
|
+
labels = {
|
|
536
|
+
"squared_error": "Squared Error",
|
|
537
|
+
"error": "Error",
|
|
538
|
+
"mean_step_size": "Mean Step Size",
|
|
539
|
+
"max_step_size": "Max Step Size",
|
|
540
|
+
"min_step_size": "Min Step Size",
|
|
541
|
+
}
|
|
542
|
+
return labels.get(metric, metric.replace("_", " ").title())
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
def _get_significance_marker_for_plot(
|
|
546
|
+
name: str,
|
|
547
|
+
best_name: str,
|
|
548
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"],
|
|
549
|
+
) -> str:
|
|
550
|
+
"""Get significance marker for plot annotation."""
|
|
551
|
+
key1 = (name, best_name)
|
|
552
|
+
key2 = (best_name, name)
|
|
553
|
+
|
|
554
|
+
if key1 in significance_results:
|
|
555
|
+
result = significance_results[key1]
|
|
556
|
+
elif key2 in significance_results:
|
|
557
|
+
result = significance_results[key2]
|
|
558
|
+
else:
|
|
559
|
+
return ""
|
|
560
|
+
|
|
561
|
+
if not result.significant:
|
|
562
|
+
return ""
|
|
563
|
+
|
|
564
|
+
p = result.p_value
|
|
565
|
+
if p < 0.001:
|
|
566
|
+
return "***"
|
|
567
|
+
elif p < 0.01:
|
|
568
|
+
return "**"
|
|
569
|
+
elif p < 0.05:
|
|
570
|
+
return "*"
|
|
571
|
+
return ""
|