statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1280 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CUDA RawKernel for Cox PH Efron backward gradient/Hessian.
|
|
3
|
+
|
|
4
|
+
Sequential scan over unique failure times (ii). Enter/exit/failure-at-risk updates are
|
|
5
|
+
commutative; large index lists use ``atomicAdd`` (double, sm_60+), small lists (<= ``seq_thresh``)
|
|
6
|
+
use thread-0 sequential adds to avoid atomic overhead. Failure accumulation for large ``m`` is
|
|
7
|
+
parallel; Efron formulas remain on thread 0. Workspace ends with a scratch double for parallel
|
|
8
|
+
``xp0f`` sum.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import os
|
|
14
|
+
from typing import Any, List, Optional, Tuple
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
_KERNEL_SOURCE = r"""
|
|
19
|
+
/* sm_60+ double atomicAdd. Small batches: thread0 sequential (no atomics). Large: parallel atomics. */
|
|
20
|
+
#define EFRON_MAX_P_STACK 128
|
|
21
|
+
// seq_thresh is passed via meta[3] for runtime tuning (see python launch code).
|
|
22
|
+
|
|
23
|
+
extern "C" __global__
|
|
24
|
+
void efron_backward_scan(
|
|
25
|
+
const double* __restrict__ X,
|
|
26
|
+
const double* __restrict__ e_eta,
|
|
27
|
+
const int* __restrict__ meta,
|
|
28
|
+
const int* __restrict__ enter_ptr,
|
|
29
|
+
const int* __restrict__ enter_ind,
|
|
30
|
+
const int* __restrict__ exit_ptr,
|
|
31
|
+
const int* __restrict__ exit_ind,
|
|
32
|
+
const int* __restrict__ fail_ptr,
|
|
33
|
+
const int* __restrict__ fail_ind,
|
|
34
|
+
double* __restrict__ grad_out,
|
|
35
|
+
double* __restrict__ hess_out,
|
|
36
|
+
double* __restrict__ workspace
|
|
37
|
+
) {
|
|
38
|
+
int n = meta[0];
|
|
39
|
+
int p = meta[1];
|
|
40
|
+
int nuft = meta[2];
|
|
41
|
+
int seq_thresh = meta[3];
|
|
42
|
+
(void)n;
|
|
43
|
+
if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return;
|
|
44
|
+
if (threadIdx.y != 0 || threadIdx.z != 0) return;
|
|
45
|
+
|
|
46
|
+
double* xp0_ptr = workspace;
|
|
47
|
+
double* xp1 = xp0_ptr + 1;
|
|
48
|
+
double* xp2 = xp1 + p;
|
|
49
|
+
double* hess_acc = xp2 + p * p;
|
|
50
|
+
double* xp1f = hess_acc + p * p;
|
|
51
|
+
double* xp2f = xp1f + p;
|
|
52
|
+
double* scratch_xp0f = xp2f + p * p;
|
|
53
|
+
|
|
54
|
+
int ws_doubles = 2 + 2 * p + 3 * p * p;
|
|
55
|
+
for (int i = threadIdx.x; i < ws_doubles; i += blockDim.x) {
|
|
56
|
+
workspace[i] = 0.0;
|
|
57
|
+
}
|
|
58
|
+
for (int j = threadIdx.x; j < p; j += blockDim.x) {
|
|
59
|
+
grad_out[j] = 0.0;
|
|
60
|
+
}
|
|
61
|
+
__syncthreads();
|
|
62
|
+
|
|
63
|
+
for (int ii = nuft - 1; ii >= 0; ii--) {
|
|
64
|
+
int e0 = enter_ptr[ii];
|
|
65
|
+
int e1 = enter_ptr[ii + 1];
|
|
66
|
+
int nt = e1 - e0;
|
|
67
|
+
if (nt <= seq_thresh) {
|
|
68
|
+
if (threadIdx.x == 0) {
|
|
69
|
+
for (int t = e0; t < e1; t++) {
|
|
70
|
+
int idx = enter_ind[t];
|
|
71
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
72
|
+
double elx = e_eta[idx];
|
|
73
|
+
*xp0_ptr += elx;
|
|
74
|
+
if (p <= EFRON_MAX_P_STACK) {
|
|
75
|
+
double row[EFRON_MAX_P_STACK];
|
|
76
|
+
for (int j = 0; j < p; j++) row[j] = Xrow[j];
|
|
77
|
+
for (int j = 0; j < p; j++) xp1[j] += elx * row[j];
|
|
78
|
+
for (int j = 0; j < p; j++)
|
|
79
|
+
for (int k = 0; k < p; k++)
|
|
80
|
+
xp2[j * p + k] += elx * row[j] * row[k];
|
|
81
|
+
} else {
|
|
82
|
+
for (int j = 0; j < p; j++) xp1[j] += elx * Xrow[j];
|
|
83
|
+
for (int j = 0; j < p; j++) {
|
|
84
|
+
double vj = Xrow[j];
|
|
85
|
+
for (int k = 0; k < p; k++)
|
|
86
|
+
xp2[j * p + k] += elx * vj * Xrow[k];
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
} else {
|
|
92
|
+
// Stage A: block-local accumulation for small p to reduce global atomics.
|
|
93
|
+
// Stage B: one global writeback per aggregated entry.
|
|
94
|
+
if (p <= 64) {
|
|
95
|
+
extern __shared__ double sh_mem[];
|
|
96
|
+
double* sh_xp0_s = sh_mem;
|
|
97
|
+
double* sh_xp1_s = sh_xp0_s + 1;
|
|
98
|
+
double* sh_xp2_s = sh_xp1_s + 64;
|
|
99
|
+
if (threadIdx.x == 0) *sh_xp0_s = 0.0;
|
|
100
|
+
for (int j = threadIdx.x; j < p; j += blockDim.x) {
|
|
101
|
+
sh_xp1_s[j] = 0.0;
|
|
102
|
+
}
|
|
103
|
+
for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
|
|
104
|
+
sh_xp2_s[j] = 0.0;
|
|
105
|
+
}
|
|
106
|
+
__syncthreads();
|
|
107
|
+
for (int tt = threadIdx.x; tt < nt; tt += blockDim.x) {
|
|
108
|
+
int idx = enter_ind[e0 + tt];
|
|
109
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
110
|
+
double elx = e_eta[idx];
|
|
111
|
+
atomicAdd(sh_xp0_s, elx);
|
|
112
|
+
for (int j = 0; j < p; j++) {
|
|
113
|
+
double vj = Xrow[j];
|
|
114
|
+
atomicAdd(sh_xp1_s + j, elx * vj);
|
|
115
|
+
}
|
|
116
|
+
for (int j = 0; j < p; j++) {
|
|
117
|
+
double vj = Xrow[j];
|
|
118
|
+
for (int k = 0; k < p; k++) {
|
|
119
|
+
atomicAdd(sh_xp2_s + j * p + k, elx * vj * Xrow[k]);
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
__syncthreads();
|
|
124
|
+
if (threadIdx.x == 0) atomicAdd(xp0_ptr, *sh_xp0_s);
|
|
125
|
+
for (int j = threadIdx.x; j < p; j += blockDim.x) {
|
|
126
|
+
atomicAdd(xp1 + j, sh_xp1_s[j]);
|
|
127
|
+
}
|
|
128
|
+
for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
|
|
129
|
+
atomicAdd(xp2 + j, sh_xp2_s[j]);
|
|
130
|
+
}
|
|
131
|
+
} else {
|
|
132
|
+
for (int tt = threadIdx.x; tt < nt; tt += blockDim.x) {
|
|
133
|
+
int idx = enter_ind[e0 + tt];
|
|
134
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
135
|
+
double elx = e_eta[idx];
|
|
136
|
+
atomicAdd(xp0_ptr, elx);
|
|
137
|
+
if (p <= EFRON_MAX_P_STACK) {
|
|
138
|
+
double row[EFRON_MAX_P_STACK];
|
|
139
|
+
for (int j = 0; j < p; j++) row[j] = Xrow[j];
|
|
140
|
+
for (int j = 0; j < p; j++) atomicAdd(xp1 + j, elx * row[j]);
|
|
141
|
+
for (int j = 0; j < p; j++)
|
|
142
|
+
for (int k = 0; k < p; k++)
|
|
143
|
+
atomicAdd(xp2 + j * p + k, elx * row[j] * row[k]);
|
|
144
|
+
} else {
|
|
145
|
+
for (int j = 0; j < p; j++) atomicAdd(xp1 + j, elx * Xrow[j]);
|
|
146
|
+
for (int j = 0; j < p; j++) {
|
|
147
|
+
double vj = Xrow[j];
|
|
148
|
+
for (int k = 0; k < p; k++)
|
|
149
|
+
atomicAdd(xp2 + j * p + k, elx * vj * Xrow[k]);
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
__syncthreads();
|
|
156
|
+
|
|
157
|
+
int f0 = fail_ptr[ii];
|
|
158
|
+
int f1 = fail_ptr[ii + 1];
|
|
159
|
+
int m = f1 - f0;
|
|
160
|
+
if (m > 0) {
|
|
161
|
+
for (int j = threadIdx.x; j < p; j += blockDim.x) {
|
|
162
|
+
xp1f[j] = 0.0;
|
|
163
|
+
}
|
|
164
|
+
for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
|
|
165
|
+
xp2f[j] = 0.0;
|
|
166
|
+
}
|
|
167
|
+
__syncthreads();
|
|
168
|
+
|
|
169
|
+
if (m <= seq_thresh) {
|
|
170
|
+
if (threadIdx.x == 0) {
|
|
171
|
+
double xp0v = *xp0_ptr;
|
|
172
|
+
double xp0f = 0.0;
|
|
173
|
+
for (int t = f0; t < f1; t++) {
|
|
174
|
+
int idx = fail_ind[t];
|
|
175
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
176
|
+
double elx = e_eta[idx];
|
|
177
|
+
xp0f += elx;
|
|
178
|
+
if (p <= EFRON_MAX_P_STACK) {
|
|
179
|
+
double row[EFRON_MAX_P_STACK];
|
|
180
|
+
for (int j = 0; j < p; j++) row[j] = Xrow[j];
|
|
181
|
+
for (int j = 0; j < p; j++) {
|
|
182
|
+
xp1f[j] += elx * row[j];
|
|
183
|
+
grad_out[j] += row[j];
|
|
184
|
+
}
|
|
185
|
+
for (int j = 0; j < p; j++)
|
|
186
|
+
for (int k = 0; k < p; k++)
|
|
187
|
+
xp2f[j * p + k] += elx * row[j] * row[k];
|
|
188
|
+
} else {
|
|
189
|
+
for (int j = 0; j < p; j++) {
|
|
190
|
+
double vj = Xrow[j];
|
|
191
|
+
xp1f[j] += elx * vj;
|
|
192
|
+
grad_out[j] += vj;
|
|
193
|
+
}
|
|
194
|
+
for (int j = 0; j < p; j++) {
|
|
195
|
+
double vj = Xrow[j];
|
|
196
|
+
for (int k = 0; k < p; k++)
|
|
197
|
+
xp2f[j * p + k] += elx * vj * Xrow[k];
|
|
198
|
+
}
|
|
199
|
+
}
|
|
200
|
+
}
|
|
201
|
+
double sum_inv_c0 = 0.0;
|
|
202
|
+
double sum_J_c0 = 0.0;
|
|
203
|
+
double sum_aa = 0.0;
|
|
204
|
+
double sum_bb = 0.0;
|
|
205
|
+
double sum_ab = 0.0;
|
|
206
|
+
for (int kk = 0; kk < m; kk++) {
|
|
207
|
+
double Jk = (double)kk / (double)m;
|
|
208
|
+
double c0 = xp0v - Jk * xp0f;
|
|
209
|
+
if (c0 < 1e-300) c0 = 1e-300;
|
|
210
|
+
double ak = 1.0 / c0;
|
|
211
|
+
double bk = Jk * ak;
|
|
212
|
+
sum_inv_c0 += ak;
|
|
213
|
+
sum_J_c0 += Jk / c0;
|
|
214
|
+
sum_aa += ak * ak;
|
|
215
|
+
sum_bb += bk * bk;
|
|
216
|
+
sum_ab += ak * bk;
|
|
217
|
+
}
|
|
218
|
+
for (int j = 0; j < p; j++) {
|
|
219
|
+
grad_out[j] -= (xp1[j] * sum_inv_c0 - xp1f[j] * sum_J_c0);
|
|
220
|
+
}
|
|
221
|
+
for (int j = 0; j < p * p; j++) {
|
|
222
|
+
hess_acc[j] += xp2[j] * sum_inv_c0;
|
|
223
|
+
hess_acc[j] -= xp2f[j] * sum_J_c0;
|
|
224
|
+
}
|
|
225
|
+
for (int j1 = 0; j1 < p; j1++) {
|
|
226
|
+
for (int j2 = j1; j2 < p; j2++) {
|
|
227
|
+
double o11 = xp1[j1] * xp1[j2];
|
|
228
|
+
double off = xp1f[j1] * xp1f[j2];
|
|
229
|
+
double cross = xp1[j1] * xp1f[j2] + xp1f[j1] * xp1[j2];
|
|
230
|
+
double hsub = sum_aa * o11 + sum_bb * off - sum_ab * cross;
|
|
231
|
+
hess_acc[j1 * p + j2] -= hsub;
|
|
232
|
+
if (j2 != j1) hess_acc[j2 * p + j1] -= hsub;
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
}
|
|
236
|
+
} else {
|
|
237
|
+
if (p <= 64) {
|
|
238
|
+
extern __shared__ double sh_mem2[];
|
|
239
|
+
double* sh_xp0f_s = sh_mem2;
|
|
240
|
+
double* sh_xp1f_s = sh_xp0f_s + 1;
|
|
241
|
+
double* sh_grad_s = sh_xp1f_s + 64;
|
|
242
|
+
double* sh_xp2f_s = sh_grad_s + 64;
|
|
243
|
+
if (threadIdx.x == 0) *sh_xp0f_s = 0.0;
|
|
244
|
+
for (int j = threadIdx.x; j < p; j += blockDim.x) {
|
|
245
|
+
sh_xp1f_s[j] = 0.0;
|
|
246
|
+
sh_grad_s[j] = 0.0;
|
|
247
|
+
}
|
|
248
|
+
for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
|
|
249
|
+
sh_xp2f_s[j] = 0.0;
|
|
250
|
+
}
|
|
251
|
+
__syncthreads();
|
|
252
|
+
for (int tt = threadIdx.x; tt < m; tt += blockDim.x) {
|
|
253
|
+
int idx = fail_ind[f0 + tt];
|
|
254
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
255
|
+
double elx = e_eta[idx];
|
|
256
|
+
atomicAdd(sh_xp0f_s, elx);
|
|
257
|
+
for (int j = 0; j < p; j++) {
|
|
258
|
+
double vj = Xrow[j];
|
|
259
|
+
atomicAdd(sh_xp1f_s + j, elx * vj);
|
|
260
|
+
atomicAdd(sh_grad_s + j, vj);
|
|
261
|
+
}
|
|
262
|
+
for (int j = 0; j < p; j++) {
|
|
263
|
+
double vj = Xrow[j];
|
|
264
|
+
for (int k = 0; k < p; k++) {
|
|
265
|
+
atomicAdd(sh_xp2f_s + j * p + k, elx * vj * Xrow[k]);
|
|
266
|
+
}
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
__syncthreads();
|
|
270
|
+
for (int j = threadIdx.x; j < p; j += blockDim.x) {
|
|
271
|
+
xp1f[j] = sh_xp1f_s[j];
|
|
272
|
+
atomicAdd(grad_out + j, sh_grad_s[j]);
|
|
273
|
+
}
|
|
274
|
+
for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
|
|
275
|
+
xp2f[j] = sh_xp2f_s[j];
|
|
276
|
+
}
|
|
277
|
+
if (threadIdx.x == 0) {
|
|
278
|
+
*scratch_xp0f = *sh_xp0f_s;
|
|
279
|
+
}
|
|
280
|
+
} else {
|
|
281
|
+
if (threadIdx.x == 0) {
|
|
282
|
+
*scratch_xp0f = 0.0;
|
|
283
|
+
}
|
|
284
|
+
__syncthreads();
|
|
285
|
+
for (int tt = threadIdx.x; tt < m; tt += blockDim.x) {
|
|
286
|
+
int idx = fail_ind[f0 + tt];
|
|
287
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
288
|
+
double elx = e_eta[idx];
|
|
289
|
+
atomicAdd(scratch_xp0f, elx);
|
|
290
|
+
if (p <= EFRON_MAX_P_STACK) {
|
|
291
|
+
double row[EFRON_MAX_P_STACK];
|
|
292
|
+
for (int j = 0; j < p; j++) row[j] = Xrow[j];
|
|
293
|
+
for (int j = 0; j < p; j++) {
|
|
294
|
+
atomicAdd(xp1f + j, elx * row[j]);
|
|
295
|
+
atomicAdd(grad_out + j, row[j]);
|
|
296
|
+
}
|
|
297
|
+
for (int j = 0; j < p; j++)
|
|
298
|
+
for (int k = 0; k < p; k++)
|
|
299
|
+
atomicAdd(xp2f + j * p + k, elx * row[j] * row[k]);
|
|
300
|
+
} else {
|
|
301
|
+
for (int j = 0; j < p; j++) {
|
|
302
|
+
double vj = Xrow[j];
|
|
303
|
+
atomicAdd(xp1f + j, elx * vj);
|
|
304
|
+
atomicAdd(grad_out + j, vj);
|
|
305
|
+
}
|
|
306
|
+
for (int j = 0; j < p; j++) {
|
|
307
|
+
double vj = Xrow[j];
|
|
308
|
+
for (int k = 0; k < p; k++)
|
|
309
|
+
atomicAdd(xp2f + j * p + k, elx * vj * Xrow[k]);
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
__syncthreads();
|
|
315
|
+
if (threadIdx.x == 0) {
|
|
316
|
+
double xp0v = *xp0_ptr;
|
|
317
|
+
double xp0f = *scratch_xp0f;
|
|
318
|
+
double sum_inv_c0 = 0.0;
|
|
319
|
+
double sum_J_c0 = 0.0;
|
|
320
|
+
double sum_aa = 0.0;
|
|
321
|
+
double sum_bb = 0.0;
|
|
322
|
+
double sum_ab = 0.0;
|
|
323
|
+
for (int kk = 0; kk < m; kk++) {
|
|
324
|
+
double Jk = (double)kk / (double)m;
|
|
325
|
+
double c0 = xp0v - Jk * xp0f;
|
|
326
|
+
if (c0 < 1e-300) c0 = 1e-300;
|
|
327
|
+
double ak = 1.0 / c0;
|
|
328
|
+
double bk = Jk * ak;
|
|
329
|
+
sum_inv_c0 += ak;
|
|
330
|
+
sum_J_c0 += Jk / c0;
|
|
331
|
+
sum_aa += ak * ak;
|
|
332
|
+
sum_bb += bk * bk;
|
|
333
|
+
sum_ab += ak * bk;
|
|
334
|
+
}
|
|
335
|
+
for (int j = 0; j < p; j++) {
|
|
336
|
+
grad_out[j] -= (xp1[j] * sum_inv_c0 - xp1f[j] * sum_J_c0);
|
|
337
|
+
}
|
|
338
|
+
for (int j = 0; j < p * p; j++) {
|
|
339
|
+
hess_acc[j] += xp2[j] * sum_inv_c0;
|
|
340
|
+
hess_acc[j] -= xp2f[j] * sum_J_c0;
|
|
341
|
+
}
|
|
342
|
+
for (int j1 = 0; j1 < p; j1++) {
|
|
343
|
+
for (int j2 = j1; j2 < p; j2++) {
|
|
344
|
+
double o11 = xp1[j1] * xp1[j2];
|
|
345
|
+
double off = xp1f[j1] * xp1f[j2];
|
|
346
|
+
double cross = xp1[j1] * xp1f[j2] + xp1f[j1] * xp1[j2];
|
|
347
|
+
double hsub = sum_aa * o11 + sum_bb * off - sum_ab * cross;
|
|
348
|
+
hess_acc[j1 * p + j2] -= hsub;
|
|
349
|
+
if (j2 != j1) hess_acc[j2 * p + j1] -= hsub;
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
__syncthreads();
|
|
356
|
+
|
|
357
|
+
int x0 = exit_ptr[ii];
|
|
358
|
+
int x1 = exit_ptr[ii + 1];
|
|
359
|
+
int nx = x1 - x0;
|
|
360
|
+
if (nx <= seq_thresh) {
|
|
361
|
+
if (threadIdx.x == 0) {
|
|
362
|
+
for (int t = x0; t < x1; t++) {
|
|
363
|
+
int idx = exit_ind[t];
|
|
364
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
365
|
+
double elx = e_eta[idx];
|
|
366
|
+
*xp0_ptr -= elx;
|
|
367
|
+
if (p <= EFRON_MAX_P_STACK) {
|
|
368
|
+
double row[EFRON_MAX_P_STACK];
|
|
369
|
+
for (int j = 0; j < p; j++) row[j] = Xrow[j];
|
|
370
|
+
for (int j = 0; j < p; j++) xp1[j] -= elx * row[j];
|
|
371
|
+
for (int j = 0; j < p; j++)
|
|
372
|
+
for (int k = 0; k < p; k++)
|
|
373
|
+
xp2[j * p + k] -= elx * row[j] * row[k];
|
|
374
|
+
} else {
|
|
375
|
+
for (int j = 0; j < p; j++) xp1[j] -= elx * Xrow[j];
|
|
376
|
+
for (int j = 0; j < p; j++) {
|
|
377
|
+
double vj = Xrow[j];
|
|
378
|
+
for (int k = 0; k < p; k++)
|
|
379
|
+
xp2[j * p + k] -= elx * vj * Xrow[k];
|
|
380
|
+
}
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
} else {
|
|
385
|
+
// Stage A/B block accumulation for exit updates, mirroring enter side.
|
|
386
|
+
if (p <= 64) {
|
|
387
|
+
__shared__ double sh_xp0;
|
|
388
|
+
__shared__ double sh_xp1[64];
|
|
389
|
+
__shared__ double sh_xp2[4096];
|
|
390
|
+
if (threadIdx.x == 0) sh_xp0 = 0.0;
|
|
391
|
+
for (int j = threadIdx.x; j < p; j += blockDim.x) {
|
|
392
|
+
sh_xp1[j] = 0.0;
|
|
393
|
+
}
|
|
394
|
+
for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
|
|
395
|
+
sh_xp2[j] = 0.0;
|
|
396
|
+
}
|
|
397
|
+
__syncthreads();
|
|
398
|
+
for (int tt = threadIdx.x; tt < nx; tt += blockDim.x) {
|
|
399
|
+
int idx = exit_ind[x0 + tt];
|
|
400
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
401
|
+
double elx = e_eta[idx];
|
|
402
|
+
atomicAdd(&sh_xp0, -elx);
|
|
403
|
+
for (int j = 0; j < p; j++) {
|
|
404
|
+
double vj = Xrow[j];
|
|
405
|
+
atomicAdd(sh_xp1 + j, -elx * vj);
|
|
406
|
+
}
|
|
407
|
+
for (int j = 0; j < p; j++) {
|
|
408
|
+
double vj = Xrow[j];
|
|
409
|
+
for (int k = 0; k < p; k++) {
|
|
410
|
+
atomicAdd(sh_xp2 + j * p + k, -elx * vj * Xrow[k]);
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
}
|
|
414
|
+
__syncthreads();
|
|
415
|
+
if (threadIdx.x == 0) atomicAdd(xp0_ptr, sh_xp0);
|
|
416
|
+
for (int j = threadIdx.x; j < p; j += blockDim.x) {
|
|
417
|
+
atomicAdd(xp1 + j, sh_xp1[j]);
|
|
418
|
+
}
|
|
419
|
+
for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
|
|
420
|
+
atomicAdd(xp2 + j, sh_xp2[j]);
|
|
421
|
+
}
|
|
422
|
+
} else {
|
|
423
|
+
for (int tt = threadIdx.x; tt < nx; tt += blockDim.x) {
|
|
424
|
+
int idx = exit_ind[x0 + tt];
|
|
425
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
426
|
+
double elx = e_eta[idx];
|
|
427
|
+
atomicAdd(xp0_ptr, -elx);
|
|
428
|
+
if (p <= EFRON_MAX_P_STACK) {
|
|
429
|
+
double row[EFRON_MAX_P_STACK];
|
|
430
|
+
for (int j = 0; j < p; j++) row[j] = Xrow[j];
|
|
431
|
+
for (int j = 0; j < p; j++) atomicAdd(xp1 + j, -elx * row[j]);
|
|
432
|
+
for (int j = 0; j < p; j++)
|
|
433
|
+
for (int k = 0; k < p; k++)
|
|
434
|
+
atomicAdd(xp2 + j * p + k, -elx * row[j] * row[k]);
|
|
435
|
+
} else {
|
|
436
|
+
for (int j = 0; j < p; j++) atomicAdd(xp1 + j, -elx * Xrow[j]);
|
|
437
|
+
for (int j = 0; j < p; j++) {
|
|
438
|
+
double vj = Xrow[j];
|
|
439
|
+
for (int k = 0; k < p; k++)
|
|
440
|
+
atomicAdd(xp2 + j * p + k, -elx * vj * Xrow[k]);
|
|
441
|
+
}
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
}
|
|
446
|
+
__syncthreads();
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
if (threadIdx.x == 0) {
|
|
450
|
+
for (int j = 0; j < p * p; j++) {
|
|
451
|
+
hess_out[j] = -hess_acc[j];
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
_KERNEL_SOURCE_SERIAL = r"""
|
|
458
|
+
#define EFRON_MAX_P_STACK 128
|
|
459
|
+
extern "C" __global__
|
|
460
|
+
void efron_backward_scan_serial(
|
|
461
|
+
const double* __restrict__ X,
|
|
462
|
+
const double* __restrict__ e_eta,
|
|
463
|
+
const int* __restrict__ meta,
|
|
464
|
+
const int* __restrict__ enter_ptr,
|
|
465
|
+
const int* __restrict__ enter_ind,
|
|
466
|
+
const int* __restrict__ exit_ptr,
|
|
467
|
+
const int* __restrict__ exit_ind,
|
|
468
|
+
const int* __restrict__ fail_ptr,
|
|
469
|
+
const int* __restrict__ fail_ind,
|
|
470
|
+
double* __restrict__ grad_out,
|
|
471
|
+
double* __restrict__ hess_out,
|
|
472
|
+
double* __restrict__ workspace
|
|
473
|
+
) {
|
|
474
|
+
if (blockIdx.x != 0 || threadIdx.x != 0) return;
|
|
475
|
+
int p = meta[1];
|
|
476
|
+
int nuft = meta[2];
|
|
477
|
+
|
|
478
|
+
double* xp0_ptr = workspace;
|
|
479
|
+
double* xp1 = xp0_ptr + 1;
|
|
480
|
+
double* xp2 = xp1 + p;
|
|
481
|
+
double* hess_acc = xp2 + p * p;
|
|
482
|
+
double* xp1f = hess_acc + p * p;
|
|
483
|
+
double* xp2f = xp1f + p;
|
|
484
|
+
|
|
485
|
+
*xp0_ptr = 0.0;
|
|
486
|
+
for (int j = 0; j < p; j++) {
|
|
487
|
+
grad_out[j] = 0.0;
|
|
488
|
+
xp1[j] = 0.0;
|
|
489
|
+
xp1f[j] = 0.0;
|
|
490
|
+
}
|
|
491
|
+
for (int j = 0; j < p * p; j++) {
|
|
492
|
+
xp2[j] = 0.0;
|
|
493
|
+
xp2f[j] = 0.0;
|
|
494
|
+
hess_acc[j] = 0.0;
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
for (int ii = nuft - 1; ii >= 0; ii--) {
|
|
498
|
+
int e0 = enter_ptr[ii];
|
|
499
|
+
int e1 = enter_ptr[ii + 1];
|
|
500
|
+
for (int t = e0; t < e1; t++) {
|
|
501
|
+
int idx = enter_ind[t];
|
|
502
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
503
|
+
double elx = e_eta[idx];
|
|
504
|
+
*xp0_ptr += elx;
|
|
505
|
+
for (int j = 0; j < p; j++) xp1[j] += elx * Xrow[j];
|
|
506
|
+
for (int j = 0; j < p; j++) {
|
|
507
|
+
double vj = Xrow[j];
|
|
508
|
+
for (int k = 0; k < p; k++) xp2[j * p + k] += elx * vj * Xrow[k];
|
|
509
|
+
}
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
int f0 = fail_ptr[ii];
|
|
513
|
+
int f1 = fail_ptr[ii + 1];
|
|
514
|
+
int m = f1 - f0;
|
|
515
|
+
if (m > 0) {
|
|
516
|
+
for (int j = 0; j < p; j++) xp1f[j] = 0.0;
|
|
517
|
+
for (int j = 0; j < p * p; j++) xp2f[j] = 0.0;
|
|
518
|
+
double xp0v = *xp0_ptr;
|
|
519
|
+
double xp0f = 0.0;
|
|
520
|
+
|
|
521
|
+
for (int t = f0; t < f1; t++) {
|
|
522
|
+
int idx = fail_ind[t];
|
|
523
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
524
|
+
double elx = e_eta[idx];
|
|
525
|
+
xp0f += elx;
|
|
526
|
+
for (int j = 0; j < p; j++) {
|
|
527
|
+
xp1f[j] += elx * Xrow[j];
|
|
528
|
+
grad_out[j] += Xrow[j];
|
|
529
|
+
}
|
|
530
|
+
for (int j = 0; j < p; j++) {
|
|
531
|
+
double vj = Xrow[j];
|
|
532
|
+
for (int k = 0; k < p; k++) xp2f[j * p + k] += elx * vj * Xrow[k];
|
|
533
|
+
}
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
double sum_inv_c0 = 0.0;
|
|
537
|
+
double sum_J_c0 = 0.0;
|
|
538
|
+
double sum_aa = 0.0;
|
|
539
|
+
double sum_bb = 0.0;
|
|
540
|
+
double sum_ab = 0.0;
|
|
541
|
+
for (int kk = 0; kk < m; kk++) {
|
|
542
|
+
double Jk = (double)kk / (double)m;
|
|
543
|
+
double c0 = xp0v - Jk * xp0f;
|
|
544
|
+
if (c0 < 1e-300) c0 = 1e-300;
|
|
545
|
+
double ak = 1.0 / c0;
|
|
546
|
+
double bk = Jk * ak;
|
|
547
|
+
sum_inv_c0 += ak;
|
|
548
|
+
sum_J_c0 += Jk / c0;
|
|
549
|
+
sum_aa += ak * ak;
|
|
550
|
+
sum_bb += bk * bk;
|
|
551
|
+
sum_ab += ak * bk;
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
for (int j = 0; j < p; j++) grad_out[j] -= (xp1[j] * sum_inv_c0 - xp1f[j] * sum_J_c0);
|
|
555
|
+
for (int j = 0; j < p * p; j++) {
|
|
556
|
+
hess_acc[j] += xp2[j] * sum_inv_c0;
|
|
557
|
+
hess_acc[j] -= xp2f[j] * sum_J_c0;
|
|
558
|
+
}
|
|
559
|
+
for (int j1 = 0; j1 < p; j1++) {
|
|
560
|
+
for (int j2 = j1; j2 < p; j2++) {
|
|
561
|
+
double o11 = xp1[j1] * xp1[j2];
|
|
562
|
+
double off = xp1f[j1] * xp1f[j2];
|
|
563
|
+
double cross = xp1[j1] * xp1f[j2] + xp1f[j1] * xp1[j2];
|
|
564
|
+
double hsub = sum_aa * o11 + sum_bb * off - sum_ab * cross;
|
|
565
|
+
hess_acc[j1 * p + j2] -= hsub;
|
|
566
|
+
if (j2 != j1) hess_acc[j2 * p + j1] -= hsub;
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
int x0 = exit_ptr[ii];
|
|
572
|
+
int x1 = exit_ptr[ii + 1];
|
|
573
|
+
for (int t = x0; t < x1; t++) {
|
|
574
|
+
int idx = exit_ind[t];
|
|
575
|
+
const double* Xrow = X + (size_t)idx * (size_t)p;
|
|
576
|
+
double elx = e_eta[idx];
|
|
577
|
+
*xp0_ptr -= elx;
|
|
578
|
+
for (int j = 0; j < p; j++) xp1[j] -= elx * Xrow[j];
|
|
579
|
+
for (int j = 0; j < p; j++) {
|
|
580
|
+
double vj = Xrow[j];
|
|
581
|
+
for (int k = 0; k < p; k++) xp2[j * p + k] -= elx * vj * Xrow[k];
|
|
582
|
+
}
|
|
583
|
+
}
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
for (int j = 0; j < p * p; j++) hess_out[j] = -hess_acc[j];
|
|
587
|
+
}
|
|
588
|
+
"""
|
|
589
|
+
|
|
590
|
+
# Workspace: xp0(1) + xp1(p) + xp2(p*p) + hess_acc(p*p) + xp1f(p) + xp2f(p*p) + scratch(1)
|
|
591
|
+
EFRON_BACKWARD_THREADS: int = 128
|
|
592
|
+
|
|
593
|
+
_kernel_cache: Any = None
|
|
594
|
+
_KERNEL_VER = 9
|
|
595
|
+
_kernel_cache_serial: Any = None
|
|
596
|
+
_KERNEL_VER_SERIAL = 1
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
def _env_int(name: str, default: int) -> int:
|
|
600
|
+
"""Parse env int with a safe fallback."""
|
|
601
|
+
v = os.environ.get(name)
|
|
602
|
+
if v is None:
|
|
603
|
+
return default
|
|
604
|
+
try:
|
|
605
|
+
return int(v)
|
|
606
|
+
except Exception:
|
|
607
|
+
return default
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def _pick_backward_launch_params(p: int, nuft: int, n: int) -> Tuple[int, int]:
|
|
611
|
+
"""Choose (seq_thresh, threads) with env override + sane defaults.
|
|
612
|
+
|
|
613
|
+
- `STATGPU_EFRON_SEQ_THRESH` >= 0 overrides seq threshold.
|
|
614
|
+
- `STATGPU_EFRON_BACKWARD_THREADS` > 0 overrides threads.
|
|
615
|
+
- Otherwise choose a heavier sequential threshold for small/medium `p`,
|
|
616
|
+
which often reduces atomic overhead in heavy-ties cases.
|
|
617
|
+
"""
|
|
618
|
+
seq_env = _env_int("STATGPU_EFRON_SEQ_THRESH", -1)
|
|
619
|
+
th_env = _env_int("STATGPU_EFRON_BACKWARD_THREADS", -1)
|
|
620
|
+
|
|
621
|
+
if th_env > 0:
|
|
622
|
+
threads = int(th_env)
|
|
623
|
+
else:
|
|
624
|
+
# Small/medium shapes are usually latency-sensitive; fewer threads can help.
|
|
625
|
+
if p <= 24 and nuft <= 512:
|
|
626
|
+
threads = 64
|
|
627
|
+
elif p <= 64 and nuft <= 512:
|
|
628
|
+
threads = 64
|
|
629
|
+
else:
|
|
630
|
+
threads = EFRON_BACKWARD_THREADS
|
|
631
|
+
|
|
632
|
+
if seq_env >= 0:
|
|
633
|
+
seq_thresh = int(seq_env)
|
|
634
|
+
else:
|
|
635
|
+
# Favor single-thread local accumulation for moderate group sizes.
|
|
636
|
+
avg_group = float(n) / max(1.0, float(nuft))
|
|
637
|
+
if p <= 24:
|
|
638
|
+
seq_thresh = 64
|
|
639
|
+
elif p <= 64:
|
|
640
|
+
# Heavy ties (large n/nuft) are often dominated by atomic contention.
|
|
641
|
+
# Favor larger sequential groups to reduce p^2 atomic pressure.
|
|
642
|
+
if avg_group >= 64.0:
|
|
643
|
+
seq_thresh = 256
|
|
644
|
+
elif avg_group >= 32.0:
|
|
645
|
+
seq_thresh = 128
|
|
646
|
+
else:
|
|
647
|
+
seq_thresh = 64
|
|
648
|
+
else:
|
|
649
|
+
seq_thresh = 16
|
|
650
|
+
|
|
651
|
+
return seq_thresh, threads
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
def _pack_csr(groups: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
|
|
655
|
+
ptr = [0]
|
|
656
|
+
ind: List[int] = []
|
|
657
|
+
for g in groups:
|
|
658
|
+
ind.extend(int(x) for x in g)
|
|
659
|
+
ptr.append(len(ind))
|
|
660
|
+
return np.asarray(ptr, dtype=np.int32), np.asarray(ind, dtype=np.int32)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def efron_indices_to_csr(
|
|
664
|
+
uft_ix: List[List[int]], risk_enter: List[List[int]], risk_exit: List[List[int]], nuft: int
|
|
665
|
+
) -> Tuple[np.ndarray, ...]:
|
|
666
|
+
enter_ptr, enter_ind = _pack_csr(risk_enter)
|
|
667
|
+
exit_ptr, exit_ind = _pack_csr(risk_exit)
|
|
668
|
+
fail_ptr, fail_ind = _pack_csr(uft_ix)
|
|
669
|
+
assert enter_ptr.size == nuft + 1 and exit_ptr.size == nuft + 1 and fail_ptr.size == nuft + 1
|
|
670
|
+
return enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
def get_efron_backward_kernel(cp):
|
|
674
|
+
global _kernel_cache
|
|
675
|
+
if (
|
|
676
|
+
_kernel_cache is None
|
|
677
|
+
or not isinstance(_kernel_cache, tuple)
|
|
678
|
+
or _kernel_cache[1] != _KERNEL_VER
|
|
679
|
+
):
|
|
680
|
+
_kernel_cache = (cp.RawKernel(_KERNEL_SOURCE, "efron_backward_scan"), _KERNEL_VER)
|
|
681
|
+
return _kernel_cache[0]
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def get_efron_backward_kernel_serial(cp):
|
|
685
|
+
global _kernel_cache_serial
|
|
686
|
+
if (
|
|
687
|
+
_kernel_cache_serial is None
|
|
688
|
+
or not isinstance(_kernel_cache_serial, tuple)
|
|
689
|
+
or _kernel_cache_serial[1] != _KERNEL_VER_SERIAL
|
|
690
|
+
):
|
|
691
|
+
_kernel_cache_serial = (
|
|
692
|
+
cp.RawKernel(_KERNEL_SOURCE_SERIAL, "efron_backward_scan_serial"),
|
|
693
|
+
_KERNEL_VER_SERIAL,
|
|
694
|
+
)
|
|
695
|
+
return _kernel_cache_serial[0]
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
_LOGLIK_THREADS: int = 128
|
|
699
|
+
|
|
700
|
+
_LOGLIK_KERNEL_SOURCE = r"""
|
|
701
|
+
#define EFRON_LOGLIK_THREADS 128
|
|
702
|
+
extern "C" __global__
|
|
703
|
+
void efron_loglik_by_group(
|
|
704
|
+
const double* __restrict__ eta,
|
|
705
|
+
const double* __restrict__ exp_eta,
|
|
706
|
+
const double* __restrict__ risk_sum,
|
|
707
|
+
const int* __restrict__ meta, // meta[0] = nuft
|
|
708
|
+
const int* __restrict__ fail_ptr,
|
|
709
|
+
const int* __restrict__ fail_ind,
|
|
710
|
+
const int* __restrict__ first_idx_uft,
|
|
711
|
+
double* __restrict__ out_ll
|
|
712
|
+
) {
|
|
713
|
+
int tid = (int)threadIdx.x;
|
|
714
|
+
int g = (int)blockIdx.x;
|
|
715
|
+
int nuft = meta[0];
|
|
716
|
+
if (g >= nuft) return;
|
|
717
|
+
|
|
718
|
+
int start = fail_ptr[g];
|
|
719
|
+
int end = fail_ptr[g + 1];
|
|
720
|
+
int m = end - start;
|
|
721
|
+
|
|
722
|
+
__shared__ double sh_events[EFRON_LOGLIK_THREADS];
|
|
723
|
+
__shared__ double sh_eta[EFRON_LOGLIK_THREADS];
|
|
724
|
+
__shared__ double sh_logs[EFRON_LOGLIK_THREADS];
|
|
725
|
+
|
|
726
|
+
double local_events = 0.0;
|
|
727
|
+
double local_eta = 0.0;
|
|
728
|
+
for (int i = start + tid; i < end; i += (int)blockDim.x) {
|
|
729
|
+
int idx = fail_ind[i];
|
|
730
|
+
double ex = exp_eta[idx];
|
|
731
|
+
local_events += ex;
|
|
732
|
+
local_eta += eta[idx];
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
sh_events[tid] = local_events;
|
|
736
|
+
sh_eta[tid] = local_eta;
|
|
737
|
+
__syncthreads();
|
|
738
|
+
|
|
739
|
+
// Reduce sum_events and sum_eta.
|
|
740
|
+
for (int stride = (int)blockDim.x / 2; stride > 0; stride >>= 1) {
|
|
741
|
+
if (tid < stride) {
|
|
742
|
+
sh_events[tid] += sh_events[tid + stride];
|
|
743
|
+
sh_eta[tid] += sh_eta[tid + stride];
|
|
744
|
+
}
|
|
745
|
+
__syncthreads();
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
double sum_events = sh_events[0];
|
|
749
|
+
double sum_eta = sh_eta[0];
|
|
750
|
+
double risk_at_t = risk_sum[first_idx_uft[g]];
|
|
751
|
+
|
|
752
|
+
double local_logs = 0.0;
|
|
753
|
+
for (int k = tid; k < m; k += (int)blockDim.x) {
|
|
754
|
+
double Jk = (double)k / (double)m;
|
|
755
|
+
double denom = risk_at_t - Jk * sum_events;
|
|
756
|
+
if (denom < 1e-300) denom = 1e-300;
|
|
757
|
+
local_logs += log(denom);
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
sh_logs[tid] = local_logs;
|
|
761
|
+
__syncthreads();
|
|
762
|
+
|
|
763
|
+
// Reduce sum_logs.
|
|
764
|
+
for (int stride = (int)blockDim.x / 2; stride > 0; stride >>= 1) {
|
|
765
|
+
if (tid < stride) {
|
|
766
|
+
sh_logs[tid] += sh_logs[tid + stride];
|
|
767
|
+
}
|
|
768
|
+
__syncthreads();
|
|
769
|
+
}
|
|
770
|
+
|
|
771
|
+
if (tid == 0) {
|
|
772
|
+
// ll = sum(eta[idx]) - sum_{k=0..m-1} log(risk_at_t - k/m * sum_events)
|
|
773
|
+
out_ll[g] = sum_eta - sh_logs[0];
|
|
774
|
+
}
|
|
775
|
+
}
|
|
776
|
+
"""
|
|
777
|
+
|
|
778
|
+
_kernel_cache_loglik: Any = None
|
|
779
|
+
_KERNEL_VER_LOGLIK = 1
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
def get_efron_loglik_kernel(cp):
|
|
783
|
+
global _kernel_cache_loglik
|
|
784
|
+
if (
|
|
785
|
+
_kernel_cache_loglik is None
|
|
786
|
+
or not isinstance(_kernel_cache_loglik, tuple)
|
|
787
|
+
or _kernel_cache_loglik[1] != _KERNEL_VER_LOGLIK
|
|
788
|
+
):
|
|
789
|
+
_kernel_cache_loglik = (
|
|
790
|
+
cp.RawKernel(_LOGLIK_KERNEL_SOURCE, "efron_loglik_by_group"),
|
|
791
|
+
_KERNEL_VER_LOGLIK,
|
|
792
|
+
)
|
|
793
|
+
return _kernel_cache_loglik[0]
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def compute_efron_loglik_raw_csr(
|
|
797
|
+
eta,
|
|
798
|
+
exp_eta,
|
|
799
|
+
risk_sum,
|
|
800
|
+
fail_ptr,
|
|
801
|
+
fail_ind,
|
|
802
|
+
first_idx_uft,
|
|
803
|
+
nuft: int,
|
|
804
|
+
*,
|
|
805
|
+
cupy_module,
|
|
806
|
+
) -> Any:
|
|
807
|
+
"""
|
|
808
|
+
Compute scalar Efron log partial likelihood on GPU using a single kernel.
|
|
809
|
+
`fail_ptr/fail_ind` are CSR arrays for uft_ix; `first_idx_uft` is int32.
|
|
810
|
+
"""
|
|
811
|
+
cp = cupy_module
|
|
812
|
+
if nuft == 0:
|
|
813
|
+
return cp.array(0.0, dtype=cp.float64)
|
|
814
|
+
|
|
815
|
+
# RawKernel assumes contiguous storage; avoid redundant copies.
|
|
816
|
+
if not getattr(eta, "flags", None) or not eta.flags.c_contiguous:
|
|
817
|
+
eta = cp.ascontiguousarray(eta)
|
|
818
|
+
if not getattr(exp_eta, "flags", None) or not exp_eta.flags.c_contiguous:
|
|
819
|
+
exp_eta = cp.ascontiguousarray(exp_eta)
|
|
820
|
+
if not getattr(risk_sum, "flags", None) or not risk_sum.flags.c_contiguous:
|
|
821
|
+
risk_sum = cp.ascontiguousarray(risk_sum)
|
|
822
|
+
|
|
823
|
+
fail_ptr_g = cp.asarray(fail_ptr, dtype=cp.int32)
|
|
824
|
+
fail_ind_g = cp.asarray(fail_ind, dtype=cp.int32)
|
|
825
|
+
first_idx_uft_g = cp.asarray(first_idx_uft, dtype=cp.int32)
|
|
826
|
+
|
|
827
|
+
out_ll = cp.zeros(int(nuft), dtype=cp.float64)
|
|
828
|
+
meta = cp.array([int(nuft)], dtype=cp.int32)
|
|
829
|
+
kernel = get_efron_loglik_kernel(cp)
|
|
830
|
+
try:
|
|
831
|
+
kernel(
|
|
832
|
+
(int(nuft),),
|
|
833
|
+
(_LOGLIK_THREADS,),
|
|
834
|
+
(
|
|
835
|
+
eta,
|
|
836
|
+
exp_eta,
|
|
837
|
+
risk_sum,
|
|
838
|
+
meta,
|
|
839
|
+
fail_ptr_g,
|
|
840
|
+
fail_ind_g,
|
|
841
|
+
first_idx_uft_g,
|
|
842
|
+
out_ll,
|
|
843
|
+
),
|
|
844
|
+
)
|
|
845
|
+
return cp.sum(out_ll)
|
|
846
|
+
except Exception:
|
|
847
|
+
# If kernel launch fails, let caller fallback to Python loop.
|
|
848
|
+
raise
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def compute_efron_loglik_raw(eta, exp_eta, risk_sum, time, efron_pre, *, cupy_module):
|
|
852
|
+
"""
|
|
853
|
+
Scalar partial log-likelihood (Efron) on GPU.
|
|
854
|
+
|
|
855
|
+
Uses a CuPy loop over cached failure groups; inner Efron sum over k is vectorized.
|
|
856
|
+
When ``efron_pre`` includes ``first_idx_uft`` (from ``_efron_unique_failure_indices``), avoids host ``searchsorted`` per group.
|
|
857
|
+
|
|
858
|
+
When cupy_module=None, uses NumPy for CPU computation (used by Torch backend fallback).
|
|
859
|
+
"""
|
|
860
|
+
cp = cupy_module
|
|
861
|
+
|
|
862
|
+
# NumPy fallback for Torch backend
|
|
863
|
+
if cp is None:
|
|
864
|
+
return _compute_efron_loglik_raw_numpy(eta, exp_eta, risk_sum, time, efron_pre)
|
|
865
|
+
|
|
866
|
+
if len(efron_pre) == 6:
|
|
867
|
+
uft_arr, uft_ix, _, _, nuft, first_idx_uft = efron_pre
|
|
868
|
+
else:
|
|
869
|
+
uft_arr, uft_ix, _, _, nuft = efron_pre
|
|
870
|
+
first_idx_uft = None
|
|
871
|
+
if nuft == 0:
|
|
872
|
+
return cp.array(0.0, dtype=cp.float64)
|
|
873
|
+
|
|
874
|
+
ll = cp.zeros((), dtype=cp.float64)
|
|
875
|
+
fi_gpu = cp.asarray(first_idx_uft, dtype=cp.int32) if first_idx_uft is not None else None
|
|
876
|
+
time_np = None
|
|
877
|
+
if fi_gpu is None:
|
|
878
|
+
time_np = cp.asnumpy(time).astype(np.float64, copy=False)
|
|
879
|
+
|
|
880
|
+
for i in range(nuft):
|
|
881
|
+
ix_ev = uft_ix[i]
|
|
882
|
+
d = len(ix_ev)
|
|
883
|
+
if d == 0:
|
|
884
|
+
continue
|
|
885
|
+
if fi_gpu is not None:
|
|
886
|
+
first_idx = fi_gpu[i]
|
|
887
|
+
else:
|
|
888
|
+
first_idx = int(np.searchsorted(time_np, float(uft_arr[i]), side="left"))
|
|
889
|
+
risk_at_t = risk_sum[first_idx]
|
|
890
|
+
idx = cp.asarray(ix_ev, dtype=cp.int32)
|
|
891
|
+
sum_events = cp.sum(exp_eta[idx])
|
|
892
|
+
kd = float(d)
|
|
893
|
+
k = cp.arange(d, dtype=cp.float64)
|
|
894
|
+
denom = risk_at_t - (k / kd) * sum_events
|
|
895
|
+
ll -= cp.sum(cp.log(cp.maximum(denom, 1e-300)))
|
|
896
|
+
ll += cp.sum(eta[idx])
|
|
897
|
+
return ll
|
|
898
|
+
|
|
899
|
+
|
|
900
|
+
def _compute_efron_loglik_raw_numpy(eta, exp_eta, risk_sum, time, efron_pre):
|
|
901
|
+
"""
|
|
902
|
+
NumPy implementation of Efron log-likelihood for Torch backend fallback.
|
|
903
|
+
|
|
904
|
+
Parameters
|
|
905
|
+
----------
|
|
906
|
+
eta : ndarray
|
|
907
|
+
Linear predictor values (n_samples,)
|
|
908
|
+
exp_eta : ndarray
|
|
909
|
+
exp(eta) values (n_samples,)
|
|
910
|
+
risk_sum : ndarray
|
|
911
|
+
Cumulative risk sums (n_samples,)
|
|
912
|
+
time : ndarray
|
|
913
|
+
Event times (n_samples,)
|
|
914
|
+
efron_pre : tuple
|
|
915
|
+
Precomputed failure time indices from _efron_unique_failure_indices
|
|
916
|
+
|
|
917
|
+
Returns
|
|
918
|
+
-------
|
|
919
|
+
float
|
|
920
|
+
Log-likelihood value
|
|
921
|
+
"""
|
|
922
|
+
if len(efron_pre) == 6:
|
|
923
|
+
uft_arr, uft_ix, _, _, nuft, first_idx_uft = efron_pre
|
|
924
|
+
else:
|
|
925
|
+
uft_arr, uft_ix, _, _, nuft = efron_pre
|
|
926
|
+
first_idx_uft = None
|
|
927
|
+
|
|
928
|
+
if nuft == 0:
|
|
929
|
+
return 0.0
|
|
930
|
+
|
|
931
|
+
ll = 0.0
|
|
932
|
+
time_np = None
|
|
933
|
+
if first_idx_uft is None:
|
|
934
|
+
time_np = time.astype(np.float64, copy=False)
|
|
935
|
+
|
|
936
|
+
for i in range(nuft):
|
|
937
|
+
ix_ev = uft_ix[i]
|
|
938
|
+
d = len(ix_ev)
|
|
939
|
+
if d == 0:
|
|
940
|
+
continue
|
|
941
|
+
|
|
942
|
+
if first_idx_uft is not None:
|
|
943
|
+
first_idx = first_idx_uft[i]
|
|
944
|
+
else:
|
|
945
|
+
first_idx = int(np.searchsorted(time_np, float(uft_arr[i]), side="left"))
|
|
946
|
+
|
|
947
|
+
risk_at_t = risk_sum[first_idx]
|
|
948
|
+
sum_events = np.sum(exp_eta[ix_ev])
|
|
949
|
+
kd = float(d)
|
|
950
|
+
|
|
951
|
+
# Efron correction: loop over k=0..d-1
|
|
952
|
+
k = np.arange(d, dtype=np.float64)
|
|
953
|
+
denom = risk_at_t - (k / kd) * sum_events
|
|
954
|
+
ll -= np.sum(np.log(np.maximum(denom, 1e-300)))
|
|
955
|
+
ll += np.sum(eta[ix_ev])
|
|
956
|
+
|
|
957
|
+
return ll
|
|
958
|
+
|
|
959
|
+
|
|
960
|
+
def compute_efron_grad_hess_raw(
|
|
961
|
+
X,
|
|
962
|
+
beta,
|
|
963
|
+
efron_pre,
|
|
964
|
+
*,
|
|
965
|
+
cupy_module,
|
|
966
|
+
efron_csr=None,
|
|
967
|
+
) -> Tuple[Any, Any]:
|
|
968
|
+
"""
|
|
969
|
+
Returns (grad, hess) as cupy arrays. Falls back to None if launch fails (caller uses Python path).
|
|
970
|
+
"""
|
|
971
|
+
cp = cupy_module
|
|
972
|
+
if efron_csr is not None:
|
|
973
|
+
# (enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind, first_idx_uft, nuft)
|
|
974
|
+
enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind, _, nuft = efron_csr
|
|
975
|
+
else:
|
|
976
|
+
if len(efron_pre) == 6:
|
|
977
|
+
_, uft_ix, risk_enter, risk_exit, nuft, _ = efron_pre
|
|
978
|
+
else:
|
|
979
|
+
_, uft_ix, risk_enter, risk_exit, nuft = efron_pre
|
|
980
|
+
enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind = efron_indices_to_csr(
|
|
981
|
+
uft_ix, risk_enter, risk_exit, nuft
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
if nuft == 0:
|
|
985
|
+
p = int(X.shape[1])
|
|
986
|
+
return cp.zeros(p, dtype=cp.float64), cp.zeros((p, p), dtype=cp.float64)
|
|
987
|
+
|
|
988
|
+
n, p = int(X.shape[0]), int(X.shape[1])
|
|
989
|
+
linpred = X @ beta
|
|
990
|
+
linpred = linpred - cp.max(linpred)
|
|
991
|
+
e_eta = cp.exp(linpred)
|
|
992
|
+
|
|
993
|
+
enter_ptr_g = cp.asarray(enter_ptr)
|
|
994
|
+
enter_ind_g = cp.asarray(enter_ind)
|
|
995
|
+
exit_ptr_g = cp.asarray(exit_ptr)
|
|
996
|
+
exit_ind_g = cp.asarray(exit_ind)
|
|
997
|
+
fail_ptr_g = cp.asarray(fail_ptr)
|
|
998
|
+
fail_ind_g = cp.asarray(fail_ind)
|
|
999
|
+
|
|
1000
|
+
grad_out = cp.zeros(p, dtype=cp.float64)
|
|
1001
|
+
hess_out = cp.zeros((p, p), dtype=cp.float64)
|
|
1002
|
+
ws = 2 + 2 * p + 3 * p * p
|
|
1003
|
+
workspace = cp.zeros(ws, dtype=cp.float64)
|
|
1004
|
+
|
|
1005
|
+
seq_thresh, threads = _pick_backward_launch_params(p, int(nuft), int(n))
|
|
1006
|
+
use_serial = (p <= 24 and int(nuft) <= 512 and _env_int("STATGPU_EFRON_SERIAL_KERNEL", 1) == 1)
|
|
1007
|
+
meta = cp.array([n, p, nuft, seq_thresh], dtype=cp.int32)
|
|
1008
|
+
kernel = get_efron_backward_kernel_serial(cp) if use_serial else get_efron_backward_kernel(cp)
|
|
1009
|
+
shared_mem = 33800 if (not use_serial and p <= 64) else 0
|
|
1010
|
+
try:
|
|
1011
|
+
kernel(
|
|
1012
|
+
(1,),
|
|
1013
|
+
((1,) if use_serial else (threads,)),
|
|
1014
|
+
(
|
|
1015
|
+
X,
|
|
1016
|
+
e_eta,
|
|
1017
|
+
meta,
|
|
1018
|
+
enter_ptr_g,
|
|
1019
|
+
enter_ind_g,
|
|
1020
|
+
exit_ptr_g,
|
|
1021
|
+
exit_ind_g,
|
|
1022
|
+
fail_ptr_g,
|
|
1023
|
+
fail_ind_g,
|
|
1024
|
+
grad_out,
|
|
1025
|
+
hess_out.reshape(-1),
|
|
1026
|
+
workspace,
|
|
1027
|
+
),
|
|
1028
|
+
shared_mem=shared_mem,
|
|
1029
|
+
)
|
|
1030
|
+
# Surface asynchronous kernel execution errors at this call site so fallback
|
|
1031
|
+
# behavior is reliable and diagnostics point to the correct launch.
|
|
1032
|
+
cp.cuda.Stream.null.synchronize()
|
|
1033
|
+
except Exception:
|
|
1034
|
+
return None
|
|
1035
|
+
|
|
1036
|
+
return grad_out, hess_out
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
_BRESLOW_KERNEL_VER = 1
|
|
1040
|
+
_breslow_kernel_cache = None
|
|
1041
|
+
_BRESLOW_UPDATE_KERNEL_VER = 1
|
|
1042
|
+
_breslow_update_kernel_cache = None
|
|
1043
|
+
|
|
1044
|
+
_BRESLOW_KERNEL_SOURCE = r"""
|
|
1045
|
+
extern "C" __global__
|
|
1046
|
+
void breslow_backward_hess_scan(
|
|
1047
|
+
const double* __restrict__ X,
|
|
1048
|
+
const double* __restrict__ e_eta,
|
|
1049
|
+
const int* __restrict__ first_idx,
|
|
1050
|
+
const double* __restrict__ counts,
|
|
1051
|
+
const int* __restrict__ meta, // [n, p, nuft, seq_thresh]
|
|
1052
|
+
double* __restrict__ hess_out,
|
|
1053
|
+
double* __restrict__ workspace
|
|
1054
|
+
) {
|
|
1055
|
+
int n = meta[0];
|
|
1056
|
+
int p = meta[1];
|
|
1057
|
+
int nuft = meta[2];
|
|
1058
|
+
int seq_thresh = meta[3];
|
|
1059
|
+
if (blockIdx.x != 0 || blockIdx.y != 0 || blockIdx.z != 0) return;
|
|
1060
|
+
if (threadIdx.y != 0 || threadIdx.z != 0) return;
|
|
1061
|
+
|
|
1062
|
+
double* xp0_ptr = workspace; // 1
|
|
1063
|
+
double* xp1 = xp0_ptr + 1; // p
|
|
1064
|
+
double* xp2 = xp1 + p; // p*p
|
|
1065
|
+
double* hess_acc = xp2 + p * p; // p*p
|
|
1066
|
+
int ws_doubles = 1 + p + 2 * p * p;
|
|
1067
|
+
|
|
1068
|
+
for (int i = threadIdx.x; i < ws_doubles; i += blockDim.x) {
|
|
1069
|
+
workspace[i] = 0.0;
|
|
1070
|
+
}
|
|
1071
|
+
__syncthreads();
|
|
1072
|
+
|
|
1073
|
+
for (int ii = nuft - 1; ii >= 0; --ii) {
|
|
1074
|
+
int start = first_idx[ii];
|
|
1075
|
+
int end = (ii == nuft - 1) ? n : first_idx[ii + 1];
|
|
1076
|
+
int nt = end - start;
|
|
1077
|
+
|
|
1078
|
+
if (nt > 0) {
|
|
1079
|
+
if (nt <= seq_thresh) {
|
|
1080
|
+
if (threadIdx.x == 0) {
|
|
1081
|
+
for (int r = start; r < end; ++r) {
|
|
1082
|
+
const double* Xrow = X + (size_t)r * (size_t)p;
|
|
1083
|
+
double elx = e_eta[r];
|
|
1084
|
+
*xp0_ptr += elx;
|
|
1085
|
+
for (int j = 0; j < p; ++j) {
|
|
1086
|
+
double vj = Xrow[j];
|
|
1087
|
+
xp1[j] += elx * vj;
|
|
1088
|
+
}
|
|
1089
|
+
for (int j = 0; j < p; ++j) {
|
|
1090
|
+
double vj = Xrow[j];
|
|
1091
|
+
for (int k = 0; k < p; ++k) {
|
|
1092
|
+
xp2[j * p + k] += elx * vj * Xrow[k];
|
|
1093
|
+
}
|
|
1094
|
+
}
|
|
1095
|
+
}
|
|
1096
|
+
}
|
|
1097
|
+
} else {
|
|
1098
|
+
for (int rr = threadIdx.x; rr < nt; rr += blockDim.x) {
|
|
1099
|
+
int r = start + rr;
|
|
1100
|
+
const double* Xrow = X + (size_t)r * (size_t)p;
|
|
1101
|
+
double elx = e_eta[r];
|
|
1102
|
+
atomicAdd(xp0_ptr, elx);
|
|
1103
|
+
for (int j = 0; j < p; ++j) {
|
|
1104
|
+
atomicAdd(xp1 + j, elx * Xrow[j]);
|
|
1105
|
+
}
|
|
1106
|
+
for (int j = 0; j < p; ++j) {
|
|
1107
|
+
double vj = Xrow[j];
|
|
1108
|
+
for (int k = 0; k < p; ++k) {
|
|
1109
|
+
atomicAdd(xp2 + j * p + k, elx * vj * Xrow[k]);
|
|
1110
|
+
}
|
|
1111
|
+
}
|
|
1112
|
+
}
|
|
1113
|
+
}
|
|
1114
|
+
__syncthreads();
|
|
1115
|
+
}
|
|
1116
|
+
|
|
1117
|
+
if (threadIdx.x == 0) {
|
|
1118
|
+
double rs = *xp0_ptr;
|
|
1119
|
+
double w = counts[ii];
|
|
1120
|
+
if (rs > 1e-300 && w != 0.0) {
|
|
1121
|
+
double inv = 1.0 / rs;
|
|
1122
|
+
double inv2 = inv * inv;
|
|
1123
|
+
for (int j1 = 0; j1 < p; ++j1) {
|
|
1124
|
+
double x1 = xp1[j1];
|
|
1125
|
+
for (int j2 = 0; j2 < p; ++j2) {
|
|
1126
|
+
double exx = xp2[j1 * p + j2] * inv;
|
|
1127
|
+
double ex1ex2 = x1 * xp1[j2] * inv2;
|
|
1128
|
+
hess_acc[j1 * p + j2] += w * (exx - ex1ex2);
|
|
1129
|
+
}
|
|
1130
|
+
}
|
|
1131
|
+
}
|
|
1132
|
+
}
|
|
1133
|
+
__syncthreads();
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
for (int j = threadIdx.x; j < p * p; j += blockDim.x) {
|
|
1137
|
+
hess_out[j] = -hess_acc[j];
|
|
1138
|
+
}
|
|
1139
|
+
}
|
|
1140
|
+
"""
|
|
1141
|
+
|
|
1142
|
+
|
|
1143
|
+
def get_breslow_hess_kernel(cp):
|
|
1144
|
+
global _breslow_kernel_cache
|
|
1145
|
+
if (
|
|
1146
|
+
_breslow_kernel_cache is None
|
|
1147
|
+
or not isinstance(_breslow_kernel_cache, tuple)
|
|
1148
|
+
or _breslow_kernel_cache[1] != _BRESLOW_KERNEL_VER
|
|
1149
|
+
):
|
|
1150
|
+
_breslow_kernel_cache = (
|
|
1151
|
+
cp.RawKernel(_BRESLOW_KERNEL_SOURCE, "breslow_backward_hess_scan"),
|
|
1152
|
+
_BRESLOW_KERNEL_VER,
|
|
1153
|
+
)
|
|
1154
|
+
return _breslow_kernel_cache[0]
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
_BRESLOW_UPDATE_KERNEL_SOURCE = r"""
|
|
1158
|
+
extern "C" __global__
|
|
1159
|
+
void breslow_hess_update(
|
|
1160
|
+
double* __restrict__ hess,
|
|
1161
|
+
const double* __restrict__ risk_x2,
|
|
1162
|
+
const double* __restrict__ ex,
|
|
1163
|
+
const double rs,
|
|
1164
|
+
const double count,
|
|
1165
|
+
const int p
|
|
1166
|
+
) {
|
|
1167
|
+
int idx = (int)(blockIdx.x * blockDim.x + threadIdx.x);
|
|
1168
|
+
int total = p * p;
|
|
1169
|
+
if (idx >= total) return;
|
|
1170
|
+
int j1 = idx / p;
|
|
1171
|
+
int j2 = idx - j1 * p;
|
|
1172
|
+
double exx = risk_x2[idx] / rs;
|
|
1173
|
+
double outer = ex[j1] * ex[j2];
|
|
1174
|
+
hess[idx] -= count * (exx - outer);
|
|
1175
|
+
}
|
|
1176
|
+
"""
|
|
1177
|
+
|
|
1178
|
+
|
|
1179
|
+
def get_breslow_hess_update_kernel(cp):
|
|
1180
|
+
global _breslow_update_kernel_cache
|
|
1181
|
+
if (
|
|
1182
|
+
_breslow_update_kernel_cache is None
|
|
1183
|
+
or not isinstance(_breslow_update_kernel_cache, tuple)
|
|
1184
|
+
or _breslow_update_kernel_cache[1] != _BRESLOW_UPDATE_KERNEL_VER
|
|
1185
|
+
):
|
|
1186
|
+
_breslow_update_kernel_cache = (
|
|
1187
|
+
cp.RawKernel(_BRESLOW_UPDATE_KERNEL_SOURCE, "breslow_hess_update"),
|
|
1188
|
+
_BRESLOW_UPDATE_KERNEL_VER,
|
|
1189
|
+
)
|
|
1190
|
+
return _breslow_update_kernel_cache[0]
|
|
1191
|
+
|
|
1192
|
+
|
|
1193
|
+
def apply_breslow_hess_update_raw(hess, risk_x2, ex, rs, count, *, cupy_module):
|
|
1194
|
+
cp = cupy_module
|
|
1195
|
+
p = int(ex.shape[0])
|
|
1196
|
+
if p <= 0:
|
|
1197
|
+
return
|
|
1198
|
+
threads = 256
|
|
1199
|
+
blocks = (p * p + threads - 1) // threads
|
|
1200
|
+
kernel = get_breslow_hess_update_kernel(cp)
|
|
1201
|
+
kernel(
|
|
1202
|
+
(blocks,),
|
|
1203
|
+
(threads,),
|
|
1204
|
+
(
|
|
1205
|
+
hess.reshape(-1),
|
|
1206
|
+
risk_x2.reshape(-1),
|
|
1207
|
+
ex,
|
|
1208
|
+
float(rs),
|
|
1209
|
+
float(count),
|
|
1210
|
+
p,
|
|
1211
|
+
),
|
|
1212
|
+
)
|
|
1213
|
+
|
|
1214
|
+
|
|
1215
|
+
def compute_breslow_hess_raw(
|
|
1216
|
+
X,
|
|
1217
|
+
first_idx_uft,
|
|
1218
|
+
counts_uft,
|
|
1219
|
+
*,
|
|
1220
|
+
cupy_module,
|
|
1221
|
+
exp_eta=None,
|
|
1222
|
+
beta=None,
|
|
1223
|
+
):
|
|
1224
|
+
"""
|
|
1225
|
+
Fused CuPy RawKernel Hessian for Breslow grouped ties.
|
|
1226
|
+
Returns hess (cupy array) or None on launch/compile failure.
|
|
1227
|
+
"""
|
|
1228
|
+
cp = cupy_module
|
|
1229
|
+
nuft = int(first_idx_uft.size)
|
|
1230
|
+
p = int(X.shape[1])
|
|
1231
|
+
if nuft == 0:
|
|
1232
|
+
return cp.zeros((p, p), dtype=cp.float64)
|
|
1233
|
+
|
|
1234
|
+
n = int(X.shape[0])
|
|
1235
|
+
if exp_eta is None:
|
|
1236
|
+
if beta is None:
|
|
1237
|
+
raise ValueError("compute_breslow_hess_raw requires either exp_eta or beta")
|
|
1238
|
+
linpred = X @ beta
|
|
1239
|
+
linpred = linpred - cp.max(linpred)
|
|
1240
|
+
e_eta = cp.exp(linpred)
|
|
1241
|
+
else:
|
|
1242
|
+
e_eta = exp_eta
|
|
1243
|
+
|
|
1244
|
+
first_idx_g = cp.asarray(first_idx_uft, dtype=cp.int32)
|
|
1245
|
+
counts_g = cp.asarray(counts_uft, dtype=cp.float64)
|
|
1246
|
+
hess_out = cp.zeros((p, p), dtype=cp.float64)
|
|
1247
|
+
workspace = cp.zeros(1 + p + 2 * p * p, dtype=cp.float64)
|
|
1248
|
+
# Keep stable default launch behavior, allow opt-in manual tuning.
|
|
1249
|
+
seq_thresh, threads = _pick_backward_launch_params(p, nuft, n)
|
|
1250
|
+
th_env = os.environ.get("STATGPU_BRESLOW_HESS_THREADS", "").strip()
|
|
1251
|
+
if th_env:
|
|
1252
|
+
try:
|
|
1253
|
+
threads = max(32, min(512, int(th_env)))
|
|
1254
|
+
except Exception:
|
|
1255
|
+
pass
|
|
1256
|
+
seq_env = os.environ.get("STATGPU_BRESLOW_SEQ_THRESH", "").strip()
|
|
1257
|
+
if seq_env:
|
|
1258
|
+
try:
|
|
1259
|
+
seq_thresh = max(1, int(seq_env))
|
|
1260
|
+
except Exception:
|
|
1261
|
+
pass
|
|
1262
|
+
meta = cp.array([n, p, nuft, seq_thresh], dtype=cp.int32)
|
|
1263
|
+
kernel = get_breslow_hess_kernel(cp)
|
|
1264
|
+
try:
|
|
1265
|
+
kernel(
|
|
1266
|
+
(1,),
|
|
1267
|
+
(threads,),
|
|
1268
|
+
(
|
|
1269
|
+
X,
|
|
1270
|
+
e_eta,
|
|
1271
|
+
first_idx_g,
|
|
1272
|
+
counts_g,
|
|
1273
|
+
meta,
|
|
1274
|
+
hess_out.reshape(-1),
|
|
1275
|
+
workspace,
|
|
1276
|
+
),
|
|
1277
|
+
)
|
|
1278
|
+
except Exception:
|
|
1279
|
+
return None
|
|
1280
|
+
return hess_out
|