univi 0.2.1__tar.gz → 0.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {univi-0.2.1/univi.egg-info → univi-0.2.5}/PKG-INFO +336 -108
- {univi-0.2.1 → univi-0.2.5}/README.md +335 -107
- {univi-0.2.1 → univi-0.2.5}/pyproject.toml +1 -1
- univi-0.2.5/univi/__init__.py +85 -0
- univi-0.2.5/univi/config.py +110 -0
- univi-0.2.5/univi/data.py +364 -0
- univi-0.2.5/univi/evaluation.py +377 -0
- {univi-0.2.1 → univi-0.2.5}/univi/models/__init__.py +4 -0
- univi-0.2.5/univi/models/decoders.py +249 -0
- {univi-0.2.1 → univi-0.2.5}/univi/models/encoders.py +3 -3
- {univi-0.2.1 → univi-0.2.5}/univi/models/mlp.py +11 -5
- univi-0.2.5/univi/models/univi.py +886 -0
- univi-0.2.5/univi/plotting.py +126 -0
- univi-0.2.5/univi/trainer.py +340 -0
- univi-0.2.5/univi/utils/io.py +222 -0
- {univi-0.2.1 → univi-0.2.5/univi.egg-info}/PKG-INFO +336 -108
- univi-0.2.1/univi/__init__.py +0 -35
- univi-0.2.1/univi/config.py +0 -71
- univi-0.2.1/univi/data.py +0 -190
- univi-0.2.1/univi/evaluation.py +0 -555
- univi-0.2.1/univi/models/decoders.py +0 -443
- univi-0.2.1/univi/models/univi.py +0 -440
- univi-0.2.1/univi/plotting.py +0 -129
- univi-0.2.1/univi/trainer.py +0 -294
- univi-0.2.1/univi/utils/io.py +0 -230
- {univi-0.2.1 → univi-0.2.5}/LICENSE +0 -0
- {univi-0.2.1 → univi-0.2.5}/setup.cfg +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/__main__.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/cli.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/diagnostics.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/__init__.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/common.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_adt_hparam_search.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_atac_hparam_search.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_citeseq_hparam_search.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_multiome_hparam_search.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_rna_hparam_search.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_teaseq_hparam_search.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/matching.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/objectives.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/pipeline.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/utils/__init__.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/utils/logging.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/utils/seed.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/utils/stats.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi/utils/torch_utils.py +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi.egg-info/SOURCES.txt +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi.egg-info/dependency_links.txt +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi.egg-info/entry_points.txt +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi.egg-info/requires.txt +0 -0
- {univi-0.2.1 → univi-0.2.5}/univi.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: univi
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.5
|
|
4
4
|
Summary: UniVI: a scalable multi-modal variational autoencoder toolkit for seamless integration and analysis of multimodal single-cell data.
|
|
5
5
|
Author-email: "Andrew J. Ashford" <ashforda@ohsu.edu>
|
|
6
6
|
License: MIT License
|
|
@@ -57,33 +57,33 @@ Dynamic: license-file
|
|
|
57
57
|
# UniVI
|
|
58
58
|
|
|
59
59
|
[](https://pypi.org/project/univi/)
|
|
60
|
-
[](https://pypi.org/project/univi/)
|
|
61
61
|
|
|
62
62
|
<picture>
|
|
63
63
|
<!-- Dark mode (GitHub supports this; PyPI may ignore <source>) -->
|
|
64
64
|
<source media="(prefers-color-scheme: dark)"
|
|
65
|
-
srcset="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.2.
|
|
65
|
+
srcset="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.2.5/assets/figures/univi_overview_dark.png">
|
|
66
66
|
<!-- Light mode / fallback (works on GitHub + PyPI) -->
|
|
67
|
-
<img src="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.2.
|
|
67
|
+
<img src="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.2.5/assets/figures/univi_overview_light.png"
|
|
68
68
|
alt="UniVI overview and evaluation roadmap"
|
|
69
69
|
width="100%">
|
|
70
70
|
</picture>
|
|
71
71
|
|
|
72
|
-
**UniVI overview and evaluation roadmap.**
|
|
73
|
-
(a) Generic UniVI architecture schematic. (b) Core training objective (for UniVI v1 - see documentation for UniVI-lite training objective). (c) Example modality combinations beyond bi-modal data (e.g. TEA-seq (tri-modal RNA + ATAC + ADT)). (d) Evaluation roadmap spanning latent alignment (FOSCTTM
|
|
72
|
+
**UniVI overview and evaluation roadmap.**
|
|
73
|
+
(a) Generic UniVI architecture schematic. (b) Core training objective (for UniVI v1 - see documentation for UniVI-lite training objective). (c) Example modality combinations beyond bi-modal data (e.g. TEA-seq (tri-modal RNA + ATAC + ADT)). (d) Evaluation roadmap spanning latent alignment (FOSCTTM), modality mixing, label transfer, reconstruction/prediction NLL, and downstream biological consistency.
|
|
74
74
|
|
|
75
75
|
---
|
|
76
76
|
|
|
77
77
|
UniVI is a **multi-modal variational autoencoder (VAE)** framework for aligning and integrating single-cell modalities such as RNA, ADT (CITE-seq), and ATAC. It’s built to support experiments like:
|
|
78
78
|
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
79
|
+
* Joint embedding of RNA + ADT (CITE-seq)
|
|
80
|
+
* RNA + ATAC (Multiome) integration
|
|
81
|
+
* RNA + ADT + ATAC (TEA-seq) tri-modal data integration
|
|
82
|
+
* Independent non-paired modalities from the same tissue type
|
|
83
|
+
* Cross-modal reconstruction and imputation
|
|
84
|
+
* Data denoising
|
|
85
|
+
* Structured evaluation of alignment quality (FOSCTTM, modality mixing, label transfer, etc.)
|
|
86
|
+
* Exploratory analysis of the relationships between heterogeneous molecular readouts that inform biological functional dimensions
|
|
87
87
|
|
|
88
88
|
This repository contains the core UniVI code, training scripts, parameter files, and example notebooks.
|
|
89
89
|
|
|
@@ -93,8 +93,8 @@ This repository contains the core UniVI code, training scripts, parameter files,
|
|
|
93
93
|
|
|
94
94
|
If you use UniVI in your work, please cite:
|
|
95
95
|
|
|
96
|
-
> Ashford AJ, Enright T, Nikolova O, Demir E.
|
|
97
|
-
> **Unifying Multimodal Single-Cell Data Using a Mixture of Experts β-Variational Autoencoder-Based Framework.**
|
|
96
|
+
> Ashford AJ, Enright T, Nikolova O, Demir E.
|
|
97
|
+
> **Unifying Multimodal Single-Cell Data Using a Mixture of Experts β-Variational Autoencoder-Based Framework.**
|
|
98
98
|
> *bioRxiv* (2025). doi: [10.1101/2025.02.28.640429](https://www.biorxiv.org/content/10.1101/2025.02.28.640429v1.full)
|
|
99
99
|
|
|
100
100
|
```bibtex
|
|
@@ -106,11 +106,12 @@ If you use UniVI in your work, please cite:
|
|
|
106
106
|
doi = {10.1101/2025.02.28.640429},
|
|
107
107
|
url = {https://www.biorxiv.org/content/10.1101/2025.02.28.640429v1}
|
|
108
108
|
}
|
|
109
|
-
|
|
109
|
+
```
|
|
110
110
|
|
|
111
111
|
---
|
|
112
112
|
|
|
113
113
|
## License
|
|
114
|
+
|
|
114
115
|
MIT License — see `LICENSE`.
|
|
115
116
|
|
|
116
117
|
---
|
|
@@ -145,7 +146,7 @@ UniVI/
|
|
|
145
146
|
│ ├── evaluate_univi.py # Evaluate trained models (FOSCTTM, label transfer, etc.)
|
|
146
147
|
│ ├── benchmark_univi_citeseq.py # CITE-seq-specific benchmarking script
|
|
147
148
|
│ ├── run_multiome_hparam_search.py
|
|
148
|
-
│ ├── run_frequency_robustness.py
|
|
149
|
+
│ ├── run_frequency_robustness.py # Composition/frequency mismatch robustness
|
|
149
150
|
│ ├── run_do_not_integrate_detection.py # “Do-not-integrate” unmatched population demo
|
|
150
151
|
│ ├── run_benchmarks.py # Unified wrapper (includes optional Harmony baseline)
|
|
151
152
|
│ └── revision_reproduce_all.sh # One-click: reproduces figures + supplemental tables
|
|
@@ -189,6 +190,8 @@ UniVI/
|
|
|
189
190
|
|
|
190
191
|
```
|
|
191
192
|
|
|
193
|
+
---
|
|
194
|
+
|
|
192
195
|
## Generated outputs
|
|
193
196
|
|
|
194
197
|
Most entry-point scripts write results into a user-specified output directory (commonly `runs/`), which is **not** tracked in git.
|
|
@@ -339,49 +342,252 @@ See the notebooks under `notebooks/` for end-to-end preprocessing examples for C
|
|
|
339
342
|
|
|
340
343
|
---
|
|
341
344
|
|
|
342
|
-
##
|
|
345
|
+
## Training modes & example recipes (v1 vs v2/lite + supervised options)
|
|
343
346
|
|
|
344
347
|
UniVI supports two training regimes:
|
|
345
348
|
|
|
346
|
-
* **UniVI v1**:
|
|
347
|
-
* **UniVI-lite**:
|
|
349
|
+
* **UniVI v1**: per-modality posteriors + reconstruction terms controlled by `v1_recon` (cross/self/avg/etc.) + posterior alignment across modality posteriors.
|
|
350
|
+
* **UniVI-lite / v2**: fused latent posterior (precision-weighted MoE/PoE style) + per-modality reconstruction + β·KL(q_fused||p) + γ·pairwise alignment between modality posteriors. Scales cleanly to 3+ modalities and is the recommended default.
|
|
351
|
+
|
|
352
|
+
### Which supervised option should I use?
|
|
353
|
+
|
|
354
|
+
Use labels to “shape” the latent in one of three ways:
|
|
355
|
+
|
|
356
|
+
1. **Classification head (decoder-only)** — `p(y|z)` (**recommended default**)
|
|
357
|
+
*Works for `loss_mode="lite"` and `loss_mode="v1"`.*
|
|
358
|
+
Best if you want the latent to be predictive/separable without changing how modalities reconstruct.
|
|
359
|
+
|
|
360
|
+
2. **Label expert injected into fusion (encoder-side)** — `q(z|y)` (**lite/v2 only**)
|
|
361
|
+
*Works only for `loss_mode="lite"` / `v2`.*
|
|
362
|
+
Best for semi-supervised settings where labels should directly influence the **fused posterior**.
|
|
363
|
+
|
|
364
|
+
3. **Labels as a full categorical “modality”** — `"celltype"` modality with likelihood `"categorical"`
|
|
365
|
+
*Best with `loss_mode="lite"`.*
|
|
366
|
+
Useful when you want cell types to behave like a first-class modality (encode/decode/reconstruct), but avoid `v1` cross-reconstruction unless you really know you want it.
|
|
367
|
+
|
|
368
|
+
---
|
|
369
|
+
|
|
370
|
+
## Supervised labels (three supported patterns)
|
|
371
|
+
|
|
372
|
+
### A) Latent classification head (decoder-only): `p(y|z)` (works in **lite/v2** and **v1**)
|
|
373
|
+
|
|
374
|
+
This is the simplest way to shape the latent. UniVI attaches a categorical head to the latent `z` and adds:
|
|
375
|
+
|
|
376
|
+
```math
|
|
377
|
+
\mathcal{L} \;+=\; \lambda \cdot \mathrm{CE}(\mathrm{logits}(z), y)
|
|
378
|
+
```
|
|
379
|
+
|
|
380
|
+
**How to enable:** initialize the model with:
|
|
381
|
+
|
|
382
|
+
* `n_label_classes > 0`
|
|
383
|
+
* `label_loss_weight` (default `1.0`)
|
|
384
|
+
* `label_ignore_index` (default `-1`, used to mask unlabeled rows)
|
|
385
|
+
|
|
386
|
+
```python
|
|
387
|
+
import numpy as np
|
|
388
|
+
import torch
|
|
389
|
+
|
|
390
|
+
from univi import UniVIMultiModalVAE, UniVIConfig, ModalityConfig
|
|
391
|
+
|
|
392
|
+
# Example labels (0..C-1) from AnnData
|
|
393
|
+
y_codes = rna.obs["celltype"].astype("category").cat.codes.to_numpy()
|
|
394
|
+
n_classes = int(y_codes.max() + 1)
|
|
395
|
+
|
|
396
|
+
univi_cfg = UniVIConfig(
|
|
397
|
+
latent_dim=40,
|
|
398
|
+
beta=5.0,
|
|
399
|
+
gamma=40.0,
|
|
400
|
+
modalities=[
|
|
401
|
+
ModalityConfig("rna", rna.n_vars, [1024, 512], [512, 1024], likelihood="nb"),
|
|
402
|
+
ModalityConfig("adt", adt.n_vars, [256, 128], [128, 256], likelihood="nb"),
|
|
403
|
+
],
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
model = UniVIMultiModalVAE(
|
|
407
|
+
univi_cfg,
|
|
408
|
+
loss_mode="lite", # OR "v1"
|
|
409
|
+
n_label_classes=n_classes,
|
|
410
|
+
label_loss_weight=1.0,
|
|
411
|
+
label_ignore_index=-1,
|
|
412
|
+
classify_from_mu=True,
|
|
413
|
+
).to("cuda")
|
|
414
|
+
```
|
|
415
|
+
|
|
416
|
+
During training your batch should provide `y`, and your loop should call:
|
|
417
|
+
|
|
418
|
+
```python
|
|
419
|
+
out = model(x_dict, y=y, epoch=epoch)
|
|
420
|
+
loss = out["loss"]
|
|
421
|
+
```
|
|
422
|
+
|
|
423
|
+
Unlabeled cells are supported: set `y=-1` and CE is automatically masked.
|
|
424
|
+
|
|
425
|
+
---
|
|
426
|
+
|
|
427
|
+
### B) Label expert injected into fusion: `q(z|y)` (**lite/v2 only**)
|
|
428
|
+
|
|
429
|
+
In **lite/v2**, UniVI can optionally add a **label encoder** as an additional expert into MoE fusion. Labeled cells get an extra “expert vote” in the fused posterior; unlabeled cells ignore it automatically.
|
|
430
|
+
|
|
431
|
+
```python
|
|
432
|
+
model = UniVIMultiModalVAE(
|
|
433
|
+
univi_cfg,
|
|
434
|
+
loss_mode="lite",
|
|
435
|
+
|
|
436
|
+
# Optional: keep the decoder-side classification head too
|
|
437
|
+
n_label_classes=n_classes,
|
|
438
|
+
label_loss_weight=1.0,
|
|
439
|
+
|
|
440
|
+
# Encoder-side label expert injected into fusion
|
|
441
|
+
use_label_encoder=True,
|
|
442
|
+
label_moe_weight=1.0, # >1 => labels influence fusion more
|
|
443
|
+
unlabeled_logvar=20.0, # very high => tiny precision => ignored in fusion
|
|
444
|
+
label_encoder_warmup=5, # wait N epochs before injecting labels into fusion
|
|
445
|
+
label_ignore_index=-1,
|
|
446
|
+
).to("cuda")
|
|
447
|
+
```
|
|
448
|
+
|
|
449
|
+
**Notes**
|
|
450
|
+
|
|
451
|
+
* This pathway is **only used in `loss_mode="lite"` / `v2`**, because it is implemented as an extra expert inside fusion.
|
|
452
|
+
* Unlabeled cells (`y=-1`) are automatically ignored in fusion via a huge log-variance.
|
|
453
|
+
|
|
454
|
+
---
|
|
455
|
+
|
|
456
|
+
### C) Treat labels as a categorical “modality” (best with **lite/v2**)
|
|
457
|
+
|
|
458
|
+
Instead of providing `y` separately, you can represent labels as another modality (e.g. `"celltype"`) with likelihood `"categorical"`. This makes labels a first-class modality with its own encoder/decoder.
|
|
459
|
+
|
|
460
|
+
**Recommended representation:** one-hot matrix `(B, C)` stored in `.X`.
|
|
461
|
+
|
|
462
|
+
```python
|
|
463
|
+
import numpy as np
|
|
464
|
+
from anndata import AnnData
|
|
465
|
+
|
|
466
|
+
# y codes (0..C-1)
|
|
467
|
+
y_codes = rna.obs["celltype"].astype("category").cat.codes.to_numpy()
|
|
468
|
+
C = int(y_codes.max() + 1)
|
|
469
|
+
|
|
470
|
+
Y = np.eye(C, dtype=np.float32)[y_codes] # (B, C) one-hot
|
|
471
|
+
|
|
472
|
+
celltype = AnnData(X=Y)
|
|
473
|
+
celltype.obs_names = rna.obs_names.copy() # MUST match paired modalities
|
|
474
|
+
celltype.var_names = [f"class_{i}" for i in range(C)]
|
|
475
|
+
|
|
476
|
+
adata_dict = {"rna": rna, "adt": adt, "celltype": celltype}
|
|
477
|
+
|
|
478
|
+
univi_cfg = UniVIConfig(
|
|
479
|
+
latent_dim=40,
|
|
480
|
+
beta=5.0,
|
|
481
|
+
gamma=40.0,
|
|
482
|
+
modalities=[
|
|
483
|
+
ModalityConfig("rna", rna.n_vars, [1024, 512], [512, 1024], likelihood="nb"),
|
|
484
|
+
ModalityConfig("adt", adt.n_vars, [256, 128], [128, 256], likelihood="nb"),
|
|
485
|
+
ModalityConfig("celltype", C, [128], [128], likelihood="categorical"),
|
|
486
|
+
],
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
model = UniVIMultiModalVAE(univi_cfg, loss_mode="lite").to("cuda")
|
|
490
|
+
```
|
|
491
|
+
|
|
492
|
+
**Important caveat for `loss_mode="v1"`**
|
|
493
|
+
`v1` can perform cross-reconstruction across all modalities. If you include `"celltype"` as a modality, you typically **do not** want cross-recon terms like `celltype → RNA`. If you must run `v1` with label-as-modality, prefer:
|
|
494
|
+
|
|
495
|
+
```python
|
|
496
|
+
model = UniVIMultiModalVAE(univi_cfg, loss_mode="v1", v1_recon="self").to("cuda")
|
|
497
|
+
```
|
|
498
|
+
|
|
499
|
+
If you want full `v1` cross-reconstruction and label shaping, prefer **Pattern A (classification head)** instead.
|
|
500
|
+
|
|
501
|
+
---
|
|
502
|
+
|
|
503
|
+
## Running a minimal training script (UniVI v1 vs UniVI-lite)
|
|
348
504
|
|
|
349
505
|
### 0) Choose the training objective (`loss_mode`) in your config JSON
|
|
350
506
|
|
|
351
|
-
In `parameter_files/*.json`, set a single switch that controls the objective
|
|
507
|
+
In `parameter_files/*.json`, set a single switch that controls the objective.
|
|
352
508
|
|
|
353
|
-
**Paper objective (v1;
|
|
509
|
+
**Paper objective (v1; `"avg"` trains with 50% weight on self-reconstruction and 50% weight on cross-reconstruction, with weights automatically adjusted so this stays true for any number of modalities):**
|
|
354
510
|
|
|
355
511
|
```json5
|
|
356
512
|
{
|
|
357
513
|
"model": {
|
|
358
514
|
"loss_mode": "v1",
|
|
359
|
-
"v1_recon": "
|
|
360
|
-
"v1_recon_mix": 0.0, // optional extra averaged-z recon weight
|
|
515
|
+
"v1_recon": "avg",
|
|
361
516
|
"normalize_v1_terms": true
|
|
517
|
+
}
|
|
362
518
|
}
|
|
363
|
-
|
|
519
|
+
```
|
|
364
520
|
|
|
365
521
|
**UniVI-lite objective (v2; lightweight / fusion-based):**
|
|
366
522
|
|
|
367
523
|
```json5
|
|
368
524
|
{
|
|
369
525
|
"model": {
|
|
370
|
-
"loss_mode": "lite"
|
|
371
|
-
|
|
372
|
-
"v1_recon_mix": 0.0, // doesn't get used if loss_mode="lite"
|
|
373
|
-
"normalize_v1_terms": true // doesn't get used if loss_mode="lite"
|
|
526
|
+
"loss_mode": "lite"
|
|
527
|
+
}
|
|
374
528
|
}
|
|
375
529
|
```
|
|
376
530
|
|
|
377
531
|
> **Note**
|
|
378
532
|
> `loss_mode: "lite"` is an alias for `loss_mode: "v2"` (they run the same objective in the current code).
|
|
379
533
|
|
|
534
|
+
### 0b) (Optional) Enable supervised labels from config JSON
|
|
535
|
+
|
|
536
|
+
**Classification head (decoder-only):**
|
|
537
|
+
|
|
538
|
+
```json5
|
|
539
|
+
{
|
|
540
|
+
"model": {
|
|
541
|
+
"loss_mode": "lite",
|
|
542
|
+
"n_label_classes": 30,
|
|
543
|
+
"label_loss_weight": 1.0,
|
|
544
|
+
"label_ignore_index": -1,
|
|
545
|
+
"classify_from_mu": true
|
|
546
|
+
}
|
|
547
|
+
}
|
|
548
|
+
```
|
|
549
|
+
|
|
550
|
+
**Lite + label expert injected into fusion (encoder-side):**
|
|
551
|
+
|
|
552
|
+
```json5
|
|
553
|
+
{
|
|
554
|
+
"model": {
|
|
555
|
+
"loss_mode": "lite",
|
|
556
|
+
"n_label_classes": 30,
|
|
557
|
+
"label_loss_weight": 1.0,
|
|
558
|
+
|
|
559
|
+
"use_label_encoder": true,
|
|
560
|
+
"label_moe_weight": 1.0,
|
|
561
|
+
"unlabeled_logvar": 20.0,
|
|
562
|
+
"label_encoder_warmup": 5,
|
|
563
|
+
"label_ignore_index": -1
|
|
564
|
+
}
|
|
565
|
+
}
|
|
566
|
+
```
|
|
567
|
+
|
|
568
|
+
**Labels as a categorical modality:** add an additional `"celltype"` modality in `"data.modalities"` and provide a matching AnnData on disk (or build it in Python).
|
|
569
|
+
|
|
570
|
+
```json5
|
|
571
|
+
{
|
|
572
|
+
"model": { "loss_mode": "lite" },
|
|
573
|
+
"data": {
|
|
574
|
+
"modalities": [
|
|
575
|
+
{ "name": "rna", "likelihood": "nb", "X_key": "X", "layer": "counts" },
|
|
576
|
+
{ "name": "adt", "likelihood": "nb", "X_key": "X", "layer": "counts" },
|
|
577
|
+
{ "name": "celltype", "likelihood": "categorical", "X_key": "X", "layer": null }
|
|
578
|
+
]
|
|
579
|
+
}
|
|
580
|
+
}
|
|
581
|
+
```
|
|
582
|
+
|
|
380
583
|
### 1) Normalization / representation switch (counts vs continuous)
|
|
381
584
|
|
|
382
|
-
|
|
585
|
+
**Important note on selectors:**
|
|
586
|
+
|
|
587
|
+
* `layer` selects `.layers[layer]` (if `X_key == "X"`).
|
|
588
|
+
* `X_key == "X"` selects `.X`/`.layers[layer]`; otherwise `X_key` selects `.obsm[X_key]`.
|
|
383
589
|
|
|
384
|
-
|
|
590
|
+
Correct pattern:
|
|
385
591
|
|
|
386
592
|
```json5
|
|
387
593
|
{
|
|
@@ -389,35 +595,34 @@ Recommended pattern (example showing several preprocessing options for the diffe
|
|
|
389
595
|
"modalities": [
|
|
390
596
|
{
|
|
391
597
|
"name": "rna",
|
|
392
|
-
"layer": "log1p",
|
|
598
|
+
"layer": "log1p", // uses adata.layers["log1p"] (since X_key=="X")
|
|
393
599
|
"X_key": "X",
|
|
394
|
-
"assume_log1p": true,
|
|
600
|
+
"assume_log1p": true,
|
|
395
601
|
"likelihood": "gaussian"
|
|
396
602
|
},
|
|
397
603
|
{
|
|
398
604
|
"name": "adt",
|
|
399
|
-
"layer": "counts",
|
|
400
|
-
"X_key": "
|
|
401
|
-
"assume_log1p": false,
|
|
605
|
+
"layer": "counts", // uses adata.layers["counts"] (since X_key=="X")
|
|
606
|
+
"X_key": "X",
|
|
607
|
+
"assume_log1p": false,
|
|
402
608
|
"likelihood": "zinb"
|
|
403
609
|
},
|
|
404
610
|
{
|
|
405
611
|
"name": "atac",
|
|
406
|
-
"layer":
|
|
407
|
-
"X_key": "X_lsi",
|
|
612
|
+
"layer": null, // ignored because X_key != "X"
|
|
613
|
+
"X_key": "X_lsi", // uses adata.obsm["X_lsi"]
|
|
408
614
|
"assume_log1p": false,
|
|
409
615
|
"likelihood": "gaussian"
|
|
410
616
|
}
|
|
411
617
|
]
|
|
412
618
|
}
|
|
413
619
|
}
|
|
414
|
-
|
|
415
620
|
```
|
|
416
621
|
|
|
417
622
|
* Use `.layers["counts"]` when you want NB/ZINB/Poisson decoders.
|
|
418
|
-
* Use continuous `.X`
|
|
623
|
+
* Use continuous `.X` or `.obsm["X_lsi"]` when you want Gaussian/MSE decoders.
|
|
419
624
|
|
|
420
|
-
> Jupyter
|
|
625
|
+
> Jupyter notebooks in this repository (UniVI/notebooks/) show recommended preprocessing per dataset for different data types and analyses. Depending on your research goals, you can use several different methods of preprocessing. The model is robust when it comes to learning underlying biology regardless of preprocessing; the key is that the decoder likelihood should roughly match the input distribution per-modality.
|
|
421
626
|
|
|
422
627
|
### 2) Train (CLI)
|
|
423
628
|
|
|
@@ -456,8 +661,8 @@ python scripts/train_univi.py \
|
|
|
456
661
|
|
|
457
662
|
```bash
|
|
458
663
|
python scripts/train_univi.py \
|
|
459
|
-
--config parameter_files/ \
|
|
460
|
-
--outdir saved_models/
|
|
664
|
+
--config parameter_files/defaults_multiome_lite.json \
|
|
665
|
+
--outdir saved_models/multiome_lite_run1 \
|
|
461
666
|
--data-root /path/to/your/data
|
|
462
667
|
```
|
|
463
668
|
|
|
@@ -481,7 +686,9 @@ python scripts/train_univi.py \
|
|
|
481
686
|
--data-root /path/to/your/data
|
|
482
687
|
```
|
|
483
688
|
|
|
484
|
-
|
|
689
|
+
---
|
|
690
|
+
|
|
691
|
+
## Quickstart: run UniVI from Python / Jupyter
|
|
485
692
|
|
|
486
693
|
If you prefer to stay inside a notebook or a Python script instead of calling the CLI, you can build the configs, model, and trainer directly.
|
|
487
694
|
|
|
@@ -500,33 +707,29 @@ from univi import (
|
|
|
500
707
|
UniVIConfig,
|
|
501
708
|
TrainingConfig,
|
|
502
709
|
)
|
|
503
|
-
from univi.data import MultiModalDataset
|
|
710
|
+
from univi.data import MultiModalDataset, align_paired_obs_names
|
|
504
711
|
from univi.trainer import UniVITrainer
|
|
505
|
-
|
|
712
|
+
```
|
|
506
713
|
|
|
507
|
-
|
|
714
|
+
### 1) Load preprocessed AnnData (paired cells)
|
|
508
715
|
|
|
509
716
|
```python
|
|
510
|
-
# Example: CITE-seq with RNA + ADT
|
|
511
717
|
rna = sc.read_h5ad("path/to/rna_citeseq.h5ad")
|
|
512
718
|
adt = sc.read_h5ad("path/to/adt_citeseq.h5ad")
|
|
513
719
|
|
|
514
|
-
|
|
515
|
-
adata_dict =
|
|
516
|
-
"rna": rna,
|
|
517
|
-
"adt": adt,
|
|
518
|
-
}
|
|
720
|
+
adata_dict = {"rna": rna, "adt": adt}
|
|
721
|
+
adata_dict = align_paired_obs_names(adata_dict) # ensures same obs_names/order
|
|
519
722
|
```
|
|
520
723
|
|
|
521
|
-
|
|
724
|
+
### 2) Build `MultiModalDataset` and DataLoaders (unsupervised)
|
|
522
725
|
|
|
523
726
|
```python
|
|
524
727
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
525
728
|
|
|
526
729
|
dataset = MultiModalDataset(
|
|
527
730
|
adata_dict=adata_dict,
|
|
528
|
-
X_key="X",
|
|
529
|
-
device=
|
|
731
|
+
X_key="X",
|
|
732
|
+
device=None, # "cpu" or "cuda"
|
|
530
733
|
)
|
|
531
734
|
|
|
532
735
|
n_cells = rna.n_obs
|
|
@@ -542,29 +745,35 @@ val_ds = Subset(dataset, val_idx)
|
|
|
542
745
|
|
|
543
746
|
batch_size = 256
|
|
544
747
|
|
|
545
|
-
train_loader = DataLoader(
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
shuffle=True,
|
|
549
|
-
num_workers=0,
|
|
550
|
-
)
|
|
748
|
+
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
|
|
749
|
+
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
|
750
|
+
```
|
|
551
751
|
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
)
|
|
752
|
+
### 2b) (Optional) Supervised batches for Pattern A/B (`(x_dict, y)`)
|
|
753
|
+
|
|
754
|
+
If you use the classification head and/or label expert injection, supply `y` as integer class indices and mask unlabeled with `-1`.
|
|
755
|
+
|
|
756
|
+
```python
|
|
757
|
+
y_codes = rna.obs["celltype"].astype("category").cat.codes.to_numpy()
|
|
758
|
+
|
|
759
|
+
dataset_sup = MultiModalDataset(adata_dict=adata_dict, X_key="X", labels=y_codes)
|
|
760
|
+
|
|
761
|
+
def collate_xy(batch):
|
|
762
|
+
xs, ys = zip(*batch)
|
|
763
|
+
x = {k: torch.stack([d[k] for d in xs], 0) for k in xs[0].keys()}
|
|
764
|
+
y = torch.as_tensor(ys, dtype=torch.long)
|
|
765
|
+
return x, y
|
|
766
|
+
|
|
767
|
+
train_loader = DataLoader(dataset_sup, batch_size=batch_size, shuffle=True, collate_fn=collate_xy)
|
|
558
768
|
```
|
|
559
769
|
|
|
560
|
-
|
|
770
|
+
### 3) Define UniVI configs (v1 vs UniVI-lite)
|
|
561
771
|
|
|
562
772
|
```python
|
|
563
|
-
# UniVI model config (architecture + regularization)
|
|
564
773
|
univi_cfg = UniVIConfig(
|
|
565
774
|
latent_dim=40,
|
|
566
|
-
beta=5.0,
|
|
567
|
-
gamma=40.0,
|
|
775
|
+
beta=5.0,
|
|
776
|
+
gamma=40.0,
|
|
568
777
|
encoder_dropout=0.1,
|
|
569
778
|
decoder_dropout=0.0,
|
|
570
779
|
encoder_batchnorm=True,
|
|
@@ -574,24 +783,11 @@ univi_cfg = UniVIConfig(
|
|
|
574
783
|
align_anneal_start=0,
|
|
575
784
|
align_anneal_end=25,
|
|
576
785
|
modalities=[
|
|
577
|
-
ModalityConfig(
|
|
578
|
-
|
|
579
|
-
input_dim=rna.n_vars,
|
|
580
|
-
encoder_hidden=[1024, 512],
|
|
581
|
-
decoder_hidden=[512, 1024],
|
|
582
|
-
likelihood="nb", # counts-like RNA
|
|
583
|
-
),
|
|
584
|
-
ModalityConfig(
|
|
585
|
-
name="adt",
|
|
586
|
-
input_dim=adt.n_vars,
|
|
587
|
-
encoder_hidden=[256, 128],
|
|
588
|
-
decoder_hidden=[128, 256],
|
|
589
|
-
likelihood="nb", # counts-like ADT
|
|
590
|
-
),
|
|
786
|
+
ModalityConfig("rna", rna.n_vars, [1024, 512], [512, 1024], likelihood="nb"),
|
|
787
|
+
ModalityConfig("adt", adt.n_vars, [256, 128], [128, 256], likelihood="nb"),
|
|
591
788
|
],
|
|
592
789
|
)
|
|
593
790
|
|
|
594
|
-
# Training config (epochs, LR, device, etc.)
|
|
595
791
|
train_cfg = TrainingConfig(
|
|
596
792
|
n_epochs=200,
|
|
597
793
|
batch_size=batch_size,
|
|
@@ -608,27 +804,47 @@ train_cfg = TrainingConfig(
|
|
|
608
804
|
)
|
|
609
805
|
```
|
|
610
806
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
* **v1** (paper objective): cross-reconstruction + cross-posterior alignment.
|
|
614
|
-
* Best when batches are paired/pseudo-paired and you want explicit cross-prediction.
|
|
615
|
-
* **lite** (aka `"v2"`): missing-modality friendly; trains even if some modalities are absent in a batch.
|
|
807
|
+
### 4) Choose the objective + supervised option
|
|
616
808
|
|
|
617
809
|
```python
|
|
618
|
-
# Option A: UniVI v1 (
|
|
810
|
+
# Option A: UniVI v1 (unsupervised)
|
|
619
811
|
model = UniVIMultiModalVAE(
|
|
620
812
|
univi_cfg,
|
|
621
813
|
loss_mode="v1",
|
|
622
|
-
v1_recon="
|
|
623
|
-
v1_recon_mix=0.0,
|
|
814
|
+
v1_recon="avg",
|
|
815
|
+
v1_recon_mix=0.0,
|
|
624
816
|
normalize_v1_terms=True,
|
|
625
817
|
).to(device)
|
|
626
818
|
|
|
627
|
-
# Option B: UniVI-lite (
|
|
819
|
+
# Option B: UniVI-lite / v2 (unsupervised)
|
|
628
820
|
# model = UniVIMultiModalVAE(univi_cfg, loss_mode="lite").to(device)
|
|
821
|
+
|
|
822
|
+
# Option C: Add classification head (Pattern A; works in lite/v2 AND v1)
|
|
823
|
+
# n_classes = int(y_codes.max() + 1)
|
|
824
|
+
# model = UniVIMultiModalVAE(
|
|
825
|
+
# univi_cfg,
|
|
826
|
+
# loss_mode="lite",
|
|
827
|
+
# n_label_classes=n_classes,
|
|
828
|
+
# label_loss_weight=1.0,
|
|
829
|
+
# label_ignore_index=-1,
|
|
830
|
+
# classify_from_mu=True,
|
|
831
|
+
# ).to(device)
|
|
832
|
+
|
|
833
|
+
# Option D: Add label expert injection into fusion (Pattern B; lite/v2 ONLY)
|
|
834
|
+
# model = UniVIMultiModalVAE(
|
|
835
|
+
# univi_cfg,
|
|
836
|
+
# loss_mode="lite",
|
|
837
|
+
# n_label_classes=n_classes,
|
|
838
|
+
# label_loss_weight=1.0,
|
|
839
|
+
# use_label_encoder=True,
|
|
840
|
+
# label_moe_weight=1.0,
|
|
841
|
+
# unlabeled_logvar=20.0,
|
|
842
|
+
# label_encoder_warmup=5,
|
|
843
|
+
# label_ignore_index=-1,
|
|
844
|
+
# ).to(device)
|
|
629
845
|
```
|
|
630
846
|
|
|
631
|
-
|
|
847
|
+
### 5) Train inside Python / Jupyter
|
|
632
848
|
|
|
633
849
|
```python
|
|
634
850
|
trainer = UniVITrainer(
|
|
@@ -639,26 +855,36 @@ trainer = UniVITrainer(
|
|
|
639
855
|
device=device,
|
|
640
856
|
)
|
|
641
857
|
|
|
642
|
-
history = trainer.fit()
|
|
858
|
+
history = trainer.fit()
|
|
643
859
|
```
|
|
644
860
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
* Computing latent embeddings (`encode_modalities` / `mixture_of_experts`)
|
|
648
|
-
* Cross-modal reconstruction (forward passes with different modality subsets)
|
|
649
|
-
* Exporting `z` to AnnData or NumPy for downstream analysis (UMAP, clustering, DE, etc.)
|
|
650
|
-
|
|
651
|
-
#### 6) Write latent `z` into AnnData `.obsm["X_univi"]`
|
|
861
|
+
### 6) Write latent `z` into AnnData `.obsm["X_univi"]`
|
|
652
862
|
|
|
653
863
|
```python
|
|
654
864
|
from univi import write_univi_latent
|
|
655
865
|
|
|
656
|
-
Z = write_univi_latent(model, adata_dict, obsm_key="X_univi", device=device)
|
|
866
|
+
Z = write_univi_latent(model, adata_dict, obsm_key="X_univi", device=device, use_mean=True)
|
|
657
867
|
print("Embedding shape:", Z.shape)
|
|
658
868
|
```
|
|
659
869
|
|
|
660
870
|
> **Tip**
|
|
661
|
-
>
|
|
871
|
+
> Use `use_mean=True` for deterministic plotting/UMAP. Sampling (`use_mean=False`) is stochastic and useful for generative behavior.
|
|
872
|
+
|
|
873
|
+
---
|
|
874
|
+
|
|
875
|
+
## Evaluating / encoding: choosing the latent representation
|
|
876
|
+
|
|
877
|
+
Some utilities (e.g., `encode_adata`) support selecting what embedding to return:
|
|
878
|
+
|
|
879
|
+
* `"moe_mean"` / `"moe_sample"`: fused latent (MoE/PoE)
|
|
880
|
+
* `"modality_mean"` / `"modality_sample"`: per-modality latent
|
|
881
|
+
|
|
882
|
+
```python
|
|
883
|
+
from univi.evaluation import encode_adata
|
|
884
|
+
|
|
885
|
+
Z_rna = encode_adata(model, rna, modality="rna", device=device, layer="counts", latent="modality_mean")
|
|
886
|
+
Z_moe = encode_adata(model, rna, modality="rna", device=device, layer="counts", latent="moe_mean")
|
|
887
|
+
```
|
|
662
888
|
|
|
663
889
|
---
|
|
664
890
|
|
|
@@ -819,3 +1045,5 @@ Typical evaluation outputs include:
|
|
|
819
1045
|
|
|
820
1046
|
For richer, exploratory workflows (TEA-seq tri-modal integration, Multiome RNA+ATAC, non-paired matching, etc.), see the notebooks in `notebooks/`.
|
|
821
1047
|
|
|
1048
|
+
---
|
|
1049
|
+
|