univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/data.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
1
|
+
# univi/data.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Mapping, Optional, Tuple, Union, List, Sequence
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import numpy as np
|
|
9
|
+
import scipy.sparse as sp
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import anndata as ad
|
|
12
|
+
from anndata import AnnData
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from torch.utils.data import Dataset
|
|
16
|
+
|
|
17
|
+
from .config import ModalityConfig
|
|
18
|
+
|
|
19
|
+
LayerSpec = Union[None, str, Mapping[str, Optional[str]]]
|
|
20
|
+
XKeySpec = Union[str, Mapping[str, str]]
|
|
21
|
+
LabelSpec = Union[
|
|
22
|
+
np.ndarray,
|
|
23
|
+
torch.Tensor,
|
|
24
|
+
Sequence[int],
|
|
25
|
+
Mapping[str, Union[np.ndarray, torch.Tensor, Sequence[int]]],
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _is_categorical_likelihood(lk: Optional[str]) -> bool:
|
|
30
|
+
lk = (lk or "").lower().strip()
|
|
31
|
+
return lk in ("categorical", "cat", "ce", "cross_entropy", "multinomial", "softmax")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _get_matrix(adata_obj: AnnData, *, layer: Optional[str], X_key: str):
|
|
35
|
+
if X_key != "X":
|
|
36
|
+
if X_key not in adata_obj.obsm:
|
|
37
|
+
raise KeyError("X_key=%r not found in adata.obsm. Keys=%s" % (X_key, list(adata_obj.obsm.keys())))
|
|
38
|
+
return adata_obj.obsm[X_key]
|
|
39
|
+
|
|
40
|
+
if layer is not None:
|
|
41
|
+
if layer not in adata_obj.layers:
|
|
42
|
+
raise KeyError("layer=%r not found in adata.layers. Keys=%s" % (layer, list(adata_obj.layers.keys())))
|
|
43
|
+
return adata_obj.layers[layer]
|
|
44
|
+
|
|
45
|
+
return adata_obj.X
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def infer_input_dim(adata_obj: AnnData, *, layer: Optional[str], X_key: str) -> int:
|
|
49
|
+
X = _get_matrix(adata_obj, layer=layer, X_key=X_key)
|
|
50
|
+
if not hasattr(X, "shape") or len(X.shape) != 2:
|
|
51
|
+
raise ValueError("Selected matrix for (layer=%r, X_key=%r) is not 2D." % (layer, X_key))
|
|
52
|
+
return int(X.shape[1])
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def align_paired_obs_names(
|
|
56
|
+
adata_dict: Dict[str, AnnData],
|
|
57
|
+
how: str = "intersection",
|
|
58
|
+
require_nonempty: bool = True,
|
|
59
|
+
sort: bool = True,
|
|
60
|
+
copy: bool = True,
|
|
61
|
+
) -> Dict[str, AnnData]:
|
|
62
|
+
if not adata_dict:
|
|
63
|
+
raise ValueError("adata_dict is empty")
|
|
64
|
+
if how != "intersection":
|
|
65
|
+
raise ValueError("Unsupported how=%r. Only 'intersection' is supported." % how)
|
|
66
|
+
|
|
67
|
+
names = list(adata_dict.keys())
|
|
68
|
+
shared = None
|
|
69
|
+
for nm in names:
|
|
70
|
+
idx = adata_dict[nm].obs_names
|
|
71
|
+
shared = idx if shared is None else shared.intersection(idx)
|
|
72
|
+
|
|
73
|
+
if shared is None:
|
|
74
|
+
shared = pd.Index([])
|
|
75
|
+
|
|
76
|
+
if require_nonempty and len(shared) == 0:
|
|
77
|
+
raise ValueError("No shared obs_names across modalities (intersection is empty).")
|
|
78
|
+
|
|
79
|
+
if sort:
|
|
80
|
+
shared = shared.sort_values()
|
|
81
|
+
|
|
82
|
+
out: Dict[str, AnnData] = {}
|
|
83
|
+
for nm in names:
|
|
84
|
+
slc = adata_dict[nm][shared, :]
|
|
85
|
+
out[nm] = slc.copy() if copy else slc
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _as_modality_map(
|
|
90
|
+
spec: Union[str, None, Mapping[str, Any]],
|
|
91
|
+
adata_dict: Dict[str, AnnData],
|
|
92
|
+
kind: str,
|
|
93
|
+
) -> Dict[str, Any]:
|
|
94
|
+
if isinstance(spec, Mapping):
|
|
95
|
+
out = dict(spec)
|
|
96
|
+
else:
|
|
97
|
+
out = {k: spec for k in adata_dict.keys()}
|
|
98
|
+
|
|
99
|
+
for k in adata_dict.keys():
|
|
100
|
+
if k not in out:
|
|
101
|
+
out[k] = None if kind == "layer" else "X"
|
|
102
|
+
return out
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class MultiModalDataset(Dataset):
|
|
106
|
+
"""
|
|
107
|
+
Multi-modal AnnData-backed torch Dataset.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
- x_dict: Dict[modality -> FloatTensor]
|
|
111
|
+
- (x_dict, y) if labels are provided, where y is:
|
|
112
|
+
* LongTensor scalar (back-compat), OR
|
|
113
|
+
* dict[str -> LongTensor scalar] (multi-head)
|
|
114
|
+
|
|
115
|
+
Categorical modality support:
|
|
116
|
+
- If modality_cfgs marks a modality as categorical with input_kind="obs",
|
|
117
|
+
x_dict[modality] is a (1,) float tensor holding an integer code.
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
adata_dict: Dict[str, AnnData],
|
|
123
|
+
layer: LayerSpec = None,
|
|
124
|
+
X_key: XKeySpec = "X",
|
|
125
|
+
paired: bool = True,
|
|
126
|
+
device: Optional[torch.device] = None,
|
|
127
|
+
labels: Optional[LabelSpec] = None,
|
|
128
|
+
dtype: torch.dtype = torch.float32,
|
|
129
|
+
modality_cfgs: Optional[List[ModalityConfig]] = None,
|
|
130
|
+
):
|
|
131
|
+
if not adata_dict:
|
|
132
|
+
raise ValueError("adata_dict is empty")
|
|
133
|
+
|
|
134
|
+
self.adata_dict: Dict[str, AnnData] = adata_dict
|
|
135
|
+
self.modalities: List[str] = list(adata_dict.keys())
|
|
136
|
+
self.paired = bool(paired)
|
|
137
|
+
self.device = device
|
|
138
|
+
self.dtype = dtype
|
|
139
|
+
|
|
140
|
+
self.layer_by_modality: Dict[str, Optional[str]] = _as_modality_map(layer, adata_dict, kind="layer")
|
|
141
|
+
self.xkey_by_modality: Dict[str, str] = _as_modality_map(X_key, adata_dict, kind="xkey")
|
|
142
|
+
|
|
143
|
+
self.mod_cfg_by_name: Dict[str, ModalityConfig] = {}
|
|
144
|
+
if modality_cfgs is not None:
|
|
145
|
+
self.mod_cfg_by_name = {m.name: m for m in modality_cfgs}
|
|
146
|
+
|
|
147
|
+
first = next(iter(adata_dict.values()))
|
|
148
|
+
self._n_cells: int = int(first.n_obs)
|
|
149
|
+
self._obs_names = first.obs_names
|
|
150
|
+
|
|
151
|
+
if self.paired:
|
|
152
|
+
for nm, adata_obj in self.adata_dict.items():
|
|
153
|
+
if int(adata_obj.n_obs) != self._n_cells:
|
|
154
|
+
raise ValueError(
|
|
155
|
+
f"Paired dataset requires matching n_obs across modalities. "
|
|
156
|
+
f"First={self._n_cells}, {nm}={adata_obj.n_obs}"
|
|
157
|
+
)
|
|
158
|
+
if not np.array_equal(adata_obj.obs_names.values, self._obs_names.values):
|
|
159
|
+
raise ValueError(
|
|
160
|
+
"Paired dataset requires identical obs_names order; %r differs. "
|
|
161
|
+
"Tip: use dataset_from_anndata_dict(..., align_obs=True)." % nm
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# Labels (optional): either a single vector or a dict of vectors
|
|
165
|
+
self.labels: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None
|
|
166
|
+
if labels is not None:
|
|
167
|
+
if isinstance(labels, Mapping):
|
|
168
|
+
yd: Dict[str, torch.Tensor] = {}
|
|
169
|
+
for hk, hv in labels.items():
|
|
170
|
+
t = hv if torch.is_tensor(hv) else torch.as_tensor(hv)
|
|
171
|
+
if t.ndim != 1:
|
|
172
|
+
t = t.reshape(-1)
|
|
173
|
+
if int(t.shape[0]) != self._n_cells:
|
|
174
|
+
raise ValueError(f"labels[{hk!r}] length ({int(t.shape[0])}) must equal n_cells ({self._n_cells})")
|
|
175
|
+
t = t.long()
|
|
176
|
+
if self.device is not None:
|
|
177
|
+
t = t.to(self.device)
|
|
178
|
+
yd[str(hk)] = t
|
|
179
|
+
self.labels = yd
|
|
180
|
+
else:
|
|
181
|
+
y = labels if torch.is_tensor(labels) else torch.as_tensor(labels)
|
|
182
|
+
if y.ndim != 1:
|
|
183
|
+
y = y.reshape(-1)
|
|
184
|
+
if int(y.shape[0]) != self._n_cells:
|
|
185
|
+
raise ValueError(f"labels length ({int(y.shape[0])}) must equal n_cells ({self._n_cells})")
|
|
186
|
+
y = y.long()
|
|
187
|
+
if self.device is not None:
|
|
188
|
+
y = y.to(self.device)
|
|
189
|
+
self.labels = y
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def n_cells(self) -> int:
|
|
193
|
+
return self._n_cells
|
|
194
|
+
|
|
195
|
+
@property
|
|
196
|
+
def obs_names(self):
|
|
197
|
+
return self._obs_names
|
|
198
|
+
|
|
199
|
+
def __len__(self) -> int:
|
|
200
|
+
return self._n_cells
|
|
201
|
+
|
|
202
|
+
def _get_X_row(self, adata_obj: AnnData, idx: int, layer: Optional[str], X_key: str) -> np.ndarray:
|
|
203
|
+
X = _get_matrix(adata_obj, layer=layer, X_key=X_key)
|
|
204
|
+
row = X[idx]
|
|
205
|
+
if sp.issparse(row):
|
|
206
|
+
row = row.toarray()
|
|
207
|
+
return np.asarray(row).reshape(-1).astype(np.float32, copy=False)
|
|
208
|
+
|
|
209
|
+
def _get_obs_label_row(self, adata_obj: AnnData, idx: int, obs_key: str) -> np.ndarray:
|
|
210
|
+
if obs_key not in adata_obj.obs:
|
|
211
|
+
raise KeyError(f"obs_key={obs_key!r} not found in adata.obs columns.")
|
|
212
|
+
|
|
213
|
+
col = adata_obj.obs[obs_key]
|
|
214
|
+
|
|
215
|
+
if pd.api.types.is_categorical_dtype(col):
|
|
216
|
+
v = int(col.cat.codes.iloc[idx])
|
|
217
|
+
return np.asarray([v], dtype=np.float32)
|
|
218
|
+
|
|
219
|
+
v = col.iloc[idx]
|
|
220
|
+
if isinstance(v, (np.integer, int)):
|
|
221
|
+
return np.asarray([int(v)], dtype=np.float32)
|
|
222
|
+
if isinstance(v, (np.floating, float)):
|
|
223
|
+
return np.asarray([float(v)], dtype=np.float32)
|
|
224
|
+
|
|
225
|
+
raise TypeError(
|
|
226
|
+
f"adata.obs[{obs_key!r}] must be numeric integer codes (or pandas Categorical). "
|
|
227
|
+
f"Got type {type(v)} at row {idx}. Encode categories to int codes first."
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
def __getitem__(self, idx: int):
|
|
231
|
+
x_dict: Dict[str, torch.Tensor] = {}
|
|
232
|
+
|
|
233
|
+
for name, adata_obj in self.adata_dict.items():
|
|
234
|
+
mcfg = self.mod_cfg_by_name.get(name, None)
|
|
235
|
+
|
|
236
|
+
if (
|
|
237
|
+
mcfg is not None
|
|
238
|
+
and _is_categorical_likelihood(mcfg.likelihood)
|
|
239
|
+
and (mcfg.input_kind or "matrix") == "obs"
|
|
240
|
+
):
|
|
241
|
+
if not mcfg.obs_key:
|
|
242
|
+
raise ValueError(f"Modality {name!r}: input_kind='obs' requires obs_key.")
|
|
243
|
+
row_np = self._get_obs_label_row(adata_obj, idx, obs_key=mcfg.obs_key)
|
|
244
|
+
else:
|
|
245
|
+
layer = self.layer_by_modality.get(name, None)
|
|
246
|
+
xkey = self.xkey_by_modality.get(name, "X")
|
|
247
|
+
row_np = self._get_X_row(adata_obj, idx, layer=layer, X_key=xkey)
|
|
248
|
+
|
|
249
|
+
x = torch.from_numpy(row_np).to(dtype=self.dtype)
|
|
250
|
+
if self.device is not None:
|
|
251
|
+
x = x.to(self.device)
|
|
252
|
+
x_dict[name] = x
|
|
253
|
+
|
|
254
|
+
if self.labels is None:
|
|
255
|
+
return x_dict
|
|
256
|
+
|
|
257
|
+
if isinstance(self.labels, dict):
|
|
258
|
+
y_out: Dict[str, torch.Tensor] = {k: v[idx] for k, v in self.labels.items()}
|
|
259
|
+
return x_dict, y_out
|
|
260
|
+
|
|
261
|
+
return x_dict, self.labels[idx]
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def dataset_from_anndata_dict(
|
|
265
|
+
adata_dict: Dict[str, AnnData],
|
|
266
|
+
layer: LayerSpec = None,
|
|
267
|
+
X_key: XKeySpec = "X",
|
|
268
|
+
paired: bool = True,
|
|
269
|
+
align_obs: bool = True,
|
|
270
|
+
labels: Optional[LabelSpec] = None,
|
|
271
|
+
device: Optional[torch.device] = None,
|
|
272
|
+
dtype: torch.dtype = torch.float32,
|
|
273
|
+
copy_aligned: bool = True,
|
|
274
|
+
modality_cfgs: Optional[List[ModalityConfig]] = None,
|
|
275
|
+
) -> Tuple[MultiModalDataset, Dict[str, AnnData]]:
|
|
276
|
+
if align_obs and paired:
|
|
277
|
+
adata_dict = align_paired_obs_names(adata_dict, how="intersection", copy=copy_aligned)
|
|
278
|
+
|
|
279
|
+
ds = MultiModalDataset(
|
|
280
|
+
adata_dict=adata_dict,
|
|
281
|
+
layer=layer,
|
|
282
|
+
X_key=X_key,
|
|
283
|
+
paired=paired,
|
|
284
|
+
device=device,
|
|
285
|
+
labels=labels,
|
|
286
|
+
dtype=dtype,
|
|
287
|
+
modality_cfgs=modality_cfgs,
|
|
288
|
+
)
|
|
289
|
+
return ds, adata_dict
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def load_anndata_dict_from_config(
|
|
293
|
+
modality_cfgs: List[Dict[str, Any]],
|
|
294
|
+
data_root: Optional[str] = None,
|
|
295
|
+
) -> Dict[str, AnnData]:
|
|
296
|
+
out: Dict[str, AnnData] = {}
|
|
297
|
+
for m in modality_cfgs:
|
|
298
|
+
if "name" not in m or "h5ad_path" not in m:
|
|
299
|
+
raise KeyError("Each modality config must contain keys: 'name' and 'h5ad_path'.")
|
|
300
|
+
|
|
301
|
+
name = m["name"]
|
|
302
|
+
path = m["h5ad_path"]
|
|
303
|
+
|
|
304
|
+
if data_root is not None and not os.path.isabs(path):
|
|
305
|
+
path = os.path.join(data_root, path)
|
|
306
|
+
|
|
307
|
+
out[name] = ad.read_h5ad(path)
|
|
308
|
+
|
|
309
|
+
if not out:
|
|
310
|
+
raise ValueError("No modalities loaded (empty modality_cfgs?)")
|
|
311
|
+
|
|
312
|
+
return out
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def collate_multimodal_xy(batch):
|
|
316
|
+
"""
|
|
317
|
+
Collate:
|
|
318
|
+
- works for [x_dict, ...] or [(x_dict, y), ...]
|
|
319
|
+
- stacks per-modality tensors into (B, D)
|
|
320
|
+
- supports y as:
|
|
321
|
+
* scalar tensor/int
|
|
322
|
+
* dict[str -> scalar tensor/int]
|
|
323
|
+
"""
|
|
324
|
+
if isinstance(batch[0], (tuple, list)) and len(batch[0]) == 2:
|
|
325
|
+
xs, ys = zip(*batch)
|
|
326
|
+
|
|
327
|
+
y0 = ys[0]
|
|
328
|
+
if isinstance(y0, Mapping):
|
|
329
|
+
y_out: Dict[str, torch.Tensor] = {}
|
|
330
|
+
keys = list(y0.keys())
|
|
331
|
+
for k in keys:
|
|
332
|
+
y_out[str(k)] = torch.stack(
|
|
333
|
+
[torch.as_tensor(yy[k], dtype=torch.long) for yy in ys], dim=0
|
|
334
|
+
)
|
|
335
|
+
y = y_out
|
|
336
|
+
else:
|
|
337
|
+
y = torch.stack([torch.as_tensor(yy, dtype=torch.long) for yy in ys], dim=0)
|
|
338
|
+
else:
|
|
339
|
+
xs, y = batch, None
|
|
340
|
+
|
|
341
|
+
keys = xs[0].keys()
|
|
342
|
+
x = {k: torch.stack([d[k] for d in xs], dim=0) for k in keys}
|
|
343
|
+
return x if y is None else (x, y)
|
|
344
|
+
|
|
345
|
+
|
univi/diagnostics.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
# univi/diagnostics.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any, Dict, Optional, List
|
|
5
|
+
import os
|
|
6
|
+
import platform
|
|
7
|
+
import importlib
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
from anndata import AnnData
|
|
12
|
+
|
|
13
|
+
from .data import _get_matrix
|
|
14
|
+
from .utils.io import load_config
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _safe_version(pkg: str) -> str:
|
|
18
|
+
try:
|
|
19
|
+
mod = importlib.import_module(pkg)
|
|
20
|
+
return getattr(mod, "__version__", "unknown")
|
|
21
|
+
except Exception:
|
|
22
|
+
return "not_installed"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def collect_environment_info() -> Dict[str, Any]:
|
|
26
|
+
return {
|
|
27
|
+
"python": platform.python_version(),
|
|
28
|
+
"platform": platform.platform(),
|
|
29
|
+
"numpy": _safe_version("numpy"),
|
|
30
|
+
"scipy": _safe_version("scipy"),
|
|
31
|
+
"pandas": _safe_version("pandas"),
|
|
32
|
+
"anndata": _safe_version("anndata"),
|
|
33
|
+
"scanpy": _safe_version("scanpy"),
|
|
34
|
+
"torch": _safe_version("torch"),
|
|
35
|
+
"sklearn": _safe_version("sklearn"),
|
|
36
|
+
"h5py": _safe_version("h5py"),
|
|
37
|
+
"matplotlib": _safe_version("matplotlib"),
|
|
38
|
+
"seaborn": _safe_version("seaborn"),
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def dataset_stats_table(
|
|
43
|
+
adata_dict: Dict[str, AnnData],
|
|
44
|
+
*,
|
|
45
|
+
layer_by: Optional[Dict[str, Optional[str]]] = None,
|
|
46
|
+
xkey_by: Optional[Dict[str, str]] = None,
|
|
47
|
+
) -> pd.DataFrame:
|
|
48
|
+
rows = []
|
|
49
|
+
for nm, adata in adata_dict.items():
|
|
50
|
+
layer = None if layer_by is None else layer_by.get(nm, None)
|
|
51
|
+
xkey = "X" if xkey_by is None else xkey_by.get(nm, "X")
|
|
52
|
+
X = _get_matrix(adata, layer=layer, X_key=xkey)
|
|
53
|
+
rows.append(
|
|
54
|
+
{
|
|
55
|
+
"modality": nm,
|
|
56
|
+
"n_cells": int(adata.n_obs),
|
|
57
|
+
"n_features": int(X.shape[1]),
|
|
58
|
+
"X_key": xkey,
|
|
59
|
+
"layer": layer if layer is not None else "",
|
|
60
|
+
}
|
|
61
|
+
)
|
|
62
|
+
return pd.DataFrame(rows)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def model_hparams_table(cfg: Dict[str, Any]) -> pd.DataFrame:
|
|
66
|
+
model = cfg.get("model", {})
|
|
67
|
+
training = cfg.get("training", {})
|
|
68
|
+
rows = []
|
|
69
|
+
# flatten a curated set of keys
|
|
70
|
+
keys = [
|
|
71
|
+
"loss_mode",
|
|
72
|
+
"v1_recon",
|
|
73
|
+
"v1_recon_mix",
|
|
74
|
+
"normalize_v1_terms",
|
|
75
|
+
"latent_dim",
|
|
76
|
+
"beta",
|
|
77
|
+
"gamma",
|
|
78
|
+
"hidden_dims_default",
|
|
79
|
+
"dropout",
|
|
80
|
+
"encoder_dropout",
|
|
81
|
+
"decoder_dropout",
|
|
82
|
+
"batchnorm",
|
|
83
|
+
"encoder_batchnorm",
|
|
84
|
+
"decoder_batchnorm",
|
|
85
|
+
"kl_anneal_start",
|
|
86
|
+
"kl_anneal_end",
|
|
87
|
+
"align_anneal_start",
|
|
88
|
+
"align_anneal_end",
|
|
89
|
+
]
|
|
90
|
+
for k in keys:
|
|
91
|
+
if k in model:
|
|
92
|
+
rows.append({"section": "model", "key": k, "value": str(model[k])})
|
|
93
|
+
tkeys = ["n_epochs", "batch_size", "lr", "weight_decay", "device", "seed", "num_workers", "early_stopping", "patience", "min_delta"]
|
|
94
|
+
for k in tkeys:
|
|
95
|
+
if k in training:
|
|
96
|
+
rows.append({"section": "training", "key": k, "value": str(training[k])})
|
|
97
|
+
|
|
98
|
+
# per-modality entries
|
|
99
|
+
for m in cfg.get("data", {}).get("modalities", []):
|
|
100
|
+
name = m.get("name", "modality")
|
|
101
|
+
for k in ["likelihood", "layer", "X_key", "hidden_dims", "encoder_hidden", "decoder_hidden"]:
|
|
102
|
+
if k in m:
|
|
103
|
+
rows.append({"section": f"data.{name}", "key": k, "value": str(m[k])})
|
|
104
|
+
|
|
105
|
+
return pd.DataFrame(rows)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def export_supplemental_table_s1(
|
|
109
|
+
config_path: str,
|
|
110
|
+
adata_dict: Dict[str, AnnData],
|
|
111
|
+
*,
|
|
112
|
+
out_xlsx: str,
|
|
113
|
+
layer_by: Optional[Dict[str, Optional[str]]] = None,
|
|
114
|
+
xkey_by: Optional[Dict[str, str]] = None,
|
|
115
|
+
extra_metrics: Optional[Dict[str, Any]] = None,
|
|
116
|
+
):
|
|
117
|
+
"""Write Supplemental_Table_S1.xlsx: environment + hparams + dataset stats (+ optional metrics)."""
|
|
118
|
+
cfg = load_config(config_path)
|
|
119
|
+
env = collect_environment_info()
|
|
120
|
+
df_env = pd.DataFrame([env])
|
|
121
|
+
df_hp = model_hparams_table(cfg)
|
|
122
|
+
df_ds = dataset_stats_table(adata_dict, layer_by=layer_by, xkey_by=xkey_by)
|
|
123
|
+
|
|
124
|
+
os.makedirs(os.path.dirname(out_xlsx) or ".", exist_ok=True)
|
|
125
|
+
with pd.ExcelWriter(out_xlsx, engine="openpyxl") as w:
|
|
126
|
+
df_env.to_excel(w, index=False, sheet_name="environment")
|
|
127
|
+
df_hp.to_excel(w, index=False, sheet_name="hyperparameters")
|
|
128
|
+
df_ds.to_excel(w, index=False, sheet_name="datasets")
|
|
129
|
+
if extra_metrics:
|
|
130
|
+
pd.DataFrame([extra_metrics]).to_excel(w, index=False, sheet_name="metrics")
|