scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. scdesigner/base/__init__.py +8 -0
  2. scdesigner/base/copula.py +416 -0
  3. scdesigner/base/marginal.py +391 -0
  4. scdesigner/base/simulator.py +59 -0
  5. scdesigner/copulas/__init__.py +8 -0
  6. scdesigner/copulas/standard_copula.py +645 -0
  7. scdesigner/datasets/__init__.py +5 -0
  8. scdesigner/datasets/pancreas.py +39 -0
  9. scdesigner/distributions/__init__.py +19 -0
  10. scdesigner/{minimal → distributions}/bernoulli.py +42 -14
  11. scdesigner/distributions/gaussian.py +114 -0
  12. scdesigner/distributions/negbin.py +121 -0
  13. scdesigner/distributions/negbin_irls.py +72 -0
  14. scdesigner/distributions/negbin_irls_funs.py +456 -0
  15. scdesigner/distributions/poisson.py +88 -0
  16. scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
  17. scdesigner/distributions/zero_inflated_poisson.py +103 -0
  18. scdesigner/simulators/__init__.py +24 -28
  19. scdesigner/simulators/composite.py +239 -0
  20. scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
  21. scdesigner/simulators/scd3.py +486 -0
  22. scdesigner/transform/__init__.py +8 -6
  23. scdesigner/{minimal → transform}/transform.py +1 -1
  24. scdesigner/{minimal → utils}/kwargs.py +4 -1
  25. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
  26. scdesigner-0.0.10.dist-info/RECORD +28 -0
  27. {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
  28. scdesigner/data/__init__.py +0 -16
  29. scdesigner/data/formula.py +0 -137
  30. scdesigner/data/group.py +0 -123
  31. scdesigner/data/sparse.py +0 -39
  32. scdesigner/diagnose/__init__.py +0 -65
  33. scdesigner/diagnose/aic_bic.py +0 -119
  34. scdesigner/diagnose/plot.py +0 -242
  35. scdesigner/estimators/__init__.py +0 -32
  36. scdesigner/estimators/bernoulli.py +0 -85
  37. scdesigner/estimators/gaussian.py +0 -121
  38. scdesigner/estimators/gaussian_copula_factory.py +0 -367
  39. scdesigner/estimators/glm_factory.py +0 -75
  40. scdesigner/estimators/negbin.py +0 -153
  41. scdesigner/estimators/pnmf.py +0 -160
  42. scdesigner/estimators/poisson.py +0 -124
  43. scdesigner/estimators/zero_inflated_negbin.py +0 -195
  44. scdesigner/estimators/zero_inflated_poisson.py +0 -85
  45. scdesigner/format/__init__.py +0 -4
  46. scdesigner/format/format.py +0 -20
  47. scdesigner/format/print.py +0 -30
  48. scdesigner/minimal/__init__.py +0 -17
  49. scdesigner/minimal/composite.py +0 -119
  50. scdesigner/minimal/copula.py +0 -205
  51. scdesigner/minimal/formula.py +0 -23
  52. scdesigner/minimal/gaussian.py +0 -65
  53. scdesigner/minimal/loader.py +0 -211
  54. scdesigner/minimal/marginal.py +0 -154
  55. scdesigner/minimal/negbin.py +0 -73
  56. scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
  57. scdesigner/minimal/scd3.py +0 -96
  58. scdesigner/minimal/scd3_instances.py +0 -50
  59. scdesigner/minimal/simulator.py +0 -25
  60. scdesigner/minimal/standard_copula.py +0 -383
  61. scdesigner/predictors/__init__.py +0 -15
  62. scdesigner/predictors/bernoulli.py +0 -9
  63. scdesigner/predictors/gaussian.py +0 -16
  64. scdesigner/predictors/negbin.py +0 -17
  65. scdesigner/predictors/poisson.py +0 -12
  66. scdesigner/predictors/zero_inflated_negbin.py +0 -18
  67. scdesigner/predictors/zero_inflated_poisson.py +0 -18
  68. scdesigner/samplers/__init__.py +0 -23
  69. scdesigner/samplers/bernoulli.py +0 -27
  70. scdesigner/samplers/gaussian.py +0 -25
  71. scdesigner/samplers/glm_factory.py +0 -103
  72. scdesigner/samplers/negbin.py +0 -25
  73. scdesigner/samplers/poisson.py +0 -25
  74. scdesigner/samplers/zero_inflated_negbin.py +0 -40
  75. scdesigner/samplers/zero_inflated_poisson.py +0 -16
  76. scdesigner/simulators/composite_regressor.py +0 -72
  77. scdesigner/simulators/glm_simulator.py +0 -167
  78. scdesigner/simulators/pnmf_regression.py +0 -61
  79. scdesigner/transform/amplify.py +0 -14
  80. scdesigner/transform/mask.py +0 -33
  81. scdesigner/transform/nullify.py +0 -25
  82. scdesigner/transform/split.py +0 -23
  83. scdesigner/transform/substitute.py +0 -14
  84. scdesigner-0.0.5.dist-info/RECORD +0 -66
@@ -0,0 +1,456 @@
1
+ import torch
2
+ from typing import Optional
3
+ import torch.special as spec
4
+
5
+ # ==============================================================================
6
+ # Weighted Least Squares Solver
7
+ # ==============================================================================
8
+
9
+ def solve_weighted_least_squares(X, weights, responses):
10
+ """
11
+ Solve multiple independent weighted least squares problems in parallel.
12
+
13
+ For each column j, solves: (X'W_j X)β_j = X'W_j z_j
14
+ where W_j is a diagonal matrix with weights[:, j] on the diagonal.
15
+
16
+ Parameters
17
+ ----------
18
+ X : torch.Tensor
19
+ Design matrix (n × p)
20
+ weights : torch.Tensor
21
+ Weight matrix (n × m), one weight vector per response
22
+ responses : torch.Tensor
23
+ Working responses (n × m)
24
+
25
+ Returns
26
+ -------
27
+ torch.Tensor
28
+ Coefficient matrix (p × m)
29
+ """
30
+ # Precompute outer products X_i ⊗ X_i for each observation
31
+ X_outer = torch.einsum("ni,nj->nij", X, X) # (n × p × p)
32
+
33
+ # Compute weighted normal equations: (X'WX) for all m responses at once
34
+ eye = torch.eye(X.shape[1]).unsqueeze(0)
35
+ weighted_XX = torch.einsum("nm,nij->mij", weights, X_outer) # (m × p × p)
36
+ weighted_XX = weighted_XX + 1e-5 * eye
37
+
38
+ # Compute X'Wz for all responses
39
+ weighted_Xy = torch.einsum("ni,nm->mi", X, weights * responses) # (m × p)
40
+
41
+ # Solve all systems at once
42
+ coefficients = torch.linalg.solve(weighted_XX, weighted_Xy.unsqueeze(-1))
43
+ return coefficients.squeeze(-1).T # (p × m)
44
+
45
+
46
+ # ==============================================================================
47
+ # Mean Parameter Updates (Beta)
48
+ # ==============================================================================
49
+
50
+ def update_mean_coefficients(X, counts, beta, dispersion, clip: float = 5.0):
51
+ """
52
+ Update mean model coefficients using one Newton-Raphson step.
53
+
54
+ Uses IRLS (Iteratively Reweighted Least Squares) with:
55
+ - Working weights: W = μ/(1 + μ/θ)
56
+ - Working response: Z = Xβ + (Y - μ)/μ
57
+
58
+ Parameters
59
+ ----------
60
+ X : torch.Tensor
61
+ Design matrix (n × p)
62
+ counts : torch.Tensor
63
+ Observed counts (n × m)
64
+ beta : torch.Tensor
65
+ Current coefficients (p × m)
66
+ dispersion : torch.Tensor
67
+ Current dispersion parameters (n × m)
68
+ clip : float, optional
69
+ Maximum absolute value for linear predictor, by default 5.0
70
+
71
+ Returns
72
+ -------
73
+ torch.Tensor
74
+ Updated coefficients (p × m)
75
+ """
76
+ linear_pred = torch.clip(X @ beta, min=-clip, max=clip)
77
+ mean = torch.exp(linear_pred)
78
+ weights = mean / (1 + mean / dispersion)
79
+ working_response = linear_pred + (counts - mean) / mean
80
+ return solve_weighted_least_squares(X, weights, working_response)
81
+
82
+
83
+ # ==============================================================================
84
+ # Dispersion Parameter Updates (Gamma)
85
+ # ==============================================================================
86
+
87
+ def update_dispersion_coefficients(Z, counts, mean, gamma, clip: float = 5.0):
88
+ """
89
+ Update dispersion model coefficients using one Fisher scoring step.
90
+
91
+ Uses working response U = η + θ·s/w where:
92
+ - η = Zγ (linear predictor)
93
+ - s = score with respect to θ
94
+ - w = approximate Fisher information
95
+
96
+ Parameters
97
+ ----------
98
+ Z : torch.Tensor
99
+ Dispersion design matrix (n × q)
100
+ counts : torch.Tensor
101
+ Observed counts (n × m)
102
+ mean : torch.Tensor
103
+ Current mean estimates (n × m)
104
+ gamma : torch.Tensor
105
+ Current dispersion coefficients (q × m)
106
+ clip : float, optional
107
+ Maximum absolute value for linear predictor, by default 5.0
108
+
109
+ Returns
110
+ -------
111
+ torch.Tensor
112
+ Updated dispersion coefficients (q × m)
113
+ """
114
+ linear_pred = torch.clip(Z @ gamma, min=-clip, max=clip)
115
+ dispersion = torch.exp(linear_pred)
116
+
117
+ # Score: ∂ℓ/∂θ
118
+ psi_diff = spec.digamma(counts + dispersion) - spec.digamma(dispersion)
119
+ score = (psi_diff + torch.log(dispersion) - torch.log(mean + dispersion) +
120
+ (mean - counts) / (mean + dispersion))
121
+
122
+ # Approximate Fisher information (replaces exact Hessian)
123
+ # Approximation: θY/(θ + Y) ≈ θ²[ψ₁(θ) - ψ₁(Y + θ)]
124
+ weights = ((dispersion * counts) / (dispersion + counts)).clip(min=1e-6)
125
+ working_response = linear_pred + (dispersion * score) / weights
126
+ return solve_weighted_least_squares(Z, weights, working_response)
127
+
128
+
129
+ # ==============================================================================
130
+ # Initialization
131
+ # ==============================================================================
132
+
133
+ def estimate_constant_dispersion(X, counts, beta):
134
+ """
135
+ Estimate constant dispersion for each response using method of moments.
136
+
137
+ Uses Pearson residuals: θ̂ = (Σμ) / max(χ² - df, 0.1)
138
+ where χ² = Σ(Y - μ)²/μ and df = n - p.
139
+
140
+ Parameters
141
+ ----------
142
+ X : torch.Tensor
143
+ Design matrix (n × p)
144
+ counts : torch.Tensor
145
+ Observed counts (n × m)
146
+ beta : torch.Tensor
147
+ Mean coefficients (p × m)
148
+
149
+ Returns
150
+ -------
151
+ torch.Tensor
152
+ Dispersion estimates (m,)
153
+ """
154
+ mean = torch.exp(X @ beta)
155
+ pearson_chi2 = torch.sum((counts - mean)**2 / mean, dim=0)
156
+ sum_mean = torch.sum(mean, dim=0)
157
+
158
+ degrees_freedom = counts.shape[0] - X.shape[1]
159
+ dispersion = sum_mean / torch.clip(pearson_chi2 - degrees_freedom, min=0.1)
160
+ return torch.clip(dispersion, min=0.1)
161
+
162
+
163
+ def fit_poisson_initial(X, counts, tol: float = 1e-3, max_iter: int = 100, clip: float = 5.0):
164
+ """
165
+ Fit Poisson GLM to initialize mean parameters.
166
+
167
+ Parameters
168
+ ----------
169
+ X : torch.Tensor
170
+ Design matrix (n × p)
171
+ counts : torch.Tensor
172
+ Observed counts (n × m)
173
+ tol : float, optional
174
+ Convergence tolerance, by default 1e-3
175
+ max_iter : int, optional
176
+ Maximum iterations, by default 100
177
+ clip : float, optional
178
+ Maximum absolute value for linear predictor, by default 5.0
179
+
180
+ Returns
181
+ -------
182
+ torch.Tensor
183
+ Initial coefficients (p × m)
184
+ """
185
+ n_features, n_responses = X.shape[1], counts.shape[1]
186
+ beta = torch.zeros((n_features, n_responses))
187
+
188
+ for _ in range(max_iter):
189
+ beta_old = beta.clone()
190
+ linear_pred = torch.clip(X @ beta, min=-clip, max=clip)
191
+ mean = torch.exp(linear_pred)
192
+ working_response = linear_pred + (counts - mean) / mean
193
+
194
+ beta = solve_weighted_least_squares(X, mean, working_response)
195
+ if torch.max(torch.abs(beta - beta_old)) < tol:
196
+ break
197
+
198
+ return beta
199
+
200
+
201
+ def accumulate_poisson_statistics(loader, beta, n_genes, p_mean, clip = 5):
202
+ """
203
+ Accumulate weighted normal equations for Poisson IRLS across batches.
204
+
205
+ Parameters
206
+ ----------
207
+ loader : DataLoader
208
+ DataLoader yielding (y_batch, x_dict)
209
+ beta : torch.Tensor
210
+ Current coefficients (p_mean × n_genes)
211
+ n_genes : int
212
+ Number of genes
213
+ p_mean : int
214
+ Number of mean predictors
215
+ clip : float, optional
216
+ Maximum absolute value for linear predictor, by default 5
217
+
218
+ Returns
219
+ -------
220
+ weighted_XX : torch.Tensor
221
+ Accumulated X'WX (n_genes × p_mean × p_mean)
222
+ weighted_Xy : torch.Tensor
223
+ Accumulated X'Wz (p_mean × n_genes)
224
+ """
225
+ weighted_XX = torch.zeros((n_genes, p_mean, p_mean))
226
+ weighted_Xy = torch.zeros((p_mean, n_genes))
227
+
228
+ for y_batch, x_dict in loader:
229
+ X = x_dict['mean'].to("cpu")
230
+
231
+ linear_pred = torch.clip(X @ beta, min=-clip, max=clip)
232
+ mean = torch.exp(linear_pred)
233
+ working_response = linear_pred + (y_batch.to("cpu") - mean) / mean
234
+
235
+ X_outer = torch.einsum("ni,nj->nij", X, X)
236
+ weighted_XX += torch.einsum("nm,nij->mij", mean, X_outer)
237
+ weighted_Xy += torch.einsum("ni,nm->im", X, mean * working_response)
238
+
239
+ return weighted_XX, weighted_Xy
240
+
241
+
242
+ def accumulate_dispersion_statistics(loader, beta, clip = 5):
243
+ """
244
+ Accumulate Pearson statistics for method of moments dispersion estimation.
245
+
246
+ Parameters
247
+ ----------
248
+ loader : DataLoader
249
+ DataLoader yielding (y_batch, x_dict)
250
+ beta : torch.Tensor
251
+ Mean coefficients (p_mean × n_genes)
252
+ clip : float, optional
253
+ Maximum absolute value for linear predictor, by default 5
254
+
255
+ Returns
256
+ -------
257
+ sum_mean : torch.Tensor
258
+ Total predicted mean (n_genes,)
259
+ sum_pearson : torch.Tensor
260
+ Total Pearson chi-squared (n_genes,)
261
+ n_total : int
262
+ Total number of observations
263
+ """
264
+ sum_mean = torch.zeros(beta.shape[1])
265
+ sum_pearson = torch.zeros(beta.shape[1])
266
+ n_total = 0
267
+
268
+ for y_batch, x_dict in loader:
269
+ X = x_dict['mean'].to('cpu')
270
+ linear_pred = torch.clip(X @ beta, min=-clip, max=clip)
271
+ mean_batch = torch.exp(linear_pred)
272
+
273
+ sum_mean += mean_batch.sum(dim=0)
274
+ sum_pearson += ((y_batch.to('cpu') - mean_batch)**2 / mean_batch).sum(dim=0)
275
+ n_total += y_batch.shape[0]
276
+
277
+ return sum_mean, sum_pearson, n_total
278
+
279
+
280
+ def initialize_parameters(loader, n_genes, p_mean, p_disp, max_iter = 10,
281
+ tol = 1e-3, clip = 5):
282
+ """
283
+ Initialize parameters using batched Poisson IRLS followed by MoM dispersion.
284
+
285
+ Logic:
286
+ 1. Iteratively fit Poisson GLM by accumulating X'WX and X'WZ across batches
287
+ 2. Use fitted Poisson means to estimate dispersion via Method of Moments
288
+
289
+ Parameters
290
+ ----------
291
+ loader : DataLoader
292
+ DataLoader yielding (y_batch, x_dict)
293
+ n_genes : int
294
+ Number of response columns (genes)
295
+ p_mean : int
296
+ Number of predictors in the mean model
297
+ p_disp : int
298
+ Number of predictors in the dispersion model
299
+ max_iter : int, optional
300
+ Maximum Poisson IRLS iterations, by default 10
301
+ tol : float, optional
302
+ Convergence tolerance for beta coefficients, by default 1e-3
303
+ clip : float, optional
304
+ Maximum absolute value for linear predictor, by default 10
305
+
306
+ Returns
307
+ -------
308
+ beta_init : torch.Tensor
309
+ (p_mean × n_genes) tensor
310
+ gamma_init : torch.Tensor
311
+ (p_disp × n_genes) tensor
312
+ """
313
+ beta = torch.zeros((p_mean, n_genes))
314
+ for _ in range(max_iter):
315
+ weighted_XX, weighted_Xy = accumulate_poisson_statistics(
316
+ loader, beta, n_genes, p_mean, clip
317
+ )
318
+
319
+ eye = torch.eye(p_mean).unsqueeze(0)
320
+ weighted_XX_reg = weighted_XX + 1e-6 * eye
321
+ beta_new = torch.linalg.solve(
322
+ weighted_XX_reg, weighted_Xy.T.unsqueeze(-1)
323
+ ).squeeze(-1).T
324
+
325
+ if torch.max(torch.abs(beta_new - beta)) < tol:
326
+ beta = beta_new
327
+ break
328
+ beta = beta_new
329
+
330
+ sum_mean, sum_pearson, n_total = accumulate_dispersion_statistics(
331
+ loader, beta, clip
332
+ )
333
+
334
+ degrees_freedom = n_total - p_mean
335
+ dispersion = sum_mean / torch.clip(sum_pearson - degrees_freedom, min=0.1)
336
+
337
+ gamma = torch.zeros((p_disp, n_genes))
338
+ gamma[0, :] = torch.log(torch.clip(dispersion, min=0.1))
339
+ return beta, gamma
340
+
341
+
342
+ # ==============================================================================
343
+ # Batch Log-Likelihood
344
+ # ==============================================================================
345
+
346
+ def compute_batch_loglikelihood(y, mu, r):
347
+ """
348
+ Compute the negative binomial log-likelihood for a batch.
349
+
350
+ Formula:
351
+ ℓ = Σ [log Γ(Y+θ) - log Γ(θ) - log Γ(Y+1) + θ log θ + Y log μ - (Y+θ)log(μ+θ)]
352
+
353
+ Parameters
354
+ ----------
355
+ y : torch.Tensor
356
+ Observed counts (n_batch × m_active)
357
+ mu : torch.Tensor
358
+ Predicted means (n_batch × m_active)
359
+ r : torch.Tensor
360
+ Dispersion parameters (n_batch × m_active)
361
+
362
+ Returns
363
+ -------
364
+ torch.Tensor
365
+ Total log-likelihood per response (m_active,)
366
+ """
367
+ ll = (
368
+ torch.lgamma(y + r) - torch.lgamma(r) - torch.lgamma(y + 1.0)
369
+ + r * torch.log(r) + y * torch.log(mu)
370
+ - (r + y) * torch.log(r + mu)
371
+ )
372
+ return torch.sum(ll, dim=0)
373
+
374
+
375
+ # ==============================================================================
376
+ # Stochastic IRLS Step
377
+ # ==============================================================================
378
+
379
+ def step_stochastic_irls(
380
+ y,
381
+ X,
382
+ Z,
383
+ beta,
384
+ gamma,
385
+ eta: float = 0.8,
386
+ tol: float = 1e-4,
387
+ ll_prev: Optional[torch.Tensor] = None,
388
+ clip_mean: float = 10.0,
389
+ clip_disp: float = 10.0
390
+ ):
391
+ """
392
+ Perform a single damped Newton-Raphson update on a minibatch.
393
+
394
+ Logic:
395
+ 1. Compute log-likelihood with current coefficients.
396
+ 2. Perform one IRLS step for Mean (Beta) and Dispersion (Gamma).
397
+ 3. Re-compute log-likelihood to determine convergence.
398
+ 4. Return updated coefficients and boolean convergence mask.
399
+
400
+ Parameters
401
+ ----------
402
+ y : torch.Tensor
403
+ Count batch (n × m)
404
+ X : torch.Tensor
405
+ Mean design matrix (n × p)
406
+ Z : torch.Tensor
407
+ Dispersion design matrix (n × q)
408
+ beta : torch.Tensor
409
+ Current mean coefficients (p × m)
410
+ gamma : torch.Tensor
411
+ Current dispersion coefficients (q × m)
412
+ eta : float, optional
413
+ Damping factor (learning rate), 1.0 is pure Newton step, by default 0.8
414
+ tol : float, optional
415
+ Relative log-likelihood change threshold for convergence, by default 1e-4
416
+ ll_prev : torch.Tensor, optional
417
+ Previous log-likelihood values, by default None
418
+ clip_mean : float, optional
419
+ Maximum absolute value for mean linear predictor, by default 10.0
420
+ clip_disp : float, optional
421
+ Maximum absolute value for dispersion linear predictor, by default 10.0
422
+
423
+ Returns
424
+ -------
425
+ beta_next : torch.Tensor
426
+ Updated mean coefficients (p × m)
427
+ gamma_next : torch.Tensor
428
+ Updated dispersion coefficients (q × m)
429
+ converged : torch.Tensor
430
+ Boolean mask of converged responses (m,)
431
+ ll_next : torch.Tensor
432
+ Updated log-likelihood values (m,)
433
+ """
434
+ # --- 2. Update Mean (Beta) ---
435
+ # Working weights W = μ/(1 + μ/θ)
436
+ beta_target = update_mean_coefficients(X, y, beta, torch.exp(Z @ gamma), clip=clip_mean)
437
+ beta_next = (1 - eta) * beta + eta * beta_target
438
+
439
+ # --- 3. Update Dispersion (Gamma) ---
440
+ # Update depends on the latest mean estimates
441
+ linear_pred_mu = torch.clip(X @ beta_next, min=-clip_mean, max=clip_mean)
442
+ mu = torch.exp(linear_pred_mu)
443
+ gamma_target = update_dispersion_coefficients(Z, y, mu, gamma, clip=clip_disp)
444
+ gamma_next = (1 - eta) * gamma + eta * gamma_target
445
+
446
+ # --- 4. Convergence Check ---
447
+ linear_pred_r_next = torch.clip(Z @ gamma_next, min=-clip_disp, max=clip_disp)
448
+ ll_next = compute_batch_loglikelihood(y, mu, torch.exp(linear_pred_r_next))
449
+
450
+ # Relative improvement in the objective function
451
+ if ll_prev is not None:
452
+ rel_change = torch.abs(ll_next - ll_prev) / (torch.abs(ll_prev) + 1e-10)
453
+ converged = rel_change <= tol
454
+ else:
455
+ converged = False
456
+ return beta_next, gamma_next, converged, ll_next
@@ -0,0 +1,88 @@
1
+ from ..data.formula import standardize_formula
2
+ from ..base.marginal import GLMPredictor, Marginal
3
+ from ..data.loader import _to_numpy
4
+ from typing import Union, Dict, Optional, Tuple
5
+ import torch
6
+ import numpy as np
7
+ from scipy.stats import poisson
8
+
9
+ class Poisson(Marginal):
10
+ """Poisson marginal estimator
11
+
12
+ This subclass behaves like `Marginal` but assumes each gene follows a
13
+ Poisson distribution with mean `mu_j(x)` that depends on covariates `x`
14
+ via the provided `formula` object.
15
+
16
+ The allowed formula keys are 'mean', defaulting to a single `mean` term
17
+ if a string formula is supplied.
18
+
19
+ Examples
20
+ --------
21
+ >>> from scdesigner.distributions import Poisson
22
+ >>> from scdesigner.datasets import pancreas
23
+ >>>
24
+ >>> sim = Poisson(formula="~ bs(pseudotime, df=5)")
25
+ >>> sim.setup_data(pancreas)
26
+ >>> sim.fit(max_epochs=1, verbose=False)
27
+ >>>
28
+ >>> # evaluate p(y | x) and mu(x)
29
+ >>> y, x = next(iter(sim.loader))
30
+ >>> l = sim.likelihood((y, x))
31
+ >>> y_hat = sim.predict(x)
32
+ >>>
33
+ >>> # convert to quantiles and back
34
+ >>> u = sim.uniformize(y, x)
35
+ >>> x_star = sim.invert(u, x)
36
+ """
37
+ def __init__(self, formula: Union[Dict, str]):
38
+ formula = standardize_formula(formula, allowed_keys=['mean'])
39
+ super().__init__(formula)
40
+
41
+ def setup_optimizer(
42
+ self,
43
+ optimizer_class: Optional[callable] = torch.optim.Adam,
44
+ **optimizer_kwargs,
45
+ ):
46
+ if self.loader is None:
47
+ raise RuntimeError("self.loader is not set (call setup_data first)")
48
+
49
+ def nll(batch):
50
+ return -self.likelihood(batch).sum()
51
+
52
+ self.predict = GLMPredictor(
53
+ n_outcomes=self.n_outcomes,
54
+ feature_dims=self.feature_dims,
55
+ loss_fn=nll,
56
+ optimizer_class=optimizer_class,
57
+ optimizer_kwargs=optimizer_kwargs
58
+ )
59
+
60
+ def likelihood(self, batch) -> torch.Tensor:
61
+ """Compute the log-likelihood"""
62
+ y, x = batch
63
+ params = self.predict(x)
64
+ mu = params.get("mean")
65
+ return y * torch.log(mu) - mu - torch.lgamma(y + 1)
66
+
67
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
68
+ """Invert pseudoobservations."""
69
+ mu, u = self._local_params(x, u)
70
+ y = poisson(mu).ppf(u)
71
+ return torch.from_numpy(y).float()
72
+
73
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
74
+ """Return uniformized pseudo-observations for counts y given covariates x."""
75
+ # cdf values using scipy's parameterization
76
+ mu, y = self._local_params(x, y)
77
+ u1 = poisson(mu).cdf(y)
78
+ u2 = np.where(y > 0, poisson(mu).cdf(y - 1), 0)
79
+ v = np.random.uniform(size=y.shape)
80
+ u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
81
+ return torch.from_numpy(u).float()
82
+
83
+ def _local_params(self, x, y=None) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
84
+ params = self.predict(x)
85
+ mu = params.get('mean')
86
+ if y is None:
87
+ return _to_numpy(mu)
88
+ return _to_numpy(mu, y)
@@ -1,13 +1,41 @@
1
- from .formula import standardize_formula
2
- from .marginal import GLMPredictor, Marginal
3
- from .loader import _to_numpy
4
- from typing import Union, Dict, Optional
1
+ from ..data.formula import standardize_formula
2
+ from ..base.marginal import GLMPredictor, Marginal
3
+ from ..data.loader import _to_numpy
4
+ from typing import Union, Dict, Optional, Tuple
5
5
  import torch
6
6
  import numpy as np
7
7
  from scipy.stats import nbinom, bernoulli
8
8
 
9
9
  class ZeroInflatedNegBin(Marginal):
10
- """Zero-inflated negative-binomial marginal estimator"""
10
+ """Zero-inflated negative-binomial marginal estimator
11
+
12
+ This subclass models a two-part mixture for counts. For each feature
13
+ j the observation follows a mixture: with probability `pi_j(x)` the value
14
+ is an extra zero (inflation), otherwise the count is drawn from a
15
+ negative-binomial distribution NB(mu_j(x), r_j(x)) parameterized here via
16
+ a mean `mu_j(x)` and dispersion `r_j(x)`. All parameters may depend on
17
+ covariates `x` through the `formula` argument.
18
+
19
+ The allowed formula keys are 'mean', 'dispersion', and 'zero_inflation'.
20
+
21
+ Examples
22
+ --------
23
+ >>> from scdesigner.distributions import ZeroInflatedNegBin
24
+ >>> from scdesigner.datasets import pancreas
25
+ >>>
26
+ >>> sim = ZeroInflatedNegBin(formula={"mean": "~ pseudotime", "dispersion": "~ 1", "zero_inflation": "~ pseudotime"})
27
+ >>> sim.setup_data(pancreas)
28
+ >>> sim.fit(max_epochs=1, verbose=False)
29
+ >>>
30
+ >>> # evaluate p(y | x) and model parameters
31
+ >>> y, x = next(iter(sim.loader))
32
+ >>> l = sim.likelihood((y, x))
33
+ >>> y_hat = sim.predict(x)
34
+ >>>
35
+ >>> # convert to quantiles and back
36
+ >>> u = sim.uniformize(y, x)
37
+ >>> x_star = sim.invert(u, x)
38
+ """
11
39
  def __init__(self, formula: Union[Dict, str]):
12
40
  formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion', 'zero_inflation'])
13
41
  super().__init__(formula)
@@ -25,7 +53,8 @@ class ZeroInflatedNegBin(Marginal):
25
53
  "dispersion": torch.exp,
26
54
  "zero_inflation": torch.sigmoid,
27
55
  }
28
- nll = lambda batch: -self.likelihood(batch).sum()
56
+ def nll(batch):
57
+ return -self.likelihood(batch).sum()
29
58
  self.predict = GLMPredictor(
30
59
  n_outcomes=self.n_outcomes,
31
60
  feature_dims=self.feature_dims,
@@ -35,7 +64,7 @@ class ZeroInflatedNegBin(Marginal):
35
64
  optimizer_kwargs=optimizer_kwargs
36
65
  )
37
66
 
38
- def likelihood(self, batch):
67
+ def likelihood(self, batch) -> torch.Tensor:
39
68
  """Compute the negative log-likelihood"""
40
69
  y, x = batch
41
70
  params = self.predict(x)
@@ -56,14 +85,14 @@ class ZeroInflatedNegBin(Marginal):
56
85
  # return the mixture, with an offset to prevent log(0)
57
86
  return torch.log(pi * (y == 0) + (1 - pi) * torch.exp(negbin_loglikelihood) + 1e-10)
58
87
 
59
- def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
88
+ def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
60
89
  """Invert pseudoobservations."""
61
90
  mu, r, pi, u = self._local_params(x, u)
62
91
  y = nbinom(n=r, p=r / (r + mu)).ppf(u)
63
92
  delta = bernoulli(1 - pi).ppf(u)
64
93
  return torch.from_numpy(y * delta).float()
65
94
 
66
- def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
95
+ def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
67
96
  """Return uniformized pseudo-observations for counts y given covariates x."""
68
97
  # cdf values using scipy's parameterization
69
98
  mu, r, pi, y = self._local_params(x, y)
@@ -76,7 +105,7 @@ class ZeroInflatedNegBin(Marginal):
76
105
  u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
77
106
  return torch.from_numpy(u).float()
78
107
 
79
- def _local_params(self, x, y=None):
108
+ def _local_params(self, x, y=None) -> Tuple:
80
109
  params = self.predict(x)
81
110
  mu = params.get('mean')
82
111
  r = params.get('dispersion')