interpkit 0.1.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. interpkit-0.3.0/PKG-INFO +509 -0
  2. interpkit-0.3.0/README.md +461 -0
  3. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/__init__.py +14 -2
  4. interpkit-0.3.0/interpkit/cli/main.py +963 -0
  5. interpkit-0.3.0/interpkit/core/cache.py +36 -0
  6. interpkit-0.3.0/interpkit/core/discovery.py +810 -0
  7. interpkit-0.3.0/interpkit/core/html.py +697 -0
  8. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/core/inputs.py +26 -13
  9. interpkit-0.3.0/interpkit/core/loader.py +292 -0
  10. interpkit-0.3.0/interpkit/core/model.py +764 -0
  11. interpkit-0.3.0/interpkit/core/plot.py +540 -0
  12. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/core/registry.py +18 -4
  13. interpkit-0.3.0/interpkit/core/render.py +782 -0
  14. interpkit-0.3.0/interpkit/core/theme.py +33 -0
  15. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/core/tl_compat.py +3 -3
  16. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/ablate.py +41 -8
  17. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/activations.py +11 -7
  18. interpkit-0.3.0/interpkit/ops/attention.py +365 -0
  19. interpkit-0.3.0/interpkit/ops/attribute.py +308 -0
  20. interpkit-0.3.0/interpkit/ops/batch.py +257 -0
  21. interpkit-0.3.0/interpkit/ops/circuits.py +526 -0
  22. interpkit-0.3.0/interpkit/ops/diff.py +105 -0
  23. interpkit-0.3.0/interpkit/ops/dla.py +488 -0
  24. interpkit-0.3.0/interpkit/ops/find_circuit.py +356 -0
  25. interpkit-0.3.0/interpkit/ops/heads.py +175 -0
  26. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/inspect.py +2 -2
  27. interpkit-0.3.0/interpkit/ops/lens.py +243 -0
  28. interpkit-0.3.0/interpkit/ops/patch.py +261 -0
  29. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/probe.py +70 -26
  30. interpkit-0.3.0/interpkit/ops/report.py +258 -0
  31. interpkit-0.3.0/interpkit/ops/sae.py +439 -0
  32. interpkit-0.3.0/interpkit/ops/scan.py +288 -0
  33. interpkit-0.3.0/interpkit/ops/steer.py +176 -0
  34. interpkit-0.3.0/interpkit/ops/trace.py +349 -0
  35. interpkit-0.3.0/interpkit.egg-info/PKG-INFO +509 -0
  36. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/SOURCES.txt +23 -0
  37. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/requires.txt +14 -1
  38. {interpkit-0.1.0 → interpkit-0.3.0}/pyproject.toml +46 -6
  39. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_activations.py +3 -1
  40. interpkit-0.3.0/tests/test_architectures.py +286 -0
  41. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_attention.py +3 -1
  42. interpkit-0.3.0/tests/test_cli.py +682 -0
  43. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_diff.py +8 -3
  44. interpkit-0.3.0/tests/test_discovery.py +80 -0
  45. interpkit-0.3.0/tests/test_discovery_units.py +874 -0
  46. interpkit-0.3.0/tests/test_error_handling.py +206 -0
  47. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_inspect.py +2 -3
  48. interpkit-0.3.0/tests/test_invariants.py +108 -0
  49. interpkit-0.3.0/tests/test_load_params.py +291 -0
  50. interpkit-0.3.0/tests/test_multi_arch.py +140 -0
  51. interpkit-0.3.0/tests/test_ops.py +219 -0
  52. interpkit-0.3.0/tests/test_param_variants.py +140 -0
  53. interpkit-0.3.0/tests/test_plot_internals.py +586 -0
  54. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_registry.py +0 -1
  55. interpkit-0.3.0/tests/test_regressions.py +90 -0
  56. interpkit-0.3.0/tests/test_render_internals.py +674 -0
  57. interpkit-0.3.0/tests/test_sae.py +219 -0
  58. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_tl_compat.py +0 -2
  59. interpkit-0.3.0/tests/test_tl_ops.py +244 -0
  60. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_trace.py +1 -3
  61. interpkit-0.1.0/PKG-INFO +0 -295
  62. interpkit-0.1.0/README.md +0 -258
  63. interpkit-0.1.0/interpkit/cli/main.py +0 -337
  64. interpkit-0.1.0/interpkit/core/discovery.py +0 -228
  65. interpkit-0.1.0/interpkit/core/html.py +0 -375
  66. interpkit-0.1.0/interpkit/core/model.py +0 -551
  67. interpkit-0.1.0/interpkit/core/plot.py +0 -352
  68. interpkit-0.1.0/interpkit/core/render.py +0 -465
  69. interpkit-0.1.0/interpkit/ops/attention.py +0 -234
  70. interpkit-0.1.0/interpkit/ops/attribute.py +0 -206
  71. interpkit-0.1.0/interpkit/ops/diff.py +0 -79
  72. interpkit-0.1.0/interpkit/ops/lens.py +0 -151
  73. interpkit-0.1.0/interpkit/ops/patch.py +0 -112
  74. interpkit-0.1.0/interpkit/ops/sae.py +0 -212
  75. interpkit-0.1.0/interpkit/ops/steer.py +0 -118
  76. interpkit-0.1.0/interpkit/ops/trace.py +0 -182
  77. interpkit-0.1.0/interpkit.egg-info/PKG-INFO +0 -295
  78. interpkit-0.1.0/tests/test_discovery.py +0 -61
  79. interpkit-0.1.0/tests/test_sae.py +0 -115
  80. {interpkit-0.1.0 → interpkit-0.3.0}/LICENSE +0 -0
  81. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/cli/__init__.py +0 -0
  82. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/core/__init__.py +0 -0
  83. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit/ops/__init__.py +0 -0
  84. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/dependency_links.txt +0 -0
  85. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/entry_points.txt +0 -0
  86. {interpkit-0.1.0 → interpkit-0.3.0}/interpkit.egg-info/top_level.txt +0 -0
  87. {interpkit-0.1.0 → interpkit-0.3.0}/setup.cfg +0 -0
  88. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_ablate.py +0 -0
  89. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_attribute.py +0 -0
  90. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_cache.py +0 -0
  91. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_html.py +0 -0
  92. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_lens.py +0 -0
  93. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_patch.py +0 -0
  94. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_plots.py +0 -0
  95. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_probe.py +0 -0
  96. {interpkit-0.1.0 → interpkit-0.3.0}/tests/test_steer.py +0 -0
@@ -0,0 +1,509 @@
1
+ Metadata-Version: 2.4
2
+ Name: interpkit
3
+ Version: 0.3.0
4
+ Summary: Mech interp for any HuggingFace model.
5
+ Author: Davide Zani
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/z4nix/interpkit
8
+ Project-URL: Repository, https://github.com/z4nix/interpkit
9
+ Project-URL: Issues, https://github.com/z4nix/interpkit/issues
10
+ Keywords: mechanistic-interpretability,pytorch,transformers,mech-interp,interpretability
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Requires-Python: >=3.10
20
+ Description-Content-Type: text/markdown
21
+ License-File: LICENSE
22
+ Requires-Dist: torch>=2.1
23
+ Requires-Dist: transformers>=4.36
24
+ Requires-Dist: safetensors>=0.4
25
+ Requires-Dist: rich>=13.0
26
+ Requires-Dist: rich-gradient>=0.3
27
+ Requires-Dist: typer>=0.9
28
+ Requires-Dist: Pillow>=10.0
29
+ Requires-Dist: matplotlib>=3.8
30
+ Requires-Dist: huggingface-hub>=0.20
31
+ Provides-Extra: vision
32
+ Requires-Dist: torchvision>=0.16; extra == "vision"
33
+ Provides-Extra: probe
34
+ Requires-Dist: scikit-learn>=1.3; extra == "probe"
35
+ Provides-Extra: dev
36
+ Requires-Dist: pytest>=7.0; extra == "dev"
37
+ Requires-Dist: pytest-timeout>=2.2; extra == "dev"
38
+ Requires-Dist: pytest-cov>=5.0; extra == "dev"
39
+ Requires-Dist: scikit-learn>=1.3; extra == "dev"
40
+ Requires-Dist: torchvision>=0.16; extra == "dev"
41
+ Requires-Dist: ruff>=0.4; extra == "dev"
42
+ Requires-Dist: mypy>=1.8; extra == "dev"
43
+ Provides-Extra: docs
44
+ Requires-Dist: mkdocs>=1.5; extra == "docs"
45
+ Requires-Dist: mkdocs-material>=9.5; extra == "docs"
46
+ Requires-Dist: mkdocstrings[python]>=0.24; extra == "docs"
47
+ Dynamic: license-file
48
+
49
+ <p align="center">
50
+ <img src="assets/logo.svg" alt="InterpKit" width="680">
51
+ </p>
52
+
53
+ [![PyPI version](https://img.shields.io/pypi/v/interpkit.svg)](https://pypi.org/project/interpkit/)
54
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
55
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
56
+
57
+ ---
58
+
59
+ ## Why InterpKit?
60
+
61
+ Mechanistic interpretability tooling today is fragmented. Each library supports a narrow set of architectures, and moving to a different model family usually means rewriting hook code from scratch.
62
+
63
+ InterpKit provides a single, consistent interface for mech interp operations across any HuggingFace model — transformers, SSMs, vision models, and more — with zero annotation required.
64
+
65
+ ---
66
+
67
+ ## Install
68
+
69
+ ```bash
70
+ pip install interpkit
71
+
72
+ # For linear probe support:
73
+ pip install interpkit[probe]
74
+ ```
75
+
76
+ Or install from source for development:
77
+
78
+ ```bash
79
+ git clone https://github.com/z4nix/interpkit.git
80
+ cd interpkit
81
+ pip install -e ".[dev]"
82
+ ```
83
+
84
+ ---
85
+
86
+ ## Quickstart
87
+
88
+ ```python
89
+ import interpkit
90
+
91
+ model = interpkit.load("gpt2")
92
+
93
+ # One-command model overview — runs DLA, logit lens, attention, attribution
94
+ # and surfaces the most interesting findings automatically
95
+ model.scan("The capital of France is")
96
+
97
+ # Or run individual operations:
98
+ model.inspect() # module tree
99
+ model.dla("The capital of France is") # direct logit attribution
100
+ model.trace("...Paris...", "...Rome...", top_k=20) # causal tracing
101
+ model.lens("The capital of France is") # logit lens (all positions)
102
+ model.attribute("The capital of France is") # gradient saliency
103
+ model.decompose("The capital of France is") # residual stream decomposition
104
+ ```
105
+
106
+ Works the same on any HF architecture:
107
+
108
+ ```python
109
+ model = interpkit.load("state-spaces/mamba-370m")
110
+ model = interpkit.load("google/vit-base-patch16-224")
111
+ model = interpkit.load("bert-base-uncased")
112
+ ```
113
+
114
+ ---
115
+
116
+ ## Operations
117
+
118
+ | Operation | What it does | Works on |
119
+ |-----------|-------------|----------|
120
+ | **`scan`** | One-command model overview: runs DLA, lens, attention, attribution and surfaces key findings | LMs |
121
+ | **`dla`** | Direct Logit Attribution — decompose output logits by head and MLP contribution; optionally decompose through an SAE into per-feature attributions | LMs |
122
+ | `inspect` | Module tree with types, param counts, shapes | Any model |
123
+ | `patch` | Activation patching at a module, head, or position | Any model |
124
+ | `trace` | Causal tracing — module-level or position-aware (Meng et al.) heatmap | Any model |
125
+ | `attribute` | Gradient saliency over inputs (returns scores programmatically) | Any model |
126
+ | `lens` | Logit lens — project activations to vocabulary at all positions | LMs (auto-detected) |
127
+ | `activations` | Extract raw activation tensors at any module | Any model |
128
+ | `head_activations` | Decompose attention output into per-head contributions | Transformers |
129
+ | `ablate` | Zero/mean ablate a component and measure effect | Any model |
130
+ | `attention` | Visualize attention patterns per layer/head | Transformers |
131
+ | `steer` | Extract and apply steering vectors | LMs |
132
+ | `probe` | Linear probe on activations | Any model |
133
+ | `diff` | Compare activations between two models | Any model |
134
+ | `features` | SAE feature decomposition | Any model |
135
+ | **`decompose`** | Residual stream decomposition — per-component norms | Transformers |
136
+ | **`ov_scores`** | OV circuit analysis — W_OV matrix per head | Transformers |
137
+ | **`qk_scores`** | QK circuit analysis — W_QK matrix per head | Transformers |
138
+ | **`composition`** | Q/K/V composition scores between heads in two layers | Transformers |
139
+ | **`find_circuit`** | Automated circuit discovery via iterative ablation | Transformers |
140
+ | **`batch`** | Run any operation over a dataset with result aggregation | Any model |
141
+
142
+ ---
143
+
144
+ ## Scan — One-Command Model Overview
145
+
146
+ The fastest way to understand what a model is doing on an input. Runs DLA, logit lens, attention analysis, and gradient attribution, then surfaces the most interesting findings in a ranked summary:
147
+
148
+ ```python
149
+ model.scan("The capital of France is")
150
+ # Output:
151
+ # Predictions: "the" (8.5%), "now" (4.8%), "a" (4.6%)
152
+ # Key Findings (ranked by significance):
153
+ # 1. Top contributor to "the": L11.attn (+204.701)
154
+ # 2. Top attention head: L11.H0 (+149.850)
155
+ # 3. Most salient input token: "is" (score 12.435)
156
+ # 4. Answer "the" first appears at layer 9/12
157
+
158
+ model.scan("The capital of France is", save="scan") # exports scan_dla.png, scan_lens.png, etc.
159
+ ```
160
+
161
+ ## Direct Logit Attribution (DLA)
162
+
163
+ Answers the fundamental question: *why does the model predict this token?* Decomposes the output logit by component (attention block + MLP per layer) and by individual attention head:
164
+
165
+ ```python
166
+ result = model.dla("The capital of France is")
167
+ # result["contributions"] — per-component logit contributions, sorted
168
+ # result["head_contributions"] — per-head breakdown
169
+ # result["target_token"] — the token being attributed
170
+
171
+ # Attribute a specific token
172
+ model.dla("The capital of France is", token="Paris")
173
+
174
+ # Save a bar chart
175
+ model.dla("The capital of France is", save="dla.png")
176
+
177
+ # Feature-level DLA — decompose a component through an SAE
178
+ # to see which individual features drive the prediction
179
+ model.dla(
180
+ "The capital of France is",
181
+ sae="jbloom/GPT2-Small-SAEs-Reformatted",
182
+ sae_at="transformer.h.11.attn",
183
+ )
184
+ # result["feature_contributions"]["features"]
185
+ # — per-feature logit attributions at the specified component
186
+ ```
187
+
188
+ ## Causal Tracing
189
+
190
+ ```python
191
+ # Module-level tracing (default) — rank modules by causal effect
192
+ model.trace("...Paris...", "...Rome...", top_k=20)
193
+
194
+ # Position-aware tracing (Meng et al. 2022) — (layer x position) heatmap
195
+ model.trace("...Paris...", "...Rome...", mode="position", save="trace.png")
196
+ ```
197
+
198
+ ## Logit Lens
199
+
200
+ Now analyses all token positions by default, producing the classic (layers x positions) heatmap:
201
+
202
+ ```python
203
+ model.lens("The capital of France is") # all positions
204
+ model.lens("The capital of France is", position=-1) # last position only
205
+ model.lens("The capital of France is", save="lens.png") # 2D heatmap export
206
+ ```
207
+
208
+ ## Attribution
209
+
210
+ Returns scores programmatically (no longer just prints):
211
+
212
+ ```python
213
+ result = model.attribute("The capital of France is")
214
+ result["tokens"] # ["The", "capital", "of", "France", "is"]
215
+ result["scores"] # [8.88, 11.15, 7.24, 7.37, 12.43]
216
+ result["target"] # 262
217
+ ```
218
+
219
+ ## Activation Patching
220
+
221
+ Supports module-level, head-level, and position-level patching:
222
+
223
+ ```python
224
+ # Module-level (original)
225
+ model.patch(clean, corrupted, at="transformer.h.8.mlp")
226
+
227
+ # Head-level — patch only attention head 3
228
+ model.patch(clean, corrupted, at="transformer.h.8", head=3)
229
+
230
+ # Position-level — patch only positions 3 and 4
231
+ model.patch(clean, corrupted, at="transformer.h.8", positions=[3, 4])
232
+
233
+ # Combined — patch head 3 at positions 3 and 4
234
+ model.patch(clean, corrupted, at="transformer.h.8", head=3, positions=[3, 4])
235
+ ```
236
+
237
+ ## Head-Level Activations
238
+
239
+ Decompose an attention module's output into per-head contributions, optionally projected through W_O into residual-stream space:
240
+
241
+ ```python
242
+ result = model.head_activations("The capital of France is", at="transformer.h.8")
243
+ result["head_acts"] # tensor (num_heads, batch, seq, d_model)
244
+ result["num_heads"] # 12
245
+ result["head_dim"] # 64
246
+ ```
247
+
248
+ ## Activations, Ablation, Attention
249
+
250
+ ```python
251
+ # Extract raw activations
252
+ act = model.activations("The capital of France is", at="transformer.h.8.mlp")
253
+ acts = model.activations("...", at=["transformer.h.0", "transformer.h.8.mlp"])
254
+
255
+ # Ablation — zero or mean
256
+ result = model.ablate("The capital of France is", at="transformer.h.8.mlp")
257
+ result = model.ablate("...", at="transformer.h.8.mlp", method="mean")
258
+
259
+ # Attention patterns
260
+ model.attention("The capital of France is") # all layers
261
+ model.attention("The capital of France is", layer=8, head=3) # single head
262
+ ```
263
+
264
+ ## Residual Stream Decomposition
265
+
266
+ Break down the residual stream at any position into contributions from each attention block and MLP:
267
+
268
+ ```python
269
+ result = model.decompose("The capital of France is")
270
+ # result["components"] — list of {"name": "L8.attn", "type": "attn", "norm": 8.94, ...}
271
+ # result["residual"] — final residual stream vector
272
+ ```
273
+
274
+ ## OV / QK Circuit Analysis
275
+
276
+ Analyse the effective weight matrices of attention heads:
277
+
278
+ ```python
279
+ # OV circuit: what does each head write to the residual stream?
280
+ model.ov_scores(layer=8)
281
+ # Per-head Frobenius norm, top singular values, approximate rank of W_OV
282
+
283
+ # QK circuit: what does each head attend to?
284
+ model.qk_scores(layer=8)
285
+
286
+ # Composition: how much does head j in layer 4 compose with head i in layer 8?
287
+ model.composition(src_layer=4, dst_layer=8, comp_type="q") # Q-composition
288
+ model.composition(src_layer=4, dst_layer=8, comp_type="k") # K-composition
289
+ model.composition(src_layer=4, dst_layer=8, comp_type="v") # V-composition
290
+ ```
291
+
292
+ ## Circuit Discovery
293
+
294
+ Automatically find the minimal set of components that explain a behaviour:
295
+
296
+ ```python
297
+ circuit = model.find_circuit(
298
+ "The Eiffel Tower is in Paris",
299
+ "The Eiffel Tower is in Rome",
300
+ threshold=0.05,
301
+ )
302
+ # circuit["circuit"] — components in the circuit, sorted by effect
303
+ # circuit["excluded"] — components not in the circuit
304
+ # circuit["verification"] — faithfulness check (how much output is preserved
305
+ # when all non-circuit components are ablated)
306
+ ```
307
+
308
+ ## Batch / Dataset Operations
309
+
310
+ Run any operation over a dataset of examples with automatic result aggregation:
311
+
312
+ ```python
313
+ # Generic batch runner
314
+ results = model.batch("trace", [
315
+ {"clean": "...Paris...", "corrupted": "...Rome..."},
316
+ {"clean": "...Berlin...", "corrupted": "...Madrid..."},
317
+ ], op_kwargs={"top_k": 10})
318
+ # results["summary"]["ranked_modules"] — modules ranked by mean effect across examples
319
+
320
+ # Convenience: trace over a dataset
321
+ results = model.trace_batch(dataset, clean_col="clean", corrupted_col="corrupted")
322
+
323
+ # Convenience: DLA over a list of texts
324
+ results = model.dla_batch(["The capital of France is", "The CEO of Apple is"])
325
+ # results["summary"]["ranked_components"] — components ranked by mean contribution
326
+ ```
327
+
328
+ ## Steering
329
+
330
+ ```python
331
+ vector = model.steer_vector("Love", "Hate", at="transformer.h.8")
332
+ model.steer("The weather today is", vector=vector, at="transformer.h.8", scale=2.0)
333
+ ```
334
+
335
+ ## Linear Probe
336
+
337
+ ```python
338
+ result = model.probe(
339
+ texts=["The cat sat", "The dog ran", "A bird flew", "A fish swam"],
340
+ labels=[0, 0, 1, 1],
341
+ at="transformer.h.8",
342
+ )
343
+ print(result["accuracy"])
344
+ ```
345
+
346
+ ## Model Diff
347
+
348
+ ```python
349
+ base = interpkit.load("gpt2")
350
+ finetuned = interpkit.load("my-finetuned-gpt2")
351
+ interpkit.diff(base, finetuned, "The capital of France is")
352
+ ```
353
+
354
+ ## SAE Features
355
+
356
+ Decompose activations into interpretable features using pre-trained Sparse Autoencoders:
357
+
358
+ ```python
359
+ # From HuggingFace
360
+ model.features(
361
+ "The capital of France is",
362
+ at="transformer.h.8",
363
+ sae="jbloom/GPT2-Small-SAEs-Reformatted",
364
+ )
365
+
366
+ # From a local file (.safetensors or .pt)
367
+ model.features(
368
+ "The capital of France is",
369
+ at="transformer.h.8",
370
+ sae="/path/to/sae_weights.safetensors",
371
+ )
372
+ ```
373
+
374
+ No SAELens dependency — weights are loaded directly via `safetensors`.
375
+
376
+ ## Activation Cache
377
+
378
+ Avoid redundant forward passes when exploring the same input with multiple operations:
379
+
380
+ ```python
381
+ model.cache("The capital of France is") # one forward pass, cache all layers
382
+ model.activations("The capital of France is", at="transformer.h.8.mlp") # instant
383
+ model.activations("The capital of France is", at="transformer.h.0.mlp") # instant
384
+
385
+ model.clear_cache() # free memory
386
+ ```
387
+
388
+ ---
389
+
390
+ ## Visualizations
391
+
392
+ Pass `save="path.png"` to export a static matplotlib figure, or `html="path.html"` for an interactive visualization:
393
+
394
+ ```python
395
+ model.attention("hello world", layer=0, head=0, save="attention.png")
396
+ model.trace("...Paris...", "...Rome...", save="trace.png")
397
+ model.trace("...Paris...", "...Rome...", mode="position", save="position_trace.png")
398
+ model.lens("The capital of France is", save="lens.png")
399
+ model.steer("The weather is", vector=vector, at="transformer.h.8", save="steer.png")
400
+ model.attribute("The capital of France is", save="attribution.png")
401
+ model.dla("The capital of France is", save="dla.png")
402
+ model.scan("The capital of France is", save="scan")
403
+ interpkit.diff(base, finetuned, "...", save="diff.png")
404
+
405
+ # Interactive HTML — self-contained files with hover tooltips, filters, and sliders
406
+ model.attention("hello world", html="attention.html")
407
+ model.trace("...Paris...", "...Rome...", html="trace.html")
408
+ model.attribute("The capital of France is", html="attribution.html")
409
+ ```
410
+
411
+ ---
412
+
413
+ ## CLI
414
+
415
+ ```bash
416
+ interpkit inspect gpt2
417
+ interpkit scan gpt2 "The capital of France is"
418
+ interpkit dla gpt2 "The capital of France is"
419
+ interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --top-k 20
420
+ interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --mode position --save trace.png
421
+ interpkit lens gpt2 "The capital of France is"
422
+ interpkit lens gpt2 "The capital of France is" --position -1
423
+ interpkit attention gpt2 "The capital of France is" --layer 8 --save attention.png
424
+ interpkit attribute gpt2 "The capital of France is"
425
+ interpkit steer gpt2 "The weather is" --positive Love --negative Hate --at transformer.h.8
426
+ interpkit ablate gpt2 "The capital of France is" --at transformer.h.8.mlp
427
+ interpkit decompose gpt2 "The capital of France is"
428
+ interpkit diff gpt2 my-finetuned-gpt2 "The capital of France is" --save diff.png
429
+ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jbloom/GPT2-Small-SAEs-Reformatted
430
+ interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae ./my_sae.safetensors
431
+ interpkit dla gpt2 "The capital of France is" --sae jbloom/GPT2-Small-SAEs-Reformatted --sae-at transformer.h.11.attn
432
+
433
+ # Interactive HTML output
434
+ interpkit attention gpt2 "hello world" --html attention.html
435
+ interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --html trace.html
436
+ interpkit attribute gpt2 "The capital of France is" --html attribution.html
437
+
438
+ # Vision models — auto-preprocessed
439
+ interpkit attribute microsoft/resnet-50 cat.jpg --target 281
440
+ ```
441
+
442
+ Run `interpkit` with no arguments for a full command reference.
443
+
444
+ ---
445
+
446
+ ## TransformerLens interop
447
+
448
+ Already using TransformerLens? Pass your `HookedTransformer` directly into InterpKit — it auto-detects the model and extracts the tokenizer:
449
+
450
+ ```python
451
+ from transformer_lens import HookedTransformer
452
+ import interpkit
453
+
454
+ tl_model = HookedTransformer.from_pretrained("gpt2")
455
+ model = interpkit.load(tl_model)
456
+
457
+ # All InterpKit operations work on TL models
458
+ model.scan("The capital of France is")
459
+ model.dla("The capital of France is")
460
+ model.trace("The Eiffel Tower is in Paris", "The Eiffel Tower is in Rome", top_k=20)
461
+ model.attention("The capital of France is", save="attention.png")
462
+ model.steer("The weather is", vector=vector, at="blocks.8", scale=2.0)
463
+ ```
464
+
465
+ Translate between native and TL hook point names:
466
+
467
+ ```python
468
+ interpkit.to_tl_name("transformer.h.8.mlp") # -> "blocks.8.mlp"
469
+ interpkit.to_native_name("blocks.8.attn", model.arch_info) # -> "transformer.h.8.attn"
470
+ interpkit.list_tl_hooks(tl_model) # -> ["blocks.0.hook_resid_pre", ...]
471
+ ```
472
+
473
+ ---
474
+
475
+ ## Local models
476
+
477
+ ```python
478
+ import torch.nn as nn
479
+ import interpkit
480
+
481
+ my_model = MyCustomModel()
482
+ interpkit.register(my_model, layers=["blocks.0", "blocks.1"], output_head="head")
483
+ model = interpkit.load(my_model, tokenizer=my_tokenizer)
484
+ model.trace(input_a, input_b, top_k=10)
485
+ ```
486
+
487
+ ---
488
+
489
+ ## Examples
490
+
491
+ See the [`examples/`](examples/) directory for Jupyter notebooks:
492
+
493
+ | Notebook | Topics |
494
+ |----------|--------|
495
+ | `01_quickstart` | Inspect, scan, DLA, trace, lens, attribution, patching, ablation |
496
+ | `02_attention_patterns` | Per-head heatmaps, layer filtering, HTML export |
497
+ | `03_steering_vectors` | Extract and apply steering vectors at different layers/scales |
498
+ | `04_sae_features` | Sparse Autoencoder feature decomposition |
499
+ | `05_caching_and_probing` | Activation cache, linear probes across layers |
500
+ | `06_model_comparison` | Diff two models, side-by-side tracing and logit lens |
501
+ | `07_vision_models` | ResNet/ViT attribution, ablation, activations |
502
+ | `08_dla_and_circuits` | DLA, head activations, residual decomposition, OV/QK analysis, composition, circuit discovery |
503
+ | `09_scan_and_batch` | Auto-scan, batch operations, dataset workflows |
504
+
505
+ ---
506
+
507
+ ## License
508
+
509
+ MIT