univi 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- univi/__init__.py +120 -0
- univi/__main__.py +5 -0
- univi/cli.py +60 -0
- univi/config.py +340 -0
- univi/data.py +345 -0
- univi/diagnostics.py +130 -0
- univi/evaluation.py +632 -0
- univi/hyperparam_optimization/__init__.py +17 -0
- univi/hyperparam_optimization/common.py +339 -0
- univi/hyperparam_optimization/run_adt_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_atac_hparam_search.py +109 -0
- univi/hyperparam_optimization/run_citeseq_hparam_search.py +137 -0
- univi/hyperparam_optimization/run_multiome_hparam_search.py +145 -0
- univi/hyperparam_optimization/run_rna_hparam_search.py +111 -0
- univi/hyperparam_optimization/run_teaseq_hparam_search.py +146 -0
- univi/interpretability.py +399 -0
- univi/matching.py +394 -0
- univi/models/__init__.py +8 -0
- univi/models/decoders.py +249 -0
- univi/models/encoders.py +848 -0
- univi/models/mlp.py +36 -0
- univi/models/tokenizers.py +376 -0
- univi/models/transformer.py +249 -0
- univi/models/univi.py +1284 -0
- univi/objectives.py +46 -0
- univi/pipeline.py +194 -0
- univi/plotting.py +126 -0
- univi/trainer.py +478 -0
- univi/utils/__init__.py +5 -0
- univi/utils/io.py +621 -0
- univi/utils/logging.py +16 -0
- univi/utils/seed.py +18 -0
- univi/utils/stats.py +23 -0
- univi/utils/torch_utils.py +23 -0
- univi-0.3.4.dist-info/METADATA +908 -0
- univi-0.3.4.dist-info/RECORD +40 -0
- univi-0.3.4.dist-info/WHEEL +5 -0
- univi-0.3.4.dist-info/entry_points.txt +2 -0
- univi-0.3.4.dist-info/licenses/LICENSE +21 -0
- univi-0.3.4.dist-info/top_level.txt +1 -0
univi/trainer.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
# univi/trainer.py
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import asdict
|
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union, Mapping
|
|
7
|
+
|
|
8
|
+
import contextlib
|
|
9
|
+
|
|
10
|
+
import anndata as ad
|
|
11
|
+
import numpy as np
|
|
12
|
+
import scipy.sparse as sp
|
|
13
|
+
import torch
|
|
14
|
+
from torch import nn, optim
|
|
15
|
+
from torch.utils.data import DataLoader
|
|
16
|
+
from tqdm.auto import tqdm
|
|
17
|
+
|
|
18
|
+
from .config import TrainingConfig
|
|
19
|
+
from .utils.io import restore_checkpoint, save_checkpoint
|
|
20
|
+
from .utils.logging import get_logger
|
|
21
|
+
|
|
22
|
+
YType = Union[torch.Tensor, Dict[str, torch.Tensor]]
|
|
23
|
+
BatchType = Union[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], YType]]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class UniVITrainer:
|
|
27
|
+
"""
|
|
28
|
+
Lightweight training loop for UniVI models.
|
|
29
|
+
|
|
30
|
+
Supports:
|
|
31
|
+
- mixed precision (AMP) via use_amp + amp_dtype
|
|
32
|
+
- optional coordinate metadata and attention-bias configuration for tokenized ATAC
|
|
33
|
+
(passed through to model forward when supported)
|
|
34
|
+
- checkpoint save/load via utils.io helpers
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
model: nn.Module,
|
|
40
|
+
train_loader: DataLoader,
|
|
41
|
+
val_loader: Optional[DataLoader] = None,
|
|
42
|
+
train_cfg: Optional[TrainingConfig] = None,
|
|
43
|
+
device: Optional[str] = None,
|
|
44
|
+
use_amp: bool = False,
|
|
45
|
+
amp_dtype: str = "fp16", # "fp16" or "bf16"
|
|
46
|
+
*,
|
|
47
|
+
feature_coords: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
48
|
+
attn_bias_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
49
|
+
):
|
|
50
|
+
self.model = model
|
|
51
|
+
self.train_loader = train_loader
|
|
52
|
+
self.val_loader = val_loader
|
|
53
|
+
|
|
54
|
+
self.cfg = train_cfg or TrainingConfig()
|
|
55
|
+
self.device = device or self.cfg.device
|
|
56
|
+
|
|
57
|
+
self.use_amp = bool(use_amp)
|
|
58
|
+
self.amp_dtype = str(amp_dtype).lower().strip()
|
|
59
|
+
if self.amp_dtype not in ("fp16", "bf16"):
|
|
60
|
+
raise ValueError(f"amp_dtype must be 'fp16' or 'bf16', got {amp_dtype!r}")
|
|
61
|
+
|
|
62
|
+
# Optional genomic context / transformer bias configuration (safe defaults).
|
|
63
|
+
# These are *data-derived* and intentionally not serialized by default.
|
|
64
|
+
self.feature_coords: Dict[str, Dict[str, Any]] = feature_coords or {}
|
|
65
|
+
self.attn_bias_cfg: Dict[str, Dict[str, Any]] = attn_bias_cfg or {}
|
|
66
|
+
|
|
67
|
+
self.model.to(self.device)
|
|
68
|
+
|
|
69
|
+
self.optimizer = optim.Adam(
|
|
70
|
+
self.model.parameters(),
|
|
71
|
+
lr=float(self.cfg.lr),
|
|
72
|
+
weight_decay=float(self.cfg.weight_decay),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
self._scaler: Optional[torch.cuda.amp.GradScaler] = None
|
|
76
|
+
if self.use_amp and torch.cuda.is_available() and str(self.device).startswith("cuda"):
|
|
77
|
+
# GradScaler is only used for fp16. bf16 autocast generally does not need scaling.
|
|
78
|
+
self._scaler = torch.cuda.amp.GradScaler(enabled=(self.amp_dtype == "fp16"))
|
|
79
|
+
|
|
80
|
+
self.logger = get_logger("UniVITrainer")
|
|
81
|
+
|
|
82
|
+
self.best_val_loss = float("inf")
|
|
83
|
+
self.best_state_dict: Optional[Dict[str, Any]] = None
|
|
84
|
+
self.epochs_no_improve = 0
|
|
85
|
+
self.best_epoch: Optional[int] = None
|
|
86
|
+
|
|
87
|
+
self.history: Dict[str, List[float]] = {
|
|
88
|
+
"train_loss": [],
|
|
89
|
+
"val_loss": [],
|
|
90
|
+
"beta": [],
|
|
91
|
+
"gamma": [],
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
self._log_config()
|
|
95
|
+
|
|
96
|
+
# ------------------------------ training ------------------------------
|
|
97
|
+
|
|
98
|
+
def fit(self) -> Dict[str, List[float]]:
|
|
99
|
+
epoch_iter = tqdm(range(1, int(self.cfg.n_epochs) + 1), desc="Training UniVI", leave=True)
|
|
100
|
+
|
|
101
|
+
for epoch in epoch_iter:
|
|
102
|
+
tr_loss, tr_beta, tr_gamma = self._run_one_epoch(epoch, train=True)
|
|
103
|
+
self.history["train_loss"].append(tr_loss)
|
|
104
|
+
self.history["beta"].append(tr_beta)
|
|
105
|
+
self.history["gamma"].append(tr_gamma)
|
|
106
|
+
|
|
107
|
+
if self.val_loader is not None:
|
|
108
|
+
va_loss, _, _ = self._run_one_epoch(epoch, train=False)
|
|
109
|
+
self.history["val_loss"].append(va_loss)
|
|
110
|
+
|
|
111
|
+
epoch_iter.set_postfix(
|
|
112
|
+
train_loss="%.4f" % tr_loss,
|
|
113
|
+
val_loss="%.4f" % va_loss,
|
|
114
|
+
beta="%.3f" % tr_beta,
|
|
115
|
+
gamma="%.3f" % tr_gamma,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self._maybe_early_stop(va_loss, epoch)
|
|
119
|
+
if self._should_stop():
|
|
120
|
+
self.logger.info(
|
|
121
|
+
"Early stopping at epoch %d (best val loss=%.4f)" % (epoch, self.best_val_loss)
|
|
122
|
+
)
|
|
123
|
+
break
|
|
124
|
+
else:
|
|
125
|
+
epoch_iter.set_postfix(
|
|
126
|
+
train_loss="%.4f" % tr_loss,
|
|
127
|
+
beta="%.3f" % tr_beta,
|
|
128
|
+
gamma="%.3f" % tr_gamma,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if self.val_loader is not None and self.best_state_dict is not None:
|
|
132
|
+
self.model.load_state_dict(self.best_state_dict)
|
|
133
|
+
if self.best_epoch is not None:
|
|
134
|
+
self.logger.info(
|
|
135
|
+
"Restored best model from epoch %d (val loss=%.4f)"
|
|
136
|
+
% (self.best_epoch, self.best_val_loss)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
return self.history
|
|
140
|
+
|
|
141
|
+
# ------------------------------ model state handling ------------------------------
|
|
142
|
+
|
|
143
|
+
def state_dict(self) -> Dict[str, Any]:
|
|
144
|
+
# Keep this small. Large dataset-derived items (coords) are not stored here.
|
|
145
|
+
return {
|
|
146
|
+
"history": self.history,
|
|
147
|
+
"best_val_loss": float(self.best_val_loss),
|
|
148
|
+
"best_epoch": self.best_epoch,
|
|
149
|
+
"epochs_no_improve": int(self.epochs_no_improve),
|
|
150
|
+
"use_amp": bool(self.use_amp),
|
|
151
|
+
"amp_dtype": str(self.amp_dtype),
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
def save(self, path: str, *, extra: Optional[Dict[str, Any]] = None, save_best: bool = False) -> None:
|
|
155
|
+
model_state = self.best_state_dict if (save_best and self.best_state_dict is not None) else self.model.state_dict()
|
|
156
|
+
|
|
157
|
+
scaler_state = None
|
|
158
|
+
if self._scaler is not None and self._scaler.is_enabled():
|
|
159
|
+
try:
|
|
160
|
+
scaler_state = self._scaler.state_dict()
|
|
161
|
+
except Exception:
|
|
162
|
+
scaler_state = None
|
|
163
|
+
|
|
164
|
+
save_checkpoint(
|
|
165
|
+
path,
|
|
166
|
+
model_state=model_state,
|
|
167
|
+
optimizer_state=self.optimizer.state_dict(),
|
|
168
|
+
extra=extra,
|
|
169
|
+
model=self.model,
|
|
170
|
+
trainer_state=self.state_dict(),
|
|
171
|
+
scaler_state=scaler_state,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
def load(
|
|
175
|
+
self,
|
|
176
|
+
path: str,
|
|
177
|
+
*,
|
|
178
|
+
map_location: Union[str, torch.device, None] = "cpu",
|
|
179
|
+
strict: bool = True,
|
|
180
|
+
restore_label_names: bool = True,
|
|
181
|
+
enforce_label_compat: bool = True,
|
|
182
|
+
) -> Dict[str, Any]:
|
|
183
|
+
payload = restore_checkpoint(
|
|
184
|
+
path,
|
|
185
|
+
model=self.model,
|
|
186
|
+
optimizer=self.optimizer,
|
|
187
|
+
scaler=self._scaler,
|
|
188
|
+
map_location=map_location,
|
|
189
|
+
strict=strict,
|
|
190
|
+
restore_label_names=restore_label_names,
|
|
191
|
+
enforce_label_compat=enforce_label_compat,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
ts = payload.get("trainer_state", None)
|
|
195
|
+
if isinstance(ts, dict):
|
|
196
|
+
self.history = ts.get("history", self.history)
|
|
197
|
+
self.best_val_loss = float(ts.get("best_val_loss", self.best_val_loss))
|
|
198
|
+
self.best_epoch = ts.get("best_epoch", self.best_epoch)
|
|
199
|
+
self.epochs_no_improve = int(ts.get("epochs_no_improve", self.epochs_no_improve))
|
|
200
|
+
|
|
201
|
+
self.model.to(self.device)
|
|
202
|
+
return payload
|
|
203
|
+
|
|
204
|
+
# ------------------------------ batch handling ------------------------------
|
|
205
|
+
|
|
206
|
+
def _split_batch(self, batch: BatchType) -> Tuple[Dict[str, torch.Tensor], Optional[YType]]:
|
|
207
|
+
"""
|
|
208
|
+
Accepts either:
|
|
209
|
+
- x_dict
|
|
210
|
+
- (x_dict, y) where y is either:
|
|
211
|
+
* LongTensor (B,) [back-compat]
|
|
212
|
+
* dict[str -> LongTensor(B,)] [multi-head]
|
|
213
|
+
|
|
214
|
+
Moves tensors to device. Does NOT force-cast x modalities (keeps float32, float16, etc).
|
|
215
|
+
Ensures y is long if provided.
|
|
216
|
+
"""
|
|
217
|
+
if isinstance(batch, (tuple, list)) and len(batch) == 2:
|
|
218
|
+
x_dict, y = batch
|
|
219
|
+
else:
|
|
220
|
+
x_dict, y = batch, None
|
|
221
|
+
|
|
222
|
+
if not isinstance(x_dict, dict):
|
|
223
|
+
raise TypeError(f"Expected batch to be dict or (dict, y). Got {type(x_dict)!r}")
|
|
224
|
+
|
|
225
|
+
x_out: Dict[str, torch.Tensor] = {}
|
|
226
|
+
for k, v in x_dict.items():
|
|
227
|
+
if v is None:
|
|
228
|
+
x_out[k] = None # type: ignore[assignment]
|
|
229
|
+
continue
|
|
230
|
+
if torch.is_tensor(v):
|
|
231
|
+
x_out[k] = v.to(self.device, non_blocking=True)
|
|
232
|
+
else:
|
|
233
|
+
x_out[k] = torch.as_tensor(v, dtype=torch.float32, device=self.device)
|
|
234
|
+
|
|
235
|
+
y_out: Optional[YType] = None
|
|
236
|
+
if y is not None:
|
|
237
|
+
if isinstance(y, Mapping):
|
|
238
|
+
yd: Dict[str, torch.Tensor] = {}
|
|
239
|
+
for hk, hv in y.items():
|
|
240
|
+
if hv is None:
|
|
241
|
+
yd[str(hk)] = None # type: ignore[assignment]
|
|
242
|
+
continue
|
|
243
|
+
if not torch.is_tensor(hv):
|
|
244
|
+
hv = torch.as_tensor(hv)
|
|
245
|
+
yd[str(hk)] = hv.long().to(self.device, non_blocking=True)
|
|
246
|
+
y_out = yd
|
|
247
|
+
else:
|
|
248
|
+
if not torch.is_tensor(y):
|
|
249
|
+
y = torch.as_tensor(y)
|
|
250
|
+
y_out = y.long().to(self.device, non_blocking=True)
|
|
251
|
+
|
|
252
|
+
return x_out, y_out
|
|
253
|
+
|
|
254
|
+
# ------------------------------ forward wrappers ------------------------------
|
|
255
|
+
|
|
256
|
+
def _forward_model(self, x_dict: Dict[str, torch.Tensor], epoch: int, y: Optional[YType]):
|
|
257
|
+
"""
|
|
258
|
+
Best-effort forward dispatch for multiple model signatures.
|
|
259
|
+
|
|
260
|
+
Tries newest first:
|
|
261
|
+
model(x, epoch=..., y=..., feature_coords=..., attn_bias_cfg=...)
|
|
262
|
+
|
|
263
|
+
Then falls back to:
|
|
264
|
+
model(x, epoch=..., y=...)
|
|
265
|
+
model(x, epoch=...)
|
|
266
|
+
model(x, y=...)
|
|
267
|
+
model(x)
|
|
268
|
+
"""
|
|
269
|
+
fc = self.feature_coords if self.feature_coords else None
|
|
270
|
+
ab = self.attn_bias_cfg if self.attn_bias_cfg else None
|
|
271
|
+
|
|
272
|
+
# Newest signature (includes optional bias plumbing)
|
|
273
|
+
try:
|
|
274
|
+
return self.model(x_dict, epoch=epoch, y=y, feature_coords=fc, attn_bias_cfg=ab)
|
|
275
|
+
except TypeError:
|
|
276
|
+
pass
|
|
277
|
+
|
|
278
|
+
# Next: epoch + y
|
|
279
|
+
try:
|
|
280
|
+
return self.model(x_dict, epoch=epoch, y=y)
|
|
281
|
+
except TypeError:
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
# epoch only
|
|
285
|
+
try:
|
|
286
|
+
return self.model(x_dict, epoch=epoch)
|
|
287
|
+
except TypeError:
|
|
288
|
+
pass
|
|
289
|
+
|
|
290
|
+
# y only
|
|
291
|
+
if y is not None:
|
|
292
|
+
try:
|
|
293
|
+
return self.model(x_dict, y=y)
|
|
294
|
+
except TypeError:
|
|
295
|
+
pass
|
|
296
|
+
|
|
297
|
+
# oldest
|
|
298
|
+
return self.model(x_dict)
|
|
299
|
+
|
|
300
|
+
@staticmethod
|
|
301
|
+
def _as_float(v: Any) -> float:
|
|
302
|
+
if v is None:
|
|
303
|
+
return 0.0
|
|
304
|
+
if isinstance(v, (float, int)):
|
|
305
|
+
return float(v)
|
|
306
|
+
if torch.is_tensor(v):
|
|
307
|
+
return float(v.detach().cpu().item())
|
|
308
|
+
try:
|
|
309
|
+
return float(v)
|
|
310
|
+
except Exception:
|
|
311
|
+
return 0.0
|
|
312
|
+
|
|
313
|
+
def _amp_context(self):
|
|
314
|
+
if not (self.use_amp and torch.cuda.is_available() and str(self.device).startswith("cuda")):
|
|
315
|
+
return contextlib.nullcontext()
|
|
316
|
+
dtype = torch.float16 if self.amp_dtype == "fp16" else torch.bfloat16
|
|
317
|
+
return torch.autocast(device_type="cuda", dtype=dtype, enabled=True)
|
|
318
|
+
|
|
319
|
+
# ------------------------------ epoch loop ------------------------------
|
|
320
|
+
|
|
321
|
+
def _run_one_epoch(self, epoch: int, train: bool = True) -> Tuple[float, float, float]:
|
|
322
|
+
if train:
|
|
323
|
+
self.model.train()
|
|
324
|
+
loader = self.train_loader
|
|
325
|
+
else:
|
|
326
|
+
self.model.eval()
|
|
327
|
+
if self.val_loader is None:
|
|
328
|
+
raise ValueError("val_loader is None but train=False")
|
|
329
|
+
loader = self.val_loader
|
|
330
|
+
|
|
331
|
+
total_loss = 0.0
|
|
332
|
+
total_beta = 0.0
|
|
333
|
+
total_gamma = 0.0
|
|
334
|
+
n_batches = 0
|
|
335
|
+
|
|
336
|
+
for batch in loader:
|
|
337
|
+
x_dict, y = self._split_batch(batch)
|
|
338
|
+
|
|
339
|
+
if train:
|
|
340
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
341
|
+
|
|
342
|
+
with torch.set_grad_enabled(train):
|
|
343
|
+
with self._amp_context():
|
|
344
|
+
out = self._forward_model(x_dict, epoch=epoch, y=y)
|
|
345
|
+
loss = out["loss"] if train else out.get("loss_fixed", out["loss"])
|
|
346
|
+
|
|
347
|
+
if train:
|
|
348
|
+
if self._scaler is not None and self._scaler.is_enabled():
|
|
349
|
+
self._scaler.scale(loss).backward()
|
|
350
|
+
if self.cfg.grad_clip is not None and float(self.cfg.grad_clip) > 0:
|
|
351
|
+
self._scaler.unscale_(self.optimizer)
|
|
352
|
+
nn.utils.clip_grad_norm_(self.model.parameters(), float(self.cfg.grad_clip))
|
|
353
|
+
self._scaler.step(self.optimizer)
|
|
354
|
+
self._scaler.update()
|
|
355
|
+
else:
|
|
356
|
+
loss.backward()
|
|
357
|
+
if self.cfg.grad_clip is not None and float(self.cfg.grad_clip) > 0:
|
|
358
|
+
nn.utils.clip_grad_norm_(self.model.parameters(), float(self.cfg.grad_clip))
|
|
359
|
+
self.optimizer.step()
|
|
360
|
+
|
|
361
|
+
total_loss += float(loss.detach().cpu().item())
|
|
362
|
+
total_beta += self._as_float(out.get("beta_used", out.get("beta", 0.0)))
|
|
363
|
+
total_gamma += self._as_float(out.get("gamma_used", out.get("gamma", 0.0)))
|
|
364
|
+
n_batches += 1
|
|
365
|
+
|
|
366
|
+
avg_loss = total_loss / max(1, n_batches)
|
|
367
|
+
avg_beta = total_beta / max(1, n_batches)
|
|
368
|
+
avg_gamma = total_gamma / max(1, n_batches)
|
|
369
|
+
|
|
370
|
+
if epoch % int(self.cfg.log_every) == 0 or epoch == 1:
|
|
371
|
+
self.logger.info(
|
|
372
|
+
"[Epoch %03d] %s loss=%.4f (beta=%.3f, gamma=%.3f)"
|
|
373
|
+
% (epoch, "Train" if train else "Val", avg_loss, avg_beta, avg_gamma)
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
return avg_loss, avg_beta, avg_gamma
|
|
377
|
+
|
|
378
|
+
# ------------------------------ early stopping ------------------------------
|
|
379
|
+
|
|
380
|
+
def _maybe_early_stop(self, val_loss: float, epoch: int) -> None:
|
|
381
|
+
if not self.cfg.early_stopping:
|
|
382
|
+
return
|
|
383
|
+
|
|
384
|
+
improved = (self.best_val_loss - float(val_loss)) > float(self.cfg.min_delta)
|
|
385
|
+
if improved:
|
|
386
|
+
self.best_val_loss = float(val_loss)
|
|
387
|
+
self.best_state_dict = {k: v.detach().cpu().clone() for k, v in self.model.state_dict().items()}
|
|
388
|
+
self.best_epoch = int(epoch)
|
|
389
|
+
self.epochs_no_improve = 0
|
|
390
|
+
self.logger.info("[Epoch %03d] New best val loss: %.4f" % (epoch, val_loss))
|
|
391
|
+
else:
|
|
392
|
+
self.epochs_no_improve += 1
|
|
393
|
+
|
|
394
|
+
def _should_stop(self) -> bool:
|
|
395
|
+
if not self.cfg.early_stopping:
|
|
396
|
+
return False
|
|
397
|
+
return int(self.epochs_no_improve) >= int(self.cfg.patience)
|
|
398
|
+
|
|
399
|
+
# ------------------------------ encoding utility ------------------------------
|
|
400
|
+
|
|
401
|
+
def encode_modality(
|
|
402
|
+
self,
|
|
403
|
+
adata: ad.AnnData,
|
|
404
|
+
modality: str,
|
|
405
|
+
layer: Optional[str] = None,
|
|
406
|
+
X_key: str = "X",
|
|
407
|
+
obs_key: Optional[str] = None,
|
|
408
|
+
batch_size: int = 512,
|
|
409
|
+
*,
|
|
410
|
+
use_moe: bool = True,
|
|
411
|
+
) -> np.ndarray:
|
|
412
|
+
"""
|
|
413
|
+
Encode a single modality AnnData into latent means.
|
|
414
|
+
|
|
415
|
+
By default, this returns the fused mean via MoE/PoE if available (use_moe=True),
|
|
416
|
+
which is identical to the modality posterior when only one modality is provided.
|
|
417
|
+
|
|
418
|
+
If you want strictly per-modality posterior mean, set use_moe=False.
|
|
419
|
+
"""
|
|
420
|
+
names = getattr(self.model, "modality_names", None)
|
|
421
|
+
if names is not None and modality not in names:
|
|
422
|
+
raise ValueError("Unknown modality %r. Available: %s" % (modality, names))
|
|
423
|
+
|
|
424
|
+
self.model.eval()
|
|
425
|
+
|
|
426
|
+
if obs_key is not None:
|
|
427
|
+
if obs_key not in adata.obs:
|
|
428
|
+
raise KeyError(f"obs_key={obs_key!r} not found in adata.obs.")
|
|
429
|
+
col = adata.obs[obs_key]
|
|
430
|
+
if hasattr(col, "cat"):
|
|
431
|
+
vals = col.cat.codes.to_numpy()
|
|
432
|
+
else:
|
|
433
|
+
vals = np.asarray(col.values)
|
|
434
|
+
X = vals.astype(np.float32).reshape(-1, 1)
|
|
435
|
+
else:
|
|
436
|
+
if X_key != "X":
|
|
437
|
+
if X_key not in adata.obsm:
|
|
438
|
+
raise KeyError(f"X_key={X_key!r} not in adata.obsm.")
|
|
439
|
+
X = adata.obsm[X_key]
|
|
440
|
+
else:
|
|
441
|
+
X = adata.layers[layer] if layer is not None else adata.X
|
|
442
|
+
|
|
443
|
+
if sp.issparse(X):
|
|
444
|
+
X = X.toarray()
|
|
445
|
+
X = np.asarray(X, dtype=np.float32)
|
|
446
|
+
|
|
447
|
+
zs = []
|
|
448
|
+
dev = torch.device(self.device)
|
|
449
|
+
|
|
450
|
+
with torch.no_grad():
|
|
451
|
+
for start in range(0, X.shape[0], int(batch_size)):
|
|
452
|
+
end = min(start + int(batch_size), X.shape[0])
|
|
453
|
+
xb = torch.as_tensor(X[start:end], dtype=torch.float32, device=dev)
|
|
454
|
+
|
|
455
|
+
mu_dict, logvar_dict = self.model.encode_modalities({modality: xb})
|
|
456
|
+
|
|
457
|
+
if use_moe:
|
|
458
|
+
if hasattr(self.model, "mixture_of_experts"):
|
|
459
|
+
mu_z, _ = self.model.mixture_of_experts(mu_dict, logvar_dict)
|
|
460
|
+
elif hasattr(self.model, "fuse_posteriors"):
|
|
461
|
+
mu_z, _ = self.model.fuse_posteriors(mu_dict, logvar_dict)
|
|
462
|
+
else:
|
|
463
|
+
mu_z = mu_dict[modality]
|
|
464
|
+
else:
|
|
465
|
+
mu_z = mu_dict[modality]
|
|
466
|
+
|
|
467
|
+
zs.append(mu_z.detach().cpu().numpy())
|
|
468
|
+
|
|
469
|
+
return np.vstack(zs)
|
|
470
|
+
|
|
471
|
+
# ------------------------------ logging ------------------------------
|
|
472
|
+
|
|
473
|
+
def _log_config(self) -> None:
|
|
474
|
+
cfg_dict = asdict(self.cfg)
|
|
475
|
+
self.logger.info("TrainingConfig:")
|
|
476
|
+
for k, v in cfg_dict.items():
|
|
477
|
+
self.logger.info(" %s: %r" % (k, v))
|
|
478
|
+
|