scdesigner 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,21 +5,45 @@ from torch.utils.data import Dataset, DataLoader
5
5
  from typing import Dict
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ import scipy.sparse
8
9
  import torch
9
10
 
11
+ def get_device():
12
+ """Detect and return the best available device (MPS, CUDA, or CPU)."""
13
+ if torch.backends.mps.is_available():
14
+ return torch.device("mps")
15
+ elif torch.cuda.is_available():
16
+ return torch.device("cuda")
17
+ else:
18
+ return torch.device("cpu")
19
+
20
+
21
+ class PreloadedDataset(Dataset):
22
+ """Dataset that assumes x and y are both fully in memory."""
23
+ def __init__(self, y_tensor, x_tensors, predictor_names):
24
+ self.y = y_tensor
25
+ self.x = x_tensors
26
+ self.predictor_names = predictor_names
27
+
28
+ def __len__(self):
29
+ return len(self.y)
30
+
31
+ def __getitem__(self, idx):
32
+ return self.y[idx], {k: v[idx] for k, v in self.x.items()}
33
+
10
34
  class AnnDataDataset(Dataset):
11
35
  """Simple PyTorch Dataset for AnnData objects.
12
36
 
13
37
  Supports optional chunked loading for backed AnnData objects. When
14
38
  `chunk_size` is provided, the dataset will load contiguous slices
15
39
  of rows (of size `chunk_size`) into memory once and serve individual
16
- rows from that cached chunk. This avoids calling `to_memory()` on
17
- a per-row basis which is expensive for large backed files.
40
+ rows from that cached chunk. Chunks are moved to device for faster access.
18
41
  """
19
42
  def __init__(self, adata: AnnData, formula: Dict[str, str], chunk_size: int):
20
43
  self.adata = adata
21
44
  self.formula = formula
22
45
  self.chunk_size = chunk_size
46
+ self.device = get_device()
23
47
 
24
48
  # keeping track of covariate-related information
25
49
  self.obs_levels = categories(self.adata.obs)
@@ -28,6 +52,7 @@ class AnnDataDataset(Dataset):
28
52
 
29
53
  # Internal cache for the currently loaded chunk
30
54
  self._chunk: AnnData | None = None
55
+ self._chunk_X = None
31
56
  self._chunk_start = 0
32
57
 
33
58
  def __len__(self):
@@ -42,19 +67,12 @@ class AnnDataDataset(Dataset):
42
67
  """
43
68
  self._ensure_chunk_loaded(idx)
44
69
  local_idx = idx - self._chunk_start
45
- adata_slice = self._chunk[local_idx]
46
-
47
- # Get X data, accounting for potential sparse matrices
48
- X = adata_slice.X
49
- if hasattr(X, 'toarray'):
50
- X = X.toarray()
51
70
 
52
- # Get obs data
71
+ # Get obs data from GPU-cached matrices
53
72
  obs_dict = {}
54
73
  for key in self.formula.keys():
55
- mat = self.obs_matrices.get(key)
56
- obs_dict[key] = to_tensor(mat.values[local_idx: local_idx + 1])
57
- return to_tensor(X), obs_dict
74
+ obs_dict[key] = self.obs_matrices[key][local_idx: local_idx + 1]
75
+ return self._chunk_X[local_idx], obs_dict
58
76
 
59
77
  def _ensure_chunk_loaded(self, idx: int) -> None:
60
78
  """Load the chunk that contains `idx` into the internal cache."""
@@ -69,36 +87,45 @@ class AnnDataDataset(Dataset):
69
87
  self._chunk = chunk
70
88
  self._chunk_start = start
71
89
 
72
- # Compute model matrices for this chunk's `obs` so we don't need
73
- # to keep the full obs data model matrices in memory.
90
+ # Move chunk to GPU
91
+ X = chunk.X
92
+ if hasattr(X, 'toarray'):
93
+ X = X.toarray()
94
+ self._chunk_X = torch.tensor(X, dtype=torch.float32).to(self.device)
95
+
96
+ # Compute model matrices for this chunk's `obs` and move to GPU
74
97
  obs_coded_chunk = code_levels(self._chunk.obs.copy(), self.obs_levels)
75
98
  self.obs_matrices = {}
99
+ predictor_names = {}
76
100
  for key, f in self.formula.items():
77
- self.obs_matrices[key] = model_matrix(f, obs_coded_chunk)
101
+ mat = model_matrix(f, obs_coded_chunk)
102
+ predictor_names [key] = list(mat.columns)
103
+ self.obs_matrices[key] = torch.tensor(mat.values, dtype=torch.float32).to(self.device)
78
104
 
79
105
  # Capture predictor (column) names from the model matrices once.
80
106
  if self.predictor_names is None:
81
- self.predictor_names = {k: list(v.columns) for k, v in self.obs_matrices.items()}
82
-
83
-
84
- def adata_loader(adata: AnnData,
85
- formula: Dict[str, str],
86
- chunk_size: int = None,
87
- batch_size: int = 1024,
88
- shuffle: bool = False,
89
- num_workers: int = 0,
90
- **kwargs) -> DataLoader:
91
- """
92
- Create a DataLoader from AnnData that returns batches of (X, obs).
93
- """
107
+ self.predictor_names = predictor_names
108
+
109
+
110
+ def adata_loader(
111
+ adata: AnnData,
112
+ formula: Dict[str, str],
113
+ chunk_size: int = None,
114
+ batch_size: int = 1024,
115
+ shuffle: bool = False,
116
+ num_workers: int = 0,
117
+ **kwargs
118
+ ) -> DataLoader:
119
+ """Create a DataLoader from AnnData that returns batches of (X, obs)."""
94
120
  data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
95
- if chunk_size is None:
96
- if getattr(adata, 'isbacked', False):
97
- chunk_size = 5000
98
- else:
99
- chunk_size = len(adata)
121
+ device = get_device()
122
+
123
+ # separate chunked from non-chunked cases
124
+ if not getattr(adata, 'isbacked', False):
125
+ dataset = _preloaded_adata(adata, formula, device)
126
+ else:
127
+ dataset = AnnDataDataset(adata, formula, chunk_size or 5000)
100
128
 
101
- dataset = AnnDataDataset(adata, formula, chunk_size)
102
129
  return DataLoader(
103
130
  dataset,
104
131
  batch_size=batch_size,
@@ -109,12 +136,30 @@ def adata_loader(adata: AnnData,
109
136
  )
110
137
 
111
138
  def obs_loader(obs: pd.DataFrame, marginal_formula, **kwargs):
112
- adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
113
- return adata_loader(
114
- adata,
115
- marginal_formula,
116
- **kwargs
117
- )
139
+ adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
140
+ return adata_loader(
141
+ adata,
142
+ marginal_formula,
143
+ **kwargs
144
+ )
145
+
146
+ ################################################################################
147
+ ## Extraction of in-memory AnnData to PreloadedDataset
148
+ ################################################################################
149
+
150
+ def _preloaded_adata(adata: AnnData, formula: Dict[str, str], device: torch.device) -> PreloadedDataset:
151
+ X = adata.X
152
+ if scipy.sparse.issparse(X):
153
+ X = X.toarray()
154
+ y = torch.tensor(X, dtype=torch.float32).to(device)
155
+
156
+ obs = code_levels(adata.obs.copy(), categories(adata.obs))
157
+ x = {
158
+ k: torch.tensor(model_matrix(f, obs).values, dtype=torch.float32).to(device)
159
+ for k, f in formula.items()
160
+ }
161
+ predictor_names = {k: list(model_matrix(f, obs).columns) for k, f in formula.items()}
162
+ return PreloadedDataset(y, x, predictor_names)
118
163
 
119
164
  ################################################################################
120
165
  ## Helper functions
@@ -1,14 +1,14 @@
1
1
  from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
2
- from .loader import adata_loader
2
+ from .loader import adata_loader, get_device
3
3
  from anndata import AnnData
4
4
  from typing import Union, Dict, Optional, Tuple
5
5
  import pandas as pd
6
- import pytorch_lightning as pl
7
6
  import torch
8
7
  import torch.nn as nn
8
+ from abc import ABC, abstractmethod
9
9
 
10
10
 
11
- class Marginal:
11
+ class Marginal(ABC):
12
12
  def __init__(self, formula: Union[Dict, str]):
13
13
  self.formula = formula
14
14
  self.feature_dims = None
@@ -17,6 +17,7 @@ class Marginal:
17
17
  self.predict = None
18
18
  self.predictor_names = None
19
19
  self.parameters = None
20
+ self.device = get_device()
20
21
 
21
22
  def setup_data(self, adata: AnnData, batch_size: int = 1024, **kwargs):
22
23
  """Set up the dataloader for the AnnData object."""
@@ -28,31 +29,31 @@ class Marginal:
28
29
  self.feature_dims = {k: v.shape[1] for k, v in obs_batch.items()}
29
30
  self.predictor_names = self.loader.dataset.predictor_names
30
31
 
31
- def fit(self, **kwargs):
32
- """Fit the marginal predictor"""
32
+ def fit(self, max_epochs: int = 100, **kwargs):
33
+ """Fit the marginal predictor using vanilla PyTorch training loop."""
33
34
  if self.predict is None:
34
35
  self.setup_optimizer(**kwargs)
35
- trainer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['trainer'])
36
- trainer = pl.Trainer(**trainer_kwargs)
37
- trainer.fit(self.predict, train_dataloaders=self.loader)
38
- self.parameters = self.format_parameters()
39
36
 
40
- def setup_optimizer(self, **kwargs):
41
- raise NotImplementedError
37
+ for epoch in range(max_epochs):
38
+ epoch_loss, n_batches = 0.0, 0
42
39
 
43
- def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
44
- """Compute the (negative) log-likelihood or loss for a batch.
45
- """
46
- raise NotImplementedError
40
+ for batch in self.loader:
41
+ y, x = batch
42
+ if y.device != self.device:
43
+ y = y.to(self.device)
44
+ x = {k: v.to(self.device) for k, v in x.items()}
47
45
 
48
- def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
49
- """Invert pseudoobservations."""
50
- raise NotImplementedError
46
+ self.predict.optimizer.zero_grad()
47
+ loss = self.predict.loss_fn((y, x))
48
+ loss.backward()
49
+ self.predict.optimizer.step()
51
50
 
52
- def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
53
- """Uniformize using learned CDF.
54
- """
55
- raise NotImplementedError
51
+ epoch_loss += loss.item()
52
+ n_batches += 1
53
+
54
+ avg_loss = epoch_loss / n_batches
55
+ print(f"Epoch {epoch}/{max_epochs}, Loss: {avg_loss:.4f}", end='\r')
56
+ self.parameters = self.format_parameters()
56
57
 
57
58
  def format_parameters(self):
58
59
  """Convert fitted coefficient tensors into pandas DataFrames.
@@ -79,8 +80,29 @@ class Marginal:
79
80
  return 0
80
81
  return sum(p.numel() for p in self.predict.parameters() if p.requires_grad)
81
82
 
83
+ @abstractmethod
84
+ def setup_optimizer(self, **kwargs):
85
+ raise NotImplementedError
86
+
87
+ @abstractmethod
88
+ def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
89
+ """Compute the (negative) log-likelihood or loss for a batch.
90
+ """
91
+ raise NotImplementedError
92
+
93
+ @abstractmethod
94
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
95
+ """Invert pseudoobservations."""
96
+ raise NotImplementedError
82
97
 
83
- class GLMPredictor(pl.LightningModule):
98
+ @abstractmethod
99
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
100
+ """Uniformize using learned CDF.
101
+ """
102
+ raise NotImplementedError
103
+
104
+
105
+ class GLMPredictor(nn.Module):
84
106
  """GLM-style predictor with arbitrary named parameters.
85
107
 
86
108
  Args:
@@ -106,21 +128,22 @@ class GLMPredictor(pl.LightningModule):
106
128
  self.feature_dims = dict(feature_dims)
107
129
  self.param_names = list(self.feature_dims.keys())
108
130
 
109
- # create default link functions and parameter matrices
110
131
  self.link_fns = link_fns or {k: torch.exp for k in self.param_names}
111
132
  self.coefs = nn.ParameterDict()
112
133
  for key, dim in self.feature_dims.items():
113
134
  self.coefs[key] = nn.Parameter(torch.zeros(dim, self.n_outcomes))
114
-
115
- # optimization parameters
116
135
  self.reset_parameters()
136
+
117
137
  self.loss_fn = loss_fn
118
- self.optimizer_class = optimizer_class
119
- self.optimizer_kwargs = optimizer_kwargs
138
+ self.to(get_device())
139
+
140
+ optimizer_kwargs = optimizer_kwargs or {}
141
+ filtered_kwargs = _filter_kwargs(optimizer_kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
142
+ self.optimizer = optimizer_class(self.parameters(), **filtered_kwargs)
120
143
 
121
144
  def reset_parameters(self):
122
145
  for p in self.coefs.values():
123
- nn.init.normal_(p, mean=0.0, std=1e-2)
146
+ nn.init.normal_(p, mean=0.0, std=1e-4)
124
147
 
125
148
  def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
126
149
  out = {}
@@ -128,13 +151,4 @@ class GLMPredictor(pl.LightningModule):
128
151
  x_beta = obs_dict[name] @ self.coefs[name]
129
152
  link = self.link_fns.get(name, torch.exp)
130
153
  out[name] = link(x_beta)
131
- return out
132
-
133
- def training_step(self, batch):
134
- loss = self.loss_fn(batch)
135
- self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
136
- return loss
137
-
138
- def configure_optimizers(self, **kwargs):
139
- optimizer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
140
- return self.optimizer_class(self.parameters(), **optimizer_kwargs)
154
+ return out
@@ -30,7 +30,7 @@ class NegBin(Marginal):
30
30
  )
31
31
 
32
32
  def likelihood(self, batch):
33
- """Compute the negative log-likelihood"""
33
+ """Compute the log-likelihood"""
34
34
  y, x = batch
35
35
  params = self.predict(x)
36
36
  mu = params.get('mean')
@@ -6,6 +6,7 @@ from anndata import AnnData
6
6
  from tqdm import tqdm
7
7
  import torch
8
8
  import numpy as np
9
+ from abc import ABC, abstractmethod
9
10
 
10
11
  class SCD3Simulator(Simulator):
11
12
  """Simulation wrapper"""
@@ -2,7 +2,7 @@ from .scd3 import SCD3Simulator
2
2
  from .negbin import NegBin
3
3
  from .zero_inflated_negbin import ZeroInflatedNegBin
4
4
  from .gaussian import Gaussian
5
- from .standard_covariance import StandardCovariance
5
+ from .standard_copula import StandardCopula
6
6
  from typing import Optional
7
7
 
8
8
 
@@ -12,7 +12,7 @@ class NegBinCopula(SCD3Simulator):
12
12
  dispersion_formula: Optional[str] = None,
13
13
  copula_formula: Optional[str] = None) -> None:
14
14
  marginal = NegBin({"mean": mean_formula, "dispersion": dispersion_formula})
15
- covariance = StandardCovariance(copula_formula)
15
+ covariance = StandardCopula(copula_formula)
16
16
  super().__init__(marginal, covariance)
17
17
 
18
18
 
@@ -27,7 +27,7 @@ class ZeroInflatedNegBinCopula(SCD3Simulator):
27
27
  "dispersion": dispersion_formula,
28
28
  "zero_inflation_formula": zero_inflation_formula
29
29
  })
30
- covariance = StandardCovariance(copula_formula)
30
+ covariance = StandardCopula(copula_formula)
31
31
  super().__init__(marginal, covariance)
32
32
 
33
33
 
@@ -36,7 +36,7 @@ class BernoulliCopula(SCD3Simulator):
36
36
  mean_formula: Optional[str] = None,
37
37
  copula_formula: Optional[str] = None) -> None:
38
38
  marginal = NegBin({"mean": mean_formula})
39
- covariance = StandardCovariance(copula_formula)
39
+ covariance = StandardCopula(copula_formula)
40
40
  super().__init__(marginal, covariance)
41
41
 
42
42
 
@@ -46,5 +46,5 @@ class GaussianCopula(SCD3Simulator):
46
46
  sdev_formula: Optional[str] = None,
47
47
  copula_formula: Optional[str] = None) -> None:
48
48
  marginal = Gaussian({"mean": mean_formula, "sdev": sdev_formula})
49
- covariance = StandardCovariance(copula_formula)
49
+ covariance = StandardCopula(copula_formula)
50
50
  super().__init__(marginal, covariance)