statgpu 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. statgpu/__init__.py +174 -0
  2. statgpu/_base.py +544 -0
  3. statgpu/_config.py +127 -0
  4. statgpu/anova/__init__.py +5 -0
  5. statgpu/anova/_oneway.py +194 -0
  6. statgpu/backends/__init__.py +83 -0
  7. statgpu/backends/_array_ops.py +529 -0
  8. statgpu/backends/_base.py +184 -0
  9. statgpu/backends/_cupy.py +453 -0
  10. statgpu/backends/_factory.py +65 -0
  11. statgpu/backends/_gpu_inference_cupy.py +214 -0
  12. statgpu/backends/_gpu_inference_torch.py +422 -0
  13. statgpu/backends/_numpy.py +324 -0
  14. statgpu/backends/_torch.py +685 -0
  15. statgpu/backends/_torch_safe.py +47 -0
  16. statgpu/backends/_utils.py +423 -0
  17. statgpu/core/__init__.py +10 -0
  18. statgpu/core/formula/__init__.py +33 -0
  19. statgpu/core/formula/_design.py +99 -0
  20. statgpu/core/formula/_parser.py +191 -0
  21. statgpu/core/formula/_terms.py +70 -0
  22. statgpu/core/formula/tests/__init__.py +0 -0
  23. statgpu/core/formula/tests/test_parser.py +194 -0
  24. statgpu/covariance/__init__.py +6 -0
  25. statgpu/covariance/_empirical.py +310 -0
  26. statgpu/covariance/_shrinkage.py +248 -0
  27. statgpu/cross_validation/__init__.py +31 -0
  28. statgpu/cross_validation/_base.py +410 -0
  29. statgpu/cross_validation/_engine.py +167 -0
  30. statgpu/diagnostics/__init__.py +7 -0
  31. statgpu/diagnostics/_regression_diagnostics.py +188 -0
  32. statgpu/feature_selection/__init__.py +24 -0
  33. statgpu/feature_selection/_knockoff.py +870 -0
  34. statgpu/feature_selection/_knockoff_utils.py +1003 -0
  35. statgpu/feature_selection/_stepwise.py +300 -0
  36. statgpu/glm_core/__init__.py +81 -0
  37. statgpu/glm_core/_base.py +202 -0
  38. statgpu/glm_core/_family.py +362 -0
  39. statgpu/glm_core/_fused.py +149 -0
  40. statgpu/glm_core/_gamma.py +111 -0
  41. statgpu/glm_core/_inverse_gaussian.py +62 -0
  42. statgpu/glm_core/_irls.py +561 -0
  43. statgpu/glm_core/_logistic.py +82 -0
  44. statgpu/glm_core/_negative_binomial.py +68 -0
  45. statgpu/glm_core/_poisson.py +60 -0
  46. statgpu/glm_core/_solver_legacy.py +100 -0
  47. statgpu/glm_core/_squared.py +53 -0
  48. statgpu/glm_core/_tweedie.py +74 -0
  49. statgpu/inference/__init__.py +239 -0
  50. statgpu/inference/_distributions_backend.py +2610 -0
  51. statgpu/inference/_multiple_testing.py +391 -0
  52. statgpu/inference/_resampling.py +1400 -0
  53. statgpu/inference/_results.py +265 -0
  54. statgpu/linear_model/__init__.py +75 -0
  55. statgpu/linear_model/_gaussian_inference.py +306 -0
  56. statgpu/linear_model/_glm_base.py +1261 -0
  57. statgpu/linear_model/_ordered_logit.py +52 -0
  58. statgpu/linear_model/_ordered_probit.py +50 -0
  59. statgpu/linear_model/_stats.py +170 -0
  60. statgpu/linear_model/cv/__init__.py +13 -0
  61. statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
  62. statgpu/linear_model/cv/_lasso_cv.py +253 -0
  63. statgpu/linear_model/cv/_logistic_cv.py +895 -0
  64. statgpu/linear_model/cv/_ridge_cv.py +1160 -0
  65. statgpu/linear_model/legacy/__init__.py +1 -0
  66. statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
  67. statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
  68. statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
  69. statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
  70. statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
  71. statgpu/linear_model/legacy/_solver_legacy.py +104 -0
  72. statgpu/linear_model/penalized/__init__.py +25 -0
  73. statgpu/linear_model/penalized/_base.py +437 -0
  74. statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
  75. statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
  76. statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
  77. statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
  78. statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
  79. statgpu/linear_model/penalized/_penalized_linear.py +236 -0
  80. statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
  81. statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
  82. statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
  83. statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
  84. statgpu/linear_model/penalized/_predict_mixin.py +182 -0
  85. statgpu/linear_model/wrappers/__init__.py +31 -0
  86. statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
  87. statgpu/linear_model/wrappers/_elasticnet.py +75 -0
  88. statgpu/linear_model/wrappers/_gamma.py +67 -0
  89. statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
  90. statgpu/linear_model/wrappers/_lasso.py +2124 -0
  91. statgpu/linear_model/wrappers/_linear.py +1127 -0
  92. statgpu/linear_model/wrappers/_logistic.py +1435 -0
  93. statgpu/linear_model/wrappers/_mcp.py +58 -0
  94. statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
  95. statgpu/linear_model/wrappers/_poisson.py +48 -0
  96. statgpu/linear_model/wrappers/_ridge.py +166 -0
  97. statgpu/linear_model/wrappers/_scad.py +58 -0
  98. statgpu/linear_model/wrappers/_tweedie.py +57 -0
  99. statgpu/metrics/__init__.py +21 -0
  100. statgpu/metrics/_classification.py +591 -0
  101. statgpu/nonparametric/__init__.py +50 -0
  102. statgpu/nonparametric/kernel_methods/__init__.py +25 -0
  103. statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
  104. statgpu/nonparametric/kernel_methods/_krr.py +234 -0
  105. statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
  106. statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
  107. statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
  108. statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
  109. statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
  110. statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
  111. statgpu/nonparametric/splines/__init__.py +5 -0
  112. statgpu/nonparametric/splines/_bspline_basis.py +336 -0
  113. statgpu/nonparametric/splines/_penalized.py +349 -0
  114. statgpu/panel/__init__.py +19 -0
  115. statgpu/panel/_covariance.py +140 -0
  116. statgpu/panel/_fixed_effects.py +420 -0
  117. statgpu/panel/_random_effects.py +385 -0
  118. statgpu/panel/_utils.py +482 -0
  119. statgpu/penalties/__init__.py +139 -0
  120. statgpu/penalties/_adaptive_l1.py +313 -0
  121. statgpu/penalties/_base.py +261 -0
  122. statgpu/penalties/_categories.py +39 -0
  123. statgpu/penalties/_elasticnet.py +98 -0
  124. statgpu/penalties/_group_lasso.py +678 -0
  125. statgpu/penalties/_group_mcp.py +553 -0
  126. statgpu/penalties/_group_scad.py +605 -0
  127. statgpu/penalties/_l1.py +107 -0
  128. statgpu/penalties/_l2.py +77 -0
  129. statgpu/penalties/_mcp.py +237 -0
  130. statgpu/penalties/_scad.py +260 -0
  131. statgpu/semiparametric/__init__.py +5 -0
  132. statgpu/semiparametric/_gam.py +401 -0
  133. statgpu/solvers/__init__.py +24 -0
  134. statgpu/solvers/_admm.py +241 -0
  135. statgpu/solvers/_constants.py +15 -0
  136. statgpu/solvers/_convergence.py +6 -0
  137. statgpu/solvers/_fista.py +436 -0
  138. statgpu/solvers/_fista_bb.py +513 -0
  139. statgpu/solvers/_fista_lla.py +541 -0
  140. statgpu/solvers/_lbfgs.py +206 -0
  141. statgpu/solvers/_newton.py +149 -0
  142. statgpu/solvers/_utils.py +277 -0
  143. statgpu/survival/__init__.py +14 -0
  144. statgpu/survival/_cox.py +3974 -0
  145. statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
  146. statgpu/survival/_cox_cv.py +1159 -0
  147. statgpu/survival/_cox_efron_cuda.py +1280 -0
  148. statgpu/survival/_cox_efron_triton.py +359 -0
  149. statgpu/unsupervised/__init__.py +29 -0
  150. statgpu/unsupervised/_agglomerative.py +307 -0
  151. statgpu/unsupervised/_dbscan.py +263 -0
  152. statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
  153. statgpu/unsupervised/_gmm.py +332 -0
  154. statgpu/unsupervised/_incremental_pca.py +176 -0
  155. statgpu/unsupervised/_kmeans.py +261 -0
  156. statgpu/unsupervised/_minibatch_kmeans.py +299 -0
  157. statgpu/unsupervised/_minibatch_nmf.py +252 -0
  158. statgpu/unsupervised/_nmf.py +190 -0
  159. statgpu/unsupervised/_pca.py +189 -0
  160. statgpu/unsupervised/_truncated_svd.py +132 -0
  161. statgpu/unsupervised/_tsne.py +192 -0
  162. statgpu/unsupervised/_umap.py +224 -0
  163. statgpu/unsupervised/_utils.py +134 -0
  164. statgpu-0.1.0.dist-info/METADATA +245 -0
  165. statgpu-0.1.0.dist-info/RECORD +168 -0
  166. statgpu-0.1.0.dist-info/WHEEL +5 -0
  167. statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
  168. statgpu-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1159 @@
1
+ """
2
+ CoxPHCV: Cross-validated Cox Proportional Hazards regression.
3
+
4
+ Implements K-fold cross-validation to select the optimal penalty (L2 regularization)
5
+ parameter for Cox PH models.
6
+ """
7
+
8
+ from typing import Optional, Union, Tuple, Dict, Any, List
9
+ from collections import OrderedDict
10
+ import hashlib
11
+ import os
12
+ import numpy as np
13
+
14
+ from statgpu._config import Device
15
+ from statgpu.backends import _get_torch_device_str
16
+ from statgpu.cross_validation._base import CVEstimatorBase
17
+ from statgpu.survival._cox import CoxPH
18
+
19
+
20
+ # =============================================================================
21
+ # CV Cache
22
+ # =============================================================================
23
+
24
+ _COXPH_CV_CACHE_MAXSIZE = int(64)
25
+ _COXPH_CV_CACHE: "OrderedDict[str, Dict[str, Any]]" = OrderedDict()
26
+
27
+
28
+ def _env_flag(name: str, default: bool = False) -> bool:
29
+ """Safely parse boolean env var."""
30
+ raw = os.environ.get(name)
31
+ if raw is None:
32
+ return bool(default)
33
+ return str(raw).strip().lower() in ("1", "true", "yes", "on")
34
+
35
+
36
+ def _env_int(
37
+ name: str,
38
+ default: int,
39
+ *,
40
+ min_value: Optional[int] = None,
41
+ max_value: Optional[int] = None,
42
+ ) -> int:
43
+ """Safely parse integer env var with optional bounds."""
44
+ raw = os.environ.get(name)
45
+ try:
46
+ val = int(raw) if raw is not None else int(default)
47
+ except (TypeError, ValueError):
48
+ val = int(default)
49
+ if min_value is not None:
50
+ val = max(min_value, val)
51
+ if max_value is not None:
52
+ val = min(max_value, val)
53
+ return val
54
+
55
+
56
+ def _env_float(name: str, default: float, *, min_value: Optional[float] = None) -> float:
57
+ """Safely parse float env var with optional lower bound."""
58
+ raw = os.environ.get(name)
59
+ try:
60
+ val = float(raw) if raw is not None else float(default)
61
+ except (TypeError, ValueError):
62
+ val = float(default)
63
+ if min_value is not None:
64
+ val = max(min_value, val)
65
+ return val
66
+
67
+
68
+ def _hash_optional_array(h: "hashlib._blake2.blake2b", tag: str, arr: Optional[np.ndarray]) -> None:
69
+ """Hash optional array content for cache-key disambiguation."""
70
+ if arr is None:
71
+ h.update(f"{tag}:none".encode("utf-8"))
72
+ return
73
+ h.update(tag.encode("utf-8"))
74
+ arr_np = np.asarray(arr)
75
+ h.update(np.asarray(arr_np.shape, dtype=np.int64).tobytes())
76
+ h.update(str(arr_np.dtype).encode("utf-8"))
77
+ h.update(np.ascontiguousarray(arr_np).tobytes())
78
+
79
+
80
+ def _coxcv_cache_get(cache_key: Optional[str]) -> Optional[Dict[str, Any]]:
81
+ """Get cached CoxPH CV results."""
82
+ if cache_key is None:
83
+ return None
84
+ val = _COXPH_CV_CACHE.get(cache_key)
85
+ if val is not None:
86
+ _COXPH_CV_CACHE.move_to_end(cache_key)
87
+ return val
88
+
89
+
90
+ def _coxcv_cache_put(cache_key: Optional[str], value: Dict[str, Any]) -> None:
91
+ """Put cached CoxPH CV results."""
92
+ if cache_key is None:
93
+ return
94
+ _COXPH_CV_CACHE[cache_key] = value
95
+ _COXPH_CV_CACHE.move_to_end(cache_key)
96
+ while len(_COXPH_CV_CACHE) > _COXPH_CV_CACHE_MAXSIZE:
97
+ _COXPH_CV_CACHE.popitem(last=False)
98
+
99
+
100
+ def _sample_hash(h, arr, max_rows=50):
101
+ """Hash a sampled subset of an array for cache key generation."""
102
+ arr_np = np.asarray(arr, dtype=np.float64).ravel()
103
+ n = arr_np.shape[0]
104
+ if n <= max_rows:
105
+ h.update(arr_np.tobytes())
106
+ else:
107
+ # Sample first, middle, and last rows
108
+ indices = np.concatenate([np.arange(max_rows//2), np.arange(n-max_rows//2, n)])
109
+ h.update(arr_np[indices].tobytes())
110
+
111
+
112
+ def _make_coxph_cv_auto_cache_key(
113
+ X_shape: Tuple[int, ...],
114
+ time_shape: Tuple[int, ...],
115
+ event_shape: Tuple[int, ...],
116
+ penalties: Optional[np.ndarray],
117
+ n_penalties: int,
118
+ penalty_min_ratio: float,
119
+ folds: List[Tuple[np.ndarray, np.ndarray]],
120
+ ties: str,
121
+ use_gpu: bool,
122
+ fit_device: str,
123
+ cv_cuda_torch_bridge: bool,
124
+ entry: Optional[np.ndarray],
125
+ cluster: Optional[np.ndarray],
126
+ two_stage_enabled: bool,
127
+ halving_enabled: bool,
128
+ coarse_n: int,
129
+ window: int,
130
+ halving_topk: int,
131
+ fast_iter: int,
132
+ fast_tol: float,
133
+ max_iter: int,
134
+ tol: float,
135
+ X_data=None,
136
+ time_data=None,
137
+ event_data=None,
138
+ ) -> str:
139
+ """
140
+ Generate automatic cache key for CoxPH CV.
141
+
142
+ Includes structural inputs (shapes/grid/folds), execution-path settings
143
+ (fit device/bridge/two-stage/halving), and optional delayed-entry or
144
+ clustering arrays to avoid stale collisions across distinct CV runs.
145
+ """
146
+ h = hashlib.blake2b(digest_size=32)
147
+ h.update(np.asarray(X_shape, dtype=np.int64).tobytes())
148
+ h.update(np.asarray(time_shape, dtype=np.int64).tobytes())
149
+ h.update(np.asarray(event_shape, dtype=np.int64).tobytes())
150
+ # Include sampled data content to avoid collisions across datasets with same shape
151
+ if X_data is not None:
152
+ _sample_hash(h, X_data, max_rows=50)
153
+ if time_data is not None:
154
+ _sample_hash(h, time_data, max_rows=50)
155
+ if event_data is not None:
156
+ _sample_hash(h, event_data, max_rows=50)
157
+ if penalties is not None:
158
+ h.update(np.asarray(penalties, dtype=np.float64).tobytes())
159
+ h.update(str(n_penalties).encode("utf-8"))
160
+ h.update(str(penalty_min_ratio).encode("utf-8"))
161
+ h.update(str(folds).encode("utf-8"))
162
+ h.update(str(ties).encode("utf-8"))
163
+ h.update(str(use_gpu).encode("utf-8"))
164
+ h.update(str(fit_device).encode("utf-8"))
165
+ h.update(str(cv_cuda_torch_bridge).encode("utf-8"))
166
+ _hash_optional_array(h, "entry", entry)
167
+ _hash_optional_array(h, "cluster", cluster)
168
+ h.update(str(two_stage_enabled).encode("utf-8"))
169
+ h.update(str(halving_enabled).encode("utf-8"))
170
+ h.update(str(coarse_n).encode("utf-8"))
171
+ h.update(str(window).encode("utf-8"))
172
+ h.update(str(halving_topk).encode("utf-8"))
173
+ h.update(str(fast_iter).encode("utf-8"))
174
+ h.update(str(fast_tol).encode("utf-8"))
175
+ h.update(str(max_iter).encode("utf-8"))
176
+ h.update(str(tol).encode("utf-8"))
177
+ return h.hexdigest()
178
+
179
+
180
+ # =============================================================================
181
+ # K-fold helpers
182
+ # =============================================================================
183
+
184
+ def _kfold_indices(n_samples: int, n_splits: int, random_state: Optional[int] = None):
185
+ """Generate K-fold train/test indices."""
186
+ rng = np.random.RandomState(random_state)
187
+ indices = np.arange(n_samples)
188
+ rng.shuffle(indices)
189
+ fold_sizes = np.full(n_splits, n_samples // n_splits, dtype=np.int64)
190
+ fold_sizes[: n_samples % n_splits] += 1
191
+ current = 0
192
+ folds = []
193
+ for fold_size in fold_sizes:
194
+ start, stop = current, current + fold_size
195
+ test_idx = indices[start:stop]
196
+ train_idx = np.concatenate([indices[:start], indices[stop:]])
197
+ folds.append((train_idx, test_idx))
198
+ current = stop
199
+ return folds
200
+
201
+
202
+ def _folds_are_complements(folds, n_samples: int) -> bool:
203
+ """Check if folds are complementary."""
204
+ test_indices = np.concatenate([f[1] for f in folds])
205
+ if len(test_indices) != n_samples:
206
+ return False
207
+ return np.array_equal(np.sort(test_indices), np.arange(n_samples))
208
+
209
+
210
+ # =============================================================================
211
+ # Penalty grid generation
212
+ # =============================================================================
213
+
214
+ def _default_coxph_penalty_grid(
215
+ X: np.ndarray,
216
+ time: np.ndarray,
217
+ event: np.ndarray,
218
+ n_penalties: int = 100,
219
+ penalty_min_ratio: float = 1e-3,
220
+ ) -> np.ndarray:
221
+ """
222
+ Generate default penalty grid for CoxPHCV.
223
+
224
+ Penalty values are log-spaced, similar to alpha grid in RidgeCV.
225
+
226
+ Parameters
227
+ ----------
228
+ X : ndarray
229
+ Design matrix (n_samples, n_features).
230
+ time : ndarray
231
+ Survival times.
232
+ event : ndarray
233
+ Event indicators (1=event, 0=censored).
234
+ n_penalties : int
235
+ Number of penalty values.
236
+ penalty_min_ratio : float
237
+ Minimum penalty as ratio of max penalty.
238
+
239
+ Returns
240
+ -------
241
+ penalties : ndarray
242
+ Log-spaced penalty values.
243
+ """
244
+ n_samples, n_features = X.shape
245
+ n_events = int(np.sum(event))
246
+
247
+ if n_events == 0:
248
+ # No events - return simple grid
249
+ return np.geomspace(1e-3, 1, n_penalties)
250
+
251
+ # Estimate penalty_max from data variance
252
+ # Larger variance -> larger potential penalty
253
+ X_var = np.var(X, axis=0)
254
+ penalty_max = np.max(X_var) * n_events * 0.1
255
+
256
+ # Ensure penalty_max is positive and reasonable
257
+ penalty_max = max(penalty_max, 1.0)
258
+ penalty_min = penalty_min_ratio * penalty_max
259
+
260
+ penalties = np.geomspace(penalty_max, penalty_min, n_penalties)
261
+ return penalties.astype(np.float64)
262
+
263
+
264
+ # =============================================================================
265
+ # Partial likelihood computation for CV evaluation
266
+ # =============================================================================
267
+
268
+ def _compute_partial_likelihood(
269
+ X: np.ndarray,
270
+ time: np.ndarray,
271
+ event: np.ndarray,
272
+ coef: np.ndarray,
273
+ entry: Optional[np.ndarray] = None,
274
+ ties: str = 'breslow',
275
+ ) -> float:
276
+ """
277
+ Compute log partial likelihood for given coefficients.
278
+
279
+ This is used for CV evaluation on held-out test folds.
280
+
281
+ Parameters
282
+ ----------
283
+ X : ndarray
284
+ Design matrix (n_samples, n_features).
285
+ time : ndarray
286
+ Survival times.
287
+ event : ndarray
288
+ Event indicators.
289
+ coef : ndarray
290
+ Coefficient values.
291
+ entry : ndarray or None
292
+ Delayed-entry times (left truncation). If None, assumes entry=0 for all samples.
293
+ ties : str
294
+ 'breslow' or 'efron'.
295
+
296
+ Returns
297
+ -------
298
+ log_pl : float
299
+ Log partial likelihood value.
300
+ """
301
+ n = len(time)
302
+ if coef is None or np.all(coef == 0):
303
+ # Null model: compute log partial likelihood at beta=0
304
+ # L(0) = sum_events[-log(|R(t_i)|)] where |R(t_i)| = n - i (sorted)
305
+ order = np.argsort(time)
306
+ event_sorted = event[order]
307
+ # Risk set size at sorted position i is (n - i)
308
+ risk_set_sizes = n - np.arange(n)
309
+ event_mask = event_sorted.astype(bool)
310
+ null_ll = -np.sum(np.log(risk_set_sizes[event_mask].astype(float)))
311
+ return null_ll
312
+
313
+ risk_scores = X @ coef
314
+ exp_risk = np.exp(risk_scores)
315
+
316
+ # Fast path (no delayed-entry): keep vectorized suffix-sum implementation.
317
+ if entry is None:
318
+ order = np.argsort(time)
319
+ time_sorted = time[order]
320
+ event_sorted = event[order]
321
+ risk_sorted = risk_scores[order]
322
+ exp_risk_sorted = exp_risk[order]
323
+ log_pl = 0.0
324
+ if ties == 'breslow':
325
+ risk_set_sum = np.cumsum(exp_risk_sorted[::-1])[::-1]
326
+ event_mask = event_sorted == 1
327
+ if np.any(event_mask):
328
+ log_pl = np.sum(risk_sorted[event_mask]) - np.sum(np.log(risk_set_sum[event_mask] + 1e-300))
329
+ elif ties == 'efron':
330
+ event_mask = event_sorted == 1
331
+ if not np.any(event_mask):
332
+ return 0.0
333
+ event_idx = np.where(event_mask)[0]
334
+ event_times = time_sorted[event_idx]
335
+ unique_times, inv, counts = np.unique(event_times, return_inverse=True, return_counts=True)
336
+ risk_set_sum = np.cumsum(exp_risk_sorted[::-1])[::-1]
337
+ for g, t in enumerate(unique_times):
338
+ d = counts[g]
339
+ if d == 0:
340
+ continue
341
+ first_idx = np.searchsorted(time_sorted, t, side='left')
342
+ risk_at_t = risk_set_sum[first_idx]
343
+ event_rows = event_idx[inv == g]
344
+ sum_risk = np.sum(risk_sorted[event_rows])
345
+ sum_exp_risk = np.sum(exp_risk_sorted[event_rows])
346
+ k = np.arange(d, dtype=np.float64) / d
347
+ denom = risk_at_t - k * sum_exp_risk
348
+ log_pl += sum_risk - np.sum(np.log(np.maximum(denom, 1e-300)))
349
+ return float(log_pl)
350
+
351
+ entry_arr = np.asarray(entry, dtype=np.float64)
352
+ # Delayed-entry path
353
+ order = np.argsort(time)
354
+ time_sorted = time[order]
355
+ event_sorted = event[order]
356
+ entry_sorted = entry_arr[order]
357
+ risk_sorted = risk_scores[order]
358
+ exp_risk_sorted = exp_risk[order]
359
+
360
+ log_pl = 0.0
361
+
362
+ # With delayed entry, risk set is:
363
+ # R(t) = {j: entry_j <= t <= time_j}
364
+ # We compute denominators directly per unique event time for correctness.
365
+ event_mask = event_sorted == 1
366
+ if not np.any(event_mask):
367
+ return 0.0
368
+ event_idx = np.where(event_mask)[0]
369
+ event_times = time_sorted[event_idx]
370
+
371
+ if ties == 'breslow':
372
+ unique_times, inv, counts = np.unique(event_times, return_inverse=True, return_counts=True)
373
+ for g, t in enumerate(unique_times):
374
+ d = counts[g]
375
+ if d == 0:
376
+ continue
377
+ events_at_t = event_idx[inv == g]
378
+ risk_mask = (entry_sorted <= t) & (time_sorted >= t)
379
+ risk_at_t = np.sum(exp_risk_sorted[risk_mask])
380
+ sum_risk = np.sum(risk_sorted[events_at_t])
381
+ log_pl += sum_risk - d * np.log(max(risk_at_t, 1e-300))
382
+
383
+ elif ties == 'efron':
384
+ # Efron method by unique failure times
385
+ unique_times, inv, counts = np.unique(event_times, return_inverse=True, return_counts=True)
386
+ for g, t in enumerate(unique_times):
387
+ d = counts[g]
388
+ if d == 0:
389
+ continue
390
+ event_rows = event_idx[inv == g]
391
+ risk_mask = (entry_sorted <= t) & (time_sorted >= t)
392
+ risk_at_t = np.sum(exp_risk_sorted[risk_mask])
393
+ sum_risk = np.sum(risk_sorted[event_rows])
394
+ sum_exp_risk = np.sum(exp_risk_sorted[event_rows])
395
+
396
+ # Efron correction
397
+ k = np.arange(d, dtype=np.float64) / d
398
+ denom = risk_at_t - k * sum_exp_risk
399
+ log_pl += sum_risk - np.sum(np.log(np.maximum(denom, 1e-300)))
400
+
401
+ return float(log_pl)
402
+
403
+
404
+ # =============================================================================
405
+ # CV main function
406
+ # =============================================================================
407
+
408
+ def _select_coxph_penalty_cv(
409
+ X,
410
+ time,
411
+ event,
412
+ entry=None,
413
+ cluster=None,
414
+ *,
415
+ penalties=None,
416
+ n_penalties: int = 100,
417
+ penalty_min_ratio: float = 1e-3,
418
+ cv_folds: int = 5,
419
+ cv_splits=None,
420
+ random_state: Optional[int] = None,
421
+ ties: str = "breslow",
422
+ device: Union[str, Device] = Device.CPU,
423
+ max_iter: int = 100,
424
+ tol: float = 1e-9,
425
+ return_details: bool = False,
426
+ cache_key: Optional[str] = None,
427
+ ):
428
+ """
429
+ Select penalty for CoxPH via K-fold cross-validation.
430
+
431
+ For each fold:
432
+ 1. Split data into train/test
433
+ 2. Fit CoxPH on train for each penalty
434
+ 3. Evaluate partial likelihood on test
435
+
436
+ Returns the penalty with maximum mean partial likelihood.
437
+
438
+ Parameters
439
+ ----------
440
+ X : ndarray
441
+ Design matrix (n_samples, n_features).
442
+ time : ndarray
443
+ Survival times (n_samples,).
444
+ event : ndarray
445
+ Event indicators (n_samples,).
446
+ entry : ndarray or None
447
+ Delayed-entry times.
448
+ cluster : ndarray or None
449
+ Cluster ids (used in model fitting; scoring remains partial likelihood).
450
+ penalties : ndarray or None
451
+ Penalty values to try. If None, generates grid.
452
+ n_penalties : int
453
+ Number of penalties (if penalties is None).
454
+ penalty_min_ratio : float
455
+ Minimum penalty ratio.
456
+ cv_folds : int
457
+ Number of CV folds.
458
+ cv_splits : list or None
459
+ Pre-computed CV splits.
460
+ random_state : int or None
461
+ Random seed.
462
+ ties : str
463
+ 'breslow' or 'efron'.
464
+ device : str or Device
465
+ Computation device.
466
+ max_iter : int
467
+ Maximum iterations.
468
+ tol : float
469
+ Convergence tolerance.
470
+ return_details : bool
471
+ Whether to return full CV details.
472
+ cache_key : str or None
473
+ Cache key.
474
+
475
+ Returns
476
+ -------
477
+ best_penalty : float
478
+ details : dict (if return_details=True)
479
+ """
480
+ device_name = str(device).lower() if not isinstance(device, Device) else device.value
481
+ use_gpu = device_name in (Device.CUDA.value, Device.TORCH.value)
482
+ # Optional CV bridge for CUDA: many medium-size CV workloads are faster with
483
+ # torch backend while preserving the same CoxPHCV public API.
484
+ cv_cuda_torch_bridge = os.environ.get(
485
+ "STATGPU_COXPHCV_CUDA_TORCH_BRIDGE", "0"
486
+ ).strip().lower() in ("1", "true", "yes", "on")
487
+
488
+ # Convert to numpy arrays
489
+ X_np = np.asarray(X, dtype=np.float64)
490
+ time_np = np.asarray(time, dtype=np.float64)
491
+ event_np = np.asarray(event, dtype=np.int32)
492
+ entry_np = None if entry is None else np.asarray(entry, dtype=np.float64)
493
+ cluster_np = None if cluster is None else np.asarray(cluster)
494
+
495
+ n_samples = X_np.shape[0]
496
+ n_features = X_np.shape[1]
497
+ fit_device = device_name
498
+ if (
499
+ cv_cuda_torch_bridge
500
+ and device_name == Device.CUDA.value
501
+ and n_samples >= 1500
502
+ and n_features >= 40
503
+ ):
504
+ fit_device = Device.TORCH.value
505
+
506
+ # Generate penalty grid
507
+ if penalties is None:
508
+ penalties = _default_coxph_penalty_grid(X_np, time_np, event_np, n_penalties, penalty_min_ratio)
509
+ else:
510
+ penalties = np.asarray(penalties, dtype=np.float64)
511
+ penalties = penalties[np.isfinite(penalties)]
512
+ penalties = penalties[penalties >= 0]
513
+ if penalties.size == 0:
514
+ penalties = _default_coxph_penalty_grid(X_np, time_np, event_np, n_penalties, penalty_min_ratio)
515
+
516
+ n_penalties_actual = len(penalties)
517
+
518
+ # Handle degenerate cases
519
+ if n_samples < 4 or cv_folds < 2:
520
+ if not return_details:
521
+ return float(penalties[0])
522
+ return {
523
+ "penalty": float(penalties[0]),
524
+ "penalties": penalties.astype(np.float64),
525
+ "pl_path": np.full((n_penalties_actual, 1), np.nan, dtype=np.float64),
526
+ "mean_pl": np.full(n_penalties_actual, np.nan, dtype=np.float64),
527
+ "best_pl": np.nan,
528
+ }
529
+
530
+ # Generate CV folds
531
+ if cv_splits is not None:
532
+ folds = cv_splits
533
+ else:
534
+ folds = _kfold_indices(n_samples, cv_folds, random_state)
535
+
536
+ folds_are_complements_flag = _folds_are_complements(folds, n_samples)
537
+ n_folds = len(folds)
538
+
539
+ # Keep exhaustive full-grid CV as the default behavior. Two-stage is opt-in.
540
+ two_stage_enabled = (
541
+ _env_flag("STATGPU_COXPHCV_TWO_STAGE", False) # default=False: opt-in
542
+ and device_name == Device.CUDA.value
543
+ and n_penalties_actual >= 8
544
+ )
545
+ halving_enabled = (
546
+ _env_flag("STATGPU_COXPHCV_SUCCESSIVE_HALVING", False)
547
+ and device_name == Device.CUDA.value
548
+ and n_penalties_actual >= 8
549
+ )
550
+ coarse_n = _env_int(
551
+ "STATGPU_COXPHCV_TWO_STAGE_COARSE",
552
+ 6,
553
+ min_value=3,
554
+ max_value=n_penalties_actual,
555
+ )
556
+ window = _env_int("STATGPU_COXPHCV_TWO_STAGE_WINDOW", 2, min_value=1)
557
+ halving_topk = _env_int(
558
+ "STATGPU_COXPHCV_HALVING_TOPK",
559
+ 3,
560
+ min_value=1,
561
+ max_value=n_penalties_actual,
562
+ )
563
+ fast_iter = _env_int(
564
+ "STATGPU_COXPHCV_HALVING_FAST_ITER",
565
+ 30,
566
+ min_value=5,
567
+ max_value=max_iter,
568
+ )
569
+ fast_tol = _env_float("STATGPU_COXPHCV_HALVING_FAST_TOL", 1e-6, min_value=tol)
570
+
571
+ # Cache handling
572
+ cache_key_eff = cache_key
573
+ if cache_key_eff is None and _COXPH_CV_CACHE_MAXSIZE > 0:
574
+ cache_key_eff = _make_coxph_cv_auto_cache_key(
575
+ X_shape=X_np.shape,
576
+ time_shape=time_np.shape,
577
+ event_shape=event_np.shape,
578
+ X_data=X_np,
579
+ time_data=time_np,
580
+ event_data=event_np,
581
+ penalties=penalties,
582
+ n_penalties=n_penalties,
583
+ penalty_min_ratio=penalty_min_ratio,
584
+ folds=folds,
585
+ ties=ties,
586
+ use_gpu=use_gpu,
587
+ fit_device=fit_device,
588
+ cv_cuda_torch_bridge=cv_cuda_torch_bridge,
589
+ entry=entry_np,
590
+ cluster=cluster_np,
591
+ two_stage_enabled=two_stage_enabled,
592
+ halving_enabled=halving_enabled,
593
+ coarse_n=coarse_n,
594
+ window=window,
595
+ halving_topk=halving_topk,
596
+ fast_iter=fast_iter,
597
+ fast_tol=fast_tol,
598
+ max_iter=max_iter,
599
+ tol=tol,
600
+ )
601
+
602
+ cached_result = _coxcv_cache_get(cache_key_eff)
603
+ if cached_result is not None:
604
+ if return_details:
605
+ return cached_result["penalty"], cached_result
606
+ return cached_result["penalty"]
607
+
608
+ # Storage for partial likelihoods: (n_penalties, n_folds)
609
+ pl_path = np.full((n_penalties_actual, n_folds), np.nan, dtype=np.float64)
610
+
611
+ def _evaluate_penalty_indices(
612
+ penalty_indices: np.ndarray,
613
+ *,
614
+ fit_max_iter: int,
615
+ fit_tol: float,
616
+ ) -> None:
617
+ if penalty_indices.size == 0:
618
+ return
619
+ penalty_indices = np.unique(np.asarray(penalty_indices, dtype=np.int64))
620
+ for fold_idx, (train_idx, test_idx) in enumerate(folds):
621
+ X_train, X_test = X_np[train_idx], X_np[test_idx]
622
+ time_train, time_test = time_np[train_idx], time_np[test_idx]
623
+ event_train, event_test = event_np[train_idx], event_np[test_idx]
624
+ entry_train = None if entry_np is None else entry_np[train_idx]
625
+ entry_test = None if entry_np is None else entry_np[test_idx]
626
+ cluster_train = None if cluster_np is None else cluster_np[train_idx]
627
+ X_fit = X_train
628
+ time_fit = time_train
629
+ event_fit = event_train
630
+ entry_fit = entry_train
631
+ cluster_fit = cluster_train
632
+
633
+ # Reduce repeated host->device conversions by preparing one fold
634
+ # tensor/array per backend and reusing it across penalties.
635
+ if fit_device == Device.CUDA.value:
636
+ try:
637
+ import cupy as cp
638
+ X_fit = cp.asarray(X_train, dtype=cp.float64)
639
+ time_fit = cp.asarray(time_train, dtype=cp.float64)
640
+ event_fit = cp.asarray(event_train, dtype=cp.int32)
641
+ entry_fit = None if entry_train is None else cp.asarray(entry_train, dtype=cp.float64)
642
+ cluster_fit = None if cluster_train is None else cp.asarray(cluster_train, dtype=cp.int64)
643
+ except Exception:
644
+ X_fit = X_train
645
+ time_fit = time_train
646
+ event_fit = event_train
647
+ entry_fit = entry_train
648
+ cluster_fit = cluster_train
649
+ elif fit_device == Device.TORCH.value:
650
+ try:
651
+ import torch
652
+ torch_device = _get_torch_device_str()
653
+ X_fit = torch.as_tensor(X_train, dtype=torch.float64, device=torch_device)
654
+ time_fit = torch.as_tensor(time_train, dtype=torch.float64, device=torch_device)
655
+ event_fit = torch.as_tensor(event_train, dtype=torch.int32, device=torch_device)
656
+ entry_fit = None if entry_train is None else torch.as_tensor(
657
+ entry_train, dtype=torch.float64, device=torch_device
658
+ )
659
+ cluster_fit = None if cluster_train is None else torch.as_tensor(
660
+ cluster_train, dtype=torch.int64, device=torch_device
661
+ )
662
+ except Exception:
663
+ X_fit = X_train
664
+ time_fit = time_train
665
+ event_fit = event_train
666
+ entry_fit = entry_train
667
+ cluster_fit = cluster_train
668
+
669
+ n_events_train = int(np.sum(event_train))
670
+ n_events_test = int(np.sum(event_test))
671
+ if n_events_train == 0 or n_events_test == 0:
672
+ continue
673
+
674
+ prev_coef = None
675
+ for penalty_idx in penalty_indices:
676
+ if np.isfinite(pl_path[penalty_idx, fold_idx]):
677
+ continue
678
+ penalty = penalties[penalty_idx]
679
+ model = CoxPH(
680
+ ties=ties,
681
+ max_iter=fit_max_iter,
682
+ tol=fit_tol,
683
+ device=fit_device,
684
+ compute_inference=False,
685
+ penalty=penalty,
686
+ )
687
+ try:
688
+ model.fit(
689
+ X_fit,
690
+ time_fit,
691
+ event_fit,
692
+ entry=entry_fit,
693
+ cluster=cluster_fit,
694
+ init_coef=prev_coef,
695
+ )
696
+ if not model._converged:
697
+ continue
698
+ prev_coef = np.asarray(model.coef_, dtype=np.float64).copy()
699
+ pl_test = _compute_partial_likelihood(
700
+ X_test, time_test, event_test, model.coef_, entry=entry_test, ties=ties
701
+ )
702
+ pl_path[penalty_idx, fold_idx] = pl_test
703
+ except Exception:
704
+ continue
705
+
706
+ if two_stage_enabled:
707
+ stage1_idx = np.unique(
708
+ np.linspace(0, n_penalties_actual - 1, num=coarse_n, dtype=np.int64)
709
+ )
710
+ _evaluate_penalty_indices(
711
+ stage1_idx,
712
+ fit_max_iter=(fast_iter if halving_enabled else max_iter),
713
+ fit_tol=(fast_tol if halving_enabled else tol),
714
+ )
715
+ stage1_mean = np.nanmean(pl_path[stage1_idx, :], axis=1)
716
+ if np.any(np.isfinite(stage1_mean)):
717
+ stage1_best = int(stage1_idx[int(np.nanargmax(stage1_mean))])
718
+ else:
719
+ stage1_best = int(stage1_idx[len(stage1_idx) // 2])
720
+ lo = max(0, stage1_best - window)
721
+ hi = min(n_penalties_actual - 1, stage1_best + window)
722
+ stage2_idx = np.arange(lo, hi + 1, dtype=np.int64)
723
+ _evaluate_penalty_indices(
724
+ stage2_idx,
725
+ fit_max_iter=(fast_iter if halving_enabled else max_iter),
726
+ fit_tol=(fast_tol if halving_enabled else tol),
727
+ )
728
+ if halving_enabled:
729
+ stage2_mean = np.full(stage2_idx.shape[0], np.nan, dtype=np.float64)
730
+ stage2_valid = np.any(np.isfinite(pl_path[stage2_idx, :]), axis=1)
731
+ if np.any(stage2_valid):
732
+ stage2_mean[stage2_valid] = np.nanmean(pl_path[stage2_idx[stage2_valid], :], axis=1)
733
+ order = np.argsort(np.nan_to_num(stage2_mean, nan=-np.inf))[::-1]
734
+ top_idx = stage2_idx[order[: min(halving_topk, len(stage2_idx))]]
735
+ # Re-evaluate top candidates with full precision and overwrite.
736
+ pl_path[top_idx, :] = np.nan
737
+ _evaluate_penalty_indices(top_idx, fit_max_iter=max_iter, fit_tol=tol)
738
+ else:
739
+ full_idx = np.arange(n_penalties_actual, dtype=np.int64)
740
+ if halving_enabled:
741
+ _evaluate_penalty_indices(full_idx, fit_max_iter=fast_iter, fit_tol=fast_tol)
742
+ full_mean = np.full(full_idx.shape[0], np.nan, dtype=np.float64)
743
+ full_valid = np.any(np.isfinite(pl_path[full_idx, :]), axis=1)
744
+ if np.any(full_valid):
745
+ full_mean[full_valid] = np.nanmean(pl_path[full_idx[full_valid], :], axis=1)
746
+ order = np.argsort(np.nan_to_num(full_mean, nan=-np.inf))[::-1]
747
+ top_idx = full_idx[order[:halving_topk]]
748
+ pl_path[top_idx, :] = np.nan
749
+ _evaluate_penalty_indices(top_idx, fit_max_iter=max_iter, fit_tol=tol)
750
+ else:
751
+ _evaluate_penalty_indices(full_idx, fit_max_iter=max_iter, fit_tol=tol)
752
+
753
+ # Safety fallback: if no penalty has any finite fold score, evaluate full grid once.
754
+ has_any_valid = np.any(np.isfinite(pl_path), axis=1)
755
+ if not np.any(has_any_valid):
756
+ _evaluate_penalty_indices(
757
+ np.arange(n_penalties_actual, dtype=np.int64),
758
+ fit_max_iter=max_iter,
759
+ fit_tol=tol,
760
+ )
761
+
762
+ # Compute mean partial likelihood across folds
763
+ mean_pl = np.full(n_penalties_actual, np.nan, dtype=np.float64)
764
+ valid_rows = np.any(np.isfinite(pl_path), axis=1)
765
+ if np.any(valid_rows):
766
+ mean_pl[valid_rows] = np.nanmean(pl_path[valid_rows], axis=1)
767
+
768
+ # Find best penalty (maximum partial likelihood)
769
+ if np.any(np.isfinite(mean_pl)):
770
+ best_idx = np.nanargmax(mean_pl)
771
+ best_penalty = float(penalties[best_idx])
772
+ best_pl = float(mean_pl[best_idx])
773
+ else:
774
+ # No valid CV results - use first penalty
775
+ best_penalty = float(penalties[0])
776
+ best_pl = np.nan
777
+
778
+ # Prepare details
779
+ details = {
780
+ "penalty": best_penalty,
781
+ "penalties": penalties.astype(np.float64),
782
+ "pl_path": pl_path.astype(np.float64),
783
+ "mean_pl": mean_pl.astype(np.float64),
784
+ "best_pl": best_pl,
785
+ "n_folds": n_folds,
786
+ }
787
+
788
+ # Cache result
789
+ if _COXPH_CV_CACHE_MAXSIZE > 0:
790
+ _coxcv_cache_put(cache_key_eff, details)
791
+
792
+ if return_details:
793
+ return best_penalty, details
794
+
795
+ return best_penalty
796
+
797
+
798
+ # =============================================================================
799
+ # CoxPHCV Class
800
+ # =============================================================================
801
+
802
+ class CoxPHCV(CVEstimatorBase):
803
+ """
804
+ Cross-validated Cox Proportional Hazards regression.
805
+
806
+ This class implements K-fold cross-validation to select the optimal
807
+ penalty (L2 regularization) parameter for Cox PH models.
808
+
809
+ Parameters
810
+ ----------
811
+ penalties : array-like or None
812
+ Penalty values to try. If None, generates n_penalties values.
813
+ n_penalties : int, default=100
814
+ Number of penalty values (if penalties is None).
815
+ penalty_min_ratio : float, default=1e-3
816
+ Minimum penalty as ratio of max penalty.
817
+ cv : int, default=5
818
+ Number of CV folds.
819
+ ties : str, default='breslow'
820
+ Method for handling ties: 'breslow' or 'efron'.
821
+ tol : float, default=1e-9
822
+ Convergence tolerance.
823
+ max_iter : int, default=100
824
+ Maximum iterations.
825
+ device : str or Device, default='auto'
826
+ Computation device: 'cpu', 'cuda', or 'auto'.
827
+ compute_inference : bool, default=True
828
+ Whether to compute standard errors after fitting.
829
+ cov_type : str, default='nonrobust'
830
+ Covariance estimator.
831
+ gpu_memory_cleanup : bool, default=False
832
+ Whether to free GPU memory after fitting.
833
+ random_state : int or None
834
+ Random seed for CV splits.
835
+
836
+ Attributes
837
+ ----------
838
+ penalty_ : float
839
+ Selected penalty value.
840
+ penalties_ : ndarray
841
+ All penalty values tested.
842
+ cv_results_ : dict
843
+ CV results including partial_likelihood_path.
844
+ best_score_ : float
845
+ Best (maximum) partial likelihood across CV folds.
846
+ coef_ : ndarray
847
+ Coefficients of the final model.
848
+ hazard_ratios_ : ndarray
849
+ exp(coef) = hazard ratios.
850
+ estimator_ : CoxPH
851
+ The fitted CoxPH with selected penalty.
852
+
853
+ Examples
854
+ --------
855
+ >>> import numpy as np
856
+ >>> from statgpu.survival import CoxPHCV
857
+ >>> X = np.random.randn(1000, 20)
858
+ >>> time = np.random.exponential(scale=100, size=1000)
859
+ >>> event = np.random.binomial(1, 0.7, size=1000)
860
+ >>> model = CoxPHCV(cv=5, device='cuda')
861
+ >>> model.fit(X, time, event)
862
+ >>> print(f"Selected penalty: {model.penalty_:.4f}")
863
+ >>> print(f"Best CV score: {model.best_score_:.4f}")
864
+ """
865
+
866
+ def __init__(
867
+ self,
868
+ penalties=None,
869
+ n_penalties: int = 100,
870
+ penalty_min_ratio: float = 1e-3,
871
+ cv: int = 5,
872
+ cv_splits=None,
873
+ ties: str = "breslow",
874
+ tol: float = 1e-9,
875
+ max_iter: int = 100,
876
+ device: Union[str, Device] = Device.AUTO,
877
+ n_jobs: Optional[int] = None,
878
+ compute_inference: bool = True,
879
+ cov_type: str = "nonrobust",
880
+ gpu_memory_cleanup: bool = False,
881
+ random_state: Optional[int] = None,
882
+ ):
883
+ super().__init__(
884
+ cv=cv,
885
+ random_state=random_state,
886
+ device=device,
887
+ n_jobs=n_jobs,
888
+ )
889
+ self.penalties = penalties
890
+ self.n_penalties = int(n_penalties)
891
+ self.penalty_min_ratio = float(penalty_min_ratio)
892
+ self.cv = int(cv)
893
+ self.cv_splits = cv_splits
894
+ self.ties = str(ties)
895
+ self.tol = float(tol)
896
+ self.max_iter = int(max_iter)
897
+ self.compute_inference = bool(compute_inference)
898
+ self.cov_type = str(cov_type)
899
+ self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
900
+
901
+ # Output attributes (initialized to None)
902
+ self.penalty_ = None
903
+ self.penalties_ = None
904
+ self.cv_results_ = None
905
+ self.best_score_ = None
906
+ self.coef_ = None
907
+ self.hazard_ratios_ = None
908
+ self.estimator_ = None
909
+
910
+ def _cleanup_cuda_memory(self):
911
+ """Best-effort CuPy memory pool cleanup."""
912
+ if not self.gpu_memory_cleanup:
913
+ return
914
+ try:
915
+ import cupy as cp
916
+
917
+ cp.get_default_memory_pool().free_all_blocks()
918
+ cp.get_default_pinned_memory_pool().free_all_blocks()
919
+ except Exception:
920
+ pass
921
+
922
+ def _cleanup_torch_memory(self):
923
+ """Best-effort Torch CUDA cache cleanup."""
924
+ if not self.gpu_memory_cleanup:
925
+ return
926
+ try:
927
+ import torch
928
+
929
+ torch.cuda.empty_cache()
930
+ torch.cuda.synchronize()
931
+ except Exception:
932
+ pass
933
+
934
+ def __del__(self):
935
+ try:
936
+ self._cleanup_cuda_memory()
937
+ self._cleanup_torch_memory()
938
+ except Exception:
939
+ pass
940
+
941
+ def _fit_cv(self, X, time, event, entry=None, cluster=None):
942
+ """
943
+ Fit CoxPH with K-fold cross-validation.
944
+
945
+ Parameters
946
+ ----------
947
+ X : array-like
948
+ Design matrix.
949
+ time : array-like
950
+ Survival times.
951
+ event : array-like
952
+ Event indicators.
953
+ entry : array-like, optional
954
+ Entry times (delayed entry).
955
+ cluster : array-like, optional
956
+ Cluster ids.
957
+
958
+ Returns
959
+ -------
960
+ self
961
+ """
962
+ device_name = self._get_compute_device().value
963
+ n_samples, n_features = np.asarray(X).shape
964
+ cv_cuda_torch_bridge = os.environ.get(
965
+ "STATGPU_COXPHCV_CUDA_TORCH_BRIDGE", "0"
966
+ ).strip().lower() in ("1", "true", "yes", "on")
967
+ fit_device_name = device_name
968
+ if (
969
+ cv_cuda_torch_bridge
970
+ and device_name == Device.CUDA.value
971
+ and n_samples >= 1500
972
+ and n_features >= 40
973
+ ):
974
+ fit_device_name = Device.TORCH.value
975
+
976
+ # Normalize penalties to list
977
+ if isinstance(self.penalties, (list, tuple, np.ndarray)):
978
+ penalties = np.asarray(self.penalties, dtype=np.float64)
979
+ else:
980
+ penalties = None
981
+
982
+ # Perform CV to find best penalty
983
+ best_penalty, details = _select_coxph_penalty_cv(
984
+ X, time, event,
985
+ entry=entry,
986
+ cluster=cluster,
987
+ penalties=penalties,
988
+ n_penalties=self.n_penalties,
989
+ penalty_min_ratio=self.penalty_min_ratio,
990
+ cv_folds=self.cv,
991
+ cv_splits=self.cv_splits,
992
+ random_state=self.random_state,
993
+ ties=self.ties,
994
+ device=fit_device_name,
995
+ max_iter=self.max_iter,
996
+ tol=self.tol,
997
+ return_details=True,
998
+ )
999
+
1000
+ # Store CV results
1001
+ self.penalty_ = float(best_penalty)
1002
+ self.penalties_ = np.asarray(details["penalties"], dtype=np.float64)
1003
+
1004
+ pl_path = np.asarray(details["pl_path"], dtype=np.float64)
1005
+ mean_pl = np.asarray(details["mean_pl"], dtype=np.float64)
1006
+
1007
+ self.cv_results_ = {
1008
+ "pl_path": pl_path,
1009
+ "mean_pl": mean_pl,
1010
+ }
1011
+ self.best_score_ = float(details["best_pl"])
1012
+
1013
+ # Fit final model on full data with best penalty
1014
+ final_model = CoxPH(
1015
+ ties=self.ties,
1016
+ tol=self.tol,
1017
+ max_iter=self.max_iter,
1018
+ device=fit_device_name,
1019
+ n_jobs=self.n_jobs,
1020
+ compute_inference=self.compute_inference,
1021
+ cov_type=self.cov_type,
1022
+ gpu_memory_cleanup=self.gpu_memory_cleanup,
1023
+ penalty=self.penalty_,
1024
+ )
1025
+ final_model.fit(X, time, event, entry=entry, cluster=cluster)
1026
+
1027
+ self.estimator_ = final_model
1028
+ self.coef_ = final_model.coef_.copy()
1029
+ self.hazard_ratios_ = final_model.hazard_ratios_.copy()
1030
+ self._cleanup_cuda_memory()
1031
+ self._cleanup_torch_memory()
1032
+
1033
+ return self
1034
+
1035
+ def fit(self, X, time, event, entry=None, cluster=None):
1036
+ """
1037
+ Fit CoxPH model with cross-validation.
1038
+
1039
+ Parameters
1040
+ ----------
1041
+ X : array-like of shape (n_samples, n_features)
1042
+ Covariate matrix.
1043
+ time : array-like of shape (n_samples,)
1044
+ Time to event or censoring.
1045
+ event : array-like of shape (n_samples,)
1046
+ Event indicator (1 = event, 0 = censored).
1047
+ entry : array-like, optional
1048
+ Entry time for delayed entry.
1049
+ cluster : array-like, optional
1050
+ Cluster ids.
1051
+
1052
+ Returns
1053
+ -------
1054
+ self : CoxPHCV
1055
+ """
1056
+ return self._fit_cv(X, time, event, entry=entry, cluster=cluster)
1057
+
1058
+ def predict(self, X):
1059
+ """
1060
+ Predict risk scores.
1061
+
1062
+ Parameters
1063
+ ----------
1064
+ X : array-like of shape (n_samples, n_features)
1065
+ Covariate matrix.
1066
+
1067
+ Returns
1068
+ -------
1069
+ risk_scores : ndarray
1070
+ Risk scores (linear predictor).
1071
+ """
1072
+ if self.coef_ is None:
1073
+ raise ValueError("Model not fitted. Call fit() first.")
1074
+
1075
+ X_arr = np.asarray(X, dtype=np.float64)
1076
+ return X_arr @ self.coef_
1077
+
1078
+ def score(self, X, time, event):
1079
+ """
1080
+ Return C-index (concordance index).
1081
+
1082
+ Parameters
1083
+ ----------
1084
+ X : array-like
1085
+ Covariate matrix.
1086
+ time : array-like
1087
+ Survival times.
1088
+ event : array-like
1089
+ Event indicators.
1090
+
1091
+ Returns
1092
+ -------
1093
+ c_index : float
1094
+ C-index (0.5 = random, 1.0 = perfect).
1095
+ """
1096
+ if self.coef_ is None:
1097
+ raise ValueError("Model not fitted. Call fit() first.")
1098
+
1099
+ X_arr = np.asarray(X, dtype=np.float64)
1100
+ time_arr = np.asarray(time, dtype=np.float64)
1101
+ event_arr = np.asarray(event, dtype=np.int32)
1102
+
1103
+ # Compute risk scores
1104
+ risk_scores = X_arr @ self.coef_
1105
+
1106
+ n = len(time_arr)
1107
+ event_mask = (event_arr == 1)
1108
+
1109
+ if not np.any(event_mask):
1110
+ return 0.5
1111
+
1112
+ # Use chunked vectorized approach for memory efficiency
1113
+ # Similar to _compute_cindex in _cox.py
1114
+ event_idx = np.where(event_mask)[0]
1115
+ n_events = len(event_idx)
1116
+
1117
+ if n_events == 0:
1118
+ return 0.5
1119
+
1120
+ concordant = np.int64(0)
1121
+ permissible = np.int64(0)
1122
+ tied_risk = np.int64(0)
1123
+
1124
+ # Chunk size: keep each (chunk × n) bool matrix <= 128 MB
1125
+ chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
1126
+
1127
+ for start in range(0, n_events, chunk_size):
1128
+ end = min(start + chunk_size, n_events)
1129
+ idx_chunk = event_idx[start:end]
1130
+
1131
+ time_i = time_arr[idx_chunk, np.newaxis]
1132
+ risk_i = risk_scores[idx_chunk, np.newaxis]
1133
+ time_j = time_arr[np.newaxis, :]
1134
+ risk_j = risk_scores[np.newaxis, :]
1135
+ event_j = event_arr[np.newaxis, :]
1136
+
1137
+ # Permissible pairs: earlier time OR same time with j censored
1138
+ perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
1139
+
1140
+ # Exclude self-comparisons
1141
+ chunk_indices = np.arange(end - start, dtype=np.int64)
1142
+ perm[chunk_indices, idx_chunk] = False
1143
+
1144
+ concordant += int(np.sum(perm & (risk_i > risk_j)))
1145
+ tied_risk += int(np.sum(perm & (risk_i == risk_j)))
1146
+ permissible += int(np.sum(perm))
1147
+
1148
+ if permissible == 0:
1149
+ return 0.5
1150
+
1151
+ return (concordant + 0.5 * tied_risk) / permissible
1152
+
1153
+ def summary(self):
1154
+ """Return summary of the fitted model."""
1155
+ if self.estimator_ is None:
1156
+ raise RuntimeError("No fitted estimator available.")
1157
+ if not hasattr(self.estimator_, "summary"):
1158
+ raise RuntimeError(f"{self.estimator_.__class__.__name__} does not implement summary().")
1159
+ return self.estimator_.summary()