univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/models/univi.py
ADDED
|
@@ -0,0 +1,1284 @@
|
|
|
1
|
+
# univi/models/univi.py
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Dict, Tuple, Optional, Any, List, Union, Mapping, Callable
|
|
5
|
+
import math
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
|
|
11
|
+
from ..config import UniVIConfig, ModalityConfig, ClassHeadConfig
|
|
12
|
+
from .mlp import build_mlp
|
|
13
|
+
from .decoders import DecoderConfig, build_decoder
|
|
14
|
+
from .encoders import build_gaussian_encoder, build_multimodal_transformer_encoder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
YType = Union[torch.Tensor, Dict[str, torch.Tensor]]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class _GradReverseFn(torch.autograd.Function):
|
|
21
|
+
"""Gradient reversal layer: forward identity, backward -lambda * grad."""
|
|
22
|
+
@staticmethod
|
|
23
|
+
def forward(ctx, x: torch.Tensor, lambd: float):
|
|
24
|
+
ctx.lambd = float(lambd)
|
|
25
|
+
return x.view_as(x)
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def backward(ctx, grad_output: torch.Tensor):
|
|
29
|
+
return -ctx.lambd * grad_output, None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class UniVIMultiModalVAE(nn.Module):
|
|
33
|
+
"""
|
|
34
|
+
Multi-modal β-VAE with per-modality encoders and decoders.
|
|
35
|
+
|
|
36
|
+
Defaults (backwards-compatible)
|
|
37
|
+
-------------------------------
|
|
38
|
+
- Per-modality encoders: MLP (or transformer if set on that modality)
|
|
39
|
+
- Fused posterior: MoE/PoE-style precision fusion over per-modality posteriors
|
|
40
|
+
|
|
41
|
+
Optional fused multimodal transformer posterior
|
|
42
|
+
----------------------------------------------
|
|
43
|
+
If cfg.fused_encoder_type == "multimodal_transformer", build a fused encoder that:
|
|
44
|
+
x_dict -> concat(tokens_rna, tokens_adt, tokens_atac, ...) -> transformer -> (mu_fused, logvar_fused)
|
|
45
|
+
|
|
46
|
+
In that mode:
|
|
47
|
+
- v2/lite can use mu_fused/logvar_fused for z
|
|
48
|
+
- alignment term becomes mean_i ||mu_i - mu_fused||^2 (instead of pairwise)
|
|
49
|
+
- if required fused modalities are missing (cfg.fused_require_all_modalities=True), fall back to MoE fusion
|
|
50
|
+
|
|
51
|
+
Optional attention bias (safe no-op)
|
|
52
|
+
------------------------------------
|
|
53
|
+
You can pass `attn_bias_cfg` into forward/encode_fused/predict_heads:
|
|
54
|
+
|
|
55
|
+
attn_bias_cfg = {
|
|
56
|
+
"atac": {"type": "distance", "lengthscale_bp": 50000, "same_chrom_only": True},
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
- Per-modality transformer encoders: will try to build a distance attention bias if the tokenizer
|
|
60
|
+
supports topk indices + coords and exposes build_distance_attn_bias().
|
|
61
|
+
- Fused multimodal transformer encoder: will fill within-modality blocks (e.g. ATAC slice) and
|
|
62
|
+
keep cross-modality blocks neutral (0).
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
LOGVAR_MIN = -10.0
|
|
66
|
+
LOGVAR_MAX = 10.0
|
|
67
|
+
EPS = 1e-8
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
cfg: UniVIConfig,
|
|
72
|
+
*,
|
|
73
|
+
loss_mode: str = "v2",
|
|
74
|
+
v1_recon: str = "cross",
|
|
75
|
+
v1_recon_mix: float = 0.0,
|
|
76
|
+
normalize_v1_terms: bool = True,
|
|
77
|
+
|
|
78
|
+
# ---- legacy single label head (kept) ----
|
|
79
|
+
n_label_classes: int = 0,
|
|
80
|
+
label_loss_weight: float = 1.0,
|
|
81
|
+
use_label_encoder: bool = False,
|
|
82
|
+
label_moe_weight: float = 1.0,
|
|
83
|
+
unlabeled_logvar: float = 20.0,
|
|
84
|
+
label_encoder_warmup: int = 0,
|
|
85
|
+
label_ignore_index: int = -1,
|
|
86
|
+
classify_from_mu: bool = True,
|
|
87
|
+
label_head_name: Optional[str] = None,
|
|
88
|
+
):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.cfg = cfg
|
|
91
|
+
self.cfg.validate()
|
|
92
|
+
|
|
93
|
+
self.loss_mode = str(loss_mode).lower().strip()
|
|
94
|
+
self.v1_recon = str(v1_recon).lower().strip()
|
|
95
|
+
self.v1_recon_mix = float(v1_recon_mix)
|
|
96
|
+
self.normalize_v1_terms = bool(normalize_v1_terms)
|
|
97
|
+
|
|
98
|
+
self.latent_dim = int(cfg.latent_dim)
|
|
99
|
+
self.beta_max = float(cfg.beta)
|
|
100
|
+
self.gamma_max = float(cfg.gamma)
|
|
101
|
+
|
|
102
|
+
self.modality_names: List[str] = [m.name for m in cfg.modalities]
|
|
103
|
+
self.mod_cfg_by_name: Dict[str, ModalityConfig] = {m.name: m for m in cfg.modalities}
|
|
104
|
+
|
|
105
|
+
# ------------------------------------------------------------
|
|
106
|
+
# Per-modality modules
|
|
107
|
+
# ------------------------------------------------------------
|
|
108
|
+
self.encoders = nn.ModuleDict()
|
|
109
|
+
self.encoder_heads = nn.ModuleDict()
|
|
110
|
+
self.decoders = nn.ModuleDict()
|
|
111
|
+
|
|
112
|
+
for m in cfg.modalities:
|
|
113
|
+
if not isinstance(m, ModalityConfig):
|
|
114
|
+
raise TypeError(f"cfg.modalities must contain ModalityConfig, got {type(m)}")
|
|
115
|
+
|
|
116
|
+
self.encoders[m.name] = build_gaussian_encoder(uni_cfg=cfg, mod_cfg=m)
|
|
117
|
+
self.encoder_heads[m.name] = nn.Identity()
|
|
118
|
+
|
|
119
|
+
dec_hidden = list(m.decoder_hidden) if m.decoder_hidden else [max(64, self.latent_dim)]
|
|
120
|
+
dec_cfg = DecoderConfig(
|
|
121
|
+
output_dim=int(m.input_dim),
|
|
122
|
+
hidden_dims=dec_hidden,
|
|
123
|
+
dropout=float(cfg.decoder_dropout),
|
|
124
|
+
batchnorm=bool(cfg.decoder_batchnorm),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
lk = (m.likelihood or "gaussian").lower().strip()
|
|
128
|
+
decoder_kwargs: Dict[str, Any] = {}
|
|
129
|
+
|
|
130
|
+
if lk in ("nb", "negative_binomial", "zinb", "zero_inflated_negative_binomial"):
|
|
131
|
+
dispersion = getattr(m, "dispersion", "gene")
|
|
132
|
+
init_log_theta = float(getattr(m, "init_log_theta", 0.0))
|
|
133
|
+
decoder_kwargs = dict(
|
|
134
|
+
dispersion=dispersion,
|
|
135
|
+
init_log_theta=init_log_theta,
|
|
136
|
+
eps=self.EPS,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
self.decoders[m.name] = build_decoder(
|
|
140
|
+
lk,
|
|
141
|
+
cfg=dec_cfg,
|
|
142
|
+
latent_dim=self.latent_dim,
|
|
143
|
+
**decoder_kwargs,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Shared prior N(0, I)
|
|
147
|
+
self.register_buffer("prior_mu", torch.zeros(self.latent_dim))
|
|
148
|
+
self.register_buffer("prior_logvar", torch.zeros(self.latent_dim))
|
|
149
|
+
|
|
150
|
+
# ------------------------------------------------------------
|
|
151
|
+
# Optional fused multimodal encoder
|
|
152
|
+
# ------------------------------------------------------------
|
|
153
|
+
self.fused_encoder_type = (getattr(cfg, "fused_encoder_type", "moe") or "moe").lower().strip()
|
|
154
|
+
self.fused_require_all = bool(getattr(cfg, "fused_require_all_modalities", True))
|
|
155
|
+
self.fused_modalities = list(getattr(cfg, "fused_modalities", None) or self.modality_names)
|
|
156
|
+
|
|
157
|
+
if self.fused_encoder_type == "multimodal_transformer":
|
|
158
|
+
self.fused_encoder = build_multimodal_transformer_encoder(
|
|
159
|
+
uni_cfg=cfg,
|
|
160
|
+
modalities=cfg.modalities,
|
|
161
|
+
fused_modalities=self.fused_modalities,
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
self.fused_encoder = None
|
|
165
|
+
|
|
166
|
+
# ------------------------------------------------------------
|
|
167
|
+
# Legacy single label head (kept)
|
|
168
|
+
# ------------------------------------------------------------
|
|
169
|
+
self.n_label_classes = int(n_label_classes) if n_label_classes is not None else 0
|
|
170
|
+
self.label_loss_weight = float(label_loss_weight)
|
|
171
|
+
self.label_ignore_index = int(label_ignore_index)
|
|
172
|
+
self.classify_from_mu = bool(classify_from_mu)
|
|
173
|
+
self.label_head_name = str(label_head_name or getattr(cfg, "label_head_name", "label"))
|
|
174
|
+
|
|
175
|
+
if self.n_label_classes > 0:
|
|
176
|
+
dec_cfg = DecoderConfig(
|
|
177
|
+
output_dim=self.n_label_classes,
|
|
178
|
+
hidden_dims=[max(64, self.latent_dim)],
|
|
179
|
+
dropout=float(cfg.decoder_dropout),
|
|
180
|
+
batchnorm=bool(cfg.decoder_batchnorm),
|
|
181
|
+
)
|
|
182
|
+
self.label_decoder = build_decoder("categorical", cfg=dec_cfg, latent_dim=self.latent_dim)
|
|
183
|
+
else:
|
|
184
|
+
self.label_decoder = None
|
|
185
|
+
|
|
186
|
+
self.label_names: Optional[List[str]] = None
|
|
187
|
+
self.label_name_to_id: Optional[Dict[str, int]] = None
|
|
188
|
+
|
|
189
|
+
# Legacy label encoder expert (optional)
|
|
190
|
+
self.use_label_encoder = bool(use_label_encoder)
|
|
191
|
+
self.label_moe_weight = float(label_moe_weight)
|
|
192
|
+
self.unlabeled_logvar = float(unlabeled_logvar)
|
|
193
|
+
self.label_encoder_warmup = int(label_encoder_warmup)
|
|
194
|
+
|
|
195
|
+
if self.n_label_classes > 0 and self.use_label_encoder:
|
|
196
|
+
self.label_encoder = build_mlp(
|
|
197
|
+
in_dim=self.n_label_classes,
|
|
198
|
+
hidden_dims=[max(64, self.latent_dim)],
|
|
199
|
+
out_dim=self.latent_dim * 2,
|
|
200
|
+
activation=nn.ReLU(),
|
|
201
|
+
dropout=float(cfg.encoder_dropout),
|
|
202
|
+
batchnorm=bool(cfg.encoder_batchnorm),
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
self.label_encoder = None
|
|
206
|
+
|
|
207
|
+
# Multi-head supervised decoders (incl. adversarial heads)
|
|
208
|
+
self.class_heads = nn.ModuleDict()
|
|
209
|
+
self.class_heads_cfg: Dict[str, Dict[str, Any]] = {}
|
|
210
|
+
self.head_label_names: Dict[str, List[str]] = {}
|
|
211
|
+
self.head_label_name_to_id: Dict[str, Dict[str, int]] = {}
|
|
212
|
+
|
|
213
|
+
if cfg.class_heads:
|
|
214
|
+
for h in cfg.class_heads:
|
|
215
|
+
if not isinstance(h, ClassHeadConfig):
|
|
216
|
+
raise TypeError(f"cfg.class_heads must contain ClassHeadConfig, got {type(h)}")
|
|
217
|
+
|
|
218
|
+
name = str(h.name)
|
|
219
|
+
n_classes = int(h.n_classes)
|
|
220
|
+
|
|
221
|
+
dec_cfg = DecoderConfig(
|
|
222
|
+
output_dim=n_classes,
|
|
223
|
+
hidden_dims=[max(64, self.latent_dim)],
|
|
224
|
+
dropout=float(cfg.decoder_dropout),
|
|
225
|
+
batchnorm=bool(cfg.decoder_batchnorm),
|
|
226
|
+
)
|
|
227
|
+
self.class_heads[name] = build_decoder("categorical", cfg=dec_cfg, latent_dim=self.latent_dim)
|
|
228
|
+
self.class_heads_cfg[name] = {
|
|
229
|
+
"n_classes": n_classes,
|
|
230
|
+
"loss_weight": float(h.loss_weight),
|
|
231
|
+
"ignore_index": int(h.ignore_index),
|
|
232
|
+
"from_mu": bool(h.from_mu),
|
|
233
|
+
"warmup": int(h.warmup),
|
|
234
|
+
"adversarial": bool(getattr(h, "adversarial", False)),
|
|
235
|
+
"adv_lambda": float(getattr(h, "adv_lambda", 1.0)),
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
# ----------------------------- label name utilities -----------------------------
|
|
239
|
+
|
|
240
|
+
def set_label_names(self, label_names: List[str]) -> None:
|
|
241
|
+
if self.n_label_classes <= 0:
|
|
242
|
+
raise ValueError("n_label_classes=0; cannot set label names.")
|
|
243
|
+
if len(label_names) != int(self.n_label_classes):
|
|
244
|
+
raise ValueError(f"label_names length {len(label_names)} != n_label_classes {self.n_label_classes}")
|
|
245
|
+
|
|
246
|
+
self.label_names = [str(x) for x in label_names]
|
|
247
|
+
|
|
248
|
+
def norm(s: str) -> str:
|
|
249
|
+
return " ".join(str(s).strip().lower().split())
|
|
250
|
+
|
|
251
|
+
m: Dict[str, int] = {}
|
|
252
|
+
for i, name in enumerate(self.label_names):
|
|
253
|
+
m[name] = i
|
|
254
|
+
m[norm(name)] = i
|
|
255
|
+
self.label_name_to_id = m
|
|
256
|
+
|
|
257
|
+
def set_head_label_names(self, head: str, label_names: List[str]) -> None:
|
|
258
|
+
head = str(head)
|
|
259
|
+
if head not in self.class_heads_cfg:
|
|
260
|
+
raise KeyError(f"Unknown head {head!r}. Known heads: {list(self.class_heads_cfg)}")
|
|
261
|
+
|
|
262
|
+
n = int(self.class_heads_cfg[head]["n_classes"])
|
|
263
|
+
if len(label_names) != n:
|
|
264
|
+
raise ValueError(f"Head {head!r}: label_names length {len(label_names)} != n_classes {n}")
|
|
265
|
+
|
|
266
|
+
names = [str(x) for x in label_names]
|
|
267
|
+
self.head_label_names[head] = names
|
|
268
|
+
|
|
269
|
+
def norm(s: str) -> str:
|
|
270
|
+
return " ".join(str(s).strip().lower().split())
|
|
271
|
+
|
|
272
|
+
m: Dict[str, int] = {}
|
|
273
|
+
for i, name in enumerate(names):
|
|
274
|
+
m[name] = i
|
|
275
|
+
m[norm(name)] = i
|
|
276
|
+
self.head_label_name_to_id[head] = m
|
|
277
|
+
|
|
278
|
+
# ----------------------------- helpers -----------------------------
|
|
279
|
+
|
|
280
|
+
def _reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
|
|
281
|
+
std = torch.exp(0.5 * logvar)
|
|
282
|
+
eps = torch.randn_like(std)
|
|
283
|
+
return mu + eps * std
|
|
284
|
+
|
|
285
|
+
def _kl_gaussian(self, mu_q: torch.Tensor, logvar_q: torch.Tensor, mu_p: torch.Tensor, logvar_p: torch.Tensor) -> torch.Tensor:
|
|
286
|
+
var_q = torch.exp(logvar_q)
|
|
287
|
+
var_p = torch.exp(logvar_p)
|
|
288
|
+
kl = logvar_p - logvar_q + (var_q + (mu_q - mu_p) ** 2) / var_p - 1.0
|
|
289
|
+
return 0.5 * kl.sum(dim=-1)
|
|
290
|
+
|
|
291
|
+
def _alignment_loss_l2mu_pairwise(self, mu_per_mod: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
292
|
+
names = list(mu_per_mod.keys())
|
|
293
|
+
if len(names) < 2:
|
|
294
|
+
if len(names) == 0:
|
|
295
|
+
return torch.tensor(0.0, device=next(self.parameters()).device).expand(1)
|
|
296
|
+
k = names[0]
|
|
297
|
+
return torch.zeros(mu_per_mod[k].size(0), device=mu_per_mod[k].device)
|
|
298
|
+
|
|
299
|
+
losses = []
|
|
300
|
+
for i in range(len(names)):
|
|
301
|
+
for j in range(i + 1, len(names)):
|
|
302
|
+
losses.append(((mu_per_mod[names[i]] - mu_per_mod[names[j]]) ** 2).sum(dim=-1))
|
|
303
|
+
return torch.stack(losses, dim=0).mean(dim=0)
|
|
304
|
+
|
|
305
|
+
def _alignment_loss_to_fused(self, mu_per_mod: Dict[str, torch.Tensor], mu_fused: torch.Tensor) -> torch.Tensor:
|
|
306
|
+
if len(mu_per_mod) == 0:
|
|
307
|
+
return torch.zeros(mu_fused.size(0), device=mu_fused.device)
|
|
308
|
+
losses = [((mu - mu_fused) ** 2).sum(dim=-1) for mu in mu_per_mod.values()]
|
|
309
|
+
return torch.stack(losses, dim=0).mean(dim=0)
|
|
310
|
+
|
|
311
|
+
@staticmethod
|
|
312
|
+
def _is_categorical_likelihood(likelihood: Optional[str]) -> bool:
|
|
313
|
+
lk = (likelihood or "").lower().strip()
|
|
314
|
+
return lk in ("categorical", "cat", "ce", "cross_entropy", "multinomial", "softmax")
|
|
315
|
+
|
|
316
|
+
def _categorical_targets_and_mask(
|
|
317
|
+
self,
|
|
318
|
+
x: torch.Tensor,
|
|
319
|
+
*,
|
|
320
|
+
n_classes: int,
|
|
321
|
+
ignore_index: int,
|
|
322
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
323
|
+
if x.dim() == 2 and x.size(1) == 1:
|
|
324
|
+
x = x[:, 0]
|
|
325
|
+
|
|
326
|
+
if x.dim() == 2:
|
|
327
|
+
mask = x.sum(dim=-1) > 0.5
|
|
328
|
+
y = x.argmax(dim=-1).long()
|
|
329
|
+
return y, mask
|
|
330
|
+
|
|
331
|
+
y = x.view(-1)
|
|
332
|
+
if y.dtype.is_floating_point:
|
|
333
|
+
y = y.round()
|
|
334
|
+
y = y.long()
|
|
335
|
+
|
|
336
|
+
mask = (y != int(ignore_index))
|
|
337
|
+
return y, mask
|
|
338
|
+
|
|
339
|
+
def _encode_categorical_input_if_needed(self, mod_name: str, x: torch.Tensor) -> torch.Tensor:
|
|
340
|
+
m_cfg = self.mod_cfg_by_name[mod_name]
|
|
341
|
+
if not self._is_categorical_likelihood(m_cfg.likelihood):
|
|
342
|
+
return x
|
|
343
|
+
|
|
344
|
+
C = int(m_cfg.input_dim)
|
|
345
|
+
ignore_index = int(getattr(m_cfg, "ignore_index", self.label_ignore_index))
|
|
346
|
+
|
|
347
|
+
if x.dim() == 2 and x.size(1) == 1:
|
|
348
|
+
x = x[:, 0]
|
|
349
|
+
if x.dim() == 2:
|
|
350
|
+
return x.float()
|
|
351
|
+
|
|
352
|
+
y, mask = self._categorical_targets_and_mask(x, n_classes=C, ignore_index=ignore_index)
|
|
353
|
+
B = y.shape[0]
|
|
354
|
+
x_oh = torch.zeros((B, C), device=y.device, dtype=torch.float32)
|
|
355
|
+
if mask.any():
|
|
356
|
+
x_oh[mask] = F.one_hot(y[mask], num_classes=C).float()
|
|
357
|
+
return x_oh
|
|
358
|
+
|
|
359
|
+
# -------------------------- attention-bias helpers (optional, safe no-op) --------------------------
|
|
360
|
+
|
|
361
|
+
def _build_distance_bias_for_permod_transformer(
|
|
362
|
+
self,
|
|
363
|
+
enc: nn.Module,
|
|
364
|
+
x_in: torch.Tensor,
|
|
365
|
+
cfg_m: Mapping[str, Any],
|
|
366
|
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
367
|
+
"""
|
|
368
|
+
Attempt to build (tokens, key_padding_mask, attn_bias) for a TransformerGaussianEncoder-like module.
|
|
369
|
+
|
|
370
|
+
Returns (None, None, None) if:
|
|
371
|
+
- encoder isn't transformer-style,
|
|
372
|
+
- tokenizer doesn't expose topk indices,
|
|
373
|
+
- coords aren't attached,
|
|
374
|
+
- config isn't distance type, etc.
|
|
375
|
+
"""
|
|
376
|
+
typ = str(cfg_m.get("type", "")).lower().strip()
|
|
377
|
+
if typ != "distance":
|
|
378
|
+
return None, None, None
|
|
379
|
+
|
|
380
|
+
vec2tok = getattr(enc, "vec2tok", None)
|
|
381
|
+
core = getattr(enc, "encoder", None)
|
|
382
|
+
if vec2tok is None or core is None:
|
|
383
|
+
return None, None, None
|
|
384
|
+
|
|
385
|
+
if not hasattr(vec2tok, "build_distance_attn_bias"):
|
|
386
|
+
return None, None, None
|
|
387
|
+
|
|
388
|
+
try:
|
|
389
|
+
tokens, key_padding_mask, meta = vec2tok(x_in, return_indices=True)
|
|
390
|
+
except Exception:
|
|
391
|
+
return None, None, None
|
|
392
|
+
|
|
393
|
+
topk_idx = None if meta is None else meta.get("topk_idx", None)
|
|
394
|
+
if topk_idx is None:
|
|
395
|
+
return None, None, None
|
|
396
|
+
|
|
397
|
+
lengthscale_bp = float(cfg_m.get("lengthscale_bp", 50_000.0))
|
|
398
|
+
same_chrom_only = bool(cfg_m.get("same_chrom_only", True))
|
|
399
|
+
include_cls = bool(getattr(vec2tok, "add_cls_token", False))
|
|
400
|
+
|
|
401
|
+
try:
|
|
402
|
+
attn_bias = vec2tok.build_distance_attn_bias(
|
|
403
|
+
topk_idx,
|
|
404
|
+
lengthscale_bp=lengthscale_bp,
|
|
405
|
+
same_chrom_only=same_chrom_only,
|
|
406
|
+
include_cls=include_cls,
|
|
407
|
+
)
|
|
408
|
+
except Exception:
|
|
409
|
+
return None, None, None
|
|
410
|
+
|
|
411
|
+
return tokens, key_padding_mask, attn_bias
|
|
412
|
+
|
|
413
|
+
def _build_fused_attn_bias_fn(
|
|
414
|
+
self,
|
|
415
|
+
attn_bias_cfg: Mapping[str, Any],
|
|
416
|
+
sub_x_dict: Dict[str, torch.Tensor],
|
|
417
|
+
) -> Optional[Callable[[Dict[str, Any]], Optional[torch.Tensor]]]:
|
|
418
|
+
"""
|
|
419
|
+
Build a callable(meta)->(B,T,T) for MultiModalTransformerGaussianEncoder.
|
|
420
|
+
|
|
421
|
+
Behavior:
|
|
422
|
+
- fills ONLY within-modality blocks for modalities configured as distance bias
|
|
423
|
+
- keeps cross-modality logits neutral (0)
|
|
424
|
+
"""
|
|
425
|
+
fused = self.fused_encoder
|
|
426
|
+
if fused is None:
|
|
427
|
+
return None
|
|
428
|
+
|
|
429
|
+
vec2tok_map = getattr(fused, "vec2tok", None)
|
|
430
|
+
if vec2tok_map is None:
|
|
431
|
+
return None
|
|
432
|
+
|
|
433
|
+
# only bother if any modality is configured for distance bias
|
|
434
|
+
want_any = any(str(v.get("type", "")).lower().strip() == "distance" for v in attn_bias_cfg.values() if isinstance(v, Mapping))
|
|
435
|
+
if not want_any:
|
|
436
|
+
return None
|
|
437
|
+
|
|
438
|
+
def fn(meta: Dict[str, Any]) -> Optional[torch.Tensor]:
|
|
439
|
+
if not meta:
|
|
440
|
+
return None
|
|
441
|
+
|
|
442
|
+
# batch size
|
|
443
|
+
any_x = next(iter(sub_x_dict.values()))
|
|
444
|
+
B = int(any_x.shape[0])
|
|
445
|
+
|
|
446
|
+
slices = meta.get("slices_with_cls", meta.get("slices", {}))
|
|
447
|
+
if not slices:
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
# total token length in fused space (already includes global CLS if present)
|
|
451
|
+
T = 0
|
|
452
|
+
for _, (a, b) in slices.items():
|
|
453
|
+
T = max(T, int(b))
|
|
454
|
+
if bool(meta.get("has_global_cls", False)):
|
|
455
|
+
T = max(T, 1)
|
|
456
|
+
|
|
457
|
+
bias_full = torch.zeros((B, T, T), device=any_x.device, dtype=torch.float32)
|
|
458
|
+
|
|
459
|
+
for m, cfg_m in attn_bias_cfg.items():
|
|
460
|
+
if not isinstance(cfg_m, Mapping):
|
|
461
|
+
continue
|
|
462
|
+
if str(cfg_m.get("type", "")).lower().strip() != "distance":
|
|
463
|
+
continue
|
|
464
|
+
if m not in slices:
|
|
465
|
+
continue
|
|
466
|
+
if m not in vec2tok_map:
|
|
467
|
+
continue
|
|
468
|
+
|
|
469
|
+
tok = vec2tok_map[m]
|
|
470
|
+
if not hasattr(tok, "build_distance_attn_bias"):
|
|
471
|
+
continue
|
|
472
|
+
|
|
473
|
+
mmeta = meta.get(m, {}) or {}
|
|
474
|
+
topk_idx = mmeta.get("topk_idx", None)
|
|
475
|
+
if topk_idx is None:
|
|
476
|
+
continue
|
|
477
|
+
|
|
478
|
+
lengthscale_bp = float(cfg_m.get("lengthscale_bp", 50_000.0))
|
|
479
|
+
same_chrom_only = bool(cfg_m.get("same_chrom_only", True))
|
|
480
|
+
|
|
481
|
+
try:
|
|
482
|
+
local = tok.build_distance_attn_bias(
|
|
483
|
+
topk_idx,
|
|
484
|
+
lengthscale_bp=lengthscale_bp,
|
|
485
|
+
same_chrom_only=same_chrom_only,
|
|
486
|
+
include_cls=False, # per-modality cls is forced off in fused encoder
|
|
487
|
+
) # (B, K, K)
|
|
488
|
+
except Exception:
|
|
489
|
+
continue
|
|
490
|
+
|
|
491
|
+
a, b = slices[m]
|
|
492
|
+
a = int(a); b = int(b)
|
|
493
|
+
if (b - a) != int(local.shape[1]):
|
|
494
|
+
continue
|
|
495
|
+
|
|
496
|
+
bias_full[:, a:b, a:b] = local
|
|
497
|
+
|
|
498
|
+
return bias_full
|
|
499
|
+
|
|
500
|
+
return fn
|
|
501
|
+
|
|
502
|
+
# -------------------------- encode/decode/fuse --------------------------
|
|
503
|
+
|
|
504
|
+
def encode_modalities(
|
|
505
|
+
self,
|
|
506
|
+
x_dict: Dict[str, torch.Tensor],
|
|
507
|
+
*,
|
|
508
|
+
attn_bias_cfg: Optional[Mapping[str, Any]] = None,
|
|
509
|
+
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
|
510
|
+
mu_dict: Dict[str, torch.Tensor] = {}
|
|
511
|
+
logvar_dict: Dict[str, torch.Tensor] = {}
|
|
512
|
+
|
|
513
|
+
for m in self.modality_names:
|
|
514
|
+
if m not in x_dict or x_dict[m] is None:
|
|
515
|
+
continue
|
|
516
|
+
|
|
517
|
+
x_in = self._encode_categorical_input_if_needed(m, x_dict[m])
|
|
518
|
+
enc = self.encoders[m]
|
|
519
|
+
|
|
520
|
+
# optional: transformer distance bias (single-pass through core encoder)
|
|
521
|
+
tokens = key_padding_mask = attn_bias = None
|
|
522
|
+
if attn_bias_cfg is not None and isinstance(attn_bias_cfg, Mapping):
|
|
523
|
+
cfg_m = attn_bias_cfg.get(m, None)
|
|
524
|
+
if isinstance(cfg_m, Mapping):
|
|
525
|
+
tokens, key_padding_mask, attn_bias = self._build_distance_bias_for_permod_transformer(enc, x_in, cfg_m)
|
|
526
|
+
|
|
527
|
+
if tokens is not None and getattr(enc, "encoder", None) is not None:
|
|
528
|
+
# enc is TransformerGaussianEncoder-like; run core directly
|
|
529
|
+
h = enc.encoder(
|
|
530
|
+
tokens,
|
|
531
|
+
key_padding_mask=key_padding_mask,
|
|
532
|
+
attn_bias=attn_bias,
|
|
533
|
+
return_attn=False,
|
|
534
|
+
)
|
|
535
|
+
mu, logvar = torch.chunk(h, 2, dim=-1)
|
|
536
|
+
else:
|
|
537
|
+
# fallback: standard encoder call
|
|
538
|
+
try:
|
|
539
|
+
mu, logvar = enc(x_in, attn_bias=attn_bias) # transformer supports attn_bias kw
|
|
540
|
+
except TypeError:
|
|
541
|
+
mu, logvar = enc(x_in)
|
|
542
|
+
|
|
543
|
+
mu = self.encoder_heads[m](mu)
|
|
544
|
+
logvar = torch.clamp(logvar, self.LOGVAR_MIN, self.LOGVAR_MAX)
|
|
545
|
+
|
|
546
|
+
mu_dict[m] = mu
|
|
547
|
+
logvar_dict[m] = logvar
|
|
548
|
+
|
|
549
|
+
return mu_dict, logvar_dict
|
|
550
|
+
|
|
551
|
+
def mixture_of_experts(self, mu_dict: Dict[str, torch.Tensor], logvar_dict: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
552
|
+
mus = list(mu_dict.values())
|
|
553
|
+
logvars = list(logvar_dict.values())
|
|
554
|
+
|
|
555
|
+
precisions = [torch.exp(-lv) for lv in logvars]
|
|
556
|
+
precision_sum = torch.stack(precisions, dim=0).sum(dim=0).clamp_min(self.EPS)
|
|
557
|
+
|
|
558
|
+
mu_weighted = torch.stack([m * p for m, p in zip(mus, precisions)], dim=0).sum(dim=0)
|
|
559
|
+
mu_comb = mu_weighted / precision_sum
|
|
560
|
+
|
|
561
|
+
var_comb = 1.0 / precision_sum
|
|
562
|
+
logvar_comb = torch.log(var_comb.clamp_min(self.EPS))
|
|
563
|
+
return mu_comb, logvar_comb
|
|
564
|
+
|
|
565
|
+
def decode_modalities(self, z: torch.Tensor) -> Dict[str, Any]:
|
|
566
|
+
return {m: self.decoders[m](z) for m in self.modality_names}
|
|
567
|
+
|
|
568
|
+
# -------------------- fused posterior --------------------
|
|
569
|
+
|
|
570
|
+
def _present_fused_modalities(self, x_dict: Dict[str, torch.Tensor]) -> List[str]:
|
|
571
|
+
return [m for m in self.fused_modalities if (m in x_dict) and (x_dict[m] is not None)]
|
|
572
|
+
|
|
573
|
+
def _can_use_fused_encoder(self, x_dict: Dict[str, torch.Tensor]) -> bool:
|
|
574
|
+
if self.fused_encoder is None:
|
|
575
|
+
return False
|
|
576
|
+
if not self.fused_modalities:
|
|
577
|
+
return False
|
|
578
|
+
|
|
579
|
+
present = self._present_fused_modalities(x_dict)
|
|
580
|
+
if self.fused_require_all:
|
|
581
|
+
return len(present) == len(self.fused_modalities)
|
|
582
|
+
return len(present) >= 1
|
|
583
|
+
|
|
584
|
+
def _compute_fused_posterior(
|
|
585
|
+
self,
|
|
586
|
+
x_dict: Dict[str, torch.Tensor],
|
|
587
|
+
*,
|
|
588
|
+
attn_bias_cfg: Optional[Mapping[str, Any]] = None,
|
|
589
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
590
|
+
assert self.fused_encoder is not None
|
|
591
|
+
present = self._present_fused_modalities(x_dict)
|
|
592
|
+
sub = {m: x_dict[m] for m in present}
|
|
593
|
+
|
|
594
|
+
attn_bias_fn = None
|
|
595
|
+
want_meta = False
|
|
596
|
+
if attn_bias_cfg is not None and isinstance(attn_bias_cfg, Mapping):
|
|
597
|
+
attn_bias_fn = self._build_fused_attn_bias_fn(attn_bias_cfg, sub)
|
|
598
|
+
want_meta = attn_bias_fn is not None
|
|
599
|
+
|
|
600
|
+
out = self.fused_encoder(sub, return_token_meta=want_meta, attn_bias_fn=attn_bias_fn)
|
|
601
|
+
if want_meta:
|
|
602
|
+
mu_f, logvar_f, _meta = out # type: ignore[misc]
|
|
603
|
+
else:
|
|
604
|
+
mu_f, logvar_f = out # type: ignore[misc]
|
|
605
|
+
|
|
606
|
+
logvar_f = torch.clamp(logvar_f, self.LOGVAR_MIN, self.LOGVAR_MAX)
|
|
607
|
+
return mu_f, logvar_f
|
|
608
|
+
|
|
609
|
+
# -------------------- label expert --------------------
|
|
610
|
+
|
|
611
|
+
def _encode_labels_as_expert(self, y: torch.Tensor, B: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
612
|
+
mu_y = torch.zeros(B, self.latent_dim, device=device)
|
|
613
|
+
logvar_y = torch.full((B, self.latent_dim), float(self.unlabeled_logvar), device=device)
|
|
614
|
+
|
|
615
|
+
if self.label_encoder is None or y is None:
|
|
616
|
+
return mu_y, logvar_y
|
|
617
|
+
|
|
618
|
+
y = y.long()
|
|
619
|
+
mask = (y >= 0) & (y != self.label_ignore_index)
|
|
620
|
+
if not mask.any():
|
|
621
|
+
return mu_y, logvar_y
|
|
622
|
+
|
|
623
|
+
y_oh = F.one_hot(y[mask], num_classes=self.n_label_classes).float()
|
|
624
|
+
h = self.label_encoder(y_oh)
|
|
625
|
+
mu_l, logvar_l = torch.chunk(h, 2, dim=-1)
|
|
626
|
+
logvar_l = torch.clamp(logvar_l, self.LOGVAR_MIN, self.LOGVAR_MAX)
|
|
627
|
+
|
|
628
|
+
w = float(self.label_moe_weight)
|
|
629
|
+
if w != 1.0:
|
|
630
|
+
logvar_l = logvar_l - math.log(max(w, 1e-8))
|
|
631
|
+
|
|
632
|
+
mu_y[mask] = mu_l
|
|
633
|
+
logvar_y[mask] = logvar_l
|
|
634
|
+
return mu_y, logvar_y
|
|
635
|
+
|
|
636
|
+
def _extract_legacy_y(self, y: Optional[YType]) -> Optional[torch.Tensor]:
|
|
637
|
+
if y is None:
|
|
638
|
+
return None
|
|
639
|
+
if isinstance(y, Mapping):
|
|
640
|
+
v = y.get(self.label_head_name, None)
|
|
641
|
+
return v.long() if v is not None else None
|
|
642
|
+
return y.long()
|
|
643
|
+
|
|
644
|
+
def _grad_reverse(self, x: torch.Tensor, lambd: float = 1.0) -> torch.Tensor:
|
|
645
|
+
return _GradReverseFn.apply(x, float(lambd))
|
|
646
|
+
|
|
647
|
+
def _apply_multihead_losses(
|
|
648
|
+
self,
|
|
649
|
+
*,
|
|
650
|
+
mu_z: torch.Tensor,
|
|
651
|
+
z: torch.Tensor,
|
|
652
|
+
y: Optional[YType],
|
|
653
|
+
epoch: int,
|
|
654
|
+
loss: torch.Tensor,
|
|
655
|
+
loss_annealed: torch.Tensor,
|
|
656
|
+
loss_fixed: torch.Tensor,
|
|
657
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
|
658
|
+
head_logits: Dict[str, torch.Tensor] = {}
|
|
659
|
+
head_loss_means: Dict[str, torch.Tensor] = {}
|
|
660
|
+
|
|
661
|
+
if y is None or not isinstance(y, Mapping) or len(self.class_heads) == 0:
|
|
662
|
+
return loss, loss_annealed, loss_fixed, head_logits, head_loss_means
|
|
663
|
+
|
|
664
|
+
B = int(mu_z.size(0))
|
|
665
|
+
device = mu_z.device
|
|
666
|
+
|
|
667
|
+
for name, head in self.class_heads.items():
|
|
668
|
+
cfg_h = self.class_heads_cfg[name]
|
|
669
|
+
if int(epoch) < int(cfg_h["warmup"]):
|
|
670
|
+
continue
|
|
671
|
+
|
|
672
|
+
y_h = y.get(name, None)
|
|
673
|
+
if y_h is None:
|
|
674
|
+
continue
|
|
675
|
+
|
|
676
|
+
y_h = y_h.long()
|
|
677
|
+
z_in = mu_z if bool(cfg_h["from_mu"]) else z
|
|
678
|
+
|
|
679
|
+
if bool(cfg_h.get("adversarial", False)):
|
|
680
|
+
z_in = self._grad_reverse(z_in, cfg_h.get("adv_lambda", 1.0))
|
|
681
|
+
|
|
682
|
+
dec_out = head(z_in)
|
|
683
|
+
logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
|
|
684
|
+
head_logits[name] = logits
|
|
685
|
+
|
|
686
|
+
ignore_index = int(cfg_h["ignore_index"])
|
|
687
|
+
mask = (y_h != ignore_index) & (y_h >= 0)
|
|
688
|
+
per_cell = torch.zeros(B, device=device)
|
|
689
|
+
if mask.any():
|
|
690
|
+
per_cell[mask] = F.cross_entropy(logits[mask], y_h[mask], reduction="none")
|
|
691
|
+
|
|
692
|
+
w = float(cfg_h["loss_weight"])
|
|
693
|
+
if w != 0.0:
|
|
694
|
+
loss = loss + w * per_cell
|
|
695
|
+
loss_annealed = loss_annealed + w * per_cell
|
|
696
|
+
loss_fixed = loss_fixed + w * per_cell
|
|
697
|
+
|
|
698
|
+
head_loss_means[name] = per_cell.mean()
|
|
699
|
+
|
|
700
|
+
return loss, loss_annealed, loss_fixed, head_logits, head_loss_means
|
|
701
|
+
|
|
702
|
+
# ------------------------------ forward dispatcher ------------------------------
|
|
703
|
+
|
|
704
|
+
def forward(
|
|
705
|
+
self,
|
|
706
|
+
x_dict: Dict[str, torch.Tensor],
|
|
707
|
+
epoch: int = 0,
|
|
708
|
+
y: Optional[YType] = None,
|
|
709
|
+
attn_bias_cfg: Optional[Mapping[str, Any]] = None,
|
|
710
|
+
**kwargs: Any, # keep older wrappers from breaking if they pass extra args
|
|
711
|
+
) -> Dict[str, torch.Tensor]:
|
|
712
|
+
mode = (self.loss_mode or "v2").lower()
|
|
713
|
+
if mode in ("v1", "paper", "cross"):
|
|
714
|
+
return self._forward_v1(x_dict=x_dict, epoch=epoch, y=y, attn_bias_cfg=attn_bias_cfg)
|
|
715
|
+
if mode in ("v2", "lite", "light", "moe", "poe", "fused"):
|
|
716
|
+
return self._forward_v2(x_dict=x_dict, epoch=epoch, y=y, attn_bias_cfg=attn_bias_cfg)
|
|
717
|
+
raise ValueError(f"Unknown loss_mode={self.loss_mode!r}.")
|
|
718
|
+
|
|
719
|
+
# ------------------------------ v2 / lite ------------------------------
|
|
720
|
+
|
|
721
|
+
def _forward_v2(
|
|
722
|
+
self,
|
|
723
|
+
x_dict: Dict[str, torch.Tensor],
|
|
724
|
+
epoch: int = 0,
|
|
725
|
+
y: Optional[YType] = None,
|
|
726
|
+
attn_bias_cfg: Optional[Mapping[str, Any]] = None,
|
|
727
|
+
) -> Dict[str, torch.Tensor]:
|
|
728
|
+
mu_dict, logvar_dict = self.encode_modalities(x_dict, attn_bias_cfg=attn_bias_cfg)
|
|
729
|
+
if len(mu_dict) == 0:
|
|
730
|
+
raise ValueError("At least one modality must be present in x_dict.")
|
|
731
|
+
|
|
732
|
+
use_fused = (self.fused_encoder_type == "multimodal_transformer") and self._can_use_fused_encoder(x_dict)
|
|
733
|
+
mu_fused: Optional[torch.Tensor] = None
|
|
734
|
+
logvar_fused: Optional[torch.Tensor] = None
|
|
735
|
+
|
|
736
|
+
if use_fused:
|
|
737
|
+
mu_fused, logvar_fused = self._compute_fused_posterior(x_dict, attn_bias_cfg=attn_bias_cfg)
|
|
738
|
+
mu_z, logvar_z = mu_fused, logvar_fused
|
|
739
|
+
else:
|
|
740
|
+
mu_z, logvar_z = self.mixture_of_experts(mu_dict, logvar_dict)
|
|
741
|
+
|
|
742
|
+
y_legacy = self._extract_legacy_y(y)
|
|
743
|
+
if self.label_encoder is not None and y_legacy is not None and epoch >= self.label_encoder_warmup:
|
|
744
|
+
B = mu_z.shape[0]
|
|
745
|
+
device = mu_z.device
|
|
746
|
+
mu_y, logvar_y = self._encode_labels_as_expert(y=y_legacy, B=B, device=device)
|
|
747
|
+
|
|
748
|
+
base_mu_dict = {"__base__": mu_z, "__label__": mu_y}
|
|
749
|
+
base_lv_dict = {"__base__": logvar_z, "__label__": logvar_y}
|
|
750
|
+
mu_z, logvar_z = self.mixture_of_experts(base_mu_dict, base_lv_dict)
|
|
751
|
+
|
|
752
|
+
z = self._reparameterize(mu_z, logvar_z)
|
|
753
|
+
xhat_dict = self.decode_modalities(z)
|
|
754
|
+
|
|
755
|
+
recon_total: Optional[torch.Tensor] = None
|
|
756
|
+
recon_losses: Dict[str, torch.Tensor] = {}
|
|
757
|
+
|
|
758
|
+
for name, m_cfg in self.mod_cfg_by_name.items():
|
|
759
|
+
if name not in x_dict or x_dict[name] is None:
|
|
760
|
+
continue
|
|
761
|
+
loss_m = self._recon_loss(
|
|
762
|
+
x=x_dict[name],
|
|
763
|
+
raw_dec_out=xhat_dict[name],
|
|
764
|
+
likelihood=m_cfg.likelihood,
|
|
765
|
+
mod_name=name,
|
|
766
|
+
)
|
|
767
|
+
recon_losses[name] = loss_m
|
|
768
|
+
recon_total = loss_m if recon_total is None else (recon_total + loss_m)
|
|
769
|
+
|
|
770
|
+
if recon_total is None:
|
|
771
|
+
raise RuntimeError("No modalities in x_dict produced recon loss.")
|
|
772
|
+
|
|
773
|
+
mu_p = self.prior_mu.expand_as(mu_z)
|
|
774
|
+
logvar_p = self.prior_logvar.expand_as(logvar_z)
|
|
775
|
+
kl = self._kl_gaussian(mu_z, logvar_z, mu_p, logvar_p)
|
|
776
|
+
|
|
777
|
+
real_mu = {k: v for k, v in mu_dict.items() if not k.startswith("__")}
|
|
778
|
+
if use_fused and (mu_fused is not None):
|
|
779
|
+
align_loss = self._alignment_loss_to_fused(real_mu, mu_fused)
|
|
780
|
+
else:
|
|
781
|
+
align_loss = self._alignment_loss_l2mu_pairwise(real_mu)
|
|
782
|
+
|
|
783
|
+
beta_t = self._anneal_weight(epoch, self.cfg.kl_anneal_start, self.cfg.kl_anneal_end, self.beta_max)
|
|
784
|
+
gamma_t = self._anneal_weight(epoch, self.cfg.align_anneal_start, self.cfg.align_anneal_end, self.gamma_max)
|
|
785
|
+
|
|
786
|
+
beta_used = beta_t if self.training else self.beta_max
|
|
787
|
+
gamma_used = gamma_t if self.training else self.gamma_max
|
|
788
|
+
|
|
789
|
+
loss_annealed = recon_total + beta_t * kl + gamma_t * align_loss
|
|
790
|
+
loss_fixed = recon_total + self.beta_max * kl + self.gamma_max * align_loss
|
|
791
|
+
loss = recon_total + beta_used * kl + gamma_used * align_loss
|
|
792
|
+
|
|
793
|
+
class_loss = None
|
|
794
|
+
class_logits = None
|
|
795
|
+
if self.label_decoder is not None:
|
|
796
|
+
z_for_cls = mu_z if self.classify_from_mu else z
|
|
797
|
+
dec_out = self.label_decoder(z_for_cls)
|
|
798
|
+
class_logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
|
|
799
|
+
|
|
800
|
+
if y_legacy is not None:
|
|
801
|
+
B = loss.shape[0]
|
|
802
|
+
yy = y_legacy.long()
|
|
803
|
+
mask = (yy >= 0) & (yy != self.label_ignore_index)
|
|
804
|
+
per_cell = torch.zeros(B, device=loss.device)
|
|
805
|
+
if mask.any():
|
|
806
|
+
per_cell[mask] = F.cross_entropy(class_logits[mask], yy[mask], reduction="none")
|
|
807
|
+
class_loss = per_cell
|
|
808
|
+
|
|
809
|
+
loss = loss + self.label_loss_weight * class_loss
|
|
810
|
+
loss_annealed = loss_annealed + self.label_loss_weight * class_loss
|
|
811
|
+
loss_fixed = loss_fixed + self.label_loss_weight * class_loss
|
|
812
|
+
|
|
813
|
+
loss, loss_annealed, loss_fixed, head_logits, head_loss_means = self._apply_multihead_losses(
|
|
814
|
+
mu_z=mu_z,
|
|
815
|
+
z=z,
|
|
816
|
+
y=y,
|
|
817
|
+
epoch=epoch,
|
|
818
|
+
loss=loss,
|
|
819
|
+
loss_annealed=loss_annealed,
|
|
820
|
+
loss_fixed=loss_fixed,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
out: Dict[str, Any] = {
|
|
824
|
+
"loss": loss.mean(),
|
|
825
|
+
"recon_total": recon_total.mean(),
|
|
826
|
+
"kl": kl.mean(),
|
|
827
|
+
"align": align_loss.mean(),
|
|
828
|
+
"mu_z": mu_z,
|
|
829
|
+
"logvar_z": logvar_z,
|
|
830
|
+
"z": z,
|
|
831
|
+
"xhat": xhat_dict,
|
|
832
|
+
"mu_dict": mu_dict,
|
|
833
|
+
"logvar_dict": logvar_dict,
|
|
834
|
+
"recon_per_modality": {k: v.mean() for k, v in recon_losses.items()},
|
|
835
|
+
"beta": torch.tensor(beta_t, device=loss.device),
|
|
836
|
+
"gamma": torch.tensor(gamma_t, device=loss.device),
|
|
837
|
+
"beta_used": torch.tensor(beta_used, device=loss.device),
|
|
838
|
+
"gamma_used": torch.tensor(gamma_used, device=loss.device),
|
|
839
|
+
"loss_annealed": loss_annealed.mean(),
|
|
840
|
+
"loss_fixed": loss_fixed.mean(),
|
|
841
|
+
"used_fused_encoder": torch.tensor(1.0 if use_fused else 0.0, device=loss.device),
|
|
842
|
+
}
|
|
843
|
+
if class_loss is not None:
|
|
844
|
+
out["class_loss"] = class_loss.mean()
|
|
845
|
+
if class_logits is not None:
|
|
846
|
+
out["class_logits"] = class_logits
|
|
847
|
+
if head_logits:
|
|
848
|
+
out["head_logits"] = head_logits
|
|
849
|
+
if head_loss_means:
|
|
850
|
+
out["head_losses"] = head_loss_means
|
|
851
|
+
return out
|
|
852
|
+
|
|
853
|
+
# ------------------------------ v1 ------------------------------
|
|
854
|
+
|
|
855
|
+
def _forward_v1(
|
|
856
|
+
self,
|
|
857
|
+
x_dict: Dict[str, torch.Tensor],
|
|
858
|
+
epoch: int = 0,
|
|
859
|
+
y: Optional[YType] = None,
|
|
860
|
+
attn_bias_cfg: Optional[Mapping[str, Any]] = None,
|
|
861
|
+
) -> Dict[str, torch.Tensor]:
|
|
862
|
+
mu_dict, logvar_dict = self.encode_modalities(x_dict, attn_bias_cfg=attn_bias_cfg)
|
|
863
|
+
if len(mu_dict) == 0:
|
|
864
|
+
raise ValueError("At least one modality must be present in x_dict.")
|
|
865
|
+
|
|
866
|
+
present = list(mu_dict.keys())
|
|
867
|
+
K = len(present)
|
|
868
|
+
|
|
869
|
+
use_fused = (self.fused_encoder_type == "multimodal_transformer") and self._can_use_fused_encoder(x_dict)
|
|
870
|
+
if use_fused:
|
|
871
|
+
mu_moe, logvar_moe = self._compute_fused_posterior(x_dict, attn_bias_cfg=attn_bias_cfg)
|
|
872
|
+
else:
|
|
873
|
+
mu_moe, logvar_moe = self.mixture_of_experts(mu_dict, logvar_dict)
|
|
874
|
+
|
|
875
|
+
z_moe = self._reparameterize(mu_moe, logvar_moe)
|
|
876
|
+
z_mod = {m: self._reparameterize(mu_dict[m], logvar_dict[m]) for m in present}
|
|
877
|
+
|
|
878
|
+
def recon_from_z_to_target(z_src: torch.Tensor, target_mod: str) -> torch.Tensor:
|
|
879
|
+
raw = self.decoders[target_mod](z_src)
|
|
880
|
+
m_cfg = self.mod_cfg_by_name[target_mod]
|
|
881
|
+
return self._recon_loss(x=x_dict[target_mod], raw_dec_out=raw, likelihood=m_cfg.likelihood, mod_name=target_mod)
|
|
882
|
+
|
|
883
|
+
device = next(iter(mu_dict.values())).device
|
|
884
|
+
B = next(iter(mu_dict.values())).shape[0]
|
|
885
|
+
|
|
886
|
+
recon_per_target: Dict[str, torch.Tensor] = {m: torch.zeros(B, device=device) for m in present}
|
|
887
|
+
recon_counts: Dict[str, int] = {m: 0 for m in present}
|
|
888
|
+
|
|
889
|
+
recon_total: Optional[torch.Tensor] = None
|
|
890
|
+
n_terms = 0
|
|
891
|
+
|
|
892
|
+
def add_term(t: torch.Tensor, tgt: str):
|
|
893
|
+
nonlocal recon_total, n_terms
|
|
894
|
+
recon_per_target[tgt] = recon_per_target[tgt] + t
|
|
895
|
+
recon_counts[tgt] += 1
|
|
896
|
+
recon_total = t if recon_total is None else (recon_total + t)
|
|
897
|
+
n_terms += 1
|
|
898
|
+
|
|
899
|
+
v1_recon = self.v1_recon
|
|
900
|
+
|
|
901
|
+
if v1_recon in ("avg", "both", "paper", "self+cross", "self_cross", "hybrid"):
|
|
902
|
+
if K <= 1:
|
|
903
|
+
tgt = present[0]
|
|
904
|
+
add_term(recon_from_z_to_target(z_mod[tgt], tgt), tgt=tgt)
|
|
905
|
+
else:
|
|
906
|
+
w_self = 0.5 / float(K)
|
|
907
|
+
w_cross = 0.5 / float(K * (K - 1))
|
|
908
|
+
for tgt in present:
|
|
909
|
+
add_term(w_self * recon_from_z_to_target(z_mod[tgt], tgt), tgt=tgt)
|
|
910
|
+
for src in present:
|
|
911
|
+
for tgt in present:
|
|
912
|
+
if src == tgt:
|
|
913
|
+
continue
|
|
914
|
+
add_term(w_cross * recon_from_z_to_target(z_mod[src], tgt), tgt=tgt)
|
|
915
|
+
|
|
916
|
+
elif v1_recon.startswith("src:"):
|
|
917
|
+
src_name = v1_recon.split("src:", 1)[1].strip()
|
|
918
|
+
if src_name not in z_mod:
|
|
919
|
+
raise ValueError(f"v1_recon={self.v1_recon!r} but '{src_name}' not present. Present={present}")
|
|
920
|
+
for tgt in present:
|
|
921
|
+
add_term(recon_from_z_to_target(z_mod[src_name], tgt), tgt=tgt)
|
|
922
|
+
|
|
923
|
+
elif v1_recon == "self":
|
|
924
|
+
for tgt in present:
|
|
925
|
+
add_term(recon_from_z_to_target(z_mod[tgt], tgt), tgt=tgt)
|
|
926
|
+
|
|
927
|
+
elif v1_recon in ("avg_z", "mean_z", "average_z"):
|
|
928
|
+
z_avg = torch.stack([z_mod[m] for m in present], dim=0).mean(dim=0)
|
|
929
|
+
for tgt in present:
|
|
930
|
+
add_term(recon_from_z_to_target(z_avg, tgt), tgt=tgt)
|
|
931
|
+
z_moe = z_avg
|
|
932
|
+
|
|
933
|
+
elif v1_recon in ("moe", "poe", "fused"):
|
|
934
|
+
for tgt in present:
|
|
935
|
+
add_term(recon_from_z_to_target(z_moe, tgt), tgt=tgt)
|
|
936
|
+
|
|
937
|
+
else:
|
|
938
|
+
if K >= 2:
|
|
939
|
+
for src in present:
|
|
940
|
+
for tgt in present:
|
|
941
|
+
if src == tgt:
|
|
942
|
+
continue
|
|
943
|
+
add_term(recon_from_z_to_target(z_mod[src], tgt), tgt=tgt)
|
|
944
|
+
else:
|
|
945
|
+
tgt = present[0]
|
|
946
|
+
add_term(recon_from_z_to_target(z_mod[tgt], tgt), tgt=tgt)
|
|
947
|
+
|
|
948
|
+
if self.v1_recon_mix > 0.0 and K >= 2:
|
|
949
|
+
mix = float(self.v1_recon_mix)
|
|
950
|
+
z_avg = torch.stack([z_mod[m] for m in present], dim=0).mean(dim=0)
|
|
951
|
+
for tgt in present:
|
|
952
|
+
add_term(mix * recon_from_z_to_target(z_avg, tgt), tgt=tgt)
|
|
953
|
+
|
|
954
|
+
if recon_total is None or n_terms == 0:
|
|
955
|
+
raise RuntimeError("v1 reconstruction produced no loss terms.")
|
|
956
|
+
|
|
957
|
+
weighted_mode = v1_recon in ("avg", "both", "paper", "self+cross", "self_cross", "hybrid")
|
|
958
|
+
if self.normalize_v1_terms and (not weighted_mode):
|
|
959
|
+
recon_total = recon_total / float(n_terms)
|
|
960
|
+
|
|
961
|
+
recon_per_target_mean: Dict[str, torch.Tensor] = {}
|
|
962
|
+
for tgt in present:
|
|
963
|
+
ct = max(int(recon_counts[tgt]), 1)
|
|
964
|
+
recon_per_target_mean[tgt] = recon_per_target[tgt] / float(ct)
|
|
965
|
+
|
|
966
|
+
mu_p = self.prior_mu.expand_as(mu_dict[present[0]])
|
|
967
|
+
logvar_p = self.prior_logvar.expand_as(logvar_dict[present[0]])
|
|
968
|
+
|
|
969
|
+
kl_terms = [self._kl_gaussian(mu_dict[m], logvar_dict[m], mu_p, logvar_p) for m in present]
|
|
970
|
+
kl = torch.stack(kl_terms, dim=0).sum(dim=0)
|
|
971
|
+
if self.normalize_v1_terms:
|
|
972
|
+
kl = kl / float(max(K, 1))
|
|
973
|
+
|
|
974
|
+
if K < 2:
|
|
975
|
+
cross_kl = torch.zeros_like(kl)
|
|
976
|
+
else:
|
|
977
|
+
cross_terms = []
|
|
978
|
+
for i in range(K):
|
|
979
|
+
for j in range(K):
|
|
980
|
+
if i == j:
|
|
981
|
+
continue
|
|
982
|
+
mi, lvi = mu_dict[present[i]], logvar_dict[present[i]]
|
|
983
|
+
mj, lvj = mu_dict[present[j]], logvar_dict[present[j]]
|
|
984
|
+
cross_terms.append(self._kl_gaussian(mi, lvi, mj, lvj))
|
|
985
|
+
cross_kl = torch.stack(cross_terms, dim=0).sum(dim=0)
|
|
986
|
+
if self.normalize_v1_terms:
|
|
987
|
+
cross_kl = cross_kl / float(K * (K - 1))
|
|
988
|
+
|
|
989
|
+
beta_t = self._anneal_weight(epoch, self.cfg.kl_anneal_start, self.cfg.kl_anneal_end, self.beta_max)
|
|
990
|
+
gamma_t = self._anneal_weight(epoch, self.cfg.align_anneal_start, self.cfg.align_anneal_end, self.gamma_max)
|
|
991
|
+
|
|
992
|
+
beta_used = beta_t if self.training else self.beta_max
|
|
993
|
+
gamma_used = gamma_t if self.training else self.gamma_max
|
|
994
|
+
|
|
995
|
+
loss_annealed = recon_total + beta_t * kl + gamma_t * cross_kl
|
|
996
|
+
loss_fixed = recon_total + self.beta_max * kl + self.gamma_max * cross_kl
|
|
997
|
+
loss = recon_total + beta_used * kl + gamma_used * cross_kl
|
|
998
|
+
|
|
999
|
+
y_legacy = self._extract_legacy_y(y)
|
|
1000
|
+
class_loss = None
|
|
1001
|
+
class_logits = None
|
|
1002
|
+
if self.label_decoder is not None:
|
|
1003
|
+
z_for_cls = mu_moe if self.classify_from_mu else z_moe
|
|
1004
|
+
dec_out = self.label_decoder(z_for_cls)
|
|
1005
|
+
class_logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
|
|
1006
|
+
|
|
1007
|
+
if y_legacy is not None:
|
|
1008
|
+
yy = y_legacy.long()
|
|
1009
|
+
mask = (yy >= 0) & (yy != self.label_ignore_index)
|
|
1010
|
+
per_cell = torch.zeros(B, device=loss.device)
|
|
1011
|
+
if mask.any():
|
|
1012
|
+
per_cell[mask] = F.cross_entropy(class_logits[mask], yy[mask], reduction="none")
|
|
1013
|
+
class_loss = per_cell
|
|
1014
|
+
|
|
1015
|
+
loss = loss + self.label_loss_weight * class_loss
|
|
1016
|
+
loss_annealed = loss_annealed + self.label_loss_weight * class_loss
|
|
1017
|
+
loss_fixed = loss_fixed + self.label_loss_weight * class_loss
|
|
1018
|
+
|
|
1019
|
+
loss, loss_annealed, loss_fixed, head_logits, head_loss_means = self._apply_multihead_losses(
|
|
1020
|
+
mu_z=mu_moe,
|
|
1021
|
+
z=z_moe,
|
|
1022
|
+
y=y,
|
|
1023
|
+
epoch=epoch,
|
|
1024
|
+
loss=loss,
|
|
1025
|
+
loss_annealed=loss_annealed,
|
|
1026
|
+
loss_fixed=loss_fixed,
|
|
1027
|
+
)
|
|
1028
|
+
|
|
1029
|
+
xhat_dict = self.decode_modalities(z_moe)
|
|
1030
|
+
|
|
1031
|
+
out: Dict[str, Any] = {
|
|
1032
|
+
"loss": loss.mean(),
|
|
1033
|
+
"recon_total": recon_total.mean(),
|
|
1034
|
+
"kl": kl.mean(),
|
|
1035
|
+
"align": cross_kl.mean(),
|
|
1036
|
+
"cross_kl": cross_kl.mean(),
|
|
1037
|
+
"mu_z": mu_moe,
|
|
1038
|
+
"logvar_z": logvar_moe,
|
|
1039
|
+
"z": z_moe,
|
|
1040
|
+
"xhat": xhat_dict,
|
|
1041
|
+
"mu_dict": mu_dict,
|
|
1042
|
+
"logvar_dict": logvar_dict,
|
|
1043
|
+
"recon_per_modality": {k: v.mean() for k, v in recon_per_target_mean.items()},
|
|
1044
|
+
"beta": torch.tensor(beta_t, device=loss.device),
|
|
1045
|
+
"gamma": torch.tensor(gamma_t, device=loss.device),
|
|
1046
|
+
"beta_used": torch.tensor(beta_used, device=loss.device),
|
|
1047
|
+
"gamma_used": torch.tensor(gamma_used, device=loss.device),
|
|
1048
|
+
"loss_annealed": loss_annealed.mean(),
|
|
1049
|
+
"loss_fixed": loss_fixed.mean(),
|
|
1050
|
+
"v1_recon_terms": torch.tensor(float(n_terms), device=loss.device),
|
|
1051
|
+
"used_fused_encoder": torch.tensor(1.0 if use_fused else 0.0, device=loss.device),
|
|
1052
|
+
}
|
|
1053
|
+
if class_loss is not None:
|
|
1054
|
+
out["class_loss"] = class_loss.mean()
|
|
1055
|
+
if class_logits is not None:
|
|
1056
|
+
out["class_logits"] = class_logits
|
|
1057
|
+
if head_logits:
|
|
1058
|
+
out["head_logits"] = head_logits
|
|
1059
|
+
if head_loss_means:
|
|
1060
|
+
out["head_losses"] = head_loss_means
|
|
1061
|
+
return out
|
|
1062
|
+
|
|
1063
|
+
# ------------------------------ reconstruction loss (unchanged) ------------------------------
|
|
1064
|
+
|
|
1065
|
+
def _unwrap_decoder_out(self, dec_out: Any) -> Any:
|
|
1066
|
+
if isinstance(dec_out, (tuple, list)):
|
|
1067
|
+
if len(dec_out) != 1:
|
|
1068
|
+
raise TypeError(f"Unsupported decoder output container of length {len(dec_out)}: {type(dec_out)!r}")
|
|
1069
|
+
dec_out = dec_out[0]
|
|
1070
|
+
if torch.is_tensor(dec_out):
|
|
1071
|
+
return dec_out
|
|
1072
|
+
if isinstance(dec_out, dict):
|
|
1073
|
+
out = dict(dec_out)
|
|
1074
|
+
if "mu" not in out and "mean" in out:
|
|
1075
|
+
out["mu"] = out["mean"]
|
|
1076
|
+
return out
|
|
1077
|
+
raise TypeError(f"Unsupported decoder output type: {type(dec_out)!r}")
|
|
1078
|
+
|
|
1079
|
+
@staticmethod
|
|
1080
|
+
def _nb_nll(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
|
1081
|
+
mu = mu.clamp(min=eps)
|
|
1082
|
+
theta = theta.clamp(min=eps)
|
|
1083
|
+
t1 = torch.lgamma(x + theta) - torch.lgamma(theta) - torch.lgamma(x + 1.0)
|
|
1084
|
+
t2 = theta * (torch.log(theta) - torch.log(theta + mu))
|
|
1085
|
+
t3 = x * (torch.log(mu) - torch.log(theta + mu))
|
|
1086
|
+
return -(t1 + t2 + t3)
|
|
1087
|
+
|
|
1088
|
+
@staticmethod
|
|
1089
|
+
def _zinb_nll(x: torch.Tensor, mu: torch.Tensor, theta: torch.Tensor, logit_pi: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
|
1090
|
+
mu = mu.clamp(min=eps)
|
|
1091
|
+
theta = theta.clamp(min=eps)
|
|
1092
|
+
pi = torch.sigmoid(logit_pi)
|
|
1093
|
+
t1 = torch.lgamma(x + theta) - torch.lgamma(theta) - torch.lgamma(x + 1.0)
|
|
1094
|
+
t2 = theta * (torch.log(theta) - torch.log(theta + mu))
|
|
1095
|
+
t3 = x * (torch.log(mu) - torch.log(theta + mu))
|
|
1096
|
+
log_nb = t1 + t2 + t3
|
|
1097
|
+
is_zero = (x < eps)
|
|
1098
|
+
log_prob_pos = torch.log1p(-pi + eps) + log_nb
|
|
1099
|
+
log_nb_zero = theta * (torch.log(theta) - torch.log(theta + mu))
|
|
1100
|
+
log_prob_zero = torch.log(pi + (1.0 - pi) * torch.exp(log_nb_zero) + eps)
|
|
1101
|
+
log_prob = torch.where(is_zero, log_prob_zero, log_prob_pos)
|
|
1102
|
+
return -log_prob
|
|
1103
|
+
|
|
1104
|
+
def _recon_loss(self, x: torch.Tensor, raw_dec_out: Any, likelihood: str, mod_name: str) -> torch.Tensor:
|
|
1105
|
+
likelihood = (likelihood or "gaussian").lower().strip()
|
|
1106
|
+
dec_out = self._unwrap_decoder_out(raw_dec_out)
|
|
1107
|
+
|
|
1108
|
+
if self._is_categorical_likelihood(likelihood):
|
|
1109
|
+
m_cfg = self.mod_cfg_by_name[mod_name]
|
|
1110
|
+
C = int(m_cfg.input_dim)
|
|
1111
|
+
ignore_index = int(getattr(m_cfg, "ignore_index", self.label_ignore_index))
|
|
1112
|
+
|
|
1113
|
+
logits = dec_out["logits"] if isinstance(dec_out, dict) else dec_out
|
|
1114
|
+
y, mask = self._categorical_targets_and_mask(x, n_classes=C, ignore_index=ignore_index)
|
|
1115
|
+
|
|
1116
|
+
nll = torch.zeros(y.shape[0], device=logits.device)
|
|
1117
|
+
if mask.any():
|
|
1118
|
+
nll[mask] = F.cross_entropy(logits[mask], y[mask], reduction="none")
|
|
1119
|
+
return nll
|
|
1120
|
+
|
|
1121
|
+
if likelihood in ("gaussian", "normal", "mse", "gaussian_diag"):
|
|
1122
|
+
if isinstance(dec_out, dict) and ("mean" in dec_out) and ("logvar" in dec_out):
|
|
1123
|
+
mean = dec_out["mean"]
|
|
1124
|
+
logvar = dec_out["logvar"].clamp(self.LOGVAR_MIN, self.LOGVAR_MAX)
|
|
1125
|
+
var = torch.exp(logvar)
|
|
1126
|
+
nll = 0.5 * (logvar + (x - mean) ** 2 / (var + self.EPS))
|
|
1127
|
+
return nll.sum(dim=-1)
|
|
1128
|
+
|
|
1129
|
+
pred = dec_out["mean"] if isinstance(dec_out, dict) and ("mean" in dec_out) else dec_out
|
|
1130
|
+
if likelihood == "mse":
|
|
1131
|
+
return ((x - pred) ** 2).mean(dim=-1)
|
|
1132
|
+
return ((x - pred) ** 2).sum(dim=-1)
|
|
1133
|
+
|
|
1134
|
+
if likelihood == "bernoulli":
|
|
1135
|
+
logits = dec_out["logits"] if isinstance(dec_out, dict) else dec_out
|
|
1136
|
+
nll = F.binary_cross_entropy_with_logits(logits, x, reduction="none")
|
|
1137
|
+
return nll.sum(dim=-1)
|
|
1138
|
+
|
|
1139
|
+
if likelihood == "poisson":
|
|
1140
|
+
log_rate = dec_out["log_rate"] if isinstance(dec_out, dict) and "log_rate" in dec_out else dec_out
|
|
1141
|
+
nll = F.poisson_nll_loss(log_rate, x, log_input=True, full=False, reduction="none")
|
|
1142
|
+
return nll.sum(dim=-1)
|
|
1143
|
+
|
|
1144
|
+
if likelihood in ("nb", "negative_binomial"):
|
|
1145
|
+
if not isinstance(dec_out, dict) or ("mu" not in dec_out) or ("log_theta" not in dec_out):
|
|
1146
|
+
raise ValueError(f"NB recon expects dict with keys ('mu','log_theta'); got {type(dec_out)}")
|
|
1147
|
+
mu = dec_out["mu"]
|
|
1148
|
+
theta = torch.exp(dec_out["log_theta"])
|
|
1149
|
+
if theta.dim() == 1:
|
|
1150
|
+
theta = theta.unsqueeze(0).expand_as(mu)
|
|
1151
|
+
return self._nb_nll(x, mu, theta, eps=self.EPS).sum(dim=-1)
|
|
1152
|
+
|
|
1153
|
+
if likelihood in ("zinb", "zero_inflated_negative_binomial"):
|
|
1154
|
+
if (
|
|
1155
|
+
not isinstance(dec_out, dict)
|
|
1156
|
+
or ("mu" not in dec_out)
|
|
1157
|
+
or ("log_theta" not in dec_out)
|
|
1158
|
+
or ("logit_pi" not in dec_out)
|
|
1159
|
+
):
|
|
1160
|
+
raise ValueError("ZINB recon expects dict with keys ('mu','log_theta','logit_pi').")
|
|
1161
|
+
mu = dec_out["mu"]
|
|
1162
|
+
theta = torch.exp(dec_out["log_theta"])
|
|
1163
|
+
logit_pi = dec_out["logit_pi"]
|
|
1164
|
+
if theta.dim() == 1:
|
|
1165
|
+
theta = theta.unsqueeze(0).expand_as(mu)
|
|
1166
|
+
if logit_pi.dim() == 1:
|
|
1167
|
+
logit_pi = logit_pi.unsqueeze(0).expand_as(mu)
|
|
1168
|
+
return self._zinb_nll(x, mu, theta, logit_pi, eps=self.EPS).sum(dim=-1)
|
|
1169
|
+
|
|
1170
|
+
pred = dec_out["mean"] if isinstance(dec_out, dict) and ("mean" in dec_out) else dec_out
|
|
1171
|
+
return ((x - pred) ** 2).sum(dim=-1)
|
|
1172
|
+
|
|
1173
|
+
# ------------------------------ annealing ------------------------------
|
|
1174
|
+
|
|
1175
|
+
def _anneal_weight(self, epoch: int, start: int, end: int, max_val: float) -> float:
|
|
1176
|
+
start = int(start)
|
|
1177
|
+
end = int(end)
|
|
1178
|
+
if end <= start:
|
|
1179
|
+
return float(max_val)
|
|
1180
|
+
if epoch <= start:
|
|
1181
|
+
return 0.0
|
|
1182
|
+
if epoch >= end:
|
|
1183
|
+
return float(max_val)
|
|
1184
|
+
frac = (epoch - start) / float(end - start)
|
|
1185
|
+
return float(max_val) * float(frac)
|
|
1186
|
+
|
|
1187
|
+
# --------------------------- convenience API ---------------------------
|
|
1188
|
+
|
|
1189
|
+
@property
|
|
1190
|
+
def device(self) -> torch.device:
|
|
1191
|
+
return next(self.parameters()).device
|
|
1192
|
+
|
|
1193
|
+
@torch.no_grad()
|
|
1194
|
+
def encode_fused(
|
|
1195
|
+
self,
|
|
1196
|
+
x_dict: Dict[str, torch.Tensor],
|
|
1197
|
+
*,
|
|
1198
|
+
epoch: int = 0,
|
|
1199
|
+
y: Optional[YType] = None,
|
|
1200
|
+
use_mean: bool = True,
|
|
1201
|
+
inject_label_expert: bool = True,
|
|
1202
|
+
attn_bias_cfg: Optional[Mapping[str, Any]] = None,
|
|
1203
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1204
|
+
mu_dict, logvar_dict = self.encode_modalities(x_dict, attn_bias_cfg=attn_bias_cfg)
|
|
1205
|
+
if len(mu_dict) == 0:
|
|
1206
|
+
raise ValueError("At least one modality must be present in x_dict.")
|
|
1207
|
+
|
|
1208
|
+
use_fused = (self.fused_encoder_type == "multimodal_transformer") and self._can_use_fused_encoder(x_dict)
|
|
1209
|
+
if use_fused:
|
|
1210
|
+
mu_z, logvar_z = self._compute_fused_posterior(x_dict, attn_bias_cfg=attn_bias_cfg)
|
|
1211
|
+
else:
|
|
1212
|
+
mu_z, logvar_z = self.mixture_of_experts(mu_dict, logvar_dict)
|
|
1213
|
+
|
|
1214
|
+
y_legacy = self._extract_legacy_y(y)
|
|
1215
|
+
if (
|
|
1216
|
+
inject_label_expert
|
|
1217
|
+
and (self.label_encoder is not None)
|
|
1218
|
+
and (y_legacy is not None)
|
|
1219
|
+
and (epoch >= self.label_encoder_warmup)
|
|
1220
|
+
):
|
|
1221
|
+
B = mu_z.shape[0]
|
|
1222
|
+
dev = mu_z.device
|
|
1223
|
+
mu_y, logvar_y = self._encode_labels_as_expert(y=y_legacy, B=B, device=dev)
|
|
1224
|
+
base_mu_dict = {"__base__": mu_z, "__label__": mu_y}
|
|
1225
|
+
base_lv_dict = {"__base__": logvar_z, "__label__": logvar_y}
|
|
1226
|
+
mu_z, logvar_z = self.mixture_of_experts(base_mu_dict, base_lv_dict)
|
|
1227
|
+
|
|
1228
|
+
z = mu_z if use_mean else self._reparameterize(mu_z, logvar_z)
|
|
1229
|
+
return mu_z, logvar_z, z
|
|
1230
|
+
|
|
1231
|
+
@torch.no_grad()
|
|
1232
|
+
def predict_heads(
|
|
1233
|
+
self,
|
|
1234
|
+
x_dict: Dict[str, torch.Tensor],
|
|
1235
|
+
*,
|
|
1236
|
+
epoch: int = 0,
|
|
1237
|
+
y: Optional[YType] = None,
|
|
1238
|
+
use_mean: bool = True,
|
|
1239
|
+
inject_label_expert: bool = True,
|
|
1240
|
+
return_probs: bool = True,
|
|
1241
|
+
attn_bias_cfg: Optional[Mapping[str, Any]] = None,
|
|
1242
|
+
) -> Dict[str, torch.Tensor]:
|
|
1243
|
+
mu_z, logvar_z, z = self.encode_fused(
|
|
1244
|
+
x_dict,
|
|
1245
|
+
epoch=epoch,
|
|
1246
|
+
y=y,
|
|
1247
|
+
use_mean=use_mean,
|
|
1248
|
+
inject_label_expert=inject_label_expert,
|
|
1249
|
+
attn_bias_cfg=attn_bias_cfg,
|
|
1250
|
+
)
|
|
1251
|
+
out: Dict[str, torch.Tensor] = {}
|
|
1252
|
+
|
|
1253
|
+
if self.label_decoder is not None:
|
|
1254
|
+
z_for_cls = mu_z if self.classify_from_mu else z
|
|
1255
|
+
dec_out = self.label_decoder(z_for_cls)
|
|
1256
|
+
logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
|
|
1257
|
+
out[self.label_head_name] = logits if not return_probs else F.softmax(logits, dim=-1)
|
|
1258
|
+
|
|
1259
|
+
for name, head in self.class_heads.items():
|
|
1260
|
+
cfg_h = self.class_heads_cfg[name]
|
|
1261
|
+
z_in = mu_z if bool(cfg_h["from_mu"]) else z
|
|
1262
|
+
dec_out = head(z_in)
|
|
1263
|
+
logits = dec_out["logits"] if isinstance(dec_out, dict) and "logits" in dec_out else dec_out
|
|
1264
|
+
out[name] = logits if not return_probs else F.softmax(logits, dim=-1)
|
|
1265
|
+
|
|
1266
|
+
return out
|
|
1267
|
+
|
|
1268
|
+
def get_classification_meta(self) -> Dict[str, Any]:
|
|
1269
|
+
meta: Dict[str, Any] = {
|
|
1270
|
+
"label_head_name": self.label_head_name,
|
|
1271
|
+
"legacy": {
|
|
1272
|
+
"n_label_classes": int(self.n_label_classes),
|
|
1273
|
+
"label_ignore_index": int(self.label_ignore_index),
|
|
1274
|
+
},
|
|
1275
|
+
"multi": {
|
|
1276
|
+
"heads": {k: dict(v) for k, v in self.class_heads_cfg.items()},
|
|
1277
|
+
},
|
|
1278
|
+
}
|
|
1279
|
+
if self.label_names is not None:
|
|
1280
|
+
meta["legacy"]["label_names"] = list(self.label_names)
|
|
1281
|
+
if self.head_label_names:
|
|
1282
|
+
meta["multi"]["label_names"] = {k: list(v) for k, v in self.head_label_names.items()}
|
|
1283
|
+
return meta
|
|
1284
|
+
|