interpkit 0.5.0__tar.gz → 0.6.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {interpkit-0.5.0 → interpkit-0.6.0}/PKG-INFO +26 -2
- {interpkit-0.5.0 → interpkit-0.6.0}/README.md +23 -1
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/__init__.py +19 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/cli/main.py +342 -7
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/enums.py +16 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/inputs.py +26 -8
- interpkit-0.6.0/interpkit/core/interventions.py +492 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/model.py +223 -3
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/paths.py +18 -1
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/render.py +176 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/support_matrix.py +8 -0
- interpkit-0.6.0/interpkit/core/topk.py +63 -0
- interpkit-0.6.0/interpkit/ops/_atp.py +13 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/_hooks.py +40 -1
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/ablate.py +9 -39
- interpkit-0.6.0/interpkit/ops/atp.py +230 -0
- interpkit-0.6.0/interpkit/ops/eap.py +355 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/find_circuit.py +130 -65
- interpkit-0.6.0/interpkit/ops/generate.py +292 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/lens.py +50 -5
- interpkit-0.6.0/interpkit/ops/maxact.py +347 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/patch.py +32 -56
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/steer.py +5 -24
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/trace.py +10 -56
- interpkit-0.6.0/interpkit/ops/tuned_lens.py +437 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/PKG-INFO +26 -2
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/SOURCES.txt +14 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/requires.txt +3 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/pyproject.toml +4 -1
- interpkit-0.6.0/tests/test_atp.py +68 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_cli.py +133 -0
- interpkit-0.6.0/tests/test_eap.py +138 -0
- interpkit-0.6.0/tests/test_generate.py +186 -0
- interpkit-0.6.0/tests/test_interventions.py +241 -0
- interpkit-0.6.0/tests/test_maxact.py +149 -0
- interpkit-0.6.0/tests/test_topk.py +58 -0
- interpkit-0.6.0/tests/test_tuned_lens.py +140 -0
- interpkit-0.5.0/interpkit/ops/_atp.py +0 -182
- {interpkit-0.5.0 → interpkit-0.6.0}/LICENSE +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/__main__.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/cli/__init__.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/__init__.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/__init__.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/blocks.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/family.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/heads.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/layers.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/names.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/probe.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/residual.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/resolve.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/tree.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/arch/types.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/cache.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/exceptions.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/html.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/loader.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/plot.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/registry.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/theme.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/core/tl_compat.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/__init__.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/activations.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/attention.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/attribute.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/batch.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/circuits.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/diff.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/dla.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/heads.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/inspect.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/probe.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/report.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/sae.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit/ops/scan.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/dependency_links.txt +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/entry_points.txt +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/interpkit.egg-info/top_level.txt +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/setup.cfg +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_ablate.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_activations.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_archinfo_serialization.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_architectures.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_attention.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_attribute.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_audit_regressions.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_cache.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_cache_invalidation.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_capabilities.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_chat.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_diff.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_discovery.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_discovery_units.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_error_handling.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_html.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_inputs.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_inspect.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_invariants.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_lens.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_load_params.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_multi_arch.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_ops.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_param_variants.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_patch.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_phase3_regressions.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_plot_internals.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_plots.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_probe.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_registry.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_regressions.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_render_internals.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_resolver.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_resolver_golden.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_robustness_audit.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_sae.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_seq2seq_contract.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_steer.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_tl_compat.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_tl_ops.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_trace.py +0 -0
- {interpkit-0.5.0 → interpkit-0.6.0}/tests/test_validation.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: interpkit
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.6.0
|
|
4
4
|
Summary: Mech interp for any HuggingFace model.
|
|
5
5
|
Author: Davide Zani
|
|
6
6
|
License-Expression: MIT
|
|
@@ -34,6 +34,8 @@ Provides-Extra: vision
|
|
|
34
34
|
Requires-Dist: torchvision>=0.16; extra == "vision"
|
|
35
35
|
Provides-Extra: probe
|
|
36
36
|
Requires-Dist: scikit-learn>=1.3; extra == "probe"
|
|
37
|
+
Provides-Extra: data
|
|
38
|
+
Requires-Dist: datasets>=2.14; extra == "data"
|
|
37
39
|
Provides-Extra: dev
|
|
38
40
|
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
39
41
|
Requires-Dist: pytest-timeout>=2.2; extra == "dev"
|
|
@@ -186,7 +188,13 @@ See [examples/10_chat_models.ipynb](examples/10_chat_models.ipynb) for a full wa
|
|
|
186
188
|
| **`ov_scores`** | OV circuit analysis — W_OV matrix per head | Transformers |
|
|
187
189
|
| **`qk_scores`** | QK circuit analysis — W_QK matrix per head | Transformers |
|
|
188
190
|
| **`composition`** | Q/K/V composition scores between heads in two layers | Transformers |
|
|
189
|
-
| **`find_circuit`** | Automated circuit discovery
|
|
191
|
+
| **`find_circuit`** | Automated circuit discovery — iterative ablation or EAP-based selection with causal verification | Transformers |
|
|
192
|
+
| **`generate`** | Generation with interventions active across every decode step + per-token lens capture | Generative LMs |
|
|
193
|
+
| **`intervene`** | Context manager applying steer/ablate/patch interventions to any op | Any model |
|
|
194
|
+
| **`atp`** | Attribution Patching — first-order patch-effect scores for all modules in 3 passes | Any model |
|
|
195
|
+
| **`eap`** | Edge Attribution Patching — gradient-based component → residual-stream edge scores (EAP-IG via `ig_steps`) | Causal LMs |
|
|
196
|
+
| **`train_tuned_lens`** | Train per-layer tuned-lens translators (Belrose et al. 2023); use via `lens(kind="tuned")` | LMs |
|
|
197
|
+
| **`max_activating`** | Scan a corpus for the examples that most activate a neuron / SAE feature / head | Any model |
|
|
190
198
|
| **`batch`** | Run any operation over a dataset with result aggregation | Any model |
|
|
191
199
|
|
|
192
200
|
---
|
|
@@ -482,6 +490,20 @@ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jb
|
|
|
482
490
|
interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
|
|
483
491
|
interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
|
|
484
492
|
|
|
493
|
+
# Generation-time interventions + per-token lens trajectories
|
|
494
|
+
interpkit generate gpt2 "I feel" --positive " joy" --negative " fear" --at transformer.h.6 --scale 8
|
|
495
|
+
interpkit generate gpt2 "The capital of France is" --capture lens
|
|
496
|
+
|
|
497
|
+
# Gradient-based circuit discovery
|
|
498
|
+
interpkit atp gpt2 --clean "The capital of France is" --corrupted "The capital of Germany is"
|
|
499
|
+
interpkit eap gpt2 --clean "..." --corrupted "..." --ig-steps 5
|
|
500
|
+
interpkit find-circuit gpt2 --clean "..." --corrupted "..." --method eap --threshold 0.3
|
|
501
|
+
|
|
502
|
+
# Tuned lens + max-activating examples
|
|
503
|
+
interpkit train-tuned-lens gpt2 --corpus-file texts.txt --save lens_dir/
|
|
504
|
+
interpkit lens gpt2 "The capital of France is" --tuned-lens lens_dir/
|
|
505
|
+
interpkit maxact gpt2 --at transformer.h.6.mlp --neuron 42 --texts-file corpus.txt
|
|
506
|
+
|
|
485
507
|
# Chat / instruct models — applies the tokenizer's chat template automatically
|
|
486
508
|
interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "Write a haiku about cats." --max-new-tokens 64
|
|
487
509
|
interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "What is 2+2?" --system "You are terse." --show-prompt
|
|
@@ -592,6 +614,8 @@ See the [`examples/`](examples/) directory for Jupyter notebooks:
|
|
|
592
614
|
| `08_dla_and_circuits` | DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery |
|
|
593
615
|
| `09_scan_and_batch` | Auto-scan, batch operations, dataset workflows |
|
|
594
616
|
| `10_chat_models` | Chat-template handling, `model.chat()`, message-list inputs, chat-style steering |
|
|
617
|
+
| `11_generation_interventions` | Steering/ablation active across every decode step, per-token lens trajectories, positional interventions, `model.intervene()` |
|
|
618
|
+
| `12_circuit_discovery_and_lenses` | Attribution Patching, Edge Attribution Patching, EAP-driven `find_circuit`, tuned lens, max-activating examples |
|
|
595
619
|
|
|
596
620
|
---
|
|
597
621
|
|
|
@@ -136,7 +136,13 @@ See [examples/10_chat_models.ipynb](examples/10_chat_models.ipynb) for a full wa
|
|
|
136
136
|
| **`ov_scores`** | OV circuit analysis — W_OV matrix per head | Transformers |
|
|
137
137
|
| **`qk_scores`** | QK circuit analysis — W_QK matrix per head | Transformers |
|
|
138
138
|
| **`composition`** | Q/K/V composition scores between heads in two layers | Transformers |
|
|
139
|
-
| **`find_circuit`** | Automated circuit discovery
|
|
139
|
+
| **`find_circuit`** | Automated circuit discovery — iterative ablation or EAP-based selection with causal verification | Transformers |
|
|
140
|
+
| **`generate`** | Generation with interventions active across every decode step + per-token lens capture | Generative LMs |
|
|
141
|
+
| **`intervene`** | Context manager applying steer/ablate/patch interventions to any op | Any model |
|
|
142
|
+
| **`atp`** | Attribution Patching — first-order patch-effect scores for all modules in 3 passes | Any model |
|
|
143
|
+
| **`eap`** | Edge Attribution Patching — gradient-based component → residual-stream edge scores (EAP-IG via `ig_steps`) | Causal LMs |
|
|
144
|
+
| **`train_tuned_lens`** | Train per-layer tuned-lens translators (Belrose et al. 2023); use via `lens(kind="tuned")` | LMs |
|
|
145
|
+
| **`max_activating`** | Scan a corpus for the examples that most activate a neuron / SAE feature / head | Any model |
|
|
140
146
|
| **`batch`** | Run any operation over a dataset with result aggregation | Any model |
|
|
141
147
|
|
|
142
148
|
---
|
|
@@ -432,6 +438,20 @@ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jb
|
|
|
432
438
|
interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
|
|
433
439
|
interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
|
|
434
440
|
|
|
441
|
+
# Generation-time interventions + per-token lens trajectories
|
|
442
|
+
interpkit generate gpt2 "I feel" --positive " joy" --negative " fear" --at transformer.h.6 --scale 8
|
|
443
|
+
interpkit generate gpt2 "The capital of France is" --capture lens
|
|
444
|
+
|
|
445
|
+
# Gradient-based circuit discovery
|
|
446
|
+
interpkit atp gpt2 --clean "The capital of France is" --corrupted "The capital of Germany is"
|
|
447
|
+
interpkit eap gpt2 --clean "..." --corrupted "..." --ig-steps 5
|
|
448
|
+
interpkit find-circuit gpt2 --clean "..." --corrupted "..." --method eap --threshold 0.3
|
|
449
|
+
|
|
450
|
+
# Tuned lens + max-activating examples
|
|
451
|
+
interpkit train-tuned-lens gpt2 --corpus-file texts.txt --save lens_dir/
|
|
452
|
+
interpkit lens gpt2 "The capital of France is" --tuned-lens lens_dir/
|
|
453
|
+
interpkit maxact gpt2 --at transformer.h.6.mlp --neuron 42 --texts-file corpus.txt
|
|
454
|
+
|
|
435
455
|
# Chat / instruct models — applies the tokenizer's chat template automatically
|
|
436
456
|
interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "Write a haiku about cats." --max-new-tokens 64
|
|
437
457
|
interpkit chat HuggingFaceTB/SmolLM2-360M-Instruct "What is 2+2?" --system "You are terse." --show-prompt
|
|
@@ -542,6 +562,8 @@ See the [`examples/`](examples/) directory for Jupyter notebooks:
|
|
|
542
562
|
| `08_dla_and_circuits` | DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery |
|
|
543
563
|
| `09_scan_and_batch` | Auto-scan, batch operations, dataset workflows |
|
|
544
564
|
| `10_chat_models` | Chat-template handling, `model.chat()`, message-list inputs, chat-style steering |
|
|
565
|
+
| `11_generation_interventions` | Steering/ablation active across every decode step, per-token lens trajectories, positional interventions, `model.intervene()` |
|
|
566
|
+
| `12_circuit_discovery_and_lenses` | Attribution Patching, Edge Attribution Patching, EAP-driven `find_circuit`, tuned lens, max-activating examples |
|
|
545
567
|
|
|
546
568
|
---
|
|
547
569
|
|
|
@@ -16,6 +16,16 @@ from interpkit.core.exceptions import (
|
|
|
16
16
|
OperationNotSupportedForArchitecture,
|
|
17
17
|
WrongInputType,
|
|
18
18
|
)
|
|
19
|
+
from interpkit.core.interventions import (
|
|
20
|
+
AblateIntervention,
|
|
21
|
+
CaptureProbe,
|
|
22
|
+
FnIntervention,
|
|
23
|
+
GenerationContext,
|
|
24
|
+
Intervention,
|
|
25
|
+
PatchIntervention,
|
|
26
|
+
SteerIntervention,
|
|
27
|
+
apply_interventions,
|
|
28
|
+
)
|
|
19
29
|
from interpkit.core.loader import load, load_module
|
|
20
30
|
from interpkit.core.model import Model
|
|
21
31
|
from interpkit.core.registry import register
|
|
@@ -54,6 +64,15 @@ __all__ = [
|
|
|
54
64
|
"LensPipelineMismatch",
|
|
55
65
|
"OperationNotSupportedForArchitecture",
|
|
56
66
|
"WrongInputType",
|
|
67
|
+
# Interventions
|
|
68
|
+
"Intervention",
|
|
69
|
+
"SteerIntervention",
|
|
70
|
+
"AblateIntervention",
|
|
71
|
+
"PatchIntervention",
|
|
72
|
+
"FnIntervention",
|
|
73
|
+
"CaptureProbe",
|
|
74
|
+
"GenerationContext",
|
|
75
|
+
"apply_interventions",
|
|
57
76
|
# Operations
|
|
58
77
|
"register",
|
|
59
78
|
"diff",
|
|
@@ -236,6 +236,28 @@ def _show_extensive_help() -> None:
|
|
|
236
236
|
padding=(0, 2),
|
|
237
237
|
))
|
|
238
238
|
|
|
239
|
+
console.print()
|
|
240
|
+
console.print(Panel(
|
|
241
|
+
f"[bold {ACCENT}]generate[/bold {ACCENT}] "
|
|
242
|
+
"[dim]interpkit generate gpt2 'I feel' --positive ' joy' --negative ' fear' --at transformer.h.6 --scale 8[/dim]\n\n"
|
|
243
|
+
"Generate text with interventions active across [italic]every[/italic] decode step —"
|
|
244
|
+
" the generation-time counterpart of [bold]steer[/bold] / [bold]ablate[/bold], which"
|
|
245
|
+
" analyse a single forward pass. A steering vector or ablation stays hooked for the"
|
|
246
|
+
" prefill and all KV-cached decode steps, so you can watch a nudged model write.\n\n"
|
|
247
|
+
" Add [bold green]--capture lens[/bold green] to record each generated token's"
|
|
248
|
+
" logit-lens trajectory: which layer first predicted the token the model ended up"
|
|
249
|
+
" emitting.\n\n"
|
|
250
|
+
" [bold]Key options:[/bold]\n"
|
|
251
|
+
" [bold green]--positive / --negative + --at[/bold green] Build a steering vector and apply it while generating.\n"
|
|
252
|
+
" [bold green]--ablate-at / --ablate-method[/bold green] Knock out a module for the whole generation.\n"
|
|
253
|
+
" [bold green]--capture lens|logits[/bold green] Per-token lens trajectory or raw step logits.\n"
|
|
254
|
+
" [bold green]--max-new-tokens N[/bold green] Generation budget (default 64).\n"
|
|
255
|
+
" [bold green]--sample / --temperature / --top-p[/bold green] Standard sampling controls.",
|
|
256
|
+
title="generate",
|
|
257
|
+
border_style=ACCENT_DIM,
|
|
258
|
+
padding=(0, 2),
|
|
259
|
+
))
|
|
260
|
+
|
|
239
261
|
# ── Core Operations ───────────────────────────────────────────
|
|
240
262
|
console.print()
|
|
241
263
|
console.print(Rule("[bold]Core Operations[/bold]", style=ACCENT))
|
|
@@ -285,9 +307,17 @@ def _show_extensive_help() -> None:
|
|
|
285
307
|
"interpkit lens gpt2 'The capital of France is'",
|
|
286
308
|
"Logit lens. After every transformer layer, the hidden state is projected directly into"
|
|
287
309
|
" vocabulary space so you can see what the model 'thinks' it's predicting at each depth."
|
|
288
|
-
" Lets you watch a vague representation sharpen into the final answer layer by layer
|
|
310
|
+
" Lets you watch a vague representation sharpen into the final answer layer by layer.\n\n"
|
|
311
|
+
" The raw projection is biased for early layers (their basis isn't aligned with the"
|
|
312
|
+
" unembedding). Train per-layer translators once with"
|
|
313
|
+
f" [bold {ACCENT}]train-tuned-lens[/bold {ACCENT}] and pass"
|
|
314
|
+
" [bold green]--tuned-lens <path>[/bold green] for the unbiased tuned-lens readout"
|
|
315
|
+
" (Belrose et al. 2023).\n"
|
|
316
|
+
" [dim]interpkit train-tuned-lens gpt2 --corpus-file texts.txt --save lens_dir/[/dim]\n"
|
|
317
|
+
" [dim]interpkit lens gpt2 'The capital of France is' --tuned-lens lens_dir/[/dim]",
|
|
289
318
|
[
|
|
290
319
|
("--position N", "Analyse a single token position instead of all positions."),
|
|
320
|
+
("--tuned-lens PATH", "Apply saved tuned-lens translators instead of the raw projection."),
|
|
291
321
|
],
|
|
292
322
|
),
|
|
293
323
|
(
|
|
@@ -461,7 +491,9 @@ def _show_extensive_help() -> None:
|
|
|
461
491
|
" [bold green]--clean / --corrupted[/bold green] Single clean and corrupted input texts.\n"
|
|
462
492
|
" [bold green]--clean-file / --corrupted-file[/bold green] Text files with one example per line (paired by line number).\n"
|
|
463
493
|
" [bold green]--threshold[/bold green] Minimum ablation effect to include (default 0.01).\n"
|
|
464
|
-
" [bold green]--method[/bold green]
|
|
494
|
+
" [bold green]--method[/bold green] mean (default) · zero · resample (ablation), or eap · eap-ig"
|
|
495
|
+
" (gradient-based selection in a handful of passes — much faster; the circuit is still"
|
|
496
|
+
" verified causally).\n"
|
|
465
497
|
" [bold green]--metric[/bold green] logit_diff · kl_div · target_prob · l2_prob",
|
|
466
498
|
title="find-circuit",
|
|
467
499
|
border_style=ACCENT_DIM,
|
|
@@ -469,6 +501,67 @@ def _show_extensive_help() -> None:
|
|
|
469
501
|
))
|
|
470
502
|
console.print()
|
|
471
503
|
|
|
504
|
+
console.print(Panel(
|
|
505
|
+
f"[bold {ACCENT}]atp[/bold {ACCENT}] "
|
|
506
|
+
"[dim]interpkit atp gpt2 --clean 'The capital of France is' --corrupted 'The capital of Germany is'[/dim]\n\n"
|
|
507
|
+
"Attribution Patching (Syed et al. 2023). A first-order gradient approximation of"
|
|
508
|
+
" activation patching: one clean forward, one corrupted forward, and one backward pass"
|
|
509
|
+
" score [italic]every[/italic] module simultaneously — versus one forward per module for"
|
|
510
|
+
" exhaustive tracing. Correlation with true patch effects is typically 0.85–0.95."
|
|
511
|
+
" Use it as the fast first look, then confirm top candidates with"
|
|
512
|
+
f" [bold {ACCENT}]trace[/bold {ACCENT}] or [bold {ACCENT}]patch[/bold {ACCENT}].\n\n"
|
|
513
|
+
" [bold]Key options:[/bold]\n"
|
|
514
|
+
" [bold green]--clean / --corrupted[/bold green] The contrast pair to attribute.\n"
|
|
515
|
+
" [bold green]--top-k[/bold green] Top modules to report by absolute score (0 = all).",
|
|
516
|
+
title="atp",
|
|
517
|
+
border_style=ACCENT_DIM,
|
|
518
|
+
padding=(0, 2),
|
|
519
|
+
))
|
|
520
|
+
console.print()
|
|
521
|
+
|
|
522
|
+
console.print(Panel(
|
|
523
|
+
f"[bold {ACCENT}]eap[/bold {ACCENT}] "
|
|
524
|
+
"[dim]interpkit eap gpt2 --clean 'The capital of France is' --corrupted 'The capital of Germany is'[/dim]\n\n"
|
|
525
|
+
"Edge Attribution Patching. Where [bold]atp[/bold] scores modules, eap scores"
|
|
526
|
+
" [italic]edges[/italic]: how much each component's clean-vs-corrupted delta matters as it"
|
|
527
|
+
" flows into each downstream residual-stream layer. The edge at a component's own layer is"
|
|
528
|
+
" its total effect; deeper edges show how the effect persists down the stream. Inputs must"
|
|
529
|
+
" tokenize to the same length.\n\n"
|
|
530
|
+
" [bold green]--ig-steps 5[/bold green] switches to EAP-IG: gradients averaged over"
|
|
531
|
+
" embeddings interpolated from corrupted toward clean — more faithful scores when the"
|
|
532
|
+
" corrupted point sits in a saturated region.\n\n"
|
|
533
|
+
" [bold]Key options:[/bold]\n"
|
|
534
|
+
" [bold green]--clean / --corrupted[/bold green] Token-aligned contrast pair.\n"
|
|
535
|
+
" [bold green]--ig-steps[/bold green] EAP-IG interpolation steps (0 = plain EAP).\n"
|
|
536
|
+
" [bold green]--top-k-edges[/bold green] Top edges to report by absolute score (0 = all).",
|
|
537
|
+
title="eap",
|
|
538
|
+
border_style=ACCENT_DIM,
|
|
539
|
+
padding=(0, 2),
|
|
540
|
+
))
|
|
541
|
+
console.print()
|
|
542
|
+
|
|
543
|
+
console.print(Panel(
|
|
544
|
+
f"[bold {ACCENT}]maxact[/bold {ACCENT}] "
|
|
545
|
+
"[dim]interpkit maxact gpt2 --at transformer.h.6.mlp --neuron 42 --texts-file corpus.txt[/dim]\n\n"
|
|
546
|
+
"Max-activating examples — the feature-browsing workflow: scan a corpus and show the"
|
|
547
|
+
" contexts where one unit fires hardest, with the peak token highlighted. Works for raw"
|
|
548
|
+
" neurons ([bold green]--neuron[/bold green]), SAE features ([bold green]--feature[/bold green]"
|
|
549
|
+
" + [bold green]--sae[/bold green]), and attention heads ([bold green]--head[/bold green])."
|
|
550
|
+
" Streams batched forwards and keeps only the top-k scored contexts, so memory stays flat"
|
|
551
|
+
" however large the corpus.\n\n"
|
|
552
|
+
" HF datasets work too (requires [dim]pip install 'interpkit[data]'[/dim]):\n"
|
|
553
|
+
" [dim]interpkit maxact gpt2 --at transformer.h.6.mlp --neuron 42 --dataset hf:imdb --max-examples 256[/dim]\n\n"
|
|
554
|
+
" [bold]Key options:[/bold]\n"
|
|
555
|
+
" [bold green]--texts-file / --dataset[/bold green] Corpus: one-per-line file, or hf:name[:split[:column]].\n"
|
|
556
|
+
" [bold green]--neuron / --feature / --head[/bold green] Which unit to scan (exactly one).\n"
|
|
557
|
+
" [bold green]--sae[/bold green] SAE repo ID or path (with --feature).\n"
|
|
558
|
+
" [bold green]--top-k / --max-examples[/bold green] How many results / how much corpus.",
|
|
559
|
+
title="maxact",
|
|
560
|
+
border_style=ACCENT_DIM,
|
|
561
|
+
padding=(0, 2),
|
|
562
|
+
))
|
|
563
|
+
console.print()
|
|
564
|
+
|
|
472
565
|
console.print(Panel(
|
|
473
566
|
f"[bold {ACCENT}]features[/bold {ACCENT}] "
|
|
474
567
|
"[dim]interpkit features gpt2 '...' --at transformer.h.8 --sae jbloom/GPT2-Small-SAEs[/dim]\n\n"
|
|
@@ -552,13 +645,15 @@ def main(
|
|
|
552
645
|
("scan", "One-command overview \u2014 DLA, lens, attention, attribution"),
|
|
553
646
|
("report", "Generate an interactive HTML report"),
|
|
554
647
|
("chat", "Send a message to a chat / instruct model"),
|
|
648
|
+
("generate", "Generate with steering/ablation active + per-token lens"),
|
|
555
649
|
])
|
|
556
650
|
|
|
557
651
|
core_ops = _cmd_table([
|
|
558
652
|
("inspect", "Module tree with types, params, roles"),
|
|
559
653
|
("dla", "Direct Logit Attribution \u2014 decompose logit by component"),
|
|
560
654
|
("trace", "Causal tracing \u2014 module or position-aware"),
|
|
561
|
-
("lens", "Logit lens \u2014 project layers to vocab"),
|
|
655
|
+
("lens", "Logit lens \u2014 project layers to vocab (--tuned-lens for tuned)"),
|
|
656
|
+
("train-tuned-lens", "Train per-layer tuned-lens translators"),
|
|
562
657
|
("attribute", "Gradient saliency over inputs"),
|
|
563
658
|
("patch", "Activation patching at module/head/position"),
|
|
564
659
|
])
|
|
@@ -574,8 +669,11 @@ def main(
|
|
|
574
669
|
])
|
|
575
670
|
|
|
576
671
|
circuit_ops = _cmd_table([
|
|
577
|
-
("find-circuit", "Automated circuit discovery"),
|
|
672
|
+
("find-circuit", "Automated circuit discovery (ablation or EAP)"),
|
|
673
|
+
("atp", "Attribution Patching — score all modules in 3 passes"),
|
|
674
|
+
("eap", "Edge Attribution Patching — gradient-based edge scores"),
|
|
578
675
|
("features", "SAE feature decomposition (single or contrastive)"),
|
|
676
|
+
("maxact", "Max-activating examples for a neuron / SAE feature / head"),
|
|
579
677
|
])
|
|
580
678
|
|
|
581
679
|
layout = Table(show_header=False, box=None, pad_edge=False, padding=0, expand=True)
|
|
@@ -726,18 +824,65 @@ def lens(
|
|
|
726
824
|
save: str | None = typer.Option(None, "--save", help="Save heatmap to file (e.g. lens.png)"),
|
|
727
825
|
html_path: str | None = typer.Option(None, "--html", help="Save interactive HTML to file"),
|
|
728
826
|
position: int | None = typer.Option(None, "--position", help="Single token position to analyse (-1 = last). Omit for all positions."),
|
|
827
|
+
tuned_lens_path: str | None = typer.Option(None, "--tuned-lens", help="Path to a saved tuned lens (switches to kind='tuned'). Train with `interpkit train-tuned-lens`."),
|
|
729
828
|
device: str | None = typer.Option(None, help="Device"),
|
|
730
829
|
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
731
830
|
device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
|
|
732
831
|
) -> None:
|
|
733
|
-
"""Logit lens: project each layer's hidden state to vocabulary space.
|
|
832
|
+
"""Logit lens: project each layer's hidden state to vocabulary space.
|
|
833
|
+
|
|
834
|
+
Pass --tuned-lens <path> to apply trained per-layer translators
|
|
835
|
+
(Belrose et al. 2023) for an unbiased early-layer readout.
|
|
836
|
+
"""
|
|
734
837
|
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
838
|
+
kind = "tuned" if tuned_lens_path is not None else "logit"
|
|
735
839
|
with console.status(" Running logit lens..."):
|
|
736
|
-
result = m.lens(
|
|
840
|
+
result = m.lens(
|
|
841
|
+
text, save=save, html=html_path, position=position,
|
|
842
|
+
kind=kind, tuned_lens=tuned_lens_path,
|
|
843
|
+
)
|
|
737
844
|
if _output_format == "json":
|
|
738
845
|
_json_dump(result if isinstance(result, dict) else {"results": result})
|
|
739
846
|
|
|
740
847
|
|
|
848
|
+
# ══════════════════════════════════════════════════════════════════
|
|
849
|
+
# train-tuned-lens
|
|
850
|
+
# ══════════════════════════════════════════════════════════════════
|
|
851
|
+
|
|
852
|
+
|
|
853
|
+
@app.command("train-tuned-lens")
|
|
854
|
+
def train_tuned_lens_cmd(
|
|
855
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
856
|
+
corpus_file: str = typer.Option(..., "--corpus-file", help="Text file with training sentences, one per line"),
|
|
857
|
+
steps: int = typer.Option(200, "--steps", help="Training steps"),
|
|
858
|
+
batch_size: int = typer.Option(4, "--batch-size", help="Batch size"),
|
|
859
|
+
lr: float = typer.Option(1e-3, "--lr", help="Adam learning rate"),
|
|
860
|
+
max_length: int = typer.Option(64, "--max-length", help="Token truncation length"),
|
|
861
|
+
seed: int = typer.Option(0, "--seed", help="Random seed (deterministic on CPU)"),
|
|
862
|
+
save: str | None = typer.Option(None, "--save", help="Output directory or .safetensors path (default: ~/.cache/interpkit/tuned_lens/<model>/)"),
|
|
863
|
+
device: str | None = typer.Option(None, help="Device"),
|
|
864
|
+
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
865
|
+
device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
|
|
866
|
+
) -> None:
|
|
867
|
+
"""Train tuned-lens translators (Belrose et al. 2023) for a model.
|
|
868
|
+
|
|
869
|
+
The model stays frozen; only per-layer affine translators train.
|
|
870
|
+
Use the result with `interpkit lens ... --tuned-lens <path>`.
|
|
871
|
+
"""
|
|
872
|
+
from interpkit.core.inputs import read_examples_file
|
|
873
|
+
from interpkit.ops.tuned_lens import default_tuned_lens_dir
|
|
874
|
+
|
|
875
|
+
corpus = read_examples_file(corpus_file)
|
|
876
|
+
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
877
|
+
out = save if save is not None else str(default_tuned_lens_dir(model_name))
|
|
878
|
+
lens_obj = m.train_tuned_lens(
|
|
879
|
+
corpus, steps=steps, batch_size=batch_size, lr=lr,
|
|
880
|
+
max_length=max_length, seed=seed, save=out,
|
|
881
|
+
)
|
|
882
|
+
if _output_format == "json":
|
|
883
|
+
_json_dump({"saved_to": out, "meta": lens_obj.meta})
|
|
884
|
+
|
|
885
|
+
|
|
741
886
|
# ══════════════════════════════════════════════════════════════════
|
|
742
887
|
# attribute
|
|
743
888
|
# ══════════════════════════════════════════════════════════════════
|
|
@@ -1080,7 +1225,7 @@ def find_circuit(
|
|
|
1080
1225
|
clean_file: str | None = typer.Option(None, "--clean-file", help="Text file with clean examples, one per line"),
|
|
1081
1226
|
corrupted_file: str | None = typer.Option(None, "--corrupted-file", help="Text file with corrupted examples, one per line (must match --clean-file line count)"),
|
|
1082
1227
|
threshold: float = typer.Option(0.01, "--threshold", help="Minimum ablation effect to include in circuit (0-1)"),
|
|
1083
|
-
method: str = typer.Option("mean", "--method", help="
|
|
1228
|
+
method: str = typer.Option("mean", "--method", help="Selection method: mean (default), zero, resample (ablation), or eap / eap-ig (gradient-based, much faster)"),
|
|
1084
1229
|
metric: str = typer.Option("logit_diff", "--metric", help="Effect metric: logit_diff, kl_div, target_prob, l2_prob"),
|
|
1085
1230
|
device: str | None = typer.Option(None, help="Device"),
|
|
1086
1231
|
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
@@ -1190,6 +1335,196 @@ def chat(
|
|
|
1190
1335
|
_json_dump({k: v for k, v in result.items() if k not in {"input_ids", "output_ids"}})
|
|
1191
1336
|
|
|
1192
1337
|
|
|
1338
|
+
# ══════════════════════════════════════════════════════════════════
|
|
1339
|
+
# atp / eap
|
|
1340
|
+
# ══════════════════════════════════════════════════════════════════
|
|
1341
|
+
|
|
1342
|
+
|
|
1343
|
+
@app.command()
|
|
1344
|
+
def atp(
|
|
1345
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
1346
|
+
clean: str = typer.Option(..., "--clean", help="Clean input"),
|
|
1347
|
+
corrupted: str = typer.Option(..., "--corrupted", help="Corrupted input"),
|
|
1348
|
+
top_k: int = typer.Option(20, "--top-k", help="Top modules to report by absolute score. 0 = all."),
|
|
1349
|
+
device: str | None = typer.Option(None, help="Device"),
|
|
1350
|
+
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
1351
|
+
device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
|
|
1352
|
+
) -> None:
|
|
1353
|
+
"""Attribution Patching: first-order patch-effect scores for every module.
|
|
1354
|
+
|
|
1355
|
+
Three model passes score all modules at once — the fast first look
|
|
1356
|
+
before committing to `trace`'s per-module full patching.
|
|
1357
|
+
"""
|
|
1358
|
+
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
1359
|
+
effective_top_k: int | None = top_k if top_k > 0 else None
|
|
1360
|
+
with console.status(" Computing attribution patching scores..."):
|
|
1361
|
+
result = m.atp(clean, corrupted, top_k=effective_top_k)
|
|
1362
|
+
if _output_format == "json":
|
|
1363
|
+
_json_dump(result)
|
|
1364
|
+
|
|
1365
|
+
|
|
1366
|
+
@app.command()
|
|
1367
|
+
def eap(
|
|
1368
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
1369
|
+
clean: str = typer.Option(..., "--clean", help="Clean input (must tokenize to same length as --corrupted)"),
|
|
1370
|
+
corrupted: str = typer.Option(..., "--corrupted", help="Corrupted input"),
|
|
1371
|
+
ig_steps: int = typer.Option(0, "--ig-steps", help="EAP-IG interpolation steps (0 = plain EAP; try 5)"),
|
|
1372
|
+
top_k_edges: int = typer.Option(30, "--top-k-edges", help="Top edges to report by absolute score. 0 = all."),
|
|
1373
|
+
device: str | None = typer.Option(None, help="Device"),
|
|
1374
|
+
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
1375
|
+
device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
|
|
1376
|
+
) -> None:
|
|
1377
|
+
"""Edge Attribution Patching: gradient-based edge scores for circuit discovery.
|
|
1378
|
+
|
|
1379
|
+
Scores every (component → residual stream) edge from a handful of
|
|
1380
|
+
passes. Pair with `find-circuit --method eap` for a causally
|
|
1381
|
+
verified circuit.
|
|
1382
|
+
"""
|
|
1383
|
+
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
1384
|
+
effective_top_k: int | None = top_k_edges if top_k_edges > 0 else None
|
|
1385
|
+
with console.status(" Computing edge attribution scores..."):
|
|
1386
|
+
result = m.eap(clean, corrupted, ig_steps=ig_steps, top_k_edges=effective_top_k)
|
|
1387
|
+
if _output_format == "json":
|
|
1388
|
+
_json_dump(result)
|
|
1389
|
+
|
|
1390
|
+
|
|
1391
|
+
# ══════════════════════════════════════════════════════════════════
|
|
1392
|
+
# maxact
|
|
1393
|
+
# ══════════════════════════════════════════════════════════════════
|
|
1394
|
+
|
|
1395
|
+
|
|
1396
|
+
@app.command()
|
|
1397
|
+
def maxact(
|
|
1398
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
1399
|
+
at: str = typer.Option(..., "--at", help="Module whose activations to scan (e.g. transformer.h.6.mlp)"),
|
|
1400
|
+
texts_file: str | None = typer.Option(None, "--texts-file", help="Text file with one example per line"),
|
|
1401
|
+
dataset: str | None = typer.Option(None, "--dataset", help="HF dataset spec: hf:name[:split[:column]] (needs interpkit[data] + --max-examples)"),
|
|
1402
|
+
neuron: int | None = typer.Option(None, "--neuron", help="Neuron index at the module (raw activation score)"),
|
|
1403
|
+
feature: int | None = typer.Option(None, "--feature", help="SAE feature index (requires --sae)"),
|
|
1404
|
+
head: int | None = typer.Option(None, "--head", help="Attention head index (pre-projection output norm)"),
|
|
1405
|
+
sae: str | None = typer.Option(None, "--sae", help="SAE repo ID or local path (with --feature)"),
|
|
1406
|
+
top_k: int = typer.Option(20, "--top-k", help="Top examples to keep"),
|
|
1407
|
+
batch_size: int = typer.Option(8, "--batch-size", help="Forward batch size"),
|
|
1408
|
+
max_examples: int | None = typer.Option(None, "--max-examples", help="Cap on dataset examples scanned"),
|
|
1409
|
+
max_length: int = typer.Option(128, "--max-length", help="Token truncation length"),
|
|
1410
|
+
device: str | None = typer.Option(None, help="Device"),
|
|
1411
|
+
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
1412
|
+
device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
|
|
1413
|
+
) -> None:
|
|
1414
|
+
"""Find the dataset examples that most activate a neuron / SAE feature / head.
|
|
1415
|
+
|
|
1416
|
+
The feature-browsing workflow: "what does this unit fire on?".
|
|
1417
|
+
Streams batched forwards and keeps only the top-k scored contexts.
|
|
1418
|
+
"""
|
|
1419
|
+
from interpkit.core.inputs import read_examples_file
|
|
1420
|
+
|
|
1421
|
+
if (texts_file is None) == (dataset is None):
|
|
1422
|
+
raise typer.BadParameter("Provide exactly one of --texts-file or --dataset.")
|
|
1423
|
+
data: list[str] | str = (
|
|
1424
|
+
read_examples_file(texts_file) if texts_file is not None else dataset # type: ignore[assignment]
|
|
1425
|
+
)
|
|
1426
|
+
|
|
1427
|
+
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
1428
|
+
result = m.max_activating(
|
|
1429
|
+
data, at=at,
|
|
1430
|
+
neuron=neuron, feature=feature, head=head, sae=sae,
|
|
1431
|
+
top_k=top_k, batch_size=batch_size, max_examples=max_examples,
|
|
1432
|
+
max_length=max_length,
|
|
1433
|
+
)
|
|
1434
|
+
if _output_format == "json":
|
|
1435
|
+
_json_dump(result)
|
|
1436
|
+
|
|
1437
|
+
|
|
1438
|
+
# ══════════════════════════════════════════════════════════════════
|
|
1439
|
+
# generate
|
|
1440
|
+
# ══════════════════════════════════════════════════════════════════
|
|
1441
|
+
|
|
1442
|
+
|
|
1443
|
+
@app.command()
|
|
1444
|
+
def generate(
|
|
1445
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
1446
|
+
prompt: str = typer.Argument(..., help="Prompt text to generate from"),
|
|
1447
|
+
max_new_tokens: int = typer.Option(64, "--max-new-tokens", help="Max generation length"),
|
|
1448
|
+
positive: str | None = typer.Option(None, "--positive", help="Positive steering text (single example)"),
|
|
1449
|
+
negative: str | None = typer.Option(None, "--negative", help="Negative steering text (single example)"),
|
|
1450
|
+
positive_file: str | None = typer.Option(None, "--positive-file", help="Text file with positive examples, one per line"),
|
|
1451
|
+
negative_file: str | None = typer.Option(None, "--negative-file", help="Text file with negative examples, one per line"),
|
|
1452
|
+
at: str | None = typer.Option(None, "--at", help="Module to apply the steering vector at (required with --positive/--negative)"),
|
|
1453
|
+
scale: float = typer.Option(2.0, "--scale", help="Steering vector scale factor"),
|
|
1454
|
+
ablate_at: str | None = typer.Option(None, "--ablate-at", help="Module to ablate during generation"),
|
|
1455
|
+
ablate_method: str = typer.Option("zero", "--ablate-method", help="Ablation method: zero, mean"),
|
|
1456
|
+
capture: str | None = typer.Option(None, "--capture", help="Per-token capture: 'lens' (logit-lens trajectory) or 'logits'"),
|
|
1457
|
+
sample: bool = typer.Option(False, "--sample/--no-sample", help="Sample (True) or use greedy decoding (False, default)"),
|
|
1458
|
+
temperature: float = typer.Option(1.0, "--temperature", help="Sampling temperature (used when --sample)"),
|
|
1459
|
+
top_p: float = typer.Option(1.0, "--top-p", help="Nucleus sampling cutoff (used when --sample)"),
|
|
1460
|
+
device: str | None = typer.Option(None, help="Device"),
|
|
1461
|
+
dtype: str | None = typer.Option(None, "--dtype", help="Model dtype: float16, bfloat16, float32, auto"),
|
|
1462
|
+
device_map: str | None = typer.Option(None, "--device-map", help="HF device_map (e.g. 'auto')"),
|
|
1463
|
+
) -> None:
|
|
1464
|
+
"""Generate text with interventions active across every decode step.
|
|
1465
|
+
|
|
1466
|
+
Steering (``--positive`` / ``--negative`` + ``--at``) and ablation
|
|
1467
|
+
(``--ablate-at``) stay hooked for the prefill and all KV-cached decode
|
|
1468
|
+
steps — the generation-time counterpart of ``steer`` / ``ablate``.
|
|
1469
|
+
``--capture lens`` additionally records each generated token's
|
|
1470
|
+
logit-lens trajectory through every block.
|
|
1471
|
+
"""
|
|
1472
|
+
from interpkit.core.inputs import read_examples_file
|
|
1473
|
+
from interpkit.core.interventions import AblateIntervention, SteerIntervention
|
|
1474
|
+
|
|
1475
|
+
wants_steering = any([positive, negative, positive_file, negative_file])
|
|
1476
|
+
if wants_steering and at is None:
|
|
1477
|
+
raise typer.BadParameter("Steering requires --at (module to apply the vector at).")
|
|
1478
|
+
|
|
1479
|
+
m = _load_model(model_name, device=device, dtype=dtype, device_map=device_map)
|
|
1480
|
+
|
|
1481
|
+
interventions: list = []
|
|
1482
|
+
if wants_steering:
|
|
1483
|
+
pos_inputs: str | list[str]
|
|
1484
|
+
neg_inputs: str | list[str]
|
|
1485
|
+
if positive_file:
|
|
1486
|
+
pos_inputs = read_examples_file(positive_file)
|
|
1487
|
+
elif positive:
|
|
1488
|
+
pos_inputs = positive
|
|
1489
|
+
else:
|
|
1490
|
+
raise typer.BadParameter("Provide --positive or --positive-file")
|
|
1491
|
+
if negative_file:
|
|
1492
|
+
neg_inputs = read_examples_file(negative_file)
|
|
1493
|
+
elif negative:
|
|
1494
|
+
neg_inputs = negative
|
|
1495
|
+
else:
|
|
1496
|
+
raise typer.BadParameter("Provide --negative or --negative-file")
|
|
1497
|
+
|
|
1498
|
+
assert at is not None
|
|
1499
|
+
vector = m.steer_vector(pos_inputs, neg_inputs, at=at)
|
|
1500
|
+
interventions.append(SteerIntervention(at, vector=vector, scale=scale))
|
|
1501
|
+
|
|
1502
|
+
if ablate_at is not None:
|
|
1503
|
+
interventions.append(AblateIntervention(ablate_at, method=ablate_method))
|
|
1504
|
+
|
|
1505
|
+
with console.status(" Generating..."):
|
|
1506
|
+
result = m.generate(
|
|
1507
|
+
prompt,
|
|
1508
|
+
max_new_tokens=max_new_tokens,
|
|
1509
|
+
interventions=interventions or None,
|
|
1510
|
+
capture=capture,
|
|
1511
|
+
do_sample=sample,
|
|
1512
|
+
temperature=temperature,
|
|
1513
|
+
top_p=top_p,
|
|
1514
|
+
)
|
|
1515
|
+
|
|
1516
|
+
if _output_format == "json":
|
|
1517
|
+
# Trim tensors: ids ride along in the Python API but bloat JSON, and
|
|
1518
|
+
# per-step logits are (1, vocab) each.
|
|
1519
|
+
out = {k: v for k, v in result.items() if k not in {"input_ids", "output_ids"}}
|
|
1520
|
+
if "steps" in out:
|
|
1521
|
+
out["steps"] = [
|
|
1522
|
+
{k: v for k, v in step.items() if k != "logits"}
|
|
1523
|
+
for step in out["steps"]
|
|
1524
|
+
]
|
|
1525
|
+
_json_dump(out)
|
|
1526
|
+
|
|
1527
|
+
|
|
1193
1528
|
def run() -> None:
|
|
1194
1529
|
"""CLI entry point that renders interpkit's intentional errors cleanly.
|
|
1195
1530
|
|
|
@@ -51,6 +51,20 @@ VALID_FIND_CIRCUIT_METHODS = frozenset({
|
|
|
51
51
|
"zero",
|
|
52
52
|
"mean",
|
|
53
53
|
"resample",
|
|
54
|
+
# Gradient-based selection (phase 2): EAP-ranked components, verified
|
|
55
|
+
# causally with mean ablation.
|
|
56
|
+
"eap",
|
|
57
|
+
"eap-ig",
|
|
58
|
+
})
|
|
59
|
+
|
|
60
|
+
VALID_EAP_METRICS = frozenset({
|
|
61
|
+
# Only logit_diff has an EAP/AtP gradient formulation today.
|
|
62
|
+
"logit_diff",
|
|
63
|
+
})
|
|
64
|
+
|
|
65
|
+
VALID_LENS_KINDS = frozenset({
|
|
66
|
+
"logit",
|
|
67
|
+
"tuned", # Belrose et al. 2023 trained per-block translators (phase 3)
|
|
54
68
|
})
|
|
55
69
|
|
|
56
70
|
VALID_IG_METHODS = frozenset({
|
|
@@ -99,6 +113,8 @@ __all__ = [
|
|
|
99
113
|
"VALID_TRACE_METHODS",
|
|
100
114
|
"VALID_ABLATE_METHODS",
|
|
101
115
|
"VALID_FIND_CIRCUIT_METHODS",
|
|
116
|
+
"VALID_EAP_METRICS",
|
|
117
|
+
"VALID_LENS_KINDS",
|
|
102
118
|
"VALID_IG_METHODS",
|
|
103
119
|
"VALID_IG_BASELINES",
|
|
104
120
|
"_validate_enum",
|
|
@@ -48,7 +48,10 @@ def warn_if_leading_space_better(
|
|
|
48
48
|
This helper detects single-token leading-space variants and surfaces
|
|
49
49
|
a yellow tip; it is a no-op for tensor / list inputs, missing
|
|
50
50
|
tokenizers, empty strings, or strings that already begin with
|
|
51
|
-
whitespace.
|
|
51
|
+
whitespace. When the plain input splits into *multiple* subword tokens
|
|
52
|
+
(e.g. ``"Hate"`` -> ``['H', 'ate']``) the warning is escalated to a red
|
|
53
|
+
message, because averaging activations over the fragments produces a
|
|
54
|
+
direction that does not represent the word at all.
|
|
52
55
|
|
|
53
56
|
Parameters
|
|
54
57
|
----------
|
|
@@ -104,13 +107,28 @@ def warn_if_leading_space_better(
|
|
|
104
107
|
|
|
105
108
|
console = _Console()
|
|
106
109
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
110
|
+
spaced = " " + text
|
|
111
|
+
if len(ids_plain) > 1:
|
|
112
|
+
# Egregious case: a single-word steering/contrast term that splits
|
|
113
|
+
# into multiple subword tokens (e.g. "Hate" -> ['H', 'ate']). The op
|
|
114
|
+
# averages activations across those fragments, so the resulting
|
|
115
|
+
# direction does not represent the word at all — almost always a
|
|
116
|
+
# mistake. Escalate from a tip to a red warning so it can't be missed.
|
|
117
|
+
console.print(
|
|
118
|
+
f" [bold red]{op_label}:[/bold red] {role} input {text!r} splits into "
|
|
119
|
+
f"{len(ids_plain)} subword tokens {ids_plain} — {op_label} will average "
|
|
120
|
+
f"meaningless fragments, so the result will not reflect {text!r}. "
|
|
121
|
+
f"{spaced!r} is a single token {ids_spaced}; use it instead "
|
|
122
|
+
f"(BPE leading-space convention)."
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
console.print(
|
|
126
|
+
f" [yellow]{op_label}:[/yellow] {role} input {text!r} tokenizes to "
|
|
127
|
+
f"{len(ids_plain)} token(s) {ids_plain}, but "
|
|
128
|
+
f"{spaced!r} is a single token {ids_spaced}. "
|
|
129
|
+
f"Consider using {spaced!r} for a stronger contrast "
|
|
130
|
+
f"(BPE leading-space convention)."
|
|
131
|
+
)
|
|
114
132
|
warned_count[0] += 1
|
|
115
133
|
|
|
116
134
|
|