univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# univi/hyperparam_optimization/run_multiome_hparam_search.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from anndata import AnnData
|
|
9
|
+
|
|
10
|
+
from univi.config import TrainingConfig
|
|
11
|
+
from .common import (
|
|
12
|
+
iter_hparam_configs,
|
|
13
|
+
train_single_config,
|
|
14
|
+
results_to_dataframe,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def run_multiome_hparam_search(
|
|
19
|
+
rna_train: AnnData,
|
|
20
|
+
atac_train: AnnData,
|
|
21
|
+
rna_val: AnnData,
|
|
22
|
+
atac_val: AnnData,
|
|
23
|
+
celltype_key: Optional[str] = "cell_type",
|
|
24
|
+
device: str = "cuda",
|
|
25
|
+
layer: str = "counts", # raw counts for NB / Poisson / ZINB
|
|
26
|
+
X_key: str = "X",
|
|
27
|
+
max_configs: int = 100,
|
|
28
|
+
seed: int = 0,
|
|
29
|
+
base_train_cfg: Optional[TrainingConfig] = None,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Hyperparameter random search for RNA+ATAC multiome.
|
|
33
|
+
|
|
34
|
+
Assumes:
|
|
35
|
+
- rna_train / atac_train are paired and same obs_names
|
|
36
|
+
- rna_val / atac_val are paired and same obs_names
|
|
37
|
+
- raw counts stored in .layers[layer] (default 'counts')
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
assert rna_train.n_obs == atac_train.n_obs
|
|
41
|
+
assert rna_val.n_obs == atac_val.n_obs
|
|
42
|
+
|
|
43
|
+
adata_train = {"rna": rna_train, "atac": atac_train}
|
|
44
|
+
adata_val = {"rna": rna_val, "atac": atac_val}
|
|
45
|
+
modalities = ["rna", "atac"]
|
|
46
|
+
|
|
47
|
+
if base_train_cfg is None:
|
|
48
|
+
base_train_cfg = TrainingConfig(
|
|
49
|
+
n_epochs=80,
|
|
50
|
+
batch_size=256,
|
|
51
|
+
lr=1e-3,
|
|
52
|
+
weight_decay=1e-5,
|
|
53
|
+
device=device,
|
|
54
|
+
log_every=5,
|
|
55
|
+
grad_clip=5.0,
|
|
56
|
+
num_workers=0,
|
|
57
|
+
seed=42,
|
|
58
|
+
early_stopping=True,
|
|
59
|
+
patience=15,
|
|
60
|
+
min_delta=0.0,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# ----- architecture options -----
|
|
64
|
+
rna_arch_options = [
|
|
65
|
+
{"name": "rna_med2", "enc": [512, 256], "dec": [256, 512]},
|
|
66
|
+
{"name": "rna_wide2", "enc": [1024, 512], "dec": [512, 1024]},
|
|
67
|
+
{"name": "rna_wide3", "enc": [1024, 512, 256], "dec": [256, 512, 1024]},
|
|
68
|
+
]
|
|
69
|
+
atac_arch_options = [
|
|
70
|
+
{"name": "atac_med2", "enc": [512, 256], "dec": [256, 512]},
|
|
71
|
+
{"name": "atac_wide2", "enc": [1024, 512], "dec": [512, 1024]},
|
|
72
|
+
{"name": "atac_wide3", "enc": [2048, 1024, 512], "dec": [512, 1024, 2048]},
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
mod_arch_space = {
|
|
76
|
+
"rna": rna_arch_options,
|
|
77
|
+
"atac": atac_arch_options,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
# ----- likelihood options -----
|
|
81
|
+
likelihood_per_mod = {
|
|
82
|
+
# for raw counts
|
|
83
|
+
"rna": ["nb", "zinb"],
|
|
84
|
+
"atac": ["nb", "poisson", "zinb"],
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
# ----- hyperparameter search space -----
|
|
88
|
+
search_space = {
|
|
89
|
+
"latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
|
|
90
|
+
"beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
91
|
+
"gamma": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
92
|
+
"lr": [1e-3, 5e-4],
|
|
93
|
+
"weight_decay": [1e-4, 1e-5],
|
|
94
|
+
"encoder_dropout": [0.0, 0.1],
|
|
95
|
+
"decoder_batchnorm":[False, True],
|
|
96
|
+
"rna_arch": rna_arch_options,
|
|
97
|
+
"atac_arch": atac_arch_options,
|
|
98
|
+
"rna_likelihood": likelihood_per_mod["rna"],
|
|
99
|
+
"atac_likelihood": likelihood_per_mod["atac"],
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
input_dims = {
|
|
103
|
+
"rna": rna_train.n_vars,
|
|
104
|
+
"atac": atac_train.n_vars,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
results: List[Any] = []
|
|
108
|
+
best_score = float("inf")
|
|
109
|
+
best_model = None
|
|
110
|
+
best_cfg = None
|
|
111
|
+
|
|
112
|
+
for cfg_id, hp in enumerate(
|
|
113
|
+
iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
|
|
114
|
+
start=1,
|
|
115
|
+
):
|
|
116
|
+
res = train_single_config(
|
|
117
|
+
config_id=cfg_id,
|
|
118
|
+
hparams=hp,
|
|
119
|
+
mod_arch_space=mod_arch_space,
|
|
120
|
+
modalities=modalities,
|
|
121
|
+
input_dims=input_dims,
|
|
122
|
+
likelihood_per_mod={
|
|
123
|
+
"rna": likelihood_per_mod["rna"],
|
|
124
|
+
"atac": likelihood_per_mod["atac"],
|
|
125
|
+
},
|
|
126
|
+
adata_train=adata_train,
|
|
127
|
+
adata_val=adata_val,
|
|
128
|
+
base_train_cfg=base_train_cfg,
|
|
129
|
+
layer=layer,
|
|
130
|
+
X_key=X_key,
|
|
131
|
+
celltype_key=celltype_key,
|
|
132
|
+
device=device,
|
|
133
|
+
multimodal_eval=True,
|
|
134
|
+
)
|
|
135
|
+
results.append(res)
|
|
136
|
+
|
|
137
|
+
score = res.metrics["composite_score"]
|
|
138
|
+
if score < best_score:
|
|
139
|
+
best_score = score
|
|
140
|
+
best_model = res # we'll re-train or save separate; for now keep config
|
|
141
|
+
best_cfg = hp
|
|
142
|
+
print(f"--> New best config (id={cfg_id}) with score={score:.3f}")
|
|
143
|
+
|
|
144
|
+
df = results_to_dataframe(results)
|
|
145
|
+
return df, best_model, best_cfg
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# univi/hyperparam_optimization/run_rna_hparam_search.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import List, Any, Optional
|
|
5
|
+
|
|
6
|
+
from anndata import AnnData
|
|
7
|
+
|
|
8
|
+
from univi.config import TrainingConfig
|
|
9
|
+
from .common import (
|
|
10
|
+
iter_hparam_configs,
|
|
11
|
+
train_single_config,
|
|
12
|
+
results_to_dataframe,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_rna_hparam_search(
|
|
17
|
+
rna_train: AnnData,
|
|
18
|
+
rna_val: AnnData,
|
|
19
|
+
device: str = "cuda",
|
|
20
|
+
layer: Optional[str] = "counts", # or "log1p" if you want Gaussian/lognormal
|
|
21
|
+
X_key: str = "X",
|
|
22
|
+
max_configs: int = 50,
|
|
23
|
+
seed: int = 0,
|
|
24
|
+
base_train_cfg: Optional[TrainingConfig] = None,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Hyperparameter search for *unimodal* RNA.
|
|
28
|
+
|
|
29
|
+
Only uses validation loss (no alignment metrics).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
adata_train = {"rna": rna_train}
|
|
33
|
+
adata_val = {"rna": rna_val}
|
|
34
|
+
modalities = ["rna"]
|
|
35
|
+
|
|
36
|
+
if base_train_cfg is None:
|
|
37
|
+
base_train_cfg = TrainingConfig(
|
|
38
|
+
n_epochs=80,
|
|
39
|
+
batch_size=256,
|
|
40
|
+
lr=1e-3,
|
|
41
|
+
weight_decay=1e-5,
|
|
42
|
+
device=device,
|
|
43
|
+
log_every=5,
|
|
44
|
+
grad_clip=5.0,
|
|
45
|
+
num_workers=0,
|
|
46
|
+
seed=42,
|
|
47
|
+
early_stopping=True,
|
|
48
|
+
patience=15,
|
|
49
|
+
min_delta=0.0,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
rna_arch_options = [
|
|
53
|
+
{"name": "rna_med2", "enc": [512, 256], "dec": [256, 512]},
|
|
54
|
+
{"name": "rna_wide2", "enc": [1024, 512], "dec": [512, 1024]},
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
mod_arch_space = {"rna": rna_arch_options}
|
|
58
|
+
likelihood_per_mod = {
|
|
59
|
+
"rna": ["nb", "zinb", "gaussian", "lognormal"],
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
search_space = {
|
|
63
|
+
"latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
|
|
64
|
+
"beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
65
|
+
"gamma": [0.0], # no cross-modal alignment
|
|
66
|
+
"lr": [1e-3, 5e-4],
|
|
67
|
+
"weight_decay": [1e-4, 1e-5],
|
|
68
|
+
"encoder_dropout": [0.0, 0.1],
|
|
69
|
+
"decoder_batchnorm":[False, True],
|
|
70
|
+
"rna_arch": rna_arch_options,
|
|
71
|
+
"rna_likelihood": likelihood_per_mod["rna"],
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
input_dims = {"rna": rna_train.n_vars}
|
|
75
|
+
|
|
76
|
+
results: List[Any] = []
|
|
77
|
+
best_score = float("inf")
|
|
78
|
+
best_cfg = None
|
|
79
|
+
best_model = None
|
|
80
|
+
|
|
81
|
+
for cfg_id, hp in enumerate(
|
|
82
|
+
iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
|
|
83
|
+
start=1,
|
|
84
|
+
):
|
|
85
|
+
res = train_single_config(
|
|
86
|
+
config_id=cfg_id,
|
|
87
|
+
hparams=hp,
|
|
88
|
+
mod_arch_space=mod_arch_space,
|
|
89
|
+
modalities=modalities,
|
|
90
|
+
input_dims=input_dims,
|
|
91
|
+
likelihood_per_mod={"rna": likelihood_per_mod["rna"]},
|
|
92
|
+
adata_train=adata_train,
|
|
93
|
+
adata_val=adata_val,
|
|
94
|
+
base_train_cfg=base_train_cfg,
|
|
95
|
+
layer=layer,
|
|
96
|
+
X_key=X_key,
|
|
97
|
+
celltype_key=None,
|
|
98
|
+
device=device,
|
|
99
|
+
multimodal_eval=False, # unimodal
|
|
100
|
+
)
|
|
101
|
+
results.append(res)
|
|
102
|
+
|
|
103
|
+
score = res.metrics["composite_score"]
|
|
104
|
+
if score < best_score:
|
|
105
|
+
best_score = score
|
|
106
|
+
best_model = res
|
|
107
|
+
best_cfg = hp
|
|
108
|
+
print(f"--> New best RNA-only config (id={cfg_id}) with score={score:.3f}")
|
|
109
|
+
|
|
110
|
+
df = results_to_dataframe(results)
|
|
111
|
+
return df, best_model, best_cfg
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
# univi/hyperparam_optimization/run_teaseq_hparam_search.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
from anndata import AnnData
|
|
7
|
+
|
|
8
|
+
from univi.config import TrainingConfig
|
|
9
|
+
from .common import (
|
|
10
|
+
iter_hparam_configs,
|
|
11
|
+
train_single_config,
|
|
12
|
+
results_to_dataframe,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def run_teaseq_hparam_search(
|
|
17
|
+
rna_train: AnnData,
|
|
18
|
+
adt_train: AnnData,
|
|
19
|
+
atac_train: AnnData,
|
|
20
|
+
rna_val: AnnData,
|
|
21
|
+
adt_val: AnnData,
|
|
22
|
+
atac_val: AnnData,
|
|
23
|
+
celltype_key: Optional[str] = "cell_type",
|
|
24
|
+
device: str = "cuda",
|
|
25
|
+
layer: str = "counts",
|
|
26
|
+
X_key: str = "X",
|
|
27
|
+
max_configs: int = 100,
|
|
28
|
+
seed: int = 0,
|
|
29
|
+
base_train_cfg: Optional[TrainingConfig] = None,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Hyperparameter random search for TEA-seq (RNA+ADT+ATAC).
|
|
33
|
+
|
|
34
|
+
All *_train and *_val are assumed paired with matching obs_names.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
assert rna_train.n_obs == adt_train.n_obs == atac_train.n_obs
|
|
38
|
+
assert rna_val.n_obs == adt_val.n_obs == atac_val.n_obs
|
|
39
|
+
|
|
40
|
+
adata_train = {"rna": rna_train, "adt": adt_train, "atac": atac_train}
|
|
41
|
+
adata_val = {"rna": rna_val, "adt": adt_val, "atac": atac_val}
|
|
42
|
+
modalities = ["rna", "adt", "atac"]
|
|
43
|
+
|
|
44
|
+
if base_train_cfg is None:
|
|
45
|
+
base_train_cfg = TrainingConfig(
|
|
46
|
+
n_epochs=80,
|
|
47
|
+
batch_size=256,
|
|
48
|
+
lr=1e-3,
|
|
49
|
+
weight_decay=1e-5,
|
|
50
|
+
device=device,
|
|
51
|
+
log_every=5,
|
|
52
|
+
grad_clip=5.0,
|
|
53
|
+
num_workers=0,
|
|
54
|
+
seed=42,
|
|
55
|
+
early_stopping=True,
|
|
56
|
+
patience=15,
|
|
57
|
+
min_delta=0.0,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
rna_arch_options = [
|
|
61
|
+
{"name": "rna_med2", "enc": [512, 256], "dec": [256, 512]},
|
|
62
|
+
{"name": "rna_wide2", "enc": [1024, 512], "dec": [512, 1024]},
|
|
63
|
+
]
|
|
64
|
+
adt_arch_options = [
|
|
65
|
+
{"name": "adt_small2", "enc": [128, 64], "dec": [64, 128]},
|
|
66
|
+
{"name": "adt_med2", "enc": [256, 128], "dec": [128, 256]},
|
|
67
|
+
]
|
|
68
|
+
atac_arch_options = [
|
|
69
|
+
{"name": "atac_med2", "enc": [512, 256], "dec": [256, 512]},
|
|
70
|
+
{"name": "atac_wide2", "enc": [1024, 512], "dec": [512, 1024]},
|
|
71
|
+
]
|
|
72
|
+
|
|
73
|
+
mod_arch_space = {
|
|
74
|
+
"rna": rna_arch_options,
|
|
75
|
+
"adt": adt_arch_options,
|
|
76
|
+
"atac": atac_arch_options,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
likelihood_per_mod = {
|
|
80
|
+
"rna": ["nb", "zinb"],
|
|
81
|
+
"adt": ["nb", "zinb", "gaussian"],
|
|
82
|
+
"atac": ["nb", "poisson", "zinb"],
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
search_space = {
|
|
86
|
+
"latent_dim": [10, 20, 32, 40, 50, 64, 82, 120, 160, 200],
|
|
87
|
+
"beta": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
88
|
+
"gamma": [0.0, 1.0, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 60.0, 80.0, 100.0, 160.0, 200.0, 300.0],
|
|
89
|
+
"lr": [1e-3, 5e-4],
|
|
90
|
+
"weight_decay": [1e-4, 1e-5],
|
|
91
|
+
"encoder_dropout": [0.0, 0.1],
|
|
92
|
+
"decoder_batchnorm":[False, True],
|
|
93
|
+
"rna_arch": rna_arch_options,
|
|
94
|
+
"adt_arch": adt_arch_options,
|
|
95
|
+
"atac_arch": atac_arch_options,
|
|
96
|
+
"rna_likelihood": likelihood_per_mod["rna"],
|
|
97
|
+
"adt_likelihood": likelihood_per_mod["adt"],
|
|
98
|
+
"atac_likelihood": likelihood_per_mod["atac"],
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
input_dims = {
|
|
102
|
+
"rna": rna_train.n_vars,
|
|
103
|
+
"adt": adt_train.n_vars,
|
|
104
|
+
"atac": atac_train.n_vars,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
results: List[Any] = []
|
|
108
|
+
best_score = float("inf")
|
|
109
|
+
best_model = None
|
|
110
|
+
best_cfg = None
|
|
111
|
+
|
|
112
|
+
for cfg_id, hp in enumerate(
|
|
113
|
+
iter_hparam_configs(search_space, max_configs=max_configs, seed=seed),
|
|
114
|
+
start=1,
|
|
115
|
+
):
|
|
116
|
+
res = train_single_config(
|
|
117
|
+
config_id=cfg_id,
|
|
118
|
+
hparams=hp,
|
|
119
|
+
mod_arch_space=mod_arch_space,
|
|
120
|
+
modalities=modalities,
|
|
121
|
+
input_dims=input_dims,
|
|
122
|
+
likelihood_per_mod={
|
|
123
|
+
"rna": likelihood_per_mod["rna"],
|
|
124
|
+
"adt": likelihood_per_mod["adt"],
|
|
125
|
+
"atac": likelihood_per_mod["atac"],
|
|
126
|
+
},
|
|
127
|
+
adata_train=adata_train,
|
|
128
|
+
adata_val=adata_val,
|
|
129
|
+
base_train_cfg=base_train_cfg,
|
|
130
|
+
layer=layer,
|
|
131
|
+
X_key=X_key,
|
|
132
|
+
celltype_key=celltype_key,
|
|
133
|
+
device=device,
|
|
134
|
+
multimodal_eval=True,
|
|
135
|
+
)
|
|
136
|
+
results.append(res)
|
|
137
|
+
|
|
138
|
+
score = res.metrics["composite_score"]
|
|
139
|
+
if score < best_score:
|
|
140
|
+
best_score = score
|
|
141
|
+
best_model = res
|
|
142
|
+
best_cfg = hp
|
|
143
|
+
print(f"--> New best TEA-seq config (id={cfg_id}) with score={score:.3f}")
|
|
144
|
+
|
|
145
|
+
df = results_to_dataframe(results)
|
|
146
|
+
return df, best_model, best_cfg
|