alberta-framework 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alberta_framework/__init__.py +196 -0
- alberta_framework/core/__init__.py +27 -0
- alberta_framework/core/learners.py +530 -0
- alberta_framework/core/normalizers.py +192 -0
- alberta_framework/core/optimizers.py +422 -0
- alberta_framework/core/types.py +198 -0
- alberta_framework/py.typed +0 -0
- alberta_framework/streams/__init__.py +83 -0
- alberta_framework/streams/base.py +70 -0
- alberta_framework/streams/gymnasium.py +655 -0
- alberta_framework/streams/synthetic.py +995 -0
- alberta_framework/utils/__init__.py +113 -0
- alberta_framework/utils/experiments.py +334 -0
- alberta_framework/utils/export.py +509 -0
- alberta_framework/utils/metrics.py +112 -0
- alberta_framework/utils/statistics.py +527 -0
- alberta_framework/utils/timing.py +138 -0
- alberta_framework/utils/visualization.py +571 -0
- alberta_framework-0.1.0.dist-info/METADATA +198 -0
- alberta_framework-0.1.0.dist-info/RECORD +22 -0
- alberta_framework-0.1.0.dist-info/WHEEL +4 -0
- alberta_framework-0.1.0.dist-info/licenses/LICENSE +190 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
"""Export utilities for experiment results.
|
|
2
|
+
|
|
3
|
+
Provides functions for exporting results to CSV, JSON, LaTeX tables,
|
|
4
|
+
and markdown, suitable for academic publications.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import csv
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from alberta_framework.utils.experiments import AggregatedResults
|
|
16
|
+
from alberta_framework.utils.statistics import SignificanceResult
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def export_to_csv(
|
|
20
|
+
results: dict[str, "AggregatedResults"],
|
|
21
|
+
filepath: str | Path,
|
|
22
|
+
metric: str = "squared_error",
|
|
23
|
+
include_timeseries: bool = False,
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Export results to CSV file.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
29
|
+
filepath: Path to output CSV file
|
|
30
|
+
metric: Metric to export
|
|
31
|
+
include_timeseries: Whether to include full timeseries (large!)
|
|
32
|
+
"""
|
|
33
|
+
filepath = Path(filepath)
|
|
34
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
35
|
+
|
|
36
|
+
if include_timeseries:
|
|
37
|
+
_export_timeseries_csv(results, filepath, metric)
|
|
38
|
+
else:
|
|
39
|
+
_export_summary_csv(results, filepath, metric)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _export_summary_csv(
|
|
43
|
+
results: dict[str, "AggregatedResults"],
|
|
44
|
+
filepath: Path,
|
|
45
|
+
metric: str,
|
|
46
|
+
) -> None:
|
|
47
|
+
"""Export summary statistics to CSV."""
|
|
48
|
+
with open(filepath, "w", newline="") as f:
|
|
49
|
+
writer = csv.writer(f)
|
|
50
|
+
writer.writerow(["config", "mean", "std", "min", "max", "n_seeds"])
|
|
51
|
+
|
|
52
|
+
for name, agg in results.items():
|
|
53
|
+
summary = agg.summary[metric]
|
|
54
|
+
writer.writerow([
|
|
55
|
+
name,
|
|
56
|
+
f"{summary.mean:.6f}",
|
|
57
|
+
f"{summary.std:.6f}",
|
|
58
|
+
f"{summary.min:.6f}",
|
|
59
|
+
f"{summary.max:.6f}",
|
|
60
|
+
summary.n_seeds,
|
|
61
|
+
])
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _export_timeseries_csv(
|
|
65
|
+
results: dict[str, "AggregatedResults"],
|
|
66
|
+
filepath: Path,
|
|
67
|
+
metric: str,
|
|
68
|
+
) -> None:
|
|
69
|
+
"""Export full timeseries to CSV."""
|
|
70
|
+
with open(filepath, "w", newline="") as f:
|
|
71
|
+
writer = csv.writer(f)
|
|
72
|
+
|
|
73
|
+
# Determine max steps
|
|
74
|
+
max_steps = max(agg.metric_arrays[metric].shape[1] for agg in results.values())
|
|
75
|
+
|
|
76
|
+
# Header
|
|
77
|
+
headers = ["step"]
|
|
78
|
+
for name, agg in results.items():
|
|
79
|
+
for seed in agg.seeds:
|
|
80
|
+
headers.append(f"{name}_seed{seed}")
|
|
81
|
+
writer.writerow(headers)
|
|
82
|
+
|
|
83
|
+
# Data rows
|
|
84
|
+
for step in range(max_steps):
|
|
85
|
+
row: list[str | int] = [step]
|
|
86
|
+
for agg in results.values():
|
|
87
|
+
arr = agg.metric_arrays[metric]
|
|
88
|
+
n_seeds = arr.shape[0]
|
|
89
|
+
n_steps = arr.shape[1]
|
|
90
|
+
for seed_idx in range(n_seeds):
|
|
91
|
+
if step < n_steps:
|
|
92
|
+
row.append(f"{arr[seed_idx, step]:.6f}")
|
|
93
|
+
else:
|
|
94
|
+
row.append("")
|
|
95
|
+
writer.writerow(row)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def export_to_json(
|
|
99
|
+
results: dict[str, "AggregatedResults"],
|
|
100
|
+
filepath: str | Path,
|
|
101
|
+
include_timeseries: bool = False,
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Export results to JSON file.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
107
|
+
filepath: Path to output JSON file
|
|
108
|
+
include_timeseries: Whether to include full timeseries (large!)
|
|
109
|
+
"""
|
|
110
|
+
filepath = Path(filepath)
|
|
111
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
112
|
+
|
|
113
|
+
from typing import Any
|
|
114
|
+
|
|
115
|
+
data: dict[str, Any] = {}
|
|
116
|
+
for name, agg in results.items():
|
|
117
|
+
summary_data: dict[str, dict[str, Any]] = {}
|
|
118
|
+
for metric_name, summary in agg.summary.items():
|
|
119
|
+
summary_data[metric_name] = {
|
|
120
|
+
"mean": summary.mean,
|
|
121
|
+
"std": summary.std,
|
|
122
|
+
"min": summary.min,
|
|
123
|
+
"max": summary.max,
|
|
124
|
+
"n_seeds": summary.n_seeds,
|
|
125
|
+
"values": summary.values.tolist(),
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
config_data: dict[str, Any] = {
|
|
129
|
+
"seeds": agg.seeds,
|
|
130
|
+
"summary": summary_data,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
if include_timeseries:
|
|
134
|
+
config_data["timeseries"] = {
|
|
135
|
+
metric: arr.tolist() for metric, arr in agg.metric_arrays.items()
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
data[name] = config_data
|
|
139
|
+
|
|
140
|
+
with open(filepath, "w") as f:
|
|
141
|
+
json.dump(data, f, indent=2)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def generate_latex_table(
|
|
145
|
+
results: dict[str, "AggregatedResults"],
|
|
146
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
|
|
147
|
+
metric: str = "squared_error",
|
|
148
|
+
caption: str = "Experimental Results",
|
|
149
|
+
label: str = "tab:results",
|
|
150
|
+
metric_label: str = "Error",
|
|
151
|
+
lower_is_better: bool = True,
|
|
152
|
+
) -> str:
|
|
153
|
+
"""Generate a LaTeX table of results.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
157
|
+
significance_results: Optional pairwise significance test results
|
|
158
|
+
metric: Metric to display
|
|
159
|
+
caption: Table caption
|
|
160
|
+
label: LaTeX label for the table
|
|
161
|
+
metric_label: Human-readable name for the metric
|
|
162
|
+
lower_is_better: Whether lower metric values are better
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
LaTeX table as a string
|
|
166
|
+
"""
|
|
167
|
+
lines = []
|
|
168
|
+
lines.append(r"\begin{table}[ht]")
|
|
169
|
+
lines.append(r"\centering")
|
|
170
|
+
lines.append(r"\caption{" + caption + "}")
|
|
171
|
+
lines.append(r"\label{" + label + "}")
|
|
172
|
+
lines.append(r"\begin{tabular}{lcc}")
|
|
173
|
+
lines.append(r"\toprule")
|
|
174
|
+
lines.append(r"Method & " + metric_label + r" & Seeds \\")
|
|
175
|
+
lines.append(r"\midrule")
|
|
176
|
+
|
|
177
|
+
# Find best result
|
|
178
|
+
summaries = {name: agg.summary[metric] for name, agg in results.items()}
|
|
179
|
+
if lower_is_better:
|
|
180
|
+
best_name = min(summaries.keys(), key=lambda k: summaries[k].mean)
|
|
181
|
+
else:
|
|
182
|
+
best_name = max(summaries.keys(), key=lambda k: summaries[k].mean)
|
|
183
|
+
|
|
184
|
+
for name, agg in results.items():
|
|
185
|
+
summary = agg.summary[metric]
|
|
186
|
+
mean_str = f"{summary.mean:.4f}"
|
|
187
|
+
std_str = f"{summary.std:.4f}"
|
|
188
|
+
|
|
189
|
+
# Bold if best
|
|
190
|
+
if name == best_name:
|
|
191
|
+
value_str = rf"\textbf{{{mean_str}}} $\pm$ {std_str}"
|
|
192
|
+
else:
|
|
193
|
+
value_str = rf"{mean_str} $\pm$ {std_str}"
|
|
194
|
+
|
|
195
|
+
# Add significance marker if provided
|
|
196
|
+
if significance_results:
|
|
197
|
+
sig_marker = _get_significance_marker(name, best_name, significance_results)
|
|
198
|
+
value_str += sig_marker
|
|
199
|
+
|
|
200
|
+
# Escape underscores in method name
|
|
201
|
+
escaped_name = name.replace("_", r"\_")
|
|
202
|
+
lines.append(rf"{escaped_name} & {value_str} & {summary.n_seeds} \\")
|
|
203
|
+
|
|
204
|
+
lines.append(r"\bottomrule")
|
|
205
|
+
lines.append(r"\end{tabular}")
|
|
206
|
+
|
|
207
|
+
if significance_results:
|
|
208
|
+
lines.append(r"\vspace{0.5em}")
|
|
209
|
+
lines.append(r"\footnotesize{$^*$ $p < 0.05$, $^{**}$ $p < 0.01$, $^{***}$ $p < 0.001$}")
|
|
210
|
+
|
|
211
|
+
lines.append(r"\end{table}")
|
|
212
|
+
|
|
213
|
+
return "\n".join(lines)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def _get_significance_marker(
|
|
217
|
+
name: str,
|
|
218
|
+
best_name: str,
|
|
219
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"],
|
|
220
|
+
) -> str:
|
|
221
|
+
"""Get significance marker for comparison with best method."""
|
|
222
|
+
if name == best_name:
|
|
223
|
+
return ""
|
|
224
|
+
|
|
225
|
+
# Find comparison result
|
|
226
|
+
key1 = (name, best_name)
|
|
227
|
+
key2 = (best_name, name)
|
|
228
|
+
|
|
229
|
+
if key1 in significance_results:
|
|
230
|
+
result = significance_results[key1]
|
|
231
|
+
elif key2 in significance_results:
|
|
232
|
+
result = significance_results[key2]
|
|
233
|
+
else:
|
|
234
|
+
return ""
|
|
235
|
+
|
|
236
|
+
if not result.significant:
|
|
237
|
+
return ""
|
|
238
|
+
|
|
239
|
+
p = result.p_value
|
|
240
|
+
if p < 0.001:
|
|
241
|
+
return r"$^{***}$"
|
|
242
|
+
elif p < 0.01:
|
|
243
|
+
return r"$^{**}$"
|
|
244
|
+
elif p < 0.05:
|
|
245
|
+
return r"$^{*}$"
|
|
246
|
+
return ""
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def generate_markdown_table(
|
|
250
|
+
results: dict[str, "AggregatedResults"],
|
|
251
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
|
|
252
|
+
metric: str = "squared_error",
|
|
253
|
+
metric_label: str = "Error",
|
|
254
|
+
lower_is_better: bool = True,
|
|
255
|
+
) -> str:
|
|
256
|
+
"""Generate a markdown table of results.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
260
|
+
significance_results: Optional pairwise significance test results
|
|
261
|
+
metric: Metric to display
|
|
262
|
+
metric_label: Human-readable name for the metric
|
|
263
|
+
lower_is_better: Whether lower metric values are better
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Markdown table as a string
|
|
267
|
+
"""
|
|
268
|
+
lines = []
|
|
269
|
+
lines.append(f"| Method | {metric_label} (mean ± std) | Seeds |")
|
|
270
|
+
lines.append("|--------|-------------------------|-------|")
|
|
271
|
+
|
|
272
|
+
# Find best result
|
|
273
|
+
summaries = {name: agg.summary[metric] for name, agg in results.items()}
|
|
274
|
+
if lower_is_better:
|
|
275
|
+
best_name = min(summaries.keys(), key=lambda k: summaries[k].mean)
|
|
276
|
+
else:
|
|
277
|
+
best_name = max(summaries.keys(), key=lambda k: summaries[k].mean)
|
|
278
|
+
|
|
279
|
+
for name, agg in results.items():
|
|
280
|
+
summary = agg.summary[metric]
|
|
281
|
+
mean_str = f"{summary.mean:.4f}"
|
|
282
|
+
std_str = f"{summary.std:.4f}"
|
|
283
|
+
|
|
284
|
+
# Bold if best
|
|
285
|
+
if name == best_name:
|
|
286
|
+
value_str = f"**{mean_str}** ± {std_str}"
|
|
287
|
+
else:
|
|
288
|
+
value_str = f"{mean_str} ± {std_str}"
|
|
289
|
+
|
|
290
|
+
# Add significance marker if provided
|
|
291
|
+
if significance_results:
|
|
292
|
+
sig_marker = _get_md_significance_marker(name, best_name, significance_results)
|
|
293
|
+
value_str += sig_marker
|
|
294
|
+
|
|
295
|
+
lines.append(f"| {name} | {value_str} | {summary.n_seeds} |")
|
|
296
|
+
|
|
297
|
+
if significance_results:
|
|
298
|
+
lines.append("")
|
|
299
|
+
lines.append("\\* p < 0.05, \\*\\* p < 0.01, \\*\\*\\* p < 0.001")
|
|
300
|
+
|
|
301
|
+
return "\n".join(lines)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _get_md_significance_marker(
|
|
305
|
+
name: str,
|
|
306
|
+
best_name: str,
|
|
307
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"],
|
|
308
|
+
) -> str:
|
|
309
|
+
"""Get significance marker for markdown."""
|
|
310
|
+
if name == best_name:
|
|
311
|
+
return ""
|
|
312
|
+
|
|
313
|
+
key1 = (name, best_name)
|
|
314
|
+
key2 = (best_name, name)
|
|
315
|
+
|
|
316
|
+
if key1 in significance_results:
|
|
317
|
+
result = significance_results[key1]
|
|
318
|
+
elif key2 in significance_results:
|
|
319
|
+
result = significance_results[key2]
|
|
320
|
+
else:
|
|
321
|
+
return ""
|
|
322
|
+
|
|
323
|
+
if not result.significant:
|
|
324
|
+
return ""
|
|
325
|
+
|
|
326
|
+
p = result.p_value
|
|
327
|
+
if p < 0.001:
|
|
328
|
+
return " ***"
|
|
329
|
+
elif p < 0.01:
|
|
330
|
+
return " **"
|
|
331
|
+
elif p < 0.05:
|
|
332
|
+
return " *"
|
|
333
|
+
return ""
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def generate_significance_table(
|
|
337
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"],
|
|
338
|
+
format: str = "latex",
|
|
339
|
+
) -> str:
|
|
340
|
+
"""Generate a table of pairwise significance results.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
significance_results: Pairwise significance test results
|
|
344
|
+
format: Output format ("latex" or "markdown")
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
Formatted table as string
|
|
348
|
+
"""
|
|
349
|
+
if format == "latex":
|
|
350
|
+
return _generate_significance_latex(significance_results)
|
|
351
|
+
else:
|
|
352
|
+
return _generate_significance_markdown(significance_results)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def _generate_significance_latex(
|
|
356
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"],
|
|
357
|
+
) -> str:
|
|
358
|
+
"""Generate LaTeX significance table."""
|
|
359
|
+
lines = []
|
|
360
|
+
lines.append(r"\begin{table}[ht]")
|
|
361
|
+
lines.append(r"\centering")
|
|
362
|
+
lines.append(r"\caption{Pairwise Significance Tests}")
|
|
363
|
+
lines.append(r"\begin{tabular}{llcccc}")
|
|
364
|
+
lines.append(r"\toprule")
|
|
365
|
+
lines.append(r"Method A & Method B & Statistic & p-value & Effect Size & Sig. \\")
|
|
366
|
+
lines.append(r"\midrule")
|
|
367
|
+
|
|
368
|
+
for (name_a, name_b), result in significance_results.items():
|
|
369
|
+
sig_str = "Yes" if result.significant else "No"
|
|
370
|
+
escaped_a = name_a.replace("_", r"\_")
|
|
371
|
+
escaped_b = name_b.replace("_", r"\_")
|
|
372
|
+
lines.append(
|
|
373
|
+
rf"{escaped_a} & {escaped_b} & {result.statistic:.3f} & "
|
|
374
|
+
rf"{result.p_value:.4f} & {result.effect_size:.3f} & {sig_str} \\"
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
lines.append(r"\bottomrule")
|
|
378
|
+
lines.append(r"\end{tabular}")
|
|
379
|
+
lines.append(r"\end{table}")
|
|
380
|
+
|
|
381
|
+
return "\n".join(lines)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def _generate_significance_markdown(
|
|
385
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"],
|
|
386
|
+
) -> str:
|
|
387
|
+
"""Generate markdown significance table."""
|
|
388
|
+
lines = []
|
|
389
|
+
lines.append("| Method A | Method B | Statistic | p-value | Effect Size | Sig. |")
|
|
390
|
+
lines.append("|----------|----------|-----------|---------|-------------|------|")
|
|
391
|
+
|
|
392
|
+
for (name_a, name_b), result in significance_results.items():
|
|
393
|
+
sig_str = "Yes" if result.significant else "No"
|
|
394
|
+
lines.append(
|
|
395
|
+
f"| {name_a} | {name_b} | {result.statistic:.3f} | "
|
|
396
|
+
f"{result.p_value:.4f} | {result.effect_size:.3f} | {sig_str} |"
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
return "\n".join(lines)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def save_experiment_report(
|
|
403
|
+
results: dict[str, "AggregatedResults"],
|
|
404
|
+
output_dir: str | Path,
|
|
405
|
+
experiment_name: str,
|
|
406
|
+
significance_results: dict[tuple[str, str], "SignificanceResult"] | None = None,
|
|
407
|
+
metric: str = "squared_error",
|
|
408
|
+
) -> dict[str, Path]:
|
|
409
|
+
"""Save a complete experiment report with all artifacts.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
413
|
+
output_dir: Directory to save artifacts
|
|
414
|
+
experiment_name: Name for the experiment (used in filenames)
|
|
415
|
+
significance_results: Optional pairwise significance test results
|
|
416
|
+
metric: Primary metric to report
|
|
417
|
+
|
|
418
|
+
Returns:
|
|
419
|
+
Dictionary mapping artifact type to file path
|
|
420
|
+
"""
|
|
421
|
+
output_dir = Path(output_dir)
|
|
422
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
423
|
+
|
|
424
|
+
artifacts: dict[str, Path] = {}
|
|
425
|
+
|
|
426
|
+
# Export summary CSV
|
|
427
|
+
csv_path = output_dir / f"{experiment_name}_summary.csv"
|
|
428
|
+
export_to_csv(results, csv_path, metric=metric)
|
|
429
|
+
artifacts["summary_csv"] = csv_path
|
|
430
|
+
|
|
431
|
+
# Export JSON
|
|
432
|
+
json_path = output_dir / f"{experiment_name}_results.json"
|
|
433
|
+
export_to_json(results, json_path, include_timeseries=False)
|
|
434
|
+
artifacts["json"] = json_path
|
|
435
|
+
|
|
436
|
+
# Generate LaTeX table
|
|
437
|
+
latex_path = output_dir / f"{experiment_name}_table.tex"
|
|
438
|
+
latex_content = generate_latex_table(
|
|
439
|
+
results,
|
|
440
|
+
significance_results=significance_results,
|
|
441
|
+
metric=metric,
|
|
442
|
+
caption=f"{experiment_name} Results",
|
|
443
|
+
label=f"tab:{experiment_name}",
|
|
444
|
+
)
|
|
445
|
+
with open(latex_path, "w") as f:
|
|
446
|
+
f.write(latex_content)
|
|
447
|
+
artifacts["latex_table"] = latex_path
|
|
448
|
+
|
|
449
|
+
# Generate markdown table
|
|
450
|
+
md_path = output_dir / f"{experiment_name}_table.md"
|
|
451
|
+
md_content = generate_markdown_table(
|
|
452
|
+
results,
|
|
453
|
+
significance_results=significance_results,
|
|
454
|
+
metric=metric,
|
|
455
|
+
)
|
|
456
|
+
with open(md_path, "w") as f:
|
|
457
|
+
f.write(md_content)
|
|
458
|
+
artifacts["markdown_table"] = md_path
|
|
459
|
+
|
|
460
|
+
# If significance results provided, save those too
|
|
461
|
+
if significance_results:
|
|
462
|
+
sig_latex_path = output_dir / f"{experiment_name}_significance.tex"
|
|
463
|
+
sig_latex = generate_significance_table(significance_results, format="latex")
|
|
464
|
+
with open(sig_latex_path, "w") as f:
|
|
465
|
+
f.write(sig_latex)
|
|
466
|
+
artifacts["significance_latex"] = sig_latex_path
|
|
467
|
+
|
|
468
|
+
sig_md_path = output_dir / f"{experiment_name}_significance.md"
|
|
469
|
+
sig_md = generate_significance_table(significance_results, format="markdown")
|
|
470
|
+
with open(sig_md_path, "w") as f:
|
|
471
|
+
f.write(sig_md)
|
|
472
|
+
artifacts["significance_md"] = sig_md_path
|
|
473
|
+
|
|
474
|
+
return artifacts
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def results_to_dataframe(
|
|
478
|
+
results: dict[str, "AggregatedResults"],
|
|
479
|
+
metric: str = "squared_error",
|
|
480
|
+
) -> "pd.DataFrame":
|
|
481
|
+
"""Convert results to a pandas DataFrame.
|
|
482
|
+
|
|
483
|
+
Requires pandas to be installed.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
results: Dictionary mapping config name to AggregatedResults
|
|
487
|
+
metric: Metric to include
|
|
488
|
+
|
|
489
|
+
Returns:
|
|
490
|
+
DataFrame with results
|
|
491
|
+
"""
|
|
492
|
+
try:
|
|
493
|
+
import pandas as pd
|
|
494
|
+
except ImportError:
|
|
495
|
+
raise ImportError("pandas is required. Install with: pip install pandas")
|
|
496
|
+
|
|
497
|
+
rows = []
|
|
498
|
+
for name, agg in results.items():
|
|
499
|
+
summary = agg.summary[metric]
|
|
500
|
+
rows.append({
|
|
501
|
+
"method": name,
|
|
502
|
+
"mean": summary.mean,
|
|
503
|
+
"std": summary.std,
|
|
504
|
+
"min": summary.min,
|
|
505
|
+
"max": summary.max,
|
|
506
|
+
"n_seeds": summary.n_seeds,
|
|
507
|
+
})
|
|
508
|
+
|
|
509
|
+
return pd.DataFrame(rows)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""Metrics and analysis utilities for continual learning experiments.
|
|
2
|
+
|
|
3
|
+
Provides functions for computing tracking error, learning curves,
|
|
4
|
+
and other metrics useful for evaluating continual learners.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compute_cumulative_error(
|
|
12
|
+
metrics_history: list[dict[str, float]],
|
|
13
|
+
error_key: str = "squared_error",
|
|
14
|
+
) -> NDArray[np.float64]:
|
|
15
|
+
"""Compute cumulative error over time.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
metrics_history: List of metric dictionaries from learning loop
|
|
19
|
+
error_key: Key to extract error values
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
Array of cumulative errors at each time step
|
|
23
|
+
"""
|
|
24
|
+
errors = np.array([m[error_key] for m in metrics_history])
|
|
25
|
+
return np.cumsum(errors)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def compute_running_mean(
|
|
29
|
+
values: NDArray[np.float64] | list[float],
|
|
30
|
+
window_size: int = 100,
|
|
31
|
+
) -> NDArray[np.float64]:
|
|
32
|
+
"""Compute running mean of values.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
values: Array of values
|
|
36
|
+
window_size: Size of the moving average window
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Array of running mean values (same length as input, padded at start)
|
|
40
|
+
"""
|
|
41
|
+
values_arr = np.asarray(values)
|
|
42
|
+
cumsum = np.cumsum(np.insert(values_arr, 0, 0))
|
|
43
|
+
running_mean = (cumsum[window_size:] - cumsum[:-window_size]) / window_size
|
|
44
|
+
|
|
45
|
+
# Pad the beginning with the first computed mean
|
|
46
|
+
if len(running_mean) > 0:
|
|
47
|
+
padding = np.full(window_size - 1, running_mean[0])
|
|
48
|
+
return np.concatenate([padding, running_mean])
|
|
49
|
+
return values_arr
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def compute_tracking_error(
|
|
53
|
+
metrics_history: list[dict[str, float]],
|
|
54
|
+
window_size: int = 100,
|
|
55
|
+
) -> NDArray[np.float64]:
|
|
56
|
+
"""Compute tracking error (running mean of squared error).
|
|
57
|
+
|
|
58
|
+
This is the key metric for evaluating continual learners:
|
|
59
|
+
how well can the learner track the non-stationary target?
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
metrics_history: List of metric dictionaries from learning loop
|
|
63
|
+
window_size: Size of the moving average window
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Array of tracking errors at each time step
|
|
67
|
+
"""
|
|
68
|
+
errors = np.array([m["squared_error"] for m in metrics_history])
|
|
69
|
+
return compute_running_mean(errors, window_size)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def extract_metric(
|
|
73
|
+
metrics_history: list[dict[str, float]],
|
|
74
|
+
key: str,
|
|
75
|
+
) -> NDArray[np.float64]:
|
|
76
|
+
"""Extract a single metric from the history.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
metrics_history: List of metric dictionaries
|
|
80
|
+
key: Key to extract
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
Array of values for that metric
|
|
84
|
+
"""
|
|
85
|
+
return np.array([m[key] for m in metrics_history])
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def compare_learners(
|
|
89
|
+
results: dict[str, list[dict[str, float]]],
|
|
90
|
+
metric: str = "squared_error",
|
|
91
|
+
) -> dict[str, dict[str, float]]:
|
|
92
|
+
"""Compare multiple learners on a given metric.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
results: Dictionary mapping learner name to metrics history
|
|
96
|
+
metric: Metric to compare
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Dictionary with summary statistics for each learner
|
|
100
|
+
"""
|
|
101
|
+
summary = {}
|
|
102
|
+
for name, metrics_history in results.items():
|
|
103
|
+
values = extract_metric(metrics_history, metric)
|
|
104
|
+
summary[name] = {
|
|
105
|
+
"mean": float(np.mean(values)),
|
|
106
|
+
"std": float(np.std(values)),
|
|
107
|
+
"cumulative": float(np.sum(values)),
|
|
108
|
+
"final_100_mean": (
|
|
109
|
+
float(np.mean(values[-100:])) if len(values) >= 100 else float(np.mean(values))
|
|
110
|
+
),
|
|
111
|
+
}
|
|
112
|
+
return summary
|