univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/objectives.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# univi/objectives.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import Dict
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def kl_diag_gaussians(
|
|
9
|
+
mu_q: torch.Tensor,
|
|
10
|
+
logvar_q: torch.Tensor,
|
|
11
|
+
mu_p: torch.Tensor,
|
|
12
|
+
logvar_p: torch.Tensor,
|
|
13
|
+
) -> torch.Tensor:
|
|
14
|
+
"""
|
|
15
|
+
KL(q||p) for diagonal Gaussians, summed over latent dim, per sample.
|
|
16
|
+
"""
|
|
17
|
+
var_q = torch.exp(logvar_q)
|
|
18
|
+
var_p = torch.exp(logvar_p)
|
|
19
|
+
kl = (
|
|
20
|
+
logvar_p
|
|
21
|
+
- logvar_q
|
|
22
|
+
+ (var_q + (mu_q - mu_p) ** 2) / var_p
|
|
23
|
+
- 1.0
|
|
24
|
+
)
|
|
25
|
+
return 0.5 * kl.sum(dim=-1)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def symmetric_alignment_loss(
|
|
29
|
+
mu_per_mod: Dict[str, torch.Tensor],
|
|
30
|
+
) -> torch.Tensor:
|
|
31
|
+
"""
|
|
32
|
+
Simple symmetric cross-modal alignment: mean pairwise L2 distance
|
|
33
|
+
between latent means across modalities.
|
|
34
|
+
"""
|
|
35
|
+
names = list(mu_per_mod.keys())
|
|
36
|
+
if len(names) < 2:
|
|
37
|
+
return torch.zeros(mu_per_mod[names[0]].size(0), device=mu_per_mod[names[0]].device)
|
|
38
|
+
|
|
39
|
+
losses = []
|
|
40
|
+
for i in range(len(names)):
|
|
41
|
+
for j in range(i + 1, len(names)):
|
|
42
|
+
mu_i = mu_per_mod[names[i]]
|
|
43
|
+
mu_j = mu_per_mod[names[j]]
|
|
44
|
+
losses.append(((mu_i - mu_j) ** 2).sum(dim=-1))
|
|
45
|
+
stacked = torch.stack(losses, dim=0)
|
|
46
|
+
return stacked.mean(dim=0)
|
univi/pipeline.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# univi/pipeline.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import asdict
|
|
5
|
+
from typing import Any, Dict, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import anndata as ad
|
|
11
|
+
from anndata import AnnData
|
|
12
|
+
|
|
13
|
+
from .config import UniVIConfig, ModalityConfig, TrainingConfig
|
|
14
|
+
from .models import UniVIMultiModalVAE
|
|
15
|
+
from .data import align_paired_obs_names, infer_input_dim, _get_matrix
|
|
16
|
+
from .utils.io import load_config, load_checkpoint
|
|
17
|
+
from .utils.seed import set_seed
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_anndata_dict(cfg: Dict[str, Any], *, data_root: Optional[str] = None) -> Dict[str, AnnData]:
|
|
21
|
+
adata_dict: Dict[str, AnnData] = {}
|
|
22
|
+
for m in cfg["data"]["modalities"]:
|
|
23
|
+
name = m["name"]
|
|
24
|
+
path = m["h5ad_path"]
|
|
25
|
+
if data_root is not None and not os.path.isabs(path):
|
|
26
|
+
path = os.path.join(data_root, path)
|
|
27
|
+
adata_dict[name] = ad.read_h5ad(path)
|
|
28
|
+
return adata_dict
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def build_univi_from_config(
|
|
32
|
+
cfg: Dict[str, Any],
|
|
33
|
+
adata_dict: Dict[str, AnnData],
|
|
34
|
+
) -> Tuple[UniVIMultiModalVAE, UniVIConfig, Dict[str, Optional[str]], Dict[str, str]]:
|
|
35
|
+
"""Build UniVI model + selectors from a loaded config and loaded AnnData dict."""
|
|
36
|
+
mcfgs = cfg["data"]["modalities"]
|
|
37
|
+
model_cfg = cfg.get("model", {})
|
|
38
|
+
|
|
39
|
+
# selectors
|
|
40
|
+
layer_by = {m["name"]: m.get("layer", None) for m in mcfgs}
|
|
41
|
+
xkey_by = {m["name"]: m.get("X_key", "X") for m in mcfgs}
|
|
42
|
+
|
|
43
|
+
modalities = []
|
|
44
|
+
for m in mcfgs:
|
|
45
|
+
name = m["name"]
|
|
46
|
+
input_dim = infer_input_dim(adata_dict[name], layer=layer_by[name], X_key=xkey_by[name])
|
|
47
|
+
|
|
48
|
+
hidden_default = model_cfg.get("hidden_dims_default", [256, 128])
|
|
49
|
+
enc = m.get("encoder_hidden", m.get("hidden_dims", hidden_default))
|
|
50
|
+
dec = m.get("decoder_hidden", m.get("decoder_hidden", list(enc)[::-1]))
|
|
51
|
+
modalities.append(
|
|
52
|
+
ModalityConfig(
|
|
53
|
+
name=name,
|
|
54
|
+
input_dim=int(input_dim),
|
|
55
|
+
encoder_hidden=list(enc),
|
|
56
|
+
decoder_hidden=list(dec),
|
|
57
|
+
likelihood=m.get("likelihood", "gaussian"),
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
univi_cfg = UniVIConfig(
|
|
62
|
+
latent_dim=int(model_cfg.get("latent_dim", 32)),
|
|
63
|
+
modalities=modalities,
|
|
64
|
+
beta=float(model_cfg.get("beta", 1.0)),
|
|
65
|
+
gamma=float(model_cfg.get("gamma", 1.0)),
|
|
66
|
+
encoder_dropout=float(model_cfg.get("encoder_dropout", model_cfg.get("dropout", 0.0))),
|
|
67
|
+
decoder_dropout=float(model_cfg.get("decoder_dropout", model_cfg.get("dropout", 0.0))),
|
|
68
|
+
encoder_batchnorm=bool(model_cfg.get("encoder_batchnorm", model_cfg.get("batchnorm", True))),
|
|
69
|
+
decoder_batchnorm=bool(model_cfg.get("decoder_batchnorm", False)),
|
|
70
|
+
kl_anneal_start=int(model_cfg.get("kl_anneal_start", 0)),
|
|
71
|
+
kl_anneal_end=int(model_cfg.get("kl_anneal_end", 0)),
|
|
72
|
+
align_anneal_start=int(model_cfg.get("align_anneal_start", 0)),
|
|
73
|
+
align_anneal_end=int(model_cfg.get("align_anneal_end", 0)),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
model = UniVIMultiModalVAE(
|
|
77
|
+
univi_cfg,
|
|
78
|
+
loss_mode=model_cfg.get("loss_mode", "v2"),
|
|
79
|
+
v1_recon=model_cfg.get("v1_recon", "cross"),
|
|
80
|
+
v1_recon_mix=float(model_cfg.get("v1_recon_mix", 0.0)),
|
|
81
|
+
normalize_v1_terms=bool(model_cfg.get("normalize_v1_terms", True)),
|
|
82
|
+
)
|
|
83
|
+
return model, univi_cfg, layer_by, xkey_by
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@torch.no_grad()
|
|
87
|
+
def encode_latents_single_modality(
|
|
88
|
+
model: UniVIMultiModalVAE,
|
|
89
|
+
adata: AnnData,
|
|
90
|
+
modality: str,
|
|
91
|
+
*,
|
|
92
|
+
layer: Optional[str] = None,
|
|
93
|
+
X_key: str = "X",
|
|
94
|
+
batch_size: int = 512,
|
|
95
|
+
device: str = "cpu",
|
|
96
|
+
) -> np.ndarray:
|
|
97
|
+
"""Encode *one modality only* (for unimodal / independent dataset experiments)."""
|
|
98
|
+
model.eval()
|
|
99
|
+
model.to(device)
|
|
100
|
+
|
|
101
|
+
X = _get_matrix(adata, layer=layer, X_key=X_key)
|
|
102
|
+
# materialize sparse batches safely
|
|
103
|
+
n = X.shape[0]
|
|
104
|
+
zs = []
|
|
105
|
+
for start in range(0, n, batch_size):
|
|
106
|
+
end = min(n, start + batch_size)
|
|
107
|
+
xb = X[start:end]
|
|
108
|
+
if hasattr(xb, "A"):
|
|
109
|
+
xb = xb.A
|
|
110
|
+
xb = torch.as_tensor(np.asarray(xb), dtype=torch.float32, device=device)
|
|
111
|
+
mu_dict, logvar_dict = model.encode_modalities({modality: xb})
|
|
112
|
+
mu_z, logvar_z = model.mixture_of_experts(mu_dict, logvar_dict)
|
|
113
|
+
zs.append(mu_z.detach().cpu().numpy())
|
|
114
|
+
return np.vstack(zs)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@torch.no_grad()
|
|
118
|
+
def encode_latents_paired(
|
|
119
|
+
model: UniVIMultiModalVAE,
|
|
120
|
+
adata_dict: Dict[str, AnnData],
|
|
121
|
+
*,
|
|
122
|
+
layer_by: Optional[Dict[str, Optional[str]]] = None,
|
|
123
|
+
xkey_by: Optional[Dict[str, str]] = None,
|
|
124
|
+
batch_size: int = 512,
|
|
125
|
+
device: str = "cpu",
|
|
126
|
+
fused: bool = True,
|
|
127
|
+
) -> Dict[str, np.ndarray]:
|
|
128
|
+
"""Encode paired cells for each modality and optionally the fused MoE latent.
|
|
129
|
+
|
|
130
|
+
Returns dict with keys for each modality (mu of that encoder) and optionally 'fused'.
|
|
131
|
+
"""
|
|
132
|
+
model.eval()
|
|
133
|
+
model.to(device)
|
|
134
|
+
|
|
135
|
+
names = list(adata_dict.keys())
|
|
136
|
+
n = adata_dict[names[0]].n_obs
|
|
137
|
+
# require paired order
|
|
138
|
+
for nm in names[1:]:
|
|
139
|
+
if not np.array_equal(adata_dict[nm].obs_names.values, adata_dict[names[0]].obs_names.values):
|
|
140
|
+
raise ValueError(f"obs_names mismatch between {names[0]} and {nm}")
|
|
141
|
+
|
|
142
|
+
out = {nm: [] for nm in names}
|
|
143
|
+
if fused:
|
|
144
|
+
out["fused"] = []
|
|
145
|
+
|
|
146
|
+
for start in range(0, n, batch_size):
|
|
147
|
+
end = min(n, start + batch_size)
|
|
148
|
+
x_dict = {}
|
|
149
|
+
for nm, adata in adata_dict.items():
|
|
150
|
+
layer = None if layer_by is None else layer_by.get(nm, None)
|
|
151
|
+
xkey = "X" if xkey_by is None else xkey_by.get(nm, "X")
|
|
152
|
+
X = _get_matrix(adata, layer=layer, X_key=xkey)[start:end]
|
|
153
|
+
if hasattr(X, "A"):
|
|
154
|
+
X = X.A
|
|
155
|
+
x_dict[nm] = torch.as_tensor(np.asarray(X), dtype=torch.float32, device=device)
|
|
156
|
+
|
|
157
|
+
mu_dict, logvar_dict = model.encode_modalities(x_dict)
|
|
158
|
+
# per-modality mus
|
|
159
|
+
for nm in names:
|
|
160
|
+
out[nm].append(mu_dict[nm].detach().cpu().numpy())
|
|
161
|
+
if fused:
|
|
162
|
+
mu_z, logvar_z = model.mixture_of_experts(mu_dict, logvar_dict)
|
|
163
|
+
out["fused"].append(mu_z.detach().cpu().numpy())
|
|
164
|
+
|
|
165
|
+
for k in list(out.keys()):
|
|
166
|
+
out[k] = np.vstack(out[k])
|
|
167
|
+
return out
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def load_model_and_data(
|
|
171
|
+
config_path: str,
|
|
172
|
+
*,
|
|
173
|
+
checkpoint_path: Optional[str] = None,
|
|
174
|
+
data_root: Optional[str] = None,
|
|
175
|
+
device: str = "cpu",
|
|
176
|
+
align_obs: bool = True,
|
|
177
|
+
) -> Tuple[Dict[str, Any], Dict[str, AnnData], UniVIMultiModalVAE, Dict[str, Optional[str]], Dict[str, str]]:
|
|
178
|
+
cfg = load_config(config_path)
|
|
179
|
+
seed = int(cfg.get("training", {}).get("seed", 0))
|
|
180
|
+
set_seed(seed, deterministic=False)
|
|
181
|
+
|
|
182
|
+
adata_dict = load_anndata_dict(cfg, data_root=data_root)
|
|
183
|
+
if align_obs:
|
|
184
|
+
adata_dict = align_paired_obs_names(adata_dict)
|
|
185
|
+
|
|
186
|
+
model, univi_cfg, layer_by, xkey_by = build_univi_from_config(cfg, adata_dict)
|
|
187
|
+
if checkpoint_path is not None:
|
|
188
|
+
ck = load_checkpoint(checkpoint_path)
|
|
189
|
+
# common patterns
|
|
190
|
+
state = ck.get("model_state", ck.get("state_dict", ck))
|
|
191
|
+
model.load_state_dict(state, strict=False)
|
|
192
|
+
|
|
193
|
+
model.to(device)
|
|
194
|
+
return cfg, adata_dict, model, layer_by, xkey_by
|
univi/plotting.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# univi/plotting.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from typing import Dict, Optional, Sequence, List
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import matplotlib.pyplot as plt
|
|
8
|
+
import scanpy as sc
|
|
9
|
+
import seaborn as sns
|
|
10
|
+
from anndata import AnnData
|
|
11
|
+
import anndata as ad
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def set_style(font_scale: float = 1.25, dpi: int = 150) -> None:
|
|
15
|
+
"""Readable, manuscript-friendly plotting defaults."""
|
|
16
|
+
import matplotlib as mpl
|
|
17
|
+
|
|
18
|
+
base = 10.0 * float(font_scale)
|
|
19
|
+
mpl.rcParams.update({
|
|
20
|
+
"figure.dpi": int(dpi),
|
|
21
|
+
"savefig.dpi": 300,
|
|
22
|
+
"font.size": base,
|
|
23
|
+
"axes.titlesize": base * 1.2,
|
|
24
|
+
"axes.labelsize": base * 1.1,
|
|
25
|
+
"xtick.labelsize": base * 0.95,
|
|
26
|
+
"ytick.labelsize": base * 0.95,
|
|
27
|
+
"legend.fontsize": base * 0.95,
|
|
28
|
+
"pdf.fonttype": 42,
|
|
29
|
+
"ps.fonttype": 42,
|
|
30
|
+
})
|
|
31
|
+
sc.settings.set_figure_params(dpi=int(dpi), dpi_save=300, frameon=False)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def umap_single_adata(
|
|
35
|
+
adata_obj: AnnData,
|
|
36
|
+
obsm_key: str = "X_univi",
|
|
37
|
+
color: Optional[Sequence[str]] = None,
|
|
38
|
+
savepath: Optional[str] = None,
|
|
39
|
+
n_neighbors: int = 30,
|
|
40
|
+
random_state: int = 0,
|
|
41
|
+
) -> None:
|
|
42
|
+
if obsm_key not in adata_obj.obsm:
|
|
43
|
+
raise KeyError("Missing obsm[%r]. Available: %s" % (obsm_key, list(adata_obj.obsm.keys())))
|
|
44
|
+
|
|
45
|
+
# Compute neighbors/umap if missing
|
|
46
|
+
if "neighbors" not in adata_obj.uns:
|
|
47
|
+
sc.pp.neighbors(adata_obj, use_rep=obsm_key, n_neighbors=int(n_neighbors))
|
|
48
|
+
if "X_umap" not in adata_obj.obsm:
|
|
49
|
+
sc.tl.umap(adata_obj, random_state=int(random_state))
|
|
50
|
+
|
|
51
|
+
if color is None:
|
|
52
|
+
color = []
|
|
53
|
+
sc.pl.umap(adata_obj, color=list(color), show=False)
|
|
54
|
+
|
|
55
|
+
if savepath is not None:
|
|
56
|
+
plt.savefig(savepath, dpi=300, bbox_inches="tight")
|
|
57
|
+
plt.close()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def umap_by_modality(
|
|
61
|
+
adata_dict: Dict[str, AnnData],
|
|
62
|
+
obsm_key: str = "X_univi",
|
|
63
|
+
color: str = "celltype",
|
|
64
|
+
savepath: Optional[str] = None,
|
|
65
|
+
n_neighbors: int = 30,
|
|
66
|
+
random_state: int = 0,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Concatenate adatas; expects each input adata already has the same obsm_key embedding
|
|
70
|
+
(or you should add it before calling this).
|
|
71
|
+
"""
|
|
72
|
+
annotated: List[AnnData] = []
|
|
73
|
+
for mod, a in adata_dict.items():
|
|
74
|
+
aa = a.copy()
|
|
75
|
+
aa.obs["univi_modality"] = str(mod)
|
|
76
|
+
annotated.append(aa)
|
|
77
|
+
|
|
78
|
+
combined = annotated[0].concatenate(
|
|
79
|
+
*annotated[1:],
|
|
80
|
+
batch_key="univi_source",
|
|
81
|
+
batch_categories=list(adata_dict.keys()),
|
|
82
|
+
index_unique="-",
|
|
83
|
+
join="outer",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Carry embeddings if needed (common pitfall: concatenate drops obsm keys)
|
|
87
|
+
if obsm_key not in combined.obsm:
|
|
88
|
+
# try to rebuild from parts
|
|
89
|
+
try:
|
|
90
|
+
Zs = [adata_dict[m].obsm[obsm_key] for m in adata_dict.keys()]
|
|
91
|
+
combined.obsm[obsm_key] = np.vstack(Zs)
|
|
92
|
+
except Exception:
|
|
93
|
+
raise KeyError("combined is missing obsm[%r] after concatenation; add it manually." % obsm_key)
|
|
94
|
+
|
|
95
|
+
umap_single_adata(
|
|
96
|
+
combined,
|
|
97
|
+
obsm_key=obsm_key,
|
|
98
|
+
color=[color, "univi_modality"],
|
|
99
|
+
savepath=savepath,
|
|
100
|
+
n_neighbors=n_neighbors,
|
|
101
|
+
random_state=random_state,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def plot_confusion_matrix(
|
|
106
|
+
cm: np.ndarray,
|
|
107
|
+
labels: np.ndarray,
|
|
108
|
+
title: str = "Label transfer (source \u2192 target)",
|
|
109
|
+
savepath: Optional[str] = None,
|
|
110
|
+
) -> None:
|
|
111
|
+
plt.figure(figsize=(6, 5))
|
|
112
|
+
sns.heatmap(
|
|
113
|
+
cm,
|
|
114
|
+
annot=False,
|
|
115
|
+
xticklabels=labels,
|
|
116
|
+
yticklabels=labels,
|
|
117
|
+
cmap="viridis",
|
|
118
|
+
)
|
|
119
|
+
plt.xlabel("Predicted")
|
|
120
|
+
plt.ylabel("True")
|
|
121
|
+
plt.title(title)
|
|
122
|
+
plt.tight_layout()
|
|
123
|
+
if savepath is not None:
|
|
124
|
+
plt.savefig(savepath, dpi=300, bbox_inches="tight")
|
|
125
|
+
plt.close()
|
|
126
|
+
|