interpkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,465 @@
1
+ """Terminal rendering — rich tables, trees, unicode bar charts, heatmap export."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import torch
8
+ from rich.console import Console
9
+ from rich.table import Table
10
+ from rich.tree import Tree
11
+
12
+ if TYPE_CHECKING:
13
+ from interpkit.core.discovery import ModelArchInfo, ModuleInfo
14
+
15
+ console = Console()
16
+
17
+ _ROLE_TAGS = {
18
+ "attention": "[attn]",
19
+ "mlp": "[mlp]",
20
+ "head": "[head]",
21
+ "norm": "[norm]",
22
+ "embed": "[embed]",
23
+ }
24
+
25
+
26
+ # ------------------------------------------------------------------
27
+ # Inspect rendering
28
+ # ------------------------------------------------------------------
29
+
30
+
31
+ def render_inspect(arch_info: "ModelArchInfo") -> None:
32
+ """Print a module tree with types, param counts, and detected roles."""
33
+ header_parts = []
34
+ if arch_info.arch_family:
35
+ header_parts.append(arch_info.arch_family)
36
+ if arch_info.num_layers is not None:
37
+ header_parts.append(f"{arch_info.num_layers} layers")
38
+ if arch_info.hidden_size is not None:
39
+ header_parts.append(f"hidden={arch_info.hidden_size}")
40
+ if arch_info.vocab_size is not None:
41
+ header_parts.append(f"vocab={arch_info.vocab_size}")
42
+
43
+ total_params = sum(m.param_count for m in arch_info.modules)
44
+ header_parts.append(f"{_format_params(total_params)} params total")
45
+
46
+ console.print(f"\n[bold]{' | '.join(header_parts)}[/bold]")
47
+
48
+ table = Table(show_header=True, header_style="bold", show_lines=False, pad_edge=False)
49
+ table.add_column("Module", style="cyan", no_wrap=True)
50
+ table.add_column("Type", style="dim")
51
+ table.add_column("Params", justify="right")
52
+ table.add_column("Output Shape", style="dim")
53
+ table.add_column("Role", style="bold yellow")
54
+
55
+ for m in arch_info.modules:
56
+ role_tag = _ROLE_TAGS.get(m.role or "", "")
57
+ shape_str = str(m.output_shape) if m.output_shape else ""
58
+ param_str = _format_params(m.param_count) if m.param_count > 0 else ""
59
+ table.add_row(m.name, m.type_name, param_str, shape_str, role_tag)
60
+
61
+ console.print(table)
62
+ console.print()
63
+
64
+
65
+ # ------------------------------------------------------------------
66
+ # Causal trace rendering
67
+ # ------------------------------------------------------------------
68
+
69
+
70
+ def render_trace(
71
+ results: list[dict[str, Any]],
72
+ model_name: str,
73
+ total_modules: int,
74
+ top_k: int | None,
75
+ ) -> None:
76
+ """Print a ranked bar chart of causal tracing results."""
77
+ scanned = len(results)
78
+ if top_k is not None:
79
+ title = f"Causal Trace: {model_name} (top {scanned} of {total_modules} modules)"
80
+ else:
81
+ title = f"Causal Trace: {model_name} ({total_modules} modules)"
82
+
83
+ console.print(f"\n[bold]{title}[/bold]")
84
+
85
+ if not results:
86
+ console.print(" No significant causal effects found.")
87
+ return
88
+
89
+ max_effect = max(r["effect"] for r in results) if results else 1.0
90
+ bar_width = 30
91
+
92
+ table = Table(show_header=False, show_lines=False, pad_edge=False, box=None)
93
+ table.add_column("Module", style="cyan", no_wrap=True, min_width=35)
94
+ table.add_column("Role", style="yellow", min_width=8)
95
+ table.add_column("Bar", no_wrap=True)
96
+ table.add_column("Effect", justify="right", style="bold")
97
+
98
+ for r in results:
99
+ fill = int(bar_width * r["effect"] / max_effect) if max_effect > 0 else 0
100
+ bar = "█" * fill
101
+ role = _ROLE_TAGS.get(r.get("role") or "", "")
102
+ table.add_row(r["module"], role, f"[green]{bar}[/green]", f"{r['effect']:.3f}")
103
+
104
+ console.print(table)
105
+
106
+ if results:
107
+ best = results[0]
108
+ console.print(f"\n Top component: [bold cyan]{best['module']}[/bold cyan] (effect: {best['effect']:.3f})")
109
+
110
+ if top_k is not None and scanned < total_modules:
111
+ console.print(
112
+ f" Run with --top-k 0 to scan all {total_modules} modules.\n"
113
+ )
114
+ else:
115
+ console.print()
116
+
117
+
118
+ # ------------------------------------------------------------------
119
+ # Logit lens rendering
120
+ # ------------------------------------------------------------------
121
+
122
+
123
+ def render_lens(
124
+ predictions: list[dict[str, Any]],
125
+ model_name: str,
126
+ ) -> None:
127
+ """Print logit lens top predictions per layer."""
128
+ console.print(f"\n[bold]Logit Lens: {model_name}[/bold]")
129
+
130
+ table = Table(show_header=True, header_style="bold", show_lines=False)
131
+ table.add_column("Layer", style="cyan")
132
+ table.add_column("Top-1 Token", style="bold")
133
+ table.add_column("Prob", justify="right")
134
+ table.add_column("Top-5 Tokens", style="dim")
135
+
136
+ for pred in predictions:
137
+ top5_str = ", ".join(
138
+ f"{tok} ({prob:.2f})" for tok, prob in zip(pred["top5_tokens"], pred["top5_probs"])
139
+ )
140
+ table.add_row(
141
+ pred["layer_name"],
142
+ pred["top1_token"],
143
+ f"{pred['top1_prob']:.3f}",
144
+ top5_str,
145
+ )
146
+
147
+ console.print(table)
148
+ console.print()
149
+
150
+
151
+ # ------------------------------------------------------------------
152
+ # Attribution rendering
153
+ # ------------------------------------------------------------------
154
+
155
+
156
+ def render_attribution_tokens(
157
+ tokens: list[str],
158
+ scores: list[float],
159
+ ) -> None:
160
+ """Print tokens coloured by attribution score (terminal)."""
161
+ console.print("\n[bold]Attribution (gradient saliency)[/bold]")
162
+
163
+ if not scores:
164
+ console.print(" No attribution scores computed.")
165
+ return
166
+
167
+ max_score = max(abs(s) for s in scores) if scores else 1.0
168
+
169
+ parts: list[str] = []
170
+ for tok, score in zip(tokens, scores):
171
+ intensity = abs(score) / max_score if max_score > 0 else 0
172
+ if intensity > 0.7:
173
+ parts.append(f"[bold red]{tok}[/bold red]")
174
+ elif intensity > 0.4:
175
+ parts.append(f"[yellow]{tok}[/yellow]")
176
+ elif intensity > 0.15:
177
+ parts.append(f"[dim]{tok}[/dim]")
178
+ else:
179
+ parts.append(tok)
180
+
181
+ console.print(" " + "".join(parts))
182
+
183
+ # Also show ranked list
184
+ ranked = sorted(zip(tokens, scores), key=lambda x: abs(x[1]), reverse=True)
185
+ console.print()
186
+ for tok, score in ranked[:10]:
187
+ bar_len = int(20 * abs(score) / max_score) if max_score > 0 else 0
188
+ bar = "█" * bar_len
189
+ console.print(f" {tok:>15s} [green]{bar}[/green] {score:.4f}")
190
+
191
+ console.print()
192
+
193
+
194
+ def render_attribution_heatmap(
195
+ attribution: torch.Tensor,
196
+ output_path: str = "attribution_heatmap.png",
197
+ ) -> None:
198
+ """Save a vision attribution heatmap to a file."""
199
+ import matplotlib.pyplot as plt
200
+ import numpy as np
201
+
202
+ attr_np = attribution.detach().cpu().numpy()
203
+
204
+ # Collapse channel dim if present
205
+ if attr_np.ndim == 3:
206
+ attr_np = attr_np.mean(axis=0)
207
+ elif attr_np.ndim == 4:
208
+ attr_np = attr_np[0].mean(axis=0)
209
+
210
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6))
211
+ im = ax.imshow(np.abs(attr_np), cmap="hot", interpolation="bilinear")
212
+ ax.set_title("Gradient Attribution")
213
+ ax.axis("off")
214
+ fig.colorbar(im, ax=ax, fraction=0.046)
215
+ fig.savefig(output_path, bbox_inches="tight", dpi=150)
216
+ plt.close(fig)
217
+
218
+ console.print(f"\n Attribution heatmap saved to [bold]{output_path}[/bold]\n")
219
+
220
+
221
+ # ------------------------------------------------------------------
222
+ # Patch rendering
223
+ # ------------------------------------------------------------------
224
+
225
+
226
+ def render_patch(result: dict[str, Any]) -> None:
227
+ """Print the result of a single activation patch."""
228
+ console.print(f"\n[bold]Activation Patch at: {result['module']}[/bold]")
229
+ console.print(f" Normalised effect: [bold]{result['effect']:.4f}[/bold]")
230
+ console.print()
231
+
232
+
233
+ # ------------------------------------------------------------------
234
+ # Activations rendering
235
+ # ------------------------------------------------------------------
236
+
237
+
238
+ def render_activations(cache: dict[str, torch.Tensor]) -> None:
239
+ """Print a summary table of extracted activations."""
240
+ console.print("\n[bold]Activations[/bold]")
241
+
242
+ table = Table(show_header=True, header_style="bold", show_lines=False)
243
+ table.add_column("Module", style="cyan", no_wrap=True)
244
+ table.add_column("Shape", style="dim")
245
+ table.add_column("Norm", justify="right")
246
+ table.add_column("Mean", justify="right")
247
+ table.add_column("Std", justify="right")
248
+ table.add_column("Min", justify="right")
249
+ table.add_column("Max", justify="right")
250
+
251
+ for name, tensor in cache.items():
252
+ t = tensor.float()
253
+ table.add_row(
254
+ name,
255
+ str(tuple(tensor.shape)),
256
+ f"{t.norm():.3f}",
257
+ f"{t.mean():.4f}",
258
+ f"{t.std():.4f}",
259
+ f"{t.min():.4f}",
260
+ f"{t.max():.4f}",
261
+ )
262
+
263
+ console.print(table)
264
+ console.print()
265
+
266
+
267
+ # ------------------------------------------------------------------
268
+ # Ablation rendering
269
+ # ------------------------------------------------------------------
270
+
271
+
272
+ def render_ablate(result: dict[str, Any]) -> None:
273
+ """Print the result of an ablation."""
274
+ method = result.get("method", "zero")
275
+ console.print(f"\n[bold]Ablation ({method}) at: {result['module']}[/bold]")
276
+ console.print(f" Effect on output: [bold]{result['effect']:.4f}[/bold]")
277
+ console.print()
278
+
279
+
280
+ # ------------------------------------------------------------------
281
+ # Attention rendering
282
+ # ------------------------------------------------------------------
283
+
284
+
285
+ def render_attention(
286
+ attention_data: list[dict[str, Any]],
287
+ tokens: list[str] | None,
288
+ model_name: str,
289
+ ) -> None:
290
+ """Print attention summary per layer/head."""
291
+ console.print(f"\n[bold]Attention Patterns: {model_name}[/bold]")
292
+
293
+ if not attention_data:
294
+ console.print(" No attention data captured.")
295
+ return
296
+
297
+ table = Table(show_header=True, header_style="bold", show_lines=False)
298
+ table.add_column("Layer", style="cyan")
299
+ table.add_column("Head", style="cyan", justify="right")
300
+ table.add_column("Top Attention", style="dim")
301
+ table.add_column("Entropy", justify="right")
302
+
303
+ for entry in attention_data:
304
+ top_attn_parts = []
305
+ for src, tgt, score in entry.get("top_pairs", [])[:3]:
306
+ src_tok = tokens[src] if tokens and src < len(tokens) else str(src)
307
+ tgt_tok = tokens[tgt] if tokens and tgt < len(tokens) else str(tgt)
308
+ top_attn_parts.append(f"{src_tok}->{tgt_tok} ({score:.2f})")
309
+ table.add_row(
310
+ str(entry["layer"]),
311
+ str(entry["head"]),
312
+ ", ".join(top_attn_parts),
313
+ f"{entry.get('entropy', 0.0):.2f}",
314
+ )
315
+
316
+ console.print(table)
317
+ console.print()
318
+
319
+
320
+ # ------------------------------------------------------------------
321
+ # Steering rendering
322
+ # ------------------------------------------------------------------
323
+
324
+
325
+ def render_steer(
326
+ original_tokens: list[tuple[str, float]],
327
+ steered_tokens: list[tuple[str, float]],
328
+ module_name: str,
329
+ scale: float,
330
+ ) -> None:
331
+ """Print side-by-side comparison of top tokens with and without steering."""
332
+ console.print(f"\n[bold]Steering at: {module_name} (scale={scale})[/bold]")
333
+
334
+ table = Table(show_header=True, header_style="bold", show_lines=False)
335
+ table.add_column("Rank", justify="right", style="dim")
336
+ table.add_column("Original Token", style="cyan")
337
+ table.add_column("Prob", justify="right")
338
+ table.add_column("Steered Token", style="green")
339
+ table.add_column("Prob", justify="right")
340
+
341
+ n = max(len(original_tokens), len(steered_tokens))
342
+ for i in range(min(n, 10)):
343
+ orig_tok = original_tokens[i][0] if i < len(original_tokens) else ""
344
+ orig_prob = f"{original_tokens[i][1]:.3f}" if i < len(original_tokens) else ""
345
+ steer_tok = steered_tokens[i][0] if i < len(steered_tokens) else ""
346
+ steer_prob = f"{steered_tokens[i][1]:.3f}" if i < len(steered_tokens) else ""
347
+ table.add_row(str(i + 1), orig_tok, orig_prob, steer_tok, steer_prob)
348
+
349
+ console.print(table)
350
+ console.print()
351
+
352
+
353
+ # ------------------------------------------------------------------
354
+ # Probe rendering
355
+ # ------------------------------------------------------------------
356
+
357
+
358
+ def render_probe(result: dict[str, Any]) -> None:
359
+ """Print probe results — accuracy and top features."""
360
+ console.print(f"\n[bold]Linear Probe at: {result['module']}[/bold]")
361
+ console.print(f" Accuracy: [bold]{result['accuracy']:.3f}[/bold]")
362
+
363
+ if result.get("train_accuracy") is not None:
364
+ console.print(f" Train accuracy: {result['train_accuracy']:.3f}")
365
+
366
+ if result.get("top_features"):
367
+ console.print("\n Top features by weight magnitude:")
368
+ for idx, weight in result["top_features"][:10]:
369
+ bar_len = int(20 * abs(weight) / abs(result["top_features"][0][1])) if result["top_features"][0][1] != 0 else 0
370
+ bar = "█" * bar_len
371
+ console.print(f" dim {idx:>5d} [green]{bar}[/green] {weight:.4f}")
372
+
373
+ console.print()
374
+
375
+
376
+ # ------------------------------------------------------------------
377
+ # Diff rendering
378
+ # ------------------------------------------------------------------
379
+
380
+
381
+ def render_diff(
382
+ results: list[dict[str, Any]],
383
+ model_a_name: str,
384
+ model_b_name: str,
385
+ ) -> None:
386
+ """Print per-module activation distance between two models."""
387
+ console.print(f"\n[bold]Model Diff: {model_a_name} vs {model_b_name}[/bold]")
388
+
389
+ if not results:
390
+ console.print(" No differences computed.")
391
+ return
392
+
393
+ max_dist = max(r["distance"] for r in results) if results else 1.0
394
+ bar_width = 30
395
+
396
+ table = Table(show_header=False, show_lines=False, pad_edge=False, box=None)
397
+ table.add_column("Module", style="cyan", no_wrap=True, min_width=35)
398
+ table.add_column("Bar", no_wrap=True)
399
+ table.add_column("Cosine Dist", justify="right", style="bold")
400
+
401
+ for r in results:
402
+ fill = int(bar_width * r["distance"] / max_dist) if max_dist > 0 else 0
403
+ bar = "█" * fill
404
+ table.add_row(r["module"], f"[green]{bar}[/green]", f"{r['distance']:.4f}")
405
+
406
+ console.print(table)
407
+
408
+ if results:
409
+ best = results[0]
410
+ console.print(f"\n Most changed: [bold cyan]{best['module']}[/bold cyan] (distance: {best['distance']:.4f})")
411
+
412
+ console.print()
413
+
414
+
415
+ # ------------------------------------------------------------------
416
+ # SAE features rendering
417
+ # ------------------------------------------------------------------
418
+
419
+
420
+ def render_features(result: dict[str, Any]) -> None:
421
+ """Print SAE feature decomposition results."""
422
+ console.print(f"\n[bold]SAE Features at: {result['module']}[/bold]")
423
+ console.print(
424
+ f" Active features: [bold]{result['num_active_features']}[/bold] / {result['total_features']} "
425
+ f"| Sparsity: {result['sparsity']:.2%} "
426
+ f"| Reconstruction error: {result['reconstruction_error']:.4f}"
427
+ )
428
+
429
+ top = result.get("top_features", [])
430
+ if not top:
431
+ console.print(" No active features found.")
432
+ console.print()
433
+ return
434
+
435
+ max_val = max(abs(v) for _, v in top) if top else 1.0
436
+
437
+ table = Table(show_header=True, header_style="bold", show_lines=False)
438
+ table.add_column("Rank", justify="right", style="dim")
439
+ table.add_column("Feature", style="cyan", justify="right")
440
+ table.add_column("Activation", justify="right")
441
+ table.add_column("Bar", no_wrap=True)
442
+
443
+ bar_width = 25
444
+ for rank, (idx, val) in enumerate(top, 1):
445
+ fill = int(bar_width * abs(val) / max_val) if max_val > 0 else 0
446
+ bar = "█" * fill
447
+ table.add_row(str(rank), str(idx), f"{val:.4f}", f"[green]{bar}[/green]")
448
+
449
+ console.print(table)
450
+ console.print()
451
+
452
+
453
+ # ------------------------------------------------------------------
454
+ # Helpers
455
+ # ------------------------------------------------------------------
456
+
457
+
458
+ def _format_params(n: int) -> str:
459
+ if n >= 1_000_000_000:
460
+ return f"{n / 1_000_000_000:.1f}B"
461
+ if n >= 1_000_000:
462
+ return f"{n / 1_000_000:.1f}M"
463
+ if n >= 1_000:
464
+ return f"{n / 1_000:.1f}K"
465
+ return str(n)
@@ -0,0 +1,174 @@
1
+ """TransformerLens interop — bidirectional name translation between native and TL hook names."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ if TYPE_CHECKING:
9
+ from interpkit.core.discovery import ModelArchInfo
10
+
11
+ # ── TL canonical hook names ──────────────────────────────────────
12
+ #
13
+ # TL standardizes all transformer architectures to:
14
+ # blocks.{N}.hook_resid_pre
15
+ # blocks.{N}.hook_resid_post
16
+ # blocks.{N}.attn.hook_q / hook_k / hook_v / hook_z / hook_result
17
+ # blocks.{N}.attn.hook_pattern
18
+ # blocks.{N}.hook_attn_out
19
+ # blocks.{N}.hook_mlp_out
20
+ # blocks.{N}.mlp.hook_pre / hook_post
21
+ # blocks.{N}.ln1 / ln2
22
+ #
23
+ # These map to native names like:
24
+ # transformer.h.{N} -> blocks.{N}
25
+ # transformer.h.{N}.attn -> blocks.{N}.attn
26
+ # transformer.h.{N}.mlp -> blocks.{N}.mlp
27
+ # model.layers.{N}.self_attn -> blocks.{N}.attn
28
+ # model.layers.{N}.mlp -> blocks.{N}.mlp
29
+
30
+ # Patterns for extracting layer index and component from native names
31
+ _NATIVE_LAYER_RE = re.compile(
32
+ r"^(?P<prefix>.+?)[.\[](?P<idx>\d+)[.\]]*(?P<suffix>.*)$"
33
+ )
34
+
35
+ # Common native -> TL component mappings
36
+ _COMPONENT_TO_TL: list[tuple[re.Pattern, str]] = [
37
+ (re.compile(r"\.self_attn$|\.attn$|\.attention$", re.I), ".attn"),
38
+ (re.compile(r"\.mlp$|\.ffn$|\.feed_forward$", re.I), ".mlp"),
39
+ (re.compile(r"\.ln_?1$|\.input_layernorm$", re.I), ".ln1"),
40
+ (re.compile(r"\.ln_?2$|\.post_attention_layernorm$", re.I), ".ln2"),
41
+ (re.compile(r"\.self_attn\.q_proj$|\.attn\.q_proj$", re.I), ".attn.hook_q"),
42
+ (re.compile(r"\.self_attn\.k_proj$|\.attn\.k_proj$", re.I), ".attn.hook_k"),
43
+ (re.compile(r"\.self_attn\.v_proj$|\.attn\.v_proj$", re.I), ".attn.hook_v"),
44
+ (re.compile(r"\.self_attn\.o_proj$|\.attn\.c_proj$|\.attn\.out_proj$", re.I), ".attn.hook_result"),
45
+ ]
46
+
47
+ # TL hook -> native component suffix patterns
48
+ _TL_TO_COMPONENT: list[tuple[str, list[str]]] = [
49
+ (".attn", ["attn", "self_attn", "attention"]),
50
+ (".mlp", ["mlp", "ffn", "feed_forward"]),
51
+ (".ln1", ["ln_1", "ln1", "input_layernorm"]),
52
+ (".ln2", ["ln_2", "ln2", "post_attention_layernorm"]),
53
+ (".attn.hook_q", ["attn.q_proj", "self_attn.q_proj"]),
54
+ (".attn.hook_k", ["attn.k_proj", "self_attn.k_proj"]),
55
+ (".attn.hook_v", ["attn.v_proj", "self_attn.v_proj"]),
56
+ (".attn.hook_result", ["attn.c_proj", "attn.out_proj", "self_attn.o_proj"]),
57
+ ]
58
+
59
+
60
+ def to_tl_name(native_name: str, arch_info: "ModelArchInfo | None" = None) -> str:
61
+ """Translate a native PyTorch module name to the corresponding TL hook name.
62
+
63
+ Examples::
64
+
65
+ to_tl_name("transformer.h.8.mlp") -> "blocks.8.mlp"
66
+ to_tl_name("transformer.h.8.attn") -> "blocks.8.attn"
67
+ to_tl_name("model.layers.3.self_attn.q_proj") -> "blocks.3.attn.hook_q"
68
+ """
69
+ m = _NATIVE_LAYER_RE.match(native_name)
70
+ if m is None:
71
+ return native_name
72
+
73
+ idx = m.group("idx")
74
+ raw_suffix = m.group("suffix")
75
+
76
+ # Normalize suffix to always start with "." for pattern matching
77
+ clean_suffix = raw_suffix.lstrip(".")
78
+ dotted_suffix = f".{clean_suffix}" if clean_suffix else ""
79
+
80
+ # Try specific component mappings first
81
+ for pattern, tl_suffix in _COMPONENT_TO_TL:
82
+ if dotted_suffix and pattern.search(dotted_suffix):
83
+ return f"blocks.{idx}{tl_suffix}"
84
+
85
+ # Bare layer reference (e.g. "transformer.h.8")
86
+ if not clean_suffix:
87
+ return f"blocks.{idx}"
88
+
89
+ # Fallback: preserve suffix as-is under blocks.{N}
90
+ return f"blocks.{idx}.{clean_suffix}"
91
+
92
+
93
+ def to_native_name(
94
+ tl_name: str,
95
+ arch_info: "ModelArchInfo | None" = None,
96
+ ) -> str:
97
+ """Translate a TL hook name back to the most likely native module name.
98
+
99
+ Requires ``arch_info`` from a loaded model to resolve the native layer prefix
100
+ (e.g. ``transformer.h`` vs ``model.layers``). Without it, returns a best-guess.
101
+
102
+ Examples::
103
+
104
+ to_native_name("blocks.8.mlp", arch_info) -> "transformer.h.8.mlp"
105
+ to_native_name("blocks.3.attn.hook_q", arch_info) -> "transformer.h.3.attn.q_proj"
106
+ """
107
+ # Parse TL name: blocks.{N}.{rest}
108
+ tl_match = re.match(r"^blocks\.(\d+)(?:\.(.+))?$", tl_name)
109
+ if tl_match is None:
110
+ return tl_name
111
+
112
+ idx = tl_match.group(1)
113
+ tl_suffix = tl_match.group(2) or ""
114
+
115
+ # Determine native layer prefix from arch_info
116
+ prefix = _infer_native_prefix(arch_info)
117
+
118
+ if not tl_suffix:
119
+ return f"{prefix}.{idx}"
120
+
121
+ # Try specific TL -> native mappings
122
+ for tl_component, native_candidates in _TL_TO_COMPONENT:
123
+ tl_component_clean = tl_component.lstrip(".")
124
+ if tl_suffix == tl_component_clean:
125
+ # Pick the first candidate that exists in the module tree, or fall back to first
126
+ if arch_info is not None:
127
+ module_names = {m.name for m in arch_info.modules}
128
+ for candidate in native_candidates:
129
+ full = f"{prefix}.{idx}.{candidate}"
130
+ if full in module_names:
131
+ return full
132
+ return f"{prefix}.{idx}.{native_candidates[0]}"
133
+
134
+ # Strip TL-specific "hook_" prefixes for unknown suffixes
135
+ clean = re.sub(r"hook_", "", tl_suffix)
136
+ return f"{prefix}.{idx}.{clean}"
137
+
138
+
139
+ def list_tl_hooks(model: Any) -> list[str]:
140
+ """List all TL hook point names on a HookedTransformer.
141
+
142
+ Returns an empty list if the model is not a HookedTransformer.
143
+ """
144
+ hook_dict = getattr(model, "hook_dict", None)
145
+ if hook_dict is not None:
146
+ return sorted(hook_dict.keys())
147
+
148
+ # Fallback: look for HookPoint modules
149
+ hooks = []
150
+ for name, mod in model.named_modules():
151
+ if type(mod).__name__ == "HookPoint":
152
+ hooks.append(name)
153
+ return sorted(hooks)
154
+
155
+
156
+ def _infer_native_prefix(arch_info: "ModelArchInfo | None") -> str:
157
+ """Infer the native layer name prefix (e.g. 'transformer.h', 'model.layers')."""
158
+ if arch_info is None:
159
+ return "blocks"
160
+
161
+ if arch_info.layer_names:
162
+ first = arch_info.layer_names[0]
163
+ # Strip trailing .{digit} to get prefix
164
+ m = re.match(r"^(.+?)\.\d+$", first)
165
+ if m:
166
+ return m.group(1)
167
+
168
+ # Scan modules for repeating indexed patterns
169
+ for mod in arch_info.modules:
170
+ m = re.match(r"^(.+?)\.\d+$", mod.name)
171
+ if m:
172
+ return m.group(1)
173
+
174
+ return "blocks"
File without changes