claude-turing 1.5.0 → 2.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +5 -2
- package/commands/export.md +48 -0
- package/commands/lit.md +47 -0
- package/commands/paper.md +44 -0
- package/commands/turing.md +6 -0
- package/package.json +1 -1
- package/src/install.js +2 -1
- package/src/verify.js +3 -0
- package/templates/scripts/__pycache__/draft_paper_sections.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/equivalence_checker.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/export_card.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/export_formats.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/latency_benchmark.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/literature_search.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/draft_paper_sections.py +498 -0
- package/templates/scripts/equivalence_checker.py +158 -0
- package/templates/scripts/export_card.py +183 -0
- package/templates/scripts/export_formats.py +385 -0
- package/templates/scripts/export_model.py +324 -0
- package/templates/scripts/latency_benchmark.py +167 -0
- package/templates/scripts/literature_search.py +421 -0
- package/templates/scripts/scaffold.py +10 -0
|
@@ -0,0 +1,498 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Paper section drafting from experiment logs.
|
|
3
|
+
|
|
4
|
+
Drafts the mechanical sections of an ML paper directly from experiment
|
|
5
|
+
data: experimental setup, results tables, ablation tables, and
|
|
6
|
+
hyperparameter appendices. Eliminates transcription errors.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/draft_paper_sections.py # All sections
|
|
10
|
+
python scripts/draft_paper_sections.py --sections setup,results # Specific sections
|
|
11
|
+
python scripts/draft_paper_sections.py --format latex # LaTeX output
|
|
12
|
+
python scripts/draft_paper_sections.py --format markdown # Markdown output
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import sys
|
|
20
|
+
from datetime import datetime, timezone
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import yaml
|
|
24
|
+
|
|
25
|
+
from scripts.turing_io import load_config, load_experiments
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
VALID_SECTIONS = ["setup", "results", "ablation", "hyperparameters"]
|
|
29
|
+
DEFAULT_FORMAT = "latex"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def load_seed_studies_for_paper(seed_dir: str = "experiments/seed_studies") -> dict[str, dict]:
|
|
33
|
+
"""Load all seed studies indexed by experiment ID."""
|
|
34
|
+
path = Path(seed_dir)
|
|
35
|
+
studies = {}
|
|
36
|
+
if not path.exists():
|
|
37
|
+
return studies
|
|
38
|
+
for f in path.glob("*-seeds.yaml"):
|
|
39
|
+
try:
|
|
40
|
+
with open(f) as fh:
|
|
41
|
+
study = yaml.safe_load(fh)
|
|
42
|
+
if study and isinstance(study, dict) and "experiment_id" in study:
|
|
43
|
+
studies[study["experiment_id"]] = study
|
|
44
|
+
except (yaml.YAMLError, OSError):
|
|
45
|
+
continue
|
|
46
|
+
return studies
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def load_ablation_studies(ablation_dir: str = "experiments/ablations") -> dict[str, dict]:
|
|
50
|
+
"""Load all ablation studies indexed by experiment ID."""
|
|
51
|
+
path = Path(ablation_dir)
|
|
52
|
+
studies = {}
|
|
53
|
+
if not path.exists():
|
|
54
|
+
return studies
|
|
55
|
+
for f in path.glob("*-ablation.yaml"):
|
|
56
|
+
try:
|
|
57
|
+
with open(f) as fh:
|
|
58
|
+
study = yaml.safe_load(fh)
|
|
59
|
+
if study and isinstance(study, dict) and "experiment_id" in study:
|
|
60
|
+
studies[study["experiment_id"]] = study
|
|
61
|
+
except (yaml.YAMLError, OSError):
|
|
62
|
+
continue
|
|
63
|
+
return studies
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_top_experiments(
|
|
67
|
+
experiments: list[dict],
|
|
68
|
+
metric: str,
|
|
69
|
+
lower_is_better: bool,
|
|
70
|
+
top_k: int = 10,
|
|
71
|
+
) -> list[dict]:
|
|
72
|
+
"""Get top-K kept experiments by primary metric."""
|
|
73
|
+
kept = [e for e in experiments if e.get("status") == "kept" and e.get("metrics", {}).get(metric) is not None]
|
|
74
|
+
kept.sort(key=lambda e: e["metrics"][metric], reverse=not lower_is_better)
|
|
75
|
+
return kept[:top_k]
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def group_by_model_type(experiments: list[dict]) -> dict[str, list[dict]]:
|
|
79
|
+
"""Group experiments by model type, keeping best per type."""
|
|
80
|
+
groups: dict[str, list[dict]] = {}
|
|
81
|
+
for exp in experiments:
|
|
82
|
+
mt = exp.get("config", {}).get("model_type", "unknown")
|
|
83
|
+
groups.setdefault(mt, []).append(exp)
|
|
84
|
+
return groups
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def draft_setup_section(
|
|
88
|
+
config: dict,
|
|
89
|
+
experiments: list[dict],
|
|
90
|
+
seed_studies: dict[str, dict],
|
|
91
|
+
output_format: str = "latex",
|
|
92
|
+
) -> str:
|
|
93
|
+
"""Draft the experimental setup section."""
|
|
94
|
+
eval_cfg = config.get("evaluation", {})
|
|
95
|
+
data_cfg = config.get("data", {})
|
|
96
|
+
primary_metric = eval_cfg.get("primary_metric", "accuracy")
|
|
97
|
+
metrics = eval_cfg.get("metrics", [primary_metric])
|
|
98
|
+
lower_is_better = eval_cfg.get("lower_is_better", False)
|
|
99
|
+
|
|
100
|
+
task_desc = config.get("task_description", "the classification task")
|
|
101
|
+
data_source = data_cfg.get("source", "the provided dataset")
|
|
102
|
+
split_ratios = data_cfg.get("split_ratios", {})
|
|
103
|
+
random_state = data_cfg.get("random_state", 42)
|
|
104
|
+
|
|
105
|
+
# Determine seed study info
|
|
106
|
+
n_seeds = 0
|
|
107
|
+
for study in seed_studies.values():
|
|
108
|
+
n_seeds = max(n_seeds, len(study.get("seeds_run", [])))
|
|
109
|
+
|
|
110
|
+
# Build prose
|
|
111
|
+
split_text = ""
|
|
112
|
+
if split_ratios:
|
|
113
|
+
parts = [f"{int(v*100)}\\%" if output_format == "latex" else f"{int(v*100)}%" for v in split_ratios.values()]
|
|
114
|
+
split_names = list(split_ratios.keys())
|
|
115
|
+
split_text = "/".join(parts) + " " + "/".join(split_names) + " split"
|
|
116
|
+
|
|
117
|
+
direction = "lower" if lower_is_better else "higher"
|
|
118
|
+
metric_list = ", ".join(metrics)
|
|
119
|
+
|
|
120
|
+
if output_format == "latex":
|
|
121
|
+
lines = [
|
|
122
|
+
r"\subsection{Experimental Setup}",
|
|
123
|
+
"",
|
|
124
|
+
f"We evaluate on {data_source} using {metric_list} as evaluation metrics "
|
|
125
|
+
f"({direction} is better for {primary_metric}).",
|
|
126
|
+
]
|
|
127
|
+
if split_text:
|
|
128
|
+
lines.append(f"Data is partitioned using a {split_text} with random state {random_state}.")
|
|
129
|
+
if n_seeds > 0:
|
|
130
|
+
lines.append(
|
|
131
|
+
f"Results are reported as mean $\\pm$ standard deviation over {n_seeds} random seeds "
|
|
132
|
+
f"to account for seed sensitivity."
|
|
133
|
+
)
|
|
134
|
+
lines.append(f"All experiments use {task_desc} as the target task.")
|
|
135
|
+
else:
|
|
136
|
+
lines = [
|
|
137
|
+
"## Experimental Setup",
|
|
138
|
+
"",
|
|
139
|
+
f"We evaluate on {data_source} using {metric_list} as evaluation metrics "
|
|
140
|
+
f"({direction} is better for {primary_metric}).",
|
|
141
|
+
]
|
|
142
|
+
if split_text:
|
|
143
|
+
lines.append(f"Data is partitioned using a {split_text} with random state {random_state}.")
|
|
144
|
+
if n_seeds > 0:
|
|
145
|
+
lines.append(
|
|
146
|
+
f"Results are reported as mean +/- standard deviation over {n_seeds} random seeds "
|
|
147
|
+
f"to account for seed sensitivity."
|
|
148
|
+
)
|
|
149
|
+
lines.append(f"All experiments use {task_desc} as the target task.")
|
|
150
|
+
|
|
151
|
+
return "\n".join(lines)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def draft_results_table(
|
|
155
|
+
experiments: list[dict],
|
|
156
|
+
metrics: list[str],
|
|
157
|
+
primary_metric: str,
|
|
158
|
+
lower_is_better: bool,
|
|
159
|
+
seed_studies: dict[str, dict],
|
|
160
|
+
output_format: str = "latex",
|
|
161
|
+
dataset_name: str = "the dataset",
|
|
162
|
+
) -> str:
|
|
163
|
+
"""Draft the results comparison table."""
|
|
164
|
+
# Group by model type, take best per type
|
|
165
|
+
groups = group_by_model_type(experiments)
|
|
166
|
+
rows = []
|
|
167
|
+
for mt, exps in groups.items():
|
|
168
|
+
# Best experiment per model type
|
|
169
|
+
best = exps[0] # Already sorted by caller
|
|
170
|
+
row = {"model_type": mt, "experiment_id": best.get("experiment_id", "?")}
|
|
171
|
+
for m in metrics:
|
|
172
|
+
val = best.get("metrics", {}).get(m)
|
|
173
|
+
seed = seed_studies.get(best.get("experiment_id", ""))
|
|
174
|
+
if seed and seed.get("metric") == m:
|
|
175
|
+
row[m] = {"value": val, "mean": seed.get("mean"), "std": seed.get("std")}
|
|
176
|
+
else:
|
|
177
|
+
row[m] = {"value": val}
|
|
178
|
+
rows.append(row)
|
|
179
|
+
|
|
180
|
+
# Find best value per metric
|
|
181
|
+
best_per_metric = {}
|
|
182
|
+
for m in metrics:
|
|
183
|
+
values = [(r["model_type"], r[m].get("mean") or r[m].get("value")) for r in rows if r[m].get("value") is not None]
|
|
184
|
+
if values:
|
|
185
|
+
if lower_is_better:
|
|
186
|
+
best_per_metric[m] = min(values, key=lambda x: x[1] if x[1] is not None else float("inf"))[0]
|
|
187
|
+
else:
|
|
188
|
+
best_per_metric[m] = max(values, key=lambda x: x[1] if x[1] is not None else float("-inf"))[0]
|
|
189
|
+
|
|
190
|
+
if output_format == "latex":
|
|
191
|
+
return _format_results_latex(rows, metrics, best_per_metric, dataset_name)
|
|
192
|
+
else:
|
|
193
|
+
return _format_results_markdown(rows, metrics, best_per_metric, dataset_name)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _format_results_latex(rows: list[dict], metrics: list[str], best_per: dict, dataset: str) -> str:
|
|
197
|
+
"""Format results as LaTeX table."""
|
|
198
|
+
n_cols = len(metrics)
|
|
199
|
+
col_spec = "l" + "c" * n_cols
|
|
200
|
+
metric_headers = " & ".join(m.replace("_", r"\_") for m in metrics)
|
|
201
|
+
|
|
202
|
+
lines = [
|
|
203
|
+
r"\begin{table}[h]",
|
|
204
|
+
r"\centering",
|
|
205
|
+
f"\\caption{{Comparison of model architectures on {dataset}.}}",
|
|
206
|
+
r"\label{tab:results}",
|
|
207
|
+
f"\\begin{{tabular}}{{{col_spec}}}",
|
|
208
|
+
r"\toprule",
|
|
209
|
+
f"Model & {metric_headers} \\\\",
|
|
210
|
+
r"\midrule",
|
|
211
|
+
]
|
|
212
|
+
|
|
213
|
+
for row in rows:
|
|
214
|
+
mt = row["model_type"].replace("_", r"\_")
|
|
215
|
+
cells = [mt]
|
|
216
|
+
for m in metrics:
|
|
217
|
+
data = row[m]
|
|
218
|
+
val = data.get("mean") or data.get("value")
|
|
219
|
+
std = data.get("std")
|
|
220
|
+
if val is None:
|
|
221
|
+
cells.append("---")
|
|
222
|
+
elif std:
|
|
223
|
+
cell = f"{val:.3f} $\\pm$ {std:.3f}"
|
|
224
|
+
if best_per.get(m) == row["model_type"]:
|
|
225
|
+
cell = f"\\textbf{{{cell}}}"
|
|
226
|
+
cells.append(cell)
|
|
227
|
+
else:
|
|
228
|
+
cell = f"{val:.4f}"
|
|
229
|
+
if best_per.get(m) == row["model_type"]:
|
|
230
|
+
cell = f"\\textbf{{{cell}}}"
|
|
231
|
+
cells.append(cell)
|
|
232
|
+
lines.append(" & ".join(cells) + r" \\")
|
|
233
|
+
|
|
234
|
+
lines.extend([
|
|
235
|
+
r"\bottomrule",
|
|
236
|
+
r"\end{tabular}",
|
|
237
|
+
r"\end{table}",
|
|
238
|
+
])
|
|
239
|
+
return "\n".join(lines)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _format_results_markdown(rows: list[dict], metrics: list[str], best_per: dict, dataset: str) -> str:
|
|
243
|
+
"""Format results as markdown table."""
|
|
244
|
+
header = f"| Model |"
|
|
245
|
+
sep = "|-------|"
|
|
246
|
+
for m in metrics:
|
|
247
|
+
header += f" {m} |"
|
|
248
|
+
sep += f"{'---' * max(len(m) // 3, 1)}--|"
|
|
249
|
+
|
|
250
|
+
lines = [
|
|
251
|
+
f"## Results on {dataset}",
|
|
252
|
+
"",
|
|
253
|
+
header,
|
|
254
|
+
sep,
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
for row in rows:
|
|
258
|
+
line = f"| {row['model_type']} |"
|
|
259
|
+
for m in metrics:
|
|
260
|
+
data = row[m]
|
|
261
|
+
val = data.get("mean") or data.get("value")
|
|
262
|
+
std = data.get("std")
|
|
263
|
+
if val is None:
|
|
264
|
+
line += " --- |"
|
|
265
|
+
elif std:
|
|
266
|
+
cell = f"{val:.3f} +/- {std:.3f}"
|
|
267
|
+
if best_per.get(m) == row["model_type"]:
|
|
268
|
+
cell = f"**{cell}**"
|
|
269
|
+
line += f" {cell} |"
|
|
270
|
+
else:
|
|
271
|
+
cell = f"{val:.4f}"
|
|
272
|
+
if best_per.get(m) == row["model_type"]:
|
|
273
|
+
cell = f"**{cell}**"
|
|
274
|
+
line += f" {cell} |"
|
|
275
|
+
lines.append(line)
|
|
276
|
+
|
|
277
|
+
return "\n".join(lines)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def draft_ablation_table(
|
|
281
|
+
ablation_studies: dict[str, dict],
|
|
282
|
+
output_format: str = "latex",
|
|
283
|
+
) -> str:
|
|
284
|
+
"""Draft ablation table from ablation study results."""
|
|
285
|
+
if not ablation_studies:
|
|
286
|
+
return "No ablation studies available. Run `/turing:ablate` first."
|
|
287
|
+
|
|
288
|
+
# Use the most recent ablation study
|
|
289
|
+
study = list(ablation_studies.values())[-1]
|
|
290
|
+
metric = study.get("metric", "accuracy")
|
|
291
|
+
full_metric = study.get("full_model_metric", 0)
|
|
292
|
+
results = study.get("results", [])
|
|
293
|
+
|
|
294
|
+
if not results:
|
|
295
|
+
return "Ablation study has no results."
|
|
296
|
+
|
|
297
|
+
if output_format == "latex":
|
|
298
|
+
metric_escaped = metric.replace("_", r"\_")
|
|
299
|
+
lines = [
|
|
300
|
+
r"\begin{table}[h]",
|
|
301
|
+
r"\centering",
|
|
302
|
+
f"\\caption{{Ablation study results ({metric_escaped}).}}",
|
|
303
|
+
r"\label{tab:ablation}",
|
|
304
|
+
r"\begin{tabular}{lcc}",
|
|
305
|
+
r"\toprule",
|
|
306
|
+
f"Configuration & {metric_escaped} & $\\Delta$ from Full \\\\",
|
|
307
|
+
r"\midrule",
|
|
308
|
+
f"Full model & {full_metric:.4f} & --- \\\\",
|
|
309
|
+
]
|
|
310
|
+
for r in results:
|
|
311
|
+
if r.get("status") == "failed":
|
|
312
|
+
continue
|
|
313
|
+
config = r.get("configuration", "?").replace("_", r"\_")
|
|
314
|
+
val = r.get("metric_value", 0)
|
|
315
|
+
delta = r.get("delta", 0)
|
|
316
|
+
delta_str = f"{delta:+.4f}" if delta is not None else "---"
|
|
317
|
+
lines.append(f"{config} & {val:.4f} & {delta_str} \\\\")
|
|
318
|
+
lines.extend([r"\bottomrule", r"\end{tabular}", r"\end{table}"])
|
|
319
|
+
return "\n".join(lines)
|
|
320
|
+
else:
|
|
321
|
+
lines = [
|
|
322
|
+
f"## Ablation Study ({metric})",
|
|
323
|
+
"",
|
|
324
|
+
f"| Configuration | {metric} | Delta from Full |",
|
|
325
|
+
f"|---------------|{'---' * max(len(metric) // 3, 1)}--|-----------------|",
|
|
326
|
+
f"| Full model | {full_metric:.4f} | --- |",
|
|
327
|
+
]
|
|
328
|
+
for r in results:
|
|
329
|
+
if r.get("status") == "failed":
|
|
330
|
+
continue
|
|
331
|
+
config = r.get("configuration", "?")
|
|
332
|
+
val = r.get("metric_value", 0)
|
|
333
|
+
delta = r.get("delta", 0)
|
|
334
|
+
delta_str = f"{delta:+.4f}" if delta is not None else "---"
|
|
335
|
+
lines.append(f"| {config} | {val:.4f} | {delta_str} |")
|
|
336
|
+
return "\n".join(lines)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def draft_hyperparameter_table(
|
|
340
|
+
experiments: list[dict],
|
|
341
|
+
output_format: str = "latex",
|
|
342
|
+
) -> str:
|
|
343
|
+
"""Draft hyperparameter appendix table."""
|
|
344
|
+
groups = group_by_model_type(experiments)
|
|
345
|
+
|
|
346
|
+
if output_format == "latex":
|
|
347
|
+
lines = [
|
|
348
|
+
r"\begin{table}[h]",
|
|
349
|
+
r"\centering",
|
|
350
|
+
r"\caption{Hyperparameters for reported models.}",
|
|
351
|
+
r"\label{tab:hyperparams}",
|
|
352
|
+
r"\begin{tabular}{llr}",
|
|
353
|
+
r"\toprule",
|
|
354
|
+
r"Model & Parameter & Value \\",
|
|
355
|
+
r"\midrule",
|
|
356
|
+
]
|
|
357
|
+
for mt, exps in groups.items():
|
|
358
|
+
best = exps[0]
|
|
359
|
+
hyperparams = best.get("config", {}).get("hyperparams", {})
|
|
360
|
+
first = True
|
|
361
|
+
for param, value in sorted(hyperparams.items()):
|
|
362
|
+
model_col = mt.replace("_", r"\_") if first else ""
|
|
363
|
+
param_escaped = param.replace("_", r"\_")
|
|
364
|
+
lines.append(f"{model_col} & {param_escaped} & {value} \\\\")
|
|
365
|
+
first = False
|
|
366
|
+
lines.append(r"\midrule")
|
|
367
|
+
if lines[-1] == r"\midrule":
|
|
368
|
+
lines.pop()
|
|
369
|
+
lines.extend([r"\bottomrule", r"\end{tabular}", r"\end{table}"])
|
|
370
|
+
return "\n".join(lines)
|
|
371
|
+
else:
|
|
372
|
+
lines = [
|
|
373
|
+
"## Hyperparameters",
|
|
374
|
+
"",
|
|
375
|
+
"| Model | Parameter | Value |",
|
|
376
|
+
"|-------|-----------|-------|",
|
|
377
|
+
]
|
|
378
|
+
for mt, exps in groups.items():
|
|
379
|
+
best = exps[0]
|
|
380
|
+
hyperparams = best.get("config", {}).get("hyperparams", {})
|
|
381
|
+
for param, value in sorted(hyperparams.items()):
|
|
382
|
+
lines.append(f"| {mt} | {param} | {value} |")
|
|
383
|
+
return "\n".join(lines)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def save_paper_sections(sections: dict[str, str], output_dir: str = "paper/sections") -> list[Path]:
|
|
387
|
+
"""Save each section to its own file."""
|
|
388
|
+
out_path = Path(output_dir)
|
|
389
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
390
|
+
|
|
391
|
+
saved = []
|
|
392
|
+
for name, content in sections.items():
|
|
393
|
+
ext = ".tex" if r"\begin" in content else ".md"
|
|
394
|
+
filepath = out_path / f"{name}{ext}"
|
|
395
|
+
with open(filepath, "w") as f:
|
|
396
|
+
f.write(content)
|
|
397
|
+
saved.append(filepath)
|
|
398
|
+
|
|
399
|
+
return saved
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def draft_paper(
|
|
403
|
+
sections_str: str | None = None,
|
|
404
|
+
output_format: str = DEFAULT_FORMAT,
|
|
405
|
+
config_path: str = "config.yaml",
|
|
406
|
+
log_path: str = "experiments/log.jsonl",
|
|
407
|
+
) -> dict:
|
|
408
|
+
"""Draft all requested paper sections.
|
|
409
|
+
|
|
410
|
+
Args:
|
|
411
|
+
sections_str: Comma-separated section names (setup,results,ablation,hyperparameters).
|
|
412
|
+
output_format: "latex" or "markdown".
|
|
413
|
+
config_path: Path to config.yaml.
|
|
414
|
+
log_path: Path to experiment log.
|
|
415
|
+
|
|
416
|
+
Returns:
|
|
417
|
+
Dict with section name -> content mappings.
|
|
418
|
+
"""
|
|
419
|
+
if sections_str:
|
|
420
|
+
sections = [s.strip() for s in sections_str.split(",")]
|
|
421
|
+
else:
|
|
422
|
+
sections = VALID_SECTIONS
|
|
423
|
+
|
|
424
|
+
config = load_config(config_path)
|
|
425
|
+
eval_cfg = config.get("evaluation", {})
|
|
426
|
+
primary_metric = eval_cfg.get("primary_metric", "accuracy")
|
|
427
|
+
all_metrics = eval_cfg.get("metrics", [primary_metric])
|
|
428
|
+
lower_is_better = eval_cfg.get("lower_is_better", False)
|
|
429
|
+
|
|
430
|
+
experiments = load_experiments(log_path)
|
|
431
|
+
top_exps = get_top_experiments(experiments, primary_metric, lower_is_better)
|
|
432
|
+
|
|
433
|
+
seed_studies = load_seed_studies_for_paper()
|
|
434
|
+
ablation_studies = load_ablation_studies()
|
|
435
|
+
|
|
436
|
+
dataset_name = config.get("data", {}).get("source", "the dataset")
|
|
437
|
+
|
|
438
|
+
result = {"format": output_format, "sections": {}}
|
|
439
|
+
|
|
440
|
+
for section in sections:
|
|
441
|
+
if section == "setup":
|
|
442
|
+
result["sections"]["setup"] = draft_setup_section(
|
|
443
|
+
config, experiments, seed_studies, output_format,
|
|
444
|
+
)
|
|
445
|
+
elif section == "results":
|
|
446
|
+
result["sections"]["results"] = draft_results_table(
|
|
447
|
+
top_exps, all_metrics, primary_metric, lower_is_better,
|
|
448
|
+
seed_studies, output_format, dataset_name,
|
|
449
|
+
)
|
|
450
|
+
elif section == "ablation":
|
|
451
|
+
result["sections"]["ablation"] = draft_ablation_table(
|
|
452
|
+
ablation_studies, output_format,
|
|
453
|
+
)
|
|
454
|
+
elif section == "hyperparameters":
|
|
455
|
+
result["sections"]["hyperparameters"] = draft_hyperparameter_table(
|
|
456
|
+
top_exps, output_format,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
return result
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def main() -> None:
|
|
463
|
+
"""CLI entry point."""
|
|
464
|
+
parser = argparse.ArgumentParser(description="Draft paper sections from experiment logs")
|
|
465
|
+
parser.add_argument("--sections", default=None, help="Comma-separated sections: setup,results,ablation,hyperparameters")
|
|
466
|
+
parser.add_argument("--format", default=DEFAULT_FORMAT, dest="output_format", choices=["latex", "markdown"])
|
|
467
|
+
parser.add_argument("--config", default="config.yaml", help="Path to config.yaml")
|
|
468
|
+
parser.add_argument("--log", default="experiments/log.jsonl", help="Path to experiment log")
|
|
469
|
+
parser.add_argument("--output", default="paper/sections", help="Output directory")
|
|
470
|
+
parser.add_argument("--json", action="store_true", help="Output raw JSON")
|
|
471
|
+
args = parser.parse_args()
|
|
472
|
+
|
|
473
|
+
result = draft_paper(
|
|
474
|
+
sections_str=args.sections,
|
|
475
|
+
output_format=args.output_format,
|
|
476
|
+
config_path=args.config,
|
|
477
|
+
log_path=args.log,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
# Save sections
|
|
481
|
+
if result.get("sections"):
|
|
482
|
+
saved = save_paper_sections(result["sections"], args.output)
|
|
483
|
+
for f in saved:
|
|
484
|
+
print(f"Saved: {f}", file=sys.stderr)
|
|
485
|
+
|
|
486
|
+
if args.json:
|
|
487
|
+
print(json.dumps(result, indent=2, default=str))
|
|
488
|
+
else:
|
|
489
|
+
for name, content in result.get("sections", {}).items():
|
|
490
|
+
print(f"\n{'=' * 60}")
|
|
491
|
+
print(f" {name.upper()}")
|
|
492
|
+
print(f"{'=' * 60}\n")
|
|
493
|
+
print(content)
|
|
494
|
+
print()
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
if __name__ == "__main__":
|
|
498
|
+
main()
|
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Inference equivalence verification for model exports.
|
|
3
|
+
|
|
4
|
+
Compares outputs between the original model and the exported model
|
|
5
|
+
to verify they produce identical (or near-identical) results.
|
|
6
|
+
|
|
7
|
+
Verdicts:
|
|
8
|
+
equivalent — max delta < 1e-6 (float precision)
|
|
9
|
+
approximately_equivalent — max delta < tolerance (default 1e-5)
|
|
10
|
+
divergent — max delta >= tolerance
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import sys
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
FLOAT32_TOLERANCE = 1e-5
|
|
23
|
+
QUANTIZED_TOLERANCE = 1e-3
|
|
24
|
+
EXACT_TOLERANCE = 1e-6
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def compare_outputs(
|
|
28
|
+
original_outputs: list | np.ndarray,
|
|
29
|
+
exported_outputs: list | np.ndarray,
|
|
30
|
+
tolerance: float = FLOAT32_TOLERANCE,
|
|
31
|
+
) -> dict:
|
|
32
|
+
"""Compare outputs from original and exported models.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
original_outputs: Predictions from original model.
|
|
36
|
+
exported_outputs: Predictions from exported model.
|
|
37
|
+
tolerance: Maximum allowed difference.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Dict with verdict, max_delta, mean_delta, and per-sample details.
|
|
41
|
+
"""
|
|
42
|
+
orig = np.array(original_outputs, dtype=np.float64).flatten()
|
|
43
|
+
exported = np.array(exported_outputs, dtype=np.float64).flatten()
|
|
44
|
+
|
|
45
|
+
if orig.shape != exported.shape:
|
|
46
|
+
return {
|
|
47
|
+
"verdict": "divergent",
|
|
48
|
+
"reason": f"Shape mismatch: original {orig.shape} vs exported {exported.shape}",
|
|
49
|
+
"max_delta": float("inf"),
|
|
50
|
+
"mean_delta": float("inf"),
|
|
51
|
+
"n_samples": 0,
|
|
52
|
+
"n_divergent": 0,
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
deltas = np.abs(orig - exported)
|
|
56
|
+
max_delta = float(np.max(deltas))
|
|
57
|
+
mean_delta = float(np.mean(deltas))
|
|
58
|
+
n_divergent = int(np.sum(deltas >= tolerance))
|
|
59
|
+
|
|
60
|
+
if max_delta < EXACT_TOLERANCE:
|
|
61
|
+
verdict = "equivalent"
|
|
62
|
+
reason = f"Exact match (max delta {max_delta:.2e} < {EXACT_TOLERANCE})"
|
|
63
|
+
elif max_delta < tolerance:
|
|
64
|
+
verdict = "approximately_equivalent"
|
|
65
|
+
reason = f"Within tolerance (max delta {max_delta:.2e} < {tolerance})"
|
|
66
|
+
else:
|
|
67
|
+
verdict = "divergent"
|
|
68
|
+
reason = (
|
|
69
|
+
f"Max delta {max_delta:.2e} exceeds tolerance {tolerance}. "
|
|
70
|
+
f"{n_divergent} of {len(orig)} samples diverge."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return {
|
|
74
|
+
"verdict": verdict,
|
|
75
|
+
"reason": reason,
|
|
76
|
+
"max_delta": round(max_delta, 10),
|
|
77
|
+
"mean_delta": round(mean_delta, 10),
|
|
78
|
+
"n_samples": len(orig),
|
|
79
|
+
"n_divergent": n_divergent,
|
|
80
|
+
"tolerance": tolerance,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def run_equivalence_check(
|
|
85
|
+
original_predict_fn,
|
|
86
|
+
exported_predict_fn,
|
|
87
|
+
test_data: np.ndarray | list,
|
|
88
|
+
tolerance: float = FLOAT32_TOLERANCE,
|
|
89
|
+
) -> dict:
|
|
90
|
+
"""Run equivalence check by predicting on test data with both models.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
original_predict_fn: Callable that takes input data and returns predictions.
|
|
94
|
+
exported_predict_fn: Callable for the exported model.
|
|
95
|
+
test_data: Input data to predict on.
|
|
96
|
+
tolerance: Maximum allowed difference.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Equivalence check result dict.
|
|
100
|
+
"""
|
|
101
|
+
try:
|
|
102
|
+
original_preds = original_predict_fn(test_data)
|
|
103
|
+
except Exception as e:
|
|
104
|
+
return {"verdict": "error", "reason": f"Original model prediction failed: {e}"}
|
|
105
|
+
|
|
106
|
+
try:
|
|
107
|
+
exported_preds = exported_predict_fn(test_data)
|
|
108
|
+
except Exception as e:
|
|
109
|
+
return {"verdict": "error", "reason": f"Exported model prediction failed: {e}"}
|
|
110
|
+
|
|
111
|
+
return compare_outputs(original_preds, exported_preds, tolerance)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def generate_test_data(n_samples: int = 100, n_features: int = 10, seed: int = 42) -> np.ndarray:
|
|
115
|
+
"""Generate random test data for equivalence checking.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
n_samples: Number of test samples.
|
|
119
|
+
n_features: Number of features per sample.
|
|
120
|
+
seed: Random seed for reproducibility.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Array of shape (n_samples, n_features).
|
|
124
|
+
"""
|
|
125
|
+
rng = np.random.RandomState(seed)
|
|
126
|
+
return rng.randn(n_samples, n_features).astype(np.float32)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def format_equivalence_report(result: dict) -> str:
|
|
130
|
+
"""Format equivalence check result as readable text."""
|
|
131
|
+
verdict = result.get("verdict", "unknown")
|
|
132
|
+
reason = result.get("reason", "")
|
|
133
|
+
|
|
134
|
+
verdict_markers = {
|
|
135
|
+
"equivalent": "PASS (exact)",
|
|
136
|
+
"approximately_equivalent": "PASS (approx)",
|
|
137
|
+
"divergent": "FAIL",
|
|
138
|
+
"error": "ERROR",
|
|
139
|
+
}
|
|
140
|
+
marker = verdict_markers.get(verdict, verdict)
|
|
141
|
+
|
|
142
|
+
lines = [
|
|
143
|
+
f"## Equivalence Check: {marker}",
|
|
144
|
+
"",
|
|
145
|
+
f"*{reason}*",
|
|
146
|
+
"",
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
if result.get("n_samples"):
|
|
150
|
+
lines.extend([
|
|
151
|
+
f"- **Samples tested:** {result['n_samples']}",
|
|
152
|
+
f"- **Max delta:** {result['max_delta']:.2e}",
|
|
153
|
+
f"- **Mean delta:** {result['mean_delta']:.2e}",
|
|
154
|
+
f"- **Divergent samples:** {result['n_divergent']}",
|
|
155
|
+
f"- **Tolerance:** {result['tolerance']:.0e}",
|
|
156
|
+
])
|
|
157
|
+
|
|
158
|
+
return "\n".join(lines)
|