univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/models/univi.py ADDED
@@ -0,0 +1,1284 @@
1
+ # univi/models/univi.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Dict, Tuple, Optional, Any, List, Union, Mapping, Callable
5
+ import math
6
+
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+ from ..config import UniVIConfig, ModalityConfig, ClassHeadConfig
12
+ from .mlp import build_mlp
13
+ from .decoders import DecoderConfig, build_decoder
14
+ from .encoders import build_gaussian_encoder, build_multimodal_transformer_encoder
15
+
16
+
17
+ YType = Union[torch.Tensor, Dict[str, torch.Tensor]]
18
+
19
+
20
+ class _GradReverseFn(torch.autograd.Function):
21
+ """Gradient reversal layer: forward identity, backward -lambda * grad."""
22
+ @staticmethod
23
+ def forward(ctx, x: torch.Tensor, lambd: float):
24
+ ctx.lambd = float(lambd)
25
+ return x.view_as(x)
26
+
27
+ @staticmethod
28
+ def backward(ctx, grad_output: torch.Tensor):
29
+ return -ctx.lambd * grad_output, None
30
+
31
+
32
+ class UniVIMultiModalVAE(nn.Module):
33
+ """
34
+ Multi-modal β-VAE with per-modality encoders and decoders.
35
+
36
+ Defaults (backwards-compatible)
37
+ -------------------------------
38
+ - Per-modality encoders: MLP (or transformer if set on that modality)
39
+ - Fused posterior: MoE/PoE-style precision fusion over per-modality posteriors
40
+
41
+ Optional fused multimodal transformer posterior
42
+ ----------------------------------------------
43
+ If cfg.fused_encoder_type == "multimodal_transformer", build a fused encoder that:
44
+ x_dict -> concat(tokens_rna, tokens_adt, tokens_atac, ...) -> transformer -> (mu_fused, logvar_fused)
45
+
46
+ In that mode:
47
+ - v2/lite can use mu_fused/logvar_fused for z
48
+ - alignment term becomes mean_i ||mu_i - mu_fused||^2 (instead of pairwise)
49
+ - if required fused modalities are missing (cfg.fused_require_all_modalities=True), fall back to MoE fusion
50
+
51
+ Optional attention bias (safe no-op)
52
+ ------------------------------------
53
+ You can pass `attn_bias_cfg` into forward/encode_fused/predict_heads:
54
+
55
+ attn_bias_cfg = {
56
+ "atac": {"type": "distance", "lengthscale_bp": 50000, "same_chrom_only": True},
57
+ }
58
+
59
+ - Per-modality transformer encoders: will try to build a distance attention bias if the tokenizer
60
+ supports topk indices + coords and exposes build_distance_attn_bias().
61
+ - Fused multimodal transformer encoder: will fill within-modality blocks (e.g. ATAC slice) and
62
+ keep cross-modality blocks neutral (0).
63
+ """
64
+
65
+ LOGVAR_MIN = -10.0
66
+ LOGVAR_MAX = 10.0
67
+ EPS = 1e-8
68
+
69
+ def __init__(
70
+ self,
71
+ cfg: UniVIConfig,
72
+ *,
73
+ loss_mode: str = "v2",
74
+ v1_recon: str = "cross",
75
+ v1_recon_mix: float = 0.0,
76
+ normalize_v1_terms: bool = True,
77
+
78
+ # ---- legacy single label head (kept) ----
79
+ n_label_classes: int = 0,
80
+ label_loss_weight: float = 1.0,
81
+ use_label_encoder: bool = False,
82
+ label_moe_weight: float = 1.0,
83
+ unlabeled_logvar: float = 20.0,
84
+ label_encoder_warmup: int = 0,
85
+ label_ignore_index: int = -1,
86
+ classify_from_mu: bool = True,
87
+ label_head_name: Optional[str] = None,
88
+ ):
89
+ super().__init__()
90
+ self.cfg = cfg
91
+ self.cfg.validate()
92
+
93
+ self.loss_mode = str(loss_mode).lower().strip()
94
+ self.v1_recon = str(v1_recon).lower().strip()
95
+ self.v1_recon_mix = float(v1_recon_mix)
96
+ self.normalize_v1_terms = bool(normalize_v1_terms)
97
+
98
+ self.latent_dim = int(cfg.latent_dim)
99
+ self.beta_max = float(cfg.beta)
100
+ self.gamma_max = float(cfg.gamma)
101
+
102
+ self.modality_names: List[str] = [m.name for m in cfg.modalities]
103
+ self.mod_cfg_by_name: Dict[str, ModalityConfig] = {m.name: m for m in cfg.modalities}
104
+
105
+ # ------------------------------------------------------------
106
+ # Per-modality modules
107
+ # ------------------------------------------------------------
108
+ self.encoders = nn.ModuleDict()
109
+ self.encoder_heads = nn.ModuleDict()
110
+ self.decoders = nn.ModuleDict()
111
+
112
+ for m in cfg.modalities:
113
+ if not isinstance(m, ModalityConfig):
114
+ raise TypeError(f"cfg.modalities must contain ModalityConfig, got {type(m)}")
115
+
116
+ self.encoders[m.name] = build_gaussian_encoder(uni_cfg=cfg, mod_cfg=m)
117
+ self.encoder_heads[m.name] = nn.Identity()
118
+
119
+ dec_hidden = list(m.decoder_hidden) if m.decoder_hidden else [max(64, self.latent_dim)]
120
+ dec_cfg = DecoderConfig(
121
+ output_dim=int(m.input_dim),
122
+ hidden_dims=dec_hidden,
123
+ dropout=float(cfg.decoder_dropout),
124
+ batchnorm=bool(cfg.decoder_batchnorm),
125
+ )
126
+
127
+ lk = (m.likelihood or "gaussian").lower().strip()
128
+ decoder_kwargs: Dict[str, Any] = {}
129
+
130
+ if lk in ("nb", "negative_binomial", "zinb", "zero_inflated_negative_binomial"):
131
+ dispersion = getattr(m, "dispersion", "gene")
132
+ init_log_theta = float(getattr(m, "init_log_theta", 0.0))
133
+ decoder_kwargs = dict(
134
+ dispersion=dispersion,
135
+ init_log_theta=init_log_theta,
136
+ eps=self.EPS,
137
+ )
138
+
139
+ self.decoders[m.name] = build_decoder(
140
+ lk,
141
+ cfg=dec_cfg,
142
+ latent_dim=self.latent_dim,
143
+ **decoder_kwargs,
144
+ )
145
+
146
+ # Shared prior N(0, I)
147
+ self.register_buffer("prior_mu", torch.zeros(self.latent_dim))
148
+ self.register_buffer("prior_logvar", torch.zeros(self.latent_dim))
149
+
150
+ # ------------------------------------------------------------
151
+ # Optional fused multimodal encoder
152
+ # ------------------------------------------------------------
153
+ self.fused_encoder_type = (getattr(cfg, "fused_encoder_type", "moe") or "moe").lower().strip()
154
+ self.fused_require_all = bool(getattr(cfg, "fused_require_all_modalities", True))
155
+ self.fused_modalities = list(getattr(cfg, "fused_modalities", None) or self.modality_names)
156
+
157
+ if self.fused_encoder_type == "multimodal_transformer":
158
+ self.fused_encoder = build_multimodal_transformer_encoder(
159
+ uni_cfg=cfg,
160
+ modalities=cfg.modalities,
161
+ fused_modalities=self.fused_modalities,
162
+ )
163
+ else:
164
+ self.fused_encoder = None
165
+
166
+ # ------------------------------------------------------------
167
+ # Legacy single label head (kept)
168
+ # ------------------------------------------------------------
169
+ self.n_label_classes = int(n_label_classes) if n_label_classes is not None else 0
170
+ self.label_loss_weight = float(label_loss_weight)
171
+ self.label_ignore_index = int(label_ignore_index)
172
+ self.classify_from_mu = bool(classify_from_mu)
173
+ self.label_head_name = str(label_head_name or getattr(cfg, "label_head_name", "label"))
174
+
175
+ if self.n_label_classes > 0:
176
+ dec_cfg = DecoderConfig(
177
+ output_dim=self.n_label_classes,
178
+ hidden_dims=[max(64, self.latent_dim)],
179
+ dropout=float(cfg.decoder_dropout),
180
+ batchnorm=bool(cfg.decoder_batchnorm),
181
+ )
182
+ self.label_decoder = build_decoder("categorical", cfg=dec_cfg, latent_dim=self.latent_dim)
183
+ else:
184
+ self.label_decoder = None
185
+
186
+ self.label_names: Optional[List[str]] = None
187
+ self.label_name_to_id: Optional[Dict[str, int]] = None
188
+
189
+ # Legacy label encoder expert (optional)
190
+ self.use_label_encoder = bool(use_label_encoder)
191
+ self.label_moe_weight = float(label_moe_weight)
192
+ self.unlabeled_logvar = float(unlabeled_logvar)
193
+ self.label_encoder_warmup = int(label_encoder_warmup)
194
+
195
+ if self.n_label_classes > 0 and self.use_label_encoder:
196
+ self.label_encoder = build_mlp(
197
+ in_dim=self.n_label_classes,
198
+ hidden_dims=[max(64, self.latent_dim)],
199
+ out_dim=self.latent_dim * 2,
200
+ activation=nn.ReLU(),
201
+ dropout=float(cfg.encoder_dropout),
202
+ batchnorm=bool(cfg.encoder_batchnorm),
203
+ )
204
+ else:
205
+ self.label_encoder = None
206
+
207
+ # Multi-head supervised decoders (incl. adversarial heads)
208
+ self.class_heads = nn.ModuleDict()
209
+ self.class_heads_cfg: Dict[str, Dict[str, Any]] = {}
210
+ self.head_label_names: Dict[str, List[str]] = {}
211
+ self.head_label_name_to_id: Dict[str, Dict[str, int]] = {}
212
+
213
+ if cfg.class_heads:
214
+ for h in cfg.class_heads:
215
+ if not isinstance(h, ClassHeadConfig):
216
+ raise TypeError(f"cfg.class_heads must contain ClassHeadConfig, got {type(h)}")
217
+
218
+ name = str(h.name)
219
+ n_classes = int(h.n_classes)
220
+
221
+ dec_cfg = DecoderConfig(
222
+ output_dim=n_classes,
223
+ hidden_dims=[max(64, self.latent_dim)],
224
+ dropout=float(cfg.decoder_dropout),
225
+ batchnorm=bool(cfg.decoder_batchnorm),
226
+ )
227
+ self.class_heads[name] = build_decoder("categorical", cfg=dec_cfg, latent_dim=self.latent_dim)
228
+ self.class_heads_cfg[name] = {
229
+ "n_classes": n_classes,
230
+ "loss_weight": float(h.loss_weight),
231
+ "ignore_index": int(h.ignore_index),
232
+ "from_mu": bool(h.from_mu),
233
+ "warmup": int(h.warmup),
234
+ "adversarial": bool(getattr(h, "adversarial", False)),
235
+ "adv_lambda": float(getattr(h, "adv_lambda", 1.0)),
236
+ }
237
+
238
+ # ----------------------------- label name utilities -----------------------------
239
+
240
+ def set_label_names(self, label_names: List[str]) -> None:
241
+ if self.n_label_classes <= 0:
242
+ raise ValueError("n_label_classes=0; cannot set label names.")
243
+ if len(label_names) != int(self.n_label_classes):
244
+ raise ValueError(f"label_names length {len(label_names)} != n_label_classes {self.n_label_classes}")
245
+
246
+ self.label_names = [str(x) for x in label_names]
247
+
248
+ def norm(s: str) -> str:
249
+ return " ".join(str(s).strip().lower().split())
250
+
251
+ m: Dict[str, int] = {}
252
+ for i, name in enumerate(self.label_names):
253
+ m[name] = i
254
+ m[norm(name)] = i
255
+ self.label_name_to_id = m
256
+
257
+ def set_head_label_names(self, head: str, label_names: List[str]) -> None:
258
+ head = str(head)
259
+ if head not in self.class_heads_cfg:
260
+ raise KeyError(f"Unknown head {head!r}. Known heads: {list(self.class_heads_cfg)}")
261
+
262
+ n = int(self.class_heads_cfg[head]["n_classes"])
263
+ if len(label_names) != n:
264
+ raise ValueError(f"Head {head!r}: label_names length {len(label_names)} != n_classes {n}")
265
+
266
+ names = [str(x) for x in label_names]
267
+ self.head_label_names[head] = names
268
+
269
+ def norm(s: str) -> str:
270
+ return " ".join(str(s).strip().lower().split())
271
+
272
+ m: Dict[str, int] = {}
273
+ for i, name in enumerate(names):
274
+ m[name] = i
275
+ m[norm(name)] = i
276
+ self.head_label_name_to_id[head] = m
277
+
278
+ # ----------------------------- helpers -----------------------------
279
+
280
+ def _reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
281
+ std = torch.exp(0.5 * logvar)
282
+ eps = torch.randn_like(std)
283
+ return mu + eps * std
284
+
285
+ def _kl_gaussian(self, mu_q: torch.Tensor, logvar_q: torch.Tensor, mu_p: torch.Tensor, logvar_p: torch.Tensor) -> torch.Tensor:
286
+ var_q = torch.exp(logvar_q)
287
+ var_p = torch.exp(logvar_p)
288
+ kl = logvar_p - logvar_q + (var_q + (mu_q - mu_p) ** 2) / var_p - 1.0
289
+ return 0.5 * kl.sum(dim=-1)
290
+
291
+ def _alignment_loss_l2mu_pairwise(self, mu_per_mod: Dict[str, torch.Tensor]) -> torch.Tensor:
292
+ names = list(mu_per_mod.keys())
293
+ if len(names) < 2:
294
+ if len(names) == 0:
295
+ return torch.tensor(0.0, device=next(self.parameters()).device).expand(1)
296
+ k = names[0]
297
+ return torch.zeros(mu_per_mod[k].size(0), device=mu_per_mod[k].device)
298
+
299
+ losses = []
300
+ for i in range(len(names)):
301
+ for j in range(i + 1, len(names)):
302
+ losses.append(((mu_per_mod[names[i]] - mu_per_mod[names[j]]) ** 2).sum(dim=-1))
303
+ return torch.stack(losses, dim=0).mean(dim=0)
304
+
305
+ def _alignment_loss_to_fused(self, mu_per_mod: Dict[str, torch.Tensor], mu_fused: torch.Tensor) -> torch.Tensor:
306
+ if len(mu_per_mod) == 0:
307
+ return torch.zeros(mu_fused.size(0), device=mu_fused.device)
308
+ losses = [((mu - mu_fused) ** 2).sum(dim=-1) for mu in mu_per_mod.values()]
309
+ return torch.stack(losses, dim=0).mean(dim=0)
310
+
311
+ @staticmethod
312
+ def _is_categorical_likelihood(likelihood: Optional[str]) -> bool:
313
+ lk = (likelihood or "").lower().strip()
314
+ return lk in ("categorical", "cat", "ce", "cross_entropy", "multinomial", "softmax")
315
+
316
+ def _categorical_targets_and_mask(
317
+ self,
318
+ x: torch.Tensor,
319
+ *,
320
+ n_classes: int,
321
+ ignore_index: int,
322
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
323
+ if x.dim() == 2 and x.size(1) == 1:
324
+ x = x[:, 0]
325
+
326
+ if x.dim() == 2:
327
+ mask = x.sum(dim=-1) > 0.5
328
+ y = x.argmax(dim=-1).long()
329
+ return y, mask
330
+
331
+ y = x.view(-1)
332
+ if y.dtype.is_floating_point:
333
+ y = y.round()
334
+ y = y.long()
335
+
336
+ mask = (y != int(ignore_index))
337
+ return y, mask
338
+
339
+ def _encode_categorical_input_if_needed(self, mod_name: str, x: torch.Tensor) -> torch.Tensor:
340
+ m_cfg = self.mod_cfg_by_name[mod_name]
341
+ if not self._is_categorical_likelihood(m_cfg.likelihood):
342
+ return x
343
+
344
+ C = int(m_cfg.input_dim)
345
+ ignore_index = int(getattr(m_cfg, "ignore_index", self.label_ignore_index))
346
+
347
+ if x.dim() == 2 and x.size(1) == 1:
348
+ x = x[:, 0]
349
+ if x.dim() == 2:
350
+ return x.float()
351
+
352
+ y, mask = self._categorical_targets_and_mask(x, n_classes=C, ignore_index=ignore_index)
353
+ B = y.shape[0]
354
+ x_oh = torch.zeros((B, C), device=y.device, dtype=torch.float32)
355
+ if mask.any():
356
+ x_oh[mask] = F.one_hot(y[mask], num_classes=C).float()
357
+ return x_oh
358
+
359
+ # -------------------------- attention-bias helpers (optional, safe no-op) --------------------------
360
+
361
+ def _build_distance_bias_for_permod_transformer(
362
+ self,
363
+ enc: nn.Module,
364
+ x_in: torch.Tensor,
365
+ cfg_m: Mapping[str, Any],
366
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
367
+ """
368
+ Attempt to build (tokens, key_padding_mask, attn_bias) for a TransformerGaussianEncoder-like module.
369
+
370
+ Returns (None, None, None) if:
371
+ - encoder isn't transformer-style,
372
+ - tokenizer doesn't expose topk indices,
373
+ - coords aren't attached,
374
+ - config isn't distance type, etc.
375
+ """
376
+ typ = str(cfg_m.get("type", "")).lower().strip()
377
+ if typ != "distance":
378
+ return None, None, None
379
+
380
+ vec2tok = getattr(enc, "vec2tok", None)
381
+ core = getattr(enc, "encoder", None)
382
+ if vec2tok is None or core is None:
383
+ return None, None, None
384
+
385
+ if not hasattr(vec2tok, "build_distance_attn_bias"):
386
+ return None, None, None
387
+
388
+ try:
389
+ tokens, key_padding_mask, meta = vec2tok(x_in, return_indices=True)
390
+ except Exception:
391
+ return None, None, None
392
+
393
+ topk_idx = None if meta is None else meta.get("topk_idx", None)
394
+ if topk_idx is None:
395
+ return None, None, None
396
+
397
+ lengthscale_bp = float(cfg_m.get("lengthscale_bp", 50_000.0))
398
+ same_chrom_only = bool(cfg_m.get("same_chrom_only", True))
399
+ include_cls = bool(getattr(vec2tok, "add_cls_token", False))
400
+
401
+ try:
402
+ attn_bias = vec2tok.build_distance_attn_bias(
403
+ topk_idx,
404
+ lengthscale_bp=lengthscale_bp,
405
+ same_chrom_only=same_chrom_only,
406
+ include_cls=include_cls,
407
+ )
408
+ except Exception:
409
+ return None, None, None
410
+
411
+ return tokens, key_padding_mask, attn_bias
412
+
413
+ def _build_fused_attn_bias_fn(
414
+ self,
415
+ attn_bias_cfg: Mapping[str, Any],
416
+ sub_x_dict: Dict[str, torch.Tensor],
417
+ ) -> Optional[Callable[[Dict[str, Any]], Optional[torch.Tensor]]]:
418
+ """
419
+ Build a callable(meta)->(B,T,T) for MultiModalTransformerGaussianEncoder.
420
+
421
+ Behavior:
422
+ - fills ONLY within-modality blocks for modalities configured as distance bias
423
+ - keeps cross-modality logits neutral (0)
424
+ """
425
+ fused = self.fused_encoder
426
+ if fused is None:
427
+ return None
428
+
429
+ vec2tok_map = getattr(fused, "vec2tok", None)
430
+ if vec2tok_map is None:
431
+ return None
432
+
433
+ # only bother if any modality is configured for distance bias
434
+ want_any = any(str(v.get("type", "")).lower().strip() == "distance" for v in attn_bias_cfg.values() if isinstance(v, Mapping))
435
+ if not want_any:
436
+ return None
437
+
438
+ def fn(meta: Dict[str, Any]) -> Optional[torch.Tensor]:
439
+ if not meta:
440
+ return None
441
+
442
+ # batch size
443
+ any_x = next(iter(sub_x_dict.values()))
444
+ B = int(any_x.shape[0])
445
+
446
+ slices = meta.get("slices_with_cls", meta.get("slices", {}))
447
+ if not slices:
448
+ return None
449
+
450
+ # total token length in fused space (already includes global CLS if present)
451
+ T = 0
452
+ for _, (a, b) in slices.items():
453
+ T = max(T, int(b))
454
+ if bool(meta.get("has_global_cls", False)):
455
+ T = max(T, 1)
456
+
457
+ bias_full = torch.zeros((B, T, T), device=any_x.device, dtype=torch.float32)
458
+
459
+ for m, cfg_m in attn_bias_cfg.items():
460
+ if not isinstance(cfg_m, Mapping):
461
+ continue
462
+ if str(cfg_m.get("type", "")).lower().strip() != "distance":
463
+ continue
464
+ if m not in slices:
465
+ continue
466
+ if m not in vec2tok_map:
467
+ continue
468
+
469
+ tok = vec2tok_map[m]
470
+ if not hasattr(tok, "build_distance_attn_bias"):
471
+ continue
472
+
473
+ mmeta = meta.get(m, {}) or {}
474
+ topk_idx = mmeta.get("topk_idx", None)
475
+ if topk_idx is None:
476
+ continue
477
+
478
+ lengthscale_bp = float(cfg_m.get("lengthscale_bp", 50_000.0))
479
+ same_chrom_only = bool(cfg_m.get("same_chrom_only", True))
480
+
481
+ try:
482
+ local = tok.build_distance_attn_bias(
483
+ topk_idx,
484
+ lengthscale_bp=lengthscale_bp,
485
+ same_chrom_only=same_chrom_only,
486
+ include_cls=False, # per-modality cls is forced off in fused encoder
487
+ ) # (B, K, K)
488
+ except Exception:
489
+ continue
490
+
491
+ a, b = slices[m]
492
+ a = int(a); b = int(b)
493
+ if (b - a) != int(local.shape[1]):
494
+ continue
495
+
496
+ bias_full[:, a:b, a:b] = local
497
+
498
+ return bias_full
499
+
500
+ return fn
501
+
502
+ # -------------------------- encode/decode/fuse --------------------------
503
+
504
+ def encode_modalities(
505
+ self,
506
+ x_dict: Dict[str, torch.Tensor],
507
+ *,
508
+ attn_bias_cfg: Optional[Mapping[str, Any]] = None,
509
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
510
+ mu_dict: Dict[str, torch.Tensor] = {}
511
+ logvar_dict: Dict[str, torch.Tensor] = {}
512
+
513
+ for m in self.modality_names:
514
+ if m not in x_dict or x_dict[m] is None:
515
+ continue
516
+
517
+ x_in = self._encode_categorical_input_if_needed(m, x_dict[m])
518
+ enc = self.encoders[m]
519
+
520
+ # optional: transformer distance bias (single-pass through core encoder)
521
+ tokens = key_padding_mask = attn_bias = None
522
+ if attn_bias_cfg is not None and isinstance(attn_bias_cfg, Mapping):
523
+ cfg_m = attn_bias_cfg.get(m, None)
524
+ if isinstance(cfg_m, Mapping):
525
+ tokens, key_padding_mask, attn_bias = self._build_distance_bias_for_permod_transformer(enc, x_in, cfg_m)
526
+
527
+ if tokens is not None and getattr(enc, "encoder", None) is not None:
528
+ # enc is TransformerGaussianEncoder-like; run core directly
529
+ h = enc.encoder(
530
+ tokens,
531
+ key_padding_mask=key_padding_mask,
532
+ attn_bias=attn_bias,
533
+ return_attn=False,
534
+ )
535
+ mu, logvar = torch.chunk(h, 2, dim=-1)
536
+ else:
537
+ # fallback: standard encoder call
538
+ try:
539
+ mu, logvar = enc(x_in, attn_bias=attn_bias) # transformer supports attn_bias kw
540
+ except TypeError:
541
+ mu, logvar = enc(x_in)
542
+
543
+ mu = self.encoder_heads[m](mu)
544
+ logvar = torch.clamp(logvar, self.LOGVAR_MIN, self.LOGVAR_MAX)
545
+
546
+ mu_dict[m] = mu
547
+ logvar_dict[m] = logvar
548
+
549
+ return mu_dict, logvar_dict
550
+
551
+ def mixture_of_experts(self, mu_dict: Dict[str, torch.Tensor], logvar_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
552
+ mus = list(mu_dict.values())
553
+ logvars = list(logvar_dict.values())
554
+
555
+ precisions = [torch.exp(-lv) for lv in logvars]
556
+ precision_sum = torch.stack(precisions, dim=0).sum(dim=0).clamp_min(self.EPS)
557
+
558
+ mu_weighted = torch.stack([m * p for m, p in zip(mus, precisions)], dim=0).sum(dim=0)
559
+ mu_comb = mu_weighted / precision_sum
560
+
561
+ var_comb = 1.0 / precision_sum
562
+ logvar_comb = torch.log(var_comb.clamp_min(self.EPS))
563
+ return mu_comb, logvar_comb
564
+
565
+ def decode_modalities(self, z: torch.Tensor) -> Dict[str, Any]:
566
+ return {m: self.decoders[m](z) for m in self.modality_names}
567
+
568
+ # -------------------- fused posterior --------------------
569
+
570
+ def _present_fused_modalities(self, x_dict: Dict[str, torch.Tensor]) -> List[str]:
571
+ return [m for m in self.fused_modalities if (m in x_dict) and (x_dict[m] is not None)]
572
+
573
+ def _can_use_fused_encoder(self, x_dict: Dict[str, torch.Tensor]) -> bool:
574
+ if self.fused_encoder is None:
575
+ return False
576
+ if not self.fused_modalities:
577
+ return False
578
+
579
+ present = self._present_fused_modalities(x_dict)
580
+ if self.fused_require_all:
581
+ return len(present) == len(self.fused_modalities)
582
+ return len(present) >= 1
583
+
584
+ def _compute_fused_posterior(
585
+ self,
586
+ x_dict: Dict[str, torch.Tensor],
587
+ *,
588
+ attn_bias_cfg: Optional[Mapping[str, Any]] = None,
589
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
590
+ assert self.fused_encoder is not None
591
+ present = self._present_fused_modalities(x_dict)
592
+ sub = {m: x_dict[m] for m in present}
593
+
594
+ attn_bias_fn = None
595
+ want_meta = False
596
+ if attn_bias_cfg is not None and isinstance(attn_bias_cfg, Mapping):
597
+ attn_bias_fn = self._build_fused_attn_bias_fn(attn_bias_cfg, sub)
598
+ want_meta = attn_bias_fn is not None
599
+
600
+ out = self.fused_encoder(sub, return_token_meta=want_meta, attn_bias_fn=attn_bias_fn)
601
+ if want_meta:
602
+ mu_f, logvar_f, _meta = out # type: ignore[misc]
603
+ else:
604
+ mu_f, logvar_f = out # type: ignore[misc]
605
+
606
+ logvar_f = torch.clamp(logvar_f, self.LOGVAR_MIN, self.LOGVAR_MAX)
607
+ return mu_f, logvar_f
608
+
609
+ # -------------------- label expert --------------------
610
+
611
+ def _encode_labels_as_expert(self, y: torch.Tensor, B: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
612
+ mu_y = torch.zeros(B, self.latent_dim, device=device)
613
+ logvar_y = torch.full((B, self.latent_dim), float(self.unlabeled_logvar), device=device)
614
+
615
+ if self.label_encoder is None or y is None:
616
+ return mu_y, logvar_y
617
+
618
+ y = y.long()
619
+ mask = (y >= 0) & (y != self.label_ignore_index)
620
+ if not mask.any():
621
+ return mu_y, logvar_y
622
+
623
+ y_oh = F.one_hot(y[mask], num_classes=self.n_label_classes).float()
624
+ h = self.label_encoder(y_oh)
625
+ mu_l, logvar_l = torch.chunk(h, 2, dim=-1)
626
+ logvar_l = torch.clamp(logvar_l, self.LOGVAR_MIN, self.LOGVAR_MAX)
627
+
628
+ w = float(self.label_moe_weight)
629
+ if w != 1.0:
630
+ logvar_l = logvar_l - math.log(max(w, 1e-8))
631
+
632
+ mu_y[mask] = mu_l
633
+ logvar_y[mask] = logvar_l
634
+ return mu_y, logvar_y
635
+
636
+ def _extract_legacy_y(self, y: Optional[YType]) -> Optional[torch.Tensor]:
637
+ if y is None:
638
+ return None
639
+ if isinstance(y, Mapping):
640
+ v = y.get(self.label_head_name, None)
641
+ return v.long() if v is not None else None
642
+ return y.long()
643
+
644
+ def _grad_reverse(self, x: torch.Tensor, lambd: float = 1.0) -> torch.Tensor:
645
+ return _GradReverseFn.apply(x, float(lambd))
646
+
647
+ def _apply_multihead_losses(
648
+ self,
649
+ *,
650
+ mu_z: torch.Tensor,
651
+ z: torch.Tensor,
652
+ y: Optional[YType],
653
+ epoch: int,
654
+ loss: torch.Tensor,
655
+ loss_annealed: torch.Tensor,
656
+ loss_fixed: torch.Tensor,
657
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
658
+ head_logits: Dict[str, torch.Tensor] = {}
659
+ head_loss_means: Dict[str, torch.Tensor] = {}
660
+
661
+ if y is None or not isinstance(y, Mapping) or len(self.class_heads) == 0:
662
+ return loss, loss_annealed, loss_fixed, head_logits, head_loss_means
663
+
664
+ B = int(mu_z.size(0))
665
+ device = mu_z.device
666
+
667
+ for name, head in self.class_heads.items():
668
+ cfg_h = self.class_heads_cfg[name]
669
+ if int(epoch) < int(cfg_h["warmup"]):
670
+ continue
671
+
672
+ y_h = y.get(name, None)
673
+ if y_h is None:
674
+ continue
675
+
676
+ y_h = y_h.long()
677
+ z_in = mu_z if bool(cfg_h["from_mu"]) else z
678
+
679
+ if bool(cfg_h.get("adversarial", False)):
680
+ z_in = self._grad_reverse(z_in, cfg_h.get("adv_lambda", 1.0))
681
+
682
+ dec_out = head(z_in)
683
+ logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
684
+ head_logits[name] = logits
685
+
686
+ ignore_index = int(cfg_h["ignore_index"])
687
+ mask = (y_h != ignore_index) & (y_h >= 0)
688
+ per_cell = torch.zeros(B, device=device)
689
+ if mask.any():
690
+ per_cell[mask] = F.cross_entropy(logits[mask], y_h[mask], reduction="none")
691
+
692
+ w = float(cfg_h["loss_weight"])
693
+ if w != 0.0:
694
+ loss = loss + w * per_cell
695
+ loss_annealed = loss_annealed + w * per_cell
696
+ loss_fixed = loss_fixed + w * per_cell
697
+
698
+ head_loss_means[name] = per_cell.mean()
699
+
700
+ return loss, loss_annealed, loss_fixed, head_logits, head_loss_means
701
+
702
+ # ------------------------------ forward dispatcher ------------------------------
703
+
704
+ def forward(
705
+ self,
706
+ x_dict: Dict[str, torch.Tensor],
707
+ epoch: int = 0,
708
+ y: Optional[YType] = None,
709
+ attn_bias_cfg: Optional[Mapping[str, Any]] = None,
710
+ **kwargs: Any, # keep older wrappers from breaking if they pass extra args
711
+ ) -> Dict[str, torch.Tensor]:
712
+ mode = (self.loss_mode or "v2").lower()
713
+ if mode in ("v1", "paper", "cross"):
714
+ return self._forward_v1(x_dict=x_dict, epoch=epoch, y=y, attn_bias_cfg=attn_bias_cfg)
715
+ if mode in ("v2", "lite", "light", "moe", "poe", "fused"):
716
+ return self._forward_v2(x_dict=x_dict, epoch=epoch, y=y, attn_bias_cfg=attn_bias_cfg)
717
+ raise ValueError(f"Unknown loss_mode={self.loss_mode!r}.")
718
+
719
+ # ------------------------------ v2 / lite ------------------------------
720
+
721
+ def _forward_v2(
722
+ self,
723
+ x_dict: Dict[str, torch.Tensor],
724
+ epoch: int = 0,
725
+ y: Optional[YType] = None,
726
+ attn_bias_cfg: Optional[Mapping[str, Any]] = None,
727
+ ) -> Dict[str, torch.Tensor]:
728
+ mu_dict, logvar_dict = self.encode_modalities(x_dict, attn_bias_cfg=attn_bias_cfg)
729
+ if len(mu_dict) == 0:
730
+ raise ValueError("At least one modality must be present in x_dict.")
731
+
732
+ use_fused = (self.fused_encoder_type == "multimodal_transformer") and self._can_use_fused_encoder(x_dict)
733
+ mu_fused: Optional[torch.Tensor] = None
734
+ logvar_fused: Optional[torch.Tensor] = None
735
+
736
+ if use_fused:
737
+ mu_fused, logvar_fused = self._compute_fused_posterior(x_dict, attn_bias_cfg=attn_bias_cfg)
738
+ mu_z, logvar_z = mu_fused, logvar_fused
739
+ else:
740
+ mu_z, logvar_z = self.mixture_of_experts(mu_dict, logvar_dict)
741
+
742
+ y_legacy = self._extract_legacy_y(y)
743
+ if self.label_encoder is not None and y_legacy is not None and epoch >= self.label_encoder_warmup:
744
+ B = mu_z.shape[0]
745
+ device = mu_z.device
746
+ mu_y, logvar_y = self._encode_labels_as_expert(y=y_legacy, B=B, device=device)
747
+
748
+ base_mu_dict = {"__base__": mu_z, "__label__": mu_y}
749
+ base_lv_dict = {"__base__": logvar_z, "__label__": logvar_y}
750
+ mu_z, logvar_z = self.mixture_of_experts(base_mu_dict, base_lv_dict)
751
+
752
+ z = self._reparameterize(mu_z, logvar_z)
753
+ xhat_dict = self.decode_modalities(z)
754
+
755
+ recon_total: Optional[torch.Tensor] = None
756
+ recon_losses: Dict[str, torch.Tensor] = {}
757
+
758
+ for name, m_cfg in self.mod_cfg_by_name.items():
759
+ if name not in x_dict or x_dict[name] is None:
760
+ continue
761
+ loss_m = self._recon_loss(
762
+ x=x_dict[name],
763
+ raw_dec_out=xhat_dict[name],
764
+ likelihood=m_cfg.likelihood,
765
+ mod_name=name,
766
+ )
767
+ recon_losses[name] = loss_m
768
+ recon_total = loss_m if recon_total is None else (recon_total + loss_m)
769
+
770
+ if recon_total is None:
771
+ raise RuntimeError("No modalities in x_dict produced recon loss.")
772
+
773
+ mu_p = self.prior_mu.expand_as(mu_z)
774
+ logvar_p = self.prior_logvar.expand_as(logvar_z)
775
+ kl = self._kl_gaussian(mu_z, logvar_z, mu_p, logvar_p)
776
+
777
+ real_mu = {k: v for k, v in mu_dict.items() if not k.startswith("__")}
778
+ if use_fused and (mu_fused is not None):
779
+ align_loss = self._alignment_loss_to_fused(real_mu, mu_fused)
780
+ else:
781
+ align_loss = self._alignment_loss_l2mu_pairwise(real_mu)
782
+
783
+ beta_t = self._anneal_weight(epoch, self.cfg.kl_anneal_start, self.cfg.kl_anneal_end, self.beta_max)
784
+ gamma_t = self._anneal_weight(epoch, self.cfg.align_anneal_start, self.cfg.align_anneal_end, self.gamma_max)
785
+
786
+ beta_used = beta_t if self.training else self.beta_max
787
+ gamma_used = gamma_t if self.training else self.gamma_max
788
+
789
+ loss_annealed = recon_total + beta_t * kl + gamma_t * align_loss
790
+ loss_fixed = recon_total + self.beta_max * kl + self.gamma_max * align_loss
791
+ loss = recon_total + beta_used * kl + gamma_used * align_loss
792
+
793
+ class_loss = None
794
+ class_logits = None
795
+ if self.label_decoder is not None:
796
+ z_for_cls = mu_z if self.classify_from_mu else z
797
+ dec_out = self.label_decoder(z_for_cls)
798
+ class_logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
799
+
800
+ if y_legacy is not None:
801
+ B = loss.shape[0]
802
+ yy = y_legacy.long()
803
+ mask = (yy >= 0) & (yy != self.label_ignore_index)
804
+ per_cell = torch.zeros(B, device=loss.device)
805
+ if mask.any():
806
+ per_cell[mask] = F.cross_entropy(class_logits[mask], yy[mask], reduction="none")
807
+ class_loss = per_cell
808
+
809
+ loss = loss + self.label_loss_weight * class_loss
810
+ loss_annealed = loss_annealed + self.label_loss_weight * class_loss
811
+ loss_fixed = loss_fixed + self.label_loss_weight * class_loss
812
+
813
+ loss, loss_annealed, loss_fixed, head_logits, head_loss_means = self._apply_multihead_losses(
814
+ mu_z=mu_z,
815
+ z=z,
816
+ y=y,
817
+ epoch=epoch,
818
+ loss=loss,
819
+ loss_annealed=loss_annealed,
820
+ loss_fixed=loss_fixed,
821
+ )
822
+
823
+ out: Dict[str, Any] = {
824
+ "loss": loss.mean(),
825
+ "recon_total": recon_total.mean(),
826
+ "kl": kl.mean(),
827
+ "align": align_loss.mean(),
828
+ "mu_z": mu_z,
829
+ "logvar_z": logvar_z,
830
+ "z": z,
831
+ "xhat": xhat_dict,
832
+ "mu_dict": mu_dict,
833
+ "logvar_dict": logvar_dict,
834
+ "recon_per_modality": {k: v.mean() for k, v in recon_losses.items()},
835
+ "beta": torch.tensor(beta_t, device=loss.device),
836
+ "gamma": torch.tensor(gamma_t, device=loss.device),
837
+ "beta_used": torch.tensor(beta_used, device=loss.device),
838
+ "gamma_used": torch.tensor(gamma_used, device=loss.device),
839
+ "loss_annealed": loss_annealed.mean(),
840
+ "loss_fixed": loss_fixed.mean(),
841
+ "used_fused_encoder": torch.tensor(1.0 if use_fused else 0.0, device=loss.device),
842
+ }
843
+ if class_loss is not None:
844
+ out["class_loss"] = class_loss.mean()
845
+ if class_logits is not None:
846
+ out["class_logits"] = class_logits
847
+ if head_logits:
848
+ out["head_logits"] = head_logits
849
+ if head_loss_means:
850
+ out["head_losses"] = head_loss_means
851
+ return out
852
+
853
+ # ------------------------------ v1 ------------------------------
854
+
855
+ def _forward_v1(
856
+ self,
857
+ x_dict: Dict[str, torch.Tensor],
858
+ epoch: int = 0,
859
+ y: Optional[YType] = None,
860
+ attn_bias_cfg: Optional[Mapping[str, Any]] = None,
861
+ ) -> Dict[str, torch.Tensor]:
862
+ mu_dict, logvar_dict = self.encode_modalities(x_dict, attn_bias_cfg=attn_bias_cfg)
863
+ if len(mu_dict) == 0:
864
+ raise ValueError("At least one modality must be present in x_dict.")
865
+
866
+ present = list(mu_dict.keys())
867
+ K = len(present)
868
+
869
+ use_fused = (self.fused_encoder_type == "multimodal_transformer") and self._can_use_fused_encoder(x_dict)
870
+ if use_fused:
871
+ mu_moe, logvar_moe = self._compute_fused_posterior(x_dict, attn_bias_cfg=attn_bias_cfg)
872
+ else:
873
+ mu_moe, logvar_moe = self.mixture_of_experts(mu_dict, logvar_dict)
874
+
875
+ z_moe = self._reparameterize(mu_moe, logvar_moe)
876
+ z_mod = {m: self._reparameterize(mu_dict[m], logvar_dict[m]) for m in present}
877
+
878
+ def recon_from_z_to_target(z_src: torch.Tensor, target_mod: str) -> torch.Tensor:
879
+ raw = self.decoders[target_mod](z_src)
880
+ m_cfg = self.mod_cfg_by_name[target_mod]
881
+ return self._recon_loss(x=x_dict[target_mod], raw_dec_out=raw, likelihood=m_cfg.likelihood, mod_name=target_mod)
882
+
883
+ device = next(iter(mu_dict.values())).device
884
+ B = next(iter(mu_dict.values())).shape[0]
885
+
886
+ recon_per_target: Dict[str, torch.Tensor] = {m: torch.zeros(B, device=device) for m in present}
887
+ recon_counts: Dict[str, int] = {m: 0 for m in present}
888
+
889
+ recon_total: Optional[torch.Tensor] = None
890
+ n_terms = 0
891
+
892
+ def add_term(t: torch.Tensor, tgt: str):
893
+ nonlocal recon_total, n_terms
894
+ recon_per_target[tgt] = recon_per_target[tgt] + t
895
+ recon_counts[tgt] += 1
896
+ recon_total = t if recon_total is None else (recon_total + t)
897
+ n_terms += 1
898
+
899
+ v1_recon = self.v1_recon
900
+
901
+ if v1_recon in ("avg", "both", "paper", "self+cross", "self_cross", "hybrid"):
902
+ if K <= 1:
903
+ tgt = present[0]
904
+ add_term(recon_from_z_to_target(z_mod[tgt], tgt), tgt=tgt)
905
+ else:
906
+ w_self = 0.5 / float(K)
907
+ w_cross = 0.5 / float(K * (K - 1))
908
+ for tgt in present:
909
+ add_term(w_self * recon_from_z_to_target(z_mod[tgt], tgt), tgt=tgt)
910
+ for src in present:
911
+ for tgt in present:
912
+ if src == tgt:
913
+ continue
914
+ add_term(w_cross * recon_from_z_to_target(z_mod[src], tgt), tgt=tgt)
915
+
916
+ elif v1_recon.startswith("src:"):
917
+ src_name = v1_recon.split("src:", 1)[1].strip()
918
+ if src_name not in z_mod:
919
+ raise ValueError(f"v1_recon={self.v1_recon!r} but '{src_name}' not present. Present={present}")
920
+ for tgt in present:
921
+ add_term(recon_from_z_to_target(z_mod[src_name], tgt), tgt=tgt)
922
+
923
+ elif v1_recon == "self":
924
+ for tgt in present:
925
+ add_term(recon_from_z_to_target(z_mod[tgt], tgt), tgt=tgt)
926
+
927
+ elif v1_recon in ("avg_z", "mean_z", "average_z"):
928
+ z_avg = torch.stack([z_mod[m] for m in present], dim=0).mean(dim=0)
929
+ for tgt in present:
930
+ add_term(recon_from_z_to_target(z_avg, tgt), tgt=tgt)
931
+ z_moe = z_avg
932
+
933
+ elif v1_recon in ("moe", "poe", "fused"):
934
+ for tgt in present:
935
+ add_term(recon_from_z_to_target(z_moe, tgt), tgt=tgt)
936
+
937
+ else:
938
+ if K >= 2:
939
+ for src in present:
940
+ for tgt in present:
941
+ if src == tgt:
942
+ continue
943
+ add_term(recon_from_z_to_target(z_mod[src], tgt), tgt=tgt)
944
+ else:
945
+ tgt = present[0]
946
+ add_term(recon_from_z_to_target(z_mod[tgt], tgt), tgt=tgt)
947
+
948
+ if self.v1_recon_mix > 0.0 and K >= 2:
949
+ mix = float(self.v1_recon_mix)
950
+ z_avg = torch.stack([z_mod[m] for m in present], dim=0).mean(dim=0)
951
+ for tgt in present:
952
+ add_term(mix * recon_from_z_to_target(z_avg, tgt), tgt=tgt)
953
+
954
+ if recon_total is None or n_terms == 0:
955
+ raise RuntimeError("v1 reconstruction produced no loss terms.")
956
+
957
+ weighted_mode = v1_recon in ("avg", "both", "paper", "self+cross", "self_cross", "hybrid")
958
+ if self.normalize_v1_terms and (not weighted_mode):
959
+ recon_total = recon_total / float(n_terms)
960
+
961
+ recon_per_target_mean: Dict[str, torch.Tensor] = {}
962
+ for tgt in present:
963
+ ct = max(int(recon_counts[tgt]), 1)
964
+ recon_per_target_mean[tgt] = recon_per_target[tgt] / float(ct)
965
+
966
+ mu_p = self.prior_mu.expand_as(mu_dict[present[0]])
967
+ logvar_p = self.prior_logvar.expand_as(logvar_dict[present[0]])
968
+
969
+ kl_terms = [self._kl_gaussian(mu_dict[m], logvar_dict[m], mu_p, logvar_p) for m in present]
970
+ kl = torch.stack(kl_terms, dim=0).sum(dim=0)
971
+ if self.normalize_v1_terms:
972
+ kl = kl / float(max(K, 1))
973
+
974
+ if K < 2:
975
+ cross_kl = torch.zeros_like(kl)
976
+ else:
977
+ cross_terms = []
978
+ for i in range(K):
979
+ for j in range(K):
980
+ if i == j:
981
+ continue
982
+ mi, lvi = mu_dict[present[i]], logvar_dict[present[i]]
983
+ mj, lvj = mu_dict[present[j]], logvar_dict[present[j]]
984
+ cross_terms.append(self._kl_gaussian(mi, lvi, mj, lvj))
985
+ cross_kl = torch.stack(cross_terms, dim=0).sum(dim=0)
986
+ if self.normalize_v1_terms:
987
+ cross_kl = cross_kl / float(K * (K - 1))
988
+
989
+ beta_t = self._anneal_weight(epoch, self.cfg.kl_anneal_start, self.cfg.kl_anneal_end, self.beta_max)
990
+ gamma_t = self._anneal_weight(epoch, self.cfg.align_anneal_start, self.cfg.align_anneal_end, self.gamma_max)
991
+
992
+ beta_used = beta_t if self.training else self.beta_max
993
+ gamma_used = gamma_t if self.training else self.gamma_max
994
+
995
+ loss_annealed = recon_total + beta_t * kl + gamma_t * cross_kl
996
+ loss_fixed = recon_total + self.beta_max * kl + self.gamma_max * cross_kl
997
+ loss = recon_total + beta_used * kl + gamma_used * cross_kl
998
+
999
+ y_legacy = self._extract_legacy_y(y)
1000
+ class_loss = None
1001
+ class_logits = None
1002
+ if self.label_decoder is not None:
1003
+ z_for_cls = mu_moe if self.classify_from_mu else z_moe
1004
+ dec_out = self.label_decoder(z_for_cls)
1005
+ class_logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
1006
+
1007
+ if y_legacy is not None:
1008
+ yy = y_legacy.long()
1009
+ mask = (yy >= 0) & (yy != self.label_ignore_index)
1010
+ per_cell = torch.zeros(B, device=loss.device)
1011
+ if mask.any():
1012
+ per_cell[mask] = F.cross_entropy(class_logits[mask], yy[mask], reduction="none")
1013
+ class_loss = per_cell
1014
+
1015
+ loss = loss + self.label_loss_weight * class_loss
1016
+ loss_annealed = loss_annealed + self.label_loss_weight * class_loss
1017
+ loss_fixed = loss_fixed + self.label_loss_weight * class_loss
1018
+
1019
+ loss, loss_annealed, loss_fixed, head_logits, head_loss_means = self._apply_multihead_losses(
1020
+ mu_z=mu_moe,
1021
+ z=z_moe,
1022
+ y=y,
1023
+ epoch=epoch,
1024
+ loss=loss,
1025
+ loss_annealed=loss_annealed,
1026
+ loss_fixed=loss_fixed,
1027
+ )
1028
+
1029
+ xhat_dict = self.decode_modalities(z_moe)
1030
+
1031
+ out: Dict[str, Any] = {
1032
+ "loss": loss.mean(),
1033
+ "recon_total": recon_total.mean(),
1034
+ "kl": kl.mean(),
1035
+ "align": cross_kl.mean(),
1036
+ "cross_kl": cross_kl.mean(),
1037
+ "mu_z": mu_moe,
1038
+ "logvar_z": logvar_moe,
1039
+ "z": z_moe,
1040
+ "xhat": xhat_dict,
1041
+ "mu_dict": mu_dict,
1042
+ "logvar_dict": logvar_dict,
1043
+ "recon_per_modality": {k: v.mean() for k, v in recon_per_target_mean.items()},
1044
+ "beta": torch.tensor(beta_t, device=loss.device),
1045
+ "gamma": torch.tensor(gamma_t, device=loss.device),
1046
+ "beta_used": torch.tensor(beta_used, device=loss.device),
1047
+ "gamma_used": torch.tensor(gamma_used, device=loss.device),
1048
+ "loss_annealed": loss_annealed.mean(),
1049
+ "loss_fixed": loss_fixed.mean(),
1050
+ "v1_recon_terms": torch.tensor(float(n_terms), device=loss.device),
1051
+ "used_fused_encoder": torch.tensor(1.0 if use_fused else 0.0, device=loss.device),
1052
+ }
1053
+ if class_loss is not None:
1054
+ out["class_loss"] = class_loss.mean()
1055
+ if class_logits is not None:
1056
+ out["class_logits"] = class_logits
1057
+ if head_logits:
1058
+ out["head_logits"] = head_logits
1059
+ if head_loss_means:
1060
+ out["head_losses"] = head_loss_means
1061
+ return out
1062
+
1063
+ # ------------------------------ reconstruction loss (unchanged) ------------------------------
1064
+
1065
+ def _unwrap_decoder_out(self, dec_out: Any) -> Any:
1066
+ if isinstance(dec_out, (tuple, list)):
1067
+ if len(dec_out) != 1:
1068
+ raise TypeError(f"Unsupported decoder output container of length {len(dec_out)}: {type(dec_out)!r}")
1069
+ dec_out = dec_out[0]
1070
+ if torch.is_tensor(dec_out):
1071
+ return dec_out
1072
+ if isinstance(dec_out, dict):
1073
+ out = dict(dec_out)
1074
+ if "mu" not in out and "mean" in out:
1075
+ out["mu"] = out["mean"]
1076
+ return out
1077
+ raise TypeError(f"Unsupported decoder output type: {type(dec_out)!r}")
1078
+
1079
+ @staticmethod
1080
+ def _nb_nll(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
1081
+ mu = mu.clamp(min=eps)
1082
+ theta = theta.clamp(min=eps)
1083
+ t1 = torch.lgamma(x + theta) - torch.lgamma(theta) - torch.lgamma(x + 1.0)
1084
+ t2 = theta * (torch.log(theta) - torch.log(theta + mu))
1085
+ t3 = x * (torch.log(mu) - torch.log(theta + mu))
1086
+ return -(t1 + t2 + t3)
1087
+
1088
+ @staticmethod
1089
+ def _zinb_nll(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, logit_pi: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
1090
+ mu = mu.clamp(min=eps)
1091
+ theta = theta.clamp(min=eps)
1092
+ pi = torch.sigmoid(logit_pi)
1093
+ t1 = torch.lgamma(x + theta) - torch.lgamma(theta) - torch.lgamma(x + 1.0)
1094
+ t2 = theta * (torch.log(theta) - torch.log(theta + mu))
1095
+ t3 = x * (torch.log(mu) - torch.log(theta + mu))
1096
+ log_nb = t1 + t2 + t3
1097
+ is_zero = (x < eps)
1098
+ log_prob_pos = torch.log1p(-pi + eps) + log_nb
1099
+ log_nb_zero = theta * (torch.log(theta) - torch.log(theta + mu))
1100
+ log_prob_zero = torch.log(pi + (1.0 - pi) * torch.exp(log_nb_zero) + eps)
1101
+ log_prob = torch.where(is_zero, log_prob_zero, log_prob_pos)
1102
+ return -log_prob
1103
+
1104
+ def _recon_loss(self, x: torch.Tensor, raw_dec_out: Any, likelihood: str, mod_name: str) -> torch.Tensor:
1105
+ likelihood = (likelihood or "gaussian").lower().strip()
1106
+ dec_out = self._unwrap_decoder_out(raw_dec_out)
1107
+
1108
+ if self._is_categorical_likelihood(likelihood):
1109
+ m_cfg = self.mod_cfg_by_name[mod_name]
1110
+ C = int(m_cfg.input_dim)
1111
+ ignore_index = int(getattr(m_cfg, "ignore_index", self.label_ignore_index))
1112
+
1113
+ logits = dec_out["logits"] if isinstance(dec_out, dict) else dec_out
1114
+ y, mask = self._categorical_targets_and_mask(x, n_classes=C, ignore_index=ignore_index)
1115
+
1116
+ nll = torch.zeros(y.shape[0], device=logits.device)
1117
+ if mask.any():
1118
+ nll[mask] = F.cross_entropy(logits[mask], y[mask], reduction="none")
1119
+ return nll
1120
+
1121
+ if likelihood in ("gaussian", "normal", "mse", "gaussian_diag"):
1122
+ if isinstance(dec_out, dict) and ("mean" in dec_out) and ("logvar" in dec_out):
1123
+ mean = dec_out["mean"]
1124
+ logvar = dec_out["logvar"].clamp(self.LOGVAR_MIN, self.LOGVAR_MAX)
1125
+ var = torch.exp(logvar)
1126
+ nll = 0.5 * (logvar + (x - mean) ** 2 / (var + self.EPS))
1127
+ return nll.sum(dim=-1)
1128
+
1129
+ pred = dec_out["mean"] if isinstance(dec_out, dict) and ("mean" in dec_out) else dec_out
1130
+ if likelihood == "mse":
1131
+ return ((x - pred) ** 2).mean(dim=-1)
1132
+ return ((x - pred) ** 2).sum(dim=-1)
1133
+
1134
+ if likelihood == "bernoulli":
1135
+ logits = dec_out["logits"] if isinstance(dec_out, dict) else dec_out
1136
+ nll = F.binary_cross_entropy_with_logits(logits, x, reduction="none")
1137
+ return nll.sum(dim=-1)
1138
+
1139
+ if likelihood == "poisson":
1140
+ log_rate = dec_out["log_rate"] if isinstance(dec_out, dict) and "log_rate" in dec_out else dec_out
1141
+ nll = F.poisson_nll_loss(log_rate, x, log_input=True, full=False, reduction="none")
1142
+ return nll.sum(dim=-1)
1143
+
1144
+ if likelihood in ("nb", "negative_binomial"):
1145
+ if not isinstance(dec_out, dict) or ("mu" not in dec_out) or ("log_theta" not in dec_out):
1146
+ raise ValueError(f"NB recon expects dict with keys ('mu','log_theta'); got {type(dec_out)}")
1147
+ mu = dec_out["mu"]
1148
+ theta = torch.exp(dec_out["log_theta"])
1149
+ if theta.dim() == 1:
1150
+ theta = theta.unsqueeze(0).expand_as(mu)
1151
+ return self._nb_nll(x, mu, theta, eps=self.EPS).sum(dim=-1)
1152
+
1153
+ if likelihood in ("zinb", "zero_inflated_negative_binomial"):
1154
+ if (
1155
+ not isinstance(dec_out, dict)
1156
+ or ("mu" not in dec_out)
1157
+ or ("log_theta" not in dec_out)
1158
+ or ("logit_pi" not in dec_out)
1159
+ ):
1160
+ raise ValueError("ZINB recon expects dict with keys ('mu','log_theta','logit_pi').")
1161
+ mu = dec_out["mu"]
1162
+ theta = torch.exp(dec_out["log_theta"])
1163
+ logit_pi = dec_out["logit_pi"]
1164
+ if theta.dim() == 1:
1165
+ theta = theta.unsqueeze(0).expand_as(mu)
1166
+ if logit_pi.dim() == 1:
1167
+ logit_pi = logit_pi.unsqueeze(0).expand_as(mu)
1168
+ return self._zinb_nll(x, mu, theta, logit_pi, eps=self.EPS).sum(dim=-1)
1169
+
1170
+ pred = dec_out["mean"] if isinstance(dec_out, dict) and ("mean" in dec_out) else dec_out
1171
+ return ((x - pred) ** 2).sum(dim=-1)
1172
+
1173
+ # ------------------------------ annealing ------------------------------
1174
+
1175
+ def _anneal_weight(self, epoch: int, start: int, end: int, max_val: float) -> float:
1176
+ start = int(start)
1177
+ end = int(end)
1178
+ if end <= start:
1179
+ return float(max_val)
1180
+ if epoch <= start:
1181
+ return 0.0
1182
+ if epoch >= end:
1183
+ return float(max_val)
1184
+ frac = (epoch - start) / float(end - start)
1185
+ return float(max_val) * float(frac)
1186
+
1187
+ # --------------------------- convenience API ---------------------------
1188
+
1189
+ @property
1190
+ def device(self) -> torch.device:
1191
+ return next(self.parameters()).device
1192
+
1193
+ @torch.no_grad()
1194
+ def encode_fused(
1195
+ self,
1196
+ x_dict: Dict[str, torch.Tensor],
1197
+ *,
1198
+ epoch: int = 0,
1199
+ y: Optional[YType] = None,
1200
+ use_mean: bool = True,
1201
+ inject_label_expert: bool = True,
1202
+ attn_bias_cfg: Optional[Mapping[str, Any]] = None,
1203
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1204
+ mu_dict, logvar_dict = self.encode_modalities(x_dict, attn_bias_cfg=attn_bias_cfg)
1205
+ if len(mu_dict) == 0:
1206
+ raise ValueError("At least one modality must be present in x_dict.")
1207
+
1208
+ use_fused = (self.fused_encoder_type == "multimodal_transformer") and self._can_use_fused_encoder(x_dict)
1209
+ if use_fused:
1210
+ mu_z, logvar_z = self._compute_fused_posterior(x_dict, attn_bias_cfg=attn_bias_cfg)
1211
+ else:
1212
+ mu_z, logvar_z = self.mixture_of_experts(mu_dict, logvar_dict)
1213
+
1214
+ y_legacy = self._extract_legacy_y(y)
1215
+ if (
1216
+ inject_label_expert
1217
+ and (self.label_encoder is not None)
1218
+ and (y_legacy is not None)
1219
+ and (epoch >= self.label_encoder_warmup)
1220
+ ):
1221
+ B = mu_z.shape[0]
1222
+ dev = mu_z.device
1223
+ mu_y, logvar_y = self._encode_labels_as_expert(y=y_legacy, B=B, device=dev)
1224
+ base_mu_dict = {"__base__": mu_z, "__label__": mu_y}
1225
+ base_lv_dict = {"__base__": logvar_z, "__label__": logvar_y}
1226
+ mu_z, logvar_z = self.mixture_of_experts(base_mu_dict, base_lv_dict)
1227
+
1228
+ z = mu_z if use_mean else self._reparameterize(mu_z, logvar_z)
1229
+ return mu_z, logvar_z, z
1230
+
1231
+ @torch.no_grad()
1232
+ def predict_heads(
1233
+ self,
1234
+ x_dict: Dict[str, torch.Tensor],
1235
+ *,
1236
+ epoch: int = 0,
1237
+ y: Optional[YType] = None,
1238
+ use_mean: bool = True,
1239
+ inject_label_expert: bool = True,
1240
+ return_probs: bool = True,
1241
+ attn_bias_cfg: Optional[Mapping[str, Any]] = None,
1242
+ ) -> Dict[str, torch.Tensor]:
1243
+ mu_z, logvar_z, z = self.encode_fused(
1244
+ x_dict,
1245
+ epoch=epoch,
1246
+ y=y,
1247
+ use_mean=use_mean,
1248
+ inject_label_expert=inject_label_expert,
1249
+ attn_bias_cfg=attn_bias_cfg,
1250
+ )
1251
+ out: Dict[str, torch.Tensor] = {}
1252
+
1253
+ if self.label_decoder is not None:
1254
+ z_for_cls = mu_z if self.classify_from_mu else z
1255
+ dec_out = self.label_decoder(z_for_cls)
1256
+ logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
1257
+ out[self.label_head_name] = logits if not return_probs else F.softmax(logits, dim=-1)
1258
+
1259
+ for name, head in self.class_heads.items():
1260
+ cfg_h = self.class_heads_cfg[name]
1261
+ z_in = mu_z if bool(cfg_h["from_mu"]) else z
1262
+ dec_out = head(z_in)
1263
+ logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
1264
+ out[name] = logits if not return_probs else F.softmax(logits, dim=-1)
1265
+
1266
+ return out
1267
+
1268
+ def get_classification_meta(self) -> Dict[str, Any]:
1269
+ meta: Dict[str, Any] = {
1270
+ "label_head_name": self.label_head_name,
1271
+ "legacy": {
1272
+ "n_label_classes": int(self.n_label_classes),
1273
+ "label_ignore_index": int(self.label_ignore_index),
1274
+ },
1275
+ "multi": {
1276
+ "heads": {k: dict(v) for k, v in self.class_heads_cfg.items()},
1277
+ },
1278
+ }
1279
+ if self.label_names is not None:
1280
+ meta["legacy"]["label_names"] = list(self.label_names)
1281
+ if self.head_label_names:
1282
+ meta["multi"]["label_names"] = {k: list(v) for k, v in self.head_label_names.items()}
1283
+ return meta
1284
+