univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
# univi/interpretability.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Dict, Any, Optional, Tuple, List, Sequence, Literal, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from torch import nn
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# -----------------------------
|
|
14
|
+
# Utilities: dense conversion
|
|
15
|
+
# -----------------------------
|
|
16
|
+
def _to_tensor(X, device: str, dtype=torch.float32) -> torch.Tensor:
|
|
17
|
+
if torch.is_tensor(X):
|
|
18
|
+
return X.to(device=device, dtype=dtype)
|
|
19
|
+
X = np.asarray(X)
|
|
20
|
+
return torch.as_tensor(X, device=device, dtype=dtype)
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class TokenMap:
|
|
24
|
+
"""Mapping from fused token positions -> modality + feature identity."""
|
|
25
|
+
modalities: List[str]
|
|
26
|
+
# token slices in the concatenated sequence (excluding global CLS if present)
|
|
27
|
+
slices: Dict[str, Tuple[int, int]]
|
|
28
|
+
# modality-specific meta returned by tokenizer
|
|
29
|
+
meta: Dict[str, Any]
|
|
30
|
+
# whether a global CLS token was prepended
|
|
31
|
+
has_global_cls: bool
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def select_attn_matrix(
|
|
35
|
+
attn_all: List[torch.Tensor],
|
|
36
|
+
*,
|
|
37
|
+
layer: int = -1,
|
|
38
|
+
reduce_heads: Literal["mean", "none"] = "mean",
|
|
39
|
+
reduce_batch: Literal["mean", "none"] = "mean",
|
|
40
|
+
) -> torch.Tensor:
|
|
41
|
+
"""
|
|
42
|
+
attn_all[layer] is either (B,T,T) if heads averaged, or (B,H,T,T) if not.
|
|
43
|
+
Returns:
|
|
44
|
+
- (T,T) if reduce_batch="mean"
|
|
45
|
+
- (B,T,T) if reduce_batch="none"
|
|
46
|
+
"""
|
|
47
|
+
attn = attn_all[layer]
|
|
48
|
+
|
|
49
|
+
if attn.dim() == 4: # (B,H,T,T)
|
|
50
|
+
if reduce_heads == "mean":
|
|
51
|
+
attn = attn.mean(dim=1) # (B,T,T)
|
|
52
|
+
else:
|
|
53
|
+
raise ValueError("reduce_heads='none' not supported here (pick a head upstream or add support).")
|
|
54
|
+
|
|
55
|
+
if attn.dim() != 3:
|
|
56
|
+
raise ValueError(f"Expected attn as (B,T,T) after head reduce; got {tuple(attn.shape)}")
|
|
57
|
+
|
|
58
|
+
if reduce_batch == "mean":
|
|
59
|
+
return attn.mean(dim=0) # (T,T)
|
|
60
|
+
return attn # (B,T,T)
|
|
61
|
+
|
|
62
|
+
# -----------------------------
|
|
63
|
+
# Getting tokens + attention from your fused encoder
|
|
64
|
+
# -----------------------------
|
|
65
|
+
@torch.no_grad()
|
|
66
|
+
def fused_encode_with_meta_and_attn(
|
|
67
|
+
model: nn.Module,
|
|
68
|
+
x_dict: Dict[str, torch.Tensor],
|
|
69
|
+
*,
|
|
70
|
+
return_attn: bool = True,
|
|
71
|
+
attn_average_heads: bool = True,
|
|
72
|
+
) -> Tuple[torch.Tensor, torch.Tensor, TokenMap, Optional[List[torch.Tensor]]]:
|
|
73
|
+
if not hasattr(model, "fused_encoder") or model.fused_encoder is None:
|
|
74
|
+
raise ValueError("Model has no fused_encoder. Set fused_encoder_type='multimodal_transformer'.")
|
|
75
|
+
|
|
76
|
+
fused = model.fused_encoder
|
|
77
|
+
|
|
78
|
+
# Fused encoder now returns one of:
|
|
79
|
+
# (mu, logvar)
|
|
80
|
+
# (mu, logvar, meta)
|
|
81
|
+
# (mu, logvar, attn_all)
|
|
82
|
+
# (mu, logvar, attn_all, meta)
|
|
83
|
+
ret = fused(
|
|
84
|
+
x_dict,
|
|
85
|
+
return_token_meta=True,
|
|
86
|
+
return_attn=return_attn,
|
|
87
|
+
attn_average_heads=attn_average_heads,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if return_attn:
|
|
91
|
+
mu, logvar, attn_all, meta = ret
|
|
92
|
+
else:
|
|
93
|
+
mu, logvar, meta = ret
|
|
94
|
+
attn_all = None
|
|
95
|
+
|
|
96
|
+
tokmap = TokenMap(
|
|
97
|
+
modalities=list(meta["modalities"]),
|
|
98
|
+
slices=dict(meta["slices"]), # slices exclude CLS; you handle shift elsewhere
|
|
99
|
+
meta={m: meta.get(m, {}) for m in meta["modalities"]},
|
|
100
|
+
has_global_cls=bool(getattr(fused, "use_global_cls", False)),
|
|
101
|
+
)
|
|
102
|
+
return mu, logvar, tokmap, attn_all
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# -----------------------------
|
|
106
|
+
# Map tokens back to feature names
|
|
107
|
+
# -----------------------------
|
|
108
|
+
def token_to_feature_names(
|
|
109
|
+
tokmap: TokenMap,
|
|
110
|
+
*,
|
|
111
|
+
var_names_by_mod: Dict[str, Sequence[str]],
|
|
112
|
+
tokenizer_mode_by_mod: Dict[str, str],
|
|
113
|
+
patch_size_by_mod: Optional[Dict[str, int]] = None,
|
|
114
|
+
) -> Dict[str, List[str]]:
|
|
115
|
+
"""
|
|
116
|
+
Returns for each modality: a list of length T_mod giving a human-readable
|
|
117
|
+
feature name per token position (token order as produced by tokenizer).
|
|
118
|
+
"""
|
|
119
|
+
out: Dict[str, List[str]] = {}
|
|
120
|
+
|
|
121
|
+
for m in tokmap.modalities:
|
|
122
|
+
mode = tokenizer_mode_by_mod[m]
|
|
123
|
+
var_names = list(map(str, var_names_by_mod[m]))
|
|
124
|
+
|
|
125
|
+
if mode in ("topk_scalar", "topk_channels"):
|
|
126
|
+
topk_idx = tokmap.meta[m].get("topk_idx", None)
|
|
127
|
+
if topk_idx is None:
|
|
128
|
+
raise ValueError(f"Missing topk_idx for modality {m}. Ensure return_token_meta=True.")
|
|
129
|
+
# topk_idx is (B,K). For naming tokens, pick a canonical naming:
|
|
130
|
+
# we will name by index position; actual per-cell chosen features vary.
|
|
131
|
+
# We'll return placeholders; per-cell mapping happens later.
|
|
132
|
+
K = int(topk_idx.shape[1])
|
|
133
|
+
out[m] = [f"{m}:topk_token_{i}" for i in range(K)]
|
|
134
|
+
elif mode == "patch":
|
|
135
|
+
if patch_size_by_mod is None or m not in patch_size_by_mod:
|
|
136
|
+
raise ValueError(f"patch_size_by_mod is required for patch token naming for modality {m}")
|
|
137
|
+
P = int(patch_size_by_mod[m])
|
|
138
|
+
T = (len(var_names) + P - 1) // P
|
|
139
|
+
names = []
|
|
140
|
+
for t in range(T):
|
|
141
|
+
a = t * P
|
|
142
|
+
b = min((t + 1) * P, len(var_names))
|
|
143
|
+
names.append(f"{m}:patch[{t}] {var_names[a]}..{var_names[b-1]}")
|
|
144
|
+
out[m] = names
|
|
145
|
+
else:
|
|
146
|
+
raise ValueError(f"Unknown tokenizer mode {mode!r} for modality {m}")
|
|
147
|
+
|
|
148
|
+
return out
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
# -----------------------------
|
|
152
|
+
# Feature-level attribution (Integrated Gradients or grad×input)
|
|
153
|
+
# -----------------------------
|
|
154
|
+
def integrated_gradients(
|
|
155
|
+
f, # function mapping inputs -> scalar
|
|
156
|
+
x: torch.Tensor, # (B,F)
|
|
157
|
+
baseline: Optional[torch.Tensor] = None,
|
|
158
|
+
steps: int = 32,
|
|
159
|
+
) -> torch.Tensor:
|
|
160
|
+
"""
|
|
161
|
+
Standard IG: returns attribution of same shape as x.
|
|
162
|
+
"""
|
|
163
|
+
if baseline is None:
|
|
164
|
+
baseline = torch.zeros_like(x)
|
|
165
|
+
|
|
166
|
+
# interpolate
|
|
167
|
+
alphas = torch.linspace(0.0, 1.0, steps=steps, device=x.device, dtype=x.dtype).view(steps, 1, 1)
|
|
168
|
+
x0 = baseline.unsqueeze(0)
|
|
169
|
+
x1 = x.unsqueeze(0)
|
|
170
|
+
xs = x0 + alphas * (x1 - x0) # (S,B,F)
|
|
171
|
+
|
|
172
|
+
grads = []
|
|
173
|
+
for s in range(steps):
|
|
174
|
+
xi = xs[s].detach().clone().requires_grad_(True)
|
|
175
|
+
y = f(xi)
|
|
176
|
+
if y.dim() != 0:
|
|
177
|
+
y = y.sum()
|
|
178
|
+
g = torch.autograd.grad(y, xi, retain_graph=False, create_graph=False)[0]
|
|
179
|
+
grads.append(g)
|
|
180
|
+
|
|
181
|
+
grads = torch.stack(grads, dim=0) # (S,B,F)
|
|
182
|
+
avg_grads = grads.mean(dim=0) # (B,F)
|
|
183
|
+
return (x - baseline) * avg_grads # (B,F)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def feature_importance_for_head(
|
|
187
|
+
model: nn.Module,
|
|
188
|
+
x_dict: Dict[str, torch.Tensor],
|
|
189
|
+
*,
|
|
190
|
+
head_name: str,
|
|
191
|
+
class_index: int,
|
|
192
|
+
method: Literal["grad_x_input", "ig"] = "ig",
|
|
193
|
+
ig_steps: int = 32,
|
|
194
|
+
per_cell_topk: int = 50,
|
|
195
|
+
feature_names_by_mod: Optional[Dict[str, Sequence[str]]] = None,
|
|
196
|
+
) -> Dict[str, Any]:
|
|
197
|
+
"""
|
|
198
|
+
Computes feature attributions for predicting `head_name` class `class_index`
|
|
199
|
+
using the fused latent (best for multimodal interpretability).
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
{modality: {names, scores, indices}} with top features.
|
|
203
|
+
"""
|
|
204
|
+
device = next(model.parameters()).device
|
|
205
|
+
model.eval()
|
|
206
|
+
|
|
207
|
+
# Clone inputs and ensure grad
|
|
208
|
+
x_in = {m: x_dict[m].detach().clone().to(device=device) for m in x_dict.keys()}
|
|
209
|
+
for m in x_in:
|
|
210
|
+
x_in[m].requires_grad_(True)
|
|
211
|
+
|
|
212
|
+
# Define scalar function: logit(class_index) from the fused representation
|
|
213
|
+
def scalar_from_inputs(x_mod: torch.Tensor, mod: str) -> torch.Tensor:
|
|
214
|
+
# This wrapper is used for IG per modality, so we rebuild x_dict each time.
|
|
215
|
+
xd = {k: (x_mod if k == mod else x_in[k]) for k in x_in.keys()}
|
|
216
|
+
|
|
217
|
+
# You can either:
|
|
218
|
+
# (A) call model.predict_heads on fused z (preferred), or
|
|
219
|
+
# (B) call model(...) and read out head logits.
|
|
220
|
+
#
|
|
221
|
+
# Easiest robust path: call model(x_dict) and extract head logits.
|
|
222
|
+
out = model(xd, epoch=0, y=None)
|
|
223
|
+
if "head_logits" in out and head_name in out["head_logits"]:
|
|
224
|
+
logits = out["head_logits"][head_name] # (B,C)
|
|
225
|
+
else:
|
|
226
|
+
# fallback: use helper if available
|
|
227
|
+
probs_or_logits = model.predict_heads(xd, return_probs=False)
|
|
228
|
+
logits = probs_or_logits[head_name]
|
|
229
|
+
|
|
230
|
+
return logits[:, int(class_index)].sum()
|
|
231
|
+
|
|
232
|
+
results: Dict[str, Any] = {}
|
|
233
|
+
|
|
234
|
+
for mod in x_in.keys():
|
|
235
|
+
if method == "grad_x_input":
|
|
236
|
+
out = model(x_in, epoch=0, y=None)
|
|
237
|
+
if "head_logits" in out and head_name in out["head_logits"]:
|
|
238
|
+
logits = out["head_logits"][head_name]
|
|
239
|
+
else:
|
|
240
|
+
logits = model.predict_heads(x_in, return_probs=False)[head_name]
|
|
241
|
+
|
|
242
|
+
score = logits[:, int(class_index)].sum()
|
|
243
|
+
grads = torch.autograd.grad(score, x_in[mod], retain_graph=True)[0]
|
|
244
|
+
attr = grads * x_in[mod] # (B,F)
|
|
245
|
+
|
|
246
|
+
else: # IG
|
|
247
|
+
attr = integrated_gradients(
|
|
248
|
+
lambda xm: scalar_from_inputs(xm, mod),
|
|
249
|
+
x_in[mod],
|
|
250
|
+
baseline=torch.zeros_like(x_in[mod]),
|
|
251
|
+
steps=int(ig_steps),
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Aggregate across cells (mean abs is usually a good default)
|
|
255
|
+
attr_agg = attr.detach().abs().mean(dim=0) # (F,)
|
|
256
|
+
|
|
257
|
+
k = min(int(per_cell_topk), int(attr_agg.numel()))
|
|
258
|
+
vals, idx = torch.topk(attr_agg, k=k, largest=True, sorted=True)
|
|
259
|
+
|
|
260
|
+
names = None
|
|
261
|
+
if feature_names_by_mod is not None and mod in feature_names_by_mod:
|
|
262
|
+
vn = list(map(str, feature_names_by_mod[mod]))
|
|
263
|
+
names = [vn[i] for i in idx.detach().cpu().tolist()]
|
|
264
|
+
|
|
265
|
+
results[mod] = {
|
|
266
|
+
"indices": idx.detach().cpu().numpy(),
|
|
267
|
+
"scores": vals.detach().cpu().numpy(),
|
|
268
|
+
"names": names,
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
return results
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
# -----------------------------
|
|
275
|
+
# Cross-modal token interaction from attention
|
|
276
|
+
# -----------------------------
|
|
277
|
+
def top_cross_modal_attention_pairs(
|
|
278
|
+
attn: torch.Tensor,
|
|
279
|
+
tokmap: TokenMap,
|
|
280
|
+
*,
|
|
281
|
+
mod_a: str,
|
|
282
|
+
mod_b: str,
|
|
283
|
+
top_n: int = 50,
|
|
284
|
+
reduce: Literal["mean", "max"] = "mean",
|
|
285
|
+
) -> List[Tuple[str, str, float]]:
|
|
286
|
+
"""
|
|
287
|
+
Takes one attention matrix (T,T) and returns strongest A->B token pairs.
|
|
288
|
+
`attn` should already correspond to the fused token sequence *including* global CLS if present.
|
|
289
|
+
"""
|
|
290
|
+
if attn.dim() != 2:
|
|
291
|
+
raise ValueError(f"Expected attn as (T,T), got {tuple(attn.shape)}")
|
|
292
|
+
|
|
293
|
+
# compute slices in the attention matrix
|
|
294
|
+
# If global CLS exists, token indices shift by +1 for everything else.
|
|
295
|
+
shift = 1 if tokmap.has_global_cls else 0
|
|
296
|
+
a0, a1 = tokmap.slices[mod_a]
|
|
297
|
+
b0, b1 = tokmap.slices[mod_b]
|
|
298
|
+
a0 += shift; a1 += shift
|
|
299
|
+
b0 += shift; b1 += shift
|
|
300
|
+
|
|
301
|
+
sub = attn[a0:a1, b0:b1] # (Ta,Tb)
|
|
302
|
+
|
|
303
|
+
# flatten and take top pairs
|
|
304
|
+
flat = sub.reshape(-1)
|
|
305
|
+
k = min(int(top_n), int(flat.numel()))
|
|
306
|
+
vals, idx = torch.topk(flat, k=k, largest=True, sorted=True)
|
|
307
|
+
|
|
308
|
+
Ta = a1 - a0
|
|
309
|
+
Tb = b1 - b0
|
|
310
|
+
|
|
311
|
+
pairs: List[Tuple[str, str, float]] = []
|
|
312
|
+
for v, ii in zip(vals.detach().cpu().tolist(), idx.detach().cpu().tolist()):
|
|
313
|
+
ia = ii // Tb
|
|
314
|
+
ib = ii % Tb
|
|
315
|
+
pairs.append((f"{mod_a}:token_{ia}", f"{mod_b}:token_{ib}", float(v)))
|
|
316
|
+
|
|
317
|
+
return pairs
|
|
318
|
+
|
|
319
|
+
@torch.no_grad()
|
|
320
|
+
def top_cross_modal_feature_pairs_from_attn(
|
|
321
|
+
attn_all: List[torch.Tensor],
|
|
322
|
+
tokmap: TokenMap,
|
|
323
|
+
*,
|
|
324
|
+
mod_a: str,
|
|
325
|
+
mod_b: str,
|
|
326
|
+
var_names_by_mod: Dict[str, Sequence[str]],
|
|
327
|
+
tokenizer_mode_by_mod: Dict[str, str],
|
|
328
|
+
layer: int = -1,
|
|
329
|
+
top_pairs_per_cell: int = 50,
|
|
330
|
+
top_n: int = 100,
|
|
331
|
+
) -> List[Tuple[str, str, float]]:
|
|
332
|
+
"""
|
|
333
|
+
Returns strongest (feature_a -> feature_b) pairs aggregated across the batch.
|
|
334
|
+
|
|
335
|
+
Works best when tokenizer mode for both modalities is topk_* (because tokens map
|
|
336
|
+
to feature indices via per-cell topk_idx).
|
|
337
|
+
"""
|
|
338
|
+
mode_a = tokenizer_mode_by_mod[mod_a]
|
|
339
|
+
mode_b = tokenizer_mode_by_mod[mod_b]
|
|
340
|
+
if mode_a not in ("topk_scalar", "topk_channels") or mode_b not in ("topk_scalar", "topk_channels"):
|
|
341
|
+
raise ValueError("This helper currently supports only topk_* tokenizers for both modalities.")
|
|
342
|
+
|
|
343
|
+
A = attn_all[layer]
|
|
344
|
+
if A.dim() == 4: # (B,H,T,T)
|
|
345
|
+
A = A.mean(dim=1) # avg heads -> (B,T,T)
|
|
346
|
+
if A.dim() != 3:
|
|
347
|
+
raise ValueError(f"Expected attn as (B,T,T); got {tuple(A.shape)}")
|
|
348
|
+
|
|
349
|
+
Bsz, Ttot, _ = A.shape
|
|
350
|
+
|
|
351
|
+
shift = 1 if tokmap.has_global_cls else 0
|
|
352
|
+
a0, a1 = tokmap.slices[mod_a]
|
|
353
|
+
b0, b1 = tokmap.slices[mod_b]
|
|
354
|
+
a0 += shift; a1 += shift
|
|
355
|
+
b0 += shift; b1 += shift
|
|
356
|
+
|
|
357
|
+
Ta = a1 - a0
|
|
358
|
+
Tb = b1 - b0
|
|
359
|
+
|
|
360
|
+
sub = A[:, a0:a1, b0:b1] # (B,Ta,Tb)
|
|
361
|
+
flat = sub.reshape(Bsz, Ta * Tb) # (B, Ta*Tb)
|
|
362
|
+
|
|
363
|
+
k = min(int(top_pairs_per_cell), int(Ta * Tb))
|
|
364
|
+
vals, idx = torch.topk(flat, k=k, dim=1, largest=True, sorted=True) # (B,k)
|
|
365
|
+
|
|
366
|
+
ia = idx // Tb # (B,k) token index within A-slice
|
|
367
|
+
ib = idx % Tb # (B,k) token index within B-slice
|
|
368
|
+
|
|
369
|
+
topk_a = tokmap.meta[mod_a].get("topk_idx", None)
|
|
370
|
+
topk_b = tokmap.meta[mod_b].get("topk_idx", None)
|
|
371
|
+
if topk_a is None or topk_b is None:
|
|
372
|
+
raise ValueError("Missing topk_idx in tokmap.meta. Ensure return_token_meta=True.")
|
|
373
|
+
|
|
374
|
+
# Map token positions -> feature indices (per-cell)
|
|
375
|
+
fa = torch.gather(topk_a.to(device=ia.device), 1, ia) # (B,k)
|
|
376
|
+
fb = torch.gather(topk_b.to(device=ib.device), 1, ib) # (B,k)
|
|
377
|
+
|
|
378
|
+
# Aggregate into a python dict (sparse aggregation)
|
|
379
|
+
scores: Dict[Tuple[int, int], float] = {}
|
|
380
|
+
fa_np = fa.detach().cpu().numpy()
|
|
381
|
+
fb_np = fb.detach().cpu().numpy()
|
|
382
|
+
v_np = vals.detach().cpu().numpy()
|
|
383
|
+
|
|
384
|
+
for b in range(Bsz):
|
|
385
|
+
for j in range(k):
|
|
386
|
+
key = (int(fa_np[b, j]), int(fb_np[b, j]))
|
|
387
|
+
scores[key] = scores.get(key, 0.0) + float(v_np[b, j])
|
|
388
|
+
|
|
389
|
+
# Convert to top list with names
|
|
390
|
+
a_names = list(map(str, var_names_by_mod[mod_a]))
|
|
391
|
+
b_names = list(map(str, var_names_by_mod[mod_b]))
|
|
392
|
+
|
|
393
|
+
items = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)[: int(top_n)]
|
|
394
|
+
out: List[Tuple[str, str, float]] = []
|
|
395
|
+
for (ia_feat, ib_feat), s in items:
|
|
396
|
+
out.append((a_names[ia_feat], b_names[ib_feat], float(s)))
|
|
397
|
+
return out
|
|
398
|
+
|
|
399
|
+
|