interpkit 0.1.0__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit-0.3.0/PKG-INFO +509 -0
- interpkit-0.3.0/README.md +461 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/__init__.py +14 -2
- interpkit-0.3.0/interpkit/cli/main.py +963 -0
- interpkit-0.3.0/interpkit/core/cache.py +36 -0
- interpkit-0.3.0/interpkit/core/discovery.py +810 -0
- interpkit-0.3.0/interpkit/core/html.py +697 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/core/inputs.py +26 -13
- interpkit-0.3.0/interpkit/core/loader.py +292 -0
- interpkit-0.3.0/interpkit/core/model.py +764 -0
- interpkit-0.3.0/interpkit/core/plot.py +540 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/core/registry.py +18 -4
- interpkit-0.3.0/interpkit/core/render.py +782 -0
- interpkit-0.3.0/interpkit/core/theme.py +33 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/core/tl_compat.py +3 -3
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/ablate.py +41 -8
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/activations.py +11 -7
- interpkit-0.3.0/interpkit/ops/attention.py +365 -0
- interpkit-0.3.0/interpkit/ops/attribute.py +308 -0
- interpkit-0.3.0/interpkit/ops/batch.py +257 -0
- interpkit-0.3.0/interpkit/ops/circuits.py +526 -0
- interpkit-0.3.0/interpkit/ops/diff.py +105 -0
- interpkit-0.3.0/interpkit/ops/dla.py +488 -0
- interpkit-0.3.0/interpkit/ops/find_circuit.py +356 -0
- interpkit-0.3.0/interpkit/ops/heads.py +175 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/inspect.py +2 -2
- interpkit-0.3.0/interpkit/ops/lens.py +243 -0
- interpkit-0.3.0/interpkit/ops/patch.py +261 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/probe.py +70 -26
- interpkit-0.3.0/interpkit/ops/report.py +258 -0
- interpkit-0.3.0/interpkit/ops/sae.py +439 -0
- interpkit-0.3.0/interpkit/ops/scan.py +288 -0
- interpkit-0.3.0/interpkit/ops/steer.py +176 -0
- interpkit-0.3.0/interpkit/ops/trace.py +349 -0
- interpkit-0.3.0/interpkit.egg-info/PKG-INFO +509 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/SOURCES.txt +23 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/requires.txt +14 -1
- {interpkit-0.1.0 → interpkit-0.3.0}/pyproject.toml +46 -6
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_activations.py +3 -1
- interpkit-0.3.0/tests/test_architectures.py +286 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_attention.py +3 -1
- interpkit-0.3.0/tests/test_cli.py +682 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_diff.py +8 -3
- interpkit-0.3.0/tests/test_discovery.py +80 -0
- interpkit-0.3.0/tests/test_discovery_units.py +874 -0
- interpkit-0.3.0/tests/test_error_handling.py +206 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_inspect.py +2 -3
- interpkit-0.3.0/tests/test_invariants.py +108 -0
- interpkit-0.3.0/tests/test_load_params.py +291 -0
- interpkit-0.3.0/tests/test_multi_arch.py +140 -0
- interpkit-0.3.0/tests/test_ops.py +219 -0
- interpkit-0.3.0/tests/test_param_variants.py +140 -0
- interpkit-0.3.0/tests/test_plot_internals.py +586 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_registry.py +0 -1
- interpkit-0.3.0/tests/test_regressions.py +90 -0
- interpkit-0.3.0/tests/test_render_internals.py +674 -0
- interpkit-0.3.0/tests/test_sae.py +219 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_tl_compat.py +0 -2
- interpkit-0.3.0/tests/test_tl_ops.py +244 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_trace.py +1 -3
- interpkit-0.1.0/PKG-INFO +0 -295
- interpkit-0.1.0/README.md +0 -258
- interpkit-0.1.0/interpkit/cli/main.py +0 -337
- interpkit-0.1.0/interpkit/core/discovery.py +0 -228
- interpkit-0.1.0/interpkit/core/html.py +0 -375
- interpkit-0.1.0/interpkit/core/model.py +0 -551
- interpkit-0.1.0/interpkit/core/plot.py +0 -352
- interpkit-0.1.0/interpkit/core/render.py +0 -465
- interpkit-0.1.0/interpkit/ops/attention.py +0 -234
- interpkit-0.1.0/interpkit/ops/attribute.py +0 -206
- interpkit-0.1.0/interpkit/ops/diff.py +0 -79
- interpkit-0.1.0/interpkit/ops/lens.py +0 -151
- interpkit-0.1.0/interpkit/ops/patch.py +0 -112
- interpkit-0.1.0/interpkit/ops/sae.py +0 -212
- interpkit-0.1.0/interpkit/ops/steer.py +0 -118
- interpkit-0.1.0/interpkit/ops/trace.py +0 -182
- interpkit-0.1.0/interpkit.egg-info/PKG-INFO +0 -295
- interpkit-0.1.0/tests/test_discovery.py +0 -61
- interpkit-0.1.0/tests/test_sae.py +0 -115
- {interpkit-0.1.0 → interpkit-0.3.0}/LICENSE +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/cli/__init__.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/core/__init__.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/__init__.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/dependency_links.txt +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/entry_points.txt +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/top_level.txt +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/setup.cfg +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_ablate.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_attribute.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_cache.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_html.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_lens.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_patch.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_plots.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_probe.py +0 -0
- {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_steer.py +0 -0
interpkit-0.3.0/PKG-INFO
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: interpkit
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Mech interp for any HuggingFace model.
|
|
5
|
+
Author: Davide Zani
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/z4nix/interpkit
|
|
8
|
+
Project-URL: Repository, https://github.com/z4nix/interpkit
|
|
9
|
+
Project-URL: Issues, https://github.com/z4nix/interpkit/issues
|
|
10
|
+
Keywords: mechanistic-interpretability,pytorch,transformers,mech-interp,interpretability
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.10
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: torch>=2.1
|
|
23
|
+
Requires-Dist: transformers>=4.36
|
|
24
|
+
Requires-Dist: safetensors>=0.4
|
|
25
|
+
Requires-Dist: rich>=13.0
|
|
26
|
+
Requires-Dist: rich-gradient>=0.3
|
|
27
|
+
Requires-Dist: typer>=0.9
|
|
28
|
+
Requires-Dist: Pillow>=10.0
|
|
29
|
+
Requires-Dist: matplotlib>=3.8
|
|
30
|
+
Requires-Dist: huggingface-hub>=0.20
|
|
31
|
+
Provides-Extra: vision
|
|
32
|
+
Requires-Dist: torchvision>=0.16; extra == "vision"
|
|
33
|
+
Provides-Extra: probe
|
|
34
|
+
Requires-Dist: scikit-learn>=1.3; extra == "probe"
|
|
35
|
+
Provides-Extra: dev
|
|
36
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
37
|
+
Requires-Dist: pytest-timeout>=2.2; extra == "dev"
|
|
38
|
+
Requires-Dist: pytest-cov>=5.0; extra == "dev"
|
|
39
|
+
Requires-Dist: scikit-learn>=1.3; extra == "dev"
|
|
40
|
+
Requires-Dist: torchvision>=0.16; extra == "dev"
|
|
41
|
+
Requires-Dist: ruff>=0.4; extra == "dev"
|
|
42
|
+
Requires-Dist: mypy>=1.8; extra == "dev"
|
|
43
|
+
Provides-Extra: docs
|
|
44
|
+
Requires-Dist: mkdocs>=1.5; extra == "docs"
|
|
45
|
+
Requires-Dist: mkdocs-material>=9.5; extra == "docs"
|
|
46
|
+
Requires-Dist: mkdocstrings[python]>=0.24; extra == "docs"
|
|
47
|
+
Dynamic: license-file
|
|
48
|
+
|
|
49
|
+
<p align="center">
|
|
50
|
+
<img src="assets/logo.svg" alt="InterpKit" width="680">
|
|
51
|
+
</p>
|
|
52
|
+
|
|
53
|
+
[](https://pypi.org/project/interpkit/)
|
|
54
|
+
[](https://opensource.org/licenses/MIT)
|
|
55
|
+
[](https://www.python.org/downloads/)
|
|
56
|
+
|
|
57
|
+
---
|
|
58
|
+
|
|
59
|
+
## Why InterpKit?
|
|
60
|
+
|
|
61
|
+
Mechanistic interpretability tooling today is fragmented. Each library supports a narrow set of architectures, and moving to a different model family usually means rewriting hook code from scratch.
|
|
62
|
+
|
|
63
|
+
InterpKit provides a single, consistent interface for mech interp operations across any HuggingFace model — transformers, SSMs, vision models, and more — with zero annotation required.
|
|
64
|
+
|
|
65
|
+
---
|
|
66
|
+
|
|
67
|
+
## Install
|
|
68
|
+
|
|
69
|
+
```bash
|
|
70
|
+
pip install interpkit
|
|
71
|
+
|
|
72
|
+
# For linear probe support:
|
|
73
|
+
pip install interpkit[probe]
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Or install from source for development:
|
|
77
|
+
|
|
78
|
+
```bash
|
|
79
|
+
git clone https://github.com/z4nix/interpkit.git
|
|
80
|
+
cd interpkit
|
|
81
|
+
pip install -e ".[dev]"
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
---
|
|
85
|
+
|
|
86
|
+
## Quickstart
|
|
87
|
+
|
|
88
|
+
```python
|
|
89
|
+
import interpkit
|
|
90
|
+
|
|
91
|
+
model = interpkit.load("gpt2")
|
|
92
|
+
|
|
93
|
+
# One-command model overview — runs DLA, logit lens, attention, attribution
|
|
94
|
+
# and surfaces the most interesting findings automatically
|
|
95
|
+
model.scan("The capital of France is")
|
|
96
|
+
|
|
97
|
+
# Or run individual operations:
|
|
98
|
+
model.inspect() # module tree
|
|
99
|
+
model.dla("The capital of France is") # direct logit attribution
|
|
100
|
+
model.trace("...Paris...", "...Rome...", top_k=20) # causal tracing
|
|
101
|
+
model.lens("The capital of France is") # logit lens (all positions)
|
|
102
|
+
model.attribute("The capital of France is") # gradient saliency
|
|
103
|
+
model.decompose("The capital of France is") # residual stream decomposition
|
|
104
|
+
```
|
|
105
|
+
|
|
106
|
+
Works the same on any HF architecture:
|
|
107
|
+
|
|
108
|
+
```python
|
|
109
|
+
model = interpkit.load("state-spaces/mamba-370m")
|
|
110
|
+
model = interpkit.load("google/vit-base-patch16-224")
|
|
111
|
+
model = interpkit.load("bert-base-uncased")
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
---
|
|
115
|
+
|
|
116
|
+
## Operations
|
|
117
|
+
|
|
118
|
+
| Operation | What it does | Works on |
|
|
119
|
+
|-----------|-------------|----------|
|
|
120
|
+
| **`scan`** | One-command model overview: runs DLA, lens, attention, attribution and surfaces key findings | LMs |
|
|
121
|
+
| **`dla`** | Direct Logit Attribution — decompose output logits by head and MLP contribution; optionally decompose through an SAE into per-feature attributions | LMs |
|
|
122
|
+
| `inspect` | Module tree with types, param counts, shapes | Any model |
|
|
123
|
+
| `patch` | Activation patching at a module, head, or position | Any model |
|
|
124
|
+
| `trace` | Causal tracing — module-level or position-aware (Meng et al.) heatmap | Any model |
|
|
125
|
+
| `attribute` | Gradient saliency over inputs (returns scores programmatically) | Any model |
|
|
126
|
+
| `lens` | Logit lens — project activations to vocabulary at all positions | LMs (auto-detected) |
|
|
127
|
+
| `activations` | Extract raw activation tensors at any module | Any model |
|
|
128
|
+
| `head_activations` | Decompose attention output into per-head contributions | Transformers |
|
|
129
|
+
| `ablate` | Zero/mean ablate a component and measure effect | Any model |
|
|
130
|
+
| `attention` | Visualize attention patterns per layer/head | Transformers |
|
|
131
|
+
| `steer` | Extract and apply steering vectors | LMs |
|
|
132
|
+
| `probe` | Linear probe on activations | Any model |
|
|
133
|
+
| `diff` | Compare activations between two models | Any model |
|
|
134
|
+
| `features` | SAE feature decomposition | Any model |
|
|
135
|
+
| **`decompose`** | Residual stream decomposition — per-component norms | Transformers |
|
|
136
|
+
| **`ov_scores`** | OV circuit analysis — W_OV matrix per head | Transformers |
|
|
137
|
+
| **`qk_scores`** | QK circuit analysis — W_QK matrix per head | Transformers |
|
|
138
|
+
| **`composition`** | Q/K/V composition scores between heads in two layers | Transformers |
|
|
139
|
+
| **`find_circuit`** | Automated circuit discovery via iterative ablation | Transformers |
|
|
140
|
+
| **`batch`** | Run any operation over a dataset with result aggregation | Any model |
|
|
141
|
+
|
|
142
|
+
---
|
|
143
|
+
|
|
144
|
+
## Scan — One-Command Model Overview
|
|
145
|
+
|
|
146
|
+
The fastest way to understand what a model is doing on an input. Runs DLA, logit lens, attention analysis, and gradient attribution, then surfaces the most interesting findings in a ranked summary:
|
|
147
|
+
|
|
148
|
+
```python
|
|
149
|
+
model.scan("The capital of France is")
|
|
150
|
+
# Output:
|
|
151
|
+
# Predictions: "the" (8.5%), "now" (4.8%), "a" (4.6%)
|
|
152
|
+
# Key Findings (ranked by significance):
|
|
153
|
+
# 1. Top contributor to "the": L11.attn (+204.701)
|
|
154
|
+
# 2. Top attention head: L11.H0 (+149.850)
|
|
155
|
+
# 3. Most salient input token: "is" (score 12.435)
|
|
156
|
+
# 4. Answer "the" first appears at layer 9/12
|
|
157
|
+
|
|
158
|
+
model.scan("The capital of France is", save="scan") # exports scan_dla.png, scan_lens.png, etc.
|
|
159
|
+
```
|
|
160
|
+
|
|
161
|
+
## Direct Logit Attribution (DLA)
|
|
162
|
+
|
|
163
|
+
Answers the fundamental question: *why does the model predict this token?* Decomposes the output logit by component (attention block + MLP per layer) and by individual attention head:
|
|
164
|
+
|
|
165
|
+
```python
|
|
166
|
+
result = model.dla("The capital of France is")
|
|
167
|
+
# result["contributions"] — per-component logit contributions, sorted
|
|
168
|
+
# result["head_contributions"] — per-head breakdown
|
|
169
|
+
# result["target_token"] — the token being attributed
|
|
170
|
+
|
|
171
|
+
# Attribute a specific token
|
|
172
|
+
model.dla("The capital of France is", token="Paris")
|
|
173
|
+
|
|
174
|
+
# Save a bar chart
|
|
175
|
+
model.dla("The capital of France is", save="dla.png")
|
|
176
|
+
|
|
177
|
+
# Feature-level DLA — decompose a component through an SAE
|
|
178
|
+
# to see which individual features drive the prediction
|
|
179
|
+
model.dla(
|
|
180
|
+
"The capital of France is",
|
|
181
|
+
sae="jbloom/GPT2-Small-SAEs-Reformatted",
|
|
182
|
+
sae_at="transformer.h.11.attn",
|
|
183
|
+
)
|
|
184
|
+
# result["feature_contributions"]["features"]
|
|
185
|
+
# — per-feature logit attributions at the specified component
|
|
186
|
+
```
|
|
187
|
+
|
|
188
|
+
## Causal Tracing
|
|
189
|
+
|
|
190
|
+
```python
|
|
191
|
+
# Module-level tracing (default) — rank modules by causal effect
|
|
192
|
+
model.trace("...Paris...", "...Rome...", top_k=20)
|
|
193
|
+
|
|
194
|
+
# Position-aware tracing (Meng et al. 2022) — (layer x position) heatmap
|
|
195
|
+
model.trace("...Paris...", "...Rome...", mode="position", save="trace.png")
|
|
196
|
+
```
|
|
197
|
+
|
|
198
|
+
## Logit Lens
|
|
199
|
+
|
|
200
|
+
Now analyses all token positions by default, producing the classic (layers x positions) heatmap:
|
|
201
|
+
|
|
202
|
+
```python
|
|
203
|
+
model.lens("The capital of France is") # all positions
|
|
204
|
+
model.lens("The capital of France is", position=-1) # last position only
|
|
205
|
+
model.lens("The capital of France is", save="lens.png") # 2D heatmap export
|
|
206
|
+
```
|
|
207
|
+
|
|
208
|
+
## Attribution
|
|
209
|
+
|
|
210
|
+
Returns scores programmatically (no longer just prints):
|
|
211
|
+
|
|
212
|
+
```python
|
|
213
|
+
result = model.attribute("The capital of France is")
|
|
214
|
+
result["tokens"] # ["The", "capital", "of", "France", "is"]
|
|
215
|
+
result["scores"] # [8.88, 11.15, 7.24, 7.37, 12.43]
|
|
216
|
+
result["target"] # 262
|
|
217
|
+
```
|
|
218
|
+
|
|
219
|
+
## Activation Patching
|
|
220
|
+
|
|
221
|
+
Supports module-level, head-level, and position-level patching:
|
|
222
|
+
|
|
223
|
+
```python
|
|
224
|
+
# Module-level (original)
|
|
225
|
+
model.patch(clean, corrupted, at="transformer.h.8.mlp")
|
|
226
|
+
|
|
227
|
+
# Head-level — patch only attention head 3
|
|
228
|
+
model.patch(clean, corrupted, at="transformer.h.8", head=3)
|
|
229
|
+
|
|
230
|
+
# Position-level — patch only positions 3 and 4
|
|
231
|
+
model.patch(clean, corrupted, at="transformer.h.8", positions=[3, 4])
|
|
232
|
+
|
|
233
|
+
# Combined — patch head 3 at positions 3 and 4
|
|
234
|
+
model.patch(clean, corrupted, at="transformer.h.8", head=3, positions=[3, 4])
|
|
235
|
+
```
|
|
236
|
+
|
|
237
|
+
## Head-Level Activations
|
|
238
|
+
|
|
239
|
+
Decompose an attention module's output into per-head contributions, optionally projected through W_O into residual-stream space:
|
|
240
|
+
|
|
241
|
+
```python
|
|
242
|
+
result = model.head_activations("The capital of France is", at="transformer.h.8")
|
|
243
|
+
result["head_acts"] # tensor (num_heads, batch, seq, d_model)
|
|
244
|
+
result["num_heads"] # 12
|
|
245
|
+
result["head_dim"] # 64
|
|
246
|
+
```
|
|
247
|
+
|
|
248
|
+
## Activations, Ablation, Attention
|
|
249
|
+
|
|
250
|
+
```python
|
|
251
|
+
# Extract raw activations
|
|
252
|
+
act = model.activations("The capital of France is", at="transformer.h.8.mlp")
|
|
253
|
+
acts = model.activations("...", at=["transformer.h.0", "transformer.h.8.mlp"])
|
|
254
|
+
|
|
255
|
+
# Ablation — zero or mean
|
|
256
|
+
result = model.ablate("The capital of France is", at="transformer.h.8.mlp")
|
|
257
|
+
result = model.ablate("...", at="transformer.h.8.mlp", method="mean")
|
|
258
|
+
|
|
259
|
+
# Attention patterns
|
|
260
|
+
model.attention("The capital of France is") # all layers
|
|
261
|
+
model.attention("The capital of France is", layer=8, head=3) # single head
|
|
262
|
+
```
|
|
263
|
+
|
|
264
|
+
## Residual Stream Decomposition
|
|
265
|
+
|
|
266
|
+
Break down the residual stream at any position into contributions from each attention block and MLP:
|
|
267
|
+
|
|
268
|
+
```python
|
|
269
|
+
result = model.decompose("The capital of France is")
|
|
270
|
+
# result["components"] — list of {"name": "L8.attn", "type": "attn", "norm": 8.94, ...}
|
|
271
|
+
# result["residual"] — final residual stream vector
|
|
272
|
+
```
|
|
273
|
+
|
|
274
|
+
## OV / QK Circuit Analysis
|
|
275
|
+
|
|
276
|
+
Analyse the effective weight matrices of attention heads:
|
|
277
|
+
|
|
278
|
+
```python
|
|
279
|
+
# OV circuit: what does each head write to the residual stream?
|
|
280
|
+
model.ov_scores(layer=8)
|
|
281
|
+
# Per-head Frobenius norm, top singular values, approximate rank of W_OV
|
|
282
|
+
|
|
283
|
+
# QK circuit: what does each head attend to?
|
|
284
|
+
model.qk_scores(layer=8)
|
|
285
|
+
|
|
286
|
+
# Composition: how much does head j in layer 4 compose with head i in layer 8?
|
|
287
|
+
model.composition(src_layer=4, dst_layer=8, comp_type="q") # Q-composition
|
|
288
|
+
model.composition(src_layer=4, dst_layer=8, comp_type="k") # K-composition
|
|
289
|
+
model.composition(src_layer=4, dst_layer=8, comp_type="v") # V-composition
|
|
290
|
+
```
|
|
291
|
+
|
|
292
|
+
## Circuit Discovery
|
|
293
|
+
|
|
294
|
+
Automatically find the minimal set of components that explain a behaviour:
|
|
295
|
+
|
|
296
|
+
```python
|
|
297
|
+
circuit = model.find_circuit(
|
|
298
|
+
"The Eiffel Tower is in Paris",
|
|
299
|
+
"The Eiffel Tower is in Rome",
|
|
300
|
+
threshold=0.05,
|
|
301
|
+
)
|
|
302
|
+
# circuit["circuit"] — components in the circuit, sorted by effect
|
|
303
|
+
# circuit["excluded"] — components not in the circuit
|
|
304
|
+
# circuit["verification"] — faithfulness check (how much output is preserved
|
|
305
|
+
# when all non-circuit components are ablated)
|
|
306
|
+
```
|
|
307
|
+
|
|
308
|
+
## Batch / Dataset Operations
|
|
309
|
+
|
|
310
|
+
Run any operation over a dataset of examples with automatic result aggregation:
|
|
311
|
+
|
|
312
|
+
```python
|
|
313
|
+
# Generic batch runner
|
|
314
|
+
results = model.batch("trace", [
|
|
315
|
+
{"clean": "...Paris...", "corrupted": "...Rome..."},
|
|
316
|
+
{"clean": "...Berlin...", "corrupted": "...Madrid..."},
|
|
317
|
+
], op_kwargs={"top_k": 10})
|
|
318
|
+
# results["summary"]["ranked_modules"] — modules ranked by mean effect across examples
|
|
319
|
+
|
|
320
|
+
# Convenience: trace over a dataset
|
|
321
|
+
results = model.trace_batch(dataset, clean_col="clean", corrupted_col="corrupted")
|
|
322
|
+
|
|
323
|
+
# Convenience: DLA over a list of texts
|
|
324
|
+
results = model.dla_batch(["The capital of France is", "The CEO of Apple is"])
|
|
325
|
+
# results["summary"]["ranked_components"] — components ranked by mean contribution
|
|
326
|
+
```
|
|
327
|
+
|
|
328
|
+
## Steering
|
|
329
|
+
|
|
330
|
+
```python
|
|
331
|
+
vector = model.steer_vector("Love", "Hate", at="transformer.h.8")
|
|
332
|
+
model.steer("The weather today is", vector=vector, at="transformer.h.8", scale=2.0)
|
|
333
|
+
```
|
|
334
|
+
|
|
335
|
+
## Linear Probe
|
|
336
|
+
|
|
337
|
+
```python
|
|
338
|
+
result = model.probe(
|
|
339
|
+
texts=["The cat sat", "The dog ran", "A bird flew", "A fish swam"],
|
|
340
|
+
labels=[0, 0, 1, 1],
|
|
341
|
+
at="transformer.h.8",
|
|
342
|
+
)
|
|
343
|
+
print(result["accuracy"])
|
|
344
|
+
```
|
|
345
|
+
|
|
346
|
+
## Model Diff
|
|
347
|
+
|
|
348
|
+
```python
|
|
349
|
+
base = interpkit.load("gpt2")
|
|
350
|
+
finetuned = interpkit.load("my-finetuned-gpt2")
|
|
351
|
+
interpkit.diff(base, finetuned, "The capital of France is")
|
|
352
|
+
```
|
|
353
|
+
|
|
354
|
+
## SAE Features
|
|
355
|
+
|
|
356
|
+
Decompose activations into interpretable features using pre-trained Sparse Autoencoders:
|
|
357
|
+
|
|
358
|
+
```python
|
|
359
|
+
# From HuggingFace
|
|
360
|
+
model.features(
|
|
361
|
+
"The capital of France is",
|
|
362
|
+
at="transformer.h.8",
|
|
363
|
+
sae="jbloom/GPT2-Small-SAEs-Reformatted",
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# From a local file (.safetensors or .pt)
|
|
367
|
+
model.features(
|
|
368
|
+
"The capital of France is",
|
|
369
|
+
at="transformer.h.8",
|
|
370
|
+
sae="/path/to/sae_weights.safetensors",
|
|
371
|
+
)
|
|
372
|
+
```
|
|
373
|
+
|
|
374
|
+
No SAELens dependency — weights are loaded directly via `safetensors`.
|
|
375
|
+
|
|
376
|
+
## Activation Cache
|
|
377
|
+
|
|
378
|
+
Avoid redundant forward passes when exploring the same input with multiple operations:
|
|
379
|
+
|
|
380
|
+
```python
|
|
381
|
+
model.cache("The capital of France is") # one forward pass, cache all layers
|
|
382
|
+
model.activations("The capital of France is", at="transformer.h.8.mlp") # instant
|
|
383
|
+
model.activations("The capital of France is", at="transformer.h.0.mlp") # instant
|
|
384
|
+
|
|
385
|
+
model.clear_cache() # free memory
|
|
386
|
+
```
|
|
387
|
+
|
|
388
|
+
---
|
|
389
|
+
|
|
390
|
+
## Visualizations
|
|
391
|
+
|
|
392
|
+
Pass `save="path.png"` to export a static matplotlib figure, or `html="path.html"` for an interactive visualization:
|
|
393
|
+
|
|
394
|
+
```python
|
|
395
|
+
model.attention("hello world", layer=0, head=0, save="attention.png")
|
|
396
|
+
model.trace("...Paris...", "...Rome...", save="trace.png")
|
|
397
|
+
model.trace("...Paris...", "...Rome...", mode="position", save="position_trace.png")
|
|
398
|
+
model.lens("The capital of France is", save="lens.png")
|
|
399
|
+
model.steer("The weather is", vector=vector, at="transformer.h.8", save="steer.png")
|
|
400
|
+
model.attribute("The capital of France is", save="attribution.png")
|
|
401
|
+
model.dla("The capital of France is", save="dla.png")
|
|
402
|
+
model.scan("The capital of France is", save="scan")
|
|
403
|
+
interpkit.diff(base, finetuned, "...", save="diff.png")
|
|
404
|
+
|
|
405
|
+
# Interactive HTML — self-contained files with hover tooltips, filters, and sliders
|
|
406
|
+
model.attention("hello world", html="attention.html")
|
|
407
|
+
model.trace("...Paris...", "...Rome...", html="trace.html")
|
|
408
|
+
model.attribute("The capital of France is", html="attribution.html")
|
|
409
|
+
```
|
|
410
|
+
|
|
411
|
+
---
|
|
412
|
+
|
|
413
|
+
## CLI
|
|
414
|
+
|
|
415
|
+
```bash
|
|
416
|
+
interpkit inspect gpt2
|
|
417
|
+
interpkit scan gpt2 "The capital of France is"
|
|
418
|
+
interpkit dla gpt2 "The capital of France is"
|
|
419
|
+
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --top-k 20
|
|
420
|
+
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --mode position --save trace.png
|
|
421
|
+
interpkit lens gpt2 "The capital of France is"
|
|
422
|
+
interpkit lens gpt2 "The capital of France is" --position -1
|
|
423
|
+
interpkit attention gpt2 "The capital of France is" --layer 8 --save attention.png
|
|
424
|
+
interpkit attribute gpt2 "The capital of France is"
|
|
425
|
+
interpkit steer gpt2 "The weather is" --positive Love --negative Hate --at transformer.h.8
|
|
426
|
+
interpkit ablate gpt2 "The capital of France is" --at transformer.h.8.mlp
|
|
427
|
+
interpkit decompose gpt2 "The capital of France is"
|
|
428
|
+
interpkit diff gpt2 my-finetuned-gpt2 "The capital of France is" --save diff.png
|
|
429
|
+
interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jbloom/GPT2-Small-SAEs-Reformatted
|
|
430
|
+
interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
|
|
431
|
+
interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
|
|
432
|
+
|
|
433
|
+
# Interactive HTML output
|
|
434
|
+
interpkit attention gpt2 "hello world" --html attention.html
|
|
435
|
+
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --html trace.html
|
|
436
|
+
interpkit attribute gpt2 "The capital of France is" --html attribution.html
|
|
437
|
+
|
|
438
|
+
# Vision models — auto-preprocessed
|
|
439
|
+
interpkit attribute microsoft/resnet-50 cat.jpg --target 281
|
|
440
|
+
```
|
|
441
|
+
|
|
442
|
+
Run `interpkit` with no arguments for a full command reference.
|
|
443
|
+
|
|
444
|
+
---
|
|
445
|
+
|
|
446
|
+
## TransformerLens interop
|
|
447
|
+
|
|
448
|
+
Already using TransformerLens? Pass your `HookedTransformer` directly into InterpKit — it auto-detects the model and extracts the tokenizer:
|
|
449
|
+
|
|
450
|
+
```python
|
|
451
|
+
from transformer_lens import HookedTransformer
|
|
452
|
+
import interpkit
|
|
453
|
+
|
|
454
|
+
tl_model = HookedTransformer.from_pretrained("gpt2")
|
|
455
|
+
model = interpkit.load(tl_model)
|
|
456
|
+
|
|
457
|
+
# All InterpKit operations work on TL models
|
|
458
|
+
model.scan("The capital of France is")
|
|
459
|
+
model.dla("The capital of France is")
|
|
460
|
+
model.trace("The Eiffel Tower is in Paris", "The Eiffel Tower is in Rome", top_k=20)
|
|
461
|
+
model.attention("The capital of France is", save="attention.png")
|
|
462
|
+
model.steer("The weather is", vector=vector, at="blocks.8", scale=2.0)
|
|
463
|
+
```
|
|
464
|
+
|
|
465
|
+
Translate between native and TL hook point names:
|
|
466
|
+
|
|
467
|
+
```python
|
|
468
|
+
interpkit.to_tl_name("transformer.h.8.mlp") # -> "blocks.8.mlp"
|
|
469
|
+
interpkit.to_native_name("blocks.8.attn", model.arch_info) # -> "transformer.h.8.attn"
|
|
470
|
+
interpkit.list_tl_hooks(tl_model) # -> ["blocks.0.hook_resid_pre", ...]
|
|
471
|
+
```
|
|
472
|
+
|
|
473
|
+
---
|
|
474
|
+
|
|
475
|
+
## Local models
|
|
476
|
+
|
|
477
|
+
```python
|
|
478
|
+
import torch.nn as nn
|
|
479
|
+
import interpkit
|
|
480
|
+
|
|
481
|
+
my_model = MyCustomModel()
|
|
482
|
+
interpkit.register(my_model, layers=["blocks.0", "blocks.1"], output_head="head")
|
|
483
|
+
model = interpkit.load(my_model, tokenizer=my_tokenizer)
|
|
484
|
+
model.trace(input_a, input_b, top_k=10)
|
|
485
|
+
```
|
|
486
|
+
|
|
487
|
+
---
|
|
488
|
+
|
|
489
|
+
## Examples
|
|
490
|
+
|
|
491
|
+
See the [`examples/`](examples/) directory for Jupyter notebooks:
|
|
492
|
+
|
|
493
|
+
| Notebook | Topics |
|
|
494
|
+
|----------|--------|
|
|
495
|
+
| `01_quickstart` | Inspect, scan, DLA, trace, lens, attribution, patching, ablation |
|
|
496
|
+
| `02_attention_patterns` | Per-head heatmaps, layer filtering, HTML export |
|
|
497
|
+
| `03_steering_vectors` | Extract and apply steering vectors at different layers/scales |
|
|
498
|
+
| `04_sae_features` | Sparse Autoencoder feature decomposition |
|
|
499
|
+
| `05_caching_and_probing` | Activation cache, linear probes across layers |
|
|
500
|
+
| `06_model_comparison` | Diff two models, side-by-side tracing and logit lens |
|
|
501
|
+
| `07_vision_models` | ResNet/ViT attribution, ablation, activations |
|
|
502
|
+
| `08_dla_and_circuits` | DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery |
|
|
503
|
+
| `09_scan_and_batch` | Auto-scan, batch operations, dataset workflows |
|
|
504
|
+
|
|
505
|
+
---
|
|
506
|
+
|
|
507
|
+
## License
|
|
508
|
+
|
|
509
|
+
MIT
|