interpkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
interpkit/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ """interpkit — mech interp for any HuggingFace model."""
2
+
3
+ from interpkit.core.model import load
4
+ from interpkit.core.registry import register
5
+ from interpkit.core.tl_compat import list_tl_hooks, to_native_name, to_tl_name
6
+
7
+
8
+ def diff(model_a, model_b, input_data, *, save=None):
9
+ """Compare activations between two models on the same input."""
10
+ from interpkit.ops.diff import run_diff
11
+
12
+ return run_diff(model_a, model_b, input_data, save=save)
13
+
14
+
15
+ __all__ = ["load", "register", "diff", "to_tl_name", "to_native_name", "list_tl_hooks"]
File without changes
interpkit/cli/main.py ADDED
@@ -0,0 +1,337 @@
1
+ """CLI entry point — Typer app with all interpkit commands."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ import typer
8
+ from rich.console import Console
9
+ from rich.panel import Panel
10
+ from rich.table import Table
11
+ from rich.text import Text
12
+
13
+ app = typer.Typer(
14
+ name="interpkit",
15
+ help="Mech interp for any HuggingFace model.",
16
+ no_args_is_help=False,
17
+ add_completion=False,
18
+ rich_markup_mode="rich",
19
+ )
20
+ console = Console()
21
+
22
+
23
+ def _load_model(model_name: str, device: str | None = None):
24
+ from interpkit.core.model import load
25
+
26
+ with console.status(f"Loading {model_name}..."):
27
+ return load(model_name, device=device)
28
+
29
+
30
+ # ══════════════════════════════════════════════════════════════════
31
+ # help — rich overview panel
32
+ # ══════════════════════════════════════════════════════════════════
33
+
34
+
35
+ @app.callback(invoke_without_command=True)
36
+ def main(ctx: typer.Context) -> None:
37
+ """Mech interp for any HuggingFace model."""
38
+ if ctx.invoked_subcommand is not None:
39
+ return
40
+
41
+ logo = r"""
42
+ ___ _ _ ___ _
43
+ |_ _|_ __ | |_ ___ _ __ _ __| |/ (_) |_
44
+ | || '_ \| __/ _ \ '__| '_ \ ' /| | __|
45
+ | || | | | || __/ | | |_) | . \| | |_
46
+ |___|_| |_|\__\___|_| | .__/|_|\_\_|\__|
47
+ |_|
48
+ """
49
+ console.print(f"[bold cyan]{logo}[/bold cyan]", highlight=False)
50
+
51
+ table = Table(
52
+ show_header=True, header_style="bold", show_lines=False,
53
+ pad_edge=True, expand=True,
54
+ )
55
+ table.add_column("Command", style="cyan", no_wrap=True)
56
+ table.add_column("Description")
57
+ table.add_column("Example", style="dim")
58
+
59
+ rows = [
60
+ ("", "[bold]Core Operations[/bold]", ""),
61
+ ("inspect", "Module tree with types, params, roles", "interpkit inspect gpt2"),
62
+ ("patch", "Activation patching at a module", "interpkit patch gpt2 --clean '...' --corrupted '...' --at transformer.h.8.mlp"),
63
+ ("trace", "Causal tracing — rank modules by effect", "interpkit trace gpt2 --clean '...' --corrupted '...'"),
64
+ ("lens", "Logit lens — project layers to vocab", "interpkit lens gpt2 'The capital of France is'"),
65
+ ("attribute", "Gradient saliency over inputs", "interpkit attribute gpt2 'The capital of France is'"),
66
+ ("", "", ""),
67
+ ("", "[bold]Analysis Operations[/bold]", ""),
68
+ ("activations", "Extract raw activation tensors", "interpkit activations gpt2 '...' --at transformer.h.8"),
69
+ ("ablate", "Zero/mean ablate a component", "interpkit ablate gpt2 '...' --at transformer.h.8.mlp"),
70
+ ("attention", "Visualize attention patterns", "interpkit attention gpt2 '...' --layer 8"),
71
+ ("steer", "Apply a steering vector", "interpkit steer gpt2 '...' --positive Love --negative Hate --at transformer.h.8"),
72
+ ("probe", "Linear probe on activations", "interpkit probe gpt2 --at transformer.h.8 --data data.json"),
73
+ ("diff", "Compare two models' activations", "interpkit diff gpt2 my-finetuned-gpt2 '...'"),
74
+ ("", "", ""),
75
+ ("", "[bold]Advanced[/bold]", ""),
76
+ ("features", "SAE feature decomposition", "interpkit features gpt2 '...' --at transformer.h.8 --sae jbloom/..."),
77
+ ]
78
+
79
+ for cmd, desc, example in rows:
80
+ table.add_row(cmd, desc, example)
81
+
82
+ panel = Panel(
83
+ table,
84
+ title="[bold cyan]Commands[/bold cyan]",
85
+ subtitle="[dim]Mech interp for any HuggingFace model.[/dim]",
86
+ border_style="cyan",
87
+ padding=(1, 2),
88
+ )
89
+ console.print()
90
+ console.print(panel)
91
+
92
+ save_hint = Text.assemble(
93
+ (" Tip: ", "bold"),
94
+ ("Most commands accept ", ""),
95
+ ("--save path.png", "bold green"),
96
+ (" to export a matplotlib figure and ", ""),
97
+ ("--html path.html", "bold green"),
98
+ (" for interactive visualizations.\n", ""),
99
+ )
100
+ console.print(save_hint)
101
+ console.print(" Run [bold cyan]interpkit <command> --help[/bold cyan] for detailed usage.\n")
102
+
103
+
104
+ # ══════════════════════════════════════════════════════════════════
105
+ # inspect
106
+ # ══════════════════════════════════════════════════════════════════
107
+
108
+
109
+ @app.command()
110
+ def inspect(
111
+ model_name: str = typer.Argument(..., help="HuggingFace model ID (e.g. gpt2, microsoft/resnet-50)"),
112
+ device: Optional[str] = typer.Option(None, help="Device (cpu, cuda, mps). Auto-detected if omitted."),
113
+ ) -> None:
114
+ """Print the model's module tree with types, param counts, and detected roles."""
115
+ m = _load_model(model_name, device=device)
116
+ m.inspect()
117
+
118
+
119
+ # ══════════════════════════════════════════════════════════════════
120
+ # patch
121
+ # ══════════════════════════════════════════════════════════════════
122
+
123
+
124
+ @app.command()
125
+ def patch(
126
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
127
+ clean: str = typer.Option(..., "--clean", help="Clean input (text string or image path)"),
128
+ corrupted: str = typer.Option(..., "--corrupted", help="Corrupted input (text string or image path)"),
129
+ at: str = typer.Option(..., "--at", help="Module name to patch (e.g. transformer.h.8.mlp)"),
130
+ device: Optional[str] = typer.Option(None, help="Device"),
131
+ ) -> None:
132
+ """Activation patching: swap one module's output from clean into corrupted run."""
133
+ m = _load_model(model_name, device=device)
134
+ m.patch(clean, corrupted, at=at)
135
+
136
+
137
+ # ══════════════════════════════════════════════════════════════════
138
+ # trace
139
+ # ══════════════════════════════════════════════════════════════════
140
+
141
+
142
+ @app.command()
143
+ def trace(
144
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
145
+ clean: str = typer.Option(..., "--clean", help="Clean input"),
146
+ corrupted: str = typer.Option(..., "--corrupted", help="Corrupted input"),
147
+ top_k: int = typer.Option(20, "--top-k", help="Scan top-K modules by proxy score. 0 = scan all."),
148
+ save: Optional[str] = typer.Option(None, "--save", help="Save bar chart to file (e.g. trace.png)"),
149
+ html_path: Optional[str] = typer.Option(None, "--html", help="Save interactive HTML to file (e.g. trace.html)"),
150
+ device: Optional[str] = typer.Option(None, help="Device"),
151
+ ) -> None:
152
+ """Causal tracing: rank modules by how much patching them restores clean output."""
153
+ effective_top_k: int | None = top_k if top_k > 0 else None
154
+ m = _load_model(model_name, device=device)
155
+ m.trace(clean, corrupted, top_k=effective_top_k, save=save, html=html_path)
156
+
157
+
158
+ # ══════════════════════════════════════════════════════════════════
159
+ # lens
160
+ # ══════════════════════════════════════════════════════════════════
161
+
162
+
163
+ @app.command()
164
+ def lens(
165
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
166
+ text: str = typer.Argument(..., help="Input text"),
167
+ save: Optional[str] = typer.Option(None, "--save", help="Save heatmap to file (e.g. lens.png)"),
168
+ device: Optional[str] = typer.Option(None, help="Device"),
169
+ ) -> None:
170
+ """Logit lens: project each layer's hidden state to vocabulary space."""
171
+ m = _load_model(model_name, device=device)
172
+ m.lens(text, save=save)
173
+
174
+
175
+ # ══════════════════════════════════════════════════════════════════
176
+ # attribute
177
+ # ══════════════════════════════════════════════════════════════════
178
+
179
+
180
+ @app.command()
181
+ def attribute(
182
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
183
+ input_data: str = typer.Argument(..., help="Input text or image path"),
184
+ target: Optional[int] = typer.Option(None, "--target", help="Target class/token index for attribution"),
185
+ save: Optional[str] = typer.Option(None, "--save", help="Save figure to file (e.g. attribution.png)"),
186
+ html_path: Optional[str] = typer.Option(None, "--html", help="Save interactive HTML to file (e.g. attribution.html)"),
187
+ device: Optional[str] = typer.Option(None, help="Device"),
188
+ ) -> None:
189
+ """Gradient saliency over input tokens or pixels."""
190
+ m = _load_model(model_name, device=device)
191
+ m.attribute(input_data, target=target, save=save, html=html_path)
192
+
193
+
194
+ # ══════════════════════════════════════════════════════════════════
195
+ # activations
196
+ # ══════════════════════════════════════════════════════════════════
197
+
198
+
199
+ @app.command()
200
+ def activations(
201
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
202
+ input_data: str = typer.Argument(..., help="Input text or image path"),
203
+ at: str = typer.Option(..., "--at", help="Module name(s) to extract, comma-separated"),
204
+ device: Optional[str] = typer.Option(None, help="Device"),
205
+ ) -> None:
206
+ """Extract and display activation statistics at named modules."""
207
+ m = _load_model(model_name, device=device)
208
+ modules = [s.strip() for s in at.split(",")]
209
+ if len(modules) == 1:
210
+ m.activations(input_data, at=modules[0])
211
+ else:
212
+ m.activations(input_data, at=modules)
213
+
214
+
215
+ # ══════════════════════════════════════════════════════════════════
216
+ # ablate
217
+ # ══════════════════════════════════════════════════════════════════
218
+
219
+
220
+ @app.command()
221
+ def ablate(
222
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
223
+ input_data: str = typer.Argument(..., help="Input text or image path"),
224
+ at: str = typer.Option(..., "--at", help="Module name to ablate"),
225
+ method: str = typer.Option("zero", "--method", help="Ablation method: zero or mean"),
226
+ device: Optional[str] = typer.Option(None, help="Device"),
227
+ ) -> None:
228
+ """Zero or mean ablate a module and measure the effect on output."""
229
+ m = _load_model(model_name, device=device)
230
+ m.ablate(input_data, at=at, method=method)
231
+
232
+
233
+ # ══════════════════════════════════════════════════════════════════
234
+ # attention
235
+ # ══════════════════════════════════════════════════════════════════
236
+
237
+
238
+ @app.command()
239
+ def attention(
240
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
241
+ input_data: str = typer.Argument(..., help="Input text"),
242
+ layer: Optional[int] = typer.Option(None, "--layer", help="Specific layer index"),
243
+ head: Optional[int] = typer.Option(None, "--head", help="Specific head index"),
244
+ save: Optional[str] = typer.Option(None, "--save", help="Save heatmap to file (e.g. attention.png)"),
245
+ html_path: Optional[str] = typer.Option(None, "--html", help="Save interactive HTML to file (e.g. attention.html)"),
246
+ device: Optional[str] = typer.Option(None, help="Device"),
247
+ ) -> None:
248
+ """Show attention patterns for transformer models."""
249
+ m = _load_model(model_name, device=device)
250
+ m.attention(input_data, layer=layer, head=head, save=save, html=html_path)
251
+
252
+
253
+ # ══════════════════════════════════════════════════════════════════
254
+ # steer
255
+ # ══════════════════════════════════════════════════════════════════
256
+
257
+
258
+ @app.command()
259
+ def steer(
260
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
261
+ input_data: str = typer.Argument(..., help="Input text to steer"),
262
+ positive: str = typer.Option(..., "--positive", help="Positive direction text"),
263
+ negative: str = typer.Option(..., "--negative", help="Negative direction text"),
264
+ at: str = typer.Option(..., "--at", help="Module name to apply steering at"),
265
+ scale: float = typer.Option(2.0, "--scale", help="Steering vector scale factor"),
266
+ save: Optional[str] = typer.Option(None, "--save", help="Save comparison chart to file"),
267
+ device: Optional[str] = typer.Option(None, help="Device"),
268
+ ) -> None:
269
+ """Extract a steering vector and apply it during inference."""
270
+ m = _load_model(model_name, device=device)
271
+ vector = m.steer_vector(positive, negative, at=at)
272
+ m.steer(input_data, vector=vector, at=at, scale=scale, save=save)
273
+
274
+
275
+ # ══════════════════════════════════════════════════════════════════
276
+ # probe
277
+ # ══════════════════════════════════════════════════════════════════
278
+
279
+
280
+ @app.command()
281
+ def probe(
282
+ model_name: str = typer.Argument(..., help="HuggingFace model ID"),
283
+ at: str = typer.Option(..., "--at", help="Module name to probe"),
284
+ data: str = typer.Option(..., "--data", help="JSON file with {texts: [...], labels: [...]}"),
285
+ device: Optional[str] = typer.Option(None, help="Device"),
286
+ ) -> None:
287
+ """Train a linear probe on activations to test linear separability."""
288
+ import json
289
+ from pathlib import Path
290
+
291
+ probe_data = json.loads(Path(data).read_text())
292
+ m = _load_model(model_name, device=device)
293
+ m.probe(texts=probe_data["texts"], labels=probe_data["labels"], at=at)
294
+
295
+
296
+ # ══════════════════════════════════════════════════════════════════
297
+ # diff
298
+ # ══════════════════════════════════════════════════════════════════
299
+
300
+
301
+ @app.command()
302
+ def diff(
303
+ model_a_name: str = typer.Argument(..., help="First model (e.g. gpt2)"),
304
+ model_b_name: str = typer.Argument(..., help="Second model (e.g. my-finetuned-gpt2)"),
305
+ input_data: str = typer.Argument(..., help="Input text to compare on"),
306
+ save: Optional[str] = typer.Option(None, "--save", help="Save bar chart to file"),
307
+ device: Optional[str] = typer.Option(None, help="Device"),
308
+ ) -> None:
309
+ """Compare activations between two models on the same input."""
310
+ import interpkit
311
+
312
+ m_a = _load_model(model_a_name, device=device)
313
+ m_b = _load_model(model_b_name, device=device)
314
+ interpkit.diff(m_a, m_b, input_data, save=save)
315
+
316
+
317
+ # ══════════════════════════════════════════════════════════════════
318
+ # features (SAE)
319
+ # ══════════════════════════════════════════════════════════════════
320
+
321
+
322
+ @app.command()
323
+ def features(
324
+ model_name: str = typer.Argument(..., help="HuggingFace model ID (e.g. gpt2)"),
325
+ input_data: str = typer.Argument(..., help="Input text"),
326
+ at: str = typer.Option(..., "--at", help="Module name to decompose (e.g. transformer.h.8)"),
327
+ sae: str = typer.Option(..., "--sae", help="HuggingFace repo ID for the SAE weights"),
328
+ top_k: int = typer.Option(20, "--top-k", help="Number of top features to display"),
329
+ device: Optional[str] = typer.Option(None, help="Device"),
330
+ ) -> None:
331
+ """Decompose activations through a Sparse Autoencoder into interpretable features."""
332
+ m = _load_model(model_name, device=device)
333
+ m.features(input_data, at=at, sae=sae, top_k=top_k)
334
+
335
+
336
+ if __name__ == "__main__":
337
+ app()
File without changes
@@ -0,0 +1,228 @@
1
+ """Auto-discover model structure from HF config, module name heuristics, and forward pass."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from dataclasses import dataclass, field
7
+ from typing import Any
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+
13
+ # ---------------------------------------------------------------------------
14
+ # Heuristic patterns for semantic module role detection
15
+ # ---------------------------------------------------------------------------
16
+
17
+ _ATTENTION_PATTERNS = re.compile(
18
+ r"(^|\.)(self_attn|attn|attention|mha|multi_head_attention)(\.|\b)", re.IGNORECASE
19
+ )
20
+ _MLP_PATTERNS = re.compile(
21
+ r"(^|\.)(mlp|ffn|feed_forward|dense|fc[_\d]|intermediate)(\.|\b)", re.IGNORECASE
22
+ )
23
+ _HEAD_PATTERNS = re.compile(
24
+ r"(^|\.)(lm_head|head|classifier|output_projection|qa_outputs)(\.|\b)", re.IGNORECASE
25
+ )
26
+ _NORM_PATTERNS = re.compile(
27
+ r"(^|\.)(layer_?norm|rms_?norm|norm|ln_f|ln_\d)(\.|\b)", re.IGNORECASE
28
+ )
29
+ _EMBED_PATTERNS = re.compile(
30
+ r"(^|\.)(embed|wte|wpe|embedding|token_embedding|position_embedding)(\.|\b)",
31
+ re.IGNORECASE,
32
+ )
33
+
34
+
35
+ @dataclass
36
+ class ModuleInfo:
37
+ """Discovered information about a single named module."""
38
+
39
+ name: str
40
+ type_name: str
41
+ param_count: int
42
+ output_shape: tuple[int, ...] | None = None
43
+ role: str | None = None # "attention", "mlp", "head", "norm", "embed", or None
44
+
45
+
46
+ @dataclass
47
+ class ModelArchInfo:
48
+ """Aggregated architecture information for a model."""
49
+
50
+ arch_family: str | None = None # e.g. "GPT2LMHeadModel", "MambaForCausalLM"
51
+ num_layers: int | None = None
52
+ hidden_size: int | None = None
53
+ num_attention_heads: int | None = None
54
+ vocab_size: int | None = None
55
+ has_lm_head: bool = False
56
+ output_head_name: str | None = None
57
+ unembedding_name: str | None = None
58
+ modules: list[ModuleInfo] = field(default_factory=list)
59
+ layer_names: list[str] = field(default_factory=list)
60
+ is_tl_model: bool = False
61
+
62
+ @property
63
+ def is_language_model(self) -> bool:
64
+ return self.has_lm_head and self.unembedding_name is not None
65
+
66
+
67
+ def _classify_role(name: str) -> str | None:
68
+ if _HEAD_PATTERNS.search(name):
69
+ return "head"
70
+ if _ATTENTION_PATTERNS.search(name):
71
+ return "attention"
72
+ if _MLP_PATTERNS.search(name):
73
+ return "mlp"
74
+ if _NORM_PATTERNS.search(name):
75
+ return "norm"
76
+ if _EMBED_PATTERNS.search(name):
77
+ return "embed"
78
+ return None
79
+
80
+
81
+ def _count_params(module: nn.Module) -> int:
82
+ return sum(p.numel() for p in module.parameters(recurse=False))
83
+
84
+
85
+ def _parse_hf_config(model: nn.Module) -> dict[str, Any]:
86
+ """Extract architecture metadata from an HF model's config, if present."""
87
+ config = getattr(model, "config", None)
88
+ if config is None:
89
+ return {}
90
+ info: dict[str, Any] = {}
91
+ info["arch_family"] = type(model).__name__
92
+
93
+ for attr in ("num_hidden_layers", "n_layer", "num_layers", "n_layers"):
94
+ val = getattr(config, attr, None)
95
+ if val is not None:
96
+ info["num_layers"] = val
97
+ break
98
+
99
+ for attr in ("hidden_size", "n_embd", "d_model"):
100
+ val = getattr(config, attr, None)
101
+ if val is not None:
102
+ info["hidden_size"] = val
103
+ break
104
+
105
+ for attr in ("num_attention_heads", "n_head", "num_heads"):
106
+ val = getattr(config, attr, None)
107
+ if val is not None:
108
+ info["num_attention_heads"] = val
109
+ break
110
+
111
+ info["vocab_size"] = getattr(config, "vocab_size", None)
112
+ return info
113
+
114
+
115
+ def _find_unembedding(model: nn.Module) -> str | None:
116
+ """Try to find the unembedding / LM head weight matrix."""
117
+ for name, module in model.named_modules():
118
+ if _HEAD_PATTERNS.search(name) and hasattr(module, "weight"):
119
+ return name
120
+ return None
121
+
122
+
123
+ def _detect_layers(modules: list[ModuleInfo]) -> list[str]:
124
+ """Identify repeated structural blocks that look like transformer/SSM layers.
125
+
126
+ Strategy: find modules whose names follow a pattern like ``something.N``
127
+ where N is a sequential integer, and whose siblings have identical structure.
128
+ We pick the longest such group.
129
+ """
130
+ pattern = re.compile(r"^(.+)\.(\d+)$")
131
+ groups: dict[str, list[str]] = {}
132
+ for m in modules:
133
+ match = pattern.match(m.name)
134
+ if match:
135
+ prefix = match.group(1)
136
+ groups.setdefault(prefix, []).append(m.name)
137
+
138
+ if not groups:
139
+ return []
140
+
141
+ best_prefix = max(groups, key=lambda k: len(groups[k]))
142
+ layers = sorted(groups[best_prefix], key=lambda n: int(n.rsplit(".", 1)[-1]))
143
+ return layers
144
+
145
+
146
+ def discover(
147
+ model: nn.Module,
148
+ dummy_input: Any | None = None,
149
+ ) -> ModelArchInfo:
150
+ """Run full auto-discovery on a model.
151
+
152
+ Parameters
153
+ ----------
154
+ model:
155
+ Any ``nn.Module``, optionally with an HF ``.config`` attribute.
156
+ dummy_input:
157
+ If provided, used for a forward pass to capture output shapes.
158
+ Can be a tensor, dict of tensors, or tuple of tensors.
159
+ """
160
+ hf_meta = _parse_hf_config(model)
161
+
162
+ # Enumerate all named modules
163
+ module_infos: list[ModuleInfo] = []
164
+ for name, mod in model.named_modules():
165
+ if name == "":
166
+ continue
167
+ info = ModuleInfo(
168
+ name=name,
169
+ type_name=type(mod).__name__,
170
+ param_count=_count_params(mod),
171
+ role=_classify_role(name),
172
+ )
173
+ module_infos.append(info)
174
+
175
+ # Output shape enumeration via hooks
176
+ if dummy_input is not None:
177
+ shapes: dict[str, tuple[int, ...]] = {}
178
+ hooks = []
179
+
180
+ def _make_hook(mod_name: str):
181
+ def hook_fn(_mod: nn.Module, _inp: Any, output: Any) -> None:
182
+ if isinstance(output, torch.Tensor):
183
+ shapes[mod_name] = tuple(output.shape)
184
+ elif isinstance(output, (tuple, list)) and len(output) > 0:
185
+ first = output[0]
186
+ if isinstance(first, torch.Tensor):
187
+ shapes[mod_name] = tuple(first.shape)
188
+ return hook_fn
189
+
190
+ for name, mod in model.named_modules():
191
+ if name == "":
192
+ continue
193
+ hooks.append(mod.register_forward_hook(_make_hook(name)))
194
+
195
+ try:
196
+ with torch.no_grad():
197
+ if isinstance(dummy_input, dict):
198
+ model(**dummy_input)
199
+ elif isinstance(dummy_input, (tuple, list)):
200
+ model(*dummy_input)
201
+ else:
202
+ model(dummy_input)
203
+ finally:
204
+ for h in hooks:
205
+ h.remove()
206
+
207
+ for info in module_infos:
208
+ info.output_shape = shapes.get(info.name)
209
+
210
+ # Find unembedding
211
+ unembed_name = _find_unembedding(model)
212
+ has_lm_head = unembed_name is not None
213
+
214
+ # Detect layer names
215
+ layer_names = _detect_layers(module_infos)
216
+
217
+ return ModelArchInfo(
218
+ arch_family=hf_meta.get("arch_family"),
219
+ num_layers=hf_meta.get("num_layers"),
220
+ hidden_size=hf_meta.get("hidden_size"),
221
+ num_attention_heads=hf_meta.get("num_attention_heads"),
222
+ vocab_size=hf_meta.get("vocab_size"),
223
+ has_lm_head=has_lm_head,
224
+ output_head_name=unembed_name,
225
+ unembedding_name=unembed_name,
226
+ modules=module_infos,
227
+ layer_names=layer_names,
228
+ )