claude-turing 4.1.0 → 4.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/.claude-plugin/plugin.json +2 -2
- package/README.md +7 -2
- package/commands/counterfactual.md +27 -0
- package/commands/registry.md +31 -0
- package/commands/simulate.md +28 -0
- package/commands/turing.md +10 -0
- package/commands/update.md +27 -0
- package/commands/whatif.md +31 -0
- package/package.json +1 -1
- package/src/install.js +2 -0
- package/src/verify.js +5 -0
- package/templates/scripts/__pycache__/counterfactual_explanation.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/experiment_simulator.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_brief.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/generate_model_card.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/incremental_update.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/model_lifecycle.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/scaffold.cpython-314.pyc +0 -0
- package/templates/scripts/__pycache__/whatif_engine.cpython-314.pyc +0 -0
- package/templates/scripts/counterfactual_explanation.py +485 -0
- package/templates/scripts/experiment_simulator.py +463 -0
- package/templates/scripts/generate_brief.py +125 -0
- package/templates/scripts/generate_model_card.py +154 -3
- package/templates/scripts/incremental_update.py +586 -0
- package/templates/scripts/model_lifecycle.py +549 -0
- package/templates/scripts/scaffold.py +10 -0
- package/templates/scripts/whatif_engine.py +763 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Input-level counterfactual explanations for the autoresearch pipeline.
|
|
3
|
+
|
|
4
|
+
For a given prediction, finds the smallest input change that would flip
|
|
5
|
+
the outcome. "This sample was classified as X — what's the minimum change
|
|
6
|
+
to make it Y?" Useful for debugging predictions and regulatory explanations.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
python scripts/counterfactual_explanation.py exp-042 --sample 1247
|
|
10
|
+
python scripts/counterfactual_explanation.py exp-042 --sample 1247 --target 0
|
|
11
|
+
python scripts/counterfactual_explanation.py exp-042 --batch-misclassified
|
|
12
|
+
python scripts/counterfactual_explanation.py --json
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import argparse
|
|
18
|
+
import json
|
|
19
|
+
import sys
|
|
20
|
+
from datetime import datetime, timezone
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import yaml
|
|
25
|
+
|
|
26
|
+
from scripts.turing_io import load_config, load_experiments
|
|
27
|
+
|
|
28
|
+
DEFAULT_LOG_PATH = "experiments/log.jsonl"
|
|
29
|
+
DEFAULT_MAX_ITERATIONS = 100
|
|
30
|
+
DEFAULT_DISTANCE_METRIC = "normalized_l2"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# --- Feature Perturbation ---
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def greedy_perturbation(
|
|
37
|
+
sample: dict[str, float],
|
|
38
|
+
predict_fn,
|
|
39
|
+
target_class: int | str,
|
|
40
|
+
feature_names: list[str],
|
|
41
|
+
feature_ranges: dict[str, tuple[float, float]],
|
|
42
|
+
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
|
43
|
+
categorical_features: list[str] | None = None,
|
|
44
|
+
) -> dict:
|
|
45
|
+
"""Find counterfactual by greedily changing one feature at a time.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
sample: Original sample as {feature_name: value}.
|
|
49
|
+
predict_fn: Function(sample_dict) -> (predicted_class, confidence).
|
|
50
|
+
target_class: Desired target class.
|
|
51
|
+
feature_names: Ordered list of feature names.
|
|
52
|
+
feature_ranges: {feature: (min, max)} from training data.
|
|
53
|
+
max_iterations: Maximum perturbation attempts.
|
|
54
|
+
categorical_features: Features that are categorical (discrete changes).
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Counterfactual result dict.
|
|
58
|
+
"""
|
|
59
|
+
if categorical_features is None:
|
|
60
|
+
categorical_features = []
|
|
61
|
+
|
|
62
|
+
current = dict(sample)
|
|
63
|
+
original_pred, original_conf = predict_fn(sample)
|
|
64
|
+
|
|
65
|
+
if str(original_pred) == str(target_class):
|
|
66
|
+
return {
|
|
67
|
+
"status": "already_target",
|
|
68
|
+
"message": f"Sample is already predicted as {target_class}",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
best_cf = None
|
|
72
|
+
best_distance = float("inf")
|
|
73
|
+
changes = []
|
|
74
|
+
|
|
75
|
+
for iteration in range(max_iterations):
|
|
76
|
+
improved = False
|
|
77
|
+
|
|
78
|
+
for feat in feature_names:
|
|
79
|
+
if feat in categorical_features:
|
|
80
|
+
candidates = _categorical_candidates(feat, current[feat], feature_ranges.get(feat))
|
|
81
|
+
else:
|
|
82
|
+
candidates = _numeric_candidates(
|
|
83
|
+
current[feat],
|
|
84
|
+
feature_ranges.get(feat, (0, 1)),
|
|
85
|
+
n_steps=5,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
for candidate_val in candidates:
|
|
89
|
+
trial = dict(current)
|
|
90
|
+
trial[feat] = candidate_val
|
|
91
|
+
pred, conf = predict_fn(trial)
|
|
92
|
+
|
|
93
|
+
if str(pred) == str(target_class):
|
|
94
|
+
dist = _compute_distance(sample, trial, feature_ranges)
|
|
95
|
+
if dist < best_distance:
|
|
96
|
+
best_distance = dist
|
|
97
|
+
best_cf = dict(trial)
|
|
98
|
+
changes = _compute_changes(sample, trial, feature_names)
|
|
99
|
+
improved = True
|
|
100
|
+
|
|
101
|
+
if best_cf is not None and not improved:
|
|
102
|
+
break
|
|
103
|
+
|
|
104
|
+
if best_cf is None:
|
|
105
|
+
return {
|
|
106
|
+
"status": "not_found",
|
|
107
|
+
"message": f"Could not find counterfactual within {max_iterations} iterations",
|
|
108
|
+
"original_prediction": original_pred,
|
|
109
|
+
"original_confidence": float(original_conf),
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
cf_pred, cf_conf = predict_fn(best_cf)
|
|
113
|
+
|
|
114
|
+
return {
|
|
115
|
+
"status": "found",
|
|
116
|
+
"original_prediction": original_pred,
|
|
117
|
+
"original_confidence": float(original_conf),
|
|
118
|
+
"counterfactual_prediction": cf_pred,
|
|
119
|
+
"counterfactual_confidence": float(cf_conf),
|
|
120
|
+
"distance": round(float(best_distance), 4),
|
|
121
|
+
"n_changes": len(changes),
|
|
122
|
+
"changes": changes,
|
|
123
|
+
"counterfactual_sample": best_cf,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _numeric_candidates(current: float, value_range: tuple[float, float], n_steps: int = 5) -> list[float]:
|
|
128
|
+
"""Generate candidate values for a numeric feature."""
|
|
129
|
+
low, high = value_range
|
|
130
|
+
step = (high - low) / max(n_steps, 1)
|
|
131
|
+
candidates = []
|
|
132
|
+
for i in range(n_steps + 1):
|
|
133
|
+
val = low + i * step
|
|
134
|
+
if val != current:
|
|
135
|
+
candidates.append(val)
|
|
136
|
+
return candidates
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _categorical_candidates(
|
|
140
|
+
feature: str,
|
|
141
|
+
current_value,
|
|
142
|
+
value_range: tuple | list | None,
|
|
143
|
+
) -> list:
|
|
144
|
+
"""Generate candidate values for a categorical feature."""
|
|
145
|
+
if value_range is None:
|
|
146
|
+
return []
|
|
147
|
+
if isinstance(value_range, (tuple, list)):
|
|
148
|
+
return [v for v in value_range if v != current_value]
|
|
149
|
+
return []
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _compute_distance(
|
|
153
|
+
original: dict[str, float],
|
|
154
|
+
counterfactual: dict[str, float],
|
|
155
|
+
feature_ranges: dict[str, tuple[float, float]],
|
|
156
|
+
) -> float:
|
|
157
|
+
"""Compute normalized L2 distance between original and counterfactual."""
|
|
158
|
+
total = 0.0
|
|
159
|
+
for feat in original:
|
|
160
|
+
orig_val = original[feat]
|
|
161
|
+
cf_val = counterfactual.get(feat, orig_val)
|
|
162
|
+
feat_range = feature_ranges.get(feat, (0, 1))
|
|
163
|
+
|
|
164
|
+
if isinstance(orig_val, str) or (isinstance(feat_range, (tuple, list)) and len(feat_range) > 2):
|
|
165
|
+
# Categorical: 1 if changed, 0 if same
|
|
166
|
+
total += 0.0 if orig_val == cf_val else 1.0
|
|
167
|
+
else:
|
|
168
|
+
low, high = feat_range[0], feat_range[1]
|
|
169
|
+
span = high - low if high != low else 1
|
|
170
|
+
normalized_diff = (cf_val - orig_val) / span
|
|
171
|
+
total += normalized_diff ** 2
|
|
172
|
+
return float(np.sqrt(total))
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _compute_changes(
|
|
176
|
+
original: dict[str, float],
|
|
177
|
+
counterfactual: dict[str, float],
|
|
178
|
+
feature_names: list[str],
|
|
179
|
+
) -> list[dict]:
|
|
180
|
+
"""Compute the list of changed features."""
|
|
181
|
+
changes = []
|
|
182
|
+
for feat in feature_names:
|
|
183
|
+
orig = original.get(feat)
|
|
184
|
+
cf = counterfactual.get(feat)
|
|
185
|
+
if orig != cf:
|
|
186
|
+
change = {
|
|
187
|
+
"feature": feat,
|
|
188
|
+
"original": orig,
|
|
189
|
+
"counterfactual": cf,
|
|
190
|
+
}
|
|
191
|
+
if isinstance(orig, (int, float)) and isinstance(cf, (int, float)):
|
|
192
|
+
change["delta"] = round(cf - orig, 6)
|
|
193
|
+
else:
|
|
194
|
+
change["delta"] = "category_change"
|
|
195
|
+
changes.append(change)
|
|
196
|
+
return changes
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
# --- Prototype-Based Search ---
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def prototype_counterfactual(
|
|
203
|
+
sample: dict[str, float],
|
|
204
|
+
training_data: list[dict[str, float]],
|
|
205
|
+
training_labels: list,
|
|
206
|
+
target_class: int | str,
|
|
207
|
+
feature_names: list[str],
|
|
208
|
+
feature_ranges: dict[str, tuple[float, float]],
|
|
209
|
+
) -> dict:
|
|
210
|
+
"""Find the nearest training sample from the target class.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
sample: Original sample.
|
|
214
|
+
training_data: List of training samples as dicts.
|
|
215
|
+
training_labels: Corresponding labels.
|
|
216
|
+
target_class: Desired target class.
|
|
217
|
+
feature_names: Feature names.
|
|
218
|
+
feature_ranges: {feature: (min, max)}.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
Nearest prototype counterfactual result.
|
|
222
|
+
"""
|
|
223
|
+
target_indices = [i for i, label in enumerate(training_labels) if str(label) == str(target_class)]
|
|
224
|
+
|
|
225
|
+
if not target_indices:
|
|
226
|
+
return {
|
|
227
|
+
"status": "not_found",
|
|
228
|
+
"message": f"No training samples found for class {target_class}",
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
best_dist = float("inf")
|
|
232
|
+
best_idx = -1
|
|
233
|
+
|
|
234
|
+
for idx in target_indices:
|
|
235
|
+
dist = _compute_distance(sample, training_data[idx], feature_ranges)
|
|
236
|
+
if dist < best_dist:
|
|
237
|
+
best_dist = dist
|
|
238
|
+
best_idx = idx
|
|
239
|
+
|
|
240
|
+
if best_idx < 0:
|
|
241
|
+
return {"status": "not_found", "message": "No valid prototype found"}
|
|
242
|
+
|
|
243
|
+
prototype = training_data[best_idx]
|
|
244
|
+
changes = _compute_changes(sample, prototype, feature_names)
|
|
245
|
+
|
|
246
|
+
return {
|
|
247
|
+
"status": "found",
|
|
248
|
+
"method": "prototype",
|
|
249
|
+
"prototype_index": best_idx,
|
|
250
|
+
"distance": round(float(best_dist), 4),
|
|
251
|
+
"n_changes": len(changes),
|
|
252
|
+
"changes": changes,
|
|
253
|
+
"counterfactual_sample": prototype,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# --- Full Pipeline ---
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def counterfactual_analysis(
|
|
261
|
+
exp_id: str,
|
|
262
|
+
sample_index: int | None = None,
|
|
263
|
+
sample_data: dict[str, float] | None = None,
|
|
264
|
+
target_class: int | str | None = None,
|
|
265
|
+
predict_fn=None,
|
|
266
|
+
training_data: list[dict] | None = None,
|
|
267
|
+
training_labels: list | None = None,
|
|
268
|
+
feature_names: list[str] | None = None,
|
|
269
|
+
feature_ranges: dict[str, tuple[float, float]] | None = None,
|
|
270
|
+
categorical_features: list[str] | None = None,
|
|
271
|
+
batch_misclassified: bool = False,
|
|
272
|
+
config_path: str = "config.yaml",
|
|
273
|
+
log_path: str = DEFAULT_LOG_PATH,
|
|
274
|
+
) -> dict:
|
|
275
|
+
"""Run counterfactual analysis.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
exp_id: Experiment ID to analyze.
|
|
279
|
+
sample_index: Index of the sample to explain.
|
|
280
|
+
sample_data: Direct sample data (alternative to index).
|
|
281
|
+
target_class: Desired counterfactual class.
|
|
282
|
+
predict_fn: Prediction function (sample_dict) -> (class, confidence).
|
|
283
|
+
training_data: Training data for prototype search.
|
|
284
|
+
training_labels: Training labels for prototype search.
|
|
285
|
+
feature_names: Feature names.
|
|
286
|
+
feature_ranges: Feature value ranges.
|
|
287
|
+
categorical_features: Categorical feature names.
|
|
288
|
+
batch_misclassified: If True, generate for all misclassified samples.
|
|
289
|
+
config_path: Path to config.yaml.
|
|
290
|
+
log_path: Path to experiment log.
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
Counterfactual analysis report.
|
|
294
|
+
"""
|
|
295
|
+
config = load_config(config_path)
|
|
296
|
+
|
|
297
|
+
if sample_data is None and sample_index is None and not batch_misclassified:
|
|
298
|
+
return {"error": "Provide --sample <index> or --batch-misclassified"}
|
|
299
|
+
|
|
300
|
+
if predict_fn is None:
|
|
301
|
+
return {
|
|
302
|
+
"error": "No prediction function available. "
|
|
303
|
+
"Load the model from the experiment first.",
|
|
304
|
+
"suggestion": f"Run `/turing:counterfactual {exp_id} --sample <index>` "
|
|
305
|
+
"from the experiment directory with train.py available.",
|
|
306
|
+
}
|
|
307
|
+
|
|
308
|
+
if feature_names is None:
|
|
309
|
+
return {"error": "Feature names not available. Provide feature_names."}
|
|
310
|
+
|
|
311
|
+
if feature_ranges is None:
|
|
312
|
+
feature_ranges = {}
|
|
313
|
+
|
|
314
|
+
results = []
|
|
315
|
+
|
|
316
|
+
if batch_misclassified and training_data and training_labels:
|
|
317
|
+
for i, (data, label) in enumerate(zip(training_data, training_labels)):
|
|
318
|
+
pred, conf = predict_fn(data)
|
|
319
|
+
if str(pred) != str(label):
|
|
320
|
+
cf = greedy_perturbation(
|
|
321
|
+
data, predict_fn, label, feature_names,
|
|
322
|
+
feature_ranges, categorical_features=categorical_features,
|
|
323
|
+
)
|
|
324
|
+
cf["sample_index"] = i
|
|
325
|
+
cf["true_label"] = label
|
|
326
|
+
results.append(cf)
|
|
327
|
+
elif sample_data is not None:
|
|
328
|
+
if target_class is None:
|
|
329
|
+
pred, _ = predict_fn(sample_data)
|
|
330
|
+
# Flip to opposite for binary
|
|
331
|
+
target_class = 0 if pred == 1 else 1
|
|
332
|
+
|
|
333
|
+
# Try greedy perturbation
|
|
334
|
+
cf_greedy = greedy_perturbation(
|
|
335
|
+
sample_data, predict_fn, target_class, feature_names,
|
|
336
|
+
feature_ranges, categorical_features=categorical_features,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Try prototype-based if training data available
|
|
340
|
+
cf_proto = None
|
|
341
|
+
if training_data and training_labels:
|
|
342
|
+
cf_proto = prototype_counterfactual(
|
|
343
|
+
sample_data, training_data, training_labels,
|
|
344
|
+
target_class, feature_names, feature_ranges,
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
results = {
|
|
348
|
+
"greedy": cf_greedy,
|
|
349
|
+
"prototype": cf_proto,
|
|
350
|
+
"best": _select_best([cf_greedy, cf_proto]),
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
return {
|
|
354
|
+
"experiment_id": exp_id,
|
|
355
|
+
"sample_index": sample_index,
|
|
356
|
+
"target_class": target_class,
|
|
357
|
+
"results": results,
|
|
358
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
def _select_best(candidates: list[dict | None]) -> dict | None:
|
|
363
|
+
"""Select the counterfactual with smallest distance."""
|
|
364
|
+
valid = [c for c in candidates if c and c.get("status") == "found"]
|
|
365
|
+
if not valid:
|
|
366
|
+
return None
|
|
367
|
+
return min(valid, key=lambda c: c.get("distance", float("inf")))
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
# --- Report Formatting ---
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
def save_counterfactual_report(report: dict, output_dir: str = "experiments/counterfactuals") -> Path:
|
|
374
|
+
"""Save counterfactual report to YAML."""
|
|
375
|
+
out_path = Path(output_dir)
|
|
376
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
377
|
+
exp_id = report.get("experiment_id", "unknown")
|
|
378
|
+
sample = report.get("sample_index", "batch")
|
|
379
|
+
filepath = out_path / f"{exp_id}-cf-{sample}.yaml"
|
|
380
|
+
with open(filepath, "w") as f:
|
|
381
|
+
yaml.dump(report, f, default_flow_style=False, sort_keys=False)
|
|
382
|
+
return filepath
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
def format_counterfactual_report(report: dict) -> str:
|
|
386
|
+
"""Format counterfactual report as readable markdown."""
|
|
387
|
+
if "error" in report:
|
|
388
|
+
return f"ERROR: {report['error']}"
|
|
389
|
+
|
|
390
|
+
lines = ["# Counterfactual Explanation", ""]
|
|
391
|
+
lines.append(f"**Experiment:** {report.get('experiment_id', 'N/A')}")
|
|
392
|
+
lines.append(f"**Sample:** {report.get('sample_index', 'N/A')}")
|
|
393
|
+
lines.append(f"**Target class:** {report.get('target_class', 'N/A')}")
|
|
394
|
+
lines.append("")
|
|
395
|
+
|
|
396
|
+
results = report.get("results", {})
|
|
397
|
+
|
|
398
|
+
if isinstance(results, dict):
|
|
399
|
+
best = results.get("best")
|
|
400
|
+
if best and best.get("status") == "found":
|
|
401
|
+
lines.append(f"**Method:** {best.get('method', 'greedy')}")
|
|
402
|
+
lines.append(f"**Distance:** {best.get('distance', 'N/A')}")
|
|
403
|
+
lines.append(f"**Changes needed:** {best.get('n_changes', 0)}")
|
|
404
|
+
lines.append("")
|
|
405
|
+
|
|
406
|
+
changes = best.get("changes", [])
|
|
407
|
+
if changes:
|
|
408
|
+
lines.append("| Feature | Original | Counterfactual | Change |")
|
|
409
|
+
lines.append("|---------|----------|----------------|--------|")
|
|
410
|
+
for c in changes:
|
|
411
|
+
delta = c.get("delta", "")
|
|
412
|
+
if isinstance(delta, (int, float)):
|
|
413
|
+
delta_str = f"{delta:+.4f}" if isinstance(delta, float) else f"{delta:+d}"
|
|
414
|
+
else:
|
|
415
|
+
delta_str = str(delta)
|
|
416
|
+
lines.append(
|
|
417
|
+
f"| {c['feature']} | {c['original']} | {c['counterfactual']} | {delta_str} |"
|
|
418
|
+
)
|
|
419
|
+
else:
|
|
420
|
+
lines.append("No counterfactual found within search budget.")
|
|
421
|
+
|
|
422
|
+
# Show method comparison
|
|
423
|
+
greedy = results.get("greedy", {})
|
|
424
|
+
proto = results.get("prototype", {})
|
|
425
|
+
if greedy.get("status") == "found" or (proto and proto.get("status") == "found"):
|
|
426
|
+
lines.append("")
|
|
427
|
+
lines.append("**Method comparison:**")
|
|
428
|
+
if greedy.get("status") == "found":
|
|
429
|
+
lines.append(f"- Greedy: distance={greedy.get('distance')}, changes={greedy.get('n_changes')}")
|
|
430
|
+
if proto and proto.get("status") == "found":
|
|
431
|
+
lines.append(f"- Prototype: distance={proto.get('distance')}, changes={proto.get('n_changes')}")
|
|
432
|
+
|
|
433
|
+
elif isinstance(results, list):
|
|
434
|
+
lines.append(f"**Batch results:** {len(results)} misclassified samples analyzed")
|
|
435
|
+
found = sum(1 for r in results if r.get("status") == "found")
|
|
436
|
+
lines.append(f"**Counterfactuals found:** {found}/{len(results)}")
|
|
437
|
+
|
|
438
|
+
lines.append("")
|
|
439
|
+
lines.append(f"*Generated: {report.get('generated_at', 'N/A')}*")
|
|
440
|
+
return "\n".join(lines)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
# --- CLI ---
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def main():
|
|
447
|
+
parser = argparse.ArgumentParser(
|
|
448
|
+
description="Counterfactual explanations — find minimum input changes to flip predictions"
|
|
449
|
+
)
|
|
450
|
+
parser.add_argument("exp_id", nargs="?", help="Experiment ID")
|
|
451
|
+
parser.add_argument("--sample", type=int, help="Sample index to explain")
|
|
452
|
+
parser.add_argument("--target", help="Target class for counterfactual")
|
|
453
|
+
parser.add_argument("--batch-misclassified", action="store_true",
|
|
454
|
+
help="Generate counterfactuals for all misclassified samples")
|
|
455
|
+
parser.add_argument("--config", default="config.yaml", help="Path to config.yaml")
|
|
456
|
+
parser.add_argument("--log", default=DEFAULT_LOG_PATH, help="Path to experiment log")
|
|
457
|
+
parser.add_argument("--json", action="store_true", help="Output raw JSON")
|
|
458
|
+
|
|
459
|
+
args = parser.parse_args()
|
|
460
|
+
|
|
461
|
+
if not args.exp_id:
|
|
462
|
+
parser.error("Please provide an experiment ID")
|
|
463
|
+
|
|
464
|
+
report = counterfactual_analysis(
|
|
465
|
+
exp_id=args.exp_id,
|
|
466
|
+
sample_index=args.sample,
|
|
467
|
+
target_class=args.target,
|
|
468
|
+
batch_misclassified=args.batch_misclassified,
|
|
469
|
+
config_path=args.config,
|
|
470
|
+
log_path=args.log,
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
if args.json:
|
|
474
|
+
print(json.dumps(report, indent=2, default=str))
|
|
475
|
+
else:
|
|
476
|
+
print(format_counterfactual_report(report))
|
|
477
|
+
|
|
478
|
+
if "error" not in report:
|
|
479
|
+
saved = save_counterfactual_report(report)
|
|
480
|
+
if not args.json:
|
|
481
|
+
print(f"\nSaved: {saved}")
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
if __name__ == "__main__":
|
|
485
|
+
main()
|