interpkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit/__init__.py +15 -0
- interpkit/cli/__init__.py +0 -0
- interpkit/cli/main.py +337 -0
- interpkit/core/__init__.py +0 -0
- interpkit/core/discovery.py +228 -0
- interpkit/core/html.py +375 -0
- interpkit/core/inputs.py +117 -0
- interpkit/core/model.py +551 -0
- interpkit/core/plot.py +352 -0
- interpkit/core/registry.py +82 -0
- interpkit/core/render.py +465 -0
- interpkit/core/tl_compat.py +174 -0
- interpkit/ops/__init__.py +0 -0
- interpkit/ops/ablate.py +90 -0
- interpkit/ops/activations.py +67 -0
- interpkit/ops/attention.py +234 -0
- interpkit/ops/attribute.py +206 -0
- interpkit/ops/diff.py +79 -0
- interpkit/ops/inspect.py +14 -0
- interpkit/ops/lens.py +151 -0
- interpkit/ops/patch.py +112 -0
- interpkit/ops/probe.py +128 -0
- interpkit/ops/sae.py +212 -0
- interpkit/ops/steer.py +118 -0
- interpkit/ops/trace.py +182 -0
- interpkit-0.1.0.dist-info/METADATA +295 -0
- interpkit-0.1.0.dist-info/RECORD +31 -0
- interpkit-0.1.0.dist-info/WHEEL +5 -0
- interpkit-0.1.0.dist-info/entry_points.txt +2 -0
- interpkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- interpkit-0.1.0.dist-info/top_level.txt +1 -0
interpkit/ops/ablate.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""ablate — zero or mean ablate a module and measure the effect on output."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from interpkit.ops.patch import _get_module
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from interpkit.core.model import Model
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def run_ablate(
|
|
16
|
+
model: "Model",
|
|
17
|
+
input_data: Any,
|
|
18
|
+
*,
|
|
19
|
+
at: str,
|
|
20
|
+
method: str = "zero",
|
|
21
|
+
) -> dict[str, Any]:
|
|
22
|
+
"""Ablate module *at* and measure the effect on output logits.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
method:
|
|
27
|
+
``"zero"`` replaces the module output with zeros.
|
|
28
|
+
``"mean"`` replaces it with the mean activation across the sequence dimension.
|
|
29
|
+
"""
|
|
30
|
+
from interpkit.core.render import render_ablate
|
|
31
|
+
|
|
32
|
+
model_input = model._prepare(input_data)
|
|
33
|
+
target_mod = _get_module(model._model, at)
|
|
34
|
+
|
|
35
|
+
# 1. Clean forward — get baseline logits
|
|
36
|
+
clean_logits = model._forward(model_input)
|
|
37
|
+
|
|
38
|
+
# 2. Ablated forward
|
|
39
|
+
def _ablate_hook(_mod: torch.nn.Module, _inp: Any, output: Any) -> Any:
|
|
40
|
+
t = output if isinstance(output, torch.Tensor) else (
|
|
41
|
+
output[0] if isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor) else None
|
|
42
|
+
)
|
|
43
|
+
if t is None:
|
|
44
|
+
return output
|
|
45
|
+
|
|
46
|
+
if method == "zero":
|
|
47
|
+
replacement = torch.zeros_like(t)
|
|
48
|
+
elif method == "mean":
|
|
49
|
+
if t.dim() >= 3:
|
|
50
|
+
replacement = t.mean(dim=-2, keepdim=True).expand_as(t)
|
|
51
|
+
else:
|
|
52
|
+
replacement = t.mean(dim=-1, keepdim=True).expand_as(t)
|
|
53
|
+
else:
|
|
54
|
+
raise ValueError(f"Unknown ablation method: {method!r}. Use 'zero' or 'mean'.")
|
|
55
|
+
|
|
56
|
+
if isinstance(output, torch.Tensor):
|
|
57
|
+
return replacement
|
|
58
|
+
return (replacement,) + tuple(output[1:])
|
|
59
|
+
|
|
60
|
+
handle = target_mod.register_forward_hook(_ablate_hook)
|
|
61
|
+
ablated_logits = model._forward(model_input)
|
|
62
|
+
handle.remove()
|
|
63
|
+
|
|
64
|
+
effect = _compute_ablation_effect(clean_logits, ablated_logits)
|
|
65
|
+
|
|
66
|
+
result = {
|
|
67
|
+
"module": at,
|
|
68
|
+
"method": method,
|
|
69
|
+
"effect": effect,
|
|
70
|
+
"clean_logits": clean_logits,
|
|
71
|
+
"ablated_logits": ablated_logits,
|
|
72
|
+
}
|
|
73
|
+
render_ablate(result)
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _compute_ablation_effect(clean: torch.Tensor, ablated: torch.Tensor) -> float:
|
|
78
|
+
"""Measure how much ablation changed the output (0 = no change, 1 = max change)."""
|
|
79
|
+
clean_flat = clean.view(-1, clean.shape[-1]).float()
|
|
80
|
+
ablated_flat = ablated.view(-1, ablated.shape[-1]).float()
|
|
81
|
+
|
|
82
|
+
if clean_flat.shape[0] > 1:
|
|
83
|
+
clean_flat = clean_flat[-1:]
|
|
84
|
+
ablated_flat = ablated_flat[-1:]
|
|
85
|
+
|
|
86
|
+
clean_probs = torch.softmax(clean_flat, dim=-1)
|
|
87
|
+
ablated_probs = torch.softmax(ablated_flat, dim=-1)
|
|
88
|
+
|
|
89
|
+
cosine_sim = torch.nn.functional.cosine_similarity(clean_probs, ablated_probs, dim=-1)
|
|
90
|
+
return (1.0 - cosine_sim.item())
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""activations — extract raw activation tensors at any named module."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from interpkit.ops.patch import _get_module
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from interpkit.core.model import Model
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def run_activations(
|
|
16
|
+
model: "Model",
|
|
17
|
+
input_data: Any,
|
|
18
|
+
*,
|
|
19
|
+
at: str | list[str],
|
|
20
|
+
print_stats: bool = True,
|
|
21
|
+
) -> dict[str, torch.Tensor] | torch.Tensor:
|
|
22
|
+
"""Extract activations at one or more named modules.
|
|
23
|
+
|
|
24
|
+
Returns a single tensor if *at* is a string, or a dict if *at* is a list.
|
|
25
|
+
"""
|
|
26
|
+
model_input = model._prepare(input_data)
|
|
27
|
+
single = isinstance(at, str)
|
|
28
|
+
module_names = [at] if single else list(at)
|
|
29
|
+
|
|
30
|
+
# Check activation cache first (pass prepared input to avoid re-tokenizing)
|
|
31
|
+
cached = model._get_cached(input_data, module_names, _prepared_input=model_input)
|
|
32
|
+
if cached is not None:
|
|
33
|
+
cache = cached
|
|
34
|
+
else:
|
|
35
|
+
cache = {}
|
|
36
|
+
|
|
37
|
+
def _make_hook(name: str):
|
|
38
|
+
def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
|
|
39
|
+
t = output if isinstance(output, torch.Tensor) else (
|
|
40
|
+
output[0] if isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor) else None
|
|
41
|
+
)
|
|
42
|
+
if t is not None:
|
|
43
|
+
cache[name] = t.detach().clone()
|
|
44
|
+
return hook_fn
|
|
45
|
+
|
|
46
|
+
hooks = []
|
|
47
|
+
for name in module_names:
|
|
48
|
+
mod = _get_module(model._model, name)
|
|
49
|
+
hooks.append(mod.register_forward_hook(_make_hook(name)))
|
|
50
|
+
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
model._forward(model_input)
|
|
53
|
+
|
|
54
|
+
for h in hooks:
|
|
55
|
+
h.remove()
|
|
56
|
+
|
|
57
|
+
if print_stats:
|
|
58
|
+
from interpkit.core.render import render_activations
|
|
59
|
+
|
|
60
|
+
render_activations(cache)
|
|
61
|
+
|
|
62
|
+
if single:
|
|
63
|
+
if at not in cache:
|
|
64
|
+
raise RuntimeError(f"Module '{at}' produced no tensor output.")
|
|
65
|
+
return cache[at]
|
|
66
|
+
|
|
67
|
+
return cache
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""attention — capture and display attention patterns for transformer models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
|
|
12
|
+
from interpkit.ops.patch import _get_module
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from interpkit.core.model import Model
|
|
16
|
+
|
|
17
|
+
console = Console()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def run_attention(
|
|
21
|
+
model: "Model",
|
|
22
|
+
input_data: Any,
|
|
23
|
+
*,
|
|
24
|
+
layer: int | None = None,
|
|
25
|
+
head: int | None = None,
|
|
26
|
+
save: str | None = None,
|
|
27
|
+
html: str | None = None,
|
|
28
|
+
) -> list[dict[str, Any]] | None:
|
|
29
|
+
"""Capture attention weights and display a summary.
|
|
30
|
+
|
|
31
|
+
Computes attention weights manually from Q/K projections via hooks,
|
|
32
|
+
since modern transformer implementations (SDPA, FlashAttention) don't
|
|
33
|
+
return attention weights.
|
|
34
|
+
"""
|
|
35
|
+
from interpkit.core.render import render_attention
|
|
36
|
+
|
|
37
|
+
arch = model.arch_info
|
|
38
|
+
attn_modules = [m for m in arch.modules if m.role == "attention"]
|
|
39
|
+
|
|
40
|
+
if not attn_modules:
|
|
41
|
+
console.print(
|
|
42
|
+
"\n [yellow]attention not available:[/yellow] no attention modules detected"
|
|
43
|
+
f" for {arch.arch_family or 'this model'}.\n"
|
|
44
|
+
)
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
model_input = model._prepare(input_data)
|
|
48
|
+
|
|
49
|
+
tokens = None
|
|
50
|
+
if model._tokenizer is not None and isinstance(input_data, str):
|
|
51
|
+
encoded = model._tokenizer(input_data, return_tensors="pt")
|
|
52
|
+
token_ids = encoded["input_ids"][0].tolist()
|
|
53
|
+
tokens = model._tokenizer.convert_ids_to_tokens(token_ids)
|
|
54
|
+
|
|
55
|
+
# Capture Q and K projections to compute attention weights manually
|
|
56
|
+
qk_cache: dict[str, dict[str, torch.Tensor]] = {}
|
|
57
|
+
hooks = []
|
|
58
|
+
|
|
59
|
+
for mod_info in attn_modules:
|
|
60
|
+
if layer is not None:
|
|
61
|
+
layer_match = re.search(r"\.(\d+)\.", mod_info.name)
|
|
62
|
+
if layer_match and int(layer_match.group(1)) != layer:
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
attn_mod = _get_module(model._model, mod_info.name)
|
|
66
|
+
|
|
67
|
+
# Find Q/K projection submodules
|
|
68
|
+
for child_name, child_mod in attn_mod.named_modules():
|
|
69
|
+
full_name = f"{mod_info.name}.{child_name}" if child_name else mod_info.name
|
|
70
|
+
is_qkv = any(p in child_name.lower() for p in ("c_attn", "qkv", "q_proj", "query"))
|
|
71
|
+
is_k = any(p in child_name.lower() for p in ("k_proj", "key"))
|
|
72
|
+
|
|
73
|
+
if is_qkv or is_k:
|
|
74
|
+
def _make_qk_hook(name: str, attn_name: str):
|
|
75
|
+
def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
|
|
76
|
+
t = output if isinstance(output, torch.Tensor) else (
|
|
77
|
+
output[0] if isinstance(output, (tuple, list)) else None
|
|
78
|
+
)
|
|
79
|
+
if t is not None:
|
|
80
|
+
qk_cache.setdefault(attn_name, {})[name] = t.detach()
|
|
81
|
+
return hook_fn
|
|
82
|
+
|
|
83
|
+
hooks.append(child_mod.register_forward_hook(_make_qk_hook(child_name, mod_info.name)))
|
|
84
|
+
|
|
85
|
+
with torch.no_grad():
|
|
86
|
+
model._forward(model_input)
|
|
87
|
+
|
|
88
|
+
for h in hooks:
|
|
89
|
+
h.remove()
|
|
90
|
+
|
|
91
|
+
# Compute attention weights from cached Q/K
|
|
92
|
+
results: list[dict[str, Any]] = []
|
|
93
|
+
|
|
94
|
+
for attn_name, projections in qk_cache.items():
|
|
95
|
+
layer_match = re.search(r"\.(\d+)\.", attn_name)
|
|
96
|
+
layer_idx = int(layer_match.group(1)) if layer_match else 0
|
|
97
|
+
|
|
98
|
+
attn_weights = _compute_attention_from_projections(
|
|
99
|
+
projections, arch.num_attention_heads or 12
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
if attn_weights is None:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
num_heads = attn_weights.shape[0]
|
|
106
|
+
for head_idx in range(num_heads):
|
|
107
|
+
if head is not None and head_idx != head:
|
|
108
|
+
continue
|
|
109
|
+
|
|
110
|
+
head_attn = attn_weights[head_idx]
|
|
111
|
+
top_pairs = _get_top_pairs(head_attn, k=5)
|
|
112
|
+
entropy = _attention_entropy(head_attn)
|
|
113
|
+
|
|
114
|
+
results.append({
|
|
115
|
+
"layer": layer_idx,
|
|
116
|
+
"head": head_idx,
|
|
117
|
+
"top_pairs": top_pairs,
|
|
118
|
+
"entropy": entropy,
|
|
119
|
+
"weights": head_attn,
|
|
120
|
+
})
|
|
121
|
+
|
|
122
|
+
if not results:
|
|
123
|
+
console.print(
|
|
124
|
+
"\n [yellow]attention:[/yellow] could not compute attention weights.\n"
|
|
125
|
+
)
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
model_name = arch.arch_family or "model"
|
|
129
|
+
render_attention(results, tokens, model_name)
|
|
130
|
+
|
|
131
|
+
if save is not None:
|
|
132
|
+
from interpkit.core.plot import plot_attention, plot_attention_multi
|
|
133
|
+
|
|
134
|
+
if layer is not None and head is not None and len(results) == 1:
|
|
135
|
+
plot_attention(results[0]["weights"], tokens, layer=results[0]["layer"],
|
|
136
|
+
head=results[0]["head"], save_path=save)
|
|
137
|
+
else:
|
|
138
|
+
plot_attention_multi(results, tokens, save_path=save)
|
|
139
|
+
|
|
140
|
+
if html is not None:
|
|
141
|
+
from interpkit.core.html import html_attention as gen_html_attention
|
|
142
|
+
from interpkit.core.html import save_html
|
|
143
|
+
|
|
144
|
+
serializable = []
|
|
145
|
+
for r in results:
|
|
146
|
+
entry = {**r}
|
|
147
|
+
w = r.get("weights")
|
|
148
|
+
if isinstance(w, torch.Tensor):
|
|
149
|
+
entry["weights"] = w.tolist()
|
|
150
|
+
serializable.append(entry)
|
|
151
|
+
save_html(gen_html_attention(serializable, tokens), html)
|
|
152
|
+
|
|
153
|
+
return results
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _compute_attention_from_projections(
|
|
157
|
+
projections: dict[str, torch.Tensor],
|
|
158
|
+
num_heads: int,
|
|
159
|
+
) -> torch.Tensor | None:
|
|
160
|
+
"""Compute attention weights from captured QKV or Q/K projections."""
|
|
161
|
+
# GPT-2 style: c_attn produces [Q, K, V] concatenated
|
|
162
|
+
for key, tensor in projections.items():
|
|
163
|
+
if "c_attn" in key or "qkv" in key:
|
|
164
|
+
# tensor shape: (batch, seq, 3 * hidden) or (seq, 3 * hidden)
|
|
165
|
+
if tensor.dim() == 3:
|
|
166
|
+
tensor = tensor[0] # drop batch
|
|
167
|
+
hidden = tensor.shape[-1] // 3
|
|
168
|
+
q, k, _v = tensor.split(hidden, dim=-1)
|
|
169
|
+
return _qk_to_attention(q, k, num_heads)
|
|
170
|
+
|
|
171
|
+
# Separate Q and K projections
|
|
172
|
+
q_tensor = None
|
|
173
|
+
k_tensor = None
|
|
174
|
+
for key, tensor in projections.items():
|
|
175
|
+
if "q_proj" in key or "query" in key:
|
|
176
|
+
q_tensor = tensor
|
|
177
|
+
elif "k_proj" in key or "key" in key:
|
|
178
|
+
k_tensor = tensor
|
|
179
|
+
|
|
180
|
+
if q_tensor is not None and k_tensor is not None:
|
|
181
|
+
if q_tensor.dim() == 3:
|
|
182
|
+
q_tensor = q_tensor[0]
|
|
183
|
+
if k_tensor.dim() == 3:
|
|
184
|
+
k_tensor = k_tensor[0]
|
|
185
|
+
return _qk_to_attention(q_tensor, k_tensor, num_heads)
|
|
186
|
+
|
|
187
|
+
return None
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _qk_to_attention(
|
|
191
|
+
q: torch.Tensor, k: torch.Tensor, num_heads: int,
|
|
192
|
+
) -> torch.Tensor:
|
|
193
|
+
"""Compute attention weights from Q and K tensors.
|
|
194
|
+
|
|
195
|
+
q, k: (seq_len, hidden_size)
|
|
196
|
+
Returns: (num_heads, seq_len, seq_len)
|
|
197
|
+
"""
|
|
198
|
+
seq_len, hidden = q.shape
|
|
199
|
+
head_dim = hidden // num_heads
|
|
200
|
+
|
|
201
|
+
q = q.view(seq_len, num_heads, head_dim).transpose(0, 1) # (heads, seq, head_dim)
|
|
202
|
+
k = k.view(seq_len, num_heads, head_dim).transpose(0, 1)
|
|
203
|
+
|
|
204
|
+
scale = head_dim ** 0.5
|
|
205
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) / scale # (heads, seq, seq)
|
|
206
|
+
|
|
207
|
+
# Apply causal mask
|
|
208
|
+
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=q.device), diagonal=1)
|
|
209
|
+
scores.masked_fill_(causal_mask.unsqueeze(0), float("-inf"))
|
|
210
|
+
|
|
211
|
+
return F.softmax(scores, dim=-1)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _get_top_pairs(
|
|
215
|
+
attn: torch.Tensor, k: int = 5,
|
|
216
|
+
) -> list[tuple[int, int, float]]:
|
|
217
|
+
"""Find top-k (source_pos, target_pos, score) pairs in an attention matrix."""
|
|
218
|
+
flat = attn.view(-1)
|
|
219
|
+
topk_vals, topk_idxs = flat.topk(min(k, flat.numel()))
|
|
220
|
+
seq_len = attn.shape[-1]
|
|
221
|
+
pairs = []
|
|
222
|
+
for val, idx in zip(topk_vals.tolist(), topk_idxs.tolist()):
|
|
223
|
+
src = idx // seq_len
|
|
224
|
+
tgt = idx % seq_len
|
|
225
|
+
pairs.append((src, tgt, val))
|
|
226
|
+
return pairs
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _attention_entropy(attn: torch.Tensor) -> float:
|
|
230
|
+
"""Mean entropy of attention distributions across query positions."""
|
|
231
|
+
eps = 1e-10
|
|
232
|
+
log_attn = torch.log(attn + eps)
|
|
233
|
+
entropy_per_query = -(attn * log_attn).sum(dim=-1)
|
|
234
|
+
return entropy_per_query.mean().item()
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""attribute — gradient saliency over input tokens or pixels."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from interpkit.core.model import Model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def run_attribute(
|
|
14
|
+
model: "Model",
|
|
15
|
+
input_data: Any,
|
|
16
|
+
*,
|
|
17
|
+
target: int | None = None,
|
|
18
|
+
save: str | None = None,
|
|
19
|
+
html: str | None = None,
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Compute gradient-based saliency and render results.
|
|
22
|
+
|
|
23
|
+
For text inputs: shows coloured tokens by importance.
|
|
24
|
+
For image/tensor inputs: saves a heatmap to disk.
|
|
25
|
+
"""
|
|
26
|
+
is_text = isinstance(input_data, str) and not _is_image_path(input_data)
|
|
27
|
+
is_image = isinstance(input_data, str) and _is_image_path(input_data)
|
|
28
|
+
|
|
29
|
+
if is_text:
|
|
30
|
+
_attribute_text(model, input_data, target=target, save=save, html=html)
|
|
31
|
+
elif is_image:
|
|
32
|
+
_attribute_image(model, input_data, target=target, save=save)
|
|
33
|
+
else:
|
|
34
|
+
_attribute_tensor(model, input_data, target=target)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _attribute_text(model: "Model", text: str, *, target: int | None, save: str | None = None, html: str | None = None) -> None:
|
|
38
|
+
from interpkit.core.render import render_attribution_tokens
|
|
39
|
+
|
|
40
|
+
if model._tokenizer is None:
|
|
41
|
+
raise ValueError("No tokenizer available for text attribution.")
|
|
42
|
+
|
|
43
|
+
encoded = model._tokenizer(text, return_tensors="pt")
|
|
44
|
+
input_ids = encoded["input_ids"].to(model._device)
|
|
45
|
+
|
|
46
|
+
# Get embedding layer
|
|
47
|
+
embed_layer = _find_embedding(model._model)
|
|
48
|
+
if embed_layer is None:
|
|
49
|
+
raise RuntimeError("Could not find embedding layer for gradient attribution.")
|
|
50
|
+
|
|
51
|
+
embeddings = embed_layer(input_ids)
|
|
52
|
+
embeddings = embeddings.detach().requires_grad_(True)
|
|
53
|
+
|
|
54
|
+
# Replace embedding output with our gradient-tracked version
|
|
55
|
+
original_forward = embed_layer.forward
|
|
56
|
+
|
|
57
|
+
def _patched_forward(*args: Any, **kwargs: Any) -> torch.Tensor:
|
|
58
|
+
return embeddings
|
|
59
|
+
|
|
60
|
+
embed_layer.forward = _patched_forward # type: ignore[assignment]
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
model_kwargs = {k: v.to(model._device) for k, v in encoded.items() if k != "input_ids"}
|
|
64
|
+
out = model._model(input_ids, **model_kwargs)
|
|
65
|
+
|
|
66
|
+
logits = out.logits if hasattr(out, "logits") else (out[0] if isinstance(out, (tuple, list)) else out)
|
|
67
|
+
|
|
68
|
+
# Pick target: last-position argmax if not specified
|
|
69
|
+
if logits.dim() == 3:
|
|
70
|
+
logits_last = logits[0, -1, :]
|
|
71
|
+
else:
|
|
72
|
+
logits_last = logits[0]
|
|
73
|
+
|
|
74
|
+
if target is None:
|
|
75
|
+
target = logits_last.argmax().item()
|
|
76
|
+
|
|
77
|
+
score = logits_last[target]
|
|
78
|
+
score.backward()
|
|
79
|
+
finally:
|
|
80
|
+
embed_layer.forward = original_forward # type: ignore[assignment]
|
|
81
|
+
|
|
82
|
+
if embeddings.grad is None:
|
|
83
|
+
raise RuntimeError("Gradient computation failed — no gradients on embeddings.")
|
|
84
|
+
|
|
85
|
+
# Per-token importance: L2 norm of gradient over the embedding dimension
|
|
86
|
+
token_grads = embeddings.grad[0] # (seq_len, hidden)
|
|
87
|
+
token_scores = token_grads.norm(dim=-1).tolist() # (seq_len,)
|
|
88
|
+
|
|
89
|
+
tokens = model._tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
|
|
90
|
+
render_attribution_tokens(tokens, token_scores)
|
|
91
|
+
|
|
92
|
+
if save is not None:
|
|
93
|
+
from interpkit.core.plot import plot_attribution
|
|
94
|
+
|
|
95
|
+
plot_attribution(tokens, token_scores, save_path=save)
|
|
96
|
+
|
|
97
|
+
if html is not None:
|
|
98
|
+
from interpkit.core.html import html_attribution as gen_html_attribution
|
|
99
|
+
from interpkit.core.html import save_html
|
|
100
|
+
|
|
101
|
+
save_html(gen_html_attribution(tokens, token_scores), html)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _attribute_image(model: "Model", image_path: str, *, target: int | None, save: str | None = None) -> None:
|
|
105
|
+
from interpkit.core.inputs import _load_image
|
|
106
|
+
from interpkit.core.render import render_attribution_heatmap
|
|
107
|
+
|
|
108
|
+
processed = _load_image(
|
|
109
|
+
image_path,
|
|
110
|
+
image_processor=model._image_processor,
|
|
111
|
+
device=model._device,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if isinstance(processed, dict):
|
|
115
|
+
pixel_key = "pixel_values" if "pixel_values" in processed else list(processed.keys())[0]
|
|
116
|
+
pixel_values = processed[pixel_key].requires_grad_(True)
|
|
117
|
+
model_input = {**processed, pixel_key: pixel_values}
|
|
118
|
+
out = model._model(**model_input)
|
|
119
|
+
else:
|
|
120
|
+
pixel_values = processed.requires_grad_(True)
|
|
121
|
+
out = model._model(pixel_values)
|
|
122
|
+
|
|
123
|
+
logits = out.logits if hasattr(out, "logits") else (out[0] if isinstance(out, (tuple, list)) else out)
|
|
124
|
+
|
|
125
|
+
if logits.dim() > 1:
|
|
126
|
+
logits_flat = logits[0]
|
|
127
|
+
else:
|
|
128
|
+
logits_flat = logits
|
|
129
|
+
|
|
130
|
+
if target is None:
|
|
131
|
+
target = logits_flat.argmax().item()
|
|
132
|
+
|
|
133
|
+
score = logits_flat[target]
|
|
134
|
+
score.backward()
|
|
135
|
+
|
|
136
|
+
if pixel_values.grad is None:
|
|
137
|
+
raise RuntimeError("Gradient computation failed — no gradients on pixel values.")
|
|
138
|
+
|
|
139
|
+
out_path = save or "attribution_heatmap.png"
|
|
140
|
+
render_attribution_heatmap(pixel_values.grad[0], output_path=out_path)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _attribute_tensor(model: "Model", tensor_input: Any, *, target: int | None) -> None:
|
|
144
|
+
from interpkit.core.render import render_attribution_tokens
|
|
145
|
+
|
|
146
|
+
inp = model._prepare(tensor_input)
|
|
147
|
+
|
|
148
|
+
if isinstance(inp, dict):
|
|
149
|
+
# Pick the first tensor-valued entry
|
|
150
|
+
for k, v in inp.items():
|
|
151
|
+
if isinstance(v, torch.Tensor) and v.is_floating_point():
|
|
152
|
+
inp[k] = v.requires_grad_(True)
|
|
153
|
+
grad_tensor = v
|
|
154
|
+
break
|
|
155
|
+
else:
|
|
156
|
+
raise ValueError("No floating-point tensor found in input dict.")
|
|
157
|
+
out = model._model(**inp)
|
|
158
|
+
else:
|
|
159
|
+
inp = inp.requires_grad_(True)
|
|
160
|
+
grad_tensor = inp
|
|
161
|
+
out = model._model(inp)
|
|
162
|
+
|
|
163
|
+
logits = out.logits if hasattr(out, "logits") else (out[0] if isinstance(out, (tuple, list)) else out)
|
|
164
|
+
logits_flat = logits.view(-1)
|
|
165
|
+
|
|
166
|
+
if target is None:
|
|
167
|
+
target = logits_flat.argmax().item()
|
|
168
|
+
|
|
169
|
+
score = logits_flat[target]
|
|
170
|
+
score.backward()
|
|
171
|
+
|
|
172
|
+
if grad_tensor.grad is None:
|
|
173
|
+
raise RuntimeError("Gradient computation failed.")
|
|
174
|
+
|
|
175
|
+
# Flatten to feature importance
|
|
176
|
+
grad = grad_tensor.grad.detach().float()
|
|
177
|
+
if grad.dim() > 1:
|
|
178
|
+
feature_scores = grad.view(grad.shape[0], -1).norm(dim=0).tolist()
|
|
179
|
+
else:
|
|
180
|
+
feature_scores = grad.abs().tolist()
|
|
181
|
+
|
|
182
|
+
labels = [f"feat_{i}" for i in range(len(feature_scores))]
|
|
183
|
+
render_attribution_tokens(labels, feature_scores)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _find_embedding(model: torch.nn.Module) -> torch.nn.Module | None:
|
|
187
|
+
"""Find the token embedding layer."""
|
|
188
|
+
for name, mod in model.named_modules():
|
|
189
|
+
if isinstance(mod, torch.nn.Embedding) and "token" not in name.lower().replace("token", ""):
|
|
190
|
+
pass
|
|
191
|
+
if isinstance(mod, torch.nn.Embedding):
|
|
192
|
+
# Pick the largest embedding (usually token embeddings, not position)
|
|
193
|
+
if mod.num_embeddings > 1000:
|
|
194
|
+
return mod
|
|
195
|
+
|
|
196
|
+
# Fallback: first embedding
|
|
197
|
+
for _name, mod in model.named_modules():
|
|
198
|
+
if isinstance(mod, torch.nn.Embedding):
|
|
199
|
+
return mod
|
|
200
|
+
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _is_image_path(s: str) -> bool:
|
|
205
|
+
import os
|
|
206
|
+
return os.path.splitext(s)[1].lower() in {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
interpkit/ops/diff.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""diff — compare activations between two models on the same input."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from interpkit.core.model import Model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def run_diff(
|
|
14
|
+
model_a: "Model",
|
|
15
|
+
model_b: "Model",
|
|
16
|
+
input_data: Any,
|
|
17
|
+
*,
|
|
18
|
+
save: str | None = None,
|
|
19
|
+
) -> list[dict[str, Any]]:
|
|
20
|
+
"""Compare activations between two models at all discovered layers.
|
|
21
|
+
|
|
22
|
+
Returns a list of dicts sorted by cosine distance (highest change first).
|
|
23
|
+
"""
|
|
24
|
+
from interpkit.core.render import render_diff
|
|
25
|
+
from interpkit.ops.activations import run_activations
|
|
26
|
+
|
|
27
|
+
# Find shared layer-like modules
|
|
28
|
+
layers_a = set(model_a.arch_info.layer_names or [])
|
|
29
|
+
layers_b = set(model_b.arch_info.layer_names or [])
|
|
30
|
+
|
|
31
|
+
# If no layers detected, fall back to all named modules
|
|
32
|
+
if not layers_a:
|
|
33
|
+
layers_a = {m.name for m in model_a.arch_info.modules if m.param_count > 0}
|
|
34
|
+
if not layers_b:
|
|
35
|
+
layers_b = {m.name for m in model_b.arch_info.modules if m.param_count > 0}
|
|
36
|
+
|
|
37
|
+
shared_layers = sorted(layers_a & layers_b)
|
|
38
|
+
|
|
39
|
+
if not shared_layers:
|
|
40
|
+
from rich.console import Console
|
|
41
|
+
Console().print("\n [yellow]diff:[/yellow] no shared modules found between the two models.\n")
|
|
42
|
+
return []
|
|
43
|
+
|
|
44
|
+
acts_a = run_activations(model_a, input_data, at=shared_layers, print_stats=False)
|
|
45
|
+
acts_b = run_activations(model_b, input_data, at=shared_layers, print_stats=False)
|
|
46
|
+
|
|
47
|
+
results: list[dict[str, Any]] = []
|
|
48
|
+
for name in shared_layers:
|
|
49
|
+
if name not in acts_a or name not in acts_b:
|
|
50
|
+
continue
|
|
51
|
+
|
|
52
|
+
a = acts_a[name].float().view(-1)
|
|
53
|
+
b = acts_b[name].float().view(-1)
|
|
54
|
+
|
|
55
|
+
if a.shape != b.shape:
|
|
56
|
+
min_size = min(a.numel(), b.numel())
|
|
57
|
+
a = a[:min_size]
|
|
58
|
+
b = b[:min_size]
|
|
59
|
+
|
|
60
|
+
cosine_sim = torch.nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0), dim=-1)
|
|
61
|
+
distance = (1.0 - cosine_sim.item())
|
|
62
|
+
|
|
63
|
+
results.append({
|
|
64
|
+
"module": name,
|
|
65
|
+
"distance": distance,
|
|
66
|
+
})
|
|
67
|
+
|
|
68
|
+
results.sort(key=lambda r: r["distance"], reverse=True)
|
|
69
|
+
|
|
70
|
+
model_a_name = model_a.arch_info.arch_family or "model_a"
|
|
71
|
+
model_b_name = model_b.arch_info.arch_family or "model_b"
|
|
72
|
+
render_diff(results, model_a_name, model_b_name)
|
|
73
|
+
|
|
74
|
+
if save is not None:
|
|
75
|
+
from interpkit.core.plot import plot_diff
|
|
76
|
+
|
|
77
|
+
plot_diff(results, model_a_name=model_a_name, model_b_name=model_b_name, save_path=save)
|
|
78
|
+
|
|
79
|
+
return results
|
interpkit/ops/inspect.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""inspect — module tree with types, param counts, output shapes, and detected roles."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from interpkit.core.model import Model
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def run_inspect(model: "Model") -> None:
|
|
12
|
+
from interpkit.core.render import render_inspect
|
|
13
|
+
|
|
14
|
+
render_inspect(model.arch_info)
|