univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,339 @@
1
+ # univi/hyperparam_optimization/common.py
2
+
3
+ from __future__ import annotations
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Any, List, Optional, Iterable, Tuple
6
+
7
+ import json
8
+ import time
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+
15
+ from univi import UniVIMultiModalVAE, ModalityConfig, UniVIConfig, TrainingConfig
16
+ from univi.data import MultiModalDataset
17
+ from univi.trainer import UniVITrainer
18
+ from univi import evaluation as univi_eval
19
+
20
+
21
+ @dataclass
22
+ class SearchResult:
23
+ config_id: int
24
+ hparams: Dict[str, Any]
25
+ best_val_loss: float
26
+ metrics: Dict[str, float]
27
+ runtime_min: float
28
+
29
+
30
+ def iter_hparam_configs(
31
+ space_dict: Dict[str, List[Any]],
32
+ max_configs: int,
33
+ seed: int = 0,
34
+ ) -> Iterable[Dict[str, Any]]:
35
+ """
36
+ Random sampler over a hyperparameter space dict.
37
+ Each config independently samples a value for each key.
38
+ """
39
+ rng = np.random.default_rng(seed)
40
+ keys = list(space_dict.keys())
41
+ for _ in range(max_configs):
42
+ hp = {}
43
+ for k in keys:
44
+ options = space_dict[k]
45
+ idx = rng.integers(len(options))
46
+ hp[k] = options[idx]
47
+ yield hp
48
+
49
+
50
+ def build_modality_configs(
51
+ arch_config_per_mod: Dict[str, Dict[str, Any]],
52
+ likelihood_per_mod: Dict[str, str],
53
+ input_dims: Dict[str, int],
54
+ ) -> List[ModalityConfig]:
55
+ """
56
+ arch_config_per_mod: e.g. {"rna": {"enc": [...], "dec": [...]}, ...}
57
+ likelihood_per_mod: e.g. {"rna": "nb", "atac": "gaussian"}
58
+ input_dims: e.g. {"rna": rna.n_vars, "atac": atac.n_vars}
59
+ """
60
+ mod_cfgs: List[ModalityConfig] = []
61
+ for mod_name, arch_cfg in arch_config_per_mod.items():
62
+ enc = arch_cfg["enc"]
63
+ dec = arch_cfg["dec"]
64
+ lik = likelihood_per_mod.get(mod_name, "gaussian")
65
+ in_dim = int(input_dims[mod_name])
66
+ mod_cfgs.append(
67
+ ModalityConfig(
68
+ name=mod_name,
69
+ input_dim=in_dim,
70
+ encoder_hidden=list(enc),
71
+ decoder_hidden=list(dec),
72
+ likelihood=lik,
73
+ )
74
+ )
75
+ return mod_cfgs
76
+
77
+
78
+ def make_dataloaders(
79
+ adata_train: Dict[str, "AnnData"],
80
+ adata_val: Dict[str, "AnnData"],
81
+ layer: Optional[str],
82
+ X_key: str,
83
+ batch_size: int,
84
+ num_workers: int,
85
+ device: Optional[str] = None,
86
+ ) -> Tuple[DataLoader, DataLoader]:
87
+ """
88
+ Thin wrapper to build MultiModalDataset and DataLoader for train/val.
89
+ """
90
+ device_obj = None
91
+ if device is not None and device != "cpu":
92
+ device_obj = torch.device(device)
93
+
94
+ train_ds = MultiModalDataset(
95
+ adata_dict=adata_train,
96
+ layer=layer,
97
+ X_key=X_key,
98
+ paired=True,
99
+ device=device_obj,
100
+ )
101
+ val_ds = MultiModalDataset(
102
+ adata_dict=adata_val,
103
+ layer=layer,
104
+ X_key=X_key,
105
+ paired=True,
106
+ device=device_obj,
107
+ )
108
+
109
+ train_loader = DataLoader(
110
+ train_ds,
111
+ batch_size=batch_size,
112
+ shuffle=True,
113
+ num_workers=num_workers,
114
+ )
115
+ val_loader = DataLoader(
116
+ val_ds,
117
+ batch_size=batch_size,
118
+ shuffle=False,
119
+ num_workers=num_workers,
120
+ )
121
+
122
+ return train_loader, val_loader
123
+
124
+
125
+ def train_single_config(
126
+ config_id: int,
127
+ hparams: Dict[str, Any],
128
+ mod_arch_space: Dict[str, List[Dict[str, Any]]],
129
+ modalities: List[str],
130
+ input_dims: Dict[str, int],
131
+ likelihood_per_mod: Dict[str, List[str]],
132
+ adata_train: Dict[str, "AnnData"],
133
+ adata_val: Dict[str, "AnnData"],
134
+ base_train_cfg: TrainingConfig,
135
+ layer: Optional[str],
136
+ X_key: str,
137
+ celltype_key: Optional[str],
138
+ device: str = "cuda",
139
+ multimodal_eval: bool = True,
140
+ ) -> SearchResult:
141
+ """
142
+ Train one UniVI model for a given hyperparam config and return metrics.
143
+
144
+ multimodal_eval:
145
+ - True: compute FOSCTTM, label transfer, modality mixing where possible
146
+ - False: only val_loss (for unimodal).
147
+ """
148
+ start = time.time()
149
+
150
+ # ----- pick architectures -----
151
+ arch_cfg_per_mod: Dict[str, Dict[str, Any]] = {}
152
+ like_cfg_per_mod: Dict[str, str] = {}
153
+ for mod in modalities:
154
+ arch_list = mod_arch_space[mod]
155
+ arch_choice = hparams[f"{mod}_arch"]
156
+ if isinstance(arch_choice, dict) and "name" in arch_choice:
157
+ arch_cfg = arch_choice
158
+ else:
159
+ # if user passed just the dict, keep as-is
160
+ arch_cfg = arch_choice
161
+ arch_cfg_per_mod[mod] = arch_cfg
162
+
163
+ like_options = likelihood_per_mod.get(mod, ["gaussian"])
164
+ like_choice = hparams.get(f"{mod}_likelihood", like_options[0])
165
+ like_cfg_per_mod[mod] = like_choice
166
+
167
+ # ----- build modality configs -----
168
+ mod_cfgs = build_modality_configs(
169
+ arch_config_per_mod=arch_cfg_per_mod,
170
+ likelihood_per_mod=like_cfg_per_mod,
171
+ input_dims=input_dims,
172
+ )
173
+
174
+ # ----- UniVI config -----
175
+ univi_cfg = UniVIConfig(
176
+ latent_dim=hparams["latent_dim"],
177
+ modalities=mod_cfgs,
178
+ beta=hparams["beta"],
179
+ gamma=hparams["gamma"],
180
+ encoder_dropout=hparams["encoder_dropout"],
181
+ decoder_dropout=0.0,
182
+ encoder_batchnorm=True,
183
+ decoder_batchnorm=hparams["decoder_batchnorm"],
184
+ kl_anneal_start=0,
185
+ kl_anneal_end=0,
186
+ align_anneal_start=0,
187
+ align_anneal_end=0,
188
+ )
189
+
190
+ model = UniVIMultiModalVAE(univi_cfg)
191
+
192
+ # ----- dataloaders -----
193
+ train_loader, val_loader = make_dataloaders(
194
+ adata_train=adata_train,
195
+ adata_val=adata_val,
196
+ layer=layer,
197
+ X_key=X_key,
198
+ batch_size=base_train_cfg.batch_size,
199
+ num_workers=base_train_cfg.num_workers,
200
+ device=device,
201
+ )
202
+
203
+ # ----- training config -----
204
+ train_cfg = TrainingConfig(
205
+ n_epochs=base_train_cfg.n_epochs,
206
+ batch_size=base_train_cfg.batch_size,
207
+ lr=hparams["lr"],
208
+ weight_decay=hparams["weight_decay"],
209
+ device=device,
210
+ log_every=base_train_cfg.log_every,
211
+ grad_clip=base_train_cfg.grad_clip,
212
+ num_workers=base_train_cfg.num_workers,
213
+ seed=base_train_cfg.seed,
214
+ early_stopping=base_train_cfg.early_stopping,
215
+ patience=base_train_cfg.patience,
216
+ min_delta=base_train_cfg.min_delta,
217
+ )
218
+
219
+ trainer = UniVITrainer(
220
+ model=model,
221
+ train_loader=train_loader,
222
+ val_loader=val_loader,
223
+ train_cfg=train_cfg,
224
+ device=device,
225
+ )
226
+
227
+ print("=" * 80)
228
+ print(f"[Config {config_id}] Hyperparameters:")
229
+ print(json.dumps(hparams, indent=2))
230
+ print("=" * 80, flush=True)
231
+
232
+ history = trainer.fit()
233
+ best_val_loss = trainer.best_val_loss
234
+
235
+ metrics: Dict[str, float] = {"best_val_loss": float(best_val_loss)}
236
+
237
+ # ----- evaluation metrics -----
238
+ if multimodal_eval and len(modalities) >= 2:
239
+ # encode all modalities for val set
240
+ Z_val: Dict[str, np.ndarray] = {}
241
+ for mod in modalities:
242
+ # use same layer used in training
243
+ Z = trainer.encode_modality(
244
+ adata_val[mod],
245
+ modality=mod,
246
+ layer=layer,
247
+ X_key=X_key,
248
+ batch_size=1024,
249
+ )
250
+ Z_val[mod] = Z
251
+ print(f" Encoded modality {mod} into latent shape {Z.shape}")
252
+
253
+ # pairwise FOSCTTM
254
+ mods = modalities
255
+ fos_vals = []
256
+ for i in range(len(mods)):
257
+ for j in range(i + 1, len(mods)):
258
+ m1, m2 = mods[i], mods[j]
259
+ fos = univi_eval.compute_foscttm(Z_val[m1], Z_val[m2])
260
+ key = f"foscttm_{m1}_vs_{m2}"
261
+ metrics[key] = float(fos)
262
+ fos_vals.append(fos)
263
+ print(f" FOSCTTM ({m1} vs {m2}): {fos:.4f}")
264
+
265
+ if fos_vals:
266
+ metrics["foscttm_mean"] = float(np.mean(fos_vals))
267
+
268
+ # label transfer: use first modality as "reference" if celltype_key is given
269
+ if celltype_key is not None and celltype_key in adata_val[modalities[0]].obs:
270
+ ref_mod = modalities[0]
271
+ labels_ref = adata_val[ref_mod].obs[celltype_key].astype(str).values
272
+
273
+ for tgt_mod in modalities[1:]:
274
+ labels_tgt = (
275
+ adata_val[tgt_mod].obs[celltype_key].astype(str).values
276
+ if celltype_key in adata_val[tgt_mod].obs
277
+ else None
278
+ )
279
+ _, acc, _ = univi_eval.label_transfer_knn(
280
+ Z_source=Z_val[ref_mod],
281
+ labels_source=labels_ref,
282
+ Z_target=Z_val[tgt_mod],
283
+ labels_target=labels_tgt,
284
+ k=15,
285
+ )
286
+ if acc is not None:
287
+ key = f"label_acc_{tgt_mod}_from_{ref_mod}"
288
+ metrics[key] = float(acc)
289
+ print(f" Label transfer ({ref_mod}→{tgt_mod}) accuracy: {acc:.3f}")
290
+
291
+ # modality mixing
292
+ Z_joint = np.concatenate(list(Z_val.values()), axis=0)
293
+ modality_labels = np.concatenate(
294
+ [[m] * Z_val[m].shape[0] for m in modalities]
295
+ )
296
+ mix_score = univi_eval.compute_modality_mixing(Z_joint, modality_labels)
297
+ metrics["modality_mixing"] = float(mix_score)
298
+ print(f" Modality mixing score: {mix_score:.4f}")
299
+
300
+ # composite score: simple weighted combo
301
+ if "foscttm_mean" in metrics:
302
+ comp = best_val_loss + 1000.0 * metrics["foscttm_mean"]
303
+ else:
304
+ comp = best_val_loss
305
+ metrics["composite_score"] = float(comp)
306
+ print(f" Composite score: {comp:.3f}")
307
+ else:
308
+ # unimodal: composite == val loss
309
+ metrics["composite_score"] = float(best_val_loss)
310
+
311
+ runtime_min = (time.time() - start) / 60.0
312
+ print(f"[Config {config_id}] Done in {runtime_min:.2f} min")
313
+ print(
314
+ f" best_val_loss = {best_val_loss:.3f}\n"
315
+ f" composite_score = {metrics['composite_score']:.3f}"
316
+ )
317
+
318
+ return SearchResult(
319
+ config_id=config_id,
320
+ hparams=hparams,
321
+ best_val_loss=float(best_val_loss),
322
+ metrics=metrics,
323
+ runtime_min=runtime_min,
324
+ )
325
+
326
+
327
+ def results_to_dataframe(results: List[SearchResult]) -> pd.DataFrame:
328
+ rows = []
329
+ for r in results:
330
+ row = {
331
+ "config_id": r.config_id,
332
+ "runtime_min": r.runtime_min,
333
+ }
334
+ row.update(r.hparams)
335
+ row.update(r.metrics)
336
+ rows.append(row)
337
+ df = pd.DataFrame(rows)
338
+ df = df.sort_values("composite_score", ascending=True).reset_index(drop=True)
339
+ return df
@@ -0,0 +1,109 @@
1
+ # univi/hyperparam_optimization/run_adt_hparam_search.py
2
+
3
+ from __future__ import annotations
4
+ from typing import List, Any, Optional
5
+
6
+ from anndata import AnnData
7
+
8
+ from univi.config import TrainingConfig
9
+ from .common import (
10
+ iter_hparam_configs,
11
+ train_single_config,
12
+ results_to_dataframe,
13
+ )
14
+
15
+
16
+ def run_adt_hparam_search(
17
+ adt_train: AnnData,
18
+ adt_val: AnnData,
19
+ device: str = "cuda",
20
+ layer: Optional[str] = "counts", # raw ADT counts; for CLR/log1p use appropriate layer
21
+ X_key: str = "X",
22
+ max_configs: int = 50,
23
+ seed: int = 0,
24
+ base_train_cfg: Optional[TrainingConfig] = None,
25
+ ):
26
+ """
27
+ Hyperparameter search for *unimodal* ADT.
28
+ """
29
+
30
+ adata_train = {"adt": adt_train}
31
+ adata_val = {"adt": adt_val}
32
+ modalities = ["adt"]
33
+
34
+ if base_train_cfg is None:
35
+ base_train_cfg = TrainingConfig(
36
+ n_epochs=80,
37
+ batch_size=256,
38
+ lr=1e-3,
39
+ weight_decay=1e-5,
40
+ device=device,
41
+ log_every=5,
42
+ grad_clip=5.0,
43
+ num_workers=0,
44
+ seed=42,
45
+ early_stopping=True,
46
+ patience=15,
47
+ min_delta=0.0,
48
+ )
49
+
50
+ adt_arch_options = [
51
+ {"name": "adt_small2", "enc": [128, 64], "dec": [64, 128]},
52
+ {"name": "adt_med2", "enc": [256, 128], "dec": [128, 256]},
53
+ ]
54
+
55
+ mod_arch_space = {"adt": adt_arch_options}
56
+ likelihood_per_mod = {
57
+ "adt": ["nb", "zinb", "gaussian"],
58
+ }
59
+
60
+ search_space = {
61
+ "latent_dim": [10, 20, 32, 40, 50, 64],
62
+ "beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
63
+ "gamma": [0.0],
64
+ "lr": [1e-3, 5e-4],
65
+ "weight_decay": [1e-4, 1e-5],
66
+ "encoder_dropout": [0.0, 0.1],
67
+ "decoder_batchnorm": [False, True],
68
+ "adt_arch": adt_arch_options,
69
+ "adt_likelihood": likelihood_per_mod["adt"],
70
+ }
71
+
72
+ input_dims = {"adt": adt_train.n_vars}
73
+
74
+ results: List[Any] = []
75
+ best_score = float("inf")
76
+ best_cfg = None
77
+ best_model = None
78
+
79
+ for cfg_id, hp in enumerate(
80
+ iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
81
+ start=1,
82
+ ):
83
+ res = train_single_config(
84
+ config_id=cfg_id,
85
+ hparams=hp,
86
+ mod_arch_space=mod_arch_space,
87
+ modalities=modalities,
88
+ input_dims=input_dims,
89
+ likelihood_per_mod={"adt": likelihood_per_mod["adt"]},
90
+ adata_train=adata_train,
91
+ adata_val=adata_val,
92
+ base_train_cfg=base_train_cfg,
93
+ layer=layer,
94
+ X_key=X_key,
95
+ celltype_key=None,
96
+ device=device,
97
+ multimodal_eval=False,
98
+ )
99
+ results.append(res)
100
+
101
+ score = res.metrics["composite_score"]
102
+ if score < best_score:
103
+ best_score = score
104
+ best_model = res
105
+ best_cfg = hp
106
+ print(f"--> New best ADT-only config (id={cfg_id}) with score={score:.3f}")
107
+
108
+ df = results_to_dataframe(results)
109
+ return df, best_model, best_cfg
@@ -0,0 +1,109 @@
1
+ # univi/hyperparam_optimization/run_atac_hparam_search.py
2
+
3
+ from __future__ import annotations
4
+ from typing import List, Any, Optional
5
+
6
+ from anndata import AnnData
7
+
8
+ from univi.config import TrainingConfig
9
+ from .common import (
10
+ iter_hparam_configs,
11
+ train_single_config,
12
+ results_to_dataframe,
13
+ )
14
+
15
+
16
+ def run_atac_hparam_search(
17
+ atac_train: AnnData,
18
+ atac_val: AnnData,
19
+ device: str = "cuda",
20
+ layer: Optional[str] = "counts", # raw peaks; or change to TF-IDF/LSI etc.
21
+ X_key: str = "X",
22
+ max_configs: int = 50,
23
+ seed: int = 0,
24
+ base_train_cfg: Optional[TrainingConfig] = None,
25
+ ):
26
+ """
27
+ Hyperparameter search for *unimodal* ATAC.
28
+ """
29
+
30
+ adata_train = {"atac": atac_train}
31
+ adata_val = {"atac": atac_val}
32
+ modalities = ["atac"]
33
+
34
+ if base_train_cfg is None:
35
+ base_train_cfg = TrainingConfig(
36
+ n_epochs=80,
37
+ batch_size=256,
38
+ lr=1e-3,
39
+ weight_decay=1e-5,
40
+ device=device,
41
+ log_every=5,
42
+ grad_clip=5.0,
43
+ num_workers=0,
44
+ seed=42,
45
+ early_stopping=True,
46
+ patience=15,
47
+ min_delta=0.0,
48
+ )
49
+
50
+ atac_arch_options = [
51
+ {"name": "atac_med2", "enc": [512, 256], "dec": [256, 512]},
52
+ {"name": "atac_wide2", "enc": [1024, 512], "dec": [512, 1024]},
53
+ ]
54
+
55
+ mod_arch_space = {"atac": atac_arch_options}
56
+ likelihood_per_mod = {
57
+ "atac": ["nb", "poisson", "zinb", "gaussian"],
58
+ }
59
+
60
+ search_space = {
61
+ "latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
62
+ "beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
63
+ "gamma": [0.0],
64
+ "lr": [1e-3, 5e-4],
65
+ "weight_decay": [1e-4, 1e-5],
66
+ "encoder_dropout": [0.0, 0.1],
67
+ "decoder_batchnorm": [False, True],
68
+ "atac_arch": atac_arch_options,
69
+ "atac_likelihood": likelihood_per_mod["atac"],
70
+ }
71
+
72
+ input_dims = {"atac": atac_train.n_vars}
73
+
74
+ results: List[Any] = []
75
+ best_score = float("inf")
76
+ best_cfg = None
77
+ best_model = None
78
+
79
+ for cfg_id, hp in enumerate(
80
+ iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
81
+ start=1,
82
+ ):
83
+ res = train_single_config(
84
+ config_id=cfg_id,
85
+ hparams=hp,
86
+ mod_arch_space=mod_arch_space,
87
+ modalities=modalities,
88
+ input_dims=input_dims,
89
+ likelihood_per_mod={"atac": likelihood_per_mod["atac"]},
90
+ adata_train=adata_train,
91
+ adata_val=adata_val,
92
+ base_train_cfg=base_train_cfg,
93
+ layer=layer,
94
+ X_key=X_key,
95
+ celltype_key=None,
96
+ device=device,
97
+ multimodal_eval=False,
98
+ )
99
+ results.append(res)
100
+
101
+ score = res.metrics["composite_score"]
102
+ if score < best_score:
103
+ best_score = score
104
+ best_model = res
105
+ best_cfg = hp
106
+ print(f"--> New best ATAC-only config (id={cfg_id}) with score={score:.3f}")
107
+
108
+ df = results_to_dataframe(results)
109
+ return df, best_model, best_cfg
@@ -0,0 +1,137 @@
1
+ # univi/hyperparam_optimization/run_citeseq_hparam_search.py
2
+
3
+ from __future__ import annotations
4
+ from typing import Dict, Any, List, Optional, Tuple
5
+
6
+ from anndata import AnnData
7
+
8
+ from univi.config import TrainingConfig
9
+ from .common import (
10
+ iter_hparam_configs,
11
+ train_single_config,
12
+ results_to_dataframe,
13
+ )
14
+
15
+
16
+ def run_citeseq_hparam_search(
17
+ rna_train: AnnData,
18
+ adt_train: AnnData,
19
+ rna_val: AnnData,
20
+ adt_val: AnnData,
21
+ celltype_key: Optional[str] = "cell_type",
22
+ device: str = "cuda",
23
+ layer: str = "counts", # raw counts for NB / ZINB
24
+ X_key: str = "X",
25
+ max_configs: int = 100,
26
+ seed: int = 0,
27
+ base_train_cfg: Optional[TrainingConfig] = None,
28
+ ):
29
+ """
30
+ Hyperparameter random search for RNA+ADT CITE-seq.
31
+
32
+ Assumes:
33
+ - rna_* and adt_* are paired and share obs_names.
34
+ - raw counts for RNA and ADT are stored in .layers[layer] (default 'counts').
35
+ """
36
+
37
+ assert rna_train.n_obs == adt_train.n_obs
38
+ assert rna_val.n_obs == adt_val.n_obs
39
+
40
+ adata_train = {"rna": rna_train, "adt": adt_train}
41
+ adata_val = {"rna": rna_val, "adt": adt_val}
42
+ modalities = ["rna", "adt"]
43
+
44
+ if base_train_cfg is None:
45
+ base_train_cfg = TrainingConfig(
46
+ n_epochs=80,
47
+ batch_size=256,
48
+ lr=1e-3,
49
+ weight_decay=1e-5,
50
+ device=device,
51
+ log_every=5,
52
+ grad_clip=5.0,
53
+ num_workers=0,
54
+ seed=42,
55
+ early_stopping=True,
56
+ patience=15,
57
+ min_delta=0.0,
58
+ )
59
+
60
+ rna_arch_options = [
61
+ {"name": "rna_med2", "enc": [512, 256], "dec": [256, 512]},
62
+ {"name": "rna_wide2", "enc": [1024, 512], "dec": [512, 1024]},
63
+ {"name": "rna_wide3", "enc": [1024, 512, 256], "dec": [256, 512, 1024]},
64
+ ]
65
+ adt_arch_options = [
66
+ {"name": "adt_small2", "enc": [128, 64], "dec": [64, 128]},
67
+ {"name": "adt_med2", "enc": [256, 128], "dec": [128, 256]},
68
+ ]
69
+
70
+ mod_arch_space = {
71
+ "rna": rna_arch_options,
72
+ "adt": adt_arch_options,
73
+ }
74
+
75
+ likelihood_per_mod = {
76
+ "rna": ["nb", "zinb"],
77
+ "adt": ["nb", "zinb", "gaussian"], # if using CLR/log1p etc.
78
+ }
79
+
80
+ search_space = {
81
+ "latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
82
+ "beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
83
+ "gamma": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
84
+ "lr": [1e-3, 5e-4],
85
+ "weight_decay": [1e-4, 1e-5],
86
+ "encoder_dropout": [0.0, 0.1],
87
+ "decoder_batchnorm":[False, True],
88
+ "rna_arch": rna_arch_options,
89
+ "adt_arch": adt_arch_options,
90
+ "rna_likelihood": likelihood_per_mod["rna"],
91
+ "adt_likelihood": likelihood_per_mod["adt"],
92
+ }
93
+
94
+ input_dims = {
95
+ "rna": rna_train.n_vars,
96
+ "adt": adt_train.n_vars,
97
+ }
98
+
99
+ results: List[Any] = []
100
+ best_score = float("inf")
101
+ best_model = None
102
+ best_cfg = None
103
+
104
+ for cfg_id, hp in enumerate(
105
+ iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
106
+ start=1,
107
+ ):
108
+ res = train_single_config(
109
+ config_id=cfg_id,
110
+ hparams=hp,
111
+ mod_arch_space=mod_arch_space,
112
+ modalities=modalities,
113
+ input_dims=input_dims,
114
+ likelihood_per_mod={
115
+ "rna": likelihood_per_mod["rna"],
116
+ "adt": likelihood_per_mod["adt"],
117
+ },
118
+ adata_train=adata_train,
119
+ adata_val=adata_val,
120
+ base_train_cfg=base_train_cfg,
121
+ layer=layer,
122
+ X_key=X_key,
123
+ celltype_key=celltype_key,
124
+ device=device,
125
+ multimodal_eval=True,
126
+ )
127
+ results.append(res)
128
+
129
+ score = res.metrics["composite_score"]
130
+ if score < best_score:
131
+ best_score = score
132
+ best_model = res
133
+ best_cfg = hp
134
+ print(f"--> New best CITE-seq config (id={cfg_id}) with score={score:.3f}")
135
+
136
+ df = results_to_dataframe(results)
137
+ return df, best_model, best_cfg