claude-turing 4.1.0 → 4.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,485 @@
1
+ #!/usr/bin/env python3
2
+ """Input-level counterfactual explanations for the autoresearch pipeline.
3
+
4
+ For a given prediction, finds the smallest input change that would flip
5
+ the outcome. "This sample was classified as X — what's the minimum change
6
+ to make it Y?" Useful for debugging predictions and regulatory explanations.
7
+
8
+ Usage:
9
+ python scripts/counterfactual_explanation.py exp-042 --sample 1247
10
+ python scripts/counterfactual_explanation.py exp-042 --sample 1247 --target 0
11
+ python scripts/counterfactual_explanation.py exp-042 --batch-misclassified
12
+ python scripts/counterfactual_explanation.py --json
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import sys
20
+ from datetime import datetime, timezone
21
+ from pathlib import Path
22
+
23
+ import numpy as np
24
+ import yaml
25
+
26
+ from scripts.turing_io import load_config, load_experiments
27
+
28
+ DEFAULT_LOG_PATH = "experiments/log.jsonl"
29
+ DEFAULT_MAX_ITERATIONS = 100
30
+ DEFAULT_DISTANCE_METRIC = "normalized_l2"
31
+
32
+
33
+ # --- Feature Perturbation ---
34
+
35
+
36
+ def greedy_perturbation(
37
+ sample: dict[str, float],
38
+ predict_fn,
39
+ target_class: int | str,
40
+ feature_names: list[str],
41
+ feature_ranges: dict[str, tuple[float, float]],
42
+ max_iterations: int = DEFAULT_MAX_ITERATIONS,
43
+ categorical_features: list[str] | None = None,
44
+ ) -> dict:
45
+ """Find counterfactual by greedily changing one feature at a time.
46
+
47
+ Args:
48
+ sample: Original sample as {feature_name: value}.
49
+ predict_fn: Function(sample_dict) -> (predicted_class, confidence).
50
+ target_class: Desired target class.
51
+ feature_names: Ordered list of feature names.
52
+ feature_ranges: {feature: (min, max)} from training data.
53
+ max_iterations: Maximum perturbation attempts.
54
+ categorical_features: Features that are categorical (discrete changes).
55
+
56
+ Returns:
57
+ Counterfactual result dict.
58
+ """
59
+ if categorical_features is None:
60
+ categorical_features = []
61
+
62
+ current = dict(sample)
63
+ original_pred, original_conf = predict_fn(sample)
64
+
65
+ if str(original_pred) == str(target_class):
66
+ return {
67
+ "status": "already_target",
68
+ "message": f"Sample is already predicted as {target_class}",
69
+ }
70
+
71
+ best_cf = None
72
+ best_distance = float("inf")
73
+ changes = []
74
+
75
+ for iteration in range(max_iterations):
76
+ improved = False
77
+
78
+ for feat in feature_names:
79
+ if feat in categorical_features:
80
+ candidates = _categorical_candidates(feat, current[feat], feature_ranges.get(feat))
81
+ else:
82
+ candidates = _numeric_candidates(
83
+ current[feat],
84
+ feature_ranges.get(feat, (0, 1)),
85
+ n_steps=5,
86
+ )
87
+
88
+ for candidate_val in candidates:
89
+ trial = dict(current)
90
+ trial[feat] = candidate_val
91
+ pred, conf = predict_fn(trial)
92
+
93
+ if str(pred) == str(target_class):
94
+ dist = _compute_distance(sample, trial, feature_ranges)
95
+ if dist < best_distance:
96
+ best_distance = dist
97
+ best_cf = dict(trial)
98
+ changes = _compute_changes(sample, trial, feature_names)
99
+ improved = True
100
+
101
+ if best_cf is not None and not improved:
102
+ break
103
+
104
+ if best_cf is None:
105
+ return {
106
+ "status": "not_found",
107
+ "message": f"Could not find counterfactual within {max_iterations} iterations",
108
+ "original_prediction": original_pred,
109
+ "original_confidence": float(original_conf),
110
+ }
111
+
112
+ cf_pred, cf_conf = predict_fn(best_cf)
113
+
114
+ return {
115
+ "status": "found",
116
+ "original_prediction": original_pred,
117
+ "original_confidence": float(original_conf),
118
+ "counterfactual_prediction": cf_pred,
119
+ "counterfactual_confidence": float(cf_conf),
120
+ "distance": round(float(best_distance), 4),
121
+ "n_changes": len(changes),
122
+ "changes": changes,
123
+ "counterfactual_sample": best_cf,
124
+ }
125
+
126
+
127
+ def _numeric_candidates(current: float, value_range: tuple[float, float], n_steps: int = 5) -> list[float]:
128
+ """Generate candidate values for a numeric feature."""
129
+ low, high = value_range
130
+ step = (high - low) / max(n_steps, 1)
131
+ candidates = []
132
+ for i in range(n_steps + 1):
133
+ val = low + i * step
134
+ if val != current:
135
+ candidates.append(val)
136
+ return candidates
137
+
138
+
139
+ def _categorical_candidates(
140
+ feature: str,
141
+ current_value,
142
+ value_range: tuple | list | None,
143
+ ) -> list:
144
+ """Generate candidate values for a categorical feature."""
145
+ if value_range is None:
146
+ return []
147
+ if isinstance(value_range, (tuple, list)):
148
+ return [v for v in value_range if v != current_value]
149
+ return []
150
+
151
+
152
+ def _compute_distance(
153
+ original: dict[str, float],
154
+ counterfactual: dict[str, float],
155
+ feature_ranges: dict[str, tuple[float, float]],
156
+ ) -> float:
157
+ """Compute normalized L2 distance between original and counterfactual."""
158
+ total = 0.0
159
+ for feat in original:
160
+ orig_val = original[feat]
161
+ cf_val = counterfactual.get(feat, orig_val)
162
+ feat_range = feature_ranges.get(feat, (0, 1))
163
+
164
+ if isinstance(orig_val, str) or (isinstance(feat_range, (tuple, list)) and len(feat_range) > 2):
165
+ # Categorical: 1 if changed, 0 if same
166
+ total += 0.0 if orig_val == cf_val else 1.0
167
+ else:
168
+ low, high = feat_range[0], feat_range[1]
169
+ span = high - low if high != low else 1
170
+ normalized_diff = (cf_val - orig_val) / span
171
+ total += normalized_diff ** 2
172
+ return float(np.sqrt(total))
173
+
174
+
175
+ def _compute_changes(
176
+ original: dict[str, float],
177
+ counterfactual: dict[str, float],
178
+ feature_names: list[str],
179
+ ) -> list[dict]:
180
+ """Compute the list of changed features."""
181
+ changes = []
182
+ for feat in feature_names:
183
+ orig = original.get(feat)
184
+ cf = counterfactual.get(feat)
185
+ if orig != cf:
186
+ change = {
187
+ "feature": feat,
188
+ "original": orig,
189
+ "counterfactual": cf,
190
+ }
191
+ if isinstance(orig, (int, float)) and isinstance(cf, (int, float)):
192
+ change["delta"] = round(cf - orig, 6)
193
+ else:
194
+ change["delta"] = "category_change"
195
+ changes.append(change)
196
+ return changes
197
+
198
+
199
+ # --- Prototype-Based Search ---
200
+
201
+
202
+ def prototype_counterfactual(
203
+ sample: dict[str, float],
204
+ training_data: list[dict[str, float]],
205
+ training_labels: list,
206
+ target_class: int | str,
207
+ feature_names: list[str],
208
+ feature_ranges: dict[str, tuple[float, float]],
209
+ ) -> dict:
210
+ """Find the nearest training sample from the target class.
211
+
212
+ Args:
213
+ sample: Original sample.
214
+ training_data: List of training samples as dicts.
215
+ training_labels: Corresponding labels.
216
+ target_class: Desired target class.
217
+ feature_names: Feature names.
218
+ feature_ranges: {feature: (min, max)}.
219
+
220
+ Returns:
221
+ Nearest prototype counterfactual result.
222
+ """
223
+ target_indices = [i for i, label in enumerate(training_labels) if str(label) == str(target_class)]
224
+
225
+ if not target_indices:
226
+ return {
227
+ "status": "not_found",
228
+ "message": f"No training samples found for class {target_class}",
229
+ }
230
+
231
+ best_dist = float("inf")
232
+ best_idx = -1
233
+
234
+ for idx in target_indices:
235
+ dist = _compute_distance(sample, training_data[idx], feature_ranges)
236
+ if dist < best_dist:
237
+ best_dist = dist
238
+ best_idx = idx
239
+
240
+ if best_idx < 0:
241
+ return {"status": "not_found", "message": "No valid prototype found"}
242
+
243
+ prototype = training_data[best_idx]
244
+ changes = _compute_changes(sample, prototype, feature_names)
245
+
246
+ return {
247
+ "status": "found",
248
+ "method": "prototype",
249
+ "prototype_index": best_idx,
250
+ "distance": round(float(best_dist), 4),
251
+ "n_changes": len(changes),
252
+ "changes": changes,
253
+ "counterfactual_sample": prototype,
254
+ }
255
+
256
+
257
+ # --- Full Pipeline ---
258
+
259
+
260
+ def counterfactual_analysis(
261
+ exp_id: str,
262
+ sample_index: int | None = None,
263
+ sample_data: dict[str, float] | None = None,
264
+ target_class: int | str | None = None,
265
+ predict_fn=None,
266
+ training_data: list[dict] | None = None,
267
+ training_labels: list | None = None,
268
+ feature_names: list[str] | None = None,
269
+ feature_ranges: dict[str, tuple[float, float]] | None = None,
270
+ categorical_features: list[str] | None = None,
271
+ batch_misclassified: bool = False,
272
+ config_path: str = "config.yaml",
273
+ log_path: str = DEFAULT_LOG_PATH,
274
+ ) -> dict:
275
+ """Run counterfactual analysis.
276
+
277
+ Args:
278
+ exp_id: Experiment ID to analyze.
279
+ sample_index: Index of the sample to explain.
280
+ sample_data: Direct sample data (alternative to index).
281
+ target_class: Desired counterfactual class.
282
+ predict_fn: Prediction function (sample_dict) -> (class, confidence).
283
+ training_data: Training data for prototype search.
284
+ training_labels: Training labels for prototype search.
285
+ feature_names: Feature names.
286
+ feature_ranges: Feature value ranges.
287
+ categorical_features: Categorical feature names.
288
+ batch_misclassified: If True, generate for all misclassified samples.
289
+ config_path: Path to config.yaml.
290
+ log_path: Path to experiment log.
291
+
292
+ Returns:
293
+ Counterfactual analysis report.
294
+ """
295
+ config = load_config(config_path)
296
+
297
+ if sample_data is None and sample_index is None and not batch_misclassified:
298
+ return {"error": "Provide --sample <index> or --batch-misclassified"}
299
+
300
+ if predict_fn is None:
301
+ return {
302
+ "error": "No prediction function available. "
303
+ "Load the model from the experiment first.",
304
+ "suggestion": f"Run `/turing:counterfactual {exp_id} --sample <index>` "
305
+ "from the experiment directory with train.py available.",
306
+ }
307
+
308
+ if feature_names is None:
309
+ return {"error": "Feature names not available. Provide feature_names."}
310
+
311
+ if feature_ranges is None:
312
+ feature_ranges = {}
313
+
314
+ results = []
315
+
316
+ if batch_misclassified and training_data and training_labels:
317
+ for i, (data, label) in enumerate(zip(training_data, training_labels)):
318
+ pred, conf = predict_fn(data)
319
+ if str(pred) != str(label):
320
+ cf = greedy_perturbation(
321
+ data, predict_fn, label, feature_names,
322
+ feature_ranges, categorical_features=categorical_features,
323
+ )
324
+ cf["sample_index"] = i
325
+ cf["true_label"] = label
326
+ results.append(cf)
327
+ elif sample_data is not None:
328
+ if target_class is None:
329
+ pred, _ = predict_fn(sample_data)
330
+ # Flip to opposite for binary
331
+ target_class = 0 if pred == 1 else 1
332
+
333
+ # Try greedy perturbation
334
+ cf_greedy = greedy_perturbation(
335
+ sample_data, predict_fn, target_class, feature_names,
336
+ feature_ranges, categorical_features=categorical_features,
337
+ )
338
+
339
+ # Try prototype-based if training data available
340
+ cf_proto = None
341
+ if training_data and training_labels:
342
+ cf_proto = prototype_counterfactual(
343
+ sample_data, training_data, training_labels,
344
+ target_class, feature_names, feature_ranges,
345
+ )
346
+
347
+ results = {
348
+ "greedy": cf_greedy,
349
+ "prototype": cf_proto,
350
+ "best": _select_best([cf_greedy, cf_proto]),
351
+ }
352
+
353
+ return {
354
+ "experiment_id": exp_id,
355
+ "sample_index": sample_index,
356
+ "target_class": target_class,
357
+ "results": results,
358
+ "generated_at": datetime.now(timezone.utc).isoformat(),
359
+ }
360
+
361
+
362
+ def _select_best(candidates: list[dict | None]) -> dict | None:
363
+ """Select the counterfactual with smallest distance."""
364
+ valid = [c for c in candidates if c and c.get("status") == "found"]
365
+ if not valid:
366
+ return None
367
+ return min(valid, key=lambda c: c.get("distance", float("inf")))
368
+
369
+
370
+ # --- Report Formatting ---
371
+
372
+
373
+ def save_counterfactual_report(report: dict, output_dir: str = "experiments/counterfactuals") -> Path:
374
+ """Save counterfactual report to YAML."""
375
+ out_path = Path(output_dir)
376
+ out_path.mkdir(parents=True, exist_ok=True)
377
+ exp_id = report.get("experiment_id", "unknown")
378
+ sample = report.get("sample_index", "batch")
379
+ filepath = out_path / f"{exp_id}-cf-{sample}.yaml"
380
+ with open(filepath, "w") as f:
381
+ yaml.dump(report, f, default_flow_style=False, sort_keys=False)
382
+ return filepath
383
+
384
+
385
+ def format_counterfactual_report(report: dict) -> str:
386
+ """Format counterfactual report as readable markdown."""
387
+ if "error" in report:
388
+ return f"ERROR: {report['error']}"
389
+
390
+ lines = ["# Counterfactual Explanation", ""]
391
+ lines.append(f"**Experiment:** {report.get('experiment_id', 'N/A')}")
392
+ lines.append(f"**Sample:** {report.get('sample_index', 'N/A')}")
393
+ lines.append(f"**Target class:** {report.get('target_class', 'N/A')}")
394
+ lines.append("")
395
+
396
+ results = report.get("results", {})
397
+
398
+ if isinstance(results, dict):
399
+ best = results.get("best")
400
+ if best and best.get("status") == "found":
401
+ lines.append(f"**Method:** {best.get('method', 'greedy')}")
402
+ lines.append(f"**Distance:** {best.get('distance', 'N/A')}")
403
+ lines.append(f"**Changes needed:** {best.get('n_changes', 0)}")
404
+ lines.append("")
405
+
406
+ changes = best.get("changes", [])
407
+ if changes:
408
+ lines.append("| Feature | Original | Counterfactual | Change |")
409
+ lines.append("|---------|----------|----------------|--------|")
410
+ for c in changes:
411
+ delta = c.get("delta", "")
412
+ if isinstance(delta, (int, float)):
413
+ delta_str = f"{delta:+.4f}" if isinstance(delta, float) else f"{delta:+d}"
414
+ else:
415
+ delta_str = str(delta)
416
+ lines.append(
417
+ f"| {c['feature']} | {c['original']} | {c['counterfactual']} | {delta_str} |"
418
+ )
419
+ else:
420
+ lines.append("No counterfactual found within search budget.")
421
+
422
+ # Show method comparison
423
+ greedy = results.get("greedy", {})
424
+ proto = results.get("prototype", {})
425
+ if greedy.get("status") == "found" or (proto and proto.get("status") == "found"):
426
+ lines.append("")
427
+ lines.append("**Method comparison:**")
428
+ if greedy.get("status") == "found":
429
+ lines.append(f"- Greedy: distance={greedy.get('distance')}, changes={greedy.get('n_changes')}")
430
+ if proto and proto.get("status") == "found":
431
+ lines.append(f"- Prototype: distance={proto.get('distance')}, changes={proto.get('n_changes')}")
432
+
433
+ elif isinstance(results, list):
434
+ lines.append(f"**Batch results:** {len(results)} misclassified samples analyzed")
435
+ found = sum(1 for r in results if r.get("status") == "found")
436
+ lines.append(f"**Counterfactuals found:** {found}/{len(results)}")
437
+
438
+ lines.append("")
439
+ lines.append(f"*Generated: {report.get('generated_at', 'N/A')}*")
440
+ return "\n".join(lines)
441
+
442
+
443
+ # --- CLI ---
444
+
445
+
446
+ def main():
447
+ parser = argparse.ArgumentParser(
448
+ description="Counterfactual explanations — find minimum input changes to flip predictions"
449
+ )
450
+ parser.add_argument("exp_id", nargs="?", help="Experiment ID")
451
+ parser.add_argument("--sample", type=int, help="Sample index to explain")
452
+ parser.add_argument("--target", help="Target class for counterfactual")
453
+ parser.add_argument("--batch-misclassified", action="store_true",
454
+ help="Generate counterfactuals for all misclassified samples")
455
+ parser.add_argument("--config", default="config.yaml", help="Path to config.yaml")
456
+ parser.add_argument("--log", default=DEFAULT_LOG_PATH, help="Path to experiment log")
457
+ parser.add_argument("--json", action="store_true", help="Output raw JSON")
458
+
459
+ args = parser.parse_args()
460
+
461
+ if not args.exp_id:
462
+ parser.error("Please provide an experiment ID")
463
+
464
+ report = counterfactual_analysis(
465
+ exp_id=args.exp_id,
466
+ sample_index=args.sample,
467
+ target_class=args.target,
468
+ batch_misclassified=args.batch_misclassified,
469
+ config_path=args.config,
470
+ log_path=args.log,
471
+ )
472
+
473
+ if args.json:
474
+ print(json.dumps(report, indent=2, default=str))
475
+ else:
476
+ print(format_counterfactual_report(report))
477
+
478
+ if "error" not in report:
479
+ saved = save_counterfactual_report(report)
480
+ if not args.json:
481
+ print(f"\nSaved: {saved}")
482
+
483
+
484
+ if __name__ == "__main__":
485
+ main()