scikit-survival 0.25.0__cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. scikit_survival-0.25.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.25.0.dist-info/RECORD +58 -0
  3. scikit_survival-0.25.0.dist-info/WHEEL +6 -0
  4. scikit_survival-0.25.0.dist-info/licenses/COPYING +674 -0
  5. scikit_survival-0.25.0.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +183 -0
  7. sksurv/base.py +115 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cpython-311-x86_64-linux-gnu.so +0 -0
  10. sksurv/column.py +205 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +12 -0
  13. sksurv/datasets/base.py +614 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/bmt.arff +46 -0
  17. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  18. sksurv/datasets/data/cgvhd.arff +118 -0
  19. sksurv/datasets/data/flchain.arff +7887 -0
  20. sksurv/datasets/data/veteran.arff +148 -0
  21. sksurv/datasets/data/whas500.arff +520 -0
  22. sksurv/docstrings.py +99 -0
  23. sksurv/ensemble/__init__.py +2 -0
  24. sksurv/ensemble/_coxph_loss.cpython-311-x86_64-linux-gnu.so +0 -0
  25. sksurv/ensemble/boosting.py +1564 -0
  26. sksurv/ensemble/forest.py +902 -0
  27. sksurv/ensemble/survival_loss.py +151 -0
  28. sksurv/exceptions.py +18 -0
  29. sksurv/functions.py +114 -0
  30. sksurv/io/__init__.py +2 -0
  31. sksurv/io/arffread.py +89 -0
  32. sksurv/io/arffwrite.py +181 -0
  33. sksurv/kernels/__init__.py +1 -0
  34. sksurv/kernels/_clinical_kernel.cpython-311-x86_64-linux-gnu.so +0 -0
  35. sksurv/kernels/clinical.py +348 -0
  36. sksurv/linear_model/__init__.py +3 -0
  37. sksurv/linear_model/_coxnet.cpython-311-x86_64-linux-gnu.so +0 -0
  38. sksurv/linear_model/aft.py +208 -0
  39. sksurv/linear_model/coxnet.py +592 -0
  40. sksurv/linear_model/coxph.py +637 -0
  41. sksurv/meta/__init__.py +4 -0
  42. sksurv/meta/base.py +35 -0
  43. sksurv/meta/ensemble_selection.py +724 -0
  44. sksurv/meta/stacking.py +370 -0
  45. sksurv/metrics.py +1028 -0
  46. sksurv/nonparametric.py +911 -0
  47. sksurv/preprocessing.py +183 -0
  48. sksurv/svm/__init__.py +11 -0
  49. sksurv/svm/_minlip.cpython-311-x86_64-linux-gnu.so +0 -0
  50. sksurv/svm/_prsvm.cpython-311-x86_64-linux-gnu.so +0 -0
  51. sksurv/svm/minlip.py +690 -0
  52. sksurv/svm/naive_survival_svm.py +249 -0
  53. sksurv/svm/survival_svm.py +1236 -0
  54. sksurv/testing.py +108 -0
  55. sksurv/tree/__init__.py +1 -0
  56. sksurv/tree/_criterion.cpython-311-x86_64-linux-gnu.so +0 -0
  57. sksurv/tree/tree.py +790 -0
  58. sksurv/util.py +415 -0
@@ -0,0 +1,902 @@
1
+ from abc import ABCMeta, abstractmethod
2
+ from functools import partial
3
+ import threading
4
+ import warnings
5
+
6
+ from joblib import Parallel, delayed
7
+ import numpy as np
8
+ from sklearn.ensemble._base import _partition_estimators
9
+ from sklearn.ensemble._forest import (
10
+ BaseForest,
11
+ _accumulate_prediction,
12
+ _generate_unsampled_indices,
13
+ _get_n_samples_bootstrap,
14
+ _parallel_build_trees,
15
+ )
16
+ from sklearn.tree._tree import DTYPE
17
+ from sklearn.utils._tags import get_tags
18
+ from sklearn.utils.validation import check_is_fitted, check_random_state, validate_data
19
+
20
+ from ..base import SurvivalAnalysisMixin
21
+ from ..docstrings import append_cumulative_hazard_example, append_survival_function_example
22
+ from ..metrics import concordance_index_censored
23
+ from ..tree import ExtraSurvivalTree, SurvivalTree
24
+ from ..tree._criterion import get_unique_times
25
+ from ..tree.tree import _array_to_step_function
26
+ from ..util import check_array_survival
27
+
28
+ __all__ = ["RandomSurvivalForest", "ExtraSurvivalTrees"]
29
+
30
+ MAX_INT = np.iinfo(np.int32).max
31
+
32
+
33
+ def _sklearn_tags_patch(self):
34
+ # BaseForest.__sklearn_tags__ calls
35
+ # type(self.estimator)(criterion=self.criterions),
36
+ # which is incompatible with LogrankCriterion
37
+ if isinstance(self, _BaseSurvivalForest):
38
+ estimator = type(self.estimator)()
39
+ else:
40
+ estimator = type(self.estimator)(criterion=self.criterion)
41
+ tags = super(BaseForest, self).__sklearn_tags__()
42
+ tags.input_tags.allow_nan = get_tags(estimator).input_tags.allow_nan
43
+ return tags
44
+
45
+
46
+ BaseForest.__sklearn_tags__ = _sklearn_tags_patch
47
+
48
+
49
+ class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
50
+ """
51
+ Base class for forest-based estimators for survival analysis.
52
+
53
+ Warning: This class should not be used directly. Use derived classes
54
+ instead.
55
+ """
56
+
57
+ @abstractmethod
58
+ def __init__(
59
+ self,
60
+ estimator,
61
+ n_estimators=100,
62
+ *,
63
+ estimator_params=tuple(),
64
+ bootstrap=False,
65
+ oob_score=False,
66
+ n_jobs=None,
67
+ random_state=None,
68
+ verbose=0,
69
+ warm_start=False,
70
+ max_samples=None,
71
+ ):
72
+ super().__init__(
73
+ estimator,
74
+ n_estimators=n_estimators,
75
+ estimator_params=estimator_params,
76
+ bootstrap=bootstrap,
77
+ oob_score=oob_score,
78
+ n_jobs=n_jobs,
79
+ random_state=random_state,
80
+ verbose=verbose,
81
+ warm_start=warm_start,
82
+ class_weight=None,
83
+ max_samples=max_samples,
84
+ )
85
+
86
+ @property
87
+ def feature_importances_(self):
88
+ """Not implemented"""
89
+ raise NotImplementedError()
90
+
91
+ def fit(self, X, y, sample_weight=None):
92
+ """Build a forest of survival trees from the training set (X, y).
93
+
94
+ Parameters
95
+ ----------
96
+ X : array-like, shape = (n_samples, n_features)
97
+ Data matrix
98
+
99
+ y : structured array, shape = (n_samples,)
100
+ A structured array with two fields. The first field is a boolean
101
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
102
+ The second field is a float with the time of event or time of censoring.
103
+
104
+ Returns
105
+ -------
106
+ self
107
+ """
108
+ self._validate_params()
109
+
110
+ X = validate_data(self, X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, ensure_all_finite=False)
111
+ event, time = check_array_survival(X, y)
112
+
113
+ # _compute_missing_values_in_feature_mask checks if X has missing values and
114
+ # will raise an error if the underlying tree base estimator can't handle missing
115
+ # values.
116
+ estimator = type(self.estimator)()
117
+ missing_values_in_feature_mask = estimator._compute_missing_values_in_feature_mask(
118
+ X, estimator_name=self.__class__.__name__
119
+ )
120
+
121
+ self._n_samples, self.n_features_in_ = X.shape
122
+ time = time.astype(np.float64)
123
+ self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
124
+ self.n_outputs_ = self.unique_times_.shape[0]
125
+
126
+ y_numeric = np.empty((X.shape[0], 2), dtype=np.float64)
127
+ y_numeric[:, 0] = time
128
+ y_numeric[:, 1] = event.astype(np.float64)
129
+
130
+ # Get bootstrap sample size
131
+ if not self.bootstrap and self.max_samples is not None: # pylint: disable=no-else-raise
132
+ raise ValueError(
133
+ "`max_sample` cannot be set if `bootstrap=False`. "
134
+ "Either switch to `bootstrap=True` or set "
135
+ "`max_sample=None`."
136
+ )
137
+ elif self.bootstrap:
138
+ n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
139
+ else:
140
+ n_samples_bootstrap = None
141
+
142
+ self._n_samples_bootstrap = n_samples_bootstrap
143
+
144
+ # Check parameters
145
+ self._validate_estimator()
146
+
147
+ if not self.bootstrap and self.oob_score:
148
+ raise ValueError("Out of bag estimation only available if bootstrap=True")
149
+
150
+ random_state = check_random_state(self.random_state)
151
+
152
+ if not self.warm_start or not hasattr(self, "estimators_"):
153
+ # Free allocated memory, if any
154
+ self.estimators_ = []
155
+
156
+ n_more_estimators = self.n_estimators - len(self.estimators_)
157
+
158
+ if n_more_estimators < 0: # pylint: disable=no-else-raise
159
+ raise ValueError(
160
+ f"n_estimators={self.n_estimators} must be larger or equal to "
161
+ f"len(estimators_)={len(self.estimators_)} when warm_start==True"
162
+ )
163
+ elif n_more_estimators == 0:
164
+ warnings.warn("Warm-start fitting without increasing n_estimators does not fit new trees.", stacklevel=2)
165
+ else:
166
+ if self.warm_start and len(self.estimators_) > 0:
167
+ # We draw from the random state to get the random state we
168
+ # would have got if we hadn't used a warm_start.
169
+ random_state.randint(MAX_INT, size=len(self.estimators_))
170
+
171
+ trees = [self._make_estimator(append=False, random_state=random_state) for i in range(n_more_estimators)]
172
+
173
+ y_tree = (
174
+ y_numeric,
175
+ self.unique_times_,
176
+ self.is_event_time_,
177
+ )
178
+ # Parallel loop: we prefer the threading backend as the Cython code
179
+ # for fitting the trees is internally releasing the Python GIL
180
+ # making threading more efficient than multiprocessing in
181
+ # that case. However, for joblib 0.12+ we respect any
182
+ # parallel_backend contexts set at a higher level,
183
+ # since correctness does not rely on using threads.
184
+ trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer="threads")(
185
+ delayed(_parallel_build_trees)(
186
+ t,
187
+ self.bootstrap,
188
+ X,
189
+ y_tree,
190
+ sample_weight,
191
+ i,
192
+ len(trees),
193
+ verbose=self.verbose,
194
+ n_samples_bootstrap=n_samples_bootstrap,
195
+ missing_values_in_feature_mask=missing_values_in_feature_mask,
196
+ )
197
+ for i, t in enumerate(trees)
198
+ )
199
+
200
+ # Collect newly grown trees
201
+ self.estimators_.extend(trees)
202
+
203
+ if self.oob_score:
204
+ self._set_oob_score_and_attributes(X, (event, time))
205
+
206
+ return self
207
+
208
+ def _set_oob_score_and_attributes(self, X, y):
209
+ """Calculate out of bag predictions and score."""
210
+ n_samples = X.shape[0]
211
+ event, time = y
212
+
213
+ predictions = np.zeros(n_samples)
214
+ n_predictions = np.zeros(n_samples)
215
+
216
+ n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, self.max_samples)
217
+
218
+ for estimator in self.estimators_:
219
+ unsampled_indices = _generate_unsampled_indices(estimator.random_state, n_samples, n_samples_bootstrap)
220
+ p_estimator = estimator.predict(X[unsampled_indices, :], check_input=False)
221
+
222
+ predictions[unsampled_indices] += p_estimator
223
+ n_predictions[unsampled_indices] += 1
224
+
225
+ if (n_predictions == 0).any():
226
+ warnings.warn(
227
+ "Some inputs do not have OOB scores. "
228
+ "This probably means too few trees were used "
229
+ "to compute any reliable oob estimates.",
230
+ stacklevel=3,
231
+ )
232
+ n_predictions[n_predictions == 0] = 1
233
+
234
+ predictions /= n_predictions
235
+ self.oob_prediction_ = predictions
236
+
237
+ self.oob_score_ = concordance_index_censored(event, time, predictions)[0]
238
+
239
+ def _predict(self, predict_fn, X):
240
+ check_is_fitted(self, "estimators_")
241
+ # Check data
242
+ X = self._validate_X_predict(X)
243
+
244
+ # Assign chunk of trees to jobs
245
+ n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)
246
+
247
+ # avoid storing the output of every estimator by summing them here
248
+ if predict_fn == "predict":
249
+ y_hat = np.zeros((X.shape[0]), dtype=np.float64)
250
+ else:
251
+ y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)
252
+
253
+ def _get_fn(est, name):
254
+ fn = getattr(est, name)
255
+ if name in ("predict_cumulative_hazard_function", "predict_survival_function"):
256
+ fn = partial(fn, return_array=True)
257
+ return fn
258
+
259
+ # Parallel loop
260
+ lock = threading.Lock()
261
+ Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")(
262
+ delayed(_accumulate_prediction)(_get_fn(e, predict_fn), X, [y_hat], lock) for e in self.estimators_
263
+ )
264
+
265
+ y_hat /= len(self.estimators_)
266
+
267
+ return y_hat
268
+
269
+ def predict(self, X):
270
+ r"""Predict risk score.
271
+
272
+ The ensemble risk score is the total number of events,
273
+ which can be estimated by the sum of the estimated
274
+ ensemble cumulative hazard function :math:`\hat{H}_e`.
275
+
276
+ .. math::
277
+
278
+ \sum_{j=1}^{n} \hat{H}_e(T_{j} \mid x) ,
279
+
280
+ where :math:`n` denotes the total number of distinct
281
+ event times in the training data.
282
+
283
+ Parameters
284
+ ----------
285
+ X : array-like, shape = (n_samples, n_features)
286
+ Data matrix.
287
+
288
+ Returns
289
+ -------
290
+ risk_scores : ndarray, shape = (n_samples,)
291
+ Predicted risk scores.
292
+ """
293
+ return self._predict("predict", X)
294
+
295
+ def predict_cumulative_hazard_function(self, X, return_array=False):
296
+ arr = self._predict("predict_cumulative_hazard_function", X)
297
+ if return_array:
298
+ return arr
299
+ return _array_to_step_function(self.unique_times_, arr)
300
+
301
+ def predict_survival_function(self, X, return_array=False):
302
+ arr = self._predict("predict_survival_function", X)
303
+ if return_array:
304
+ return arr
305
+ return _array_to_step_function(self.unique_times_, arr)
306
+
307
+
308
+ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
309
+ """A random survival forest.
310
+
311
+ A random survival forest is a meta estimator that fits a number of
312
+ survival trees on various sub-samples of the dataset and uses
313
+ averaging to improve the predictive accuracy and control over-fitting.
314
+ The sub-sample size is always the same as the original input sample
315
+ size but the samples are drawn with replacement if
316
+ `bootstrap=True` (default).
317
+
318
+ In each survival tree, the quality of a split is measured by the
319
+ log-rank splitting rule.
320
+
321
+ See the :ref:`User Guide </user_guide/random-survival-forest.ipynb>`,
322
+ [1]_ and [2]_ for further description.
323
+
324
+ Parameters
325
+ ----------
326
+ n_estimators : int, optional, default: 100
327
+ The number of trees in the forest.
328
+
329
+ max_depth : int or None, optional, default: None
330
+ The maximum depth of the tree. If None, then nodes are expanded until
331
+ all leaves are pure or until all leaves contain less than
332
+ min_samples_split samples.
333
+
334
+ min_samples_split : int, float, optional, default: 6
335
+ The minimum number of samples required to split an internal node:
336
+
337
+ - If int, then consider `min_samples_split` as the minimum number.
338
+ - If float, then `min_samples_split` is a fraction and
339
+ `ceil(min_samples_split * n_samples)` are the minimum
340
+ number of samples for each split.
341
+
342
+ min_samples_leaf : int, float, optional, default: 3
343
+ The minimum number of samples required to be at a leaf node.
344
+ A split point at any depth will only be considered if it leaves at
345
+ least ``min_samples_leaf`` training samples in each of the left and
346
+ right branches. This may have the effect of smoothing the model,
347
+ especially in regression.
348
+
349
+ - If int, then consider `min_samples_leaf` as the minimum number.
350
+ - If float, then `min_samples_leaf` is a fraction and
351
+ `ceil(min_samples_leaf * n_samples)` are the minimum
352
+ number of samples for each node.
353
+
354
+ min_weight_fraction_leaf : float, optional, default: 0.
355
+ The minimum weighted fraction of the sum total of weights (of all
356
+ the input samples) required to be at a leaf node. Samples have
357
+ equal weight when sample_weight is not provided.
358
+
359
+ max_features : int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt'
360
+ The number of features to consider when looking for the best split:
361
+
362
+ - If int, then consider `max_features` features at each split.
363
+ - If float, then `max_features` is a fraction and
364
+ `int(max_features * n_features)` features are considered at each
365
+ split.
366
+ - If "sqrt", then `max_features=sqrt(n_features)`.
367
+ - If "log2", then `max_features=log2(n_features)`.
368
+ - If None, then `max_features=n_features`.
369
+
370
+ Note: the search for a split does not stop until at least one
371
+ valid partition of the node samples is found, even if it requires to
372
+ effectively inspect more than ``max_features`` features.
373
+
374
+ max_leaf_nodes : int or None, optional, default: None
375
+ Grow a tree with ``max_leaf_nodes`` in best-first fashion.
376
+ Best nodes are defined as relative reduction in impurity.
377
+ If None then unlimited number of leaf nodes.
378
+
379
+ bootstrap : bool, optional, default: True
380
+ Whether bootstrap samples are used when building trees. If False, the
381
+ whole dataset is used to build each tree.
382
+
383
+ oob_score : bool, optional, default: False
384
+ Whether to use out-of-bag samples to estimate
385
+ the generalization accuracy.
386
+
387
+ n_jobs : int or None, optional, default: None
388
+ The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
389
+ :meth:`decision_path` and :meth:`apply` are all parallelized over the
390
+ trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
391
+ context. ``-1`` means using all processors.
392
+
393
+ random_state : int, RandomState instance or None, optional, default: None
394
+ Controls both the randomness of the bootstrapping of the samples used
395
+ when building trees (if ``bootstrap=True``) and the sampling of the
396
+ features to consider when looking for the best split at each node
397
+ (if ``max_features < n_features``).
398
+
399
+ verbose : int, optional, default: 0
400
+ Controls the verbosity when fitting and predicting.
401
+
402
+ warm_start : bool, optional, default: False
403
+ When set to ``True``, reuse the solution of the previous call to fit
404
+ and add more estimators to the ensemble, otherwise, just fit a whole
405
+ new forest.
406
+
407
+ max_samples : int or float, optional, default: None
408
+ If bootstrap is True, the number of samples to draw from X
409
+ to train each base estimator.
410
+
411
+ - If None (default), then draw `X.shape[0]` samples.
412
+ - If int, then draw `max_samples` samples.
413
+ - If float, then draw `max_samples * X.shape[0]` samples. Thus,
414
+ `max_samples` should be in the interval `(0.0, 1.0]`.
415
+
416
+ low_memory : bool, optional, default: False
417
+ If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
418
+ and :meth:`predict_survival_function` are not implemented.
419
+
420
+ Attributes
421
+ ----------
422
+ estimators_ : list of SurvivalTree instances
423
+ The collection of fitted sub-estimators.
424
+
425
+ unique_times_ : ndarray, shape = (n_unique_times,)
426
+ Unique time points.
427
+
428
+ n_features_in_ : int
429
+ Number of features seen during ``fit``.
430
+
431
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
432
+ Names of features seen during ``fit``. Defined only when `X`
433
+ has feature names that are all strings.
434
+
435
+ oob_score_ : float
436
+ Concordance index of the training dataset obtained
437
+ using an out-of-bag estimate.
438
+
439
+ See also
440
+ --------
441
+ sksurv.tree.SurvivalTree
442
+ A single survival tree.
443
+
444
+ Notes
445
+ -----
446
+ The default values for the parameters controlling the size of the trees
447
+ (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
448
+ unpruned trees which can potentially be very large on some data sets. To
449
+ reduce memory consumption, the complexity and size of the trees should be
450
+ controlled by setting those parameter values.
451
+
452
+ Compared to scikit-learn's random forest models, :class:`RandomSurvivalForest`
453
+ currently does not support controlling the depth of a tree based on the log-rank
454
+ test statistics or it's associated p-value, i.e., the parameters
455
+ `min_impurity_decrease` or `min_impurity_split` are absent.
456
+ In addition, the `feature_importances_` attribute is not available.
457
+ It is recommended to estimate feature importances via
458
+ :func:`sklearn.inspection.permutation_importance`.
459
+
460
+ The features are always randomly permuted at each split. Therefore,
461
+ the best found split may vary, even with the same training data,
462
+ ``max_features=n_features`` and ``bootstrap=False``, if the improvement
463
+ of the criterion is identical for several splits enumerated during the
464
+ search of the best split. To obtain a deterministic behavior during
465
+ fitting, ``random_state`` has to be fixed.
466
+
467
+ References
468
+ ----------
469
+ .. [1] Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008).
470
+ Random survival forests. The Annals of Applied Statistics, 2(3), 841–860.
471
+
472
+ .. [2] Ishwaran, H., Kogalur, U. B. (2007). Random survival forests for R.
473
+ R News, 7(2), 25–31. https://cran.r-project.org/doc/Rnews/Rnews_2007-2.pdf.
474
+ """
475
+
476
+ _parameter_constraints = {
477
+ **BaseForest._parameter_constraints,
478
+ **SurvivalTree._parameter_constraints,
479
+ }
480
+ _parameter_constraints.pop("splitter")
481
+
482
+ def __init__(
483
+ self,
484
+ n_estimators=100,
485
+ *,
486
+ max_depth=None,
487
+ min_samples_split=6,
488
+ min_samples_leaf=3,
489
+ min_weight_fraction_leaf=0.0,
490
+ max_features="sqrt",
491
+ max_leaf_nodes=None,
492
+ bootstrap=True,
493
+ oob_score=False,
494
+ n_jobs=None,
495
+ random_state=None,
496
+ verbose=0,
497
+ warm_start=False,
498
+ max_samples=None,
499
+ low_memory=False,
500
+ ):
501
+ super().__init__(
502
+ estimator=SurvivalTree(),
503
+ n_estimators=n_estimators,
504
+ estimator_params=(
505
+ "max_depth",
506
+ "min_samples_split",
507
+ "min_samples_leaf",
508
+ "min_weight_fraction_leaf",
509
+ "max_features",
510
+ "max_leaf_nodes",
511
+ "random_state",
512
+ "low_memory",
513
+ ),
514
+ bootstrap=bootstrap,
515
+ oob_score=oob_score,
516
+ n_jobs=n_jobs,
517
+ random_state=random_state,
518
+ verbose=verbose,
519
+ warm_start=warm_start,
520
+ max_samples=max_samples,
521
+ )
522
+
523
+ self.max_depth = max_depth
524
+ self.min_samples_split = min_samples_split
525
+ self.min_samples_leaf = min_samples_leaf
526
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
527
+ self.max_features = max_features
528
+ self.max_leaf_nodes = max_leaf_nodes
529
+ self.low_memory = low_memory
530
+
531
+ @append_cumulative_hazard_example(estimator_mod="ensemble", estimator_class="RandomSurvivalForest")
532
+ def predict_cumulative_hazard_function(self, X, return_array=False):
533
+ """Predict cumulative hazard function.
534
+
535
+ For each tree in the ensemble, the cumulative hazard
536
+ function (CHF) for an individual with feature vector
537
+ :math:`x` is computed from all samples of the bootstrap
538
+ sample that are in the same terminal node as :math:`x`.
539
+ It is estimated by the Nelson–Aalen estimator.
540
+ The ensemble CHF at time :math:`t` is the average
541
+ value across all trees in the ensemble at the
542
+ specified time point.
543
+
544
+ Parameters
545
+ ----------
546
+ X : array-like, shape = (n_samples, n_features)
547
+ Data matrix.
548
+
549
+ return_array : bool, default: False
550
+ Whether to return a single array of cumulative hazard values
551
+ or a list of step functions.
552
+
553
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
554
+ objects is returned.
555
+
556
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
557
+ returned, where `n_unique_times` is the number of unique
558
+ event times in the training data. Each row represents the cumulative
559
+ hazard function of an individual evaluated at `unique_times_`.
560
+
561
+ Returns
562
+ -------
563
+ cum_hazard : ndarray
564
+ If `return_array` is `False`, an array of `n_samples`
565
+ :class:`sksurv.functions.StepFunction` instances is returned.
566
+
567
+ If `return_array` is `True`, a numeric array of shape
568
+ `(n_samples, n_unique_times_)` is returned.
569
+
570
+ Examples
571
+ --------
572
+ """
573
+ return super().predict_cumulative_hazard_function(X, return_array)
574
+
575
+ @append_survival_function_example(estimator_mod="ensemble", estimator_class="RandomSurvivalForest")
576
+ def predict_survival_function(self, X, return_array=False):
577
+ """Predict survival function.
578
+
579
+ For each tree in the ensemble, the survival function
580
+ for an individual with feature vector :math:`x` is
581
+ computed from all samples of the bootstrap sample that
582
+ are in the same terminal node as :math:`x`.
583
+ It is estimated by the Kaplan-Meier estimator.
584
+ The ensemble survival function at time :math:`t` is
585
+ the average value across all trees in the ensemble at
586
+ the specified time point.
587
+
588
+ Parameters
589
+ ----------
590
+ X : array-like, shape = (n_samples, n_features)
591
+ Data matrix.
592
+
593
+ return_array : bool, default: False
594
+ Whether to return a single array of survival probabilities
595
+ or a list of step functions.
596
+
597
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
598
+ objects is returned.
599
+
600
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
601
+ returned, where `n_unique_times` is the number of unique
602
+ event times in the training data. Each row represents the survival
603
+ function of an individual evaluated at `unique_times_`.
604
+
605
+ Returns
606
+ -------
607
+ survival : ndarray
608
+ If `return_array` is `False`, an array of `n_samples`
609
+ :class:`sksurv.functions.StepFunction` instances is returned.
610
+
611
+ If `return_array` is `True`, a numeric array of shape
612
+ `(n_samples, n_unique_times_)` is returned.
613
+
614
+ Examples
615
+ --------
616
+ """
617
+ return super().predict_survival_function(X, return_array)
618
+
619
+
620
+ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
621
+ """An extremely random survival forest.
622
+
623
+ This class implements a meta estimator that fits a number of randomized
624
+ survival trees (a.k.a. extra-trees) on various sub-samples of the dataset
625
+ and uses averaging to improve the predictive accuracy and control
626
+ over-fitting. The sub-sample size is always the same as the original
627
+ input sample size but the samples are drawn with replacement if
628
+ `bootstrap=True` (default).
629
+
630
+ In each randomized survival tree, the quality of a split is measured by
631
+ the log-rank splitting rule.
632
+
633
+ Compared to :class:`RandomSurvivalForest`, randomness goes one step
634
+ further in the way splits are computed. As in
635
+ :class:`RandomSurvivalForest`, a random subset of candidate features is
636
+ used, but instead of looking for the most discriminative thresholds,
637
+ thresholds are drawn at random for each candidate feature and the best of
638
+ these randomly-generated thresholds is picked as the splitting rule.
639
+
640
+ Parameters
641
+ ----------
642
+ n_estimators : int, optional, default: 100
643
+ The number of trees in the forest.
644
+
645
+ max_depth : int or None, optional, default: None
646
+ The maximum depth of the tree. If None, then nodes are expanded until
647
+ all leaves are pure or until all leaves contain less than
648
+ min_samples_split samples.
649
+
650
+ min_samples_split : int, float, optional, default: 6
651
+ The minimum number of samples required to split an internal node:
652
+
653
+ - If int, then consider `min_samples_split` as the minimum number.
654
+ - If float, then `min_samples_split` is a fraction and
655
+ `ceil(min_samples_split * n_samples)` are the minimum
656
+ number of samples for each split.
657
+
658
+ min_samples_leaf : int, float, optional, default: 3
659
+ The minimum number of samples required to be at a leaf node.
660
+ A split point at any depth will only be considered if it leaves at
661
+ least ``min_samples_leaf`` training samples in each of the left and
662
+ right branches. This may have the effect of smoothing the model,
663
+ especially in regression.
664
+
665
+ - If int, then consider `min_samples_leaf` as the minimum number.
666
+ - If float, then `min_samples_leaf` is a fraction and
667
+ `ceil(min_samples_leaf * n_samples)` are the minimum
668
+ number of samples for each node.
669
+
670
+ min_weight_fraction_leaf : float, optional, default: 0.
671
+ The minimum weighted fraction of the sum total of weights (of all
672
+ the input samples) required to be at a leaf node. Samples have
673
+ equal weight when sample_weight is not provided.
674
+
675
+ max_features : int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt'
676
+ The number of features to consider when looking for the best split:
677
+
678
+ - If int, then consider `max_features` features at each split.
679
+ - If float, then `max_features` is a fraction and
680
+ `int(max_features * n_features)` features are considered at each
681
+ split.
682
+ - If "sqrt", then `max_features=sqrt(n_features)`.
683
+ - If "log2", then `max_features=log2(n_features)`.
684
+ - If None, then `max_features=n_features`.
685
+
686
+ Note: the search for a split does not stop until at least one
687
+ valid partition of the node samples is found, even if it requires to
688
+ effectively inspect more than ``max_features`` features.
689
+
690
+ max_leaf_nodes : int or None, optional, default: None
691
+ Grow a tree with ``max_leaf_nodes`` in best-first fashion.
692
+ Best nodes are defined as relative reduction in impurity.
693
+ If None then unlimited number of leaf nodes.
694
+
695
+ bootstrap : bool, optional, default: True
696
+ Whether bootstrap samples are used when building trees. If False, the
697
+ whole dataset is used to build each tree.
698
+
699
+ oob_score : bool, optional, default: False
700
+ Whether to use out-of-bag samples to estimate
701
+ the generalization accuracy.
702
+
703
+ n_jobs : int or None, optional, default: None
704
+ The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
705
+ :meth:`decision_path` and :meth:`apply` are all parallelized over the
706
+ trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
707
+ context. ``-1`` means using all processors.
708
+
709
+ random_state : int, RandomState instance or None, optional, default: None
710
+ Controls both the randomness of the bootstrapping of the samples used
711
+ when building trees (if ``bootstrap=True``) and the sampling of the
712
+ features to consider when looking for the best split at each node
713
+ (if ``max_features < n_features``).
714
+
715
+ verbose : int, optional, default: 0
716
+ Controls the verbosity when fitting and predicting.
717
+
718
+ warm_start : bool, optional, default: False
719
+ When set to ``True``, reuse the solution of the previous call to fit
720
+ and add more estimators to the ensemble, otherwise, just fit a whole
721
+ new forest.
722
+
723
+ max_samples : int or float, optional, default: None
724
+ If bootstrap is True, the number of samples to draw from X
725
+ to train each base estimator.
726
+
727
+ - If None (default), then draw `X.shape[0]` samples.
728
+ - If int, then draw `max_samples` samples.
729
+ - If float, then draw `max_samples * X.shape[0]` samples. Thus,
730
+ `max_samples` should be in the interval `(0.0, 1.0]`.
731
+
732
+ low_memory : bool, optional, default: False
733
+ If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
734
+ and :meth:`predict_survival_function` are not implemented.
735
+
736
+ Attributes
737
+ ----------
738
+ estimators_ : list of SurvivalTree instances
739
+ The collection of fitted sub-estimators.
740
+
741
+ unique_times_ : ndarray, shape = (n_unique_times,)
742
+ Unique time points.
743
+
744
+ n_features_in_ : int
745
+ Number of features seen during ``fit``.
746
+
747
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
748
+ Names of features seen during ``fit``. Defined only when `X`
749
+ has feature names that are all strings.
750
+
751
+ oob_score_ : float
752
+ Concordance index of the training dataset obtained
753
+ using an out-of-bag estimate.
754
+
755
+ See also
756
+ --------
757
+ sksurv.tree.SurvivalTree
758
+ A single survival tree.
759
+ """
760
+
761
+ _parameter_constraints = {
762
+ **BaseForest._parameter_constraints,
763
+ **SurvivalTree._parameter_constraints,
764
+ }
765
+ _parameter_constraints.pop("splitter")
766
+
767
+ def __init__(
768
+ self,
769
+ n_estimators=100,
770
+ *,
771
+ max_depth=None,
772
+ min_samples_split=6,
773
+ min_samples_leaf=3,
774
+ min_weight_fraction_leaf=0.0,
775
+ max_features="sqrt",
776
+ max_leaf_nodes=None,
777
+ bootstrap=True,
778
+ oob_score=False,
779
+ n_jobs=None,
780
+ random_state=None,
781
+ verbose=0,
782
+ warm_start=False,
783
+ max_samples=None,
784
+ low_memory=False,
785
+ ):
786
+ super().__init__(
787
+ estimator=ExtraSurvivalTree(),
788
+ n_estimators=n_estimators,
789
+ estimator_params=(
790
+ "max_depth",
791
+ "min_samples_split",
792
+ "min_samples_leaf",
793
+ "min_weight_fraction_leaf",
794
+ "max_features",
795
+ "max_leaf_nodes",
796
+ "random_state",
797
+ "low_memory",
798
+ ),
799
+ bootstrap=bootstrap,
800
+ oob_score=oob_score,
801
+ n_jobs=n_jobs,
802
+ random_state=random_state,
803
+ verbose=verbose,
804
+ warm_start=warm_start,
805
+ max_samples=max_samples,
806
+ )
807
+
808
+ self.max_depth = max_depth
809
+ self.min_samples_split = min_samples_split
810
+ self.min_samples_leaf = min_samples_leaf
811
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
812
+ self.max_features = max_features
813
+ self.max_leaf_nodes = max_leaf_nodes
814
+ self.low_memory = low_memory
815
+
816
+ @append_cumulative_hazard_example(estimator_mod="ensemble", estimator_class="ExtraSurvivalTrees")
817
+ def predict_cumulative_hazard_function(self, X, return_array=False):
818
+ """Predict cumulative hazard function.
819
+
820
+ For each tree in the ensemble, the cumulative hazard
821
+ function (CHF) for an individual with feature vector
822
+ :math:`x` is computed from all samples of the bootstrap
823
+ sample that are in the same terminal node as :math:`x`.
824
+ It is estimated by the Nelson–Aalen estimator.
825
+ The ensemble CHF at time :math:`t` is the average
826
+ value across all trees in the ensemble at the
827
+ specified time point.
828
+
829
+ Parameters
830
+ ----------
831
+ X : array-like, shape = (n_samples, n_features)
832
+ Data matrix.
833
+
834
+ return_array : bool, default: False
835
+ Whether to return a single array of cumulative hazard values
836
+ or a list of step functions.
837
+
838
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
839
+ objects is returned.
840
+
841
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
842
+ returned, where `n_unique_times` is the number of unique
843
+ event times in the training data. Each row represents the cumulative
844
+ hazard function of an individual evaluated at `unique_times_`.
845
+
846
+ Returns
847
+ -------
848
+ cum_hazard : ndarray
849
+ If `return_array` is `False`, an array of `n_samples`
850
+ :class:`sksurv.functions.StepFunction` instances is returned.
851
+
852
+ If `return_array` is `True`, a numeric array of shape
853
+ `(n_samples, n_unique_times_)` is returned.
854
+
855
+ Examples
856
+ --------
857
+ """
858
+ return super().predict_cumulative_hazard_function(X, return_array)
859
+
860
+ @append_survival_function_example(estimator_mod="ensemble", estimator_class="ExtraSurvivalTrees")
861
+ def predict_survival_function(self, X, return_array=False):
862
+ """Predict survival function.
863
+
864
+ For each tree in the ensemble, the survival function
865
+ for an individual with feature vector :math:`x` is
866
+ computed from all samples of the bootstrap sample that
867
+ are in the same terminal node as :math:`x`.
868
+ It is estimated by the Kaplan-Meier estimator.
869
+ The ensemble survival function at time :math:`t` is
870
+ the average value across all trees in the ensemble at
871
+ the specified time point.
872
+
873
+ Parameters
874
+ ----------
875
+ X : array-like, shape = (n_samples, n_features)
876
+ Data matrix.
877
+
878
+ return_array : bool, default: False
879
+ Whether to return a single array of survival probabilities
880
+ or a list of step functions.
881
+
882
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
883
+ objects is returned.
884
+
885
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
886
+ returned, where `n_unique_times` is the number of unique
887
+ event times in the training data. Each row represents the survival
888
+ function of an individual evaluated at `unique_times_`.
889
+
890
+ Returns
891
+ -------
892
+ survival : ndarray
893
+ If `return_array` is `False`, an array of `n_samples`
894
+ :class:`sksurv.functions.StepFunction` instances is returned.
895
+
896
+ If `return_array` is `True`, a numeric array of shape
897
+ `(n_samples, n_unique_times_)` is returned.
898
+
899
+ Examples
900
+ --------
901
+ """
902
+ return super().predict_survival_function(X, return_array)