survival 1.1.36__cp314-cp314-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- survival/__init__.py +15 -0
- survival/_survival.cpython-314-darwin.so +0 -0
- survival/_survival.pyi +732 -0
- survival/py.typed +0 -0
- survival/sklearn_compat.py +1308 -0
- survival-1.1.36.dist-info/METADATA +678 -0
- survival-1.1.36.dist-info/RECORD +9 -0
- survival-1.1.36.dist-info/WHEEL +4 -0
- survival-1.1.36.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1308 @@
|
|
|
1
|
+
# ruff: noqa: N803, N806, UP037
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from survival import _survival as _surv
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from numpy.typing import ArrayLike, NDArray
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from sklearn.base import BaseEstimator, RegressorMixin
|
|
16
|
+
from sklearn.utils.validation import check_array, check_is_fitted
|
|
17
|
+
|
|
18
|
+
_HAS_SKLEARN = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
_HAS_SKLEARN = False
|
|
21
|
+
|
|
22
|
+
class BaseEstimator:
|
|
23
|
+
def get_params(self, deep: bool = True) -> dict:
|
|
24
|
+
return {
|
|
25
|
+
k: getattr(self, k)
|
|
26
|
+
for k in self.__init__.__code__.co_varnames[1 : self.__init__.__code__.co_argcount]
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def set_params(self, **params) -> "BaseEstimator":
|
|
30
|
+
for key, value in params.items():
|
|
31
|
+
setattr(self, key, value)
|
|
32
|
+
return self
|
|
33
|
+
|
|
34
|
+
class RegressorMixin:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
def check_array(X, **kwargs):
|
|
38
|
+
return np.asarray(X)
|
|
39
|
+
|
|
40
|
+
def check_is_fitted(estimator, attributes=None):
|
|
41
|
+
if not hasattr(estimator, "is_fitted_") or not estimator.is_fitted_:
|
|
42
|
+
raise ValueError(f"{type(estimator).__name__} is not fitted yet.")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _validate_survival_data(
|
|
46
|
+
X: ArrayLike, y: ArrayLike
|
|
47
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.int32]]:
|
|
48
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
49
|
+
y = np.asarray(y)
|
|
50
|
+
|
|
51
|
+
if y.ndim == 1:
|
|
52
|
+
raise ValueError("y must be a 2D array with columns [time, status]")
|
|
53
|
+
if y.shape[1] != 2:
|
|
54
|
+
raise ValueError("y must have exactly 2 columns: [time, status]")
|
|
55
|
+
|
|
56
|
+
time = y[:, 0].astype(np.float64)
|
|
57
|
+
status = y[:, 1].astype(np.int32)
|
|
58
|
+
|
|
59
|
+
if X.shape[0] != len(time):
|
|
60
|
+
raise ValueError(f"X has {X.shape[0]} samples, but y has {len(time)} samples")
|
|
61
|
+
|
|
62
|
+
return X, time, status
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _compute_concordance_index(
|
|
66
|
+
time: NDArray[np.float64],
|
|
67
|
+
status: NDArray[np.int32],
|
|
68
|
+
risk_scores: NDArray[np.float64],
|
|
69
|
+
) -> float:
|
|
70
|
+
"""Compute concordance index (C-index) for survival predictions."""
|
|
71
|
+
n = len(time)
|
|
72
|
+
concordant = 0.0
|
|
73
|
+
comparable = 0.0
|
|
74
|
+
|
|
75
|
+
for i in range(n):
|
|
76
|
+
if status[i] == 0:
|
|
77
|
+
continue
|
|
78
|
+
for j in range(n):
|
|
79
|
+
if i == j:
|
|
80
|
+
continue
|
|
81
|
+
if time[i] < time[j]:
|
|
82
|
+
comparable += 1.0
|
|
83
|
+
if risk_scores[i] > risk_scores[j]:
|
|
84
|
+
concordant += 1.0
|
|
85
|
+
elif risk_scores[i] == risk_scores[j]:
|
|
86
|
+
concordant += 0.5
|
|
87
|
+
|
|
88
|
+
return concordant / comparable if comparable > 0 else 0.5
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class SurvivalScoreMixin:
|
|
92
|
+
"""Mixin providing concordance index scoring for survival models."""
|
|
93
|
+
|
|
94
|
+
def score(self, X: ArrayLike, y: ArrayLike) -> float:
|
|
95
|
+
"""Return the concordance index on the given test data.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
X : array-like of shape (n_samples, n_features)
|
|
100
|
+
Test samples.
|
|
101
|
+
y : array-like of shape (n_samples, 2)
|
|
102
|
+
True target values.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
score : float
|
|
107
|
+
Concordance index (C-index), between 0 and 1.
|
|
108
|
+
"""
|
|
109
|
+
check_is_fitted(self)
|
|
110
|
+
X, time, status = _validate_survival_data(X, y)
|
|
111
|
+
risk_scores = self.predict(X)
|
|
112
|
+
return _compute_concordance_index(time, status, risk_scores)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class CoxPHEstimator(SurvivalScoreMixin, BaseEstimator, RegressorMixin):
|
|
116
|
+
"""Scikit-learn compatible Cox Proportional Hazards model.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
n_iters : int, default=20
|
|
121
|
+
Maximum number of iterations for the Newton-Raphson optimization.
|
|
122
|
+
|
|
123
|
+
Attributes
|
|
124
|
+
----------
|
|
125
|
+
model_ : CoxPHModel
|
|
126
|
+
The underlying fitted Cox model.
|
|
127
|
+
coef_ : ndarray of shape (n_features,)
|
|
128
|
+
Estimated coefficients.
|
|
129
|
+
n_features_in_ : int
|
|
130
|
+
Number of features seen during fit.
|
|
131
|
+
|
|
132
|
+
Examples
|
|
133
|
+
--------
|
|
134
|
+
>>> from survival.sklearn_compat import CoxPHEstimator
|
|
135
|
+
>>> import numpy as np
|
|
136
|
+
>>> X = np.random.randn(100, 3)
|
|
137
|
+
>>> y = np.column_stack([np.random.exponential(10, 100), np.random.binomial(1, 0.7, 100)])
|
|
138
|
+
>>> model = CoxPHEstimator()
|
|
139
|
+
>>> model.fit(X, y)
|
|
140
|
+
>>> risk_scores = model.predict(X)
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def __init__(self, n_iters: int = 20):
|
|
144
|
+
self.n_iters = n_iters
|
|
145
|
+
|
|
146
|
+
def fit(self, X: ArrayLike, y: ArrayLike) -> "CoxPHEstimator":
|
|
147
|
+
"""Fit the Cox PH model.
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
X : array-like of shape (n_samples, n_features)
|
|
152
|
+
Training data.
|
|
153
|
+
y : array-like of shape (n_samples, 2)
|
|
154
|
+
Target values where y[:, 0] is survival time and y[:, 1] is event status.
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
self : CoxPHEstimator
|
|
159
|
+
Fitted estimator.
|
|
160
|
+
"""
|
|
161
|
+
X, time, status = _validate_survival_data(X, y)
|
|
162
|
+
self.n_features_in_ = X.shape[1]
|
|
163
|
+
|
|
164
|
+
covariates = X.tolist()
|
|
165
|
+
self.model_ = _surv.CoxPHModel.new_with_data(covariates, time.tolist(), status.tolist())
|
|
166
|
+
self.model_.fit(self.n_iters)
|
|
167
|
+
|
|
168
|
+
self.coef_ = np.array(self.model_.coefficients)
|
|
169
|
+
self.is_fitted_ = True
|
|
170
|
+
return self
|
|
171
|
+
|
|
172
|
+
def predict(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
173
|
+
"""Predict risk scores for samples.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
X : array-like of shape (n_samples, n_features)
|
|
178
|
+
Samples to predict.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
risk_scores : ndarray of shape (n_samples,)
|
|
183
|
+
Predicted risk scores (higher = higher risk).
|
|
184
|
+
"""
|
|
185
|
+
check_is_fitted(self)
|
|
186
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
187
|
+
|
|
188
|
+
if X.shape[1] != self.n_features_in_:
|
|
189
|
+
raise ValueError(
|
|
190
|
+
f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
return np.array(self.model_.predict(X.tolist()))
|
|
194
|
+
|
|
195
|
+
def predict_survival_function(
|
|
196
|
+
self, X: ArrayLike, times: ArrayLike | None = None
|
|
197
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
|
|
198
|
+
"""Predict survival function for samples.
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
X : array-like of shape (n_samples, n_features)
|
|
203
|
+
Samples to predict.
|
|
204
|
+
times : array-like of shape (n_times,), optional
|
|
205
|
+
Time points at which to evaluate the survival function.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
times : ndarray of shape (n_times,)
|
|
210
|
+
Time points.
|
|
211
|
+
survival : ndarray of shape (n_samples, n_times)
|
|
212
|
+
Survival probabilities.
|
|
213
|
+
"""
|
|
214
|
+
check_is_fitted(self)
|
|
215
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
216
|
+
|
|
217
|
+
times_list = times.tolist() if times is not None else None
|
|
218
|
+
t, surv = self.model_.survival_curve(X.tolist(), times_list)
|
|
219
|
+
return np.array(t), np.array(surv)
|
|
220
|
+
|
|
221
|
+
def predict_median_survival_time(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
222
|
+
"""Predict median survival time for samples.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
X : array-like of shape (n_samples, n_features)
|
|
227
|
+
Samples to predict.
|
|
228
|
+
|
|
229
|
+
Returns
|
|
230
|
+
-------
|
|
231
|
+
median_times : ndarray of shape (n_samples,)
|
|
232
|
+
Predicted median survival times (NaN if survival never drops below 0.5).
|
|
233
|
+
"""
|
|
234
|
+
check_is_fitted(self)
|
|
235
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
236
|
+
|
|
237
|
+
result = self.model_.predicted_survival_time(X.tolist(), 0.5)
|
|
238
|
+
return np.array([t if t is not None else np.nan for t in result])
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class GradientBoostSurvivalEstimator(SurvivalScoreMixin, BaseEstimator, RegressorMixin):
|
|
242
|
+
"""Scikit-learn compatible Gradient Boosting Survival model.
|
|
243
|
+
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
n_estimators : int, default=100
|
|
247
|
+
Number of boosting iterations.
|
|
248
|
+
learning_rate : float, default=0.1
|
|
249
|
+
Learning rate shrinks the contribution of each tree.
|
|
250
|
+
max_depth : int, default=3
|
|
251
|
+
Maximum depth of the individual regression trees.
|
|
252
|
+
min_samples_split : int, default=10
|
|
253
|
+
Minimum number of samples required to split an internal node.
|
|
254
|
+
min_samples_leaf : int, default=5
|
|
255
|
+
Minimum number of samples required at each leaf node.
|
|
256
|
+
subsample : float, default=1.0
|
|
257
|
+
Fraction of samples used for fitting individual trees.
|
|
258
|
+
max_features : int or None, default=None
|
|
259
|
+
Number of features to consider for splits.
|
|
260
|
+
seed : int or None, default=None
|
|
261
|
+
Random seed for reproducibility.
|
|
262
|
+
|
|
263
|
+
Attributes
|
|
264
|
+
----------
|
|
265
|
+
model_ : GradientBoostSurvival
|
|
266
|
+
The underlying fitted model.
|
|
267
|
+
feature_importances_ : ndarray of shape (n_features,)
|
|
268
|
+
Feature importances.
|
|
269
|
+
n_features_in_ : int
|
|
270
|
+
Number of features seen during fit.
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
def __init__(
|
|
274
|
+
self,
|
|
275
|
+
n_estimators: int = 100,
|
|
276
|
+
learning_rate: float = 0.1,
|
|
277
|
+
max_depth: int = 3,
|
|
278
|
+
min_samples_split: int = 10,
|
|
279
|
+
min_samples_leaf: int = 5,
|
|
280
|
+
subsample: float = 1.0,
|
|
281
|
+
max_features: int | None = None,
|
|
282
|
+
seed: int | None = None,
|
|
283
|
+
):
|
|
284
|
+
self.n_estimators = n_estimators
|
|
285
|
+
self.learning_rate = learning_rate
|
|
286
|
+
self.max_depth = max_depth
|
|
287
|
+
self.min_samples_split = min_samples_split
|
|
288
|
+
self.min_samples_leaf = min_samples_leaf
|
|
289
|
+
self.subsample = subsample
|
|
290
|
+
self.max_features = max_features
|
|
291
|
+
self.seed = seed
|
|
292
|
+
|
|
293
|
+
def fit(self, X: ArrayLike, y: ArrayLike) -> "GradientBoostSurvivalEstimator":
|
|
294
|
+
"""Fit the gradient boosting survival model.
|
|
295
|
+
|
|
296
|
+
Parameters
|
|
297
|
+
----------
|
|
298
|
+
X : array-like of shape (n_samples, n_features)
|
|
299
|
+
Training data.
|
|
300
|
+
y : array-like of shape (n_samples, 2)
|
|
301
|
+
Target values where y[:, 0] is survival time and y[:, 1] is event status.
|
|
302
|
+
|
|
303
|
+
Returns
|
|
304
|
+
-------
|
|
305
|
+
self : GradientBoostSurvivalEstimator
|
|
306
|
+
Fitted estimator.
|
|
307
|
+
"""
|
|
308
|
+
X, time, status = _validate_survival_data(X, y)
|
|
309
|
+
self.n_features_in_ = X.shape[1]
|
|
310
|
+
n_obs = X.shape[0]
|
|
311
|
+
|
|
312
|
+
config = _surv.GradientBoostSurvivalConfig(
|
|
313
|
+
n_estimators=self.n_estimators,
|
|
314
|
+
learning_rate=self.learning_rate,
|
|
315
|
+
max_depth=self.max_depth,
|
|
316
|
+
min_samples_split=self.min_samples_split,
|
|
317
|
+
min_samples_leaf=self.min_samples_leaf,
|
|
318
|
+
subsample=self.subsample,
|
|
319
|
+
max_features=self.max_features,
|
|
320
|
+
seed=self.seed,
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
x_flat = X.flatten().tolist()
|
|
324
|
+
self.model_ = _surv.GradientBoostSurvival.fit(
|
|
325
|
+
x_flat, n_obs, self.n_features_in_, time.tolist(), status.tolist(), config
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
self.feature_importances_ = np.array(self.model_.feature_importance)
|
|
329
|
+
self.is_fitted_ = True
|
|
330
|
+
return self
|
|
331
|
+
|
|
332
|
+
def predict(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
333
|
+
"""Predict risk scores for samples.
|
|
334
|
+
|
|
335
|
+
Parameters
|
|
336
|
+
----------
|
|
337
|
+
X : array-like of shape (n_samples, n_features)
|
|
338
|
+
Samples to predict.
|
|
339
|
+
|
|
340
|
+
Returns
|
|
341
|
+
-------
|
|
342
|
+
risk_scores : ndarray of shape (n_samples,)
|
|
343
|
+
Predicted risk scores (higher = higher risk).
|
|
344
|
+
"""
|
|
345
|
+
check_is_fitted(self)
|
|
346
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
347
|
+
|
|
348
|
+
if X.shape[1] != self.n_features_in_:
|
|
349
|
+
raise ValueError(
|
|
350
|
+
f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
x_flat = X.flatten().tolist()
|
|
354
|
+
return np.array(self.model_.predict_risk(x_flat, X.shape[0]))
|
|
355
|
+
|
|
356
|
+
def predict_survival_function(
|
|
357
|
+
self, X: ArrayLike
|
|
358
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
|
|
359
|
+
"""Predict survival function for samples.
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
X : array-like of shape (n_samples, n_features)
|
|
364
|
+
Samples to predict.
|
|
365
|
+
|
|
366
|
+
Returns
|
|
367
|
+
-------
|
|
368
|
+
times : ndarray of shape (n_times,)
|
|
369
|
+
Time points.
|
|
370
|
+
survival : ndarray of shape (n_samples, n_times)
|
|
371
|
+
Survival probabilities.
|
|
372
|
+
"""
|
|
373
|
+
check_is_fitted(self)
|
|
374
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
375
|
+
|
|
376
|
+
x_flat = X.flatten().tolist()
|
|
377
|
+
survival = self.model_.predict_survival(x_flat, X.shape[0])
|
|
378
|
+
return np.array(self.model_.unique_times), np.array(survival)
|
|
379
|
+
|
|
380
|
+
def predict_median_survival_time(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
381
|
+
"""Predict median survival time for samples.
|
|
382
|
+
|
|
383
|
+
Parameters
|
|
384
|
+
----------
|
|
385
|
+
X : array-like of shape (n_samples, n_features)
|
|
386
|
+
Samples to predict.
|
|
387
|
+
|
|
388
|
+
Returns
|
|
389
|
+
-------
|
|
390
|
+
median_times : ndarray of shape (n_samples,)
|
|
391
|
+
Predicted median survival times (NaN if survival never drops below 0.5).
|
|
392
|
+
"""
|
|
393
|
+
check_is_fitted(self)
|
|
394
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
395
|
+
|
|
396
|
+
x_flat = X.flatten().tolist()
|
|
397
|
+
result = self.model_.predict_median_survival_time(x_flat, X.shape[0])
|
|
398
|
+
return np.array([t if t is not None else np.nan for t in result])
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
class SurvivalForestEstimator(SurvivalScoreMixin, BaseEstimator, RegressorMixin):
|
|
402
|
+
"""Scikit-learn compatible Random Survival Forest model.
|
|
403
|
+
|
|
404
|
+
Parameters
|
|
405
|
+
----------
|
|
406
|
+
n_trees : int, default=500
|
|
407
|
+
Number of trees in the forest.
|
|
408
|
+
max_depth : int or None, default=None
|
|
409
|
+
Maximum depth of trees.
|
|
410
|
+
min_node_size : int, default=15
|
|
411
|
+
Minimum number of samples at each leaf node.
|
|
412
|
+
mtry : int or None, default=None
|
|
413
|
+
Number of features to consider at each split (default: sqrt(n_features)).
|
|
414
|
+
sample_fraction : float, default=0.632
|
|
415
|
+
Fraction of samples used for each tree.
|
|
416
|
+
seed : int or None, default=None
|
|
417
|
+
Random seed for reproducibility.
|
|
418
|
+
oob_error : bool, default=True
|
|
419
|
+
Whether to compute out-of-bag error.
|
|
420
|
+
|
|
421
|
+
Attributes
|
|
422
|
+
----------
|
|
423
|
+
model_ : SurvivalForest
|
|
424
|
+
The underlying fitted model.
|
|
425
|
+
variable_importance_ : ndarray of shape (n_features,)
|
|
426
|
+
Variable importances.
|
|
427
|
+
oob_error_ : float or None
|
|
428
|
+
Out-of-bag error (if computed).
|
|
429
|
+
n_features_in_ : int
|
|
430
|
+
Number of features seen during fit.
|
|
431
|
+
"""
|
|
432
|
+
|
|
433
|
+
def __init__(
|
|
434
|
+
self,
|
|
435
|
+
n_trees: int = 500,
|
|
436
|
+
max_depth: int | None = None,
|
|
437
|
+
min_node_size: int = 15,
|
|
438
|
+
mtry: int | None = None,
|
|
439
|
+
sample_fraction: float = 0.632,
|
|
440
|
+
seed: int | None = None,
|
|
441
|
+
oob_error: bool = True,
|
|
442
|
+
):
|
|
443
|
+
self.n_trees = n_trees
|
|
444
|
+
self.max_depth = max_depth
|
|
445
|
+
self.min_node_size = min_node_size
|
|
446
|
+
self.mtry = mtry
|
|
447
|
+
self.sample_fraction = sample_fraction
|
|
448
|
+
self.seed = seed
|
|
449
|
+
self.oob_error = oob_error
|
|
450
|
+
|
|
451
|
+
def fit(self, X: ArrayLike, y: ArrayLike) -> "SurvivalForestEstimator":
|
|
452
|
+
"""Fit the random survival forest model.
|
|
453
|
+
|
|
454
|
+
Parameters
|
|
455
|
+
----------
|
|
456
|
+
X : array-like of shape (n_samples, n_features)
|
|
457
|
+
Training data.
|
|
458
|
+
y : array-like of shape (n_samples, 2)
|
|
459
|
+
Target values where y[:, 0] is survival time and y[:, 1] is event status.
|
|
460
|
+
|
|
461
|
+
Returns
|
|
462
|
+
-------
|
|
463
|
+
self : SurvivalForestEstimator
|
|
464
|
+
Fitted estimator.
|
|
465
|
+
"""
|
|
466
|
+
X, time, status = _validate_survival_data(X, y)
|
|
467
|
+
self.n_features_in_ = X.shape[1]
|
|
468
|
+
n_obs = X.shape[0]
|
|
469
|
+
|
|
470
|
+
config = _surv.SurvivalForestConfig(
|
|
471
|
+
n_trees=self.n_trees,
|
|
472
|
+
max_depth=self.max_depth,
|
|
473
|
+
min_node_size=self.min_node_size,
|
|
474
|
+
mtry=self.mtry,
|
|
475
|
+
sample_fraction=self.sample_fraction,
|
|
476
|
+
seed=self.seed,
|
|
477
|
+
oob_error=self.oob_error,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
x_flat = X.flatten().tolist()
|
|
481
|
+
self.model_ = _surv.SurvivalForest.fit(
|
|
482
|
+
x_flat, n_obs, self.n_features_in_, time.tolist(), status.tolist(), config
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
self.variable_importance_ = np.array(self.model_.variable_importance)
|
|
486
|
+
self.oob_error_ = self.model_.oob_error
|
|
487
|
+
self.is_fitted_ = True
|
|
488
|
+
return self
|
|
489
|
+
|
|
490
|
+
def predict(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
491
|
+
"""Predict risk scores for samples.
|
|
492
|
+
|
|
493
|
+
Parameters
|
|
494
|
+
----------
|
|
495
|
+
X : array-like of shape (n_samples, n_features)
|
|
496
|
+
Samples to predict.
|
|
497
|
+
|
|
498
|
+
Returns
|
|
499
|
+
-------
|
|
500
|
+
risk_scores : ndarray of shape (n_samples,)
|
|
501
|
+
Predicted risk scores (cumulative hazard at last time point).
|
|
502
|
+
"""
|
|
503
|
+
check_is_fitted(self)
|
|
504
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
505
|
+
|
|
506
|
+
if X.shape[1] != self.n_features_in_:
|
|
507
|
+
raise ValueError(
|
|
508
|
+
f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
x_flat = X.flatten().tolist()
|
|
512
|
+
return np.array(self.model_.predict_risk(x_flat, X.shape[0]))
|
|
513
|
+
|
|
514
|
+
def predict_survival_function(
|
|
515
|
+
self, X: ArrayLike
|
|
516
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
|
|
517
|
+
"""Predict survival function for samples.
|
|
518
|
+
|
|
519
|
+
Parameters
|
|
520
|
+
----------
|
|
521
|
+
X : array-like of shape (n_samples, n_features)
|
|
522
|
+
Samples to predict.
|
|
523
|
+
|
|
524
|
+
Returns
|
|
525
|
+
-------
|
|
526
|
+
times : ndarray of shape (n_times,)
|
|
527
|
+
Time points.
|
|
528
|
+
survival : ndarray of shape (n_samples, n_times)
|
|
529
|
+
Survival probabilities.
|
|
530
|
+
"""
|
|
531
|
+
check_is_fitted(self)
|
|
532
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
533
|
+
|
|
534
|
+
x_flat = X.flatten().tolist()
|
|
535
|
+
survival = self.model_.predict_survival(x_flat, X.shape[0])
|
|
536
|
+
return np.array(self.model_.unique_times), np.array(survival)
|
|
537
|
+
|
|
538
|
+
def predict_median_survival_time(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
539
|
+
"""Predict median survival time for samples.
|
|
540
|
+
|
|
541
|
+
Parameters
|
|
542
|
+
----------
|
|
543
|
+
X : array-like of shape (n_samples, n_features)
|
|
544
|
+
Samples to predict.
|
|
545
|
+
|
|
546
|
+
Returns
|
|
547
|
+
-------
|
|
548
|
+
median_times : ndarray of shape (n_samples,)
|
|
549
|
+
Predicted median survival times (NaN if survival never drops below 0.5).
|
|
550
|
+
"""
|
|
551
|
+
check_is_fitted(self)
|
|
552
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
553
|
+
|
|
554
|
+
x_flat = X.flatten().tolist()
|
|
555
|
+
result = self.model_.predict_median_survival_time(x_flat, X.shape[0])
|
|
556
|
+
return np.array([t if t is not None else np.nan for t in result])
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
class AFTEstimator(BaseEstimator, RegressorMixin):
|
|
560
|
+
"""Scikit-learn compatible Accelerated Failure Time (AFT) model.
|
|
561
|
+
|
|
562
|
+
AFT models assume that covariates act multiplicatively on the survival time,
|
|
563
|
+
i.e., log(T) = X @ beta + sigma * epsilon, where epsilon follows a specified
|
|
564
|
+
error distribution.
|
|
565
|
+
|
|
566
|
+
Parameters
|
|
567
|
+
----------
|
|
568
|
+
distribution : str, default="weibull"
|
|
569
|
+
Error distribution. One of:
|
|
570
|
+
- "weibull": Weibull distribution (extreme value errors)
|
|
571
|
+
- "lognormal": Log-normal distribution (Gaussian errors)
|
|
572
|
+
- "loglogistic": Log-logistic distribution (logistic errors)
|
|
573
|
+
- "exponential": Exponential distribution (special case of Weibull)
|
|
574
|
+
- "gaussian": Gaussian distribution (for linear models)
|
|
575
|
+
- "logistic": Logistic distribution (for linear models)
|
|
576
|
+
max_iter : int, default=100
|
|
577
|
+
Maximum number of iterations for optimization.
|
|
578
|
+
tol : float, default=1e-9
|
|
579
|
+
Convergence tolerance.
|
|
580
|
+
|
|
581
|
+
Attributes
|
|
582
|
+
----------
|
|
583
|
+
model_ : SurvivalFit
|
|
584
|
+
The underlying fitted AFT model.
|
|
585
|
+
coef_ : ndarray of shape (n_features,)
|
|
586
|
+
Estimated coefficients (acceleration factors in log scale).
|
|
587
|
+
scale_ : float
|
|
588
|
+
Estimated scale parameter (sigma).
|
|
589
|
+
n_features_in_ : int
|
|
590
|
+
Number of features seen during fit.
|
|
591
|
+
|
|
592
|
+
Examples
|
|
593
|
+
--------
|
|
594
|
+
>>> from survival.sklearn_compat import AFTEstimator
|
|
595
|
+
>>> import numpy as np
|
|
596
|
+
>>> X = np.random.randn(100, 3)
|
|
597
|
+
>>> y = np.column_stack([np.random.exponential(10, 100), np.random.binomial(1, 0.7, 100)])
|
|
598
|
+
>>> model = AFTEstimator(distribution="weibull")
|
|
599
|
+
>>> model.fit(X, y)
|
|
600
|
+
>>> predicted_times = model.predict(X)
|
|
601
|
+
|
|
602
|
+
Notes
|
|
603
|
+
-----
|
|
604
|
+
The AFT model interprets coefficients as acceleration factors:
|
|
605
|
+
- Positive coefficients increase expected survival time
|
|
606
|
+
- Negative coefficients decrease expected survival time
|
|
607
|
+
- exp(coef) gives the multiplicative effect on survival time
|
|
608
|
+
"""
|
|
609
|
+
|
|
610
|
+
def __init__(
|
|
611
|
+
self,
|
|
612
|
+
distribution: str = "weibull",
|
|
613
|
+
max_iter: int = 200,
|
|
614
|
+
tol: float = 1e-9,
|
|
615
|
+
):
|
|
616
|
+
self.distribution = distribution
|
|
617
|
+
self.max_iter = max_iter
|
|
618
|
+
self.tol = tol
|
|
619
|
+
|
|
620
|
+
def fit(self, X: ArrayLike, y: ArrayLike) -> "AFTEstimator":
|
|
621
|
+
"""Fit the AFT model using maximum likelihood estimation.
|
|
622
|
+
|
|
623
|
+
Parameters
|
|
624
|
+
----------
|
|
625
|
+
X : array-like of shape (n_samples, n_features)
|
|
626
|
+
Training data.
|
|
627
|
+
y : array-like of shape (n_samples, 2)
|
|
628
|
+
Target values where y[:, 0] is survival time and y[:, 1] is event status.
|
|
629
|
+
|
|
630
|
+
Returns
|
|
631
|
+
-------
|
|
632
|
+
self : AFTEstimator
|
|
633
|
+
Fitted estimator.
|
|
634
|
+
"""
|
|
635
|
+
X, time, status = _validate_survival_data(X, y)
|
|
636
|
+
self.n_features_in_ = X.shape[1]
|
|
637
|
+
n = len(time)
|
|
638
|
+
|
|
639
|
+
events = status == 1
|
|
640
|
+
n_events = events.sum()
|
|
641
|
+
|
|
642
|
+
if n_events < X.shape[1] + 1:
|
|
643
|
+
raise ValueError(
|
|
644
|
+
f"Not enough events ({n_events}) to fit model with {X.shape[1]} features"
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
X_with_intercept = np.column_stack([np.ones(n), X])
|
|
648
|
+
|
|
649
|
+
self.model_ = _surv.survreg(
|
|
650
|
+
time=time.tolist(),
|
|
651
|
+
status=status.tolist(),
|
|
652
|
+
covariates=X_with_intercept.tolist(),
|
|
653
|
+
distribution=self.distribution,
|
|
654
|
+
max_iter=self.max_iter,
|
|
655
|
+
eps=self.tol,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
coefs = np.array(self.model_.coefficients)
|
|
659
|
+
self.intercept_ = coefs[0]
|
|
660
|
+
self.coef_ = coefs[1:-1]
|
|
661
|
+
self.scale_ = np.exp(coefs[-1])
|
|
662
|
+
self.converged_ = self.model_.convergence_flag == 0
|
|
663
|
+
|
|
664
|
+
self.is_fitted_ = True
|
|
665
|
+
return self
|
|
666
|
+
|
|
667
|
+
def predict(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
668
|
+
"""Predict expected survival time for samples.
|
|
669
|
+
|
|
670
|
+
Parameters
|
|
671
|
+
----------
|
|
672
|
+
X : array-like of shape (n_samples, n_features)
|
|
673
|
+
Samples to predict.
|
|
674
|
+
|
|
675
|
+
Returns
|
|
676
|
+
-------
|
|
677
|
+
survival_times : ndarray of shape (n_samples,)
|
|
678
|
+
Predicted survival times (median by default).
|
|
679
|
+
"""
|
|
680
|
+
check_is_fitted(self)
|
|
681
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
682
|
+
|
|
683
|
+
if X.shape[1] != self.n_features_in_:
|
|
684
|
+
raise ValueError(
|
|
685
|
+
f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
linear_pred = self.intercept_ + X @ self.coef_
|
|
689
|
+
return np.exp(linear_pred)
|
|
690
|
+
|
|
691
|
+
def predict_median(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
692
|
+
"""Predict median survival time for samples.
|
|
693
|
+
|
|
694
|
+
Parameters
|
|
695
|
+
----------
|
|
696
|
+
X : array-like of shape (n_samples, n_features)
|
|
697
|
+
Samples to predict.
|
|
698
|
+
|
|
699
|
+
Returns
|
|
700
|
+
-------
|
|
701
|
+
median_times : ndarray of shape (n_samples,)
|
|
702
|
+
Predicted median survival times.
|
|
703
|
+
"""
|
|
704
|
+
check_is_fitted(self)
|
|
705
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
706
|
+
|
|
707
|
+
linear_pred = self.intercept_ + X @ self.coef_
|
|
708
|
+
|
|
709
|
+
if self.distribution in ("weibull", "exponential", "extreme_value"):
|
|
710
|
+
median_z = np.log(np.log(2))
|
|
711
|
+
elif self.distribution in ("lognormal", "gaussian", "loglogistic", "logistic"):
|
|
712
|
+
median_z = 0.0
|
|
713
|
+
else:
|
|
714
|
+
median_z = 0.0
|
|
715
|
+
|
|
716
|
+
return np.exp(linear_pred + self.scale_ * median_z)
|
|
717
|
+
|
|
718
|
+
def predict_quantile(self, X: ArrayLike, q: float = 0.5) -> NDArray[np.float64]:
|
|
719
|
+
"""Predict survival time quantile for samples.
|
|
720
|
+
|
|
721
|
+
Parameters
|
|
722
|
+
----------
|
|
723
|
+
X : array-like of shape (n_samples, n_features)
|
|
724
|
+
Samples to predict.
|
|
725
|
+
q : float, default=0.5
|
|
726
|
+
Quantile to predict (0 < q < 1). Default is median (0.5).
|
|
727
|
+
|
|
728
|
+
Returns
|
|
729
|
+
-------
|
|
730
|
+
quantile_times : ndarray of shape (n_samples,)
|
|
731
|
+
Predicted survival times at the given quantile.
|
|
732
|
+
"""
|
|
733
|
+
check_is_fitted(self)
|
|
734
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
735
|
+
|
|
736
|
+
if not 0 < q < 1:
|
|
737
|
+
raise ValueError("q must be between 0 and 1")
|
|
738
|
+
|
|
739
|
+
linear_pred = self.intercept_ + X @ self.coef_
|
|
740
|
+
|
|
741
|
+
if self.distribution in ("weibull", "exponential", "extreme_value"):
|
|
742
|
+
z_q = np.log(-np.log(1 - q))
|
|
743
|
+
elif self.distribution in ("lognormal", "gaussian"):
|
|
744
|
+
from scipy.stats import norm
|
|
745
|
+
|
|
746
|
+
z_q = norm.ppf(q)
|
|
747
|
+
elif self.distribution in ("loglogistic", "logistic"):
|
|
748
|
+
z_q = np.log(q / (1 - q))
|
|
749
|
+
else:
|
|
750
|
+
z_q = 0.0
|
|
751
|
+
|
|
752
|
+
return np.exp(linear_pred + self.scale_ * z_q)
|
|
753
|
+
|
|
754
|
+
def score(self, X: ArrayLike, y: ArrayLike) -> float:
|
|
755
|
+
"""Return the concordance index on the given test data.
|
|
756
|
+
|
|
757
|
+
Parameters
|
|
758
|
+
----------
|
|
759
|
+
X : array-like of shape (n_samples, n_features)
|
|
760
|
+
Test samples.
|
|
761
|
+
y : array-like of shape (n_samples, 2)
|
|
762
|
+
True target values.
|
|
763
|
+
|
|
764
|
+
Returns
|
|
765
|
+
-------
|
|
766
|
+
score : float
|
|
767
|
+
Concordance index (C-index), between 0 and 1.
|
|
768
|
+
"""
|
|
769
|
+
check_is_fitted(self)
|
|
770
|
+
X, time, status = _validate_survival_data(X, y)
|
|
771
|
+
predicted_times = self.predict(X)
|
|
772
|
+
return _compute_concordance_index(time, status, -predicted_times)
|
|
773
|
+
|
|
774
|
+
@property
|
|
775
|
+
def acceleration_factors(self) -> NDArray[np.float64]:
|
|
776
|
+
"""Return acceleration factors (exp of coefficients).
|
|
777
|
+
|
|
778
|
+
Returns
|
|
779
|
+
-------
|
|
780
|
+
af : ndarray of shape (n_features,)
|
|
781
|
+
Acceleration factors. Values > 1 increase survival time,
|
|
782
|
+
values < 1 decrease survival time.
|
|
783
|
+
"""
|
|
784
|
+
check_is_fitted(self)
|
|
785
|
+
return np.exp(self.coef_)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
def iter_chunks(X: ArrayLike, batch_size: int = 1000) -> Iterator[tuple[int, NDArray[np.float64]]]:
|
|
789
|
+
"""Iterate over an array in chunks.
|
|
790
|
+
|
|
791
|
+
Parameters
|
|
792
|
+
----------
|
|
793
|
+
X : array-like of shape (n_samples, n_features)
|
|
794
|
+
Input array.
|
|
795
|
+
batch_size : int, default=1000
|
|
796
|
+
Number of samples per chunk.
|
|
797
|
+
|
|
798
|
+
Yields
|
|
799
|
+
------
|
|
800
|
+
start_idx : int
|
|
801
|
+
Starting index of the chunk.
|
|
802
|
+
chunk : ndarray
|
|
803
|
+
Chunk of the input array.
|
|
804
|
+
|
|
805
|
+
Examples
|
|
806
|
+
--------
|
|
807
|
+
>>> import numpy as np
|
|
808
|
+
>>> X = np.random.randn(10000, 5)
|
|
809
|
+
>>> for start_idx, chunk in iter_chunks(X, batch_size=1000):
|
|
810
|
+
... print(f"Processing samples {start_idx} to {start_idx + len(chunk)}")
|
|
811
|
+
"""
|
|
812
|
+
X = np.asarray(X)
|
|
813
|
+
n_samples = X.shape[0]
|
|
814
|
+
for start_idx in range(0, n_samples, batch_size):
|
|
815
|
+
end_idx = min(start_idx + batch_size, n_samples)
|
|
816
|
+
yield start_idx, X[start_idx:end_idx]
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
class StreamingMixin:
|
|
820
|
+
"""Mixin class providing streaming/batched prediction methods."""
|
|
821
|
+
|
|
822
|
+
def predict_batched(
|
|
823
|
+
self, X: ArrayLike, batch_size: int = 1000
|
|
824
|
+
) -> Iterator[NDArray[np.float64]]:
|
|
825
|
+
"""Predict risk scores in batches to handle large datasets.
|
|
826
|
+
|
|
827
|
+
Parameters
|
|
828
|
+
----------
|
|
829
|
+
X : array-like of shape (n_samples, n_features)
|
|
830
|
+
Samples to predict.
|
|
831
|
+
batch_size : int, default=1000
|
|
832
|
+
Number of samples per batch.
|
|
833
|
+
|
|
834
|
+
Yields
|
|
835
|
+
------
|
|
836
|
+
risk_scores : ndarray of shape (batch_size,) or smaller for last batch
|
|
837
|
+
Predicted risk scores for each batch.
|
|
838
|
+
|
|
839
|
+
Examples
|
|
840
|
+
--------
|
|
841
|
+
>>> model = GradientBoostSurvivalEstimator()
|
|
842
|
+
>>> model.fit(X_train, y_train)
|
|
843
|
+
>>> all_predictions = []
|
|
844
|
+
>>> for batch_preds in model.predict_batched(X_large, batch_size=5000):
|
|
845
|
+
... all_predictions.append(batch_preds)
|
|
846
|
+
>>> predictions = np.concatenate(all_predictions)
|
|
847
|
+
"""
|
|
848
|
+
for _, chunk in iter_chunks(X, batch_size):
|
|
849
|
+
yield self.predict(chunk)
|
|
850
|
+
|
|
851
|
+
def predict_survival_batched(
|
|
852
|
+
self, X: ArrayLike, batch_size: int = 1000
|
|
853
|
+
) -> Iterator[tuple[NDArray[np.float64], NDArray[np.float64]]]:
|
|
854
|
+
"""Predict survival functions in batches.
|
|
855
|
+
|
|
856
|
+
Parameters
|
|
857
|
+
----------
|
|
858
|
+
X : array-like of shape (n_samples, n_features)
|
|
859
|
+
Samples to predict.
|
|
860
|
+
batch_size : int, default=1000
|
|
861
|
+
Number of samples per batch.
|
|
862
|
+
|
|
863
|
+
Yields
|
|
864
|
+
------
|
|
865
|
+
times : ndarray of shape (n_times,)
|
|
866
|
+
Time points (same for all batches).
|
|
867
|
+
survival : ndarray of shape (batch_size, n_times)
|
|
868
|
+
Survival probabilities for each batch.
|
|
869
|
+
"""
|
|
870
|
+
for _, chunk in iter_chunks(X, batch_size):
|
|
871
|
+
yield self.predict_survival_function(chunk)
|
|
872
|
+
|
|
873
|
+
def predict_to_array(
|
|
874
|
+
self, X: ArrayLike, batch_size: int = 1000, out: NDArray | None = None
|
|
875
|
+
) -> NDArray[np.float64]:
|
|
876
|
+
"""Predict risk scores with optional pre-allocated output array.
|
|
877
|
+
|
|
878
|
+
This method is memory-efficient for large datasets as it can write
|
|
879
|
+
directly to a pre-allocated array or memory-mapped file.
|
|
880
|
+
|
|
881
|
+
Parameters
|
|
882
|
+
----------
|
|
883
|
+
X : array-like of shape (n_samples, n_features)
|
|
884
|
+
Samples to predict.
|
|
885
|
+
batch_size : int, default=1000
|
|
886
|
+
Number of samples per batch.
|
|
887
|
+
out : ndarray of shape (n_samples,), optional
|
|
888
|
+
Pre-allocated output array. If None, a new array is created.
|
|
889
|
+
|
|
890
|
+
Returns
|
|
891
|
+
-------
|
|
892
|
+
risk_scores : ndarray of shape (n_samples,)
|
|
893
|
+
Predicted risk scores.
|
|
894
|
+
|
|
895
|
+
Examples
|
|
896
|
+
--------
|
|
897
|
+
>>> # Using with memory-mapped array for very large datasets
|
|
898
|
+
>>> import numpy as np
|
|
899
|
+
>>> out = np.memmap('predictions.dat', dtype='float64', mode='w+', shape=(1000000,))
|
|
900
|
+
>>> model.predict_to_array(X_large, batch_size=10000, out=out)
|
|
901
|
+
>>> out.flush() # Write to disk
|
|
902
|
+
"""
|
|
903
|
+
X = np.asarray(X)
|
|
904
|
+
n_samples = X.shape[0]
|
|
905
|
+
|
|
906
|
+
if out is None:
|
|
907
|
+
out = np.empty(n_samples, dtype=np.float64)
|
|
908
|
+
elif out.shape[0] != n_samples:
|
|
909
|
+
raise ValueError(f"out has shape {out.shape}, expected ({n_samples},)")
|
|
910
|
+
|
|
911
|
+
for start_idx, chunk in iter_chunks(X, batch_size):
|
|
912
|
+
end_idx = start_idx + chunk.shape[0]
|
|
913
|
+
out[start_idx:end_idx] = self.predict(chunk)
|
|
914
|
+
|
|
915
|
+
return out
|
|
916
|
+
|
|
917
|
+
|
|
918
|
+
class StreamingCoxPHEstimator(CoxPHEstimator, StreamingMixin):
|
|
919
|
+
"""Cox PH Estimator with streaming/batched prediction support.
|
|
920
|
+
|
|
921
|
+
This class extends CoxPHEstimator with methods for processing large
|
|
922
|
+
datasets that don't fit in memory.
|
|
923
|
+
|
|
924
|
+
See CoxPHEstimator for full documentation.
|
|
925
|
+
"""
|
|
926
|
+
|
|
927
|
+
pass
|
|
928
|
+
|
|
929
|
+
|
|
930
|
+
class StreamingGradientBoostSurvivalEstimator(GradientBoostSurvivalEstimator, StreamingMixin):
|
|
931
|
+
"""Gradient Boosting Survival Estimator with streaming support.
|
|
932
|
+
|
|
933
|
+
This class extends GradientBoostSurvivalEstimator with methods for
|
|
934
|
+
processing large datasets that don't fit in memory.
|
|
935
|
+
|
|
936
|
+
See GradientBoostSurvivalEstimator for full documentation.
|
|
937
|
+
"""
|
|
938
|
+
|
|
939
|
+
pass
|
|
940
|
+
|
|
941
|
+
|
|
942
|
+
class StreamingSurvivalForestEstimator(SurvivalForestEstimator, StreamingMixin):
|
|
943
|
+
"""Survival Forest Estimator with streaming support.
|
|
944
|
+
|
|
945
|
+
This class extends SurvivalForestEstimator with methods for processing
|
|
946
|
+
large datasets that don't fit in memory.
|
|
947
|
+
|
|
948
|
+
See SurvivalForestEstimator for full documentation.
|
|
949
|
+
"""
|
|
950
|
+
|
|
951
|
+
pass
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
class StreamingAFTEstimator(AFTEstimator, StreamingMixin):
|
|
955
|
+
"""AFT Estimator with streaming/batched prediction support.
|
|
956
|
+
|
|
957
|
+
This class extends AFTEstimator with methods for processing large
|
|
958
|
+
datasets that don't fit in memory.
|
|
959
|
+
|
|
960
|
+
See AFTEstimator for full documentation.
|
|
961
|
+
"""
|
|
962
|
+
|
|
963
|
+
pass
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
class DeepSurvEstimator(BaseEstimator, RegressorMixin):
|
|
967
|
+
"""Scikit-learn compatible DeepSurv model.
|
|
968
|
+
|
|
969
|
+
DeepSurv is a deep feedforward neural network for survival analysis
|
|
970
|
+
using Cox partial likelihood loss.
|
|
971
|
+
|
|
972
|
+
Parameters
|
|
973
|
+
----------
|
|
974
|
+
hidden_layers : list of int, default=[64, 32]
|
|
975
|
+
Number of neurons in each hidden layer.
|
|
976
|
+
activation : str, default="selu"
|
|
977
|
+
Activation function. One of "relu", "selu", or "tanh".
|
|
978
|
+
dropout_rate : float, default=0.2
|
|
979
|
+
Dropout rate applied after each hidden layer.
|
|
980
|
+
learning_rate : float, default=0.001
|
|
981
|
+
Learning rate for the Adam optimizer.
|
|
982
|
+
batch_size : int, default=256
|
|
983
|
+
Mini-batch size for training.
|
|
984
|
+
n_epochs : int, default=100
|
|
985
|
+
Number of training epochs.
|
|
986
|
+
l2_reg : float, default=0.0001
|
|
987
|
+
L2 regularization (weight decay) coefficient.
|
|
988
|
+
seed : int or None, default=None
|
|
989
|
+
Random seed for reproducibility.
|
|
990
|
+
early_stopping_patience : int or None, default=10
|
|
991
|
+
Number of epochs without improvement before early stopping.
|
|
992
|
+
Set to None to disable early stopping.
|
|
993
|
+
validation_fraction : float, default=0.1
|
|
994
|
+
Fraction of training data to use for validation.
|
|
995
|
+
|
|
996
|
+
Attributes
|
|
997
|
+
----------
|
|
998
|
+
model_ : DeepSurv
|
|
999
|
+
The underlying fitted model.
|
|
1000
|
+
n_features_in_ : int
|
|
1001
|
+
Number of features seen during fit.
|
|
1002
|
+
"""
|
|
1003
|
+
|
|
1004
|
+
def __init__(
|
|
1005
|
+
self,
|
|
1006
|
+
hidden_layers: list[int] | None = None,
|
|
1007
|
+
activation: str = "selu",
|
|
1008
|
+
dropout_rate: float = 0.2,
|
|
1009
|
+
learning_rate: float = 0.001,
|
|
1010
|
+
batch_size: int = 256,
|
|
1011
|
+
n_epochs: int = 100,
|
|
1012
|
+
l2_reg: float = 0.0001,
|
|
1013
|
+
seed: int | None = None,
|
|
1014
|
+
early_stopping_patience: int | None = 10,
|
|
1015
|
+
validation_fraction: float = 0.1,
|
|
1016
|
+
):
|
|
1017
|
+
self.hidden_layers = hidden_layers if hidden_layers is not None else [64, 32]
|
|
1018
|
+
self.activation = activation
|
|
1019
|
+
self.dropout_rate = dropout_rate
|
|
1020
|
+
self.learning_rate = learning_rate
|
|
1021
|
+
self.batch_size = batch_size
|
|
1022
|
+
self.n_epochs = n_epochs
|
|
1023
|
+
self.l2_reg = l2_reg
|
|
1024
|
+
self.seed = seed
|
|
1025
|
+
self.early_stopping_patience = early_stopping_patience
|
|
1026
|
+
self.validation_fraction = validation_fraction
|
|
1027
|
+
|
|
1028
|
+
def fit(self, X: ArrayLike, y: ArrayLike) -> "DeepSurvEstimator":
|
|
1029
|
+
"""Fit the DeepSurv model.
|
|
1030
|
+
|
|
1031
|
+
Parameters
|
|
1032
|
+
----------
|
|
1033
|
+
X : array-like of shape (n_samples, n_features)
|
|
1034
|
+
Training data.
|
|
1035
|
+
y : array-like of shape (n_samples, 2)
|
|
1036
|
+
Target values where y[:, 0] is survival time and y[:, 1] is event status.
|
|
1037
|
+
|
|
1038
|
+
Returns
|
|
1039
|
+
-------
|
|
1040
|
+
self : DeepSurvEstimator
|
|
1041
|
+
Fitted estimator.
|
|
1042
|
+
"""
|
|
1043
|
+
X, time, status = _validate_survival_data(X, y)
|
|
1044
|
+
self.n_features_in_ = X.shape[1]
|
|
1045
|
+
n_obs = X.shape[0]
|
|
1046
|
+
|
|
1047
|
+
activation = _surv.Activation(self.activation)
|
|
1048
|
+
config = _surv.DeepSurvConfig(
|
|
1049
|
+
hidden_layers=self.hidden_layers,
|
|
1050
|
+
activation=activation,
|
|
1051
|
+
dropout_rate=self.dropout_rate,
|
|
1052
|
+
learning_rate=self.learning_rate,
|
|
1053
|
+
batch_size=self.batch_size,
|
|
1054
|
+
n_epochs=self.n_epochs,
|
|
1055
|
+
l2_reg=self.l2_reg,
|
|
1056
|
+
seed=self.seed,
|
|
1057
|
+
early_stopping_patience=self.early_stopping_patience,
|
|
1058
|
+
validation_fraction=self.validation_fraction,
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
x_flat = X.flatten().tolist()
|
|
1062
|
+
self.model_ = _surv.DeepSurv.fit(
|
|
1063
|
+
x_flat, n_obs, self.n_features_in_, time.tolist(), status.tolist(), config
|
|
1064
|
+
)
|
|
1065
|
+
|
|
1066
|
+
self.is_fitted_ = True
|
|
1067
|
+
return self
|
|
1068
|
+
|
|
1069
|
+
def predict(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
1070
|
+
"""Predict risk scores for samples.
|
|
1071
|
+
|
|
1072
|
+
Parameters
|
|
1073
|
+
----------
|
|
1074
|
+
X : array-like of shape (n_samples, n_features)
|
|
1075
|
+
Samples to predict.
|
|
1076
|
+
|
|
1077
|
+
Returns
|
|
1078
|
+
-------
|
|
1079
|
+
risk_scores : ndarray of shape (n_samples,)
|
|
1080
|
+
Predicted risk scores (higher = higher risk).
|
|
1081
|
+
"""
|
|
1082
|
+
check_is_fitted(self)
|
|
1083
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
1084
|
+
|
|
1085
|
+
if X.shape[1] != self.n_features_in_:
|
|
1086
|
+
raise ValueError(
|
|
1087
|
+
f"X has {X.shape[1]} features, but model expects {self.n_features_in_}"
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
x_flat = X.flatten().tolist()
|
|
1091
|
+
return np.array(self.model_.predict_risk(x_flat, X.shape[0]))
|
|
1092
|
+
|
|
1093
|
+
def predict_survival_function(
|
|
1094
|
+
self, X: ArrayLike
|
|
1095
|
+
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
|
|
1096
|
+
"""Predict survival function for samples.
|
|
1097
|
+
|
|
1098
|
+
Parameters
|
|
1099
|
+
----------
|
|
1100
|
+
X : array-like of shape (n_samples, n_features)
|
|
1101
|
+
Samples to predict.
|
|
1102
|
+
|
|
1103
|
+
Returns
|
|
1104
|
+
-------
|
|
1105
|
+
times : ndarray of shape (n_times,)
|
|
1106
|
+
Time points.
|
|
1107
|
+
survival : ndarray of shape (n_samples, n_times)
|
|
1108
|
+
Survival probabilities.
|
|
1109
|
+
"""
|
|
1110
|
+
check_is_fitted(self)
|
|
1111
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
1112
|
+
|
|
1113
|
+
x_flat = X.flatten().tolist()
|
|
1114
|
+
survival = self.model_.predict_survival(x_flat, X.shape[0])
|
|
1115
|
+
return np.array(self.model_.unique_times), np.array(survival)
|
|
1116
|
+
|
|
1117
|
+
def predict_median_survival_time(self, X: ArrayLike) -> NDArray[np.float64]:
|
|
1118
|
+
"""Predict median survival time for samples.
|
|
1119
|
+
|
|
1120
|
+
Parameters
|
|
1121
|
+
----------
|
|
1122
|
+
X : array-like of shape (n_samples, n_features)
|
|
1123
|
+
Samples to predict.
|
|
1124
|
+
|
|
1125
|
+
Returns
|
|
1126
|
+
-------
|
|
1127
|
+
median_times : ndarray of shape (n_samples,)
|
|
1128
|
+
Predicted median survival times (NaN if survival never drops below 0.5).
|
|
1129
|
+
"""
|
|
1130
|
+
check_is_fitted(self)
|
|
1131
|
+
X = check_array(X, dtype=np.float64, ensure_2d=True)
|
|
1132
|
+
|
|
1133
|
+
x_flat = X.flatten().tolist()
|
|
1134
|
+
result = self.model_.predict_median_survival_time(x_flat, X.shape[0])
|
|
1135
|
+
return np.array([t if t is not None else np.nan for t in result])
|
|
1136
|
+
|
|
1137
|
+
def score(self, X: ArrayLike, y: ArrayLike) -> float:
|
|
1138
|
+
"""Return the concordance index on the given test data.
|
|
1139
|
+
|
|
1140
|
+
Parameters
|
|
1141
|
+
----------
|
|
1142
|
+
X : array-like of shape (n_samples, n_features)
|
|
1143
|
+
Test samples.
|
|
1144
|
+
y : array-like of shape (n_samples, 2)
|
|
1145
|
+
True target values.
|
|
1146
|
+
|
|
1147
|
+
Returns
|
|
1148
|
+
-------
|
|
1149
|
+
score : float
|
|
1150
|
+
Concordance index (C-index), between 0 and 1.
|
|
1151
|
+
"""
|
|
1152
|
+
check_is_fitted(self)
|
|
1153
|
+
X, time, status = _validate_survival_data(X, y)
|
|
1154
|
+
risk_scores = self.predict(X)
|
|
1155
|
+
return _compute_concordance_index(time, status, risk_scores)
|
|
1156
|
+
|
|
1157
|
+
@property
|
|
1158
|
+
def train_loss(self) -> NDArray[np.float64]:
|
|
1159
|
+
"""Training loss history."""
|
|
1160
|
+
check_is_fitted(self)
|
|
1161
|
+
return np.array(self.model_.train_loss)
|
|
1162
|
+
|
|
1163
|
+
@property
|
|
1164
|
+
def val_loss(self) -> NDArray[np.float64]:
|
|
1165
|
+
"""Validation loss history."""
|
|
1166
|
+
check_is_fitted(self)
|
|
1167
|
+
return np.array(self.model_.val_loss)
|
|
1168
|
+
|
|
1169
|
+
|
|
1170
|
+
class StreamingDeepSurvEstimator(DeepSurvEstimator, StreamingMixin):
|
|
1171
|
+
"""DeepSurv Estimator with streaming/batched prediction support.
|
|
1172
|
+
|
|
1173
|
+
This class extends DeepSurvEstimator with methods for processing large
|
|
1174
|
+
datasets that don't fit in memory.
|
|
1175
|
+
|
|
1176
|
+
See DeepSurvEstimator for full documentation.
|
|
1177
|
+
"""
|
|
1178
|
+
|
|
1179
|
+
pass
|
|
1180
|
+
|
|
1181
|
+
|
|
1182
|
+
def predict_large_dataset(
|
|
1183
|
+
estimator,
|
|
1184
|
+
X: ArrayLike,
|
|
1185
|
+
batch_size: int = 1000,
|
|
1186
|
+
output_file: str | None = None,
|
|
1187
|
+
verbose: bool = False,
|
|
1188
|
+
) -> NDArray[np.float64]:
|
|
1189
|
+
"""Predict on a large dataset using batched processing.
|
|
1190
|
+
|
|
1191
|
+
This is a utility function for making predictions on datasets that may
|
|
1192
|
+
not fit in memory. It processes the data in batches and optionally
|
|
1193
|
+
writes results to a memory-mapped file.
|
|
1194
|
+
|
|
1195
|
+
Parameters
|
|
1196
|
+
----------
|
|
1197
|
+
estimator : fitted estimator
|
|
1198
|
+
A fitted survival estimator with a predict method.
|
|
1199
|
+
X : array-like of shape (n_samples, n_features)
|
|
1200
|
+
Samples to predict. Can be a numpy array or memory-mapped array.
|
|
1201
|
+
batch_size : int, default=1000
|
|
1202
|
+
Number of samples to process at once.
|
|
1203
|
+
output_file : str, optional
|
|
1204
|
+
Path to output file for memory-mapped results. If provided, results
|
|
1205
|
+
are written to this file and can exceed available RAM.
|
|
1206
|
+
verbose : bool, default=False
|
|
1207
|
+
If True, print progress information.
|
|
1208
|
+
|
|
1209
|
+
Returns
|
|
1210
|
+
-------
|
|
1211
|
+
predictions : ndarray of shape (n_samples,)
|
|
1212
|
+
Predicted risk scores. If output_file is provided, this is a
|
|
1213
|
+
memory-mapped array.
|
|
1214
|
+
|
|
1215
|
+
Examples
|
|
1216
|
+
--------
|
|
1217
|
+
>>> # Process a very large dataset
|
|
1218
|
+
>>> predictions = predict_large_dataset(
|
|
1219
|
+
... model, X_huge, batch_size=10000,
|
|
1220
|
+
... output_file='predictions.mmap', verbose=True
|
|
1221
|
+
... )
|
|
1222
|
+
"""
|
|
1223
|
+
X = np.asarray(X)
|
|
1224
|
+
n_samples = X.shape[0]
|
|
1225
|
+
|
|
1226
|
+
if output_file is not None:
|
|
1227
|
+
predictions = np.memmap(output_file, dtype=np.float64, mode="w+", shape=(n_samples,))
|
|
1228
|
+
else:
|
|
1229
|
+
predictions = np.empty(n_samples, dtype=np.float64)
|
|
1230
|
+
|
|
1231
|
+
n_batches = (n_samples + batch_size - 1) // batch_size
|
|
1232
|
+
|
|
1233
|
+
for batch_idx, (start_idx, chunk) in enumerate(iter_chunks(X, batch_size)):
|
|
1234
|
+
end_idx = start_idx + chunk.shape[0]
|
|
1235
|
+
predictions[start_idx:end_idx] = estimator.predict(chunk)
|
|
1236
|
+
|
|
1237
|
+
if verbose:
|
|
1238
|
+
print(f"Processed batch {batch_idx + 1}/{n_batches} (samples {start_idx}-{end_idx})")
|
|
1239
|
+
|
|
1240
|
+
if output_file is not None:
|
|
1241
|
+
predictions.flush()
|
|
1242
|
+
|
|
1243
|
+
return predictions
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
def survival_curves_to_disk(
|
|
1247
|
+
estimator,
|
|
1248
|
+
X: ArrayLike,
|
|
1249
|
+
output_file: str,
|
|
1250
|
+
batch_size: int = 100,
|
|
1251
|
+
verbose: bool = False,
|
|
1252
|
+
) -> tuple[NDArray[np.float64], np.memmap]:
|
|
1253
|
+
"""Compute survival curves and write to disk for large datasets.
|
|
1254
|
+
|
|
1255
|
+
This function computes survival curves in batches and stores them in
|
|
1256
|
+
a memory-mapped file, allowing processing of datasets larger than RAM.
|
|
1257
|
+
|
|
1258
|
+
Parameters
|
|
1259
|
+
----------
|
|
1260
|
+
estimator : fitted estimator
|
|
1261
|
+
A fitted survival estimator with predict_survival_function method.
|
|
1262
|
+
X : array-like of shape (n_samples, n_features)
|
|
1263
|
+
Samples to predict.
|
|
1264
|
+
output_file : str
|
|
1265
|
+
Path to output file for memory-mapped survival curves.
|
|
1266
|
+
batch_size : int, default=100
|
|
1267
|
+
Number of samples to process at once. Smaller values use less
|
|
1268
|
+
memory but are slower.
|
|
1269
|
+
verbose : bool, default=False
|
|
1270
|
+
If True, print progress information.
|
|
1271
|
+
|
|
1272
|
+
Returns
|
|
1273
|
+
-------
|
|
1274
|
+
times : ndarray of shape (n_times,)
|
|
1275
|
+
Time points for the survival curves.
|
|
1276
|
+
survival : memmap of shape (n_samples, n_times)
|
|
1277
|
+
Memory-mapped array of survival probabilities.
|
|
1278
|
+
|
|
1279
|
+
Examples
|
|
1280
|
+
--------
|
|
1281
|
+
>>> times, survival_curves = survival_curves_to_disk(
|
|
1282
|
+
... model, X_huge, 'survival_curves.mmap',
|
|
1283
|
+
... batch_size=100, verbose=True
|
|
1284
|
+
... )
|
|
1285
|
+
>>> # Access individual survival curves without loading all into memory
|
|
1286
|
+
>>> curve_0 = survival_curves[0] # Loads only first curve
|
|
1287
|
+
"""
|
|
1288
|
+
X = np.asarray(X)
|
|
1289
|
+
n_samples = X.shape[0]
|
|
1290
|
+
|
|
1291
|
+
first_times, first_surv = estimator.predict_survival_function(X[:1])
|
|
1292
|
+
n_times = len(first_times)
|
|
1293
|
+
times = first_times
|
|
1294
|
+
|
|
1295
|
+
survival = np.memmap(output_file, dtype=np.float64, mode="w+", shape=(n_samples, n_times))
|
|
1296
|
+
|
|
1297
|
+
n_batches = (n_samples + batch_size - 1) // batch_size
|
|
1298
|
+
|
|
1299
|
+
for batch_idx, (start_idx, chunk) in enumerate(iter_chunks(X, batch_size)):
|
|
1300
|
+
end_idx = start_idx + chunk.shape[0]
|
|
1301
|
+
_, batch_surv = estimator.predict_survival_function(chunk)
|
|
1302
|
+
survival[start_idx:end_idx] = batch_surv
|
|
1303
|
+
|
|
1304
|
+
if verbose:
|
|
1305
|
+
print(f"Processed batch {batch_idx + 1}/{n_batches} (samples {start_idx}-{end_idx})")
|
|
1306
|
+
|
|
1307
|
+
survival.flush()
|
|
1308
|
+
return times, survival
|