claude-turing 4.1.0 → 4.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +5 -2
- package/commands/counterfactual.md +27 -0
- package/commands/simulate.md +28 -0
- package/commands/turing.md +6 -0
- package/commands/whatif.md +31 -0
- package/package.json +1 -1
- package/src/install.js +1 -0
- package/src/verify.js +3 -0
- package/templates/scripts/__pycache__/counterfactual_explanation.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/experiment_simulator.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/whatif_engine.cpython-314.pyc +0 -0
- package/templates/scripts/counterfactual_explanation.py +485 -0
- package/templates/scripts/experiment_simulator.py +463 -0
- package/templates/scripts/generate_brief.py +64 -0
- package/templates/scripts/scaffold.py +6 -0
- package/templates/scripts/whatif_engine.py +763 -0
|
@@ -0,0 +1,763 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""What-if analysis engine for the autoresearch pipeline.
|
|
3
|
+
|
|
4
|
+
Routes hypothetical questions to existing estimators: scaling laws,
|
|
5
|
+
ablation, sensitivity, ensemble, pruning, and stitch. Returns an
|
|
6
|
+
estimate with confidence without running new experiments.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/whatif_engine.py "what if I had 2x more data"
|
|
10
|
+
python scripts/whatif_engine.py "what if I removed class 3"
|
|
11
|
+
python scripts/whatif_engine.py "what if I combined exp-031 with exp-042"
|
|
12
|
+
python scripts/whatif_engine.py --json
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import re
|
|
20
|
+
import sys
|
|
21
|
+
from datetime import datetime, timezone
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
import yaml
|
|
25
|
+
|
|
26
|
+
from scripts.turing_io import load_config, load_experiments
|
|
27
|
+
|
|
28
|
+
DEFAULT_LOG_PATH = "experiments/log.jsonl"
|
|
29
|
+
|
|
30
|
+
# Question patterns mapped to estimator routes
|
|
31
|
+
ROUTE_PATTERNS = [
|
|
32
|
+
{
|
|
33
|
+
"name": "scaling",
|
|
34
|
+
"patterns": [
|
|
35
|
+
r"(?:more|less|2x|3x|4x|5x|10x|half|double|triple)\s+(?:data|samples|training data|examples)",
|
|
36
|
+
r"(?:data|dataset)\s+(?:size|scaling|increase|decrease)",
|
|
37
|
+
r"(?:scale|scaled)\s+(?:up|down)\s+(?:data|the data)",
|
|
38
|
+
],
|
|
39
|
+
"source": "scaling law fit from `/turing:scale`",
|
|
40
|
+
"verify_cmd": "/turing:scale --extrapolate",
|
|
41
|
+
"data_dir": "experiments/scaling",
|
|
42
|
+
},
|
|
43
|
+
{
|
|
44
|
+
"name": "ablation",
|
|
45
|
+
"patterns": [
|
|
46
|
+
r"(?:remov\w*|drop\w*|exclud\w*|without|ablat\w*)\s+(?:class|feature|component|column|variable)",
|
|
47
|
+
r"(?:class|feature|component)\s+(?:\d+|[A-Za-z_]+)\s+(?:remov\w*|drop\w*|exclud\w*)",
|
|
48
|
+
],
|
|
49
|
+
"source": "ablation study from `/turing:ablate`",
|
|
50
|
+
"verify_cmd": "/turing:ablate",
|
|
51
|
+
"data_dir": "experiments/ablations",
|
|
52
|
+
},
|
|
53
|
+
{
|
|
54
|
+
"name": "stitch",
|
|
55
|
+
"patterns": [
|
|
56
|
+
r"(?:combine|stitch|swap|mix)\s+.*(?:from|of)\s+exp-\d+\s+(?:with|and)\s+.*exp-\d+",
|
|
57
|
+
r"(?:pipeline|stage)\s+(?:from|of)\s+exp-\d+",
|
|
58
|
+
],
|
|
59
|
+
"source": "pipeline composition from `/turing:stitch`",
|
|
60
|
+
"verify_cmd": "/turing:stitch",
|
|
61
|
+
"data_dir": "experiments/cache",
|
|
62
|
+
},
|
|
63
|
+
{
|
|
64
|
+
"name": "sensitivity",
|
|
65
|
+
"patterns": [
|
|
66
|
+
r"(?:different|change|modify|adjust|set)\s+(?:hyperparameter|learning.?rate|lr|depth|estimators|epochs|batch.?size)",
|
|
67
|
+
r"(?:learning.?rate|lr|depth|max_depth|n_estimators|epochs|batch.?size)\s+(?:was|were|to|=|of)\s+[\d.]+",
|
|
68
|
+
],
|
|
69
|
+
"source": "sensitivity interpolation from `/turing:sensitivity`",
|
|
70
|
+
"verify_cmd": "/turing:sensitivity",
|
|
71
|
+
"data_dir": "experiments/sensitivity",
|
|
72
|
+
},
|
|
73
|
+
{
|
|
74
|
+
"name": "ensemble",
|
|
75
|
+
"patterns": [
|
|
76
|
+
r"(?:ensembl\w*|combine|blend\w*|stack\w*|vot\w*)\s+(?:(?:these|the|top|best)\s+)*(?:models|experiments)",
|
|
77
|
+
r"(?:voting|stacking|blending)\s+(?:of|with|from)",
|
|
78
|
+
],
|
|
79
|
+
"source": "prediction correlation from `/turing:ensemble`",
|
|
80
|
+
"verify_cmd": "/turing:ensemble",
|
|
81
|
+
"data_dir": "experiments/ensembles",
|
|
82
|
+
},
|
|
83
|
+
{
|
|
84
|
+
"name": "pruning",
|
|
85
|
+
"patterns": [
|
|
86
|
+
r"(?:prune|pruning|sparsity)\s+(?:to|at|of)?\s*\d+",
|
|
87
|
+
r"\d+%?\s*(?:sparsity|sparse|pruned)",
|
|
88
|
+
],
|
|
89
|
+
"source": "pruning sweep interpolation from `/turing:prune`",
|
|
90
|
+
"verify_cmd": "/turing:prune",
|
|
91
|
+
"data_dir": "experiments/pruning",
|
|
92
|
+
},
|
|
93
|
+
{
|
|
94
|
+
"name": "budget",
|
|
95
|
+
"patterns": [
|
|
96
|
+
r"(?:spend|spent|budget|allocate|invest)\s+.*(?:budget|remaining).*(?:on|in|for)",
|
|
97
|
+
r"(?:remaining|left)\s+(?:budget|experiments|compute)",
|
|
98
|
+
r"(?:budget)\s+(?:on|for|between)",
|
|
99
|
+
],
|
|
100
|
+
"source": "budget allocation from `/turing:budget`",
|
|
101
|
+
"verify_cmd": "/turing:budget",
|
|
102
|
+
"data_dir": None,
|
|
103
|
+
},
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# --- Question Parsing ---
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def classify_question(question: str) -> dict:
|
|
111
|
+
"""Classify a what-if question into a route.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
question: Natural language question.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Route dict with name, source, verify_cmd, or unknown route.
|
|
118
|
+
"""
|
|
119
|
+
q_lower = question.lower().strip()
|
|
120
|
+
|
|
121
|
+
for route in ROUTE_PATTERNS:
|
|
122
|
+
for pattern in route["patterns"]:
|
|
123
|
+
if re.search(pattern, q_lower):
|
|
124
|
+
return {
|
|
125
|
+
"route": route["name"],
|
|
126
|
+
"source": route["source"],
|
|
127
|
+
"verify_cmd": route["verify_cmd"],
|
|
128
|
+
"data_dir": route["data_dir"],
|
|
129
|
+
"matched_pattern": pattern,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
return {
|
|
133
|
+
"route": "unknown",
|
|
134
|
+
"source": None,
|
|
135
|
+
"verify_cmd": None,
|
|
136
|
+
"data_dir": None,
|
|
137
|
+
"matched_pattern": None,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def extract_multiplier(question: str) -> float | None:
|
|
142
|
+
"""Extract a data multiplier from the question (e.g., '2x' -> 2.0)."""
|
|
143
|
+
match = re.search(r"(\d+(?:\.\d+)?)\s*x\s+(?:more|the)?\s*(?:data|samples)", question.lower())
|
|
144
|
+
if match:
|
|
145
|
+
return float(match.group(1))
|
|
146
|
+
|
|
147
|
+
if re.search(r"(?:double|twice)\s+(?:the\s+)?(?:data|samples)", question.lower()):
|
|
148
|
+
return 2.0
|
|
149
|
+
if re.search(r"(?:triple)\s+(?:the\s+)?(?:data|samples)", question.lower()):
|
|
150
|
+
return 3.0
|
|
151
|
+
if re.search(r"(?:half)\s+(?:the\s+)?(?:data|samples)", question.lower()):
|
|
152
|
+
return 0.5
|
|
153
|
+
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def extract_experiment_ids(question: str) -> list[str]:
|
|
158
|
+
"""Extract experiment IDs from the question."""
|
|
159
|
+
return re.findall(r"exp-(\d+)", question.lower())
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def extract_target_value(question: str) -> float | None:
|
|
163
|
+
"""Extract a target numeric value (e.g., sparsity percentage, hyperparameter value)."""
|
|
164
|
+
match = re.search(r"(\d+(?:\.\d+)?)\s*%", question)
|
|
165
|
+
if match:
|
|
166
|
+
return float(match.group(1))
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# --- Estimators ---
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def estimate_scaling(
|
|
174
|
+
question: str,
|
|
175
|
+
experiments: list[dict],
|
|
176
|
+
primary_metric: str,
|
|
177
|
+
scaling_dir: str = "experiments/scaling",
|
|
178
|
+
) -> dict:
|
|
179
|
+
"""Estimate metric change from data scaling.
|
|
180
|
+
|
|
181
|
+
Uses existing scaling law data if available, otherwise extrapolates
|
|
182
|
+
from experiment history.
|
|
183
|
+
"""
|
|
184
|
+
multiplier = extract_multiplier(question)
|
|
185
|
+
if multiplier is None:
|
|
186
|
+
return {"error": "Could not parse data multiplier from question"}
|
|
187
|
+
|
|
188
|
+
# Try loading existing scaling data
|
|
189
|
+
scaling_path = Path(scaling_dir)
|
|
190
|
+
scaling_results = []
|
|
191
|
+
if scaling_path.exists():
|
|
192
|
+
for f in scaling_path.glob("*.yaml"):
|
|
193
|
+
with open(f) as fh:
|
|
194
|
+
data = yaml.safe_load(fh)
|
|
195
|
+
if isinstance(data, dict) and "fit" in data:
|
|
196
|
+
scaling_results.append(data)
|
|
197
|
+
|
|
198
|
+
if scaling_results:
|
|
199
|
+
# Use the most recent scaling fit
|
|
200
|
+
fit = scaling_results[-1].get("fit", {})
|
|
201
|
+
a = fit.get("a", 0)
|
|
202
|
+
b = fit.get("b", 0)
|
|
203
|
+
c = fit.get("c", 0)
|
|
204
|
+
r_squared = fit.get("r_squared", 0)
|
|
205
|
+
|
|
206
|
+
current_metric = _get_best_metric(experiments, primary_metric)
|
|
207
|
+
predicted = a * (multiplier ** b) + c
|
|
208
|
+
|
|
209
|
+
confidence = _r_squared_to_confidence(r_squared)
|
|
210
|
+
return {
|
|
211
|
+
"estimate": round(predicted, 4),
|
|
212
|
+
"current": current_metric,
|
|
213
|
+
"delta": round(predicted - current_metric, 4) if current_metric else None,
|
|
214
|
+
"confidence": confidence,
|
|
215
|
+
"confidence_detail": f"R²={r_squared:.3f} on scaling curve",
|
|
216
|
+
"multiplier": multiplier,
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
# Fallback: rough extrapolation from experiment count
|
|
220
|
+
current_metric = _get_best_metric(experiments, primary_metric)
|
|
221
|
+
if current_metric is None:
|
|
222
|
+
return {"error": "No experiments with metrics found"}
|
|
223
|
+
|
|
224
|
+
# Conservative log-based estimate
|
|
225
|
+
import math
|
|
226
|
+
delta_estimate = current_metric * 0.01 * math.log2(multiplier)
|
|
227
|
+
predicted = current_metric + delta_estimate
|
|
228
|
+
|
|
229
|
+
return {
|
|
230
|
+
"estimate": round(predicted, 4),
|
|
231
|
+
"current": current_metric,
|
|
232
|
+
"delta": round(delta_estimate, 4),
|
|
233
|
+
"confidence": "LOW",
|
|
234
|
+
"confidence_detail": "No scaling data — using conservative log extrapolation",
|
|
235
|
+
"multiplier": multiplier,
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def estimate_ablation(
|
|
240
|
+
question: str,
|
|
241
|
+
experiments: list[dict],
|
|
242
|
+
primary_metric: str,
|
|
243
|
+
ablation_dir: str = "experiments/ablations",
|
|
244
|
+
) -> dict:
|
|
245
|
+
"""Estimate impact of removing a class/feature from ablation data."""
|
|
246
|
+
ablation_path = Path(ablation_dir)
|
|
247
|
+
if not ablation_path.exists() or not list(ablation_path.glob("*.yaml")):
|
|
248
|
+
return {"error": "No ablation data available. Run `/turing:ablate` first."}
|
|
249
|
+
|
|
250
|
+
# Load the most recent ablation study
|
|
251
|
+
ablation_files = sorted(ablation_path.glob("*.yaml"))
|
|
252
|
+
with open(ablation_files[-1]) as f:
|
|
253
|
+
ablation_data = yaml.safe_load(f)
|
|
254
|
+
|
|
255
|
+
if not isinstance(ablation_data, dict):
|
|
256
|
+
return {"error": "Malformed ablation data"}
|
|
257
|
+
|
|
258
|
+
components = ablation_data.get("components", ablation_data.get("results", []))
|
|
259
|
+
if not components:
|
|
260
|
+
return {"error": "No component ablation results found"}
|
|
261
|
+
|
|
262
|
+
current_metric = _get_best_metric(experiments, primary_metric)
|
|
263
|
+
|
|
264
|
+
# Find the component mentioned in the question
|
|
265
|
+
q_lower = question.lower()
|
|
266
|
+
matched = None
|
|
267
|
+
for comp in components if isinstance(components, list) else []:
|
|
268
|
+
name = comp.get("component", comp.get("name", "")).lower()
|
|
269
|
+
if name and name in q_lower:
|
|
270
|
+
matched = comp
|
|
271
|
+
break
|
|
272
|
+
|
|
273
|
+
if matched:
|
|
274
|
+
impact = matched.get("impact", matched.get("metric_delta", 0))
|
|
275
|
+
return {
|
|
276
|
+
"estimate": round(current_metric + impact, 4) if current_metric else None,
|
|
277
|
+
"current": current_metric,
|
|
278
|
+
"delta": round(impact, 4),
|
|
279
|
+
"component": matched.get("component", matched.get("name")),
|
|
280
|
+
"confidence": "HIGH",
|
|
281
|
+
"confidence_detail": "Direct ablation data available",
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
# General summary if no specific match
|
|
285
|
+
avg_impact = sum(
|
|
286
|
+
abs(c.get("impact", c.get("metric_delta", 0)))
|
|
287
|
+
for c in (components if isinstance(components, list) else [])
|
|
288
|
+
) / max(len(components) if isinstance(components, list) else 1, 1)
|
|
289
|
+
|
|
290
|
+
return {
|
|
291
|
+
"estimate": None,
|
|
292
|
+
"current": current_metric,
|
|
293
|
+
"delta": None,
|
|
294
|
+
"confidence": "LOW",
|
|
295
|
+
"confidence_detail": f"No exact match — average component impact is ±{avg_impact:.4f}",
|
|
296
|
+
"available_components": [
|
|
297
|
+
c.get("component", c.get("name")) for c in (components if isinstance(components, list) else [])
|
|
298
|
+
],
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
def estimate_sensitivity(
|
|
303
|
+
question: str,
|
|
304
|
+
experiments: list[dict],
|
|
305
|
+
primary_metric: str,
|
|
306
|
+
sensitivity_dir: str = "experiments/sensitivity",
|
|
307
|
+
) -> dict:
|
|
308
|
+
"""Estimate hyperparameter change impact from sensitivity data."""
|
|
309
|
+
sens_path = Path(sensitivity_dir)
|
|
310
|
+
if not sens_path.exists() or not list(sens_path.glob("*.yaml")):
|
|
311
|
+
return {"error": "No sensitivity data available. Run `/turing:sensitivity` first."}
|
|
312
|
+
|
|
313
|
+
sens_files = sorted(sens_path.glob("*.yaml"))
|
|
314
|
+
with open(sens_files[-1]) as f:
|
|
315
|
+
sens_data = yaml.safe_load(f)
|
|
316
|
+
|
|
317
|
+
if not isinstance(sens_data, dict):
|
|
318
|
+
return {"error": "Malformed sensitivity data"}
|
|
319
|
+
|
|
320
|
+
sensitivities = sens_data.get("sensitivities", [])
|
|
321
|
+
current_metric = _get_best_metric(experiments, primary_metric)
|
|
322
|
+
|
|
323
|
+
q_lower = question.lower()
|
|
324
|
+
for sens in sensitivities:
|
|
325
|
+
param = sens.get("param", "").lower()
|
|
326
|
+
if param and param.replace("_", " ") in q_lower.replace("_", " "):
|
|
327
|
+
return {
|
|
328
|
+
"estimate": None,
|
|
329
|
+
"current": current_metric,
|
|
330
|
+
"param": sens.get("param"),
|
|
331
|
+
"sensitivity_level": sens.get("level", "UNKNOWN"),
|
|
332
|
+
"metric_range": [sens.get("metric_min"), sens.get("metric_max")],
|
|
333
|
+
"best_value": sens.get("best_value"),
|
|
334
|
+
"confidence": "MED" if sens.get("level") in ("HIGH", "MED") else "LOW",
|
|
335
|
+
"confidence_detail": f"Sensitivity level: {sens.get('level')}, range: {sens.get('metric_min')}-{sens.get('metric_max')}",
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
return {
|
|
339
|
+
"error": "Parameter not found in sensitivity data",
|
|
340
|
+
"available_params": [s.get("param") for s in sensitivities],
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def estimate_ensemble(
|
|
345
|
+
question: str,
|
|
346
|
+
experiments: list[dict],
|
|
347
|
+
primary_metric: str,
|
|
348
|
+
ensemble_dir: str = "experiments/ensembles",
|
|
349
|
+
) -> dict:
|
|
350
|
+
"""Estimate ensemble improvement from prior ensemble data."""
|
|
351
|
+
ens_path = Path(ensemble_dir)
|
|
352
|
+
current_metric = _get_best_metric(experiments, primary_metric)
|
|
353
|
+
|
|
354
|
+
if ens_path.exists() and list(ens_path.glob("*.yaml")):
|
|
355
|
+
ens_files = sorted(ens_path.glob("*.yaml"))
|
|
356
|
+
with open(ens_files[-1]) as f:
|
|
357
|
+
ens_data = yaml.safe_load(f)
|
|
358
|
+
if isinstance(ens_data, dict):
|
|
359
|
+
best_method = ens_data.get("best_method", {})
|
|
360
|
+
ens_metric = best_method.get("metric", best_method.get(primary_metric))
|
|
361
|
+
if ens_metric is not None:
|
|
362
|
+
return {
|
|
363
|
+
"estimate": round(ens_metric, 4),
|
|
364
|
+
"current": current_metric,
|
|
365
|
+
"delta": round(ens_metric - current_metric, 4) if current_metric else None,
|
|
366
|
+
"method": best_method.get("method", "unknown"),
|
|
367
|
+
"confidence": "HIGH",
|
|
368
|
+
"confidence_detail": "Prior ensemble result available",
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
# Conservative estimate: ensembles typically improve 1-3%
|
|
372
|
+
if current_metric is not None:
|
|
373
|
+
delta = current_metric * 0.015
|
|
374
|
+
return {
|
|
375
|
+
"estimate": round(current_metric + delta, 4),
|
|
376
|
+
"current": current_metric,
|
|
377
|
+
"delta": round(delta, 4),
|
|
378
|
+
"confidence": "LOW",
|
|
379
|
+
"confidence_detail": "No prior ensemble data — using typical 1.5% improvement estimate",
|
|
380
|
+
}
|
|
381
|
+
|
|
382
|
+
return {"error": "No experiments or ensemble data available"}
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def estimate_pruning(
|
|
386
|
+
question: str,
|
|
387
|
+
experiments: list[dict],
|
|
388
|
+
primary_metric: str,
|
|
389
|
+
pruning_dir: str = "experiments/pruning",
|
|
390
|
+
) -> dict:
|
|
391
|
+
"""Estimate pruning impact from prior pruning sweep data."""
|
|
392
|
+
target_sparsity = extract_target_value(question)
|
|
393
|
+
prune_path = Path(pruning_dir)
|
|
394
|
+
|
|
395
|
+
if not prune_path.exists() or not list(prune_path.glob("*.yaml")):
|
|
396
|
+
return {"error": "No pruning data available. Run `/turing:prune` first."}
|
|
397
|
+
|
|
398
|
+
prune_files = sorted(prune_path.glob("*.yaml"))
|
|
399
|
+
with open(prune_files[-1]) as f:
|
|
400
|
+
prune_data = yaml.safe_load(f)
|
|
401
|
+
|
|
402
|
+
if not isinstance(prune_data, dict):
|
|
403
|
+
return {"error": "Malformed pruning data"}
|
|
404
|
+
|
|
405
|
+
current_metric = _get_best_metric(experiments, primary_metric)
|
|
406
|
+
sweep = prune_data.get("sweep_results", prune_data.get("results", []))
|
|
407
|
+
|
|
408
|
+
if target_sparsity is not None and sweep:
|
|
409
|
+
# Interpolate from sweep
|
|
410
|
+
sparsities = [s.get("sparsity", 0) for s in sweep]
|
|
411
|
+
metrics = [s.get("metric", s.get(primary_metric, 0)) for s in sweep]
|
|
412
|
+
|
|
413
|
+
if sparsities and metrics:
|
|
414
|
+
predicted = _linear_interpolate(sparsities, metrics, target_sparsity)
|
|
415
|
+
if predicted is not None:
|
|
416
|
+
return {
|
|
417
|
+
"estimate": round(predicted, 4),
|
|
418
|
+
"current": current_metric,
|
|
419
|
+
"delta": round(predicted - current_metric, 4) if current_metric else None,
|
|
420
|
+
"target_sparsity": target_sparsity,
|
|
421
|
+
"confidence": "MED",
|
|
422
|
+
"confidence_detail": f"Interpolated from {len(sweep)} pruning sweep points",
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
# Return available data
|
|
426
|
+
return {
|
|
427
|
+
"estimate": None,
|
|
428
|
+
"current": current_metric,
|
|
429
|
+
"confidence": "LOW",
|
|
430
|
+
"confidence_detail": "Could not interpolate — check pruning sweep data",
|
|
431
|
+
"available_sparsities": [s.get("sparsity") for s in sweep] if sweep else [],
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def estimate_stitch(
|
|
436
|
+
question: str,
|
|
437
|
+
experiments: list[dict],
|
|
438
|
+
primary_metric: str,
|
|
439
|
+
) -> dict:
|
|
440
|
+
"""Estimate pipeline stitch impact."""
|
|
441
|
+
exp_ids = extract_experiment_ids(question)
|
|
442
|
+
if len(exp_ids) < 2:
|
|
443
|
+
return {"error": "Need at least 2 experiment IDs to estimate stitch (e.g., exp-031 and exp-042)"}
|
|
444
|
+
|
|
445
|
+
current_metric = _get_best_metric(experiments, primary_metric)
|
|
446
|
+
|
|
447
|
+
# Look up the referenced experiments
|
|
448
|
+
ref_metrics = []
|
|
449
|
+
for eid in exp_ids:
|
|
450
|
+
full_id = f"exp-{eid}"
|
|
451
|
+
for exp in experiments:
|
|
452
|
+
if exp.get("experiment_id") == full_id:
|
|
453
|
+
m = exp.get("metrics", {}).get(primary_metric)
|
|
454
|
+
if m is not None:
|
|
455
|
+
ref_metrics.append({"id": full_id, "metric": m})
|
|
456
|
+
break
|
|
457
|
+
|
|
458
|
+
if len(ref_metrics) < 2:
|
|
459
|
+
return {
|
|
460
|
+
"error": f"Could not find metrics for referenced experiments",
|
|
461
|
+
"found": ref_metrics,
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
# Conservative estimate: best of the two + small bonus
|
|
465
|
+
best = max(m["metric"] for m in ref_metrics)
|
|
466
|
+
delta = abs(ref_metrics[0]["metric"] - ref_metrics[1]["metric"]) * 0.3
|
|
467
|
+
predicted = best + delta
|
|
468
|
+
|
|
469
|
+
return {
|
|
470
|
+
"estimate": round(predicted, 4),
|
|
471
|
+
"current": current_metric,
|
|
472
|
+
"delta": round(predicted - current_metric, 4) if current_metric else None,
|
|
473
|
+
"experiments": ref_metrics,
|
|
474
|
+
"confidence": "LOW",
|
|
475
|
+
"confidence_detail": "Estimated from individual metrics — actual stitch may differ",
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def estimate_budget(
|
|
480
|
+
question: str,
|
|
481
|
+
experiments: list[dict],
|
|
482
|
+
primary_metric: str,
|
|
483
|
+
) -> dict:
|
|
484
|
+
"""Estimate budget allocation impact."""
|
|
485
|
+
current_metric = _get_best_metric(experiments, primary_metric)
|
|
486
|
+
total_exps = len(experiments)
|
|
487
|
+
kept = sum(1 for e in experiments if e.get("status") == "kept")
|
|
488
|
+
|
|
489
|
+
return {
|
|
490
|
+
"estimate": None,
|
|
491
|
+
"current": current_metric,
|
|
492
|
+
"total_experiments": total_exps,
|
|
493
|
+
"kept_ratio": round(kept / total_exps, 2) if total_exps > 0 else 0,
|
|
494
|
+
"confidence": "LOW",
|
|
495
|
+
"confidence_detail": "Budget allocation requires simulation — use `/turing:simulate` for prediction",
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
# --- Helpers ---
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
def _get_best_metric(experiments: list[dict], metric: str) -> float | None:
|
|
503
|
+
"""Get the best metric value from all experiments."""
|
|
504
|
+
values = []
|
|
505
|
+
for exp in experiments:
|
|
506
|
+
v = exp.get("metrics", {}).get(metric)
|
|
507
|
+
if v is not None:
|
|
508
|
+
values.append(v)
|
|
509
|
+
return max(values) if values else None
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _r_squared_to_confidence(r_squared: float) -> str:
|
|
513
|
+
"""Map R² to confidence level."""
|
|
514
|
+
if r_squared >= 0.95:
|
|
515
|
+
return "HIGH"
|
|
516
|
+
elif r_squared >= 0.80:
|
|
517
|
+
return "MED"
|
|
518
|
+
else:
|
|
519
|
+
return "LOW"
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
def _linear_interpolate(xs: list[float], ys: list[float], target_x: float) -> float | None:
|
|
523
|
+
"""Simple linear interpolation."""
|
|
524
|
+
if not xs or not ys or len(xs) != len(ys):
|
|
525
|
+
return None
|
|
526
|
+
|
|
527
|
+
pairs = sorted(zip(xs, ys))
|
|
528
|
+
xs_sorted = [p[0] for p in pairs]
|
|
529
|
+
ys_sorted = [p[1] for p in pairs]
|
|
530
|
+
|
|
531
|
+
if target_x <= xs_sorted[0]:
|
|
532
|
+
return ys_sorted[0]
|
|
533
|
+
if target_x >= xs_sorted[-1]:
|
|
534
|
+
return ys_sorted[-1]
|
|
535
|
+
|
|
536
|
+
for i in range(len(xs_sorted) - 1):
|
|
537
|
+
if xs_sorted[i] <= target_x <= xs_sorted[i + 1]:
|
|
538
|
+
t = (target_x - xs_sorted[i]) / (xs_sorted[i + 1] - xs_sorted[i])
|
|
539
|
+
return ys_sorted[i] + t * (ys_sorted[i + 1] - ys_sorted[i])
|
|
540
|
+
|
|
541
|
+
return None
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
# --- Main Pipeline ---
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
ESTIMATOR_MAP = {
|
|
548
|
+
"scaling": estimate_scaling,
|
|
549
|
+
"ablation": estimate_ablation,
|
|
550
|
+
"sensitivity": estimate_sensitivity,
|
|
551
|
+
"ensemble": estimate_ensemble,
|
|
552
|
+
"pruning": estimate_pruning,
|
|
553
|
+
"stitch": estimate_stitch,
|
|
554
|
+
"budget": estimate_budget,
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def whatif_analysis(
|
|
559
|
+
question: str,
|
|
560
|
+
config_path: str = "config.yaml",
|
|
561
|
+
log_path: str = DEFAULT_LOG_PATH,
|
|
562
|
+
) -> dict:
|
|
563
|
+
"""Run what-if analysis for a hypothetical question.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
question: Natural language what-if question.
|
|
567
|
+
config_path: Path to config.yaml.
|
|
568
|
+
log_path: Path to experiment log.
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
What-if analysis result with estimate, confidence, and recommendation.
|
|
572
|
+
"""
|
|
573
|
+
config = load_config(config_path)
|
|
574
|
+
eval_cfg = config.get("evaluation", {})
|
|
575
|
+
primary_metric = eval_cfg.get("primary_metric", "accuracy")
|
|
576
|
+
|
|
577
|
+
experiments = load_experiments(log_path)
|
|
578
|
+
|
|
579
|
+
classification = classify_question(question)
|
|
580
|
+
route = classification["route"]
|
|
581
|
+
|
|
582
|
+
if route == "unknown":
|
|
583
|
+
return {
|
|
584
|
+
"question": question,
|
|
585
|
+
"route": "unknown",
|
|
586
|
+
"error": "Cannot classify question — no matching estimator found",
|
|
587
|
+
"available_routes": [r["name"] for r in ROUTE_PATTERNS],
|
|
588
|
+
"suggestion": "Try phrasing as: 'what if I had Nx more data', "
|
|
589
|
+
"'what if I removed <component>', "
|
|
590
|
+
"'what if I changed <hyperparameter>'",
|
|
591
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
estimator = ESTIMATOR_MAP.get(route)
|
|
595
|
+
if estimator is None:
|
|
596
|
+
return {"error": f"No estimator for route: {route}"}
|
|
597
|
+
|
|
598
|
+
# Build kwargs based on estimator signature
|
|
599
|
+
kwargs = {
|
|
600
|
+
"question": question,
|
|
601
|
+
"experiments": experiments,
|
|
602
|
+
"primary_metric": primary_metric,
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
# Add data_dir if the estimator accepts it
|
|
606
|
+
data_dir = classification.get("data_dir")
|
|
607
|
+
if data_dir and route in ("scaling", "ablation", "sensitivity", "ensemble", "pruning"):
|
|
608
|
+
dir_param_names = {
|
|
609
|
+
"scaling": "scaling_dir",
|
|
610
|
+
"ablation": "ablation_dir",
|
|
611
|
+
"sensitivity": "sensitivity_dir",
|
|
612
|
+
"ensemble": "ensemble_dir",
|
|
613
|
+
"pruning": "pruning_dir",
|
|
614
|
+
}
|
|
615
|
+
if route in dir_param_names:
|
|
616
|
+
kwargs[dir_param_names[route]] = data_dir
|
|
617
|
+
|
|
618
|
+
result = estimator(**kwargs)
|
|
619
|
+
|
|
620
|
+
# Build recommendation
|
|
621
|
+
recommendation = _generate_recommendation(result, route, classification)
|
|
622
|
+
|
|
623
|
+
return {
|
|
624
|
+
"question": question,
|
|
625
|
+
"route": route,
|
|
626
|
+
"source": classification["source"],
|
|
627
|
+
"verify_cmd": classification["verify_cmd"],
|
|
628
|
+
"result": result,
|
|
629
|
+
"recommendation": recommendation,
|
|
630
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _generate_recommendation(result: dict, route: str, classification: dict) -> str:
|
|
635
|
+
"""Generate a human-readable recommendation from the result."""
|
|
636
|
+
if "error" in result:
|
|
637
|
+
return f"Cannot estimate — {result['error']}"
|
|
638
|
+
|
|
639
|
+
estimate = result.get("estimate")
|
|
640
|
+
current = result.get("current")
|
|
641
|
+
delta = result.get("delta")
|
|
642
|
+
confidence = result.get("confidence", "UNKNOWN")
|
|
643
|
+
|
|
644
|
+
if estimate is None or current is None:
|
|
645
|
+
return f"Insufficient data for a point estimate. Run `{classification.get('verify_cmd', 'the source command')}` first."
|
|
646
|
+
|
|
647
|
+
if delta is not None:
|
|
648
|
+
if abs(delta) < 0.001:
|
|
649
|
+
return "Marginal gain. Likely not worth the effort."
|
|
650
|
+
elif delta > 0.01:
|
|
651
|
+
return f"Promising (+{delta:.4f}). Worth investigating further."
|
|
652
|
+
elif delta > 0:
|
|
653
|
+
return f"Small gain (+{delta:.4f}). Consider other approaches first."
|
|
654
|
+
else:
|
|
655
|
+
return f"Negative impact ({delta:.4f}). Avoid this direction."
|
|
656
|
+
|
|
657
|
+
return f"Estimate available ({confidence} confidence). Verify with `{classification.get('verify_cmd')}`."
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
# --- Report Formatting ---
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def save_whatif_report(report: dict, output_dir: str = "experiments/whatif") -> Path:
|
|
664
|
+
"""Save what-if analysis report to YAML."""
|
|
665
|
+
out_path = Path(output_dir)
|
|
666
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
667
|
+
ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
668
|
+
filepath = out_path / f"whatif-{ts}.yaml"
|
|
669
|
+
with open(filepath, "w") as f:
|
|
670
|
+
yaml.dump(report, f, default_flow_style=False, sort_keys=False)
|
|
671
|
+
return filepath
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def format_whatif_report(report: dict) -> str:
|
|
675
|
+
"""Format what-if report as readable markdown."""
|
|
676
|
+
if "error" in report and "result" not in report:
|
|
677
|
+
return f"ERROR: {report['error']}"
|
|
678
|
+
|
|
679
|
+
lines = ["# What-If Analysis", ""]
|
|
680
|
+
lines.append(f"**Question:** {report.get('question', 'N/A')}")
|
|
681
|
+
lines.append(f"**Route:** {report.get('route', 'unknown')}")
|
|
682
|
+
lines.append(f"**Source:** {report.get('source', 'N/A')}")
|
|
683
|
+
lines.append("")
|
|
684
|
+
|
|
685
|
+
result = report.get("result", {})
|
|
686
|
+
|
|
687
|
+
if "error" in result:
|
|
688
|
+
lines.append(f"**Error:** {result['error']}")
|
|
689
|
+
if "available_routes" in report:
|
|
690
|
+
lines.append("")
|
|
691
|
+
lines.append("Available routes: " + ", ".join(report["available_routes"]))
|
|
692
|
+
if "suggestion" in report:
|
|
693
|
+
lines.append("")
|
|
694
|
+
lines.append(f"**Suggestion:** {report['suggestion']}")
|
|
695
|
+
else:
|
|
696
|
+
current = result.get("current")
|
|
697
|
+
estimate = result.get("estimate")
|
|
698
|
+
delta = result.get("delta")
|
|
699
|
+
confidence = result.get("confidence", "UNKNOWN")
|
|
700
|
+
|
|
701
|
+
if current is not None:
|
|
702
|
+
lines.append(f"**Current best:** {current}")
|
|
703
|
+
if estimate is not None:
|
|
704
|
+
lines.append(f"**Estimated:** {estimate}")
|
|
705
|
+
if delta is not None:
|
|
706
|
+
sign = "+" if delta >= 0 else ""
|
|
707
|
+
lines.append(f"**Delta:** {sign}{delta}")
|
|
708
|
+
lines.append(f"**Confidence:** {confidence}")
|
|
709
|
+
|
|
710
|
+
detail = result.get("confidence_detail")
|
|
711
|
+
if detail:
|
|
712
|
+
lines.append(f"**Detail:** {detail}")
|
|
713
|
+
|
|
714
|
+
lines.append("")
|
|
715
|
+
rec = report.get("recommendation")
|
|
716
|
+
if rec:
|
|
717
|
+
lines.append(f"**Recommendation:** {rec}")
|
|
718
|
+
|
|
719
|
+
verify = report.get("verify_cmd")
|
|
720
|
+
if verify:
|
|
721
|
+
lines.append(f"**To verify:** run `{verify}`")
|
|
722
|
+
|
|
723
|
+
lines.append("")
|
|
724
|
+
lines.append(f"*Generated: {report.get('generated_at', 'N/A')}*")
|
|
725
|
+
return "\n".join(lines)
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
# --- CLI ---
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
def main():
|
|
732
|
+
parser = argparse.ArgumentParser(
|
|
733
|
+
description="What-if analysis engine — answer hypotheticals from existing data"
|
|
734
|
+
)
|
|
735
|
+
parser.add_argument("question", nargs="?", help="What-if question to analyze")
|
|
736
|
+
parser.add_argument("--config", default="config.yaml", help="Path to config.yaml")
|
|
737
|
+
parser.add_argument("--log", default=DEFAULT_LOG_PATH, help="Path to experiment log")
|
|
738
|
+
parser.add_argument("--json", action="store_true", help="Output raw JSON")
|
|
739
|
+
|
|
740
|
+
args = parser.parse_args()
|
|
741
|
+
|
|
742
|
+
if not args.question:
|
|
743
|
+
parser.error("Please provide a what-if question")
|
|
744
|
+
|
|
745
|
+
report = whatif_analysis(
|
|
746
|
+
question=args.question,
|
|
747
|
+
config_path=args.config,
|
|
748
|
+
log_path=args.log,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if args.json:
|
|
752
|
+
print(json.dumps(report, indent=2))
|
|
753
|
+
else:
|
|
754
|
+
print(format_whatif_report(report))
|
|
755
|
+
|
|
756
|
+
# Save report
|
|
757
|
+
saved = save_whatif_report(report)
|
|
758
|
+
if not args.json:
|
|
759
|
+
print(f"\nSaved: {saved}")
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
if __name__ == "__main__":
|
|
763
|
+
main()
|