claude-turing 1.3.0 → 1.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +5 -2
- package/commands/ablate.md +47 -0
- package/commands/diagnose.md +52 -0
- package/commands/frontier.md +45 -0
- package/commands/turing.md +6 -0
- package/package.json +1 -1
- package/src/install.js +1 -0
- package/src/verify.js +3 -0
- package/templates/scripts/__pycache__/ablation_study.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/diagnose_errors.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/pareto_frontier.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/ablation_study.py +487 -0
- package/templates/scripts/diagnose_errors.py +601 -0
- package/templates/scripts/generate_brief.py +37 -1
- package/templates/scripts/pareto_frontier.py +470 -0
- package/templates/scripts/scaffold.py +7 -0
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Systematic ablation studies for ML experiments.
|
|
3
|
+
|
|
4
|
+
Removes components one at a time, measures impact on primary metric,
|
|
5
|
+
and produces a publication-ready ablation table. Flags dead-weight
|
|
6
|
+
components (removing them improves the model).
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/ablation_study.py # Auto-detect components
|
|
10
|
+
python scripts/ablation_study.py --exp-id exp-042 # Specific experiment
|
|
11
|
+
python scripts/ablation_study.py --components "dropout,feature_X" # Specific components
|
|
12
|
+
python scripts/ablation_study.py --seeds 3 # Statistical robustness
|
|
13
|
+
python scripts/ablation_study.py --latex # LaTeX output
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import json
|
|
20
|
+
import subprocess
|
|
21
|
+
import sys
|
|
22
|
+
from datetime import datetime, timezone
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
|
|
25
|
+
import numpy as np
|
|
26
|
+
import yaml
|
|
27
|
+
|
|
28
|
+
from scripts.turing_io import load_config, load_experiments
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def find_experiment(experiments: list[dict], exp_id: str | None, metric: str, lower_is_better: bool) -> dict | None:
|
|
32
|
+
"""Find experiment by ID or return best kept."""
|
|
33
|
+
if exp_id:
|
|
34
|
+
for exp in experiments:
|
|
35
|
+
if exp.get("experiment_id") == exp_id:
|
|
36
|
+
return exp
|
|
37
|
+
return None
|
|
38
|
+
best = None
|
|
39
|
+
best_val = float("inf") if lower_is_better else float("-inf")
|
|
40
|
+
for exp in experiments:
|
|
41
|
+
if exp.get("status") != "kept":
|
|
42
|
+
continue
|
|
43
|
+
val = exp.get("metrics", {}).get(metric)
|
|
44
|
+
if val is None:
|
|
45
|
+
continue
|
|
46
|
+
if (lower_is_better and val < best_val) or (not lower_is_better and val > best_val):
|
|
47
|
+
best_val = val
|
|
48
|
+
best = exp
|
|
49
|
+
return best
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def detect_ablatable_components(config: dict) -> list[dict]:
|
|
53
|
+
"""Auto-detect components that can be ablated from the model config.
|
|
54
|
+
|
|
55
|
+
Returns list of component dicts with name, type, current_value, and
|
|
56
|
+
ablation_config (what to set it to when "removing" the component).
|
|
57
|
+
"""
|
|
58
|
+
components = []
|
|
59
|
+
hyperparams = config.get("model", {}).get("hyperparams", {})
|
|
60
|
+
|
|
61
|
+
# Regularization parameters
|
|
62
|
+
regularization_params = {
|
|
63
|
+
"max_depth": ("regularization", "depth limit", 0),
|
|
64
|
+
"min_child_weight": ("regularization", "min samples per leaf", 0),
|
|
65
|
+
"min_samples_split": ("regularization", "min split samples", 2),
|
|
66
|
+
"min_samples_leaf": ("regularization", "min leaf samples", 1),
|
|
67
|
+
"reg_alpha": ("regularization", "L1 penalty", 0),
|
|
68
|
+
"reg_lambda": ("regularization", "L2 penalty", 0),
|
|
69
|
+
"alpha": ("regularization", "L1 penalty", 0),
|
|
70
|
+
"l1_ratio": ("regularization", "L1/L2 ratio", 0),
|
|
71
|
+
"gamma": ("regularization", "min split loss", 0),
|
|
72
|
+
"subsample": ("regularization", "row subsampling", 1.0),
|
|
73
|
+
"colsample_bytree": ("regularization", "column subsampling", 1.0),
|
|
74
|
+
"colsample_bylevel": ("regularization", "level column subsampling", 1.0),
|
|
75
|
+
"dropout_rate": ("regularization", "dropout", 0),
|
|
76
|
+
"weight_decay": ("regularization", "weight decay", 0),
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
for param, (comp_type, desc, removal_val) in regularization_params.items():
|
|
80
|
+
if param in hyperparams:
|
|
81
|
+
current = hyperparams[param]
|
|
82
|
+
if current != removal_val:
|
|
83
|
+
components.append({
|
|
84
|
+
"name": param,
|
|
85
|
+
"type": comp_type,
|
|
86
|
+
"description": desc,
|
|
87
|
+
"current_value": current,
|
|
88
|
+
"ablation_value": removal_val,
|
|
89
|
+
"config_path": f"model.hyperparams.{param}",
|
|
90
|
+
})
|
|
91
|
+
|
|
92
|
+
# Model complexity parameters (reduce, not remove)
|
|
93
|
+
complexity_params = {
|
|
94
|
+
"n_estimators": ("complexity", "number of trees/estimators", 10),
|
|
95
|
+
"num_leaves": ("complexity", "number of leaves", 8),
|
|
96
|
+
"max_features": ("complexity", "feature subset", None),
|
|
97
|
+
"hidden_layer_sizes": ("complexity", "network architecture", None),
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
for param, (comp_type, desc, reduction_val) in complexity_params.items():
|
|
101
|
+
if param in hyperparams and reduction_val is not None:
|
|
102
|
+
components.append({
|
|
103
|
+
"name": param,
|
|
104
|
+
"type": comp_type,
|
|
105
|
+
"description": desc,
|
|
106
|
+
"current_value": hyperparams[param],
|
|
107
|
+
"ablation_value": reduction_val,
|
|
108
|
+
"config_path": f"model.hyperparams.{param}",
|
|
109
|
+
})
|
|
110
|
+
|
|
111
|
+
# Learning rate (test with higher LR = less refined)
|
|
112
|
+
if "learning_rate" in hyperparams:
|
|
113
|
+
lr = hyperparams["learning_rate"]
|
|
114
|
+
if lr < 0.5:
|
|
115
|
+
components.append({
|
|
116
|
+
"name": "learning_rate",
|
|
117
|
+
"type": "training",
|
|
118
|
+
"description": "learning rate (test with 10x higher)",
|
|
119
|
+
"current_value": lr,
|
|
120
|
+
"ablation_value": min(lr * 10, 1.0),
|
|
121
|
+
"config_path": "model.hyperparams.learning_rate",
|
|
122
|
+
})
|
|
123
|
+
|
|
124
|
+
return components
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def parse_component_list(components_str: str) -> list[str]:
|
|
128
|
+
"""Parse comma-separated component names."""
|
|
129
|
+
return [c.strip() for c in components_str.split(",") if c.strip()]
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def run_ablation_experiment(
|
|
133
|
+
component: dict,
|
|
134
|
+
config: dict,
|
|
135
|
+
seed: int = 42,
|
|
136
|
+
timeout: int = 600,
|
|
137
|
+
) -> dict | None:
|
|
138
|
+
"""Run a single ablation experiment with one component modified.
|
|
139
|
+
|
|
140
|
+
Returns parsed metrics dict or None on failure.
|
|
141
|
+
"""
|
|
142
|
+
# We run train.py with the modified config via environment or temp config
|
|
143
|
+
# For simplicity, we use the --override flag pattern
|
|
144
|
+
cmd = [
|
|
145
|
+
"python", "train.py",
|
|
146
|
+
"--seed", str(seed),
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
try:
|
|
150
|
+
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
|
|
151
|
+
except subprocess.TimeoutExpired:
|
|
152
|
+
return None
|
|
153
|
+
|
|
154
|
+
if proc.returncode != 0:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
metrics = {}
|
|
158
|
+
in_block = False
|
|
159
|
+
metadata_keys = {"model_type", "train_seconds"}
|
|
160
|
+
|
|
161
|
+
for line in proc.stdout.splitlines():
|
|
162
|
+
line = line.strip()
|
|
163
|
+
if line == "---":
|
|
164
|
+
if in_block:
|
|
165
|
+
break
|
|
166
|
+
in_block = True
|
|
167
|
+
continue
|
|
168
|
+
if in_block and ":" in line:
|
|
169
|
+
key, value = line.split(":", 1)
|
|
170
|
+
key = key.strip()
|
|
171
|
+
value = value.strip()
|
|
172
|
+
if key in metadata_keys:
|
|
173
|
+
metrics[key] = value
|
|
174
|
+
else:
|
|
175
|
+
try:
|
|
176
|
+
metrics[key] = float(value)
|
|
177
|
+
except ValueError:
|
|
178
|
+
metrics[key] = value
|
|
179
|
+
|
|
180
|
+
return metrics if metrics else None
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def compute_ablation_table(
|
|
184
|
+
full_model_metric: float,
|
|
185
|
+
ablation_results: list[dict],
|
|
186
|
+
metric: str,
|
|
187
|
+
lower_is_better: bool,
|
|
188
|
+
) -> list[dict]:
|
|
189
|
+
"""Compute the ablation table with deltas and rankings.
|
|
190
|
+
|
|
191
|
+
Returns list of row dicts sorted by absolute impact (largest first).
|
|
192
|
+
"""
|
|
193
|
+
rows = []
|
|
194
|
+
|
|
195
|
+
for result in ablation_results:
|
|
196
|
+
component = result["component"]
|
|
197
|
+
value = result.get("metric_value")
|
|
198
|
+
|
|
199
|
+
if value is None:
|
|
200
|
+
rows.append({
|
|
201
|
+
"configuration": f"− {component['name']}",
|
|
202
|
+
"component": component,
|
|
203
|
+
"metric_value": None,
|
|
204
|
+
"delta": None,
|
|
205
|
+
"delta_pct": None,
|
|
206
|
+
"is_dead_weight": False,
|
|
207
|
+
"status": "failed",
|
|
208
|
+
})
|
|
209
|
+
continue
|
|
210
|
+
|
|
211
|
+
delta = value - full_model_metric
|
|
212
|
+
delta_pct = (delta / abs(full_model_metric) * 100) if full_model_metric != 0 else 0
|
|
213
|
+
|
|
214
|
+
# Dead weight: removing it improves the metric
|
|
215
|
+
if lower_is_better:
|
|
216
|
+
is_dead_weight = delta < 0 # lower is better, so negative delta = improvement
|
|
217
|
+
else:
|
|
218
|
+
is_dead_weight = delta > 0 # higher is better, so positive delta = improvement
|
|
219
|
+
|
|
220
|
+
rows.append({
|
|
221
|
+
"configuration": f"− {component['name']}",
|
|
222
|
+
"component": component,
|
|
223
|
+
"metric_value": round(value, 6),
|
|
224
|
+
"delta": round(delta, 6),
|
|
225
|
+
"delta_pct": round(delta_pct, 2),
|
|
226
|
+
"is_dead_weight": is_dead_weight,
|
|
227
|
+
"status": "completed",
|
|
228
|
+
})
|
|
229
|
+
|
|
230
|
+
# Sort by absolute delta (most impactful first)
|
|
231
|
+
rows.sort(key=lambda r: -abs(r["delta"]) if r["delta"] is not None else 0)
|
|
232
|
+
return rows
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def format_ablation_table(
|
|
236
|
+
full_metric: float,
|
|
237
|
+
rows: list[dict],
|
|
238
|
+
metric: str,
|
|
239
|
+
lower_is_better: bool,
|
|
240
|
+
) -> str:
|
|
241
|
+
"""Format ablation results as a markdown table."""
|
|
242
|
+
direction = "lower" if lower_is_better else "higher"
|
|
243
|
+
lines = [
|
|
244
|
+
f"# Ablation Study",
|
|
245
|
+
"",
|
|
246
|
+
f"*{metric} ({direction} is better)*",
|
|
247
|
+
"",
|
|
248
|
+
f"| Configuration | {metric} | Δ from Full | % Change | Status |",
|
|
249
|
+
f"|---------------|{'---' * len(metric)}--|-------------|----------|--------|",
|
|
250
|
+
f"| Full model | {full_metric:.4f} | — | — | baseline |",
|
|
251
|
+
]
|
|
252
|
+
|
|
253
|
+
for row in rows:
|
|
254
|
+
if row["status"] == "failed":
|
|
255
|
+
lines.append(f"| {row['configuration']} | FAILED | — | — | error |")
|
|
256
|
+
continue
|
|
257
|
+
|
|
258
|
+
delta_str = f"{row['delta']:+.4f}" if row["delta"] is not None else "N/A"
|
|
259
|
+
pct_str = f"{row['delta_pct']:+.1f}%" if row["delta_pct"] is not None else "N/A"
|
|
260
|
+
status = "DEAD WEIGHT" if row["is_dead_weight"] else "contributes"
|
|
261
|
+
|
|
262
|
+
lines.append(
|
|
263
|
+
f"| {row['configuration']} | {row['metric_value']:.4f} "
|
|
264
|
+
f"| {delta_str} | {pct_str} | {status} |"
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Summary
|
|
268
|
+
dead_weight = [r for r in rows if r.get("is_dead_weight")]
|
|
269
|
+
if dead_weight:
|
|
270
|
+
lines.extend([
|
|
271
|
+
"",
|
|
272
|
+
"## Dead-Weight Components",
|
|
273
|
+
"",
|
|
274
|
+
"These components can be removed to **improve** the model:",
|
|
275
|
+
"",
|
|
276
|
+
])
|
|
277
|
+
for r in dead_weight:
|
|
278
|
+
lines.append(f"- **{r['component']['name']}** ({r['component']['description']}): "
|
|
279
|
+
f"removing it improves {metric} by {abs(r['delta']):.4f}")
|
|
280
|
+
|
|
281
|
+
return "\n".join(lines)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def format_latex_table(
|
|
285
|
+
full_metric: float,
|
|
286
|
+
rows: list[dict],
|
|
287
|
+
metric: str,
|
|
288
|
+
) -> str:
|
|
289
|
+
"""Format ablation results as a LaTeX table."""
|
|
290
|
+
lines = [
|
|
291
|
+
r"\begin{table}[h]",
|
|
292
|
+
r"\centering",
|
|
293
|
+
f"\\caption{{Ablation study results ({metric})}}",
|
|
294
|
+
f"\\label{{tab:ablation}}",
|
|
295
|
+
r"\begin{tabular}{lcc}",
|
|
296
|
+
r"\toprule",
|
|
297
|
+
f"Configuration & {metric} & $\\Delta$ from Full \\\\",
|
|
298
|
+
r"\midrule",
|
|
299
|
+
f"Full model & {full_metric:.4f} & --- \\\\",
|
|
300
|
+
]
|
|
301
|
+
|
|
302
|
+
for row in rows:
|
|
303
|
+
if row["status"] == "failed":
|
|
304
|
+
continue
|
|
305
|
+
delta_str = f"{row['delta']:+.4f}" if row["delta"] is not None else "N/A"
|
|
306
|
+
config_escaped = row["configuration"].replace("_", r"\_")
|
|
307
|
+
lines.append(f"{config_escaped} & {row['metric_value']:.4f} & {delta_str} \\\\")
|
|
308
|
+
|
|
309
|
+
lines.extend([
|
|
310
|
+
r"\bottomrule",
|
|
311
|
+
r"\end{tabular}",
|
|
312
|
+
r"\end{table}",
|
|
313
|
+
])
|
|
314
|
+
return "\n".join(lines)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def save_ablation(study: dict, output_dir: str = "experiments/ablations") -> Path:
|
|
318
|
+
"""Save ablation study to YAML file."""
|
|
319
|
+
out_path = Path(output_dir)
|
|
320
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
321
|
+
exp_id = study.get("experiment_id", "unknown")
|
|
322
|
+
filepath = out_path / f"{exp_id}-ablation.yaml"
|
|
323
|
+
with open(filepath, "w") as f:
|
|
324
|
+
yaml.dump(study, f, default_flow_style=False, sort_keys=False)
|
|
325
|
+
return filepath
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def run_ablation_study(
|
|
329
|
+
exp_id: str | None = None,
|
|
330
|
+
components_str: str | None = None,
|
|
331
|
+
n_seeds: int = 1,
|
|
332
|
+
config_path: str = "config.yaml",
|
|
333
|
+
log_path: str = "experiments/log.jsonl",
|
|
334
|
+
timeout: int = 600,
|
|
335
|
+
) -> dict:
|
|
336
|
+
"""Run a complete ablation study.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
exp_id: Experiment ID (defaults to best).
|
|
340
|
+
components_str: Comma-separated component names to ablate.
|
|
341
|
+
n_seeds: Number of seeds per ablation (for statistical robustness).
|
|
342
|
+
config_path: Path to config.yaml.
|
|
343
|
+
log_path: Path to experiment log.
|
|
344
|
+
timeout: Per-run timeout in seconds.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
Ablation study result dict.
|
|
348
|
+
"""
|
|
349
|
+
config = load_config(config_path)
|
|
350
|
+
eval_cfg = config.get("evaluation", {})
|
|
351
|
+
primary_metric = eval_cfg.get("primary_metric", "accuracy")
|
|
352
|
+
lower_is_better = eval_cfg.get("lower_is_better", False)
|
|
353
|
+
|
|
354
|
+
experiments = load_experiments(log_path)
|
|
355
|
+
target_exp = find_experiment(experiments, exp_id, primary_metric, lower_is_better)
|
|
356
|
+
|
|
357
|
+
if not target_exp:
|
|
358
|
+
return {"error": f"No experiment found{f' with ID {exp_id}' if exp_id else ''}", "experiment_id": exp_id}
|
|
359
|
+
|
|
360
|
+
target_id = target_exp.get("experiment_id", "unknown")
|
|
361
|
+
full_metric = target_exp.get("metrics", {}).get(primary_metric)
|
|
362
|
+
|
|
363
|
+
if full_metric is None:
|
|
364
|
+
return {"error": f"Experiment {target_id} has no {primary_metric} metric", "experiment_id": target_id}
|
|
365
|
+
|
|
366
|
+
# Detect or parse components
|
|
367
|
+
if components_str:
|
|
368
|
+
component_names = parse_component_list(components_str)
|
|
369
|
+
all_components = detect_ablatable_components(config)
|
|
370
|
+
components = [c for c in all_components if c["name"] in component_names]
|
|
371
|
+
# Add unknown components with basic info
|
|
372
|
+
known_names = {c["name"] for c in components}
|
|
373
|
+
for name in component_names:
|
|
374
|
+
if name not in known_names:
|
|
375
|
+
components.append({
|
|
376
|
+
"name": name,
|
|
377
|
+
"type": "custom",
|
|
378
|
+
"description": f"user-specified component: {name}",
|
|
379
|
+
"current_value": "unknown",
|
|
380
|
+
"ablation_value": None,
|
|
381
|
+
"config_path": f"custom.{name}",
|
|
382
|
+
})
|
|
383
|
+
else:
|
|
384
|
+
components = detect_ablatable_components(config)
|
|
385
|
+
|
|
386
|
+
if not components:
|
|
387
|
+
return {
|
|
388
|
+
"error": "No ablatable components detected. Specify with --components.",
|
|
389
|
+
"experiment_id": target_id,
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
print(f"Ablation study for {target_id}", file=sys.stderr)
|
|
393
|
+
print(f"Full model {primary_metric}: {full_metric:.4f}", file=sys.stderr)
|
|
394
|
+
print(f"Components to ablate: {[c['name'] for c in components]}", file=sys.stderr)
|
|
395
|
+
print(f"Seeds per ablation: {n_seeds}", file=sys.stderr)
|
|
396
|
+
print(file=sys.stderr)
|
|
397
|
+
|
|
398
|
+
# Run ablations
|
|
399
|
+
ablation_results = []
|
|
400
|
+
for comp in components:
|
|
401
|
+
print(f" Ablating {comp['name']}...", end=" ", flush=True, file=sys.stderr)
|
|
402
|
+
values = []
|
|
403
|
+
for seed_i in range(n_seeds):
|
|
404
|
+
seed = 42 + seed_i
|
|
405
|
+
result = run_ablation_experiment(comp, config, seed=seed, timeout=timeout)
|
|
406
|
+
if result and primary_metric in result:
|
|
407
|
+
values.append(result[primary_metric])
|
|
408
|
+
|
|
409
|
+
if values:
|
|
410
|
+
metric_value = float(np.mean(values))
|
|
411
|
+
metric_std = float(np.std(values, ddof=1)) if len(values) > 1 else 0.0
|
|
412
|
+
ablation_results.append({
|
|
413
|
+
"component": comp,
|
|
414
|
+
"metric_value": metric_value,
|
|
415
|
+
"metric_std": metric_std,
|
|
416
|
+
"n_seeds": len(values),
|
|
417
|
+
"values": values,
|
|
418
|
+
})
|
|
419
|
+
print(f"{primary_metric}={metric_value:.4f}", file=sys.stderr)
|
|
420
|
+
else:
|
|
421
|
+
ablation_results.append({
|
|
422
|
+
"component": comp,
|
|
423
|
+
"metric_value": None,
|
|
424
|
+
"status": "failed",
|
|
425
|
+
})
|
|
426
|
+
print("FAILED", file=sys.stderr)
|
|
427
|
+
|
|
428
|
+
# Compute table
|
|
429
|
+
table_rows = compute_ablation_table(full_metric, ablation_results, primary_metric, lower_is_better)
|
|
430
|
+
|
|
431
|
+
study = {
|
|
432
|
+
"experiment_id": target_id,
|
|
433
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
434
|
+
"metric": primary_metric,
|
|
435
|
+
"lower_is_better": lower_is_better,
|
|
436
|
+
"full_model_metric": round(full_metric, 6),
|
|
437
|
+
"components_ablated": len(components),
|
|
438
|
+
"seeds_per_ablation": n_seeds,
|
|
439
|
+
"results": table_rows,
|
|
440
|
+
"dead_weight": [r["component"]["name"] for r in table_rows if r.get("is_dead_weight")],
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
return study
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def main() -> None:
|
|
447
|
+
"""CLI entry point."""
|
|
448
|
+
parser = argparse.ArgumentParser(description="Systematic ablation studies for ML experiments")
|
|
449
|
+
parser.add_argument("--exp-id", default=None, help="Experiment ID (defaults to best)")
|
|
450
|
+
parser.add_argument("--components", default=None, help="Comma-separated component names to ablate")
|
|
451
|
+
parser.add_argument("--seeds", type=int, default=1, help="Seeds per ablation (default: 1, use 3+ for robust)")
|
|
452
|
+
parser.add_argument("--config", default="config.yaml", help="Path to config.yaml")
|
|
453
|
+
parser.add_argument("--log", default="experiments/log.jsonl", help="Path to experiment log")
|
|
454
|
+
parser.add_argument("--timeout", type=int, default=600, help="Per-run timeout in seconds")
|
|
455
|
+
parser.add_argument("--latex", action="store_true", help="Output LaTeX table instead of markdown")
|
|
456
|
+
parser.add_argument("--json", action="store_true", help="Output raw JSON")
|
|
457
|
+
args = parser.parse_args()
|
|
458
|
+
|
|
459
|
+
study = run_ablation_study(
|
|
460
|
+
exp_id=args.exp_id,
|
|
461
|
+
components_str=args.components,
|
|
462
|
+
n_seeds=args.seeds,
|
|
463
|
+
config_path=args.config,
|
|
464
|
+
log_path=args.log,
|
|
465
|
+
timeout=args.timeout,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
if "error" not in study:
|
|
469
|
+
filepath = save_ablation(study)
|
|
470
|
+
print(f"\nSaved to {filepath}", file=sys.stderr)
|
|
471
|
+
|
|
472
|
+
if args.json:
|
|
473
|
+
print(json.dumps(study, indent=2, default=str))
|
|
474
|
+
elif args.latex:
|
|
475
|
+
if "error" in study:
|
|
476
|
+
print(f"ERROR: {study['error']}")
|
|
477
|
+
else:
|
|
478
|
+
print(format_latex_table(study["full_model_metric"], study["results"], study["metric"]))
|
|
479
|
+
else:
|
|
480
|
+
if "error" in study:
|
|
481
|
+
print(f"ERROR: {study['error']}")
|
|
482
|
+
else:
|
|
483
|
+
print(format_ablation_table(study["full_model_metric"], study["results"], study["metric"], study["lower_is_better"]))
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
if __name__ == "__main__":
|
|
487
|
+
main()
|