univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/models/encoders.py
ADDED
|
@@ -0,0 +1,848 @@
|
|
|
1
|
+
# univi/models/encoders.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from dataclasses import dataclass, replace
|
|
5
|
+
from typing import Dict, List, Optional, Sequence, Tuple, Union, Any, Mapping
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
from ..config import UniVIConfig, ModalityConfig, TokenizerConfig, TransformerConfig as CFGTransformerConfig
|
|
12
|
+
from .mlp import build_mlp
|
|
13
|
+
from .transformer import TransformerEncoder, TransformerConfig as ModelTransformerConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# =============================================================================
|
|
17
|
+
# Small config helper
|
|
18
|
+
# =============================================================================
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class EncoderConfig:
|
|
22
|
+
input_dim: int
|
|
23
|
+
hidden_dims: List[int]
|
|
24
|
+
latent_dim: int
|
|
25
|
+
dropout: float = 0.1
|
|
26
|
+
batchnorm: bool = True
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# Base encoders
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
class GaussianEncoder(nn.Module):
|
|
34
|
+
"""Base: x -> (mu, logvar) for a diagonal Gaussian."""
|
|
35
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MLPGaussianEncoder(GaussianEncoder):
|
|
40
|
+
"""MLP encoder: x -> (mu, logvar) directly."""
|
|
41
|
+
def __init__(self, cfg: EncoderConfig):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.cfg = cfg
|
|
44
|
+
self.net = build_mlp(
|
|
45
|
+
in_dim=int(cfg.input_dim),
|
|
46
|
+
hidden_dims=list(cfg.hidden_dims),
|
|
47
|
+
out_dim=2 * int(cfg.latent_dim),
|
|
48
|
+
dropout=float(cfg.dropout),
|
|
49
|
+
batchnorm=bool(cfg.batchnorm),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
53
|
+
h = self.net(x)
|
|
54
|
+
mu, logvar = torch.chunk(h, 2, dim=-1)
|
|
55
|
+
return mu, logvar
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# =============================================================================
|
|
59
|
+
# Tokenization: vector -> tokens (+ optional embeddings / bias)
|
|
60
|
+
# =============================================================================
|
|
61
|
+
|
|
62
|
+
def _mlp(in_dim: int, out_dim: int, hidden: int = 128) -> nn.Module:
|
|
63
|
+
return nn.Sequential(
|
|
64
|
+
nn.LayerNorm(in_dim),
|
|
65
|
+
nn.Linear(in_dim, hidden),
|
|
66
|
+
nn.GELU(),
|
|
67
|
+
nn.Linear(hidden, out_dim),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class _VectorToTokens(nn.Module):
|
|
72
|
+
"""
|
|
73
|
+
Turn a vector x (B, F) into tokens (B, T, D_in) and optional key_padding_mask.
|
|
74
|
+
|
|
75
|
+
TokenizerConfig modes
|
|
76
|
+
---------------------
|
|
77
|
+
- topk_scalar:
|
|
78
|
+
Select top-k features per cell (by value), output tokens (B, K, 1)
|
|
79
|
+
- topk_channels:
|
|
80
|
+
Select top-k features per cell, output tokens (B, K, C) where C=len(channels)
|
|
81
|
+
channels in {"value","rank","dropout"}:
|
|
82
|
+
* value: raw x at selected indices
|
|
83
|
+
* rank: rank01 among selected K tokens (0..1), per cell
|
|
84
|
+
* dropout: indicator (value==0)
|
|
85
|
+
- patch:
|
|
86
|
+
Split contiguous features into patches of size P:
|
|
87
|
+
tokens (B, T, P) or (B, T, patch_proj_dim) if patch_proj_dim is set
|
|
88
|
+
|
|
89
|
+
add_cls_token:
|
|
90
|
+
If True, prepend a learned CLS token embedding to tokens.
|
|
91
|
+
|
|
92
|
+
NEW (optional)
|
|
93
|
+
--------------
|
|
94
|
+
- use_feature_embedding:
|
|
95
|
+
Adds learned embedding for selected feature IDs (topk modes), or patch IDs (patch mode).
|
|
96
|
+
- use_coord_embedding (ATAC):
|
|
97
|
+
Adds chromosome embedding + coordinate MLP for selected feature coords (topk modes).
|
|
98
|
+
Call set_feature_coords(chrom_ids, start, end) to attach coords.
|
|
99
|
+
- token_proj_dim:
|
|
100
|
+
If set (or implied by embeddings), project raw token channels/patches to token_proj_dim.
|
|
101
|
+
|
|
102
|
+
Notes
|
|
103
|
+
-----
|
|
104
|
+
- Expects dense float input (B,F). If modality data is sparse, densify upstream.
|
|
105
|
+
- key_padding_mask uses True = PAD/ignore (MultiheadAttention convention).
|
|
106
|
+
"""
|
|
107
|
+
def __init__(self, *, input_dim: int, tok: TokenizerConfig):
|
|
108
|
+
super().__init__()
|
|
109
|
+
self.input_dim = int(input_dim)
|
|
110
|
+
self.tok = tok
|
|
111
|
+
|
|
112
|
+
mode = str(tok.mode).lower().strip()
|
|
113
|
+
if mode not in ("topk_scalar", "topk_channels", "patch"):
|
|
114
|
+
raise ValueError(f"Unknown tokenizer mode {tok.mode!r}")
|
|
115
|
+
|
|
116
|
+
self.mode = mode
|
|
117
|
+
self.add_cls_token = bool(getattr(tok, "add_cls_token", False))
|
|
118
|
+
|
|
119
|
+
# ------------------------------------------------------------------
|
|
120
|
+
# NEW options (all default to False/None -> no behavior change)
|
|
121
|
+
# ------------------------------------------------------------------
|
|
122
|
+
self.use_feature_emb = bool(getattr(tok, "use_feature_embedding", False))
|
|
123
|
+
self.feature_emb_mode = str(getattr(tok, "feature_emb_mode", "add")).lower().strip()
|
|
124
|
+
self.n_features = getattr(tok, "n_features", None)
|
|
125
|
+
self.feature_emb_dim = getattr(tok, "feature_emb_dim", None)
|
|
126
|
+
|
|
127
|
+
self.use_coord_emb = bool(getattr(tok, "use_coord_embedding", False))
|
|
128
|
+
self.coord_mode = str(getattr(tok, "coord_mode", "midpoint")).lower().strip()
|
|
129
|
+
self.coord_scale = float(getattr(tok, "coord_scale", 1e-6))
|
|
130
|
+
self.coord_emb_dim = getattr(tok, "coord_emb_dim", None)
|
|
131
|
+
self.n_chroms = getattr(tok, "n_chroms", None)
|
|
132
|
+
self.coord_mlp_hidden = int(getattr(tok, "coord_mlp_hidden", 128))
|
|
133
|
+
|
|
134
|
+
self.token_proj_dim = getattr(tok, "token_proj_dim", None)
|
|
135
|
+
self._has_coords = False
|
|
136
|
+
self.register_buffer("_chrom_ids", torch.empty(0, dtype=torch.long), persistent=False)
|
|
137
|
+
self.register_buffer("_start", torch.empty(0), persistent=False)
|
|
138
|
+
self.register_buffer("_end", torch.empty(0), persistent=False)
|
|
139
|
+
|
|
140
|
+
# ------------------------------------------------------------------
|
|
141
|
+
# Base token dims per mode (pre-projection)
|
|
142
|
+
# ------------------------------------------------------------------
|
|
143
|
+
if self.mode == "topk_scalar":
|
|
144
|
+
self.n_tokens = int(tok.n_tokens)
|
|
145
|
+
base_d = 1
|
|
146
|
+
self.channels: Sequence[str] = ("value",)
|
|
147
|
+
|
|
148
|
+
elif self.mode == "topk_channels":
|
|
149
|
+
self.n_tokens = int(tok.n_tokens)
|
|
150
|
+
ch = list(tok.channels)
|
|
151
|
+
if not ch:
|
|
152
|
+
raise ValueError("topk_channels requires tokenizer.channels non-empty")
|
|
153
|
+
bad = [c for c in ch if c not in ("value", "rank", "dropout")]
|
|
154
|
+
if bad:
|
|
155
|
+
raise ValueError(f"topk_channels invalid channels: {bad}")
|
|
156
|
+
self.channels = tuple(ch)
|
|
157
|
+
base_d = len(self.channels)
|
|
158
|
+
|
|
159
|
+
else: # patch
|
|
160
|
+
P = int(tok.patch_size)
|
|
161
|
+
if P <= 0:
|
|
162
|
+
raise ValueError("patch_size must be > 0")
|
|
163
|
+
self.patch_size = P
|
|
164
|
+
T = (self.input_dim + P - 1) // P
|
|
165
|
+
self.n_tokens = int(T)
|
|
166
|
+
|
|
167
|
+
proj_dim = tok.patch_proj_dim
|
|
168
|
+
if proj_dim is None:
|
|
169
|
+
self.patch_proj = None
|
|
170
|
+
base_d = P
|
|
171
|
+
else:
|
|
172
|
+
proj_dim = int(proj_dim)
|
|
173
|
+
if proj_dim <= 0:
|
|
174
|
+
raise ValueError("patch_proj_dim must be > 0 if set")
|
|
175
|
+
self.patch_proj = nn.Linear(P, proj_dim)
|
|
176
|
+
base_d = proj_dim
|
|
177
|
+
|
|
178
|
+
# ------------------------------------------------------------------
|
|
179
|
+
# Decide final token dim (d_in)
|
|
180
|
+
# - if token_proj_dim set -> project to it
|
|
181
|
+
# - else if embeddings enabled -> pick a reasonable dim (feature/coord emb dim, or 64)
|
|
182
|
+
# - else -> keep base_d (exact backward-compat)
|
|
183
|
+
# ------------------------------------------------------------------
|
|
184
|
+
implied_proj: Optional[int] = None
|
|
185
|
+
if self.token_proj_dim is not None:
|
|
186
|
+
implied_proj = int(self.token_proj_dim)
|
|
187
|
+
elif self.use_feature_emb or self.use_coord_emb:
|
|
188
|
+
implied_proj = int(self.feature_emb_dim or self.coord_emb_dim or 64)
|
|
189
|
+
|
|
190
|
+
self._d_in = int(implied_proj) if implied_proj is not None else int(base_d)
|
|
191
|
+
|
|
192
|
+
# ------------------------------------------------------------------
|
|
193
|
+
# Optional projection from base_d -> d_in
|
|
194
|
+
# ------------------------------------------------------------------
|
|
195
|
+
self.val_proj: Optional[nn.Module] = None
|
|
196
|
+
if self._d_in != int(base_d):
|
|
197
|
+
self.val_proj = _mlp(int(base_d), self._d_in, hidden=self.coord_mlp_hidden)
|
|
198
|
+
|
|
199
|
+
# ------------------------------------------------------------------
|
|
200
|
+
# NEW: feature ID embedding (topk) / patch ID embedding (patch)
|
|
201
|
+
# ------------------------------------------------------------------
|
|
202
|
+
self.id_emb: Optional[nn.Embedding] = None
|
|
203
|
+
self.id_fuse: Optional[nn.Linear] = None
|
|
204
|
+
if self.use_feature_emb:
|
|
205
|
+
if self.mode in ("topk_scalar", "topk_channels"):
|
|
206
|
+
if self.n_features is None or int(self.n_features) <= 0:
|
|
207
|
+
raise ValueError("use_feature_embedding=True for topk requires tok.n_features > 0")
|
|
208
|
+
emb_dim = int(self.feature_emb_dim or self._d_in)
|
|
209
|
+
self.id_emb = nn.Embedding(int(self.n_features), emb_dim)
|
|
210
|
+
if self.feature_emb_mode == "concat":
|
|
211
|
+
self.id_fuse = nn.Linear(self._d_in + emb_dim, self._d_in)
|
|
212
|
+
else:
|
|
213
|
+
# patch: embed patch index (0..T-1)
|
|
214
|
+
emb_dim = int(self.feature_emb_dim or self._d_in)
|
|
215
|
+
self.id_emb = nn.Embedding(int(self.n_tokens), emb_dim)
|
|
216
|
+
if self.feature_emb_mode == "concat":
|
|
217
|
+
self.id_fuse = nn.Linear(self._d_in + emb_dim, self._d_in)
|
|
218
|
+
|
|
219
|
+
# ------------------------------------------------------------------
|
|
220
|
+
# NEW: coord embeddings (topk only)
|
|
221
|
+
# ------------------------------------------------------------------
|
|
222
|
+
self.chrom_emb: Optional[nn.Embedding] = None
|
|
223
|
+
self.coord_mlp: Optional[nn.Module] = None
|
|
224
|
+
if self.use_coord_emb:
|
|
225
|
+
if self.mode not in ("topk_scalar", "topk_channels"):
|
|
226
|
+
raise ValueError("use_coord_embedding=True is only supported for topk_* tokenizers.")
|
|
227
|
+
if self.n_chroms is None or int(self.n_chroms) <= 0:
|
|
228
|
+
raise ValueError("use_coord_embedding=True requires tok.n_chroms > 0")
|
|
229
|
+
self.chrom_emb = nn.Embedding(int(self.n_chroms), self._d_in)
|
|
230
|
+
cd_in = 1 if self.coord_mode == "midpoint" else 2
|
|
231
|
+
self.coord_mlp = _mlp(cd_in, self._d_in, hidden=self.coord_mlp_hidden)
|
|
232
|
+
|
|
233
|
+
# ------------------------------------------------------------------
|
|
234
|
+
# CLS token (learned, matches d_in)
|
|
235
|
+
# ------------------------------------------------------------------
|
|
236
|
+
self.cls_token: Optional[nn.Parameter] = None
|
|
237
|
+
if self.add_cls_token:
|
|
238
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self._d_in))
|
|
239
|
+
nn.init.normal_(self.cls_token, mean=0.0, std=0.02)
|
|
240
|
+
|
|
241
|
+
@property
|
|
242
|
+
def d_in(self) -> int:
|
|
243
|
+
return int(self._d_in)
|
|
244
|
+
|
|
245
|
+
def set_feature_coords(self, chrom_ids: torch.Tensor, start: torch.Tensor, end: torch.Tensor) -> None:
|
|
246
|
+
"""
|
|
247
|
+
Attach per-feature genomic coordinates (for ATAC coordinate embeddings and distance bias).
|
|
248
|
+
|
|
249
|
+
Parameters
|
|
250
|
+
----------
|
|
251
|
+
chrom_ids : LongTensor [F] with values in [0, n_chroms)
|
|
252
|
+
start/end : Float/LongTensor [F] in basepairs
|
|
253
|
+
"""
|
|
254
|
+
chrom_ids = chrom_ids.long().contiguous()
|
|
255
|
+
start = start.to(dtype=torch.float32).contiguous()
|
|
256
|
+
end = end.to(dtype=torch.float32).contiguous()
|
|
257
|
+
if chrom_ids.ndim != 1 or start.ndim != 1 or end.ndim != 1:
|
|
258
|
+
raise ValueError("set_feature_coords expects 1D tensors: chrom_ids/start/end.")
|
|
259
|
+
if not (chrom_ids.shape[0] == start.shape[0] == end.shape[0] == self.input_dim):
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"set_feature_coords expects length F={self.input_dim}; got "
|
|
262
|
+
f"{chrom_ids.shape[0]}, {start.shape[0]}, {end.shape[0]}"
|
|
263
|
+
)
|
|
264
|
+
self._chrom_ids = chrom_ids
|
|
265
|
+
self._start = start
|
|
266
|
+
self._end = end
|
|
267
|
+
self._has_coords = True
|
|
268
|
+
|
|
269
|
+
def _apply_id_emb(self, tokens: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
|
|
270
|
+
if self.id_emb is None:
|
|
271
|
+
return tokens
|
|
272
|
+
idv = self.id_emb(ids) # (..., E)
|
|
273
|
+
|
|
274
|
+
if self.feature_emb_mode == "add":
|
|
275
|
+
if idv.shape[-1] != tokens.shape[-1]:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"feature_emb_mode='add' requires emb_dim==d_in; got emb_dim={idv.shape[-1]} vs d_in={tokens.shape[-1]}"
|
|
278
|
+
)
|
|
279
|
+
return tokens + idv
|
|
280
|
+
|
|
281
|
+
if self.feature_emb_mode == "concat":
|
|
282
|
+
if self.id_fuse is None:
|
|
283
|
+
raise RuntimeError("id_fuse not initialized for concat mode.")
|
|
284
|
+
return self.id_fuse(torch.cat([tokens, idv], dim=-1))
|
|
285
|
+
|
|
286
|
+
raise ValueError(f"Unknown feature_emb_mode={self.feature_emb_mode!r}")
|
|
287
|
+
|
|
288
|
+
def _apply_coord_emb(self, tokens: torch.Tensor, topk_idx: torch.Tensor) -> torch.Tensor:
|
|
289
|
+
if not self.use_coord_emb:
|
|
290
|
+
return tokens
|
|
291
|
+
if not self._has_coords:
|
|
292
|
+
# coords not attached -> silently no-op (keeps things robust)
|
|
293
|
+
return tokens
|
|
294
|
+
if self.chrom_emb is None or self.coord_mlp is None:
|
|
295
|
+
return tokens
|
|
296
|
+
|
|
297
|
+
B, K = topk_idx.shape
|
|
298
|
+
Fdim = self.input_dim
|
|
299
|
+
if self._chrom_ids.numel() != Fdim:
|
|
300
|
+
raise ValueError(f"Tokenizer coords are for F={self._chrom_ids.numel()} but input_dim={Fdim}")
|
|
301
|
+
|
|
302
|
+
chrom = torch.gather(self._chrom_ids.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
|
|
303
|
+
start = torch.gather(self._start.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
|
|
304
|
+
end = torch.gather(self._end.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
|
|
305
|
+
|
|
306
|
+
chrom_e = self.chrom_emb(chrom) # (B,K,D)
|
|
307
|
+
|
|
308
|
+
if self.coord_mode == "midpoint":
|
|
309
|
+
mid = 0.5 * (start + end)
|
|
310
|
+
pos = (mid * self.coord_scale).unsqueeze(-1) # (B,K,1)
|
|
311
|
+
else:
|
|
312
|
+
pos = torch.stack([start * self.coord_scale, end * self.coord_scale], dim=-1) # (B,K,2)
|
|
313
|
+
|
|
314
|
+
pos_e = self.coord_mlp(pos) # (B,K,D)
|
|
315
|
+
return tokens + chrom_e + pos_e
|
|
316
|
+
|
|
317
|
+
def build_distance_attn_bias(
|
|
318
|
+
self,
|
|
319
|
+
topk_idx: torch.Tensor,
|
|
320
|
+
*,
|
|
321
|
+
lengthscale_bp: float = 50_000.0,
|
|
322
|
+
same_chrom_only: bool = True,
|
|
323
|
+
include_cls: bool = False,
|
|
324
|
+
cls_is_zero: bool = True,
|
|
325
|
+
) -> torch.Tensor:
|
|
326
|
+
"""
|
|
327
|
+
Build an additive attention bias matrix for topk tokens based on genomic distance.
|
|
328
|
+
|
|
329
|
+
Returns
|
|
330
|
+
-------
|
|
331
|
+
attn_bias : FloatTensor (B, T, T)
|
|
332
|
+
Additive bias to attention logits (higher = more attention).
|
|
333
|
+
Uses a Gaussian kernel: bias = - (dist / lengthscale)^2
|
|
334
|
+
|
|
335
|
+
Notes
|
|
336
|
+
-----
|
|
337
|
+
- Only valid when coords are attached via set_feature_coords().
|
|
338
|
+
- For CLS handling:
|
|
339
|
+
include_cls=True assumes your tokens will have CLS prepended at position 0.
|
|
340
|
+
"""
|
|
341
|
+
if self.mode not in ("topk_scalar", "topk_channels"):
|
|
342
|
+
raise ValueError("build_distance_attn_bias is only supported for topk_* modes.")
|
|
343
|
+
if not self._has_coords:
|
|
344
|
+
raise ValueError("build_distance_attn_bias requires feature coords; call set_feature_coords first.")
|
|
345
|
+
if float(lengthscale_bp) <= 0:
|
|
346
|
+
raise ValueError("lengthscale_bp must be > 0")
|
|
347
|
+
|
|
348
|
+
B, K = topk_idx.shape
|
|
349
|
+
Fdim = self.input_dim
|
|
350
|
+
|
|
351
|
+
chrom = torch.gather(self._chrom_ids.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
|
|
352
|
+
start = torch.gather(self._start.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
|
|
353
|
+
end = torch.gather(self._end.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
|
|
354
|
+
mid = 0.5 * (start + end) # (B,K)
|
|
355
|
+
|
|
356
|
+
# pairwise distances per batch: (B,K,K)
|
|
357
|
+
dist = (mid.unsqueeze(2) - mid.unsqueeze(1)).abs()
|
|
358
|
+
|
|
359
|
+
if same_chrom_only:
|
|
360
|
+
same = (chrom.unsqueeze(2) == chrom.unsqueeze(1))
|
|
361
|
+
else:
|
|
362
|
+
same = torch.ones((B, K, K), device=dist.device, dtype=torch.bool)
|
|
363
|
+
|
|
364
|
+
ls = float(lengthscale_bp)
|
|
365
|
+
bias = -((dist / ls) ** 2)
|
|
366
|
+
bias = torch.where(same, bias, torch.full_like(bias, -1e4))
|
|
367
|
+
|
|
368
|
+
if include_cls:
|
|
369
|
+
T = K + 1
|
|
370
|
+
out = torch.zeros((B, T, T), device=bias.device, dtype=bias.dtype)
|
|
371
|
+
out[:, 1:, 1:] = bias
|
|
372
|
+
if not cls_is_zero:
|
|
373
|
+
# (rare) you could choose to bias CLS-to-all; default keeps it neutral
|
|
374
|
+
pass
|
|
375
|
+
return out
|
|
376
|
+
|
|
377
|
+
return bias
|
|
378
|
+
|
|
379
|
+
def forward(
|
|
380
|
+
self,
|
|
381
|
+
x: torch.Tensor,
|
|
382
|
+
*,
|
|
383
|
+
return_indices: bool = False,
|
|
384
|
+
) -> Union[
|
|
385
|
+
Tuple[torch.Tensor, Optional[torch.Tensor]],
|
|
386
|
+
Tuple[torch.Tensor, Optional[torch.Tensor], Dict[str, Any]],
|
|
387
|
+
]:
|
|
388
|
+
if x.dim() != 2:
|
|
389
|
+
raise ValueError(f"_VectorToTokens expects x as (B,F); got shape {tuple(x.shape)}")
|
|
390
|
+
B, Fdim = x.shape
|
|
391
|
+
if Fdim != self.input_dim:
|
|
392
|
+
raise ValueError(f"Expected input_dim={self.input_dim}, got F={Fdim}")
|
|
393
|
+
|
|
394
|
+
key_padding_mask: Optional[torch.Tensor] = None
|
|
395
|
+
meta: Dict[str, Any] = {}
|
|
396
|
+
|
|
397
|
+
if self.mode == "topk_scalar":
|
|
398
|
+
K = min(int(self.n_tokens), Fdim)
|
|
399
|
+
vals, idx = torch.topk(x, k=K, dim=1, largest=True, sorted=False) # (B,K)
|
|
400
|
+
tokens = vals.unsqueeze(-1) # (B,K,1)
|
|
401
|
+
key_padding_mask = None
|
|
402
|
+
if return_indices:
|
|
403
|
+
meta["topk_idx"] = idx # (B,K)
|
|
404
|
+
|
|
405
|
+
if self.val_proj is not None:
|
|
406
|
+
tokens = self.val_proj(tokens) # (B,K,D)
|
|
407
|
+
|
|
408
|
+
if self.use_feature_emb:
|
|
409
|
+
tokens = self._apply_id_emb(tokens, idx)
|
|
410
|
+
|
|
411
|
+
tokens = self._apply_coord_emb(tokens, idx)
|
|
412
|
+
|
|
413
|
+
elif self.mode == "topk_channels":
|
|
414
|
+
K = min(int(self.n_tokens), Fdim)
|
|
415
|
+
vals, idx = torch.topk(x, k=K, dim=1, largest=True, sorted=True) # (B,K)
|
|
416
|
+
feats = []
|
|
417
|
+
for c in self.channels:
|
|
418
|
+
if c == "value":
|
|
419
|
+
feats.append(vals)
|
|
420
|
+
elif c == "rank":
|
|
421
|
+
if K <= 1:
|
|
422
|
+
rank01 = torch.zeros((B, K), device=x.device, dtype=torch.float32)
|
|
423
|
+
else:
|
|
424
|
+
rank01 = torch.linspace(0.0, 1.0, steps=K, device=x.device, dtype=torch.float32)
|
|
425
|
+
rank01 = rank01.unsqueeze(0).expand(B, K)
|
|
426
|
+
feats.append(rank01)
|
|
427
|
+
elif c == "dropout":
|
|
428
|
+
feats.append((vals == 0).to(torch.float32))
|
|
429
|
+
else:
|
|
430
|
+
raise RuntimeError(f"Unhandled channel: {c!r}")
|
|
431
|
+
tokens = torch.stack(feats, dim=-1) # (B,K,C)
|
|
432
|
+
key_padding_mask = None
|
|
433
|
+
if return_indices:
|
|
434
|
+
meta["topk_idx"] = idx # (B,K)
|
|
435
|
+
|
|
436
|
+
if self.val_proj is not None:
|
|
437
|
+
tokens = self.val_proj(tokens) # (B,K,D)
|
|
438
|
+
|
|
439
|
+
if self.use_feature_emb:
|
|
440
|
+
tokens = self._apply_id_emb(tokens, idx)
|
|
441
|
+
|
|
442
|
+
tokens = self._apply_coord_emb(tokens, idx)
|
|
443
|
+
|
|
444
|
+
else: # patch
|
|
445
|
+
P = int(self.patch_size)
|
|
446
|
+
T = int(self.n_tokens)
|
|
447
|
+
pad = T * P - Fdim
|
|
448
|
+
if pad > 0:
|
|
449
|
+
x_pad = F.pad(x, (0, pad), mode="constant", value=0.0) # (B, T*P)
|
|
450
|
+
else:
|
|
451
|
+
x_pad = x
|
|
452
|
+
patches = x_pad.view(B, T, P) # (B,T,P)
|
|
453
|
+
|
|
454
|
+
tokens = self.patch_proj(patches) if self.patch_proj is not None else patches
|
|
455
|
+
|
|
456
|
+
if pad > 0:
|
|
457
|
+
real_counts = torch.full((T,), P, device=x.device, dtype=torch.int64)
|
|
458
|
+
last_real = Fdim - (T - 1) * P
|
|
459
|
+
if last_real < P:
|
|
460
|
+
real_counts[-1] = max(int(last_real), 0)
|
|
461
|
+
key_padding_mask = (real_counts == 0).unsqueeze(0).expand(B, T)
|
|
462
|
+
else:
|
|
463
|
+
key_padding_mask = None
|
|
464
|
+
|
|
465
|
+
if return_indices:
|
|
466
|
+
meta["patch_size"] = P
|
|
467
|
+
meta["n_patches"] = T
|
|
468
|
+
meta["pad"] = pad
|
|
469
|
+
|
|
470
|
+
if self.use_feature_emb and self.id_emb is not None:
|
|
471
|
+
pid = torch.arange(T, device=x.device).view(1, T).expand(B, T) # (B,T)
|
|
472
|
+
tokens = self._apply_id_emb(tokens, pid)
|
|
473
|
+
|
|
474
|
+
if self.add_cls_token:
|
|
475
|
+
assert self.cls_token is not None
|
|
476
|
+
cls = self.cls_token.expand(B, -1, -1) # (B,1,D)
|
|
477
|
+
tokens = torch.cat([cls, tokens], dim=1)
|
|
478
|
+
if key_padding_mask is not None:
|
|
479
|
+
cls_mask = torch.zeros((B, 1), device=x.device, dtype=torch.bool)
|
|
480
|
+
key_padding_mask = torch.cat([cls_mask, key_padding_mask], dim=1)
|
|
481
|
+
|
|
482
|
+
if return_indices:
|
|
483
|
+
return tokens, key_padding_mask, meta
|
|
484
|
+
return tokens, key_padding_mask
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
# =============================================================================
|
|
488
|
+
# Per-modality transformer Gaussian encoder
|
|
489
|
+
# =============================================================================
|
|
490
|
+
|
|
491
|
+
def _cfg_to_model_tcfg(cfg: CFGTransformerConfig) -> ModelTransformerConfig:
|
|
492
|
+
return ModelTransformerConfig(
|
|
493
|
+
d_model=int(cfg.d_model),
|
|
494
|
+
num_heads=int(cfg.num_heads),
|
|
495
|
+
num_layers=int(cfg.num_layers),
|
|
496
|
+
dim_feedforward=int(cfg.dim_feedforward),
|
|
497
|
+
dropout=float(cfg.dropout),
|
|
498
|
+
attn_dropout=float(cfg.attn_dropout),
|
|
499
|
+
activation=str(cfg.activation),
|
|
500
|
+
pooling=str(cfg.pooling),
|
|
501
|
+
max_tokens=None if cfg.max_tokens is None else int(cfg.max_tokens),
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
class TransformerGaussianEncoder(GaussianEncoder):
|
|
506
|
+
"""
|
|
507
|
+
(B,F) -> tokens (B,T,D_in) -> TransformerEncoder -> (mu, logvar)
|
|
508
|
+
|
|
509
|
+
New convenience:
|
|
510
|
+
- attach coords: self.vec2tok.set_feature_coords(...)
|
|
511
|
+
- optional attn_bias passthrough (e.g., distance bias)
|
|
512
|
+
"""
|
|
513
|
+
def __init__(
|
|
514
|
+
self,
|
|
515
|
+
*,
|
|
516
|
+
input_dim: int,
|
|
517
|
+
latent_dim: int,
|
|
518
|
+
tokenizer: _VectorToTokens,
|
|
519
|
+
tcfg: ModelTransformerConfig,
|
|
520
|
+
use_positional_encoding: bool = True,
|
|
521
|
+
):
|
|
522
|
+
super().__init__()
|
|
523
|
+
self.vec2tok = tokenizer
|
|
524
|
+
|
|
525
|
+
if use_positional_encoding and tcfg.max_tokens is None:
|
|
526
|
+
tcfg.max_tokens = int(self.vec2tok.n_tokens + (1 if self.vec2tok.add_cls_token else 0))
|
|
527
|
+
|
|
528
|
+
self.encoder = TransformerEncoder(
|
|
529
|
+
cfg=tcfg,
|
|
530
|
+
d_in=int(self.vec2tok.d_in),
|
|
531
|
+
d_out=2 * int(latent_dim),
|
|
532
|
+
use_positional_encoding=bool(use_positional_encoding),
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
def forward(
|
|
536
|
+
self,
|
|
537
|
+
x: torch.Tensor,
|
|
538
|
+
*,
|
|
539
|
+
return_attn: bool = False,
|
|
540
|
+
attn_average_heads: bool = True,
|
|
541
|
+
return_token_meta: bool = False,
|
|
542
|
+
attn_bias: Optional[torch.Tensor] = None,
|
|
543
|
+
):
|
|
544
|
+
if return_token_meta:
|
|
545
|
+
tokens, key_padding_mask, meta = self.vec2tok(x, return_indices=True)
|
|
546
|
+
else:
|
|
547
|
+
tokens, key_padding_mask = self.vec2tok(x, return_indices=False)
|
|
548
|
+
meta = None
|
|
549
|
+
|
|
550
|
+
if return_attn:
|
|
551
|
+
h, attn_all = self.encoder(
|
|
552
|
+
tokens,
|
|
553
|
+
key_padding_mask=key_padding_mask,
|
|
554
|
+
attn_bias=attn_bias,
|
|
555
|
+
return_attn=True,
|
|
556
|
+
attn_average_heads=attn_average_heads,
|
|
557
|
+
)
|
|
558
|
+
else:
|
|
559
|
+
h = self.encoder(
|
|
560
|
+
tokens,
|
|
561
|
+
key_padding_mask=key_padding_mask,
|
|
562
|
+
attn_bias=attn_bias,
|
|
563
|
+
return_attn=False,
|
|
564
|
+
)
|
|
565
|
+
attn_all = None
|
|
566
|
+
|
|
567
|
+
mu, logvar = torch.chunk(h, 2, dim=-1)
|
|
568
|
+
|
|
569
|
+
if return_attn and return_token_meta:
|
|
570
|
+
return mu, logvar, attn_all, meta
|
|
571
|
+
if return_attn:
|
|
572
|
+
return mu, logvar, attn_all
|
|
573
|
+
if return_token_meta:
|
|
574
|
+
return mu, logvar, meta
|
|
575
|
+
return mu, logvar
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
# =============================================================================
|
|
579
|
+
# Multimodal concatenated-token transformer Gaussian encoder (fused)
|
|
580
|
+
# =============================================================================
|
|
581
|
+
|
|
582
|
+
def _tokcfg_without_cls(tok_cfg_in: TokenizerConfig) -> TokenizerConfig:
|
|
583
|
+
# preserve all fields, only override add_cls_token=False
|
|
584
|
+
return replace(tok_cfg_in, add_cls_token=False)
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
class MultiModalTransformerGaussianEncoder(nn.Module):
|
|
588
|
+
"""
|
|
589
|
+
Fused encoder over multiple modalities by concatenating tokens.
|
|
590
|
+
|
|
591
|
+
Produces ONE fused posterior q(z | x_all). It does not replace per-modality q(z|x_m).
|
|
592
|
+
|
|
593
|
+
Clean coord hook:
|
|
594
|
+
fused_encoder.set_feature_coords("atac", chrom_ids, start, end)
|
|
595
|
+
|
|
596
|
+
Attention bias:
|
|
597
|
+
You can pass attn_bias=... into forward, or pass attn_bias_fn(meta)->bias.
|
|
598
|
+
"""
|
|
599
|
+
def __init__(
|
|
600
|
+
self,
|
|
601
|
+
*,
|
|
602
|
+
modalities: Sequence[str],
|
|
603
|
+
input_dims: Dict[str, int],
|
|
604
|
+
tokenizers: Dict[str, TokenizerConfig],
|
|
605
|
+
transformer_cfg: CFGTransformerConfig,
|
|
606
|
+
latent_dim: int,
|
|
607
|
+
add_modality_embeddings: bool = True,
|
|
608
|
+
use_positional_encoding: bool = True,
|
|
609
|
+
):
|
|
610
|
+
super().__init__()
|
|
611
|
+
|
|
612
|
+
self.modalities = list(modalities)
|
|
613
|
+
self.latent_dim = int(latent_dim)
|
|
614
|
+
|
|
615
|
+
tcfg_model = _cfg_to_model_tcfg(transformer_cfg)
|
|
616
|
+
d_model = int(tcfg_model.d_model)
|
|
617
|
+
|
|
618
|
+
self.vec2tok = nn.ModuleDict()
|
|
619
|
+
self.proj = nn.ModuleDict()
|
|
620
|
+
self.mod_emb = nn.ParameterDict() if add_modality_embeddings else None
|
|
621
|
+
|
|
622
|
+
total_tokens = 0
|
|
623
|
+
for m in self.modalities:
|
|
624
|
+
tok_cfg_in = tokenizers[m]
|
|
625
|
+
tok_cfg = _tokcfg_without_cls(tok_cfg_in) # force per-modality CLS off (one global CLS optional)
|
|
626
|
+
|
|
627
|
+
tok = _VectorToTokens(input_dim=int(input_dims[m]), tok=tok_cfg)
|
|
628
|
+
self.vec2tok[m] = tok
|
|
629
|
+
self.proj[m] = nn.Linear(int(tok.d_in), d_model, bias=True)
|
|
630
|
+
|
|
631
|
+
if self.mod_emb is not None:
|
|
632
|
+
self.mod_emb[m] = nn.Parameter(torch.zeros(1, 1, d_model))
|
|
633
|
+
nn.init.normal_(self.mod_emb[m], mean=0.0, std=0.02)
|
|
634
|
+
|
|
635
|
+
total_tokens += int(tok.n_tokens)
|
|
636
|
+
|
|
637
|
+
self.pooling = str(tcfg_model.pooling).lower().strip()
|
|
638
|
+
self.use_global_cls = (self.pooling == "cls")
|
|
639
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if self.use_global_cls else None
|
|
640
|
+
if self.cls_token is not None:
|
|
641
|
+
nn.init.normal_(self.cls_token, mean=0.0, std=0.02)
|
|
642
|
+
|
|
643
|
+
if use_positional_encoding and tcfg_model.max_tokens is None:
|
|
644
|
+
tcfg_model.max_tokens = int(total_tokens + (1 if self.use_global_cls else 0))
|
|
645
|
+
|
|
646
|
+
self.encoder = TransformerEncoder(
|
|
647
|
+
cfg=tcfg_model,
|
|
648
|
+
d_in=d_model,
|
|
649
|
+
d_out=2 * int(latent_dim),
|
|
650
|
+
use_positional_encoding=bool(use_positional_encoding),
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
def set_feature_coords(self, modality: str, chrom_ids: torch.Tensor, start: torch.Tensor, end: torch.Tensor) -> None:
|
|
654
|
+
modality = str(modality)
|
|
655
|
+
if modality not in self.vec2tok:
|
|
656
|
+
raise KeyError(f"Unknown modality {modality!r}. Known: {list(self.vec2tok.keys())}")
|
|
657
|
+
self.vec2tok[modality].set_feature_coords(chrom_ids, start, end)
|
|
658
|
+
|
|
659
|
+
def forward(
|
|
660
|
+
self,
|
|
661
|
+
x_dict: Dict[str, torch.Tensor],
|
|
662
|
+
*,
|
|
663
|
+
return_token_meta: bool = False,
|
|
664
|
+
return_attn: bool = False,
|
|
665
|
+
attn_average_heads: bool = True,
|
|
666
|
+
attn_bias: Optional[torch.Tensor] = None,
|
|
667
|
+
attn_bias_fn: Optional[Any] = None,
|
|
668
|
+
):
|
|
669
|
+
tokens_list: List[torch.Tensor] = []
|
|
670
|
+
masks_list: List[Optional[torch.Tensor]] = []
|
|
671
|
+
|
|
672
|
+
meta: Dict[str, Any] = {
|
|
673
|
+
"modalities": self.modalities,
|
|
674
|
+
"slices": {}, # modality -> (start, end) in CONCAT TOKEN SPACE (excluding global CLS)
|
|
675
|
+
"has_global_cls": bool(self.use_global_cls),
|
|
676
|
+
"cls_index": 0 if self.use_global_cls else None,
|
|
677
|
+
}
|
|
678
|
+
t_cursor = 0
|
|
679
|
+
|
|
680
|
+
for m in self.modalities:
|
|
681
|
+
x = x_dict[m]
|
|
682
|
+
|
|
683
|
+
if return_token_meta:
|
|
684
|
+
tok, mask, mmeta = self.vec2tok[m](x, return_indices=True)
|
|
685
|
+
meta[m] = mmeta
|
|
686
|
+
else:
|
|
687
|
+
tok, mask = self.vec2tok[m](x, return_indices=False)
|
|
688
|
+
|
|
689
|
+
tok = self.proj[m](tok) # (B,T,d_model)
|
|
690
|
+
if self.mod_emb is not None:
|
|
691
|
+
tok = tok + self.mod_emb[m]
|
|
692
|
+
|
|
693
|
+
Tm = tok.shape[1]
|
|
694
|
+
meta["slices"][m] = (t_cursor, t_cursor + Tm)
|
|
695
|
+
t_cursor += Tm
|
|
696
|
+
|
|
697
|
+
tokens_list.append(tok)
|
|
698
|
+
masks_list.append(mask)
|
|
699
|
+
|
|
700
|
+
tokens = torch.cat(tokens_list, dim=1) # (B, T_total, d_model)
|
|
701
|
+
|
|
702
|
+
key_padding_mask: Optional[torch.Tensor] = None
|
|
703
|
+
if any(m is not None for m in masks_list):
|
|
704
|
+
B = tokens.shape[0]
|
|
705
|
+
built: List[torch.Tensor] = []
|
|
706
|
+
for i, mname in enumerate(self.modalities):
|
|
707
|
+
mask = masks_list[i]
|
|
708
|
+
if mask is None:
|
|
709
|
+
Tm = tokens_list[i].shape[1]
|
|
710
|
+
built.append(torch.zeros((B, Tm), device=tokens.device, dtype=torch.bool))
|
|
711
|
+
else:
|
|
712
|
+
built.append(mask.to(dtype=torch.bool))
|
|
713
|
+
key_padding_mask = torch.cat(built, dim=1)
|
|
714
|
+
|
|
715
|
+
# Prepend ONE global CLS if pooling="cls"
|
|
716
|
+
if self.use_global_cls:
|
|
717
|
+
assert self.cls_token is not None
|
|
718
|
+
cls = self.cls_token.expand(tokens.shape[0], -1, -1) # (B,1,D)
|
|
719
|
+
tokens = torch.cat([cls, tokens], dim=1)
|
|
720
|
+
|
|
721
|
+
meta["slices_with_cls"] = {k: (a + 1, b + 1) for k, (a, b) in meta["slices"].items()}
|
|
722
|
+
|
|
723
|
+
if key_padding_mask is not None:
|
|
724
|
+
cls_mask = torch.zeros((tokens.shape[0], 1), device=tokens.device, dtype=torch.bool)
|
|
725
|
+
key_padding_mask = torch.cat([cls_mask, key_padding_mask], dim=1)
|
|
726
|
+
else:
|
|
727
|
+
meta["slices_with_cls"] = dict(meta["slices"])
|
|
728
|
+
|
|
729
|
+
# Optionally let user build bias from meta
|
|
730
|
+
if attn_bias is None and attn_bias_fn is not None:
|
|
731
|
+
attn_bias = attn_bias_fn(meta)
|
|
732
|
+
|
|
733
|
+
# Run transformer (+ optional attn collection)
|
|
734
|
+
if return_attn:
|
|
735
|
+
h, attn_all = self.encoder(
|
|
736
|
+
tokens,
|
|
737
|
+
key_padding_mask=key_padding_mask,
|
|
738
|
+
attn_bias=attn_bias,
|
|
739
|
+
return_attn=True,
|
|
740
|
+
attn_average_heads=attn_average_heads,
|
|
741
|
+
)
|
|
742
|
+
else:
|
|
743
|
+
h = self.encoder(
|
|
744
|
+
tokens,
|
|
745
|
+
key_padding_mask=key_padding_mask,
|
|
746
|
+
attn_bias=attn_bias,
|
|
747
|
+
return_attn=False,
|
|
748
|
+
)
|
|
749
|
+
attn_all = None
|
|
750
|
+
|
|
751
|
+
mu, logvar = torch.chunk(h, 2, dim=-1)
|
|
752
|
+
|
|
753
|
+
if return_attn and return_token_meta:
|
|
754
|
+
return mu, logvar, attn_all, meta
|
|
755
|
+
if return_attn:
|
|
756
|
+
return mu, logvar, attn_all
|
|
757
|
+
if return_token_meta:
|
|
758
|
+
return mu, logvar, meta
|
|
759
|
+
return mu, logvar
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
# =============================================================================
|
|
763
|
+
# Factories
|
|
764
|
+
# =============================================================================
|
|
765
|
+
|
|
766
|
+
def build_gaussian_encoder(*, uni_cfg: UniVIConfig, mod_cfg: ModalityConfig) -> GaussianEncoder:
|
|
767
|
+
"""
|
|
768
|
+
Factory for per-modality Gaussian encoders.
|
|
769
|
+
|
|
770
|
+
Supported mod_cfg.encoder_type:
|
|
771
|
+
- "mlp" (default)
|
|
772
|
+
- "transformer"
|
|
773
|
+
"""
|
|
774
|
+
kind = (mod_cfg.encoder_type or "mlp").lower().strip()
|
|
775
|
+
|
|
776
|
+
if kind == "mlp":
|
|
777
|
+
return MLPGaussianEncoder(
|
|
778
|
+
EncoderConfig(
|
|
779
|
+
input_dim=int(mod_cfg.input_dim),
|
|
780
|
+
hidden_dims=list(mod_cfg.encoder_hidden),
|
|
781
|
+
latent_dim=int(uni_cfg.latent_dim),
|
|
782
|
+
dropout=float(uni_cfg.encoder_dropout),
|
|
783
|
+
batchnorm=bool(uni_cfg.encoder_batchnorm),
|
|
784
|
+
)
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
if kind == "transformer":
|
|
788
|
+
if mod_cfg.transformer is None:
|
|
789
|
+
raise ValueError(f"Modality {mod_cfg.name!r}: encoder_type='transformer' requires mod_cfg.transformer.")
|
|
790
|
+
if mod_cfg.tokenizer is None:
|
|
791
|
+
raise ValueError(f"Modality {mod_cfg.name!r}: encoder_type='transformer' requires mod_cfg.tokenizer.")
|
|
792
|
+
|
|
793
|
+
tokenizer = _VectorToTokens(
|
|
794
|
+
input_dim=int(mod_cfg.input_dim),
|
|
795
|
+
tok=mod_cfg.tokenizer,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
tcfg = _cfg_to_model_tcfg(mod_cfg.transformer)
|
|
799
|
+
if tcfg.max_tokens is None:
|
|
800
|
+
tcfg.max_tokens = int(tokenizer.n_tokens + (1 if tokenizer.add_cls_token else 0))
|
|
801
|
+
|
|
802
|
+
return TransformerGaussianEncoder(
|
|
803
|
+
input_dim=int(mod_cfg.input_dim),
|
|
804
|
+
latent_dim=int(uni_cfg.latent_dim),
|
|
805
|
+
tokenizer=tokenizer,
|
|
806
|
+
tcfg=tcfg,
|
|
807
|
+
use_positional_encoding=True,
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
raise ValueError(f"Unknown encoder_type={kind!r} for modality {mod_cfg.name!r}")
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
def build_multimodal_transformer_encoder(
|
|
814
|
+
*,
|
|
815
|
+
uni_cfg: UniVIConfig,
|
|
816
|
+
modalities: Sequence[ModalityConfig],
|
|
817
|
+
fused_modalities: Optional[Sequence[str]] = None,
|
|
818
|
+
) -> MultiModalTransformerGaussianEncoder:
|
|
819
|
+
"""
|
|
820
|
+
Build the fused multimodal transformer encoder from existing per-modality configs.
|
|
821
|
+
|
|
822
|
+
Requirements:
|
|
823
|
+
- uni_cfg.fused_transformer is set
|
|
824
|
+
- each fused modality has mod_cfg.tokenizer set (even if its per-modality encoder is MLP)
|
|
825
|
+
"""
|
|
826
|
+
if uni_cfg.fused_transformer is None:
|
|
827
|
+
raise ValueError("UniVIConfig.fused_transformer must be set for fused_encoder_type='multimodal_transformer'.")
|
|
828
|
+
|
|
829
|
+
mods = {m.name: m for m in modalities}
|
|
830
|
+
use_names = list(fused_modalities) if fused_modalities is not None else list(mods.keys())
|
|
831
|
+
|
|
832
|
+
input_dims = {n: int(mods[n].input_dim) for n in use_names}
|
|
833
|
+
tokenizers: Dict[str, TokenizerConfig] = {}
|
|
834
|
+
for n in use_names:
|
|
835
|
+
if mods[n].tokenizer is None:
|
|
836
|
+
raise ValueError(f"Fused multimodal encoder requires tokenizer for modality {n!r}")
|
|
837
|
+
tokenizers[n] = mods[n].tokenizer
|
|
838
|
+
|
|
839
|
+
return MultiModalTransformerGaussianEncoder(
|
|
840
|
+
modalities=use_names,
|
|
841
|
+
input_dims=input_dims,
|
|
842
|
+
tokenizers=tokenizers,
|
|
843
|
+
transformer_cfg=uni_cfg.fused_transformer,
|
|
844
|
+
latent_dim=int(uni_cfg.latent_dim),
|
|
845
|
+
add_modality_embeddings=bool(getattr(uni_cfg, "fused_add_modality_embeddings", True)),
|
|
846
|
+
use_positional_encoding=True,
|
|
847
|
+
)
|
|
848
|
+
|