univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/evaluation.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
1
|
+
# univi/evaluation.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import scipy.sparse as sp
|
|
9
|
+
import torch
|
|
10
|
+
|
|
11
|
+
from sklearn.neighbors import NearestNeighbors
|
|
12
|
+
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# ----------------------------
|
|
16
|
+
# Small helpers
|
|
17
|
+
# ----------------------------
|
|
18
|
+
def _mean_sem(x: np.ndarray) -> Tuple[float, float]:
|
|
19
|
+
x = np.asarray(x, dtype=float)
|
|
20
|
+
if x.size == 0:
|
|
21
|
+
return 0.0, 0.0
|
|
22
|
+
if x.size == 1:
|
|
23
|
+
return float(x.mean()), 0.0
|
|
24
|
+
return float(x.mean()), float(x.std(ddof=1) / np.sqrt(x.size))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _json_safe(obj: Any) -> Any:
|
|
28
|
+
"""Convert numpy scalars/arrays into JSON-safe python types."""
|
|
29
|
+
if isinstance(obj, (np.floating, np.integer)):
|
|
30
|
+
return obj.item()
|
|
31
|
+
if isinstance(obj, np.ndarray):
|
|
32
|
+
return obj.tolist()
|
|
33
|
+
return obj
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ------------------------------------------------------------------
|
|
37
|
+
# 1. FOSCTTM (exact, blockwise) + Recall@k (top-k match rate)
|
|
38
|
+
# ------------------------------------------------------------------
|
|
39
|
+
def compute_foscttm(
|
|
40
|
+
Z1: np.ndarray,
|
|
41
|
+
Z2: np.ndarray,
|
|
42
|
+
metric: str = "euclidean",
|
|
43
|
+
block_size: int = 512,
|
|
44
|
+
return_sem: bool = False,
|
|
45
|
+
return_per_cell: bool = False,
|
|
46
|
+
) -> Union[float, Tuple[float, float], Tuple[float, np.ndarray], Tuple[float, float, np.ndarray]]:
|
|
47
|
+
"""
|
|
48
|
+
Compute FOSCTTM assuming 1:1 pairing between rows of Z1 and Z2.
|
|
49
|
+
|
|
50
|
+
Definition used:
|
|
51
|
+
For each i:
|
|
52
|
+
frac_i = #{j: d(Z1[i], Z2[j]) < d(Z1[i], Z2[i])} / (N-1)
|
|
53
|
+
FOSCTTM = mean_i frac_i
|
|
54
|
+
|
|
55
|
+
This is computed EXACTLY using blockwise pairwise distance computation to avoid NxN kneighbors storage.
|
|
56
|
+
|
|
57
|
+
Supports metric in {"euclidean", "cosine"}.
|
|
58
|
+
"""
|
|
59
|
+
Z1 = np.asarray(Z1, dtype=np.float32)
|
|
60
|
+
Z2 = np.asarray(Z2, dtype=np.float32)
|
|
61
|
+
|
|
62
|
+
if Z1.shape != Z2.shape:
|
|
63
|
+
raise ValueError(f"Z1/Z2 must have same shape for FOSCTTM. Got {Z1.shape} vs {Z2.shape}")
|
|
64
|
+
|
|
65
|
+
n = int(Z1.shape[0])
|
|
66
|
+
if n <= 1:
|
|
67
|
+
out0: Any = 0.0
|
|
68
|
+
if return_sem and return_per_cell:
|
|
69
|
+
return 0.0, 0.0, np.zeros(n, dtype=np.float32)
|
|
70
|
+
if return_sem:
|
|
71
|
+
return 0.0, 0.0
|
|
72
|
+
if return_per_cell:
|
|
73
|
+
return 0.0, np.zeros(n, dtype=np.float32)
|
|
74
|
+
return 0.0
|
|
75
|
+
|
|
76
|
+
metric = str(metric).lower().strip()
|
|
77
|
+
if metric not in {"euclidean", "cosine"}:
|
|
78
|
+
raise ValueError("compute_foscttm currently supports metric in {'euclidean','cosine'}.")
|
|
79
|
+
|
|
80
|
+
fos = np.empty(n, dtype=np.float32)
|
|
81
|
+
|
|
82
|
+
if metric == "euclidean":
|
|
83
|
+
# squared Euclidean: ||a-b||^2 = ||a||^2 + ||b||^2 - 2 a·b
|
|
84
|
+
Z2_T = Z2.T
|
|
85
|
+
n2 = np.sum(Z2 * Z2, axis=1) # (n,)
|
|
86
|
+
for i0 in range(0, n, int(block_size)):
|
|
87
|
+
i1 = min(i0 + int(block_size), n)
|
|
88
|
+
A = Z1[i0:i1]
|
|
89
|
+
n1 = np.sum(A * A, axis=1)[:, None] # (b,1)
|
|
90
|
+
d2 = n1 + n2[None, :] - 2.0 * (A @ Z2_T) # (b,n)
|
|
91
|
+
true = d2[np.arange(i1 - i0), np.arange(i0, i1)]
|
|
92
|
+
fos[i0:i1] = (d2 < true[:, None]).sum(axis=1) / (n - 1)
|
|
93
|
+
|
|
94
|
+
else: # cosine distance = 1 - cosine_similarity
|
|
95
|
+
Z2_T = Z2.T
|
|
96
|
+
n2 = np.linalg.norm(Z2, axis=1) + 1e-8 # (n,)
|
|
97
|
+
for i0 in range(0, n, int(block_size)):
|
|
98
|
+
i1 = min(i0 + int(block_size), n)
|
|
99
|
+
A = Z1[i0:i1]
|
|
100
|
+
n1 = np.linalg.norm(A, axis=1) + 1e-8 # (b,)
|
|
101
|
+
sim = (A @ Z2_T) / (n1[:, None] * n2[None, :]) # (b,n)
|
|
102
|
+
d = 1.0 - sim
|
|
103
|
+
true = d[np.arange(i1 - i0), np.arange(i0, i1)]
|
|
104
|
+
fos[i0:i1] = (d < true[:, None]).sum(axis=1) / (n - 1)
|
|
105
|
+
|
|
106
|
+
m, s = _mean_sem(fos.astype(float))
|
|
107
|
+
|
|
108
|
+
if return_sem and return_per_cell:
|
|
109
|
+
return float(m), float(s), fos
|
|
110
|
+
if return_sem:
|
|
111
|
+
return float(m), float(s)
|
|
112
|
+
if return_per_cell:
|
|
113
|
+
return float(m), fos
|
|
114
|
+
return float(m)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def compute_match_recall_at_k(
|
|
118
|
+
Z1: np.ndarray,
|
|
119
|
+
Z2: np.ndarray,
|
|
120
|
+
k: int = 10,
|
|
121
|
+
metric: str = "euclidean",
|
|
122
|
+
block_size: int = 512,
|
|
123
|
+
return_sem: bool = False,
|
|
124
|
+
return_per_cell: bool = False,
|
|
125
|
+
) -> Union[float, Tuple[float, float], Tuple[float, np.ndarray], Tuple[float, float, np.ndarray]]:
|
|
126
|
+
"""
|
|
127
|
+
Recall@k for paired matching:
|
|
128
|
+
hit_i = 1 if true match (i) is among k nearest neighbors of Z1[i] in Z2
|
|
129
|
+
recall@k = mean_i hit_i
|
|
130
|
+
|
|
131
|
+
Computed exactly blockwise for metric in {"euclidean","cosine"}.
|
|
132
|
+
"""
|
|
133
|
+
Z1 = np.asarray(Z1, dtype=np.float32)
|
|
134
|
+
Z2 = np.asarray(Z2, dtype=np.float32)
|
|
135
|
+
|
|
136
|
+
if Z1.shape != Z2.shape:
|
|
137
|
+
raise ValueError(f"Z1/Z2 must have same shape. Got {Z1.shape} vs {Z2.shape}")
|
|
138
|
+
|
|
139
|
+
n = int(Z1.shape[0])
|
|
140
|
+
if n == 0:
|
|
141
|
+
raise ValueError("Empty inputs.")
|
|
142
|
+
if n == 1:
|
|
143
|
+
hits = np.array([1.0], dtype=np.float32)
|
|
144
|
+
if return_sem and return_per_cell:
|
|
145
|
+
return 1.0, 0.0, hits
|
|
146
|
+
if return_sem:
|
|
147
|
+
return 1.0, 0.0
|
|
148
|
+
if return_per_cell:
|
|
149
|
+
return 1.0, hits
|
|
150
|
+
return 1.0
|
|
151
|
+
|
|
152
|
+
k = int(max(1, min(int(k), n)))
|
|
153
|
+
metric = str(metric).lower().strip()
|
|
154
|
+
if metric not in {"euclidean", "cosine"}:
|
|
155
|
+
raise ValueError("compute_match_recall_at_k currently supports metric in {'euclidean','cosine'}.")
|
|
156
|
+
|
|
157
|
+
hits = np.empty(n, dtype=np.float32)
|
|
158
|
+
|
|
159
|
+
if metric == "euclidean":
|
|
160
|
+
Z2_T = Z2.T
|
|
161
|
+
n2 = np.sum(Z2 * Z2, axis=1) # (n,)
|
|
162
|
+
for i0 in range(0, n, int(block_size)):
|
|
163
|
+
i1 = min(i0 + int(block_size), n)
|
|
164
|
+
A = Z1[i0:i1]
|
|
165
|
+
n1 = np.sum(A * A, axis=1)[:, None] # (b,1)
|
|
166
|
+
d2 = n1 + n2[None, :] - 2.0 * (A @ Z2_T) # (b,n)
|
|
167
|
+
# indices of k smallest (unordered), then check membership
|
|
168
|
+
topk = np.argpartition(d2, kth=k - 1, axis=1)[:, :k]
|
|
169
|
+
for r in range(i1 - i0):
|
|
170
|
+
hits[i0 + r] = 1.0 if (i0 + r) in topk[r] else 0.0
|
|
171
|
+
else:
|
|
172
|
+
Z2_T = Z2.T
|
|
173
|
+
n2 = np.linalg.norm(Z2, axis=1) + 1e-8
|
|
174
|
+
for i0 in range(0, n, int(block_size)):
|
|
175
|
+
i1 = min(i0 + int(block_size), n)
|
|
176
|
+
A = Z1[i0:i1]
|
|
177
|
+
n1 = np.linalg.norm(A, axis=1) + 1e-8
|
|
178
|
+
sim = (A @ Z2_T) / (n1[:, None] * n2[None, :])
|
|
179
|
+
d = 1.0 - sim
|
|
180
|
+
topk = np.argpartition(d, kth=k - 1, axis=1)[:, :k]
|
|
181
|
+
for r in range(i1 - i0):
|
|
182
|
+
hits[i0 + r] = 1.0 if (i0 + r) in topk[r] else 0.0
|
|
183
|
+
|
|
184
|
+
m, s = _mean_sem(hits.astype(float))
|
|
185
|
+
if return_sem and return_per_cell:
|
|
186
|
+
return float(m), float(s), hits
|
|
187
|
+
if return_sem:
|
|
188
|
+
return float(m), float(s)
|
|
189
|
+
if return_per_cell:
|
|
190
|
+
return float(m), hits
|
|
191
|
+
return float(m)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# ------------------------------------------------------------------
|
|
195
|
+
# 2. Modality mixing
|
|
196
|
+
# ------------------------------------------------------------------
|
|
197
|
+
def compute_modality_mixing(
|
|
198
|
+
Z: np.ndarray,
|
|
199
|
+
modality_labels: np.ndarray,
|
|
200
|
+
k: int = 20,
|
|
201
|
+
metric: str = "euclidean",
|
|
202
|
+
return_sem: bool = False,
|
|
203
|
+
return_per_cell: bool = False,
|
|
204
|
+
) -> Union[float, Tuple[float, float], Tuple[float, np.ndarray], Tuple[float, float, np.ndarray]]:
|
|
205
|
+
"""
|
|
206
|
+
Mean fraction of kNN neighbors that are from a different modality.
|
|
207
|
+
"""
|
|
208
|
+
Z = np.asarray(Z, dtype=np.float32)
|
|
209
|
+
modality_labels = np.asarray(modality_labels)
|
|
210
|
+
if Z.shape[0] != modality_labels.shape[0]:
|
|
211
|
+
raise ValueError("Z and modality_labels must align on n_cells.")
|
|
212
|
+
|
|
213
|
+
n = int(Z.shape[0])
|
|
214
|
+
if n <= 1:
|
|
215
|
+
if return_sem and return_per_cell:
|
|
216
|
+
return 0.0, 0.0, np.zeros(n, dtype=np.float32)
|
|
217
|
+
if return_sem:
|
|
218
|
+
return 0.0, 0.0
|
|
219
|
+
if return_per_cell:
|
|
220
|
+
return 0.0, np.zeros(n, dtype=np.float32)
|
|
221
|
+
return 0.0
|
|
222
|
+
|
|
223
|
+
metric = str(metric).lower().strip()
|
|
224
|
+
k_eff = int(min(max(int(k), 1), n - 1))
|
|
225
|
+
|
|
226
|
+
nn = NearestNeighbors(n_neighbors=k_eff + 1, metric=metric)
|
|
227
|
+
nn.fit(Z)
|
|
228
|
+
neigh_idx = nn.kneighbors(Z, return_distance=False)[:, 1:] # drop self
|
|
229
|
+
|
|
230
|
+
neigh_mods = modality_labels[neigh_idx]
|
|
231
|
+
frac_other = (neigh_mods != modality_labels[:, None]).mean(axis=1).astype(np.float32)
|
|
232
|
+
|
|
233
|
+
m, s = _mean_sem(frac_other.astype(float))
|
|
234
|
+
if return_sem and return_per_cell:
|
|
235
|
+
return float(m), float(s), frac_other
|
|
236
|
+
if return_sem:
|
|
237
|
+
return float(m), float(s)
|
|
238
|
+
if return_per_cell:
|
|
239
|
+
return float(m), frac_other
|
|
240
|
+
return float(m)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# ------------------------------------------------------------------
|
|
244
|
+
# 3. Label transfer (kNN) with extra stats (macro/weighted F1)
|
|
245
|
+
# ------------------------------------------------------------------
|
|
246
|
+
def label_transfer_knn(
|
|
247
|
+
Z_source: np.ndarray,
|
|
248
|
+
labels_source: np.ndarray,
|
|
249
|
+
Z_target: np.ndarray,
|
|
250
|
+
labels_target: Optional[np.ndarray] = None,
|
|
251
|
+
k: int = 15,
|
|
252
|
+
metric: str = "euclidean",
|
|
253
|
+
return_label_order: bool = False,
|
|
254
|
+
return_f1: bool = False,
|
|
255
|
+
):
|
|
256
|
+
"""
|
|
257
|
+
Backwards-compatible returns:
|
|
258
|
+
- if labels_target is None: (pred_labels, None, empty_cm)
|
|
259
|
+
- if labels_target provided and both flags False: (pred_labels, acc, cm)
|
|
260
|
+
- if return_label_order True: add label_order
|
|
261
|
+
- if return_f1 True: add f1_dict
|
|
262
|
+
- if both True: add both (label_order, f1_dict) in that order
|
|
263
|
+
"""
|
|
264
|
+
Z_source = np.asarray(Z_source, dtype=np.float32)
|
|
265
|
+
Z_target = np.asarray(Z_target, dtype=np.float32)
|
|
266
|
+
labels_source = np.asarray(labels_source)
|
|
267
|
+
if labels_target is not None:
|
|
268
|
+
labels_target = np.asarray(labels_target)
|
|
269
|
+
|
|
270
|
+
n_source = int(Z_source.shape[0])
|
|
271
|
+
if n_source == 0:
|
|
272
|
+
raise ValueError("Z_source is empty.")
|
|
273
|
+
|
|
274
|
+
k_eff = int(min(max(int(k), 1), n_source))
|
|
275
|
+
nn = NearestNeighbors(n_neighbors=k_eff, metric=metric)
|
|
276
|
+
nn.fit(Z_source)
|
|
277
|
+
neigh_idx = nn.kneighbors(Z_target, return_distance=False)
|
|
278
|
+
|
|
279
|
+
uniq_src, src_codes = np.unique(labels_source, return_inverse=True)
|
|
280
|
+
|
|
281
|
+
pred_codes = np.empty(Z_target.shape[0], dtype=np.int64)
|
|
282
|
+
for i in range(Z_target.shape[0]):
|
|
283
|
+
votes = src_codes[neigh_idx[i]]
|
|
284
|
+
bc = np.bincount(votes, minlength=len(uniq_src))
|
|
285
|
+
pred_codes[i] = int(bc.argmax())
|
|
286
|
+
|
|
287
|
+
pred_labels = uniq_src[pred_codes]
|
|
288
|
+
|
|
289
|
+
if labels_target is None:
|
|
290
|
+
return pred_labels, None, np.array([])
|
|
291
|
+
|
|
292
|
+
label_order = np.unique(np.concatenate([labels_target, pred_labels]))
|
|
293
|
+
acc = float(accuracy_score(labels_target, pred_labels))
|
|
294
|
+
cm = confusion_matrix(labels_target, pred_labels, labels=label_order)
|
|
295
|
+
|
|
296
|
+
extras = []
|
|
297
|
+
if return_label_order:
|
|
298
|
+
extras.append(label_order)
|
|
299
|
+
if return_f1:
|
|
300
|
+
extras.append({
|
|
301
|
+
"macro_f1": float(f1_score(labels_target, pred_labels, average="macro")),
|
|
302
|
+
"weighted_f1": float(f1_score(labels_target, pred_labels, average="weighted")),
|
|
303
|
+
})
|
|
304
|
+
|
|
305
|
+
if not extras:
|
|
306
|
+
return pred_labels, acc, cm
|
|
307
|
+
return (pred_labels, acc, cm, *extras)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# ------------------------------------------------------------------
|
|
311
|
+
# 4. Reconstruction metrics (continuous; useful for CITE-seq CLR/gaussian)
|
|
312
|
+
# ------------------------------------------------------------------
|
|
313
|
+
def mse_per_feature(x_true: np.ndarray, x_pred: np.ndarray) -> np.ndarray:
|
|
314
|
+
x_true = np.asarray(x_true)
|
|
315
|
+
x_pred = np.asarray(x_pred)
|
|
316
|
+
if x_true.shape != x_pred.shape:
|
|
317
|
+
raise ValueError("x_true and x_pred must have same shape.")
|
|
318
|
+
return np.mean((x_true - x_pred) ** 2, axis=0)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def pearson_corr_per_feature(x_true: np.ndarray, x_pred: np.ndarray) -> np.ndarray:
|
|
322
|
+
x_true = np.asarray(x_true, dtype=np.float32)
|
|
323
|
+
x_pred = np.asarray(x_pred, dtype=np.float32)
|
|
324
|
+
if x_true.shape != x_pred.shape:
|
|
325
|
+
raise ValueError("x_true and x_pred must have same shape.")
|
|
326
|
+
|
|
327
|
+
x_true_c = x_true - x_true.mean(axis=0, keepdims=True)
|
|
328
|
+
x_pred_c = x_pred - x_pred.mean(axis=0, keepdims=True)
|
|
329
|
+
|
|
330
|
+
num = (x_true_c * x_pred_c).sum(axis=0)
|
|
331
|
+
denom = np.sqrt((x_true_c ** 2).sum(axis=0) * (x_pred_c ** 2).sum(axis=0)) + 1e-8
|
|
332
|
+
return num / denom
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def reconstruction_metrics(x_true: np.ndarray, x_pred: np.ndarray) -> Dict[str, Any]:
|
|
336
|
+
pf_mse = mse_per_feature(x_true, x_pred)
|
|
337
|
+
pf_r = pearson_corr_per_feature(x_true, x_pred)
|
|
338
|
+
return {
|
|
339
|
+
"mse_mean": float(np.mean(pf_mse)),
|
|
340
|
+
"mse_median": float(np.median(pf_mse)),
|
|
341
|
+
"pearson_mean": float(np.mean(pf_r)),
|
|
342
|
+
"pearson_median": float(np.median(pf_r)),
|
|
343
|
+
"mse_per_feature": pf_mse,
|
|
344
|
+
"pearson_per_feature": pf_r,
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
# ------------------------------------------------------------------
|
|
349
|
+
# 5. Encoding + cross-modal prediction
|
|
350
|
+
# ------------------------------------------------------------------
|
|
351
|
+
def encode_adata(
|
|
352
|
+
model,
|
|
353
|
+
adata,
|
|
354
|
+
modality: str,
|
|
355
|
+
device: str = "cpu",
|
|
356
|
+
layer: Optional[str] = None,
|
|
357
|
+
X_key: str = "X",
|
|
358
|
+
batch_size: int = 1024,
|
|
359
|
+
latent: str = "moe_mean",
|
|
360
|
+
random_state: int = 0,
|
|
361
|
+
) -> np.ndarray:
|
|
362
|
+
from .data import _get_matrix
|
|
363
|
+
|
|
364
|
+
latent = str(latent).lower().strip()
|
|
365
|
+
valid = {"moe_mean", "moe_sample", "modality_mean", "modality_sample"}
|
|
366
|
+
if latent not in valid:
|
|
367
|
+
raise ValueError("latent must be one of %s; got %r" % (sorted(valid), latent))
|
|
368
|
+
|
|
369
|
+
def _sample_gaussian(mu: torch.Tensor, logvar: torch.Tensor, gen: torch.Generator) -> torch.Tensor:
|
|
370
|
+
eps = torch.randn(mu.shape, device=mu.device, generator=gen, dtype=mu.dtype)
|
|
371
|
+
return mu + eps * torch.exp(0.5 * logvar)
|
|
372
|
+
|
|
373
|
+
model.eval()
|
|
374
|
+
X = _get_matrix(adata, layer=layer, X_key=X_key)
|
|
375
|
+
if sp.issparse(X):
|
|
376
|
+
X = X.toarray()
|
|
377
|
+
|
|
378
|
+
dev = torch.device(device)
|
|
379
|
+
gen = torch.Generator(device=dev)
|
|
380
|
+
gen.manual_seed(int(random_state))
|
|
381
|
+
|
|
382
|
+
zs = []
|
|
383
|
+
with torch.no_grad():
|
|
384
|
+
for start in range(0, X.shape[0], int(batch_size)):
|
|
385
|
+
end = min(start + int(batch_size), X.shape[0])
|
|
386
|
+
xb = torch.as_tensor(np.asarray(X[start:end]), dtype=torch.float32, device=dev)
|
|
387
|
+
|
|
388
|
+
mu_dict, logvar_dict = model.encode_modalities({modality: xb})
|
|
389
|
+
|
|
390
|
+
if "modality" in latent:
|
|
391
|
+
mu = mu_dict[modality]
|
|
392
|
+
lv = logvar_dict[modality]
|
|
393
|
+
z = mu if latent.endswith("_mean") else _sample_gaussian(mu, lv, gen)
|
|
394
|
+
else:
|
|
395
|
+
# robust fallback in case a future refactor renames MoE fuser
|
|
396
|
+
if hasattr(model, "mixture_of_experts"):
|
|
397
|
+
mu_z, logvar_z = model.mixture_of_experts(mu_dict, logvar_dict)
|
|
398
|
+
elif hasattr(model, "fuse_posteriors"):
|
|
399
|
+
mu_z, logvar_z = model.fuse_posteriors(mu_dict, logvar_dict)
|
|
400
|
+
else:
|
|
401
|
+
# single-modality fallback
|
|
402
|
+
mu_z, logvar_z = mu_dict[modality], logvar_dict[modality]
|
|
403
|
+
|
|
404
|
+
z = mu_z if latent.endswith("_mean") else _sample_gaussian(mu_z, logvar_z, gen)
|
|
405
|
+
|
|
406
|
+
zs.append(z.detach().cpu().numpy())
|
|
407
|
+
|
|
408
|
+
return np.vstack(zs)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def cross_modal_predict(
|
|
412
|
+
model,
|
|
413
|
+
adata_src,
|
|
414
|
+
src_mod: str,
|
|
415
|
+
tgt_mod: str,
|
|
416
|
+
device: str = "cpu",
|
|
417
|
+
layer: Optional[str] = None,
|
|
418
|
+
X_key: str = "X",
|
|
419
|
+
batch_size: int = 512,
|
|
420
|
+
use_moe: bool = True,
|
|
421
|
+
) -> np.ndarray:
|
|
422
|
+
"""
|
|
423
|
+
Encode src_mod then decode tgt_mod.
|
|
424
|
+
|
|
425
|
+
For paired data with ONLY src_mod observed, MoE fusion == src posterior.
|
|
426
|
+
Still, use_moe=False can be handy if you want to force src-only even if model changes.
|
|
427
|
+
"""
|
|
428
|
+
from .data import _get_matrix
|
|
429
|
+
|
|
430
|
+
model.eval()
|
|
431
|
+
X = _get_matrix(adata_src, layer=layer, X_key=X_key)
|
|
432
|
+
if sp.issparse(X):
|
|
433
|
+
X = X.toarray()
|
|
434
|
+
|
|
435
|
+
dev = torch.device(device)
|
|
436
|
+
|
|
437
|
+
preds = []
|
|
438
|
+
with torch.no_grad():
|
|
439
|
+
for start in range(0, X.shape[0], int(batch_size)):
|
|
440
|
+
end = min(start + int(batch_size), X.shape[0])
|
|
441
|
+
xb = torch.as_tensor(np.asarray(X[start:end]), dtype=torch.float32, device=dev)
|
|
442
|
+
|
|
443
|
+
mu_dict, logvar_dict = model.encode_modalities({src_mod: xb})
|
|
444
|
+
|
|
445
|
+
if use_moe and hasattr(model, "mixture_of_experts"):
|
|
446
|
+
mu_z, _ = model.mixture_of_experts(mu_dict, logvar_dict)
|
|
447
|
+
else:
|
|
448
|
+
mu_z = mu_dict[src_mod]
|
|
449
|
+
|
|
450
|
+
xhat_dict = model.decode_modalities(mu_z)
|
|
451
|
+
if tgt_mod not in xhat_dict:
|
|
452
|
+
raise KeyError(f"Target modality {tgt_mod!r} not found. Available: {list(xhat_dict.keys())}")
|
|
453
|
+
preds.append(xhat_dict[tgt_mod].detach().cpu().numpy())
|
|
454
|
+
|
|
455
|
+
return np.vstack(preds) if preds else np.zeros((0, 0), dtype=float)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def denoise_adata(
|
|
459
|
+
model,
|
|
460
|
+
adata,
|
|
461
|
+
modality: str,
|
|
462
|
+
device: str = "cpu",
|
|
463
|
+
layer: Optional[str] = None,
|
|
464
|
+
X_key: str = "X",
|
|
465
|
+
batch_size: int = 512,
|
|
466
|
+
out_layer: Optional[str] = None,
|
|
467
|
+
overwrite_X: bool = False,
|
|
468
|
+
dtype: Optional[np.dtype] = np.float32,
|
|
469
|
+
) -> np.ndarray:
|
|
470
|
+
X_hat = cross_modal_predict(
|
|
471
|
+
model,
|
|
472
|
+
adata_src=adata,
|
|
473
|
+
src_mod=modality,
|
|
474
|
+
tgt_mod=modality,
|
|
475
|
+
device=device,
|
|
476
|
+
layer=layer,
|
|
477
|
+
X_key=X_key,
|
|
478
|
+
batch_size=batch_size,
|
|
479
|
+
use_moe=True,
|
|
480
|
+
)
|
|
481
|
+
if dtype is not None:
|
|
482
|
+
X_hat = np.asarray(X_hat, dtype=dtype)
|
|
483
|
+
|
|
484
|
+
if overwrite_X:
|
|
485
|
+
adata.X = X_hat
|
|
486
|
+
elif out_layer is not None:
|
|
487
|
+
adata.layers[out_layer] = X_hat
|
|
488
|
+
|
|
489
|
+
return X_hat
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
# ------------------------------------------------------------------
|
|
493
|
+
# 6. High-level alignment eval (Figure-ready)
|
|
494
|
+
# ------------------------------------------------------------------
|
|
495
|
+
def evaluate_alignment(
|
|
496
|
+
Z1: Optional[np.ndarray] = None,
|
|
497
|
+
Z2: Optional[np.ndarray] = None,
|
|
498
|
+
model=None,
|
|
499
|
+
adata1=None,
|
|
500
|
+
adata2=None,
|
|
501
|
+
mod1: Optional[str] = None,
|
|
502
|
+
mod2: Optional[str] = None,
|
|
503
|
+
device: str = "cpu",
|
|
504
|
+
layer1: Optional[str] = None,
|
|
505
|
+
layer2: Optional[str] = None,
|
|
506
|
+
X_key1: str = "X",
|
|
507
|
+
X_key2: str = "X",
|
|
508
|
+
batch_size: int = 1024,
|
|
509
|
+
latent: str = "moe_mean",
|
|
510
|
+
latent1: Optional[str] = None,
|
|
511
|
+
latent2: Optional[str] = None,
|
|
512
|
+
random_state: int = 0,
|
|
513
|
+
metric: str = "euclidean",
|
|
514
|
+
k_mixing: int = 20,
|
|
515
|
+
k_transfer: int = 15,
|
|
516
|
+
modality_labels: Optional[np.ndarray] = None,
|
|
517
|
+
labels_source: Optional[np.ndarray] = None,
|
|
518
|
+
labels_target: Optional[np.ndarray] = None,
|
|
519
|
+
recall_ks: Tuple[int, ...] = (1, 5, 10),
|
|
520
|
+
foscttm_block_size: int = 512,
|
|
521
|
+
json_safe: bool = True,
|
|
522
|
+
) -> Dict[str, Any]:
|
|
523
|
+
"""
|
|
524
|
+
Returns a dict with:
|
|
525
|
+
- foscttm (mean), foscttm_sem
|
|
526
|
+
- recall@k + sem for each k in recall_ks
|
|
527
|
+
- modality_mixing (mean), modality_mixing_sem
|
|
528
|
+
- label transfer: acc, macro/weighted f1 (optional), confusion matrix + label order
|
|
529
|
+
"""
|
|
530
|
+
out: Dict[str, Any] = {}
|
|
531
|
+
|
|
532
|
+
lat1 = latent if latent1 is None else latent1
|
|
533
|
+
lat2 = latent if latent2 is None else latent2
|
|
534
|
+
|
|
535
|
+
if Z1 is None or Z2 is None:
|
|
536
|
+
if model is None or adata1 is None or adata2 is None or mod1 is None or mod2 is None:
|
|
537
|
+
raise ValueError("Provide either (Z1, Z2) or (model, adata1, adata2, mod1, mod2).")
|
|
538
|
+
|
|
539
|
+
Z1 = encode_adata(
|
|
540
|
+
model, adata1, modality=mod1, device=device, layer=layer1, X_key=X_key1,
|
|
541
|
+
batch_size=batch_size, latent=lat1, random_state=random_state
|
|
542
|
+
)
|
|
543
|
+
Z2 = encode_adata(
|
|
544
|
+
model, adata2, modality=mod2, device=device, layer=layer2, X_key=X_key2,
|
|
545
|
+
batch_size=batch_size, latent=lat2, random_state=random_state
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
Z1 = np.asarray(Z1)
|
|
549
|
+
Z2 = np.asarray(Z2)
|
|
550
|
+
|
|
551
|
+
out["n1"] = int(Z1.shape[0])
|
|
552
|
+
out["n2"] = int(Z2.shape[0])
|
|
553
|
+
out["dim"] = int(Z1.shape[1]) if Z1.ndim == 2 else None
|
|
554
|
+
out["latent1"] = lat1
|
|
555
|
+
out["latent2"] = lat2
|
|
556
|
+
out["metric"] = str(metric)
|
|
557
|
+
|
|
558
|
+
# FOSCTTM + SEM
|
|
559
|
+
if Z1.shape == Z2.shape and Z1.shape[0] > 1:
|
|
560
|
+
fos_mean, fos_sem = compute_foscttm(
|
|
561
|
+
Z1, Z2, metric=metric, block_size=foscttm_block_size, return_sem=True, return_per_cell=False
|
|
562
|
+
)
|
|
563
|
+
out["foscttm"] = fos_mean
|
|
564
|
+
out["foscttm_sem"] = fos_sem
|
|
565
|
+
else:
|
|
566
|
+
out["foscttm"] = None
|
|
567
|
+
out["foscttm_sem"] = None
|
|
568
|
+
|
|
569
|
+
# Recall@k
|
|
570
|
+
if Z1.shape == Z2.shape and Z1.shape[0] > 1:
|
|
571
|
+
for k in recall_ks:
|
|
572
|
+
r_mean, r_sem = compute_match_recall_at_k(
|
|
573
|
+
Z1, Z2, k=int(k), metric=metric, block_size=foscttm_block_size, return_sem=True, return_per_cell=False
|
|
574
|
+
)
|
|
575
|
+
out[f"recall_at_{int(k)}"] = r_mean
|
|
576
|
+
out[f"recall_at_{int(k)}_sem"] = r_sem
|
|
577
|
+
else:
|
|
578
|
+
for k in recall_ks:
|
|
579
|
+
out[f"recall_at_{int(k)}"] = None
|
|
580
|
+
out[f"recall_at_{int(k)}_sem"] = None
|
|
581
|
+
|
|
582
|
+
# Modality mixing computed on concatenated embeddings
|
|
583
|
+
Z_concat = None
|
|
584
|
+
if (Z1.ndim == 2 and Z2.ndim == 2 and Z1.shape[1] == Z2.shape[1]):
|
|
585
|
+
Z_concat = np.vstack([Z1, Z2])
|
|
586
|
+
|
|
587
|
+
if Z_concat is not None and Z_concat.shape[0] > 1:
|
|
588
|
+
if modality_labels is None:
|
|
589
|
+
modality_labels = np.concatenate([np.repeat("mod1", Z1.shape[0]), np.repeat("mod2", Z2.shape[0])])
|
|
590
|
+
mix_mean, mix_sem = compute_modality_mixing(
|
|
591
|
+
Z_concat, modality_labels=np.asarray(modality_labels),
|
|
592
|
+
k=k_mixing, metric=metric, return_sem=True, return_per_cell=False
|
|
593
|
+
)
|
|
594
|
+
out["modality_mixing"] = mix_mean
|
|
595
|
+
out["modality_mixing_sem"] = mix_sem
|
|
596
|
+
out["k_mixing"] = int(k_mixing)
|
|
597
|
+
else:
|
|
598
|
+
out["modality_mixing"] = None
|
|
599
|
+
out["modality_mixing_sem"] = None
|
|
600
|
+
out["k_mixing"] = int(k_mixing)
|
|
601
|
+
|
|
602
|
+
# Label transfer
|
|
603
|
+
if labels_source is not None:
|
|
604
|
+
pred, acc, cm, order, f1d = label_transfer_knn(
|
|
605
|
+
Z_source=Z1,
|
|
606
|
+
labels_source=np.asarray(labels_source),
|
|
607
|
+
Z_target=Z2,
|
|
608
|
+
labels_target=np.asarray(labels_target) if labels_target is not None else None,
|
|
609
|
+
k=k_transfer,
|
|
610
|
+
metric=metric,
|
|
611
|
+
return_label_order=True,
|
|
612
|
+
return_f1=True,
|
|
613
|
+
)
|
|
614
|
+
out["label_transfer_pred"] = pred
|
|
615
|
+
out["label_transfer_acc"] = acc
|
|
616
|
+
out["label_transfer_cm"] = cm
|
|
617
|
+
out["label_transfer_label_order"] = order
|
|
618
|
+
out["label_transfer_f1"] = f1d
|
|
619
|
+
out["k_transfer"] = int(k_transfer)
|
|
620
|
+
else:
|
|
621
|
+
out["label_transfer_pred"] = None
|
|
622
|
+
out["label_transfer_acc"] = None
|
|
623
|
+
out["label_transfer_cm"] = None
|
|
624
|
+
out["label_transfer_label_order"] = None
|
|
625
|
+
out["label_transfer_f1"] = None
|
|
626
|
+
out["k_transfer"] = int(k_transfer)
|
|
627
|
+
|
|
628
|
+
if json_safe:
|
|
629
|
+
out = {k: _json_safe(v) for k, v in out.items()}
|
|
630
|
+
|
|
631
|
+
return out
|
|
632
|
+
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# univi/hyperparam_optimization/__init__.py
|
|
2
|
+
|
|
3
|
+
from .run_multiome_hparam_search import run_multiome_hparam_search
|
|
4
|
+
from .run_citeseq_hparam_search import run_citeseq_hparam_search
|
|
5
|
+
from .run_teaseq_hparam_search import run_teaseq_hparam_search
|
|
6
|
+
from .run_rna_hparam_search import run_rna_hparam_search
|
|
7
|
+
from .run_atac_hparam_search import run_atac_hparam_search
|
|
8
|
+
from .run_adt_hparam_search import run_adt_hparam_search
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"run_multiome_hparam_search",
|
|
12
|
+
"run_citeseq_hparam_search",
|
|
13
|
+
"run_teaseq_hparam_search",
|
|
14
|
+
"run_rna_hparam_search",
|
|
15
|
+
"run_atac_hparam_search",
|
|
16
|
+
"run_adt_hparam_search",
|
|
17
|
+
]
|