scikit-survival 0.23.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. scikit_survival-0.23.1.dist-info/COPYING +674 -0
  2. scikit_survival-0.23.1.dist-info/METADATA +888 -0
  3. scikit_survival-0.23.1.dist-info/RECORD +55 -0
  4. scikit_survival-0.23.1.dist-info/WHEEL +5 -0
  5. scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +138 -0
  7. sksurv/base.py +103 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
  10. sksurv/column.py +201 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +10 -0
  13. sksurv/datasets/base.py +436 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  17. sksurv/datasets/data/flchain.arff +7887 -0
  18. sksurv/datasets/data/veteran.arff +148 -0
  19. sksurv/datasets/data/whas500.arff +520 -0
  20. sksurv/ensemble/__init__.py +2 -0
  21. sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
  22. sksurv/ensemble/boosting.py +1610 -0
  23. sksurv/ensemble/forest.py +947 -0
  24. sksurv/ensemble/survival_loss.py +151 -0
  25. sksurv/exceptions.py +18 -0
  26. sksurv/functions.py +114 -0
  27. sksurv/io/__init__.py +2 -0
  28. sksurv/io/arffread.py +58 -0
  29. sksurv/io/arffwrite.py +145 -0
  30. sksurv/kernels/__init__.py +1 -0
  31. sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
  32. sksurv/kernels/clinical.py +328 -0
  33. sksurv/linear_model/__init__.py +3 -0
  34. sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
  35. sksurv/linear_model/aft.py +205 -0
  36. sksurv/linear_model/coxnet.py +543 -0
  37. sksurv/linear_model/coxph.py +618 -0
  38. sksurv/meta/__init__.py +4 -0
  39. sksurv/meta/base.py +35 -0
  40. sksurv/meta/ensemble_selection.py +642 -0
  41. sksurv/meta/stacking.py +349 -0
  42. sksurv/metrics.py +996 -0
  43. sksurv/nonparametric.py +588 -0
  44. sksurv/preprocessing.py +155 -0
  45. sksurv/svm/__init__.py +11 -0
  46. sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
  47. sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
  48. sksurv/svm/minlip.py +606 -0
  49. sksurv/svm/naive_survival_svm.py +221 -0
  50. sksurv/svm/survival_svm.py +1228 -0
  51. sksurv/testing.py +108 -0
  52. sksurv/tree/__init__.py +1 -0
  53. sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
  54. sksurv/tree/tree.py +703 -0
  55. sksurv/util.py +333 -0
@@ -0,0 +1,947 @@
1
+ from abc import ABCMeta, abstractmethod
2
+ from functools import partial
3
+ import threading
4
+ import warnings
5
+
6
+ from joblib import Parallel, delayed
7
+ import numpy as np
8
+ from sklearn.ensemble._base import _partition_estimators
9
+ from sklearn.ensemble._forest import (
10
+ BaseForest,
11
+ _accumulate_prediction,
12
+ _generate_unsampled_indices,
13
+ _get_n_samples_bootstrap,
14
+ _parallel_build_trees,
15
+ )
16
+ from sklearn.tree._tree import DTYPE
17
+ from sklearn.utils._tags import _safe_tags
18
+ from sklearn.utils.validation import check_is_fitted, check_random_state
19
+
20
+ from ..base import SurvivalAnalysisMixin
21
+ from ..metrics import concordance_index_censored
22
+ from ..tree import ExtraSurvivalTree, SurvivalTree
23
+ from ..tree._criterion import get_unique_times
24
+ from ..tree.tree import _array_to_step_function
25
+ from ..util import check_array_survival
26
+
27
+ __all__ = ["RandomSurvivalForest", "ExtraSurvivalTrees"]
28
+
29
+ MAX_INT = np.iinfo(np.int32).max
30
+
31
+
32
+ def _more_tags_patch(self):
33
+ # BaseForest._more_tags calls
34
+ # type(self.estimator)(criterion=self.criterions),
35
+ # which is incompatible with LogrankCriterion
36
+ if isinstance(self, _BaseSurvivalForest):
37
+ estimator = type(self.estimator)()
38
+ else:
39
+ estimator = type(self.estimator)(criterion=self.criterion)
40
+ return {"allow_nan": _safe_tags(estimator, key="allow_nan")}
41
+
42
+
43
+ BaseForest._more_tags = _more_tags_patch
44
+
45
+
46
+ class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
47
+ """
48
+ Base class for forest-based estimators for survival analysis.
49
+
50
+ Warning: This class should not be used directly. Use derived classes
51
+ instead.
52
+ """
53
+
54
+ @abstractmethod
55
+ def __init__(
56
+ self,
57
+ estimator,
58
+ n_estimators=100,
59
+ *,
60
+ estimator_params=tuple(),
61
+ bootstrap=False,
62
+ oob_score=False,
63
+ n_jobs=None,
64
+ random_state=None,
65
+ verbose=0,
66
+ warm_start=False,
67
+ max_samples=None,
68
+ ):
69
+ super().__init__(
70
+ estimator,
71
+ n_estimators=n_estimators,
72
+ estimator_params=estimator_params,
73
+ bootstrap=bootstrap,
74
+ oob_score=oob_score,
75
+ n_jobs=n_jobs,
76
+ random_state=random_state,
77
+ verbose=verbose,
78
+ warm_start=warm_start,
79
+ class_weight=None,
80
+ max_samples=max_samples,
81
+ )
82
+
83
+ @property
84
+ def feature_importances_(self):
85
+ """Not implemented"""
86
+ raise NotImplementedError()
87
+
88
+ def fit(self, X, y, sample_weight=None):
89
+ """Build a forest of survival trees from the training set (X, y).
90
+
91
+ Parameters
92
+ ----------
93
+ X : array-like, shape = (n_samples, n_features)
94
+ Data matrix
95
+
96
+ y : structured array, shape = (n_samples,)
97
+ A structured array containing the binary event indicator
98
+ as first field, and time of event or time of censoring as
99
+ second field.
100
+
101
+ Returns
102
+ -------
103
+ self
104
+ """
105
+ self._validate_params()
106
+
107
+ X = self._validate_data(X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, force_all_finite=False)
108
+ event, time = check_array_survival(X, y)
109
+
110
+ # _compute_missing_values_in_feature_mask checks if X has missing values and
111
+ # will raise an error if the underlying tree base estimator can't handle missing
112
+ # values.
113
+ estimator = type(self.estimator)()
114
+ missing_values_in_feature_mask = estimator._compute_missing_values_in_feature_mask(
115
+ X, estimator_name=self.__class__.__name__
116
+ )
117
+
118
+ self.n_features_in_ = X.shape[1]
119
+ time = time.astype(np.float64)
120
+ self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
121
+ self.n_outputs_ = self.unique_times_.shape[0]
122
+
123
+ y_numeric = np.empty((X.shape[0], 2), dtype=np.float64)
124
+ y_numeric[:, 0] = time
125
+ y_numeric[:, 1] = event.astype(np.float64)
126
+
127
+ # Get bootstrap sample size
128
+ n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
129
+
130
+ # Check parameters
131
+ self._validate_estimator()
132
+
133
+ if not self.bootstrap and self.oob_score:
134
+ raise ValueError("Out of bag estimation only available if bootstrap=True")
135
+
136
+ random_state = check_random_state(self.random_state)
137
+
138
+ if not self.warm_start or not hasattr(self, "estimators_"):
139
+ # Free allocated memory, if any
140
+ self.estimators_ = []
141
+
142
+ n_more_estimators = self.n_estimators - len(self.estimators_)
143
+
144
+ if n_more_estimators < 0:
145
+ raise ValueError(
146
+ f"n_estimators={self.n_estimators} must be larger or equal to "
147
+ f"len(estimators_)={len(self.estimators_)} when warm_start==True"
148
+ )
149
+
150
+ if n_more_estimators == 0:
151
+ warnings.warn("Warm-start fitting without increasing n_estimators does not fit new trees.", stacklevel=2)
152
+ else:
153
+ if self.warm_start and len(self.estimators_) > 0:
154
+ # We draw from the random state to get the random state we
155
+ # would have got if we hadn't used a warm_start.
156
+ random_state.randint(MAX_INT, size=len(self.estimators_))
157
+
158
+ trees = [self._make_estimator(append=False, random_state=random_state) for i in range(n_more_estimators)]
159
+
160
+ y_tree = (
161
+ y_numeric,
162
+ self.unique_times_,
163
+ self.is_event_time_,
164
+ )
165
+ # Parallel loop: we prefer the threading backend as the Cython code
166
+ # for fitting the trees is internally releasing the Python GIL
167
+ # making threading more efficient than multiprocessing in
168
+ # that case. However, for joblib 0.12+ we respect any
169
+ # parallel_backend contexts set at a higher level,
170
+ # since correctness does not rely on using threads.
171
+ trees = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, prefer="threads")(
172
+ delayed(_parallel_build_trees)(
173
+ t,
174
+ self.bootstrap,
175
+ X,
176
+ y_tree,
177
+ sample_weight,
178
+ i,
179
+ len(trees),
180
+ verbose=self.verbose,
181
+ n_samples_bootstrap=n_samples_bootstrap,
182
+ missing_values_in_feature_mask=missing_values_in_feature_mask,
183
+ )
184
+ for i, t in enumerate(trees)
185
+ )
186
+
187
+ # Collect newly grown trees
188
+ self.estimators_.extend(trees)
189
+
190
+ if self.oob_score:
191
+ self._set_oob_score_and_attributes(X, (event, time))
192
+
193
+ return self
194
+
195
+ def _set_oob_score_and_attributes(self, X, y):
196
+ """Calculate out of bag predictions and score."""
197
+ n_samples = X.shape[0]
198
+ event, time = y
199
+
200
+ predictions = np.zeros(n_samples)
201
+ n_predictions = np.zeros(n_samples)
202
+
203
+ n_samples_bootstrap = _get_n_samples_bootstrap(n_samples, self.max_samples)
204
+
205
+ for estimator in self.estimators_:
206
+ unsampled_indices = _generate_unsampled_indices(estimator.random_state, n_samples, n_samples_bootstrap)
207
+ p_estimator = estimator.predict(X[unsampled_indices, :], check_input=False)
208
+
209
+ predictions[unsampled_indices] += p_estimator
210
+ n_predictions[unsampled_indices] += 1
211
+
212
+ if (n_predictions == 0).any():
213
+ warnings.warn(
214
+ "Some inputs do not have OOB scores. "
215
+ "This probably means too few trees were used "
216
+ "to compute any reliable oob estimates.",
217
+ stacklevel=3,
218
+ )
219
+ n_predictions[n_predictions == 0] = 1
220
+
221
+ predictions /= n_predictions
222
+ self.oob_prediction_ = predictions
223
+
224
+ self.oob_score_ = concordance_index_censored(event, time, predictions)[0]
225
+
226
+ def _predict(self, predict_fn, X):
227
+ check_is_fitted(self, "estimators_")
228
+ # Check data
229
+ X = self._validate_X_predict(X)
230
+
231
+ # Assign chunk of trees to jobs
232
+ n_jobs, _, _ = _partition_estimators(self.n_estimators, self.n_jobs)
233
+
234
+ # avoid storing the output of every estimator by summing them here
235
+ if predict_fn == "predict":
236
+ y_hat = np.zeros((X.shape[0]), dtype=np.float64)
237
+ else:
238
+ y_hat = np.zeros((X.shape[0], self.n_outputs_), dtype=np.float64)
239
+
240
+ def _get_fn(est, name):
241
+ fn = getattr(est, name)
242
+ if name in ("predict_cumulative_hazard_function", "predict_survival_function"):
243
+ fn = partial(fn, return_array=True)
244
+ return fn
245
+
246
+ # Parallel loop
247
+ lock = threading.Lock()
248
+ Parallel(n_jobs=n_jobs, verbose=self.verbose, require="sharedmem")(
249
+ delayed(_accumulate_prediction)(_get_fn(e, predict_fn), X, [y_hat], lock) for e in self.estimators_
250
+ )
251
+
252
+ y_hat /= len(self.estimators_)
253
+
254
+ return y_hat
255
+
256
+ def predict(self, X):
257
+ """Predict risk score.
258
+
259
+ The ensemble risk score is the total number of events,
260
+ which can be estimated by the sum of the estimated
261
+ ensemble cumulative hazard function :math:`\\hat{H}_e`.
262
+
263
+ .. math::
264
+
265
+ \\sum_{j=1}^{n} \\hat{H}_e(T_{j} \\mid x) ,
266
+
267
+ where :math:`n` denotes the total number of distinct
268
+ event times in the training data.
269
+
270
+ Parameters
271
+ ----------
272
+ X : array-like, shape = (n_samples, n_features)
273
+ Data matrix.
274
+
275
+ Returns
276
+ -------
277
+ risk_scores : ndarray, shape = (n_samples,)
278
+ Predicted risk scores.
279
+ """
280
+ return self._predict("predict", X)
281
+
282
+ def predict_cumulative_hazard_function(self, X, return_array=False):
283
+ arr = self._predict("predict_cumulative_hazard_function", X)
284
+ if return_array:
285
+ return arr
286
+ return _array_to_step_function(self.unique_times_, arr)
287
+
288
+ def predict_survival_function(self, X, return_array=False):
289
+ arr = self._predict("predict_survival_function", X)
290
+ if return_array:
291
+ return arr
292
+ return _array_to_step_function(self.unique_times_, arr)
293
+
294
+
295
+ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
296
+ """A random survival forest.
297
+
298
+ A random survival forest is a meta estimator that fits a number of
299
+ survival trees on various sub-samples of the dataset and uses
300
+ averaging to improve the predictive accuracy and control over-fitting.
301
+ The sub-sample size is always the same as the original input sample
302
+ size but the samples are drawn with replacement if
303
+ `bootstrap=True` (default).
304
+
305
+ In each survival tree, the quality of a split is measured by the
306
+ log-rank splitting rule.
307
+
308
+ See the :ref:`User Guide </user_guide/random-survival-forest.ipynb>`,
309
+ [1]_ and [2]_ for further description.
310
+
311
+ Parameters
312
+ ----------
313
+ n_estimators : integer, optional, default: 100
314
+ The number of trees in the forest.
315
+
316
+ max_depth : int or None, optional, default: None
317
+ The maximum depth of the tree. If None, then nodes are expanded until
318
+ all leaves are pure or until all leaves contain less than
319
+ min_samples_split samples.
320
+
321
+ min_samples_split : int, float, optional, default: 6
322
+ The minimum number of samples required to split an internal node:
323
+
324
+ - If int, then consider `min_samples_split` as the minimum number.
325
+ - If float, then `min_samples_split` is a fraction and
326
+ `ceil(min_samples_split * n_samples)` are the minimum
327
+ number of samples for each split.
328
+
329
+ min_samples_leaf : int, float, optional, default: 3
330
+ The minimum number of samples required to be at a leaf node.
331
+ A split point at any depth will only be considered if it leaves at
332
+ least ``min_samples_leaf`` training samples in each of the left and
333
+ right branches. This may have the effect of smoothing the model,
334
+ especially in regression.
335
+
336
+ - If int, then consider `min_samples_leaf` as the minimum number.
337
+ - If float, then `min_samples_leaf` is a fraction and
338
+ `ceil(min_samples_leaf * n_samples)` are the minimum
339
+ number of samples for each node.
340
+
341
+ min_weight_fraction_leaf : float, optional, default: 0.
342
+ The minimum weighted fraction of the sum total of weights (of all
343
+ the input samples) required to be at a leaf node. Samples have
344
+ equal weight when sample_weight is not provided.
345
+
346
+ max_features : int, float, string or None, optional, default: None
347
+ The number of features to consider when looking for the best split:
348
+
349
+ - If int, then consider `max_features` features at each split.
350
+ - If float, then `max_features` is a fraction and
351
+ `int(max_features * n_features)` features are considered at each
352
+ split.
353
+ - If "sqrt", then `max_features=sqrt(n_features)`.
354
+ - If "log2", then `max_features=log2(n_features)`.
355
+ - If None, then `max_features=n_features`.
356
+
357
+ Note: the search for a split does not stop until at least one
358
+ valid partition of the node samples is found, even if it requires to
359
+ effectively inspect more than ``max_features`` features.
360
+
361
+ max_leaf_nodes : int or None, optional, default: None
362
+ Grow a tree with ``max_leaf_nodes`` in best-first fashion.
363
+ Best nodes are defined as relative reduction in impurity.
364
+ If None then unlimited number of leaf nodes.
365
+
366
+ bootstrap : boolean, optional, default: True
367
+ Whether bootstrap samples are used when building trees. If False, the
368
+ whole dataset is used to build each tree.
369
+
370
+ oob_score : bool, default: False
371
+ Whether to use out-of-bag samples to estimate
372
+ the generalization accuracy.
373
+
374
+ n_jobs : int or None, optional, default: None
375
+ The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
376
+ :meth:`decision_path` and :meth:`apply` are all parallelized over the
377
+ trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
378
+ context. ``-1`` means using all processors.
379
+
380
+ random_state : int, RandomState instance or None, optional, default: None
381
+ Controls both the randomness of the bootstrapping of the samples used
382
+ when building trees (if ``bootstrap=True``) and the sampling of the
383
+ features to consider when looking for the best split at each node
384
+ (if ``max_features < n_features``).
385
+
386
+ verbose : int, optional, default: 0
387
+ Controls the verbosity when fitting and predicting.
388
+
389
+ warm_start : bool, optional, default: False
390
+ When set to ``True``, reuse the solution of the previous call to fit
391
+ and add more estimators to the ensemble, otherwise, just fit a whole
392
+ new forest.
393
+
394
+ max_samples : int or float, optional, default: None
395
+ If bootstrap is True, the number of samples to draw from X
396
+ to train each base estimator.
397
+
398
+ - If None (default), then draw `X.shape[0]` samples.
399
+ - If int, then draw `max_samples` samples.
400
+ - If float, then draw `max_samples * X.shape[0]` samples. Thus,
401
+ `max_samples` should be in the interval `(0.0, 1.0]`.
402
+
403
+ low_memory : boolean, default: False
404
+ If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function``
405
+ and ``predict_survival_function`` are not implemented.
406
+
407
+ Attributes
408
+ ----------
409
+ estimators_ : list of SurvivalTree instances
410
+ The collection of fitted sub-estimators.
411
+
412
+ unique_times_ : array of shape = (n_unique_times,)
413
+ Unique time points.
414
+
415
+ n_features_in_ : int
416
+ Number of features seen during ``fit``.
417
+
418
+ feature_names_in_ : ndarray of shape (`n_features_in_`,)
419
+ Names of features seen during ``fit``. Defined only when `X`
420
+ has feature names that are all strings.
421
+
422
+ oob_score_ : float
423
+ Concordance index of the training dataset obtained
424
+ using an out-of-bag estimate.
425
+
426
+ See also
427
+ --------
428
+ sksurv.tree.SurvivalTree
429
+ A single survival tree.
430
+
431
+ Notes
432
+ -----
433
+ The default values for the parameters controlling the size of the trees
434
+ (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
435
+ unpruned trees which can potentially be very large on some data sets. To
436
+ reduce memory consumption, the complexity and size of the trees should be
437
+ controlled by setting those parameter values.
438
+
439
+ Compared to scikit-learn's random forest models, :class:`RandomSurvivalForest`
440
+ currently does not support controlling the depth of a tree based on the log-rank
441
+ test statistics or it's associated p-value, i.e., the parameters
442
+ `min_impurity_decrease` or `min_impurity_split` are absent.
443
+ In addition, the `feature_importances_` attribute is not available.
444
+ It is recommended to estimate feature importances via
445
+ `permutation-based methods <https://eli5.readthedocs.io>`_.
446
+
447
+ The features are always randomly permuted at each split. Therefore,
448
+ the best found split may vary, even with the same training data,
449
+ ``max_features=n_features`` and ``bootstrap=False``, if the improvement
450
+ of the criterion is identical for several splits enumerated during the
451
+ search of the best split. To obtain a deterministic behavior during
452
+ fitting, ``random_state`` has to be fixed.
453
+
454
+ References
455
+ ----------
456
+ .. [1] Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008).
457
+ Random survival forests. The Annals of Applied Statistics, 2(3), 841–860.
458
+
459
+ .. [2] Ishwaran, H., Kogalur, U. B. (2007). Random survival forests for R.
460
+ R News, 7(2), 25–31. https://cran.r-project.org/doc/Rnews/Rnews_2007-2.pdf.
461
+ """
462
+
463
+ _parameter_constraints = {
464
+ **BaseForest._parameter_constraints,
465
+ **SurvivalTree._parameter_constraints,
466
+ }
467
+ _parameter_constraints.pop("splitter")
468
+
469
+ def __init__(
470
+ self,
471
+ n_estimators=100,
472
+ *,
473
+ max_depth=None,
474
+ min_samples_split=6,
475
+ min_samples_leaf=3,
476
+ min_weight_fraction_leaf=0.0,
477
+ max_features="sqrt",
478
+ max_leaf_nodes=None,
479
+ bootstrap=True,
480
+ oob_score=False,
481
+ n_jobs=None,
482
+ random_state=None,
483
+ verbose=0,
484
+ warm_start=False,
485
+ max_samples=None,
486
+ low_memory=False,
487
+ ):
488
+ super().__init__(
489
+ estimator=SurvivalTree(),
490
+ n_estimators=n_estimators,
491
+ estimator_params=(
492
+ "max_depth",
493
+ "min_samples_split",
494
+ "min_samples_leaf",
495
+ "min_weight_fraction_leaf",
496
+ "max_features",
497
+ "max_leaf_nodes",
498
+ "random_state",
499
+ "low_memory",
500
+ ),
501
+ bootstrap=bootstrap,
502
+ oob_score=oob_score,
503
+ n_jobs=n_jobs,
504
+ random_state=random_state,
505
+ verbose=verbose,
506
+ warm_start=warm_start,
507
+ max_samples=max_samples,
508
+ )
509
+
510
+ self.max_depth = max_depth
511
+ self.min_samples_split = min_samples_split
512
+ self.min_samples_leaf = min_samples_leaf
513
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
514
+ self.max_features = max_features
515
+ self.max_leaf_nodes = max_leaf_nodes
516
+ self.low_memory = low_memory
517
+
518
+ def predict_cumulative_hazard_function(self, X, return_array=False):
519
+ """Predict cumulative hazard function.
520
+
521
+ For each tree in the ensemble, the cumulative hazard
522
+ function (CHF) for an individual with feature vector
523
+ :math:`x` is computed from all samples of the bootstrap
524
+ sample that are in the same terminal node as :math:`x`.
525
+ It is estimated by the Nelson–Aalen estimator.
526
+ The ensemble CHF at time :math:`t` is the average
527
+ value across all trees in the ensemble at the
528
+ specified time point.
529
+
530
+ Parameters
531
+ ----------
532
+ X : array-like, shape = (n_samples, n_features)
533
+ Data matrix.
534
+
535
+ return_array : boolean, default: False
536
+ If set, return an array with the cumulative hazard rate
537
+ for each `self.unique_times_`, otherwise an array of
538
+ :class:`sksurv.functions.StepFunction`.
539
+
540
+ Returns
541
+ -------
542
+ cum_hazard : ndarray
543
+ If `return_array` is set, an array with the cumulative hazard rate
544
+ for each `self.unique_times_`, otherwise an array of length `n_samples`
545
+ of :class:`sksurv.functions.StepFunction` instances will be returned.
546
+
547
+ Examples
548
+ --------
549
+ >>> import matplotlib.pyplot as plt
550
+ >>> from sksurv.datasets import load_whas500
551
+ >>> from sksurv.ensemble import RandomSurvivalForest
552
+
553
+ Load and prepare the data.
554
+
555
+ >>> X, y = load_whas500()
556
+ >>> X = X.astype(float)
557
+
558
+ Fit the model.
559
+
560
+ >>> estimator = RandomSurvivalForest().fit(X, y)
561
+
562
+ Estimate the cumulative hazard function for the first 5 samples.
563
+
564
+ >>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
565
+
566
+ Plot the estimated cumulative hazard functions.
567
+
568
+ >>> for fn in chf_funcs:
569
+ ... plt.step(fn.x, fn(fn.x), where="post")
570
+ ...
571
+ >>> plt.ylim(0, 1)
572
+ >>> plt.show()
573
+ """
574
+ return super().predict_cumulative_hazard_function(X, return_array)
575
+
576
+ def predict_survival_function(self, X, return_array=False):
577
+ """Predict survival function.
578
+
579
+ For each tree in the ensemble, the survival function
580
+ for an individual with feature vector :math:`x` is
581
+ computed from all samples of the bootstrap sample that
582
+ are in the same terminal node as :math:`x`.
583
+ It is estimated by the Kaplan-Meier estimator.
584
+ The ensemble survival function at time :math:`t` is
585
+ the average value across all trees in the ensemble at
586
+ the specified time point.
587
+
588
+ Parameters
589
+ ----------
590
+ X : array-like, shape = (n_samples, n_features)
591
+ Data matrix.
592
+
593
+ return_array : boolean
594
+ If set, return an array with the probability
595
+ of survival for each `self.unique_times_`,
596
+ otherwise an array of :class:`sksurv.functions.StepFunction`.
597
+
598
+ Returns
599
+ -------
600
+ survival : ndarray
601
+ If `return_array` is set, an array with the probability
602
+ of survival for each `self.unique_times_`,
603
+ otherwise an array of :class:`sksurv.functions.StepFunction`
604
+ will be returned.
605
+
606
+ Examples
607
+ --------
608
+ >>> import matplotlib.pyplot as plt
609
+ >>> from sksurv.datasets import load_whas500
610
+ >>> from sksurv.ensemble import RandomSurvivalForest
611
+
612
+ Load and prepare the data.
613
+
614
+ >>> X, y = load_whas500()
615
+ >>> X = X.astype(float)
616
+
617
+ Fit the model.
618
+
619
+ >>> estimator = RandomSurvivalForest().fit(X, y)
620
+
621
+ Estimate the survival function for the first 5 samples.
622
+
623
+ >>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
624
+
625
+ Plot the estimated survival functions.
626
+
627
+ >>> for fn in surv_funcs:
628
+ ... plt.step(fn.x, fn(fn.x), where="post")
629
+ ...
630
+ >>> plt.ylim(0, 1)
631
+ >>> plt.show()
632
+ """
633
+ return super().predict_survival_function(X, return_array)
634
+
635
+
636
+ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
637
+ """An extremely random survival forest.
638
+
639
+ This class implements a meta estimator that fits a number of randomized
640
+ survival trees (a.k.a. extra-trees) on various sub-samples of the dataset
641
+ and uses averaging to improve the predictive accuracy and control
642
+ over-fitting. The sub-sample size is always the same as the original
643
+ input sample size but the samples are drawn with replacement if
644
+ `bootstrap=True` (default).
645
+
646
+ In each randomized survival tree, the quality of a split is measured by
647
+ the log-rank splitting rule.
648
+
649
+ Compared to :class:`RandomSurvivalForest`, randomness goes one step
650
+ further in the way splits are computed. As in
651
+ :class:`RandomSurvivalForest`, a random subset of candidate features is
652
+ used, but instead of looking for the most discriminative thresholds,
653
+ thresholds are drawn at random for each candidate feature and the best of
654
+ these randomly-generated thresholds is picked as the splitting rule.
655
+
656
+ Parameters
657
+ ----------
658
+ n_estimators : integer, optional, default: 100
659
+ The number of trees in the forest.
660
+
661
+ max_depth : int or None, optional, default: None
662
+ The maximum depth of the tree. If None, then nodes are expanded until
663
+ all leaves are pure or until all leaves contain less than
664
+ min_samples_split samples.
665
+
666
+ min_samples_split : int, float, optional, default: 6
667
+ The minimum number of samples required to split an internal node:
668
+
669
+ - If int, then consider `min_samples_split` as the minimum number.
670
+ - If float, then `min_samples_split` is a fraction and
671
+ `ceil(min_samples_split * n_samples)` are the minimum
672
+ number of samples for each split.
673
+
674
+ min_samples_leaf : int, float, optional, default: 3
675
+ The minimum number of samples required to be at a leaf node.
676
+ A split point at any depth will only be considered if it leaves at
677
+ least ``min_samples_leaf`` training samples in each of the left and
678
+ right branches. This may have the effect of smoothing the model,
679
+ especially in regression.
680
+
681
+ - If int, then consider `min_samples_leaf` as the minimum number.
682
+ - If float, then `min_samples_leaf` is a fraction and
683
+ `ceil(min_samples_leaf * n_samples)` are the minimum
684
+ number of samples for each node.
685
+
686
+ min_weight_fraction_leaf : float, optional, default: 0.
687
+ The minimum weighted fraction of the sum total of weights (of all
688
+ the input samples) required to be at a leaf node. Samples have
689
+ equal weight when sample_weight is not provided.
690
+
691
+ max_features : int, float, string or None, optional, default: None
692
+ The number of features to consider when looking for the best split:
693
+
694
+ - If int, then consider `max_features` features at each split.
695
+ - If float, then `max_features` is a fraction and
696
+ `int(max_features * n_features)` features are considered at each
697
+ split.
698
+ - If "sqrt", then `max_features=sqrt(n_features)`.
699
+ - If "log2", then `max_features=log2(n_features)`.
700
+ - If None, then `max_features=n_features`.
701
+
702
+ Note: the search for a split does not stop until at least one
703
+ valid partition of the node samples is found, even if it requires to
704
+ effectively inspect more than ``max_features`` features.
705
+
706
+ max_leaf_nodes : int or None, optional, default: None
707
+ Grow a tree with ``max_leaf_nodes`` in best-first fashion.
708
+ Best nodes are defined as relative reduction in impurity.
709
+ If None then unlimited number of leaf nodes.
710
+
711
+ bootstrap : boolean, optional, default: True
712
+ Whether bootstrap samples are used when building trees. If False, the
713
+ whole dataset is used to build each tree.
714
+
715
+ oob_score : bool, default: False
716
+ Whether to use out-of-bag samples to estimate
717
+ the generalization accuracy.
718
+
719
+ n_jobs : int or None, optional, default: None
720
+ The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`,
721
+ :meth:`decision_path` and :meth:`apply` are all parallelized over the
722
+ trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend`
723
+ context. ``-1`` means using all processors.
724
+
725
+ random_state : int, RandomState instance or None, optional, default: None
726
+ Controls both the randomness of the bootstrapping of the samples used
727
+ when building trees (if ``bootstrap=True``) and the sampling of the
728
+ features to consider when looking for the best split at each node
729
+ (if ``max_features < n_features``).
730
+
731
+ verbose : int, optional, default: 0
732
+ Controls the verbosity when fitting and predicting.
733
+
734
+ warm_start : bool, optional, default: False
735
+ When set to ``True``, reuse the solution of the previous call to fit
736
+ and add more estimators to the ensemble, otherwise, just fit a whole
737
+ new forest.
738
+
739
+ max_samples : int or float, optional, default: None
740
+ If bootstrap is True, the number of samples to draw from X
741
+ to train each base estimator.
742
+
743
+ - If None (default), then draw `X.shape[0]` samples.
744
+ - If int, then draw `max_samples` samples.
745
+ - If float, then draw `max_samples * X.shape[0]` samples. Thus,
746
+ `max_samples` should be in the interval `(0.0, 1.0]`.
747
+
748
+ low_memory : boolean, default: False
749
+ If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function``
750
+ and ``predict_survival_function`` are not implemented.
751
+
752
+ Attributes
753
+ ----------
754
+ estimators_ : list of SurvivalTree instances
755
+ The collection of fitted sub-estimators.
756
+
757
+ unique_times_ : array of shape = (n_unique_times,)
758
+ Unique time points.
759
+
760
+ n_features_in_ : int
761
+ The number of features when ``fit`` is performed.
762
+
763
+ feature_names_in_ : ndarray of shape (`n_features_in_`,)
764
+ Names of features seen during ``fit``. Defined only when `X`
765
+ has feature names that are all strings.
766
+
767
+ oob_score_ : float
768
+ Concordance index of the training dataset obtained
769
+ using an out-of-bag estimate.
770
+
771
+ See also
772
+ --------
773
+ sksurv.tree.SurvivalTree
774
+ A single survival tree.
775
+ """
776
+
777
+ _parameter_constraints = {
778
+ **BaseForest._parameter_constraints,
779
+ **SurvivalTree._parameter_constraints,
780
+ }
781
+ _parameter_constraints.pop("splitter")
782
+
783
+ def __init__(
784
+ self,
785
+ n_estimators=100,
786
+ *,
787
+ max_depth=None,
788
+ min_samples_split=6,
789
+ min_samples_leaf=3,
790
+ min_weight_fraction_leaf=0.0,
791
+ max_features="sqrt",
792
+ max_leaf_nodes=None,
793
+ bootstrap=True,
794
+ oob_score=False,
795
+ n_jobs=None,
796
+ random_state=None,
797
+ verbose=0,
798
+ warm_start=False,
799
+ max_samples=None,
800
+ low_memory=False,
801
+ ):
802
+ super().__init__(
803
+ estimator=ExtraSurvivalTree(),
804
+ n_estimators=n_estimators,
805
+ estimator_params=(
806
+ "max_depth",
807
+ "min_samples_split",
808
+ "min_samples_leaf",
809
+ "min_weight_fraction_leaf",
810
+ "max_features",
811
+ "max_leaf_nodes",
812
+ "random_state",
813
+ "low_memory",
814
+ ),
815
+ bootstrap=bootstrap,
816
+ oob_score=oob_score,
817
+ n_jobs=n_jobs,
818
+ random_state=random_state,
819
+ verbose=verbose,
820
+ warm_start=warm_start,
821
+ max_samples=max_samples,
822
+ )
823
+
824
+ self.max_depth = max_depth
825
+ self.min_samples_split = min_samples_split
826
+ self.min_samples_leaf = min_samples_leaf
827
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
828
+ self.max_features = max_features
829
+ self.max_leaf_nodes = max_leaf_nodes
830
+ self.low_memory = low_memory
831
+
832
+ def predict_cumulative_hazard_function(self, X, return_array=False):
833
+ """Predict cumulative hazard function.
834
+
835
+ For each tree in the ensemble, the cumulative hazard
836
+ function (CHF) for an individual with feature vector
837
+ :math:`x` is computed from all samples of the bootstrap
838
+ sample that are in the same terminal node as :math:`x`.
839
+ It is estimated by the Nelson–Aalen estimator.
840
+ The ensemble CHF at time :math:`t` is the average
841
+ value across all trees in the ensemble at the
842
+ specified time point.
843
+
844
+ Parameters
845
+ ----------
846
+ X : array-like, shape = (n_samples, n_features)
847
+ Data matrix.
848
+
849
+ return_array : boolean, default: False
850
+ If set, return an array with the cumulative hazard rate
851
+ for each `self.unique_times_`, otherwise an array of
852
+ :class:`sksurv.functions.StepFunction`.
853
+
854
+ Returns
855
+ -------
856
+ cum_hazard : ndarray
857
+ If `return_array` is set, an array with the cumulative hazard rate
858
+ for each `self.unique_times_`, otherwise an array of length `n_samples`
859
+ of :class:`sksurv.functions.StepFunction` instances will be returned.
860
+
861
+ Examples
862
+ --------
863
+ >>> import matplotlib.pyplot as plt
864
+ >>> from sksurv.datasets import load_whas500
865
+ >>> from sksurv.ensemble import ExtraSurvivalTrees
866
+
867
+ Load and prepare the data.
868
+
869
+ >>> X, y = load_whas500()
870
+ >>> X = X.astype(float)
871
+
872
+ Fit the model.
873
+
874
+ >>> estimator = ExtraSurvivalTrees().fit(X, y)
875
+
876
+ Estimate the cumulative hazard function for the first 5 samples.
877
+
878
+ >>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
879
+
880
+ Plot the estimated cumulative hazard functions.
881
+
882
+ >>> for fn in chf_funcs:
883
+ ... plt.step(fn.x, fn(fn.x), where="post")
884
+ ...
885
+ >>> plt.ylim(0, 1)
886
+ >>> plt.show()
887
+ """
888
+ return super().predict_cumulative_hazard_function(X, return_array)
889
+
890
+ def predict_survival_function(self, X, return_array=False):
891
+ """Predict survival function.
892
+
893
+ For each tree in the ensemble, the survival function
894
+ for an individual with feature vector :math:`x` is
895
+ computed from all samples of the bootstrap sample that
896
+ are in the same terminal node as :math:`x`.
897
+ It is estimated by the Kaplan-Meier estimator.
898
+ The ensemble survival function at time :math:`t` is
899
+ the average value across all trees in the ensemble at
900
+ the specified time point.
901
+
902
+ Parameters
903
+ ----------
904
+ X : array-like, shape = (n_samples, n_features)
905
+ Data matrix.
906
+
907
+ return_array : boolean, default: False
908
+ If set, return an array with the probability
909
+ of survival for each `self.unique_times_`,
910
+ otherwise an array of :class:`sksurv.functions.StepFunction`.
911
+
912
+ Returns
913
+ -------
914
+ survival : ndarray
915
+ If `return_array` is set, an array with the probability of
916
+ survival for each `self.unique_times_`, otherwise an array of
917
+ length `n_samples` of :class:`sksurv.functions.StepFunction`
918
+ instances will be returned.
919
+
920
+ Examples
921
+ --------
922
+ >>> import matplotlib.pyplot as plt
923
+ >>> from sksurv.datasets import load_whas500
924
+ >>> from sksurv.ensemble import ExtraSurvivalTrees
925
+
926
+ Load and prepare the data.
927
+
928
+ >>> X, y = load_whas500()
929
+ >>> X = X.astype(float)
930
+
931
+ Fit the model.
932
+
933
+ >>> estimator = ExtraSurvivalTrees().fit(X, y)
934
+
935
+ Estimate the survival function for the first 5 samples.
936
+
937
+ >>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
938
+
939
+ Plot the estimated survival functions.
940
+
941
+ >>> for fn in surv_funcs:
942
+ ... plt.step(fn.x, fn(fn.x), where="post")
943
+ ...
944
+ >>> plt.ylim(0, 1)
945
+ >>> plt.show()
946
+ """
947
+ return super().predict_survival_function(X, return_array)