univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/data.py ADDED
@@ -0,0 +1,345 @@
1
+ # univi/data.py
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Mapping, Optional, Tuple, Union, List, Sequence
6
+
7
+ import os
8
+ import numpy as np
9
+ import scipy.sparse as sp
10
+ import pandas as pd
11
+ import anndata as ad
12
+ from anndata import AnnData
13
+
14
+ import torch
15
+ from torch.utils.data import Dataset
16
+
17
+ from .config import ModalityConfig
18
+
19
+ LayerSpec = Union[None, str, Mapping[str, Optional[str]]]
20
+ XKeySpec = Union[str, Mapping[str, str]]
21
+ LabelSpec = Union[
22
+ np.ndarray,
23
+ torch.Tensor,
24
+ Sequence[int],
25
+ Mapping[str, Union[np.ndarray, torch.Tensor, Sequence[int]]],
26
+ ]
27
+
28
+
29
+ def _is_categorical_likelihood(lk: Optional[str]) -> bool:
30
+ lk = (lk or "").lower().strip()
31
+ return lk in ("categorical", "cat", "ce", "cross_entropy", "multinomial", "softmax")
32
+
33
+
34
+ def _get_matrix(adata_obj: AnnData, *, layer: Optional[str], X_key: str):
35
+ if X_key != "X":
36
+ if X_key not in adata_obj.obsm:
37
+ raise KeyError("X_key=%r not found in adata.obsm. Keys=%s" % (X_key, list(adata_obj.obsm.keys())))
38
+ return adata_obj.obsm[X_key]
39
+
40
+ if layer is not None:
41
+ if layer not in adata_obj.layers:
42
+ raise KeyError("layer=%r not found in adata.layers. Keys=%s" % (layer, list(adata_obj.layers.keys())))
43
+ return adata_obj.layers[layer]
44
+
45
+ return adata_obj.X
46
+
47
+
48
+ def infer_input_dim(adata_obj: AnnData, *, layer: Optional[str], X_key: str) -> int:
49
+ X = _get_matrix(adata_obj, layer=layer, X_key=X_key)
50
+ if not hasattr(X, "shape") or len(X.shape) != 2:
51
+ raise ValueError("Selected matrix for (layer=%r, X_key=%r) is not 2D." % (layer, X_key))
52
+ return int(X.shape[1])
53
+
54
+
55
+ def align_paired_obs_names(
56
+ adata_dict: Dict[str, AnnData],
57
+ how: str = "intersection",
58
+ require_nonempty: bool = True,
59
+ sort: bool = True,
60
+ copy: bool = True,
61
+ ) -> Dict[str, AnnData]:
62
+ if not adata_dict:
63
+ raise ValueError("adata_dict is empty")
64
+ if how != "intersection":
65
+ raise ValueError("Unsupported how=%r. Only 'intersection' is supported." % how)
66
+
67
+ names = list(adata_dict.keys())
68
+ shared = None
69
+ for nm in names:
70
+ idx = adata_dict[nm].obs_names
71
+ shared = idx if shared is None else shared.intersection(idx)
72
+
73
+ if shared is None:
74
+ shared = pd.Index([])
75
+
76
+ if require_nonempty and len(shared) == 0:
77
+ raise ValueError("No shared obs_names across modalities (intersection is empty).")
78
+
79
+ if sort:
80
+ shared = shared.sort_values()
81
+
82
+ out: Dict[str, AnnData] = {}
83
+ for nm in names:
84
+ slc = adata_dict[nm][shared, :]
85
+ out[nm] = slc.copy() if copy else slc
86
+ return out
87
+
88
+
89
+ def _as_modality_map(
90
+ spec: Union[str, None, Mapping[str, Any]],
91
+ adata_dict: Dict[str, AnnData],
92
+ kind: str,
93
+ ) -> Dict[str, Any]:
94
+ if isinstance(spec, Mapping):
95
+ out = dict(spec)
96
+ else:
97
+ out = {k: spec for k in adata_dict.keys()}
98
+
99
+ for k in adata_dict.keys():
100
+ if k not in out:
101
+ out[k] = None if kind == "layer" else "X"
102
+ return out
103
+
104
+
105
+ class MultiModalDataset(Dataset):
106
+ """
107
+ Multi-modal AnnData-backed torch Dataset.
108
+
109
+ Returns:
110
+ - x_dict: Dict[modality -> FloatTensor]
111
+ - (x_dict, y) if labels are provided, where y is:
112
+ * LongTensor scalar (back-compat), OR
113
+ * dict[str -> LongTensor scalar] (multi-head)
114
+
115
+ Categorical modality support:
116
+ - If modality_cfgs marks a modality as categorical with input_kind="obs",
117
+ x_dict[modality] is a (1,) float tensor holding an integer code.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ adata_dict: Dict[str, AnnData],
123
+ layer: LayerSpec = None,
124
+ X_key: XKeySpec = "X",
125
+ paired: bool = True,
126
+ device: Optional[torch.device] = None,
127
+ labels: Optional[LabelSpec] = None,
128
+ dtype: torch.dtype = torch.float32,
129
+ modality_cfgs: Optional[List[ModalityConfig]] = None,
130
+ ):
131
+ if not adata_dict:
132
+ raise ValueError("adata_dict is empty")
133
+
134
+ self.adata_dict: Dict[str, AnnData] = adata_dict
135
+ self.modalities: List[str] = list(adata_dict.keys())
136
+ self.paired = bool(paired)
137
+ self.device = device
138
+ self.dtype = dtype
139
+
140
+ self.layer_by_modality: Dict[str, Optional[str]] = _as_modality_map(layer, adata_dict, kind="layer")
141
+ self.xkey_by_modality: Dict[str, str] = _as_modality_map(X_key, adata_dict, kind="xkey")
142
+
143
+ self.mod_cfg_by_name: Dict[str, ModalityConfig] = {}
144
+ if modality_cfgs is not None:
145
+ self.mod_cfg_by_name = {m.name: m for m in modality_cfgs}
146
+
147
+ first = next(iter(adata_dict.values()))
148
+ self._n_cells: int = int(first.n_obs)
149
+ self._obs_names = first.obs_names
150
+
151
+ if self.paired:
152
+ for nm, adata_obj in self.adata_dict.items():
153
+ if int(adata_obj.n_obs) != self._n_cells:
154
+ raise ValueError(
155
+ f"Paired dataset requires matching n_obs across modalities. "
156
+ f"First={self._n_cells}, {nm}={adata_obj.n_obs}"
157
+ )
158
+ if not np.array_equal(adata_obj.obs_names.values, self._obs_names.values):
159
+ raise ValueError(
160
+ "Paired dataset requires identical obs_names order; %r differs. "
161
+ "Tip: use dataset_from_anndata_dict(..., align_obs=True)." % nm
162
+ )
163
+
164
+ # Labels (optional): either a single vector or a dict of vectors
165
+ self.labels: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None
166
+ if labels is not None:
167
+ if isinstance(labels, Mapping):
168
+ yd: Dict[str, torch.Tensor] = {}
169
+ for hk, hv in labels.items():
170
+ t = hv if torch.is_tensor(hv) else torch.as_tensor(hv)
171
+ if t.ndim != 1:
172
+ t = t.reshape(-1)
173
+ if int(t.shape[0]) != self._n_cells:
174
+ raise ValueError(f"labels[{hk!r}] length ({int(t.shape[0])}) must equal n_cells ({self._n_cells})")
175
+ t = t.long()
176
+ if self.device is not None:
177
+ t = t.to(self.device)
178
+ yd[str(hk)] = t
179
+ self.labels = yd
180
+ else:
181
+ y = labels if torch.is_tensor(labels) else torch.as_tensor(labels)
182
+ if y.ndim != 1:
183
+ y = y.reshape(-1)
184
+ if int(y.shape[0]) != self._n_cells:
185
+ raise ValueError(f"labels length ({int(y.shape[0])}) must equal n_cells ({self._n_cells})")
186
+ y = y.long()
187
+ if self.device is not None:
188
+ y = y.to(self.device)
189
+ self.labels = y
190
+
191
+ @property
192
+ def n_cells(self) -> int:
193
+ return self._n_cells
194
+
195
+ @property
196
+ def obs_names(self):
197
+ return self._obs_names
198
+
199
+ def __len__(self) -> int:
200
+ return self._n_cells
201
+
202
+ def _get_X_row(self, adata_obj: AnnData, idx: int, layer: Optional[str], X_key: str) -> np.ndarray:
203
+ X = _get_matrix(adata_obj, layer=layer, X_key=X_key)
204
+ row = X[idx]
205
+ if sp.issparse(row):
206
+ row = row.toarray()
207
+ return np.asarray(row).reshape(-1).astype(np.float32, copy=False)
208
+
209
+ def _get_obs_label_row(self, adata_obj: AnnData, idx: int, obs_key: str) -> np.ndarray:
210
+ if obs_key not in adata_obj.obs:
211
+ raise KeyError(f"obs_key={obs_key!r} not found in adata.obs columns.")
212
+
213
+ col = adata_obj.obs[obs_key]
214
+
215
+ if pd.api.types.is_categorical_dtype(col):
216
+ v = int(col.cat.codes.iloc[idx])
217
+ return np.asarray([v], dtype=np.float32)
218
+
219
+ v = col.iloc[idx]
220
+ if isinstance(v, (np.integer, int)):
221
+ return np.asarray([int(v)], dtype=np.float32)
222
+ if isinstance(v, (np.floating, float)):
223
+ return np.asarray([float(v)], dtype=np.float32)
224
+
225
+ raise TypeError(
226
+ f"adata.obs[{obs_key!r}] must be numeric integer codes (or pandas Categorical). "
227
+ f"Got type {type(v)} at row {idx}. Encode categories to int codes first."
228
+ )
229
+
230
+ def __getitem__(self, idx: int):
231
+ x_dict: Dict[str, torch.Tensor] = {}
232
+
233
+ for name, adata_obj in self.adata_dict.items():
234
+ mcfg = self.mod_cfg_by_name.get(name, None)
235
+
236
+ if (
237
+ mcfg is not None
238
+ and _is_categorical_likelihood(mcfg.likelihood)
239
+ and (mcfg.input_kind or "matrix") == "obs"
240
+ ):
241
+ if not mcfg.obs_key:
242
+ raise ValueError(f"Modality {name!r}: input_kind='obs' requires obs_key.")
243
+ row_np = self._get_obs_label_row(adata_obj, idx, obs_key=mcfg.obs_key)
244
+ else:
245
+ layer = self.layer_by_modality.get(name, None)
246
+ xkey = self.xkey_by_modality.get(name, "X")
247
+ row_np = self._get_X_row(adata_obj, idx, layer=layer, X_key=xkey)
248
+
249
+ x = torch.from_numpy(row_np).to(dtype=self.dtype)
250
+ if self.device is not None:
251
+ x = x.to(self.device)
252
+ x_dict[name] = x
253
+
254
+ if self.labels is None:
255
+ return x_dict
256
+
257
+ if isinstance(self.labels, dict):
258
+ y_out: Dict[str, torch.Tensor] = {k: v[idx] for k, v in self.labels.items()}
259
+ return x_dict, y_out
260
+
261
+ return x_dict, self.labels[idx]
262
+
263
+
264
+ def dataset_from_anndata_dict(
265
+ adata_dict: Dict[str, AnnData],
266
+ layer: LayerSpec = None,
267
+ X_key: XKeySpec = "X",
268
+ paired: bool = True,
269
+ align_obs: bool = True,
270
+ labels: Optional[LabelSpec] = None,
271
+ device: Optional[torch.device] = None,
272
+ dtype: torch.dtype = torch.float32,
273
+ copy_aligned: bool = True,
274
+ modality_cfgs: Optional[List[ModalityConfig]] = None,
275
+ ) -> Tuple[MultiModalDataset, Dict[str, AnnData]]:
276
+ if align_obs and paired:
277
+ adata_dict = align_paired_obs_names(adata_dict, how="intersection", copy=copy_aligned)
278
+
279
+ ds = MultiModalDataset(
280
+ adata_dict=adata_dict,
281
+ layer=layer,
282
+ X_key=X_key,
283
+ paired=paired,
284
+ device=device,
285
+ labels=labels,
286
+ dtype=dtype,
287
+ modality_cfgs=modality_cfgs,
288
+ )
289
+ return ds, adata_dict
290
+
291
+
292
+ def load_anndata_dict_from_config(
293
+ modality_cfgs: List[Dict[str, Any]],
294
+ data_root: Optional[str] = None,
295
+ ) -> Dict[str, AnnData]:
296
+ out: Dict[str, AnnData] = {}
297
+ for m in modality_cfgs:
298
+ if "name" not in m or "h5ad_path" not in m:
299
+ raise KeyError("Each modality config must contain keys: 'name' and 'h5ad_path'.")
300
+
301
+ name = m["name"]
302
+ path = m["h5ad_path"]
303
+
304
+ if data_root is not None and not os.path.isabs(path):
305
+ path = os.path.join(data_root, path)
306
+
307
+ out[name] = ad.read_h5ad(path)
308
+
309
+ if not out:
310
+ raise ValueError("No modalities loaded (empty modality_cfgs?)")
311
+
312
+ return out
313
+
314
+
315
+ def collate_multimodal_xy(batch):
316
+ """
317
+ Collate:
318
+ - works for [x_dict, ...] or [(x_dict, y), ...]
319
+ - stacks per-modality tensors into (B, D)
320
+ - supports y as:
321
+ * scalar tensor/int
322
+ * dict[str -> scalar tensor/int]
323
+ """
324
+ if isinstance(batch[0], (tuple, list)) and len(batch[0]) == 2:
325
+ xs, ys = zip(*batch)
326
+
327
+ y0 = ys[0]
328
+ if isinstance(y0, Mapping):
329
+ y_out: Dict[str, torch.Tensor] = {}
330
+ keys = list(y0.keys())
331
+ for k in keys:
332
+ y_out[str(k)] = torch.stack(
333
+ [torch.as_tensor(yy[k], dtype=torch.long) for yy in ys], dim=0
334
+ )
335
+ y = y_out
336
+ else:
337
+ y = torch.stack([torch.as_tensor(yy, dtype=torch.long) for yy in ys], dim=0)
338
+ else:
339
+ xs, y = batch, None
340
+
341
+ keys = xs[0].keys()
342
+ x = {k: torch.stack([d[k] for d in xs], dim=0) for k in keys}
343
+ return x if y is None else (x, y)
344
+
345
+
univi/diagnostics.py ADDED
@@ -0,0 +1,130 @@
1
+ # univi/diagnostics.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Any, Dict, Optional, List
5
+ import os
6
+ import platform
7
+ import importlib
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ from anndata import AnnData
12
+
13
+ from .data import _get_matrix
14
+ from .utils.io import load_config
15
+
16
+
17
+ def _safe_version(pkg: str) -> str:
18
+ try:
19
+ mod = importlib.import_module(pkg)
20
+ return getattr(mod, "__version__", "unknown")
21
+ except Exception:
22
+ return "not_installed"
23
+
24
+
25
+ def collect_environment_info() -> Dict[str, Any]:
26
+ return {
27
+ "python": platform.python_version(),
28
+ "platform": platform.platform(),
29
+ "numpy": _safe_version("numpy"),
30
+ "scipy": _safe_version("scipy"),
31
+ "pandas": _safe_version("pandas"),
32
+ "anndata": _safe_version("anndata"),
33
+ "scanpy": _safe_version("scanpy"),
34
+ "torch": _safe_version("torch"),
35
+ "sklearn": _safe_version("sklearn"),
36
+ "h5py": _safe_version("h5py"),
37
+ "matplotlib": _safe_version("matplotlib"),
38
+ "seaborn": _safe_version("seaborn"),
39
+ }
40
+
41
+
42
+ def dataset_stats_table(
43
+ adata_dict: Dict[str, AnnData],
44
+ *,
45
+ layer_by: Optional[Dict[str, Optional[str]]] = None,
46
+ xkey_by: Optional[Dict[str, str]] = None,
47
+ ) -> pd.DataFrame:
48
+ rows = []
49
+ for nm, adata in adata_dict.items():
50
+ layer = None if layer_by is None else layer_by.get(nm, None)
51
+ xkey = "X" if xkey_by is None else xkey_by.get(nm, "X")
52
+ X = _get_matrix(adata, layer=layer, X_key=xkey)
53
+ rows.append(
54
+ {
55
+ "modality": nm,
56
+ "n_cells": int(adata.n_obs),
57
+ "n_features": int(X.shape[1]),
58
+ "X_key": xkey,
59
+ "layer": layer if layer is not None else "",
60
+ }
61
+ )
62
+ return pd.DataFrame(rows)
63
+
64
+
65
+ def model_hparams_table(cfg: Dict[str, Any]) -> pd.DataFrame:
66
+ model = cfg.get("model", {})
67
+ training = cfg.get("training", {})
68
+ rows = []
69
+ # flatten a curated set of keys
70
+ keys = [
71
+ "loss_mode",
72
+ "v1_recon",
73
+ "v1_recon_mix",
74
+ "normalize_v1_terms",
75
+ "latent_dim",
76
+ "beta",
77
+ "gamma",
78
+ "hidden_dims_default",
79
+ "dropout",
80
+ "encoder_dropout",
81
+ "decoder_dropout",
82
+ "batchnorm",
83
+ "encoder_batchnorm",
84
+ "decoder_batchnorm",
85
+ "kl_anneal_start",
86
+ "kl_anneal_end",
87
+ "align_anneal_start",
88
+ "align_anneal_end",
89
+ ]
90
+ for k in keys:
91
+ if k in model:
92
+ rows.append({"section": "model", "key": k, "value": str(model[k])})
93
+ tkeys = ["n_epochs", "batch_size", "lr", "weight_decay", "device", "seed", "num_workers", "early_stopping", "patience", "min_delta"]
94
+ for k in tkeys:
95
+ if k in training:
96
+ rows.append({"section": "training", "key": k, "value": str(training[k])})
97
+
98
+ # per-modality entries
99
+ for m in cfg.get("data", {}).get("modalities", []):
100
+ name = m.get("name", "modality")
101
+ for k in ["likelihood", "layer", "X_key", "hidden_dims", "encoder_hidden", "decoder_hidden"]:
102
+ if k in m:
103
+ rows.append({"section": f"data.{name}", "key": k, "value": str(m[k])})
104
+
105
+ return pd.DataFrame(rows)
106
+
107
+
108
+ def export_supplemental_table_s1(
109
+ config_path: str,
110
+ adata_dict: Dict[str, AnnData],
111
+ *,
112
+ out_xlsx: str,
113
+ layer_by: Optional[Dict[str, Optional[str]]] = None,
114
+ xkey_by: Optional[Dict[str, str]] = None,
115
+ extra_metrics: Optional[Dict[str, Any]] = None,
116
+ ):
117
+ """Write Supplemental_Table_S1.xlsx: environment + hparams + dataset stats (+ optional metrics)."""
118
+ cfg = load_config(config_path)
119
+ env = collect_environment_info()
120
+ df_env = pd.DataFrame([env])
121
+ df_hp = model_hparams_table(cfg)
122
+ df_ds = dataset_stats_table(adata_dict, layer_by=layer_by, xkey_by=xkey_by)
123
+
124
+ os.makedirs(os.path.dirname(out_xlsx) or ".", exist_ok=True)
125
+ with pd.ExcelWriter(out_xlsx, engine="openpyxl") as w:
126
+ df_env.to_excel(w, index=False, sheet_name="environment")
127
+ df_hp.to_excel(w, index=False, sheet_name="hyperparameters")
128
+ df_ds.to_excel(w, index=False, sheet_name="datasets")
129
+ if extra_metrics:
130
+ pd.DataFrame([extra_metrics]).to_excel(w, index=False, sheet_name="metrics")