univi 0.2.1__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {univi-0.2.1/univi.egg-info → univi-0.2.5}/PKG-INFO +336 -108
  2. {univi-0.2.1 → univi-0.2.5}/README.md +335 -107
  3. {univi-0.2.1 → univi-0.2.5}/pyproject.toml +1 -1
  4. univi-0.2.5/univi/__init__.py +85 -0
  5. univi-0.2.5/univi/config.py +110 -0
  6. univi-0.2.5/univi/data.py +364 -0
  7. univi-0.2.5/univi/evaluation.py +377 -0
  8. {univi-0.2.1 → univi-0.2.5}/univi/models/__init__.py +4 -0
  9. univi-0.2.5/univi/models/decoders.py +249 -0
  10. {univi-0.2.1 → univi-0.2.5}/univi/models/encoders.py +3 -3
  11. {univi-0.2.1 → univi-0.2.5}/univi/models/mlp.py +11 -5
  12. univi-0.2.5/univi/models/univi.py +886 -0
  13. univi-0.2.5/univi/plotting.py +126 -0
  14. univi-0.2.5/univi/trainer.py +340 -0
  15. univi-0.2.5/univi/utils/io.py +222 -0
  16. {univi-0.2.1 → univi-0.2.5/univi.egg-info}/PKG-INFO +336 -108
  17. univi-0.2.1/univi/__init__.py +0 -35
  18. univi-0.2.1/univi/config.py +0 -71
  19. univi-0.2.1/univi/data.py +0 -190
  20. univi-0.2.1/univi/evaluation.py +0 -555
  21. univi-0.2.1/univi/models/decoders.py +0 -443
  22. univi-0.2.1/univi/models/univi.py +0 -440
  23. univi-0.2.1/univi/plotting.py +0 -129
  24. univi-0.2.1/univi/trainer.py +0 -294
  25. univi-0.2.1/univi/utils/io.py +0 -230
  26. {univi-0.2.1 → univi-0.2.5}/LICENSE +0 -0
  27. {univi-0.2.1 → univi-0.2.5}/setup.cfg +0 -0
  28. {univi-0.2.1 → univi-0.2.5}/univi/__main__.py +0 -0
  29. {univi-0.2.1 → univi-0.2.5}/univi/cli.py +0 -0
  30. {univi-0.2.1 → univi-0.2.5}/univi/diagnostics.py +0 -0
  31. {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/__init__.py +0 -0
  32. {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/common.py +0 -0
  33. {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_adt_hparam_search.py +0 -0
  34. {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_atac_hparam_search.py +0 -0
  35. {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_citeseq_hparam_search.py +0 -0
  36. {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_multiome_hparam_search.py +0 -0
  37. {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_rna_hparam_search.py +0 -0
  38. {univi-0.2.1 → univi-0.2.5}/univi/hyperparam_optimization/run_teaseq_hparam_search.py +0 -0
  39. {univi-0.2.1 → univi-0.2.5}/univi/matching.py +0 -0
  40. {univi-0.2.1 → univi-0.2.5}/univi/objectives.py +0 -0
  41. {univi-0.2.1 → univi-0.2.5}/univi/pipeline.py +0 -0
  42. {univi-0.2.1 → univi-0.2.5}/univi/utils/__init__.py +0 -0
  43. {univi-0.2.1 → univi-0.2.5}/univi/utils/logging.py +0 -0
  44. {univi-0.2.1 → univi-0.2.5}/univi/utils/seed.py +0 -0
  45. {univi-0.2.1 → univi-0.2.5}/univi/utils/stats.py +0 -0
  46. {univi-0.2.1 → univi-0.2.5}/univi/utils/torch_utils.py +0 -0
  47. {univi-0.2.1 → univi-0.2.5}/univi.egg-info/SOURCES.txt +0 -0
  48. {univi-0.2.1 → univi-0.2.5}/univi.egg-info/dependency_links.txt +0 -0
  49. {univi-0.2.1 → univi-0.2.5}/univi.egg-info/entry_points.txt +0 -0
  50. {univi-0.2.1 → univi-0.2.5}/univi.egg-info/requires.txt +0 -0
  51. {univi-0.2.1 → univi-0.2.5}/univi.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: univi
3
- Version: 0.2.1
3
+ Version: 0.2.5
4
4
  Summary: UniVI: a scalable multi-modal variational autoencoder toolkit for seamless integration and analysis of multimodal single-cell data.
5
5
  Author-email: "Andrew J. Ashford" <ashforda@ohsu.edu>
6
6
  License: MIT License
@@ -57,33 +57,33 @@ Dynamic: license-file
57
57
  # UniVI
58
58
 
59
59
  [![PyPI version](https://img.shields.io/pypi/v/univi)](https://pypi.org/project/univi/)
60
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/univi.svg?v=0.2.1)](https://pypi.org/project/univi/)
60
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/univi.svg?v=0.2.5)](https://pypi.org/project/univi/)
61
61
 
62
62
  <picture>
63
63
  <!-- Dark mode (GitHub supports this; PyPI may ignore <source>) -->
64
64
  <source media="(prefers-color-scheme: dark)"
65
- srcset="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.2.1/assets/figures/univi_overview_dark.png">
65
+ srcset="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.2.5/assets/figures/univi_overview_dark.png">
66
66
  <!-- Light mode / fallback (works on GitHub + PyPI) -->
67
- <img src="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.2.1/assets/figures/univi_overview_light.png"
67
+ <img src="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.2.5/assets/figures/univi_overview_light.png"
68
68
  alt="UniVI overview and evaluation roadmap"
69
69
  width="100%">
70
70
  </picture>
71
71
 
72
- **UniVI overview and evaluation roadmap.**
73
- (a) Generic UniVI architecture schematic. (b) Core training objective (for UniVI v1 - see documentation for UniVI-lite training objective). (c) Example modality combinations beyond bi-modal data (e.g. TEA-seq (tri-modal RNA + ATAC + ADT)). (d) Evaluation roadmap spanning latent alignment (FOSCTTM), modality mixing, label transfer, reconstruction/prediction NLL, and downstream biological consistency.
72
+ **UniVI overview and evaluation roadmap.**
73
+ (a) Generic UniVI architecture schematic. (b) Core training objective (for UniVI v1 - see documentation for UniVI-lite training objective). (c) Example modality combinations beyond bi-modal data (e.g. TEA-seq (tri-modal RNA + ATAC + ADT)). (d) Evaluation roadmap spanning latent alignment (FOSCTTM), modality mixing, label transfer, reconstruction/prediction NLL, and downstream biological consistency.
74
74
 
75
75
  ---
76
76
 
77
77
  UniVI is a **multi-modal variational autoencoder (VAE)** framework for aligning and integrating single-cell modalities such as RNA, ADT (CITE-seq), and ATAC. It’s built to support experiments like:
78
78
 
79
- - Joint embedding of RNA + ADT (CITE-seq)
80
- - RNA + ATAC (Multiome) integration
81
- - RNA + ADT + ATAC (TEA-seq) tri-modal data integration
82
- - Independent non-paired modalities from the same tissue type
83
- - Cross-modal reconstruction and imputation
84
- - Data denoising
85
- - Structured evaluation of alignment quality (FOSCTTM, modality mixing, label transfer, etc.)
86
- - Exploratory analysis of the relationships between heterogeneous molecular readouts that inform biological functional dimensions
79
+ * Joint embedding of RNA + ADT (CITE-seq)
80
+ * RNA + ATAC (Multiome) integration
81
+ * RNA + ADT + ATAC (TEA-seq) tri-modal data integration
82
+ * Independent non-paired modalities from the same tissue type
83
+ * Cross-modal reconstruction and imputation
84
+ * Data denoising
85
+ * Structured evaluation of alignment quality (FOSCTTM, modality mixing, label transfer, etc.)
86
+ * Exploratory analysis of the relationships between heterogeneous molecular readouts that inform biological functional dimensions
87
87
 
88
88
  This repository contains the core UniVI code, training scripts, parameter files, and example notebooks.
89
89
 
@@ -93,8 +93,8 @@ This repository contains the core UniVI code, training scripts, parameter files,
93
93
 
94
94
  If you use UniVI in your work, please cite:
95
95
 
96
- > Ashford AJ, Enright T, Nikolova O, Demir E.
97
- > **Unifying Multimodal Single-Cell Data Using a Mixture of Experts β-Variational Autoencoder-Based Framework.**
96
+ > Ashford AJ, Enright T, Nikolova O, Demir E.
97
+ > **Unifying Multimodal Single-Cell Data Using a Mixture of Experts β-Variational Autoencoder-Based Framework.**
98
98
  > *bioRxiv* (2025). doi: [10.1101/2025.02.28.640429](https://www.biorxiv.org/content/10.1101/2025.02.28.640429v1.full)
99
99
 
100
100
  ```bibtex
@@ -106,11 +106,12 @@ If you use UniVI in your work, please cite:
106
106
  doi = {10.1101/2025.02.28.640429},
107
107
  url = {https://www.biorxiv.org/content/10.1101/2025.02.28.640429v1}
108
108
  }
109
- ````
109
+ ```
110
110
 
111
111
  ---
112
112
 
113
113
  ## License
114
+
114
115
  MIT License — see `LICENSE`.
115
116
 
116
117
  ---
@@ -145,7 +146,7 @@ UniVI/
145
146
  │ ├── evaluate_univi.py # Evaluate trained models (FOSCTTM, label transfer, etc.)
146
147
  │ ├── benchmark_univi_citeseq.py # CITE-seq-specific benchmarking script
147
148
  │ ├── run_multiome_hparam_search.py
148
- │ ├── run_frequency_robustness.py. # Composition/frequency mismatch robustness
149
+ │ ├── run_frequency_robustness.py # Composition/frequency mismatch robustness
149
150
  │ ├── run_do_not_integrate_detection.py # “Do-not-integrate” unmatched population demo
150
151
  │ ├── run_benchmarks.py # Unified wrapper (includes optional Harmony baseline)
151
152
  │ └── revision_reproduce_all.sh # One-click: reproduces figures + supplemental tables
@@ -189,6 +190,8 @@ UniVI/
189
190
 
190
191
  ```
191
192
 
193
+ ---
194
+
192
195
  ## Generated outputs
193
196
 
194
197
  Most entry-point scripts write results into a user-specified output directory (commonly `runs/`), which is **not** tracked in git.
@@ -339,49 +342,252 @@ See the notebooks under `notebooks/` for end-to-end preprocessing examples for C
339
342
 
340
343
  ---
341
344
 
342
- ## Running a minimal training script (UniVI v1 vs UniVI-lite)
345
+ ## Training modes & example recipes (v1 vs v2/lite + supervised options)
343
346
 
344
347
  UniVI supports two training regimes:
345
348
 
346
- * **UniVI v1**: paired/pseudo-paired batches + cross-modal reconstruction (e.g., RNA→ADT and ADT→RNA) + posterior alignment.
347
- * **UniVI-lite**: missing-modality friendly (can train when only a subset of modalities are present in a batch), typically with a lighter latent alignment term.
349
+ * **UniVI v1**: per-modality posteriors + reconstruction terms controlled by `v1_recon` (cross/self/avg/etc.) + posterior alignment across modality posteriors.
350
+ * **UniVI-lite / v2**: fused latent posterior (precision-weighted MoE/PoE style) + per-modality reconstruction + β·KL(q_fused||p) + γ·pairwise alignment between modality posteriors. Scales cleanly to 3+ modalities and is the recommended default.
351
+
352
+ ### Which supervised option should I use?
353
+
354
+ Use labels to “shape” the latent in one of three ways:
355
+
356
+ 1. **Classification head (decoder-only)** — `p(y|z)` (**recommended default**)
357
+ *Works for `loss_mode="lite"` and `loss_mode="v1"`.*
358
+ Best if you want the latent to be predictive/separable without changing how modalities reconstruct.
359
+
360
+ 2. **Label expert injected into fusion (encoder-side)** — `q(z|y)` (**lite/v2 only**)
361
+ *Works only for `loss_mode="lite"` / `v2`.*
362
+ Best for semi-supervised settings where labels should directly influence the **fused posterior**.
363
+
364
+ 3. **Labels as a full categorical “modality”** — `"celltype"` modality with likelihood `"categorical"`
365
+ *Best with `loss_mode="lite"`.*
366
+ Useful when you want cell types to behave like a first-class modality (encode/decode/reconstruct), but avoid `v1` cross-reconstruction unless you really know you want it.
367
+
368
+ ---
369
+
370
+ ## Supervised labels (three supported patterns)
371
+
372
+ ### A) Latent classification head (decoder-only): `p(y|z)` (works in **lite/v2** and **v1**)
373
+
374
+ This is the simplest way to shape the latent. UniVI attaches a categorical head to the latent `z` and adds:
375
+
376
+ ```math
377
+ \mathcal{L} \;+=\; \lambda \cdot \mathrm{CE}(\mathrm{logits}(z), y)
378
+ ```
379
+
380
+ **How to enable:** initialize the model with:
381
+
382
+ * `n_label_classes > 0`
383
+ * `label_loss_weight` (default `1.0`)
384
+ * `label_ignore_index` (default `-1`, used to mask unlabeled rows)
385
+
386
+ ```python
387
+ import numpy as np
388
+ import torch
389
+
390
+ from univi import UniVIMultiModalVAE, UniVIConfig, ModalityConfig
391
+
392
+ # Example labels (0..C-1) from AnnData
393
+ y_codes = rna.obs["celltype"].astype("category").cat.codes.to_numpy()
394
+ n_classes = int(y_codes.max() + 1)
395
+
396
+ univi_cfg = UniVIConfig(
397
+ latent_dim=40,
398
+ beta=5.0,
399
+ gamma=40.0,
400
+ modalities=[
401
+ ModalityConfig("rna", rna.n_vars, [1024, 512], [512, 1024], likelihood="nb"),
402
+ ModalityConfig("adt", adt.n_vars, [256, 128], [128, 256], likelihood="nb"),
403
+ ],
404
+ )
405
+
406
+ model = UniVIMultiModalVAE(
407
+ univi_cfg,
408
+ loss_mode="lite", # OR "v1"
409
+ n_label_classes=n_classes,
410
+ label_loss_weight=1.0,
411
+ label_ignore_index=-1,
412
+ classify_from_mu=True,
413
+ ).to("cuda")
414
+ ```
415
+
416
+ During training your batch should provide `y`, and your loop should call:
417
+
418
+ ```python
419
+ out = model(x_dict, y=y, epoch=epoch)
420
+ loss = out["loss"]
421
+ ```
422
+
423
+ Unlabeled cells are supported: set `y=-1` and CE is automatically masked.
424
+
425
+ ---
426
+
427
+ ### B) Label expert injected into fusion: `q(z|y)` (**lite/v2 only**)
428
+
429
+ In **lite/v2**, UniVI can optionally add a **label encoder** as an additional expert into MoE fusion. Labeled cells get an extra “expert vote” in the fused posterior; unlabeled cells ignore it automatically.
430
+
431
+ ```python
432
+ model = UniVIMultiModalVAE(
433
+ univi_cfg,
434
+ loss_mode="lite",
435
+
436
+ # Optional: keep the decoder-side classification head too
437
+ n_label_classes=n_classes,
438
+ label_loss_weight=1.0,
439
+
440
+ # Encoder-side label expert injected into fusion
441
+ use_label_encoder=True,
442
+ label_moe_weight=1.0, # >1 => labels influence fusion more
443
+ unlabeled_logvar=20.0, # very high => tiny precision => ignored in fusion
444
+ label_encoder_warmup=5, # wait N epochs before injecting labels into fusion
445
+ label_ignore_index=-1,
446
+ ).to("cuda")
447
+ ```
448
+
449
+ **Notes**
450
+
451
+ * This pathway is **only used in `loss_mode="lite"` / `v2`**, because it is implemented as an extra expert inside fusion.
452
+ * Unlabeled cells (`y=-1`) are automatically ignored in fusion via a huge log-variance.
453
+
454
+ ---
455
+
456
+ ### C) Treat labels as a categorical “modality” (best with **lite/v2**)
457
+
458
+ Instead of providing `y` separately, you can represent labels as another modality (e.g. `"celltype"`) with likelihood `"categorical"`. This makes labels a first-class modality with its own encoder/decoder.
459
+
460
+ **Recommended representation:** one-hot matrix `(B, C)` stored in `.X`.
461
+
462
+ ```python
463
+ import numpy as np
464
+ from anndata import AnnData
465
+
466
+ # y codes (0..C-1)
467
+ y_codes = rna.obs["celltype"].astype("category").cat.codes.to_numpy()
468
+ C = int(y_codes.max() + 1)
469
+
470
+ Y = np.eye(C, dtype=np.float32)[y_codes] # (B, C) one-hot
471
+
472
+ celltype = AnnData(X=Y)
473
+ celltype.obs_names = rna.obs_names.copy() # MUST match paired modalities
474
+ celltype.var_names = [f"class_{i}" for i in range(C)]
475
+
476
+ adata_dict = {"rna": rna, "adt": adt, "celltype": celltype}
477
+
478
+ univi_cfg = UniVIConfig(
479
+ latent_dim=40,
480
+ beta=5.0,
481
+ gamma=40.0,
482
+ modalities=[
483
+ ModalityConfig("rna", rna.n_vars, [1024, 512], [512, 1024], likelihood="nb"),
484
+ ModalityConfig("adt", adt.n_vars, [256, 128], [128, 256], likelihood="nb"),
485
+ ModalityConfig("celltype", C, [128], [128], likelihood="categorical"),
486
+ ],
487
+ )
488
+
489
+ model = UniVIMultiModalVAE(univi_cfg, loss_mode="lite").to("cuda")
490
+ ```
491
+
492
+ **Important caveat for `loss_mode="v1"`**
493
+ `v1` can perform cross-reconstruction across all modalities. If you include `"celltype"` as a modality, you typically **do not** want cross-recon terms like `celltype → RNA`. If you must run `v1` with label-as-modality, prefer:
494
+
495
+ ```python
496
+ model = UniVIMultiModalVAE(univi_cfg, loss_mode="v1", v1_recon="self").to("cuda")
497
+ ```
498
+
499
+ If you want full `v1` cross-reconstruction and label shaping, prefer **Pattern A (classification head)** instead.
500
+
501
+ ---
502
+
503
+ ## Running a minimal training script (UniVI v1 vs UniVI-lite)
348
504
 
349
505
  ### 0) Choose the training objective (`loss_mode`) in your config JSON
350
506
 
351
- In `parameter_files/*.json`, set a single switch that controls the objective:
507
+ In `parameter_files/*.json`, set a single switch that controls the objective.
352
508
 
353
- **Paper objective (v1; cross-reconstruction + cross-posterior alignment):**
509
+ **Paper objective (v1; `"avg"` trains with 50% weight on self-reconstruction and 50% weight on cross-reconstruction, with weights automatically adjusted so this stays true for any number of modalities):**
354
510
 
355
511
  ```json5
356
512
  {
357
513
  "model": {
358
514
  "loss_mode": "v1",
359
- "v1_recon": "cross", // or "cross" | "self" | "avg" | "moe" | "src:rna" etc.
360
- "v1_recon_mix": 0.0, // optional extra averaged-z recon weight
515
+ "v1_recon": "avg",
361
516
  "normalize_v1_terms": true
517
+ }
362
518
  }
363
- ````
519
+ ```
364
520
 
365
521
  **UniVI-lite objective (v2; lightweight / fusion-based):**
366
522
 
367
523
  ```json5
368
524
  {
369
525
  "model": {
370
- "loss_mode": "lite", // "lite" is also a proxy for "v2"
371
- "v1_recon": "cross", // doesn't get used if loss_mode="lite"
372
- "v1_recon_mix": 0.0, // doesn't get used if loss_mode="lite"
373
- "normalize_v1_terms": true // doesn't get used if loss_mode="lite"
526
+ "loss_mode": "lite"
527
+ }
374
528
  }
375
529
  ```
376
530
 
377
531
  > **Note**
378
532
  > `loss_mode: "lite"` is an alias for `loss_mode: "v2"` (they run the same objective in the current code).
379
533
 
534
+ ### 0b) (Optional) Enable supervised labels from config JSON
535
+
536
+ **Classification head (decoder-only):**
537
+
538
+ ```json5
539
+ {
540
+ "model": {
541
+ "loss_mode": "lite",
542
+ "n_label_classes": 30,
543
+ "label_loss_weight": 1.0,
544
+ "label_ignore_index": -1,
545
+ "classify_from_mu": true
546
+ }
547
+ }
548
+ ```
549
+
550
+ **Lite + label expert injected into fusion (encoder-side):**
551
+
552
+ ```json5
553
+ {
554
+ "model": {
555
+ "loss_mode": "lite",
556
+ "n_label_classes": 30,
557
+ "label_loss_weight": 1.0,
558
+
559
+ "use_label_encoder": true,
560
+ "label_moe_weight": 1.0,
561
+ "unlabeled_logvar": 20.0,
562
+ "label_encoder_warmup": 5,
563
+ "label_ignore_index": -1
564
+ }
565
+ }
566
+ ```
567
+
568
+ **Labels as a categorical modality:** add an additional `"celltype"` modality in `"data.modalities"` and provide a matching AnnData on disk (or build it in Python).
569
+
570
+ ```json5
571
+ {
572
+ "model": { "loss_mode": "lite" },
573
+ "data": {
574
+ "modalities": [
575
+ { "name": "rna", "likelihood": "nb", "X_key": "X", "layer": "counts" },
576
+ { "name": "adt", "likelihood": "nb", "X_key": "X", "layer": "counts" },
577
+ { "name": "celltype", "likelihood": "categorical", "X_key": "X", "layer": null }
578
+ ]
579
+ }
580
+ }
581
+ ```
582
+
380
583
  ### 1) Normalization / representation switch (counts vs continuous)
381
584
 
382
- UniVI can be trained on **counts** (NB/ZINB/Poisson likelihoods) or **continuous** representations (Gaussian/MSE likelihoods). In your configs, keep this explicit.
585
+ **Important note on selectors:**
586
+
587
+ * `layer` selects `.layers[layer]` (if `X_key == "X"`).
588
+ * `X_key == "X"` selects `.X`/`.layers[layer]`; otherwise `X_key` selects `.obsm[X_key]`.
383
589
 
384
- Recommended pattern (example showing several preprocessing options for the different data types, YMMV):
590
+ Correct pattern:
385
591
 
386
592
  ```json5
387
593
  {
@@ -389,35 +595,34 @@ Recommended pattern (example showing several preprocessing options for the diffe
389
595
  "modalities": [
390
596
  {
391
597
  "name": "rna",
392
- "layer": "log1p", // uses .X or .layers["log1p"]
598
+ "layer": "log1p", // uses adata.layers["log1p"] (since X_key=="X")
393
599
  "X_key": "X",
394
- "assume_log1p": true, // already log-normalized
600
+ "assume_log1p": true,
395
601
  "likelihood": "gaussian"
396
602
  },
397
603
  {
398
604
  "name": "adt",
399
- "layer": "counts", // raw counts in .layers["counts"]
400
- "X_key": "counts",
401
- "assume_log1p": false, // use raw for ZINB
605
+ "layer": "counts", // uses adata.layers["counts"] (since X_key=="X")
606
+ "X_key": "X",
607
+ "assume_log1p": false,
402
608
  "likelihood": "zinb"
403
609
  },
404
610
  {
405
611
  "name": "atac",
406
- "layer": "X_lsi", // continuous LSI features
407
- "X_key": "X_lsi",
612
+ "layer": null, // ignored because X_key != "X"
613
+ "X_key": "X_lsi", // uses adata.obsm["X_lsi"]
408
614
  "assume_log1p": false,
409
615
  "likelihood": "gaussian"
410
616
  }
411
617
  ]
412
618
  }
413
619
  }
414
-
415
620
  ```
416
621
 
417
622
  * Use `.layers["counts"]` when you want NB/ZINB/Poisson decoders.
418
- * Use continuous `.X` (log1p/CLR/LSI) when you want Gaussian/MSE decoders.
623
+ * Use continuous `.X` or `.obsm["X_lsi"]` when you want Gaussian/MSE decoders.
419
624
 
420
- > Jupyter Notebooks in this repository (UniVI/notebooks/) show recommended preprocessing per dataset for different data types and analyses. Depending on your research goals, you can use several different methods of preprocessing. The model is quite robust when it comes to learning underlying biology regardless of input data processing method used; the main key is that the decoder likelihood should roughly match the input distribution per-modality.
625
+ > Jupyter notebooks in this repository (UniVI/notebooks/) show recommended preprocessing per dataset for different data types and analyses. Depending on your research goals, you can use several different methods of preprocessing. The model is robust when it comes to learning underlying biology regardless of preprocessing; the key is that the decoder likelihood should roughly match the input distribution per-modality.
421
626
 
422
627
  ### 2) Train (CLI)
423
628
 
@@ -456,8 +661,8 @@ python scripts/train_univi.py \
456
661
 
457
662
  ```bash
458
663
  python scripts/train_univi.py \
459
- --config parameter_files/ \
460
- --outdir saved_models/defaults_multiome_lite.json \
664
+ --config parameter_files/defaults_multiome_lite.json \
665
+ --outdir saved_models/multiome_lite_run1 \
461
666
  --data-root /path/to/your/data
462
667
  ```
463
668
 
@@ -481,7 +686,9 @@ python scripts/train_univi.py \
481
686
  --data-root /path/to/your/data
482
687
  ```
483
688
 
484
- ### 3) Quickstart: run UniVI from Python / Jupyter
689
+ ---
690
+
691
+ ## Quickstart: run UniVI from Python / Jupyter
485
692
 
486
693
  If you prefer to stay inside a notebook or a Python script instead of calling the CLI, you can build the configs, model, and trainer directly.
487
694
 
@@ -500,33 +707,29 @@ from univi import (
500
707
  UniVIConfig,
501
708
  TrainingConfig,
502
709
  )
503
- from univi.data import MultiModalDataset
710
+ from univi.data import MultiModalDataset, align_paired_obs_names
504
711
  from univi.trainer import UniVITrainer
505
- ````
712
+ ```
506
713
 
507
- #### 1) Load preprocessed AnnData (paired cells)
714
+ ### 1) Load preprocessed AnnData (paired cells)
508
715
 
509
716
  ```python
510
- # Example: CITE-seq with RNA + ADT
511
717
  rna = sc.read_h5ad("path/to/rna_citeseq.h5ad")
512
718
  adt = sc.read_h5ad("path/to/adt_citeseq.h5ad")
513
719
 
514
- # Assumes rna.obs_names == adt.obs_names (same cells, same order)
515
- adata_dict = {
516
- "rna": rna,
517
- "adt": adt,
518
- }
720
+ adata_dict = {"rna": rna, "adt": adt}
721
+ adata_dict = align_paired_obs_names(adata_dict) # ensures same obs_names/order
519
722
  ```
520
723
 
521
- #### 2) Build `MultiModalDataset` and DataLoaders
724
+ ### 2) Build `MultiModalDataset` and DataLoaders (unsupervised)
522
725
 
523
726
  ```python
524
727
  device = "cuda" if torch.cuda.is_available() else "cpu"
525
728
 
526
729
  dataset = MultiModalDataset(
527
730
  adata_dict=adata_dict,
528
- X_key="X", # use .X from each AnnData for training
529
- device=device, # tensors moved to this device on-the-fly
731
+ X_key="X",
732
+ device=None, # "cpu" or "cuda"
530
733
  )
531
734
 
532
735
  n_cells = rna.n_obs
@@ -542,29 +745,35 @@ val_ds = Subset(dataset, val_idx)
542
745
 
543
746
  batch_size = 256
544
747
 
545
- train_loader = DataLoader(
546
- train_ds,
547
- batch_size=batch_size,
548
- shuffle=True,
549
- num_workers=0,
550
- )
748
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
749
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
750
+ ```
551
751
 
552
- val_loader = DataLoader(
553
- val_ds,
554
- batch_size=batch_size,
555
- shuffle=False,
556
- num_workers=0,
557
- )
752
+ ### 2b) (Optional) Supervised batches for Pattern A/B (`(x_dict, y)`)
753
+
754
+ If you use the classification head and/or label expert injection, supply `y` as integer class indices and mask unlabeled with `-1`.
755
+
756
+ ```python
757
+ y_codes = rna.obs["celltype"].astype("category").cat.codes.to_numpy()
758
+
759
+ dataset_sup = MultiModalDataset(adata_dict=adata_dict, X_key="X", labels=y_codes)
760
+
761
+ def collate_xy(batch):
762
+ xs, ys = zip(*batch)
763
+ x = {k: torch.stack([d[k] for d in xs], 0) for k in xs[0].keys()}
764
+ y = torch.as_tensor(ys, dtype=torch.long)
765
+ return x, y
766
+
767
+ train_loader = DataLoader(dataset_sup, batch_size=batch_size, shuffle=True, collate_fn=collate_xy)
558
768
  ```
559
769
 
560
- #### 3) Define UniVI configs (v1 vs UniVI-lite)
770
+ ### 3) Define UniVI configs (v1 vs UniVI-lite)
561
771
 
562
772
  ```python
563
- # UniVI model config (architecture + regularization)
564
773
  univi_cfg = UniVIConfig(
565
774
  latent_dim=40,
566
- beta=5.0, # KL weight
567
- gamma=40.0, # alignment weight (used differently in v1 vs lite)
775
+ beta=5.0,
776
+ gamma=40.0,
568
777
  encoder_dropout=0.1,
569
778
  decoder_dropout=0.0,
570
779
  encoder_batchnorm=True,
@@ -574,24 +783,11 @@ univi_cfg = UniVIConfig(
574
783
  align_anneal_start=0,
575
784
  align_anneal_end=25,
576
785
  modalities=[
577
- ModalityConfig(
578
- name="rna",
579
- input_dim=rna.n_vars,
580
- encoder_hidden=[1024, 512],
581
- decoder_hidden=[512, 1024],
582
- likelihood="nb", # counts-like RNA
583
- ),
584
- ModalityConfig(
585
- name="adt",
586
- input_dim=adt.n_vars,
587
- encoder_hidden=[256, 128],
588
- decoder_hidden=[128, 256],
589
- likelihood="nb", # counts-like ADT
590
- ),
786
+ ModalityConfig("rna", rna.n_vars, [1024, 512], [512, 1024], likelihood="nb"),
787
+ ModalityConfig("adt", adt.n_vars, [256, 128], [128, 256], likelihood="nb"),
591
788
  ],
592
789
  )
593
790
 
594
- # Training config (epochs, LR, device, etc.)
595
791
  train_cfg = TrainingConfig(
596
792
  n_epochs=200,
597
793
  batch_size=batch_size,
@@ -608,27 +804,47 @@ train_cfg = TrainingConfig(
608
804
  )
609
805
  ```
610
806
 
611
- #### 4) Choose the objective: **v1** vs **UniVI-lite**
612
-
613
- * **v1** (paper objective): cross-reconstruction + cross-posterior alignment.
614
- * Best when batches are paired/pseudo-paired and you want explicit cross-prediction.
615
- * **lite** (aka `"v2"`): missing-modality friendly; trains even if some modalities are absent in a batch.
807
+ ### 4) Choose the objective + supervised option
616
808
 
617
809
  ```python
618
- # Option A: UniVI v1 (paper)
810
+ # Option A: UniVI v1 (unsupervised)
619
811
  model = UniVIMultiModalVAE(
620
812
  univi_cfg,
621
813
  loss_mode="v1",
622
- v1_recon="cross", # "cross" | "self" | "avg" | "moe" | "src:rna" etc.
623
- v1_recon_mix=0.0, # optional extra averaged-z recon weight
814
+ v1_recon="avg",
815
+ v1_recon_mix=0.0,
624
816
  normalize_v1_terms=True,
625
817
  ).to(device)
626
818
 
627
- # Option B: UniVI-lite (v2)
819
+ # Option B: UniVI-lite / v2 (unsupervised)
628
820
  # model = UniVIMultiModalVAE(univi_cfg, loss_mode="lite").to(device)
821
+
822
+ # Option C: Add classification head (Pattern A; works in lite/v2 AND v1)
823
+ # n_classes = int(y_codes.max() + 1)
824
+ # model = UniVIMultiModalVAE(
825
+ # univi_cfg,
826
+ # loss_mode="lite",
827
+ # n_label_classes=n_classes,
828
+ # label_loss_weight=1.0,
829
+ # label_ignore_index=-1,
830
+ # classify_from_mu=True,
831
+ # ).to(device)
832
+
833
+ # Option D: Add label expert injection into fusion (Pattern B; lite/v2 ONLY)
834
+ # model = UniVIMultiModalVAE(
835
+ # univi_cfg,
836
+ # loss_mode="lite",
837
+ # n_label_classes=n_classes,
838
+ # label_loss_weight=1.0,
839
+ # use_label_encoder=True,
840
+ # label_moe_weight=1.0,
841
+ # unlabeled_logvar=20.0,
842
+ # label_encoder_warmup=5,
843
+ # label_ignore_index=-1,
844
+ # ).to(device)
629
845
  ```
630
846
 
631
- #### 5) Train inside Python / Jupyter
847
+ ### 5) Train inside Python / Jupyter
632
848
 
633
849
  ```python
634
850
  trainer = UniVITrainer(
@@ -639,26 +855,36 @@ trainer = UniVITrainer(
639
855
  device=device,
640
856
  )
641
857
 
642
- history = trainer.fit() # runs the training loop
858
+ history = trainer.fit()
643
859
  ```
644
860
 
645
- `history` typically contains per-epoch loss and metric curves. After training, you can reuse `model` directly for:
646
-
647
- * Computing latent embeddings (`encode_modalities` / `mixture_of_experts`)
648
- * Cross-modal reconstruction (forward passes with different modality subsets)
649
- * Exporting `z` to AnnData or NumPy for downstream analysis (UMAP, clustering, DE, etc.)
650
-
651
- #### 6) Write latent `z` into AnnData `.obsm["X_univi"]`
861
+ ### 6) Write latent `z` into AnnData `.obsm["X_univi"]`
652
862
 
653
863
  ```python
654
864
  from univi import write_univi_latent
655
865
 
656
- Z = write_univi_latent(model, adata_dict, obsm_key="X_univi", device=device)
866
+ Z = write_univi_latent(model, adata_dict, obsm_key="X_univi", device=device, use_mean=True)
657
867
  print("Embedding shape:", Z.shape)
658
868
  ```
659
869
 
660
870
  > **Tip**
661
- > If you want deterministic embeddings for plotting, add the argument `use_mean=True` to the `write_univi_latent` function so you store mu_z instead of a sampled z. Of note, a sampled z is a stochastic sampling from each latent distribution which allows for generative modeling, while mu_z uses the means of each latent distribution for a more informative view of overall latent structure.
871
+ > Use `use_mean=True` for deterministic plotting/UMAP. Sampling (`use_mean=False`) is stochastic and useful for generative behavior.
872
+
873
+ ---
874
+
875
+ ## Evaluating / encoding: choosing the latent representation
876
+
877
+ Some utilities (e.g., `encode_adata`) support selecting what embedding to return:
878
+
879
+ * `"moe_mean"` / `"moe_sample"`: fused latent (MoE/PoE)
880
+ * `"modality_mean"` / `"modality_sample"`: per-modality latent
881
+
882
+ ```python
883
+ from univi.evaluation import encode_adata
884
+
885
+ Z_rna = encode_adata(model, rna, modality="rna", device=device, layer="counts", latent="modality_mean")
886
+ Z_moe = encode_adata(model, rna, modality="rna", device=device, layer="counts", latent="moe_mean")
887
+ ```
662
888
 
663
889
  ---
664
890
 
@@ -819,3 +1045,5 @@ Typical evaluation outputs include:
819
1045
 
820
1046
  For richer, exploratory workflows (TEA-seq tri-modal integration, Multiome RNA+ATAC, non-paired matching, etc.), see the notebooks in `notebooks/`.
821
1047
 
1048
+ ---
1049
+