edgepython 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edgepython/__init__.py +114 -0
- edgepython/classes.py +517 -0
- edgepython/compressed_matrix.py +388 -0
- edgepython/dgelist.py +314 -0
- edgepython/dispersion.py +920 -0
- edgepython/dispersion_lowlevel.py +1066 -0
- edgepython/exact_test.py +525 -0
- edgepython/expression.py +323 -0
- edgepython/filtering.py +96 -0
- edgepython/gene_sets.py +1215 -0
- edgepython/glm_fit.py +653 -0
- edgepython/glm_levenberg.py +359 -0
- edgepython/glm_test.py +375 -0
- edgepython/io.py +1887 -0
- edgepython/limma_port.py +987 -0
- edgepython/normalization.py +546 -0
- edgepython/ql_weights.py +765 -0
- edgepython/results.py +236 -0
- edgepython/sc_fit.py +1511 -0
- edgepython/smoothing.py +474 -0
- edgepython/splicing.py +537 -0
- edgepython/utils.py +1050 -0
- edgepython/visualization.py +409 -0
- edgepython/weighted_lowess.py +323 -0
- edgepython-0.2.0.dist-info/METADATA +201 -0
- edgepython-0.2.0.dist-info/RECORD +29 -0
- edgepython-0.2.0.dist-info/WHEEL +5 -0
- edgepython-0.2.0.dist-info/licenses/LICENSE +674 -0
- edgepython-0.2.0.dist-info/top_level.txt +1 -0
edgepython/sc_fit.py
ADDED
|
@@ -0,0 +1,1511 @@
|
|
|
1
|
+
# This code was written by Claude (Anthropic). The project was directed by Lior Pachter.
|
|
2
|
+
"""Single-cell NB mixed model fitting (NEBULA-LN port).
|
|
3
|
+
|
|
4
|
+
Implements ``glm_sc_fit()`` and ``glm_sc_test()`` for cell-level
|
|
5
|
+
negative binomial gamma mixed model (NBGMM) analysis of multi-subject
|
|
6
|
+
single-cell RNA-seq data.
|
|
7
|
+
|
|
8
|
+
Reference
|
|
9
|
+
---------
|
|
10
|
+
He L, Davila-Velderrain J, Sumida TS, Hafler DA, Bhatt DL et al.
|
|
11
|
+
NEBULA is a fast negative binomial mixed model for differential or
|
|
12
|
+
co-expression analysis of large-scale multi-subject single-cell data.
|
|
13
|
+
*Communications Biology*, 4:629, 2021.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
import warnings
|
|
20
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
21
|
+
from math import lgamma as _lgamma
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import pandas as pd
|
|
26
|
+
from numba import njit
|
|
27
|
+
from scipy.optimize import minimize as _minimize
|
|
28
|
+
from scipy.special import digamma as _digamma, gammaln as _gammaln
|
|
29
|
+
from scipy.stats import chi2 as _chi2
|
|
30
|
+
|
|
31
|
+
from .normalization import calc_norm_factors
|
|
32
|
+
from .dgelist import make_dgelist
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ---------------------------------------------------------------------------
|
|
36
|
+
# Numba-accelerated core functions
|
|
37
|
+
# ---------------------------------------------------------------------------
|
|
38
|
+
|
|
39
|
+
@njit(cache=True)
|
|
40
|
+
def _digamma_nb(x):
|
|
41
|
+
"""Digamma (psi) function for x > 0. Accurate to ~15 digits."""
|
|
42
|
+
result = 0.0
|
|
43
|
+
while x < 7.0:
|
|
44
|
+
result -= 1.0 / x
|
|
45
|
+
x += 1.0
|
|
46
|
+
r = 1.0 / (x * x)
|
|
47
|
+
result += math.log(x) - 0.5 / x
|
|
48
|
+
result -= r * (1.0/12.0 - r * (1.0/120.0 - r * (1.0/252.0
|
|
49
|
+
- r * (1.0/240.0 - r * (5.0/660.0 - r * 691.0/32760.0)))))
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@njit(cache=True)
|
|
54
|
+
def _ptmg_negll_and_grad_nb(para, X, offset, Y, n_one, n_two, ytwo,
|
|
55
|
+
fid, cumsumy, posind, posindy, nb, nind, k):
|
|
56
|
+
"""Numba-compiled NBGMM negative log-likelihood + gradient."""
|
|
57
|
+
beta = para[:nb]
|
|
58
|
+
sigma_param = para[nb]
|
|
59
|
+
phi = para[nb + 1]
|
|
60
|
+
|
|
61
|
+
exps = math.exp(sigma_param)
|
|
62
|
+
exps_m1 = exps - 1.0
|
|
63
|
+
if exps_m1 <= 0:
|
|
64
|
+
return 1e30, np.zeros(nb + 2)
|
|
65
|
+
alpha = 1.0 / exps_m1
|
|
66
|
+
exps_s = math.sqrt(exps)
|
|
67
|
+
lam = alpha / exps_s
|
|
68
|
+
gamma = phi
|
|
69
|
+
|
|
70
|
+
exps_m1_sq = exps_m1 * exps_m1
|
|
71
|
+
alpha_pr = -exps / exps_m1_sq
|
|
72
|
+
lambda_pr = (1.0 - 3.0 * exps) / (2.0 * exps_s * exps_m1_sq)
|
|
73
|
+
|
|
74
|
+
log_lambda = math.log(lam)
|
|
75
|
+
log_gamma = math.log(gamma) if gamma > 0 else -1e30
|
|
76
|
+
|
|
77
|
+
nelem = len(posindy)
|
|
78
|
+
|
|
79
|
+
# Linear predictor: xtb = offset + X @ beta
|
|
80
|
+
xtb = np.empty(nind)
|
|
81
|
+
for i in range(nind):
|
|
82
|
+
s = offset[i]
|
|
83
|
+
for j in range(nb):
|
|
84
|
+
s += X[i, j] * beta[j]
|
|
85
|
+
xtb[i] = s
|
|
86
|
+
|
|
87
|
+
# term1 = sum y_j * eta_j (only non-zero)
|
|
88
|
+
term1 = 0.0
|
|
89
|
+
for i in range(nelem):
|
|
90
|
+
term1 += xtb[posindy[i]] * Y[i]
|
|
91
|
+
|
|
92
|
+
# exp(xtb) with overflow protection
|
|
93
|
+
extb = np.empty(nind)
|
|
94
|
+
for i in range(nind):
|
|
95
|
+
v = xtb[i]
|
|
96
|
+
extb[i] = math.exp(min(v, 500.0))
|
|
97
|
+
|
|
98
|
+
# Per-sample sums
|
|
99
|
+
cumsumxtb = np.empty(k)
|
|
100
|
+
for s in range(k):
|
|
101
|
+
start = fid[s]
|
|
102
|
+
end = fid[s + 1]
|
|
103
|
+
acc = 0.0
|
|
104
|
+
for i in range(start, end):
|
|
105
|
+
acc += extb[i]
|
|
106
|
+
cumsumxtb[s] = acc
|
|
107
|
+
|
|
108
|
+
ystar = np.empty(k)
|
|
109
|
+
mustar = np.empty(k)
|
|
110
|
+
mustar_log = np.empty(k)
|
|
111
|
+
ymustar = np.empty(k)
|
|
112
|
+
ymumustar = np.empty(k)
|
|
113
|
+
for s in range(k):
|
|
114
|
+
ystar[s] = cumsumy[s] + alpha
|
|
115
|
+
mustar[s] = cumsumxtb[s] + lam
|
|
116
|
+
mustar_log[s] = math.log(mustar[s])
|
|
117
|
+
ymustar[s] = ystar[s] / mustar[s]
|
|
118
|
+
ymumustar[s] = ymustar[s] / mustar[s]
|
|
119
|
+
|
|
120
|
+
for s in range(k):
|
|
121
|
+
term1 -= ystar[s] * mustar_log[s]
|
|
122
|
+
term1 += k * alpha * log_lambda
|
|
123
|
+
term1 += nind * gamma * log_gamma
|
|
124
|
+
|
|
125
|
+
# gstar = gamma + y_j
|
|
126
|
+
gstar_vec = np.full(nind, gamma)
|
|
127
|
+
for i in range(nelem):
|
|
128
|
+
gstar_vec[posindy[i]] += Y[i]
|
|
129
|
+
|
|
130
|
+
# sum_elgcp[j] = ymustar[s(j)] * extb[j]
|
|
131
|
+
sum_elgcp = np.empty(nind)
|
|
132
|
+
for s in range(k):
|
|
133
|
+
start = fid[s]
|
|
134
|
+
end = fid[s + 1]
|
|
135
|
+
val = ymustar[s]
|
|
136
|
+
for i in range(start, end):
|
|
137
|
+
sum_elgcp[i] = val * extb[i]
|
|
138
|
+
|
|
139
|
+
for i in range(nind):
|
|
140
|
+
term1 += sum_elgcp[i]
|
|
141
|
+
|
|
142
|
+
sum_elgcp_pg = np.empty(nind)
|
|
143
|
+
gstar_phiymustar = np.empty(nind)
|
|
144
|
+
log_sum_elgcp_pg = np.empty(nind)
|
|
145
|
+
slpey = 0.0
|
|
146
|
+
for i in range(nind):
|
|
147
|
+
sum_elgcp_pg[i] = sum_elgcp[i] + gamma
|
|
148
|
+
gstar_phiymustar[i] = gstar_vec[i] / sum_elgcp_pg[i]
|
|
149
|
+
log_sum_elgcp_pg[i] = math.log(sum_elgcp_pg[i])
|
|
150
|
+
slpey += log_sum_elgcp_pg[i]
|
|
151
|
+
|
|
152
|
+
term1 -= gamma * slpey
|
|
153
|
+
for i in range(nelem):
|
|
154
|
+
term1 -= Y[i] * log_sum_elgcp_pg[posindy[i]]
|
|
155
|
+
|
|
156
|
+
fn_cpp = -term1
|
|
157
|
+
|
|
158
|
+
# --- Gradient ---
|
|
159
|
+
dbeta_42 = np.zeros(k)
|
|
160
|
+
xexb_f = np.zeros((nb, k))
|
|
161
|
+
dbeta_41 = np.zeros((nb, k))
|
|
162
|
+
|
|
163
|
+
for s in range(k):
|
|
164
|
+
start = fid[s]
|
|
165
|
+
end = fid[s + 1]
|
|
166
|
+
for i in range(start, end):
|
|
167
|
+
gp = gstar_phiymustar[i]
|
|
168
|
+
ext_i = extb[i]
|
|
169
|
+
dbeta_42[s] += gp * ext_i
|
|
170
|
+
for j in range(nb):
|
|
171
|
+
xexb = X[i, j] * ext_i
|
|
172
|
+
xexb_f[j, s] += xexb
|
|
173
|
+
dbeta_41[j, s] += gp * xexb
|
|
174
|
+
|
|
175
|
+
db = np.zeros(nb)
|
|
176
|
+
for i in range(nelem):
|
|
177
|
+
for j in range(nb):
|
|
178
|
+
db[j] += X[posindy[i], j] * Y[i]
|
|
179
|
+
|
|
180
|
+
for s in range(k):
|
|
181
|
+
val = ymumustar[s] * (dbeta_42[s] - cumsumxtb[s])
|
|
182
|
+
for j in range(nb):
|
|
183
|
+
db[j] += xexb_f[j, s] * val - dbeta_41[j, s] * ymustar[s]
|
|
184
|
+
|
|
185
|
+
ldm = log_lambda * k
|
|
186
|
+
for s in range(k):
|
|
187
|
+
ldm -= mustar_log[s]
|
|
188
|
+
adlmy = exps_s * k
|
|
189
|
+
for s in range(k):
|
|
190
|
+
adlmy -= ymustar[s]
|
|
191
|
+
|
|
192
|
+
dtau = 0.0
|
|
193
|
+
dtau_lp = 0.0
|
|
194
|
+
for s in range(k):
|
|
195
|
+
dtau += alpha_pr * (cumsumxtb[s] - dbeta_42[s]) / mustar[s]
|
|
196
|
+
dtau_lp += ymumustar[s] * (dbeta_42[s] - cumsumxtb[s])
|
|
197
|
+
dtau += lambda_pr * dtau_lp
|
|
198
|
+
dtau += alpha_pr * ldm + lambda_pr * adlmy
|
|
199
|
+
|
|
200
|
+
dtau2 = log_gamma * nind + nind - slpey
|
|
201
|
+
for i in range(nind):
|
|
202
|
+
dtau2 -= gstar_phiymustar[i]
|
|
203
|
+
|
|
204
|
+
gr = np.zeros(nb + 2)
|
|
205
|
+
for j in range(nb):
|
|
206
|
+
gr[j] = -db[j]
|
|
207
|
+
gr[nb] = -dtau
|
|
208
|
+
gr[nb + 1] = -dtau2
|
|
209
|
+
|
|
210
|
+
# --- R-level lgamma corrections ---
|
|
211
|
+
n_one_plus_two = n_one + n_two
|
|
212
|
+
|
|
213
|
+
lgamma_fn = 0.0
|
|
214
|
+
for s_idx in range(len(posind)):
|
|
215
|
+
lgamma_fn += math.lgamma(cumsumy[posind[s_idx]] + alpha)
|
|
216
|
+
lgamma_fn -= len(posind) * math.lgamma(alpha)
|
|
217
|
+
for v_idx in range(len(ytwo)):
|
|
218
|
+
lgamma_fn += math.lgamma(ytwo[v_idx] + gamma)
|
|
219
|
+
lgamma_fn -= (nelem - n_one_plus_two) * math.lgamma(gamma)
|
|
220
|
+
if n_one_plus_two > 0:
|
|
221
|
+
lgamma_fn += n_one_plus_two * math.log(gamma)
|
|
222
|
+
if n_two > 0:
|
|
223
|
+
lgamma_fn += n_two * math.log(gamma + 1.0)
|
|
224
|
+
|
|
225
|
+
fn = fn_cpp - lgamma_fn
|
|
226
|
+
|
|
227
|
+
# --- Digamma corrections ---
|
|
228
|
+
dig_alpha_sum = 0.0
|
|
229
|
+
for s_idx in range(len(posind)):
|
|
230
|
+
dig_alpha_sum += _digamma_nb(cumsumy[posind[s_idx]] + alpha)
|
|
231
|
+
dig_alpha_sum -= len(posind) * _digamma_nb(alpha)
|
|
232
|
+
|
|
233
|
+
dig_gamma_sum = 0.0
|
|
234
|
+
for v_idx in range(len(ytwo)):
|
|
235
|
+
dig_gamma_sum += _digamma_nb(ytwo[v_idx] + gamma)
|
|
236
|
+
dig_gamma_sum -= (nelem - n_one_plus_two) * _digamma_nb(gamma)
|
|
237
|
+
if n_one_plus_two > 0:
|
|
238
|
+
dig_gamma_sum += n_one_plus_two / gamma
|
|
239
|
+
if n_two > 0:
|
|
240
|
+
dig_gamma_sum += n_two / (gamma + 1.0)
|
|
241
|
+
|
|
242
|
+
gr[nb] -= alpha_pr * dig_alpha_sum
|
|
243
|
+
gr[nb + 1] -= dig_gamma_sum
|
|
244
|
+
|
|
245
|
+
return fn, gr
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@njit(cache=True)
|
|
249
|
+
def _compute_pml_loglik_nb(offset, X, beta, logw, fid, k, posindy, Y,
|
|
250
|
+
cumsumy, gamma, alpha, lam, nind, nb):
|
|
251
|
+
"""Numba-compiled PML log-likelihood evaluation."""
|
|
252
|
+
nelem = len(posindy)
|
|
253
|
+
|
|
254
|
+
# extb_lin = offset + X @ beta
|
|
255
|
+
extb_lin = np.empty(nind)
|
|
256
|
+
for i in range(nind):
|
|
257
|
+
s = offset[i]
|
|
258
|
+
for j in range(nb):
|
|
259
|
+
s += X[i, j] * beta[j]
|
|
260
|
+
extb_lin[i] = s
|
|
261
|
+
|
|
262
|
+
loglik = 0.0
|
|
263
|
+
for i in range(nelem):
|
|
264
|
+
loglik += extb_lin[posindy[i]] * Y[i]
|
|
265
|
+
|
|
266
|
+
# logw @ cumsumy
|
|
267
|
+
for s in range(k):
|
|
268
|
+
loglik += logw[s] * cumsumy[s]
|
|
269
|
+
|
|
270
|
+
# Add logw to linear predictor per sample
|
|
271
|
+
for s in range(k):
|
|
272
|
+
start = fid[s]
|
|
273
|
+
end = fid[s + 1]
|
|
274
|
+
for i in range(start, end):
|
|
275
|
+
extb_lin[i] += logw[s]
|
|
276
|
+
|
|
277
|
+
# exp(extb_lin)
|
|
278
|
+
extb = np.empty(nind)
|
|
279
|
+
for i in range(nind):
|
|
280
|
+
extb[i] = math.exp(min(extb_lin[i], 500.0))
|
|
281
|
+
|
|
282
|
+
for i in range(nind):
|
|
283
|
+
extbphil = math.log(extb[i] + gamma)
|
|
284
|
+
loglik -= gamma * extbphil
|
|
285
|
+
|
|
286
|
+
for i in range(nelem):
|
|
287
|
+
loglik -= Y[i] * math.log(extb[posindy[i]] + gamma)
|
|
288
|
+
|
|
289
|
+
for s in range(k):
|
|
290
|
+
loglik += alpha * logw[s] - lam * math.exp(logw[s])
|
|
291
|
+
|
|
292
|
+
return loglik, extb
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
@njit(cache=True)
|
|
296
|
+
def _opt_pml_nb(X, offset, Y_vals, fid, cumsumy, posindy, nb, nind, k,
|
|
297
|
+
beta_init, sigma0, sigma1, eps, ord_):
|
|
298
|
+
"""Numba-compiled PML optimizer.
|
|
299
|
+
|
|
300
|
+
Returns (beta, logw, vb2, loglik, loglikp, logdet, step, stepd, sec_ord).
|
|
301
|
+
"""
|
|
302
|
+
exps = math.exp(sigma0)
|
|
303
|
+
alpha = 1.0 / (exps - 1.0)
|
|
304
|
+
lam = 1.0 / (math.sqrt(exps) * (exps - 1.0))
|
|
305
|
+
gamma = sigma1
|
|
306
|
+
|
|
307
|
+
logw = np.zeros(k)
|
|
308
|
+
beta = beta_init.copy()
|
|
309
|
+
|
|
310
|
+
# gstar: gamma + y for non-zero cells
|
|
311
|
+
gstar = np.full(nind, gamma)
|
|
312
|
+
nelem = len(posindy)
|
|
313
|
+
for i in range(nelem):
|
|
314
|
+
gstar[posindy[i]] += Y_vals[i]
|
|
315
|
+
|
|
316
|
+
# Precompute yx = X^T @ y (only non-zero entries)
|
|
317
|
+
yx = np.zeros(nb)
|
|
318
|
+
for i in range(nelem):
|
|
319
|
+
for j in range(nb):
|
|
320
|
+
yx[j] += X[posindy[i], j] * Y_vals[i]
|
|
321
|
+
|
|
322
|
+
# Initial log-likelihood
|
|
323
|
+
loglik, extb = _compute_pml_loglik_nb(
|
|
324
|
+
offset, X, beta, logw, fid, k, posindy, Y_vals,
|
|
325
|
+
cumsumy, gamma, alpha, lam, nind, nb
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
loglikp = 0.0
|
|
329
|
+
step = 0
|
|
330
|
+
maxstep = 50
|
|
331
|
+
maxstd = 10
|
|
332
|
+
convd = 0.01
|
|
333
|
+
stepd = 0
|
|
334
|
+
|
|
335
|
+
vb = np.zeros((nb, nb))
|
|
336
|
+
vb2 = np.zeros((nb, nb))
|
|
337
|
+
vw = np.zeros(k)
|
|
338
|
+
vwb = np.zeros((k, nb))
|
|
339
|
+
|
|
340
|
+
while step == 0 or (loglik - loglikp > eps and step < maxstep):
|
|
341
|
+
step += 1
|
|
342
|
+
|
|
343
|
+
damp = np.ones(nb)
|
|
344
|
+
damp_w = np.ones(k)
|
|
345
|
+
|
|
346
|
+
# gstar_extb_phi = gstar / (1 + gamma/extb)
|
|
347
|
+
gstar_extb_phi = np.empty(nind)
|
|
348
|
+
for i in range(nind):
|
|
349
|
+
if extb[i] < 1e-300:
|
|
350
|
+
gstar_extb_phi[i] = 0.0
|
|
351
|
+
else:
|
|
352
|
+
gstar_extb_phi[i] = gstar[i] / (1.0 + gamma / extb[i])
|
|
353
|
+
|
|
354
|
+
# Gradient w.r.t. beta: db = yx - X^T @ gstar_extb_phi
|
|
355
|
+
db = np.empty(nb)
|
|
356
|
+
for j in range(nb):
|
|
357
|
+
s = yx[j]
|
|
358
|
+
for i in range(nind):
|
|
359
|
+
s -= X[i, j] * gstar_extb_phi[i]
|
|
360
|
+
db[j] = s
|
|
361
|
+
|
|
362
|
+
# Gradient w.r.t. logw
|
|
363
|
+
dw = np.empty(k)
|
|
364
|
+
w = np.empty(k)
|
|
365
|
+
for s in range(k):
|
|
366
|
+
start = fid[s]
|
|
367
|
+
end = fid[s + 1]
|
|
368
|
+
acc = 0.0
|
|
369
|
+
for i in range(start, end):
|
|
370
|
+
acc += gstar_extb_phi[i]
|
|
371
|
+
w[s] = math.exp(logw[s])
|
|
372
|
+
dw[s] = cumsumy[s] - acc - lam * w[s] + alpha
|
|
373
|
+
|
|
374
|
+
# Hessian diagonal w.r.t. logw
|
|
375
|
+
gstar_extb_phi2 = np.empty(nind)
|
|
376
|
+
for i in range(nind):
|
|
377
|
+
denom = extb[i] + gamma
|
|
378
|
+
if denom < 1e-300:
|
|
379
|
+
gstar_extb_phi2[i] = 0.0
|
|
380
|
+
else:
|
|
381
|
+
gstar_extb_phi2[i] = gstar_extb_phi[i] / denom
|
|
382
|
+
|
|
383
|
+
for s in range(k):
|
|
384
|
+
start = fid[s]
|
|
385
|
+
end = fid[s + 1]
|
|
386
|
+
acc = 0.0
|
|
387
|
+
for i in range(start, end):
|
|
388
|
+
acc += gstar_extb_phi2[i]
|
|
389
|
+
vw[s] = gamma * acc + lam * w[s]
|
|
390
|
+
|
|
391
|
+
# Cross-term Hessian vwb (k × nb)
|
|
392
|
+
for s in range(k):
|
|
393
|
+
start = fid[s]
|
|
394
|
+
end = fid[s + 1]
|
|
395
|
+
for j in range(nb):
|
|
396
|
+
acc = 0.0
|
|
397
|
+
for i in range(start, end):
|
|
398
|
+
acc += X[i, j] * gstar_extb_phi2[i]
|
|
399
|
+
vwb[s, j] = gamma * acc
|
|
400
|
+
|
|
401
|
+
# Hessian w.r.t. beta (nb × nb)
|
|
402
|
+
for ii in range(nb):
|
|
403
|
+
for jj in range(ii, nb):
|
|
404
|
+
acc = 0.0
|
|
405
|
+
for i in range(nind):
|
|
406
|
+
acc += X[i, ii] * gstar_extb_phi2[i] * X[i, jj]
|
|
407
|
+
vb[ii, jj] = gamma * acc
|
|
408
|
+
if ii != jj:
|
|
409
|
+
vb[jj, ii] = vb[ii, jj]
|
|
410
|
+
|
|
411
|
+
# Floor vw to avoid division by zero
|
|
412
|
+
for s in range(k):
|
|
413
|
+
if vw[s] < 1e-15:
|
|
414
|
+
vw[s] = 1e-15
|
|
415
|
+
|
|
416
|
+
# Schur complement: vb2 = vb - vwb^T @ diag(1/vw) @ vwb
|
|
417
|
+
for ii in range(nb):
|
|
418
|
+
for jj in range(nb):
|
|
419
|
+
acc = 0.0
|
|
420
|
+
for s in range(k):
|
|
421
|
+
acc += vwb[s, ii] * vwb[s, jj] / vw[s]
|
|
422
|
+
vb2[ii, jj] = vb[ii, jj] - acc
|
|
423
|
+
|
|
424
|
+
# Newton step
|
|
425
|
+
dwvw = np.empty(k)
|
|
426
|
+
for s in range(k):
|
|
427
|
+
dwvw[s] = dw[s] / vw[s]
|
|
428
|
+
|
|
429
|
+
# rhs = db - vwb^T @ dwvw
|
|
430
|
+
rhs = np.empty(nb)
|
|
431
|
+
for j in range(nb):
|
|
432
|
+
acc = 0.0
|
|
433
|
+
for s in range(k):
|
|
434
|
+
acc += vwb[s, j] * dwvw[s]
|
|
435
|
+
rhs[j] = db[j] - acc
|
|
436
|
+
|
|
437
|
+
# Regularize if needed
|
|
438
|
+
for ii in range(nb):
|
|
439
|
+
if abs(vb2[ii, ii]) < 1e-10:
|
|
440
|
+
vb2[ii, ii] += 1e-8
|
|
441
|
+
|
|
442
|
+
stepbeta = np.linalg.solve(vb2, rhs)
|
|
443
|
+
|
|
444
|
+
# steplogw = dwvw - (vwb @ stepbeta) / vw (vw already floored above)
|
|
445
|
+
steplogw = np.empty(k)
|
|
446
|
+
for s in range(k):
|
|
447
|
+
acc = 0.0
|
|
448
|
+
for j in range(nb):
|
|
449
|
+
acc += vwb[s, j] * stepbeta[j]
|
|
450
|
+
steplogw[s] = dwvw[s] - acc / vw[s]
|
|
451
|
+
|
|
452
|
+
new_b = beta + stepbeta
|
|
453
|
+
new_w = logw + steplogw
|
|
454
|
+
|
|
455
|
+
loglikp = loglik
|
|
456
|
+
loglik, extb = _compute_pml_loglik_nb(
|
|
457
|
+
offset, X, new_b, new_w, fid, k, posindy, Y_vals,
|
|
458
|
+
cumsumy, gamma, alpha, lam, nind, nb
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
likdif = loglik - loglikp
|
|
462
|
+
stepd = 0
|
|
463
|
+
minstep = 40.0
|
|
464
|
+
|
|
465
|
+
while likdif < 0 or math.isinf(loglik):
|
|
466
|
+
stepd += 1
|
|
467
|
+
minstep /= 2.0
|
|
468
|
+
|
|
469
|
+
if stepd > maxstd:
|
|
470
|
+
likdif = 0.0
|
|
471
|
+
loglik = loglikp
|
|
472
|
+
mabsdb = 0.0
|
|
473
|
+
mabsdw = 0.0
|
|
474
|
+
for j in range(nb):
|
|
475
|
+
if abs(db[j]) > mabsdb:
|
|
476
|
+
mabsdb = abs(db[j])
|
|
477
|
+
for s in range(k):
|
|
478
|
+
if abs(dw[s]) > mabsdw:
|
|
479
|
+
mabsdw = abs(dw[s])
|
|
480
|
+
if mabsdb > convd or mabsdw > convd:
|
|
481
|
+
stepd += 1
|
|
482
|
+
break
|
|
483
|
+
|
|
484
|
+
for i in range(nb):
|
|
485
|
+
if -40 < stepbeta[i] < 40:
|
|
486
|
+
damp[i] /= 2.0
|
|
487
|
+
new_b[i] = beta[i] + stepbeta[i] * damp[i]
|
|
488
|
+
else:
|
|
489
|
+
new_b[i] = beta[i] + (minstep if stepbeta[i] > 0 else -minstep)
|
|
490
|
+
|
|
491
|
+
for s in range(k):
|
|
492
|
+
if -40 < steplogw[s] < 40:
|
|
493
|
+
damp_w[s] /= 2.0
|
|
494
|
+
new_w[s] = logw[s] + steplogw[s] * damp_w[s]
|
|
495
|
+
else:
|
|
496
|
+
new_w[s] = logw[s] + (minstep if steplogw[s] > 0 else -minstep)
|
|
497
|
+
|
|
498
|
+
loglik, extb = _compute_pml_loglik_nb(
|
|
499
|
+
offset, X, new_b, new_w, fid, k, posindy, Y_vals,
|
|
500
|
+
cumsumy, gamma, alpha, lam, nind, nb
|
|
501
|
+
)
|
|
502
|
+
likdif = loglik - loglikp
|
|
503
|
+
|
|
504
|
+
beta = new_b
|
|
505
|
+
logw = new_w
|
|
506
|
+
|
|
507
|
+
# Log-determinant
|
|
508
|
+
logdet = 0.0
|
|
509
|
+
for s in range(k):
|
|
510
|
+
logdet += math.log(max(abs(vw[s]), 1e-300))
|
|
511
|
+
|
|
512
|
+
# Second-order correction
|
|
513
|
+
sec_ord = 0.0
|
|
514
|
+
if ord_ > 1:
|
|
515
|
+
for i in range(nind):
|
|
516
|
+
if extb[i] < 1e-300:
|
|
517
|
+
gstar_extb_phi[i] = 0.0
|
|
518
|
+
else:
|
|
519
|
+
gstar_extb_phi[i] = gstar[i] / (1.0 + gamma / extb[i])
|
|
520
|
+
extbg = np.empty(nind)
|
|
521
|
+
for i in range(nind):
|
|
522
|
+
extbg[i] = extb[i] + gamma
|
|
523
|
+
if extbg[i] < 1e-300:
|
|
524
|
+
gstar_extb_phi[i] = 0.0
|
|
525
|
+
else:
|
|
526
|
+
gstar_extb_phi[i] /= extbg[i]
|
|
527
|
+
for s in range(k):
|
|
528
|
+
start = fid[s]
|
|
529
|
+
end = fid[s + 1]
|
|
530
|
+
acc = 0.0
|
|
531
|
+
for i in range(start, end):
|
|
532
|
+
acc += gstar_extb_phi[i]
|
|
533
|
+
vw[s] = gamma * acc + lam * math.exp(logw[s])
|
|
534
|
+
if vw[s] < 1e-15:
|
|
535
|
+
vw[s] = 1e-15
|
|
536
|
+
vws = np.empty(k)
|
|
537
|
+
for s in range(k):
|
|
538
|
+
vws[s] = vw[s] * vw[s]
|
|
539
|
+
|
|
540
|
+
for i in range(nind):
|
|
541
|
+
if extbg[i] < 1e-300:
|
|
542
|
+
gstar_extb_phi[i] = 0.0
|
|
543
|
+
else:
|
|
544
|
+
gstar_extb_phi[i] /= extbg[i]
|
|
545
|
+
third_der = np.empty(k)
|
|
546
|
+
for s in range(k):
|
|
547
|
+
start = fid[s]
|
|
548
|
+
end = fid[s + 1]
|
|
549
|
+
acc = 0.0
|
|
550
|
+
for i in range(start, end):
|
|
551
|
+
acc += gstar_extb_phi[i] * (gamma - extb[i])
|
|
552
|
+
third_der[s] = gamma * acc + lam * math.exp(logw[s])
|
|
553
|
+
acc = 0.0
|
|
554
|
+
for s in range(k):
|
|
555
|
+
acc += third_der[s] * third_der[s] / (vws[s] * vw[s])
|
|
556
|
+
sec_ord += 5.0 / 24.0 * acc
|
|
557
|
+
|
|
558
|
+
if ord_ > 2:
|
|
559
|
+
for i in range(nind):
|
|
560
|
+
if extbg[i] < 1e-300:
|
|
561
|
+
gstar_extb_phi[i] = 0.0
|
|
562
|
+
else:
|
|
563
|
+
gstar_extb_phi[i] /= extbg[i]
|
|
564
|
+
four_der = np.empty(k)
|
|
565
|
+
for s in range(k):
|
|
566
|
+
start = fid[s]
|
|
567
|
+
end = fid[s + 1]
|
|
568
|
+
acc = 0.0
|
|
569
|
+
for i in range(start, end):
|
|
570
|
+
extbp = extb[i] * extb[i]
|
|
571
|
+
acc += gstar_extb_phi[i] * (gamma*gamma + extbp - 4*gamma*extb[i])
|
|
572
|
+
four_der[s] = gamma * acc + lam * math.exp(logw[s])
|
|
573
|
+
acc = 0.0
|
|
574
|
+
for s in range(k):
|
|
575
|
+
acc += four_der[s] / vws[s]
|
|
576
|
+
sec_ord -= acc / 8.0
|
|
577
|
+
|
|
578
|
+
for i in range(nind):
|
|
579
|
+
if extbg[i] < 1e-300:
|
|
580
|
+
gstar_extb_phi[i] = 0.0
|
|
581
|
+
else:
|
|
582
|
+
gstar_extb_phi[i] /= extbg[i]
|
|
583
|
+
for s in range(k):
|
|
584
|
+
start = fid[s]
|
|
585
|
+
end = fid[s + 1]
|
|
586
|
+
acc2 = 0.0
|
|
587
|
+
for i in range(start, end):
|
|
588
|
+
extbp = extb[i] * extb[i]
|
|
589
|
+
acc2 += gstar_extb_phi[i] * (
|
|
590
|
+
gamma**3 - 11*gamma*gamma*extb[i]
|
|
591
|
+
+ 11*gamma*extbp - extbp*extb[i]
|
|
592
|
+
)
|
|
593
|
+
four_der[s] = gamma * acc2 + lam * math.exp(logw[s])
|
|
594
|
+
acc = 0.0
|
|
595
|
+
for s in range(k):
|
|
596
|
+
acc += four_der[s] * third_der[s] / (vws[s] * vws[s])
|
|
597
|
+
sec_ord += 7.0 / 48.0 * acc
|
|
598
|
+
|
|
599
|
+
return beta, logw, vb2, loglik, loglikp, logdet, step, stepd, sec_ord
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
@njit(cache=True)
|
|
603
|
+
def _get_cell_nb(X, fid, nb, k):
|
|
604
|
+
"""Numba-compiled cell-level covariate detection."""
|
|
605
|
+
iscell = np.zeros(nb)
|
|
606
|
+
for i in range(nb):
|
|
607
|
+
for j in range(k):
|
|
608
|
+
start = fid[j]
|
|
609
|
+
end = fid[j + 1]
|
|
610
|
+
ref = X[start, i]
|
|
611
|
+
found = False
|
|
612
|
+
for idx in range(start, end):
|
|
613
|
+
if X[idx, i] != ref:
|
|
614
|
+
found = True
|
|
615
|
+
break
|
|
616
|
+
if found:
|
|
617
|
+
iscell[i] = 1.0
|
|
618
|
+
break
|
|
619
|
+
return iscell
|
|
620
|
+
|
|
621
|
+
# ---------------------------------------------------------------------------
|
|
622
|
+
# Utility helpers
|
|
623
|
+
# ---------------------------------------------------------------------------
|
|
624
|
+
|
|
625
|
+
def _center_design(pred: np.ndarray):
|
|
626
|
+
"""Center design columns and scale to unit variance.
|
|
627
|
+
|
|
628
|
+
Matches nebula's ``center_m`` C++ function exactly.
|
|
629
|
+
|
|
630
|
+
Returns
|
|
631
|
+
-------
|
|
632
|
+
pred_centered : ndarray (n × p)
|
|
633
|
+
sds : ndarray (p,)
|
|
634
|
+
Column standard deviations (population, not sample).
|
|
635
|
+
The intercept column gets sd=0; zero-vector columns get sd=-1.
|
|
636
|
+
int_col : int
|
|
637
|
+
0-based index of the intercept column.
|
|
638
|
+
"""
|
|
639
|
+
pred = np.asarray(pred, dtype=np.float64).copy()
|
|
640
|
+
n, p = pred.shape
|
|
641
|
+
means = pred.mean(axis=0)
|
|
642
|
+
cm = pred - means
|
|
643
|
+
sds = np.sqrt((cm * cm).mean(axis=0))
|
|
644
|
+
|
|
645
|
+
int_col = None
|
|
646
|
+
for i in range(p):
|
|
647
|
+
if sds[i] > 0:
|
|
648
|
+
cm[:, i] /= sds[i]
|
|
649
|
+
else:
|
|
650
|
+
if pred[0, i] != 0:
|
|
651
|
+
# intercept column: fill with ones
|
|
652
|
+
cm[:, i] = 1.0
|
|
653
|
+
sds[i] = 0.0
|
|
654
|
+
int_col = i
|
|
655
|
+
else:
|
|
656
|
+
sds[i] = -1.0
|
|
657
|
+
|
|
658
|
+
if int_col is None:
|
|
659
|
+
raise ValueError("The design matrix must include an intercept term.")
|
|
660
|
+
if (sds == 0).sum() > 1 or (sds < 0).any():
|
|
661
|
+
raise ValueError(
|
|
662
|
+
"Some predictors have zero variation or a zero vector."
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
return cm, sds, int_col
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
def _cv_offset(offset: np.ndarray | None, nind: int):
|
|
669
|
+
"""Process offset, matching nebula's ``cv_offset`` C++ function.
|
|
670
|
+
|
|
671
|
+
Parameters
|
|
672
|
+
----------
|
|
673
|
+
offset : array (nind,) of *positive* scaling factors, or None.
|
|
674
|
+
nind : int
|
|
675
|
+
|
|
676
|
+
Returns
|
|
677
|
+
-------
|
|
678
|
+
log_offset : ndarray (nind,) — log of the offset
|
|
679
|
+
moffset : float — mean of log-offset (0 if offset was None)
|
|
680
|
+
cv2 : float — squared CV of the raw offset
|
|
681
|
+
"""
|
|
682
|
+
if offset is None:
|
|
683
|
+
log_offset = np.zeros(nind)
|
|
684
|
+
return log_offset, 0.0, 0.0
|
|
685
|
+
|
|
686
|
+
offset = np.asarray(offset, dtype=np.float64)
|
|
687
|
+
moffset_raw = offset.mean()
|
|
688
|
+
cv = 0.0
|
|
689
|
+
if moffset_raw > 0:
|
|
690
|
+
cv = np.sqrt(((offset - moffset_raw) ** 2).sum() / nind) / moffset_raw
|
|
691
|
+
log_offset = np.log(offset)
|
|
692
|
+
moffset = log_offset.mean()
|
|
693
|
+
return log_offset, moffset, cv * cv
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def _call_cumsumy(count, fid, k, ngene):
|
|
697
|
+
"""Sum counts per gene per sample, matching nebula's ``call_cumsumy``.
|
|
698
|
+
|
|
699
|
+
Parameters
|
|
700
|
+
----------
|
|
701
|
+
count : sparse or dense, genes × cells
|
|
702
|
+
fid : int array (k+1,) of segment boundaries (0-based)
|
|
703
|
+
k : int, number of samples
|
|
704
|
+
ngene : int
|
|
705
|
+
|
|
706
|
+
Returns
|
|
707
|
+
-------
|
|
708
|
+
cumsumy : ndarray (ngene, k)
|
|
709
|
+
"""
|
|
710
|
+
cumsumy = np.zeros((ngene, k), dtype=np.float64)
|
|
711
|
+
# Use dense or sparse slicing
|
|
712
|
+
for s in range(k):
|
|
713
|
+
start, end = fid[s], fid[s + 1]
|
|
714
|
+
chunk = count[:, start:end]
|
|
715
|
+
if hasattr(chunk, 'toarray'):
|
|
716
|
+
cumsumy[:, s] = np.asarray(chunk.sum(axis=1)).ravel()
|
|
717
|
+
else:
|
|
718
|
+
cumsumy[:, s] = chunk.sum(axis=1).ravel()
|
|
719
|
+
return cumsumy
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def _call_posindy(y_gene: np.ndarray):
|
|
723
|
+
"""Extract non-zero positions and counts for a single gene row.
|
|
724
|
+
|
|
725
|
+
Matches nebula's ``call_posindy``.
|
|
726
|
+
|
|
727
|
+
Parameters
|
|
728
|
+
----------
|
|
729
|
+
y_gene : 1D array (ncells,)
|
|
730
|
+
|
|
731
|
+
Returns
|
|
732
|
+
-------
|
|
733
|
+
dict with keys:
|
|
734
|
+
posindy : int array — 0-based indices of non-zero cells
|
|
735
|
+
Y : float array — corresponding count values
|
|
736
|
+
mct : float — mean count per cell
|
|
737
|
+
n_onetwo : int array (2,) — [n_one, n_two]
|
|
738
|
+
ytwo : float array — counts > 2
|
|
739
|
+
"""
|
|
740
|
+
nz = np.nonzero(y_gene)[0]
|
|
741
|
+
posindy = nz.astype(np.int32)
|
|
742
|
+
Y = y_gene[nz].astype(np.float64)
|
|
743
|
+
mct = Y.sum() / len(y_gene)
|
|
744
|
+
|
|
745
|
+
n_one = int((Y == 1).sum())
|
|
746
|
+
n_two = int((Y == 2).sum())
|
|
747
|
+
ytwo = Y[Y > 2]
|
|
748
|
+
|
|
749
|
+
return {
|
|
750
|
+
'posindy': posindy,
|
|
751
|
+
'Y': Y,
|
|
752
|
+
'mct': mct,
|
|
753
|
+
'n_onetwo': np.array([n_one, n_two], dtype=np.int32),
|
|
754
|
+
'ytwo': ytwo,
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def _get_cell(X, fid, nb, k):
|
|
759
|
+
"""Identify cell-level covariates (vary within a subject).
|
|
760
|
+
|
|
761
|
+
Delegates to numba-compiled ``_get_cell_nb``.
|
|
762
|
+
"""
|
|
763
|
+
return _get_cell_nb(X, fid, nb, k)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def _get_cv(offset, X, beta, cell_ind, ncell, nc):
|
|
767
|
+
"""Compute squared CV of fitted values at cell-level predictors.
|
|
768
|
+
|
|
769
|
+
Matches nebula's ``get_cv``.
|
|
770
|
+
"""
|
|
771
|
+
extb = offset.copy()
|
|
772
|
+
for i in range(ncell):
|
|
773
|
+
ind = int(cell_ind[i])
|
|
774
|
+
extb = extb + X[:, ind] * beta[ind]
|
|
775
|
+
with np.errstate(over='ignore'):
|
|
776
|
+
extb = np.exp(extb)
|
|
777
|
+
m = extb.mean()
|
|
778
|
+
if m > 0:
|
|
779
|
+
return ((extb - m) ** 2).sum() / nc / (m * m)
|
|
780
|
+
return 0.0
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
# ---------------------------------------------------------------------------
|
|
784
|
+
# NBGMM log-likelihood + gradient (ptmg_ll_der)
|
|
785
|
+
# ---------------------------------------------------------------------------
|
|
786
|
+
|
|
787
|
+
def _ptmg_negll_and_grad(para, X, offset, Y, n_onetwo, ytwo, fid, cumsumy,
|
|
788
|
+
posind, posindy, nb, nind, k):
|
|
789
|
+
"""Negative log-likelihood and gradient for NBGMM (L-BFGS-B stage).
|
|
790
|
+
|
|
791
|
+
Delegates to numba-compiled ``_ptmg_negll_and_grad_nb``.
|
|
792
|
+
"""
|
|
793
|
+
return _ptmg_negll_and_grad_nb(
|
|
794
|
+
para, X, offset, Y,
|
|
795
|
+
int(n_onetwo[0]), int(n_onetwo[1]), ytwo,
|
|
796
|
+
fid, cumsumy, posind, posindy, nb, nind, k
|
|
797
|
+
)
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
# ---------------------------------------------------------------------------
|
|
801
|
+
# Penalized ML optimizer (opt_pml for NBGMM)
|
|
802
|
+
# ---------------------------------------------------------------------------
|
|
803
|
+
|
|
804
|
+
def _opt_pml(X, offset, Y_vals, fid, cumsumy, posind, posindy, nb, nind, k,
|
|
805
|
+
beta_init, sigma, reml=0, eps=1e-6, ord_=1):
|
|
806
|
+
"""Port of nebula's ``opt_pml`` C++ function.
|
|
807
|
+
|
|
808
|
+
Delegates to numba-compiled ``_opt_pml_nb``.
|
|
809
|
+
"""
|
|
810
|
+
beta, logw, vb2, loglik, loglikp, logdet, step, stepd, sec_ord = \
|
|
811
|
+
_opt_pml_nb(X, offset, Y_vals, fid, cumsumy, posindy, nb, nind, k,
|
|
812
|
+
beta_init, sigma[0], sigma[1], eps, ord_)
|
|
813
|
+
return {
|
|
814
|
+
'beta': beta,
|
|
815
|
+
'logw': logw,
|
|
816
|
+
'var': vb2,
|
|
817
|
+
'loglik': loglik,
|
|
818
|
+
'loglikp': loglikp,
|
|
819
|
+
'logdet': logdet,
|
|
820
|
+
'iter': int(step),
|
|
821
|
+
'damp': int(stepd),
|
|
822
|
+
'second': sec_ord,
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
# ---------------------------------------------------------------------------
|
|
827
|
+
# Convergence check
|
|
828
|
+
# ---------------------------------------------------------------------------
|
|
829
|
+
|
|
830
|
+
def _check_conv(repml, conv, nb, vare, min_bounds, max_bounds, cutoff=1e-8):
|
|
831
|
+
"""Port of nebula's ``check_conv``."""
|
|
832
|
+
if conv == 1:
|
|
833
|
+
if vare[0] == max_bounds[0] or vare[1] == min_bounds[1]:
|
|
834
|
+
conv = -60
|
|
835
|
+
elif np.isnan(repml['loglik']):
|
|
836
|
+
conv = -30
|
|
837
|
+
elif repml['iter'] == 50:
|
|
838
|
+
conv = -20
|
|
839
|
+
elif repml['damp'] == 11:
|
|
840
|
+
conv = -10
|
|
841
|
+
elif repml['damp'] == 12:
|
|
842
|
+
conv = -40
|
|
843
|
+
|
|
844
|
+
if nb > 1:
|
|
845
|
+
try:
|
|
846
|
+
eigvals = np.linalg.eigvalsh(repml['var'])
|
|
847
|
+
if eigvals.min() < cutoff:
|
|
848
|
+
conv = -25
|
|
849
|
+
except np.linalg.LinAlgError:
|
|
850
|
+
conv = -25
|
|
851
|
+
|
|
852
|
+
return conv
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
# ---------------------------------------------------------------------------
|
|
856
|
+
# Per-gene fitting
|
|
857
|
+
# ---------------------------------------------------------------------------
|
|
858
|
+
|
|
859
|
+
def _fit_gene_nebula_ln(gene_idx, y_gene, X, offset, fid, cumsumy_gene,
|
|
860
|
+
posind, nb, nind, k, sds, int_col, moffset,
|
|
861
|
+
min_bounds, max_bounds, mfs, cutoff_cell, kappa):
|
|
862
|
+
"""Fit NBGMM (NEBULA-LN) for a single gene.
|
|
863
|
+
|
|
864
|
+
Returns
|
|
865
|
+
-------
|
|
866
|
+
tuple: (beta_rescaled, se_rescaled, sigma2, inv_phi, conv, logw)
|
|
867
|
+
"""
|
|
868
|
+
posv = _call_posindy(y_gene)
|
|
869
|
+
posindy = posv['posindy']
|
|
870
|
+
Y = posv['Y']
|
|
871
|
+
mct = posv['mct']
|
|
872
|
+
n_onetwo = posv['n_onetwo']
|
|
873
|
+
ytwo = posv['ytwo']
|
|
874
|
+
|
|
875
|
+
# ord parameter
|
|
876
|
+
if mct * mfs < 3:
|
|
877
|
+
ord_ = 3
|
|
878
|
+
else:
|
|
879
|
+
ord_ = 1
|
|
880
|
+
|
|
881
|
+
# Initial beta
|
|
882
|
+
lmct = np.log(max(mct, 1e-300))
|
|
883
|
+
para_init = np.zeros(nb + 2)
|
|
884
|
+
para_init[int_col] = lmct - moffset
|
|
885
|
+
para_init[nb] = 1.0 # sigma_param
|
|
886
|
+
para_init[nb + 1] = 1.0 # phi (cell-level overdispersion)
|
|
887
|
+
|
|
888
|
+
lower = np.concatenate([np.full(nb, -100.0), [min_bounds[0], min_bounds[1]]])
|
|
889
|
+
upper = np.concatenate([np.full(nb, 100.0), [max_bounds[0], max_bounds[1]]])
|
|
890
|
+
bounds = list(zip(lower, upper))
|
|
891
|
+
|
|
892
|
+
# Stage 1: L-BFGS-B
|
|
893
|
+
try:
|
|
894
|
+
res = _minimize(
|
|
895
|
+
_ptmg_negll_and_grad,
|
|
896
|
+
para_init,
|
|
897
|
+
args=(X, offset, Y, n_onetwo, ytwo, fid, cumsumy_gene,
|
|
898
|
+
posind, posindy, nb, nind, k),
|
|
899
|
+
method='L-BFGS-B',
|
|
900
|
+
jac=True,
|
|
901
|
+
bounds=bounds,
|
|
902
|
+
options={'ftol': 1e-6, 'maxiter': 200},
|
|
903
|
+
)
|
|
904
|
+
refp = res.x
|
|
905
|
+
is_conv = 1 if res.success else 0
|
|
906
|
+
except Exception:
|
|
907
|
+
# Fallback: use initial values
|
|
908
|
+
refp = para_init.copy()
|
|
909
|
+
is_conv = 0
|
|
910
|
+
|
|
911
|
+
conv = is_conv
|
|
912
|
+
vare = np.array([refp[nb], refp[nb + 1]])
|
|
913
|
+
|
|
914
|
+
# Determine cell-level predictor CV
|
|
915
|
+
cell_ind_arr = _get_cell(X, fid, nb, k)
|
|
916
|
+
ncell = int(cell_ind_arr.sum())
|
|
917
|
+
cell_ind = np.where(cell_ind_arr > 0)[0]
|
|
918
|
+
if ncell > 0:
|
|
919
|
+
try:
|
|
920
|
+
cv2p = _get_cv(offset, X, refp[:nb], cell_ind, ncell, nind)
|
|
921
|
+
except Exception:
|
|
922
|
+
cv2p = float('nan')
|
|
923
|
+
else:
|
|
924
|
+
cv2p = 0.0
|
|
925
|
+
|
|
926
|
+
gni = mfs * vare[1]
|
|
927
|
+
|
|
928
|
+
# Determine if we need HL refinement
|
|
929
|
+
fit = 1
|
|
930
|
+
if (gni < cutoff_cell) or (conv == 0) or np.isnan(cv2p):
|
|
931
|
+
# Would need NEBULA-HL refinement — skip for LN-only impl
|
|
932
|
+
# Just note fit=2 for algorithm tracking
|
|
933
|
+
fit = 2
|
|
934
|
+
else:
|
|
935
|
+
kappa_obs = gni / (1.0 + cv2p)
|
|
936
|
+
if (kappa_obs < 20) or (kappa_obs < kappa and vare[0] < 8.0 / kappa_obs):
|
|
937
|
+
fit = 3
|
|
938
|
+
|
|
939
|
+
# Beta for PML: start from intercept init, not L-BFGS-B result
|
|
940
|
+
betae = np.zeros(nb)
|
|
941
|
+
betae[int_col] = lmct - moffset
|
|
942
|
+
|
|
943
|
+
# Bias correction for intercept
|
|
944
|
+
betae[int_col] -= vare[0] / 2.0
|
|
945
|
+
|
|
946
|
+
# Stage 2: Penalized ML
|
|
947
|
+
try:
|
|
948
|
+
repml = _opt_pml(
|
|
949
|
+
X, offset, Y, fid, cumsumy_gene, posind, posindy,
|
|
950
|
+
nb, nind, k, betae, vare, reml=0, eps=1e-6, ord_=ord_,
|
|
951
|
+
)
|
|
952
|
+
except Exception:
|
|
953
|
+
# Numerical failure in PML — mark as non-converged
|
|
954
|
+
return (np.full(nb, np.nan), np.full(nb, np.nan),
|
|
955
|
+
vare[0], 1.0 / vare[1] if vare[1] > 0 else np.inf,
|
|
956
|
+
-30, 0, np.zeros(k))
|
|
957
|
+
|
|
958
|
+
conv = _check_conv(repml, conv, nb, vare, min_bounds, max_bounds)
|
|
959
|
+
|
|
960
|
+
# Invert Fisher information to get covariance
|
|
961
|
+
beta_pml = repml['beta']
|
|
962
|
+
logw = repml['logw']
|
|
963
|
+
fisher = repml['var'] # This is vb2 (the Schur complement)
|
|
964
|
+
|
|
965
|
+
se = np.full(nb, np.nan)
|
|
966
|
+
if conv != -25:
|
|
967
|
+
try:
|
|
968
|
+
cov = np.linalg.inv(fisher)
|
|
969
|
+
se = np.sqrt(np.maximum(np.diag(cov), 0.0))
|
|
970
|
+
except np.linalg.LinAlgError:
|
|
971
|
+
conv = -25
|
|
972
|
+
|
|
973
|
+
# Rescale by column SDs (undo centering)
|
|
974
|
+
sds_use = sds.copy()
|
|
975
|
+
sds_use[int_col] = 1.0
|
|
976
|
+
beta_rescaled = beta_pml / sds_use
|
|
977
|
+
se_rescaled = se / sds_use
|
|
978
|
+
|
|
979
|
+
sigma2 = vare[0]
|
|
980
|
+
inv_phi = 1.0 / vare[1] if vare[1] > 0 else np.inf
|
|
981
|
+
|
|
982
|
+
return beta_rescaled, se_rescaled, sigma2, inv_phi, conv, fit, logw
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
# ---------------------------------------------------------------------------
|
|
986
|
+
# Main entry points
|
|
987
|
+
# ---------------------------------------------------------------------------
|
|
988
|
+
|
|
989
|
+
def glm_sc_fit(y, cell_meta=None, design=None, sample=None,
|
|
990
|
+
offset=None, norm_method='TMM', method='nebula',
|
|
991
|
+
min_bounds=None, max_bounds=None,
|
|
992
|
+
cpc=0.005, mincp=5, cutoff_cell=20, kappa=800,
|
|
993
|
+
ncore=1, verbose=True):
|
|
994
|
+
"""Fit a single-cell NB gamma mixed model (NEBULA-LN).
|
|
995
|
+
|
|
996
|
+
Parameters
|
|
997
|
+
----------
|
|
998
|
+
y : AnnData, dict, or ndarray
|
|
999
|
+
Count data. AnnData objects are (cells × genes); raw matrices
|
|
1000
|
+
should be (genes × cells).
|
|
1001
|
+
cell_meta : DataFrame, optional
|
|
1002
|
+
Cell-level metadata. Extracted from ``y.obs`` for AnnData.
|
|
1003
|
+
design : ndarray or str, optional
|
|
1004
|
+
Design matrix (cells × predictors) with an intercept column.
|
|
1005
|
+
If ``None``, an intercept-only model is fitted.
|
|
1006
|
+
sample : str or array-like
|
|
1007
|
+
Subject/sample identifiers. If a string, it names a column in
|
|
1008
|
+
*cell_meta*.
|
|
1009
|
+
offset : array-like, optional
|
|
1010
|
+
Positive per-cell scaling factors. If provided, ``norm_method``
|
|
1011
|
+
is ignored.
|
|
1012
|
+
norm_method : str
|
|
1013
|
+
``'TMM'`` (default): compute per-cell offset from per-cell
|
|
1014
|
+
library sizes and pseudobulk TMM normalization factors.
|
|
1015
|
+
``'none'``: all-ones offset (original nebula behaviour).
|
|
1016
|
+
method : str
|
|
1017
|
+
``'nebula'`` (default): NEBULA-LN algorithm.
|
|
1018
|
+
min_bounds, max_bounds : tuple of float, optional
|
|
1019
|
+
Bounds for (sigma_param, phi). Defaults (1e-4, 1e-4) and
|
|
1020
|
+
(10, 1000).
|
|
1021
|
+
cpc : float
|
|
1022
|
+
Minimum mean counts per cell for gene filtering.
|
|
1023
|
+
mincp : int
|
|
1024
|
+
Minimum non-zero cells for gene filtering.
|
|
1025
|
+
cutoff_cell : float
|
|
1026
|
+
Threshold for NEBULA-HL fallback (cells_per_subject × phi).
|
|
1027
|
+
kappa : float
|
|
1028
|
+
Accuracy threshold for subject-level overdispersion.
|
|
1029
|
+
ncore : int
|
|
1030
|
+
Number of parallel workers (1 = sequential).
|
|
1031
|
+
verbose : bool
|
|
1032
|
+
Print progress messages.
|
|
1033
|
+
|
|
1034
|
+
Returns
|
|
1035
|
+
-------
|
|
1036
|
+
dict
|
|
1037
|
+
DGEGLM-like fit result with keys ``'coefficients'``, ``'se'``,
|
|
1038
|
+
``'dispersion'``, ``'design'``, ``'offset'``, ``'genes'``,
|
|
1039
|
+
``'sigma_sample'``, ``'convergence'``, ``'method'``, etc.
|
|
1040
|
+
Pass to ``top_tags(fit, coef=...)`` for Wald testing.
|
|
1041
|
+
"""
|
|
1042
|
+
if min_bounds is None:
|
|
1043
|
+
min_bounds = (1e-4, 1e-4)
|
|
1044
|
+
if max_bounds is None:
|
|
1045
|
+
max_bounds = (10.0, 1000.0)
|
|
1046
|
+
|
|
1047
|
+
# --- Input handling ---
|
|
1048
|
+
gene_names = None
|
|
1049
|
+
try:
|
|
1050
|
+
import anndata
|
|
1051
|
+
is_anndata = isinstance(y, anndata.AnnData)
|
|
1052
|
+
except ImportError:
|
|
1053
|
+
is_anndata = False
|
|
1054
|
+
|
|
1055
|
+
if is_anndata:
|
|
1056
|
+
adata = y
|
|
1057
|
+
X_raw = adata.X
|
|
1058
|
+
if hasattr(X_raw, 'toarray'):
|
|
1059
|
+
X_raw = X_raw.toarray()
|
|
1060
|
+
counts = np.asarray(X_raw, dtype=np.float64).T # genes × cells
|
|
1061
|
+
if cell_meta is None:
|
|
1062
|
+
cell_meta = adata.obs.copy()
|
|
1063
|
+
gene_names = np.array(adata.var_names)
|
|
1064
|
+
elif isinstance(y, dict) and 'counts' in y:
|
|
1065
|
+
counts = np.asarray(y['counts'], dtype=np.float64)
|
|
1066
|
+
if cell_meta is None and 'obs' in y:
|
|
1067
|
+
cell_meta = y['obs']
|
|
1068
|
+
if gene_names is None and 'genes' in y:
|
|
1069
|
+
gene_names = np.asarray(y['genes'])
|
|
1070
|
+
else:
|
|
1071
|
+
if hasattr(y, 'toarray'):
|
|
1072
|
+
counts = np.asarray(y.toarray(), dtype=np.float64)
|
|
1073
|
+
else:
|
|
1074
|
+
counts = np.asarray(y, dtype=np.float64)
|
|
1075
|
+
|
|
1076
|
+
ngene, nind = counts.shape
|
|
1077
|
+
if nind < 2:
|
|
1078
|
+
raise ValueError("There is no more than one cell in the count matrix.")
|
|
1079
|
+
|
|
1080
|
+
# --- Resolve sample IDs ---
|
|
1081
|
+
if sample is None:
|
|
1082
|
+
raise ValueError(
|
|
1083
|
+
"The 'sample' argument is required. Provide per-cell sample IDs."
|
|
1084
|
+
)
|
|
1085
|
+
if isinstance(sample, str):
|
|
1086
|
+
if cell_meta is None:
|
|
1087
|
+
raise ValueError(
|
|
1088
|
+
f"sample='{sample}' requires cell_meta with that column."
|
|
1089
|
+
)
|
|
1090
|
+
sample_ids = np.asarray(cell_meta[sample])
|
|
1091
|
+
else:
|
|
1092
|
+
sample_ids = np.asarray(sample)
|
|
1093
|
+
if len(sample_ids) != nind:
|
|
1094
|
+
raise ValueError(
|
|
1095
|
+
"Length of sample IDs should equal the number of cells."
|
|
1096
|
+
)
|
|
1097
|
+
|
|
1098
|
+
# --- Save design column names before sort (which may convert DataFrame → numpy) ---
|
|
1099
|
+
_design_colnames = None
|
|
1100
|
+
if design is not None and hasattr(design, 'columns'):
|
|
1101
|
+
_design_colnames = list(design.columns)
|
|
1102
|
+
|
|
1103
|
+
# --- Sort cells by sample (group_cell) ---
|
|
1104
|
+
sample_ids_str = np.array([str(s) for s in sample_ids])
|
|
1105
|
+
levels = list(dict.fromkeys(sample_ids_str)) # unique, order-preserving
|
|
1106
|
+
sample_numeric = np.array(
|
|
1107
|
+
[levels.index(s) + 1 for s in sample_ids_str], dtype=np.int32
|
|
1108
|
+
)
|
|
1109
|
+
if not np.all(sample_numeric[:-1] <= sample_numeric[1:]):
|
|
1110
|
+
# Need to sort
|
|
1111
|
+
order = np.argsort(sample_numeric, kind='stable')
|
|
1112
|
+
counts = counts[:, order]
|
|
1113
|
+
sample_numeric = sample_numeric[order]
|
|
1114
|
+
sample_ids_str = sample_ids_str[order]
|
|
1115
|
+
if cell_meta is not None:
|
|
1116
|
+
if isinstance(cell_meta, pd.DataFrame):
|
|
1117
|
+
cell_meta = cell_meta.iloc[order].reset_index(drop=True)
|
|
1118
|
+
else:
|
|
1119
|
+
cell_meta = cell_meta[order]
|
|
1120
|
+
if offset is not None:
|
|
1121
|
+
offset = np.asarray(offset, dtype=np.float64)[order]
|
|
1122
|
+
if design is not None and not isinstance(design, str):
|
|
1123
|
+
design = np.asarray(design, dtype=np.float64)[order]
|
|
1124
|
+
|
|
1125
|
+
k = len(levels)
|
|
1126
|
+
# Build fid: 0-based start index of each sample's cells + sentinel
|
|
1127
|
+
diffs = np.where(np.concatenate([[1], np.diff(sample_numeric)]))[0]
|
|
1128
|
+
fid = np.concatenate([diffs, [nind]]).astype(np.int32)
|
|
1129
|
+
|
|
1130
|
+
# --- Design matrix ---
|
|
1131
|
+
if design is None:
|
|
1132
|
+
pred = np.ones((nind, 1), dtype=np.float64)
|
|
1133
|
+
predn = None
|
|
1134
|
+
sds = np.array([0.0])
|
|
1135
|
+
int_col = 0
|
|
1136
|
+
else:
|
|
1137
|
+
if isinstance(design, str):
|
|
1138
|
+
# Formula — resolve against cell_meta
|
|
1139
|
+
from .utils import model_matrix
|
|
1140
|
+
pred = np.asarray(
|
|
1141
|
+
model_matrix(design, cell_meta), dtype=np.float64
|
|
1142
|
+
)
|
|
1143
|
+
else:
|
|
1144
|
+
pred = np.asarray(design, dtype=np.float64)
|
|
1145
|
+
if pred.shape[0] != nind:
|
|
1146
|
+
raise ValueError(
|
|
1147
|
+
"Design matrix rows must equal number of cells."
|
|
1148
|
+
)
|
|
1149
|
+
predn = None
|
|
1150
|
+
if hasattr(design, 'columns'):
|
|
1151
|
+
predn = list(design.columns)
|
|
1152
|
+
elif isinstance(design, pd.DataFrame):
|
|
1153
|
+
predn = list(design.columns)
|
|
1154
|
+
pred, sds, int_col = _center_design(pred)
|
|
1155
|
+
|
|
1156
|
+
nb = pred.shape[1]
|
|
1157
|
+
|
|
1158
|
+
# --- Offset ---
|
|
1159
|
+
if offset is not None:
|
|
1160
|
+
# User-provided offset (positive scaling factors)
|
|
1161
|
+
log_offset, moffset, cv2 = _cv_offset(offset, nind)
|
|
1162
|
+
elif norm_method.upper() == 'TMM':
|
|
1163
|
+
# Pseudobulk TMM normalization → per-cell offset
|
|
1164
|
+
lib_size = counts.sum(axis=0).astype(np.float64)
|
|
1165
|
+
pb = np.zeros((ngene, k), dtype=np.float64)
|
|
1166
|
+
for s in range(k):
|
|
1167
|
+
start, end = fid[s], fid[s + 1]
|
|
1168
|
+
chunk = counts[:, start:end]
|
|
1169
|
+
if hasattr(chunk, 'toarray'):
|
|
1170
|
+
pb[:, s] = np.asarray(chunk.sum(axis=1)).ravel()
|
|
1171
|
+
else:
|
|
1172
|
+
pb[:, s] = chunk.sum(axis=1).ravel()
|
|
1173
|
+
pb_dge = make_dgelist(pb)
|
|
1174
|
+
pb_dge = calc_norm_factors(pb_dge)
|
|
1175
|
+
norm_factors = pb_dge['samples']['norm.factors'].values
|
|
1176
|
+
# Expand sample-level norm factors to per-cell
|
|
1177
|
+
cell_nf = np.empty(nind, dtype=np.float64)
|
|
1178
|
+
for s in range(k):
|
|
1179
|
+
start, end = fid[s], fid[s + 1]
|
|
1180
|
+
cell_nf[start:end] = norm_factors[s]
|
|
1181
|
+
# Floor at 0.5 to avoid log(0) for zero-count cells
|
|
1182
|
+
offset_raw = np.maximum(lib_size * cell_nf, 0.5)
|
|
1183
|
+
log_offset, moffset, cv2 = _cv_offset(offset_raw, nind)
|
|
1184
|
+
else:
|
|
1185
|
+
# No normalization (original nebula behaviour)
|
|
1186
|
+
log_offset, moffset, cv2 = _cv_offset(None, nind)
|
|
1187
|
+
|
|
1188
|
+
# --- CPS check ---
|
|
1189
|
+
mfs = nind / k
|
|
1190
|
+
if mfs < 30 and verbose:
|
|
1191
|
+
warnings.warn(
|
|
1192
|
+
f"The average number of cells per subject ({mfs:.1f}) is less "
|
|
1193
|
+
f"than 30. NEBULA-LN may be inaccurate for small cell counts."
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
# --- Cumsumy ---
|
|
1197
|
+
cumsumy = _call_cumsumy(counts, fid, k, ngene)
|
|
1198
|
+
|
|
1199
|
+
# --- Gene filtering ---
|
|
1200
|
+
# Non-zero cell counts per gene
|
|
1201
|
+
if hasattr(counts, 'nnz'):
|
|
1202
|
+
# sparse
|
|
1203
|
+
from scipy.sparse import issparse
|
|
1204
|
+
nz_per_gene = np.diff(counts.indptr) if hasattr(counts, 'indptr') else \
|
|
1205
|
+
np.array([(counts[g, :] != 0).sum() for g in range(ngene)])
|
|
1206
|
+
else:
|
|
1207
|
+
nz_per_gene = (counts != 0).sum(axis=1)
|
|
1208
|
+
|
|
1209
|
+
mean_cpc = cumsumy.sum(axis=1) / nind
|
|
1210
|
+
mask_cpc = mean_cpc > cpc
|
|
1211
|
+
mask_mincp = nz_per_gene >= mincp
|
|
1212
|
+
gene_mask = mask_cpc & mask_mincp
|
|
1213
|
+
gid = np.where(gene_mask)[0]
|
|
1214
|
+
lgid = len(gid)
|
|
1215
|
+
|
|
1216
|
+
if verbose:
|
|
1217
|
+
print(f"Remove {ngene - lgid} genes having low expression.")
|
|
1218
|
+
if lgid == 0:
|
|
1219
|
+
raise ValueError("No gene passed the filtering.")
|
|
1220
|
+
if verbose:
|
|
1221
|
+
print(f"Analyzing {lgid} genes with {k} subjects and {nind} cells.")
|
|
1222
|
+
|
|
1223
|
+
# posind per gene: which samples have non-zero counts
|
|
1224
|
+
posind_per_gene = [np.where(cumsumy[g, :] > 0)[0] for g in gid]
|
|
1225
|
+
|
|
1226
|
+
# --- Per-gene fitting ---
|
|
1227
|
+
def _fit_one(idx):
|
|
1228
|
+
g = gid[idx]
|
|
1229
|
+
if hasattr(counts, 'toarray'):
|
|
1230
|
+
y_gene = np.asarray(counts[g, :].toarray()).ravel()
|
|
1231
|
+
else:
|
|
1232
|
+
y_gene = counts[g, :]
|
|
1233
|
+
return _fit_gene_nebula_ln(
|
|
1234
|
+
g, y_gene, pred, log_offset, fid, cumsumy[g, :],
|
|
1235
|
+
posind_per_gene[idx], nb, nind, k, sds, int_col, moffset,
|
|
1236
|
+
min_bounds, max_bounds, mfs, cutoff_cell, kappa,
|
|
1237
|
+
)
|
|
1238
|
+
|
|
1239
|
+
if ncore > 1:
|
|
1240
|
+
# Parallel execution
|
|
1241
|
+
with ProcessPoolExecutor(max_workers=ncore) as executor:
|
|
1242
|
+
results = list(executor.map(_fit_one, range(lgid)))
|
|
1243
|
+
else:
|
|
1244
|
+
results = []
|
|
1245
|
+
for idx in range(lgid):
|
|
1246
|
+
if verbose and lgid > 100 and idx % max(1, lgid // 10) == 0:
|
|
1247
|
+
print(f" Gene {idx + 1}/{lgid}...")
|
|
1248
|
+
results.append(_fit_one(idx))
|
|
1249
|
+
|
|
1250
|
+
# --- Collect results ---
|
|
1251
|
+
coefficients = np.zeros((lgid, nb))
|
|
1252
|
+
se_arr = np.zeros((lgid, nb))
|
|
1253
|
+
sigma_sample = np.zeros(lgid)
|
|
1254
|
+
cell_disp = np.zeros(lgid) # 1/phi
|
|
1255
|
+
convergence = np.zeros(lgid, dtype=np.int32)
|
|
1256
|
+
algorithm_codes = np.zeros(lgid, dtype=np.int32)
|
|
1257
|
+
|
|
1258
|
+
for idx, res in enumerate(results):
|
|
1259
|
+
beta_r, se_r, sigma2, inv_phi, conv, fit, logw = res
|
|
1260
|
+
coefficients[idx, :] = beta_r
|
|
1261
|
+
se_arr[idx, :] = se_r
|
|
1262
|
+
sigma_sample[idx] = sigma2
|
|
1263
|
+
cell_disp[idx] = inv_phi
|
|
1264
|
+
convergence[idx] = conv
|
|
1265
|
+
algorithm_codes[idx] = fit
|
|
1266
|
+
|
|
1267
|
+
# --- Resolve predictor names ---
|
|
1268
|
+
if predn is None:
|
|
1269
|
+
predn = _design_colnames
|
|
1270
|
+
if predn is None:
|
|
1271
|
+
if design is not None and hasattr(design, 'columns'):
|
|
1272
|
+
predn = list(design.columns)
|
|
1273
|
+
if predn is None:
|
|
1274
|
+
predn = [f"V{i+1}" for i in range(nb)]
|
|
1275
|
+
|
|
1276
|
+
# --- Gene annotation DataFrame ---
|
|
1277
|
+
if gene_names is not None:
|
|
1278
|
+
genes_df = pd.DataFrame({'gene': gene_names[gid]})
|
|
1279
|
+
else:
|
|
1280
|
+
genes_df = None
|
|
1281
|
+
|
|
1282
|
+
# --- Average log abundance for filtered genes ---
|
|
1283
|
+
ave_log_abund = np.log2(mean_cpc[gid] + 0.5)
|
|
1284
|
+
|
|
1285
|
+
# --- DGEGLM-like return ---
|
|
1286
|
+
return {
|
|
1287
|
+
'coefficients': coefficients,
|
|
1288
|
+
'se': se_arr,
|
|
1289
|
+
'dispersion': cell_disp,
|
|
1290
|
+
'sigma_sample': sigma_sample,
|
|
1291
|
+
'convergence': convergence,
|
|
1292
|
+
'design': pred,
|
|
1293
|
+
'offset': log_offset,
|
|
1294
|
+
'genes': genes_df,
|
|
1295
|
+
'gene_mask': gene_mask,
|
|
1296
|
+
'method': 'nebula_ln',
|
|
1297
|
+
'ncells': nind,
|
|
1298
|
+
'nsamples': k,
|
|
1299
|
+
'predictor_names': predn,
|
|
1300
|
+
'sample_map': sample_ids_str,
|
|
1301
|
+
'samples_unique': np.array(levels),
|
|
1302
|
+
'ave_log_abundance': ave_log_abund,
|
|
1303
|
+
}
|
|
1304
|
+
|
|
1305
|
+
|
|
1306
|
+
def shrink_sc_disp(fit, counts=None, covariate=None, robust=True):
|
|
1307
|
+
"""Empirical Bayes shrinkage of cell-level NB dispersion.
|
|
1308
|
+
|
|
1309
|
+
Shrinks the per-gene NB overdispersion parameter phi toward a
|
|
1310
|
+
(possibly trended) prior using limma's squeezeVar framework.
|
|
1311
|
+
|
|
1312
|
+
Parameters
|
|
1313
|
+
----------
|
|
1314
|
+
fit : dict
|
|
1315
|
+
Output from ``glm_sc_fit()``.
|
|
1316
|
+
counts : ndarray or sparse matrix, optional
|
|
1317
|
+
Gene-by-cell count matrix (same genes/ordering as used in
|
|
1318
|
+
``glm_sc_fit``). Used to compute log-mean abundance as
|
|
1319
|
+
covariate for the trended prior. If *None* and
|
|
1320
|
+
``fit['ave_log_abundance']`` exists, that is used instead.
|
|
1321
|
+
covariate : array-like, optional
|
|
1322
|
+
Custom covariate for the trended prior. Overrides the
|
|
1323
|
+
abundance covariate derived from *counts*.
|
|
1324
|
+
robust : bool
|
|
1325
|
+
Use robust estimation (default True). Protects against
|
|
1326
|
+
outlier genes with extremely high or low dispersion.
|
|
1327
|
+
|
|
1328
|
+
Returns
|
|
1329
|
+
-------
|
|
1330
|
+
dict
|
|
1331
|
+
The input *fit* dict, updated in-place with new keys:
|
|
1332
|
+
|
|
1333
|
+
- ``phi_raw`` : raw per-gene phi (= 1/dispersion)
|
|
1334
|
+
- ``phi_post`` : posterior (shrunk) phi
|
|
1335
|
+
- ``phi_prior`` : prior phi (scalar or trended)
|
|
1336
|
+
- ``df_residual`` : residual degrees of freedom
|
|
1337
|
+
- ``df_prior_phi`` : prior df from empirical Bayes
|
|
1338
|
+
- ``dispersion_shrunk`` : 1/phi_post (shrunk dispersion)
|
|
1339
|
+
"""
|
|
1340
|
+
import warnings
|
|
1341
|
+
from .limma_port import squeeze_var
|
|
1342
|
+
|
|
1343
|
+
dispersion = fit['dispersion']
|
|
1344
|
+
ngenes = len(dispersion)
|
|
1345
|
+
|
|
1346
|
+
# Convert to phi = 1/dispersion
|
|
1347
|
+
with np.errstate(divide='ignore'):
|
|
1348
|
+
phi_raw = np.where(dispersion > 0, 1.0 / dispersion, np.inf)
|
|
1349
|
+
|
|
1350
|
+
# Convergence mask: only use converged genes for prior estimation
|
|
1351
|
+
conv_mask = fit['convergence'] == 1
|
|
1352
|
+
|
|
1353
|
+
# Floor phi at a small positive value; mark inf as NaN
|
|
1354
|
+
phi_floor = 1e-8
|
|
1355
|
+
phi_use = np.maximum(phi_raw.copy(), phi_floor)
|
|
1356
|
+
phi_use[~np.isfinite(phi_use)] = np.nan
|
|
1357
|
+
|
|
1358
|
+
# Residual degrees of freedom: N - p - (K - 1)
|
|
1359
|
+
n_cells = fit['ncells']
|
|
1360
|
+
n_predictors = fit['design'].shape[1]
|
|
1361
|
+
n_samples = fit['nsamples']
|
|
1362
|
+
df_residual = n_cells - n_predictors - (n_samples - 1)
|
|
1363
|
+
df_residual = max(df_residual, 1)
|
|
1364
|
+
|
|
1365
|
+
# Determine covariate for trended prior
|
|
1366
|
+
if covariate is not None:
|
|
1367
|
+
cov = np.asarray(covariate, dtype=np.float64)
|
|
1368
|
+
elif counts is not None:
|
|
1369
|
+
if hasattr(counts, 'toarray'):
|
|
1370
|
+
mean_cpc = np.asarray(counts.mean(axis=1)).ravel()
|
|
1371
|
+
else:
|
|
1372
|
+
mean_cpc = counts.mean(axis=1).ravel()
|
|
1373
|
+
gene_mask = fit['gene_mask']
|
|
1374
|
+
cov = np.log2(mean_cpc[gene_mask] + 0.5)
|
|
1375
|
+
elif 'ave_log_abundance' in fit:
|
|
1376
|
+
cov = fit['ave_log_abundance']
|
|
1377
|
+
else:
|
|
1378
|
+
cov = None
|
|
1379
|
+
|
|
1380
|
+
# Filter to converged genes with finite phi
|
|
1381
|
+
ok_mask = conv_mask & np.isfinite(phi_use)
|
|
1382
|
+
idx_ok = np.where(ok_mask)[0]
|
|
1383
|
+
|
|
1384
|
+
if len(idx_ok) < 3:
|
|
1385
|
+
warnings.warn("Fewer than 3 converged genes; skipping shrinkage.")
|
|
1386
|
+
fit['phi_raw'] = phi_raw
|
|
1387
|
+
fit['phi_post'] = phi_raw.copy()
|
|
1388
|
+
fit['phi_prior'] = np.nan
|
|
1389
|
+
fit['df_residual'] = df_residual
|
|
1390
|
+
fit['df_prior_phi'] = 0.0
|
|
1391
|
+
fit['dispersion_shrunk'] = fit['dispersion'].copy()
|
|
1392
|
+
return fit
|
|
1393
|
+
|
|
1394
|
+
phi_ok = phi_use[idx_ok]
|
|
1395
|
+
cov_ok = cov[idx_ok] if cov is not None else None
|
|
1396
|
+
|
|
1397
|
+
# Call squeeze_var with scalar df (same for all genes).
|
|
1398
|
+
# Fall back gracefully: trended → untrended → no shrinkage.
|
|
1399
|
+
sv = None
|
|
1400
|
+
for cov_attempt in ([cov_ok, None] if cov_ok is not None else [None]):
|
|
1401
|
+
try:
|
|
1402
|
+
sv = squeeze_var(phi_ok, df=float(df_residual),
|
|
1403
|
+
covariate=cov_attempt, robust=robust)
|
|
1404
|
+
break
|
|
1405
|
+
except (ValueError, RuntimeError):
|
|
1406
|
+
continue
|
|
1407
|
+
if sv is None:
|
|
1408
|
+
try:
|
|
1409
|
+
sv = squeeze_var(phi_ok, df=float(df_residual),
|
|
1410
|
+
covariate=None, robust=False)
|
|
1411
|
+
except (ValueError, RuntimeError):
|
|
1412
|
+
warnings.warn("squeeze_var failed; returning unshrunk estimates.")
|
|
1413
|
+
fit['phi_raw'] = phi_raw
|
|
1414
|
+
fit['phi_post'] = phi_raw.copy()
|
|
1415
|
+
fit['phi_prior'] = np.nanmedian(phi_ok)
|
|
1416
|
+
fit['df_residual'] = df_residual
|
|
1417
|
+
fit['df_prior_phi'] = 0.0
|
|
1418
|
+
fit['dispersion_shrunk'] = fit['dispersion'].copy()
|
|
1419
|
+
return fit
|
|
1420
|
+
|
|
1421
|
+
# Map results back to full gene array
|
|
1422
|
+
phi_post = np.full(ngenes, np.nan)
|
|
1423
|
+
phi_post[idx_ok] = sv['var_post']
|
|
1424
|
+
|
|
1425
|
+
phi_prior_full = np.full(ngenes, np.nan)
|
|
1426
|
+
if isinstance(sv['var_prior'], np.ndarray):
|
|
1427
|
+
phi_prior_full[idx_ok] = sv['var_prior']
|
|
1428
|
+
median_prior = np.nanmedian(sv['var_prior'])
|
|
1429
|
+
else:
|
|
1430
|
+
phi_prior_full[:] = sv['var_prior']
|
|
1431
|
+
median_prior = sv['var_prior']
|
|
1432
|
+
|
|
1433
|
+
# Non-converged genes get the prior value
|
|
1434
|
+
phi_post[~ok_mask] = median_prior
|
|
1435
|
+
phi_prior_full[~ok_mask] = median_prior
|
|
1436
|
+
|
|
1437
|
+
# Store results
|
|
1438
|
+
fit['phi_raw'] = phi_raw
|
|
1439
|
+
fit['phi_post'] = phi_post
|
|
1440
|
+
fit['phi_prior'] = phi_prior_full
|
|
1441
|
+
fit['df_residual'] = df_residual
|
|
1442
|
+
fit['df_prior_phi'] = sv['df_prior']
|
|
1443
|
+
with np.errstate(divide='ignore'):
|
|
1444
|
+
fit['dispersion_shrunk'] = np.where(
|
|
1445
|
+
phi_post > 0, 1.0 / phi_post, np.inf
|
|
1446
|
+
)
|
|
1447
|
+
|
|
1448
|
+
return fit
|
|
1449
|
+
|
|
1450
|
+
|
|
1451
|
+
def glm_sc_test(fit, coef=None, contrast=None):
|
|
1452
|
+
"""Wald test on a ``glm_sc_fit`` result.
|
|
1453
|
+
|
|
1454
|
+
Parameters
|
|
1455
|
+
----------
|
|
1456
|
+
fit : dict
|
|
1457
|
+
Output from ``glm_sc_fit()``.
|
|
1458
|
+
coef : int, optional
|
|
1459
|
+
0-based column index of the coefficient to test. Default: last
|
|
1460
|
+
column.
|
|
1461
|
+
contrast : ndarray, optional
|
|
1462
|
+
Custom contrast vector (length p). If given, *coef* is ignored.
|
|
1463
|
+
|
|
1464
|
+
Returns
|
|
1465
|
+
-------
|
|
1466
|
+
dict with key ``'table'`` containing a DataFrame with columns:
|
|
1467
|
+
logFC, SE, z, PValue, FDR, sigma_sample, dispersion, converged.
|
|
1468
|
+
"""
|
|
1469
|
+
coefficients = fit['coefficients']
|
|
1470
|
+
se_arr = fit['se']
|
|
1471
|
+
ngenes, nb = coefficients.shape
|
|
1472
|
+
|
|
1473
|
+
if contrast is not None:
|
|
1474
|
+
contrast = np.asarray(contrast, dtype=np.float64)
|
|
1475
|
+
logFC = coefficients @ contrast
|
|
1476
|
+
se = np.sqrt(np.maximum(
|
|
1477
|
+
np.sum((se_arr ** 2) * (contrast ** 2), axis=1), 0
|
|
1478
|
+
))
|
|
1479
|
+
else:
|
|
1480
|
+
if coef is None:
|
|
1481
|
+
coef = nb - 1
|
|
1482
|
+
logFC = coefficients[:, coef]
|
|
1483
|
+
se = se_arr[:, coef]
|
|
1484
|
+
|
|
1485
|
+
z = logFC / se
|
|
1486
|
+
pvalue = _chi2.sf(z ** 2, 1)
|
|
1487
|
+
|
|
1488
|
+
# FDR (Benjamini-Hochberg)
|
|
1489
|
+
n = len(pvalue)
|
|
1490
|
+
valid = ~np.isnan(pvalue)
|
|
1491
|
+
fdr = np.full(n, np.nan)
|
|
1492
|
+
if valid.any():
|
|
1493
|
+
from statsmodels.stats.multitest import multipletests
|
|
1494
|
+
_, fdr_vals, _, _ = multipletests(pvalue[valid], method='fdr_bh')
|
|
1495
|
+
fdr[valid] = fdr_vals
|
|
1496
|
+
|
|
1497
|
+
table = pd.DataFrame({
|
|
1498
|
+
'logFC': logFC,
|
|
1499
|
+
'SE': se,
|
|
1500
|
+
'z': z,
|
|
1501
|
+
'PValue': pvalue,
|
|
1502
|
+
'FDR': fdr,
|
|
1503
|
+
'sigma_sample': fit['sigma_sample'],
|
|
1504
|
+
'dispersion': fit['dispersion'],
|
|
1505
|
+
'converged': fit['convergence'],
|
|
1506
|
+
})
|
|
1507
|
+
|
|
1508
|
+
if fit.get('genes') is not None:
|
|
1509
|
+
table.index = fit['genes']
|
|
1510
|
+
|
|
1511
|
+
return {'table': table}
|