univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,848 @@
1
+ # univi/models/encoders.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, replace
5
+ from typing import Dict, List, Optional, Sequence, Tuple, Union, Any, Mapping
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+ from ..config import UniVIConfig, ModalityConfig, TokenizerConfig, TransformerConfig as CFGTransformerConfig
12
+ from .mlp import build_mlp
13
+ from .transformer import TransformerEncoder, TransformerConfig as ModelTransformerConfig
14
+
15
+
16
+ # =============================================================================
17
+ # Small config helper
18
+ # =============================================================================
19
+
20
+ @dataclass
21
+ class EncoderConfig:
22
+ input_dim: int
23
+ hidden_dims: List[int]
24
+ latent_dim: int
25
+ dropout: float = 0.1
26
+ batchnorm: bool = True
27
+
28
+
29
+ # =============================================================================
30
+ # Base encoders
31
+ # =============================================================================
32
+
33
+ class GaussianEncoder(nn.Module):
34
+ """Base: x -> (mu, logvar) for a diagonal Gaussian."""
35
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ raise NotImplementedError
37
+
38
+
39
+ class MLPGaussianEncoder(GaussianEncoder):
40
+ """MLP encoder: x -> (mu, logvar) directly."""
41
+ def __init__(self, cfg: EncoderConfig):
42
+ super().__init__()
43
+ self.cfg = cfg
44
+ self.net = build_mlp(
45
+ in_dim=int(cfg.input_dim),
46
+ hidden_dims=list(cfg.hidden_dims),
47
+ out_dim=2 * int(cfg.latent_dim),
48
+ dropout=float(cfg.dropout),
49
+ batchnorm=bool(cfg.batchnorm),
50
+ )
51
+
52
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
53
+ h = self.net(x)
54
+ mu, logvar = torch.chunk(h, 2, dim=-1)
55
+ return mu, logvar
56
+
57
+
58
+ # =============================================================================
59
+ # Tokenization: vector -> tokens (+ optional embeddings / bias)
60
+ # =============================================================================
61
+
62
+ def _mlp(in_dim: int, out_dim: int, hidden: int = 128) -> nn.Module:
63
+ return nn.Sequential(
64
+ nn.LayerNorm(in_dim),
65
+ nn.Linear(in_dim, hidden),
66
+ nn.GELU(),
67
+ nn.Linear(hidden, out_dim),
68
+ )
69
+
70
+
71
+ class _VectorToTokens(nn.Module):
72
+ """
73
+ Turn a vector x (B, F) into tokens (B, T, D_in) and optional key_padding_mask.
74
+
75
+ TokenizerConfig modes
76
+ ---------------------
77
+ - topk_scalar:
78
+ Select top-k features per cell (by value), output tokens (B, K, 1)
79
+ - topk_channels:
80
+ Select top-k features per cell, output tokens (B, K, C) where C=len(channels)
81
+ channels in {"value","rank","dropout"}:
82
+ * value: raw x at selected indices
83
+ * rank: rank01 among selected K tokens (0..1), per cell
84
+ * dropout: indicator (value==0)
85
+ - patch:
86
+ Split contiguous features into patches of size P:
87
+ tokens (B, T, P) or (B, T, patch_proj_dim) if patch_proj_dim is set
88
+
89
+ add_cls_token:
90
+ If True, prepend a learned CLS token embedding to tokens.
91
+
92
+ NEW (optional)
93
+ --------------
94
+ - use_feature_embedding:
95
+ Adds learned embedding for selected feature IDs (topk modes), or patch IDs (patch mode).
96
+ - use_coord_embedding (ATAC):
97
+ Adds chromosome embedding + coordinate MLP for selected feature coords (topk modes).
98
+ Call set_feature_coords(chrom_ids, start, end) to attach coords.
99
+ - token_proj_dim:
100
+ If set (or implied by embeddings), project raw token channels/patches to token_proj_dim.
101
+
102
+ Notes
103
+ -----
104
+ - Expects dense float input (B,F). If modality data is sparse, densify upstream.
105
+ - key_padding_mask uses True = PAD/ignore (MultiheadAttention convention).
106
+ """
107
+ def __init__(self, *, input_dim: int, tok: TokenizerConfig):
108
+ super().__init__()
109
+ self.input_dim = int(input_dim)
110
+ self.tok = tok
111
+
112
+ mode = str(tok.mode).lower().strip()
113
+ if mode not in ("topk_scalar", "topk_channels", "patch"):
114
+ raise ValueError(f"Unknown tokenizer mode {tok.mode!r}")
115
+
116
+ self.mode = mode
117
+ self.add_cls_token = bool(getattr(tok, "add_cls_token", False))
118
+
119
+ # ------------------------------------------------------------------
120
+ # NEW options (all default to False/None -> no behavior change)
121
+ # ------------------------------------------------------------------
122
+ self.use_feature_emb = bool(getattr(tok, "use_feature_embedding", False))
123
+ self.feature_emb_mode = str(getattr(tok, "feature_emb_mode", "add")).lower().strip()
124
+ self.n_features = getattr(tok, "n_features", None)
125
+ self.feature_emb_dim = getattr(tok, "feature_emb_dim", None)
126
+
127
+ self.use_coord_emb = bool(getattr(tok, "use_coord_embedding", False))
128
+ self.coord_mode = str(getattr(tok, "coord_mode", "midpoint")).lower().strip()
129
+ self.coord_scale = float(getattr(tok, "coord_scale", 1e-6))
130
+ self.coord_emb_dim = getattr(tok, "coord_emb_dim", None)
131
+ self.n_chroms = getattr(tok, "n_chroms", None)
132
+ self.coord_mlp_hidden = int(getattr(tok, "coord_mlp_hidden", 128))
133
+
134
+ self.token_proj_dim = getattr(tok, "token_proj_dim", None)
135
+ self._has_coords = False
136
+ self.register_buffer("_chrom_ids", torch.empty(0, dtype=torch.long), persistent=False)
137
+ self.register_buffer("_start", torch.empty(0), persistent=False)
138
+ self.register_buffer("_end", torch.empty(0), persistent=False)
139
+
140
+ # ------------------------------------------------------------------
141
+ # Base token dims per mode (pre-projection)
142
+ # ------------------------------------------------------------------
143
+ if self.mode == "topk_scalar":
144
+ self.n_tokens = int(tok.n_tokens)
145
+ base_d = 1
146
+ self.channels: Sequence[str] = ("value",)
147
+
148
+ elif self.mode == "topk_channels":
149
+ self.n_tokens = int(tok.n_tokens)
150
+ ch = list(tok.channels)
151
+ if not ch:
152
+ raise ValueError("topk_channels requires tokenizer.channels non-empty")
153
+ bad = [c for c in ch if c not in ("value", "rank", "dropout")]
154
+ if bad:
155
+ raise ValueError(f"topk_channels invalid channels: {bad}")
156
+ self.channels = tuple(ch)
157
+ base_d = len(self.channels)
158
+
159
+ else: # patch
160
+ P = int(tok.patch_size)
161
+ if P <= 0:
162
+ raise ValueError("patch_size must be > 0")
163
+ self.patch_size = P
164
+ T = (self.input_dim + P - 1) // P
165
+ self.n_tokens = int(T)
166
+
167
+ proj_dim = tok.patch_proj_dim
168
+ if proj_dim is None:
169
+ self.patch_proj = None
170
+ base_d = P
171
+ else:
172
+ proj_dim = int(proj_dim)
173
+ if proj_dim <= 0:
174
+ raise ValueError("patch_proj_dim must be > 0 if set")
175
+ self.patch_proj = nn.Linear(P, proj_dim)
176
+ base_d = proj_dim
177
+
178
+ # ------------------------------------------------------------------
179
+ # Decide final token dim (d_in)
180
+ # - if token_proj_dim set -> project to it
181
+ # - else if embeddings enabled -> pick a reasonable dim (feature/coord emb dim, or 64)
182
+ # - else -> keep base_d (exact backward-compat)
183
+ # ------------------------------------------------------------------
184
+ implied_proj: Optional[int] = None
185
+ if self.token_proj_dim is not None:
186
+ implied_proj = int(self.token_proj_dim)
187
+ elif self.use_feature_emb or self.use_coord_emb:
188
+ implied_proj = int(self.feature_emb_dim or self.coord_emb_dim or 64)
189
+
190
+ self._d_in = int(implied_proj) if implied_proj is not None else int(base_d)
191
+
192
+ # ------------------------------------------------------------------
193
+ # Optional projection from base_d -> d_in
194
+ # ------------------------------------------------------------------
195
+ self.val_proj: Optional[nn.Module] = None
196
+ if self._d_in != int(base_d):
197
+ self.val_proj = _mlp(int(base_d), self._d_in, hidden=self.coord_mlp_hidden)
198
+
199
+ # ------------------------------------------------------------------
200
+ # NEW: feature ID embedding (topk) / patch ID embedding (patch)
201
+ # ------------------------------------------------------------------
202
+ self.id_emb: Optional[nn.Embedding] = None
203
+ self.id_fuse: Optional[nn.Linear] = None
204
+ if self.use_feature_emb:
205
+ if self.mode in ("topk_scalar", "topk_channels"):
206
+ if self.n_features is None or int(self.n_features) <= 0:
207
+ raise ValueError("use_feature_embedding=True for topk requires tok.n_features > 0")
208
+ emb_dim = int(self.feature_emb_dim or self._d_in)
209
+ self.id_emb = nn.Embedding(int(self.n_features), emb_dim)
210
+ if self.feature_emb_mode == "concat":
211
+ self.id_fuse = nn.Linear(self._d_in + emb_dim, self._d_in)
212
+ else:
213
+ # patch: embed patch index (0..T-1)
214
+ emb_dim = int(self.feature_emb_dim or self._d_in)
215
+ self.id_emb = nn.Embedding(int(self.n_tokens), emb_dim)
216
+ if self.feature_emb_mode == "concat":
217
+ self.id_fuse = nn.Linear(self._d_in + emb_dim, self._d_in)
218
+
219
+ # ------------------------------------------------------------------
220
+ # NEW: coord embeddings (topk only)
221
+ # ------------------------------------------------------------------
222
+ self.chrom_emb: Optional[nn.Embedding] = None
223
+ self.coord_mlp: Optional[nn.Module] = None
224
+ if self.use_coord_emb:
225
+ if self.mode not in ("topk_scalar", "topk_channels"):
226
+ raise ValueError("use_coord_embedding=True is only supported for topk_* tokenizers.")
227
+ if self.n_chroms is None or int(self.n_chroms) <= 0:
228
+ raise ValueError("use_coord_embedding=True requires tok.n_chroms > 0")
229
+ self.chrom_emb = nn.Embedding(int(self.n_chroms), self._d_in)
230
+ cd_in = 1 if self.coord_mode == "midpoint" else 2
231
+ self.coord_mlp = _mlp(cd_in, self._d_in, hidden=self.coord_mlp_hidden)
232
+
233
+ # ------------------------------------------------------------------
234
+ # CLS token (learned, matches d_in)
235
+ # ------------------------------------------------------------------
236
+ self.cls_token: Optional[nn.Parameter] = None
237
+ if self.add_cls_token:
238
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self._d_in))
239
+ nn.init.normal_(self.cls_token, mean=0.0, std=0.02)
240
+
241
+ @property
242
+ def d_in(self) -> int:
243
+ return int(self._d_in)
244
+
245
+ def set_feature_coords(self, chrom_ids: torch.Tensor, start: torch.Tensor, end: torch.Tensor) -> None:
246
+ """
247
+ Attach per-feature genomic coordinates (for ATAC coordinate embeddings and distance bias).
248
+
249
+ Parameters
250
+ ----------
251
+ chrom_ids : LongTensor [F] with values in [0, n_chroms)
252
+ start/end : Float/LongTensor [F] in basepairs
253
+ """
254
+ chrom_ids = chrom_ids.long().contiguous()
255
+ start = start.to(dtype=torch.float32).contiguous()
256
+ end = end.to(dtype=torch.float32).contiguous()
257
+ if chrom_ids.ndim != 1 or start.ndim != 1 or end.ndim != 1:
258
+ raise ValueError("set_feature_coords expects 1D tensors: chrom_ids/start/end.")
259
+ if not (chrom_ids.shape[0] == start.shape[0] == end.shape[0] == self.input_dim):
260
+ raise ValueError(
261
+ f"set_feature_coords expects length F={self.input_dim}; got "
262
+ f"{chrom_ids.shape[0]}, {start.shape[0]}, {end.shape[0]}"
263
+ )
264
+ self._chrom_ids = chrom_ids
265
+ self._start = start
266
+ self._end = end
267
+ self._has_coords = True
268
+
269
+ def _apply_id_emb(self, tokens: torch.Tensor, ids: torch.Tensor) -> torch.Tensor:
270
+ if self.id_emb is None:
271
+ return tokens
272
+ idv = self.id_emb(ids) # (..., E)
273
+
274
+ if self.feature_emb_mode == "add":
275
+ if idv.shape[-1] != tokens.shape[-1]:
276
+ raise ValueError(
277
+ f"feature_emb_mode='add' requires emb_dim==d_in; got emb_dim={idv.shape[-1]} vs d_in={tokens.shape[-1]}"
278
+ )
279
+ return tokens + idv
280
+
281
+ if self.feature_emb_mode == "concat":
282
+ if self.id_fuse is None:
283
+ raise RuntimeError("id_fuse not initialized for concat mode.")
284
+ return self.id_fuse(torch.cat([tokens, idv], dim=-1))
285
+
286
+ raise ValueError(f"Unknown feature_emb_mode={self.feature_emb_mode!r}")
287
+
288
+ def _apply_coord_emb(self, tokens: torch.Tensor, topk_idx: torch.Tensor) -> torch.Tensor:
289
+ if not self.use_coord_emb:
290
+ return tokens
291
+ if not self._has_coords:
292
+ # coords not attached -> silently no-op (keeps things robust)
293
+ return tokens
294
+ if self.chrom_emb is None or self.coord_mlp is None:
295
+ return tokens
296
+
297
+ B, K = topk_idx.shape
298
+ Fdim = self.input_dim
299
+ if self._chrom_ids.numel() != Fdim:
300
+ raise ValueError(f"Tokenizer coords are for F={self._chrom_ids.numel()} but input_dim={Fdim}")
301
+
302
+ chrom = torch.gather(self._chrom_ids.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
303
+ start = torch.gather(self._start.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
304
+ end = torch.gather(self._end.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
305
+
306
+ chrom_e = self.chrom_emb(chrom) # (B,K,D)
307
+
308
+ if self.coord_mode == "midpoint":
309
+ mid = 0.5 * (start + end)
310
+ pos = (mid * self.coord_scale).unsqueeze(-1) # (B,K,1)
311
+ else:
312
+ pos = torch.stack([start * self.coord_scale, end * self.coord_scale], dim=-1) # (B,K,2)
313
+
314
+ pos_e = self.coord_mlp(pos) # (B,K,D)
315
+ return tokens + chrom_e + pos_e
316
+
317
+ def build_distance_attn_bias(
318
+ self,
319
+ topk_idx: torch.Tensor,
320
+ *,
321
+ lengthscale_bp: float = 50_000.0,
322
+ same_chrom_only: bool = True,
323
+ include_cls: bool = False,
324
+ cls_is_zero: bool = True,
325
+ ) -> torch.Tensor:
326
+ """
327
+ Build an additive attention bias matrix for topk tokens based on genomic distance.
328
+
329
+ Returns
330
+ -------
331
+ attn_bias : FloatTensor (B, T, T)
332
+ Additive bias to attention logits (higher = more attention).
333
+ Uses a Gaussian kernel: bias = - (dist / lengthscale)^2
334
+
335
+ Notes
336
+ -----
337
+ - Only valid when coords are attached via set_feature_coords().
338
+ - For CLS handling:
339
+ include_cls=True assumes your tokens will have CLS prepended at position 0.
340
+ """
341
+ if self.mode not in ("topk_scalar", "topk_channels"):
342
+ raise ValueError("build_distance_attn_bias is only supported for topk_* modes.")
343
+ if not self._has_coords:
344
+ raise ValueError("build_distance_attn_bias requires feature coords; call set_feature_coords first.")
345
+ if float(lengthscale_bp) <= 0:
346
+ raise ValueError("lengthscale_bp must be > 0")
347
+
348
+ B, K = topk_idx.shape
349
+ Fdim = self.input_dim
350
+
351
+ chrom = torch.gather(self._chrom_ids.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
352
+ start = torch.gather(self._start.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
353
+ end = torch.gather(self._end.view(1, Fdim).expand(B, Fdim), 1, topk_idx) # (B,K)
354
+ mid = 0.5 * (start + end) # (B,K)
355
+
356
+ # pairwise distances per batch: (B,K,K)
357
+ dist = (mid.unsqueeze(2) - mid.unsqueeze(1)).abs()
358
+
359
+ if same_chrom_only:
360
+ same = (chrom.unsqueeze(2) == chrom.unsqueeze(1))
361
+ else:
362
+ same = torch.ones((B, K, K), device=dist.device, dtype=torch.bool)
363
+
364
+ ls = float(lengthscale_bp)
365
+ bias = -((dist / ls) ** 2)
366
+ bias = torch.where(same, bias, torch.full_like(bias, -1e4))
367
+
368
+ if include_cls:
369
+ T = K + 1
370
+ out = torch.zeros((B, T, T), device=bias.device, dtype=bias.dtype)
371
+ out[:, 1:, 1:] = bias
372
+ if not cls_is_zero:
373
+ # (rare) you could choose to bias CLS-to-all; default keeps it neutral
374
+ pass
375
+ return out
376
+
377
+ return bias
378
+
379
+ def forward(
380
+ self,
381
+ x: torch.Tensor,
382
+ *,
383
+ return_indices: bool = False,
384
+ ) -> Union[
385
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
386
+ Tuple[torch.Tensor, Optional[torch.Tensor], Dict[str, Any]],
387
+ ]:
388
+ if x.dim() != 2:
389
+ raise ValueError(f"_VectorToTokens expects x as (B,F); got shape {tuple(x.shape)}")
390
+ B, Fdim = x.shape
391
+ if Fdim != self.input_dim:
392
+ raise ValueError(f"Expected input_dim={self.input_dim}, got F={Fdim}")
393
+
394
+ key_padding_mask: Optional[torch.Tensor] = None
395
+ meta: Dict[str, Any] = {}
396
+
397
+ if self.mode == "topk_scalar":
398
+ K = min(int(self.n_tokens), Fdim)
399
+ vals, idx = torch.topk(x, k=K, dim=1, largest=True, sorted=False) # (B,K)
400
+ tokens = vals.unsqueeze(-1) # (B,K,1)
401
+ key_padding_mask = None
402
+ if return_indices:
403
+ meta["topk_idx"] = idx # (B,K)
404
+
405
+ if self.val_proj is not None:
406
+ tokens = self.val_proj(tokens) # (B,K,D)
407
+
408
+ if self.use_feature_emb:
409
+ tokens = self._apply_id_emb(tokens, idx)
410
+
411
+ tokens = self._apply_coord_emb(tokens, idx)
412
+
413
+ elif self.mode == "topk_channels":
414
+ K = min(int(self.n_tokens), Fdim)
415
+ vals, idx = torch.topk(x, k=K, dim=1, largest=True, sorted=True) # (B,K)
416
+ feats = []
417
+ for c in self.channels:
418
+ if c == "value":
419
+ feats.append(vals)
420
+ elif c == "rank":
421
+ if K <= 1:
422
+ rank01 = torch.zeros((B, K), device=x.device, dtype=torch.float32)
423
+ else:
424
+ rank01 = torch.linspace(0.0, 1.0, steps=K, device=x.device, dtype=torch.float32)
425
+ rank01 = rank01.unsqueeze(0).expand(B, K)
426
+ feats.append(rank01)
427
+ elif c == "dropout":
428
+ feats.append((vals == 0).to(torch.float32))
429
+ else:
430
+ raise RuntimeError(f"Unhandled channel: {c!r}")
431
+ tokens = torch.stack(feats, dim=-1) # (B,K,C)
432
+ key_padding_mask = None
433
+ if return_indices:
434
+ meta["topk_idx"] = idx # (B,K)
435
+
436
+ if self.val_proj is not None:
437
+ tokens = self.val_proj(tokens) # (B,K,D)
438
+
439
+ if self.use_feature_emb:
440
+ tokens = self._apply_id_emb(tokens, idx)
441
+
442
+ tokens = self._apply_coord_emb(tokens, idx)
443
+
444
+ else: # patch
445
+ P = int(self.patch_size)
446
+ T = int(self.n_tokens)
447
+ pad = T * P - Fdim
448
+ if pad > 0:
449
+ x_pad = F.pad(x, (0, pad), mode="constant", value=0.0) # (B, T*P)
450
+ else:
451
+ x_pad = x
452
+ patches = x_pad.view(B, T, P) # (B,T,P)
453
+
454
+ tokens = self.patch_proj(patches) if self.patch_proj is not None else patches
455
+
456
+ if pad > 0:
457
+ real_counts = torch.full((T,), P, device=x.device, dtype=torch.int64)
458
+ last_real = Fdim - (T - 1) * P
459
+ if last_real < P:
460
+ real_counts[-1] = max(int(last_real), 0)
461
+ key_padding_mask = (real_counts == 0).unsqueeze(0).expand(B, T)
462
+ else:
463
+ key_padding_mask = None
464
+
465
+ if return_indices:
466
+ meta["patch_size"] = P
467
+ meta["n_patches"] = T
468
+ meta["pad"] = pad
469
+
470
+ if self.use_feature_emb and self.id_emb is not None:
471
+ pid = torch.arange(T, device=x.device).view(1, T).expand(B, T) # (B,T)
472
+ tokens = self._apply_id_emb(tokens, pid)
473
+
474
+ if self.add_cls_token:
475
+ assert self.cls_token is not None
476
+ cls = self.cls_token.expand(B, -1, -1) # (B,1,D)
477
+ tokens = torch.cat([cls, tokens], dim=1)
478
+ if key_padding_mask is not None:
479
+ cls_mask = torch.zeros((B, 1), device=x.device, dtype=torch.bool)
480
+ key_padding_mask = torch.cat([cls_mask, key_padding_mask], dim=1)
481
+
482
+ if return_indices:
483
+ return tokens, key_padding_mask, meta
484
+ return tokens, key_padding_mask
485
+
486
+
487
+ # =============================================================================
488
+ # Per-modality transformer Gaussian encoder
489
+ # =============================================================================
490
+
491
+ def _cfg_to_model_tcfg(cfg: CFGTransformerConfig) -> ModelTransformerConfig:
492
+ return ModelTransformerConfig(
493
+ d_model=int(cfg.d_model),
494
+ num_heads=int(cfg.num_heads),
495
+ num_layers=int(cfg.num_layers),
496
+ dim_feedforward=int(cfg.dim_feedforward),
497
+ dropout=float(cfg.dropout),
498
+ attn_dropout=float(cfg.attn_dropout),
499
+ activation=str(cfg.activation),
500
+ pooling=str(cfg.pooling),
501
+ max_tokens=None if cfg.max_tokens is None else int(cfg.max_tokens),
502
+ )
503
+
504
+
505
+ class TransformerGaussianEncoder(GaussianEncoder):
506
+ """
507
+ (B,F) -> tokens (B,T,D_in) -> TransformerEncoder -> (mu, logvar)
508
+
509
+ New convenience:
510
+ - attach coords: self.vec2tok.set_feature_coords(...)
511
+ - optional attn_bias passthrough (e.g., distance bias)
512
+ """
513
+ def __init__(
514
+ self,
515
+ *,
516
+ input_dim: int,
517
+ latent_dim: int,
518
+ tokenizer: _VectorToTokens,
519
+ tcfg: ModelTransformerConfig,
520
+ use_positional_encoding: bool = True,
521
+ ):
522
+ super().__init__()
523
+ self.vec2tok = tokenizer
524
+
525
+ if use_positional_encoding and tcfg.max_tokens is None:
526
+ tcfg.max_tokens = int(self.vec2tok.n_tokens + (1 if self.vec2tok.add_cls_token else 0))
527
+
528
+ self.encoder = TransformerEncoder(
529
+ cfg=tcfg,
530
+ d_in=int(self.vec2tok.d_in),
531
+ d_out=2 * int(latent_dim),
532
+ use_positional_encoding=bool(use_positional_encoding),
533
+ )
534
+
535
+ def forward(
536
+ self,
537
+ x: torch.Tensor,
538
+ *,
539
+ return_attn: bool = False,
540
+ attn_average_heads: bool = True,
541
+ return_token_meta: bool = False,
542
+ attn_bias: Optional[torch.Tensor] = None,
543
+ ):
544
+ if return_token_meta:
545
+ tokens, key_padding_mask, meta = self.vec2tok(x, return_indices=True)
546
+ else:
547
+ tokens, key_padding_mask = self.vec2tok(x, return_indices=False)
548
+ meta = None
549
+
550
+ if return_attn:
551
+ h, attn_all = self.encoder(
552
+ tokens,
553
+ key_padding_mask=key_padding_mask,
554
+ attn_bias=attn_bias,
555
+ return_attn=True,
556
+ attn_average_heads=attn_average_heads,
557
+ )
558
+ else:
559
+ h = self.encoder(
560
+ tokens,
561
+ key_padding_mask=key_padding_mask,
562
+ attn_bias=attn_bias,
563
+ return_attn=False,
564
+ )
565
+ attn_all = None
566
+
567
+ mu, logvar = torch.chunk(h, 2, dim=-1)
568
+
569
+ if return_attn and return_token_meta:
570
+ return mu, logvar, attn_all, meta
571
+ if return_attn:
572
+ return mu, logvar, attn_all
573
+ if return_token_meta:
574
+ return mu, logvar, meta
575
+ return mu, logvar
576
+
577
+
578
+ # =============================================================================
579
+ # Multimodal concatenated-token transformer Gaussian encoder (fused)
580
+ # =============================================================================
581
+
582
+ def _tokcfg_without_cls(tok_cfg_in: TokenizerConfig) -> TokenizerConfig:
583
+ # preserve all fields, only override add_cls_token=False
584
+ return replace(tok_cfg_in, add_cls_token=False)
585
+
586
+
587
+ class MultiModalTransformerGaussianEncoder(nn.Module):
588
+ """
589
+ Fused encoder over multiple modalities by concatenating tokens.
590
+
591
+ Produces ONE fused posterior q(z | x_all). It does not replace per-modality q(z|x_m).
592
+
593
+ Clean coord hook:
594
+ fused_encoder.set_feature_coords("atac", chrom_ids, start, end)
595
+
596
+ Attention bias:
597
+ You can pass attn_bias=... into forward, or pass attn_bias_fn(meta)->bias.
598
+ """
599
+ def __init__(
600
+ self,
601
+ *,
602
+ modalities: Sequence[str],
603
+ input_dims: Dict[str, int],
604
+ tokenizers: Dict[str, TokenizerConfig],
605
+ transformer_cfg: CFGTransformerConfig,
606
+ latent_dim: int,
607
+ add_modality_embeddings: bool = True,
608
+ use_positional_encoding: bool = True,
609
+ ):
610
+ super().__init__()
611
+
612
+ self.modalities = list(modalities)
613
+ self.latent_dim = int(latent_dim)
614
+
615
+ tcfg_model = _cfg_to_model_tcfg(transformer_cfg)
616
+ d_model = int(tcfg_model.d_model)
617
+
618
+ self.vec2tok = nn.ModuleDict()
619
+ self.proj = nn.ModuleDict()
620
+ self.mod_emb = nn.ParameterDict() if add_modality_embeddings else None
621
+
622
+ total_tokens = 0
623
+ for m in self.modalities:
624
+ tok_cfg_in = tokenizers[m]
625
+ tok_cfg = _tokcfg_without_cls(tok_cfg_in) # force per-modality CLS off (one global CLS optional)
626
+
627
+ tok = _VectorToTokens(input_dim=int(input_dims[m]), tok=tok_cfg)
628
+ self.vec2tok[m] = tok
629
+ self.proj[m] = nn.Linear(int(tok.d_in), d_model, bias=True)
630
+
631
+ if self.mod_emb is not None:
632
+ self.mod_emb[m] = nn.Parameter(torch.zeros(1, 1, d_model))
633
+ nn.init.normal_(self.mod_emb[m], mean=0.0, std=0.02)
634
+
635
+ total_tokens += int(tok.n_tokens)
636
+
637
+ self.pooling = str(tcfg_model.pooling).lower().strip()
638
+ self.use_global_cls = (self.pooling == "cls")
639
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) if self.use_global_cls else None
640
+ if self.cls_token is not None:
641
+ nn.init.normal_(self.cls_token, mean=0.0, std=0.02)
642
+
643
+ if use_positional_encoding and tcfg_model.max_tokens is None:
644
+ tcfg_model.max_tokens = int(total_tokens + (1 if self.use_global_cls else 0))
645
+
646
+ self.encoder = TransformerEncoder(
647
+ cfg=tcfg_model,
648
+ d_in=d_model,
649
+ d_out=2 * int(latent_dim),
650
+ use_positional_encoding=bool(use_positional_encoding),
651
+ )
652
+
653
+ def set_feature_coords(self, modality: str, chrom_ids: torch.Tensor, start: torch.Tensor, end: torch.Tensor) -> None:
654
+ modality = str(modality)
655
+ if modality not in self.vec2tok:
656
+ raise KeyError(f"Unknown modality {modality!r}. Known: {list(self.vec2tok.keys())}")
657
+ self.vec2tok[modality].set_feature_coords(chrom_ids, start, end)
658
+
659
+ def forward(
660
+ self,
661
+ x_dict: Dict[str, torch.Tensor],
662
+ *,
663
+ return_token_meta: bool = False,
664
+ return_attn: bool = False,
665
+ attn_average_heads: bool = True,
666
+ attn_bias: Optional[torch.Tensor] = None,
667
+ attn_bias_fn: Optional[Any] = None,
668
+ ):
669
+ tokens_list: List[torch.Tensor] = []
670
+ masks_list: List[Optional[torch.Tensor]] = []
671
+
672
+ meta: Dict[str, Any] = {
673
+ "modalities": self.modalities,
674
+ "slices": {}, # modality -> (start, end) in CONCAT TOKEN SPACE (excluding global CLS)
675
+ "has_global_cls": bool(self.use_global_cls),
676
+ "cls_index": 0 if self.use_global_cls else None,
677
+ }
678
+ t_cursor = 0
679
+
680
+ for m in self.modalities:
681
+ x = x_dict[m]
682
+
683
+ if return_token_meta:
684
+ tok, mask, mmeta = self.vec2tok[m](x, return_indices=True)
685
+ meta[m] = mmeta
686
+ else:
687
+ tok, mask = self.vec2tok[m](x, return_indices=False)
688
+
689
+ tok = self.proj[m](tok) # (B,T,d_model)
690
+ if self.mod_emb is not None:
691
+ tok = tok + self.mod_emb[m]
692
+
693
+ Tm = tok.shape[1]
694
+ meta["slices"][m] = (t_cursor, t_cursor + Tm)
695
+ t_cursor += Tm
696
+
697
+ tokens_list.append(tok)
698
+ masks_list.append(mask)
699
+
700
+ tokens = torch.cat(tokens_list, dim=1) # (B, T_total, d_model)
701
+
702
+ key_padding_mask: Optional[torch.Tensor] = None
703
+ if any(m is not None for m in masks_list):
704
+ B = tokens.shape[0]
705
+ built: List[torch.Tensor] = []
706
+ for i, mname in enumerate(self.modalities):
707
+ mask = masks_list[i]
708
+ if mask is None:
709
+ Tm = tokens_list[i].shape[1]
710
+ built.append(torch.zeros((B, Tm), device=tokens.device, dtype=torch.bool))
711
+ else:
712
+ built.append(mask.to(dtype=torch.bool))
713
+ key_padding_mask = torch.cat(built, dim=1)
714
+
715
+ # Prepend ONE global CLS if pooling="cls"
716
+ if self.use_global_cls:
717
+ assert self.cls_token is not None
718
+ cls = self.cls_token.expand(tokens.shape[0], -1, -1) # (B,1,D)
719
+ tokens = torch.cat([cls, tokens], dim=1)
720
+
721
+ meta["slices_with_cls"] = {k: (a + 1, b + 1) for k, (a, b) in meta["slices"].items()}
722
+
723
+ if key_padding_mask is not None:
724
+ cls_mask = torch.zeros((tokens.shape[0], 1), device=tokens.device, dtype=torch.bool)
725
+ key_padding_mask = torch.cat([cls_mask, key_padding_mask], dim=1)
726
+ else:
727
+ meta["slices_with_cls"] = dict(meta["slices"])
728
+
729
+ # Optionally let user build bias from meta
730
+ if attn_bias is None and attn_bias_fn is not None:
731
+ attn_bias = attn_bias_fn(meta)
732
+
733
+ # Run transformer (+ optional attn collection)
734
+ if return_attn:
735
+ h, attn_all = self.encoder(
736
+ tokens,
737
+ key_padding_mask=key_padding_mask,
738
+ attn_bias=attn_bias,
739
+ return_attn=True,
740
+ attn_average_heads=attn_average_heads,
741
+ )
742
+ else:
743
+ h = self.encoder(
744
+ tokens,
745
+ key_padding_mask=key_padding_mask,
746
+ attn_bias=attn_bias,
747
+ return_attn=False,
748
+ )
749
+ attn_all = None
750
+
751
+ mu, logvar = torch.chunk(h, 2, dim=-1)
752
+
753
+ if return_attn and return_token_meta:
754
+ return mu, logvar, attn_all, meta
755
+ if return_attn:
756
+ return mu, logvar, attn_all
757
+ if return_token_meta:
758
+ return mu, logvar, meta
759
+ return mu, logvar
760
+
761
+
762
+ # =============================================================================
763
+ # Factories
764
+ # =============================================================================
765
+
766
+ def build_gaussian_encoder(*, uni_cfg: UniVIConfig, mod_cfg: ModalityConfig) -> GaussianEncoder:
767
+ """
768
+ Factory for per-modality Gaussian encoders.
769
+
770
+ Supported mod_cfg.encoder_type:
771
+ - "mlp" (default)
772
+ - "transformer"
773
+ """
774
+ kind = (mod_cfg.encoder_type or "mlp").lower().strip()
775
+
776
+ if kind == "mlp":
777
+ return MLPGaussianEncoder(
778
+ EncoderConfig(
779
+ input_dim=int(mod_cfg.input_dim),
780
+ hidden_dims=list(mod_cfg.encoder_hidden),
781
+ latent_dim=int(uni_cfg.latent_dim),
782
+ dropout=float(uni_cfg.encoder_dropout),
783
+ batchnorm=bool(uni_cfg.encoder_batchnorm),
784
+ )
785
+ )
786
+
787
+ if kind == "transformer":
788
+ if mod_cfg.transformer is None:
789
+ raise ValueError(f"Modality {mod_cfg.name!r}: encoder_type='transformer' requires mod_cfg.transformer.")
790
+ if mod_cfg.tokenizer is None:
791
+ raise ValueError(f"Modality {mod_cfg.name!r}: encoder_type='transformer' requires mod_cfg.tokenizer.")
792
+
793
+ tokenizer = _VectorToTokens(
794
+ input_dim=int(mod_cfg.input_dim),
795
+ tok=mod_cfg.tokenizer,
796
+ )
797
+
798
+ tcfg = _cfg_to_model_tcfg(mod_cfg.transformer)
799
+ if tcfg.max_tokens is None:
800
+ tcfg.max_tokens = int(tokenizer.n_tokens + (1 if tokenizer.add_cls_token else 0))
801
+
802
+ return TransformerGaussianEncoder(
803
+ input_dim=int(mod_cfg.input_dim),
804
+ latent_dim=int(uni_cfg.latent_dim),
805
+ tokenizer=tokenizer,
806
+ tcfg=tcfg,
807
+ use_positional_encoding=True,
808
+ )
809
+
810
+ raise ValueError(f"Unknown encoder_type={kind!r} for modality {mod_cfg.name!r}")
811
+
812
+
813
+ def build_multimodal_transformer_encoder(
814
+ *,
815
+ uni_cfg: UniVIConfig,
816
+ modalities: Sequence[ModalityConfig],
817
+ fused_modalities: Optional[Sequence[str]] = None,
818
+ ) -> MultiModalTransformerGaussianEncoder:
819
+ """
820
+ Build the fused multimodal transformer encoder from existing per-modality configs.
821
+
822
+ Requirements:
823
+ - uni_cfg.fused_transformer is set
824
+ - each fused modality has mod_cfg.tokenizer set (even if its per-modality encoder is MLP)
825
+ """
826
+ if uni_cfg.fused_transformer is None:
827
+ raise ValueError("UniVIConfig.fused_transformer must be set for fused_encoder_type='multimodal_transformer'.")
828
+
829
+ mods = {m.name: m for m in modalities}
830
+ use_names = list(fused_modalities) if fused_modalities is not None else list(mods.keys())
831
+
832
+ input_dims = {n: int(mods[n].input_dim) for n in use_names}
833
+ tokenizers: Dict[str, TokenizerConfig] = {}
834
+ for n in use_names:
835
+ if mods[n].tokenizer is None:
836
+ raise ValueError(f"Fused multimodal encoder requires tokenizer for modality {n!r}")
837
+ tokenizers[n] = mods[n].tokenizer
838
+
839
+ return MultiModalTransformerGaussianEncoder(
840
+ modalities=use_names,
841
+ input_dims=input_dims,
842
+ tokenizers=tokenizers,
843
+ transformer_cfg=uni_cfg.fused_transformer,
844
+ latent_dim=int(uni_cfg.latent_dim),
845
+ add_modality_embeddings=bool(getattr(uni_cfg, "fused_add_modality_embeddings", True)),
846
+ use_positional_encoding=True,
847
+ )
848
+