interpkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit/__init__.py +15 -0
- interpkit/cli/__init__.py +0 -0
- interpkit/cli/main.py +337 -0
- interpkit/core/__init__.py +0 -0
- interpkit/core/discovery.py +228 -0
- interpkit/core/html.py +375 -0
- interpkit/core/inputs.py +117 -0
- interpkit/core/model.py +551 -0
- interpkit/core/plot.py +352 -0
- interpkit/core/registry.py +82 -0
- interpkit/core/render.py +465 -0
- interpkit/core/tl_compat.py +174 -0
- interpkit/ops/__init__.py +0 -0
- interpkit/ops/ablate.py +90 -0
- interpkit/ops/activations.py +67 -0
- interpkit/ops/attention.py +234 -0
- interpkit/ops/attribute.py +206 -0
- interpkit/ops/diff.py +79 -0
- interpkit/ops/inspect.py +14 -0
- interpkit/ops/lens.py +151 -0
- interpkit/ops/patch.py +112 -0
- interpkit/ops/probe.py +128 -0
- interpkit/ops/sae.py +212 -0
- interpkit/ops/steer.py +118 -0
- interpkit/ops/trace.py +182 -0
- interpkit-0.1.0.dist-info/METADATA +295 -0
- interpkit-0.1.0.dist-info/RECORD +31 -0
- interpkit-0.1.0.dist-info/WHEEL +5 -0
- interpkit-0.1.0.dist-info/entry_points.txt +2 -0
- interpkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- interpkit-0.1.0.dist-info/top_level.txt +1 -0
interpkit/core/render.py
ADDED
|
@@ -0,0 +1,465 @@
|
|
|
1
|
+
"""Terminal rendering — rich tables, trees, unicode bar charts, heatmap export."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from rich.table import Table
|
|
10
|
+
from rich.tree import Tree
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from interpkit.core.discovery import ModelArchInfo, ModuleInfo
|
|
14
|
+
|
|
15
|
+
console = Console()
|
|
16
|
+
|
|
17
|
+
_ROLE_TAGS = {
|
|
18
|
+
"attention": "[attn]",
|
|
19
|
+
"mlp": "[mlp]",
|
|
20
|
+
"head": "[head]",
|
|
21
|
+
"norm": "[norm]",
|
|
22
|
+
"embed": "[embed]",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ------------------------------------------------------------------
|
|
27
|
+
# Inspect rendering
|
|
28
|
+
# ------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def render_inspect(arch_info: "ModelArchInfo") -> None:
|
|
32
|
+
"""Print a module tree with types, param counts, and detected roles."""
|
|
33
|
+
header_parts = []
|
|
34
|
+
if arch_info.arch_family:
|
|
35
|
+
header_parts.append(arch_info.arch_family)
|
|
36
|
+
if arch_info.num_layers is not None:
|
|
37
|
+
header_parts.append(f"{arch_info.num_layers} layers")
|
|
38
|
+
if arch_info.hidden_size is not None:
|
|
39
|
+
header_parts.append(f"hidden={arch_info.hidden_size}")
|
|
40
|
+
if arch_info.vocab_size is not None:
|
|
41
|
+
header_parts.append(f"vocab={arch_info.vocab_size}")
|
|
42
|
+
|
|
43
|
+
total_params = sum(m.param_count for m in arch_info.modules)
|
|
44
|
+
header_parts.append(f"{_format_params(total_params)} params total")
|
|
45
|
+
|
|
46
|
+
console.print(f"\n[bold]{' | '.join(header_parts)}[/bold]")
|
|
47
|
+
|
|
48
|
+
table = Table(show_header=True, header_style="bold", show_lines=False, pad_edge=False)
|
|
49
|
+
table.add_column("Module", style="cyan", no_wrap=True)
|
|
50
|
+
table.add_column("Type", style="dim")
|
|
51
|
+
table.add_column("Params", justify="right")
|
|
52
|
+
table.add_column("Output Shape", style="dim")
|
|
53
|
+
table.add_column("Role", style="bold yellow")
|
|
54
|
+
|
|
55
|
+
for m in arch_info.modules:
|
|
56
|
+
role_tag = _ROLE_TAGS.get(m.role or "", "")
|
|
57
|
+
shape_str = str(m.output_shape) if m.output_shape else ""
|
|
58
|
+
param_str = _format_params(m.param_count) if m.param_count > 0 else ""
|
|
59
|
+
table.add_row(m.name, m.type_name, param_str, shape_str, role_tag)
|
|
60
|
+
|
|
61
|
+
console.print(table)
|
|
62
|
+
console.print()
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ------------------------------------------------------------------
|
|
66
|
+
# Causal trace rendering
|
|
67
|
+
# ------------------------------------------------------------------
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def render_trace(
|
|
71
|
+
results: list[dict[str, Any]],
|
|
72
|
+
model_name: str,
|
|
73
|
+
total_modules: int,
|
|
74
|
+
top_k: int | None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Print a ranked bar chart of causal tracing results."""
|
|
77
|
+
scanned = len(results)
|
|
78
|
+
if top_k is not None:
|
|
79
|
+
title = f"Causal Trace: {model_name} (top {scanned} of {total_modules} modules)"
|
|
80
|
+
else:
|
|
81
|
+
title = f"Causal Trace: {model_name} ({total_modules} modules)"
|
|
82
|
+
|
|
83
|
+
console.print(f"\n[bold]{title}[/bold]")
|
|
84
|
+
|
|
85
|
+
if not results:
|
|
86
|
+
console.print(" No significant causal effects found.")
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
max_effect = max(r["effect"] for r in results) if results else 1.0
|
|
90
|
+
bar_width = 30
|
|
91
|
+
|
|
92
|
+
table = Table(show_header=False, show_lines=False, pad_edge=False, box=None)
|
|
93
|
+
table.add_column("Module", style="cyan", no_wrap=True, min_width=35)
|
|
94
|
+
table.add_column("Role", style="yellow", min_width=8)
|
|
95
|
+
table.add_column("Bar", no_wrap=True)
|
|
96
|
+
table.add_column("Effect", justify="right", style="bold")
|
|
97
|
+
|
|
98
|
+
for r in results:
|
|
99
|
+
fill = int(bar_width * r["effect"] / max_effect) if max_effect > 0 else 0
|
|
100
|
+
bar = "█" * fill
|
|
101
|
+
role = _ROLE_TAGS.get(r.get("role") or "", "")
|
|
102
|
+
table.add_row(r["module"], role, f"[green]{bar}[/green]", f"{r['effect']:.3f}")
|
|
103
|
+
|
|
104
|
+
console.print(table)
|
|
105
|
+
|
|
106
|
+
if results:
|
|
107
|
+
best = results[0]
|
|
108
|
+
console.print(f"\n Top component: [bold cyan]{best['module']}[/bold cyan] (effect: {best['effect']:.3f})")
|
|
109
|
+
|
|
110
|
+
if top_k is not None and scanned < total_modules:
|
|
111
|
+
console.print(
|
|
112
|
+
f" Run with --top-k 0 to scan all {total_modules} modules.\n"
|
|
113
|
+
)
|
|
114
|
+
else:
|
|
115
|
+
console.print()
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# ------------------------------------------------------------------
|
|
119
|
+
# Logit lens rendering
|
|
120
|
+
# ------------------------------------------------------------------
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def render_lens(
|
|
124
|
+
predictions: list[dict[str, Any]],
|
|
125
|
+
model_name: str,
|
|
126
|
+
) -> None:
|
|
127
|
+
"""Print logit lens top predictions per layer."""
|
|
128
|
+
console.print(f"\n[bold]Logit Lens: {model_name}[/bold]")
|
|
129
|
+
|
|
130
|
+
table = Table(show_header=True, header_style="bold", show_lines=False)
|
|
131
|
+
table.add_column("Layer", style="cyan")
|
|
132
|
+
table.add_column("Top-1 Token", style="bold")
|
|
133
|
+
table.add_column("Prob", justify="right")
|
|
134
|
+
table.add_column("Top-5 Tokens", style="dim")
|
|
135
|
+
|
|
136
|
+
for pred in predictions:
|
|
137
|
+
top5_str = ", ".join(
|
|
138
|
+
f"{tok} ({prob:.2f})" for tok, prob in zip(pred["top5_tokens"], pred["top5_probs"])
|
|
139
|
+
)
|
|
140
|
+
table.add_row(
|
|
141
|
+
pred["layer_name"],
|
|
142
|
+
pred["top1_token"],
|
|
143
|
+
f"{pred['top1_prob']:.3f}",
|
|
144
|
+
top5_str,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
console.print(table)
|
|
148
|
+
console.print()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ------------------------------------------------------------------
|
|
152
|
+
# Attribution rendering
|
|
153
|
+
# ------------------------------------------------------------------
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def render_attribution_tokens(
|
|
157
|
+
tokens: list[str],
|
|
158
|
+
scores: list[float],
|
|
159
|
+
) -> None:
|
|
160
|
+
"""Print tokens coloured by attribution score (terminal)."""
|
|
161
|
+
console.print("\n[bold]Attribution (gradient saliency)[/bold]")
|
|
162
|
+
|
|
163
|
+
if not scores:
|
|
164
|
+
console.print(" No attribution scores computed.")
|
|
165
|
+
return
|
|
166
|
+
|
|
167
|
+
max_score = max(abs(s) for s in scores) if scores else 1.0
|
|
168
|
+
|
|
169
|
+
parts: list[str] = []
|
|
170
|
+
for tok, score in zip(tokens, scores):
|
|
171
|
+
intensity = abs(score) / max_score if max_score > 0 else 0
|
|
172
|
+
if intensity > 0.7:
|
|
173
|
+
parts.append(f"[bold red]{tok}[/bold red]")
|
|
174
|
+
elif intensity > 0.4:
|
|
175
|
+
parts.append(f"[yellow]{tok}[/yellow]")
|
|
176
|
+
elif intensity > 0.15:
|
|
177
|
+
parts.append(f"[dim]{tok}[/dim]")
|
|
178
|
+
else:
|
|
179
|
+
parts.append(tok)
|
|
180
|
+
|
|
181
|
+
console.print(" " + "".join(parts))
|
|
182
|
+
|
|
183
|
+
# Also show ranked list
|
|
184
|
+
ranked = sorted(zip(tokens, scores), key=lambda x: abs(x[1]), reverse=True)
|
|
185
|
+
console.print()
|
|
186
|
+
for tok, score in ranked[:10]:
|
|
187
|
+
bar_len = int(20 * abs(score) / max_score) if max_score > 0 else 0
|
|
188
|
+
bar = "█" * bar_len
|
|
189
|
+
console.print(f" {tok:>15s} [green]{bar}[/green] {score:.4f}")
|
|
190
|
+
|
|
191
|
+
console.print()
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def render_attribution_heatmap(
|
|
195
|
+
attribution: torch.Tensor,
|
|
196
|
+
output_path: str = "attribution_heatmap.png",
|
|
197
|
+
) -> None:
|
|
198
|
+
"""Save a vision attribution heatmap to a file."""
|
|
199
|
+
import matplotlib.pyplot as plt
|
|
200
|
+
import numpy as np
|
|
201
|
+
|
|
202
|
+
attr_np = attribution.detach().cpu().numpy()
|
|
203
|
+
|
|
204
|
+
# Collapse channel dim if present
|
|
205
|
+
if attr_np.ndim == 3:
|
|
206
|
+
attr_np = attr_np.mean(axis=0)
|
|
207
|
+
elif attr_np.ndim == 4:
|
|
208
|
+
attr_np = attr_np[0].mean(axis=0)
|
|
209
|
+
|
|
210
|
+
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
|
|
211
|
+
im = ax.imshow(np.abs(attr_np), cmap="hot", interpolation="bilinear")
|
|
212
|
+
ax.set_title("Gradient Attribution")
|
|
213
|
+
ax.axis("off")
|
|
214
|
+
fig.colorbar(im, ax=ax, fraction=0.046)
|
|
215
|
+
fig.savefig(output_path, bbox_inches="tight", dpi=150)
|
|
216
|
+
plt.close(fig)
|
|
217
|
+
|
|
218
|
+
console.print(f"\n Attribution heatmap saved to [bold]{output_path}[/bold]\n")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
# ------------------------------------------------------------------
|
|
222
|
+
# Patch rendering
|
|
223
|
+
# ------------------------------------------------------------------
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def render_patch(result: dict[str, Any]) -> None:
|
|
227
|
+
"""Print the result of a single activation patch."""
|
|
228
|
+
console.print(f"\n[bold]Activation Patch at: {result['module']}[/bold]")
|
|
229
|
+
console.print(f" Normalised effect: [bold]{result['effect']:.4f}[/bold]")
|
|
230
|
+
console.print()
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# ------------------------------------------------------------------
|
|
234
|
+
# Activations rendering
|
|
235
|
+
# ------------------------------------------------------------------
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def render_activations(cache: dict[str, torch.Tensor]) -> None:
|
|
239
|
+
"""Print a summary table of extracted activations."""
|
|
240
|
+
console.print("\n[bold]Activations[/bold]")
|
|
241
|
+
|
|
242
|
+
table = Table(show_header=True, header_style="bold", show_lines=False)
|
|
243
|
+
table.add_column("Module", style="cyan", no_wrap=True)
|
|
244
|
+
table.add_column("Shape", style="dim")
|
|
245
|
+
table.add_column("Norm", justify="right")
|
|
246
|
+
table.add_column("Mean", justify="right")
|
|
247
|
+
table.add_column("Std", justify="right")
|
|
248
|
+
table.add_column("Min", justify="right")
|
|
249
|
+
table.add_column("Max", justify="right")
|
|
250
|
+
|
|
251
|
+
for name, tensor in cache.items():
|
|
252
|
+
t = tensor.float()
|
|
253
|
+
table.add_row(
|
|
254
|
+
name,
|
|
255
|
+
str(tuple(tensor.shape)),
|
|
256
|
+
f"{t.norm():.3f}",
|
|
257
|
+
f"{t.mean():.4f}",
|
|
258
|
+
f"{t.std():.4f}",
|
|
259
|
+
f"{t.min():.4f}",
|
|
260
|
+
f"{t.max():.4f}",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
console.print(table)
|
|
264
|
+
console.print()
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
# ------------------------------------------------------------------
|
|
268
|
+
# Ablation rendering
|
|
269
|
+
# ------------------------------------------------------------------
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def render_ablate(result: dict[str, Any]) -> None:
|
|
273
|
+
"""Print the result of an ablation."""
|
|
274
|
+
method = result.get("method", "zero")
|
|
275
|
+
console.print(f"\n[bold]Ablation ({method}) at: {result['module']}[/bold]")
|
|
276
|
+
console.print(f" Effect on output: [bold]{result['effect']:.4f}[/bold]")
|
|
277
|
+
console.print()
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
# ------------------------------------------------------------------
|
|
281
|
+
# Attention rendering
|
|
282
|
+
# ------------------------------------------------------------------
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def render_attention(
|
|
286
|
+
attention_data: list[dict[str, Any]],
|
|
287
|
+
tokens: list[str] | None,
|
|
288
|
+
model_name: str,
|
|
289
|
+
) -> None:
|
|
290
|
+
"""Print attention summary per layer/head."""
|
|
291
|
+
console.print(f"\n[bold]Attention Patterns: {model_name}[/bold]")
|
|
292
|
+
|
|
293
|
+
if not attention_data:
|
|
294
|
+
console.print(" No attention data captured.")
|
|
295
|
+
return
|
|
296
|
+
|
|
297
|
+
table = Table(show_header=True, header_style="bold", show_lines=False)
|
|
298
|
+
table.add_column("Layer", style="cyan")
|
|
299
|
+
table.add_column("Head", style="cyan", justify="right")
|
|
300
|
+
table.add_column("Top Attention", style="dim")
|
|
301
|
+
table.add_column("Entropy", justify="right")
|
|
302
|
+
|
|
303
|
+
for entry in attention_data:
|
|
304
|
+
top_attn_parts = []
|
|
305
|
+
for src, tgt, score in entry.get("top_pairs", [])[:3]:
|
|
306
|
+
src_tok = tokens[src] if tokens and src < len(tokens) else str(src)
|
|
307
|
+
tgt_tok = tokens[tgt] if tokens and tgt < len(tokens) else str(tgt)
|
|
308
|
+
top_attn_parts.append(f"{src_tok}->{tgt_tok} ({score:.2f})")
|
|
309
|
+
table.add_row(
|
|
310
|
+
str(entry["layer"]),
|
|
311
|
+
str(entry["head"]),
|
|
312
|
+
", ".join(top_attn_parts),
|
|
313
|
+
f"{entry.get('entropy', 0.0):.2f}",
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
console.print(table)
|
|
317
|
+
console.print()
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
# ------------------------------------------------------------------
|
|
321
|
+
# Steering rendering
|
|
322
|
+
# ------------------------------------------------------------------
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
def render_steer(
|
|
326
|
+
original_tokens: list[tuple[str, float]],
|
|
327
|
+
steered_tokens: list[tuple[str, float]],
|
|
328
|
+
module_name: str,
|
|
329
|
+
scale: float,
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Print side-by-side comparison of top tokens with and without steering."""
|
|
332
|
+
console.print(f"\n[bold]Steering at: {module_name} (scale={scale})[/bold]")
|
|
333
|
+
|
|
334
|
+
table = Table(show_header=True, header_style="bold", show_lines=False)
|
|
335
|
+
table.add_column("Rank", justify="right", style="dim")
|
|
336
|
+
table.add_column("Original Token", style="cyan")
|
|
337
|
+
table.add_column("Prob", justify="right")
|
|
338
|
+
table.add_column("Steered Token", style="green")
|
|
339
|
+
table.add_column("Prob", justify="right")
|
|
340
|
+
|
|
341
|
+
n = max(len(original_tokens), len(steered_tokens))
|
|
342
|
+
for i in range(min(n, 10)):
|
|
343
|
+
orig_tok = original_tokens[i][0] if i < len(original_tokens) else ""
|
|
344
|
+
orig_prob = f"{original_tokens[i][1]:.3f}" if i < len(original_tokens) else ""
|
|
345
|
+
steer_tok = steered_tokens[i][0] if i < len(steered_tokens) else ""
|
|
346
|
+
steer_prob = f"{steered_tokens[i][1]:.3f}" if i < len(steered_tokens) else ""
|
|
347
|
+
table.add_row(str(i + 1), orig_tok, orig_prob, steer_tok, steer_prob)
|
|
348
|
+
|
|
349
|
+
console.print(table)
|
|
350
|
+
console.print()
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
# ------------------------------------------------------------------
|
|
354
|
+
# Probe rendering
|
|
355
|
+
# ------------------------------------------------------------------
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def render_probe(result: dict[str, Any]) -> None:
|
|
359
|
+
"""Print probe results — accuracy and top features."""
|
|
360
|
+
console.print(f"\n[bold]Linear Probe at: {result['module']}[/bold]")
|
|
361
|
+
console.print(f" Accuracy: [bold]{result['accuracy']:.3f}[/bold]")
|
|
362
|
+
|
|
363
|
+
if result.get("train_accuracy") is not None:
|
|
364
|
+
console.print(f" Train accuracy: {result['train_accuracy']:.3f}")
|
|
365
|
+
|
|
366
|
+
if result.get("top_features"):
|
|
367
|
+
console.print("\n Top features by weight magnitude:")
|
|
368
|
+
for idx, weight in result["top_features"][:10]:
|
|
369
|
+
bar_len = int(20 * abs(weight) / abs(result["top_features"][0][1])) if result["top_features"][0][1] != 0 else 0
|
|
370
|
+
bar = "█" * bar_len
|
|
371
|
+
console.print(f" dim {idx:>5d} [green]{bar}[/green] {weight:.4f}")
|
|
372
|
+
|
|
373
|
+
console.print()
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
# ------------------------------------------------------------------
|
|
377
|
+
# Diff rendering
|
|
378
|
+
# ------------------------------------------------------------------
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def render_diff(
|
|
382
|
+
results: list[dict[str, Any]],
|
|
383
|
+
model_a_name: str,
|
|
384
|
+
model_b_name: str,
|
|
385
|
+
) -> None:
|
|
386
|
+
"""Print per-module activation distance between two models."""
|
|
387
|
+
console.print(f"\n[bold]Model Diff: {model_a_name} vs {model_b_name}[/bold]")
|
|
388
|
+
|
|
389
|
+
if not results:
|
|
390
|
+
console.print(" No differences computed.")
|
|
391
|
+
return
|
|
392
|
+
|
|
393
|
+
max_dist = max(r["distance"] for r in results) if results else 1.0
|
|
394
|
+
bar_width = 30
|
|
395
|
+
|
|
396
|
+
table = Table(show_header=False, show_lines=False, pad_edge=False, box=None)
|
|
397
|
+
table.add_column("Module", style="cyan", no_wrap=True, min_width=35)
|
|
398
|
+
table.add_column("Bar", no_wrap=True)
|
|
399
|
+
table.add_column("Cosine Dist", justify="right", style="bold")
|
|
400
|
+
|
|
401
|
+
for r in results:
|
|
402
|
+
fill = int(bar_width * r["distance"] / max_dist) if max_dist > 0 else 0
|
|
403
|
+
bar = "█" * fill
|
|
404
|
+
table.add_row(r["module"], f"[green]{bar}[/green]", f"{r['distance']:.4f}")
|
|
405
|
+
|
|
406
|
+
console.print(table)
|
|
407
|
+
|
|
408
|
+
if results:
|
|
409
|
+
best = results[0]
|
|
410
|
+
console.print(f"\n Most changed: [bold cyan]{best['module']}[/bold cyan] (distance: {best['distance']:.4f})")
|
|
411
|
+
|
|
412
|
+
console.print()
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
# ------------------------------------------------------------------
|
|
416
|
+
# SAE features rendering
|
|
417
|
+
# ------------------------------------------------------------------
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def render_features(result: dict[str, Any]) -> None:
|
|
421
|
+
"""Print SAE feature decomposition results."""
|
|
422
|
+
console.print(f"\n[bold]SAE Features at: {result['module']}[/bold]")
|
|
423
|
+
console.print(
|
|
424
|
+
f" Active features: [bold]{result['num_active_features']}[/bold] / {result['total_features']} "
|
|
425
|
+
f"| Sparsity: {result['sparsity']:.2%} "
|
|
426
|
+
f"| Reconstruction error: {result['reconstruction_error']:.4f}"
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
top = result.get("top_features", [])
|
|
430
|
+
if not top:
|
|
431
|
+
console.print(" No active features found.")
|
|
432
|
+
console.print()
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
max_val = max(abs(v) for _, v in top) if top else 1.0
|
|
436
|
+
|
|
437
|
+
table = Table(show_header=True, header_style="bold", show_lines=False)
|
|
438
|
+
table.add_column("Rank", justify="right", style="dim")
|
|
439
|
+
table.add_column("Feature", style="cyan", justify="right")
|
|
440
|
+
table.add_column("Activation", justify="right")
|
|
441
|
+
table.add_column("Bar", no_wrap=True)
|
|
442
|
+
|
|
443
|
+
bar_width = 25
|
|
444
|
+
for rank, (idx, val) in enumerate(top, 1):
|
|
445
|
+
fill = int(bar_width * abs(val) / max_val) if max_val > 0 else 0
|
|
446
|
+
bar = "█" * fill
|
|
447
|
+
table.add_row(str(rank), str(idx), f"{val:.4f}", f"[green]{bar}[/green]")
|
|
448
|
+
|
|
449
|
+
console.print(table)
|
|
450
|
+
console.print()
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
# ------------------------------------------------------------------
|
|
454
|
+
# Helpers
|
|
455
|
+
# ------------------------------------------------------------------
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def _format_params(n: int) -> str:
|
|
459
|
+
if n >= 1_000_000_000:
|
|
460
|
+
return f"{n / 1_000_000_000:.1f}B"
|
|
461
|
+
if n >= 1_000_000:
|
|
462
|
+
return f"{n / 1_000_000:.1f}M"
|
|
463
|
+
if n >= 1_000:
|
|
464
|
+
return f"{n / 1_000:.1f}K"
|
|
465
|
+
return str(n)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
"""TransformerLens interop — bidirectional name translation between native and TL hook names."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from interpkit.core.discovery import ModelArchInfo
|
|
10
|
+
|
|
11
|
+
# ── TL canonical hook names ──────────────────────────────────────
|
|
12
|
+
#
|
|
13
|
+
# TL standardizes all transformer architectures to:
|
|
14
|
+
# blocks.{N}.hook_resid_pre
|
|
15
|
+
# blocks.{N}.hook_resid_post
|
|
16
|
+
# blocks.{N}.attn.hook_q / hook_k / hook_v / hook_z / hook_result
|
|
17
|
+
# blocks.{N}.attn.hook_pattern
|
|
18
|
+
# blocks.{N}.hook_attn_out
|
|
19
|
+
# blocks.{N}.hook_mlp_out
|
|
20
|
+
# blocks.{N}.mlp.hook_pre / hook_post
|
|
21
|
+
# blocks.{N}.ln1 / ln2
|
|
22
|
+
#
|
|
23
|
+
# These map to native names like:
|
|
24
|
+
# transformer.h.{N} -> blocks.{N}
|
|
25
|
+
# transformer.h.{N}.attn -> blocks.{N}.attn
|
|
26
|
+
# transformer.h.{N}.mlp -> blocks.{N}.mlp
|
|
27
|
+
# model.layers.{N}.self_attn -> blocks.{N}.attn
|
|
28
|
+
# model.layers.{N}.mlp -> blocks.{N}.mlp
|
|
29
|
+
|
|
30
|
+
# Patterns for extracting layer index and component from native names
|
|
31
|
+
_NATIVE_LAYER_RE = re.compile(
|
|
32
|
+
r"^(?P<prefix>.+?)[.\[](?P<idx>\d+)[.\]]*(?P<suffix>.*)$"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Common native -> TL component mappings
|
|
36
|
+
_COMPONENT_TO_TL: list[tuple[re.Pattern, str]] = [
|
|
37
|
+
(re.compile(r"\.self_attn$|\.attn$|\.attention$", re.I), ".attn"),
|
|
38
|
+
(re.compile(r"\.mlp$|\.ffn$|\.feed_forward$", re.I), ".mlp"),
|
|
39
|
+
(re.compile(r"\.ln_?1$|\.input_layernorm$", re.I), ".ln1"),
|
|
40
|
+
(re.compile(r"\.ln_?2$|\.post_attention_layernorm$", re.I), ".ln2"),
|
|
41
|
+
(re.compile(r"\.self_attn\.q_proj$|\.attn\.q_proj$", re.I), ".attn.hook_q"),
|
|
42
|
+
(re.compile(r"\.self_attn\.k_proj$|\.attn\.k_proj$", re.I), ".attn.hook_k"),
|
|
43
|
+
(re.compile(r"\.self_attn\.v_proj$|\.attn\.v_proj$", re.I), ".attn.hook_v"),
|
|
44
|
+
(re.compile(r"\.self_attn\.o_proj$|\.attn\.c_proj$|\.attn\.out_proj$", re.I), ".attn.hook_result"),
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
# TL hook -> native component suffix patterns
|
|
48
|
+
_TL_TO_COMPONENT: list[tuple[str, list[str]]] = [
|
|
49
|
+
(".attn", ["attn", "self_attn", "attention"]),
|
|
50
|
+
(".mlp", ["mlp", "ffn", "feed_forward"]),
|
|
51
|
+
(".ln1", ["ln_1", "ln1", "input_layernorm"]),
|
|
52
|
+
(".ln2", ["ln_2", "ln2", "post_attention_layernorm"]),
|
|
53
|
+
(".attn.hook_q", ["attn.q_proj", "self_attn.q_proj"]),
|
|
54
|
+
(".attn.hook_k", ["attn.k_proj", "self_attn.k_proj"]),
|
|
55
|
+
(".attn.hook_v", ["attn.v_proj", "self_attn.v_proj"]),
|
|
56
|
+
(".attn.hook_result", ["attn.c_proj", "attn.out_proj", "self_attn.o_proj"]),
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def to_tl_name(native_name: str, arch_info: "ModelArchInfo | None" = None) -> str:
|
|
61
|
+
"""Translate a native PyTorch module name to the corresponding TL hook name.
|
|
62
|
+
|
|
63
|
+
Examples::
|
|
64
|
+
|
|
65
|
+
to_tl_name("transformer.h.8.mlp") -> "blocks.8.mlp"
|
|
66
|
+
to_tl_name("transformer.h.8.attn") -> "blocks.8.attn"
|
|
67
|
+
to_tl_name("model.layers.3.self_attn.q_proj") -> "blocks.3.attn.hook_q"
|
|
68
|
+
"""
|
|
69
|
+
m = _NATIVE_LAYER_RE.match(native_name)
|
|
70
|
+
if m is None:
|
|
71
|
+
return native_name
|
|
72
|
+
|
|
73
|
+
idx = m.group("idx")
|
|
74
|
+
raw_suffix = m.group("suffix")
|
|
75
|
+
|
|
76
|
+
# Normalize suffix to always start with "." for pattern matching
|
|
77
|
+
clean_suffix = raw_suffix.lstrip(".")
|
|
78
|
+
dotted_suffix = f".{clean_suffix}" if clean_suffix else ""
|
|
79
|
+
|
|
80
|
+
# Try specific component mappings first
|
|
81
|
+
for pattern, tl_suffix in _COMPONENT_TO_TL:
|
|
82
|
+
if dotted_suffix and pattern.search(dotted_suffix):
|
|
83
|
+
return f"blocks.{idx}{tl_suffix}"
|
|
84
|
+
|
|
85
|
+
# Bare layer reference (e.g. "transformer.h.8")
|
|
86
|
+
if not clean_suffix:
|
|
87
|
+
return f"blocks.{idx}"
|
|
88
|
+
|
|
89
|
+
# Fallback: preserve suffix as-is under blocks.{N}
|
|
90
|
+
return f"blocks.{idx}.{clean_suffix}"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def to_native_name(
|
|
94
|
+
tl_name: str,
|
|
95
|
+
arch_info: "ModelArchInfo | None" = None,
|
|
96
|
+
) -> str:
|
|
97
|
+
"""Translate a TL hook name back to the most likely native module name.
|
|
98
|
+
|
|
99
|
+
Requires ``arch_info`` from a loaded model to resolve the native layer prefix
|
|
100
|
+
(e.g. ``transformer.h`` vs ``model.layers``). Without it, returns a best-guess.
|
|
101
|
+
|
|
102
|
+
Examples::
|
|
103
|
+
|
|
104
|
+
to_native_name("blocks.8.mlp", arch_info) -> "transformer.h.8.mlp"
|
|
105
|
+
to_native_name("blocks.3.attn.hook_q", arch_info) -> "transformer.h.3.attn.q_proj"
|
|
106
|
+
"""
|
|
107
|
+
# Parse TL name: blocks.{N}.{rest}
|
|
108
|
+
tl_match = re.match(r"^blocks\.(\d+)(?:\.(.+))?$", tl_name)
|
|
109
|
+
if tl_match is None:
|
|
110
|
+
return tl_name
|
|
111
|
+
|
|
112
|
+
idx = tl_match.group(1)
|
|
113
|
+
tl_suffix = tl_match.group(2) or ""
|
|
114
|
+
|
|
115
|
+
# Determine native layer prefix from arch_info
|
|
116
|
+
prefix = _infer_native_prefix(arch_info)
|
|
117
|
+
|
|
118
|
+
if not tl_suffix:
|
|
119
|
+
return f"{prefix}.{idx}"
|
|
120
|
+
|
|
121
|
+
# Try specific TL -> native mappings
|
|
122
|
+
for tl_component, native_candidates in _TL_TO_COMPONENT:
|
|
123
|
+
tl_component_clean = tl_component.lstrip(".")
|
|
124
|
+
if tl_suffix == tl_component_clean:
|
|
125
|
+
# Pick the first candidate that exists in the module tree, or fall back to first
|
|
126
|
+
if arch_info is not None:
|
|
127
|
+
module_names = {m.name for m in arch_info.modules}
|
|
128
|
+
for candidate in native_candidates:
|
|
129
|
+
full = f"{prefix}.{idx}.{candidate}"
|
|
130
|
+
if full in module_names:
|
|
131
|
+
return full
|
|
132
|
+
return f"{prefix}.{idx}.{native_candidates[0]}"
|
|
133
|
+
|
|
134
|
+
# Strip TL-specific "hook_" prefixes for unknown suffixes
|
|
135
|
+
clean = re.sub(r"hook_", "", tl_suffix)
|
|
136
|
+
return f"{prefix}.{idx}.{clean}"
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def list_tl_hooks(model: Any) -> list[str]:
|
|
140
|
+
"""List all TL hook point names on a HookedTransformer.
|
|
141
|
+
|
|
142
|
+
Returns an empty list if the model is not a HookedTransformer.
|
|
143
|
+
"""
|
|
144
|
+
hook_dict = getattr(model, "hook_dict", None)
|
|
145
|
+
if hook_dict is not None:
|
|
146
|
+
return sorted(hook_dict.keys())
|
|
147
|
+
|
|
148
|
+
# Fallback: look for HookPoint modules
|
|
149
|
+
hooks = []
|
|
150
|
+
for name, mod in model.named_modules():
|
|
151
|
+
if type(mod).__name__ == "HookPoint":
|
|
152
|
+
hooks.append(name)
|
|
153
|
+
return sorted(hooks)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _infer_native_prefix(arch_info: "ModelArchInfo | None") -> str:
|
|
157
|
+
"""Infer the native layer name prefix (e.g. 'transformer.h', 'model.layers')."""
|
|
158
|
+
if arch_info is None:
|
|
159
|
+
return "blocks"
|
|
160
|
+
|
|
161
|
+
if arch_info.layer_names:
|
|
162
|
+
first = arch_info.layer_names[0]
|
|
163
|
+
# Strip trailing .{digit} to get prefix
|
|
164
|
+
m = re.match(r"^(.+?)\.\d+$", first)
|
|
165
|
+
if m:
|
|
166
|
+
return m.group(1)
|
|
167
|
+
|
|
168
|
+
# Scan modules for repeating indexed patterns
|
|
169
|
+
for mod in arch_info.modules:
|
|
170
|
+
m = re.match(r"^(.+?)\.\d+$", mod.name)
|
|
171
|
+
if m:
|
|
172
|
+
return m.group(1)
|
|
173
|
+
|
|
174
|
+
return "blocks"
|
|
File without changes
|