interpkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit/__init__.py +15 -0
- interpkit/cli/__init__.py +0 -0
- interpkit/cli/main.py +337 -0
- interpkit/core/__init__.py +0 -0
- interpkit/core/discovery.py +228 -0
- interpkit/core/html.py +375 -0
- interpkit/core/inputs.py +117 -0
- interpkit/core/model.py +551 -0
- interpkit/core/plot.py +352 -0
- interpkit/core/registry.py +82 -0
- interpkit/core/render.py +465 -0
- interpkit/core/tl_compat.py +174 -0
- interpkit/ops/__init__.py +0 -0
- interpkit/ops/ablate.py +90 -0
- interpkit/ops/activations.py +67 -0
- interpkit/ops/attention.py +234 -0
- interpkit/ops/attribute.py +206 -0
- interpkit/ops/diff.py +79 -0
- interpkit/ops/inspect.py +14 -0
- interpkit/ops/lens.py +151 -0
- interpkit/ops/patch.py +112 -0
- interpkit/ops/probe.py +128 -0
- interpkit/ops/sae.py +212 -0
- interpkit/ops/steer.py +118 -0
- interpkit/ops/trace.py +182 -0
- interpkit-0.1.0.dist-info/METADATA +295 -0
- interpkit-0.1.0.dist-info/RECORD +31 -0
- interpkit-0.1.0.dist-info/WHEEL +5 -0
- interpkit-0.1.0.dist-info/entry_points.txt +2 -0
- interpkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- interpkit-0.1.0.dist-info/top_level.txt +1 -0
interpkit/core/plot.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
"""Matplotlib visualizations — publication-quality figures for mech interp results."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
import numpy as np
|
|
10
|
+
import matplotlib.pyplot as plt
|
|
11
|
+
import matplotlib.colors as mcolors
|
|
12
|
+
from rich.console import Console
|
|
13
|
+
|
|
14
|
+
console = Console()
|
|
15
|
+
|
|
16
|
+
# ── Shared style ─────────────────────────────────────────────────
|
|
17
|
+
|
|
18
|
+
_PALETTE = {
|
|
19
|
+
"bg": "#1a1a2e",
|
|
20
|
+
"surface": "#16213e",
|
|
21
|
+
"primary": "#0f3460",
|
|
22
|
+
"accent": "#e94560",
|
|
23
|
+
"text": "#eaeaea",
|
|
24
|
+
"muted": "#888888",
|
|
25
|
+
"grid": "#2a2a4a",
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
plt.rcParams.update({
|
|
29
|
+
"figure.facecolor": _PALETTE["bg"],
|
|
30
|
+
"axes.facecolor": _PALETTE["surface"],
|
|
31
|
+
"axes.edgecolor": _PALETTE["grid"],
|
|
32
|
+
"axes.labelcolor": _PALETTE["text"],
|
|
33
|
+
"text.color": _PALETTE["text"],
|
|
34
|
+
"xtick.color": _PALETTE["muted"],
|
|
35
|
+
"ytick.color": _PALETTE["muted"],
|
|
36
|
+
"grid.color": _PALETTE["grid"],
|
|
37
|
+
"grid.alpha": 0.3,
|
|
38
|
+
"font.family": "monospace",
|
|
39
|
+
"font.size": 10,
|
|
40
|
+
})
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _save_and_show(fig: plt.Figure, path: str | None, default_name: str) -> str:
|
|
44
|
+
out = path or default_name
|
|
45
|
+
fig.savefig(out, bbox_inches="tight", dpi=150, facecolor=fig.get_facecolor())
|
|
46
|
+
plt.close(fig)
|
|
47
|
+
console.print(f" Saved to [bold]{out}[/bold]")
|
|
48
|
+
return out
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# ── Attention heatmap ────────────────────────────────────────────
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def plot_attention(
|
|
55
|
+
weights: torch.Tensor,
|
|
56
|
+
tokens: list[str] | None = None,
|
|
57
|
+
layer: int = 0,
|
|
58
|
+
head: int = 0,
|
|
59
|
+
save_path: str | None = None,
|
|
60
|
+
) -> str:
|
|
61
|
+
"""Plot a single attention head as a heatmap.
|
|
62
|
+
|
|
63
|
+
weights: (seq_len, seq_len) attention matrix
|
|
64
|
+
"""
|
|
65
|
+
attn = weights.detach().cpu().float().numpy()
|
|
66
|
+
seq_len = attn.shape[0]
|
|
67
|
+
|
|
68
|
+
fig, ax = plt.subplots(figsize=(max(4, seq_len * 0.6), max(4, seq_len * 0.6)))
|
|
69
|
+
|
|
70
|
+
cmap = mcolors.LinearSegmentedColormap.from_list(
|
|
71
|
+
"interpkit", ["#1a1a2e", "#0f3460", "#e94560", "#ffdd57"]
|
|
72
|
+
)
|
|
73
|
+
im = ax.imshow(attn, cmap=cmap, aspect="equal", vmin=0, vmax=1)
|
|
74
|
+
|
|
75
|
+
if tokens:
|
|
76
|
+
labels = [t[:12] for t in tokens[:seq_len]]
|
|
77
|
+
ax.set_xticks(range(len(labels)))
|
|
78
|
+
ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=8)
|
|
79
|
+
ax.set_yticks(range(len(labels)))
|
|
80
|
+
ax.set_yticklabels(labels, fontsize=8)
|
|
81
|
+
|
|
82
|
+
ax.set_xlabel("Key (attends to)", fontsize=9)
|
|
83
|
+
ax.set_ylabel("Query (from)", fontsize=9)
|
|
84
|
+
ax.set_title(f"Attention — Layer {layer}, Head {head}", fontsize=11, fontweight="bold")
|
|
85
|
+
|
|
86
|
+
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
|
87
|
+
cbar.ax.tick_params(labelsize=8)
|
|
88
|
+
cbar.set_label("Attention weight", fontsize=9)
|
|
89
|
+
|
|
90
|
+
fig.tight_layout()
|
|
91
|
+
default = f"attention_L{layer}_H{head}.png"
|
|
92
|
+
return _save_and_show(fig, save_path, default)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def plot_attention_multi(
|
|
96
|
+
attention_data: list[dict[str, Any]],
|
|
97
|
+
tokens: list[str] | None = None,
|
|
98
|
+
save_path: str | None = None,
|
|
99
|
+
) -> str:
|
|
100
|
+
"""Plot a grid of attention heads."""
|
|
101
|
+
if not attention_data:
|
|
102
|
+
return ""
|
|
103
|
+
|
|
104
|
+
layers = sorted(set(d["layer"] for d in attention_data))
|
|
105
|
+
heads_per_layer = max(
|
|
106
|
+
sum(1 for d in attention_data if d["layer"] == l) for l in layers
|
|
107
|
+
)
|
|
108
|
+
n_layers = len(layers)
|
|
109
|
+
n_heads = min(heads_per_layer, 8) # cap grid width
|
|
110
|
+
|
|
111
|
+
fig, axes = plt.subplots(
|
|
112
|
+
n_layers, n_heads,
|
|
113
|
+
figsize=(n_heads * 2.5, n_layers * 2.5),
|
|
114
|
+
squeeze=False,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
cmap = mcolors.LinearSegmentedColormap.from_list(
|
|
118
|
+
"interpkit", ["#1a1a2e", "#0f3460", "#e94560", "#ffdd57"]
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
for ax_row in axes:
|
|
122
|
+
for ax in ax_row:
|
|
123
|
+
ax.axis("off")
|
|
124
|
+
|
|
125
|
+
layer_to_idx = {l: i for i, l in enumerate(layers)}
|
|
126
|
+
head_counts: dict[int, int] = {}
|
|
127
|
+
|
|
128
|
+
for entry in attention_data:
|
|
129
|
+
l = entry["layer"]
|
|
130
|
+
row = layer_to_idx[l]
|
|
131
|
+
col = head_counts.get(l, 0)
|
|
132
|
+
head_counts[l] = col + 1
|
|
133
|
+
|
|
134
|
+
if col >= n_heads:
|
|
135
|
+
continue
|
|
136
|
+
|
|
137
|
+
ax = axes[row][col]
|
|
138
|
+
attn = entry["weights"].detach().cpu().float().numpy()
|
|
139
|
+
ax.imshow(attn, cmap=cmap, aspect="equal", vmin=0, vmax=1)
|
|
140
|
+
ax.set_title(f"L{l} H{entry['head']}", fontsize=7, pad=2)
|
|
141
|
+
ax.axis("on")
|
|
142
|
+
ax.set_xticks([])
|
|
143
|
+
ax.set_yticks([])
|
|
144
|
+
ax.spines[:].set_color(_PALETTE["grid"])
|
|
145
|
+
|
|
146
|
+
fig.suptitle("Attention Patterns", fontsize=13, fontweight="bold", y=0.98)
|
|
147
|
+
fig.tight_layout(rect=[0, 0, 1, 0.95])
|
|
148
|
+
return _save_and_show(fig, save_path, "attention_grid.png")
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# ── Causal trace bar chart ───────────────────────────────────────
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def plot_trace(
|
|
155
|
+
results: list[dict[str, Any]],
|
|
156
|
+
model_name: str = "",
|
|
157
|
+
save_path: str | None = None,
|
|
158
|
+
) -> str:
|
|
159
|
+
"""Horizontal bar chart of causal tracing results."""
|
|
160
|
+
if not results:
|
|
161
|
+
return ""
|
|
162
|
+
|
|
163
|
+
top = results[:25]
|
|
164
|
+
modules = [r["module"].split(".")[-2:] for r in reversed(top)]
|
|
165
|
+
labels = [".".join(m) for m in modules]
|
|
166
|
+
effects = [r["effect"] for r in reversed(top)]
|
|
167
|
+
|
|
168
|
+
fig, ax = plt.subplots(figsize=(8, max(3, len(top) * 0.35)))
|
|
169
|
+
|
|
170
|
+
colors = [
|
|
171
|
+
_PALETTE["accent"] if e == max(effects) else "#0f3460"
|
|
172
|
+
for e in effects
|
|
173
|
+
]
|
|
174
|
+
bars = ax.barh(range(len(labels)), effects, color=colors, height=0.7, edgecolor="none")
|
|
175
|
+
|
|
176
|
+
ax.set_yticks(range(len(labels)))
|
|
177
|
+
ax.set_yticklabels(labels, fontsize=8)
|
|
178
|
+
ax.set_xlabel("Patching Effect", fontsize=10)
|
|
179
|
+
ax.set_title(f"Causal Trace{f': {model_name}' if model_name else ''}", fontsize=12, fontweight="bold")
|
|
180
|
+
ax.set_xlim(0, max(effects) * 1.15 if effects else 1)
|
|
181
|
+
ax.grid(axis="x", alpha=0.2)
|
|
182
|
+
|
|
183
|
+
for bar, val in zip(bars, effects):
|
|
184
|
+
ax.text(bar.get_width() + max(effects) * 0.02, bar.get_y() + bar.get_height() / 2,
|
|
185
|
+
f"{val:.3f}", va="center", fontsize=7, color=_PALETTE["muted"])
|
|
186
|
+
|
|
187
|
+
fig.tight_layout()
|
|
188
|
+
return _save_and_show(fig, save_path, "causal_trace.png")
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
# ── Logit lens heatmap ───────────────────────────────────────────
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def plot_lens(
|
|
195
|
+
predictions: list[dict[str, Any]],
|
|
196
|
+
save_path: str | None = None,
|
|
197
|
+
) -> str:
|
|
198
|
+
"""Heatmap: layers on y-axis, top-1 token confidence as color, token text annotated."""
|
|
199
|
+
if not predictions:
|
|
200
|
+
return ""
|
|
201
|
+
|
|
202
|
+
layers = [p["layer_name"] for p in predictions]
|
|
203
|
+
probs = [p["top1_prob"] for p in predictions]
|
|
204
|
+
token_labels = [p["top1_token"] for p in predictions]
|
|
205
|
+
|
|
206
|
+
fig, ax = plt.subplots(figsize=(6, max(3, len(layers) * 0.4)))
|
|
207
|
+
|
|
208
|
+
cmap = mcolors.LinearSegmentedColormap.from_list(
|
|
209
|
+
"interpkit_lens", ["#1a1a2e", "#0f3460", "#28a745", "#ffdd57"]
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Single-column heatmap
|
|
213
|
+
data = np.array(probs).reshape(-1, 1)
|
|
214
|
+
im = ax.imshow(data, cmap=cmap, aspect=0.3, vmin=0, vmax=1)
|
|
215
|
+
|
|
216
|
+
ax.set_yticks(range(len(layers)))
|
|
217
|
+
ax.set_yticklabels(layers, fontsize=8)
|
|
218
|
+
ax.set_xticks([])
|
|
219
|
+
|
|
220
|
+
for i, (tok, prob) in enumerate(zip(token_labels, probs)):
|
|
221
|
+
text_color = _PALETTE["bg"] if prob > 0.5 else _PALETTE["text"]
|
|
222
|
+
ax.text(0, i, f" {tok} ({prob:.2f})", ha="center", va="center",
|
|
223
|
+
fontsize=8, fontweight="bold", color=text_color)
|
|
224
|
+
|
|
225
|
+
ax.set_title("Logit Lens — Top-1 per Layer", fontsize=11, fontweight="bold")
|
|
226
|
+
|
|
227
|
+
cbar = fig.colorbar(im, ax=ax, fraction=0.05, pad=0.04)
|
|
228
|
+
cbar.set_label("Probability", fontsize=9)
|
|
229
|
+
cbar.ax.tick_params(labelsize=8)
|
|
230
|
+
|
|
231
|
+
fig.tight_layout()
|
|
232
|
+
return _save_and_show(fig, save_path, "logit_lens.png")
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# ── Steering comparison ─────────────────────────────────────────
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def plot_steer(
|
|
239
|
+
original_tokens: list[tuple[str, float]],
|
|
240
|
+
steered_tokens: list[tuple[str, float]],
|
|
241
|
+
module_name: str = "",
|
|
242
|
+
scale: float = 1.0,
|
|
243
|
+
save_path: str | None = None,
|
|
244
|
+
) -> str:
|
|
245
|
+
"""Grouped bar chart comparing original vs steered top token probabilities."""
|
|
246
|
+
n = min(10, len(original_tokens), len(steered_tokens))
|
|
247
|
+
if n == 0:
|
|
248
|
+
return ""
|
|
249
|
+
|
|
250
|
+
labels_orig = [t[0].strip() for t in original_tokens[:n]]
|
|
251
|
+
probs_orig = [t[1] for t in original_tokens[:n]]
|
|
252
|
+
labels_steer = [t[0].strip() for t in steered_tokens[:n]]
|
|
253
|
+
probs_steer = [t[1] for t in steered_tokens[:n]]
|
|
254
|
+
|
|
255
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, max(3, n * 0.4)), sharey=False)
|
|
256
|
+
|
|
257
|
+
# Original
|
|
258
|
+
ax1.barh(range(n), probs_orig, color="#0f3460", height=0.6)
|
|
259
|
+
ax1.set_yticks(range(n))
|
|
260
|
+
ax1.set_yticklabels(labels_orig, fontsize=9)
|
|
261
|
+
ax1.set_xlabel("Probability", fontsize=9)
|
|
262
|
+
ax1.set_title("Original", fontsize=11, fontweight="bold")
|
|
263
|
+
ax1.invert_yaxis()
|
|
264
|
+
ax1.set_xlim(0, max(probs_orig + probs_steer) * 1.2)
|
|
265
|
+
ax1.grid(axis="x", alpha=0.2)
|
|
266
|
+
|
|
267
|
+
# Steered
|
|
268
|
+
ax2.barh(range(n), probs_steer, color=_PALETTE["accent"], height=0.6)
|
|
269
|
+
ax2.set_yticks(range(n))
|
|
270
|
+
ax2.set_yticklabels(labels_steer, fontsize=9)
|
|
271
|
+
ax2.set_xlabel("Probability", fontsize=9)
|
|
272
|
+
ax2.set_title(f"Steered (scale={scale})", fontsize=11, fontweight="bold")
|
|
273
|
+
ax2.invert_yaxis()
|
|
274
|
+
ax2.set_xlim(0, max(probs_orig + probs_steer) * 1.2)
|
|
275
|
+
ax2.grid(axis="x", alpha=0.2)
|
|
276
|
+
|
|
277
|
+
fig.suptitle(f"Steering at {module_name}", fontsize=12, fontweight="bold", y=1.0)
|
|
278
|
+
fig.tight_layout()
|
|
279
|
+
return _save_and_show(fig, save_path, "steering.png")
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# ── Diff bar chart ───────────────────────────────────────────────
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def plot_diff(
|
|
286
|
+
results: list[dict[str, Any]],
|
|
287
|
+
model_a_name: str = "A",
|
|
288
|
+
model_b_name: str = "B",
|
|
289
|
+
save_path: str | None = None,
|
|
290
|
+
) -> str:
|
|
291
|
+
"""Horizontal bar chart of per-layer activation distance between two models."""
|
|
292
|
+
if not results:
|
|
293
|
+
return ""
|
|
294
|
+
|
|
295
|
+
top = results[:25]
|
|
296
|
+
modules = [r["module"].split(".")[-2:] for r in reversed(top)]
|
|
297
|
+
labels = [".".join(m) for m in modules]
|
|
298
|
+
distances = [r["distance"] for r in reversed(top)]
|
|
299
|
+
|
|
300
|
+
fig, ax = plt.subplots(figsize=(8, max(3, len(top) * 0.35)))
|
|
301
|
+
|
|
302
|
+
colors = [
|
|
303
|
+
_PALETTE["accent"] if d == max(distances) else "#0f3460"
|
|
304
|
+
for d in distances
|
|
305
|
+
]
|
|
306
|
+
bars = ax.barh(range(len(labels)), distances, color=colors, height=0.7)
|
|
307
|
+
|
|
308
|
+
ax.set_yticks(range(len(labels)))
|
|
309
|
+
ax.set_yticklabels(labels, fontsize=8)
|
|
310
|
+
ax.set_xlabel("Cosine Distance", fontsize=10)
|
|
311
|
+
ax.set_title(f"Model Diff: {model_a_name} vs {model_b_name}", fontsize=12, fontweight="bold")
|
|
312
|
+
ax.set_xlim(0, max(distances) * 1.15 if distances else 1)
|
|
313
|
+
ax.grid(axis="x", alpha=0.2)
|
|
314
|
+
|
|
315
|
+
for bar, val in zip(bars, distances):
|
|
316
|
+
ax.text(bar.get_width() + max(distances) * 0.02, bar.get_y() + bar.get_height() / 2,
|
|
317
|
+
f"{val:.4f}", va="center", fontsize=7, color=_PALETTE["muted"])
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
fig.tight_layout()
|
|
321
|
+
except ValueError:
|
|
322
|
+
pass
|
|
323
|
+
return _save_and_show(fig, save_path, "model_diff.png")
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# ── Attribution (text) bar chart ─────────────────────────────────
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def plot_attribution(
|
|
330
|
+
tokens: list[str],
|
|
331
|
+
scores: list[float],
|
|
332
|
+
save_path: str | None = None,
|
|
333
|
+
) -> str:
|
|
334
|
+
"""Bar chart of token attribution scores."""
|
|
335
|
+
if not scores:
|
|
336
|
+
return ""
|
|
337
|
+
|
|
338
|
+
fig, ax = plt.subplots(figsize=(max(4, len(tokens) * 0.6), 4))
|
|
339
|
+
|
|
340
|
+
max_score = max(abs(s) for s in scores)
|
|
341
|
+
norm_scores = [s / max_score if max_score > 0 else 0 for s in scores]
|
|
342
|
+
colors = [_PALETTE["accent"] if ns > 0.5 else "#0f3460" for ns in [abs(n) for n in norm_scores]]
|
|
343
|
+
|
|
344
|
+
bars = ax.bar(range(len(tokens)), [abs(s) for s in scores], color=colors, width=0.7)
|
|
345
|
+
ax.set_xticks(range(len(tokens)))
|
|
346
|
+
ax.set_xticklabels(tokens, rotation=45, ha="right", fontsize=9)
|
|
347
|
+
ax.set_ylabel("Attribution Score", fontsize=10)
|
|
348
|
+
ax.set_title("Gradient Attribution", fontsize=12, fontweight="bold")
|
|
349
|
+
ax.grid(axis="y", alpha=0.2)
|
|
350
|
+
|
|
351
|
+
fig.tight_layout()
|
|
352
|
+
return _save_and_show(fig, save_path, "attribution.png")
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Manual annotation registry for custom nn.Module models.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
import interpkit
|
|
6
|
+
|
|
7
|
+
interpkit.register(
|
|
8
|
+
my_model,
|
|
9
|
+
layers=["blocks.0", "blocks.1", "blocks.2"],
|
|
10
|
+
output_head="head",
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
model = interpkit.load(my_model, tokenizer=my_tokenizer)
|
|
14
|
+
model.trace(tensor_a, tensor_b)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import weakref
|
|
20
|
+
from dataclasses import dataclass, field
|
|
21
|
+
from typing import TYPE_CHECKING
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import torch.nn as nn
|
|
25
|
+
|
|
26
|
+
_REGISTRY: dict[int, "Registration"] = {}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class Registration:
|
|
31
|
+
layers: list[str] = field(default_factory=list)
|
|
32
|
+
output_head: str | None = None
|
|
33
|
+
attention_modules: list[str] = field(default_factory=list)
|
|
34
|
+
mlp_modules: list[str] = field(default_factory=list)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def register(
|
|
38
|
+
model: nn.Module,
|
|
39
|
+
*,
|
|
40
|
+
layers: list[str] | None = None,
|
|
41
|
+
output_head: str | None = None,
|
|
42
|
+
attention_modules: list[str] | None = None,
|
|
43
|
+
mlp_modules: list[str] | None = None,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Annotate a custom ``nn.Module`` so interpkit knows its structure.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
model:
|
|
50
|
+
The model instance to annotate.
|
|
51
|
+
layers:
|
|
52
|
+
Ordered list of module names that constitute the repeated layer blocks
|
|
53
|
+
(e.g. ``["blocks.0", "blocks.1", ...]``).
|
|
54
|
+
output_head:
|
|
55
|
+
Module name of the output / LM head.
|
|
56
|
+
attention_modules:
|
|
57
|
+
Module names that should be treated as attention.
|
|
58
|
+
mlp_modules:
|
|
59
|
+
Module names that should be treated as MLPs.
|
|
60
|
+
"""
|
|
61
|
+
reg = Registration(
|
|
62
|
+
layers=layers or [],
|
|
63
|
+
output_head=output_head,
|
|
64
|
+
attention_modules=attention_modules or [],
|
|
65
|
+
mlp_modules=mlp_modules or [],
|
|
66
|
+
)
|
|
67
|
+
model_id = id(model)
|
|
68
|
+
_REGISTRY[model_id] = reg
|
|
69
|
+
|
|
70
|
+
# Clean up when the model is garbage-collected
|
|
71
|
+
def _cleanup(ref: weakref.ref) -> None: # noqa: ARG001
|
|
72
|
+
_REGISTRY.pop(model_id, None)
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
weakref.finalize(model, _cleanup, weakref.ref(model))
|
|
76
|
+
except TypeError:
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_registration(model: nn.Module) -> Registration | None:
|
|
81
|
+
"""Return the registration for *model*, if any."""
|
|
82
|
+
return _REGISTRY.get(id(model))
|