univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/objectives.py ADDED
@@ -0,0 +1,46 @@
1
+ # univi/objectives.py
2
+
3
+ from __future__ import annotations
4
+ from typing import Dict
5
+ import torch
6
+
7
+
8
+ def kl_diag_gaussians(
9
+ mu_q: torch.Tensor,
10
+ logvar_q: torch.Tensor,
11
+ mu_p: torch.Tensor,
12
+ logvar_p: torch.Tensor,
13
+ ) -> torch.Tensor:
14
+ """
15
+ KL(q||p) for diagonal Gaussians, summed over latent dim, per sample.
16
+ """
17
+ var_q = torch.exp(logvar_q)
18
+ var_p = torch.exp(logvar_p)
19
+ kl = (
20
+ logvar_p
21
+ - logvar_q
22
+ + (var_q + (mu_q - mu_p) ** 2) / var_p
23
+ - 1.0
24
+ )
25
+ return 0.5 * kl.sum(dim=-1)
26
+
27
+
28
+ def symmetric_alignment_loss(
29
+ mu_per_mod: Dict[str, torch.Tensor],
30
+ ) -> torch.Tensor:
31
+ """
32
+ Simple symmetric cross-modal alignment: mean pairwise L2 distance
33
+ between latent means across modalities.
34
+ """
35
+ names = list(mu_per_mod.keys())
36
+ if len(names) < 2:
37
+ return torch.zeros(mu_per_mod[names[0]].size(0), device=mu_per_mod[names[0]].device)
38
+
39
+ losses = []
40
+ for i in range(len(names)):
41
+ for j in range(i + 1, len(names)):
42
+ mu_i = mu_per_mod[names[i]]
43
+ mu_j = mu_per_mod[names[j]]
44
+ losses.append(((mu_i - mu_j) ** 2).sum(dim=-1))
45
+ stacked = torch.stack(losses, dim=0)
46
+ return stacked.mean(dim=0)
univi/pipeline.py ADDED
@@ -0,0 +1,194 @@
1
+ # univi/pipeline.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import asdict
5
+ from typing import Any, Dict, Optional, Tuple
6
+
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ import anndata as ad
11
+ from anndata import AnnData
12
+
13
+ from .config import UniVIConfig, ModalityConfig, TrainingConfig
14
+ from .models import UniVIMultiModalVAE
15
+ from .data import align_paired_obs_names, infer_input_dim, _get_matrix
16
+ from .utils.io import load_config, load_checkpoint
17
+ from .utils.seed import set_seed
18
+
19
+
20
+ def load_anndata_dict(cfg: Dict[str, Any], *, data_root: Optional[str] = None) -> Dict[str, AnnData]:
21
+ adata_dict: Dict[str, AnnData] = {}
22
+ for m in cfg["data"]["modalities"]:
23
+ name = m["name"]
24
+ path = m["h5ad_path"]
25
+ if data_root is not None and not os.path.isabs(path):
26
+ path = os.path.join(data_root, path)
27
+ adata_dict[name] = ad.read_h5ad(path)
28
+ return adata_dict
29
+
30
+
31
+ def build_univi_from_config(
32
+ cfg: Dict[str, Any],
33
+ adata_dict: Dict[str, AnnData],
34
+ ) -> Tuple[UniVIMultiModalVAE, UniVIConfig, Dict[str, Optional[str]], Dict[str, str]]:
35
+ """Build UniVI model + selectors from a loaded config and loaded AnnData dict."""
36
+ mcfgs = cfg["data"]["modalities"]
37
+ model_cfg = cfg.get("model", {})
38
+
39
+ # selectors
40
+ layer_by = {m["name"]: m.get("layer", None) for m in mcfgs}
41
+ xkey_by = {m["name"]: m.get("X_key", "X") for m in mcfgs}
42
+
43
+ modalities = []
44
+ for m in mcfgs:
45
+ name = m["name"]
46
+ input_dim = infer_input_dim(adata_dict[name], layer=layer_by[name], X_key=xkey_by[name])
47
+
48
+ hidden_default = model_cfg.get("hidden_dims_default", [256, 128])
49
+ enc = m.get("encoder_hidden", m.get("hidden_dims", hidden_default))
50
+ dec = m.get("decoder_hidden", m.get("decoder_hidden", list(enc)[::-1]))
51
+ modalities.append(
52
+ ModalityConfig(
53
+ name=name,
54
+ input_dim=int(input_dim),
55
+ encoder_hidden=list(enc),
56
+ decoder_hidden=list(dec),
57
+ likelihood=m.get("likelihood", "gaussian"),
58
+ )
59
+ )
60
+
61
+ univi_cfg = UniVIConfig(
62
+ latent_dim=int(model_cfg.get("latent_dim", 32)),
63
+ modalities=modalities,
64
+ beta=float(model_cfg.get("beta", 1.0)),
65
+ gamma=float(model_cfg.get("gamma", 1.0)),
66
+ encoder_dropout=float(model_cfg.get("encoder_dropout", model_cfg.get("dropout", 0.0))),
67
+ decoder_dropout=float(model_cfg.get("decoder_dropout", model_cfg.get("dropout", 0.0))),
68
+ encoder_batchnorm=bool(model_cfg.get("encoder_batchnorm", model_cfg.get("batchnorm", True))),
69
+ decoder_batchnorm=bool(model_cfg.get("decoder_batchnorm", False)),
70
+ kl_anneal_start=int(model_cfg.get("kl_anneal_start", 0)),
71
+ kl_anneal_end=int(model_cfg.get("kl_anneal_end", 0)),
72
+ align_anneal_start=int(model_cfg.get("align_anneal_start", 0)),
73
+ align_anneal_end=int(model_cfg.get("align_anneal_end", 0)),
74
+ )
75
+
76
+ model = UniVIMultiModalVAE(
77
+ univi_cfg,
78
+ loss_mode=model_cfg.get("loss_mode", "v2"),
79
+ v1_recon=model_cfg.get("v1_recon", "cross"),
80
+ v1_recon_mix=float(model_cfg.get("v1_recon_mix", 0.0)),
81
+ normalize_v1_terms=bool(model_cfg.get("normalize_v1_terms", True)),
82
+ )
83
+ return model, univi_cfg, layer_by, xkey_by
84
+
85
+
86
+ @torch.no_grad()
87
+ def encode_latents_single_modality(
88
+ model: UniVIMultiModalVAE,
89
+ adata: AnnData,
90
+ modality: str,
91
+ *,
92
+ layer: Optional[str] = None,
93
+ X_key: str = "X",
94
+ batch_size: int = 512,
95
+ device: str = "cpu",
96
+ ) -> np.ndarray:
97
+ """Encode *one modality only* (for unimodal / independent dataset experiments)."""
98
+ model.eval()
99
+ model.to(device)
100
+
101
+ X = _get_matrix(adata, layer=layer, X_key=X_key)
102
+ # materialize sparse batches safely
103
+ n = X.shape[0]
104
+ zs = []
105
+ for start in range(0, n, batch_size):
106
+ end = min(n, start + batch_size)
107
+ xb = X[start:end]
108
+ if hasattr(xb, "A"):
109
+ xb = xb.A
110
+ xb = torch.as_tensor(np.asarray(xb), dtype=torch.float32, device=device)
111
+ mu_dict, logvar_dict = model.encode_modalities({modality: xb})
112
+ mu_z, logvar_z = model.mixture_of_experts(mu_dict, logvar_dict)
113
+ zs.append(mu_z.detach().cpu().numpy())
114
+ return np.vstack(zs)
115
+
116
+
117
+ @torch.no_grad()
118
+ def encode_latents_paired(
119
+ model: UniVIMultiModalVAE,
120
+ adata_dict: Dict[str, AnnData],
121
+ *,
122
+ layer_by: Optional[Dict[str, Optional[str]]] = None,
123
+ xkey_by: Optional[Dict[str, str]] = None,
124
+ batch_size: int = 512,
125
+ device: str = "cpu",
126
+ fused: bool = True,
127
+ ) -> Dict[str, np.ndarray]:
128
+ """Encode paired cells for each modality and optionally the fused MoE latent.
129
+
130
+ Returns dict with keys for each modality (mu of that encoder) and optionally 'fused'.
131
+ """
132
+ model.eval()
133
+ model.to(device)
134
+
135
+ names = list(adata_dict.keys())
136
+ n = adata_dict[names[0]].n_obs
137
+ # require paired order
138
+ for nm in names[1:]:
139
+ if not np.array_equal(adata_dict[nm].obs_names.values, adata_dict[names[0]].obs_names.values):
140
+ raise ValueError(f"obs_names mismatch between {names[0]} and {nm}")
141
+
142
+ out = {nm: [] for nm in names}
143
+ if fused:
144
+ out["fused"] = []
145
+
146
+ for start in range(0, n, batch_size):
147
+ end = min(n, start + batch_size)
148
+ x_dict = {}
149
+ for nm, adata in adata_dict.items():
150
+ layer = None if layer_by is None else layer_by.get(nm, None)
151
+ xkey = "X" if xkey_by is None else xkey_by.get(nm, "X")
152
+ X = _get_matrix(adata, layer=layer, X_key=xkey)[start:end]
153
+ if hasattr(X, "A"):
154
+ X = X.A
155
+ x_dict[nm] = torch.as_tensor(np.asarray(X), dtype=torch.float32, device=device)
156
+
157
+ mu_dict, logvar_dict = model.encode_modalities(x_dict)
158
+ # per-modality mus
159
+ for nm in names:
160
+ out[nm].append(mu_dict[nm].detach().cpu().numpy())
161
+ if fused:
162
+ mu_z, logvar_z = model.mixture_of_experts(mu_dict, logvar_dict)
163
+ out["fused"].append(mu_z.detach().cpu().numpy())
164
+
165
+ for k in list(out.keys()):
166
+ out[k] = np.vstack(out[k])
167
+ return out
168
+
169
+
170
+ def load_model_and_data(
171
+ config_path: str,
172
+ *,
173
+ checkpoint_path: Optional[str] = None,
174
+ data_root: Optional[str] = None,
175
+ device: str = "cpu",
176
+ align_obs: bool = True,
177
+ ) -> Tuple[Dict[str, Any], Dict[str, AnnData], UniVIMultiModalVAE, Dict[str, Optional[str]], Dict[str, str]]:
178
+ cfg = load_config(config_path)
179
+ seed = int(cfg.get("training", {}).get("seed", 0))
180
+ set_seed(seed, deterministic=False)
181
+
182
+ adata_dict = load_anndata_dict(cfg, data_root=data_root)
183
+ if align_obs:
184
+ adata_dict = align_paired_obs_names(adata_dict)
185
+
186
+ model, univi_cfg, layer_by, xkey_by = build_univi_from_config(cfg, adata_dict)
187
+ if checkpoint_path is not None:
188
+ ck = load_checkpoint(checkpoint_path)
189
+ # common patterns
190
+ state = ck.get("model_state", ck.get("state_dict", ck))
191
+ model.load_state_dict(state, strict=False)
192
+
193
+ model.to(device)
194
+ return cfg, adata_dict, model, layer_by, xkey_by
univi/plotting.py ADDED
@@ -0,0 +1,126 @@
1
+ # univi/plotting.py
2
+
3
+ from __future__ import annotations
4
+ from typing import Dict, Optional, Sequence, List
5
+
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import scanpy as sc
9
+ import seaborn as sns
10
+ from anndata import AnnData
11
+ import anndata as ad
12
+
13
+
14
+ def set_style(font_scale: float = 1.25, dpi: int = 150) -> None:
15
+ """Readable, manuscript-friendly plotting defaults."""
16
+ import matplotlib as mpl
17
+
18
+ base = 10.0 * float(font_scale)
19
+ mpl.rcParams.update({
20
+ "figure.dpi": int(dpi),
21
+ "savefig.dpi": 300,
22
+ "font.size": base,
23
+ "axes.titlesize": base * 1.2,
24
+ "axes.labelsize": base * 1.1,
25
+ "xtick.labelsize": base * 0.95,
26
+ "ytick.labelsize": base * 0.95,
27
+ "legend.fontsize": base * 0.95,
28
+ "pdf.fonttype": 42,
29
+ "ps.fonttype": 42,
30
+ })
31
+ sc.settings.set_figure_params(dpi=int(dpi), dpi_save=300, frameon=False)
32
+
33
+
34
+ def umap_single_adata(
35
+ adata_obj: AnnData,
36
+ obsm_key: str = "X_univi",
37
+ color: Optional[Sequence[str]] = None,
38
+ savepath: Optional[str] = None,
39
+ n_neighbors: int = 30,
40
+ random_state: int = 0,
41
+ ) -> None:
42
+ if obsm_key not in adata_obj.obsm:
43
+ raise KeyError("Missing obsm[%r]. Available: %s" % (obsm_key, list(adata_obj.obsm.keys())))
44
+
45
+ # Compute neighbors/umap if missing
46
+ if "neighbors" not in adata_obj.uns:
47
+ sc.pp.neighbors(adata_obj, use_rep=obsm_key, n_neighbors=int(n_neighbors))
48
+ if "X_umap" not in adata_obj.obsm:
49
+ sc.tl.umap(adata_obj, random_state=int(random_state))
50
+
51
+ if color is None:
52
+ color = []
53
+ sc.pl.umap(adata_obj, color=list(color), show=False)
54
+
55
+ if savepath is not None:
56
+ plt.savefig(savepath, dpi=300, bbox_inches="tight")
57
+ plt.close()
58
+
59
+
60
+ def umap_by_modality(
61
+ adata_dict: Dict[str, AnnData],
62
+ obsm_key: str = "X_univi",
63
+ color: str = "celltype",
64
+ savepath: Optional[str] = None,
65
+ n_neighbors: int = 30,
66
+ random_state: int = 0,
67
+ ) -> None:
68
+ """
69
+ Concatenate adatas; expects each input adata already has the same obsm_key embedding
70
+ (or you should add it before calling this).
71
+ """
72
+ annotated: List[AnnData] = []
73
+ for mod, a in adata_dict.items():
74
+ aa = a.copy()
75
+ aa.obs["univi_modality"] = str(mod)
76
+ annotated.append(aa)
77
+
78
+ combined = annotated[0].concatenate(
79
+ *annotated[1:],
80
+ batch_key="univi_source",
81
+ batch_categories=list(adata_dict.keys()),
82
+ index_unique="-",
83
+ join="outer",
84
+ )
85
+
86
+ # Carry embeddings if needed (common pitfall: concatenate drops obsm keys)
87
+ if obsm_key not in combined.obsm:
88
+ # try to rebuild from parts
89
+ try:
90
+ Zs = [adata_dict[m].obsm[obsm_key] for m in adata_dict.keys()]
91
+ combined.obsm[obsm_key] = np.vstack(Zs)
92
+ except Exception:
93
+ raise KeyError("combined is missing obsm[%r] after concatenation; add it manually." % obsm_key)
94
+
95
+ umap_single_adata(
96
+ combined,
97
+ obsm_key=obsm_key,
98
+ color=[color, "univi_modality"],
99
+ savepath=savepath,
100
+ n_neighbors=n_neighbors,
101
+ random_state=random_state,
102
+ )
103
+
104
+
105
+ def plot_confusion_matrix(
106
+ cm: np.ndarray,
107
+ labels: np.ndarray,
108
+ title: str = "Label transfer (source \u2192 target)",
109
+ savepath: Optional[str] = None,
110
+ ) -> None:
111
+ plt.figure(figsize=(6, 5))
112
+ sns.heatmap(
113
+ cm,
114
+ annot=False,
115
+ xticklabels=labels,
116
+ yticklabels=labels,
117
+ cmap="viridis",
118
+ )
119
+ plt.xlabel("Predicted")
120
+ plt.ylabel("True")
121
+ plt.title(title)
122
+ plt.tight_layout()
123
+ if savepath is not None:
124
+ plt.savefig(savepath, dpi=300, bbox_inches="tight")
125
+ plt.close()
126
+