interpkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit/__init__.py +15 -0
- interpkit/cli/__init__.py +0 -0
- interpkit/cli/main.py +337 -0
- interpkit/core/__init__.py +0 -0
- interpkit/core/discovery.py +228 -0
- interpkit/core/html.py +375 -0
- interpkit/core/inputs.py +117 -0
- interpkit/core/model.py +551 -0
- interpkit/core/plot.py +352 -0
- interpkit/core/registry.py +82 -0
- interpkit/core/render.py +465 -0
- interpkit/core/tl_compat.py +174 -0
- interpkit/ops/__init__.py +0 -0
- interpkit/ops/ablate.py +90 -0
- interpkit/ops/activations.py +67 -0
- interpkit/ops/attention.py +234 -0
- interpkit/ops/attribute.py +206 -0
- interpkit/ops/diff.py +79 -0
- interpkit/ops/inspect.py +14 -0
- interpkit/ops/lens.py +151 -0
- interpkit/ops/patch.py +112 -0
- interpkit/ops/probe.py +128 -0
- interpkit/ops/sae.py +212 -0
- interpkit/ops/steer.py +118 -0
- interpkit/ops/trace.py +182 -0
- interpkit-0.1.0.dist-info/METADATA +295 -0
- interpkit-0.1.0.dist-info/RECORD +31 -0
- interpkit-0.1.0.dist-info/WHEEL +5 -0
- interpkit-0.1.0.dist-info/entry_points.txt +2 -0
- interpkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- interpkit-0.1.0.dist-info/top_level.txt +1 -0
interpkit/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""interpkit — mech interp for any HuggingFace model."""
|
|
2
|
+
|
|
3
|
+
from interpkit.core.model import load
|
|
4
|
+
from interpkit.core.registry import register
|
|
5
|
+
from interpkit.core.tl_compat import list_tl_hooks, to_native_name, to_tl_name
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def diff(model_a, model_b, input_data, *, save=None):
|
|
9
|
+
"""Compare activations between two models on the same input."""
|
|
10
|
+
from interpkit.ops.diff import run_diff
|
|
11
|
+
|
|
12
|
+
return run_diff(model_a, model_b, input_data, save=save)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = ["load", "register", "diff", "to_tl_name", "to_native_name", "list_tl_hooks"]
|
|
File without changes
|
interpkit/cli/main.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""CLI entry point — Typer app with all interpkit commands."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import typer
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from rich.panel import Panel
|
|
10
|
+
from rich.table import Table
|
|
11
|
+
from rich.text import Text
|
|
12
|
+
|
|
13
|
+
app = typer.Typer(
|
|
14
|
+
name="interpkit",
|
|
15
|
+
help="Mech interp for any HuggingFace model.",
|
|
16
|
+
no_args_is_help=False,
|
|
17
|
+
add_completion=False,
|
|
18
|
+
rich_markup_mode="rich",
|
|
19
|
+
)
|
|
20
|
+
console = Console()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _load_model(model_name: str, device: str | None = None):
|
|
24
|
+
from interpkit.core.model import load
|
|
25
|
+
|
|
26
|
+
with console.status(f"Loading {model_name}..."):
|
|
27
|
+
return load(model_name, device=device)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ══════════════════════════════════════════════════════════════════
|
|
31
|
+
# help — rich overview panel
|
|
32
|
+
# ══════════════════════════════════════════════════════════════════
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@app.callback(invoke_without_command=True)
|
|
36
|
+
def main(ctx: typer.Context) -> None:
|
|
37
|
+
"""Mech interp for any HuggingFace model."""
|
|
38
|
+
if ctx.invoked_subcommand is not None:
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
logo = r"""
|
|
42
|
+
___ _ _ ___ _
|
|
43
|
+
|_ _|_ __ | |_ ___ _ __ _ __| |/ (_) |_
|
|
44
|
+
| || '_ \| __/ _ \ '__| '_ \ ' /| | __|
|
|
45
|
+
| || | | | || __/ | | |_) | . \| | |_
|
|
46
|
+
|___|_| |_|\__\___|_| | .__/|_|\_\_|\__|
|
|
47
|
+
|_|
|
|
48
|
+
"""
|
|
49
|
+
console.print(f"[bold cyan]{logo}[/bold cyan]", highlight=False)
|
|
50
|
+
|
|
51
|
+
table = Table(
|
|
52
|
+
show_header=True, header_style="bold", show_lines=False,
|
|
53
|
+
pad_edge=True, expand=True,
|
|
54
|
+
)
|
|
55
|
+
table.add_column("Command", style="cyan", no_wrap=True)
|
|
56
|
+
table.add_column("Description")
|
|
57
|
+
table.add_column("Example", style="dim")
|
|
58
|
+
|
|
59
|
+
rows = [
|
|
60
|
+
("", "[bold]Core Operations[/bold]", ""),
|
|
61
|
+
("inspect", "Module tree with types, params, roles", "interpkit inspect gpt2"),
|
|
62
|
+
("patch", "Activation patching at a module", "interpkit patch gpt2 --clean '...' --corrupted '...' --at transformer.h.8.mlp"),
|
|
63
|
+
("trace", "Causal tracing — rank modules by effect", "interpkit trace gpt2 --clean '...' --corrupted '...'"),
|
|
64
|
+
("lens", "Logit lens — project layers to vocab", "interpkit lens gpt2 'The capital of France is'"),
|
|
65
|
+
("attribute", "Gradient saliency over inputs", "interpkit attribute gpt2 'The capital of France is'"),
|
|
66
|
+
("", "", ""),
|
|
67
|
+
("", "[bold]Analysis Operations[/bold]", ""),
|
|
68
|
+
("activations", "Extract raw activation tensors", "interpkit activations gpt2 '...' --at transformer.h.8"),
|
|
69
|
+
("ablate", "Zero/mean ablate a component", "interpkit ablate gpt2 '...' --at transformer.h.8.mlp"),
|
|
70
|
+
("attention", "Visualize attention patterns", "interpkit attention gpt2 '...' --layer 8"),
|
|
71
|
+
("steer", "Apply a steering vector", "interpkit steer gpt2 '...' --positive Love --negative Hate --at transformer.h.8"),
|
|
72
|
+
("probe", "Linear probe on activations", "interpkit probe gpt2 --at transformer.h.8 --data data.json"),
|
|
73
|
+
("diff", "Compare two models' activations", "interpkit diff gpt2 my-finetuned-gpt2 '...'"),
|
|
74
|
+
("", "", ""),
|
|
75
|
+
("", "[bold]Advanced[/bold]", ""),
|
|
76
|
+
("features", "SAE feature decomposition", "interpkit features gpt2 '...' --at transformer.h.8 --sae jbloom/..."),
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
for cmd, desc, example in rows:
|
|
80
|
+
table.add_row(cmd, desc, example)
|
|
81
|
+
|
|
82
|
+
panel = Panel(
|
|
83
|
+
table,
|
|
84
|
+
title="[bold cyan]Commands[/bold cyan]",
|
|
85
|
+
subtitle="[dim]Mech interp for any HuggingFace model.[/dim]",
|
|
86
|
+
border_style="cyan",
|
|
87
|
+
padding=(1, 2),
|
|
88
|
+
)
|
|
89
|
+
console.print()
|
|
90
|
+
console.print(panel)
|
|
91
|
+
|
|
92
|
+
save_hint = Text.assemble(
|
|
93
|
+
(" Tip: ", "bold"),
|
|
94
|
+
("Most commands accept ", ""),
|
|
95
|
+
("--save path.png", "bold green"),
|
|
96
|
+
(" to export a matplotlib figure and ", ""),
|
|
97
|
+
("--html path.html", "bold green"),
|
|
98
|
+
(" for interactive visualizations.\n", ""),
|
|
99
|
+
)
|
|
100
|
+
console.print(save_hint)
|
|
101
|
+
console.print(" Run [bold cyan]interpkit <command> --help[/bold cyan] for detailed usage.\n")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ══════════════════════════════════════════════════════════════════
|
|
105
|
+
# inspect
|
|
106
|
+
# ══════════════════════════════════════════════════════════════════
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@app.command()
|
|
110
|
+
def inspect(
|
|
111
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID (e.g. gpt2, microsoft/resnet-50)"),
|
|
112
|
+
device: Optional[str] = typer.Option(None, help="Device (cpu, cuda, mps). Auto-detected if omitted."),
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Print the model's module tree with types, param counts, and detected roles."""
|
|
115
|
+
m = _load_model(model_name, device=device)
|
|
116
|
+
m.inspect()
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
# ══════════════════════════════════════════════════════════════════
|
|
120
|
+
# patch
|
|
121
|
+
# ══════════════════════════════════════════════════════════════════
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@app.command()
|
|
125
|
+
def patch(
|
|
126
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
127
|
+
clean: str = typer.Option(..., "--clean", help="Clean input (text string or image path)"),
|
|
128
|
+
corrupted: str = typer.Option(..., "--corrupted", help="Corrupted input (text string or image path)"),
|
|
129
|
+
at: str = typer.Option(..., "--at", help="Module name to patch (e.g. transformer.h.8.mlp)"),
|
|
130
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
131
|
+
) -> None:
|
|
132
|
+
"""Activation patching: swap one module's output from clean into corrupted run."""
|
|
133
|
+
m = _load_model(model_name, device=device)
|
|
134
|
+
m.patch(clean, corrupted, at=at)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# ══════════════════════════════════════════════════════════════════
|
|
138
|
+
# trace
|
|
139
|
+
# ══════════════════════════════════════════════════════════════════
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@app.command()
|
|
143
|
+
def trace(
|
|
144
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
145
|
+
clean: str = typer.Option(..., "--clean", help="Clean input"),
|
|
146
|
+
corrupted: str = typer.Option(..., "--corrupted", help="Corrupted input"),
|
|
147
|
+
top_k: int = typer.Option(20, "--top-k", help="Scan top-K modules by proxy score. 0 = scan all."),
|
|
148
|
+
save: Optional[str] = typer.Option(None, "--save", help="Save bar chart to file (e.g. trace.png)"),
|
|
149
|
+
html_path: Optional[str] = typer.Option(None, "--html", help="Save interactive HTML to file (e.g. trace.html)"),
|
|
150
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
151
|
+
) -> None:
|
|
152
|
+
"""Causal tracing: rank modules by how much patching them restores clean output."""
|
|
153
|
+
effective_top_k: int | None = top_k if top_k > 0 else None
|
|
154
|
+
m = _load_model(model_name, device=device)
|
|
155
|
+
m.trace(clean, corrupted, top_k=effective_top_k, save=save, html=html_path)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# ══════════════════════════════════════════════════════════════════
|
|
159
|
+
# lens
|
|
160
|
+
# ══════════════════════════════════════════════════════════════════
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@app.command()
|
|
164
|
+
def lens(
|
|
165
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
166
|
+
text: str = typer.Argument(..., help="Input text"),
|
|
167
|
+
save: Optional[str] = typer.Option(None, "--save", help="Save heatmap to file (e.g. lens.png)"),
|
|
168
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
169
|
+
) -> None:
|
|
170
|
+
"""Logit lens: project each layer's hidden state to vocabulary space."""
|
|
171
|
+
m = _load_model(model_name, device=device)
|
|
172
|
+
m.lens(text, save=save)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# ══════════════════════════════════════════════════════════════════
|
|
176
|
+
# attribute
|
|
177
|
+
# ══════════════════════════════════════════════════════════════════
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
@app.command()
|
|
181
|
+
def attribute(
|
|
182
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
183
|
+
input_data: str = typer.Argument(..., help="Input text or image path"),
|
|
184
|
+
target: Optional[int] = typer.Option(None, "--target", help="Target class/token index for attribution"),
|
|
185
|
+
save: Optional[str] = typer.Option(None, "--save", help="Save figure to file (e.g. attribution.png)"),
|
|
186
|
+
html_path: Optional[str] = typer.Option(None, "--html", help="Save interactive HTML to file (e.g. attribution.html)"),
|
|
187
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
188
|
+
) -> None:
|
|
189
|
+
"""Gradient saliency over input tokens or pixels."""
|
|
190
|
+
m = _load_model(model_name, device=device)
|
|
191
|
+
m.attribute(input_data, target=target, save=save, html=html_path)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# ══════════════════════════════════════════════════════════════════
|
|
195
|
+
# activations
|
|
196
|
+
# ══════════════════════════════════════════════════════════════════
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@app.command()
|
|
200
|
+
def activations(
|
|
201
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
202
|
+
input_data: str = typer.Argument(..., help="Input text or image path"),
|
|
203
|
+
at: str = typer.Option(..., "--at", help="Module name(s) to extract, comma-separated"),
|
|
204
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
205
|
+
) -> None:
|
|
206
|
+
"""Extract and display activation statistics at named modules."""
|
|
207
|
+
m = _load_model(model_name, device=device)
|
|
208
|
+
modules = [s.strip() for s in at.split(",")]
|
|
209
|
+
if len(modules) == 1:
|
|
210
|
+
m.activations(input_data, at=modules[0])
|
|
211
|
+
else:
|
|
212
|
+
m.activations(input_data, at=modules)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ══════════════════════════════════════════════════════════════════
|
|
216
|
+
# ablate
|
|
217
|
+
# ══════════════════════════════════════════════════════════════════
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@app.command()
|
|
221
|
+
def ablate(
|
|
222
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
223
|
+
input_data: str = typer.Argument(..., help="Input text or image path"),
|
|
224
|
+
at: str = typer.Option(..., "--at", help="Module name to ablate"),
|
|
225
|
+
method: str = typer.Option("zero", "--method", help="Ablation method: zero or mean"),
|
|
226
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
227
|
+
) -> None:
|
|
228
|
+
"""Zero or mean ablate a module and measure the effect on output."""
|
|
229
|
+
m = _load_model(model_name, device=device)
|
|
230
|
+
m.ablate(input_data, at=at, method=method)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# ══════════════════════════════════════════════════════════════════
|
|
234
|
+
# attention
|
|
235
|
+
# ══════════════════════════════════════════════════════════════════
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@app.command()
|
|
239
|
+
def attention(
|
|
240
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
241
|
+
input_data: str = typer.Argument(..., help="Input text"),
|
|
242
|
+
layer: Optional[int] = typer.Option(None, "--layer", help="Specific layer index"),
|
|
243
|
+
head: Optional[int] = typer.Option(None, "--head", help="Specific head index"),
|
|
244
|
+
save: Optional[str] = typer.Option(None, "--save", help="Save heatmap to file (e.g. attention.png)"),
|
|
245
|
+
html_path: Optional[str] = typer.Option(None, "--html", help="Save interactive HTML to file (e.g. attention.html)"),
|
|
246
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
247
|
+
) -> None:
|
|
248
|
+
"""Show attention patterns for transformer models."""
|
|
249
|
+
m = _load_model(model_name, device=device)
|
|
250
|
+
m.attention(input_data, layer=layer, head=head, save=save, html=html_path)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
# ══════════════════════════════════════════════════════════════════
|
|
254
|
+
# steer
|
|
255
|
+
# ══════════════════════════════════════════════════════════════════
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@app.command()
|
|
259
|
+
def steer(
|
|
260
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
261
|
+
input_data: str = typer.Argument(..., help="Input text to steer"),
|
|
262
|
+
positive: str = typer.Option(..., "--positive", help="Positive direction text"),
|
|
263
|
+
negative: str = typer.Option(..., "--negative", help="Negative direction text"),
|
|
264
|
+
at: str = typer.Option(..., "--at", help="Module name to apply steering at"),
|
|
265
|
+
scale: float = typer.Option(2.0, "--scale", help="Steering vector scale factor"),
|
|
266
|
+
save: Optional[str] = typer.Option(None, "--save", help="Save comparison chart to file"),
|
|
267
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
268
|
+
) -> None:
|
|
269
|
+
"""Extract a steering vector and apply it during inference."""
|
|
270
|
+
m = _load_model(model_name, device=device)
|
|
271
|
+
vector = m.steer_vector(positive, negative, at=at)
|
|
272
|
+
m.steer(input_data, vector=vector, at=at, scale=scale, save=save)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
# ══════════════════════════════════════════════════════════════════
|
|
276
|
+
# probe
|
|
277
|
+
# ══════════════════════════════════════════════════════════════════
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@app.command()
|
|
281
|
+
def probe(
|
|
282
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID"),
|
|
283
|
+
at: str = typer.Option(..., "--at", help="Module name to probe"),
|
|
284
|
+
data: str = typer.Option(..., "--data", help="JSON file with {texts: [...], labels: [...]}"),
|
|
285
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
286
|
+
) -> None:
|
|
287
|
+
"""Train a linear probe on activations to test linear separability."""
|
|
288
|
+
import json
|
|
289
|
+
from pathlib import Path
|
|
290
|
+
|
|
291
|
+
probe_data = json.loads(Path(data).read_text())
|
|
292
|
+
m = _load_model(model_name, device=device)
|
|
293
|
+
m.probe(texts=probe_data["texts"], labels=probe_data["labels"], at=at)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# ══════════════════════════════════════════════════════════════════
|
|
297
|
+
# diff
|
|
298
|
+
# ══════════════════════════════════════════════════════════════════
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@app.command()
|
|
302
|
+
def diff(
|
|
303
|
+
model_a_name: str = typer.Argument(..., help="First model (e.g. gpt2)"),
|
|
304
|
+
model_b_name: str = typer.Argument(..., help="Second model (e.g. my-finetuned-gpt2)"),
|
|
305
|
+
input_data: str = typer.Argument(..., help="Input text to compare on"),
|
|
306
|
+
save: Optional[str] = typer.Option(None, "--save", help="Save bar chart to file"),
|
|
307
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
308
|
+
) -> None:
|
|
309
|
+
"""Compare activations between two models on the same input."""
|
|
310
|
+
import interpkit
|
|
311
|
+
|
|
312
|
+
m_a = _load_model(model_a_name, device=device)
|
|
313
|
+
m_b = _load_model(model_b_name, device=device)
|
|
314
|
+
interpkit.diff(m_a, m_b, input_data, save=save)
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
# ══════════════════════════════════════════════════════════════════
|
|
318
|
+
# features (SAE)
|
|
319
|
+
# ══════════════════════════════════════════════════════════════════
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@app.command()
|
|
323
|
+
def features(
|
|
324
|
+
model_name: str = typer.Argument(..., help="HuggingFace model ID (e.g. gpt2)"),
|
|
325
|
+
input_data: str = typer.Argument(..., help="Input text"),
|
|
326
|
+
at: str = typer.Option(..., "--at", help="Module name to decompose (e.g. transformer.h.8)"),
|
|
327
|
+
sae: str = typer.Option(..., "--sae", help="HuggingFace repo ID for the SAE weights"),
|
|
328
|
+
top_k: int = typer.Option(20, "--top-k", help="Number of top features to display"),
|
|
329
|
+
device: Optional[str] = typer.Option(None, help="Device"),
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Decompose activations through a Sparse Autoencoder into interpretable features."""
|
|
332
|
+
m = _load_model(model_name, device=device)
|
|
333
|
+
m.features(input_data, at=at, sae=sae, top_k=top_k)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
if __name__ == "__main__":
|
|
337
|
+
app()
|
|
File without changes
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""Auto-discover model structure from HF config, module name heuristics, and forward pass."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# ---------------------------------------------------------------------------
|
|
14
|
+
# Heuristic patterns for semantic module role detection
|
|
15
|
+
# ---------------------------------------------------------------------------
|
|
16
|
+
|
|
17
|
+
_ATTENTION_PATTERNS = re.compile(
|
|
18
|
+
r"(^|\.)(self_attn|attn|attention|mha|multi_head_attention)(\.|\b)", re.IGNORECASE
|
|
19
|
+
)
|
|
20
|
+
_MLP_PATTERNS = re.compile(
|
|
21
|
+
r"(^|\.)(mlp|ffn|feed_forward|dense|fc[_\d]|intermediate)(\.|\b)", re.IGNORECASE
|
|
22
|
+
)
|
|
23
|
+
_HEAD_PATTERNS = re.compile(
|
|
24
|
+
r"(^|\.)(lm_head|head|classifier|output_projection|qa_outputs)(\.|\b)", re.IGNORECASE
|
|
25
|
+
)
|
|
26
|
+
_NORM_PATTERNS = re.compile(
|
|
27
|
+
r"(^|\.)(layer_?norm|rms_?norm|norm|ln_f|ln_\d)(\.|\b)", re.IGNORECASE
|
|
28
|
+
)
|
|
29
|
+
_EMBED_PATTERNS = re.compile(
|
|
30
|
+
r"(^|\.)(embed|wte|wpe|embedding|token_embedding|position_embedding)(\.|\b)",
|
|
31
|
+
re.IGNORECASE,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ModuleInfo:
|
|
37
|
+
"""Discovered information about a single named module."""
|
|
38
|
+
|
|
39
|
+
name: str
|
|
40
|
+
type_name: str
|
|
41
|
+
param_count: int
|
|
42
|
+
output_shape: tuple[int, ...] | None = None
|
|
43
|
+
role: str | None = None # "attention", "mlp", "head", "norm", "embed", or None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class ModelArchInfo:
|
|
48
|
+
"""Aggregated architecture information for a model."""
|
|
49
|
+
|
|
50
|
+
arch_family: str | None = None # e.g. "GPT2LMHeadModel", "MambaForCausalLM"
|
|
51
|
+
num_layers: int | None = None
|
|
52
|
+
hidden_size: int | None = None
|
|
53
|
+
num_attention_heads: int | None = None
|
|
54
|
+
vocab_size: int | None = None
|
|
55
|
+
has_lm_head: bool = False
|
|
56
|
+
output_head_name: str | None = None
|
|
57
|
+
unembedding_name: str | None = None
|
|
58
|
+
modules: list[ModuleInfo] = field(default_factory=list)
|
|
59
|
+
layer_names: list[str] = field(default_factory=list)
|
|
60
|
+
is_tl_model: bool = False
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def is_language_model(self) -> bool:
|
|
64
|
+
return self.has_lm_head and self.unembedding_name is not None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _classify_role(name: str) -> str | None:
|
|
68
|
+
if _HEAD_PATTERNS.search(name):
|
|
69
|
+
return "head"
|
|
70
|
+
if _ATTENTION_PATTERNS.search(name):
|
|
71
|
+
return "attention"
|
|
72
|
+
if _MLP_PATTERNS.search(name):
|
|
73
|
+
return "mlp"
|
|
74
|
+
if _NORM_PATTERNS.search(name):
|
|
75
|
+
return "norm"
|
|
76
|
+
if _EMBED_PATTERNS.search(name):
|
|
77
|
+
return "embed"
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _count_params(module: nn.Module) -> int:
|
|
82
|
+
return sum(p.numel() for p in module.parameters(recurse=False))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _parse_hf_config(model: nn.Module) -> dict[str, Any]:
|
|
86
|
+
"""Extract architecture metadata from an HF model's config, if present."""
|
|
87
|
+
config = getattr(model, "config", None)
|
|
88
|
+
if config is None:
|
|
89
|
+
return {}
|
|
90
|
+
info: dict[str, Any] = {}
|
|
91
|
+
info["arch_family"] = type(model).__name__
|
|
92
|
+
|
|
93
|
+
for attr in ("num_hidden_layers", "n_layer", "num_layers", "n_layers"):
|
|
94
|
+
val = getattr(config, attr, None)
|
|
95
|
+
if val is not None:
|
|
96
|
+
info["num_layers"] = val
|
|
97
|
+
break
|
|
98
|
+
|
|
99
|
+
for attr in ("hidden_size", "n_embd", "d_model"):
|
|
100
|
+
val = getattr(config, attr, None)
|
|
101
|
+
if val is not None:
|
|
102
|
+
info["hidden_size"] = val
|
|
103
|
+
break
|
|
104
|
+
|
|
105
|
+
for attr in ("num_attention_heads", "n_head", "num_heads"):
|
|
106
|
+
val = getattr(config, attr, None)
|
|
107
|
+
if val is not None:
|
|
108
|
+
info["num_attention_heads"] = val
|
|
109
|
+
break
|
|
110
|
+
|
|
111
|
+
info["vocab_size"] = getattr(config, "vocab_size", None)
|
|
112
|
+
return info
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _find_unembedding(model: nn.Module) -> str | None:
|
|
116
|
+
"""Try to find the unembedding / LM head weight matrix."""
|
|
117
|
+
for name, module in model.named_modules():
|
|
118
|
+
if _HEAD_PATTERNS.search(name) and hasattr(module, "weight"):
|
|
119
|
+
return name
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _detect_layers(modules: list[ModuleInfo]) -> list[str]:
|
|
124
|
+
"""Identify repeated structural blocks that look like transformer/SSM layers.
|
|
125
|
+
|
|
126
|
+
Strategy: find modules whose names follow a pattern like ``something.N``
|
|
127
|
+
where N is a sequential integer, and whose siblings have identical structure.
|
|
128
|
+
We pick the longest such group.
|
|
129
|
+
"""
|
|
130
|
+
pattern = re.compile(r"^(.+)\.(\d+)$")
|
|
131
|
+
groups: dict[str, list[str]] = {}
|
|
132
|
+
for m in modules:
|
|
133
|
+
match = pattern.match(m.name)
|
|
134
|
+
if match:
|
|
135
|
+
prefix = match.group(1)
|
|
136
|
+
groups.setdefault(prefix, []).append(m.name)
|
|
137
|
+
|
|
138
|
+
if not groups:
|
|
139
|
+
return []
|
|
140
|
+
|
|
141
|
+
best_prefix = max(groups, key=lambda k: len(groups[k]))
|
|
142
|
+
layers = sorted(groups[best_prefix], key=lambda n: int(n.rsplit(".", 1)[-1]))
|
|
143
|
+
return layers
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def discover(
|
|
147
|
+
model: nn.Module,
|
|
148
|
+
dummy_input: Any | None = None,
|
|
149
|
+
) -> ModelArchInfo:
|
|
150
|
+
"""Run full auto-discovery on a model.
|
|
151
|
+
|
|
152
|
+
Parameters
|
|
153
|
+
----------
|
|
154
|
+
model:
|
|
155
|
+
Any ``nn.Module``, optionally with an HF ``.config`` attribute.
|
|
156
|
+
dummy_input:
|
|
157
|
+
If provided, used for a forward pass to capture output shapes.
|
|
158
|
+
Can be a tensor, dict of tensors, or tuple of tensors.
|
|
159
|
+
"""
|
|
160
|
+
hf_meta = _parse_hf_config(model)
|
|
161
|
+
|
|
162
|
+
# Enumerate all named modules
|
|
163
|
+
module_infos: list[ModuleInfo] = []
|
|
164
|
+
for name, mod in model.named_modules():
|
|
165
|
+
if name == "":
|
|
166
|
+
continue
|
|
167
|
+
info = ModuleInfo(
|
|
168
|
+
name=name,
|
|
169
|
+
type_name=type(mod).__name__,
|
|
170
|
+
param_count=_count_params(mod),
|
|
171
|
+
role=_classify_role(name),
|
|
172
|
+
)
|
|
173
|
+
module_infos.append(info)
|
|
174
|
+
|
|
175
|
+
# Output shape enumeration via hooks
|
|
176
|
+
if dummy_input is not None:
|
|
177
|
+
shapes: dict[str, tuple[int, ...]] = {}
|
|
178
|
+
hooks = []
|
|
179
|
+
|
|
180
|
+
def _make_hook(mod_name: str):
|
|
181
|
+
def hook_fn(_mod: nn.Module, _inp: Any, output: Any) -> None:
|
|
182
|
+
if isinstance(output, torch.Tensor):
|
|
183
|
+
shapes[mod_name] = tuple(output.shape)
|
|
184
|
+
elif isinstance(output, (tuple, list)) and len(output) > 0:
|
|
185
|
+
first = output[0]
|
|
186
|
+
if isinstance(first, torch.Tensor):
|
|
187
|
+
shapes[mod_name] = tuple(first.shape)
|
|
188
|
+
return hook_fn
|
|
189
|
+
|
|
190
|
+
for name, mod in model.named_modules():
|
|
191
|
+
if name == "":
|
|
192
|
+
continue
|
|
193
|
+
hooks.append(mod.register_forward_hook(_make_hook(name)))
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
with torch.no_grad():
|
|
197
|
+
if isinstance(dummy_input, dict):
|
|
198
|
+
model(**dummy_input)
|
|
199
|
+
elif isinstance(dummy_input, (tuple, list)):
|
|
200
|
+
model(*dummy_input)
|
|
201
|
+
else:
|
|
202
|
+
model(dummy_input)
|
|
203
|
+
finally:
|
|
204
|
+
for h in hooks:
|
|
205
|
+
h.remove()
|
|
206
|
+
|
|
207
|
+
for info in module_infos:
|
|
208
|
+
info.output_shape = shapes.get(info.name)
|
|
209
|
+
|
|
210
|
+
# Find unembedding
|
|
211
|
+
unembed_name = _find_unembedding(model)
|
|
212
|
+
has_lm_head = unembed_name is not None
|
|
213
|
+
|
|
214
|
+
# Detect layer names
|
|
215
|
+
layer_names = _detect_layers(module_infos)
|
|
216
|
+
|
|
217
|
+
return ModelArchInfo(
|
|
218
|
+
arch_family=hf_meta.get("arch_family"),
|
|
219
|
+
num_layers=hf_meta.get("num_layers"),
|
|
220
|
+
hidden_size=hf_meta.get("hidden_size"),
|
|
221
|
+
num_attention_heads=hf_meta.get("num_attention_heads"),
|
|
222
|
+
vocab_size=hf_meta.get("vocab_size"),
|
|
223
|
+
has_lm_head=has_lm_head,
|
|
224
|
+
output_head_name=unembed_name,
|
|
225
|
+
unembedding_name=unembed_name,
|
|
226
|
+
modules=module_infos,
|
|
227
|
+
layer_names=layer_names,
|
|
228
|
+
)
|