scikit-survival 0.23.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. scikit_survival-0.23.1.dist-info/COPYING +674 -0
  2. scikit_survival-0.23.1.dist-info/METADATA +888 -0
  3. scikit_survival-0.23.1.dist-info/RECORD +55 -0
  4. scikit_survival-0.23.1.dist-info/WHEEL +5 -0
  5. scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +138 -0
  7. sksurv/base.py +103 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
  10. sksurv/column.py +201 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +10 -0
  13. sksurv/datasets/base.py +436 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  17. sksurv/datasets/data/flchain.arff +7887 -0
  18. sksurv/datasets/data/veteran.arff +148 -0
  19. sksurv/datasets/data/whas500.arff +520 -0
  20. sksurv/ensemble/__init__.py +2 -0
  21. sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
  22. sksurv/ensemble/boosting.py +1610 -0
  23. sksurv/ensemble/forest.py +947 -0
  24. sksurv/ensemble/survival_loss.py +151 -0
  25. sksurv/exceptions.py +18 -0
  26. sksurv/functions.py +114 -0
  27. sksurv/io/__init__.py +2 -0
  28. sksurv/io/arffread.py +58 -0
  29. sksurv/io/arffwrite.py +145 -0
  30. sksurv/kernels/__init__.py +1 -0
  31. sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
  32. sksurv/kernels/clinical.py +328 -0
  33. sksurv/linear_model/__init__.py +3 -0
  34. sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
  35. sksurv/linear_model/aft.py +205 -0
  36. sksurv/linear_model/coxnet.py +543 -0
  37. sksurv/linear_model/coxph.py +618 -0
  38. sksurv/meta/__init__.py +4 -0
  39. sksurv/meta/base.py +35 -0
  40. sksurv/meta/ensemble_selection.py +642 -0
  41. sksurv/meta/stacking.py +349 -0
  42. sksurv/metrics.py +996 -0
  43. sksurv/nonparametric.py +588 -0
  44. sksurv/preprocessing.py +155 -0
  45. sksurv/svm/__init__.py +11 -0
  46. sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
  47. sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
  48. sksurv/svm/minlip.py +606 -0
  49. sksurv/svm/naive_survival_svm.py +221 -0
  50. sksurv/svm/survival_svm.py +1228 -0
  51. sksurv/testing.py +108 -0
  52. sksurv/tree/__init__.py +1 -0
  53. sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
  54. sksurv/tree/tree.py +703 -0
  55. sksurv/util.py +333 -0
sksurv/tree/tree.py ADDED
@@ -0,0 +1,703 @@
1
+ from math import ceil
2
+ from numbers import Integral, Real
3
+
4
+ import numpy as np
5
+ from scipy.sparse import issparse
6
+ from sklearn.base import BaseEstimator
7
+ from sklearn.tree import _tree
8
+ from sklearn.tree._classes import DENSE_SPLITTERS, SPARSE_SPLITTERS
9
+ from sklearn.tree._splitter import Splitter
10
+ from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
11
+ from sklearn.tree._utils import _any_isnan_axis0
12
+ from sklearn.utils._param_validation import Interval, StrOptions
13
+ from sklearn.utils.validation import (
14
+ _assert_all_finite_element_wise,
15
+ assert_all_finite,
16
+ check_is_fitted,
17
+ check_random_state,
18
+ )
19
+
20
+ from ..base import SurvivalAnalysisMixin
21
+ from ..functions import StepFunction
22
+ from ..util import check_array_survival
23
+ from ._criterion import LogrankCriterion, get_unique_times
24
+
25
+ __all__ = ["ExtraSurvivalTree", "SurvivalTree"]
26
+
27
+ DTYPE = _tree.DTYPE
28
+
29
+
30
+ def _array_to_step_function(x, array):
31
+ n_samples = array.shape[0]
32
+ funcs = np.empty(n_samples, dtype=np.object_)
33
+ for i in range(n_samples):
34
+ funcs[i] = StepFunction(x=x, y=array[i])
35
+ return funcs
36
+
37
+
38
+ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
39
+ """A survival tree.
40
+
41
+ The quality of a split is measured by the
42
+ log-rank splitting rule.
43
+
44
+ If ``splitter='best'``, fit and predict methods support
45
+ missing values. See :ref:`tree_missing_value_support` for details.
46
+
47
+ See [1]_, [2]_ and [3]_ for further description.
48
+
49
+ Parameters
50
+ ----------
51
+ splitter : {'best', 'random'}, default: 'best'
52
+ The strategy used to choose the split at each node. Supported
53
+ strategies are 'best' to choose the best split and 'random' to choose
54
+ the best random split.
55
+
56
+ max_depth : int or None, optional, default: None
57
+ The maximum depth of the tree. If None, then nodes are expanded until
58
+ all leaves are pure or until all leaves contain less than
59
+ `min_samples_split` samples.
60
+
61
+ min_samples_split : int, float, optional, default: 6
62
+ The minimum number of samples required to split an internal node:
63
+
64
+ - If int, then consider `min_samples_split` as the minimum number.
65
+ - If float, then `min_samples_split` is a fraction and
66
+ `ceil(min_samples_split * n_samples)` are the minimum
67
+ number of samples for each split.
68
+
69
+ min_samples_leaf : int, float, optional, default: 3
70
+ The minimum number of samples required to be at a leaf node.
71
+ A split point at any depth will only be considered if it leaves at
72
+ least ``min_samples_leaf`` training samples in each of the left and
73
+ right branches. This may have the effect of smoothing the model,
74
+ especially in regression.
75
+
76
+ - If int, then consider `min_samples_leaf` as the minimum number.
77
+ - If float, then `min_samples_leaf` is a fraction and
78
+ `ceil(min_samples_leaf * n_samples)` are the minimum
79
+ number of samples for each node.
80
+
81
+ min_weight_fraction_leaf : float, optional, default: 0.
82
+ The minimum weighted fraction of the sum total of weights (of all
83
+ the input samples) required to be at a leaf node. Samples have
84
+ equal weight when sample_weight is not provided.
85
+
86
+ max_features : int, float, string or None, optional, default: None
87
+ The number of features to consider when looking for the best split:
88
+
89
+ - If int, then consider `max_features` features at each split.
90
+ - If float, then `max_features` is a fraction and
91
+ `max(1, int(max_features * n_features_in_))` features are considered at
92
+ each split.
93
+ - If "sqrt", then `max_features=sqrt(n_features)`.
94
+ - If "log2", then `max_features=log2(n_features)`.
95
+ - If None, then `max_features=n_features`.
96
+
97
+ Note: the search for a split does not stop until at least one
98
+ valid partition of the node samples is found, even if it requires to
99
+ effectively inspect more than ``max_features`` features.
100
+
101
+ random_state : int, RandomState instance or None, optional, default: None
102
+ Controls the randomness of the estimator. The features are always
103
+ randomly permuted at each split, even if ``splitter`` is set to
104
+ ``"best"``. When ``max_features < n_features``, the algorithm will
105
+ select ``max_features`` at random at each split before finding the best
106
+ split among them. But the best found split may vary across different
107
+ runs, even if ``max_features=n_features``. That is the case, if the
108
+ improvement of the criterion is identical for several splits and one
109
+ split has to be selected at random. To obtain a deterministic behavior
110
+ during fitting, ``random_state`` has to be fixed to an integer.
111
+
112
+ max_leaf_nodes : int or None, optional, default: None
113
+ Grow a tree with ``max_leaf_nodes`` in best-first fashion.
114
+ Best nodes are defined as relative reduction in impurity.
115
+ If None then unlimited number of leaf nodes.
116
+
117
+ low_memory : boolean, default: False
118
+ If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function``
119
+ and ``predict_survival_function`` are not implemented.
120
+
121
+ Attributes
122
+ ----------
123
+ unique_times_ : array of shape = (n_unique_times,)
124
+ Unique time points.
125
+
126
+ max_features_ : int,
127
+ The inferred value of max_features.
128
+
129
+ n_features_in_ : int
130
+ Number of features seen during ``fit``.
131
+
132
+ feature_names_in_ : ndarray of shape (`n_features_in_`,)
133
+ Names of features seen during ``fit``. Defined only when `X`
134
+ has feature names that are all strings.
135
+
136
+ tree_ : Tree object
137
+ The underlying Tree object. Please refer to
138
+ ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object.
139
+
140
+ See also
141
+ --------
142
+ sksurv.ensemble.RandomSurvivalForest
143
+ An ensemble of SurvivalTrees.
144
+
145
+ References
146
+ ----------
147
+ .. [1] Leblanc, M., & Crowley, J. (1993). Survival Trees by Goodness of Split.
148
+ Journal of the American Statistical Association, 88(422), 457–467.
149
+
150
+ .. [2] Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008).
151
+ Random survival forests. The Annals of Applied Statistics, 2(3), 841–860.
152
+
153
+ .. [3] Ishwaran, H., Kogalur, U. B. (2007). Random survival forests for R.
154
+ R News, 7(2), 25–31. https://cran.r-project.org/doc/Rnews/Rnews_2007-2.pdf.
155
+ """
156
+
157
+ _parameter_constraints = {
158
+ "splitter": [StrOptions({"best", "random"})],
159
+ "max_depth": [Interval(Integral, 1, None, closed="left"), None],
160
+ "min_samples_split": [
161
+ Interval(Integral, 2, None, closed="left"),
162
+ Interval(Real, 0.0, 1.0, closed="neither"),
163
+ ],
164
+ "min_samples_leaf": [
165
+ Interval(Integral, 1, None, closed="left"),
166
+ Interval(Real, 0.0, 0.5, closed="right"),
167
+ ],
168
+ "min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
169
+ "max_features": [
170
+ Interval(Integral, 1, None, closed="left"),
171
+ Interval(Real, 0.0, 1.0, closed="right"),
172
+ StrOptions({"sqrt", "log2"}),
173
+ None,
174
+ ],
175
+ "random_state": ["random_state"],
176
+ "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
177
+ "low_memory": ["boolean"],
178
+ }
179
+
180
+ criterion = "logrank"
181
+
182
+ def __init__(
183
+ self,
184
+ *,
185
+ splitter="best",
186
+ max_depth=None,
187
+ min_samples_split=6,
188
+ min_samples_leaf=3,
189
+ min_weight_fraction_leaf=0.0,
190
+ max_features=None,
191
+ random_state=None,
192
+ max_leaf_nodes=None,
193
+ low_memory=False,
194
+ ):
195
+ self.splitter = splitter
196
+ self.max_depth = max_depth
197
+ self.min_samples_split = min_samples_split
198
+ self.min_samples_leaf = min_samples_leaf
199
+ self.min_weight_fraction_leaf = min_weight_fraction_leaf
200
+ self.max_features = max_features
201
+ self.random_state = random_state
202
+ self.max_leaf_nodes = max_leaf_nodes
203
+ self.low_memory = low_memory
204
+
205
+ def _more_tags(self):
206
+ allow_nan = self.splitter == "best"
207
+ return {"allow_nan": allow_nan}
208
+
209
+ def _support_missing_values(self, X):
210
+ return not issparse(X) and self._get_tags()["allow_nan"]
211
+
212
+ def _compute_missing_values_in_feature_mask(self, X, estimator_name=None):
213
+ """Return boolean mask denoting if there are missing values for each feature.
214
+
215
+ This method also ensures that X is finite.
216
+
217
+ Parameter
218
+ ---------
219
+ X : array-like of shape (n_samples, n_features), dtype=DOUBLE
220
+ Input data.
221
+
222
+ estimator_name : str or None, default=None
223
+ Name to use when raising an error. Defaults to the class name.
224
+
225
+ Returns
226
+ -------
227
+ missing_values_in_feature_mask : ndarray of shape (n_features,), or None
228
+ Missing value mask. If missing values are not supported or there
229
+ are no missing values, return None.
230
+ """
231
+ estimator_name = estimator_name or self.__class__.__name__
232
+ common_kwargs = dict(estimator_name=estimator_name, input_name="X")
233
+
234
+ if not self._support_missing_values(X):
235
+ assert_all_finite(X, **common_kwargs)
236
+ return None
237
+
238
+ with np.errstate(over="ignore"):
239
+ overall_sum = np.sum(X)
240
+
241
+ if not np.isfinite(overall_sum):
242
+ # Raise a ValueError in case of the presence of an infinite element.
243
+ _assert_all_finite_element_wise(X, xp=np, allow_nan=True, **common_kwargs)
244
+
245
+ # If the sum is not nan, then there are no missing values
246
+ if not np.isnan(overall_sum):
247
+ return None
248
+
249
+ missing_values_in_feature_mask = _any_isnan_axis0(X)
250
+ return missing_values_in_feature_mask
251
+
252
+ def fit(self, X, y, sample_weight=None, check_input=True):
253
+ """Build a survival tree from the training set (X, y).
254
+
255
+ If ``splitter='best'``, `X` is allowed to contain missing
256
+ values. In addition to evaluating each potential threshold on
257
+ the non-missing data, the splitter will evaluate the split
258
+ with all the missing values going to the left node or the
259
+ right node. See :ref:`tree_missing_value_support` for details.
260
+
261
+ Parameters
262
+ ----------
263
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
264
+ Data matrix
265
+
266
+ y : structured array, shape = (n_samples,)
267
+ A structured array containing the binary event indicator
268
+ as first field, and time of event or time of censoring as
269
+ second field.
270
+
271
+ check_input : boolean, default: True
272
+ Allow to bypass several input checking.
273
+ Don't use this parameter unless you know what you do.
274
+
275
+ Returns
276
+ -------
277
+ self
278
+ """
279
+ self._fit(X, y, sample_weight, check_input)
280
+ return self
281
+
282
+ def _fit(self, X, y, sample_weight=None, check_input=True, missing_values_in_feature_mask=None):
283
+ random_state = check_random_state(self.random_state)
284
+
285
+ if check_input:
286
+ X = self._validate_data(X, dtype=DTYPE, ensure_min_samples=2, accept_sparse="csc", force_all_finite=False)
287
+ event, time = check_array_survival(X, y)
288
+ time = time.astype(np.float64)
289
+ self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
290
+ missing_values_in_feature_mask = self._compute_missing_values_in_feature_mask(X)
291
+ if issparse(X):
292
+ X.sort_indices()
293
+
294
+ y_numeric = np.empty((X.shape[0], 2), dtype=np.float64)
295
+ y_numeric[:, 0] = time
296
+ y_numeric[:, 1] = event.astype(np.float64)
297
+ else:
298
+ y_numeric, self.unique_times_, self.is_event_time_ = y
299
+
300
+ n_samples, self.n_features_in_ = X.shape
301
+ params = self._check_params(n_samples)
302
+
303
+ if self.low_memory:
304
+ self.n_outputs_ = 1
305
+ # one "class" only, for the sum over the CHF
306
+ self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp)
307
+ else:
308
+ self.n_outputs_ = self.unique_times_.shape[0]
309
+ # one "class" for CHF, one for survival function
310
+ self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2
311
+
312
+ # Build tree
313
+ criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_, self.is_event_time_)
314
+
315
+ SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
316
+
317
+ splitter = self.splitter
318
+ if not isinstance(self.splitter, Splitter):
319
+ splitter = SPLITTERS[self.splitter](
320
+ criterion,
321
+ self.max_features_,
322
+ params["min_samples_leaf"],
323
+ params["min_weight_leaf"],
324
+ random_state,
325
+ None, # monotonic_cst
326
+ )
327
+
328
+ self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)
329
+
330
+ # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
331
+ if params["max_leaf_nodes"] < 0:
332
+ builder = DepthFirstTreeBuilder(
333
+ splitter,
334
+ params["min_samples_split"],
335
+ params["min_samples_leaf"],
336
+ params["min_weight_leaf"],
337
+ params["max_depth"],
338
+ 0.0, # min_impurity_decrease
339
+ )
340
+ else:
341
+ builder = BestFirstTreeBuilder(
342
+ splitter,
343
+ params["min_samples_split"],
344
+ params["min_samples_leaf"],
345
+ params["min_weight_leaf"],
346
+ params["max_depth"],
347
+ params["max_leaf_nodes"],
348
+ 0.0, # min_impurity_decrease
349
+ )
350
+
351
+ builder.build(self.tree_, X, y_numeric, sample_weight, missing_values_in_feature_mask)
352
+
353
+ return self
354
+
355
+ def _check_params(self, n_samples):
356
+ self._validate_params()
357
+
358
+ # Check parameters
359
+ max_depth = (2**31) - 1 if self.max_depth is None else self.max_depth
360
+
361
+ max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes
362
+
363
+ if isinstance(self.min_samples_leaf, (Integral, np.integer)):
364
+ min_samples_leaf = self.min_samples_leaf
365
+ else: # float
366
+ min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
367
+
368
+ if isinstance(self.min_samples_split, Integral):
369
+ min_samples_split = self.min_samples_split
370
+ else: # float
371
+ min_samples_split = int(ceil(self.min_samples_split * n_samples))
372
+ min_samples_split = max(2, min_samples_split)
373
+
374
+ min_samples_split = max(min_samples_split, 2 * min_samples_leaf)
375
+
376
+ self._check_max_features()
377
+
378
+ if not 0 <= self.min_weight_fraction_leaf <= 0.5:
379
+ raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
380
+
381
+ min_weight_leaf = self.min_weight_fraction_leaf * n_samples
382
+
383
+ return {
384
+ "max_depth": max_depth,
385
+ "max_leaf_nodes": max_leaf_nodes,
386
+ "min_samples_leaf": min_samples_leaf,
387
+ "min_samples_split": min_samples_split,
388
+ "min_weight_leaf": min_weight_leaf,
389
+ }
390
+
391
+ def _check_max_features(self):
392
+ if isinstance(self.max_features, str):
393
+ if self.max_features == "sqrt":
394
+ max_features = max(1, int(np.sqrt(self.n_features_in_)))
395
+ elif self.max_features == "log2":
396
+ max_features = max(1, int(np.log2(self.n_features_in_)))
397
+
398
+ elif self.max_features is None:
399
+ max_features = self.n_features_in_
400
+ elif isinstance(self.max_features, (Integral, np.integer)):
401
+ max_features = self.max_features
402
+ else: # float
403
+ if self.max_features > 0.0:
404
+ max_features = max(1, int(self.max_features * self.n_features_in_))
405
+ else:
406
+ max_features = 0
407
+
408
+ if not 0 < max_features <= self.n_features_in_:
409
+ raise ValueError("max_features must be in (0, n_features]")
410
+
411
+ self.max_features_ = max_features
412
+
413
+ def _check_low_memory(self, function):
414
+ """Check if `function` is supported in low memory mode and throw if it is not."""
415
+ if self.low_memory:
416
+ raise NotImplementedError(
417
+ f"{function} is not implemented in low memory mode."
418
+ + " run fit with low_memory=False to disable low memory mode."
419
+ )
420
+
421
+ def _validate_X_predict(self, X, check_input, accept_sparse="csr"):
422
+ """Validate X whenever one tries to predict"""
423
+ if check_input:
424
+ if self._support_missing_values(X):
425
+ force_all_finite = "allow-nan"
426
+ else:
427
+ force_all_finite = True
428
+ X = self._validate_data(
429
+ X,
430
+ dtype=DTYPE,
431
+ accept_sparse=accept_sparse,
432
+ reset=False,
433
+ force_all_finite=force_all_finite,
434
+ )
435
+ else:
436
+ # The number of features is checked regardless of `check_input`
437
+ self._check_n_features(X, reset=False)
438
+
439
+ return X
440
+
441
+ def predict(self, X, check_input=True):
442
+ """Predict risk score.
443
+
444
+ The risk score is the total number of events, which can
445
+ be estimated by the sum of the estimated cumulative
446
+ hazard function :math:`\\hat{H}_h` in terminal node :math:`h`.
447
+
448
+ .. math::
449
+
450
+ \\sum_{j=1}^{n(h)} \\hat{H}_h(T_{j} \\mid x) ,
451
+
452
+ where :math:`n(h)` denotes the number of distinct event times
453
+ of samples belonging to the same terminal node as :math:`x`.
454
+
455
+ Parameters
456
+ ----------
457
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
458
+ Data matrix.
459
+ If ``splitter='best'``, `X` is allowed to contain missing
460
+ values and decisions are made as described in
461
+ :ref:`tree_missing_value_support`.
462
+
463
+ check_input : boolean, default: True
464
+ Allow to bypass several input checking.
465
+ Don't use this parameter unless you know what you do.
466
+
467
+ Returns
468
+ -------
469
+ risk_scores : ndarray, shape = (n_samples,)
470
+ Predicted risk scores.
471
+ """
472
+
473
+ if self.low_memory:
474
+ check_is_fitted(self, "tree_")
475
+ X = self._validate_X_predict(X, check_input, accept_sparse="csr")
476
+ pred = self.tree_.predict(X)
477
+ return pred[..., 0]
478
+
479
+ chf = self.predict_cumulative_hazard_function(X, check_input, return_array=True)
480
+ return chf[:, self.is_event_time_].sum(1)
481
+
482
+ def predict_cumulative_hazard_function(self, X, check_input=True, return_array=False):
483
+ """Predict cumulative hazard function.
484
+
485
+ The cumulative hazard function (CHF) for an individual
486
+ with feature vector :math:`x` is computed from
487
+ all samples of the training data that are in the
488
+ same terminal node as :math:`x`.
489
+ It is estimated by the Nelson–Aalen estimator.
490
+
491
+ Parameters
492
+ ----------
493
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
494
+ Data matrix.
495
+ If ``splitter='best'``, `X` is allowed to contain missing
496
+ values and decisions are made as described in
497
+ :ref:`tree_missing_value_support`.
498
+
499
+ check_input : boolean, default: True
500
+ Allow to bypass several input checking.
501
+ Don't use this parameter unless you know what you do.
502
+
503
+ return_array : boolean, default: False
504
+ If set, return an array with the cumulative hazard rate
505
+ for each `self.unique_times_`, otherwise an array of
506
+ :class:`sksurv.functions.StepFunction`.
507
+
508
+ Returns
509
+ -------
510
+ cum_hazard : ndarray
511
+ If `return_array` is set, an array with the cumulative hazard rate
512
+ for each `self.unique_times_`, otherwise an array of length `n_samples`
513
+ of :class:`sksurv.functions.StepFunction` instances will be returned.
514
+
515
+ Examples
516
+ --------
517
+ >>> import matplotlib.pyplot as plt
518
+ >>> from sksurv.datasets import load_whas500
519
+ >>> from sksurv.tree import SurvivalTree
520
+
521
+ Load and prepare the data.
522
+
523
+ >>> X, y = load_whas500()
524
+ >>> X = X.astype(float)
525
+
526
+ Fit the model.
527
+
528
+ >>> estimator = SurvivalTree().fit(X, y)
529
+
530
+ Estimate the cumulative hazard function for the first 5 samples.
531
+
532
+ >>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
533
+
534
+ Plot the estimated cumulative hazard functions.
535
+
536
+ >>> for fn in chf_funcs:
537
+ ... plt.step(fn.x, fn(fn.x), where="post")
538
+ ...
539
+ >>> plt.ylim(0, 1)
540
+ >>> plt.show()
541
+ """
542
+ self._check_low_memory("predict_cumulative_hazard_function")
543
+ check_is_fitted(self, "tree_")
544
+ X = self._validate_X_predict(X, check_input, accept_sparse="csr")
545
+
546
+ pred = self.tree_.predict(X)
547
+ arr = pred[..., 0]
548
+ if return_array:
549
+ return arr
550
+ return _array_to_step_function(self.unique_times_, arr)
551
+
552
+ def predict_survival_function(self, X, check_input=True, return_array=False):
553
+ """Predict survival function.
554
+
555
+ The survival function for an individual
556
+ with feature vector :math:`x` is computed from
557
+ all samples of the training data that are in the
558
+ same terminal node as :math:`x`.
559
+ It is estimated by the Kaplan-Meier estimator.
560
+
561
+ Parameters
562
+ ----------
563
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
564
+ Data matrix.
565
+ If ``splitter='best'``, `X` is allowed to contain missing
566
+ values and decisions are made as described in
567
+ :ref:`tree_missing_value_support`.
568
+
569
+ check_input : boolean, default: True
570
+ Allow to bypass several input checking.
571
+ Don't use this parameter unless you know what you do.
572
+
573
+ return_array : boolean, default: False
574
+ If set, return an array with the probability
575
+ of survival for each `self.unique_times_`,
576
+ otherwise an array of :class:`sksurv.functions.StepFunction`.
577
+
578
+ Returns
579
+ -------
580
+ survival : ndarray
581
+ If `return_array` is set, an array with the probability of
582
+ survival for each `self.unique_times_`, otherwise an array of
583
+ length `n_samples` of :class:`sksurv.functions.StepFunction`
584
+ instances will be returned.
585
+
586
+ Examples
587
+ --------
588
+ >>> import matplotlib.pyplot as plt
589
+ >>> from sksurv.datasets import load_whas500
590
+ >>> from sksurv.tree import SurvivalTree
591
+
592
+ Load and prepare the data.
593
+
594
+ >>> X, y = load_whas500()
595
+ >>> X = X.astype(float)
596
+
597
+ Fit the model.
598
+
599
+ >>> estimator = SurvivalTree().fit(X, y)
600
+
601
+ Estimate the survival function for the first 5 samples.
602
+
603
+ >>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
604
+
605
+ Plot the estimated survival functions.
606
+
607
+ >>> for fn in surv_funcs:
608
+ ... plt.step(fn.x, fn(fn.x), where="post")
609
+ ...
610
+ >>> plt.ylim(0, 1)
611
+ >>> plt.show()
612
+ """
613
+ self._check_low_memory("predict_survival_function")
614
+ check_is_fitted(self, "tree_")
615
+ X = self._validate_X_predict(X, check_input, accept_sparse="csr")
616
+
617
+ pred = self.tree_.predict(X)
618
+ arr = pred[..., 1]
619
+ if return_array:
620
+ return arr
621
+ return _array_to_step_function(self.unique_times_, arr)
622
+
623
+ def apply(self, X, check_input=True):
624
+ """Return the index of the leaf that each sample is predicted as.
625
+
626
+ Parameters
627
+ ----------
628
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
629
+ The input samples. Internally, it will be converted to
630
+ ``dtype=np.float32`` and if a sparse matrix is provided
631
+ to a sparse ``csr_matrix``.
632
+ If ``splitter='best'``, `X` is allowed to contain missing
633
+ values and decisions are made as described in
634
+ :ref:`tree_missing_value_support`.
635
+
636
+ check_input : bool, default: True
637
+ Allow to bypass several input checking.
638
+ Don't use this parameter unless you know what you do.
639
+
640
+ Returns
641
+ -------
642
+ X_leaves : array-like, shape = (n_samples,)
643
+ For each datapoint x in X, return the index of the leaf x
644
+ ends up in. Leaves are numbered within
645
+ ``[0; self.tree_.node_count)``, possibly with gaps in the
646
+ numbering.
647
+ """
648
+ check_is_fitted(self, "tree_")
649
+ self._validate_X_predict(X, check_input)
650
+ return self.tree_.apply(X)
651
+
652
+ def decision_path(self, X, check_input=True):
653
+ """Return the decision path in the tree.
654
+
655
+ Parameters
656
+ ----------
657
+ X : array-like or sparse matrix, shape = (n_samples, n_features)
658
+ The input samples. Internally, it will be converted to
659
+ ``dtype=np.float32`` and if a sparse matrix is provided
660
+ to a sparse ``csr_matrix``.
661
+ If ``splitter='best'``, `X` is allowed to contain missing
662
+ values and decisions are made as described in
663
+ :ref:`tree_missing_value_support`.
664
+
665
+ check_input : bool, default=True
666
+ Allow to bypass several input checking.
667
+ Don't use this parameter unless you know what you do.
668
+
669
+ Returns
670
+ -------
671
+ indicator : sparse matrix, shape = (n_samples, n_nodes)
672
+ Return a node indicator CSR matrix where non zero elements
673
+ indicates that the samples goes through the nodes.
674
+ """
675
+ X = self._validate_X_predict(X, check_input)
676
+ return self.tree_.decision_path(X)
677
+
678
+
679
+ class ExtraSurvivalTree(SurvivalTree):
680
+ def __init__(
681
+ self,
682
+ *,
683
+ splitter="random",
684
+ max_depth=None,
685
+ min_samples_split=6,
686
+ min_samples_leaf=3,
687
+ min_weight_fraction_leaf=0.0,
688
+ max_features=None,
689
+ random_state=None,
690
+ max_leaf_nodes=None,
691
+ low_memory=False,
692
+ ):
693
+ super().__init__(
694
+ splitter=splitter,
695
+ max_depth=max_depth,
696
+ min_samples_split=min_samples_split,
697
+ min_samples_leaf=min_samples_leaf,
698
+ min_weight_fraction_leaf=min_weight_fraction_leaf,
699
+ max_features=max_features,
700
+ random_state=random_state,
701
+ max_leaf_nodes=max_leaf_nodes,
702
+ low_memory=low_memory,
703
+ )