statgpu 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statgpu/__init__.py +174 -0
- statgpu/_base.py +544 -0
- statgpu/_config.py +127 -0
- statgpu/anova/__init__.py +5 -0
- statgpu/anova/_oneway.py +194 -0
- statgpu/backends/__init__.py +83 -0
- statgpu/backends/_array_ops.py +529 -0
- statgpu/backends/_base.py +184 -0
- statgpu/backends/_cupy.py +453 -0
- statgpu/backends/_factory.py +65 -0
- statgpu/backends/_gpu_inference_cupy.py +214 -0
- statgpu/backends/_gpu_inference_torch.py +422 -0
- statgpu/backends/_numpy.py +324 -0
- statgpu/backends/_torch.py +685 -0
- statgpu/backends/_torch_safe.py +47 -0
- statgpu/backends/_utils.py +423 -0
- statgpu/core/__init__.py +10 -0
- statgpu/core/formula/__init__.py +33 -0
- statgpu/core/formula/_design.py +99 -0
- statgpu/core/formula/_parser.py +191 -0
- statgpu/core/formula/_terms.py +70 -0
- statgpu/core/formula/tests/__init__.py +0 -0
- statgpu/core/formula/tests/test_parser.py +194 -0
- statgpu/covariance/__init__.py +6 -0
- statgpu/covariance/_empirical.py +310 -0
- statgpu/covariance/_shrinkage.py +248 -0
- statgpu/cross_validation/__init__.py +31 -0
- statgpu/cross_validation/_base.py +410 -0
- statgpu/cross_validation/_engine.py +167 -0
- statgpu/diagnostics/__init__.py +7 -0
- statgpu/diagnostics/_regression_diagnostics.py +188 -0
- statgpu/feature_selection/__init__.py +24 -0
- statgpu/feature_selection/_knockoff.py +870 -0
- statgpu/feature_selection/_knockoff_utils.py +1003 -0
- statgpu/feature_selection/_stepwise.py +300 -0
- statgpu/glm_core/__init__.py +81 -0
- statgpu/glm_core/_base.py +202 -0
- statgpu/glm_core/_family.py +362 -0
- statgpu/glm_core/_fused.py +149 -0
- statgpu/glm_core/_gamma.py +111 -0
- statgpu/glm_core/_inverse_gaussian.py +62 -0
- statgpu/glm_core/_irls.py +561 -0
- statgpu/glm_core/_logistic.py +82 -0
- statgpu/glm_core/_negative_binomial.py +68 -0
- statgpu/glm_core/_poisson.py +60 -0
- statgpu/glm_core/_solver_legacy.py +100 -0
- statgpu/glm_core/_squared.py +53 -0
- statgpu/glm_core/_tweedie.py +74 -0
- statgpu/inference/__init__.py +239 -0
- statgpu/inference/_distributions_backend.py +2610 -0
- statgpu/inference/_multiple_testing.py +391 -0
- statgpu/inference/_resampling.py +1400 -0
- statgpu/inference/_results.py +265 -0
- statgpu/linear_model/__init__.py +75 -0
- statgpu/linear_model/_gaussian_inference.py +306 -0
- statgpu/linear_model/_glm_base.py +1261 -0
- statgpu/linear_model/_ordered_logit.py +52 -0
- statgpu/linear_model/_ordered_probit.py +50 -0
- statgpu/linear_model/_stats.py +170 -0
- statgpu/linear_model/cv/__init__.py +13 -0
- statgpu/linear_model/cv/_elasticnet_cv.py +892 -0
- statgpu/linear_model/cv/_lasso_cv.py +253 -0
- statgpu/linear_model/cv/_logistic_cv.py +895 -0
- statgpu/linear_model/cv/_ridge_cv.py +1160 -0
- statgpu/linear_model/legacy/__init__.py +1 -0
- statgpu/linear_model/legacy/_distributions_legacy_gpu.py +340 -0
- statgpu/linear_model/legacy/_elasticnet_legacy.py +936 -0
- statgpu/linear_model/legacy/_lasso_legacy.py +4876 -0
- statgpu/linear_model/legacy/_penalized_legacy.py +1174 -0
- statgpu/linear_model/legacy/_ridge_legacy.py +863 -0
- statgpu/linear_model/legacy/_solver_legacy.py +104 -0
- statgpu/linear_model/penalized/__init__.py +25 -0
- statgpu/linear_model/penalized/_base.py +437 -0
- statgpu/linear_model/penalized/_fit_mixin.py +1877 -0
- statgpu/linear_model/penalized/_inference_mixin.py +1179 -0
- statgpu/linear_model/penalized/_penalized_cv.py +2699 -0
- statgpu/linear_model/penalized/_penalized_gamma.py +86 -0
- statgpu/linear_model/penalized/_penalized_inverse_gaussian.py +62 -0
- statgpu/linear_model/penalized/_penalized_linear.py +236 -0
- statgpu/linear_model/penalized/_penalized_logistic.py +100 -0
- statgpu/linear_model/penalized/_penalized_negative_binomial.py +65 -0
- statgpu/linear_model/penalized/_penalized_poisson.py +62 -0
- statgpu/linear_model/penalized/_penalized_tweedie.py +65 -0
- statgpu/linear_model/penalized/_predict_mixin.py +182 -0
- statgpu/linear_model/wrappers/__init__.py +31 -0
- statgpu/linear_model/wrappers/_adaptive_lasso.py +63 -0
- statgpu/linear_model/wrappers/_elasticnet.py +75 -0
- statgpu/linear_model/wrappers/_gamma.py +67 -0
- statgpu/linear_model/wrappers/_inverse_gaussian.py +47 -0
- statgpu/linear_model/wrappers/_lasso.py +2124 -0
- statgpu/linear_model/wrappers/_linear.py +1127 -0
- statgpu/linear_model/wrappers/_logistic.py +1435 -0
- statgpu/linear_model/wrappers/_mcp.py +58 -0
- statgpu/linear_model/wrappers/_negative_binomial.py +58 -0
- statgpu/linear_model/wrappers/_poisson.py +48 -0
- statgpu/linear_model/wrappers/_ridge.py +166 -0
- statgpu/linear_model/wrappers/_scad.py +58 -0
- statgpu/linear_model/wrappers/_tweedie.py +57 -0
- statgpu/metrics/__init__.py +21 -0
- statgpu/metrics/_classification.py +591 -0
- statgpu/nonparametric/__init__.py +50 -0
- statgpu/nonparametric/kernel_methods/__init__.py +25 -0
- statgpu/nonparametric/kernel_methods/_kernels.py +246 -0
- statgpu/nonparametric/kernel_methods/_krr.py +234 -0
- statgpu/nonparametric/kernel_methods/_krr_cv.py +380 -0
- statgpu/nonparametric/kernel_smoothing/__init__.py +39 -0
- statgpu/nonparametric/kernel_smoothing/_bandwidth_selection.py +1083 -0
- statgpu/nonparametric/kernel_smoothing/_kde.py +761 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_common.py +348 -0
- statgpu/nonparametric/kernel_smoothing/_kernel_regression.py +748 -0
- statgpu/nonparametric/splines/__init__.py +5 -0
- statgpu/nonparametric/splines/_bspline_basis.py +336 -0
- statgpu/nonparametric/splines/_penalized.py +349 -0
- statgpu/panel/__init__.py +19 -0
- statgpu/panel/_covariance.py +140 -0
- statgpu/panel/_fixed_effects.py +420 -0
- statgpu/panel/_random_effects.py +385 -0
- statgpu/panel/_utils.py +482 -0
- statgpu/penalties/__init__.py +139 -0
- statgpu/penalties/_adaptive_l1.py +313 -0
- statgpu/penalties/_base.py +261 -0
- statgpu/penalties/_categories.py +39 -0
- statgpu/penalties/_elasticnet.py +98 -0
- statgpu/penalties/_group_lasso.py +678 -0
- statgpu/penalties/_group_mcp.py +553 -0
- statgpu/penalties/_group_scad.py +605 -0
- statgpu/penalties/_l1.py +107 -0
- statgpu/penalties/_l2.py +77 -0
- statgpu/penalties/_mcp.py +237 -0
- statgpu/penalties/_scad.py +260 -0
- statgpu/semiparametric/__init__.py +5 -0
- statgpu/semiparametric/_gam.py +401 -0
- statgpu/solvers/__init__.py +24 -0
- statgpu/solvers/_admm.py +241 -0
- statgpu/solvers/_constants.py +15 -0
- statgpu/solvers/_convergence.py +6 -0
- statgpu/solvers/_fista.py +436 -0
- statgpu/solvers/_fista_bb.py +513 -0
- statgpu/solvers/_fista_lla.py +541 -0
- statgpu/solvers/_lbfgs.py +206 -0
- statgpu/solvers/_newton.py +149 -0
- statgpu/solvers/_utils.py +277 -0
- statgpu/survival/__init__.py +14 -0
- statgpu/survival/_cox.py +3974 -0
- statgpu/survival/_cox_breslow_triton_kernel.py +106 -0
- statgpu/survival/_cox_cv.py +1159 -0
- statgpu/survival/_cox_efron_cuda.py +1280 -0
- statgpu/survival/_cox_efron_triton.py +359 -0
- statgpu/unsupervised/__init__.py +29 -0
- statgpu/unsupervised/_agglomerative.py +307 -0
- statgpu/unsupervised/_dbscan.py +263 -0
- statgpu/unsupervised/_dbscan_cpu.pyx +125 -0
- statgpu/unsupervised/_gmm.py +332 -0
- statgpu/unsupervised/_incremental_pca.py +176 -0
- statgpu/unsupervised/_kmeans.py +261 -0
- statgpu/unsupervised/_minibatch_kmeans.py +299 -0
- statgpu/unsupervised/_minibatch_nmf.py +252 -0
- statgpu/unsupervised/_nmf.py +190 -0
- statgpu/unsupervised/_pca.py +189 -0
- statgpu/unsupervised/_truncated_svd.py +132 -0
- statgpu/unsupervised/_tsne.py +192 -0
- statgpu/unsupervised/_umap.py +224 -0
- statgpu/unsupervised/_utils.py +134 -0
- statgpu-0.1.0.dist-info/METADATA +245 -0
- statgpu-0.1.0.dist-info/RECORD +168 -0
- statgpu-0.1.0.dist-info/WHEEL +5 -0
- statgpu-0.1.0.dist-info/licenses/LICENSE +199 -0
- statgpu-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1159 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CoxPHCV: Cross-validated Cox Proportional Hazards regression.
|
|
3
|
+
|
|
4
|
+
Implements K-fold cross-validation to select the optimal penalty (L2 regularization)
|
|
5
|
+
parameter for Cox PH models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Optional, Union, Tuple, Dict, Any, List
|
|
9
|
+
from collections import OrderedDict
|
|
10
|
+
import hashlib
|
|
11
|
+
import os
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
from statgpu._config import Device
|
|
15
|
+
from statgpu.backends import _get_torch_device_str
|
|
16
|
+
from statgpu.cross_validation._base import CVEstimatorBase
|
|
17
|
+
from statgpu.survival._cox import CoxPH
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# =============================================================================
|
|
21
|
+
# CV Cache
|
|
22
|
+
# =============================================================================
|
|
23
|
+
|
|
24
|
+
_COXPH_CV_CACHE_MAXSIZE = int(64)
|
|
25
|
+
_COXPH_CV_CACHE: "OrderedDict[str, Dict[str, Any]]" = OrderedDict()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _env_flag(name: str, default: bool = False) -> bool:
|
|
29
|
+
"""Safely parse boolean env var."""
|
|
30
|
+
raw = os.environ.get(name)
|
|
31
|
+
if raw is None:
|
|
32
|
+
return bool(default)
|
|
33
|
+
return str(raw).strip().lower() in ("1", "true", "yes", "on")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _env_int(
|
|
37
|
+
name: str,
|
|
38
|
+
default: int,
|
|
39
|
+
*,
|
|
40
|
+
min_value: Optional[int] = None,
|
|
41
|
+
max_value: Optional[int] = None,
|
|
42
|
+
) -> int:
|
|
43
|
+
"""Safely parse integer env var with optional bounds."""
|
|
44
|
+
raw = os.environ.get(name)
|
|
45
|
+
try:
|
|
46
|
+
val = int(raw) if raw is not None else int(default)
|
|
47
|
+
except (TypeError, ValueError):
|
|
48
|
+
val = int(default)
|
|
49
|
+
if min_value is not None:
|
|
50
|
+
val = max(min_value, val)
|
|
51
|
+
if max_value is not None:
|
|
52
|
+
val = min(max_value, val)
|
|
53
|
+
return val
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _env_float(name: str, default: float, *, min_value: Optional[float] = None) -> float:
|
|
57
|
+
"""Safely parse float env var with optional lower bound."""
|
|
58
|
+
raw = os.environ.get(name)
|
|
59
|
+
try:
|
|
60
|
+
val = float(raw) if raw is not None else float(default)
|
|
61
|
+
except (TypeError, ValueError):
|
|
62
|
+
val = float(default)
|
|
63
|
+
if min_value is not None:
|
|
64
|
+
val = max(min_value, val)
|
|
65
|
+
return val
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _hash_optional_array(h: "hashlib._blake2.blake2b", tag: str, arr: Optional[np.ndarray]) -> None:
|
|
69
|
+
"""Hash optional array content for cache-key disambiguation."""
|
|
70
|
+
if arr is None:
|
|
71
|
+
h.update(f"{tag}:none".encode("utf-8"))
|
|
72
|
+
return
|
|
73
|
+
h.update(tag.encode("utf-8"))
|
|
74
|
+
arr_np = np.asarray(arr)
|
|
75
|
+
h.update(np.asarray(arr_np.shape, dtype=np.int64).tobytes())
|
|
76
|
+
h.update(str(arr_np.dtype).encode("utf-8"))
|
|
77
|
+
h.update(np.ascontiguousarray(arr_np).tobytes())
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _coxcv_cache_get(cache_key: Optional[str]) -> Optional[Dict[str, Any]]:
|
|
81
|
+
"""Get cached CoxPH CV results."""
|
|
82
|
+
if cache_key is None:
|
|
83
|
+
return None
|
|
84
|
+
val = _COXPH_CV_CACHE.get(cache_key)
|
|
85
|
+
if val is not None:
|
|
86
|
+
_COXPH_CV_CACHE.move_to_end(cache_key)
|
|
87
|
+
return val
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def _coxcv_cache_put(cache_key: Optional[str], value: Dict[str, Any]) -> None:
|
|
91
|
+
"""Put cached CoxPH CV results."""
|
|
92
|
+
if cache_key is None:
|
|
93
|
+
return
|
|
94
|
+
_COXPH_CV_CACHE[cache_key] = value
|
|
95
|
+
_COXPH_CV_CACHE.move_to_end(cache_key)
|
|
96
|
+
while len(_COXPH_CV_CACHE) > _COXPH_CV_CACHE_MAXSIZE:
|
|
97
|
+
_COXPH_CV_CACHE.popitem(last=False)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _sample_hash(h, arr, max_rows=50):
|
|
101
|
+
"""Hash a sampled subset of an array for cache key generation."""
|
|
102
|
+
arr_np = np.asarray(arr, dtype=np.float64).ravel()
|
|
103
|
+
n = arr_np.shape[0]
|
|
104
|
+
if n <= max_rows:
|
|
105
|
+
h.update(arr_np.tobytes())
|
|
106
|
+
else:
|
|
107
|
+
# Sample first, middle, and last rows
|
|
108
|
+
indices = np.concatenate([np.arange(max_rows//2), np.arange(n-max_rows//2, n)])
|
|
109
|
+
h.update(arr_np[indices].tobytes())
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _make_coxph_cv_auto_cache_key(
|
|
113
|
+
X_shape: Tuple[int, ...],
|
|
114
|
+
time_shape: Tuple[int, ...],
|
|
115
|
+
event_shape: Tuple[int, ...],
|
|
116
|
+
penalties: Optional[np.ndarray],
|
|
117
|
+
n_penalties: int,
|
|
118
|
+
penalty_min_ratio: float,
|
|
119
|
+
folds: List[Tuple[np.ndarray, np.ndarray]],
|
|
120
|
+
ties: str,
|
|
121
|
+
use_gpu: bool,
|
|
122
|
+
fit_device: str,
|
|
123
|
+
cv_cuda_torch_bridge: bool,
|
|
124
|
+
entry: Optional[np.ndarray],
|
|
125
|
+
cluster: Optional[np.ndarray],
|
|
126
|
+
two_stage_enabled: bool,
|
|
127
|
+
halving_enabled: bool,
|
|
128
|
+
coarse_n: int,
|
|
129
|
+
window: int,
|
|
130
|
+
halving_topk: int,
|
|
131
|
+
fast_iter: int,
|
|
132
|
+
fast_tol: float,
|
|
133
|
+
max_iter: int,
|
|
134
|
+
tol: float,
|
|
135
|
+
X_data=None,
|
|
136
|
+
time_data=None,
|
|
137
|
+
event_data=None,
|
|
138
|
+
) -> str:
|
|
139
|
+
"""
|
|
140
|
+
Generate automatic cache key for CoxPH CV.
|
|
141
|
+
|
|
142
|
+
Includes structural inputs (shapes/grid/folds), execution-path settings
|
|
143
|
+
(fit device/bridge/two-stage/halving), and optional delayed-entry or
|
|
144
|
+
clustering arrays to avoid stale collisions across distinct CV runs.
|
|
145
|
+
"""
|
|
146
|
+
h = hashlib.blake2b(digest_size=32)
|
|
147
|
+
h.update(np.asarray(X_shape, dtype=np.int64).tobytes())
|
|
148
|
+
h.update(np.asarray(time_shape, dtype=np.int64).tobytes())
|
|
149
|
+
h.update(np.asarray(event_shape, dtype=np.int64).tobytes())
|
|
150
|
+
# Include sampled data content to avoid collisions across datasets with same shape
|
|
151
|
+
if X_data is not None:
|
|
152
|
+
_sample_hash(h, X_data, max_rows=50)
|
|
153
|
+
if time_data is not None:
|
|
154
|
+
_sample_hash(h, time_data, max_rows=50)
|
|
155
|
+
if event_data is not None:
|
|
156
|
+
_sample_hash(h, event_data, max_rows=50)
|
|
157
|
+
if penalties is not None:
|
|
158
|
+
h.update(np.asarray(penalties, dtype=np.float64).tobytes())
|
|
159
|
+
h.update(str(n_penalties).encode("utf-8"))
|
|
160
|
+
h.update(str(penalty_min_ratio).encode("utf-8"))
|
|
161
|
+
h.update(str(folds).encode("utf-8"))
|
|
162
|
+
h.update(str(ties).encode("utf-8"))
|
|
163
|
+
h.update(str(use_gpu).encode("utf-8"))
|
|
164
|
+
h.update(str(fit_device).encode("utf-8"))
|
|
165
|
+
h.update(str(cv_cuda_torch_bridge).encode("utf-8"))
|
|
166
|
+
_hash_optional_array(h, "entry", entry)
|
|
167
|
+
_hash_optional_array(h, "cluster", cluster)
|
|
168
|
+
h.update(str(two_stage_enabled).encode("utf-8"))
|
|
169
|
+
h.update(str(halving_enabled).encode("utf-8"))
|
|
170
|
+
h.update(str(coarse_n).encode("utf-8"))
|
|
171
|
+
h.update(str(window).encode("utf-8"))
|
|
172
|
+
h.update(str(halving_topk).encode("utf-8"))
|
|
173
|
+
h.update(str(fast_iter).encode("utf-8"))
|
|
174
|
+
h.update(str(fast_tol).encode("utf-8"))
|
|
175
|
+
h.update(str(max_iter).encode("utf-8"))
|
|
176
|
+
h.update(str(tol).encode("utf-8"))
|
|
177
|
+
return h.hexdigest()
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# =============================================================================
|
|
181
|
+
# K-fold helpers
|
|
182
|
+
# =============================================================================
|
|
183
|
+
|
|
184
|
+
def _kfold_indices(n_samples: int, n_splits: int, random_state: Optional[int] = None):
|
|
185
|
+
"""Generate K-fold train/test indices."""
|
|
186
|
+
rng = np.random.RandomState(random_state)
|
|
187
|
+
indices = np.arange(n_samples)
|
|
188
|
+
rng.shuffle(indices)
|
|
189
|
+
fold_sizes = np.full(n_splits, n_samples // n_splits, dtype=np.int64)
|
|
190
|
+
fold_sizes[: n_samples % n_splits] += 1
|
|
191
|
+
current = 0
|
|
192
|
+
folds = []
|
|
193
|
+
for fold_size in fold_sizes:
|
|
194
|
+
start, stop = current, current + fold_size
|
|
195
|
+
test_idx = indices[start:stop]
|
|
196
|
+
train_idx = np.concatenate([indices[:start], indices[stop:]])
|
|
197
|
+
folds.append((train_idx, test_idx))
|
|
198
|
+
current = stop
|
|
199
|
+
return folds
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _folds_are_complements(folds, n_samples: int) -> bool:
|
|
203
|
+
"""Check if folds are complementary."""
|
|
204
|
+
test_indices = np.concatenate([f[1] for f in folds])
|
|
205
|
+
if len(test_indices) != n_samples:
|
|
206
|
+
return False
|
|
207
|
+
return np.array_equal(np.sort(test_indices), np.arange(n_samples))
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
# =============================================================================
|
|
211
|
+
# Penalty grid generation
|
|
212
|
+
# =============================================================================
|
|
213
|
+
|
|
214
|
+
def _default_coxph_penalty_grid(
|
|
215
|
+
X: np.ndarray,
|
|
216
|
+
time: np.ndarray,
|
|
217
|
+
event: np.ndarray,
|
|
218
|
+
n_penalties: int = 100,
|
|
219
|
+
penalty_min_ratio: float = 1e-3,
|
|
220
|
+
) -> np.ndarray:
|
|
221
|
+
"""
|
|
222
|
+
Generate default penalty grid for CoxPHCV.
|
|
223
|
+
|
|
224
|
+
Penalty values are log-spaced, similar to alpha grid in RidgeCV.
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
X : ndarray
|
|
229
|
+
Design matrix (n_samples, n_features).
|
|
230
|
+
time : ndarray
|
|
231
|
+
Survival times.
|
|
232
|
+
event : ndarray
|
|
233
|
+
Event indicators (1=event, 0=censored).
|
|
234
|
+
n_penalties : int
|
|
235
|
+
Number of penalty values.
|
|
236
|
+
penalty_min_ratio : float
|
|
237
|
+
Minimum penalty as ratio of max penalty.
|
|
238
|
+
|
|
239
|
+
Returns
|
|
240
|
+
-------
|
|
241
|
+
penalties : ndarray
|
|
242
|
+
Log-spaced penalty values.
|
|
243
|
+
"""
|
|
244
|
+
n_samples, n_features = X.shape
|
|
245
|
+
n_events = int(np.sum(event))
|
|
246
|
+
|
|
247
|
+
if n_events == 0:
|
|
248
|
+
# No events - return simple grid
|
|
249
|
+
return np.geomspace(1e-3, 1, n_penalties)
|
|
250
|
+
|
|
251
|
+
# Estimate penalty_max from data variance
|
|
252
|
+
# Larger variance -> larger potential penalty
|
|
253
|
+
X_var = np.var(X, axis=0)
|
|
254
|
+
penalty_max = np.max(X_var) * n_events * 0.1
|
|
255
|
+
|
|
256
|
+
# Ensure penalty_max is positive and reasonable
|
|
257
|
+
penalty_max = max(penalty_max, 1.0)
|
|
258
|
+
penalty_min = penalty_min_ratio * penalty_max
|
|
259
|
+
|
|
260
|
+
penalties = np.geomspace(penalty_max, penalty_min, n_penalties)
|
|
261
|
+
return penalties.astype(np.float64)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
# =============================================================================
|
|
265
|
+
# Partial likelihood computation for CV evaluation
|
|
266
|
+
# =============================================================================
|
|
267
|
+
|
|
268
|
+
def _compute_partial_likelihood(
|
|
269
|
+
X: np.ndarray,
|
|
270
|
+
time: np.ndarray,
|
|
271
|
+
event: np.ndarray,
|
|
272
|
+
coef: np.ndarray,
|
|
273
|
+
entry: Optional[np.ndarray] = None,
|
|
274
|
+
ties: str = 'breslow',
|
|
275
|
+
) -> float:
|
|
276
|
+
"""
|
|
277
|
+
Compute log partial likelihood for given coefficients.
|
|
278
|
+
|
|
279
|
+
This is used for CV evaluation on held-out test folds.
|
|
280
|
+
|
|
281
|
+
Parameters
|
|
282
|
+
----------
|
|
283
|
+
X : ndarray
|
|
284
|
+
Design matrix (n_samples, n_features).
|
|
285
|
+
time : ndarray
|
|
286
|
+
Survival times.
|
|
287
|
+
event : ndarray
|
|
288
|
+
Event indicators.
|
|
289
|
+
coef : ndarray
|
|
290
|
+
Coefficient values.
|
|
291
|
+
entry : ndarray or None
|
|
292
|
+
Delayed-entry times (left truncation). If None, assumes entry=0 for all samples.
|
|
293
|
+
ties : str
|
|
294
|
+
'breslow' or 'efron'.
|
|
295
|
+
|
|
296
|
+
Returns
|
|
297
|
+
-------
|
|
298
|
+
log_pl : float
|
|
299
|
+
Log partial likelihood value.
|
|
300
|
+
"""
|
|
301
|
+
n = len(time)
|
|
302
|
+
if coef is None or np.all(coef == 0):
|
|
303
|
+
# Null model: compute log partial likelihood at beta=0
|
|
304
|
+
# L(0) = sum_events[-log(|R(t_i)|)] where |R(t_i)| = n - i (sorted)
|
|
305
|
+
order = np.argsort(time)
|
|
306
|
+
event_sorted = event[order]
|
|
307
|
+
# Risk set size at sorted position i is (n - i)
|
|
308
|
+
risk_set_sizes = n - np.arange(n)
|
|
309
|
+
event_mask = event_sorted.astype(bool)
|
|
310
|
+
null_ll = -np.sum(np.log(risk_set_sizes[event_mask].astype(float)))
|
|
311
|
+
return null_ll
|
|
312
|
+
|
|
313
|
+
risk_scores = X @ coef
|
|
314
|
+
exp_risk = np.exp(risk_scores)
|
|
315
|
+
|
|
316
|
+
# Fast path (no delayed-entry): keep vectorized suffix-sum implementation.
|
|
317
|
+
if entry is None:
|
|
318
|
+
order = np.argsort(time)
|
|
319
|
+
time_sorted = time[order]
|
|
320
|
+
event_sorted = event[order]
|
|
321
|
+
risk_sorted = risk_scores[order]
|
|
322
|
+
exp_risk_sorted = exp_risk[order]
|
|
323
|
+
log_pl = 0.0
|
|
324
|
+
if ties == 'breslow':
|
|
325
|
+
risk_set_sum = np.cumsum(exp_risk_sorted[::-1])[::-1]
|
|
326
|
+
event_mask = event_sorted == 1
|
|
327
|
+
if np.any(event_mask):
|
|
328
|
+
log_pl = np.sum(risk_sorted[event_mask]) - np.sum(np.log(risk_set_sum[event_mask] + 1e-300))
|
|
329
|
+
elif ties == 'efron':
|
|
330
|
+
event_mask = event_sorted == 1
|
|
331
|
+
if not np.any(event_mask):
|
|
332
|
+
return 0.0
|
|
333
|
+
event_idx = np.where(event_mask)[0]
|
|
334
|
+
event_times = time_sorted[event_idx]
|
|
335
|
+
unique_times, inv, counts = np.unique(event_times, return_inverse=True, return_counts=True)
|
|
336
|
+
risk_set_sum = np.cumsum(exp_risk_sorted[::-1])[::-1]
|
|
337
|
+
for g, t in enumerate(unique_times):
|
|
338
|
+
d = counts[g]
|
|
339
|
+
if d == 0:
|
|
340
|
+
continue
|
|
341
|
+
first_idx = np.searchsorted(time_sorted, t, side='left')
|
|
342
|
+
risk_at_t = risk_set_sum[first_idx]
|
|
343
|
+
event_rows = event_idx[inv == g]
|
|
344
|
+
sum_risk = np.sum(risk_sorted[event_rows])
|
|
345
|
+
sum_exp_risk = np.sum(exp_risk_sorted[event_rows])
|
|
346
|
+
k = np.arange(d, dtype=np.float64) / d
|
|
347
|
+
denom = risk_at_t - k * sum_exp_risk
|
|
348
|
+
log_pl += sum_risk - np.sum(np.log(np.maximum(denom, 1e-300)))
|
|
349
|
+
return float(log_pl)
|
|
350
|
+
|
|
351
|
+
entry_arr = np.asarray(entry, dtype=np.float64)
|
|
352
|
+
# Delayed-entry path
|
|
353
|
+
order = np.argsort(time)
|
|
354
|
+
time_sorted = time[order]
|
|
355
|
+
event_sorted = event[order]
|
|
356
|
+
entry_sorted = entry_arr[order]
|
|
357
|
+
risk_sorted = risk_scores[order]
|
|
358
|
+
exp_risk_sorted = exp_risk[order]
|
|
359
|
+
|
|
360
|
+
log_pl = 0.0
|
|
361
|
+
|
|
362
|
+
# With delayed entry, risk set is:
|
|
363
|
+
# R(t) = {j: entry_j <= t <= time_j}
|
|
364
|
+
# We compute denominators directly per unique event time for correctness.
|
|
365
|
+
event_mask = event_sorted == 1
|
|
366
|
+
if not np.any(event_mask):
|
|
367
|
+
return 0.0
|
|
368
|
+
event_idx = np.where(event_mask)[0]
|
|
369
|
+
event_times = time_sorted[event_idx]
|
|
370
|
+
|
|
371
|
+
if ties == 'breslow':
|
|
372
|
+
unique_times, inv, counts = np.unique(event_times, return_inverse=True, return_counts=True)
|
|
373
|
+
for g, t in enumerate(unique_times):
|
|
374
|
+
d = counts[g]
|
|
375
|
+
if d == 0:
|
|
376
|
+
continue
|
|
377
|
+
events_at_t = event_idx[inv == g]
|
|
378
|
+
risk_mask = (entry_sorted <= t) & (time_sorted >= t)
|
|
379
|
+
risk_at_t = np.sum(exp_risk_sorted[risk_mask])
|
|
380
|
+
sum_risk = np.sum(risk_sorted[events_at_t])
|
|
381
|
+
log_pl += sum_risk - d * np.log(max(risk_at_t, 1e-300))
|
|
382
|
+
|
|
383
|
+
elif ties == 'efron':
|
|
384
|
+
# Efron method by unique failure times
|
|
385
|
+
unique_times, inv, counts = np.unique(event_times, return_inverse=True, return_counts=True)
|
|
386
|
+
for g, t in enumerate(unique_times):
|
|
387
|
+
d = counts[g]
|
|
388
|
+
if d == 0:
|
|
389
|
+
continue
|
|
390
|
+
event_rows = event_idx[inv == g]
|
|
391
|
+
risk_mask = (entry_sorted <= t) & (time_sorted >= t)
|
|
392
|
+
risk_at_t = np.sum(exp_risk_sorted[risk_mask])
|
|
393
|
+
sum_risk = np.sum(risk_sorted[event_rows])
|
|
394
|
+
sum_exp_risk = np.sum(exp_risk_sorted[event_rows])
|
|
395
|
+
|
|
396
|
+
# Efron correction
|
|
397
|
+
k = np.arange(d, dtype=np.float64) / d
|
|
398
|
+
denom = risk_at_t - k * sum_exp_risk
|
|
399
|
+
log_pl += sum_risk - np.sum(np.log(np.maximum(denom, 1e-300)))
|
|
400
|
+
|
|
401
|
+
return float(log_pl)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
# =============================================================================
|
|
405
|
+
# CV main function
|
|
406
|
+
# =============================================================================
|
|
407
|
+
|
|
408
|
+
def _select_coxph_penalty_cv(
|
|
409
|
+
X,
|
|
410
|
+
time,
|
|
411
|
+
event,
|
|
412
|
+
entry=None,
|
|
413
|
+
cluster=None,
|
|
414
|
+
*,
|
|
415
|
+
penalties=None,
|
|
416
|
+
n_penalties: int = 100,
|
|
417
|
+
penalty_min_ratio: float = 1e-3,
|
|
418
|
+
cv_folds: int = 5,
|
|
419
|
+
cv_splits=None,
|
|
420
|
+
random_state: Optional[int] = None,
|
|
421
|
+
ties: str = "breslow",
|
|
422
|
+
device: Union[str, Device] = Device.CPU,
|
|
423
|
+
max_iter: int = 100,
|
|
424
|
+
tol: float = 1e-9,
|
|
425
|
+
return_details: bool = False,
|
|
426
|
+
cache_key: Optional[str] = None,
|
|
427
|
+
):
|
|
428
|
+
"""
|
|
429
|
+
Select penalty for CoxPH via K-fold cross-validation.
|
|
430
|
+
|
|
431
|
+
For each fold:
|
|
432
|
+
1. Split data into train/test
|
|
433
|
+
2. Fit CoxPH on train for each penalty
|
|
434
|
+
3. Evaluate partial likelihood on test
|
|
435
|
+
|
|
436
|
+
Returns the penalty with maximum mean partial likelihood.
|
|
437
|
+
|
|
438
|
+
Parameters
|
|
439
|
+
----------
|
|
440
|
+
X : ndarray
|
|
441
|
+
Design matrix (n_samples, n_features).
|
|
442
|
+
time : ndarray
|
|
443
|
+
Survival times (n_samples,).
|
|
444
|
+
event : ndarray
|
|
445
|
+
Event indicators (n_samples,).
|
|
446
|
+
entry : ndarray or None
|
|
447
|
+
Delayed-entry times.
|
|
448
|
+
cluster : ndarray or None
|
|
449
|
+
Cluster ids (used in model fitting; scoring remains partial likelihood).
|
|
450
|
+
penalties : ndarray or None
|
|
451
|
+
Penalty values to try. If None, generates grid.
|
|
452
|
+
n_penalties : int
|
|
453
|
+
Number of penalties (if penalties is None).
|
|
454
|
+
penalty_min_ratio : float
|
|
455
|
+
Minimum penalty ratio.
|
|
456
|
+
cv_folds : int
|
|
457
|
+
Number of CV folds.
|
|
458
|
+
cv_splits : list or None
|
|
459
|
+
Pre-computed CV splits.
|
|
460
|
+
random_state : int or None
|
|
461
|
+
Random seed.
|
|
462
|
+
ties : str
|
|
463
|
+
'breslow' or 'efron'.
|
|
464
|
+
device : str or Device
|
|
465
|
+
Computation device.
|
|
466
|
+
max_iter : int
|
|
467
|
+
Maximum iterations.
|
|
468
|
+
tol : float
|
|
469
|
+
Convergence tolerance.
|
|
470
|
+
return_details : bool
|
|
471
|
+
Whether to return full CV details.
|
|
472
|
+
cache_key : str or None
|
|
473
|
+
Cache key.
|
|
474
|
+
|
|
475
|
+
Returns
|
|
476
|
+
-------
|
|
477
|
+
best_penalty : float
|
|
478
|
+
details : dict (if return_details=True)
|
|
479
|
+
"""
|
|
480
|
+
device_name = str(device).lower() if not isinstance(device, Device) else device.value
|
|
481
|
+
use_gpu = device_name in (Device.CUDA.value, Device.TORCH.value)
|
|
482
|
+
# Optional CV bridge for CUDA: many medium-size CV workloads are faster with
|
|
483
|
+
# torch backend while preserving the same CoxPHCV public API.
|
|
484
|
+
cv_cuda_torch_bridge = os.environ.get(
|
|
485
|
+
"STATGPU_COXPHCV_CUDA_TORCH_BRIDGE", "0"
|
|
486
|
+
).strip().lower() in ("1", "true", "yes", "on")
|
|
487
|
+
|
|
488
|
+
# Convert to numpy arrays
|
|
489
|
+
X_np = np.asarray(X, dtype=np.float64)
|
|
490
|
+
time_np = np.asarray(time, dtype=np.float64)
|
|
491
|
+
event_np = np.asarray(event, dtype=np.int32)
|
|
492
|
+
entry_np = None if entry is None else np.asarray(entry, dtype=np.float64)
|
|
493
|
+
cluster_np = None if cluster is None else np.asarray(cluster)
|
|
494
|
+
|
|
495
|
+
n_samples = X_np.shape[0]
|
|
496
|
+
n_features = X_np.shape[1]
|
|
497
|
+
fit_device = device_name
|
|
498
|
+
if (
|
|
499
|
+
cv_cuda_torch_bridge
|
|
500
|
+
and device_name == Device.CUDA.value
|
|
501
|
+
and n_samples >= 1500
|
|
502
|
+
and n_features >= 40
|
|
503
|
+
):
|
|
504
|
+
fit_device = Device.TORCH.value
|
|
505
|
+
|
|
506
|
+
# Generate penalty grid
|
|
507
|
+
if penalties is None:
|
|
508
|
+
penalties = _default_coxph_penalty_grid(X_np, time_np, event_np, n_penalties, penalty_min_ratio)
|
|
509
|
+
else:
|
|
510
|
+
penalties = np.asarray(penalties, dtype=np.float64)
|
|
511
|
+
penalties = penalties[np.isfinite(penalties)]
|
|
512
|
+
penalties = penalties[penalties >= 0]
|
|
513
|
+
if penalties.size == 0:
|
|
514
|
+
penalties = _default_coxph_penalty_grid(X_np, time_np, event_np, n_penalties, penalty_min_ratio)
|
|
515
|
+
|
|
516
|
+
n_penalties_actual = len(penalties)
|
|
517
|
+
|
|
518
|
+
# Handle degenerate cases
|
|
519
|
+
if n_samples < 4 or cv_folds < 2:
|
|
520
|
+
if not return_details:
|
|
521
|
+
return float(penalties[0])
|
|
522
|
+
return {
|
|
523
|
+
"penalty": float(penalties[0]),
|
|
524
|
+
"penalties": penalties.astype(np.float64),
|
|
525
|
+
"pl_path": np.full((n_penalties_actual, 1), np.nan, dtype=np.float64),
|
|
526
|
+
"mean_pl": np.full(n_penalties_actual, np.nan, dtype=np.float64),
|
|
527
|
+
"best_pl": np.nan,
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
# Generate CV folds
|
|
531
|
+
if cv_splits is not None:
|
|
532
|
+
folds = cv_splits
|
|
533
|
+
else:
|
|
534
|
+
folds = _kfold_indices(n_samples, cv_folds, random_state)
|
|
535
|
+
|
|
536
|
+
folds_are_complements_flag = _folds_are_complements(folds, n_samples)
|
|
537
|
+
n_folds = len(folds)
|
|
538
|
+
|
|
539
|
+
# Keep exhaustive full-grid CV as the default behavior. Two-stage is opt-in.
|
|
540
|
+
two_stage_enabled = (
|
|
541
|
+
_env_flag("STATGPU_COXPHCV_TWO_STAGE", False) # default=False: opt-in
|
|
542
|
+
and device_name == Device.CUDA.value
|
|
543
|
+
and n_penalties_actual >= 8
|
|
544
|
+
)
|
|
545
|
+
halving_enabled = (
|
|
546
|
+
_env_flag("STATGPU_COXPHCV_SUCCESSIVE_HALVING", False)
|
|
547
|
+
and device_name == Device.CUDA.value
|
|
548
|
+
and n_penalties_actual >= 8
|
|
549
|
+
)
|
|
550
|
+
coarse_n = _env_int(
|
|
551
|
+
"STATGPU_COXPHCV_TWO_STAGE_COARSE",
|
|
552
|
+
6,
|
|
553
|
+
min_value=3,
|
|
554
|
+
max_value=n_penalties_actual,
|
|
555
|
+
)
|
|
556
|
+
window = _env_int("STATGPU_COXPHCV_TWO_STAGE_WINDOW", 2, min_value=1)
|
|
557
|
+
halving_topk = _env_int(
|
|
558
|
+
"STATGPU_COXPHCV_HALVING_TOPK",
|
|
559
|
+
3,
|
|
560
|
+
min_value=1,
|
|
561
|
+
max_value=n_penalties_actual,
|
|
562
|
+
)
|
|
563
|
+
fast_iter = _env_int(
|
|
564
|
+
"STATGPU_COXPHCV_HALVING_FAST_ITER",
|
|
565
|
+
30,
|
|
566
|
+
min_value=5,
|
|
567
|
+
max_value=max_iter,
|
|
568
|
+
)
|
|
569
|
+
fast_tol = _env_float("STATGPU_COXPHCV_HALVING_FAST_TOL", 1e-6, min_value=tol)
|
|
570
|
+
|
|
571
|
+
# Cache handling
|
|
572
|
+
cache_key_eff = cache_key
|
|
573
|
+
if cache_key_eff is None and _COXPH_CV_CACHE_MAXSIZE > 0:
|
|
574
|
+
cache_key_eff = _make_coxph_cv_auto_cache_key(
|
|
575
|
+
X_shape=X_np.shape,
|
|
576
|
+
time_shape=time_np.shape,
|
|
577
|
+
event_shape=event_np.shape,
|
|
578
|
+
X_data=X_np,
|
|
579
|
+
time_data=time_np,
|
|
580
|
+
event_data=event_np,
|
|
581
|
+
penalties=penalties,
|
|
582
|
+
n_penalties=n_penalties,
|
|
583
|
+
penalty_min_ratio=penalty_min_ratio,
|
|
584
|
+
folds=folds,
|
|
585
|
+
ties=ties,
|
|
586
|
+
use_gpu=use_gpu,
|
|
587
|
+
fit_device=fit_device,
|
|
588
|
+
cv_cuda_torch_bridge=cv_cuda_torch_bridge,
|
|
589
|
+
entry=entry_np,
|
|
590
|
+
cluster=cluster_np,
|
|
591
|
+
two_stage_enabled=two_stage_enabled,
|
|
592
|
+
halving_enabled=halving_enabled,
|
|
593
|
+
coarse_n=coarse_n,
|
|
594
|
+
window=window,
|
|
595
|
+
halving_topk=halving_topk,
|
|
596
|
+
fast_iter=fast_iter,
|
|
597
|
+
fast_tol=fast_tol,
|
|
598
|
+
max_iter=max_iter,
|
|
599
|
+
tol=tol,
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
cached_result = _coxcv_cache_get(cache_key_eff)
|
|
603
|
+
if cached_result is not None:
|
|
604
|
+
if return_details:
|
|
605
|
+
return cached_result["penalty"], cached_result
|
|
606
|
+
return cached_result["penalty"]
|
|
607
|
+
|
|
608
|
+
# Storage for partial likelihoods: (n_penalties, n_folds)
|
|
609
|
+
pl_path = np.full((n_penalties_actual, n_folds), np.nan, dtype=np.float64)
|
|
610
|
+
|
|
611
|
+
def _evaluate_penalty_indices(
|
|
612
|
+
penalty_indices: np.ndarray,
|
|
613
|
+
*,
|
|
614
|
+
fit_max_iter: int,
|
|
615
|
+
fit_tol: float,
|
|
616
|
+
) -> None:
|
|
617
|
+
if penalty_indices.size == 0:
|
|
618
|
+
return
|
|
619
|
+
penalty_indices = np.unique(np.asarray(penalty_indices, dtype=np.int64))
|
|
620
|
+
for fold_idx, (train_idx, test_idx) in enumerate(folds):
|
|
621
|
+
X_train, X_test = X_np[train_idx], X_np[test_idx]
|
|
622
|
+
time_train, time_test = time_np[train_idx], time_np[test_idx]
|
|
623
|
+
event_train, event_test = event_np[train_idx], event_np[test_idx]
|
|
624
|
+
entry_train = None if entry_np is None else entry_np[train_idx]
|
|
625
|
+
entry_test = None if entry_np is None else entry_np[test_idx]
|
|
626
|
+
cluster_train = None if cluster_np is None else cluster_np[train_idx]
|
|
627
|
+
X_fit = X_train
|
|
628
|
+
time_fit = time_train
|
|
629
|
+
event_fit = event_train
|
|
630
|
+
entry_fit = entry_train
|
|
631
|
+
cluster_fit = cluster_train
|
|
632
|
+
|
|
633
|
+
# Reduce repeated host->device conversions by preparing one fold
|
|
634
|
+
# tensor/array per backend and reusing it across penalties.
|
|
635
|
+
if fit_device == Device.CUDA.value:
|
|
636
|
+
try:
|
|
637
|
+
import cupy as cp
|
|
638
|
+
X_fit = cp.asarray(X_train, dtype=cp.float64)
|
|
639
|
+
time_fit = cp.asarray(time_train, dtype=cp.float64)
|
|
640
|
+
event_fit = cp.asarray(event_train, dtype=cp.int32)
|
|
641
|
+
entry_fit = None if entry_train is None else cp.asarray(entry_train, dtype=cp.float64)
|
|
642
|
+
cluster_fit = None if cluster_train is None else cp.asarray(cluster_train, dtype=cp.int64)
|
|
643
|
+
except Exception:
|
|
644
|
+
X_fit = X_train
|
|
645
|
+
time_fit = time_train
|
|
646
|
+
event_fit = event_train
|
|
647
|
+
entry_fit = entry_train
|
|
648
|
+
cluster_fit = cluster_train
|
|
649
|
+
elif fit_device == Device.TORCH.value:
|
|
650
|
+
try:
|
|
651
|
+
import torch
|
|
652
|
+
torch_device = _get_torch_device_str()
|
|
653
|
+
X_fit = torch.as_tensor(X_train, dtype=torch.float64, device=torch_device)
|
|
654
|
+
time_fit = torch.as_tensor(time_train, dtype=torch.float64, device=torch_device)
|
|
655
|
+
event_fit = torch.as_tensor(event_train, dtype=torch.int32, device=torch_device)
|
|
656
|
+
entry_fit = None if entry_train is None else torch.as_tensor(
|
|
657
|
+
entry_train, dtype=torch.float64, device=torch_device
|
|
658
|
+
)
|
|
659
|
+
cluster_fit = None if cluster_train is None else torch.as_tensor(
|
|
660
|
+
cluster_train, dtype=torch.int64, device=torch_device
|
|
661
|
+
)
|
|
662
|
+
except Exception:
|
|
663
|
+
X_fit = X_train
|
|
664
|
+
time_fit = time_train
|
|
665
|
+
event_fit = event_train
|
|
666
|
+
entry_fit = entry_train
|
|
667
|
+
cluster_fit = cluster_train
|
|
668
|
+
|
|
669
|
+
n_events_train = int(np.sum(event_train))
|
|
670
|
+
n_events_test = int(np.sum(event_test))
|
|
671
|
+
if n_events_train == 0 or n_events_test == 0:
|
|
672
|
+
continue
|
|
673
|
+
|
|
674
|
+
prev_coef = None
|
|
675
|
+
for penalty_idx in penalty_indices:
|
|
676
|
+
if np.isfinite(pl_path[penalty_idx, fold_idx]):
|
|
677
|
+
continue
|
|
678
|
+
penalty = penalties[penalty_idx]
|
|
679
|
+
model = CoxPH(
|
|
680
|
+
ties=ties,
|
|
681
|
+
max_iter=fit_max_iter,
|
|
682
|
+
tol=fit_tol,
|
|
683
|
+
device=fit_device,
|
|
684
|
+
compute_inference=False,
|
|
685
|
+
penalty=penalty,
|
|
686
|
+
)
|
|
687
|
+
try:
|
|
688
|
+
model.fit(
|
|
689
|
+
X_fit,
|
|
690
|
+
time_fit,
|
|
691
|
+
event_fit,
|
|
692
|
+
entry=entry_fit,
|
|
693
|
+
cluster=cluster_fit,
|
|
694
|
+
init_coef=prev_coef,
|
|
695
|
+
)
|
|
696
|
+
if not model._converged:
|
|
697
|
+
continue
|
|
698
|
+
prev_coef = np.asarray(model.coef_, dtype=np.float64).copy()
|
|
699
|
+
pl_test = _compute_partial_likelihood(
|
|
700
|
+
X_test, time_test, event_test, model.coef_, entry=entry_test, ties=ties
|
|
701
|
+
)
|
|
702
|
+
pl_path[penalty_idx, fold_idx] = pl_test
|
|
703
|
+
except Exception:
|
|
704
|
+
continue
|
|
705
|
+
|
|
706
|
+
if two_stage_enabled:
|
|
707
|
+
stage1_idx = np.unique(
|
|
708
|
+
np.linspace(0, n_penalties_actual - 1, num=coarse_n, dtype=np.int64)
|
|
709
|
+
)
|
|
710
|
+
_evaluate_penalty_indices(
|
|
711
|
+
stage1_idx,
|
|
712
|
+
fit_max_iter=(fast_iter if halving_enabled else max_iter),
|
|
713
|
+
fit_tol=(fast_tol if halving_enabled else tol),
|
|
714
|
+
)
|
|
715
|
+
stage1_mean = np.nanmean(pl_path[stage1_idx, :], axis=1)
|
|
716
|
+
if np.any(np.isfinite(stage1_mean)):
|
|
717
|
+
stage1_best = int(stage1_idx[int(np.nanargmax(stage1_mean))])
|
|
718
|
+
else:
|
|
719
|
+
stage1_best = int(stage1_idx[len(stage1_idx) // 2])
|
|
720
|
+
lo = max(0, stage1_best - window)
|
|
721
|
+
hi = min(n_penalties_actual - 1, stage1_best + window)
|
|
722
|
+
stage2_idx = np.arange(lo, hi + 1, dtype=np.int64)
|
|
723
|
+
_evaluate_penalty_indices(
|
|
724
|
+
stage2_idx,
|
|
725
|
+
fit_max_iter=(fast_iter if halving_enabled else max_iter),
|
|
726
|
+
fit_tol=(fast_tol if halving_enabled else tol),
|
|
727
|
+
)
|
|
728
|
+
if halving_enabled:
|
|
729
|
+
stage2_mean = np.full(stage2_idx.shape[0], np.nan, dtype=np.float64)
|
|
730
|
+
stage2_valid = np.any(np.isfinite(pl_path[stage2_idx, :]), axis=1)
|
|
731
|
+
if np.any(stage2_valid):
|
|
732
|
+
stage2_mean[stage2_valid] = np.nanmean(pl_path[stage2_idx[stage2_valid], :], axis=1)
|
|
733
|
+
order = np.argsort(np.nan_to_num(stage2_mean, nan=-np.inf))[::-1]
|
|
734
|
+
top_idx = stage2_idx[order[: min(halving_topk, len(stage2_idx))]]
|
|
735
|
+
# Re-evaluate top candidates with full precision and overwrite.
|
|
736
|
+
pl_path[top_idx, :] = np.nan
|
|
737
|
+
_evaluate_penalty_indices(top_idx, fit_max_iter=max_iter, fit_tol=tol)
|
|
738
|
+
else:
|
|
739
|
+
full_idx = np.arange(n_penalties_actual, dtype=np.int64)
|
|
740
|
+
if halving_enabled:
|
|
741
|
+
_evaluate_penalty_indices(full_idx, fit_max_iter=fast_iter, fit_tol=fast_tol)
|
|
742
|
+
full_mean = np.full(full_idx.shape[0], np.nan, dtype=np.float64)
|
|
743
|
+
full_valid = np.any(np.isfinite(pl_path[full_idx, :]), axis=1)
|
|
744
|
+
if np.any(full_valid):
|
|
745
|
+
full_mean[full_valid] = np.nanmean(pl_path[full_idx[full_valid], :], axis=1)
|
|
746
|
+
order = np.argsort(np.nan_to_num(full_mean, nan=-np.inf))[::-1]
|
|
747
|
+
top_idx = full_idx[order[:halving_topk]]
|
|
748
|
+
pl_path[top_idx, :] = np.nan
|
|
749
|
+
_evaluate_penalty_indices(top_idx, fit_max_iter=max_iter, fit_tol=tol)
|
|
750
|
+
else:
|
|
751
|
+
_evaluate_penalty_indices(full_idx, fit_max_iter=max_iter, fit_tol=tol)
|
|
752
|
+
|
|
753
|
+
# Safety fallback: if no penalty has any finite fold score, evaluate full grid once.
|
|
754
|
+
has_any_valid = np.any(np.isfinite(pl_path), axis=1)
|
|
755
|
+
if not np.any(has_any_valid):
|
|
756
|
+
_evaluate_penalty_indices(
|
|
757
|
+
np.arange(n_penalties_actual, dtype=np.int64),
|
|
758
|
+
fit_max_iter=max_iter,
|
|
759
|
+
fit_tol=tol,
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
# Compute mean partial likelihood across folds
|
|
763
|
+
mean_pl = np.full(n_penalties_actual, np.nan, dtype=np.float64)
|
|
764
|
+
valid_rows = np.any(np.isfinite(pl_path), axis=1)
|
|
765
|
+
if np.any(valid_rows):
|
|
766
|
+
mean_pl[valid_rows] = np.nanmean(pl_path[valid_rows], axis=1)
|
|
767
|
+
|
|
768
|
+
# Find best penalty (maximum partial likelihood)
|
|
769
|
+
if np.any(np.isfinite(mean_pl)):
|
|
770
|
+
best_idx = np.nanargmax(mean_pl)
|
|
771
|
+
best_penalty = float(penalties[best_idx])
|
|
772
|
+
best_pl = float(mean_pl[best_idx])
|
|
773
|
+
else:
|
|
774
|
+
# No valid CV results - use first penalty
|
|
775
|
+
best_penalty = float(penalties[0])
|
|
776
|
+
best_pl = np.nan
|
|
777
|
+
|
|
778
|
+
# Prepare details
|
|
779
|
+
details = {
|
|
780
|
+
"penalty": best_penalty,
|
|
781
|
+
"penalties": penalties.astype(np.float64),
|
|
782
|
+
"pl_path": pl_path.astype(np.float64),
|
|
783
|
+
"mean_pl": mean_pl.astype(np.float64),
|
|
784
|
+
"best_pl": best_pl,
|
|
785
|
+
"n_folds": n_folds,
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
# Cache result
|
|
789
|
+
if _COXPH_CV_CACHE_MAXSIZE > 0:
|
|
790
|
+
_coxcv_cache_put(cache_key_eff, details)
|
|
791
|
+
|
|
792
|
+
if return_details:
|
|
793
|
+
return best_penalty, details
|
|
794
|
+
|
|
795
|
+
return best_penalty
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
# =============================================================================
|
|
799
|
+
# CoxPHCV Class
|
|
800
|
+
# =============================================================================
|
|
801
|
+
|
|
802
|
+
class CoxPHCV(CVEstimatorBase):
|
|
803
|
+
"""
|
|
804
|
+
Cross-validated Cox Proportional Hazards regression.
|
|
805
|
+
|
|
806
|
+
This class implements K-fold cross-validation to select the optimal
|
|
807
|
+
penalty (L2 regularization) parameter for Cox PH models.
|
|
808
|
+
|
|
809
|
+
Parameters
|
|
810
|
+
----------
|
|
811
|
+
penalties : array-like or None
|
|
812
|
+
Penalty values to try. If None, generates n_penalties values.
|
|
813
|
+
n_penalties : int, default=100
|
|
814
|
+
Number of penalty values (if penalties is None).
|
|
815
|
+
penalty_min_ratio : float, default=1e-3
|
|
816
|
+
Minimum penalty as ratio of max penalty.
|
|
817
|
+
cv : int, default=5
|
|
818
|
+
Number of CV folds.
|
|
819
|
+
ties : str, default='breslow'
|
|
820
|
+
Method for handling ties: 'breslow' or 'efron'.
|
|
821
|
+
tol : float, default=1e-9
|
|
822
|
+
Convergence tolerance.
|
|
823
|
+
max_iter : int, default=100
|
|
824
|
+
Maximum iterations.
|
|
825
|
+
device : str or Device, default='auto'
|
|
826
|
+
Computation device: 'cpu', 'cuda', or 'auto'.
|
|
827
|
+
compute_inference : bool, default=True
|
|
828
|
+
Whether to compute standard errors after fitting.
|
|
829
|
+
cov_type : str, default='nonrobust'
|
|
830
|
+
Covariance estimator.
|
|
831
|
+
gpu_memory_cleanup : bool, default=False
|
|
832
|
+
Whether to free GPU memory after fitting.
|
|
833
|
+
random_state : int or None
|
|
834
|
+
Random seed for CV splits.
|
|
835
|
+
|
|
836
|
+
Attributes
|
|
837
|
+
----------
|
|
838
|
+
penalty_ : float
|
|
839
|
+
Selected penalty value.
|
|
840
|
+
penalties_ : ndarray
|
|
841
|
+
All penalty values tested.
|
|
842
|
+
cv_results_ : dict
|
|
843
|
+
CV results including partial_likelihood_path.
|
|
844
|
+
best_score_ : float
|
|
845
|
+
Best (maximum) partial likelihood across CV folds.
|
|
846
|
+
coef_ : ndarray
|
|
847
|
+
Coefficients of the final model.
|
|
848
|
+
hazard_ratios_ : ndarray
|
|
849
|
+
exp(coef) = hazard ratios.
|
|
850
|
+
estimator_ : CoxPH
|
|
851
|
+
The fitted CoxPH with selected penalty.
|
|
852
|
+
|
|
853
|
+
Examples
|
|
854
|
+
--------
|
|
855
|
+
>>> import numpy as np
|
|
856
|
+
>>> from statgpu.survival import CoxPHCV
|
|
857
|
+
>>> X = np.random.randn(1000, 20)
|
|
858
|
+
>>> time = np.random.exponential(scale=100, size=1000)
|
|
859
|
+
>>> event = np.random.binomial(1, 0.7, size=1000)
|
|
860
|
+
>>> model = CoxPHCV(cv=5, device='cuda')
|
|
861
|
+
>>> model.fit(X, time, event)
|
|
862
|
+
>>> print(f"Selected penalty: {model.penalty_:.4f}")
|
|
863
|
+
>>> print(f"Best CV score: {model.best_score_:.4f}")
|
|
864
|
+
"""
|
|
865
|
+
|
|
866
|
+
def __init__(
|
|
867
|
+
self,
|
|
868
|
+
penalties=None,
|
|
869
|
+
n_penalties: int = 100,
|
|
870
|
+
penalty_min_ratio: float = 1e-3,
|
|
871
|
+
cv: int = 5,
|
|
872
|
+
cv_splits=None,
|
|
873
|
+
ties: str = "breslow",
|
|
874
|
+
tol: float = 1e-9,
|
|
875
|
+
max_iter: int = 100,
|
|
876
|
+
device: Union[str, Device] = Device.AUTO,
|
|
877
|
+
n_jobs: Optional[int] = None,
|
|
878
|
+
compute_inference: bool = True,
|
|
879
|
+
cov_type: str = "nonrobust",
|
|
880
|
+
gpu_memory_cleanup: bool = False,
|
|
881
|
+
random_state: Optional[int] = None,
|
|
882
|
+
):
|
|
883
|
+
super().__init__(
|
|
884
|
+
cv=cv,
|
|
885
|
+
random_state=random_state,
|
|
886
|
+
device=device,
|
|
887
|
+
n_jobs=n_jobs,
|
|
888
|
+
)
|
|
889
|
+
self.penalties = penalties
|
|
890
|
+
self.n_penalties = int(n_penalties)
|
|
891
|
+
self.penalty_min_ratio = float(penalty_min_ratio)
|
|
892
|
+
self.cv = int(cv)
|
|
893
|
+
self.cv_splits = cv_splits
|
|
894
|
+
self.ties = str(ties)
|
|
895
|
+
self.tol = float(tol)
|
|
896
|
+
self.max_iter = int(max_iter)
|
|
897
|
+
self.compute_inference = bool(compute_inference)
|
|
898
|
+
self.cov_type = str(cov_type)
|
|
899
|
+
self.gpu_memory_cleanup = bool(gpu_memory_cleanup)
|
|
900
|
+
|
|
901
|
+
# Output attributes (initialized to None)
|
|
902
|
+
self.penalty_ = None
|
|
903
|
+
self.penalties_ = None
|
|
904
|
+
self.cv_results_ = None
|
|
905
|
+
self.best_score_ = None
|
|
906
|
+
self.coef_ = None
|
|
907
|
+
self.hazard_ratios_ = None
|
|
908
|
+
self.estimator_ = None
|
|
909
|
+
|
|
910
|
+
def _cleanup_cuda_memory(self):
|
|
911
|
+
"""Best-effort CuPy memory pool cleanup."""
|
|
912
|
+
if not self.gpu_memory_cleanup:
|
|
913
|
+
return
|
|
914
|
+
try:
|
|
915
|
+
import cupy as cp
|
|
916
|
+
|
|
917
|
+
cp.get_default_memory_pool().free_all_blocks()
|
|
918
|
+
cp.get_default_pinned_memory_pool().free_all_blocks()
|
|
919
|
+
except Exception:
|
|
920
|
+
pass
|
|
921
|
+
|
|
922
|
+
def _cleanup_torch_memory(self):
|
|
923
|
+
"""Best-effort Torch CUDA cache cleanup."""
|
|
924
|
+
if not self.gpu_memory_cleanup:
|
|
925
|
+
return
|
|
926
|
+
try:
|
|
927
|
+
import torch
|
|
928
|
+
|
|
929
|
+
torch.cuda.empty_cache()
|
|
930
|
+
torch.cuda.synchronize()
|
|
931
|
+
except Exception:
|
|
932
|
+
pass
|
|
933
|
+
|
|
934
|
+
def __del__(self):
|
|
935
|
+
try:
|
|
936
|
+
self._cleanup_cuda_memory()
|
|
937
|
+
self._cleanup_torch_memory()
|
|
938
|
+
except Exception:
|
|
939
|
+
pass
|
|
940
|
+
|
|
941
|
+
def _fit_cv(self, X, time, event, entry=None, cluster=None):
|
|
942
|
+
"""
|
|
943
|
+
Fit CoxPH with K-fold cross-validation.
|
|
944
|
+
|
|
945
|
+
Parameters
|
|
946
|
+
----------
|
|
947
|
+
X : array-like
|
|
948
|
+
Design matrix.
|
|
949
|
+
time : array-like
|
|
950
|
+
Survival times.
|
|
951
|
+
event : array-like
|
|
952
|
+
Event indicators.
|
|
953
|
+
entry : array-like, optional
|
|
954
|
+
Entry times (delayed entry).
|
|
955
|
+
cluster : array-like, optional
|
|
956
|
+
Cluster ids.
|
|
957
|
+
|
|
958
|
+
Returns
|
|
959
|
+
-------
|
|
960
|
+
self
|
|
961
|
+
"""
|
|
962
|
+
device_name = self._get_compute_device().value
|
|
963
|
+
n_samples, n_features = np.asarray(X).shape
|
|
964
|
+
cv_cuda_torch_bridge = os.environ.get(
|
|
965
|
+
"STATGPU_COXPHCV_CUDA_TORCH_BRIDGE", "0"
|
|
966
|
+
).strip().lower() in ("1", "true", "yes", "on")
|
|
967
|
+
fit_device_name = device_name
|
|
968
|
+
if (
|
|
969
|
+
cv_cuda_torch_bridge
|
|
970
|
+
and device_name == Device.CUDA.value
|
|
971
|
+
and n_samples >= 1500
|
|
972
|
+
and n_features >= 40
|
|
973
|
+
):
|
|
974
|
+
fit_device_name = Device.TORCH.value
|
|
975
|
+
|
|
976
|
+
# Normalize penalties to list
|
|
977
|
+
if isinstance(self.penalties, (list, tuple, np.ndarray)):
|
|
978
|
+
penalties = np.asarray(self.penalties, dtype=np.float64)
|
|
979
|
+
else:
|
|
980
|
+
penalties = None
|
|
981
|
+
|
|
982
|
+
# Perform CV to find best penalty
|
|
983
|
+
best_penalty, details = _select_coxph_penalty_cv(
|
|
984
|
+
X, time, event,
|
|
985
|
+
entry=entry,
|
|
986
|
+
cluster=cluster,
|
|
987
|
+
penalties=penalties,
|
|
988
|
+
n_penalties=self.n_penalties,
|
|
989
|
+
penalty_min_ratio=self.penalty_min_ratio,
|
|
990
|
+
cv_folds=self.cv,
|
|
991
|
+
cv_splits=self.cv_splits,
|
|
992
|
+
random_state=self.random_state,
|
|
993
|
+
ties=self.ties,
|
|
994
|
+
device=fit_device_name,
|
|
995
|
+
max_iter=self.max_iter,
|
|
996
|
+
tol=self.tol,
|
|
997
|
+
return_details=True,
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
# Store CV results
|
|
1001
|
+
self.penalty_ = float(best_penalty)
|
|
1002
|
+
self.penalties_ = np.asarray(details["penalties"], dtype=np.float64)
|
|
1003
|
+
|
|
1004
|
+
pl_path = np.asarray(details["pl_path"], dtype=np.float64)
|
|
1005
|
+
mean_pl = np.asarray(details["mean_pl"], dtype=np.float64)
|
|
1006
|
+
|
|
1007
|
+
self.cv_results_ = {
|
|
1008
|
+
"pl_path": pl_path,
|
|
1009
|
+
"mean_pl": mean_pl,
|
|
1010
|
+
}
|
|
1011
|
+
self.best_score_ = float(details["best_pl"])
|
|
1012
|
+
|
|
1013
|
+
# Fit final model on full data with best penalty
|
|
1014
|
+
final_model = CoxPH(
|
|
1015
|
+
ties=self.ties,
|
|
1016
|
+
tol=self.tol,
|
|
1017
|
+
max_iter=self.max_iter,
|
|
1018
|
+
device=fit_device_name,
|
|
1019
|
+
n_jobs=self.n_jobs,
|
|
1020
|
+
compute_inference=self.compute_inference,
|
|
1021
|
+
cov_type=self.cov_type,
|
|
1022
|
+
gpu_memory_cleanup=self.gpu_memory_cleanup,
|
|
1023
|
+
penalty=self.penalty_,
|
|
1024
|
+
)
|
|
1025
|
+
final_model.fit(X, time, event, entry=entry, cluster=cluster)
|
|
1026
|
+
|
|
1027
|
+
self.estimator_ = final_model
|
|
1028
|
+
self.coef_ = final_model.coef_.copy()
|
|
1029
|
+
self.hazard_ratios_ = final_model.hazard_ratios_.copy()
|
|
1030
|
+
self._cleanup_cuda_memory()
|
|
1031
|
+
self._cleanup_torch_memory()
|
|
1032
|
+
|
|
1033
|
+
return self
|
|
1034
|
+
|
|
1035
|
+
def fit(self, X, time, event, entry=None, cluster=None):
|
|
1036
|
+
"""
|
|
1037
|
+
Fit CoxPH model with cross-validation.
|
|
1038
|
+
|
|
1039
|
+
Parameters
|
|
1040
|
+
----------
|
|
1041
|
+
X : array-like of shape (n_samples, n_features)
|
|
1042
|
+
Covariate matrix.
|
|
1043
|
+
time : array-like of shape (n_samples,)
|
|
1044
|
+
Time to event or censoring.
|
|
1045
|
+
event : array-like of shape (n_samples,)
|
|
1046
|
+
Event indicator (1 = event, 0 = censored).
|
|
1047
|
+
entry : array-like, optional
|
|
1048
|
+
Entry time for delayed entry.
|
|
1049
|
+
cluster : array-like, optional
|
|
1050
|
+
Cluster ids.
|
|
1051
|
+
|
|
1052
|
+
Returns
|
|
1053
|
+
-------
|
|
1054
|
+
self : CoxPHCV
|
|
1055
|
+
"""
|
|
1056
|
+
return self._fit_cv(X, time, event, entry=entry, cluster=cluster)
|
|
1057
|
+
|
|
1058
|
+
def predict(self, X):
|
|
1059
|
+
"""
|
|
1060
|
+
Predict risk scores.
|
|
1061
|
+
|
|
1062
|
+
Parameters
|
|
1063
|
+
----------
|
|
1064
|
+
X : array-like of shape (n_samples, n_features)
|
|
1065
|
+
Covariate matrix.
|
|
1066
|
+
|
|
1067
|
+
Returns
|
|
1068
|
+
-------
|
|
1069
|
+
risk_scores : ndarray
|
|
1070
|
+
Risk scores (linear predictor).
|
|
1071
|
+
"""
|
|
1072
|
+
if self.coef_ is None:
|
|
1073
|
+
raise ValueError("Model not fitted. Call fit() first.")
|
|
1074
|
+
|
|
1075
|
+
X_arr = np.asarray(X, dtype=np.float64)
|
|
1076
|
+
return X_arr @ self.coef_
|
|
1077
|
+
|
|
1078
|
+
def score(self, X, time, event):
|
|
1079
|
+
"""
|
|
1080
|
+
Return C-index (concordance index).
|
|
1081
|
+
|
|
1082
|
+
Parameters
|
|
1083
|
+
----------
|
|
1084
|
+
X : array-like
|
|
1085
|
+
Covariate matrix.
|
|
1086
|
+
time : array-like
|
|
1087
|
+
Survival times.
|
|
1088
|
+
event : array-like
|
|
1089
|
+
Event indicators.
|
|
1090
|
+
|
|
1091
|
+
Returns
|
|
1092
|
+
-------
|
|
1093
|
+
c_index : float
|
|
1094
|
+
C-index (0.5 = random, 1.0 = perfect).
|
|
1095
|
+
"""
|
|
1096
|
+
if self.coef_ is None:
|
|
1097
|
+
raise ValueError("Model not fitted. Call fit() first.")
|
|
1098
|
+
|
|
1099
|
+
X_arr = np.asarray(X, dtype=np.float64)
|
|
1100
|
+
time_arr = np.asarray(time, dtype=np.float64)
|
|
1101
|
+
event_arr = np.asarray(event, dtype=np.int32)
|
|
1102
|
+
|
|
1103
|
+
# Compute risk scores
|
|
1104
|
+
risk_scores = X_arr @ self.coef_
|
|
1105
|
+
|
|
1106
|
+
n = len(time_arr)
|
|
1107
|
+
event_mask = (event_arr == 1)
|
|
1108
|
+
|
|
1109
|
+
if not np.any(event_mask):
|
|
1110
|
+
return 0.5
|
|
1111
|
+
|
|
1112
|
+
# Use chunked vectorized approach for memory efficiency
|
|
1113
|
+
# Similar to _compute_cindex in _cox.py
|
|
1114
|
+
event_idx = np.where(event_mask)[0]
|
|
1115
|
+
n_events = len(event_idx)
|
|
1116
|
+
|
|
1117
|
+
if n_events == 0:
|
|
1118
|
+
return 0.5
|
|
1119
|
+
|
|
1120
|
+
concordant = np.int64(0)
|
|
1121
|
+
permissible = np.int64(0)
|
|
1122
|
+
tied_risk = np.int64(0)
|
|
1123
|
+
|
|
1124
|
+
# Chunk size: keep each (chunk × n) bool matrix <= 128 MB
|
|
1125
|
+
chunk_size = max(1, min(n_events, int(128e6 / max(n, 1))))
|
|
1126
|
+
|
|
1127
|
+
for start in range(0, n_events, chunk_size):
|
|
1128
|
+
end = min(start + chunk_size, n_events)
|
|
1129
|
+
idx_chunk = event_idx[start:end]
|
|
1130
|
+
|
|
1131
|
+
time_i = time_arr[idx_chunk, np.newaxis]
|
|
1132
|
+
risk_i = risk_scores[idx_chunk, np.newaxis]
|
|
1133
|
+
time_j = time_arr[np.newaxis, :]
|
|
1134
|
+
risk_j = risk_scores[np.newaxis, :]
|
|
1135
|
+
event_j = event_arr[np.newaxis, :]
|
|
1136
|
+
|
|
1137
|
+
# Permissible pairs: earlier time OR same time with j censored
|
|
1138
|
+
perm = (time_i < time_j) | ((time_i == time_j) & (event_j == 0))
|
|
1139
|
+
|
|
1140
|
+
# Exclude self-comparisons
|
|
1141
|
+
chunk_indices = np.arange(end - start, dtype=np.int64)
|
|
1142
|
+
perm[chunk_indices, idx_chunk] = False
|
|
1143
|
+
|
|
1144
|
+
concordant += int(np.sum(perm & (risk_i > risk_j)))
|
|
1145
|
+
tied_risk += int(np.sum(perm & (risk_i == risk_j)))
|
|
1146
|
+
permissible += int(np.sum(perm))
|
|
1147
|
+
|
|
1148
|
+
if permissible == 0:
|
|
1149
|
+
return 0.5
|
|
1150
|
+
|
|
1151
|
+
return (concordant + 0.5 * tied_risk) / permissible
|
|
1152
|
+
|
|
1153
|
+
def summary(self):
|
|
1154
|
+
"""Return summary of the fitted model."""
|
|
1155
|
+
if self.estimator_ is None:
|
|
1156
|
+
raise RuntimeError("No fitted estimator available.")
|
|
1157
|
+
if not hasattr(self.estimator_, "summary"):
|
|
1158
|
+
raise RuntimeError(f"{self.estimator_.__class__.__name__} does not implement summary().")
|
|
1159
|
+
return self.estimator_.summary()
|