univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/__init__.py ADDED
@@ -0,0 +1,120 @@
1
+ # univi/__init__.py
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, List
6
+
7
+ __version__ = "0.3.4"
8
+
9
+ # Eager (fast/light) public API
10
+ from .config import ModalityConfig, UniVIConfig, TrainingConfig
11
+ from .models import UniVIMultiModalVAE
12
+ from . import matching
13
+
14
+ __all__ = [
15
+ "__version__",
16
+ # configs
17
+ "ModalityConfig",
18
+ "UniVIConfig",
19
+ "TrainingConfig",
20
+ # model
21
+ "UniVIMultiModalVAE",
22
+ # lightweight module
23
+ "matching",
24
+ # model state
25
+ "save_checkpoint",
26
+ "load_checkpoint",
27
+ "restore_checkpoint",
28
+ # lazy exports
29
+ "UniVITrainer",
30
+ "write_univi_latent",
31
+ "MultiModalDataset",
32
+ "pipeline",
33
+ "diagnostics",
34
+ # modules
35
+ "evaluation",
36
+ "plotting",
37
+ # eval convenience (optional)
38
+ "encode_adata",
39
+ "evaluate_alignment",
40
+ # interpretability
41
+ "interpretability",
42
+ "fused_encode_with_meta_and_attn",
43
+ "feature_importance_for_head",
44
+ "top_cross_modal_feature_pairs_from_attn",
45
+ ]
46
+
47
+
48
+ def __getattr__(name: str) -> Any:
49
+ """
50
+ Lazy exports keep `import univi` fast/light and avoid heavy deps unless needed.
51
+ """
52
+ # ---- training ----
53
+ if name == "UniVITrainer":
54
+ from .trainer import UniVITrainer
55
+ return UniVITrainer
56
+
57
+ # ---- IO ----
58
+ if name == "write_univi_latent":
59
+ from .utils.io import write_univi_latent
60
+ return write_univi_latent
61
+
62
+ # ---- data ----
63
+ if name == "MultiModalDataset":
64
+ from .data import MultiModalDataset
65
+ return MultiModalDataset
66
+
67
+ # ---- model state ----
68
+ if name in {"save_checkpoint", "load_checkpoint", "restore_checkpoint"}:
69
+ from .utils.io import save_checkpoint, load_checkpoint, restore_checkpoint
70
+ return {"save_checkpoint": save_checkpoint, "load_checkpoint": load_checkpoint, "restore_checkpoint": restore_checkpoint}[name]
71
+
72
+ # ---- modules (return module objects) ----
73
+ if name == "pipeline":
74
+ from . import pipeline as _pipeline
75
+ return _pipeline
76
+
77
+ if name == "diagnostics":
78
+ from . import diagnostics as _diagnostics
79
+ return _diagnostics
80
+
81
+ if name == "evaluation":
82
+ from . import evaluation as _evaluation
83
+ return _evaluation
84
+
85
+ if name == "plotting":
86
+ from . import plotting as _plotting
87
+ return _plotting
88
+
89
+ # ---- interpretability ----
90
+ if name == "interpretability":
91
+ from . import interpretability as _interpretability
92
+ return _interpretability
93
+
94
+ if name in {
95
+ "fused_encode_with_meta_and_attn",
96
+ "feature_importance_for_head",
97
+ "top_cross_modal_feature_pairs_from_attn",
98
+ }:
99
+ from .interpretability import (
100
+ fused_encode_with_meta_and_attn,
101
+ feature_importance_for_head,
102
+ top_cross_modal_feature_pairs_from_attn,
103
+ )
104
+ return {
105
+ "fused_encode_with_meta_and_attn": fused_encode_with_meta_and_attn,
106
+ "feature_importance_for_head": feature_importance_for_head,
107
+ "top_cross_modal_feature_pairs_from_attn": top_cross_modal_feature_pairs_from_attn,
108
+ }[name]
109
+
110
+ # ---- eval convenience functions (re-export) ----
111
+ if name in {"encode_adata", "evaluate_alignment"}:
112
+ from .evaluation import encode_adata, evaluate_alignment
113
+ return {"encode_adata": encode_adata, "evaluate_alignment": evaluate_alignment}[name]
114
+
115
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
116
+
117
+
118
+ def __dir__() -> List[str]:
119
+ return sorted(list(globals().keys()) + __all__)
120
+
univi/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ # univi/__main__.py
2
+ from .cli import main
3
+
4
+ if __name__ == "__main__":
5
+ raise SystemExit(main())
univi/cli.py ADDED
@@ -0,0 +1,60 @@
1
+ # univi/cli.py
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import os
6
+ import sys
7
+
8
+ from .pipeline import load_model_and_data, encode_latents_paired
9
+ from .diagnostics import export_supplemental_table_s1
10
+
11
+
12
+ def main(argv=None):
13
+ ap = argparse.ArgumentParser(prog="univi")
14
+ sub = ap.add_subparsers(dest="cmd", required=True)
15
+
16
+ ap_s1 = sub.add_parser("export-s1", help="Export Supplemental_Table_S1.xlsx (env + hparams + dataset stats).")
17
+ ap_s1.add_argument("--config", required=True)
18
+ ap_s1.add_argument("--checkpoint", default=None)
19
+ ap_s1.add_argument("--data-root", default=None)
20
+ ap_s1.add_argument("--out", required=True)
21
+
22
+ ap_encode = sub.add_parser("encode", help="Encode paired latents and save as .npz")
23
+ ap_encode.add_argument("--config", required=True)
24
+ ap_encode.add_argument("--checkpoint", required=True)
25
+ ap_encode.add_argument("--data-root", default=None)
26
+ ap_encode.add_argument("--out", required=True)
27
+ ap_encode.add_argument("--device", default="cpu")
28
+ ap_encode.add_argument("--batch-size", type=int, default=512)
29
+
30
+ args = ap.parse_args(argv)
31
+
32
+ if args.cmd == "export-s1":
33
+ cfg, adata_dict, model, layer_by, xkey_by = load_model_and_data(
34
+ args.config, checkpoint_path=args.checkpoint, data_root=args.data_root, device="cpu"
35
+ )
36
+ export_supplemental_table_s1(
37
+ args.config,
38
+ adata_dict,
39
+ out_xlsx=args.out,
40
+ layer_by=layer_by,
41
+ xkey_by=xkey_by,
42
+ extra_metrics=None,
43
+ )
44
+ return 0
45
+
46
+ if args.cmd == "encode":
47
+ cfg, adata_dict, model, layer_by, xkey_by = load_model_and_data(
48
+ args.config, checkpoint_path=args.checkpoint, data_root=args.data_root, device=args.device
49
+ )
50
+ Z = encode_latents_paired(model, adata_dict, layer_by=layer_by, xkey_by=xkey_by, batch_size=args.batch_size, device=args.device, fused=True)
51
+ os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
52
+ import numpy as np
53
+ np.savez_compressed(args.out, **Z)
54
+ return 0
55
+
56
+ return 1
57
+
58
+
59
+ if __name__ == "__main__":
60
+ raise SystemExit(main())
univi/config.py ADDED
@@ -0,0 +1,340 @@
1
+ # univi/config.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional, Literal, Sequence, Any
6
+
7
+
8
+ # =============================================================================
9
+ # Transformer + tokenizer config
10
+ # =============================================================================
11
+
12
+ @dataclass
13
+ class TransformerConfig:
14
+ """
15
+ Configuration for transformer encoder backends.
16
+
17
+ Notes
18
+ -----
19
+ - Mirrors fields expected by univi/models/transformer.py:TransformerConfig.
20
+ - max_tokens is only needed if you enable learned positional embeddings.
21
+ - Relative positional bias is optional and intended mainly for ATAC peaks
22
+ when you provide token_pos (basepair midpoints) at runtime.
23
+ """
24
+ d_model: int
25
+ num_heads: int
26
+ num_layers: int
27
+ dim_feedforward: int = 4096
28
+ dropout: float = 0.1
29
+ attn_dropout: float = 0.1
30
+ activation: Literal["relu", "gelu"] = "gelu"
31
+ pooling: Literal["cls", "mean"] = "mean"
32
+ max_tokens: Optional[int] = None
33
+
34
+ # Optional: binned relative-position attention bias (e.g., genomic distance)
35
+ use_relpos_bias: bool = False
36
+ relpos_num_bins: int = 32
37
+ relpos_max_dist: float = 1e6 # basepairs
38
+
39
+
40
+ @dataclass
41
+ class TokenizerConfig:
42
+ """
43
+ Turns (B, F) into (B, T, D_in) + optional key_padding_mask.
44
+
45
+ Modes
46
+ -----
47
+ - "topk_scalar": top-k features per cell, scalar value only -> (B, K, 1)
48
+ - "topk_channels": top-k features per cell, multiple channels -> (B, K, C)
49
+ channels from: "value", "rank", "dropout"
50
+ - "patch": split features into contiguous patches -> (B, T, patch_size)
51
+ OR project each patch -> (B, T, patch_proj_dim)
52
+ - "topk_embed": top-k features per cell with explicit feature identity:
53
+ token = Emb(feature_id) + MLP(channels)
54
+ -> (B, K, d_model)
55
+
56
+ Optionally add ATAC coordinate embeddings:
57
+ token += Emb(chrom_id) + MLP(midpoint_bp / coord_scale)
58
+
59
+ Notes
60
+ -----
61
+ - topk_embed is the recommended way to use attention over sparse omics features
62
+ without losing feature identity.
63
+ - If you need relative bias, you should pass token_pos (bp midpoints) into the
64
+ transformer at runtime. The tokenizer will stash it in tokenizer.last_meta["token_pos"].
65
+ """
66
+ mode: Literal["topk_scalar", "topk_channels", "patch", "topk_embed"] = "topk_scalar"
67
+
68
+ # top-k settings
69
+ n_tokens: int = 256
70
+ channels: Sequence[Literal["value", "rank", "dropout"]] = ("value",)
71
+
72
+ # patch settings
73
+ patch_size: int = 32
74
+ patch_proj_dim: Optional[int] = None
75
+
76
+ # general
77
+ add_cls_token: bool = False
78
+
79
+ # ---- topk_embed settings ----
80
+ # required for topk_embed
81
+ n_features: Optional[int] = None
82
+ d_model: Optional[int] = None
83
+ value_mlp_hidden: int = 256
84
+
85
+ # optional coord embeddings (mainly ATAC)
86
+ use_coords: bool = False
87
+ chrom_vocab_size: int = 0
88
+ coord_scale: float = 1e6 # divide bp midpoints by this before coord MLP
89
+
90
+ # Optional per-feature metadata for coords (set at runtime; not great for JSON)
91
+ # Expected keys: {"chrom": ..., "start": ..., "end": ...}
92
+ # Values can be lists/arrays/torch tensors; tokenizer will convert to tensors.
93
+ feature_info: Optional[Dict[str, Any]] = None
94
+
95
+
96
+ # =============================================================================
97
+ # Core UniVI configs
98
+ # =============================================================================
99
+
100
+ @dataclass
101
+ class ModalityConfig:
102
+ """
103
+ Configuration for a single modality.
104
+
105
+ Notes
106
+ -----
107
+ - For categorical modalities, set:
108
+ likelihood="categorical"
109
+ input_dim = n_classes (C)
110
+
111
+ and optionally set:
112
+ input_kind="obs"
113
+ obs_key="your_obs_column"
114
+
115
+ The dataset returns a (B,1) tensor of label codes; the model converts
116
+ to one-hot for encoding and to class indices for CE.
117
+
118
+ - ignore_index is used for unlabeled entries (masked in CE).
119
+ """
120
+ name: str
121
+ input_dim: int
122
+ encoder_hidden: List[int]
123
+ decoder_hidden: List[int]
124
+ likelihood: str = "gaussian"
125
+
126
+ # categorical modality support
127
+ ignore_index: int = -1
128
+ input_kind: Literal["matrix", "obs"] = "matrix"
129
+ obs_key: Optional[str] = None
130
+
131
+ # encoder backend (per-modality only)
132
+ encoder_type: Literal["mlp", "transformer"] = "mlp"
133
+ transformer: Optional[TransformerConfig] = None
134
+ tokenizer: Optional[TokenizerConfig] = None
135
+
136
+
137
+ @dataclass
138
+ class ClassHeadConfig:
139
+ """
140
+ Configuration for an auxiliary supervised classification head p(y_h | z).
141
+
142
+ Notes
143
+ -----
144
+ - from_mu=True: classify from mu_z (more stable), else from sampled z.
145
+ - warmup: epoch before enabling this head's loss.
146
+ - adversarial=True: gradient reversal head (domain/tech confusion).
147
+ """
148
+ name: str
149
+ n_classes: int
150
+ loss_weight: float = 1.0
151
+ ignore_index: int = -1
152
+ from_mu: bool = True
153
+ warmup: int = 0
154
+
155
+ adversarial: bool = False
156
+ adv_lambda: float = 1.0
157
+
158
+
159
+ @dataclass
160
+ class UniVIConfig:
161
+ latent_dim: int
162
+ modalities: List[ModalityConfig]
163
+
164
+ beta: float = 1.0
165
+ gamma: float = 1.0
166
+
167
+ encoder_dropout: float = 0.0
168
+ decoder_dropout: float = 0.0
169
+ encoder_batchnorm: bool = True
170
+ decoder_batchnorm: bool = False
171
+
172
+ kl_anneal_start: int = 0
173
+ kl_anneal_end: int = 0
174
+ align_anneal_start: int = 0
175
+ align_anneal_end: int = 0
176
+
177
+ class_heads: Optional[List[ClassHeadConfig]] = None
178
+ label_head_name: str = "label"
179
+
180
+ # ---------------------------------------------------------------------
181
+ # Optional fused multimodal encoder over concatenated tokens
182
+ # ---------------------------------------------------------------------
183
+ fused_encoder_type: Literal["moe", "multimodal_transformer"] = "moe"
184
+ fused_transformer: Optional[TransformerConfig] = None
185
+ fused_modalities: Optional[Sequence[str]] = None # default: all modalities
186
+ fused_add_modality_embeddings: bool = True
187
+ fused_require_all_modalities: bool = True # if True: fall back to MoE when missing
188
+
189
+ def validate(self) -> None:
190
+ if int(self.latent_dim) <= 0:
191
+ raise ValueError(f"latent_dim must be > 0, got {self.latent_dim}")
192
+
193
+ # modality name sanity
194
+ names = [m.name for m in self.modalities]
195
+ if len(set(names)) != len(names):
196
+ dupes = sorted({n for n in names if names.count(n) > 1})
197
+ raise ValueError(f"Duplicate modality names in cfg.modalities: {dupes}")
198
+
199
+ mod_by_name: Dict[str, ModalityConfig] = {m.name: m for m in self.modalities}
200
+
201
+ for m in self.modalities:
202
+ if int(m.input_dim) <= 0:
203
+ raise ValueError(f"Modality {m.name!r}: input_dim must be > 0, got {m.input_dim}")
204
+
205
+ lk = (m.likelihood or "").lower().strip()
206
+ if lk in ("categorical", "cat", "ce", "cross_entropy", "multinomial", "softmax"):
207
+ if int(m.input_dim) < 2:
208
+ raise ValueError(f"Categorical modality {m.name!r}: input_dim must be n_classes >= 2.")
209
+ if m.input_kind == "obs" and not m.obs_key:
210
+ raise ValueError(f"Categorical modality {m.name!r}: input_kind='obs' requires obs_key.")
211
+
212
+ enc_type = (m.encoder_type or "mlp").lower().strip()
213
+ if enc_type not in ("mlp", "transformer"):
214
+ raise ValueError(
215
+ f"Modality {m.name!r}: encoder_type must be 'mlp' or 'transformer', got {m.encoder_type!r}"
216
+ )
217
+
218
+ if enc_type == "transformer":
219
+ if m.transformer is None:
220
+ raise ValueError(f"Modality {m.name!r}: encoder_type='transformer' requires transformer config.")
221
+ if m.tokenizer is None:
222
+ raise ValueError(f"Modality {m.name!r}: encoder_type='transformer' requires tokenizer config.")
223
+ _validate_tokenizer(m.name, m.tokenizer)
224
+
225
+ # fused encoder sanity
226
+ fe = (self.fused_encoder_type or "moe").lower().strip()
227
+ if fe not in ("moe", "multimodal_transformer"):
228
+ raise ValueError(
229
+ f"fused_encoder_type must be 'moe' or 'multimodal_transformer', got {self.fused_encoder_type!r}"
230
+ )
231
+
232
+ if fe == "multimodal_transformer":
233
+ if self.fused_transformer is None:
234
+ raise ValueError("fused_encoder_type='multimodal_transformer' requires UniVIConfig.fused_transformer.")
235
+
236
+ fused_names = list(self.fused_modalities) if self.fused_modalities is not None else list(mod_by_name.keys())
237
+ if not fused_names:
238
+ raise ValueError("fused_modalities is empty; expected at least one modality name.")
239
+
240
+ missing = [n for n in fused_names if n not in mod_by_name]
241
+ if missing:
242
+ raise ValueError(f"fused_modalities contains unknown modalities: {missing}. Known: {list(mod_by_name)}")
243
+
244
+ for n in fused_names:
245
+ tok = mod_by_name[n].tokenizer
246
+ if tok is None:
247
+ raise ValueError(
248
+ f"Fused multimodal transformer requires ModalityConfig.tokenizer for modality {n!r}."
249
+ )
250
+ _validate_tokenizer(n, tok)
251
+
252
+ # class head sanity
253
+ if self.class_heads is not None:
254
+ hn = [h.name for h in self.class_heads]
255
+ if len(set(hn)) != len(hn):
256
+ dupes = sorted({n for n in hn if hn.count(n) > 1})
257
+ raise ValueError(f"Duplicate class head names in cfg.class_heads: {dupes}")
258
+ for h in self.class_heads:
259
+ if int(h.n_classes) < 2:
260
+ raise ValueError(f"Class head {h.name!r}: n_classes must be >= 2.")
261
+ if float(h.loss_weight) < 0:
262
+ raise ValueError(f"Class head {h.name!r}: loss_weight must be >= 0.")
263
+ if int(h.warmup) < 0:
264
+ raise ValueError(f"Class head {h.name!r}: warmup must be >= 0.")
265
+ if float(getattr(h, "adv_lambda", 1.0)) < 0.0:
266
+ raise ValueError(f"Class head {h.name!r}: adv_lambda must be >= 0.")
267
+
268
+ # anneal sanity
269
+ for k in ("kl_anneal_start", "kl_anneal_end", "align_anneal_start", "align_anneal_end"):
270
+ v = int(getattr(self, k))
271
+ if v < 0:
272
+ raise ValueError(f"{k} must be >= 0, got {v}")
273
+
274
+
275
+ def _validate_tokenizer(mod_name: str, tok: TokenizerConfig) -> None:
276
+ mode = (tok.mode or "").lower().strip()
277
+ if mode not in ("topk_scalar", "topk_channels", "patch", "topk_embed"):
278
+ raise ValueError(
279
+ f"Modality {mod_name!r}: tokenizer.mode must be one of "
280
+ f"['topk_scalar','topk_channels','patch','topk_embed'], got {tok.mode!r}"
281
+ )
282
+
283
+ if mode in ("topk_scalar", "topk_channels", "topk_embed"):
284
+ if int(tok.n_tokens) <= 0:
285
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.n_tokens must be > 0 for topk_*")
286
+
287
+ if mode == "topk_channels":
288
+ if not tok.channels:
289
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.channels must be non-empty for topk_channels")
290
+ bad = [c for c in tok.channels if c not in ("value", "rank", "dropout")]
291
+ if bad:
292
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.channels has invalid entries: {bad}")
293
+
294
+ if mode == "topk_embed":
295
+ if tok.n_features is None or int(tok.n_features) <= 0:
296
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.n_features must be set (>0) for topk_embed")
297
+ if tok.d_model is None or int(tok.d_model) <= 0:
298
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.d_model must be set (>0) for topk_embed")
299
+ if not tok.channels:
300
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.channels must be non-empty for topk_embed")
301
+ bad = [c for c in tok.channels if c not in ("value", "rank", "dropout")]
302
+ if bad:
303
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.channels has invalid entries: {bad}")
304
+
305
+ if tok.use_coords:
306
+ if int(tok.chrom_vocab_size) <= 0:
307
+ raise ValueError(
308
+ f"Modality {mod_name!r}: tokenizer.chrom_vocab_size must be > 0 when use_coords=True"
309
+ )
310
+ # feature_info may be injected at runtime; validate if present
311
+ if tok.feature_info is not None:
312
+ for k in ("chrom", "start", "end"):
313
+ if k not in tok.feature_info:
314
+ raise ValueError(
315
+ f"Modality {mod_name!r}: tokenizer.feature_info missing key {k!r} (required for coords)"
316
+ )
317
+
318
+ if mode == "patch":
319
+ if int(tok.patch_size) <= 0:
320
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.patch_size must be > 0 for patch")
321
+ if tok.patch_proj_dim is not None and int(tok.patch_proj_dim) <= 0:
322
+ raise ValueError(f"Modality {mod_name!r}: tokenizer.patch_proj_dim must be > 0 if set")
323
+
324
+
325
+ @dataclass
326
+ class TrainingConfig:
327
+ n_epochs: int = 200
328
+ batch_size: int = 256
329
+ lr: float = 1e-3
330
+ weight_decay: float = 0.0
331
+ device: str = "cpu"
332
+ log_every: int = 10
333
+ grad_clip: Optional[float] = None
334
+ num_workers: int = 0
335
+ seed: int = 0
336
+
337
+ early_stopping: bool = False
338
+ patience: int = 20
339
+ min_delta: float = 0.0
340
+