scdesigner 0.0.4__tar.gz → 0.0.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of scdesigner might be problematic. Click here for more details.

Files changed (70) hide show
  1. {scdesigner-0.0.4 → scdesigner-0.0.5}/PKG-INFO +1 -2
  2. {scdesigner-0.0.4 → scdesigner-0.0.5}/pyproject.toml +1 -2
  3. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/loader.py +85 -40
  4. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/marginal.py +33 -24
  5. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/standard_copula.py +49 -49
  6. {scdesigner-0.0.4 → scdesigner-0.0.5}/.gitignore +0 -0
  7. {scdesigner-0.0.4 → scdesigner-0.0.5}/README.md +0 -0
  8. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/__init__.py +0 -0
  9. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/data/__init__.py +0 -0
  10. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/data/formula.py +0 -0
  11. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/data/group.py +0 -0
  12. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/data/sparse.py +0 -0
  13. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/diagnose/__init__.py +0 -0
  14. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/diagnose/aic_bic.py +0 -0
  15. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/diagnose/plot.py +0 -0
  16. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/__init__.py +0 -0
  17. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/bernoulli.py +0 -0
  18. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/gaussian.py +0 -0
  19. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/gaussian_copula_factory.py +0 -0
  20. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/glm_factory.py +0 -0
  21. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/negbin.py +0 -0
  22. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/pnmf.py +0 -0
  23. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/poisson.py +0 -0
  24. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/zero_inflated_negbin.py +0 -0
  25. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/estimators/zero_inflated_poisson.py +0 -0
  26. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/format/__init__.py +0 -0
  27. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/format/format.py +0 -0
  28. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/format/print.py +0 -0
  29. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/__init__.py +0 -0
  30. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/bernoulli.py +0 -0
  31. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/composite.py +0 -0
  32. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/copula.py +0 -0
  33. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/formula.py +0 -0
  34. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/gaussian.py +0 -0
  35. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/kwargs.py +0 -0
  36. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/negbin.py +0 -0
  37. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -0
  38. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/scd3.py +0 -0
  39. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/scd3_instances.py +0 -0
  40. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/simulator.py +0 -0
  41. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/transform.py +0 -0
  42. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/minimal/zero_inflated_negbin.py +0 -0
  43. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/__init__.py +0 -0
  44. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/bernoulli.py +0 -0
  45. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/gaussian.py +0 -0
  46. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/negbin.py +0 -0
  47. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/poisson.py +0 -0
  48. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/zero_inflated_negbin.py +0 -0
  49. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/predictors/zero_inflated_poisson.py +0 -0
  50. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/__init__.py +0 -0
  51. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/bernoulli.py +0 -0
  52. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/gaussian.py +0 -0
  53. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/glm_factory.py +0 -0
  54. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/negbin.py +0 -0
  55. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/poisson.py +0 -0
  56. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/zero_inflated_negbin.py +0 -0
  57. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/samplers/zero_inflated_poisson.py +0 -0
  58. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/simulators/__init__.py +0 -0
  59. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/simulators/composite_regressor.py +0 -0
  60. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/simulators/glm_simulator.py +0 -0
  61. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/simulators/pnmf_regression.py +0 -0
  62. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/__init__.py +0 -0
  63. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/amplify.py +0 -0
  64. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/mask.py +0 -0
  65. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/nullify.py +0 -0
  66. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/split.py +0 -0
  67. {scdesigner-0.0.4 → scdesigner-0.0.5}/src/scdesigner/transform/substitute.py +0 -0
  68. {scdesigner-0.0.4 → scdesigner-0.0.5}/tests/__init__.py +0 -0
  69. {scdesigner-0.0.4 → scdesigner-0.0.5}/tests/test_negative_binomial.py +0 -0
  70. {scdesigner-0.0.4 → scdesigner-0.0.5}/tests/test_simulator.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: scdesigner
3
- Version: 0.0.4
3
+ Version: 0.0.5
4
4
  Summary: Interactive simulation for rigorous and transparent multi-omics analysis.
5
5
  Project-URL: Homepage, https://github.com/krisrs1128/scDesigner/
6
6
  Project-URL: Issues, https://github.com/krisrs1128/scDesigner/Issues/
@@ -11,7 +11,6 @@ Classifier: Programming Language :: Python :: 3
11
11
  Requires-Python: >=3.8
12
12
  Requires-Dist: anndata
13
13
  Requires-Dist: formulaic
14
- Requires-Dist: lightning
15
14
  Requires-Dist: numpy
16
15
  Requires-Dist: pandas
17
16
  Requires-Dist: rich
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "scdesigner"
3
- version = "0.0.4"
3
+ version = "0.0.5"
4
4
  authors = [
5
5
  { name="Kris Sankaran", email="ksankaran@wisc.edu" },
6
6
  ]
@@ -15,7 +15,6 @@ classifiers = [
15
15
  dependencies = [
16
16
  "anndata",
17
17
  "formulaic",
18
- "lightning",
19
18
  "numpy",
20
19
  "pandas",
21
20
  "rich",
@@ -5,21 +5,45 @@ from torch.utils.data import Dataset, DataLoader
5
5
  from typing import Dict
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ import scipy.sparse
8
9
  import torch
9
10
 
11
+ def get_device():
12
+ """Detect and return the best available device (MPS, CUDA, or CPU)."""
13
+ if torch.backends.mps.is_available():
14
+ return torch.device("mps")
15
+ elif torch.cuda.is_available():
16
+ return torch.device("cuda")
17
+ else:
18
+ return torch.device("cpu")
19
+
20
+
21
+ class PreloadedDataset(Dataset):
22
+ """Dataset that assumes x and y are both fully in memory."""
23
+ def __init__(self, y_tensor, x_tensors, predictor_names):
24
+ self.y = y_tensor
25
+ self.x = x_tensors
26
+ self.predictor_names = predictor_names
27
+
28
+ def __len__(self):
29
+ return len(self.y)
30
+
31
+ def __getitem__(self, idx):
32
+ return self.y[idx], {k: v[idx] for k, v in self.x.items()}
33
+
10
34
  class AnnDataDataset(Dataset):
11
35
  """Simple PyTorch Dataset for AnnData objects.
12
36
 
13
37
  Supports optional chunked loading for backed AnnData objects. When
14
38
  `chunk_size` is provided, the dataset will load contiguous slices
15
39
  of rows (of size `chunk_size`) into memory once and serve individual
16
- rows from that cached chunk. This avoids calling `to_memory()` on
17
- a per-row basis which is expensive for large backed files.
40
+ rows from that cached chunk. Chunks are moved to device for faster access.
18
41
  """
19
42
  def __init__(self, adata: AnnData, formula: Dict[str, str], chunk_size: int):
20
43
  self.adata = adata
21
44
  self.formula = formula
22
45
  self.chunk_size = chunk_size
46
+ self.device = get_device()
23
47
 
24
48
  # keeping track of covariate-related information
25
49
  self.obs_levels = categories(self.adata.obs)
@@ -28,6 +52,7 @@ class AnnDataDataset(Dataset):
28
52
 
29
53
  # Internal cache for the currently loaded chunk
30
54
  self._chunk: AnnData | None = None
55
+ self._chunk_X = None
31
56
  self._chunk_start = 0
32
57
 
33
58
  def __len__(self):
@@ -42,19 +67,12 @@ class AnnDataDataset(Dataset):
42
67
  """
43
68
  self._ensure_chunk_loaded(idx)
44
69
  local_idx = idx - self._chunk_start
45
- adata_slice = self._chunk[local_idx]
46
-
47
- # Get X data, accounting for potential sparse matrices
48
- X = adata_slice.X
49
- if hasattr(X, 'toarray'):
50
- X = X.toarray()
51
70
 
52
- # Get obs data
71
+ # Get obs data from GPU-cached matrices
53
72
  obs_dict = {}
54
73
  for key in self.formula.keys():
55
- mat = self.obs_matrices.get(key)
56
- obs_dict[key] = to_tensor(mat.values[local_idx: local_idx + 1])
57
- return to_tensor(X), obs_dict
74
+ obs_dict[key] = self.obs_matrices[key][local_idx: local_idx + 1]
75
+ return self._chunk_X[local_idx], obs_dict
58
76
 
59
77
  def _ensure_chunk_loaded(self, idx: int) -> None:
60
78
  """Load the chunk that contains `idx` into the internal cache."""
@@ -69,36 +87,45 @@ class AnnDataDataset(Dataset):
69
87
  self._chunk = chunk
70
88
  self._chunk_start = start
71
89
 
72
- # Compute model matrices for this chunk's `obs` so we don't need
73
- # to keep the full obs data model matrices in memory.
90
+ # Move chunk to GPU
91
+ X = chunk.X
92
+ if hasattr(X, 'toarray'):
93
+ X = X.toarray()
94
+ self._chunk_X = torch.tensor(X, dtype=torch.float32).to(self.device)
95
+
96
+ # Compute model matrices for this chunk's `obs` and move to GPU
74
97
  obs_coded_chunk = code_levels(self._chunk.obs.copy(), self.obs_levels)
75
98
  self.obs_matrices = {}
99
+ predictor_names = {}
76
100
  for key, f in self.formula.items():
77
- self.obs_matrices[key] = model_matrix(f, obs_coded_chunk)
101
+ mat = model_matrix(f, obs_coded_chunk)
102
+ predictor_names [key] = list(mat.columns)
103
+ self.obs_matrices[key] = torch.tensor(mat.values, dtype=torch.float32).to(self.device)
78
104
 
79
105
  # Capture predictor (column) names from the model matrices once.
80
106
  if self.predictor_names is None:
81
- self.predictor_names = {k: list(v.columns) for k, v in self.obs_matrices.items()}
82
-
83
-
84
- def adata_loader(adata: AnnData,
85
- formula: Dict[str, str],
86
- chunk_size: int = None,
87
- batch_size: int = 1024,
88
- shuffle: bool = False,
89
- num_workers: int = 0,
90
- **kwargs) -> DataLoader:
91
- """
92
- Create a DataLoader from AnnData that returns batches of (X, obs).
93
- """
107
+ self.predictor_names = predictor_names
108
+
109
+
110
+ def adata_loader(
111
+ adata: AnnData,
112
+ formula: Dict[str, str],
113
+ chunk_size: int = None,
114
+ batch_size: int = 1024,
115
+ shuffle: bool = False,
116
+ num_workers: int = 0,
117
+ **kwargs
118
+ ) -> DataLoader:
119
+ """Create a DataLoader from AnnData that returns batches of (X, obs)."""
94
120
  data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
95
- if chunk_size is None:
96
- if getattr(adata, 'isbacked', False):
97
- chunk_size = 5000
98
- else:
99
- chunk_size = len(adata)
121
+ device = get_device()
122
+
123
+ # separate chunked from non-chunked cases
124
+ if not getattr(adata, 'isbacked', False):
125
+ dataset = _preloaded_adata(adata, formula, device)
126
+ else:
127
+ dataset = AnnDataDataset(adata, formula, chunk_size or 5000)
100
128
 
101
- dataset = AnnDataDataset(adata, formula, chunk_size)
102
129
  return DataLoader(
103
130
  dataset,
104
131
  batch_size=batch_size,
@@ -109,12 +136,30 @@ def adata_loader(adata: AnnData,
109
136
  )
110
137
 
111
138
  def obs_loader(obs: pd.DataFrame, marginal_formula, **kwargs):
112
- adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
113
- return adata_loader(
114
- adata,
115
- marginal_formula,
116
- **kwargs
117
- )
139
+ adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
140
+ return adata_loader(
141
+ adata,
142
+ marginal_formula,
143
+ **kwargs
144
+ )
145
+
146
+ ################################################################################
147
+ ## Extraction of in-memory AnnData to PreloadedDataset
148
+ ################################################################################
149
+
150
+ def _preloaded_adata(adata: AnnData, formula: Dict[str, str], device: torch.device) -> PreloadedDataset:
151
+ X = adata.X
152
+ if scipy.sparse.issparse(X):
153
+ X = X.toarray()
154
+ y = torch.tensor(X, dtype=torch.float32).to(device)
155
+
156
+ obs = code_levels(adata.obs.copy(), categories(adata.obs))
157
+ x = {
158
+ k: torch.tensor(model_matrix(f, obs).values, dtype=torch.float32).to(device)
159
+ for k, f in formula.items()
160
+ }
161
+ predictor_names = {k: list(model_matrix(f, obs).columns) for k, f in formula.items()}
162
+ return PreloadedDataset(y, x, predictor_names)
118
163
 
119
164
  ################################################################################
120
165
  ## Helper functions
@@ -1,9 +1,8 @@
1
1
  from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
2
- from .loader import adata_loader
2
+ from .loader import adata_loader, get_device
3
3
  from anndata import AnnData
4
4
  from typing import Union, Dict, Optional, Tuple
5
5
  import pandas as pd
6
- import pytorch_lightning as pl
7
6
  import torch
8
7
  import torch.nn as nn
9
8
  from abc import ABC, abstractmethod
@@ -18,6 +17,7 @@ class Marginal(ABC):
18
17
  self.predict = None
19
18
  self.predictor_names = None
20
19
  self.parameters = None
20
+ self.device = get_device()
21
21
 
22
22
  def setup_data(self, adata: AnnData, batch_size: int = 1024, **kwargs):
23
23
  """Set up the dataloader for the AnnData object."""
@@ -29,13 +29,30 @@ class Marginal(ABC):
29
29
  self.feature_dims = {k: v.shape[1] for k, v in obs_batch.items()}
30
30
  self.predictor_names = self.loader.dataset.predictor_names
31
31
 
32
- def fit(self, **kwargs):
33
- """Fit the marginal predictor"""
32
+ def fit(self, max_epochs: int = 100, **kwargs):
33
+ """Fit the marginal predictor using vanilla PyTorch training loop."""
34
34
  if self.predict is None:
35
35
  self.setup_optimizer(**kwargs)
36
- trainer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['trainer'])
37
- trainer = pl.Trainer(**trainer_kwargs)
38
- trainer.fit(self.predict, train_dataloaders=self.loader)
36
+
37
+ for epoch in range(max_epochs):
38
+ epoch_loss, n_batches = 0.0, 0
39
+
40
+ for batch in self.loader:
41
+ y, x = batch
42
+ if y.device != self.device:
43
+ y = y.to(self.device)
44
+ x = {k: v.to(self.device) for k, v in x.items()}
45
+
46
+ self.predict.optimizer.zero_grad()
47
+ loss = self.predict.loss_fn((y, x))
48
+ loss.backward()
49
+ self.predict.optimizer.step()
50
+
51
+ epoch_loss += loss.item()
52
+ n_batches += 1
53
+
54
+ avg_loss = epoch_loss / n_batches
55
+ print(f"Epoch {epoch}/{max_epochs}, Loss: {avg_loss:.4f}", end='\r')
39
56
  self.parameters = self.format_parameters()
40
57
 
41
58
  def format_parameters(self):
@@ -85,7 +102,7 @@ class Marginal(ABC):
85
102
  raise NotImplementedError
86
103
 
87
104
 
88
- class GLMPredictor(pl.LightningModule):
105
+ class GLMPredictor(nn.Module):
89
106
  """GLM-style predictor with arbitrary named parameters.
90
107
 
91
108
  Args:
@@ -111,21 +128,22 @@ class GLMPredictor(pl.LightningModule):
111
128
  self.feature_dims = dict(feature_dims)
112
129
  self.param_names = list(self.feature_dims.keys())
113
130
 
114
- # create default link functions and parameter matrices
115
131
  self.link_fns = link_fns or {k: torch.exp for k in self.param_names}
116
132
  self.coefs = nn.ParameterDict()
117
133
  for key, dim in self.feature_dims.items():
118
134
  self.coefs[key] = nn.Parameter(torch.zeros(dim, self.n_outcomes))
119
-
120
- # optimization parameters
121
135
  self.reset_parameters()
136
+
122
137
  self.loss_fn = loss_fn
123
- self.optimizer_class = optimizer_class
124
- self.optimizer_kwargs = optimizer_kwargs
138
+ self.to(get_device())
139
+
140
+ optimizer_kwargs = optimizer_kwargs or {}
141
+ filtered_kwargs = _filter_kwargs(optimizer_kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
142
+ self.optimizer = optimizer_class(self.parameters(), **filtered_kwargs)
125
143
 
126
144
  def reset_parameters(self):
127
145
  for p in self.coefs.values():
128
- nn.init.normal_(p, mean=0.0, std=1e-2)
146
+ nn.init.normal_(p, mean=0.0, std=1e-4)
129
147
 
130
148
  def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
131
149
  out = {}
@@ -133,13 +151,4 @@ class GLMPredictor(pl.LightningModule):
133
151
  x_beta = obs_dict[name] @ self.coefs[name]
134
152
  link = self.link_fns.get(name, torch.exp)
135
153
  out[name] = link(x_beta)
136
- return out
137
-
138
- def training_step(self, batch):
139
- loss = self.loss_fn(batch)
140
- self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
141
- return loss
142
-
143
- def configure_optimizers(self, **kwargs):
144
- optimizer_kwargs = _filter_kwargs(self.optimizer_kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
145
- return self.optimizer_class(self.parameters(), **optimizer_kwargs)
154
+ return out
@@ -51,16 +51,16 @@ class StandardCopula(Copula):
51
51
  def fit(self, uniformizer: Callable, **kwargs):
52
52
  """
53
53
  Fit the copula covariance model.
54
-
54
+
55
55
  Args:
56
56
  uniformizer (Callable): Function to convert data to uniform distribution
57
57
  **kwargs: Additional arguments
58
58
  top_k (int, optional): Use only top-k most expressed genes for covariance estimation.
59
59
  If None, estimates full covariance for all genes.
60
-
60
+
61
61
  Returns:
62
62
  None: Stores fitted parameters in self.parameters as dict of CovarianceStructure objects.
63
-
63
+
64
64
  Raises:
65
65
  ValueError: If top_k is not a positive integer or exceeds n_outcomes
66
66
  """
@@ -76,11 +76,11 @@ class StandardCopula(Copula):
76
76
  sorted_indices = np.argsort(gene_total_expression)
77
77
  top_k_indices = sorted_indices[-top_k:]
78
78
  remaining_indices = sorted_indices[:-top_k]
79
- covariances = self._compute_block_covariance(uniformizer, top_k_indices,
79
+ covariances = self._compute_block_covariance(uniformizer, top_k_indices,
80
80
  remaining_indices, top_k)
81
81
  else:
82
82
  covariances = self._compute_full_covariance(uniformizer)
83
-
83
+
84
84
  self.parameters = covariances
85
85
 
86
86
  def pseudo_obs(self, x_dict: Dict):
@@ -88,7 +88,7 @@ class StandardCopula(Copula):
88
88
  # {"group1": [indices of group 1], "group2": [indices of group 2]}
89
89
  # The initialization method ensures that x_dict will always have a "group" key.
90
90
  group_data = x_dict.get("group")
91
- memberships = group_data.numpy()
91
+ memberships = group_data.cpu().numpy()
92
92
  group_ix = {g: np.where(memberships[:, self.group_col[g] == 1])[0] for g in self.groups}
93
93
 
94
94
  # initialize the result
@@ -106,14 +106,14 @@ class StandardCopula(Copula):
106
106
  def likelihood(self, uniformizer: Callable, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
107
107
  """
108
108
  Compute likelihood of data given the copula model.
109
-
109
+
110
110
  Args:
111
111
  uniformizer (Callable): Function to convert expression data to uniform distribution
112
112
  batch (Tuple[torch.Tensor, Dict[str, torch.Tensor]]): Data batch containing:
113
113
  - Y (torch.Tensor): Expression data of shape (n_cells, n_genes)
114
114
  - X_dict (Dict[str, torch.Tensor]): Covariates dict with keys as parameter names
115
115
  and values as tensors of shape (n_cells, n_covariates)
116
-
116
+
117
117
  Returns:
118
118
  np.ndarray: Log-likelihood for each cell, shape (n_cells,)
119
119
  """
@@ -132,19 +132,19 @@ class StandardCopula(Copula):
132
132
  group_ix = {g: np.where(memberships[:, self.group_col[g] == 1])[0] for g in self.groups}
133
133
 
134
134
  ll = np.zeros(len(z))
135
-
135
+
136
136
  for group, cov_struct in parameters.items():
137
137
  ix = group_ix[group]
138
138
  if len(ix) > 0:
139
139
  z_modeled = z[ix][:, cov_struct.modeled_indices]
140
-
140
+
141
141
  ll_modeled = multivariate_normal.logpdf(z_modeled,
142
- np.zeros(cov_struct.num_modeled_genes),
142
+ np.zeros(cov_struct.num_modeled_genes),
143
143
  cov_struct.cov.values)
144
144
  if cov_struct.num_remaining_genes > 0:
145
145
  z_remaining = z[ix][:, cov_struct.remaining_indices]
146
146
  ll_remaining = norm.logpdf(z_remaining,
147
- loc=0,
147
+ loc=0,
148
148
  scale = np.sqrt(cov_struct.remaining_var.values))
149
149
  else:
150
150
  ll_remaining = 0
@@ -155,7 +155,7 @@ class StandardCopula(Copula):
155
155
  S = self.parameters
156
156
  per_group = [((S[g].num_modeled_genes * (S[g].num_modeled_genes - 1)) / 2) for g in self.groups]
157
157
  return sum(per_group)
158
-
158
+
159
159
  def _validate_parameters(self, **kwargs):
160
160
  top_k = kwargs.get("top_k", None)
161
161
  if top_k is not None:
@@ -166,14 +166,14 @@ class StandardCopula(Copula):
166
166
  if top_k > self.n_outcomes:
167
167
  raise ValueError(f"top_k ({top_k}) cannot exceed number of outcomes ({self.n_outcomes})")
168
168
  return top_k
169
-
170
-
169
+
170
+
171
171
 
172
172
  def _accumulate_top_k_stats(self, uniformizer:Callable, top_k_idx, rem_idx, top_k) \
173
- -> Tuple[Dict[Union[str, int], np.ndarray],
174
- Dict[Union[str, int], np.ndarray],
175
- Dict[Union[str, int], np.ndarray],
176
- Dict[Union[str, int], np.ndarray],
173
+ -> Tuple[Dict[Union[str, int], np.ndarray],
174
+ Dict[Union[str, int], np.ndarray],
175
+ Dict[Union[str, int], np.ndarray],
176
+ Dict[Union[str, int], np.ndarray],
177
177
  Dict[Union[str, int], int]]:
178
178
  """Accumulate sufficient statistics for top-k covariance estimation.
179
179
 
@@ -198,7 +198,7 @@ class StandardCopula(Copula):
198
198
 
199
199
  for y, x_dict in tqdm(self.loader, desc="Estimating top-k copula covariance"):
200
200
  group_data = x_dict.get("group")
201
- memberships = group_data.numpy()
201
+ memberships = group_data.cpu().numpy()
202
202
  u = uniformizer(y, x_dict)
203
203
  z = norm.ppf(u)
204
204
 
@@ -211,20 +211,20 @@ class StandardCopula(Copula):
211
211
  n_g = mask.sum()
212
212
 
213
213
  top_k_z, rem_z = z_g[:, top_k_idx], z_g[:, rem_idx]
214
-
214
+
215
215
  top_k_sums[g] += top_k_z.sum(axis=0)
216
216
  top_k_second_moments[g] += top_k_z.T @ top_k_z
217
-
217
+
218
218
  rem_sums[g] += rem_z.sum(axis=0)
219
219
  rem_second_moments[g] += (rem_z ** 2).sum(axis=0)
220
-
220
+
221
221
  Ng[g] += n_g
222
222
 
223
223
  return top_k_sums, top_k_second_moments, rem_sums, rem_second_moments, Ng
224
-
224
+
225
225
  def _accumulate_full_stats(self, uniformizer:Callable) \
226
- -> Tuple[Dict[Union[str, int], np.ndarray],
227
- Dict[Union[str, int], np.ndarray],
226
+ -> Tuple[Dict[Union[str, int], np.ndarray],
227
+ Dict[Union[str, int], np.ndarray],
228
228
  Dict[Union[str, int], int]]:
229
229
  """Accumulate sufficient statistics for full covariance estimation.
230
230
 
@@ -242,14 +242,14 @@ class StandardCopula(Copula):
242
242
 
243
243
  for y, x_dict in tqdm(self.loader, desc="Estimating copula covariance"):
244
244
  group_data = x_dict.get("group")
245
- memberships = group_data.numpy()
246
-
245
+ memberships = group_data.cpu().numpy()
246
+
247
247
  u = uniformizer(y, x_dict)
248
248
  z = norm.ppf(u)
249
249
 
250
250
  for g in self.groups:
251
251
  mask = memberships[:, self.group_col[g]] == 1
252
-
252
+
253
253
  if not np.any(mask):
254
254
  continue
255
255
 
@@ -258,12 +258,12 @@ class StandardCopula(Copula):
258
258
 
259
259
  second_moments[g] += z_g.T @ z_g
260
260
  sums[g] += z_g.sum(axis=0)
261
-
261
+
262
262
  Ng[g] += n_g
263
263
 
264
264
  return sums, second_moments, Ng
265
-
266
- def _compute_block_covariance(self, uniformizer:Callable,
265
+
266
+ def _compute_block_covariance(self, uniformizer:Callable,
267
267
  top_k_idx: np.ndarray, rem_idx: np.ndarray, top_k: int) \
268
268
  -> Dict[Union[str, int], CovarianceStructure]:
269
269
  """Compute the covariance matrix for the top-k and remaining genes.
@@ -300,7 +300,7 @@ class StandardCopula(Copula):
300
300
  remaining_names=remaining_names
301
301
  )
302
302
  return covariance
303
-
303
+
304
304
  def _compute_full_covariance(self, uniformizer:Callable) -> Dict[Union[str, int], CovarianceStructure]:
305
305
  """Compute the covariance matrix for the full genes.
306
306
 
@@ -327,7 +327,7 @@ class StandardCopula(Copula):
327
327
  remaining_names=None
328
328
  )
329
329
  return covariance
330
-
330
+
331
331
  def _fast_normal_pseudo_obs(self, n_samples: int, cov_struct: CovarianceStructure) -> np.ndarray:
332
332
  """Sample pseudo-observations from the covariance structure.
333
333
 
@@ -339,28 +339,28 @@ class StandardCopula(Copula):
339
339
  np.ndarray: Pseudo-observations with shape (n_samples, total_genes)
340
340
  """
341
341
  u = np.zeros((n_samples, cov_struct.total_genes))
342
-
342
+
343
343
  z_modeled = np.random.multivariate_normal(
344
- mean=np.zeros(cov_struct.num_modeled_genes),
345
- cov=cov_struct.cov.values,
344
+ mean=np.zeros(cov_struct.num_modeled_genes),
345
+ cov=cov_struct.cov.values,
346
346
  size=n_samples
347
347
  )
348
-
348
+
349
349
  z_remaining = np.random.normal(
350
- loc=0,
351
- scale=cov_struct.remaining_var.values ** 0.5,
350
+ loc=0,
351
+ scale=cov_struct.remaining_var.values ** 0.5,
352
352
  size=(n_samples, cov_struct.num_remaining_genes)
353
353
  )
354
-
354
+
355
355
  normal_distn_modeled = norm(0, np.diag(cov_struct.cov.values) ** 0.5)
356
356
  u[:, cov_struct.modeled_indices] = normal_distn_modeled.cdf(z_modeled)
357
-
357
+
358
358
  normal_distn_remaining = norm(0, cov_struct.remaining_var.values ** 0.5)
359
359
  u[:, cov_struct.remaining_indices] = normal_distn_remaining.cdf(z_remaining)
360
-
360
+
361
361
  return u
362
-
363
- def _normal_pseudo_obs(self, n_samples: int, cov_struct: CovarianceStructure) -> np.ndarray:
362
+
363
+ def _normal_pseudo_obs(self, n_samples: int, cov_struct: CovarianceStructure) -> np.ndarray:
364
364
  """Sample pseudo-observations from the covariance structure.
365
365
 
366
366
  Args:
@@ -372,12 +372,12 @@ class StandardCopula(Copula):
372
372
  """
373
373
  u = np.zeros((n_samples, cov_struct.total_genes))
374
374
  z = np.random.multivariate_normal(
375
- mean=np.zeros(cov_struct.total_genes),
376
- cov=cov_struct.cov.values,
375
+ mean=np.zeros(cov_struct.total_genes),
376
+ cov=cov_struct.cov.values,
377
377
  size=n_samples
378
378
  )
379
-
379
+
380
380
  normal_distn = norm(0, np.diag(cov_struct.cov.values) ** 0.5)
381
381
  u = normal_distn.cdf(z)
382
-
382
+
383
383
  return u
File without changes
File without changes
File without changes