edgepython 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edgepython/sc_fit.py ADDED
@@ -0,0 +1,1511 @@
1
+ # This code was written by Claude (Anthropic). The project was directed by Lior Pachter.
2
+ """Single-cell NB mixed model fitting (NEBULA-LN port).
3
+
4
+ Implements ``glm_sc_fit()`` and ``glm_sc_test()`` for cell-level
5
+ negative binomial gamma mixed model (NBGMM) analysis of multi-subject
6
+ single-cell RNA-seq data.
7
+
8
+ Reference
9
+ ---------
10
+ He L, Davila-Velderrain J, Sumida TS, Hafler DA, Bhatt DL et al.
11
+ NEBULA is a fast negative binomial mixed model for differential or
12
+ co-expression analysis of large-scale multi-subject single-cell data.
13
+ *Communications Biology*, 4:629, 2021.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import math
19
+ import warnings
20
+ from concurrent.futures import ProcessPoolExecutor
21
+ from math import lgamma as _lgamma
22
+ from typing import Any
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ from numba import njit
27
+ from scipy.optimize import minimize as _minimize
28
+ from scipy.special import digamma as _digamma, gammaln as _gammaln
29
+ from scipy.stats import chi2 as _chi2
30
+
31
+ from .normalization import calc_norm_factors
32
+ from .dgelist import make_dgelist
33
+
34
+
35
+ # ---------------------------------------------------------------------------
36
+ # Numba-accelerated core functions
37
+ # ---------------------------------------------------------------------------
38
+
39
+ @njit(cache=True)
40
+ def _digamma_nb(x):
41
+ """Digamma (psi) function for x > 0. Accurate to ~15 digits."""
42
+ result = 0.0
43
+ while x < 7.0:
44
+ result -= 1.0 / x
45
+ x += 1.0
46
+ r = 1.0 / (x * x)
47
+ result += math.log(x) - 0.5 / x
48
+ result -= r * (1.0/12.0 - r * (1.0/120.0 - r * (1.0/252.0
49
+ - r * (1.0/240.0 - r * (5.0/660.0 - r * 691.0/32760.0)))))
50
+ return result
51
+
52
+
53
+ @njit(cache=True)
54
+ def _ptmg_negll_and_grad_nb(para, X, offset, Y, n_one, n_two, ytwo,
55
+ fid, cumsumy, posind, posindy, nb, nind, k):
56
+ """Numba-compiled NBGMM negative log-likelihood + gradient."""
57
+ beta = para[:nb]
58
+ sigma_param = para[nb]
59
+ phi = para[nb + 1]
60
+
61
+ exps = math.exp(sigma_param)
62
+ exps_m1 = exps - 1.0
63
+ if exps_m1 <= 0:
64
+ return 1e30, np.zeros(nb + 2)
65
+ alpha = 1.0 / exps_m1
66
+ exps_s = math.sqrt(exps)
67
+ lam = alpha / exps_s
68
+ gamma = phi
69
+
70
+ exps_m1_sq = exps_m1 * exps_m1
71
+ alpha_pr = -exps / exps_m1_sq
72
+ lambda_pr = (1.0 - 3.0 * exps) / (2.0 * exps_s * exps_m1_sq)
73
+
74
+ log_lambda = math.log(lam)
75
+ log_gamma = math.log(gamma) if gamma > 0 else -1e30
76
+
77
+ nelem = len(posindy)
78
+
79
+ # Linear predictor: xtb = offset + X @ beta
80
+ xtb = np.empty(nind)
81
+ for i in range(nind):
82
+ s = offset[i]
83
+ for j in range(nb):
84
+ s += X[i, j] * beta[j]
85
+ xtb[i] = s
86
+
87
+ # term1 = sum y_j * eta_j (only non-zero)
88
+ term1 = 0.0
89
+ for i in range(nelem):
90
+ term1 += xtb[posindy[i]] * Y[i]
91
+
92
+ # exp(xtb) with overflow protection
93
+ extb = np.empty(nind)
94
+ for i in range(nind):
95
+ v = xtb[i]
96
+ extb[i] = math.exp(min(v, 500.0))
97
+
98
+ # Per-sample sums
99
+ cumsumxtb = np.empty(k)
100
+ for s in range(k):
101
+ start = fid[s]
102
+ end = fid[s + 1]
103
+ acc = 0.0
104
+ for i in range(start, end):
105
+ acc += extb[i]
106
+ cumsumxtb[s] = acc
107
+
108
+ ystar = np.empty(k)
109
+ mustar = np.empty(k)
110
+ mustar_log = np.empty(k)
111
+ ymustar = np.empty(k)
112
+ ymumustar = np.empty(k)
113
+ for s in range(k):
114
+ ystar[s] = cumsumy[s] + alpha
115
+ mustar[s] = cumsumxtb[s] + lam
116
+ mustar_log[s] = math.log(mustar[s])
117
+ ymustar[s] = ystar[s] / mustar[s]
118
+ ymumustar[s] = ymustar[s] / mustar[s]
119
+
120
+ for s in range(k):
121
+ term1 -= ystar[s] * mustar_log[s]
122
+ term1 += k * alpha * log_lambda
123
+ term1 += nind * gamma * log_gamma
124
+
125
+ # gstar = gamma + y_j
126
+ gstar_vec = np.full(nind, gamma)
127
+ for i in range(nelem):
128
+ gstar_vec[posindy[i]] += Y[i]
129
+
130
+ # sum_elgcp[j] = ymustar[s(j)] * extb[j]
131
+ sum_elgcp = np.empty(nind)
132
+ for s in range(k):
133
+ start = fid[s]
134
+ end = fid[s + 1]
135
+ val = ymustar[s]
136
+ for i in range(start, end):
137
+ sum_elgcp[i] = val * extb[i]
138
+
139
+ for i in range(nind):
140
+ term1 += sum_elgcp[i]
141
+
142
+ sum_elgcp_pg = np.empty(nind)
143
+ gstar_phiymustar = np.empty(nind)
144
+ log_sum_elgcp_pg = np.empty(nind)
145
+ slpey = 0.0
146
+ for i in range(nind):
147
+ sum_elgcp_pg[i] = sum_elgcp[i] + gamma
148
+ gstar_phiymustar[i] = gstar_vec[i] / sum_elgcp_pg[i]
149
+ log_sum_elgcp_pg[i] = math.log(sum_elgcp_pg[i])
150
+ slpey += log_sum_elgcp_pg[i]
151
+
152
+ term1 -= gamma * slpey
153
+ for i in range(nelem):
154
+ term1 -= Y[i] * log_sum_elgcp_pg[posindy[i]]
155
+
156
+ fn_cpp = -term1
157
+
158
+ # --- Gradient ---
159
+ dbeta_42 = np.zeros(k)
160
+ xexb_f = np.zeros((nb, k))
161
+ dbeta_41 = np.zeros((nb, k))
162
+
163
+ for s in range(k):
164
+ start = fid[s]
165
+ end = fid[s + 1]
166
+ for i in range(start, end):
167
+ gp = gstar_phiymustar[i]
168
+ ext_i = extb[i]
169
+ dbeta_42[s] += gp * ext_i
170
+ for j in range(nb):
171
+ xexb = X[i, j] * ext_i
172
+ xexb_f[j, s] += xexb
173
+ dbeta_41[j, s] += gp * xexb
174
+
175
+ db = np.zeros(nb)
176
+ for i in range(nelem):
177
+ for j in range(nb):
178
+ db[j] += X[posindy[i], j] * Y[i]
179
+
180
+ for s in range(k):
181
+ val = ymumustar[s] * (dbeta_42[s] - cumsumxtb[s])
182
+ for j in range(nb):
183
+ db[j] += xexb_f[j, s] * val - dbeta_41[j, s] * ymustar[s]
184
+
185
+ ldm = log_lambda * k
186
+ for s in range(k):
187
+ ldm -= mustar_log[s]
188
+ adlmy = exps_s * k
189
+ for s in range(k):
190
+ adlmy -= ymustar[s]
191
+
192
+ dtau = 0.0
193
+ dtau_lp = 0.0
194
+ for s in range(k):
195
+ dtau += alpha_pr * (cumsumxtb[s] - dbeta_42[s]) / mustar[s]
196
+ dtau_lp += ymumustar[s] * (dbeta_42[s] - cumsumxtb[s])
197
+ dtau += lambda_pr * dtau_lp
198
+ dtau += alpha_pr * ldm + lambda_pr * adlmy
199
+
200
+ dtau2 = log_gamma * nind + nind - slpey
201
+ for i in range(nind):
202
+ dtau2 -= gstar_phiymustar[i]
203
+
204
+ gr = np.zeros(nb + 2)
205
+ for j in range(nb):
206
+ gr[j] = -db[j]
207
+ gr[nb] = -dtau
208
+ gr[nb + 1] = -dtau2
209
+
210
+ # --- R-level lgamma corrections ---
211
+ n_one_plus_two = n_one + n_two
212
+
213
+ lgamma_fn = 0.0
214
+ for s_idx in range(len(posind)):
215
+ lgamma_fn += math.lgamma(cumsumy[posind[s_idx]] + alpha)
216
+ lgamma_fn -= len(posind) * math.lgamma(alpha)
217
+ for v_idx in range(len(ytwo)):
218
+ lgamma_fn += math.lgamma(ytwo[v_idx] + gamma)
219
+ lgamma_fn -= (nelem - n_one_plus_two) * math.lgamma(gamma)
220
+ if n_one_plus_two > 0:
221
+ lgamma_fn += n_one_plus_two * math.log(gamma)
222
+ if n_two > 0:
223
+ lgamma_fn += n_two * math.log(gamma + 1.0)
224
+
225
+ fn = fn_cpp - lgamma_fn
226
+
227
+ # --- Digamma corrections ---
228
+ dig_alpha_sum = 0.0
229
+ for s_idx in range(len(posind)):
230
+ dig_alpha_sum += _digamma_nb(cumsumy[posind[s_idx]] + alpha)
231
+ dig_alpha_sum -= len(posind) * _digamma_nb(alpha)
232
+
233
+ dig_gamma_sum = 0.0
234
+ for v_idx in range(len(ytwo)):
235
+ dig_gamma_sum += _digamma_nb(ytwo[v_idx] + gamma)
236
+ dig_gamma_sum -= (nelem - n_one_plus_two) * _digamma_nb(gamma)
237
+ if n_one_plus_two > 0:
238
+ dig_gamma_sum += n_one_plus_two / gamma
239
+ if n_two > 0:
240
+ dig_gamma_sum += n_two / (gamma + 1.0)
241
+
242
+ gr[nb] -= alpha_pr * dig_alpha_sum
243
+ gr[nb + 1] -= dig_gamma_sum
244
+
245
+ return fn, gr
246
+
247
+
248
+ @njit(cache=True)
249
+ def _compute_pml_loglik_nb(offset, X, beta, logw, fid, k, posindy, Y,
250
+ cumsumy, gamma, alpha, lam, nind, nb):
251
+ """Numba-compiled PML log-likelihood evaluation."""
252
+ nelem = len(posindy)
253
+
254
+ # extb_lin = offset + X @ beta
255
+ extb_lin = np.empty(nind)
256
+ for i in range(nind):
257
+ s = offset[i]
258
+ for j in range(nb):
259
+ s += X[i, j] * beta[j]
260
+ extb_lin[i] = s
261
+
262
+ loglik = 0.0
263
+ for i in range(nelem):
264
+ loglik += extb_lin[posindy[i]] * Y[i]
265
+
266
+ # logw @ cumsumy
267
+ for s in range(k):
268
+ loglik += logw[s] * cumsumy[s]
269
+
270
+ # Add logw to linear predictor per sample
271
+ for s in range(k):
272
+ start = fid[s]
273
+ end = fid[s + 1]
274
+ for i in range(start, end):
275
+ extb_lin[i] += logw[s]
276
+
277
+ # exp(extb_lin)
278
+ extb = np.empty(nind)
279
+ for i in range(nind):
280
+ extb[i] = math.exp(min(extb_lin[i], 500.0))
281
+
282
+ for i in range(nind):
283
+ extbphil = math.log(extb[i] + gamma)
284
+ loglik -= gamma * extbphil
285
+
286
+ for i in range(nelem):
287
+ loglik -= Y[i] * math.log(extb[posindy[i]] + gamma)
288
+
289
+ for s in range(k):
290
+ loglik += alpha * logw[s] - lam * math.exp(logw[s])
291
+
292
+ return loglik, extb
293
+
294
+
295
+ @njit(cache=True)
296
+ def _opt_pml_nb(X, offset, Y_vals, fid, cumsumy, posindy, nb, nind, k,
297
+ beta_init, sigma0, sigma1, eps, ord_):
298
+ """Numba-compiled PML optimizer.
299
+
300
+ Returns (beta, logw, vb2, loglik, loglikp, logdet, step, stepd, sec_ord).
301
+ """
302
+ exps = math.exp(sigma0)
303
+ alpha = 1.0 / (exps - 1.0)
304
+ lam = 1.0 / (math.sqrt(exps) * (exps - 1.0))
305
+ gamma = sigma1
306
+
307
+ logw = np.zeros(k)
308
+ beta = beta_init.copy()
309
+
310
+ # gstar: gamma + y for non-zero cells
311
+ gstar = np.full(nind, gamma)
312
+ nelem = len(posindy)
313
+ for i in range(nelem):
314
+ gstar[posindy[i]] += Y_vals[i]
315
+
316
+ # Precompute yx = X^T @ y (only non-zero entries)
317
+ yx = np.zeros(nb)
318
+ for i in range(nelem):
319
+ for j in range(nb):
320
+ yx[j] += X[posindy[i], j] * Y_vals[i]
321
+
322
+ # Initial log-likelihood
323
+ loglik, extb = _compute_pml_loglik_nb(
324
+ offset, X, beta, logw, fid, k, posindy, Y_vals,
325
+ cumsumy, gamma, alpha, lam, nind, nb
326
+ )
327
+
328
+ loglikp = 0.0
329
+ step = 0
330
+ maxstep = 50
331
+ maxstd = 10
332
+ convd = 0.01
333
+ stepd = 0
334
+
335
+ vb = np.zeros((nb, nb))
336
+ vb2 = np.zeros((nb, nb))
337
+ vw = np.zeros(k)
338
+ vwb = np.zeros((k, nb))
339
+
340
+ while step == 0 or (loglik - loglikp > eps and step < maxstep):
341
+ step += 1
342
+
343
+ damp = np.ones(nb)
344
+ damp_w = np.ones(k)
345
+
346
+ # gstar_extb_phi = gstar / (1 + gamma/extb)
347
+ gstar_extb_phi = np.empty(nind)
348
+ for i in range(nind):
349
+ if extb[i] < 1e-300:
350
+ gstar_extb_phi[i] = 0.0
351
+ else:
352
+ gstar_extb_phi[i] = gstar[i] / (1.0 + gamma / extb[i])
353
+
354
+ # Gradient w.r.t. beta: db = yx - X^T @ gstar_extb_phi
355
+ db = np.empty(nb)
356
+ for j in range(nb):
357
+ s = yx[j]
358
+ for i in range(nind):
359
+ s -= X[i, j] * gstar_extb_phi[i]
360
+ db[j] = s
361
+
362
+ # Gradient w.r.t. logw
363
+ dw = np.empty(k)
364
+ w = np.empty(k)
365
+ for s in range(k):
366
+ start = fid[s]
367
+ end = fid[s + 1]
368
+ acc = 0.0
369
+ for i in range(start, end):
370
+ acc += gstar_extb_phi[i]
371
+ w[s] = math.exp(logw[s])
372
+ dw[s] = cumsumy[s] - acc - lam * w[s] + alpha
373
+
374
+ # Hessian diagonal w.r.t. logw
375
+ gstar_extb_phi2 = np.empty(nind)
376
+ for i in range(nind):
377
+ denom = extb[i] + gamma
378
+ if denom < 1e-300:
379
+ gstar_extb_phi2[i] = 0.0
380
+ else:
381
+ gstar_extb_phi2[i] = gstar_extb_phi[i] / denom
382
+
383
+ for s in range(k):
384
+ start = fid[s]
385
+ end = fid[s + 1]
386
+ acc = 0.0
387
+ for i in range(start, end):
388
+ acc += gstar_extb_phi2[i]
389
+ vw[s] = gamma * acc + lam * w[s]
390
+
391
+ # Cross-term Hessian vwb (k × nb)
392
+ for s in range(k):
393
+ start = fid[s]
394
+ end = fid[s + 1]
395
+ for j in range(nb):
396
+ acc = 0.0
397
+ for i in range(start, end):
398
+ acc += X[i, j] * gstar_extb_phi2[i]
399
+ vwb[s, j] = gamma * acc
400
+
401
+ # Hessian w.r.t. beta (nb × nb)
402
+ for ii in range(nb):
403
+ for jj in range(ii, nb):
404
+ acc = 0.0
405
+ for i in range(nind):
406
+ acc += X[i, ii] * gstar_extb_phi2[i] * X[i, jj]
407
+ vb[ii, jj] = gamma * acc
408
+ if ii != jj:
409
+ vb[jj, ii] = vb[ii, jj]
410
+
411
+ # Floor vw to avoid division by zero
412
+ for s in range(k):
413
+ if vw[s] < 1e-15:
414
+ vw[s] = 1e-15
415
+
416
+ # Schur complement: vb2 = vb - vwb^T @ diag(1/vw) @ vwb
417
+ for ii in range(nb):
418
+ for jj in range(nb):
419
+ acc = 0.0
420
+ for s in range(k):
421
+ acc += vwb[s, ii] * vwb[s, jj] / vw[s]
422
+ vb2[ii, jj] = vb[ii, jj] - acc
423
+
424
+ # Newton step
425
+ dwvw = np.empty(k)
426
+ for s in range(k):
427
+ dwvw[s] = dw[s] / vw[s]
428
+
429
+ # rhs = db - vwb^T @ dwvw
430
+ rhs = np.empty(nb)
431
+ for j in range(nb):
432
+ acc = 0.0
433
+ for s in range(k):
434
+ acc += vwb[s, j] * dwvw[s]
435
+ rhs[j] = db[j] - acc
436
+
437
+ # Regularize if needed
438
+ for ii in range(nb):
439
+ if abs(vb2[ii, ii]) < 1e-10:
440
+ vb2[ii, ii] += 1e-8
441
+
442
+ stepbeta = np.linalg.solve(vb2, rhs)
443
+
444
+ # steplogw = dwvw - (vwb @ stepbeta) / vw (vw already floored above)
445
+ steplogw = np.empty(k)
446
+ for s in range(k):
447
+ acc = 0.0
448
+ for j in range(nb):
449
+ acc += vwb[s, j] * stepbeta[j]
450
+ steplogw[s] = dwvw[s] - acc / vw[s]
451
+
452
+ new_b = beta + stepbeta
453
+ new_w = logw + steplogw
454
+
455
+ loglikp = loglik
456
+ loglik, extb = _compute_pml_loglik_nb(
457
+ offset, X, new_b, new_w, fid, k, posindy, Y_vals,
458
+ cumsumy, gamma, alpha, lam, nind, nb
459
+ )
460
+
461
+ likdif = loglik - loglikp
462
+ stepd = 0
463
+ minstep = 40.0
464
+
465
+ while likdif < 0 or math.isinf(loglik):
466
+ stepd += 1
467
+ minstep /= 2.0
468
+
469
+ if stepd > maxstd:
470
+ likdif = 0.0
471
+ loglik = loglikp
472
+ mabsdb = 0.0
473
+ mabsdw = 0.0
474
+ for j in range(nb):
475
+ if abs(db[j]) > mabsdb:
476
+ mabsdb = abs(db[j])
477
+ for s in range(k):
478
+ if abs(dw[s]) > mabsdw:
479
+ mabsdw = abs(dw[s])
480
+ if mabsdb > convd or mabsdw > convd:
481
+ stepd += 1
482
+ break
483
+
484
+ for i in range(nb):
485
+ if -40 < stepbeta[i] < 40:
486
+ damp[i] /= 2.0
487
+ new_b[i] = beta[i] + stepbeta[i] * damp[i]
488
+ else:
489
+ new_b[i] = beta[i] + (minstep if stepbeta[i] > 0 else -minstep)
490
+
491
+ for s in range(k):
492
+ if -40 < steplogw[s] < 40:
493
+ damp_w[s] /= 2.0
494
+ new_w[s] = logw[s] + steplogw[s] * damp_w[s]
495
+ else:
496
+ new_w[s] = logw[s] + (minstep if steplogw[s] > 0 else -minstep)
497
+
498
+ loglik, extb = _compute_pml_loglik_nb(
499
+ offset, X, new_b, new_w, fid, k, posindy, Y_vals,
500
+ cumsumy, gamma, alpha, lam, nind, nb
501
+ )
502
+ likdif = loglik - loglikp
503
+
504
+ beta = new_b
505
+ logw = new_w
506
+
507
+ # Log-determinant
508
+ logdet = 0.0
509
+ for s in range(k):
510
+ logdet += math.log(max(abs(vw[s]), 1e-300))
511
+
512
+ # Second-order correction
513
+ sec_ord = 0.0
514
+ if ord_ > 1:
515
+ for i in range(nind):
516
+ if extb[i] < 1e-300:
517
+ gstar_extb_phi[i] = 0.0
518
+ else:
519
+ gstar_extb_phi[i] = gstar[i] / (1.0 + gamma / extb[i])
520
+ extbg = np.empty(nind)
521
+ for i in range(nind):
522
+ extbg[i] = extb[i] + gamma
523
+ if extbg[i] < 1e-300:
524
+ gstar_extb_phi[i] = 0.0
525
+ else:
526
+ gstar_extb_phi[i] /= extbg[i]
527
+ for s in range(k):
528
+ start = fid[s]
529
+ end = fid[s + 1]
530
+ acc = 0.0
531
+ for i in range(start, end):
532
+ acc += gstar_extb_phi[i]
533
+ vw[s] = gamma * acc + lam * math.exp(logw[s])
534
+ if vw[s] < 1e-15:
535
+ vw[s] = 1e-15
536
+ vws = np.empty(k)
537
+ for s in range(k):
538
+ vws[s] = vw[s] * vw[s]
539
+
540
+ for i in range(nind):
541
+ if extbg[i] < 1e-300:
542
+ gstar_extb_phi[i] = 0.0
543
+ else:
544
+ gstar_extb_phi[i] /= extbg[i]
545
+ third_der = np.empty(k)
546
+ for s in range(k):
547
+ start = fid[s]
548
+ end = fid[s + 1]
549
+ acc = 0.0
550
+ for i in range(start, end):
551
+ acc += gstar_extb_phi[i] * (gamma - extb[i])
552
+ third_der[s] = gamma * acc + lam * math.exp(logw[s])
553
+ acc = 0.0
554
+ for s in range(k):
555
+ acc += third_der[s] * third_der[s] / (vws[s] * vw[s])
556
+ sec_ord += 5.0 / 24.0 * acc
557
+
558
+ if ord_ > 2:
559
+ for i in range(nind):
560
+ if extbg[i] < 1e-300:
561
+ gstar_extb_phi[i] = 0.0
562
+ else:
563
+ gstar_extb_phi[i] /= extbg[i]
564
+ four_der = np.empty(k)
565
+ for s in range(k):
566
+ start = fid[s]
567
+ end = fid[s + 1]
568
+ acc = 0.0
569
+ for i in range(start, end):
570
+ extbp = extb[i] * extb[i]
571
+ acc += gstar_extb_phi[i] * (gamma*gamma + extbp - 4*gamma*extb[i])
572
+ four_der[s] = gamma * acc + lam * math.exp(logw[s])
573
+ acc = 0.0
574
+ for s in range(k):
575
+ acc += four_der[s] / vws[s]
576
+ sec_ord -= acc / 8.0
577
+
578
+ for i in range(nind):
579
+ if extbg[i] < 1e-300:
580
+ gstar_extb_phi[i] = 0.0
581
+ else:
582
+ gstar_extb_phi[i] /= extbg[i]
583
+ for s in range(k):
584
+ start = fid[s]
585
+ end = fid[s + 1]
586
+ acc2 = 0.0
587
+ for i in range(start, end):
588
+ extbp = extb[i] * extb[i]
589
+ acc2 += gstar_extb_phi[i] * (
590
+ gamma**3 - 11*gamma*gamma*extb[i]
591
+ + 11*gamma*extbp - extbp*extb[i]
592
+ )
593
+ four_der[s] = gamma * acc2 + lam * math.exp(logw[s])
594
+ acc = 0.0
595
+ for s in range(k):
596
+ acc += four_der[s] * third_der[s] / (vws[s] * vws[s])
597
+ sec_ord += 7.0 / 48.0 * acc
598
+
599
+ return beta, logw, vb2, loglik, loglikp, logdet, step, stepd, sec_ord
600
+
601
+
602
+ @njit(cache=True)
603
+ def _get_cell_nb(X, fid, nb, k):
604
+ """Numba-compiled cell-level covariate detection."""
605
+ iscell = np.zeros(nb)
606
+ for i in range(nb):
607
+ for j in range(k):
608
+ start = fid[j]
609
+ end = fid[j + 1]
610
+ ref = X[start, i]
611
+ found = False
612
+ for idx in range(start, end):
613
+ if X[idx, i] != ref:
614
+ found = True
615
+ break
616
+ if found:
617
+ iscell[i] = 1.0
618
+ break
619
+ return iscell
620
+
621
+ # ---------------------------------------------------------------------------
622
+ # Utility helpers
623
+ # ---------------------------------------------------------------------------
624
+
625
+ def _center_design(pred: np.ndarray):
626
+ """Center design columns and scale to unit variance.
627
+
628
+ Matches nebula's ``center_m`` C++ function exactly.
629
+
630
+ Returns
631
+ -------
632
+ pred_centered : ndarray (n × p)
633
+ sds : ndarray (p,)
634
+ Column standard deviations (population, not sample).
635
+ The intercept column gets sd=0; zero-vector columns get sd=-1.
636
+ int_col : int
637
+ 0-based index of the intercept column.
638
+ """
639
+ pred = np.asarray(pred, dtype=np.float64).copy()
640
+ n, p = pred.shape
641
+ means = pred.mean(axis=0)
642
+ cm = pred - means
643
+ sds = np.sqrt((cm * cm).mean(axis=0))
644
+
645
+ int_col = None
646
+ for i in range(p):
647
+ if sds[i] > 0:
648
+ cm[:, i] /= sds[i]
649
+ else:
650
+ if pred[0, i] != 0:
651
+ # intercept column: fill with ones
652
+ cm[:, i] = 1.0
653
+ sds[i] = 0.0
654
+ int_col = i
655
+ else:
656
+ sds[i] = -1.0
657
+
658
+ if int_col is None:
659
+ raise ValueError("The design matrix must include an intercept term.")
660
+ if (sds == 0).sum() > 1 or (sds < 0).any():
661
+ raise ValueError(
662
+ "Some predictors have zero variation or a zero vector."
663
+ )
664
+
665
+ return cm, sds, int_col
666
+
667
+
668
+ def _cv_offset(offset: np.ndarray | None, nind: int):
669
+ """Process offset, matching nebula's ``cv_offset`` C++ function.
670
+
671
+ Parameters
672
+ ----------
673
+ offset : array (nind,) of *positive* scaling factors, or None.
674
+ nind : int
675
+
676
+ Returns
677
+ -------
678
+ log_offset : ndarray (nind,) — log of the offset
679
+ moffset : float — mean of log-offset (0 if offset was None)
680
+ cv2 : float — squared CV of the raw offset
681
+ """
682
+ if offset is None:
683
+ log_offset = np.zeros(nind)
684
+ return log_offset, 0.0, 0.0
685
+
686
+ offset = np.asarray(offset, dtype=np.float64)
687
+ moffset_raw = offset.mean()
688
+ cv = 0.0
689
+ if moffset_raw > 0:
690
+ cv = np.sqrt(((offset - moffset_raw) ** 2).sum() / nind) / moffset_raw
691
+ log_offset = np.log(offset)
692
+ moffset = log_offset.mean()
693
+ return log_offset, moffset, cv * cv
694
+
695
+
696
+ def _call_cumsumy(count, fid, k, ngene):
697
+ """Sum counts per gene per sample, matching nebula's ``call_cumsumy``.
698
+
699
+ Parameters
700
+ ----------
701
+ count : sparse or dense, genes × cells
702
+ fid : int array (k+1,) of segment boundaries (0-based)
703
+ k : int, number of samples
704
+ ngene : int
705
+
706
+ Returns
707
+ -------
708
+ cumsumy : ndarray (ngene, k)
709
+ """
710
+ cumsumy = np.zeros((ngene, k), dtype=np.float64)
711
+ # Use dense or sparse slicing
712
+ for s in range(k):
713
+ start, end = fid[s], fid[s + 1]
714
+ chunk = count[:, start:end]
715
+ if hasattr(chunk, 'toarray'):
716
+ cumsumy[:, s] = np.asarray(chunk.sum(axis=1)).ravel()
717
+ else:
718
+ cumsumy[:, s] = chunk.sum(axis=1).ravel()
719
+ return cumsumy
720
+
721
+
722
+ def _call_posindy(y_gene: np.ndarray):
723
+ """Extract non-zero positions and counts for a single gene row.
724
+
725
+ Matches nebula's ``call_posindy``.
726
+
727
+ Parameters
728
+ ----------
729
+ y_gene : 1D array (ncells,)
730
+
731
+ Returns
732
+ -------
733
+ dict with keys:
734
+ posindy : int array — 0-based indices of non-zero cells
735
+ Y : float array — corresponding count values
736
+ mct : float — mean count per cell
737
+ n_onetwo : int array (2,) — [n_one, n_two]
738
+ ytwo : float array — counts > 2
739
+ """
740
+ nz = np.nonzero(y_gene)[0]
741
+ posindy = nz.astype(np.int32)
742
+ Y = y_gene[nz].astype(np.float64)
743
+ mct = Y.sum() / len(y_gene)
744
+
745
+ n_one = int((Y == 1).sum())
746
+ n_two = int((Y == 2).sum())
747
+ ytwo = Y[Y > 2]
748
+
749
+ return {
750
+ 'posindy': posindy,
751
+ 'Y': Y,
752
+ 'mct': mct,
753
+ 'n_onetwo': np.array([n_one, n_two], dtype=np.int32),
754
+ 'ytwo': ytwo,
755
+ }
756
+
757
+
758
+ def _get_cell(X, fid, nb, k):
759
+ """Identify cell-level covariates (vary within a subject).
760
+
761
+ Delegates to numba-compiled ``_get_cell_nb``.
762
+ """
763
+ return _get_cell_nb(X, fid, nb, k)
764
+
765
+
766
+ def _get_cv(offset, X, beta, cell_ind, ncell, nc):
767
+ """Compute squared CV of fitted values at cell-level predictors.
768
+
769
+ Matches nebula's ``get_cv``.
770
+ """
771
+ extb = offset.copy()
772
+ for i in range(ncell):
773
+ ind = int(cell_ind[i])
774
+ extb = extb + X[:, ind] * beta[ind]
775
+ with np.errstate(over='ignore'):
776
+ extb = np.exp(extb)
777
+ m = extb.mean()
778
+ if m > 0:
779
+ return ((extb - m) ** 2).sum() / nc / (m * m)
780
+ return 0.0
781
+
782
+
783
+ # ---------------------------------------------------------------------------
784
+ # NBGMM log-likelihood + gradient (ptmg_ll_der)
785
+ # ---------------------------------------------------------------------------
786
+
787
+ def _ptmg_negll_and_grad(para, X, offset, Y, n_onetwo, ytwo, fid, cumsumy,
788
+ posind, posindy, nb, nind, k):
789
+ """Negative log-likelihood and gradient for NBGMM (L-BFGS-B stage).
790
+
791
+ Delegates to numba-compiled ``_ptmg_negll_and_grad_nb``.
792
+ """
793
+ return _ptmg_negll_and_grad_nb(
794
+ para, X, offset, Y,
795
+ int(n_onetwo[0]), int(n_onetwo[1]), ytwo,
796
+ fid, cumsumy, posind, posindy, nb, nind, k
797
+ )
798
+
799
+
800
+ # ---------------------------------------------------------------------------
801
+ # Penalized ML optimizer (opt_pml for NBGMM)
802
+ # ---------------------------------------------------------------------------
803
+
804
+ def _opt_pml(X, offset, Y_vals, fid, cumsumy, posind, posindy, nb, nind, k,
805
+ beta_init, sigma, reml=0, eps=1e-6, ord_=1):
806
+ """Port of nebula's ``opt_pml`` C++ function.
807
+
808
+ Delegates to numba-compiled ``_opt_pml_nb``.
809
+ """
810
+ beta, logw, vb2, loglik, loglikp, logdet, step, stepd, sec_ord = \
811
+ _opt_pml_nb(X, offset, Y_vals, fid, cumsumy, posindy, nb, nind, k,
812
+ beta_init, sigma[0], sigma[1], eps, ord_)
813
+ return {
814
+ 'beta': beta,
815
+ 'logw': logw,
816
+ 'var': vb2,
817
+ 'loglik': loglik,
818
+ 'loglikp': loglikp,
819
+ 'logdet': logdet,
820
+ 'iter': int(step),
821
+ 'damp': int(stepd),
822
+ 'second': sec_ord,
823
+ }
824
+
825
+
826
+ # ---------------------------------------------------------------------------
827
+ # Convergence check
828
+ # ---------------------------------------------------------------------------
829
+
830
+ def _check_conv(repml, conv, nb, vare, min_bounds, max_bounds, cutoff=1e-8):
831
+ """Port of nebula's ``check_conv``."""
832
+ if conv == 1:
833
+ if vare[0] == max_bounds[0] or vare[1] == min_bounds[1]:
834
+ conv = -60
835
+ elif np.isnan(repml['loglik']):
836
+ conv = -30
837
+ elif repml['iter'] == 50:
838
+ conv = -20
839
+ elif repml['damp'] == 11:
840
+ conv = -10
841
+ elif repml['damp'] == 12:
842
+ conv = -40
843
+
844
+ if nb > 1:
845
+ try:
846
+ eigvals = np.linalg.eigvalsh(repml['var'])
847
+ if eigvals.min() < cutoff:
848
+ conv = -25
849
+ except np.linalg.LinAlgError:
850
+ conv = -25
851
+
852
+ return conv
853
+
854
+
855
+ # ---------------------------------------------------------------------------
856
+ # Per-gene fitting
857
+ # ---------------------------------------------------------------------------
858
+
859
+ def _fit_gene_nebula_ln(gene_idx, y_gene, X, offset, fid, cumsumy_gene,
860
+ posind, nb, nind, k, sds, int_col, moffset,
861
+ min_bounds, max_bounds, mfs, cutoff_cell, kappa):
862
+ """Fit NBGMM (NEBULA-LN) for a single gene.
863
+
864
+ Returns
865
+ -------
866
+ tuple: (beta_rescaled, se_rescaled, sigma2, inv_phi, conv, logw)
867
+ """
868
+ posv = _call_posindy(y_gene)
869
+ posindy = posv['posindy']
870
+ Y = posv['Y']
871
+ mct = posv['mct']
872
+ n_onetwo = posv['n_onetwo']
873
+ ytwo = posv['ytwo']
874
+
875
+ # ord parameter
876
+ if mct * mfs < 3:
877
+ ord_ = 3
878
+ else:
879
+ ord_ = 1
880
+
881
+ # Initial beta
882
+ lmct = np.log(max(mct, 1e-300))
883
+ para_init = np.zeros(nb + 2)
884
+ para_init[int_col] = lmct - moffset
885
+ para_init[nb] = 1.0 # sigma_param
886
+ para_init[nb + 1] = 1.0 # phi (cell-level overdispersion)
887
+
888
+ lower = np.concatenate([np.full(nb, -100.0), [min_bounds[0], min_bounds[1]]])
889
+ upper = np.concatenate([np.full(nb, 100.0), [max_bounds[0], max_bounds[1]]])
890
+ bounds = list(zip(lower, upper))
891
+
892
+ # Stage 1: L-BFGS-B
893
+ try:
894
+ res = _minimize(
895
+ _ptmg_negll_and_grad,
896
+ para_init,
897
+ args=(X, offset, Y, n_onetwo, ytwo, fid, cumsumy_gene,
898
+ posind, posindy, nb, nind, k),
899
+ method='L-BFGS-B',
900
+ jac=True,
901
+ bounds=bounds,
902
+ options={'ftol': 1e-6, 'maxiter': 200},
903
+ )
904
+ refp = res.x
905
+ is_conv = 1 if res.success else 0
906
+ except Exception:
907
+ # Fallback: use initial values
908
+ refp = para_init.copy()
909
+ is_conv = 0
910
+
911
+ conv = is_conv
912
+ vare = np.array([refp[nb], refp[nb + 1]])
913
+
914
+ # Determine cell-level predictor CV
915
+ cell_ind_arr = _get_cell(X, fid, nb, k)
916
+ ncell = int(cell_ind_arr.sum())
917
+ cell_ind = np.where(cell_ind_arr > 0)[0]
918
+ if ncell > 0:
919
+ try:
920
+ cv2p = _get_cv(offset, X, refp[:nb], cell_ind, ncell, nind)
921
+ except Exception:
922
+ cv2p = float('nan')
923
+ else:
924
+ cv2p = 0.0
925
+
926
+ gni = mfs * vare[1]
927
+
928
+ # Determine if we need HL refinement
929
+ fit = 1
930
+ if (gni < cutoff_cell) or (conv == 0) or np.isnan(cv2p):
931
+ # Would need NEBULA-HL refinement — skip for LN-only impl
932
+ # Just note fit=2 for algorithm tracking
933
+ fit = 2
934
+ else:
935
+ kappa_obs = gni / (1.0 + cv2p)
936
+ if (kappa_obs < 20) or (kappa_obs < kappa and vare[0] < 8.0 / kappa_obs):
937
+ fit = 3
938
+
939
+ # Beta for PML: start from intercept init, not L-BFGS-B result
940
+ betae = np.zeros(nb)
941
+ betae[int_col] = lmct - moffset
942
+
943
+ # Bias correction for intercept
944
+ betae[int_col] -= vare[0] / 2.0
945
+
946
+ # Stage 2: Penalized ML
947
+ try:
948
+ repml = _opt_pml(
949
+ X, offset, Y, fid, cumsumy_gene, posind, posindy,
950
+ nb, nind, k, betae, vare, reml=0, eps=1e-6, ord_=ord_,
951
+ )
952
+ except Exception:
953
+ # Numerical failure in PML — mark as non-converged
954
+ return (np.full(nb, np.nan), np.full(nb, np.nan),
955
+ vare[0], 1.0 / vare[1] if vare[1] > 0 else np.inf,
956
+ -30, 0, np.zeros(k))
957
+
958
+ conv = _check_conv(repml, conv, nb, vare, min_bounds, max_bounds)
959
+
960
+ # Invert Fisher information to get covariance
961
+ beta_pml = repml['beta']
962
+ logw = repml['logw']
963
+ fisher = repml['var'] # This is vb2 (the Schur complement)
964
+
965
+ se = np.full(nb, np.nan)
966
+ if conv != -25:
967
+ try:
968
+ cov = np.linalg.inv(fisher)
969
+ se = np.sqrt(np.maximum(np.diag(cov), 0.0))
970
+ except np.linalg.LinAlgError:
971
+ conv = -25
972
+
973
+ # Rescale by column SDs (undo centering)
974
+ sds_use = sds.copy()
975
+ sds_use[int_col] = 1.0
976
+ beta_rescaled = beta_pml / sds_use
977
+ se_rescaled = se / sds_use
978
+
979
+ sigma2 = vare[0]
980
+ inv_phi = 1.0 / vare[1] if vare[1] > 0 else np.inf
981
+
982
+ return beta_rescaled, se_rescaled, sigma2, inv_phi, conv, fit, logw
983
+
984
+
985
+ # ---------------------------------------------------------------------------
986
+ # Main entry points
987
+ # ---------------------------------------------------------------------------
988
+
989
+ def glm_sc_fit(y, cell_meta=None, design=None, sample=None,
990
+ offset=None, norm_method='TMM', method='nebula',
991
+ min_bounds=None, max_bounds=None,
992
+ cpc=0.005, mincp=5, cutoff_cell=20, kappa=800,
993
+ ncore=1, verbose=True):
994
+ """Fit a single-cell NB gamma mixed model (NEBULA-LN).
995
+
996
+ Parameters
997
+ ----------
998
+ y : AnnData, dict, or ndarray
999
+ Count data. AnnData objects are (cells × genes); raw matrices
1000
+ should be (genes × cells).
1001
+ cell_meta : DataFrame, optional
1002
+ Cell-level metadata. Extracted from ``y.obs`` for AnnData.
1003
+ design : ndarray or str, optional
1004
+ Design matrix (cells × predictors) with an intercept column.
1005
+ If ``None``, an intercept-only model is fitted.
1006
+ sample : str or array-like
1007
+ Subject/sample identifiers. If a string, it names a column in
1008
+ *cell_meta*.
1009
+ offset : array-like, optional
1010
+ Positive per-cell scaling factors. If provided, ``norm_method``
1011
+ is ignored.
1012
+ norm_method : str
1013
+ ``'TMM'`` (default): compute per-cell offset from per-cell
1014
+ library sizes and pseudobulk TMM normalization factors.
1015
+ ``'none'``: all-ones offset (original nebula behaviour).
1016
+ method : str
1017
+ ``'nebula'`` (default): NEBULA-LN algorithm.
1018
+ min_bounds, max_bounds : tuple of float, optional
1019
+ Bounds for (sigma_param, phi). Defaults (1e-4, 1e-4) and
1020
+ (10, 1000).
1021
+ cpc : float
1022
+ Minimum mean counts per cell for gene filtering.
1023
+ mincp : int
1024
+ Minimum non-zero cells for gene filtering.
1025
+ cutoff_cell : float
1026
+ Threshold for NEBULA-HL fallback (cells_per_subject × phi).
1027
+ kappa : float
1028
+ Accuracy threshold for subject-level overdispersion.
1029
+ ncore : int
1030
+ Number of parallel workers (1 = sequential).
1031
+ verbose : bool
1032
+ Print progress messages.
1033
+
1034
+ Returns
1035
+ -------
1036
+ dict
1037
+ DGEGLM-like fit result with keys ``'coefficients'``, ``'se'``,
1038
+ ``'dispersion'``, ``'design'``, ``'offset'``, ``'genes'``,
1039
+ ``'sigma_sample'``, ``'convergence'``, ``'method'``, etc.
1040
+ Pass to ``top_tags(fit, coef=...)`` for Wald testing.
1041
+ """
1042
+ if min_bounds is None:
1043
+ min_bounds = (1e-4, 1e-4)
1044
+ if max_bounds is None:
1045
+ max_bounds = (10.0, 1000.0)
1046
+
1047
+ # --- Input handling ---
1048
+ gene_names = None
1049
+ try:
1050
+ import anndata
1051
+ is_anndata = isinstance(y, anndata.AnnData)
1052
+ except ImportError:
1053
+ is_anndata = False
1054
+
1055
+ if is_anndata:
1056
+ adata = y
1057
+ X_raw = adata.X
1058
+ if hasattr(X_raw, 'toarray'):
1059
+ X_raw = X_raw.toarray()
1060
+ counts = np.asarray(X_raw, dtype=np.float64).T # genes × cells
1061
+ if cell_meta is None:
1062
+ cell_meta = adata.obs.copy()
1063
+ gene_names = np.array(adata.var_names)
1064
+ elif isinstance(y, dict) and 'counts' in y:
1065
+ counts = np.asarray(y['counts'], dtype=np.float64)
1066
+ if cell_meta is None and 'obs' in y:
1067
+ cell_meta = y['obs']
1068
+ if gene_names is None and 'genes' in y:
1069
+ gene_names = np.asarray(y['genes'])
1070
+ else:
1071
+ if hasattr(y, 'toarray'):
1072
+ counts = np.asarray(y.toarray(), dtype=np.float64)
1073
+ else:
1074
+ counts = np.asarray(y, dtype=np.float64)
1075
+
1076
+ ngene, nind = counts.shape
1077
+ if nind < 2:
1078
+ raise ValueError("There is no more than one cell in the count matrix.")
1079
+
1080
+ # --- Resolve sample IDs ---
1081
+ if sample is None:
1082
+ raise ValueError(
1083
+ "The 'sample' argument is required. Provide per-cell sample IDs."
1084
+ )
1085
+ if isinstance(sample, str):
1086
+ if cell_meta is None:
1087
+ raise ValueError(
1088
+ f"sample='{sample}' requires cell_meta with that column."
1089
+ )
1090
+ sample_ids = np.asarray(cell_meta[sample])
1091
+ else:
1092
+ sample_ids = np.asarray(sample)
1093
+ if len(sample_ids) != nind:
1094
+ raise ValueError(
1095
+ "Length of sample IDs should equal the number of cells."
1096
+ )
1097
+
1098
+ # --- Save design column names before sort (which may convert DataFrame → numpy) ---
1099
+ _design_colnames = None
1100
+ if design is not None and hasattr(design, 'columns'):
1101
+ _design_colnames = list(design.columns)
1102
+
1103
+ # --- Sort cells by sample (group_cell) ---
1104
+ sample_ids_str = np.array([str(s) for s in sample_ids])
1105
+ levels = list(dict.fromkeys(sample_ids_str)) # unique, order-preserving
1106
+ sample_numeric = np.array(
1107
+ [levels.index(s) + 1 for s in sample_ids_str], dtype=np.int32
1108
+ )
1109
+ if not np.all(sample_numeric[:-1] <= sample_numeric[1:]):
1110
+ # Need to sort
1111
+ order = np.argsort(sample_numeric, kind='stable')
1112
+ counts = counts[:, order]
1113
+ sample_numeric = sample_numeric[order]
1114
+ sample_ids_str = sample_ids_str[order]
1115
+ if cell_meta is not None:
1116
+ if isinstance(cell_meta, pd.DataFrame):
1117
+ cell_meta = cell_meta.iloc[order].reset_index(drop=True)
1118
+ else:
1119
+ cell_meta = cell_meta[order]
1120
+ if offset is not None:
1121
+ offset = np.asarray(offset, dtype=np.float64)[order]
1122
+ if design is not None and not isinstance(design, str):
1123
+ design = np.asarray(design, dtype=np.float64)[order]
1124
+
1125
+ k = len(levels)
1126
+ # Build fid: 0-based start index of each sample's cells + sentinel
1127
+ diffs = np.where(np.concatenate([[1], np.diff(sample_numeric)]))[0]
1128
+ fid = np.concatenate([diffs, [nind]]).astype(np.int32)
1129
+
1130
+ # --- Design matrix ---
1131
+ if design is None:
1132
+ pred = np.ones((nind, 1), dtype=np.float64)
1133
+ predn = None
1134
+ sds = np.array([0.0])
1135
+ int_col = 0
1136
+ else:
1137
+ if isinstance(design, str):
1138
+ # Formula — resolve against cell_meta
1139
+ from .utils import model_matrix
1140
+ pred = np.asarray(
1141
+ model_matrix(design, cell_meta), dtype=np.float64
1142
+ )
1143
+ else:
1144
+ pred = np.asarray(design, dtype=np.float64)
1145
+ if pred.shape[0] != nind:
1146
+ raise ValueError(
1147
+ "Design matrix rows must equal number of cells."
1148
+ )
1149
+ predn = None
1150
+ if hasattr(design, 'columns'):
1151
+ predn = list(design.columns)
1152
+ elif isinstance(design, pd.DataFrame):
1153
+ predn = list(design.columns)
1154
+ pred, sds, int_col = _center_design(pred)
1155
+
1156
+ nb = pred.shape[1]
1157
+
1158
+ # --- Offset ---
1159
+ if offset is not None:
1160
+ # User-provided offset (positive scaling factors)
1161
+ log_offset, moffset, cv2 = _cv_offset(offset, nind)
1162
+ elif norm_method.upper() == 'TMM':
1163
+ # Pseudobulk TMM normalization → per-cell offset
1164
+ lib_size = counts.sum(axis=0).astype(np.float64)
1165
+ pb = np.zeros((ngene, k), dtype=np.float64)
1166
+ for s in range(k):
1167
+ start, end = fid[s], fid[s + 1]
1168
+ chunk = counts[:, start:end]
1169
+ if hasattr(chunk, 'toarray'):
1170
+ pb[:, s] = np.asarray(chunk.sum(axis=1)).ravel()
1171
+ else:
1172
+ pb[:, s] = chunk.sum(axis=1).ravel()
1173
+ pb_dge = make_dgelist(pb)
1174
+ pb_dge = calc_norm_factors(pb_dge)
1175
+ norm_factors = pb_dge['samples']['norm.factors'].values
1176
+ # Expand sample-level norm factors to per-cell
1177
+ cell_nf = np.empty(nind, dtype=np.float64)
1178
+ for s in range(k):
1179
+ start, end = fid[s], fid[s + 1]
1180
+ cell_nf[start:end] = norm_factors[s]
1181
+ # Floor at 0.5 to avoid log(0) for zero-count cells
1182
+ offset_raw = np.maximum(lib_size * cell_nf, 0.5)
1183
+ log_offset, moffset, cv2 = _cv_offset(offset_raw, nind)
1184
+ else:
1185
+ # No normalization (original nebula behaviour)
1186
+ log_offset, moffset, cv2 = _cv_offset(None, nind)
1187
+
1188
+ # --- CPS check ---
1189
+ mfs = nind / k
1190
+ if mfs < 30 and verbose:
1191
+ warnings.warn(
1192
+ f"The average number of cells per subject ({mfs:.1f}) is less "
1193
+ f"than 30. NEBULA-LN may be inaccurate for small cell counts."
1194
+ )
1195
+
1196
+ # --- Cumsumy ---
1197
+ cumsumy = _call_cumsumy(counts, fid, k, ngene)
1198
+
1199
+ # --- Gene filtering ---
1200
+ # Non-zero cell counts per gene
1201
+ if hasattr(counts, 'nnz'):
1202
+ # sparse
1203
+ from scipy.sparse import issparse
1204
+ nz_per_gene = np.diff(counts.indptr) if hasattr(counts, 'indptr') else \
1205
+ np.array([(counts[g, :] != 0).sum() for g in range(ngene)])
1206
+ else:
1207
+ nz_per_gene = (counts != 0).sum(axis=1)
1208
+
1209
+ mean_cpc = cumsumy.sum(axis=1) / nind
1210
+ mask_cpc = mean_cpc > cpc
1211
+ mask_mincp = nz_per_gene >= mincp
1212
+ gene_mask = mask_cpc & mask_mincp
1213
+ gid = np.where(gene_mask)[0]
1214
+ lgid = len(gid)
1215
+
1216
+ if verbose:
1217
+ print(f"Remove {ngene - lgid} genes having low expression.")
1218
+ if lgid == 0:
1219
+ raise ValueError("No gene passed the filtering.")
1220
+ if verbose:
1221
+ print(f"Analyzing {lgid} genes with {k} subjects and {nind} cells.")
1222
+
1223
+ # posind per gene: which samples have non-zero counts
1224
+ posind_per_gene = [np.where(cumsumy[g, :] > 0)[0] for g in gid]
1225
+
1226
+ # --- Per-gene fitting ---
1227
+ def _fit_one(idx):
1228
+ g = gid[idx]
1229
+ if hasattr(counts, 'toarray'):
1230
+ y_gene = np.asarray(counts[g, :].toarray()).ravel()
1231
+ else:
1232
+ y_gene = counts[g, :]
1233
+ return _fit_gene_nebula_ln(
1234
+ g, y_gene, pred, log_offset, fid, cumsumy[g, :],
1235
+ posind_per_gene[idx], nb, nind, k, sds, int_col, moffset,
1236
+ min_bounds, max_bounds, mfs, cutoff_cell, kappa,
1237
+ )
1238
+
1239
+ if ncore > 1:
1240
+ # Parallel execution
1241
+ with ProcessPoolExecutor(max_workers=ncore) as executor:
1242
+ results = list(executor.map(_fit_one, range(lgid)))
1243
+ else:
1244
+ results = []
1245
+ for idx in range(lgid):
1246
+ if verbose and lgid > 100 and idx % max(1, lgid // 10) == 0:
1247
+ print(f" Gene {idx + 1}/{lgid}...")
1248
+ results.append(_fit_one(idx))
1249
+
1250
+ # --- Collect results ---
1251
+ coefficients = np.zeros((lgid, nb))
1252
+ se_arr = np.zeros((lgid, nb))
1253
+ sigma_sample = np.zeros(lgid)
1254
+ cell_disp = np.zeros(lgid) # 1/phi
1255
+ convergence = np.zeros(lgid, dtype=np.int32)
1256
+ algorithm_codes = np.zeros(lgid, dtype=np.int32)
1257
+
1258
+ for idx, res in enumerate(results):
1259
+ beta_r, se_r, sigma2, inv_phi, conv, fit, logw = res
1260
+ coefficients[idx, :] = beta_r
1261
+ se_arr[idx, :] = se_r
1262
+ sigma_sample[idx] = sigma2
1263
+ cell_disp[idx] = inv_phi
1264
+ convergence[idx] = conv
1265
+ algorithm_codes[idx] = fit
1266
+
1267
+ # --- Resolve predictor names ---
1268
+ if predn is None:
1269
+ predn = _design_colnames
1270
+ if predn is None:
1271
+ if design is not None and hasattr(design, 'columns'):
1272
+ predn = list(design.columns)
1273
+ if predn is None:
1274
+ predn = [f"V{i+1}" for i in range(nb)]
1275
+
1276
+ # --- Gene annotation DataFrame ---
1277
+ if gene_names is not None:
1278
+ genes_df = pd.DataFrame({'gene': gene_names[gid]})
1279
+ else:
1280
+ genes_df = None
1281
+
1282
+ # --- Average log abundance for filtered genes ---
1283
+ ave_log_abund = np.log2(mean_cpc[gid] + 0.5)
1284
+
1285
+ # --- DGEGLM-like return ---
1286
+ return {
1287
+ 'coefficients': coefficients,
1288
+ 'se': se_arr,
1289
+ 'dispersion': cell_disp,
1290
+ 'sigma_sample': sigma_sample,
1291
+ 'convergence': convergence,
1292
+ 'design': pred,
1293
+ 'offset': log_offset,
1294
+ 'genes': genes_df,
1295
+ 'gene_mask': gene_mask,
1296
+ 'method': 'nebula_ln',
1297
+ 'ncells': nind,
1298
+ 'nsamples': k,
1299
+ 'predictor_names': predn,
1300
+ 'sample_map': sample_ids_str,
1301
+ 'samples_unique': np.array(levels),
1302
+ 'ave_log_abundance': ave_log_abund,
1303
+ }
1304
+
1305
+
1306
+ def shrink_sc_disp(fit, counts=None, covariate=None, robust=True):
1307
+ """Empirical Bayes shrinkage of cell-level NB dispersion.
1308
+
1309
+ Shrinks the per-gene NB overdispersion parameter phi toward a
1310
+ (possibly trended) prior using limma's squeezeVar framework.
1311
+
1312
+ Parameters
1313
+ ----------
1314
+ fit : dict
1315
+ Output from ``glm_sc_fit()``.
1316
+ counts : ndarray or sparse matrix, optional
1317
+ Gene-by-cell count matrix (same genes/ordering as used in
1318
+ ``glm_sc_fit``). Used to compute log-mean abundance as
1319
+ covariate for the trended prior. If *None* and
1320
+ ``fit['ave_log_abundance']`` exists, that is used instead.
1321
+ covariate : array-like, optional
1322
+ Custom covariate for the trended prior. Overrides the
1323
+ abundance covariate derived from *counts*.
1324
+ robust : bool
1325
+ Use robust estimation (default True). Protects against
1326
+ outlier genes with extremely high or low dispersion.
1327
+
1328
+ Returns
1329
+ -------
1330
+ dict
1331
+ The input *fit* dict, updated in-place with new keys:
1332
+
1333
+ - ``phi_raw`` : raw per-gene phi (= 1/dispersion)
1334
+ - ``phi_post`` : posterior (shrunk) phi
1335
+ - ``phi_prior`` : prior phi (scalar or trended)
1336
+ - ``df_residual`` : residual degrees of freedom
1337
+ - ``df_prior_phi`` : prior df from empirical Bayes
1338
+ - ``dispersion_shrunk`` : 1/phi_post (shrunk dispersion)
1339
+ """
1340
+ import warnings
1341
+ from .limma_port import squeeze_var
1342
+
1343
+ dispersion = fit['dispersion']
1344
+ ngenes = len(dispersion)
1345
+
1346
+ # Convert to phi = 1/dispersion
1347
+ with np.errstate(divide='ignore'):
1348
+ phi_raw = np.where(dispersion > 0, 1.0 / dispersion, np.inf)
1349
+
1350
+ # Convergence mask: only use converged genes for prior estimation
1351
+ conv_mask = fit['convergence'] == 1
1352
+
1353
+ # Floor phi at a small positive value; mark inf as NaN
1354
+ phi_floor = 1e-8
1355
+ phi_use = np.maximum(phi_raw.copy(), phi_floor)
1356
+ phi_use[~np.isfinite(phi_use)] = np.nan
1357
+
1358
+ # Residual degrees of freedom: N - p - (K - 1)
1359
+ n_cells = fit['ncells']
1360
+ n_predictors = fit['design'].shape[1]
1361
+ n_samples = fit['nsamples']
1362
+ df_residual = n_cells - n_predictors - (n_samples - 1)
1363
+ df_residual = max(df_residual, 1)
1364
+
1365
+ # Determine covariate for trended prior
1366
+ if covariate is not None:
1367
+ cov = np.asarray(covariate, dtype=np.float64)
1368
+ elif counts is not None:
1369
+ if hasattr(counts, 'toarray'):
1370
+ mean_cpc = np.asarray(counts.mean(axis=1)).ravel()
1371
+ else:
1372
+ mean_cpc = counts.mean(axis=1).ravel()
1373
+ gene_mask = fit['gene_mask']
1374
+ cov = np.log2(mean_cpc[gene_mask] + 0.5)
1375
+ elif 'ave_log_abundance' in fit:
1376
+ cov = fit['ave_log_abundance']
1377
+ else:
1378
+ cov = None
1379
+
1380
+ # Filter to converged genes with finite phi
1381
+ ok_mask = conv_mask & np.isfinite(phi_use)
1382
+ idx_ok = np.where(ok_mask)[0]
1383
+
1384
+ if len(idx_ok) < 3:
1385
+ warnings.warn("Fewer than 3 converged genes; skipping shrinkage.")
1386
+ fit['phi_raw'] = phi_raw
1387
+ fit['phi_post'] = phi_raw.copy()
1388
+ fit['phi_prior'] = np.nan
1389
+ fit['df_residual'] = df_residual
1390
+ fit['df_prior_phi'] = 0.0
1391
+ fit['dispersion_shrunk'] = fit['dispersion'].copy()
1392
+ return fit
1393
+
1394
+ phi_ok = phi_use[idx_ok]
1395
+ cov_ok = cov[idx_ok] if cov is not None else None
1396
+
1397
+ # Call squeeze_var with scalar df (same for all genes).
1398
+ # Fall back gracefully: trended → untrended → no shrinkage.
1399
+ sv = None
1400
+ for cov_attempt in ([cov_ok, None] if cov_ok is not None else [None]):
1401
+ try:
1402
+ sv = squeeze_var(phi_ok, df=float(df_residual),
1403
+ covariate=cov_attempt, robust=robust)
1404
+ break
1405
+ except (ValueError, RuntimeError):
1406
+ continue
1407
+ if sv is None:
1408
+ try:
1409
+ sv = squeeze_var(phi_ok, df=float(df_residual),
1410
+ covariate=None, robust=False)
1411
+ except (ValueError, RuntimeError):
1412
+ warnings.warn("squeeze_var failed; returning unshrunk estimates.")
1413
+ fit['phi_raw'] = phi_raw
1414
+ fit['phi_post'] = phi_raw.copy()
1415
+ fit['phi_prior'] = np.nanmedian(phi_ok)
1416
+ fit['df_residual'] = df_residual
1417
+ fit['df_prior_phi'] = 0.0
1418
+ fit['dispersion_shrunk'] = fit['dispersion'].copy()
1419
+ return fit
1420
+
1421
+ # Map results back to full gene array
1422
+ phi_post = np.full(ngenes, np.nan)
1423
+ phi_post[idx_ok] = sv['var_post']
1424
+
1425
+ phi_prior_full = np.full(ngenes, np.nan)
1426
+ if isinstance(sv['var_prior'], np.ndarray):
1427
+ phi_prior_full[idx_ok] = sv['var_prior']
1428
+ median_prior = np.nanmedian(sv['var_prior'])
1429
+ else:
1430
+ phi_prior_full[:] = sv['var_prior']
1431
+ median_prior = sv['var_prior']
1432
+
1433
+ # Non-converged genes get the prior value
1434
+ phi_post[~ok_mask] = median_prior
1435
+ phi_prior_full[~ok_mask] = median_prior
1436
+
1437
+ # Store results
1438
+ fit['phi_raw'] = phi_raw
1439
+ fit['phi_post'] = phi_post
1440
+ fit['phi_prior'] = phi_prior_full
1441
+ fit['df_residual'] = df_residual
1442
+ fit['df_prior_phi'] = sv['df_prior']
1443
+ with np.errstate(divide='ignore'):
1444
+ fit['dispersion_shrunk'] = np.where(
1445
+ phi_post > 0, 1.0 / phi_post, np.inf
1446
+ )
1447
+
1448
+ return fit
1449
+
1450
+
1451
+ def glm_sc_test(fit, coef=None, contrast=None):
1452
+ """Wald test on a ``glm_sc_fit`` result.
1453
+
1454
+ Parameters
1455
+ ----------
1456
+ fit : dict
1457
+ Output from ``glm_sc_fit()``.
1458
+ coef : int, optional
1459
+ 0-based column index of the coefficient to test. Default: last
1460
+ column.
1461
+ contrast : ndarray, optional
1462
+ Custom contrast vector (length p). If given, *coef* is ignored.
1463
+
1464
+ Returns
1465
+ -------
1466
+ dict with key ``'table'`` containing a DataFrame with columns:
1467
+ logFC, SE, z, PValue, FDR, sigma_sample, dispersion, converged.
1468
+ """
1469
+ coefficients = fit['coefficients']
1470
+ se_arr = fit['se']
1471
+ ngenes, nb = coefficients.shape
1472
+
1473
+ if contrast is not None:
1474
+ contrast = np.asarray(contrast, dtype=np.float64)
1475
+ logFC = coefficients @ contrast
1476
+ se = np.sqrt(np.maximum(
1477
+ np.sum((se_arr ** 2) * (contrast ** 2), axis=1), 0
1478
+ ))
1479
+ else:
1480
+ if coef is None:
1481
+ coef = nb - 1
1482
+ logFC = coefficients[:, coef]
1483
+ se = se_arr[:, coef]
1484
+
1485
+ z = logFC / se
1486
+ pvalue = _chi2.sf(z ** 2, 1)
1487
+
1488
+ # FDR (Benjamini-Hochberg)
1489
+ n = len(pvalue)
1490
+ valid = ~np.isnan(pvalue)
1491
+ fdr = np.full(n, np.nan)
1492
+ if valid.any():
1493
+ from statsmodels.stats.multitest import multipletests
1494
+ _, fdr_vals, _, _ = multipletests(pvalue[valid], method='fdr_bh')
1495
+ fdr[valid] = fdr_vals
1496
+
1497
+ table = pd.DataFrame({
1498
+ 'logFC': logFC,
1499
+ 'SE': se,
1500
+ 'z': z,
1501
+ 'PValue': pvalue,
1502
+ 'FDR': fdr,
1503
+ 'sigma_sample': fit['sigma_sample'],
1504
+ 'dispersion': fit['dispersion'],
1505
+ 'converged': fit['convergence'],
1506
+ })
1507
+
1508
+ if fit.get('genes') is not None:
1509
+ table.index = fit['genes']
1510
+
1511
+ return {'table': table}