interpkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
interpkit/ops/lens.py ADDED
@@ -0,0 +1,151 @@
1
+ """lens — logit lens: project each layer's output to vocabulary space."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import torch
8
+ from rich.console import Console
9
+
10
+ from interpkit.ops.patch import _get_module
11
+
12
+ if TYPE_CHECKING:
13
+ from interpkit.core.model import Model
14
+
15
+ console = Console()
16
+
17
+
18
+ def run_lens(model: "Model", text: Any, *, save: str | None = None) -> list[dict[str, Any]] | None:
19
+ """Project each layer's hidden state through the unembedding matrix.
20
+
21
+ Only works for language models with a detectable output head.
22
+ """
23
+ from interpkit.core.render import render_lens
24
+
25
+ arch = model.arch_info
26
+
27
+ if not arch.is_language_model or arch.unembedding_name is None:
28
+ console.print(
29
+ f"\n [yellow]lens not available:[/yellow] no unembedding matrix detected"
30
+ f" for {arch.arch_family or 'this model'}.\n"
31
+ )
32
+ return None
33
+
34
+ if not arch.layer_names:
35
+ console.print(
36
+ "\n [yellow]lens not available:[/yellow] no layer structure detected.\n"
37
+ )
38
+ return None
39
+
40
+ if model._tokenizer is None:
41
+ console.print(
42
+ "\n [yellow]lens not available:[/yellow] no tokenizer loaded.\n"
43
+ )
44
+ return None
45
+
46
+ text_input = model._prepare(text)
47
+
48
+ # Get the unembedding weight matrix
49
+ unembed_mod = _get_module(model._model, arch.unembedding_name)
50
+ unembed_weight = unembed_mod.weight # shape: (vocab_size, hidden_size)
51
+
52
+ # Capture hidden states at the output of each layer
53
+ layer_outputs: dict[str, torch.Tensor] = {}
54
+
55
+ def _make_hook(name: str):
56
+ def hook_fn(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
57
+ t = output if isinstance(output, torch.Tensor) else (
58
+ output[0] if isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor) else None
59
+ )
60
+ if t is not None:
61
+ layer_outputs[name] = t.detach()
62
+ return hook_fn
63
+
64
+ hooks = []
65
+ for layer_name in arch.layer_names:
66
+ try:
67
+ mod = _get_module(model._model, layer_name)
68
+ hooks.append(mod.register_forward_hook(_make_hook(layer_name)))
69
+ except AttributeError:
70
+ continue
71
+
72
+ with torch.no_grad():
73
+ model._forward(text_input)
74
+
75
+ for h in hooks:
76
+ h.remove()
77
+
78
+ if not layer_outputs:
79
+ console.print("\n [yellow]lens:[/yellow] no layer outputs captured.\n")
80
+ return None
81
+
82
+ # Apply layer norm if the model has a final layer norm before the head
83
+ # (common pattern: ln_f, model.norm, etc.)
84
+ final_norm = _find_final_norm(model._model, arch)
85
+
86
+ predictions: list[dict[str, Any]] = []
87
+
88
+ for layer_name in arch.layer_names:
89
+ if layer_name not in layer_outputs:
90
+ continue
91
+
92
+ hidden = layer_outputs[layer_name].float()
93
+
94
+ # Take the last token position
95
+ if hidden.dim() == 3:
96
+ hidden = hidden[:, -1, :] # (batch, hidden)
97
+ elif hidden.dim() == 2:
98
+ hidden = hidden[-1:, :]
99
+
100
+ # Apply final norm if found
101
+ if final_norm is not None:
102
+ hidden = final_norm(hidden)
103
+
104
+ # Project through unembedding: logits = hidden @ W^T
105
+ logits = hidden @ unembed_weight.float().T # (batch, vocab)
106
+ probs = torch.softmax(logits, dim=-1)
107
+
108
+ top5_probs, top5_ids = probs[0].topk(5)
109
+ top5_tokens = [model._tokenizer.decode([tid]) for tid in top5_ids.tolist()]
110
+ top5_probs_list = top5_probs.tolist()
111
+
112
+ predictions.append({
113
+ "layer_name": layer_name,
114
+ "top1_token": top5_tokens[0],
115
+ "top1_prob": top5_probs_list[0],
116
+ "top5_tokens": top5_tokens,
117
+ "top5_probs": top5_probs_list,
118
+ })
119
+
120
+ model_name = arch.arch_family or "model"
121
+ render_lens(predictions, model_name)
122
+
123
+ if save is not None:
124
+ from interpkit.core.plot import plot_lens
125
+
126
+ plot_lens(predictions, save_path=save)
127
+
128
+ return predictions
129
+
130
+
131
+ def _find_final_norm(model: torch.nn.Module, arch: Any) -> torch.nn.Module | None:
132
+ """Try to find the final layer norm applied before the LM head."""
133
+ import re
134
+
135
+ norm_pattern = re.compile(
136
+ r"^(model\.norm|transformer\.ln_f|gpt_neox\.final_layer_norm|"
137
+ r"model\.final_layernorm|backbone\.norm_f)$",
138
+ re.IGNORECASE,
139
+ )
140
+ for name, mod in model.named_modules():
141
+ if norm_pattern.match(name):
142
+ return mod
143
+
144
+ # Generic fallback: look for a top-level norm module
145
+ for name, mod in model.named_modules():
146
+ if name.count(".") <= 1 and isinstance(mod, (torch.nn.LayerNorm,)):
147
+ type_name = type(mod).__name__
148
+ if "norm" in name.lower() or "Norm" in type_name:
149
+ return mod
150
+
151
+ return None
interpkit/ops/patch.py ADDED
@@ -0,0 +1,112 @@
1
+ """patch — activation patching at a named module between clean and corrupted inputs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import torch
8
+
9
+ if TYPE_CHECKING:
10
+ from interpkit.core.model import Model
11
+
12
+
13
+ def _get_module(model: torch.nn.Module, name: str) -> torch.nn.Module:
14
+ parts = name.split(".")
15
+ mod = model
16
+ for part in parts:
17
+ mod = getattr(mod, part)
18
+ return mod
19
+
20
+
21
+ def run_patch(
22
+ model: "Model",
23
+ clean: Any,
24
+ corrupted: Any,
25
+ *,
26
+ at: str,
27
+ ) -> dict[str, Any]:
28
+ """Patch the output of module *at* from the clean run into the corrupted run.
29
+
30
+ Returns a dict with ``effect`` — a normalised scalar in [0, 1] measuring how
31
+ much the patched corrupted run's output shifted toward the clean output.
32
+ """
33
+ from interpkit.core.render import render_patch
34
+
35
+ clean_input, corrupted_input = model._prepare_pair(clean, corrupted)
36
+
37
+ # 1. Clean forward — cache the target module's output
38
+ cached_activation: list[torch.Tensor] = []
39
+
40
+ target_mod = _get_module(model._model, at)
41
+
42
+ def _cache_hook(_mod: torch.nn.Module, _inp: Any, output: Any) -> None:
43
+ if isinstance(output, torch.Tensor):
44
+ cached_activation.append(output.detach().clone())
45
+ elif isinstance(output, (tuple, list)):
46
+ cached_activation.append(output[0].detach().clone())
47
+
48
+ handle = target_mod.register_forward_hook(_cache_hook)
49
+ clean_logits = model._forward(clean_input)
50
+ handle.remove()
51
+
52
+ if not cached_activation:
53
+ raise RuntimeError(f"Module '{at}' produced no tensor output during clean forward pass.")
54
+
55
+ # 2. Corrupted forward (baseline)
56
+ corrupted_logits = model._forward(corrupted_input)
57
+
58
+ # 3. Patched forward — replace target module's output with cached clean activation
59
+ def _patch_hook(_mod: torch.nn.Module, _inp: Any, output: Any) -> Any:
60
+ if isinstance(output, torch.Tensor):
61
+ return cached_activation[0]
62
+ elif isinstance(output, (tuple, list)):
63
+ return (cached_activation[0],) + tuple(output[1:])
64
+ return output
65
+
66
+ handle = target_mod.register_forward_hook(_patch_hook)
67
+ patched_logits = model._forward(corrupted_input)
68
+ handle.remove()
69
+
70
+ # 4. Compute normalised effect
71
+ effect = _compute_effect(clean_logits, corrupted_logits, patched_logits)
72
+
73
+ result = {
74
+ "module": at,
75
+ "effect": effect,
76
+ "clean_logits": clean_logits,
77
+ "corrupted_logits": corrupted_logits,
78
+ "patched_logits": patched_logits,
79
+ }
80
+ render_patch(result)
81
+ return result
82
+
83
+
84
+ def _compute_effect(
85
+ clean: torch.Tensor,
86
+ corrupted: torch.Tensor,
87
+ patched: torch.Tensor,
88
+ ) -> float:
89
+ """Normalised patching effect: 0 = patched == corrupted, 1 = patched == clean."""
90
+ # Use KL divergence on the last-token logits as the distance metric
91
+ clean_flat = clean.view(-1, clean.shape[-1]).float()
92
+ corrupted_flat = corrupted.view(-1, corrupted.shape[-1]).float()
93
+ patched_flat = patched.view(-1, patched.shape[-1]).float()
94
+
95
+ # Take last position for sequence models
96
+ if clean_flat.shape[0] > 1:
97
+ clean_flat = clean_flat[-1:]
98
+ corrupted_flat = corrupted_flat[-1:]
99
+ patched_flat = patched_flat[-1:]
100
+
101
+ clean_probs = torch.softmax(clean_flat, dim=-1)
102
+ corrupted_probs = torch.softmax(corrupted_flat, dim=-1)
103
+ patched_probs = torch.softmax(patched_flat, dim=-1)
104
+
105
+ dist_corrupted_clean = torch.norm(corrupted_probs - clean_probs).item()
106
+ dist_patched_clean = torch.norm(patched_probs - clean_probs).item()
107
+
108
+ if dist_corrupted_clean < 1e-8:
109
+ return 0.0
110
+
111
+ effect = 1.0 - (dist_patched_clean / dist_corrupted_clean)
112
+ return max(0.0, min(1.0, effect))
interpkit/ops/probe.py ADDED
@@ -0,0 +1,128 @@
1
+ """probe — train a linear probe on activations to test linear separability."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import torch
8
+ from rich.console import Console
9
+
10
+ if TYPE_CHECKING:
11
+ from interpkit.core.model import Model
12
+
13
+ console = Console()
14
+
15
+
16
+ def run_probe(
17
+ model: "Model",
18
+ texts: list[str],
19
+ labels: list[int],
20
+ *,
21
+ at: str,
22
+ ) -> dict[str, Any]:
23
+ """Train a linear probe on activations at module *at*.
24
+
25
+ Uses LogisticRegression from scikit-learn. Falls back to a simple
26
+ torch-based probe if sklearn is not installed.
27
+ """
28
+ from interpkit.core.render import render_probe
29
+ from interpkit.ops.activations import run_activations
30
+
31
+ if len(texts) != len(labels):
32
+ raise ValueError(f"texts ({len(texts)}) and labels ({len(labels)}) must have the same length.")
33
+
34
+ # Extract activations for all texts
35
+ features = []
36
+ for text in texts:
37
+ act = run_activations(model, text, at=at, print_stats=False)
38
+ # Take last-token hidden state for sequence models
39
+ if act.dim() == 3:
40
+ vec = act[0, -1, :] # (hidden,)
41
+ elif act.dim() == 2:
42
+ vec = act[-1, :]
43
+ else:
44
+ vec = act.view(-1)
45
+ features.append(vec.cpu().float().numpy())
46
+
47
+ import numpy as np
48
+
49
+ X = np.stack(features)
50
+ y = np.array(labels)
51
+
52
+ try:
53
+ result = _probe_sklearn(X, y)
54
+ except ImportError:
55
+ result = _probe_torch(X, y)
56
+
57
+ result["module"] = at
58
+ render_probe(result)
59
+ return result
60
+
61
+
62
+ def _probe_sklearn(X: Any, y: Any) -> dict[str, Any]:
63
+ from sklearn.linear_model import LogisticRegression
64
+ from sklearn.model_selection import cross_val_score
65
+
66
+ n_samples = len(y)
67
+
68
+ if n_samples >= 10:
69
+ cv_folds = min(5, n_samples)
70
+ clf = LogisticRegression(max_iter=1000, solver="lbfgs")
71
+ scores = cross_val_score(clf, X, y, cv=cv_folds, scoring="accuracy")
72
+ accuracy = float(scores.mean())
73
+ else:
74
+ accuracy = None
75
+
76
+ # Train on full data for feature analysis
77
+ clf = LogisticRegression(max_iter=1000, solver="lbfgs")
78
+ clf.fit(X, y)
79
+ train_accuracy = float(clf.score(X, y))
80
+
81
+ # Top features by weight magnitude
82
+ weights = clf.coef_[0] if clf.coef_.ndim == 2 else clf.coef_
83
+ top_indices = list(reversed(sorted(range(len(weights)), key=lambda i: abs(weights[i]))))[:20]
84
+ top_features = [(int(i), float(weights[i])) for i in top_indices]
85
+
86
+ return {
87
+ "accuracy": accuracy if accuracy is not None else train_accuracy,
88
+ "train_accuracy": train_accuracy,
89
+ "top_features": top_features,
90
+ }
91
+
92
+
93
+ def _probe_torch(X: Any, y: Any) -> dict[str, Any]:
94
+ """Fallback probe using pure PyTorch when sklearn is not available."""
95
+ import numpy as np
96
+
97
+ X_t = torch.tensor(X, dtype=torch.float32)
98
+ y_t = torch.tensor(y, dtype=torch.long)
99
+
100
+ n_features = X_t.shape[1]
101
+ n_classes = len(set(y))
102
+
103
+ linear = torch.nn.Linear(n_features, n_classes)
104
+ optimizer = torch.optim.Adam(linear.parameters(), lr=0.01)
105
+ criterion = torch.nn.CrossEntropyLoss()
106
+
107
+ linear.train()
108
+ for _ in range(500):
109
+ optimizer.zero_grad()
110
+ logits = linear(X_t)
111
+ loss = criterion(logits, y_t)
112
+ loss.backward()
113
+ optimizer.step()
114
+
115
+ linear.eval()
116
+ with torch.no_grad():
117
+ preds = linear(X_t).argmax(dim=-1)
118
+ train_accuracy = float((preds == y_t).float().mean().item())
119
+
120
+ weights = linear.weight.detach()[0].numpy() if n_classes == 2 else linear.weight.detach().mean(dim=0).numpy()
121
+ top_indices = list(reversed(sorted(range(len(weights)), key=lambda i: abs(weights[i]))))[:20]
122
+ top_features = [(int(i), float(weights[i])) for i in top_indices]
123
+
124
+ return {
125
+ "accuracy": train_accuracy,
126
+ "train_accuracy": train_accuracy,
127
+ "top_features": top_features,
128
+ }
interpkit/ops/sae.py ADDED
@@ -0,0 +1,212 @@
1
+ """sae — load pre-trained Sparse Autoencoders and decompose activations into features."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from dataclasses import dataclass, field
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ import torch
11
+
12
+ from interpkit.ops.patch import _get_module
13
+
14
+ if TYPE_CHECKING:
15
+ from interpkit.core.model import Model
16
+
17
+
18
+ @dataclass
19
+ class SAE:
20
+ """A loaded Sparse Autoencoder with weights ready for inference."""
21
+
22
+ W_enc: torch.Tensor
23
+ W_dec: torch.Tensor
24
+ b_enc: torch.Tensor
25
+ b_dec: torch.Tensor
26
+ d_in: int = 0
27
+ d_sae: int = 0
28
+ metadata: dict[str, Any] = field(default_factory=dict)
29
+
30
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
31
+ return torch.relu((x - self.b_dec) @ self.W_enc + self.b_enc)
32
+
33
+ def decode(self, features: torch.Tensor) -> torch.Tensor:
34
+ return features @ self.W_dec + self.b_dec
35
+
36
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
37
+ """Return (features, reconstruction)."""
38
+ features = self.encode(x)
39
+ x_hat = self.decode(features)
40
+ return features, x_hat
41
+
42
+
43
+ def load_sae(hf_id: str, *, device: str | torch.device = "cpu") -> SAE:
44
+ """Download and load a Sparse Autoencoder from HuggingFace.
45
+
46
+ Expects the repo to contain ``sae_weights.safetensors`` (or ``.pt``)
47
+ and optionally ``cfg.json`` with metadata.
48
+ """
49
+ from huggingface_hub import hf_hub_download
50
+
51
+ device = torch.device(device)
52
+
53
+ # Try safetensors first, fall back to .pt
54
+ weights = _download_weights(hf_id)
55
+
56
+ required_keys = {"W_enc", "W_dec", "b_enc", "b_dec"}
57
+ missing = required_keys - set(weights.keys())
58
+ if missing:
59
+ raise KeyError(
60
+ f"SAE weights from {hf_id!r} are missing keys: {missing}. "
61
+ f"Found keys: {list(weights.keys())}. "
62
+ f"interpkit expects the SAELens format (W_enc, W_dec, b_enc, b_dec)."
63
+ )
64
+
65
+ W_enc = weights["W_enc"].to(device).float()
66
+ W_dec = weights["W_dec"].to(device).float()
67
+ b_enc = weights["b_enc"].to(device).float()
68
+ b_dec = weights["b_dec"].to(device).float()
69
+
70
+ metadata = _download_config(hf_id)
71
+
72
+ d_in = W_enc.shape[0]
73
+ d_sae = W_enc.shape[1]
74
+
75
+ return SAE(
76
+ W_enc=W_enc,
77
+ W_dec=W_dec,
78
+ b_enc=b_enc,
79
+ b_dec=b_dec,
80
+ d_in=d_in,
81
+ d_sae=d_sae,
82
+ metadata=metadata,
83
+ )
84
+
85
+
86
+ def load_sae_from_tensors(
87
+ W_enc: torch.Tensor,
88
+ W_dec: torch.Tensor,
89
+ b_enc: torch.Tensor,
90
+ b_dec: torch.Tensor,
91
+ *,
92
+ metadata: dict[str, Any] | None = None,
93
+ ) -> SAE:
94
+ """Create an SAE from raw weight tensors (useful for testing)."""
95
+ return SAE(
96
+ W_enc=W_enc,
97
+ W_dec=W_dec,
98
+ b_enc=b_enc,
99
+ b_dec=b_dec,
100
+ d_in=W_enc.shape[0],
101
+ d_sae=W_enc.shape[1],
102
+ metadata=metadata or {},
103
+ )
104
+
105
+
106
+ def run_features(
107
+ model: "Model",
108
+ input_data: Any,
109
+ *,
110
+ at: str,
111
+ sae: SAE,
112
+ top_k: int = 20,
113
+ print_results: bool = True,
114
+ ) -> dict[str, Any]:
115
+ """Decompose activations at *at* through the SAE and return top features.
116
+
117
+ Returns a dict with ``top_features``, ``reconstruction_error``, ``sparsity``.
118
+ """
119
+ from interpkit.ops.activations import run_activations
120
+
121
+ act = run_activations(model, input_data, at=at, print_stats=False)
122
+ if not isinstance(act, torch.Tensor):
123
+ raise TypeError(f"Expected tensor from activations, got {type(act).__name__}")
124
+
125
+ # Flatten to 2D: (batch * seq, d_model)
126
+ if act.dim() == 1:
127
+ flat = act.unsqueeze(0).float()
128
+ else:
129
+ flat = act.view(-1, act.shape[-1]).float()
130
+
131
+ if flat.shape[-1] != sae.d_in:
132
+ raise ValueError(
133
+ f"Activation dim ({flat.shape[-1]}) does not match SAE input dim ({sae.d_in}). "
134
+ f"Make sure the SAE was trained on the same layer."
135
+ )
136
+
137
+ features, x_hat = sae.forward(flat)
138
+
139
+ # Reconstruction error (mean L2 across positions)
140
+ recon_error = (flat - x_hat).norm(dim=-1).mean().item()
141
+
142
+ # Sparsity: fraction of features that are zero
143
+ sparsity = (features == 0).float().mean().item()
144
+
145
+ # Top-K features by mean activation (across all positions)
146
+ mean_activations = features.mean(dim=0)
147
+ topk_vals, topk_idxs = mean_activations.topk(min(top_k, sae.d_sae))
148
+
149
+ top_features = [
150
+ (idx.item(), val.item())
151
+ for idx, val in zip(topk_idxs, topk_vals)
152
+ ]
153
+
154
+ result = {
155
+ "module": at,
156
+ "top_features": top_features,
157
+ "reconstruction_error": recon_error,
158
+ "sparsity": sparsity,
159
+ "num_active_features": int((mean_activations > 0).sum().item()),
160
+ "total_features": sae.d_sae,
161
+ "feature_activations": features.detach(),
162
+ }
163
+
164
+ if print_results:
165
+ from interpkit.core.render import render_features
166
+
167
+ render_features(result)
168
+
169
+ return result
170
+
171
+
172
+ def _download_weights(hf_id: str) -> dict[str, torch.Tensor]:
173
+ """Download SAE weights from HuggingFace."""
174
+ from huggingface_hub import hf_hub_download
175
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
176
+
177
+ # Try safetensors first
178
+ try:
179
+ path = hf_hub_download(hf_id, filename="sae_weights.safetensors")
180
+ from safetensors.torch import load_file
181
+
182
+ return load_file(path)
183
+ except (EntryNotFoundError, FileNotFoundError):
184
+ pass
185
+ except RepositoryNotFoundError:
186
+ raise FileNotFoundError(
187
+ f"HuggingFace repository {hf_id!r} not found. "
188
+ f"Check the repo ID and your network/auth settings."
189
+ )
190
+
191
+ # Fall back to .pt
192
+ try:
193
+ path = hf_hub_download(hf_id, filename="sae_weights.pt")
194
+ return torch.load(path, map_location="cpu", weights_only=True)
195
+ except (EntryNotFoundError, FileNotFoundError):
196
+ pass
197
+
198
+ raise FileNotFoundError(
199
+ f"Could not find sae_weights.safetensors or sae_weights.pt in {hf_id!r}. "
200
+ f"The HF repo should contain one of these files."
201
+ )
202
+
203
+
204
+ def _download_config(hf_id: str) -> dict[str, Any]:
205
+ """Download SAE config from HuggingFace (optional)."""
206
+ from huggingface_hub import hf_hub_download
207
+
208
+ try:
209
+ path = hf_hub_download(hf_id, filename="cfg.json")
210
+ return json.loads(Path(path).read_text())
211
+ except Exception:
212
+ return {}