interpkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit/__init__.py +15 -0
- interpkit/cli/__init__.py +0 -0
- interpkit/cli/main.py +337 -0
- interpkit/core/__init__.py +0 -0
- interpkit/core/discovery.py +228 -0
- interpkit/core/html.py +375 -0
- interpkit/core/inputs.py +117 -0
- interpkit/core/model.py +551 -0
- interpkit/core/plot.py +352 -0
- interpkit/core/registry.py +82 -0
- interpkit/core/render.py +465 -0
- interpkit/core/tl_compat.py +174 -0
- interpkit/ops/__init__.py +0 -0
- interpkit/ops/ablate.py +90 -0
- interpkit/ops/activations.py +67 -0
- interpkit/ops/attention.py +234 -0
- interpkit/ops/attribute.py +206 -0
- interpkit/ops/diff.py +79 -0
- interpkit/ops/inspect.py +14 -0
- interpkit/ops/lens.py +151 -0
- interpkit/ops/patch.py +112 -0
- interpkit/ops/probe.py +128 -0
- interpkit/ops/sae.py +212 -0
- interpkit/ops/steer.py +118 -0
- interpkit/ops/trace.py +182 -0
- interpkit-0.1.0.dist-info/METADATA +295 -0
- interpkit-0.1.0.dist-info/RECORD +31 -0
- interpkit-0.1.0.dist-info/WHEEL +5 -0
- interpkit-0.1.0.dist-info/entry_points.txt +2 -0
- interpkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- interpkit-0.1.0.dist-info/top_level.txt +1 -0
interpkit/core/model.py
ADDED
|
@@ -0,0 +1,551 @@
|
|
|
1
|
+
"""Universal model wrapper — load any HF model or nn.Module and run mech interp ops."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
|
|
10
|
+
from interpkit.core.discovery import ModelArchInfo, discover
|
|
11
|
+
from interpkit.core.inputs import prepare_input, prepare_pair
|
|
12
|
+
from interpkit.core.registry import Registration, get_registration
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Model:
|
|
16
|
+
"""Wraps a PyTorch model for mechanistic interpretability operations.
|
|
17
|
+
|
|
18
|
+
Created via :func:`interpkit.load` — not instantiated directly.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
model: nn.Module,
|
|
24
|
+
*,
|
|
25
|
+
tokenizer: Any | None = None,
|
|
26
|
+
image_processor: Any | None = None,
|
|
27
|
+
arch_info: ModelArchInfo,
|
|
28
|
+
registration: Registration | None = None,
|
|
29
|
+
device: torch.device | str = "cpu",
|
|
30
|
+
) -> None:
|
|
31
|
+
self._model = model
|
|
32
|
+
self._tokenizer = tokenizer
|
|
33
|
+
self._image_processor = image_processor
|
|
34
|
+
self.arch_info = arch_info
|
|
35
|
+
self._registration = registration
|
|
36
|
+
self._device = torch.device(device)
|
|
37
|
+
self._cache: dict[str, torch.Tensor] = {}
|
|
38
|
+
self._cache_input_hash: int | None = None
|
|
39
|
+
|
|
40
|
+
# ------------------------------------------------------------------
|
|
41
|
+
# Input preparation
|
|
42
|
+
# ------------------------------------------------------------------
|
|
43
|
+
|
|
44
|
+
def _prepare(self, raw: str | torch.Tensor | Any) -> dict[str, torch.Tensor] | torch.Tensor:
|
|
45
|
+
return prepare_input(
|
|
46
|
+
raw,
|
|
47
|
+
tokenizer=self._tokenizer,
|
|
48
|
+
image_processor=self._image_processor,
|
|
49
|
+
device=self._device,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def _prepare_pair(
|
|
53
|
+
self, raw_a: str | torch.Tensor | Any, raw_b: str | torch.Tensor | Any,
|
|
54
|
+
) -> tuple[dict[str, torch.Tensor] | torch.Tensor, dict[str, torch.Tensor] | torch.Tensor]:
|
|
55
|
+
return prepare_pair(
|
|
56
|
+
raw_a, raw_b,
|
|
57
|
+
tokenizer=self._tokenizer,
|
|
58
|
+
image_processor=self._image_processor,
|
|
59
|
+
device=self._device,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def _forward(self, model_input: dict[str, torch.Tensor] | torch.Tensor) -> torch.Tensor:
|
|
63
|
+
"""Run a forward pass and return the output logits / final tensor."""
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
if isinstance(model_input, dict):
|
|
66
|
+
out = self._model(**model_input)
|
|
67
|
+
else:
|
|
68
|
+
out = self._model(model_input)
|
|
69
|
+
|
|
70
|
+
if hasattr(out, "logits"):
|
|
71
|
+
return out.logits
|
|
72
|
+
if isinstance(out, torch.Tensor):
|
|
73
|
+
return out
|
|
74
|
+
if isinstance(out, (tuple, list)):
|
|
75
|
+
return out[0]
|
|
76
|
+
raise TypeError(f"Unexpected model output type: {type(out).__name__}")
|
|
77
|
+
|
|
78
|
+
# ------------------------------------------------------------------
|
|
79
|
+
# Activation cache
|
|
80
|
+
# ------------------------------------------------------------------
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def cached(self) -> bool:
|
|
84
|
+
"""True if the activation cache is populated."""
|
|
85
|
+
return len(self._cache) > 0
|
|
86
|
+
|
|
87
|
+
def cache(
|
|
88
|
+
self,
|
|
89
|
+
input_data: str | torch.Tensor | Any,
|
|
90
|
+
*,
|
|
91
|
+
at: list[str] | None = None,
|
|
92
|
+
) -> "Model":
|
|
93
|
+
"""Run a forward pass and cache activations for reuse by other operations.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
input_data:
|
|
98
|
+
The input to cache activations for.
|
|
99
|
+
at:
|
|
100
|
+
Module names to cache. If None, caches all modules with parameters.
|
|
101
|
+
|
|
102
|
+
Returns ``self`` for chaining.
|
|
103
|
+
"""
|
|
104
|
+
from interpkit.ops.activations import run_activations
|
|
105
|
+
|
|
106
|
+
model_input = self._prepare(input_data)
|
|
107
|
+
input_hash = _hash_input(model_input)
|
|
108
|
+
|
|
109
|
+
if at is None:
|
|
110
|
+
at = [m.name for m in self.arch_info.modules if m.param_count > 0]
|
|
111
|
+
|
|
112
|
+
result = run_activations(self, input_data, at=at, print_stats=False)
|
|
113
|
+
self._cache = result if isinstance(result, dict) else {at[0]: result}
|
|
114
|
+
self._cache_input_hash = input_hash
|
|
115
|
+
return self
|
|
116
|
+
|
|
117
|
+
def clear_cache(self) -> None:
|
|
118
|
+
"""Free cached activation tensors."""
|
|
119
|
+
self._cache.clear()
|
|
120
|
+
self._cache_input_hash = None
|
|
121
|
+
|
|
122
|
+
def _get_cached(
|
|
123
|
+
self,
|
|
124
|
+
input_data: str | torch.Tensor | Any,
|
|
125
|
+
module_names: list[str],
|
|
126
|
+
*,
|
|
127
|
+
_prepared_input: dict[str, torch.Tensor] | torch.Tensor | None = None,
|
|
128
|
+
) -> dict[str, torch.Tensor] | None:
|
|
129
|
+
"""Return cached activations if available for this input, else None.
|
|
130
|
+
|
|
131
|
+
Pass *_prepared_input* to avoid re-tokenizing when the caller
|
|
132
|
+
already has the prepared input.
|
|
133
|
+
"""
|
|
134
|
+
if not self._cache:
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
model_input = _prepared_input if _prepared_input is not None else self._prepare(input_data)
|
|
138
|
+
input_hash = _hash_input(model_input)
|
|
139
|
+
|
|
140
|
+
if input_hash != self._cache_input_hash:
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
if all(name in self._cache for name in module_names):
|
|
144
|
+
return {name: self._cache[name] for name in module_names}
|
|
145
|
+
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
# ------------------------------------------------------------------
|
|
149
|
+
# Public operations — delegate to ops/
|
|
150
|
+
# ------------------------------------------------------------------
|
|
151
|
+
|
|
152
|
+
def inspect(self) -> None:
|
|
153
|
+
"""Print the model's module tree with types, param counts, and detected roles."""
|
|
154
|
+
from interpkit.ops.inspect import run_inspect
|
|
155
|
+
|
|
156
|
+
run_inspect(self)
|
|
157
|
+
|
|
158
|
+
def activations(
|
|
159
|
+
self,
|
|
160
|
+
input_data: str | torch.Tensor | Any,
|
|
161
|
+
*,
|
|
162
|
+
at: str | list[str],
|
|
163
|
+
) -> dict[str, torch.Tensor] | torch.Tensor:
|
|
164
|
+
"""Extract raw activation tensors at one or more named modules.
|
|
165
|
+
|
|
166
|
+
Returns a single tensor if *at* is a string, or a dict if *at* is a list.
|
|
167
|
+
"""
|
|
168
|
+
from interpkit.ops.activations import run_activations
|
|
169
|
+
|
|
170
|
+
return run_activations(self, input_data, at=at)
|
|
171
|
+
|
|
172
|
+
def steer_vector(
|
|
173
|
+
self,
|
|
174
|
+
positive: str | torch.Tensor | Any,
|
|
175
|
+
negative: str | torch.Tensor | Any,
|
|
176
|
+
*,
|
|
177
|
+
at: str,
|
|
178
|
+
) -> torch.Tensor:
|
|
179
|
+
"""Extract a steering vector: activation(positive) - activation(negative)."""
|
|
180
|
+
from interpkit.ops.steer import run_steer_vector
|
|
181
|
+
|
|
182
|
+
return run_steer_vector(self, positive, negative, at=at)
|
|
183
|
+
|
|
184
|
+
def steer(
|
|
185
|
+
self,
|
|
186
|
+
input_data: str | torch.Tensor | Any,
|
|
187
|
+
*,
|
|
188
|
+
vector: torch.Tensor,
|
|
189
|
+
at: str,
|
|
190
|
+
scale: float = 2.0,
|
|
191
|
+
save: str | None = None,
|
|
192
|
+
) -> dict[str, Any]:
|
|
193
|
+
"""Run inference with a steering vector added at module *at*.
|
|
194
|
+
|
|
195
|
+
Shows side-by-side comparison of original vs steered top predictions.
|
|
196
|
+
Pass ``save="path.png"`` to export a matplotlib figure.
|
|
197
|
+
"""
|
|
198
|
+
from interpkit.ops.steer import run_steer
|
|
199
|
+
|
|
200
|
+
return run_steer(self, input_data, vector=vector, at=at, scale=scale, save=save)
|
|
201
|
+
|
|
202
|
+
def attention(
|
|
203
|
+
self,
|
|
204
|
+
input_data: str | torch.Tensor | Any,
|
|
205
|
+
*,
|
|
206
|
+
layer: int | None = None,
|
|
207
|
+
head: int | None = None,
|
|
208
|
+
save: str | None = None,
|
|
209
|
+
html: str | None = None,
|
|
210
|
+
) -> list[dict[str, Any]] | None:
|
|
211
|
+
"""Show attention patterns. Returns None for non-transformer models.
|
|
212
|
+
|
|
213
|
+
Pass ``save="path.png"`` to export a matplotlib heatmap.
|
|
214
|
+
Pass ``html="path.html"`` to export an interactive HTML page.
|
|
215
|
+
"""
|
|
216
|
+
from interpkit.ops.attention import run_attention
|
|
217
|
+
|
|
218
|
+
return run_attention(self, input_data, layer=layer, head=head, save=save, html=html)
|
|
219
|
+
|
|
220
|
+
def ablate(
|
|
221
|
+
self,
|
|
222
|
+
input_data: str | torch.Tensor | Any,
|
|
223
|
+
*,
|
|
224
|
+
at: str,
|
|
225
|
+
method: str = "zero",
|
|
226
|
+
) -> dict[str, Any]:
|
|
227
|
+
"""Ablate a module (zero or mean) and measure effect on output.
|
|
228
|
+
|
|
229
|
+
Returns a dict with ``effect`` (0 = no change, 1 = max change).
|
|
230
|
+
"""
|
|
231
|
+
from interpkit.ops.ablate import run_ablate
|
|
232
|
+
|
|
233
|
+
return run_ablate(self, input_data, at=at, method=method)
|
|
234
|
+
|
|
235
|
+
def patch(
|
|
236
|
+
self,
|
|
237
|
+
clean: str | torch.Tensor | Any,
|
|
238
|
+
corrupted: str | torch.Tensor | Any,
|
|
239
|
+
*,
|
|
240
|
+
at: str,
|
|
241
|
+
) -> dict[str, Any]:
|
|
242
|
+
"""Activation patching: swap a single module's output from clean into corrupted.
|
|
243
|
+
|
|
244
|
+
Returns a dict with ``clean_logits``, ``corrupted_logits``, ``patched_logits``,
|
|
245
|
+
and ``effect`` (normalised scalar measuring how much the patch restored clean behaviour).
|
|
246
|
+
"""
|
|
247
|
+
from interpkit.ops.patch import run_patch
|
|
248
|
+
|
|
249
|
+
return run_patch(self, clean, corrupted, at=at)
|
|
250
|
+
|
|
251
|
+
def trace(
|
|
252
|
+
self,
|
|
253
|
+
clean: str | torch.Tensor | Any,
|
|
254
|
+
corrupted: str | torch.Tensor | Any,
|
|
255
|
+
*,
|
|
256
|
+
top_k: int | None = 20,
|
|
257
|
+
save: str | None = None,
|
|
258
|
+
html: str | None = None,
|
|
259
|
+
) -> list[dict[str, Any]]:
|
|
260
|
+
"""Causal tracing: rank modules by how much patching them restores clean output.
|
|
261
|
+
|
|
262
|
+
Uses a two-phase approach: fast proxy (activation norm delta) to shortlist,
|
|
263
|
+
then full patch-and-measure on the top-k candidates.
|
|
264
|
+
Pass ``save="path.png"`` to export a matplotlib bar chart.
|
|
265
|
+
Pass ``html="path.html"`` to export an interactive HTML page.
|
|
266
|
+
"""
|
|
267
|
+
from interpkit.ops.trace import run_trace
|
|
268
|
+
|
|
269
|
+
return run_trace(self, clean, corrupted, top_k=top_k, save=save, html=html)
|
|
270
|
+
|
|
271
|
+
def lens(
|
|
272
|
+
self,
|
|
273
|
+
text: str | torch.Tensor | Any,
|
|
274
|
+
*,
|
|
275
|
+
save: str | None = None,
|
|
276
|
+
) -> list[dict[str, Any]] | None:
|
|
277
|
+
"""Logit lens: project each layer's output to vocabulary space.
|
|
278
|
+
|
|
279
|
+
Only available for language models with a detectable unembedding matrix.
|
|
280
|
+
Pass ``save="path.png"`` to export a matplotlib heatmap.
|
|
281
|
+
"""
|
|
282
|
+
from interpkit.ops.lens import run_lens
|
|
283
|
+
|
|
284
|
+
return run_lens(self, text, save=save)
|
|
285
|
+
|
|
286
|
+
def probe(
|
|
287
|
+
self,
|
|
288
|
+
texts: list[str],
|
|
289
|
+
labels: list[int],
|
|
290
|
+
*,
|
|
291
|
+
at: str,
|
|
292
|
+
) -> dict[str, Any]:
|
|
293
|
+
"""Train a linear probe on activations at module *at*.
|
|
294
|
+
|
|
295
|
+
Returns accuracy, top features by weight magnitude.
|
|
296
|
+
Requires scikit-learn (``pip install interpkit[probe]``), falls back to
|
|
297
|
+
a torch-based probe otherwise.
|
|
298
|
+
"""
|
|
299
|
+
from interpkit.ops.probe import run_probe
|
|
300
|
+
|
|
301
|
+
return run_probe(self, texts, labels, at=at)
|
|
302
|
+
|
|
303
|
+
def features(
|
|
304
|
+
self,
|
|
305
|
+
input_data: str | torch.Tensor | Any,
|
|
306
|
+
*,
|
|
307
|
+
at: str,
|
|
308
|
+
sae: str | Any,
|
|
309
|
+
top_k: int = 20,
|
|
310
|
+
) -> dict[str, Any]:
|
|
311
|
+
"""Decompose activations at *at* through a Sparse Autoencoder.
|
|
312
|
+
|
|
313
|
+
Parameters
|
|
314
|
+
----------
|
|
315
|
+
sae:
|
|
316
|
+
Either a HuggingFace repo ID (``"jbloom/GPT2-Small-SAEs-Reformatted"``)
|
|
317
|
+
or a pre-loaded :class:`interpkit.ops.sae.SAE` object.
|
|
318
|
+
"""
|
|
319
|
+
from interpkit.ops.sae import SAE as SAEClass
|
|
320
|
+
from interpkit.ops.sae import load_sae, run_features
|
|
321
|
+
|
|
322
|
+
if isinstance(sae, str):
|
|
323
|
+
sae = load_sae(sae, device=self._device)
|
|
324
|
+
elif not isinstance(sae, SAEClass):
|
|
325
|
+
raise TypeError(f"Expected SAE or HF repo ID string, got {type(sae).__name__}")
|
|
326
|
+
|
|
327
|
+
return run_features(self, input_data, at=at, sae=sae, top_k=top_k)
|
|
328
|
+
|
|
329
|
+
def attribute(
|
|
330
|
+
self,
|
|
331
|
+
input_data: str | torch.Tensor | Any,
|
|
332
|
+
*,
|
|
333
|
+
target: int | None = None,
|
|
334
|
+
save: str | None = None,
|
|
335
|
+
html: str | None = None,
|
|
336
|
+
) -> None:
|
|
337
|
+
"""Gradient saliency over the input.
|
|
338
|
+
|
|
339
|
+
For NLP: prints coloured tokens by importance.
|
|
340
|
+
For vision: saves a heatmap image.
|
|
341
|
+
Pass ``save="path.png"`` to export a matplotlib figure.
|
|
342
|
+
Pass ``html="path.html"`` to export an interactive HTML page.
|
|
343
|
+
"""
|
|
344
|
+
from interpkit.ops.attribute import run_attribute
|
|
345
|
+
|
|
346
|
+
run_attribute(self, input_data, target=target, save=save, html=html)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
# ======================================================================
|
|
350
|
+
# Top-level loader
|
|
351
|
+
# ======================================================================
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def load(
|
|
355
|
+
model_or_name: str | nn.Module,
|
|
356
|
+
*,
|
|
357
|
+
tokenizer: Any | None = None,
|
|
358
|
+
image_processor: Any | None = None,
|
|
359
|
+
device: str | torch.device | None = None,
|
|
360
|
+
) -> Model:
|
|
361
|
+
"""Load a model for mechanistic interpretability.
|
|
362
|
+
|
|
363
|
+
Parameters
|
|
364
|
+
----------
|
|
365
|
+
model_or_name:
|
|
366
|
+
A HuggingFace model ID (``"gpt2"``, ``"microsoft/resnet-50"``)
|
|
367
|
+
or an existing ``nn.Module`` instance.
|
|
368
|
+
tokenizer:
|
|
369
|
+
An explicit tokenizer. Auto-loaded for HF models if not provided.
|
|
370
|
+
image_processor:
|
|
371
|
+
An explicit image processor. Auto-loaded for HF vision models if not provided.
|
|
372
|
+
device:
|
|
373
|
+
Device to run on. Defaults to CUDA if available, else CPU.
|
|
374
|
+
"""
|
|
375
|
+
if device is None:
|
|
376
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
377
|
+
|
|
378
|
+
is_tl = False
|
|
379
|
+
|
|
380
|
+
if isinstance(model_or_name, str):
|
|
381
|
+
model, tokenizer, image_processor = _load_from_hf(
|
|
382
|
+
model_or_name, tokenizer=tokenizer, image_processor=image_processor, device=device
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
model = model_or_name
|
|
386
|
+
|
|
387
|
+
# Detect TransformerLens HookedTransformer
|
|
388
|
+
if _is_hooked_transformer(model):
|
|
389
|
+
is_tl = True
|
|
390
|
+
if tokenizer is None:
|
|
391
|
+
tl_tok = getattr(model, "tokenizer", None)
|
|
392
|
+
if tl_tok is not None:
|
|
393
|
+
tokenizer = tl_tok
|
|
394
|
+
if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
|
|
395
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
396
|
+
|
|
397
|
+
model.to(device)
|
|
398
|
+
|
|
399
|
+
model.eval()
|
|
400
|
+
registration = get_registration(model)
|
|
401
|
+
|
|
402
|
+
# Build a dummy input for shape enumeration
|
|
403
|
+
# TL models accept a raw token tensor, not tokenizer dict kwargs
|
|
404
|
+
if is_tl:
|
|
405
|
+
dummy = torch.tensor([[0]], device=device)
|
|
406
|
+
else:
|
|
407
|
+
dummy = _make_dummy_input(model, tokenizer=tokenizer, image_processor=image_processor, device=device)
|
|
408
|
+
arch_info = discover(model, dummy_input=dummy)
|
|
409
|
+
arch_info.is_tl_model = is_tl
|
|
410
|
+
|
|
411
|
+
# Merge manual registration into arch_info
|
|
412
|
+
if registration is not None:
|
|
413
|
+
if registration.layers:
|
|
414
|
+
arch_info.layer_names = registration.layers
|
|
415
|
+
if registration.output_head:
|
|
416
|
+
arch_info.output_head_name = registration.output_head
|
|
417
|
+
arch_info.unembedding_name = registration.output_head
|
|
418
|
+
arch_info.has_lm_head = True
|
|
419
|
+
for mod_info in arch_info.modules:
|
|
420
|
+
if mod_info.name in registration.attention_modules:
|
|
421
|
+
mod_info.role = "attention"
|
|
422
|
+
elif mod_info.name in registration.mlp_modules:
|
|
423
|
+
mod_info.role = "mlp"
|
|
424
|
+
|
|
425
|
+
return Model(
|
|
426
|
+
model,
|
|
427
|
+
tokenizer=tokenizer,
|
|
428
|
+
image_processor=image_processor,
|
|
429
|
+
arch_info=arch_info,
|
|
430
|
+
registration=registration,
|
|
431
|
+
device=device,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _load_from_hf(
|
|
436
|
+
name: str,
|
|
437
|
+
*,
|
|
438
|
+
tokenizer: Any | None,
|
|
439
|
+
image_processor: Any | None,
|
|
440
|
+
device: str | torch.device,
|
|
441
|
+
) -> tuple[nn.Module, Any | None, Any | None]:
|
|
442
|
+
from transformers import AutoModel, AutoTokenizer
|
|
443
|
+
|
|
444
|
+
# Try loading as a causal/seq2seq/masked LM first, then fall back to AutoModel
|
|
445
|
+
model = None
|
|
446
|
+
for auto_cls_name in (
|
|
447
|
+
"AutoModelForCausalLM",
|
|
448
|
+
"AutoModelForSeq2SeqLM",
|
|
449
|
+
"AutoModelForMaskedLM",
|
|
450
|
+
"AutoModelForImageClassification",
|
|
451
|
+
"AutoModel",
|
|
452
|
+
):
|
|
453
|
+
try:
|
|
454
|
+
from transformers import AutoConfig
|
|
455
|
+
|
|
456
|
+
config = AutoConfig.from_pretrained(name)
|
|
457
|
+
import transformers
|
|
458
|
+
|
|
459
|
+
auto_cls = getattr(transformers, auto_cls_name)
|
|
460
|
+
model = auto_cls.from_pretrained(name, config=config)
|
|
461
|
+
break
|
|
462
|
+
except (ValueError, OSError, KeyError):
|
|
463
|
+
continue
|
|
464
|
+
|
|
465
|
+
if model is None:
|
|
466
|
+
model = AutoModel.from_pretrained(name)
|
|
467
|
+
|
|
468
|
+
model = model.to(device)
|
|
469
|
+
|
|
470
|
+
if tokenizer is None:
|
|
471
|
+
try:
|
|
472
|
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
|
473
|
+
except Exception:
|
|
474
|
+
pass
|
|
475
|
+
|
|
476
|
+
if tokenizer is not None and tokenizer.pad_token is None:
|
|
477
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
478
|
+
|
|
479
|
+
if image_processor is None:
|
|
480
|
+
try:
|
|
481
|
+
from transformers import AutoImageProcessor
|
|
482
|
+
|
|
483
|
+
image_processor = AutoImageProcessor.from_pretrained(name)
|
|
484
|
+
except (OSError, KeyError, ImportError):
|
|
485
|
+
pass
|
|
486
|
+
|
|
487
|
+
return model, tokenizer, image_processor
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _make_dummy_input(
|
|
491
|
+
model: nn.Module,
|
|
492
|
+
*,
|
|
493
|
+
tokenizer: Any | None,
|
|
494
|
+
image_processor: Any | None,
|
|
495
|
+
device: str | torch.device,
|
|
496
|
+
) -> dict[str, torch.Tensor] | torch.Tensor | None:
|
|
497
|
+
"""Create a small dummy input for forward-pass shape enumeration."""
|
|
498
|
+
if tokenizer is not None:
|
|
499
|
+
try:
|
|
500
|
+
encoded = tokenizer("hello", return_tensors="pt")
|
|
501
|
+
return {k: v.to(device) for k, v in encoded.items()}
|
|
502
|
+
except Exception:
|
|
503
|
+
pass
|
|
504
|
+
|
|
505
|
+
if image_processor is not None:
|
|
506
|
+
try:
|
|
507
|
+
from PIL import Image
|
|
508
|
+
|
|
509
|
+
dummy_img = Image.new("RGB", (224, 224), color=(128, 128, 128))
|
|
510
|
+
processed = image_processor(images=dummy_img, return_tensors="pt")
|
|
511
|
+
return {k: v.to(device) for k, v in processed.items()}
|
|
512
|
+
except Exception:
|
|
513
|
+
pass
|
|
514
|
+
|
|
515
|
+
# Fallback: try a simple tensor
|
|
516
|
+
config = getattr(model, "config", None)
|
|
517
|
+
if config is not None:
|
|
518
|
+
hidden = getattr(config, "hidden_size", None) or getattr(config, "n_embd", None)
|
|
519
|
+
if hidden:
|
|
520
|
+
return torch.randn(1, 8, hidden, device=device)
|
|
521
|
+
|
|
522
|
+
return None
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _hash_input(model_input: dict[str, torch.Tensor] | torch.Tensor) -> int:
|
|
526
|
+
"""Compute a hash of a model input for cache key comparison.
|
|
527
|
+
|
|
528
|
+
Uses the raw byte content of tensors to avoid collisions from inputs
|
|
529
|
+
that happen to share the same sum/shape.
|
|
530
|
+
"""
|
|
531
|
+
import hashlib
|
|
532
|
+
|
|
533
|
+
h = hashlib.sha256()
|
|
534
|
+
if isinstance(model_input, dict):
|
|
535
|
+
for k in sorted(model_input.keys()):
|
|
536
|
+
v = model_input[k]
|
|
537
|
+
h.update(k.encode())
|
|
538
|
+
h.update(v.cpu().contiguous().numpy().tobytes())
|
|
539
|
+
else:
|
|
540
|
+
h.update(model_input.cpu().contiguous().numpy().tobytes())
|
|
541
|
+
return int.from_bytes(h.digest()[:8], "little")
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
def _is_hooked_transformer(model: nn.Module) -> bool:
|
|
545
|
+
"""Detect a TransformerLens HookedTransformer without importing the library."""
|
|
546
|
+
cls_name = type(model).__name__
|
|
547
|
+
if cls_name in ("HookedTransformer", "HookedEncoder", "HookedEncoderDecoder"):
|
|
548
|
+
return True
|
|
549
|
+
if hasattr(model, "hook_dict") and hasattr(model, "cfg"):
|
|
550
|
+
return True
|
|
551
|
+
return False
|