scdesigner 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdesigner/estimators/__init__.py +8 -3
- scdesigner/estimators/gaussian_copula_factory.py +222 -7
- scdesigner/estimators/negbin.py +24 -0
- scdesigner/estimators/poisson.py +24 -0
- scdesigner/minimal/composite.py +2 -2
- scdesigner/minimal/copula.py +178 -6
- scdesigner/minimal/loader.py +85 -40
- scdesigner/minimal/marginal.py +53 -39
- scdesigner/minimal/negbin.py +1 -1
- scdesigner/minimal/scd3.py +1 -0
- scdesigner/minimal/scd3_instances.py +5 -5
- scdesigner/minimal/standard_copula.py +383 -0
- scdesigner/minimal/transform.py +27 -30
- scdesigner/samplers/glm_factory.py +66 -4
- scdesigner/transform/nullify.py +1 -1
- {scdesigner-0.0.3.dist-info → scdesigner-0.0.5.dist-info}/METADATA +1 -2
- {scdesigner-0.0.3.dist-info → scdesigner-0.0.5.dist-info}/RECORD +18 -18
- scdesigner/minimal/standard_covariance.py +0 -124
- {scdesigner-0.0.3.dist-info → scdesigner-0.0.5.dist-info}/WHEEL +0 -0
scdesigner/minimal/loader.py
CHANGED
|
@@ -5,21 +5,45 @@ from torch.utils.data import Dataset, DataLoader
|
|
|
5
5
|
from typing import Dict
|
|
6
6
|
import numpy as np
|
|
7
7
|
import pandas as pd
|
|
8
|
+
import scipy.sparse
|
|
8
9
|
import torch
|
|
9
10
|
|
|
11
|
+
def get_device():
|
|
12
|
+
"""Detect and return the best available device (MPS, CUDA, or CPU)."""
|
|
13
|
+
if torch.backends.mps.is_available():
|
|
14
|
+
return torch.device("mps")
|
|
15
|
+
elif torch.cuda.is_available():
|
|
16
|
+
return torch.device("cuda")
|
|
17
|
+
else:
|
|
18
|
+
return torch.device("cpu")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PreloadedDataset(Dataset):
|
|
22
|
+
"""Dataset that assumes x and y are both fully in memory."""
|
|
23
|
+
def __init__(self, y_tensor, x_tensors, predictor_names):
|
|
24
|
+
self.y = y_tensor
|
|
25
|
+
self.x = x_tensors
|
|
26
|
+
self.predictor_names = predictor_names
|
|
27
|
+
|
|
28
|
+
def __len__(self):
|
|
29
|
+
return len(self.y)
|
|
30
|
+
|
|
31
|
+
def __getitem__(self, idx):
|
|
32
|
+
return self.y[idx], {k: v[idx] for k, v in self.x.items()}
|
|
33
|
+
|
|
10
34
|
class AnnDataDataset(Dataset):
|
|
11
35
|
"""Simple PyTorch Dataset for AnnData objects.
|
|
12
36
|
|
|
13
37
|
Supports optional chunked loading for backed AnnData objects. When
|
|
14
38
|
`chunk_size` is provided, the dataset will load contiguous slices
|
|
15
39
|
of rows (of size `chunk_size`) into memory once and serve individual
|
|
16
|
-
rows from that cached chunk.
|
|
17
|
-
a per-row basis which is expensive for large backed files.
|
|
40
|
+
rows from that cached chunk. Chunks are moved to device for faster access.
|
|
18
41
|
"""
|
|
19
42
|
def __init__(self, adata: AnnData, formula: Dict[str, str], chunk_size: int):
|
|
20
43
|
self.adata = adata
|
|
21
44
|
self.formula = formula
|
|
22
45
|
self.chunk_size = chunk_size
|
|
46
|
+
self.device = get_device()
|
|
23
47
|
|
|
24
48
|
# keeping track of covariate-related information
|
|
25
49
|
self.obs_levels = categories(self.adata.obs)
|
|
@@ -28,6 +52,7 @@ class AnnDataDataset(Dataset):
|
|
|
28
52
|
|
|
29
53
|
# Internal cache for the currently loaded chunk
|
|
30
54
|
self._chunk: AnnData | None = None
|
|
55
|
+
self._chunk_X = None
|
|
31
56
|
self._chunk_start = 0
|
|
32
57
|
|
|
33
58
|
def __len__(self):
|
|
@@ -42,19 +67,12 @@ class AnnDataDataset(Dataset):
|
|
|
42
67
|
"""
|
|
43
68
|
self._ensure_chunk_loaded(idx)
|
|
44
69
|
local_idx = idx - self._chunk_start
|
|
45
|
-
adata_slice = self._chunk[local_idx]
|
|
46
|
-
|
|
47
|
-
# Get X data, accounting for potential sparse matrices
|
|
48
|
-
X = adata_slice.X
|
|
49
|
-
if hasattr(X, 'toarray'):
|
|
50
|
-
X = X.toarray()
|
|
51
70
|
|
|
52
|
-
# Get obs data
|
|
71
|
+
# Get obs data from GPU-cached matrices
|
|
53
72
|
obs_dict = {}
|
|
54
73
|
for key in self.formula.keys():
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
return to_tensor(X), obs_dict
|
|
74
|
+
obs_dict[key] = self.obs_matrices[key][local_idx: local_idx + 1]
|
|
75
|
+
return self._chunk_X[local_idx], obs_dict
|
|
58
76
|
|
|
59
77
|
def _ensure_chunk_loaded(self, idx: int) -> None:
|
|
60
78
|
"""Load the chunk that contains `idx` into the internal cache."""
|
|
@@ -69,36 +87,45 @@ class AnnDataDataset(Dataset):
|
|
|
69
87
|
self._chunk = chunk
|
|
70
88
|
self._chunk_start = start
|
|
71
89
|
|
|
72
|
-
#
|
|
73
|
-
|
|
90
|
+
# Move chunk to GPU
|
|
91
|
+
X = chunk.X
|
|
92
|
+
if hasattr(X, 'toarray'):
|
|
93
|
+
X = X.toarray()
|
|
94
|
+
self._chunk_X = torch.tensor(X, dtype=torch.float32).to(self.device)
|
|
95
|
+
|
|
96
|
+
# Compute model matrices for this chunk's `obs` and move to GPU
|
|
74
97
|
obs_coded_chunk = code_levels(self._chunk.obs.copy(), self.obs_levels)
|
|
75
98
|
self.obs_matrices = {}
|
|
99
|
+
predictor_names = {}
|
|
76
100
|
for key, f in self.formula.items():
|
|
77
|
-
|
|
101
|
+
mat = model_matrix(f, obs_coded_chunk)
|
|
102
|
+
predictor_names [key] = list(mat.columns)
|
|
103
|
+
self.obs_matrices[key] = torch.tensor(mat.values, dtype=torch.float32).to(self.device)
|
|
78
104
|
|
|
79
105
|
# Capture predictor (column) names from the model matrices once.
|
|
80
106
|
if self.predictor_names is None:
|
|
81
|
-
self.predictor_names =
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
def adata_loader(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
"""
|
|
107
|
+
self.predictor_names = predictor_names
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def adata_loader(
|
|
111
|
+
adata: AnnData,
|
|
112
|
+
formula: Dict[str, str],
|
|
113
|
+
chunk_size: int = None,
|
|
114
|
+
batch_size: int = 1024,
|
|
115
|
+
shuffle: bool = False,
|
|
116
|
+
num_workers: int = 0,
|
|
117
|
+
**kwargs
|
|
118
|
+
) -> DataLoader:
|
|
119
|
+
"""Create a DataLoader from AnnData that returns batches of (X, obs)."""
|
|
94
120
|
data_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['data'])
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
121
|
+
device = get_device()
|
|
122
|
+
|
|
123
|
+
# separate chunked from non-chunked cases
|
|
124
|
+
if not getattr(adata, 'isbacked', False):
|
|
125
|
+
dataset = _preloaded_adata(adata, formula, device)
|
|
126
|
+
else:
|
|
127
|
+
dataset = AnnDataDataset(adata, formula, chunk_size or 5000)
|
|
100
128
|
|
|
101
|
-
dataset = AnnDataDataset(adata, formula, chunk_size)
|
|
102
129
|
return DataLoader(
|
|
103
130
|
dataset,
|
|
104
131
|
batch_size=batch_size,
|
|
@@ -109,12 +136,30 @@ def adata_loader(adata: AnnData,
|
|
|
109
136
|
)
|
|
110
137
|
|
|
111
138
|
def obs_loader(obs: pd.DataFrame, marginal_formula, **kwargs):
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
139
|
+
adata = AnnData(X=np.zeros((len(obs), 1)), obs=obs)
|
|
140
|
+
return adata_loader(
|
|
141
|
+
adata,
|
|
142
|
+
marginal_formula,
|
|
143
|
+
**kwargs
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
################################################################################
|
|
147
|
+
## Extraction of in-memory AnnData to PreloadedDataset
|
|
148
|
+
################################################################################
|
|
149
|
+
|
|
150
|
+
def _preloaded_adata(adata: AnnData, formula: Dict[str, str], device: torch.device) -> PreloadedDataset:
|
|
151
|
+
X = adata.X
|
|
152
|
+
if scipy.sparse.issparse(X):
|
|
153
|
+
X = X.toarray()
|
|
154
|
+
y = torch.tensor(X, dtype=torch.float32).to(device)
|
|
155
|
+
|
|
156
|
+
obs = code_levels(adata.obs.copy(), categories(adata.obs))
|
|
157
|
+
x = {
|
|
158
|
+
k: torch.tensor(model_matrix(f, obs).values, dtype=torch.float32).to(device)
|
|
159
|
+
for k, f in formula.items()
|
|
160
|
+
}
|
|
161
|
+
predictor_names = {k: list(model_matrix(f, obs).columns) for k, f in formula.items()}
|
|
162
|
+
return PreloadedDataset(y, x, predictor_names)
|
|
118
163
|
|
|
119
164
|
################################################################################
|
|
120
165
|
## Helper functions
|
scdesigner/minimal/marginal.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
from .kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
|
|
2
|
-
from .loader import adata_loader
|
|
2
|
+
from .loader import adata_loader, get_device
|
|
3
3
|
from anndata import AnnData
|
|
4
4
|
from typing import Union, Dict, Optional, Tuple
|
|
5
5
|
import pandas as pd
|
|
6
|
-
import pytorch_lightning as pl
|
|
7
6
|
import torch
|
|
8
7
|
import torch.nn as nn
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
9
|
|
|
10
10
|
|
|
11
|
-
class Marginal:
|
|
11
|
+
class Marginal(ABC):
|
|
12
12
|
def __init__(self, formula: Union[Dict, str]):
|
|
13
13
|
self.formula = formula
|
|
14
14
|
self.feature_dims = None
|
|
@@ -17,6 +17,7 @@ class Marginal:
|
|
|
17
17
|
self.predict = None
|
|
18
18
|
self.predictor_names = None
|
|
19
19
|
self.parameters = None
|
|
20
|
+
self.device = get_device()
|
|
20
21
|
|
|
21
22
|
def setup_data(self, adata: AnnData, batch_size: int = 1024, **kwargs):
|
|
22
23
|
"""Set up the dataloader for the AnnData object."""
|
|
@@ -28,31 +29,31 @@ class Marginal:
|
|
|
28
29
|
self.feature_dims = {k: v.shape[1] for k, v in obs_batch.items()}
|
|
29
30
|
self.predictor_names = self.loader.dataset.predictor_names
|
|
30
31
|
|
|
31
|
-
def fit(self, **kwargs):
|
|
32
|
-
"""Fit the marginal predictor"""
|
|
32
|
+
def fit(self, max_epochs: int = 100, **kwargs):
|
|
33
|
+
"""Fit the marginal predictor using vanilla PyTorch training loop."""
|
|
33
34
|
if self.predict is None:
|
|
34
35
|
self.setup_optimizer(**kwargs)
|
|
35
|
-
trainer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['trainer'])
|
|
36
|
-
trainer = pl.Trainer(**trainer_kwargs)
|
|
37
|
-
trainer.fit(self.predict, train_dataloaders=self.loader)
|
|
38
|
-
self.parameters = self.format_parameters()
|
|
39
36
|
|
|
40
|
-
|
|
41
|
-
|
|
37
|
+
for epoch in range(max_epochs):
|
|
38
|
+
epoch_loss, n_batches = 0.0, 0
|
|
42
39
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
40
|
+
for batch in self.loader:
|
|
41
|
+
y, x = batch
|
|
42
|
+
if y.device != self.device:
|
|
43
|
+
y = y.to(self.device)
|
|
44
|
+
x = {k: v.to(self.device) for k, v in x.items()}
|
|
47
45
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
46
|
+
self.predict.optimizer.zero_grad()
|
|
47
|
+
loss = self.predict.loss_fn((y, x))
|
|
48
|
+
loss.backward()
|
|
49
|
+
self.predict.optimizer.step()
|
|
51
50
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
51
|
+
epoch_loss += loss.item()
|
|
52
|
+
n_batches += 1
|
|
53
|
+
|
|
54
|
+
avg_loss = epoch_loss / n_batches
|
|
55
|
+
print(f"Epoch {epoch}/{max_epochs}, Loss: {avg_loss:.4f}", end='\r')
|
|
56
|
+
self.parameters = self.format_parameters()
|
|
56
57
|
|
|
57
58
|
def format_parameters(self):
|
|
58
59
|
"""Convert fitted coefficient tensors into pandas DataFrames.
|
|
@@ -79,8 +80,29 @@ class Marginal:
|
|
|
79
80
|
return 0
|
|
80
81
|
return sum(p.numel() for p in self.predict.parameters() if p.requires_grad)
|
|
81
82
|
|
|
83
|
+
@abstractmethod
|
|
84
|
+
def setup_optimizer(self, **kwargs):
|
|
85
|
+
raise NotImplementedError
|
|
86
|
+
|
|
87
|
+
@abstractmethod
|
|
88
|
+
def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
|
|
89
|
+
"""Compute the (negative) log-likelihood or loss for a batch.
|
|
90
|
+
"""
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
|
|
93
|
+
@abstractmethod
|
|
94
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
95
|
+
"""Invert pseudoobservations."""
|
|
96
|
+
raise NotImplementedError
|
|
82
97
|
|
|
83
|
-
|
|
98
|
+
@abstractmethod
|
|
99
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
100
|
+
"""Uniformize using learned CDF.
|
|
101
|
+
"""
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class GLMPredictor(nn.Module):
|
|
84
106
|
"""GLM-style predictor with arbitrary named parameters.
|
|
85
107
|
|
|
86
108
|
Args:
|
|
@@ -106,21 +128,22 @@ class GLMPredictor(pl.LightningModule):
|
|
|
106
128
|
self.feature_dims = dict(feature_dims)
|
|
107
129
|
self.param_names = list(self.feature_dims.keys())
|
|
108
130
|
|
|
109
|
-
# create default link functions and parameter matrices
|
|
110
131
|
self.link_fns = link_fns or {k: torch.exp for k in self.param_names}
|
|
111
132
|
self.coefs = nn.ParameterDict()
|
|
112
133
|
for key, dim in self.feature_dims.items():
|
|
113
134
|
self.coefs[key] = nn.Parameter(torch.zeros(dim, self.n_outcomes))
|
|
114
|
-
|
|
115
|
-
# optimization parameters
|
|
116
135
|
self.reset_parameters()
|
|
136
|
+
|
|
117
137
|
self.loss_fn = loss_fn
|
|
118
|
-
self.
|
|
119
|
-
|
|
138
|
+
self.to(get_device())
|
|
139
|
+
|
|
140
|
+
optimizer_kwargs = optimizer_kwargs or {}
|
|
141
|
+
filtered_kwargs = _filter_kwargs(optimizer_kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
|
|
142
|
+
self.optimizer = optimizer_class(self.parameters(), **filtered_kwargs)
|
|
120
143
|
|
|
121
144
|
def reset_parameters(self):
|
|
122
145
|
for p in self.coefs.values():
|
|
123
|
-
nn.init.normal_(p, mean=0.0, std=1e-
|
|
146
|
+
nn.init.normal_(p, mean=0.0, std=1e-4)
|
|
124
147
|
|
|
125
148
|
def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
126
149
|
out = {}
|
|
@@ -128,13 +151,4 @@ class GLMPredictor(pl.LightningModule):
|
|
|
128
151
|
x_beta = obs_dict[name] @ self.coefs[name]
|
|
129
152
|
link = self.link_fns.get(name, torch.exp)
|
|
130
153
|
out[name] = link(x_beta)
|
|
131
|
-
return out
|
|
132
|
-
|
|
133
|
-
def training_step(self, batch):
|
|
134
|
-
loss = self.loss_fn(batch)
|
|
135
|
-
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
|
136
|
-
return loss
|
|
137
|
-
|
|
138
|
-
def configure_optimizers(self, **kwargs):
|
|
139
|
-
optimizer_kwargs = _filter_kwargs(kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
|
|
140
|
-
return self.optimizer_class(self.parameters(), **optimizer_kwargs)
|
|
154
|
+
return out
|
scdesigner/minimal/negbin.py
CHANGED
scdesigner/minimal/scd3.py
CHANGED
|
@@ -2,7 +2,7 @@ from .scd3 import SCD3Simulator
|
|
|
2
2
|
from .negbin import NegBin
|
|
3
3
|
from .zero_inflated_negbin import ZeroInflatedNegBin
|
|
4
4
|
from .gaussian import Gaussian
|
|
5
|
-
from .
|
|
5
|
+
from .standard_copula import StandardCopula
|
|
6
6
|
from typing import Optional
|
|
7
7
|
|
|
8
8
|
|
|
@@ -12,7 +12,7 @@ class NegBinCopula(SCD3Simulator):
|
|
|
12
12
|
dispersion_formula: Optional[str] = None,
|
|
13
13
|
copula_formula: Optional[str] = None) -> None:
|
|
14
14
|
marginal = NegBin({"mean": mean_formula, "dispersion": dispersion_formula})
|
|
15
|
-
covariance =
|
|
15
|
+
covariance = StandardCopula(copula_formula)
|
|
16
16
|
super().__init__(marginal, covariance)
|
|
17
17
|
|
|
18
18
|
|
|
@@ -27,7 +27,7 @@ class ZeroInflatedNegBinCopula(SCD3Simulator):
|
|
|
27
27
|
"dispersion": dispersion_formula,
|
|
28
28
|
"zero_inflation_formula": zero_inflation_formula
|
|
29
29
|
})
|
|
30
|
-
covariance =
|
|
30
|
+
covariance = StandardCopula(copula_formula)
|
|
31
31
|
super().__init__(marginal, covariance)
|
|
32
32
|
|
|
33
33
|
|
|
@@ -36,7 +36,7 @@ class BernoulliCopula(SCD3Simulator):
|
|
|
36
36
|
mean_formula: Optional[str] = None,
|
|
37
37
|
copula_formula: Optional[str] = None) -> None:
|
|
38
38
|
marginal = NegBin({"mean": mean_formula})
|
|
39
|
-
covariance =
|
|
39
|
+
covariance = StandardCopula(copula_formula)
|
|
40
40
|
super().__init__(marginal, covariance)
|
|
41
41
|
|
|
42
42
|
|
|
@@ -46,5 +46,5 @@ class GaussianCopula(SCD3Simulator):
|
|
|
46
46
|
sdev_formula: Optional[str] = None,
|
|
47
47
|
copula_formula: Optional[str] = None) -> None:
|
|
48
48
|
marginal = Gaussian({"mean": mean_formula, "sdev": sdev_formula})
|
|
49
|
-
covariance =
|
|
49
|
+
covariance = StandardCopula(copula_formula)
|
|
50
50
|
super().__init__(marginal, covariance)
|