univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/utils/io.py ADDED
@@ -0,0 +1,621 @@
1
+ # univi/utils/io.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Dict, Optional, Sequence, Union, Mapping, Literal
5
+ import os
6
+ import json
7
+
8
+ import numpy as np
9
+ import torch
10
+ import scipy.sparse as sp
11
+ import anndata as ad
12
+
13
+
14
+ SplitKey = Literal["train", "val", "test"]
15
+ SplitMap = Dict[SplitKey, Any]
16
+
17
+
18
+ # =============================================================================
19
+ # Checkpointing
20
+ # =============================================================================
21
+
22
+ def save_checkpoint(
23
+ path: str,
24
+ model_state: Optional[Dict[str, Any]] = None,
25
+ optimizer_state: Optional[Dict[str, Any]] = None,
26
+ extra: Optional[Dict[str, Any]] = None,
27
+ *,
28
+ model: Optional[torch.nn.Module] = None,
29
+ optimizer: Optional[torch.optim.Optimizer] = None,
30
+ trainer_state: Optional[Dict[str, Any]] = None,
31
+ scaler_state: Optional[Dict[str, Any]] = None,
32
+ config: Optional[Dict[str, Any]] = None,
33
+ strict_label_compat: bool = True,
34
+ ) -> None:
35
+ if model_state is None and model is not None:
36
+ model_state = model.state_dict()
37
+ if optimizer_state is None and optimizer is not None:
38
+ optimizer_state = optimizer.state_dict()
39
+
40
+ if model_state is None:
41
+ raise ValueError("save_checkpoint requires either model_state=... or model=...")
42
+
43
+ payload: Dict[str, Any] = {
44
+ "format_version": 3,
45
+ "model_state": model_state,
46
+ }
47
+ if optimizer_state is not None:
48
+ payload["optimizer_state"] = optimizer_state
49
+ if trainer_state is not None:
50
+ payload["trainer_state"] = dict(trainer_state)
51
+ if scaler_state is not None:
52
+ payload["scaler_state"] = dict(scaler_state)
53
+ if config is not None:
54
+ payload["config"] = dict(config)
55
+ if extra is not None:
56
+ payload["extra"] = dict(extra)
57
+
58
+ # --- classification metadata (legacy + multi-head) ---
59
+ if model is not None:
60
+ meta: Dict[str, Any] = {}
61
+
62
+ n_label_classes = getattr(model, "n_label_classes", None)
63
+ label_names = getattr(model, "label_names", None)
64
+ label_head_name = getattr(model, "label_head_name", None)
65
+
66
+ if n_label_classes is not None:
67
+ meta.setdefault("legacy", {})["n_label_classes"] = int(n_label_classes)
68
+ if label_names is not None:
69
+ meta.setdefault("legacy", {})["label_names"] = list(label_names)
70
+ if label_head_name is not None:
71
+ meta["label_head_name"] = str(label_head_name)
72
+
73
+ class_heads_cfg = getattr(model, "class_heads_cfg", None)
74
+ head_label_names = getattr(model, "head_label_names", None)
75
+
76
+ if isinstance(class_heads_cfg, dict) and len(class_heads_cfg) > 0:
77
+ meta.setdefault("multi", {})["heads"] = {k: dict(v) for k, v in class_heads_cfg.items()}
78
+ if isinstance(head_label_names, dict) and len(head_label_names) > 0:
79
+ meta.setdefault("multi", {})["label_names"] = {k: list(v) for k, v in head_label_names.items()}
80
+
81
+ if hasattr(model, "get_classification_meta"):
82
+ try:
83
+ meta = dict(model.get_classification_meta())
84
+ except Exception:
85
+ pass
86
+
87
+ if meta:
88
+ payload["label_meta"] = meta
89
+ payload["strict_label_compat"] = bool(strict_label_compat)
90
+
91
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
92
+ torch.save(payload, path)
93
+
94
+
95
+ def load_checkpoint(path: str, *, map_location: Union[str, torch.device, None] = "cpu") -> Dict[str, Any]:
96
+ return torch.load(path, map_location=map_location)
97
+
98
+
99
+ def restore_checkpoint(
100
+ payload_or_path: Union[str, Dict[str, Any]],
101
+ *,
102
+ model: torch.nn.Module,
103
+ optimizer: Optional[torch.optim.Optimizer] = None,
104
+ scaler: Optional[torch.cuda.amp.GradScaler] = None,
105
+ map_location: Union[str, torch.device, None] = "cpu",
106
+ strict: bool = True,
107
+ restore_label_names: bool = True,
108
+ enforce_label_compat: bool = True,
109
+ ) -> Dict[str, Any]:
110
+ payload = load_checkpoint(payload_or_path, map_location=map_location) if isinstance(payload_or_path, str) else payload_or_path
111
+
112
+ if enforce_label_compat:
113
+ meta = payload.get("label_meta", {}) or {}
114
+
115
+ legacy = meta.get("legacy", {}) if isinstance(meta, dict) else {}
116
+ ckpt_n = legacy.get("n_label_classes", None)
117
+ model_n = getattr(model, "n_label_classes", None)
118
+ if ckpt_n is not None and model_n is not None and int(ckpt_n) != int(model_n):
119
+ raise ValueError(
120
+ f"Checkpoint n_label_classes={ckpt_n} does not match model n_label_classes={model_n}. "
121
+ "Rebuild the model with the same n_label_classes."
122
+ )
123
+
124
+ multi = meta.get("multi", {}) if isinstance(meta, dict) else {}
125
+ ckpt_heads = multi.get("heads", None)
126
+ model_heads = getattr(model, "class_heads_cfg", None)
127
+ if isinstance(ckpt_heads, dict) and isinstance(model_heads, dict):
128
+ for hk, hcfg in ckpt_heads.items():
129
+ if hk not in model_heads:
130
+ raise ValueError(
131
+ f"Checkpoint contains head {hk!r} but model does not. "
132
+ f"Model heads: {list(model_heads.keys())}"
133
+ )
134
+ ckpt_c = int(hcfg.get("n_classes", -1))
135
+ model_c = int(model_heads[hk].get("n_classes", -1))
136
+ if ckpt_c != model_c:
137
+ raise ValueError(f"Head {hk!r} n_classes mismatch: checkpoint={ckpt_c}, model={model_c}.")
138
+
139
+ model.load_state_dict(payload["model_state"], strict=bool(strict))
140
+
141
+ if optimizer is not None and "optimizer_state" in payload:
142
+ optimizer.load_state_dict(payload["optimizer_state"])
143
+
144
+ if scaler is not None and payload.get("scaler_state") is not None:
145
+ try:
146
+ scaler.load_state_dict(payload["scaler_state"])
147
+ except Exception:
148
+ pass
149
+
150
+ if restore_label_names:
151
+ meta = payload.get("label_meta", {}) or {}
152
+
153
+ legacy = meta.get("legacy", {}) if isinstance(meta, dict) else {}
154
+ label_names = legacy.get("label_names", None)
155
+ if label_names is not None and hasattr(model, "set_label_names"):
156
+ try:
157
+ model.set_label_names(list(label_names))
158
+ except Exception:
159
+ pass
160
+
161
+ multi = meta.get("multi", {}) if isinstance(meta, dict) else {}
162
+ head_names = multi.get("label_names", None)
163
+ if isinstance(head_names, dict) and hasattr(model, "set_head_label_names"):
164
+ for hk, names in head_names.items():
165
+ try:
166
+ model.set_head_label_names(str(hk), list(names))
167
+ except Exception:
168
+ pass
169
+
170
+ return payload
171
+
172
+
173
+ # =============================================================================
174
+ # JSON config helpers
175
+ # =============================================================================
176
+
177
+ def save_config_json(config: Dict[str, Any], path: str) -> None:
178
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
179
+ with open(path, "w") as f:
180
+ json.dump(config, f, indent=2, sort_keys=True)
181
+
182
+
183
+ def load_config(config_path: str) -> Dict[str, Any]:
184
+ with open(config_path) as f:
185
+ return json.load(f)
186
+
187
+
188
+ # =============================================================================
189
+ # AnnData split helpers (SAFE + FLEXIBLE)
190
+ # =============================================================================
191
+
192
+ def _is_sequence_of_str(x: Any) -> bool:
193
+ return isinstance(x, (list, tuple, np.ndarray)) and len(x) > 0 and isinstance(x[0], (str, np.str_))
194
+
195
+
196
+ def _is_sequence_of_int(x: Any) -> bool:
197
+ return isinstance(x, (list, tuple, np.ndarray)) and len(x) > 0 and isinstance(x[0], (int, np.integer))
198
+
199
+
200
+ def _normalize_split_selector(
201
+ adata: ad.AnnData,
202
+ selector: Any,
203
+ *,
204
+ name: str,
205
+ ) -> Union[np.ndarray, Sequence[str], Sequence[int]]:
206
+ """
207
+ Convert a selector into one of:
208
+ - boolean mask (np.ndarray[bool] length n_obs)
209
+ - list of obs_names (Sequence[str])
210
+ - list of integer indices (Sequence[int])
211
+
212
+ Rejects AnnData objects or other iterables that could trigger accidental iteration.
213
+ """
214
+ if selector is None:
215
+ return np.zeros(adata.n_obs, dtype=bool)
216
+
217
+ # HARD FAIL: AnnData passed where indices were expected
218
+ if isinstance(selector, ad.AnnData):
219
+ raise TypeError(
220
+ f"{name}: expected indices/obs_names/mask, but got AnnData. "
221
+ "Pass `splits={'train': adata_train, ...}` instead."
222
+ )
223
+
224
+ # bool mask
225
+ if isinstance(selector, np.ndarray) and selector.dtype == bool:
226
+ if selector.shape[0] != adata.n_obs:
227
+ raise ValueError(f"{name}: boolean mask length {selector.shape[0]} != n_obs {adata.n_obs}.")
228
+ return selector
229
+
230
+ # pandas Series of bool
231
+ try:
232
+ import pandas as pd # optional
233
+ if isinstance(selector, pd.Series) and selector.dtype == bool:
234
+ v = selector.to_numpy()
235
+ if v.shape[0] != adata.n_obs:
236
+ raise ValueError(f"{name}: boolean mask length {v.shape[0]} != n_obs {adata.n_obs}.")
237
+ return v
238
+ except Exception:
239
+ pass
240
+
241
+ # list/tuple/np array of strings: obs_names
242
+ if _is_sequence_of_str(selector):
243
+ return [str(s) for s in list(selector)]
244
+
245
+ # list/tuple/np array of ints: indices
246
+ if _is_sequence_of_int(selector):
247
+ idx = [int(i) for i in list(selector)]
248
+ if len(idx) > 0:
249
+ mx = max(idx)
250
+ mn = min(idx)
251
+ if mn < 0 or mx >= adata.n_obs:
252
+ raise IndexError(f"{name}: index out of bounds (min={mn}, max={mx}, n_obs={adata.n_obs}).")
253
+ return idx
254
+
255
+ # empty list/tuple
256
+ if isinstance(selector, (list, tuple)) and len(selector) == 0:
257
+ return np.zeros(adata.n_obs, dtype=bool)
258
+
259
+ raise TypeError(
260
+ f"{name}: unsupported selector type {type(selector)}. "
261
+ "Use obs_names (list[str]), indices (list[int]), or boolean mask (np.ndarray[bool]). "
262
+ "If you already have split AnnData objects, pass `splits=...`."
263
+ )
264
+
265
+
266
+ def _subset_adata(
267
+ adata: ad.AnnData,
268
+ selector: Union[np.ndarray, Sequence[str], Sequence[int]],
269
+ *,
270
+ copy: bool,
271
+ ) -> ad.AnnData:
272
+ if isinstance(selector, np.ndarray) and selector.dtype == bool:
273
+ out = adata[selector]
274
+ elif _is_sequence_of_str(selector):
275
+ out = adata[list(selector)]
276
+ else:
277
+ out = adata[list(selector), :]
278
+ return out.copy() if copy else out
279
+
280
+
281
+ def _write_h5ad_safe(adata_obj: ad.AnnData, path: str, *, write_backed: bool = False) -> None:
282
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
283
+ # write_backed kept for API compatibility; anndata's backed writing patterns vary.
284
+ adata_obj.write_h5ad(path)
285
+
286
+
287
+ def save_anndata_splits(
288
+ adata: Optional[ad.AnnData] = None,
289
+ outdir: str = ".",
290
+ *,
291
+ prefix: str = "dataset",
292
+ # Mode A: split from adata.obs[split_key]
293
+ split_key: Optional[str] = "split",
294
+ train_label: str = "train",
295
+ val_label: str = "val",
296
+ test_label: str = "test",
297
+ # Mode B: split from selectors (indices/obs_names/mask)
298
+ split_map: Optional[Dict[str, Any]] = None,
299
+ # Mode C: already split objects (no resplitting)
300
+ splits: Optional[Dict[str, ad.AnnData]] = None,
301
+ # Behavior
302
+ copy: bool = False,
303
+ write_backed: bool = False,
304
+ save_h5ad: bool = True,
305
+ save_split_map: bool = True,
306
+ split_map_name: Optional[str] = None,
307
+ require_disjoint: bool = True,
308
+ ) -> Dict[str, ad.AnnData]:
309
+ """
310
+ Save train/val/test splits to {outdir}/{prefix}_{train|val|test}.h5ad and optionally a split map JSON.
311
+
312
+ You can provide EXACTLY ONE of:
313
+ 1) splits={...} (already split AnnData objects)
314
+ 2) adata + split_map={...} (selectors: obs_names, indices, or boolean masks)
315
+ 3) adata + split_key in adata.obs (labels in .obs)
316
+
317
+ Safety: passing AnnData objects inside split_map is rejected with a clear error.
318
+ """
319
+ os.makedirs(outdir, exist_ok=True)
320
+
321
+ if splits is not None:
322
+ if adata is not None and split_map is not None:
323
+ raise ValueError("Provide only one of: splits=, (adata+split_map), or (adata+split_key).")
324
+ missing = [k for k in ("train", "val", "test") if k not in splits]
325
+ if missing:
326
+ raise ValueError(f"splits is missing keys: {missing}. Expected train/val/test.")
327
+ train = splits["train"]
328
+ val = splits["val"]
329
+ test = splits["test"]
330
+ if require_disjoint:
331
+ s1 = set(train.obs_names)
332
+ s2 = set(val.obs_names)
333
+ s3 = set(test.obs_names)
334
+ if (s1 & s2) or (s1 & s3) or (s2 & s3):
335
+ raise ValueError("splits are not disjoint by obs_names (overlap found).")
336
+ else:
337
+ if adata is None:
338
+ raise ValueError("If splits is not provided, you must provide adata=...")
339
+
340
+ if split_map is not None:
341
+ sel_train = _normalize_split_selector(adata, split_map.get("train", None), name="split_map['train']")
342
+ sel_val = _normalize_split_selector(adata, split_map.get("val", None), name="split_map['val']")
343
+ sel_test = _normalize_split_selector(adata, split_map.get("test", None), name="split_map['test']")
344
+
345
+ if require_disjoint:
346
+ def to_set(sel):
347
+ if isinstance(sel, np.ndarray) and sel.dtype == bool:
348
+ return set(np.where(sel)[0].tolist())
349
+ if _is_sequence_of_str(sel):
350
+ return set(map(str, sel))
351
+ return set(map(int, sel))
352
+ a = to_set(sel_train); b = to_set(sel_val); c = to_set(sel_test)
353
+ if (a & b) or (a & c) or (b & c):
354
+ raise ValueError("split_map splits overlap (require_disjoint=True).")
355
+
356
+ train = _subset_adata(adata, sel_train, copy=copy)
357
+ val = _subset_adata(adata, sel_val, copy=copy)
358
+ test = _subset_adata(adata, sel_test, copy=copy)
359
+
360
+ else:
361
+ if split_key is None or split_key not in adata.obs:
362
+ raise ValueError(
363
+ f"Expected split labels in adata.obs[{split_key!r}], or provide split_map=..., or splits=...."
364
+ )
365
+ s = adata.obs[split_key].astype(str)
366
+ train = adata[s == train_label].copy() if copy else adata[s == train_label]
367
+ val = adata[s == val_label].copy() if copy else adata[s == val_label]
368
+ test = adata[s == test_label].copy() if copy else adata[s == test_label]
369
+
370
+ if require_disjoint:
371
+ s1 = set(train.obs_names); s2 = set(val.obs_names); s3 = set(test.obs_names)
372
+ if (s1 & s2) or (s1 & s3) or (s2 & s3):
373
+ raise ValueError("split_key-derived splits are not disjoint by obs_names.")
374
+
375
+ paths = {
376
+ "train": os.path.join(outdir, f"{prefix}_train.h5ad"),
377
+ "val": os.path.join(outdir, f"{prefix}_val.h5ad"),
378
+ "test": os.path.join(outdir, f"{prefix}_test.h5ad"),
379
+ }
380
+
381
+ if save_h5ad:
382
+ _write_h5ad_safe(train, paths["train"], write_backed=write_backed)
383
+ _write_h5ad_safe(val, paths["val"], write_backed=write_backed)
384
+ _write_h5ad_safe(test, paths["test"], write_backed=write_backed)
385
+
386
+ if save_split_map:
387
+ sm = {
388
+ "train": train.obs_names.tolist(),
389
+ "val": val.obs_names.tolist(),
390
+ "test": test.obs_names.tolist(),
391
+ "prefix": prefix,
392
+ }
393
+ if splits is None and split_map is None and adata is not None:
394
+ sm.update({
395
+ "split_key": split_key,
396
+ "train_label": train_label,
397
+ "val_label": val_label,
398
+ "test_label": test_label,
399
+ })
400
+
401
+ fn = split_map_name or f"{prefix}_split_map.json"
402
+ with open(os.path.join(outdir, fn), "w") as f:
403
+ json.dump(sm, f, indent=2)
404
+
405
+ return {"train": train, "val": val, "test": test}
406
+
407
+
408
+ # =============================================================================
409
+ # Loading helpers (ADDED)
410
+ # =============================================================================
411
+
412
+ def load_split_map(outdir: str, prefix: str = "dataset", split_map_name: Optional[str] = None) -> Dict[str, Any]:
413
+ fn = split_map_name or f"{prefix}_split_map.json"
414
+ path = os.path.join(outdir, fn)
415
+ with open(path) as f:
416
+ return json.load(f)
417
+
418
+
419
+ def load_anndata_splits(
420
+ outdir: str,
421
+ prefix: str = "dataset",
422
+ *,
423
+ backed: Optional[Union[bool, str]] = None,
424
+ ) -> Dict[str, ad.AnnData]:
425
+ """
426
+ Load {prefix}_{train|val|test}.h5ad from outdir.
427
+
428
+ backed:
429
+ - None: normal in-memory load
430
+ - "r": backed read-only (useful for huge files)
431
+ - True: alias for "r"
432
+ """
433
+ paths = {
434
+ "train": os.path.join(outdir, f"{prefix}_train.h5ad"),
435
+ "val": os.path.join(outdir, f"{prefix}_val.h5ad"),
436
+ "test": os.path.join(outdir, f"{prefix}_test.h5ad"),
437
+ }
438
+ missing = [k for k, p in paths.items() if not os.path.exists(p)]
439
+ if missing:
440
+ raise FileNotFoundError(f"Missing split files for prefix={prefix!r}: {missing}")
441
+
442
+ if backed is True:
443
+ backed = "r"
444
+
445
+ return {k: ad.read_h5ad(p, backed=backed) for k, p in paths.items()}
446
+
447
+
448
+ def subset_anndata_from_split_map(
449
+ adata: ad.AnnData,
450
+ split_map: Dict[str, Any],
451
+ *,
452
+ copy: bool = False,
453
+ require_all: bool = True,
454
+ ) -> Dict[str, ad.AnnData]:
455
+ """
456
+ Recreate splits from a loaded split_map JSON (expects obs_names lists in split_map['train'/'val'/'test']).
457
+ """
458
+ keys = ("train", "val", "test")
459
+ if require_all:
460
+ for k in keys:
461
+ if k not in split_map:
462
+ raise KeyError(f"split_map missing key {k!r}")
463
+
464
+ def get_names(k: str):
465
+ v = split_map.get(k, [])
466
+ if v is None:
467
+ v = []
468
+ if not isinstance(v, (list, tuple)):
469
+ raise TypeError(f"split_map[{k!r}] must be a list of obs_names.")
470
+ return [str(x) for x in v]
471
+
472
+ train_names = get_names("train")
473
+ val_names = get_names("val")
474
+ test_names = get_names("test")
475
+
476
+ train = adata[train_names]
477
+ val = adata[val_names]
478
+ test = adata[test_names]
479
+
480
+ if copy:
481
+ train = train.copy()
482
+ val = val.copy()
483
+ test = test.copy()
484
+
485
+ return {"train": train, "val": val, "test": test}
486
+
487
+
488
+ def load_or_recreate_splits(
489
+ outdir: str,
490
+ prefix: str,
491
+ *,
492
+ adata: Optional[ad.AnnData] = None,
493
+ backed: Optional[Union[bool, str]] = None,
494
+ split_map_name: Optional[str] = None,
495
+ copy: bool = False,
496
+ ) -> Dict[str, ad.AnnData]:
497
+ """
498
+ Convenience:
499
+ - if {prefix}_train/val/test.h5ad exist: load them
500
+ - else if split map exists and adata is provided: recreate splits from adata
501
+ """
502
+ paths = {
503
+ "train": os.path.join(outdir, f"{prefix}_train.h5ad"),
504
+ "val": os.path.join(outdir, f"{prefix}_val.h5ad"),
505
+ "test": os.path.join(outdir, f"{prefix}_test.h5ad"),
506
+ }
507
+ if all(os.path.exists(p) for p in paths.values()):
508
+ return load_anndata_splits(outdir, prefix=prefix, backed=backed)
509
+
510
+ # fallback to split map + adata
511
+ sm = load_split_map(outdir, prefix=prefix, split_map_name=split_map_name)
512
+ if adata is None:
513
+ raise ValueError(
514
+ f"Split .h5ad files not found for prefix={prefix!r}. "
515
+ "Provide `adata=...` to recreate from split_map JSON."
516
+ )
517
+ return subset_anndata_from_split_map(adata, sm, copy=copy)
518
+
519
+
520
+ # =============================================================================
521
+ # Latent writer
522
+ # =============================================================================
523
+
524
+ def _select_X(adata_obj: ad.AnnData, layer: Optional[str], X_key: str):
525
+ if X_key != "X":
526
+ if X_key not in adata_obj.obsm:
527
+ raise KeyError(f"X_key={X_key!r} not found in adata.obsm. Keys={list(adata_obj.obsm.keys())}")
528
+ return adata_obj.obsm[X_key]
529
+ if layer is not None:
530
+ if layer not in adata_obj.layers:
531
+ raise KeyError(f"layer={layer!r} not found in adata.layers. Keys={list(adata_obj.layers.keys())}")
532
+ return adata_obj.layers[layer]
533
+ return adata_obj.X
534
+
535
+
536
+ @torch.no_grad()
537
+ def write_univi_latent(
538
+ model,
539
+ adata_dict: Dict[str, ad.AnnData],
540
+ *,
541
+ obsm_key: str = "X_univi",
542
+ batch_size: int = 512,
543
+ device: Optional[Union[str, torch.device]] = None,
544
+ use_mean: bool = False,
545
+ epoch: int = 0,
546
+ y: Optional[Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]] = None,
547
+ layer: Union[None, str, Mapping[str, Optional[str]]] = None,
548
+ X_key: Union[str, Mapping[str, str]] = "X",
549
+ require_paired: bool = True,
550
+ ) -> np.ndarray:
551
+ model.eval()
552
+
553
+ names = list(adata_dict.keys())
554
+ if len(names) == 0:
555
+ raise ValueError("adata_dict is empty.")
556
+
557
+ n = int(adata_dict[names[0]].n_obs)
558
+
559
+ if require_paired:
560
+ ref = adata_dict[names[0]].obs_names.values
561
+ for nm in names[1:]:
562
+ if adata_dict[nm].n_obs != n:
563
+ raise ValueError(f"n_obs mismatch: {nm} has {adata_dict[nm].n_obs} vs {n}")
564
+ if not np.array_equal(adata_dict[nm].obs_names.values, ref):
565
+ raise ValueError(f"obs_names mismatch between {names[0]} and {nm}")
566
+
567
+ if device is None:
568
+ try:
569
+ device = next(model.parameters()).device
570
+ except StopIteration:
571
+ device = "cpu"
572
+
573
+ layer_by_mod = dict(layer) if isinstance(layer, dict) else {nm: layer for nm in names}
574
+ xkey_by_mod = dict(X_key) if isinstance(X_key, dict) else {nm: X_key for nm in names}
575
+
576
+ y_t = None
577
+ if y is not None:
578
+ if isinstance(y, dict):
579
+ y_t = {str(k): (v if torch.is_tensor(v) else torch.as_tensor(v)).long() for k, v in y.items()}
580
+ else:
581
+ y_t = (y if torch.is_tensor(y) else torch.as_tensor(y)).long()
582
+
583
+ zs = []
584
+ bs = int(batch_size)
585
+ if bs <= 0:
586
+ raise ValueError("batch_size must be > 0")
587
+
588
+ for start in range(0, n, bs):
589
+ end = min(n, start + bs)
590
+ x_dict = {}
591
+ for nm in names:
592
+ adata_obj = adata_dict[nm]
593
+ X = _select_X(adata_obj, layer_by_mod.get(nm, None), xkey_by_mod.get(nm, "X"))
594
+ xb = X[start:end]
595
+ if sp.issparse(xb):
596
+ xb = xb.toarray()
597
+ xb = np.asarray(xb)
598
+ x_dict[nm] = torch.as_tensor(xb, dtype=torch.float32, device=device)
599
+
600
+ yb = None
601
+ if isinstance(y_t, dict):
602
+ yb = {k: v[start:end].to(device) for k, v in y_t.items()}
603
+ elif torch.is_tensor(y_t):
604
+ yb = y_t[start:end].to(device)
605
+
606
+ if hasattr(model, "encode_fused"):
607
+ mu_z, logvar_z, z = model.encode_fused(x_dict, epoch=int(epoch), y=yb, use_mean=bool(use_mean))
608
+ z_use = z
609
+ else:
610
+ out = model(x_dict)
611
+ z_use = out["mu_z"] if (use_mean and ("mu_z" in out)) else out["z"]
612
+
613
+ zs.append(z_use.detach().cpu().numpy())
614
+
615
+ Z = np.vstack(zs)
616
+
617
+ for nm in names:
618
+ adata_dict[nm].obsm[obsm_key] = Z
619
+
620
+ return Z
621
+
univi/utils/logging.py ADDED
@@ -0,0 +1,16 @@
1
+ # univi/utils/logging.py
2
+
3
+ from __future__ import annotations
4
+ import logging
5
+ from typing import Optional
6
+
7
+
8
+ def get_logger(name: str = "univi", level: int = logging.INFO) -> logging.Logger:
9
+ logger = logging.getLogger(name)
10
+ if not logger.handlers:
11
+ handler = logging.StreamHandler()
12
+ fmt = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s"
13
+ handler.setFormatter(logging.Formatter(fmt))
14
+ logger.addHandler(handler)
15
+ logger.setLevel(level)
16
+ return logger
univi/utils/seed.py ADDED
@@ -0,0 +1,18 @@
1
+ # univi/utils/seed.py
2
+
3
+ from __future__ import annotations
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+
8
+
9
+ def set_seed(seed: int = 0, deterministic: bool = False):
10
+ random.seed(seed)
11
+ np.random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ if torch.cuda.is_available():
14
+ torch.cuda.manual_seed_all(seed)
15
+
16
+ if deterministic:
17
+ torch.backends.cudnn.deterministic = True
18
+ torch.backends.cudnn.benchmark = False
univi/utils/stats.py ADDED
@@ -0,0 +1,23 @@
1
+ # univi/utils/stats.py
2
+
3
+ from __future__ import annotations
4
+ from typing import Dict
5
+ import numpy as np
6
+
7
+
8
+ def mean_dict(d: Dict[str, float]) -> float:
9
+ """
10
+ Simple helper: mean of dictionary values.
11
+ """
12
+ if not d:
13
+ return float("nan")
14
+ return float(np.mean(list(d.values())))
15
+
16
+
17
+ def zscore(x: np.ndarray, axis: int = 0, eps: float = 1e-8) -> np.ndarray:
18
+ """
19
+ Z-score normalization along given axis.
20
+ """
21
+ mu = x.mean(axis=axis, keepdims=True)
22
+ std = x.std(axis=axis, keepdims=True) + eps
23
+ return (x - mu) / std