scikit-survival 0.24.0__cp312-cp312-win_amd64.whl → 0.25.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.25.0.dist-info/METADATA +185 -0
- scikit_survival-0.25.0.dist-info/RECORD +58 -0
- {scikit_survival-0.24.0.dist-info → scikit_survival-0.25.0.dist-info}/WHEEL +1 -1
- sksurv/__init__.py +51 -6
- sksurv/base.py +12 -2
- sksurv/bintrees/_binarytrees.cp312-win_amd64.pyd +0 -0
- sksurv/column.py +33 -29
- sksurv/compare.py +22 -22
- sksurv/datasets/base.py +45 -20
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/_coxph_loss.cp312-win_amd64.pyd +0 -0
- sksurv/ensemble/boosting.py +116 -168
- sksurv/ensemble/forest.py +94 -151
- sksurv/functions.py +29 -29
- sksurv/io/arffread.py +34 -3
- sksurv/io/arffwrite.py +38 -2
- sksurv/kernels/_clinical_kernel.cp312-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +33 -13
- sksurv/linear_model/_coxnet.cp312-win_amd64.pyd +0 -0
- sksurv/linear_model/aft.py +14 -11
- sksurv/linear_model/coxnet.py +138 -89
- sksurv/linear_model/coxph.py +102 -83
- sksurv/meta/ensemble_selection.py +91 -9
- sksurv/meta/stacking.py +47 -26
- sksurv/metrics.py +257 -224
- sksurv/nonparametric.py +150 -81
- sksurv/preprocessing.py +55 -27
- sksurv/svm/_minlip.cp312-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp312-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +160 -79
- sksurv/svm/naive_survival_svm.py +63 -34
- sksurv/svm/survival_svm.py +104 -104
- sksurv/tree/_criterion.cp312-win_amd64.pyd +0 -0
- sksurv/tree/tree.py +170 -84
- sksurv/util.py +80 -26
- scikit_survival-0.24.0.dist-info/METADATA +0 -888
- scikit_survival-0.24.0.dist-info/RECORD +0 -57
- {scikit_survival-0.24.0.dist-info → scikit_survival-0.25.0.dist-info/licenses}/COPYING +0 -0
- {scikit_survival-0.24.0.dist-info → scikit_survival-0.25.0.dist-info}/top_level.txt +0 -0
sksurv/tree/tree.py
CHANGED
|
@@ -20,6 +20,7 @@ from sklearn.utils.validation import (
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
from ..base import SurvivalAnalysisMixin
|
|
23
|
+
from ..docstrings import append_cumulative_hazard_example, append_survival_function_example
|
|
23
24
|
from ..functions import StepFunction
|
|
24
25
|
from ..util import check_array_survival
|
|
25
26
|
from ._criterion import LogrankCriterion, get_unique_times
|
|
@@ -38,10 +39,9 @@ def _array_to_step_function(x, array):
|
|
|
38
39
|
|
|
39
40
|
|
|
40
41
|
class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
41
|
-
"""A survival tree.
|
|
42
|
+
"""A single survival tree.
|
|
42
43
|
|
|
43
|
-
The quality of a split is measured by the
|
|
44
|
-
log-rank splitting rule.
|
|
44
|
+
The quality of a split is measured by the log-rank splitting rule.
|
|
45
45
|
|
|
46
46
|
If ``splitter='best'``, fit and predict methods support
|
|
47
47
|
missing values. See :ref:`tree_missing_value_support` for details.
|
|
@@ -85,7 +85,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
85
85
|
the input samples) required to be at a leaf node. Samples have
|
|
86
86
|
equal weight when sample_weight is not provided.
|
|
87
87
|
|
|
88
|
-
max_features : int, float,
|
|
88
|
+
max_features : int, float or {'sqrt', 'log2'} or None, optional, default: None
|
|
89
89
|
The number of features to consider when looking for the best split:
|
|
90
90
|
|
|
91
91
|
- If int, then consider `max_features` features at each split.
|
|
@@ -116,22 +116,22 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
116
116
|
Best nodes are defined as relative reduction in impurity.
|
|
117
117
|
If None then unlimited number of leaf nodes.
|
|
118
118
|
|
|
119
|
-
low_memory :
|
|
120
|
-
If set,
|
|
121
|
-
and
|
|
119
|
+
low_memory : bool, optional, default: False
|
|
120
|
+
If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
|
|
121
|
+
and :meth:`predict_survival_function` are not implemented.
|
|
122
122
|
|
|
123
123
|
Attributes
|
|
124
124
|
----------
|
|
125
|
-
unique_times_ :
|
|
125
|
+
unique_times_ : ndarray, shape = (n_unique_times,), dtype = float
|
|
126
126
|
Unique time points.
|
|
127
127
|
|
|
128
|
-
max_features_ : int
|
|
128
|
+
max_features_ : int
|
|
129
129
|
The inferred value of max_features.
|
|
130
130
|
|
|
131
131
|
n_features_in_ : int
|
|
132
132
|
Number of features seen during ``fit``.
|
|
133
133
|
|
|
134
|
-
feature_names_in_ : ndarray
|
|
134
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,), dtype = object
|
|
135
135
|
Names of features seen during ``fit``. Defined only when `X`
|
|
136
136
|
has feature names that are all strings.
|
|
137
137
|
|
|
@@ -141,8 +141,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
141
141
|
|
|
142
142
|
See also
|
|
143
143
|
--------
|
|
144
|
-
sksurv.ensemble.RandomSurvivalForest
|
|
145
|
-
An ensemble of SurvivalTrees.
|
|
144
|
+
sksurv.ensemble.RandomSurvivalForest : An ensemble of SurvivalTrees.
|
|
146
145
|
|
|
147
146
|
References
|
|
148
147
|
----------
|
|
@@ -219,7 +218,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
219
218
|
|
|
220
219
|
Parameter
|
|
221
220
|
---------
|
|
222
|
-
X : array-like
|
|
221
|
+
X : array-like, shape = (n_samples, n_features), dtype = DOUBLE
|
|
223
222
|
Input data.
|
|
224
223
|
|
|
225
224
|
estimator_name : str or None, default=None
|
|
@@ -267,9 +266,9 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
267
266
|
Data matrix
|
|
268
267
|
|
|
269
268
|
y : structured array, shape = (n_samples,)
|
|
270
|
-
A structured array
|
|
271
|
-
|
|
272
|
-
second field.
|
|
269
|
+
A structured array with two fields. The first field is a boolean
|
|
270
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
271
|
+
The second field is a float with the time of event or time of censoring.
|
|
273
272
|
|
|
274
273
|
check_input : boolean, default: True
|
|
275
274
|
Allow to bypass several input checking.
|
|
@@ -440,15 +439,15 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
440
439
|
return X
|
|
441
440
|
|
|
442
441
|
def predict(self, X, check_input=True):
|
|
443
|
-
"""Predict risk score.
|
|
442
|
+
r"""Predict risk score.
|
|
444
443
|
|
|
445
444
|
The risk score is the total number of events, which can
|
|
446
445
|
be estimated by the sum of the estimated cumulative
|
|
447
|
-
hazard function :math
|
|
446
|
+
hazard function :math:`\hat{H}_h` in terminal node :math:`h`.
|
|
448
447
|
|
|
449
448
|
.. math::
|
|
450
449
|
|
|
451
|
-
|
|
450
|
+
\sum_{j=1}^{n(h)} \hat{H}_h(T_{j} \mid x) ,
|
|
452
451
|
|
|
453
452
|
where :math:`n(h)` denotes the number of distinct event times
|
|
454
453
|
of samples belonging to the same terminal node as :math:`x`.
|
|
@@ -467,7 +466,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
467
466
|
|
|
468
467
|
Returns
|
|
469
468
|
-------
|
|
470
|
-
risk_scores : ndarray, shape = (n_samples,)
|
|
469
|
+
risk_scores : ndarray, shape = (n_samples,), dtype=float
|
|
471
470
|
Predicted risk scores.
|
|
472
471
|
"""
|
|
473
472
|
|
|
@@ -480,6 +479,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
480
479
|
chf = self.predict_cumulative_hazard_function(X, check_input, return_array=True)
|
|
481
480
|
return chf[:, self.is_event_time_].sum(1)
|
|
482
481
|
|
|
482
|
+
@append_cumulative_hazard_example(estimator_mod="tree", estimator_class="SurvivalTree")
|
|
483
483
|
def predict_cumulative_hazard_function(self, X, check_input=True, return_array=False):
|
|
484
484
|
"""Predict cumulative hazard function.
|
|
485
485
|
|
|
@@ -501,44 +501,29 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
501
501
|
Allow to bypass several input checking.
|
|
502
502
|
Don't use this parameter unless you know what you do.
|
|
503
503
|
|
|
504
|
-
return_array :
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
504
|
+
return_array : bool, default: False
|
|
505
|
+
Whether to return a single array of cumulative hazard values
|
|
506
|
+
or a list of step functions.
|
|
507
|
+
|
|
508
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
509
|
+
objects is returned.
|
|
510
|
+
|
|
511
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
512
|
+
returned, where `n_unique_times` is the number of unique
|
|
513
|
+
event times in the training data. Each row represents the cumulative
|
|
514
|
+
hazard function of an individual evaluated at `unique_times_`.
|
|
508
515
|
|
|
509
516
|
Returns
|
|
510
517
|
-------
|
|
511
518
|
cum_hazard : ndarray
|
|
512
|
-
If `return_array` is
|
|
513
|
-
|
|
514
|
-
|
|
519
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
520
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
521
|
+
|
|
522
|
+
If `return_array` is `True`, a numeric array of shape
|
|
523
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
515
524
|
|
|
516
525
|
Examples
|
|
517
526
|
--------
|
|
518
|
-
>>> import matplotlib.pyplot as plt
|
|
519
|
-
>>> from sksurv.datasets import load_whas500
|
|
520
|
-
>>> from sksurv.tree import SurvivalTree
|
|
521
|
-
|
|
522
|
-
Load and prepare the data.
|
|
523
|
-
|
|
524
|
-
>>> X, y = load_whas500()
|
|
525
|
-
>>> X = X.astype(float)
|
|
526
|
-
|
|
527
|
-
Fit the model.
|
|
528
|
-
|
|
529
|
-
>>> estimator = SurvivalTree().fit(X, y)
|
|
530
|
-
|
|
531
|
-
Estimate the cumulative hazard function for the first 5 samples.
|
|
532
|
-
|
|
533
|
-
>>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
|
|
534
|
-
|
|
535
|
-
Plot the estimated cumulative hazard functions.
|
|
536
|
-
|
|
537
|
-
>>> for fn in chf_funcs:
|
|
538
|
-
... plt.step(fn.x, fn(fn.x), where="post")
|
|
539
|
-
...
|
|
540
|
-
>>> plt.ylim(0, 1)
|
|
541
|
-
>>> plt.show()
|
|
542
527
|
"""
|
|
543
528
|
self._check_low_memory("predict_cumulative_hazard_function")
|
|
544
529
|
check_is_fitted(self, "tree_")
|
|
@@ -550,6 +535,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
550
535
|
return arr
|
|
551
536
|
return _array_to_step_function(self.unique_times_, arr)
|
|
552
537
|
|
|
538
|
+
@append_survival_function_example(estimator_mod="tree", estimator_class="SurvivalTree")
|
|
553
539
|
def predict_survival_function(self, X, check_input=True, return_array=False):
|
|
554
540
|
"""Predict survival function.
|
|
555
541
|
|
|
@@ -571,45 +557,29 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
571
557
|
Allow to bypass several input checking.
|
|
572
558
|
Don't use this parameter unless you know what you do.
|
|
573
559
|
|
|
574
|
-
return_array :
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
560
|
+
return_array : bool, default: False
|
|
561
|
+
Whether to return a single array of survival probabilities
|
|
562
|
+
or a list of step functions.
|
|
563
|
+
|
|
564
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
565
|
+
objects is returned.
|
|
566
|
+
|
|
567
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
568
|
+
returned, where `n_unique_times` is the number of unique
|
|
569
|
+
event times in the training data. Each row represents the survival
|
|
570
|
+
function of an individual evaluated at `unique_times_`.
|
|
578
571
|
|
|
579
572
|
Returns
|
|
580
573
|
-------
|
|
581
574
|
survival : ndarray
|
|
582
|
-
If `return_array` is
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
575
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
576
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
577
|
+
|
|
578
|
+
If `return_array` is `True`, a numeric array of shape
|
|
579
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
586
580
|
|
|
587
581
|
Examples
|
|
588
582
|
--------
|
|
589
|
-
>>> import matplotlib.pyplot as plt
|
|
590
|
-
>>> from sksurv.datasets import load_whas500
|
|
591
|
-
>>> from sksurv.tree import SurvivalTree
|
|
592
|
-
|
|
593
|
-
Load and prepare the data.
|
|
594
|
-
|
|
595
|
-
>>> X, y = load_whas500()
|
|
596
|
-
>>> X = X.astype(float)
|
|
597
|
-
|
|
598
|
-
Fit the model.
|
|
599
|
-
|
|
600
|
-
>>> estimator = SurvivalTree().fit(X, y)
|
|
601
|
-
|
|
602
|
-
Estimate the survival function for the first 5 samples.
|
|
603
|
-
|
|
604
|
-
>>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
|
|
605
|
-
|
|
606
|
-
Plot the estimated survival functions.
|
|
607
|
-
|
|
608
|
-
>>> for fn in surv_funcs:
|
|
609
|
-
... plt.step(fn.x, fn(fn.x), where="post")
|
|
610
|
-
...
|
|
611
|
-
>>> plt.ylim(0, 1)
|
|
612
|
-
>>> plt.show()
|
|
613
583
|
"""
|
|
614
584
|
self._check_low_memory("predict_survival_function")
|
|
615
585
|
check_is_fitted(self, "tree_")
|
|
@@ -640,7 +610,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
640
610
|
|
|
641
611
|
Returns
|
|
642
612
|
-------
|
|
643
|
-
X_leaves :
|
|
613
|
+
X_leaves : ndarray, shape = (n_samples,), dtype=int
|
|
644
614
|
For each datapoint x in X, return the index of the leaf x
|
|
645
615
|
ends up in. Leaves are numbered within
|
|
646
616
|
``[0; self.tree_.node_count)``, possibly with gaps in the
|
|
@@ -678,6 +648,110 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
678
648
|
|
|
679
649
|
|
|
680
650
|
class ExtraSurvivalTree(SurvivalTree):
|
|
651
|
+
"""An Extremely Randomized Survival Tree.
|
|
652
|
+
|
|
653
|
+
This class implements an Extremely Randomized Tree for survival analysis.
|
|
654
|
+
It differs from :class:`SurvivalTree` in how splits are chosen:
|
|
655
|
+
instead of searching for the optimal split, it considers a random subset
|
|
656
|
+
of features and random thresholds for each feature, then picks the best
|
|
657
|
+
among these random candidates.
|
|
658
|
+
|
|
659
|
+
Parameters
|
|
660
|
+
----------
|
|
661
|
+
splitter : {'best', 'random'}, default: 'random'
|
|
662
|
+
The strategy used to choose the split at each node. Supported
|
|
663
|
+
strategies are 'best' to choose the best split and 'random' to choose
|
|
664
|
+
the best random split.
|
|
665
|
+
|
|
666
|
+
max_depth : int or None, optional, default: None
|
|
667
|
+
The maximum depth of the tree. If None, then nodes are expanded until
|
|
668
|
+
all leaves are pure or until all leaves contain less than
|
|
669
|
+
`min_samples_split` samples.
|
|
670
|
+
|
|
671
|
+
min_samples_split : int, float, optional, default: 6
|
|
672
|
+
The minimum number of samples required to split an internal node:
|
|
673
|
+
|
|
674
|
+
- If int, then consider `min_samples_split` as the minimum number.
|
|
675
|
+
- If float, then `min_samples_split` is a fraction and
|
|
676
|
+
`ceil(min_samples_split * n_samples)` are the minimum
|
|
677
|
+
number of samples for each split.
|
|
678
|
+
|
|
679
|
+
min_samples_leaf : int, float, optional, default: 3
|
|
680
|
+
The minimum number of samples required to be at a leaf node.
|
|
681
|
+
A split point at any depth will only be considered if it leaves at
|
|
682
|
+
least ``min_samples_leaf`` training samples in each of the left and
|
|
683
|
+
right branches. This may have the effect of smoothing the model,
|
|
684
|
+
especially in regression.
|
|
685
|
+
|
|
686
|
+
- If int, then consider `min_samples_leaf` as the minimum number.
|
|
687
|
+
- If float, then `min_samples_leaf` is a fraction and
|
|
688
|
+
`ceil(min_samples_leaf * n_samples)` are the minimum
|
|
689
|
+
number of samples for each node.
|
|
690
|
+
|
|
691
|
+
min_weight_fraction_leaf : float, optional, default: 0.
|
|
692
|
+
The minimum weighted fraction of the sum total of weights (of all
|
|
693
|
+
the input samples) required to be at a leaf node. Samples have
|
|
694
|
+
equal weight when sample_weight is not provided.
|
|
695
|
+
|
|
696
|
+
max_features : int, float or {'sqrt', 'log2'} or None, optional, default: None
|
|
697
|
+
The number of features to consider when looking for the best split:
|
|
698
|
+
|
|
699
|
+
- If int, then consider `max_features` features at each split.
|
|
700
|
+
- If float, then `max_features` is a fraction and
|
|
701
|
+
`max(1, int(max_features * n_features_in_))` features are considered at
|
|
702
|
+
each split.
|
|
703
|
+
- If "sqrt", then `max_features=sqrt(n_features)`.
|
|
704
|
+
- If "log2", then `max_features=log2(n_features)`.
|
|
705
|
+
- If None, then `max_features=n_features`.
|
|
706
|
+
|
|
707
|
+
Note: the search for a split does not stop until at least one
|
|
708
|
+
valid partition of the node samples is found, even if it requires to
|
|
709
|
+
effectively inspect more than ``max_features`` features.
|
|
710
|
+
|
|
711
|
+
random_state : int, RandomState instance or None, optional, default: None
|
|
712
|
+
Controls the randomness of the estimator. The features are always
|
|
713
|
+
randomly permuted at each split, even if ``splitter`` is set to
|
|
714
|
+
``"best"``. When ``max_features < n_features``, the algorithm will
|
|
715
|
+
select ``max_features`` at random at each split before finding the best
|
|
716
|
+
split among them. But the best found split may vary across different
|
|
717
|
+
runs, even if ``max_features=n_features``. That is the case, if the
|
|
718
|
+
improvement of the criterion is identical for several splits and one
|
|
719
|
+
split has to be selected at random. To obtain a deterministic behavior
|
|
720
|
+
during fitting, ``random_state`` has to be fixed to an integer.
|
|
721
|
+
|
|
722
|
+
max_leaf_nodes : int or None, optional, default: None
|
|
723
|
+
Grow a tree with ``max_leaf_nodes`` in best-first fashion.
|
|
724
|
+
Best nodes are defined as relative reduction in impurity.
|
|
725
|
+
If None then unlimited number of leaf nodes.
|
|
726
|
+
|
|
727
|
+
low_memory : bool, optional, default: False
|
|
728
|
+
If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
|
|
729
|
+
and :meth:`predict_survival_function` are not implemented.
|
|
730
|
+
|
|
731
|
+
Attributes
|
|
732
|
+
----------
|
|
733
|
+
unique_times_ : ndarray, shape = (n_unique_times,), dtype = float
|
|
734
|
+
Unique time points.
|
|
735
|
+
|
|
736
|
+
max_features_ : int
|
|
737
|
+
The inferred value of max_features.
|
|
738
|
+
|
|
739
|
+
n_features_in_ : int
|
|
740
|
+
Number of features seen during ``fit``.
|
|
741
|
+
|
|
742
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,), dtype = object
|
|
743
|
+
Names of features seen during ``fit``. Defined only when `X`
|
|
744
|
+
has feature names that are all strings.
|
|
745
|
+
|
|
746
|
+
tree_ : Tree object
|
|
747
|
+
The underlying Tree object. Please refer to
|
|
748
|
+
``help(sklearn.tree._tree.Tree)`` for attributes of Tree object.
|
|
749
|
+
|
|
750
|
+
See also
|
|
751
|
+
--------
|
|
752
|
+
sksurv.ensemble.ExtraSurvivalTrees : An ensemble of ExtraSurvivalTrees.
|
|
753
|
+
"""
|
|
754
|
+
|
|
681
755
|
def __init__(
|
|
682
756
|
self,
|
|
683
757
|
*,
|
|
@@ -702,3 +776,15 @@ class ExtraSurvivalTree(SurvivalTree):
|
|
|
702
776
|
max_leaf_nodes=max_leaf_nodes,
|
|
703
777
|
low_memory=low_memory,
|
|
704
778
|
)
|
|
779
|
+
|
|
780
|
+
def predict_cumulative_hazard_function(self, X, check_input=True, return_array=False):
|
|
781
|
+
ExtraSurvivalTree.predict_cumulative_hazard_function.__doc__ = (
|
|
782
|
+
SurvivalTree.predict_cumulative_hazard_function.__doc__.replace("SurvivalTree", "ExtraSurvivalTree")
|
|
783
|
+
)
|
|
784
|
+
return super().predict_cumulative_hazard_function(X, check_input=check_input, return_array=return_array)
|
|
785
|
+
|
|
786
|
+
def predict_survival_function(self, X, check_input=True, return_array=False):
|
|
787
|
+
ExtraSurvivalTree.predict_survival_function.__doc__ = SurvivalTree.predict_survival_function.__doc__.replace(
|
|
788
|
+
"SurvivalTree", "ExtraSurvivalTree"
|
|
789
|
+
)
|
|
790
|
+
return super().predict_survival_function(X, check_input=check_input, return_array=return_array)
|
sksurv/util.py
CHANGED
|
@@ -19,29 +19,52 @@ __all__ = ["check_array_survival", "check_y_survival", "safe_concat", "Surv"]
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class Surv:
|
|
22
|
-
"""
|
|
23
|
-
|
|
22
|
+
"""A helper class to create a structured array for survival analysis.
|
|
23
|
+
|
|
24
|
+
This class provides helper functions to create a structured array that
|
|
25
|
+
encapsulates the event indicator and the observed time. The resulting
|
|
26
|
+
structured array is the recommended format for the ``y`` argument in
|
|
27
|
+
scikit-survival's estimators.
|
|
24
28
|
"""
|
|
25
29
|
|
|
26
30
|
@staticmethod
|
|
27
31
|
def from_arrays(event, time, name_event=None, name_time=None):
|
|
28
|
-
"""Create structured array.
|
|
32
|
+
"""Create structured array from event indicator and time arrays.
|
|
29
33
|
|
|
30
34
|
Parameters
|
|
31
35
|
----------
|
|
32
|
-
event : array-like
|
|
33
|
-
Event indicator. A boolean array or array with values 0/1
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
Name of
|
|
36
|
+
event : array-like, shape=(n_samples,)
|
|
37
|
+
Event indicator. A boolean array or array with values 0/1,
|
|
38
|
+
where ``True`` or 1 indicates an event and ``False`` or 0
|
|
39
|
+
indicates right-censoring.
|
|
40
|
+
time : array-like, shape=(n_samples,)
|
|
41
|
+
Observed time. Time to event or time of censoring.
|
|
42
|
+
name_event : str, optional, default: 'event'
|
|
43
|
+
Name of the event field in the structured array.
|
|
44
|
+
name_time : str, optional, default: 'time'
|
|
45
|
+
Name of the observed time field in the structured array.
|
|
40
46
|
|
|
41
47
|
Returns
|
|
42
48
|
-------
|
|
43
|
-
y :
|
|
44
|
-
|
|
49
|
+
y : numpy.ndarray
|
|
50
|
+
A structured array with two fields. The first field is a boolean
|
|
51
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
52
|
+
The second field is a float with the time of event or time of censoring.
|
|
53
|
+
The names of the fields are set to the values of `name_event` and `name_time`.
|
|
54
|
+
|
|
55
|
+
Examples
|
|
56
|
+
--------
|
|
57
|
+
>>> from sksurv.util import Surv
|
|
58
|
+
>>>
|
|
59
|
+
>>> y = Surv.from_arrays(event=[True, False, True],
|
|
60
|
+
... time=[10, 25, 15])
|
|
61
|
+
>>> y
|
|
62
|
+
array([( True, 10.), (False, 25.), ( True, 15.)],
|
|
63
|
+
dtype=[('event', '?'), ('time', '<f8')])
|
|
64
|
+
>>> y['event']
|
|
65
|
+
array([ True, False, True])
|
|
66
|
+
>>> y['time']
|
|
67
|
+
array([10., 25., 15.])
|
|
45
68
|
"""
|
|
46
69
|
name_event = name_event or "event"
|
|
47
70
|
name_time = name_time or "time"
|
|
@@ -72,21 +95,48 @@ class Surv:
|
|
|
72
95
|
|
|
73
96
|
@staticmethod
|
|
74
97
|
def from_dataframe(event, time, data):
|
|
75
|
-
"""Create structured array from
|
|
98
|
+
"""Create structured array from columns in a pandas DataFrame.
|
|
76
99
|
|
|
77
100
|
Parameters
|
|
78
101
|
----------
|
|
79
|
-
event :
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
102
|
+
event : str
|
|
103
|
+
Name of the column in ``data`` containing the event indicator.
|
|
104
|
+
It must be a boolean or have values 0/1,
|
|
105
|
+
where ``True`` or 1 indicates an event and ``False`` or 0
|
|
106
|
+
indicates right-censoring.
|
|
107
|
+
time : str
|
|
108
|
+
Name of the column in ``data`` containing the observed time
|
|
109
|
+
(time to event or time of censoring).
|
|
83
110
|
data : pandas.DataFrame
|
|
84
|
-
|
|
111
|
+
A DataFrame with columns for event and time.
|
|
85
112
|
|
|
86
113
|
Returns
|
|
87
114
|
-------
|
|
88
|
-
y :
|
|
89
|
-
|
|
115
|
+
y : numpy.ndarray
|
|
116
|
+
A structured array with two fields. The first field is a boolean
|
|
117
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
118
|
+
The second field is a float with the time of event or time of censoring.
|
|
119
|
+
The names of the fields are the respective column names.
|
|
120
|
+
|
|
121
|
+
Examples
|
|
122
|
+
--------
|
|
123
|
+
>>> import pandas as pd
|
|
124
|
+
>>> from sksurv.util import Surv
|
|
125
|
+
>>>
|
|
126
|
+
>>> df = pd.DataFrame({
|
|
127
|
+
... 'status': [True, False, True],
|
|
128
|
+
... 'followup_time': [10, 25, 15],
|
|
129
|
+
... })
|
|
130
|
+
>>> y = Surv.from_dataframe(
|
|
131
|
+
... event='status', time='followup_time', data=df,
|
|
132
|
+
... )
|
|
133
|
+
>>> y
|
|
134
|
+
array([( True, 10.), (False, 25.), ( True, 15.)],
|
|
135
|
+
dtype=[('status', '?'), ('followup_time', '<f8')])
|
|
136
|
+
>>> y['status']
|
|
137
|
+
array([ True, False, True])
|
|
138
|
+
>>> y['followup_time']
|
|
139
|
+
array([10., 25., 15.])
|
|
90
140
|
"""
|
|
91
141
|
if not isinstance(data, pd.DataFrame):
|
|
92
142
|
raise TypeError(f"expected pandas.DataFrame, but got {type(data)!r}")
|
|
@@ -180,15 +230,19 @@ def check_y_survival(y_or_event, *args, allow_all_censored=False, allow_time_zer
|
|
|
180
230
|
|
|
181
231
|
|
|
182
232
|
def check_event_dtype(event, competing_risks=False):
|
|
183
|
-
"""Check that the event array has the correct
|
|
184
|
-
|
|
233
|
+
"""Check that the event array has the correct dtype.
|
|
234
|
+
|
|
235
|
+
For single-event survival analysis, the event indicator must be a
|
|
236
|
+
boolean array. For competing risk analysis, it must be an integer
|
|
237
|
+
array.
|
|
185
238
|
|
|
186
239
|
Parameters
|
|
187
240
|
----------
|
|
188
|
-
event :
|
|
189
|
-
|
|
241
|
+
event : ndarray, shape=(n_samples,), dtype=bool | int
|
|
242
|
+
Array containing the event indicator.
|
|
243
|
+
|
|
190
244
|
competing_risks : bool, optional, default: False
|
|
191
|
-
|
|
245
|
+
Whether `event` is for a competing risks analysis.
|
|
192
246
|
"""
|
|
193
247
|
if competing_risks:
|
|
194
248
|
if not np.issubdtype(event.dtype, np.integer):
|