statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,3974 @@
1
+ """
2
+ Cox Proportional Hazards regression with GPU acceleration.
3
+
4
+ Implements Cox PH model using Breslow and Efron approximations for ties with
5
+ Newton-Raphson optimization. Matches R's survival::coxph() API.
6
+ """
7
+
8
+ from typing import Optional, Union, Tuple, Dict, Any, List
9
+ import os
10
+ import numpy as np
11
+ from scipy import stats
12
+
13
+ from statgpu._base import BaseEstimator
14
+ from statgpu._config import Device
15
+
16
+ # Optional Cython import for faster Efron gradient/Hessian computation
17
+ try:
18
+ from ._cox_efron_cy import efron_grad_hess as _efron_grad_hess_cython
19
+ HAS_CYTHON_EFRON = True
20
+ except ImportError:
21
+ HAS_CYTHON_EFRON = False
22
+ _efron_grad_hess_cython = None
23
+
24
+ try:
25
+ from statgpu.survival._cox_efron_triton import _find_p_ce
26
+ HAS_TRITON_EFRON = True
27
+ except ImportError:
28
+ HAS_TRITON_EFRON = False
29
+ _find_p_ce = None
30
+
31
+
32
+ def _unpack_efron_pre6(efron_pre):
33
+ """``(uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft)`` — supports legacy 5-tuple in tests only."""
34
+ if len(efron_pre) == 6:
35
+ return efron_pre
36
+ if len(efron_pre) == 5:
37
+ uft, uft_ix, re, rx, nuft = efron_pre
38
+ return uft, uft_ix, re, rx, nuft, None
39
+ raise ValueError(f"invalid efron_pre length {len(efron_pre)}")
40
+
41
+ class CoxPH(BaseEstimator):
42
+ """
43
+ Cox Proportional Hazards regression with GPU acceleration.
44
+
45
+ Parameters
46
+ ----------
47
+ ties : str, default='breslow'
48
+ Method for handling ties: 'breslow' or 'efron'.
49
+ tol : float, default=1e-9
50
+ Convergence tolerance for Newton-Raphson.
51
+ max_iter : int, default=100
52
+ Maximum number of iterations.
53
+ device : str or Device, default='auto'
54
+ Computation device: 'cpu', 'cuda', or 'auto'.
55
+ compute_inference : bool, default=True
56
+ If True, compute standard errors/tests/baseline hazard on CPU after fitting.
57
+ Set to False to reduce CPU-GPU data transfers in CUDA mode.
58
+ compute_cindex : bool, default=True
59
+ If True, compute training-set C-index during fit. Disabling this can
60
+ significantly reduce fit time, especially on CUDA/Torch for moderate n.
61
+
62
+ Attributes
63
+ ----------
64
+ coef_ : ndarray of shape (n_features,)
65
+ Estimated coefficients (log hazard ratios).
66
+ hazard_ratios_ : ndarray of shape (n_features,)
67
+ exp(coef) = hazard ratios.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ ties: str = 'breslow',
73
+ tol: float = 1e-9,
74
+ max_iter: int = 100,
75
+ device: Union[str, Device] = Device.AUTO,
76
+ n_jobs: Optional[int] = None,
77
+ compute_inference: bool = True,
78
+ compute_cindex: bool = True,
79
+ cov_type: str = "nonrobust",
80
+ gpu_memory_cleanup: bool = False,
81
+ penalty: float = 0.0,
82
+ ):
83
+ super().__init__(device=device, n_jobs=n_jobs)
84
+ self.ties = ties.lower()
85
+ self.tol = tol
86
+ self.max_iter = max_iter
87
+ self.compute_inference = compute_inference
88
+ self.compute_cindex = bool(compute_cindex)
89
+ self.cov_type = cov_type.lower()
90
+ self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
91
+ self.penalty = float(penalty)
92
+
93
+ if self.ties not in ('breslow', 'efron'):
94
+ raise ValueError("ties must be 'breslow' or 'efron'")
95
+ if self.cov_type not in ("nonrobust", "hc0", "hc1", "cluster"):
96
+ raise ValueError("cov_type must be one of: 'nonrobust', 'hc0', 'hc1', 'cluster'")
97
+ if self.penalty < 0:
98
+ raise ValueError("penalty must be non-negative")
99
+
100
+ # Fitted attributes
101
+ self.coef_ = None
102
+ self.hazard_ratios_ = None
103
+
104
+ # Internal storage for inference
105
+ self._time = None
106
+ self._event = None
107
+ self._X = None
108
+ self._entry = None
109
+ self._nobs = None
110
+ self._nevents = None
111
+ self._bse = None
112
+ self._zvalues = None
113
+ self._pvalues = None
114
+ self._conf_int = None
115
+ self._log_likelihood = None
116
+ self._log_likelihood_null = None
117
+ self._iterations = 0
118
+ self._converged = False
119
+ self._var_matrix = None
120
+ self._score_test_stat = None
121
+ self._baseline_hazard = None
122
+ self._baseline_cumulative_hazard = None
123
+ self._unique_times = None
124
+ self._cindex = None
125
+ self._feature_names = None
126
+ self._wald_test_stat = None
127
+ self._wald_test_pvalue = None
128
+ self._lr_test_stat = None
129
+ self._lr_test_pvalue = None
130
+ self._score_test_pvalue = None
131
+ # Efron only: cached (uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft); depends only on sorted time/event.
132
+ self._efron_pre = None
133
+ # Efron optimization: True when all failure groups are singletons (no ties),
134
+ # in which case Efron equals Breslow and we can use faster vectorized paths.
135
+ self._efron_all_singletons = False
136
+ # Efron only: cached CSR packed indices for GPU kernels.
137
+ # (enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind, first_idx_uft, nuft)
138
+ self._efron_pre_csr = None
139
+ # Breslow only: cached (first_idx_uft, counts_uft) on CPU.
140
+ self._breslow_pre = None
141
+ # Breslow only: cached (first_idx_uft_gpu, counts_uft_gpu) on GPU.
142
+ self._breslow_pre_gpu = None
143
+
144
+ def _cleanup_cuda_memory(self):
145
+ """Best-effort CuPy memory pool cleanup."""
146
+ if not self.gpu_memory_cleanup:
147
+ return
148
+ try:
149
+ import cupy as cp
150
+ cp.get_default_memory_pool().free_all_blocks()
151
+ cp.get_default_pinned_memory_pool().free_all_blocks()
152
+ except Exception:
153
+ pass
154
+
155
+ def _cleanup_torch_memory(self):
156
+ """Best-effort Torch CUDA cache cleanup."""
157
+ if not self.gpu_memory_cleanup:
158
+ return
159
+ try:
160
+ import torch
161
+
162
+ torch.cuda.empty_cache()
163
+ torch.cuda.synchronize()
164
+ except Exception:
165
+ pass
166
+
167
+ def __del__(self):
168
+ try:
169
+ self._cleanup_cuda_memory()
170
+ self._cleanup_torch_memory()
171
+ except Exception:
172
+ pass
173
+
174
+ @staticmethod
175
+ def _extract_convergence_status(result):
176
+ """Best-effort convergence extraction from statsmodels results."""
177
+ conv_attr = getattr(result, "converged", None)
178
+ if conv_attr is not None:
179
+ return bool(conv_attr)
180
+
181
+ mle_retvals = getattr(result, "mle_retvals", None)
182
+ if isinstance(mle_retvals, dict):
183
+ conv_attr = mle_retvals.get("converged")
184
+ if conv_attr is not None:
185
+ return bool(conv_attr)
186
+ elif mle_retvals is not None:
187
+ conv_attr = getattr(mle_retvals, "converged", None)
188
+ if conv_attr is not None:
189
+ return bool(conv_attr)
190
+ return None
191
+
192
+ def fit(self, X=None, time=None, event=None, entry=None, cluster=None, init_coef=None, formula=None, data=None):
193
+ """
194
+ Fit Cox Proportional Hazards model.
195
+
196
+ Parameters
197
+ ----------
198
+ X : array-like of shape (n_samples, n_features)
199
+ Covariate matrix. Required if ``formula`` is None.
200
+ time : array-like of shape (n_samples,)
201
+ Time to event or censoring. Required if ``formula`` is None.
202
+ event : array-like of shape (n_samples,)
203
+ Event indicator (1 = event, 0 = censored). Required if ``formula`` is None.
204
+ entry : array-like of shape (n_samples,), optional
205
+ Entry time for delayed entry (left truncation).
206
+ cluster : array-like of shape (n_samples,), optional
207
+ Cluster ids for cluster-robust covariance when `cov_type='cluster'`.
208
+ init_coef : array-like of shape (n_features,), optional
209
+ Initial coefficient guess for warm-start optimization.
210
+ formula : str or None
211
+ R-style formula with Surv() response, e.g.
212
+ ``"Surv(time, event) ~ x1 + x2 + C(sex)"``.
213
+ data : pd.DataFrame or None
214
+ DataFrame used with ``formula`` for column lookup.
215
+
216
+ Returns
217
+ -------
218
+ self : CoxPH
219
+ Fitted estimator.
220
+ """
221
+ # Handle formula interface
222
+ if formula is not None:
223
+ if data is None:
224
+ raise ValueError(
225
+ "formula was provided but data is None. "
226
+ "Pass data=your_dataframe when using formula."
227
+ )
228
+ from statgpu.core.formula import _surv, make_surv_env
229
+ import patsy
230
+ from patsy import EvalEnvironment
231
+
232
+ env = make_surv_env()
233
+ # Create evaluation environment with custom Surv function
234
+ custom_env = EvalEnvironment([env])
235
+ y_patsy, X_patsy = patsy.dmatrices(
236
+ formula, data, eval_env=custom_env, return_type="matrix",
237
+ )
238
+ design_info = X_patsy.design_info
239
+ # y_patsy is the result of Surv(time, event) -> shape (n, 2)
240
+ y_arr = np.asarray(y_patsy)
241
+ if y_arr.ndim == 1:
242
+ raise ValueError(
243
+ "Formula response must be Surv(time, event), not a single variable. "
244
+ "Use: formula='Surv(time, event) ~ x1 + x2'"
245
+ )
246
+ time = y_arr[:, 0]
247
+ event = y_arr[:, 1]
248
+ X_arr = np.asarray(X_patsy)
249
+
250
+ # Drop intercept column from design matrix (CoxPH doesn't use intercept)
251
+ self._feature_names = list(design_info.column_names)
252
+ if "Intercept" in self._feature_names:
253
+ self._feature_names.remove("Intercept")
254
+ X_arr = X_arr[:, 1:]
255
+ self._design_info = design_info
256
+ X = X_arr
257
+ else:
258
+ if X is None or time is None or event is None:
259
+ raise ValueError(
260
+ "Either formula+data or X+time+event must be provided."
261
+ )
262
+ self._design_info = None
263
+ device = self._get_compute_device()
264
+
265
+ if device == Device.CUDA:
266
+ import cupy as cp
267
+
268
+ X_gpu = cp.asarray(self._to_array(X), dtype=cp.float64)
269
+ time_gpu = cp.asarray(self._to_array(time), dtype=cp.float64)
270
+ event_gpu = cp.asarray(self._to_array(event), dtype=cp.int32)
271
+ entry_gpu = None if entry is None else cp.asarray(self._to_array(entry), dtype=cp.float64)
272
+
273
+ if X_gpu.ndim == 1:
274
+ X_gpu = X_gpu.reshape(-1, 1)
275
+ if entry_gpu is not None and entry_gpu.shape[0] != X_gpu.shape[0]:
276
+ raise ValueError("entry must have shape (n_samples,)")
277
+
278
+ self._nobs = int(X_gpu.shape[0])
279
+ self._nevents = int(cp.sum(event_gpu).item())
280
+ if self._feature_names is None:
281
+ self._feature_names = [f'x{i+1}' for i in range(int(X_gpu.shape[1]))]
282
+
283
+ # Keep CPU copies only when CPU-side inference/baseline stats are requested.
284
+ if self.compute_inference:
285
+ self._X = cp.asnumpy(X_gpu)
286
+ self._time = cp.asnumpy(time_gpu)
287
+ self._event = cp.asnumpy(event_gpu)
288
+ self._entry = None if entry_gpu is None else cp.asnumpy(entry_gpu)
289
+ else:
290
+ self._X = None
291
+ self._time = None
292
+ self._event = None
293
+ self._entry = None
294
+
295
+ cluster_gpu = None if cluster is None else cp.asarray(self._to_array(cluster), dtype=cp.int64)
296
+ self._fit_gpu(X_gpu, time_gpu, event_gpu, entry_gpu, cluster_gpu, init_coef=init_coef)
297
+ elif device == Device.TORCH:
298
+ import torch
299
+
300
+ torch_device = "cuda"
301
+
302
+ X_torch = self._to_array(X, Device.TORCH, backend="torch").to(dtype=torch.float64)
303
+ time_torch = self._to_array(time, Device.TORCH, backend="torch").to(dtype=torch.float64)
304
+ event_torch = self._to_array(event, Device.TORCH, backend="torch").to(dtype=torch.int32)
305
+ entry_torch = None if entry is None else self._to_array(
306
+ entry, Device.TORCH, backend="torch"
307
+ ).to(dtype=torch.float64)
308
+
309
+ if X_torch.ndim == 1:
310
+ X_torch = X_torch.reshape(-1, 1)
311
+ if entry_torch is not None and entry_torch.shape[0] != X_torch.shape[0]:
312
+ raise ValueError("entry must have shape (n_samples,)")
313
+
314
+ self._nobs = int(X_torch.shape[0])
315
+ self._nevents = int(torch.sum(event_torch).item())
316
+ if self._feature_names is None:
317
+ self._feature_names = [f'x{i+1}' for i in range(int(X_torch.shape[1]))]
318
+
319
+ # Keep CPU copies only when CPU-side inference/baseline stats are requested.
320
+ if self.compute_inference:
321
+ self._X = X_torch.cpu().numpy()
322
+ self._time = time_torch.cpu().numpy()
323
+ self._event = event_torch.cpu().numpy()
324
+ self._entry = None if entry_torch is None else entry_torch.cpu().numpy()
325
+ else:
326
+ self._X = None
327
+ self._time = None
328
+ self._event = None
329
+ self._entry = None
330
+
331
+ cluster_torch = None if cluster is None else self._to_array(
332
+ cluster, Device.TORCH, backend="torch"
333
+ ).to(dtype=torch.int64)
334
+ self._fit_torch(
335
+ X_torch,
336
+ time_torch,
337
+ event_torch,
338
+ entry_torch,
339
+ cluster_torch,
340
+ torch_device,
341
+ init_coef=init_coef,
342
+ )
343
+ else:
344
+ X_np = np.asarray(self._to_array(X, Device.CPU), dtype=np.float64)
345
+ time_np = np.asarray(self._to_array(time, Device.CPU), dtype=np.float64)
346
+ event_np = np.asarray(self._to_array(event, Device.CPU), dtype=np.int32)
347
+ entry_np = None if entry is None else np.asarray(self._to_array(entry, Device.CPU), dtype=np.float64)
348
+
349
+ if X_np.ndim == 1:
350
+ X_np = X_np.reshape(-1, 1)
351
+ if entry_np is not None and entry_np.shape[0] != X_np.shape[0]:
352
+ raise ValueError("entry must have shape (n_samples,)")
353
+
354
+ self._nobs = X_np.shape[0]
355
+ self._nevents = np.sum(event_np)
356
+
357
+ # Store original data (CPU mode is CPU-only)
358
+ self._time = time_np.copy()
359
+ self._event = event_np.copy()
360
+ self._X = X_np.copy()
361
+ self._entry = None if entry_np is None else entry_np.copy()
362
+ if self._feature_names is None:
363
+ self._feature_names = [f'x{i+1}' for i in range(X_np.shape[1])]
364
+
365
+ cluster_np = None if cluster is None else np.asarray(self._to_array(cluster, Device.CPU))
366
+ self._fit_cpu(X_np, time_np, event_np, entry_np, cluster_np, init_coef=init_coef)
367
+
368
+ self._fitted = True
369
+ return self
370
+
371
+ def _fit_cpu(self, X, time, event, entry=None, cluster=None, init_coef=None):
372
+ """Fit using CPU (NumPy)."""
373
+ if entry is not None:
374
+ self._fit_cpu_with_entry(X, time, event, np.asarray(entry, dtype=np.float64), cluster)
375
+ return
376
+ n_samples, n_features = X.shape
377
+
378
+ # Sort by time ascending so risk-set terms are suffix sums:
379
+ # R(t_i) = {j: t_j >= t_i} -> indices i..n-1 after ascending sort.
380
+ order = np.argsort(time)
381
+ X_sorted = X[order]
382
+ time_sorted = time[order]
383
+ event_sorted = event[order]
384
+ entry_sorted = None if entry is None else np.asarray(entry, dtype=np.float64)[order]
385
+ cluster_sorted = None if cluster is None else np.asarray(cluster)[order]
386
+
387
+ self._efron_pre = None
388
+ self._breslow_pre = None
389
+ self._breslow_pre_gpu = None
390
+ if self.ties == "efron":
391
+ self._efron_pre = self._efron_unique_failure_indices(time_sorted, event_sorted)
392
+ try:
393
+ uft, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(self._efron_pre)
394
+ self._efron_all_singletons = bool(nuft > 0) and all(
395
+ len(ix) == 1 for ix in uft_ix
396
+ )
397
+ except Exception:
398
+ self._efron_all_singletons = False
399
+ else:
400
+ self._efron_all_singletons = False
401
+ self._breslow_pre = self._breslow_unique_failure_groups(
402
+ time_sorted, event_sorted
403
+ )
404
+ if entry_sorted is not None:
405
+ event_idx_np = np.flatnonzero(event_sorted.astype(np.int32) == 1)
406
+ event_times_np = time_sorted[event_idx_np].astype(np.float64, copy=False)
407
+ uft_np, inv_np = np.unique(event_times_np, return_inverse=True)
408
+ self._entry_fail_groups_np = [
409
+ event_idx_np[inv_np == g].astype(np.int64, copy=False)
410
+ for g in range(len(uft_np))
411
+ ]
412
+ self._entry_fail_times_np = uft_np.astype(np.float64, copy=False)
413
+ self._entry_order_np = np.argsort(entry_sorted).astype(np.int64, copy=False)
414
+ self._entry_add_end_np = np.searchsorted(
415
+ entry_sorted, uft_np, side="left"
416
+ ).astype(np.int64, copy=False)
417
+ self._entry_rem_end_np = np.searchsorted(
418
+ time_sorted, uft_np, side="left"
419
+ ).astype(np.int64, copy=False)
420
+ else:
421
+ self._entry_fail_groups_np = None
422
+ self._entry_fail_times_np = None
423
+ self._entry_order_np = None
424
+ self._entry_add_end_np = None
425
+ self._entry_rem_end_np = None
426
+
427
+ # Initialize coefficients (supports warm-start path in CV)
428
+ if init_coef is None:
429
+ beta = np.zeros(n_features, dtype=np.float64)
430
+ else:
431
+ beta = np.asarray(init_coef, dtype=np.float64).reshape(-1)
432
+ if beta.shape[0] != n_features:
433
+ raise ValueError("init_coef must have shape (n_features,)")
434
+
435
+ # Compute null log-likelihood (beta = 0)
436
+ self._log_likelihood_null = self._compute_log_likelihood(
437
+ np.zeros(n_features), X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
438
+ )
439
+
440
+ # Newton-Raphson optimization with L2 penalty
441
+ penalty = float(self.penalty) if hasattr(self, 'penalty') else 0.0
442
+ use_penalty = penalty > 0.0
443
+ # Preferred Newton direction for CPU path; updated adaptively.
444
+ preferred_direction = -1.0
445
+ iteration = -1 # default if max_iter=0
446
+
447
+ for iteration in range(self.max_iter):
448
+ # Compute gradient and Hessian
449
+ grad, hess = self._compute_gradient_hessian(
450
+ beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
451
+ )
452
+
453
+ # Add penalty terms: gradient -= 2*penalty*beta, hessian -= 2*penalty*I
454
+ if use_penalty:
455
+ grad = grad - 2 * penalty * beta
456
+ hess = hess - 2 * penalty * np.eye(n_features, dtype=np.float64)
457
+
458
+ # Solve a Newton-like step on (-hess). In practice, different tie paths
459
+ # may expose Hessian with different sign conventions, so we choose the
460
+ # ascent direction adaptively below using objective evaluation.
461
+ try:
462
+ delta = np.linalg.solve(-hess, grad)
463
+ except np.linalg.LinAlgError:
464
+ # Use pseudo-inverse if singular
465
+ delta = np.linalg.lstsq(-hess, grad, rcond=None)[0]
466
+
467
+ # Line search with step halving
468
+ # Compute log-likelihood at current point
469
+ old_ll = self._compute_log_likelihood(
470
+ beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
471
+ )
472
+ if use_penalty:
473
+ old_ll = old_ll - penalty * np.sum(beta ** 2)
474
+
475
+ # Fast path: try preferred direction first, only test opposite
476
+ # when the preferred full step does not improve.
477
+ direction = preferred_direction
478
+ new_beta = beta + direction * delta
479
+ new_ll = self._compute_log_likelihood(
480
+ new_beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
481
+ )
482
+ if use_penalty:
483
+ new_ll = new_ll - penalty * np.sum(new_beta ** 2)
484
+
485
+ if new_ll <= old_ll - 1e-8:
486
+ # Probe the opposite direction only when needed.
487
+ if entry_sorted is None:
488
+ alt_direction = -direction
489
+ alt_beta = beta + alt_direction * delta
490
+ alt_ll = self._compute_log_likelihood(
491
+ alt_beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
492
+ )
493
+ if use_penalty:
494
+ alt_ll = alt_ll - penalty * np.sum(alt_beta ** 2)
495
+ if alt_ll > new_ll:
496
+ direction = alt_direction
497
+ preferred_direction = alt_direction
498
+ new_beta = alt_beta
499
+ new_ll = alt_ll
500
+
501
+ # Backtracking line search from step=0.5; step=1 was already evaluated.
502
+ if new_ll <= old_ll - 1e-8:
503
+ step = 0.5
504
+ for _ in range(20):
505
+ trial_beta = beta + direction * step * delta
506
+ trial_ll = self._compute_log_likelihood(
507
+ trial_beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
508
+ )
509
+ if use_penalty:
510
+ trial_ll = trial_ll - penalty * np.sum(trial_beta ** 2)
511
+ if trial_ll > old_ll - 1e-8:
512
+ new_beta = trial_beta
513
+ new_ll = trial_ll
514
+ break
515
+ step *= 0.5
516
+ else:
517
+ step = 1.0
518
+ else:
519
+ # Keep successful direction for the next iteration.
520
+ preferred_direction = direction
521
+ step = 1.0
522
+
523
+ # Check convergence
524
+ if np.linalg.norm(delta) * step < self.tol:
525
+ self._converged = True
526
+ beta = new_beta
527
+ break
528
+
529
+ beta = new_beta
530
+
531
+ self._iterations = iteration + 1
532
+ self.coef_ = beta
533
+ self.hazard_ratios_ = np.exp(beta)
534
+
535
+ # Compute final log-likelihood
536
+ self._log_likelihood = self._compute_log_likelihood(
537
+ beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
538
+ )
539
+
540
+ # Compute optional inference statistics
541
+ if self.compute_inference:
542
+ self._compute_inference_cpu(X_sorted, time_sorted, event_sorted, cluster_sorted)
543
+ self._compute_baseline_hazard(X_sorted, time_sorted, event_sorted)
544
+ else:
545
+ self._var_matrix = None
546
+ self._bse = None
547
+ self._zvalues = None
548
+ self._pvalues = None
549
+ self._conf_int = None
550
+ self._score_test_stat = None
551
+ self._score_test_pvalue = None
552
+ self._wald_test_stat = None
553
+ self._wald_test_pvalue = None
554
+ self._lr_test_stat = None
555
+ self._lr_test_pvalue = None
556
+ self._baseline_hazard = None
557
+ self._baseline_cumulative_hazard = None
558
+ self._unique_times = None
559
+
560
+ # Release large temporary GPU tensors early.
561
+ try:
562
+ del X_sorted
563
+ except Exception:
564
+ pass
565
+ try:
566
+ del time_sorted
567
+ except Exception:
568
+ pass
569
+ try:
570
+ del event_sorted
571
+ except Exception:
572
+ pass
573
+ try:
574
+ del grad
575
+ except Exception:
576
+ pass
577
+ try:
578
+ del hess
579
+ except Exception:
580
+ pass
581
+ try:
582
+ del delta
583
+ except Exception:
584
+ pass
585
+ self._cleanup_cuda_memory()
586
+ if self.compute_cindex:
587
+ self._compute_cindex()
588
+ else:
589
+ self._cindex = None
590
+
591
+ def _fit_cpu_with_entry(self, X, time, event, entry, cluster=None):
592
+ """Fit using statsmodels PHReg when delayed entry is provided.
593
+
594
+ Note: L2 penalty is not applied in this path (statsmodels PHReg
595
+ does not support penalized fitting). A warning is emitted when
596
+ penalty is specified.
597
+ """
598
+ if float(self.penalty) > 0:
599
+ import warnings
600
+ warnings.warn(
601
+ "CoxPH with entry (delayed entry) does not support penalties via "
602
+ "statsmodels PHReg. The penalty will be ignored. "
603
+ "Use the GPU/torch path for penalized Cox with delayed entry.",
604
+ UserWarning, stacklevel=3,
605
+ )
606
+ import statsmodels.duration.api as smd
607
+
608
+ n_samples, n_features = X.shape
609
+ model = smd.PHReg(time, X, status=event, entry=entry, ties=self.ties)
610
+ res = model.fit(disp=0)
611
+
612
+ self._iterations = int(getattr(res, "iterations", 0) or 0)
613
+ conv_attr = self._extract_convergence_status(res)
614
+ self._converged = bool(conv_attr) if conv_attr is not None else False
615
+ self.coef_ = np.asarray(res.params, dtype=np.float64)
616
+ self.hazard_ratios_ = np.exp(self.coef_)
617
+ self._log_likelihood = float(res.llf)
618
+
619
+ try:
620
+ null_model = smd.PHReg(time, np.zeros((n_samples, 1), dtype=np.float64), status=event, entry=entry, ties=self.ties)
621
+ null_res = null_model.fit(disp=0)
622
+ self._log_likelihood_null = float(null_res.llf)
623
+ except Exception:
624
+ self._log_likelihood_null = np.nan
625
+
626
+ cov = np.asarray(res.cov_params(), dtype=np.float64)
627
+ if cov.shape != (n_features, n_features):
628
+ cov = np.full((n_features, n_features), np.nan, dtype=np.float64)
629
+ self._var_matrix = cov
630
+ self._bse = np.sqrt(np.maximum(np.diag(cov), 0.0))
631
+ self._zvalues = self.coef_ / (self._bse + 1e-30)
632
+ self._pvalues = 2 * (1 - stats.norm.cdf(np.abs(self._zvalues)))
633
+ self._conf_int = np.asarray(res.conf_int(), dtype=np.float64)
634
+
635
+ # Delayed-entry robust covariance override is intentionally skipped:
636
+ # current internal robust score/hessian helpers do not account for entry.
637
+
638
+ self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
639
+ self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
640
+ try:
641
+ var_inv = np.linalg.solve(self._var_matrix, np.eye(n_features))
642
+ self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
643
+ except np.linalg.LinAlgError:
644
+ self._wald_test_stat = np.nan
645
+ self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
646
+ self._score_test_stat = np.nan
647
+ self._score_test_pvalue = np.nan
648
+
649
+ # Baseline hazard from PHReg output.
650
+ try:
651
+ base = res.baseline_cumulative_hazard[0]
652
+ self._unique_times = np.asarray(base[0], dtype=np.float64)
653
+ self._baseline_cumulative_hazard = np.asarray(base[1], dtype=np.float64)
654
+ if self._baseline_cumulative_hazard.size > 0:
655
+ self._baseline_hazard = np.diff(
656
+ np.concatenate([[0.0], self._baseline_cumulative_hazard])
657
+ )
658
+ else:
659
+ self._baseline_hazard = np.array([], dtype=np.float64)
660
+ except Exception:
661
+ self._baseline_hazard = None
662
+ self._baseline_cumulative_hazard = None
663
+ self._unique_times = None
664
+
665
+ if self.compute_cindex:
666
+ self._compute_cindex()
667
+ else:
668
+ self._cindex = None
669
+
670
+ def _fit_gpu(self, X, time, event, entry=None, cluster=None, init_coef=None):
671
+ """Fit using GPU with full GPU computation."""
672
+ import cupy as cp
673
+ from statgpu.inference._distributions_backend import norm
674
+
675
+ n_samples, n_features = X.shape
676
+
677
+ # Transfer to GPU once
678
+ X = cp.asarray(X, dtype=cp.float64)
679
+ time = cp.asarray(time, dtype=cp.float64)
680
+ event = cp.asarray(event, dtype=cp.int32)
681
+
682
+ # Sort by time ascending so risk-set terms are suffix sums:
683
+ # R(t_i) = {j: t_j >= t_i} -> indices i..n-1 after ascending sort.
684
+ order = cp.argsort(time, kind="stable")
685
+ X_sorted = X[order]
686
+ time_sorted = time[order]
687
+ event_sorted = event[order]
688
+ entry_sorted = None if entry is None else entry[order]
689
+ cluster_sorted = None if cluster is None else cluster[order]
690
+ event_idx_sorted = cp.where(event_sorted == 1)[0]
691
+ self._event_idx_gpu = event_idx_sorted
692
+ self._event_X_sum_gpu = (
693
+ cp.sum(X_sorted[event_idx_sorted], axis=0)
694
+ if int(event_idx_sorted.size) > 0
695
+ else cp.zeros(n_features, dtype=cp.float64)
696
+ )
697
+
698
+ # Precompute Efron tie structure once (depends only on time/event order).
699
+ efron_pre = None
700
+ self._breslow_pre = None
701
+ self._breslow_pre_gpu = None
702
+ if self.ties == "efron":
703
+ if entry_sorted is None:
704
+ efron_pre = self._efron_unique_failure_indices(
705
+ cp.asnumpy(time_sorted), cp.asnumpy(event_sorted)
706
+ )
707
+ self._efron_pre = efron_pre
708
+ try:
709
+ _, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(efron_pre)
710
+ self._efron_all_singletons = bool(nuft > 0) and all(
711
+ len(ix) == 1 for ix in uft_ix
712
+ )
713
+ except Exception:
714
+ self._efron_all_singletons = False
715
+ # Pack enter/exit/fail indices once; reuse across Newton steps on GPU.
716
+ try:
717
+ from ._cox_efron_cuda import efron_indices_to_csr
718
+
719
+ uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft = _unpack_efron_pre6(
720
+ efron_pre
721
+ )
722
+ (
723
+ enter_ptr,
724
+ enter_ind,
725
+ exit_ptr,
726
+ exit_ind,
727
+ fail_ptr,
728
+ fail_ind,
729
+ ) = efron_indices_to_csr(uft_ix, risk_enter, risk_exit, nuft)
730
+ self._efron_pre_csr = (
731
+ enter_ptr,
732
+ enter_ind,
733
+ exit_ptr,
734
+ exit_ind,
735
+ fail_ptr,
736
+ fail_ind,
737
+ first_idx_uft,
738
+ nuft,
739
+ )
740
+ self._efron_pre_csr_gpu = (
741
+ cp.asarray(enter_ptr, dtype=cp.int32),
742
+ cp.asarray(enter_ind, dtype=cp.int32),
743
+ cp.asarray(exit_ptr, dtype=cp.int32),
744
+ cp.asarray(exit_ind, dtype=cp.int32),
745
+ cp.asarray(fail_ptr, dtype=cp.int32),
746
+ cp.asarray(fail_ind, dtype=cp.int32),
747
+ cp.asarray(first_idx_uft, dtype=cp.int32),
748
+ int(nuft),
749
+ )
750
+ except Exception:
751
+ self._efron_pre_csr = None
752
+ self._efron_pre_csr_gpu = None
753
+ else:
754
+ self._efron_pre = None
755
+ self._efron_pre_csr = None
756
+ self._efron_pre_csr_gpu = None
757
+ try:
758
+ _, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(efron_pre)
759
+ n_events = int(cp.asnumpy(cp.sum(event_sorted)))
760
+ avg_tie = float(n_events / max(1, int(nuft)))
761
+ except Exception:
762
+ avg_tie = 1.0
763
+ else:
764
+ self._efron_pre = None
765
+ self._efron_all_singletons = False
766
+ self._efron_pre_csr = None
767
+ self._efron_pre_csr_gpu = None
768
+ first_idx_uft, counts_uft = self._breslow_unique_failure_groups(
769
+ cp.asnumpy(time_sorted), cp.asnumpy(event_sorted)
770
+ )
771
+ self._breslow_pre = (first_idx_uft, counts_uft)
772
+ self._breslow_pre_gpu = (
773
+ cp.asarray(first_idx_uft, dtype=cp.int32),
774
+ cp.asarray(counts_uft, dtype=cp.int32),
775
+ )
776
+ self._breslow_counts_f_gpu = cp.asarray(counts_uft, dtype=cp.float64)
777
+ self._breslow_first_idx_np = np.asarray(first_idx_uft, dtype=np.int64)
778
+ self._breslow_counts_np = np.asarray(counts_uft, dtype=np.float64)
779
+ if entry_sorted is not None:
780
+ # Entry path: avoid stale index cache drift across different sort permutations.
781
+ self._entry_fail_groups_gpu = None
782
+ self._entry_fail_times_gpu = None
783
+ self._entry_order_gpu = None
784
+ self._entry_add_end_np_gpu = None
785
+ self._entry_rem_end_np_gpu = None
786
+ else:
787
+ self._entry_fail_groups_gpu = None
788
+ self._entry_fail_times_gpu = None
789
+ self._entry_order_gpu = None
790
+ self._entry_add_end_np_gpu = None
791
+ self._entry_rem_end_np_gpu = None
792
+ n_events = int(cp.asnumpy(cp.sum(event_sorted)))
793
+ avg_tie = float(n_events / max(1, int(len(counts_uft))))
794
+
795
+ # Initialize coefficients on GPU (supports warm-start path in CV)
796
+ if init_coef is None:
797
+ beta = cp.zeros(n_features, dtype=cp.float64)
798
+ else:
799
+ beta = cp.asarray(np.asarray(init_coef, dtype=np.float64), dtype=cp.float64).reshape(-1)
800
+ if int(beta.shape[0]) != int(n_features):
801
+ raise ValueError("init_coef must have shape (n_features,)")
802
+
803
+ # Compute null log-likelihood on GPU
804
+ entry_ctx_gpu = None
805
+ if entry_sorted is not None:
806
+ _ctx = self._build_entry_ctx_gpu(time_sorted, event_sorted, entry_sorted, cp)
807
+ event_idx_ctx = _ctx[5]
808
+ entry_ctx_gpu = (
809
+ _ctx[0], _ctx[1], _ctx[2], _ctx[3],
810
+ cp.ascontiguousarray(X_sorted[_ctx[0]]),
811
+ cp.ascontiguousarray(X_sorted),
812
+ event_idx_ctx,
813
+ cp.sum(X_sorted[event_idx_ctx], axis=0),
814
+ _ctx[6],
815
+ )
816
+ loglik_null_gpu = self._compute_log_likelihood_gpu(
817
+ cp.zeros(n_features, dtype=cp.float64),
818
+ X_sorted,
819
+ time_sorted,
820
+ event_sorted,
821
+ efron_pre,
822
+ entry=entry_sorted,
823
+ entry_ctx=entry_ctx_gpu,
824
+ )
825
+
826
+ # Newton-Raphson optimization on GPU with L2 penalty
827
+ penalty = float(self.penalty) if hasattr(self, 'penalty') else 0.0
828
+ use_penalty = penalty > 0.0
829
+ diag_idx = cp.arange(n_features, dtype=cp.int64) if use_penalty else None
830
+ eye_cache = (
831
+ cp.eye(n_features, dtype=cp.float64)
832
+ if (self.compute_inference or use_penalty)
833
+ else None
834
+ )
835
+
836
+ # Newton-Raphson optimization on GPU
837
+ loglik_gpu = None
838
+ current_obj = None
839
+ iteration = -1 # default if max_iter=0
840
+ for iteration in range(self.max_iter):
841
+ # Compute gradient and Hessian on GPU
842
+ grad, hess, aux_stats = self._compute_gradient_hessian_gpu(
843
+ beta, X_sorted, time_sorted, event_sorted, efron_pre, return_aux=True, entry=entry_sorted, entry_ctx=entry_ctx_gpu
844
+ )
845
+
846
+ # Add penalty terms: gradient -= 2*penalty*beta, hessian -= 2*penalty*I
847
+ if use_penalty:
848
+ grad = grad - 2 * penalty * beta
849
+ # In-place diagonal shift avoids allocating a new dense eye each iteration.
850
+ hess[diag_idx, diag_idx] -= 2 * penalty
851
+
852
+ # Newton: delta = inv(hess) @ grad; hess is NSD — solve (-hess) x = grad, delta = -x
853
+ delta = self._solve_newton_delta_gpu(hess, grad, cp, eye_cache=eye_cache)
854
+ step = 1.0
855
+ accepted_step = True
856
+ if entry_sorted is not None:
857
+ if current_obj is None:
858
+ old_ll = self._compute_log_likelihood_gpu_from_stats(
859
+ aux_stats[0], aux_stats[1], aux_stats[2], time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_gpu
860
+ )
861
+ if use_penalty:
862
+ old_ll = old_ll - penalty * cp.sum(beta * beta)
863
+ current_obj = old_ll
864
+ else:
865
+ old_ll = current_obj
866
+ new_beta = beta - delta
867
+ new_ll = self._compute_log_likelihood_gpu(
868
+ new_beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_gpu
869
+ )
870
+ if use_penalty:
871
+ new_ll = new_ll - penalty * cp.sum(new_beta * new_beta)
872
+ if float((new_ll - old_ll).item()) <= -1e-8:
873
+ step = 0.5
874
+ accepted = False
875
+ for _ in range(20):
876
+ trial_beta = beta - step * delta
877
+ trial_ll = self._compute_log_likelihood_gpu(
878
+ trial_beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_gpu
879
+ )
880
+ if use_penalty:
881
+ trial_ll = trial_ll - penalty * cp.sum(trial_beta * trial_beta)
882
+ if float((trial_ll - old_ll).item()) > -1e-8:
883
+ beta = trial_beta
884
+ current_obj = trial_ll
885
+ accepted = True
886
+ break
887
+ step *= 0.5
888
+ if not accepted:
889
+ accepted_step = False
890
+ else:
891
+ beta = new_beta
892
+ current_obj = new_ll
893
+ else:
894
+ beta = beta - delta
895
+
896
+ # Check convergence on GPU
897
+ if entry_sorted is not None:
898
+ delta_norm = float(cp.linalg.norm(delta).item())
899
+ if accepted_step and delta_norm * step < self.tol:
900
+ self._converged = True
901
+ loglik_gpu = self._compute_log_likelihood_gpu(
902
+ beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_gpu
903
+ )
904
+ break
905
+ else:
906
+ grad_norm = float(cp.linalg.norm(grad).item())
907
+ delta_norm = float(cp.linalg.norm(delta).item())
908
+ if accepted_step and grad_norm < max(self.tol * 10.0, 1e-8) and delta_norm * step < self.tol:
909
+ self._converged = True
910
+ # Reuse current iteration statistics to avoid an extra
911
+ # Efron log-likelihood setup pass when converged.
912
+ eta_cur, exp_eta_cur, risk_sum_cur = aux_stats
913
+ loglik_gpu = self._compute_log_likelihood_gpu_from_stats(
914
+ eta_cur, exp_eta_cur, risk_sum_cur, time_sorted, event_sorted, efron_pre, entry=entry_sorted
915
+ )
916
+ break
917
+
918
+ # Compute final log-likelihood on GPU unless already obtained on convergence.
919
+ if loglik_gpu is None:
920
+ loglik_gpu = self._compute_log_likelihood_gpu(
921
+ beta, X_sorted, time_sorted, event_sorted, efron_pre
922
+ , entry=entry_sorted, entry_ctx=entry_ctx_gpu
923
+ )
924
+
925
+ # Single transfer at the end
926
+ self._iterations = iteration + 1
927
+ self.coef_ = cp.asnumpy(beta)
928
+ self.hazard_ratios_ = np.exp(self.coef_)
929
+ self._log_likelihood_null = float(cp.asnumpy(loglik_null_gpu))
930
+ self._log_likelihood = float(cp.asnumpy(loglik_gpu))
931
+ if self.compute_cindex:
932
+ cindex_gpu = self._compute_cindex_gpu(X_sorted, time_sorted, event_sorted, beta)
933
+ self._cindex = float(cp.asnumpy(cindex_gpu))
934
+ else:
935
+ self._cindex = None
936
+
937
+ # Inference:
938
+ # - nonrobust: stay on GPU to avoid expensive host transfers/recompute
939
+ # - hc0/hc1/cluster: use CPU inference path (current implementation)
940
+ if self.compute_inference:
941
+ if self.cov_type == "nonrobust":
942
+ try:
943
+ info = -hess
944
+ rhs_eye = eye_cache if eye_cache is not None else cp.eye(info.shape[0], dtype=info.dtype)
945
+ var_gpu = cp.linalg.solve(info, rhs_eye)
946
+ except Exception:
947
+ var_gpu = cp.linalg.pinv(-hess)
948
+ bse_gpu = cp.sqrt(cp.maximum(cp.diag(var_gpu), 0.0))
949
+ z_gpu = beta / (bse_gpu + 1e-30)
950
+ p_gpu = cp.minimum(1.0, 2.0 * norm.sf(cp.abs(z_gpu)))
951
+ z_crit = norm.ppf(0.975)
952
+ ci_gpu = cp.stack([beta - z_crit * bse_gpu, beta + z_crit * bse_gpu], axis=1)
953
+
954
+ self._bse = cp.asnumpy(bse_gpu)
955
+ self._zvalues = cp.asnumpy(z_gpu)
956
+ self._pvalues = cp.asnumpy(p_gpu)
957
+ self._conf_int = cp.asnumpy(ci_gpu)
958
+ self._var_matrix = np.diag(np.square(self._bse))
959
+ self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
960
+ self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
961
+ try:
962
+ var_inv = np.linalg.solve(self._var_matrix, np.eye(self._var_matrix.shape[0]))
963
+ self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
964
+ except np.linalg.LinAlgError:
965
+ self._wald_test_stat = np.nan
966
+ self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
967
+ self._score_test_stat = np.nan
968
+ self._score_test_pvalue = np.nan
969
+ # Keep baseline hazard optional in CUDA fast path to reduce transfer overhead.
970
+ self._baseline_hazard = None
971
+ self._baseline_cumulative_hazard = None
972
+ self._unique_times = None
973
+ else:
974
+ score_resid_gpu = self._compute_robust_score_residuals_gpu(X_sorted, time_sorted, event_sorted)
975
+ try:
976
+ info = -hess
977
+ rhs_eye = eye_cache if eye_cache is not None else cp.eye(info.shape[0], dtype=info.dtype)
978
+ bread = cp.linalg.solve(info, rhs_eye)
979
+ except Exception:
980
+ bread = cp.linalg.pinv(-hess)
981
+
982
+ if self.cov_type == "cluster":
983
+ if cluster_sorted is None:
984
+ raise ValueError("cov_type='cluster' requires cluster ids in fit(..., cluster=...)")
985
+ unique_clusters = cp.unique(cluster_sorted)
986
+ meat = cp.zeros((n_features, n_features), dtype=cp.float64)
987
+ for g in unique_clusters:
988
+ u_g = cp.sum(score_resid_gpu[cluster_sorted == g], axis=0)
989
+ meat += cp.outer(u_g, u_g)
990
+ else:
991
+ meat = score_resid_gpu.T @ score_resid_gpu
992
+ if self.cov_type == "hc1":
993
+ n = X_sorted.shape[0]
994
+ k = X_sorted.shape[1]
995
+ if n > k:
996
+ meat = meat * (n / (n - k))
997
+
998
+ var_gpu = bread @ meat @ bread
999
+ bse_gpu = cp.sqrt(cp.maximum(cp.diag(var_gpu), 0.0))
1000
+ z_gpu = beta / (bse_gpu + 1e-30)
1001
+ p_gpu = cp.minimum(1.0, 2.0 * norm.sf(cp.abs(z_gpu)))
1002
+ z_crit = norm.ppf(0.975)
1003
+ ci_gpu = cp.stack([beta - z_crit * bse_gpu, beta + z_crit * bse_gpu], axis=1)
1004
+
1005
+ self._var_matrix = cp.asnumpy(var_gpu)
1006
+ self._bse = cp.asnumpy(bse_gpu)
1007
+ self._zvalues = cp.asnumpy(z_gpu)
1008
+ self._pvalues = cp.asnumpy(p_gpu)
1009
+ self._conf_int = cp.asnumpy(ci_gpu)
1010
+ self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
1011
+ self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
1012
+ try:
1013
+ var_inv = np.linalg.solve(self._var_matrix, np.eye(self._var_matrix.shape[0]))
1014
+ self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
1015
+ except np.linalg.LinAlgError:
1016
+ self._wald_test_stat = np.nan
1017
+ self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
1018
+ self._score_test_stat = np.nan
1019
+ self._score_test_pvalue = np.nan
1020
+ # Compute baseline hazard on GPU
1021
+ self._compute_baseline_hazard_gpu(X_sorted, time_sorted, event_sorted, beta)
1022
+ else:
1023
+ self._var_matrix = None
1024
+ self._bse = None
1025
+ self._zvalues = None
1026
+ self._pvalues = None
1027
+ self._conf_int = None
1028
+ self._score_test_stat = None
1029
+ self._score_test_pvalue = None
1030
+ self._wald_test_stat = None
1031
+ self._wald_test_pvalue = None
1032
+ self._lr_test_stat = None
1033
+ self._lr_test_pvalue = None
1034
+ self._baseline_hazard = None
1035
+ self._baseline_cumulative_hazard = None
1036
+ self._unique_times = None
1037
+
1038
+ def _fit_torch(self, X, time, event, entry=None, cluster=None, torch_device="cuda", init_coef=None):
1039
+ """Fit using Torch with full GPU computation."""
1040
+ import torch
1041
+ from statgpu.inference._distributions_backend import norm
1042
+
1043
+ n_samples, n_features = X.shape
1044
+
1045
+ # Sort by time ascending so risk-set terms are suffix sums
1046
+ order = torch.argsort(time, stable=True)
1047
+ X_sorted = X[order]
1048
+ time_sorted = time[order]
1049
+ event_sorted = event[order]
1050
+ entry_sorted = None if entry is None else entry[order]
1051
+ cluster_sorted = None if cluster is None else cluster[order]
1052
+
1053
+ # Precompute Efron tie structure once (depends only on time/event order)
1054
+ efron_pre = None
1055
+ self._breslow_pre = None
1056
+ self._breslow_pre_torch = None
1057
+ if self.ties == "efron":
1058
+ if entry_sorted is None:
1059
+ efron_pre = self._efron_unique_failure_indices(
1060
+ time_sorted.cpu().numpy(), event_sorted.cpu().numpy()
1061
+ )
1062
+ self._efron_pre = efron_pre
1063
+ try:
1064
+ _, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(efron_pre)
1065
+ self._efron_all_singletons = bool(nuft > 0) and all(
1066
+ len(ix) == 1 for ix in uft_ix
1067
+ )
1068
+ except Exception:
1069
+ self._efron_all_singletons = False
1070
+ # Reuse CUDA CSR packing for Torch-CUDA fused kernels when available.
1071
+ try:
1072
+ import cupy as cp
1073
+ from ._cox_efron_cuda import efron_indices_to_csr
1074
+
1075
+ uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft = _unpack_efron_pre6(
1076
+ efron_pre
1077
+ )
1078
+ (
1079
+ enter_ptr,
1080
+ enter_ind,
1081
+ exit_ptr,
1082
+ exit_ind,
1083
+ fail_ptr,
1084
+ fail_ind,
1085
+ ) = efron_indices_to_csr(uft_ix, risk_enter, risk_exit, nuft)
1086
+ self._efron_pre_csr = (
1087
+ enter_ptr,
1088
+ enter_ind,
1089
+ exit_ptr,
1090
+ exit_ind,
1091
+ fail_ptr,
1092
+ fail_ind,
1093
+ first_idx_uft,
1094
+ nuft,
1095
+ )
1096
+ self._efron_pre_csr_gpu = (
1097
+ cp.asarray(enter_ptr, dtype=cp.int32),
1098
+ cp.asarray(enter_ind, dtype=cp.int32),
1099
+ cp.asarray(exit_ptr, dtype=cp.int32),
1100
+ cp.asarray(exit_ind, dtype=cp.int32),
1101
+ cp.asarray(fail_ptr, dtype=cp.int32),
1102
+ cp.asarray(fail_ind, dtype=cp.int32),
1103
+ cp.asarray(first_idx_uft, dtype=cp.int32),
1104
+ int(nuft),
1105
+ )
1106
+ except Exception:
1107
+ self._efron_pre_csr = None
1108
+ self._efron_pre_csr_gpu = None
1109
+ else:
1110
+ self._efron_pre = None
1111
+ self._efron_pre_csr = None
1112
+ self._efron_pre_csr_gpu = None
1113
+ try:
1114
+ _, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(efron_pre)
1115
+ n_events = int(torch.sum(event_sorted).item())
1116
+ avg_tie = float(n_events / max(1, int(nuft)))
1117
+ except Exception:
1118
+ avg_tie = 1.0
1119
+ else:
1120
+ self._efron_pre = None
1121
+ self._efron_all_singletons = False
1122
+ self._efron_pre_csr = None
1123
+ self._efron_pre_csr_gpu = None
1124
+ first_idx_uft, counts_uft = self._breslow_unique_failure_groups(
1125
+ time_sorted.cpu().numpy(), event_sorted.cpu().numpy()
1126
+ )
1127
+ self._breslow_pre = (first_idx_uft, counts_uft)
1128
+ self._breslow_pre_torch = (
1129
+ torch.tensor(first_idx_uft, dtype=torch.int32, device=torch_device),
1130
+ torch.tensor(counts_uft, dtype=torch.int32, device=torch_device),
1131
+ )
1132
+ if entry_sorted is not None:
1133
+ # Entry path: avoid stale index cache drift across different sort permutations.
1134
+ self._entry_fail_groups_torch = None
1135
+ self._entry_fail_times_torch = None
1136
+ self._entry_order_torch = None
1137
+ self._entry_add_end_np_torch = None
1138
+ self._entry_rem_end_np_torch = None
1139
+ else:
1140
+ self._entry_fail_groups_torch = None
1141
+ self._entry_fail_times_torch = None
1142
+ self._entry_order_torch = None
1143
+ self._entry_add_end_np_torch = None
1144
+ self._entry_rem_end_np_torch = None
1145
+ n_events = int(torch.sum(event_sorted).item())
1146
+ avg_tie = float(n_events / max(1, int(len(counts_uft))))
1147
+
1148
+ # Initialize coefficients on Torch device (supports warm-start path in CV)
1149
+ if init_coef is None:
1150
+ beta = torch.zeros(n_features, dtype=torch.float64, device=torch_device)
1151
+ else:
1152
+ beta = torch.as_tensor(init_coef, dtype=torch.float64, device=torch_device).reshape(-1)
1153
+ if int(beta.shape[0]) != int(n_features):
1154
+ raise ValueError("init_coef must have shape (n_features,)")
1155
+
1156
+ # Compute null log-likelihood on Torch
1157
+ entry_ctx_torch = None
1158
+ if entry_sorted is not None:
1159
+ _ctx = self._build_entry_ctx_torch(time_sorted, event_sorted, entry_sorted, torch_device)
1160
+ event_idx_ctx = _ctx[5]
1161
+ entry_ctx_torch = (
1162
+ _ctx[0],
1163
+ _ctx[1],
1164
+ _ctx[2],
1165
+ _ctx[3],
1166
+ X_sorted.index_select(0, _ctx[0]).contiguous(),
1167
+ X_sorted.contiguous(),
1168
+ event_idx_ctx,
1169
+ torch.sum(X_sorted.index_select(0, event_idx_ctx), dim=0),
1170
+ _ctx[6],
1171
+ )
1172
+ loglik_null_torch = self._compute_log_likelihood_torch(
1173
+ torch.zeros(n_features, dtype=torch.float64, device=torch_device),
1174
+ X_sorted,
1175
+ time_sorted,
1176
+ event_sorted,
1177
+ efron_pre,
1178
+ entry=entry_sorted,
1179
+ entry_ctx=entry_ctx_torch,
1180
+ )
1181
+
1182
+ # Newton-Raphson optimization on Torch with L2 penalty
1183
+ penalty = float(self.penalty) if hasattr(self, 'penalty') else 0.0
1184
+ use_penalty = penalty > 0.0
1185
+ diag_idx = torch.arange(n_features, dtype=torch.long, device=torch_device) if use_penalty else None
1186
+
1187
+ # Newton-Raphson optimization on Torch
1188
+ iteration = 0
1189
+ loglik_torch = None
1190
+ current_obj = None
1191
+ for iteration in range(self.max_iter):
1192
+ # Compute gradient and Hessian on Torch
1193
+ grad, hess, aux_stats = self._compute_gradient_hessian_torch(
1194
+ beta, X_sorted, time_sorted, event_sorted, efron_pre, return_aux=True, entry=entry_sorted, entry_ctx=entry_ctx_torch
1195
+ )
1196
+
1197
+ # Add penalty terms: gradient -= 2*penalty*beta, hessian -= 2*penalty*I
1198
+ if use_penalty:
1199
+ grad = grad - 2 * penalty * beta
1200
+ hess[diag_idx, diag_idx] -= 2 * penalty
1201
+
1202
+ # Newton: delta = inv(hess) @ grad; hess is NSD — solve (-hess) x = grad, delta = -x
1203
+ delta = self._solve_newton_delta_torch(hess, grad)
1204
+ step = 1.0
1205
+ accepted_step = True
1206
+ if entry_sorted is not None:
1207
+ if current_obj is None:
1208
+ old_ll = self._compute_log_likelihood_torch_from_stats(
1209
+ aux_stats[0], aux_stats[1], aux_stats[2], time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_torch
1210
+ )
1211
+ if use_penalty:
1212
+ old_ll = old_ll - penalty * torch.sum(beta * beta)
1213
+ current_obj = old_ll
1214
+ else:
1215
+ old_ll = current_obj
1216
+ new_beta = beta - delta
1217
+ new_ll = self._compute_log_likelihood_torch(
1218
+ new_beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_torch
1219
+ )
1220
+ if use_penalty:
1221
+ new_ll = new_ll - penalty * torch.sum(new_beta * new_beta)
1222
+ if float((new_ll - old_ll).item()) <= -1e-8:
1223
+ step = 0.5
1224
+ accepted = False
1225
+ for _ in range(20):
1226
+ trial_beta = beta - step * delta
1227
+ trial_ll = self._compute_log_likelihood_torch(
1228
+ trial_beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_torch
1229
+ )
1230
+ if use_penalty:
1231
+ trial_ll = trial_ll - penalty * torch.sum(trial_beta * trial_beta)
1232
+ if float((trial_ll - old_ll).item()) > -1e-8:
1233
+ beta = trial_beta
1234
+ current_obj = trial_ll
1235
+ accepted = True
1236
+ break
1237
+ step *= 0.5
1238
+ if not accepted:
1239
+ accepted_step = False
1240
+ else:
1241
+ beta = new_beta
1242
+ current_obj = new_ll
1243
+ else:
1244
+ beta = beta - delta
1245
+
1246
+ # Check convergence
1247
+ if entry_sorted is not None:
1248
+ delta_norm = float(torch.linalg.norm(delta).item())
1249
+ if accepted_step and delta_norm * step < self.tol:
1250
+ self._converged = True
1251
+ loglik_torch = self._compute_log_likelihood_torch(
1252
+ beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_torch
1253
+ )
1254
+ break
1255
+ else:
1256
+ grad_norm = float(torch.linalg.norm(grad).item())
1257
+ delta_norm = float(torch.linalg.norm(delta).item())
1258
+ if accepted_step and grad_norm < max(self.tol * 10.0, 1e-8) and delta_norm * step < self.tol:
1259
+ self._converged = True
1260
+ eta_cur, exp_eta_cur, risk_sum_cur = aux_stats
1261
+ loglik_torch = self._compute_log_likelihood_torch_from_stats(
1262
+ eta_cur, exp_eta_cur, risk_sum_cur, time_sorted, event_sorted, efron_pre, entry=entry_sorted
1263
+ )
1264
+ break
1265
+
1266
+ # Compute final log-likelihood on Torch unless already obtained.
1267
+ if loglik_torch is None:
1268
+ loglik_torch = self._compute_log_likelihood_torch(
1269
+ beta, X_sorted, time_sorted, event_sorted, efron_pre
1270
+ , entry=entry_sorted, entry_ctx=entry_ctx_torch
1271
+ )
1272
+
1273
+ # Single transfer at the end
1274
+ self._iterations = iteration + 1
1275
+ self.coef_ = beta.cpu().numpy()
1276
+ self.hazard_ratios_ = np.exp(self.coef_)
1277
+ self._log_likelihood_null = float(loglik_null_torch.item())
1278
+ self._log_likelihood = float(loglik_torch.item())
1279
+ if self.compute_cindex:
1280
+ cindex_torch = self._compute_cindex_torch(X_sorted, time_sorted, event_sorted, beta)
1281
+ self._cindex = float(cindex_torch.item())
1282
+ else:
1283
+ self._cindex = None
1284
+
1285
+ # Inference: nonrobust on Torch, other types fall back to CPU
1286
+ if self.compute_inference:
1287
+ if self.cov_type == "nonrobust":
1288
+ try:
1289
+ info = -hess
1290
+ var_torch = torch.linalg.solve(info, torch.eye(info.shape[0], dtype=info.dtype, device=torch_device))
1291
+ except Exception:
1292
+ var_torch = torch.linalg.pinv(-hess)
1293
+ bse_torch = torch.sqrt(torch.maximum(torch.diag(var_torch), torch.tensor(0.0, dtype=torch.float64, device=torch_device)))
1294
+ z_torch = beta / (bse_torch + 1e-30)
1295
+ p_torch = torch.minimum(torch.tensor(1.0, device=torch_device), 2.0 * norm.sf(torch.abs(z_torch)))
1296
+ z_crit = norm.ppf(0.975)
1297
+ ci_torch = torch.stack([beta - z_crit * bse_torch, beta + z_crit * bse_torch], dim=1)
1298
+
1299
+ self._bse = bse_torch.cpu().numpy()
1300
+ self._zvalues = z_torch.cpu().numpy()
1301
+ self._pvalues = p_torch.cpu().numpy()
1302
+ self._conf_int = ci_torch.cpu().numpy()
1303
+ self._var_matrix = np.diag(np.square(self._bse))
1304
+ self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
1305
+ self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
1306
+ try:
1307
+ var_inv = np.linalg.solve(self._var_matrix, np.eye(self._var_matrix.shape[0]))
1308
+ self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
1309
+ except np.linalg.LinAlgError:
1310
+ self._wald_test_stat = np.nan
1311
+ self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
1312
+ self._score_test_stat = np.nan
1313
+ self._score_test_pvalue = np.nan
1314
+ # Compute baseline hazard on Torch
1315
+ self._compute_baseline_hazard_torch(X_sorted, time_sorted, event_sorted, beta)
1316
+ else:
1317
+ # For hc0/hc1/cluster, use CPU inference path
1318
+ self._compute_inference_cpu(X_sorted.cpu().numpy(), time_sorted.cpu().numpy(), event_sorted.cpu().numpy(),
1319
+ cluster_sorted.cpu().numpy() if cluster_sorted is not None else None)
1320
+ self._baseline_hazard = None
1321
+ self._baseline_cumulative_hazard = None
1322
+ self._unique_times = None
1323
+ else:
1324
+ self._var_matrix = None
1325
+ self._bse = None
1326
+ self._zvalues = None
1327
+ self._pvalues = None
1328
+ self._conf_int = None
1329
+ self._score_test_stat = None
1330
+ self._score_test_pvalue = None
1331
+ self._wald_test_stat = None
1332
+ self._wald_test_pvalue = None
1333
+ self._lr_test_stat = None
1334
+ self._lr_test_pvalue = None
1335
+ self._baseline_hazard = None
1336
+ self._baseline_cumulative_hazard = None
1337
+ self._unique_times = None
1338
+ self._cleanup_torch_memory()
1339
+
1340
+ def _compute_log_likelihood(self, beta, X, time, event, efron_pre=None, entry=None):
1341
+ """Compute log partial likelihood (Breslow/Efron tie handling)."""
1342
+ eta = X @ beta
1343
+ eta_eff = eta
1344
+ if entry is not None and self.ties == "breslow":
1345
+ eta_eff = eta - np.max(eta)
1346
+ # Note: We do NOT center eta here. While centering prevents exp overflow,
1347
+ # it introduces a beta-dependent shift that complicates numeric gradient verification.
1348
+ # In practice, exp(eta) overflow is rare when beta is near convergence.
1349
+ exp_eta = np.exp(eta_eff)
1350
+
1351
+ # Risk set suffix sums for standard (no-entry) path.
1352
+ risk_sum = np.cumsum(exp_eta[::-1])[::-1] if entry is None else None
1353
+
1354
+ event_mask = event == 1
1355
+ if not np.any(event_mask):
1356
+ return 0.0
1357
+
1358
+ if self.ties == "breslow":
1359
+ if entry is not None:
1360
+ fail_groups = getattr(self, "_entry_fail_groups_np", None)
1361
+ add_end_np = getattr(self, "_entry_add_end_np", None)
1362
+ rem_end_np = getattr(self, "_entry_rem_end_np", None)
1363
+ order_np = getattr(self, "_entry_order_np", None)
1364
+ if (
1365
+ fail_groups is None
1366
+ or add_end_np is None
1367
+ or rem_end_np is None
1368
+ or order_np is None
1369
+ ):
1370
+ event_idx = np.flatnonzero(event_mask)
1371
+ event_times = time[event_idx]
1372
+ uft_np, inv_np = np.unique(event_times, return_inverse=True)
1373
+ fail_groups = [
1374
+ event_idx[inv_np == g].astype(np.int64, copy=False)
1375
+ for g in range(len(uft_np))
1376
+ ]
1377
+ order_np = np.argsort(np.asarray(entry, dtype=np.float64)).astype(np.int64, copy=False)
1378
+ add_end_np = np.searchsorted(
1379
+ np.asarray(entry, dtype=np.float64)[order_np], uft_np, side="left"
1380
+ ).astype(np.int64, copy=False)
1381
+ rem_end_np = np.searchsorted(time, uft_np, side="left").astype(np.int64, copy=False)
1382
+
1383
+ s0 = 0.0
1384
+ add_ptr = 0
1385
+ rem_ptr = 0
1386
+ ll = 0.0
1387
+ for g, fail_idx in enumerate(fail_groups):
1388
+ add_end = int(add_end_np[g])
1389
+ if add_end > add_ptr:
1390
+ idx_add = order_np[add_ptr:add_end]
1391
+ s0 += float(np.sum(exp_eta[idx_add]))
1392
+ add_ptr = add_end
1393
+ rem_end = int(rem_end_np[g])
1394
+ if rem_end > rem_ptr:
1395
+ s0 -= float(np.sum(exp_eta[rem_ptr:rem_end]))
1396
+ rem_ptr = rem_end
1397
+ d_t = int(fail_idx.shape[0])
1398
+ if d_t <= 0:
1399
+ continue
1400
+ s0_safe = max(s0, 1e-300)
1401
+ ll += float(np.sum(eta_eff[fail_idx]) - d_t * np.log(s0_safe))
1402
+ return float(ll)
1403
+
1404
+ # l(β) = sum_i(eta_i) - sum_t(d_t * log(S0(t)))
1405
+ breslow_pre = getattr(self, "_breslow_pre", None)
1406
+ if (
1407
+ breslow_pre is not None
1408
+ and len(breslow_pre) == 2
1409
+ and breslow_pre[0].size > 0
1410
+ ):
1411
+ first_idx = breslow_pre[0].astype(np.int64, copy=False)
1412
+ counts = breslow_pre[1].astype(np.float64, copy=False)
1413
+ else:
1414
+ event_times = time[event_mask]
1415
+ uft, counts_i = np.unique(event_times, return_counts=True)
1416
+ first_idx = np.searchsorted(time, uft, side="left").astype(np.int64)
1417
+ counts = counts_i.astype(np.float64)
1418
+ risk_at = risk_sum[first_idx]
1419
+ # With centering: ll = sum(eta_i - eta_max) - sum(d_t * log(S0(t) * exp(-eta_max)))
1420
+ # = sum(eta_i) - n_events*eta_max - sum(d_t * (log(S0(t)) - eta_max))
1421
+ # = sum(eta_i) - n_events*eta_max - sum(d_t * log(S0(t))) + n_events*eta_max
1422
+ # = sum(eta_i) - sum(d_t * log(S0(t))) [eta_max cancels]
1423
+ return float(np.sum(eta_eff[event_mask]) - np.sum(counts * np.log(risk_at)))
1424
+
1425
+ # ---- Efron ----
1426
+ ll = 0.0
1427
+ if efron_pre is not None:
1428
+ uft, uft_ix, _, _, nuft, first_idx_uft = _unpack_efron_pre6(efron_pre)
1429
+
1430
+ # Sum of eta for all events (centering cancels out, use original eta)
1431
+ all_eta_sum = 0.0
1432
+ all_log_denom_sum = 0.0
1433
+
1434
+ for g in range(nuft):
1435
+ ix_ev = uft_ix[g]
1436
+ d = len(ix_ev)
1437
+ if d == 0:
1438
+ continue
1439
+ first_idx = (
1440
+ int(first_idx_uft[g])
1441
+ if first_idx_uft is not None
1442
+ else int(np.searchsorted(time, uft[g], side="left"))
1443
+ )
1444
+ risk_at_t = risk_sum[first_idx]
1445
+ sum_events = float(np.sum(exp_eta[ix_ev]))
1446
+ all_eta_sum += float(np.sum(eta[ix_ev]))
1447
+
1448
+ # Vectorized log denominator sum
1449
+ # Pre-compute k/d values to avoid repeated division
1450
+ k_vals = np.arange(d, dtype=np.float64)
1451
+ denom = risk_at_t - (k_vals / d) * sum_events
1452
+ all_log_denom_sum += float(np.sum(np.log(np.maximum(denom, 1e-300))))
1453
+
1454
+ return float(all_eta_sum - all_log_denom_sum)
1455
+
1456
+ # No precomputation: group event rows by unique failure times (vectorized).
1457
+ event_idx = np.flatnonzero(event_mask)
1458
+ event_times = time[event_idx]
1459
+ uft, inv, counts = np.unique(event_times, return_inverse=True, return_counts=True)
1460
+ first_idx = np.searchsorted(time, uft, side="left").astype(np.int64)
1461
+ risk_at = risk_sum[first_idx]
1462
+
1463
+ sum_events = np.bincount(inv, weights=exp_eta[event_idx], minlength=len(uft)).astype(np.float64)
1464
+ sum_eta_events = np.bincount(inv, weights=eta[event_idx], minlength=len(uft)).astype(np.float64)
1465
+
1466
+ # Vectorized log-likelihood computation
1467
+ ll = float(np.sum(sum_eta_events))
1468
+
1469
+ # For each unique failure time, compute sum of log denominators
1470
+ max_d = int(np.max(counts)) if len(counts) > 0 else 0
1471
+ if max_d > 0:
1472
+ # Create k matrix: (n_uft, max_d) where each row has [0/d, 1/d, ..., (d-1)/d]
1473
+ # Use broadcasting with careful masking for different d values
1474
+ k_matrix = np.arange(max_d, dtype=np.float64) / np.arange(1, max_d + 1, dtype=np.float64)[:, np.newaxis]
1475
+ # This is complex; fall back to loop for correctness
1476
+ for g in range(len(uft)):
1477
+ d = int(counts[g])
1478
+ if d == 0:
1479
+ continue
1480
+ k = np.arange(d, dtype=np.float64) / d
1481
+ denom = risk_at[g] - k * sum_events[g]
1482
+ ll -= float(np.sum(np.log(np.maximum(denom, 1e-300))))
1483
+ else:
1484
+ for g in range(len(uft)):
1485
+ d = int(counts[g])
1486
+ if d == 0:
1487
+ continue
1488
+ k = np.arange(d, dtype=np.float64) / d
1489
+ denom = risk_at[g] - k * sum_events[g]
1490
+ ll -= float(np.sum(np.log(np.maximum(denom, 1e-300))))
1491
+
1492
+ return float(ll)
1493
+
1494
+ def _solve_newton_delta_gpu(self, hess, grad, cp, eye_cache=None):
1495
+ """Newton step delta = inv(hess) @ grad; prefer SPD solve on (-hess) with light jitter."""
1496
+ p = int(hess.shape[0])
1497
+ try:
1498
+ H = -hess
1499
+ eps = 1e-11 * (cp.max(cp.abs(cp.diag(H))) + 1.0)
1500
+ jitter_eye = eye_cache if eye_cache is not None else cp.eye(p, dtype=cp.float64)
1501
+ H = H + eps * jitter_eye
1502
+ # Fast path: SPD solve via Cholesky is usually faster than generic solve.
1503
+ try:
1504
+ L = cp.linalg.cholesky(H)
1505
+ y = cp.linalg.solve(L, grad)
1506
+ x = cp.linalg.solve(L.T, y)
1507
+ return -x
1508
+ except Exception:
1509
+ return -cp.linalg.solve(H, grad)
1510
+ except Exception:
1511
+ try:
1512
+ return cp.linalg.solve(hess, grad)
1513
+ except Exception:
1514
+ return cp.linalg.lstsq(hess, grad, rcond=None)[0].flatten()
1515
+
1516
+ def _compute_log_likelihood_gpu(self, beta, X, time, event, efron_pre=None, entry=None, entry_ctx=None):
1517
+ """Compute log partial likelihood on GPU."""
1518
+ import cupy as cp
1519
+
1520
+ eta = X @ beta
1521
+ exp_eta = cp.exp(eta)
1522
+ # Entry+breslow path does not consume risk_sum; skip the cumsum to
1523
+ # reduce per-evaluation overhead during line-search probes.
1524
+ risk_sum = None if entry is not None else cp.cumsum(exp_eta[::-1])[::-1]
1525
+ return self._compute_log_likelihood_gpu_from_stats(
1526
+ eta, exp_eta, risk_sum, time, event, efron_pre, entry=entry, entry_ctx=entry_ctx
1527
+ )
1528
+
1529
+ def _build_entry_ctx_gpu(self, time, event, entry, cp):
1530
+ """Build entry-time grouped indexing context for a specific sorted GPU view."""
1531
+ event_mask = event == 1
1532
+ event_idx = cp.where(event_mask)[0]
1533
+ evt_t = cp.asnumpy(time[event_idx])
1534
+ if evt_t.size == 0:
1535
+ return (
1536
+ cp.zeros((0,), dtype=cp.int64),
1537
+ np.zeros((0,), dtype=np.float64),
1538
+ np.zeros((0,), dtype=np.int64),
1539
+ np.zeros((0,), dtype=np.int64),
1540
+ cp.zeros((0,), dtype=cp.int64),
1541
+ cp.zeros((0,), dtype=cp.int64),
1542
+ np.zeros((1,), dtype=np.int64),
1543
+ )
1544
+ uft_np, d_counts = np.unique(evt_t, return_counts=True)
1545
+ d_counts = d_counts.astype(np.float64, copy=False)
1546
+ entry_order = cp.argsort(entry)
1547
+ entry_sorted_np = cp.asnumpy(entry[entry_order])
1548
+ time_np = cp.asnumpy(time)
1549
+ add_end_np = np.searchsorted(entry_sorted_np, uft_np, side="left").astype(np.int64, copy=False)
1550
+ rem_end_np = np.searchsorted(time_np, uft_np, side="left").astype(np.int64, copy=False)
1551
+ rem_order = cp.arange(int(time.shape[0]), dtype=cp.int64)
1552
+ event_idx = event_idx.astype(cp.int64, copy=False)
1553
+ fail_ptr = np.empty(d_counts.shape[0] + 1, dtype=np.int64)
1554
+ fail_ptr[0] = 0
1555
+ fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
1556
+ return (entry_order, d_counts, add_end_np, rem_end_np, rem_order, event_idx, fail_ptr)
1557
+
1558
+ def _compute_log_likelihood_gpu_from_stats(
1559
+ self, eta, exp_eta, risk_sum, time, event, efron_pre=None, entry=None, entry_ctx=None
1560
+ ):
1561
+ """Compute log partial likelihood on GPU with precomputed Efron stats."""
1562
+ import cupy as cp
1563
+
1564
+ ll = cp.array(0.0, dtype=cp.float64)
1565
+ event_mask = event == 1
1566
+
1567
+ if not cp.any(event_mask):
1568
+ return ll
1569
+
1570
+ if entry is not None:
1571
+ if entry_ctx is None:
1572
+ entry_order, d_counts, add_end_np, rem_end_np, _rem_order, event_idx, fail_ptr = self._build_entry_ctx_gpu(
1573
+ time, event, entry, cp
1574
+ )
1575
+ else:
1576
+ entry_order, d_counts, add_end_np, rem_end_np = entry_ctx[:4]
1577
+ event_idx = entry_ctx[6] if len(entry_ctx) > 6 else cp.where(event_mask)[0]
1578
+ fail_ptr = entry_ctx[8] if len(entry_ctx) > 8 else None
1579
+ n_groups = int(d_counts.shape[0])
1580
+ if n_groups == 0:
1581
+ return cp.array(0.0, dtype=cp.float64)
1582
+ if fail_ptr is None:
1583
+ fail_ptr = np.empty(n_groups + 1, dtype=np.int64)
1584
+ fail_ptr[0] = 0
1585
+ fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
1586
+
1587
+ exp_entry = exp_eta[entry_order]
1588
+ exp_rem = exp_eta
1589
+ add_pref = cp.cumsum(exp_entry, axis=0)
1590
+ rem_pref = cp.cumsum(exp_rem, axis=0)
1591
+ s0_add = cp.zeros(n_groups, dtype=cp.float64)
1592
+ s0_rem = cp.zeros(n_groups, dtype=cp.float64)
1593
+ mask_add = add_end_np > 0
1594
+ mask_rem = rem_end_np > 0
1595
+ if np.any(mask_add):
1596
+ idx_add = cp.asarray(add_end_np[mask_add] - 1, dtype=cp.int64)
1597
+ s0_add[cp.asarray(mask_add)] = add_pref[idx_add]
1598
+ if np.any(mask_rem):
1599
+ idx_rem = cp.asarray(rem_end_np[mask_rem] - 1, dtype=cp.int64)
1600
+ s0_rem[cp.asarray(mask_rem)] = rem_pref[idx_rem]
1601
+ s0_vec = cp.maximum(s0_add - s0_rem, 1e-300)
1602
+ event_eta = eta[event_idx]
1603
+
1604
+ if self.ties == "breslow":
1605
+ d_vec = cp.asarray(d_counts, dtype=cp.float64)
1606
+ return cp.sum(event_eta) - cp.sum(d_vec * cp.log(s0_vec))
1607
+
1608
+ ll = cp.sum(event_eta)
1609
+ event_exp = exp_eta[event_idx]
1610
+ for g in range(n_groups):
1611
+ d = int(d_counts[g])
1612
+ if d <= 0:
1613
+ continue
1614
+ st = int(fail_ptr[g])
1615
+ ed = int(fail_ptr[g + 1])
1616
+ ef = cp.sum(event_exp[st:ed])
1617
+ base = s0_vec[g]
1618
+ for k in range(d):
1619
+ denom = cp.maximum(base - (float(k) / float(d)) * ef, 1e-300)
1620
+ ll = ll - cp.log(denom)
1621
+ return ll
1622
+
1623
+ if self.ties == 'breslow':
1624
+ # Vectorized Breslow using cached failure groups to avoid
1625
+ # Python loops and host-device sync in GPU hot path.
1626
+ breslow_pre_gpu = getattr(self, "_breslow_pre_gpu", None)
1627
+ if (
1628
+ breslow_pre_gpu is not None
1629
+ and len(breslow_pre_gpu) == 2
1630
+ and int(breslow_pre_gpu[0].size) > 0
1631
+ ):
1632
+ first_idx_uft, counts_uft = breslow_pre_gpu
1633
+ else:
1634
+ uft, counts_uft = cp.unique(time[event_mask], return_counts=True)
1635
+ first_idx_uft = cp.searchsorted(time, uft, side="left")
1636
+ counts_uft = counts_uft.astype(cp.int32, copy=False)
1637
+ risk_at = risk_sum[first_idx_uft]
1638
+ return cp.sum(eta[event_mask]) - cp.sum(
1639
+ counts_uft.astype(cp.float64) * cp.log(risk_at)
1640
+ )
1641
+
1642
+ # Efron: if all groups are singleton failures, Efron == Breslow.
1643
+ if getattr(self, "_efron_all_singletons", False):
1644
+ ep = efron_pre if efron_pre is not None else getattr(self, "_efron_pre", None)
1645
+ if ep is not None:
1646
+ _, _, _, _, nuft, first_idx_uft = _unpack_efron_pre6(ep)
1647
+ first_idx_uft = cp.asarray(first_idx_uft, dtype=cp.int32)
1648
+ counts_uft = cp.ones(int(nuft), dtype=cp.int32)
1649
+ else:
1650
+ uft, counts_uft = cp.unique(time[event_mask], return_counts=True)
1651
+ first_idx_uft = cp.searchsorted(time, uft, side="left")
1652
+ counts_uft = counts_uft.astype(cp.int32, copy=False)
1653
+ risk_at = risk_sum[first_idx_uft]
1654
+ return cp.sum(eta[event_mask]) - cp.sum(
1655
+ counts_uft.astype(cp.float64) * cp.log(risk_at)
1656
+ )
1657
+
1658
+ # Efron: loop over cached failure groups (see `_cox_efron_cuda.compute_efron_loglik_raw`)
1659
+ if efron_pre is not None:
1660
+ try:
1661
+ csr_gpu = getattr(self, "_efron_pre_csr_gpu", None)
1662
+ if csr_gpu is not None:
1663
+ from ._cox_efron_cuda import compute_efron_loglik_raw_csr
1664
+
1665
+ _, _, _, _, fail_ptr, fail_ind, first_idx_uft, nuft = csr_gpu
1666
+ return compute_efron_loglik_raw_csr(
1667
+ eta,
1668
+ exp_eta,
1669
+ risk_sum,
1670
+ fail_ptr,
1671
+ fail_ind,
1672
+ first_idx_uft,
1673
+ nuft,
1674
+ cupy_module=cp,
1675
+ )
1676
+ except Exception:
1677
+ pass
1678
+
1679
+ from ._cox_efron_cuda import compute_efron_loglik_raw
1680
+
1681
+ return compute_efron_loglik_raw(
1682
+ eta, exp_eta, risk_sum, time, efron_pre, cupy_module=cp
1683
+ )
1684
+
1685
+ unique_times = cp.unique(time[event_mask])
1686
+ for t in unique_times:
1687
+ at_time_t = time == t
1688
+ events_at_t = at_time_t & event_mask
1689
+ d = int(cp.sum(events_at_t).item())
1690
+
1691
+ if d == 0:
1692
+ continue
1693
+
1694
+ risk_indices = cp.where(time >= t)[0]
1695
+ if risk_indices.size == 0:
1696
+ continue
1697
+
1698
+ first_idx = risk_indices[0]
1699
+ risk_at_t = risk_sum[first_idx]
1700
+ sum_events = cp.sum(exp_eta[events_at_t])
1701
+
1702
+ ll += cp.sum(eta[events_at_t])
1703
+ for k in range(d):
1704
+ ll -= cp.log(cp.maximum(risk_at_t - (k / d) * sum_events, 1e-300))
1705
+
1706
+ return ll
1707
+
1708
+ def _compute_gradient_hessian(self, beta, X, time, event, efron_pre=None, entry=None):
1709
+ """
1710
+ Gradient and Hessian of the log partial likelihood (same sign convention as statsmodels).
1711
+
1712
+ Parameters
1713
+ ----------
1714
+ efron_pre : optional
1715
+ Output of `_efron_unique_failure_indices`; if None and ties='efron', it is recomputed.
1716
+ Pass the cached structure from `fit` to avoid O(n) Python work every Newton step.
1717
+ """
1718
+ n_samples, n_features = X.shape
1719
+
1720
+ # Linear predictor
1721
+ eta = X @ beta
1722
+ eta_eff = eta
1723
+ if entry is not None and self.ties == "breslow":
1724
+ eta_eff = eta - np.max(eta)
1725
+ exp_eta = np.exp(eta_eff)
1726
+
1727
+ risk_sum = np.cumsum(exp_eta[::-1])[::-1] if entry is None else None
1728
+ X_exp_eta = X * exp_eta[:, np.newaxis]
1729
+ risk_X_sum = np.cumsum(X_exp_eta[::-1], axis=0)[::-1] if entry is None else None
1730
+
1731
+ if self.ties == 'breslow':
1732
+ event_mask = event == 1
1733
+ grad = np.zeros(n_features, dtype=np.float64)
1734
+ if entry is not None:
1735
+ fail_groups = getattr(self, "_entry_fail_groups_np", None)
1736
+ add_end_np = getattr(self, "_entry_add_end_np", None)
1737
+ rem_end_np = getattr(self, "_entry_rem_end_np", None)
1738
+ order_np = getattr(self, "_entry_order_np", None)
1739
+ if (
1740
+ fail_groups is None
1741
+ or add_end_np is None
1742
+ or rem_end_np is None
1743
+ or order_np is None
1744
+ ):
1745
+ event_idx = np.flatnonzero(event_mask)
1746
+ event_times = time[event_idx]
1747
+ uft_np, inv_np = np.unique(event_times, return_inverse=True)
1748
+ fail_groups = [
1749
+ event_idx[inv_np == g].astype(np.int64, copy=False)
1750
+ for g in range(len(uft_np))
1751
+ ]
1752
+ order_np = np.argsort(np.asarray(entry, dtype=np.float64)).astype(np.int64, copy=False)
1753
+ add_end_np = np.searchsorted(
1754
+ np.asarray(entry, dtype=np.float64)[order_np], uft_np, side="left"
1755
+ ).astype(np.int64, copy=False)
1756
+ rem_end_np = np.searchsorted(time, uft_np, side="left").astype(np.int64, copy=False)
1757
+
1758
+ hess = np.zeros((n_features, n_features), dtype=np.float64)
1759
+ s0 = 0.0
1760
+ s1 = np.zeros(n_features, dtype=np.float64)
1761
+ s2 = np.zeros((n_features, n_features), dtype=np.float64)
1762
+ add_ptr = 0
1763
+ rem_ptr = 0
1764
+ for g, fail_idx in enumerate(fail_groups):
1765
+ add_end = int(add_end_np[g])
1766
+ if add_end > add_ptr:
1767
+ idx_add = order_np[add_ptr:add_end]
1768
+ x_add = X[idx_add]
1769
+ w_add = exp_eta[idx_add]
1770
+ wx_add = x_add * w_add[:, np.newaxis]
1771
+ s0 += float(np.sum(w_add))
1772
+ s1 += np.sum(wx_add, axis=0)
1773
+ s2 += wx_add.T @ x_add
1774
+ add_ptr = add_end
1775
+ rem_end = int(rem_end_np[g])
1776
+ if rem_end > rem_ptr:
1777
+ x_rem = X[rem_ptr:rem_end]
1778
+ w_rem = exp_eta[rem_ptr:rem_end]
1779
+ wx_rem = x_rem * w_rem[:, np.newaxis]
1780
+ s0 -= float(np.sum(w_rem))
1781
+ s1 -= np.sum(wx_rem, axis=0)
1782
+ s2 -= wx_rem.T @ x_rem
1783
+ rem_ptr = rem_end
1784
+ d_t = int(fail_idx.shape[0])
1785
+ if d_t <= 0:
1786
+ continue
1787
+ d_t_f = float(d_t)
1788
+ grad += np.sum(X[fail_idx], axis=0)
1789
+ s0_safe = max(s0, 1e-300)
1790
+ if s0 <= 1e-15:
1791
+ continue
1792
+ ex = s1 / s0_safe
1793
+ grad -= d_t_f * ex
1794
+ hess -= d_t_f * (s2 / s0_safe - np.outer(ex, ex))
1795
+ return grad, hess
1796
+
1797
+ first_idx = np.array([], dtype=np.int64)
1798
+ counts = np.array([], dtype=np.float64)
1799
+ if np.any(event_mask):
1800
+ breslow_pre = getattr(self, "_breslow_pre", None)
1801
+ if (
1802
+ breslow_pre is not None
1803
+ and len(breslow_pre) == 2
1804
+ and breslow_pre[0].size > 0
1805
+ ):
1806
+ first_idx = breslow_pre[0].astype(np.int64, copy=False)
1807
+ counts = breslow_pre[1].astype(np.float64, copy=False)
1808
+ else:
1809
+ event_times = time[event_mask]
1810
+ uft, counts_i = np.unique(event_times, return_counts=True)
1811
+ first_idx = np.searchsorted(time, uft, side="left").astype(np.int64)
1812
+ counts = counts_i.astype(np.float64)
1813
+
1814
+ sum_X_events = np.sum(X[event_mask], axis=0)
1815
+ E_X = risk_X_sum[first_idx] / risk_sum[first_idx][:, np.newaxis]
1816
+ grad = sum_X_events - np.sum(E_X * counts[:, np.newaxis], axis=0)
1817
+
1818
+ hess = self._compute_hessian_breslow_fast(
1819
+ X, time, event, risk_sum, risk_X_sum, exp_eta, first_idx, counts
1820
+ )
1821
+ else:
1822
+ # Efron: prefer Cython core if available; fall back to Python implementation
1823
+ # for environments without compiled extension or unexpected runtime issues.
1824
+ # Shift eta by a constant for numerical stability in exp(eta). This does not
1825
+ # change Efron gradient/Hessian because terms are scale-invariant.
1826
+ eta_efron = eta - np.max(eta)
1827
+ if HAS_CYTHON_EFRON and efron_pre is not None:
1828
+ try:
1829
+ uft, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
1830
+ grad, hess = _efron_grad_hess_cython(
1831
+ eta_efron, X, risk_enter, risk_exit, uft_ix, nuft
1832
+ )
1833
+ # Align sign convention with existing CPU Efron backward path.
1834
+ hess = -hess
1835
+ if not (np.isfinite(grad).all() and np.isfinite(hess).all()):
1836
+ raise FloatingPointError("non-finite Cython Efron grad/hess")
1837
+ except Exception:
1838
+ from ._cox_efron_cy import efron_grad_hess_python
1839
+ uft, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
1840
+ grad, hess = efron_grad_hess_python(
1841
+ eta_efron, X, risk_enter, risk_exit, uft_ix, nuft
1842
+ )
1843
+ hess = -hess
1844
+ if not (np.isfinite(grad).all() and np.isfinite(hess).all()):
1845
+ grad, hess = self._compute_gradient_hessian_efron_backward(
1846
+ beta, X, time, event, efron_pre
1847
+ )
1848
+ else:
1849
+ grad, hess = self._compute_gradient_hessian_efron_backward(
1850
+ beta, X, time, event, efron_pre
1851
+ )
1852
+
1853
+ return grad, hess
1854
+
1855
+ def _compute_hessian_breslow_fast(
1856
+ self,
1857
+ X,
1858
+ time,
1859
+ event,
1860
+ risk_sum,
1861
+ risk_X_sum,
1862
+ exp_eta,
1863
+ first_idx=None,
1864
+ counts=None,
1865
+ ):
1866
+ """Compute Breslow Hessian with an auto-selected CPU strategy."""
1867
+ event_mask = event == 1
1868
+ if not np.any(event_mask):
1869
+ return np.zeros((X.shape[1], X.shape[1]), dtype=np.float64)
1870
+
1871
+ # Group tied events by unique failure times to share the same R(t)
1872
+ # denominator across all events at time t (Breslow ties).
1873
+ if first_idx is None or counts is None or len(first_idx) == 0:
1874
+ breslow_pre = getattr(self, "_breslow_pre", None)
1875
+ if (
1876
+ breslow_pre is not None
1877
+ and len(breslow_pre) == 2
1878
+ and breslow_pre[0].size > 0
1879
+ ):
1880
+ first_idx = breslow_pre[0].astype(np.int64, copy=False)
1881
+ counts = breslow_pre[1].astype(np.float64, copy=False)
1882
+ else:
1883
+ event_times = time[event_mask]
1884
+ uft, counts_i = np.unique(event_times, return_counts=True)
1885
+ first_idx = np.searchsorted(time, uft, side="left").astype(np.int64)
1886
+ counts = counts_i.astype(np.float64)
1887
+
1888
+ # Two CPU kernels are kept intentionally:
1889
+ # 1) Tensor path: higher memory, but can be faster for small p / few groups.
1890
+ # 2) Incremental path: lower memory traffic for larger (n, p).
1891
+ p = int(X.shape[1])
1892
+ n_groups = int(len(first_idx))
1893
+ if p <= 24 and n_groups <= 512:
1894
+ return self._compute_hessian_breslow_tensor_grouped(
1895
+ X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
1896
+ )
1897
+ return self._compute_hessian_breslow_incremental_grouped(
1898
+ X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
1899
+ )
1900
+
1901
+ def _compute_hessian_breslow_tensor_grouped(
1902
+ self, X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
1903
+ ):
1904
+ """Grouped Breslow Hessian using explicit (n, p, p) tensor moments."""
1905
+ x2_weighted = np.einsum("ni,nj,n->nij", X, X, exp_eta)
1906
+ risk_X2_sum = np.cumsum(x2_weighted[::-1], axis=0)[::-1]
1907
+ risk_sum_at = risk_sum[first_idx]
1908
+ E_X = risk_X_sum[first_idx] / risk_sum_at[:, np.newaxis]
1909
+ E_XX = risk_X2_sum[first_idx] / risk_sum_at[:, np.newaxis, np.newaxis]
1910
+ centered = E_XX - np.einsum("ni,nj->nij", E_X, E_X)
1911
+ return -np.sum(centered * counts[:, np.newaxis, np.newaxis], axis=0)
1912
+
1913
+ def _compute_hessian_breslow_incremental_grouped(
1914
+ self, X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
1915
+ ):
1916
+ """Grouped Breslow Hessian with incremental risk-set second moments."""
1917
+ # risk_X2 tracks sum_{j in current risk set} exp_eta[j] * x_j x_j^T.
1918
+ X_exp = X * exp_eta[:, np.newaxis]
1919
+ risk_X2 = X_exp.T @ X
1920
+
1921
+ hess = np.zeros((X.shape[1], X.shape[1]), dtype=np.float64)
1922
+ prev_idx = 0
1923
+ for g in range(len(first_idx)):
1924
+ idx = int(first_idx[g])
1925
+ if idx > prev_idx:
1926
+ blk = slice(prev_idx, idx)
1927
+ # Remove rows that are no longer in risk set.
1928
+ risk_X2 -= X_exp[blk].T @ X[blk]
1929
+ prev_idx = idx
1930
+
1931
+ rs = float(risk_sum[idx])
1932
+ if rs <= 0.0:
1933
+ continue
1934
+ ex = risk_X_sum[idx] / rs
1935
+ exx = risk_X2 / rs
1936
+ hess -= counts[g] * (exx - np.outer(ex, ex))
1937
+
1938
+ return hess
1939
+
1940
+ def _compute_hessian_breslow_incremental_grouped_cupy(
1941
+ self, X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
1942
+ ):
1943
+ """CuPy grouped Breslow Hessian with incremental risk-set second moments."""
1944
+ import cupy as cp
1945
+ from ._cox_efron_cuda import apply_breslow_hess_update_raw
1946
+
1947
+ X_exp = X * exp_eta[:, cp.newaxis]
1948
+ risk_X2 = X_exp.T @ X
1949
+
1950
+ p = int(X.shape[1])
1951
+ hess = cp.zeros((p, p), dtype=cp.float64)
1952
+ use_update_kernel = (
1953
+ os.environ.get("STATGPU_BRESLOW_HESS_UPDATE_KERNEL", "1").strip().lower()
1954
+ in ("1", "true", "yes", "on")
1955
+ )
1956
+ exx_buf = cp.empty((p, p), dtype=cp.float64) if not use_update_kernel else None
1957
+ outer_buf = cp.empty((p, p), dtype=cp.float64) if not use_update_kernel else None
1958
+ prev_idx = 0
1959
+ block_size_env = os.environ.get("STATGPU_BRESLOW_GEMM_BLOCK", "1024")
1960
+ try:
1961
+ block_size = int(block_size_env)
1962
+ except (TypeError, ValueError):
1963
+ block_size = 1024
1964
+ block_size = max(64, block_size)
1965
+ # Reuse cached host-side tie metadata when available.
1966
+ first_idx_np = getattr(self, "_breslow_first_idx_np", None)
1967
+ counts_np = getattr(self, "_breslow_counts_np", None)
1968
+ if (
1969
+ first_idx_np is None
1970
+ or counts_np is None
1971
+ or int(first_idx_np.shape[0]) != int(first_idx.shape[0])
1972
+ ):
1973
+ first_idx_np = cp.asnumpy(first_idx).astype(np.int64, copy=False)
1974
+ counts_np = cp.asnumpy(counts).astype(np.float64, copy=False)
1975
+ risk_at_np = cp.asnumpy(risk_sum[first_idx]).astype(np.float64, copy=False)
1976
+ n_groups = int(first_idx_np.size)
1977
+ for g in range(n_groups):
1978
+ idx = int(first_idx_np[g])
1979
+ if idx > prev_idx:
1980
+ # Batch removals to reduce many tiny GEMM launches.
1981
+ cur = prev_idx
1982
+ while cur < idx:
1983
+ nxt = min(idx, cur + block_size)
1984
+ blk = slice(cur, nxt)
1985
+ risk_X2 -= (X_exp[blk].T @ X[blk])
1986
+ cur = nxt
1987
+ prev_idx = idx
1988
+
1989
+ rs = float(risk_at_np[g])
1990
+ if rs <= 0.0:
1991
+ continue
1992
+ ex = risk_X_sum[idx] / rs
1993
+ if use_update_kernel:
1994
+ apply_breslow_hess_update_raw(
1995
+ hess, risk_X2, ex, rs, counts_np[g], cupy_module=cp
1996
+ )
1997
+ else:
1998
+ inv_rs = 1.0 / rs
1999
+ cp.multiply(risk_X2, inv_rs, out=exx_buf)
2000
+ cp.multiply(ex[:, cp.newaxis], ex[cp.newaxis, :], out=outer_buf)
2001
+ cp.subtract(exx_buf, outer_buf, out=exx_buf)
2002
+ hess -= counts_np[g] * exx_buf
2003
+
2004
+ return hess
2005
+
2006
+ def _compute_hessian_breslow_fused_cupy(self, X, first_idx, counts, exp_eta):
2007
+ """Try fused RawKernel Hessian for Breslow; return None on failure."""
2008
+ import cupy as cp
2009
+ debug_fused = (
2010
+ os.environ.get("STATGPU_DEBUG_BRESLOW_FUSED", "0").strip().lower()
2011
+ in ("1", "true", "yes", "on")
2012
+ )
2013
+ try:
2014
+ from ._cox_efron_cuda import compute_breslow_hess_raw
2015
+ return compute_breslow_hess_raw(
2016
+ X,
2017
+ first_idx,
2018
+ counts,
2019
+ cupy_module=cp,
2020
+ exp_eta=exp_eta,
2021
+ )
2022
+ except Exception as ex:
2023
+ if debug_fused:
2024
+ print(f"[CUDA Breslow fused fallback] {type(ex).__name__}: {ex}")
2025
+ return None
2026
+
2027
+ def _compute_hessian_breslow(self, beta, X, time, event, risk_sum, risk_X_sum, exp_eta):
2028
+ """
2029
+ Compute Hessian for Breslow approximation.
2030
+
2031
+ Uses an incremental suffix-scan so total cost is O(n·p²) instead of
2032
+ the previous O(n_events × n × p²) triple-loop.
2033
+
2034
+ Algorithm:
2035
+ 1. Compute the full second-moment matrix M = (X * exp_eta).T @ X -- O(n·p²).
2036
+ 2. Walk through sorted event positions left-to-right, subtracting the
2037
+ contribution of rows that fall *before* the current event (and are
2038
+ therefore not in its risk set) from M incrementally.
2039
+ Each row is subtracted exactly once, so total subtraction work = O(n·p²).
2040
+ """
2041
+ n_samples, n_features = X.shape
2042
+ hess = np.zeros((n_features, n_features), dtype=np.float64)
2043
+
2044
+ X_exp = X * exp_eta[:, np.newaxis] # (n, p)
2045
+ risk_X2_sum = X_exp.T @ X # (p, p), O(n·p²)
2046
+
2047
+ event_positions = np.where(event)[0] # sorted ascending
2048
+ prev_pos = 0
2049
+
2050
+ for ev_i in event_positions:
2051
+ # Remove rows [prev_pos, ev_i) from risk_X2_sum;
2052
+ # they have t < t[ev_i] and are no longer in R(t[ev_i]).
2053
+ if ev_i > prev_pos:
2054
+ blk = slice(prev_pos, ev_i)
2055
+ risk_X2_sum -= X_exp[blk].T @ X[blk] # O(k·p²), k = ev_i - prev_pos
2056
+ prev_pos = ev_i # next event will subtract starting from here
2057
+
2058
+ E_X = risk_X_sum[ev_i] / risk_sum[ev_i] # (p,)
2059
+ E_XX = risk_X2_sum / risk_sum[ev_i] # (p, p)
2060
+ hess -= E_XX - np.outer(E_X, E_X)
2061
+
2062
+ return hess
2063
+
2064
+ def _efron_unique_failure_indices(self, time: np.ndarray, event: np.ndarray):
2065
+ """
2066
+ Unique failure-time bookkeeping (single stratum), matching statsmodels PHSurvivalTime.
2067
+ `time` must be sorted ascending (as in fit).
2068
+ """
2069
+ ift = np.flatnonzero(event == 1)
2070
+ if ift.size == 0:
2071
+ return np.array([], dtype=np.float64), [], [], [], 0, np.array([], dtype=np.int32)
2072
+ n = time.shape[0]
2073
+ ft = time[ift]
2074
+ uft = np.unique(ft)
2075
+ nuft = int(uft.size)
2076
+
2077
+ # First row index at each unique failure time (sorted time); avoids searchsorted in log-likelihood loops.
2078
+ first_idx_uft = np.searchsorted(time, uft, side="left").astype(np.int32)
2079
+
2080
+ # uft_ix: group indices of event rows by unique failure time.
2081
+ group_ids = np.searchsorted(uft, ft, side="left").astype(np.int32) # shape: (n_events,)
2082
+ order_ev = np.argsort(group_ids, kind="stable")
2083
+ ift_sorted = ift[order_ev]
2084
+ group_sorted = group_ids[order_ev]
2085
+ counts_ev = np.bincount(group_sorted, minlength=nuft)
2086
+ ptr_ev = np.empty(nuft + 1, dtype=np.int32)
2087
+ ptr_ev[0] = 0
2088
+ ptr_ev[1:] = np.cumsum(counts_ev, dtype=np.int32)
2089
+ uft_ix = [ift_sorted[ptr_ev[i] : ptr_ev[i + 1]].tolist() for i in range(nuft)]
2090
+
2091
+ # risk_enter: for each unique failure time i, indices of samples with
2092
+ # uft[i-1] <= time < uft[i] (samples entering risk set as we scan backward).
2093
+ # For i=0, includes all samples with time >= uft[0].
2094
+ j_enter = np.searchsorted(uft, time, side="right").astype(np.int32) - 1
2095
+ mask_enter = j_enter >= 0
2096
+ idx_enter = np.nonzero(mask_enter)[0]
2097
+ j_enter_m = j_enter[mask_enter]
2098
+ order_en = np.argsort(j_enter_m, kind="stable")
2099
+ idx_enter_sorted = idx_enter[order_en]
2100
+ j_enter_sorted = j_enter_m[order_en]
2101
+ counts_en = np.bincount(j_enter_sorted, minlength=nuft)
2102
+ ptr_en = np.empty(nuft + 1, dtype=np.int32)
2103
+ ptr_en[0] = 0
2104
+ ptr_en[1:] = np.cumsum(counts_en, dtype=np.int32)
2105
+ risk_enter = [
2106
+ idx_enter_sorted[ptr_en[i] : ptr_en[i + 1]].tolist() for i in range(nuft)
2107
+ ]
2108
+
2109
+ # risk_exit: for backward scan, this is NOT used in the standard Efron algorithm.
2110
+ # The original code had a placeholder that put all samples at index 0, which was wrong.
2111
+ # For proper backward scan, we don't need risk_exit - we only add samples via risk_enter.
2112
+ # Set risk_exit to empty lists for all indices.
2113
+ risk_exit = [[] for _ in range(nuft)]
2114
+
2115
+ return uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft
2116
+
2117
+ @staticmethod
2118
+ def _use_heavy_ties_cpu_fallback() -> bool:
2119
+ """Opt-in adaptive CPU fallback for heavy-ties GPU/Torch runs."""
2120
+ v = os.environ.get("STATGPU_HEAVY_TIES_CPU_FALLBACK", "0").strip().lower()
2121
+ return v in ("1", "true", "yes", "on")
2122
+
2123
+ def _should_cpu_fallback_heavy_ties(self, n_samples, n_features, avg_tie_size):
2124
+ """Heuristic: small/medium problems with dense ties are often CPU-faster."""
2125
+ if not self._use_heavy_ties_cpu_fallback():
2126
+ return False
2127
+ if self.ties not in ("efron", "breslow"):
2128
+ return False
2129
+ if avg_tie_size < 8.0:
2130
+ return False
2131
+ return int(n_samples) <= 20000 and int(n_features) <= 64
2132
+
2133
+ def _breslow_unique_failure_groups(self, time: np.ndarray, event: np.ndarray):
2134
+ """
2135
+ Breslow tie groups for sorted time/event.
2136
+ Returns (first_idx_uft, counts_uft), both int32 arrays.
2137
+ """
2138
+ ift = np.flatnonzero(event == 1)
2139
+ if ift.size == 0:
2140
+ return np.array([], dtype=np.int32), np.array([], dtype=np.int32)
2141
+ ft = time[ift]
2142
+ uft, counts = np.unique(ft, return_counts=True)
2143
+ first_idx_uft = np.searchsorted(time, uft, side="left").astype(np.int32)
2144
+ return first_idx_uft, counts.astype(np.int32)
2145
+
2146
+ def _compute_gradient_hessian_efron_backward(self, beta, X, time, event, efron_pre=None):
2147
+ """
2148
+ Efron gradient and Hessian using direct computation (O(n*d) per unique failure time).
2149
+
2150
+ The gradient is: sum_events(X_i) - sum_events(E[X|R(t_i)])
2151
+ where E[X|R(t_i)] uses the Efron approximation.
2152
+
2153
+ Note: We do NOT center eta for consistency with _compute_log_likelihood.
2154
+ The Efron ratio formula is scale-invariant, so centering is not needed for
2155
+ numerical stability in typical use cases.
2156
+ """
2157
+ n_samples, n_features = X.shape
2158
+ linpred = X @ beta
2159
+ # No centering - matches _compute_log_likelihood
2160
+ e_linpred = np.exp(linpred)
2161
+
2162
+ event_mask = event == 1
2163
+ event_idx = np.where(event_mask)[0]
2164
+
2165
+ if len(event_idx) == 0:
2166
+ return np.zeros(n_features, dtype=np.float64), np.zeros((n_features, n_features), dtype=np.float64)
2167
+
2168
+ # Get unique failure times and their counts
2169
+ event_times = time[event_mask]
2170
+ uft, counts = np.unique(event_times, return_counts=True)
2171
+ nuft = len(uft)
2172
+
2173
+ grad = np.zeros(n_features, dtype=np.float64)
2174
+ hess_inner = np.zeros((n_features, n_features), dtype=np.float64)
2175
+
2176
+ # Pre-compute suffix sums for risk sets
2177
+ # risk_sum[i] = sum of exp(lp) for all j with time[j] >= time[i]
2178
+ order = np.argsort(time)
2179
+ time_sorted = time[order]
2180
+ e_lp_sorted = e_linpred[order]
2181
+ X_sorted = X[order]
2182
+
2183
+ # Suffix sum: risk_sum_sorted[i] = sum of e_lp_sorted[j] for j >= i
2184
+ risk_sum_sorted = np.cumsum(e_lp_sorted[::-1])[::-1]
2185
+ # risk_X_sum_sorted[i] = sum of e_lp_sorted[j] * X_sorted[j] for j >= i
2186
+ risk_X_sum_sorted = np.cumsum((X_sorted * e_lp_sorted[:, np.newaxis])[::-1], axis=0)[::-1]
2187
+ # risk_XX_sum_sorted[i] = sum of e_lp_sorted[j] * X_sorted[j] @ X_sorted[j]^T for j >= i
2188
+ # Use matrix multiplication trick: (X^T diag(e) X) but we need per-row cumulative
2189
+ # Direct einsum is clearest but slow; alternative is loop-based accumulation
2190
+ # For now, use einsum - it's O(n*p^2) but vectorized
2191
+ XX_outer = np.einsum('ni,nj,n->nij', X_sorted, X_sorted, e_lp_sorted)
2192
+ risk_XX_sum_sorted = np.cumsum(XX_outer[::-1], axis=0)[::-1]
2193
+
2194
+ # For each unique failure time, compute the Efron-adjusted expectation
2195
+ for g in range(nuft):
2196
+ t_g = uft[g]
2197
+ d_g = counts[g]
2198
+
2199
+ # Find first index in sorted array with time >= t_g
2200
+ first_idx = np.searchsorted(time_sorted, t_g, side='left')
2201
+
2202
+ # Risk set sums at t_g
2203
+ S0 = risk_sum_sorted[first_idx]
2204
+ S1 = risk_X_sum_sorted[first_idx] # sum of e^lp * X for risk set
2205
+
2206
+ # Events at this time
2207
+ events_at_g = event_idx[event_times == t_g]
2208
+ X_events = X[events_at_g]
2209
+ sum_X_events = X_events.sum(axis=0)
2210
+
2211
+ # Efron approximation: E[X|R(t)] ≈ (1/d) * sum_{k=0}^{d-1} S1(t - k*S0/d) / (S0 - k*S0/d)
2212
+ # Simplified: for each k, weight = 1/(S0 * (1 - k/d)) = 1/(S0 - k*S0/d)
2213
+ # But we need to handle the case where some observations are the events themselves
2214
+
2215
+ # Direct Efron formula for gradient contribution:
2216
+ # sum_{j in events} X_j - sum_{k=0}^{d-1} S1 / (S0 - k*S0/d)
2217
+
2218
+ # Actually, the correct Efron gradient is:
2219
+ # sum_events(X) - sum_{k=0}^{d-1} [S1 / (S0 - (k/d)*sum_events(e^lp))]
2220
+
2221
+ # sum of e^lp for events at this time
2222
+ sum_e_events = e_linpred[events_at_g].sum()
2223
+
2224
+ # Efron adjustment: for k in 0..d-1, compute gradient and Hessian contributions
2225
+ for k in range(d_g):
2226
+ frac = k / d_g
2227
+ denom = S0 - frac * sum_e_events
2228
+ if denom < 1e-300:
2229
+ denom = 1e-300
2230
+
2231
+ # Gradient: S1 / denom (subtracted from sum_X_events later)
2232
+ grad_contrib = S1 / denom
2233
+ grad -= grad_contrib
2234
+
2235
+ # Hessian: -risk_XX_sum/denom + outer(S1,S1)/denom^2
2236
+ # Both terms are needed for correct Newton direction
2237
+ risk_XX_sum = risk_XX_sum_sorted[first_idx]
2238
+ hess_inner -= risk_XX_sum / denom
2239
+ hess_inner += np.outer(S1, S1) / (denom * denom)
2240
+
2241
+ # Add event contribution to gradient
2242
+ grad += sum_X_events
2243
+
2244
+ hess = -hess_inner
2245
+ return grad, hess
2246
+
2247
+ def _compute_gradient_hessian_gpu(
2248
+ self, beta, X, time, event, efron_pre=None, return_aux=False, entry=None, entry_ctx=None
2249
+ ):
2250
+ """Compute gradient and Hessian on GPU."""
2251
+ import cupy as cp
2252
+ import time as _time
2253
+
2254
+ n_samples, n_features = X.shape
2255
+
2256
+ profile_breslow = (
2257
+ os.environ.get("STATGPU_PROFILE_BRESLOW_CUDA", "0").strip().lower()
2258
+ in ("1", "true", "yes", "on")
2259
+ )
2260
+ _t0_all = _time.perf_counter() if profile_breslow else None
2261
+ eta = X @ beta
2262
+ exp_eta = cp.exp(eta)
2263
+ event_mask = event == 1
2264
+
2265
+ # Risk sets (entry-aware path uses dynamic masks below).
2266
+ risk_sum = cp.cumsum(exp_eta[::-1])[::-1] if entry is None else None
2267
+ X_exp_eta = X * exp_eta[:, cp.newaxis]
2268
+ risk_X_sum = cp.cumsum(X_exp_eta[::-1], axis=0)[::-1] if entry is None else None
2269
+ if profile_breslow:
2270
+ cp.cuda.Stream.null.synchronize()
2271
+ _t_pre = _time.perf_counter()
2272
+
2273
+ # Efron: when no ties, use Breslow vectorized path.
2274
+ if self.ties == "efron" and entry is None:
2275
+ if getattr(self, "_efron_all_singletons", False):
2276
+ ep = efron_pre if efron_pre is not None else getattr(self, "_efron_pre", None)
2277
+ if ep is not None:
2278
+ _, _, _, _, nuft, first_idx_uft = _unpack_efron_pre6(ep)
2279
+ first_idx_uft = cp.asarray(first_idx_uft, dtype=cp.int32)
2280
+ counts_uft = cp.ones(int(nuft), dtype=cp.int32)
2281
+ else:
2282
+ uft, counts_uft = cp.unique(time[event_mask], return_counts=True)
2283
+ first_idx_uft = cp.searchsorted(time, uft, side="left")
2284
+ counts_uft = counts_uft.astype(cp.int32, copy=False)
2285
+ counts_f = counts_uft.astype(cp.float64)
2286
+ grad_pre = getattr(self, "_event_X_sum_gpu", None)
2287
+ grad = (
2288
+ grad_pre.copy()
2289
+ if grad_pre is not None and int(grad_pre.shape[0]) == int(n_features)
2290
+ else cp.sum(X[event_mask], axis=0)
2291
+ )
2292
+ E_X = risk_X_sum[first_idx_uft] / risk_sum[first_idx_uft][:, cp.newaxis]
2293
+ grad = grad - cp.sum(E_X * counts_f[:, cp.newaxis], axis=0)
2294
+ use_fused_breslow = (
2295
+ os.environ.get("STATGPU_BRESLOW_FUSED_CUPY", "0").strip().lower()
2296
+ in ("1", "true", "yes", "on")
2297
+ )
2298
+ hess = None
2299
+ if use_fused_breslow:
2300
+ hess = self._compute_hessian_breslow_fused_cupy(
2301
+ X, first_idx_uft, counts_f, exp_eta
2302
+ )
2303
+ if hess is None:
2304
+ hess = self._compute_hessian_breslow_incremental_grouped_cupy(
2305
+ X, risk_sum, risk_X_sum, exp_eta, first_idx_uft, counts_f
2306
+ )
2307
+ if return_aux:
2308
+ return grad, hess, (eta, exp_eta, risk_sum)
2309
+ return grad, hess
2310
+ if efron_pre is None:
2311
+ efron_pre = self._efron_unique_failure_indices(
2312
+ cp.asnumpy(time), cp.asnumpy(event)
2313
+ )
2314
+ out = self._compute_gradient_hessian_efron_backward_gpu(
2315
+ beta, X, efron_pre
2316
+ )
2317
+ if return_aux:
2318
+ return out[0], out[1], (eta, exp_eta, risk_sum)
2319
+ return out
2320
+
2321
+ # Breslow gradient/Hessian (entry-aware path).
2322
+ event_mask = event == 1
2323
+ grad = cp.zeros(n_features, dtype=cp.float64)
2324
+
2325
+ if not cp.any(event_mask):
2326
+ out = (grad, cp.zeros((n_features, n_features), dtype=cp.float64))
2327
+ if return_aux:
2328
+ return out[0], out[1], (eta, exp_eta, risk_sum)
2329
+ return out
2330
+
2331
+ if entry is not None:
2332
+ if entry_ctx is None:
2333
+ entry_order, d_counts, add_end_np, rem_end_np, rem_order, event_idx, fail_ptr = self._build_entry_ctx_gpu(
2334
+ time, event, entry, cp
2335
+ )
2336
+ X_entry = cp.ascontiguousarray(X[entry_order])
2337
+ X_rem = cp.ascontiguousarray(X[rem_order])
2338
+ grad += cp.sum(X[event_idx], axis=0)
2339
+ else:
2340
+ entry_order, d_counts, add_end_np, rem_end_np = entry_ctx[:4]
2341
+ X_entry = entry_ctx[4] if len(entry_ctx) > 4 else X[entry_order]
2342
+ X_rem = entry_ctx[5] if len(entry_ctx) > 5 else X
2343
+ event_idx = entry_ctx[6] if len(entry_ctx) > 6 else cp.where(event_mask)[0]
2344
+ grad += entry_ctx[7] if len(entry_ctx) > 7 else cp.sum(X[event_mask], axis=0)
2345
+ fail_ptr = entry_ctx[8] if len(entry_ctx) > 8 else None
2346
+ hess = cp.zeros((n_features, n_features), dtype=cp.float64)
2347
+ exp_entry = exp_eta[entry_order]
2348
+ exp_rem = exp_eta
2349
+ wx_entry = X_entry * exp_entry[:, cp.newaxis]
2350
+ wx_rem = X_rem * exp_rem[:, cp.newaxis]
2351
+ n_groups = int(d_counts.shape[0])
2352
+ if n_groups == 0:
2353
+ if return_aux:
2354
+ return grad, hess, (eta, exp_eta, risk_sum)
2355
+ return grad, hess
2356
+ s0_add_pref = cp.cumsum(exp_entry, axis=0)
2357
+ s0_rem_pref = cp.cumsum(exp_rem, axis=0)
2358
+ s1_add_pref = cp.cumsum(wx_entry, axis=0)
2359
+ s1_rem_pref = cp.cumsum(wx_rem, axis=0)
2360
+ s0_add = cp.zeros(n_groups, dtype=cp.float64)
2361
+ s0_rem = cp.zeros(n_groups, dtype=cp.float64)
2362
+ s1_add = cp.zeros((n_groups, n_features), dtype=cp.float64)
2363
+ s1_rem = cp.zeros((n_groups, n_features), dtype=cp.float64)
2364
+ mask_add = add_end_np > 0
2365
+ mask_rem = rem_end_np > 0
2366
+ if np.any(mask_add):
2367
+ idx_add = cp.asarray(add_end_np[mask_add] - 1, dtype=cp.int64)
2368
+ mask_add_cp = cp.asarray(mask_add)
2369
+ s0_add[mask_add_cp] = s0_add_pref[idx_add]
2370
+ s1_add[mask_add_cp] = s1_add_pref[idx_add]
2371
+ if np.any(mask_rem):
2372
+ idx_rem = cp.asarray(rem_end_np[mask_rem] - 1, dtype=cp.int64)
2373
+ mask_rem_cp = cp.asarray(mask_rem)
2374
+ s0_rem[mask_rem_cp] = s0_rem_pref[idx_rem]
2375
+ s1_rem[mask_rem_cp] = s1_rem_pref[idx_rem]
2376
+ s0_vec = s0_add - s0_rem
2377
+ s1_vec = s1_add - s1_rem
2378
+ d_vec = cp.asarray(d_counts, dtype=cp.float64)
2379
+ s0_safe_vec = cp.maximum(s0_vec, 1e-15)
2380
+ use_efron_entry = (self.ties == "efron")
2381
+ ex_vec = s1_vec / s0_safe_vec[:, cp.newaxis]
2382
+ if not use_efron_entry:
2383
+ grad -= cp.sum(d_vec[:, cp.newaxis] * ex_vec, axis=0)
2384
+ if use_efron_entry:
2385
+ if fail_ptr is None:
2386
+ fail_ptr = np.empty(n_groups + 1, dtype=np.int64)
2387
+ fail_ptr[0] = 0
2388
+ fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
2389
+ event_exp = exp_eta[event_idx]
2390
+ X_fail = X[event_idx]
2391
+ add_ptr = 0
2392
+ rem_ptr = 0
2393
+ s2 = cp.zeros((n_features, n_features), dtype=cp.float64)
2394
+ s2_block_size = int(os.environ.get("STATGPU_ENTRY_S2_BLOCK_SIZE", "8192"))
2395
+ if s2_block_size <= 0:
2396
+ s2_block_size = 10**18
2397
+ use_s2_fused = (
2398
+ os.environ.get("STATGPU_ENTRY_S2_FUSED_CUPY", "0").strip().lower()
2399
+ in ("1", "true", "yes", "on")
2400
+ )
2401
+ s2_fused_min_rows = int(os.environ.get("STATGPU_ENTRY_S2_FUSED_MIN_ROWS", "512"))
2402
+ if s2_fused_min_rows < 1:
2403
+ s2_fused_min_rows = 1
2404
+ for g in range(n_groups):
2405
+ add_end = int(add_end_np[g])
2406
+ if add_end > add_ptr:
2407
+ x_add = X_entry[add_ptr:add_end]
2408
+ w_add = exp_entry[add_ptr:add_end]
2409
+ n_add = int(add_end - add_ptr)
2410
+ if use_s2_fused and n_add >= s2_fused_min_rows:
2411
+ s2 = self._s2_weighted_update_cupy_fused(s2, x_add, w_add, sign=1.0)
2412
+ elif n_add <= s2_block_size:
2413
+ s2 = s2 + (x_add.T @ (x_add * w_add[:, cp.newaxis]))
2414
+ else:
2415
+ s2 = self._s2_weighted_update_cupy_blocked(
2416
+ s2, x_add, w_add, s2_block_size, sign=1.0
2417
+ )
2418
+ add_ptr = add_end
2419
+
2420
+ rem_end = int(rem_end_np[g])
2421
+ if rem_end > rem_ptr:
2422
+ x_rem = X_rem[rem_ptr:rem_end]
2423
+ w_rem = exp_eta[rem_ptr:rem_end]
2424
+ n_rem = int(rem_end - rem_ptr)
2425
+ if use_s2_fused and n_rem >= s2_fused_min_rows:
2426
+ s2 = self._s2_weighted_update_cupy_fused(s2, x_rem, w_rem, sign=-1.0)
2427
+ elif n_rem <= s2_block_size:
2428
+ s2 = s2 - (x_rem.T @ (x_rem * w_rem[:, cp.newaxis]))
2429
+ else:
2430
+ s2 = self._s2_weighted_update_cupy_blocked(
2431
+ s2, x_rem, w_rem, s2_block_size, sign=-1.0
2432
+ )
2433
+ rem_ptr = rem_end
2434
+
2435
+ d_t_f = float(d_counts[g])
2436
+ if d_t_f <= 0:
2437
+ continue
2438
+ if use_efron_entry:
2439
+ st = int(fail_ptr[g])
2440
+ ed = int(fail_ptr[g + 1])
2441
+ ef = event_exp[st:ed]
2442
+ xf = X_fail[st:ed]
2443
+ ef_sum = cp.sum(ef)
2444
+ ef_x_sum = cp.sum(xf * ef[:, cp.newaxis], axis=0)
2445
+ ef_x2_sum = (xf.T @ (xf * ef[:, cp.newaxis]))
2446
+ s0_g = cp.maximum(s0_vec[g], 1e-15)
2447
+ s1_g = s1_vec[g]
2448
+ d_i = int(d_t_f)
2449
+ for k in range(d_i):
2450
+ frac = float(k) / float(d_i)
2451
+ denom = cp.maximum(s0_g - frac * ef_sum, 1e-15)
2452
+ s1_k = s1_g - frac * ef_x_sum
2453
+ s2_k = s2 - frac * ef_x2_sum
2454
+ ex_k = s1_k / denom
2455
+ grad -= ex_k
2456
+ hess -= s2_k / denom
2457
+ hess += cp.outer(ex_k, ex_k)
2458
+ else:
2459
+ s0_safe = s0_safe_vec[g]
2460
+ hess -= (d_t_f / s0_safe) * s2
2461
+ if not use_efron_entry:
2462
+ hess += ex_vec.T @ (d_vec[:, cp.newaxis] * ex_vec)
2463
+ if return_aux:
2464
+ return grad, hess, (eta, exp_eta, risk_sum)
2465
+ return grad, hess
2466
+
2467
+ # For Breslow ties, all events at the same failure time share the
2468
+ # same risk set R(t); grouping is required for correctness.
2469
+ breslow_pre_gpu = getattr(self, "_breslow_pre_gpu", None)
2470
+ if (
2471
+ breslow_pre_gpu is not None
2472
+ and len(breslow_pre_gpu) == 2
2473
+ and int(breslow_pre_gpu[0].size) > 0
2474
+ ):
2475
+ first_idx_uft, counts_uft = breslow_pre_gpu
2476
+ else:
2477
+ uft, counts_uft = cp.unique(time[event_mask], return_counts=True)
2478
+ first_idx_uft = cp.searchsorted(time, uft, side="left")
2479
+ counts_uft = counts_uft.astype(cp.int32, copy=False)
2480
+
2481
+ counts_f = getattr(self, "_breslow_counts_f_gpu", None)
2482
+ if counts_f is None or int(counts_f.shape[0]) != int(counts_uft.shape[0]):
2483
+ counts_f = counts_uft.astype(cp.float64)
2484
+ grad_pre = getattr(self, "_event_X_sum_gpu", None)
2485
+ grad = (
2486
+ grad_pre.copy()
2487
+ if grad_pre is not None and int(grad_pre.shape[0]) == int(n_features)
2488
+ else cp.sum(X[event_mask], axis=0)
2489
+ )
2490
+ E_X = risk_X_sum[first_idx_uft] / risk_sum[first_idx_uft][:, cp.newaxis]
2491
+ grad = grad - cp.sum(E_X * counts_f[:, cp.newaxis], axis=0)
2492
+ if profile_breslow:
2493
+ cp.cuda.Stream.null.synchronize()
2494
+ _t_grad = _time.perf_counter()
2495
+ use_fused_breslow = (
2496
+ os.environ.get("STATGPU_BRESLOW_FUSED_CUPY", "0").strip().lower()
2497
+ in ("1", "true", "yes", "on")
2498
+ )
2499
+ hess = None
2500
+ if use_fused_breslow:
2501
+ hess = self._compute_hessian_breslow_fused_cupy(
2502
+ X, first_idx_uft, counts_f, exp_eta
2503
+ )
2504
+ if hess is None:
2505
+ hess = self._compute_hessian_breslow_incremental_grouped_cupy(
2506
+ X, risk_sum, risk_X_sum, exp_eta, first_idx_uft, counts_f
2507
+ )
2508
+ if profile_breslow:
2509
+ cp.cuda.Stream.null.synchronize()
2510
+ _t_hess = _time.perf_counter()
2511
+ print(
2512
+ f"[CUDA Breslow profile] pre={(_t_pre - _t0_all):.4f}s "
2513
+ f"grad={(_t_grad - _t_pre):.4f}s "
2514
+ f"hess={(_t_hess - _t_grad):.4f}s "
2515
+ f"total={(_t_hess - _t0_all):.4f}s"
2516
+ )
2517
+ if return_aux:
2518
+ return grad, hess, (eta, exp_eta, risk_sum)
2519
+ return grad, hess
2520
+
2521
+ def _s2_weighted_update_cupy_blocked(self, s2, x, w, block_size, sign=1.0):
2522
+ """Blocked update for large slices: s2 += sign * X^T (X * w)."""
2523
+ import cupy as cp
2524
+
2525
+ n = int(x.shape[0])
2526
+ if n <= 0:
2527
+ return s2
2528
+ for st in range(0, n, block_size):
2529
+ ed = min(st + block_size, n)
2530
+ xb = x[st:ed]
2531
+ wb = w[st:ed]
2532
+ s2 = s2 + sign * (xb.T @ (xb * wb[:, cp.newaxis]))
2533
+ return s2
2534
+
2535
+ def _get_entry_s2_fused_kernel_cupy(self):
2536
+ """Build/cache CuPy RawKernel for fused weighted X^T X update."""
2537
+ k = getattr(self, "_entry_s2_fused_kernel_cupy", None)
2538
+ if k is not None:
2539
+ return k
2540
+ import cupy as cp
2541
+
2542
+ src = r"""
2543
+ extern "C" __global__
2544
+ void entry_s2_outer_f64(const double* x, const double* w, double* out, int n, int p) {
2545
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
2546
+ int j = blockIdx.y * blockDim.y + threadIdx.y;
2547
+ if (i >= p || j >= p) return;
2548
+ double acc = 0.0;
2549
+ for (int r = 0; r < n; ++r) {
2550
+ double wr = w[r];
2551
+ double xi = x[(size_t)r * (size_t)p + (size_t)i];
2552
+ double xj = x[(size_t)r * (size_t)p + (size_t)j];
2553
+ acc += wr * xi * xj;
2554
+ }
2555
+ out[(size_t)i * (size_t)p + (size_t)j] = acc;
2556
+ }
2557
+ """
2558
+ k = cp.RawKernel(src, "entry_s2_outer_f64")
2559
+ self._entry_s2_fused_kernel_cupy = k
2560
+ return k
2561
+
2562
+ def _s2_weighted_update_cupy_fused(self, s2, x, w, sign=1.0):
2563
+ """CuPy fused kernel update for s2 += sign * X^T (X * w)."""
2564
+ import cupy as cp
2565
+
2566
+ n = int(x.shape[0])
2567
+ if n <= 0:
2568
+ return s2
2569
+ x = cp.ascontiguousarray(x, dtype=cp.float64)
2570
+ w = cp.ascontiguousarray(w, dtype=cp.float64)
2571
+ p = int(x.shape[1])
2572
+ out = cp.empty((p, p), dtype=cp.float64)
2573
+ threads = (16, 16, 1)
2574
+ blocks = ((p + 15) // 16, (p + 15) // 16, 1)
2575
+ ker = self._get_entry_s2_fused_kernel_cupy()
2576
+ ker(blocks, threads, (x, w, out, np.int32(n), np.int32(p)))
2577
+ if sign > 0:
2578
+ return s2 + out
2579
+ return s2 - out
2580
+
2581
+ def _compute_gradient_hessian_efron_backward_gpu(self, beta, X, efron_pre):
2582
+ """CuPy Efron grad/Hessian: prefer single CUDA RawKernel scan, else Python-loop fallback."""
2583
+ import cupy as cp
2584
+
2585
+ uft, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
2586
+ n_features = X.shape[1]
2587
+ if nuft == 0:
2588
+ return cp.zeros(n_features, dtype=cp.float64), cp.zeros(
2589
+ (n_features, n_features), dtype=cp.float64
2590
+ )
2591
+
2592
+ n_samples = int(X.shape[0])
2593
+ avg_tie = float(n_samples) / max(1.0, float(nuft))
2594
+ use_grouped_gemm = (
2595
+ os.environ.get("STATGPU_EFRON_GROUPED_GEMM", "1").strip().lower()
2596
+ in ("1", "true", "yes", "on")
2597
+ )
2598
+ if use_grouped_gemm and n_features <= 192 and avg_tie >= 24.0:
2599
+ return self._compute_gradient_hessian_efron_grouped_gemm_cupy(
2600
+ beta, X, efron_pre
2601
+ )
2602
+
2603
+ try:
2604
+ from ._cox_efron_cuda import compute_efron_grad_hess_raw
2605
+
2606
+ csr_gpu = getattr(self, "_efron_pre_csr_gpu", None)
2607
+ if csr_gpu is not None:
2608
+ out = compute_efron_grad_hess_raw(
2609
+ X,
2610
+ beta,
2611
+ efron_pre,
2612
+ efron_csr=csr_gpu,
2613
+ cupy_module=cp,
2614
+ )
2615
+ else:
2616
+ out = compute_efron_grad_hess_raw(X, beta, efron_pre, cupy_module=cp)
2617
+ if out is not None:
2618
+ return out[0], out[1]
2619
+ except Exception:
2620
+ pass
2621
+
2622
+ linpred = X @ beta
2623
+ linpred = linpred - cp.max(linpred)
2624
+ e_linpred = cp.exp(linpred)
2625
+
2626
+ grad = cp.zeros(n_features, dtype=cp.float64)
2627
+ hess_inner = cp.zeros((n_features, n_features), dtype=cp.float64)
2628
+ xp0 = cp.zeros((), dtype=cp.float64)
2629
+ xp1 = cp.zeros(n_features, dtype=cp.float64)
2630
+ xp2 = cp.zeros((n_features, n_features), dtype=cp.float64)
2631
+ for i in range(nuft)[::-1]:
2632
+ ix = risk_enter[i]
2633
+ if len(ix) > 0:
2634
+ ix = cp.array(ix, dtype=cp.int32)
2635
+ elx = e_linpred[ix]
2636
+ v = X[ix]
2637
+ xp0 = xp0 + elx.sum()
2638
+ xp1 = xp1 + (elx[:, None] * v).sum(axis=0)
2639
+ xp2 = xp2 + cp.einsum("ij,ik,i->jk", v, v, elx)
2640
+ ixf = uft_ix[i]
2641
+ if len(ixf) > 0:
2642
+ ixf = cp.array(ixf, dtype=cp.int32)
2643
+ v = X[ixf]
2644
+ elx = e_linpred[ixf]
2645
+ xp0f = elx.sum()
2646
+ xp1f = (elx[:, None] * v).sum(axis=0)
2647
+ xp2f = cp.einsum("ij,ik,i->jk", v, v, elx)
2648
+ m = len(ixf)
2649
+ J = cp.arange(m, dtype=cp.float64) / max(m, 1)
2650
+ c0 = xp0 - J * xp0f
2651
+ c0 = cp.maximum(c0, 1e-300)
2652
+ inv = 1.0 / c0
2653
+ ak = inv
2654
+ bk = J * inv
2655
+ sum_inv_c0 = cp.sum(ak)
2656
+ sum_J_c0 = cp.sum(bk)
2657
+ sum_aa = cp.sum(ak * ak)
2658
+ sum_bb = cp.sum(bk * bk)
2659
+ sum_ab = cp.sum(ak * bk)
2660
+ grad = grad + v.sum(axis=0)
2661
+ grad = grad - (xp1 * sum_inv_c0 - xp1f * sum_J_c0)
2662
+ hess_inner = hess_inner + xp2 * sum_inv_c0
2663
+ hess_inner = hess_inner - xp2f * sum_J_c0
2664
+ hess_inner = hess_inner - (
2665
+ sum_aa * cp.outer(xp1, xp1)
2666
+ + sum_bb * cp.outer(xp1f, xp1f)
2667
+ - sum_ab * (cp.outer(xp1, xp1f) + cp.outer(xp1f, xp1))
2668
+ )
2669
+ ix = risk_exit[i]
2670
+ if len(ix) > 0:
2671
+ ix = cp.array(ix, dtype=cp.int32)
2672
+ elx = e_linpred[ix]
2673
+ v = X[ix]
2674
+ xp0 = xp0 - elx.sum()
2675
+ xp1 = xp1 - (elx[:, None] * v).sum(axis=0)
2676
+ xp2 = xp2 - cp.einsum("ij,ik,i->jk", v, v, elx)
2677
+
2678
+ hess = -hess_inner
2679
+ return grad, hess
2680
+
2681
+ def _compute_gradient_hessian_efron_grouped_gemm_cupy(self, beta, X, efron_pre):
2682
+ """Exact Efron grad/hess on CuPy via grouped GEMM updates (no p^2 atomics)."""
2683
+ import cupy as cp
2684
+
2685
+ _, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
2686
+ n_features = int(X.shape[1])
2687
+ linpred = X @ beta
2688
+ linpred = linpred - cp.max(linpred)
2689
+ e_linpred = cp.exp(linpred)
2690
+
2691
+ grad = cp.zeros(n_features, dtype=cp.float64)
2692
+ hess_inner = cp.zeros((n_features, n_features), dtype=cp.float64)
2693
+ xp0 = cp.zeros((), dtype=cp.float64)
2694
+ xp1 = cp.zeros(n_features, dtype=cp.float64)
2695
+ xp2 = cp.zeros((n_features, n_features), dtype=cp.float64)
2696
+ j_cache = {}
2697
+
2698
+ for i in range(nuft - 1, -1, -1):
2699
+ ix = risk_enter[i]
2700
+ if len(ix) > 0:
2701
+ idx = cp.asarray(ix, dtype=cp.int32)
2702
+ v = X[idx]
2703
+ elx = e_linpred[idx]
2704
+ wv = v * elx[:, None]
2705
+ xp0 = xp0 + cp.sum(elx)
2706
+ xp1 = xp1 + cp.sum(wv, axis=0)
2707
+ xp2 = xp2 + (wv.T @ v)
2708
+
2709
+ ixf = uft_ix[i]
2710
+ if len(ixf) > 0:
2711
+ idxf = cp.asarray(ixf, dtype=cp.int32)
2712
+ v = X[idxf]
2713
+ elx = e_linpred[idxf]
2714
+ wv = v * elx[:, None]
2715
+ xp0f = cp.sum(elx)
2716
+ xp1f = cp.sum(wv, axis=0)
2717
+ xp2f = wv.T @ v
2718
+ m = len(ixf)
2719
+ if m not in j_cache:
2720
+ j_cache[m] = cp.arange(m, dtype=cp.float64) / float(max(m, 1))
2721
+ J = j_cache[m]
2722
+ c0 = cp.maximum(xp0 - J * xp0f, 1e-300)
2723
+ inv = 1.0 / c0
2724
+ ak = inv
2725
+ bk = J * inv
2726
+ sum_inv_c0 = cp.sum(ak)
2727
+ sum_J_c0 = cp.sum(bk)
2728
+ sum_aa = cp.sum(ak * ak)
2729
+ sum_bb = cp.sum(bk * bk)
2730
+ sum_ab = cp.sum(ak * bk)
2731
+ grad = grad + cp.sum(v, axis=0)
2732
+ grad = grad - (xp1 * sum_inv_c0 - xp1f * sum_J_c0)
2733
+ hess_inner = hess_inner + xp2 * sum_inv_c0
2734
+ hess_inner = hess_inner - xp2f * sum_J_c0
2735
+ hess_inner = hess_inner - (
2736
+ sum_aa * cp.outer(xp1, xp1)
2737
+ + sum_bb * cp.outer(xp1f, xp1f)
2738
+ - sum_ab * (cp.outer(xp1, xp1f) + cp.outer(xp1f, xp1))
2739
+ )
2740
+
2741
+ ix = risk_exit[i]
2742
+ if len(ix) > 0:
2743
+ idx = cp.asarray(ix, dtype=cp.int32)
2744
+ v = X[idx]
2745
+ elx = e_linpred[idx]
2746
+ wv = v * elx[:, None]
2747
+ xp0 = xp0 - cp.sum(elx)
2748
+ xp1 = xp1 - cp.sum(wv, axis=0)
2749
+ xp2 = xp2 - (wv.T @ v)
2750
+
2751
+ return grad, -hess_inner
2752
+
2753
+ def _solve_newton_delta_torch(self, hess, grad):
2754
+ """Newton step delta = inv(hess) @ grad; prefer SPD solve on (-hess) with light jitter."""
2755
+ import torch
2756
+
2757
+ p = int(hess.shape[0])
2758
+ try:
2759
+ H = -hess
2760
+ eps = 1e-11 * (torch.max(torch.abs(torch.diag(H))) + 1.0)
2761
+ H = H + eps * torch.eye(p, dtype=torch.float64, device=hess.device)
2762
+ return -torch.linalg.solve(H, grad)
2763
+ except Exception:
2764
+ try:
2765
+ return torch.linalg.solve(hess, grad)
2766
+ except Exception:
2767
+ result = torch.linalg.lstsq(hess, grad)
2768
+ return result.solution.flatten()
2769
+
2770
+ def _compute_gradient_hessian_efron_grouped_gemm_torch(self, beta, X, efron_pre):
2771
+ """Exact Efron grad/hess on Torch device via grouped GEMM updates."""
2772
+ import torch
2773
+
2774
+ _, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
2775
+ n_features = int(X.shape[1])
2776
+ linpred = X @ beta
2777
+ linpred = linpred - torch.max(linpred)
2778
+ e_linpred = torch.exp(linpred)
2779
+
2780
+ grad = torch.zeros(n_features, dtype=torch.float64, device=beta.device)
2781
+ hess_inner = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
2782
+ xp0 = torch.zeros((), dtype=torch.float64, device=beta.device)
2783
+ xp1 = torch.zeros(n_features, dtype=torch.float64, device=beta.device)
2784
+ xp2 = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
2785
+ j_cache = {}
2786
+
2787
+ for i in range(nuft - 1, -1, -1):
2788
+ ix = risk_enter[i]
2789
+ if len(ix) > 0:
2790
+ idx = torch.as_tensor(ix, dtype=torch.long, device=beta.device)
2791
+ v = X[idx]
2792
+ elx = e_linpred[idx]
2793
+ wv = v * elx[:, None]
2794
+ xp0 = xp0 + torch.sum(elx)
2795
+ xp1 = xp1 + torch.sum(wv, dim=0)
2796
+ xp2 = xp2 + (wv.transpose(0, 1) @ v)
2797
+
2798
+ ixf = uft_ix[i]
2799
+ if len(ixf) > 0:
2800
+ idxf = torch.as_tensor(ixf, dtype=torch.long, device=beta.device)
2801
+ v = X[idxf]
2802
+ elx = e_linpred[idxf]
2803
+ wv = v * elx[:, None]
2804
+ xp0f = torch.sum(elx)
2805
+ xp1f = torch.sum(wv, dim=0)
2806
+ xp2f = wv.transpose(0, 1) @ v
2807
+ m = len(ixf)
2808
+ if m not in j_cache:
2809
+ j_cache[m] = torch.arange(m, dtype=torch.float64, device=beta.device) / float(max(m, 1))
2810
+ J = j_cache[m]
2811
+ c0 = torch.clamp(xp0 - J * xp0f, min=1e-300)
2812
+ inv = 1.0 / c0
2813
+ ak = inv
2814
+ bk = J * inv
2815
+ sum_inv_c0 = torch.sum(ak)
2816
+ sum_J_c0 = torch.sum(bk)
2817
+ sum_aa = torch.sum(ak * ak)
2818
+ sum_bb = torch.sum(bk * bk)
2819
+ sum_ab = torch.sum(ak * bk)
2820
+ grad = grad + torch.sum(v, dim=0)
2821
+ grad = grad - (xp1 * sum_inv_c0 - xp1f * sum_J_c0)
2822
+ hess_inner = hess_inner + xp2 * sum_inv_c0
2823
+ hess_inner = hess_inner - xp2f * sum_J_c0
2824
+ hess_inner = hess_inner - (
2825
+ sum_aa * torch.outer(xp1, xp1)
2826
+ + sum_bb * torch.outer(xp1f, xp1f)
2827
+ - sum_ab * (torch.outer(xp1, xp1f) + torch.outer(xp1f, xp1))
2828
+ )
2829
+
2830
+ ix = risk_exit[i]
2831
+ if len(ix) > 0:
2832
+ idx = torch.as_tensor(ix, dtype=torch.long, device=beta.device)
2833
+ v = X[idx]
2834
+ elx = e_linpred[idx]
2835
+ wv = v * elx[:, None]
2836
+ xp0 = xp0 - torch.sum(elx)
2837
+ xp1 = xp1 - torch.sum(wv, dim=0)
2838
+ xp2 = xp2 - (wv.transpose(0, 1) @ v)
2839
+
2840
+ return grad, -hess_inner
2841
+
2842
+ def _compute_log_likelihood_torch(self, beta, X, time, event, efron_pre=None, entry=None, entry_ctx=None):
2843
+ """Compute log partial likelihood on Torch."""
2844
+ import torch
2845
+
2846
+ eta = X @ beta
2847
+ exp_eta = torch.exp(eta)
2848
+ # Entry+breslow path does not consume risk_sum; skip the cumsum to
2849
+ # reduce per-evaluation overhead during line-search probes.
2850
+ risk_sum = None if entry is not None else torch.cumsum(exp_eta.flip(0), dim=0).flip(0)
2851
+ return self._compute_log_likelihood_torch_from_stats(
2852
+ eta, exp_eta, risk_sum, time, event, efron_pre, entry=entry, entry_ctx=entry_ctx
2853
+ )
2854
+
2855
+ def _build_entry_ctx_torch(self, time, event, entry, device):
2856
+ """Build entry-time grouped indexing context for a specific sorted Torch view."""
2857
+ import torch
2858
+
2859
+ event_mask = event == 1
2860
+ event_idx = torch.where(event_mask)[0]
2861
+ evt_t = time[event_idx].detach().cpu().numpy()
2862
+ if evt_t.size == 0:
2863
+ return (
2864
+ torch.zeros((0,), dtype=torch.long, device=device),
2865
+ np.zeros((0,), dtype=np.float64),
2866
+ np.zeros((0,), dtype=np.int64),
2867
+ np.zeros((0,), dtype=np.int64),
2868
+ torch.zeros((0,), dtype=torch.long, device=device),
2869
+ torch.zeros((0,), dtype=torch.long, device=device),
2870
+ np.zeros((1,), dtype=np.int64),
2871
+ )
2872
+ uft_np, d_counts = np.unique(evt_t, return_counts=True)
2873
+ d_counts = d_counts.astype(np.float64, copy=False)
2874
+ entry_order = torch.argsort(entry, stable=True)
2875
+ entry_sorted_np = entry.index_select(0, entry_order).detach().cpu().numpy()
2876
+ time_np = time.detach().cpu().numpy()
2877
+ add_end_np = np.searchsorted(entry_sorted_np, uft_np, side="left").astype(np.int64, copy=False)
2878
+ rem_end_np = np.searchsorted(time_np, uft_np, side="left").astype(np.int64, copy=False)
2879
+ rem_order = torch.arange(int(time.shape[0]), dtype=torch.long, device=device)
2880
+ event_idx = event_idx.to(torch.long)
2881
+ fail_ptr = np.empty(d_counts.shape[0] + 1, dtype=np.int64)
2882
+ fail_ptr[0] = 0
2883
+ fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
2884
+ return (entry_order, d_counts, add_end_np, rem_end_np, rem_order, event_idx, fail_ptr)
2885
+
2886
+ def _compute_log_likelihood_torch_from_stats(
2887
+ self, eta, exp_eta, risk_sum, time, event, efron_pre=None, entry=None, entry_ctx=None
2888
+ ):
2889
+ """Compute log partial likelihood on Torch with precomputed stats."""
2890
+ import torch
2891
+
2892
+ ll = torch.tensor(0.0, dtype=torch.float64, device=eta.device)
2893
+ event_mask = event == 1
2894
+
2895
+ if not torch.any(event_mask):
2896
+ return ll
2897
+
2898
+ if entry is not None:
2899
+ if entry_ctx is None:
2900
+ entry_order, d_counts, add_end_np, rem_end_np, _rem_order, event_idx, fail_ptr = self._build_entry_ctx_torch(
2901
+ time, event, entry, eta.device
2902
+ )
2903
+ else:
2904
+ entry_order, d_counts, add_end_np, rem_end_np = entry_ctx[:4]
2905
+ event_idx = entry_ctx[6] if len(entry_ctx) > 6 else torch.where(event_mask)[0]
2906
+ fail_ptr = entry_ctx[8] if len(entry_ctx) > 8 else None
2907
+
2908
+ n_groups = int(d_counts.shape[0])
2909
+ if n_groups == 0:
2910
+ return torch.tensor(0.0, dtype=torch.float64, device=eta.device)
2911
+ if fail_ptr is None:
2912
+ fail_ptr = np.empty(n_groups + 1, dtype=np.int64)
2913
+ fail_ptr[0] = 0
2914
+ fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
2915
+
2916
+ exp_entry = exp_eta.index_select(0, entry_order)
2917
+ exp_rem = exp_eta
2918
+ s0_add_pref = torch.cumsum(exp_entry, dim=0)
2919
+ s0_rem_pref = torch.cumsum(exp_rem, dim=0)
2920
+ s0_add = torch.zeros(n_groups, dtype=torch.float64, device=eta.device)
2921
+ s0_rem = torch.zeros(n_groups, dtype=torch.float64, device=eta.device)
2922
+ mask_add = add_end_np > 0
2923
+ mask_rem = rem_end_np > 0
2924
+ if np.any(mask_add):
2925
+ idx_add = torch.as_tensor(add_end_np[mask_add] - 1, dtype=torch.long, device=eta.device)
2926
+ s0_add[torch.as_tensor(mask_add, dtype=torch.bool, device=eta.device)] = s0_add_pref.index_select(0, idx_add)
2927
+ if np.any(mask_rem):
2928
+ idx_rem = torch.as_tensor(rem_end_np[mask_rem] - 1, dtype=torch.long, device=eta.device)
2929
+ s0_rem[torch.as_tensor(mask_rem, dtype=torch.bool, device=eta.device)] = s0_rem_pref.index_select(0, idx_rem)
2930
+ s0_vec = torch.clamp(s0_add - s0_rem, min=1e-300)
2931
+ event_eta = eta.index_select(0, event_idx)
2932
+
2933
+ if self.ties == "breslow":
2934
+ d_vec = torch.as_tensor(d_counts, dtype=torch.float64, device=eta.device)
2935
+ return torch.sum(event_eta) - torch.sum(d_vec * torch.log(s0_vec))
2936
+
2937
+ ll = torch.sum(event_eta)
2938
+ event_exp = exp_eta.index_select(0, event_idx)
2939
+ for g in range(n_groups):
2940
+ d = int(d_counts[g])
2941
+ if d <= 0:
2942
+ continue
2943
+ st = int(fail_ptr[g])
2944
+ ed = int(fail_ptr[g + 1])
2945
+ ef = torch.sum(event_exp[st:ed])
2946
+ base = s0_vec[g]
2947
+ for k in range(d):
2948
+ denom = torch.clamp(base - (float(k) / float(d)) * ef, min=1e-300)
2949
+ ll = ll - torch.log(denom)
2950
+ return ll
2951
+
2952
+ if self.ties == "breslow":
2953
+ # Vectorized Breslow using cached failure groups
2954
+ breslow_pre_torch = getattr(self, "_breslow_pre_torch", None)
2955
+ if (
2956
+ breslow_pre_torch is not None
2957
+ and len(breslow_pre_torch) == 2
2958
+ and int(breslow_pre_torch[0].numel()) > 0
2959
+ ):
2960
+ first_idx_uft, counts_uft = breslow_pre_torch
2961
+ else:
2962
+ uft, counts_uft = torch.unique(time[event_mask], return_counts=True)
2963
+ first_idx_uft = torch.searchsorted(time, uft, side="left")
2964
+ counts_uft = counts_uft.to(torch.int32)
2965
+ risk_at = risk_sum[first_idx_uft]
2966
+ return torch.sum(eta[event_mask]) - torch.sum(
2967
+ counts_uft.to(torch.float64) * torch.log(risk_at)
2968
+ )
2969
+
2970
+ # Efron: keep computation fully on torch backend.
2971
+ if efron_pre is not None:
2972
+ needs_exact_ties = not getattr(self, "_efron_all_singletons", False)
2973
+ # No-tie Efron equals Breslow; keep computation on torch device.
2974
+ if not needs_exact_ties:
2975
+ _, _, _, _, nuft, first_idx_uft = _unpack_efron_pre6(efron_pre)
2976
+ first_idx_t = torch.as_tensor(first_idx_uft, dtype=torch.int64, device=eta.device)
2977
+ counts_t = torch.ones(int(nuft), dtype=torch.float64, device=eta.device)
2978
+ risk_at = risk_sum[first_idx_t]
2979
+ return torch.sum(eta[event_mask]) - torch.sum(counts_t * torch.log(risk_at))
2980
+
2981
+ # Fallback Efron (loop version)
2982
+ unique_times = torch.unique(time[event_mask])
2983
+ for t in unique_times:
2984
+ at_time_t = time == t
2985
+ events_at_t = at_time_t & event_mask
2986
+ d = int(torch.sum(events_at_t).item())
2987
+
2988
+ if d == 0:
2989
+ continue
2990
+
2991
+ risk_indices = torch.where(time >= t)[0]
2992
+ if risk_indices.numel() == 0:
2993
+ continue
2994
+
2995
+ first_idx = risk_indices[0]
2996
+ risk_at_t = risk_sum[first_idx]
2997
+ sum_events = torch.sum(exp_eta[events_at_t])
2998
+
2999
+ ll += torch.sum(eta[events_at_t])
3000
+ for k in range(d):
3001
+ ll -= torch.log(torch.maximum(risk_at_t - (k / d) * sum_events, torch.tensor(1e-300, dtype=torch.float64, device=eta.device)))
3002
+
3003
+ return ll
3004
+
3005
+ def _compute_gradient_hessian_torch(
3006
+ self, beta, X, time, event, efron_pre=None, return_aux=False, entry=None, entry_ctx=None
3007
+ ):
3008
+ """Fully vectorized gradient/Hessian for Torch - Efron and Breslow."""
3009
+ import torch
3010
+ n_samples, n_features = X.shape
3011
+ eta = X @ beta
3012
+ exp_eta = torch.exp(eta)
3013
+ rev_idx = torch.arange(n_samples - 1, -1, -1, device=beta.device)
3014
+ risk_sum = torch.cumsum(exp_eta[rev_idx], dim=0)[rev_idx] if entry is None else None
3015
+
3016
+ if self.ties == "efron" and efron_pre is not None and entry is None:
3017
+ needs_exact_ties = not getattr(self, "_efron_all_singletons", False)
3018
+ n_samples = int(X.shape[0])
3019
+ avg_tie = float(n_samples) / max(1.0, float(_unpack_efron_pre6(efron_pre)[4]))
3020
+ use_grouped_gemm = (
3021
+ os.environ.get("STATGPU_EFRON_GROUPED_GEMM", "1").strip().lower()
3022
+ in ("1", "true", "yes", "on")
3023
+ )
3024
+ # For real ties, use exact torch grouped GEMM path only.
3025
+ if needs_exact_ties and (
3026
+ use_grouped_gemm
3027
+ and beta.is_cuda
3028
+ and n_features <= 192
3029
+ and avg_tie >= 24.0
3030
+ ):
3031
+ out = self._compute_gradient_hessian_efron_grouped_gemm_torch(
3032
+ beta, X, efron_pre
3033
+ )
3034
+ if return_aux:
3035
+ return out[0], out[1], (eta, exp_eta, risk_sum)
3036
+ return out
3037
+
3038
+ # ---- Triton Efron path ----
3039
+ if (
3040
+ os.environ.get("STATGPU_EFRON_TRITON", "0").strip().lower()
3041
+ in ("1", "true", "yes", "on")
3042
+ and beta.is_cuda
3043
+ and efron_pre is not None
3044
+ ):
3045
+ from statgpu.survival._cox_efron_triton import compute_efron_grad_hess_triton
3046
+ triton_out = compute_efron_grad_hess_triton(X, beta, efron_pre)
3047
+ if triton_out is not None:
3048
+ grad, hess = triton_out
3049
+ if return_aux:
3050
+ return grad, hess, (eta, exp_eta, risk_sum)
3051
+ return grad, hess
3052
+
3053
+ # Reverse cumsum for risk sets (vectorized)
3054
+ risk_X_sum = torch.cumsum((X * exp_eta[:, None])[rev_idx], dim=0)[rev_idx] if entry is None else None
3055
+
3056
+ event_mask = event == 1
3057
+ if not torch.any(event_mask):
3058
+ out = (
3059
+ torch.zeros(n_features, dtype=torch.float64, device=beta.device),
3060
+ torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device),
3061
+ )
3062
+ if return_aux:
3063
+ return out[0], out[1], (eta, exp_eta, risk_sum)
3064
+ return out
3065
+
3066
+ if entry is not None:
3067
+ if entry_ctx is None:
3068
+ entry_order, d_counts, add_end_np, rem_end_np, rem_order, event_idx, fail_ptr = self._build_entry_ctx_torch(
3069
+ time, event, entry, beta.device
3070
+ )
3071
+ X_entry = X.index_select(0, entry_order).contiguous()
3072
+ X_rem = X.index_select(0, rem_order).contiguous()
3073
+ grad = torch.sum(X.index_select(0, event_idx), dim=0)
3074
+ else:
3075
+ entry_order, d_counts, add_end_np, rem_end_np = entry_ctx[:4]
3076
+ X_entry = entry_ctx[4] if len(entry_ctx) > 4 else X.index_select(0, entry_order)
3077
+ X_rem = entry_ctx[5] if len(entry_ctx) > 5 else X
3078
+ event_idx = entry_ctx[6] if len(entry_ctx) > 6 else torch.where(event_mask)[0]
3079
+ grad = entry_ctx[7] if len(entry_ctx) > 7 else torch.sum(X[event_mask], dim=0)
3080
+ fail_ptr = entry_ctx[8] if len(entry_ctx) > 8 else None
3081
+ hess = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
3082
+ exp_entry = exp_eta.index_select(0, entry_order)
3083
+ exp_rem = exp_eta
3084
+ wx_entry = X_entry * exp_entry.unsqueeze(1)
3085
+ wx_rem = X_rem * exp_rem.unsqueeze(1)
3086
+ n_groups = int(d_counts.shape[0])
3087
+ if n_groups == 0:
3088
+ if return_aux:
3089
+ return grad, hess, (eta, exp_eta, risk_sum)
3090
+ return grad, hess
3091
+ s0_add_pref = torch.cumsum(exp_entry, dim=0)
3092
+ s0_rem_pref = torch.cumsum(exp_rem, dim=0)
3093
+ s1_add_pref = torch.cumsum(wx_entry, dim=0)
3094
+ s1_rem_pref = torch.cumsum(wx_rem, dim=0)
3095
+ s0_add = torch.zeros(n_groups, dtype=torch.float64, device=beta.device)
3096
+ s0_rem = torch.zeros(n_groups, dtype=torch.float64, device=beta.device)
3097
+ s1_add = torch.zeros((n_groups, n_features), dtype=torch.float64, device=beta.device)
3098
+ s1_rem = torch.zeros((n_groups, n_features), dtype=torch.float64, device=beta.device)
3099
+ mask_add = add_end_np > 0
3100
+ mask_rem = rem_end_np > 0
3101
+ if np.any(mask_add):
3102
+ idx_add = torch.as_tensor(add_end_np[mask_add] - 1, dtype=torch.long, device=beta.device)
3103
+ mask_add_t = torch.as_tensor(mask_add, dtype=torch.bool, device=beta.device)
3104
+ s0_add[mask_add_t] = s0_add_pref.index_select(0, idx_add)
3105
+ s1_add[mask_add_t] = s1_add_pref.index_select(0, idx_add)
3106
+ if np.any(mask_rem):
3107
+ idx_rem = torch.as_tensor(rem_end_np[mask_rem] - 1, dtype=torch.long, device=beta.device)
3108
+ mask_rem_t = torch.as_tensor(mask_rem, dtype=torch.bool, device=beta.device)
3109
+ s0_rem[mask_rem_t] = s0_rem_pref.index_select(0, idx_rem)
3110
+ s1_rem[mask_rem_t] = s1_rem_pref.index_select(0, idx_rem)
3111
+ s0_vec = s0_add - s0_rem
3112
+ s1_vec = s1_add - s1_rem
3113
+ d_vec = torch.as_tensor(d_counts, dtype=torch.float64, device=beta.device)
3114
+ s0_safe_vec = torch.clamp(s0_vec, min=1e-15)
3115
+ use_efron_entry = (self.ties == "efron")
3116
+ ex_vec = s1_vec / s0_safe_vec.unsqueeze(1)
3117
+ if not use_efron_entry:
3118
+ grad = grad - torch.sum(d_vec.unsqueeze(1) * ex_vec, dim=0)
3119
+ if use_efron_entry:
3120
+ if fail_ptr is None:
3121
+ fail_ptr = np.empty(n_groups + 1, dtype=np.int64)
3122
+ fail_ptr[0] = 0
3123
+ fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
3124
+ event_exp = exp_eta.index_select(0, event_idx)
3125
+ X_fail = X.index_select(0, event_idx)
3126
+ add_ptr = 0
3127
+ rem_ptr = 0
3128
+ s2 = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
3129
+ s2_block_size = int(os.environ.get("STATGPU_ENTRY_S2_BLOCK_SIZE", "8192"))
3130
+ if s2_block_size <= 0:
3131
+ s2_block_size = 10**18
3132
+ s2_fn = self._get_entry_s2_torch_fn()
3133
+ for g in range(n_groups):
3134
+ add_end = int(add_end_np[g])
3135
+ if add_end > add_ptr:
3136
+ x_add = X_entry[add_ptr:add_end]
3137
+ w_add = exp_entry[add_ptr:add_end]
3138
+ n_add = int(add_end - add_ptr)
3139
+ if n_add <= s2_block_size:
3140
+ s2 = s2 + s2_fn(x_add, w_add)
3141
+ else:
3142
+ s2 = self._s2_weighted_update_torch_blocked(
3143
+ s2, x_add, w_add, s2_block_size, sign=1.0
3144
+ )
3145
+ add_ptr = add_end
3146
+
3147
+ rem_end = int(rem_end_np[g])
3148
+ if rem_end > rem_ptr:
3149
+ x_rem = X_rem[rem_ptr:rem_end]
3150
+ w_rem = exp_eta[rem_ptr:rem_end]
3151
+ n_rem = int(rem_end - rem_ptr)
3152
+ if n_rem <= s2_block_size:
3153
+ s2 = s2 - s2_fn(x_rem, w_rem)
3154
+ else:
3155
+ s2 = self._s2_weighted_update_torch_blocked(
3156
+ s2, x_rem, w_rem, s2_block_size, sign=-1.0
3157
+ )
3158
+ rem_ptr = rem_end
3159
+
3160
+ d_t_f = float(d_counts[g])
3161
+ if d_t_f <= 0:
3162
+ continue
3163
+ if use_efron_entry:
3164
+ st = int(fail_ptr[g])
3165
+ ed = int(fail_ptr[g + 1])
3166
+ ef = event_exp[st:ed]
3167
+ xf = X_fail[st:ed]
3168
+ ef_sum = torch.sum(ef)
3169
+ ef_x_sum = torch.sum(xf * ef.unsqueeze(1), dim=0)
3170
+ ef_x2_sum = xf.transpose(0, 1) @ (xf * ef.unsqueeze(1))
3171
+ s0_g = torch.clamp(s0_vec[g], min=1e-15)
3172
+ s1_g = s1_vec[g]
3173
+ d_i = int(d_t_f)
3174
+ for k in range(d_i):
3175
+ frac = float(k) / float(d_i)
3176
+ denom = torch.clamp(s0_g - frac * ef_sum, min=1e-15)
3177
+ s1_k = s1_g - frac * ef_x_sum
3178
+ s2_k = s2 - frac * ef_x2_sum
3179
+ ex_k = s1_k / denom
3180
+ grad = grad - ex_k
3181
+ hess = hess - (s2_k / denom)
3182
+ hess = hess + torch.outer(ex_k, ex_k)
3183
+ else:
3184
+ s0_safe = s0_safe_vec[g]
3185
+ hess = hess - (d_t_f / s0_safe) * s2
3186
+ if not use_efron_entry:
3187
+ hess = hess + ex_vec.transpose(0, 1) @ (d_vec.unsqueeze(1) * ex_vec)
3188
+ if return_aux:
3189
+ return grad, hess, (eta, exp_eta, risk_sum)
3190
+ return grad, hess
3191
+
3192
+ # Get event data
3193
+ event_times = time[event_mask]
3194
+
3195
+ # Unique failure times with inverse mapping
3196
+ uft, unique_inv = torch.unique(event_times, sorted=True, return_inverse=True)
3197
+ n_uft = len(uft)
3198
+ counts = torch.bincount(unique_inv).to(torch.float64)
3199
+
3200
+ # Get first index of each unique time
3201
+ sorted_times, sort_idx = torch.sort(time)
3202
+ first_in_sorted = torch.searchsorted(sorted_times, uft, side="left")
3203
+ first_idx = sort_idx[first_in_sorted]
3204
+
3205
+ # Risk values at unique times
3206
+ risk_at_uft = risk_sum[first_idx]
3207
+ risk_X_at_uft = risk_X_sum[first_idx]
3208
+ E_X_at_uft = risk_X_at_uft / risk_at_uft[:, None]
3209
+
3210
+ # Sum X and exp(eta) for events at each unique time
3211
+ event_indices = event_mask.nonzero(as_tuple=True)[0]
3212
+ sum_X_per_uft = torch.zeros((n_uft, n_features), dtype=torch.float64, device=beta.device)
3213
+ sum_X_per_uft.index_add_(0, unique_inv, X[event_indices])
3214
+
3215
+ # ============= GRADIENT =============
3216
+ if self.ties == "efron":
3217
+ # Efron closed-form: (d+1)/2 * E[X|R]
3218
+ efron_weight = (counts + 1) / 2.0
3219
+ grad = torch.sum(sum_X_per_uft - efron_weight[:, None] * E_X_at_uft, dim=0)
3220
+ else:
3221
+ # Breslow: d * E[X|R]
3222
+ grad = torch.sum(sum_X_per_uft - counts[:, None] * E_X_at_uft, dim=0)
3223
+
3224
+ # Hessian
3225
+ # Use incremental risk-set second moments to avoid materializing
3226
+ # a (n_samples, n_features, n_features) tensor on GPU (can OOM at 50k x 100).
3227
+ X_exp = X * exp_eta[:, None]
3228
+ risk_X2 = X_exp.transpose(0, 1) @ X
3229
+
3230
+ # Weight by counts (Breslow) or Efron-adjusted weights
3231
+ if self.ties == "efron":
3232
+ weights = efron_weight
3233
+ else:
3234
+ weights = counts
3235
+
3236
+ # ---- Triton Breslow path ----
3237
+ if (
3238
+ self.ties != "efron"
3239
+ and os.environ.get("STATGPU_BRESLOW_TRITON", "0").strip().lower()
3240
+ in ("1", "true", "yes", "on")
3241
+ and beta.is_cuda
3242
+ ):
3243
+ from statgpu.survival._cox_efron_triton import compute_breslow_grad_hess_triton
3244
+ triton_out = compute_breslow_grad_hess_triton(X, beta, time, event)
3245
+ if triton_out is not None:
3246
+ grad, hess = triton_out
3247
+ if return_aux:
3248
+ return grad, hess, (eta, exp_eta, risk_sum)
3249
+ return grad, hess
3250
+
3251
+ hess = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
3252
+ prev_idx = 0
3253
+ for g in range(n_uft):
3254
+ idx = int(first_idx[g].item())
3255
+ if idx > prev_idx:
3256
+ blk = slice(prev_idx, idx)
3257
+ risk_X2 = risk_X2 - (X_exp[blk].transpose(0, 1) @ X[blk])
3258
+ prev_idx = idx
3259
+
3260
+ rs = torch.clamp(risk_at_uft[g], min=1e-300)
3261
+ w = weights[g]
3262
+ ex = E_X_at_uft[g]
3263
+ # Torch-native p^2 update: avoid CuPy bridge and host sync.
3264
+ hess.sub_(risk_X2 * (w / rs))
3265
+ hess.add_(torch.outer(ex, ex) * w)
3266
+
3267
+ if return_aux:
3268
+ return grad, hess, (eta, exp_eta, risk_sum)
3269
+ return grad, hess
3270
+
3271
+ def _s2_weighted_update_torch_blocked(self, s2, x, w, block_size, sign=1.0):
3272
+ """Blocked update for large slices: s2 += sign * X^T (X * w)."""
3273
+ s2_fn = self._get_entry_s2_torch_fn()
3274
+
3275
+ n = int(x.shape[0])
3276
+ if n <= 0:
3277
+ return s2
3278
+ for st in range(0, n, block_size):
3279
+ ed = min(st + block_size, n)
3280
+ xb = x[st:ed]
3281
+ wb = w[st:ed]
3282
+ s2 = s2 + sign * s2_fn(xb, wb)
3283
+ return s2
3284
+
3285
+ def _get_entry_s2_torch_fn(self):
3286
+ """Build/cache torch or torch.compile function for weighted X^T X."""
3287
+ fn = getattr(self, "_entry_s2_torch_fn", None)
3288
+ if fn is not None:
3289
+ return fn
3290
+ import torch
3291
+
3292
+ def _s2_core(x, w):
3293
+ return x.transpose(0, 1) @ (x * w.unsqueeze(1))
3294
+
3295
+ use_compile = (
3296
+ os.environ.get("STATGPU_ENTRY_S2_COMPILE_TORCH", "0").strip().lower()
3297
+ in ("1", "true", "yes", "on")
3298
+ )
3299
+ if use_compile and hasattr(torch, "compile"):
3300
+ mode = os.environ.get("STATGPU_ENTRY_S2_COMPILE_MODE", "default")
3301
+ try:
3302
+ fn = torch.compile(_s2_core, dynamic=True, fullgraph=False, mode=mode)
3303
+ except Exception:
3304
+ fn = _s2_core
3305
+ else:
3306
+ fn = _s2_core
3307
+ self._entry_s2_torch_fn = fn
3308
+ return fn
3309
+
3310
+ def _compute_cindex_torch(self, X, time, event, beta):
3311
+ """Compute concordance index (C-index) on Torch."""
3312
+ import torch
3313
+
3314
+ # Linear predictor (risk score)
3315
+ risk_score = X @ beta
3316
+
3317
+ n = len(time)
3318
+ event_mask = (event == 1)
3319
+
3320
+ if torch.sum(event_mask) == 0:
3321
+ return torch.tensor(0.5, dtype=torch.float64, device=beta.device)
3322
+
3323
+ # Use chunked vectorized approach for memory efficiency
3324
+ event_idx = torch.where(event_mask)[0]
3325
+ n_events = len(event_idx)
3326
+
3327
+ if n_events == 0:
3328
+ return torch.tensor(float("nan"), dtype=torch.float64, device=beta.device)
3329
+
3330
+ concordant = torch.tensor(0, dtype=torch.int64, device=beta.device)
3331
+ permissible = torch.tensor(0, dtype=torch.int64, device=beta.device)
3332
+ tied_risk = torch.tensor(0, dtype=torch.int64, device=beta.device)
3333
+
3334
+ # Chunk size for memory efficiency (~128 MB per batch matrix)
3335
+ chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
3336
+
3337
+ for start in range(0, n_events, chunk_size):
3338
+ end = min(start + chunk_size, n_events)
3339
+ idx_chunk = event_idx[start:end]
3340
+
3341
+ time_i = time[idx_chunk][:, None]
3342
+ risk_i = risk_score[idx_chunk][:, None]
3343
+ time_j = time[None, :]
3344
+ risk_j = risk_score[None, :]
3345
+ event_j = event[None, :]
3346
+
3347
+ # Permissible pairs: earlier time OR same time with j censored
3348
+ perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
3349
+ # Exclude self-comparisons
3350
+ chunk_indices = torch.arange(end - start, device=beta.device)
3351
+ perm[chunk_indices, idx_chunk] = False
3352
+
3353
+ concordant += torch.sum(perm & (risk_i > risk_j))
3354
+ tied_risk += torch.sum(perm & (risk_i == risk_j))
3355
+ permissible += torch.sum(perm)
3356
+
3357
+ if permissible > 0:
3358
+ return (concordant.to(torch.float64) + 0.5 * tied_risk.to(torch.float64)) / permissible.to(torch.float64)
3359
+ else:
3360
+ return torch.tensor(float("nan"), dtype=torch.float64, device=beta.device)
3361
+
3362
+ def _compute_inference_cpu(self, X, time, event, cluster=None):
3363
+ """Compute standard errors, z-values, p-values, and confidence intervals."""
3364
+ n_features = X.shape[1]
3365
+
3366
+ # Keep inference self-contained (no nested external model fitting),
3367
+ # so runtime reflects this implementation directly.
3368
+
3369
+ # Compute information matrix (negative Hessian at MLE)
3370
+ _, hess = self._compute_gradient_hessian(
3371
+ self.coef_, X, time, event, getattr(self, "_efron_pre", None), entry=getattr(self, "_entry", None)
3372
+ )
3373
+
3374
+ # Bread matrix from observed information.
3375
+ try:
3376
+ bread = np.linalg.solve(-hess, np.eye(n_features))
3377
+ except np.linalg.LinAlgError:
3378
+ bread = np.linalg.pinv(-hess)
3379
+
3380
+ if self.cov_type == "nonrobust":
3381
+ self._var_matrix = bread
3382
+ elif self.cov_type == "cluster":
3383
+ if cluster is None:
3384
+ raise ValueError("cov_type='cluster' requires cluster ids in fit(..., cluster=...)")
3385
+ cluster = np.asarray(cluster)
3386
+ score_resid = self._compute_robust_score_residuals(X, time, event)
3387
+ uniq = np.unique(cluster)
3388
+ meat = np.zeros((n_features, n_features), dtype=np.float64)
3389
+ for g in uniq:
3390
+ u_g = np.sum(score_resid[cluster == g], axis=0)
3391
+ meat += np.outer(u_g, u_g)
3392
+ self._var_matrix = bread @ meat @ bread
3393
+ else:
3394
+ score_resid = self._compute_robust_score_residuals(X, time, event)
3395
+ meat = score_resid.T @ score_resid
3396
+ self._var_matrix = bread @ meat @ bread
3397
+ if self.cov_type == "hc1":
3398
+ n = X.shape[0]
3399
+ k = X.shape[1]
3400
+ if n > k:
3401
+ self._var_matrix = self._var_matrix * (n / (n - k))
3402
+
3403
+ # Standard errors
3404
+ self._bse = np.sqrt(np.maximum(np.diag(self._var_matrix), 0.0))
3405
+
3406
+ # z-values (add epsilon to avoid division by zero)
3407
+ self._zvalues = self.coef_ / (self._bse + 1e-30)
3408
+
3409
+ # p-values (two-sided)
3410
+ self._pvalues = 2 * (1 - stats.norm.cdf(np.abs(self._zvalues)))
3411
+
3412
+ # 95% confidence intervals
3413
+ alpha = 0.05
3414
+ z_crit = stats.norm.ppf(1 - alpha / 2)
3415
+ self._conf_int = np.column_stack([
3416
+ self.coef_ - z_crit * self._bse,
3417
+ self.coef_ + z_crit * self._bse
3418
+ ])
3419
+
3420
+ # Wald test (global test that all coefficients are 0)
3421
+ try:
3422
+ var_inv = np.linalg.solve(self._var_matrix, np.eye(n_features))
3423
+ self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
3424
+ except np.linalg.LinAlgError:
3425
+ self._wald_test_stat = np.nan
3426
+ self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
3427
+
3428
+ # Likelihood ratio test
3429
+ self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
3430
+ self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
3431
+
3432
+ # Score test (Rao's test) - computed at beta = 0
3433
+ ep = getattr(self, "_efron_pre", None)
3434
+ grad_0, _ = self._compute_gradient_hessian(np.zeros(n_features), X, time, event, ep, entry=getattr(self, "_entry", None))
3435
+ try:
3436
+ _, hess_0 = self._compute_gradient_hessian(np.zeros(n_features), X, time, event, ep, entry=getattr(self, "_entry", None))
3437
+ info_0 = -hess_0
3438
+ info_0_inv = np.linalg.solve(info_0, np.eye(n_features))
3439
+ self._score_test_stat = grad_0 @ info_0_inv @ grad_0
3440
+ except:
3441
+ self._score_test_stat = np.nan
3442
+ self._score_test_pvalue = 1 - stats.chi2.cdf(self._score_test_stat, n_features)
3443
+
3444
+ def _compute_robust_score_residuals(self, X, time, event):
3445
+ """
3446
+ Per-observation contributions for sandwich (HC0/HC1/cluster).
3447
+
3448
+ When `statsmodels` is available, uses `PHReg.score_residuals`, which
3449
+ follows the martingale / leverage construction used by statsmodels for
3450
+ cluster-robust covariance (same for both Breslow and Efron partial
3451
+ likelihood). This aligns robust SEs with statsmodels much more closely
3452
+ than the closed-form Breslow score residual or the fast Efron
3453
+ approximation.
3454
+
3455
+ Falls back to `_compute_score_residuals_exact_breslow` (Breslow) or
3456
+ `_compute_score_residuals_fast` (Efron) when statsmodels is missing or
3457
+ raises.
3458
+ """
3459
+ sr = self._score_residuals_via_statsmodels_if_available(X, time, event)
3460
+ if sr is not None:
3461
+ return sr
3462
+ if self.ties == "breslow":
3463
+ return self._compute_score_residuals_exact_breslow(X, time, event)
3464
+ return self._compute_score_residuals_fast(X, time, event)
3465
+
3466
+ def _compute_robust_score_residuals_gpu(self, X, time, event):
3467
+ """GPU robust score residuals using event-row approximation."""
3468
+ import cupy as cp
3469
+
3470
+ eta = X @ cp.asarray(self.coef_)
3471
+ exp_eta = cp.exp(eta)
3472
+ risk_sum = cp.cumsum(exp_eta[::-1])[::-1] + 1e-30
3473
+ risk_X_sum = cp.cumsum((X * exp_eta[:, cp.newaxis])[::-1], axis=0)[::-1]
3474
+ score_residuals = cp.zeros((X.shape[0], X.shape[1]), dtype=cp.float64)
3475
+ event_mask = event == 1
3476
+ score_residuals[event_mask] = X[event_mask] - risk_X_sum[event_mask] / risk_sum[event_mask, cp.newaxis]
3477
+ return score_residuals
3478
+
3479
+ def _score_residuals_via_statsmodels_if_available(
3480
+ self, X: np.ndarray, time: np.ndarray, event: np.ndarray
3481
+ ):
3482
+ """Return statsmodels-style score residuals, or None if unavailable."""
3483
+ try:
3484
+ import statsmodels.duration.api as smd
3485
+ except Exception:
3486
+ return None
3487
+ try:
3488
+ model = smd.PHReg(time, X, status=event, ties=self.ties)
3489
+ sr = model.score_residuals(self.coef_)
3490
+ if sr.shape != (X.shape[0], X.shape[1]):
3491
+ return None
3492
+ # Undefined strata / risk-set rows are NaN in statsmodels; drop from meat.
3493
+ sr = np.nan_to_num(sr, nan=0.0, posinf=0.0, neginf=0.0)
3494
+ return np.asarray(sr, dtype=np.float64)
3495
+ except Exception:
3496
+ return None
3497
+
3498
+ def _compute_score_residuals_fast(self, X, time, event):
3499
+ """
3500
+ Fast approximate per-observation score residuals at fitted beta.
3501
+
3502
+ Event-row approximation:
3503
+ u_i = x_i - E[X | R(t_i)] for event rows, 0 for censored rows.
3504
+ This is substantially faster for larger n.
3505
+ """
3506
+ n_samples, n_features = X.shape
3507
+ eta = X @ self.coef_
3508
+ exp_eta = np.exp(eta)
3509
+ risk_sum = np.cumsum(exp_eta[::-1])[::-1] + 1e-30
3510
+ risk_X_sum = np.cumsum((X * exp_eta[:, np.newaxis])[::-1], axis=0)[::-1]
3511
+ u = np.zeros((n_samples, n_features), dtype=np.float64)
3512
+ # Vectorized: fill only event rows.
3513
+ event_mask = event == 1
3514
+ u[event_mask] = X[event_mask] - risk_X_sum[event_mask] / risk_sum[event_mask, np.newaxis]
3515
+ return u
3516
+
3517
+ def _compute_score_residuals_exact_breslow(self, X, time, event):
3518
+ """
3519
+ Exact per-observation score residuals for Breslow ties in O(n p).
3520
+
3521
+ u_j = I(event_j) * s_j - exp_eta_j * sum_{i<=j, event_i=1}(s_i / risk_sum_i),
3522
+ where s_i = x_i - E[X|R(t_i)].
3523
+ """
3524
+ eta = X @ self.coef_
3525
+ exp_eta = np.exp(eta)
3526
+ risk_sum = np.cumsum(exp_eta[::-1])[::-1] + 1e-30
3527
+ risk_X_sum = np.cumsum((X * exp_eta[:, np.newaxis])[::-1], axis=0)[::-1]
3528
+ event_mask = (event == 1).astype(np.float64)
3529
+ s = X - (risk_X_sum / risk_sum[:, np.newaxis])
3530
+ a = (event_mask[:, np.newaxis] * s) / risk_sum[:, np.newaxis]
3531
+ csum_a = np.cumsum(a, axis=0)
3532
+ u = event_mask[:, np.newaxis] * s - exp_eta[:, np.newaxis] * csum_a
3533
+ return u
3534
+
3535
+ def _compute_baseline_hazard(self, X, time, event):
3536
+ """Compute Breslow estimator of baseline hazard and survival function."""
3537
+ # Get unique event times
3538
+ event_mask = event == 1
3539
+ if not np.any(event_mask):
3540
+ self._unique_times = np.array([])
3541
+ self._baseline_hazard = np.array([])
3542
+ self._baseline_cumulative_hazard = np.array([])
3543
+ return
3544
+
3545
+ unique_times = np.unique(time[event_mask])
3546
+ self._unique_times = unique_times
3547
+
3548
+ # Linear predictor
3549
+ eta = X @ self.coef_
3550
+ exp_eta = np.exp(eta)
3551
+
3552
+ # Compute baseline cumulative hazard using Breslow estimator
3553
+ cumulative_hazard = np.zeros(len(unique_times))
3554
+
3555
+ for i, t in enumerate(unique_times):
3556
+ # Events at time t
3557
+ d_i = np.sum((time == t) & (event == 1))
3558
+
3559
+ # Risk set at time t (all with time >= t)
3560
+ risk_set = time >= t
3561
+ risk_sum = np.sum(exp_eta[risk_set])
3562
+
3563
+ # Breslow estimator contribution
3564
+ cumulative_hazard[i] = d_i / risk_sum
3565
+
3566
+ # Cumulative sum
3567
+ self._baseline_cumulative_hazard = np.cumsum(cumulative_hazard)
3568
+
3569
+ # Hazard (discrete)
3570
+ self._baseline_hazard = cumulative_hazard
3571
+
3572
+ def _compute_baseline_hazard_gpu(self, X, time, event, beta):
3573
+ """Compute Breslow estimator of baseline hazard and survival function on GPU."""
3574
+ import cupy as cp
3575
+
3576
+ # Get unique event times
3577
+ event_mask = event == 1
3578
+ if not cp.any(event_mask):
3579
+ self._unique_times = cp.array([])
3580
+ self._baseline_hazard = cp.array([])
3581
+ self._baseline_cumulative_hazard = cp.array([])
3582
+ return
3583
+
3584
+ unique_times = cp.unique(time[event_mask])
3585
+ self._unique_times = unique_times
3586
+
3587
+ # Linear predictor
3588
+ eta = X @ beta
3589
+ exp_eta = cp.exp(eta)
3590
+
3591
+ # Compute baseline cumulative hazard using Breslow estimator (vectorized)
3592
+ cumulative_hazard = cp.zeros(len(unique_times))
3593
+
3594
+ # Vectorized computation using searchsorted
3595
+ # For each unique time, compute d_i / risk_sum
3596
+ for i, t in enumerate(unique_times):
3597
+ # Events at time t
3598
+ d_i = int(cp.sum((time == t) & (event == 1)))
3599
+
3600
+ # Risk set at time t (all with time >= t)
3601
+ risk_set = time >= t
3602
+ risk_sum = cp.sum(exp_eta[risk_set])
3603
+
3604
+ # Breslow estimator contribution
3605
+ cumulative_hazard[i] = d_i / risk_sum
3606
+
3607
+ # Cumulative sum
3608
+ self._baseline_cumulative_hazard = cp.cumsum(cumulative_hazard)
3609
+
3610
+ # Hazard (discrete)
3611
+ self._baseline_hazard = cumulative_hazard
3612
+
3613
+ # Transfer to CPU for storage
3614
+ self._unique_times = cp.asnumpy(self._unique_times)
3615
+ self._baseline_hazard = cp.asnumpy(self._baseline_hazard)
3616
+ self._baseline_cumulative_hazard = cp.asnumpy(self._baseline_cumulative_hazard)
3617
+
3618
+ def _compute_baseline_hazard_torch(self, X, time, event, beta):
3619
+ """Compute Breslow estimator of baseline hazard and survival function on Torch."""
3620
+ import torch
3621
+
3622
+ # Get unique event times
3623
+ event_mask = event == 1
3624
+ if not torch.any(event_mask):
3625
+ self._unique_times = torch.tensor([], dtype=torch.float64, device=beta.device)
3626
+ self._baseline_hazard = torch.tensor([], dtype=torch.float64, device=beta.device)
3627
+ self._baseline_cumulative_hazard = torch.tensor([], dtype=torch.float64, device=beta.device)
3628
+ return
3629
+
3630
+ unique_times = torch.unique(time[event_mask])
3631
+ self._unique_times = unique_times
3632
+
3633
+ # Linear predictor
3634
+ eta = X @ beta
3635
+ exp_eta = torch.exp(eta)
3636
+
3637
+ # Compute baseline cumulative hazard using Breslow estimator (vectorized)
3638
+ cumulative_hazard = torch.zeros(len(unique_times), dtype=torch.float64, device=beta.device)
3639
+
3640
+ # Vectorized computation
3641
+ for i, t in enumerate(unique_times):
3642
+ # Events at time t
3643
+ d_i = int(torch.sum((time == t) & (event == 1)))
3644
+
3645
+ # Risk set at time t (all with time >= t)
3646
+ risk_set = time >= t
3647
+ risk_sum = torch.sum(exp_eta[risk_set])
3648
+
3649
+ # Breslow estimator contribution
3650
+ cumulative_hazard[i] = d_i / risk_sum
3651
+
3652
+ # Cumulative sum
3653
+ self._baseline_cumulative_hazard = torch.cumsum(cumulative_hazard, dim=0)
3654
+
3655
+ # Hazard (discrete)
3656
+ self._baseline_hazard = cumulative_hazard
3657
+
3658
+ # Transfer to CPU for storage
3659
+ self._unique_times = self._unique_times.cpu().numpy()
3660
+ self._baseline_hazard = self._baseline_hazard.cpu().numpy()
3661
+ self._baseline_cumulative_hazard = self._baseline_cumulative_hazard.cpu().numpy()
3662
+
3663
+ def _compute_cindex_gpu(self, X, time, event, beta):
3664
+ """Compute concordance index (C-index) on GPU using chunked vectorized approach."""
3665
+ import cupy as cp
3666
+
3667
+ # Linear predictor (risk score) on GPU
3668
+ risk_score = X @ beta
3669
+
3670
+ n = len(time)
3671
+ event_mask = (event == 1)
3672
+
3673
+ if cp.sum(event_mask) == 0:
3674
+ return cp.array(0.5, dtype=cp.float64)
3675
+
3676
+ # Use chunked vectorized approach for memory efficiency
3677
+ event_idx = cp.where(event_mask)[0]
3678
+ n_events = len(event_idx)
3679
+
3680
+ if n_events == 0:
3681
+ return cp.array(float("nan"), dtype=cp.float64)
3682
+
3683
+ concordant = cp.int64(0)
3684
+ permissible = cp.int64(0)
3685
+ tied_risk = cp.int64(0)
3686
+
3687
+ # Chunk size for memory efficiency (~128 MB per batch matrix)
3688
+ chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
3689
+
3690
+ for start in range(0, n_events, chunk_size):
3691
+ end = min(start + chunk_size, n_events)
3692
+ idx_chunk = event_idx[start:end]
3693
+
3694
+ time_i = time[idx_chunk][:, None]
3695
+ risk_i = risk_score[idx_chunk][:, None]
3696
+ time_j = time[None, :]
3697
+ risk_j = risk_score[None, :]
3698
+ event_j = event[None, :]
3699
+
3700
+ # Permissible pairs: earlier time OR same time with j censored
3701
+ perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
3702
+ # Exclude self-comparisons
3703
+ chunk_indices = cp.arange(end - start, dtype=cp.int64)
3704
+ perm[chunk_indices, idx_chunk] = False
3705
+
3706
+ concordant += cp.sum(perm & (risk_i > risk_j))
3707
+ tied_risk += cp.sum(perm & (risk_i == risk_j))
3708
+ permissible += cp.sum(perm)
3709
+
3710
+ if permissible > 0:
3711
+ return (concordant.astype(cp.float64) + 0.5 * tied_risk.astype(cp.float64)) / permissible.astype(cp.float64)
3712
+ else:
3713
+ return cp.array(float("nan"), dtype=cp.float64)
3714
+
3715
+ def _compute_cindex(self):
3716
+ """
3717
+ Compute concordance index (C-index) using chunked vectorized NumPy.
3718
+
3719
+ Replaces the O(n²) double Python loop with batched boolean matrix ops.
3720
+ Chunk size is chosen so each batch matrix stays within ~128 MB.
3721
+ """
3722
+ if self._X is None or self.coef_ is None:
3723
+ self._cindex = None
3724
+ return
3725
+
3726
+ risk_score = self._X @ self.coef_
3727
+ time = self._time
3728
+ event = self._event
3729
+ n = len(time)
3730
+
3731
+ event_idx = np.where(event == 1)[0]
3732
+ n_events = len(event_idx)
3733
+
3734
+ if n_events == 0:
3735
+ self._cindex = np.nan
3736
+ return
3737
+
3738
+ concordant = np.int64(0)
3739
+ permissible = np.int64(0)
3740
+ tied_risk = np.int64(0)
3741
+
3742
+ # Chunk so each (chunk × n) bool matrix is ≤ 128 MB.
3743
+ chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
3744
+
3745
+ for start in range(0, n_events, chunk_size):
3746
+ end = min(start + chunk_size, n_events)
3747
+ idx_chunk = event_idx[start:end] # (c,)
3748
+
3749
+ time_i = time[idx_chunk, np.newaxis] # (c, 1)
3750
+ risk_i = risk_score[idx_chunk, np.newaxis]
3751
+ time_j = time[np.newaxis, :] # (1, n)
3752
+ risk_j = risk_score[np.newaxis, :]
3753
+ event_j = event[np.newaxis, :]
3754
+
3755
+ # Permissible pairs: earlier time OR same time with j censored.
3756
+ perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
3757
+ # Exclude self-comparisons.
3758
+ perm[np.arange(end - start), idx_chunk] = False
3759
+
3760
+ concordant += int(np.sum(perm & (risk_i > risk_j)))
3761
+ tied_risk += int(np.sum(perm & (risk_i == risk_j)))
3762
+ permissible += int(np.sum(perm))
3763
+
3764
+ if permissible > 0:
3765
+ self._cindex = (concordant + 0.5 * tied_risk) / permissible
3766
+ else:
3767
+ self._cindex = np.nan
3768
+
3769
+ def summary(self):
3770
+ """Print summary table similar to R's summary(coxph())."""
3771
+ if not self._fitted:
3772
+ raise RuntimeError("Model has not been fitted yet.")
3773
+
3774
+ print("=" * 80)
3775
+ print(" Cox Proportional Hazards Model")
3776
+ print("=" * 80)
3777
+ print(f"Call:")
3778
+ print(f" coxph(formula = Surv(time, event) ~ ., ties = '{self.ties}')")
3779
+ print()
3780
+ print(f" n= {self._nobs}, number of events= {int(self._nevents)}")
3781
+ print(f" covariance type= {self.cov_type}")
3782
+ print()
3783
+ if self.compute_inference and self._bse is not None:
3784
+ print(f"{'':<15} {'coef':>10} {'exp(coef)':>12} {'se(coef)':>10} {'z':>10} {'Pr(>|z|)':>10}")
3785
+ print("-" * 80)
3786
+
3787
+ for i, name in enumerate(self._feature_names):
3788
+ print(f"{name:<15} {self.coef_[i]:>10.4f} {self.hazard_ratios_[i]:>12.4f} "
3789
+ f"{self._bse[i]:>10.4f} {self._zvalues[i]:>10.3f} {self._pvalues[i]:>10.4f}")
3790
+
3791
+ print("-" * 80)
3792
+ print(f"{'':<15} {'exp(coef)':>12} {'exp(-coef)':>12} {'lower .95':>12} {'upper .95':>12}")
3793
+ print("-" * 80)
3794
+
3795
+ for i, name in enumerate(self._feature_names):
3796
+ hr = self.hazard_ratios_[i]
3797
+ print(f"{name:<15} {hr:>12.4f} {1/hr:>12.4f} "
3798
+ f"{np.exp(self._conf_int[i, 0]):>12.4f} {np.exp(self._conf_int[i, 1]):>12.4f}")
3799
+ else:
3800
+ print(f"{'':<15} {'coef':>10} {'exp(coef)':>12}")
3801
+ print("-" * 80)
3802
+ for i, name in enumerate(self._feature_names):
3803
+ print(f"{name:<15} {self.coef_[i]:>10.4f} {self.hazard_ratios_[i]:>12.4f}")
3804
+ print("-" * 80)
3805
+ print("Inference statistics disabled (compute_inference=False).")
3806
+
3807
+ print("=" * 80)
3808
+ if self._cindex is None:
3809
+ print("Concordance: skipped (compute_cindex=False)")
3810
+ else:
3811
+ print(f"Concordance: {self._cindex:.3f} (if 0.5-0.7: moderate, 0.7-0.9: strong)")
3812
+ if self.compute_inference and self._lr_test_stat is not None:
3813
+ print(f"Likelihood ratio test: {self._lr_test_stat:.2f} on {len(self.coef_)} df, p={self._lr_test_pvalue:.4e}")
3814
+ print(f"Wald test: {self._wald_test_stat:.2f} on {len(self.coef_)} df, p={self._wald_test_pvalue:.4e}")
3815
+ print(f"Score (logrank) test: {self._score_test_stat:.2f} on {len(self.coef_)} df, p={self._score_test_pvalue:.4e}")
3816
+ else:
3817
+ print("Likelihood/Wald/Score tests skipped (compute_inference=False).")
3818
+ print(f"Number of Newton-Raphson iterations: {self._iterations}")
3819
+ print(f"Converged: {self._converged}")
3820
+ print("=" * 80)
3821
+
3822
+ def predict_hazard_ratio(self, X):
3823
+ """
3824
+ Predict hazard ratios (exp(X @ coef)).
3825
+
3826
+ Parameters
3827
+ ----------
3828
+ X : array-like of shape (n_samples, n_features)
3829
+ Covariate matrix.
3830
+
3831
+ Returns
3832
+ -------
3833
+ hazard_ratios : ndarray of shape (n_samples,)
3834
+ Predicted hazard ratios.
3835
+ """
3836
+ self._check_is_fitted()
3837
+ X = np.asarray(X, dtype=np.float64)
3838
+ if X.ndim == 1:
3839
+ X = X.reshape(-1, 1)
3840
+ return np.exp(X @ self.coef_)
3841
+
3842
+ def predict_risk_score(self, X):
3843
+ """
3844
+ Predict risk scores (X @ coef).
3845
+
3846
+ Parameters
3847
+ ----------
3848
+ X : array-like of shape (n_samples, n_features)
3849
+ Covariate matrix.
3850
+
3851
+ Returns
3852
+ -------
3853
+ risk_scores : ndarray of shape (n_samples,)
3854
+ Predicted risk scores (linear predictor).
3855
+ """
3856
+ self._check_is_fitted()
3857
+ X = np.asarray(X, dtype=np.float64)
3858
+ if X.ndim == 1:
3859
+ X = X.reshape(-1, 1)
3860
+ return X @ self.coef_
3861
+
3862
+ def predict_survival(self, X, times=None):
3863
+ """
3864
+ Predict survival function S(t|X) = exp(-H0(t) * exp(X @ coef)).
3865
+
3866
+ Parameters
3867
+ ----------
3868
+ X : array-like of shape (n_samples, n_features)
3869
+ Covariate matrix.
3870
+ time : array-like, optional
3871
+ Times at which to evaluate survival function.
3872
+ If None, uses unique event times from training data.
3873
+
3874
+ Returns
3875
+ -------
3876
+ survival : ndarray of shape (n_samples, n_times)
3877
+ Predicted survival probabilities.
3878
+ times : ndarray
3879
+ Times at which survival is evaluated.
3880
+ """
3881
+ self._check_is_fitted()
3882
+ X = np.asarray(X, dtype=np.float64)
3883
+ if X.ndim == 1:
3884
+ X = X.reshape(-1, 1)
3885
+
3886
+ if times is None:
3887
+ times = self._unique_times
3888
+ else:
3889
+ times = np.asarray(times)
3890
+
3891
+ if len(times) == 0 or self._baseline_cumulative_hazard is None:
3892
+ return np.ones((X.shape[0], len(times))), times
3893
+
3894
+ # Hazard ratios
3895
+ hr = np.exp(X @ self.coef_)
3896
+
3897
+ # Survival function: S(t) = exp(-H0(t) * HR)
3898
+ survival = np.exp(-self._baseline_cumulative_hazard[np.newaxis, :] * hr[:, np.newaxis])
3899
+
3900
+ return survival, times
3901
+
3902
+ def predict(self, X):
3903
+ """Alias for predict_hazard_ratio."""
3904
+ return self.predict_hazard_ratio(X)
3905
+
3906
+ def score(self, X, time, event):
3907
+ """
3908
+ Compute concordance index on test data.
3909
+
3910
+ Parameters
3911
+ ----------
3912
+ X : array-like of shape (n_samples, n_features)
3913
+ Test covariates.
3914
+ time : array-like of shape (n_samples,)
3915
+ Test event/censoring times.
3916
+ event : array-like of shape (n_samples,)
3917
+ Test event indicators.
3918
+
3919
+ Returns
3920
+ -------
3921
+ cindex : float
3922
+ Concordance index.
3923
+ """
3924
+ self._check_is_fitted()
3925
+
3926
+ risk_score = self.predict_risk_score(X)
3927
+ time = np.asarray(time)
3928
+ event = np.asarray(event)
3929
+
3930
+ n = len(time)
3931
+ event_mask = (event == 1)
3932
+
3933
+ if not np.any(event_mask):
3934
+ return 0.5
3935
+
3936
+ # Use chunked vectorized approach for memory efficiency
3937
+ # Similar to _compute_cindex
3938
+ event_idx = np.where(event_mask)[0]
3939
+ n_events = len(event_idx)
3940
+
3941
+ if n_events == 0:
3942
+ return 0.5
3943
+
3944
+ concordant = np.int64(0)
3945
+ permissible = np.int64(0)
3946
+ tied_risk = np.int64(0)
3947
+
3948
+ # Chunk size: keep each (chunk × n) bool matrix <= 128 MB
3949
+ chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
3950
+
3951
+ for start in range(0, n_events, chunk_size):
3952
+ end = min(start + chunk_size, n_events)
3953
+ idx_chunk = event_idx[start:end]
3954
+
3955
+ time_i = time[idx_chunk, np.newaxis]
3956
+ risk_i = risk_score[idx_chunk, np.newaxis]
3957
+ time_j = time[np.newaxis, :]
3958
+ risk_j = risk_score[np.newaxis, :]
3959
+ event_j = event[np.newaxis, :]
3960
+
3961
+ # Permissible pairs: earlier time OR same time with j censored
3962
+ perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
3963
+
3964
+ # Exclude self-comparisons
3965
+ chunk_indices = np.arange(end - start, dtype=np.int64)
3966
+ perm[chunk_indices, idx_chunk] = False
3967
+
3968
+ concordant += int(np.sum(perm & (risk_i > risk_j)))
3969
+ tied_risk += int(np.sum(perm & (risk_i == risk_j)))
3970
+ permissible += int(np.sum(perm))
3971
+
3972
+ if permissible > 0:
3973
+ return (concordant + 0.5 * tied_risk) / permissible
3974
+ return np.nan