univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,908 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: univi
|
|
3
|
+
Version: 0.3.4
|
|
4
|
+
Summary: UniVI: a scalable multi-modal variational autoencoder toolkit for seamless integration and analysis of multimodal single-cell data.
|
|
5
|
+
Author-email: "Andrew J. Ashford" <ashforda@ohsu.edu>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 Andrew J. Ashford
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the “Software”), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Project-URL: Homepage, https://github.com/Ashford-A/UniVI
|
|
29
|
+
Project-URL: Repository, https://github.com/Ashford-A/UniVI
|
|
30
|
+
Project-URL: Bug Tracker, https://github.com/Ashford-A/UniVI/issues
|
|
31
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
32
|
+
Classifier: Programming Language :: Python :: 3
|
|
33
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
34
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
35
|
+
Requires-Python: >=3.10
|
|
36
|
+
Description-Content-Type: text/markdown
|
|
37
|
+
License-File: LICENSE
|
|
38
|
+
Requires-Dist: numpy>=1.26
|
|
39
|
+
Requires-Dist: scipy>=1.11
|
|
40
|
+
Requires-Dist: pandas>=2.1
|
|
41
|
+
Requires-Dist: anndata>=0.10
|
|
42
|
+
Requires-Dist: scanpy>=1.11
|
|
43
|
+
Requires-Dist: torch>=2.2
|
|
44
|
+
Requires-Dist: scikit-learn>=1.3
|
|
45
|
+
Requires-Dist: h5py>=3.10
|
|
46
|
+
Requires-Dist: pyyaml>=6.0
|
|
47
|
+
Requires-Dist: matplotlib>=3.8
|
|
48
|
+
Requires-Dist: seaborn>=0.13
|
|
49
|
+
Requires-Dist: igraph>=0.11
|
|
50
|
+
Requires-Dist: leidenalg>=0.10
|
|
51
|
+
Requires-Dist: tqdm>=4.66
|
|
52
|
+
Requires-Dist: openpyxl>=3.1
|
|
53
|
+
Provides-Extra: bench
|
|
54
|
+
Requires-Dist: harmonypy>=0.0.9; extra == "bench"
|
|
55
|
+
Dynamic: license-file
|
|
56
|
+
|
|
57
|
+
# UniVI
|
|
58
|
+
|
|
59
|
+
[](https://pypi.org/project/univi/)
|
|
60
|
+
[](https://pypi.org/project/univi/)
|
|
61
|
+
|
|
62
|
+
<picture>
|
|
63
|
+
<!-- Dark mode (GitHub supports this; PyPI may ignore <source>) -->
|
|
64
|
+
<source media="(prefers-color-scheme: dark)"
|
|
65
|
+
srcset="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.3.4/assets/figures/univi_overview_dark.png">
|
|
66
|
+
<!-- Light mode / fallback (works on GitHub + PyPI) -->
|
|
67
|
+
<img src="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.3.4/assets/figures/univi_overview_light.png"
|
|
68
|
+
alt="UniVI overview and evaluation roadmap"
|
|
69
|
+
width="100%">
|
|
70
|
+
</picture>
|
|
71
|
+
|
|
72
|
+
**UniVI** is a **multi-modal variational autoencoder (VAE)** framework for aligning and integrating single-cell modalities such as RNA, ADT (CITE-seq), and ATAC.
|
|
73
|
+
|
|
74
|
+
It’s designed for experiments like:
|
|
75
|
+
|
|
76
|
+
- **Joint embedding** of paired multimodal data (CITE-seq, Multiome, TEA-seq)
|
|
77
|
+
- **Zero-shot projection** of external unimodal cohorts into a paired “bridge” latent
|
|
78
|
+
- **Cross-modal reconstruction / imputation** (RNA→ADT, ATAC→RNA, etc.)
|
|
79
|
+
- **Denoising** via learned generative decoders
|
|
80
|
+
- **Evaluation** (FOSCTTM, modality mixing, label transfer, feature recovery)
|
|
81
|
+
- **Optional supervised heads** for harmonized annotation and domain confusion
|
|
82
|
+
- **Optional transformer encoders** (per-modality and/or fused multimodal transformer posterior)
|
|
83
|
+
- **Token-level hooks** for interpretability (top-k indices; optional attention maps if enabled)
|
|
84
|
+
|
|
85
|
+
---
|
|
86
|
+
|
|
87
|
+
## Preprint
|
|
88
|
+
|
|
89
|
+
If you use UniVI in your work, please cite:
|
|
90
|
+
|
|
91
|
+
> Ashford AJ, Enright T, Nikolova O, Demir E.
|
|
92
|
+
> **Unifying Multimodal Single-Cell Data Using a Mixture of Experts β-Variational Autoencoder-Based Framework.**
|
|
93
|
+
> *bioRxiv* (2025). doi: [10.1101/2025.02.28.640429](https://www.biorxiv.org/content/10.1101/2025.02.28.640429v1.full)
|
|
94
|
+
|
|
95
|
+
```bibtex
|
|
96
|
+
@article{Ashford2025UniVI,
|
|
97
|
+
title = {Unifying Multimodal Single-Cell Data Using a Mixture of Experts β-Variational Autoencoder-Based Framework},
|
|
98
|
+
author = {Ashford, Andrew J. and Enright, Trevor and Nikolova, Olga and Demir, Emek},
|
|
99
|
+
journal = {bioRxiv},
|
|
100
|
+
year = {2025},
|
|
101
|
+
doi = {10.1101/2025.02.28.640429},
|
|
102
|
+
url = {https://www.biorxiv.org/content/10.1101/2025.02.28.640429v1}
|
|
103
|
+
}
|
|
104
|
+
````
|
|
105
|
+
|
|
106
|
+
---
|
|
107
|
+
|
|
108
|
+
## License
|
|
109
|
+
|
|
110
|
+
MIT License — see `LICENSE`.
|
|
111
|
+
|
|
112
|
+
---
|
|
113
|
+
|
|
114
|
+
## Repository structure
|
|
115
|
+
|
|
116
|
+
```text
|
|
117
|
+
UniVI/
|
|
118
|
+
├── README.md # Project overview, installation, quickstart
|
|
119
|
+
├── LICENSE # MIT license text file
|
|
120
|
+
├── pyproject.toml # Python packaging config (pip / PyPI)
|
|
121
|
+
├── assets/ # Static assets used by README/docs
|
|
122
|
+
│ └── figures/ # Schematic figure(s) for repository front page
|
|
123
|
+
├── conda.recipe/ # Conda build recipe (for conda-build)
|
|
124
|
+
│ └── meta.yaml
|
|
125
|
+
├── envs/ # Example conda environments
|
|
126
|
+
│ ├── UniVI_working_environment.yml
|
|
127
|
+
│ ├── UniVI_working_environment_v2_full.yml
|
|
128
|
+
│ ├── UniVI_working_environment_v2_minimal.yml
|
|
129
|
+
│ └── univi_env.yml # Recommended env (CUDA-friendly)
|
|
130
|
+
├── data/ # Small example data notes (datasets are typically external)
|
|
131
|
+
│ └── README.md # Notes on data sources / formats
|
|
132
|
+
├── notebooks/ # Jupyter notebooks (demos & benchmarks)
|
|
133
|
+
│ ├── UniVI_CITE-seq_*.ipynb
|
|
134
|
+
│ ├── UniVI_10x_Multiome_*.ipynb
|
|
135
|
+
│ └── UniVI_TEA-seq_*.ipynb
|
|
136
|
+
├── parameter_files/ # JSON configs for model + training + data selectors
|
|
137
|
+
│ ├── defaults_*.json # Default configs (per experiment)
|
|
138
|
+
│ └── params_*.json # Example “named” configs (RNA, ADT, ATAC, etc.)
|
|
139
|
+
├── scripts/ # Reproducible entry points (revision-friendly)
|
|
140
|
+
│ ├── train_univi.py # Train UniVI from a parameter JSON
|
|
141
|
+
│ ├── evaluate_univi.py # Evaluate trained models (FOSCTTM, label transfer, etc.)
|
|
142
|
+
│ ├── benchmark_univi_citeseq.py # CITE-seq-specific benchmarking script
|
|
143
|
+
│ ├── run_multiome_hparam_search.py
|
|
144
|
+
│ ├── run_frequency_robustness.py # Composition/frequency mismatch robustness
|
|
145
|
+
│ ├── run_do_not_integrate_detection.py # “Do-not-integrate” unmatched population demo
|
|
146
|
+
│ ├── run_benchmarks.py # Unified wrapper (includes optional Harmony baseline)
|
|
147
|
+
│ └── revision_reproduce_all.sh # One-click: reproduces figures + supplemental tables
|
|
148
|
+
└── univi/ # UniVI Python package (importable as `import univi`)
|
|
149
|
+
├── __init__.py # Package exports and __version__
|
|
150
|
+
├── __main__.py # Enables: `python -m univi ...`
|
|
151
|
+
├── cli.py # Minimal CLI (e.g., export-s1, encode)
|
|
152
|
+
├── pipeline.py # Config-driven model+data loading; latent encoding helpers
|
|
153
|
+
├── diagnostics.py # Exports Supplemental_Table_S1.xlsx (env + hparams + dataset stats)
|
|
154
|
+
├── config.py # Config dataclasses (UniVIConfig, ModalityConfig, TrainingConfig)
|
|
155
|
+
├── data.py # Dataset wrappers + matrix selectors (layer/X_key, obsm support)
|
|
156
|
+
├── evaluation.py # Metrics (FOSCTTM, mixing, label transfer, feature recovery)
|
|
157
|
+
├── matching.py # Modality matching / alignment helpers
|
|
158
|
+
├── objectives.py # Losses (ELBO variants, KL/alignment annealing, etc.)
|
|
159
|
+
├── plotting.py # Plotting helpers + consistent style defaults
|
|
160
|
+
├── trainer.py # UniVITrainer: training loop, logging, checkpointing
|
|
161
|
+
├── interpretability.py # Helper scripts for transformer token weight interpretability
|
|
162
|
+
├── figures/ # Package-internal figure assets (placeholder)
|
|
163
|
+
│ └── .gitkeep
|
|
164
|
+
├── models/ # VAE architectures + building blocks
|
|
165
|
+
│ ├── __init__.py
|
|
166
|
+
│ ├── mlp.py # Shared MLP building blocks
|
|
167
|
+
│ ├── encoders.py # Modality encoders (MLP + transformer + fused transformer)
|
|
168
|
+
│ ├── decoders.py # Likelihood-specific decoders (NB, ZINB, Gaussian, etc.)
|
|
169
|
+
│ ├── transformer.py # Transformer blocks + encoder (+ optional attn bias support)
|
|
170
|
+
│ ├── tokenizer.py # Tokenization configs/helpers (top-k / patch)
|
|
171
|
+
│ └── univi.py # Core UniVI multi-modal VAE
|
|
172
|
+
├── hyperparam_optimization/ # Hyperparameter search scripts
|
|
173
|
+
│ ├── __init__.py
|
|
174
|
+
│ ├── common.py
|
|
175
|
+
│ ├── run_adt_hparam_search.py
|
|
176
|
+
│ ├── run_atac_hparam_search.py
|
|
177
|
+
│ ├── run_citeseq_hparam_search.py
|
|
178
|
+
│ ├── run_multiome_hparam_search.py
|
|
179
|
+
│ ├── run_rna_hparam_search.py
|
|
180
|
+
│ └── run_teaseq_hparam_search.py
|
|
181
|
+
└── utils/ # General utilities
|
|
182
|
+
├── __init__.py
|
|
183
|
+
├── io.py # I/O helpers (AnnData, configs, checkpoints)
|
|
184
|
+
├── logging.py # Logging configuration / progress reporting
|
|
185
|
+
├── seed.py # Reproducibility helpers (seeding RNGs)
|
|
186
|
+
├── stats.py # Small statistical helpers / transforms
|
|
187
|
+
└── torch_utils.py # PyTorch utilities (device, tensor helpers)
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
---
|
|
191
|
+
|
|
192
|
+
## Generated outputs
|
|
193
|
+
|
|
194
|
+
Most entry-point scripts write results into a user-specified output directory (commonly `runs/`), which is not tracked in git.
|
|
195
|
+
|
|
196
|
+
A typical `runs/` folder produced by `scripts/revision_reproduce_all.sh` looks like:
|
|
197
|
+
|
|
198
|
+
```text
|
|
199
|
+
runs/
|
|
200
|
+
└── <run_name>/ # user-chosen run name (often includes dataset + date)
|
|
201
|
+
├── checkpoints/ # model/trainer state for resuming or export
|
|
202
|
+
│ ├── univi_checkpoint.pt # primary checkpoint (model + optimizer + config, if enabled)
|
|
203
|
+
│ └── best.pt # optional: best-val checkpoint (if early stopping enabled)
|
|
204
|
+
├── eval/ # evaluation summaries and derived plots
|
|
205
|
+
│ ├── metrics.json # machine-readable metrics summary
|
|
206
|
+
│ ├── metrics.csv # flat table for quick comparisons
|
|
207
|
+
│ └── plots/ # optional: UMAPs, heatmaps, and benchmark figures
|
|
208
|
+
├── embeddings/ # optional: exported latents for downstream analysis
|
|
209
|
+
│ ├── mu_z.npy # fused mean embedding (cells x latent_dim)
|
|
210
|
+
│ ├── modality_mu/ # per-modality embeddings q(z|x_m)
|
|
211
|
+
│ │ ├── rna.npy
|
|
212
|
+
│ │ ├── adt.npy
|
|
213
|
+
│ │ └── atac.npy
|
|
214
|
+
│ └── obs_names.txt # row order for embeddings (safe joins)
|
|
215
|
+
├── reconstructions/ # optional: recon and cross-recon exports
|
|
216
|
+
│ ├── rna_from_rna.npy # denoised reconstruction
|
|
217
|
+
│ ├── adt_from_adt.npy
|
|
218
|
+
│ ├── adt_from_rna.npy # cross-modal imputation example
|
|
219
|
+
│ └── rna_from_atac.npy
|
|
220
|
+
├── robustness/ # robustness experiments (frequency mismatch, DnI, etc.)
|
|
221
|
+
│ ├── frequency_perturbation_results.csv
|
|
222
|
+
│ ├── frequency_perturbation_plot.png
|
|
223
|
+
│ ├── frequency_perturbation_plot.pdf
|
|
224
|
+
│ ├── do_not_integrate_summary.csv
|
|
225
|
+
│ ├── do_not_integrate_plot.png
|
|
226
|
+
│ └── do_not_integrate_plot.pdf
|
|
227
|
+
├── benchmarks/ # baseline comparisons (optionally includes Harmony, etc.)
|
|
228
|
+
│ ├── results.csv
|
|
229
|
+
│ ├── results.png
|
|
230
|
+
│ └── results.pdf
|
|
231
|
+
├── tables/
|
|
232
|
+
│ └── Supplemental_Table_S1.xlsx # environment + hparams + dataset statistics snapshot
|
|
233
|
+
└── logs/
|
|
234
|
+
├── train.log # training log (stdout/stderr capture)
|
|
235
|
+
└── history.csv # per-epoch train/val traces (if enabled)
|
|
236
|
+
```
|
|
237
|
+
|
|
238
|
+
(Exact subfolders vary by script and flags; the layout above shows the common outputs across the pipeline.)
|
|
239
|
+
|
|
240
|
+
---
|
|
241
|
+
|
|
242
|
+
## Installation
|
|
243
|
+
|
|
244
|
+
### Install via PyPI
|
|
245
|
+
|
|
246
|
+
```bash
|
|
247
|
+
pip install univi
|
|
248
|
+
```
|
|
249
|
+
|
|
250
|
+
> **Note:** UniVI requires `torch`. If `import torch` fails, install PyTorch for your platform/CUDA from:
|
|
251
|
+
> [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
|
|
252
|
+
|
|
253
|
+
### Development install (from source)
|
|
254
|
+
|
|
255
|
+
```bash
|
|
256
|
+
git clone https://github.com/Ashford-A/UniVI.git
|
|
257
|
+
cd UniVI
|
|
258
|
+
|
|
259
|
+
conda env create -f envs/univi_env.yml
|
|
260
|
+
conda activate univi_env
|
|
261
|
+
|
|
262
|
+
pip install -e .
|
|
263
|
+
```
|
|
264
|
+
|
|
265
|
+
### (Optional) Install via conda / mamba
|
|
266
|
+
|
|
267
|
+
```bash
|
|
268
|
+
conda install -c conda-forge univi
|
|
269
|
+
# or
|
|
270
|
+
mamba install -c conda-forge univi
|
|
271
|
+
```
|
|
272
|
+
|
|
273
|
+
UniVI is also installable from a custom channel:
|
|
274
|
+
|
|
275
|
+
```bash
|
|
276
|
+
conda install ashford-a::univi
|
|
277
|
+
# or
|
|
278
|
+
mamba install ashford-a::univi
|
|
279
|
+
```
|
|
280
|
+
|
|
281
|
+
---
|
|
282
|
+
|
|
283
|
+
## Data expectations (high-level)
|
|
284
|
+
|
|
285
|
+
UniVI expects **per-modality AnnData** objects with matching cells (paired data or consistently paired across modalities).
|
|
286
|
+
|
|
287
|
+
Typical expectations:
|
|
288
|
+
|
|
289
|
+
* Each modality is an `AnnData` with the same `obs_names` (same cells, same order)
|
|
290
|
+
* Raw counts often live in `.layers["counts"]`
|
|
291
|
+
* A processed training representation lives in `.X` (or `.obsm["X_*"]` for ATAC LSI)
|
|
292
|
+
* Decoder likelihoods should roughly match the training representation:
|
|
293
|
+
|
|
294
|
+
* counts-like → `nb` / `zinb` / `poisson`
|
|
295
|
+
* continuous → `gaussian` / `mse`
|
|
296
|
+
|
|
297
|
+
See `notebooks/` for end-to-end preprocessing examples.
|
|
298
|
+
|
|
299
|
+
---
|
|
300
|
+
|
|
301
|
+
## Training objectives (v1 vs v2/lite)
|
|
302
|
+
|
|
303
|
+
UniVI supports two main training regimes:
|
|
304
|
+
|
|
305
|
+
* **UniVI v1 (“paper”)**
|
|
306
|
+
Per-modality posteriors + flexible reconstruction scheme (cross/self/avg) + posterior alignment across modalities.
|
|
307
|
+
|
|
308
|
+
* **UniVI v2 / lite**
|
|
309
|
+
A fused posterior (precision-weighted MoE/PoE-style by default; optional fused transformer) + per-modality recon + β·KL + γ·alignment.
|
|
310
|
+
Convenient for 3+ modalities and “loosely paired” settings.
|
|
311
|
+
|
|
312
|
+
You choose via `loss_mode` at model construction (Python) or config JSON (CLI scripts).
|
|
313
|
+
|
|
314
|
+
---
|
|
315
|
+
|
|
316
|
+
## Quickstart (Python / Jupyter)
|
|
317
|
+
|
|
318
|
+
Below is a minimal paired **CITE-seq (RNA + ADT)** example using `MultiModalDataset` + `UniVITrainer`.
|
|
319
|
+
|
|
320
|
+
```python
|
|
321
|
+
import numpy as np
|
|
322
|
+
import scanpy as sc
|
|
323
|
+
import torch
|
|
324
|
+
from torch.utils.data import DataLoader, Subset
|
|
325
|
+
|
|
326
|
+
from univi import UniVIMultiModalVAE, ModalityConfig, UniVIConfig, TrainingConfig
|
|
327
|
+
from univi.data import MultiModalDataset, align_paired_obs_names
|
|
328
|
+
from univi.trainer import UniVITrainer
|
|
329
|
+
```
|
|
330
|
+
|
|
331
|
+
### 1) Load paired AnnData
|
|
332
|
+
|
|
333
|
+
```python
|
|
334
|
+
rna = sc.read_h5ad("path/to/rna_citeseq.h5ad")
|
|
335
|
+
adt = sc.read_h5ad("path/to/adt_citeseq.h5ad")
|
|
336
|
+
|
|
337
|
+
adata_dict = {"rna": rna, "adt": adt}
|
|
338
|
+
adata_dict = align_paired_obs_names(adata_dict) # ensures same obs_names/order
|
|
339
|
+
```
|
|
340
|
+
|
|
341
|
+
### 2) Dataset + dataloaders
|
|
342
|
+
|
|
343
|
+
```python
|
|
344
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
345
|
+
|
|
346
|
+
dataset = MultiModalDataset(
|
|
347
|
+
adata_dict=adata_dict,
|
|
348
|
+
X_key="X", # uses .X by default
|
|
349
|
+
device=None, # dataset returns CPU tensors; model moves to GPU
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
n = rna.n_obs
|
|
353
|
+
idx = np.arange(n)
|
|
354
|
+
rng = np.random.default_rng(0)
|
|
355
|
+
rng.shuffle(idx)
|
|
356
|
+
split = int(0.8 * n)
|
|
357
|
+
train_idx, val_idx = idx[:split], idx[split:]
|
|
358
|
+
|
|
359
|
+
train_ds = Subset(dataset, train_idx)
|
|
360
|
+
val_ds = Subset(dataset, val_idx)
|
|
361
|
+
|
|
362
|
+
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=0)
|
|
363
|
+
val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=0)
|
|
364
|
+
```
|
|
365
|
+
|
|
366
|
+
### 3) Config + model
|
|
367
|
+
|
|
368
|
+
```python
|
|
369
|
+
univi_cfg = UniVIConfig(
|
|
370
|
+
latent_dim=40,
|
|
371
|
+
beta=1.5,
|
|
372
|
+
gamma=2.5,
|
|
373
|
+
encoder_dropout=0.1,
|
|
374
|
+
decoder_dropout=0.0,
|
|
375
|
+
modalities=[
|
|
376
|
+
ModalityConfig("rna", rna.n_vars, [512, 256, 128], [128, 256, 512], likelihood="nb"),
|
|
377
|
+
ModalityConfig("adt", adt.n_vars, [128, 64], [64, 128], likelihood="nb"),
|
|
378
|
+
],
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
train_cfg = TrainingConfig(
|
|
382
|
+
n_epochs=1000,
|
|
383
|
+
batch_size=256,
|
|
384
|
+
lr=1e-3,
|
|
385
|
+
weight_decay=1e-4,
|
|
386
|
+
device=device,
|
|
387
|
+
log_every=20,
|
|
388
|
+
grad_clip=5.0,
|
|
389
|
+
early_stopping=True,
|
|
390
|
+
patience=50,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# v1 (paper)
|
|
394
|
+
model = UniVIMultiModalVAE(
|
|
395
|
+
univi_cfg,
|
|
396
|
+
loss_mode="v1",
|
|
397
|
+
v1_recon="avg",
|
|
398
|
+
normalize_v1_terms=True,
|
|
399
|
+
).to(device)
|
|
400
|
+
|
|
401
|
+
# Or: v2/lite
|
|
402
|
+
# model = UniVIMultiModalVAE(univi_cfg, loss_mode="v2").to(device)
|
|
403
|
+
```
|
|
404
|
+
|
|
405
|
+
### 4) Train
|
|
406
|
+
|
|
407
|
+
```python
|
|
408
|
+
trainer = UniVITrainer(
|
|
409
|
+
model=model,
|
|
410
|
+
train_loader=train_loader,
|
|
411
|
+
val_loader=val_loader,
|
|
412
|
+
train_cfg=train_cfg,
|
|
413
|
+
device=device,
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
history = trainer.fit()
|
|
417
|
+
```
|
|
418
|
+
|
|
419
|
+
---
|
|
420
|
+
|
|
421
|
+
## Mixed precision (AMP)
|
|
422
|
+
|
|
423
|
+
AMP (automatic mixed precision) can reduce VRAM usage and speed up training on GPUs by running selected ops in lower precision (fp16 or bf16) while keeping numerically sensitive parts in fp32.
|
|
424
|
+
|
|
425
|
+
If your trainer supports AMP flags, prefer bf16 where available. If using fp16, gradient scaling is typically used internally to avoid underflow.
|
|
426
|
+
|
|
427
|
+
---
|
|
428
|
+
|
|
429
|
+
## Checkpointing and resuming
|
|
430
|
+
|
|
431
|
+
Training is often run on clusters, so checkpoints are treated as first-class outputs.
|
|
432
|
+
|
|
433
|
+
Typical checkpoints contain:
|
|
434
|
+
|
|
435
|
+
* model weights
|
|
436
|
+
* optimizer state (for faithful resumption)
|
|
437
|
+
* training config/model config (for reproducibility)
|
|
438
|
+
* optional AMP scaler state (when using fp16 AMP)
|
|
439
|
+
|
|
440
|
+
See `univi/utils/io.py` for the exact checkpoint read/write helpers used by the trainer.
|
|
441
|
+
|
|
442
|
+
---
|
|
443
|
+
|
|
444
|
+
## Classification (built-in heads)
|
|
445
|
+
|
|
446
|
+
UniVI supports **in-model supervised classification heads** (single “legacy” label head and/or multi-head auxiliary decoders). This is useful for:
|
|
447
|
+
|
|
448
|
+
* harmonized cell-type annotation (e.g., bridge → projected cohorts)
|
|
449
|
+
* batch/tech/patient prediction (sanity checks, confounding)
|
|
450
|
+
* adversarial domain confusion via gradient reversal (GRL)
|
|
451
|
+
* multi-task setups (e.g., celltype + patient + mutation flags)
|
|
452
|
+
|
|
453
|
+
### How it works
|
|
454
|
+
|
|
455
|
+
* Heads are configured via `UniVIConfig.class_heads` using `ClassHeadConfig`.
|
|
456
|
+
* Training targets are passed as `y`, a **dict mapping head name → integer class indices** with shape `(B,)`.
|
|
457
|
+
* Unlabeled entries should use `ignore_index` (default `-1`) and are masked out automatically.
|
|
458
|
+
* Each head can be delayed with `warmup` and weighted with `loss_weight`.
|
|
459
|
+
* Set `adversarial=True` for GRL heads (domain confusion).
|
|
460
|
+
|
|
461
|
+
### 1) Add heads in the config
|
|
462
|
+
|
|
463
|
+
```python
|
|
464
|
+
from univi.config import ClassHeadConfig
|
|
465
|
+
|
|
466
|
+
univi_cfg = UniVIConfig(
|
|
467
|
+
latent_dim=40,
|
|
468
|
+
beta=1.5,
|
|
469
|
+
gamma=2.5,
|
|
470
|
+
modalities=[
|
|
471
|
+
ModalityConfig("rna", rna.n_vars, [512,256,128], [128,256,512], likelihood="nb"),
|
|
472
|
+
ModalityConfig("adt", adt.n_vars, [128,64], [64,128], likelihood="nb"),
|
|
473
|
+
],
|
|
474
|
+
class_heads=[
|
|
475
|
+
ClassHeadConfig(
|
|
476
|
+
name="celltype",
|
|
477
|
+
n_classes=int(rna.obs["celltype"].astype("category").cat.categories.size),
|
|
478
|
+
loss_weight=1.0,
|
|
479
|
+
ignore_index=-1,
|
|
480
|
+
from_mu=True, # classify from mu_z (more stable)
|
|
481
|
+
warmup=0,
|
|
482
|
+
),
|
|
483
|
+
ClassHeadConfig(
|
|
484
|
+
name="batch",
|
|
485
|
+
n_classes=int(rna.obs["batch"].astype("category").cat.categories.size),
|
|
486
|
+
loss_weight=0.2,
|
|
487
|
+
ignore_index=-1,
|
|
488
|
+
from_mu=True,
|
|
489
|
+
warmup=10,
|
|
490
|
+
adversarial=True, # GRL head (domain confusion)
|
|
491
|
+
adv_lambda=1.0,
|
|
492
|
+
),
|
|
493
|
+
],
|
|
494
|
+
)
|
|
495
|
+
```
|
|
496
|
+
|
|
497
|
+
Optional: attach readable label names (for your own decoding later):
|
|
498
|
+
|
|
499
|
+
```python
|
|
500
|
+
model.set_head_label_names("celltype", list(rna.obs["celltype"].astype("category").cat.categories))
|
|
501
|
+
model.set_head_label_names("batch", list(rna.obs["batch"].astype("category").cat.categories))
|
|
502
|
+
```
|
|
503
|
+
|
|
504
|
+
### 2) Pass `y` to the model during training
|
|
505
|
+
|
|
506
|
+
Example pattern (construct labels from arrays aligned to dataset order):
|
|
507
|
+
|
|
508
|
+
```python
|
|
509
|
+
celltype_codes = rna.obs["celltype"].astype("category").cat.codes.to_numpy()
|
|
510
|
+
batch_codes = rna.obs["batch"].astype("category").cat.codes.to_numpy()
|
|
511
|
+
|
|
512
|
+
y = {
|
|
513
|
+
"celltype": torch.tensor(celltype_codes[batch_idx], device=device),
|
|
514
|
+
"batch": torch.tensor(batch_codes[batch_idx], device=device),
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
out = model(x_dict, epoch=epoch, y=y)
|
|
518
|
+
loss = out["loss"]
|
|
519
|
+
loss.backward()
|
|
520
|
+
```
|
|
521
|
+
|
|
522
|
+
When labels are provided, the forward output can include:
|
|
523
|
+
|
|
524
|
+
* `out["head_logits"]`: dict of logits `(B, n_classes)` per head
|
|
525
|
+
* `out["head_losses"]`: mean CE per head (masked by `ignore_index`)
|
|
526
|
+
|
|
527
|
+
### 3) Predict heads after training
|
|
528
|
+
|
|
529
|
+
```python
|
|
530
|
+
model.eval()
|
|
531
|
+
batch = next(iter(val_loader))
|
|
532
|
+
x_dict = {k: v.to(device) for k, v in batch.items()}
|
|
533
|
+
|
|
534
|
+
with torch.no_grad():
|
|
535
|
+
probs = model.predict_heads(x_dict, return_probs=True)
|
|
536
|
+
|
|
537
|
+
for head_name, P in probs.items():
|
|
538
|
+
print(head_name, P.shape) # (B, n_classes)
|
|
539
|
+
```
|
|
540
|
+
|
|
541
|
+
To inspect which heads exist + their settings:
|
|
542
|
+
|
|
543
|
+
```python
|
|
544
|
+
meta = model.get_classification_meta()
|
|
545
|
+
print(meta)
|
|
546
|
+
```
|
|
547
|
+
|
|
548
|
+
---
|
|
549
|
+
|
|
550
|
+
## After training: what you can do with a trained UniVI model
|
|
551
|
+
|
|
552
|
+
UniVI isn’t just “map to latent”. With a trained model you can typically:
|
|
553
|
+
|
|
554
|
+
* **Encode modality-specific posteriors** `q(z|x_rna)`, `q(z|x_adt)`, …
|
|
555
|
+
* **Encode a fused posterior** (MoE/PoE by default; optional fused multimodal transformer posterior)
|
|
556
|
+
* **Denoise / reconstruct** inputs via the learned decoders
|
|
557
|
+
* **Cross-reconstruct / impute** across modalities (RNA→ADT, ATAC→RNA, etc.)
|
|
558
|
+
* **Evaluate alignment** (FOSCTTM, Recall@k, modality mixing, label transfer)
|
|
559
|
+
* **Predict supervised targets** via built-in classification heads (if enabled)
|
|
560
|
+
* **Inspect uncertainty** via per-modality posterior means/variances
|
|
561
|
+
* (Optional) **Inspect transformer token metadata** (top-k indices; attention maps when enabled)
|
|
562
|
+
|
|
563
|
+
### Fused posterior options
|
|
564
|
+
|
|
565
|
+
UniVI can produce a fused latent in two ways:
|
|
566
|
+
|
|
567
|
+
* Default: **precision-weighted MoE/PoE fusion** over per-modality posteriors
|
|
568
|
+
* Optional: **fused multimodal transformer posterior** (`fused_encoder_type="multimodal_transformer"`)
|
|
569
|
+
|
|
570
|
+
In both cases, the standard embedding used for plotting/neighbors is the fused mean:
|
|
571
|
+
|
|
572
|
+
```python
|
|
573
|
+
mu_z, logvar_z, z = model.encode_fused(x_dict, use_mean=True)
|
|
574
|
+
````
|
|
575
|
+
|
|
576
|
+
### 1) Encode embeddings for plotting / neighbors (built-in)
|
|
577
|
+
|
|
578
|
+
Use `encode_adata` to get either fused (MoE/PoE) or modality-specific latents directly from an AnnData.
|
|
579
|
+
|
|
580
|
+
```python
|
|
581
|
+
import scanpy as sc
|
|
582
|
+
import torch
|
|
583
|
+
from univi.evaluation import encode_adata
|
|
584
|
+
|
|
585
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
586
|
+
model.eval()
|
|
587
|
+
|
|
588
|
+
# Fused latent (MoE/PoE) from a single observed modality
|
|
589
|
+
Z_fused = encode_adata(
|
|
590
|
+
model,
|
|
591
|
+
rna,
|
|
592
|
+
modality="rna",
|
|
593
|
+
device=device,
|
|
594
|
+
layer="counts", # or None to use .X
|
|
595
|
+
latent="moe_mean", # {"moe_mean","moe_sample","modality_mean","modality_sample"}
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# Modality-specific latent (projection / diagnostics)
|
|
599
|
+
Z_rna = encode_adata(
|
|
600
|
+
model,
|
|
601
|
+
rna,
|
|
602
|
+
modality="rna",
|
|
603
|
+
device=device,
|
|
604
|
+
layer="counts",
|
|
605
|
+
latent="modality_mean",
|
|
606
|
+
)
|
|
607
|
+
|
|
608
|
+
rna.obsm["X_univi_fused"] = Z_fused
|
|
609
|
+
rna.obsm["X_univi_rna"] = Z_rna
|
|
610
|
+
|
|
611
|
+
sc.pp.neighbors(rna, use_rep="X_univi_fused")
|
|
612
|
+
sc.tl.umap(rna)
|
|
613
|
+
sc.pl.umap(rna, color=["celltype"], frameon=False)
|
|
614
|
+
```
|
|
615
|
+
|
|
616
|
+
### 2) Evaluate paired alignment (FOSCTTM, Recall@k, mixing, label transfer)
|
|
617
|
+
|
|
618
|
+
`evaluate_alignment` is a figure-ready wrapper. It can take precomputed `Z1/Z2`, or compute embeddings from AnnData via `encode_adata`.
|
|
619
|
+
|
|
620
|
+
```python
|
|
621
|
+
from univi.evaluation import evaluate_alignment
|
|
622
|
+
|
|
623
|
+
# For paired data, you typically pass modality-specific latents for the two modalities
|
|
624
|
+
res = evaluate_alignment(
|
|
625
|
+
model=model,
|
|
626
|
+
adata1=rna,
|
|
627
|
+
adata2=adt,
|
|
628
|
+
mod1="rna",
|
|
629
|
+
mod2="adt",
|
|
630
|
+
device=device,
|
|
631
|
+
layer1="counts",
|
|
632
|
+
layer2="counts",
|
|
633
|
+
latent="modality_mean",
|
|
634
|
+
metric="euclidean",
|
|
635
|
+
recall_ks=(1, 5, 10),
|
|
636
|
+
k_mixing=20,
|
|
637
|
+
k_transfer=15,
|
|
638
|
+
# optional label transfer inputs:
|
|
639
|
+
# labels_source=rna.obs["celltype"].to_numpy(),
|
|
640
|
+
# labels_target=adt.obs["celltype"].to_numpy(),
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
print(res) # dict includes foscttm(+sem), recall@k(+sem), modality_mixing(+sem), label transfer (optional)
|
|
644
|
+
```
|
|
645
|
+
|
|
646
|
+
### 3) Denoise / reconstruct a modality (built-in)
|
|
647
|
+
|
|
648
|
+
`denoise_adata` runs “encode modality → decode same modality” and can write to a layer.
|
|
649
|
+
|
|
650
|
+
```python
|
|
651
|
+
from univi.evaluation import denoise_adata
|
|
652
|
+
|
|
653
|
+
Xhat_rna = denoise_adata(
|
|
654
|
+
model,
|
|
655
|
+
rna,
|
|
656
|
+
modality="rna",
|
|
657
|
+
device=device,
|
|
658
|
+
layer="counts",
|
|
659
|
+
out_layer="univi_denoised", # writes rna.layers["univi_denoised"]
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
# Quick marker plots from denoised values:
|
|
663
|
+
import scanpy as sc
|
|
664
|
+
markers = ["TRAC", "NKG7", "LYZ", "MS4A1", "CD79A"]
|
|
665
|
+
|
|
666
|
+
rna_d = rna.copy()
|
|
667
|
+
rna_d.X = rna_d.layers["univi_denoised"]
|
|
668
|
+
sc.pl.umap(rna_d, color=markers, frameon=False, title=[f"{g} (denoised)" for g in markers])
|
|
669
|
+
```
|
|
670
|
+
|
|
671
|
+
### 4) Cross-modal reconstruction / imputation (built-in)
|
|
672
|
+
|
|
673
|
+
`cross_modal_predict` runs “encode src modality → decode target modality” and returns a dense numpy array.
|
|
674
|
+
|
|
675
|
+
```python
|
|
676
|
+
from univi.evaluation import cross_modal_predict
|
|
677
|
+
|
|
678
|
+
# Example: RNA -> predicted ADT
|
|
679
|
+
adt_from_rna = cross_modal_predict(
|
|
680
|
+
model,
|
|
681
|
+
adata_src=rna,
|
|
682
|
+
src_mod="rna",
|
|
683
|
+
tgt_mod="adt",
|
|
684
|
+
device=device,
|
|
685
|
+
layer="counts",
|
|
686
|
+
batch_size=512,
|
|
687
|
+
use_moe=True, # for src-only input, MoE reduces to the src posterior
|
|
688
|
+
)
|
|
689
|
+
print(adt_from_rna.shape) # (cells, adt_features)
|
|
690
|
+
```
|
|
691
|
+
|
|
692
|
+
### 5) Direct model calls (advanced / debugging)
|
|
693
|
+
|
|
694
|
+
If you want full control (or want posterior means/variances explicitly), call the model methods directly.
|
|
695
|
+
|
|
696
|
+
```python
|
|
697
|
+
import torch
|
|
698
|
+
|
|
699
|
+
model.eval()
|
|
700
|
+
batch = next(iter(val_loader))
|
|
701
|
+
x_dict = {k: v.to(device) for k, v in batch.items()}
|
|
702
|
+
|
|
703
|
+
with torch.no_grad():
|
|
704
|
+
# Per-modality posteriors
|
|
705
|
+
mu_dict, logvar_dict = model.encode_modalities(x_dict)
|
|
706
|
+
|
|
707
|
+
# Fused posterior (MoE/PoE or fused transformer, depending on config)
|
|
708
|
+
mu_z, logvar_z, z = model.encode_fused(x_dict, use_mean=True)
|
|
709
|
+
|
|
710
|
+
# Decode all modalities from a chosen latent (implementation-dependent keys)
|
|
711
|
+
xhat_dict = model.decode_modalities(mu_z)
|
|
712
|
+
```
|
|
713
|
+
|
|
714
|
+
---
|
|
715
|
+
|
|
716
|
+
## CLI training (from JSON configs)
|
|
717
|
+
|
|
718
|
+
Most `scripts/*.py` entry points accept a parameter JSON.
|
|
719
|
+
|
|
720
|
+
**Train:**
|
|
721
|
+
|
|
722
|
+
```bash
|
|
723
|
+
python scripts/train_univi.py \
|
|
724
|
+
--config parameter_files/defaults_cite_seq_scaled_gaussian_v1.json \
|
|
725
|
+
--outdir saved_models/citeseq_v1_run1 \
|
|
726
|
+
--data-root /path/to/your/data
|
|
727
|
+
```
|
|
728
|
+
|
|
729
|
+
**Evaluate:**
|
|
730
|
+
|
|
731
|
+
```bash
|
|
732
|
+
python scripts/evaluate_univi.py \
|
|
733
|
+
--config parameter_files/defaults_cite_seq_scaled_gaussian_v1.json \
|
|
734
|
+
--model-checkpoint saved_models/citeseq_v1_run1/checkpoints/univi_checkpoint.pt \
|
|
735
|
+
--outdir saved_models/citeseq_v1_run1/eval
|
|
736
|
+
```
|
|
737
|
+
|
|
738
|
+
---
|
|
739
|
+
|
|
740
|
+
## Optional: Transformer encoders (per-modality)
|
|
741
|
+
|
|
742
|
+
By default, UniVI uses **MLP encoders** (`encoder_type="mlp"`), and classic workflows work unchanged.
|
|
743
|
+
|
|
744
|
+
If you want a transformer encoder for a modality, set:
|
|
745
|
+
|
|
746
|
+
* `encoder_type="transformer"`
|
|
747
|
+
* a `TokenizerConfig` (how `(B,F)` becomes `(B,T,D_in)`)
|
|
748
|
+
* a `TransformerConfig` (depth/width/pooling)
|
|
749
|
+
|
|
750
|
+
Example:
|
|
751
|
+
|
|
752
|
+
```python
|
|
753
|
+
from univi.config import TransformerConfig, TokenizerConfig
|
|
754
|
+
|
|
755
|
+
univi_cfg = UniVIConfig(
|
|
756
|
+
latent_dim=40,
|
|
757
|
+
beta=1.0,
|
|
758
|
+
gamma=1.25,
|
|
759
|
+
modalities=[
|
|
760
|
+
ModalityConfig(
|
|
761
|
+
name="rna",
|
|
762
|
+
input_dim=rna.n_vars,
|
|
763
|
+
encoder_hidden=[512, 256, 128], # ignored by transformer encoder; kept for compatibility
|
|
764
|
+
decoder_hidden=[128, 256, 512],
|
|
765
|
+
likelihood="gaussian",
|
|
766
|
+
encoder_type="transformer",
|
|
767
|
+
tokenizer=TokenizerConfig(mode="topk_channels", n_tokens=512, channels=("value","rank","dropout")),
|
|
768
|
+
transformer=TransformerConfig(
|
|
769
|
+
d_model=256, num_heads=8, num_layers=4,
|
|
770
|
+
dim_feedforward=1024, dropout=0.1, attn_dropout=0.1,
|
|
771
|
+
activation="gelu", pooling="mean",
|
|
772
|
+
),
|
|
773
|
+
),
|
|
774
|
+
ModalityConfig(
|
|
775
|
+
name="adt",
|
|
776
|
+
input_dim=adt.n_vars,
|
|
777
|
+
encoder_hidden=[128, 64],
|
|
778
|
+
decoder_hidden=[64, 128],
|
|
779
|
+
likelihood="gaussian",
|
|
780
|
+
encoder_type="mlp",
|
|
781
|
+
tokenizer=TokenizerConfig(mode="topk_scalar", n_tokens=min(32, adt.n_vars)), # useful for fused encoder
|
|
782
|
+
),
|
|
783
|
+
],
|
|
784
|
+
)
|
|
785
|
+
```
|
|
786
|
+
|
|
787
|
+
Notes:
|
|
788
|
+
|
|
789
|
+
* Tokenizers focus attention on the most informative features per cell (top-k) or local structure (patching).
|
|
790
|
+
* Transformer encoders expose optional interpretability hooks (token indices and, when enabled, attention maps).
|
|
791
|
+
|
|
792
|
+
---
|
|
793
|
+
|
|
794
|
+
## Optional: ATAC coordinate embeddings and distance attention bias (advanced)
|
|
795
|
+
|
|
796
|
+
For top-k tokenizers, UniVI can optionally incorporate genomic context:
|
|
797
|
+
|
|
798
|
+
* **Coordinate embeddings**: chromosome embedding + coordinate MLP per selected feature
|
|
799
|
+
* **Distance-based attention bias**: encourages attention between nearby peaks (same chromosome)
|
|
800
|
+
|
|
801
|
+
### Enable in the tokenizer config (ATAC example)
|
|
802
|
+
|
|
803
|
+
```python
|
|
804
|
+
TokenizerConfig(
|
|
805
|
+
mode="topk_channels",
|
|
806
|
+
n_tokens=512,
|
|
807
|
+
channels=("value","rank","dropout"),
|
|
808
|
+
use_coord_embedding=True,
|
|
809
|
+
n_chroms=<num_chromosomes>,
|
|
810
|
+
coord_scale=1e-6,
|
|
811
|
+
)
|
|
812
|
+
```
|
|
813
|
+
|
|
814
|
+
### Attach coordinates and configure distance bias via `UniVITrainer`
|
|
815
|
+
|
|
816
|
+
If your `UniVITrainer` supports `feature_coords` and `attn_bias_cfg`, you can attach genomic coordinates once and let the trainer build the bias for you:
|
|
817
|
+
|
|
818
|
+
```python
|
|
819
|
+
feature_coords = {
|
|
820
|
+
"atac": {
|
|
821
|
+
"chrom_ids": chrom_ids_long, # (F,)
|
|
822
|
+
"start": start_bp, # (F,)
|
|
823
|
+
"end": end_bp, # (F,)
|
|
824
|
+
}
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
attn_bias_cfg = {
|
|
828
|
+
"atac": {
|
|
829
|
+
"type": "distance",
|
|
830
|
+
"lengthscale_bp": 50_000.0,
|
|
831
|
+
"same_chrom_only": True,
|
|
832
|
+
}
|
|
833
|
+
}
|
|
834
|
+
|
|
835
|
+
trainer = UniVITrainer(
|
|
836
|
+
model,
|
|
837
|
+
train_loader,
|
|
838
|
+
val_loader=val_loader,
|
|
839
|
+
train_cfg=TrainingConfig(...),
|
|
840
|
+
device="cuda",
|
|
841
|
+
feature_coords=feature_coords,
|
|
842
|
+
attn_bias_cfg=attn_bias_cfg,
|
|
843
|
+
)
|
|
844
|
+
trainer.fit()
|
|
845
|
+
```
|
|
846
|
+
|
|
847
|
+
This path keeps the model code clean and makes the feature-coordinate plumbing consistent across runs.
|
|
848
|
+
|
|
849
|
+
---
|
|
850
|
+
|
|
851
|
+
## Optional: Fused multimodal transformer encoder (advanced)
|
|
852
|
+
|
|
853
|
+
A single transformer sees **concatenated tokens from multiple modalities** and returns a **single fused posterior** `q(z|all modalities)` using global CLS pooling (or mean pooling).
|
|
854
|
+
|
|
855
|
+
### Minimal config
|
|
856
|
+
|
|
857
|
+
```python
|
|
858
|
+
from univi.config import TransformerConfig
|
|
859
|
+
|
|
860
|
+
univi_cfg = UniVIConfig(
|
|
861
|
+
latent_dim=40,
|
|
862
|
+
beta=1.0,
|
|
863
|
+
gamma=1.25,
|
|
864
|
+
modalities=[...], # your per-modality configs still exist
|
|
865
|
+
fused_encoder_type="multimodal_transformer",
|
|
866
|
+
fused_modalities=("rna", "adt", "atac"), # default: all modalities
|
|
867
|
+
fused_transformer=TransformerConfig(
|
|
868
|
+
d_model=256, num_heads=8, num_layers=4,
|
|
869
|
+
dim_feedforward=1024, dropout=0.1, attn_dropout=0.1,
|
|
870
|
+
activation="gelu", pooling="cls",
|
|
871
|
+
),
|
|
872
|
+
)
|
|
873
|
+
```
|
|
874
|
+
|
|
875
|
+
Notes:
|
|
876
|
+
|
|
877
|
+
* Every modality in `fused_modalities` must define a `tokenizer` (even if its per-modality encoder is MLP).
|
|
878
|
+
* If `fused_require_all_modalities=True` and a fused modality is missing at inference, UniVI falls back to MoE/PoE fusion.
|
|
879
|
+
|
|
880
|
+
---
|
|
881
|
+
|
|
882
|
+
## Hyperparameter optimization (optional)
|
|
883
|
+
|
|
884
|
+
```python
|
|
885
|
+
from univi.hyperparam_optimization import (
|
|
886
|
+
run_multiome_hparam_search,
|
|
887
|
+
run_citeseq_hparam_search,
|
|
888
|
+
run_teaseq_hparam_search,
|
|
889
|
+
run_rna_hparam_search,
|
|
890
|
+
run_atac_hparam_search,
|
|
891
|
+
run_adt_hparam_search,
|
|
892
|
+
)
|
|
893
|
+
```
|
|
894
|
+
|
|
895
|
+
See `univi/hyperparam_optimization/` and `notebooks/` for examples.
|
|
896
|
+
|
|
897
|
+
---
|
|
898
|
+
|
|
899
|
+
## Contact, questions, and bug reports
|
|
900
|
+
|
|
901
|
+
* **Questions / comments:** open a GitHub Issue with the `question` label (or a Discussion if enabled).
|
|
902
|
+
* **Bug reports:** open a GitHub Issue and include:
|
|
903
|
+
|
|
904
|
+
* your UniVI version: `python -c "import univi; print(univi.__version__)"`
|
|
905
|
+
* minimal code to reproduce (or a short notebook snippet)
|
|
906
|
+
* stack trace + OS/CUDA/PyTorch versions
|
|
907
|
+
* **Feature requests:** open an Issue describing the use-case + expected inputs/outputs (a tiny example is ideal).
|
|
908
|
+
|