interpkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit/__init__.py +15 -0
- interpkit/cli/__init__.py +0 -0
- interpkit/cli/main.py +337 -0
- interpkit/core/__init__.py +0 -0
- interpkit/core/discovery.py +228 -0
- interpkit/core/html.py +375 -0
- interpkit/core/inputs.py +117 -0
- interpkit/core/model.py +551 -0
- interpkit/core/plot.py +352 -0
- interpkit/core/registry.py +82 -0
- interpkit/core/render.py +465 -0
- interpkit/core/tl_compat.py +174 -0
- interpkit/ops/__init__.py +0 -0
- interpkit/ops/ablate.py +90 -0
- interpkit/ops/activations.py +67 -0
- interpkit/ops/attention.py +234 -0
- interpkit/ops/attribute.py +206 -0
- interpkit/ops/diff.py +79 -0
- interpkit/ops/inspect.py +14 -0
- interpkit/ops/lens.py +151 -0
- interpkit/ops/patch.py +112 -0
- interpkit/ops/probe.py +128 -0
- interpkit/ops/sae.py +212 -0
- interpkit/ops/steer.py +118 -0
- interpkit/ops/trace.py +182 -0
- interpkit-0.1.0.dist-info/METADATA +295 -0
- interpkit-0.1.0.dist-info/RECORD +31 -0
- interpkit-0.1.0.dist-info/WHEEL +5 -0
- interpkit-0.1.0.dist-info/entry_points.txt +2 -0
- interpkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- interpkit-0.1.0.dist-info/top_level.txt +1 -0
interpkit/ops/steer.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""steer — extract and apply steering vectors."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from interpkit.ops.patch import _get_module
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from interpkit.core.model import Model
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def run_steer_vector(
|
|
16
|
+
model: "Model",
|
|
17
|
+
positive: Any,
|
|
18
|
+
negative: Any,
|
|
19
|
+
*,
|
|
20
|
+
at: str,
|
|
21
|
+
) -> torch.Tensor:
|
|
22
|
+
"""Extract a steering vector: activation(positive) - activation(negative) at module *at*.
|
|
23
|
+
|
|
24
|
+
Both inputs are padded to the same length. The vector is the mean
|
|
25
|
+
difference across the sequence dimension.
|
|
26
|
+
"""
|
|
27
|
+
from interpkit.ops.activations import run_activations
|
|
28
|
+
|
|
29
|
+
pos_act = run_activations(model, positive, at=at, print_stats=False)
|
|
30
|
+
neg_act = run_activations(model, negative, at=at, print_stats=False)
|
|
31
|
+
|
|
32
|
+
# Mean across sequence dim if present
|
|
33
|
+
if pos_act.dim() >= 3:
|
|
34
|
+
pos_mean = pos_act[0].mean(dim=0) # (hidden,)
|
|
35
|
+
neg_mean = neg_act[0].mean(dim=0)
|
|
36
|
+
elif pos_act.dim() == 2:
|
|
37
|
+
pos_mean = pos_act.mean(dim=0)
|
|
38
|
+
neg_mean = neg_act.mean(dim=0)
|
|
39
|
+
else:
|
|
40
|
+
pos_mean = pos_act
|
|
41
|
+
neg_mean = neg_act
|
|
42
|
+
|
|
43
|
+
return pos_mean - neg_mean
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def run_steer(
|
|
47
|
+
model: "Model",
|
|
48
|
+
input_data: Any,
|
|
49
|
+
*,
|
|
50
|
+
vector: torch.Tensor,
|
|
51
|
+
at: str,
|
|
52
|
+
scale: float = 2.0,
|
|
53
|
+
save: str | None = None,
|
|
54
|
+
) -> dict[str, Any]:
|
|
55
|
+
"""Run inference with and without a steering vector, compare top predictions."""
|
|
56
|
+
from interpkit.core.render import render_steer
|
|
57
|
+
|
|
58
|
+
model_input = model._prepare(input_data)
|
|
59
|
+
|
|
60
|
+
# 1. Original forward
|
|
61
|
+
original_logits = model._forward(model_input)
|
|
62
|
+
|
|
63
|
+
# 2. Steered forward
|
|
64
|
+
target_mod = _get_module(model._model, at)
|
|
65
|
+
|
|
66
|
+
def _steer_hook(_mod: torch.nn.Module, _inp: Any, output: Any) -> Any:
|
|
67
|
+
if isinstance(output, torch.Tensor):
|
|
68
|
+
return output + scale * vector.to(output.device)
|
|
69
|
+
elif isinstance(output, (tuple, list)):
|
|
70
|
+
steered = output[0] + scale * vector.to(output[0].device)
|
|
71
|
+
return (steered,) + tuple(output[1:])
|
|
72
|
+
return output
|
|
73
|
+
|
|
74
|
+
handle = target_mod.register_forward_hook(_steer_hook)
|
|
75
|
+
steered_logits = model._forward(model_input)
|
|
76
|
+
handle.remove()
|
|
77
|
+
|
|
78
|
+
# Extract top tokens
|
|
79
|
+
original_tokens = _top_tokens(model, original_logits)
|
|
80
|
+
steered_tokens = _top_tokens(model, steered_logits)
|
|
81
|
+
|
|
82
|
+
render_steer(original_tokens, steered_tokens, at, scale)
|
|
83
|
+
|
|
84
|
+
if save is not None:
|
|
85
|
+
from interpkit.core.plot import plot_steer
|
|
86
|
+
|
|
87
|
+
plot_steer(original_tokens, steered_tokens, module_name=at, scale=scale, save_path=save)
|
|
88
|
+
|
|
89
|
+
return {
|
|
90
|
+
"original_logits": original_logits,
|
|
91
|
+
"steered_logits": steered_logits,
|
|
92
|
+
"original_top": original_tokens,
|
|
93
|
+
"steered_top": steered_tokens,
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _top_tokens(
|
|
98
|
+
model: "Model",
|
|
99
|
+
logits: torch.Tensor,
|
|
100
|
+
k: int = 10,
|
|
101
|
+
) -> list[tuple[str, float]]:
|
|
102
|
+
"""Extract top-k predicted tokens from logits."""
|
|
103
|
+
if logits.dim() == 3:
|
|
104
|
+
last_logits = logits[0, -1, :]
|
|
105
|
+
elif logits.dim() == 2:
|
|
106
|
+
last_logits = logits[-1, :]
|
|
107
|
+
else:
|
|
108
|
+
last_logits = logits.view(-1)
|
|
109
|
+
|
|
110
|
+
probs = torch.softmax(last_logits.float(), dim=-1)
|
|
111
|
+
top_probs, top_ids = probs.topk(k)
|
|
112
|
+
|
|
113
|
+
if model._tokenizer is not None:
|
|
114
|
+
tokens = [model._tokenizer.decode([tid]) for tid in top_ids.tolist()]
|
|
115
|
+
else:
|
|
116
|
+
tokens = [str(tid) for tid in top_ids.tolist()]
|
|
117
|
+
|
|
118
|
+
return list(zip(tokens, top_probs.tolist()))
|
interpkit/ops/trace.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""trace — two-phase causal tracing across all modules, ranked by causal effect."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from rich.console import Console
|
|
10
|
+
from rich.progress import Progress
|
|
11
|
+
|
|
12
|
+
from interpkit.ops.patch import _compute_effect, _get_module
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from interpkit.core.model import Model
|
|
16
|
+
|
|
17
|
+
console = Console()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def run_trace(
|
|
21
|
+
model: "Model",
|
|
22
|
+
clean: Any,
|
|
23
|
+
corrupted: Any,
|
|
24
|
+
*,
|
|
25
|
+
top_k: int | None = 20,
|
|
26
|
+
save: str | None = None,
|
|
27
|
+
html: str | None = None,
|
|
28
|
+
) -> list[dict[str, Any]]:
|
|
29
|
+
"""Two-phase causal tracing.
|
|
30
|
+
|
|
31
|
+
Phase 1 (fast proxy): run clean and corrupted forward passes, capture activation
|
|
32
|
+
norms at every module, rank by norm delta.
|
|
33
|
+
|
|
34
|
+
Phase 2 (expensive): for the top-K modules by proxy score, run full
|
|
35
|
+
patch-and-measure to get true causal effect.
|
|
36
|
+
"""
|
|
37
|
+
from interpkit.core.render import render_trace
|
|
38
|
+
|
|
39
|
+
clean_input, corrupted_input = model._prepare_pair(clean, corrupted)
|
|
40
|
+
|
|
41
|
+
# Filter to leaf modules with parameters (skip containers)
|
|
42
|
+
candidates = [
|
|
43
|
+
m for m in model.arch_info.modules
|
|
44
|
+
if m.param_count > 0
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
total_modules = len(candidates)
|
|
48
|
+
|
|
49
|
+
if top_k == 0:
|
|
50
|
+
top_k = None
|
|
51
|
+
|
|
52
|
+
# ----------------------------------------------------------------
|
|
53
|
+
# Phase 1: fast proxy — activation norm delta
|
|
54
|
+
# ----------------------------------------------------------------
|
|
55
|
+
clean_norms: dict[str, float] = {}
|
|
56
|
+
corrupted_norms: dict[str, float] = {}
|
|
57
|
+
|
|
58
|
+
def _make_norm_hook(store: dict[str, float], name: str):
|
|
59
|
+
def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
|
|
60
|
+
t = output if isinstance(output, torch.Tensor) else (
|
|
61
|
+
output[0] if isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor) else None
|
|
62
|
+
)
|
|
63
|
+
if t is not None:
|
|
64
|
+
store[name] = t.detach().float().norm().item()
|
|
65
|
+
return hook_fn
|
|
66
|
+
|
|
67
|
+
# Clean pass
|
|
68
|
+
hooks = []
|
|
69
|
+
for m in candidates:
|
|
70
|
+
mod = _get_module(model._model, m.name)
|
|
71
|
+
hooks.append(mod.register_forward_hook(_make_norm_hook(clean_norms, m.name)))
|
|
72
|
+
with torch.no_grad():
|
|
73
|
+
clean_logits = model._forward(clean_input)
|
|
74
|
+
for h in hooks:
|
|
75
|
+
h.remove()
|
|
76
|
+
|
|
77
|
+
# Corrupted pass
|
|
78
|
+
hooks = []
|
|
79
|
+
for m in candidates:
|
|
80
|
+
mod = _get_module(model._model, m.name)
|
|
81
|
+
hooks.append(mod.register_forward_hook(_make_norm_hook(corrupted_norms, m.name)))
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
corrupted_logits = model._forward(corrupted_input)
|
|
84
|
+
for h in hooks:
|
|
85
|
+
h.remove()
|
|
86
|
+
|
|
87
|
+
# Rank by proxy: absolute norm difference
|
|
88
|
+
proxy_scores: list[tuple[str, float]] = []
|
|
89
|
+
for m in candidates:
|
|
90
|
+
cn = clean_norms.get(m.name, 0.0)
|
|
91
|
+
crn = corrupted_norms.get(m.name, 0.0)
|
|
92
|
+
proxy_scores.append((m.name, abs(cn - crn)))
|
|
93
|
+
|
|
94
|
+
proxy_scores.sort(key=lambda x: x[1], reverse=True)
|
|
95
|
+
|
|
96
|
+
# Select top-K for expensive phase
|
|
97
|
+
if top_k is not None and top_k < total_modules:
|
|
98
|
+
selected_names = {name for name, _ in proxy_scores[:top_k]}
|
|
99
|
+
console.print(
|
|
100
|
+
f"\n Scanning top {top_k} of {total_modules} modules by proxy score."
|
|
101
|
+
)
|
|
102
|
+
else:
|
|
103
|
+
selected_names = {name for name, _ in proxy_scores}
|
|
104
|
+
|
|
105
|
+
# ----------------------------------------------------------------
|
|
106
|
+
# Phase 2: full causal patching on selected modules
|
|
107
|
+
# ----------------------------------------------------------------
|
|
108
|
+
|
|
109
|
+
# Cache all clean activations in one pass
|
|
110
|
+
clean_cache: dict[str, torch.Tensor] = {}
|
|
111
|
+
|
|
112
|
+
def _make_cache_hook(name: str):
|
|
113
|
+
def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
|
|
114
|
+
t = output if isinstance(output, torch.Tensor) else (
|
|
115
|
+
output[0] if isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor) else None
|
|
116
|
+
)
|
|
117
|
+
if t is not None:
|
|
118
|
+
clean_cache[name] = t.detach().clone()
|
|
119
|
+
return hook_fn
|
|
120
|
+
|
|
121
|
+
hooks = []
|
|
122
|
+
for m in candidates:
|
|
123
|
+
if m.name in selected_names:
|
|
124
|
+
mod = _get_module(model._model, m.name)
|
|
125
|
+
hooks.append(mod.register_forward_hook(_make_cache_hook(m.name)))
|
|
126
|
+
with torch.no_grad():
|
|
127
|
+
clean_logits = model._forward(clean_input)
|
|
128
|
+
for h in hooks:
|
|
129
|
+
h.remove()
|
|
130
|
+
|
|
131
|
+
# Patch each selected module one at a time
|
|
132
|
+
results: list[dict[str, Any]] = []
|
|
133
|
+
module_role_map = {m.name: m.role for m in candidates}
|
|
134
|
+
|
|
135
|
+
with Progress(console=console, transient=True) as progress:
|
|
136
|
+
task = progress.add_task("Causal tracing", total=len(selected_names))
|
|
137
|
+
for name in selected_names:
|
|
138
|
+
if name not in clean_cache:
|
|
139
|
+
progress.advance(task)
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
target_mod = _get_module(model._model, name)
|
|
143
|
+
|
|
144
|
+
def _make_patch_hook(cached: torch.Tensor):
|
|
145
|
+
def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> Any:
|
|
146
|
+
if isinstance(output, torch.Tensor):
|
|
147
|
+
return cached
|
|
148
|
+
elif isinstance(output, (tuple, list)):
|
|
149
|
+
return (cached,) + tuple(output[1:])
|
|
150
|
+
return output
|
|
151
|
+
return hook_fn
|
|
152
|
+
|
|
153
|
+
handle = target_mod.register_forward_hook(_make_patch_hook(clean_cache[name]))
|
|
154
|
+
with torch.no_grad():
|
|
155
|
+
patched_logits = model._forward(corrupted_input)
|
|
156
|
+
handle.remove()
|
|
157
|
+
|
|
158
|
+
effect = _compute_effect(clean_logits, corrupted_logits, patched_logits)
|
|
159
|
+
results.append({
|
|
160
|
+
"module": name,
|
|
161
|
+
"role": module_role_map.get(name),
|
|
162
|
+
"effect": effect,
|
|
163
|
+
})
|
|
164
|
+
progress.advance(task)
|
|
165
|
+
|
|
166
|
+
results.sort(key=lambda x: x["effect"], reverse=True)
|
|
167
|
+
|
|
168
|
+
model_name = model.arch_info.arch_family or "model"
|
|
169
|
+
render_trace(results, model_name, total_modules, top_k)
|
|
170
|
+
|
|
171
|
+
if save is not None:
|
|
172
|
+
from interpkit.core.plot import plot_trace
|
|
173
|
+
|
|
174
|
+
plot_trace(results, model_name=model_name, save_path=save)
|
|
175
|
+
|
|
176
|
+
if html is not None:
|
|
177
|
+
from interpkit.core.html import html_trace as gen_html_trace
|
|
178
|
+
from interpkit.core.html import save_html
|
|
179
|
+
|
|
180
|
+
save_html(gen_html_trace(results), html)
|
|
181
|
+
|
|
182
|
+
return results
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: interpkit
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Mech interp for any HuggingFace model.
|
|
5
|
+
Author: Davide Zani
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/davidezani/InterpKit
|
|
8
|
+
Project-URL: Repository, https://github.com/davidezani/InterpKit
|
|
9
|
+
Project-URL: Issues, https://github.com/davidezani/InterpKit/issues
|
|
10
|
+
Keywords: mechanistic-interpretability,pytorch,transformers,mech-interp,interpretability
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: Programming Language :: Python :: 3
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Requires-Python: >=3.10
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: torch>=2.1
|
|
23
|
+
Requires-Dist: transformers>=4.36
|
|
24
|
+
Requires-Dist: nnsight>=0.3
|
|
25
|
+
Requires-Dist: rich>=13.0
|
|
26
|
+
Requires-Dist: typer>=0.9
|
|
27
|
+
Requires-Dist: Pillow>=10.0
|
|
28
|
+
Requires-Dist: matplotlib>=3.8
|
|
29
|
+
Requires-Dist: huggingface-hub>=0.20
|
|
30
|
+
Provides-Extra: probe
|
|
31
|
+
Requires-Dist: scikit-learn>=1.3; extra == "probe"
|
|
32
|
+
Provides-Extra: dev
|
|
33
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
34
|
+
Requires-Dist: pytest-timeout>=2.2; extra == "dev"
|
|
35
|
+
Requires-Dist: scikit-learn>=1.3; extra == "dev"
|
|
36
|
+
Dynamic: license-file
|
|
37
|
+
|
|
38
|
+
# InterpKit
|
|
39
|
+
|
|
40
|
+
> Mech interp for any HuggingFace model.
|
|
41
|
+
|
|
42
|
+
[](https://pypi.org/project/interpkit/)
|
|
43
|
+
[](https://opensource.org/licenses/MIT)
|
|
44
|
+
[](https://www.python.org/downloads/)
|
|
45
|
+
|
|
46
|
+
---
|
|
47
|
+
|
|
48
|
+
## The Problem
|
|
49
|
+
|
|
50
|
+
TransformerLens is excellent — but only works on GPT-style decoder-only transformers. The moment you step outside that (Mamba, SSMs, ViT, CNNs, BERT, T5, MoE models), there is no equivalent tool. You write hook code from scratch every time.
|
|
51
|
+
|
|
52
|
+
InterpKit fills this gap: the same standard mech interp operations, on any HuggingFace model, with no annotation required.
|
|
53
|
+
|
|
54
|
+
---
|
|
55
|
+
|
|
56
|
+
## Install
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
pip install interpkit
|
|
60
|
+
|
|
61
|
+
# For linear probe support:
|
|
62
|
+
pip install interpkit[probe]
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
Or install from source for development:
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
git clone https://github.com/davidezani/InterpKit.git
|
|
69
|
+
cd InterpKit
|
|
70
|
+
pip install -e ".[dev]"
|
|
71
|
+
```
|
|
72
|
+
|
|
73
|
+
---
|
|
74
|
+
|
|
75
|
+
## Quickstart
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
import interpkit
|
|
79
|
+
|
|
80
|
+
model = interpkit.load("gpt2")
|
|
81
|
+
|
|
82
|
+
model.inspect() # module tree with roles, params, shapes
|
|
83
|
+
model.trace("...Paris...", "...Rome...", top_k=20) # causal tracing
|
|
84
|
+
model.patch("...Paris...", "...Rome...", at="transformer.h.8.mlp")
|
|
85
|
+
model.lens("The capital of France is") # logit lens
|
|
86
|
+
model.attribute("The capital of France is") # gradient saliency
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
Works the same on any HF architecture:
|
|
90
|
+
|
|
91
|
+
```python
|
|
92
|
+
model = interpkit.load("state-spaces/mamba-370m")
|
|
93
|
+
model = interpkit.load("google/vit-base-patch16-224")
|
|
94
|
+
model = interpkit.load("bert-base-uncased")
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
---
|
|
98
|
+
|
|
99
|
+
## Operations
|
|
100
|
+
|
|
101
|
+
| Operation | What it does | Works on |
|
|
102
|
+
|-----------|-------------|----------|
|
|
103
|
+
| `inspect` | Module tree with types, param counts, shapes | Any model |
|
|
104
|
+
| `patch` | Activation patching at a named module | Any model |
|
|
105
|
+
| `trace` | Causal tracing across modules, ranked by effect | Any model |
|
|
106
|
+
| `attribute` | Gradient saliency over inputs | Any model |
|
|
107
|
+
| `lens` | Logit lens — project activations to vocabulary | LMs (auto-detected) |
|
|
108
|
+
| `activations` | Extract raw activation tensors at any module | Any model |
|
|
109
|
+
| `ablate` | Zero/mean ablate a component and measure effect | Any model |
|
|
110
|
+
| `attention` | Visualize attention patterns per layer/head | Transformers |
|
|
111
|
+
| `steer` | Extract and apply steering vectors | Any model |
|
|
112
|
+
| `probe` | Linear probe on activations | Any model |
|
|
113
|
+
| `diff` | Compare activations between two models | Any model |
|
|
114
|
+
| `features` | SAE feature decomposition | Any model |
|
|
115
|
+
|
|
116
|
+
---
|
|
117
|
+
|
|
118
|
+
## Activations, Ablation, Attention
|
|
119
|
+
|
|
120
|
+
```python
|
|
121
|
+
# Extract raw activations
|
|
122
|
+
act = model.activations("The capital of France is", at="transformer.h.8.mlp")
|
|
123
|
+
acts = model.activations("...", at=["transformer.h.0", "transformer.h.8.mlp"])
|
|
124
|
+
|
|
125
|
+
# Ablation — zero or mean
|
|
126
|
+
result = model.ablate("The capital of France is", at="transformer.h.8.mlp")
|
|
127
|
+
result = model.ablate("...", at="transformer.h.8.mlp", method="mean")
|
|
128
|
+
|
|
129
|
+
# Attention patterns
|
|
130
|
+
model.attention("The capital of France is") # all layers
|
|
131
|
+
model.attention("The capital of France is", layer=8, head=3) # single head
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
## Steering
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
# 1. Extract a steering vector
|
|
138
|
+
vector = model.steer_vector("Love", "Hate", at="transformer.h.8")
|
|
139
|
+
|
|
140
|
+
# 2. Apply during inference — side-by-side comparison
|
|
141
|
+
model.steer("The weather today is", vector=vector, at="transformer.h.8", scale=2.0)
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
## Linear Probe
|
|
145
|
+
|
|
146
|
+
```python
|
|
147
|
+
result = model.probe(
|
|
148
|
+
texts=["The cat sat", "The dog ran", "A bird flew", "A fish swam"],
|
|
149
|
+
labels=[0, 0, 1, 1],
|
|
150
|
+
at="transformer.h.8",
|
|
151
|
+
)
|
|
152
|
+
print(result["accuracy"])
|
|
153
|
+
```
|
|
154
|
+
|
|
155
|
+
## Model Diff
|
|
156
|
+
|
|
157
|
+
```python
|
|
158
|
+
base = interpkit.load("gpt2")
|
|
159
|
+
finetuned = interpkit.load("my-finetuned-gpt2")
|
|
160
|
+
interpkit.diff(base, finetuned, "The capital of France is")
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
## SAE Features
|
|
164
|
+
|
|
165
|
+
Decompose activations into interpretable features using pre-trained Sparse Autoencoders from HuggingFace:
|
|
166
|
+
|
|
167
|
+
```python
|
|
168
|
+
model.features(
|
|
169
|
+
"The capital of France is",
|
|
170
|
+
at="transformer.h.8",
|
|
171
|
+
sae="jbloom/GPT2-Small-SAEs-Reformatted",
|
|
172
|
+
)
|
|
173
|
+
```
|
|
174
|
+
|
|
175
|
+
No SAELens dependency — weights are loaded directly via `safetensors`.
|
|
176
|
+
|
|
177
|
+
## Activation Cache
|
|
178
|
+
|
|
179
|
+
Avoid redundant forward passes when exploring the same input with multiple operations:
|
|
180
|
+
|
|
181
|
+
```python
|
|
182
|
+
model.cache("The capital of France is") # one forward pass, cache all layers
|
|
183
|
+
model.activations("The capital of France is", at="transformer.h.8.mlp") # instant
|
|
184
|
+
model.activations("The capital of France is", at="transformer.h.0.mlp") # instant
|
|
185
|
+
|
|
186
|
+
model.clear_cache() # free memory
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
---
|
|
190
|
+
|
|
191
|
+
## Visualizations
|
|
192
|
+
|
|
193
|
+
Pass `save="path.png"` to export a static matplotlib figure, or `html="path.html"` for an interactive visualization:
|
|
194
|
+
|
|
195
|
+
```python
|
|
196
|
+
model.attention("hello world", layer=0, head=0, save="attention.png")
|
|
197
|
+
model.trace("...Paris...", "...Rome...", save="trace.png")
|
|
198
|
+
model.lens("The capital of France is", save="lens.png")
|
|
199
|
+
model.steer("The weather is", vector=vector, at="transformer.h.8", save="steer.png")
|
|
200
|
+
model.attribute("The capital of France is", save="attribution.png")
|
|
201
|
+
interpkit.diff(base, finetuned, "...", save="diff.png")
|
|
202
|
+
|
|
203
|
+
# Interactive HTML — self-contained files with hover tooltips, filters, and sliders
|
|
204
|
+
model.attention("hello world", html="attention.html")
|
|
205
|
+
model.trace("...Paris...", "...Rome...", html="trace.html")
|
|
206
|
+
model.attribute("The capital of France is", html="attribution.html")
|
|
207
|
+
```
|
|
208
|
+
|
|
209
|
+
---
|
|
210
|
+
|
|
211
|
+
## CLI
|
|
212
|
+
|
|
213
|
+
```bash
|
|
214
|
+
interpkit inspect gpt2
|
|
215
|
+
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --top-k 20
|
|
216
|
+
interpkit lens gpt2 "The capital of France is"
|
|
217
|
+
interpkit attention gpt2 "The capital of France is" --layer 8 --save attention.png
|
|
218
|
+
interpkit steer gpt2 "The weather is" --positive Love --negative Hate --at transformer.h.8
|
|
219
|
+
interpkit ablate gpt2 "The capital of France is" --at transformer.h.8.mlp
|
|
220
|
+
interpkit diff gpt2 my-finetuned-gpt2 "The capital of France is" --save diff.png
|
|
221
|
+
interpkit features gpt2 "The capital of France is" --at transformer.h.8 --sae jbloom/GPT2-Small-SAEs-Reformatted
|
|
222
|
+
|
|
223
|
+
# Interactive HTML output
|
|
224
|
+
interpkit attention gpt2 "hello world" --html attention.html
|
|
225
|
+
interpkit trace gpt2 --clean "...Paris..." --corrupted "...Rome..." --html trace.html
|
|
226
|
+
interpkit attribute gpt2 "The capital of France is" --html attribution.html
|
|
227
|
+
|
|
228
|
+
# Vision models — auto-preprocessed
|
|
229
|
+
interpkit attribute microsoft/resnet-50 cat.jpg --target 281
|
|
230
|
+
```
|
|
231
|
+
|
|
232
|
+
Run `interpkit` with no arguments for a full command reference.
|
|
233
|
+
|
|
234
|
+
---
|
|
235
|
+
|
|
236
|
+
## TransformerLens interop
|
|
237
|
+
|
|
238
|
+
Already using TransformerLens? Pass your `HookedTransformer` directly into InterpKit — it auto-detects the model and extracts the tokenizer:
|
|
239
|
+
|
|
240
|
+
```python
|
|
241
|
+
from transformer_lens import HookedTransformer
|
|
242
|
+
import interpkit
|
|
243
|
+
|
|
244
|
+
tl_model = HookedTransformer.from_pretrained("gpt2")
|
|
245
|
+
model = interpkit.load(tl_model)
|
|
246
|
+
|
|
247
|
+
# All InterpKit operations work on TL models
|
|
248
|
+
model.trace("The Eiffel Tower is in Paris", "The Eiffel Tower is in Rome", top_k=20)
|
|
249
|
+
model.attention("The capital of France is", save="attention.png")
|
|
250
|
+
model.steer("The weather is", vector=vector, at="blocks.8", scale=2.0)
|
|
251
|
+
```
|
|
252
|
+
|
|
253
|
+
Translate between native and TL hook point names:
|
|
254
|
+
|
|
255
|
+
```python
|
|
256
|
+
interpkit.to_tl_name("transformer.h.8.mlp") # -> "blocks.8.mlp"
|
|
257
|
+
interpkit.to_native_name("blocks.8.attn", model.arch_info) # -> "transformer.h.8.attn"
|
|
258
|
+
interpkit.list_tl_hooks(tl_model) # -> ["blocks.0.hook_resid_pre", ...]
|
|
259
|
+
```
|
|
260
|
+
|
|
261
|
+
---
|
|
262
|
+
|
|
263
|
+
## Local models
|
|
264
|
+
|
|
265
|
+
```python
|
|
266
|
+
import torch.nn as nn
|
|
267
|
+
import interpkit
|
|
268
|
+
|
|
269
|
+
my_model = MyCustomModel()
|
|
270
|
+
interpkit.register(my_model, layers=["blocks.0", "blocks.1"], output_head="head")
|
|
271
|
+
model = interpkit.load(my_model, tokenizer=my_tokenizer)
|
|
272
|
+
model.trace(input_a, input_b, top_k=10)
|
|
273
|
+
```
|
|
274
|
+
|
|
275
|
+
---
|
|
276
|
+
|
|
277
|
+
## Examples
|
|
278
|
+
|
|
279
|
+
See the [`examples/`](examples/) directory for Jupyter notebooks:
|
|
280
|
+
|
|
281
|
+
| Notebook | Topics |
|
|
282
|
+
|----------|--------|
|
|
283
|
+
| `01_quickstart` | Inspect, trace, lens, attribution, patching, ablation |
|
|
284
|
+
| `02_attention_patterns` | Per-head heatmaps, layer filtering, HTML export |
|
|
285
|
+
| `03_steering_vectors` | Extract and apply steering vectors at different layers/scales |
|
|
286
|
+
| `04_sae_features` | Sparse Autoencoder feature decomposition |
|
|
287
|
+
| `05_caching_and_probing` | Activation cache, linear probes across layers |
|
|
288
|
+
| `06_model_comparison` | Diff two models, side-by-side tracing and logit lens |
|
|
289
|
+
| `07_vision_models` | ResNet/ViT attribution, ablation, activations |
|
|
290
|
+
|
|
291
|
+
---
|
|
292
|
+
|
|
293
|
+
## License
|
|
294
|
+
|
|
295
|
+
MIT
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
interpkit/__init__.py,sha256=cTks5G9HHVi0VFvqKnTrI5lhI-6b6ECuLtk81R20uhU,542
|
|
2
|
+
interpkit/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
|
+
interpkit/cli/main.py,sha256=hsAA33eodAsi0wAawbAyWRRGiiQCseCOJOIhje-KFng,17428
|
|
4
|
+
interpkit/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
+
interpkit/core/discovery.py,sha256=hs9nFHSO7Tr7nkAPZSGYi8CEsZ0DLgmPBDl_cJLudVs,7285
|
|
6
|
+
interpkit/core/html.py,sha256=VnIszOvjMKl3erCIiE4pIsQSPg6_qOvlXgtRH1IniUQ,12317
|
|
7
|
+
interpkit/core/inputs.py,sha256=gsF5ljcCV18s3caglwiTz4TSQ6w6kphGrAYQkDgSTV4,4137
|
|
8
|
+
interpkit/core/model.py,sha256=JBZToZ_pd71zES9pl8N3hCNm-KK5ALX0Irys20loUn0,18308
|
|
9
|
+
interpkit/core/plot.py,sha256=B9V4_R-2zWPIFwaL96uqBbrSz3OCC_vKeZcm2wnAyQ0,12051
|
|
10
|
+
interpkit/core/registry.py,sha256=dUJPiVMomk5QSjarKrw-yvNyteuukkjga823yzNN_YM,2162
|
|
11
|
+
interpkit/core/render.py,sha256=REAGSelRybnf-7HYCb7-YCXUKTd5QOmxHv3FKtehnWY,15914
|
|
12
|
+
interpkit/core/tl_compat.py,sha256=44ffZ7__JrgKf4e5UmaQ3fZLIAqjmw1VhjqqA_rD-AQ,6416
|
|
13
|
+
interpkit/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
interpkit/ops/ablate.py,sha256=CYgZzuIeP8iS2gK2ax5zZi9jDbG6dtldsMcb7S8B6oo,2828
|
|
15
|
+
interpkit/ops/activations.py,sha256=B4OP3cX23eUSP4oYWle7pxb6AzxVh3-dwBDFWtyvVag,1965
|
|
16
|
+
interpkit/ops/attention.py,sha256=qZxZvMoHmcpr47TbJInTdqrFZSk2PTlgyvdXN4T5bQs,7958
|
|
17
|
+
interpkit/ops/attribute.py,sha256=utqAijEhGpYKpMjufuNgrF2ouLdkwlXGkjAX8_Qubp8,7027
|
|
18
|
+
interpkit/ops/diff.py,sha256=SuC6qEgBgoeq5aM51onA7FHTrYMFHJdDtoULVcgotbU,2516
|
|
19
|
+
interpkit/ops/inspect.py,sha256=VUOamp5ePIIlUHZVwhBLmkr78hCv-aOmCwDyDyGn3KQ,356
|
|
20
|
+
interpkit/ops/lens.py,sha256=f3m8_pxXslntUAAPkftX7xVqlYqweIB2cCInoBs0p64,4782
|
|
21
|
+
interpkit/ops/patch.py,sha256=eiW_2NM73TErsvA63IrxOKhcB5SohF0q-jE0Y6Uckek,3730
|
|
22
|
+
interpkit/ops/probe.py,sha256=A4dTKxrUtilchXK-wd831Ry1bFAbtCXyZNIRh-INPWc,3852
|
|
23
|
+
interpkit/ops/sae.py,sha256=mSIfvzxODvEnHkB1X98hmxP9dCgbrLFeYF6duKEYQyc,6275
|
|
24
|
+
interpkit/ops/steer.py,sha256=1zyG_OHl-5pT9mziklPJsV3Z7nsiIRP9TubuKyv1wcM,3407
|
|
25
|
+
interpkit/ops/trace.py,sha256=kcLhYCyshbmL4B9BcLP4RVedpsHvf0TQAA9KRBMNKQM,6249
|
|
26
|
+
interpkit-0.1.0.dist-info/licenses/LICENSE,sha256=_kFMCpDgee4UrUPtrYy3fB-h9SwKJ_saecsNbQTNXIY,1068
|
|
27
|
+
interpkit-0.1.0.dist-info/METADATA,sha256=DI_IG1xxM8olIv5wi9HN1YbBk81K-Gw8rmJH7MNo1Bg,9714
|
|
28
|
+
interpkit-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
29
|
+
interpkit-0.1.0.dist-info/entry_points.txt,sha256=80W12z7dMTFc6DAU_vVU0MQ0xfhYSSFIlwzl6sIdBVg,53
|
|
30
|
+
interpkit-0.1.0.dist-info/top_level.txt,sha256=41VDlyHKdt6ePGIEKnazC4wsDpj1udXamgEP5haogk4,10
|
|
31
|
+
interpkit-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Davide Zani
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
interpkit
|