statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
statgpu/survival/_cox.py
ADDED
|
@@ -0,0 +1,3974 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cox Proportional Hazards regression with GPU acceleration.
|
|
3
|
+
|
|
4
|
+
Implements Cox PH model using Breslow and Efron approximations for ties with
|
|
5
|
+
Newton-Raphson optimization. Matches R's survival::coxph() API.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Union, Tuple, Dict, Any, List
|
|
9
|
+
import os
|
|
10
|
+
import numpy as np
|
|
11
|
+
from scipy import stats
|
|
12
|
+
|
|
13
|
+
from statgpu._base import BaseEstimator
|
|
14
|
+
from statgpu._config import Device
|
|
15
|
+
|
|
16
|
+
# Optional Cython import for faster Efron gradient/Hessian computation
|
|
17
|
+
try:
|
|
18
|
+
from ._cox_efron_cy import efron_grad_hess as _efron_grad_hess_cython
|
|
19
|
+
HAS_CYTHON_EFRON = True
|
|
20
|
+
except ImportError:
|
|
21
|
+
HAS_CYTHON_EFRON = False
|
|
22
|
+
_efron_grad_hess_cython = None
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from statgpu.survival._cox_efron_triton import _find_p_ce
|
|
26
|
+
HAS_TRITON_EFRON = True
|
|
27
|
+
except ImportError:
|
|
28
|
+
HAS_TRITON_EFRON = False
|
|
29
|
+
_find_p_ce = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _unpack_efron_pre6(efron_pre):
|
|
33
|
+
"""``(uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft)`` — supports legacy 5-tuple in tests only."""
|
|
34
|
+
if len(efron_pre) == 6:
|
|
35
|
+
return efron_pre
|
|
36
|
+
if len(efron_pre) == 5:
|
|
37
|
+
uft, uft_ix, re, rx, nuft = efron_pre
|
|
38
|
+
return uft, uft_ix, re, rx, nuft, None
|
|
39
|
+
raise ValueError(f"invalid efron_pre length {len(efron_pre)}")
|
|
40
|
+
|
|
41
|
+
class CoxPH(BaseEstimator):
|
|
42
|
+
"""
|
|
43
|
+
Cox Proportional Hazards regression with GPU acceleration.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
ties : str, default='breslow'
|
|
48
|
+
Method for handling ties: 'breslow' or 'efron'.
|
|
49
|
+
tol : float, default=1e-9
|
|
50
|
+
Convergence tolerance for Newton-Raphson.
|
|
51
|
+
max_iter : int, default=100
|
|
52
|
+
Maximum number of iterations.
|
|
53
|
+
device : str or Device, default='auto'
|
|
54
|
+
Computation device: 'cpu', 'cuda', or 'auto'.
|
|
55
|
+
compute_inference : bool, default=True
|
|
56
|
+
If True, compute standard errors/tests/baseline hazard on CPU after fitting.
|
|
57
|
+
Set to False to reduce CPU-GPU data transfers in CUDA mode.
|
|
58
|
+
compute_cindex : bool, default=True
|
|
59
|
+
If True, compute training-set C-index during fit. Disabling this can
|
|
60
|
+
significantly reduce fit time, especially on CUDA/Torch for moderate n.
|
|
61
|
+
|
|
62
|
+
Attributes
|
|
63
|
+
----------
|
|
64
|
+
coef_ : ndarray of shape (n_features,)
|
|
65
|
+
Estimated coefficients (log hazard ratios).
|
|
66
|
+
hazard_ratios_ : ndarray of shape (n_features,)
|
|
67
|
+
exp(coef) = hazard ratios.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
ties: str = 'breslow',
|
|
73
|
+
tol: float = 1e-9,
|
|
74
|
+
max_iter: int = 100,
|
|
75
|
+
device: Union[str, Device] = Device.AUTO,
|
|
76
|
+
n_jobs: Optional[int] = None,
|
|
77
|
+
compute_inference: bool = True,
|
|
78
|
+
compute_cindex: bool = True,
|
|
79
|
+
cov_type: str = "nonrobust",
|
|
80
|
+
gpu_memory_cleanup: bool = False,
|
|
81
|
+
penalty: float = 0.0,
|
|
82
|
+
):
|
|
83
|
+
super().__init__(device=device, n_jobs=n_jobs)
|
|
84
|
+
self.ties = ties.lower()
|
|
85
|
+
self.tol = tol
|
|
86
|
+
self.max_iter = max_iter
|
|
87
|
+
self.compute_inference = compute_inference
|
|
88
|
+
self.compute_cindex = bool(compute_cindex)
|
|
89
|
+
self.cov_type = cov_type.lower()
|
|
90
|
+
self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
|
|
91
|
+
self.penalty = float(penalty)
|
|
92
|
+
|
|
93
|
+
if self.ties not in ('breslow', 'efron'):
|
|
94
|
+
raise ValueError("ties must be 'breslow' or 'efron'")
|
|
95
|
+
if self.cov_type not in ("nonrobust", "hc0", "hc1", "cluster"):
|
|
96
|
+
raise ValueError("cov_type must be one of: 'nonrobust', 'hc0', 'hc1', 'cluster'")
|
|
97
|
+
if self.penalty < 0:
|
|
98
|
+
raise ValueError("penalty must be non-negative")
|
|
99
|
+
|
|
100
|
+
# Fitted attributes
|
|
101
|
+
self.coef_ = None
|
|
102
|
+
self.hazard_ratios_ = None
|
|
103
|
+
|
|
104
|
+
# Internal storage for inference
|
|
105
|
+
self._time = None
|
|
106
|
+
self._event = None
|
|
107
|
+
self._X = None
|
|
108
|
+
self._entry = None
|
|
109
|
+
self._nobs = None
|
|
110
|
+
self._nevents = None
|
|
111
|
+
self._bse = None
|
|
112
|
+
self._zvalues = None
|
|
113
|
+
self._pvalues = None
|
|
114
|
+
self._conf_int = None
|
|
115
|
+
self._log_likelihood = None
|
|
116
|
+
self._log_likelihood_null = None
|
|
117
|
+
self._iterations = 0
|
|
118
|
+
self._converged = False
|
|
119
|
+
self._var_matrix = None
|
|
120
|
+
self._score_test_stat = None
|
|
121
|
+
self._baseline_hazard = None
|
|
122
|
+
self._baseline_cumulative_hazard = None
|
|
123
|
+
self._unique_times = None
|
|
124
|
+
self._cindex = None
|
|
125
|
+
self._feature_names = None
|
|
126
|
+
self._wald_test_stat = None
|
|
127
|
+
self._wald_test_pvalue = None
|
|
128
|
+
self._lr_test_stat = None
|
|
129
|
+
self._lr_test_pvalue = None
|
|
130
|
+
self._score_test_pvalue = None
|
|
131
|
+
# Efron only: cached (uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft); depends only on sorted time/event.
|
|
132
|
+
self._efron_pre = None
|
|
133
|
+
# Efron optimization: True when all failure groups are singletons (no ties),
|
|
134
|
+
# in which case Efron equals Breslow and we can use faster vectorized paths.
|
|
135
|
+
self._efron_all_singletons = False
|
|
136
|
+
# Efron only: cached CSR packed indices for GPU kernels.
|
|
137
|
+
# (enter_ptr, enter_ind, exit_ptr, exit_ind, fail_ptr, fail_ind, first_idx_uft, nuft)
|
|
138
|
+
self._efron_pre_csr = None
|
|
139
|
+
# Breslow only: cached (first_idx_uft, counts_uft) on CPU.
|
|
140
|
+
self._breslow_pre = None
|
|
141
|
+
# Breslow only: cached (first_idx_uft_gpu, counts_uft_gpu) on GPU.
|
|
142
|
+
self._breslow_pre_gpu = None
|
|
143
|
+
|
|
144
|
+
def _cleanup_cuda_memory(self):
|
|
145
|
+
"""Best-effort CuPy memory pool cleanup."""
|
|
146
|
+
if not self.gpu_memory_cleanup:
|
|
147
|
+
return
|
|
148
|
+
try:
|
|
149
|
+
import cupy as cp
|
|
150
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
151
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
152
|
+
except Exception:
|
|
153
|
+
pass
|
|
154
|
+
|
|
155
|
+
def _cleanup_torch_memory(self):
|
|
156
|
+
"""Best-effort Torch CUDA cache cleanup."""
|
|
157
|
+
if not self.gpu_memory_cleanup:
|
|
158
|
+
return
|
|
159
|
+
try:
|
|
160
|
+
import torch
|
|
161
|
+
|
|
162
|
+
torch.cuda.empty_cache()
|
|
163
|
+
torch.cuda.synchronize()
|
|
164
|
+
except Exception:
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
def __del__(self):
|
|
168
|
+
try:
|
|
169
|
+
self._cleanup_cuda_memory()
|
|
170
|
+
self._cleanup_torch_memory()
|
|
171
|
+
except Exception:
|
|
172
|
+
pass
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def _extract_convergence_status(result):
|
|
176
|
+
"""Best-effort convergence extraction from statsmodels results."""
|
|
177
|
+
conv_attr = getattr(result, "converged", None)
|
|
178
|
+
if conv_attr is not None:
|
|
179
|
+
return bool(conv_attr)
|
|
180
|
+
|
|
181
|
+
mle_retvals = getattr(result, "mle_retvals", None)
|
|
182
|
+
if isinstance(mle_retvals, dict):
|
|
183
|
+
conv_attr = mle_retvals.get("converged")
|
|
184
|
+
if conv_attr is not None:
|
|
185
|
+
return bool(conv_attr)
|
|
186
|
+
elif mle_retvals is not None:
|
|
187
|
+
conv_attr = getattr(mle_retvals, "converged", None)
|
|
188
|
+
if conv_attr is not None:
|
|
189
|
+
return bool(conv_attr)
|
|
190
|
+
return None
|
|
191
|
+
|
|
192
|
+
def fit(self, X=None, time=None, event=None, entry=None, cluster=None, init_coef=None, formula=None, data=None):
|
|
193
|
+
"""
|
|
194
|
+
Fit Cox Proportional Hazards model.
|
|
195
|
+
|
|
196
|
+
Parameters
|
|
197
|
+
----------
|
|
198
|
+
X : array-like of shape (n_samples, n_features)
|
|
199
|
+
Covariate matrix. Required if ``formula`` is None.
|
|
200
|
+
time : array-like of shape (n_samples,)
|
|
201
|
+
Time to event or censoring. Required if ``formula`` is None.
|
|
202
|
+
event : array-like of shape (n_samples,)
|
|
203
|
+
Event indicator (1 = event, 0 = censored). Required if ``formula`` is None.
|
|
204
|
+
entry : array-like of shape (n_samples,), optional
|
|
205
|
+
Entry time for delayed entry (left truncation).
|
|
206
|
+
cluster : array-like of shape (n_samples,), optional
|
|
207
|
+
Cluster ids for cluster-robust covariance when `cov_type='cluster'`.
|
|
208
|
+
init_coef : array-like of shape (n_features,), optional
|
|
209
|
+
Initial coefficient guess for warm-start optimization.
|
|
210
|
+
formula : str or None
|
|
211
|
+
R-style formula with Surv() response, e.g.
|
|
212
|
+
``"Surv(time, event) ~ x1 + x2 + C(sex)"``.
|
|
213
|
+
data : pd.DataFrame or None
|
|
214
|
+
DataFrame used with ``formula`` for column lookup.
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
self : CoxPH
|
|
219
|
+
Fitted estimator.
|
|
220
|
+
"""
|
|
221
|
+
# Handle formula interface
|
|
222
|
+
if formula is not None:
|
|
223
|
+
if data is None:
|
|
224
|
+
raise ValueError(
|
|
225
|
+
"formula was provided but data is None. "
|
|
226
|
+
"Pass data=your_dataframe when using formula."
|
|
227
|
+
)
|
|
228
|
+
from statgpu.core.formula import _surv, make_surv_env
|
|
229
|
+
import patsy
|
|
230
|
+
from patsy import EvalEnvironment
|
|
231
|
+
|
|
232
|
+
env = make_surv_env()
|
|
233
|
+
# Create evaluation environment with custom Surv function
|
|
234
|
+
custom_env = EvalEnvironment([env])
|
|
235
|
+
y_patsy, X_patsy = patsy.dmatrices(
|
|
236
|
+
formula, data, eval_env=custom_env, return_type="matrix",
|
|
237
|
+
)
|
|
238
|
+
design_info = X_patsy.design_info
|
|
239
|
+
# y_patsy is the result of Surv(time, event) -> shape (n, 2)
|
|
240
|
+
y_arr = np.asarray(y_patsy)
|
|
241
|
+
if y_arr.ndim == 1:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
"Formula response must be Surv(time, event), not a single variable. "
|
|
244
|
+
"Use: formula='Surv(time, event) ~ x1 + x2'"
|
|
245
|
+
)
|
|
246
|
+
time = y_arr[:, 0]
|
|
247
|
+
event = y_arr[:, 1]
|
|
248
|
+
X_arr = np.asarray(X_patsy)
|
|
249
|
+
|
|
250
|
+
# Drop intercept column from design matrix (CoxPH doesn't use intercept)
|
|
251
|
+
self._feature_names = list(design_info.column_names)
|
|
252
|
+
if "Intercept" in self._feature_names:
|
|
253
|
+
self._feature_names.remove("Intercept")
|
|
254
|
+
X_arr = X_arr[:, 1:]
|
|
255
|
+
self._design_info = design_info
|
|
256
|
+
X = X_arr
|
|
257
|
+
else:
|
|
258
|
+
if X is None or time is None or event is None:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
"Either formula+data or X+time+event must be provided."
|
|
261
|
+
)
|
|
262
|
+
self._design_info = None
|
|
263
|
+
device = self._get_compute_device()
|
|
264
|
+
|
|
265
|
+
if device == Device.CUDA:
|
|
266
|
+
import cupy as cp
|
|
267
|
+
|
|
268
|
+
X_gpu = cp.asarray(self._to_array(X), dtype=cp.float64)
|
|
269
|
+
time_gpu = cp.asarray(self._to_array(time), dtype=cp.float64)
|
|
270
|
+
event_gpu = cp.asarray(self._to_array(event), dtype=cp.int32)
|
|
271
|
+
entry_gpu = None if entry is None else cp.asarray(self._to_array(entry), dtype=cp.float64)
|
|
272
|
+
|
|
273
|
+
if X_gpu.ndim == 1:
|
|
274
|
+
X_gpu = X_gpu.reshape(-1, 1)
|
|
275
|
+
if entry_gpu is not None and entry_gpu.shape[0] != X_gpu.shape[0]:
|
|
276
|
+
raise ValueError("entry must have shape (n_samples,)")
|
|
277
|
+
|
|
278
|
+
self._nobs = int(X_gpu.shape[0])
|
|
279
|
+
self._nevents = int(cp.sum(event_gpu).item())
|
|
280
|
+
if self._feature_names is None:
|
|
281
|
+
self._feature_names = [f'x{i+1}' for i in range(int(X_gpu.shape[1]))]
|
|
282
|
+
|
|
283
|
+
# Keep CPU copies only when CPU-side inference/baseline stats are requested.
|
|
284
|
+
if self.compute_inference:
|
|
285
|
+
self._X = cp.asnumpy(X_gpu)
|
|
286
|
+
self._time = cp.asnumpy(time_gpu)
|
|
287
|
+
self._event = cp.asnumpy(event_gpu)
|
|
288
|
+
self._entry = None if entry_gpu is None else cp.asnumpy(entry_gpu)
|
|
289
|
+
else:
|
|
290
|
+
self._X = None
|
|
291
|
+
self._time = None
|
|
292
|
+
self._event = None
|
|
293
|
+
self._entry = None
|
|
294
|
+
|
|
295
|
+
cluster_gpu = None if cluster is None else cp.asarray(self._to_array(cluster), dtype=cp.int64)
|
|
296
|
+
self._fit_gpu(X_gpu, time_gpu, event_gpu, entry_gpu, cluster_gpu, init_coef=init_coef)
|
|
297
|
+
elif device == Device.TORCH:
|
|
298
|
+
import torch
|
|
299
|
+
|
|
300
|
+
torch_device = "cuda"
|
|
301
|
+
|
|
302
|
+
X_torch = self._to_array(X, Device.TORCH, backend="torch").to(dtype=torch.float64)
|
|
303
|
+
time_torch = self._to_array(time, Device.TORCH, backend="torch").to(dtype=torch.float64)
|
|
304
|
+
event_torch = self._to_array(event, Device.TORCH, backend="torch").to(dtype=torch.int32)
|
|
305
|
+
entry_torch = None if entry is None else self._to_array(
|
|
306
|
+
entry, Device.TORCH, backend="torch"
|
|
307
|
+
).to(dtype=torch.float64)
|
|
308
|
+
|
|
309
|
+
if X_torch.ndim == 1:
|
|
310
|
+
X_torch = X_torch.reshape(-1, 1)
|
|
311
|
+
if entry_torch is not None and entry_torch.shape[0] != X_torch.shape[0]:
|
|
312
|
+
raise ValueError("entry must have shape (n_samples,)")
|
|
313
|
+
|
|
314
|
+
self._nobs = int(X_torch.shape[0])
|
|
315
|
+
self._nevents = int(torch.sum(event_torch).item())
|
|
316
|
+
if self._feature_names is None:
|
|
317
|
+
self._feature_names = [f'x{i+1}' for i in range(int(X_torch.shape[1]))]
|
|
318
|
+
|
|
319
|
+
# Keep CPU copies only when CPU-side inference/baseline stats are requested.
|
|
320
|
+
if self.compute_inference:
|
|
321
|
+
self._X = X_torch.cpu().numpy()
|
|
322
|
+
self._time = time_torch.cpu().numpy()
|
|
323
|
+
self._event = event_torch.cpu().numpy()
|
|
324
|
+
self._entry = None if entry_torch is None else entry_torch.cpu().numpy()
|
|
325
|
+
else:
|
|
326
|
+
self._X = None
|
|
327
|
+
self._time = None
|
|
328
|
+
self._event = None
|
|
329
|
+
self._entry = None
|
|
330
|
+
|
|
331
|
+
cluster_torch = None if cluster is None else self._to_array(
|
|
332
|
+
cluster, Device.TORCH, backend="torch"
|
|
333
|
+
).to(dtype=torch.int64)
|
|
334
|
+
self._fit_torch(
|
|
335
|
+
X_torch,
|
|
336
|
+
time_torch,
|
|
337
|
+
event_torch,
|
|
338
|
+
entry_torch,
|
|
339
|
+
cluster_torch,
|
|
340
|
+
torch_device,
|
|
341
|
+
init_coef=init_coef,
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
X_np = np.asarray(self._to_array(X, Device.CPU), dtype=np.float64)
|
|
345
|
+
time_np = np.asarray(self._to_array(time, Device.CPU), dtype=np.float64)
|
|
346
|
+
event_np = np.asarray(self._to_array(event, Device.CPU), dtype=np.int32)
|
|
347
|
+
entry_np = None if entry is None else np.asarray(self._to_array(entry, Device.CPU), dtype=np.float64)
|
|
348
|
+
|
|
349
|
+
if X_np.ndim == 1:
|
|
350
|
+
X_np = X_np.reshape(-1, 1)
|
|
351
|
+
if entry_np is not None and entry_np.shape[0] != X_np.shape[0]:
|
|
352
|
+
raise ValueError("entry must have shape (n_samples,)")
|
|
353
|
+
|
|
354
|
+
self._nobs = X_np.shape[0]
|
|
355
|
+
self._nevents = np.sum(event_np)
|
|
356
|
+
|
|
357
|
+
# Store original data (CPU mode is CPU-only)
|
|
358
|
+
self._time = time_np.copy()
|
|
359
|
+
self._event = event_np.copy()
|
|
360
|
+
self._X = X_np.copy()
|
|
361
|
+
self._entry = None if entry_np is None else entry_np.copy()
|
|
362
|
+
if self._feature_names is None:
|
|
363
|
+
self._feature_names = [f'x{i+1}' for i in range(X_np.shape[1])]
|
|
364
|
+
|
|
365
|
+
cluster_np = None if cluster is None else np.asarray(self._to_array(cluster, Device.CPU))
|
|
366
|
+
self._fit_cpu(X_np, time_np, event_np, entry_np, cluster_np, init_coef=init_coef)
|
|
367
|
+
|
|
368
|
+
self._fitted = True
|
|
369
|
+
return self
|
|
370
|
+
|
|
371
|
+
def _fit_cpu(self, X, time, event, entry=None, cluster=None, init_coef=None):
|
|
372
|
+
"""Fit using CPU (NumPy)."""
|
|
373
|
+
if entry is not None:
|
|
374
|
+
self._fit_cpu_with_entry(X, time, event, np.asarray(entry, dtype=np.float64), cluster)
|
|
375
|
+
return
|
|
376
|
+
n_samples, n_features = X.shape
|
|
377
|
+
|
|
378
|
+
# Sort by time ascending so risk-set terms are suffix sums:
|
|
379
|
+
# R(t_i) = {j: t_j >= t_i} -> indices i..n-1 after ascending sort.
|
|
380
|
+
order = np.argsort(time)
|
|
381
|
+
X_sorted = X[order]
|
|
382
|
+
time_sorted = time[order]
|
|
383
|
+
event_sorted = event[order]
|
|
384
|
+
entry_sorted = None if entry is None else np.asarray(entry, dtype=np.float64)[order]
|
|
385
|
+
cluster_sorted = None if cluster is None else np.asarray(cluster)[order]
|
|
386
|
+
|
|
387
|
+
self._efron_pre = None
|
|
388
|
+
self._breslow_pre = None
|
|
389
|
+
self._breslow_pre_gpu = None
|
|
390
|
+
if self.ties == "efron":
|
|
391
|
+
self._efron_pre = self._efron_unique_failure_indices(time_sorted, event_sorted)
|
|
392
|
+
try:
|
|
393
|
+
uft, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(self._efron_pre)
|
|
394
|
+
self._efron_all_singletons = bool(nuft > 0) and all(
|
|
395
|
+
len(ix) == 1 for ix in uft_ix
|
|
396
|
+
)
|
|
397
|
+
except Exception:
|
|
398
|
+
self._efron_all_singletons = False
|
|
399
|
+
else:
|
|
400
|
+
self._efron_all_singletons = False
|
|
401
|
+
self._breslow_pre = self._breslow_unique_failure_groups(
|
|
402
|
+
time_sorted, event_sorted
|
|
403
|
+
)
|
|
404
|
+
if entry_sorted is not None:
|
|
405
|
+
event_idx_np = np.flatnonzero(event_sorted.astype(np.int32) == 1)
|
|
406
|
+
event_times_np = time_sorted[event_idx_np].astype(np.float64, copy=False)
|
|
407
|
+
uft_np, inv_np = np.unique(event_times_np, return_inverse=True)
|
|
408
|
+
self._entry_fail_groups_np = [
|
|
409
|
+
event_idx_np[inv_np == g].astype(np.int64, copy=False)
|
|
410
|
+
for g in range(len(uft_np))
|
|
411
|
+
]
|
|
412
|
+
self._entry_fail_times_np = uft_np.astype(np.float64, copy=False)
|
|
413
|
+
self._entry_order_np = np.argsort(entry_sorted).astype(np.int64, copy=False)
|
|
414
|
+
self._entry_add_end_np = np.searchsorted(
|
|
415
|
+
entry_sorted, uft_np, side="left"
|
|
416
|
+
).astype(np.int64, copy=False)
|
|
417
|
+
self._entry_rem_end_np = np.searchsorted(
|
|
418
|
+
time_sorted, uft_np, side="left"
|
|
419
|
+
).astype(np.int64, copy=False)
|
|
420
|
+
else:
|
|
421
|
+
self._entry_fail_groups_np = None
|
|
422
|
+
self._entry_fail_times_np = None
|
|
423
|
+
self._entry_order_np = None
|
|
424
|
+
self._entry_add_end_np = None
|
|
425
|
+
self._entry_rem_end_np = None
|
|
426
|
+
|
|
427
|
+
# Initialize coefficients (supports warm-start path in CV)
|
|
428
|
+
if init_coef is None:
|
|
429
|
+
beta = np.zeros(n_features, dtype=np.float64)
|
|
430
|
+
else:
|
|
431
|
+
beta = np.asarray(init_coef, dtype=np.float64).reshape(-1)
|
|
432
|
+
if beta.shape[0] != n_features:
|
|
433
|
+
raise ValueError("init_coef must have shape (n_features,)")
|
|
434
|
+
|
|
435
|
+
# Compute null log-likelihood (beta = 0)
|
|
436
|
+
self._log_likelihood_null = self._compute_log_likelihood(
|
|
437
|
+
np.zeros(n_features), X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
# Newton-Raphson optimization with L2 penalty
|
|
441
|
+
penalty = float(self.penalty) if hasattr(self, 'penalty') else 0.0
|
|
442
|
+
use_penalty = penalty > 0.0
|
|
443
|
+
# Preferred Newton direction for CPU path; updated adaptively.
|
|
444
|
+
preferred_direction = -1.0
|
|
445
|
+
iteration = -1 # default if max_iter=0
|
|
446
|
+
|
|
447
|
+
for iteration in range(self.max_iter):
|
|
448
|
+
# Compute gradient and Hessian
|
|
449
|
+
grad, hess = self._compute_gradient_hessian(
|
|
450
|
+
beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Add penalty terms: gradient -= 2*penalty*beta, hessian -= 2*penalty*I
|
|
454
|
+
if use_penalty:
|
|
455
|
+
grad = grad - 2 * penalty * beta
|
|
456
|
+
hess = hess - 2 * penalty * np.eye(n_features, dtype=np.float64)
|
|
457
|
+
|
|
458
|
+
# Solve a Newton-like step on (-hess). In practice, different tie paths
|
|
459
|
+
# may expose Hessian with different sign conventions, so we choose the
|
|
460
|
+
# ascent direction adaptively below using objective evaluation.
|
|
461
|
+
try:
|
|
462
|
+
delta = np.linalg.solve(-hess, grad)
|
|
463
|
+
except np.linalg.LinAlgError:
|
|
464
|
+
# Use pseudo-inverse if singular
|
|
465
|
+
delta = np.linalg.lstsq(-hess, grad, rcond=None)[0]
|
|
466
|
+
|
|
467
|
+
# Line search with step halving
|
|
468
|
+
# Compute log-likelihood at current point
|
|
469
|
+
old_ll = self._compute_log_likelihood(
|
|
470
|
+
beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
|
|
471
|
+
)
|
|
472
|
+
if use_penalty:
|
|
473
|
+
old_ll = old_ll - penalty * np.sum(beta ** 2)
|
|
474
|
+
|
|
475
|
+
# Fast path: try preferred direction first, only test opposite
|
|
476
|
+
# when the preferred full step does not improve.
|
|
477
|
+
direction = preferred_direction
|
|
478
|
+
new_beta = beta + direction * delta
|
|
479
|
+
new_ll = self._compute_log_likelihood(
|
|
480
|
+
new_beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
|
|
481
|
+
)
|
|
482
|
+
if use_penalty:
|
|
483
|
+
new_ll = new_ll - penalty * np.sum(new_beta ** 2)
|
|
484
|
+
|
|
485
|
+
if new_ll <= old_ll - 1e-8:
|
|
486
|
+
# Probe the opposite direction only when needed.
|
|
487
|
+
if entry_sorted is None:
|
|
488
|
+
alt_direction = -direction
|
|
489
|
+
alt_beta = beta + alt_direction * delta
|
|
490
|
+
alt_ll = self._compute_log_likelihood(
|
|
491
|
+
alt_beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
|
|
492
|
+
)
|
|
493
|
+
if use_penalty:
|
|
494
|
+
alt_ll = alt_ll - penalty * np.sum(alt_beta ** 2)
|
|
495
|
+
if alt_ll > new_ll:
|
|
496
|
+
direction = alt_direction
|
|
497
|
+
preferred_direction = alt_direction
|
|
498
|
+
new_beta = alt_beta
|
|
499
|
+
new_ll = alt_ll
|
|
500
|
+
|
|
501
|
+
# Backtracking line search from step=0.5; step=1 was already evaluated.
|
|
502
|
+
if new_ll <= old_ll - 1e-8:
|
|
503
|
+
step = 0.5
|
|
504
|
+
for _ in range(20):
|
|
505
|
+
trial_beta = beta + direction * step * delta
|
|
506
|
+
trial_ll = self._compute_log_likelihood(
|
|
507
|
+
trial_beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
|
|
508
|
+
)
|
|
509
|
+
if use_penalty:
|
|
510
|
+
trial_ll = trial_ll - penalty * np.sum(trial_beta ** 2)
|
|
511
|
+
if trial_ll > old_ll - 1e-8:
|
|
512
|
+
new_beta = trial_beta
|
|
513
|
+
new_ll = trial_ll
|
|
514
|
+
break
|
|
515
|
+
step *= 0.5
|
|
516
|
+
else:
|
|
517
|
+
step = 1.0
|
|
518
|
+
else:
|
|
519
|
+
# Keep successful direction for the next iteration.
|
|
520
|
+
preferred_direction = direction
|
|
521
|
+
step = 1.0
|
|
522
|
+
|
|
523
|
+
# Check convergence
|
|
524
|
+
if np.linalg.norm(delta) * step < self.tol:
|
|
525
|
+
self._converged = True
|
|
526
|
+
beta = new_beta
|
|
527
|
+
break
|
|
528
|
+
|
|
529
|
+
beta = new_beta
|
|
530
|
+
|
|
531
|
+
self._iterations = iteration + 1
|
|
532
|
+
self.coef_ = beta
|
|
533
|
+
self.hazard_ratios_ = np.exp(beta)
|
|
534
|
+
|
|
535
|
+
# Compute final log-likelihood
|
|
536
|
+
self._log_likelihood = self._compute_log_likelihood(
|
|
537
|
+
beta, X_sorted, time_sorted, event_sorted, self._efron_pre, entry=entry_sorted
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# Compute optional inference statistics
|
|
541
|
+
if self.compute_inference:
|
|
542
|
+
self._compute_inference_cpu(X_sorted, time_sorted, event_sorted, cluster_sorted)
|
|
543
|
+
self._compute_baseline_hazard(X_sorted, time_sorted, event_sorted)
|
|
544
|
+
else:
|
|
545
|
+
self._var_matrix = None
|
|
546
|
+
self._bse = None
|
|
547
|
+
self._zvalues = None
|
|
548
|
+
self._pvalues = None
|
|
549
|
+
self._conf_int = None
|
|
550
|
+
self._score_test_stat = None
|
|
551
|
+
self._score_test_pvalue = None
|
|
552
|
+
self._wald_test_stat = None
|
|
553
|
+
self._wald_test_pvalue = None
|
|
554
|
+
self._lr_test_stat = None
|
|
555
|
+
self._lr_test_pvalue = None
|
|
556
|
+
self._baseline_hazard = None
|
|
557
|
+
self._baseline_cumulative_hazard = None
|
|
558
|
+
self._unique_times = None
|
|
559
|
+
|
|
560
|
+
# Release large temporary GPU tensors early.
|
|
561
|
+
try:
|
|
562
|
+
del X_sorted
|
|
563
|
+
except Exception:
|
|
564
|
+
pass
|
|
565
|
+
try:
|
|
566
|
+
del time_sorted
|
|
567
|
+
except Exception:
|
|
568
|
+
pass
|
|
569
|
+
try:
|
|
570
|
+
del event_sorted
|
|
571
|
+
except Exception:
|
|
572
|
+
pass
|
|
573
|
+
try:
|
|
574
|
+
del grad
|
|
575
|
+
except Exception:
|
|
576
|
+
pass
|
|
577
|
+
try:
|
|
578
|
+
del hess
|
|
579
|
+
except Exception:
|
|
580
|
+
pass
|
|
581
|
+
try:
|
|
582
|
+
del delta
|
|
583
|
+
except Exception:
|
|
584
|
+
pass
|
|
585
|
+
self._cleanup_cuda_memory()
|
|
586
|
+
if self.compute_cindex:
|
|
587
|
+
self._compute_cindex()
|
|
588
|
+
else:
|
|
589
|
+
self._cindex = None
|
|
590
|
+
|
|
591
|
+
def _fit_cpu_with_entry(self, X, time, event, entry, cluster=None):
|
|
592
|
+
"""Fit using statsmodels PHReg when delayed entry is provided.
|
|
593
|
+
|
|
594
|
+
Note: L2 penalty is not applied in this path (statsmodels PHReg
|
|
595
|
+
does not support penalized fitting). A warning is emitted when
|
|
596
|
+
penalty is specified.
|
|
597
|
+
"""
|
|
598
|
+
if float(self.penalty) > 0:
|
|
599
|
+
import warnings
|
|
600
|
+
warnings.warn(
|
|
601
|
+
"CoxPH with entry (delayed entry) does not support penalties via "
|
|
602
|
+
"statsmodels PHReg. The penalty will be ignored. "
|
|
603
|
+
"Use the GPU/torch path for penalized Cox with delayed entry.",
|
|
604
|
+
UserWarning, stacklevel=3,
|
|
605
|
+
)
|
|
606
|
+
import statsmodels.duration.api as smd
|
|
607
|
+
|
|
608
|
+
n_samples, n_features = X.shape
|
|
609
|
+
model = smd.PHReg(time, X, status=event, entry=entry, ties=self.ties)
|
|
610
|
+
res = model.fit(disp=0)
|
|
611
|
+
|
|
612
|
+
self._iterations = int(getattr(res, "iterations", 0) or 0)
|
|
613
|
+
conv_attr = self._extract_convergence_status(res)
|
|
614
|
+
self._converged = bool(conv_attr) if conv_attr is not None else False
|
|
615
|
+
self.coef_ = np.asarray(res.params, dtype=np.float64)
|
|
616
|
+
self.hazard_ratios_ = np.exp(self.coef_)
|
|
617
|
+
self._log_likelihood = float(res.llf)
|
|
618
|
+
|
|
619
|
+
try:
|
|
620
|
+
null_model = smd.PHReg(time, np.zeros((n_samples, 1), dtype=np.float64), status=event, entry=entry, ties=self.ties)
|
|
621
|
+
null_res = null_model.fit(disp=0)
|
|
622
|
+
self._log_likelihood_null = float(null_res.llf)
|
|
623
|
+
except Exception:
|
|
624
|
+
self._log_likelihood_null = np.nan
|
|
625
|
+
|
|
626
|
+
cov = np.asarray(res.cov_params(), dtype=np.float64)
|
|
627
|
+
if cov.shape != (n_features, n_features):
|
|
628
|
+
cov = np.full((n_features, n_features), np.nan, dtype=np.float64)
|
|
629
|
+
self._var_matrix = cov
|
|
630
|
+
self._bse = np.sqrt(np.maximum(np.diag(cov), 0.0))
|
|
631
|
+
self._zvalues = self.coef_ / (self._bse + 1e-30)
|
|
632
|
+
self._pvalues = 2 * (1 - stats.norm.cdf(np.abs(self._zvalues)))
|
|
633
|
+
self._conf_int = np.asarray(res.conf_int(), dtype=np.float64)
|
|
634
|
+
|
|
635
|
+
# Delayed-entry robust covariance override is intentionally skipped:
|
|
636
|
+
# current internal robust score/hessian helpers do not account for entry.
|
|
637
|
+
|
|
638
|
+
self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
|
|
639
|
+
self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
|
|
640
|
+
try:
|
|
641
|
+
var_inv = np.linalg.solve(self._var_matrix, np.eye(n_features))
|
|
642
|
+
self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
|
|
643
|
+
except np.linalg.LinAlgError:
|
|
644
|
+
self._wald_test_stat = np.nan
|
|
645
|
+
self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
|
|
646
|
+
self._score_test_stat = np.nan
|
|
647
|
+
self._score_test_pvalue = np.nan
|
|
648
|
+
|
|
649
|
+
# Baseline hazard from PHReg output.
|
|
650
|
+
try:
|
|
651
|
+
base = res.baseline_cumulative_hazard[0]
|
|
652
|
+
self._unique_times = np.asarray(base[0], dtype=np.float64)
|
|
653
|
+
self._baseline_cumulative_hazard = np.asarray(base[1], dtype=np.float64)
|
|
654
|
+
if self._baseline_cumulative_hazard.size > 0:
|
|
655
|
+
self._baseline_hazard = np.diff(
|
|
656
|
+
np.concatenate([[0.0], self._baseline_cumulative_hazard])
|
|
657
|
+
)
|
|
658
|
+
else:
|
|
659
|
+
self._baseline_hazard = np.array([], dtype=np.float64)
|
|
660
|
+
except Exception:
|
|
661
|
+
self._baseline_hazard = None
|
|
662
|
+
self._baseline_cumulative_hazard = None
|
|
663
|
+
self._unique_times = None
|
|
664
|
+
|
|
665
|
+
if self.compute_cindex:
|
|
666
|
+
self._compute_cindex()
|
|
667
|
+
else:
|
|
668
|
+
self._cindex = None
|
|
669
|
+
|
|
670
|
+
def _fit_gpu(self, X, time, event, entry=None, cluster=None, init_coef=None):
|
|
671
|
+
"""Fit using GPU with full GPU computation."""
|
|
672
|
+
import cupy as cp
|
|
673
|
+
from statgpu.inference._distributions_backend import norm
|
|
674
|
+
|
|
675
|
+
n_samples, n_features = X.shape
|
|
676
|
+
|
|
677
|
+
# Transfer to GPU once
|
|
678
|
+
X = cp.asarray(X, dtype=cp.float64)
|
|
679
|
+
time = cp.asarray(time, dtype=cp.float64)
|
|
680
|
+
event = cp.asarray(event, dtype=cp.int32)
|
|
681
|
+
|
|
682
|
+
# Sort by time ascending so risk-set terms are suffix sums:
|
|
683
|
+
# R(t_i) = {j: t_j >= t_i} -> indices i..n-1 after ascending sort.
|
|
684
|
+
order = cp.argsort(time, kind="stable")
|
|
685
|
+
X_sorted = X[order]
|
|
686
|
+
time_sorted = time[order]
|
|
687
|
+
event_sorted = event[order]
|
|
688
|
+
entry_sorted = None if entry is None else entry[order]
|
|
689
|
+
cluster_sorted = None if cluster is None else cluster[order]
|
|
690
|
+
event_idx_sorted = cp.where(event_sorted == 1)[0]
|
|
691
|
+
self._event_idx_gpu = event_idx_sorted
|
|
692
|
+
self._event_X_sum_gpu = (
|
|
693
|
+
cp.sum(X_sorted[event_idx_sorted], axis=0)
|
|
694
|
+
if int(event_idx_sorted.size) > 0
|
|
695
|
+
else cp.zeros(n_features, dtype=cp.float64)
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# Precompute Efron tie structure once (depends only on time/event order).
|
|
699
|
+
efron_pre = None
|
|
700
|
+
self._breslow_pre = None
|
|
701
|
+
self._breslow_pre_gpu = None
|
|
702
|
+
if self.ties == "efron":
|
|
703
|
+
if entry_sorted is None:
|
|
704
|
+
efron_pre = self._efron_unique_failure_indices(
|
|
705
|
+
cp.asnumpy(time_sorted), cp.asnumpy(event_sorted)
|
|
706
|
+
)
|
|
707
|
+
self._efron_pre = efron_pre
|
|
708
|
+
try:
|
|
709
|
+
_, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
710
|
+
self._efron_all_singletons = bool(nuft > 0) and all(
|
|
711
|
+
len(ix) == 1 for ix in uft_ix
|
|
712
|
+
)
|
|
713
|
+
except Exception:
|
|
714
|
+
self._efron_all_singletons = False
|
|
715
|
+
# Pack enter/exit/fail indices once; reuse across Newton steps on GPU.
|
|
716
|
+
try:
|
|
717
|
+
from ._cox_efron_cuda import efron_indices_to_csr
|
|
718
|
+
|
|
719
|
+
uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft = _unpack_efron_pre6(
|
|
720
|
+
efron_pre
|
|
721
|
+
)
|
|
722
|
+
(
|
|
723
|
+
enter_ptr,
|
|
724
|
+
enter_ind,
|
|
725
|
+
exit_ptr,
|
|
726
|
+
exit_ind,
|
|
727
|
+
fail_ptr,
|
|
728
|
+
fail_ind,
|
|
729
|
+
) = efron_indices_to_csr(uft_ix, risk_enter, risk_exit, nuft)
|
|
730
|
+
self._efron_pre_csr = (
|
|
731
|
+
enter_ptr,
|
|
732
|
+
enter_ind,
|
|
733
|
+
exit_ptr,
|
|
734
|
+
exit_ind,
|
|
735
|
+
fail_ptr,
|
|
736
|
+
fail_ind,
|
|
737
|
+
first_idx_uft,
|
|
738
|
+
nuft,
|
|
739
|
+
)
|
|
740
|
+
self._efron_pre_csr_gpu = (
|
|
741
|
+
cp.asarray(enter_ptr, dtype=cp.int32),
|
|
742
|
+
cp.asarray(enter_ind, dtype=cp.int32),
|
|
743
|
+
cp.asarray(exit_ptr, dtype=cp.int32),
|
|
744
|
+
cp.asarray(exit_ind, dtype=cp.int32),
|
|
745
|
+
cp.asarray(fail_ptr, dtype=cp.int32),
|
|
746
|
+
cp.asarray(fail_ind, dtype=cp.int32),
|
|
747
|
+
cp.asarray(first_idx_uft, dtype=cp.int32),
|
|
748
|
+
int(nuft),
|
|
749
|
+
)
|
|
750
|
+
except Exception:
|
|
751
|
+
self._efron_pre_csr = None
|
|
752
|
+
self._efron_pre_csr_gpu = None
|
|
753
|
+
else:
|
|
754
|
+
self._efron_pre = None
|
|
755
|
+
self._efron_pre_csr = None
|
|
756
|
+
self._efron_pre_csr_gpu = None
|
|
757
|
+
try:
|
|
758
|
+
_, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
759
|
+
n_events = int(cp.asnumpy(cp.sum(event_sorted)))
|
|
760
|
+
avg_tie = float(n_events / max(1, int(nuft)))
|
|
761
|
+
except Exception:
|
|
762
|
+
avg_tie = 1.0
|
|
763
|
+
else:
|
|
764
|
+
self._efron_pre = None
|
|
765
|
+
self._efron_all_singletons = False
|
|
766
|
+
self._efron_pre_csr = None
|
|
767
|
+
self._efron_pre_csr_gpu = None
|
|
768
|
+
first_idx_uft, counts_uft = self._breslow_unique_failure_groups(
|
|
769
|
+
cp.asnumpy(time_sorted), cp.asnumpy(event_sorted)
|
|
770
|
+
)
|
|
771
|
+
self._breslow_pre = (first_idx_uft, counts_uft)
|
|
772
|
+
self._breslow_pre_gpu = (
|
|
773
|
+
cp.asarray(first_idx_uft, dtype=cp.int32),
|
|
774
|
+
cp.asarray(counts_uft, dtype=cp.int32),
|
|
775
|
+
)
|
|
776
|
+
self._breslow_counts_f_gpu = cp.asarray(counts_uft, dtype=cp.float64)
|
|
777
|
+
self._breslow_first_idx_np = np.asarray(first_idx_uft, dtype=np.int64)
|
|
778
|
+
self._breslow_counts_np = np.asarray(counts_uft, dtype=np.float64)
|
|
779
|
+
if entry_sorted is not None:
|
|
780
|
+
# Entry path: avoid stale index cache drift across different sort permutations.
|
|
781
|
+
self._entry_fail_groups_gpu = None
|
|
782
|
+
self._entry_fail_times_gpu = None
|
|
783
|
+
self._entry_order_gpu = None
|
|
784
|
+
self._entry_add_end_np_gpu = None
|
|
785
|
+
self._entry_rem_end_np_gpu = None
|
|
786
|
+
else:
|
|
787
|
+
self._entry_fail_groups_gpu = None
|
|
788
|
+
self._entry_fail_times_gpu = None
|
|
789
|
+
self._entry_order_gpu = None
|
|
790
|
+
self._entry_add_end_np_gpu = None
|
|
791
|
+
self._entry_rem_end_np_gpu = None
|
|
792
|
+
n_events = int(cp.asnumpy(cp.sum(event_sorted)))
|
|
793
|
+
avg_tie = float(n_events / max(1, int(len(counts_uft))))
|
|
794
|
+
|
|
795
|
+
# Initialize coefficients on GPU (supports warm-start path in CV)
|
|
796
|
+
if init_coef is None:
|
|
797
|
+
beta = cp.zeros(n_features, dtype=cp.float64)
|
|
798
|
+
else:
|
|
799
|
+
beta = cp.asarray(np.asarray(init_coef, dtype=np.float64), dtype=cp.float64).reshape(-1)
|
|
800
|
+
if int(beta.shape[0]) != int(n_features):
|
|
801
|
+
raise ValueError("init_coef must have shape (n_features,)")
|
|
802
|
+
|
|
803
|
+
# Compute null log-likelihood on GPU
|
|
804
|
+
entry_ctx_gpu = None
|
|
805
|
+
if entry_sorted is not None:
|
|
806
|
+
_ctx = self._build_entry_ctx_gpu(time_sorted, event_sorted, entry_sorted, cp)
|
|
807
|
+
event_idx_ctx = _ctx[5]
|
|
808
|
+
entry_ctx_gpu = (
|
|
809
|
+
_ctx[0], _ctx[1], _ctx[2], _ctx[3],
|
|
810
|
+
cp.ascontiguousarray(X_sorted[_ctx[0]]),
|
|
811
|
+
cp.ascontiguousarray(X_sorted),
|
|
812
|
+
event_idx_ctx,
|
|
813
|
+
cp.sum(X_sorted[event_idx_ctx], axis=0),
|
|
814
|
+
_ctx[6],
|
|
815
|
+
)
|
|
816
|
+
loglik_null_gpu = self._compute_log_likelihood_gpu(
|
|
817
|
+
cp.zeros(n_features, dtype=cp.float64),
|
|
818
|
+
X_sorted,
|
|
819
|
+
time_sorted,
|
|
820
|
+
event_sorted,
|
|
821
|
+
efron_pre,
|
|
822
|
+
entry=entry_sorted,
|
|
823
|
+
entry_ctx=entry_ctx_gpu,
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Newton-Raphson optimization on GPU with L2 penalty
|
|
827
|
+
penalty = float(self.penalty) if hasattr(self, 'penalty') else 0.0
|
|
828
|
+
use_penalty = penalty > 0.0
|
|
829
|
+
diag_idx = cp.arange(n_features, dtype=cp.int64) if use_penalty else None
|
|
830
|
+
eye_cache = (
|
|
831
|
+
cp.eye(n_features, dtype=cp.float64)
|
|
832
|
+
if (self.compute_inference or use_penalty)
|
|
833
|
+
else None
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
# Newton-Raphson optimization on GPU
|
|
837
|
+
loglik_gpu = None
|
|
838
|
+
current_obj = None
|
|
839
|
+
iteration = -1 # default if max_iter=0
|
|
840
|
+
for iteration in range(self.max_iter):
|
|
841
|
+
# Compute gradient and Hessian on GPU
|
|
842
|
+
grad, hess, aux_stats = self._compute_gradient_hessian_gpu(
|
|
843
|
+
beta, X_sorted, time_sorted, event_sorted, efron_pre, return_aux=True, entry=entry_sorted, entry_ctx=entry_ctx_gpu
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
# Add penalty terms: gradient -= 2*penalty*beta, hessian -= 2*penalty*I
|
|
847
|
+
if use_penalty:
|
|
848
|
+
grad = grad - 2 * penalty * beta
|
|
849
|
+
# In-place diagonal shift avoids allocating a new dense eye each iteration.
|
|
850
|
+
hess[diag_idx, diag_idx] -= 2 * penalty
|
|
851
|
+
|
|
852
|
+
# Newton: delta = inv(hess) @ grad; hess is NSD — solve (-hess) x = grad, delta = -x
|
|
853
|
+
delta = self._solve_newton_delta_gpu(hess, grad, cp, eye_cache=eye_cache)
|
|
854
|
+
step = 1.0
|
|
855
|
+
accepted_step = True
|
|
856
|
+
if entry_sorted is not None:
|
|
857
|
+
if current_obj is None:
|
|
858
|
+
old_ll = self._compute_log_likelihood_gpu_from_stats(
|
|
859
|
+
aux_stats[0], aux_stats[1], aux_stats[2], time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_gpu
|
|
860
|
+
)
|
|
861
|
+
if use_penalty:
|
|
862
|
+
old_ll = old_ll - penalty * cp.sum(beta * beta)
|
|
863
|
+
current_obj = old_ll
|
|
864
|
+
else:
|
|
865
|
+
old_ll = current_obj
|
|
866
|
+
new_beta = beta - delta
|
|
867
|
+
new_ll = self._compute_log_likelihood_gpu(
|
|
868
|
+
new_beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_gpu
|
|
869
|
+
)
|
|
870
|
+
if use_penalty:
|
|
871
|
+
new_ll = new_ll - penalty * cp.sum(new_beta * new_beta)
|
|
872
|
+
if float((new_ll - old_ll).item()) <= -1e-8:
|
|
873
|
+
step = 0.5
|
|
874
|
+
accepted = False
|
|
875
|
+
for _ in range(20):
|
|
876
|
+
trial_beta = beta - step * delta
|
|
877
|
+
trial_ll = self._compute_log_likelihood_gpu(
|
|
878
|
+
trial_beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_gpu
|
|
879
|
+
)
|
|
880
|
+
if use_penalty:
|
|
881
|
+
trial_ll = trial_ll - penalty * cp.sum(trial_beta * trial_beta)
|
|
882
|
+
if float((trial_ll - old_ll).item()) > -1e-8:
|
|
883
|
+
beta = trial_beta
|
|
884
|
+
current_obj = trial_ll
|
|
885
|
+
accepted = True
|
|
886
|
+
break
|
|
887
|
+
step *= 0.5
|
|
888
|
+
if not accepted:
|
|
889
|
+
accepted_step = False
|
|
890
|
+
else:
|
|
891
|
+
beta = new_beta
|
|
892
|
+
current_obj = new_ll
|
|
893
|
+
else:
|
|
894
|
+
beta = beta - delta
|
|
895
|
+
|
|
896
|
+
# Check convergence on GPU
|
|
897
|
+
if entry_sorted is not None:
|
|
898
|
+
delta_norm = float(cp.linalg.norm(delta).item())
|
|
899
|
+
if accepted_step and delta_norm * step < self.tol:
|
|
900
|
+
self._converged = True
|
|
901
|
+
loglik_gpu = self._compute_log_likelihood_gpu(
|
|
902
|
+
beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_gpu
|
|
903
|
+
)
|
|
904
|
+
break
|
|
905
|
+
else:
|
|
906
|
+
grad_norm = float(cp.linalg.norm(grad).item())
|
|
907
|
+
delta_norm = float(cp.linalg.norm(delta).item())
|
|
908
|
+
if accepted_step and grad_norm < max(self.tol * 10.0, 1e-8) and delta_norm * step < self.tol:
|
|
909
|
+
self._converged = True
|
|
910
|
+
# Reuse current iteration statistics to avoid an extra
|
|
911
|
+
# Efron log-likelihood setup pass when converged.
|
|
912
|
+
eta_cur, exp_eta_cur, risk_sum_cur = aux_stats
|
|
913
|
+
loglik_gpu = self._compute_log_likelihood_gpu_from_stats(
|
|
914
|
+
eta_cur, exp_eta_cur, risk_sum_cur, time_sorted, event_sorted, efron_pre, entry=entry_sorted
|
|
915
|
+
)
|
|
916
|
+
break
|
|
917
|
+
|
|
918
|
+
# Compute final log-likelihood on GPU unless already obtained on convergence.
|
|
919
|
+
if loglik_gpu is None:
|
|
920
|
+
loglik_gpu = self._compute_log_likelihood_gpu(
|
|
921
|
+
beta, X_sorted, time_sorted, event_sorted, efron_pre
|
|
922
|
+
, entry=entry_sorted, entry_ctx=entry_ctx_gpu
|
|
923
|
+
)
|
|
924
|
+
|
|
925
|
+
# Single transfer at the end
|
|
926
|
+
self._iterations = iteration + 1
|
|
927
|
+
self.coef_ = cp.asnumpy(beta)
|
|
928
|
+
self.hazard_ratios_ = np.exp(self.coef_)
|
|
929
|
+
self._log_likelihood_null = float(cp.asnumpy(loglik_null_gpu))
|
|
930
|
+
self._log_likelihood = float(cp.asnumpy(loglik_gpu))
|
|
931
|
+
if self.compute_cindex:
|
|
932
|
+
cindex_gpu = self._compute_cindex_gpu(X_sorted, time_sorted, event_sorted, beta)
|
|
933
|
+
self._cindex = float(cp.asnumpy(cindex_gpu))
|
|
934
|
+
else:
|
|
935
|
+
self._cindex = None
|
|
936
|
+
|
|
937
|
+
# Inference:
|
|
938
|
+
# - nonrobust: stay on GPU to avoid expensive host transfers/recompute
|
|
939
|
+
# - hc0/hc1/cluster: use CPU inference path (current implementation)
|
|
940
|
+
if self.compute_inference:
|
|
941
|
+
if self.cov_type == "nonrobust":
|
|
942
|
+
try:
|
|
943
|
+
info = -hess
|
|
944
|
+
rhs_eye = eye_cache if eye_cache is not None else cp.eye(info.shape[0], dtype=info.dtype)
|
|
945
|
+
var_gpu = cp.linalg.solve(info, rhs_eye)
|
|
946
|
+
except Exception:
|
|
947
|
+
var_gpu = cp.linalg.pinv(-hess)
|
|
948
|
+
bse_gpu = cp.sqrt(cp.maximum(cp.diag(var_gpu), 0.0))
|
|
949
|
+
z_gpu = beta / (bse_gpu + 1e-30)
|
|
950
|
+
p_gpu = cp.minimum(1.0, 2.0 * norm.sf(cp.abs(z_gpu)))
|
|
951
|
+
z_crit = norm.ppf(0.975)
|
|
952
|
+
ci_gpu = cp.stack([beta - z_crit * bse_gpu, beta + z_crit * bse_gpu], axis=1)
|
|
953
|
+
|
|
954
|
+
self._bse = cp.asnumpy(bse_gpu)
|
|
955
|
+
self._zvalues = cp.asnumpy(z_gpu)
|
|
956
|
+
self._pvalues = cp.asnumpy(p_gpu)
|
|
957
|
+
self._conf_int = cp.asnumpy(ci_gpu)
|
|
958
|
+
self._var_matrix = np.diag(np.square(self._bse))
|
|
959
|
+
self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
|
|
960
|
+
self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
|
|
961
|
+
try:
|
|
962
|
+
var_inv = np.linalg.solve(self._var_matrix, np.eye(self._var_matrix.shape[0]))
|
|
963
|
+
self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
|
|
964
|
+
except np.linalg.LinAlgError:
|
|
965
|
+
self._wald_test_stat = np.nan
|
|
966
|
+
self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
|
|
967
|
+
self._score_test_stat = np.nan
|
|
968
|
+
self._score_test_pvalue = np.nan
|
|
969
|
+
# Keep baseline hazard optional in CUDA fast path to reduce transfer overhead.
|
|
970
|
+
self._baseline_hazard = None
|
|
971
|
+
self._baseline_cumulative_hazard = None
|
|
972
|
+
self._unique_times = None
|
|
973
|
+
else:
|
|
974
|
+
score_resid_gpu = self._compute_robust_score_residuals_gpu(X_sorted, time_sorted, event_sorted)
|
|
975
|
+
try:
|
|
976
|
+
info = -hess
|
|
977
|
+
rhs_eye = eye_cache if eye_cache is not None else cp.eye(info.shape[0], dtype=info.dtype)
|
|
978
|
+
bread = cp.linalg.solve(info, rhs_eye)
|
|
979
|
+
except Exception:
|
|
980
|
+
bread = cp.linalg.pinv(-hess)
|
|
981
|
+
|
|
982
|
+
if self.cov_type == "cluster":
|
|
983
|
+
if cluster_sorted is None:
|
|
984
|
+
raise ValueError("cov_type='cluster' requires cluster ids in fit(..., cluster=...)")
|
|
985
|
+
unique_clusters = cp.unique(cluster_sorted)
|
|
986
|
+
meat = cp.zeros((n_features, n_features), dtype=cp.float64)
|
|
987
|
+
for g in unique_clusters:
|
|
988
|
+
u_g = cp.sum(score_resid_gpu[cluster_sorted == g], axis=0)
|
|
989
|
+
meat += cp.outer(u_g, u_g)
|
|
990
|
+
else:
|
|
991
|
+
meat = score_resid_gpu.T @ score_resid_gpu
|
|
992
|
+
if self.cov_type == "hc1":
|
|
993
|
+
n = X_sorted.shape[0]
|
|
994
|
+
k = X_sorted.shape[1]
|
|
995
|
+
if n > k:
|
|
996
|
+
meat = meat * (n / (n - k))
|
|
997
|
+
|
|
998
|
+
var_gpu = bread @ meat @ bread
|
|
999
|
+
bse_gpu = cp.sqrt(cp.maximum(cp.diag(var_gpu), 0.0))
|
|
1000
|
+
z_gpu = beta / (bse_gpu + 1e-30)
|
|
1001
|
+
p_gpu = cp.minimum(1.0, 2.0 * norm.sf(cp.abs(z_gpu)))
|
|
1002
|
+
z_crit = norm.ppf(0.975)
|
|
1003
|
+
ci_gpu = cp.stack([beta - z_crit * bse_gpu, beta + z_crit * bse_gpu], axis=1)
|
|
1004
|
+
|
|
1005
|
+
self._var_matrix = cp.asnumpy(var_gpu)
|
|
1006
|
+
self._bse = cp.asnumpy(bse_gpu)
|
|
1007
|
+
self._zvalues = cp.asnumpy(z_gpu)
|
|
1008
|
+
self._pvalues = cp.asnumpy(p_gpu)
|
|
1009
|
+
self._conf_int = cp.asnumpy(ci_gpu)
|
|
1010
|
+
self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
|
|
1011
|
+
self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
|
|
1012
|
+
try:
|
|
1013
|
+
var_inv = np.linalg.solve(self._var_matrix, np.eye(self._var_matrix.shape[0]))
|
|
1014
|
+
self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
|
|
1015
|
+
except np.linalg.LinAlgError:
|
|
1016
|
+
self._wald_test_stat = np.nan
|
|
1017
|
+
self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
|
|
1018
|
+
self._score_test_stat = np.nan
|
|
1019
|
+
self._score_test_pvalue = np.nan
|
|
1020
|
+
# Compute baseline hazard on GPU
|
|
1021
|
+
self._compute_baseline_hazard_gpu(X_sorted, time_sorted, event_sorted, beta)
|
|
1022
|
+
else:
|
|
1023
|
+
self._var_matrix = None
|
|
1024
|
+
self._bse = None
|
|
1025
|
+
self._zvalues = None
|
|
1026
|
+
self._pvalues = None
|
|
1027
|
+
self._conf_int = None
|
|
1028
|
+
self._score_test_stat = None
|
|
1029
|
+
self._score_test_pvalue = None
|
|
1030
|
+
self._wald_test_stat = None
|
|
1031
|
+
self._wald_test_pvalue = None
|
|
1032
|
+
self._lr_test_stat = None
|
|
1033
|
+
self._lr_test_pvalue = None
|
|
1034
|
+
self._baseline_hazard = None
|
|
1035
|
+
self._baseline_cumulative_hazard = None
|
|
1036
|
+
self._unique_times = None
|
|
1037
|
+
|
|
1038
|
+
def _fit_torch(self, X, time, event, entry=None, cluster=None, torch_device="cuda", init_coef=None):
|
|
1039
|
+
"""Fit using Torch with full GPU computation."""
|
|
1040
|
+
import torch
|
|
1041
|
+
from statgpu.inference._distributions_backend import norm
|
|
1042
|
+
|
|
1043
|
+
n_samples, n_features = X.shape
|
|
1044
|
+
|
|
1045
|
+
# Sort by time ascending so risk-set terms are suffix sums
|
|
1046
|
+
order = torch.argsort(time, stable=True)
|
|
1047
|
+
X_sorted = X[order]
|
|
1048
|
+
time_sorted = time[order]
|
|
1049
|
+
event_sorted = event[order]
|
|
1050
|
+
entry_sorted = None if entry is None else entry[order]
|
|
1051
|
+
cluster_sorted = None if cluster is None else cluster[order]
|
|
1052
|
+
|
|
1053
|
+
# Precompute Efron tie structure once (depends only on time/event order)
|
|
1054
|
+
efron_pre = None
|
|
1055
|
+
self._breslow_pre = None
|
|
1056
|
+
self._breslow_pre_torch = None
|
|
1057
|
+
if self.ties == "efron":
|
|
1058
|
+
if entry_sorted is None:
|
|
1059
|
+
efron_pre = self._efron_unique_failure_indices(
|
|
1060
|
+
time_sorted.cpu().numpy(), event_sorted.cpu().numpy()
|
|
1061
|
+
)
|
|
1062
|
+
self._efron_pre = efron_pre
|
|
1063
|
+
try:
|
|
1064
|
+
_, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
1065
|
+
self._efron_all_singletons = bool(nuft > 0) and all(
|
|
1066
|
+
len(ix) == 1 for ix in uft_ix
|
|
1067
|
+
)
|
|
1068
|
+
except Exception:
|
|
1069
|
+
self._efron_all_singletons = False
|
|
1070
|
+
# Reuse CUDA CSR packing for Torch-CUDA fused kernels when available.
|
|
1071
|
+
try:
|
|
1072
|
+
import cupy as cp
|
|
1073
|
+
from ._cox_efron_cuda import efron_indices_to_csr
|
|
1074
|
+
|
|
1075
|
+
uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft = _unpack_efron_pre6(
|
|
1076
|
+
efron_pre
|
|
1077
|
+
)
|
|
1078
|
+
(
|
|
1079
|
+
enter_ptr,
|
|
1080
|
+
enter_ind,
|
|
1081
|
+
exit_ptr,
|
|
1082
|
+
exit_ind,
|
|
1083
|
+
fail_ptr,
|
|
1084
|
+
fail_ind,
|
|
1085
|
+
) = efron_indices_to_csr(uft_ix, risk_enter, risk_exit, nuft)
|
|
1086
|
+
self._efron_pre_csr = (
|
|
1087
|
+
enter_ptr,
|
|
1088
|
+
enter_ind,
|
|
1089
|
+
exit_ptr,
|
|
1090
|
+
exit_ind,
|
|
1091
|
+
fail_ptr,
|
|
1092
|
+
fail_ind,
|
|
1093
|
+
first_idx_uft,
|
|
1094
|
+
nuft,
|
|
1095
|
+
)
|
|
1096
|
+
self._efron_pre_csr_gpu = (
|
|
1097
|
+
cp.asarray(enter_ptr, dtype=cp.int32),
|
|
1098
|
+
cp.asarray(enter_ind, dtype=cp.int32),
|
|
1099
|
+
cp.asarray(exit_ptr, dtype=cp.int32),
|
|
1100
|
+
cp.asarray(exit_ind, dtype=cp.int32),
|
|
1101
|
+
cp.asarray(fail_ptr, dtype=cp.int32),
|
|
1102
|
+
cp.asarray(fail_ind, dtype=cp.int32),
|
|
1103
|
+
cp.asarray(first_idx_uft, dtype=cp.int32),
|
|
1104
|
+
int(nuft),
|
|
1105
|
+
)
|
|
1106
|
+
except Exception:
|
|
1107
|
+
self._efron_pre_csr = None
|
|
1108
|
+
self._efron_pre_csr_gpu = None
|
|
1109
|
+
else:
|
|
1110
|
+
self._efron_pre = None
|
|
1111
|
+
self._efron_pre_csr = None
|
|
1112
|
+
self._efron_pre_csr_gpu = None
|
|
1113
|
+
try:
|
|
1114
|
+
_, uft_ix, _, _, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
1115
|
+
n_events = int(torch.sum(event_sorted).item())
|
|
1116
|
+
avg_tie = float(n_events / max(1, int(nuft)))
|
|
1117
|
+
except Exception:
|
|
1118
|
+
avg_tie = 1.0
|
|
1119
|
+
else:
|
|
1120
|
+
self._efron_pre = None
|
|
1121
|
+
self._efron_all_singletons = False
|
|
1122
|
+
self._efron_pre_csr = None
|
|
1123
|
+
self._efron_pre_csr_gpu = None
|
|
1124
|
+
first_idx_uft, counts_uft = self._breslow_unique_failure_groups(
|
|
1125
|
+
time_sorted.cpu().numpy(), event_sorted.cpu().numpy()
|
|
1126
|
+
)
|
|
1127
|
+
self._breslow_pre = (first_idx_uft, counts_uft)
|
|
1128
|
+
self._breslow_pre_torch = (
|
|
1129
|
+
torch.tensor(first_idx_uft, dtype=torch.int32, device=torch_device),
|
|
1130
|
+
torch.tensor(counts_uft, dtype=torch.int32, device=torch_device),
|
|
1131
|
+
)
|
|
1132
|
+
if entry_sorted is not None:
|
|
1133
|
+
# Entry path: avoid stale index cache drift across different sort permutations.
|
|
1134
|
+
self._entry_fail_groups_torch = None
|
|
1135
|
+
self._entry_fail_times_torch = None
|
|
1136
|
+
self._entry_order_torch = None
|
|
1137
|
+
self._entry_add_end_np_torch = None
|
|
1138
|
+
self._entry_rem_end_np_torch = None
|
|
1139
|
+
else:
|
|
1140
|
+
self._entry_fail_groups_torch = None
|
|
1141
|
+
self._entry_fail_times_torch = None
|
|
1142
|
+
self._entry_order_torch = None
|
|
1143
|
+
self._entry_add_end_np_torch = None
|
|
1144
|
+
self._entry_rem_end_np_torch = None
|
|
1145
|
+
n_events = int(torch.sum(event_sorted).item())
|
|
1146
|
+
avg_tie = float(n_events / max(1, int(len(counts_uft))))
|
|
1147
|
+
|
|
1148
|
+
# Initialize coefficients on Torch device (supports warm-start path in CV)
|
|
1149
|
+
if init_coef is None:
|
|
1150
|
+
beta = torch.zeros(n_features, dtype=torch.float64, device=torch_device)
|
|
1151
|
+
else:
|
|
1152
|
+
beta = torch.as_tensor(init_coef, dtype=torch.float64, device=torch_device).reshape(-1)
|
|
1153
|
+
if int(beta.shape[0]) != int(n_features):
|
|
1154
|
+
raise ValueError("init_coef must have shape (n_features,)")
|
|
1155
|
+
|
|
1156
|
+
# Compute null log-likelihood on Torch
|
|
1157
|
+
entry_ctx_torch = None
|
|
1158
|
+
if entry_sorted is not None:
|
|
1159
|
+
_ctx = self._build_entry_ctx_torch(time_sorted, event_sorted, entry_sorted, torch_device)
|
|
1160
|
+
event_idx_ctx = _ctx[5]
|
|
1161
|
+
entry_ctx_torch = (
|
|
1162
|
+
_ctx[0],
|
|
1163
|
+
_ctx[1],
|
|
1164
|
+
_ctx[2],
|
|
1165
|
+
_ctx[3],
|
|
1166
|
+
X_sorted.index_select(0, _ctx[0]).contiguous(),
|
|
1167
|
+
X_sorted.contiguous(),
|
|
1168
|
+
event_idx_ctx,
|
|
1169
|
+
torch.sum(X_sorted.index_select(0, event_idx_ctx), dim=0),
|
|
1170
|
+
_ctx[6],
|
|
1171
|
+
)
|
|
1172
|
+
loglik_null_torch = self._compute_log_likelihood_torch(
|
|
1173
|
+
torch.zeros(n_features, dtype=torch.float64, device=torch_device),
|
|
1174
|
+
X_sorted,
|
|
1175
|
+
time_sorted,
|
|
1176
|
+
event_sorted,
|
|
1177
|
+
efron_pre,
|
|
1178
|
+
entry=entry_sorted,
|
|
1179
|
+
entry_ctx=entry_ctx_torch,
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1182
|
+
# Newton-Raphson optimization on Torch with L2 penalty
|
|
1183
|
+
penalty = float(self.penalty) if hasattr(self, 'penalty') else 0.0
|
|
1184
|
+
use_penalty = penalty > 0.0
|
|
1185
|
+
diag_idx = torch.arange(n_features, dtype=torch.long, device=torch_device) if use_penalty else None
|
|
1186
|
+
|
|
1187
|
+
# Newton-Raphson optimization on Torch
|
|
1188
|
+
iteration = 0
|
|
1189
|
+
loglik_torch = None
|
|
1190
|
+
current_obj = None
|
|
1191
|
+
for iteration in range(self.max_iter):
|
|
1192
|
+
# Compute gradient and Hessian on Torch
|
|
1193
|
+
grad, hess, aux_stats = self._compute_gradient_hessian_torch(
|
|
1194
|
+
beta, X_sorted, time_sorted, event_sorted, efron_pre, return_aux=True, entry=entry_sorted, entry_ctx=entry_ctx_torch
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
# Add penalty terms: gradient -= 2*penalty*beta, hessian -= 2*penalty*I
|
|
1198
|
+
if use_penalty:
|
|
1199
|
+
grad = grad - 2 * penalty * beta
|
|
1200
|
+
hess[diag_idx, diag_idx] -= 2 * penalty
|
|
1201
|
+
|
|
1202
|
+
# Newton: delta = inv(hess) @ grad; hess is NSD — solve (-hess) x = grad, delta = -x
|
|
1203
|
+
delta = self._solve_newton_delta_torch(hess, grad)
|
|
1204
|
+
step = 1.0
|
|
1205
|
+
accepted_step = True
|
|
1206
|
+
if entry_sorted is not None:
|
|
1207
|
+
if current_obj is None:
|
|
1208
|
+
old_ll = self._compute_log_likelihood_torch_from_stats(
|
|
1209
|
+
aux_stats[0], aux_stats[1], aux_stats[2], time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_torch
|
|
1210
|
+
)
|
|
1211
|
+
if use_penalty:
|
|
1212
|
+
old_ll = old_ll - penalty * torch.sum(beta * beta)
|
|
1213
|
+
current_obj = old_ll
|
|
1214
|
+
else:
|
|
1215
|
+
old_ll = current_obj
|
|
1216
|
+
new_beta = beta - delta
|
|
1217
|
+
new_ll = self._compute_log_likelihood_torch(
|
|
1218
|
+
new_beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_torch
|
|
1219
|
+
)
|
|
1220
|
+
if use_penalty:
|
|
1221
|
+
new_ll = new_ll - penalty * torch.sum(new_beta * new_beta)
|
|
1222
|
+
if float((new_ll - old_ll).item()) <= -1e-8:
|
|
1223
|
+
step = 0.5
|
|
1224
|
+
accepted = False
|
|
1225
|
+
for _ in range(20):
|
|
1226
|
+
trial_beta = beta - step * delta
|
|
1227
|
+
trial_ll = self._compute_log_likelihood_torch(
|
|
1228
|
+
trial_beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_torch
|
|
1229
|
+
)
|
|
1230
|
+
if use_penalty:
|
|
1231
|
+
trial_ll = trial_ll - penalty * torch.sum(trial_beta * trial_beta)
|
|
1232
|
+
if float((trial_ll - old_ll).item()) > -1e-8:
|
|
1233
|
+
beta = trial_beta
|
|
1234
|
+
current_obj = trial_ll
|
|
1235
|
+
accepted = True
|
|
1236
|
+
break
|
|
1237
|
+
step *= 0.5
|
|
1238
|
+
if not accepted:
|
|
1239
|
+
accepted_step = False
|
|
1240
|
+
else:
|
|
1241
|
+
beta = new_beta
|
|
1242
|
+
current_obj = new_ll
|
|
1243
|
+
else:
|
|
1244
|
+
beta = beta - delta
|
|
1245
|
+
|
|
1246
|
+
# Check convergence
|
|
1247
|
+
if entry_sorted is not None:
|
|
1248
|
+
delta_norm = float(torch.linalg.norm(delta).item())
|
|
1249
|
+
if accepted_step and delta_norm * step < self.tol:
|
|
1250
|
+
self._converged = True
|
|
1251
|
+
loglik_torch = self._compute_log_likelihood_torch(
|
|
1252
|
+
beta, X_sorted, time_sorted, event_sorted, efron_pre, entry=entry_sorted, entry_ctx=entry_ctx_torch
|
|
1253
|
+
)
|
|
1254
|
+
break
|
|
1255
|
+
else:
|
|
1256
|
+
grad_norm = float(torch.linalg.norm(grad).item())
|
|
1257
|
+
delta_norm = float(torch.linalg.norm(delta).item())
|
|
1258
|
+
if accepted_step and grad_norm < max(self.tol * 10.0, 1e-8) and delta_norm * step < self.tol:
|
|
1259
|
+
self._converged = True
|
|
1260
|
+
eta_cur, exp_eta_cur, risk_sum_cur = aux_stats
|
|
1261
|
+
loglik_torch = self._compute_log_likelihood_torch_from_stats(
|
|
1262
|
+
eta_cur, exp_eta_cur, risk_sum_cur, time_sorted, event_sorted, efron_pre, entry=entry_sorted
|
|
1263
|
+
)
|
|
1264
|
+
break
|
|
1265
|
+
|
|
1266
|
+
# Compute final log-likelihood on Torch unless already obtained.
|
|
1267
|
+
if loglik_torch is None:
|
|
1268
|
+
loglik_torch = self._compute_log_likelihood_torch(
|
|
1269
|
+
beta, X_sorted, time_sorted, event_sorted, efron_pre
|
|
1270
|
+
, entry=entry_sorted, entry_ctx=entry_ctx_torch
|
|
1271
|
+
)
|
|
1272
|
+
|
|
1273
|
+
# Single transfer at the end
|
|
1274
|
+
self._iterations = iteration + 1
|
|
1275
|
+
self.coef_ = beta.cpu().numpy()
|
|
1276
|
+
self.hazard_ratios_ = np.exp(self.coef_)
|
|
1277
|
+
self._log_likelihood_null = float(loglik_null_torch.item())
|
|
1278
|
+
self._log_likelihood = float(loglik_torch.item())
|
|
1279
|
+
if self.compute_cindex:
|
|
1280
|
+
cindex_torch = self._compute_cindex_torch(X_sorted, time_sorted, event_sorted, beta)
|
|
1281
|
+
self._cindex = float(cindex_torch.item())
|
|
1282
|
+
else:
|
|
1283
|
+
self._cindex = None
|
|
1284
|
+
|
|
1285
|
+
# Inference: nonrobust on Torch, other types fall back to CPU
|
|
1286
|
+
if self.compute_inference:
|
|
1287
|
+
if self.cov_type == "nonrobust":
|
|
1288
|
+
try:
|
|
1289
|
+
info = -hess
|
|
1290
|
+
var_torch = torch.linalg.solve(info, torch.eye(info.shape[0], dtype=info.dtype, device=torch_device))
|
|
1291
|
+
except Exception:
|
|
1292
|
+
var_torch = torch.linalg.pinv(-hess)
|
|
1293
|
+
bse_torch = torch.sqrt(torch.maximum(torch.diag(var_torch), torch.tensor(0.0, dtype=torch.float64, device=torch_device)))
|
|
1294
|
+
z_torch = beta / (bse_torch + 1e-30)
|
|
1295
|
+
p_torch = torch.minimum(torch.tensor(1.0, device=torch_device), 2.0 * norm.sf(torch.abs(z_torch)))
|
|
1296
|
+
z_crit = norm.ppf(0.975)
|
|
1297
|
+
ci_torch = torch.stack([beta - z_crit * bse_torch, beta + z_crit * bse_torch], dim=1)
|
|
1298
|
+
|
|
1299
|
+
self._bse = bse_torch.cpu().numpy()
|
|
1300
|
+
self._zvalues = z_torch.cpu().numpy()
|
|
1301
|
+
self._pvalues = p_torch.cpu().numpy()
|
|
1302
|
+
self._conf_int = ci_torch.cpu().numpy()
|
|
1303
|
+
self._var_matrix = np.diag(np.square(self._bse))
|
|
1304
|
+
self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
|
|
1305
|
+
self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
|
|
1306
|
+
try:
|
|
1307
|
+
var_inv = np.linalg.solve(self._var_matrix, np.eye(self._var_matrix.shape[0]))
|
|
1308
|
+
self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
|
|
1309
|
+
except np.linalg.LinAlgError:
|
|
1310
|
+
self._wald_test_stat = np.nan
|
|
1311
|
+
self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
|
|
1312
|
+
self._score_test_stat = np.nan
|
|
1313
|
+
self._score_test_pvalue = np.nan
|
|
1314
|
+
# Compute baseline hazard on Torch
|
|
1315
|
+
self._compute_baseline_hazard_torch(X_sorted, time_sorted, event_sorted, beta)
|
|
1316
|
+
else:
|
|
1317
|
+
# For hc0/hc1/cluster, use CPU inference path
|
|
1318
|
+
self._compute_inference_cpu(X_sorted.cpu().numpy(), time_sorted.cpu().numpy(), event_sorted.cpu().numpy(),
|
|
1319
|
+
cluster_sorted.cpu().numpy() if cluster_sorted is not None else None)
|
|
1320
|
+
self._baseline_hazard = None
|
|
1321
|
+
self._baseline_cumulative_hazard = None
|
|
1322
|
+
self._unique_times = None
|
|
1323
|
+
else:
|
|
1324
|
+
self._var_matrix = None
|
|
1325
|
+
self._bse = None
|
|
1326
|
+
self._zvalues = None
|
|
1327
|
+
self._pvalues = None
|
|
1328
|
+
self._conf_int = None
|
|
1329
|
+
self._score_test_stat = None
|
|
1330
|
+
self._score_test_pvalue = None
|
|
1331
|
+
self._wald_test_stat = None
|
|
1332
|
+
self._wald_test_pvalue = None
|
|
1333
|
+
self._lr_test_stat = None
|
|
1334
|
+
self._lr_test_pvalue = None
|
|
1335
|
+
self._baseline_hazard = None
|
|
1336
|
+
self._baseline_cumulative_hazard = None
|
|
1337
|
+
self._unique_times = None
|
|
1338
|
+
self._cleanup_torch_memory()
|
|
1339
|
+
|
|
1340
|
+
def _compute_log_likelihood(self, beta, X, time, event, efron_pre=None, entry=None):
|
|
1341
|
+
"""Compute log partial likelihood (Breslow/Efron tie handling)."""
|
|
1342
|
+
eta = X @ beta
|
|
1343
|
+
eta_eff = eta
|
|
1344
|
+
if entry is not None and self.ties == "breslow":
|
|
1345
|
+
eta_eff = eta - np.max(eta)
|
|
1346
|
+
# Note: We do NOT center eta here. While centering prevents exp overflow,
|
|
1347
|
+
# it introduces a beta-dependent shift that complicates numeric gradient verification.
|
|
1348
|
+
# In practice, exp(eta) overflow is rare when beta is near convergence.
|
|
1349
|
+
exp_eta = np.exp(eta_eff)
|
|
1350
|
+
|
|
1351
|
+
# Risk set suffix sums for standard (no-entry) path.
|
|
1352
|
+
risk_sum = np.cumsum(exp_eta[::-1])[::-1] if entry is None else None
|
|
1353
|
+
|
|
1354
|
+
event_mask = event == 1
|
|
1355
|
+
if not np.any(event_mask):
|
|
1356
|
+
return 0.0
|
|
1357
|
+
|
|
1358
|
+
if self.ties == "breslow":
|
|
1359
|
+
if entry is not None:
|
|
1360
|
+
fail_groups = getattr(self, "_entry_fail_groups_np", None)
|
|
1361
|
+
add_end_np = getattr(self, "_entry_add_end_np", None)
|
|
1362
|
+
rem_end_np = getattr(self, "_entry_rem_end_np", None)
|
|
1363
|
+
order_np = getattr(self, "_entry_order_np", None)
|
|
1364
|
+
if (
|
|
1365
|
+
fail_groups is None
|
|
1366
|
+
or add_end_np is None
|
|
1367
|
+
or rem_end_np is None
|
|
1368
|
+
or order_np is None
|
|
1369
|
+
):
|
|
1370
|
+
event_idx = np.flatnonzero(event_mask)
|
|
1371
|
+
event_times = time[event_idx]
|
|
1372
|
+
uft_np, inv_np = np.unique(event_times, return_inverse=True)
|
|
1373
|
+
fail_groups = [
|
|
1374
|
+
event_idx[inv_np == g].astype(np.int64, copy=False)
|
|
1375
|
+
for g in range(len(uft_np))
|
|
1376
|
+
]
|
|
1377
|
+
order_np = np.argsort(np.asarray(entry, dtype=np.float64)).astype(np.int64, copy=False)
|
|
1378
|
+
add_end_np = np.searchsorted(
|
|
1379
|
+
np.asarray(entry, dtype=np.float64)[order_np], uft_np, side="left"
|
|
1380
|
+
).astype(np.int64, copy=False)
|
|
1381
|
+
rem_end_np = np.searchsorted(time, uft_np, side="left").astype(np.int64, copy=False)
|
|
1382
|
+
|
|
1383
|
+
s0 = 0.0
|
|
1384
|
+
add_ptr = 0
|
|
1385
|
+
rem_ptr = 0
|
|
1386
|
+
ll = 0.0
|
|
1387
|
+
for g, fail_idx in enumerate(fail_groups):
|
|
1388
|
+
add_end = int(add_end_np[g])
|
|
1389
|
+
if add_end > add_ptr:
|
|
1390
|
+
idx_add = order_np[add_ptr:add_end]
|
|
1391
|
+
s0 += float(np.sum(exp_eta[idx_add]))
|
|
1392
|
+
add_ptr = add_end
|
|
1393
|
+
rem_end = int(rem_end_np[g])
|
|
1394
|
+
if rem_end > rem_ptr:
|
|
1395
|
+
s0 -= float(np.sum(exp_eta[rem_ptr:rem_end]))
|
|
1396
|
+
rem_ptr = rem_end
|
|
1397
|
+
d_t = int(fail_idx.shape[0])
|
|
1398
|
+
if d_t <= 0:
|
|
1399
|
+
continue
|
|
1400
|
+
s0_safe = max(s0, 1e-300)
|
|
1401
|
+
ll += float(np.sum(eta_eff[fail_idx]) - d_t * np.log(s0_safe))
|
|
1402
|
+
return float(ll)
|
|
1403
|
+
|
|
1404
|
+
# l(β) = sum_i(eta_i) - sum_t(d_t * log(S0(t)))
|
|
1405
|
+
breslow_pre = getattr(self, "_breslow_pre", None)
|
|
1406
|
+
if (
|
|
1407
|
+
breslow_pre is not None
|
|
1408
|
+
and len(breslow_pre) == 2
|
|
1409
|
+
and breslow_pre[0].size > 0
|
|
1410
|
+
):
|
|
1411
|
+
first_idx = breslow_pre[0].astype(np.int64, copy=False)
|
|
1412
|
+
counts = breslow_pre[1].astype(np.float64, copy=False)
|
|
1413
|
+
else:
|
|
1414
|
+
event_times = time[event_mask]
|
|
1415
|
+
uft, counts_i = np.unique(event_times, return_counts=True)
|
|
1416
|
+
first_idx = np.searchsorted(time, uft, side="left").astype(np.int64)
|
|
1417
|
+
counts = counts_i.astype(np.float64)
|
|
1418
|
+
risk_at = risk_sum[first_idx]
|
|
1419
|
+
# With centering: ll = sum(eta_i - eta_max) - sum(d_t * log(S0(t) * exp(-eta_max)))
|
|
1420
|
+
# = sum(eta_i) - n_events*eta_max - sum(d_t * (log(S0(t)) - eta_max))
|
|
1421
|
+
# = sum(eta_i) - n_events*eta_max - sum(d_t * log(S0(t))) + n_events*eta_max
|
|
1422
|
+
# = sum(eta_i) - sum(d_t * log(S0(t))) [eta_max cancels]
|
|
1423
|
+
return float(np.sum(eta_eff[event_mask]) - np.sum(counts * np.log(risk_at)))
|
|
1424
|
+
|
|
1425
|
+
# ---- Efron ----
|
|
1426
|
+
ll = 0.0
|
|
1427
|
+
if efron_pre is not None:
|
|
1428
|
+
uft, uft_ix, _, _, nuft, first_idx_uft = _unpack_efron_pre6(efron_pre)
|
|
1429
|
+
|
|
1430
|
+
# Sum of eta for all events (centering cancels out, use original eta)
|
|
1431
|
+
all_eta_sum = 0.0
|
|
1432
|
+
all_log_denom_sum = 0.0
|
|
1433
|
+
|
|
1434
|
+
for g in range(nuft):
|
|
1435
|
+
ix_ev = uft_ix[g]
|
|
1436
|
+
d = len(ix_ev)
|
|
1437
|
+
if d == 0:
|
|
1438
|
+
continue
|
|
1439
|
+
first_idx = (
|
|
1440
|
+
int(first_idx_uft[g])
|
|
1441
|
+
if first_idx_uft is not None
|
|
1442
|
+
else int(np.searchsorted(time, uft[g], side="left"))
|
|
1443
|
+
)
|
|
1444
|
+
risk_at_t = risk_sum[first_idx]
|
|
1445
|
+
sum_events = float(np.sum(exp_eta[ix_ev]))
|
|
1446
|
+
all_eta_sum += float(np.sum(eta[ix_ev]))
|
|
1447
|
+
|
|
1448
|
+
# Vectorized log denominator sum
|
|
1449
|
+
# Pre-compute k/d values to avoid repeated division
|
|
1450
|
+
k_vals = np.arange(d, dtype=np.float64)
|
|
1451
|
+
denom = risk_at_t - (k_vals / d) * sum_events
|
|
1452
|
+
all_log_denom_sum += float(np.sum(np.log(np.maximum(denom, 1e-300))))
|
|
1453
|
+
|
|
1454
|
+
return float(all_eta_sum - all_log_denom_sum)
|
|
1455
|
+
|
|
1456
|
+
# No precomputation: group event rows by unique failure times (vectorized).
|
|
1457
|
+
event_idx = np.flatnonzero(event_mask)
|
|
1458
|
+
event_times = time[event_idx]
|
|
1459
|
+
uft, inv, counts = np.unique(event_times, return_inverse=True, return_counts=True)
|
|
1460
|
+
first_idx = np.searchsorted(time, uft, side="left").astype(np.int64)
|
|
1461
|
+
risk_at = risk_sum[first_idx]
|
|
1462
|
+
|
|
1463
|
+
sum_events = np.bincount(inv, weights=exp_eta[event_idx], minlength=len(uft)).astype(np.float64)
|
|
1464
|
+
sum_eta_events = np.bincount(inv, weights=eta[event_idx], minlength=len(uft)).astype(np.float64)
|
|
1465
|
+
|
|
1466
|
+
# Vectorized log-likelihood computation
|
|
1467
|
+
ll = float(np.sum(sum_eta_events))
|
|
1468
|
+
|
|
1469
|
+
# For each unique failure time, compute sum of log denominators
|
|
1470
|
+
max_d = int(np.max(counts)) if len(counts) > 0 else 0
|
|
1471
|
+
if max_d > 0:
|
|
1472
|
+
# Create k matrix: (n_uft, max_d) where each row has [0/d, 1/d, ..., (d-1)/d]
|
|
1473
|
+
# Use broadcasting with careful masking for different d values
|
|
1474
|
+
k_matrix = np.arange(max_d, dtype=np.float64) / np.arange(1, max_d + 1, dtype=np.float64)[:, np.newaxis]
|
|
1475
|
+
# This is complex; fall back to loop for correctness
|
|
1476
|
+
for g in range(len(uft)):
|
|
1477
|
+
d = int(counts[g])
|
|
1478
|
+
if d == 0:
|
|
1479
|
+
continue
|
|
1480
|
+
k = np.arange(d, dtype=np.float64) / d
|
|
1481
|
+
denom = risk_at[g] - k * sum_events[g]
|
|
1482
|
+
ll -= float(np.sum(np.log(np.maximum(denom, 1e-300))))
|
|
1483
|
+
else:
|
|
1484
|
+
for g in range(len(uft)):
|
|
1485
|
+
d = int(counts[g])
|
|
1486
|
+
if d == 0:
|
|
1487
|
+
continue
|
|
1488
|
+
k = np.arange(d, dtype=np.float64) / d
|
|
1489
|
+
denom = risk_at[g] - k * sum_events[g]
|
|
1490
|
+
ll -= float(np.sum(np.log(np.maximum(denom, 1e-300))))
|
|
1491
|
+
|
|
1492
|
+
return float(ll)
|
|
1493
|
+
|
|
1494
|
+
def _solve_newton_delta_gpu(self, hess, grad, cp, eye_cache=None):
|
|
1495
|
+
"""Newton step delta = inv(hess) @ grad; prefer SPD solve on (-hess) with light jitter."""
|
|
1496
|
+
p = int(hess.shape[0])
|
|
1497
|
+
try:
|
|
1498
|
+
H = -hess
|
|
1499
|
+
eps = 1e-11 * (cp.max(cp.abs(cp.diag(H))) + 1.0)
|
|
1500
|
+
jitter_eye = eye_cache if eye_cache is not None else cp.eye(p, dtype=cp.float64)
|
|
1501
|
+
H = H + eps * jitter_eye
|
|
1502
|
+
# Fast path: SPD solve via Cholesky is usually faster than generic solve.
|
|
1503
|
+
try:
|
|
1504
|
+
L = cp.linalg.cholesky(H)
|
|
1505
|
+
y = cp.linalg.solve(L, grad)
|
|
1506
|
+
x = cp.linalg.solve(L.T, y)
|
|
1507
|
+
return -x
|
|
1508
|
+
except Exception:
|
|
1509
|
+
return -cp.linalg.solve(H, grad)
|
|
1510
|
+
except Exception:
|
|
1511
|
+
try:
|
|
1512
|
+
return cp.linalg.solve(hess, grad)
|
|
1513
|
+
except Exception:
|
|
1514
|
+
return cp.linalg.lstsq(hess, grad, rcond=None)[0].flatten()
|
|
1515
|
+
|
|
1516
|
+
def _compute_log_likelihood_gpu(self, beta, X, time, event, efron_pre=None, entry=None, entry_ctx=None):
|
|
1517
|
+
"""Compute log partial likelihood on GPU."""
|
|
1518
|
+
import cupy as cp
|
|
1519
|
+
|
|
1520
|
+
eta = X @ beta
|
|
1521
|
+
exp_eta = cp.exp(eta)
|
|
1522
|
+
# Entry+breslow path does not consume risk_sum; skip the cumsum to
|
|
1523
|
+
# reduce per-evaluation overhead during line-search probes.
|
|
1524
|
+
risk_sum = None if entry is not None else cp.cumsum(exp_eta[::-1])[::-1]
|
|
1525
|
+
return self._compute_log_likelihood_gpu_from_stats(
|
|
1526
|
+
eta, exp_eta, risk_sum, time, event, efron_pre, entry=entry, entry_ctx=entry_ctx
|
|
1527
|
+
)
|
|
1528
|
+
|
|
1529
|
+
def _build_entry_ctx_gpu(self, time, event, entry, cp):
|
|
1530
|
+
"""Build entry-time grouped indexing context for a specific sorted GPU view."""
|
|
1531
|
+
event_mask = event == 1
|
|
1532
|
+
event_idx = cp.where(event_mask)[0]
|
|
1533
|
+
evt_t = cp.asnumpy(time[event_idx])
|
|
1534
|
+
if evt_t.size == 0:
|
|
1535
|
+
return (
|
|
1536
|
+
cp.zeros((0,), dtype=cp.int64),
|
|
1537
|
+
np.zeros((0,), dtype=np.float64),
|
|
1538
|
+
np.zeros((0,), dtype=np.int64),
|
|
1539
|
+
np.zeros((0,), dtype=np.int64),
|
|
1540
|
+
cp.zeros((0,), dtype=cp.int64),
|
|
1541
|
+
cp.zeros((0,), dtype=cp.int64),
|
|
1542
|
+
np.zeros((1,), dtype=np.int64),
|
|
1543
|
+
)
|
|
1544
|
+
uft_np, d_counts = np.unique(evt_t, return_counts=True)
|
|
1545
|
+
d_counts = d_counts.astype(np.float64, copy=False)
|
|
1546
|
+
entry_order = cp.argsort(entry)
|
|
1547
|
+
entry_sorted_np = cp.asnumpy(entry[entry_order])
|
|
1548
|
+
time_np = cp.asnumpy(time)
|
|
1549
|
+
add_end_np = np.searchsorted(entry_sorted_np, uft_np, side="left").astype(np.int64, copy=False)
|
|
1550
|
+
rem_end_np = np.searchsorted(time_np, uft_np, side="left").astype(np.int64, copy=False)
|
|
1551
|
+
rem_order = cp.arange(int(time.shape[0]), dtype=cp.int64)
|
|
1552
|
+
event_idx = event_idx.astype(cp.int64, copy=False)
|
|
1553
|
+
fail_ptr = np.empty(d_counts.shape[0] + 1, dtype=np.int64)
|
|
1554
|
+
fail_ptr[0] = 0
|
|
1555
|
+
fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
|
|
1556
|
+
return (entry_order, d_counts, add_end_np, rem_end_np, rem_order, event_idx, fail_ptr)
|
|
1557
|
+
|
|
1558
|
+
def _compute_log_likelihood_gpu_from_stats(
|
|
1559
|
+
self, eta, exp_eta, risk_sum, time, event, efron_pre=None, entry=None, entry_ctx=None
|
|
1560
|
+
):
|
|
1561
|
+
"""Compute log partial likelihood on GPU with precomputed Efron stats."""
|
|
1562
|
+
import cupy as cp
|
|
1563
|
+
|
|
1564
|
+
ll = cp.array(0.0, dtype=cp.float64)
|
|
1565
|
+
event_mask = event == 1
|
|
1566
|
+
|
|
1567
|
+
if not cp.any(event_mask):
|
|
1568
|
+
return ll
|
|
1569
|
+
|
|
1570
|
+
if entry is not None:
|
|
1571
|
+
if entry_ctx is None:
|
|
1572
|
+
entry_order, d_counts, add_end_np, rem_end_np, _rem_order, event_idx, fail_ptr = self._build_entry_ctx_gpu(
|
|
1573
|
+
time, event, entry, cp
|
|
1574
|
+
)
|
|
1575
|
+
else:
|
|
1576
|
+
entry_order, d_counts, add_end_np, rem_end_np = entry_ctx[:4]
|
|
1577
|
+
event_idx = entry_ctx[6] if len(entry_ctx) > 6 else cp.where(event_mask)[0]
|
|
1578
|
+
fail_ptr = entry_ctx[8] if len(entry_ctx) > 8 else None
|
|
1579
|
+
n_groups = int(d_counts.shape[0])
|
|
1580
|
+
if n_groups == 0:
|
|
1581
|
+
return cp.array(0.0, dtype=cp.float64)
|
|
1582
|
+
if fail_ptr is None:
|
|
1583
|
+
fail_ptr = np.empty(n_groups + 1, dtype=np.int64)
|
|
1584
|
+
fail_ptr[0] = 0
|
|
1585
|
+
fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
|
|
1586
|
+
|
|
1587
|
+
exp_entry = exp_eta[entry_order]
|
|
1588
|
+
exp_rem = exp_eta
|
|
1589
|
+
add_pref = cp.cumsum(exp_entry, axis=0)
|
|
1590
|
+
rem_pref = cp.cumsum(exp_rem, axis=0)
|
|
1591
|
+
s0_add = cp.zeros(n_groups, dtype=cp.float64)
|
|
1592
|
+
s0_rem = cp.zeros(n_groups, dtype=cp.float64)
|
|
1593
|
+
mask_add = add_end_np > 0
|
|
1594
|
+
mask_rem = rem_end_np > 0
|
|
1595
|
+
if np.any(mask_add):
|
|
1596
|
+
idx_add = cp.asarray(add_end_np[mask_add] - 1, dtype=cp.int64)
|
|
1597
|
+
s0_add[cp.asarray(mask_add)] = add_pref[idx_add]
|
|
1598
|
+
if np.any(mask_rem):
|
|
1599
|
+
idx_rem = cp.asarray(rem_end_np[mask_rem] - 1, dtype=cp.int64)
|
|
1600
|
+
s0_rem[cp.asarray(mask_rem)] = rem_pref[idx_rem]
|
|
1601
|
+
s0_vec = cp.maximum(s0_add - s0_rem, 1e-300)
|
|
1602
|
+
event_eta = eta[event_idx]
|
|
1603
|
+
|
|
1604
|
+
if self.ties == "breslow":
|
|
1605
|
+
d_vec = cp.asarray(d_counts, dtype=cp.float64)
|
|
1606
|
+
return cp.sum(event_eta) - cp.sum(d_vec * cp.log(s0_vec))
|
|
1607
|
+
|
|
1608
|
+
ll = cp.sum(event_eta)
|
|
1609
|
+
event_exp = exp_eta[event_idx]
|
|
1610
|
+
for g in range(n_groups):
|
|
1611
|
+
d = int(d_counts[g])
|
|
1612
|
+
if d <= 0:
|
|
1613
|
+
continue
|
|
1614
|
+
st = int(fail_ptr[g])
|
|
1615
|
+
ed = int(fail_ptr[g + 1])
|
|
1616
|
+
ef = cp.sum(event_exp[st:ed])
|
|
1617
|
+
base = s0_vec[g]
|
|
1618
|
+
for k in range(d):
|
|
1619
|
+
denom = cp.maximum(base - (float(k) / float(d)) * ef, 1e-300)
|
|
1620
|
+
ll = ll - cp.log(denom)
|
|
1621
|
+
return ll
|
|
1622
|
+
|
|
1623
|
+
if self.ties == 'breslow':
|
|
1624
|
+
# Vectorized Breslow using cached failure groups to avoid
|
|
1625
|
+
# Python loops and host-device sync in GPU hot path.
|
|
1626
|
+
breslow_pre_gpu = getattr(self, "_breslow_pre_gpu", None)
|
|
1627
|
+
if (
|
|
1628
|
+
breslow_pre_gpu is not None
|
|
1629
|
+
and len(breslow_pre_gpu) == 2
|
|
1630
|
+
and int(breslow_pre_gpu[0].size) > 0
|
|
1631
|
+
):
|
|
1632
|
+
first_idx_uft, counts_uft = breslow_pre_gpu
|
|
1633
|
+
else:
|
|
1634
|
+
uft, counts_uft = cp.unique(time[event_mask], return_counts=True)
|
|
1635
|
+
first_idx_uft = cp.searchsorted(time, uft, side="left")
|
|
1636
|
+
counts_uft = counts_uft.astype(cp.int32, copy=False)
|
|
1637
|
+
risk_at = risk_sum[first_idx_uft]
|
|
1638
|
+
return cp.sum(eta[event_mask]) - cp.sum(
|
|
1639
|
+
counts_uft.astype(cp.float64) * cp.log(risk_at)
|
|
1640
|
+
)
|
|
1641
|
+
|
|
1642
|
+
# Efron: if all groups are singleton failures, Efron == Breslow.
|
|
1643
|
+
if getattr(self, "_efron_all_singletons", False):
|
|
1644
|
+
ep = efron_pre if efron_pre is not None else getattr(self, "_efron_pre", None)
|
|
1645
|
+
if ep is not None:
|
|
1646
|
+
_, _, _, _, nuft, first_idx_uft = _unpack_efron_pre6(ep)
|
|
1647
|
+
first_idx_uft = cp.asarray(first_idx_uft, dtype=cp.int32)
|
|
1648
|
+
counts_uft = cp.ones(int(nuft), dtype=cp.int32)
|
|
1649
|
+
else:
|
|
1650
|
+
uft, counts_uft = cp.unique(time[event_mask], return_counts=True)
|
|
1651
|
+
first_idx_uft = cp.searchsorted(time, uft, side="left")
|
|
1652
|
+
counts_uft = counts_uft.astype(cp.int32, copy=False)
|
|
1653
|
+
risk_at = risk_sum[first_idx_uft]
|
|
1654
|
+
return cp.sum(eta[event_mask]) - cp.sum(
|
|
1655
|
+
counts_uft.astype(cp.float64) * cp.log(risk_at)
|
|
1656
|
+
)
|
|
1657
|
+
|
|
1658
|
+
# Efron: loop over cached failure groups (see `_cox_efron_cuda.compute_efron_loglik_raw`)
|
|
1659
|
+
if efron_pre is not None:
|
|
1660
|
+
try:
|
|
1661
|
+
csr_gpu = getattr(self, "_efron_pre_csr_gpu", None)
|
|
1662
|
+
if csr_gpu is not None:
|
|
1663
|
+
from ._cox_efron_cuda import compute_efron_loglik_raw_csr
|
|
1664
|
+
|
|
1665
|
+
_, _, _, _, fail_ptr, fail_ind, first_idx_uft, nuft = csr_gpu
|
|
1666
|
+
return compute_efron_loglik_raw_csr(
|
|
1667
|
+
eta,
|
|
1668
|
+
exp_eta,
|
|
1669
|
+
risk_sum,
|
|
1670
|
+
fail_ptr,
|
|
1671
|
+
fail_ind,
|
|
1672
|
+
first_idx_uft,
|
|
1673
|
+
nuft,
|
|
1674
|
+
cupy_module=cp,
|
|
1675
|
+
)
|
|
1676
|
+
except Exception:
|
|
1677
|
+
pass
|
|
1678
|
+
|
|
1679
|
+
from ._cox_efron_cuda import compute_efron_loglik_raw
|
|
1680
|
+
|
|
1681
|
+
return compute_efron_loglik_raw(
|
|
1682
|
+
eta, exp_eta, risk_sum, time, efron_pre, cupy_module=cp
|
|
1683
|
+
)
|
|
1684
|
+
|
|
1685
|
+
unique_times = cp.unique(time[event_mask])
|
|
1686
|
+
for t in unique_times:
|
|
1687
|
+
at_time_t = time == t
|
|
1688
|
+
events_at_t = at_time_t & event_mask
|
|
1689
|
+
d = int(cp.sum(events_at_t).item())
|
|
1690
|
+
|
|
1691
|
+
if d == 0:
|
|
1692
|
+
continue
|
|
1693
|
+
|
|
1694
|
+
risk_indices = cp.where(time >= t)[0]
|
|
1695
|
+
if risk_indices.size == 0:
|
|
1696
|
+
continue
|
|
1697
|
+
|
|
1698
|
+
first_idx = risk_indices[0]
|
|
1699
|
+
risk_at_t = risk_sum[first_idx]
|
|
1700
|
+
sum_events = cp.sum(exp_eta[events_at_t])
|
|
1701
|
+
|
|
1702
|
+
ll += cp.sum(eta[events_at_t])
|
|
1703
|
+
for k in range(d):
|
|
1704
|
+
ll -= cp.log(cp.maximum(risk_at_t - (k / d) * sum_events, 1e-300))
|
|
1705
|
+
|
|
1706
|
+
return ll
|
|
1707
|
+
|
|
1708
|
+
def _compute_gradient_hessian(self, beta, X, time, event, efron_pre=None, entry=None):
|
|
1709
|
+
"""
|
|
1710
|
+
Gradient and Hessian of the log partial likelihood (same sign convention as statsmodels).
|
|
1711
|
+
|
|
1712
|
+
Parameters
|
|
1713
|
+
----------
|
|
1714
|
+
efron_pre : optional
|
|
1715
|
+
Output of `_efron_unique_failure_indices`; if None and ties='efron', it is recomputed.
|
|
1716
|
+
Pass the cached structure from `fit` to avoid O(n) Python work every Newton step.
|
|
1717
|
+
"""
|
|
1718
|
+
n_samples, n_features = X.shape
|
|
1719
|
+
|
|
1720
|
+
# Linear predictor
|
|
1721
|
+
eta = X @ beta
|
|
1722
|
+
eta_eff = eta
|
|
1723
|
+
if entry is not None and self.ties == "breslow":
|
|
1724
|
+
eta_eff = eta - np.max(eta)
|
|
1725
|
+
exp_eta = np.exp(eta_eff)
|
|
1726
|
+
|
|
1727
|
+
risk_sum = np.cumsum(exp_eta[::-1])[::-1] if entry is None else None
|
|
1728
|
+
X_exp_eta = X * exp_eta[:, np.newaxis]
|
|
1729
|
+
risk_X_sum = np.cumsum(X_exp_eta[::-1], axis=0)[::-1] if entry is None else None
|
|
1730
|
+
|
|
1731
|
+
if self.ties == 'breslow':
|
|
1732
|
+
event_mask = event == 1
|
|
1733
|
+
grad = np.zeros(n_features, dtype=np.float64)
|
|
1734
|
+
if entry is not None:
|
|
1735
|
+
fail_groups = getattr(self, "_entry_fail_groups_np", None)
|
|
1736
|
+
add_end_np = getattr(self, "_entry_add_end_np", None)
|
|
1737
|
+
rem_end_np = getattr(self, "_entry_rem_end_np", None)
|
|
1738
|
+
order_np = getattr(self, "_entry_order_np", None)
|
|
1739
|
+
if (
|
|
1740
|
+
fail_groups is None
|
|
1741
|
+
or add_end_np is None
|
|
1742
|
+
or rem_end_np is None
|
|
1743
|
+
or order_np is None
|
|
1744
|
+
):
|
|
1745
|
+
event_idx = np.flatnonzero(event_mask)
|
|
1746
|
+
event_times = time[event_idx]
|
|
1747
|
+
uft_np, inv_np = np.unique(event_times, return_inverse=True)
|
|
1748
|
+
fail_groups = [
|
|
1749
|
+
event_idx[inv_np == g].astype(np.int64, copy=False)
|
|
1750
|
+
for g in range(len(uft_np))
|
|
1751
|
+
]
|
|
1752
|
+
order_np = np.argsort(np.asarray(entry, dtype=np.float64)).astype(np.int64, copy=False)
|
|
1753
|
+
add_end_np = np.searchsorted(
|
|
1754
|
+
np.asarray(entry, dtype=np.float64)[order_np], uft_np, side="left"
|
|
1755
|
+
).astype(np.int64, copy=False)
|
|
1756
|
+
rem_end_np = np.searchsorted(time, uft_np, side="left").astype(np.int64, copy=False)
|
|
1757
|
+
|
|
1758
|
+
hess = np.zeros((n_features, n_features), dtype=np.float64)
|
|
1759
|
+
s0 = 0.0
|
|
1760
|
+
s1 = np.zeros(n_features, dtype=np.float64)
|
|
1761
|
+
s2 = np.zeros((n_features, n_features), dtype=np.float64)
|
|
1762
|
+
add_ptr = 0
|
|
1763
|
+
rem_ptr = 0
|
|
1764
|
+
for g, fail_idx in enumerate(fail_groups):
|
|
1765
|
+
add_end = int(add_end_np[g])
|
|
1766
|
+
if add_end > add_ptr:
|
|
1767
|
+
idx_add = order_np[add_ptr:add_end]
|
|
1768
|
+
x_add = X[idx_add]
|
|
1769
|
+
w_add = exp_eta[idx_add]
|
|
1770
|
+
wx_add = x_add * w_add[:, np.newaxis]
|
|
1771
|
+
s0 += float(np.sum(w_add))
|
|
1772
|
+
s1 += np.sum(wx_add, axis=0)
|
|
1773
|
+
s2 += wx_add.T @ x_add
|
|
1774
|
+
add_ptr = add_end
|
|
1775
|
+
rem_end = int(rem_end_np[g])
|
|
1776
|
+
if rem_end > rem_ptr:
|
|
1777
|
+
x_rem = X[rem_ptr:rem_end]
|
|
1778
|
+
w_rem = exp_eta[rem_ptr:rem_end]
|
|
1779
|
+
wx_rem = x_rem * w_rem[:, np.newaxis]
|
|
1780
|
+
s0 -= float(np.sum(w_rem))
|
|
1781
|
+
s1 -= np.sum(wx_rem, axis=0)
|
|
1782
|
+
s2 -= wx_rem.T @ x_rem
|
|
1783
|
+
rem_ptr = rem_end
|
|
1784
|
+
d_t = int(fail_idx.shape[0])
|
|
1785
|
+
if d_t <= 0:
|
|
1786
|
+
continue
|
|
1787
|
+
d_t_f = float(d_t)
|
|
1788
|
+
grad += np.sum(X[fail_idx], axis=0)
|
|
1789
|
+
s0_safe = max(s0, 1e-300)
|
|
1790
|
+
if s0 <= 1e-15:
|
|
1791
|
+
continue
|
|
1792
|
+
ex = s1 / s0_safe
|
|
1793
|
+
grad -= d_t_f * ex
|
|
1794
|
+
hess -= d_t_f * (s2 / s0_safe - np.outer(ex, ex))
|
|
1795
|
+
return grad, hess
|
|
1796
|
+
|
|
1797
|
+
first_idx = np.array([], dtype=np.int64)
|
|
1798
|
+
counts = np.array([], dtype=np.float64)
|
|
1799
|
+
if np.any(event_mask):
|
|
1800
|
+
breslow_pre = getattr(self, "_breslow_pre", None)
|
|
1801
|
+
if (
|
|
1802
|
+
breslow_pre is not None
|
|
1803
|
+
and len(breslow_pre) == 2
|
|
1804
|
+
and breslow_pre[0].size > 0
|
|
1805
|
+
):
|
|
1806
|
+
first_idx = breslow_pre[0].astype(np.int64, copy=False)
|
|
1807
|
+
counts = breslow_pre[1].astype(np.float64, copy=False)
|
|
1808
|
+
else:
|
|
1809
|
+
event_times = time[event_mask]
|
|
1810
|
+
uft, counts_i = np.unique(event_times, return_counts=True)
|
|
1811
|
+
first_idx = np.searchsorted(time, uft, side="left").astype(np.int64)
|
|
1812
|
+
counts = counts_i.astype(np.float64)
|
|
1813
|
+
|
|
1814
|
+
sum_X_events = np.sum(X[event_mask], axis=0)
|
|
1815
|
+
E_X = risk_X_sum[first_idx] / risk_sum[first_idx][:, np.newaxis]
|
|
1816
|
+
grad = sum_X_events - np.sum(E_X * counts[:, np.newaxis], axis=0)
|
|
1817
|
+
|
|
1818
|
+
hess = self._compute_hessian_breslow_fast(
|
|
1819
|
+
X, time, event, risk_sum, risk_X_sum, exp_eta, first_idx, counts
|
|
1820
|
+
)
|
|
1821
|
+
else:
|
|
1822
|
+
# Efron: prefer Cython core if available; fall back to Python implementation
|
|
1823
|
+
# for environments without compiled extension or unexpected runtime issues.
|
|
1824
|
+
# Shift eta by a constant for numerical stability in exp(eta). This does not
|
|
1825
|
+
# change Efron gradient/Hessian because terms are scale-invariant.
|
|
1826
|
+
eta_efron = eta - np.max(eta)
|
|
1827
|
+
if HAS_CYTHON_EFRON and efron_pre is not None:
|
|
1828
|
+
try:
|
|
1829
|
+
uft, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
1830
|
+
grad, hess = _efron_grad_hess_cython(
|
|
1831
|
+
eta_efron, X, risk_enter, risk_exit, uft_ix, nuft
|
|
1832
|
+
)
|
|
1833
|
+
# Align sign convention with existing CPU Efron backward path.
|
|
1834
|
+
hess = -hess
|
|
1835
|
+
if not (np.isfinite(grad).all() and np.isfinite(hess).all()):
|
|
1836
|
+
raise FloatingPointError("non-finite Cython Efron grad/hess")
|
|
1837
|
+
except Exception:
|
|
1838
|
+
from ._cox_efron_cy import efron_grad_hess_python
|
|
1839
|
+
uft, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
1840
|
+
grad, hess = efron_grad_hess_python(
|
|
1841
|
+
eta_efron, X, risk_enter, risk_exit, uft_ix, nuft
|
|
1842
|
+
)
|
|
1843
|
+
hess = -hess
|
|
1844
|
+
if not (np.isfinite(grad).all() and np.isfinite(hess).all()):
|
|
1845
|
+
grad, hess = self._compute_gradient_hessian_efron_backward(
|
|
1846
|
+
beta, X, time, event, efron_pre
|
|
1847
|
+
)
|
|
1848
|
+
else:
|
|
1849
|
+
grad, hess = self._compute_gradient_hessian_efron_backward(
|
|
1850
|
+
beta, X, time, event, efron_pre
|
|
1851
|
+
)
|
|
1852
|
+
|
|
1853
|
+
return grad, hess
|
|
1854
|
+
|
|
1855
|
+
def _compute_hessian_breslow_fast(
|
|
1856
|
+
self,
|
|
1857
|
+
X,
|
|
1858
|
+
time,
|
|
1859
|
+
event,
|
|
1860
|
+
risk_sum,
|
|
1861
|
+
risk_X_sum,
|
|
1862
|
+
exp_eta,
|
|
1863
|
+
first_idx=None,
|
|
1864
|
+
counts=None,
|
|
1865
|
+
):
|
|
1866
|
+
"""Compute Breslow Hessian with an auto-selected CPU strategy."""
|
|
1867
|
+
event_mask = event == 1
|
|
1868
|
+
if not np.any(event_mask):
|
|
1869
|
+
return np.zeros((X.shape[1], X.shape[1]), dtype=np.float64)
|
|
1870
|
+
|
|
1871
|
+
# Group tied events by unique failure times to share the same R(t)
|
|
1872
|
+
# denominator across all events at time t (Breslow ties).
|
|
1873
|
+
if first_idx is None or counts is None or len(first_idx) == 0:
|
|
1874
|
+
breslow_pre = getattr(self, "_breslow_pre", None)
|
|
1875
|
+
if (
|
|
1876
|
+
breslow_pre is not None
|
|
1877
|
+
and len(breslow_pre) == 2
|
|
1878
|
+
and breslow_pre[0].size > 0
|
|
1879
|
+
):
|
|
1880
|
+
first_idx = breslow_pre[0].astype(np.int64, copy=False)
|
|
1881
|
+
counts = breslow_pre[1].astype(np.float64, copy=False)
|
|
1882
|
+
else:
|
|
1883
|
+
event_times = time[event_mask]
|
|
1884
|
+
uft, counts_i = np.unique(event_times, return_counts=True)
|
|
1885
|
+
first_idx = np.searchsorted(time, uft, side="left").astype(np.int64)
|
|
1886
|
+
counts = counts_i.astype(np.float64)
|
|
1887
|
+
|
|
1888
|
+
# Two CPU kernels are kept intentionally:
|
|
1889
|
+
# 1) Tensor path: higher memory, but can be faster for small p / few groups.
|
|
1890
|
+
# 2) Incremental path: lower memory traffic for larger (n, p).
|
|
1891
|
+
p = int(X.shape[1])
|
|
1892
|
+
n_groups = int(len(first_idx))
|
|
1893
|
+
if p <= 24 and n_groups <= 512:
|
|
1894
|
+
return self._compute_hessian_breslow_tensor_grouped(
|
|
1895
|
+
X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
|
|
1896
|
+
)
|
|
1897
|
+
return self._compute_hessian_breslow_incremental_grouped(
|
|
1898
|
+
X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
|
|
1899
|
+
)
|
|
1900
|
+
|
|
1901
|
+
def _compute_hessian_breslow_tensor_grouped(
|
|
1902
|
+
self, X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
|
|
1903
|
+
):
|
|
1904
|
+
"""Grouped Breslow Hessian using explicit (n, p, p) tensor moments."""
|
|
1905
|
+
x2_weighted = np.einsum("ni,nj,n->nij", X, X, exp_eta)
|
|
1906
|
+
risk_X2_sum = np.cumsum(x2_weighted[::-1], axis=0)[::-1]
|
|
1907
|
+
risk_sum_at = risk_sum[first_idx]
|
|
1908
|
+
E_X = risk_X_sum[first_idx] / risk_sum_at[:, np.newaxis]
|
|
1909
|
+
E_XX = risk_X2_sum[first_idx] / risk_sum_at[:, np.newaxis, np.newaxis]
|
|
1910
|
+
centered = E_XX - np.einsum("ni,nj->nij", E_X, E_X)
|
|
1911
|
+
return -np.sum(centered * counts[:, np.newaxis, np.newaxis], axis=0)
|
|
1912
|
+
|
|
1913
|
+
def _compute_hessian_breslow_incremental_grouped(
|
|
1914
|
+
self, X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
|
|
1915
|
+
):
|
|
1916
|
+
"""Grouped Breslow Hessian with incremental risk-set second moments."""
|
|
1917
|
+
# risk_X2 tracks sum_{j in current risk set} exp_eta[j] * x_j x_j^T.
|
|
1918
|
+
X_exp = X * exp_eta[:, np.newaxis]
|
|
1919
|
+
risk_X2 = X_exp.T @ X
|
|
1920
|
+
|
|
1921
|
+
hess = np.zeros((X.shape[1], X.shape[1]), dtype=np.float64)
|
|
1922
|
+
prev_idx = 0
|
|
1923
|
+
for g in range(len(first_idx)):
|
|
1924
|
+
idx = int(first_idx[g])
|
|
1925
|
+
if idx > prev_idx:
|
|
1926
|
+
blk = slice(prev_idx, idx)
|
|
1927
|
+
# Remove rows that are no longer in risk set.
|
|
1928
|
+
risk_X2 -= X_exp[blk].T @ X[blk]
|
|
1929
|
+
prev_idx = idx
|
|
1930
|
+
|
|
1931
|
+
rs = float(risk_sum[idx])
|
|
1932
|
+
if rs <= 0.0:
|
|
1933
|
+
continue
|
|
1934
|
+
ex = risk_X_sum[idx] / rs
|
|
1935
|
+
exx = risk_X2 / rs
|
|
1936
|
+
hess -= counts[g] * (exx - np.outer(ex, ex))
|
|
1937
|
+
|
|
1938
|
+
return hess
|
|
1939
|
+
|
|
1940
|
+
def _compute_hessian_breslow_incremental_grouped_cupy(
|
|
1941
|
+
self, X, risk_sum, risk_X_sum, exp_eta, first_idx, counts
|
|
1942
|
+
):
|
|
1943
|
+
"""CuPy grouped Breslow Hessian with incremental risk-set second moments."""
|
|
1944
|
+
import cupy as cp
|
|
1945
|
+
from ._cox_efron_cuda import apply_breslow_hess_update_raw
|
|
1946
|
+
|
|
1947
|
+
X_exp = X * exp_eta[:, cp.newaxis]
|
|
1948
|
+
risk_X2 = X_exp.T @ X
|
|
1949
|
+
|
|
1950
|
+
p = int(X.shape[1])
|
|
1951
|
+
hess = cp.zeros((p, p), dtype=cp.float64)
|
|
1952
|
+
use_update_kernel = (
|
|
1953
|
+
os.environ.get("STATGPU_BRESLOW_HESS_UPDATE_KERNEL", "1").strip().lower()
|
|
1954
|
+
in ("1", "true", "yes", "on")
|
|
1955
|
+
)
|
|
1956
|
+
exx_buf = cp.empty((p, p), dtype=cp.float64) if not use_update_kernel else None
|
|
1957
|
+
outer_buf = cp.empty((p, p), dtype=cp.float64) if not use_update_kernel else None
|
|
1958
|
+
prev_idx = 0
|
|
1959
|
+
block_size_env = os.environ.get("STATGPU_BRESLOW_GEMM_BLOCK", "1024")
|
|
1960
|
+
try:
|
|
1961
|
+
block_size = int(block_size_env)
|
|
1962
|
+
except (TypeError, ValueError):
|
|
1963
|
+
block_size = 1024
|
|
1964
|
+
block_size = max(64, block_size)
|
|
1965
|
+
# Reuse cached host-side tie metadata when available.
|
|
1966
|
+
first_idx_np = getattr(self, "_breslow_first_idx_np", None)
|
|
1967
|
+
counts_np = getattr(self, "_breslow_counts_np", None)
|
|
1968
|
+
if (
|
|
1969
|
+
first_idx_np is None
|
|
1970
|
+
or counts_np is None
|
|
1971
|
+
or int(first_idx_np.shape[0]) != int(first_idx.shape[0])
|
|
1972
|
+
):
|
|
1973
|
+
first_idx_np = cp.asnumpy(first_idx).astype(np.int64, copy=False)
|
|
1974
|
+
counts_np = cp.asnumpy(counts).astype(np.float64, copy=False)
|
|
1975
|
+
risk_at_np = cp.asnumpy(risk_sum[first_idx]).astype(np.float64, copy=False)
|
|
1976
|
+
n_groups = int(first_idx_np.size)
|
|
1977
|
+
for g in range(n_groups):
|
|
1978
|
+
idx = int(first_idx_np[g])
|
|
1979
|
+
if idx > prev_idx:
|
|
1980
|
+
# Batch removals to reduce many tiny GEMM launches.
|
|
1981
|
+
cur = prev_idx
|
|
1982
|
+
while cur < idx:
|
|
1983
|
+
nxt = min(idx, cur + block_size)
|
|
1984
|
+
blk = slice(cur, nxt)
|
|
1985
|
+
risk_X2 -= (X_exp[blk].T @ X[blk])
|
|
1986
|
+
cur = nxt
|
|
1987
|
+
prev_idx = idx
|
|
1988
|
+
|
|
1989
|
+
rs = float(risk_at_np[g])
|
|
1990
|
+
if rs <= 0.0:
|
|
1991
|
+
continue
|
|
1992
|
+
ex = risk_X_sum[idx] / rs
|
|
1993
|
+
if use_update_kernel:
|
|
1994
|
+
apply_breslow_hess_update_raw(
|
|
1995
|
+
hess, risk_X2, ex, rs, counts_np[g], cupy_module=cp
|
|
1996
|
+
)
|
|
1997
|
+
else:
|
|
1998
|
+
inv_rs = 1.0 / rs
|
|
1999
|
+
cp.multiply(risk_X2, inv_rs, out=exx_buf)
|
|
2000
|
+
cp.multiply(ex[:, cp.newaxis], ex[cp.newaxis, :], out=outer_buf)
|
|
2001
|
+
cp.subtract(exx_buf, outer_buf, out=exx_buf)
|
|
2002
|
+
hess -= counts_np[g] * exx_buf
|
|
2003
|
+
|
|
2004
|
+
return hess
|
|
2005
|
+
|
|
2006
|
+
def _compute_hessian_breslow_fused_cupy(self, X, first_idx, counts, exp_eta):
|
|
2007
|
+
"""Try fused RawKernel Hessian for Breslow; return None on failure."""
|
|
2008
|
+
import cupy as cp
|
|
2009
|
+
debug_fused = (
|
|
2010
|
+
os.environ.get("STATGPU_DEBUG_BRESLOW_FUSED", "0").strip().lower()
|
|
2011
|
+
in ("1", "true", "yes", "on")
|
|
2012
|
+
)
|
|
2013
|
+
try:
|
|
2014
|
+
from ._cox_efron_cuda import compute_breslow_hess_raw
|
|
2015
|
+
return compute_breslow_hess_raw(
|
|
2016
|
+
X,
|
|
2017
|
+
first_idx,
|
|
2018
|
+
counts,
|
|
2019
|
+
cupy_module=cp,
|
|
2020
|
+
exp_eta=exp_eta,
|
|
2021
|
+
)
|
|
2022
|
+
except Exception as ex:
|
|
2023
|
+
if debug_fused:
|
|
2024
|
+
print(f"[CUDA Breslow fused fallback] {type(ex).__name__}: {ex}")
|
|
2025
|
+
return None
|
|
2026
|
+
|
|
2027
|
+
def _compute_hessian_breslow(self, beta, X, time, event, risk_sum, risk_X_sum, exp_eta):
|
|
2028
|
+
"""
|
|
2029
|
+
Compute Hessian for Breslow approximation.
|
|
2030
|
+
|
|
2031
|
+
Uses an incremental suffix-scan so total cost is O(n·p²) instead of
|
|
2032
|
+
the previous O(n_events × n × p²) triple-loop.
|
|
2033
|
+
|
|
2034
|
+
Algorithm:
|
|
2035
|
+
1. Compute the full second-moment matrix M = (X * exp_eta).T @ X -- O(n·p²).
|
|
2036
|
+
2. Walk through sorted event positions left-to-right, subtracting the
|
|
2037
|
+
contribution of rows that fall *before* the current event (and are
|
|
2038
|
+
therefore not in its risk set) from M incrementally.
|
|
2039
|
+
Each row is subtracted exactly once, so total subtraction work = O(n·p²).
|
|
2040
|
+
"""
|
|
2041
|
+
n_samples, n_features = X.shape
|
|
2042
|
+
hess = np.zeros((n_features, n_features), dtype=np.float64)
|
|
2043
|
+
|
|
2044
|
+
X_exp = X * exp_eta[:, np.newaxis] # (n, p)
|
|
2045
|
+
risk_X2_sum = X_exp.T @ X # (p, p), O(n·p²)
|
|
2046
|
+
|
|
2047
|
+
event_positions = np.where(event)[0] # sorted ascending
|
|
2048
|
+
prev_pos = 0
|
|
2049
|
+
|
|
2050
|
+
for ev_i in event_positions:
|
|
2051
|
+
# Remove rows [prev_pos, ev_i) from risk_X2_sum;
|
|
2052
|
+
# they have t < t[ev_i] and are no longer in R(t[ev_i]).
|
|
2053
|
+
if ev_i > prev_pos:
|
|
2054
|
+
blk = slice(prev_pos, ev_i)
|
|
2055
|
+
risk_X2_sum -= X_exp[blk].T @ X[blk] # O(k·p²), k = ev_i - prev_pos
|
|
2056
|
+
prev_pos = ev_i # next event will subtract starting from here
|
|
2057
|
+
|
|
2058
|
+
E_X = risk_X_sum[ev_i] / risk_sum[ev_i] # (p,)
|
|
2059
|
+
E_XX = risk_X2_sum / risk_sum[ev_i] # (p, p)
|
|
2060
|
+
hess -= E_XX - np.outer(E_X, E_X)
|
|
2061
|
+
|
|
2062
|
+
return hess
|
|
2063
|
+
|
|
2064
|
+
def _efron_unique_failure_indices(self, time: np.ndarray, event: np.ndarray):
|
|
2065
|
+
"""
|
|
2066
|
+
Unique failure-time bookkeeping (single stratum), matching statsmodels PHSurvivalTime.
|
|
2067
|
+
`time` must be sorted ascending (as in fit).
|
|
2068
|
+
"""
|
|
2069
|
+
ift = np.flatnonzero(event == 1)
|
|
2070
|
+
if ift.size == 0:
|
|
2071
|
+
return np.array([], dtype=np.float64), [], [], [], 0, np.array([], dtype=np.int32)
|
|
2072
|
+
n = time.shape[0]
|
|
2073
|
+
ft = time[ift]
|
|
2074
|
+
uft = np.unique(ft)
|
|
2075
|
+
nuft = int(uft.size)
|
|
2076
|
+
|
|
2077
|
+
# First row index at each unique failure time (sorted time); avoids searchsorted in log-likelihood loops.
|
|
2078
|
+
first_idx_uft = np.searchsorted(time, uft, side="left").astype(np.int32)
|
|
2079
|
+
|
|
2080
|
+
# uft_ix: group indices of event rows by unique failure time.
|
|
2081
|
+
group_ids = np.searchsorted(uft, ft, side="left").astype(np.int32) # shape: (n_events,)
|
|
2082
|
+
order_ev = np.argsort(group_ids, kind="stable")
|
|
2083
|
+
ift_sorted = ift[order_ev]
|
|
2084
|
+
group_sorted = group_ids[order_ev]
|
|
2085
|
+
counts_ev = np.bincount(group_sorted, minlength=nuft)
|
|
2086
|
+
ptr_ev = np.empty(nuft + 1, dtype=np.int32)
|
|
2087
|
+
ptr_ev[0] = 0
|
|
2088
|
+
ptr_ev[1:] = np.cumsum(counts_ev, dtype=np.int32)
|
|
2089
|
+
uft_ix = [ift_sorted[ptr_ev[i] : ptr_ev[i + 1]].tolist() for i in range(nuft)]
|
|
2090
|
+
|
|
2091
|
+
# risk_enter: for each unique failure time i, indices of samples with
|
|
2092
|
+
# uft[i-1] <= time < uft[i] (samples entering risk set as we scan backward).
|
|
2093
|
+
# For i=0, includes all samples with time >= uft[0].
|
|
2094
|
+
j_enter = np.searchsorted(uft, time, side="right").astype(np.int32) - 1
|
|
2095
|
+
mask_enter = j_enter >= 0
|
|
2096
|
+
idx_enter = np.nonzero(mask_enter)[0]
|
|
2097
|
+
j_enter_m = j_enter[mask_enter]
|
|
2098
|
+
order_en = np.argsort(j_enter_m, kind="stable")
|
|
2099
|
+
idx_enter_sorted = idx_enter[order_en]
|
|
2100
|
+
j_enter_sorted = j_enter_m[order_en]
|
|
2101
|
+
counts_en = np.bincount(j_enter_sorted, minlength=nuft)
|
|
2102
|
+
ptr_en = np.empty(nuft + 1, dtype=np.int32)
|
|
2103
|
+
ptr_en[0] = 0
|
|
2104
|
+
ptr_en[1:] = np.cumsum(counts_en, dtype=np.int32)
|
|
2105
|
+
risk_enter = [
|
|
2106
|
+
idx_enter_sorted[ptr_en[i] : ptr_en[i + 1]].tolist() for i in range(nuft)
|
|
2107
|
+
]
|
|
2108
|
+
|
|
2109
|
+
# risk_exit: for backward scan, this is NOT used in the standard Efron algorithm.
|
|
2110
|
+
# The original code had a placeholder that put all samples at index 0, which was wrong.
|
|
2111
|
+
# For proper backward scan, we don't need risk_exit - we only add samples via risk_enter.
|
|
2112
|
+
# Set risk_exit to empty lists for all indices.
|
|
2113
|
+
risk_exit = [[] for _ in range(nuft)]
|
|
2114
|
+
|
|
2115
|
+
return uft, uft_ix, risk_enter, risk_exit, nuft, first_idx_uft
|
|
2116
|
+
|
|
2117
|
+
@staticmethod
|
|
2118
|
+
def _use_heavy_ties_cpu_fallback() -> bool:
|
|
2119
|
+
"""Opt-in adaptive CPU fallback for heavy-ties GPU/Torch runs."""
|
|
2120
|
+
v = os.environ.get("STATGPU_HEAVY_TIES_CPU_FALLBACK", "0").strip().lower()
|
|
2121
|
+
return v in ("1", "true", "yes", "on")
|
|
2122
|
+
|
|
2123
|
+
def _should_cpu_fallback_heavy_ties(self, n_samples, n_features, avg_tie_size):
|
|
2124
|
+
"""Heuristic: small/medium problems with dense ties are often CPU-faster."""
|
|
2125
|
+
if not self._use_heavy_ties_cpu_fallback():
|
|
2126
|
+
return False
|
|
2127
|
+
if self.ties not in ("efron", "breslow"):
|
|
2128
|
+
return False
|
|
2129
|
+
if avg_tie_size < 8.0:
|
|
2130
|
+
return False
|
|
2131
|
+
return int(n_samples) <= 20000 and int(n_features) <= 64
|
|
2132
|
+
|
|
2133
|
+
def _breslow_unique_failure_groups(self, time: np.ndarray, event: np.ndarray):
|
|
2134
|
+
"""
|
|
2135
|
+
Breslow tie groups for sorted time/event.
|
|
2136
|
+
Returns (first_idx_uft, counts_uft), both int32 arrays.
|
|
2137
|
+
"""
|
|
2138
|
+
ift = np.flatnonzero(event == 1)
|
|
2139
|
+
if ift.size == 0:
|
|
2140
|
+
return np.array([], dtype=np.int32), np.array([], dtype=np.int32)
|
|
2141
|
+
ft = time[ift]
|
|
2142
|
+
uft, counts = np.unique(ft, return_counts=True)
|
|
2143
|
+
first_idx_uft = np.searchsorted(time, uft, side="left").astype(np.int32)
|
|
2144
|
+
return first_idx_uft, counts.astype(np.int32)
|
|
2145
|
+
|
|
2146
|
+
def _compute_gradient_hessian_efron_backward(self, beta, X, time, event, efron_pre=None):
|
|
2147
|
+
"""
|
|
2148
|
+
Efron gradient and Hessian using direct computation (O(n*d) per unique failure time).
|
|
2149
|
+
|
|
2150
|
+
The gradient is: sum_events(X_i) - sum_events(E[X|R(t_i)])
|
|
2151
|
+
where E[X|R(t_i)] uses the Efron approximation.
|
|
2152
|
+
|
|
2153
|
+
Note: We do NOT center eta for consistency with _compute_log_likelihood.
|
|
2154
|
+
The Efron ratio formula is scale-invariant, so centering is not needed for
|
|
2155
|
+
numerical stability in typical use cases.
|
|
2156
|
+
"""
|
|
2157
|
+
n_samples, n_features = X.shape
|
|
2158
|
+
linpred = X @ beta
|
|
2159
|
+
# No centering - matches _compute_log_likelihood
|
|
2160
|
+
e_linpred = np.exp(linpred)
|
|
2161
|
+
|
|
2162
|
+
event_mask = event == 1
|
|
2163
|
+
event_idx = np.where(event_mask)[0]
|
|
2164
|
+
|
|
2165
|
+
if len(event_idx) == 0:
|
|
2166
|
+
return np.zeros(n_features, dtype=np.float64), np.zeros((n_features, n_features), dtype=np.float64)
|
|
2167
|
+
|
|
2168
|
+
# Get unique failure times and their counts
|
|
2169
|
+
event_times = time[event_mask]
|
|
2170
|
+
uft, counts = np.unique(event_times, return_counts=True)
|
|
2171
|
+
nuft = len(uft)
|
|
2172
|
+
|
|
2173
|
+
grad = np.zeros(n_features, dtype=np.float64)
|
|
2174
|
+
hess_inner = np.zeros((n_features, n_features), dtype=np.float64)
|
|
2175
|
+
|
|
2176
|
+
# Pre-compute suffix sums for risk sets
|
|
2177
|
+
# risk_sum[i] = sum of exp(lp) for all j with time[j] >= time[i]
|
|
2178
|
+
order = np.argsort(time)
|
|
2179
|
+
time_sorted = time[order]
|
|
2180
|
+
e_lp_sorted = e_linpred[order]
|
|
2181
|
+
X_sorted = X[order]
|
|
2182
|
+
|
|
2183
|
+
# Suffix sum: risk_sum_sorted[i] = sum of e_lp_sorted[j] for j >= i
|
|
2184
|
+
risk_sum_sorted = np.cumsum(e_lp_sorted[::-1])[::-1]
|
|
2185
|
+
# risk_X_sum_sorted[i] = sum of e_lp_sorted[j] * X_sorted[j] for j >= i
|
|
2186
|
+
risk_X_sum_sorted = np.cumsum((X_sorted * e_lp_sorted[:, np.newaxis])[::-1], axis=0)[::-1]
|
|
2187
|
+
# risk_XX_sum_sorted[i] = sum of e_lp_sorted[j] * X_sorted[j] @ X_sorted[j]^T for j >= i
|
|
2188
|
+
# Use matrix multiplication trick: (X^T diag(e) X) but we need per-row cumulative
|
|
2189
|
+
# Direct einsum is clearest but slow; alternative is loop-based accumulation
|
|
2190
|
+
# For now, use einsum - it's O(n*p^2) but vectorized
|
|
2191
|
+
XX_outer = np.einsum('ni,nj,n->nij', X_sorted, X_sorted, e_lp_sorted)
|
|
2192
|
+
risk_XX_sum_sorted = np.cumsum(XX_outer[::-1], axis=0)[::-1]
|
|
2193
|
+
|
|
2194
|
+
# For each unique failure time, compute the Efron-adjusted expectation
|
|
2195
|
+
for g in range(nuft):
|
|
2196
|
+
t_g = uft[g]
|
|
2197
|
+
d_g = counts[g]
|
|
2198
|
+
|
|
2199
|
+
# Find first index in sorted array with time >= t_g
|
|
2200
|
+
first_idx = np.searchsorted(time_sorted, t_g, side='left')
|
|
2201
|
+
|
|
2202
|
+
# Risk set sums at t_g
|
|
2203
|
+
S0 = risk_sum_sorted[first_idx]
|
|
2204
|
+
S1 = risk_X_sum_sorted[first_idx] # sum of e^lp * X for risk set
|
|
2205
|
+
|
|
2206
|
+
# Events at this time
|
|
2207
|
+
events_at_g = event_idx[event_times == t_g]
|
|
2208
|
+
X_events = X[events_at_g]
|
|
2209
|
+
sum_X_events = X_events.sum(axis=0)
|
|
2210
|
+
|
|
2211
|
+
# Efron approximation: E[X|R(t)] ≈ (1/d) * sum_{k=0}^{d-1} S1(t - k*S0/d) / (S0 - k*S0/d)
|
|
2212
|
+
# Simplified: for each k, weight = 1/(S0 * (1 - k/d)) = 1/(S0 - k*S0/d)
|
|
2213
|
+
# But we need to handle the case where some observations are the events themselves
|
|
2214
|
+
|
|
2215
|
+
# Direct Efron formula for gradient contribution:
|
|
2216
|
+
# sum_{j in events} X_j - sum_{k=0}^{d-1} S1 / (S0 - k*S0/d)
|
|
2217
|
+
|
|
2218
|
+
# Actually, the correct Efron gradient is:
|
|
2219
|
+
# sum_events(X) - sum_{k=0}^{d-1} [S1 / (S0 - (k/d)*sum_events(e^lp))]
|
|
2220
|
+
|
|
2221
|
+
# sum of e^lp for events at this time
|
|
2222
|
+
sum_e_events = e_linpred[events_at_g].sum()
|
|
2223
|
+
|
|
2224
|
+
# Efron adjustment: for k in 0..d-1, compute gradient and Hessian contributions
|
|
2225
|
+
for k in range(d_g):
|
|
2226
|
+
frac = k / d_g
|
|
2227
|
+
denom = S0 - frac * sum_e_events
|
|
2228
|
+
if denom < 1e-300:
|
|
2229
|
+
denom = 1e-300
|
|
2230
|
+
|
|
2231
|
+
# Gradient: S1 / denom (subtracted from sum_X_events later)
|
|
2232
|
+
grad_contrib = S1 / denom
|
|
2233
|
+
grad -= grad_contrib
|
|
2234
|
+
|
|
2235
|
+
# Hessian: -risk_XX_sum/denom + outer(S1,S1)/denom^2
|
|
2236
|
+
# Both terms are needed for correct Newton direction
|
|
2237
|
+
risk_XX_sum = risk_XX_sum_sorted[first_idx]
|
|
2238
|
+
hess_inner -= risk_XX_sum / denom
|
|
2239
|
+
hess_inner += np.outer(S1, S1) / (denom * denom)
|
|
2240
|
+
|
|
2241
|
+
# Add event contribution to gradient
|
|
2242
|
+
grad += sum_X_events
|
|
2243
|
+
|
|
2244
|
+
hess = -hess_inner
|
|
2245
|
+
return grad, hess
|
|
2246
|
+
|
|
2247
|
+
def _compute_gradient_hessian_gpu(
|
|
2248
|
+
self, beta, X, time, event, efron_pre=None, return_aux=False, entry=None, entry_ctx=None
|
|
2249
|
+
):
|
|
2250
|
+
"""Compute gradient and Hessian on GPU."""
|
|
2251
|
+
import cupy as cp
|
|
2252
|
+
import time as _time
|
|
2253
|
+
|
|
2254
|
+
n_samples, n_features = X.shape
|
|
2255
|
+
|
|
2256
|
+
profile_breslow = (
|
|
2257
|
+
os.environ.get("STATGPU_PROFILE_BRESLOW_CUDA", "0").strip().lower()
|
|
2258
|
+
in ("1", "true", "yes", "on")
|
|
2259
|
+
)
|
|
2260
|
+
_t0_all = _time.perf_counter() if profile_breslow else None
|
|
2261
|
+
eta = X @ beta
|
|
2262
|
+
exp_eta = cp.exp(eta)
|
|
2263
|
+
event_mask = event == 1
|
|
2264
|
+
|
|
2265
|
+
# Risk sets (entry-aware path uses dynamic masks below).
|
|
2266
|
+
risk_sum = cp.cumsum(exp_eta[::-1])[::-1] if entry is None else None
|
|
2267
|
+
X_exp_eta = X * exp_eta[:, cp.newaxis]
|
|
2268
|
+
risk_X_sum = cp.cumsum(X_exp_eta[::-1], axis=0)[::-1] if entry is None else None
|
|
2269
|
+
if profile_breslow:
|
|
2270
|
+
cp.cuda.Stream.null.synchronize()
|
|
2271
|
+
_t_pre = _time.perf_counter()
|
|
2272
|
+
|
|
2273
|
+
# Efron: when no ties, use Breslow vectorized path.
|
|
2274
|
+
if self.ties == "efron" and entry is None:
|
|
2275
|
+
if getattr(self, "_efron_all_singletons", False):
|
|
2276
|
+
ep = efron_pre if efron_pre is not None else getattr(self, "_efron_pre", None)
|
|
2277
|
+
if ep is not None:
|
|
2278
|
+
_, _, _, _, nuft, first_idx_uft = _unpack_efron_pre6(ep)
|
|
2279
|
+
first_idx_uft = cp.asarray(first_idx_uft, dtype=cp.int32)
|
|
2280
|
+
counts_uft = cp.ones(int(nuft), dtype=cp.int32)
|
|
2281
|
+
else:
|
|
2282
|
+
uft, counts_uft = cp.unique(time[event_mask], return_counts=True)
|
|
2283
|
+
first_idx_uft = cp.searchsorted(time, uft, side="left")
|
|
2284
|
+
counts_uft = counts_uft.astype(cp.int32, copy=False)
|
|
2285
|
+
counts_f = counts_uft.astype(cp.float64)
|
|
2286
|
+
grad_pre = getattr(self, "_event_X_sum_gpu", None)
|
|
2287
|
+
grad = (
|
|
2288
|
+
grad_pre.copy()
|
|
2289
|
+
if grad_pre is not None and int(grad_pre.shape[0]) == int(n_features)
|
|
2290
|
+
else cp.sum(X[event_mask], axis=0)
|
|
2291
|
+
)
|
|
2292
|
+
E_X = risk_X_sum[first_idx_uft] / risk_sum[first_idx_uft][:, cp.newaxis]
|
|
2293
|
+
grad = grad - cp.sum(E_X * counts_f[:, cp.newaxis], axis=0)
|
|
2294
|
+
use_fused_breslow = (
|
|
2295
|
+
os.environ.get("STATGPU_BRESLOW_FUSED_CUPY", "0").strip().lower()
|
|
2296
|
+
in ("1", "true", "yes", "on")
|
|
2297
|
+
)
|
|
2298
|
+
hess = None
|
|
2299
|
+
if use_fused_breslow:
|
|
2300
|
+
hess = self._compute_hessian_breslow_fused_cupy(
|
|
2301
|
+
X, first_idx_uft, counts_f, exp_eta
|
|
2302
|
+
)
|
|
2303
|
+
if hess is None:
|
|
2304
|
+
hess = self._compute_hessian_breslow_incremental_grouped_cupy(
|
|
2305
|
+
X, risk_sum, risk_X_sum, exp_eta, first_idx_uft, counts_f
|
|
2306
|
+
)
|
|
2307
|
+
if return_aux:
|
|
2308
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
2309
|
+
return grad, hess
|
|
2310
|
+
if efron_pre is None:
|
|
2311
|
+
efron_pre = self._efron_unique_failure_indices(
|
|
2312
|
+
cp.asnumpy(time), cp.asnumpy(event)
|
|
2313
|
+
)
|
|
2314
|
+
out = self._compute_gradient_hessian_efron_backward_gpu(
|
|
2315
|
+
beta, X, efron_pre
|
|
2316
|
+
)
|
|
2317
|
+
if return_aux:
|
|
2318
|
+
return out[0], out[1], (eta, exp_eta, risk_sum)
|
|
2319
|
+
return out
|
|
2320
|
+
|
|
2321
|
+
# Breslow gradient/Hessian (entry-aware path).
|
|
2322
|
+
event_mask = event == 1
|
|
2323
|
+
grad = cp.zeros(n_features, dtype=cp.float64)
|
|
2324
|
+
|
|
2325
|
+
if not cp.any(event_mask):
|
|
2326
|
+
out = (grad, cp.zeros((n_features, n_features), dtype=cp.float64))
|
|
2327
|
+
if return_aux:
|
|
2328
|
+
return out[0], out[1], (eta, exp_eta, risk_sum)
|
|
2329
|
+
return out
|
|
2330
|
+
|
|
2331
|
+
if entry is not None:
|
|
2332
|
+
if entry_ctx is None:
|
|
2333
|
+
entry_order, d_counts, add_end_np, rem_end_np, rem_order, event_idx, fail_ptr = self._build_entry_ctx_gpu(
|
|
2334
|
+
time, event, entry, cp
|
|
2335
|
+
)
|
|
2336
|
+
X_entry = cp.ascontiguousarray(X[entry_order])
|
|
2337
|
+
X_rem = cp.ascontiguousarray(X[rem_order])
|
|
2338
|
+
grad += cp.sum(X[event_idx], axis=0)
|
|
2339
|
+
else:
|
|
2340
|
+
entry_order, d_counts, add_end_np, rem_end_np = entry_ctx[:4]
|
|
2341
|
+
X_entry = entry_ctx[4] if len(entry_ctx) > 4 else X[entry_order]
|
|
2342
|
+
X_rem = entry_ctx[5] if len(entry_ctx) > 5 else X
|
|
2343
|
+
event_idx = entry_ctx[6] if len(entry_ctx) > 6 else cp.where(event_mask)[0]
|
|
2344
|
+
grad += entry_ctx[7] if len(entry_ctx) > 7 else cp.sum(X[event_mask], axis=0)
|
|
2345
|
+
fail_ptr = entry_ctx[8] if len(entry_ctx) > 8 else None
|
|
2346
|
+
hess = cp.zeros((n_features, n_features), dtype=cp.float64)
|
|
2347
|
+
exp_entry = exp_eta[entry_order]
|
|
2348
|
+
exp_rem = exp_eta
|
|
2349
|
+
wx_entry = X_entry * exp_entry[:, cp.newaxis]
|
|
2350
|
+
wx_rem = X_rem * exp_rem[:, cp.newaxis]
|
|
2351
|
+
n_groups = int(d_counts.shape[0])
|
|
2352
|
+
if n_groups == 0:
|
|
2353
|
+
if return_aux:
|
|
2354
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
2355
|
+
return grad, hess
|
|
2356
|
+
s0_add_pref = cp.cumsum(exp_entry, axis=0)
|
|
2357
|
+
s0_rem_pref = cp.cumsum(exp_rem, axis=0)
|
|
2358
|
+
s1_add_pref = cp.cumsum(wx_entry, axis=0)
|
|
2359
|
+
s1_rem_pref = cp.cumsum(wx_rem, axis=0)
|
|
2360
|
+
s0_add = cp.zeros(n_groups, dtype=cp.float64)
|
|
2361
|
+
s0_rem = cp.zeros(n_groups, dtype=cp.float64)
|
|
2362
|
+
s1_add = cp.zeros((n_groups, n_features), dtype=cp.float64)
|
|
2363
|
+
s1_rem = cp.zeros((n_groups, n_features), dtype=cp.float64)
|
|
2364
|
+
mask_add = add_end_np > 0
|
|
2365
|
+
mask_rem = rem_end_np > 0
|
|
2366
|
+
if np.any(mask_add):
|
|
2367
|
+
idx_add = cp.asarray(add_end_np[mask_add] - 1, dtype=cp.int64)
|
|
2368
|
+
mask_add_cp = cp.asarray(mask_add)
|
|
2369
|
+
s0_add[mask_add_cp] = s0_add_pref[idx_add]
|
|
2370
|
+
s1_add[mask_add_cp] = s1_add_pref[idx_add]
|
|
2371
|
+
if np.any(mask_rem):
|
|
2372
|
+
idx_rem = cp.asarray(rem_end_np[mask_rem] - 1, dtype=cp.int64)
|
|
2373
|
+
mask_rem_cp = cp.asarray(mask_rem)
|
|
2374
|
+
s0_rem[mask_rem_cp] = s0_rem_pref[idx_rem]
|
|
2375
|
+
s1_rem[mask_rem_cp] = s1_rem_pref[idx_rem]
|
|
2376
|
+
s0_vec = s0_add - s0_rem
|
|
2377
|
+
s1_vec = s1_add - s1_rem
|
|
2378
|
+
d_vec = cp.asarray(d_counts, dtype=cp.float64)
|
|
2379
|
+
s0_safe_vec = cp.maximum(s0_vec, 1e-15)
|
|
2380
|
+
use_efron_entry = (self.ties == "efron")
|
|
2381
|
+
ex_vec = s1_vec / s0_safe_vec[:, cp.newaxis]
|
|
2382
|
+
if not use_efron_entry:
|
|
2383
|
+
grad -= cp.sum(d_vec[:, cp.newaxis] * ex_vec, axis=0)
|
|
2384
|
+
if use_efron_entry:
|
|
2385
|
+
if fail_ptr is None:
|
|
2386
|
+
fail_ptr = np.empty(n_groups + 1, dtype=np.int64)
|
|
2387
|
+
fail_ptr[0] = 0
|
|
2388
|
+
fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
|
|
2389
|
+
event_exp = exp_eta[event_idx]
|
|
2390
|
+
X_fail = X[event_idx]
|
|
2391
|
+
add_ptr = 0
|
|
2392
|
+
rem_ptr = 0
|
|
2393
|
+
s2 = cp.zeros((n_features, n_features), dtype=cp.float64)
|
|
2394
|
+
s2_block_size = int(os.environ.get("STATGPU_ENTRY_S2_BLOCK_SIZE", "8192"))
|
|
2395
|
+
if s2_block_size <= 0:
|
|
2396
|
+
s2_block_size = 10**18
|
|
2397
|
+
use_s2_fused = (
|
|
2398
|
+
os.environ.get("STATGPU_ENTRY_S2_FUSED_CUPY", "0").strip().lower()
|
|
2399
|
+
in ("1", "true", "yes", "on")
|
|
2400
|
+
)
|
|
2401
|
+
s2_fused_min_rows = int(os.environ.get("STATGPU_ENTRY_S2_FUSED_MIN_ROWS", "512"))
|
|
2402
|
+
if s2_fused_min_rows < 1:
|
|
2403
|
+
s2_fused_min_rows = 1
|
|
2404
|
+
for g in range(n_groups):
|
|
2405
|
+
add_end = int(add_end_np[g])
|
|
2406
|
+
if add_end > add_ptr:
|
|
2407
|
+
x_add = X_entry[add_ptr:add_end]
|
|
2408
|
+
w_add = exp_entry[add_ptr:add_end]
|
|
2409
|
+
n_add = int(add_end - add_ptr)
|
|
2410
|
+
if use_s2_fused and n_add >= s2_fused_min_rows:
|
|
2411
|
+
s2 = self._s2_weighted_update_cupy_fused(s2, x_add, w_add, sign=1.0)
|
|
2412
|
+
elif n_add <= s2_block_size:
|
|
2413
|
+
s2 = s2 + (x_add.T @ (x_add * w_add[:, cp.newaxis]))
|
|
2414
|
+
else:
|
|
2415
|
+
s2 = self._s2_weighted_update_cupy_blocked(
|
|
2416
|
+
s2, x_add, w_add, s2_block_size, sign=1.0
|
|
2417
|
+
)
|
|
2418
|
+
add_ptr = add_end
|
|
2419
|
+
|
|
2420
|
+
rem_end = int(rem_end_np[g])
|
|
2421
|
+
if rem_end > rem_ptr:
|
|
2422
|
+
x_rem = X_rem[rem_ptr:rem_end]
|
|
2423
|
+
w_rem = exp_eta[rem_ptr:rem_end]
|
|
2424
|
+
n_rem = int(rem_end - rem_ptr)
|
|
2425
|
+
if use_s2_fused and n_rem >= s2_fused_min_rows:
|
|
2426
|
+
s2 = self._s2_weighted_update_cupy_fused(s2, x_rem, w_rem, sign=-1.0)
|
|
2427
|
+
elif n_rem <= s2_block_size:
|
|
2428
|
+
s2 = s2 - (x_rem.T @ (x_rem * w_rem[:, cp.newaxis]))
|
|
2429
|
+
else:
|
|
2430
|
+
s2 = self._s2_weighted_update_cupy_blocked(
|
|
2431
|
+
s2, x_rem, w_rem, s2_block_size, sign=-1.0
|
|
2432
|
+
)
|
|
2433
|
+
rem_ptr = rem_end
|
|
2434
|
+
|
|
2435
|
+
d_t_f = float(d_counts[g])
|
|
2436
|
+
if d_t_f <= 0:
|
|
2437
|
+
continue
|
|
2438
|
+
if use_efron_entry:
|
|
2439
|
+
st = int(fail_ptr[g])
|
|
2440
|
+
ed = int(fail_ptr[g + 1])
|
|
2441
|
+
ef = event_exp[st:ed]
|
|
2442
|
+
xf = X_fail[st:ed]
|
|
2443
|
+
ef_sum = cp.sum(ef)
|
|
2444
|
+
ef_x_sum = cp.sum(xf * ef[:, cp.newaxis], axis=0)
|
|
2445
|
+
ef_x2_sum = (xf.T @ (xf * ef[:, cp.newaxis]))
|
|
2446
|
+
s0_g = cp.maximum(s0_vec[g], 1e-15)
|
|
2447
|
+
s1_g = s1_vec[g]
|
|
2448
|
+
d_i = int(d_t_f)
|
|
2449
|
+
for k in range(d_i):
|
|
2450
|
+
frac = float(k) / float(d_i)
|
|
2451
|
+
denom = cp.maximum(s0_g - frac * ef_sum, 1e-15)
|
|
2452
|
+
s1_k = s1_g - frac * ef_x_sum
|
|
2453
|
+
s2_k = s2 - frac * ef_x2_sum
|
|
2454
|
+
ex_k = s1_k / denom
|
|
2455
|
+
grad -= ex_k
|
|
2456
|
+
hess -= s2_k / denom
|
|
2457
|
+
hess += cp.outer(ex_k, ex_k)
|
|
2458
|
+
else:
|
|
2459
|
+
s0_safe = s0_safe_vec[g]
|
|
2460
|
+
hess -= (d_t_f / s0_safe) * s2
|
|
2461
|
+
if not use_efron_entry:
|
|
2462
|
+
hess += ex_vec.T @ (d_vec[:, cp.newaxis] * ex_vec)
|
|
2463
|
+
if return_aux:
|
|
2464
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
2465
|
+
return grad, hess
|
|
2466
|
+
|
|
2467
|
+
# For Breslow ties, all events at the same failure time share the
|
|
2468
|
+
# same risk set R(t); grouping is required for correctness.
|
|
2469
|
+
breslow_pre_gpu = getattr(self, "_breslow_pre_gpu", None)
|
|
2470
|
+
if (
|
|
2471
|
+
breslow_pre_gpu is not None
|
|
2472
|
+
and len(breslow_pre_gpu) == 2
|
|
2473
|
+
and int(breslow_pre_gpu[0].size) > 0
|
|
2474
|
+
):
|
|
2475
|
+
first_idx_uft, counts_uft = breslow_pre_gpu
|
|
2476
|
+
else:
|
|
2477
|
+
uft, counts_uft = cp.unique(time[event_mask], return_counts=True)
|
|
2478
|
+
first_idx_uft = cp.searchsorted(time, uft, side="left")
|
|
2479
|
+
counts_uft = counts_uft.astype(cp.int32, copy=False)
|
|
2480
|
+
|
|
2481
|
+
counts_f = getattr(self, "_breslow_counts_f_gpu", None)
|
|
2482
|
+
if counts_f is None or int(counts_f.shape[0]) != int(counts_uft.shape[0]):
|
|
2483
|
+
counts_f = counts_uft.astype(cp.float64)
|
|
2484
|
+
grad_pre = getattr(self, "_event_X_sum_gpu", None)
|
|
2485
|
+
grad = (
|
|
2486
|
+
grad_pre.copy()
|
|
2487
|
+
if grad_pre is not None and int(grad_pre.shape[0]) == int(n_features)
|
|
2488
|
+
else cp.sum(X[event_mask], axis=0)
|
|
2489
|
+
)
|
|
2490
|
+
E_X = risk_X_sum[first_idx_uft] / risk_sum[first_idx_uft][:, cp.newaxis]
|
|
2491
|
+
grad = grad - cp.sum(E_X * counts_f[:, cp.newaxis], axis=0)
|
|
2492
|
+
if profile_breslow:
|
|
2493
|
+
cp.cuda.Stream.null.synchronize()
|
|
2494
|
+
_t_grad = _time.perf_counter()
|
|
2495
|
+
use_fused_breslow = (
|
|
2496
|
+
os.environ.get("STATGPU_BRESLOW_FUSED_CUPY", "0").strip().lower()
|
|
2497
|
+
in ("1", "true", "yes", "on")
|
|
2498
|
+
)
|
|
2499
|
+
hess = None
|
|
2500
|
+
if use_fused_breslow:
|
|
2501
|
+
hess = self._compute_hessian_breslow_fused_cupy(
|
|
2502
|
+
X, first_idx_uft, counts_f, exp_eta
|
|
2503
|
+
)
|
|
2504
|
+
if hess is None:
|
|
2505
|
+
hess = self._compute_hessian_breslow_incremental_grouped_cupy(
|
|
2506
|
+
X, risk_sum, risk_X_sum, exp_eta, first_idx_uft, counts_f
|
|
2507
|
+
)
|
|
2508
|
+
if profile_breslow:
|
|
2509
|
+
cp.cuda.Stream.null.synchronize()
|
|
2510
|
+
_t_hess = _time.perf_counter()
|
|
2511
|
+
print(
|
|
2512
|
+
f"[CUDA Breslow profile] pre={(_t_pre - _t0_all):.4f}s "
|
|
2513
|
+
f"grad={(_t_grad - _t_pre):.4f}s "
|
|
2514
|
+
f"hess={(_t_hess - _t_grad):.4f}s "
|
|
2515
|
+
f"total={(_t_hess - _t0_all):.4f}s"
|
|
2516
|
+
)
|
|
2517
|
+
if return_aux:
|
|
2518
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
2519
|
+
return grad, hess
|
|
2520
|
+
|
|
2521
|
+
def _s2_weighted_update_cupy_blocked(self, s2, x, w, block_size, sign=1.0):
|
|
2522
|
+
"""Blocked update for large slices: s2 += sign * X^T (X * w)."""
|
|
2523
|
+
import cupy as cp
|
|
2524
|
+
|
|
2525
|
+
n = int(x.shape[0])
|
|
2526
|
+
if n <= 0:
|
|
2527
|
+
return s2
|
|
2528
|
+
for st in range(0, n, block_size):
|
|
2529
|
+
ed = min(st + block_size, n)
|
|
2530
|
+
xb = x[st:ed]
|
|
2531
|
+
wb = w[st:ed]
|
|
2532
|
+
s2 = s2 + sign * (xb.T @ (xb * wb[:, cp.newaxis]))
|
|
2533
|
+
return s2
|
|
2534
|
+
|
|
2535
|
+
def _get_entry_s2_fused_kernel_cupy(self):
|
|
2536
|
+
"""Build/cache CuPy RawKernel for fused weighted X^T X update."""
|
|
2537
|
+
k = getattr(self, "_entry_s2_fused_kernel_cupy", None)
|
|
2538
|
+
if k is not None:
|
|
2539
|
+
return k
|
|
2540
|
+
import cupy as cp
|
|
2541
|
+
|
|
2542
|
+
src = r"""
|
|
2543
|
+
extern "C" __global__
|
|
2544
|
+
void entry_s2_outer_f64(const double* x, const double* w, double* out, int n, int p) {
|
|
2545
|
+
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
2546
|
+
int j = blockIdx.y * blockDim.y + threadIdx.y;
|
|
2547
|
+
if (i >= p || j >= p) return;
|
|
2548
|
+
double acc = 0.0;
|
|
2549
|
+
for (int r = 0; r < n; ++r) {
|
|
2550
|
+
double wr = w[r];
|
|
2551
|
+
double xi = x[(size_t)r * (size_t)p + (size_t)i];
|
|
2552
|
+
double xj = x[(size_t)r * (size_t)p + (size_t)j];
|
|
2553
|
+
acc += wr * xi * xj;
|
|
2554
|
+
}
|
|
2555
|
+
out[(size_t)i * (size_t)p + (size_t)j] = acc;
|
|
2556
|
+
}
|
|
2557
|
+
"""
|
|
2558
|
+
k = cp.RawKernel(src, "entry_s2_outer_f64")
|
|
2559
|
+
self._entry_s2_fused_kernel_cupy = k
|
|
2560
|
+
return k
|
|
2561
|
+
|
|
2562
|
+
def _s2_weighted_update_cupy_fused(self, s2, x, w, sign=1.0):
|
|
2563
|
+
"""CuPy fused kernel update for s2 += sign * X^T (X * w)."""
|
|
2564
|
+
import cupy as cp
|
|
2565
|
+
|
|
2566
|
+
n = int(x.shape[0])
|
|
2567
|
+
if n <= 0:
|
|
2568
|
+
return s2
|
|
2569
|
+
x = cp.ascontiguousarray(x, dtype=cp.float64)
|
|
2570
|
+
w = cp.ascontiguousarray(w, dtype=cp.float64)
|
|
2571
|
+
p = int(x.shape[1])
|
|
2572
|
+
out = cp.empty((p, p), dtype=cp.float64)
|
|
2573
|
+
threads = (16, 16, 1)
|
|
2574
|
+
blocks = ((p + 15) // 16, (p + 15) // 16, 1)
|
|
2575
|
+
ker = self._get_entry_s2_fused_kernel_cupy()
|
|
2576
|
+
ker(blocks, threads, (x, w, out, np.int32(n), np.int32(p)))
|
|
2577
|
+
if sign > 0:
|
|
2578
|
+
return s2 + out
|
|
2579
|
+
return s2 - out
|
|
2580
|
+
|
|
2581
|
+
def _compute_gradient_hessian_efron_backward_gpu(self, beta, X, efron_pre):
|
|
2582
|
+
"""CuPy Efron grad/Hessian: prefer single CUDA RawKernel scan, else Python-loop fallback."""
|
|
2583
|
+
import cupy as cp
|
|
2584
|
+
|
|
2585
|
+
uft, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
2586
|
+
n_features = X.shape[1]
|
|
2587
|
+
if nuft == 0:
|
|
2588
|
+
return cp.zeros(n_features, dtype=cp.float64), cp.zeros(
|
|
2589
|
+
(n_features, n_features), dtype=cp.float64
|
|
2590
|
+
)
|
|
2591
|
+
|
|
2592
|
+
n_samples = int(X.shape[0])
|
|
2593
|
+
avg_tie = float(n_samples) / max(1.0, float(nuft))
|
|
2594
|
+
use_grouped_gemm = (
|
|
2595
|
+
os.environ.get("STATGPU_EFRON_GROUPED_GEMM", "1").strip().lower()
|
|
2596
|
+
in ("1", "true", "yes", "on")
|
|
2597
|
+
)
|
|
2598
|
+
if use_grouped_gemm and n_features <= 192 and avg_tie >= 24.0:
|
|
2599
|
+
return self._compute_gradient_hessian_efron_grouped_gemm_cupy(
|
|
2600
|
+
beta, X, efron_pre
|
|
2601
|
+
)
|
|
2602
|
+
|
|
2603
|
+
try:
|
|
2604
|
+
from ._cox_efron_cuda import compute_efron_grad_hess_raw
|
|
2605
|
+
|
|
2606
|
+
csr_gpu = getattr(self, "_efron_pre_csr_gpu", None)
|
|
2607
|
+
if csr_gpu is not None:
|
|
2608
|
+
out = compute_efron_grad_hess_raw(
|
|
2609
|
+
X,
|
|
2610
|
+
beta,
|
|
2611
|
+
efron_pre,
|
|
2612
|
+
efron_csr=csr_gpu,
|
|
2613
|
+
cupy_module=cp,
|
|
2614
|
+
)
|
|
2615
|
+
else:
|
|
2616
|
+
out = compute_efron_grad_hess_raw(X, beta, efron_pre, cupy_module=cp)
|
|
2617
|
+
if out is not None:
|
|
2618
|
+
return out[0], out[1]
|
|
2619
|
+
except Exception:
|
|
2620
|
+
pass
|
|
2621
|
+
|
|
2622
|
+
linpred = X @ beta
|
|
2623
|
+
linpred = linpred - cp.max(linpred)
|
|
2624
|
+
e_linpred = cp.exp(linpred)
|
|
2625
|
+
|
|
2626
|
+
grad = cp.zeros(n_features, dtype=cp.float64)
|
|
2627
|
+
hess_inner = cp.zeros((n_features, n_features), dtype=cp.float64)
|
|
2628
|
+
xp0 = cp.zeros((), dtype=cp.float64)
|
|
2629
|
+
xp1 = cp.zeros(n_features, dtype=cp.float64)
|
|
2630
|
+
xp2 = cp.zeros((n_features, n_features), dtype=cp.float64)
|
|
2631
|
+
for i in range(nuft)[::-1]:
|
|
2632
|
+
ix = risk_enter[i]
|
|
2633
|
+
if len(ix) > 0:
|
|
2634
|
+
ix = cp.array(ix, dtype=cp.int32)
|
|
2635
|
+
elx = e_linpred[ix]
|
|
2636
|
+
v = X[ix]
|
|
2637
|
+
xp0 = xp0 + elx.sum()
|
|
2638
|
+
xp1 = xp1 + (elx[:, None] * v).sum(axis=0)
|
|
2639
|
+
xp2 = xp2 + cp.einsum("ij,ik,i->jk", v, v, elx)
|
|
2640
|
+
ixf = uft_ix[i]
|
|
2641
|
+
if len(ixf) > 0:
|
|
2642
|
+
ixf = cp.array(ixf, dtype=cp.int32)
|
|
2643
|
+
v = X[ixf]
|
|
2644
|
+
elx = e_linpred[ixf]
|
|
2645
|
+
xp0f = elx.sum()
|
|
2646
|
+
xp1f = (elx[:, None] * v).sum(axis=0)
|
|
2647
|
+
xp2f = cp.einsum("ij,ik,i->jk", v, v, elx)
|
|
2648
|
+
m = len(ixf)
|
|
2649
|
+
J = cp.arange(m, dtype=cp.float64) / max(m, 1)
|
|
2650
|
+
c0 = xp0 - J * xp0f
|
|
2651
|
+
c0 = cp.maximum(c0, 1e-300)
|
|
2652
|
+
inv = 1.0 / c0
|
|
2653
|
+
ak = inv
|
|
2654
|
+
bk = J * inv
|
|
2655
|
+
sum_inv_c0 = cp.sum(ak)
|
|
2656
|
+
sum_J_c0 = cp.sum(bk)
|
|
2657
|
+
sum_aa = cp.sum(ak * ak)
|
|
2658
|
+
sum_bb = cp.sum(bk * bk)
|
|
2659
|
+
sum_ab = cp.sum(ak * bk)
|
|
2660
|
+
grad = grad + v.sum(axis=0)
|
|
2661
|
+
grad = grad - (xp1 * sum_inv_c0 - xp1f * sum_J_c0)
|
|
2662
|
+
hess_inner = hess_inner + xp2 * sum_inv_c0
|
|
2663
|
+
hess_inner = hess_inner - xp2f * sum_J_c0
|
|
2664
|
+
hess_inner = hess_inner - (
|
|
2665
|
+
sum_aa * cp.outer(xp1, xp1)
|
|
2666
|
+
+ sum_bb * cp.outer(xp1f, xp1f)
|
|
2667
|
+
- sum_ab * (cp.outer(xp1, xp1f) + cp.outer(xp1f, xp1))
|
|
2668
|
+
)
|
|
2669
|
+
ix = risk_exit[i]
|
|
2670
|
+
if len(ix) > 0:
|
|
2671
|
+
ix = cp.array(ix, dtype=cp.int32)
|
|
2672
|
+
elx = e_linpred[ix]
|
|
2673
|
+
v = X[ix]
|
|
2674
|
+
xp0 = xp0 - elx.sum()
|
|
2675
|
+
xp1 = xp1 - (elx[:, None] * v).sum(axis=0)
|
|
2676
|
+
xp2 = xp2 - cp.einsum("ij,ik,i->jk", v, v, elx)
|
|
2677
|
+
|
|
2678
|
+
hess = -hess_inner
|
|
2679
|
+
return grad, hess
|
|
2680
|
+
|
|
2681
|
+
def _compute_gradient_hessian_efron_grouped_gemm_cupy(self, beta, X, efron_pre):
|
|
2682
|
+
"""Exact Efron grad/hess on CuPy via grouped GEMM updates (no p^2 atomics)."""
|
|
2683
|
+
import cupy as cp
|
|
2684
|
+
|
|
2685
|
+
_, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
2686
|
+
n_features = int(X.shape[1])
|
|
2687
|
+
linpred = X @ beta
|
|
2688
|
+
linpred = linpred - cp.max(linpred)
|
|
2689
|
+
e_linpred = cp.exp(linpred)
|
|
2690
|
+
|
|
2691
|
+
grad = cp.zeros(n_features, dtype=cp.float64)
|
|
2692
|
+
hess_inner = cp.zeros((n_features, n_features), dtype=cp.float64)
|
|
2693
|
+
xp0 = cp.zeros((), dtype=cp.float64)
|
|
2694
|
+
xp1 = cp.zeros(n_features, dtype=cp.float64)
|
|
2695
|
+
xp2 = cp.zeros((n_features, n_features), dtype=cp.float64)
|
|
2696
|
+
j_cache = {}
|
|
2697
|
+
|
|
2698
|
+
for i in range(nuft - 1, -1, -1):
|
|
2699
|
+
ix = risk_enter[i]
|
|
2700
|
+
if len(ix) > 0:
|
|
2701
|
+
idx = cp.asarray(ix, dtype=cp.int32)
|
|
2702
|
+
v = X[idx]
|
|
2703
|
+
elx = e_linpred[idx]
|
|
2704
|
+
wv = v * elx[:, None]
|
|
2705
|
+
xp0 = xp0 + cp.sum(elx)
|
|
2706
|
+
xp1 = xp1 + cp.sum(wv, axis=0)
|
|
2707
|
+
xp2 = xp2 + (wv.T @ v)
|
|
2708
|
+
|
|
2709
|
+
ixf = uft_ix[i]
|
|
2710
|
+
if len(ixf) > 0:
|
|
2711
|
+
idxf = cp.asarray(ixf, dtype=cp.int32)
|
|
2712
|
+
v = X[idxf]
|
|
2713
|
+
elx = e_linpred[idxf]
|
|
2714
|
+
wv = v * elx[:, None]
|
|
2715
|
+
xp0f = cp.sum(elx)
|
|
2716
|
+
xp1f = cp.sum(wv, axis=0)
|
|
2717
|
+
xp2f = wv.T @ v
|
|
2718
|
+
m = len(ixf)
|
|
2719
|
+
if m not in j_cache:
|
|
2720
|
+
j_cache[m] = cp.arange(m, dtype=cp.float64) / float(max(m, 1))
|
|
2721
|
+
J = j_cache[m]
|
|
2722
|
+
c0 = cp.maximum(xp0 - J * xp0f, 1e-300)
|
|
2723
|
+
inv = 1.0 / c0
|
|
2724
|
+
ak = inv
|
|
2725
|
+
bk = J * inv
|
|
2726
|
+
sum_inv_c0 = cp.sum(ak)
|
|
2727
|
+
sum_J_c0 = cp.sum(bk)
|
|
2728
|
+
sum_aa = cp.sum(ak * ak)
|
|
2729
|
+
sum_bb = cp.sum(bk * bk)
|
|
2730
|
+
sum_ab = cp.sum(ak * bk)
|
|
2731
|
+
grad = grad + cp.sum(v, axis=0)
|
|
2732
|
+
grad = grad - (xp1 * sum_inv_c0 - xp1f * sum_J_c0)
|
|
2733
|
+
hess_inner = hess_inner + xp2 * sum_inv_c0
|
|
2734
|
+
hess_inner = hess_inner - xp2f * sum_J_c0
|
|
2735
|
+
hess_inner = hess_inner - (
|
|
2736
|
+
sum_aa * cp.outer(xp1, xp1)
|
|
2737
|
+
+ sum_bb * cp.outer(xp1f, xp1f)
|
|
2738
|
+
- sum_ab * (cp.outer(xp1, xp1f) + cp.outer(xp1f, xp1))
|
|
2739
|
+
)
|
|
2740
|
+
|
|
2741
|
+
ix = risk_exit[i]
|
|
2742
|
+
if len(ix) > 0:
|
|
2743
|
+
idx = cp.asarray(ix, dtype=cp.int32)
|
|
2744
|
+
v = X[idx]
|
|
2745
|
+
elx = e_linpred[idx]
|
|
2746
|
+
wv = v * elx[:, None]
|
|
2747
|
+
xp0 = xp0 - cp.sum(elx)
|
|
2748
|
+
xp1 = xp1 - cp.sum(wv, axis=0)
|
|
2749
|
+
xp2 = xp2 - (wv.T @ v)
|
|
2750
|
+
|
|
2751
|
+
return grad, -hess_inner
|
|
2752
|
+
|
|
2753
|
+
def _solve_newton_delta_torch(self, hess, grad):
|
|
2754
|
+
"""Newton step delta = inv(hess) @ grad; prefer SPD solve on (-hess) with light jitter."""
|
|
2755
|
+
import torch
|
|
2756
|
+
|
|
2757
|
+
p = int(hess.shape[0])
|
|
2758
|
+
try:
|
|
2759
|
+
H = -hess
|
|
2760
|
+
eps = 1e-11 * (torch.max(torch.abs(torch.diag(H))) + 1.0)
|
|
2761
|
+
H = H + eps * torch.eye(p, dtype=torch.float64, device=hess.device)
|
|
2762
|
+
return -torch.linalg.solve(H, grad)
|
|
2763
|
+
except Exception:
|
|
2764
|
+
try:
|
|
2765
|
+
return torch.linalg.solve(hess, grad)
|
|
2766
|
+
except Exception:
|
|
2767
|
+
result = torch.linalg.lstsq(hess, grad)
|
|
2768
|
+
return result.solution.flatten()
|
|
2769
|
+
|
|
2770
|
+
def _compute_gradient_hessian_efron_grouped_gemm_torch(self, beta, X, efron_pre):
|
|
2771
|
+
"""Exact Efron grad/hess on Torch device via grouped GEMM updates."""
|
|
2772
|
+
import torch
|
|
2773
|
+
|
|
2774
|
+
_, uft_ix, risk_enter, risk_exit, nuft, _ = _unpack_efron_pre6(efron_pre)
|
|
2775
|
+
n_features = int(X.shape[1])
|
|
2776
|
+
linpred = X @ beta
|
|
2777
|
+
linpred = linpred - torch.max(linpred)
|
|
2778
|
+
e_linpred = torch.exp(linpred)
|
|
2779
|
+
|
|
2780
|
+
grad = torch.zeros(n_features, dtype=torch.float64, device=beta.device)
|
|
2781
|
+
hess_inner = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
|
|
2782
|
+
xp0 = torch.zeros((), dtype=torch.float64, device=beta.device)
|
|
2783
|
+
xp1 = torch.zeros(n_features, dtype=torch.float64, device=beta.device)
|
|
2784
|
+
xp2 = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
|
|
2785
|
+
j_cache = {}
|
|
2786
|
+
|
|
2787
|
+
for i in range(nuft - 1, -1, -1):
|
|
2788
|
+
ix = risk_enter[i]
|
|
2789
|
+
if len(ix) > 0:
|
|
2790
|
+
idx = torch.as_tensor(ix, dtype=torch.long, device=beta.device)
|
|
2791
|
+
v = X[idx]
|
|
2792
|
+
elx = e_linpred[idx]
|
|
2793
|
+
wv = v * elx[:, None]
|
|
2794
|
+
xp0 = xp0 + torch.sum(elx)
|
|
2795
|
+
xp1 = xp1 + torch.sum(wv, dim=0)
|
|
2796
|
+
xp2 = xp2 + (wv.transpose(0, 1) @ v)
|
|
2797
|
+
|
|
2798
|
+
ixf = uft_ix[i]
|
|
2799
|
+
if len(ixf) > 0:
|
|
2800
|
+
idxf = torch.as_tensor(ixf, dtype=torch.long, device=beta.device)
|
|
2801
|
+
v = X[idxf]
|
|
2802
|
+
elx = e_linpred[idxf]
|
|
2803
|
+
wv = v * elx[:, None]
|
|
2804
|
+
xp0f = torch.sum(elx)
|
|
2805
|
+
xp1f = torch.sum(wv, dim=0)
|
|
2806
|
+
xp2f = wv.transpose(0, 1) @ v
|
|
2807
|
+
m = len(ixf)
|
|
2808
|
+
if m not in j_cache:
|
|
2809
|
+
j_cache[m] = torch.arange(m, dtype=torch.float64, device=beta.device) / float(max(m, 1))
|
|
2810
|
+
J = j_cache[m]
|
|
2811
|
+
c0 = torch.clamp(xp0 - J * xp0f, min=1e-300)
|
|
2812
|
+
inv = 1.0 / c0
|
|
2813
|
+
ak = inv
|
|
2814
|
+
bk = J * inv
|
|
2815
|
+
sum_inv_c0 = torch.sum(ak)
|
|
2816
|
+
sum_J_c0 = torch.sum(bk)
|
|
2817
|
+
sum_aa = torch.sum(ak * ak)
|
|
2818
|
+
sum_bb = torch.sum(bk * bk)
|
|
2819
|
+
sum_ab = torch.sum(ak * bk)
|
|
2820
|
+
grad = grad + torch.sum(v, dim=0)
|
|
2821
|
+
grad = grad - (xp1 * sum_inv_c0 - xp1f * sum_J_c0)
|
|
2822
|
+
hess_inner = hess_inner + xp2 * sum_inv_c0
|
|
2823
|
+
hess_inner = hess_inner - xp2f * sum_J_c0
|
|
2824
|
+
hess_inner = hess_inner - (
|
|
2825
|
+
sum_aa * torch.outer(xp1, xp1)
|
|
2826
|
+
+ sum_bb * torch.outer(xp1f, xp1f)
|
|
2827
|
+
- sum_ab * (torch.outer(xp1, xp1f) + torch.outer(xp1f, xp1))
|
|
2828
|
+
)
|
|
2829
|
+
|
|
2830
|
+
ix = risk_exit[i]
|
|
2831
|
+
if len(ix) > 0:
|
|
2832
|
+
idx = torch.as_tensor(ix, dtype=torch.long, device=beta.device)
|
|
2833
|
+
v = X[idx]
|
|
2834
|
+
elx = e_linpred[idx]
|
|
2835
|
+
wv = v * elx[:, None]
|
|
2836
|
+
xp0 = xp0 - torch.sum(elx)
|
|
2837
|
+
xp1 = xp1 - torch.sum(wv, dim=0)
|
|
2838
|
+
xp2 = xp2 - (wv.transpose(0, 1) @ v)
|
|
2839
|
+
|
|
2840
|
+
return grad, -hess_inner
|
|
2841
|
+
|
|
2842
|
+
def _compute_log_likelihood_torch(self, beta, X, time, event, efron_pre=None, entry=None, entry_ctx=None):
|
|
2843
|
+
"""Compute log partial likelihood on Torch."""
|
|
2844
|
+
import torch
|
|
2845
|
+
|
|
2846
|
+
eta = X @ beta
|
|
2847
|
+
exp_eta = torch.exp(eta)
|
|
2848
|
+
# Entry+breslow path does not consume risk_sum; skip the cumsum to
|
|
2849
|
+
# reduce per-evaluation overhead during line-search probes.
|
|
2850
|
+
risk_sum = None if entry is not None else torch.cumsum(exp_eta.flip(0), dim=0).flip(0)
|
|
2851
|
+
return self._compute_log_likelihood_torch_from_stats(
|
|
2852
|
+
eta, exp_eta, risk_sum, time, event, efron_pre, entry=entry, entry_ctx=entry_ctx
|
|
2853
|
+
)
|
|
2854
|
+
|
|
2855
|
+
def _build_entry_ctx_torch(self, time, event, entry, device):
|
|
2856
|
+
"""Build entry-time grouped indexing context for a specific sorted Torch view."""
|
|
2857
|
+
import torch
|
|
2858
|
+
|
|
2859
|
+
event_mask = event == 1
|
|
2860
|
+
event_idx = torch.where(event_mask)[0]
|
|
2861
|
+
evt_t = time[event_idx].detach().cpu().numpy()
|
|
2862
|
+
if evt_t.size == 0:
|
|
2863
|
+
return (
|
|
2864
|
+
torch.zeros((0,), dtype=torch.long, device=device),
|
|
2865
|
+
np.zeros((0,), dtype=np.float64),
|
|
2866
|
+
np.zeros((0,), dtype=np.int64),
|
|
2867
|
+
np.zeros((0,), dtype=np.int64),
|
|
2868
|
+
torch.zeros((0,), dtype=torch.long, device=device),
|
|
2869
|
+
torch.zeros((0,), dtype=torch.long, device=device),
|
|
2870
|
+
np.zeros((1,), dtype=np.int64),
|
|
2871
|
+
)
|
|
2872
|
+
uft_np, d_counts = np.unique(evt_t, return_counts=True)
|
|
2873
|
+
d_counts = d_counts.astype(np.float64, copy=False)
|
|
2874
|
+
entry_order = torch.argsort(entry, stable=True)
|
|
2875
|
+
entry_sorted_np = entry.index_select(0, entry_order).detach().cpu().numpy()
|
|
2876
|
+
time_np = time.detach().cpu().numpy()
|
|
2877
|
+
add_end_np = np.searchsorted(entry_sorted_np, uft_np, side="left").astype(np.int64, copy=False)
|
|
2878
|
+
rem_end_np = np.searchsorted(time_np, uft_np, side="left").astype(np.int64, copy=False)
|
|
2879
|
+
rem_order = torch.arange(int(time.shape[0]), dtype=torch.long, device=device)
|
|
2880
|
+
event_idx = event_idx.to(torch.long)
|
|
2881
|
+
fail_ptr = np.empty(d_counts.shape[0] + 1, dtype=np.int64)
|
|
2882
|
+
fail_ptr[0] = 0
|
|
2883
|
+
fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
|
|
2884
|
+
return (entry_order, d_counts, add_end_np, rem_end_np, rem_order, event_idx, fail_ptr)
|
|
2885
|
+
|
|
2886
|
+
def _compute_log_likelihood_torch_from_stats(
|
|
2887
|
+
self, eta, exp_eta, risk_sum, time, event, efron_pre=None, entry=None, entry_ctx=None
|
|
2888
|
+
):
|
|
2889
|
+
"""Compute log partial likelihood on Torch with precomputed stats."""
|
|
2890
|
+
import torch
|
|
2891
|
+
|
|
2892
|
+
ll = torch.tensor(0.0, dtype=torch.float64, device=eta.device)
|
|
2893
|
+
event_mask = event == 1
|
|
2894
|
+
|
|
2895
|
+
if not torch.any(event_mask):
|
|
2896
|
+
return ll
|
|
2897
|
+
|
|
2898
|
+
if entry is not None:
|
|
2899
|
+
if entry_ctx is None:
|
|
2900
|
+
entry_order, d_counts, add_end_np, rem_end_np, _rem_order, event_idx, fail_ptr = self._build_entry_ctx_torch(
|
|
2901
|
+
time, event, entry, eta.device
|
|
2902
|
+
)
|
|
2903
|
+
else:
|
|
2904
|
+
entry_order, d_counts, add_end_np, rem_end_np = entry_ctx[:4]
|
|
2905
|
+
event_idx = entry_ctx[6] if len(entry_ctx) > 6 else torch.where(event_mask)[0]
|
|
2906
|
+
fail_ptr = entry_ctx[8] if len(entry_ctx) > 8 else None
|
|
2907
|
+
|
|
2908
|
+
n_groups = int(d_counts.shape[0])
|
|
2909
|
+
if n_groups == 0:
|
|
2910
|
+
return torch.tensor(0.0, dtype=torch.float64, device=eta.device)
|
|
2911
|
+
if fail_ptr is None:
|
|
2912
|
+
fail_ptr = np.empty(n_groups + 1, dtype=np.int64)
|
|
2913
|
+
fail_ptr[0] = 0
|
|
2914
|
+
fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
|
|
2915
|
+
|
|
2916
|
+
exp_entry = exp_eta.index_select(0, entry_order)
|
|
2917
|
+
exp_rem = exp_eta
|
|
2918
|
+
s0_add_pref = torch.cumsum(exp_entry, dim=0)
|
|
2919
|
+
s0_rem_pref = torch.cumsum(exp_rem, dim=0)
|
|
2920
|
+
s0_add = torch.zeros(n_groups, dtype=torch.float64, device=eta.device)
|
|
2921
|
+
s0_rem = torch.zeros(n_groups, dtype=torch.float64, device=eta.device)
|
|
2922
|
+
mask_add = add_end_np > 0
|
|
2923
|
+
mask_rem = rem_end_np > 0
|
|
2924
|
+
if np.any(mask_add):
|
|
2925
|
+
idx_add = torch.as_tensor(add_end_np[mask_add] - 1, dtype=torch.long, device=eta.device)
|
|
2926
|
+
s0_add[torch.as_tensor(mask_add, dtype=torch.bool, device=eta.device)] = s0_add_pref.index_select(0, idx_add)
|
|
2927
|
+
if np.any(mask_rem):
|
|
2928
|
+
idx_rem = torch.as_tensor(rem_end_np[mask_rem] - 1, dtype=torch.long, device=eta.device)
|
|
2929
|
+
s0_rem[torch.as_tensor(mask_rem, dtype=torch.bool, device=eta.device)] = s0_rem_pref.index_select(0, idx_rem)
|
|
2930
|
+
s0_vec = torch.clamp(s0_add - s0_rem, min=1e-300)
|
|
2931
|
+
event_eta = eta.index_select(0, event_idx)
|
|
2932
|
+
|
|
2933
|
+
if self.ties == "breslow":
|
|
2934
|
+
d_vec = torch.as_tensor(d_counts, dtype=torch.float64, device=eta.device)
|
|
2935
|
+
return torch.sum(event_eta) - torch.sum(d_vec * torch.log(s0_vec))
|
|
2936
|
+
|
|
2937
|
+
ll = torch.sum(event_eta)
|
|
2938
|
+
event_exp = exp_eta.index_select(0, event_idx)
|
|
2939
|
+
for g in range(n_groups):
|
|
2940
|
+
d = int(d_counts[g])
|
|
2941
|
+
if d <= 0:
|
|
2942
|
+
continue
|
|
2943
|
+
st = int(fail_ptr[g])
|
|
2944
|
+
ed = int(fail_ptr[g + 1])
|
|
2945
|
+
ef = torch.sum(event_exp[st:ed])
|
|
2946
|
+
base = s0_vec[g]
|
|
2947
|
+
for k in range(d):
|
|
2948
|
+
denom = torch.clamp(base - (float(k) / float(d)) * ef, min=1e-300)
|
|
2949
|
+
ll = ll - torch.log(denom)
|
|
2950
|
+
return ll
|
|
2951
|
+
|
|
2952
|
+
if self.ties == "breslow":
|
|
2953
|
+
# Vectorized Breslow using cached failure groups
|
|
2954
|
+
breslow_pre_torch = getattr(self, "_breslow_pre_torch", None)
|
|
2955
|
+
if (
|
|
2956
|
+
breslow_pre_torch is not None
|
|
2957
|
+
and len(breslow_pre_torch) == 2
|
|
2958
|
+
and int(breslow_pre_torch[0].numel()) > 0
|
|
2959
|
+
):
|
|
2960
|
+
first_idx_uft, counts_uft = breslow_pre_torch
|
|
2961
|
+
else:
|
|
2962
|
+
uft, counts_uft = torch.unique(time[event_mask], return_counts=True)
|
|
2963
|
+
first_idx_uft = torch.searchsorted(time, uft, side="left")
|
|
2964
|
+
counts_uft = counts_uft.to(torch.int32)
|
|
2965
|
+
risk_at = risk_sum[first_idx_uft]
|
|
2966
|
+
return torch.sum(eta[event_mask]) - torch.sum(
|
|
2967
|
+
counts_uft.to(torch.float64) * torch.log(risk_at)
|
|
2968
|
+
)
|
|
2969
|
+
|
|
2970
|
+
# Efron: keep computation fully on torch backend.
|
|
2971
|
+
if efron_pre is not None:
|
|
2972
|
+
needs_exact_ties = not getattr(self, "_efron_all_singletons", False)
|
|
2973
|
+
# No-tie Efron equals Breslow; keep computation on torch device.
|
|
2974
|
+
if not needs_exact_ties:
|
|
2975
|
+
_, _, _, _, nuft, first_idx_uft = _unpack_efron_pre6(efron_pre)
|
|
2976
|
+
first_idx_t = torch.as_tensor(first_idx_uft, dtype=torch.int64, device=eta.device)
|
|
2977
|
+
counts_t = torch.ones(int(nuft), dtype=torch.float64, device=eta.device)
|
|
2978
|
+
risk_at = risk_sum[first_idx_t]
|
|
2979
|
+
return torch.sum(eta[event_mask]) - torch.sum(counts_t * torch.log(risk_at))
|
|
2980
|
+
|
|
2981
|
+
# Fallback Efron (loop version)
|
|
2982
|
+
unique_times = torch.unique(time[event_mask])
|
|
2983
|
+
for t in unique_times:
|
|
2984
|
+
at_time_t = time == t
|
|
2985
|
+
events_at_t = at_time_t & event_mask
|
|
2986
|
+
d = int(torch.sum(events_at_t).item())
|
|
2987
|
+
|
|
2988
|
+
if d == 0:
|
|
2989
|
+
continue
|
|
2990
|
+
|
|
2991
|
+
risk_indices = torch.where(time >= t)[0]
|
|
2992
|
+
if risk_indices.numel() == 0:
|
|
2993
|
+
continue
|
|
2994
|
+
|
|
2995
|
+
first_idx = risk_indices[0]
|
|
2996
|
+
risk_at_t = risk_sum[first_idx]
|
|
2997
|
+
sum_events = torch.sum(exp_eta[events_at_t])
|
|
2998
|
+
|
|
2999
|
+
ll += torch.sum(eta[events_at_t])
|
|
3000
|
+
for k in range(d):
|
|
3001
|
+
ll -= torch.log(torch.maximum(risk_at_t - (k / d) * sum_events, torch.tensor(1e-300, dtype=torch.float64, device=eta.device)))
|
|
3002
|
+
|
|
3003
|
+
return ll
|
|
3004
|
+
|
|
3005
|
+
def _compute_gradient_hessian_torch(
|
|
3006
|
+
self, beta, X, time, event, efron_pre=None, return_aux=False, entry=None, entry_ctx=None
|
|
3007
|
+
):
|
|
3008
|
+
"""Fully vectorized gradient/Hessian for Torch - Efron and Breslow."""
|
|
3009
|
+
import torch
|
|
3010
|
+
n_samples, n_features = X.shape
|
|
3011
|
+
eta = X @ beta
|
|
3012
|
+
exp_eta = torch.exp(eta)
|
|
3013
|
+
rev_idx = torch.arange(n_samples - 1, -1, -1, device=beta.device)
|
|
3014
|
+
risk_sum = torch.cumsum(exp_eta[rev_idx], dim=0)[rev_idx] if entry is None else None
|
|
3015
|
+
|
|
3016
|
+
if self.ties == "efron" and efron_pre is not None and entry is None:
|
|
3017
|
+
needs_exact_ties = not getattr(self, "_efron_all_singletons", False)
|
|
3018
|
+
n_samples = int(X.shape[0])
|
|
3019
|
+
avg_tie = float(n_samples) / max(1.0, float(_unpack_efron_pre6(efron_pre)[4]))
|
|
3020
|
+
use_grouped_gemm = (
|
|
3021
|
+
os.environ.get("STATGPU_EFRON_GROUPED_GEMM", "1").strip().lower()
|
|
3022
|
+
in ("1", "true", "yes", "on")
|
|
3023
|
+
)
|
|
3024
|
+
# For real ties, use exact torch grouped GEMM path only.
|
|
3025
|
+
if needs_exact_ties and (
|
|
3026
|
+
use_grouped_gemm
|
|
3027
|
+
and beta.is_cuda
|
|
3028
|
+
and n_features <= 192
|
|
3029
|
+
and avg_tie >= 24.0
|
|
3030
|
+
):
|
|
3031
|
+
out = self._compute_gradient_hessian_efron_grouped_gemm_torch(
|
|
3032
|
+
beta, X, efron_pre
|
|
3033
|
+
)
|
|
3034
|
+
if return_aux:
|
|
3035
|
+
return out[0], out[1], (eta, exp_eta, risk_sum)
|
|
3036
|
+
return out
|
|
3037
|
+
|
|
3038
|
+
# ---- Triton Efron path ----
|
|
3039
|
+
if (
|
|
3040
|
+
os.environ.get("STATGPU_EFRON_TRITON", "0").strip().lower()
|
|
3041
|
+
in ("1", "true", "yes", "on")
|
|
3042
|
+
and beta.is_cuda
|
|
3043
|
+
and efron_pre is not None
|
|
3044
|
+
):
|
|
3045
|
+
from statgpu.survival._cox_efron_triton import compute_efron_grad_hess_triton
|
|
3046
|
+
triton_out = compute_efron_grad_hess_triton(X, beta, efron_pre)
|
|
3047
|
+
if triton_out is not None:
|
|
3048
|
+
grad, hess = triton_out
|
|
3049
|
+
if return_aux:
|
|
3050
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
3051
|
+
return grad, hess
|
|
3052
|
+
|
|
3053
|
+
# Reverse cumsum for risk sets (vectorized)
|
|
3054
|
+
risk_X_sum = torch.cumsum((X * exp_eta[:, None])[rev_idx], dim=0)[rev_idx] if entry is None else None
|
|
3055
|
+
|
|
3056
|
+
event_mask = event == 1
|
|
3057
|
+
if not torch.any(event_mask):
|
|
3058
|
+
out = (
|
|
3059
|
+
torch.zeros(n_features, dtype=torch.float64, device=beta.device),
|
|
3060
|
+
torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device),
|
|
3061
|
+
)
|
|
3062
|
+
if return_aux:
|
|
3063
|
+
return out[0], out[1], (eta, exp_eta, risk_sum)
|
|
3064
|
+
return out
|
|
3065
|
+
|
|
3066
|
+
if entry is not None:
|
|
3067
|
+
if entry_ctx is None:
|
|
3068
|
+
entry_order, d_counts, add_end_np, rem_end_np, rem_order, event_idx, fail_ptr = self._build_entry_ctx_torch(
|
|
3069
|
+
time, event, entry, beta.device
|
|
3070
|
+
)
|
|
3071
|
+
X_entry = X.index_select(0, entry_order).contiguous()
|
|
3072
|
+
X_rem = X.index_select(0, rem_order).contiguous()
|
|
3073
|
+
grad = torch.sum(X.index_select(0, event_idx), dim=0)
|
|
3074
|
+
else:
|
|
3075
|
+
entry_order, d_counts, add_end_np, rem_end_np = entry_ctx[:4]
|
|
3076
|
+
X_entry = entry_ctx[4] if len(entry_ctx) > 4 else X.index_select(0, entry_order)
|
|
3077
|
+
X_rem = entry_ctx[5] if len(entry_ctx) > 5 else X
|
|
3078
|
+
event_idx = entry_ctx[6] if len(entry_ctx) > 6 else torch.where(event_mask)[0]
|
|
3079
|
+
grad = entry_ctx[7] if len(entry_ctx) > 7 else torch.sum(X[event_mask], dim=0)
|
|
3080
|
+
fail_ptr = entry_ctx[8] if len(entry_ctx) > 8 else None
|
|
3081
|
+
hess = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
|
|
3082
|
+
exp_entry = exp_eta.index_select(0, entry_order)
|
|
3083
|
+
exp_rem = exp_eta
|
|
3084
|
+
wx_entry = X_entry * exp_entry.unsqueeze(1)
|
|
3085
|
+
wx_rem = X_rem * exp_rem.unsqueeze(1)
|
|
3086
|
+
n_groups = int(d_counts.shape[0])
|
|
3087
|
+
if n_groups == 0:
|
|
3088
|
+
if return_aux:
|
|
3089
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
3090
|
+
return grad, hess
|
|
3091
|
+
s0_add_pref = torch.cumsum(exp_entry, dim=0)
|
|
3092
|
+
s0_rem_pref = torch.cumsum(exp_rem, dim=0)
|
|
3093
|
+
s1_add_pref = torch.cumsum(wx_entry, dim=0)
|
|
3094
|
+
s1_rem_pref = torch.cumsum(wx_rem, dim=0)
|
|
3095
|
+
s0_add = torch.zeros(n_groups, dtype=torch.float64, device=beta.device)
|
|
3096
|
+
s0_rem = torch.zeros(n_groups, dtype=torch.float64, device=beta.device)
|
|
3097
|
+
s1_add = torch.zeros((n_groups, n_features), dtype=torch.float64, device=beta.device)
|
|
3098
|
+
s1_rem = torch.zeros((n_groups, n_features), dtype=torch.float64, device=beta.device)
|
|
3099
|
+
mask_add = add_end_np > 0
|
|
3100
|
+
mask_rem = rem_end_np > 0
|
|
3101
|
+
if np.any(mask_add):
|
|
3102
|
+
idx_add = torch.as_tensor(add_end_np[mask_add] - 1, dtype=torch.long, device=beta.device)
|
|
3103
|
+
mask_add_t = torch.as_tensor(mask_add, dtype=torch.bool, device=beta.device)
|
|
3104
|
+
s0_add[mask_add_t] = s0_add_pref.index_select(0, idx_add)
|
|
3105
|
+
s1_add[mask_add_t] = s1_add_pref.index_select(0, idx_add)
|
|
3106
|
+
if np.any(mask_rem):
|
|
3107
|
+
idx_rem = torch.as_tensor(rem_end_np[mask_rem] - 1, dtype=torch.long, device=beta.device)
|
|
3108
|
+
mask_rem_t = torch.as_tensor(mask_rem, dtype=torch.bool, device=beta.device)
|
|
3109
|
+
s0_rem[mask_rem_t] = s0_rem_pref.index_select(0, idx_rem)
|
|
3110
|
+
s1_rem[mask_rem_t] = s1_rem_pref.index_select(0, idx_rem)
|
|
3111
|
+
s0_vec = s0_add - s0_rem
|
|
3112
|
+
s1_vec = s1_add - s1_rem
|
|
3113
|
+
d_vec = torch.as_tensor(d_counts, dtype=torch.float64, device=beta.device)
|
|
3114
|
+
s0_safe_vec = torch.clamp(s0_vec, min=1e-15)
|
|
3115
|
+
use_efron_entry = (self.ties == "efron")
|
|
3116
|
+
ex_vec = s1_vec / s0_safe_vec.unsqueeze(1)
|
|
3117
|
+
if not use_efron_entry:
|
|
3118
|
+
grad = grad - torch.sum(d_vec.unsqueeze(1) * ex_vec, dim=0)
|
|
3119
|
+
if use_efron_entry:
|
|
3120
|
+
if fail_ptr is None:
|
|
3121
|
+
fail_ptr = np.empty(n_groups + 1, dtype=np.int64)
|
|
3122
|
+
fail_ptr[0] = 0
|
|
3123
|
+
fail_ptr[1:] = np.cumsum(d_counts.astype(np.int64), dtype=np.int64)
|
|
3124
|
+
event_exp = exp_eta.index_select(0, event_idx)
|
|
3125
|
+
X_fail = X.index_select(0, event_idx)
|
|
3126
|
+
add_ptr = 0
|
|
3127
|
+
rem_ptr = 0
|
|
3128
|
+
s2 = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
|
|
3129
|
+
s2_block_size = int(os.environ.get("STATGPU_ENTRY_S2_BLOCK_SIZE", "8192"))
|
|
3130
|
+
if s2_block_size <= 0:
|
|
3131
|
+
s2_block_size = 10**18
|
|
3132
|
+
s2_fn = self._get_entry_s2_torch_fn()
|
|
3133
|
+
for g in range(n_groups):
|
|
3134
|
+
add_end = int(add_end_np[g])
|
|
3135
|
+
if add_end > add_ptr:
|
|
3136
|
+
x_add = X_entry[add_ptr:add_end]
|
|
3137
|
+
w_add = exp_entry[add_ptr:add_end]
|
|
3138
|
+
n_add = int(add_end - add_ptr)
|
|
3139
|
+
if n_add <= s2_block_size:
|
|
3140
|
+
s2 = s2 + s2_fn(x_add, w_add)
|
|
3141
|
+
else:
|
|
3142
|
+
s2 = self._s2_weighted_update_torch_blocked(
|
|
3143
|
+
s2, x_add, w_add, s2_block_size, sign=1.0
|
|
3144
|
+
)
|
|
3145
|
+
add_ptr = add_end
|
|
3146
|
+
|
|
3147
|
+
rem_end = int(rem_end_np[g])
|
|
3148
|
+
if rem_end > rem_ptr:
|
|
3149
|
+
x_rem = X_rem[rem_ptr:rem_end]
|
|
3150
|
+
w_rem = exp_eta[rem_ptr:rem_end]
|
|
3151
|
+
n_rem = int(rem_end - rem_ptr)
|
|
3152
|
+
if n_rem <= s2_block_size:
|
|
3153
|
+
s2 = s2 - s2_fn(x_rem, w_rem)
|
|
3154
|
+
else:
|
|
3155
|
+
s2 = self._s2_weighted_update_torch_blocked(
|
|
3156
|
+
s2, x_rem, w_rem, s2_block_size, sign=-1.0
|
|
3157
|
+
)
|
|
3158
|
+
rem_ptr = rem_end
|
|
3159
|
+
|
|
3160
|
+
d_t_f = float(d_counts[g])
|
|
3161
|
+
if d_t_f <= 0:
|
|
3162
|
+
continue
|
|
3163
|
+
if use_efron_entry:
|
|
3164
|
+
st = int(fail_ptr[g])
|
|
3165
|
+
ed = int(fail_ptr[g + 1])
|
|
3166
|
+
ef = event_exp[st:ed]
|
|
3167
|
+
xf = X_fail[st:ed]
|
|
3168
|
+
ef_sum = torch.sum(ef)
|
|
3169
|
+
ef_x_sum = torch.sum(xf * ef.unsqueeze(1), dim=0)
|
|
3170
|
+
ef_x2_sum = xf.transpose(0, 1) @ (xf * ef.unsqueeze(1))
|
|
3171
|
+
s0_g = torch.clamp(s0_vec[g], min=1e-15)
|
|
3172
|
+
s1_g = s1_vec[g]
|
|
3173
|
+
d_i = int(d_t_f)
|
|
3174
|
+
for k in range(d_i):
|
|
3175
|
+
frac = float(k) / float(d_i)
|
|
3176
|
+
denom = torch.clamp(s0_g - frac * ef_sum, min=1e-15)
|
|
3177
|
+
s1_k = s1_g - frac * ef_x_sum
|
|
3178
|
+
s2_k = s2 - frac * ef_x2_sum
|
|
3179
|
+
ex_k = s1_k / denom
|
|
3180
|
+
grad = grad - ex_k
|
|
3181
|
+
hess = hess - (s2_k / denom)
|
|
3182
|
+
hess = hess + torch.outer(ex_k, ex_k)
|
|
3183
|
+
else:
|
|
3184
|
+
s0_safe = s0_safe_vec[g]
|
|
3185
|
+
hess = hess - (d_t_f / s0_safe) * s2
|
|
3186
|
+
if not use_efron_entry:
|
|
3187
|
+
hess = hess + ex_vec.transpose(0, 1) @ (d_vec.unsqueeze(1) * ex_vec)
|
|
3188
|
+
if return_aux:
|
|
3189
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
3190
|
+
return grad, hess
|
|
3191
|
+
|
|
3192
|
+
# Get event data
|
|
3193
|
+
event_times = time[event_mask]
|
|
3194
|
+
|
|
3195
|
+
# Unique failure times with inverse mapping
|
|
3196
|
+
uft, unique_inv = torch.unique(event_times, sorted=True, return_inverse=True)
|
|
3197
|
+
n_uft = len(uft)
|
|
3198
|
+
counts = torch.bincount(unique_inv).to(torch.float64)
|
|
3199
|
+
|
|
3200
|
+
# Get first index of each unique time
|
|
3201
|
+
sorted_times, sort_idx = torch.sort(time)
|
|
3202
|
+
first_in_sorted = torch.searchsorted(sorted_times, uft, side="left")
|
|
3203
|
+
first_idx = sort_idx[first_in_sorted]
|
|
3204
|
+
|
|
3205
|
+
# Risk values at unique times
|
|
3206
|
+
risk_at_uft = risk_sum[first_idx]
|
|
3207
|
+
risk_X_at_uft = risk_X_sum[first_idx]
|
|
3208
|
+
E_X_at_uft = risk_X_at_uft / risk_at_uft[:, None]
|
|
3209
|
+
|
|
3210
|
+
# Sum X and exp(eta) for events at each unique time
|
|
3211
|
+
event_indices = event_mask.nonzero(as_tuple=True)[0]
|
|
3212
|
+
sum_X_per_uft = torch.zeros((n_uft, n_features), dtype=torch.float64, device=beta.device)
|
|
3213
|
+
sum_X_per_uft.index_add_(0, unique_inv, X[event_indices])
|
|
3214
|
+
|
|
3215
|
+
# ============= GRADIENT =============
|
|
3216
|
+
if self.ties == "efron":
|
|
3217
|
+
# Efron closed-form: (d+1)/2 * E[X|R]
|
|
3218
|
+
efron_weight = (counts + 1) / 2.0
|
|
3219
|
+
grad = torch.sum(sum_X_per_uft - efron_weight[:, None] * E_X_at_uft, dim=0)
|
|
3220
|
+
else:
|
|
3221
|
+
# Breslow: d * E[X|R]
|
|
3222
|
+
grad = torch.sum(sum_X_per_uft - counts[:, None] * E_X_at_uft, dim=0)
|
|
3223
|
+
|
|
3224
|
+
# Hessian
|
|
3225
|
+
# Use incremental risk-set second moments to avoid materializing
|
|
3226
|
+
# a (n_samples, n_features, n_features) tensor on GPU (can OOM at 50k x 100).
|
|
3227
|
+
X_exp = X * exp_eta[:, None]
|
|
3228
|
+
risk_X2 = X_exp.transpose(0, 1) @ X
|
|
3229
|
+
|
|
3230
|
+
# Weight by counts (Breslow) or Efron-adjusted weights
|
|
3231
|
+
if self.ties == "efron":
|
|
3232
|
+
weights = efron_weight
|
|
3233
|
+
else:
|
|
3234
|
+
weights = counts
|
|
3235
|
+
|
|
3236
|
+
# ---- Triton Breslow path ----
|
|
3237
|
+
if (
|
|
3238
|
+
self.ties != "efron"
|
|
3239
|
+
and os.environ.get("STATGPU_BRESLOW_TRITON", "0").strip().lower()
|
|
3240
|
+
in ("1", "true", "yes", "on")
|
|
3241
|
+
and beta.is_cuda
|
|
3242
|
+
):
|
|
3243
|
+
from statgpu.survival._cox_efron_triton import compute_breslow_grad_hess_triton
|
|
3244
|
+
triton_out = compute_breslow_grad_hess_triton(X, beta, time, event)
|
|
3245
|
+
if triton_out is not None:
|
|
3246
|
+
grad, hess = triton_out
|
|
3247
|
+
if return_aux:
|
|
3248
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
3249
|
+
return grad, hess
|
|
3250
|
+
|
|
3251
|
+
hess = torch.zeros((n_features, n_features), dtype=torch.float64, device=beta.device)
|
|
3252
|
+
prev_idx = 0
|
|
3253
|
+
for g in range(n_uft):
|
|
3254
|
+
idx = int(first_idx[g].item())
|
|
3255
|
+
if idx > prev_idx:
|
|
3256
|
+
blk = slice(prev_idx, idx)
|
|
3257
|
+
risk_X2 = risk_X2 - (X_exp[blk].transpose(0, 1) @ X[blk])
|
|
3258
|
+
prev_idx = idx
|
|
3259
|
+
|
|
3260
|
+
rs = torch.clamp(risk_at_uft[g], min=1e-300)
|
|
3261
|
+
w = weights[g]
|
|
3262
|
+
ex = E_X_at_uft[g]
|
|
3263
|
+
# Torch-native p^2 update: avoid CuPy bridge and host sync.
|
|
3264
|
+
hess.sub_(risk_X2 * (w / rs))
|
|
3265
|
+
hess.add_(torch.outer(ex, ex) * w)
|
|
3266
|
+
|
|
3267
|
+
if return_aux:
|
|
3268
|
+
return grad, hess, (eta, exp_eta, risk_sum)
|
|
3269
|
+
return grad, hess
|
|
3270
|
+
|
|
3271
|
+
def _s2_weighted_update_torch_blocked(self, s2, x, w, block_size, sign=1.0):
|
|
3272
|
+
"""Blocked update for large slices: s2 += sign * X^T (X * w)."""
|
|
3273
|
+
s2_fn = self._get_entry_s2_torch_fn()
|
|
3274
|
+
|
|
3275
|
+
n = int(x.shape[0])
|
|
3276
|
+
if n <= 0:
|
|
3277
|
+
return s2
|
|
3278
|
+
for st in range(0, n, block_size):
|
|
3279
|
+
ed = min(st + block_size, n)
|
|
3280
|
+
xb = x[st:ed]
|
|
3281
|
+
wb = w[st:ed]
|
|
3282
|
+
s2 = s2 + sign * s2_fn(xb, wb)
|
|
3283
|
+
return s2
|
|
3284
|
+
|
|
3285
|
+
def _get_entry_s2_torch_fn(self):
|
|
3286
|
+
"""Build/cache torch or torch.compile function for weighted X^T X."""
|
|
3287
|
+
fn = getattr(self, "_entry_s2_torch_fn", None)
|
|
3288
|
+
if fn is not None:
|
|
3289
|
+
return fn
|
|
3290
|
+
import torch
|
|
3291
|
+
|
|
3292
|
+
def _s2_core(x, w):
|
|
3293
|
+
return x.transpose(0, 1) @ (x * w.unsqueeze(1))
|
|
3294
|
+
|
|
3295
|
+
use_compile = (
|
|
3296
|
+
os.environ.get("STATGPU_ENTRY_S2_COMPILE_TORCH", "0").strip().lower()
|
|
3297
|
+
in ("1", "true", "yes", "on")
|
|
3298
|
+
)
|
|
3299
|
+
if use_compile and hasattr(torch, "compile"):
|
|
3300
|
+
mode = os.environ.get("STATGPU_ENTRY_S2_COMPILE_MODE", "default")
|
|
3301
|
+
try:
|
|
3302
|
+
fn = torch.compile(_s2_core, dynamic=True, fullgraph=False, mode=mode)
|
|
3303
|
+
except Exception:
|
|
3304
|
+
fn = _s2_core
|
|
3305
|
+
else:
|
|
3306
|
+
fn = _s2_core
|
|
3307
|
+
self._entry_s2_torch_fn = fn
|
|
3308
|
+
return fn
|
|
3309
|
+
|
|
3310
|
+
def _compute_cindex_torch(self, X, time, event, beta):
|
|
3311
|
+
"""Compute concordance index (C-index) on Torch."""
|
|
3312
|
+
import torch
|
|
3313
|
+
|
|
3314
|
+
# Linear predictor (risk score)
|
|
3315
|
+
risk_score = X @ beta
|
|
3316
|
+
|
|
3317
|
+
n = len(time)
|
|
3318
|
+
event_mask = (event == 1)
|
|
3319
|
+
|
|
3320
|
+
if torch.sum(event_mask) == 0:
|
|
3321
|
+
return torch.tensor(0.5, dtype=torch.float64, device=beta.device)
|
|
3322
|
+
|
|
3323
|
+
# Use chunked vectorized approach for memory efficiency
|
|
3324
|
+
event_idx = torch.where(event_mask)[0]
|
|
3325
|
+
n_events = len(event_idx)
|
|
3326
|
+
|
|
3327
|
+
if n_events == 0:
|
|
3328
|
+
return torch.tensor(float("nan"), dtype=torch.float64, device=beta.device)
|
|
3329
|
+
|
|
3330
|
+
concordant = torch.tensor(0, dtype=torch.int64, device=beta.device)
|
|
3331
|
+
permissible = torch.tensor(0, dtype=torch.int64, device=beta.device)
|
|
3332
|
+
tied_risk = torch.tensor(0, dtype=torch.int64, device=beta.device)
|
|
3333
|
+
|
|
3334
|
+
# Chunk size for memory efficiency (~128 MB per batch matrix)
|
|
3335
|
+
chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
|
|
3336
|
+
|
|
3337
|
+
for start in range(0, n_events, chunk_size):
|
|
3338
|
+
end = min(start + chunk_size, n_events)
|
|
3339
|
+
idx_chunk = event_idx[start:end]
|
|
3340
|
+
|
|
3341
|
+
time_i = time[idx_chunk][:, None]
|
|
3342
|
+
risk_i = risk_score[idx_chunk][:, None]
|
|
3343
|
+
time_j = time[None, :]
|
|
3344
|
+
risk_j = risk_score[None, :]
|
|
3345
|
+
event_j = event[None, :]
|
|
3346
|
+
|
|
3347
|
+
# Permissible pairs: earlier time OR same time with j censored
|
|
3348
|
+
perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
|
|
3349
|
+
# Exclude self-comparisons
|
|
3350
|
+
chunk_indices = torch.arange(end - start, device=beta.device)
|
|
3351
|
+
perm[chunk_indices, idx_chunk] = False
|
|
3352
|
+
|
|
3353
|
+
concordant += torch.sum(perm & (risk_i > risk_j))
|
|
3354
|
+
tied_risk += torch.sum(perm & (risk_i == risk_j))
|
|
3355
|
+
permissible += torch.sum(perm)
|
|
3356
|
+
|
|
3357
|
+
if permissible > 0:
|
|
3358
|
+
return (concordant.to(torch.float64) + 0.5 * tied_risk.to(torch.float64)) / permissible.to(torch.float64)
|
|
3359
|
+
else:
|
|
3360
|
+
return torch.tensor(float("nan"), dtype=torch.float64, device=beta.device)
|
|
3361
|
+
|
|
3362
|
+
def _compute_inference_cpu(self, X, time, event, cluster=None):
|
|
3363
|
+
"""Compute standard errors, z-values, p-values, and confidence intervals."""
|
|
3364
|
+
n_features = X.shape[1]
|
|
3365
|
+
|
|
3366
|
+
# Keep inference self-contained (no nested external model fitting),
|
|
3367
|
+
# so runtime reflects this implementation directly.
|
|
3368
|
+
|
|
3369
|
+
# Compute information matrix (negative Hessian at MLE)
|
|
3370
|
+
_, hess = self._compute_gradient_hessian(
|
|
3371
|
+
self.coef_, X, time, event, getattr(self, "_efron_pre", None), entry=getattr(self, "_entry", None)
|
|
3372
|
+
)
|
|
3373
|
+
|
|
3374
|
+
# Bread matrix from observed information.
|
|
3375
|
+
try:
|
|
3376
|
+
bread = np.linalg.solve(-hess, np.eye(n_features))
|
|
3377
|
+
except np.linalg.LinAlgError:
|
|
3378
|
+
bread = np.linalg.pinv(-hess)
|
|
3379
|
+
|
|
3380
|
+
if self.cov_type == "nonrobust":
|
|
3381
|
+
self._var_matrix = bread
|
|
3382
|
+
elif self.cov_type == "cluster":
|
|
3383
|
+
if cluster is None:
|
|
3384
|
+
raise ValueError("cov_type='cluster' requires cluster ids in fit(..., cluster=...)")
|
|
3385
|
+
cluster = np.asarray(cluster)
|
|
3386
|
+
score_resid = self._compute_robust_score_residuals(X, time, event)
|
|
3387
|
+
uniq = np.unique(cluster)
|
|
3388
|
+
meat = np.zeros((n_features, n_features), dtype=np.float64)
|
|
3389
|
+
for g in uniq:
|
|
3390
|
+
u_g = np.sum(score_resid[cluster == g], axis=0)
|
|
3391
|
+
meat += np.outer(u_g, u_g)
|
|
3392
|
+
self._var_matrix = bread @ meat @ bread
|
|
3393
|
+
else:
|
|
3394
|
+
score_resid = self._compute_robust_score_residuals(X, time, event)
|
|
3395
|
+
meat = score_resid.T @ score_resid
|
|
3396
|
+
self._var_matrix = bread @ meat @ bread
|
|
3397
|
+
if self.cov_type == "hc1":
|
|
3398
|
+
n = X.shape[0]
|
|
3399
|
+
k = X.shape[1]
|
|
3400
|
+
if n > k:
|
|
3401
|
+
self._var_matrix = self._var_matrix * (n / (n - k))
|
|
3402
|
+
|
|
3403
|
+
# Standard errors
|
|
3404
|
+
self._bse = np.sqrt(np.maximum(np.diag(self._var_matrix), 0.0))
|
|
3405
|
+
|
|
3406
|
+
# z-values (add epsilon to avoid division by zero)
|
|
3407
|
+
self._zvalues = self.coef_ / (self._bse + 1e-30)
|
|
3408
|
+
|
|
3409
|
+
# p-values (two-sided)
|
|
3410
|
+
self._pvalues = 2 * (1 - stats.norm.cdf(np.abs(self._zvalues)))
|
|
3411
|
+
|
|
3412
|
+
# 95% confidence intervals
|
|
3413
|
+
alpha = 0.05
|
|
3414
|
+
z_crit = stats.norm.ppf(1 - alpha / 2)
|
|
3415
|
+
self._conf_int = np.column_stack([
|
|
3416
|
+
self.coef_ - z_crit * self._bse,
|
|
3417
|
+
self.coef_ + z_crit * self._bse
|
|
3418
|
+
])
|
|
3419
|
+
|
|
3420
|
+
# Wald test (global test that all coefficients are 0)
|
|
3421
|
+
try:
|
|
3422
|
+
var_inv = np.linalg.solve(self._var_matrix, np.eye(n_features))
|
|
3423
|
+
self._wald_test_stat = self.coef_ @ var_inv @ self.coef_
|
|
3424
|
+
except np.linalg.LinAlgError:
|
|
3425
|
+
self._wald_test_stat = np.nan
|
|
3426
|
+
self._wald_test_pvalue = 1 - stats.chi2.cdf(self._wald_test_stat, n_features)
|
|
3427
|
+
|
|
3428
|
+
# Likelihood ratio test
|
|
3429
|
+
self._lr_test_stat = 2 * (self._log_likelihood - self._log_likelihood_null)
|
|
3430
|
+
self._lr_test_pvalue = 1 - stats.chi2.cdf(self._lr_test_stat, n_features)
|
|
3431
|
+
|
|
3432
|
+
# Score test (Rao's test) - computed at beta = 0
|
|
3433
|
+
ep = getattr(self, "_efron_pre", None)
|
|
3434
|
+
grad_0, _ = self._compute_gradient_hessian(np.zeros(n_features), X, time, event, ep, entry=getattr(self, "_entry", None))
|
|
3435
|
+
try:
|
|
3436
|
+
_, hess_0 = self._compute_gradient_hessian(np.zeros(n_features), X, time, event, ep, entry=getattr(self, "_entry", None))
|
|
3437
|
+
info_0 = -hess_0
|
|
3438
|
+
info_0_inv = np.linalg.solve(info_0, np.eye(n_features))
|
|
3439
|
+
self._score_test_stat = grad_0 @ info_0_inv @ grad_0
|
|
3440
|
+
except:
|
|
3441
|
+
self._score_test_stat = np.nan
|
|
3442
|
+
self._score_test_pvalue = 1 - stats.chi2.cdf(self._score_test_stat, n_features)
|
|
3443
|
+
|
|
3444
|
+
def _compute_robust_score_residuals(self, X, time, event):
|
|
3445
|
+
"""
|
|
3446
|
+
Per-observation contributions for sandwich (HC0/HC1/cluster).
|
|
3447
|
+
|
|
3448
|
+
When `statsmodels` is available, uses `PHReg.score_residuals`, which
|
|
3449
|
+
follows the martingale / leverage construction used by statsmodels for
|
|
3450
|
+
cluster-robust covariance (same for both Breslow and Efron partial
|
|
3451
|
+
likelihood). This aligns robust SEs with statsmodels much more closely
|
|
3452
|
+
than the closed-form Breslow score residual or the fast Efron
|
|
3453
|
+
approximation.
|
|
3454
|
+
|
|
3455
|
+
Falls back to `_compute_score_residuals_exact_breslow` (Breslow) or
|
|
3456
|
+
`_compute_score_residuals_fast` (Efron) when statsmodels is missing or
|
|
3457
|
+
raises.
|
|
3458
|
+
"""
|
|
3459
|
+
sr = self._score_residuals_via_statsmodels_if_available(X, time, event)
|
|
3460
|
+
if sr is not None:
|
|
3461
|
+
return sr
|
|
3462
|
+
if self.ties == "breslow":
|
|
3463
|
+
return self._compute_score_residuals_exact_breslow(X, time, event)
|
|
3464
|
+
return self._compute_score_residuals_fast(X, time, event)
|
|
3465
|
+
|
|
3466
|
+
def _compute_robust_score_residuals_gpu(self, X, time, event):
|
|
3467
|
+
"""GPU robust score residuals using event-row approximation."""
|
|
3468
|
+
import cupy as cp
|
|
3469
|
+
|
|
3470
|
+
eta = X @ cp.asarray(self.coef_)
|
|
3471
|
+
exp_eta = cp.exp(eta)
|
|
3472
|
+
risk_sum = cp.cumsum(exp_eta[::-1])[::-1] + 1e-30
|
|
3473
|
+
risk_X_sum = cp.cumsum((X * exp_eta[:, cp.newaxis])[::-1], axis=0)[::-1]
|
|
3474
|
+
score_residuals = cp.zeros((X.shape[0], X.shape[1]), dtype=cp.float64)
|
|
3475
|
+
event_mask = event == 1
|
|
3476
|
+
score_residuals[event_mask] = X[event_mask] - risk_X_sum[event_mask] / risk_sum[event_mask, cp.newaxis]
|
|
3477
|
+
return score_residuals
|
|
3478
|
+
|
|
3479
|
+
def _score_residuals_via_statsmodels_if_available(
|
|
3480
|
+
self, X: np.ndarray, time: np.ndarray, event: np.ndarray
|
|
3481
|
+
):
|
|
3482
|
+
"""Return statsmodels-style score residuals, or None if unavailable."""
|
|
3483
|
+
try:
|
|
3484
|
+
import statsmodels.duration.api as smd
|
|
3485
|
+
except Exception:
|
|
3486
|
+
return None
|
|
3487
|
+
try:
|
|
3488
|
+
model = smd.PHReg(time, X, status=event, ties=self.ties)
|
|
3489
|
+
sr = model.score_residuals(self.coef_)
|
|
3490
|
+
if sr.shape != (X.shape[0], X.shape[1]):
|
|
3491
|
+
return None
|
|
3492
|
+
# Undefined strata / risk-set rows are NaN in statsmodels; drop from meat.
|
|
3493
|
+
sr = np.nan_to_num(sr, nan=0.0, posinf=0.0, neginf=0.0)
|
|
3494
|
+
return np.asarray(sr, dtype=np.float64)
|
|
3495
|
+
except Exception:
|
|
3496
|
+
return None
|
|
3497
|
+
|
|
3498
|
+
def _compute_score_residuals_fast(self, X, time, event):
|
|
3499
|
+
"""
|
|
3500
|
+
Fast approximate per-observation score residuals at fitted beta.
|
|
3501
|
+
|
|
3502
|
+
Event-row approximation:
|
|
3503
|
+
u_i = x_i - E[X | R(t_i)] for event rows, 0 for censored rows.
|
|
3504
|
+
This is substantially faster for larger n.
|
|
3505
|
+
"""
|
|
3506
|
+
n_samples, n_features = X.shape
|
|
3507
|
+
eta = X @ self.coef_
|
|
3508
|
+
exp_eta = np.exp(eta)
|
|
3509
|
+
risk_sum = np.cumsum(exp_eta[::-1])[::-1] + 1e-30
|
|
3510
|
+
risk_X_sum = np.cumsum((X * exp_eta[:, np.newaxis])[::-1], axis=0)[::-1]
|
|
3511
|
+
u = np.zeros((n_samples, n_features), dtype=np.float64)
|
|
3512
|
+
# Vectorized: fill only event rows.
|
|
3513
|
+
event_mask = event == 1
|
|
3514
|
+
u[event_mask] = X[event_mask] - risk_X_sum[event_mask] / risk_sum[event_mask, np.newaxis]
|
|
3515
|
+
return u
|
|
3516
|
+
|
|
3517
|
+
def _compute_score_residuals_exact_breslow(self, X, time, event):
|
|
3518
|
+
"""
|
|
3519
|
+
Exact per-observation score residuals for Breslow ties in O(n p).
|
|
3520
|
+
|
|
3521
|
+
u_j = I(event_j) * s_j - exp_eta_j * sum_{i<=j, event_i=1}(s_i / risk_sum_i),
|
|
3522
|
+
where s_i = x_i - E[X|R(t_i)].
|
|
3523
|
+
"""
|
|
3524
|
+
eta = X @ self.coef_
|
|
3525
|
+
exp_eta = np.exp(eta)
|
|
3526
|
+
risk_sum = np.cumsum(exp_eta[::-1])[::-1] + 1e-30
|
|
3527
|
+
risk_X_sum = np.cumsum((X * exp_eta[:, np.newaxis])[::-1], axis=0)[::-1]
|
|
3528
|
+
event_mask = (event == 1).astype(np.float64)
|
|
3529
|
+
s = X - (risk_X_sum / risk_sum[:, np.newaxis])
|
|
3530
|
+
a = (event_mask[:, np.newaxis] * s) / risk_sum[:, np.newaxis]
|
|
3531
|
+
csum_a = np.cumsum(a, axis=0)
|
|
3532
|
+
u = event_mask[:, np.newaxis] * s - exp_eta[:, np.newaxis] * csum_a
|
|
3533
|
+
return u
|
|
3534
|
+
|
|
3535
|
+
def _compute_baseline_hazard(self, X, time, event):
|
|
3536
|
+
"""Compute Breslow estimator of baseline hazard and survival function."""
|
|
3537
|
+
# Get unique event times
|
|
3538
|
+
event_mask = event == 1
|
|
3539
|
+
if not np.any(event_mask):
|
|
3540
|
+
self._unique_times = np.array([])
|
|
3541
|
+
self._baseline_hazard = np.array([])
|
|
3542
|
+
self._baseline_cumulative_hazard = np.array([])
|
|
3543
|
+
return
|
|
3544
|
+
|
|
3545
|
+
unique_times = np.unique(time[event_mask])
|
|
3546
|
+
self._unique_times = unique_times
|
|
3547
|
+
|
|
3548
|
+
# Linear predictor
|
|
3549
|
+
eta = X @ self.coef_
|
|
3550
|
+
exp_eta = np.exp(eta)
|
|
3551
|
+
|
|
3552
|
+
# Compute baseline cumulative hazard using Breslow estimator
|
|
3553
|
+
cumulative_hazard = np.zeros(len(unique_times))
|
|
3554
|
+
|
|
3555
|
+
for i, t in enumerate(unique_times):
|
|
3556
|
+
# Events at time t
|
|
3557
|
+
d_i = np.sum((time == t) & (event == 1))
|
|
3558
|
+
|
|
3559
|
+
# Risk set at time t (all with time >= t)
|
|
3560
|
+
risk_set = time >= t
|
|
3561
|
+
risk_sum = np.sum(exp_eta[risk_set])
|
|
3562
|
+
|
|
3563
|
+
# Breslow estimator contribution
|
|
3564
|
+
cumulative_hazard[i] = d_i / risk_sum
|
|
3565
|
+
|
|
3566
|
+
# Cumulative sum
|
|
3567
|
+
self._baseline_cumulative_hazard = np.cumsum(cumulative_hazard)
|
|
3568
|
+
|
|
3569
|
+
# Hazard (discrete)
|
|
3570
|
+
self._baseline_hazard = cumulative_hazard
|
|
3571
|
+
|
|
3572
|
+
def _compute_baseline_hazard_gpu(self, X, time, event, beta):
|
|
3573
|
+
"""Compute Breslow estimator of baseline hazard and survival function on GPU."""
|
|
3574
|
+
import cupy as cp
|
|
3575
|
+
|
|
3576
|
+
# Get unique event times
|
|
3577
|
+
event_mask = event == 1
|
|
3578
|
+
if not cp.any(event_mask):
|
|
3579
|
+
self._unique_times = cp.array([])
|
|
3580
|
+
self._baseline_hazard = cp.array([])
|
|
3581
|
+
self._baseline_cumulative_hazard = cp.array([])
|
|
3582
|
+
return
|
|
3583
|
+
|
|
3584
|
+
unique_times = cp.unique(time[event_mask])
|
|
3585
|
+
self._unique_times = unique_times
|
|
3586
|
+
|
|
3587
|
+
# Linear predictor
|
|
3588
|
+
eta = X @ beta
|
|
3589
|
+
exp_eta = cp.exp(eta)
|
|
3590
|
+
|
|
3591
|
+
# Compute baseline cumulative hazard using Breslow estimator (vectorized)
|
|
3592
|
+
cumulative_hazard = cp.zeros(len(unique_times))
|
|
3593
|
+
|
|
3594
|
+
# Vectorized computation using searchsorted
|
|
3595
|
+
# For each unique time, compute d_i / risk_sum
|
|
3596
|
+
for i, t in enumerate(unique_times):
|
|
3597
|
+
# Events at time t
|
|
3598
|
+
d_i = int(cp.sum((time == t) & (event == 1)))
|
|
3599
|
+
|
|
3600
|
+
# Risk set at time t (all with time >= t)
|
|
3601
|
+
risk_set = time >= t
|
|
3602
|
+
risk_sum = cp.sum(exp_eta[risk_set])
|
|
3603
|
+
|
|
3604
|
+
# Breslow estimator contribution
|
|
3605
|
+
cumulative_hazard[i] = d_i / risk_sum
|
|
3606
|
+
|
|
3607
|
+
# Cumulative sum
|
|
3608
|
+
self._baseline_cumulative_hazard = cp.cumsum(cumulative_hazard)
|
|
3609
|
+
|
|
3610
|
+
# Hazard (discrete)
|
|
3611
|
+
self._baseline_hazard = cumulative_hazard
|
|
3612
|
+
|
|
3613
|
+
# Transfer to CPU for storage
|
|
3614
|
+
self._unique_times = cp.asnumpy(self._unique_times)
|
|
3615
|
+
self._baseline_hazard = cp.asnumpy(self._baseline_hazard)
|
|
3616
|
+
self._baseline_cumulative_hazard = cp.asnumpy(self._baseline_cumulative_hazard)
|
|
3617
|
+
|
|
3618
|
+
def _compute_baseline_hazard_torch(self, X, time, event, beta):
|
|
3619
|
+
"""Compute Breslow estimator of baseline hazard and survival function on Torch."""
|
|
3620
|
+
import torch
|
|
3621
|
+
|
|
3622
|
+
# Get unique event times
|
|
3623
|
+
event_mask = event == 1
|
|
3624
|
+
if not torch.any(event_mask):
|
|
3625
|
+
self._unique_times = torch.tensor([], dtype=torch.float64, device=beta.device)
|
|
3626
|
+
self._baseline_hazard = torch.tensor([], dtype=torch.float64, device=beta.device)
|
|
3627
|
+
self._baseline_cumulative_hazard = torch.tensor([], dtype=torch.float64, device=beta.device)
|
|
3628
|
+
return
|
|
3629
|
+
|
|
3630
|
+
unique_times = torch.unique(time[event_mask])
|
|
3631
|
+
self._unique_times = unique_times
|
|
3632
|
+
|
|
3633
|
+
# Linear predictor
|
|
3634
|
+
eta = X @ beta
|
|
3635
|
+
exp_eta = torch.exp(eta)
|
|
3636
|
+
|
|
3637
|
+
# Compute baseline cumulative hazard using Breslow estimator (vectorized)
|
|
3638
|
+
cumulative_hazard = torch.zeros(len(unique_times), dtype=torch.float64, device=beta.device)
|
|
3639
|
+
|
|
3640
|
+
# Vectorized computation
|
|
3641
|
+
for i, t in enumerate(unique_times):
|
|
3642
|
+
# Events at time t
|
|
3643
|
+
d_i = int(torch.sum((time == t) & (event == 1)))
|
|
3644
|
+
|
|
3645
|
+
# Risk set at time t (all with time >= t)
|
|
3646
|
+
risk_set = time >= t
|
|
3647
|
+
risk_sum = torch.sum(exp_eta[risk_set])
|
|
3648
|
+
|
|
3649
|
+
# Breslow estimator contribution
|
|
3650
|
+
cumulative_hazard[i] = d_i / risk_sum
|
|
3651
|
+
|
|
3652
|
+
# Cumulative sum
|
|
3653
|
+
self._baseline_cumulative_hazard = torch.cumsum(cumulative_hazard, dim=0)
|
|
3654
|
+
|
|
3655
|
+
# Hazard (discrete)
|
|
3656
|
+
self._baseline_hazard = cumulative_hazard
|
|
3657
|
+
|
|
3658
|
+
# Transfer to CPU for storage
|
|
3659
|
+
self._unique_times = self._unique_times.cpu().numpy()
|
|
3660
|
+
self._baseline_hazard = self._baseline_hazard.cpu().numpy()
|
|
3661
|
+
self._baseline_cumulative_hazard = self._baseline_cumulative_hazard.cpu().numpy()
|
|
3662
|
+
|
|
3663
|
+
def _compute_cindex_gpu(self, X, time, event, beta):
|
|
3664
|
+
"""Compute concordance index (C-index) on GPU using chunked vectorized approach."""
|
|
3665
|
+
import cupy as cp
|
|
3666
|
+
|
|
3667
|
+
# Linear predictor (risk score) on GPU
|
|
3668
|
+
risk_score = X @ beta
|
|
3669
|
+
|
|
3670
|
+
n = len(time)
|
|
3671
|
+
event_mask = (event == 1)
|
|
3672
|
+
|
|
3673
|
+
if cp.sum(event_mask) == 0:
|
|
3674
|
+
return cp.array(0.5, dtype=cp.float64)
|
|
3675
|
+
|
|
3676
|
+
# Use chunked vectorized approach for memory efficiency
|
|
3677
|
+
event_idx = cp.where(event_mask)[0]
|
|
3678
|
+
n_events = len(event_idx)
|
|
3679
|
+
|
|
3680
|
+
if n_events == 0:
|
|
3681
|
+
return cp.array(float("nan"), dtype=cp.float64)
|
|
3682
|
+
|
|
3683
|
+
concordant = cp.int64(0)
|
|
3684
|
+
permissible = cp.int64(0)
|
|
3685
|
+
tied_risk = cp.int64(0)
|
|
3686
|
+
|
|
3687
|
+
# Chunk size for memory efficiency (~128 MB per batch matrix)
|
|
3688
|
+
chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
|
|
3689
|
+
|
|
3690
|
+
for start in range(0, n_events, chunk_size):
|
|
3691
|
+
end = min(start + chunk_size, n_events)
|
|
3692
|
+
idx_chunk = event_idx[start:end]
|
|
3693
|
+
|
|
3694
|
+
time_i = time[idx_chunk][:, None]
|
|
3695
|
+
risk_i = risk_score[idx_chunk][:, None]
|
|
3696
|
+
time_j = time[None, :]
|
|
3697
|
+
risk_j = risk_score[None, :]
|
|
3698
|
+
event_j = event[None, :]
|
|
3699
|
+
|
|
3700
|
+
# Permissible pairs: earlier time OR same time with j censored
|
|
3701
|
+
perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
|
|
3702
|
+
# Exclude self-comparisons
|
|
3703
|
+
chunk_indices = cp.arange(end - start, dtype=cp.int64)
|
|
3704
|
+
perm[chunk_indices, idx_chunk] = False
|
|
3705
|
+
|
|
3706
|
+
concordant += cp.sum(perm & (risk_i > risk_j))
|
|
3707
|
+
tied_risk += cp.sum(perm & (risk_i == risk_j))
|
|
3708
|
+
permissible += cp.sum(perm)
|
|
3709
|
+
|
|
3710
|
+
if permissible > 0:
|
|
3711
|
+
return (concordant.astype(cp.float64) + 0.5 * tied_risk.astype(cp.float64)) / permissible.astype(cp.float64)
|
|
3712
|
+
else:
|
|
3713
|
+
return cp.array(float("nan"), dtype=cp.float64)
|
|
3714
|
+
|
|
3715
|
+
def _compute_cindex(self):
|
|
3716
|
+
"""
|
|
3717
|
+
Compute concordance index (C-index) using chunked vectorized NumPy.
|
|
3718
|
+
|
|
3719
|
+
Replaces the O(n²) double Python loop with batched boolean matrix ops.
|
|
3720
|
+
Chunk size is chosen so each batch matrix stays within ~128 MB.
|
|
3721
|
+
"""
|
|
3722
|
+
if self._X is None or self.coef_ is None:
|
|
3723
|
+
self._cindex = None
|
|
3724
|
+
return
|
|
3725
|
+
|
|
3726
|
+
risk_score = self._X @ self.coef_
|
|
3727
|
+
time = self._time
|
|
3728
|
+
event = self._event
|
|
3729
|
+
n = len(time)
|
|
3730
|
+
|
|
3731
|
+
event_idx = np.where(event == 1)[0]
|
|
3732
|
+
n_events = len(event_idx)
|
|
3733
|
+
|
|
3734
|
+
if n_events == 0:
|
|
3735
|
+
self._cindex = np.nan
|
|
3736
|
+
return
|
|
3737
|
+
|
|
3738
|
+
concordant = np.int64(0)
|
|
3739
|
+
permissible = np.int64(0)
|
|
3740
|
+
tied_risk = np.int64(0)
|
|
3741
|
+
|
|
3742
|
+
# Chunk so each (chunk × n) bool matrix is ≤ 128 MB.
|
|
3743
|
+
chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
|
|
3744
|
+
|
|
3745
|
+
for start in range(0, n_events, chunk_size):
|
|
3746
|
+
end = min(start + chunk_size, n_events)
|
|
3747
|
+
idx_chunk = event_idx[start:end] # (c,)
|
|
3748
|
+
|
|
3749
|
+
time_i = time[idx_chunk, np.newaxis] # (c, 1)
|
|
3750
|
+
risk_i = risk_score[idx_chunk, np.newaxis]
|
|
3751
|
+
time_j = time[np.newaxis, :] # (1, n)
|
|
3752
|
+
risk_j = risk_score[np.newaxis, :]
|
|
3753
|
+
event_j = event[np.newaxis, :]
|
|
3754
|
+
|
|
3755
|
+
# Permissible pairs: earlier time OR same time with j censored.
|
|
3756
|
+
perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
|
|
3757
|
+
# Exclude self-comparisons.
|
|
3758
|
+
perm[np.arange(end - start), idx_chunk] = False
|
|
3759
|
+
|
|
3760
|
+
concordant += int(np.sum(perm & (risk_i > risk_j)))
|
|
3761
|
+
tied_risk += int(np.sum(perm & (risk_i == risk_j)))
|
|
3762
|
+
permissible += int(np.sum(perm))
|
|
3763
|
+
|
|
3764
|
+
if permissible > 0:
|
|
3765
|
+
self._cindex = (concordant + 0.5 * tied_risk) / permissible
|
|
3766
|
+
else:
|
|
3767
|
+
self._cindex = np.nan
|
|
3768
|
+
|
|
3769
|
+
def summary(self):
|
|
3770
|
+
"""Print summary table similar to R's summary(coxph())."""
|
|
3771
|
+
if not self._fitted:
|
|
3772
|
+
raise RuntimeError("Model has not been fitted yet.")
|
|
3773
|
+
|
|
3774
|
+
print("=" * 80)
|
|
3775
|
+
print(" Cox Proportional Hazards Model")
|
|
3776
|
+
print("=" * 80)
|
|
3777
|
+
print(f"Call:")
|
|
3778
|
+
print(f" coxph(formula = Surv(time, event) ~ ., ties = '{self.ties}')")
|
|
3779
|
+
print()
|
|
3780
|
+
print(f" n= {self._nobs}, number of events= {int(self._nevents)}")
|
|
3781
|
+
print(f" covariance type= {self.cov_type}")
|
|
3782
|
+
print()
|
|
3783
|
+
if self.compute_inference and self._bse is not None:
|
|
3784
|
+
print(f"{'':<15} {'coef':>10} {'exp(coef)':>12} {'se(coef)':>10} {'z':>10} {'Pr(>|z|)':>10}")
|
|
3785
|
+
print("-" * 80)
|
|
3786
|
+
|
|
3787
|
+
for i, name in enumerate(self._feature_names):
|
|
3788
|
+
print(f"{name:<15} {self.coef_[i]:>10.4f} {self.hazard_ratios_[i]:>12.4f} "
|
|
3789
|
+
f"{self._bse[i]:>10.4f} {self._zvalues[i]:>10.3f} {self._pvalues[i]:>10.4f}")
|
|
3790
|
+
|
|
3791
|
+
print("-" * 80)
|
|
3792
|
+
print(f"{'':<15} {'exp(coef)':>12} {'exp(-coef)':>12} {'lower .95':>12} {'upper .95':>12}")
|
|
3793
|
+
print("-" * 80)
|
|
3794
|
+
|
|
3795
|
+
for i, name in enumerate(self._feature_names):
|
|
3796
|
+
hr = self.hazard_ratios_[i]
|
|
3797
|
+
print(f"{name:<15} {hr:>12.4f} {1/hr:>12.4f} "
|
|
3798
|
+
f"{np.exp(self._conf_int[i, 0]):>12.4f} {np.exp(self._conf_int[i, 1]):>12.4f}")
|
|
3799
|
+
else:
|
|
3800
|
+
print(f"{'':<15} {'coef':>10} {'exp(coef)':>12}")
|
|
3801
|
+
print("-" * 80)
|
|
3802
|
+
for i, name in enumerate(self._feature_names):
|
|
3803
|
+
print(f"{name:<15} {self.coef_[i]:>10.4f} {self.hazard_ratios_[i]:>12.4f}")
|
|
3804
|
+
print("-" * 80)
|
|
3805
|
+
print("Inference statistics disabled (compute_inference=False).")
|
|
3806
|
+
|
|
3807
|
+
print("=" * 80)
|
|
3808
|
+
if self._cindex is None:
|
|
3809
|
+
print("Concordance: skipped (compute_cindex=False)")
|
|
3810
|
+
else:
|
|
3811
|
+
print(f"Concordance: {self._cindex:.3f} (if 0.5-0.7: moderate, 0.7-0.9: strong)")
|
|
3812
|
+
if self.compute_inference and self._lr_test_stat is not None:
|
|
3813
|
+
print(f"Likelihood ratio test: {self._lr_test_stat:.2f} on {len(self.coef_)} df, p={self._lr_test_pvalue:.4e}")
|
|
3814
|
+
print(f"Wald test: {self._wald_test_stat:.2f} on {len(self.coef_)} df, p={self._wald_test_pvalue:.4e}")
|
|
3815
|
+
print(f"Score (logrank) test: {self._score_test_stat:.2f} on {len(self.coef_)} df, p={self._score_test_pvalue:.4e}")
|
|
3816
|
+
else:
|
|
3817
|
+
print("Likelihood/Wald/Score tests skipped (compute_inference=False).")
|
|
3818
|
+
print(f"Number of Newton-Raphson iterations: {self._iterations}")
|
|
3819
|
+
print(f"Converged: {self._converged}")
|
|
3820
|
+
print("=" * 80)
|
|
3821
|
+
|
|
3822
|
+
def predict_hazard_ratio(self, X):
|
|
3823
|
+
"""
|
|
3824
|
+
Predict hazard ratios (exp(X @ coef)).
|
|
3825
|
+
|
|
3826
|
+
Parameters
|
|
3827
|
+
----------
|
|
3828
|
+
X : array-like of shape (n_samples, n_features)
|
|
3829
|
+
Covariate matrix.
|
|
3830
|
+
|
|
3831
|
+
Returns
|
|
3832
|
+
-------
|
|
3833
|
+
hazard_ratios : ndarray of shape (n_samples,)
|
|
3834
|
+
Predicted hazard ratios.
|
|
3835
|
+
"""
|
|
3836
|
+
self._check_is_fitted()
|
|
3837
|
+
X = np.asarray(X, dtype=np.float64)
|
|
3838
|
+
if X.ndim == 1:
|
|
3839
|
+
X = X.reshape(-1, 1)
|
|
3840
|
+
return np.exp(X @ self.coef_)
|
|
3841
|
+
|
|
3842
|
+
def predict_risk_score(self, X):
|
|
3843
|
+
"""
|
|
3844
|
+
Predict risk scores (X @ coef).
|
|
3845
|
+
|
|
3846
|
+
Parameters
|
|
3847
|
+
----------
|
|
3848
|
+
X : array-like of shape (n_samples, n_features)
|
|
3849
|
+
Covariate matrix.
|
|
3850
|
+
|
|
3851
|
+
Returns
|
|
3852
|
+
-------
|
|
3853
|
+
risk_scores : ndarray of shape (n_samples,)
|
|
3854
|
+
Predicted risk scores (linear predictor).
|
|
3855
|
+
"""
|
|
3856
|
+
self._check_is_fitted()
|
|
3857
|
+
X = np.asarray(X, dtype=np.float64)
|
|
3858
|
+
if X.ndim == 1:
|
|
3859
|
+
X = X.reshape(-1, 1)
|
|
3860
|
+
return X @ self.coef_
|
|
3861
|
+
|
|
3862
|
+
def predict_survival(self, X, times=None):
|
|
3863
|
+
"""
|
|
3864
|
+
Predict survival function S(t|X) = exp(-H0(t) * exp(X @ coef)).
|
|
3865
|
+
|
|
3866
|
+
Parameters
|
|
3867
|
+
----------
|
|
3868
|
+
X : array-like of shape (n_samples, n_features)
|
|
3869
|
+
Covariate matrix.
|
|
3870
|
+
time : array-like, optional
|
|
3871
|
+
Times at which to evaluate survival function.
|
|
3872
|
+
If None, uses unique event times from training data.
|
|
3873
|
+
|
|
3874
|
+
Returns
|
|
3875
|
+
-------
|
|
3876
|
+
survival : ndarray of shape (n_samples, n_times)
|
|
3877
|
+
Predicted survival probabilities.
|
|
3878
|
+
times : ndarray
|
|
3879
|
+
Times at which survival is evaluated.
|
|
3880
|
+
"""
|
|
3881
|
+
self._check_is_fitted()
|
|
3882
|
+
X = np.asarray(X, dtype=np.float64)
|
|
3883
|
+
if X.ndim == 1:
|
|
3884
|
+
X = X.reshape(-1, 1)
|
|
3885
|
+
|
|
3886
|
+
if times is None:
|
|
3887
|
+
times = self._unique_times
|
|
3888
|
+
else:
|
|
3889
|
+
times = np.asarray(times)
|
|
3890
|
+
|
|
3891
|
+
if len(times) == 0 or self._baseline_cumulative_hazard is None:
|
|
3892
|
+
return np.ones((X.shape[0], len(times))), times
|
|
3893
|
+
|
|
3894
|
+
# Hazard ratios
|
|
3895
|
+
hr = np.exp(X @ self.coef_)
|
|
3896
|
+
|
|
3897
|
+
# Survival function: S(t) = exp(-H0(t) * HR)
|
|
3898
|
+
survival = np.exp(-self._baseline_cumulative_hazard[np.newaxis, :] * hr[:, np.newaxis])
|
|
3899
|
+
|
|
3900
|
+
return survival, times
|
|
3901
|
+
|
|
3902
|
+
def predict(self, X):
|
|
3903
|
+
"""Alias for predict_hazard_ratio."""
|
|
3904
|
+
return self.predict_hazard_ratio(X)
|
|
3905
|
+
|
|
3906
|
+
def score(self, X, time, event):
|
|
3907
|
+
"""
|
|
3908
|
+
Compute concordance index on test data.
|
|
3909
|
+
|
|
3910
|
+
Parameters
|
|
3911
|
+
----------
|
|
3912
|
+
X : array-like of shape (n_samples, n_features)
|
|
3913
|
+
Test covariates.
|
|
3914
|
+
time : array-like of shape (n_samples,)
|
|
3915
|
+
Test event/censoring times.
|
|
3916
|
+
event : array-like of shape (n_samples,)
|
|
3917
|
+
Test event indicators.
|
|
3918
|
+
|
|
3919
|
+
Returns
|
|
3920
|
+
-------
|
|
3921
|
+
cindex : float
|
|
3922
|
+
Concordance index.
|
|
3923
|
+
"""
|
|
3924
|
+
self._check_is_fitted()
|
|
3925
|
+
|
|
3926
|
+
risk_score = self.predict_risk_score(X)
|
|
3927
|
+
time = np.asarray(time)
|
|
3928
|
+
event = np.asarray(event)
|
|
3929
|
+
|
|
3930
|
+
n = len(time)
|
|
3931
|
+
event_mask = (event == 1)
|
|
3932
|
+
|
|
3933
|
+
if not np.any(event_mask):
|
|
3934
|
+
return 0.5
|
|
3935
|
+
|
|
3936
|
+
# Use chunked vectorized approach for memory efficiency
|
|
3937
|
+
# Similar to _compute_cindex
|
|
3938
|
+
event_idx = np.where(event_mask)[0]
|
|
3939
|
+
n_events = len(event_idx)
|
|
3940
|
+
|
|
3941
|
+
if n_events == 0:
|
|
3942
|
+
return 0.5
|
|
3943
|
+
|
|
3944
|
+
concordant = np.int64(0)
|
|
3945
|
+
permissible = np.int64(0)
|
|
3946
|
+
tied_risk = np.int64(0)
|
|
3947
|
+
|
|
3948
|
+
# Chunk size: keep each (chunk × n) bool matrix <= 128 MB
|
|
3949
|
+
chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
|
|
3950
|
+
|
|
3951
|
+
for start in range(0, n_events, chunk_size):
|
|
3952
|
+
end = min(start + chunk_size, n_events)
|
|
3953
|
+
idx_chunk = event_idx[start:end]
|
|
3954
|
+
|
|
3955
|
+
time_i = time[idx_chunk, np.newaxis]
|
|
3956
|
+
risk_i = risk_score[idx_chunk, np.newaxis]
|
|
3957
|
+
time_j = time[np.newaxis, :]
|
|
3958
|
+
risk_j = risk_score[np.newaxis, :]
|
|
3959
|
+
event_j = event[np.newaxis, :]
|
|
3960
|
+
|
|
3961
|
+
# Permissible pairs: earlier time OR same time with j censored
|
|
3962
|
+
perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
|
|
3963
|
+
|
|
3964
|
+
# Exclude self-comparisons
|
|
3965
|
+
chunk_indices = np.arange(end - start, dtype=np.int64)
|
|
3966
|
+
perm[chunk_indices, idx_chunk] = False
|
|
3967
|
+
|
|
3968
|
+
concordant += int(np.sum(perm & (risk_i > risk_j)))
|
|
3969
|
+
tied_risk += int(np.sum(perm & (risk_i == risk_j)))
|
|
3970
|
+
permissible += int(np.sum(perm))
|
|
3971
|
+
|
|
3972
|
+
if permissible > 0:
|
|
3973
|
+
return (concordant + 0.5 * tied_risk) / permissible
|
|
3974
|
+
return np.nan
|