univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,145 @@
1
+ # univi/hyperparam_optimization/run_multiome_hparam_search.py
2
+
3
+ from __future__ import annotations
4
+ from typing import Dict, Any, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+
8
+ from anndata import AnnData
9
+
10
+ from univi.config import TrainingConfig
11
+ from .common import (
12
+ iter_hparam_configs,
13
+ train_single_config,
14
+ results_to_dataframe,
15
+ )
16
+
17
+
18
+ def run_multiome_hparam_search(
19
+ rna_train: AnnData,
20
+ atac_train: AnnData,
21
+ rna_val: AnnData,
22
+ atac_val: AnnData,
23
+ celltype_key: Optional[str] = "cell_type",
24
+ device: str = "cuda",
25
+ layer: str = "counts", # raw counts for NB / Poisson / ZINB
26
+ X_key: str = "X",
27
+ max_configs: int = 100,
28
+ seed: int = 0,
29
+ base_train_cfg: Optional[TrainingConfig] = None,
30
+ ):
31
+ """
32
+ Hyperparameter random search for RNA+ATAC multiome.
33
+
34
+ Assumes:
35
+ - rna_train / atac_train are paired and same obs_names
36
+ - rna_val / atac_val are paired and same obs_names
37
+ - raw counts stored in .layers[layer] (default 'counts')
38
+ """
39
+
40
+ assert rna_train.n_obs == atac_train.n_obs
41
+ assert rna_val.n_obs == atac_val.n_obs
42
+
43
+ adata_train = {"rna": rna_train, "atac": atac_train}
44
+ adata_val = {"rna": rna_val, "atac": atac_val}
45
+ modalities = ["rna", "atac"]
46
+
47
+ if base_train_cfg is None:
48
+ base_train_cfg = TrainingConfig(
49
+ n_epochs=80,
50
+ batch_size=256,
51
+ lr=1e-3,
52
+ weight_decay=1e-5,
53
+ device=device,
54
+ log_every=5,
55
+ grad_clip=5.0,
56
+ num_workers=0,
57
+ seed=42,
58
+ early_stopping=True,
59
+ patience=15,
60
+ min_delta=0.0,
61
+ )
62
+
63
+ # ----- architecture options -----
64
+ rna_arch_options = [
65
+ {"name": "rna_med2", "enc": [512, 256], "dec": [256, 512]},
66
+ {"name": "rna_wide2", "enc": [1024, 512], "dec": [512, 1024]},
67
+ {"name": "rna_wide3", "enc": [1024, 512, 256], "dec": [256, 512, 1024]},
68
+ ]
69
+ atac_arch_options = [
70
+ {"name": "atac_med2", "enc": [512, 256], "dec": [256, 512]},
71
+ {"name": "atac_wide2", "enc": [1024, 512], "dec": [512, 1024]},
72
+ {"name": "atac_wide3", "enc": [2048, 1024, 512], "dec": [512, 1024, 2048]},
73
+ ]
74
+
75
+ mod_arch_space = {
76
+ "rna": rna_arch_options,
77
+ "atac": atac_arch_options,
78
+ }
79
+
80
+ # ----- likelihood options -----
81
+ likelihood_per_mod = {
82
+ # for raw counts
83
+ "rna": ["nb", "zinb"],
84
+ "atac": ["nb", "poisson", "zinb"],
85
+ }
86
+
87
+ # ----- hyperparameter search space -----
88
+ search_space = {
89
+ "latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
90
+ "beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
91
+ "gamma": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
92
+ "lr": [1e-3, 5e-4],
93
+ "weight_decay": [1e-4, 1e-5],
94
+ "encoder_dropout": [0.0, 0.1],
95
+ "decoder_batchnorm":[False, True],
96
+ "rna_arch": rna_arch_options,
97
+ "atac_arch": atac_arch_options,
98
+ "rna_likelihood": likelihood_per_mod["rna"],
99
+ "atac_likelihood": likelihood_per_mod["atac"],
100
+ }
101
+
102
+ input_dims = {
103
+ "rna": rna_train.n_vars,
104
+ "atac": atac_train.n_vars,
105
+ }
106
+
107
+ results: List[Any] = []
108
+ best_score = float("inf")
109
+ best_model = None
110
+ best_cfg = None
111
+
112
+ for cfg_id, hp in enumerate(
113
+ iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
114
+ start=1,
115
+ ):
116
+ res = train_single_config(
117
+ config_id=cfg_id,
118
+ hparams=hp,
119
+ mod_arch_space=mod_arch_space,
120
+ modalities=modalities,
121
+ input_dims=input_dims,
122
+ likelihood_per_mod={
123
+ "rna": likelihood_per_mod["rna"],
124
+ "atac": likelihood_per_mod["atac"],
125
+ },
126
+ adata_train=adata_train,
127
+ adata_val=adata_val,
128
+ base_train_cfg=base_train_cfg,
129
+ layer=layer,
130
+ X_key=X_key,
131
+ celltype_key=celltype_key,
132
+ device=device,
133
+ multimodal_eval=True,
134
+ )
135
+ results.append(res)
136
+
137
+ score = res.metrics["composite_score"]
138
+ if score < best_score:
139
+ best_score = score
140
+ best_model = res # we'll re-train or save separate; for now keep config
141
+ best_cfg = hp
142
+ print(f"--> New best config (id={cfg_id}) with score={score:.3f}")
143
+
144
+ df = results_to_dataframe(results)
145
+ return df, best_model, best_cfg
@@ -0,0 +1,111 @@
1
+ # univi/hyperparam_optimization/run_rna_hparam_search.py
2
+
3
+ from __future__ import annotations
4
+ from typing import List, Any, Optional
5
+
6
+ from anndata import AnnData
7
+
8
+ from univi.config import TrainingConfig
9
+ from .common import (
10
+ iter_hparam_configs,
11
+ train_single_config,
12
+ results_to_dataframe,
13
+ )
14
+
15
+
16
+ def run_rna_hparam_search(
17
+ rna_train: AnnData,
18
+ rna_val: AnnData,
19
+ device: str = "cuda",
20
+ layer: Optional[str] = "counts", # or "log1p" if you want Gaussian/lognormal
21
+ X_key: str = "X",
22
+ max_configs: int = 50,
23
+ seed: int = 0,
24
+ base_train_cfg: Optional[TrainingConfig] = None,
25
+ ):
26
+ """
27
+ Hyperparameter search for *unimodal* RNA.
28
+
29
+ Only uses validation loss (no alignment metrics).
30
+ """
31
+
32
+ adata_train = {"rna": rna_train}
33
+ adata_val = {"rna": rna_val}
34
+ modalities = ["rna"]
35
+
36
+ if base_train_cfg is None:
37
+ base_train_cfg = TrainingConfig(
38
+ n_epochs=80,
39
+ batch_size=256,
40
+ lr=1e-3,
41
+ weight_decay=1e-5,
42
+ device=device,
43
+ log_every=5,
44
+ grad_clip=5.0,
45
+ num_workers=0,
46
+ seed=42,
47
+ early_stopping=True,
48
+ patience=15,
49
+ min_delta=0.0,
50
+ )
51
+
52
+ rna_arch_options = [
53
+ {"name": "rna_med2", "enc": [512, 256], "dec": [256, 512]},
54
+ {"name": "rna_wide2", "enc": [1024, 512], "dec": [512, 1024]},
55
+ ]
56
+
57
+ mod_arch_space = {"rna": rna_arch_options}
58
+ likelihood_per_mod = {
59
+ "rna": ["nb", "zinb", "gaussian", "lognormal"],
60
+ }
61
+
62
+ search_space = {
63
+ "latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
64
+ "beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
65
+ "gamma": [0.0], # no cross-modal alignment
66
+ "lr": [1e-3, 5e-4],
67
+ "weight_decay": [1e-4, 1e-5],
68
+ "encoder_dropout": [0.0, 0.1],
69
+ "decoder_batchnorm":[False, True],
70
+ "rna_arch": rna_arch_options,
71
+ "rna_likelihood": likelihood_per_mod["rna"],
72
+ }
73
+
74
+ input_dims = {"rna": rna_train.n_vars}
75
+
76
+ results: List[Any] = []
77
+ best_score = float("inf")
78
+ best_cfg = None
79
+ best_model = None
80
+
81
+ for cfg_id, hp in enumerate(
82
+ iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
83
+ start=1,
84
+ ):
85
+ res = train_single_config(
86
+ config_id=cfg_id,
87
+ hparams=hp,
88
+ mod_arch_space=mod_arch_space,
89
+ modalities=modalities,
90
+ input_dims=input_dims,
91
+ likelihood_per_mod={"rna": likelihood_per_mod["rna"]},
92
+ adata_train=adata_train,
93
+ adata_val=adata_val,
94
+ base_train_cfg=base_train_cfg,
95
+ layer=layer,
96
+ X_key=X_key,
97
+ celltype_key=None,
98
+ device=device,
99
+ multimodal_eval=False, # unimodal
100
+ )
101
+ results.append(res)
102
+
103
+ score = res.metrics["composite_score"]
104
+ if score < best_score:
105
+ best_score = score
106
+ best_model = res
107
+ best_cfg = hp
108
+ print(f"--> New best RNA-only config (id={cfg_id}) with score={score:.3f}")
109
+
110
+ df = results_to_dataframe(results)
111
+ return df, best_model, best_cfg
@@ -0,0 +1,146 @@
1
+ # univi/hyperparam_optimization/run_teaseq_hparam_search.py
2
+
3
+ from __future__ import annotations
4
+ from typing import Dict, Any, List, Optional, Tuple
5
+
6
+ from anndata import AnnData
7
+
8
+ from univi.config import TrainingConfig
9
+ from .common import (
10
+ iter_hparam_configs,
11
+ train_single_config,
12
+ results_to_dataframe,
13
+ )
14
+
15
+
16
+ def run_teaseq_hparam_search(
17
+ rna_train: AnnData,
18
+ adt_train: AnnData,
19
+ atac_train: AnnData,
20
+ rna_val: AnnData,
21
+ adt_val: AnnData,
22
+ atac_val: AnnData,
23
+ celltype_key: Optional[str] = "cell_type",
24
+ device: str = "cuda",
25
+ layer: str = "counts",
26
+ X_key: str = "X",
27
+ max_configs: int = 100,
28
+ seed: int = 0,
29
+ base_train_cfg: Optional[TrainingConfig] = None,
30
+ ):
31
+ """
32
+ Hyperparameter random search for TEA-seq (RNA+ADT+ATAC).
33
+
34
+ All *_train and *_val are assumed paired with matching obs_names.
35
+ """
36
+
37
+ assert rna_train.n_obs == adt_train.n_obs == atac_train.n_obs
38
+ assert rna_val.n_obs == adt_val.n_obs == atac_val.n_obs
39
+
40
+ adata_train = {"rna": rna_train, "adt": adt_train, "atac": atac_train}
41
+ adata_val = {"rna": rna_val, "adt": adt_val, "atac": atac_val}
42
+ modalities = ["rna", "adt", "atac"]
43
+
44
+ if base_train_cfg is None:
45
+ base_train_cfg = TrainingConfig(
46
+ n_epochs=80,
47
+ batch_size=256,
48
+ lr=1e-3,
49
+ weight_decay=1e-5,
50
+ device=device,
51
+ log_every=5,
52
+ grad_clip=5.0,
53
+ num_workers=0,
54
+ seed=42,
55
+ early_stopping=True,
56
+ patience=15,
57
+ min_delta=0.0,
58
+ )
59
+
60
+ rna_arch_options = [
61
+ {"name": "rna_med2", "enc": [512, 256], "dec": [256, 512]},
62
+ {"name": "rna_wide2", "enc": [1024, 512], "dec": [512, 1024]},
63
+ ]
64
+ adt_arch_options = [
65
+ {"name": "adt_small2", "enc": [128, 64], "dec": [64, 128]},
66
+ {"name": "adt_med2", "enc": [256, 128], "dec": [128, 256]},
67
+ ]
68
+ atac_arch_options = [
69
+ {"name": "atac_med2", "enc": [512, 256], "dec": [256, 512]},
70
+ {"name": "atac_wide2", "enc": [1024, 512], "dec": [512, 1024]},
71
+ ]
72
+
73
+ mod_arch_space = {
74
+ "rna": rna_arch_options,
75
+ "adt": adt_arch_options,
76
+ "atac": atac_arch_options,
77
+ }
78
+
79
+ likelihood_per_mod = {
80
+ "rna": ["nb", "zinb"],
81
+ "adt": ["nb", "zinb", "gaussian"],
82
+ "atac": ["nb", "poisson", "zinb"],
83
+ }
84
+
85
+ search_space = {
86
+ "latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
87
+ "beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
88
+ "gamma": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
89
+ "lr": [1e-3, 5e-4],
90
+ "weight_decay": [1e-4, 1e-5],
91
+ "encoder_dropout": [0.0, 0.1],
92
+ "decoder_batchnorm":[False, True],
93
+ "rna_arch": rna_arch_options,
94
+ "adt_arch": adt_arch_options,
95
+ "atac_arch": atac_arch_options,
96
+ "rna_likelihood": likelihood_per_mod["rna"],
97
+ "adt_likelihood": likelihood_per_mod["adt"],
98
+ "atac_likelihood": likelihood_per_mod["atac"],
99
+ }
100
+
101
+ input_dims = {
102
+ "rna": rna_train.n_vars,
103
+ "adt": adt_train.n_vars,
104
+ "atac": atac_train.n_vars,
105
+ }
106
+
107
+ results: List[Any] = []
108
+ best_score = float("inf")
109
+ best_model = None
110
+ best_cfg = None
111
+
112
+ for cfg_id, hp in enumerate(
113
+ iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
114
+ start=1,
115
+ ):
116
+ res = train_single_config(
117
+ config_id=cfg_id,
118
+ hparams=hp,
119
+ mod_arch_space=mod_arch_space,
120
+ modalities=modalities,
121
+ input_dims=input_dims,
122
+ likelihood_per_mod={
123
+ "rna": likelihood_per_mod["rna"],
124
+ "adt": likelihood_per_mod["adt"],
125
+ "atac": likelihood_per_mod["atac"],
126
+ },
127
+ adata_train=adata_train,
128
+ adata_val=adata_val,
129
+ base_train_cfg=base_train_cfg,
130
+ layer=layer,
131
+ X_key=X_key,
132
+ celltype_key=celltype_key,
133
+ device=device,
134
+ multimodal_eval=True,
135
+ )
136
+ results.append(res)
137
+
138
+ score = res.metrics["composite_score"]
139
+ if score < best_score:
140
+ best_score = score
141
+ best_model = res
142
+ best_cfg = hp
143
+ print(f"--> New best TEA-seq config (id={cfg_id}) with score={score:.3f}")
144
+
145
+ df = results_to_dataframe(results)
146
+ return df, best_model, best_cfg