cyclevi 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cyclevi-0.1.0/CycleVI_model.py +1438 -0
- cyclevi-0.1.0/LICENSE +28 -0
- cyclevi-0.1.0/PKG-INFO +103 -0
- cyclevi-0.1.0/README.md +45 -0
- cyclevi-0.1.0/Tutorial.ipynb +887 -0
- cyclevi-0.1.0/Tutorial_colab.ipynb +1352 -0
- cyclevi-0.1.0/cyclevi/__init__.py +27 -0
- cyclevi-0.1.0/cyclevi/data/__init__.py +0 -0
- cyclevi-0.1.0/cyclevi/data/homo_sapiens_cc_genes.csv +98 -0
- cyclevi-0.1.0/cyclevi/model.py +1438 -0
- cyclevi-0.1.0/cyclevi.egg-info/PKG-INFO +103 -0
- cyclevi-0.1.0/cyclevi.egg-info/SOURCES.txt +16 -0
- cyclevi-0.1.0/cyclevi.egg-info/dependency_links.txt +1 -0
- cyclevi-0.1.0/cyclevi.egg-info/requires.txt +5 -0
- cyclevi-0.1.0/cyclevi.egg-info/top_level.txt +1 -0
- cyclevi-0.1.0/data/homo_sapiens_cc_genes.csv +98 -0
- cyclevi-0.1.0/pyproject.toml +46 -0
- cyclevi-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,1438 @@
|
|
|
1
|
+
# ─────────────────────────────────────────────────────────────
|
|
2
|
+
# Imports
|
|
3
|
+
# ─────────────────────────────────────────────────────────────
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import warnings
|
|
7
|
+
from collections.abc import Iterator, Iterable
|
|
8
|
+
from functools import partial
|
|
9
|
+
from numbers import Number
|
|
10
|
+
from typing import Callable, Literal
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
# PyTorch
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from torch import Tensor
|
|
18
|
+
from torch.distributions import Distribution, Normal
|
|
19
|
+
from torch.nn.functional import one_hot
|
|
20
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
21
|
+
|
|
22
|
+
# Single-cell analysis tools
|
|
23
|
+
from anndata import AnnData
|
|
24
|
+
|
|
25
|
+
# scvi-tools
|
|
26
|
+
from scvi import REGISTRY_KEYS, settings
|
|
27
|
+
from scvi.data import AnnDataManager
|
|
28
|
+
from scvi.data._constants import ADATA_MINIFY_TYPE
|
|
29
|
+
from scvi.data._utils import _get_adata_minify_type
|
|
30
|
+
from scvi.data.fields import (
|
|
31
|
+
CategoricalJointObsField,
|
|
32
|
+
CategoricalObsField,
|
|
33
|
+
LayerField,
|
|
34
|
+
NumericalJointObsField,
|
|
35
|
+
NumericalObsField,
|
|
36
|
+
)
|
|
37
|
+
from scvi.distributions._utils import DistributionConcatenator
|
|
38
|
+
from scvi.model._utils import (
|
|
39
|
+
_get_batch_code_from_category,
|
|
40
|
+
_init_library_size,
|
|
41
|
+
scrna_raw_counts_properties,
|
|
42
|
+
)
|
|
43
|
+
from scvi.model.base import (
|
|
44
|
+
ArchesMixin,
|
|
45
|
+
BaseMinifiedModeModelClass,
|
|
46
|
+
EmbeddingMixin,
|
|
47
|
+
RNASeqMixin,
|
|
48
|
+
UnsupervisedTrainingMixin,
|
|
49
|
+
VAEMixin,
|
|
50
|
+
)
|
|
51
|
+
from scvi.model.base._de_core import _de_core
|
|
52
|
+
from scvi.module._constants import MODULE_KEYS
|
|
53
|
+
from scvi.module.base import (
|
|
54
|
+
BaseMinifiedModeModuleClass,
|
|
55
|
+
BaseModuleClass,
|
|
56
|
+
EmbeddingModuleMixin,
|
|
57
|
+
LossOutput,
|
|
58
|
+
auto_move_data,
|
|
59
|
+
)
|
|
60
|
+
from scvi.nn import Encoder, FCLayers
|
|
61
|
+
from scvi.train import (
|
|
62
|
+
TrainingPlan,
|
|
63
|
+
)
|
|
64
|
+
from scvi.utils import (
|
|
65
|
+
setup_anndata_dsp,
|
|
66
|
+
track,
|
|
67
|
+
unsupported_if_adata_minified,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# ─────────────────────────────────────────────────────────────
|
|
71
|
+
# Helpers
|
|
72
|
+
# ─────────────────────────────────────────────────────────────
|
|
73
|
+
|
|
74
|
+
def _identity(x):
|
|
75
|
+
return x
|
|
76
|
+
|
|
77
|
+
# Logger setup
|
|
78
|
+
logger = logging.getLogger(__name__)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# ─────────────────────────────────────────────────────────────
|
|
82
|
+
# Optional: Creating cell cycle gene mask
|
|
83
|
+
# ─────────────────────────────────────────────────────────────
|
|
84
|
+
|
|
85
|
+
def create_cell_cycle_gene_mask(adata: AnnData, genes_txt: str, var_column: str = None) -> torch.Tensor:
|
|
86
|
+
"""
|
|
87
|
+
Create a boolean mask indicating which genes are cell cycle–dependent.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
adata : AnnData
|
|
92
|
+
The AnnData object containing single-cell expression data.
|
|
93
|
+
genes_txt : str
|
|
94
|
+
Path to a text file with a list of cell cycle genes (one per line).
|
|
95
|
+
var_column : str, optional
|
|
96
|
+
Name of a column in adata.var to use for gene identifiers.
|
|
97
|
+
If None, use adata.var_names (default).
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
torch.Tensor
|
|
102
|
+
A boolean tensor of shape (n_genes,) where True indicates
|
|
103
|
+
that the gene is in the cell cycle gene list.
|
|
104
|
+
"""
|
|
105
|
+
gene_list = pd.read_csv(genes_txt, header=None)[0].str.upper().tolist()
|
|
106
|
+
gene_set = set(gene_list)
|
|
107
|
+
|
|
108
|
+
if var_column:
|
|
109
|
+
genes = adata.var[var_column].astype(str).str.upper()
|
|
110
|
+
else:
|
|
111
|
+
genes = adata.var_names.str.upper()
|
|
112
|
+
|
|
113
|
+
return torch.tensor([g in gene_set for g in genes], dtype=torch.bool)
|
|
114
|
+
|
|
115
|
+
# Example usage:
|
|
116
|
+
# Load cell cycle mask from the provided GO annotation file
|
|
117
|
+
# cycle_mask = create_cell_cycle_gene_mask(adata, "GO_cell_cycle_annotation_human.txt")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# ─────────────────────────────────────────────────────────────
|
|
121
|
+
# Cell Cycle Registry Keys
|
|
122
|
+
# ─────────────────────────────────────────────────────────────
|
|
123
|
+
|
|
124
|
+
class CYCLE_REGISTRY_KEYS:
|
|
125
|
+
CYCLE_LABEL_KEY = "cycle_initiation_label"
|
|
126
|
+
CYCLE_ANGLE_KEY = "cycle_initiation_angle"
|
|
127
|
+
|
|
128
|
+
# ─────────────────────────────────────────────────────────────
|
|
129
|
+
# Adversarial Classifier
|
|
130
|
+
# ─────────────────────────────────────────────────────────────
|
|
131
|
+
|
|
132
|
+
class Classifier(nn.Module):
|
|
133
|
+
"""
|
|
134
|
+
Simple feedforward neural network for adversarial classification.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
n_input (int): Dimensionality of input features.
|
|
138
|
+
n_hidden (int): Number of units in hidden layers.
|
|
139
|
+
n_labels (int): Number of output classes.
|
|
140
|
+
n_layers (int): Number of linear layers (default: 2).
|
|
141
|
+
logits (bool): Whether to output raw logits (default: True).
|
|
142
|
+
"""
|
|
143
|
+
def __init__(self, n_input, n_hidden, n_labels, n_layers=2, logits=True):
|
|
144
|
+
super().__init__()
|
|
145
|
+
self.logits = logits
|
|
146
|
+
|
|
147
|
+
layers = []
|
|
148
|
+
in_dim = n_input
|
|
149
|
+
for _ in range(n_layers - 1):
|
|
150
|
+
layers.append(nn.Linear(in_dim, n_hidden))
|
|
151
|
+
layers.append(nn.ReLU())
|
|
152
|
+
in_dim = n_hidden
|
|
153
|
+
layers.append(nn.Linear(in_dim, n_labels))
|
|
154
|
+
|
|
155
|
+
self.network = nn.Sequential(*layers)
|
|
156
|
+
|
|
157
|
+
def forward(self, x):
|
|
158
|
+
return self.network(x)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class PhaseAdversarialTrainingPlan(TrainingPlan):
|
|
162
|
+
"""
|
|
163
|
+
Training plan with adversarial phase classifier to prevent cycle phase information
|
|
164
|
+
leakage in non-circular latent space (z_other).
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
module (BaseModuleClass): scvi-tools model module.
|
|
168
|
+
scale_adversarial_loss (float | str): Scaling factor for adversarial loss, or 'auto'.
|
|
169
|
+
All other arguments follow scvi-tools TrainingPlan.
|
|
170
|
+
"""
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
module: BaseModuleClass,
|
|
174
|
+
*,
|
|
175
|
+
optimizer: str = "Adam",
|
|
176
|
+
optimizer_creator=None,
|
|
177
|
+
lr: float = 1e-3,
|
|
178
|
+
weight_decay: float = 1e-6,
|
|
179
|
+
n_steps_kl_warmup: int = None,
|
|
180
|
+
n_epochs_kl_warmup: int = 400,
|
|
181
|
+
reduce_lr_on_plateau: bool = False,
|
|
182
|
+
lr_factor: float = 0.6,
|
|
183
|
+
lr_patience: int = 30,
|
|
184
|
+
lr_threshold: float = 0.0,
|
|
185
|
+
lr_scheduler_metric: str = "elbo_validation",
|
|
186
|
+
lr_min: float = 0.0,
|
|
187
|
+
scale_adversarial_loss: float | str = "auto",
|
|
188
|
+
compile: bool = False,
|
|
189
|
+
compile_kwargs: dict | None = None,
|
|
190
|
+
**loss_kwargs,
|
|
191
|
+
):
|
|
192
|
+
super().__init__(
|
|
193
|
+
module=module,
|
|
194
|
+
optimizer=optimizer,
|
|
195
|
+
optimizer_creator=optimizer_creator,
|
|
196
|
+
lr=lr,
|
|
197
|
+
weight_decay=weight_decay,
|
|
198
|
+
n_steps_kl_warmup=n_steps_kl_warmup,
|
|
199
|
+
n_epochs_kl_warmup=n_epochs_kl_warmup,
|
|
200
|
+
reduce_lr_on_plateau=reduce_lr_on_plateau,
|
|
201
|
+
lr_factor=lr_factor,
|
|
202
|
+
lr_patience=lr_patience,
|
|
203
|
+
lr_threshold=lr_threshold,
|
|
204
|
+
lr_scheduler_metric=lr_scheduler_metric,
|
|
205
|
+
lr_min=lr_min,
|
|
206
|
+
compile=compile,
|
|
207
|
+
compile_kwargs=compile_kwargs,
|
|
208
|
+
**loss_kwargs,
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Setup adversarial classifier (e.g., to predict 3 discrete phases)
|
|
212
|
+
self.adversarial_classifier = Classifier(
|
|
213
|
+
n_input=self.module.n_latent - 2, # exclude 2D z_cycle
|
|
214
|
+
n_hidden=32,
|
|
215
|
+
n_labels=3,
|
|
216
|
+
n_layers=2,
|
|
217
|
+
logits=True,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
self.scale_adversarial_loss = scale_adversarial_loss
|
|
221
|
+
self.automatic_optimization = False # Manual optimization loop
|
|
222
|
+
|
|
223
|
+
def loss_adversarial_classifier(self, z_other, phase_index, predict_true_class=True):
|
|
224
|
+
"""
|
|
225
|
+
Computes adversarial loss either for classification or fooling.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
z_other (Tensor): Latent space excluding z_cycle.
|
|
229
|
+
phase_index (Tensor): True class indices.
|
|
230
|
+
predict_true_class (bool): If False, trains to fool the classifier.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
Tensor: Loss value.
|
|
234
|
+
"""
|
|
235
|
+
logits = self.adversarial_classifier(z_other)
|
|
236
|
+
cls_logits = torch.nn.LogSoftmax(dim=1)(logits)
|
|
237
|
+
n_classes = cls_logits.shape[1]
|
|
238
|
+
|
|
239
|
+
if predict_true_class:
|
|
240
|
+
cls_target = torch.nn.functional.one_hot(phase_index.squeeze(-1), n_classes)
|
|
241
|
+
else:
|
|
242
|
+
# For fooling: create soft target that spreads over all incorrect classes
|
|
243
|
+
one_hot = torch.nn.functional.one_hot(phase_index.squeeze(-1), n_classes)
|
|
244
|
+
cls_target = (~one_hot.bool()).float() / (n_classes - 1)
|
|
245
|
+
|
|
246
|
+
return -(cls_logits * cls_target).sum(dim=1).mean()
|
|
247
|
+
|
|
248
|
+
def training_step(self, batch, batch_idx):
|
|
249
|
+
"""
|
|
250
|
+
Custom training step with adversarial optimization.
|
|
251
|
+
|
|
252
|
+
Step 1: Optimize model to fool the classifier.
|
|
253
|
+
Step 2: Optimize classifier to predict true phase.
|
|
254
|
+
"""
|
|
255
|
+
if "kl_weight" in self.loss_kwargs:
|
|
256
|
+
self.loss_kwargs["kl_weight"] = self.kl_weight
|
|
257
|
+
self.log("kl_weight", self.kl_weight, on_step=True, on_epoch=False)
|
|
258
|
+
|
|
259
|
+
# Determine scaling factor kappa
|
|
260
|
+
kappa = 1 - self.kl_weight if self.scale_adversarial_loss == "auto" else self.scale_adversarial_loss
|
|
261
|
+
|
|
262
|
+
# Assume phase is the first categorical covariate
|
|
263
|
+
batch_cat = batch[CYCLE_REGISTRY_KEYS.CYCLE_LABEL_KEY].long().squeeze(-1)
|
|
264
|
+
|
|
265
|
+
opt1, opt2 = self.optimizers() if isinstance(self.optimizers(), list) else (self.optimizers(), None)
|
|
266
|
+
|
|
267
|
+
outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
|
|
268
|
+
z = outputs["z"]
|
|
269
|
+
z_other = z[:, 2:] # Remove 2D z_cycle from full latent
|
|
270
|
+
|
|
271
|
+
loss = scvi_loss.loss
|
|
272
|
+
|
|
273
|
+
# Step 1: Fool classifier (optimize model to remove phase info)
|
|
274
|
+
if kappa > 0:
|
|
275
|
+
fool_loss = self.loss_adversarial_classifier(z_other, batch_cat, predict_true_class=False)
|
|
276
|
+
loss += fool_loss * kappa * 1000
|
|
277
|
+
|
|
278
|
+
self.log("train_loss", loss, on_epoch=True, prog_bar=True)
|
|
279
|
+
self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
|
|
280
|
+
|
|
281
|
+
opt1.zero_grad()
|
|
282
|
+
self.manual_backward(loss)
|
|
283
|
+
opt1.step()
|
|
284
|
+
|
|
285
|
+
# Step 2: Train classifier (optimize to correctly predict phase)
|
|
286
|
+
if opt2 is not None:
|
|
287
|
+
cls_loss = self.loss_adversarial_classifier(z_other.detach(), batch_cat, predict_true_class=True)
|
|
288
|
+
cls_loss *= kappa
|
|
289
|
+
|
|
290
|
+
opt2.zero_grad()
|
|
291
|
+
self.manual_backward(cls_loss)
|
|
292
|
+
opt2.step()
|
|
293
|
+
|
|
294
|
+
def configure_optimizers(self):
|
|
295
|
+
"""
|
|
296
|
+
Returns separate optimizers for model and classifier.
|
|
297
|
+
Optionally adds LR scheduler for the model.
|
|
298
|
+
"""
|
|
299
|
+
# Optimizer for main model
|
|
300
|
+
params1 = filter(lambda p: p.requires_grad, self.module.parameters())
|
|
301
|
+
optimizer1 = self.get_optimizer_creator()(params1)
|
|
302
|
+
config1 = {"optimizer": optimizer1}
|
|
303
|
+
|
|
304
|
+
# Optional learning rate scheduler
|
|
305
|
+
if self.reduce_lr_on_plateau:
|
|
306
|
+
scheduler = ReduceLROnPlateau(
|
|
307
|
+
optimizer1,
|
|
308
|
+
patience=self.lr_patience,
|
|
309
|
+
factor=self.lr_factor,
|
|
310
|
+
threshold=self.lr_threshold,
|
|
311
|
+
min_lr=self.lr_min,
|
|
312
|
+
threshold_mode="abs",
|
|
313
|
+
)
|
|
314
|
+
config1["lr_scheduler"] = {"scheduler": scheduler, "monitor": self.lr_scheduler_metric}
|
|
315
|
+
|
|
316
|
+
# Optimizer for adversarial classifier
|
|
317
|
+
params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters())
|
|
318
|
+
optimizer2 = torch.optim.Adam(params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay)
|
|
319
|
+
|
|
320
|
+
if "lr_scheduler" in config1:
|
|
321
|
+
return [config1["optimizer"], optimizer2], [config1["lr_scheduler"]]
|
|
322
|
+
return [config1["optimizer"], optimizer2]
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
# ─────────────────────────────────────────────────────────────
|
|
326
|
+
# Decoder
|
|
327
|
+
# ─────────────────────────────────────────────────────────────
|
|
328
|
+
class DecoderCycleVI(nn.Module):
|
|
329
|
+
"""
|
|
330
|
+
Custom decoder for CycleVI model that separates gene expression into:
|
|
331
|
+
- Non-cyclic components via a feedforward network.
|
|
332
|
+
- Cell cycle–dependent components via Fourier basis functions.
|
|
333
|
+
|
|
334
|
+
Arguments:
|
|
335
|
+
n_input (int): Input dimension (latent + covariates).
|
|
336
|
+
n_output (int): Number of output genes.
|
|
337
|
+
n_layers (int): Number of hidden layers in the feedforward decoder.
|
|
338
|
+
n_hidden (int): Width of each hidden layer.
|
|
339
|
+
n_cat_list: A list containing the number of categories
|
|
340
|
+
for each category of interest. Each category will be
|
|
341
|
+
included using a one-hot encoding
|
|
342
|
+
inject_covariates (bool): Whether to inject covariates in FCLayers.
|
|
343
|
+
use_batch_norm (bool): Whether to apply batch normalization.
|
|
344
|
+
use_layer_norm (bool): Whether to apply layer normalization.
|
|
345
|
+
scale_activation (str): 'softmax' or 'softplus' for output.
|
|
346
|
+
cycle_gene_mask (torch.Tensor): Boolean mask over cycle-regulated genes.
|
|
347
|
+
n_fourier (int): Number of Fourier harmonics for cyclic signal.
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
def __init__(
|
|
351
|
+
self,
|
|
352
|
+
n_input: int,
|
|
353
|
+
n_output: int,
|
|
354
|
+
n_layers: int = 1,
|
|
355
|
+
n_hidden: int = 128,
|
|
356
|
+
n_cat_list: Iterable[int] = None,
|
|
357
|
+
inject_covariates: bool = True,
|
|
358
|
+
use_batch_norm: bool = False,
|
|
359
|
+
use_layer_norm: bool = False,
|
|
360
|
+
scale_activation: str = "softmax",
|
|
361
|
+
cycle_gene_mask: torch.Tensor = None,
|
|
362
|
+
n_fourier: int = 3,
|
|
363
|
+
**kwargs
|
|
364
|
+
):
|
|
365
|
+
super().__init__()
|
|
366
|
+
|
|
367
|
+
if cycle_gene_mask is None:
|
|
368
|
+
cycle_gene_mask = torch.ones(n_output, dtype=torch.bool)
|
|
369
|
+
elif cycle_gene_mask.shape[0] != n_output:
|
|
370
|
+
raise ValueError("`cycle_gene_mask` must match n_output")
|
|
371
|
+
|
|
372
|
+
self.register_buffer("cycle_mask", cycle_gene_mask.float())
|
|
373
|
+
self.n_output = n_output
|
|
374
|
+
self.n_fourier = n_fourier
|
|
375
|
+
|
|
376
|
+
# Feedforward decoder for non-cycle component (takes z_latent only)
|
|
377
|
+
self.non_cycle_fc = FCLayers(
|
|
378
|
+
n_in=n_input - 2, # exclude 2D z_cycle
|
|
379
|
+
n_out=n_hidden,
|
|
380
|
+
n_cat_list=n_cat_list,
|
|
381
|
+
n_layers=n_layers,
|
|
382
|
+
n_hidden=n_hidden,
|
|
383
|
+
use_batch_norm=use_batch_norm,
|
|
384
|
+
use_layer_norm=use_layer_norm,
|
|
385
|
+
inject_covariates=True,
|
|
386
|
+
activation_fn=nn.ReLU,
|
|
387
|
+
use_activation=True,
|
|
388
|
+
bias=True,
|
|
389
|
+
)
|
|
390
|
+
self.non_cycle_linear = nn.Linear(n_hidden, n_output)
|
|
391
|
+
|
|
392
|
+
# Fourier weights for periodic expression modulation
|
|
393
|
+
self.fourier_W = nn.Parameter(0.01 * torch.randn(2 * n_fourier, n_output))
|
|
394
|
+
|
|
395
|
+
# Gradient mask so only cycle genes get updated via Fourier weights
|
|
396
|
+
grad_mask = torch.zeros_like(self.fourier_W)
|
|
397
|
+
grad_mask[:, cycle_gene_mask] = 1.0
|
|
398
|
+
self.fourier_W.register_hook(lambda grad: grad * grad_mask.to(grad.device))
|
|
399
|
+
|
|
400
|
+
# Raw dispersion parameter for Negative Binomial model
|
|
401
|
+
self.disp_raw = nn.Parameter(0.01 * torch.randn(n_output))
|
|
402
|
+
|
|
403
|
+
# Output activation
|
|
404
|
+
self.px_scale_activation = (
|
|
405
|
+
nn.Softmax(dim=-1) if scale_activation == "softmax" else nn.Softplus()
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
def forward(
|
|
409
|
+
self,
|
|
410
|
+
z: torch.Tensor, # Latent vector [N, latent_dim]
|
|
411
|
+
library: torch.Tensor, # Log-library size [N, 1]
|
|
412
|
+
remove_cell_cycle: bool = False, # Disable cyclic modulation if True
|
|
413
|
+
*cat_list: int,
|
|
414
|
+
):
|
|
415
|
+
# Split latent space: z_cycle (2D) and z_latent (rest)
|
|
416
|
+
z_cycle = z[..., 0:2]
|
|
417
|
+
z_latent = z[..., 2:]
|
|
418
|
+
|
|
419
|
+
# Feedforward decoder for baseline gene expression
|
|
420
|
+
x_input = z_latent
|
|
421
|
+
x = self.non_cycle_fc(x_input,*cat_list)
|
|
422
|
+
non_cycle_out = self.non_cycle_linear(x)
|
|
423
|
+
|
|
424
|
+
# Convert 2D z_cycle into phase angle θ
|
|
425
|
+
x, y = z_cycle[..., 0], z_cycle[..., 1]
|
|
426
|
+
angle = torch.atan2(y, x) # shape: [N]
|
|
427
|
+
|
|
428
|
+
# Compute cycle-dependent modulation (Fourier basis)
|
|
429
|
+
if remove_cell_cycle:
|
|
430
|
+
cycle_effect = 0.0
|
|
431
|
+
else:
|
|
432
|
+
# Construct Fourier basis: [cos(kθ), sin(kθ)] for k in 1..n
|
|
433
|
+
basis = [
|
|
434
|
+
torch.cos(k * angle) for k in range(1, self.n_fourier + 1)
|
|
435
|
+
] + [
|
|
436
|
+
torch.sin(k * angle) for k in range(1, self.n_fourier + 1)
|
|
437
|
+
]
|
|
438
|
+
fourier_basis = torch.stack(basis, dim=-1) # shape: [N, 2 * n_fourier]
|
|
439
|
+
cycle_effect = torch.matmul(fourier_basis, self.fourier_W) * self.cycle_mask
|
|
440
|
+
|
|
441
|
+
# Combine non-cycle and cycle effects
|
|
442
|
+
eta = non_cycle_out + cycle_effect # shape: [N, n_output]
|
|
443
|
+
|
|
444
|
+
# Output gene proportions (scale) and rate
|
|
445
|
+
px_scale = self.px_scale_activation(eta) # [N, G]
|
|
446
|
+
px_rate = torch.exp(library) * px_scale # scaled by library size
|
|
447
|
+
disp = self.disp_raw # [G]
|
|
448
|
+
|
|
449
|
+
return (
|
|
450
|
+
px_scale, # gene proportions or log-expression rates
|
|
451
|
+
disp, # gene-wise dispersion
|
|
452
|
+
px_rate, # expected gene counts
|
|
453
|
+
None, # placeholder
|
|
454
|
+
angle.unsqueeze(1), # inferred cell cycle phase θ
|
|
455
|
+
None, None, # unused (e.g., radius)
|
|
456
|
+
self.fourier_W, # learned Fourier weights
|
|
457
|
+
None, # placeholder
|
|
458
|
+
non_cycle_out # non-cyclic output (for debugging)
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# ─────────────────────────────────────────────────────────────
|
|
462
|
+
# VAE
|
|
463
|
+
# ─────────────────────────────────────────────────────────────
|
|
464
|
+
|
|
465
|
+
class CycleVI_VAE(EmbeddingModuleMixin, BaseMinifiedModeModuleClass):
|
|
466
|
+
"""
|
|
467
|
+
Cell Cycle–aware Variational Autoencoder for single-cell RNA-seq data.
|
|
468
|
+
|
|
469
|
+
This model extends the standard scVI architecture by incorporating:
|
|
470
|
+
- A disentangled 2D circular latent space for modeling the cell cycle.
|
|
471
|
+
- A custom decoder with Fourier basis functions for periodic expression.
|
|
472
|
+
- Support for batch correction, observed and latent library sizes, and covariates.
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
def __init__(
|
|
476
|
+
self,
|
|
477
|
+
n_input: int,
|
|
478
|
+
n_batch: int = 0,
|
|
479
|
+
n_labels: int = 0,
|
|
480
|
+
n_hidden: int = 128,
|
|
481
|
+
n_latent: int = 10,
|
|
482
|
+
n_layers: int = 1,
|
|
483
|
+
n_continuous_cov: int = 0,
|
|
484
|
+
n_cats_per_cov: list[int] | None = None,
|
|
485
|
+
dropout_rate: float = 0.1,
|
|
486
|
+
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene-label",
|
|
487
|
+
log_variational: bool = True,
|
|
488
|
+
gene_likelihood: Literal["zinb", "nb", "poisson"] = "nb",
|
|
489
|
+
latent_distribution: Literal["normal", "ln"] = "normal",
|
|
490
|
+
encode_covariates: bool = False,
|
|
491
|
+
deeply_inject_covariates: bool = True,
|
|
492
|
+
batch_representation: Literal["one-hot", "embedding"] = "one-hot",
|
|
493
|
+
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
|
|
494
|
+
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
|
|
495
|
+
use_size_factor_key: bool = False,
|
|
496
|
+
use_observed_lib_size: bool = True,
|
|
497
|
+
library_log_means: np.ndarray | None = None,
|
|
498
|
+
library_log_vars: np.ndarray | None = None,
|
|
499
|
+
var_activation: Callable[[torch.Tensor], torch.Tensor] = None,
|
|
500
|
+
extra_encoder_kwargs: dict | None = None,
|
|
501
|
+
extra_decoder_kwargs: dict | None = None,
|
|
502
|
+
batch_embedding_kwargs: dict | None = None,
|
|
503
|
+
cycle_gene_mask: torch.Tensor | None = None,
|
|
504
|
+
):
|
|
505
|
+
super().__init__()
|
|
506
|
+
|
|
507
|
+
# Store core configuration
|
|
508
|
+
self.dispersion = dispersion
|
|
509
|
+
self.n_latent = n_latent
|
|
510
|
+
self.log_variational = log_variational
|
|
511
|
+
self.gene_likelihood = gene_likelihood
|
|
512
|
+
self.n_batch = n_batch
|
|
513
|
+
self.n_labels = n_labels
|
|
514
|
+
self.latent_distribution = latent_distribution
|
|
515
|
+
self.encode_covariates = encode_covariates
|
|
516
|
+
self.use_size_factor_key = use_size_factor_key
|
|
517
|
+
self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size
|
|
518
|
+
|
|
519
|
+
# Handle library size modeling if not using observed values
|
|
520
|
+
if not self.use_observed_lib_size:
|
|
521
|
+
if library_log_means is None or library_log_vars is None:
|
|
522
|
+
raise ValueError("Must provide library_log_means and library_log_vars if not using observed lib size.")
|
|
523
|
+
self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float())
|
|
524
|
+
self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float())
|
|
525
|
+
|
|
526
|
+
# ─────────────────────────────────────────────────────────────
|
|
527
|
+
# Setup batch representation
|
|
528
|
+
# ─────────────────────────────────────────────────────────────
|
|
529
|
+
self.batch_representation = batch_representation
|
|
530
|
+
if batch_representation == "embedding":
|
|
531
|
+
self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, **(batch_embedding_kwargs or {}))
|
|
532
|
+
batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim
|
|
533
|
+
elif batch_representation != "one-hot":
|
|
534
|
+
raise ValueError("`batch_representation` must be either 'one-hot' or 'embedding'.")
|
|
535
|
+
|
|
536
|
+
# ─────────────────────────────────────────────────────────────
|
|
537
|
+
# Encoder Setup
|
|
538
|
+
# ─────────────────────────────────────────────────────────────
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
# Determine normalization configuration
|
|
542
|
+
use_bn_enc = use_batch_norm in ["encoder", "both"]
|
|
543
|
+
use_bn_dec = use_batch_norm in ["decoder", "both"]
|
|
544
|
+
use_ln_enc = use_layer_norm in ["encoder", "both"]
|
|
545
|
+
use_ln_dec = use_layer_norm in ["decoder", "both"]
|
|
546
|
+
|
|
547
|
+
# Compute encoder input dimension
|
|
548
|
+
n_input_encoder = n_input + n_continuous_cov * encode_covariates
|
|
549
|
+
if self.batch_representation == "embedding":
|
|
550
|
+
n_input_encoder += batch_dim * encode_covariates
|
|
551
|
+
cat_list = list([] if n_cats_per_cov is None else n_cats_per_cov)
|
|
552
|
+
else:
|
|
553
|
+
cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov)
|
|
554
|
+
|
|
555
|
+
encoder_cat_list = cat_list if encode_covariates else None
|
|
556
|
+
_extra_encoder_kwargs = extra_encoder_kwargs or {}
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
self.z_encoder = Encoder(
|
|
560
|
+
n_input=n_input_encoder,
|
|
561
|
+
n_output=n_latent,
|
|
562
|
+
n_cat_list=encoder_cat_list,
|
|
563
|
+
n_layers=n_layers,
|
|
564
|
+
n_hidden=n_hidden,
|
|
565
|
+
dropout_rate=dropout_rate,
|
|
566
|
+
distribution=latent_distribution,
|
|
567
|
+
inject_covariates=deeply_inject_covariates,
|
|
568
|
+
use_batch_norm=use_bn_enc,
|
|
569
|
+
use_layer_norm=use_ln_enc,
|
|
570
|
+
var_activation=var_activation,
|
|
571
|
+
return_dist=True,
|
|
572
|
+
**_extra_encoder_kwargs,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# ─────────────────────────────────────────────────────────────
|
|
576
|
+
# Decoder Setup
|
|
577
|
+
# ─────────────────────────────────────────────────────────────
|
|
578
|
+
|
|
579
|
+
# Decoder input: z + optional batch embedding
|
|
580
|
+
n_input_decoder = n_latent + n_continuous_cov
|
|
581
|
+
if batch_representation == "embedding":
|
|
582
|
+
n_input_decoder += batch_dim
|
|
583
|
+
|
|
584
|
+
_extra_decoder_kwargs = extra_decoder_kwargs or {}
|
|
585
|
+
self.decoder = DecoderCycleVI(
|
|
586
|
+
n_input=n_input_decoder,
|
|
587
|
+
n_output=n_input,
|
|
588
|
+
n_layers=n_layers,
|
|
589
|
+
n_hidden=n_hidden,
|
|
590
|
+
n_cat_list=cat_list,
|
|
591
|
+
use_batch_norm=use_bn_dec,
|
|
592
|
+
use_layer_norm=use_ln_dec,
|
|
593
|
+
inject_covariates=deeply_inject_covariates,
|
|
594
|
+
scale_activation="softplus" if use_size_factor_key else "softmax",
|
|
595
|
+
cycle_gene_mask=cycle_gene_mask,
|
|
596
|
+
**_extra_decoder_kwargs,
|
|
597
|
+
)
|
|
598
|
+
|
|
599
|
+
# ─────────────────────────────────────────────────────────────
|
|
600
|
+
# Prepare tensors for inference
|
|
601
|
+
# ─────────────────────────────────────────────────────────────
|
|
602
|
+
|
|
603
|
+
def _get_inference_input(
|
|
604
|
+
self,
|
|
605
|
+
tensors: dict[str, torch.Tensor | None],
|
|
606
|
+
full_forward_pass: bool = False,
|
|
607
|
+
) -> dict[str, torch.Tensor | None]:
|
|
608
|
+
"""Get input tensors for the inference process."""
|
|
609
|
+
# Decide which data loader to use based on full_forward_pass flag and the minified data type
|
|
610
|
+
if full_forward_pass or self.minified_data_type is None:
|
|
611
|
+
loader = "full_data"
|
|
612
|
+
elif self.minified_data_type in [
|
|
613
|
+
ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
|
|
614
|
+
ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS,
|
|
615
|
+
]:
|
|
616
|
+
loader = "minified_data"
|
|
617
|
+
else:
|
|
618
|
+
raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}")
|
|
619
|
+
|
|
620
|
+
# For full data, return the standard tensors used in the model
|
|
621
|
+
if loader == "full_data":
|
|
622
|
+
return {
|
|
623
|
+
MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY],
|
|
624
|
+
MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY],
|
|
625
|
+
MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None),
|
|
626
|
+
MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None),
|
|
627
|
+
}
|
|
628
|
+
else:
|
|
629
|
+
# For minified data, use cached latent parameters
|
|
630
|
+
return {
|
|
631
|
+
MODULE_KEYS.QZM_KEY: tensors[REGISTRY_KEYS.LATENT_QZM_KEY],
|
|
632
|
+
MODULE_KEYS.QZV_KEY: tensors[REGISTRY_KEYS.LATENT_QZV_KEY],
|
|
633
|
+
REGISTRY_KEYS.OBSERVED_LIB_SIZE: tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE],
|
|
634
|
+
}
|
|
635
|
+
# ─────────────────────────────────────────────────────────────
|
|
636
|
+
# Prepare tensors for generative model
|
|
637
|
+
# ─────────────────────────────────────────────────────────────
|
|
638
|
+
|
|
639
|
+
def _get_generative_input(
|
|
640
|
+
self,
|
|
641
|
+
tensors: dict[str, torch.Tensor],
|
|
642
|
+
inference_outputs: dict[str, torch.Tensor | Distribution | None],
|
|
643
|
+
) -> dict[str, torch.Tensor | None]:
|
|
644
|
+
"""Get input tensors for the generative process."""
|
|
645
|
+
# Retrieve and transform size factor if provided
|
|
646
|
+
size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None)
|
|
647
|
+
if size_factor is not None:
|
|
648
|
+
size_factor = torch.log(size_factor)
|
|
649
|
+
|
|
650
|
+
# Return a dictionary mapping module keys to the appropriate tensors/distributions
|
|
651
|
+
return {
|
|
652
|
+
MODULE_KEYS.Z_KEY: inference_outputs[MODULE_KEYS.Z_KEY],
|
|
653
|
+
MODULE_KEYS.LIBRARY_KEY: inference_outputs[MODULE_KEYS.LIBRARY_KEY],
|
|
654
|
+
MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY],
|
|
655
|
+
MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.LABELS_KEY],
|
|
656
|
+
MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None),
|
|
657
|
+
MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None),
|
|
658
|
+
MODULE_KEYS.SIZE_FACTOR_KEY: size_factor,
|
|
659
|
+
|
|
660
|
+
}
|
|
661
|
+
|
|
662
|
+
# ─────────────────────────────────────────────────────────────
|
|
663
|
+
# For each cell, computes the mean and variance of the log library size for the corresponding batch.
|
|
664
|
+
# ─────────────────────────────────────────────────────────────
|
|
665
|
+
|
|
666
|
+
def _compute_local_library_params(
|
|
667
|
+
self,
|
|
668
|
+
batch_index: torch.Tensor,
|
|
669
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
670
|
+
"""
|
|
671
|
+
Computes local library parameters.
|
|
672
|
+
|
|
673
|
+
For each cell, computes the mean and variance of the log library size
|
|
674
|
+
for the corresponding batch.
|
|
675
|
+
"""
|
|
676
|
+
from torch.nn.functional import linear
|
|
677
|
+
|
|
678
|
+
n_batch = self.library_log_means.shape[1] # Number of batches from the library means buffer
|
|
679
|
+
# Compute local means using one-hot encoding for the batch index and linear transformation
|
|
680
|
+
local_library_log_means = linear(
|
|
681
|
+
one_hot(batch_index.squeeze(-1), n_batch).float(), self.library_log_means
|
|
682
|
+
)
|
|
683
|
+
# Compute local variances similarly
|
|
684
|
+
local_library_log_vars = linear(
|
|
685
|
+
one_hot(batch_index.squeeze(-1), n_batch).float(), self.library_log_vars
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
return local_library_log_means, local_library_log_vars
|
|
689
|
+
|
|
690
|
+
@auto_move_data # Automatically move inputs/outputs to the correct device (CPU/GPU)
|
|
691
|
+
|
|
692
|
+
# ─────────────────────────────────────────────────────────────
|
|
693
|
+
# Encodes input data into latent variables
|
|
694
|
+
# ─────────────────────────────────────────────────────────────
|
|
695
|
+
|
|
696
|
+
def _regular_inference(
|
|
697
|
+
self,
|
|
698
|
+
x: torch.Tensor,
|
|
699
|
+
batch_index: torch.Tensor,
|
|
700
|
+
cont_covs: torch.Tensor | None = None,
|
|
701
|
+
cat_covs: torch.Tensor | None = None,
|
|
702
|
+
n_samples: int = 1,
|
|
703
|
+
) -> dict[str, torch.Tensor | Distribution | None]:
|
|
704
|
+
"""Run the regular inference process with normalization by library size."""
|
|
705
|
+
# Step 1: Compute observed library size (sum over genes per cell)
|
|
706
|
+
library = torch.sum(x, dim=1, keepdim=True) # shape [N, 1]
|
|
707
|
+
|
|
708
|
+
# Step 2: Normalize expression per cell
|
|
709
|
+
x_normalized = x / (library + 1e-8) # Add small epsilon to avoid division by zero
|
|
710
|
+
|
|
711
|
+
# Step 3: Apply log1p for numerical stability if enabled
|
|
712
|
+
if self.log_variational:
|
|
713
|
+
x_normalized = torch.log1p(x_normalized)
|
|
714
|
+
|
|
715
|
+
# Step 4: Prepare encoder input
|
|
716
|
+
if cont_covs is not None and self.encode_covariates:
|
|
717
|
+
encoder_input = torch.cat((x_normalized, cont_covs), dim=-1)
|
|
718
|
+
else:
|
|
719
|
+
encoder_input = x_normalized
|
|
720
|
+
|
|
721
|
+
if cat_covs is not None and self.encode_covariates:
|
|
722
|
+
categorical_input = torch.split(cat_covs, 1, dim=1)
|
|
723
|
+
else:
|
|
724
|
+
categorical_input = ()
|
|
725
|
+
|
|
726
|
+
if self.batch_representation == "embedding" and self.encode_covariates:
|
|
727
|
+
batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index)
|
|
728
|
+
encoder_input = torch.cat([encoder_input, batch_rep], dim=-1)
|
|
729
|
+
qz, z = self.z_encoder(encoder_input, *categorical_input)
|
|
730
|
+
else:
|
|
731
|
+
qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
|
|
732
|
+
|
|
733
|
+
# If you're not using observed_lib_size, compute encoded one
|
|
734
|
+
ql = None
|
|
735
|
+
if not self.use_observed_lib_size:
|
|
736
|
+
if self.batch_representation == "embedding":
|
|
737
|
+
ql, library_encoded = self.l_encoder(encoder_input, *categorical_input)
|
|
738
|
+
else:
|
|
739
|
+
ql, library_encoded = self.l_encoder(encoder_input, batch_index, *categorical_input)
|
|
740
|
+
library = library_encoded
|
|
741
|
+
|
|
742
|
+
# Expand for MC sampling if needed
|
|
743
|
+
if n_samples > 1:
|
|
744
|
+
untran_z = qz.sample((n_samples,))
|
|
745
|
+
z = self.z_encoder.z_transformation(untran_z)
|
|
746
|
+
library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) \
|
|
747
|
+
if self.use_observed_lib_size else ql.sample((n_samples,))
|
|
748
|
+
|
|
749
|
+
return {
|
|
750
|
+
MODULE_KEYS.Z_KEY: z,
|
|
751
|
+
MODULE_KEYS.QZ_KEY: qz,
|
|
752
|
+
MODULE_KEYS.QL_KEY: ql,
|
|
753
|
+
MODULE_KEYS.LIBRARY_KEY: torch.log(library + 1e-8), # used in decoder
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
@auto_move_data
|
|
758
|
+
def _cached_inference(
|
|
759
|
+
self,
|
|
760
|
+
qzm: torch.Tensor, # Cached latent mean values
|
|
761
|
+
qzv: torch.Tensor, # Cached latent variance values
|
|
762
|
+
observed_lib_size: torch.Tensor, # Observed library size values
|
|
763
|
+
n_samples: int = 1, # Number of samples for Monte Carlo approximation
|
|
764
|
+
) -> dict[str, torch.Tensor | None]:
|
|
765
|
+
"""Run the cached inference process."""
|
|
766
|
+
|
|
767
|
+
# Reconstruct the latent distribution using the cached parameters
|
|
768
|
+
qz = Normal(qzm, qzv.sqrt())
|
|
769
|
+
# Sample from the latent distribution; using sample() (non-reparameterized)
|
|
770
|
+
untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,))
|
|
771
|
+
# Transform the sampled latent variables if necessary
|
|
772
|
+
z = self.z_encoder.z_transformation(untran_z)
|
|
773
|
+
# Compute the library by taking log of the observed library size
|
|
774
|
+
library = torch.log(observed_lib_size)
|
|
775
|
+
if n_samples > 1:
|
|
776
|
+
library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1)))
|
|
777
|
+
|
|
778
|
+
return {
|
|
779
|
+
MODULE_KEYS.Z_KEY: z,
|
|
780
|
+
MODULE_KEYS.QZ_KEY: qz,
|
|
781
|
+
MODULE_KEYS.QL_KEY: None,
|
|
782
|
+
MODULE_KEYS.LIBRARY_KEY: library,
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
@auto_move_data
|
|
786
|
+
|
|
787
|
+
# ─────────────────────────────────────────────────────────────
|
|
788
|
+
# Decodes latent z back to gene expression
|
|
789
|
+
# ─────────────────────────────────────────────────────────────
|
|
790
|
+
def generative(
|
|
791
|
+
self,
|
|
792
|
+
z,
|
|
793
|
+
library,
|
|
794
|
+
batch_index,
|
|
795
|
+
cont_covs=None,
|
|
796
|
+
cat_covs=None,
|
|
797
|
+
size_factor=None,
|
|
798
|
+
y=None,
|
|
799
|
+
transform_batch=None,
|
|
800
|
+
remove_cell_cycle: bool = False,
|
|
801
|
+
):
|
|
802
|
+
from scvi.distributions import NegativeBinomial, Normal, Poisson, ZeroInflatedNegativeBinomial
|
|
803
|
+
# 1. Build decoder_input = [z (+ cont_covs)]
|
|
804
|
+
if cont_covs is None:
|
|
805
|
+
decoder_input = z
|
|
806
|
+
elif z.dim() != cont_covs.dim():
|
|
807
|
+
decoder_input = torch.cat(
|
|
808
|
+
[z, cont_covs.unsqueeze(0).expand(z.size(0), -1, -1)], dim=-1
|
|
809
|
+
)
|
|
810
|
+
else:
|
|
811
|
+
decoder_input = torch.cat([z, cont_covs], dim=-1)
|
|
812
|
+
|
|
813
|
+
# 2. Add batch *embedding* if we're in embedding mode
|
|
814
|
+
if self.batch_representation == "embedding":
|
|
815
|
+
batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index)
|
|
816
|
+
# make dims match if we're in MC-sampling mode (z can be [n_samples, n_cells, ...])
|
|
817
|
+
if decoder_input.dim() != batch_rep.dim():
|
|
818
|
+
batch_rep = batch_rep.unsqueeze(0).expand(decoder_input.size(0), -1, -1)
|
|
819
|
+
decoder_input = torch.cat([decoder_input, batch_rep], dim=-1)
|
|
820
|
+
# IMPORTANT: in this mode, batch is NOT part of n_cat_list
|
|
821
|
+
|
|
822
|
+
# 3. Build categorical inputs for FCLayers
|
|
823
|
+
# Start with any categorical covariates from setup_anndata(...)
|
|
824
|
+
if cat_covs is not None:
|
|
825
|
+
categorical_input = torch.split(cat_covs, 1, dim=1)
|
|
826
|
+
else:
|
|
827
|
+
categorical_input = ()
|
|
828
|
+
|
|
829
|
+
# If batch is represented as one-hot (not embedding), it *is* part of n_cat_list
|
|
830
|
+
if self.batch_representation == "one-hot":
|
|
831
|
+
categorical_input = (batch_index, *categorical_input)
|
|
832
|
+
|
|
833
|
+
# 4. Handle transform_batch (same as scVI: override batch_index used for dispersion/priors)
|
|
834
|
+
if transform_batch is not None:
|
|
835
|
+
batch_index = torch.ones_like(batch_index) * transform_batch
|
|
836
|
+
|
|
837
|
+
# 5. size_factor / library handling
|
|
838
|
+
if not self.use_size_factor_key:
|
|
839
|
+
size_factor = library # scVI uses observed lib size unless overridden
|
|
840
|
+
|
|
841
|
+
# 6. Run the decoder
|
|
842
|
+
(
|
|
843
|
+
px_scale,
|
|
844
|
+
disp,
|
|
845
|
+
px_rate,
|
|
846
|
+
_,
|
|
847
|
+
angle,
|
|
848
|
+
radius,
|
|
849
|
+
_,
|
|
850
|
+
W_fourier,
|
|
851
|
+
_,
|
|
852
|
+
baseline,
|
|
853
|
+
) = self.decoder(
|
|
854
|
+
decoder_input,
|
|
855
|
+
size_factor,
|
|
856
|
+
remove_cell_cycle,
|
|
857
|
+
*categorical_input,
|
|
858
|
+
)
|
|
859
|
+
|
|
860
|
+
px_r = torch.exp(disp)
|
|
861
|
+
|
|
862
|
+
if self.gene_likelihood == "zinb":
|
|
863
|
+
px = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=None, scale=px_scale)
|
|
864
|
+
elif self.gene_likelihood == "nb":
|
|
865
|
+
px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale)
|
|
866
|
+
elif self.gene_likelihood == "poisson":
|
|
867
|
+
px = Poisson(rate=px_rate, scale=px_scale)
|
|
868
|
+
elif self.gene_likelihood == "normal":
|
|
869
|
+
px = Normal(px_rate, px_r, normal_mu=px_scale)
|
|
870
|
+
|
|
871
|
+
if self.use_observed_lib_size:
|
|
872
|
+
pl = None
|
|
873
|
+
else:
|
|
874
|
+
local_library_log_means, local_library_log_vars = self._compute_local_library_params(batch_index)
|
|
875
|
+
pl = Normal(local_library_log_means, local_library_log_vars.sqrt())
|
|
876
|
+
|
|
877
|
+
pz = Normal(torch.zeros_like(z), torch.ones_like(z))
|
|
878
|
+
|
|
879
|
+
return {
|
|
880
|
+
MODULE_KEYS.PX_KEY: px,
|
|
881
|
+
MODULE_KEYS.PL_KEY: pl,
|
|
882
|
+
MODULE_KEYS.PZ_KEY: pz,
|
|
883
|
+
"angle": angle,
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
@unsupported_if_adata_minified # Mark this method as unsupported if AnnData is in minified mode
|
|
888
|
+
# ─────────────────────────────────────────────────────────────
|
|
889
|
+
# Loss function
|
|
890
|
+
# ─────────────────────────────────────────────────────────────
|
|
891
|
+
def loss(
|
|
892
|
+
self,
|
|
893
|
+
tensors: dict[str, torch.Tensor],
|
|
894
|
+
inference_outputs: dict[str, torch.Tensor | Distribution | None],
|
|
895
|
+
generative_outputs: dict[str, Distribution | torch.Tensor | None],
|
|
896
|
+
kl_weight: torch.Tensor | float = 1.0,
|
|
897
|
+
) -> LossOutput:
|
|
898
|
+
from torch.distributions import kl_divergence
|
|
899
|
+
|
|
900
|
+
x = tensors[REGISTRY_KEYS.X_KEY]
|
|
901
|
+
|
|
902
|
+
# KL divergence for z
|
|
903
|
+
kl_divergence_z = kl_divergence(
|
|
904
|
+
inference_outputs[MODULE_KEYS.QZ_KEY], generative_outputs[MODULE_KEYS.PZ_KEY]
|
|
905
|
+
).sum(dim=-1)
|
|
906
|
+
|
|
907
|
+
# KL for library size
|
|
908
|
+
if not self.use_observed_lib_size:
|
|
909
|
+
kl_divergence_l = kl_divergence(
|
|
910
|
+
inference_outputs[MODULE_KEYS.QL_KEY], generative_outputs[MODULE_KEYS.PL_KEY]
|
|
911
|
+
).sum(dim=1)
|
|
912
|
+
else:
|
|
913
|
+
kl_divergence_l = torch.zeros_like(kl_divergence_z)
|
|
914
|
+
|
|
915
|
+
# Reconstruction loss
|
|
916
|
+
reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
|
|
917
|
+
|
|
918
|
+
# Latent position
|
|
919
|
+
z_latent = inference_outputs[MODULE_KEYS.QZ_KEY].loc
|
|
920
|
+
x_latent, y_latent = z_latent[:, 0], z_latent[:, 1]
|
|
921
|
+
angle = torch.atan2(y_latent, x_latent)
|
|
922
|
+
radius = torch.sqrt(x_latent**2 + y_latent**2 + 1e-6)
|
|
923
|
+
|
|
924
|
+
# G2M and S scores from continuous covariates
|
|
925
|
+
target_angle = tensors[CYCLE_REGISTRY_KEYS.CYCLE_ANGLE_KEY].squeeze(-1)
|
|
926
|
+
|
|
927
|
+
# Angle loss (squared angular distance)
|
|
928
|
+
delta_angle = angle - target_angle
|
|
929
|
+
angle_loss = torch.mean(1.0 - torch.cos(delta_angle))
|
|
930
|
+
|
|
931
|
+
# Radius penalty (high as r → 0)
|
|
932
|
+
radius_penalty = torch.mean(torch.exp(-10 * radius))
|
|
933
|
+
|
|
934
|
+
# Weighted loss terms
|
|
935
|
+
weighted_kl_local = kl_weight * kl_divergence_z + kl_divergence_l
|
|
936
|
+
cycle_pos_weight = 100 * (1.0 - kl_weight)**4
|
|
937
|
+
weighted_angle_loss = 0.5 * cycle_pos_weight * angle_loss
|
|
938
|
+
weighted_radius_penalty = 100 * (kl_weight**2)*radius_penalty
|
|
939
|
+
|
|
940
|
+
# Total loss
|
|
941
|
+
loss = torch.mean(reconst_loss + weighted_kl_local) + weighted_angle_loss + weighted_radius_penalty
|
|
942
|
+
|
|
943
|
+
return LossOutput(
|
|
944
|
+
loss=loss,
|
|
945
|
+
reconstruction_loss=reconst_loss,
|
|
946
|
+
kl_local={
|
|
947
|
+
MODULE_KEYS.KL_L_KEY: kl_divergence_l,
|
|
948
|
+
MODULE_KEYS.KL_Z_KEY: kl_divergence_z,
|
|
949
|
+
},
|
|
950
|
+
extra_metrics={
|
|
951
|
+
"z": inference_outputs["z"],
|
|
952
|
+
"batch": tensors[REGISTRY_KEYS.BATCH_KEY],
|
|
953
|
+
"labels": tensors[REGISTRY_KEYS.LABELS_KEY],
|
|
954
|
+
"angle_loss": angle_loss,
|
|
955
|
+
"radius_penalty": radius_penalty,
|
|
956
|
+
"weighted_angle_loss": weighted_angle_loss,
|
|
957
|
+
"weighted_radius_penalty": weighted_radius_penalty,
|
|
958
|
+
},
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
@torch.inference_mode()
|
|
964
|
+
|
|
965
|
+
# ─────────────────────────────────────────────────────────────
|
|
966
|
+
# Samples gene expression from the posterior predictive distribution
|
|
967
|
+
# ─────────────────────────────────────────────────────────────
|
|
968
|
+
|
|
969
|
+
def sample(
|
|
970
|
+
self,
|
|
971
|
+
tensors: dict[str, torch.Tensor], # Input tensors for sampling
|
|
972
|
+
n_samples: int = 1, # Number of Monte Carlo samples to draw per observation
|
|
973
|
+
max_poisson_rate: float = 1e8, # Maximum value to clip Poisson rate to avoid numerical issues
|
|
974
|
+
) -> torch.Tensor:
|
|
975
|
+
r"""Generate predictive samples from the posterior predictive distribution.
|
|
976
|
+
|
|
977
|
+
The posterior predictive distribution is denoted as :math:`p(\hat{x} \mid x)`, where
|
|
978
|
+
:math:`x` is the input data and :math:`\hat{x}` is the sampled data.
|
|
979
|
+
|
|
980
|
+
We sample from this distribution by first sampling ``n_samples`` times from the posterior
|
|
981
|
+
distribution :math:`q(z \mid x)` for a given observation, and then sampling from the
|
|
982
|
+
likelihood :math:`p(\hat{x} \mid z)` for each of these.
|
|
983
|
+
"""
|
|
984
|
+
from scvi.distributions import Poisson
|
|
985
|
+
|
|
986
|
+
inference_kwargs = {"n_samples": n_samples}
|
|
987
|
+
# Run a forward pass to get generative outputs (without computing loss)
|
|
988
|
+
_, generative_outputs = self.forward(
|
|
989
|
+
tensors, inference_kwargs=inference_kwargs, compute_loss=False
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
dist = generative_outputs[MODULE_KEYS.PX_KEY]
|
|
993
|
+
if self.gene_likelihood == "poisson":
|
|
994
|
+
# Handle potential issues on MPS devices by clamping the Poisson rate
|
|
995
|
+
dist = (
|
|
996
|
+
Poisson(torch.clamp(dist.rate.to("cpu"), max=max_poisson_rate))
|
|
997
|
+
if self.device.type == "mps"
|
|
998
|
+
else Poisson(torch.clamp(dist.rate, max=max_poisson_rate))
|
|
999
|
+
)
|
|
1000
|
+
|
|
1001
|
+
# Draw samples from the likelihood distribution; shape depends on n_samples
|
|
1002
|
+
samples = dist.sample()
|
|
1003
|
+
# If multiple samples were drawn, permute dimensions so that output is (n_obs, n_vars, n_samples)
|
|
1004
|
+
samples = torch.permute(samples, (1, 2, 0)) if n_samples > 1 else samples
|
|
1005
|
+
|
|
1006
|
+
return samples.cpu() # Return samples on CPU
|
|
1007
|
+
|
|
1008
|
+
@torch.inference_mode()
|
|
1009
|
+
@auto_move_data
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
# ─────────────────────────────────────────────────────────────
|
|
1013
|
+
# Estimates marginal log-likelihood with Monte Carlo sampling
|
|
1014
|
+
# ─────────────────────────────────────────────────────────────
|
|
1015
|
+
|
|
1016
|
+
def marginal_ll(
|
|
1017
|
+
self,
|
|
1018
|
+
tensors: dict[str, torch.Tensor], # Input tensors for marginal likelihood computation
|
|
1019
|
+
n_mc_samples: int, # Total number of Monte Carlo samples for estimation
|
|
1020
|
+
return_mean: bool = False, # Whether to return the mean marginal likelihood over cells
|
|
1021
|
+
n_mc_samples_per_pass: int = 1, # Number of samples per computation pass (to reduce memory usage)
|
|
1022
|
+
):
|
|
1023
|
+
"""Compute the marginal log-likelihood of the data under the model."""
|
|
1024
|
+
from torch import logsumexp
|
|
1025
|
+
from torch.distributions import Normal
|
|
1026
|
+
|
|
1027
|
+
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
|
|
1028
|
+
|
|
1029
|
+
to_sum = [] # List to accumulate log probabilities over multiple passes
|
|
1030
|
+
if n_mc_samples_per_pass > n_mc_samples:
|
|
1031
|
+
warnings.warn(
|
|
1032
|
+
"Number of chunks is larger than the total number of samples, setting it to the "
|
|
1033
|
+
"number of samples",
|
|
1034
|
+
RuntimeWarning,
|
|
1035
|
+
stacklevel=settings.warnings_stacklevel,
|
|
1036
|
+
)
|
|
1037
|
+
n_mc_samples_per_pass = n_mc_samples
|
|
1038
|
+
n_passes = int(np.ceil(n_mc_samples / n_mc_samples_per_pass))
|
|
1039
|
+
for _ in range(n_passes):
|
|
1040
|
+
# For each pass, run a forward pass to get inference outputs and loss components
|
|
1041
|
+
inference_outputs, _, losses = self.forward(
|
|
1042
|
+
tensors,
|
|
1043
|
+
inference_kwargs={"n_samples": n_mc_samples_per_pass},
|
|
1044
|
+
get_inference_input_kwargs={"full_forward_pass": True},
|
|
1045
|
+
)
|
|
1046
|
+
qz = inference_outputs[MODULE_KEYS.QZ_KEY]
|
|
1047
|
+
ql = inference_outputs[MODULE_KEYS.QL_KEY]
|
|
1048
|
+
z = inference_outputs[MODULE_KEYS.Z_KEY]
|
|
1049
|
+
library = inference_outputs[MODULE_KEYS.LIBRARY_KEY]
|
|
1050
|
+
|
|
1051
|
+
# Get the reconstruction loss from the losses output
|
|
1052
|
+
reconst_loss = losses.dict_sum(losses.reconstruction_loss)
|
|
1053
|
+
|
|
1054
|
+
# Compute log probabilities for the latent variable and reconstruction
|
|
1055
|
+
p_z = (
|
|
1056
|
+
Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)).log_prob(z).sum(dim=-1)
|
|
1057
|
+
)
|
|
1058
|
+
p_x_zl = -reconst_loss
|
|
1059
|
+
q_z_x = qz.log_prob(z).sum(dim=-1)
|
|
1060
|
+
log_prob_sum = p_z + p_x_zl - q_z_x
|
|
1061
|
+
|
|
1062
|
+
if not self.use_observed_lib_size:
|
|
1063
|
+
# Compute additional log probabilities for library size if not observed
|
|
1064
|
+
local_library_log_means, local_library_log_vars = self._compute_local_library_params(batch_index)
|
|
1065
|
+
p_l = (
|
|
1066
|
+
Normal(local_library_log_means, local_library_log_vars.sqrt())
|
|
1067
|
+
.log_prob(library)
|
|
1068
|
+
.sum(dim=-1)
|
|
1069
|
+
)
|
|
1070
|
+
q_l_x = ql.log_prob(library).sum(dim=-1)
|
|
1071
|
+
log_prob_sum += p_l - q_l_x
|
|
1072
|
+
if n_mc_samples_per_pass == 1:
|
|
1073
|
+
log_prob_sum = log_prob_sum.unsqueeze(0)
|
|
1074
|
+
|
|
1075
|
+
to_sum.append(log_prob_sum)
|
|
1076
|
+
# Concatenate all passes and compute log-sum-exp for a Monte Carlo estimate
|
|
1077
|
+
to_sum = torch.cat(to_sum, dim=0)
|
|
1078
|
+
batch_log_lkl = logsumexp(to_sum, dim=0) - np.log(n_mc_samples)
|
|
1079
|
+
if return_mean:
|
|
1080
|
+
batch_log_lkl = torch.mean(batch_log_lkl).item()
|
|
1081
|
+
else:
|
|
1082
|
+
batch_log_lkl = batch_log_lkl.cpu()
|
|
1083
|
+
return batch_log_lkl
|
|
1084
|
+
|
|
1085
|
+
|
|
1086
|
+
# ─────────────────────────────────────────────────────────────
|
|
1087
|
+
# Model
|
|
1088
|
+
# ─────────────────────────────────────────────────────────────
|
|
1089
|
+
class CycleVI(EmbeddingMixin,
|
|
1090
|
+
RNASeqMixin,
|
|
1091
|
+
VAEMixin,
|
|
1092
|
+
ArchesMixin,
|
|
1093
|
+
UnsupervisedTrainingMixin,
|
|
1094
|
+
BaseMinifiedModeModelClass):
|
|
1095
|
+
|
|
1096
|
+
# Tell scvi-tools which module class this model uses
|
|
1097
|
+
_module_cls = CycleVI_VAE
|
|
1098
|
+
|
|
1099
|
+
# Keys for storing latent mean and variance in AnnData
|
|
1100
|
+
_LATENT_QZM_KEY = "ccvi_latent_qzm" # Key for the latent mean in AnnData
|
|
1101
|
+
_LATENT_QZV_KEY = "ccvi_latent_qzv" # Key for the latent variance in AnnData
|
|
1102
|
+
|
|
1103
|
+
# Define the training plan
|
|
1104
|
+
_training_plan_cls = PhaseAdversarialTrainingPlan
|
|
1105
|
+
|
|
1106
|
+
# ─────────────────────────────────────────────────────────────
|
|
1107
|
+
# Constructor
|
|
1108
|
+
# ─────────────────────────────────────────────────────────────
|
|
1109
|
+
def __init__(
|
|
1110
|
+
self,
|
|
1111
|
+
adata: AnnData | None = None, # Input data; can be None (if adata is not provided, the model will delay initialization until train is called).
|
|
1112
|
+
n_hidden: int = 128, # Hidden units per layer
|
|
1113
|
+
n_latent: int = 10, # Dimensionality of latent space
|
|
1114
|
+
n_layers: int = 1, # Number of layers in encoder/decoder neural networks
|
|
1115
|
+
dropout_rate: float = 0.1, # Dropout rate
|
|
1116
|
+
dispersion: Literal[...] = "gene-label", # How to parameterize dispersion (per gene, per cell, etc.)
|
|
1117
|
+
gene_likelihood: Literal[...] = "nb", # Likelihood distribution for gene expression (usually Negative Binomial)
|
|
1118
|
+
latent_distribution: Literal[...] = "normal", # Latent distribution type
|
|
1119
|
+
cycle_gene_mask: torch.Tensor | None = None, # Boolean mask marking cycle genes for disentanglement (usually none)
|
|
1120
|
+
**kwargs, # Any other parameters passed to the VAE
|
|
1121
|
+
):
|
|
1122
|
+
|
|
1123
|
+
# Call the constructor of the parent mixin/base classes
|
|
1124
|
+
super().__init__(adata)
|
|
1125
|
+
|
|
1126
|
+
# Store model configuration for the underlying PyTorch module (CycleVI_VAE)
|
|
1127
|
+
self._module_kwargs = {
|
|
1128
|
+
"n_hidden": n_hidden,
|
|
1129
|
+
"n_latent": n_latent,
|
|
1130
|
+
"n_layers": n_layers,
|
|
1131
|
+
"dropout_rate": dropout_rate,
|
|
1132
|
+
"dispersion": dispersion,
|
|
1133
|
+
"gene_likelihood": gene_likelihood,
|
|
1134
|
+
"latent_distribution": latent_distribution,
|
|
1135
|
+
"cycle_gene_mask": cycle_gene_mask,
|
|
1136
|
+
**kwargs,
|
|
1137
|
+
}
|
|
1138
|
+
|
|
1139
|
+
# Build a human-readable summary string of the model architecture
|
|
1140
|
+
self._model_summary_string = (
|
|
1141
|
+
"CycleVI model with the following parameters: \n"
|
|
1142
|
+
f"n_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, "
|
|
1143
|
+
f"dropout_rate: {dropout_rate}, dispersion: {dispersion}, "
|
|
1144
|
+
f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}, cycle_gene_mask: {cycle_gene_mask} "
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
# If lazy initialization is enabled (adata is not provided), postpone model creation until training
|
|
1148
|
+
if self._module_init_on_train:
|
|
1149
|
+
self.module = None
|
|
1150
|
+
warnings.warn(
|
|
1151
|
+
"Model was initialized without `adata`. The module will be initialized when "
|
|
1152
|
+
"calling `train`. This behavior is experimental and may change in the future.",
|
|
1153
|
+
UserWarning,
|
|
1154
|
+
stacklevel=settings.warnings_stacklevel,
|
|
1155
|
+
)
|
|
1156
|
+
else:
|
|
1157
|
+
# ─────────────────────────────────────────────
|
|
1158
|
+
# Collect dataset-specific information
|
|
1159
|
+
# ─────────────────────────────────────────────
|
|
1160
|
+
|
|
1161
|
+
# For categorical covariates (e.g. cell type, donor ID), get the number of categories per covariate
|
|
1162
|
+
n_cats_per_cov = (
|
|
1163
|
+
self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
|
|
1164
|
+
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
|
|
1165
|
+
else None
|
|
1166
|
+
)
|
|
1167
|
+
|
|
1168
|
+
# Number of batches (e.g. experimental batches in AnnData.obs["batch"])
|
|
1169
|
+
n_batch = self.summary_stats.n_batch
|
|
1170
|
+
|
|
1171
|
+
# Check if per-cell size factors are already stored in the AnnData registry
|
|
1172
|
+
use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
|
|
1173
|
+
|
|
1174
|
+
# Initialize library size parameters if needed
|
|
1175
|
+
library_log_means, library_log_vars = None, None
|
|
1176
|
+
if (
|
|
1177
|
+
not use_size_factor_key
|
|
1178
|
+
and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
|
|
1179
|
+
):
|
|
1180
|
+
library_log_means, library_log_vars = _init_library_size(
|
|
1181
|
+
self.adata_manager, n_batch
|
|
1182
|
+
)
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
# ─────────────────────────────────────────────
|
|
1186
|
+
# Instantiate the VAE
|
|
1187
|
+
# ─────────────────────────────────────────────
|
|
1188
|
+
self.module = self._module_cls(
|
|
1189
|
+
n_input=self.summary_stats.n_vars, # number of genes
|
|
1190
|
+
n_batch=n_batch, # number of batches
|
|
1191
|
+
n_labels=self.summary_stats.n_labels, # number of labels
|
|
1192
|
+
n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
|
|
1193
|
+
n_cats_per_cov=n_cats_per_cov, # categorical covariates
|
|
1194
|
+
n_hidden=n_hidden, # hidden units per layer
|
|
1195
|
+
n_latent=n_latent, # latent dimension
|
|
1196
|
+
n_layers=n_layers, # number of layers
|
|
1197
|
+
dropout_rate=dropout_rate, # dropout probability
|
|
1198
|
+
dispersion=dispersion, # dispersion parameterization mode
|
|
1199
|
+
gene_likelihood=gene_likelihood, # likelihood model
|
|
1200
|
+
latent_distribution=latent_distribution, # prior distribution for z
|
|
1201
|
+
use_size_factor_key=use_size_factor_key, # whether to use size factors
|
|
1202
|
+
library_log_means=library_log_means, # init mean for library size prior
|
|
1203
|
+
library_log_vars=library_log_vars, # init variance for library size prior
|
|
1204
|
+
cycle_gene_mask=cycle_gene_mask, # pass your custom mask for cycle genes
|
|
1205
|
+
**kwargs, # forward any additional arguments
|
|
1206
|
+
)
|
|
1207
|
+
|
|
1208
|
+
# Set minified type to the model (used for memory optimization)
|
|
1209
|
+
self.module.minified_data_type = self.minified_data_type
|
|
1210
|
+
|
|
1211
|
+
# Save init parameters for reproducibility
|
|
1212
|
+
self.init_params_ = self._get_init_params(locals())
|
|
1213
|
+
|
|
1214
|
+
# ─────────────────────────────────────────────
|
|
1215
|
+
# Register data with scvi-tools (what to read from AnnData and where)
|
|
1216
|
+
# ─────────────────────────────────────────────
|
|
1217
|
+
@classmethod
|
|
1218
|
+
@setup_anndata_dsp.dedent
|
|
1219
|
+
def setup_anndata(
|
|
1220
|
+
cls,
|
|
1221
|
+
adata: AnnData,
|
|
1222
|
+
layer: str | None = None,
|
|
1223
|
+
batch_key: str | None = None,
|
|
1224
|
+
labels_key: str | None = None,
|
|
1225
|
+
size_factor_key: str | None = None,
|
|
1226
|
+
categorical_covariate_keys: list[str] | None = None,
|
|
1227
|
+
continuous_covariate_keys: list[str] | None = None,
|
|
1228
|
+
# NEW:
|
|
1229
|
+
cycle_initiation_label_key: str | None = None, # e.g. "phase"
|
|
1230
|
+
cycle_initiation_angle_key: str | None = None, # e.g. "cycle_angle_uniform"
|
|
1231
|
+
**kwargs,
|
|
1232
|
+
):
|
|
1233
|
+
"""
|
|
1234
|
+
Register AnnData fields for CycleVI.
|
|
1235
|
+
|
|
1236
|
+
Parameters
|
|
1237
|
+
----------
|
|
1238
|
+
%(param_adata)s
|
|
1239
|
+
%(param_layer)s
|
|
1240
|
+
%(param_batch_key)s
|
|
1241
|
+
%(param_labels_key)s
|
|
1242
|
+
%(param_size_factor_key)s
|
|
1243
|
+
%(param_cat_cov_keys)s
|
|
1244
|
+
%(param_cont_cov_keys)s
|
|
1245
|
+
|
|
1246
|
+
Notes
|
|
1247
|
+
-----
|
|
1248
|
+
Phase inputs are registered separately and NOT injected as covariates:
|
|
1249
|
+
- phase labels -> CYCLE_REGISTRY_KEYS.CYCLE_LABEL_KEY
|
|
1250
|
+
- phase angle -> CYCLE_REGISTRY_KEYS.CYCLE_ANGLE_KEY
|
|
1251
|
+
"""
|
|
1252
|
+
|
|
1253
|
+
setup_method_args = cls._get_setup_method_args(**locals())
|
|
1254
|
+
|
|
1255
|
+
anndata_fields = [
|
|
1256
|
+
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
|
|
1257
|
+
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
|
|
1258
|
+
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
|
|
1259
|
+
NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
|
|
1260
|
+
|
|
1261
|
+
# regular covs (kept separate from phase)
|
|
1262
|
+
CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
|
|
1263
|
+
NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
|
|
1264
|
+
|
|
1265
|
+
# NEW: *separate* phase inputs
|
|
1266
|
+
CategoricalObsField(CYCLE_REGISTRY_KEYS.CYCLE_LABEL_KEY, cycle_initiation_label_key),
|
|
1267
|
+
NumericalObsField(CYCLE_REGISTRY_KEYS.CYCLE_ANGLE_KEY, cycle_initiation_angle_key),
|
|
1268
|
+
|
|
1269
|
+
|
|
1270
|
+
]
|
|
1271
|
+
|
|
1272
|
+
adata_minify_type = _get_adata_minify_type(adata)
|
|
1273
|
+
if adata_minify_type is not None:
|
|
1274
|
+
anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
|
|
1275
|
+
|
|
1276
|
+
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
|
|
1277
|
+
adata_manager.register_fields(adata, **kwargs)
|
|
1278
|
+
cls.register_manager(adata_manager)
|
|
1279
|
+
|
|
1280
|
+
@torch.inference_mode()
|
|
1281
|
+
|
|
1282
|
+
|
|
1283
|
+
def get_normalized_expression(
|
|
1284
|
+
self,
|
|
1285
|
+
adata: AnnData | None = None,
|
|
1286
|
+
indices: list[int] | None = None,
|
|
1287
|
+
transform_batch: list[Number | str] | None = None,
|
|
1288
|
+
gene_list: list[str] | None = None,
|
|
1289
|
+
library_size: float | Literal["latent"] = 1,
|
|
1290
|
+
n_samples: int = 1,
|
|
1291
|
+
n_samples_overall: int = None,
|
|
1292
|
+
weights: Literal["uniform", "importance"] | None = None,
|
|
1293
|
+
batch_size: int | None = None,
|
|
1294
|
+
return_mean: bool = True,
|
|
1295
|
+
return_numpy: bool | None = None,
|
|
1296
|
+
silent: bool = True,
|
|
1297
|
+
dataloader: Iterator[dict[str, Tensor | None]] | None = None,
|
|
1298
|
+
remove_cell_cycle: bool = False,
|
|
1299
|
+
**importance_weighting_kwargs,
|
|
1300
|
+
) -> np.ndarray | pd.DataFrame:
|
|
1301
|
+
|
|
1302
|
+
|
|
1303
|
+
if dataloader is None:
|
|
1304
|
+
adata = self._validate_anndata(adata)
|
|
1305
|
+
if indices is None:
|
|
1306
|
+
indices = np.arange(adata.n_obs)
|
|
1307
|
+
if n_samples_overall is not None:
|
|
1308
|
+
assert n_samples == 1
|
|
1309
|
+
n_samples = n_samples_overall // len(indices) + 1
|
|
1310
|
+
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
|
|
1311
|
+
transform_batch = _get_batch_code_from_category(
|
|
1312
|
+
self.get_anndata_manager(adata, required=True), transform_batch
|
|
1313
|
+
)
|
|
1314
|
+
gene_mask = slice(None) if gene_list is None else adata.var_names.isin(gene_list)
|
|
1315
|
+
else:
|
|
1316
|
+
scdl = dataloader
|
|
1317
|
+
gene_mask = slice(None)
|
|
1318
|
+
transform_batch = [None]
|
|
1319
|
+
|
|
1320
|
+
if n_samples > 1 and return_mean is False:
|
|
1321
|
+
if return_numpy is False:
|
|
1322
|
+
warnings.warn("return_numpy must be True if n_samples > 1 and return_mean is False.")
|
|
1323
|
+
return_numpy = True
|
|
1324
|
+
|
|
1325
|
+
generative_output_key = "mu" if library_size == "latent" else "scale"
|
|
1326
|
+
scaling = 1 if library_size == "latent" else library_size
|
|
1327
|
+
|
|
1328
|
+
exprs = []
|
|
1329
|
+
zs = []
|
|
1330
|
+
qz_store = DistributionConcatenator()
|
|
1331
|
+
px_store = DistributionConcatenator()
|
|
1332
|
+
|
|
1333
|
+
for tensors in scdl:
|
|
1334
|
+
per_batch_exprs = []
|
|
1335
|
+
for batch in track(transform_batch, disable=silent):
|
|
1336
|
+
generative_kwargs = self._get_transform_batch_gen_kwargs(batch)
|
|
1337
|
+
generative_kwargs["remove_cell_cycle"] = remove_cell_cycle
|
|
1338
|
+
inference_outputs, generative_outputs = self.module.forward(
|
|
1339
|
+
tensors=tensors,
|
|
1340
|
+
inference_kwargs={"n_samples": n_samples},
|
|
1341
|
+
generative_kwargs=generative_kwargs,
|
|
1342
|
+
compute_loss=False,
|
|
1343
|
+
)
|
|
1344
|
+
px_generative = generative_outputs["px"]
|
|
1345
|
+
exp_ = px_generative.get_normalized(generative_output_key)
|
|
1346
|
+
exp_ = exp_[..., gene_mask] * scaling
|
|
1347
|
+
per_batch_exprs.append(exp_[None].cpu())
|
|
1348
|
+
if weights == "importance":
|
|
1349
|
+
qz_store.store_distribution(inference_outputs["qz"])
|
|
1350
|
+
px_store.store_distribution(px_generative)
|
|
1351
|
+
|
|
1352
|
+
zs.append(inference_outputs["z"].cpu())
|
|
1353
|
+
per_batch_exprs = torch.cat(per_batch_exprs, dim=0).mean(0).numpy()
|
|
1354
|
+
exprs.append(per_batch_exprs)
|
|
1355
|
+
|
|
1356
|
+
cell_axis = 1 if n_samples > 1 else 0
|
|
1357
|
+
exprs = np.concatenate(exprs, axis=cell_axis)
|
|
1358
|
+
|
|
1359
|
+
if n_samples_overall is not None:
|
|
1360
|
+
exprs = exprs.reshape(-1, exprs.shape[-1])
|
|
1361
|
+
n_samples_ = exprs.shape[0]
|
|
1362
|
+
if weights is None or weights == "uniform":
|
|
1363
|
+
p = None
|
|
1364
|
+
else:
|
|
1365
|
+
qz = qz_store.get_concatenated_distributions(axis=0)
|
|
1366
|
+
px = px_store.get_concatenated_distributions(axis=0 if n_samples == 1 else 1)
|
|
1367
|
+
p = self._get_importance_weights(
|
|
1368
|
+
adata, indices, qz, px, torch.concat(zs, dim=cell_axis), **importance_weighting_kwargs
|
|
1369
|
+
)
|
|
1370
|
+
exprs = exprs[np.random.choice(n_samples_, n_samples_overall, p=p, replace=True)]
|
|
1371
|
+
elif n_samples > 1 and return_mean:
|
|
1372
|
+
exprs = exprs.mean(0)
|
|
1373
|
+
|
|
1374
|
+
if (return_numpy is None or not return_numpy) and dataloader is None:
|
|
1375
|
+
return pd.DataFrame(exprs, columns=adata.var_names[gene_mask], index=adata.obs_names[indices])
|
|
1376
|
+
return exprs
|
|
1377
|
+
|
|
1378
|
+
def differential_expression(
|
|
1379
|
+
self,
|
|
1380
|
+
adata: AnnData | None = None,
|
|
1381
|
+
groupby: str | None = None,
|
|
1382
|
+
group1: list[str] | None = None,
|
|
1383
|
+
group2: str | None = None,
|
|
1384
|
+
idx1: list[int] | list[bool] | str | None = None,
|
|
1385
|
+
idx2: list[int] | list[bool] | str | None = None,
|
|
1386
|
+
mode: Literal["vanilla", "change"] = "vanilla",
|
|
1387
|
+
delta: float = 0.25,
|
|
1388
|
+
batch_size: int | None = None,
|
|
1389
|
+
all_stats: bool = True,
|
|
1390
|
+
batch_correction: bool = False,
|
|
1391
|
+
batchid1: list[str] | None = None,
|
|
1392
|
+
batchid2: list[str] | None = None,
|
|
1393
|
+
fdr_target: float = 0.05,
|
|
1394
|
+
silent: bool = False,
|
|
1395
|
+
weights: Literal["uniform", "importance"] | None = "uniform",
|
|
1396
|
+
filter_outlier_cells: bool = False,
|
|
1397
|
+
remove_cell_cycle: bool = False,
|
|
1398
|
+
importance_weighting_kwargs: dict | None = None,
|
|
1399
|
+
**kwargs,
|
|
1400
|
+
) -> pd.DataFrame:
|
|
1401
|
+
adata = self._validate_anndata(adata)
|
|
1402
|
+
col_names = adata.var_names
|
|
1403
|
+
importance_weighting_kwargs = importance_weighting_kwargs or {}
|
|
1404
|
+
|
|
1405
|
+
model_fn = partial(
|
|
1406
|
+
self.get_normalized_expression,
|
|
1407
|
+
return_numpy=True,
|
|
1408
|
+
n_samples=1,
|
|
1409
|
+
batch_size=batch_size,
|
|
1410
|
+
weights=weights,
|
|
1411
|
+
remove_cell_cycle=remove_cell_cycle,
|
|
1412
|
+
**importance_weighting_kwargs,
|
|
1413
|
+
)
|
|
1414
|
+
representation_fn = self.get_latent_representation if filter_outlier_cells else None
|
|
1415
|
+
|
|
1416
|
+
result = _de_core(
|
|
1417
|
+
self.get_anndata_manager(adata, required=True),
|
|
1418
|
+
model_fn,
|
|
1419
|
+
representation_fn,
|
|
1420
|
+
groupby,
|
|
1421
|
+
group1,
|
|
1422
|
+
group2,
|
|
1423
|
+
idx1,
|
|
1424
|
+
idx2,
|
|
1425
|
+
all_stats,
|
|
1426
|
+
scrna_raw_counts_properties,
|
|
1427
|
+
col_names,
|
|
1428
|
+
mode,
|
|
1429
|
+
batchid1,
|
|
1430
|
+
batchid2,
|
|
1431
|
+
delta,
|
|
1432
|
+
batch_correction,
|
|
1433
|
+
fdr_target,
|
|
1434
|
+
silent,
|
|
1435
|
+
**kwargs,
|
|
1436
|
+
)
|
|
1437
|
+
|
|
1438
|
+
return result
|