scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdesigner/base/__init__.py +8 -0
- scdesigner/base/copula.py +416 -0
- scdesigner/base/marginal.py +391 -0
- scdesigner/base/simulator.py +59 -0
- scdesigner/copulas/__init__.py +8 -0
- scdesigner/copulas/standard_copula.py +645 -0
- scdesigner/datasets/__init__.py +5 -0
- scdesigner/datasets/pancreas.py +39 -0
- scdesigner/distributions/__init__.py +19 -0
- scdesigner/{minimal → distributions}/bernoulli.py +42 -14
- scdesigner/distributions/gaussian.py +114 -0
- scdesigner/distributions/negbin.py +121 -0
- scdesigner/distributions/negbin_irls.py +72 -0
- scdesigner/distributions/negbin_irls_funs.py +456 -0
- scdesigner/distributions/poisson.py +88 -0
- scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
- scdesigner/distributions/zero_inflated_poisson.py +103 -0
- scdesigner/simulators/__init__.py +24 -28
- scdesigner/simulators/composite.py +239 -0
- scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
- scdesigner/simulators/scd3.py +486 -0
- scdesigner/transform/__init__.py +8 -6
- scdesigner/{minimal → transform}/transform.py +1 -1
- scdesigner/{minimal → utils}/kwargs.py +4 -1
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
- scdesigner-0.0.10.dist-info/RECORD +28 -0
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
- scdesigner/data/__init__.py +0 -16
- scdesigner/data/formula.py +0 -137
- scdesigner/data/group.py +0 -123
- scdesigner/data/sparse.py +0 -39
- scdesigner/diagnose/__init__.py +0 -65
- scdesigner/diagnose/aic_bic.py +0 -119
- scdesigner/diagnose/plot.py +0 -242
- scdesigner/estimators/__init__.py +0 -32
- scdesigner/estimators/bernoulli.py +0 -85
- scdesigner/estimators/gaussian.py +0 -121
- scdesigner/estimators/gaussian_copula_factory.py +0 -367
- scdesigner/estimators/glm_factory.py +0 -75
- scdesigner/estimators/negbin.py +0 -153
- scdesigner/estimators/pnmf.py +0 -160
- scdesigner/estimators/poisson.py +0 -124
- scdesigner/estimators/zero_inflated_negbin.py +0 -195
- scdesigner/estimators/zero_inflated_poisson.py +0 -85
- scdesigner/format/__init__.py +0 -4
- scdesigner/format/format.py +0 -20
- scdesigner/format/print.py +0 -30
- scdesigner/minimal/__init__.py +0 -17
- scdesigner/minimal/composite.py +0 -119
- scdesigner/minimal/copula.py +0 -205
- scdesigner/minimal/formula.py +0 -23
- scdesigner/minimal/gaussian.py +0 -65
- scdesigner/minimal/loader.py +0 -211
- scdesigner/minimal/marginal.py +0 -154
- scdesigner/minimal/negbin.py +0 -73
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
- scdesigner/minimal/scd3.py +0 -96
- scdesigner/minimal/scd3_instances.py +0 -50
- scdesigner/minimal/simulator.py +0 -25
- scdesigner/minimal/standard_copula.py +0 -383
- scdesigner/predictors/__init__.py +0 -15
- scdesigner/predictors/bernoulli.py +0 -9
- scdesigner/predictors/gaussian.py +0 -16
- scdesigner/predictors/negbin.py +0 -17
- scdesigner/predictors/poisson.py +0 -12
- scdesigner/predictors/zero_inflated_negbin.py +0 -18
- scdesigner/predictors/zero_inflated_poisson.py +0 -18
- scdesigner/samplers/__init__.py +0 -23
- scdesigner/samplers/bernoulli.py +0 -27
- scdesigner/samplers/gaussian.py +0 -25
- scdesigner/samplers/glm_factory.py +0 -103
- scdesigner/samplers/negbin.py +0 -25
- scdesigner/samplers/poisson.py +0 -25
- scdesigner/samplers/zero_inflated_negbin.py +0 -40
- scdesigner/samplers/zero_inflated_poisson.py +0 -16
- scdesigner/simulators/composite_regressor.py +0 -72
- scdesigner/simulators/glm_simulator.py +0 -167
- scdesigner/simulators/pnmf_regression.py +0 -61
- scdesigner/transform/amplify.py +0 -14
- scdesigner/transform/mask.py +0 -33
- scdesigner/transform/nullify.py +0 -25
- scdesigner/transform/split.py +0 -23
- scdesigner/transform/substitute.py +0 -14
- scdesigner-0.0.5.dist-info/RECORD +0 -66
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
from ..utils.kwargs import DEFAULT_ALLOWED_KWARGS, _filter_kwargs
|
|
2
|
+
from ..data.loader import adata_loader, get_device
|
|
3
|
+
from anndata import AnnData
|
|
4
|
+
from typing import Union, Dict, Optional, Tuple
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Marginal(ABC):
|
|
12
|
+
"""
|
|
13
|
+
A Feature-wise Marginal Model
|
|
14
|
+
|
|
15
|
+
This is a class for handling feature-wise (e.g., gene-level) modeling that
|
|
16
|
+
ignores any correlation between features. For example, it can be used to
|
|
17
|
+
model the relationship between experimental design features, such as cell
|
|
18
|
+
type or treatment, and the parameters of a collection of negative binomial
|
|
19
|
+
models (one for each gene).
|
|
20
|
+
|
|
21
|
+
These marginals can be plugged into a copula to build the complete scDesign3
|
|
22
|
+
simulator. Methods are expected for estimating model parameters, evaluating
|
|
23
|
+
likelihoods on new samples, and generating new samples given conditioning
|
|
24
|
+
information. Since these marginal models are intended to be used within a
|
|
25
|
+
copula, they should also provide utilities for evaluating quantiles and
|
|
26
|
+
computing cumulative distribution functions.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
formula : dict or str
|
|
31
|
+
A dictionary or string specifying the relationship between the columns
|
|
32
|
+
of an input data frame (adata.obs, adata.var, or similar attributes) and
|
|
33
|
+
the parameters of the marginal model. If only a string is provided,
|
|
34
|
+
then the means are allowed to depend on the design parameters, while all
|
|
35
|
+
other parameters are treated as fixed. If a dictionary is provided,
|
|
36
|
+
each key should correspond to a parameter. The string values should be
|
|
37
|
+
in a format that can be parsed by the formulaic package. For example,
|
|
38
|
+
'~ x' will ensure that the parameter varies linearly with X.
|
|
39
|
+
|
|
40
|
+
Attributes
|
|
41
|
+
----------
|
|
42
|
+
formula : dict or str
|
|
43
|
+
A dictionary or string specifying the relationship between the columns
|
|
44
|
+
of an input data frame (adata.obs, adata.var, or similar attributes) and
|
|
45
|
+
the parameters of the marginal model. If only a string is provided,
|
|
46
|
+
then the means are allowed to depend on the design parameters, while all
|
|
47
|
+
other parameters are treated as fixed. If a dictionary is provided,
|
|
48
|
+
each key should correspond to a parameter. The string values should be
|
|
49
|
+
in a format that can be parsed by the formulaic package. For example,
|
|
50
|
+
'~ x' will ensure that the parameter varies linearly with X.
|
|
51
|
+
|
|
52
|
+
feature_dims : dict
|
|
53
|
+
A dictionary containing the number of predictors associated with each
|
|
54
|
+
distributional parameter. Note that this number is repeated for every
|
|
55
|
+
feature (e.g., gene) in the marginal model. This information is often
|
|
56
|
+
useful for computing the complexity of the estimated model.
|
|
57
|
+
|
|
58
|
+
loader : torch.utils.data.DataLoader
|
|
59
|
+
A torch DataLoader object that returns batches of data for use during
|
|
60
|
+
training. This loader is constructed internally within the
|
|
61
|
+
setup_optimizer method. Enumerating this loader returns a tuple: the
|
|
62
|
+
first element contains a tensor of feature measurements (y), and the
|
|
63
|
+
second element is a dictionary of tensors containing predictors to use
|
|
64
|
+
for each parameter (x, for each parameter theta(x)). This design is
|
|
65
|
+
useful because the design matrices may differ between parameters of the
|
|
66
|
+
marginal model, y | x ~ F_(theta(x))(y)
|
|
67
|
+
|
|
68
|
+
n_outcomes : int
|
|
69
|
+
The number of features modeled by this marginal model. For example,
|
|
70
|
+
this corresponds to the number of genes being simulated.
|
|
71
|
+
|
|
72
|
+
predict : nn.Module
|
|
73
|
+
A torch.nn.Module storing the relationship between predictors for each
|
|
74
|
+
parameter and the predicted feature-wise outcomes. This module is expected
|
|
75
|
+
to take the second element of the tuple defined by each batch and then
|
|
76
|
+
predict a tensor with the same shape as the first element of the batch
|
|
77
|
+
tuple.
|
|
78
|
+
|
|
79
|
+
predictor_names : dict of list of str
|
|
80
|
+
A dictionary whose keys are the parameter names associated with this
|
|
81
|
+
marginal model. The values for each key are the names of predictors in
|
|
82
|
+
the design matrix implied by the associated formula. Note that these
|
|
83
|
+
names may have been expanded from the original formula specification.
|
|
84
|
+
For example, if cell_type is included in the formula, then the predictor
|
|
85
|
+
names will include the unique levels of cell_type as separate columns in
|
|
86
|
+
the design matrix and therefore as separate elements in this list.
|
|
87
|
+
|
|
88
|
+
parameters : dict of pandas.DataFrame
|
|
89
|
+
A dictionary whose keys are the parameter names associated with this
|
|
90
|
+
marginal model. The values for each key are pandas DataFrames storing
|
|
91
|
+
the fitted parameter values. The rows of each DataFrame are the
|
|
92
|
+
experimental features specified by the associated formula object (the
|
|
93
|
+
rownames are the same as those in predictor_names). The columns are
|
|
94
|
+
features that are being predicted.
|
|
95
|
+
|
|
96
|
+
device : torch.device
|
|
97
|
+
The device on which the prediction module is stored. This is
|
|
98
|
+
automatically determined when calling the .fit method.
|
|
99
|
+
|
|
100
|
+
Examples
|
|
101
|
+
--------
|
|
102
|
+
>>> class DummyModel(Marginal):
|
|
103
|
+
... def fit(self):
|
|
104
|
+
... pass
|
|
105
|
+
...
|
|
106
|
+
... def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]):
|
|
107
|
+
... pass
|
|
108
|
+
...
|
|
109
|
+
... def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
110
|
+
... pass
|
|
111
|
+
...
|
|
112
|
+
... def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
113
|
+
... pass
|
|
114
|
+
...
|
|
115
|
+
... def setup_optimizer(self):
|
|
116
|
+
... pass
|
|
117
|
+
...
|
|
118
|
+
>>> model = DummyModel("~ cell_type")
|
|
119
|
+
>>> model.fit()
|
|
120
|
+
"""
|
|
121
|
+
def __init__(self, formula: Union[Dict, str], device: Optional[torch.device]=None):
|
|
122
|
+
self.formula = formula
|
|
123
|
+
self.feature_dims = None
|
|
124
|
+
self.loader = None
|
|
125
|
+
self.n_outcomes = None
|
|
126
|
+
self.predict = None
|
|
127
|
+
self.predictor_names = None
|
|
128
|
+
self.parameters = None
|
|
129
|
+
self.device = get_device(device)
|
|
130
|
+
|
|
131
|
+
def setup_data(self, adata: AnnData, batch_size: int = 1024, **kwargs):
|
|
132
|
+
"""Set up the dataloader for the AnnData object.
|
|
133
|
+
|
|
134
|
+
The simulator class definition doesn’t actually require any particular
|
|
135
|
+
template dataset. This is helpful for reasoning about simulators
|
|
136
|
+
abstractly, but when we actually want to estimate parameters, we need a
|
|
137
|
+
template. This method takes an input template dataset and adds
|
|
138
|
+
attributes to the simulator object for later estimation and sampling
|
|
139
|
+
steps.
|
|
140
|
+
|
|
141
|
+
Parameters
|
|
142
|
+
----------
|
|
143
|
+
adata : AnnData
|
|
144
|
+
This is the object on which we want to estimate the simulator. This
|
|
145
|
+
serves as the template for all downstream fitting.
|
|
146
|
+
batch_size : int
|
|
147
|
+
The number of sample to return on each call of the data loader.
|
|
148
|
+
Defaults to 1024.
|
|
149
|
+
**kwargs : Any
|
|
150
|
+
Other keyword arguments passed to data loader construction. Any
|
|
151
|
+
argument recognized by the PyTorch DataLoader can be passed in here.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
"""
|
|
156
|
+
# keep a reference to the AnnData for later use (e.g., var_names)
|
|
157
|
+
self.adata = adata
|
|
158
|
+
self.loader = adata_loader(adata, self.formula, batch_size=batch_size, **kwargs)
|
|
159
|
+
X_batch, obs_batch = next(iter(self.loader))
|
|
160
|
+
self.n_outcomes = X_batch.shape[1]
|
|
161
|
+
self.feature_dims = {k: v.shape[1] for k, v in obs_batch.items()}
|
|
162
|
+
self.predictor_names = self.loader.dataset.predictor_names
|
|
163
|
+
|
|
164
|
+
def fit(self, max_epochs: int = 100, verbose: bool = True, **kwargs):
|
|
165
|
+
"""Fit the marginal predictor using vanilla PyTorch training loop.
|
|
166
|
+
|
|
167
|
+
This method runs stochastic gradient optimization using the template
|
|
168
|
+
dataset defined by the setup_data method. The specific optimizer used
|
|
169
|
+
can be modified with the setup_optimizer method and defaults to Adam.
|
|
170
|
+
|
|
171
|
+
Note that, unlike `fit` in class `Simulator`, this method does not allow
|
|
172
|
+
the template dataset as input. This requires `.setup_data()` to be
|
|
173
|
+
called first. We want to give finer-grained control over the data
|
|
174
|
+
loading and optimization in this class relative to the specific
|
|
175
|
+
`Simulator` implementations, which are designed to be easy to run with
|
|
176
|
+
as few steps as possible.
|
|
177
|
+
|
|
178
|
+
Parameters
|
|
179
|
+
----------
|
|
180
|
+
max_epochs : int
|
|
181
|
+
The maximum number of epochs. This is the number of times we feed
|
|
182
|
+
through our cells in the dataset.
|
|
183
|
+
verbose : bool
|
|
184
|
+
Should we print intermediate training outputs?
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
None
|
|
189
|
+
This method doesn't return anything but modifies the self.parameters
|
|
190
|
+
attribute with the trained model parameters.
|
|
191
|
+
"""
|
|
192
|
+
if self.predict is None:
|
|
193
|
+
self.setup_optimizer(**kwargs)
|
|
194
|
+
|
|
195
|
+
for epoch in range(max_epochs):
|
|
196
|
+
epoch_loss, n_batches = 0.0, 0
|
|
197
|
+
|
|
198
|
+
for batch in self.loader:
|
|
199
|
+
y, x = batch
|
|
200
|
+
if y.device != self.device:
|
|
201
|
+
y = y.to(self.device)
|
|
202
|
+
x = {k: v.to(self.device) for k, v in x.items()}
|
|
203
|
+
|
|
204
|
+
self.predict.optimizer.zero_grad()
|
|
205
|
+
loss = self.predict.loss_fn((y, x))
|
|
206
|
+
loss.backward()
|
|
207
|
+
self.predict.optimizer.step()
|
|
208
|
+
|
|
209
|
+
epoch_loss += loss.item()
|
|
210
|
+
n_batches += 1
|
|
211
|
+
|
|
212
|
+
avg_loss = epoch_loss / n_batches
|
|
213
|
+
if verbose:
|
|
214
|
+
print(f"Epoch {epoch}/{max_epochs}, Loss: {avg_loss:.4f}", end='\r')
|
|
215
|
+
self.parameters = self.format_parameters()
|
|
216
|
+
|
|
217
|
+
def format_parameters(self):
|
|
218
|
+
"""Convert fitted coefficient tensors into pandas DataFrames.
|
|
219
|
+
|
|
220
|
+
Returns:
|
|
221
|
+
dict: mapping from parameter name -> pandas.DataFrame with rows
|
|
222
|
+
corresponding to predictor column names (from
|
|
223
|
+
`self.predictor_names[param]`) and columns corresponding to
|
|
224
|
+
`self.adata.var_names` (gene names). The values are moved to
|
|
225
|
+
CPU and converted to numpy floats.
|
|
226
|
+
"""
|
|
227
|
+
var_names = list(self.adata.var_names)
|
|
228
|
+
|
|
229
|
+
dfs = {}
|
|
230
|
+
for param, tensor in self.predict.coefs.items():
|
|
231
|
+
coef_np = tensor.detach().cpu().numpy()
|
|
232
|
+
row_names = list(self.predictor_names[param])
|
|
233
|
+
dfs[param] = pd.DataFrame(coef_np, index=row_names, columns=var_names)
|
|
234
|
+
return dfs
|
|
235
|
+
|
|
236
|
+
def num_params(self):
|
|
237
|
+
"""Return the number of parameters.
|
|
238
|
+
|
|
239
|
+
Count the number of parameters in the marginal simulator. Usually this
|
|
240
|
+
is just the number of predictors times the number of genes, because we
|
|
241
|
+
use a linear model. However, in specific implementations, it’s possible
|
|
242
|
+
to use more flexible models, in which case the number of parameters
|
|
243
|
+
would increase.
|
|
244
|
+
"""
|
|
245
|
+
if self.predict is None:
|
|
246
|
+
return 0
|
|
247
|
+
return sum(p.numel() for p in self.predict.parameters() if p.requires_grad)
|
|
248
|
+
|
|
249
|
+
@abstractmethod
|
|
250
|
+
def setup_optimizer(self, **kwargs):
|
|
251
|
+
raise NotImplementedError
|
|
252
|
+
|
|
253
|
+
@abstractmethod
|
|
254
|
+
def likelihood(self, batch: Tuple[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
|
|
255
|
+
"""Compute the log-likelihood for a batch.
|
|
256
|
+
|
|
257
|
+
The likelihood is used for maximum likelihood estimation. It is also
|
|
258
|
+
used when computing AIC and BIC scores, which are important when
|
|
259
|
+
choosing an appropriate model complexity.
|
|
260
|
+
|
|
261
|
+
Parameters
|
|
262
|
+
----------
|
|
263
|
+
batch : tuple of (torch.Tensor, dict of str -> torch.Tensor)
|
|
264
|
+
A tuple of gene expression (y) and experimental factors (x) used to
|
|
265
|
+
evaluate the model likelihood. The first element of the tuple is a
|
|
266
|
+
cells x genes tensor. The second is a a dictionary of tensors, with
|
|
267
|
+
one key/value pair per parameter. These tensors are the
|
|
268
|
+
conditioning information to pass to the .predict() function of this
|
|
269
|
+
distribution class. They are the numerical design matrices implied
|
|
270
|
+
by the initializing formulas.
|
|
271
|
+
|
|
272
|
+
Returns
|
|
273
|
+
-------
|
|
274
|
+
torch.Tensor
|
|
275
|
+
A scalar containing the log-likelihood of the batch under the
|
|
276
|
+
current model parameters.
|
|
277
|
+
"""
|
|
278
|
+
raise NotImplementedError
|
|
279
|
+
|
|
280
|
+
@abstractmethod
|
|
281
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
282
|
+
"""Invert pseudoobservations.
|
|
283
|
+
|
|
284
|
+
Return a quantile from the distribution. This handles the link between
|
|
285
|
+
the marginal and the copula model. The idea is that the copula will
|
|
286
|
+
generate pseudo-observations on the unit cube. All the values will be
|
|
287
|
+
between zero and one. By calling invert, we transform these zero-to-one
|
|
288
|
+
valued pseudo-observations into observations in the original data space;
|
|
289
|
+
values between zero and one can be thought of like quantiles.
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
u : torch.Tensor
|
|
294
|
+
Scalars between [0, 1] that specify the quantile level we want.
|
|
295
|
+
x : Dict of (str -> torch.Tensor)
|
|
296
|
+
A dictionary of tensors, with one key/value pair per parameter.
|
|
297
|
+
These tensors are the conditioning information to pass to the
|
|
298
|
+
.predict() function of this distribution class. They are the
|
|
299
|
+
numerical design matrices implied by the initializing formulas.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
z : torch.Tensor
|
|
304
|
+
A tensor with dimension dim(u) x num_genes. Each row gives the
|
|
305
|
+
requested quantile for the marginal distributions across all genes.
|
|
306
|
+
"""
|
|
307
|
+
raise NotImplementedError
|
|
308
|
+
|
|
309
|
+
@abstractmethod
|
|
310
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
311
|
+
"""Uniformize using learned CDF.
|
|
312
|
+
|
|
313
|
+
Apply a quantile/CDF transformation to observations y, accounting for
|
|
314
|
+
conditioning variables x. This step is used in training the copula
|
|
315
|
+
model. Since the copula needs to operate on the unit cube, we need to
|
|
316
|
+
transform the original data onto the unit cube. This can be done by
|
|
317
|
+
applying a CDF transformation.
|
|
318
|
+
|
|
319
|
+
Parameters
|
|
320
|
+
----------
|
|
321
|
+
y: torch.Tensor
|
|
322
|
+
A cells x genes tensor with gene expression levels across all cells.
|
|
323
|
+
x : Dict of (str -> torch.Tensor)
|
|
324
|
+
A dictionary of tensors, with one key/value pair per parameter.
|
|
325
|
+
These tensors are the conditioning information to pass to the
|
|
326
|
+
.predict() function of this distribution class. They are the
|
|
327
|
+
numerical design matrices implied by the initializing formulas.
|
|
328
|
+
"""
|
|
329
|
+
raise NotImplementedError
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class GLMPredictor(nn.Module):
|
|
333
|
+
"""GLM-style predictor with arbitrary named parameters.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
n_outcomes: number of model outputs (e.g. genes)
|
|
337
|
+
feature_dims: mapping from param name -> number of covariate features
|
|
338
|
+
link_fns: optional mapping from param name -> callable(link) applied to linear predictor
|
|
339
|
+
|
|
340
|
+
The module will create one coefficient matrix per named parameter with shape
|
|
341
|
+
(n_features_for_param, n_outcomes) and expose them as Parameters under
|
|
342
|
+
`self.coefs[param_name]`.
|
|
343
|
+
"""
|
|
344
|
+
def __init__(
|
|
345
|
+
self,
|
|
346
|
+
n_outcomes: int,
|
|
347
|
+
feature_dims: Dict[str, int],
|
|
348
|
+
link_fns: Dict[str, callable] = None,
|
|
349
|
+
loss_fn: Optional[callable] = None,
|
|
350
|
+
optimizer_class: Optional[callable] = torch.optim.AdamW,
|
|
351
|
+
optimizer_kwargs: Optional[Dict] = None,
|
|
352
|
+
device: Optional[torch.device] = None,
|
|
353
|
+
):
|
|
354
|
+
super().__init__()
|
|
355
|
+
self.n_outcomes = int(n_outcomes)
|
|
356
|
+
self.feature_dims = dict(feature_dims)
|
|
357
|
+
self.param_names = list(self.feature_dims.keys())
|
|
358
|
+
|
|
359
|
+
self.link_fns = link_fns or {k: torch.exp for k in self.param_names}
|
|
360
|
+
self.coefs = nn.ParameterDict()
|
|
361
|
+
for key, dim in self.feature_dims.items():
|
|
362
|
+
self.coefs[key] = nn.Parameter(torch.zeros(dim, self.n_outcomes))
|
|
363
|
+
self.reset_parameters()
|
|
364
|
+
|
|
365
|
+
self.loss_fn = loss_fn
|
|
366
|
+
self.to(get_device(device))
|
|
367
|
+
|
|
368
|
+
optimizer_kwargs = optimizer_kwargs or {}
|
|
369
|
+
filtered_kwargs = _filter_kwargs(optimizer_kwargs, DEFAULT_ALLOWED_KWARGS['optimizer'])
|
|
370
|
+
self.optimizer = optimizer_class(self.parameters(), **filtered_kwargs)
|
|
371
|
+
|
|
372
|
+
def reset_parameters(self):
|
|
373
|
+
for p in self.coefs.values():
|
|
374
|
+
nn.init.normal_(p, mean=0.0, std=1e-4)
|
|
375
|
+
|
|
376
|
+
def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
377
|
+
"""
|
|
378
|
+
Forward Pass for Given Covariates
|
|
379
|
+
|
|
380
|
+
obs_dict : Dict of (str -> torch.Tensor)
|
|
381
|
+
A dictionary of tensors, with one key/value pair per parameter.
|
|
382
|
+
These tensors are the conditioning information to pass to the
|
|
383
|
+
.predict() function of this distribution class. They are the
|
|
384
|
+
numerical design matrices implied by the initializing formulas.
|
|
385
|
+
"""
|
|
386
|
+
out = {}
|
|
387
|
+
for name in self.param_names:
|
|
388
|
+
x_beta = obs_dict[name] @ self.coefs[name]
|
|
389
|
+
link = self.link_fns.get(name, torch.exp)
|
|
390
|
+
out[name] = link(x_beta)
|
|
391
|
+
return out
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from anndata import AnnData
|
|
2
|
+
from typing import Dict
|
|
3
|
+
from pandas import DataFrame
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
|
|
6
|
+
class Simulator:
|
|
7
|
+
"""Simulation abstract class
|
|
8
|
+
|
|
9
|
+
This abstract simulator class defines the minimal methods that must be
|
|
10
|
+
exported by every simulator in scdesigner. These methods are:
|
|
11
|
+
|
|
12
|
+
* `fit`: Given an anndata dataset object, estimate the model parameters.
|
|
13
|
+
* `predict`: Given experimental and biological features x in the form of
|
|
14
|
+
a cell-by-features pd.DataFrame, return the parameters theta(x) of
|
|
15
|
+
interest.
|
|
16
|
+
* `sample`: Given the same experimental/biological information x as
|
|
17
|
+
`predict`, simulate hypothetical profiles associated wtih those
|
|
18
|
+
samples.
|
|
19
|
+
|
|
20
|
+
Example instantiations of this class are given in the module
|
|
21
|
+
`scdesigner.base.simulators`.
|
|
22
|
+
|
|
23
|
+
Examples
|
|
24
|
+
--------
|
|
25
|
+
>>> from scdesigner.datasets import pancreas
|
|
26
|
+
>>> sim = Simulator()
|
|
27
|
+
>>> sim.parameters
|
|
28
|
+
>>>
|
|
29
|
+
>>> # this is how a subclass would run, once its fit, predict, and sample
|
|
30
|
+
>>> # methods are implemented.
|
|
31
|
+
>>> sim.fit(pancreas) # doctest: +SKIP
|
|
32
|
+
>>> sim.predict(pancreas.obs) # doctest: +SKIP
|
|
33
|
+
>>> sim.sample(pancreas.obs) # doctest: +SKIP
|
|
34
|
+
"""
|
|
35
|
+
def __init__(self):
|
|
36
|
+
self.parameters = None
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def fit(self, anndata: AnnData, **kwargs) -> None:
|
|
40
|
+
"""Fit the simulator
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
adata : AnnData
|
|
45
|
+
This is the object on which we want to estimate the simulator. This
|
|
46
|
+
serves as the template for all downstream fitting.
|
|
47
|
+
"""
|
|
48
|
+
self.template = anndata
|
|
49
|
+
raise NotImplementedError
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def predict(self, obs: DataFrame=None, **kwargs) -> Dict:
|
|
53
|
+
"""Predict from an obs dataframe"""
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def sample(self, obs: DataFrame=None, **kwargs) -> AnnData:
|
|
58
|
+
"""Generate samples."""
|
|
59
|
+
raise NotImplementedError
|