scikit-survival 0.25.0__cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. scikit_survival-0.25.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.25.0.dist-info/RECORD +58 -0
  3. scikit_survival-0.25.0.dist-info/WHEEL +6 -0
  4. scikit_survival-0.25.0.dist-info/licenses/COPYING +674 -0
  5. scikit_survival-0.25.0.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +183 -0
  7. sksurv/base.py +115 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cpython-313-x86_64-linux-gnu.so +0 -0
  10. sksurv/column.py +205 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +12 -0
  13. sksurv/datasets/base.py +614 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/bmt.arff +46 -0
  17. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  18. sksurv/datasets/data/cgvhd.arff +118 -0
  19. sksurv/datasets/data/flchain.arff +7887 -0
  20. sksurv/datasets/data/veteran.arff +148 -0
  21. sksurv/datasets/data/whas500.arff +520 -0
  22. sksurv/docstrings.py +99 -0
  23. sksurv/ensemble/__init__.py +2 -0
  24. sksurv/ensemble/_coxph_loss.cpython-313-x86_64-linux-gnu.so +0 -0
  25. sksurv/ensemble/boosting.py +1564 -0
  26. sksurv/ensemble/forest.py +902 -0
  27. sksurv/ensemble/survival_loss.py +151 -0
  28. sksurv/exceptions.py +18 -0
  29. sksurv/functions.py +114 -0
  30. sksurv/io/__init__.py +2 -0
  31. sksurv/io/arffread.py +89 -0
  32. sksurv/io/arffwrite.py +181 -0
  33. sksurv/kernels/__init__.py +1 -0
  34. sksurv/kernels/_clinical_kernel.cpython-313-x86_64-linux-gnu.so +0 -0
  35. sksurv/kernels/clinical.py +348 -0
  36. sksurv/linear_model/__init__.py +3 -0
  37. sksurv/linear_model/_coxnet.cpython-313-x86_64-linux-gnu.so +0 -0
  38. sksurv/linear_model/aft.py +208 -0
  39. sksurv/linear_model/coxnet.py +592 -0
  40. sksurv/linear_model/coxph.py +637 -0
  41. sksurv/meta/__init__.py +4 -0
  42. sksurv/meta/base.py +35 -0
  43. sksurv/meta/ensemble_selection.py +724 -0
  44. sksurv/meta/stacking.py +370 -0
  45. sksurv/metrics.py +1028 -0
  46. sksurv/nonparametric.py +911 -0
  47. sksurv/preprocessing.py +183 -0
  48. sksurv/svm/__init__.py +11 -0
  49. sksurv/svm/_minlip.cpython-313-x86_64-linux-gnu.so +0 -0
  50. sksurv/svm/_prsvm.cpython-313-x86_64-linux-gnu.so +0 -0
  51. sksurv/svm/minlip.py +690 -0
  52. sksurv/svm/naive_survival_svm.py +249 -0
  53. sksurv/svm/survival_svm.py +1236 -0
  54. sksurv/testing.py +108 -0
  55. sksurv/tree/__init__.py +1 -0
  56. sksurv/tree/_criterion.cpython-313-x86_64-linux-gnu.so +0 -0
  57. sksurv/tree/tree.py +790 -0
  58. sksurv/util.py +415 -0
sksurv/tree/tree.py ADDED
@@ -0,0 +1,790 @@
1
+ from math import ceil
2
+ from numbers import Integral, Real
3
+
4
+ import numpy as np
5
+ from scipy.sparse import issparse
6
+ from sklearn.base import BaseEstimator
7
+ from sklearn.tree import _tree
8
+ from sklearn.tree._classes import DENSE_SPLITTERS, SPARSE_SPLITTERS
9
+ from sklearn.tree._splitter import Splitter
10
+ from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
11
+ from sklearn.tree._utils import _any_isnan_axis0
12
+ from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
13
+ from sklearn.utils.validation import (
14
+ _assert_all_finite_element_wise,
15
+ _check_n_features,
16
+ assert_all_finite,
17
+ check_is_fitted,
18
+ check_random_state,
19
+ validate_data,
20
+ )
21
+
22
+ from ..base import SurvivalAnalysisMixin
23
+ from ..docstrings import append_cumulative_hazard_example, append_survival_function_example
24
+ from ..functions import StepFunction
25
+ from ..util import check_array_survival
26
+ from ._criterion import LogrankCriterion, get_unique_times
27
+
28
+ __all__ = ["ExtraSurvivalTree", "SurvivalTree"]
29
+
30
+ DTYPE = _tree.DTYPE
31
+
32
+
33
+ def _array_to_step_function(x, array):
34
+ n_samples = array.shape[0]
35
+ funcs = np.empty(n_samples, dtype=np.object_)
36
+ for i in range(n_samples):
37
+ funcs[i] = StepFunction(x=x, y=array[i])
38
+ return funcs
39
+
40
+
41
+ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
42
+ """A single survival tree.
43
+
44
+ The quality of a split is measured by the log-rank splitting rule.
45
+
46
+ If ``splitter='best'``, fit and predict methods support
47
+ missing values. See :ref:`tree_missing_value_support` for details.
48
+
49
+ See [1]_, [2]_ and [3]_ for further description.
50
+
51
+ Parameters
52
+ ----------
53
+ splitter : {'best', 'random'}, default: 'best'
54
+ The strategy used to choose the split at each node. Supported
55
+ strategies are 'best' to choose the best split and 'random' to choose
56
+ the best random split.
57
+
58
+ max_depth : int or None, optional, default: None
59
+ The maximum depth of the tree. If None, then nodes are expanded until
60
+ all leaves are pure or until all leaves contain less than
61
+ `min_samples_split` samples.
62
+
63
+ min_samples_split : int, float, optional, default: 6
64
+ The minimum number of samples required to split an internal node:
65
+
66
+ - If int, then consider `min_samples_split` as the minimum number.
67
+ - If float, then `min_samples_split` is a fraction and
68
+ `ceil(min_samples_split * n_samples)` are the minimum
69
+ number of samples for each split.
70
+
71
+ min_samples_leaf : int, float, optional, default: 3
72
+ The minimum number of samples required to be at a leaf node.
73
+ A split point at any depth will only be considered if it leaves at
74
+ least ``min_samples_leaf`` training samples in each of the left and
75
+ right branches. This may have the effect of smoothing the model,
76
+ especially in regression.
77
+
78
+ - If int, then consider `min_samples_leaf` as the minimum number.
79
+ - If float, then `min_samples_leaf` is a fraction and
80
+ `ceil(min_samples_leaf * n_samples)` are the minimum
81
+ number of samples for each node.
82
+
83
+ min_weight_fraction_leaf : float, optional, default: 0.
84
+ The minimum weighted fraction of the sum total of weights (of all
85
+ the input samples) required to be at a leaf node. Samples have
86
+ equal weight when sample_weight is not provided.
87
+
88
+ max_features : int, float or {'sqrt', 'log2'} or None, optional, default: None
89
+ The number of features to consider when looking for the best split:
90
+
91
+ - If int, then consider `max_features` features at each split.
92
+ - If float, then `max_features` is a fraction and
93
+ `max(1, int(max_features * n_features_in_))` features are considered at
94
+ each split.
95
+ - If "sqrt", then `max_features=sqrt(n_features)`.
96
+ - If "log2", then `max_features=log2(n_features)`.
97
+ - If None, then `max_features=n_features`.
98
+
99
+ Note: the search for a split does not stop until at least one
100
+ valid partition of the node samples is found, even if it requires to
101
+ effectively inspect more than ``max_features`` features.
102
+
103
+ random_state : int, RandomState instance or None, optional, default: None
104
+ Controls the randomness of the estimator. The features are always
105
+ randomly permuted at each split, even if ``splitter`` is set to
106
+ ``"best"``. When ``max_features < n_features``, the algorithm will
107
+ select ``max_features`` at random at each split before finding the best
108
+ split among them. But the best found split may vary across different
109
+ runs, even if ``max_features=n_features``. That is the case, if the
110
+ improvement of the criterion is identical for several splits and one
111
+ split has to be selected at random. To obtain a deterministic behavior
112
+ during fitting, ``random_state`` has to be fixed to an integer.
113
+
114
+ max_leaf_nodes : int or None, optional, default: None
115
+ Grow a tree with ``max_leaf_nodes`` in best-first fashion.
116
+ Best nodes are defined as relative reduction in impurity.
117
+ If None then unlimited number of leaf nodes.
118
+
119
+ low_memory : bool, optional, default: False
120
+ If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
121
+ and :meth:`predict_survival_function` are not implemented.
122
+
123
+ Attributes
124
+ ----------
125
+ unique_times_ : ndarray, shape = (n_unique_times,), dtype = float
126
+ Unique time points.
127
+
128
+ max_features_ : int
129
+ The inferred value of max_features.
130
+
131
+ n_features_in_ : int
132
+ Number of features seen during ``fit``.
133
+
134
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,), dtype = object
135
+ Names of features seen during ``fit``. Defined only when `X`
136
+ has feature names that are all strings.
137
+
138
+ tree_ : Tree object
139
+ The underlying Tree object. Please refer to
140
+ ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object.
141
+
142
+ See also
143
+ --------
144
+ sksurv.ensemble.RandomSurvivalForest : An ensemble of SurvivalTrees.
145
+
146
+ References
147
+ ----------
148
+ .. [1] Leblanc, M., & Crowley, J. (1993). Survival Trees by Goodness of Split.
149
+ Journal of the American Statistical Association, 88(422), 457–467.
150
+
151
+ .. [2] Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008).
152
+ Random survival forests. The Annals of Applied Statistics, 2(3), 841–860.
153
+
154
+ .. [3] Ishwaran, H., Kogalur, U. B. (2007). Random survival forests for R.
155
+ R News, 7(2), 25–31. https://cran.r-project.org/doc/Rnews/Rnews_2007-2.pdf.
156
+ """
157
+
158
+ _parameter_constraints = {
159
+ "splitter": [StrOptions({"best", "random"})],
160
+ "max_depth": [Interval(Integral, 1, None, closed="left"), None],
161
+ "min_samples_split": [
162
+ Interval(Integral, 2, None, closed="left"),
163
+ Interval(RealNotInt, 0.0, 1.0, closed="neither"),
164
+ ],
165
+ "min_samples_leaf": [
166
+ Interval(Integral, 1, None, closed="left"),
167
+ Interval(RealNotInt, 0.0, 0.5, closed="right"),
168
+ ],
169
+ "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
170
+ "max_features": [
171
+ Interval(Integral, 1, None, closed="left"),
172
+ Interval(RealNotInt, 0.0, 1.0, closed="right"),
173
+ StrOptions({"sqrt", "log2"}),
174
+ None,
175
+ ],
176
+ "random_state": ["random_state"],
177
+ "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
178
+ "low_memory": ["boolean"],
179
+ }
180
+
181
+ criterion = "logrank"
182
+
183
+ def __init__(
184
+ self,
185
+ *,
186
+ splitter="best",
187
+ max_depth=None,
188
+ min_samples_split=6,
189
+ min_samples_leaf=3,
190
+ min_weight_fraction_leaf=0.0,
191
+ max_features=None,
192
+ random_state=None,
193
+ max_leaf_nodes=None,
194
+ low_memory=False,
195
+ ):
196
+ self.splitter = splitter
197
+ self.max_depth = max_depth
198
+ self.min_samples_split = min_samples_split
199
+ self.min_samples_leaf = min_samples_leaf
200
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
201
+ self.max_features = max_features
202
+ self.random_state = random_state
203
+ self.max_leaf_nodes = max_leaf_nodes
204
+ self.low_memory = low_memory
205
+
206
+ def __sklearn_tags__(self):
207
+ tags = super().__sklearn_tags__()
208
+ tags.input_tags.allow_nan = self.splitter in ("best", "random")
209
+ return tags
210
+
211
+ def _support_missing_values(self, X):
212
+ return not issparse(X) and self.__sklearn_tags__().input_tags.allow_nan
213
+
214
+ def _compute_missing_values_in_feature_mask(self, X, estimator_name=None):
215
+ """Return boolean mask denoting if there are missing values for each feature.
216
+
217
+ This method also ensures that X is finite.
218
+
219
+ Parameter
220
+ ---------
221
+ X : array-like, shape = (n_samples, n_features), dtype = DOUBLE
222
+ Input data.
223
+
224
+ estimator_name : str or None, default=None
225
+ Name to use when raising an error. Defaults to the class name.
226
+
227
+ Returns
228
+ -------
229
+ missing_values_in_feature_mask : ndarray of shape (n_features,), or None
230
+ Missing value mask. If missing values are not supported or there
231
+ are no missing values, return None.
232
+ """
233
+ estimator_name = estimator_name or self.__class__.__name__
234
+ common_kwargs = dict(estimator_name=estimator_name, input_name="X")
235
+
236
+ if not self._support_missing_values(X):
237
+ assert_all_finite(X, **common_kwargs)
238
+ return None
239
+
240
+ with np.errstate(over="ignore"):
241
+ overall_sum = np.sum(X)
242
+
243
+ if not np.isfinite(overall_sum):
244
+ # Raise a ValueError in case of the presence of an infinite element.
245
+ _assert_all_finite_element_wise(X, xp=np, allow_nan=True, **common_kwargs)
246
+
247
+ # If the sum is not nan, then there are no missing values
248
+ if not np.isnan(overall_sum):
249
+ return None
250
+
251
+ missing_values_in_feature_mask = _any_isnan_axis0(X)
252
+ return missing_values_in_feature_mask
253
+
254
+ def fit(self, X, y, sample_weight=None, check_input=True):
255
+ """Build a survival tree from the training set (X, y).
256
+
257
+ If ``splitter='best'``, `X` is allowed to contain missing
258
+ values. In addition to evaluating each potential threshold on
259
+ the non-missing data, the splitter will evaluate the split
260
+ with all the missing values going to the left node or the
261
+ right node. See :ref:`tree_missing_value_support` for details.
262
+
263
+ Parameters
264
+ ----------
265
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
266
+ Data matrix
267
+
268
+ y : structured array, shape = (n_samples,)
269
+ A structured array with two fields. The first field is a boolean
270
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
271
+ The second field is a float with the time of event or time of censoring.
272
+
273
+ check_input : boolean, default: True
274
+ Allow to bypass several input checking.
275
+ Don't use this parameter unless you know what you do.
276
+
277
+ Returns
278
+ -------
279
+ self
280
+ """
281
+ self._fit(X, y, sample_weight, check_input)
282
+ return self
283
+
284
+ def _fit(self, X, y, sample_weight=None, check_input=True, missing_values_in_feature_mask=None):
285
+ random_state = check_random_state(self.random_state)
286
+
287
+ if check_input:
288
+ X = validate_data(self, X, dtype=DTYPE, ensure_min_samples=2, accept_sparse="csc", ensure_all_finite=False)
289
+ event, time = check_array_survival(X, y)
290
+ time = time.astype(np.float64)
291
+ self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
292
+ missing_values_in_feature_mask = self._compute_missing_values_in_feature_mask(X)
293
+ if issparse(X):
294
+ X.sort_indices()
295
+
296
+ y_numeric = np.empty((X.shape[0], 2), dtype=np.float64)
297
+ y_numeric[:, 0] = time
298
+ y_numeric[:, 1] = event.astype(np.float64)
299
+ else:
300
+ y_numeric, self.unique_times_, self.is_event_time_ = y
301
+
302
+ n_samples, self.n_features_in_ = X.shape
303
+ params = self._check_params(n_samples)
304
+
305
+ if self.low_memory:
306
+ self.n_outputs_ = 1
307
+ # one "class" only, for the sum over the CHF
308
+ self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp)
309
+ else:
310
+ self.n_outputs_ = self.unique_times_.shape[0]
311
+ # one "class" for CHF, one for survival function
312
+ self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2
313
+
314
+ # Build tree
315
+ criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_, self.is_event_time_)
316
+
317
+ SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
318
+
319
+ splitter = self.splitter
320
+ if not isinstance(self.splitter, Splitter):
321
+ splitter = SPLITTERS[self.splitter](
322
+ criterion,
323
+ self.max_features_,
324
+ params["min_samples_leaf"],
325
+ params["min_weight_leaf"],
326
+ random_state,
327
+ None, # monotonic_cst
328
+ )
329
+
330
+ self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)
331
+
332
+ # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
333
+ if params["max_leaf_nodes"] < 0:
334
+ builder = DepthFirstTreeBuilder(
335
+ splitter,
336
+ params["min_samples_split"],
337
+ params["min_samples_leaf"],
338
+ params["min_weight_leaf"],
339
+ params["max_depth"],
340
+ 0.0, # min_impurity_decrease
341
+ )
342
+ else:
343
+ builder = BestFirstTreeBuilder(
344
+ splitter,
345
+ params["min_samples_split"],
346
+ params["min_samples_leaf"],
347
+ params["min_weight_leaf"],
348
+ params["max_depth"],
349
+ params["max_leaf_nodes"],
350
+ 0.0, # min_impurity_decrease
351
+ )
352
+
353
+ builder.build(self.tree_, X, y_numeric, sample_weight, missing_values_in_feature_mask)
354
+
355
+ return self
356
+
357
+ def _check_params(self, n_samples):
358
+ self._validate_params()
359
+
360
+ # Check parameters
361
+ max_depth = (2**31) - 1 if self.max_depth is None else self.max_depth
362
+
363
+ max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes
364
+
365
+ if isinstance(self.min_samples_leaf, Integral):
366
+ min_samples_leaf = self.min_samples_leaf
367
+ else: # float
368
+ min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
369
+
370
+ if isinstance(self.min_samples_split, Integral):
371
+ min_samples_split = self.min_samples_split
372
+ else: # float
373
+ min_samples_split = int(ceil(self.min_samples_split * n_samples))
374
+ min_samples_split = max(2, min_samples_split)
375
+
376
+ min_samples_split = max(min_samples_split, 2 * min_samples_leaf)
377
+
378
+ self._check_max_features()
379
+
380
+ min_weight_leaf = self.min_weight_fraction_leaf * n_samples
381
+
382
+ return {
383
+ "max_depth": max_depth,
384
+ "max_leaf_nodes": max_leaf_nodes,
385
+ "min_samples_leaf": min_samples_leaf,
386
+ "min_samples_split": min_samples_split,
387
+ "min_weight_leaf": min_weight_leaf,
388
+ }
389
+
390
+ def _check_max_features(self):
391
+ if isinstance(self.max_features, str):
392
+ if self.max_features == "sqrt":
393
+ max_features = max(1, int(np.sqrt(self.n_features_in_)))
394
+ elif self.max_features == "log2":
395
+ max_features = max(1, int(np.log2(self.n_features_in_)))
396
+
397
+ elif self.max_features is None:
398
+ max_features = self.n_features_in_
399
+ elif isinstance(self.max_features, Integral):
400
+ max_features = self.max_features
401
+ else: # float
402
+ if self.max_features > 0.0:
403
+ max_features = max(1, int(self.max_features * self.n_features_in_))
404
+ else:
405
+ max_features = 0 # pragma: no cover
406
+
407
+ if not 0 < max_features <= self.n_features_in_:
408
+ raise ValueError("max_features must be in (0, n_features]")
409
+
410
+ self.max_features_ = max_features
411
+
412
+ def _check_low_memory(self, function):
413
+ """Check if `function` is supported in low memory mode and throw if it is not."""
414
+ if self.low_memory:
415
+ raise NotImplementedError(
416
+ f"{function} is not implemented in low memory mode."
417
+ + " run fit with low_memory=False to disable low memory mode."
418
+ )
419
+
420
+ def _validate_X_predict(self, X, check_input, accept_sparse="csr"):
421
+ """Validate X whenever one tries to predict"""
422
+ if check_input:
423
+ if self._support_missing_values(X):
424
+ ensure_all_finite = "allow-nan"
425
+ else:
426
+ ensure_all_finite = True
427
+ X = validate_data(
428
+ self,
429
+ X,
430
+ dtype=DTYPE,
431
+ accept_sparse=accept_sparse,
432
+ reset=False,
433
+ ensure_all_finite=ensure_all_finite,
434
+ )
435
+ else:
436
+ # The number of features is checked regardless of `check_input`
437
+ _check_n_features(self, X, reset=False)
438
+
439
+ return X
440
+
441
+ def predict(self, X, check_input=True):
442
+ r"""Predict risk score.
443
+
444
+ The risk score is the total number of events, which can
445
+ be estimated by the sum of the estimated cumulative
446
+ hazard function :math:`\hat{H}_h` in terminal node :math:`h`.
447
+
448
+ .. math::
449
+
450
+ \sum_{j=1}^{n(h)} \hat{H}_h(T_{j} \mid x) ,
451
+
452
+ where :math:`n(h)` denotes the number of distinct event times
453
+ of samples belonging to the same terminal node as :math:`x`.
454
+
455
+ Parameters
456
+ ----------
457
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
458
+ Data matrix.
459
+ If ``splitter='best'``, `X` is allowed to contain missing
460
+ values and decisions are made as described in
461
+ :ref:`tree_missing_value_support`.
462
+
463
+ check_input : boolean, default: True
464
+ Allow to bypass several input checking.
465
+ Don't use this parameter unless you know what you do.
466
+
467
+ Returns
468
+ -------
469
+ risk_scores : ndarray, shape = (n_samples,), dtype=float
470
+ Predicted risk scores.
471
+ """
472
+
473
+ if self.low_memory:
474
+ check_is_fitted(self, "tree_")
475
+ X = self._validate_X_predict(X, check_input, accept_sparse="csr")
476
+ pred = self.tree_.predict(X)
477
+ return pred[..., 0]
478
+
479
+ chf = self.predict_cumulative_hazard_function(X, check_input, return_array=True)
480
+ return chf[:, self.is_event_time_].sum(1)
481
+
482
+ @append_cumulative_hazard_example(estimator_mod="tree", estimator_class="SurvivalTree")
483
+ def predict_cumulative_hazard_function(self, X, check_input=True, return_array=False):
484
+ """Predict cumulative hazard function.
485
+
486
+ The cumulative hazard function (CHF) for an individual
487
+ with feature vector :math:`x` is computed from
488
+ all samples of the training data that are in the
489
+ same terminal node as :math:`x`.
490
+ It is estimated by the Nelson–Aalen estimator.
491
+
492
+ Parameters
493
+ ----------
494
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
495
+ Data matrix.
496
+ If ``splitter='best'``, `X` is allowed to contain missing
497
+ values and decisions are made as described in
498
+ :ref:`tree_missing_value_support`.
499
+
500
+ check_input : boolean, default: True
501
+ Allow to bypass several input checking.
502
+ Don't use this parameter unless you know what you do.
503
+
504
+ return_array : bool, default: False
505
+ Whether to return a single array of cumulative hazard values
506
+ or a list of step functions.
507
+
508
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
509
+ objects is returned.
510
+
511
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
512
+ returned, where `n_unique_times` is the number of unique
513
+ event times in the training data. Each row represents the cumulative
514
+ hazard function of an individual evaluated at `unique_times_`.
515
+
516
+ Returns
517
+ -------
518
+ cum_hazard : ndarray
519
+ If `return_array` is `False`, an array of `n_samples`
520
+ :class:`sksurv.functions.StepFunction` instances is returned.
521
+
522
+ If `return_array` is `True`, a numeric array of shape
523
+ `(n_samples, n_unique_times_)` is returned.
524
+
525
+ Examples
526
+ --------
527
+ """
528
+ self._check_low_memory("predict_cumulative_hazard_function")
529
+ check_is_fitted(self, "tree_")
530
+ X = self._validate_X_predict(X, check_input, accept_sparse="csr")
531
+
532
+ pred = self.tree_.predict(X)
533
+ arr = pred[..., 0]
534
+ if return_array:
535
+ return arr
536
+ return _array_to_step_function(self.unique_times_, arr)
537
+
538
+ @append_survival_function_example(estimator_mod="tree", estimator_class="SurvivalTree")
539
+ def predict_survival_function(self, X, check_input=True, return_array=False):
540
+ """Predict survival function.
541
+
542
+ The survival function for an individual
543
+ with feature vector :math:`x` is computed from
544
+ all samples of the training data that are in the
545
+ same terminal node as :math:`x`.
546
+ It is estimated by the Kaplan-Meier estimator.
547
+
548
+ Parameters
549
+ ----------
550
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
551
+ Data matrix.
552
+ If ``splitter='best'``, `X` is allowed to contain missing
553
+ values and decisions are made as described in
554
+ :ref:`tree_missing_value_support`.
555
+
556
+ check_input : boolean, default: True
557
+ Allow to bypass several input checking.
558
+ Don't use this parameter unless you know what you do.
559
+
560
+ return_array : bool, default: False
561
+ Whether to return a single array of survival probabilities
562
+ or a list of step functions.
563
+
564
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
565
+ objects is returned.
566
+
567
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
568
+ returned, where `n_unique_times` is the number of unique
569
+ event times in the training data. Each row represents the survival
570
+ function of an individual evaluated at `unique_times_`.
571
+
572
+ Returns
573
+ -------
574
+ survival : ndarray
575
+ If `return_array` is `False`, an array of `n_samples`
576
+ :class:`sksurv.functions.StepFunction` instances is returned.
577
+
578
+ If `return_array` is `True`, a numeric array of shape
579
+ `(n_samples, n_unique_times_)` is returned.
580
+
581
+ Examples
582
+ --------
583
+ """
584
+ self._check_low_memory("predict_survival_function")
585
+ check_is_fitted(self, "tree_")
586
+ X = self._validate_X_predict(X, check_input, accept_sparse="csr")
587
+
588
+ pred = self.tree_.predict(X)
589
+ arr = pred[..., 1]
590
+ if return_array:
591
+ return arr
592
+ return _array_to_step_function(self.unique_times_, arr)
593
+
594
+ def apply(self, X, check_input=True):
595
+ """Return the index of the leaf that each sample is predicted as.
596
+
597
+ Parameters
598
+ ----------
599
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
600
+ The input samples. Internally, it will be converted to
601
+ ``dtype=np.float32`` and if a sparse matrix is provided
602
+ to a sparse ``csr_matrix``.
603
+ If ``splitter='best'``, `X` is allowed to contain missing
604
+ values and decisions are made as described in
605
+ :ref:`tree_missing_value_support`.
606
+
607
+ check_input : bool, default: True
608
+ Allow to bypass several input checking.
609
+ Don't use this parameter unless you know what you do.
610
+
611
+ Returns
612
+ -------
613
+ X_leaves : ndarray, shape = (n_samples,), dtype=int
614
+ For each datapoint x in X, return the index of the leaf x
615
+ ends up in. Leaves are numbered within
616
+ ``[0; self.tree_.node_count)``, possibly with gaps in the
617
+ numbering.
618
+ """
619
+ check_is_fitted(self, "tree_")
620
+ self._validate_X_predict(X, check_input)
621
+ return self.tree_.apply(X)
622
+
623
+ def decision_path(self, X, check_input=True):
624
+ """Return the decision path in the tree.
625
+
626
+ Parameters
627
+ ----------
628
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
629
+ The input samples. Internally, it will be converted to
630
+ ``dtype=np.float32`` and if a sparse matrix is provided
631
+ to a sparse ``csr_matrix``.
632
+ If ``splitter='best'``, `X` is allowed to contain missing
633
+ values and decisions are made as described in
634
+ :ref:`tree_missing_value_support`.
635
+
636
+ check_input : bool, default=True
637
+ Allow to bypass several input checking.
638
+ Don't use this parameter unless you know what you do.
639
+
640
+ Returns
641
+ -------
642
+ indicator : sparse matrix, shape = (n_samples, n_nodes)
643
+ Return a node indicator CSR matrix where non zero elements
644
+ indicates that the samples goes through the nodes.
645
+ """
646
+ X = self._validate_X_predict(X, check_input)
647
+ return self.tree_.decision_path(X)
648
+
649
+
650
+ class ExtraSurvivalTree(SurvivalTree):
651
+ """An Extremely Randomized Survival Tree.
652
+
653
+ This class implements an Extremely Randomized Tree for survival analysis.
654
+ It differs from :class:`SurvivalTree` in how splits are chosen:
655
+ instead of searching for the optimal split, it considers a random subset
656
+ of features and random thresholds for each feature, then picks the best
657
+ among these random candidates.
658
+
659
+ Parameters
660
+ ----------
661
+ splitter : {'best', 'random'}, default: 'random'
662
+ The strategy used to choose the split at each node. Supported
663
+ strategies are 'best' to choose the best split and 'random' to choose
664
+ the best random split.
665
+
666
+ max_depth : int or None, optional, default: None
667
+ The maximum depth of the tree. If None, then nodes are expanded until
668
+ all leaves are pure or until all leaves contain less than
669
+ `min_samples_split` samples.
670
+
671
+ min_samples_split : int, float, optional, default: 6
672
+ The minimum number of samples required to split an internal node:
673
+
674
+ - If int, then consider `min_samples_split` as the minimum number.
675
+ - If float, then `min_samples_split` is a fraction and
676
+ `ceil(min_samples_split * n_samples)` are the minimum
677
+ number of samples for each split.
678
+
679
+ min_samples_leaf : int, float, optional, default: 3
680
+ The minimum number of samples required to be at a leaf node.
681
+ A split point at any depth will only be considered if it leaves at
682
+ least ``min_samples_leaf`` training samples in each of the left and
683
+ right branches. This may have the effect of smoothing the model,
684
+ especially in regression.
685
+
686
+ - If int, then consider `min_samples_leaf` as the minimum number.
687
+ - If float, then `min_samples_leaf` is a fraction and
688
+ `ceil(min_samples_leaf * n_samples)` are the minimum
689
+ number of samples for each node.
690
+
691
+ min_weight_fraction_leaf : float, optional, default: 0.
692
+ The minimum weighted fraction of the sum total of weights (of all
693
+ the input samples) required to be at a leaf node. Samples have
694
+ equal weight when sample_weight is not provided.
695
+
696
+ max_features : int, float or {'sqrt', 'log2'} or None, optional, default: None
697
+ The number of features to consider when looking for the best split:
698
+
699
+ - If int, then consider `max_features` features at each split.
700
+ - If float, then `max_features` is a fraction and
701
+ `max(1, int(max_features * n_features_in_))` features are considered at
702
+ each split.
703
+ - If "sqrt", then `max_features=sqrt(n_features)`.
704
+ - If "log2", then `max_features=log2(n_features)`.
705
+ - If None, then `max_features=n_features`.
706
+
707
+ Note: the search for a split does not stop until at least one
708
+ valid partition of the node samples is found, even if it requires to
709
+ effectively inspect more than ``max_features`` features.
710
+
711
+ random_state : int, RandomState instance or None, optional, default: None
712
+ Controls the randomness of the estimator. The features are always
713
+ randomly permuted at each split, even if ``splitter`` is set to
714
+ ``"best"``. When ``max_features < n_features``, the algorithm will
715
+ select ``max_features`` at random at each split before finding the best
716
+ split among them. But the best found split may vary across different
717
+ runs, even if ``max_features=n_features``. That is the case, if the
718
+ improvement of the criterion is identical for several splits and one
719
+ split has to be selected at random. To obtain a deterministic behavior
720
+ during fitting, ``random_state`` has to be fixed to an integer.
721
+
722
+ max_leaf_nodes : int or None, optional, default: None
723
+ Grow a tree with ``max_leaf_nodes`` in best-first fashion.
724
+ Best nodes are defined as relative reduction in impurity.
725
+ If None then unlimited number of leaf nodes.
726
+
727
+ low_memory : bool, optional, default: False
728
+ If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
729
+ and :meth:`predict_survival_function` are not implemented.
730
+
731
+ Attributes
732
+ ----------
733
+ unique_times_ : ndarray, shape = (n_unique_times,), dtype = float
734
+ Unique time points.
735
+
736
+ max_features_ : int
737
+ The inferred value of max_features.
738
+
739
+ n_features_in_ : int
740
+ Number of features seen during ``fit``.
741
+
742
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,), dtype = object
743
+ Names of features seen during ``fit``. Defined only when `X`
744
+ has feature names that are all strings.
745
+
746
+ tree_ : Tree object
747
+ The underlying Tree object. Please refer to
748
+ ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object.
749
+
750
+ See also
751
+ --------
752
+ sksurv.ensemble.ExtraSurvivalTrees : An ensemble of ExtraSurvivalTrees.
753
+ """
754
+
755
+ def __init__(
756
+ self,
757
+ *,
758
+ splitter="random",
759
+ max_depth=None,
760
+ min_samples_split=6,
761
+ min_samples_leaf=3,
762
+ min_weight_fraction_leaf=0.0,
763
+ max_features=None,
764
+ random_state=None,
765
+ max_leaf_nodes=None,
766
+ low_memory=False,
767
+ ):
768
+ super().__init__(
769
+ splitter=splitter,
770
+ max_depth=max_depth,
771
+ min_samples_split=min_samples_split,
772
+ min_samples_leaf=min_samples_leaf,
773
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
774
+ max_features=max_features,
775
+ random_state=random_state,
776
+ max_leaf_nodes=max_leaf_nodes,
777
+ low_memory=low_memory,
778
+ )
779
+
780
+ def predict_cumulative_hazard_function(self, X, check_input=True, return_array=False):
781
+ ExtraSurvivalTree.predict_cumulative_hazard_function.__doc__ = (
782
+ SurvivalTree.predict_cumulative_hazard_function.__doc__.replace("SurvivalTree", "ExtraSurvivalTree")
783
+ )
784
+ return super().predict_cumulative_hazard_function(X, check_input=check_input, return_array=return_array)
785
+
786
+ def predict_survival_function(self, X, check_input=True, return_array=False):
787
+ ExtraSurvivalTree.predict_survival_function.__doc__ = SurvivalTree.predict_survival_function.__doc__.replace(
788
+ "SurvivalTree", "ExtraSurvivalTree"
789
+ )
790
+ return super().predict_survival_function(X, check_input=check_input, return_array=return_array)