claude-turing 4.1.0 → 4.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +5 -2
- package/commands/counterfactual.md +27 -0
- package/commands/simulate.md +28 -0
- package/commands/turing.md +6 -0
- package/commands/whatif.md +31 -0
- package/package.json +1 -1
- package/src/install.js +1 -0
- package/src/verify.js +3 -0
- package/templates/scripts/__pycache__/counterfactual_explanation.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/experiment_simulator.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/whatif_engine.cpython-314.pyc +0 -0
- package/templates/scripts/counterfactual_explanation.py +485 -0
- package/templates/scripts/experiment_simulator.py +463 -0
- package/templates/scripts/generate_brief.py +64 -0
- package/templates/scripts/scaffold.py +6 -0
- package/templates/scripts/whatif_engine.py +763 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "turing",
|
|
3
|
-
"version": "4.
|
|
4
|
-
"description": "Autonomous ML research harness — the autoresearch loop as a formal protocol.
|
|
3
|
+
"version": "4.2.0",
|
|
4
|
+
"description": "Autonomous ML research harness — the autoresearch loop as a formal protocol. 69 commands, 2 specialized agents, what-if analysis (whatif + counterfactual + simulate), collaboration (onboard + share + review), research communication (cite + present + changelog), experiment archaeology (trend + flashback + archive + annotate + search + template + replay), model surgery (prune + quantize + merge + surgery), feature & training intelligence, model debugging, pre-training intelligence, meta-intelligence, scaling & efficiency, model composition, deep analysis, experiment orchestration, literature + paper, model export, profiling, checkpoints, experiment intelligence, statistical rigor, tree-search, cost-performance, model cards, hypothesis database, novelty guard, anti-cheating, taste-leverage loop. Inspired by Karpathy's autoresearch and the scientific method itself.",
|
|
5
5
|
"author": {
|
|
6
6
|
"name": "pragnition"
|
|
7
7
|
},
|
package/README.md
CHANGED
|
@@ -377,6 +377,9 @@ The index (`hypotheses.yaml`) is the lightweight queue. The detail files (`hypot
|
|
|
377
377
|
| `/turing:onboard [--audience]` | Project onboarding — walkthrough for new collaborators |
|
|
378
378
|
| `/turing:share <exp-ids...>` | Experiment packaging — portable archive with manifest |
|
|
379
379
|
| `/turing:review [--venue]` | Peer review simulation — weaknesses, fix commands, score |
|
|
380
|
+
| `/turing:whatif "<question>"` | What-if analysis — answer hypotheticals from existing experiment data |
|
|
381
|
+
| `/turing:counterfactual <exp-id>` | Counterfactual explanations — minimum input change to flip a prediction |
|
|
382
|
+
| `/turing:simulate [--configs]` | Experiment outcome prediction — pre-filter configs, save budget |
|
|
380
383
|
|
|
381
384
|
And for fully hands-off operation:
|
|
382
385
|
|
|
@@ -561,11 +564,11 @@ Each project gets independent config, data, experiments, models, and agent memor
|
|
|
561
564
|
|
|
562
565
|
## Architecture of Turing Itself
|
|
563
566
|
|
|
564
|
-
|
|
567
|
+
69 commands, 2 agents, 10 config files, 88 template scripts, model registry, artifact contract, cost-performance frontier, model cards, tree-search exploration, statistical rigor, experiment intelligence, performance profiling, smart checkpoints, production model export, literature integration, paper section drafting, experiment orchestration (queue + retry + fork), deep analysis (diff + watch + regress), model composition (ensemble + stitch + warm), scaling & efficiency (scale + budget + distill), meta-intelligence (transfer + audit), pre-training intelligence (sanity + baseline + leak), model debugging (xray + sensitivity + calibrate), feature & training intelligence (feature + curriculum), model surgery (prune + quantize + merge + surgery), experiment archaeology (trend + flashback + archive + annotate + search + template + replay), research communication (cite + present + changelog), collaboration (onboard + share + review), what-if analysis (whatif + counterfactual + simulate), 16 ADRs. See [docs/ARCHITECTURE.md](docs/ARCHITECTURE.md) for the full codemap.
|
|
565
568
|
|
|
566
569
|
```
|
|
567
570
|
turing/
|
|
568
|
-
├── commands/
|
|
571
|
+
├── commands/ 65 skill files (core + taste-leverage + reporting + exploration + statistical rigor + experiment intelligence + performance + deployment + research workflow + orchestration + deep analysis + model composition + scaling & efficiency + meta-intelligence + pre-training intelligence + model debugging + feature & training intelligence + model surgery + experiment archaeology + research communication + what-if analysis)
|
|
569
572
|
├── agents/ 2 agents (researcher: read/write, evaluator: read-only)
|
|
570
573
|
├── config/ 8 files (lifecycle, taxonomy, archetypes, novelty aliases)
|
|
571
574
|
├── templates/ Scaffolded into user projects by /turing:init
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: counterfactual
|
|
3
|
+
description: Input-level counterfactual explanations — find the smallest input change to flip a prediction.
|
|
4
|
+
disable-model-invocation: true
|
|
5
|
+
argument-hint: "<exp-id> --sample <index> [--target <class>]"
|
|
6
|
+
allowed-tools: Read, Bash(*), Grep, Glob
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
What would need to change to flip this prediction? Minimum-change counterfactual for individual predictions.
|
|
10
|
+
|
|
11
|
+
## Steps
|
|
12
|
+
1. `source .venv/bin/activate`
|
|
13
|
+
2. `python scripts/counterfactual_explanation.py $ARGUMENTS`
|
|
14
|
+
3. **Saved:** `experiments/counterfactuals/`
|
|
15
|
+
|
|
16
|
+
## Methods
|
|
17
|
+
- **Greedy perturbation:** change one feature at a time, find minimum flip
|
|
18
|
+
- **Prototype-based:** find nearest training sample from target class
|
|
19
|
+
- Both methods run and the best (smallest distance) is selected
|
|
20
|
+
|
|
21
|
+
## Examples
|
|
22
|
+
```
|
|
23
|
+
/turing:counterfactual exp-042 --sample 1247
|
|
24
|
+
/turing:counterfactual exp-042 --sample 1247 --target 0
|
|
25
|
+
/turing:counterfactual exp-042 --batch-misclassified
|
|
26
|
+
/turing:counterfactual exp-042 --sample 500 --json
|
|
27
|
+
```
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: simulate
|
|
3
|
+
description: Experiment outcome prediction — predict which configs will beat the current best before running them.
|
|
4
|
+
disable-model-invocation: true
|
|
5
|
+
argument-hint: "[--configs configs.yaml] [--top-k 5] [--threshold 0.001]"
|
|
6
|
+
allowed-tools: Read, Bash(*), Grep, Glob
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
Predict outcomes before spending compute. Ranks proposed configs and recommends which to run vs skip.
|
|
10
|
+
|
|
11
|
+
## Steps
|
|
12
|
+
1. `source .venv/bin/activate`
|
|
13
|
+
2. `python scripts/experiment_simulator.py $ARGUMENTS`
|
|
14
|
+
3. **Saved:** `experiments/simulations/`
|
|
15
|
+
|
|
16
|
+
## How it works
|
|
17
|
+
- Builds a surrogate model from experiment history (weighted k-NN)
|
|
18
|
+
- Predicts metric for each proposed config
|
|
19
|
+
- Applies novelty penalty for configs far from training distribution
|
|
20
|
+
- Ranks and filters: only recommend configs predicted to improve
|
|
21
|
+
|
|
22
|
+
## Examples
|
|
23
|
+
```
|
|
24
|
+
/turing:simulate --configs sweep_configs.yaml
|
|
25
|
+
/turing:simulate --configs candidates.yaml --top-k 3
|
|
26
|
+
/turing:simulate --configs proposals.yaml --threshold 0.005
|
|
27
|
+
/turing:simulate --configs sweep.yaml --json
|
|
28
|
+
```
|
package/commands/turing.md
CHANGED
|
@@ -75,6 +75,9 @@ You are the Turing ML research router. Detect the user's intent and route to the
|
|
|
75
75
|
| "search", "find experiment", "query experiments", "which experiments" | `/turing:search` | Query |
|
|
76
76
|
| "template", "recipe", "save config", "reusable config", "starting point" | `/turing:template` | Manage |
|
|
77
77
|
| "replay", "re-run", "revisit", "retry old", "would it work now" | `/turing:replay` | Validate |
|
|
78
|
+
| "what if", "what-if", "hypothetical", "estimate impact", "would it help" | `/turing:whatif` | Analyze |
|
|
79
|
+
| "counterfactual", "flip prediction", "why this prediction", "minimum change", "explanation" | `/turing:counterfactual` | Explain |
|
|
80
|
+
| "simulate", "predict outcome", "pre-filter", "which configs will work", "forecast" | `/turing:simulate` | Predict |
|
|
78
81
|
|
|
79
82
|
## Sub-commands
|
|
80
83
|
|
|
@@ -146,6 +149,9 @@ You are the Turing ML research router. Detect the user's intent and route to the
|
|
|
146
149
|
| `/turing:onboard [--audience]` | Project onboarding: full walkthrough for new collaborators | (inline) |
|
|
147
150
|
| `/turing:share <exp-ids...>` | Experiment packaging: portable archive with manifest and README | (inline) |
|
|
148
151
|
| `/turing:review [--venue]` | Peer review simulation: weaknesses, questions, fix commands, score | (inline) |
|
|
152
|
+
| `/turing:whatif "<question>"` | What-if analysis: route hypotheticals to existing estimators (scaling, ablation, sensitivity, ensemble, pruning) | (inline) |
|
|
153
|
+
| `/turing:counterfactual <exp-id> --sample <index>` | Input-level counterfactual explanations: minimum input change to flip a prediction | (inline) |
|
|
154
|
+
| `/turing:simulate [--configs] [--top-k]` | Experiment outcome prediction: pre-filter configs using surrogate model, save budget | (inline) |
|
|
149
155
|
|
|
150
156
|
## Proactive Detection
|
|
151
157
|
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: whatif
|
|
3
|
+
description: What-if analysis — answer hypotheticals from existing experiment data without running new experiments.
|
|
4
|
+
disable-model-invocation: true
|
|
5
|
+
argument-hint: "\"<question>\" [--json]"
|
|
6
|
+
allowed-tools: Read, Bash(*), Grep, Glob
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
Answer "what if?" questions using existing experiment data. Routes to the right estimator automatically.
|
|
10
|
+
|
|
11
|
+
## Steps
|
|
12
|
+
1. `source .venv/bin/activate`
|
|
13
|
+
2. `python scripts/whatif_engine.py $ARGUMENTS`
|
|
14
|
+
3. **Saved:** `experiments/whatif/`
|
|
15
|
+
|
|
16
|
+
## Supported question types
|
|
17
|
+
- **Data scaling:** "what if I had 2x more data" → scaling law extrapolation
|
|
18
|
+
- **Ablation:** "what if I removed class 3" → ablation study data
|
|
19
|
+
- **Pipeline stitch:** "what if I combined exp-031 with exp-042" → stitch estimation
|
|
20
|
+
- **Hyperparameters:** "what if learning_rate was 0.01" → sensitivity interpolation
|
|
21
|
+
- **Ensemble:** "what if I ensembled the top models" → correlation analysis
|
|
22
|
+
- **Pruning:** "what if I pruned to 50% sparsity" → pruning sweep interpolation
|
|
23
|
+
- **Budget:** "what if I spent my budget on X vs Y" → budget allocation
|
|
24
|
+
|
|
25
|
+
## Examples
|
|
26
|
+
```
|
|
27
|
+
/turing:whatif "what if I had 2x more data"
|
|
28
|
+
/turing:whatif "what if I removed class 3"
|
|
29
|
+
/turing:whatif "what if I combined exp-031 with exp-042"
|
|
30
|
+
/turing:whatif "what if learning_rate was 0.01" --json
|
|
31
|
+
```
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "claude-turing",
|
|
3
|
-
"version": "4.
|
|
3
|
+
"version": "4.2.0",
|
|
4
4
|
"type": "module",
|
|
5
5
|
"description": "Autonomous ML research harness for Claude Code. The autoresearch loop as a formal protocol — iteratively trains, evaluates, and improves ML models with structured experiment tracking, convergence detection, immutable evaluation infrastructure, and safety guardrails.",
|
|
6
6
|
"bin": {
|
package/src/install.js
CHANGED
|
@@ -36,6 +36,7 @@ const SUB_COMMANDS = [
|
|
|
36
36
|
"trend", "flashback", "archive", "annotate", "search", "template", "replay",
|
|
37
37
|
"cite", "present", "changelog",
|
|
38
38
|
"onboard", "share", "review",
|
|
39
|
+
"whatif", "counterfactual", "simulate",
|
|
39
40
|
];
|
|
40
41
|
|
|
41
42
|
export async function install(opts = {}) {
|
package/src/verify.js
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Input-level counterfactual explanations for the autoresearch pipeline.
|
|
3
|
+
|
|
4
|
+
For a given prediction, finds the smallest input change that would flip
|
|
5
|
+
the outcome. "This sample was classified as X — what's the minimum change
|
|
6
|
+
to make it Y?" Useful for debugging predictions and regulatory explanations.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/counterfactual_explanation.py exp-042 --sample 1247
|
|
10
|
+
python scripts/counterfactual_explanation.py exp-042 --sample 1247 --target 0
|
|
11
|
+
python scripts/counterfactual_explanation.py exp-042 --batch-misclassified
|
|
12
|
+
python scripts/counterfactual_explanation.py --json
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import sys
|
|
20
|
+
from datetime import datetime, timezone
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import yaml
|
|
25
|
+
|
|
26
|
+
from scripts.turing_io import load_config, load_experiments
|
|
27
|
+
|
|
28
|
+
DEFAULT_LOG_PATH = "experiments/log.jsonl"
|
|
29
|
+
DEFAULT_MAX_ITERATIONS = 100
|
|
30
|
+
DEFAULT_DISTANCE_METRIC = "normalized_l2"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# --- Feature Perturbation ---
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def greedy_perturbation(
|
|
37
|
+
sample: dict[str, float],
|
|
38
|
+
predict_fn,
|
|
39
|
+
target_class: int | str,
|
|
40
|
+
feature_names: list[str],
|
|
41
|
+
feature_ranges: dict[str, tuple[float, float]],
|
|
42
|
+
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
|
43
|
+
categorical_features: list[str] | None = None,
|
|
44
|
+
) -> dict:
|
|
45
|
+
"""Find counterfactual by greedily changing one feature at a time.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
sample: Original sample as {feature_name: value}.
|
|
49
|
+
predict_fn: Function(sample_dict) -> (predicted_class, confidence).
|
|
50
|
+
target_class: Desired target class.
|
|
51
|
+
feature_names: Ordered list of feature names.
|
|
52
|
+
feature_ranges: {feature: (min, max)} from training data.
|
|
53
|
+
max_iterations: Maximum perturbation attempts.
|
|
54
|
+
categorical_features: Features that are categorical (discrete changes).
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Counterfactual result dict.
|
|
58
|
+
"""
|
|
59
|
+
if categorical_features is None:
|
|
60
|
+
categorical_features = []
|
|
61
|
+
|
|
62
|
+
current = dict(sample)
|
|
63
|
+
original_pred, original_conf = predict_fn(sample)
|
|
64
|
+
|
|
65
|
+
if str(original_pred) == str(target_class):
|
|
66
|
+
return {
|
|
67
|
+
"status": "already_target",
|
|
68
|
+
"message": f"Sample is already predicted as {target_class}",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
best_cf = None
|
|
72
|
+
best_distance = float("inf")
|
|
73
|
+
changes = []
|
|
74
|
+
|
|
75
|
+
for iteration in range(max_iterations):
|
|
76
|
+
improved = False
|
|
77
|
+
|
|
78
|
+
for feat in feature_names:
|
|
79
|
+
if feat in categorical_features:
|
|
80
|
+
candidates = _categorical_candidates(feat, current[feat], feature_ranges.get(feat))
|
|
81
|
+
else:
|
|
82
|
+
candidates = _numeric_candidates(
|
|
83
|
+
current[feat],
|
|
84
|
+
feature_ranges.get(feat, (0, 1)),
|
|
85
|
+
n_steps=5,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
for candidate_val in candidates:
|
|
89
|
+
trial = dict(current)
|
|
90
|
+
trial[feat] = candidate_val
|
|
91
|
+
pred, conf = predict_fn(trial)
|
|
92
|
+
|
|
93
|
+
if str(pred) == str(target_class):
|
|
94
|
+
dist = _compute_distance(sample, trial, feature_ranges)
|
|
95
|
+
if dist < best_distance:
|
|
96
|
+
best_distance = dist
|
|
97
|
+
best_cf = dict(trial)
|
|
98
|
+
changes = _compute_changes(sample, trial, feature_names)
|
|
99
|
+
improved = True
|
|
100
|
+
|
|
101
|
+
if best_cf is not None and not improved:
|
|
102
|
+
break
|
|
103
|
+
|
|
104
|
+
if best_cf is None:
|
|
105
|
+
return {
|
|
106
|
+
"status": "not_found",
|
|
107
|
+
"message": f"Could not find counterfactual within {max_iterations} iterations",
|
|
108
|
+
"original_prediction": original_pred,
|
|
109
|
+
"original_confidence": float(original_conf),
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
cf_pred, cf_conf = predict_fn(best_cf)
|
|
113
|
+
|
|
114
|
+
return {
|
|
115
|
+
"status": "found",
|
|
116
|
+
"original_prediction": original_pred,
|
|
117
|
+
"original_confidence": float(original_conf),
|
|
118
|
+
"counterfactual_prediction": cf_pred,
|
|
119
|
+
"counterfactual_confidence": float(cf_conf),
|
|
120
|
+
"distance": round(float(best_distance), 4),
|
|
121
|
+
"n_changes": len(changes),
|
|
122
|
+
"changes": changes,
|
|
123
|
+
"counterfactual_sample": best_cf,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _numeric_candidates(current: float, value_range: tuple[float, float], n_steps: int = 5) -> list[float]:
|
|
128
|
+
"""Generate candidate values for a numeric feature."""
|
|
129
|
+
low, high = value_range
|
|
130
|
+
step = (high - low) / max(n_steps, 1)
|
|
131
|
+
candidates = []
|
|
132
|
+
for i in range(n_steps + 1):
|
|
133
|
+
val = low + i * step
|
|
134
|
+
if val != current:
|
|
135
|
+
candidates.append(val)
|
|
136
|
+
return candidates
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _categorical_candidates(
|
|
140
|
+
feature: str,
|
|
141
|
+
current_value,
|
|
142
|
+
value_range: tuple | list | None,
|
|
143
|
+
) -> list:
|
|
144
|
+
"""Generate candidate values for a categorical feature."""
|
|
145
|
+
if value_range is None:
|
|
146
|
+
return []
|
|
147
|
+
if isinstance(value_range, (tuple, list)):
|
|
148
|
+
return [v for v in value_range if v != current_value]
|
|
149
|
+
return []
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _compute_distance(
|
|
153
|
+
original: dict[str, float],
|
|
154
|
+
counterfactual: dict[str, float],
|
|
155
|
+
feature_ranges: dict[str, tuple[float, float]],
|
|
156
|
+
) -> float:
|
|
157
|
+
"""Compute normalized L2 distance between original and counterfactual."""
|
|
158
|
+
total = 0.0
|
|
159
|
+
for feat in original:
|
|
160
|
+
orig_val = original[feat]
|
|
161
|
+
cf_val = counterfactual.get(feat, orig_val)
|
|
162
|
+
feat_range = feature_ranges.get(feat, (0, 1))
|
|
163
|
+
|
|
164
|
+
if isinstance(orig_val, str) or (isinstance(feat_range, (tuple, list)) and len(feat_range) > 2):
|
|
165
|
+
# Categorical: 1 if changed, 0 if same
|
|
166
|
+
total += 0.0 if orig_val == cf_val else 1.0
|
|
167
|
+
else:
|
|
168
|
+
low, high = feat_range[0], feat_range[1]
|
|
169
|
+
span = high - low if high != low else 1
|
|
170
|
+
normalized_diff = (cf_val - orig_val) / span
|
|
171
|
+
total += normalized_diff ** 2
|
|
172
|
+
return float(np.sqrt(total))
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _compute_changes(
|
|
176
|
+
original: dict[str, float],
|
|
177
|
+
counterfactual: dict[str, float],
|
|
178
|
+
feature_names: list[str],
|
|
179
|
+
) -> list[dict]:
|
|
180
|
+
"""Compute the list of changed features."""
|
|
181
|
+
changes = []
|
|
182
|
+
for feat in feature_names:
|
|
183
|
+
orig = original.get(feat)
|
|
184
|
+
cf = counterfactual.get(feat)
|
|
185
|
+
if orig != cf:
|
|
186
|
+
change = {
|
|
187
|
+
"feature": feat,
|
|
188
|
+
"original": orig,
|
|
189
|
+
"counterfactual": cf,
|
|
190
|
+
}
|
|
191
|
+
if isinstance(orig, (int, float)) and isinstance(cf, (int, float)):
|
|
192
|
+
change["delta"] = round(cf - orig, 6)
|
|
193
|
+
else:
|
|
194
|
+
change["delta"] = "category_change"
|
|
195
|
+
changes.append(change)
|
|
196
|
+
return changes
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# --- Prototype-Based Search ---
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def prototype_counterfactual(
|
|
203
|
+
sample: dict[str, float],
|
|
204
|
+
training_data: list[dict[str, float]],
|
|
205
|
+
training_labels: list,
|
|
206
|
+
target_class: int | str,
|
|
207
|
+
feature_names: list[str],
|
|
208
|
+
feature_ranges: dict[str, tuple[float, float]],
|
|
209
|
+
) -> dict:
|
|
210
|
+
"""Find the nearest training sample from the target class.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
sample: Original sample.
|
|
214
|
+
training_data: List of training samples as dicts.
|
|
215
|
+
training_labels: Corresponding labels.
|
|
216
|
+
target_class: Desired target class.
|
|
217
|
+
feature_names: Feature names.
|
|
218
|
+
feature_ranges: {feature: (min, max)}.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Nearest prototype counterfactual result.
|
|
222
|
+
"""
|
|
223
|
+
target_indices = [i for i, label in enumerate(training_labels) if str(label) == str(target_class)]
|
|
224
|
+
|
|
225
|
+
if not target_indices:
|
|
226
|
+
return {
|
|
227
|
+
"status": "not_found",
|
|
228
|
+
"message": f"No training samples found for class {target_class}",
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
best_dist = float("inf")
|
|
232
|
+
best_idx = -1
|
|
233
|
+
|
|
234
|
+
for idx in target_indices:
|
|
235
|
+
dist = _compute_distance(sample, training_data[idx], feature_ranges)
|
|
236
|
+
if dist < best_dist:
|
|
237
|
+
best_dist = dist
|
|
238
|
+
best_idx = idx
|
|
239
|
+
|
|
240
|
+
if best_idx < 0:
|
|
241
|
+
return {"status": "not_found", "message": "No valid prototype found"}
|
|
242
|
+
|
|
243
|
+
prototype = training_data[best_idx]
|
|
244
|
+
changes = _compute_changes(sample, prototype, feature_names)
|
|
245
|
+
|
|
246
|
+
return {
|
|
247
|
+
"status": "found",
|
|
248
|
+
"method": "prototype",
|
|
249
|
+
"prototype_index": best_idx,
|
|
250
|
+
"distance": round(float(best_dist), 4),
|
|
251
|
+
"n_changes": len(changes),
|
|
252
|
+
"changes": changes,
|
|
253
|
+
"counterfactual_sample": prototype,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# --- Full Pipeline ---
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def counterfactual_analysis(
|
|
261
|
+
exp_id: str,
|
|
262
|
+
sample_index: int | None = None,
|
|
263
|
+
sample_data: dict[str, float] | None = None,
|
|
264
|
+
target_class: int | str | None = None,
|
|
265
|
+
predict_fn=None,
|
|
266
|
+
training_data: list[dict] | None = None,
|
|
267
|
+
training_labels: list | None = None,
|
|
268
|
+
feature_names: list[str] | None = None,
|
|
269
|
+
feature_ranges: dict[str, tuple[float, float]] | None = None,
|
|
270
|
+
categorical_features: list[str] | None = None,
|
|
271
|
+
batch_misclassified: bool = False,
|
|
272
|
+
config_path: str = "config.yaml",
|
|
273
|
+
log_path: str = DEFAULT_LOG_PATH,
|
|
274
|
+
) -> dict:
|
|
275
|
+
"""Run counterfactual analysis.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
exp_id: Experiment ID to analyze.
|
|
279
|
+
sample_index: Index of the sample to explain.
|
|
280
|
+
sample_data: Direct sample data (alternative to index).
|
|
281
|
+
target_class: Desired counterfactual class.
|
|
282
|
+
predict_fn: Prediction function (sample_dict) -> (class, confidence).
|
|
283
|
+
training_data: Training data for prototype search.
|
|
284
|
+
training_labels: Training labels for prototype search.
|
|
285
|
+
feature_names: Feature names.
|
|
286
|
+
feature_ranges: Feature value ranges.
|
|
287
|
+
categorical_features: Categorical feature names.
|
|
288
|
+
batch_misclassified: If True, generate for all misclassified samples.
|
|
289
|
+
config_path: Path to config.yaml.
|
|
290
|
+
log_path: Path to experiment log.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Counterfactual analysis report.
|
|
294
|
+
"""
|
|
295
|
+
config = load_config(config_path)
|
|
296
|
+
|
|
297
|
+
if sample_data is None and sample_index is None and not batch_misclassified:
|
|
298
|
+
return {"error": "Provide --sample <index> or --batch-misclassified"}
|
|
299
|
+
|
|
300
|
+
if predict_fn is None:
|
|
301
|
+
return {
|
|
302
|
+
"error": "No prediction function available. "
|
|
303
|
+
"Load the model from the experiment first.",
|
|
304
|
+
"suggestion": f"Run `/turing:counterfactual {exp_id} --sample <index>` "
|
|
305
|
+
"from the experiment directory with train.py available.",
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
if feature_names is None:
|
|
309
|
+
return {"error": "Feature names not available. Provide feature_names."}
|
|
310
|
+
|
|
311
|
+
if feature_ranges is None:
|
|
312
|
+
feature_ranges = {}
|
|
313
|
+
|
|
314
|
+
results = []
|
|
315
|
+
|
|
316
|
+
if batch_misclassified and training_data and training_labels:
|
|
317
|
+
for i, (data, label) in enumerate(zip(training_data, training_labels)):
|
|
318
|
+
pred, conf = predict_fn(data)
|
|
319
|
+
if str(pred) != str(label):
|
|
320
|
+
cf = greedy_perturbation(
|
|
321
|
+
data, predict_fn, label, feature_names,
|
|
322
|
+
feature_ranges, categorical_features=categorical_features,
|
|
323
|
+
)
|
|
324
|
+
cf["sample_index"] = i
|
|
325
|
+
cf["true_label"] = label
|
|
326
|
+
results.append(cf)
|
|
327
|
+
elif sample_data is not None:
|
|
328
|
+
if target_class is None:
|
|
329
|
+
pred, _ = predict_fn(sample_data)
|
|
330
|
+
# Flip to opposite for binary
|
|
331
|
+
target_class = 0 if pred == 1 else 1
|
|
332
|
+
|
|
333
|
+
# Try greedy perturbation
|
|
334
|
+
cf_greedy = greedy_perturbation(
|
|
335
|
+
sample_data, predict_fn, target_class, feature_names,
|
|
336
|
+
feature_ranges, categorical_features=categorical_features,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Try prototype-based if training data available
|
|
340
|
+
cf_proto = None
|
|
341
|
+
if training_data and training_labels:
|
|
342
|
+
cf_proto = prototype_counterfactual(
|
|
343
|
+
sample_data, training_data, training_labels,
|
|
344
|
+
target_class, feature_names, feature_ranges,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
results = {
|
|
348
|
+
"greedy": cf_greedy,
|
|
349
|
+
"prototype": cf_proto,
|
|
350
|
+
"best": _select_best([cf_greedy, cf_proto]),
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
return {
|
|
354
|
+
"experiment_id": exp_id,
|
|
355
|
+
"sample_index": sample_index,
|
|
356
|
+
"target_class": target_class,
|
|
357
|
+
"results": results,
|
|
358
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _select_best(candidates: list[dict | None]) -> dict | None:
|
|
363
|
+
"""Select the counterfactual with smallest distance."""
|
|
364
|
+
valid = [c for c in candidates if c and c.get("status") == "found"]
|
|
365
|
+
if not valid:
|
|
366
|
+
return None
|
|
367
|
+
return min(valid, key=lambda c: c.get("distance", float("inf")))
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
# --- Report Formatting ---
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def save_counterfactual_report(report: dict, output_dir: str = "experiments/counterfactuals") -> Path:
|
|
374
|
+
"""Save counterfactual report to YAML."""
|
|
375
|
+
out_path = Path(output_dir)
|
|
376
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
377
|
+
exp_id = report.get("experiment_id", "unknown")
|
|
378
|
+
sample = report.get("sample_index", "batch")
|
|
379
|
+
filepath = out_path / f"{exp_id}-cf-{sample}.yaml"
|
|
380
|
+
with open(filepath, "w") as f:
|
|
381
|
+
yaml.dump(report, f, default_flow_style=False, sort_keys=False)
|
|
382
|
+
return filepath
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def format_counterfactual_report(report: dict) -> str:
|
|
386
|
+
"""Format counterfactual report as readable markdown."""
|
|
387
|
+
if "error" in report:
|
|
388
|
+
return f"ERROR: {report['error']}"
|
|
389
|
+
|
|
390
|
+
lines = ["# Counterfactual Explanation", ""]
|
|
391
|
+
lines.append(f"**Experiment:** {report.get('experiment_id', 'N/A')}")
|
|
392
|
+
lines.append(f"**Sample:** {report.get('sample_index', 'N/A')}")
|
|
393
|
+
lines.append(f"**Target class:** {report.get('target_class', 'N/A')}")
|
|
394
|
+
lines.append("")
|
|
395
|
+
|
|
396
|
+
results = report.get("results", {})
|
|
397
|
+
|
|
398
|
+
if isinstance(results, dict):
|
|
399
|
+
best = results.get("best")
|
|
400
|
+
if best and best.get("status") == "found":
|
|
401
|
+
lines.append(f"**Method:** {best.get('method', 'greedy')}")
|
|
402
|
+
lines.append(f"**Distance:** {best.get('distance', 'N/A')}")
|
|
403
|
+
lines.append(f"**Changes needed:** {best.get('n_changes', 0)}")
|
|
404
|
+
lines.append("")
|
|
405
|
+
|
|
406
|
+
changes = best.get("changes", [])
|
|
407
|
+
if changes:
|
|
408
|
+
lines.append("| Feature | Original | Counterfactual | Change |")
|
|
409
|
+
lines.append("|---------|----------|----------------|--------|")
|
|
410
|
+
for c in changes:
|
|
411
|
+
delta = c.get("delta", "")
|
|
412
|
+
if isinstance(delta, (int, float)):
|
|
413
|
+
delta_str = f"{delta:+.4f}" if isinstance(delta, float) else f"{delta:+d}"
|
|
414
|
+
else:
|
|
415
|
+
delta_str = str(delta)
|
|
416
|
+
lines.append(
|
|
417
|
+
f"| {c['feature']} | {c['original']} | {c['counterfactual']} | {delta_str} |"
|
|
418
|
+
)
|
|
419
|
+
else:
|
|
420
|
+
lines.append("No counterfactual found within search budget.")
|
|
421
|
+
|
|
422
|
+
# Show method comparison
|
|
423
|
+
greedy = results.get("greedy", {})
|
|
424
|
+
proto = results.get("prototype", {})
|
|
425
|
+
if greedy.get("status") == "found" or (proto and proto.get("status") == "found"):
|
|
426
|
+
lines.append("")
|
|
427
|
+
lines.append("**Method comparison:**")
|
|
428
|
+
if greedy.get("status") == "found":
|
|
429
|
+
lines.append(f"- Greedy: distance={greedy.get('distance')}, changes={greedy.get('n_changes')}")
|
|
430
|
+
if proto and proto.get("status") == "found":
|
|
431
|
+
lines.append(f"- Prototype: distance={proto.get('distance')}, changes={proto.get('n_changes')}")
|
|
432
|
+
|
|
433
|
+
elif isinstance(results, list):
|
|
434
|
+
lines.append(f"**Batch results:** {len(results)} misclassified samples analyzed")
|
|
435
|
+
found = sum(1 for r in results if r.get("status") == "found")
|
|
436
|
+
lines.append(f"**Counterfactuals found:** {found}/{len(results)}")
|
|
437
|
+
|
|
438
|
+
lines.append("")
|
|
439
|
+
lines.append(f"*Generated: {report.get('generated_at', 'N/A')}*")
|
|
440
|
+
return "\n".join(lines)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
# --- CLI ---
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def main():
|
|
447
|
+
parser = argparse.ArgumentParser(
|
|
448
|
+
description="Counterfactual explanations — find minimum input changes to flip predictions"
|
|
449
|
+
)
|
|
450
|
+
parser.add_argument("exp_id", nargs="?", help="Experiment ID")
|
|
451
|
+
parser.add_argument("--sample", type=int, help="Sample index to explain")
|
|
452
|
+
parser.add_argument("--target", help="Target class for counterfactual")
|
|
453
|
+
parser.add_argument("--batch-misclassified", action="store_true",
|
|
454
|
+
help="Generate counterfactuals for all misclassified samples")
|
|
455
|
+
parser.add_argument("--config", default="config.yaml", help="Path to config.yaml")
|
|
456
|
+
parser.add_argument("--log", default=DEFAULT_LOG_PATH, help="Path to experiment log")
|
|
457
|
+
parser.add_argument("--json", action="store_true", help="Output raw JSON")
|
|
458
|
+
|
|
459
|
+
args = parser.parse_args()
|
|
460
|
+
|
|
461
|
+
if not args.exp_id:
|
|
462
|
+
parser.error("Please provide an experiment ID")
|
|
463
|
+
|
|
464
|
+
report = counterfactual_analysis(
|
|
465
|
+
exp_id=args.exp_id,
|
|
466
|
+
sample_index=args.sample,
|
|
467
|
+
target_class=args.target,
|
|
468
|
+
batch_misclassified=args.batch_misclassified,
|
|
469
|
+
config_path=args.config,
|
|
470
|
+
log_path=args.log,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
if args.json:
|
|
474
|
+
print(json.dumps(report, indent=2, default=str))
|
|
475
|
+
else:
|
|
476
|
+
print(format_counterfactual_report(report))
|
|
477
|
+
|
|
478
|
+
if "error" not in report:
|
|
479
|
+
saved = save_counterfactual_report(report)
|
|
480
|
+
if not args.json:
|
|
481
|
+
print(f"\nSaved: {saved}")
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
if __name__ == "__main__":
|
|
485
|
+
main()
|