univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/__init__.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# univi/__init__.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, List
|
|
6
|
+
|
|
7
|
+
__version__ = "0.3.4"
|
|
8
|
+
|
|
9
|
+
# Eager (fast/light) public API
|
|
10
|
+
from .config import ModalityConfig, UniVIConfig, TrainingConfig
|
|
11
|
+
from .models import UniVIMultiModalVAE
|
|
12
|
+
from . import matching
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"__version__",
|
|
16
|
+
# configs
|
|
17
|
+
"ModalityConfig",
|
|
18
|
+
"UniVIConfig",
|
|
19
|
+
"TrainingConfig",
|
|
20
|
+
# model
|
|
21
|
+
"UniVIMultiModalVAE",
|
|
22
|
+
# lightweight module
|
|
23
|
+
"matching",
|
|
24
|
+
# model state
|
|
25
|
+
"save_checkpoint",
|
|
26
|
+
"load_checkpoint",
|
|
27
|
+
"restore_checkpoint",
|
|
28
|
+
# lazy exports
|
|
29
|
+
"UniVITrainer",
|
|
30
|
+
"write_univi_latent",
|
|
31
|
+
"MultiModalDataset",
|
|
32
|
+
"pipeline",
|
|
33
|
+
"diagnostics",
|
|
34
|
+
# modules
|
|
35
|
+
"evaluation",
|
|
36
|
+
"plotting",
|
|
37
|
+
# eval convenience (optional)
|
|
38
|
+
"encode_adata",
|
|
39
|
+
"evaluate_alignment",
|
|
40
|
+
# interpretability
|
|
41
|
+
"interpretability",
|
|
42
|
+
"fused_encode_with_meta_and_attn",
|
|
43
|
+
"feature_importance_for_head",
|
|
44
|
+
"top_cross_modal_feature_pairs_from_attn",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def __getattr__(name: str) -> Any:
|
|
49
|
+
"""
|
|
50
|
+
Lazy exports keep `import univi` fast/light and avoid heavy deps unless needed.
|
|
51
|
+
"""
|
|
52
|
+
# ---- training ----
|
|
53
|
+
if name == "UniVITrainer":
|
|
54
|
+
from .trainer import UniVITrainer
|
|
55
|
+
return UniVITrainer
|
|
56
|
+
|
|
57
|
+
# ---- IO ----
|
|
58
|
+
if name == "write_univi_latent":
|
|
59
|
+
from .utils.io import write_univi_latent
|
|
60
|
+
return write_univi_latent
|
|
61
|
+
|
|
62
|
+
# ---- data ----
|
|
63
|
+
if name == "MultiModalDataset":
|
|
64
|
+
from .data import MultiModalDataset
|
|
65
|
+
return MultiModalDataset
|
|
66
|
+
|
|
67
|
+
# ---- model state ----
|
|
68
|
+
if name in {"save_checkpoint", "load_checkpoint", "restore_checkpoint"}:
|
|
69
|
+
from .utils.io import save_checkpoint, load_checkpoint, restore_checkpoint
|
|
70
|
+
return {"save_checkpoint": save_checkpoint, "load_checkpoint": load_checkpoint, "restore_checkpoint": restore_checkpoint}[name]
|
|
71
|
+
|
|
72
|
+
# ---- modules (return module objects) ----
|
|
73
|
+
if name == "pipeline":
|
|
74
|
+
from . import pipeline as _pipeline
|
|
75
|
+
return _pipeline
|
|
76
|
+
|
|
77
|
+
if name == "diagnostics":
|
|
78
|
+
from . import diagnostics as _diagnostics
|
|
79
|
+
return _diagnostics
|
|
80
|
+
|
|
81
|
+
if name == "evaluation":
|
|
82
|
+
from . import evaluation as _evaluation
|
|
83
|
+
return _evaluation
|
|
84
|
+
|
|
85
|
+
if name == "plotting":
|
|
86
|
+
from . import plotting as _plotting
|
|
87
|
+
return _plotting
|
|
88
|
+
|
|
89
|
+
# ---- interpretability ----
|
|
90
|
+
if name == "interpretability":
|
|
91
|
+
from . import interpretability as _interpretability
|
|
92
|
+
return _interpretability
|
|
93
|
+
|
|
94
|
+
if name in {
|
|
95
|
+
"fused_encode_with_meta_and_attn",
|
|
96
|
+
"feature_importance_for_head",
|
|
97
|
+
"top_cross_modal_feature_pairs_from_attn",
|
|
98
|
+
}:
|
|
99
|
+
from .interpretability import (
|
|
100
|
+
fused_encode_with_meta_and_attn,
|
|
101
|
+
feature_importance_for_head,
|
|
102
|
+
top_cross_modal_feature_pairs_from_attn,
|
|
103
|
+
)
|
|
104
|
+
return {
|
|
105
|
+
"fused_encode_with_meta_and_attn": fused_encode_with_meta_and_attn,
|
|
106
|
+
"feature_importance_for_head": feature_importance_for_head,
|
|
107
|
+
"top_cross_modal_feature_pairs_from_attn": top_cross_modal_feature_pairs_from_attn,
|
|
108
|
+
}[name]
|
|
109
|
+
|
|
110
|
+
# ---- eval convenience functions (re-export) ----
|
|
111
|
+
if name in {"encode_adata", "evaluate_alignment"}:
|
|
112
|
+
from .evaluation import encode_adata, evaluate_alignment
|
|
113
|
+
return {"encode_adata": encode_adata, "evaluate_alignment": evaluate_alignment}[name]
|
|
114
|
+
|
|
115
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def __dir__() -> List[str]:
|
|
119
|
+
return sorted(list(globals().keys()) + __all__)
|
|
120
|
+
|
univi/__main__.py
ADDED
univi/cli.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# univi/cli.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
from .pipeline import load_model_and_data, encode_latents_paired
|
|
9
|
+
from .diagnostics import export_supplemental_table_s1
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def main(argv=None):
|
|
13
|
+
ap = argparse.ArgumentParser(prog="univi")
|
|
14
|
+
sub = ap.add_subparsers(dest="cmd", required=True)
|
|
15
|
+
|
|
16
|
+
ap_s1 = sub.add_parser("export-s1", help="Export Supplemental_Table_S1.xlsx (env + hparams + dataset stats).")
|
|
17
|
+
ap_s1.add_argument("--config", required=True)
|
|
18
|
+
ap_s1.add_argument("--checkpoint", default=None)
|
|
19
|
+
ap_s1.add_argument("--data-root", default=None)
|
|
20
|
+
ap_s1.add_argument("--out", required=True)
|
|
21
|
+
|
|
22
|
+
ap_encode = sub.add_parser("encode", help="Encode paired latents and save as .npz")
|
|
23
|
+
ap_encode.add_argument("--config", required=True)
|
|
24
|
+
ap_encode.add_argument("--checkpoint", required=True)
|
|
25
|
+
ap_encode.add_argument("--data-root", default=None)
|
|
26
|
+
ap_encode.add_argument("--out", required=True)
|
|
27
|
+
ap_encode.add_argument("--device", default="cpu")
|
|
28
|
+
ap_encode.add_argument("--batch-size", type=int, default=512)
|
|
29
|
+
|
|
30
|
+
args = ap.parse_args(argv)
|
|
31
|
+
|
|
32
|
+
if args.cmd == "export-s1":
|
|
33
|
+
cfg, adata_dict, model, layer_by, xkey_by = load_model_and_data(
|
|
34
|
+
args.config, checkpoint_path=args.checkpoint, data_root=args.data_root, device="cpu"
|
|
35
|
+
)
|
|
36
|
+
export_supplemental_table_s1(
|
|
37
|
+
args.config,
|
|
38
|
+
adata_dict,
|
|
39
|
+
out_xlsx=args.out,
|
|
40
|
+
layer_by=layer_by,
|
|
41
|
+
xkey_by=xkey_by,
|
|
42
|
+
extra_metrics=None,
|
|
43
|
+
)
|
|
44
|
+
return 0
|
|
45
|
+
|
|
46
|
+
if args.cmd == "encode":
|
|
47
|
+
cfg, adata_dict, model, layer_by, xkey_by = load_model_and_data(
|
|
48
|
+
args.config, checkpoint_path=args.checkpoint, data_root=args.data_root, device=args.device
|
|
49
|
+
)
|
|
50
|
+
Z = encode_latents_paired(model, adata_dict, layer_by=layer_by, xkey_by=xkey_by, batch_size=args.batch_size, device=args.device, fused=True)
|
|
51
|
+
os.makedirs(os.path.dirname(args.out) or ".", exist_ok=True)
|
|
52
|
+
import numpy as np
|
|
53
|
+
np.savez_compressed(args.out, **Z)
|
|
54
|
+
return 0
|
|
55
|
+
|
|
56
|
+
return 1
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
if __name__ == "__main__":
|
|
60
|
+
raise SystemExit(main())
|
univi/config.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
1
|
+
# univi/config.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Dict, List, Optional, Literal, Sequence, Any
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# =============================================================================
|
|
9
|
+
# Transformer + tokenizer config
|
|
10
|
+
# =============================================================================
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class TransformerConfig:
|
|
14
|
+
"""
|
|
15
|
+
Configuration for transformer encoder backends.
|
|
16
|
+
|
|
17
|
+
Notes
|
|
18
|
+
-----
|
|
19
|
+
- Mirrors fields expected by univi/models/transformer.py:TransformerConfig.
|
|
20
|
+
- max_tokens is only needed if you enable learned positional embeddings.
|
|
21
|
+
- Relative positional bias is optional and intended mainly for ATAC peaks
|
|
22
|
+
when you provide token_pos (basepair midpoints) at runtime.
|
|
23
|
+
"""
|
|
24
|
+
d_model: int
|
|
25
|
+
num_heads: int
|
|
26
|
+
num_layers: int
|
|
27
|
+
dim_feedforward: int = 4096
|
|
28
|
+
dropout: float = 0.1
|
|
29
|
+
attn_dropout: float = 0.1
|
|
30
|
+
activation: Literal["relu", "gelu"] = "gelu"
|
|
31
|
+
pooling: Literal["cls", "mean"] = "mean"
|
|
32
|
+
max_tokens: Optional[int] = None
|
|
33
|
+
|
|
34
|
+
# Optional: binned relative-position attention bias (e.g., genomic distance)
|
|
35
|
+
use_relpos_bias: bool = False
|
|
36
|
+
relpos_num_bins: int = 32
|
|
37
|
+
relpos_max_dist: float = 1e6 # basepairs
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class TokenizerConfig:
|
|
42
|
+
"""
|
|
43
|
+
Turns (B, F) into (B, T, D_in) + optional key_padding_mask.
|
|
44
|
+
|
|
45
|
+
Modes
|
|
46
|
+
-----
|
|
47
|
+
- "topk_scalar": top-k features per cell, scalar value only -> (B, K, 1)
|
|
48
|
+
- "topk_channels": top-k features per cell, multiple channels -> (B, K, C)
|
|
49
|
+
channels from: "value", "rank", "dropout"
|
|
50
|
+
- "patch": split features into contiguous patches -> (B, T, patch_size)
|
|
51
|
+
OR project each patch -> (B, T, patch_proj_dim)
|
|
52
|
+
- "topk_embed": top-k features per cell with explicit feature identity:
|
|
53
|
+
token = Emb(feature_id) + MLP(channels)
|
|
54
|
+
-> (B, K, d_model)
|
|
55
|
+
|
|
56
|
+
Optionally add ATAC coordinate embeddings:
|
|
57
|
+
token += Emb(chrom_id) + MLP(midpoint_bp / coord_scale)
|
|
58
|
+
|
|
59
|
+
Notes
|
|
60
|
+
-----
|
|
61
|
+
- topk_embed is the recommended way to use attention over sparse omics features
|
|
62
|
+
without losing feature identity.
|
|
63
|
+
- If you need relative bias, you should pass token_pos (bp midpoints) into the
|
|
64
|
+
transformer at runtime. The tokenizer will stash it in tokenizer.last_meta["token_pos"].
|
|
65
|
+
"""
|
|
66
|
+
mode: Literal["topk_scalar", "topk_channels", "patch", "topk_embed"] = "topk_scalar"
|
|
67
|
+
|
|
68
|
+
# top-k settings
|
|
69
|
+
n_tokens: int = 256
|
|
70
|
+
channels: Sequence[Literal["value", "rank", "dropout"]] = ("value",)
|
|
71
|
+
|
|
72
|
+
# patch settings
|
|
73
|
+
patch_size: int = 32
|
|
74
|
+
patch_proj_dim: Optional[int] = None
|
|
75
|
+
|
|
76
|
+
# general
|
|
77
|
+
add_cls_token: bool = False
|
|
78
|
+
|
|
79
|
+
# ---- topk_embed settings ----
|
|
80
|
+
# required for topk_embed
|
|
81
|
+
n_features: Optional[int] = None
|
|
82
|
+
d_model: Optional[int] = None
|
|
83
|
+
value_mlp_hidden: int = 256
|
|
84
|
+
|
|
85
|
+
# optional coord embeddings (mainly ATAC)
|
|
86
|
+
use_coords: bool = False
|
|
87
|
+
chrom_vocab_size: int = 0
|
|
88
|
+
coord_scale: float = 1e6 # divide bp midpoints by this before coord MLP
|
|
89
|
+
|
|
90
|
+
# Optional per-feature metadata for coords (set at runtime; not great for JSON)
|
|
91
|
+
# Expected keys: {"chrom": ..., "start": ..., "end": ...}
|
|
92
|
+
# Values can be lists/arrays/torch tensors; tokenizer will convert to tensors.
|
|
93
|
+
feature_info: Optional[Dict[str, Any]] = None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# =============================================================================
|
|
97
|
+
# Core UniVI configs
|
|
98
|
+
# =============================================================================
|
|
99
|
+
|
|
100
|
+
@dataclass
|
|
101
|
+
class ModalityConfig:
|
|
102
|
+
"""
|
|
103
|
+
Configuration for a single modality.
|
|
104
|
+
|
|
105
|
+
Notes
|
|
106
|
+
-----
|
|
107
|
+
- For categorical modalities, set:
|
|
108
|
+
likelihood="categorical"
|
|
109
|
+
input_dim = n_classes (C)
|
|
110
|
+
|
|
111
|
+
and optionally set:
|
|
112
|
+
input_kind="obs"
|
|
113
|
+
obs_key="your_obs_column"
|
|
114
|
+
|
|
115
|
+
The dataset returns a (B,1) tensor of label codes; the model converts
|
|
116
|
+
to one-hot for encoding and to class indices for CE.
|
|
117
|
+
|
|
118
|
+
- ignore_index is used for unlabeled entries (masked in CE).
|
|
119
|
+
"""
|
|
120
|
+
name: str
|
|
121
|
+
input_dim: int
|
|
122
|
+
encoder_hidden: List[int]
|
|
123
|
+
decoder_hidden: List[int]
|
|
124
|
+
likelihood: str = "gaussian"
|
|
125
|
+
|
|
126
|
+
# categorical modality support
|
|
127
|
+
ignore_index: int = -1
|
|
128
|
+
input_kind: Literal["matrix", "obs"] = "matrix"
|
|
129
|
+
obs_key: Optional[str] = None
|
|
130
|
+
|
|
131
|
+
# encoder backend (per-modality only)
|
|
132
|
+
encoder_type: Literal["mlp", "transformer"] = "mlp"
|
|
133
|
+
transformer: Optional[TransformerConfig] = None
|
|
134
|
+
tokenizer: Optional[TokenizerConfig] = None
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass
|
|
138
|
+
class ClassHeadConfig:
|
|
139
|
+
"""
|
|
140
|
+
Configuration for an auxiliary supervised classification head p(y_h | z).
|
|
141
|
+
|
|
142
|
+
Notes
|
|
143
|
+
-----
|
|
144
|
+
- from_mu=True: classify from mu_z (more stable), else from sampled z.
|
|
145
|
+
- warmup: epoch before enabling this head's loss.
|
|
146
|
+
- adversarial=True: gradient reversal head (domain/tech confusion).
|
|
147
|
+
"""
|
|
148
|
+
name: str
|
|
149
|
+
n_classes: int
|
|
150
|
+
loss_weight: float = 1.0
|
|
151
|
+
ignore_index: int = -1
|
|
152
|
+
from_mu: bool = True
|
|
153
|
+
warmup: int = 0
|
|
154
|
+
|
|
155
|
+
adversarial: bool = False
|
|
156
|
+
adv_lambda: float = 1.0
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@dataclass
|
|
160
|
+
class UniVIConfig:
|
|
161
|
+
latent_dim: int
|
|
162
|
+
modalities: List[ModalityConfig]
|
|
163
|
+
|
|
164
|
+
beta: float = 1.0
|
|
165
|
+
gamma: float = 1.0
|
|
166
|
+
|
|
167
|
+
encoder_dropout: float = 0.0
|
|
168
|
+
decoder_dropout: float = 0.0
|
|
169
|
+
encoder_batchnorm: bool = True
|
|
170
|
+
decoder_batchnorm: bool = False
|
|
171
|
+
|
|
172
|
+
kl_anneal_start: int = 0
|
|
173
|
+
kl_anneal_end: int = 0
|
|
174
|
+
align_anneal_start: int = 0
|
|
175
|
+
align_anneal_end: int = 0
|
|
176
|
+
|
|
177
|
+
class_heads: Optional[List[ClassHeadConfig]] = None
|
|
178
|
+
label_head_name: str = "label"
|
|
179
|
+
|
|
180
|
+
# ---------------------------------------------------------------------
|
|
181
|
+
# Optional fused multimodal encoder over concatenated tokens
|
|
182
|
+
# ---------------------------------------------------------------------
|
|
183
|
+
fused_encoder_type: Literal["moe", "multimodal_transformer"] = "moe"
|
|
184
|
+
fused_transformer: Optional[TransformerConfig] = None
|
|
185
|
+
fused_modalities: Optional[Sequence[str]] = None # default: all modalities
|
|
186
|
+
fused_add_modality_embeddings: bool = True
|
|
187
|
+
fused_require_all_modalities: bool = True # if True: fall back to MoE when missing
|
|
188
|
+
|
|
189
|
+
def validate(self) -> None:
|
|
190
|
+
if int(self.latent_dim) <= 0:
|
|
191
|
+
raise ValueError(f"latent_dim must be > 0, got {self.latent_dim}")
|
|
192
|
+
|
|
193
|
+
# modality name sanity
|
|
194
|
+
names = [m.name for m in self.modalities]
|
|
195
|
+
if len(set(names)) != len(names):
|
|
196
|
+
dupes = sorted({n for n in names if names.count(n) > 1})
|
|
197
|
+
raise ValueError(f"Duplicate modality names in cfg.modalities: {dupes}")
|
|
198
|
+
|
|
199
|
+
mod_by_name: Dict[str, ModalityConfig] = {m.name: m for m in self.modalities}
|
|
200
|
+
|
|
201
|
+
for m in self.modalities:
|
|
202
|
+
if int(m.input_dim) <= 0:
|
|
203
|
+
raise ValueError(f"Modality {m.name!r}: input_dim must be > 0, got {m.input_dim}")
|
|
204
|
+
|
|
205
|
+
lk = (m.likelihood or "").lower().strip()
|
|
206
|
+
if lk in ("categorical", "cat", "ce", "cross_entropy", "multinomial", "softmax"):
|
|
207
|
+
if int(m.input_dim) < 2:
|
|
208
|
+
raise ValueError(f"Categorical modality {m.name!r}: input_dim must be n_classes >= 2.")
|
|
209
|
+
if m.input_kind == "obs" and not m.obs_key:
|
|
210
|
+
raise ValueError(f"Categorical modality {m.name!r}: input_kind='obs' requires obs_key.")
|
|
211
|
+
|
|
212
|
+
enc_type = (m.encoder_type or "mlp").lower().strip()
|
|
213
|
+
if enc_type not in ("mlp", "transformer"):
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"Modality {m.name!r}: encoder_type must be 'mlp' or 'transformer', got {m.encoder_type!r}"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if enc_type == "transformer":
|
|
219
|
+
if m.transformer is None:
|
|
220
|
+
raise ValueError(f"Modality {m.name!r}: encoder_type='transformer' requires transformer config.")
|
|
221
|
+
if m.tokenizer is None:
|
|
222
|
+
raise ValueError(f"Modality {m.name!r}: encoder_type='transformer' requires tokenizer config.")
|
|
223
|
+
_validate_tokenizer(m.name, m.tokenizer)
|
|
224
|
+
|
|
225
|
+
# fused encoder sanity
|
|
226
|
+
fe = (self.fused_encoder_type or "moe").lower().strip()
|
|
227
|
+
if fe not in ("moe", "multimodal_transformer"):
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"fused_encoder_type must be 'moe' or 'multimodal_transformer', got {self.fused_encoder_type!r}"
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if fe == "multimodal_transformer":
|
|
233
|
+
if self.fused_transformer is None:
|
|
234
|
+
raise ValueError("fused_encoder_type='multimodal_transformer' requires UniVIConfig.fused_transformer.")
|
|
235
|
+
|
|
236
|
+
fused_names = list(self.fused_modalities) if self.fused_modalities is not None else list(mod_by_name.keys())
|
|
237
|
+
if not fused_names:
|
|
238
|
+
raise ValueError("fused_modalities is empty; expected at least one modality name.")
|
|
239
|
+
|
|
240
|
+
missing = [n for n in fused_names if n not in mod_by_name]
|
|
241
|
+
if missing:
|
|
242
|
+
raise ValueError(f"fused_modalities contains unknown modalities: {missing}. Known: {list(mod_by_name)}")
|
|
243
|
+
|
|
244
|
+
for n in fused_names:
|
|
245
|
+
tok = mod_by_name[n].tokenizer
|
|
246
|
+
if tok is None:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"Fused multimodal transformer requires ModalityConfig.tokenizer for modality {n!r}."
|
|
249
|
+
)
|
|
250
|
+
_validate_tokenizer(n, tok)
|
|
251
|
+
|
|
252
|
+
# class head sanity
|
|
253
|
+
if self.class_heads is not None:
|
|
254
|
+
hn = [h.name for h in self.class_heads]
|
|
255
|
+
if len(set(hn)) != len(hn):
|
|
256
|
+
dupes = sorted({n for n in hn if hn.count(n) > 1})
|
|
257
|
+
raise ValueError(f"Duplicate class head names in cfg.class_heads: {dupes}")
|
|
258
|
+
for h in self.class_heads:
|
|
259
|
+
if int(h.n_classes) < 2:
|
|
260
|
+
raise ValueError(f"Class head {h.name!r}: n_classes must be >= 2.")
|
|
261
|
+
if float(h.loss_weight) < 0:
|
|
262
|
+
raise ValueError(f"Class head {h.name!r}: loss_weight must be >= 0.")
|
|
263
|
+
if int(h.warmup) < 0:
|
|
264
|
+
raise ValueError(f"Class head {h.name!r}: warmup must be >= 0.")
|
|
265
|
+
if float(getattr(h, "adv_lambda", 1.0)) < 0.0:
|
|
266
|
+
raise ValueError(f"Class head {h.name!r}: adv_lambda must be >= 0.")
|
|
267
|
+
|
|
268
|
+
# anneal sanity
|
|
269
|
+
for k in ("kl_anneal_start", "kl_anneal_end", "align_anneal_start", "align_anneal_end"):
|
|
270
|
+
v = int(getattr(self, k))
|
|
271
|
+
if v < 0:
|
|
272
|
+
raise ValueError(f"{k} must be >= 0, got {v}")
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def _validate_tokenizer(mod_name: str, tok: TokenizerConfig) -> None:
|
|
276
|
+
mode = (tok.mode or "").lower().strip()
|
|
277
|
+
if mode not in ("topk_scalar", "topk_channels", "patch", "topk_embed"):
|
|
278
|
+
raise ValueError(
|
|
279
|
+
f"Modality {mod_name!r}: tokenizer.mode must be one of "
|
|
280
|
+
f"['topk_scalar','topk_channels','patch','topk_embed'], got {tok.mode!r}"
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if mode in ("topk_scalar", "topk_channels", "topk_embed"):
|
|
284
|
+
if int(tok.n_tokens) <= 0:
|
|
285
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.n_tokens must be > 0 for topk_*")
|
|
286
|
+
|
|
287
|
+
if mode == "topk_channels":
|
|
288
|
+
if not tok.channels:
|
|
289
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.channels must be non-empty for topk_channels")
|
|
290
|
+
bad = [c for c in tok.channels if c not in ("value", "rank", "dropout")]
|
|
291
|
+
if bad:
|
|
292
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.channels has invalid entries: {bad}")
|
|
293
|
+
|
|
294
|
+
if mode == "topk_embed":
|
|
295
|
+
if tok.n_features is None or int(tok.n_features) <= 0:
|
|
296
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.n_features must be set (>0) for topk_embed")
|
|
297
|
+
if tok.d_model is None or int(tok.d_model) <= 0:
|
|
298
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.d_model must be set (>0) for topk_embed")
|
|
299
|
+
if not tok.channels:
|
|
300
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.channels must be non-empty for topk_embed")
|
|
301
|
+
bad = [c for c in tok.channels if c not in ("value", "rank", "dropout")]
|
|
302
|
+
if bad:
|
|
303
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.channels has invalid entries: {bad}")
|
|
304
|
+
|
|
305
|
+
if tok.use_coords:
|
|
306
|
+
if int(tok.chrom_vocab_size) <= 0:
|
|
307
|
+
raise ValueError(
|
|
308
|
+
f"Modality {mod_name!r}: tokenizer.chrom_vocab_size must be > 0 when use_coords=True"
|
|
309
|
+
)
|
|
310
|
+
# feature_info may be injected at runtime; validate if present
|
|
311
|
+
if tok.feature_info is not None:
|
|
312
|
+
for k in ("chrom", "start", "end"):
|
|
313
|
+
if k not in tok.feature_info:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"Modality {mod_name!r}: tokenizer.feature_info missing key {k!r} (required for coords)"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
if mode == "patch":
|
|
319
|
+
if int(tok.patch_size) <= 0:
|
|
320
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.patch_size must be > 0 for patch")
|
|
321
|
+
if tok.patch_proj_dim is not None and int(tok.patch_proj_dim) <= 0:
|
|
322
|
+
raise ValueError(f"Modality {mod_name!r}: tokenizer.patch_proj_dim must be > 0 if set")
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@dataclass
|
|
326
|
+
class TrainingConfig:
|
|
327
|
+
n_epochs: int = 200
|
|
328
|
+
batch_size: int = 256
|
|
329
|
+
lr: float = 1e-3
|
|
330
|
+
weight_decay: float = 0.0
|
|
331
|
+
device: str = "cpu"
|
|
332
|
+
log_every: int = 10
|
|
333
|
+
grad_clip: Optional[float] = None
|
|
334
|
+
num_workers: int = 0
|
|
335
|
+
seed: int = 0
|
|
336
|
+
|
|
337
|
+
early_stopping: bool = False
|
|
338
|
+
patience: int = 20
|
|
339
|
+
min_delta: float = 0.0
|
|
340
|
+
|