interpkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit/__init__.py +15 -0
- interpkit/cli/__init__.py +0 -0
- interpkit/cli/main.py +337 -0
- interpkit/core/__init__.py +0 -0
- interpkit/core/discovery.py +228 -0
- interpkit/core/html.py +375 -0
- interpkit/core/inputs.py +117 -0
- interpkit/core/model.py +551 -0
- interpkit/core/plot.py +352 -0
- interpkit/core/registry.py +82 -0
- interpkit/core/render.py +465 -0
- interpkit/core/tl_compat.py +174 -0
- interpkit/ops/__init__.py +0 -0
- interpkit/ops/ablate.py +90 -0
- interpkit/ops/activations.py +67 -0
- interpkit/ops/attention.py +234 -0
- interpkit/ops/attribute.py +206 -0
- interpkit/ops/diff.py +79 -0
- interpkit/ops/inspect.py +14 -0
- interpkit/ops/lens.py +151 -0
- interpkit/ops/patch.py +112 -0
- interpkit/ops/probe.py +128 -0
- interpkit/ops/sae.py +212 -0
- interpkit/ops/steer.py +118 -0
- interpkit/ops/trace.py +182 -0
- interpkit-0.1.0.dist-info/METADATA +295 -0
- interpkit-0.1.0.dist-info/RECORD +31 -0
- interpkit-0.1.0.dist-info/WHEEL +5 -0
- interpkit-0.1.0.dist-info/entry_points.txt +2 -0
- interpkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- interpkit-0.1.0.dist-info/top_level.txt +1 -0
interpkit/ops/lens.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""lens — logit lens: project each layer's output to vocabulary space."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
from interpkit.ops.patch import _get_module
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from interpkit.core.model import Model
|
|
14
|
+
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def run_lens(model: "Model", text: Any, *, save: str | None = None) -> list[dict[str, Any]] | None:
|
|
19
|
+
"""Project each layer's hidden state through the unembedding matrix.
|
|
20
|
+
|
|
21
|
+
Only works for language models with a detectable output head.
|
|
22
|
+
"""
|
|
23
|
+
from interpkit.core.render import render_lens
|
|
24
|
+
|
|
25
|
+
arch = model.arch_info
|
|
26
|
+
|
|
27
|
+
if not arch.is_language_model or arch.unembedding_name is None:
|
|
28
|
+
console.print(
|
|
29
|
+
f"\n [yellow]lens not available:[/yellow] no unembedding matrix detected"
|
|
30
|
+
f" for {arch.arch_family or 'this model'}.\n"
|
|
31
|
+
)
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
if not arch.layer_names:
|
|
35
|
+
console.print(
|
|
36
|
+
"\n [yellow]lens not available:[/yellow] no layer structure detected.\n"
|
|
37
|
+
)
|
|
38
|
+
return None
|
|
39
|
+
|
|
40
|
+
if model._tokenizer is None:
|
|
41
|
+
console.print(
|
|
42
|
+
"\n [yellow]lens not available:[/yellow] no tokenizer loaded.\n"
|
|
43
|
+
)
|
|
44
|
+
return None
|
|
45
|
+
|
|
46
|
+
text_input = model._prepare(text)
|
|
47
|
+
|
|
48
|
+
# Get the unembedding weight matrix
|
|
49
|
+
unembed_mod = _get_module(model._model, arch.unembedding_name)
|
|
50
|
+
unembed_weight = unembed_mod.weight # shape: (vocab_size, hidden_size)
|
|
51
|
+
|
|
52
|
+
# Capture hidden states at the output of each layer
|
|
53
|
+
layer_outputs: dict[str, torch.Tensor] = {}
|
|
54
|
+
|
|
55
|
+
def _make_hook(name: str):
|
|
56
|
+
def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
|
|
57
|
+
t = output if isinstance(output, torch.Tensor) else (
|
|
58
|
+
output[0] if isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor) else None
|
|
59
|
+
)
|
|
60
|
+
if t is not None:
|
|
61
|
+
layer_outputs[name] = t.detach()
|
|
62
|
+
return hook_fn
|
|
63
|
+
|
|
64
|
+
hooks = []
|
|
65
|
+
for layer_name in arch.layer_names:
|
|
66
|
+
try:
|
|
67
|
+
mod = _get_module(model._model, layer_name)
|
|
68
|
+
hooks.append(mod.register_forward_hook(_make_hook(layer_name)))
|
|
69
|
+
except AttributeError:
|
|
70
|
+
continue
|
|
71
|
+
|
|
72
|
+
with torch.no_grad():
|
|
73
|
+
model._forward(text_input)
|
|
74
|
+
|
|
75
|
+
for h in hooks:
|
|
76
|
+
h.remove()
|
|
77
|
+
|
|
78
|
+
if not layer_outputs:
|
|
79
|
+
console.print("\n [yellow]lens:[/yellow] no layer outputs captured.\n")
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
# Apply layer norm if the model has a final layer norm before the head
|
|
83
|
+
# (common pattern: ln_f, model.norm, etc.)
|
|
84
|
+
final_norm = _find_final_norm(model._model, arch)
|
|
85
|
+
|
|
86
|
+
predictions: list[dict[str, Any]] = []
|
|
87
|
+
|
|
88
|
+
for layer_name in arch.layer_names:
|
|
89
|
+
if layer_name not in layer_outputs:
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
hidden = layer_outputs[layer_name].float()
|
|
93
|
+
|
|
94
|
+
# Take the last token position
|
|
95
|
+
if hidden.dim() == 3:
|
|
96
|
+
hidden = hidden[:, -1, :] # (batch, hidden)
|
|
97
|
+
elif hidden.dim() == 2:
|
|
98
|
+
hidden = hidden[-1:, :]
|
|
99
|
+
|
|
100
|
+
# Apply final norm if found
|
|
101
|
+
if final_norm is not None:
|
|
102
|
+
hidden = final_norm(hidden)
|
|
103
|
+
|
|
104
|
+
# Project through unembedding: logits = hidden @ W^T
|
|
105
|
+
logits = hidden @ unembed_weight.float().T # (batch, vocab)
|
|
106
|
+
probs = torch.softmax(logits, dim=-1)
|
|
107
|
+
|
|
108
|
+
top5_probs, top5_ids = probs[0].topk(5)
|
|
109
|
+
top5_tokens = [model._tokenizer.decode([tid]) for tid in top5_ids.tolist()]
|
|
110
|
+
top5_probs_list = top5_probs.tolist()
|
|
111
|
+
|
|
112
|
+
predictions.append({
|
|
113
|
+
"layer_name": layer_name,
|
|
114
|
+
"top1_token": top5_tokens[0],
|
|
115
|
+
"top1_prob": top5_probs_list[0],
|
|
116
|
+
"top5_tokens": top5_tokens,
|
|
117
|
+
"top5_probs": top5_probs_list,
|
|
118
|
+
})
|
|
119
|
+
|
|
120
|
+
model_name = arch.arch_family or "model"
|
|
121
|
+
render_lens(predictions, model_name)
|
|
122
|
+
|
|
123
|
+
if save is not None:
|
|
124
|
+
from interpkit.core.plot import plot_lens
|
|
125
|
+
|
|
126
|
+
plot_lens(predictions, save_path=save)
|
|
127
|
+
|
|
128
|
+
return predictions
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _find_final_norm(model: torch.nn.Module, arch: Any) -> torch.nn.Module | None:
|
|
132
|
+
"""Try to find the final layer norm applied before the LM head."""
|
|
133
|
+
import re
|
|
134
|
+
|
|
135
|
+
norm_pattern = re.compile(
|
|
136
|
+
r"^(model\.norm|transformer\.ln_f|gpt_neox\.final_layer_norm|"
|
|
137
|
+
r"model\.final_layernorm|backbone\.norm_f)$",
|
|
138
|
+
re.IGNORECASE,
|
|
139
|
+
)
|
|
140
|
+
for name, mod in model.named_modules():
|
|
141
|
+
if norm_pattern.match(name):
|
|
142
|
+
return mod
|
|
143
|
+
|
|
144
|
+
# Generic fallback: look for a top-level norm module
|
|
145
|
+
for name, mod in model.named_modules():
|
|
146
|
+
if name.count(".") <= 1 and isinstance(mod, (torch.nn.LayerNorm,)):
|
|
147
|
+
type_name = type(mod).__name__
|
|
148
|
+
if "norm" in name.lower() or "Norm" in type_name:
|
|
149
|
+
return mod
|
|
150
|
+
|
|
151
|
+
return None
|
interpkit/ops/patch.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""patch — activation patching at a named module between clean and corrupted inputs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from interpkit.core.model import Model
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _get_module(model: torch.nn.Module, name: str) -> torch.nn.Module:
|
|
14
|
+
parts = name.split(".")
|
|
15
|
+
mod = model
|
|
16
|
+
for part in parts:
|
|
17
|
+
mod = getattr(mod, part)
|
|
18
|
+
return mod
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def run_patch(
|
|
22
|
+
model: "Model",
|
|
23
|
+
clean: Any,
|
|
24
|
+
corrupted: Any,
|
|
25
|
+
*,
|
|
26
|
+
at: str,
|
|
27
|
+
) -> dict[str, Any]:
|
|
28
|
+
"""Patch the output of module *at* from the clean run into the corrupted run.
|
|
29
|
+
|
|
30
|
+
Returns a dict with ``effect`` — a normalised scalar in [0, 1] measuring how
|
|
31
|
+
much the patched corrupted run's output shifted toward the clean output.
|
|
32
|
+
"""
|
|
33
|
+
from interpkit.core.render import render_patch
|
|
34
|
+
|
|
35
|
+
clean_input, corrupted_input = model._prepare_pair(clean, corrupted)
|
|
36
|
+
|
|
37
|
+
# 1. Clean forward — cache the target module's output
|
|
38
|
+
cached_activation: list[torch.Tensor] = []
|
|
39
|
+
|
|
40
|
+
target_mod = _get_module(model._model, at)
|
|
41
|
+
|
|
42
|
+
def _cache_hook(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
|
|
43
|
+
if isinstance(output, torch.Tensor):
|
|
44
|
+
cached_activation.append(output.detach().clone())
|
|
45
|
+
elif isinstance(output, (tuple, list)):
|
|
46
|
+
cached_activation.append(output[0].detach().clone())
|
|
47
|
+
|
|
48
|
+
handle = target_mod.register_forward_hook(_cache_hook)
|
|
49
|
+
clean_logits = model._forward(clean_input)
|
|
50
|
+
handle.remove()
|
|
51
|
+
|
|
52
|
+
if not cached_activation:
|
|
53
|
+
raise RuntimeError(f"Module '{at}' produced no tensor output during clean forward pass.")
|
|
54
|
+
|
|
55
|
+
# 2. Corrupted forward (baseline)
|
|
56
|
+
corrupted_logits = model._forward(corrupted_input)
|
|
57
|
+
|
|
58
|
+
# 3. Patched forward — replace target module's output with cached clean activation
|
|
59
|
+
def _patch_hook(_mod: torch.nn.Module, _inp: Any, output: Any) -> Any:
|
|
60
|
+
if isinstance(output, torch.Tensor):
|
|
61
|
+
return cached_activation[0]
|
|
62
|
+
elif isinstance(output, (tuple, list)):
|
|
63
|
+
return (cached_activation[0],) + tuple(output[1:])
|
|
64
|
+
return output
|
|
65
|
+
|
|
66
|
+
handle = target_mod.register_forward_hook(_patch_hook)
|
|
67
|
+
patched_logits = model._forward(corrupted_input)
|
|
68
|
+
handle.remove()
|
|
69
|
+
|
|
70
|
+
# 4. Compute normalised effect
|
|
71
|
+
effect = _compute_effect(clean_logits, corrupted_logits, patched_logits)
|
|
72
|
+
|
|
73
|
+
result = {
|
|
74
|
+
"module": at,
|
|
75
|
+
"effect": effect,
|
|
76
|
+
"clean_logits": clean_logits,
|
|
77
|
+
"corrupted_logits": corrupted_logits,
|
|
78
|
+
"patched_logits": patched_logits,
|
|
79
|
+
}
|
|
80
|
+
render_patch(result)
|
|
81
|
+
return result
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _compute_effect(
|
|
85
|
+
clean: torch.Tensor,
|
|
86
|
+
corrupted: torch.Tensor,
|
|
87
|
+
patched: torch.Tensor,
|
|
88
|
+
) -> float:
|
|
89
|
+
"""Normalised patching effect: 0 = patched == corrupted, 1 = patched == clean."""
|
|
90
|
+
# Use KL divergence on the last-token logits as the distance metric
|
|
91
|
+
clean_flat = clean.view(-1, clean.shape[-1]).float()
|
|
92
|
+
corrupted_flat = corrupted.view(-1, corrupted.shape[-1]).float()
|
|
93
|
+
patched_flat = patched.view(-1, patched.shape[-1]).float()
|
|
94
|
+
|
|
95
|
+
# Take last position for sequence models
|
|
96
|
+
if clean_flat.shape[0] > 1:
|
|
97
|
+
clean_flat = clean_flat[-1:]
|
|
98
|
+
corrupted_flat = corrupted_flat[-1:]
|
|
99
|
+
patched_flat = patched_flat[-1:]
|
|
100
|
+
|
|
101
|
+
clean_probs = torch.softmax(clean_flat, dim=-1)
|
|
102
|
+
corrupted_probs = torch.softmax(corrupted_flat, dim=-1)
|
|
103
|
+
patched_probs = torch.softmax(patched_flat, dim=-1)
|
|
104
|
+
|
|
105
|
+
dist_corrupted_clean = torch.norm(corrupted_probs - clean_probs).item()
|
|
106
|
+
dist_patched_clean = torch.norm(patched_probs - clean_probs).item()
|
|
107
|
+
|
|
108
|
+
if dist_corrupted_clean < 1e-8:
|
|
109
|
+
return 0.0
|
|
110
|
+
|
|
111
|
+
effect = 1.0 - (dist_patched_clean / dist_corrupted_clean)
|
|
112
|
+
return max(0.0, min(1.0, effect))
|
interpkit/ops/probe.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""probe — train a linear probe on activations to test linear separability."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from interpkit.core.model import Model
|
|
12
|
+
|
|
13
|
+
console = Console()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_probe(
|
|
17
|
+
model: "Model",
|
|
18
|
+
texts: list[str],
|
|
19
|
+
labels: list[int],
|
|
20
|
+
*,
|
|
21
|
+
at: str,
|
|
22
|
+
) -> dict[str, Any]:
|
|
23
|
+
"""Train a linear probe on activations at module *at*.
|
|
24
|
+
|
|
25
|
+
Uses LogisticRegression from scikit-learn. Falls back to a simple
|
|
26
|
+
torch-based probe if sklearn is not installed.
|
|
27
|
+
"""
|
|
28
|
+
from interpkit.core.render import render_probe
|
|
29
|
+
from interpkit.ops.activations import run_activations
|
|
30
|
+
|
|
31
|
+
if len(texts) != len(labels):
|
|
32
|
+
raise ValueError(f"texts ({len(texts)}) and labels ({len(labels)}) must have the same length.")
|
|
33
|
+
|
|
34
|
+
# Extract activations for all texts
|
|
35
|
+
features = []
|
|
36
|
+
for text in texts:
|
|
37
|
+
act = run_activations(model, text, at=at, print_stats=False)
|
|
38
|
+
# Take last-token hidden state for sequence models
|
|
39
|
+
if act.dim() == 3:
|
|
40
|
+
vec = act[0, -1, :] # (hidden,)
|
|
41
|
+
elif act.dim() == 2:
|
|
42
|
+
vec = act[-1, :]
|
|
43
|
+
else:
|
|
44
|
+
vec = act.view(-1)
|
|
45
|
+
features.append(vec.cpu().float().numpy())
|
|
46
|
+
|
|
47
|
+
import numpy as np
|
|
48
|
+
|
|
49
|
+
X = np.stack(features)
|
|
50
|
+
y = np.array(labels)
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
result = _probe_sklearn(X, y)
|
|
54
|
+
except ImportError:
|
|
55
|
+
result = _probe_torch(X, y)
|
|
56
|
+
|
|
57
|
+
result["module"] = at
|
|
58
|
+
render_probe(result)
|
|
59
|
+
return result
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _probe_sklearn(X: Any, y: Any) -> dict[str, Any]:
|
|
63
|
+
from sklearn.linear_model import LogisticRegression
|
|
64
|
+
from sklearn.model_selection import cross_val_score
|
|
65
|
+
|
|
66
|
+
n_samples = len(y)
|
|
67
|
+
|
|
68
|
+
if n_samples >= 10:
|
|
69
|
+
cv_folds = min(5, n_samples)
|
|
70
|
+
clf = LogisticRegression(max_iter=1000, solver="lbfgs")
|
|
71
|
+
scores = cross_val_score(clf, X, y, cv=cv_folds, scoring="accuracy")
|
|
72
|
+
accuracy = float(scores.mean())
|
|
73
|
+
else:
|
|
74
|
+
accuracy = None
|
|
75
|
+
|
|
76
|
+
# Train on full data for feature analysis
|
|
77
|
+
clf = LogisticRegression(max_iter=1000, solver="lbfgs")
|
|
78
|
+
clf.fit(X, y)
|
|
79
|
+
train_accuracy = float(clf.score(X, y))
|
|
80
|
+
|
|
81
|
+
# Top features by weight magnitude
|
|
82
|
+
weights = clf.coef_[0] if clf.coef_.ndim == 2 else clf.coef_
|
|
83
|
+
top_indices = list(reversed(sorted(range(len(weights)), key=lambda i: abs(weights[i]))))[:20]
|
|
84
|
+
top_features = [(int(i), float(weights[i])) for i in top_indices]
|
|
85
|
+
|
|
86
|
+
return {
|
|
87
|
+
"accuracy": accuracy if accuracy is not None else train_accuracy,
|
|
88
|
+
"train_accuracy": train_accuracy,
|
|
89
|
+
"top_features": top_features,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _probe_torch(X: Any, y: Any) -> dict[str, Any]:
|
|
94
|
+
"""Fallback probe using pure PyTorch when sklearn is not available."""
|
|
95
|
+
import numpy as np
|
|
96
|
+
|
|
97
|
+
X_t = torch.tensor(X, dtype=torch.float32)
|
|
98
|
+
y_t = torch.tensor(y, dtype=torch.long)
|
|
99
|
+
|
|
100
|
+
n_features = X_t.shape[1]
|
|
101
|
+
n_classes = len(set(y))
|
|
102
|
+
|
|
103
|
+
linear = torch.nn.Linear(n_features, n_classes)
|
|
104
|
+
optimizer = torch.optim.Adam(linear.parameters(), lr=0.01)
|
|
105
|
+
criterion = torch.nn.CrossEntropyLoss()
|
|
106
|
+
|
|
107
|
+
linear.train()
|
|
108
|
+
for _ in range(500):
|
|
109
|
+
optimizer.zero_grad()
|
|
110
|
+
logits = linear(X_t)
|
|
111
|
+
loss = criterion(logits, y_t)
|
|
112
|
+
loss.backward()
|
|
113
|
+
optimizer.step()
|
|
114
|
+
|
|
115
|
+
linear.eval()
|
|
116
|
+
with torch.no_grad():
|
|
117
|
+
preds = linear(X_t).argmax(dim=-1)
|
|
118
|
+
train_accuracy = float((preds == y_t).float().mean().item())
|
|
119
|
+
|
|
120
|
+
weights = linear.weight.detach()[0].numpy() if n_classes == 2 else linear.weight.detach().mean(dim=0).numpy()
|
|
121
|
+
top_indices = list(reversed(sorted(range(len(weights)), key=lambda i: abs(weights[i]))))[:20]
|
|
122
|
+
top_features = [(int(i), float(weights[i])) for i in top_indices]
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"accuracy": train_accuracy,
|
|
126
|
+
"train_accuracy": train_accuracy,
|
|
127
|
+
"top_features": top_features,
|
|
128
|
+
}
|
interpkit/ops/sae.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""sae — load pre-trained Sparse Autoencoders and decompose activations into features."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
from interpkit.ops.patch import _get_module
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from interpkit.core.model import Model
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class SAE:
|
|
20
|
+
"""A loaded Sparse Autoencoder with weights ready for inference."""
|
|
21
|
+
|
|
22
|
+
W_enc: torch.Tensor
|
|
23
|
+
W_dec: torch.Tensor
|
|
24
|
+
b_enc: torch.Tensor
|
|
25
|
+
b_dec: torch.Tensor
|
|
26
|
+
d_in: int = 0
|
|
27
|
+
d_sae: int = 0
|
|
28
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
29
|
+
|
|
30
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
31
|
+
return torch.relu((x - self.b_dec) @ self.W_enc + self.b_enc)
|
|
32
|
+
|
|
33
|
+
def decode(self, features: torch.Tensor) -> torch.Tensor:
|
|
34
|
+
return features @ self.W_dec + self.b_dec
|
|
35
|
+
|
|
36
|
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
37
|
+
"""Return (features, reconstruction)."""
|
|
38
|
+
features = self.encode(x)
|
|
39
|
+
x_hat = self.decode(features)
|
|
40
|
+
return features, x_hat
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def load_sae(hf_id: str, *, device: str | torch.device = "cpu") -> SAE:
|
|
44
|
+
"""Download and load a Sparse Autoencoder from HuggingFace.
|
|
45
|
+
|
|
46
|
+
Expects the repo to contain ``sae_weights.safetensors`` (or ``.pt``)
|
|
47
|
+
and optionally ``cfg.json`` with metadata.
|
|
48
|
+
"""
|
|
49
|
+
from huggingface_hub import hf_hub_download
|
|
50
|
+
|
|
51
|
+
device = torch.device(device)
|
|
52
|
+
|
|
53
|
+
# Try safetensors first, fall back to .pt
|
|
54
|
+
weights = _download_weights(hf_id)
|
|
55
|
+
|
|
56
|
+
required_keys = {"W_enc", "W_dec", "b_enc", "b_dec"}
|
|
57
|
+
missing = required_keys - set(weights.keys())
|
|
58
|
+
if missing:
|
|
59
|
+
raise KeyError(
|
|
60
|
+
f"SAE weights from {hf_id!r} are missing keys: {missing}. "
|
|
61
|
+
f"Found keys: {list(weights.keys())}. "
|
|
62
|
+
f"interpkit expects the SAELens format (W_enc, W_dec, b_enc, b_dec)."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
W_enc = weights["W_enc"].to(device).float()
|
|
66
|
+
W_dec = weights["W_dec"].to(device).float()
|
|
67
|
+
b_enc = weights["b_enc"].to(device).float()
|
|
68
|
+
b_dec = weights["b_dec"].to(device).float()
|
|
69
|
+
|
|
70
|
+
metadata = _download_config(hf_id)
|
|
71
|
+
|
|
72
|
+
d_in = W_enc.shape[0]
|
|
73
|
+
d_sae = W_enc.shape[1]
|
|
74
|
+
|
|
75
|
+
return SAE(
|
|
76
|
+
W_enc=W_enc,
|
|
77
|
+
W_dec=W_dec,
|
|
78
|
+
b_enc=b_enc,
|
|
79
|
+
b_dec=b_dec,
|
|
80
|
+
d_in=d_in,
|
|
81
|
+
d_sae=d_sae,
|
|
82
|
+
metadata=metadata,
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def load_sae_from_tensors(
|
|
87
|
+
W_enc: torch.Tensor,
|
|
88
|
+
W_dec: torch.Tensor,
|
|
89
|
+
b_enc: torch.Tensor,
|
|
90
|
+
b_dec: torch.Tensor,
|
|
91
|
+
*,
|
|
92
|
+
metadata: dict[str, Any] | None = None,
|
|
93
|
+
) -> SAE:
|
|
94
|
+
"""Create an SAE from raw weight tensors (useful for testing)."""
|
|
95
|
+
return SAE(
|
|
96
|
+
W_enc=W_enc,
|
|
97
|
+
W_dec=W_dec,
|
|
98
|
+
b_enc=b_enc,
|
|
99
|
+
b_dec=b_dec,
|
|
100
|
+
d_in=W_enc.shape[0],
|
|
101
|
+
d_sae=W_enc.shape[1],
|
|
102
|
+
metadata=metadata or {},
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def run_features(
|
|
107
|
+
model: "Model",
|
|
108
|
+
input_data: Any,
|
|
109
|
+
*,
|
|
110
|
+
at: str,
|
|
111
|
+
sae: SAE,
|
|
112
|
+
top_k: int = 20,
|
|
113
|
+
print_results: bool = True,
|
|
114
|
+
) -> dict[str, Any]:
|
|
115
|
+
"""Decompose activations at *at* through the SAE and return top features.
|
|
116
|
+
|
|
117
|
+
Returns a dict with ``top_features``, ``reconstruction_error``, ``sparsity``.
|
|
118
|
+
"""
|
|
119
|
+
from interpkit.ops.activations import run_activations
|
|
120
|
+
|
|
121
|
+
act = run_activations(model, input_data, at=at, print_stats=False)
|
|
122
|
+
if not isinstance(act, torch.Tensor):
|
|
123
|
+
raise TypeError(f"Expected tensor from activations, got {type(act).__name__}")
|
|
124
|
+
|
|
125
|
+
# Flatten to 2D: (batch * seq, d_model)
|
|
126
|
+
if act.dim() == 1:
|
|
127
|
+
flat = act.unsqueeze(0).float()
|
|
128
|
+
else:
|
|
129
|
+
flat = act.view(-1, act.shape[-1]).float()
|
|
130
|
+
|
|
131
|
+
if flat.shape[-1] != sae.d_in:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Activation dim ({flat.shape[-1]}) does not match SAE input dim ({sae.d_in}). "
|
|
134
|
+
f"Make sure the SAE was trained on the same layer."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
features, x_hat = sae.forward(flat)
|
|
138
|
+
|
|
139
|
+
# Reconstruction error (mean L2 across positions)
|
|
140
|
+
recon_error = (flat - x_hat).norm(dim=-1).mean().item()
|
|
141
|
+
|
|
142
|
+
# Sparsity: fraction of features that are zero
|
|
143
|
+
sparsity = (features == 0).float().mean().item()
|
|
144
|
+
|
|
145
|
+
# Top-K features by mean activation (across all positions)
|
|
146
|
+
mean_activations = features.mean(dim=0)
|
|
147
|
+
topk_vals, topk_idxs = mean_activations.topk(min(top_k, sae.d_sae))
|
|
148
|
+
|
|
149
|
+
top_features = [
|
|
150
|
+
(idx.item(), val.item())
|
|
151
|
+
for idx, val in zip(topk_idxs, topk_vals)
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
result = {
|
|
155
|
+
"module": at,
|
|
156
|
+
"top_features": top_features,
|
|
157
|
+
"reconstruction_error": recon_error,
|
|
158
|
+
"sparsity": sparsity,
|
|
159
|
+
"num_active_features": int((mean_activations > 0).sum().item()),
|
|
160
|
+
"total_features": sae.d_sae,
|
|
161
|
+
"feature_activations": features.detach(),
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
if print_results:
|
|
165
|
+
from interpkit.core.render import render_features
|
|
166
|
+
|
|
167
|
+
render_features(result)
|
|
168
|
+
|
|
169
|
+
return result
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _download_weights(hf_id: str) -> dict[str, torch.Tensor]:
|
|
173
|
+
"""Download SAE weights from HuggingFace."""
|
|
174
|
+
from huggingface_hub import hf_hub_download
|
|
175
|
+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
|
176
|
+
|
|
177
|
+
# Try safetensors first
|
|
178
|
+
try:
|
|
179
|
+
path = hf_hub_download(hf_id, filename="sae_weights.safetensors")
|
|
180
|
+
from safetensors.torch import load_file
|
|
181
|
+
|
|
182
|
+
return load_file(path)
|
|
183
|
+
except (EntryNotFoundError, FileNotFoundError):
|
|
184
|
+
pass
|
|
185
|
+
except RepositoryNotFoundError:
|
|
186
|
+
raise FileNotFoundError(
|
|
187
|
+
f"HuggingFace repository {hf_id!r} not found. "
|
|
188
|
+
f"Check the repo ID and your network/auth settings."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Fall back to .pt
|
|
192
|
+
try:
|
|
193
|
+
path = hf_hub_download(hf_id, filename="sae_weights.pt")
|
|
194
|
+
return torch.load(path, map_location="cpu", weights_only=True)
|
|
195
|
+
except (EntryNotFoundError, FileNotFoundError):
|
|
196
|
+
pass
|
|
197
|
+
|
|
198
|
+
raise FileNotFoundError(
|
|
199
|
+
f"Could not find sae_weights.safetensors or sae_weights.pt in {hf_id!r}. "
|
|
200
|
+
f"The HF repo should contain one of these files."
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _download_config(hf_id: str) -> dict[str, Any]:
|
|
205
|
+
"""Download SAE config from HuggingFace (optional)."""
|
|
206
|
+
from huggingface_hub import hf_hub_download
|
|
207
|
+
|
|
208
|
+
try:
|
|
209
|
+
path = hf_hub_download(hf_id, filename="cfg.json")
|
|
210
|
+
return json.loads(Path(path).read_text())
|
|
211
|
+
except Exception:
|
|
212
|
+
return {}
|