cyclevi 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1438 @@
1
+ # ─────────────────────────────────────────────────────────────
2
+ # Imports
3
+ # ─────────────────────────────────────────────────────────────
4
+
5
+ import logging
6
+ import warnings
7
+ from collections.abc import Iterator, Iterable
8
+ from functools import partial
9
+ from numbers import Number
10
+ from typing import Callable, Literal
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ # PyTorch
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch import Tensor
18
+ from torch.distributions import Distribution, Normal
19
+ from torch.nn.functional import one_hot
20
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
21
+
22
+ # Single-cell analysis tools
23
+ from anndata import AnnData
24
+
25
+ # scvi-tools
26
+ from scvi import REGISTRY_KEYS, settings
27
+ from scvi.data import AnnDataManager
28
+ from scvi.data._constants import ADATA_MINIFY_TYPE
29
+ from scvi.data._utils import _get_adata_minify_type
30
+ from scvi.data.fields import (
31
+ CategoricalJointObsField,
32
+ CategoricalObsField,
33
+ LayerField,
34
+ NumericalJointObsField,
35
+ NumericalObsField,
36
+ )
37
+ from scvi.distributions._utils import DistributionConcatenator
38
+ from scvi.model._utils import (
39
+ _get_batch_code_from_category,
40
+ _init_library_size,
41
+ scrna_raw_counts_properties,
42
+ )
43
+ from scvi.model.base import (
44
+ ArchesMixin,
45
+ BaseMinifiedModeModelClass,
46
+ EmbeddingMixin,
47
+ RNASeqMixin,
48
+ UnsupervisedTrainingMixin,
49
+ VAEMixin,
50
+ )
51
+ from scvi.model.base._de_core import _de_core
52
+ from scvi.module._constants import MODULE_KEYS
53
+ from scvi.module.base import (
54
+ BaseMinifiedModeModuleClass,
55
+ BaseModuleClass,
56
+ EmbeddingModuleMixin,
57
+ LossOutput,
58
+ auto_move_data,
59
+ )
60
+ from scvi.nn import Encoder, FCLayers
61
+ from scvi.train import (
62
+ TrainingPlan,
63
+ )
64
+ from scvi.utils import (
65
+ setup_anndata_dsp,
66
+ track,
67
+ unsupported_if_adata_minified,
68
+ )
69
+
70
+ # ─────────────────────────────────────────────────────────────
71
+ # Helpers
72
+ # ─────────────────────────────────────────────────────────────
73
+
74
+ def _identity(x):
75
+ return x
76
+
77
+ # Logger setup
78
+ logger = logging.getLogger(__name__)
79
+
80
+
81
+ # ─────────────────────────────────────────────────────────────
82
+ # Optional: Creating cell cycle gene mask
83
+ # ─────────────────────────────────────────────────────────────
84
+
85
+ def create_cell_cycle_gene_mask(adata: AnnData, genes_txt: str, var_column: str = None) -> torch.Tensor:
86
+ """
87
+ Create a boolean mask indicating which genes are cell cycle–dependent.
88
+
89
+ Parameters
90
+ ----------
91
+ adata : AnnData
92
+ The AnnData object containing single-cell expression data.
93
+ genes_txt : str
94
+ Path to a text file with a list of cell cycle genes (one per line).
95
+ var_column : str, optional
96
+ Name of a column in adata.var to use for gene identifiers.
97
+ If None, use adata.var_names (default).
98
+
99
+ Returns
100
+ -------
101
+ torch.Tensor
102
+ A boolean tensor of shape (n_genes,) where True indicates
103
+ that the gene is in the cell cycle gene list.
104
+ """
105
+ gene_list = pd.read_csv(genes_txt, header=None)[0].str.upper().tolist()
106
+ gene_set = set(gene_list)
107
+
108
+ if var_column:
109
+ genes = adata.var[var_column].astype(str).str.upper()
110
+ else:
111
+ genes = adata.var_names.str.upper()
112
+
113
+ return torch.tensor([g in gene_set for g in genes], dtype=torch.bool)
114
+
115
+ # Example usage:
116
+ # Load cell cycle mask from the provided GO annotation file
117
+ # cycle_mask = create_cell_cycle_gene_mask(adata, "GO_cell_cycle_annotation_human.txt")
118
+
119
+
120
+ # ─────────────────────────────────────────────────────────────
121
+ # Cell Cycle Registry Keys
122
+ # ─────────────────────────────────────────────────────────────
123
+
124
+ class CYCLE_REGISTRY_KEYS:
125
+ CYCLE_LABEL_KEY = "cycle_initiation_label"
126
+ CYCLE_ANGLE_KEY = "cycle_initiation_angle"
127
+
128
+ # ─────────────────────────────────────────────────────────────
129
+ # Adversarial Classifier
130
+ # ─────────────────────────────────────────────────────────────
131
+
132
+ class Classifier(nn.Module):
133
+ """
134
+ Simple feedforward neural network for adversarial classification.
135
+
136
+ Args:
137
+ n_input (int): Dimensionality of input features.
138
+ n_hidden (int): Number of units in hidden layers.
139
+ n_labels (int): Number of output classes.
140
+ n_layers (int): Number of linear layers (default: 2).
141
+ logits (bool): Whether to output raw logits (default: True).
142
+ """
143
+ def __init__(self, n_input, n_hidden, n_labels, n_layers=2, logits=True):
144
+ super().__init__()
145
+ self.logits = logits
146
+
147
+ layers = []
148
+ in_dim = n_input
149
+ for _ in range(n_layers - 1):
150
+ layers.append(nn.Linear(in_dim, n_hidden))
151
+ layers.append(nn.ReLU())
152
+ in_dim = n_hidden
153
+ layers.append(nn.Linear(in_dim, n_labels))
154
+
155
+ self.network = nn.Sequential(*layers)
156
+
157
+ def forward(self, x):
158
+ return self.network(x)
159
+
160
+
161
+ class PhaseAdversarialTrainingPlan(TrainingPlan):
162
+ """
163
+ Training plan with adversarial phase classifier to prevent cycle phase information
164
+ leakage in non-circular latent space (z_other).
165
+
166
+ Args:
167
+ module (BaseModuleClass): scvi-tools model module.
168
+ scale_adversarial_loss (float | str): Scaling factor for adversarial loss, or 'auto'.
169
+ All other arguments follow scvi-tools TrainingPlan.
170
+ """
171
+ def __init__(
172
+ self,
173
+ module: BaseModuleClass,
174
+ *,
175
+ optimizer: str = "Adam",
176
+ optimizer_creator=None,
177
+ lr: float = 1e-3,
178
+ weight_decay: float = 1e-6,
179
+ n_steps_kl_warmup: int = None,
180
+ n_epochs_kl_warmup: int = 400,
181
+ reduce_lr_on_plateau: bool = False,
182
+ lr_factor: float = 0.6,
183
+ lr_patience: int = 30,
184
+ lr_threshold: float = 0.0,
185
+ lr_scheduler_metric: str = "elbo_validation",
186
+ lr_min: float = 0.0,
187
+ scale_adversarial_loss: float | str = "auto",
188
+ compile: bool = False,
189
+ compile_kwargs: dict | None = None,
190
+ **loss_kwargs,
191
+ ):
192
+ super().__init__(
193
+ module=module,
194
+ optimizer=optimizer,
195
+ optimizer_creator=optimizer_creator,
196
+ lr=lr,
197
+ weight_decay=weight_decay,
198
+ n_steps_kl_warmup=n_steps_kl_warmup,
199
+ n_epochs_kl_warmup=n_epochs_kl_warmup,
200
+ reduce_lr_on_plateau=reduce_lr_on_plateau,
201
+ lr_factor=lr_factor,
202
+ lr_patience=lr_patience,
203
+ lr_threshold=lr_threshold,
204
+ lr_scheduler_metric=lr_scheduler_metric,
205
+ lr_min=lr_min,
206
+ compile=compile,
207
+ compile_kwargs=compile_kwargs,
208
+ **loss_kwargs,
209
+ )
210
+
211
+ # Setup adversarial classifier (e.g., to predict 3 discrete phases)
212
+ self.adversarial_classifier = Classifier(
213
+ n_input=self.module.n_latent - 2, # exclude 2D z_cycle
214
+ n_hidden=32,
215
+ n_labels=3,
216
+ n_layers=2,
217
+ logits=True,
218
+ )
219
+
220
+ self.scale_adversarial_loss = scale_adversarial_loss
221
+ self.automatic_optimization = False # Manual optimization loop
222
+
223
+ def loss_adversarial_classifier(self, z_other, phase_index, predict_true_class=True):
224
+ """
225
+ Computes adversarial loss either for classification or fooling.
226
+
227
+ Args:
228
+ z_other (Tensor): Latent space excluding z_cycle.
229
+ phase_index (Tensor): True class indices.
230
+ predict_true_class (bool): If False, trains to fool the classifier.
231
+
232
+ Returns:
233
+ Tensor: Loss value.
234
+ """
235
+ logits = self.adversarial_classifier(z_other)
236
+ cls_logits = torch.nn.LogSoftmax(dim=1)(logits)
237
+ n_classes = cls_logits.shape[1]
238
+
239
+ if predict_true_class:
240
+ cls_target = torch.nn.functional.one_hot(phase_index.squeeze(-1), n_classes)
241
+ else:
242
+ # For fooling: create soft target that spreads over all incorrect classes
243
+ one_hot = torch.nn.functional.one_hot(phase_index.squeeze(-1), n_classes)
244
+ cls_target = (~one_hot.bool()).float() / (n_classes - 1)
245
+
246
+ return -(cls_logits * cls_target).sum(dim=1).mean()
247
+
248
+ def training_step(self, batch, batch_idx):
249
+ """
250
+ Custom training step with adversarial optimization.
251
+
252
+ Step 1: Optimize model to fool the classifier.
253
+ Step 2: Optimize classifier to predict true phase.
254
+ """
255
+ if "kl_weight" in self.loss_kwargs:
256
+ self.loss_kwargs["kl_weight"] = self.kl_weight
257
+ self.log("kl_weight", self.kl_weight, on_step=True, on_epoch=False)
258
+
259
+ # Determine scaling factor kappa
260
+ kappa = 1 - self.kl_weight if self.scale_adversarial_loss == "auto" else self.scale_adversarial_loss
261
+
262
+ # Assume phase is the first categorical covariate
263
+ batch_cat = batch[CYCLE_REGISTRY_KEYS.CYCLE_LABEL_KEY].long().squeeze(-1)
264
+
265
+ opt1, opt2 = self.optimizers() if isinstance(self.optimizers(), list) else (self.optimizers(), None)
266
+
267
+ outputs, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
268
+ z = outputs["z"]
269
+ z_other = z[:, 2:] # Remove 2D z_cycle from full latent
270
+
271
+ loss = scvi_loss.loss
272
+
273
+ # Step 1: Fool classifier (optimize model to remove phase info)
274
+ if kappa > 0:
275
+ fool_loss = self.loss_adversarial_classifier(z_other, batch_cat, predict_true_class=False)
276
+ loss += fool_loss * kappa * 1000
277
+
278
+ self.log("train_loss", loss, on_epoch=True, prog_bar=True)
279
+ self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
280
+
281
+ opt1.zero_grad()
282
+ self.manual_backward(loss)
283
+ opt1.step()
284
+
285
+ # Step 2: Train classifier (optimize to correctly predict phase)
286
+ if opt2 is not None:
287
+ cls_loss = self.loss_adversarial_classifier(z_other.detach(), batch_cat, predict_true_class=True)
288
+ cls_loss *= kappa
289
+
290
+ opt2.zero_grad()
291
+ self.manual_backward(cls_loss)
292
+ opt2.step()
293
+
294
+ def configure_optimizers(self):
295
+ """
296
+ Returns separate optimizers for model and classifier.
297
+ Optionally adds LR scheduler for the model.
298
+ """
299
+ # Optimizer for main model
300
+ params1 = filter(lambda p: p.requires_grad, self.module.parameters())
301
+ optimizer1 = self.get_optimizer_creator()(params1)
302
+ config1 = {"optimizer": optimizer1}
303
+
304
+ # Optional learning rate scheduler
305
+ if self.reduce_lr_on_plateau:
306
+ scheduler = ReduceLROnPlateau(
307
+ optimizer1,
308
+ patience=self.lr_patience,
309
+ factor=self.lr_factor,
310
+ threshold=self.lr_threshold,
311
+ min_lr=self.lr_min,
312
+ threshold_mode="abs",
313
+ )
314
+ config1["lr_scheduler"] = {"scheduler": scheduler, "monitor": self.lr_scheduler_metric}
315
+
316
+ # Optimizer for adversarial classifier
317
+ params2 = filter(lambda p: p.requires_grad, self.adversarial_classifier.parameters())
318
+ optimizer2 = torch.optim.Adam(params2, lr=1e-3, eps=0.01, weight_decay=self.weight_decay)
319
+
320
+ if "lr_scheduler" in config1:
321
+ return [config1["optimizer"], optimizer2], [config1["lr_scheduler"]]
322
+ return [config1["optimizer"], optimizer2]
323
+
324
+
325
+ # ─────────────────────────────────────────────────────────────
326
+ # Decoder
327
+ # ─────────────────────────────────────────────────────────────
328
+ class DecoderCycleVI(nn.Module):
329
+ """
330
+ Custom decoder for CycleVI model that separates gene expression into:
331
+ - Non-cyclic components via a feedforward network.
332
+ - Cell cycle–dependent components via Fourier basis functions.
333
+
334
+ Arguments:
335
+ n_input (int): Input dimension (latent + covariates).
336
+ n_output (int): Number of output genes.
337
+ n_layers (int): Number of hidden layers in the feedforward decoder.
338
+ n_hidden (int): Width of each hidden layer.
339
+ n_cat_list: A list containing the number of categories
340
+ for each category of interest. Each category will be
341
+ included using a one-hot encoding
342
+ inject_covariates (bool): Whether to inject covariates in FCLayers.
343
+ use_batch_norm (bool): Whether to apply batch normalization.
344
+ use_layer_norm (bool): Whether to apply layer normalization.
345
+ scale_activation (str): 'softmax' or 'softplus' for output.
346
+ cycle_gene_mask (torch.Tensor): Boolean mask over cycle-regulated genes.
347
+ n_fourier (int): Number of Fourier harmonics for cyclic signal.
348
+ """
349
+
350
+ def __init__(
351
+ self,
352
+ n_input: int,
353
+ n_output: int,
354
+ n_layers: int = 1,
355
+ n_hidden: int = 128,
356
+ n_cat_list: Iterable[int] = None,
357
+ inject_covariates: bool = True,
358
+ use_batch_norm: bool = False,
359
+ use_layer_norm: bool = False,
360
+ scale_activation: str = "softmax",
361
+ cycle_gene_mask: torch.Tensor = None,
362
+ n_fourier: int = 3,
363
+ **kwargs
364
+ ):
365
+ super().__init__()
366
+
367
+ if cycle_gene_mask is None:
368
+ cycle_gene_mask = torch.ones(n_output, dtype=torch.bool)
369
+ elif cycle_gene_mask.shape[0] != n_output:
370
+ raise ValueError("`cycle_gene_mask` must match n_output")
371
+
372
+ self.register_buffer("cycle_mask", cycle_gene_mask.float())
373
+ self.n_output = n_output
374
+ self.n_fourier = n_fourier
375
+
376
+ # Feedforward decoder for non-cycle component (takes z_latent only)
377
+ self.non_cycle_fc = FCLayers(
378
+ n_in=n_input - 2, # exclude 2D z_cycle
379
+ n_out=n_hidden,
380
+ n_cat_list=n_cat_list,
381
+ n_layers=n_layers,
382
+ n_hidden=n_hidden,
383
+ use_batch_norm=use_batch_norm,
384
+ use_layer_norm=use_layer_norm,
385
+ inject_covariates=True,
386
+ activation_fn=nn.ReLU,
387
+ use_activation=True,
388
+ bias=True,
389
+ )
390
+ self.non_cycle_linear = nn.Linear(n_hidden, n_output)
391
+
392
+ # Fourier weights for periodic expression modulation
393
+ self.fourier_W = nn.Parameter(0.01 * torch.randn(2 * n_fourier, n_output))
394
+
395
+ # Gradient mask so only cycle genes get updated via Fourier weights
396
+ grad_mask = torch.zeros_like(self.fourier_W)
397
+ grad_mask[:, cycle_gene_mask] = 1.0
398
+ self.fourier_W.register_hook(lambda grad: grad * grad_mask.to(grad.device))
399
+
400
+ # Raw dispersion parameter for Negative Binomial model
401
+ self.disp_raw = nn.Parameter(0.01 * torch.randn(n_output))
402
+
403
+ # Output activation
404
+ self.px_scale_activation = (
405
+ nn.Softmax(dim=-1) if scale_activation == "softmax" else nn.Softplus()
406
+ )
407
+
408
+ def forward(
409
+ self,
410
+ z: torch.Tensor, # Latent vector [N, latent_dim]
411
+ library: torch.Tensor, # Log-library size [N, 1]
412
+ remove_cell_cycle: bool = False, # Disable cyclic modulation if True
413
+ *cat_list: int,
414
+ ):
415
+ # Split latent space: z_cycle (2D) and z_latent (rest)
416
+ z_cycle = z[..., 0:2]
417
+ z_latent = z[..., 2:]
418
+
419
+ # Feedforward decoder for baseline gene expression
420
+ x_input = z_latent
421
+ x = self.non_cycle_fc(x_input,*cat_list)
422
+ non_cycle_out = self.non_cycle_linear(x)
423
+
424
+ # Convert 2D z_cycle into phase angle θ
425
+ x, y = z_cycle[..., 0], z_cycle[..., 1]
426
+ angle = torch.atan2(y, x) # shape: [N]
427
+
428
+ # Compute cycle-dependent modulation (Fourier basis)
429
+ if remove_cell_cycle:
430
+ cycle_effect = 0.0
431
+ else:
432
+ # Construct Fourier basis: [cos(kθ), sin(kθ)] for k in 1..n
433
+ basis = [
434
+ torch.cos(k * angle) for k in range(1, self.n_fourier + 1)
435
+ ] + [
436
+ torch.sin(k * angle) for k in range(1, self.n_fourier + 1)
437
+ ]
438
+ fourier_basis = torch.stack(basis, dim=-1) # shape: [N, 2 * n_fourier]
439
+ cycle_effect = torch.matmul(fourier_basis, self.fourier_W) * self.cycle_mask
440
+
441
+ # Combine non-cycle and cycle effects
442
+ eta = non_cycle_out + cycle_effect # shape: [N, n_output]
443
+
444
+ # Output gene proportions (scale) and rate
445
+ px_scale = self.px_scale_activation(eta) # [N, G]
446
+ px_rate = torch.exp(library) * px_scale # scaled by library size
447
+ disp = self.disp_raw # [G]
448
+
449
+ return (
450
+ px_scale, # gene proportions or log-expression rates
451
+ disp, # gene-wise dispersion
452
+ px_rate, # expected gene counts
453
+ None, # placeholder
454
+ angle.unsqueeze(1), # inferred cell cycle phase θ
455
+ None, None, # unused (e.g., radius)
456
+ self.fourier_W, # learned Fourier weights
457
+ None, # placeholder
458
+ non_cycle_out # non-cyclic output (for debugging)
459
+ )
460
+
461
+ # ─────────────────────────────────────────────────────────────
462
+ # VAE
463
+ # ─────────────────────────────────────────────────────────────
464
+
465
+ class CycleVI_VAE(EmbeddingModuleMixin, BaseMinifiedModeModuleClass):
466
+ """
467
+ Cell Cycle–aware Variational Autoencoder for single-cell RNA-seq data.
468
+
469
+ This model extends the standard scVI architecture by incorporating:
470
+ - A disentangled 2D circular latent space for modeling the cell cycle.
471
+ - A custom decoder with Fourier basis functions for periodic expression.
472
+ - Support for batch correction, observed and latent library sizes, and covariates.
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ n_input: int,
478
+ n_batch: int = 0,
479
+ n_labels: int = 0,
480
+ n_hidden: int = 128,
481
+ n_latent: int = 10,
482
+ n_layers: int = 1,
483
+ n_continuous_cov: int = 0,
484
+ n_cats_per_cov: list[int] | None = None,
485
+ dropout_rate: float = 0.1,
486
+ dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene-label",
487
+ log_variational: bool = True,
488
+ gene_likelihood: Literal["zinb", "nb", "poisson"] = "nb",
489
+ latent_distribution: Literal["normal", "ln"] = "normal",
490
+ encode_covariates: bool = False,
491
+ deeply_inject_covariates: bool = True,
492
+ batch_representation: Literal["one-hot", "embedding"] = "one-hot",
493
+ use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
494
+ use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
495
+ use_size_factor_key: bool = False,
496
+ use_observed_lib_size: bool = True,
497
+ library_log_means: np.ndarray | None = None,
498
+ library_log_vars: np.ndarray | None = None,
499
+ var_activation: Callable[[torch.Tensor], torch.Tensor] = None,
500
+ extra_encoder_kwargs: dict | None = None,
501
+ extra_decoder_kwargs: dict | None = None,
502
+ batch_embedding_kwargs: dict | None = None,
503
+ cycle_gene_mask: torch.Tensor | None = None,
504
+ ):
505
+ super().__init__()
506
+
507
+ # Store core configuration
508
+ self.dispersion = dispersion
509
+ self.n_latent = n_latent
510
+ self.log_variational = log_variational
511
+ self.gene_likelihood = gene_likelihood
512
+ self.n_batch = n_batch
513
+ self.n_labels = n_labels
514
+ self.latent_distribution = latent_distribution
515
+ self.encode_covariates = encode_covariates
516
+ self.use_size_factor_key = use_size_factor_key
517
+ self.use_observed_lib_size = use_size_factor_key or use_observed_lib_size
518
+
519
+ # Handle library size modeling if not using observed values
520
+ if not self.use_observed_lib_size:
521
+ if library_log_means is None or library_log_vars is None:
522
+ raise ValueError("Must provide library_log_means and library_log_vars if not using observed lib size.")
523
+ self.register_buffer("library_log_means", torch.from_numpy(library_log_means).float())
524
+ self.register_buffer("library_log_vars", torch.from_numpy(library_log_vars).float())
525
+
526
+ # ─────────────────────────────────────────────────────────────
527
+ # Setup batch representation
528
+ # ─────────────────────────────────────────────────────────────
529
+ self.batch_representation = batch_representation
530
+ if batch_representation == "embedding":
531
+ self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, **(batch_embedding_kwargs or {}))
532
+ batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim
533
+ elif batch_representation != "one-hot":
534
+ raise ValueError("`batch_representation` must be either 'one-hot' or 'embedding'.")
535
+
536
+ # ─────────────────────────────────────────────────────────────
537
+ # Encoder Setup
538
+ # ─────────────────────────────────────────────────────────────
539
+
540
+
541
+ # Determine normalization configuration
542
+ use_bn_enc = use_batch_norm in ["encoder", "both"]
543
+ use_bn_dec = use_batch_norm in ["decoder", "both"]
544
+ use_ln_enc = use_layer_norm in ["encoder", "both"]
545
+ use_ln_dec = use_layer_norm in ["decoder", "both"]
546
+
547
+ # Compute encoder input dimension
548
+ n_input_encoder = n_input + n_continuous_cov * encode_covariates
549
+ if self.batch_representation == "embedding":
550
+ n_input_encoder += batch_dim * encode_covariates
551
+ cat_list = list([] if n_cats_per_cov is None else n_cats_per_cov)
552
+ else:
553
+ cat_list = [n_batch] + list([] if n_cats_per_cov is None else n_cats_per_cov)
554
+
555
+ encoder_cat_list = cat_list if encode_covariates else None
556
+ _extra_encoder_kwargs = extra_encoder_kwargs or {}
557
+
558
+
559
+ self.z_encoder = Encoder(
560
+ n_input=n_input_encoder,
561
+ n_output=n_latent,
562
+ n_cat_list=encoder_cat_list,
563
+ n_layers=n_layers,
564
+ n_hidden=n_hidden,
565
+ dropout_rate=dropout_rate,
566
+ distribution=latent_distribution,
567
+ inject_covariates=deeply_inject_covariates,
568
+ use_batch_norm=use_bn_enc,
569
+ use_layer_norm=use_ln_enc,
570
+ var_activation=var_activation,
571
+ return_dist=True,
572
+ **_extra_encoder_kwargs,
573
+ )
574
+
575
+ # ─────────────────────────────────────────────────────────────
576
+ # Decoder Setup
577
+ # ─────────────────────────────────────────────────────────────
578
+
579
+ # Decoder input: z + optional batch embedding
580
+ n_input_decoder = n_latent + n_continuous_cov
581
+ if batch_representation == "embedding":
582
+ n_input_decoder += batch_dim
583
+
584
+ _extra_decoder_kwargs = extra_decoder_kwargs or {}
585
+ self.decoder = DecoderCycleVI(
586
+ n_input=n_input_decoder,
587
+ n_output=n_input,
588
+ n_layers=n_layers,
589
+ n_hidden=n_hidden,
590
+ n_cat_list=cat_list,
591
+ use_batch_norm=use_bn_dec,
592
+ use_layer_norm=use_ln_dec,
593
+ inject_covariates=deeply_inject_covariates,
594
+ scale_activation="softplus" if use_size_factor_key else "softmax",
595
+ cycle_gene_mask=cycle_gene_mask,
596
+ **_extra_decoder_kwargs,
597
+ )
598
+
599
+ # ─────────────────────────────────────────────────────────────
600
+ # Prepare tensors for inference
601
+ # ─────────────────────────────────────────────────────────────
602
+
603
+ def _get_inference_input(
604
+ self,
605
+ tensors: dict[str, torch.Tensor | None],
606
+ full_forward_pass: bool = False,
607
+ ) -> dict[str, torch.Tensor | None]:
608
+ """Get input tensors for the inference process."""
609
+ # Decide which data loader to use based on full_forward_pass flag and the minified data type
610
+ if full_forward_pass or self.minified_data_type is None:
611
+ loader = "full_data"
612
+ elif self.minified_data_type in [
613
+ ADATA_MINIFY_TYPE.LATENT_POSTERIOR,
614
+ ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS,
615
+ ]:
616
+ loader = "minified_data"
617
+ else:
618
+ raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}")
619
+
620
+ # For full data, return the standard tensors used in the model
621
+ if loader == "full_data":
622
+ return {
623
+ MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY],
624
+ MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY],
625
+ MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None),
626
+ MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None),
627
+ }
628
+ else:
629
+ # For minified data, use cached latent parameters
630
+ return {
631
+ MODULE_KEYS.QZM_KEY: tensors[REGISTRY_KEYS.LATENT_QZM_KEY],
632
+ MODULE_KEYS.QZV_KEY: tensors[REGISTRY_KEYS.LATENT_QZV_KEY],
633
+ REGISTRY_KEYS.OBSERVED_LIB_SIZE: tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE],
634
+ }
635
+ # ─────────────────────────────────────────────────────────────
636
+ # Prepare tensors for generative model
637
+ # ─────────────────────────────────────────────────────────────
638
+
639
+ def _get_generative_input(
640
+ self,
641
+ tensors: dict[str, torch.Tensor],
642
+ inference_outputs: dict[str, torch.Tensor | Distribution | None],
643
+ ) -> dict[str, torch.Tensor | None]:
644
+ """Get input tensors for the generative process."""
645
+ # Retrieve and transform size factor if provided
646
+ size_factor = tensors.get(REGISTRY_KEYS.SIZE_FACTOR_KEY, None)
647
+ if size_factor is not None:
648
+ size_factor = torch.log(size_factor)
649
+
650
+ # Return a dictionary mapping module keys to the appropriate tensors/distributions
651
+ return {
652
+ MODULE_KEYS.Z_KEY: inference_outputs[MODULE_KEYS.Z_KEY],
653
+ MODULE_KEYS.LIBRARY_KEY: inference_outputs[MODULE_KEYS.LIBRARY_KEY],
654
+ MODULE_KEYS.BATCH_INDEX_KEY: tensors[REGISTRY_KEYS.BATCH_KEY],
655
+ MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.LABELS_KEY],
656
+ MODULE_KEYS.CONT_COVS_KEY: tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None),
657
+ MODULE_KEYS.CAT_COVS_KEY: tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None),
658
+ MODULE_KEYS.SIZE_FACTOR_KEY: size_factor,
659
+
660
+ }
661
+
662
+ # ─────────────────────────────────────────────────────────────
663
+ # For each cell, computes the mean and variance of the log library size for the corresponding batch.
664
+ # ─────────────────────────────────────────────────────────────
665
+
666
+ def _compute_local_library_params(
667
+ self,
668
+ batch_index: torch.Tensor,
669
+ ) -> tuple[torch.Tensor, torch.Tensor]:
670
+ """
671
+ Computes local library parameters.
672
+
673
+ For each cell, computes the mean and variance of the log library size
674
+ for the corresponding batch.
675
+ """
676
+ from torch.nn.functional import linear
677
+
678
+ n_batch = self.library_log_means.shape[1] # Number of batches from the library means buffer
679
+ # Compute local means using one-hot encoding for the batch index and linear transformation
680
+ local_library_log_means = linear(
681
+ one_hot(batch_index.squeeze(-1), n_batch).float(), self.library_log_means
682
+ )
683
+ # Compute local variances similarly
684
+ local_library_log_vars = linear(
685
+ one_hot(batch_index.squeeze(-1), n_batch).float(), self.library_log_vars
686
+ )
687
+
688
+ return local_library_log_means, local_library_log_vars
689
+
690
+ @auto_move_data # Automatically move inputs/outputs to the correct device (CPU/GPU)
691
+
692
+ # ─────────────────────────────────────────────────────────────
693
+ # Encodes input data into latent variables
694
+ # ─────────────────────────────────────────────────────────────
695
+
696
+ def _regular_inference(
697
+ self,
698
+ x: torch.Tensor,
699
+ batch_index: torch.Tensor,
700
+ cont_covs: torch.Tensor | None = None,
701
+ cat_covs: torch.Tensor | None = None,
702
+ n_samples: int = 1,
703
+ ) -> dict[str, torch.Tensor | Distribution | None]:
704
+ """Run the regular inference process with normalization by library size."""
705
+ # Step 1: Compute observed library size (sum over genes per cell)
706
+ library = torch.sum(x, dim=1, keepdim=True) # shape [N, 1]
707
+
708
+ # Step 2: Normalize expression per cell
709
+ x_normalized = x / (library + 1e-8) # Add small epsilon to avoid division by zero
710
+
711
+ # Step 3: Apply log1p for numerical stability if enabled
712
+ if self.log_variational:
713
+ x_normalized = torch.log1p(x_normalized)
714
+
715
+ # Step 4: Prepare encoder input
716
+ if cont_covs is not None and self.encode_covariates:
717
+ encoder_input = torch.cat((x_normalized, cont_covs), dim=-1)
718
+ else:
719
+ encoder_input = x_normalized
720
+
721
+ if cat_covs is not None and self.encode_covariates:
722
+ categorical_input = torch.split(cat_covs, 1, dim=1)
723
+ else:
724
+ categorical_input = ()
725
+
726
+ if self.batch_representation == "embedding" and self.encode_covariates:
727
+ batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index)
728
+ encoder_input = torch.cat([encoder_input, batch_rep], dim=-1)
729
+ qz, z = self.z_encoder(encoder_input, *categorical_input)
730
+ else:
731
+ qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
732
+
733
+ # If you're not using observed_lib_size, compute encoded one
734
+ ql = None
735
+ if not self.use_observed_lib_size:
736
+ if self.batch_representation == "embedding":
737
+ ql, library_encoded = self.l_encoder(encoder_input, *categorical_input)
738
+ else:
739
+ ql, library_encoded = self.l_encoder(encoder_input, batch_index, *categorical_input)
740
+ library = library_encoded
741
+
742
+ # Expand for MC sampling if needed
743
+ if n_samples > 1:
744
+ untran_z = qz.sample((n_samples,))
745
+ z = self.z_encoder.z_transformation(untran_z)
746
+ library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) \
747
+ if self.use_observed_lib_size else ql.sample((n_samples,))
748
+
749
+ return {
750
+ MODULE_KEYS.Z_KEY: z,
751
+ MODULE_KEYS.QZ_KEY: qz,
752
+ MODULE_KEYS.QL_KEY: ql,
753
+ MODULE_KEYS.LIBRARY_KEY: torch.log(library + 1e-8), # used in decoder
754
+ }
755
+
756
+
757
+ @auto_move_data
758
+ def _cached_inference(
759
+ self,
760
+ qzm: torch.Tensor, # Cached latent mean values
761
+ qzv: torch.Tensor, # Cached latent variance values
762
+ observed_lib_size: torch.Tensor, # Observed library size values
763
+ n_samples: int = 1, # Number of samples for Monte Carlo approximation
764
+ ) -> dict[str, torch.Tensor | None]:
765
+ """Run the cached inference process."""
766
+
767
+ # Reconstruct the latent distribution using the cached parameters
768
+ qz = Normal(qzm, qzv.sqrt())
769
+ # Sample from the latent distribution; using sample() (non-reparameterized)
770
+ untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,))
771
+ # Transform the sampled latent variables if necessary
772
+ z = self.z_encoder.z_transformation(untran_z)
773
+ # Compute the library by taking log of the observed library size
774
+ library = torch.log(observed_lib_size)
775
+ if n_samples > 1:
776
+ library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1)))
777
+
778
+ return {
779
+ MODULE_KEYS.Z_KEY: z,
780
+ MODULE_KEYS.QZ_KEY: qz,
781
+ MODULE_KEYS.QL_KEY: None,
782
+ MODULE_KEYS.LIBRARY_KEY: library,
783
+ }
784
+
785
+ @auto_move_data
786
+
787
+ # ─────────────────────────────────────────────────────────────
788
+ # Decodes latent z back to gene expression
789
+ # ─────────────────────────────────────────────────────────────
790
+ def generative(
791
+ self,
792
+ z,
793
+ library,
794
+ batch_index,
795
+ cont_covs=None,
796
+ cat_covs=None,
797
+ size_factor=None,
798
+ y=None,
799
+ transform_batch=None,
800
+ remove_cell_cycle: bool = False,
801
+ ):
802
+ from scvi.distributions import NegativeBinomial, Normal, Poisson, ZeroInflatedNegativeBinomial
803
+ # 1. Build decoder_input = [z (+ cont_covs)]
804
+ if cont_covs is None:
805
+ decoder_input = z
806
+ elif z.dim() != cont_covs.dim():
807
+ decoder_input = torch.cat(
808
+ [z, cont_covs.unsqueeze(0).expand(z.size(0), -1, -1)], dim=-1
809
+ )
810
+ else:
811
+ decoder_input = torch.cat([z, cont_covs], dim=-1)
812
+
813
+ # 2. Add batch *embedding* if we're in embedding mode
814
+ if self.batch_representation == "embedding":
815
+ batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index)
816
+ # make dims match if we're in MC-sampling mode (z can be [n_samples, n_cells, ...])
817
+ if decoder_input.dim() != batch_rep.dim():
818
+ batch_rep = batch_rep.unsqueeze(0).expand(decoder_input.size(0), -1, -1)
819
+ decoder_input = torch.cat([decoder_input, batch_rep], dim=-1)
820
+ # IMPORTANT: in this mode, batch is NOT part of n_cat_list
821
+
822
+ # 3. Build categorical inputs for FCLayers
823
+ # Start with any categorical covariates from setup_anndata(...)
824
+ if cat_covs is not None:
825
+ categorical_input = torch.split(cat_covs, 1, dim=1)
826
+ else:
827
+ categorical_input = ()
828
+
829
+ # If batch is represented as one-hot (not embedding), it *is* part of n_cat_list
830
+ if self.batch_representation == "one-hot":
831
+ categorical_input = (batch_index, *categorical_input)
832
+
833
+ # 4. Handle transform_batch (same as scVI: override batch_index used for dispersion/priors)
834
+ if transform_batch is not None:
835
+ batch_index = torch.ones_like(batch_index) * transform_batch
836
+
837
+ # 5. size_factor / library handling
838
+ if not self.use_size_factor_key:
839
+ size_factor = library # scVI uses observed lib size unless overridden
840
+
841
+ # 6. Run the decoder
842
+ (
843
+ px_scale,
844
+ disp,
845
+ px_rate,
846
+ _,
847
+ angle,
848
+ radius,
849
+ _,
850
+ W_fourier,
851
+ _,
852
+ baseline,
853
+ ) = self.decoder(
854
+ decoder_input,
855
+ size_factor,
856
+ remove_cell_cycle,
857
+ *categorical_input,
858
+ )
859
+
860
+ px_r = torch.exp(disp)
861
+
862
+ if self.gene_likelihood == "zinb":
863
+ px = ZeroInflatedNegativeBinomial(mu=px_rate, theta=px_r, zi_logits=None, scale=px_scale)
864
+ elif self.gene_likelihood == "nb":
865
+ px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale)
866
+ elif self.gene_likelihood == "poisson":
867
+ px = Poisson(rate=px_rate, scale=px_scale)
868
+ elif self.gene_likelihood == "normal":
869
+ px = Normal(px_rate, px_r, normal_mu=px_scale)
870
+
871
+ if self.use_observed_lib_size:
872
+ pl = None
873
+ else:
874
+ local_library_log_means, local_library_log_vars = self._compute_local_library_params(batch_index)
875
+ pl = Normal(local_library_log_means, local_library_log_vars.sqrt())
876
+
877
+ pz = Normal(torch.zeros_like(z), torch.ones_like(z))
878
+
879
+ return {
880
+ MODULE_KEYS.PX_KEY: px,
881
+ MODULE_KEYS.PL_KEY: pl,
882
+ MODULE_KEYS.PZ_KEY: pz,
883
+ "angle": angle,
884
+ }
885
+
886
+
887
+ @unsupported_if_adata_minified # Mark this method as unsupported if AnnData is in minified mode
888
+ # ─────────────────────────────────────────────────────────────
889
+ # Loss function
890
+ # ─────────────────────────────────────────────────────────────
891
+ def loss(
892
+ self,
893
+ tensors: dict[str, torch.Tensor],
894
+ inference_outputs: dict[str, torch.Tensor | Distribution | None],
895
+ generative_outputs: dict[str, Distribution | torch.Tensor | None],
896
+ kl_weight: torch.Tensor | float = 1.0,
897
+ ) -> LossOutput:
898
+ from torch.distributions import kl_divergence
899
+
900
+ x = tensors[REGISTRY_KEYS.X_KEY]
901
+
902
+ # KL divergence for z
903
+ kl_divergence_z = kl_divergence(
904
+ inference_outputs[MODULE_KEYS.QZ_KEY], generative_outputs[MODULE_KEYS.PZ_KEY]
905
+ ).sum(dim=-1)
906
+
907
+ # KL for library size
908
+ if not self.use_observed_lib_size:
909
+ kl_divergence_l = kl_divergence(
910
+ inference_outputs[MODULE_KEYS.QL_KEY], generative_outputs[MODULE_KEYS.PL_KEY]
911
+ ).sum(dim=1)
912
+ else:
913
+ kl_divergence_l = torch.zeros_like(kl_divergence_z)
914
+
915
+ # Reconstruction loss
916
+ reconst_loss = -generative_outputs[MODULE_KEYS.PX_KEY].log_prob(x).sum(-1)
917
+
918
+ # Latent position
919
+ z_latent = inference_outputs[MODULE_KEYS.QZ_KEY].loc
920
+ x_latent, y_latent = z_latent[:, 0], z_latent[:, 1]
921
+ angle = torch.atan2(y_latent, x_latent)
922
+ radius = torch.sqrt(x_latent**2 + y_latent**2 + 1e-6)
923
+
924
+ # G2M and S scores from continuous covariates
925
+ target_angle = tensors[CYCLE_REGISTRY_KEYS.CYCLE_ANGLE_KEY].squeeze(-1)
926
+
927
+ # Angle loss (squared angular distance)
928
+ delta_angle = angle - target_angle
929
+ angle_loss = torch.mean(1.0 - torch.cos(delta_angle))
930
+
931
+ # Radius penalty (high as r → 0)
932
+ radius_penalty = torch.mean(torch.exp(-10 * radius))
933
+
934
+ # Weighted loss terms
935
+ weighted_kl_local = kl_weight * kl_divergence_z + kl_divergence_l
936
+ cycle_pos_weight = 100 * (1.0 - kl_weight)**4
937
+ weighted_angle_loss = 0.5 * cycle_pos_weight * angle_loss
938
+ weighted_radius_penalty = 100 * (kl_weight**2)*radius_penalty
939
+
940
+ # Total loss
941
+ loss = torch.mean(reconst_loss + weighted_kl_local) + weighted_angle_loss + weighted_radius_penalty
942
+
943
+ return LossOutput(
944
+ loss=loss,
945
+ reconstruction_loss=reconst_loss,
946
+ kl_local={
947
+ MODULE_KEYS.KL_L_KEY: kl_divergence_l,
948
+ MODULE_KEYS.KL_Z_KEY: kl_divergence_z,
949
+ },
950
+ extra_metrics={
951
+ "z": inference_outputs["z"],
952
+ "batch": tensors[REGISTRY_KEYS.BATCH_KEY],
953
+ "labels": tensors[REGISTRY_KEYS.LABELS_KEY],
954
+ "angle_loss": angle_loss,
955
+ "radius_penalty": radius_penalty,
956
+ "weighted_angle_loss": weighted_angle_loss,
957
+ "weighted_radius_penalty": weighted_radius_penalty,
958
+ },
959
+ )
960
+
961
+
962
+
963
+ @torch.inference_mode()
964
+
965
+ # ─────────────────────────────────────────────────────────────
966
+ # Samples gene expression from the posterior predictive distribution
967
+ # ─────────────────────────────────────────────────────────────
968
+
969
+ def sample(
970
+ self,
971
+ tensors: dict[str, torch.Tensor], # Input tensors for sampling
972
+ n_samples: int = 1, # Number of Monte Carlo samples to draw per observation
973
+ max_poisson_rate: float = 1e8, # Maximum value to clip Poisson rate to avoid numerical issues
974
+ ) -> torch.Tensor:
975
+ r"""Generate predictive samples from the posterior predictive distribution.
976
+
977
+ The posterior predictive distribution is denoted as :math:`p(\hat{x} \mid x)`, where
978
+ :math:`x` is the input data and :math:`\hat{x}` is the sampled data.
979
+
980
+ We sample from this distribution by first sampling ``n_samples`` times from the posterior
981
+ distribution :math:`q(z \mid x)` for a given observation, and then sampling from the
982
+ likelihood :math:`p(\hat{x} \mid z)` for each of these.
983
+ """
984
+ from scvi.distributions import Poisson
985
+
986
+ inference_kwargs = {"n_samples": n_samples}
987
+ # Run a forward pass to get generative outputs (without computing loss)
988
+ _, generative_outputs = self.forward(
989
+ tensors, inference_kwargs=inference_kwargs, compute_loss=False
990
+ )
991
+
992
+ dist = generative_outputs[MODULE_KEYS.PX_KEY]
993
+ if self.gene_likelihood == "poisson":
994
+ # Handle potential issues on MPS devices by clamping the Poisson rate
995
+ dist = (
996
+ Poisson(torch.clamp(dist.rate.to("cpu"), max=max_poisson_rate))
997
+ if self.device.type == "mps"
998
+ else Poisson(torch.clamp(dist.rate, max=max_poisson_rate))
999
+ )
1000
+
1001
+ # Draw samples from the likelihood distribution; shape depends on n_samples
1002
+ samples = dist.sample()
1003
+ # If multiple samples were drawn, permute dimensions so that output is (n_obs, n_vars, n_samples)
1004
+ samples = torch.permute(samples, (1, 2, 0)) if n_samples > 1 else samples
1005
+
1006
+ return samples.cpu() # Return samples on CPU
1007
+
1008
+ @torch.inference_mode()
1009
+ @auto_move_data
1010
+
1011
+
1012
+ # ─────────────────────────────────────────────────────────────
1013
+ # Estimates marginal log-likelihood with Monte Carlo sampling
1014
+ # ─────────────────────────────────────────────────────────────
1015
+
1016
+ def marginal_ll(
1017
+ self,
1018
+ tensors: dict[str, torch.Tensor], # Input tensors for marginal likelihood computation
1019
+ n_mc_samples: int, # Total number of Monte Carlo samples for estimation
1020
+ return_mean: bool = False, # Whether to return the mean marginal likelihood over cells
1021
+ n_mc_samples_per_pass: int = 1, # Number of samples per computation pass (to reduce memory usage)
1022
+ ):
1023
+ """Compute the marginal log-likelihood of the data under the model."""
1024
+ from torch import logsumexp
1025
+ from torch.distributions import Normal
1026
+
1027
+ batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
1028
+
1029
+ to_sum = [] # List to accumulate log probabilities over multiple passes
1030
+ if n_mc_samples_per_pass > n_mc_samples:
1031
+ warnings.warn(
1032
+ "Number of chunks is larger than the total number of samples, setting it to the "
1033
+ "number of samples",
1034
+ RuntimeWarning,
1035
+ stacklevel=settings.warnings_stacklevel,
1036
+ )
1037
+ n_mc_samples_per_pass = n_mc_samples
1038
+ n_passes = int(np.ceil(n_mc_samples / n_mc_samples_per_pass))
1039
+ for _ in range(n_passes):
1040
+ # For each pass, run a forward pass to get inference outputs and loss components
1041
+ inference_outputs, _, losses = self.forward(
1042
+ tensors,
1043
+ inference_kwargs={"n_samples": n_mc_samples_per_pass},
1044
+ get_inference_input_kwargs={"full_forward_pass": True},
1045
+ )
1046
+ qz = inference_outputs[MODULE_KEYS.QZ_KEY]
1047
+ ql = inference_outputs[MODULE_KEYS.QL_KEY]
1048
+ z = inference_outputs[MODULE_KEYS.Z_KEY]
1049
+ library = inference_outputs[MODULE_KEYS.LIBRARY_KEY]
1050
+
1051
+ # Get the reconstruction loss from the losses output
1052
+ reconst_loss = losses.dict_sum(losses.reconstruction_loss)
1053
+
1054
+ # Compute log probabilities for the latent variable and reconstruction
1055
+ p_z = (
1056
+ Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)).log_prob(z).sum(dim=-1)
1057
+ )
1058
+ p_x_zl = -reconst_loss
1059
+ q_z_x = qz.log_prob(z).sum(dim=-1)
1060
+ log_prob_sum = p_z + p_x_zl - q_z_x
1061
+
1062
+ if not self.use_observed_lib_size:
1063
+ # Compute additional log probabilities for library size if not observed
1064
+ local_library_log_means, local_library_log_vars = self._compute_local_library_params(batch_index)
1065
+ p_l = (
1066
+ Normal(local_library_log_means, local_library_log_vars.sqrt())
1067
+ .log_prob(library)
1068
+ .sum(dim=-1)
1069
+ )
1070
+ q_l_x = ql.log_prob(library).sum(dim=-1)
1071
+ log_prob_sum += p_l - q_l_x
1072
+ if n_mc_samples_per_pass == 1:
1073
+ log_prob_sum = log_prob_sum.unsqueeze(0)
1074
+
1075
+ to_sum.append(log_prob_sum)
1076
+ # Concatenate all passes and compute log-sum-exp for a Monte Carlo estimate
1077
+ to_sum = torch.cat(to_sum, dim=0)
1078
+ batch_log_lkl = logsumexp(to_sum, dim=0) - np.log(n_mc_samples)
1079
+ if return_mean:
1080
+ batch_log_lkl = torch.mean(batch_log_lkl).item()
1081
+ else:
1082
+ batch_log_lkl = batch_log_lkl.cpu()
1083
+ return batch_log_lkl
1084
+
1085
+
1086
+ # ─────────────────────────────────────────────────────────────
1087
+ # Model
1088
+ # ─────────────────────────────────────────────────────────────
1089
+ class CycleVI(EmbeddingMixin,
1090
+ RNASeqMixin,
1091
+ VAEMixin,
1092
+ ArchesMixin,
1093
+ UnsupervisedTrainingMixin,
1094
+ BaseMinifiedModeModelClass):
1095
+
1096
+ # Tell scvi-tools which module class this model uses
1097
+ _module_cls = CycleVI_VAE
1098
+
1099
+ # Keys for storing latent mean and variance in AnnData
1100
+ _LATENT_QZM_KEY = "ccvi_latent_qzm" # Key for the latent mean in AnnData
1101
+ _LATENT_QZV_KEY = "ccvi_latent_qzv" # Key for the latent variance in AnnData
1102
+
1103
+ # Define the training plan
1104
+ _training_plan_cls = PhaseAdversarialTrainingPlan
1105
+
1106
+ # ─────────────────────────────────────────────────────────────
1107
+ # Constructor
1108
+ # ─────────────────────────────────────────────────────────────
1109
+ def __init__(
1110
+ self,
1111
+ adata: AnnData | None = None, # Input data; can be None (if adata is not provided, the model will delay initialization until train is called).
1112
+ n_hidden: int = 128, # Hidden units per layer
1113
+ n_latent: int = 10, # Dimensionality of latent space
1114
+ n_layers: int = 1, # Number of layers in encoder/decoder neural networks
1115
+ dropout_rate: float = 0.1, # Dropout rate
1116
+ dispersion: Literal[...] = "gene-label", # How to parameterize dispersion (per gene, per cell, etc.)
1117
+ gene_likelihood: Literal[...] = "nb", # Likelihood distribution for gene expression (usually Negative Binomial)
1118
+ latent_distribution: Literal[...] = "normal", # Latent distribution type
1119
+ cycle_gene_mask: torch.Tensor | None = None, # Boolean mask marking cycle genes for disentanglement (usually none)
1120
+ **kwargs, # Any other parameters passed to the VAE
1121
+ ):
1122
+
1123
+ # Call the constructor of the parent mixin/base classes
1124
+ super().__init__(adata)
1125
+
1126
+ # Store model configuration for the underlying PyTorch module (CycleVI_VAE)
1127
+ self._module_kwargs = {
1128
+ "n_hidden": n_hidden,
1129
+ "n_latent": n_latent,
1130
+ "n_layers": n_layers,
1131
+ "dropout_rate": dropout_rate,
1132
+ "dispersion": dispersion,
1133
+ "gene_likelihood": gene_likelihood,
1134
+ "latent_distribution": latent_distribution,
1135
+ "cycle_gene_mask": cycle_gene_mask,
1136
+ **kwargs,
1137
+ }
1138
+
1139
+ # Build a human-readable summary string of the model architecture
1140
+ self._model_summary_string = (
1141
+ "CycleVI model with the following parameters: \n"
1142
+ f"n_hidden: {n_hidden}, n_latent: {n_latent}, n_layers: {n_layers}, "
1143
+ f"dropout_rate: {dropout_rate}, dispersion: {dispersion}, "
1144
+ f"gene_likelihood: {gene_likelihood}, latent_distribution: {latent_distribution}, cycle_gene_mask: {cycle_gene_mask} "
1145
+ )
1146
+
1147
+ # If lazy initialization is enabled (adata is not provided), postpone model creation until training
1148
+ if self._module_init_on_train:
1149
+ self.module = None
1150
+ warnings.warn(
1151
+ "Model was initialized without `adata`. The module will be initialized when "
1152
+ "calling `train`. This behavior is experimental and may change in the future.",
1153
+ UserWarning,
1154
+ stacklevel=settings.warnings_stacklevel,
1155
+ )
1156
+ else:
1157
+ # ─────────────────────────────────────────────
1158
+ # Collect dataset-specific information
1159
+ # ─────────────────────────────────────────────
1160
+
1161
+ # For categorical covariates (e.g. cell type, donor ID), get the number of categories per covariate
1162
+ n_cats_per_cov = (
1163
+ self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key
1164
+ if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
1165
+ else None
1166
+ )
1167
+
1168
+ # Number of batches (e.g. experimental batches in AnnData.obs["batch"])
1169
+ n_batch = self.summary_stats.n_batch
1170
+
1171
+ # Check if per-cell size factors are already stored in the AnnData registry
1172
+ use_size_factor_key = REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
1173
+
1174
+ # Initialize library size parameters if needed
1175
+ library_log_means, library_log_vars = None, None
1176
+ if (
1177
+ not use_size_factor_key
1178
+ and self.minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR
1179
+ ):
1180
+ library_log_means, library_log_vars = _init_library_size(
1181
+ self.adata_manager, n_batch
1182
+ )
1183
+
1184
+
1185
+ # ─────────────────────────────────────────────
1186
+ # Instantiate the VAE
1187
+ # ─────────────────────────────────────────────
1188
+ self.module = self._module_cls(
1189
+ n_input=self.summary_stats.n_vars, # number of genes
1190
+ n_batch=n_batch, # number of batches
1191
+ n_labels=self.summary_stats.n_labels, # number of labels
1192
+ n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
1193
+ n_cats_per_cov=n_cats_per_cov, # categorical covariates
1194
+ n_hidden=n_hidden, # hidden units per layer
1195
+ n_latent=n_latent, # latent dimension
1196
+ n_layers=n_layers, # number of layers
1197
+ dropout_rate=dropout_rate, # dropout probability
1198
+ dispersion=dispersion, # dispersion parameterization mode
1199
+ gene_likelihood=gene_likelihood, # likelihood model
1200
+ latent_distribution=latent_distribution, # prior distribution for z
1201
+ use_size_factor_key=use_size_factor_key, # whether to use size factors
1202
+ library_log_means=library_log_means, # init mean for library size prior
1203
+ library_log_vars=library_log_vars, # init variance for library size prior
1204
+ cycle_gene_mask=cycle_gene_mask, # pass your custom mask for cycle genes
1205
+ **kwargs, # forward any additional arguments
1206
+ )
1207
+
1208
+ # Set minified type to the model (used for memory optimization)
1209
+ self.module.minified_data_type = self.minified_data_type
1210
+
1211
+ # Save init parameters for reproducibility
1212
+ self.init_params_ = self._get_init_params(locals())
1213
+
1214
+ # ─────────────────────────────────────────────
1215
+ # Register data with scvi-tools (what to read from AnnData and where)
1216
+ # ─────────────────────────────────────────────
1217
+ @classmethod
1218
+ @setup_anndata_dsp.dedent
1219
+ def setup_anndata(
1220
+ cls,
1221
+ adata: AnnData,
1222
+ layer: str | None = None,
1223
+ batch_key: str | None = None,
1224
+ labels_key: str | None = None,
1225
+ size_factor_key: str | None = None,
1226
+ categorical_covariate_keys: list[str] | None = None,
1227
+ continuous_covariate_keys: list[str] | None = None,
1228
+ # NEW:
1229
+ cycle_initiation_label_key: str | None = None, # e.g. "phase"
1230
+ cycle_initiation_angle_key: str | None = None, # e.g. "cycle_angle_uniform"
1231
+ **kwargs,
1232
+ ):
1233
+ """
1234
+ Register AnnData fields for CycleVI.
1235
+
1236
+ Parameters
1237
+ ----------
1238
+ %(param_adata)s
1239
+ %(param_layer)s
1240
+ %(param_batch_key)s
1241
+ %(param_labels_key)s
1242
+ %(param_size_factor_key)s
1243
+ %(param_cat_cov_keys)s
1244
+ %(param_cont_cov_keys)s
1245
+
1246
+ Notes
1247
+ -----
1248
+ Phase inputs are registered separately and NOT injected as covariates:
1249
+ - phase labels -> CYCLE_REGISTRY_KEYS.CYCLE_LABEL_KEY
1250
+ - phase angle -> CYCLE_REGISTRY_KEYS.CYCLE_ANGLE_KEY
1251
+ """
1252
+
1253
+ setup_method_args = cls._get_setup_method_args(**locals())
1254
+
1255
+ anndata_fields = [
1256
+ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
1257
+ CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
1258
+ CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
1259
+ NumericalObsField(REGISTRY_KEYS.SIZE_FACTOR_KEY, size_factor_key, required=False),
1260
+
1261
+ # regular covs (kept separate from phase)
1262
+ CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
1263
+ NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
1264
+
1265
+ # NEW: *separate* phase inputs
1266
+ CategoricalObsField(CYCLE_REGISTRY_KEYS.CYCLE_LABEL_KEY, cycle_initiation_label_key),
1267
+ NumericalObsField(CYCLE_REGISTRY_KEYS.CYCLE_ANGLE_KEY, cycle_initiation_angle_key),
1268
+
1269
+
1270
+ ]
1271
+
1272
+ adata_minify_type = _get_adata_minify_type(adata)
1273
+ if adata_minify_type is not None:
1274
+ anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type)
1275
+
1276
+ adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
1277
+ adata_manager.register_fields(adata, **kwargs)
1278
+ cls.register_manager(adata_manager)
1279
+
1280
+ @torch.inference_mode()
1281
+
1282
+
1283
+ def get_normalized_expression(
1284
+ self,
1285
+ adata: AnnData | None = None,
1286
+ indices: list[int] | None = None,
1287
+ transform_batch: list[Number | str] | None = None,
1288
+ gene_list: list[str] | None = None,
1289
+ library_size: float | Literal["latent"] = 1,
1290
+ n_samples: int = 1,
1291
+ n_samples_overall: int = None,
1292
+ weights: Literal["uniform", "importance"] | None = None,
1293
+ batch_size: int | None = None,
1294
+ return_mean: bool = True,
1295
+ return_numpy: bool | None = None,
1296
+ silent: bool = True,
1297
+ dataloader: Iterator[dict[str, Tensor | None]] | None = None,
1298
+ remove_cell_cycle: bool = False,
1299
+ **importance_weighting_kwargs,
1300
+ ) -> np.ndarray | pd.DataFrame:
1301
+
1302
+
1303
+ if dataloader is None:
1304
+ adata = self._validate_anndata(adata)
1305
+ if indices is None:
1306
+ indices = np.arange(adata.n_obs)
1307
+ if n_samples_overall is not None:
1308
+ assert n_samples == 1
1309
+ n_samples = n_samples_overall // len(indices) + 1
1310
+ scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
1311
+ transform_batch = _get_batch_code_from_category(
1312
+ self.get_anndata_manager(adata, required=True), transform_batch
1313
+ )
1314
+ gene_mask = slice(None) if gene_list is None else adata.var_names.isin(gene_list)
1315
+ else:
1316
+ scdl = dataloader
1317
+ gene_mask = slice(None)
1318
+ transform_batch = [None]
1319
+
1320
+ if n_samples > 1 and return_mean is False:
1321
+ if return_numpy is False:
1322
+ warnings.warn("return_numpy must be True if n_samples > 1 and return_mean is False.")
1323
+ return_numpy = True
1324
+
1325
+ generative_output_key = "mu" if library_size == "latent" else "scale"
1326
+ scaling = 1 if library_size == "latent" else library_size
1327
+
1328
+ exprs = []
1329
+ zs = []
1330
+ qz_store = DistributionConcatenator()
1331
+ px_store = DistributionConcatenator()
1332
+
1333
+ for tensors in scdl:
1334
+ per_batch_exprs = []
1335
+ for batch in track(transform_batch, disable=silent):
1336
+ generative_kwargs = self._get_transform_batch_gen_kwargs(batch)
1337
+ generative_kwargs["remove_cell_cycle"] = remove_cell_cycle
1338
+ inference_outputs, generative_outputs = self.module.forward(
1339
+ tensors=tensors,
1340
+ inference_kwargs={"n_samples": n_samples},
1341
+ generative_kwargs=generative_kwargs,
1342
+ compute_loss=False,
1343
+ )
1344
+ px_generative = generative_outputs["px"]
1345
+ exp_ = px_generative.get_normalized(generative_output_key)
1346
+ exp_ = exp_[..., gene_mask] * scaling
1347
+ per_batch_exprs.append(exp_[None].cpu())
1348
+ if weights == "importance":
1349
+ qz_store.store_distribution(inference_outputs["qz"])
1350
+ px_store.store_distribution(px_generative)
1351
+
1352
+ zs.append(inference_outputs["z"].cpu())
1353
+ per_batch_exprs = torch.cat(per_batch_exprs, dim=0).mean(0).numpy()
1354
+ exprs.append(per_batch_exprs)
1355
+
1356
+ cell_axis = 1 if n_samples > 1 else 0
1357
+ exprs = np.concatenate(exprs, axis=cell_axis)
1358
+
1359
+ if n_samples_overall is not None:
1360
+ exprs = exprs.reshape(-1, exprs.shape[-1])
1361
+ n_samples_ = exprs.shape[0]
1362
+ if weights is None or weights == "uniform":
1363
+ p = None
1364
+ else:
1365
+ qz = qz_store.get_concatenated_distributions(axis=0)
1366
+ px = px_store.get_concatenated_distributions(axis=0 if n_samples == 1 else 1)
1367
+ p = self._get_importance_weights(
1368
+ adata, indices, qz, px, torch.concat(zs, dim=cell_axis), **importance_weighting_kwargs
1369
+ )
1370
+ exprs = exprs[np.random.choice(n_samples_, n_samples_overall, p=p, replace=True)]
1371
+ elif n_samples > 1 and return_mean:
1372
+ exprs = exprs.mean(0)
1373
+
1374
+ if (return_numpy is None or not return_numpy) and dataloader is None:
1375
+ return pd.DataFrame(exprs, columns=adata.var_names[gene_mask], index=adata.obs_names[indices])
1376
+ return exprs
1377
+
1378
+ def differential_expression(
1379
+ self,
1380
+ adata: AnnData | None = None,
1381
+ groupby: str | None = None,
1382
+ group1: list[str] | None = None,
1383
+ group2: str | None = None,
1384
+ idx1: list[int] | list[bool] | str | None = None,
1385
+ idx2: list[int] | list[bool] | str | None = None,
1386
+ mode: Literal["vanilla", "change"] = "vanilla",
1387
+ delta: float = 0.25,
1388
+ batch_size: int | None = None,
1389
+ all_stats: bool = True,
1390
+ batch_correction: bool = False,
1391
+ batchid1: list[str] | None = None,
1392
+ batchid2: list[str] | None = None,
1393
+ fdr_target: float = 0.05,
1394
+ silent: bool = False,
1395
+ weights: Literal["uniform", "importance"] | None = "uniform",
1396
+ filter_outlier_cells: bool = False,
1397
+ remove_cell_cycle: bool = False,
1398
+ importance_weighting_kwargs: dict | None = None,
1399
+ **kwargs,
1400
+ ) -> pd.DataFrame:
1401
+ adata = self._validate_anndata(adata)
1402
+ col_names = adata.var_names
1403
+ importance_weighting_kwargs = importance_weighting_kwargs or {}
1404
+
1405
+ model_fn = partial(
1406
+ self.get_normalized_expression,
1407
+ return_numpy=True,
1408
+ n_samples=1,
1409
+ batch_size=batch_size,
1410
+ weights=weights,
1411
+ remove_cell_cycle=remove_cell_cycle,
1412
+ **importance_weighting_kwargs,
1413
+ )
1414
+ representation_fn = self.get_latent_representation if filter_outlier_cells else None
1415
+
1416
+ result = _de_core(
1417
+ self.get_anndata_manager(adata, required=True),
1418
+ model_fn,
1419
+ representation_fn,
1420
+ groupby,
1421
+ group1,
1422
+ group2,
1423
+ idx1,
1424
+ idx2,
1425
+ all_stats,
1426
+ scrna_raw_counts_properties,
1427
+ col_names,
1428
+ mode,
1429
+ batchid1,
1430
+ batchid2,
1431
+ delta,
1432
+ batch_correction,
1433
+ fdr_target,
1434
+ silent,
1435
+ **kwargs,
1436
+ )
1437
+
1438
+ return result