alberta-framework 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,509 @@
1
+ """Export utilities for experiment results.
2
+
3
+ Provides functions for exporting results to CSV, JSON, LaTeX tables,
4
+ and markdown, suitable for academic publications.
5
+ """
6
+
7
+ import csv
8
+ import json
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING
11
+
12
+ if TYPE_CHECKING:
13
+ import pandas as pd
14
+
15
+ from alberta_framework.utils.experiments import AggregatedResults
16
+ from alberta_framework.utils.statistics import SignificanceResult
17
+
18
+
19
+ def export_to_csv(
20
+ results: dict[str, "AggregatedResults"],
21
+ filepath: str | Path,
22
+ metric: str = "squared_error",
23
+ include_timeseries: bool = False,
24
+ ) -> None:
25
+ """Export results to CSV file.
26
+
27
+ Args:
28
+ results: Dictionary mapping config name to AggregatedResults
29
+ filepath: Path to output CSV file
30
+ metric: Metric to export
31
+ include_timeseries: Whether to include full timeseries (large!)
32
+ """
33
+ filepath = Path(filepath)
34
+ filepath.parent.mkdir(parents=True, exist_ok=True)
35
+
36
+ if include_timeseries:
37
+ _export_timeseries_csv(results, filepath, metric)
38
+ else:
39
+ _export_summary_csv(results, filepath, metric)
40
+
41
+
42
+ def _export_summary_csv(
43
+ results: dict[str, "AggregatedResults"],
44
+ filepath: Path,
45
+ metric: str,
46
+ ) -> None:
47
+ """Export summary statistics to CSV."""
48
+ with open(filepath, "w", newline="") as f:
49
+ writer = csv.writer(f)
50
+ writer.writerow(["config", "mean", "std", "min", "max", "n_seeds"])
51
+
52
+ for name, agg in results.items():
53
+ summary = agg.summary[metric]
54
+ writer.writerow([
55
+ name,
56
+ f"{summary.mean:.6f}",
57
+ f"{summary.std:.6f}",
58
+ f"{summary.min:.6f}",
59
+ f"{summary.max:.6f}",
60
+ summary.n_seeds,
61
+ ])
62
+
63
+
64
+ def _export_timeseries_csv(
65
+ results: dict[str, "AggregatedResults"],
66
+ filepath: Path,
67
+ metric: str,
68
+ ) -> None:
69
+ """Export full timeseries to CSV."""
70
+ with open(filepath, "w", newline="") as f:
71
+ writer = csv.writer(f)
72
+
73
+ # Determine max steps
74
+ max_steps = max(agg.metric_arrays[metric].shape[1] for agg in results.values())
75
+
76
+ # Header
77
+ headers = ["step"]
78
+ for name, agg in results.items():
79
+ for seed in agg.seeds:
80
+ headers.append(f"{name}_seed{seed}")
81
+ writer.writerow(headers)
82
+
83
+ # Data rows
84
+ for step in range(max_steps):
85
+ row: list[str | int] = [step]
86
+ for agg in results.values():
87
+ arr = agg.metric_arrays[metric]
88
+ n_seeds = arr.shape[0]
89
+ n_steps = arr.shape[1]
90
+ for seed_idx in range(n_seeds):
91
+ if step < n_steps:
92
+ row.append(f"{arr[seed_idx, step]:.6f}")
93
+ else:
94
+ row.append("")
95
+ writer.writerow(row)
96
+
97
+
98
+ def export_to_json(
99
+ results: dict[str, "AggregatedResults"],
100
+ filepath: str | Path,
101
+ include_timeseries: bool = False,
102
+ ) -> None:
103
+ """Export results to JSON file.
104
+
105
+ Args:
106
+ results: Dictionary mapping config name to AggregatedResults
107
+ filepath: Path to output JSON file
108
+ include_timeseries: Whether to include full timeseries (large!)
109
+ """
110
+ filepath = Path(filepath)
111
+ filepath.parent.mkdir(parents=True, exist_ok=True)
112
+
113
+ from typing import Any
114
+
115
+ data: dict[str, Any] = {}
116
+ for name, agg in results.items():
117
+ summary_data: dict[str, dict[str, Any]] = {}
118
+ for metric_name, summary in agg.summary.items():
119
+ summary_data[metric_name] = {
120
+ "mean": summary.mean,
121
+ "std": summary.std,
122
+ "min": summary.min,
123
+ "max": summary.max,
124
+ "n_seeds": summary.n_seeds,
125
+ "values": summary.values.tolist(),
126
+ }
127
+
128
+ config_data: dict[str, Any] = {
129
+ "seeds": agg.seeds,
130
+ "summary": summary_data,
131
+ }
132
+
133
+ if include_timeseries:
134
+ config_data["timeseries"] = {
135
+ metric: arr.tolist() for metric, arr in agg.metric_arrays.items()
136
+ }
137
+
138
+ data[name] = config_data
139
+
140
+ with open(filepath, "w") as f:
141
+ json.dump(data, f, indent=2)
142
+
143
+
144
+ def generate_latex_table(
145
+ results: dict[str, "AggregatedResults"],
146
+ significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
147
+ metric: str = "squared_error",
148
+ caption: str = "Experimental Results",
149
+ label: str = "tab:results",
150
+ metric_label: str = "Error",
151
+ lower_is_better: bool = True,
152
+ ) -> str:
153
+ """Generate a LaTeX table of results.
154
+
155
+ Args:
156
+ results: Dictionary mapping config name to AggregatedResults
157
+ significance_results: Optional pairwise significance test results
158
+ metric: Metric to display
159
+ caption: Table caption
160
+ label: LaTeX label for the table
161
+ metric_label: Human-readable name for the metric
162
+ lower_is_better: Whether lower metric values are better
163
+
164
+ Returns:
165
+ LaTeX table as a string
166
+ """
167
+ lines = []
168
+ lines.append(r"\begin{table}[ht]")
169
+ lines.append(r"\centering")
170
+ lines.append(r"\caption{" + caption + "}")
171
+ lines.append(r"\label{" + label + "}")
172
+ lines.append(r"\begin{tabular}{lcc}")
173
+ lines.append(r"\toprule")
174
+ lines.append(r"Method & " + metric_label + r" & Seeds \\")
175
+ lines.append(r"\midrule")
176
+
177
+ # Find best result
178
+ summaries = {name: agg.summary[metric] for name, agg in results.items()}
179
+ if lower_is_better:
180
+ best_name = min(summaries.keys(), key=lambda k: summaries[k].mean)
181
+ else:
182
+ best_name = max(summaries.keys(), key=lambda k: summaries[k].mean)
183
+
184
+ for name, agg in results.items():
185
+ summary = agg.summary[metric]
186
+ mean_str = f"{summary.mean:.4f}"
187
+ std_str = f"{summary.std:.4f}"
188
+
189
+ # Bold if best
190
+ if name == best_name:
191
+ value_str = rf"\textbf{{{mean_str}}} $\pm$ {std_str}"
192
+ else:
193
+ value_str = rf"{mean_str} $\pm$ {std_str}"
194
+
195
+ # Add significance marker if provided
196
+ if significance_results:
197
+ sig_marker = _get_significance_marker(name, best_name, significance_results)
198
+ value_str += sig_marker
199
+
200
+ # Escape underscores in method name
201
+ escaped_name = name.replace("_", r"\_")
202
+ lines.append(rf"{escaped_name} & {value_str} & {summary.n_seeds} \\")
203
+
204
+ lines.append(r"\bottomrule")
205
+ lines.append(r"\end{tabular}")
206
+
207
+ if significance_results:
208
+ lines.append(r"\vspace{0.5em}")
209
+ lines.append(r"\footnotesize{$^*$ $p < 0.05$, $^{**}$ $p < 0.01$, $^{***}$ $p < 0.001$}")
210
+
211
+ lines.append(r"\end{table}")
212
+
213
+ return "\n".join(lines)
214
+
215
+
216
+ def _get_significance_marker(
217
+ name: str,
218
+ best_name: str,
219
+ significance_results: dict[tuple[str, str], "SignificanceResult"],
220
+ ) -> str:
221
+ """Get significance marker for comparison with best method."""
222
+ if name == best_name:
223
+ return ""
224
+
225
+ # Find comparison result
226
+ key1 = (name, best_name)
227
+ key2 = (best_name, name)
228
+
229
+ if key1 in significance_results:
230
+ result = significance_results[key1]
231
+ elif key2 in significance_results:
232
+ result = significance_results[key2]
233
+ else:
234
+ return ""
235
+
236
+ if not result.significant:
237
+ return ""
238
+
239
+ p = result.p_value
240
+ if p < 0.001:
241
+ return r"$^{***}$"
242
+ elif p < 0.01:
243
+ return r"$^{**}$"
244
+ elif p < 0.05:
245
+ return r"$^{*}$"
246
+ return ""
247
+
248
+
249
+ def generate_markdown_table(
250
+ results: dict[str, "AggregatedResults"],
251
+ significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
252
+ metric: str = "squared_error",
253
+ metric_label: str = "Error",
254
+ lower_is_better: bool = True,
255
+ ) -> str:
256
+ """Generate a markdown table of results.
257
+
258
+ Args:
259
+ results: Dictionary mapping config name to AggregatedResults
260
+ significance_results: Optional pairwise significance test results
261
+ metric: Metric to display
262
+ metric_label: Human-readable name for the metric
263
+ lower_is_better: Whether lower metric values are better
264
+
265
+ Returns:
266
+ Markdown table as a string
267
+ """
268
+ lines = []
269
+ lines.append(f"| Method | {metric_label} (mean ± std) | Seeds |")
270
+ lines.append("|--------|-------------------------|-------|")
271
+
272
+ # Find best result
273
+ summaries = {name: agg.summary[metric] for name, agg in results.items()}
274
+ if lower_is_better:
275
+ best_name = min(summaries.keys(), key=lambda k: summaries[k].mean)
276
+ else:
277
+ best_name = max(summaries.keys(), key=lambda k: summaries[k].mean)
278
+
279
+ for name, agg in results.items():
280
+ summary = agg.summary[metric]
281
+ mean_str = f"{summary.mean:.4f}"
282
+ std_str = f"{summary.std:.4f}"
283
+
284
+ # Bold if best
285
+ if name == best_name:
286
+ value_str = f"**{mean_str}** ± {std_str}"
287
+ else:
288
+ value_str = f"{mean_str} ± {std_str}"
289
+
290
+ # Add significance marker if provided
291
+ if significance_results:
292
+ sig_marker = _get_md_significance_marker(name, best_name, significance_results)
293
+ value_str += sig_marker
294
+
295
+ lines.append(f"| {name} | {value_str} | {summary.n_seeds} |")
296
+
297
+ if significance_results:
298
+ lines.append("")
299
+ lines.append("\\* p < 0.05, \\*\\* p < 0.01, \\*\\*\\* p < 0.001")
300
+
301
+ return "\n".join(lines)
302
+
303
+
304
+ def _get_md_significance_marker(
305
+ name: str,
306
+ best_name: str,
307
+ significance_results: dict[tuple[str, str], "SignificanceResult"],
308
+ ) -> str:
309
+ """Get significance marker for markdown."""
310
+ if name == best_name:
311
+ return ""
312
+
313
+ key1 = (name, best_name)
314
+ key2 = (best_name, name)
315
+
316
+ if key1 in significance_results:
317
+ result = significance_results[key1]
318
+ elif key2 in significance_results:
319
+ result = significance_results[key2]
320
+ else:
321
+ return ""
322
+
323
+ if not result.significant:
324
+ return ""
325
+
326
+ p = result.p_value
327
+ if p < 0.001:
328
+ return " ***"
329
+ elif p < 0.01:
330
+ return " **"
331
+ elif p < 0.05:
332
+ return " *"
333
+ return ""
334
+
335
+
336
+ def generate_significance_table(
337
+ significance_results: dict[tuple[str, str], "SignificanceResult"],
338
+ format: str = "latex",
339
+ ) -> str:
340
+ """Generate a table of pairwise significance results.
341
+
342
+ Args:
343
+ significance_results: Pairwise significance test results
344
+ format: Output format ("latex" or "markdown")
345
+
346
+ Returns:
347
+ Formatted table as string
348
+ """
349
+ if format == "latex":
350
+ return _generate_significance_latex(significance_results)
351
+ else:
352
+ return _generate_significance_markdown(significance_results)
353
+
354
+
355
+ def _generate_significance_latex(
356
+ significance_results: dict[tuple[str, str], "SignificanceResult"],
357
+ ) -> str:
358
+ """Generate LaTeX significance table."""
359
+ lines = []
360
+ lines.append(r"\begin{table}[ht]")
361
+ lines.append(r"\centering")
362
+ lines.append(r"\caption{Pairwise Significance Tests}")
363
+ lines.append(r"\begin{tabular}{llcccc}")
364
+ lines.append(r"\toprule")
365
+ lines.append(r"Method A & Method B & Statistic & p-value & Effect Size & Sig. \\")
366
+ lines.append(r"\midrule")
367
+
368
+ for (name_a, name_b), result in significance_results.items():
369
+ sig_str = "Yes" if result.significant else "No"
370
+ escaped_a = name_a.replace("_", r"\_")
371
+ escaped_b = name_b.replace("_", r"\_")
372
+ lines.append(
373
+ rf"{escaped_a} & {escaped_b} & {result.statistic:.3f} & "
374
+ rf"{result.p_value:.4f} & {result.effect_size:.3f} & {sig_str} \\"
375
+ )
376
+
377
+ lines.append(r"\bottomrule")
378
+ lines.append(r"\end{tabular}")
379
+ lines.append(r"\end{table}")
380
+
381
+ return "\n".join(lines)
382
+
383
+
384
+ def _generate_significance_markdown(
385
+ significance_results: dict[tuple[str, str], "SignificanceResult"],
386
+ ) -> str:
387
+ """Generate markdown significance table."""
388
+ lines = []
389
+ lines.append("| Method A | Method B | Statistic | p-value | Effect Size | Sig. |")
390
+ lines.append("|----------|----------|-----------|---------|-------------|------|")
391
+
392
+ for (name_a, name_b), result in significance_results.items():
393
+ sig_str = "Yes" if result.significant else "No"
394
+ lines.append(
395
+ f"| {name_a} | {name_b} | {result.statistic:.3f} | "
396
+ f"{result.p_value:.4f} | {result.effect_size:.3f} | {sig_str} |"
397
+ )
398
+
399
+ return "\n".join(lines)
400
+
401
+
402
+ def save_experiment_report(
403
+ results: dict[str, "AggregatedResults"],
404
+ output_dir: str | Path,
405
+ experiment_name: str,
406
+ significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
407
+ metric: str = "squared_error",
408
+ ) -> dict[str, Path]:
409
+ """Save a complete experiment report with all artifacts.
410
+
411
+ Args:
412
+ results: Dictionary mapping config name to AggregatedResults
413
+ output_dir: Directory to save artifacts
414
+ experiment_name: Name for the experiment (used in filenames)
415
+ significance_results: Optional pairwise significance test results
416
+ metric: Primary metric to report
417
+
418
+ Returns:
419
+ Dictionary mapping artifact type to file path
420
+ """
421
+ output_dir = Path(output_dir)
422
+ output_dir.mkdir(parents=True, exist_ok=True)
423
+
424
+ artifacts: dict[str, Path] = {}
425
+
426
+ # Export summary CSV
427
+ csv_path = output_dir / f"{experiment_name}_summary.csv"
428
+ export_to_csv(results, csv_path, metric=metric)
429
+ artifacts["summary_csv"] = csv_path
430
+
431
+ # Export JSON
432
+ json_path = output_dir / f"{experiment_name}_results.json"
433
+ export_to_json(results, json_path, include_timeseries=False)
434
+ artifacts["json"] = json_path
435
+
436
+ # Generate LaTeX table
437
+ latex_path = output_dir / f"{experiment_name}_table.tex"
438
+ latex_content = generate_latex_table(
439
+ results,
440
+ significance_results=significance_results,
441
+ metric=metric,
442
+ caption=f"{experiment_name} Results",
443
+ label=f"tab:{experiment_name}",
444
+ )
445
+ with open(latex_path, "w") as f:
446
+ f.write(latex_content)
447
+ artifacts["latex_table"] = latex_path
448
+
449
+ # Generate markdown table
450
+ md_path = output_dir / f"{experiment_name}_table.md"
451
+ md_content = generate_markdown_table(
452
+ results,
453
+ significance_results=significance_results,
454
+ metric=metric,
455
+ )
456
+ with open(md_path, "w") as f:
457
+ f.write(md_content)
458
+ artifacts["markdown_table"] = md_path
459
+
460
+ # If significance results provided, save those too
461
+ if significance_results:
462
+ sig_latex_path = output_dir / f"{experiment_name}_significance.tex"
463
+ sig_latex = generate_significance_table(significance_results, format="latex")
464
+ with open(sig_latex_path, "w") as f:
465
+ f.write(sig_latex)
466
+ artifacts["significance_latex"] = sig_latex_path
467
+
468
+ sig_md_path = output_dir / f"{experiment_name}_significance.md"
469
+ sig_md = generate_significance_table(significance_results, format="markdown")
470
+ with open(sig_md_path, "w") as f:
471
+ f.write(sig_md)
472
+ artifacts["significance_md"] = sig_md_path
473
+
474
+ return artifacts
475
+
476
+
477
+ def results_to_dataframe(
478
+ results: dict[str, "AggregatedResults"],
479
+ metric: str = "squared_error",
480
+ ) -> "pd.DataFrame":
481
+ """Convert results to a pandas DataFrame.
482
+
483
+ Requires pandas to be installed.
484
+
485
+ Args:
486
+ results: Dictionary mapping config name to AggregatedResults
487
+ metric: Metric to include
488
+
489
+ Returns:
490
+ DataFrame with results
491
+ """
492
+ try:
493
+ import pandas as pd
494
+ except ImportError:
495
+ raise ImportError("pandas is required. Install with: pip install pandas")
496
+
497
+ rows = []
498
+ for name, agg in results.items():
499
+ summary = agg.summary[metric]
500
+ rows.append({
501
+ "method": name,
502
+ "mean": summary.mean,
503
+ "std": summary.std,
504
+ "min": summary.min,
505
+ "max": summary.max,
506
+ "n_seeds": summary.n_seeds,
507
+ })
508
+
509
+ return pd.DataFrame(rows)
@@ -0,0 +1,112 @@
1
+ """Metrics and analysis utilities for continual learning experiments.
2
+
3
+ Provides functions for computing tracking error, learning curves,
4
+ and other metrics useful for evaluating continual learners.
5
+ """
6
+
7
+ import numpy as np
8
+ from numpy.typing import NDArray
9
+
10
+
11
+ def compute_cumulative_error(
12
+ metrics_history: list[dict[str, float]],
13
+ error_key: str = "squared_error",
14
+ ) -> NDArray[np.float64]:
15
+ """Compute cumulative error over time.
16
+
17
+ Args:
18
+ metrics_history: List of metric dictionaries from learning loop
19
+ error_key: Key to extract error values
20
+
21
+ Returns:
22
+ Array of cumulative errors at each time step
23
+ """
24
+ errors = np.array([m[error_key] for m in metrics_history])
25
+ return np.cumsum(errors)
26
+
27
+
28
+ def compute_running_mean(
29
+ values: NDArray[np.float64] | list[float],
30
+ window_size: int = 100,
31
+ ) -> NDArray[np.float64]:
32
+ """Compute running mean of values.
33
+
34
+ Args:
35
+ values: Array of values
36
+ window_size: Size of the moving average window
37
+
38
+ Returns:
39
+ Array of running mean values (same length as input, padded at start)
40
+ """
41
+ values_arr = np.asarray(values)
42
+ cumsum = np.cumsum(np.insert(values_arr, 0, 0))
43
+ running_mean = (cumsum[window_size:] - cumsum[:-window_size]) / window_size
44
+
45
+ # Pad the beginning with the first computed mean
46
+ if len(running_mean) > 0:
47
+ padding = np.full(window_size - 1, running_mean[0])
48
+ return np.concatenate([padding, running_mean])
49
+ return values_arr
50
+
51
+
52
+ def compute_tracking_error(
53
+ metrics_history: list[dict[str, float]],
54
+ window_size: int = 100,
55
+ ) -> NDArray[np.float64]:
56
+ """Compute tracking error (running mean of squared error).
57
+
58
+ This is the key metric for evaluating continual learners:
59
+ how well can the learner track the non-stationary target?
60
+
61
+ Args:
62
+ metrics_history: List of metric dictionaries from learning loop
63
+ window_size: Size of the moving average window
64
+
65
+ Returns:
66
+ Array of tracking errors at each time step
67
+ """
68
+ errors = np.array([m["squared_error"] for m in metrics_history])
69
+ return compute_running_mean(errors, window_size)
70
+
71
+
72
+ def extract_metric(
73
+ metrics_history: list[dict[str, float]],
74
+ key: str,
75
+ ) -> NDArray[np.float64]:
76
+ """Extract a single metric from the history.
77
+
78
+ Args:
79
+ metrics_history: List of metric dictionaries
80
+ key: Key to extract
81
+
82
+ Returns:
83
+ Array of values for that metric
84
+ """
85
+ return np.array([m[key] for m in metrics_history])
86
+
87
+
88
+ def compare_learners(
89
+ results: dict[str, list[dict[str, float]]],
90
+ metric: str = "squared_error",
91
+ ) -> dict[str, dict[str, float]]:
92
+ """Compare multiple learners on a given metric.
93
+
94
+ Args:
95
+ results: Dictionary mapping learner name to metrics history
96
+ metric: Metric to compare
97
+
98
+ Returns:
99
+ Dictionary with summary statistics for each learner
100
+ """
101
+ summary = {}
102
+ for name, metrics_history in results.items():
103
+ values = extract_metric(metrics_history, metric)
104
+ summary[name] = {
105
+ "mean": float(np.mean(values)),
106
+ "std": float(np.std(values)),
107
+ "cumulative": float(np.sum(values)),
108
+ "final_100_mean": (
109
+ float(np.mean(values[-100:])) if len(values) >= 100 else float(np.mean(values))
110
+ ),
111
+ }
112
+ return summary