scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdesigner/base/__init__.py +8 -0
- scdesigner/base/copula.py +416 -0
- scdesigner/base/marginal.py +391 -0
- scdesigner/base/simulator.py +59 -0
- scdesigner/copulas/__init__.py +8 -0
- scdesigner/copulas/standard_copula.py +645 -0
- scdesigner/datasets/__init__.py +5 -0
- scdesigner/datasets/pancreas.py +39 -0
- scdesigner/distributions/__init__.py +19 -0
- scdesigner/{minimal → distributions}/bernoulli.py +42 -14
- scdesigner/distributions/gaussian.py +114 -0
- scdesigner/distributions/negbin.py +121 -0
- scdesigner/distributions/negbin_irls.py +72 -0
- scdesigner/distributions/negbin_irls_funs.py +456 -0
- scdesigner/distributions/poisson.py +88 -0
- scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
- scdesigner/distributions/zero_inflated_poisson.py +103 -0
- scdesigner/simulators/__init__.py +24 -28
- scdesigner/simulators/composite.py +239 -0
- scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
- scdesigner/simulators/scd3.py +486 -0
- scdesigner/transform/__init__.py +8 -6
- scdesigner/{minimal → transform}/transform.py +1 -1
- scdesigner/{minimal → utils}/kwargs.py +4 -1
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
- scdesigner-0.0.10.dist-info/RECORD +28 -0
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
- scdesigner/data/__init__.py +0 -16
- scdesigner/data/formula.py +0 -137
- scdesigner/data/group.py +0 -123
- scdesigner/data/sparse.py +0 -39
- scdesigner/diagnose/__init__.py +0 -65
- scdesigner/diagnose/aic_bic.py +0 -119
- scdesigner/diagnose/plot.py +0 -242
- scdesigner/estimators/__init__.py +0 -32
- scdesigner/estimators/bernoulli.py +0 -85
- scdesigner/estimators/gaussian.py +0 -121
- scdesigner/estimators/gaussian_copula_factory.py +0 -367
- scdesigner/estimators/glm_factory.py +0 -75
- scdesigner/estimators/negbin.py +0 -153
- scdesigner/estimators/pnmf.py +0 -160
- scdesigner/estimators/poisson.py +0 -124
- scdesigner/estimators/zero_inflated_negbin.py +0 -195
- scdesigner/estimators/zero_inflated_poisson.py +0 -85
- scdesigner/format/__init__.py +0 -4
- scdesigner/format/format.py +0 -20
- scdesigner/format/print.py +0 -30
- scdesigner/minimal/__init__.py +0 -17
- scdesigner/minimal/composite.py +0 -119
- scdesigner/minimal/copula.py +0 -205
- scdesigner/minimal/formula.py +0 -23
- scdesigner/minimal/gaussian.py +0 -65
- scdesigner/minimal/loader.py +0 -211
- scdesigner/minimal/marginal.py +0 -154
- scdesigner/minimal/negbin.py +0 -73
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
- scdesigner/minimal/scd3.py +0 -96
- scdesigner/minimal/scd3_instances.py +0 -50
- scdesigner/minimal/simulator.py +0 -25
- scdesigner/minimal/standard_copula.py +0 -383
- scdesigner/predictors/__init__.py +0 -15
- scdesigner/predictors/bernoulli.py +0 -9
- scdesigner/predictors/gaussian.py +0 -16
- scdesigner/predictors/negbin.py +0 -17
- scdesigner/predictors/poisson.py +0 -12
- scdesigner/predictors/zero_inflated_negbin.py +0 -18
- scdesigner/predictors/zero_inflated_poisson.py +0 -18
- scdesigner/samplers/__init__.py +0 -23
- scdesigner/samplers/bernoulli.py +0 -27
- scdesigner/samplers/gaussian.py +0 -25
- scdesigner/samplers/glm_factory.py +0 -103
- scdesigner/samplers/negbin.py +0 -25
- scdesigner/samplers/poisson.py +0 -25
- scdesigner/samplers/zero_inflated_negbin.py +0 -40
- scdesigner/samplers/zero_inflated_poisson.py +0 -16
- scdesigner/simulators/composite_regressor.py +0 -72
- scdesigner/simulators/glm_simulator.py +0 -167
- scdesigner/simulators/pnmf_regression.py +0 -61
- scdesigner/transform/amplify.py +0 -14
- scdesigner/transform/mask.py +0 -33
- scdesigner/transform/nullify.py +0 -25
- scdesigner/transform/split.py +0 -23
- scdesigner/transform/substitute.py +0 -14
- scdesigner-0.0.5.dist-info/RECORD +0 -66
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
from ..data.formula import standardize_formula
|
|
2
|
+
from ..data.loader import _to_numpy
|
|
3
|
+
from ..base.simulator import Simulator
|
|
4
|
+
from anndata import AnnData
|
|
5
|
+
from formulaic import model_matrix
|
|
6
|
+
from scipy.stats import gamma
|
|
7
|
+
from typing import Union, Dict
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
################################################################################
|
|
13
|
+
## Functions for estimating PNMF regression
|
|
14
|
+
################################################################################
|
|
15
|
+
|
|
16
|
+
# computes PNMF weight and score, ncol specify the number of clusters
|
|
17
|
+
def pnmf(log_data, nbase=3, **kwargs): # data is np array, log transformed read data
|
|
18
|
+
"""
|
|
19
|
+
Estimate PNMF components from log-transformed expression counts.
|
|
20
|
+
|
|
21
|
+
Parameters
|
|
22
|
+
----------
|
|
23
|
+
log_data : np.ndarray
|
|
24
|
+
Log-transformed expression matrix (genes × cells).
|
|
25
|
+
nbase : int, optional
|
|
26
|
+
Number of latent PNMF bases to extract.
|
|
27
|
+
**kwargs
|
|
28
|
+
Additional arguments forwarded to :func:`pnmf_eucdist`.
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
tuple[np.ndarray, np.ndarray]
|
|
33
|
+
Tuple containing the learned PNMF weight matrix ``W`` and score matrix
|
|
34
|
+
``S`` (pseudo-basis loadings for each cell).
|
|
35
|
+
"""
|
|
36
|
+
U = left_singular(log_data, nbase)
|
|
37
|
+
W = pnmf_eucdist(log_data, U, **kwargs)
|
|
38
|
+
W = W / np.linalg.norm(W, ord=2)
|
|
39
|
+
S = W.T @ log_data
|
|
40
|
+
return W, S
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def gamma_regression_array(
|
|
44
|
+
x: np.array, y: np.array, lr: float = 0.1, epochs: int = 40
|
|
45
|
+
) -> dict:
|
|
46
|
+
"""
|
|
47
|
+
Fit gamma regression coefficients in a batched regression context.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
x : np.ndarray
|
|
52
|
+
Design matrix for covariates (cells × covariates).
|
|
53
|
+
y : np.ndarray
|
|
54
|
+
Target matrix (cells × latent features) derived from PNMF scores.
|
|
55
|
+
lr : float, optional
|
|
56
|
+
Learning rate for the Adam optimizer.
|
|
57
|
+
epochs : int, optional
|
|
58
|
+
Number of training epochs.
|
|
59
|
+
|
|
60
|
+
Returns
|
|
61
|
+
-------
|
|
62
|
+
dict[str, np.ndarray]
|
|
63
|
+
Dictionary containing estimated ``"a"``, ``"loc"``, and ``"beta"``
|
|
64
|
+
regression coefficients shaped (covariates, outcomes).
|
|
65
|
+
"""
|
|
66
|
+
x = torch.tensor(x, dtype=torch.float32)
|
|
67
|
+
y = torch.tensor(y, dtype=torch.float32)
|
|
68
|
+
|
|
69
|
+
n_features, n_outcomes = x.shape[1], y.shape[1]
|
|
70
|
+
a = torch.zeros(n_features * n_outcomes, requires_grad=True)
|
|
71
|
+
loc = torch.zeros(n_features * n_outcomes, requires_grad=True)
|
|
72
|
+
beta = torch.zeros(n_features * n_outcomes, requires_grad=True)
|
|
73
|
+
optimizer = torch.optim.Adam([a, loc, beta], lr=lr)
|
|
74
|
+
|
|
75
|
+
for i in range(epochs):
|
|
76
|
+
optimizer.zero_grad()
|
|
77
|
+
loss = negative_gamma_log_likelihood(a, beta, loc, x, y)
|
|
78
|
+
loss.backward()
|
|
79
|
+
optimizer.step()
|
|
80
|
+
|
|
81
|
+
a, loc, beta = _to_numpy(a, loc, beta)
|
|
82
|
+
a = a.reshape(n_features, n_outcomes)
|
|
83
|
+
loc = loc.reshape(n_features, n_outcomes)
|
|
84
|
+
beta = beta.reshape(n_features, n_outcomes)
|
|
85
|
+
return {"a": a, "loc": loc, "beta": beta}
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def class_generator(score, n_clusters=3):
|
|
89
|
+
"""
|
|
90
|
+
Cluster PNMF scores and return discrete class labels. (This function is not used in the current implementation.)
|
|
91
|
+
|
|
92
|
+
Parameters
|
|
93
|
+
----------
|
|
94
|
+
score : np.ndarray
|
|
95
|
+
PNMF scores (latent factors) of shape (n_features, n_cells).
|
|
96
|
+
n_clusters : int, optional
|
|
97
|
+
Number of target clusters for grouping the scores.
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
-------
|
|
101
|
+
np.ndarray
|
|
102
|
+
Array of cluster labels of length ``n_cells``.
|
|
103
|
+
"""
|
|
104
|
+
from sklearn.cluster import KMeans
|
|
105
|
+
kmeans = KMeans(n_clusters, random_state=0) # Specify the number of clusters
|
|
106
|
+
kmeans.fit(score.T)
|
|
107
|
+
labels = kmeans.labels_
|
|
108
|
+
return labels
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
###############################################################################
|
|
112
|
+
## Helpers for deriving PNMF
|
|
113
|
+
###############################################################################
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def pnmf_eucdist(X, W_init, maxIter=500, threshold=1e-4, tol=1e-10, verbose=False, **kwargs):
|
|
117
|
+
"""
|
|
118
|
+
Optimize PNMF weights via Euclidean distance minimization.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
X : np.ndarray
|
|
123
|
+
Input expression matrix (genes × cells).
|
|
124
|
+
W_init : np.ndarray
|
|
125
|
+
Initial estimate of the weight matrix.
|
|
126
|
+
maxIter : int, optional
|
|
127
|
+
Maximum number of iterations.
|
|
128
|
+
threshold : float, optional
|
|
129
|
+
Convergence threshold on relative weight change.
|
|
130
|
+
tol : float, optional
|
|
131
|
+
Numeric tolerance for truncating small entries.
|
|
132
|
+
verbose : bool, optional
|
|
133
|
+
If True, print progress every 10 iterations.
|
|
134
|
+
**kwargs
|
|
135
|
+
Reserved for future options.
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
np.ndarray
|
|
140
|
+
Normalized PNMF weight matrix with positive entries.
|
|
141
|
+
"""
|
|
142
|
+
# initialization
|
|
143
|
+
W = W_init # initial W is the PCA of X
|
|
144
|
+
XX = X @ X.T
|
|
145
|
+
|
|
146
|
+
# iterations
|
|
147
|
+
for iter in range(maxIter):
|
|
148
|
+
if verbose and (iter + 1) % 10 == 0:
|
|
149
|
+
print("%d iterations used." % (iter + 1))
|
|
150
|
+
W_old = W
|
|
151
|
+
|
|
152
|
+
XXW = XX @ W
|
|
153
|
+
SclFactor = np.dot(W, W.T @ XXW) + np.dot(XXW, W.T @ W)
|
|
154
|
+
|
|
155
|
+
# QuotientLB
|
|
156
|
+
SclFactor = MatFindlb(SclFactor, tol)
|
|
157
|
+
SclFactor = XXW / SclFactor
|
|
158
|
+
W = W * SclFactor # somehow W *= SclFactor doesn't work?
|
|
159
|
+
|
|
160
|
+
norm_W = np.linalg.norm(W)
|
|
161
|
+
W /= norm_W
|
|
162
|
+
W = MatFind(W, tol)
|
|
163
|
+
|
|
164
|
+
diffW = np.linalg.norm(W_old - W) / np.linalg.norm(W_old)
|
|
165
|
+
if diffW < threshold:
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
return W
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# left singular vector of X
|
|
172
|
+
def left_singular(X, k):
|
|
173
|
+
"""
|
|
174
|
+
Extract the top `k` left singular vectors of matrix `X` for initialization.
|
|
175
|
+
"""
|
|
176
|
+
from scipy.sparse.linalg import svds
|
|
177
|
+
U, _, _ = svds(X, k=k)
|
|
178
|
+
return np.abs(U)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def MatFindlb(A, lb):
|
|
182
|
+
"""
|
|
183
|
+
Clamp matrix A's entries to be greater than or equal to a lower bound `lb`.
|
|
184
|
+
"""
|
|
185
|
+
B = np.ones(A.shape) * lb
|
|
186
|
+
Alb = np.where(A < lb, B, A)
|
|
187
|
+
return Alb
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def MatFind(A, ZeroThres):
|
|
191
|
+
"""
|
|
192
|
+
Zero out values below a threshold.
|
|
193
|
+
"""
|
|
194
|
+
B = np.zeros(A.shape)
|
|
195
|
+
Atrunc = np.where(A < ZeroThres, B, A)
|
|
196
|
+
return Atrunc
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
###############################################################################
|
|
200
|
+
## Helpers for training PNMF regression
|
|
201
|
+
###############################################################################
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def shifted_gamma_pdf(x, alpha, beta, loc):
|
|
205
|
+
"""
|
|
206
|
+
Compute penalized negative log probability density function for shifted gamma.
|
|
207
|
+
A huge penalty is applied to values below the location parameter to ensure
|
|
208
|
+
the parameters fall within the support of the gamma distribution.
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
x : torch.Tensor or array-like
|
|
213
|
+
Observed values.
|
|
214
|
+
alpha : torch.Tensor
|
|
215
|
+
Shape parameters.
|
|
216
|
+
beta : torch.Tensor
|
|
217
|
+
Rate parameters.
|
|
218
|
+
loc : torch.Tensor
|
|
219
|
+
Location parameters (shift).
|
|
220
|
+
|
|
221
|
+
Returns
|
|
222
|
+
-------
|
|
223
|
+
torch.Tensor
|
|
224
|
+
Mean negative log likelihood over ``x``.
|
|
225
|
+
"""
|
|
226
|
+
if not torch.is_tensor(x):
|
|
227
|
+
x = torch.tensor(x)
|
|
228
|
+
mask = x < loc
|
|
229
|
+
y_clamped = torch.clamp(x - loc, min=1e-12)
|
|
230
|
+
|
|
231
|
+
log_pdf = (
|
|
232
|
+
alpha * torch.log(beta)
|
|
233
|
+
- torch.lgamma(alpha)
|
|
234
|
+
+ (alpha - 1) * torch.log(y_clamped)
|
|
235
|
+
- beta * y_clamped
|
|
236
|
+
)
|
|
237
|
+
loss = -torch.mean(log_pdf[~mask])
|
|
238
|
+
n_invalid = mask.sum()
|
|
239
|
+
if n_invalid > 0: # force samples to be greater than loc
|
|
240
|
+
loss = loss + 1e10 * n_invalid.float()
|
|
241
|
+
return loss
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def negative_gamma_log_likelihood(log_a, log_beta, loc, X, y):
|
|
245
|
+
"""
|
|
246
|
+
Compute the (negative) log likelihood over a gamma regression layer.
|
|
247
|
+
|
|
248
|
+
Parameters
|
|
249
|
+
----------
|
|
250
|
+
log_a : torch.Tensor
|
|
251
|
+
Log-scale shape coefficients of shape (covariates × outcomes).
|
|
252
|
+
log_beta : torch.Tensor
|
|
253
|
+
Log-scale rate coefficients of shape (covariates × outcomes).
|
|
254
|
+
loc : torch.Tensor
|
|
255
|
+
Location coefficients of shape (covariates × outcomes).
|
|
256
|
+
X : torch.Tensor
|
|
257
|
+
Design matrix (cells × covariates).
|
|
258
|
+
y : torch.Tensor
|
|
259
|
+
Observed PNMF scores (cells × outcomes).
|
|
260
|
+
|
|
261
|
+
Returns
|
|
262
|
+
-------
|
|
263
|
+
torch.Tensor
|
|
264
|
+
Scalar tensor representing the mean negative log-likelihood.
|
|
265
|
+
"""
|
|
266
|
+
n_features = X.shape[1]
|
|
267
|
+
n_outcomes = y.shape[1]
|
|
268
|
+
|
|
269
|
+
a = torch.exp(log_a.reshape(n_features, n_outcomes))
|
|
270
|
+
beta = torch.exp(log_beta.reshape(n_features, n_outcomes))
|
|
271
|
+
loc = loc.reshape(n_features, n_outcomes)
|
|
272
|
+
return shifted_gamma_pdf(y, X @ a, X @ beta, X @ loc)
|
|
273
|
+
|
|
274
|
+
def format_gamma_parameters(
|
|
275
|
+
parameters: dict,
|
|
276
|
+
W_index: list,
|
|
277
|
+
coef_index: list,
|
|
278
|
+
) -> dict:
|
|
279
|
+
"""
|
|
280
|
+
Format gamma regression parameters as DataFrames for downstream use.
|
|
281
|
+
|
|
282
|
+
Parameters
|
|
283
|
+
----------
|
|
284
|
+
parameters : dict
|
|
285
|
+
Dictionary containing ``"a"``, ``"loc"``, ``"beta"``, and ``"W"`` arrays.
|
|
286
|
+
W_index : list
|
|
287
|
+
Row labels for the PNMF weights (typically gene names).
|
|
288
|
+
coef_index : list
|
|
289
|
+
Row labels for the regression coefficients (typically covariate names).
|
|
290
|
+
|
|
291
|
+
Returns
|
|
292
|
+
-------
|
|
293
|
+
dict
|
|
294
|
+
Updated dictionary with DataFrames stored under the original keys.
|
|
295
|
+
"""
|
|
296
|
+
parameters["a"] = pd.DataFrame(parameters["a"], index=coef_index)
|
|
297
|
+
parameters["loc"] = pd.DataFrame(parameters["loc"], index=coef_index)
|
|
298
|
+
parameters["beta"] = pd.DataFrame(parameters["beta"], index=coef_index)
|
|
299
|
+
parameters["W"] = pd.DataFrame(parameters["W"], index=W_index)
|
|
300
|
+
return parameters
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
################################################################################
|
|
304
|
+
## Associated PNMF Objects
|
|
305
|
+
################################################################################
|
|
306
|
+
|
|
307
|
+
class PositiveNMF(Simulator):
|
|
308
|
+
"""
|
|
309
|
+
Positive nonnegative matrix factorization (PNMF) simulator with gamma regression.
|
|
310
|
+
|
|
311
|
+
This simulator fits a low-rank positive factorization on log-transformed
|
|
312
|
+
expression data and then models the resulting latent scores using a
|
|
313
|
+
covariate-dependent shifted gamma distribution. Sampling proceeds by
|
|
314
|
+
drawing gamma latent scores and mapping them back to the gene space via the
|
|
315
|
+
learned PNMF weights.
|
|
316
|
+
|
|
317
|
+
Parameters
|
|
318
|
+
----------
|
|
319
|
+
formula : dict or str
|
|
320
|
+
Mean-model formula for the gamma regression. If a string is provided,
|
|
321
|
+
it is interpreted as the mean formula and stored under the key
|
|
322
|
+
``"mean"``. The formula is evaluated against ``adata.obs`` via
|
|
323
|
+
:func:`formulaic.model_matrix`.
|
|
324
|
+
**kwargs
|
|
325
|
+
Keyword arguments forwarded to :func:`pnmf` (e.g. ``nbase``, ``maxIter``).
|
|
326
|
+
|
|
327
|
+
Attributes
|
|
328
|
+
----------
|
|
329
|
+
formula : dict
|
|
330
|
+
Standardized formula dictionary containing at least the ``"mean"`` key.
|
|
331
|
+
parameters : dict or None
|
|
332
|
+
Fitted parameters after calling :meth:`fit`. Keys include:
|
|
333
|
+
|
|
334
|
+
* ``"a"`` (:class:`pandas.DataFrame`): gamma shape regression coefficients.
|
|
335
|
+
* ``"loc"`` (:class:`pandas.DataFrame`): gamma location regression coefficients.
|
|
336
|
+
* ``"beta"`` (:class:`pandas.DataFrame`): gamma rate regression coefficients.
|
|
337
|
+
* ``"W"`` (:class:`pandas.DataFrame`): PNMF weight matrix mapping latent
|
|
338
|
+
scores to genes.
|
|
339
|
+
n_outcomes : int
|
|
340
|
+
Number of simulated outcomes (cells) in the training data.
|
|
341
|
+
columns : pandas.Index
|
|
342
|
+
Column names of the design matrix produced from ``formula["mean"]``.
|
|
343
|
+
|
|
344
|
+
Examples
|
|
345
|
+
--------
|
|
346
|
+
Fit a PNMF simulator, inspect fitted parameters, and generate samples:
|
|
347
|
+
|
|
348
|
+
>>> import numpy as np
|
|
349
|
+
>>> import pandas as pd
|
|
350
|
+
>>> from anndata import AnnData
|
|
351
|
+
>>> from scdesigner.simulators import PositiveNMF
|
|
352
|
+
>>>
|
|
353
|
+
>>> rng = np.random.default_rng(0)
|
|
354
|
+
>>> X = rng.poisson(lam=2.0, size=(50, 20)).astype(float) # (cells × genes)
|
|
355
|
+
>>> obs = pd.DataFrame({"condition": rng.choice(["A", "B"], size=50)})
|
|
356
|
+
>>> adata = AnnData(X=X, obs=obs)
|
|
357
|
+
>>> adata.var_names = [f"g{i}" for i in range(adata.n_vars)]
|
|
358
|
+
>>>
|
|
359
|
+
>>> sim = PositiveNMF("~ 1 + condition", nbase=3, maxIter=50)
|
|
360
|
+
>>> sim.fit(adata, lr=0.1)
|
|
361
|
+
>>> isinstance(sim.parameters, dict)
|
|
362
|
+
True
|
|
363
|
+
>>>
|
|
364
|
+
>>> # Predict gamma parameters for new observations
|
|
365
|
+
>>> new_obs = pd.DataFrame({"condition": ["A", "B", "A"]})
|
|
366
|
+
>>> pred = sim.predict(new_obs)
|
|
367
|
+
>>> sorted(pred.keys())
|
|
368
|
+
['a', 'beta', 'loc']
|
|
369
|
+
>>>
|
|
370
|
+
>>> # Sample a new dataset with the same genes
|
|
371
|
+
>>> adata_sim = sim.sample(new_obs)
|
|
372
|
+
>>> adata_sim.n_obs == 3 and adata_sim.n_vars == adata.n_vars
|
|
373
|
+
True
|
|
374
|
+
"""
|
|
375
|
+
def __init__(self, formula: Union[Dict, str], **kwargs):
|
|
376
|
+
"""
|
|
377
|
+
Parameters
|
|
378
|
+
----------
|
|
379
|
+
formula : dict or str
|
|
380
|
+
Formula describing the mean model for the gamma regression.
|
|
381
|
+
**kwargs
|
|
382
|
+
Keyword arguments passed through to :func:`pnmf`.
|
|
383
|
+
"""
|
|
384
|
+
self.formula = standardize_formula(formula, allowed_keys=['mean'])
|
|
385
|
+
self.parameters = None
|
|
386
|
+
self._hyperparams = kwargs
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def setup_data(self, adata: AnnData, **kwargs):
|
|
390
|
+
self.log_data = np.log1p(adata.X).T #(genes x cells)
|
|
391
|
+
self.n_outcomes = self.log_data.shape[1]
|
|
392
|
+
self._template = adata
|
|
393
|
+
self.x = model_matrix(self.formula["mean"], adata.obs)
|
|
394
|
+
self.columns = self.x.columns
|
|
395
|
+
self.x = np.asarray(self.x)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def fit(self, adata: AnnData, lr: float=0.1):
|
|
399
|
+
"""
|
|
400
|
+
Fit the PNMF marginals on the provided AnnData.
|
|
401
|
+
|
|
402
|
+
Parameters
|
|
403
|
+
----------
|
|
404
|
+
adata : AnnData
|
|
405
|
+
Dataset used to estimate PNMF weights and gamma coefficients.
|
|
406
|
+
lr : float, optional
|
|
407
|
+
Learning rate for the gamma regression solver.
|
|
408
|
+
"""
|
|
409
|
+
self.setup_data(adata)
|
|
410
|
+
W, S = pnmf(self.log_data, **self._hyperparams)
|
|
411
|
+
parameters = gamma_regression_array(self.x, S.T, lr)
|
|
412
|
+
parameters["W"] = W
|
|
413
|
+
self.parameters = format_gamma_parameters(
|
|
414
|
+
parameters, list(self._template.var_names), list(self.columns)
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
def predict(self, obs=None, **kwargs):
|
|
419
|
+
"""
|
|
420
|
+
Predict gamma regression parameters for new observations.
|
|
421
|
+
|
|
422
|
+
Parameters
|
|
423
|
+
----------
|
|
424
|
+
obs : pandas.DataFrame, optional
|
|
425
|
+
Observation metadata used to construct the design matrix. Defaults
|
|
426
|
+
to the training ``AnnData`` observations.
|
|
427
|
+
|
|
428
|
+
Returns
|
|
429
|
+
-------
|
|
430
|
+
dict[str, np.ndarray]
|
|
431
|
+
Dictionary with ``"a"``, ``"loc"``, and ``"beta"`` arrays for each
|
|
432
|
+
target feature and observation.
|
|
433
|
+
"""
|
|
434
|
+
if obs is None:
|
|
435
|
+
obs = self._template.obs
|
|
436
|
+
|
|
437
|
+
x = model_matrix(self.formula["mean"], obs)
|
|
438
|
+
a, loc, beta = (
|
|
439
|
+
x @ np.exp(self.parameters["a"]),
|
|
440
|
+
x @ self.parameters["loc"],
|
|
441
|
+
x @ np.exp(self.parameters["beta"]),
|
|
442
|
+
)
|
|
443
|
+
return {"a": a, "loc": loc, "beta": beta}
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def sample(self, obs=None):
|
|
447
|
+
"""
|
|
448
|
+
Generate expression samples from the fitted model.
|
|
449
|
+
|
|
450
|
+
Parameters
|
|
451
|
+
----------
|
|
452
|
+
obs : pandas.DataFrame, optional
|
|
453
|
+
Metadata for the observations to simulate. Defaults to the training
|
|
454
|
+
``AnnData`` annotations.
|
|
455
|
+
|
|
456
|
+
Returns
|
|
457
|
+
-------
|
|
458
|
+
AnnData
|
|
459
|
+
Simulated :class:`AnnData` matrix containing generated expression
|
|
460
|
+
counts on the original feature ordering.
|
|
461
|
+
"""
|
|
462
|
+
if obs is None:
|
|
463
|
+
obs = self._template.obs
|
|
464
|
+
W = self.parameters["W"]
|
|
465
|
+
parameters = self.predict(obs)
|
|
466
|
+
a, loc, beta = parameters["a"], parameters["loc"], parameters["beta"]
|
|
467
|
+
sim_score = gamma(a, loc, 1 / beta).rvs()
|
|
468
|
+
samples = np.exp(W @ sim_score.T).T
|
|
469
|
+
|
|
470
|
+
# thresholding samples
|
|
471
|
+
floor = np.floor(samples)
|
|
472
|
+
samples = floor + np.where(samples - floor < 0.9, 0, 1) - 1
|
|
473
|
+
samples = np.where(samples < 0, 0, samples)
|
|
474
|
+
|
|
475
|
+
result = AnnData(X=samples, obs=obs)
|
|
476
|
+
result.var_names = self._template.var_names
|
|
477
|
+
return result
|