interpkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- interpkit/__init__.py +15 -0
- interpkit/cli/__init__.py +0 -0
- interpkit/cli/main.py +337 -0
- interpkit/core/__init__.py +0 -0
- interpkit/core/discovery.py +228 -0
- interpkit/core/html.py +375 -0
- interpkit/core/inputs.py +117 -0
- interpkit/core/model.py +551 -0
- interpkit/core/plot.py +352 -0
- interpkit/core/registry.py +82 -0
- interpkit/core/render.py +465 -0
- interpkit/core/tl_compat.py +174 -0
- interpkit/ops/__init__.py +0 -0
- interpkit/ops/ablate.py +90 -0
- interpkit/ops/activations.py +67 -0
- interpkit/ops/attention.py +234 -0
- interpkit/ops/attribute.py +206 -0
- interpkit/ops/diff.py +79 -0
- interpkit/ops/inspect.py +14 -0
- interpkit/ops/lens.py +151 -0
- interpkit/ops/patch.py +112 -0
- interpkit/ops/probe.py +128 -0
- interpkit/ops/sae.py +212 -0
- interpkit/ops/steer.py +118 -0
- interpkit/ops/trace.py +182 -0
- interpkit-0.1.0.dist-info/METADATA +295 -0
- interpkit-0.1.0.dist-info/RECORD +31 -0
- interpkit-0.1.0.dist-info/WHEEL +5 -0
- interpkit-0.1.0.dist-info/entry_points.txt +2 -0
- interpkit-0.1.0.dist-info/licenses/LICENSE +21 -0
- interpkit-0.1.0.dist-info/top_level.txt +1 -0
interpkit/core/html.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
"""Interactive HTML visualization generators — self-contained files with inline JS/CSS."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import html
|
|
6
|
+
import json
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
_DARK_BG = "#1a1a2e"
|
|
11
|
+
_PANEL_BG = "#16213e"
|
|
12
|
+
_ACCENT = "#0f3460"
|
|
13
|
+
_HIGHLIGHT = "#e94560"
|
|
14
|
+
_TEXT = "#eee"
|
|
15
|
+
_DIM_TEXT = "#aaa"
|
|
16
|
+
_GREEN = "#53d769"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _wrap_page(title: str, body: str, *, extra_css: str = "", extra_js: str = "") -> str:
|
|
20
|
+
"""Wrap body HTML in a full self-contained page with dark theme."""
|
|
21
|
+
return f"""<!DOCTYPE html>
|
|
22
|
+
<html lang="en">
|
|
23
|
+
<head>
|
|
24
|
+
<meta charset="utf-8">
|
|
25
|
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
26
|
+
<title>{html.escape(title)}</title>
|
|
27
|
+
<style>
|
|
28
|
+
* {{ margin: 0; padding: 0; box-sizing: border-box; }}
|
|
29
|
+
body {{
|
|
30
|
+
background: {_DARK_BG};
|
|
31
|
+
color: {_TEXT};
|
|
32
|
+
font-family: 'Segoe UI', system-ui, -apple-system, sans-serif;
|
|
33
|
+
padding: 24px;
|
|
34
|
+
line-height: 1.5;
|
|
35
|
+
}}
|
|
36
|
+
h1 {{ color: {_HIGHLIGHT}; margin-bottom: 8px; font-size: 1.6em; }}
|
|
37
|
+
h2 {{ color: {_GREEN}; margin: 16px 0 8px; font-size: 1.2em; }}
|
|
38
|
+
.subtitle {{ color: {_DIM_TEXT}; margin-bottom: 20px; }}
|
|
39
|
+
.panel {{
|
|
40
|
+
background: {_PANEL_BG};
|
|
41
|
+
border-radius: 8px;
|
|
42
|
+
padding: 16px;
|
|
43
|
+
margin-bottom: 16px;
|
|
44
|
+
border: 1px solid {_ACCENT};
|
|
45
|
+
}}
|
|
46
|
+
table {{ border-collapse: collapse; width: 100%; }}
|
|
47
|
+
th, td {{ padding: 6px 10px; text-align: left; border-bottom: 1px solid {_ACCENT}; }}
|
|
48
|
+
th {{ color: {_GREEN}; font-weight: 600; }}
|
|
49
|
+
.heatmap-cell {{
|
|
50
|
+
width: 28px; height: 28px;
|
|
51
|
+
display: inline-block;
|
|
52
|
+
text-align: center;
|
|
53
|
+
font-size: 10px;
|
|
54
|
+
line-height: 28px;
|
|
55
|
+
cursor: pointer;
|
|
56
|
+
border-radius: 2px;
|
|
57
|
+
position: relative;
|
|
58
|
+
}}
|
|
59
|
+
.tooltip {{
|
|
60
|
+
position: fixed;
|
|
61
|
+
background: #222;
|
|
62
|
+
color: {_TEXT};
|
|
63
|
+
padding: 6px 10px;
|
|
64
|
+
border-radius: 4px;
|
|
65
|
+
font-size: 12px;
|
|
66
|
+
pointer-events: none;
|
|
67
|
+
z-index: 1000;
|
|
68
|
+
display: none;
|
|
69
|
+
box-shadow: 0 2px 8px rgba(0,0,0,.4);
|
|
70
|
+
}}
|
|
71
|
+
.bar {{
|
|
72
|
+
height: 22px;
|
|
73
|
+
border-radius: 3px;
|
|
74
|
+
display: inline-block;
|
|
75
|
+
vertical-align: middle;
|
|
76
|
+
min-width: 2px;
|
|
77
|
+
transition: width 0.3s;
|
|
78
|
+
}}
|
|
79
|
+
.controls {{ margin-bottom: 16px; }}
|
|
80
|
+
select, input[type=range] {{ margin: 0 8px; }}
|
|
81
|
+
select {{
|
|
82
|
+
background: {_ACCENT};
|
|
83
|
+
color: {_TEXT};
|
|
84
|
+
border: none;
|
|
85
|
+
padding: 4px 8px;
|
|
86
|
+
border-radius: 4px;
|
|
87
|
+
cursor: pointer;
|
|
88
|
+
}}
|
|
89
|
+
label {{ color: {_DIM_TEXT}; font-size: 0.9em; }}
|
|
90
|
+
.token {{
|
|
91
|
+
display: inline-block;
|
|
92
|
+
padding: 2px 4px;
|
|
93
|
+
margin: 1px;
|
|
94
|
+
border-radius: 3px;
|
|
95
|
+
cursor: pointer;
|
|
96
|
+
transition: background 0.2s;
|
|
97
|
+
}}
|
|
98
|
+
.filter-btn {{
|
|
99
|
+
background: {_ACCENT};
|
|
100
|
+
color: {_TEXT};
|
|
101
|
+
border: none;
|
|
102
|
+
padding: 4px 12px;
|
|
103
|
+
border-radius: 4px;
|
|
104
|
+
cursor: pointer;
|
|
105
|
+
margin: 2px;
|
|
106
|
+
font-size: 0.85em;
|
|
107
|
+
}}
|
|
108
|
+
.filter-btn.active {{ background: {_HIGHLIGHT}; }}
|
|
109
|
+
{extra_css}
|
|
110
|
+
</style>
|
|
111
|
+
</head>
|
|
112
|
+
<body>
|
|
113
|
+
{body}
|
|
114
|
+
<div class="tooltip" id="tooltip"></div>
|
|
115
|
+
<script>
|
|
116
|
+
const tooltip = document.getElementById('tooltip');
|
|
117
|
+
function showTip(e, text) {{
|
|
118
|
+
tooltip.textContent = text;
|
|
119
|
+
tooltip.style.display = 'block';
|
|
120
|
+
tooltip.style.left = (e.clientX + 12) + 'px';
|
|
121
|
+
tooltip.style.top = (e.clientY - 30) + 'px';
|
|
122
|
+
}}
|
|
123
|
+
function hideTip() {{ tooltip.style.display = 'none'; }}
|
|
124
|
+
{extra_js}
|
|
125
|
+
</script>
|
|
126
|
+
</body>
|
|
127
|
+
</html>"""
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def html_attention(
|
|
131
|
+
attention_data: list[dict[str, Any]],
|
|
132
|
+
tokens: list[str] | None,
|
|
133
|
+
) -> str:
|
|
134
|
+
"""Generate an interactive attention heatmap HTML page.
|
|
135
|
+
|
|
136
|
+
Each head is rendered as a grid. Click a head to expand it.
|
|
137
|
+
Hover cells for exact attention scores. Dropdown to select layer.
|
|
138
|
+
"""
|
|
139
|
+
if not attention_data:
|
|
140
|
+
return _wrap_page("Attention", "<h1>Attention</h1><p>No attention data.</p>")
|
|
141
|
+
|
|
142
|
+
layers: dict[int, list[dict]] = {}
|
|
143
|
+
for entry in attention_data:
|
|
144
|
+
layers.setdefault(entry["layer"], []).append(entry)
|
|
145
|
+
|
|
146
|
+
tok_labels = tokens or [str(i) for i in range(max(len(entry.get("weights", [])) for entry in attention_data))]
|
|
147
|
+
tok_json = json.dumps([html.escape(t) for t in tok_labels])
|
|
148
|
+
data_json = json.dumps({
|
|
149
|
+
str(layer): [
|
|
150
|
+
{
|
|
151
|
+
"head": e["head"],
|
|
152
|
+
"weights": [[round(float(w), 4) for w in row] for row in e.get("weights", [])],
|
|
153
|
+
"entropy": round(e.get("entropy", 0.0), 3),
|
|
154
|
+
}
|
|
155
|
+
for e in heads
|
|
156
|
+
]
|
|
157
|
+
for layer, heads in sorted(layers.items())
|
|
158
|
+
})
|
|
159
|
+
layer_options = "".join(
|
|
160
|
+
f'<option value="{layer}">Layer {layer}</option>' for layer in sorted(layers.keys())
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
body = f"""
|
|
164
|
+
<h1>Attention Patterns</h1>
|
|
165
|
+
<div class="controls panel">
|
|
166
|
+
<label>Layer: <select id="layerSelect" onchange="renderLayer(this.value)">
|
|
167
|
+
{layer_options}
|
|
168
|
+
</select></label>
|
|
169
|
+
</div>
|
|
170
|
+
<div id="headsContainer"></div>
|
|
171
|
+
"""
|
|
172
|
+
|
|
173
|
+
js = f"""
|
|
174
|
+
const DATA = {data_json};
|
|
175
|
+
const TOKENS = {tok_json};
|
|
176
|
+
|
|
177
|
+
function colorForWeight(w) {{
|
|
178
|
+
const r = Math.round(233 * w + 22 * (1-w));
|
|
179
|
+
const g = Math.round(69 * w + 33 * (1-w));
|
|
180
|
+
const b = Math.round(96 * w + 62 * (1-w));
|
|
181
|
+
return `rgb(${{r}},${{g}},${{b}})`;
|
|
182
|
+
}}
|
|
183
|
+
|
|
184
|
+
function renderLayer(layer) {{
|
|
185
|
+
const container = document.getElementById('headsContainer');
|
|
186
|
+
const heads = DATA[layer] || [];
|
|
187
|
+
let html = '';
|
|
188
|
+
for (const h of heads) {{
|
|
189
|
+
const n = h.weights.length;
|
|
190
|
+
let grid = '<div style="display:inline-grid;grid-template-columns:auto ' +
|
|
191
|
+
'repeat(' + n + ', 28px);gap:1px;margin:8px 0">';
|
|
192
|
+
grid += '<div></div>';
|
|
193
|
+
for (let j = 0; j < n; j++) {{
|
|
194
|
+
grid += '<div style="font-size:9px;color:{_DIM_TEXT};text-align:center;overflow:hidden;max-width:28px">' + TOKENS[j] + '</div>';
|
|
195
|
+
}}
|
|
196
|
+
for (let i = 0; i < n; i++) {{
|
|
197
|
+
grid += '<div style="font-size:9px;color:{_DIM_TEXT};text-align:right;padding-right:4px">' + TOKENS[i] + '</div>';
|
|
198
|
+
for (let j = 0; j < n; j++) {{
|
|
199
|
+
const w = h.weights[i][j];
|
|
200
|
+
grid += '<div class="heatmap-cell" style="background:' + colorForWeight(w) + '"' +
|
|
201
|
+
' onmousemove="showTip(event, TOKENS['+i+']+\\'→\\'+TOKENS['+j+']+\\': \\'+' + w.toFixed(4) + ')"' +
|
|
202
|
+
' onmouseleave="hideTip()"></div>';
|
|
203
|
+
}}
|
|
204
|
+
}}
|
|
205
|
+
grid += '</div>';
|
|
206
|
+
html += '<div class="panel"><h2>Head ' + h.head + ' <span style="color:{_DIM_TEXT};font-size:0.8em">(entropy: ' + h.entropy + ')</span></h2>' + grid + '</div>';
|
|
207
|
+
}}
|
|
208
|
+
container.innerHTML = html;
|
|
209
|
+
}}
|
|
210
|
+
|
|
211
|
+
renderLayer(Object.keys(DATA)[0] || '0');
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
return _wrap_page("Attention Patterns", body, extra_js=js)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def html_trace(results: list[dict[str, Any]]) -> str:
|
|
218
|
+
"""Generate an interactive causal trace HTML page.
|
|
219
|
+
|
|
220
|
+
Sortable horizontal bar chart. Hover for exact effect values.
|
|
221
|
+
Click filter buttons to filter by role.
|
|
222
|
+
"""
|
|
223
|
+
if not results:
|
|
224
|
+
return _wrap_page("Causal Trace", "<h1>Causal Trace</h1><p>No results.</p>")
|
|
225
|
+
|
|
226
|
+
data_json = json.dumps([
|
|
227
|
+
{
|
|
228
|
+
"module": r["module"],
|
|
229
|
+
"effect": round(r["effect"], 4),
|
|
230
|
+
"role": r.get("role", ""),
|
|
231
|
+
}
|
|
232
|
+
for r in results
|
|
233
|
+
])
|
|
234
|
+
|
|
235
|
+
roles = sorted(set(r.get("role", "") for r in results if r.get("role")))
|
|
236
|
+
role_buttons = "".join(
|
|
237
|
+
f'<button class="filter-btn" onclick="toggleFilter(\'{role}\')">{role}</button>'
|
|
238
|
+
for role in roles
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
body = f"""
|
|
242
|
+
<h1>Causal Trace</h1>
|
|
243
|
+
<div class="panel controls">
|
|
244
|
+
<label>Filter by role:</label>
|
|
245
|
+
<button class="filter-btn active" onclick="toggleFilter('all')">All</button>
|
|
246
|
+
{role_buttons}
|
|
247
|
+
<span style="margin-left:16px;color:{_DIM_TEXT}">Click bars for details</span>
|
|
248
|
+
</div>
|
|
249
|
+
<div id="chartContainer" class="panel"></div>
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
js = f"""
|
|
253
|
+
const TRACE_DATA = {data_json};
|
|
254
|
+
let activeFilter = 'all';
|
|
255
|
+
|
|
256
|
+
function toggleFilter(role) {{
|
|
257
|
+
activeFilter = role;
|
|
258
|
+
document.querySelectorAll('.filter-btn').forEach(b => {{
|
|
259
|
+
b.classList.toggle('active', b.textContent === role || (role === 'all' && b.textContent === 'All'));
|
|
260
|
+
}});
|
|
261
|
+
renderChart();
|
|
262
|
+
}}
|
|
263
|
+
|
|
264
|
+
function renderChart() {{
|
|
265
|
+
const filtered = activeFilter === 'all' ? TRACE_DATA : TRACE_DATA.filter(d => d.role === activeFilter);
|
|
266
|
+
const maxEffect = Math.max(...filtered.map(d => d.effect), 0.001);
|
|
267
|
+
const container = document.getElementById('chartContainer');
|
|
268
|
+
let html = '<table><tr><th>Module</th><th>Role</th><th style="width:50%">Effect</th><th>Value</th></tr>';
|
|
269
|
+
for (const d of filtered) {{
|
|
270
|
+
const pct = (d.effect / maxEffect * 100).toFixed(1);
|
|
271
|
+
html += '<tr onmousemove="showTip(event, \\'' + d.module + ': ' + d.effect.toFixed(4) + '\\')" onmouseleave="hideTip()">' +
|
|
272
|
+
'<td style="color:{_TEXT};font-family:monospace;font-size:0.85em">' + d.module + '</td>' +
|
|
273
|
+
'<td style="color:{_DIM_TEXT}">' + (d.role || '-') + '</td>' +
|
|
274
|
+
'<td><div class="bar" style="width:' + pct + '%;background:linear-gradient(90deg,{_GREEN},{_HIGHLIGHT})"></div></td>' +
|
|
275
|
+
'<td style="text-align:right;font-weight:600">' + d.effect.toFixed(4) + '</td></tr>';
|
|
276
|
+
}}
|
|
277
|
+
html += '</table>';
|
|
278
|
+
container.innerHTML = html;
|
|
279
|
+
}}
|
|
280
|
+
|
|
281
|
+
renderChart();
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
return _wrap_page("Causal Trace", body, extra_js=js)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def html_attribution(
|
|
288
|
+
tokens: list[str],
|
|
289
|
+
scores: list[float],
|
|
290
|
+
) -> str:
|
|
291
|
+
"""Generate an interactive token attribution HTML page.
|
|
292
|
+
|
|
293
|
+
Slider to adjust display threshold. Hover tokens for exact scores.
|
|
294
|
+
"""
|
|
295
|
+
if not tokens or not scores:
|
|
296
|
+
return _wrap_page("Attribution", "<h1>Attribution</h1><p>No data.</p>")
|
|
297
|
+
|
|
298
|
+
data_json = json.dumps([
|
|
299
|
+
{"token": html.escape(t), "score": round(s, 6)}
|
|
300
|
+
for t, s in zip(tokens, scores)
|
|
301
|
+
])
|
|
302
|
+
|
|
303
|
+
max_score = max(abs(s) for s in scores) if scores else 1.0
|
|
304
|
+
|
|
305
|
+
body = f"""
|
|
306
|
+
<h1>Attribution (Gradient Saliency)</h1>
|
|
307
|
+
<div class="panel controls">
|
|
308
|
+
<label>Threshold: <input type="range" id="threshold" min="0" max="100" value="0"
|
|
309
|
+
oninput="renderAttribution(this.value / 100)">
|
|
310
|
+
<span id="threshVal">0%</span></label>
|
|
311
|
+
</div>
|
|
312
|
+
<div class="panel">
|
|
313
|
+
<h2>Token Coloring</h2>
|
|
314
|
+
<div id="tokenContainer" style="margin:12px 0;line-height:2.2"></div>
|
|
315
|
+
</div>
|
|
316
|
+
<div class="panel">
|
|
317
|
+
<h2>Ranked Tokens</h2>
|
|
318
|
+
<div id="rankedContainer"></div>
|
|
319
|
+
</div>
|
|
320
|
+
"""
|
|
321
|
+
|
|
322
|
+
js = f"""
|
|
323
|
+
const ATTR_DATA = {data_json};
|
|
324
|
+
const MAX_SCORE = {max_score};
|
|
325
|
+
|
|
326
|
+
function scoreColor(intensity) {{
|
|
327
|
+
if (intensity > 0.7) return 'rgba(233,69,96,0.85)';
|
|
328
|
+
if (intensity > 0.4) return 'rgba(255,193,7,0.65)';
|
|
329
|
+
if (intensity > 0.15) return 'rgba(255,255,255,0.15)';
|
|
330
|
+
return 'transparent';
|
|
331
|
+
}}
|
|
332
|
+
|
|
333
|
+
function renderAttribution(threshold) {{
|
|
334
|
+
document.getElementById('threshVal').textContent = (threshold * 100).toFixed(0) + '%';
|
|
335
|
+
const tc = document.getElementById('tokenContainer');
|
|
336
|
+
const rc = document.getElementById('rankedContainer');
|
|
337
|
+
let tokHtml = '';
|
|
338
|
+
for (const d of ATTR_DATA) {{
|
|
339
|
+
const intensity = Math.abs(d.score) / MAX_SCORE;
|
|
340
|
+
if (intensity < threshold) {{
|
|
341
|
+
tokHtml += '<span class="token" style="opacity:0.3" onmousemove="showTip(event, \\'' + d.token + ': ' + d.score.toFixed(4) + '\\')" onmouseleave="hideTip()">' + d.token + '</span>';
|
|
342
|
+
}} else {{
|
|
343
|
+
tokHtml += '<span class="token" style="background:' + scoreColor(intensity) + '" onmousemove="showTip(event, \\'' + d.token + ': ' + d.score.toFixed(4) + '\\')" onmouseleave="hideTip()">' + d.token + '</span>';
|
|
344
|
+
}}
|
|
345
|
+
}}
|
|
346
|
+
tc.innerHTML = tokHtml;
|
|
347
|
+
|
|
348
|
+
const sorted = [...ATTR_DATA].sort((a,b) => Math.abs(b.score) - Math.abs(a.score));
|
|
349
|
+
let rankHtml = '<table><tr><th>Token</th><th style="width:60%">Score</th><th>Value</th></tr>';
|
|
350
|
+
for (const d of sorted.slice(0, 20)) {{
|
|
351
|
+
const intensity = Math.abs(d.score) / MAX_SCORE;
|
|
352
|
+
if (intensity < threshold) continue;
|
|
353
|
+
const pct = (intensity * 100).toFixed(1);
|
|
354
|
+
rankHtml += '<tr><td style="font-weight:600">' + d.token + '</td>' +
|
|
355
|
+
'<td><div class="bar" style="width:' + pct + '%;background:{_HIGHLIGHT}"></div></td>' +
|
|
356
|
+
'<td style="text-align:right">' + d.score.toFixed(4) + '</td></tr>';
|
|
357
|
+
}}
|
|
358
|
+
rankHtml += '</table>';
|
|
359
|
+
rc.innerHTML = rankHtml;
|
|
360
|
+
}}
|
|
361
|
+
|
|
362
|
+
renderAttribution(0);
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
return _wrap_page("Attribution", body, extra_js=js)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def save_html(content: str, path: str) -> None:
|
|
369
|
+
"""Write HTML content to a file and print confirmation."""
|
|
370
|
+
from pathlib import Path
|
|
371
|
+
|
|
372
|
+
from rich.console import Console
|
|
373
|
+
|
|
374
|
+
Path(path).write_text(content, encoding="utf-8")
|
|
375
|
+
Console().print(f" Interactive HTML saved to [bold]{path}[/bold]")
|
interpkit/core/inputs.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Universal input loader — text, images, raw tensors."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def prepare_input(
|
|
13
|
+
raw: str | torch.Tensor | Any,
|
|
14
|
+
*,
|
|
15
|
+
tokenizer: Any | None = None,
|
|
16
|
+
image_processor: Any | None = None,
|
|
17
|
+
device: torch.device | str = "cpu",
|
|
18
|
+
) -> dict[str, torch.Tensor] | torch.Tensor:
|
|
19
|
+
"""Normalise a user-provided input into model-ready tensors.
|
|
20
|
+
|
|
21
|
+
Dispatch order:
|
|
22
|
+
1. ``torch.Tensor`` → return as-is (moved to *device*).
|
|
23
|
+
2. ``dict`` of tensors → return as-is (moved to *device*).
|
|
24
|
+
3. ``str`` that looks like an image path → load image and preprocess.
|
|
25
|
+
4. ``str`` → tokenize with *tokenizer*.
|
|
26
|
+
"""
|
|
27
|
+
if isinstance(raw, dict):
|
|
28
|
+
return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in raw.items()}
|
|
29
|
+
|
|
30
|
+
if isinstance(raw, torch.Tensor):
|
|
31
|
+
return raw.to(device)
|
|
32
|
+
|
|
33
|
+
if isinstance(raw, str):
|
|
34
|
+
# Image file?
|
|
35
|
+
if _looks_like_image_path(raw):
|
|
36
|
+
return _load_image(raw, image_processor=image_processor, device=device)
|
|
37
|
+
|
|
38
|
+
# .pt tensor file?
|
|
39
|
+
if raw.endswith(".pt"):
|
|
40
|
+
tensor = torch.load(raw, map_location=device, weights_only=True)
|
|
41
|
+
if isinstance(tensor, torch.Tensor):
|
|
42
|
+
return tensor
|
|
43
|
+
return tensor # could be a dict
|
|
44
|
+
|
|
45
|
+
# Text
|
|
46
|
+
if tokenizer is None:
|
|
47
|
+
raise ValueError(
|
|
48
|
+
f"Cannot tokenize string input — no tokenizer available. "
|
|
49
|
+
f"Pass a tokenizer when loading the model or provide a torch.Tensor directly."
|
|
50
|
+
)
|
|
51
|
+
encoded = tokenizer(raw, return_tensors="pt")
|
|
52
|
+
return {k: v.to(device) for k, v in encoded.items()}
|
|
53
|
+
|
|
54
|
+
raise TypeError(f"Unsupported input type: {type(raw).__name__}")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def prepare_pair(
|
|
58
|
+
raw_a: str | torch.Tensor | Any,
|
|
59
|
+
raw_b: str | torch.Tensor | Any,
|
|
60
|
+
*,
|
|
61
|
+
tokenizer: Any | None = None,
|
|
62
|
+
image_processor: Any | None = None,
|
|
63
|
+
device: torch.device | str = "cpu",
|
|
64
|
+
) -> tuple[dict[str, torch.Tensor] | torch.Tensor, dict[str, torch.Tensor] | torch.Tensor]:
|
|
65
|
+
"""Prepare two inputs for paired operations (patching, tracing).
|
|
66
|
+
|
|
67
|
+
For text inputs, both are tokenized together with padding so they
|
|
68
|
+
have the same sequence length — required for activation patching.
|
|
69
|
+
"""
|
|
70
|
+
both_text = isinstance(raw_a, str) and isinstance(raw_b, str)
|
|
71
|
+
both_text = both_text and not _looks_like_image_path(raw_a) and not _looks_like_image_path(raw_b)
|
|
72
|
+
both_text = both_text and not raw_a.endswith(".pt") and not raw_b.endswith(".pt")
|
|
73
|
+
|
|
74
|
+
if both_text and tokenizer is not None:
|
|
75
|
+
encoded = tokenizer(
|
|
76
|
+
[raw_a, raw_b],
|
|
77
|
+
return_tensors="pt",
|
|
78
|
+
padding=True,
|
|
79
|
+
)
|
|
80
|
+
input_a = {k: v[0:1].to(device) for k, v in encoded.items()}
|
|
81
|
+
input_b = {k: v[1:2].to(device) for k, v in encoded.items()}
|
|
82
|
+
return input_a, input_b
|
|
83
|
+
|
|
84
|
+
a = prepare_input(raw_a, tokenizer=tokenizer, image_processor=image_processor, device=device)
|
|
85
|
+
b = prepare_input(raw_b, tokenizer=tokenizer, image_processor=image_processor, device=device)
|
|
86
|
+
return a, b
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _looks_like_image_path(s: str) -> bool:
|
|
90
|
+
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp"}
|
|
91
|
+
ext = os.path.splitext(s)[1].lower()
|
|
92
|
+
return ext in _IMAGE_EXTS and (Path(s).exists() or not s.startswith("/"))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _load_image(
|
|
96
|
+
path: str,
|
|
97
|
+
*,
|
|
98
|
+
image_processor: Any | None = None,
|
|
99
|
+
device: torch.device | str = "cpu",
|
|
100
|
+
) -> dict[str, torch.Tensor] | torch.Tensor:
|
|
101
|
+
from PIL import Image
|
|
102
|
+
|
|
103
|
+
img = Image.open(path).convert("RGB")
|
|
104
|
+
|
|
105
|
+
if image_processor is not None:
|
|
106
|
+
processed = image_processor(images=img, return_tensors="pt")
|
|
107
|
+
return {k: v.to(device) for k, v in processed.items()}
|
|
108
|
+
|
|
109
|
+
from torchvision import transforms # type: ignore[import-untyped]
|
|
110
|
+
|
|
111
|
+
transform = transforms.Compose([
|
|
112
|
+
transforms.Resize(256),
|
|
113
|
+
transforms.CenterCrop(224),
|
|
114
|
+
transforms.ToTensor(),
|
|
115
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
116
|
+
])
|
|
117
|
+
return transform(img).unsqueeze(0).to(device)
|