univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/utils/io.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
1
|
+
# univi/utils/io.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any, Dict, Optional, Sequence, Union, Mapping, Literal
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import scipy.sparse as sp
|
|
11
|
+
import anndata as ad
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
SplitKey = Literal["train", "val", "test"]
|
|
15
|
+
SplitMap = Dict[SplitKey, Any]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# =============================================================================
|
|
19
|
+
# Checkpointing
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
def save_checkpoint(
|
|
23
|
+
path: str,
|
|
24
|
+
model_state: Optional[Dict[str, Any]] = None,
|
|
25
|
+
optimizer_state: Optional[Dict[str, Any]] = None,
|
|
26
|
+
extra: Optional[Dict[str, Any]] = None,
|
|
27
|
+
*,
|
|
28
|
+
model: Optional[torch.nn.Module] = None,
|
|
29
|
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
30
|
+
trainer_state: Optional[Dict[str, Any]] = None,
|
|
31
|
+
scaler_state: Optional[Dict[str, Any]] = None,
|
|
32
|
+
config: Optional[Dict[str, Any]] = None,
|
|
33
|
+
strict_label_compat: bool = True,
|
|
34
|
+
) -> None:
|
|
35
|
+
if model_state is None and model is not None:
|
|
36
|
+
model_state = model.state_dict()
|
|
37
|
+
if optimizer_state is None and optimizer is not None:
|
|
38
|
+
optimizer_state = optimizer.state_dict()
|
|
39
|
+
|
|
40
|
+
if model_state is None:
|
|
41
|
+
raise ValueError("save_checkpoint requires either model_state=... or model=...")
|
|
42
|
+
|
|
43
|
+
payload: Dict[str, Any] = {
|
|
44
|
+
"format_version": 3,
|
|
45
|
+
"model_state": model_state,
|
|
46
|
+
}
|
|
47
|
+
if optimizer_state is not None:
|
|
48
|
+
payload["optimizer_state"] = optimizer_state
|
|
49
|
+
if trainer_state is not None:
|
|
50
|
+
payload["trainer_state"] = dict(trainer_state)
|
|
51
|
+
if scaler_state is not None:
|
|
52
|
+
payload["scaler_state"] = dict(scaler_state)
|
|
53
|
+
if config is not None:
|
|
54
|
+
payload["config"] = dict(config)
|
|
55
|
+
if extra is not None:
|
|
56
|
+
payload["extra"] = dict(extra)
|
|
57
|
+
|
|
58
|
+
# --- classification metadata (legacy + multi-head) ---
|
|
59
|
+
if model is not None:
|
|
60
|
+
meta: Dict[str, Any] = {}
|
|
61
|
+
|
|
62
|
+
n_label_classes = getattr(model, "n_label_classes", None)
|
|
63
|
+
label_names = getattr(model, "label_names", None)
|
|
64
|
+
label_head_name = getattr(model, "label_head_name", None)
|
|
65
|
+
|
|
66
|
+
if n_label_classes is not None:
|
|
67
|
+
meta.setdefault("legacy", {})["n_label_classes"] = int(n_label_classes)
|
|
68
|
+
if label_names is not None:
|
|
69
|
+
meta.setdefault("legacy", {})["label_names"] = list(label_names)
|
|
70
|
+
if label_head_name is not None:
|
|
71
|
+
meta["label_head_name"] = str(label_head_name)
|
|
72
|
+
|
|
73
|
+
class_heads_cfg = getattr(model, "class_heads_cfg", None)
|
|
74
|
+
head_label_names = getattr(model, "head_label_names", None)
|
|
75
|
+
|
|
76
|
+
if isinstance(class_heads_cfg, dict) and len(class_heads_cfg) > 0:
|
|
77
|
+
meta.setdefault("multi", {})["heads"] = {k: dict(v) for k, v in class_heads_cfg.items()}
|
|
78
|
+
if isinstance(head_label_names, dict) and len(head_label_names) > 0:
|
|
79
|
+
meta.setdefault("multi", {})["label_names"] = {k: list(v) for k, v in head_label_names.items()}
|
|
80
|
+
|
|
81
|
+
if hasattr(model, "get_classification_meta"):
|
|
82
|
+
try:
|
|
83
|
+
meta = dict(model.get_classification_meta())
|
|
84
|
+
except Exception:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
if meta:
|
|
88
|
+
payload["label_meta"] = meta
|
|
89
|
+
payload["strict_label_compat"] = bool(strict_label_compat)
|
|
90
|
+
|
|
91
|
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
|
92
|
+
torch.save(payload, path)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def load_checkpoint(path: str, *, map_location: Union[str, torch.device, None] = "cpu") -> Dict[str, Any]:
|
|
96
|
+
return torch.load(path, map_location=map_location)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def restore_checkpoint(
|
|
100
|
+
payload_or_path: Union[str, Dict[str, Any]],
|
|
101
|
+
*,
|
|
102
|
+
model: torch.nn.Module,
|
|
103
|
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
104
|
+
scaler: Optional[torch.cuda.amp.GradScaler] = None,
|
|
105
|
+
map_location: Union[str, torch.device, None] = "cpu",
|
|
106
|
+
strict: bool = True,
|
|
107
|
+
restore_label_names: bool = True,
|
|
108
|
+
enforce_label_compat: bool = True,
|
|
109
|
+
) -> Dict[str, Any]:
|
|
110
|
+
payload = load_checkpoint(payload_or_path, map_location=map_location) if isinstance(payload_or_path, str) else payload_or_path
|
|
111
|
+
|
|
112
|
+
if enforce_label_compat:
|
|
113
|
+
meta = payload.get("label_meta", {}) or {}
|
|
114
|
+
|
|
115
|
+
legacy = meta.get("legacy", {}) if isinstance(meta, dict) else {}
|
|
116
|
+
ckpt_n = legacy.get("n_label_classes", None)
|
|
117
|
+
model_n = getattr(model, "n_label_classes", None)
|
|
118
|
+
if ckpt_n is not None and model_n is not None and int(ckpt_n) != int(model_n):
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"Checkpoint n_label_classes={ckpt_n} does not match model n_label_classes={model_n}. "
|
|
121
|
+
"Rebuild the model with the same n_label_classes."
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
multi = meta.get("multi", {}) if isinstance(meta, dict) else {}
|
|
125
|
+
ckpt_heads = multi.get("heads", None)
|
|
126
|
+
model_heads = getattr(model, "class_heads_cfg", None)
|
|
127
|
+
if isinstance(ckpt_heads, dict) and isinstance(model_heads, dict):
|
|
128
|
+
for hk, hcfg in ckpt_heads.items():
|
|
129
|
+
if hk not in model_heads:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"Checkpoint contains head {hk!r} but model does not. "
|
|
132
|
+
f"Model heads: {list(model_heads.keys())}"
|
|
133
|
+
)
|
|
134
|
+
ckpt_c = int(hcfg.get("n_classes", -1))
|
|
135
|
+
model_c = int(model_heads[hk].get("n_classes", -1))
|
|
136
|
+
if ckpt_c != model_c:
|
|
137
|
+
raise ValueError(f"Head {hk!r} n_classes mismatch: checkpoint={ckpt_c}, model={model_c}.")
|
|
138
|
+
|
|
139
|
+
model.load_state_dict(payload["model_state"], strict=bool(strict))
|
|
140
|
+
|
|
141
|
+
if optimizer is not None and "optimizer_state" in payload:
|
|
142
|
+
optimizer.load_state_dict(payload["optimizer_state"])
|
|
143
|
+
|
|
144
|
+
if scaler is not None and payload.get("scaler_state") is not None:
|
|
145
|
+
try:
|
|
146
|
+
scaler.load_state_dict(payload["scaler_state"])
|
|
147
|
+
except Exception:
|
|
148
|
+
pass
|
|
149
|
+
|
|
150
|
+
if restore_label_names:
|
|
151
|
+
meta = payload.get("label_meta", {}) or {}
|
|
152
|
+
|
|
153
|
+
legacy = meta.get("legacy", {}) if isinstance(meta, dict) else {}
|
|
154
|
+
label_names = legacy.get("label_names", None)
|
|
155
|
+
if label_names is not None and hasattr(model, "set_label_names"):
|
|
156
|
+
try:
|
|
157
|
+
model.set_label_names(list(label_names))
|
|
158
|
+
except Exception:
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
multi = meta.get("multi", {}) if isinstance(meta, dict) else {}
|
|
162
|
+
head_names = multi.get("label_names", None)
|
|
163
|
+
if isinstance(head_names, dict) and hasattr(model, "set_head_label_names"):
|
|
164
|
+
for hk, names in head_names.items():
|
|
165
|
+
try:
|
|
166
|
+
model.set_head_label_names(str(hk), list(names))
|
|
167
|
+
except Exception:
|
|
168
|
+
pass
|
|
169
|
+
|
|
170
|
+
return payload
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
# =============================================================================
|
|
174
|
+
# JSON config helpers
|
|
175
|
+
# =============================================================================
|
|
176
|
+
|
|
177
|
+
def save_config_json(config: Dict[str, Any], path: str) -> None:
|
|
178
|
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
|
179
|
+
with open(path, "w") as f:
|
|
180
|
+
json.dump(config, f, indent=2, sort_keys=True)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def load_config(config_path: str) -> Dict[str, Any]:
|
|
184
|
+
with open(config_path) as f:
|
|
185
|
+
return json.load(f)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# =============================================================================
|
|
189
|
+
# AnnData split helpers (SAFE + FLEXIBLE)
|
|
190
|
+
# =============================================================================
|
|
191
|
+
|
|
192
|
+
def _is_sequence_of_str(x: Any) -> bool:
|
|
193
|
+
return isinstance(x, (list, tuple, np.ndarray)) and len(x) > 0 and isinstance(x[0], (str, np.str_))
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _is_sequence_of_int(x: Any) -> bool:
|
|
197
|
+
return isinstance(x, (list, tuple, np.ndarray)) and len(x) > 0 and isinstance(x[0], (int, np.integer))
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _normalize_split_selector(
|
|
201
|
+
adata: ad.AnnData,
|
|
202
|
+
selector: Any,
|
|
203
|
+
*,
|
|
204
|
+
name: str,
|
|
205
|
+
) -> Union[np.ndarray, Sequence[str], Sequence[int]]:
|
|
206
|
+
"""
|
|
207
|
+
Convert a selector into one of:
|
|
208
|
+
- boolean mask (np.ndarray[bool] length n_obs)
|
|
209
|
+
- list of obs_names (Sequence[str])
|
|
210
|
+
- list of integer indices (Sequence[int])
|
|
211
|
+
|
|
212
|
+
Rejects AnnData objects or other iterables that could trigger accidental iteration.
|
|
213
|
+
"""
|
|
214
|
+
if selector is None:
|
|
215
|
+
return np.zeros(adata.n_obs, dtype=bool)
|
|
216
|
+
|
|
217
|
+
# HARD FAIL: AnnData passed where indices were expected
|
|
218
|
+
if isinstance(selector, ad.AnnData):
|
|
219
|
+
raise TypeError(
|
|
220
|
+
f"{name}: expected indices/obs_names/mask, but got AnnData. "
|
|
221
|
+
"Pass `splits={'train': adata_train, ...}` instead."
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# bool mask
|
|
225
|
+
if isinstance(selector, np.ndarray) and selector.dtype == bool:
|
|
226
|
+
if selector.shape[0] != adata.n_obs:
|
|
227
|
+
raise ValueError(f"{name}: boolean mask length {selector.shape[0]} != n_obs {adata.n_obs}.")
|
|
228
|
+
return selector
|
|
229
|
+
|
|
230
|
+
# pandas Series of bool
|
|
231
|
+
try:
|
|
232
|
+
import pandas as pd # optional
|
|
233
|
+
if isinstance(selector, pd.Series) and selector.dtype == bool:
|
|
234
|
+
v = selector.to_numpy()
|
|
235
|
+
if v.shape[0] != adata.n_obs:
|
|
236
|
+
raise ValueError(f"{name}: boolean mask length {v.shape[0]} != n_obs {adata.n_obs}.")
|
|
237
|
+
return v
|
|
238
|
+
except Exception:
|
|
239
|
+
pass
|
|
240
|
+
|
|
241
|
+
# list/tuple/np array of strings: obs_names
|
|
242
|
+
if _is_sequence_of_str(selector):
|
|
243
|
+
return [str(s) for s in list(selector)]
|
|
244
|
+
|
|
245
|
+
# list/tuple/np array of ints: indices
|
|
246
|
+
if _is_sequence_of_int(selector):
|
|
247
|
+
idx = [int(i) for i in list(selector)]
|
|
248
|
+
if len(idx) > 0:
|
|
249
|
+
mx = max(idx)
|
|
250
|
+
mn = min(idx)
|
|
251
|
+
if mn < 0 or mx >= adata.n_obs:
|
|
252
|
+
raise IndexError(f"{name}: index out of bounds (min={mn}, max={mx}, n_obs={adata.n_obs}).")
|
|
253
|
+
return idx
|
|
254
|
+
|
|
255
|
+
# empty list/tuple
|
|
256
|
+
if isinstance(selector, (list, tuple)) and len(selector) == 0:
|
|
257
|
+
return np.zeros(adata.n_obs, dtype=bool)
|
|
258
|
+
|
|
259
|
+
raise TypeError(
|
|
260
|
+
f"{name}: unsupported selector type {type(selector)}. "
|
|
261
|
+
"Use obs_names (list[str]), indices (list[int]), or boolean mask (np.ndarray[bool]). "
|
|
262
|
+
"If you already have split AnnData objects, pass `splits=...`."
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _subset_adata(
|
|
267
|
+
adata: ad.AnnData,
|
|
268
|
+
selector: Union[np.ndarray, Sequence[str], Sequence[int]],
|
|
269
|
+
*,
|
|
270
|
+
copy: bool,
|
|
271
|
+
) -> ad.AnnData:
|
|
272
|
+
if isinstance(selector, np.ndarray) and selector.dtype == bool:
|
|
273
|
+
out = adata[selector]
|
|
274
|
+
elif _is_sequence_of_str(selector):
|
|
275
|
+
out = adata[list(selector)]
|
|
276
|
+
else:
|
|
277
|
+
out = adata[list(selector), :]
|
|
278
|
+
return out.copy() if copy else out
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def _write_h5ad_safe(adata_obj: ad.AnnData, path: str, *, write_backed: bool = False) -> None:
|
|
282
|
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
|
283
|
+
# write_backed kept for API compatibility; anndata's backed writing patterns vary.
|
|
284
|
+
adata_obj.write_h5ad(path)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def save_anndata_splits(
|
|
288
|
+
adata: Optional[ad.AnnData] = None,
|
|
289
|
+
outdir: str = ".",
|
|
290
|
+
*,
|
|
291
|
+
prefix: str = "dataset",
|
|
292
|
+
# Mode A: split from adata.obs[split_key]
|
|
293
|
+
split_key: Optional[str] = "split",
|
|
294
|
+
train_label: str = "train",
|
|
295
|
+
val_label: str = "val",
|
|
296
|
+
test_label: str = "test",
|
|
297
|
+
# Mode B: split from selectors (indices/obs_names/mask)
|
|
298
|
+
split_map: Optional[Dict[str, Any]] = None,
|
|
299
|
+
# Mode C: already split objects (no resplitting)
|
|
300
|
+
splits: Optional[Dict[str, ad.AnnData]] = None,
|
|
301
|
+
# Behavior
|
|
302
|
+
copy: bool = False,
|
|
303
|
+
write_backed: bool = False,
|
|
304
|
+
save_h5ad: bool = True,
|
|
305
|
+
save_split_map: bool = True,
|
|
306
|
+
split_map_name: Optional[str] = None,
|
|
307
|
+
require_disjoint: bool = True,
|
|
308
|
+
) -> Dict[str, ad.AnnData]:
|
|
309
|
+
"""
|
|
310
|
+
Save train/val/test splits to {outdir}/{prefix}_{train|val|test}.h5ad and optionally a split map JSON.
|
|
311
|
+
|
|
312
|
+
You can provide EXACTLY ONE of:
|
|
313
|
+
1) splits={...} (already split AnnData objects)
|
|
314
|
+
2) adata + split_map={...} (selectors: obs_names, indices, or boolean masks)
|
|
315
|
+
3) adata + split_key in adata.obs (labels in .obs)
|
|
316
|
+
|
|
317
|
+
Safety: passing AnnData objects inside split_map is rejected with a clear error.
|
|
318
|
+
"""
|
|
319
|
+
os.makedirs(outdir, exist_ok=True)
|
|
320
|
+
|
|
321
|
+
if splits is not None:
|
|
322
|
+
if adata is not None and split_map is not None:
|
|
323
|
+
raise ValueError("Provide only one of: splits=, (adata+split_map), or (adata+split_key).")
|
|
324
|
+
missing = [k for k in ("train", "val", "test") if k not in splits]
|
|
325
|
+
if missing:
|
|
326
|
+
raise ValueError(f"splits is missing keys: {missing}. Expected train/val/test.")
|
|
327
|
+
train = splits["train"]
|
|
328
|
+
val = splits["val"]
|
|
329
|
+
test = splits["test"]
|
|
330
|
+
if require_disjoint:
|
|
331
|
+
s1 = set(train.obs_names)
|
|
332
|
+
s2 = set(val.obs_names)
|
|
333
|
+
s3 = set(test.obs_names)
|
|
334
|
+
if (s1 & s2) or (s1 & s3) or (s2 & s3):
|
|
335
|
+
raise ValueError("splits are not disjoint by obs_names (overlap found).")
|
|
336
|
+
else:
|
|
337
|
+
if adata is None:
|
|
338
|
+
raise ValueError("If splits is not provided, you must provide adata=...")
|
|
339
|
+
|
|
340
|
+
if split_map is not None:
|
|
341
|
+
sel_train = _normalize_split_selector(adata, split_map.get("train", None), name="split_map['train']")
|
|
342
|
+
sel_val = _normalize_split_selector(adata, split_map.get("val", None), name="split_map['val']")
|
|
343
|
+
sel_test = _normalize_split_selector(adata, split_map.get("test", None), name="split_map['test']")
|
|
344
|
+
|
|
345
|
+
if require_disjoint:
|
|
346
|
+
def to_set(sel):
|
|
347
|
+
if isinstance(sel, np.ndarray) and sel.dtype == bool:
|
|
348
|
+
return set(np.where(sel)[0].tolist())
|
|
349
|
+
if _is_sequence_of_str(sel):
|
|
350
|
+
return set(map(str, sel))
|
|
351
|
+
return set(map(int, sel))
|
|
352
|
+
a = to_set(sel_train); b = to_set(sel_val); c = to_set(sel_test)
|
|
353
|
+
if (a & b) or (a & c) or (b & c):
|
|
354
|
+
raise ValueError("split_map splits overlap (require_disjoint=True).")
|
|
355
|
+
|
|
356
|
+
train = _subset_adata(adata, sel_train, copy=copy)
|
|
357
|
+
val = _subset_adata(adata, sel_val, copy=copy)
|
|
358
|
+
test = _subset_adata(adata, sel_test, copy=copy)
|
|
359
|
+
|
|
360
|
+
else:
|
|
361
|
+
if split_key is None or split_key not in adata.obs:
|
|
362
|
+
raise ValueError(
|
|
363
|
+
f"Expected split labels in adata.obs[{split_key!r}], or provide split_map=..., or splits=...."
|
|
364
|
+
)
|
|
365
|
+
s = adata.obs[split_key].astype(str)
|
|
366
|
+
train = adata[s == train_label].copy() if copy else adata[s == train_label]
|
|
367
|
+
val = adata[s == val_label].copy() if copy else adata[s == val_label]
|
|
368
|
+
test = adata[s == test_label].copy() if copy else adata[s == test_label]
|
|
369
|
+
|
|
370
|
+
if require_disjoint:
|
|
371
|
+
s1 = set(train.obs_names); s2 = set(val.obs_names); s3 = set(test.obs_names)
|
|
372
|
+
if (s1 & s2) or (s1 & s3) or (s2 & s3):
|
|
373
|
+
raise ValueError("split_key-derived splits are not disjoint by obs_names.")
|
|
374
|
+
|
|
375
|
+
paths = {
|
|
376
|
+
"train": os.path.join(outdir, f"{prefix}_train.h5ad"),
|
|
377
|
+
"val": os.path.join(outdir, f"{prefix}_val.h5ad"),
|
|
378
|
+
"test": os.path.join(outdir, f"{prefix}_test.h5ad"),
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
if save_h5ad:
|
|
382
|
+
_write_h5ad_safe(train, paths["train"], write_backed=write_backed)
|
|
383
|
+
_write_h5ad_safe(val, paths["val"], write_backed=write_backed)
|
|
384
|
+
_write_h5ad_safe(test, paths["test"], write_backed=write_backed)
|
|
385
|
+
|
|
386
|
+
if save_split_map:
|
|
387
|
+
sm = {
|
|
388
|
+
"train": train.obs_names.tolist(),
|
|
389
|
+
"val": val.obs_names.tolist(),
|
|
390
|
+
"test": test.obs_names.tolist(),
|
|
391
|
+
"prefix": prefix,
|
|
392
|
+
}
|
|
393
|
+
if splits is None and split_map is None and adata is not None:
|
|
394
|
+
sm.update({
|
|
395
|
+
"split_key": split_key,
|
|
396
|
+
"train_label": train_label,
|
|
397
|
+
"val_label": val_label,
|
|
398
|
+
"test_label": test_label,
|
|
399
|
+
})
|
|
400
|
+
|
|
401
|
+
fn = split_map_name or f"{prefix}_split_map.json"
|
|
402
|
+
with open(os.path.join(outdir, fn), "w") as f:
|
|
403
|
+
json.dump(sm, f, indent=2)
|
|
404
|
+
|
|
405
|
+
return {"train": train, "val": val, "test": test}
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
# =============================================================================
|
|
409
|
+
# Loading helpers (ADDED)
|
|
410
|
+
# =============================================================================
|
|
411
|
+
|
|
412
|
+
def load_split_map(outdir: str, prefix: str = "dataset", split_map_name: Optional[str] = None) -> Dict[str, Any]:
|
|
413
|
+
fn = split_map_name or f"{prefix}_split_map.json"
|
|
414
|
+
path = os.path.join(outdir, fn)
|
|
415
|
+
with open(path) as f:
|
|
416
|
+
return json.load(f)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def load_anndata_splits(
|
|
420
|
+
outdir: str,
|
|
421
|
+
prefix: str = "dataset",
|
|
422
|
+
*,
|
|
423
|
+
backed: Optional[Union[bool, str]] = None,
|
|
424
|
+
) -> Dict[str, ad.AnnData]:
|
|
425
|
+
"""
|
|
426
|
+
Load {prefix}_{train|val|test}.h5ad from outdir.
|
|
427
|
+
|
|
428
|
+
backed:
|
|
429
|
+
- None: normal in-memory load
|
|
430
|
+
- "r": backed read-only (useful for huge files)
|
|
431
|
+
- True: alias for "r"
|
|
432
|
+
"""
|
|
433
|
+
paths = {
|
|
434
|
+
"train": os.path.join(outdir, f"{prefix}_train.h5ad"),
|
|
435
|
+
"val": os.path.join(outdir, f"{prefix}_val.h5ad"),
|
|
436
|
+
"test": os.path.join(outdir, f"{prefix}_test.h5ad"),
|
|
437
|
+
}
|
|
438
|
+
missing = [k for k, p in paths.items() if not os.path.exists(p)]
|
|
439
|
+
if missing:
|
|
440
|
+
raise FileNotFoundError(f"Missing split files for prefix={prefix!r}: {missing}")
|
|
441
|
+
|
|
442
|
+
if backed is True:
|
|
443
|
+
backed = "r"
|
|
444
|
+
|
|
445
|
+
return {k: ad.read_h5ad(p, backed=backed) for k, p in paths.items()}
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
def subset_anndata_from_split_map(
|
|
449
|
+
adata: ad.AnnData,
|
|
450
|
+
split_map: Dict[str, Any],
|
|
451
|
+
*,
|
|
452
|
+
copy: bool = False,
|
|
453
|
+
require_all: bool = True,
|
|
454
|
+
) -> Dict[str, ad.AnnData]:
|
|
455
|
+
"""
|
|
456
|
+
Recreate splits from a loaded split_map JSON (expects obs_names lists in split_map['train'/'val'/'test']).
|
|
457
|
+
"""
|
|
458
|
+
keys = ("train", "val", "test")
|
|
459
|
+
if require_all:
|
|
460
|
+
for k in keys:
|
|
461
|
+
if k not in split_map:
|
|
462
|
+
raise KeyError(f"split_map missing key {k!r}")
|
|
463
|
+
|
|
464
|
+
def get_names(k: str):
|
|
465
|
+
v = split_map.get(k, [])
|
|
466
|
+
if v is None:
|
|
467
|
+
v = []
|
|
468
|
+
if not isinstance(v, (list, tuple)):
|
|
469
|
+
raise TypeError(f"split_map[{k!r}] must be a list of obs_names.")
|
|
470
|
+
return [str(x) for x in v]
|
|
471
|
+
|
|
472
|
+
train_names = get_names("train")
|
|
473
|
+
val_names = get_names("val")
|
|
474
|
+
test_names = get_names("test")
|
|
475
|
+
|
|
476
|
+
train = adata[train_names]
|
|
477
|
+
val = adata[val_names]
|
|
478
|
+
test = adata[test_names]
|
|
479
|
+
|
|
480
|
+
if copy:
|
|
481
|
+
train = train.copy()
|
|
482
|
+
val = val.copy()
|
|
483
|
+
test = test.copy()
|
|
484
|
+
|
|
485
|
+
return {"train": train, "val": val, "test": test}
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def load_or_recreate_splits(
|
|
489
|
+
outdir: str,
|
|
490
|
+
prefix: str,
|
|
491
|
+
*,
|
|
492
|
+
adata: Optional[ad.AnnData] = None,
|
|
493
|
+
backed: Optional[Union[bool, str]] = None,
|
|
494
|
+
split_map_name: Optional[str] = None,
|
|
495
|
+
copy: bool = False,
|
|
496
|
+
) -> Dict[str, ad.AnnData]:
|
|
497
|
+
"""
|
|
498
|
+
Convenience:
|
|
499
|
+
- if {prefix}_train/val/test.h5ad exist: load them
|
|
500
|
+
- else if split map exists and adata is provided: recreate splits from adata
|
|
501
|
+
"""
|
|
502
|
+
paths = {
|
|
503
|
+
"train": os.path.join(outdir, f"{prefix}_train.h5ad"),
|
|
504
|
+
"val": os.path.join(outdir, f"{prefix}_val.h5ad"),
|
|
505
|
+
"test": os.path.join(outdir, f"{prefix}_test.h5ad"),
|
|
506
|
+
}
|
|
507
|
+
if all(os.path.exists(p) for p in paths.values()):
|
|
508
|
+
return load_anndata_splits(outdir, prefix=prefix, backed=backed)
|
|
509
|
+
|
|
510
|
+
# fallback to split map + adata
|
|
511
|
+
sm = load_split_map(outdir, prefix=prefix, split_map_name=split_map_name)
|
|
512
|
+
if adata is None:
|
|
513
|
+
raise ValueError(
|
|
514
|
+
f"Split .h5ad files not found for prefix={prefix!r}. "
|
|
515
|
+
"Provide `adata=...` to recreate from split_map JSON."
|
|
516
|
+
)
|
|
517
|
+
return subset_anndata_from_split_map(adata, sm, copy=copy)
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
# =============================================================================
|
|
521
|
+
# Latent writer
|
|
522
|
+
# =============================================================================
|
|
523
|
+
|
|
524
|
+
def _select_X(adata_obj: ad.AnnData, layer: Optional[str], X_key: str):
|
|
525
|
+
if X_key != "X":
|
|
526
|
+
if X_key not in adata_obj.obsm:
|
|
527
|
+
raise KeyError(f"X_key={X_key!r} not found in adata.obsm. Keys={list(adata_obj.obsm.keys())}")
|
|
528
|
+
return adata_obj.obsm[X_key]
|
|
529
|
+
if layer is not None:
|
|
530
|
+
if layer not in adata_obj.layers:
|
|
531
|
+
raise KeyError(f"layer={layer!r} not found in adata.layers. Keys={list(adata_obj.layers.keys())}")
|
|
532
|
+
return adata_obj.layers[layer]
|
|
533
|
+
return adata_obj.X
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@torch.no_grad()
|
|
537
|
+
def write_univi_latent(
|
|
538
|
+
model,
|
|
539
|
+
adata_dict: Dict[str, ad.AnnData],
|
|
540
|
+
*,
|
|
541
|
+
obsm_key: str = "X_univi",
|
|
542
|
+
batch_size: int = 512,
|
|
543
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
544
|
+
use_mean: bool = False,
|
|
545
|
+
epoch: int = 0,
|
|
546
|
+
y: Optional[Union[np.ndarray, torch.Tensor, Dict[str, Union[np.ndarray, torch.Tensor]]]] = None,
|
|
547
|
+
layer: Union[None, str, Mapping[str, Optional[str]]] = None,
|
|
548
|
+
X_key: Union[str, Mapping[str, str]] = "X",
|
|
549
|
+
require_paired: bool = True,
|
|
550
|
+
) -> np.ndarray:
|
|
551
|
+
model.eval()
|
|
552
|
+
|
|
553
|
+
names = list(adata_dict.keys())
|
|
554
|
+
if len(names) == 0:
|
|
555
|
+
raise ValueError("adata_dict is empty.")
|
|
556
|
+
|
|
557
|
+
n = int(adata_dict[names[0]].n_obs)
|
|
558
|
+
|
|
559
|
+
if require_paired:
|
|
560
|
+
ref = adata_dict[names[0]].obs_names.values
|
|
561
|
+
for nm in names[1:]:
|
|
562
|
+
if adata_dict[nm].n_obs != n:
|
|
563
|
+
raise ValueError(f"n_obs mismatch: {nm} has {adata_dict[nm].n_obs} vs {n}")
|
|
564
|
+
if not np.array_equal(adata_dict[nm].obs_names.values, ref):
|
|
565
|
+
raise ValueError(f"obs_names mismatch between {names[0]} and {nm}")
|
|
566
|
+
|
|
567
|
+
if device is None:
|
|
568
|
+
try:
|
|
569
|
+
device = next(model.parameters()).device
|
|
570
|
+
except StopIteration:
|
|
571
|
+
device = "cpu"
|
|
572
|
+
|
|
573
|
+
layer_by_mod = dict(layer) if isinstance(layer, dict) else {nm: layer for nm in names}
|
|
574
|
+
xkey_by_mod = dict(X_key) if isinstance(X_key, dict) else {nm: X_key for nm in names}
|
|
575
|
+
|
|
576
|
+
y_t = None
|
|
577
|
+
if y is not None:
|
|
578
|
+
if isinstance(y, dict):
|
|
579
|
+
y_t = {str(k): (v if torch.is_tensor(v) else torch.as_tensor(v)).long() for k, v in y.items()}
|
|
580
|
+
else:
|
|
581
|
+
y_t = (y if torch.is_tensor(y) else torch.as_tensor(y)).long()
|
|
582
|
+
|
|
583
|
+
zs = []
|
|
584
|
+
bs = int(batch_size)
|
|
585
|
+
if bs <= 0:
|
|
586
|
+
raise ValueError("batch_size must be > 0")
|
|
587
|
+
|
|
588
|
+
for start in range(0, n, bs):
|
|
589
|
+
end = min(n, start + bs)
|
|
590
|
+
x_dict = {}
|
|
591
|
+
for nm in names:
|
|
592
|
+
adata_obj = adata_dict[nm]
|
|
593
|
+
X = _select_X(adata_obj, layer_by_mod.get(nm, None), xkey_by_mod.get(nm, "X"))
|
|
594
|
+
xb = X[start:end]
|
|
595
|
+
if sp.issparse(xb):
|
|
596
|
+
xb = xb.toarray()
|
|
597
|
+
xb = np.asarray(xb)
|
|
598
|
+
x_dict[nm] = torch.as_tensor(xb, dtype=torch.float32, device=device)
|
|
599
|
+
|
|
600
|
+
yb = None
|
|
601
|
+
if isinstance(y_t, dict):
|
|
602
|
+
yb = {k: v[start:end].to(device) for k, v in y_t.items()}
|
|
603
|
+
elif torch.is_tensor(y_t):
|
|
604
|
+
yb = y_t[start:end].to(device)
|
|
605
|
+
|
|
606
|
+
if hasattr(model, "encode_fused"):
|
|
607
|
+
mu_z, logvar_z, z = model.encode_fused(x_dict, epoch=int(epoch), y=yb, use_mean=bool(use_mean))
|
|
608
|
+
z_use = z
|
|
609
|
+
else:
|
|
610
|
+
out = model(x_dict)
|
|
611
|
+
z_use = out["mu_z"] if (use_mean and ("mu_z" in out)) else out["z"]
|
|
612
|
+
|
|
613
|
+
zs.append(z_use.detach().cpu().numpy())
|
|
614
|
+
|
|
615
|
+
Z = np.vstack(zs)
|
|
616
|
+
|
|
617
|
+
for nm in names:
|
|
618
|
+
adata_dict[nm].obsm[obsm_key] = Z
|
|
619
|
+
|
|
620
|
+
return Z
|
|
621
|
+
|
univi/utils/logging.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# univi/utils/logging.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def get_logger(name: str = "univi", level: int = logging.INFO) -> logging.Logger:
|
|
9
|
+
logger = logging.getLogger(name)
|
|
10
|
+
if not logger.handlers:
|
|
11
|
+
handler = logging.StreamHandler()
|
|
12
|
+
fmt = "[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s"
|
|
13
|
+
handler.setFormatter(logging.Formatter(fmt))
|
|
14
|
+
logger.addHandler(handler)
|
|
15
|
+
logger.setLevel(level)
|
|
16
|
+
return logger
|
univi/utils/seed.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# univi/utils/seed.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
import random
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def set_seed(seed: int = 0, deterministic: bool = False):
|
|
10
|
+
random.seed(seed)
|
|
11
|
+
np.random.seed(seed)
|
|
12
|
+
torch.manual_seed(seed)
|
|
13
|
+
if torch.cuda.is_available():
|
|
14
|
+
torch.cuda.manual_seed_all(seed)
|
|
15
|
+
|
|
16
|
+
if deterministic:
|
|
17
|
+
torch.backends.cudnn.deterministic = True
|
|
18
|
+
torch.backends.cudnn.benchmark = False
|
univi/utils/stats.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# univi/utils/stats.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import Dict
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def mean_dict(d: Dict[str, float]) -> float:
|
|
9
|
+
"""
|
|
10
|
+
Simple helper: mean of dictionary values.
|
|
11
|
+
"""
|
|
12
|
+
if not d:
|
|
13
|
+
return float("nan")
|
|
14
|
+
return float(np.mean(list(d.values())))
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def zscore(x: np.ndarray, axis: int = 0, eps: float = 1e-8) -> np.ndarray:
|
|
18
|
+
"""
|
|
19
|
+
Z-score normalization along given axis.
|
|
20
|
+
"""
|
|
21
|
+
mu = x.mean(axis=axis, keepdims=True)
|
|
22
|
+
std = x.std(axis=axis, keepdims=True) + eps
|
|
23
|
+
return (x - mu) / std
|