interpkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
interpkit/core/plot.py ADDED
@@ -0,0 +1,352 @@
1
+ """Matplotlib visualizations — publication-quality figures for mech interp results."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import torch
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.colors as mcolors
12
+ from rich.console import Console
13
+
14
+ console = Console()
15
+
16
+ # ── Shared style ─────────────────────────────────────────────────
17
+
18
+ _PALETTE = {
19
+ "bg": "#1a1a2e",
20
+ "surface": "#16213e",
21
+ "primary": "#0f3460",
22
+ "accent": "#e94560",
23
+ "text": "#eaeaea",
24
+ "muted": "#888888",
25
+ "grid": "#2a2a4a",
26
+ }
27
+
28
+ plt.rcParams.update({
29
+ "figure.facecolor": _PALETTE["bg"],
30
+ "axes.facecolor": _PALETTE["surface"],
31
+ "axes.edgecolor": _PALETTE["grid"],
32
+ "axes.labelcolor": _PALETTE["text"],
33
+ "text.color": _PALETTE["text"],
34
+ "xtick.color": _PALETTE["muted"],
35
+ "ytick.color": _PALETTE["muted"],
36
+ "grid.color": _PALETTE["grid"],
37
+ "grid.alpha": 0.3,
38
+ "font.family": "monospace",
39
+ "font.size": 10,
40
+ })
41
+
42
+
43
+ def _save_and_show(fig: plt.Figure, path: str | None, default_name: str) -> str:
44
+ out = path or default_name
45
+ fig.savefig(out, bbox_inches="tight", dpi=150, facecolor=fig.get_facecolor())
46
+ plt.close(fig)
47
+ console.print(f" Saved to [bold]{out}[/bold]")
48
+ return out
49
+
50
+
51
+ # ── Attention heatmap ────────────────────────────────────────────
52
+
53
+
54
+ def plot_attention(
55
+ weights: torch.Tensor,
56
+ tokens: list[str] | None = None,
57
+ layer: int = 0,
58
+ head: int = 0,
59
+ save_path: str | None = None,
60
+ ) -> str:
61
+ """Plot a single attention head as a heatmap.
62
+
63
+ weights: (seq_len, seq_len) attention matrix
64
+ """
65
+ attn = weights.detach().cpu().float().numpy()
66
+ seq_len = attn.shape[0]
67
+
68
+ fig, ax = plt.subplots(figsize=(max(4, seq_len * 0.6), max(4, seq_len * 0.6)))
69
+
70
+ cmap = mcolors.LinearSegmentedColormap.from_list(
71
+ "interpkit", ["#1a1a2e", "#0f3460", "#e94560", "#ffdd57"]
72
+ )
73
+ im = ax.imshow(attn, cmap=cmap, aspect="equal", vmin=0, vmax=1)
74
+
75
+ if tokens:
76
+ labels = [t[:12] for t in tokens[:seq_len]]
77
+ ax.set_xticks(range(len(labels)))
78
+ ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
79
+ ax.set_yticks(range(len(labels)))
80
+ ax.set_yticklabels(labels, fontsize=8)
81
+
82
+ ax.set_xlabel("Key (attends to)", fontsize=9)
83
+ ax.set_ylabel("Query (from)", fontsize=9)
84
+ ax.set_title(f"Attention — Layer {layer}, Head {head}", fontsize=11, fontweight="bold")
85
+
86
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
87
+ cbar.ax.tick_params(labelsize=8)
88
+ cbar.set_label("Attention weight", fontsize=9)
89
+
90
+ fig.tight_layout()
91
+ default = f"attention_L{layer}_H{head}.png"
92
+ return _save_and_show(fig, save_path, default)
93
+
94
+
95
+ def plot_attention_multi(
96
+ attention_data: list[dict[str, Any]],
97
+ tokens: list[str] | None = None,
98
+ save_path: str | None = None,
99
+ ) -> str:
100
+ """Plot a grid of attention heads."""
101
+ if not attention_data:
102
+ return ""
103
+
104
+ layers = sorted(set(d["layer"] for d in attention_data))
105
+ heads_per_layer = max(
106
+ sum(1 for d in attention_data if d["layer"] == l) for l in layers
107
+ )
108
+ n_layers = len(layers)
109
+ n_heads = min(heads_per_layer, 8) # cap grid width
110
+
111
+ fig, axes = plt.subplots(
112
+ n_layers, n_heads,
113
+ figsize=(n_heads * 2.5, n_layers * 2.5),
114
+ squeeze=False,
115
+ )
116
+
117
+ cmap = mcolors.LinearSegmentedColormap.from_list(
118
+ "interpkit", ["#1a1a2e", "#0f3460", "#e94560", "#ffdd57"]
119
+ )
120
+
121
+ for ax_row in axes:
122
+ for ax in ax_row:
123
+ ax.axis("off")
124
+
125
+ layer_to_idx = {l: i for i, l in enumerate(layers)}
126
+ head_counts: dict[int, int] = {}
127
+
128
+ for entry in attention_data:
129
+ l = entry["layer"]
130
+ row = layer_to_idx[l]
131
+ col = head_counts.get(l, 0)
132
+ head_counts[l] = col + 1
133
+
134
+ if col >= n_heads:
135
+ continue
136
+
137
+ ax = axes[row][col]
138
+ attn = entry["weights"].detach().cpu().float().numpy()
139
+ ax.imshow(attn, cmap=cmap, aspect="equal", vmin=0, vmax=1)
140
+ ax.set_title(f"L{l} H{entry['head']}", fontsize=7, pad=2)
141
+ ax.axis("on")
142
+ ax.set_xticks([])
143
+ ax.set_yticks([])
144
+ ax.spines[:].set_color(_PALETTE["grid"])
145
+
146
+ fig.suptitle("Attention Patterns", fontsize=13, fontweight="bold", y=0.98)
147
+ fig.tight_layout(rect=[0, 0, 1, 0.95])
148
+ return _save_and_show(fig, save_path, "attention_grid.png")
149
+
150
+
151
+ # ── Causal trace bar chart ───────────────────────────────────────
152
+
153
+
154
+ def plot_trace(
155
+ results: list[dict[str, Any]],
156
+ model_name: str = "",
157
+ save_path: str | None = None,
158
+ ) -> str:
159
+ """Horizontal bar chart of causal tracing results."""
160
+ if not results:
161
+ return ""
162
+
163
+ top = results[:25]
164
+ modules = [r["module"].split(".")[-2:] for r in reversed(top)]
165
+ labels = [".".join(m) for m in modules]
166
+ effects = [r["effect"] for r in reversed(top)]
167
+
168
+ fig, ax = plt.subplots(figsize=(8, max(3, len(top) * 0.35)))
169
+
170
+ colors = [
171
+ _PALETTE["accent"] if e == max(effects) else "#0f3460"
172
+ for e in effects
173
+ ]
174
+ bars = ax.barh(range(len(labels)), effects, color=colors, height=0.7, edgecolor="none")
175
+
176
+ ax.set_yticks(range(len(labels)))
177
+ ax.set_yticklabels(labels, fontsize=8)
178
+ ax.set_xlabel("Patching Effect", fontsize=10)
179
+ ax.set_title(f"Causal Trace{f': {model_name}' if model_name else ''}", fontsize=12, fontweight="bold")
180
+ ax.set_xlim(0, max(effects) * 1.15 if effects else 1)
181
+ ax.grid(axis="x", alpha=0.2)
182
+
183
+ for bar, val in zip(bars, effects):
184
+ ax.text(bar.get_width() + max(effects) * 0.02, bar.get_y() + bar.get_height() / 2,
185
+ f"{val:.3f}", va="center", fontsize=7, color=_PALETTE["muted"])
186
+
187
+ fig.tight_layout()
188
+ return _save_and_show(fig, save_path, "causal_trace.png")
189
+
190
+
191
+ # ── Logit lens heatmap ───────────────────────────────────────────
192
+
193
+
194
+ def plot_lens(
195
+ predictions: list[dict[str, Any]],
196
+ save_path: str | None = None,
197
+ ) -> str:
198
+ """Heatmap: layers on y-axis, top-1 token confidence as color, token text annotated."""
199
+ if not predictions:
200
+ return ""
201
+
202
+ layers = [p["layer_name"] for p in predictions]
203
+ probs = [p["top1_prob"] for p in predictions]
204
+ token_labels = [p["top1_token"] for p in predictions]
205
+
206
+ fig, ax = plt.subplots(figsize=(6, max(3, len(layers) * 0.4)))
207
+
208
+ cmap = mcolors.LinearSegmentedColormap.from_list(
209
+ "interpkit_lens", ["#1a1a2e", "#0f3460", "#28a745", "#ffdd57"]
210
+ )
211
+
212
+ # Single-column heatmap
213
+ data = np.array(probs).reshape(-1, 1)
214
+ im = ax.imshow(data, cmap=cmap, aspect=0.3, vmin=0, vmax=1)
215
+
216
+ ax.set_yticks(range(len(layers)))
217
+ ax.set_yticklabels(layers, fontsize=8)
218
+ ax.set_xticks([])
219
+
220
+ for i, (tok, prob) in enumerate(zip(token_labels, probs)):
221
+ text_color = _PALETTE["bg"] if prob > 0.5 else _PALETTE["text"]
222
+ ax.text(0, i, f" {tok} ({prob:.2f})", ha="center", va="center",
223
+ fontsize=8, fontweight="bold", color=text_color)
224
+
225
+ ax.set_title("Logit Lens — Top-1 per Layer", fontsize=11, fontweight="bold")
226
+
227
+ cbar = fig.colorbar(im, ax=ax, fraction=0.05, pad=0.04)
228
+ cbar.set_label("Probability", fontsize=9)
229
+ cbar.ax.tick_params(labelsize=8)
230
+
231
+ fig.tight_layout()
232
+ return _save_and_show(fig, save_path, "logit_lens.png")
233
+
234
+
235
+ # ── Steering comparison ─────────────────────────────────────────
236
+
237
+
238
+ def plot_steer(
239
+ original_tokens: list[tuple[str, float]],
240
+ steered_tokens: list[tuple[str, float]],
241
+ module_name: str = "",
242
+ scale: float = 1.0,
243
+ save_path: str | None = None,
244
+ ) -> str:
245
+ """Grouped bar chart comparing original vs steered top token probabilities."""
246
+ n = min(10, len(original_tokens), len(steered_tokens))
247
+ if n == 0:
248
+ return ""
249
+
250
+ labels_orig = [t[0].strip() for t in original_tokens[:n]]
251
+ probs_orig = [t[1] for t in original_tokens[:n]]
252
+ labels_steer = [t[0].strip() for t in steered_tokens[:n]]
253
+ probs_steer = [t[1] for t in steered_tokens[:n]]
254
+
255
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, max(3, n * 0.4)), sharey=False)
256
+
257
+ # Original
258
+ ax1.barh(range(n), probs_orig, color="#0f3460", height=0.6)
259
+ ax1.set_yticks(range(n))
260
+ ax1.set_yticklabels(labels_orig, fontsize=9)
261
+ ax1.set_xlabel("Probability", fontsize=9)
262
+ ax1.set_title("Original", fontsize=11, fontweight="bold")
263
+ ax1.invert_yaxis()
264
+ ax1.set_xlim(0, max(probs_orig + probs_steer) * 1.2)
265
+ ax1.grid(axis="x", alpha=0.2)
266
+
267
+ # Steered
268
+ ax2.barh(range(n), probs_steer, color=_PALETTE["accent"], height=0.6)
269
+ ax2.set_yticks(range(n))
270
+ ax2.set_yticklabels(labels_steer, fontsize=9)
271
+ ax2.set_xlabel("Probability", fontsize=9)
272
+ ax2.set_title(f"Steered (scale={scale})", fontsize=11, fontweight="bold")
273
+ ax2.invert_yaxis()
274
+ ax2.set_xlim(0, max(probs_orig + probs_steer) * 1.2)
275
+ ax2.grid(axis="x", alpha=0.2)
276
+
277
+ fig.suptitle(f"Steering at {module_name}", fontsize=12, fontweight="bold", y=1.0)
278
+ fig.tight_layout()
279
+ return _save_and_show(fig, save_path, "steering.png")
280
+
281
+
282
+ # ── Diff bar chart ───────────────────────────────────────────────
283
+
284
+
285
+ def plot_diff(
286
+ results: list[dict[str, Any]],
287
+ model_a_name: str = "A",
288
+ model_b_name: str = "B",
289
+ save_path: str | None = None,
290
+ ) -> str:
291
+ """Horizontal bar chart of per-layer activation distance between two models."""
292
+ if not results:
293
+ return ""
294
+
295
+ top = results[:25]
296
+ modules = [r["module"].split(".")[-2:] for r in reversed(top)]
297
+ labels = [".".join(m) for m in modules]
298
+ distances = [r["distance"] for r in reversed(top)]
299
+
300
+ fig, ax = plt.subplots(figsize=(8, max(3, len(top) * 0.35)))
301
+
302
+ colors = [
303
+ _PALETTE["accent"] if d == max(distances) else "#0f3460"
304
+ for d in distances
305
+ ]
306
+ bars = ax.barh(range(len(labels)), distances, color=colors, height=0.7)
307
+
308
+ ax.set_yticks(range(len(labels)))
309
+ ax.set_yticklabels(labels, fontsize=8)
310
+ ax.set_xlabel("Cosine Distance", fontsize=10)
311
+ ax.set_title(f"Model Diff: {model_a_name} vs {model_b_name}", fontsize=12, fontweight="bold")
312
+ ax.set_xlim(0, max(distances) * 1.15 if distances else 1)
313
+ ax.grid(axis="x", alpha=0.2)
314
+
315
+ for bar, val in zip(bars, distances):
316
+ ax.text(bar.get_width() + max(distances) * 0.02, bar.get_y() + bar.get_height() / 2,
317
+ f"{val:.4f}", va="center", fontsize=7, color=_PALETTE["muted"])
318
+
319
+ try:
320
+ fig.tight_layout()
321
+ except ValueError:
322
+ pass
323
+ return _save_and_show(fig, save_path, "model_diff.png")
324
+
325
+
326
+ # ── Attribution (text) bar chart ─────────────────────────────────
327
+
328
+
329
+ def plot_attribution(
330
+ tokens: list[str],
331
+ scores: list[float],
332
+ save_path: str | None = None,
333
+ ) -> str:
334
+ """Bar chart of token attribution scores."""
335
+ if not scores:
336
+ return ""
337
+
338
+ fig, ax = plt.subplots(figsize=(max(4, len(tokens) * 0.6), 4))
339
+
340
+ max_score = max(abs(s) for s in scores)
341
+ norm_scores = [s / max_score if max_score > 0 else 0 for s in scores]
342
+ colors = [_PALETTE["accent"] if ns > 0.5 else "#0f3460" for ns in [abs(n) for n in norm_scores]]
343
+
344
+ bars = ax.bar(range(len(tokens)), [abs(s) for s in scores], color=colors, width=0.7)
345
+ ax.set_xticks(range(len(tokens)))
346
+ ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=9)
347
+ ax.set_ylabel("Attribution Score", fontsize=10)
348
+ ax.set_title("Gradient Attribution", fontsize=12, fontweight="bold")
349
+ ax.grid(axis="y", alpha=0.2)
350
+
351
+ fig.tight_layout()
352
+ return _save_and_show(fig, save_path, "attribution.png")
@@ -0,0 +1,82 @@
1
+ """Manual annotation registry for custom nn.Module models.
2
+
3
+ Usage::
4
+
5
+ import interpkit
6
+
7
+ interpkit.register(
8
+ my_model,
9
+ layers=["blocks.0", "blocks.1", "blocks.2"],
10
+ output_head="head",
11
+ )
12
+
13
+ model = interpkit.load(my_model, tokenizer=my_tokenizer)
14
+ model.trace(tensor_a, tensor_b)
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import weakref
20
+ from dataclasses import dataclass, field
21
+ from typing import TYPE_CHECKING
22
+
23
+ if TYPE_CHECKING:
24
+ import torch.nn as nn
25
+
26
+ _REGISTRY: dict[int, "Registration"] = {}
27
+
28
+
29
+ @dataclass
30
+ class Registration:
31
+ layers: list[str] = field(default_factory=list)
32
+ output_head: str | None = None
33
+ attention_modules: list[str] = field(default_factory=list)
34
+ mlp_modules: list[str] = field(default_factory=list)
35
+
36
+
37
+ def register(
38
+ model: nn.Module,
39
+ *,
40
+ layers: list[str] | None = None,
41
+ output_head: str | None = None,
42
+ attention_modules: list[str] | None = None,
43
+ mlp_modules: list[str] | None = None,
44
+ ) -> None:
45
+ """Annotate a custom ``nn.Module`` so interpkit knows its structure.
46
+
47
+ Parameters
48
+ ----------
49
+ model:
50
+ The model instance to annotate.
51
+ layers:
52
+ Ordered list of module names that constitute the repeated layer blocks
53
+ (e.g. ``["blocks.0", "blocks.1", ...]``).
54
+ output_head:
55
+ Module name of the output / LM head.
56
+ attention_modules:
57
+ Module names that should be treated as attention.
58
+ mlp_modules:
59
+ Module names that should be treated as MLPs.
60
+ """
61
+ reg = Registration(
62
+ layers=layers or [],
63
+ output_head=output_head,
64
+ attention_modules=attention_modules or [],
65
+ mlp_modules=mlp_modules or [],
66
+ )
67
+ model_id = id(model)
68
+ _REGISTRY[model_id] = reg
69
+
70
+ # Clean up when the model is garbage-collected
71
+ def _cleanup(ref: weakref.ref) -> None: # noqa: ARG001
72
+ _REGISTRY.pop(model_id, None)
73
+
74
+ try:
75
+ weakref.finalize(model, _cleanup, weakref.ref(model))
76
+ except TypeError:
77
+ pass
78
+
79
+
80
+ def get_registration(model: nn.Module) -> Registration | None:
81
+ """Return the registration for *model*, if any."""
82
+ return _REGISTRY.get(id(model))