univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,399 @@
1
+ # univi/interpretability.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Any, Optional, Tuple, List, Sequence, Literal, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ # -----------------------------
14
+ # Utilities: dense conversion
15
+ # -----------------------------
16
+ def _to_tensor(X, device: str, dtype=torch.float32) -> torch.Tensor:
17
+ if torch.is_tensor(X):
18
+ return X.to(device=device, dtype=dtype)
19
+ X = np.asarray(X)
20
+ return torch.as_tensor(X, device=device, dtype=dtype)
21
+
22
+ @dataclass
23
+ class TokenMap:
24
+ """Mapping from fused token positions -> modality + feature identity."""
25
+ modalities: List[str]
26
+ # token slices in the concatenated sequence (excluding global CLS if present)
27
+ slices: Dict[str, Tuple[int, int]]
28
+ # modality-specific meta returned by tokenizer
29
+ meta: Dict[str, Any]
30
+ # whether a global CLS token was prepended
31
+ has_global_cls: bool
32
+
33
+
34
+ def select_attn_matrix(
35
+ attn_all: List[torch.Tensor],
36
+ *,
37
+ layer: int = -1,
38
+ reduce_heads: Literal["mean", "none"] = "mean",
39
+ reduce_batch: Literal["mean", "none"] = "mean",
40
+ ) -> torch.Tensor:
41
+ """
42
+ attn_all[layer] is either (B,T,T) if heads averaged, or (B,H,T,T) if not.
43
+ Returns:
44
+ - (T,T) if reduce_batch="mean"
45
+ - (B,T,T) if reduce_batch="none"
46
+ """
47
+ attn = attn_all[layer]
48
+
49
+ if attn.dim() == 4: # (B,H,T,T)
50
+ if reduce_heads == "mean":
51
+ attn = attn.mean(dim=1) # (B,T,T)
52
+ else:
53
+ raise ValueError("reduce_heads='none' not supported here (pick a head upstream or add support).")
54
+
55
+ if attn.dim() != 3:
56
+ raise ValueError(f"Expected attn as (B,T,T) after head reduce; got {tuple(attn.shape)}")
57
+
58
+ if reduce_batch == "mean":
59
+ return attn.mean(dim=0) # (T,T)
60
+ return attn # (B,T,T)
61
+
62
+ # -----------------------------
63
+ # Getting tokens + attention from your fused encoder
64
+ # -----------------------------
65
+ @torch.no_grad()
66
+ def fused_encode_with_meta_and_attn(
67
+ model: nn.Module,
68
+ x_dict: Dict[str, torch.Tensor],
69
+ *,
70
+ return_attn: bool = True,
71
+ attn_average_heads: bool = True,
72
+ ) -> Tuple[torch.Tensor, torch.Tensor, TokenMap, Optional[List[torch.Tensor]]]:
73
+ if not hasattr(model, "fused_encoder") or model.fused_encoder is None:
74
+ raise ValueError("Model has no fused_encoder. Set fused_encoder_type='multimodal_transformer'.")
75
+
76
+ fused = model.fused_encoder
77
+
78
+ # Fused encoder now returns one of:
79
+ # (mu, logvar)
80
+ # (mu, logvar, meta)
81
+ # (mu, logvar, attn_all)
82
+ # (mu, logvar, attn_all, meta)
83
+ ret = fused(
84
+ x_dict,
85
+ return_token_meta=True,
86
+ return_attn=return_attn,
87
+ attn_average_heads=attn_average_heads,
88
+ )
89
+
90
+ if return_attn:
91
+ mu, logvar, attn_all, meta = ret
92
+ else:
93
+ mu, logvar, meta = ret
94
+ attn_all = None
95
+
96
+ tokmap = TokenMap(
97
+ modalities=list(meta["modalities"]),
98
+ slices=dict(meta["slices"]), # slices exclude CLS; you handle shift elsewhere
99
+ meta={m: meta.get(m, {}) for m in meta["modalities"]},
100
+ has_global_cls=bool(getattr(fused, "use_global_cls", False)),
101
+ )
102
+ return mu, logvar, tokmap, attn_all
103
+
104
+
105
+ # -----------------------------
106
+ # Map tokens back to feature names
107
+ # -----------------------------
108
+ def token_to_feature_names(
109
+ tokmap: TokenMap,
110
+ *,
111
+ var_names_by_mod: Dict[str, Sequence[str]],
112
+ tokenizer_mode_by_mod: Dict[str, str],
113
+ patch_size_by_mod: Optional[Dict[str, int]] = None,
114
+ ) -> Dict[str, List[str]]:
115
+ """
116
+ Returns for each modality: a list of length T_mod giving a human-readable
117
+ feature name per token position (token order as produced by tokenizer).
118
+ """
119
+ out: Dict[str, List[str]] = {}
120
+
121
+ for m in tokmap.modalities:
122
+ mode = tokenizer_mode_by_mod[m]
123
+ var_names = list(map(str, var_names_by_mod[m]))
124
+
125
+ if mode in ("topk_scalar", "topk_channels"):
126
+ topk_idx = tokmap.meta[m].get("topk_idx", None)
127
+ if topk_idx is None:
128
+ raise ValueError(f"Missing topk_idx for modality {m}. Ensure return_token_meta=True.")
129
+ # topk_idx is (B,K). For naming tokens, pick a canonical naming:
130
+ # we will name by index position; actual per-cell chosen features vary.
131
+ # We'll return placeholders; per-cell mapping happens later.
132
+ K = int(topk_idx.shape[1])
133
+ out[m] = [f"{m}:topk_token_{i}" for i in range(K)]
134
+ elif mode == "patch":
135
+ if patch_size_by_mod is None or m not in patch_size_by_mod:
136
+ raise ValueError(f"patch_size_by_mod is required for patch token naming for modality {m}")
137
+ P = int(patch_size_by_mod[m])
138
+ T = (len(var_names) + P - 1) // P
139
+ names = []
140
+ for t in range(T):
141
+ a = t * P
142
+ b = min((t + 1) * P, len(var_names))
143
+ names.append(f"{m}:patch[{t}] {var_names[a]}..{var_names[b-1]}")
144
+ out[m] = names
145
+ else:
146
+ raise ValueError(f"Unknown tokenizer mode {mode!r} for modality {m}")
147
+
148
+ return out
149
+
150
+
151
+ # -----------------------------
152
+ # Feature-level attribution (Integrated Gradients or grad×input)
153
+ # -----------------------------
154
+ def integrated_gradients(
155
+ f, # function mapping inputs -> scalar
156
+ x: torch.Tensor, # (B,F)
157
+ baseline: Optional[torch.Tensor] = None,
158
+ steps: int = 32,
159
+ ) -> torch.Tensor:
160
+ """
161
+ Standard IG: returns attribution of same shape as x.
162
+ """
163
+ if baseline is None:
164
+ baseline = torch.zeros_like(x)
165
+
166
+ # interpolate
167
+ alphas = torch.linspace(0.0, 1.0, steps=steps, device=x.device, dtype=x.dtype).view(steps, 1, 1)
168
+ x0 = baseline.unsqueeze(0)
169
+ x1 = x.unsqueeze(0)
170
+ xs = x0 + alphas * (x1 - x0) # (S,B,F)
171
+
172
+ grads = []
173
+ for s in range(steps):
174
+ xi = xs[s].detach().clone().requires_grad_(True)
175
+ y = f(xi)
176
+ if y.dim() != 0:
177
+ y = y.sum()
178
+ g = torch.autograd.grad(y, xi, retain_graph=False, create_graph=False)[0]
179
+ grads.append(g)
180
+
181
+ grads = torch.stack(grads, dim=0) # (S,B,F)
182
+ avg_grads = grads.mean(dim=0) # (B,F)
183
+ return (x - baseline) * avg_grads # (B,F)
184
+
185
+
186
+ def feature_importance_for_head(
187
+ model: nn.Module,
188
+ x_dict: Dict[str, torch.Tensor],
189
+ *,
190
+ head_name: str,
191
+ class_index: int,
192
+ method: Literal["grad_x_input", "ig"] = "ig",
193
+ ig_steps: int = 32,
194
+ per_cell_topk: int = 50,
195
+ feature_names_by_mod: Optional[Dict[str, Sequence[str]]] = None,
196
+ ) -> Dict[str, Any]:
197
+ """
198
+ Computes feature attributions for predicting `head_name` class `class_index`
199
+ using the fused latent (best for multimodal interpretability).
200
+
201
+ Returns:
202
+ {modality: {names, scores, indices}} with top features.
203
+ """
204
+ device = next(model.parameters()).device
205
+ model.eval()
206
+
207
+ # Clone inputs and ensure grad
208
+ x_in = {m: x_dict[m].detach().clone().to(device=device) for m in x_dict.keys()}
209
+ for m in x_in:
210
+ x_in[m].requires_grad_(True)
211
+
212
+ # Define scalar function: logit(class_index) from the fused representation
213
+ def scalar_from_inputs(x_mod: torch.Tensor, mod: str) -> torch.Tensor:
214
+ # This wrapper is used for IG per modality, so we rebuild x_dict each time.
215
+ xd = {k: (x_mod if k == mod else x_in[k]) for k in x_in.keys()}
216
+
217
+ # You can either:
218
+ # (A) call model.predict_heads on fused z (preferred), or
219
+ # (B) call model(...) and read out head logits.
220
+ #
221
+ # Easiest robust path: call model(x_dict) and extract head logits.
222
+ out = model(xd, epoch=0, y=None)
223
+ if "head_logits" in out and head_name in out["head_logits"]:
224
+ logits = out["head_logits"][head_name] # (B,C)
225
+ else:
226
+ # fallback: use helper if available
227
+ probs_or_logits = model.predict_heads(xd, return_probs=False)
228
+ logits = probs_or_logits[head_name]
229
+
230
+ return logits[:, int(class_index)].sum()
231
+
232
+ results: Dict[str, Any] = {}
233
+
234
+ for mod in x_in.keys():
235
+ if method == "grad_x_input":
236
+ out = model(x_in, epoch=0, y=None)
237
+ if "head_logits" in out and head_name in out["head_logits"]:
238
+ logits = out["head_logits"][head_name]
239
+ else:
240
+ logits = model.predict_heads(x_in, return_probs=False)[head_name]
241
+
242
+ score = logits[:, int(class_index)].sum()
243
+ grads = torch.autograd.grad(score, x_in[mod], retain_graph=True)[0]
244
+ attr = grads * x_in[mod] # (B,F)
245
+
246
+ else: # IG
247
+ attr = integrated_gradients(
248
+ lambda xm: scalar_from_inputs(xm, mod),
249
+ x_in[mod],
250
+ baseline=torch.zeros_like(x_in[mod]),
251
+ steps=int(ig_steps),
252
+ )
253
+
254
+ # Aggregate across cells (mean abs is usually a good default)
255
+ attr_agg = attr.detach().abs().mean(dim=0) # (F,)
256
+
257
+ k = min(int(per_cell_topk), int(attr_agg.numel()))
258
+ vals, idx = torch.topk(attr_agg, k=k, largest=True, sorted=True)
259
+
260
+ names = None
261
+ if feature_names_by_mod is not None and mod in feature_names_by_mod:
262
+ vn = list(map(str, feature_names_by_mod[mod]))
263
+ names = [vn[i] for i in idx.detach().cpu().tolist()]
264
+
265
+ results[mod] = {
266
+ "indices": idx.detach().cpu().numpy(),
267
+ "scores": vals.detach().cpu().numpy(),
268
+ "names": names,
269
+ }
270
+
271
+ return results
272
+
273
+
274
+ # -----------------------------
275
+ # Cross-modal token interaction from attention
276
+ # -----------------------------
277
+ def top_cross_modal_attention_pairs(
278
+ attn: torch.Tensor,
279
+ tokmap: TokenMap,
280
+ *,
281
+ mod_a: str,
282
+ mod_b: str,
283
+ top_n: int = 50,
284
+ reduce: Literal["mean", "max"] = "mean",
285
+ ) -> List[Tuple[str, str, float]]:
286
+ """
287
+ Takes one attention matrix (T,T) and returns strongest A->B token pairs.
288
+ `attn` should already correspond to the fused token sequence *including* global CLS if present.
289
+ """
290
+ if attn.dim() != 2:
291
+ raise ValueError(f"Expected attn as (T,T), got {tuple(attn.shape)}")
292
+
293
+ # compute slices in the attention matrix
294
+ # If global CLS exists, token indices shift by +1 for everything else.
295
+ shift = 1 if tokmap.has_global_cls else 0
296
+ a0, a1 = tokmap.slices[mod_a]
297
+ b0, b1 = tokmap.slices[mod_b]
298
+ a0 += shift; a1 += shift
299
+ b0 += shift; b1 += shift
300
+
301
+ sub = attn[a0:a1, b0:b1] # (Ta,Tb)
302
+
303
+ # flatten and take top pairs
304
+ flat = sub.reshape(-1)
305
+ k = min(int(top_n), int(flat.numel()))
306
+ vals, idx = torch.topk(flat, k=k, largest=True, sorted=True)
307
+
308
+ Ta = a1 - a0
309
+ Tb = b1 - b0
310
+
311
+ pairs: List[Tuple[str, str, float]] = []
312
+ for v, ii in zip(vals.detach().cpu().tolist(), idx.detach().cpu().tolist()):
313
+ ia = ii // Tb
314
+ ib = ii % Tb
315
+ pairs.append((f"{mod_a}:token_{ia}", f"{mod_b}:token_{ib}", float(v)))
316
+
317
+ return pairs
318
+
319
+ @torch.no_grad()
320
+ def top_cross_modal_feature_pairs_from_attn(
321
+ attn_all: List[torch.Tensor],
322
+ tokmap: TokenMap,
323
+ *,
324
+ mod_a: str,
325
+ mod_b: str,
326
+ var_names_by_mod: Dict[str, Sequence[str]],
327
+ tokenizer_mode_by_mod: Dict[str, str],
328
+ layer: int = -1,
329
+ top_pairs_per_cell: int = 50,
330
+ top_n: int = 100,
331
+ ) -> List[Tuple[str, str, float]]:
332
+ """
333
+ Returns strongest (feature_a -> feature_b) pairs aggregated across the batch.
334
+
335
+ Works best when tokenizer mode for both modalities is topk_* (because tokens map
336
+ to feature indices via per-cell topk_idx).
337
+ """
338
+ mode_a = tokenizer_mode_by_mod[mod_a]
339
+ mode_b = tokenizer_mode_by_mod[mod_b]
340
+ if mode_a not in ("topk_scalar", "topk_channels") or mode_b not in ("topk_scalar", "topk_channels"):
341
+ raise ValueError("This helper currently supports only topk_* tokenizers for both modalities.")
342
+
343
+ A = attn_all[layer]
344
+ if A.dim() == 4: # (B,H,T,T)
345
+ A = A.mean(dim=1) # avg heads -> (B,T,T)
346
+ if A.dim() != 3:
347
+ raise ValueError(f"Expected attn as (B,T,T); got {tuple(A.shape)}")
348
+
349
+ Bsz, Ttot, _ = A.shape
350
+
351
+ shift = 1 if tokmap.has_global_cls else 0
352
+ a0, a1 = tokmap.slices[mod_a]
353
+ b0, b1 = tokmap.slices[mod_b]
354
+ a0 += shift; a1 += shift
355
+ b0 += shift; b1 += shift
356
+
357
+ Ta = a1 - a0
358
+ Tb = b1 - b0
359
+
360
+ sub = A[:, a0:a1, b0:b1] # (B,Ta,Tb)
361
+ flat = sub.reshape(Bsz, Ta * Tb) # (B, Ta*Tb)
362
+
363
+ k = min(int(top_pairs_per_cell), int(Ta * Tb))
364
+ vals, idx = torch.topk(flat, k=k, dim=1, largest=True, sorted=True) # (B,k)
365
+
366
+ ia = idx // Tb # (B,k) token index within A-slice
367
+ ib = idx % Tb # (B,k) token index within B-slice
368
+
369
+ topk_a = tokmap.meta[mod_a].get("topk_idx", None)
370
+ topk_b = tokmap.meta[mod_b].get("topk_idx", None)
371
+ if topk_a is None or topk_b is None:
372
+ raise ValueError("Missing topk_idx in tokmap.meta. Ensure return_token_meta=True.")
373
+
374
+ # Map token positions -> feature indices (per-cell)
375
+ fa = torch.gather(topk_a.to(device=ia.device), 1, ia) # (B,k)
376
+ fb = torch.gather(topk_b.to(device=ib.device), 1, ib) # (B,k)
377
+
378
+ # Aggregate into a python dict (sparse aggregation)
379
+ scores: Dict[Tuple[int, int], float] = {}
380
+ fa_np = fa.detach().cpu().numpy()
381
+ fb_np = fb.detach().cpu().numpy()
382
+ v_np = vals.detach().cpu().numpy()
383
+
384
+ for b in range(Bsz):
385
+ for j in range(k):
386
+ key = (int(fa_np[b, j]), int(fb_np[b, j]))
387
+ scores[key] = scores.get(key, 0.0) + float(v_np[b, j])
388
+
389
+ # Convert to top list with names
390
+ a_names = list(map(str, var_names_by_mod[mod_a]))
391
+ b_names = list(map(str, var_names_by_mod[mod_b]))
392
+
393
+ items = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[: int(top_n)]
394
+ out: List[Tuple[str, str, float]] = []
395
+ for (ia_feat, ib_feat), s in items:
396
+ out.append((a_names[ia_feat], b_names[ib_feat], float(s)))
397
+ return out
398
+
399
+