univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/evaluation.py ADDED
@@ -0,0 +1,632 @@
1
+ # univi/evaluation.py
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import scipy.sparse as sp
9
+ import torch
10
+
11
+ from sklearn.neighbors import NearestNeighbors
12
+ from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
13
+
14
+
15
+ # ----------------------------
16
+ # Small helpers
17
+ # ----------------------------
18
+ def _mean_sem(x: np.ndarray) -> Tuple[float, float]:
19
+ x = np.asarray(x, dtype=float)
20
+ if x.size == 0:
21
+ return 0.0, 0.0
22
+ if x.size == 1:
23
+ return float(x.mean()), 0.0
24
+ return float(x.mean()), float(x.std(ddof=1) / np.sqrt(x.size))
25
+
26
+
27
+ def _json_safe(obj: Any) -> Any:
28
+ """Convert numpy scalars/arrays into JSON-safe python types."""
29
+ if isinstance(obj, (np.floating, np.integer)):
30
+ return obj.item()
31
+ if isinstance(obj, np.ndarray):
32
+ return obj.tolist()
33
+ return obj
34
+
35
+
36
+ # ------------------------------------------------------------------
37
+ # 1. FOSCTTM (exact, blockwise) + Recall@k (top-k match rate)
38
+ # ------------------------------------------------------------------
39
+ def compute_foscttm(
40
+ Z1: np.ndarray,
41
+ Z2: np.ndarray,
42
+ metric: str = "euclidean",
43
+ block_size: int = 512,
44
+ return_sem: bool = False,
45
+ return_per_cell: bool = False,
46
+ ) -> Union[float, Tuple[float, float], Tuple[float, np.ndarray], Tuple[float, float, np.ndarray]]:
47
+ """
48
+ Compute FOSCTTM assuming 1:1 pairing between rows of Z1 and Z2.
49
+
50
+ Definition used:
51
+ For each i:
52
+ frac_i = #{j: d(Z1[i], Z2[j]) < d(Z1[i], Z2[i])} / (N-1)
53
+ FOSCTTM = mean_i frac_i
54
+
55
+ This is computed EXACTLY using blockwise pairwise distance computation to avoid NxN kneighbors storage.
56
+
57
+ Supports metric in {"euclidean", "cosine"}.
58
+ """
59
+ Z1 = np.asarray(Z1, dtype=np.float32)
60
+ Z2 = np.asarray(Z2, dtype=np.float32)
61
+
62
+ if Z1.shape != Z2.shape:
63
+ raise ValueError(f"Z1/Z2 must have same shape for FOSCTTM. Got {Z1.shape} vs {Z2.shape}")
64
+
65
+ n = int(Z1.shape[0])
66
+ if n <= 1:
67
+ out0: Any = 0.0
68
+ if return_sem and return_per_cell:
69
+ return 0.0, 0.0, np.zeros(n, dtype=np.float32)
70
+ if return_sem:
71
+ return 0.0, 0.0
72
+ if return_per_cell:
73
+ return 0.0, np.zeros(n, dtype=np.float32)
74
+ return 0.0
75
+
76
+ metric = str(metric).lower().strip()
77
+ if metric not in {"euclidean", "cosine"}:
78
+ raise ValueError("compute_foscttm currently supports metric in {'euclidean','cosine'}.")
79
+
80
+ fos = np.empty(n, dtype=np.float32)
81
+
82
+ if metric == "euclidean":
83
+ # squared Euclidean: ||a-b||^2 = ||a||^2 + ||b||^2 - 2 a·b
84
+ Z2_T = Z2.T
85
+ n2 = np.sum(Z2 * Z2, axis=1) # (n,)
86
+ for i0 in range(0, n, int(block_size)):
87
+ i1 = min(i0 + int(block_size), n)
88
+ A = Z1[i0:i1]
89
+ n1 = np.sum(A * A, axis=1)[:, None] # (b,1)
90
+ d2 = n1 + n2[None, :] - 2.0 * (A @ Z2_T) # (b,n)
91
+ true = d2[np.arange(i1 - i0), np.arange(i0, i1)]
92
+ fos[i0:i1] = (d2 < true[:, None]).sum(axis=1) / (n - 1)
93
+
94
+ else: # cosine distance = 1 - cosine_similarity
95
+ Z2_T = Z2.T
96
+ n2 = np.linalg.norm(Z2, axis=1) + 1e-8 # (n,)
97
+ for i0 in range(0, n, int(block_size)):
98
+ i1 = min(i0 + int(block_size), n)
99
+ A = Z1[i0:i1]
100
+ n1 = np.linalg.norm(A, axis=1) + 1e-8 # (b,)
101
+ sim = (A @ Z2_T) / (n1[:, None] * n2[None, :]) # (b,n)
102
+ d = 1.0 - sim
103
+ true = d[np.arange(i1 - i0), np.arange(i0, i1)]
104
+ fos[i0:i1] = (d < true[:, None]).sum(axis=1) / (n - 1)
105
+
106
+ m, s = _mean_sem(fos.astype(float))
107
+
108
+ if return_sem and return_per_cell:
109
+ return float(m), float(s), fos
110
+ if return_sem:
111
+ return float(m), float(s)
112
+ if return_per_cell:
113
+ return float(m), fos
114
+ return float(m)
115
+
116
+
117
+ def compute_match_recall_at_k(
118
+ Z1: np.ndarray,
119
+ Z2: np.ndarray,
120
+ k: int = 10,
121
+ metric: str = "euclidean",
122
+ block_size: int = 512,
123
+ return_sem: bool = False,
124
+ return_per_cell: bool = False,
125
+ ) -> Union[float, Tuple[float, float], Tuple[float, np.ndarray], Tuple[float, float, np.ndarray]]:
126
+ """
127
+ Recall@k for paired matching:
128
+ hit_i = 1 if true match (i) is among k nearest neighbors of Z1[i] in Z2
129
+ recall@k = mean_i hit_i
130
+
131
+ Computed exactly blockwise for metric in {"euclidean","cosine"}.
132
+ """
133
+ Z1 = np.asarray(Z1, dtype=np.float32)
134
+ Z2 = np.asarray(Z2, dtype=np.float32)
135
+
136
+ if Z1.shape != Z2.shape:
137
+ raise ValueError(f"Z1/Z2 must have same shape. Got {Z1.shape} vs {Z2.shape}")
138
+
139
+ n = int(Z1.shape[0])
140
+ if n == 0:
141
+ raise ValueError("Empty inputs.")
142
+ if n == 1:
143
+ hits = np.array([1.0], dtype=np.float32)
144
+ if return_sem and return_per_cell:
145
+ return 1.0, 0.0, hits
146
+ if return_sem:
147
+ return 1.0, 0.0
148
+ if return_per_cell:
149
+ return 1.0, hits
150
+ return 1.0
151
+
152
+ k = int(max(1, min(int(k), n)))
153
+ metric = str(metric).lower().strip()
154
+ if metric not in {"euclidean", "cosine"}:
155
+ raise ValueError("compute_match_recall_at_k currently supports metric in {'euclidean','cosine'}.")
156
+
157
+ hits = np.empty(n, dtype=np.float32)
158
+
159
+ if metric == "euclidean":
160
+ Z2_T = Z2.T
161
+ n2 = np.sum(Z2 * Z2, axis=1) # (n,)
162
+ for i0 in range(0, n, int(block_size)):
163
+ i1 = min(i0 + int(block_size), n)
164
+ A = Z1[i0:i1]
165
+ n1 = np.sum(A * A, axis=1)[:, None] # (b,1)
166
+ d2 = n1 + n2[None, :] - 2.0 * (A @ Z2_T) # (b,n)
167
+ # indices of k smallest (unordered), then check membership
168
+ topk = np.argpartition(d2, kth=k - 1, axis=1)[:, :k]
169
+ for r in range(i1 - i0):
170
+ hits[i0 + r] = 1.0 if (i0 + r) in topk[r] else 0.0
171
+ else:
172
+ Z2_T = Z2.T
173
+ n2 = np.linalg.norm(Z2, axis=1) + 1e-8
174
+ for i0 in range(0, n, int(block_size)):
175
+ i1 = min(i0 + int(block_size), n)
176
+ A = Z1[i0:i1]
177
+ n1 = np.linalg.norm(A, axis=1) + 1e-8
178
+ sim = (A @ Z2_T) / (n1[:, None] * n2[None, :])
179
+ d = 1.0 - sim
180
+ topk = np.argpartition(d, kth=k - 1, axis=1)[:, :k]
181
+ for r in range(i1 - i0):
182
+ hits[i0 + r] = 1.0 if (i0 + r) in topk[r] else 0.0
183
+
184
+ m, s = _mean_sem(hits.astype(float))
185
+ if return_sem and return_per_cell:
186
+ return float(m), float(s), hits
187
+ if return_sem:
188
+ return float(m), float(s)
189
+ if return_per_cell:
190
+ return float(m), hits
191
+ return float(m)
192
+
193
+
194
+ # ------------------------------------------------------------------
195
+ # 2. Modality mixing
196
+ # ------------------------------------------------------------------
197
+ def compute_modality_mixing(
198
+ Z: np.ndarray,
199
+ modality_labels: np.ndarray,
200
+ k: int = 20,
201
+ metric: str = "euclidean",
202
+ return_sem: bool = False,
203
+ return_per_cell: bool = False,
204
+ ) -> Union[float, Tuple[float, float], Tuple[float, np.ndarray], Tuple[float, float, np.ndarray]]:
205
+ """
206
+ Mean fraction of kNN neighbors that are from a different modality.
207
+ """
208
+ Z = np.asarray(Z, dtype=np.float32)
209
+ modality_labels = np.asarray(modality_labels)
210
+ if Z.shape[0] != modality_labels.shape[0]:
211
+ raise ValueError("Z and modality_labels must align on n_cells.")
212
+
213
+ n = int(Z.shape[0])
214
+ if n <= 1:
215
+ if return_sem and return_per_cell:
216
+ return 0.0, 0.0, np.zeros(n, dtype=np.float32)
217
+ if return_sem:
218
+ return 0.0, 0.0
219
+ if return_per_cell:
220
+ return 0.0, np.zeros(n, dtype=np.float32)
221
+ return 0.0
222
+
223
+ metric = str(metric).lower().strip()
224
+ k_eff = int(min(max(int(k), 1), n - 1))
225
+
226
+ nn = NearestNeighbors(n_neighbors=k_eff + 1, metric=metric)
227
+ nn.fit(Z)
228
+ neigh_idx = nn.kneighbors(Z, return_distance=False)[:, 1:] # drop self
229
+
230
+ neigh_mods = modality_labels[neigh_idx]
231
+ frac_other = (neigh_mods != modality_labels[:, None]).mean(axis=1).astype(np.float32)
232
+
233
+ m, s = _mean_sem(frac_other.astype(float))
234
+ if return_sem and return_per_cell:
235
+ return float(m), float(s), frac_other
236
+ if return_sem:
237
+ return float(m), float(s)
238
+ if return_per_cell:
239
+ return float(m), frac_other
240
+ return float(m)
241
+
242
+
243
+ # ------------------------------------------------------------------
244
+ # 3. Label transfer (kNN) with extra stats (macro/weighted F1)
245
+ # ------------------------------------------------------------------
246
+ def label_transfer_knn(
247
+ Z_source: np.ndarray,
248
+ labels_source: np.ndarray,
249
+ Z_target: np.ndarray,
250
+ labels_target: Optional[np.ndarray] = None,
251
+ k: int = 15,
252
+ metric: str = "euclidean",
253
+ return_label_order: bool = False,
254
+ return_f1: bool = False,
255
+ ):
256
+ """
257
+ Backwards-compatible returns:
258
+ - if labels_target is None: (pred_labels, None, empty_cm)
259
+ - if labels_target provided and both flags False: (pred_labels, acc, cm)
260
+ - if return_label_order True: add label_order
261
+ - if return_f1 True: add f1_dict
262
+ - if both True: add both (label_order, f1_dict) in that order
263
+ """
264
+ Z_source = np.asarray(Z_source, dtype=np.float32)
265
+ Z_target = np.asarray(Z_target, dtype=np.float32)
266
+ labels_source = np.asarray(labels_source)
267
+ if labels_target is not None:
268
+ labels_target = np.asarray(labels_target)
269
+
270
+ n_source = int(Z_source.shape[0])
271
+ if n_source == 0:
272
+ raise ValueError("Z_source is empty.")
273
+
274
+ k_eff = int(min(max(int(k), 1), n_source))
275
+ nn = NearestNeighbors(n_neighbors=k_eff, metric=metric)
276
+ nn.fit(Z_source)
277
+ neigh_idx = nn.kneighbors(Z_target, return_distance=False)
278
+
279
+ uniq_src, src_codes = np.unique(labels_source, return_inverse=True)
280
+
281
+ pred_codes = np.empty(Z_target.shape[0], dtype=np.int64)
282
+ for i in range(Z_target.shape[0]):
283
+ votes = src_codes[neigh_idx[i]]
284
+ bc = np.bincount(votes, minlength=len(uniq_src))
285
+ pred_codes[i] = int(bc.argmax())
286
+
287
+ pred_labels = uniq_src[pred_codes]
288
+
289
+ if labels_target is None:
290
+ return pred_labels, None, np.array([])
291
+
292
+ label_order = np.unique(np.concatenate([labels_target, pred_labels]))
293
+ acc = float(accuracy_score(labels_target, pred_labels))
294
+ cm = confusion_matrix(labels_target, pred_labels, labels=label_order)
295
+
296
+ extras = []
297
+ if return_label_order:
298
+ extras.append(label_order)
299
+ if return_f1:
300
+ extras.append({
301
+ "macro_f1": float(f1_score(labels_target, pred_labels, average="macro")),
302
+ "weighted_f1": float(f1_score(labels_target, pred_labels, average="weighted")),
303
+ })
304
+
305
+ if not extras:
306
+ return pred_labels, acc, cm
307
+ return (pred_labels, acc, cm, *extras)
308
+
309
+
310
+ # ------------------------------------------------------------------
311
+ # 4. Reconstruction metrics (continuous; useful for CITE-seq CLR/gaussian)
312
+ # ------------------------------------------------------------------
313
+ def mse_per_feature(x_true: np.ndarray, x_pred: np.ndarray) -> np.ndarray:
314
+ x_true = np.asarray(x_true)
315
+ x_pred = np.asarray(x_pred)
316
+ if x_true.shape != x_pred.shape:
317
+ raise ValueError("x_true and x_pred must have same shape.")
318
+ return np.mean((x_true - x_pred) ** 2, axis=0)
319
+
320
+
321
+ def pearson_corr_per_feature(x_true: np.ndarray, x_pred: np.ndarray) -> np.ndarray:
322
+ x_true = np.asarray(x_true, dtype=np.float32)
323
+ x_pred = np.asarray(x_pred, dtype=np.float32)
324
+ if x_true.shape != x_pred.shape:
325
+ raise ValueError("x_true and x_pred must have same shape.")
326
+
327
+ x_true_c = x_true - x_true.mean(axis=0, keepdims=True)
328
+ x_pred_c = x_pred - x_pred.mean(axis=0, keepdims=True)
329
+
330
+ num = (x_true_c * x_pred_c).sum(axis=0)
331
+ denom = np.sqrt((x_true_c ** 2).sum(axis=0) * (x_pred_c ** 2).sum(axis=0)) + 1e-8
332
+ return num / denom
333
+
334
+
335
+ def reconstruction_metrics(x_true: np.ndarray, x_pred: np.ndarray) -> Dict[str, Any]:
336
+ pf_mse = mse_per_feature(x_true, x_pred)
337
+ pf_r = pearson_corr_per_feature(x_true, x_pred)
338
+ return {
339
+ "mse_mean": float(np.mean(pf_mse)),
340
+ "mse_median": float(np.median(pf_mse)),
341
+ "pearson_mean": float(np.mean(pf_r)),
342
+ "pearson_median": float(np.median(pf_r)),
343
+ "mse_per_feature": pf_mse,
344
+ "pearson_per_feature": pf_r,
345
+ }
346
+
347
+
348
+ # ------------------------------------------------------------------
349
+ # 5. Encoding + cross-modal prediction
350
+ # ------------------------------------------------------------------
351
+ def encode_adata(
352
+ model,
353
+ adata,
354
+ modality: str,
355
+ device: str = "cpu",
356
+ layer: Optional[str] = None,
357
+ X_key: str = "X",
358
+ batch_size: int = 1024,
359
+ latent: str = "moe_mean",
360
+ random_state: int = 0,
361
+ ) -> np.ndarray:
362
+ from .data import _get_matrix
363
+
364
+ latent = str(latent).lower().strip()
365
+ valid = {"moe_mean", "moe_sample", "modality_mean", "modality_sample"}
366
+ if latent not in valid:
367
+ raise ValueError("latent must be one of %s; got %r" % (sorted(valid), latent))
368
+
369
+ def _sample_gaussian(mu: torch.Tensor, logvar: torch.Tensor, gen: torch.Generator) -> torch.Tensor:
370
+ eps = torch.randn(mu.shape, device=mu.device, generator=gen, dtype=mu.dtype)
371
+ return mu + eps * torch.exp(0.5 * logvar)
372
+
373
+ model.eval()
374
+ X = _get_matrix(adata, layer=layer, X_key=X_key)
375
+ if sp.issparse(X):
376
+ X = X.toarray()
377
+
378
+ dev = torch.device(device)
379
+ gen = torch.Generator(device=dev)
380
+ gen.manual_seed(int(random_state))
381
+
382
+ zs = []
383
+ with torch.no_grad():
384
+ for start in range(0, X.shape[0], int(batch_size)):
385
+ end = min(start + int(batch_size), X.shape[0])
386
+ xb = torch.as_tensor(np.asarray(X[start:end]), dtype=torch.float32, device=dev)
387
+
388
+ mu_dict, logvar_dict = model.encode_modalities({modality: xb})
389
+
390
+ if "modality" in latent:
391
+ mu = mu_dict[modality]
392
+ lv = logvar_dict[modality]
393
+ z = mu if latent.endswith("_mean") else _sample_gaussian(mu, lv, gen)
394
+ else:
395
+ # robust fallback in case a future refactor renames MoE fuser
396
+ if hasattr(model, "mixture_of_experts"):
397
+ mu_z, logvar_z = model.mixture_of_experts(mu_dict, logvar_dict)
398
+ elif hasattr(model, "fuse_posteriors"):
399
+ mu_z, logvar_z = model.fuse_posteriors(mu_dict, logvar_dict)
400
+ else:
401
+ # single-modality fallback
402
+ mu_z, logvar_z = mu_dict[modality], logvar_dict[modality]
403
+
404
+ z = mu_z if latent.endswith("_mean") else _sample_gaussian(mu_z, logvar_z, gen)
405
+
406
+ zs.append(z.detach().cpu().numpy())
407
+
408
+ return np.vstack(zs)
409
+
410
+
411
+ def cross_modal_predict(
412
+ model,
413
+ adata_src,
414
+ src_mod: str,
415
+ tgt_mod: str,
416
+ device: str = "cpu",
417
+ layer: Optional[str] = None,
418
+ X_key: str = "X",
419
+ batch_size: int = 512,
420
+ use_moe: bool = True,
421
+ ) -> np.ndarray:
422
+ """
423
+ Encode src_mod then decode tgt_mod.
424
+
425
+ For paired data with ONLY src_mod observed, MoE fusion == src posterior.
426
+ Still, use_moe=False can be handy if you want to force src-only even if model changes.
427
+ """
428
+ from .data import _get_matrix
429
+
430
+ model.eval()
431
+ X = _get_matrix(adata_src, layer=layer, X_key=X_key)
432
+ if sp.issparse(X):
433
+ X = X.toarray()
434
+
435
+ dev = torch.device(device)
436
+
437
+ preds = []
438
+ with torch.no_grad():
439
+ for start in range(0, X.shape[0], int(batch_size)):
440
+ end = min(start + int(batch_size), X.shape[0])
441
+ xb = torch.as_tensor(np.asarray(X[start:end]), dtype=torch.float32, device=dev)
442
+
443
+ mu_dict, logvar_dict = model.encode_modalities({src_mod: xb})
444
+
445
+ if use_moe and hasattr(model, "mixture_of_experts"):
446
+ mu_z, _ = model.mixture_of_experts(mu_dict, logvar_dict)
447
+ else:
448
+ mu_z = mu_dict[src_mod]
449
+
450
+ xhat_dict = model.decode_modalities(mu_z)
451
+ if tgt_mod not in xhat_dict:
452
+ raise KeyError(f"Target modality {tgt_mod!r} not found. Available: {list(xhat_dict.keys())}")
453
+ preds.append(xhat_dict[tgt_mod].detach().cpu().numpy())
454
+
455
+ return np.vstack(preds) if preds else np.zeros((0, 0), dtype=float)
456
+
457
+
458
+ def denoise_adata(
459
+ model,
460
+ adata,
461
+ modality: str,
462
+ device: str = "cpu",
463
+ layer: Optional[str] = None,
464
+ X_key: str = "X",
465
+ batch_size: int = 512,
466
+ out_layer: Optional[str] = None,
467
+ overwrite_X: bool = False,
468
+ dtype: Optional[np.dtype] = np.float32,
469
+ ) -> np.ndarray:
470
+ X_hat = cross_modal_predict(
471
+ model,
472
+ adata_src=adata,
473
+ src_mod=modality,
474
+ tgt_mod=modality,
475
+ device=device,
476
+ layer=layer,
477
+ X_key=X_key,
478
+ batch_size=batch_size,
479
+ use_moe=True,
480
+ )
481
+ if dtype is not None:
482
+ X_hat = np.asarray(X_hat, dtype=dtype)
483
+
484
+ if overwrite_X:
485
+ adata.X = X_hat
486
+ elif out_layer is not None:
487
+ adata.layers[out_layer] = X_hat
488
+
489
+ return X_hat
490
+
491
+
492
+ # ------------------------------------------------------------------
493
+ # 6. High-level alignment eval (Figure-ready)
494
+ # ------------------------------------------------------------------
495
+ def evaluate_alignment(
496
+ Z1: Optional[np.ndarray] = None,
497
+ Z2: Optional[np.ndarray] = None,
498
+ model=None,
499
+ adata1=None,
500
+ adata2=None,
501
+ mod1: Optional[str] = None,
502
+ mod2: Optional[str] = None,
503
+ device: str = "cpu",
504
+ layer1: Optional[str] = None,
505
+ layer2: Optional[str] = None,
506
+ X_key1: str = "X",
507
+ X_key2: str = "X",
508
+ batch_size: int = 1024,
509
+ latent: str = "moe_mean",
510
+ latent1: Optional[str] = None,
511
+ latent2: Optional[str] = None,
512
+ random_state: int = 0,
513
+ metric: str = "euclidean",
514
+ k_mixing: int = 20,
515
+ k_transfer: int = 15,
516
+ modality_labels: Optional[np.ndarray] = None,
517
+ labels_source: Optional[np.ndarray] = None,
518
+ labels_target: Optional[np.ndarray] = None,
519
+ recall_ks: Tuple[int, ...] = (1, 5, 10),
520
+ foscttm_block_size: int = 512,
521
+ json_safe: bool = True,
522
+ ) -> Dict[str, Any]:
523
+ """
524
+ Returns a dict with:
525
+ - foscttm (mean), foscttm_sem
526
+ - recall@k + sem for each k in recall_ks
527
+ - modality_mixing (mean), modality_mixing_sem
528
+ - label transfer: acc, macro/weighted f1 (optional), confusion matrix + label order
529
+ """
530
+ out: Dict[str, Any] = {}
531
+
532
+ lat1 = latent if latent1 is None else latent1
533
+ lat2 = latent if latent2 is None else latent2
534
+
535
+ if Z1 is None or Z2 is None:
536
+ if model is None or adata1 is None or adata2 is None or mod1 is None or mod2 is None:
537
+ raise ValueError("Provide either (Z1, Z2) or (model, adata1, adata2, mod1, mod2).")
538
+
539
+ Z1 = encode_adata(
540
+ model, adata1, modality=mod1, device=device, layer=layer1, X_key=X_key1,
541
+ batch_size=batch_size, latent=lat1, random_state=random_state
542
+ )
543
+ Z2 = encode_adata(
544
+ model, adata2, modality=mod2, device=device, layer=layer2, X_key=X_key2,
545
+ batch_size=batch_size, latent=lat2, random_state=random_state
546
+ )
547
+
548
+ Z1 = np.asarray(Z1)
549
+ Z2 = np.asarray(Z2)
550
+
551
+ out["n1"] = int(Z1.shape[0])
552
+ out["n2"] = int(Z2.shape[0])
553
+ out["dim"] = int(Z1.shape[1]) if Z1.ndim == 2 else None
554
+ out["latent1"] = lat1
555
+ out["latent2"] = lat2
556
+ out["metric"] = str(metric)
557
+
558
+ # FOSCTTM + SEM
559
+ if Z1.shape == Z2.shape and Z1.shape[0] > 1:
560
+ fos_mean, fos_sem = compute_foscttm(
561
+ Z1, Z2, metric=metric, block_size=foscttm_block_size, return_sem=True, return_per_cell=False
562
+ )
563
+ out["foscttm"] = fos_mean
564
+ out["foscttm_sem"] = fos_sem
565
+ else:
566
+ out["foscttm"] = None
567
+ out["foscttm_sem"] = None
568
+
569
+ # Recall@k
570
+ if Z1.shape == Z2.shape and Z1.shape[0] > 1:
571
+ for k in recall_ks:
572
+ r_mean, r_sem = compute_match_recall_at_k(
573
+ Z1, Z2, k=int(k), metric=metric, block_size=foscttm_block_size, return_sem=True, return_per_cell=False
574
+ )
575
+ out[f"recall_at_{int(k)}"] = r_mean
576
+ out[f"recall_at_{int(k)}_sem"] = r_sem
577
+ else:
578
+ for k in recall_ks:
579
+ out[f"recall_at_{int(k)}"] = None
580
+ out[f"recall_at_{int(k)}_sem"] = None
581
+
582
+ # Modality mixing computed on concatenated embeddings
583
+ Z_concat = None
584
+ if (Z1.ndim == 2 and Z2.ndim == 2 and Z1.shape[1] == Z2.shape[1]):
585
+ Z_concat = np.vstack([Z1, Z2])
586
+
587
+ if Z_concat is not None and Z_concat.shape[0] > 1:
588
+ if modality_labels is None:
589
+ modality_labels = np.concatenate([np.repeat("mod1", Z1.shape[0]), np.repeat("mod2", Z2.shape[0])])
590
+ mix_mean, mix_sem = compute_modality_mixing(
591
+ Z_concat, modality_labels=np.asarray(modality_labels),
592
+ k=k_mixing, metric=metric, return_sem=True, return_per_cell=False
593
+ )
594
+ out["modality_mixing"] = mix_mean
595
+ out["modality_mixing_sem"] = mix_sem
596
+ out["k_mixing"] = int(k_mixing)
597
+ else:
598
+ out["modality_mixing"] = None
599
+ out["modality_mixing_sem"] = None
600
+ out["k_mixing"] = int(k_mixing)
601
+
602
+ # Label transfer
603
+ if labels_source is not None:
604
+ pred, acc, cm, order, f1d = label_transfer_knn(
605
+ Z_source=Z1,
606
+ labels_source=np.asarray(labels_source),
607
+ Z_target=Z2,
608
+ labels_target=np.asarray(labels_target) if labels_target is not None else None,
609
+ k=k_transfer,
610
+ metric=metric,
611
+ return_label_order=True,
612
+ return_f1=True,
613
+ )
614
+ out["label_transfer_pred"] = pred
615
+ out["label_transfer_acc"] = acc
616
+ out["label_transfer_cm"] = cm
617
+ out["label_transfer_label_order"] = order
618
+ out["label_transfer_f1"] = f1d
619
+ out["k_transfer"] = int(k_transfer)
620
+ else:
621
+ out["label_transfer_pred"] = None
622
+ out["label_transfer_acc"] = None
623
+ out["label_transfer_cm"] = None
624
+ out["label_transfer_label_order"] = None
625
+ out["label_transfer_f1"] = None
626
+ out["k_transfer"] = int(k_transfer)
627
+
628
+ if json_safe:
629
+ out = {k: _json_safe(v) for k, v in out.items()}
630
+
631
+ return out
632
+
@@ -0,0 +1,17 @@
1
+ # univi/hyperparam_optimization/__init__.py
2
+
3
+ from .run_multiome_hparam_search import run_multiome_hparam_search
4
+ from .run_citeseq_hparam_search import run_citeseq_hparam_search
5
+ from .run_teaseq_hparam_search import run_teaseq_hparam_search
6
+ from .run_rna_hparam_search import run_rna_hparam_search
7
+ from .run_atac_hparam_search import run_atac_hparam_search
8
+ from .run_adt_hparam_search import run_adt_hparam_search
9
+
10
+ __all__ = [
11
+ "run_multiome_hparam_search",
12
+ "run_citeseq_hparam_search",
13
+ "run_teaseq_hparam_search",
14
+ "run_rna_hparam_search",
15
+ "run_atac_hparam_search",
16
+ "run_adt_hparam_search",
17
+ ]