interpkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,551 @@
1
+ """Universal model wrapper — load any HF model or nn.Module and run mech interp ops."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from interpkit.core.discovery import ModelArchInfo, discover
11
+ from interpkit.core.inputs import prepare_input, prepare_pair
12
+ from interpkit.core.registry import Registration, get_registration
13
+
14
+
15
+ class Model:
16
+ """Wraps a PyTorch model for mechanistic interpretability operations.
17
+
18
+ Created via :func:`interpkit.load` — not instantiated directly.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ model: nn.Module,
24
+ *,
25
+ tokenizer: Any | None = None,
26
+ image_processor: Any | None = None,
27
+ arch_info: ModelArchInfo,
28
+ registration: Registration | None = None,
29
+ device: torch.device | str = "cpu",
30
+ ) -> None:
31
+ self._model = model
32
+ self._tokenizer = tokenizer
33
+ self._image_processor = image_processor
34
+ self.arch_info = arch_info
35
+ self._registration = registration
36
+ self._device = torch.device(device)
37
+ self._cache: dict[str, torch.Tensor] = {}
38
+ self._cache_input_hash: int | None = None
39
+
40
+ # ------------------------------------------------------------------
41
+ # Input preparation
42
+ # ------------------------------------------------------------------
43
+
44
+ def _prepare(self, raw: str | torch.Tensor | Any) -> dict[str, torch.Tensor] | torch.Tensor:
45
+ return prepare_input(
46
+ raw,
47
+ tokenizer=self._tokenizer,
48
+ image_processor=self._image_processor,
49
+ device=self._device,
50
+ )
51
+
52
+ def _prepare_pair(
53
+ self, raw_a: str | torch.Tensor | Any, raw_b: str | torch.Tensor | Any,
54
+ ) -> tuple[dict[str, torch.Tensor] | torch.Tensor, dict[str, torch.Tensor] | torch.Tensor]:
55
+ return prepare_pair(
56
+ raw_a, raw_b,
57
+ tokenizer=self._tokenizer,
58
+ image_processor=self._image_processor,
59
+ device=self._device,
60
+ )
61
+
62
+ def _forward(self, model_input: dict[str, torch.Tensor] | torch.Tensor) -> torch.Tensor:
63
+ """Run a forward pass and return the output logits / final tensor."""
64
+ with torch.no_grad():
65
+ if isinstance(model_input, dict):
66
+ out = self._model(**model_input)
67
+ else:
68
+ out = self._model(model_input)
69
+
70
+ if hasattr(out, "logits"):
71
+ return out.logits
72
+ if isinstance(out, torch.Tensor):
73
+ return out
74
+ if isinstance(out, (tuple, list)):
75
+ return out[0]
76
+ raise TypeError(f"Unexpected model output type: {type(out).__name__}")
77
+
78
+ # ------------------------------------------------------------------
79
+ # Activation cache
80
+ # ------------------------------------------------------------------
81
+
82
+ @property
83
+ def cached(self) -> bool:
84
+ """True if the activation cache is populated."""
85
+ return len(self._cache) > 0
86
+
87
+ def cache(
88
+ self,
89
+ input_data: str | torch.Tensor | Any,
90
+ *,
91
+ at: list[str] | None = None,
92
+ ) -> "Model":
93
+ """Run a forward pass and cache activations for reuse by other operations.
94
+
95
+ Parameters
96
+ ----------
97
+ input_data:
98
+ The input to cache activations for.
99
+ at:
100
+ Module names to cache. If None, caches all modules with parameters.
101
+
102
+ Returns ``self`` for chaining.
103
+ """
104
+ from interpkit.ops.activations import run_activations
105
+
106
+ model_input = self._prepare(input_data)
107
+ input_hash = _hash_input(model_input)
108
+
109
+ if at is None:
110
+ at = [m.name for m in self.arch_info.modules if m.param_count > 0]
111
+
112
+ result = run_activations(self, input_data, at=at, print_stats=False)
113
+ self._cache = result if isinstance(result, dict) else {at[0]: result}
114
+ self._cache_input_hash = input_hash
115
+ return self
116
+
117
+ def clear_cache(self) -> None:
118
+ """Free cached activation tensors."""
119
+ self._cache.clear()
120
+ self._cache_input_hash = None
121
+
122
+ def _get_cached(
123
+ self,
124
+ input_data: str | torch.Tensor | Any,
125
+ module_names: list[str],
126
+ *,
127
+ _prepared_input: dict[str, torch.Tensor] | torch.Tensor | None = None,
128
+ ) -> dict[str, torch.Tensor] | None:
129
+ """Return cached activations if available for this input, else None.
130
+
131
+ Pass *_prepared_input* to avoid re-tokenizing when the caller
132
+ already has the prepared input.
133
+ """
134
+ if not self._cache:
135
+ return None
136
+
137
+ model_input = _prepared_input if _prepared_input is not None else self._prepare(input_data)
138
+ input_hash = _hash_input(model_input)
139
+
140
+ if input_hash != self._cache_input_hash:
141
+ return None
142
+
143
+ if all(name in self._cache for name in module_names):
144
+ return {name: self._cache[name] for name in module_names}
145
+
146
+ return None
147
+
148
+ # ------------------------------------------------------------------
149
+ # Public operations — delegate to ops/
150
+ # ------------------------------------------------------------------
151
+
152
+ def inspect(self) -> None:
153
+ """Print the model's module tree with types, param counts, and detected roles."""
154
+ from interpkit.ops.inspect import run_inspect
155
+
156
+ run_inspect(self)
157
+
158
+ def activations(
159
+ self,
160
+ input_data: str | torch.Tensor | Any,
161
+ *,
162
+ at: str | list[str],
163
+ ) -> dict[str, torch.Tensor] | torch.Tensor:
164
+ """Extract raw activation tensors at one or more named modules.
165
+
166
+ Returns a single tensor if *at* is a string, or a dict if *at* is a list.
167
+ """
168
+ from interpkit.ops.activations import run_activations
169
+
170
+ return run_activations(self, input_data, at=at)
171
+
172
+ def steer_vector(
173
+ self,
174
+ positive: str | torch.Tensor | Any,
175
+ negative: str | torch.Tensor | Any,
176
+ *,
177
+ at: str,
178
+ ) -> torch.Tensor:
179
+ """Extract a steering vector: activation(positive) - activation(negative)."""
180
+ from interpkit.ops.steer import run_steer_vector
181
+
182
+ return run_steer_vector(self, positive, negative, at=at)
183
+
184
+ def steer(
185
+ self,
186
+ input_data: str | torch.Tensor | Any,
187
+ *,
188
+ vector: torch.Tensor,
189
+ at: str,
190
+ scale: float = 2.0,
191
+ save: str | None = None,
192
+ ) -> dict[str, Any]:
193
+ """Run inference with a steering vector added at module *at*.
194
+
195
+ Shows side-by-side comparison of original vs steered top predictions.
196
+ Pass ``save="path.png"`` to export a matplotlib figure.
197
+ """
198
+ from interpkit.ops.steer import run_steer
199
+
200
+ return run_steer(self, input_data, vector=vector, at=at, scale=scale, save=save)
201
+
202
+ def attention(
203
+ self,
204
+ input_data: str | torch.Tensor | Any,
205
+ *,
206
+ layer: int | None = None,
207
+ head: int | None = None,
208
+ save: str | None = None,
209
+ html: str | None = None,
210
+ ) -> list[dict[str, Any]] | None:
211
+ """Show attention patterns. Returns None for non-transformer models.
212
+
213
+ Pass ``save="path.png"`` to export a matplotlib heatmap.
214
+ Pass ``html="path.html"`` to export an interactive HTML page.
215
+ """
216
+ from interpkit.ops.attention import run_attention
217
+
218
+ return run_attention(self, input_data, layer=layer, head=head, save=save, html=html)
219
+
220
+ def ablate(
221
+ self,
222
+ input_data: str | torch.Tensor | Any,
223
+ *,
224
+ at: str,
225
+ method: str = "zero",
226
+ ) -> dict[str, Any]:
227
+ """Ablate a module (zero or mean) and measure effect on output.
228
+
229
+ Returns a dict with ``effect`` (0 = no change, 1 = max change).
230
+ """
231
+ from interpkit.ops.ablate import run_ablate
232
+
233
+ return run_ablate(self, input_data, at=at, method=method)
234
+
235
+ def patch(
236
+ self,
237
+ clean: str | torch.Tensor | Any,
238
+ corrupted: str | torch.Tensor | Any,
239
+ *,
240
+ at: str,
241
+ ) -> dict[str, Any]:
242
+ """Activation patching: swap a single module's output from clean into corrupted.
243
+
244
+ Returns a dict with ``clean_logits``, ``corrupted_logits``, ``patched_logits``,
245
+ and ``effect`` (normalised scalar measuring how much the patch restored clean behaviour).
246
+ """
247
+ from interpkit.ops.patch import run_patch
248
+
249
+ return run_patch(self, clean, corrupted, at=at)
250
+
251
+ def trace(
252
+ self,
253
+ clean: str | torch.Tensor | Any,
254
+ corrupted: str | torch.Tensor | Any,
255
+ *,
256
+ top_k: int | None = 20,
257
+ save: str | None = None,
258
+ html: str | None = None,
259
+ ) -> list[dict[str, Any]]:
260
+ """Causal tracing: rank modules by how much patching them restores clean output.
261
+
262
+ Uses a two-phase approach: fast proxy (activation norm delta) to shortlist,
263
+ then full patch-and-measure on the top-k candidates.
264
+ Pass ``save="path.png"`` to export a matplotlib bar chart.
265
+ Pass ``html="path.html"`` to export an interactive HTML page.
266
+ """
267
+ from interpkit.ops.trace import run_trace
268
+
269
+ return run_trace(self, clean, corrupted, top_k=top_k, save=save, html=html)
270
+
271
+ def lens(
272
+ self,
273
+ text: str | torch.Tensor | Any,
274
+ *,
275
+ save: str | None = None,
276
+ ) -> list[dict[str, Any]] | None:
277
+ """Logit lens: project each layer's output to vocabulary space.
278
+
279
+ Only available for language models with a detectable unembedding matrix.
280
+ Pass ``save="path.png"`` to export a matplotlib heatmap.
281
+ """
282
+ from interpkit.ops.lens import run_lens
283
+
284
+ return run_lens(self, text, save=save)
285
+
286
+ def probe(
287
+ self,
288
+ texts: list[str],
289
+ labels: list[int],
290
+ *,
291
+ at: str,
292
+ ) -> dict[str, Any]:
293
+ """Train a linear probe on activations at module *at*.
294
+
295
+ Returns accuracy, top features by weight magnitude.
296
+ Requires scikit-learn (``pip install interpkit[probe]``), falls back to
297
+ a torch-based probe otherwise.
298
+ """
299
+ from interpkit.ops.probe import run_probe
300
+
301
+ return run_probe(self, texts, labels, at=at)
302
+
303
+ def features(
304
+ self,
305
+ input_data: str | torch.Tensor | Any,
306
+ *,
307
+ at: str,
308
+ sae: str | Any,
309
+ top_k: int = 20,
310
+ ) -> dict[str, Any]:
311
+ """Decompose activations at *at* through a Sparse Autoencoder.
312
+
313
+ Parameters
314
+ ----------
315
+ sae:
316
+ Either a HuggingFace repo ID (``"jbloom/GPT2-Small-SAEs-Reformatted"``)
317
+ or a pre-loaded :class:`interpkit.ops.sae.SAE` object.
318
+ """
319
+ from interpkit.ops.sae import SAE as SAEClass
320
+ from interpkit.ops.sae import load_sae, run_features
321
+
322
+ if isinstance(sae, str):
323
+ sae = load_sae(sae, device=self._device)
324
+ elif not isinstance(sae, SAEClass):
325
+ raise TypeError(f"Expected SAE or HF repo ID string, got {type(sae).__name__}")
326
+
327
+ return run_features(self, input_data, at=at, sae=sae, top_k=top_k)
328
+
329
+ def attribute(
330
+ self,
331
+ input_data: str | torch.Tensor | Any,
332
+ *,
333
+ target: int | None = None,
334
+ save: str | None = None,
335
+ html: str | None = None,
336
+ ) -> None:
337
+ """Gradient saliency over the input.
338
+
339
+ For NLP: prints coloured tokens by importance.
340
+ For vision: saves a heatmap image.
341
+ Pass ``save="path.png"`` to export a matplotlib figure.
342
+ Pass ``html="path.html"`` to export an interactive HTML page.
343
+ """
344
+ from interpkit.ops.attribute import run_attribute
345
+
346
+ run_attribute(self, input_data, target=target, save=save, html=html)
347
+
348
+
349
+ # ======================================================================
350
+ # Top-level loader
351
+ # ======================================================================
352
+
353
+
354
+ def load(
355
+ model_or_name: str | nn.Module,
356
+ *,
357
+ tokenizer: Any | None = None,
358
+ image_processor: Any | None = None,
359
+ device: str | torch.device | None = None,
360
+ ) -> Model:
361
+ """Load a model for mechanistic interpretability.
362
+
363
+ Parameters
364
+ ----------
365
+ model_or_name:
366
+ A HuggingFace model ID (``"gpt2"``, ``"microsoft/resnet-50"``)
367
+ or an existing ``nn.Module`` instance.
368
+ tokenizer:
369
+ An explicit tokenizer. Auto-loaded for HF models if not provided.
370
+ image_processor:
371
+ An explicit image processor. Auto-loaded for HF vision models if not provided.
372
+ device:
373
+ Device to run on. Defaults to CUDA if available, else CPU.
374
+ """
375
+ if device is None:
376
+ device = "cuda" if torch.cuda.is_available() else "cpu"
377
+
378
+ is_tl = False
379
+
380
+ if isinstance(model_or_name, str):
381
+ model, tokenizer, image_processor = _load_from_hf(
382
+ model_or_name, tokenizer=tokenizer, image_processor=image_processor, device=device
383
+ )
384
+ else:
385
+ model = model_or_name
386
+
387
+ # Detect TransformerLens HookedTransformer
388
+ if _is_hooked_transformer(model):
389
+ is_tl = True
390
+ if tokenizer is None:
391
+ tl_tok = getattr(model, "tokenizer", None)
392
+ if tl_tok is not None:
393
+ tokenizer = tl_tok
394
+ if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
395
+ tokenizer.pad_token = tokenizer.eos_token
396
+
397
+ model.to(device)
398
+
399
+ model.eval()
400
+ registration = get_registration(model)
401
+
402
+ # Build a dummy input for shape enumeration
403
+ # TL models accept a raw token tensor, not tokenizer dict kwargs
404
+ if is_tl:
405
+ dummy = torch.tensor([[0]], device=device)
406
+ else:
407
+ dummy = _make_dummy_input(model, tokenizer=tokenizer, image_processor=image_processor, device=device)
408
+ arch_info = discover(model, dummy_input=dummy)
409
+ arch_info.is_tl_model = is_tl
410
+
411
+ # Merge manual registration into arch_info
412
+ if registration is not None:
413
+ if registration.layers:
414
+ arch_info.layer_names = registration.layers
415
+ if registration.output_head:
416
+ arch_info.output_head_name = registration.output_head
417
+ arch_info.unembedding_name = registration.output_head
418
+ arch_info.has_lm_head = True
419
+ for mod_info in arch_info.modules:
420
+ if mod_info.name in registration.attention_modules:
421
+ mod_info.role = "attention"
422
+ elif mod_info.name in registration.mlp_modules:
423
+ mod_info.role = "mlp"
424
+
425
+ return Model(
426
+ model,
427
+ tokenizer=tokenizer,
428
+ image_processor=image_processor,
429
+ arch_info=arch_info,
430
+ registration=registration,
431
+ device=device,
432
+ )
433
+
434
+
435
+ def _load_from_hf(
436
+ name: str,
437
+ *,
438
+ tokenizer: Any | None,
439
+ image_processor: Any | None,
440
+ device: str | torch.device,
441
+ ) -> tuple[nn.Module, Any | None, Any | None]:
442
+ from transformers import AutoModel, AutoTokenizer
443
+
444
+ # Try loading as a causal/seq2seq/masked LM first, then fall back to AutoModel
445
+ model = None
446
+ for auto_cls_name in (
447
+ "AutoModelForCausalLM",
448
+ "AutoModelForSeq2SeqLM",
449
+ "AutoModelForMaskedLM",
450
+ "AutoModelForImageClassification",
451
+ "AutoModel",
452
+ ):
453
+ try:
454
+ from transformers import AutoConfig
455
+
456
+ config = AutoConfig.from_pretrained(name)
457
+ import transformers
458
+
459
+ auto_cls = getattr(transformers, auto_cls_name)
460
+ model = auto_cls.from_pretrained(name, config=config)
461
+ break
462
+ except (ValueError, OSError, KeyError):
463
+ continue
464
+
465
+ if model is None:
466
+ model = AutoModel.from_pretrained(name)
467
+
468
+ model = model.to(device)
469
+
470
+ if tokenizer is None:
471
+ try:
472
+ tokenizer = AutoTokenizer.from_pretrained(name)
473
+ except Exception:
474
+ pass
475
+
476
+ if tokenizer is not None and tokenizer.pad_token is None:
477
+ tokenizer.pad_token = tokenizer.eos_token
478
+
479
+ if image_processor is None:
480
+ try:
481
+ from transformers import AutoImageProcessor
482
+
483
+ image_processor = AutoImageProcessor.from_pretrained(name)
484
+ except (OSError, KeyError, ImportError):
485
+ pass
486
+
487
+ return model, tokenizer, image_processor
488
+
489
+
490
+ def _make_dummy_input(
491
+ model: nn.Module,
492
+ *,
493
+ tokenizer: Any | None,
494
+ image_processor: Any | None,
495
+ device: str | torch.device,
496
+ ) -> dict[str, torch.Tensor] | torch.Tensor | None:
497
+ """Create a small dummy input for forward-pass shape enumeration."""
498
+ if tokenizer is not None:
499
+ try:
500
+ encoded = tokenizer("hello", return_tensors="pt")
501
+ return {k: v.to(device) for k, v in encoded.items()}
502
+ except Exception:
503
+ pass
504
+
505
+ if image_processor is not None:
506
+ try:
507
+ from PIL import Image
508
+
509
+ dummy_img = Image.new("RGB", (224, 224), color=(128, 128, 128))
510
+ processed = image_processor(images=dummy_img, return_tensors="pt")
511
+ return {k: v.to(device) for k, v in processed.items()}
512
+ except Exception:
513
+ pass
514
+
515
+ # Fallback: try a simple tensor
516
+ config = getattr(model, "config", None)
517
+ if config is not None:
518
+ hidden = getattr(config, "hidden_size", None) or getattr(config, "n_embd", None)
519
+ if hidden:
520
+ return torch.randn(1, 8, hidden, device=device)
521
+
522
+ return None
523
+
524
+
525
+ def _hash_input(model_input: dict[str, torch.Tensor] | torch.Tensor) -> int:
526
+ """Compute a hash of a model input for cache key comparison.
527
+
528
+ Uses the raw byte content of tensors to avoid collisions from inputs
529
+ that happen to share the same sum/shape.
530
+ """
531
+ import hashlib
532
+
533
+ h = hashlib.sha256()
534
+ if isinstance(model_input, dict):
535
+ for k in sorted(model_input.keys()):
536
+ v = model_input[k]
537
+ h.update(k.encode())
538
+ h.update(v.cpu().contiguous().numpy().tobytes())
539
+ else:
540
+ h.update(model_input.cpu().contiguous().numpy().tobytes())
541
+ return int.from_bytes(h.digest()[:8], "little")
542
+
543
+
544
+ def _is_hooked_transformer(model: nn.Module) -> bool:
545
+ """Detect a TransformerLens HookedTransformer without importing the library."""
546
+ cls_name = type(model).__name__
547
+ if cls_name in ("HookedTransformer", "HookedEncoder", "HookedEncoderDecoder"):
548
+ return True
549
+ if hasattr(model, "hook_dict") and hasattr(model, "cfg"):
550
+ return True
551
+ return False