scikit-survival 0.26.0__cp314-cp314-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.26.0.dist-info/METADATA +185 -0
- scikit_survival-0.26.0.dist-info/RECORD +58 -0
- scikit_survival-0.26.0.dist-info/WHEEL +6 -0
- scikit_survival-0.26.0.dist-info/licenses/COPYING +674 -0
- scikit_survival-0.26.0.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +183 -0
- sksurv/base.py +115 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cpython-314-darwin.so +0 -0
- sksurv/column.py +204 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +12 -0
- sksurv/datasets/base.py +614 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/bmt.arff +46 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/cgvhd.arff +118 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cpython-314-darwin.so +0 -0
- sksurv/ensemble/boosting.py +1564 -0
- sksurv/ensemble/forest.py +902 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +91 -0
- sksurv/io/arffwrite.py +181 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cpython-314-darwin.so +0 -0
- sksurv/kernels/clinical.py +348 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cpython-314-darwin.so +0 -0
- sksurv/linear_model/aft.py +208 -0
- sksurv/linear_model/coxnet.py +592 -0
- sksurv/linear_model/coxph.py +637 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +724 -0
- sksurv/meta/stacking.py +370 -0
- sksurv/metrics.py +1028 -0
- sksurv/nonparametric.py +911 -0
- sksurv/preprocessing.py +195 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cpython-314-darwin.so +0 -0
- sksurv/svm/_prsvm.cpython-314-darwin.so +0 -0
- sksurv/svm/minlip.py +695 -0
- sksurv/svm/naive_survival_svm.py +249 -0
- sksurv/svm/survival_svm.py +1236 -0
- sksurv/testing.py +155 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cpython-314-darwin.so +0 -0
- sksurv/tree/tree.py +790 -0
- sksurv/util.py +416 -0
|
@@ -0,0 +1,902 @@
|
|
|
1
|
+
from abc import ABCMeta, abstractmethod
|
|
2
|
+
from functools import partial
|
|
3
|
+
import threading
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
from joblib import Parallel, delayed
|
|
7
|
+
import numpy as np
|
|
8
|
+
from sklearn.ensemble._base import _partition_estimators
|
|
9
|
+
from sklearn.ensemble._forest import (
|
|
10
|
+
BaseForest,
|
|
11
|
+
_accumulate_prediction,
|
|
12
|
+
_generate_unsampled_indices,
|
|
13
|
+
_get_n_samples_bootstrap,
|
|
14
|
+
_parallel_build_trees,
|
|
15
|
+
)
|
|
16
|
+
from sklearn.tree._tree import DTYPE
|
|
17
|
+
from sklearn.utils._tags import get_tags
|
|
18
|
+
from sklearn.utils.validation import check_is_fitted, check_random_state, validate_data
|
|
19
|
+
|
|
20
|
+
from ..base import SurvivalAnalysisMixin
|
|
21
|
+
from ..docstrings import append_cumulative_hazard_example, append_survival_function_example
|
|
22
|
+
from ..metrics import concordance_index_censored
|
|
23
|
+
from ..tree import ExtraSurvivalTree, SurvivalTree
|
|
24
|
+
from ..tree._criterion import get_unique_times
|
|
25
|
+
from ..tree.tree import _array_to_step_function
|
|
26
|
+
from ..util import check_array_survival
|
|
27
|
+
|
|
28
|
+
__all__ = ["RandomSurvivalForest", "ExtraSurvivalTrees"]
|
|
29
|
+
|
|
30
|
+
MAX_INT = np.iinfo(np.int32).max
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _sklearn_tags_patch(self):
|
|
34
|
+
# BaseForest.__sklearn_tags__ calls
|
|
35
|
+
# type(self.estimator)(criterion=self.criterions),
|
|
36
|
+
# which is incompatible with LogrankCriterion
|
|
37
|
+
if isinstance(self, _BaseSurvivalForest):
|
|
38
|
+
estimator = type(self.estimator)()
|
|
39
|
+
else:
|
|
40
|
+
estimator = type(self.estimator)(criterion=self.criterion)
|
|
41
|
+
tags = super(BaseForest, self).__sklearn_tags__()
|
|
42
|
+
tags.input_tags.allow_nan = get_tags(estimator).input_tags.allow_nan
|
|
43
|
+
return tags
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
BaseForest.__sklearn_tags__ = _sklearn_tags_patch
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
|
|
50
|
+
"""
|
|
51
|
+
Base class for forest-based estimators for survival analysis.
|
|
52
|
+
|
|
53
|
+
Warning: This class should not be used directly. Use derived classes
|
|
54
|
+
instead.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
@abstractmethod
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
estimator,
|
|
61
|
+
n_estimators=100,
|
|
62
|
+
*,
|
|
63
|
+
estimator_params=tuple(),
|
|
64
|
+
bootstrap=False,
|
|
65
|
+
oob_score=False,
|
|
66
|
+
n_jobs=None,
|
|
67
|
+
random_state=None,
|
|
68
|
+
verbose=0,
|
|
69
|
+
warm_start=False,
|
|
70
|
+
max_samples=None,
|
|
71
|
+
):
|
|
72
|
+
super().__init__(
|
|
73
|
+
estimator,
|
|
74
|
+
n_estimators=n_estimators,
|
|
75
|
+
estimator_params=estimator_params,
|
|
76
|
+
bootstrap=bootstrap,
|
|
77
|
+
oob_score=oob_score,
|
|
78
|
+
n_jobs=n_jobs,
|
|
79
|
+
random_state=random_state,
|
|
80
|
+
verbose=verbose,
|
|
81
|
+
warm_start=warm_start,
|
|
82
|
+
class_weight=None,
|
|
83
|
+
max_samples=max_samples,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def feature_importances_(self):
|
|
88
|
+
"""Not implemented"""
|
|
89
|
+
raise NotImplementedError()
|
|
90
|
+
|
|
91
|
+
def fit(self, X, y, sample_weight=None):
|
|
92
|
+
"""Build a forest of survival trees from the training set (X, y).
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
X : array-like, shape = (n_samples, n_features)
|
|
97
|
+
Data matrix
|
|
98
|
+
|
|
99
|
+
y : structured array, shape = (n_samples,)
|
|
100
|
+
A structured array with two fields. The first field is a boolean
|
|
101
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
102
|
+
The second field is a float with the time of event or time of censoring.
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
self
|
|
107
|
+
"""
|
|
108
|
+
self._validate_params()
|
|
109
|
+
|
|
110
|
+
X = validate_data(self, X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, ensure_all_finite=False)
|
|
111
|
+
event, time = check_array_survival(X, y)
|
|
112
|
+
|
|
113
|
+
# _compute_missing_values_in_feature_mask checks if X has missing values and
|
|
114
|
+
# will raise an error if the underlying tree base estimator can't handle missing
|
|
115
|
+
# values.
|
|
116
|
+
estimator = type(self.estimator)()
|
|
117
|
+
missing_values_in_feature_mask = estimator._compute_missing_values_in_feature_mask(
|
|
118
|
+
X, estimator_name=self.__class__.__name__
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
self._n_samples, self.n_features_in_ = X.shape
|
|
122
|
+
time = time.astype(np.float64)
|
|
123
|
+
self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
|
|
124
|
+
self.n_outputs_ = self.unique_times_.shape[0]
|
|
125
|
+
|
|
126
|
+
y_numeric = np.empty((X.shape[0], 2), dtype=np.float64)
|
|
127
|
+
y_numeric[:, 0] = time
|
|
128
|
+
y_numeric[:, 1] = event.astype(np.float64)
|
|
129
|
+
|
|
130
|
+
# Get bootstrap sample size
|
|
131
|
+
if not self.bootstrap and self.max_samples is not None: # pylint: disable=no-else-raise
|
|
132
|
+
raise ValueError(
|
|
133
|
+
"`max_sample` cannot be set if `bootstrap=False`. "
|
|
134
|
+
"Either switch to `bootstrap=True` or set "
|
|
135
|
+
"`max_sample=None`."
|
|
136
|
+
)
|
|
137
|
+
elif self.bootstrap:
|
|
138
|
+
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
|
|
139
|
+
else:
|
|
140
|
+
n_samples_bootstrap = None
|
|
141
|
+
|
|
142
|
+
self._n_samples_bootstrap = n_samples_bootstrap
|
|
143
|
+
|
|
144
|
+
# Check parameters
|
|
145
|
+
self._validate_estimator()
|
|
146
|
+
|
|
147
|
+
if not self.bootstrap and self.oob_score:
|
|
148
|
+
raise ValueError("Out of bag estimation only available if bootstrap=True")
|
|
149
|
+
|
|
150
|
+
random_state = check_random_state(self.random_state)
|
|
151
|
+
|
|
152
|
+
if not self.warm_start or not hasattr(self, "estimators_"):
|
|
153
|
+
# Free allocated memory, if any
|
|
154
|
+
self.estimators_ = []
|
|
155
|
+
|
|
156
|
+
n_more_estimators = self.n_estimators - len(self.estimators_)
|
|
157
|
+
|
|
158
|
+
if n_more_estimators < 0: # pylint: disable=no-else-raise
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"n_estimators={self.n_estimators} must be larger or equal to "
|
|
161
|
+
f"len(estimators_)={len(self.estimators_)} when warm_start==True"
|
|
162
|
+
)
|
|
163
|
+
elif n_more_estimators == 0:
|
|
164
|
+
warnings.warn("Warm-start fitting without increasing n_estimators does not fit new trees.", stacklevel=2)
|
|
165
|
+
else:
|
|
166
|
+
if self.warm_start and len(self.estimators_) > 0:
|
|
167
|
+
# We draw from the random state to get the random state we
|
|
168
|
+
# would have got if we hadn't used a warm_start.
|
|
169
|
+
random_state.randint(MAX_INT, size=len(self.estimators_))
|
|
170
|
+
|
|
171
|
+
trees = [self._make_estimator(append=False, random_state=random_state) for i in range(n_more_estimators)]
|
|
172
|
+
|
|
173
|
+
y_tree = (
|
|
174
|
+
y_numeric,
|
|
175
|
+
self.unique_times_,
|
|
176
|
+
self.is_event_time_,
|
|
177
|
+
)
|
|
178
|
+
# Parallel loop: we prefer the threading backend as the Cython code
|
|
179
|
+
# for fitting the trees is internally releasing the Python GIL
|
|
180
|
+
# making threading more efficient than multiprocessing in
|
|
181
|
+
# that case. However, for joblib 0.12+ we respect any
|
|
182
|
+
# parallel_backend contexts set at a higher level,
|
|
183
|
+
# since correctness does not rely on using threads.
|
|
184
|
+
trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer="threads")(
|
|
185
|
+
delayed(_parallel_build_trees)(
|
|
186
|
+
t,
|
|
187
|
+
self.bootstrap,
|
|
188
|
+
X,
|
|
189
|
+
y_tree,
|
|
190
|
+
sample_weight,
|
|
191
|
+
i,
|
|
192
|
+
len(trees),
|
|
193
|
+
verbose=self.verbose,
|
|
194
|
+
n_samples_bootstrap=n_samples_bootstrap,
|
|
195
|
+
missing_values_in_feature_mask=missing_values_in_feature_mask,
|
|
196
|
+
)
|
|
197
|
+
for i, t in enumerate(trees)
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Collect newly grown trees
|
|
201
|
+
self.estimators_.extend(trees)
|
|
202
|
+
|
|
203
|
+
if self.oob_score:
|
|
204
|
+
self._set_oob_score_and_attributes(X, (event, time))
|
|
205
|
+
|
|
206
|
+
return self
|
|
207
|
+
|
|
208
|
+
def _set_oob_score_and_attributes(self, X, y):
|
|
209
|
+
"""Calculate out of bag predictions and score."""
|
|
210
|
+
n_samples = X.shape[0]
|
|
211
|
+
event, time = y
|
|
212
|
+
|
|
213
|
+
predictions = np.zeros(n_samples)
|
|
214
|
+
n_predictions = np.zeros(n_samples)
|
|
215
|
+
|
|
216
|
+
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, self.max_samples)
|
|
217
|
+
|
|
218
|
+
for estimator in self.estimators_:
|
|
219
|
+
unsampled_indices = _generate_unsampled_indices(estimator.random_state, n_samples, n_samples_bootstrap)
|
|
220
|
+
p_estimator = estimator.predict(X[unsampled_indices, :], check_input=False)
|
|
221
|
+
|
|
222
|
+
predictions[unsampled_indices] += p_estimator
|
|
223
|
+
n_predictions[unsampled_indices] += 1
|
|
224
|
+
|
|
225
|
+
if (n_predictions == 0).any():
|
|
226
|
+
warnings.warn(
|
|
227
|
+
"Some inputs do not have OOB scores. "
|
|
228
|
+
"This probably means too few trees were used "
|
|
229
|
+
"to compute any reliable oob estimates.",
|
|
230
|
+
stacklevel=3,
|
|
231
|
+
)
|
|
232
|
+
n_predictions[n_predictions == 0] = 1
|
|
233
|
+
|
|
234
|
+
predictions /= n_predictions
|
|
235
|
+
self.oob_prediction_ = predictions
|
|
236
|
+
|
|
237
|
+
self.oob_score_ = concordance_index_censored(event, time, predictions)[0]
|
|
238
|
+
|
|
239
|
+
def _predict(self, predict_fn, X):
|
|
240
|
+
check_is_fitted(self, "estimators_")
|
|
241
|
+
# Check data
|
|
242
|
+
X = self._validate_X_predict(X)
|
|
243
|
+
|
|
244
|
+
# Assign chunk of trees to jobs
|
|
245
|
+
n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)
|
|
246
|
+
|
|
247
|
+
# avoid storing the output of every estimator by summing them here
|
|
248
|
+
if predict_fn == "predict":
|
|
249
|
+
y_hat = np.zeros((X.shape[0]), dtype=np.float64)
|
|
250
|
+
else:
|
|
251
|
+
y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)
|
|
252
|
+
|
|
253
|
+
def _get_fn(est, name):
|
|
254
|
+
fn = getattr(est, name)
|
|
255
|
+
if name in ("predict_cumulative_hazard_function", "predict_survival_function"):
|
|
256
|
+
fn = partial(fn, return_array=True)
|
|
257
|
+
return fn
|
|
258
|
+
|
|
259
|
+
# Parallel loop
|
|
260
|
+
lock = threading.Lock()
|
|
261
|
+
Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")(
|
|
262
|
+
delayed(_accumulate_prediction)(_get_fn(e, predict_fn), X, [y_hat], lock) for e in self.estimators_
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
y_hat /= len(self.estimators_)
|
|
266
|
+
|
|
267
|
+
return y_hat
|
|
268
|
+
|
|
269
|
+
def predict(self, X):
|
|
270
|
+
r"""Predict risk score.
|
|
271
|
+
|
|
272
|
+
The ensemble risk score is the total number of events,
|
|
273
|
+
which can be estimated by the sum of the estimated
|
|
274
|
+
ensemble cumulative hazard function :math:`\hat{H}_e`.
|
|
275
|
+
|
|
276
|
+
.. math::
|
|
277
|
+
|
|
278
|
+
\sum_{j=1}^{n} \hat{H}_e(T_{j} \mid x) ,
|
|
279
|
+
|
|
280
|
+
where :math:`n` denotes the total number of distinct
|
|
281
|
+
event times in the training data.
|
|
282
|
+
|
|
283
|
+
Parameters
|
|
284
|
+
----------
|
|
285
|
+
X : array-like, shape = (n_samples, n_features)
|
|
286
|
+
Data matrix.
|
|
287
|
+
|
|
288
|
+
Returns
|
|
289
|
+
-------
|
|
290
|
+
risk_scores : ndarray, shape = (n_samples,)
|
|
291
|
+
Predicted risk scores.
|
|
292
|
+
"""
|
|
293
|
+
return self._predict("predict", X)
|
|
294
|
+
|
|
295
|
+
def predict_cumulative_hazard_function(self, X, return_array=False):
|
|
296
|
+
arr = self._predict("predict_cumulative_hazard_function", X)
|
|
297
|
+
if return_array:
|
|
298
|
+
return arr
|
|
299
|
+
return _array_to_step_function(self.unique_times_, arr)
|
|
300
|
+
|
|
301
|
+
def predict_survival_function(self, X, return_array=False):
|
|
302
|
+
arr = self._predict("predict_survival_function", X)
|
|
303
|
+
if return_array:
|
|
304
|
+
return arr
|
|
305
|
+
return _array_to_step_function(self.unique_times_, arr)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
309
|
+
"""A random survival forest.
|
|
310
|
+
|
|
311
|
+
A random survival forest is a meta estimator that fits a number of
|
|
312
|
+
survival trees on various sub-samples of the dataset and uses
|
|
313
|
+
averaging to improve the predictive accuracy and control over-fitting.
|
|
314
|
+
The sub-sample size is always the same as the original input sample
|
|
315
|
+
size but the samples are drawn with replacement if
|
|
316
|
+
`bootstrap=True` (default).
|
|
317
|
+
|
|
318
|
+
In each survival tree, the quality of a split is measured by the
|
|
319
|
+
log-rank splitting rule.
|
|
320
|
+
|
|
321
|
+
See the :ref:`User Guide </user_guide/random-survival-forest.ipynb>`,
|
|
322
|
+
[1]_ and [2]_ for further description.
|
|
323
|
+
|
|
324
|
+
Parameters
|
|
325
|
+
----------
|
|
326
|
+
n_estimators : int, optional, default: 100
|
|
327
|
+
The number of trees in the forest.
|
|
328
|
+
|
|
329
|
+
max_depth : int or None, optional, default: None
|
|
330
|
+
The maximum depth of the tree. If None, then nodes are expanded until
|
|
331
|
+
all leaves are pure or until all leaves contain less than
|
|
332
|
+
min_samples_split samples.
|
|
333
|
+
|
|
334
|
+
min_samples_split : int, float, optional, default: 6
|
|
335
|
+
The minimum number of samples required to split an internal node:
|
|
336
|
+
|
|
337
|
+
- If int, then consider `min_samples_split` as the minimum number.
|
|
338
|
+
- If float, then `min_samples_split` is a fraction and
|
|
339
|
+
`ceil(min_samples_split * n_samples)` are the minimum
|
|
340
|
+
number of samples for each split.
|
|
341
|
+
|
|
342
|
+
min_samples_leaf : int, float, optional, default: 3
|
|
343
|
+
The minimum number of samples required to be at a leaf node.
|
|
344
|
+
A split point at any depth will only be considered if it leaves at
|
|
345
|
+
least ``min_samples_leaf`` training samples in each of the left and
|
|
346
|
+
right branches. This may have the effect of smoothing the model,
|
|
347
|
+
especially in regression.
|
|
348
|
+
|
|
349
|
+
- If int, then consider `min_samples_leaf` as the minimum number.
|
|
350
|
+
- If float, then `min_samples_leaf` is a fraction and
|
|
351
|
+
`ceil(min_samples_leaf * n_samples)` are the minimum
|
|
352
|
+
number of samples for each node.
|
|
353
|
+
|
|
354
|
+
min_weight_fraction_leaf : float, optional, default: 0.
|
|
355
|
+
The minimum weighted fraction of the sum total of weights (of all
|
|
356
|
+
the input samples) required to be at a leaf node. Samples have
|
|
357
|
+
equal weight when sample_weight is not provided.
|
|
358
|
+
|
|
359
|
+
max_features : int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt'
|
|
360
|
+
The number of features to consider when looking for the best split:
|
|
361
|
+
|
|
362
|
+
- If int, then consider `max_features` features at each split.
|
|
363
|
+
- If float, then `max_features` is a fraction and
|
|
364
|
+
`int(max_features * n_features)` features are considered at each
|
|
365
|
+
split.
|
|
366
|
+
- If "sqrt", then `max_features=sqrt(n_features)`.
|
|
367
|
+
- If "log2", then `max_features=log2(n_features)`.
|
|
368
|
+
- If None, then `max_features=n_features`.
|
|
369
|
+
|
|
370
|
+
Note: the search for a split does not stop until at least one
|
|
371
|
+
valid partition of the node samples is found, even if it requires to
|
|
372
|
+
effectively inspect more than ``max_features`` features.
|
|
373
|
+
|
|
374
|
+
max_leaf_nodes : int or None, optional, default: None
|
|
375
|
+
Grow a tree with ``max_leaf_nodes`` in best-first fashion.
|
|
376
|
+
Best nodes are defined as relative reduction in impurity.
|
|
377
|
+
If None then unlimited number of leaf nodes.
|
|
378
|
+
|
|
379
|
+
bootstrap : bool, optional, default: True
|
|
380
|
+
Whether bootstrap samples are used when building trees. If False, the
|
|
381
|
+
whole dataset is used to build each tree.
|
|
382
|
+
|
|
383
|
+
oob_score : bool, optional, default: False
|
|
384
|
+
Whether to use out-of-bag samples to estimate
|
|
385
|
+
the generalization accuracy.
|
|
386
|
+
|
|
387
|
+
n_jobs : int or None, optional, default: None
|
|
388
|
+
The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
|
|
389
|
+
:meth:`decision_path` and :meth:`apply` are all parallelized over the
|
|
390
|
+
trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
|
|
391
|
+
context. ``-1`` means using all processors.
|
|
392
|
+
|
|
393
|
+
random_state : int, RandomState instance or None, optional, default: None
|
|
394
|
+
Controls both the randomness of the bootstrapping of the samples used
|
|
395
|
+
when building trees (if ``bootstrap=True``) and the sampling of the
|
|
396
|
+
features to consider when looking for the best split at each node
|
|
397
|
+
(if ``max_features < n_features``).
|
|
398
|
+
|
|
399
|
+
verbose : int, optional, default: 0
|
|
400
|
+
Controls the verbosity when fitting and predicting.
|
|
401
|
+
|
|
402
|
+
warm_start : bool, optional, default: False
|
|
403
|
+
When set to ``True``, reuse the solution of the previous call to fit
|
|
404
|
+
and add more estimators to the ensemble, otherwise, just fit a whole
|
|
405
|
+
new forest.
|
|
406
|
+
|
|
407
|
+
max_samples : int or float, optional, default: None
|
|
408
|
+
If bootstrap is True, the number of samples to draw from X
|
|
409
|
+
to train each base estimator.
|
|
410
|
+
|
|
411
|
+
- If None (default), then draw `X.shape[0]` samples.
|
|
412
|
+
- If int, then draw `max_samples` samples.
|
|
413
|
+
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
|
|
414
|
+
`max_samples` should be in the interval `(0.0, 1.0]`.
|
|
415
|
+
|
|
416
|
+
low_memory : bool, optional, default: False
|
|
417
|
+
If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
|
|
418
|
+
and :meth:`predict_survival_function` are not implemented.
|
|
419
|
+
|
|
420
|
+
Attributes
|
|
421
|
+
----------
|
|
422
|
+
estimators_ : list of SurvivalTree instances
|
|
423
|
+
The collection of fitted sub-estimators.
|
|
424
|
+
|
|
425
|
+
unique_times_ : ndarray, shape = (n_unique_times,)
|
|
426
|
+
Unique time points.
|
|
427
|
+
|
|
428
|
+
n_features_in_ : int
|
|
429
|
+
Number of features seen during ``fit``.
|
|
430
|
+
|
|
431
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,)
|
|
432
|
+
Names of features seen during ``fit``. Defined only when `X`
|
|
433
|
+
has feature names that are all strings.
|
|
434
|
+
|
|
435
|
+
oob_score_ : float
|
|
436
|
+
Concordance index of the training dataset obtained
|
|
437
|
+
using an out-of-bag estimate.
|
|
438
|
+
|
|
439
|
+
See also
|
|
440
|
+
--------
|
|
441
|
+
sksurv.tree.SurvivalTree
|
|
442
|
+
A single survival tree.
|
|
443
|
+
|
|
444
|
+
Notes
|
|
445
|
+
-----
|
|
446
|
+
The default values for the parameters controlling the size of the trees
|
|
447
|
+
(e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
|
|
448
|
+
unpruned trees which can potentially be very large on some data sets. To
|
|
449
|
+
reduce memory consumption, the complexity and size of the trees should be
|
|
450
|
+
controlled by setting those parameter values.
|
|
451
|
+
|
|
452
|
+
Compared to scikit-learn's random forest models, :class:`RandomSurvivalForest`
|
|
453
|
+
currently does not support controlling the depth of a tree based on the log-rank
|
|
454
|
+
test statistics or it's associated p-value, i.e., the parameters
|
|
455
|
+
`min_impurity_decrease` or `min_impurity_split` are absent.
|
|
456
|
+
In addition, the `feature_importances_` attribute is not available.
|
|
457
|
+
It is recommended to estimate feature importances via
|
|
458
|
+
:func:`sklearn.inspection.permutation_importance`.
|
|
459
|
+
|
|
460
|
+
The features are always randomly permuted at each split. Therefore,
|
|
461
|
+
the best found split may vary, even with the same training data,
|
|
462
|
+
``max_features=n_features`` and ``bootstrap=False``, if the improvement
|
|
463
|
+
of the criterion is identical for several splits enumerated during the
|
|
464
|
+
search of the best split. To obtain a deterministic behavior during
|
|
465
|
+
fitting, ``random_state`` has to be fixed.
|
|
466
|
+
|
|
467
|
+
References
|
|
468
|
+
----------
|
|
469
|
+
.. [1] Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008).
|
|
470
|
+
Random survival forests. The Annals of Applied Statistics, 2(3), 841–860.
|
|
471
|
+
|
|
472
|
+
.. [2] Ishwaran, H., Kogalur, U. B. (2007). Random survival forests for R.
|
|
473
|
+
R News, 7(2), 25–31. https://cran.r-project.org/doc/Rnews/Rnews_2007-2.pdf.
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
_parameter_constraints = {
|
|
477
|
+
**BaseForest._parameter_constraints,
|
|
478
|
+
**SurvivalTree._parameter_constraints,
|
|
479
|
+
}
|
|
480
|
+
_parameter_constraints.pop("splitter")
|
|
481
|
+
|
|
482
|
+
def __init__(
|
|
483
|
+
self,
|
|
484
|
+
n_estimators=100,
|
|
485
|
+
*,
|
|
486
|
+
max_depth=None,
|
|
487
|
+
min_samples_split=6,
|
|
488
|
+
min_samples_leaf=3,
|
|
489
|
+
min_weight_fraction_leaf=0.0,
|
|
490
|
+
max_features="sqrt",
|
|
491
|
+
max_leaf_nodes=None,
|
|
492
|
+
bootstrap=True,
|
|
493
|
+
oob_score=False,
|
|
494
|
+
n_jobs=None,
|
|
495
|
+
random_state=None,
|
|
496
|
+
verbose=0,
|
|
497
|
+
warm_start=False,
|
|
498
|
+
max_samples=None,
|
|
499
|
+
low_memory=False,
|
|
500
|
+
):
|
|
501
|
+
super().__init__(
|
|
502
|
+
estimator=SurvivalTree(),
|
|
503
|
+
n_estimators=n_estimators,
|
|
504
|
+
estimator_params=(
|
|
505
|
+
"max_depth",
|
|
506
|
+
"min_samples_split",
|
|
507
|
+
"min_samples_leaf",
|
|
508
|
+
"min_weight_fraction_leaf",
|
|
509
|
+
"max_features",
|
|
510
|
+
"max_leaf_nodes",
|
|
511
|
+
"random_state",
|
|
512
|
+
"low_memory",
|
|
513
|
+
),
|
|
514
|
+
bootstrap=bootstrap,
|
|
515
|
+
oob_score=oob_score,
|
|
516
|
+
n_jobs=n_jobs,
|
|
517
|
+
random_state=random_state,
|
|
518
|
+
verbose=verbose,
|
|
519
|
+
warm_start=warm_start,
|
|
520
|
+
max_samples=max_samples,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
self.max_depth = max_depth
|
|
524
|
+
self.min_samples_split = min_samples_split
|
|
525
|
+
self.min_samples_leaf = min_samples_leaf
|
|
526
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
527
|
+
self.max_features = max_features
|
|
528
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
529
|
+
self.low_memory = low_memory
|
|
530
|
+
|
|
531
|
+
@append_cumulative_hazard_example(estimator_mod="ensemble", estimator_class="RandomSurvivalForest")
|
|
532
|
+
def predict_cumulative_hazard_function(self, X, return_array=False):
|
|
533
|
+
"""Predict cumulative hazard function.
|
|
534
|
+
|
|
535
|
+
For each tree in the ensemble, the cumulative hazard
|
|
536
|
+
function (CHF) for an individual with feature vector
|
|
537
|
+
:math:`x` is computed from all samples of the bootstrap
|
|
538
|
+
sample that are in the same terminal node as :math:`x`.
|
|
539
|
+
It is estimated by the Nelson–Aalen estimator.
|
|
540
|
+
The ensemble CHF at time :math:`t` is the average
|
|
541
|
+
value across all trees in the ensemble at the
|
|
542
|
+
specified time point.
|
|
543
|
+
|
|
544
|
+
Parameters
|
|
545
|
+
----------
|
|
546
|
+
X : array-like, shape = (n_samples, n_features)
|
|
547
|
+
Data matrix.
|
|
548
|
+
|
|
549
|
+
return_array : bool, default: False
|
|
550
|
+
Whether to return a single array of cumulative hazard values
|
|
551
|
+
or a list of step functions.
|
|
552
|
+
|
|
553
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
554
|
+
objects is returned.
|
|
555
|
+
|
|
556
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
557
|
+
returned, where `n_unique_times` is the number of unique
|
|
558
|
+
event times in the training data. Each row represents the cumulative
|
|
559
|
+
hazard function of an individual evaluated at `unique_times_`.
|
|
560
|
+
|
|
561
|
+
Returns
|
|
562
|
+
-------
|
|
563
|
+
cum_hazard : ndarray
|
|
564
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
565
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
566
|
+
|
|
567
|
+
If `return_array` is `True`, a numeric array of shape
|
|
568
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
569
|
+
|
|
570
|
+
Examples
|
|
571
|
+
--------
|
|
572
|
+
"""
|
|
573
|
+
return super().predict_cumulative_hazard_function(X, return_array)
|
|
574
|
+
|
|
575
|
+
@append_survival_function_example(estimator_mod="ensemble", estimator_class="RandomSurvivalForest")
|
|
576
|
+
def predict_survival_function(self, X, return_array=False):
|
|
577
|
+
"""Predict survival function.
|
|
578
|
+
|
|
579
|
+
For each tree in the ensemble, the survival function
|
|
580
|
+
for an individual with feature vector :math:`x` is
|
|
581
|
+
computed from all samples of the bootstrap sample that
|
|
582
|
+
are in the same terminal node as :math:`x`.
|
|
583
|
+
It is estimated by the Kaplan-Meier estimator.
|
|
584
|
+
The ensemble survival function at time :math:`t` is
|
|
585
|
+
the average value across all trees in the ensemble at
|
|
586
|
+
the specified time point.
|
|
587
|
+
|
|
588
|
+
Parameters
|
|
589
|
+
----------
|
|
590
|
+
X : array-like, shape = (n_samples, n_features)
|
|
591
|
+
Data matrix.
|
|
592
|
+
|
|
593
|
+
return_array : bool, default: False
|
|
594
|
+
Whether to return a single array of survival probabilities
|
|
595
|
+
or a list of step functions.
|
|
596
|
+
|
|
597
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
598
|
+
objects is returned.
|
|
599
|
+
|
|
600
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
601
|
+
returned, where `n_unique_times` is the number of unique
|
|
602
|
+
event times in the training data. Each row represents the survival
|
|
603
|
+
function of an individual evaluated at `unique_times_`.
|
|
604
|
+
|
|
605
|
+
Returns
|
|
606
|
+
-------
|
|
607
|
+
survival : ndarray
|
|
608
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
609
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
610
|
+
|
|
611
|
+
If `return_array` is `True`, a numeric array of shape
|
|
612
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
613
|
+
|
|
614
|
+
Examples
|
|
615
|
+
--------
|
|
616
|
+
"""
|
|
617
|
+
return super().predict_survival_function(X, return_array)
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
621
|
+
"""An extremely random survival forest.
|
|
622
|
+
|
|
623
|
+
This class implements a meta estimator that fits a number of randomized
|
|
624
|
+
survival trees (a.k.a. extra-trees) on various sub-samples of the dataset
|
|
625
|
+
and uses averaging to improve the predictive accuracy and control
|
|
626
|
+
over-fitting. The sub-sample size is always the same as the original
|
|
627
|
+
input sample size but the samples are drawn with replacement if
|
|
628
|
+
`bootstrap=True` (default).
|
|
629
|
+
|
|
630
|
+
In each randomized survival tree, the quality of a split is measured by
|
|
631
|
+
the log-rank splitting rule.
|
|
632
|
+
|
|
633
|
+
Compared to :class:`RandomSurvivalForest`, randomness goes one step
|
|
634
|
+
further in the way splits are computed. As in
|
|
635
|
+
:class:`RandomSurvivalForest`, a random subset of candidate features is
|
|
636
|
+
used, but instead of looking for the most discriminative thresholds,
|
|
637
|
+
thresholds are drawn at random for each candidate feature and the best of
|
|
638
|
+
these randomly-generated thresholds is picked as the splitting rule.
|
|
639
|
+
|
|
640
|
+
Parameters
|
|
641
|
+
----------
|
|
642
|
+
n_estimators : int, optional, default: 100
|
|
643
|
+
The number of trees in the forest.
|
|
644
|
+
|
|
645
|
+
max_depth : int or None, optional, default: None
|
|
646
|
+
The maximum depth of the tree. If None, then nodes are expanded until
|
|
647
|
+
all leaves are pure or until all leaves contain less than
|
|
648
|
+
min_samples_split samples.
|
|
649
|
+
|
|
650
|
+
min_samples_split : int, float, optional, default: 6
|
|
651
|
+
The minimum number of samples required to split an internal node:
|
|
652
|
+
|
|
653
|
+
- If int, then consider `min_samples_split` as the minimum number.
|
|
654
|
+
- If float, then `min_samples_split` is a fraction and
|
|
655
|
+
`ceil(min_samples_split * n_samples)` are the minimum
|
|
656
|
+
number of samples for each split.
|
|
657
|
+
|
|
658
|
+
min_samples_leaf : int, float, optional, default: 3
|
|
659
|
+
The minimum number of samples required to be at a leaf node.
|
|
660
|
+
A split point at any depth will only be considered if it leaves at
|
|
661
|
+
least ``min_samples_leaf`` training samples in each of the left and
|
|
662
|
+
right branches. This may have the effect of smoothing the model,
|
|
663
|
+
especially in regression.
|
|
664
|
+
|
|
665
|
+
- If int, then consider `min_samples_leaf` as the minimum number.
|
|
666
|
+
- If float, then `min_samples_leaf` is a fraction and
|
|
667
|
+
`ceil(min_samples_leaf * n_samples)` are the minimum
|
|
668
|
+
number of samples for each node.
|
|
669
|
+
|
|
670
|
+
min_weight_fraction_leaf : float, optional, default: 0.
|
|
671
|
+
The minimum weighted fraction of the sum total of weights (of all
|
|
672
|
+
the input samples) required to be at a leaf node. Samples have
|
|
673
|
+
equal weight when sample_weight is not provided.
|
|
674
|
+
|
|
675
|
+
max_features : int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt'
|
|
676
|
+
The number of features to consider when looking for the best split:
|
|
677
|
+
|
|
678
|
+
- If int, then consider `max_features` features at each split.
|
|
679
|
+
- If float, then `max_features` is a fraction and
|
|
680
|
+
`int(max_features * n_features)` features are considered at each
|
|
681
|
+
split.
|
|
682
|
+
- If "sqrt", then `max_features=sqrt(n_features)`.
|
|
683
|
+
- If "log2", then `max_features=log2(n_features)`.
|
|
684
|
+
- If None, then `max_features=n_features`.
|
|
685
|
+
|
|
686
|
+
Note: the search for a split does not stop until at least one
|
|
687
|
+
valid partition of the node samples is found, even if it requires to
|
|
688
|
+
effectively inspect more than ``max_features`` features.
|
|
689
|
+
|
|
690
|
+
max_leaf_nodes : int or None, optional, default: None
|
|
691
|
+
Grow a tree with ``max_leaf_nodes`` in best-first fashion.
|
|
692
|
+
Best nodes are defined as relative reduction in impurity.
|
|
693
|
+
If None then unlimited number of leaf nodes.
|
|
694
|
+
|
|
695
|
+
bootstrap : bool, optional, default: True
|
|
696
|
+
Whether bootstrap samples are used when building trees. If False, the
|
|
697
|
+
whole dataset is used to build each tree.
|
|
698
|
+
|
|
699
|
+
oob_score : bool, optional, default: False
|
|
700
|
+
Whether to use out-of-bag samples to estimate
|
|
701
|
+
the generalization accuracy.
|
|
702
|
+
|
|
703
|
+
n_jobs : int or None, optional, default: None
|
|
704
|
+
The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
|
|
705
|
+
:meth:`decision_path` and :meth:`apply` are all parallelized over the
|
|
706
|
+
trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
|
|
707
|
+
context. ``-1`` means using all processors.
|
|
708
|
+
|
|
709
|
+
random_state : int, RandomState instance or None, optional, default: None
|
|
710
|
+
Controls both the randomness of the bootstrapping of the samples used
|
|
711
|
+
when building trees (if ``bootstrap=True``) and the sampling of the
|
|
712
|
+
features to consider when looking for the best split at each node
|
|
713
|
+
(if ``max_features < n_features``).
|
|
714
|
+
|
|
715
|
+
verbose : int, optional, default: 0
|
|
716
|
+
Controls the verbosity when fitting and predicting.
|
|
717
|
+
|
|
718
|
+
warm_start : bool, optional, default: False
|
|
719
|
+
When set to ``True``, reuse the solution of the previous call to fit
|
|
720
|
+
and add more estimators to the ensemble, otherwise, just fit a whole
|
|
721
|
+
new forest.
|
|
722
|
+
|
|
723
|
+
max_samples : int or float, optional, default: None
|
|
724
|
+
If bootstrap is True, the number of samples to draw from X
|
|
725
|
+
to train each base estimator.
|
|
726
|
+
|
|
727
|
+
- If None (default), then draw `X.shape[0]` samples.
|
|
728
|
+
- If int, then draw `max_samples` samples.
|
|
729
|
+
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
|
|
730
|
+
`max_samples` should be in the interval `(0.0, 1.0]`.
|
|
731
|
+
|
|
732
|
+
low_memory : bool, optional, default: False
|
|
733
|
+
If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
|
|
734
|
+
and :meth:`predict_survival_function` are not implemented.
|
|
735
|
+
|
|
736
|
+
Attributes
|
|
737
|
+
----------
|
|
738
|
+
estimators_ : list of SurvivalTree instances
|
|
739
|
+
The collection of fitted sub-estimators.
|
|
740
|
+
|
|
741
|
+
unique_times_ : ndarray, shape = (n_unique_times,)
|
|
742
|
+
Unique time points.
|
|
743
|
+
|
|
744
|
+
n_features_in_ : int
|
|
745
|
+
Number of features seen during ``fit``.
|
|
746
|
+
|
|
747
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,)
|
|
748
|
+
Names of features seen during ``fit``. Defined only when `X`
|
|
749
|
+
has feature names that are all strings.
|
|
750
|
+
|
|
751
|
+
oob_score_ : float
|
|
752
|
+
Concordance index of the training dataset obtained
|
|
753
|
+
using an out-of-bag estimate.
|
|
754
|
+
|
|
755
|
+
See also
|
|
756
|
+
--------
|
|
757
|
+
sksurv.tree.SurvivalTree
|
|
758
|
+
A single survival tree.
|
|
759
|
+
"""
|
|
760
|
+
|
|
761
|
+
_parameter_constraints = {
|
|
762
|
+
**BaseForest._parameter_constraints,
|
|
763
|
+
**SurvivalTree._parameter_constraints,
|
|
764
|
+
}
|
|
765
|
+
_parameter_constraints.pop("splitter")
|
|
766
|
+
|
|
767
|
+
def __init__(
|
|
768
|
+
self,
|
|
769
|
+
n_estimators=100,
|
|
770
|
+
*,
|
|
771
|
+
max_depth=None,
|
|
772
|
+
min_samples_split=6,
|
|
773
|
+
min_samples_leaf=3,
|
|
774
|
+
min_weight_fraction_leaf=0.0,
|
|
775
|
+
max_features="sqrt",
|
|
776
|
+
max_leaf_nodes=None,
|
|
777
|
+
bootstrap=True,
|
|
778
|
+
oob_score=False,
|
|
779
|
+
n_jobs=None,
|
|
780
|
+
random_state=None,
|
|
781
|
+
verbose=0,
|
|
782
|
+
warm_start=False,
|
|
783
|
+
max_samples=None,
|
|
784
|
+
low_memory=False,
|
|
785
|
+
):
|
|
786
|
+
super().__init__(
|
|
787
|
+
estimator=ExtraSurvivalTree(),
|
|
788
|
+
n_estimators=n_estimators,
|
|
789
|
+
estimator_params=(
|
|
790
|
+
"max_depth",
|
|
791
|
+
"min_samples_split",
|
|
792
|
+
"min_samples_leaf",
|
|
793
|
+
"min_weight_fraction_leaf",
|
|
794
|
+
"max_features",
|
|
795
|
+
"max_leaf_nodes",
|
|
796
|
+
"random_state",
|
|
797
|
+
"low_memory",
|
|
798
|
+
),
|
|
799
|
+
bootstrap=bootstrap,
|
|
800
|
+
oob_score=oob_score,
|
|
801
|
+
n_jobs=n_jobs,
|
|
802
|
+
random_state=random_state,
|
|
803
|
+
verbose=verbose,
|
|
804
|
+
warm_start=warm_start,
|
|
805
|
+
max_samples=max_samples,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
self.max_depth = max_depth
|
|
809
|
+
self.min_samples_split = min_samples_split
|
|
810
|
+
self.min_samples_leaf = min_samples_leaf
|
|
811
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
812
|
+
self.max_features = max_features
|
|
813
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
814
|
+
self.low_memory = low_memory
|
|
815
|
+
|
|
816
|
+
@append_cumulative_hazard_example(estimator_mod="ensemble", estimator_class="ExtraSurvivalTrees")
|
|
817
|
+
def predict_cumulative_hazard_function(self, X, return_array=False):
|
|
818
|
+
"""Predict cumulative hazard function.
|
|
819
|
+
|
|
820
|
+
For each tree in the ensemble, the cumulative hazard
|
|
821
|
+
function (CHF) for an individual with feature vector
|
|
822
|
+
:math:`x` is computed from all samples of the bootstrap
|
|
823
|
+
sample that are in the same terminal node as :math:`x`.
|
|
824
|
+
It is estimated by the Nelson–Aalen estimator.
|
|
825
|
+
The ensemble CHF at time :math:`t` is the average
|
|
826
|
+
value across all trees in the ensemble at the
|
|
827
|
+
specified time point.
|
|
828
|
+
|
|
829
|
+
Parameters
|
|
830
|
+
----------
|
|
831
|
+
X : array-like, shape = (n_samples, n_features)
|
|
832
|
+
Data matrix.
|
|
833
|
+
|
|
834
|
+
return_array : bool, default: False
|
|
835
|
+
Whether to return a single array of cumulative hazard values
|
|
836
|
+
or a list of step functions.
|
|
837
|
+
|
|
838
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
839
|
+
objects is returned.
|
|
840
|
+
|
|
841
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
842
|
+
returned, where `n_unique_times` is the number of unique
|
|
843
|
+
event times in the training data. Each row represents the cumulative
|
|
844
|
+
hazard function of an individual evaluated at `unique_times_`.
|
|
845
|
+
|
|
846
|
+
Returns
|
|
847
|
+
-------
|
|
848
|
+
cum_hazard : ndarray
|
|
849
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
850
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
851
|
+
|
|
852
|
+
If `return_array` is `True`, a numeric array of shape
|
|
853
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
854
|
+
|
|
855
|
+
Examples
|
|
856
|
+
--------
|
|
857
|
+
"""
|
|
858
|
+
return super().predict_cumulative_hazard_function(X, return_array)
|
|
859
|
+
|
|
860
|
+
@append_survival_function_example(estimator_mod="ensemble", estimator_class="ExtraSurvivalTrees")
|
|
861
|
+
def predict_survival_function(self, X, return_array=False):
|
|
862
|
+
"""Predict survival function.
|
|
863
|
+
|
|
864
|
+
For each tree in the ensemble, the survival function
|
|
865
|
+
for an individual with feature vector :math:`x` is
|
|
866
|
+
computed from all samples of the bootstrap sample that
|
|
867
|
+
are in the same terminal node as :math:`x`.
|
|
868
|
+
It is estimated by the Kaplan-Meier estimator.
|
|
869
|
+
The ensemble survival function at time :math:`t` is
|
|
870
|
+
the average value across all trees in the ensemble at
|
|
871
|
+
the specified time point.
|
|
872
|
+
|
|
873
|
+
Parameters
|
|
874
|
+
----------
|
|
875
|
+
X : array-like, shape = (n_samples, n_features)
|
|
876
|
+
Data matrix.
|
|
877
|
+
|
|
878
|
+
return_array : bool, default: False
|
|
879
|
+
Whether to return a single array of survival probabilities
|
|
880
|
+
or a list of step functions.
|
|
881
|
+
|
|
882
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
883
|
+
objects is returned.
|
|
884
|
+
|
|
885
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
886
|
+
returned, where `n_unique_times` is the number of unique
|
|
887
|
+
event times in the training data. Each row represents the survival
|
|
888
|
+
function of an individual evaluated at `unique_times_`.
|
|
889
|
+
|
|
890
|
+
Returns
|
|
891
|
+
-------
|
|
892
|
+
survival : ndarray
|
|
893
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
894
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
895
|
+
|
|
896
|
+
If `return_array` is `True`, a numeric array of shape
|
|
897
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
898
|
+
|
|
899
|
+
Examples
|
|
900
|
+
--------
|
|
901
|
+
"""
|
|
902
|
+
return super().predict_survival_function(X, return_array)
|