scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. scdesigner/base/__init__.py +8 -0
  2. scdesigner/base/copula.py +416 -0
  3. scdesigner/base/marginal.py +391 -0
  4. scdesigner/base/simulator.py +59 -0
  5. scdesigner/copulas/__init__.py +8 -0
  6. scdesigner/copulas/standard_copula.py +645 -0
  7. scdesigner/datasets/__init__.py +5 -0
  8. scdesigner/datasets/pancreas.py +39 -0
  9. scdesigner/distributions/__init__.py +19 -0
  10. scdesigner/{minimal → distributions}/bernoulli.py +42 -14
  11. scdesigner/distributions/gaussian.py +114 -0
  12. scdesigner/distributions/negbin.py +121 -0
  13. scdesigner/distributions/negbin_irls.py +72 -0
  14. scdesigner/distributions/negbin_irls_funs.py +456 -0
  15. scdesigner/distributions/poisson.py +88 -0
  16. scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
  17. scdesigner/distributions/zero_inflated_poisson.py +103 -0
  18. scdesigner/simulators/__init__.py +24 -28
  19. scdesigner/simulators/composite.py +239 -0
  20. scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
  21. scdesigner/simulators/scd3.py +486 -0
  22. scdesigner/transform/__init__.py +8 -6
  23. scdesigner/{minimal → transform}/transform.py +1 -1
  24. scdesigner/{minimal → utils}/kwargs.py +4 -1
  25. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
  26. scdesigner-0.0.10.dist-info/RECORD +28 -0
  27. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
  28. scdesigner/data/__init__.py +0 -16
  29. scdesigner/data/formula.py +0 -137
  30. scdesigner/data/group.py +0 -123
  31. scdesigner/data/sparse.py +0 -39
  32. scdesigner/diagnose/__init__.py +0 -65
  33. scdesigner/diagnose/aic_bic.py +0 -119
  34. scdesigner/diagnose/plot.py +0 -242
  35. scdesigner/estimators/__init__.py +0 -32
  36. scdesigner/estimators/bernoulli.py +0 -85
  37. scdesigner/estimators/gaussian.py +0 -121
  38. scdesigner/estimators/gaussian_copula_factory.py +0 -367
  39. scdesigner/estimators/glm_factory.py +0 -75
  40. scdesigner/estimators/negbin.py +0 -153
  41. scdesigner/estimators/pnmf.py +0 -160
  42. scdesigner/estimators/poisson.py +0 -124
  43. scdesigner/estimators/zero_inflated_negbin.py +0 -195
  44. scdesigner/estimators/zero_inflated_poisson.py +0 -85
  45. scdesigner/format/__init__.py +0 -4
  46. scdesigner/format/format.py +0 -20
  47. scdesigner/format/print.py +0 -30
  48. scdesigner/minimal/__init__.py +0 -17
  49. scdesigner/minimal/composite.py +0 -119
  50. scdesigner/minimal/copula.py +0 -205
  51. scdesigner/minimal/formula.py +0 -23
  52. scdesigner/minimal/gaussian.py +0 -65
  53. scdesigner/minimal/loader.py +0 -211
  54. scdesigner/minimal/marginal.py +0 -154
  55. scdesigner/minimal/negbin.py +0 -73
  56. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
  57. scdesigner/minimal/scd3.py +0 -96
  58. scdesigner/minimal/scd3_instances.py +0 -50
  59. scdesigner/minimal/simulator.py +0 -25
  60. scdesigner/minimal/standard_copula.py +0 -383
  61. scdesigner/predictors/__init__.py +0 -15
  62. scdesigner/predictors/bernoulli.py +0 -9
  63. scdesigner/predictors/gaussian.py +0 -16
  64. scdesigner/predictors/negbin.py +0 -17
  65. scdesigner/predictors/poisson.py +0 -12
  66. scdesigner/predictors/zero_inflated_negbin.py +0 -18
  67. scdesigner/predictors/zero_inflated_poisson.py +0 -18
  68. scdesigner/samplers/__init__.py +0 -23
  69. scdesigner/samplers/bernoulli.py +0 -27
  70. scdesigner/samplers/gaussian.py +0 -25
  71. scdesigner/samplers/glm_factory.py +0 -103
  72. scdesigner/samplers/negbin.py +0 -25
  73. scdesigner/samplers/poisson.py +0 -25
  74. scdesigner/samplers/zero_inflated_negbin.py +0 -40
  75. scdesigner/samplers/zero_inflated_poisson.py +0 -16
  76. scdesigner/simulators/composite_regressor.py +0 -72
  77. scdesigner/simulators/glm_simulator.py +0 -167
  78. scdesigner/simulators/pnmf_regression.py +0 -61
  79. scdesigner/transform/amplify.py +0 -14
  80. scdesigner/transform/mask.py +0 -33
  81. scdesigner/transform/nullify.py +0 -25
  82. scdesigner/transform/split.py +0 -23
  83. scdesigner/transform/substitute.py +0 -14
  84. scdesigner-0.0.5.dist-info/RECORD +0 -66
@@ -0,0 +1,477 @@
1
+ from ..data.formula import standardize_formula
2
+ from ..data.loader import _to_numpy
3
+ from ..base.simulator import Simulator
4
+ from anndata import AnnData
5
+ from formulaic import model_matrix
6
+ from scipy.stats import gamma
7
+ from typing import Union, Dict
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+
12
+ ################################################################################
13
+ ## Functions for estimating PNMF regression
14
+ ################################################################################
15
+
16
+ # computes PNMF weight and score, ncol specify the number of clusters
17
+ def pnmf(log_data, nbase=3, **kwargs): # data is np array, log transformed read data
18
+ """
19
+ Estimate PNMF components from log-transformed expression counts.
20
+
21
+ Parameters
22
+ ----------
23
+ log_data : np.ndarray
24
+ Log-transformed expression matrix (genes × cells).
25
+ nbase : int, optional
26
+ Number of latent PNMF bases to extract.
27
+ **kwargs
28
+ Additional arguments forwarded to :func:`pnmf_eucdist`.
29
+
30
+ Returns
31
+ -------
32
+ tuple[np.ndarray, np.ndarray]
33
+ Tuple containing the learned PNMF weight matrix ``W`` and score matrix
34
+ ``S`` (pseudo-basis loadings for each cell).
35
+ """
36
+ U = left_singular(log_data, nbase)
37
+ W = pnmf_eucdist(log_data, U, **kwargs)
38
+ W = W / np.linalg.norm(W, ord=2)
39
+ S = W.T @ log_data
40
+ return W, S
41
+
42
+
43
+ def gamma_regression_array(
44
+ x: np.array, y: np.array, lr: float = 0.1, epochs: int = 40
45
+ ) -> dict:
46
+ """
47
+ Fit gamma regression coefficients in a batched regression context.
48
+
49
+ Parameters
50
+ ----------
51
+ x : np.ndarray
52
+ Design matrix for covariates (cells × covariates).
53
+ y : np.ndarray
54
+ Target matrix (cells × latent features) derived from PNMF scores.
55
+ lr : float, optional
56
+ Learning rate for the Adam optimizer.
57
+ epochs : int, optional
58
+ Number of training epochs.
59
+
60
+ Returns
61
+ -------
62
+ dict[str, np.ndarray]
63
+ Dictionary containing estimated ``"a"``, ``"loc"``, and ``"beta"``
64
+ regression coefficients shaped (covariates, outcomes).
65
+ """
66
+ x = torch.tensor(x, dtype=torch.float32)
67
+ y = torch.tensor(y, dtype=torch.float32)
68
+
69
+ n_features, n_outcomes = x.shape[1], y.shape[1]
70
+ a = torch.zeros(n_features * n_outcomes, requires_grad=True)
71
+ loc = torch.zeros(n_features * n_outcomes, requires_grad=True)
72
+ beta = torch.zeros(n_features * n_outcomes, requires_grad=True)
73
+ optimizer = torch.optim.Adam([a, loc, beta], lr=lr)
74
+
75
+ for i in range(epochs):
76
+ optimizer.zero_grad()
77
+ loss = negative_gamma_log_likelihood(a, beta, loc, x, y)
78
+ loss.backward()
79
+ optimizer.step()
80
+
81
+ a, loc, beta = _to_numpy(a, loc, beta)
82
+ a = a.reshape(n_features, n_outcomes)
83
+ loc = loc.reshape(n_features, n_outcomes)
84
+ beta = beta.reshape(n_features, n_outcomes)
85
+ return {"a": a, "loc": loc, "beta": beta}
86
+
87
+
88
+ def class_generator(score, n_clusters=3):
89
+ """
90
+ Cluster PNMF scores and return discrete class labels. (This function is not used in the current implementation.)
91
+
92
+ Parameters
93
+ ----------
94
+ score : np.ndarray
95
+ PNMF scores (latent factors) of shape (n_features, n_cells).
96
+ n_clusters : int, optional
97
+ Number of target clusters for grouping the scores.
98
+
99
+ Returns
100
+ -------
101
+ np.ndarray
102
+ Array of cluster labels of length ``n_cells``.
103
+ """
104
+ from sklearn.cluster import KMeans
105
+ kmeans = KMeans(n_clusters, random_state=0) # Specify the number of clusters
106
+ kmeans.fit(score.T)
107
+ labels = kmeans.labels_
108
+ return labels
109
+
110
+
111
+ ###############################################################################
112
+ ## Helpers for deriving PNMF
113
+ ###############################################################################
114
+
115
+
116
+ def pnmf_eucdist(X, W_init, maxIter=500, threshold=1e-4, tol=1e-10, verbose=False, **kwargs):
117
+ """
118
+ Optimize PNMF weights via Euclidean distance minimization.
119
+
120
+ Parameters
121
+ ----------
122
+ X : np.ndarray
123
+ Input expression matrix (genes × cells).
124
+ W_init : np.ndarray
125
+ Initial estimate of the weight matrix.
126
+ maxIter : int, optional
127
+ Maximum number of iterations.
128
+ threshold : float, optional
129
+ Convergence threshold on relative weight change.
130
+ tol : float, optional
131
+ Numeric tolerance for truncating small entries.
132
+ verbose : bool, optional
133
+ If True, print progress every 10 iterations.
134
+ **kwargs
135
+ Reserved for future options.
136
+
137
+ Returns
138
+ -------
139
+ np.ndarray
140
+ Normalized PNMF weight matrix with positive entries.
141
+ """
142
+ # initialization
143
+ W = W_init # initial W is the PCA of X
144
+ XX = X @ X.T
145
+
146
+ # iterations
147
+ for iter in range(maxIter):
148
+ if verbose and (iter + 1) % 10 == 0:
149
+ print("%d iterations used." % (iter + 1))
150
+ W_old = W
151
+
152
+ XXW = XX @ W
153
+ SclFactor = np.dot(W, W.T @ XXW) + np.dot(XXW, W.T @ W)
154
+
155
+ # QuotientLB
156
+ SclFactor = MatFindlb(SclFactor, tol)
157
+ SclFactor = XXW / SclFactor
158
+ W = W * SclFactor # somehow W *= SclFactor doesn't work?
159
+
160
+ norm_W = np.linalg.norm(W)
161
+ W /= norm_W
162
+ W = MatFind(W, tol)
163
+
164
+ diffW = np.linalg.norm(W_old - W) / np.linalg.norm(W_old)
165
+ if diffW < threshold:
166
+ break
167
+
168
+ return W
169
+
170
+
171
+ # left singular vector of X
172
+ def left_singular(X, k):
173
+ """
174
+ Extract the top `k` left singular vectors of matrix `X` for initialization.
175
+ """
176
+ from scipy.sparse.linalg import svds
177
+ U, _, _ = svds(X, k=k)
178
+ return np.abs(U)
179
+
180
+
181
+ def MatFindlb(A, lb):
182
+ """
183
+ Clamp matrix A's entries to be greater than or equal to a lower bound `lb`.
184
+ """
185
+ B = np.ones(A.shape) * lb
186
+ Alb = np.where(A < lb, B, A)
187
+ return Alb
188
+
189
+
190
+ def MatFind(A, ZeroThres):
191
+ """
192
+ Zero out values below a threshold.
193
+ """
194
+ B = np.zeros(A.shape)
195
+ Atrunc = np.where(A < ZeroThres, B, A)
196
+ return Atrunc
197
+
198
+
199
+ ###############################################################################
200
+ ## Helpers for training PNMF regression
201
+ ###############################################################################
202
+
203
+
204
+ def shifted_gamma_pdf(x, alpha, beta, loc):
205
+ """
206
+ Compute penalized negative log probability density function for shifted gamma.
207
+ A huge penalty is applied to values below the location parameter to ensure
208
+ the parameters fall within the support of the gamma distribution.
209
+
210
+ Parameters
211
+ ----------
212
+ x : torch.Tensor or array-like
213
+ Observed values.
214
+ alpha : torch.Tensor
215
+ Shape parameters.
216
+ beta : torch.Tensor
217
+ Rate parameters.
218
+ loc : torch.Tensor
219
+ Location parameters (shift).
220
+
221
+ Returns
222
+ -------
223
+ torch.Tensor
224
+ Mean negative log likelihood over ``x``.
225
+ """
226
+ if not torch.is_tensor(x):
227
+ x = torch.tensor(x)
228
+ mask = x < loc
229
+ y_clamped = torch.clamp(x - loc, min=1e-12)
230
+
231
+ log_pdf = (
232
+ alpha * torch.log(beta)
233
+ - torch.lgamma(alpha)
234
+ + (alpha - 1) * torch.log(y_clamped)
235
+ - beta * y_clamped
236
+ )
237
+ loss = -torch.mean(log_pdf[~mask])
238
+ n_invalid = mask.sum()
239
+ if n_invalid > 0: # force samples to be greater than loc
240
+ loss = loss + 1e10 * n_invalid.float()
241
+ return loss
242
+
243
+
244
+ def negative_gamma_log_likelihood(log_a, log_beta, loc, X, y):
245
+ """
246
+ Compute the (negative) log likelihood over a gamma regression layer.
247
+
248
+ Parameters
249
+ ----------
250
+ log_a : torch.Tensor
251
+ Log-scale shape coefficients of shape (covariates × outcomes).
252
+ log_beta : torch.Tensor
253
+ Log-scale rate coefficients of shape (covariates × outcomes).
254
+ loc : torch.Tensor
255
+ Location coefficients of shape (covariates × outcomes).
256
+ X : torch.Tensor
257
+ Design matrix (cells × covariates).
258
+ y : torch.Tensor
259
+ Observed PNMF scores (cells × outcomes).
260
+
261
+ Returns
262
+ -------
263
+ torch.Tensor
264
+ Scalar tensor representing the mean negative log-likelihood.
265
+ """
266
+ n_features = X.shape[1]
267
+ n_outcomes = y.shape[1]
268
+
269
+ a = torch.exp(log_a.reshape(n_features, n_outcomes))
270
+ beta = torch.exp(log_beta.reshape(n_features, n_outcomes))
271
+ loc = loc.reshape(n_features, n_outcomes)
272
+ return shifted_gamma_pdf(y, X @ a, X @ beta, X @ loc)
273
+
274
+ def format_gamma_parameters(
275
+ parameters: dict,
276
+ W_index: list,
277
+ coef_index: list,
278
+ ) -> dict:
279
+ """
280
+ Format gamma regression parameters as DataFrames for downstream use.
281
+
282
+ Parameters
283
+ ----------
284
+ parameters : dict
285
+ Dictionary containing ``"a"``, ``"loc"``, ``"beta"``, and ``"W"`` arrays.
286
+ W_index : list
287
+ Row labels for the PNMF weights (typically gene names).
288
+ coef_index : list
289
+ Row labels for the regression coefficients (typically covariate names).
290
+
291
+ Returns
292
+ -------
293
+ dict
294
+ Updated dictionary with DataFrames stored under the original keys.
295
+ """
296
+ parameters["a"] = pd.DataFrame(parameters["a"], index=coef_index)
297
+ parameters["loc"] = pd.DataFrame(parameters["loc"], index=coef_index)
298
+ parameters["beta"] = pd.DataFrame(parameters["beta"], index=coef_index)
299
+ parameters["W"] = pd.DataFrame(parameters["W"], index=W_index)
300
+ return parameters
301
+
302
+
303
+ ################################################################################
304
+ ## Associated PNMF Objects
305
+ ################################################################################
306
+
307
+ class PositiveNMF(Simulator):
308
+ """
309
+ Positive nonnegative matrix factorization (PNMF) simulator with gamma regression.
310
+
311
+ This simulator fits a low-rank positive factorization on log-transformed
312
+ expression data and then models the resulting latent scores using a
313
+ covariate-dependent shifted gamma distribution. Sampling proceeds by
314
+ drawing gamma latent scores and mapping them back to the gene space via the
315
+ learned PNMF weights.
316
+
317
+ Parameters
318
+ ----------
319
+ formula : dict or str
320
+ Mean-model formula for the gamma regression. If a string is provided,
321
+ it is interpreted as the mean formula and stored under the key
322
+ ``"mean"``. The formula is evaluated against ``adata.obs`` via
323
+ :func:`formulaic.model_matrix`.
324
+ **kwargs
325
+ Keyword arguments forwarded to :func:`pnmf` (e.g. ``nbase``, ``maxIter``).
326
+
327
+ Attributes
328
+ ----------
329
+ formula : dict
330
+ Standardized formula dictionary containing at least the ``"mean"`` key.
331
+ parameters : dict or None
332
+ Fitted parameters after calling :meth:`fit`. Keys include:
333
+
334
+ * ``"a"`` (:class:`pandas.DataFrame`): gamma shape regression coefficients.
335
+ * ``"loc"`` (:class:`pandas.DataFrame`): gamma location regression coefficients.
336
+ * ``"beta"`` (:class:`pandas.DataFrame`): gamma rate regression coefficients.
337
+ * ``"W"`` (:class:`pandas.DataFrame`): PNMF weight matrix mapping latent
338
+ scores to genes.
339
+ n_outcomes : int
340
+ Number of simulated outcomes (cells) in the training data.
341
+ columns : pandas.Index
342
+ Column names of the design matrix produced from ``formula["mean"]``.
343
+
344
+ Examples
345
+ --------
346
+ Fit a PNMF simulator, inspect fitted parameters, and generate samples:
347
+
348
+ >>> import numpy as np
349
+ >>> import pandas as pd
350
+ >>> from anndata import AnnData
351
+ >>> from scdesigner.simulators import PositiveNMF
352
+ >>>
353
+ >>> rng = np.random.default_rng(0)
354
+ >>> X = rng.poisson(lam=2.0, size=(50, 20)).astype(float) # (cells × genes)
355
+ >>> obs = pd.DataFrame({"condition": rng.choice(["A", "B"], size=50)})
356
+ >>> adata = AnnData(X=X, obs=obs)
357
+ >>> adata.var_names = [f"g{i}" for i in range(adata.n_vars)]
358
+ >>>
359
+ >>> sim = PositiveNMF("~ 1 + condition", nbase=3, maxIter=50)
360
+ >>> sim.fit(adata, lr=0.1)
361
+ >>> isinstance(sim.parameters, dict)
362
+ True
363
+ >>>
364
+ >>> # Predict gamma parameters for new observations
365
+ >>> new_obs = pd.DataFrame({"condition": ["A", "B", "A"]})
366
+ >>> pred = sim.predict(new_obs)
367
+ >>> sorted(pred.keys())
368
+ ['a', 'beta', 'loc']
369
+ >>>
370
+ >>> # Sample a new dataset with the same genes
371
+ >>> adata_sim = sim.sample(new_obs)
372
+ >>> adata_sim.n_obs == 3 and adata_sim.n_vars == adata.n_vars
373
+ True
374
+ """
375
+ def __init__(self, formula: Union[Dict, str], **kwargs):
376
+ """
377
+ Parameters
378
+ ----------
379
+ formula : dict or str
380
+ Formula describing the mean model for the gamma regression.
381
+ **kwargs
382
+ Keyword arguments passed through to :func:`pnmf`.
383
+ """
384
+ self.formula = standardize_formula(formula, allowed_keys=['mean'])
385
+ self.parameters = None
386
+ self._hyperparams = kwargs
387
+
388
+
389
+ def setup_data(self, adata: AnnData, **kwargs):
390
+ self.log_data = np.log1p(adata.X).T #(genes x cells)
391
+ self.n_outcomes = self.log_data.shape[1]
392
+ self._template = adata
393
+ self.x = model_matrix(self.formula["mean"], adata.obs)
394
+ self.columns = self.x.columns
395
+ self.x = np.asarray(self.x)
396
+
397
+
398
+ def fit(self, adata: AnnData, lr: float=0.1):
399
+ """
400
+ Fit the PNMF marginals on the provided AnnData.
401
+
402
+ Parameters
403
+ ----------
404
+ adata : AnnData
405
+ Dataset used to estimate PNMF weights and gamma coefficients.
406
+ lr : float, optional
407
+ Learning rate for the gamma regression solver.
408
+ """
409
+ self.setup_data(adata)
410
+ W, S = pnmf(self.log_data, **self._hyperparams)
411
+ parameters = gamma_regression_array(self.x, S.T, lr)
412
+ parameters["W"] = W
413
+ self.parameters = format_gamma_parameters(
414
+ parameters, list(self._template.var_names), list(self.columns)
415
+ )
416
+
417
+
418
+ def predict(self, obs=None, **kwargs):
419
+ """
420
+ Predict gamma regression parameters for new observations.
421
+
422
+ Parameters
423
+ ----------
424
+ obs : pandas.DataFrame, optional
425
+ Observation metadata used to construct the design matrix. Defaults
426
+ to the training ``AnnData`` observations.
427
+
428
+ Returns
429
+ -------
430
+ dict[str, np.ndarray]
431
+ Dictionary with ``"a"``, ``"loc"``, and ``"beta"`` arrays for each
432
+ target feature and observation.
433
+ """
434
+ if obs is None:
435
+ obs = self._template.obs
436
+
437
+ x = model_matrix(self.formula["mean"], obs)
438
+ a, loc, beta = (
439
+ x @ np.exp(self.parameters["a"]),
440
+ x @ self.parameters["loc"],
441
+ x @ np.exp(self.parameters["beta"]),
442
+ )
443
+ return {"a": a, "loc": loc, "beta": beta}
444
+
445
+
446
+ def sample(self, obs=None):
447
+ """
448
+ Generate expression samples from the fitted model.
449
+
450
+ Parameters
451
+ ----------
452
+ obs : pandas.DataFrame, optional
453
+ Metadata for the observations to simulate. Defaults to the training
454
+ ``AnnData`` annotations.
455
+
456
+ Returns
457
+ -------
458
+ AnnData
459
+ Simulated :class:`AnnData` matrix containing generated expression
460
+ counts on the original feature ordering.
461
+ """
462
+ if obs is None:
463
+ obs = self._template.obs
464
+ W = self.parameters["W"]
465
+ parameters = self.predict(obs)
466
+ a, loc, beta = parameters["a"], parameters["loc"], parameters["beta"]
467
+ sim_score = gamma(a, loc, 1 / beta).rvs()
468
+ samples = np.exp(W @ sim_score.T).T
469
+
470
+ # thresholding samples
471
+ floor = np.floor(samples)
472
+ samples = floor + np.where(samples - floor < 0.9, 0, 1) - 1
473
+ samples = np.where(samples < 0, 0, samples)
474
+
475
+ result = AnnData(X=samples, obs=obs)
476
+ result.var_names = self._template.var_names
477
+ return result