univi 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
univi/trainer.py ADDED
@@ -0,0 +1,478 @@
1
+ # univi/trainer.py
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import asdict
6
+ from typing import Any, Dict, List, Optional, Tuple, Union, Mapping
7
+
8
+ import contextlib
9
+
10
+ import anndata as ad
11
+ import numpy as np
12
+ import scipy.sparse as sp
13
+ import torch
14
+ from torch import nn, optim
15
+ from torch.utils.data import DataLoader
16
+ from tqdm.auto import tqdm
17
+
18
+ from .config import TrainingConfig
19
+ from .utils.io import restore_checkpoint, save_checkpoint
20
+ from .utils.logging import get_logger
21
+
22
+ YType = Union[torch.Tensor, Dict[str, torch.Tensor]]
23
+ BatchType = Union[Dict[str, torch.Tensor], Tuple[Dict[str, torch.Tensor], YType]]
24
+
25
+
26
+ class UniVITrainer:
27
+ """
28
+ Lightweight training loop for UniVI models.
29
+
30
+ Supports:
31
+ - mixed precision (AMP) via use_amp + amp_dtype
32
+ - optional coordinate metadata and attention-bias configuration for tokenized ATAC
33
+ (passed through to model forward when supported)
34
+ - checkpoint save/load via utils.io helpers
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ model: nn.Module,
40
+ train_loader: DataLoader,
41
+ val_loader: Optional[DataLoader] = None,
42
+ train_cfg: Optional[TrainingConfig] = None,
43
+ device: Optional[str] = None,
44
+ use_amp: bool = False,
45
+ amp_dtype: str = "fp16", # "fp16" or "bf16"
46
+ *,
47
+ feature_coords: Optional[Dict[str, Dict[str, Any]]] = None,
48
+ attn_bias_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
49
+ ):
50
+ self.model = model
51
+ self.train_loader = train_loader
52
+ self.val_loader = val_loader
53
+
54
+ self.cfg = train_cfg or TrainingConfig()
55
+ self.device = device or self.cfg.device
56
+
57
+ self.use_amp = bool(use_amp)
58
+ self.amp_dtype = str(amp_dtype).lower().strip()
59
+ if self.amp_dtype not in ("fp16", "bf16"):
60
+ raise ValueError(f"amp_dtype must be 'fp16' or 'bf16', got {amp_dtype!r}")
61
+
62
+ # Optional genomic context / transformer bias configuration (safe defaults).
63
+ # These are *data-derived* and intentionally not serialized by default.
64
+ self.feature_coords: Dict[str, Dict[str, Any]] = feature_coords or {}
65
+ self.attn_bias_cfg: Dict[str, Dict[str, Any]] = attn_bias_cfg or {}
66
+
67
+ self.model.to(self.device)
68
+
69
+ self.optimizer = optim.Adam(
70
+ self.model.parameters(),
71
+ lr=float(self.cfg.lr),
72
+ weight_decay=float(self.cfg.weight_decay),
73
+ )
74
+
75
+ self._scaler: Optional[torch.cuda.amp.GradScaler] = None
76
+ if self.use_amp and torch.cuda.is_available() and str(self.device).startswith("cuda"):
77
+ # GradScaler is only used for fp16. bf16 autocast generally does not need scaling.
78
+ self._scaler = torch.cuda.amp.GradScaler(enabled=(self.amp_dtype == "fp16"))
79
+
80
+ self.logger = get_logger("UniVITrainer")
81
+
82
+ self.best_val_loss = float("inf")
83
+ self.best_state_dict: Optional[Dict[str, Any]] = None
84
+ self.epochs_no_improve = 0
85
+ self.best_epoch: Optional[int] = None
86
+
87
+ self.history: Dict[str, List[float]] = {
88
+ "train_loss": [],
89
+ "val_loss": [],
90
+ "beta": [],
91
+ "gamma": [],
92
+ }
93
+
94
+ self._log_config()
95
+
96
+ # ------------------------------ training ------------------------------
97
+
98
+ def fit(self) -> Dict[str, List[float]]:
99
+ epoch_iter = tqdm(range(1, int(self.cfg.n_epochs) + 1), desc="Training UniVI", leave=True)
100
+
101
+ for epoch in epoch_iter:
102
+ tr_loss, tr_beta, tr_gamma = self._run_one_epoch(epoch, train=True)
103
+ self.history["train_loss"].append(tr_loss)
104
+ self.history["beta"].append(tr_beta)
105
+ self.history["gamma"].append(tr_gamma)
106
+
107
+ if self.val_loader is not None:
108
+ va_loss, _, _ = self._run_one_epoch(epoch, train=False)
109
+ self.history["val_loss"].append(va_loss)
110
+
111
+ epoch_iter.set_postfix(
112
+ train_loss="%.4f" % tr_loss,
113
+ val_loss="%.4f" % va_loss,
114
+ beta="%.3f" % tr_beta,
115
+ gamma="%.3f" % tr_gamma,
116
+ )
117
+
118
+ self._maybe_early_stop(va_loss, epoch)
119
+ if self._should_stop():
120
+ self.logger.info(
121
+ "Early stopping at epoch %d (best val loss=%.4f)" % (epoch, self.best_val_loss)
122
+ )
123
+ break
124
+ else:
125
+ epoch_iter.set_postfix(
126
+ train_loss="%.4f" % tr_loss,
127
+ beta="%.3f" % tr_beta,
128
+ gamma="%.3f" % tr_gamma,
129
+ )
130
+
131
+ if self.val_loader is not None and self.best_state_dict is not None:
132
+ self.model.load_state_dict(self.best_state_dict)
133
+ if self.best_epoch is not None:
134
+ self.logger.info(
135
+ "Restored best model from epoch %d (val loss=%.4f)"
136
+ % (self.best_epoch, self.best_val_loss)
137
+ )
138
+
139
+ return self.history
140
+
141
+ # ------------------------------ model state handling ------------------------------
142
+
143
+ def state_dict(self) -> Dict[str, Any]:
144
+ # Keep this small. Large dataset-derived items (coords) are not stored here.
145
+ return {
146
+ "history": self.history,
147
+ "best_val_loss": float(self.best_val_loss),
148
+ "best_epoch": self.best_epoch,
149
+ "epochs_no_improve": int(self.epochs_no_improve),
150
+ "use_amp": bool(self.use_amp),
151
+ "amp_dtype": str(self.amp_dtype),
152
+ }
153
+
154
+ def save(self, path: str, *, extra: Optional[Dict[str, Any]] = None, save_best: bool = False) -> None:
155
+ model_state = self.best_state_dict if (save_best and self.best_state_dict is not None) else self.model.state_dict()
156
+
157
+ scaler_state = None
158
+ if self._scaler is not None and self._scaler.is_enabled():
159
+ try:
160
+ scaler_state = self._scaler.state_dict()
161
+ except Exception:
162
+ scaler_state = None
163
+
164
+ save_checkpoint(
165
+ path,
166
+ model_state=model_state,
167
+ optimizer_state=self.optimizer.state_dict(),
168
+ extra=extra,
169
+ model=self.model,
170
+ trainer_state=self.state_dict(),
171
+ scaler_state=scaler_state,
172
+ )
173
+
174
+ def load(
175
+ self,
176
+ path: str,
177
+ *,
178
+ map_location: Union[str, torch.device, None] = "cpu",
179
+ strict: bool = True,
180
+ restore_label_names: bool = True,
181
+ enforce_label_compat: bool = True,
182
+ ) -> Dict[str, Any]:
183
+ payload = restore_checkpoint(
184
+ path,
185
+ model=self.model,
186
+ optimizer=self.optimizer,
187
+ scaler=self._scaler,
188
+ map_location=map_location,
189
+ strict=strict,
190
+ restore_label_names=restore_label_names,
191
+ enforce_label_compat=enforce_label_compat,
192
+ )
193
+
194
+ ts = payload.get("trainer_state", None)
195
+ if isinstance(ts, dict):
196
+ self.history = ts.get("history", self.history)
197
+ self.best_val_loss = float(ts.get("best_val_loss", self.best_val_loss))
198
+ self.best_epoch = ts.get("best_epoch", self.best_epoch)
199
+ self.epochs_no_improve = int(ts.get("epochs_no_improve", self.epochs_no_improve))
200
+
201
+ self.model.to(self.device)
202
+ return payload
203
+
204
+ # ------------------------------ batch handling ------------------------------
205
+
206
+ def _split_batch(self, batch: BatchType) -> Tuple[Dict[str, torch.Tensor], Optional[YType]]:
207
+ """
208
+ Accepts either:
209
+ - x_dict
210
+ - (x_dict, y) where y is either:
211
+ * LongTensor (B,) [back-compat]
212
+ * dict[str -> LongTensor(B,)] [multi-head]
213
+
214
+ Moves tensors to device. Does NOT force-cast x modalities (keeps float32, float16, etc).
215
+ Ensures y is long if provided.
216
+ """
217
+ if isinstance(batch, (tuple, list)) and len(batch) == 2:
218
+ x_dict, y = batch
219
+ else:
220
+ x_dict, y = batch, None
221
+
222
+ if not isinstance(x_dict, dict):
223
+ raise TypeError(f"Expected batch to be dict or (dict, y). Got {type(x_dict)!r}")
224
+
225
+ x_out: Dict[str, torch.Tensor] = {}
226
+ for k, v in x_dict.items():
227
+ if v is None:
228
+ x_out[k] = None # type: ignore[assignment]
229
+ continue
230
+ if torch.is_tensor(v):
231
+ x_out[k] = v.to(self.device, non_blocking=True)
232
+ else:
233
+ x_out[k] = torch.as_tensor(v, dtype=torch.float32, device=self.device)
234
+
235
+ y_out: Optional[YType] = None
236
+ if y is not None:
237
+ if isinstance(y, Mapping):
238
+ yd: Dict[str, torch.Tensor] = {}
239
+ for hk, hv in y.items():
240
+ if hv is None:
241
+ yd[str(hk)] = None # type: ignore[assignment]
242
+ continue
243
+ if not torch.is_tensor(hv):
244
+ hv = torch.as_tensor(hv)
245
+ yd[str(hk)] = hv.long().to(self.device, non_blocking=True)
246
+ y_out = yd
247
+ else:
248
+ if not torch.is_tensor(y):
249
+ y = torch.as_tensor(y)
250
+ y_out = y.long().to(self.device, non_blocking=True)
251
+
252
+ return x_out, y_out
253
+
254
+ # ------------------------------ forward wrappers ------------------------------
255
+
256
+ def _forward_model(self, x_dict: Dict[str, torch.Tensor], epoch: int, y: Optional[YType]):
257
+ """
258
+ Best-effort forward dispatch for multiple model signatures.
259
+
260
+ Tries newest first:
261
+ model(x, epoch=..., y=..., feature_coords=..., attn_bias_cfg=...)
262
+
263
+ Then falls back to:
264
+ model(x, epoch=..., y=...)
265
+ model(x, epoch=...)
266
+ model(x, y=...)
267
+ model(x)
268
+ """
269
+ fc = self.feature_coords if self.feature_coords else None
270
+ ab = self.attn_bias_cfg if self.attn_bias_cfg else None
271
+
272
+ # Newest signature (includes optional bias plumbing)
273
+ try:
274
+ return self.model(x_dict, epoch=epoch, y=y, feature_coords=fc, attn_bias_cfg=ab)
275
+ except TypeError:
276
+ pass
277
+
278
+ # Next: epoch + y
279
+ try:
280
+ return self.model(x_dict, epoch=epoch, y=y)
281
+ except TypeError:
282
+ pass
283
+
284
+ # epoch only
285
+ try:
286
+ return self.model(x_dict, epoch=epoch)
287
+ except TypeError:
288
+ pass
289
+
290
+ # y only
291
+ if y is not None:
292
+ try:
293
+ return self.model(x_dict, y=y)
294
+ except TypeError:
295
+ pass
296
+
297
+ # oldest
298
+ return self.model(x_dict)
299
+
300
+ @staticmethod
301
+ def _as_float(v: Any) -> float:
302
+ if v is None:
303
+ return 0.0
304
+ if isinstance(v, (float, int)):
305
+ return float(v)
306
+ if torch.is_tensor(v):
307
+ return float(v.detach().cpu().item())
308
+ try:
309
+ return float(v)
310
+ except Exception:
311
+ return 0.0
312
+
313
+ def _amp_context(self):
314
+ if not (self.use_amp and torch.cuda.is_available() and str(self.device).startswith("cuda")):
315
+ return contextlib.nullcontext()
316
+ dtype = torch.float16 if self.amp_dtype == "fp16" else torch.bfloat16
317
+ return torch.autocast(device_type="cuda", dtype=dtype, enabled=True)
318
+
319
+ # ------------------------------ epoch loop ------------------------------
320
+
321
+ def _run_one_epoch(self, epoch: int, train: bool = True) -> Tuple[float, float, float]:
322
+ if train:
323
+ self.model.train()
324
+ loader = self.train_loader
325
+ else:
326
+ self.model.eval()
327
+ if self.val_loader is None:
328
+ raise ValueError("val_loader is None but train=False")
329
+ loader = self.val_loader
330
+
331
+ total_loss = 0.0
332
+ total_beta = 0.0
333
+ total_gamma = 0.0
334
+ n_batches = 0
335
+
336
+ for batch in loader:
337
+ x_dict, y = self._split_batch(batch)
338
+
339
+ if train:
340
+ self.optimizer.zero_grad(set_to_none=True)
341
+
342
+ with torch.set_grad_enabled(train):
343
+ with self._amp_context():
344
+ out = self._forward_model(x_dict, epoch=epoch, y=y)
345
+ loss = out["loss"] if train else out.get("loss_fixed", out["loss"])
346
+
347
+ if train:
348
+ if self._scaler is not None and self._scaler.is_enabled():
349
+ self._scaler.scale(loss).backward()
350
+ if self.cfg.grad_clip is not None and float(self.cfg.grad_clip) > 0:
351
+ self._scaler.unscale_(self.optimizer)
352
+ nn.utils.clip_grad_norm_(self.model.parameters(), float(self.cfg.grad_clip))
353
+ self._scaler.step(self.optimizer)
354
+ self._scaler.update()
355
+ else:
356
+ loss.backward()
357
+ if self.cfg.grad_clip is not None and float(self.cfg.grad_clip) > 0:
358
+ nn.utils.clip_grad_norm_(self.model.parameters(), float(self.cfg.grad_clip))
359
+ self.optimizer.step()
360
+
361
+ total_loss += float(loss.detach().cpu().item())
362
+ total_beta += self._as_float(out.get("beta_used", out.get("beta", 0.0)))
363
+ total_gamma += self._as_float(out.get("gamma_used", out.get("gamma", 0.0)))
364
+ n_batches += 1
365
+
366
+ avg_loss = total_loss / max(1, n_batches)
367
+ avg_beta = total_beta / max(1, n_batches)
368
+ avg_gamma = total_gamma / max(1, n_batches)
369
+
370
+ if epoch % int(self.cfg.log_every) == 0 or epoch == 1:
371
+ self.logger.info(
372
+ "[Epoch %03d] %s loss=%.4f (beta=%.3f, gamma=%.3f)"
373
+ % (epoch, "Train" if train else "Val", avg_loss, avg_beta, avg_gamma)
374
+ )
375
+
376
+ return avg_loss, avg_beta, avg_gamma
377
+
378
+ # ------------------------------ early stopping ------------------------------
379
+
380
+ def _maybe_early_stop(self, val_loss: float, epoch: int) -> None:
381
+ if not self.cfg.early_stopping:
382
+ return
383
+
384
+ improved = (self.best_val_loss - float(val_loss)) > float(self.cfg.min_delta)
385
+ if improved:
386
+ self.best_val_loss = float(val_loss)
387
+ self.best_state_dict = {k: v.detach().cpu().clone() for k, v in self.model.state_dict().items()}
388
+ self.best_epoch = int(epoch)
389
+ self.epochs_no_improve = 0
390
+ self.logger.info("[Epoch %03d] New best val loss: %.4f" % (epoch, val_loss))
391
+ else:
392
+ self.epochs_no_improve += 1
393
+
394
+ def _should_stop(self) -> bool:
395
+ if not self.cfg.early_stopping:
396
+ return False
397
+ return int(self.epochs_no_improve) >= int(self.cfg.patience)
398
+
399
+ # ------------------------------ encoding utility ------------------------------
400
+
401
+ def encode_modality(
402
+ self,
403
+ adata: ad.AnnData,
404
+ modality: str,
405
+ layer: Optional[str] = None,
406
+ X_key: str = "X",
407
+ obs_key: Optional[str] = None,
408
+ batch_size: int = 512,
409
+ *,
410
+ use_moe: bool = True,
411
+ ) -> np.ndarray:
412
+ """
413
+ Encode a single modality AnnData into latent means.
414
+
415
+ By default, this returns the fused mean via MoE/PoE if available (use_moe=True),
416
+ which is identical to the modality posterior when only one modality is provided.
417
+
418
+ If you want strictly per-modality posterior mean, set use_moe=False.
419
+ """
420
+ names = getattr(self.model, "modality_names", None)
421
+ if names is not None and modality not in names:
422
+ raise ValueError("Unknown modality %r. Available: %s" % (modality, names))
423
+
424
+ self.model.eval()
425
+
426
+ if obs_key is not None:
427
+ if obs_key not in adata.obs:
428
+ raise KeyError(f"obs_key={obs_key!r} not found in adata.obs.")
429
+ col = adata.obs[obs_key]
430
+ if hasattr(col, "cat"):
431
+ vals = col.cat.codes.to_numpy()
432
+ else:
433
+ vals = np.asarray(col.values)
434
+ X = vals.astype(np.float32).reshape(-1, 1)
435
+ else:
436
+ if X_key != "X":
437
+ if X_key not in adata.obsm:
438
+ raise KeyError(f"X_key={X_key!r} not in adata.obsm.")
439
+ X = adata.obsm[X_key]
440
+ else:
441
+ X = adata.layers[layer] if layer is not None else adata.X
442
+
443
+ if sp.issparse(X):
444
+ X = X.toarray()
445
+ X = np.asarray(X, dtype=np.float32)
446
+
447
+ zs = []
448
+ dev = torch.device(self.device)
449
+
450
+ with torch.no_grad():
451
+ for start in range(0, X.shape[0], int(batch_size)):
452
+ end = min(start + int(batch_size), X.shape[0])
453
+ xb = torch.as_tensor(X[start:end], dtype=torch.float32, device=dev)
454
+
455
+ mu_dict, logvar_dict = self.model.encode_modalities({modality: xb})
456
+
457
+ if use_moe:
458
+ if hasattr(self.model, "mixture_of_experts"):
459
+ mu_z, _ = self.model.mixture_of_experts(mu_dict, logvar_dict)
460
+ elif hasattr(self.model, "fuse_posteriors"):
461
+ mu_z, _ = self.model.fuse_posteriors(mu_dict, logvar_dict)
462
+ else:
463
+ mu_z = mu_dict[modality]
464
+ else:
465
+ mu_z = mu_dict[modality]
466
+
467
+ zs.append(mu_z.detach().cpu().numpy())
468
+
469
+ return np.vstack(zs)
470
+
471
+ # ------------------------------ logging ------------------------------
472
+
473
+ def _log_config(self) -> None:
474
+ cfg_dict = asdict(self.cfg)
475
+ self.logger.info("TrainingConfig:")
476
+ for k, v in cfg_dict.items():
477
+ self.logger.info(" %s: %r" % (k, v))
478
+
@@ -0,0 +1,5 @@
1
+ from .io import write_univi_latent
2
+
3
+ __all__ = [
4
+ "write_univi_latent",
5
+ ]