univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# univi/hyperparam_optimization/common.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Dict, Any, List, Optional, Iterable, Tuple
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import torch
|
|
13
|
+
from torch.utils.data import DataLoader
|
|
14
|
+
|
|
15
|
+
from univi import UniVIMultiModalVAE, ModalityConfig, UniVIConfig, TrainingConfig
|
|
16
|
+
from univi.data import MultiModalDataset
|
|
17
|
+
from univi.trainer import UniVITrainer
|
|
18
|
+
from univi import evaluation as univi_eval
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class SearchResult:
|
|
23
|
+
config_id: int
|
|
24
|
+
hparams: Dict[str, Any]
|
|
25
|
+
best_val_loss: float
|
|
26
|
+
metrics: Dict[str, float]
|
|
27
|
+
runtime_min: float
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def iter_hparam_configs(
|
|
31
|
+
space_dict: Dict[str, List[Any]],
|
|
32
|
+
max_configs: int,
|
|
33
|
+
seed: int = 0,
|
|
34
|
+
) -> Iterable[Dict[str, Any]]:
|
|
35
|
+
"""
|
|
36
|
+
Random sampler over a hyperparameter space dict.
|
|
37
|
+
Each config independently samples a value for each key.
|
|
38
|
+
"""
|
|
39
|
+
rng = np.random.default_rng(seed)
|
|
40
|
+
keys = list(space_dict.keys())
|
|
41
|
+
for _ in range(max_configs):
|
|
42
|
+
hp = {}
|
|
43
|
+
for k in keys:
|
|
44
|
+
options = space_dict[k]
|
|
45
|
+
idx = rng.integers(len(options))
|
|
46
|
+
hp[k] = options[idx]
|
|
47
|
+
yield hp
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def build_modality_configs(
|
|
51
|
+
arch_config_per_mod: Dict[str, Dict[str, Any]],
|
|
52
|
+
likelihood_per_mod: Dict[str, str],
|
|
53
|
+
input_dims: Dict[str, int],
|
|
54
|
+
) -> List[ModalityConfig]:
|
|
55
|
+
"""
|
|
56
|
+
arch_config_per_mod: e.g. {"rna": {"enc": [...], "dec": [...]}, ...}
|
|
57
|
+
likelihood_per_mod: e.g. {"rna": "nb", "atac": "gaussian"}
|
|
58
|
+
input_dims: e.g. {"rna": rna.n_vars, "atac": atac.n_vars}
|
|
59
|
+
"""
|
|
60
|
+
mod_cfgs: List[ModalityConfig] = []
|
|
61
|
+
for mod_name, arch_cfg in arch_config_per_mod.items():
|
|
62
|
+
enc = arch_cfg["enc"]
|
|
63
|
+
dec = arch_cfg["dec"]
|
|
64
|
+
lik = likelihood_per_mod.get(mod_name, "gaussian")
|
|
65
|
+
in_dim = int(input_dims[mod_name])
|
|
66
|
+
mod_cfgs.append(
|
|
67
|
+
ModalityConfig(
|
|
68
|
+
name=mod_name,
|
|
69
|
+
input_dim=in_dim,
|
|
70
|
+
encoder_hidden=list(enc),
|
|
71
|
+
decoder_hidden=list(dec),
|
|
72
|
+
likelihood=lik,
|
|
73
|
+
)
|
|
74
|
+
)
|
|
75
|
+
return mod_cfgs
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def make_dataloaders(
|
|
79
|
+
adata_train: Dict[str, "AnnData"],
|
|
80
|
+
adata_val: Dict[str, "AnnData"],
|
|
81
|
+
layer: Optional[str],
|
|
82
|
+
X_key: str,
|
|
83
|
+
batch_size: int,
|
|
84
|
+
num_workers: int,
|
|
85
|
+
device: Optional[str] = None,
|
|
86
|
+
) -> Tuple[DataLoader, DataLoader]:
|
|
87
|
+
"""
|
|
88
|
+
Thin wrapper to build MultiModalDataset and DataLoader for train/val.
|
|
89
|
+
"""
|
|
90
|
+
device_obj = None
|
|
91
|
+
if device is not None and device != "cpu":
|
|
92
|
+
device_obj = torch.device(device)
|
|
93
|
+
|
|
94
|
+
train_ds = MultiModalDataset(
|
|
95
|
+
adata_dict=adata_train,
|
|
96
|
+
layer=layer,
|
|
97
|
+
X_key=X_key,
|
|
98
|
+
paired=True,
|
|
99
|
+
device=device_obj,
|
|
100
|
+
)
|
|
101
|
+
val_ds = MultiModalDataset(
|
|
102
|
+
adata_dict=adata_val,
|
|
103
|
+
layer=layer,
|
|
104
|
+
X_key=X_key,
|
|
105
|
+
paired=True,
|
|
106
|
+
device=device_obj,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
train_loader = DataLoader(
|
|
110
|
+
train_ds,
|
|
111
|
+
batch_size=batch_size,
|
|
112
|
+
shuffle=True,
|
|
113
|
+
num_workers=num_workers,
|
|
114
|
+
)
|
|
115
|
+
val_loader = DataLoader(
|
|
116
|
+
val_ds,
|
|
117
|
+
batch_size=batch_size,
|
|
118
|
+
shuffle=False,
|
|
119
|
+
num_workers=num_workers,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
return train_loader, val_loader
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def train_single_config(
|
|
126
|
+
config_id: int,
|
|
127
|
+
hparams: Dict[str, Any],
|
|
128
|
+
mod_arch_space: Dict[str, List[Dict[str, Any]]],
|
|
129
|
+
modalities: List[str],
|
|
130
|
+
input_dims: Dict[str, int],
|
|
131
|
+
likelihood_per_mod: Dict[str, List[str]],
|
|
132
|
+
adata_train: Dict[str, "AnnData"],
|
|
133
|
+
adata_val: Dict[str, "AnnData"],
|
|
134
|
+
base_train_cfg: TrainingConfig,
|
|
135
|
+
layer: Optional[str],
|
|
136
|
+
X_key: str,
|
|
137
|
+
celltype_key: Optional[str],
|
|
138
|
+
device: str = "cuda",
|
|
139
|
+
multimodal_eval: bool = True,
|
|
140
|
+
) -> SearchResult:
|
|
141
|
+
"""
|
|
142
|
+
Train one UniVI model for a given hyperparam config and return metrics.
|
|
143
|
+
|
|
144
|
+
multimodal_eval:
|
|
145
|
+
- True: compute FOSCTTM, label transfer, modality mixing where possible
|
|
146
|
+
- False: only val_loss (for unimodal).
|
|
147
|
+
"""
|
|
148
|
+
start = time.time()
|
|
149
|
+
|
|
150
|
+
# ----- pick architectures -----
|
|
151
|
+
arch_cfg_per_mod: Dict[str, Dict[str, Any]] = {}
|
|
152
|
+
like_cfg_per_mod: Dict[str, str] = {}
|
|
153
|
+
for mod in modalities:
|
|
154
|
+
arch_list = mod_arch_space[mod]
|
|
155
|
+
arch_choice = hparams[f"{mod}_arch"]
|
|
156
|
+
if isinstance(arch_choice, dict) and "name" in arch_choice:
|
|
157
|
+
arch_cfg = arch_choice
|
|
158
|
+
else:
|
|
159
|
+
# if user passed just the dict, keep as-is
|
|
160
|
+
arch_cfg = arch_choice
|
|
161
|
+
arch_cfg_per_mod[mod] = arch_cfg
|
|
162
|
+
|
|
163
|
+
like_options = likelihood_per_mod.get(mod, ["gaussian"])
|
|
164
|
+
like_choice = hparams.get(f"{mod}_likelihood", like_options[0])
|
|
165
|
+
like_cfg_per_mod[mod] = like_choice
|
|
166
|
+
|
|
167
|
+
# ----- build modality configs -----
|
|
168
|
+
mod_cfgs = build_modality_configs(
|
|
169
|
+
arch_config_per_mod=arch_cfg_per_mod,
|
|
170
|
+
likelihood_per_mod=like_cfg_per_mod,
|
|
171
|
+
input_dims=input_dims,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# ----- UniVI config -----
|
|
175
|
+
univi_cfg = UniVIConfig(
|
|
176
|
+
latent_dim=hparams["latent_dim"],
|
|
177
|
+
modalities=mod_cfgs,
|
|
178
|
+
beta=hparams["beta"],
|
|
179
|
+
gamma=hparams["gamma"],
|
|
180
|
+
encoder_dropout=hparams["encoder_dropout"],
|
|
181
|
+
decoder_dropout=0.0,
|
|
182
|
+
encoder_batchnorm=True,
|
|
183
|
+
decoder_batchnorm=hparams["decoder_batchnorm"],
|
|
184
|
+
kl_anneal_start=0,
|
|
185
|
+
kl_anneal_end=0,
|
|
186
|
+
align_anneal_start=0,
|
|
187
|
+
align_anneal_end=0,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
model = UniVIMultiModalVAE(univi_cfg)
|
|
191
|
+
|
|
192
|
+
# ----- dataloaders -----
|
|
193
|
+
train_loader, val_loader = make_dataloaders(
|
|
194
|
+
adata_train=adata_train,
|
|
195
|
+
adata_val=adata_val,
|
|
196
|
+
layer=layer,
|
|
197
|
+
X_key=X_key,
|
|
198
|
+
batch_size=base_train_cfg.batch_size,
|
|
199
|
+
num_workers=base_train_cfg.num_workers,
|
|
200
|
+
device=device,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# ----- training config -----
|
|
204
|
+
train_cfg = TrainingConfig(
|
|
205
|
+
n_epochs=base_train_cfg.n_epochs,
|
|
206
|
+
batch_size=base_train_cfg.batch_size,
|
|
207
|
+
lr=hparams["lr"],
|
|
208
|
+
weight_decay=hparams["weight_decay"],
|
|
209
|
+
device=device,
|
|
210
|
+
log_every=base_train_cfg.log_every,
|
|
211
|
+
grad_clip=base_train_cfg.grad_clip,
|
|
212
|
+
num_workers=base_train_cfg.num_workers,
|
|
213
|
+
seed=base_train_cfg.seed,
|
|
214
|
+
early_stopping=base_train_cfg.early_stopping,
|
|
215
|
+
patience=base_train_cfg.patience,
|
|
216
|
+
min_delta=base_train_cfg.min_delta,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
trainer = UniVITrainer(
|
|
220
|
+
model=model,
|
|
221
|
+
train_loader=train_loader,
|
|
222
|
+
val_loader=val_loader,
|
|
223
|
+
train_cfg=train_cfg,
|
|
224
|
+
device=device,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
print("=" * 80)
|
|
228
|
+
print(f"[Config {config_id}] Hyperparameters:")
|
|
229
|
+
print(json.dumps(hparams, indent=2))
|
|
230
|
+
print("=" * 80, flush=True)
|
|
231
|
+
|
|
232
|
+
history = trainer.fit()
|
|
233
|
+
best_val_loss = trainer.best_val_loss
|
|
234
|
+
|
|
235
|
+
metrics: Dict[str, float] = {"best_val_loss": float(best_val_loss)}
|
|
236
|
+
|
|
237
|
+
# ----- evaluation metrics -----
|
|
238
|
+
if multimodal_eval and len(modalities) >= 2:
|
|
239
|
+
# encode all modalities for val set
|
|
240
|
+
Z_val: Dict[str, np.ndarray] = {}
|
|
241
|
+
for mod in modalities:
|
|
242
|
+
# use same layer used in training
|
|
243
|
+
Z = trainer.encode_modality(
|
|
244
|
+
adata_val[mod],
|
|
245
|
+
modality=mod,
|
|
246
|
+
layer=layer,
|
|
247
|
+
X_key=X_key,
|
|
248
|
+
batch_size=1024,
|
|
249
|
+
)
|
|
250
|
+
Z_val[mod] = Z
|
|
251
|
+
print(f" Encoded modality {mod} into latent shape {Z.shape}")
|
|
252
|
+
|
|
253
|
+
# pairwise FOSCTTM
|
|
254
|
+
mods = modalities
|
|
255
|
+
fos_vals = []
|
|
256
|
+
for i in range(len(mods)):
|
|
257
|
+
for j in range(i + 1, len(mods)):
|
|
258
|
+
m1, m2 = mods[i], mods[j]
|
|
259
|
+
fos = univi_eval.compute_foscttm(Z_val[m1], Z_val[m2])
|
|
260
|
+
key = f"foscttm_{m1}_vs_{m2}"
|
|
261
|
+
metrics[key] = float(fos)
|
|
262
|
+
fos_vals.append(fos)
|
|
263
|
+
print(f" FOSCTTM ({m1} vs {m2}): {fos:.4f}")
|
|
264
|
+
|
|
265
|
+
if fos_vals:
|
|
266
|
+
metrics["foscttm_mean"] = float(np.mean(fos_vals))
|
|
267
|
+
|
|
268
|
+
# label transfer: use first modality as "reference" if celltype_key is given
|
|
269
|
+
if celltype_key is not None and celltype_key in adata_val[modalities[0]].obs:
|
|
270
|
+
ref_mod = modalities[0]
|
|
271
|
+
labels_ref = adata_val[ref_mod].obs[celltype_key].astype(str).values
|
|
272
|
+
|
|
273
|
+
for tgt_mod in modalities[1:]:
|
|
274
|
+
labels_tgt = (
|
|
275
|
+
adata_val[tgt_mod].obs[celltype_key].astype(str).values
|
|
276
|
+
if celltype_key in adata_val[tgt_mod].obs
|
|
277
|
+
else None
|
|
278
|
+
)
|
|
279
|
+
_, acc, _ = univi_eval.label_transfer_knn(
|
|
280
|
+
Z_source=Z_val[ref_mod],
|
|
281
|
+
labels_source=labels_ref,
|
|
282
|
+
Z_target=Z_val[tgt_mod],
|
|
283
|
+
labels_target=labels_tgt,
|
|
284
|
+
k=15,
|
|
285
|
+
)
|
|
286
|
+
if acc is not None:
|
|
287
|
+
key = f"label_acc_{tgt_mod}_from_{ref_mod}"
|
|
288
|
+
metrics[key] = float(acc)
|
|
289
|
+
print(f" Label transfer ({ref_mod}→{tgt_mod}) accuracy: {acc:.3f}")
|
|
290
|
+
|
|
291
|
+
# modality mixing
|
|
292
|
+
Z_joint = np.concatenate(list(Z_val.values()), axis=0)
|
|
293
|
+
modality_labels = np.concatenate(
|
|
294
|
+
[[m] * Z_val[m].shape[0] for m in modalities]
|
|
295
|
+
)
|
|
296
|
+
mix_score = univi_eval.compute_modality_mixing(Z_joint, modality_labels)
|
|
297
|
+
metrics["modality_mixing"] = float(mix_score)
|
|
298
|
+
print(f" Modality mixing score: {mix_score:.4f}")
|
|
299
|
+
|
|
300
|
+
# composite score: simple weighted combo
|
|
301
|
+
if "foscttm_mean" in metrics:
|
|
302
|
+
comp = best_val_loss + 1000.0 * metrics["foscttm_mean"]
|
|
303
|
+
else:
|
|
304
|
+
comp = best_val_loss
|
|
305
|
+
metrics["composite_score"] = float(comp)
|
|
306
|
+
print(f" Composite score: {comp:.3f}")
|
|
307
|
+
else:
|
|
308
|
+
# unimodal: composite == val loss
|
|
309
|
+
metrics["composite_score"] = float(best_val_loss)
|
|
310
|
+
|
|
311
|
+
runtime_min = (time.time() - start) / 60.0
|
|
312
|
+
print(f"[Config {config_id}] Done in {runtime_min:.2f} min")
|
|
313
|
+
print(
|
|
314
|
+
f" best_val_loss = {best_val_loss:.3f}\n"
|
|
315
|
+
f" composite_score = {metrics['composite_score']:.3f}"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return SearchResult(
|
|
319
|
+
config_id=config_id,
|
|
320
|
+
hparams=hparams,
|
|
321
|
+
best_val_loss=float(best_val_loss),
|
|
322
|
+
metrics=metrics,
|
|
323
|
+
runtime_min=runtime_min,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def results_to_dataframe(results: List[SearchResult]) -> pd.DataFrame:
|
|
328
|
+
rows = []
|
|
329
|
+
for r in results:
|
|
330
|
+
row = {
|
|
331
|
+
"config_id": r.config_id,
|
|
332
|
+
"runtime_min": r.runtime_min,
|
|
333
|
+
}
|
|
334
|
+
row.update(r.hparams)
|
|
335
|
+
row.update(r.metrics)
|
|
336
|
+
rows.append(row)
|
|
337
|
+
df = pd.DataFrame(rows)
|
|
338
|
+
df = df.sort_values("composite_score", ascending=True).reset_index(drop=True)
|
|
339
|
+
return df
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# univi/hyperparam_optimization/run_adt_hparam_search.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import List, Any, Optional
|
|
5
|
+
|
|
6
|
+
from anndata import AnnData
|
|
7
|
+
|
|
8
|
+
from univi.config import TrainingConfig
|
|
9
|
+
from .common import (
|
|
10
|
+
iter_hparam_configs,
|
|
11
|
+
train_single_config,
|
|
12
|
+
results_to_dataframe,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_adt_hparam_search(
|
|
17
|
+
adt_train: AnnData,
|
|
18
|
+
adt_val: AnnData,
|
|
19
|
+
device: str = "cuda",
|
|
20
|
+
layer: Optional[str] = "counts", # raw ADT counts; for CLR/log1p use appropriate layer
|
|
21
|
+
X_key: str = "X",
|
|
22
|
+
max_configs: int = 50,
|
|
23
|
+
seed: int = 0,
|
|
24
|
+
base_train_cfg: Optional[TrainingConfig] = None,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Hyperparameter search for *unimodal* ADT.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
adata_train = {"adt": adt_train}
|
|
31
|
+
adata_val = {"adt": adt_val}
|
|
32
|
+
modalities = ["adt"]
|
|
33
|
+
|
|
34
|
+
if base_train_cfg is None:
|
|
35
|
+
base_train_cfg = TrainingConfig(
|
|
36
|
+
n_epochs=80,
|
|
37
|
+
batch_size=256,
|
|
38
|
+
lr=1e-3,
|
|
39
|
+
weight_decay=1e-5,
|
|
40
|
+
device=device,
|
|
41
|
+
log_every=5,
|
|
42
|
+
grad_clip=5.0,
|
|
43
|
+
num_workers=0,
|
|
44
|
+
seed=42,
|
|
45
|
+
early_stopping=True,
|
|
46
|
+
patience=15,
|
|
47
|
+
min_delta=0.0,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
adt_arch_options = [
|
|
51
|
+
{"name": "adt_small2", "enc": [128, 64], "dec": [64, 128]},
|
|
52
|
+
{"name": "adt_med2", "enc": [256, 128], "dec": [128, 256]},
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
mod_arch_space = {"adt": adt_arch_options}
|
|
56
|
+
likelihood_per_mod = {
|
|
57
|
+
"adt": ["nb", "zinb", "gaussian"],
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
search_space = {
|
|
61
|
+
"latent_dim": [10, 20, 32, 40, 50, 64],
|
|
62
|
+
"beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
63
|
+
"gamma": [0.0],
|
|
64
|
+
"lr": [1e-3, 5e-4],
|
|
65
|
+
"weight_decay": [1e-4, 1e-5],
|
|
66
|
+
"encoder_dropout": [0.0, 0.1],
|
|
67
|
+
"decoder_batchnorm": [False, True],
|
|
68
|
+
"adt_arch": adt_arch_options,
|
|
69
|
+
"adt_likelihood": likelihood_per_mod["adt"],
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
input_dims = {"adt": adt_train.n_vars}
|
|
73
|
+
|
|
74
|
+
results: List[Any] = []
|
|
75
|
+
best_score = float("inf")
|
|
76
|
+
best_cfg = None
|
|
77
|
+
best_model = None
|
|
78
|
+
|
|
79
|
+
for cfg_id, hp in enumerate(
|
|
80
|
+
iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
|
|
81
|
+
start=1,
|
|
82
|
+
):
|
|
83
|
+
res = train_single_config(
|
|
84
|
+
config_id=cfg_id,
|
|
85
|
+
hparams=hp,
|
|
86
|
+
mod_arch_space=mod_arch_space,
|
|
87
|
+
modalities=modalities,
|
|
88
|
+
input_dims=input_dims,
|
|
89
|
+
likelihood_per_mod={"adt": likelihood_per_mod["adt"]},
|
|
90
|
+
adata_train=adata_train,
|
|
91
|
+
adata_val=adata_val,
|
|
92
|
+
base_train_cfg=base_train_cfg,
|
|
93
|
+
layer=layer,
|
|
94
|
+
X_key=X_key,
|
|
95
|
+
celltype_key=None,
|
|
96
|
+
device=device,
|
|
97
|
+
multimodal_eval=False,
|
|
98
|
+
)
|
|
99
|
+
results.append(res)
|
|
100
|
+
|
|
101
|
+
score = res.metrics["composite_score"]
|
|
102
|
+
if score < best_score:
|
|
103
|
+
best_score = score
|
|
104
|
+
best_model = res
|
|
105
|
+
best_cfg = hp
|
|
106
|
+
print(f"--> New best ADT-only config (id={cfg_id}) with score={score:.3f}")
|
|
107
|
+
|
|
108
|
+
df = results_to_dataframe(results)
|
|
109
|
+
return df, best_model, best_cfg
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# univi/hyperparam_optimization/run_atac_hparam_search.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import List, Any, Optional
|
|
5
|
+
|
|
6
|
+
from anndata import AnnData
|
|
7
|
+
|
|
8
|
+
from univi.config import TrainingConfig
|
|
9
|
+
from .common import (
|
|
10
|
+
iter_hparam_configs,
|
|
11
|
+
train_single_config,
|
|
12
|
+
results_to_dataframe,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_atac_hparam_search(
|
|
17
|
+
atac_train: AnnData,
|
|
18
|
+
atac_val: AnnData,
|
|
19
|
+
device: str = "cuda",
|
|
20
|
+
layer: Optional[str] = "counts", # raw peaks; or change to TF-IDF/LSI etc.
|
|
21
|
+
X_key: str = "X",
|
|
22
|
+
max_configs: int = 50,
|
|
23
|
+
seed: int = 0,
|
|
24
|
+
base_train_cfg: Optional[TrainingConfig] = None,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Hyperparameter search for *unimodal* ATAC.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
adata_train = {"atac": atac_train}
|
|
31
|
+
adata_val = {"atac": atac_val}
|
|
32
|
+
modalities = ["atac"]
|
|
33
|
+
|
|
34
|
+
if base_train_cfg is None:
|
|
35
|
+
base_train_cfg = TrainingConfig(
|
|
36
|
+
n_epochs=80,
|
|
37
|
+
batch_size=256,
|
|
38
|
+
lr=1e-3,
|
|
39
|
+
weight_decay=1e-5,
|
|
40
|
+
device=device,
|
|
41
|
+
log_every=5,
|
|
42
|
+
grad_clip=5.0,
|
|
43
|
+
num_workers=0,
|
|
44
|
+
seed=42,
|
|
45
|
+
early_stopping=True,
|
|
46
|
+
patience=15,
|
|
47
|
+
min_delta=0.0,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
atac_arch_options = [
|
|
51
|
+
{"name": "atac_med2", "enc": [512, 256], "dec": [256, 512]},
|
|
52
|
+
{"name": "atac_wide2", "enc": [1024, 512], "dec": [512, 1024]},
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
mod_arch_space = {"atac": atac_arch_options}
|
|
56
|
+
likelihood_per_mod = {
|
|
57
|
+
"atac": ["nb", "poisson", "zinb", "gaussian"],
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
search_space = {
|
|
61
|
+
"latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
|
|
62
|
+
"beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
63
|
+
"gamma": [0.0],
|
|
64
|
+
"lr": [1e-3, 5e-4],
|
|
65
|
+
"weight_decay": [1e-4, 1e-5],
|
|
66
|
+
"encoder_dropout": [0.0, 0.1],
|
|
67
|
+
"decoder_batchnorm": [False, True],
|
|
68
|
+
"atac_arch": atac_arch_options,
|
|
69
|
+
"atac_likelihood": likelihood_per_mod["atac"],
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
input_dims = {"atac": atac_train.n_vars}
|
|
73
|
+
|
|
74
|
+
results: List[Any] = []
|
|
75
|
+
best_score = float("inf")
|
|
76
|
+
best_cfg = None
|
|
77
|
+
best_model = None
|
|
78
|
+
|
|
79
|
+
for cfg_id, hp in enumerate(
|
|
80
|
+
iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
|
|
81
|
+
start=1,
|
|
82
|
+
):
|
|
83
|
+
res = train_single_config(
|
|
84
|
+
config_id=cfg_id,
|
|
85
|
+
hparams=hp,
|
|
86
|
+
mod_arch_space=mod_arch_space,
|
|
87
|
+
modalities=modalities,
|
|
88
|
+
input_dims=input_dims,
|
|
89
|
+
likelihood_per_mod={"atac": likelihood_per_mod["atac"]},
|
|
90
|
+
adata_train=adata_train,
|
|
91
|
+
adata_val=adata_val,
|
|
92
|
+
base_train_cfg=base_train_cfg,
|
|
93
|
+
layer=layer,
|
|
94
|
+
X_key=X_key,
|
|
95
|
+
celltype_key=None,
|
|
96
|
+
device=device,
|
|
97
|
+
multimodal_eval=False,
|
|
98
|
+
)
|
|
99
|
+
results.append(res)
|
|
100
|
+
|
|
101
|
+
score = res.metrics["composite_score"]
|
|
102
|
+
if score < best_score:
|
|
103
|
+
best_score = score
|
|
104
|
+
best_model = res
|
|
105
|
+
best_cfg = hp
|
|
106
|
+
print(f"--> New best ATAC-only config (id={cfg_id}) with score={score:.3f}")
|
|
107
|
+
|
|
108
|
+
df = results_to_dataframe(results)
|
|
109
|
+
return df, best_model, best_cfg
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
# univi/hyperparam_optimization/run_citeseq_hparam_search.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
from anndata import AnnData
|
|
7
|
+
|
|
8
|
+
from univi.config import TrainingConfig
|
|
9
|
+
from .common import (
|
|
10
|
+
iter_hparam_configs,
|
|
11
|
+
train_single_config,
|
|
12
|
+
results_to_dataframe,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_citeseq_hparam_search(
|
|
17
|
+
rna_train: AnnData,
|
|
18
|
+
adt_train: AnnData,
|
|
19
|
+
rna_val: AnnData,
|
|
20
|
+
adt_val: AnnData,
|
|
21
|
+
celltype_key: Optional[str] = "cell_type",
|
|
22
|
+
device: str = "cuda",
|
|
23
|
+
layer: str = "counts", # raw counts for NB / ZINB
|
|
24
|
+
X_key: str = "X",
|
|
25
|
+
max_configs: int = 100,
|
|
26
|
+
seed: int = 0,
|
|
27
|
+
base_train_cfg: Optional[TrainingConfig] = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Hyperparameter random search for RNA+ADT CITE-seq.
|
|
31
|
+
|
|
32
|
+
Assumes:
|
|
33
|
+
- rna_* and adt_* are paired and share obs_names.
|
|
34
|
+
- raw counts for RNA and ADT are stored in .layers[layer] (default 'counts').
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
assert rna_train.n_obs == adt_train.n_obs
|
|
38
|
+
assert rna_val.n_obs == adt_val.n_obs
|
|
39
|
+
|
|
40
|
+
adata_train = {"rna": rna_train, "adt": adt_train}
|
|
41
|
+
adata_val = {"rna": rna_val, "adt": adt_val}
|
|
42
|
+
modalities = ["rna", "adt"]
|
|
43
|
+
|
|
44
|
+
if base_train_cfg is None:
|
|
45
|
+
base_train_cfg = TrainingConfig(
|
|
46
|
+
n_epochs=80,
|
|
47
|
+
batch_size=256,
|
|
48
|
+
lr=1e-3,
|
|
49
|
+
weight_decay=1e-5,
|
|
50
|
+
device=device,
|
|
51
|
+
log_every=5,
|
|
52
|
+
grad_clip=5.0,
|
|
53
|
+
num_workers=0,
|
|
54
|
+
seed=42,
|
|
55
|
+
early_stopping=True,
|
|
56
|
+
patience=15,
|
|
57
|
+
min_delta=0.0,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
rna_arch_options = [
|
|
61
|
+
{"name": "rna_med2", "enc": [512, 256], "dec": [256, 512]},
|
|
62
|
+
{"name": "rna_wide2", "enc": [1024, 512], "dec": [512, 1024]},
|
|
63
|
+
{"name": "rna_wide3", "enc": [1024, 512, 256], "dec": [256, 512, 1024]},
|
|
64
|
+
]
|
|
65
|
+
adt_arch_options = [
|
|
66
|
+
{"name": "adt_small2", "enc": [128, 64], "dec": [64, 128]},
|
|
67
|
+
{"name": "adt_med2", "enc": [256, 128], "dec": [128, 256]},
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
mod_arch_space = {
|
|
71
|
+
"rna": rna_arch_options,
|
|
72
|
+
"adt": adt_arch_options,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
likelihood_per_mod = {
|
|
76
|
+
"rna": ["nb", "zinb"],
|
|
77
|
+
"adt": ["nb", "zinb", "gaussian"], # if using CLR/log1p etc.
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
search_space = {
|
|
81
|
+
"latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
|
|
82
|
+
"beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
83
|
+
"gamma": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
84
|
+
"lr": [1e-3, 5e-4],
|
|
85
|
+
"weight_decay": [1e-4, 1e-5],
|
|
86
|
+
"encoder_dropout": [0.0, 0.1],
|
|
87
|
+
"decoder_batchnorm":[False, True],
|
|
88
|
+
"rna_arch": rna_arch_options,
|
|
89
|
+
"adt_arch": adt_arch_options,
|
|
90
|
+
"rna_likelihood": likelihood_per_mod["rna"],
|
|
91
|
+
"adt_likelihood": likelihood_per_mod["adt"],
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
input_dims = {
|
|
95
|
+
"rna": rna_train.n_vars,
|
|
96
|
+
"adt": adt_train.n_vars,
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
results: List[Any] = []
|
|
100
|
+
best_score = float("inf")
|
|
101
|
+
best_model = None
|
|
102
|
+
best_cfg = None
|
|
103
|
+
|
|
104
|
+
for cfg_id, hp in enumerate(
|
|
105
|
+
iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
|
|
106
|
+
start=1,
|
|
107
|
+
):
|
|
108
|
+
res = train_single_config(
|
|
109
|
+
config_id=cfg_id,
|
|
110
|
+
hparams=hp,
|
|
111
|
+
mod_arch_space=mod_arch_space,
|
|
112
|
+
modalities=modalities,
|
|
113
|
+
input_dims=input_dims,
|
|
114
|
+
likelihood_per_mod={
|
|
115
|
+
"rna": likelihood_per_mod["rna"],
|
|
116
|
+
"adt": likelihood_per_mod["adt"],
|
|
117
|
+
},
|
|
118
|
+
adata_train=adata_train,
|
|
119
|
+
adata_val=adata_val,
|
|
120
|
+
base_train_cfg=base_train_cfg,
|
|
121
|
+
layer=layer,
|
|
122
|
+
X_key=X_key,
|
|
123
|
+
celltype_key=celltype_key,
|
|
124
|
+
device=device,
|
|
125
|
+
multimodal_eval=True,
|
|
126
|
+
)
|
|
127
|
+
results.append(res)
|
|
128
|
+
|
|
129
|
+
score = res.metrics["composite_score"]
|
|
130
|
+
if score < best_score:
|
|
131
|
+
best_score = score
|
|
132
|
+
best_model = res
|
|
133
|
+
best_cfg = hp
|
|
134
|
+
print(f"--> New best CITE-seq config (id={cfg_id}) with score={score:.3f}")
|
|
135
|
+
|
|
136
|
+
df = results_to_dataframe(results)
|
|
137
|
+
return df, best_model, best_cfg
|