univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,908 @@
1
+ Metadata-Version: 2.4
2
+ Name: univi
3
+ Version: 0.3.4
4
+ Summary: UniVI: a scalable multi-modal variational autoencoder toolkit for seamless integration and analysis of multimodal single-cell data.
5
+ Author-email: "Andrew J. Ashford" <ashforda@ohsu.edu>
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 Andrew J. Ashford
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the “Software”), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Project-URL: Homepage, https://github.com/Ashford-A/UniVI
29
+ Project-URL: Repository, https://github.com/Ashford-A/UniVI
30
+ Project-URL: Bug Tracker, https://github.com/Ashford-A/UniVI/issues
31
+ Classifier: License :: OSI Approved :: MIT License
32
+ Classifier: Programming Language :: Python :: 3
33
+ Classifier: Programming Language :: Python :: 3.10
34
+ Classifier: Programming Language :: Python :: 3 :: Only
35
+ Requires-Python: >=3.10
36
+ Description-Content-Type: text/markdown
37
+ License-File: LICENSE
38
+ Requires-Dist: numpy>=1.26
39
+ Requires-Dist: scipy>=1.11
40
+ Requires-Dist: pandas>=2.1
41
+ Requires-Dist: anndata>=0.10
42
+ Requires-Dist: scanpy>=1.11
43
+ Requires-Dist: torch>=2.2
44
+ Requires-Dist: scikit-learn>=1.3
45
+ Requires-Dist: h5py>=3.10
46
+ Requires-Dist: pyyaml>=6.0
47
+ Requires-Dist: matplotlib>=3.8
48
+ Requires-Dist: seaborn>=0.13
49
+ Requires-Dist: igraph>=0.11
50
+ Requires-Dist: leidenalg>=0.10
51
+ Requires-Dist: tqdm>=4.66
52
+ Requires-Dist: openpyxl>=3.1
53
+ Provides-Extra: bench
54
+ Requires-Dist: harmonypy>=0.0.9; extra == "bench"
55
+ Dynamic: license-file
56
+
57
+ # UniVI
58
+
59
+ [![PyPI version](https://img.shields.io/pypi/v/univi)](https://pypi.org/project/univi/)
60
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/univi.svg?v=0.3.4)](https://pypi.org/project/univi/)
61
+
62
+ <picture>
63
+ <!-- Dark mode (GitHub supports this; PyPI may ignore <source>) -->
64
+ <source media="(prefers-color-scheme: dark)"
65
+ srcset="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.3.4/assets/figures/univi_overview_dark.png">
66
+ <!-- Light mode / fallback (works on GitHub + PyPI) -->
67
+ <img src="https://raw.githubusercontent.com/Ashford-A/UniVI/v0.3.4/assets/figures/univi_overview_light.png"
68
+ alt="UniVI overview and evaluation roadmap"
69
+ width="100%">
70
+ </picture>
71
+
72
+ **UniVI** is a **multi-modal variational autoencoder (VAE)** framework for aligning and integrating single-cell modalities such as RNA, ADT (CITE-seq), and ATAC.
73
+
74
+ It’s designed for experiments like:
75
+
76
+ - **Joint embedding** of paired multimodal data (CITE-seq, Multiome, TEA-seq)
77
+ - **Zero-shot projection** of external unimodal cohorts into a paired “bridge” latent
78
+ - **Cross-modal reconstruction / imputation** (RNA→ADT, ATAC→RNA, etc.)
79
+ - **Denoising** via learned generative decoders
80
+ - **Evaluation** (FOSCTTM, modality mixing, label transfer, feature recovery)
81
+ - **Optional supervised heads** for harmonized annotation and domain confusion
82
+ - **Optional transformer encoders** (per-modality and/or fused multimodal transformer posterior)
83
+ - **Token-level hooks** for interpretability (top-k indices; optional attention maps if enabled)
84
+
85
+ ---
86
+
87
+ ## Preprint
88
+
89
+ If you use UniVI in your work, please cite:
90
+
91
+ > Ashford AJ, Enright T, Nikolova O, Demir E.
92
+ > **Unifying Multimodal Single-Cell Data Using a Mixture of Experts β-Variational Autoencoder-Based Framework.**
93
+ > *bioRxiv* (2025). doi: [10.1101/2025.02.28.640429](https://www.biorxiv.org/content/10.1101/2025.02.28.640429v1.full)
94
+
95
+ ```bibtex
96
+ @article{Ashford2025UniVI,
97
+ title = {Unifying Multimodal Single-Cell Data Using a Mixture of Experts β-Variational Autoencoder-Based Framework},
98
+ author = {Ashford, Andrew J. and Enright, Trevor and Nikolova, Olga and Demir, Emek},
99
+ journal = {bioRxiv},
100
+ year = {2025},
101
+ doi = {10.1101/2025.02.28.640429},
102
+ url = {https://www.biorxiv.org/content/10.1101/2025.02.28.640429v1}
103
+ }
104
+ ````
105
+
106
+ ---
107
+
108
+ ## License
109
+
110
+ MIT License — see `LICENSE`.
111
+
112
+ ---
113
+
114
+ ## Repository structure
115
+
116
+ ```text
117
+ UniVI/
118
+ ├── README.md # Project overview, installation, quickstart
119
+ ├── LICENSE # MIT license text file
120
+ ├── pyproject.toml # Python packaging config (pip / PyPI)
121
+ ├── assets/ # Static assets used by README/docs
122
+ │ └── figures/ # Schematic figure(s) for repository front page
123
+ ├── conda.recipe/ # Conda build recipe (for conda-build)
124
+ │ └── meta.yaml
125
+ ├── envs/ # Example conda environments
126
+ │ ├── UniVI_working_environment.yml
127
+ │ ├── UniVI_working_environment_v2_full.yml
128
+ │ ├── UniVI_working_environment_v2_minimal.yml
129
+ │ └── univi_env.yml # Recommended env (CUDA-friendly)
130
+ ├── data/ # Small example data notes (datasets are typically external)
131
+ │ └── README.md # Notes on data sources / formats
132
+ ├── notebooks/ # Jupyter notebooks (demos & benchmarks)
133
+ │ ├── UniVI_CITE-seq_*.ipynb
134
+ │ ├── UniVI_10x_Multiome_*.ipynb
135
+ │ └── UniVI_TEA-seq_*.ipynb
136
+ ├── parameter_files/ # JSON configs for model + training + data selectors
137
+ │ ├── defaults_*.json # Default configs (per experiment)
138
+ │ └── params_*.json # Example “named” configs (RNA, ADT, ATAC, etc.)
139
+ ├── scripts/ # Reproducible entry points (revision-friendly)
140
+ │ ├── train_univi.py # Train UniVI from a parameter JSON
141
+ │ ├── evaluate_univi.py # Evaluate trained models (FOSCTTM, label transfer, etc.)
142
+ │ ├── benchmark_univi_citeseq.py # CITE-seq-specific benchmarking script
143
+ │ ├── run_multiome_hparam_search.py
144
+ │ ├── run_frequency_robustness.py # Composition/frequency mismatch robustness
145
+ │ ├── run_do_not_integrate_detection.py # “Do-not-integrate” unmatched population demo
146
+ │ ├── run_benchmarks.py # Unified wrapper (includes optional Harmony baseline)
147
+ │ └── revision_reproduce_all.sh # One-click: reproduces figures + supplemental tables
148
+ └── univi/ # UniVI Python package (importable as `import univi`)
149
+ ├── __init__.py # Package exports and __version__
150
+ ├── __main__.py # Enables: `python -m univi ...`
151
+ ├── cli.py # Minimal CLI (e.g., export-s1, encode)
152
+ ├── pipeline.py # Config-driven model+data loading; latent encoding helpers
153
+ ├── diagnostics.py # Exports Supplemental_Table_S1.xlsx (env + hparams + dataset stats)
154
+ ├── config.py # Config dataclasses (UniVIConfig, ModalityConfig, TrainingConfig)
155
+ ├── data.py # Dataset wrappers + matrix selectors (layer/X_key, obsm support)
156
+ ├── evaluation.py # Metrics (FOSCTTM, mixing, label transfer, feature recovery)
157
+ ├── matching.py # Modality matching / alignment helpers
158
+ ├── objectives.py # Losses (ELBO variants, KL/alignment annealing, etc.)
159
+ ├── plotting.py # Plotting helpers + consistent style defaults
160
+ ├── trainer.py # UniVITrainer: training loop, logging, checkpointing
161
+ ├── interpretability.py # Helper scripts for transformer token weight interpretability
162
+ ├── figures/ # Package-internal figure assets (placeholder)
163
+ │ └── .gitkeep
164
+ ├── models/ # VAE architectures + building blocks
165
+ │ ├── __init__.py
166
+ │ ├── mlp.py # Shared MLP building blocks
167
+ │ ├── encoders.py # Modality encoders (MLP + transformer + fused transformer)
168
+ │ ├── decoders.py # Likelihood-specific decoders (NB, ZINB, Gaussian, etc.)
169
+ │ ├── transformer.py # Transformer blocks + encoder (+ optional attn bias support)
170
+ │ ├── tokenizer.py # Tokenization configs/helpers (top-k / patch)
171
+ │ └── univi.py # Core UniVI multi-modal VAE
172
+ ├── hyperparam_optimization/ # Hyperparameter search scripts
173
+ │ ├── __init__.py
174
+ │ ├── common.py
175
+ │ ├── run_adt_hparam_search.py
176
+ │ ├── run_atac_hparam_search.py
177
+ │ ├── run_citeseq_hparam_search.py
178
+ │ ├── run_multiome_hparam_search.py
179
+ │ ├── run_rna_hparam_search.py
180
+ │ └── run_teaseq_hparam_search.py
181
+ └── utils/ # General utilities
182
+ ├── __init__.py
183
+ ├── io.py # I/O helpers (AnnData, configs, checkpoints)
184
+ ├── logging.py # Logging configuration / progress reporting
185
+ ├── seed.py # Reproducibility helpers (seeding RNGs)
186
+ ├── stats.py # Small statistical helpers / transforms
187
+ └── torch_utils.py # PyTorch utilities (device, tensor helpers)
188
+ ```
189
+
190
+ ---
191
+
192
+ ## Generated outputs
193
+
194
+ Most entry-point scripts write results into a user-specified output directory (commonly `runs/`), which is not tracked in git.
195
+
196
+ A typical `runs/` folder produced by `scripts/revision_reproduce_all.sh` looks like:
197
+
198
+ ```text
199
+ runs/
200
+ └── <run_name>/ # user-chosen run name (often includes dataset + date)
201
+ ├── checkpoints/ # model/trainer state for resuming or export
202
+ │ ├── univi_checkpoint.pt # primary checkpoint (model + optimizer + config, if enabled)
203
+ │ └── best.pt # optional: best-val checkpoint (if early stopping enabled)
204
+ ├── eval/ # evaluation summaries and derived plots
205
+ │ ├── metrics.json # machine-readable metrics summary
206
+ │ ├── metrics.csv # flat table for quick comparisons
207
+ │ └── plots/ # optional: UMAPs, heatmaps, and benchmark figures
208
+ ├── embeddings/ # optional: exported latents for downstream analysis
209
+ │ ├── mu_z.npy # fused mean embedding (cells x latent_dim)
210
+ │ ├── modality_mu/ # per-modality embeddings q(z|x_m)
211
+ │ │ ├── rna.npy
212
+ │ │ ├── adt.npy
213
+ │ │ └── atac.npy
214
+ │ └── obs_names.txt # row order for embeddings (safe joins)
215
+ ├── reconstructions/ # optional: recon and cross-recon exports
216
+ │ ├── rna_from_rna.npy # denoised reconstruction
217
+ │ ├── adt_from_adt.npy
218
+ │ ├── adt_from_rna.npy # cross-modal imputation example
219
+ │ └── rna_from_atac.npy
220
+ ├── robustness/ # robustness experiments (frequency mismatch, DnI, etc.)
221
+ │ ├── frequency_perturbation_results.csv
222
+ │ ├── frequency_perturbation_plot.png
223
+ │ ├── frequency_perturbation_plot.pdf
224
+ │ ├── do_not_integrate_summary.csv
225
+ │ ├── do_not_integrate_plot.png
226
+ │ └── do_not_integrate_plot.pdf
227
+ ├── benchmarks/ # baseline comparisons (optionally includes Harmony, etc.)
228
+ │ ├── results.csv
229
+ │ ├── results.png
230
+ │ └── results.pdf
231
+ ├── tables/
232
+ │ └── Supplemental_Table_S1.xlsx # environment + hparams + dataset statistics snapshot
233
+ └── logs/
234
+ ├── train.log # training log (stdout/stderr capture)
235
+ └── history.csv # per-epoch train/val traces (if enabled)
236
+ ```
237
+
238
+ (Exact subfolders vary by script and flags; the layout above shows the common outputs across the pipeline.)
239
+
240
+ ---
241
+
242
+ ## Installation
243
+
244
+ ### Install via PyPI
245
+
246
+ ```bash
247
+ pip install univi
248
+ ```
249
+
250
+ > **Note:** UniVI requires `torch`. If `import torch` fails, install PyTorch for your platform/CUDA from:
251
+ > [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
252
+
253
+ ### Development install (from source)
254
+
255
+ ```bash
256
+ git clone https://github.com/Ashford-A/UniVI.git
257
+ cd UniVI
258
+
259
+ conda env create -f envs/univi_env.yml
260
+ conda activate univi_env
261
+
262
+ pip install -e .
263
+ ```
264
+
265
+ ### (Optional) Install via conda / mamba
266
+
267
+ ```bash
268
+ conda install -c conda-forge univi
269
+ # or
270
+ mamba install -c conda-forge univi
271
+ ```
272
+
273
+ UniVI is also installable from a custom channel:
274
+
275
+ ```bash
276
+ conda install ashford-a::univi
277
+ # or
278
+ mamba install ashford-a::univi
279
+ ```
280
+
281
+ ---
282
+
283
+ ## Data expectations (high-level)
284
+
285
+ UniVI expects **per-modality AnnData** objects with matching cells (paired data or consistently paired across modalities).
286
+
287
+ Typical expectations:
288
+
289
+ * Each modality is an `AnnData` with the same `obs_names` (same cells, same order)
290
+ * Raw counts often live in `.layers["counts"]`
291
+ * A processed training representation lives in `.X` (or `.obsm["X_*"]` for ATAC LSI)
292
+ * Decoder likelihoods should roughly match the training representation:
293
+
294
+ * counts-like → `nb` / `zinb` / `poisson`
295
+ * continuous → `gaussian` / `mse`
296
+
297
+ See `notebooks/` for end-to-end preprocessing examples.
298
+
299
+ ---
300
+
301
+ ## Training objectives (v1 vs v2/lite)
302
+
303
+ UniVI supports two main training regimes:
304
+
305
+ * **UniVI v1 (“paper”)**
306
+ Per-modality posteriors + flexible reconstruction scheme (cross/self/avg) + posterior alignment across modalities.
307
+
308
+ * **UniVI v2 / lite**
309
+ A fused posterior (precision-weighted MoE/PoE-style by default; optional fused transformer) + per-modality recon + β·KL + γ·alignment.
310
+ Convenient for 3+ modalities and “loosely paired” settings.
311
+
312
+ You choose via `loss_mode` at model construction (Python) or config JSON (CLI scripts).
313
+
314
+ ---
315
+
316
+ ## Quickstart (Python / Jupyter)
317
+
318
+ Below is a minimal paired **CITE-seq (RNA + ADT)** example using `MultiModalDataset` + `UniVITrainer`.
319
+
320
+ ```python
321
+ import numpy as np
322
+ import scanpy as sc
323
+ import torch
324
+ from torch.utils.data import DataLoader, Subset
325
+
326
+ from univi import UniVIMultiModalVAE, ModalityConfig, UniVIConfig, TrainingConfig
327
+ from univi.data import MultiModalDataset, align_paired_obs_names
328
+ from univi.trainer import UniVITrainer
329
+ ```
330
+
331
+ ### 1) Load paired AnnData
332
+
333
+ ```python
334
+ rna = sc.read_h5ad("path/to/rna_citeseq.h5ad")
335
+ adt = sc.read_h5ad("path/to/adt_citeseq.h5ad")
336
+
337
+ adata_dict = {"rna": rna, "adt": adt}
338
+ adata_dict = align_paired_obs_names(adata_dict) # ensures same obs_names/order
339
+ ```
340
+
341
+ ### 2) Dataset + dataloaders
342
+
343
+ ```python
344
+ device = "cuda" if torch.cuda.is_available() else "cpu"
345
+
346
+ dataset = MultiModalDataset(
347
+ adata_dict=adata_dict,
348
+ X_key="X", # uses .X by default
349
+ device=None, # dataset returns CPU tensors; model moves to GPU
350
+ )
351
+
352
+ n = rna.n_obs
353
+ idx = np.arange(n)
354
+ rng = np.random.default_rng(0)
355
+ rng.shuffle(idx)
356
+ split = int(0.8 * n)
357
+ train_idx, val_idx = idx[:split], idx[split:]
358
+
359
+ train_ds = Subset(dataset, train_idx)
360
+ val_ds = Subset(dataset, val_idx)
361
+
362
+ train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=0)
363
+ val_loader = DataLoader(val_ds, batch_size=256, shuffle=False, num_workers=0)
364
+ ```
365
+
366
+ ### 3) Config + model
367
+
368
+ ```python
369
+ univi_cfg = UniVIConfig(
370
+ latent_dim=40,
371
+ beta=1.5,
372
+ gamma=2.5,
373
+ encoder_dropout=0.1,
374
+ decoder_dropout=0.0,
375
+ modalities=[
376
+ ModalityConfig("rna", rna.n_vars, [512, 256, 128], [128, 256, 512], likelihood="nb"),
377
+ ModalityConfig("adt", adt.n_vars, [128, 64], [64, 128], likelihood="nb"),
378
+ ],
379
+ )
380
+
381
+ train_cfg = TrainingConfig(
382
+ n_epochs=1000,
383
+ batch_size=256,
384
+ lr=1e-3,
385
+ weight_decay=1e-4,
386
+ device=device,
387
+ log_every=20,
388
+ grad_clip=5.0,
389
+ early_stopping=True,
390
+ patience=50,
391
+ )
392
+
393
+ # v1 (paper)
394
+ model = UniVIMultiModalVAE(
395
+ univi_cfg,
396
+ loss_mode="v1",
397
+ v1_recon="avg",
398
+ normalize_v1_terms=True,
399
+ ).to(device)
400
+
401
+ # Or: v2/lite
402
+ # model = UniVIMultiModalVAE(univi_cfg, loss_mode="v2").to(device)
403
+ ```
404
+
405
+ ### 4) Train
406
+
407
+ ```python
408
+ trainer = UniVITrainer(
409
+ model=model,
410
+ train_loader=train_loader,
411
+ val_loader=val_loader,
412
+ train_cfg=train_cfg,
413
+ device=device,
414
+ )
415
+
416
+ history = trainer.fit()
417
+ ```
418
+
419
+ ---
420
+
421
+ ## Mixed precision (AMP)
422
+
423
+ AMP (automatic mixed precision) can reduce VRAM usage and speed up training on GPUs by running selected ops in lower precision (fp16 or bf16) while keeping numerically sensitive parts in fp32.
424
+
425
+ If your trainer supports AMP flags, prefer bf16 where available. If using fp16, gradient scaling is typically used internally to avoid underflow.
426
+
427
+ ---
428
+
429
+ ## Checkpointing and resuming
430
+
431
+ Training is often run on clusters, so checkpoints are treated as first-class outputs.
432
+
433
+ Typical checkpoints contain:
434
+
435
+ * model weights
436
+ * optimizer state (for faithful resumption)
437
+ * training config/model config (for reproducibility)
438
+ * optional AMP scaler state (when using fp16 AMP)
439
+
440
+ See `univi/utils/io.py` for the exact checkpoint read/write helpers used by the trainer.
441
+
442
+ ---
443
+
444
+ ## Classification (built-in heads)
445
+
446
+ UniVI supports **in-model supervised classification heads** (single “legacy” label head and/or multi-head auxiliary decoders). This is useful for:
447
+
448
+ * harmonized cell-type annotation (e.g., bridge → projected cohorts)
449
+ * batch/tech/patient prediction (sanity checks, confounding)
450
+ * adversarial domain confusion via gradient reversal (GRL)
451
+ * multi-task setups (e.g., celltype + patient + mutation flags)
452
+
453
+ ### How it works
454
+
455
+ * Heads are configured via `UniVIConfig.class_heads` using `ClassHeadConfig`.
456
+ * Training targets are passed as `y`, a **dict mapping head name → integer class indices** with shape `(B,)`.
457
+ * Unlabeled entries should use `ignore_index` (default `-1`) and are masked out automatically.
458
+ * Each head can be delayed with `warmup` and weighted with `loss_weight`.
459
+ * Set `adversarial=True` for GRL heads (domain confusion).
460
+
461
+ ### 1) Add heads in the config
462
+
463
+ ```python
464
+ from univi.config import ClassHeadConfig
465
+
466
+ univi_cfg = UniVIConfig(
467
+ latent_dim=40,
468
+ beta=1.5,
469
+ gamma=2.5,
470
+ modalities=[
471
+ ModalityConfig("rna", rna.n_vars, [512,256,128], [128,256,512], likelihood="nb"),
472
+ ModalityConfig("adt", adt.n_vars, [128,64], [64,128], likelihood="nb"),
473
+ ],
474
+ class_heads=[
475
+ ClassHeadConfig(
476
+ name="celltype",
477
+ n_classes=int(rna.obs["celltype"].astype("category").cat.categories.size),
478
+ loss_weight=1.0,
479
+ ignore_index=-1,
480
+ from_mu=True, # classify from mu_z (more stable)
481
+ warmup=0,
482
+ ),
483
+ ClassHeadConfig(
484
+ name="batch",
485
+ n_classes=int(rna.obs["batch"].astype("category").cat.categories.size),
486
+ loss_weight=0.2,
487
+ ignore_index=-1,
488
+ from_mu=True,
489
+ warmup=10,
490
+ adversarial=True, # GRL head (domain confusion)
491
+ adv_lambda=1.0,
492
+ ),
493
+ ],
494
+ )
495
+ ```
496
+
497
+ Optional: attach readable label names (for your own decoding later):
498
+
499
+ ```python
500
+ model.set_head_label_names("celltype", list(rna.obs["celltype"].astype("category").cat.categories))
501
+ model.set_head_label_names("batch", list(rna.obs["batch"].astype("category").cat.categories))
502
+ ```
503
+
504
+ ### 2) Pass `y` to the model during training
505
+
506
+ Example pattern (construct labels from arrays aligned to dataset order):
507
+
508
+ ```python
509
+ celltype_codes = rna.obs["celltype"].astype("category").cat.codes.to_numpy()
510
+ batch_codes = rna.obs["batch"].astype("category").cat.codes.to_numpy()
511
+
512
+ y = {
513
+ "celltype": torch.tensor(celltype_codes[batch_idx], device=device),
514
+ "batch": torch.tensor(batch_codes[batch_idx], device=device),
515
+ }
516
+
517
+ out = model(x_dict, epoch=epoch, y=y)
518
+ loss = out["loss"]
519
+ loss.backward()
520
+ ```
521
+
522
+ When labels are provided, the forward output can include:
523
+
524
+ * `out["head_logits"]`: dict of logits `(B, n_classes)` per head
525
+ * `out["head_losses"]`: mean CE per head (masked by `ignore_index`)
526
+
527
+ ### 3) Predict heads after training
528
+
529
+ ```python
530
+ model.eval()
531
+ batch = next(iter(val_loader))
532
+ x_dict = {k: v.to(device) for k, v in batch.items()}
533
+
534
+ with torch.no_grad():
535
+ probs = model.predict_heads(x_dict, return_probs=True)
536
+
537
+ for head_name, P in probs.items():
538
+ print(head_name, P.shape) # (B, n_classes)
539
+ ```
540
+
541
+ To inspect which heads exist + their settings:
542
+
543
+ ```python
544
+ meta = model.get_classification_meta()
545
+ print(meta)
546
+ ```
547
+
548
+ ---
549
+
550
+ ## After training: what you can do with a trained UniVI model
551
+
552
+ UniVI isn’t just “map to latent”. With a trained model you can typically:
553
+
554
+ * **Encode modality-specific posteriors** `q(z|x_rna)`, `q(z|x_adt)`, …
555
+ * **Encode a fused posterior** (MoE/PoE by default; optional fused multimodal transformer posterior)
556
+ * **Denoise / reconstruct** inputs via the learned decoders
557
+ * **Cross-reconstruct / impute** across modalities (RNA→ADT, ATAC→RNA, etc.)
558
+ * **Evaluate alignment** (FOSCTTM, Recall@k, modality mixing, label transfer)
559
+ * **Predict supervised targets** via built-in classification heads (if enabled)
560
+ * **Inspect uncertainty** via per-modality posterior means/variances
561
+ * (Optional) **Inspect transformer token metadata** (top-k indices; attention maps when enabled)
562
+
563
+ ### Fused posterior options
564
+
565
+ UniVI can produce a fused latent in two ways:
566
+
567
+ * Default: **precision-weighted MoE/PoE fusion** over per-modality posteriors
568
+ * Optional: **fused multimodal transformer posterior** (`fused_encoder_type="multimodal_transformer"`)
569
+
570
+ In both cases, the standard embedding used for plotting/neighbors is the fused mean:
571
+
572
+ ```python
573
+ mu_z, logvar_z, z = model.encode_fused(x_dict, use_mean=True)
574
+ ````
575
+
576
+ ### 1) Encode embeddings for plotting / neighbors (built-in)
577
+
578
+ Use `encode_adata` to get either fused (MoE/PoE) or modality-specific latents directly from an AnnData.
579
+
580
+ ```python
581
+ import scanpy as sc
582
+ import torch
583
+ from univi.evaluation import encode_adata
584
+
585
+ device = "cuda" if torch.cuda.is_available() else "cpu"
586
+ model.eval()
587
+
588
+ # Fused latent (MoE/PoE) from a single observed modality
589
+ Z_fused = encode_adata(
590
+ model,
591
+ rna,
592
+ modality="rna",
593
+ device=device,
594
+ layer="counts", # or None to use .X
595
+ latent="moe_mean", # {"moe_mean","moe_sample","modality_mean","modality_sample"}
596
+ )
597
+
598
+ # Modality-specific latent (projection / diagnostics)
599
+ Z_rna = encode_adata(
600
+ model,
601
+ rna,
602
+ modality="rna",
603
+ device=device,
604
+ layer="counts",
605
+ latent="modality_mean",
606
+ )
607
+
608
+ rna.obsm["X_univi_fused"] = Z_fused
609
+ rna.obsm["X_univi_rna"] = Z_rna
610
+
611
+ sc.pp.neighbors(rna, use_rep="X_univi_fused")
612
+ sc.tl.umap(rna)
613
+ sc.pl.umap(rna, color=["celltype"], frameon=False)
614
+ ```
615
+
616
+ ### 2) Evaluate paired alignment (FOSCTTM, Recall@k, mixing, label transfer)
617
+
618
+ `evaluate_alignment` is a figure-ready wrapper. It can take precomputed `Z1/Z2`, or compute embeddings from AnnData via `encode_adata`.
619
+
620
+ ```python
621
+ from univi.evaluation import evaluate_alignment
622
+
623
+ # For paired data, you typically pass modality-specific latents for the two modalities
624
+ res = evaluate_alignment(
625
+ model=model,
626
+ adata1=rna,
627
+ adata2=adt,
628
+ mod1="rna",
629
+ mod2="adt",
630
+ device=device,
631
+ layer1="counts",
632
+ layer2="counts",
633
+ latent="modality_mean",
634
+ metric="euclidean",
635
+ recall_ks=(1, 5, 10),
636
+ k_mixing=20,
637
+ k_transfer=15,
638
+ # optional label transfer inputs:
639
+ # labels_source=rna.obs["celltype"].to_numpy(),
640
+ # labels_target=adt.obs["celltype"].to_numpy(),
641
+ )
642
+
643
+ print(res) # dict includes foscttm(+sem), recall@k(+sem), modality_mixing(+sem), label transfer (optional)
644
+ ```
645
+
646
+ ### 3) Denoise / reconstruct a modality (built-in)
647
+
648
+ `denoise_adata` runs “encode modality → decode same modality” and can write to a layer.
649
+
650
+ ```python
651
+ from univi.evaluation import denoise_adata
652
+
653
+ Xhat_rna = denoise_adata(
654
+ model,
655
+ rna,
656
+ modality="rna",
657
+ device=device,
658
+ layer="counts",
659
+ out_layer="univi_denoised", # writes rna.layers["univi_denoised"]
660
+ )
661
+
662
+ # Quick marker plots from denoised values:
663
+ import scanpy as sc
664
+ markers = ["TRAC", "NKG7", "LYZ", "MS4A1", "CD79A"]
665
+
666
+ rna_d = rna.copy()
667
+ rna_d.X = rna_d.layers["univi_denoised"]
668
+ sc.pl.umap(rna_d, color=markers, frameon=False, title=[f"{g} (denoised)" for g in markers])
669
+ ```
670
+
671
+ ### 4) Cross-modal reconstruction / imputation (built-in)
672
+
673
+ `cross_modal_predict` runs “encode src modality → decode target modality” and returns a dense numpy array.
674
+
675
+ ```python
676
+ from univi.evaluation import cross_modal_predict
677
+
678
+ # Example: RNA -> predicted ADT
679
+ adt_from_rna = cross_modal_predict(
680
+ model,
681
+ adata_src=rna,
682
+ src_mod="rna",
683
+ tgt_mod="adt",
684
+ device=device,
685
+ layer="counts",
686
+ batch_size=512,
687
+ use_moe=True, # for src-only input, MoE reduces to the src posterior
688
+ )
689
+ print(adt_from_rna.shape) # (cells, adt_features)
690
+ ```
691
+
692
+ ### 5) Direct model calls (advanced / debugging)
693
+
694
+ If you want full control (or want posterior means/variances explicitly), call the model methods directly.
695
+
696
+ ```python
697
+ import torch
698
+
699
+ model.eval()
700
+ batch = next(iter(val_loader))
701
+ x_dict = {k: v.to(device) for k, v in batch.items()}
702
+
703
+ with torch.no_grad():
704
+ # Per-modality posteriors
705
+ mu_dict, logvar_dict = model.encode_modalities(x_dict)
706
+
707
+ # Fused posterior (MoE/PoE or fused transformer, depending on config)
708
+ mu_z, logvar_z, z = model.encode_fused(x_dict, use_mean=True)
709
+
710
+ # Decode all modalities from a chosen latent (implementation-dependent keys)
711
+ xhat_dict = model.decode_modalities(mu_z)
712
+ ```
713
+
714
+ ---
715
+
716
+ ## CLI training (from JSON configs)
717
+
718
+ Most `scripts/*.py` entry points accept a parameter JSON.
719
+
720
+ **Train:**
721
+
722
+ ```bash
723
+ python scripts/train_univi.py \
724
+ --config parameter_files/defaults_cite_seq_scaled_gaussian_v1.json \
725
+ --outdir saved_models/citeseq_v1_run1 \
726
+ --data-root /path/to/your/data
727
+ ```
728
+
729
+ **Evaluate:**
730
+
731
+ ```bash
732
+ python scripts/evaluate_univi.py \
733
+ --config parameter_files/defaults_cite_seq_scaled_gaussian_v1.json \
734
+ --model-checkpoint saved_models/citeseq_v1_run1/checkpoints/univi_checkpoint.pt \
735
+ --outdir saved_models/citeseq_v1_run1/eval
736
+ ```
737
+
738
+ ---
739
+
740
+ ## Optional: Transformer encoders (per-modality)
741
+
742
+ By default, UniVI uses **MLP encoders** (`encoder_type="mlp"`), and classic workflows work unchanged.
743
+
744
+ If you want a transformer encoder for a modality, set:
745
+
746
+ * `encoder_type="transformer"`
747
+ * a `TokenizerConfig` (how `(B,F)` becomes `(B,T,D_in)`)
748
+ * a `TransformerConfig` (depth/width/pooling)
749
+
750
+ Example:
751
+
752
+ ```python
753
+ from univi.config import TransformerConfig, TokenizerConfig
754
+
755
+ univi_cfg = UniVIConfig(
756
+ latent_dim=40,
757
+ beta=1.0,
758
+ gamma=1.25,
759
+ modalities=[
760
+ ModalityConfig(
761
+ name="rna",
762
+ input_dim=rna.n_vars,
763
+ encoder_hidden=[512, 256, 128], # ignored by transformer encoder; kept for compatibility
764
+ decoder_hidden=[128, 256, 512],
765
+ likelihood="gaussian",
766
+ encoder_type="transformer",
767
+ tokenizer=TokenizerConfig(mode="topk_channels", n_tokens=512, channels=("value","rank","dropout")),
768
+ transformer=TransformerConfig(
769
+ d_model=256, num_heads=8, num_layers=4,
770
+ dim_feedforward=1024, dropout=0.1, attn_dropout=0.1,
771
+ activation="gelu", pooling="mean",
772
+ ),
773
+ ),
774
+ ModalityConfig(
775
+ name="adt",
776
+ input_dim=adt.n_vars,
777
+ encoder_hidden=[128, 64],
778
+ decoder_hidden=[64, 128],
779
+ likelihood="gaussian",
780
+ encoder_type="mlp",
781
+ tokenizer=TokenizerConfig(mode="topk_scalar", n_tokens=min(32, adt.n_vars)), # useful for fused encoder
782
+ ),
783
+ ],
784
+ )
785
+ ```
786
+
787
+ Notes:
788
+
789
+ * Tokenizers focus attention on the most informative features per cell (top-k) or local structure (patching).
790
+ * Transformer encoders expose optional interpretability hooks (token indices and, when enabled, attention maps).
791
+
792
+ ---
793
+
794
+ ## Optional: ATAC coordinate embeddings and distance attention bias (advanced)
795
+
796
+ For top-k tokenizers, UniVI can optionally incorporate genomic context:
797
+
798
+ * **Coordinate embeddings**: chromosome embedding + coordinate MLP per selected feature
799
+ * **Distance-based attention bias**: encourages attention between nearby peaks (same chromosome)
800
+
801
+ ### Enable in the tokenizer config (ATAC example)
802
+
803
+ ```python
804
+ TokenizerConfig(
805
+ mode="topk_channels",
806
+ n_tokens=512,
807
+ channels=("value","rank","dropout"),
808
+ use_coord_embedding=True,
809
+ n_chroms=<num_chromosomes>,
810
+ coord_scale=1e-6,
811
+ )
812
+ ```
813
+
814
+ ### Attach coordinates and configure distance bias via `UniVITrainer`
815
+
816
+ If your `UniVITrainer` supports `feature_coords` and `attn_bias_cfg`, you can attach genomic coordinates once and let the trainer build the bias for you:
817
+
818
+ ```python
819
+ feature_coords = {
820
+ "atac": {
821
+ "chrom_ids": chrom_ids_long, # (F,)
822
+ "start": start_bp, # (F,)
823
+ "end": end_bp, # (F,)
824
+ }
825
+ }
826
+
827
+ attn_bias_cfg = {
828
+ "atac": {
829
+ "type": "distance",
830
+ "lengthscale_bp": 50_000.0,
831
+ "same_chrom_only": True,
832
+ }
833
+ }
834
+
835
+ trainer = UniVITrainer(
836
+ model,
837
+ train_loader,
838
+ val_loader=val_loader,
839
+ train_cfg=TrainingConfig(...),
840
+ device="cuda",
841
+ feature_coords=feature_coords,
842
+ attn_bias_cfg=attn_bias_cfg,
843
+ )
844
+ trainer.fit()
845
+ ```
846
+
847
+ This path keeps the model code clean and makes the feature-coordinate plumbing consistent across runs.
848
+
849
+ ---
850
+
851
+ ## Optional: Fused multimodal transformer encoder (advanced)
852
+
853
+ A single transformer sees **concatenated tokens from multiple modalities** and returns a **single fused posterior** `q(z|all modalities)` using global CLS pooling (or mean pooling).
854
+
855
+ ### Minimal config
856
+
857
+ ```python
858
+ from univi.config import TransformerConfig
859
+
860
+ univi_cfg = UniVIConfig(
861
+ latent_dim=40,
862
+ beta=1.0,
863
+ gamma=1.25,
864
+ modalities=[...], # your per-modality configs still exist
865
+ fused_encoder_type="multimodal_transformer",
866
+ fused_modalities=("rna", "adt", "atac"), # default: all modalities
867
+ fused_transformer=TransformerConfig(
868
+ d_model=256, num_heads=8, num_layers=4,
869
+ dim_feedforward=1024, dropout=0.1, attn_dropout=0.1,
870
+ activation="gelu", pooling="cls",
871
+ ),
872
+ )
873
+ ```
874
+
875
+ Notes:
876
+
877
+ * Every modality in `fused_modalities` must define a `tokenizer` (even if its per-modality encoder is MLP).
878
+ * If `fused_require_all_modalities=True` and a fused modality is missing at inference, UniVI falls back to MoE/PoE fusion.
879
+
880
+ ---
881
+
882
+ ## Hyperparameter optimization (optional)
883
+
884
+ ```python
885
+ from univi.hyperparam_optimization import (
886
+ run_multiome_hparam_search,
887
+ run_citeseq_hparam_search,
888
+ run_teaseq_hparam_search,
889
+ run_rna_hparam_search,
890
+ run_atac_hparam_search,
891
+ run_adt_hparam_search,
892
+ )
893
+ ```
894
+
895
+ See `univi/hyperparam_optimization/` and `notebooks/` for examples.
896
+
897
+ ---
898
+
899
+ ## Contact, questions, and bug reports
900
+
901
+ * **Questions / comments:** open a GitHub Issue with the `question` label (or a Discussion if enabled).
902
+ * **Bug reports:** open a GitHub Issue and include:
903
+
904
+ * your UniVI version: `python -c "import univi; print(univi.__version__)"`
905
+ * minimal code to reproduce (or a short notebook snippet)
906
+ * stack trace + OS/CUDA/PyTorch versions
907
+ * **Feature requests:** open an Issue describing the use-case + expected inputs/outputs (a tiny example is ideal).
908
+