interpkit 0.5.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. {interpkit-0.5.0 → interpkit-0.6.0}/PKG-INFO +26 -2
  2. {interpkit-0.5.0 → interpkit-0.6.0}/README.md +23 -1
  3. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/__init__.py +19 -0
  4. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/cli/main.py +342 -7
  5. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/enums.py +16 -0
  6. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/inputs.py +26 -8
  7. interpkit-0.6.0/interpkit/core/interventions.py +492 -0
  8. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/model.py +223 -3
  9. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/paths.py +18 -1
  10. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/render.py +176 -0
  11. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/support_matrix.py +8 -0
  12. interpkit-0.6.0/interpkit/core/topk.py +63 -0
  13. interpkit-0.6.0/interpkit/ops/_atp.py +13 -0
  14. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/_hooks.py +40 -1
  15. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/ablate.py +9 -39
  16. interpkit-0.6.0/interpkit/ops/atp.py +230 -0
  17. interpkit-0.6.0/interpkit/ops/eap.py +355 -0
  18. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/find_circuit.py +130 -65
  19. interpkit-0.6.0/interpkit/ops/generate.py +292 -0
  20. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/lens.py +50 -5
  21. interpkit-0.6.0/interpkit/ops/maxact.py +347 -0
  22. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/patch.py +32 -56
  23. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/steer.py +5 -24
  24. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/trace.py +10 -56
  25. interpkit-0.6.0/interpkit/ops/tuned_lens.py +437 -0
  26. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/PKG-INFO +26 -2
  27. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/SOURCES.txt +14 -0
  28. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/requires.txt +3 -0
  29. {interpkit-0.5.0 → interpkit-0.6.0}/pyproject.toml +4 -1
  30. interpkit-0.6.0/tests/test_atp.py +68 -0
  31. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_cli.py +133 -0
  32. interpkit-0.6.0/tests/test_eap.py +138 -0
  33. interpkit-0.6.0/tests/test_generate.py +186 -0
  34. interpkit-0.6.0/tests/test_interventions.py +241 -0
  35. interpkit-0.6.0/tests/test_maxact.py +149 -0
  36. interpkit-0.6.0/tests/test_topk.py +58 -0
  37. interpkit-0.6.0/tests/test_tuned_lens.py +140 -0
  38. interpkit-0.5.0/interpkit/ops/_atp.py +0 -182
  39. {interpkit-0.5.0 → interpkit-0.6.0}/LICENSE +0 -0
  40. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/__main__.py +0 -0
  41. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/cli/__init__.py +0 -0
  42. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/__init__.py +0 -0
  43. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/__init__.py +0 -0
  44. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/blocks.py +0 -0
  45. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/family.py +0 -0
  46. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/heads.py +0 -0
  47. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/layers.py +0 -0
  48. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/names.py +0 -0
  49. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/probe.py +0 -0
  50. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/residual.py +0 -0
  51. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/resolve.py +0 -0
  52. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/tree.py +0 -0
  53. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/types.py +0 -0
  54. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/cache.py +0 -0
  55. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/exceptions.py +0 -0
  56. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/html.py +0 -0
  57. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/loader.py +0 -0
  58. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/plot.py +0 -0
  59. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/registry.py +0 -0
  60. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/theme.py +0 -0
  61. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/tl_compat.py +0 -0
  62. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/__init__.py +0 -0
  63. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/activations.py +0 -0
  64. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/attention.py +0 -0
  65. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/attribute.py +0 -0
  66. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/batch.py +0 -0
  67. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/circuits.py +0 -0
  68. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/diff.py +0 -0
  69. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/dla.py +0 -0
  70. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/heads.py +0 -0
  71. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/inspect.py +0 -0
  72. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/probe.py +0 -0
  73. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/report.py +0 -0
  74. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/sae.py +0 -0
  75. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/scan.py +0 -0
  76. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/dependency_links.txt +0 -0
  77. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/entry_points.txt +0 -0
  78. {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/top_level.txt +0 -0
  79. {interpkit-0.5.0 → interpkit-0.6.0}/setup.cfg +0 -0
  80. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_ablate.py +0 -0
  81. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_activations.py +0 -0
  82. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_archinfo_serialization.py +0 -0
  83. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_architectures.py +0 -0
  84. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_attention.py +0 -0
  85. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_attribute.py +0 -0
  86. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_audit_regressions.py +0 -0
  87. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_cache.py +0 -0
  88. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_cache_invalidation.py +0 -0
  89. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_capabilities.py +0 -0
  90. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_chat.py +0 -0
  91. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_diff.py +0 -0
  92. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_discovery.py +0 -0
  93. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_discovery_units.py +0 -0
  94. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_error_handling.py +0 -0
  95. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_html.py +0 -0
  96. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_inputs.py +0 -0
  97. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_inspect.py +0 -0
  98. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_invariants.py +0 -0
  99. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_lens.py +0 -0
  100. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_load_params.py +0 -0
  101. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_multi_arch.py +0 -0
  102. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_ops.py +0 -0
  103. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_param_variants.py +0 -0
  104. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_patch.py +0 -0
  105. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_phase3_regressions.py +0 -0
  106. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_plot_internals.py +0 -0
  107. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_plots.py +0 -0
  108. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_probe.py +0 -0
  109. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_registry.py +0 -0
  110. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_regressions.py +0 -0
  111. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_render_internals.py +0 -0
  112. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_resolver.py +0 -0
  113. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_resolver_golden.py +0 -0
  114. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_robustness_audit.py +0 -0
  115. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_sae.py +0 -0
  116. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_seq2seq_contract.py +0 -0
  117. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_steer.py +0 -0
  118. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_tl_compat.py +0 -0
  119. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_tl_ops.py +0 -0
  120. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_trace.py +0 -0
  121. {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_validation.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: interpkit
3
- Version: 0.5.0
3
+ Version: 0.6.0
4
4
  Summary: Mech interp for any HuggingFace model.
5
5
  Author: Davide Zani
6
6
  License-Expression: MIT
@@ -34,6 +34,8 @@ Provides-Extra: vision
34
34
  Requires-Dist: torchvision>=0.16; extra == "vision"
35
35
  Provides-Extra: probe
36
36
  Requires-Dist: scikit-learn>=1.3; extra == "probe"
37
+ Provides-Extra: data
38
+ Requires-Dist: datasets>=2.14; extra == "data"
37
39
  Provides-Extra: dev
38
40
  Requires-Dist: pytest>=7.0; extra == "dev"
39
41
  Requires-Dist: pytest-timeout>=2.2; extra == "dev"
@@ -186,7 +188,13 @@ See [examples/10_chat_models.ipynb](examples/10_chat_models.ipynb) for a full wa
186
188
  | **`ov_scores`** | OV circuit analysis — W_OV matrix per head | Transformers |
187
189
  | **`qk_scores`** | QK circuit analysis — W_QK matrix per head | Transformers |
188
190
  | **`composition`** | Q/K/V composition scores between heads in two layers | Transformers |
189
- | **`find_circuit`** | Automated circuit discovery via iterative ablation | Transformers |
191
+ | **`find_circuit`** | Automated circuit discovery iterative ablation or EAP-based selection with causal verification | Transformers |
192
+ | **`generate`** | Generation with interventions active across every decode step + per-token lens capture | Generative LMs |
193
+ | **`intervene`** | Context manager applying steer/ablate/patch interventions to any op | Any model |
194
+ | **`atp`** | Attribution Patching — first-order patch-effect scores for all modules in 3 passes | Any model |
195
+ | **`eap`** | Edge Attribution Patching — gradient-based component → residual-stream edge scores (EAP-IG via `ig_steps`) | Causal LMs |
196
+ | **`train_tuned_lens`** | Train per-layer tuned-lens translators (Belrose et al. 2023); use via `lens(kind="tuned")` | LMs |
197
+ | **`max_activating`** | Scan a corpus for the examples that most activate a neuron / SAE feature / head | Any model |
190
198
  | **`batch`** | Run any operation over a dataset with result aggregation | Any model |
191
199
 
192
200
  ---
@@ -482,6 +490,20 @@ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jb
482
490
  interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
483
491
  interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
484
492
 
493
+ # Generation-time interventions + per-token lens trajectories
494
+ interpkit generate gpt2 "I feel" --positive " joy" --negative " fear" --at transformer.h.6 --scale 8
495
+ interpkit generate gpt2 "The capital of France is" --capture lens
496
+
497
+ # Gradient-based circuit discovery
498
+ interpkit atp gpt2 --clean "The capital of France is" --corrupted "The capital of Germany is"
499
+ interpkit eap gpt2 --clean "..." --corrupted "..." --ig-steps 5
500
+ interpkit find-circuit gpt2 --clean "..." --corrupted "..." --method eap --threshold 0.3
501
+
502
+ # Tuned lens + max-activating examples
503
+ interpkit train-tuned-lens gpt2 --corpus-file texts.txt --save lens_dir/
504
+ interpkit lens gpt2 "The capital of France is" --tuned-lens lens_dir/
505
+ interpkit maxact gpt2 --at transformer.h.6.mlp --neuron 42 --texts-file corpus.txt
506
+
485
507
  # Chat / instruct models — applies the tokenizer's chat template automatically
486
508
  interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "Write a haiku about cats." --max-new-tokens 64
487
509
  interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "What is 2+2?" --system "You are terse." --show-prompt
@@ -592,6 +614,8 @@ See the [`examples/`](examples/) directory for Jupyter notebooks:
592
614
  | `08_dla_and_circuits` | DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery |
593
615
  | `09_scan_and_batch` | Auto-scan, batch operations, dataset workflows |
594
616
  | `10_chat_models` | Chat-template handling, `model.chat()`, message-list inputs, chat-style steering |
617
+ | `11_generation_interventions` | Steering/ablation active across every decode step, per-token lens trajectories, positional interventions, `model.intervene()` |
618
+ | `12_circuit_discovery_and_lenses` | Attribution Patching, Edge Attribution Patching, EAP-driven `find_circuit`, tuned lens, max-activating examples |
595
619
 
596
620
  ---
597
621
 
@@ -136,7 +136,13 @@ See [examples/10_chat_models.ipynb](examples/10_chat_models.ipynb) for a full wa
136
136
  | **`ov_scores`** | OV circuit analysis — W_OV matrix per head | Transformers |
137
137
  | **`qk_scores`** | QK circuit analysis — W_QK matrix per head | Transformers |
138
138
  | **`composition`** | Q/K/V composition scores between heads in two layers | Transformers |
139
- | **`find_circuit`** | Automated circuit discovery via iterative ablation | Transformers |
139
+ | **`find_circuit`** | Automated circuit discovery iterative ablation or EAP-based selection with causal verification | Transformers |
140
+ | **`generate`** | Generation with interventions active across every decode step + per-token lens capture | Generative LMs |
141
+ | **`intervene`** | Context manager applying steer/ablate/patch interventions to any op | Any model |
142
+ | **`atp`** | Attribution Patching — first-order patch-effect scores for all modules in 3 passes | Any model |
143
+ | **`eap`** | Edge Attribution Patching — gradient-based component → residual-stream edge scores (EAP-IG via `ig_steps`) | Causal LMs |
144
+ | **`train_tuned_lens`** | Train per-layer tuned-lens translators (Belrose et al. 2023); use via `lens(kind="tuned")` | LMs |
145
+ | **`max_activating`** | Scan a corpus for the examples that most activate a neuron / SAE feature / head | Any model |
140
146
  | **`batch`** | Run any operation over a dataset with result aggregation | Any model |
141
147
 
142
148
  ---
@@ -432,6 +438,20 @@ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jb
432
438
  interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
433
439
  interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
434
440
 
441
+ # Generation-time interventions + per-token lens trajectories
442
+ interpkit generate gpt2 "I feel" --positive " joy" --negative " fear" --at transformer.h.6 --scale 8
443
+ interpkit generate gpt2 "The capital of France is" --capture lens
444
+
445
+ # Gradient-based circuit discovery
446
+ interpkit atp gpt2 --clean "The capital of France is" --corrupted "The capital of Germany is"
447
+ interpkit eap gpt2 --clean "..." --corrupted "..." --ig-steps 5
448
+ interpkit find-circuit gpt2 --clean "..." --corrupted "..." --method eap --threshold 0.3
449
+
450
+ # Tuned lens + max-activating examples
451
+ interpkit train-tuned-lens gpt2 --corpus-file texts.txt --save lens_dir/
452
+ interpkit lens gpt2 "The capital of France is" --tuned-lens lens_dir/
453
+ interpkit maxact gpt2 --at transformer.h.6.mlp --neuron 42 --texts-file corpus.txt
454
+
435
455
  # Chat / instruct models — applies the tokenizer's chat template automatically
436
456
  interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "Write a haiku about cats." --max-new-tokens 64
437
457
  interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "What is 2+2?" --system "You are terse." --show-prompt
@@ -542,6 +562,8 @@ See the [`examples/`](examples/) directory for Jupyter notebooks:
542
562
  | `08_dla_and_circuits` | DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery |
543
563
  | `09_scan_and_batch` | Auto-scan, batch operations, dataset workflows |
544
564
  | `10_chat_models` | Chat-template handling, `model.chat()`, message-list inputs, chat-style steering |
565
+ | `11_generation_interventions` | Steering/ablation active across every decode step, per-token lens trajectories, positional interventions, `model.intervene()` |
566
+ | `12_circuit_discovery_and_lenses` | Attribution Patching, Edge Attribution Patching, EAP-driven `find_circuit`, tuned lens, max-activating examples |
545
567
 
546
568
  ---
547
569
 
@@ -16,6 +16,16 @@ from interpkit.core.exceptions import (
16
16
  OperationNotSupportedForArchitecture,
17
17
  WrongInputType,
18
18
  )
19
+ from interpkit.core.interventions import (
20
+ AblateIntervention,
21
+ CaptureProbe,
22
+ FnIntervention,
23
+ GenerationContext,
24
+ Intervention,
25
+ PatchIntervention,
26
+ SteerIntervention,
27
+ apply_interventions,
28
+ )
19
29
  from interpkit.core.loader import load, load_module
20
30
  from interpkit.core.model import Model
21
31
  from interpkit.core.registry import register
@@ -54,6 +64,15 @@ __all__ = [
54
64
  "LensPipelineMismatch",
55
65
  "OperationNotSupportedForArchitecture",
56
66
  "WrongInputType",
67
+ # Interventions
68
+ "Intervention",
69
+ "SteerIntervention",
70
+ "AblateIntervention",
71
+ "PatchIntervention",
72
+ "FnIntervention",
73
+ "CaptureProbe",
74
+ "GenerationContext",
75
+ "apply_interventions",
57
76
  # Operations
58
77
  "register",
59
78
  "diff",
@@ -236,6 +236,28 @@ def _show_extensive_help() -> None:
236
236
  padding=(0, 2),
237
237
  ))
238
238
 
239
+ console.print()
240
+ console.print(Panel(
241
+ f"[bold {ACCENT}]generate[/bold {ACCENT}] "
242
+ "[dim]interpkit generate gpt2 'I feel' --positive ' joy' --negative ' fear' --at transformer.h.6 --scale 8[/dim]\n\n"
243
+ "Generate text with interventions active across [italic]every[/italic] decode step —"
244
+ " the generation-time counterpart of [bold]steer[/bold] / [bold]ablate[/bold], which"
245
+ " analyse a single forward pass. A steering vector or ablation stays hooked for the"
246
+ " prefill and all KV-cached decode steps, so you can watch a nudged model write.\n\n"
247
+ " Add [bold green]--capture lens[/bold green] to record each generated token's"
248
+ " logit-lens trajectory: which layer first predicted the token the model ended up"
249
+ " emitting.\n\n"
250
+ " [bold]Key options:[/bold]\n"
251
+ " [bold green]--positive / --negative + --at[/bold green] Build a steering vector and apply it while generating.\n"
252
+ " [bold green]--ablate-at / --ablate-method[/bold green] Knock out a module for the whole generation.\n"
253
+ " [bold green]--capture lens|logits[/bold green] Per-token lens trajectory or raw step logits.\n"
254
+ " [bold green]--max-new-tokens N[/bold green] Generation budget (default 64).\n"
255
+ " [bold green]--sample / --temperature / --top-p[/bold green] Standard sampling controls.",
256
+ title="generate",
257
+ border_style=ACCENT_DIM,
258
+ padding=(0, 2),
259
+ ))
260
+
239
261
  # ── Core Operations ───────────────────────────────────────────
240
262
  console.print()
241
263
  console.print(Rule("[bold]Core Operations[/bold]", style=ACCENT))
@@ -285,9 +307,17 @@ def _show_extensive_help() -> None:
285
307
  "interpkit lens gpt2 'The capital of France is'",
286
308
  "Logit lens. After every transformer layer, the hidden state is projected directly into"
287
309
  " vocabulary space so you can see what the model 'thinks' it's predicting at each depth."
288
- " Lets you watch a vague representation sharpen into the final answer layer by layer.",
310
+ " Lets you watch a vague representation sharpen into the final answer layer by layer.\n\n"
311
+ " The raw projection is biased for early layers (their basis isn't aligned with the"
312
+ " unembedding). Train per-layer translators once with"
313
+ f" [bold {ACCENT}]train-tuned-lens[/bold {ACCENT}] and pass"
314
+ " [bold green]--tuned-lens <path>[/bold green] for the unbiased tuned-lens readout"
315
+ " (Belrose et al. 2023).\n"
316
+ " [dim]interpkit train-tuned-lens gpt2 --corpus-file texts.txt --save lens_dir/[/dim]\n"
317
+ " [dim]interpkit lens gpt2 'The capital of France is' --tuned-lens lens_dir/[/dim]",
289
318
  [
290
319
  ("--position N", "Analyse a single token position instead of all positions."),
320
+ ("--tuned-lens PATH", "Apply saved tuned-lens translators instead of the raw projection."),
291
321
  ],
292
322
  ),
293
323
  (
@@ -461,7 +491,9 @@ def _show_extensive_help() -> None:
461
491
  " [bold green]--clean / --corrupted[/bold green] Single clean and corrupted input texts.\n"
462
492
  " [bold green]--clean-file / --corrupted-file[/bold green] Text files with one example per line (paired by line number).\n"
463
493
  " [bold green]--threshold[/bold green] Minimum ablation effect to include (default 0.01).\n"
464
- " [bold green]--method[/bold green] Ablation method: mean (default), zero, resample.\n"
494
+ " [bold green]--method[/bold green] mean (default) · zero · resample (ablation), or eap · eap-ig"
495
+ " (gradient-based selection in a handful of passes — much faster; the circuit is still"
496
+ " verified causally).\n"
465
497
  " [bold green]--metric[/bold green] logit_diff · kl_div · target_prob · l2_prob",
466
498
  title="find-circuit",
467
499
  border_style=ACCENT_DIM,
@@ -469,6 +501,67 @@ def _show_extensive_help() -> None:
469
501
  ))
470
502
  console.print()
471
503
 
504
+ console.print(Panel(
505
+ f"[bold {ACCENT}]atp[/bold {ACCENT}] "
506
+ "[dim]interpkit atp gpt2 --clean 'The capital of France is' --corrupted 'The capital of Germany is'[/dim]\n\n"
507
+ "Attribution Patching (Syed et al. 2023). A first-order gradient approximation of"
508
+ " activation patching: one clean forward, one corrupted forward, and one backward pass"
509
+ " score [italic]every[/italic] module simultaneously — versus one forward per module for"
510
+ " exhaustive tracing. Correlation with true patch effects is typically 0.85–0.95."
511
+ " Use it as the fast first look, then confirm top candidates with"
512
+ f" [bold {ACCENT}]trace[/bold {ACCENT}] or [bold {ACCENT}]patch[/bold {ACCENT}].\n\n"
513
+ " [bold]Key options:[/bold]\n"
514
+ " [bold green]--clean / --corrupted[/bold green] The contrast pair to attribute.\n"
515
+ " [bold green]--top-k[/bold green] Top modules to report by absolute score (0 = all).",
516
+ title="atp",
517
+ border_style=ACCENT_DIM,
518
+ padding=(0, 2),
519
+ ))
520
+ console.print()
521
+
522
+ console.print(Panel(
523
+ f"[bold {ACCENT}]eap[/bold {ACCENT}] "
524
+ "[dim]interpkit eap gpt2 --clean 'The capital of France is' --corrupted 'The capital of Germany is'[/dim]\n\n"
525
+ "Edge Attribution Patching. Where [bold]atp[/bold] scores modules, eap scores"
526
+ " [italic]edges[/italic]: how much each component's clean-vs-corrupted delta matters as it"
527
+ " flows into each downstream residual-stream layer. The edge at a component's own layer is"
528
+ " its total effect; deeper edges show how the effect persists down the stream. Inputs must"
529
+ " tokenize to the same length.\n\n"
530
+ " [bold green]--ig-steps 5[/bold green] switches to EAP-IG: gradients averaged over"
531
+ " embeddings interpolated from corrupted toward clean — more faithful scores when the"
532
+ " corrupted point sits in a saturated region.\n\n"
533
+ " [bold]Key options:[/bold]\n"
534
+ " [bold green]--clean / --corrupted[/bold green] Token-aligned contrast pair.\n"
535
+ " [bold green]--ig-steps[/bold green] EAP-IG interpolation steps (0 = plain EAP).\n"
536
+ " [bold green]--top-k-edges[/bold green] Top edges to report by absolute score (0 = all).",
537
+ title="eap",
538
+ border_style=ACCENT_DIM,
539
+ padding=(0, 2),
540
+ ))
541
+ console.print()
542
+
543
+ console.print(Panel(
544
+ f"[bold {ACCENT}]maxact[/bold {ACCENT}] "
545
+ "[dim]interpkit maxact gpt2 --at transformer.h.6.mlp --neuron 42 --texts-file corpus.txt[/dim]\n\n"
546
+ "Max-activating examples — the feature-browsing workflow: scan a corpus and show the"
547
+ " contexts where one unit fires hardest, with the peak token highlighted. Works for raw"
548
+ " neurons ([bold green]--neuron[/bold green]), SAE features ([bold green]--feature[/bold green]"
549
+ " + [bold green]--sae[/bold green]), and attention heads ([bold green]--head[/bold green])."
550
+ " Streams batched forwards and keeps only the top-k scored contexts, so memory stays flat"
551
+ " however large the corpus.\n\n"
552
+ " HF datasets work too (requires [dim]pip install 'interpkit[data]'[/dim]):\n"
553
+ " [dim]interpkit maxact gpt2 --at transformer.h.6.mlp --neuron 42 --dataset hf:imdb --max-examples 256[/dim]\n\n"
554
+ " [bold]Key options:[/bold]\n"
555
+ " [bold green]--texts-file / --dataset[/bold green] Corpus: one-per-line file, or hf:name[:split[:column]].\n"
556
+ " [bold green]--neuron / --feature / --head[/bold green] Which unit to scan (exactly one).\n"
557
+ " [bold green]--sae[/bold green] SAE repo ID or path (with --feature).\n"
558
+ " [bold green]--top-k / --max-examples[/bold green] How many results / how much corpus.",
559
+ title="maxact",
560
+ border_style=ACCENT_DIM,
561
+ padding=(0, 2),
562
+ ))
563
+ console.print()
564
+
472
565
  console.print(Panel(
473
566
  f"[bold {ACCENT}]features[/bold {ACCENT}] "
474
567
  "[dim]interpkit features gpt2 '...' --at transformer.h.8 --sae jbloom/GPT2-Small-SAEs[/dim]\n\n"
@@ -552,13 +645,15 @@ def main(
552
645
  ("scan", "One-command overview \u2014 DLA, lens, attention, attribution"),
553
646
  ("report", "Generate an interactive HTML report"),
554
647
  ("chat", "Send a message to a chat / instruct model"),
648
+ ("generate", "Generate with steering/ablation active + per-token lens"),
555
649
  ])
556
650
 
557
651
  core_ops = _cmd_table([
558
652
  ("inspect", "Module tree with types, params, roles"),
559
653
  ("dla", "Direct Logit Attribution \u2014 decompose logit by component"),
560
654
  ("trace", "Causal tracing \u2014 module or position-aware"),
561
- ("lens", "Logit lens \u2014 project layers to vocab"),
655
+ ("lens", "Logit lens \u2014 project layers to vocab (--tuned-lens for tuned)"),
656
+ ("train-tuned-lens", "Train per-layer tuned-lens translators"),
562
657
  ("attribute", "Gradient saliency over inputs"),
563
658
  ("patch", "Activation patching at module/head/position"),
564
659
  ])
@@ -574,8 +669,11 @@ def main(
574
669
  ])
575
670
 
576
671
  circuit_ops = _cmd_table([
577
- ("find-circuit", "Automated circuit discovery"),
672
+ ("find-circuit", "Automated circuit discovery (ablation or EAP)"),
673
+ ("atp", "Attribution Patching — score all modules in 3 passes"),
674
+ ("eap", "Edge Attribution Patching — gradient-based edge scores"),
578
675
  ("features", "SAE feature decomposition (single or contrastive)"),
676
+ ("maxact", "Max-activating examples for a neuron / SAE feature / head"),
579
677
  ])
580
678
 
581
679
  layout = Table(show_header=False, box=None, pad_edge=False, padding=0, expand=True)
@@ -726,18 +824,65 @@ def lens(
726
824
  save: str | None = typer.Option(None, "--save", help="Save heatmap to file (e.g. lens.png)"),
727
825
  html_path: str | None = typer.Option(None, "--html", help="Save interactive HTML to file"),
728
826
  position: int | None = typer.Option(None, "--position", help="Single token position to analyse (-1 = last). Omit for all positions."),
827
+ tuned_lens_path: str | None = typer.Option(None, "--tuned-lens", help="Path to a saved tuned lens (switches to kind='tuned'). Train with `interpkit train-tuned-lens`."),
729
828
  device: str | None = typer.Option(None, help="Device"),
730
829
  dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
731
830
  device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
732
831
  ) -> None:
733
- """Logit lens: project each layer's hidden state to vocabulary space."""
832
+ """Logit lens: project each layer's hidden state to vocabulary space.
833
+
834
+ Pass --tuned-lens <path> to apply trained per-layer translators
835
+ (Belrose et al. 2023) for an unbiased early-layer readout.
836
+ """
734
837
  m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
838
+ kind = "tuned" if tuned_lens_path is not None else "logit"
735
839
  with console.status(" Running logit lens..."):
736
- result = m.lens(text, save=save, html=html_path, position=position)
840
+ result = m.lens(
841
+ text, save=save, html=html_path, position=position,
842
+ kind=kind, tuned_lens=tuned_lens_path,
843
+ )
737
844
  if _output_format == "json":
738
845
  _json_dump(result if isinstance(result, dict) else {"results": result})
739
846
 
740
847
 
848
+ # ══════════════════════════════════════════════════════════════════
849
+ # train-tuned-lens
850
+ # ══════════════════════════════════════════════════════════════════
851
+
852
+
853
+ @app.command("train-tuned-lens")
854
+ def train_tuned_lens_cmd(
855
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
856
+ corpus_file: str = typer.Option(..., "--corpus-file", help="Text file with training sentences, one per line"),
857
+ steps: int = typer.Option(200, "--steps", help="Training steps"),
858
+ batch_size: int = typer.Option(4, "--batch-size", help="Batch size"),
859
+ lr: float = typer.Option(1e-3, "--lr", help="Adam learning rate"),
860
+ max_length: int = typer.Option(64, "--max-length", help="Token truncation length"),
861
+ seed: int = typer.Option(0, "--seed", help="Random seed (deterministic on CPU)"),
862
+ save: str | None = typer.Option(None, "--save", help="Output directory or .safetensors path (default: ~/.cache/interpkit/tuned_lens/<model>/)"),
863
+ device: str | None = typer.Option(None, help="Device"),
864
+ dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
865
+ device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
866
+ ) -> None:
867
+ """Train tuned-lens translators (Belrose et al. 2023) for a model.
868
+
869
+ The model stays frozen; only per-layer affine translators train.
870
+ Use the result with `interpkit lens ... --tuned-lens <path>`.
871
+ """
872
+ from interpkit.core.inputs import read_examples_file
873
+ from interpkit.ops.tuned_lens import default_tuned_lens_dir
874
+
875
+ corpus = read_examples_file(corpus_file)
876
+ m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
877
+ out = save if save is not None else str(default_tuned_lens_dir(model_name))
878
+ lens_obj = m.train_tuned_lens(
879
+ corpus, steps=steps, batch_size=batch_size, lr=lr,
880
+ max_length=max_length, seed=seed, save=out,
881
+ )
882
+ if _output_format == "json":
883
+ _json_dump({"saved_to": out, "meta": lens_obj.meta})
884
+
885
+
741
886
  # ══════════════════════════════════════════════════════════════════
742
887
  # attribute
743
888
  # ══════════════════════════════════════════════════════════════════
@@ -1080,7 +1225,7 @@ def find_circuit(
1080
1225
  clean_file: str | None = typer.Option(None, "--clean-file", help="Text file with clean examples, one per line"),
1081
1226
  corrupted_file: str | None = typer.Option(None, "--corrupted-file", help="Text file with corrupted examples, one per line (must match --clean-file line count)"),
1082
1227
  threshold: float = typer.Option(0.01, "--threshold", help="Minimum ablation effect to include in circuit (0-1)"),
1083
- method: str = typer.Option("mean", "--method", help="Ablation method: mean (default), zero, resample"),
1228
+ method: str = typer.Option("mean", "--method", help="Selection method: mean (default), zero, resample (ablation), or eap / eap-ig (gradient-based, much faster)"),
1084
1229
  metric: str = typer.Option("logit_diff", "--metric", help="Effect metric: logit_diff, kl_div, target_prob, l2_prob"),
1085
1230
  device: str | None = typer.Option(None, help="Device"),
1086
1231
  dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
@@ -1190,6 +1335,196 @@ def chat(
1190
1335
  _json_dump({k: v for k, v in result.items() if k not in {"input_ids", "output_ids"}})
1191
1336
 
1192
1337
 
1338
+ # ══════════════════════════════════════════════════════════════════
1339
+ # atp / eap
1340
+ # ══════════════════════════════════════════════════════════════════
1341
+
1342
+
1343
+ @app.command()
1344
+ def atp(
1345
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
1346
+ clean: str = typer.Option(..., "--clean", help="Clean input"),
1347
+ corrupted: str = typer.Option(..., "--corrupted", help="Corrupted input"),
1348
+ top_k: int = typer.Option(20, "--top-k", help="Top modules to report by absolute score. 0 = all."),
1349
+ device: str | None = typer.Option(None, help="Device"),
1350
+ dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
1351
+ device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
1352
+ ) -> None:
1353
+ """Attribution Patching: first-order patch-effect scores for every module.
1354
+
1355
+ Three model passes score all modules at once — the fast first look
1356
+ before committing to `trace`'s per-module full patching.
1357
+ """
1358
+ m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
1359
+ effective_top_k: int | None = top_k if top_k > 0 else None
1360
+ with console.status(" Computing attribution patching scores..."):
1361
+ result = m.atp(clean, corrupted, top_k=effective_top_k)
1362
+ if _output_format == "json":
1363
+ _json_dump(result)
1364
+
1365
+
1366
+ @app.command()
1367
+ def eap(
1368
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
1369
+ clean: str = typer.Option(..., "--clean", help="Clean input (must tokenize to same length as --corrupted)"),
1370
+ corrupted: str = typer.Option(..., "--corrupted", help="Corrupted input"),
1371
+ ig_steps: int = typer.Option(0, "--ig-steps", help="EAP-IG interpolation steps (0 = plain EAP; try 5)"),
1372
+ top_k_edges: int = typer.Option(30, "--top-k-edges", help="Top edges to report by absolute score. 0 = all."),
1373
+ device: str | None = typer.Option(None, help="Device"),
1374
+ dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
1375
+ device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
1376
+ ) -> None:
1377
+ """Edge Attribution Patching: gradient-based edge scores for circuit discovery.
1378
+
1379
+ Scores every (component → residual stream) edge from a handful of
1380
+ passes. Pair with `find-circuit --method eap` for a causally
1381
+ verified circuit.
1382
+ """
1383
+ m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
1384
+ effective_top_k: int | None = top_k_edges if top_k_edges > 0 else None
1385
+ with console.status(" Computing edge attribution scores..."):
1386
+ result = m.eap(clean, corrupted, ig_steps=ig_steps, top_k_edges=effective_top_k)
1387
+ if _output_format == "json":
1388
+ _json_dump(result)
1389
+
1390
+
1391
+ # ══════════════════════════════════════════════════════════════════
1392
+ # maxact
1393
+ # ══════════════════════════════════════════════════════════════════
1394
+
1395
+
1396
+ @app.command()
1397
+ def maxact(
1398
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
1399
+ at: str = typer.Option(..., "--at", help="Module whose activations to scan (e.g. transformer.h.6.mlp)"),
1400
+ texts_file: str | None = typer.Option(None, "--texts-file", help="Text file with one example per line"),
1401
+ dataset: str | None = typer.Option(None, "--dataset", help="HF dataset spec: hf:name[:split[:column]] (needs interpkit[data] + --max-examples)"),
1402
+ neuron: int | None = typer.Option(None, "--neuron", help="Neuron index at the module (raw activation score)"),
1403
+ feature: int | None = typer.Option(None, "--feature", help="SAE feature index (requires --sae)"),
1404
+ head: int | None = typer.Option(None, "--head", help="Attention head index (pre-projection output norm)"),
1405
+ sae: str | None = typer.Option(None, "--sae", help="SAE repo ID or local path (with --feature)"),
1406
+ top_k: int = typer.Option(20, "--top-k", help="Top examples to keep"),
1407
+ batch_size: int = typer.Option(8, "--batch-size", help="Forward batch size"),
1408
+ max_examples: int | None = typer.Option(None, "--max-examples", help="Cap on dataset examples scanned"),
1409
+ max_length: int = typer.Option(128, "--max-length", help="Token truncation length"),
1410
+ device: str | None = typer.Option(None, help="Device"),
1411
+ dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
1412
+ device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
1413
+ ) -> None:
1414
+ """Find the dataset examples that most activate a neuron / SAE feature / head.
1415
+
1416
+ The feature-browsing workflow: "what does this unit fire on?".
1417
+ Streams batched forwards and keeps only the top-k scored contexts.
1418
+ """
1419
+ from interpkit.core.inputs import read_examples_file
1420
+
1421
+ if (texts_file is None) == (dataset is None):
1422
+ raise typer.BadParameter("Provide exactly one of --texts-file or --dataset.")
1423
+ data: list[str] | str = (
1424
+ read_examples_file(texts_file) if texts_file is not None else dataset # type: ignore[assignment]
1425
+ )
1426
+
1427
+ m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
1428
+ result = m.max_activating(
1429
+ data, at=at,
1430
+ neuron=neuron, feature=feature, head=head, sae=sae,
1431
+ top_k=top_k, batch_size=batch_size, max_examples=max_examples,
1432
+ max_length=max_length,
1433
+ )
1434
+ if _output_format == "json":
1435
+ _json_dump(result)
1436
+
1437
+
1438
+ # ══════════════════════════════════════════════════════════════════
1439
+ # generate
1440
+ # ══════════════════════════════════════════════════════════════════
1441
+
1442
+
1443
+ @app.command()
1444
+ def generate(
1445
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
1446
+ prompt: str = typer.Argument(..., help="Prompt text to generate from"),
1447
+ max_new_tokens: int = typer.Option(64, "--max-new-tokens", help="Max generation length"),
1448
+ positive: str | None = typer.Option(None, "--positive", help="Positive steering text (single example)"),
1449
+ negative: str | None = typer.Option(None, "--negative", help="Negative steering text (single example)"),
1450
+ positive_file: str | None = typer.Option(None, "--positive-file", help="Text file with positive examples, one per line"),
1451
+ negative_file: str | None = typer.Option(None, "--negative-file", help="Text file with negative examples, one per line"),
1452
+ at: str | None = typer.Option(None, "--at", help="Module to apply the steering vector at (required with --positive/--negative)"),
1453
+ scale: float = typer.Option(2.0, "--scale", help="Steering vector scale factor"),
1454
+ ablate_at: str | None = typer.Option(None, "--ablate-at", help="Module to ablate during generation"),
1455
+ ablate_method: str = typer.Option("zero", "--ablate-method", help="Ablation method: zero, mean"),
1456
+ capture: str | None = typer.Option(None, "--capture", help="Per-token capture: 'lens' (logit-lens trajectory) or 'logits'"),
1457
+ sample: bool = typer.Option(False, "--sample/--no-sample", help="Sample (True) or use greedy decoding (False, default)"),
1458
+ temperature: float = typer.Option(1.0, "--temperature", help="Sampling temperature (used when --sample)"),
1459
+ top_p: float = typer.Option(1.0, "--top-p", help="Nucleus sampling cutoff (used when --sample)"),
1460
+ device: str | None = typer.Option(None, help="Device"),
1461
+ dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
1462
+ device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
1463
+ ) -> None:
1464
+ """Generate text with interventions active across every decode step.
1465
+
1466
+ Steering (``--positive`` / ``--negative`` + ``--at``) and ablation
1467
+ (``--ablate-at``) stay hooked for the prefill and all KV-cached decode
1468
+ steps — the generation-time counterpart of ``steer`` / ``ablate``.
1469
+ ``--capture lens`` additionally records each generated token's
1470
+ logit-lens trajectory through every block.
1471
+ """
1472
+ from interpkit.core.inputs import read_examples_file
1473
+ from interpkit.core.interventions import AblateIntervention, SteerIntervention
1474
+
1475
+ wants_steering = any([positive, negative, positive_file, negative_file])
1476
+ if wants_steering and at is None:
1477
+ raise typer.BadParameter("Steering requires --at (module to apply the vector at).")
1478
+
1479
+ m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
1480
+
1481
+ interventions: list = []
1482
+ if wants_steering:
1483
+ pos_inputs: str | list[str]
1484
+ neg_inputs: str | list[str]
1485
+ if positive_file:
1486
+ pos_inputs = read_examples_file(positive_file)
1487
+ elif positive:
1488
+ pos_inputs = positive
1489
+ else:
1490
+ raise typer.BadParameter("Provide --positive or --positive-file")
1491
+ if negative_file:
1492
+ neg_inputs = read_examples_file(negative_file)
1493
+ elif negative:
1494
+ neg_inputs = negative
1495
+ else:
1496
+ raise typer.BadParameter("Provide --negative or --negative-file")
1497
+
1498
+ assert at is not None
1499
+ vector = m.steer_vector(pos_inputs, neg_inputs, at=at)
1500
+ interventions.append(SteerIntervention(at, vector=vector, scale=scale))
1501
+
1502
+ if ablate_at is not None:
1503
+ interventions.append(AblateIntervention(ablate_at, method=ablate_method))
1504
+
1505
+ with console.status(" Generating..."):
1506
+ result = m.generate(
1507
+ prompt,
1508
+ max_new_tokens=max_new_tokens,
1509
+ interventions=interventions or None,
1510
+ capture=capture,
1511
+ do_sample=sample,
1512
+ temperature=temperature,
1513
+ top_p=top_p,
1514
+ )
1515
+
1516
+ if _output_format == "json":
1517
+ # Trim tensors: ids ride along in the Python API but bloat JSON, and
1518
+ # per-step logits are (1, vocab) each.
1519
+ out = {k: v for k, v in result.items() if k not in {"input_ids", "output_ids"}}
1520
+ if "steps" in out:
1521
+ out["steps"] = [
1522
+ {k: v for k, v in step.items() if k != "logits"}
1523
+ for step in out["steps"]
1524
+ ]
1525
+ _json_dump(out)
1526
+
1527
+
1193
1528
  def run() -> None:
1194
1529
  """CLI entry point that renders interpkit's intentional errors cleanly.
1195
1530
 
@@ -51,6 +51,20 @@ VALID_FIND_CIRCUIT_METHODS = frozenset({
51
51
  "zero",
52
52
  "mean",
53
53
  "resample",
54
+ # Gradient-based selection (phase 2): EAP-ranked components, verified
55
+ # causally with mean ablation.
56
+ "eap",
57
+ "eap-ig",
58
+ })
59
+
60
+ VALID_EAP_METRICS = frozenset({
61
+ # Only logit_diff has an EAP/AtP gradient formulation today.
62
+ "logit_diff",
63
+ })
64
+
65
+ VALID_LENS_KINDS = frozenset({
66
+ "logit",
67
+ "tuned", # Belrose et al. 2023 trained per-block translators (phase 3)
54
68
  })
55
69
 
56
70
  VALID_IG_METHODS = frozenset({
@@ -99,6 +113,8 @@ __all__ = [
99
113
  "VALID_TRACE_METHODS",
100
114
  "VALID_ABLATE_METHODS",
101
115
  "VALID_FIND_CIRCUIT_METHODS",
116
+ "VALID_EAP_METRICS",
117
+ "VALID_LENS_KINDS",
102
118
  "VALID_IG_METHODS",
103
119
  "VALID_IG_BASELINES",
104
120
  "_validate_enum",
@@ -48,7 +48,10 @@ def warn_if_leading_space_better(
48
48
  This helper detects single-token leading-space variants and surfaces
49
49
  a yellow tip; it is a no-op for tensor / list inputs, missing
50
50
  tokenizers, empty strings, or strings that already begin with
51
- whitespace.
51
+ whitespace. When the plain input splits into *multiple* subword tokens
52
+ (e.g. ``"Hate"`` -> ``['H', 'ate']``) the warning is escalated to a red
53
+ message, because averaging activations over the fragments produces a
54
+ direction that does not represent the word at all.
52
55
 
53
56
  Parameters
54
57
  ----------
@@ -104,13 +107,28 @@ def warn_if_leading_space_better(
104
107
 
105
108
  console = _Console()
106
109
 
107
- console.print(
108
- f" [yellow]{op_label}:[/yellow] {role} input {text!r} tokenizes to "
109
- f"{len(ids_plain)} token(s) {ids_plain}, but "
110
- f"{(' ' + text)!r} is a single token {ids_spaced}. "
111
- f"Consider using {(' ' + text)!r} for a stronger contrast "
112
- f"(BPE leading-space convention)."
113
- )
110
+ spaced = " " + text
111
+ if len(ids_plain) > 1:
112
+ # Egregious case: a single-word steering/contrast term that splits
113
+ # into multiple subword tokens (e.g. "Hate" -> ['H', 'ate']). The op
114
+ # averages activations across those fragments, so the resulting
115
+ # direction does not represent the word at all — almost always a
116
+ # mistake. Escalate from a tip to a red warning so it can't be missed.
117
+ console.print(
118
+ f" [bold red]{op_label}:[/bold red] {role} input {text!r} splits into "
119
+ f"{len(ids_plain)} subword tokens {ids_plain} — {op_label} will average "
120
+ f"meaningless fragments, so the result will not reflect {text!r}. "
121
+ f"{spaced!r} is a single token {ids_spaced}; use it instead "
122
+ f"(BPE leading-space convention)."
123
+ )
124
+ else:
125
+ console.print(
126
+ f" [yellow]{op_label}:[/yellow] {role} input {text!r} tokenizes to "
127
+ f"{len(ids_plain)} token(s) {ids_plain}, but "
128
+ f"{spaced!r} is a single token {ids_spaced}. "
129
+ f"Consider using {spaced!r} for a stronger contrast "
130
+ f"(BPE leading-space convention)."
131
+ )
114
132
  warned_count[0] += 1
115
133
 
116
134