scdesigner 0.0.5__py3-none-any.whl → 0.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scdesigner/base/__init__.py +8 -0
- scdesigner/base/copula.py +416 -0
- scdesigner/base/marginal.py +391 -0
- scdesigner/base/simulator.py +59 -0
- scdesigner/copulas/__init__.py +8 -0
- scdesigner/copulas/standard_copula.py +645 -0
- scdesigner/datasets/__init__.py +5 -0
- scdesigner/datasets/pancreas.py +39 -0
- scdesigner/distributions/__init__.py +19 -0
- scdesigner/{minimal → distributions}/bernoulli.py +42 -14
- scdesigner/distributions/gaussian.py +114 -0
- scdesigner/distributions/negbin.py +121 -0
- scdesigner/distributions/negbin_irls.py +72 -0
- scdesigner/distributions/negbin_irls_funs.py +456 -0
- scdesigner/distributions/poisson.py +88 -0
- scdesigner/{minimal → distributions}/zero_inflated_negbin.py +39 -10
- scdesigner/distributions/zero_inflated_poisson.py +103 -0
- scdesigner/simulators/__init__.py +24 -28
- scdesigner/simulators/composite.py +239 -0
- scdesigner/simulators/positive_nonnegative_matrix_factorization.py +477 -0
- scdesigner/simulators/scd3.py +486 -0
- scdesigner/transform/__init__.py +8 -6
- scdesigner/{minimal → transform}/transform.py +1 -1
- scdesigner/{minimal → utils}/kwargs.py +4 -1
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/METADATA +1 -1
- scdesigner-0.0.10.dist-info/RECORD +28 -0
- {scdesigner-0.0.5.dist-info → scdesigner-0.0.10.dist-info}/WHEEL +1 -1
- scdesigner/data/__init__.py +0 -16
- scdesigner/data/formula.py +0 -137
- scdesigner/data/group.py +0 -123
- scdesigner/data/sparse.py +0 -39
- scdesigner/diagnose/__init__.py +0 -65
- scdesigner/diagnose/aic_bic.py +0 -119
- scdesigner/diagnose/plot.py +0 -242
- scdesigner/estimators/__init__.py +0 -32
- scdesigner/estimators/bernoulli.py +0 -85
- scdesigner/estimators/gaussian.py +0 -121
- scdesigner/estimators/gaussian_copula_factory.py +0 -367
- scdesigner/estimators/glm_factory.py +0 -75
- scdesigner/estimators/negbin.py +0 -153
- scdesigner/estimators/pnmf.py +0 -160
- scdesigner/estimators/poisson.py +0 -124
- scdesigner/estimators/zero_inflated_negbin.py +0 -195
- scdesigner/estimators/zero_inflated_poisson.py +0 -85
- scdesigner/format/__init__.py +0 -4
- scdesigner/format/format.py +0 -20
- scdesigner/format/print.py +0 -30
- scdesigner/minimal/__init__.py +0 -17
- scdesigner/minimal/composite.py +0 -119
- scdesigner/minimal/copula.py +0 -205
- scdesigner/minimal/formula.py +0 -23
- scdesigner/minimal/gaussian.py +0 -65
- scdesigner/minimal/loader.py +0 -211
- scdesigner/minimal/marginal.py +0 -154
- scdesigner/minimal/negbin.py +0 -73
- scdesigner/minimal/positive_nonnegative_matrix_factorization.py +0 -231
- scdesigner/minimal/scd3.py +0 -96
- scdesigner/minimal/scd3_instances.py +0 -50
- scdesigner/minimal/simulator.py +0 -25
- scdesigner/minimal/standard_copula.py +0 -383
- scdesigner/predictors/__init__.py +0 -15
- scdesigner/predictors/bernoulli.py +0 -9
- scdesigner/predictors/gaussian.py +0 -16
- scdesigner/predictors/negbin.py +0 -17
- scdesigner/predictors/poisson.py +0 -12
- scdesigner/predictors/zero_inflated_negbin.py +0 -18
- scdesigner/predictors/zero_inflated_poisson.py +0 -18
- scdesigner/samplers/__init__.py +0 -23
- scdesigner/samplers/bernoulli.py +0 -27
- scdesigner/samplers/gaussian.py +0 -25
- scdesigner/samplers/glm_factory.py +0 -103
- scdesigner/samplers/negbin.py +0 -25
- scdesigner/samplers/poisson.py +0 -25
- scdesigner/samplers/zero_inflated_negbin.py +0 -40
- scdesigner/samplers/zero_inflated_poisson.py +0 -16
- scdesigner/simulators/composite_regressor.py +0 -72
- scdesigner/simulators/glm_simulator.py +0 -167
- scdesigner/simulators/pnmf_regression.py +0 -61
- scdesigner/transform/amplify.py +0 -14
- scdesigner/transform/mask.py +0 -33
- scdesigner/transform/nullify.py +0 -25
- scdesigner/transform/split.py +0 -23
- scdesigner/transform/substitute.py +0 -14
- scdesigner-0.0.5.dist-info/RECORD +0 -66
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
import torch.special as spec
|
|
4
|
+
|
|
5
|
+
# ==============================================================================
|
|
6
|
+
# Weighted Least Squares Solver
|
|
7
|
+
# ==============================================================================
|
|
8
|
+
|
|
9
|
+
def solve_weighted_least_squares(X, weights, responses):
|
|
10
|
+
"""
|
|
11
|
+
Solve multiple independent weighted least squares problems in parallel.
|
|
12
|
+
|
|
13
|
+
For each column j, solves: (X'W_j X)β_j = X'W_j z_j
|
|
14
|
+
where W_j is a diagonal matrix with weights[:, j] on the diagonal.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
X : torch.Tensor
|
|
19
|
+
Design matrix (n × p)
|
|
20
|
+
weights : torch.Tensor
|
|
21
|
+
Weight matrix (n × m), one weight vector per response
|
|
22
|
+
responses : torch.Tensor
|
|
23
|
+
Working responses (n × m)
|
|
24
|
+
|
|
25
|
+
Returns
|
|
26
|
+
-------
|
|
27
|
+
torch.Tensor
|
|
28
|
+
Coefficient matrix (p × m)
|
|
29
|
+
"""
|
|
30
|
+
# Precompute outer products X_i ⊗ X_i for each observation
|
|
31
|
+
X_outer = torch.einsum("ni,nj->nij", X, X) # (n × p × p)
|
|
32
|
+
|
|
33
|
+
# Compute weighted normal equations: (X'WX) for all m responses at once
|
|
34
|
+
eye = torch.eye(X.shape[1]).unsqueeze(0)
|
|
35
|
+
weighted_XX = torch.einsum("nm,nij->mij", weights, X_outer) # (m × p × p)
|
|
36
|
+
weighted_XX = weighted_XX + 1e-5 * eye
|
|
37
|
+
|
|
38
|
+
# Compute X'Wz for all responses
|
|
39
|
+
weighted_Xy = torch.einsum("ni,nm->mi", X, weights * responses) # (m × p)
|
|
40
|
+
|
|
41
|
+
# Solve all systems at once
|
|
42
|
+
coefficients = torch.linalg.solve(weighted_XX, weighted_Xy.unsqueeze(-1))
|
|
43
|
+
return coefficients.squeeze(-1).T # (p × m)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# ==============================================================================
|
|
47
|
+
# Mean Parameter Updates (Beta)
|
|
48
|
+
# ==============================================================================
|
|
49
|
+
|
|
50
|
+
def update_mean_coefficients(X, counts, beta, dispersion, clip: float = 5.0):
|
|
51
|
+
"""
|
|
52
|
+
Update mean model coefficients using one Newton-Raphson step.
|
|
53
|
+
|
|
54
|
+
Uses IRLS (Iteratively Reweighted Least Squares) with:
|
|
55
|
+
- Working weights: W = μ/(1 + μ/θ)
|
|
56
|
+
- Working response: Z = Xβ + (Y - μ)/μ
|
|
57
|
+
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
X : torch.Tensor
|
|
61
|
+
Design matrix (n × p)
|
|
62
|
+
counts : torch.Tensor
|
|
63
|
+
Observed counts (n × m)
|
|
64
|
+
beta : torch.Tensor
|
|
65
|
+
Current coefficients (p × m)
|
|
66
|
+
dispersion : torch.Tensor
|
|
67
|
+
Current dispersion parameters (n × m)
|
|
68
|
+
clip : float, optional
|
|
69
|
+
Maximum absolute value for linear predictor, by default 5.0
|
|
70
|
+
|
|
71
|
+
Returns
|
|
72
|
+
-------
|
|
73
|
+
torch.Tensor
|
|
74
|
+
Updated coefficients (p × m)
|
|
75
|
+
"""
|
|
76
|
+
linear_pred = torch.clip(X @ beta, min=-clip, max=clip)
|
|
77
|
+
mean = torch.exp(linear_pred)
|
|
78
|
+
weights = mean / (1 + mean / dispersion)
|
|
79
|
+
working_response = linear_pred + (counts - mean) / mean
|
|
80
|
+
return solve_weighted_least_squares(X, weights, working_response)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# ==============================================================================
|
|
84
|
+
# Dispersion Parameter Updates (Gamma)
|
|
85
|
+
# ==============================================================================
|
|
86
|
+
|
|
87
|
+
def update_dispersion_coefficients(Z, counts, mean, gamma, clip: float = 5.0):
|
|
88
|
+
"""
|
|
89
|
+
Update dispersion model coefficients using one Fisher scoring step.
|
|
90
|
+
|
|
91
|
+
Uses working response U = η + θ·s/w where:
|
|
92
|
+
- η = Zγ (linear predictor)
|
|
93
|
+
- s = score with respect to θ
|
|
94
|
+
- w = approximate Fisher information
|
|
95
|
+
|
|
96
|
+
Parameters
|
|
97
|
+
----------
|
|
98
|
+
Z : torch.Tensor
|
|
99
|
+
Dispersion design matrix (n × q)
|
|
100
|
+
counts : torch.Tensor
|
|
101
|
+
Observed counts (n × m)
|
|
102
|
+
mean : torch.Tensor
|
|
103
|
+
Current mean estimates (n × m)
|
|
104
|
+
gamma : torch.Tensor
|
|
105
|
+
Current dispersion coefficients (q × m)
|
|
106
|
+
clip : float, optional
|
|
107
|
+
Maximum absolute value for linear predictor, by default 5.0
|
|
108
|
+
|
|
109
|
+
Returns
|
|
110
|
+
-------
|
|
111
|
+
torch.Tensor
|
|
112
|
+
Updated dispersion coefficients (q × m)
|
|
113
|
+
"""
|
|
114
|
+
linear_pred = torch.clip(Z @ gamma, min=-clip, max=clip)
|
|
115
|
+
dispersion = torch.exp(linear_pred)
|
|
116
|
+
|
|
117
|
+
# Score: ∂ℓ/∂θ
|
|
118
|
+
psi_diff = spec.digamma(counts + dispersion) - spec.digamma(dispersion)
|
|
119
|
+
score = (psi_diff + torch.log(dispersion) - torch.log(mean + dispersion) +
|
|
120
|
+
(mean - counts) / (mean + dispersion))
|
|
121
|
+
|
|
122
|
+
# Approximate Fisher information (replaces exact Hessian)
|
|
123
|
+
# Approximation: θY/(θ + Y) ≈ θ²[ψ₁(θ) - ψ₁(Y + θ)]
|
|
124
|
+
weights = ((dispersion * counts) / (dispersion + counts)).clip(min=1e-6)
|
|
125
|
+
working_response = linear_pred + (dispersion * score) / weights
|
|
126
|
+
return solve_weighted_least_squares(Z, weights, working_response)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
# ==============================================================================
|
|
130
|
+
# Initialization
|
|
131
|
+
# ==============================================================================
|
|
132
|
+
|
|
133
|
+
def estimate_constant_dispersion(X, counts, beta):
|
|
134
|
+
"""
|
|
135
|
+
Estimate constant dispersion for each response using method of moments.
|
|
136
|
+
|
|
137
|
+
Uses Pearson residuals: θ̂ = (Σμ) / max(χ² - df, 0.1)
|
|
138
|
+
where χ² = Σ(Y - μ)²/μ and df = n - p.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
X : torch.Tensor
|
|
143
|
+
Design matrix (n × p)
|
|
144
|
+
counts : torch.Tensor
|
|
145
|
+
Observed counts (n × m)
|
|
146
|
+
beta : torch.Tensor
|
|
147
|
+
Mean coefficients (p × m)
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
torch.Tensor
|
|
152
|
+
Dispersion estimates (m,)
|
|
153
|
+
"""
|
|
154
|
+
mean = torch.exp(X @ beta)
|
|
155
|
+
pearson_chi2 = torch.sum((counts - mean)**2 / mean, dim=0)
|
|
156
|
+
sum_mean = torch.sum(mean, dim=0)
|
|
157
|
+
|
|
158
|
+
degrees_freedom = counts.shape[0] - X.shape[1]
|
|
159
|
+
dispersion = sum_mean / torch.clip(pearson_chi2 - degrees_freedom, min=0.1)
|
|
160
|
+
return torch.clip(dispersion, min=0.1)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def fit_poisson_initial(X, counts, tol: float = 1e-3, max_iter: int = 100, clip: float = 5.0):
|
|
164
|
+
"""
|
|
165
|
+
Fit Poisson GLM to initialize mean parameters.
|
|
166
|
+
|
|
167
|
+
Parameters
|
|
168
|
+
----------
|
|
169
|
+
X : torch.Tensor
|
|
170
|
+
Design matrix (n × p)
|
|
171
|
+
counts : torch.Tensor
|
|
172
|
+
Observed counts (n × m)
|
|
173
|
+
tol : float, optional
|
|
174
|
+
Convergence tolerance, by default 1e-3
|
|
175
|
+
max_iter : int, optional
|
|
176
|
+
Maximum iterations, by default 100
|
|
177
|
+
clip : float, optional
|
|
178
|
+
Maximum absolute value for linear predictor, by default 5.0
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
torch.Tensor
|
|
183
|
+
Initial coefficients (p × m)
|
|
184
|
+
"""
|
|
185
|
+
n_features, n_responses = X.shape[1], counts.shape[1]
|
|
186
|
+
beta = torch.zeros((n_features, n_responses))
|
|
187
|
+
|
|
188
|
+
for _ in range(max_iter):
|
|
189
|
+
beta_old = beta.clone()
|
|
190
|
+
linear_pred = torch.clip(X @ beta, min=-clip, max=clip)
|
|
191
|
+
mean = torch.exp(linear_pred)
|
|
192
|
+
working_response = linear_pred + (counts - mean) / mean
|
|
193
|
+
|
|
194
|
+
beta = solve_weighted_least_squares(X, mean, working_response)
|
|
195
|
+
if torch.max(torch.abs(beta - beta_old)) < tol:
|
|
196
|
+
break
|
|
197
|
+
|
|
198
|
+
return beta
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def accumulate_poisson_statistics(loader, beta, n_genes, p_mean, clip = 5):
|
|
202
|
+
"""
|
|
203
|
+
Accumulate weighted normal equations for Poisson IRLS across batches.
|
|
204
|
+
|
|
205
|
+
Parameters
|
|
206
|
+
----------
|
|
207
|
+
loader : DataLoader
|
|
208
|
+
DataLoader yielding (y_batch, x_dict)
|
|
209
|
+
beta : torch.Tensor
|
|
210
|
+
Current coefficients (p_mean × n_genes)
|
|
211
|
+
n_genes : int
|
|
212
|
+
Number of genes
|
|
213
|
+
p_mean : int
|
|
214
|
+
Number of mean predictors
|
|
215
|
+
clip : float, optional
|
|
216
|
+
Maximum absolute value for linear predictor, by default 5
|
|
217
|
+
|
|
218
|
+
Returns
|
|
219
|
+
-------
|
|
220
|
+
weighted_XX : torch.Tensor
|
|
221
|
+
Accumulated X'WX (n_genes × p_mean × p_mean)
|
|
222
|
+
weighted_Xy : torch.Tensor
|
|
223
|
+
Accumulated X'Wz (p_mean × n_genes)
|
|
224
|
+
"""
|
|
225
|
+
weighted_XX = torch.zeros((n_genes, p_mean, p_mean))
|
|
226
|
+
weighted_Xy = torch.zeros((p_mean, n_genes))
|
|
227
|
+
|
|
228
|
+
for y_batch, x_dict in loader:
|
|
229
|
+
X = x_dict['mean'].to("cpu")
|
|
230
|
+
|
|
231
|
+
linear_pred = torch.clip(X @ beta, min=-clip, max=clip)
|
|
232
|
+
mean = torch.exp(linear_pred)
|
|
233
|
+
working_response = linear_pred + (y_batch.to("cpu") - mean) / mean
|
|
234
|
+
|
|
235
|
+
X_outer = torch.einsum("ni,nj->nij", X, X)
|
|
236
|
+
weighted_XX += torch.einsum("nm,nij->mij", mean, X_outer)
|
|
237
|
+
weighted_Xy += torch.einsum("ni,nm->im", X, mean * working_response)
|
|
238
|
+
|
|
239
|
+
return weighted_XX, weighted_Xy
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def accumulate_dispersion_statistics(loader, beta, clip = 5):
|
|
243
|
+
"""
|
|
244
|
+
Accumulate Pearson statistics for method of moments dispersion estimation.
|
|
245
|
+
|
|
246
|
+
Parameters
|
|
247
|
+
----------
|
|
248
|
+
loader : DataLoader
|
|
249
|
+
DataLoader yielding (y_batch, x_dict)
|
|
250
|
+
beta : torch.Tensor
|
|
251
|
+
Mean coefficients (p_mean × n_genes)
|
|
252
|
+
clip : float, optional
|
|
253
|
+
Maximum absolute value for linear predictor, by default 5
|
|
254
|
+
|
|
255
|
+
Returns
|
|
256
|
+
-------
|
|
257
|
+
sum_mean : torch.Tensor
|
|
258
|
+
Total predicted mean (n_genes,)
|
|
259
|
+
sum_pearson : torch.Tensor
|
|
260
|
+
Total Pearson chi-squared (n_genes,)
|
|
261
|
+
n_total : int
|
|
262
|
+
Total number of observations
|
|
263
|
+
"""
|
|
264
|
+
sum_mean = torch.zeros(beta.shape[1])
|
|
265
|
+
sum_pearson = torch.zeros(beta.shape[1])
|
|
266
|
+
n_total = 0
|
|
267
|
+
|
|
268
|
+
for y_batch, x_dict in loader:
|
|
269
|
+
X = x_dict['mean'].to('cpu')
|
|
270
|
+
linear_pred = torch.clip(X @ beta, min=-clip, max=clip)
|
|
271
|
+
mean_batch = torch.exp(linear_pred)
|
|
272
|
+
|
|
273
|
+
sum_mean += mean_batch.sum(dim=0)
|
|
274
|
+
sum_pearson += ((y_batch.to('cpu') - mean_batch)**2 / mean_batch).sum(dim=0)
|
|
275
|
+
n_total += y_batch.shape[0]
|
|
276
|
+
|
|
277
|
+
return sum_mean, sum_pearson, n_total
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def initialize_parameters(loader, n_genes, p_mean, p_disp, max_iter = 10,
|
|
281
|
+
tol = 1e-3, clip = 5):
|
|
282
|
+
"""
|
|
283
|
+
Initialize parameters using batched Poisson IRLS followed by MoM dispersion.
|
|
284
|
+
|
|
285
|
+
Logic:
|
|
286
|
+
1. Iteratively fit Poisson GLM by accumulating X'WX and X'WZ across batches
|
|
287
|
+
2. Use fitted Poisson means to estimate dispersion via Method of Moments
|
|
288
|
+
|
|
289
|
+
Parameters
|
|
290
|
+
----------
|
|
291
|
+
loader : DataLoader
|
|
292
|
+
DataLoader yielding (y_batch, x_dict)
|
|
293
|
+
n_genes : int
|
|
294
|
+
Number of response columns (genes)
|
|
295
|
+
p_mean : int
|
|
296
|
+
Number of predictors in the mean model
|
|
297
|
+
p_disp : int
|
|
298
|
+
Number of predictors in the dispersion model
|
|
299
|
+
max_iter : int, optional
|
|
300
|
+
Maximum Poisson IRLS iterations, by default 10
|
|
301
|
+
tol : float, optional
|
|
302
|
+
Convergence tolerance for beta coefficients, by default 1e-3
|
|
303
|
+
clip : float, optional
|
|
304
|
+
Maximum absolute value for linear predictor, by default 10
|
|
305
|
+
|
|
306
|
+
Returns
|
|
307
|
+
-------
|
|
308
|
+
beta_init : torch.Tensor
|
|
309
|
+
(p_mean × n_genes) tensor
|
|
310
|
+
gamma_init : torch.Tensor
|
|
311
|
+
(p_disp × n_genes) tensor
|
|
312
|
+
"""
|
|
313
|
+
beta = torch.zeros((p_mean, n_genes))
|
|
314
|
+
for _ in range(max_iter):
|
|
315
|
+
weighted_XX, weighted_Xy = accumulate_poisson_statistics(
|
|
316
|
+
loader, beta, n_genes, p_mean, clip
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
eye = torch.eye(p_mean).unsqueeze(0)
|
|
320
|
+
weighted_XX_reg = weighted_XX + 1e-6 * eye
|
|
321
|
+
beta_new = torch.linalg.solve(
|
|
322
|
+
weighted_XX_reg, weighted_Xy.T.unsqueeze(-1)
|
|
323
|
+
).squeeze(-1).T
|
|
324
|
+
|
|
325
|
+
if torch.max(torch.abs(beta_new - beta)) < tol:
|
|
326
|
+
beta = beta_new
|
|
327
|
+
break
|
|
328
|
+
beta = beta_new
|
|
329
|
+
|
|
330
|
+
sum_mean, sum_pearson, n_total = accumulate_dispersion_statistics(
|
|
331
|
+
loader, beta, clip
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
degrees_freedom = n_total - p_mean
|
|
335
|
+
dispersion = sum_mean / torch.clip(sum_pearson - degrees_freedom, min=0.1)
|
|
336
|
+
|
|
337
|
+
gamma = torch.zeros((p_disp, n_genes))
|
|
338
|
+
gamma[0, :] = torch.log(torch.clip(dispersion, min=0.1))
|
|
339
|
+
return beta, gamma
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
# ==============================================================================
|
|
343
|
+
# Batch Log-Likelihood
|
|
344
|
+
# ==============================================================================
|
|
345
|
+
|
|
346
|
+
def compute_batch_loglikelihood(y, mu, r):
|
|
347
|
+
"""
|
|
348
|
+
Compute the negative binomial log-likelihood for a batch.
|
|
349
|
+
|
|
350
|
+
Formula:
|
|
351
|
+
ℓ = Σ [log Γ(Y+θ) - log Γ(θ) - log Γ(Y+1) + θ log θ + Y log μ - (Y+θ)log(μ+θ)]
|
|
352
|
+
|
|
353
|
+
Parameters
|
|
354
|
+
----------
|
|
355
|
+
y : torch.Tensor
|
|
356
|
+
Observed counts (n_batch × m_active)
|
|
357
|
+
mu : torch.Tensor
|
|
358
|
+
Predicted means (n_batch × m_active)
|
|
359
|
+
r : torch.Tensor
|
|
360
|
+
Dispersion parameters (n_batch × m_active)
|
|
361
|
+
|
|
362
|
+
Returns
|
|
363
|
+
-------
|
|
364
|
+
torch.Tensor
|
|
365
|
+
Total log-likelihood per response (m_active,)
|
|
366
|
+
"""
|
|
367
|
+
ll = (
|
|
368
|
+
torch.lgamma(y + r) - torch.lgamma(r) - torch.lgamma(y + 1.0)
|
|
369
|
+
+ r * torch.log(r) + y * torch.log(mu)
|
|
370
|
+
- (r + y) * torch.log(r + mu)
|
|
371
|
+
)
|
|
372
|
+
return torch.sum(ll, dim=0)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
# ==============================================================================
|
|
376
|
+
# Stochastic IRLS Step
|
|
377
|
+
# ==============================================================================
|
|
378
|
+
|
|
379
|
+
def step_stochastic_irls(
|
|
380
|
+
y,
|
|
381
|
+
X,
|
|
382
|
+
Z,
|
|
383
|
+
beta,
|
|
384
|
+
gamma,
|
|
385
|
+
eta: float = 0.8,
|
|
386
|
+
tol: float = 1e-4,
|
|
387
|
+
ll_prev: Optional[torch.Tensor] = None,
|
|
388
|
+
clip_mean: float = 10.0,
|
|
389
|
+
clip_disp: float = 10.0
|
|
390
|
+
):
|
|
391
|
+
"""
|
|
392
|
+
Perform a single damped Newton-Raphson update on a minibatch.
|
|
393
|
+
|
|
394
|
+
Logic:
|
|
395
|
+
1. Compute log-likelihood with current coefficients.
|
|
396
|
+
2. Perform one IRLS step for Mean (Beta) and Dispersion (Gamma).
|
|
397
|
+
3. Re-compute log-likelihood to determine convergence.
|
|
398
|
+
4. Return updated coefficients and boolean convergence mask.
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
y : torch.Tensor
|
|
403
|
+
Count batch (n × m)
|
|
404
|
+
X : torch.Tensor
|
|
405
|
+
Mean design matrix (n × p)
|
|
406
|
+
Z : torch.Tensor
|
|
407
|
+
Dispersion design matrix (n × q)
|
|
408
|
+
beta : torch.Tensor
|
|
409
|
+
Current mean coefficients (p × m)
|
|
410
|
+
gamma : torch.Tensor
|
|
411
|
+
Current dispersion coefficients (q × m)
|
|
412
|
+
eta : float, optional
|
|
413
|
+
Damping factor (learning rate), 1.0 is pure Newton step, by default 0.8
|
|
414
|
+
tol : float, optional
|
|
415
|
+
Relative log-likelihood change threshold for convergence, by default 1e-4
|
|
416
|
+
ll_prev : torch.Tensor, optional
|
|
417
|
+
Previous log-likelihood values, by default None
|
|
418
|
+
clip_mean : float, optional
|
|
419
|
+
Maximum absolute value for mean linear predictor, by default 10.0
|
|
420
|
+
clip_disp : float, optional
|
|
421
|
+
Maximum absolute value for dispersion linear predictor, by default 10.0
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
beta_next : torch.Tensor
|
|
426
|
+
Updated mean coefficients (p × m)
|
|
427
|
+
gamma_next : torch.Tensor
|
|
428
|
+
Updated dispersion coefficients (q × m)
|
|
429
|
+
converged : torch.Tensor
|
|
430
|
+
Boolean mask of converged responses (m,)
|
|
431
|
+
ll_next : torch.Tensor
|
|
432
|
+
Updated log-likelihood values (m,)
|
|
433
|
+
"""
|
|
434
|
+
# --- 2. Update Mean (Beta) ---
|
|
435
|
+
# Working weights W = μ/(1 + μ/θ)
|
|
436
|
+
beta_target = update_mean_coefficients(X, y, beta, torch.exp(Z @ gamma), clip=clip_mean)
|
|
437
|
+
beta_next = (1 - eta) * beta + eta * beta_target
|
|
438
|
+
|
|
439
|
+
# --- 3. Update Dispersion (Gamma) ---
|
|
440
|
+
# Update depends on the latest mean estimates
|
|
441
|
+
linear_pred_mu = torch.clip(X @ beta_next, min=-clip_mean, max=clip_mean)
|
|
442
|
+
mu = torch.exp(linear_pred_mu)
|
|
443
|
+
gamma_target = update_dispersion_coefficients(Z, y, mu, gamma, clip=clip_disp)
|
|
444
|
+
gamma_next = (1 - eta) * gamma + eta * gamma_target
|
|
445
|
+
|
|
446
|
+
# --- 4. Convergence Check ---
|
|
447
|
+
linear_pred_r_next = torch.clip(Z @ gamma_next, min=-clip_disp, max=clip_disp)
|
|
448
|
+
ll_next = compute_batch_loglikelihood(y, mu, torch.exp(linear_pred_r_next))
|
|
449
|
+
|
|
450
|
+
# Relative improvement in the objective function
|
|
451
|
+
if ll_prev is not None:
|
|
452
|
+
rel_change = torch.abs(ll_next - ll_prev) / (torch.abs(ll_prev) + 1e-10)
|
|
453
|
+
converged = rel_change <= tol
|
|
454
|
+
else:
|
|
455
|
+
converged = False
|
|
456
|
+
return beta_next, gamma_next, converged, ll_next
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from ..data.formula import standardize_formula
|
|
2
|
+
from ..base.marginal import GLMPredictor, Marginal
|
|
3
|
+
from ..data.loader import _to_numpy
|
|
4
|
+
from typing import Union, Dict, Optional, Tuple
|
|
5
|
+
import torch
|
|
6
|
+
import numpy as np
|
|
7
|
+
from scipy.stats import poisson
|
|
8
|
+
|
|
9
|
+
class Poisson(Marginal):
|
|
10
|
+
"""Poisson marginal estimator
|
|
11
|
+
|
|
12
|
+
This subclass behaves like `Marginal` but assumes each gene follows a
|
|
13
|
+
Poisson distribution with mean `mu_j(x)` that depends on covariates `x`
|
|
14
|
+
via the provided `formula` object.
|
|
15
|
+
|
|
16
|
+
The allowed formula keys are 'mean', defaulting to a single `mean` term
|
|
17
|
+
if a string formula is supplied.
|
|
18
|
+
|
|
19
|
+
Examples
|
|
20
|
+
--------
|
|
21
|
+
>>> from scdesigner.distributions import Poisson
|
|
22
|
+
>>> from scdesigner.datasets import pancreas
|
|
23
|
+
>>>
|
|
24
|
+
>>> sim = Poisson(formula="~ bs(pseudotime, df=5)")
|
|
25
|
+
>>> sim.setup_data(pancreas)
|
|
26
|
+
>>> sim.fit(max_epochs=1, verbose=False)
|
|
27
|
+
>>>
|
|
28
|
+
>>> # evaluate p(y | x) and mu(x)
|
|
29
|
+
>>> y, x = next(iter(sim.loader))
|
|
30
|
+
>>> l = sim.likelihood((y, x))
|
|
31
|
+
>>> y_hat = sim.predict(x)
|
|
32
|
+
>>>
|
|
33
|
+
>>> # convert to quantiles and back
|
|
34
|
+
>>> u = sim.uniformize(y, x)
|
|
35
|
+
>>> x_star = sim.invert(u, x)
|
|
36
|
+
"""
|
|
37
|
+
def __init__(self, formula: Union[Dict, str]):
|
|
38
|
+
formula = standardize_formula(formula, allowed_keys=['mean'])
|
|
39
|
+
super().__init__(formula)
|
|
40
|
+
|
|
41
|
+
def setup_optimizer(
|
|
42
|
+
self,
|
|
43
|
+
optimizer_class: Optional[callable] = torch.optim.Adam,
|
|
44
|
+
**optimizer_kwargs,
|
|
45
|
+
):
|
|
46
|
+
if self.loader is None:
|
|
47
|
+
raise RuntimeError("self.loader is not set (call setup_data first)")
|
|
48
|
+
|
|
49
|
+
def nll(batch):
|
|
50
|
+
return -self.likelihood(batch).sum()
|
|
51
|
+
|
|
52
|
+
self.predict = GLMPredictor(
|
|
53
|
+
n_outcomes=self.n_outcomes,
|
|
54
|
+
feature_dims=self.feature_dims,
|
|
55
|
+
loss_fn=nll,
|
|
56
|
+
optimizer_class=optimizer_class,
|
|
57
|
+
optimizer_kwargs=optimizer_kwargs
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def likelihood(self, batch) -> torch.Tensor:
|
|
61
|
+
"""Compute the log-likelihood"""
|
|
62
|
+
y, x = batch
|
|
63
|
+
params = self.predict(x)
|
|
64
|
+
mu = params.get("mean")
|
|
65
|
+
return y * torch.log(mu) - mu - torch.lgamma(y + 1)
|
|
66
|
+
|
|
67
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
68
|
+
"""Invert pseudoobservations."""
|
|
69
|
+
mu, u = self._local_params(x, u)
|
|
70
|
+
y = poisson(mu).ppf(u)
|
|
71
|
+
return torch.from_numpy(y).float()
|
|
72
|
+
|
|
73
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
|
|
74
|
+
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
75
|
+
# cdf values using scipy's parameterization
|
|
76
|
+
mu, y = self._local_params(x, y)
|
|
77
|
+
u1 = poisson(mu).cdf(y)
|
|
78
|
+
u2 = np.where(y > 0, poisson(mu).cdf(y - 1), 0)
|
|
79
|
+
v = np.random.uniform(size=y.shape)
|
|
80
|
+
u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
|
|
81
|
+
return torch.from_numpy(u).float()
|
|
82
|
+
|
|
83
|
+
def _local_params(self, x, y=None) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
|
|
84
|
+
params = self.predict(x)
|
|
85
|
+
mu = params.get('mean')
|
|
86
|
+
if y is None:
|
|
87
|
+
return _to_numpy(mu)
|
|
88
|
+
return _to_numpy(mu, y)
|
|
@@ -1,13 +1,41 @@
|
|
|
1
|
-
from .formula import standardize_formula
|
|
2
|
-
from .marginal import GLMPredictor, Marginal
|
|
3
|
-
from .loader import _to_numpy
|
|
4
|
-
from typing import Union, Dict, Optional
|
|
1
|
+
from ..data.formula import standardize_formula
|
|
2
|
+
from ..base.marginal import GLMPredictor, Marginal
|
|
3
|
+
from ..data.loader import _to_numpy
|
|
4
|
+
from typing import Union, Dict, Optional, Tuple
|
|
5
5
|
import torch
|
|
6
6
|
import numpy as np
|
|
7
7
|
from scipy.stats import nbinom, bernoulli
|
|
8
8
|
|
|
9
9
|
class ZeroInflatedNegBin(Marginal):
|
|
10
|
-
"""Zero-inflated negative-binomial marginal estimator
|
|
10
|
+
"""Zero-inflated negative-binomial marginal estimator
|
|
11
|
+
|
|
12
|
+
This subclass models a two-part mixture for counts. For each feature
|
|
13
|
+
j the observation follows a mixture: with probability `pi_j(x)` the value
|
|
14
|
+
is an extra zero (inflation), otherwise the count is drawn from a
|
|
15
|
+
negative-binomial distribution NB(mu_j(x), r_j(x)) parameterized here via
|
|
16
|
+
a mean `mu_j(x)` and dispersion `r_j(x)`. All parameters may depend on
|
|
17
|
+
covariates `x` through the `formula` argument.
|
|
18
|
+
|
|
19
|
+
The allowed formula keys are 'mean', 'dispersion', and 'zero_inflation'.
|
|
20
|
+
|
|
21
|
+
Examples
|
|
22
|
+
--------
|
|
23
|
+
>>> from scdesigner.distributions import ZeroInflatedNegBin
|
|
24
|
+
>>> from scdesigner.datasets import pancreas
|
|
25
|
+
>>>
|
|
26
|
+
>>> sim = ZeroInflatedNegBin(formula={"mean": "~ pseudotime", "dispersion": "~ 1", "zero_inflation": "~ pseudotime"})
|
|
27
|
+
>>> sim.setup_data(pancreas)
|
|
28
|
+
>>> sim.fit(max_epochs=1, verbose=False)
|
|
29
|
+
>>>
|
|
30
|
+
>>> # evaluate p(y | x) and model parameters
|
|
31
|
+
>>> y, x = next(iter(sim.loader))
|
|
32
|
+
>>> l = sim.likelihood((y, x))
|
|
33
|
+
>>> y_hat = sim.predict(x)
|
|
34
|
+
>>>
|
|
35
|
+
>>> # convert to quantiles and back
|
|
36
|
+
>>> u = sim.uniformize(y, x)
|
|
37
|
+
>>> x_star = sim.invert(u, x)
|
|
38
|
+
"""
|
|
11
39
|
def __init__(self, formula: Union[Dict, str]):
|
|
12
40
|
formula = standardize_formula(formula, allowed_keys=['mean', 'dispersion', 'zero_inflation'])
|
|
13
41
|
super().__init__(formula)
|
|
@@ -25,7 +53,8 @@ class ZeroInflatedNegBin(Marginal):
|
|
|
25
53
|
"dispersion": torch.exp,
|
|
26
54
|
"zero_inflation": torch.sigmoid,
|
|
27
55
|
}
|
|
28
|
-
nll
|
|
56
|
+
def nll(batch):
|
|
57
|
+
return -self.likelihood(batch).sum()
|
|
29
58
|
self.predict = GLMPredictor(
|
|
30
59
|
n_outcomes=self.n_outcomes,
|
|
31
60
|
feature_dims=self.feature_dims,
|
|
@@ -35,7 +64,7 @@ class ZeroInflatedNegBin(Marginal):
|
|
|
35
64
|
optimizer_kwargs=optimizer_kwargs
|
|
36
65
|
)
|
|
37
66
|
|
|
38
|
-
def likelihood(self, batch):
|
|
67
|
+
def likelihood(self, batch) -> torch.Tensor:
|
|
39
68
|
"""Compute the negative log-likelihood"""
|
|
40
69
|
y, x = batch
|
|
41
70
|
params = self.predict(x)
|
|
@@ -56,14 +85,14 @@ class ZeroInflatedNegBin(Marginal):
|
|
|
56
85
|
# return the mixture, with an offset to prevent log(0)
|
|
57
86
|
return torch.log(pi * (y == 0) + (1 - pi) * torch.exp(negbin_loglikelihood) + 1e-10)
|
|
58
87
|
|
|
59
|
-
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]):
|
|
88
|
+
def invert(self, u: torch.Tensor, x: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
60
89
|
"""Invert pseudoobservations."""
|
|
61
90
|
mu, r, pi, u = self._local_params(x, u)
|
|
62
91
|
y = nbinom(n=r, p=r / (r + mu)).ppf(u)
|
|
63
92
|
delta = bernoulli(1 - pi).ppf(u)
|
|
64
93
|
return torch.from_numpy(y * delta).float()
|
|
65
94
|
|
|
66
|
-
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6):
|
|
95
|
+
def uniformize(self, y: torch.Tensor, x: Dict[str, torch.Tensor], epsilon=1e-6) -> torch.Tensor:
|
|
67
96
|
"""Return uniformized pseudo-observations for counts y given covariates x."""
|
|
68
97
|
# cdf values using scipy's parameterization
|
|
69
98
|
mu, r, pi, y = self._local_params(x, y)
|
|
@@ -76,7 +105,7 @@ class ZeroInflatedNegBin(Marginal):
|
|
|
76
105
|
u = np.clip(v * u1 + (1 - v) * u2, epsilon, 1 - epsilon)
|
|
77
106
|
return torch.from_numpy(u).float()
|
|
78
107
|
|
|
79
|
-
def _local_params(self, x, y=None):
|
|
108
|
+
def _local_params(self, x, y=None) -> Tuple:
|
|
80
109
|
params = self.predict(x)
|
|
81
110
|
mu = params.get('mean')
|
|
82
111
|
r = params.get('dispersion')
|