claude-turing 4.1.0 → 4.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,763 @@
1
+ #!/usr/bin/env python3
2
+ """What-if analysis engine for the autoresearch pipeline.
3
+
4
+ Routes hypothetical questions to existing estimators: scaling laws,
5
+ ablation, sensitivity, ensemble, pruning, and stitch. Returns an
6
+ estimate with confidence without running new experiments.
7
+
8
+ Usage:
9
+ python scripts/whatif_engine.py "what if I had 2x more data"
10
+ python scripts/whatif_engine.py "what if I removed class 3"
11
+ python scripts/whatif_engine.py "what if I combined exp-031 with exp-042"
12
+ python scripts/whatif_engine.py --json
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import re
20
+ import sys
21
+ from datetime import datetime, timezone
22
+ from pathlib import Path
23
+
24
+ import yaml
25
+
26
+ from scripts.turing_io import load_config, load_experiments
27
+
28
+ DEFAULT_LOG_PATH = "experiments/log.jsonl"
29
+
30
+ # Question patterns mapped to estimator routes
31
+ ROUTE_PATTERNS = [
32
+ {
33
+ "name": "scaling",
34
+ "patterns": [
35
+ r"(?:more|less|2x|3x|4x|5x|10x|half|double|triple)\s+(?:data|samples|training data|examples)",
36
+ r"(?:data|dataset)\s+(?:size|scaling|increase|decrease)",
37
+ r"(?:scale|scaled)\s+(?:up|down)\s+(?:data|the data)",
38
+ ],
39
+ "source": "scaling law fit from `/turing:scale`",
40
+ "verify_cmd": "/turing:scale --extrapolate",
41
+ "data_dir": "experiments/scaling",
42
+ },
43
+ {
44
+ "name": "ablation",
45
+ "patterns": [
46
+ r"(?:remov\w*|drop\w*|exclud\w*|without|ablat\w*)\s+(?:class|feature|component|column|variable)",
47
+ r"(?:class|feature|component)\s+(?:\d+|[A-Za-z_]+)\s+(?:remov\w*|drop\w*|exclud\w*)",
48
+ ],
49
+ "source": "ablation study from `/turing:ablate`",
50
+ "verify_cmd": "/turing:ablate",
51
+ "data_dir": "experiments/ablations",
52
+ },
53
+ {
54
+ "name": "stitch",
55
+ "patterns": [
56
+ r"(?:combine|stitch|swap|mix)\s+.*(?:from|of)\s+exp-\d+\s+(?:with|and)\s+.*exp-\d+",
57
+ r"(?:pipeline|stage)\s+(?:from|of)\s+exp-\d+",
58
+ ],
59
+ "source": "pipeline composition from `/turing:stitch`",
60
+ "verify_cmd": "/turing:stitch",
61
+ "data_dir": "experiments/cache",
62
+ },
63
+ {
64
+ "name": "sensitivity",
65
+ "patterns": [
66
+ r"(?:different|change|modify|adjust|set)\s+(?:hyperparameter|learning.?rate|lr|depth|estimators|epochs|batch.?size)",
67
+ r"(?:learning.?rate|lr|depth|max_depth|n_estimators|epochs|batch.?size)\s+(?:was|were|to|=|of)\s+[\d.]+",
68
+ ],
69
+ "source": "sensitivity interpolation from `/turing:sensitivity`",
70
+ "verify_cmd": "/turing:sensitivity",
71
+ "data_dir": "experiments/sensitivity",
72
+ },
73
+ {
74
+ "name": "ensemble",
75
+ "patterns": [
76
+ r"(?:ensembl\w*|combine|blend\w*|stack\w*|vot\w*)\s+(?:(?:these|the|top|best)\s+)*(?:models|experiments)",
77
+ r"(?:voting|stacking|blending)\s+(?:of|with|from)",
78
+ ],
79
+ "source": "prediction correlation from `/turing:ensemble`",
80
+ "verify_cmd": "/turing:ensemble",
81
+ "data_dir": "experiments/ensembles",
82
+ },
83
+ {
84
+ "name": "pruning",
85
+ "patterns": [
86
+ r"(?:prune|pruning|sparsity)\s+(?:to|at|of)?\s*\d+",
87
+ r"\d+%?\s*(?:sparsity|sparse|pruned)",
88
+ ],
89
+ "source": "pruning sweep interpolation from `/turing:prune`",
90
+ "verify_cmd": "/turing:prune",
91
+ "data_dir": "experiments/pruning",
92
+ },
93
+ {
94
+ "name": "budget",
95
+ "patterns": [
96
+ r"(?:spend|spent|budget|allocate|invest)\s+.*(?:budget|remaining).*(?:on|in|for)",
97
+ r"(?:remaining|left)\s+(?:budget|experiments|compute)",
98
+ r"(?:budget)\s+(?:on|for|between)",
99
+ ],
100
+ "source": "budget allocation from `/turing:budget`",
101
+ "verify_cmd": "/turing:budget",
102
+ "data_dir": None,
103
+ },
104
+ ]
105
+
106
+
107
+ # --- Question Parsing ---
108
+
109
+
110
+ def classify_question(question: str) -> dict:
111
+ """Classify a what-if question into a route.
112
+
113
+ Args:
114
+ question: Natural language question.
115
+
116
+ Returns:
117
+ Route dict with name, source, verify_cmd, or unknown route.
118
+ """
119
+ q_lower = question.lower().strip()
120
+
121
+ for route in ROUTE_PATTERNS:
122
+ for pattern in route["patterns"]:
123
+ if re.search(pattern, q_lower):
124
+ return {
125
+ "route": route["name"],
126
+ "source": route["source"],
127
+ "verify_cmd": route["verify_cmd"],
128
+ "data_dir": route["data_dir"],
129
+ "matched_pattern": pattern,
130
+ }
131
+
132
+ return {
133
+ "route": "unknown",
134
+ "source": None,
135
+ "verify_cmd": None,
136
+ "data_dir": None,
137
+ "matched_pattern": None,
138
+ }
139
+
140
+
141
+ def extract_multiplier(question: str) -> float | None:
142
+ """Extract a data multiplier from the question (e.g., '2x' -> 2.0)."""
143
+ match = re.search(r"(\d+(?:\.\d+)?)\s*x\s+(?:more|the)?\s*(?:data|samples)", question.lower())
144
+ if match:
145
+ return float(match.group(1))
146
+
147
+ if re.search(r"(?:double|twice)\s+(?:the\s+)?(?:data|samples)", question.lower()):
148
+ return 2.0
149
+ if re.search(r"(?:triple)\s+(?:the\s+)?(?:data|samples)", question.lower()):
150
+ return 3.0
151
+ if re.search(r"(?:half)\s+(?:the\s+)?(?:data|samples)", question.lower()):
152
+ return 0.5
153
+
154
+ return None
155
+
156
+
157
+ def extract_experiment_ids(question: str) -> list[str]:
158
+ """Extract experiment IDs from the question."""
159
+ return re.findall(r"exp-(\d+)", question.lower())
160
+
161
+
162
+ def extract_target_value(question: str) -> float | None:
163
+ """Extract a target numeric value (e.g., sparsity percentage, hyperparameter value)."""
164
+ match = re.search(r"(\d+(?:\.\d+)?)\s*%", question)
165
+ if match:
166
+ return float(match.group(1))
167
+ return None
168
+
169
+
170
+ # --- Estimators ---
171
+
172
+
173
+ def estimate_scaling(
174
+ question: str,
175
+ experiments: list[dict],
176
+ primary_metric: str,
177
+ scaling_dir: str = "experiments/scaling",
178
+ ) -> dict:
179
+ """Estimate metric change from data scaling.
180
+
181
+ Uses existing scaling law data if available, otherwise extrapolates
182
+ from experiment history.
183
+ """
184
+ multiplier = extract_multiplier(question)
185
+ if multiplier is None:
186
+ return {"error": "Could not parse data multiplier from question"}
187
+
188
+ # Try loading existing scaling data
189
+ scaling_path = Path(scaling_dir)
190
+ scaling_results = []
191
+ if scaling_path.exists():
192
+ for f in scaling_path.glob("*.yaml"):
193
+ with open(f) as fh:
194
+ data = yaml.safe_load(fh)
195
+ if isinstance(data, dict) and "fit" in data:
196
+ scaling_results.append(data)
197
+
198
+ if scaling_results:
199
+ # Use the most recent scaling fit
200
+ fit = scaling_results[-1].get("fit", {})
201
+ a = fit.get("a", 0)
202
+ b = fit.get("b", 0)
203
+ c = fit.get("c", 0)
204
+ r_squared = fit.get("r_squared", 0)
205
+
206
+ current_metric = _get_best_metric(experiments, primary_metric)
207
+ predicted = a * (multiplier ** b) + c
208
+
209
+ confidence = _r_squared_to_confidence(r_squared)
210
+ return {
211
+ "estimate": round(predicted, 4),
212
+ "current": current_metric,
213
+ "delta": round(predicted - current_metric, 4) if current_metric else None,
214
+ "confidence": confidence,
215
+ "confidence_detail": f"R²={r_squared:.3f} on scaling curve",
216
+ "multiplier": multiplier,
217
+ }
218
+
219
+ # Fallback: rough extrapolation from experiment count
220
+ current_metric = _get_best_metric(experiments, primary_metric)
221
+ if current_metric is None:
222
+ return {"error": "No experiments with metrics found"}
223
+
224
+ # Conservative log-based estimate
225
+ import math
226
+ delta_estimate = current_metric * 0.01 * math.log2(multiplier)
227
+ predicted = current_metric + delta_estimate
228
+
229
+ return {
230
+ "estimate": round(predicted, 4),
231
+ "current": current_metric,
232
+ "delta": round(delta_estimate, 4),
233
+ "confidence": "LOW",
234
+ "confidence_detail": "No scaling data — using conservative log extrapolation",
235
+ "multiplier": multiplier,
236
+ }
237
+
238
+
239
+ def estimate_ablation(
240
+ question: str,
241
+ experiments: list[dict],
242
+ primary_metric: str,
243
+ ablation_dir: str = "experiments/ablations",
244
+ ) -> dict:
245
+ """Estimate impact of removing a class/feature from ablation data."""
246
+ ablation_path = Path(ablation_dir)
247
+ if not ablation_path.exists() or not list(ablation_path.glob("*.yaml")):
248
+ return {"error": "No ablation data available. Run `/turing:ablate` first."}
249
+
250
+ # Load the most recent ablation study
251
+ ablation_files = sorted(ablation_path.glob("*.yaml"))
252
+ with open(ablation_files[-1]) as f:
253
+ ablation_data = yaml.safe_load(f)
254
+
255
+ if not isinstance(ablation_data, dict):
256
+ return {"error": "Malformed ablation data"}
257
+
258
+ components = ablation_data.get("components", ablation_data.get("results", []))
259
+ if not components:
260
+ return {"error": "No component ablation results found"}
261
+
262
+ current_metric = _get_best_metric(experiments, primary_metric)
263
+
264
+ # Find the component mentioned in the question
265
+ q_lower = question.lower()
266
+ matched = None
267
+ for comp in components if isinstance(components, list) else []:
268
+ name = comp.get("component", comp.get("name", "")).lower()
269
+ if name and name in q_lower:
270
+ matched = comp
271
+ break
272
+
273
+ if matched:
274
+ impact = matched.get("impact", matched.get("metric_delta", 0))
275
+ return {
276
+ "estimate": round(current_metric + impact, 4) if current_metric else None,
277
+ "current": current_metric,
278
+ "delta": round(impact, 4),
279
+ "component": matched.get("component", matched.get("name")),
280
+ "confidence": "HIGH",
281
+ "confidence_detail": "Direct ablation data available",
282
+ }
283
+
284
+ # General summary if no specific match
285
+ avg_impact = sum(
286
+ abs(c.get("impact", c.get("metric_delta", 0)))
287
+ for c in (components if isinstance(components, list) else [])
288
+ ) / max(len(components) if isinstance(components, list) else 1, 1)
289
+
290
+ return {
291
+ "estimate": None,
292
+ "current": current_metric,
293
+ "delta": None,
294
+ "confidence": "LOW",
295
+ "confidence_detail": f"No exact match — average component impact is ±{avg_impact:.4f}",
296
+ "available_components": [
297
+ c.get("component", c.get("name")) for c in (components if isinstance(components, list) else [])
298
+ ],
299
+ }
300
+
301
+
302
+ def estimate_sensitivity(
303
+ question: str,
304
+ experiments: list[dict],
305
+ primary_metric: str,
306
+ sensitivity_dir: str = "experiments/sensitivity",
307
+ ) -> dict:
308
+ """Estimate hyperparameter change impact from sensitivity data."""
309
+ sens_path = Path(sensitivity_dir)
310
+ if not sens_path.exists() or not list(sens_path.glob("*.yaml")):
311
+ return {"error": "No sensitivity data available. Run `/turing:sensitivity` first."}
312
+
313
+ sens_files = sorted(sens_path.glob("*.yaml"))
314
+ with open(sens_files[-1]) as f:
315
+ sens_data = yaml.safe_load(f)
316
+
317
+ if not isinstance(sens_data, dict):
318
+ return {"error": "Malformed sensitivity data"}
319
+
320
+ sensitivities = sens_data.get("sensitivities", [])
321
+ current_metric = _get_best_metric(experiments, primary_metric)
322
+
323
+ q_lower = question.lower()
324
+ for sens in sensitivities:
325
+ param = sens.get("param", "").lower()
326
+ if param and param.replace("_", " ") in q_lower.replace("_", " "):
327
+ return {
328
+ "estimate": None,
329
+ "current": current_metric,
330
+ "param": sens.get("param"),
331
+ "sensitivity_level": sens.get("level", "UNKNOWN"),
332
+ "metric_range": [sens.get("metric_min"), sens.get("metric_max")],
333
+ "best_value": sens.get("best_value"),
334
+ "confidence": "MED" if sens.get("level") in ("HIGH", "MED") else "LOW",
335
+ "confidence_detail": f"Sensitivity level: {sens.get('level')}, range: {sens.get('metric_min')}-{sens.get('metric_max')}",
336
+ }
337
+
338
+ return {
339
+ "error": "Parameter not found in sensitivity data",
340
+ "available_params": [s.get("param") for s in sensitivities],
341
+ }
342
+
343
+
344
+ def estimate_ensemble(
345
+ question: str,
346
+ experiments: list[dict],
347
+ primary_metric: str,
348
+ ensemble_dir: str = "experiments/ensembles",
349
+ ) -> dict:
350
+ """Estimate ensemble improvement from prior ensemble data."""
351
+ ens_path = Path(ensemble_dir)
352
+ current_metric = _get_best_metric(experiments, primary_metric)
353
+
354
+ if ens_path.exists() and list(ens_path.glob("*.yaml")):
355
+ ens_files = sorted(ens_path.glob("*.yaml"))
356
+ with open(ens_files[-1]) as f:
357
+ ens_data = yaml.safe_load(f)
358
+ if isinstance(ens_data, dict):
359
+ best_method = ens_data.get("best_method", {})
360
+ ens_metric = best_method.get("metric", best_method.get(primary_metric))
361
+ if ens_metric is not None:
362
+ return {
363
+ "estimate": round(ens_metric, 4),
364
+ "current": current_metric,
365
+ "delta": round(ens_metric - current_metric, 4) if current_metric else None,
366
+ "method": best_method.get("method", "unknown"),
367
+ "confidence": "HIGH",
368
+ "confidence_detail": "Prior ensemble result available",
369
+ }
370
+
371
+ # Conservative estimate: ensembles typically improve 1-3%
372
+ if current_metric is not None:
373
+ delta = current_metric * 0.015
374
+ return {
375
+ "estimate": round(current_metric + delta, 4),
376
+ "current": current_metric,
377
+ "delta": round(delta, 4),
378
+ "confidence": "LOW",
379
+ "confidence_detail": "No prior ensemble data — using typical 1.5% improvement estimate",
380
+ }
381
+
382
+ return {"error": "No experiments or ensemble data available"}
383
+
384
+
385
+ def estimate_pruning(
386
+ question: str,
387
+ experiments: list[dict],
388
+ primary_metric: str,
389
+ pruning_dir: str = "experiments/pruning",
390
+ ) -> dict:
391
+ """Estimate pruning impact from prior pruning sweep data."""
392
+ target_sparsity = extract_target_value(question)
393
+ prune_path = Path(pruning_dir)
394
+
395
+ if not prune_path.exists() or not list(prune_path.glob("*.yaml")):
396
+ return {"error": "No pruning data available. Run `/turing:prune` first."}
397
+
398
+ prune_files = sorted(prune_path.glob("*.yaml"))
399
+ with open(prune_files[-1]) as f:
400
+ prune_data = yaml.safe_load(f)
401
+
402
+ if not isinstance(prune_data, dict):
403
+ return {"error": "Malformed pruning data"}
404
+
405
+ current_metric = _get_best_metric(experiments, primary_metric)
406
+ sweep = prune_data.get("sweep_results", prune_data.get("results", []))
407
+
408
+ if target_sparsity is not None and sweep:
409
+ # Interpolate from sweep
410
+ sparsities = [s.get("sparsity", 0) for s in sweep]
411
+ metrics = [s.get("metric", s.get(primary_metric, 0)) for s in sweep]
412
+
413
+ if sparsities and metrics:
414
+ predicted = _linear_interpolate(sparsities, metrics, target_sparsity)
415
+ if predicted is not None:
416
+ return {
417
+ "estimate": round(predicted, 4),
418
+ "current": current_metric,
419
+ "delta": round(predicted - current_metric, 4) if current_metric else None,
420
+ "target_sparsity": target_sparsity,
421
+ "confidence": "MED",
422
+ "confidence_detail": f"Interpolated from {len(sweep)} pruning sweep points",
423
+ }
424
+
425
+ # Return available data
426
+ return {
427
+ "estimate": None,
428
+ "current": current_metric,
429
+ "confidence": "LOW",
430
+ "confidence_detail": "Could not interpolate — check pruning sweep data",
431
+ "available_sparsities": [s.get("sparsity") for s in sweep] if sweep else [],
432
+ }
433
+
434
+
435
+ def estimate_stitch(
436
+ question: str,
437
+ experiments: list[dict],
438
+ primary_metric: str,
439
+ ) -> dict:
440
+ """Estimate pipeline stitch impact."""
441
+ exp_ids = extract_experiment_ids(question)
442
+ if len(exp_ids) < 2:
443
+ return {"error": "Need at least 2 experiment IDs to estimate stitch (e.g., exp-031 and exp-042)"}
444
+
445
+ current_metric = _get_best_metric(experiments, primary_metric)
446
+
447
+ # Look up the referenced experiments
448
+ ref_metrics = []
449
+ for eid in exp_ids:
450
+ full_id = f"exp-{eid}"
451
+ for exp in experiments:
452
+ if exp.get("experiment_id") == full_id:
453
+ m = exp.get("metrics", {}).get(primary_metric)
454
+ if m is not None:
455
+ ref_metrics.append({"id": full_id, "metric": m})
456
+ break
457
+
458
+ if len(ref_metrics) < 2:
459
+ return {
460
+ "error": f"Could not find metrics for referenced experiments",
461
+ "found": ref_metrics,
462
+ }
463
+
464
+ # Conservative estimate: best of the two + small bonus
465
+ best = max(m["metric"] for m in ref_metrics)
466
+ delta = abs(ref_metrics[0]["metric"] - ref_metrics[1]["metric"]) * 0.3
467
+ predicted = best + delta
468
+
469
+ return {
470
+ "estimate": round(predicted, 4),
471
+ "current": current_metric,
472
+ "delta": round(predicted - current_metric, 4) if current_metric else None,
473
+ "experiments": ref_metrics,
474
+ "confidence": "LOW",
475
+ "confidence_detail": "Estimated from individual metrics — actual stitch may differ",
476
+ }
477
+
478
+
479
+ def estimate_budget(
480
+ question: str,
481
+ experiments: list[dict],
482
+ primary_metric: str,
483
+ ) -> dict:
484
+ """Estimate budget allocation impact."""
485
+ current_metric = _get_best_metric(experiments, primary_metric)
486
+ total_exps = len(experiments)
487
+ kept = sum(1 for e in experiments if e.get("status") == "kept")
488
+
489
+ return {
490
+ "estimate": None,
491
+ "current": current_metric,
492
+ "total_experiments": total_exps,
493
+ "kept_ratio": round(kept / total_exps, 2) if total_exps > 0 else 0,
494
+ "confidence": "LOW",
495
+ "confidence_detail": "Budget allocation requires simulation — use `/turing:simulate` for prediction",
496
+ }
497
+
498
+
499
+ # --- Helpers ---
500
+
501
+
502
+ def _get_best_metric(experiments: list[dict], metric: str) -> float | None:
503
+ """Get the best metric value from all experiments."""
504
+ values = []
505
+ for exp in experiments:
506
+ v = exp.get("metrics", {}).get(metric)
507
+ if v is not None:
508
+ values.append(v)
509
+ return max(values) if values else None
510
+
511
+
512
+ def _r_squared_to_confidence(r_squared: float) -> str:
513
+ """Map R² to confidence level."""
514
+ if r_squared >= 0.95:
515
+ return "HIGH"
516
+ elif r_squared >= 0.80:
517
+ return "MED"
518
+ else:
519
+ return "LOW"
520
+
521
+
522
+ def _linear_interpolate(xs: list[float], ys: list[float], target_x: float) -> float | None:
523
+ """Simple linear interpolation."""
524
+ if not xs or not ys or len(xs) != len(ys):
525
+ return None
526
+
527
+ pairs = sorted(zip(xs, ys))
528
+ xs_sorted = [p[0] for p in pairs]
529
+ ys_sorted = [p[1] for p in pairs]
530
+
531
+ if target_x <= xs_sorted[0]:
532
+ return ys_sorted[0]
533
+ if target_x >= xs_sorted[-1]:
534
+ return ys_sorted[-1]
535
+
536
+ for i in range(len(xs_sorted) - 1):
537
+ if xs_sorted[i] <= target_x <= xs_sorted[i + 1]:
538
+ t = (target_x - xs_sorted[i]) / (xs_sorted[i + 1] - xs_sorted[i])
539
+ return ys_sorted[i] + t * (ys_sorted[i + 1] - ys_sorted[i])
540
+
541
+ return None
542
+
543
+
544
+ # --- Main Pipeline ---
545
+
546
+
547
+ ESTIMATOR_MAP = {
548
+ "scaling": estimate_scaling,
549
+ "ablation": estimate_ablation,
550
+ "sensitivity": estimate_sensitivity,
551
+ "ensemble": estimate_ensemble,
552
+ "pruning": estimate_pruning,
553
+ "stitch": estimate_stitch,
554
+ "budget": estimate_budget,
555
+ }
556
+
557
+
558
+ def whatif_analysis(
559
+ question: str,
560
+ config_path: str = "config.yaml",
561
+ log_path: str = DEFAULT_LOG_PATH,
562
+ ) -> dict:
563
+ """Run what-if analysis for a hypothetical question.
564
+
565
+ Args:
566
+ question: Natural language what-if question.
567
+ config_path: Path to config.yaml.
568
+ log_path: Path to experiment log.
569
+
570
+ Returns:
571
+ What-if analysis result with estimate, confidence, and recommendation.
572
+ """
573
+ config = load_config(config_path)
574
+ eval_cfg = config.get("evaluation", {})
575
+ primary_metric = eval_cfg.get("primary_metric", "accuracy")
576
+
577
+ experiments = load_experiments(log_path)
578
+
579
+ classification = classify_question(question)
580
+ route = classification["route"]
581
+
582
+ if route == "unknown":
583
+ return {
584
+ "question": question,
585
+ "route": "unknown",
586
+ "error": "Cannot classify question — no matching estimator found",
587
+ "available_routes": [r["name"] for r in ROUTE_PATTERNS],
588
+ "suggestion": "Try phrasing as: 'what if I had Nx more data', "
589
+ "'what if I removed <component>', "
590
+ "'what if I changed <hyperparameter>'",
591
+ "generated_at": datetime.now(timezone.utc).isoformat(),
592
+ }
593
+
594
+ estimator = ESTIMATOR_MAP.get(route)
595
+ if estimator is None:
596
+ return {"error": f"No estimator for route: {route}"}
597
+
598
+ # Build kwargs based on estimator signature
599
+ kwargs = {
600
+ "question": question,
601
+ "experiments": experiments,
602
+ "primary_metric": primary_metric,
603
+ }
604
+
605
+ # Add data_dir if the estimator accepts it
606
+ data_dir = classification.get("data_dir")
607
+ if data_dir and route in ("scaling", "ablation", "sensitivity", "ensemble", "pruning"):
608
+ dir_param_names = {
609
+ "scaling": "scaling_dir",
610
+ "ablation": "ablation_dir",
611
+ "sensitivity": "sensitivity_dir",
612
+ "ensemble": "ensemble_dir",
613
+ "pruning": "pruning_dir",
614
+ }
615
+ if route in dir_param_names:
616
+ kwargs[dir_param_names[route]] = data_dir
617
+
618
+ result = estimator(**kwargs)
619
+
620
+ # Build recommendation
621
+ recommendation = _generate_recommendation(result, route, classification)
622
+
623
+ return {
624
+ "question": question,
625
+ "route": route,
626
+ "source": classification["source"],
627
+ "verify_cmd": classification["verify_cmd"],
628
+ "result": result,
629
+ "recommendation": recommendation,
630
+ "generated_at": datetime.now(timezone.utc).isoformat(),
631
+ }
632
+
633
+
634
+ def _generate_recommendation(result: dict, route: str, classification: dict) -> str:
635
+ """Generate a human-readable recommendation from the result."""
636
+ if "error" in result:
637
+ return f"Cannot estimate — {result['error']}"
638
+
639
+ estimate = result.get("estimate")
640
+ current = result.get("current")
641
+ delta = result.get("delta")
642
+ confidence = result.get("confidence", "UNKNOWN")
643
+
644
+ if estimate is None or current is None:
645
+ return f"Insufficient data for a point estimate. Run `{classification.get('verify_cmd', 'the source command')}` first."
646
+
647
+ if delta is not None:
648
+ if abs(delta) < 0.001:
649
+ return "Marginal gain. Likely not worth the effort."
650
+ elif delta > 0.01:
651
+ return f"Promising (+{delta:.4f}). Worth investigating further."
652
+ elif delta > 0:
653
+ return f"Small gain (+{delta:.4f}). Consider other approaches first."
654
+ else:
655
+ return f"Negative impact ({delta:.4f}). Avoid this direction."
656
+
657
+ return f"Estimate available ({confidence} confidence). Verify with `{classification.get('verify_cmd')}`."
658
+
659
+
660
+ # --- Report Formatting ---
661
+
662
+
663
+ def save_whatif_report(report: dict, output_dir: str = "experiments/whatif") -> Path:
664
+ """Save what-if analysis report to YAML."""
665
+ out_path = Path(output_dir)
666
+ out_path.mkdir(parents=True, exist_ok=True)
667
+ ts = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
668
+ filepath = out_path / f"whatif-{ts}.yaml"
669
+ with open(filepath, "w") as f:
670
+ yaml.dump(report, f, default_flow_style=False, sort_keys=False)
671
+ return filepath
672
+
673
+
674
+ def format_whatif_report(report: dict) -> str:
675
+ """Format what-if report as readable markdown."""
676
+ if "error" in report and "result" not in report:
677
+ return f"ERROR: {report['error']}"
678
+
679
+ lines = ["# What-If Analysis", ""]
680
+ lines.append(f"**Question:** {report.get('question', 'N/A')}")
681
+ lines.append(f"**Route:** {report.get('route', 'unknown')}")
682
+ lines.append(f"**Source:** {report.get('source', 'N/A')}")
683
+ lines.append("")
684
+
685
+ result = report.get("result", {})
686
+
687
+ if "error" in result:
688
+ lines.append(f"**Error:** {result['error']}")
689
+ if "available_routes" in report:
690
+ lines.append("")
691
+ lines.append("Available routes: " + ", ".join(report["available_routes"]))
692
+ if "suggestion" in report:
693
+ lines.append("")
694
+ lines.append(f"**Suggestion:** {report['suggestion']}")
695
+ else:
696
+ current = result.get("current")
697
+ estimate = result.get("estimate")
698
+ delta = result.get("delta")
699
+ confidence = result.get("confidence", "UNKNOWN")
700
+
701
+ if current is not None:
702
+ lines.append(f"**Current best:** {current}")
703
+ if estimate is not None:
704
+ lines.append(f"**Estimated:** {estimate}")
705
+ if delta is not None:
706
+ sign = "+" if delta >= 0 else ""
707
+ lines.append(f"**Delta:** {sign}{delta}")
708
+ lines.append(f"**Confidence:** {confidence}")
709
+
710
+ detail = result.get("confidence_detail")
711
+ if detail:
712
+ lines.append(f"**Detail:** {detail}")
713
+
714
+ lines.append("")
715
+ rec = report.get("recommendation")
716
+ if rec:
717
+ lines.append(f"**Recommendation:** {rec}")
718
+
719
+ verify = report.get("verify_cmd")
720
+ if verify:
721
+ lines.append(f"**To verify:** run `{verify}`")
722
+
723
+ lines.append("")
724
+ lines.append(f"*Generated: {report.get('generated_at', 'N/A')}*")
725
+ return "\n".join(lines)
726
+
727
+
728
+ # --- CLI ---
729
+
730
+
731
+ def main():
732
+ parser = argparse.ArgumentParser(
733
+ description="What-if analysis engine — answer hypotheticals from existing data"
734
+ )
735
+ parser.add_argument("question", nargs="?", help="What-if question to analyze")
736
+ parser.add_argument("--config", default="config.yaml", help="Path to config.yaml")
737
+ parser.add_argument("--log", default=DEFAULT_LOG_PATH, help="Path to experiment log")
738
+ parser.add_argument("--json", action="store_true", help="Output raw JSON")
739
+
740
+ args = parser.parse_args()
741
+
742
+ if not args.question:
743
+ parser.error("Please provide a what-if question")
744
+
745
+ report = whatif_analysis(
746
+ question=args.question,
747
+ config_path=args.config,
748
+ log_path=args.log,
749
+ )
750
+
751
+ if args.json:
752
+ print(json.dumps(report, indent=2))
753
+ else:
754
+ print(format_whatif_report(report))
755
+
756
+ # Save report
757
+ saved = save_whatif_report(report)
758
+ if not args.json:
759
+ print(f"\nSaved: {saved}")
760
+
761
+
762
+ if __name__ == "__main__":
763
+ main()