statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1280 @@
1
+ """
2
+ CUDA RawKernel for Cox PH Efron backward gradient/Hessian.
3
+
4
+ Sequential scan over unique failure times (ii). Enter/exit/failure-at-risk updates are
5
+ commutative; large index lists use ``atomicAdd`` (double, sm_60+), small lists (<= ``seq_thresh``)
6
+ use thread-0 sequential adds to avoid atomic overhead. Failure accumulation for large ``m`` is
7
+ parallel; Efron formulas remain on thread 0. Workspace ends with a scratch double for parallel
8
+ ``xp0f`` sum.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ from typing import Any, List, Optional, Tuple
15
+
16
+ import numpy as np
17
+
18
+ _KERNEL_SOURCE = r"""
19
+ /* sm_60+ double atomicAdd. Small batches: thread0 sequential (no atomics). Large: parallel atomics. */
20
+ #define EFRON_MAX_P_STACK 128
21
+ // seq_thresh is passed via meta[3] for runtime tuning (see python launch code).
22
+
23
+ extern "C" __global__
24
+ void efron_backward_scan(
25
+ const double* __restrict__ X,
26
+ const double* __restrict__ e_eta,
27
+ const int* __restrict__ meta,
28
+ const int* __restrict__ enter_ptr,
29
+ const int* __restrict__ enter_ind,
30
+ const int* __restrict__ exit_ptr,
31
+ const int* __restrict__ exit_ind,
32
+ const int* __restrict__ fail_ptr,
33
+ const int* __restrict__ fail_ind,
34
+ double* __restrict__ grad_out,
35
+ double* __restrict__ hess_out,
36
+ double* __restrict__ workspace
37
+ ) {
38
+ int n = meta[0];
39
+ int p = meta[1];
40
+ int nuft = meta[2];
41
+ int seq_thresh = meta[3];
42
+ (void)n;
43
+ if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return;
44
+ if (threadIdx.y != 0 || threadIdx.z != 0) return;
45
+
46
+ double* xp0_ptr = workspace;
47
+ double* xp1 = xp0_ptr + 1;
48
+ double* xp2 = xp1 + p;
49
+ double* hess_acc = xp2 + p * p;
50
+ double* xp1f = hess_acc + p * p;
51
+ double* xp2f = xp1f + p;
52
+ double* scratch_xp0f = xp2f + p * p;
53
+
54
+ int ws_doubles = 2 + 2 * p + 3 * p * p;
55
+ for (int i = threadIdx.x; i < ws_doubles; i += blockDim.x) {
56
+ workspace[i] = 0.0;
57
+ }
58
+ for (int j = threadIdx.x; j < p; j += blockDim.x) {
59
+ grad_out[j] = 0.0;
60
+ }
61
+ __syncthreads();
62
+
63
+ for (int ii = nuft - 1; ii >= 0; ii--) {
64
+ int e0 = enter_ptr[ii];
65
+ int e1 = enter_ptr[ii + 1];
66
+ int nt = e1 - e0;
67
+ if (nt <= seq_thresh) {
68
+ if (threadIdx.x == 0) {
69
+ for (int t = e0; t < e1; t++) {
70
+ int idx = enter_ind[t];
71
+ const double* Xrow = X + (size_t)idx * (size_t)p;
72
+ double elx = e_eta[idx];
73
+ *xp0_ptr += elx;
74
+ if (p <= EFRON_MAX_P_STACK) {
75
+ double row[EFRON_MAX_P_STACK];
76
+ for (int j = 0; j < p; j++) row[j] = Xrow[j];
77
+ for (int j = 0; j < p; j++) xp1[j] += elx * row[j];
78
+ for (int j = 0; j < p; j++)
79
+ for (int k = 0; k < p; k++)
80
+ xp2[j * p + k] += elx * row[j] * row[k];
81
+ } else {
82
+ for (int j = 0; j < p; j++) xp1[j] += elx * Xrow[j];
83
+ for (int j = 0; j < p; j++) {
84
+ double vj = Xrow[j];
85
+ for (int k = 0; k < p; k++)
86
+ xp2[j * p + k] += elx * vj * Xrow[k];
87
+ }
88
+ }
89
+ }
90
+ }
91
+ } else {
92
+ // Stage A: block-local accumulation for small p to reduce global atomics.
93
+ // Stage B: one global writeback per aggregated entry.
94
+ if (p <= 64) {
95
+ extern __shared__ double sh_mem[];
96
+ double* sh_xp0_s = sh_mem;
97
+ double* sh_xp1_s = sh_xp0_s + 1;
98
+ double* sh_xp2_s = sh_xp1_s + 64;
99
+ if (threadIdx.x == 0) *sh_xp0_s = 0.0;
100
+ for (int j = threadIdx.x; j < p; j += blockDim.x) {
101
+ sh_xp1_s[j] = 0.0;
102
+ }
103
+ for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
104
+ sh_xp2_s[j] = 0.0;
105
+ }
106
+ __syncthreads();
107
+ for (int tt = threadIdx.x; tt < nt; tt += blockDim.x) {
108
+ int idx = enter_ind[e0 + tt];
109
+ const double* Xrow = X + (size_t)idx * (size_t)p;
110
+ double elx = e_eta[idx];
111
+ atomicAdd(sh_xp0_s, elx);
112
+ for (int j = 0; j < p; j++) {
113
+ double vj = Xrow[j];
114
+ atomicAdd(sh_xp1_s + j, elx * vj);
115
+ }
116
+ for (int j = 0; j < p; j++) {
117
+ double vj = Xrow[j];
118
+ for (int k = 0; k < p; k++) {
119
+ atomicAdd(sh_xp2_s + j * p + k, elx * vj * Xrow[k]);
120
+ }
121
+ }
122
+ }
123
+ __syncthreads();
124
+ if (threadIdx.x == 0) atomicAdd(xp0_ptr, *sh_xp0_s);
125
+ for (int j = threadIdx.x; j < p; j += blockDim.x) {
126
+ atomicAdd(xp1 + j, sh_xp1_s[j]);
127
+ }
128
+ for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
129
+ atomicAdd(xp2 + j, sh_xp2_s[j]);
130
+ }
131
+ } else {
132
+ for (int tt = threadIdx.x; tt < nt; tt += blockDim.x) {
133
+ int idx = enter_ind[e0 + tt];
134
+ const double* Xrow = X + (size_t)idx * (size_t)p;
135
+ double elx = e_eta[idx];
136
+ atomicAdd(xp0_ptr, elx);
137
+ if (p <= EFRON_MAX_P_STACK) {
138
+ double row[EFRON_MAX_P_STACK];
139
+ for (int j = 0; j < p; j++) row[j] = Xrow[j];
140
+ for (int j = 0; j < p; j++) atomicAdd(xp1 + j, elx * row[j]);
141
+ for (int j = 0; j < p; j++)
142
+ for (int k = 0; k < p; k++)
143
+ atomicAdd(xp2 + j * p + k, elx * row[j] * row[k]);
144
+ } else {
145
+ for (int j = 0; j < p; j++) atomicAdd(xp1 + j, elx * Xrow[j]);
146
+ for (int j = 0; j < p; j++) {
147
+ double vj = Xrow[j];
148
+ for (int k = 0; k < p; k++)
149
+ atomicAdd(xp2 + j * p + k, elx * vj * Xrow[k]);
150
+ }
151
+ }
152
+ }
153
+ }
154
+ }
155
+ __syncthreads();
156
+
157
+ int f0 = fail_ptr[ii];
158
+ int f1 = fail_ptr[ii + 1];
159
+ int m = f1 - f0;
160
+ if (m > 0) {
161
+ for (int j = threadIdx.x; j < p; j += blockDim.x) {
162
+ xp1f[j] = 0.0;
163
+ }
164
+ for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
165
+ xp2f[j] = 0.0;
166
+ }
167
+ __syncthreads();
168
+
169
+ if (m <= seq_thresh) {
170
+ if (threadIdx.x == 0) {
171
+ double xp0v = *xp0_ptr;
172
+ double xp0f = 0.0;
173
+ for (int t = f0; t < f1; t++) {
174
+ int idx = fail_ind[t];
175
+ const double* Xrow = X + (size_t)idx * (size_t)p;
176
+ double elx = e_eta[idx];
177
+ xp0f += elx;
178
+ if (p <= EFRON_MAX_P_STACK) {
179
+ double row[EFRON_MAX_P_STACK];
180
+ for (int j = 0; j < p; j++) row[j] = Xrow[j];
181
+ for (int j = 0; j < p; j++) {
182
+ xp1f[j] += elx * row[j];
183
+ grad_out[j] += row[j];
184
+ }
185
+ for (int j = 0; j < p; j++)
186
+ for (int k = 0; k < p; k++)
187
+ xp2f[j * p + k] += elx * row[j] * row[k];
188
+ } else {
189
+ for (int j = 0; j < p; j++) {
190
+ double vj = Xrow[j];
191
+ xp1f[j] += elx * vj;
192
+ grad_out[j] += vj;
193
+ }
194
+ for (int j = 0; j < p; j++) {
195
+ double vj = Xrow[j];
196
+ for (int k = 0; k < p; k++)
197
+ xp2f[j * p + k] += elx * vj * Xrow[k];
198
+ }
199
+ }
200
+ }
201
+ double sum_inv_c0 = 0.0;
202
+ double sum_J_c0 = 0.0;
203
+ double sum_aa = 0.0;
204
+ double sum_bb = 0.0;
205
+ double sum_ab = 0.0;
206
+ for (int kk = 0; kk < m; kk++) {
207
+ double Jk = (double)kk / (double)m;
208
+ double c0 = xp0v - Jk * xp0f;
209
+ if (c0 < 1e-300) c0 = 1e-300;
210
+ double ak = 1.0 / c0;
211
+ double bk = Jk * ak;
212
+ sum_inv_c0 += ak;
213
+ sum_J_c0 += Jk / c0;
214
+ sum_aa += ak * ak;
215
+ sum_bb += bk * bk;
216
+ sum_ab += ak * bk;
217
+ }
218
+ for (int j = 0; j < p; j++) {
219
+ grad_out[j] -= (xp1[j] * sum_inv_c0 - xp1f[j] * sum_J_c0);
220
+ }
221
+ for (int j = 0; j < p * p; j++) {
222
+ hess_acc[j] += xp2[j] * sum_inv_c0;
223
+ hess_acc[j] -= xp2f[j] * sum_J_c0;
224
+ }
225
+ for (int j1 = 0; j1 < p; j1++) {
226
+ for (int j2 = j1; j2 < p; j2++) {
227
+ double o11 = xp1[j1] * xp1[j2];
228
+ double off = xp1f[j1] * xp1f[j2];
229
+ double cross = xp1[j1] * xp1f[j2] + xp1f[j1] * xp1[j2];
230
+ double hsub = sum_aa * o11 + sum_bb * off - sum_ab * cross;
231
+ hess_acc[j1 * p + j2] -= hsub;
232
+ if (j2 != j1) hess_acc[j2 * p + j1] -= hsub;
233
+ }
234
+ }
235
+ }
236
+ } else {
237
+ if (p <= 64) {
238
+ extern __shared__ double sh_mem2[];
239
+ double* sh_xp0f_s = sh_mem2;
240
+ double* sh_xp1f_s = sh_xp0f_s + 1;
241
+ double* sh_grad_s = sh_xp1f_s + 64;
242
+ double* sh_xp2f_s = sh_grad_s + 64;
243
+ if (threadIdx.x == 0) *sh_xp0f_s = 0.0;
244
+ for (int j = threadIdx.x; j < p; j += blockDim.x) {
245
+ sh_xp1f_s[j] = 0.0;
246
+ sh_grad_s[j] = 0.0;
247
+ }
248
+ for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
249
+ sh_xp2f_s[j] = 0.0;
250
+ }
251
+ __syncthreads();
252
+ for (int tt = threadIdx.x; tt < m; tt += blockDim.x) {
253
+ int idx = fail_ind[f0 + tt];
254
+ const double* Xrow = X + (size_t)idx * (size_t)p;
255
+ double elx = e_eta[idx];
256
+ atomicAdd(sh_xp0f_s, elx);
257
+ for (int j = 0; j < p; j++) {
258
+ double vj = Xrow[j];
259
+ atomicAdd(sh_xp1f_s + j, elx * vj);
260
+ atomicAdd(sh_grad_s + j, vj);
261
+ }
262
+ for (int j = 0; j < p; j++) {
263
+ double vj = Xrow[j];
264
+ for (int k = 0; k < p; k++) {
265
+ atomicAdd(sh_xp2f_s + j * p + k, elx * vj * Xrow[k]);
266
+ }
267
+ }
268
+ }
269
+ __syncthreads();
270
+ for (int j = threadIdx.x; j < p; j += blockDim.x) {
271
+ xp1f[j] = sh_xp1f_s[j];
272
+ atomicAdd(grad_out + j, sh_grad_s[j]);
273
+ }
274
+ for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
275
+ xp2f[j] = sh_xp2f_s[j];
276
+ }
277
+ if (threadIdx.x == 0) {
278
+ *scratch_xp0f = *sh_xp0f_s;
279
+ }
280
+ } else {
281
+ if (threadIdx.x == 0) {
282
+ *scratch_xp0f = 0.0;
283
+ }
284
+ __syncthreads();
285
+ for (int tt = threadIdx.x; tt < m; tt += blockDim.x) {
286
+ int idx = fail_ind[f0 + tt];
287
+ const double* Xrow = X + (size_t)idx * (size_t)p;
288
+ double elx = e_eta[idx];
289
+ atomicAdd(scratch_xp0f, elx);
290
+ if (p <= EFRON_MAX_P_STACK) {
291
+ double row[EFRON_MAX_P_STACK];
292
+ for (int j = 0; j < p; j++) row[j] = Xrow[j];
293
+ for (int j = 0; j < p; j++) {
294
+ atomicAdd(xp1f + j, elx * row[j]);
295
+ atomicAdd(grad_out + j, row[j]);
296
+ }
297
+ for (int j = 0; j < p; j++)
298
+ for (int k = 0; k < p; k++)
299
+ atomicAdd(xp2f + j * p + k, elx * row[j] * row[k]);
300
+ } else {
301
+ for (int j = 0; j < p; j++) {
302
+ double vj = Xrow[j];
303
+ atomicAdd(xp1f + j, elx * vj);
304
+ atomicAdd(grad_out + j, vj);
305
+ }
306
+ for (int j = 0; j < p; j++) {
307
+ double vj = Xrow[j];
308
+ for (int k = 0; k < p; k++)
309
+ atomicAdd(xp2f + j * p + k, elx * vj * Xrow[k]);
310
+ }
311
+ }
312
+ }
313
+ }
314
+ __syncthreads();
315
+ if (threadIdx.x == 0) {
316
+ double xp0v = *xp0_ptr;
317
+ double xp0f = *scratch_xp0f;
318
+ double sum_inv_c0 = 0.0;
319
+ double sum_J_c0 = 0.0;
320
+ double sum_aa = 0.0;
321
+ double sum_bb = 0.0;
322
+ double sum_ab = 0.0;
323
+ for (int kk = 0; kk < m; kk++) {
324
+ double Jk = (double)kk / (double)m;
325
+ double c0 = xp0v - Jk * xp0f;
326
+ if (c0 < 1e-300) c0 = 1e-300;
327
+ double ak = 1.0 / c0;
328
+ double bk = Jk * ak;
329
+ sum_inv_c0 += ak;
330
+ sum_J_c0 += Jk / c0;
331
+ sum_aa += ak * ak;
332
+ sum_bb += bk * bk;
333
+ sum_ab += ak * bk;
334
+ }
335
+ for (int j = 0; j < p; j++) {
336
+ grad_out[j] -= (xp1[j] * sum_inv_c0 - xp1f[j] * sum_J_c0);
337
+ }
338
+ for (int j = 0; j < p * p; j++) {
339
+ hess_acc[j] += xp2[j] * sum_inv_c0;
340
+ hess_acc[j] -= xp2f[j] * sum_J_c0;
341
+ }
342
+ for (int j1 = 0; j1 < p; j1++) {
343
+ for (int j2 = j1; j2 < p; j2++) {
344
+ double o11 = xp1[j1] * xp1[j2];
345
+ double off = xp1f[j1] * xp1f[j2];
346
+ double cross = xp1[j1] * xp1f[j2] + xp1f[j1] * xp1[j2];
347
+ double hsub = sum_aa * o11 + sum_bb * off - sum_ab * cross;
348
+ hess_acc[j1 * p + j2] -= hsub;
349
+ if (j2 != j1) hess_acc[j2 * p + j1] -= hsub;
350
+ }
351
+ }
352
+ }
353
+ }
354
+ }
355
+ __syncthreads();
356
+
357
+ int x0 = exit_ptr[ii];
358
+ int x1 = exit_ptr[ii + 1];
359
+ int nx = x1 - x0;
360
+ if (nx <= seq_thresh) {
361
+ if (threadIdx.x == 0) {
362
+ for (int t = x0; t < x1; t++) {
363
+ int idx = exit_ind[t];
364
+ const double* Xrow = X + (size_t)idx * (size_t)p;
365
+ double elx = e_eta[idx];
366
+ *xp0_ptr -= elx;
367
+ if (p <= EFRON_MAX_P_STACK) {
368
+ double row[EFRON_MAX_P_STACK];
369
+ for (int j = 0; j < p; j++) row[j] = Xrow[j];
370
+ for (int j = 0; j < p; j++) xp1[j] -= elx * row[j];
371
+ for (int j = 0; j < p; j++)
372
+ for (int k = 0; k < p; k++)
373
+ xp2[j * p + k] -= elx * row[j] * row[k];
374
+ } else {
375
+ for (int j = 0; j < p; j++) xp1[j] -= elx * Xrow[j];
376
+ for (int j = 0; j < p; j++) {
377
+ double vj = Xrow[j];
378
+ for (int k = 0; k < p; k++)
379
+ xp2[j * p + k] -= elx * vj * Xrow[k];
380
+ }
381
+ }
382
+ }
383
+ }
384
+ } else {
385
+ // Stage A/B block accumulation for exit updates, mirroring enter side.
386
+ if (p <= 64) {
387
+ __shared__ double sh_xp0;
388
+ __shared__ double sh_xp1[64];
389
+ __shared__ double sh_xp2[4096];
390
+ if (threadIdx.x == 0) sh_xp0 = 0.0;
391
+ for (int j = threadIdx.x; j < p; j += blockDim.x) {
392
+ sh_xp1[j] = 0.0;
393
+ }
394
+ for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
395
+ sh_xp2[j] = 0.0;
396
+ }
397
+ __syncthreads();
398
+ for (int tt = threadIdx.x; tt < nx; tt += blockDim.x) {
399
+ int idx = exit_ind[x0 + tt];
400
+ const double* Xrow = X + (size_t)idx * (size_t)p;
401
+ double elx = e_eta[idx];
402
+ atomicAdd(&sh_xp0, -elx);
403
+ for (int j = 0; j < p; j++) {
404
+ double vj = Xrow[j];
405
+ atomicAdd(sh_xp1 + j, -elx * vj);
406
+ }
407
+ for (int j = 0; j < p; j++) {
408
+ double vj = Xrow[j];
409
+ for (int k = 0; k < p; k++) {
410
+ atomicAdd(sh_xp2 + j * p + k, -elx * vj * Xrow[k]);
411
+ }
412
+ }
413
+ }
414
+ __syncthreads();
415
+ if (threadIdx.x == 0) atomicAdd(xp0_ptr, sh_xp0);
416
+ for (int j = threadIdx.x; j < p; j += blockDim.x) {
417
+ atomicAdd(xp1 + j, sh_xp1[j]);
418
+ }
419
+ for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
420
+ atomicAdd(xp2 + j, sh_xp2[j]);
421
+ }
422
+ } else {
423
+ for (int tt = threadIdx.x; tt < nx; tt += blockDim.x) {
424
+ int idx = exit_ind[x0 + tt];
425
+ const double* Xrow = X + (size_t)idx * (size_t)p;
426
+ double elx = e_eta[idx];
427
+ atomicAdd(xp0_ptr, -elx);
428
+ if (p <= EFRON_MAX_P_STACK) {
429
+ double row[EFRON_MAX_P_STACK];
430
+ for (int j = 0; j < p; j++) row[j] = Xrow[j];
431
+ for (int j = 0; j < p; j++) atomicAdd(xp1 + j, -elx * row[j]);
432
+ for (int j = 0; j < p; j++)
433
+ for (int k = 0; k < p; k++)
434
+ atomicAdd(xp2 + j * p + k, -elx * row[j] * row[k]);
435
+ } else {
436
+ for (int j = 0; j < p; j++) atomicAdd(xp1 + j, -elx * Xrow[j]);
437
+ for (int j = 0; j < p; j++) {
438
+ double vj = Xrow[j];
439
+ for (int k = 0; k < p; k++)
440
+ atomicAdd(xp2 + j * p + k, -elx * vj * Xrow[k]);
441
+ }
442
+ }
443
+ }
444
+ }
445
+ }
446
+ __syncthreads();
447
+ }
448
+
449
+ if (threadIdx.x == 0) {
450
+ for (int j = 0; j < p * p; j++) {
451
+ hess_out[j] = -hess_acc[j];
452
+ }
453
+ }
454
+ }
455
+ """
456
+
457
+ _KERNEL_SOURCE_SERIAL = r"""
458
+ #define EFRON_MAX_P_STACK 128
459
+ extern "C" __global__
460
+ void efron_backward_scan_serial(
461
+ const double* __restrict__ X,
462
+ const double* __restrict__ e_eta,
463
+ const int* __restrict__ meta,
464
+ const int* __restrict__ enter_ptr,
465
+ const int* __restrict__ enter_ind,
466
+ const int* __restrict__ exit_ptr,
467
+ const int* __restrict__ exit_ind,
468
+ const int* __restrict__ fail_ptr,
469
+ const int* __restrict__ fail_ind,
470
+ double* __restrict__ grad_out,
471
+ double* __restrict__ hess_out,
472
+ double* __restrict__ workspace
473
+ ) {
474
+ if (blockIdx.x != 0 || threadIdx.x != 0) return;
475
+ int p = meta[1];
476
+ int nuft = meta[2];
477
+
478
+ double* xp0_ptr = workspace;
479
+ double* xp1 = xp0_ptr + 1;
480
+ double* xp2 = xp1 + p;
481
+ double* hess_acc = xp2 + p * p;
482
+ double* xp1f = hess_acc + p * p;
483
+ double* xp2f = xp1f + p;
484
+
485
+ *xp0_ptr = 0.0;
486
+ for (int j = 0; j < p; j++) {
487
+ grad_out[j] = 0.0;
488
+ xp1[j] = 0.0;
489
+ xp1f[j] = 0.0;
490
+ }
491
+ for (int j = 0; j < p * p; j++) {
492
+ xp2[j] = 0.0;
493
+ xp2f[j] = 0.0;
494
+ hess_acc[j] = 0.0;
495
+ }
496
+
497
+ for (int ii = nuft - 1; ii >= 0; ii--) {
498
+ int e0 = enter_ptr[ii];
499
+ int e1 = enter_ptr[ii + 1];
500
+ for (int t = e0; t < e1; t++) {
501
+ int idx = enter_ind[t];
502
+ const double* Xrow = X + (size_t)idx * (size_t)p;
503
+ double elx = e_eta[idx];
504
+ *xp0_ptr += elx;
505
+ for (int j = 0; j < p; j++) xp1[j] += elx * Xrow[j];
506
+ for (int j = 0; j < p; j++) {
507
+ double vj = Xrow[j];
508
+ for (int k = 0; k < p; k++) xp2[j * p + k] += elx * vj * Xrow[k];
509
+ }
510
+ }
511
+
512
+ int f0 = fail_ptr[ii];
513
+ int f1 = fail_ptr[ii + 1];
514
+ int m = f1 - f0;
515
+ if (m > 0) {
516
+ for (int j = 0; j < p; j++) xp1f[j] = 0.0;
517
+ for (int j = 0; j < p * p; j++) xp2f[j] = 0.0;
518
+ double xp0v = *xp0_ptr;
519
+ double xp0f = 0.0;
520
+
521
+ for (int t = f0; t < f1; t++) {
522
+ int idx = fail_ind[t];
523
+ const double* Xrow = X + (size_t)idx * (size_t)p;
524
+ double elx = e_eta[idx];
525
+ xp0f += elx;
526
+ for (int j = 0; j < p; j++) {
527
+ xp1f[j] += elx * Xrow[j];
528
+ grad_out[j] += Xrow[j];
529
+ }
530
+ for (int j = 0; j < p; j++) {
531
+ double vj = Xrow[j];
532
+ for (int k = 0; k < p; k++) xp2f[j * p + k] += elx * vj * Xrow[k];
533
+ }
534
+ }
535
+
536
+ double sum_inv_c0 = 0.0;
537
+ double sum_J_c0 = 0.0;
538
+ double sum_aa = 0.0;
539
+ double sum_bb = 0.0;
540
+ double sum_ab = 0.0;
541
+ for (int kk = 0; kk < m; kk++) {
542
+ double Jk = (double)kk / (double)m;
543
+ double c0 = xp0v - Jk * xp0f;
544
+ if (c0 < 1e-300) c0 = 1e-300;
545
+ double ak = 1.0 / c0;
546
+ double bk = Jk * ak;
547
+ sum_inv_c0 += ak;
548
+ sum_J_c0 += Jk / c0;
549
+ sum_aa += ak * ak;
550
+ sum_bb += bk * bk;
551
+ sum_ab += ak * bk;
552
+ }
553
+
554
+ for (int j = 0; j < p; j++) grad_out[j] -= (xp1[j] * sum_inv_c0 - xp1f[j] * sum_J_c0);
555
+ for (int j = 0; j < p * p; j++) {
556
+ hess_acc[j] += xp2[j] * sum_inv_c0;
557
+ hess_acc[j] -= xp2f[j] * sum_J_c0;
558
+ }
559
+ for (int j1 = 0; j1 < p; j1++) {
560
+ for (int j2 = j1; j2 < p; j2++) {
561
+ double o11 = xp1[j1] * xp1[j2];
562
+ double off = xp1f[j1] * xp1f[j2];
563
+ double cross = xp1[j1] * xp1f[j2] + xp1f[j1] * xp1[j2];
564
+ double hsub = sum_aa * o11 + sum_bb * off - sum_ab * cross;
565
+ hess_acc[j1 * p + j2] -= hsub;
566
+ if (j2 != j1) hess_acc[j2 * p + j1] -= hsub;
567
+ }
568
+ }
569
+ }
570
+
571
+ int x0 = exit_ptr[ii];
572
+ int x1 = exit_ptr[ii + 1];
573
+ for (int t = x0; t < x1; t++) {
574
+ int idx = exit_ind[t];
575
+ const double* Xrow = X + (size_t)idx * (size_t)p;
576
+ double elx = e_eta[idx];
577
+ *xp0_ptr -= elx;
578
+ for (int j = 0; j < p; j++) xp1[j] -= elx * Xrow[j];
579
+ for (int j = 0; j < p; j++) {
580
+ double vj = Xrow[j];
581
+ for (int k = 0; k < p; k++) xp2[j * p + k] -= elx * vj * Xrow[k];
582
+ }
583
+ }
584
+ }
585
+
586
+ for (int j = 0; j < p * p; j++) hess_out[j] = -hess_acc[j];
587
+ }
588
+ """
589
+
590
+ # Workspace: xp0(1) + xp1(p) + xp2(p*p) + hess_acc(p*p) + xp1f(p) + xp2f(p*p) + scratch(1)
591
+ EFRON_BACKWARD_THREADS: int = 128
592
+
593
+ _kernel_cache: Any = None
594
+ _KERNEL_VER = 9
595
+ _kernel_cache_serial: Any = None
596
+ _KERNEL_VER_SERIAL = 1
597
+
598
+
599
+ def _env_int(name: str, default: int) -> int:
600
+ """Parse env int with a safe fallback."""
601
+ v = os.environ.get(name)
602
+ if v is None:
603
+ return default
604
+ try:
605
+ return int(v)
606
+ except Exception:
607
+ return default
608
+
609
+
610
+ def _pick_backward_launch_params(p: int, nuft: int, n: int) -> Tuple[int, int]:
611
+ """Choose (seq_thresh, threads) with env override + sane defaults.
612
+
613
+ - `STATGPU_EFRON_SEQ_THRESH` >= 0 overrides seq threshold.
614
+ - `STATGPU_EFRON_BACKWARD_THREADS` > 0 overrides threads.
615
+ - Otherwise choose a heavier sequential threshold for small/medium `p`,
616
+ which often reduces atomic overhead in heavy-ties cases.
617
+ """
618
+ seq_env = _env_int("STATGPU_EFRON_SEQ_THRESH", -1)
619
+ th_env = _env_int("STATGPU_EFRON_BACKWARD_THREADS", -1)
620
+
621
+ if th_env > 0:
622
+ threads = int(th_env)
623
+ else:
624
+ # Small/medium shapes are usually latency-sensitive; fewer threads can help.
625
+ if p <= 24 and nuft <= 512:
626
+ threads = 64
627
+ elif p <= 64 and nuft <= 512:
628
+ threads = 64
629
+ else:
630
+ threads = EFRON_BACKWARD_THREADS
631
+
632
+ if seq_env >= 0:
633
+ seq_thresh = int(seq_env)
634
+ else:
635
+ # Favor single-thread local accumulation for moderate group sizes.
636
+ avg_group = float(n) / max(1.0, float(nuft))
637
+ if p <= 24:
638
+ seq_thresh = 64
639
+ elif p <= 64:
640
+ # Heavy ties (large n/nuft) are often dominated by atomic contention.
641
+ # Favor larger sequential groups to reduce p^2 atomic pressure.
642
+ if avg_group >= 64.0:
643
+ seq_thresh = 256
644
+ elif avg_group >= 32.0:
645
+ seq_thresh = 128
646
+ else:
647
+ seq_thresh = 64
648
+ else:
649
+ seq_thresh = 16
650
+
651
+ return seq_thresh, threads
652
+
653
+
654
+ def _pack_csr(groups: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
655
+ ptr = [0]
656
+ ind: List[int] = []
657
+ for g in groups:
658
+ ind.extend(int(x) for x in g)
659
+ ptr.append(len(ind))
660
+ return np.asarray(ptr, dtype=np.int32), np.asarray(ind, dtype=np.int32)
661
+
662
+
663
+ def efron_indices_to_csr(
664
+ uft_ix: List[List[int]], risk_enter: List[List[int]], risk_exit: List[List[int]], nuft: int
665
+ ) -> Tuple[np.ndarray, ...]:
666
+ enter_ptr, enter_ind = _pack_csr(risk_enter)
667
+ exit_ptr, exit_ind = _pack_csr(risk_exit)
668
+ fail_ptr, fail_ind = _pack_csr(uft_ix)
669
+ assert enter_ptr.size == nuft + 1 and exit_ptr.size == nuft + 1 and fail_ptr.size == nuft + 1
670
+ return enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind
671
+
672
+
673
+ def get_efron_backward_kernel(cp):
674
+ global _kernel_cache
675
+ if (
676
+ _kernel_cache is None
677
+ or not isinstance(_kernel_cache, tuple)
678
+ or _kernel_cache[1] != _KERNEL_VER
679
+ ):
680
+ _kernel_cache = (cp.RawKernel(_KERNEL_SOURCE, "efron_backward_scan"), _KERNEL_VER)
681
+ return _kernel_cache[0]
682
+
683
+
684
+ def get_efron_backward_kernel_serial(cp):
685
+ global _kernel_cache_serial
686
+ if (
687
+ _kernel_cache_serial is None
688
+ or not isinstance(_kernel_cache_serial, tuple)
689
+ or _kernel_cache_serial[1] != _KERNEL_VER_SERIAL
690
+ ):
691
+ _kernel_cache_serial = (
692
+ cp.RawKernel(_KERNEL_SOURCE_SERIAL, "efron_backward_scan_serial"),
693
+ _KERNEL_VER_SERIAL,
694
+ )
695
+ return _kernel_cache_serial[0]
696
+
697
+
698
+ _LOGLIK_THREADS: int = 128
699
+
700
+ _LOGLIK_KERNEL_SOURCE = r"""
701
+ #define EFRON_LOGLIK_THREADS 128
702
+ extern "C" __global__
703
+ void efron_loglik_by_group(
704
+ const double* __restrict__ eta,
705
+ const double* __restrict__ exp_eta,
706
+ const double* __restrict__ risk_sum,
707
+ const int* __restrict__ meta, // meta[0] = nuft
708
+ const int* __restrict__ fail_ptr,
709
+ const int* __restrict__ fail_ind,
710
+ const int* __restrict__ first_idx_uft,
711
+ double* __restrict__ out_ll
712
+ ) {
713
+ int tid = (int)threadIdx.x;
714
+ int g = (int)blockIdx.x;
715
+ int nuft = meta[0];
716
+ if (g >= nuft) return;
717
+
718
+ int start = fail_ptr[g];
719
+ int end = fail_ptr[g + 1];
720
+ int m = end - start;
721
+
722
+ __shared__ double sh_events[EFRON_LOGLIK_THREADS];
723
+ __shared__ double sh_eta[EFRON_LOGLIK_THREADS];
724
+ __shared__ double sh_logs[EFRON_LOGLIK_THREADS];
725
+
726
+ double local_events = 0.0;
727
+ double local_eta = 0.0;
728
+ for (int i = start + tid; i < end; i += (int)blockDim.x) {
729
+ int idx = fail_ind[i];
730
+ double ex = exp_eta[idx];
731
+ local_events += ex;
732
+ local_eta += eta[idx];
733
+ }
734
+
735
+ sh_events[tid] = local_events;
736
+ sh_eta[tid] = local_eta;
737
+ __syncthreads();
738
+
739
+ // Reduce sum_events and sum_eta.
740
+ for (int stride = (int)blockDim.x / 2; stride > 0; stride >>= 1) {
741
+ if (tid < stride) {
742
+ sh_events[tid] += sh_events[tid + stride];
743
+ sh_eta[tid] += sh_eta[tid + stride];
744
+ }
745
+ __syncthreads();
746
+ }
747
+
748
+ double sum_events = sh_events[0];
749
+ double sum_eta = sh_eta[0];
750
+ double risk_at_t = risk_sum[first_idx_uft[g]];
751
+
752
+ double local_logs = 0.0;
753
+ for (int k = tid; k < m; k += (int)blockDim.x) {
754
+ double Jk = (double)k / (double)m;
755
+ double denom = risk_at_t - Jk * sum_events;
756
+ if (denom < 1e-300) denom = 1e-300;
757
+ local_logs += log(denom);
758
+ }
759
+
760
+ sh_logs[tid] = local_logs;
761
+ __syncthreads();
762
+
763
+ // Reduce sum_logs.
764
+ for (int stride = (int)blockDim.x / 2; stride > 0; stride >>= 1) {
765
+ if (tid < stride) {
766
+ sh_logs[tid] += sh_logs[tid + stride];
767
+ }
768
+ __syncthreads();
769
+ }
770
+
771
+ if (tid == 0) {
772
+ // ll = sum(eta[idx]) - sum_{k=0..m-1} log(risk_at_t - k/m * sum_events)
773
+ out_ll[g] = sum_eta - sh_logs[0];
774
+ }
775
+ }
776
+ """
777
+
778
+ _kernel_cache_loglik: Any = None
779
+ _KERNEL_VER_LOGLIK = 1
780
+
781
+
782
+ def get_efron_loglik_kernel(cp):
783
+ global _kernel_cache_loglik
784
+ if (
785
+ _kernel_cache_loglik is None
786
+ or not isinstance(_kernel_cache_loglik, tuple)
787
+ or _kernel_cache_loglik[1] != _KERNEL_VER_LOGLIK
788
+ ):
789
+ _kernel_cache_loglik = (
790
+ cp.RawKernel(_LOGLIK_KERNEL_SOURCE, "efron_loglik_by_group"),
791
+ _KERNEL_VER_LOGLIK,
792
+ )
793
+ return _kernel_cache_loglik[0]
794
+
795
+
796
+ def compute_efron_loglik_raw_csr(
797
+ eta,
798
+ exp_eta,
799
+ risk_sum,
800
+ fail_ptr,
801
+ fail_ind,
802
+ first_idx_uft,
803
+ nuft: int,
804
+ *,
805
+ cupy_module,
806
+ ) -> Any:
807
+ """
808
+ Compute scalar Efron log partial likelihood on GPU using a single kernel.
809
+ `fail_ptr/fail_ind` are CSR arrays for uft_ix; `first_idx_uft` is int32.
810
+ """
811
+ cp = cupy_module
812
+ if nuft == 0:
813
+ return cp.array(0.0, dtype=cp.float64)
814
+
815
+ # RawKernel assumes contiguous storage; avoid redundant copies.
816
+ if not getattr(eta, "flags", None) or not eta.flags.c_contiguous:
817
+ eta = cp.ascontiguousarray(eta)
818
+ if not getattr(exp_eta, "flags", None) or not exp_eta.flags.c_contiguous:
819
+ exp_eta = cp.ascontiguousarray(exp_eta)
820
+ if not getattr(risk_sum, "flags", None) or not risk_sum.flags.c_contiguous:
821
+ risk_sum = cp.ascontiguousarray(risk_sum)
822
+
823
+ fail_ptr_g = cp.asarray(fail_ptr, dtype=cp.int32)
824
+ fail_ind_g = cp.asarray(fail_ind, dtype=cp.int32)
825
+ first_idx_uft_g = cp.asarray(first_idx_uft, dtype=cp.int32)
826
+
827
+ out_ll = cp.zeros(int(nuft), dtype=cp.float64)
828
+ meta = cp.array([int(nuft)], dtype=cp.int32)
829
+ kernel = get_efron_loglik_kernel(cp)
830
+ try:
831
+ kernel(
832
+ (int(nuft),),
833
+ (_LOGLIK_THREADS,),
834
+ (
835
+ eta,
836
+ exp_eta,
837
+ risk_sum,
838
+ meta,
839
+ fail_ptr_g,
840
+ fail_ind_g,
841
+ first_idx_uft_g,
842
+ out_ll,
843
+ ),
844
+ )
845
+ return cp.sum(out_ll)
846
+ except Exception:
847
+ # If kernel launch fails, let caller fallback to Python loop.
848
+ raise
849
+
850
+
851
+ def compute_efron_loglik_raw(eta, exp_eta, risk_sum, time, efron_pre, *, cupy_module):
852
+ """
853
+ Scalar partial log-likelihood (Efron) on GPU.
854
+
855
+ Uses a CuPy loop over cached failure groups; inner Efron sum over k is vectorized.
856
+ When ``efron_pre`` includes ``first_idx_uft`` (from ``_efron_unique_failure_indices``), avoids host ``searchsorted`` per group.
857
+
858
+ When cupy_module=None, uses NumPy for CPU computation (used by Torch backend fallback).
859
+ """
860
+ cp = cupy_module
861
+
862
+ # NumPy fallback for Torch backend
863
+ if cp is None:
864
+ return _compute_efron_loglik_raw_numpy(eta, exp_eta, risk_sum, time, efron_pre)
865
+
866
+ if len(efron_pre) == 6:
867
+ uft_arr, uft_ix, _, _, nuft, first_idx_uft = efron_pre
868
+ else:
869
+ uft_arr, uft_ix, _, _, nuft = efron_pre
870
+ first_idx_uft = None
871
+ if nuft == 0:
872
+ return cp.array(0.0, dtype=cp.float64)
873
+
874
+ ll = cp.zeros((), dtype=cp.float64)
875
+ fi_gpu = cp.asarray(first_idx_uft, dtype=cp.int32) if first_idx_uft is not None else None
876
+ time_np = None
877
+ if fi_gpu is None:
878
+ time_np = cp.asnumpy(time).astype(np.float64, copy=False)
879
+
880
+ for i in range(nuft):
881
+ ix_ev = uft_ix[i]
882
+ d = len(ix_ev)
883
+ if d == 0:
884
+ continue
885
+ if fi_gpu is not None:
886
+ first_idx = fi_gpu[i]
887
+ else:
888
+ first_idx = int(np.searchsorted(time_np, float(uft_arr[i]), side="left"))
889
+ risk_at_t = risk_sum[first_idx]
890
+ idx = cp.asarray(ix_ev, dtype=cp.int32)
891
+ sum_events = cp.sum(exp_eta[idx])
892
+ kd = float(d)
893
+ k = cp.arange(d, dtype=cp.float64)
894
+ denom = risk_at_t - (k / kd) * sum_events
895
+ ll -= cp.sum(cp.log(cp.maximum(denom, 1e-300)))
896
+ ll += cp.sum(eta[idx])
897
+ return ll
898
+
899
+
900
+ def _compute_efron_loglik_raw_numpy(eta, exp_eta, risk_sum, time, efron_pre):
901
+ """
902
+ NumPy implementation of Efron log-likelihood for Torch backend fallback.
903
+
904
+ Parameters
905
+ ----------
906
+ eta : ndarray
907
+ Linear predictor values (n_samples,)
908
+ exp_eta : ndarray
909
+ exp(eta) values (n_samples,)
910
+ risk_sum : ndarray
911
+ Cumulative risk sums (n_samples,)
912
+ time : ndarray
913
+ Event times (n_samples,)
914
+ efron_pre : tuple
915
+ Precomputed failure time indices from _efron_unique_failure_indices
916
+
917
+ Returns
918
+ -------
919
+ float
920
+ Log-likelihood value
921
+ """
922
+ if len(efron_pre) == 6:
923
+ uft_arr, uft_ix, _, _, nuft, first_idx_uft = efron_pre
924
+ else:
925
+ uft_arr, uft_ix, _, _, nuft = efron_pre
926
+ first_idx_uft = None
927
+
928
+ if nuft == 0:
929
+ return 0.0
930
+
931
+ ll = 0.0
932
+ time_np = None
933
+ if first_idx_uft is None:
934
+ time_np = time.astype(np.float64, copy=False)
935
+
936
+ for i in range(nuft):
937
+ ix_ev = uft_ix[i]
938
+ d = len(ix_ev)
939
+ if d == 0:
940
+ continue
941
+
942
+ if first_idx_uft is not None:
943
+ first_idx = first_idx_uft[i]
944
+ else:
945
+ first_idx = int(np.searchsorted(time_np, float(uft_arr[i]), side="left"))
946
+
947
+ risk_at_t = risk_sum[first_idx]
948
+ sum_events = np.sum(exp_eta[ix_ev])
949
+ kd = float(d)
950
+
951
+ # Efron correction: loop over k=0..d-1
952
+ k = np.arange(d, dtype=np.float64)
953
+ denom = risk_at_t - (k / kd) * sum_events
954
+ ll -= np.sum(np.log(np.maximum(denom, 1e-300)))
955
+ ll += np.sum(eta[ix_ev])
956
+
957
+ return ll
958
+
959
+
960
+ def compute_efron_grad_hess_raw(
961
+ X,
962
+ beta,
963
+ efron_pre,
964
+ *,
965
+ cupy_module,
966
+ efron_csr=None,
967
+ ) -> Tuple[Any, Any]:
968
+ """
969
+ Returns (grad, hess) as cupy arrays. Falls back to None if launch fails (caller uses Python path).
970
+ """
971
+ cp = cupy_module
972
+ if efron_csr is not None:
973
+ # (enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind, first_idx_uft, nuft)
974
+ enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind, _, nuft = efron_csr
975
+ else:
976
+ if len(efron_pre) == 6:
977
+ _, uft_ix, risk_enter, risk_exit, nuft, _ = efron_pre
978
+ else:
979
+ _, uft_ix, risk_enter, risk_exit, nuft = efron_pre
980
+ enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind = efron_indices_to_csr(
981
+ uft_ix, risk_enter, risk_exit, nuft
982
+ )
983
+
984
+ if nuft == 0:
985
+ p = int(X.shape[1])
986
+ return cp.zeros(p, dtype=cp.float64), cp.zeros((p, p), dtype=cp.float64)
987
+
988
+ n, p = int(X.shape[0]), int(X.shape[1])
989
+ linpred = X @ beta
990
+ linpred = linpred - cp.max(linpred)
991
+ e_eta = cp.exp(linpred)
992
+
993
+ enter_ptr_g = cp.asarray(enter_ptr)
994
+ enter_ind_g = cp.asarray(enter_ind)
995
+ exit_ptr_g = cp.asarray(exit_ptr)
996
+ exit_ind_g = cp.asarray(exit_ind)
997
+ fail_ptr_g = cp.asarray(fail_ptr)
998
+ fail_ind_g = cp.asarray(fail_ind)
999
+
1000
+ grad_out = cp.zeros(p, dtype=cp.float64)
1001
+ hess_out = cp.zeros((p, p), dtype=cp.float64)
1002
+ ws = 2 + 2 * p + 3 * p * p
1003
+ workspace = cp.zeros(ws, dtype=cp.float64)
1004
+
1005
+ seq_thresh, threads = _pick_backward_launch_params(p, int(nuft), int(n))
1006
+ use_serial = (p <= 24 and int(nuft) <= 512 and _env_int("STATGPU_EFRON_SERIAL_KERNEL", 1) == 1)
1007
+ meta = cp.array([n, p, nuft, seq_thresh], dtype=cp.int32)
1008
+ kernel = get_efron_backward_kernel_serial(cp) if use_serial else get_efron_backward_kernel(cp)
1009
+ shared_mem = 33800 if (not use_serial and p <= 64) else 0
1010
+ try:
1011
+ kernel(
1012
+ (1,),
1013
+ ((1,) if use_serial else (threads,)),
1014
+ (
1015
+ X,
1016
+ e_eta,
1017
+ meta,
1018
+ enter_ptr_g,
1019
+ enter_ind_g,
1020
+ exit_ptr_g,
1021
+ exit_ind_g,
1022
+ fail_ptr_g,
1023
+ fail_ind_g,
1024
+ grad_out,
1025
+ hess_out.reshape(-1),
1026
+ workspace,
1027
+ ),
1028
+ shared_mem=shared_mem,
1029
+ )
1030
+ # Surface asynchronous kernel execution errors at this call site so fallback
1031
+ # behavior is reliable and diagnostics point to the correct launch.
1032
+ cp.cuda.Stream.null.synchronize()
1033
+ except Exception:
1034
+ return None
1035
+
1036
+ return grad_out, hess_out
1037
+
1038
+
1039
+ _BRESLOW_KERNEL_VER = 1
1040
+ _breslow_kernel_cache = None
1041
+ _BRESLOW_UPDATE_KERNEL_VER = 1
1042
+ _breslow_update_kernel_cache = None
1043
+
1044
+ _BRESLOW_KERNEL_SOURCE = r"""
1045
+ extern "C" __global__
1046
+ void breslow_backward_hess_scan(
1047
+ const double* __restrict__ X,
1048
+ const double* __restrict__ e_eta,
1049
+ const int* __restrict__ first_idx,
1050
+ const double* __restrict__ counts,
1051
+ const int* __restrict__ meta, // [n, p, nuft, seq_thresh]
1052
+ double* __restrict__ hess_out,
1053
+ double* __restrict__ workspace
1054
+ ) {
1055
+ int n = meta[0];
1056
+ int p = meta[1];
1057
+ int nuft = meta[2];
1058
+ int seq_thresh = meta[3];
1059
+ if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return;
1060
+ if (threadIdx.y != 0 || threadIdx.z != 0) return;
1061
+
1062
+ double* xp0_ptr = workspace; // 1
1063
+ double* xp1 = xp0_ptr + 1; // p
1064
+ double* xp2 = xp1 + p; // p*p
1065
+ double* hess_acc = xp2 + p * p; // p*p
1066
+ int ws_doubles = 1 + p + 2 * p * p;
1067
+
1068
+ for (int i = threadIdx.x; i < ws_doubles; i += blockDim.x) {
1069
+ workspace[i] = 0.0;
1070
+ }
1071
+ __syncthreads();
1072
+
1073
+ for (int ii = nuft - 1; ii >= 0; --ii) {
1074
+ int start = first_idx[ii];
1075
+ int end = (ii == nuft - 1) ? n : first_idx[ii + 1];
1076
+ int nt = end - start;
1077
+
1078
+ if (nt > 0) {
1079
+ if (nt <= seq_thresh) {
1080
+ if (threadIdx.x == 0) {
1081
+ for (int r = start; r < end; ++r) {
1082
+ const double* Xrow = X + (size_t)r * (size_t)p;
1083
+ double elx = e_eta[r];
1084
+ *xp0_ptr += elx;
1085
+ for (int j = 0; j < p; ++j) {
1086
+ double vj = Xrow[j];
1087
+ xp1[j] += elx * vj;
1088
+ }
1089
+ for (int j = 0; j < p; ++j) {
1090
+ double vj = Xrow[j];
1091
+ for (int k = 0; k < p; ++k) {
1092
+ xp2[j * p + k] += elx * vj * Xrow[k];
1093
+ }
1094
+ }
1095
+ }
1096
+ }
1097
+ } else {
1098
+ for (int rr = threadIdx.x; rr < nt; rr += blockDim.x) {
1099
+ int r = start + rr;
1100
+ const double* Xrow = X + (size_t)r * (size_t)p;
1101
+ double elx = e_eta[r];
1102
+ atomicAdd(xp0_ptr, elx);
1103
+ for (int j = 0; j < p; ++j) {
1104
+ atomicAdd(xp1 + j, elx * Xrow[j]);
1105
+ }
1106
+ for (int j = 0; j < p; ++j) {
1107
+ double vj = Xrow[j];
1108
+ for (int k = 0; k < p; ++k) {
1109
+ atomicAdd(xp2 + j * p + k, elx * vj * Xrow[k]);
1110
+ }
1111
+ }
1112
+ }
1113
+ }
1114
+ __syncthreads();
1115
+ }
1116
+
1117
+ if (threadIdx.x == 0) {
1118
+ double rs = *xp0_ptr;
1119
+ double w = counts[ii];
1120
+ if (rs > 1e-300 && w != 0.0) {
1121
+ double inv = 1.0 / rs;
1122
+ double inv2 = inv * inv;
1123
+ for (int j1 = 0; j1 < p; ++j1) {
1124
+ double x1 = xp1[j1];
1125
+ for (int j2 = 0; j2 < p; ++j2) {
1126
+ double exx = xp2[j1 * p + j2] * inv;
1127
+ double ex1ex2 = x1 * xp1[j2] * inv2;
1128
+ hess_acc[j1 * p + j2] += w * (exx - ex1ex2);
1129
+ }
1130
+ }
1131
+ }
1132
+ }
1133
+ __syncthreads();
1134
+ }
1135
+
1136
+ for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
1137
+ hess_out[j] = -hess_acc[j];
1138
+ }
1139
+ }
1140
+ """
1141
+
1142
+
1143
+ def get_breslow_hess_kernel(cp):
1144
+ global _breslow_kernel_cache
1145
+ if (
1146
+ _breslow_kernel_cache is None
1147
+ or not isinstance(_breslow_kernel_cache, tuple)
1148
+ or _breslow_kernel_cache[1] != _BRESLOW_KERNEL_VER
1149
+ ):
1150
+ _breslow_kernel_cache = (
1151
+ cp.RawKernel(_BRESLOW_KERNEL_SOURCE, "breslow_backward_hess_scan"),
1152
+ _BRESLOW_KERNEL_VER,
1153
+ )
1154
+ return _breslow_kernel_cache[0]
1155
+
1156
+
1157
+ _BRESLOW_UPDATE_KERNEL_SOURCE = r"""
1158
+ extern "C" __global__
1159
+ void breslow_hess_update(
1160
+ double* __restrict__ hess,
1161
+ const double* __restrict__ risk_x2,
1162
+ const double* __restrict__ ex,
1163
+ const double rs,
1164
+ const double count,
1165
+ const int p
1166
+ ) {
1167
+ int idx = (int)(blockIdx.x * blockDim.x + threadIdx.x);
1168
+ int total = p * p;
1169
+ if (idx >= total) return;
1170
+ int j1 = idx / p;
1171
+ int j2 = idx - j1 * p;
1172
+ double exx = risk_x2[idx] / rs;
1173
+ double outer = ex[j1] * ex[j2];
1174
+ hess[idx] -= count * (exx - outer);
1175
+ }
1176
+ """
1177
+
1178
+
1179
+ def get_breslow_hess_update_kernel(cp):
1180
+ global _breslow_update_kernel_cache
1181
+ if (
1182
+ _breslow_update_kernel_cache is None
1183
+ or not isinstance(_breslow_update_kernel_cache, tuple)
1184
+ or _breslow_update_kernel_cache[1] != _BRESLOW_UPDATE_KERNEL_VER
1185
+ ):
1186
+ _breslow_update_kernel_cache = (
1187
+ cp.RawKernel(_BRESLOW_UPDATE_KERNEL_SOURCE, "breslow_hess_update"),
1188
+ _BRESLOW_UPDATE_KERNEL_VER,
1189
+ )
1190
+ return _breslow_update_kernel_cache[0]
1191
+
1192
+
1193
+ def apply_breslow_hess_update_raw(hess, risk_x2, ex, rs, count, *, cupy_module):
1194
+ cp = cupy_module
1195
+ p = int(ex.shape[0])
1196
+ if p <= 0:
1197
+ return
1198
+ threads = 256
1199
+ blocks = (p * p + threads - 1) // threads
1200
+ kernel = get_breslow_hess_update_kernel(cp)
1201
+ kernel(
1202
+ (blocks,),
1203
+ (threads,),
1204
+ (
1205
+ hess.reshape(-1),
1206
+ risk_x2.reshape(-1),
1207
+ ex,
1208
+ float(rs),
1209
+ float(count),
1210
+ p,
1211
+ ),
1212
+ )
1213
+
1214
+
1215
+ def compute_breslow_hess_raw(
1216
+ X,
1217
+ first_idx_uft,
1218
+ counts_uft,
1219
+ *,
1220
+ cupy_module,
1221
+ exp_eta=None,
1222
+ beta=None,
1223
+ ):
1224
+ """
1225
+ Fused CuPy RawKernel Hessian for Breslow grouped ties.
1226
+ Returns hess (cupy array) or None on launch/compile failure.
1227
+ """
1228
+ cp = cupy_module
1229
+ nuft = int(first_idx_uft.size)
1230
+ p = int(X.shape[1])
1231
+ if nuft == 0:
1232
+ return cp.zeros((p, p), dtype=cp.float64)
1233
+
1234
+ n = int(X.shape[0])
1235
+ if exp_eta is None:
1236
+ if beta is None:
1237
+ raise ValueError("compute_breslow_hess_raw requires either exp_eta or beta")
1238
+ linpred = X @ beta
1239
+ linpred = linpred - cp.max(linpred)
1240
+ e_eta = cp.exp(linpred)
1241
+ else:
1242
+ e_eta = exp_eta
1243
+
1244
+ first_idx_g = cp.asarray(first_idx_uft, dtype=cp.int32)
1245
+ counts_g = cp.asarray(counts_uft, dtype=cp.float64)
1246
+ hess_out = cp.zeros((p, p), dtype=cp.float64)
1247
+ workspace = cp.zeros(1 + p + 2 * p * p, dtype=cp.float64)
1248
+ # Keep stable default launch behavior, allow opt-in manual tuning.
1249
+ seq_thresh, threads = _pick_backward_launch_params(p, nuft, n)
1250
+ th_env = os.environ.get("STATGPU_BRESLOW_HESS_THREADS", "").strip()
1251
+ if th_env:
1252
+ try:
1253
+ threads = max(32, min(512, int(th_env)))
1254
+ except Exception:
1255
+ pass
1256
+ seq_env = os.environ.get("STATGPU_BRESLOW_SEQ_THRESH", "").strip()
1257
+ if seq_env:
1258
+ try:
1259
+ seq_thresh = max(1, int(seq_env))
1260
+ except Exception:
1261
+ pass
1262
+ meta = cp.array([n, p, nuft, seq_thresh], dtype=cp.int32)
1263
+ kernel = get_breslow_hess_kernel(cp)
1264
+ try:
1265
+ kernel(
1266
+ (1,),
1267
+ (threads,),
1268
+ (
1269
+ X,
1270
+ e_eta,
1271
+ first_idx_g,
1272
+ counts_g,
1273
+ meta,
1274
+ hess_out.reshape(-1),
1275
+ workspace,
1276
+ ),
1277
+ )
1278
+ except Exception:
1279
+ return None
1280
+ return hess_out