alberta-framework 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,571 @@
1
+ """Publication-quality visualization utilities.
2
+
3
+ Provides functions for creating figures suitable for academic papers,
4
+ including learning curves, bar plots, heatmaps, and multi-panel figures.
5
+ """
6
+
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, cast
9
+
10
+ import numpy as np
11
+
12
+ if TYPE_CHECKING:
13
+ from matplotlib.axes import Axes
14
+ from matplotlib.figure import Figure
15
+
16
+ from alberta_framework.utils.experiments import AggregatedResults
17
+ from alberta_framework.utils.statistics import SignificanceResult
18
+
19
+
20
+ # Default publication style settings
21
+ _DEFAULT_STYLE = {
22
+ "font_size": 10,
23
+ "figure_width": 3.5, # Single column width in inches
24
+ "figure_height": 2.8,
25
+ "line_width": 1.5,
26
+ "marker_size": 4,
27
+ "dpi": 300,
28
+ "use_latex": False,
29
+ }
30
+
31
+ _current_style = _DEFAULT_STYLE.copy()
32
+
33
+
34
+ def set_publication_style(
35
+ font_size: int = 10,
36
+ use_latex: bool = False,
37
+ figure_width: float = 3.5,
38
+ figure_height: float | None = None,
39
+ style: str = "seaborn-v0_8-whitegrid",
40
+ ) -> None:
41
+ """Set matplotlib style for publication-quality figures.
42
+
43
+ Args:
44
+ font_size: Base font size
45
+ use_latex: Whether to use LaTeX for text rendering
46
+ figure_width: Default figure width in inches
47
+ figure_height: Default figure height (auto if None)
48
+ style: Matplotlib style to use
49
+ """
50
+ try:
51
+ import matplotlib.pyplot as plt
52
+ except ImportError:
53
+ raise ImportError("matplotlib is required. Install with: pip install matplotlib")
54
+
55
+ # Update current style
56
+ _current_style["font_size"] = font_size
57
+ _current_style["figure_width"] = figure_width
58
+ _current_style["use_latex"] = use_latex
59
+ if figure_height is not None:
60
+ _current_style["figure_height"] = figure_height
61
+ else:
62
+ _current_style["figure_height"] = figure_width * 0.8
63
+
64
+ # Try to use the requested style, fall back to default if not available
65
+ try:
66
+ plt.style.use(style)
67
+ except OSError:
68
+ # Style not available, use defaults
69
+ pass
70
+
71
+ # Configure matplotlib
72
+ plt.rcParams.update({
73
+ "font.size": font_size,
74
+ "axes.labelsize": font_size,
75
+ "axes.titlesize": font_size + 1,
76
+ "xtick.labelsize": font_size - 1,
77
+ "ytick.labelsize": font_size - 1,
78
+ "legend.fontsize": font_size - 1,
79
+ "figure.figsize": (_current_style["figure_width"], _current_style["figure_height"]),
80
+ "figure.dpi": _current_style["dpi"],
81
+ "savefig.dpi": _current_style["dpi"],
82
+ "lines.linewidth": _current_style["line_width"],
83
+ "lines.markersize": _current_style["marker_size"],
84
+ "axes.linewidth": 0.8,
85
+ "grid.linewidth": 0.5,
86
+ "grid.alpha": 0.3,
87
+ })
88
+
89
+ if use_latex:
90
+ plt.rcParams.update({
91
+ "text.usetex": True,
92
+ "font.family": "serif",
93
+ "font.serif": ["Computer Modern Roman"],
94
+ })
95
+
96
+
97
+ def plot_learning_curves(
98
+ results: dict[str, "AggregatedResults"],
99
+ metric: str = "squared_error",
100
+ show_ci: bool = True,
101
+ log_scale: bool = True,
102
+ window_size: int = 100,
103
+ ax: "Axes | None" = None,
104
+ colors: dict[str, str] | None = None,
105
+ labels: dict[str, str] | None = None,
106
+ ) -> tuple["Figure", "Axes"]:
107
+ """Plot learning curves with confidence intervals.
108
+
109
+ Args:
110
+ results: Dictionary mapping config name to AggregatedResults
111
+ metric: Metric to plot
112
+ show_ci: Whether to show confidence intervals
113
+ log_scale: Whether to use log scale for y-axis
114
+ window_size: Window size for running mean smoothing
115
+ ax: Existing axes to plot on (creates new figure if None)
116
+ colors: Optional custom colors for each method
117
+ labels: Optional custom labels for legend
118
+
119
+ Returns:
120
+ Tuple of (figure, axes)
121
+ """
122
+ try:
123
+ import matplotlib.pyplot as plt
124
+ except ImportError:
125
+ raise ImportError("matplotlib is required. Install with: pip install matplotlib")
126
+
127
+ from alberta_framework.utils.metrics import compute_running_mean
128
+ from alberta_framework.utils.statistics import compute_timeseries_statistics
129
+
130
+ if ax is None:
131
+ fig, ax = plt.subplots()
132
+ else:
133
+ fig = cast("Figure", ax.figure)
134
+ # Default colors
135
+ default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
136
+
137
+ for i, (name, agg) in enumerate(results.items()):
138
+ color = (colors or {}).get(name, default_colors[i % len(default_colors)])
139
+ label = (labels or {}).get(name, name)
140
+
141
+ # Compute smoothed mean and CI
142
+ metric_array = agg.metric_arrays[metric]
143
+
144
+ # Smooth each seed individually, then compute statistics
145
+ smoothed = np.array([
146
+ compute_running_mean(metric_array[seed_idx], window_size)
147
+ for seed_idx in range(metric_array.shape[0])
148
+ ])
149
+
150
+ mean, ci_lower, ci_upper = compute_timeseries_statistics(smoothed)
151
+
152
+ steps = np.arange(len(mean))
153
+ ax.plot(steps, mean, color=color, label=label, linewidth=_current_style["line_width"])
154
+
155
+ if show_ci:
156
+ ax.fill_between(steps, ci_lower, ci_upper, color=color, alpha=0.2)
157
+
158
+ ax.set_xlabel("Time Step")
159
+ ax.set_ylabel(_metric_to_label(metric))
160
+ if log_scale:
161
+ ax.set_yscale("log")
162
+ ax.legend(loc="best", framealpha=0.9)
163
+ ax.grid(True, alpha=0.3)
164
+
165
+ fig.tight_layout()
166
+ return fig, ax
167
+
168
+
169
+ def plot_final_performance_bars(
170
+ results: dict[str, "AggregatedResults"],
171
+ metric: str = "squared_error",
172
+ show_significance: bool = True,
173
+ significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
174
+ ax: "Axes | None" = None,
175
+ colors: dict[str, str] | None = None,
176
+ lower_is_better: bool = True,
177
+ ) -> tuple["Figure", "Axes"]:
178
+ """Plot final performance as bar chart with error bars.
179
+
180
+ Args:
181
+ results: Dictionary mapping config name to AggregatedResults
182
+ metric: Metric to plot
183
+ show_significance: Whether to show significance markers
184
+ significance_results: Pairwise significance test results
185
+ ax: Existing axes to plot on (creates new figure if None)
186
+ colors: Optional custom colors for each method
187
+ lower_is_better: Whether lower values are better
188
+
189
+ Returns:
190
+ Tuple of (figure, axes)
191
+ """
192
+ try:
193
+ import matplotlib.pyplot as plt
194
+ except ImportError:
195
+ raise ImportError("matplotlib is required. Install with: pip install matplotlib")
196
+
197
+ if ax is None:
198
+ fig, ax = plt.subplots()
199
+ else:
200
+ fig = cast("Figure", ax.figure)
201
+ names = list(results.keys())
202
+ means = [results[name].summary[metric].mean for name in names]
203
+ stds = [results[name].summary[metric].std for name in names]
204
+
205
+ # Default colors
206
+ default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
207
+
208
+ x = np.arange(len(names))
209
+ bar_colors = [
210
+ (colors or {}).get(name, default_colors[i % len(default_colors)])
211
+ for i, name in enumerate(names)
212
+ ]
213
+
214
+ bars = ax.bar(
215
+ x, means, yerr=stds, capsize=3, color=bar_colors, edgecolor="black", linewidth=0.5
216
+ )
217
+
218
+ # Find best and mark it
219
+ if lower_is_better:
220
+ best_idx = int(np.argmin(means))
221
+ else:
222
+ best_idx = int(np.argmax(means))
223
+
224
+ bars[best_idx].set_edgecolor("gold")
225
+ bars[best_idx].set_linewidth(2)
226
+
227
+ ax.set_xticks(x)
228
+ ax.set_xticklabels(names, rotation=45, ha="right")
229
+ ax.set_ylabel(_metric_to_label(metric))
230
+
231
+ # Add significance markers if provided
232
+ if show_significance and significance_results:
233
+ best_name = names[best_idx]
234
+ y_max = max(m + s for m, s in zip(means, stds, strict=False))
235
+ y_offset = y_max * 0.05
236
+
237
+ for i, name in enumerate(names):
238
+ if name == best_name:
239
+ continue
240
+
241
+ marker = _get_significance_marker_for_plot(name, best_name, significance_results)
242
+ if marker:
243
+ ax.annotate(
244
+ marker,
245
+ (i, means[i] + stds[i] + y_offset),
246
+ ha="center",
247
+ fontsize=_current_style["font_size"],
248
+ )
249
+
250
+ fig.tight_layout()
251
+ return fig, ax
252
+
253
+
254
+ def plot_hyperparameter_heatmap(
255
+ results: dict[str, "AggregatedResults"],
256
+ param1_name: str,
257
+ param1_values: list[Any],
258
+ param2_name: str,
259
+ param2_values: list[Any],
260
+ metric: str = "squared_error",
261
+ name_pattern: str = "{p1}_{p2}",
262
+ ax: "Axes | None" = None,
263
+ cmap: str = "viridis_r",
264
+ lower_is_better: bool = True,
265
+ ) -> tuple["Figure", "Axes"]:
266
+ """Plot hyperparameter sensitivity heatmap.
267
+
268
+ Args:
269
+ results: Dictionary mapping config name to AggregatedResults
270
+ param1_name: Name of first parameter (y-axis)
271
+ param1_values: Values of first parameter
272
+ param2_name: Name of second parameter (x-axis)
273
+ param2_values: Values of second parameter
274
+ metric: Metric to plot
275
+ name_pattern: Pattern to generate config names (use {p1}, {p2})
276
+ ax: Existing axes to plot on
277
+ cmap: Colormap to use
278
+ lower_is_better: Whether lower values are better
279
+
280
+ Returns:
281
+ Tuple of (figure, axes)
282
+ """
283
+ try:
284
+ import matplotlib.pyplot as plt
285
+ except ImportError:
286
+ raise ImportError("matplotlib is required. Install with: pip install matplotlib")
287
+
288
+ if ax is None:
289
+ fig, ax = plt.subplots()
290
+ else:
291
+ fig = cast("Figure", ax.figure)
292
+ # Build heatmap data
293
+ data = np.zeros((len(param1_values), len(param2_values)))
294
+ for i, p1 in enumerate(param1_values):
295
+ for j, p2 in enumerate(param2_values):
296
+ name = name_pattern.format(p1=p1, p2=p2)
297
+ if name in results:
298
+ data[i, j] = results[name].summary[metric].mean
299
+ else:
300
+ data[i, j] = np.nan
301
+
302
+ if lower_is_better:
303
+ cmap_to_use = cmap
304
+ else:
305
+ cmap_to_use = cmap.replace("_r", "") if "_r" in cmap else f"{cmap}_r"
306
+
307
+ im = ax.imshow(data, cmap=cmap_to_use, aspect="auto")
308
+ ax.set_xticks(np.arange(len(param2_values)))
309
+ ax.set_yticks(np.arange(len(param1_values)))
310
+ ax.set_xticklabels([str(v) for v in param2_values])
311
+ ax.set_yticklabels([str(v) for v in param1_values])
312
+ ax.set_xlabel(param2_name)
313
+ ax.set_ylabel(param1_name)
314
+
315
+ # Add colorbar
316
+ cbar = fig.colorbar(im, ax=ax)
317
+ cbar.set_label(_metric_to_label(metric))
318
+
319
+ # Add value annotations
320
+ for i in range(len(param1_values)):
321
+ for j in range(len(param2_values)):
322
+ if not np.isnan(data[i, j]):
323
+ text_color = "white" if data[i, j] > np.nanmean(data) else "black"
324
+ ax.annotate(
325
+ f"{data[i, j]:.3f}",
326
+ (j, i),
327
+ ha="center",
328
+ va="center",
329
+ color=text_color,
330
+ fontsize=_current_style["font_size"] - 2,
331
+ )
332
+
333
+ fig.tight_layout()
334
+ return fig, ax
335
+
336
+
337
+ def plot_step_size_evolution(
338
+ results: dict[str, "AggregatedResults"],
339
+ metric: str = "mean_step_size",
340
+ show_ci: bool = True,
341
+ ax: "Axes | None" = None,
342
+ colors: dict[str, str] | None = None,
343
+ ) -> tuple["Figure", "Axes"]:
344
+ """Plot step-size evolution over time.
345
+
346
+ Args:
347
+ results: Dictionary mapping config name to AggregatedResults
348
+ metric: Step-size metric to plot
349
+ show_ci: Whether to show confidence intervals
350
+ ax: Existing axes to plot on
351
+ colors: Optional custom colors
352
+
353
+ Returns:
354
+ Tuple of (figure, axes)
355
+ """
356
+ try:
357
+ import matplotlib.pyplot as plt
358
+ except ImportError:
359
+ raise ImportError("matplotlib is required. Install with: pip install matplotlib")
360
+
361
+ from alberta_framework.utils.statistics import compute_timeseries_statistics
362
+
363
+ if ax is None:
364
+ fig, ax = plt.subplots()
365
+ else:
366
+ fig = cast("Figure", ax.figure)
367
+ default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
368
+
369
+ for i, (name, agg) in enumerate(results.items()):
370
+ if metric not in agg.metric_arrays:
371
+ continue
372
+
373
+ color = (colors or {}).get(name, default_colors[i % len(default_colors)])
374
+ metric_array = agg.metric_arrays[metric]
375
+
376
+ mean, ci_lower, ci_upper = compute_timeseries_statistics(metric_array)
377
+ steps = np.arange(len(mean))
378
+
379
+ ax.plot(steps, mean, color=color, label=name, linewidth=_current_style["line_width"])
380
+ if show_ci:
381
+ ax.fill_between(steps, ci_lower, ci_upper, color=color, alpha=0.2)
382
+
383
+ ax.set_xlabel("Time Step")
384
+ ax.set_ylabel("Step Size")
385
+ ax.set_yscale("log")
386
+ ax.legend(loc="best", framealpha=0.9)
387
+ ax.grid(True, alpha=0.3)
388
+
389
+ fig.tight_layout()
390
+ return fig, ax
391
+
392
+
393
+ def create_comparison_figure(
394
+ results: dict[str, "AggregatedResults"],
395
+ significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
396
+ metric: str = "squared_error",
397
+ step_size_metric: str = "mean_step_size",
398
+ ) -> "Figure":
399
+ """Create a 2x2 multi-panel comparison figure.
400
+
401
+ Panels:
402
+ - Top-left: Learning curves
403
+ - Top-right: Final performance bars
404
+ - Bottom-left: Step-size evolution
405
+ - Bottom-right: Cumulative error
406
+
407
+ Args:
408
+ results: Dictionary mapping config name to AggregatedResults
409
+ significance_results: Optional pairwise significance test results
410
+ metric: Error metric to use
411
+ step_size_metric: Step-size metric to use
412
+
413
+ Returns:
414
+ Figure with 4 subplots
415
+ """
416
+ try:
417
+ import matplotlib.pyplot as plt
418
+ except ImportError:
419
+ raise ImportError("matplotlib is required. Install with: pip install matplotlib")
420
+
421
+ fig, axes = plt.subplots(2, 2, figsize=(7, 5.6))
422
+
423
+ # Top-left: Learning curves
424
+ plot_learning_curves(results, metric=metric, ax=axes[0, 0])
425
+ axes[0, 0].set_title("Learning Curves")
426
+
427
+ # Top-right: Final performance bars
428
+ plot_final_performance_bars(
429
+ results,
430
+ metric=metric,
431
+ significance_results=significance_results,
432
+ ax=axes[0, 1],
433
+ )
434
+ axes[0, 1].set_title("Final Performance")
435
+
436
+ # Bottom-left: Step-size evolution (if available)
437
+ has_step_sizes = any(step_size_metric in agg.metric_arrays for agg in results.values())
438
+ if has_step_sizes:
439
+ plot_step_size_evolution(results, metric=step_size_metric, ax=axes[1, 0])
440
+ axes[1, 0].set_title("Step-Size Evolution")
441
+ else:
442
+ axes[1, 0].text(
443
+ 0.5,
444
+ 0.5,
445
+ "Step-size data\nnot available",
446
+ ha="center",
447
+ va="center",
448
+ transform=axes[1, 0].transAxes,
449
+ )
450
+ axes[1, 0].set_title("Step-Size Evolution")
451
+
452
+ # Bottom-right: Cumulative error
453
+ _plot_cumulative_error(results, metric=metric, ax=axes[1, 1])
454
+ axes[1, 1].set_title("Cumulative Error")
455
+
456
+ fig.tight_layout()
457
+ return fig
458
+
459
+
460
+ def _plot_cumulative_error(
461
+ results: dict[str, "AggregatedResults"],
462
+ metric: str,
463
+ ax: "Axes",
464
+ ) -> None:
465
+ """Plot cumulative error."""
466
+ try:
467
+ import matplotlib.pyplot as plt
468
+ except ImportError:
469
+ return
470
+
471
+ from alberta_framework.utils.statistics import compute_timeseries_statistics
472
+
473
+ default_colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
474
+
475
+ for i, (name, agg) in enumerate(results.items()):
476
+ color = default_colors[i % len(default_colors)]
477
+ metric_array = agg.metric_arrays[metric]
478
+
479
+ # Compute cumulative sum for each seed
480
+ cumsum_array = np.cumsum(metric_array, axis=1)
481
+ mean, ci_lower, ci_upper = compute_timeseries_statistics(cumsum_array)
482
+ steps = np.arange(len(mean))
483
+
484
+ ax.plot(steps, mean, color=color, label=name, linewidth=_current_style["line_width"])
485
+ ax.fill_between(steps, ci_lower, ci_upper, color=color, alpha=0.2)
486
+
487
+ ax.set_xlabel("Time Step")
488
+ ax.set_ylabel("Cumulative Error")
489
+ ax.legend(loc="best", framealpha=0.9)
490
+ ax.grid(True, alpha=0.3)
491
+
492
+
493
+ def save_figure(
494
+ fig: "Figure",
495
+ filename: str | Path,
496
+ formats: list[str] | None = None,
497
+ dpi: int = 300,
498
+ transparent: bool = False,
499
+ ) -> list[Path]:
500
+ """Save figure to multiple formats.
501
+
502
+ Args:
503
+ fig: Matplotlib figure to save
504
+ filename: Base filename (without extension)
505
+ formats: List of formats to save (default: ["pdf", "png"])
506
+ dpi: Resolution for raster formats
507
+ transparent: Whether to use transparent background
508
+
509
+ Returns:
510
+ List of saved file paths
511
+ """
512
+ if formats is None:
513
+ formats = ["pdf", "png"]
514
+
515
+ filename = Path(filename)
516
+ filename.parent.mkdir(parents=True, exist_ok=True)
517
+
518
+ saved_paths = []
519
+ for fmt in formats:
520
+ path = filename.with_suffix(f".{fmt}")
521
+ fig.savefig(
522
+ path,
523
+ format=fmt,
524
+ dpi=dpi,
525
+ bbox_inches="tight",
526
+ transparent=transparent,
527
+ )
528
+ saved_paths.append(path)
529
+
530
+ return saved_paths
531
+
532
+
533
+ def _metric_to_label(metric: str) -> str:
534
+ """Convert metric name to human-readable label."""
535
+ labels = {
536
+ "squared_error": "Squared Error",
537
+ "error": "Error",
538
+ "mean_step_size": "Mean Step Size",
539
+ "max_step_size": "Max Step Size",
540
+ "min_step_size": "Min Step Size",
541
+ }
542
+ return labels.get(metric, metric.replace("_", " ").title())
543
+
544
+
545
+ def _get_significance_marker_for_plot(
546
+ name: str,
547
+ best_name: str,
548
+ significance_results: dict[tuple[str, str], "SignificanceResult"],
549
+ ) -> str:
550
+ """Get significance marker for plot annotation."""
551
+ key1 = (name, best_name)
552
+ key2 = (best_name, name)
553
+
554
+ if key1 in significance_results:
555
+ result = significance_results[key1]
556
+ elif key2 in significance_results:
557
+ result = significance_results[key2]
558
+ else:
559
+ return ""
560
+
561
+ if not result.significant:
562
+ return ""
563
+
564
+ p = result.p_value
565
+ if p < 0.001:
566
+ return "***"
567
+ elif p < 0.01:
568
+ return "**"
569
+ elif p < 0.05:
570
+ return "*"
571
+ return ""