scikit-survival 0.25.0__cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.25.0.dist-info/METADATA +185 -0
- scikit_survival-0.25.0.dist-info/RECORD +58 -0
- scikit_survival-0.25.0.dist-info/WHEEL +6 -0
- scikit_survival-0.25.0.dist-info/licenses/COPYING +674 -0
- scikit_survival-0.25.0.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +183 -0
- sksurv/base.py +115 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cpython-313-x86_64-linux-gnu.so +0 -0
- sksurv/column.py +205 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +12 -0
- sksurv/datasets/base.py +614 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/bmt.arff +46 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/cgvhd.arff +118 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cpython-313-x86_64-linux-gnu.so +0 -0
- sksurv/ensemble/boosting.py +1564 -0
- sksurv/ensemble/forest.py +902 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +89 -0
- sksurv/io/arffwrite.py +181 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cpython-313-x86_64-linux-gnu.so +0 -0
- sksurv/kernels/clinical.py +348 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cpython-313-x86_64-linux-gnu.so +0 -0
- sksurv/linear_model/aft.py +208 -0
- sksurv/linear_model/coxnet.py +592 -0
- sksurv/linear_model/coxph.py +637 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +724 -0
- sksurv/meta/stacking.py +370 -0
- sksurv/metrics.py +1028 -0
- sksurv/nonparametric.py +911 -0
- sksurv/preprocessing.py +183 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cpython-313-x86_64-linux-gnu.so +0 -0
- sksurv/svm/_prsvm.cpython-313-x86_64-linux-gnu.so +0 -0
- sksurv/svm/minlip.py +690 -0
- sksurv/svm/naive_survival_svm.py +249 -0
- sksurv/svm/survival_svm.py +1236 -0
- sksurv/testing.py +108 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cpython-313-x86_64-linux-gnu.so +0 -0
- sksurv/tree/tree.py +790 -0
- sksurv/util.py +415 -0
sksurv/tree/tree.py
ADDED
|
@@ -0,0 +1,790 @@
|
|
|
1
|
+
from math import ceil
|
|
2
|
+
from numbers import Integral, Real
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from scipy.sparse import issparse
|
|
6
|
+
from sklearn.base import BaseEstimator
|
|
7
|
+
from sklearn.tree import _tree
|
|
8
|
+
from sklearn.tree._classes import DENSE_SPLITTERS, SPARSE_SPLITTERS
|
|
9
|
+
from sklearn.tree._splitter import Splitter
|
|
10
|
+
from sklearn.tree._tree import BestFirstTreeBuilder, DepthFirstTreeBuilder, Tree
|
|
11
|
+
from sklearn.tree._utils import _any_isnan_axis0
|
|
12
|
+
from sklearn.utils._param_validation import Interval, RealNotInt, StrOptions
|
|
13
|
+
from sklearn.utils.validation import (
|
|
14
|
+
_assert_all_finite_element_wise,
|
|
15
|
+
_check_n_features,
|
|
16
|
+
assert_all_finite,
|
|
17
|
+
check_is_fitted,
|
|
18
|
+
check_random_state,
|
|
19
|
+
validate_data,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
from ..base import SurvivalAnalysisMixin
|
|
23
|
+
from ..docstrings import append_cumulative_hazard_example, append_survival_function_example
|
|
24
|
+
from ..functions import StepFunction
|
|
25
|
+
from ..util import check_array_survival
|
|
26
|
+
from ._criterion import LogrankCriterion, get_unique_times
|
|
27
|
+
|
|
28
|
+
__all__ = ["ExtraSurvivalTree", "SurvivalTree"]
|
|
29
|
+
|
|
30
|
+
DTYPE = _tree.DTYPE
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _array_to_step_function(x, array):
|
|
34
|
+
n_samples = array.shape[0]
|
|
35
|
+
funcs = np.empty(n_samples, dtype=np.object_)
|
|
36
|
+
for i in range(n_samples):
|
|
37
|
+
funcs[i] = StepFunction(x=x, y=array[i])
|
|
38
|
+
return funcs
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
42
|
+
"""A single survival tree.
|
|
43
|
+
|
|
44
|
+
The quality of a split is measured by the log-rank splitting rule.
|
|
45
|
+
|
|
46
|
+
If ``splitter='best'``, fit and predict methods support
|
|
47
|
+
missing values. See :ref:`tree_missing_value_support` for details.
|
|
48
|
+
|
|
49
|
+
See [1]_, [2]_ and [3]_ for further description.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
splitter : {'best', 'random'}, default: 'best'
|
|
54
|
+
The strategy used to choose the split at each node. Supported
|
|
55
|
+
strategies are 'best' to choose the best split and 'random' to choose
|
|
56
|
+
the best random split.
|
|
57
|
+
|
|
58
|
+
max_depth : int or None, optional, default: None
|
|
59
|
+
The maximum depth of the tree. If None, then nodes are expanded until
|
|
60
|
+
all leaves are pure or until all leaves contain less than
|
|
61
|
+
`min_samples_split` samples.
|
|
62
|
+
|
|
63
|
+
min_samples_split : int, float, optional, default: 6
|
|
64
|
+
The minimum number of samples required to split an internal node:
|
|
65
|
+
|
|
66
|
+
- If int, then consider `min_samples_split` as the minimum number.
|
|
67
|
+
- If float, then `min_samples_split` is a fraction and
|
|
68
|
+
`ceil(min_samples_split * n_samples)` are the minimum
|
|
69
|
+
number of samples for each split.
|
|
70
|
+
|
|
71
|
+
min_samples_leaf : int, float, optional, default: 3
|
|
72
|
+
The minimum number of samples required to be at a leaf node.
|
|
73
|
+
A split point at any depth will only be considered if it leaves at
|
|
74
|
+
least ``min_samples_leaf`` training samples in each of the left and
|
|
75
|
+
right branches. This may have the effect of smoothing the model,
|
|
76
|
+
especially in regression.
|
|
77
|
+
|
|
78
|
+
- If int, then consider `min_samples_leaf` as the minimum number.
|
|
79
|
+
- If float, then `min_samples_leaf` is a fraction and
|
|
80
|
+
`ceil(min_samples_leaf * n_samples)` are the minimum
|
|
81
|
+
number of samples for each node.
|
|
82
|
+
|
|
83
|
+
min_weight_fraction_leaf : float, optional, default: 0.
|
|
84
|
+
The minimum weighted fraction of the sum total of weights (of all
|
|
85
|
+
the input samples) required to be at a leaf node. Samples have
|
|
86
|
+
equal weight when sample_weight is not provided.
|
|
87
|
+
|
|
88
|
+
max_features : int, float or {'sqrt', 'log2'} or None, optional, default: None
|
|
89
|
+
The number of features to consider when looking for the best split:
|
|
90
|
+
|
|
91
|
+
- If int, then consider `max_features` features at each split.
|
|
92
|
+
- If float, then `max_features` is a fraction and
|
|
93
|
+
`max(1, int(max_features * n_features_in_))` features are considered at
|
|
94
|
+
each split.
|
|
95
|
+
- If "sqrt", then `max_features=sqrt(n_features)`.
|
|
96
|
+
- If "log2", then `max_features=log2(n_features)`.
|
|
97
|
+
- If None, then `max_features=n_features`.
|
|
98
|
+
|
|
99
|
+
Note: the search for a split does not stop until at least one
|
|
100
|
+
valid partition of the node samples is found, even if it requires to
|
|
101
|
+
effectively inspect more than ``max_features`` features.
|
|
102
|
+
|
|
103
|
+
random_state : int, RandomState instance or None, optional, default: None
|
|
104
|
+
Controls the randomness of the estimator. The features are always
|
|
105
|
+
randomly permuted at each split, even if ``splitter`` is set to
|
|
106
|
+
``"best"``. When ``max_features < n_features``, the algorithm will
|
|
107
|
+
select ``max_features`` at random at each split before finding the best
|
|
108
|
+
split among them. But the best found split may vary across different
|
|
109
|
+
runs, even if ``max_features=n_features``. That is the case, if the
|
|
110
|
+
improvement of the criterion is identical for several splits and one
|
|
111
|
+
split has to be selected at random. To obtain a deterministic behavior
|
|
112
|
+
during fitting, ``random_state`` has to be fixed to an integer.
|
|
113
|
+
|
|
114
|
+
max_leaf_nodes : int or None, optional, default: None
|
|
115
|
+
Grow a tree with ``max_leaf_nodes`` in best-first fashion.
|
|
116
|
+
Best nodes are defined as relative reduction in impurity.
|
|
117
|
+
If None then unlimited number of leaf nodes.
|
|
118
|
+
|
|
119
|
+
low_memory : bool, optional, default: False
|
|
120
|
+
If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
|
|
121
|
+
and :meth:`predict_survival_function` are not implemented.
|
|
122
|
+
|
|
123
|
+
Attributes
|
|
124
|
+
----------
|
|
125
|
+
unique_times_ : ndarray, shape = (n_unique_times,), dtype = float
|
|
126
|
+
Unique time points.
|
|
127
|
+
|
|
128
|
+
max_features_ : int
|
|
129
|
+
The inferred value of max_features.
|
|
130
|
+
|
|
131
|
+
n_features_in_ : int
|
|
132
|
+
Number of features seen during ``fit``.
|
|
133
|
+
|
|
134
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,), dtype = object
|
|
135
|
+
Names of features seen during ``fit``. Defined only when `X`
|
|
136
|
+
has feature names that are all strings.
|
|
137
|
+
|
|
138
|
+
tree_ : Tree object
|
|
139
|
+
The underlying Tree object. Please refer to
|
|
140
|
+
``help(sklearn.tree._tree.Tree)`` for attributes of Tree object.
|
|
141
|
+
|
|
142
|
+
See also
|
|
143
|
+
--------
|
|
144
|
+
sksurv.ensemble.RandomSurvivalForest : An ensemble of SurvivalTrees.
|
|
145
|
+
|
|
146
|
+
References
|
|
147
|
+
----------
|
|
148
|
+
.. [1] Leblanc, M., & Crowley, J. (1993). Survival Trees by Goodness of Split.
|
|
149
|
+
Journal of the American Statistical Association, 88(422), 457–467.
|
|
150
|
+
|
|
151
|
+
.. [2] Ishwaran, H., Kogalur, U. B., Blackstone, E. H., & Lauer, M. S. (2008).
|
|
152
|
+
Random survival forests. The Annals of Applied Statistics, 2(3), 841–860.
|
|
153
|
+
|
|
154
|
+
.. [3] Ishwaran, H., Kogalur, U. B. (2007). Random survival forests for R.
|
|
155
|
+
R News, 7(2), 25–31. https://cran.r-project.org/doc/Rnews/Rnews_2007-2.pdf.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
_parameter_constraints = {
|
|
159
|
+
"splitter": [StrOptions({"best", "random"})],
|
|
160
|
+
"max_depth": [Interval(Integral, 1, None, closed="left"), None],
|
|
161
|
+
"min_samples_split": [
|
|
162
|
+
Interval(Integral, 2, None, closed="left"),
|
|
163
|
+
Interval(RealNotInt, 0.0, 1.0, closed="neither"),
|
|
164
|
+
],
|
|
165
|
+
"min_samples_leaf": [
|
|
166
|
+
Interval(Integral, 1, None, closed="left"),
|
|
167
|
+
Interval(RealNotInt, 0.0, 0.5, closed="right"),
|
|
168
|
+
],
|
|
169
|
+
"min_weight_fraction_leaf": [Interval(Real, 0.0, 0.5, closed="both")],
|
|
170
|
+
"max_features": [
|
|
171
|
+
Interval(Integral, 1, None, closed="left"),
|
|
172
|
+
Interval(RealNotInt, 0.0, 1.0, closed="right"),
|
|
173
|
+
StrOptions({"sqrt", "log2"}),
|
|
174
|
+
None,
|
|
175
|
+
],
|
|
176
|
+
"random_state": ["random_state"],
|
|
177
|
+
"max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
|
|
178
|
+
"low_memory": ["boolean"],
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
criterion = "logrank"
|
|
182
|
+
|
|
183
|
+
def __init__(
|
|
184
|
+
self,
|
|
185
|
+
*,
|
|
186
|
+
splitter="best",
|
|
187
|
+
max_depth=None,
|
|
188
|
+
min_samples_split=6,
|
|
189
|
+
min_samples_leaf=3,
|
|
190
|
+
min_weight_fraction_leaf=0.0,
|
|
191
|
+
max_features=None,
|
|
192
|
+
random_state=None,
|
|
193
|
+
max_leaf_nodes=None,
|
|
194
|
+
low_memory=False,
|
|
195
|
+
):
|
|
196
|
+
self.splitter = splitter
|
|
197
|
+
self.max_depth = max_depth
|
|
198
|
+
self.min_samples_split = min_samples_split
|
|
199
|
+
self.min_samples_leaf = min_samples_leaf
|
|
200
|
+
self.min_weight_fraction_leaf = min_weight_fraction_leaf
|
|
201
|
+
self.max_features = max_features
|
|
202
|
+
self.random_state = random_state
|
|
203
|
+
self.max_leaf_nodes = max_leaf_nodes
|
|
204
|
+
self.low_memory = low_memory
|
|
205
|
+
|
|
206
|
+
def __sklearn_tags__(self):
|
|
207
|
+
tags = super().__sklearn_tags__()
|
|
208
|
+
tags.input_tags.allow_nan = self.splitter in ("best", "random")
|
|
209
|
+
return tags
|
|
210
|
+
|
|
211
|
+
def _support_missing_values(self, X):
|
|
212
|
+
return not issparse(X) and self.__sklearn_tags__().input_tags.allow_nan
|
|
213
|
+
|
|
214
|
+
def _compute_missing_values_in_feature_mask(self, X, estimator_name=None):
|
|
215
|
+
"""Return boolean mask denoting if there are missing values for each feature.
|
|
216
|
+
|
|
217
|
+
This method also ensures that X is finite.
|
|
218
|
+
|
|
219
|
+
Parameter
|
|
220
|
+
---------
|
|
221
|
+
X : array-like, shape = (n_samples, n_features), dtype = DOUBLE
|
|
222
|
+
Input data.
|
|
223
|
+
|
|
224
|
+
estimator_name : str or None, default=None
|
|
225
|
+
Name to use when raising an error. Defaults to the class name.
|
|
226
|
+
|
|
227
|
+
Returns
|
|
228
|
+
-------
|
|
229
|
+
missing_values_in_feature_mask : ndarray of shape (n_features,), or None
|
|
230
|
+
Missing value mask. If missing values are not supported or there
|
|
231
|
+
are no missing values, return None.
|
|
232
|
+
"""
|
|
233
|
+
estimator_name = estimator_name or self.__class__.__name__
|
|
234
|
+
common_kwargs = dict(estimator_name=estimator_name, input_name="X")
|
|
235
|
+
|
|
236
|
+
if not self._support_missing_values(X):
|
|
237
|
+
assert_all_finite(X, **common_kwargs)
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
with np.errstate(over="ignore"):
|
|
241
|
+
overall_sum = np.sum(X)
|
|
242
|
+
|
|
243
|
+
if not np.isfinite(overall_sum):
|
|
244
|
+
# Raise a ValueError in case of the presence of an infinite element.
|
|
245
|
+
_assert_all_finite_element_wise(X, xp=np, allow_nan=True, **common_kwargs)
|
|
246
|
+
|
|
247
|
+
# If the sum is not nan, then there are no missing values
|
|
248
|
+
if not np.isnan(overall_sum):
|
|
249
|
+
return None
|
|
250
|
+
|
|
251
|
+
missing_values_in_feature_mask = _any_isnan_axis0(X)
|
|
252
|
+
return missing_values_in_feature_mask
|
|
253
|
+
|
|
254
|
+
def fit(self, X, y, sample_weight=None, check_input=True):
|
|
255
|
+
"""Build a survival tree from the training set (X, y).
|
|
256
|
+
|
|
257
|
+
If ``splitter='best'``, `X` is allowed to contain missing
|
|
258
|
+
values. In addition to evaluating each potential threshold on
|
|
259
|
+
the non-missing data, the splitter will evaluate the split
|
|
260
|
+
with all the missing values going to the left node or the
|
|
261
|
+
right node. See :ref:`tree_missing_value_support` for details.
|
|
262
|
+
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
X : array-like or sparse matrix, shape = (n_samples, n_features)
|
|
266
|
+
Data matrix
|
|
267
|
+
|
|
268
|
+
y : structured array, shape = (n_samples,)
|
|
269
|
+
A structured array with two fields. The first field is a boolean
|
|
270
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
271
|
+
The second field is a float with the time of event or time of censoring.
|
|
272
|
+
|
|
273
|
+
check_input : boolean, default: True
|
|
274
|
+
Allow to bypass several input checking.
|
|
275
|
+
Don't use this parameter unless you know what you do.
|
|
276
|
+
|
|
277
|
+
Returns
|
|
278
|
+
-------
|
|
279
|
+
self
|
|
280
|
+
"""
|
|
281
|
+
self._fit(X, y, sample_weight, check_input)
|
|
282
|
+
return self
|
|
283
|
+
|
|
284
|
+
def _fit(self, X, y, sample_weight=None, check_input=True, missing_values_in_feature_mask=None):
|
|
285
|
+
random_state = check_random_state(self.random_state)
|
|
286
|
+
|
|
287
|
+
if check_input:
|
|
288
|
+
X = validate_data(self, X, dtype=DTYPE, ensure_min_samples=2, accept_sparse="csc", ensure_all_finite=False)
|
|
289
|
+
event, time = check_array_survival(X, y)
|
|
290
|
+
time = time.astype(np.float64)
|
|
291
|
+
self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
|
|
292
|
+
missing_values_in_feature_mask = self._compute_missing_values_in_feature_mask(X)
|
|
293
|
+
if issparse(X):
|
|
294
|
+
X.sort_indices()
|
|
295
|
+
|
|
296
|
+
y_numeric = np.empty((X.shape[0], 2), dtype=np.float64)
|
|
297
|
+
y_numeric[:, 0] = time
|
|
298
|
+
y_numeric[:, 1] = event.astype(np.float64)
|
|
299
|
+
else:
|
|
300
|
+
y_numeric, self.unique_times_, self.is_event_time_ = y
|
|
301
|
+
|
|
302
|
+
n_samples, self.n_features_in_ = X.shape
|
|
303
|
+
params = self._check_params(n_samples)
|
|
304
|
+
|
|
305
|
+
if self.low_memory:
|
|
306
|
+
self.n_outputs_ = 1
|
|
307
|
+
# one "class" only, for the sum over the CHF
|
|
308
|
+
self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp)
|
|
309
|
+
else:
|
|
310
|
+
self.n_outputs_ = self.unique_times_.shape[0]
|
|
311
|
+
# one "class" for CHF, one for survival function
|
|
312
|
+
self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2
|
|
313
|
+
|
|
314
|
+
# Build tree
|
|
315
|
+
criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_, self.is_event_time_)
|
|
316
|
+
|
|
317
|
+
SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
|
|
318
|
+
|
|
319
|
+
splitter = self.splitter
|
|
320
|
+
if not isinstance(self.splitter, Splitter):
|
|
321
|
+
splitter = SPLITTERS[self.splitter](
|
|
322
|
+
criterion,
|
|
323
|
+
self.max_features_,
|
|
324
|
+
params["min_samples_leaf"],
|
|
325
|
+
params["min_weight_leaf"],
|
|
326
|
+
random_state,
|
|
327
|
+
None, # monotonic_cst
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)
|
|
331
|
+
|
|
332
|
+
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
|
|
333
|
+
if params["max_leaf_nodes"] < 0:
|
|
334
|
+
builder = DepthFirstTreeBuilder(
|
|
335
|
+
splitter,
|
|
336
|
+
params["min_samples_split"],
|
|
337
|
+
params["min_samples_leaf"],
|
|
338
|
+
params["min_weight_leaf"],
|
|
339
|
+
params["max_depth"],
|
|
340
|
+
0.0, # min_impurity_decrease
|
|
341
|
+
)
|
|
342
|
+
else:
|
|
343
|
+
builder = BestFirstTreeBuilder(
|
|
344
|
+
splitter,
|
|
345
|
+
params["min_samples_split"],
|
|
346
|
+
params["min_samples_leaf"],
|
|
347
|
+
params["min_weight_leaf"],
|
|
348
|
+
params["max_depth"],
|
|
349
|
+
params["max_leaf_nodes"],
|
|
350
|
+
0.0, # min_impurity_decrease
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
builder.build(self.tree_, X, y_numeric, sample_weight, missing_values_in_feature_mask)
|
|
354
|
+
|
|
355
|
+
return self
|
|
356
|
+
|
|
357
|
+
def _check_params(self, n_samples):
|
|
358
|
+
self._validate_params()
|
|
359
|
+
|
|
360
|
+
# Check parameters
|
|
361
|
+
max_depth = (2**31) - 1 if self.max_depth is None else self.max_depth
|
|
362
|
+
|
|
363
|
+
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes
|
|
364
|
+
|
|
365
|
+
if isinstance(self.min_samples_leaf, Integral):
|
|
366
|
+
min_samples_leaf = self.min_samples_leaf
|
|
367
|
+
else: # float
|
|
368
|
+
min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))
|
|
369
|
+
|
|
370
|
+
if isinstance(self.min_samples_split, Integral):
|
|
371
|
+
min_samples_split = self.min_samples_split
|
|
372
|
+
else: # float
|
|
373
|
+
min_samples_split = int(ceil(self.min_samples_split * n_samples))
|
|
374
|
+
min_samples_split = max(2, min_samples_split)
|
|
375
|
+
|
|
376
|
+
min_samples_split = max(min_samples_split, 2 * min_samples_leaf)
|
|
377
|
+
|
|
378
|
+
self._check_max_features()
|
|
379
|
+
|
|
380
|
+
min_weight_leaf = self.min_weight_fraction_leaf * n_samples
|
|
381
|
+
|
|
382
|
+
return {
|
|
383
|
+
"max_depth": max_depth,
|
|
384
|
+
"max_leaf_nodes": max_leaf_nodes,
|
|
385
|
+
"min_samples_leaf": min_samples_leaf,
|
|
386
|
+
"min_samples_split": min_samples_split,
|
|
387
|
+
"min_weight_leaf": min_weight_leaf,
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
def _check_max_features(self):
|
|
391
|
+
if isinstance(self.max_features, str):
|
|
392
|
+
if self.max_features == "sqrt":
|
|
393
|
+
max_features = max(1, int(np.sqrt(self.n_features_in_)))
|
|
394
|
+
elif self.max_features == "log2":
|
|
395
|
+
max_features = max(1, int(np.log2(self.n_features_in_)))
|
|
396
|
+
|
|
397
|
+
elif self.max_features is None:
|
|
398
|
+
max_features = self.n_features_in_
|
|
399
|
+
elif isinstance(self.max_features, Integral):
|
|
400
|
+
max_features = self.max_features
|
|
401
|
+
else: # float
|
|
402
|
+
if self.max_features > 0.0:
|
|
403
|
+
max_features = max(1, int(self.max_features * self.n_features_in_))
|
|
404
|
+
else:
|
|
405
|
+
max_features = 0 # pragma: no cover
|
|
406
|
+
|
|
407
|
+
if not 0 < max_features <= self.n_features_in_:
|
|
408
|
+
raise ValueError("max_features must be in (0, n_features]")
|
|
409
|
+
|
|
410
|
+
self.max_features_ = max_features
|
|
411
|
+
|
|
412
|
+
def _check_low_memory(self, function):
|
|
413
|
+
"""Check if `function` is supported in low memory mode and throw if it is not."""
|
|
414
|
+
if self.low_memory:
|
|
415
|
+
raise NotImplementedError(
|
|
416
|
+
f"{function} is not implemented in low memory mode."
|
|
417
|
+
+ " run fit with low_memory=False to disable low memory mode."
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
def _validate_X_predict(self, X, check_input, accept_sparse="csr"):
|
|
421
|
+
"""Validate X whenever one tries to predict"""
|
|
422
|
+
if check_input:
|
|
423
|
+
if self._support_missing_values(X):
|
|
424
|
+
ensure_all_finite = "allow-nan"
|
|
425
|
+
else:
|
|
426
|
+
ensure_all_finite = True
|
|
427
|
+
X = validate_data(
|
|
428
|
+
self,
|
|
429
|
+
X,
|
|
430
|
+
dtype=DTYPE,
|
|
431
|
+
accept_sparse=accept_sparse,
|
|
432
|
+
reset=False,
|
|
433
|
+
ensure_all_finite=ensure_all_finite,
|
|
434
|
+
)
|
|
435
|
+
else:
|
|
436
|
+
# The number of features is checked regardless of `check_input`
|
|
437
|
+
_check_n_features(self, X, reset=False)
|
|
438
|
+
|
|
439
|
+
return X
|
|
440
|
+
|
|
441
|
+
def predict(self, X, check_input=True):
|
|
442
|
+
r"""Predict risk score.
|
|
443
|
+
|
|
444
|
+
The risk score is the total number of events, which can
|
|
445
|
+
be estimated by the sum of the estimated cumulative
|
|
446
|
+
hazard function :math:`\hat{H}_h` in terminal node :math:`h`.
|
|
447
|
+
|
|
448
|
+
.. math::
|
|
449
|
+
|
|
450
|
+
\sum_{j=1}^{n(h)} \hat{H}_h(T_{j} \mid x) ,
|
|
451
|
+
|
|
452
|
+
where :math:`n(h)` denotes the number of distinct event times
|
|
453
|
+
of samples belonging to the same terminal node as :math:`x`.
|
|
454
|
+
|
|
455
|
+
Parameters
|
|
456
|
+
----------
|
|
457
|
+
X : array-like or sparse matrix, shape = (n_samples, n_features)
|
|
458
|
+
Data matrix.
|
|
459
|
+
If ``splitter='best'``, `X` is allowed to contain missing
|
|
460
|
+
values and decisions are made as described in
|
|
461
|
+
:ref:`tree_missing_value_support`.
|
|
462
|
+
|
|
463
|
+
check_input : boolean, default: True
|
|
464
|
+
Allow to bypass several input checking.
|
|
465
|
+
Don't use this parameter unless you know what you do.
|
|
466
|
+
|
|
467
|
+
Returns
|
|
468
|
+
-------
|
|
469
|
+
risk_scores : ndarray, shape = (n_samples,), dtype=float
|
|
470
|
+
Predicted risk scores.
|
|
471
|
+
"""
|
|
472
|
+
|
|
473
|
+
if self.low_memory:
|
|
474
|
+
check_is_fitted(self, "tree_")
|
|
475
|
+
X = self._validate_X_predict(X, check_input, accept_sparse="csr")
|
|
476
|
+
pred = self.tree_.predict(X)
|
|
477
|
+
return pred[..., 0]
|
|
478
|
+
|
|
479
|
+
chf = self.predict_cumulative_hazard_function(X, check_input, return_array=True)
|
|
480
|
+
return chf[:, self.is_event_time_].sum(1)
|
|
481
|
+
|
|
482
|
+
@append_cumulative_hazard_example(estimator_mod="tree", estimator_class="SurvivalTree")
|
|
483
|
+
def predict_cumulative_hazard_function(self, X, check_input=True, return_array=False):
|
|
484
|
+
"""Predict cumulative hazard function.
|
|
485
|
+
|
|
486
|
+
The cumulative hazard function (CHF) for an individual
|
|
487
|
+
with feature vector :math:`x` is computed from
|
|
488
|
+
all samples of the training data that are in the
|
|
489
|
+
same terminal node as :math:`x`.
|
|
490
|
+
It is estimated by the Nelson–Aalen estimator.
|
|
491
|
+
|
|
492
|
+
Parameters
|
|
493
|
+
----------
|
|
494
|
+
X : array-like or sparse matrix, shape = (n_samples, n_features)
|
|
495
|
+
Data matrix.
|
|
496
|
+
If ``splitter='best'``, `X` is allowed to contain missing
|
|
497
|
+
values and decisions are made as described in
|
|
498
|
+
:ref:`tree_missing_value_support`.
|
|
499
|
+
|
|
500
|
+
check_input : boolean, default: True
|
|
501
|
+
Allow to bypass several input checking.
|
|
502
|
+
Don't use this parameter unless you know what you do.
|
|
503
|
+
|
|
504
|
+
return_array : bool, default: False
|
|
505
|
+
Whether to return a single array of cumulative hazard values
|
|
506
|
+
or a list of step functions.
|
|
507
|
+
|
|
508
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
509
|
+
objects is returned.
|
|
510
|
+
|
|
511
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
512
|
+
returned, where `n_unique_times` is the number of unique
|
|
513
|
+
event times in the training data. Each row represents the cumulative
|
|
514
|
+
hazard function of an individual evaluated at `unique_times_`.
|
|
515
|
+
|
|
516
|
+
Returns
|
|
517
|
+
-------
|
|
518
|
+
cum_hazard : ndarray
|
|
519
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
520
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
521
|
+
|
|
522
|
+
If `return_array` is `True`, a numeric array of shape
|
|
523
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
524
|
+
|
|
525
|
+
Examples
|
|
526
|
+
--------
|
|
527
|
+
"""
|
|
528
|
+
self._check_low_memory("predict_cumulative_hazard_function")
|
|
529
|
+
check_is_fitted(self, "tree_")
|
|
530
|
+
X = self._validate_X_predict(X, check_input, accept_sparse="csr")
|
|
531
|
+
|
|
532
|
+
pred = self.tree_.predict(X)
|
|
533
|
+
arr = pred[..., 0]
|
|
534
|
+
if return_array:
|
|
535
|
+
return arr
|
|
536
|
+
return _array_to_step_function(self.unique_times_, arr)
|
|
537
|
+
|
|
538
|
+
@append_survival_function_example(estimator_mod="tree", estimator_class="SurvivalTree")
|
|
539
|
+
def predict_survival_function(self, X, check_input=True, return_array=False):
|
|
540
|
+
"""Predict survival function.
|
|
541
|
+
|
|
542
|
+
The survival function for an individual
|
|
543
|
+
with feature vector :math:`x` is computed from
|
|
544
|
+
all samples of the training data that are in the
|
|
545
|
+
same terminal node as :math:`x`.
|
|
546
|
+
It is estimated by the Kaplan-Meier estimator.
|
|
547
|
+
|
|
548
|
+
Parameters
|
|
549
|
+
----------
|
|
550
|
+
X : array-like or sparse matrix, shape = (n_samples, n_features)
|
|
551
|
+
Data matrix.
|
|
552
|
+
If ``splitter='best'``, `X` is allowed to contain missing
|
|
553
|
+
values and decisions are made as described in
|
|
554
|
+
:ref:`tree_missing_value_support`.
|
|
555
|
+
|
|
556
|
+
check_input : boolean, default: True
|
|
557
|
+
Allow to bypass several input checking.
|
|
558
|
+
Don't use this parameter unless you know what you do.
|
|
559
|
+
|
|
560
|
+
return_array : bool, default: False
|
|
561
|
+
Whether to return a single array of survival probabilities
|
|
562
|
+
or a list of step functions.
|
|
563
|
+
|
|
564
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
565
|
+
objects is returned.
|
|
566
|
+
|
|
567
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
568
|
+
returned, where `n_unique_times` is the number of unique
|
|
569
|
+
event times in the training data. Each row represents the survival
|
|
570
|
+
function of an individual evaluated at `unique_times_`.
|
|
571
|
+
|
|
572
|
+
Returns
|
|
573
|
+
-------
|
|
574
|
+
survival : ndarray
|
|
575
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
576
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
577
|
+
|
|
578
|
+
If `return_array` is `True`, a numeric array of shape
|
|
579
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
580
|
+
|
|
581
|
+
Examples
|
|
582
|
+
--------
|
|
583
|
+
"""
|
|
584
|
+
self._check_low_memory("predict_survival_function")
|
|
585
|
+
check_is_fitted(self, "tree_")
|
|
586
|
+
X = self._validate_X_predict(X, check_input, accept_sparse="csr")
|
|
587
|
+
|
|
588
|
+
pred = self.tree_.predict(X)
|
|
589
|
+
arr = pred[..., 1]
|
|
590
|
+
if return_array:
|
|
591
|
+
return arr
|
|
592
|
+
return _array_to_step_function(self.unique_times_, arr)
|
|
593
|
+
|
|
594
|
+
def apply(self, X, check_input=True):
|
|
595
|
+
"""Return the index of the leaf that each sample is predicted as.
|
|
596
|
+
|
|
597
|
+
Parameters
|
|
598
|
+
----------
|
|
599
|
+
X : array-like or sparse matrix, shape = (n_samples, n_features)
|
|
600
|
+
The input samples. Internally, it will be converted to
|
|
601
|
+
``dtype=np.float32`` and if a sparse matrix is provided
|
|
602
|
+
to a sparse ``csr_matrix``.
|
|
603
|
+
If ``splitter='best'``, `X` is allowed to contain missing
|
|
604
|
+
values and decisions are made as described in
|
|
605
|
+
:ref:`tree_missing_value_support`.
|
|
606
|
+
|
|
607
|
+
check_input : bool, default: True
|
|
608
|
+
Allow to bypass several input checking.
|
|
609
|
+
Don't use this parameter unless you know what you do.
|
|
610
|
+
|
|
611
|
+
Returns
|
|
612
|
+
-------
|
|
613
|
+
X_leaves : ndarray, shape = (n_samples,), dtype=int
|
|
614
|
+
For each datapoint x in X, return the index of the leaf x
|
|
615
|
+
ends up in. Leaves are numbered within
|
|
616
|
+
``[0; self.tree_.node_count)``, possibly with gaps in the
|
|
617
|
+
numbering.
|
|
618
|
+
"""
|
|
619
|
+
check_is_fitted(self, "tree_")
|
|
620
|
+
self._validate_X_predict(X, check_input)
|
|
621
|
+
return self.tree_.apply(X)
|
|
622
|
+
|
|
623
|
+
def decision_path(self, X, check_input=True):
|
|
624
|
+
"""Return the decision path in the tree.
|
|
625
|
+
|
|
626
|
+
Parameters
|
|
627
|
+
----------
|
|
628
|
+
X : array-like or sparse matrix, shape = (n_samples, n_features)
|
|
629
|
+
The input samples. Internally, it will be converted to
|
|
630
|
+
``dtype=np.float32`` and if a sparse matrix is provided
|
|
631
|
+
to a sparse ``csr_matrix``.
|
|
632
|
+
If ``splitter='best'``, `X` is allowed to contain missing
|
|
633
|
+
values and decisions are made as described in
|
|
634
|
+
:ref:`tree_missing_value_support`.
|
|
635
|
+
|
|
636
|
+
check_input : bool, default=True
|
|
637
|
+
Allow to bypass several input checking.
|
|
638
|
+
Don't use this parameter unless you know what you do.
|
|
639
|
+
|
|
640
|
+
Returns
|
|
641
|
+
-------
|
|
642
|
+
indicator : sparse matrix, shape = (n_samples, n_nodes)
|
|
643
|
+
Return a node indicator CSR matrix where non zero elements
|
|
644
|
+
indicates that the samples goes through the nodes.
|
|
645
|
+
"""
|
|
646
|
+
X = self._validate_X_predict(X, check_input)
|
|
647
|
+
return self.tree_.decision_path(X)
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
class ExtraSurvivalTree(SurvivalTree):
|
|
651
|
+
"""An Extremely Randomized Survival Tree.
|
|
652
|
+
|
|
653
|
+
This class implements an Extremely Randomized Tree for survival analysis.
|
|
654
|
+
It differs from :class:`SurvivalTree` in how splits are chosen:
|
|
655
|
+
instead of searching for the optimal split, it considers a random subset
|
|
656
|
+
of features and random thresholds for each feature, then picks the best
|
|
657
|
+
among these random candidates.
|
|
658
|
+
|
|
659
|
+
Parameters
|
|
660
|
+
----------
|
|
661
|
+
splitter : {'best', 'random'}, default: 'random'
|
|
662
|
+
The strategy used to choose the split at each node. Supported
|
|
663
|
+
strategies are 'best' to choose the best split and 'random' to choose
|
|
664
|
+
the best random split.
|
|
665
|
+
|
|
666
|
+
max_depth : int or None, optional, default: None
|
|
667
|
+
The maximum depth of the tree. If None, then nodes are expanded until
|
|
668
|
+
all leaves are pure or until all leaves contain less than
|
|
669
|
+
`min_samples_split` samples.
|
|
670
|
+
|
|
671
|
+
min_samples_split : int, float, optional, default: 6
|
|
672
|
+
The minimum number of samples required to split an internal node:
|
|
673
|
+
|
|
674
|
+
- If int, then consider `min_samples_split` as the minimum number.
|
|
675
|
+
- If float, then `min_samples_split` is a fraction and
|
|
676
|
+
`ceil(min_samples_split * n_samples)` are the minimum
|
|
677
|
+
number of samples for each split.
|
|
678
|
+
|
|
679
|
+
min_samples_leaf : int, float, optional, default: 3
|
|
680
|
+
The minimum number of samples required to be at a leaf node.
|
|
681
|
+
A split point at any depth will only be considered if it leaves at
|
|
682
|
+
least ``min_samples_leaf`` training samples in each of the left and
|
|
683
|
+
right branches. This may have the effect of smoothing the model,
|
|
684
|
+
especially in regression.
|
|
685
|
+
|
|
686
|
+
- If int, then consider `min_samples_leaf` as the minimum number.
|
|
687
|
+
- If float, then `min_samples_leaf` is a fraction and
|
|
688
|
+
`ceil(min_samples_leaf * n_samples)` are the minimum
|
|
689
|
+
number of samples for each node.
|
|
690
|
+
|
|
691
|
+
min_weight_fraction_leaf : float, optional, default: 0.
|
|
692
|
+
The minimum weighted fraction of the sum total of weights (of all
|
|
693
|
+
the input samples) required to be at a leaf node. Samples have
|
|
694
|
+
equal weight when sample_weight is not provided.
|
|
695
|
+
|
|
696
|
+
max_features : int, float or {'sqrt', 'log2'} or None, optional, default: None
|
|
697
|
+
The number of features to consider when looking for the best split:
|
|
698
|
+
|
|
699
|
+
- If int, then consider `max_features` features at each split.
|
|
700
|
+
- If float, then `max_features` is a fraction and
|
|
701
|
+
`max(1, int(max_features * n_features_in_))` features are considered at
|
|
702
|
+
each split.
|
|
703
|
+
- If "sqrt", then `max_features=sqrt(n_features)`.
|
|
704
|
+
- If "log2", then `max_features=log2(n_features)`.
|
|
705
|
+
- If None, then `max_features=n_features`.
|
|
706
|
+
|
|
707
|
+
Note: the search for a split does not stop until at least one
|
|
708
|
+
valid partition of the node samples is found, even if it requires to
|
|
709
|
+
effectively inspect more than ``max_features`` features.
|
|
710
|
+
|
|
711
|
+
random_state : int, RandomState instance or None, optional, default: None
|
|
712
|
+
Controls the randomness of the estimator. The features are always
|
|
713
|
+
randomly permuted at each split, even if ``splitter`` is set to
|
|
714
|
+
``"best"``. When ``max_features < n_features``, the algorithm will
|
|
715
|
+
select ``max_features`` at random at each split before finding the best
|
|
716
|
+
split among them. But the best found split may vary across different
|
|
717
|
+
runs, even if ``max_features=n_features``. That is the case, if the
|
|
718
|
+
improvement of the criterion is identical for several splits and one
|
|
719
|
+
split has to be selected at random. To obtain a deterministic behavior
|
|
720
|
+
during fitting, ``random_state`` has to be fixed to an integer.
|
|
721
|
+
|
|
722
|
+
max_leaf_nodes : int or None, optional, default: None
|
|
723
|
+
Grow a tree with ``max_leaf_nodes`` in best-first fashion.
|
|
724
|
+
Best nodes are defined as relative reduction in impurity.
|
|
725
|
+
If None then unlimited number of leaf nodes.
|
|
726
|
+
|
|
727
|
+
low_memory : bool, optional, default: False
|
|
728
|
+
If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
|
|
729
|
+
and :meth:`predict_survival_function` are not implemented.
|
|
730
|
+
|
|
731
|
+
Attributes
|
|
732
|
+
----------
|
|
733
|
+
unique_times_ : ndarray, shape = (n_unique_times,), dtype = float
|
|
734
|
+
Unique time points.
|
|
735
|
+
|
|
736
|
+
max_features_ : int
|
|
737
|
+
The inferred value of max_features.
|
|
738
|
+
|
|
739
|
+
n_features_in_ : int
|
|
740
|
+
Number of features seen during ``fit``.
|
|
741
|
+
|
|
742
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,), dtype = object
|
|
743
|
+
Names of features seen during ``fit``. Defined only when `X`
|
|
744
|
+
has feature names that are all strings.
|
|
745
|
+
|
|
746
|
+
tree_ : Tree object
|
|
747
|
+
The underlying Tree object. Please refer to
|
|
748
|
+
``help(sklearn.tree._tree.Tree)`` for attributes of Tree object.
|
|
749
|
+
|
|
750
|
+
See also
|
|
751
|
+
--------
|
|
752
|
+
sksurv.ensemble.ExtraSurvivalTrees : An ensemble of ExtraSurvivalTrees.
|
|
753
|
+
"""
|
|
754
|
+
|
|
755
|
+
def __init__(
|
|
756
|
+
self,
|
|
757
|
+
*,
|
|
758
|
+
splitter="random",
|
|
759
|
+
max_depth=None,
|
|
760
|
+
min_samples_split=6,
|
|
761
|
+
min_samples_leaf=3,
|
|
762
|
+
min_weight_fraction_leaf=0.0,
|
|
763
|
+
max_features=None,
|
|
764
|
+
random_state=None,
|
|
765
|
+
max_leaf_nodes=None,
|
|
766
|
+
low_memory=False,
|
|
767
|
+
):
|
|
768
|
+
super().__init__(
|
|
769
|
+
splitter=splitter,
|
|
770
|
+
max_depth=max_depth,
|
|
771
|
+
min_samples_split=min_samples_split,
|
|
772
|
+
min_samples_leaf=min_samples_leaf,
|
|
773
|
+
min_weight_fraction_leaf=min_weight_fraction_leaf,
|
|
774
|
+
max_features=max_features,
|
|
775
|
+
random_state=random_state,
|
|
776
|
+
max_leaf_nodes=max_leaf_nodes,
|
|
777
|
+
low_memory=low_memory,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
def predict_cumulative_hazard_function(self, X, check_input=True, return_array=False):
|
|
781
|
+
ExtraSurvivalTree.predict_cumulative_hazard_function.__doc__ = (
|
|
782
|
+
SurvivalTree.predict_cumulative_hazard_function.__doc__.replace("SurvivalTree", "ExtraSurvivalTree")
|
|
783
|
+
)
|
|
784
|
+
return super().predict_cumulative_hazard_function(X, check_input=check_input, return_array=return_array)
|
|
785
|
+
|
|
786
|
+
def predict_survival_function(self, X, check_input=True, return_array=False):
|
|
787
|
+
ExtraSurvivalTree.predict_survival_function.__doc__ = SurvivalTree.predict_survival_function.__doc__.replace(
|
|
788
|
+
"SurvivalTree", "ExtraSurvivalTree"
|
|
789
|
+
)
|
|
790
|
+
return super().predict_survival_function(X, check_input=check_input, return_array=return_array)
|