scikit-survival 0.23.1__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.23.1.dist-info/COPYING +674 -0
- scikit_survival-0.23.1.dist-info/METADATA +888 -0
- scikit_survival-0.23.1.dist-info/RECORD +55 -0
- scikit_survival-0.23.1.dist-info/WHEEL +5 -0
- scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +138 -0
- sksurv/base.py +103 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
- sksurv/column.py +201 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +10 -0
- sksurv/datasets/base.py +436 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
- sksurv/ensemble/boosting.py +1610 -0
- sksurv/ensemble/forest.py +947 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +58 -0
- sksurv/io/arffwrite.py +145 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +328 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
- sksurv/linear_model/aft.py +205 -0
- sksurv/linear_model/coxnet.py +543 -0
- sksurv/linear_model/coxph.py +618 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +642 -0
- sksurv/meta/stacking.py +349 -0
- sksurv/metrics.py +996 -0
- sksurv/nonparametric.py +588 -0
- sksurv/preprocessing.py +155 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +606 -0
- sksurv/svm/naive_survival_svm.py +221 -0
- sksurv/svm/survival_svm.py +1228 -0
- sksurv/testing.py +108 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
- sksurv/tree/tree.py +703 -0
- sksurv/util.py +333 -0
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
sksurv/__init__.py,sha256=ylL6ChkAkHWnuIFWG5lky_d_ZNFfukUUt5tU0HVJfgo,3710
|
|
2
|
+
sksurv/base.py,sha256=0HXvIivawJMsby7UxnLUFOKbKvurrJEM0KrMIYmqyBU,3801
|
|
3
|
+
sksurv/column.py,sha256=hojN9vVDw2ka5LRueyc3f1g12AjwuzSIiPqGP5js2Nk,6670
|
|
4
|
+
sksurv/compare.py,sha256=WbJb2NFXMncgkcrNKq3Xgo5YqW_klgkDkKJc219aJiw,4212
|
|
5
|
+
sksurv/exceptions.py,sha256=WopRsYNia5MQTvvVXL3U_nf418t29pplfBz6HFtmt1M,819
|
|
6
|
+
sksurv/functions.py,sha256=x63NGGVQVqtFiEG-wXW569q1U_kKJFwbOb9-8BcfpaM,3780
|
|
7
|
+
sksurv/metrics.py,sha256=rAA_UmN-ddNm52OE3pcixZcfg_XKFig7CeQ2eQw7ArI,38379
|
|
8
|
+
sksurv/nonparametric.py,sha256=aZo5tXkv8a6Fup7NKZzSqW9GEtpCfwk82B2RCWvD0QI,19339
|
|
9
|
+
sksurv/preprocessing.py,sha256=WGrL3ngfskFuMqdMxr7Dkw5Bj8NWij_xZuLRqDdz5Eo,5315
|
|
10
|
+
sksurv/testing.py,sha256=9qoUAeRfp3tb4Zg7oCHkBS0DJXkIuxdflRLoV7NEfek,4403
|
|
11
|
+
sksurv/util.py,sha256=S8GnCvsaOUqjbzFXQylXCUnQSfoy057H8FQ4XidYqgI,12276
|
|
12
|
+
sksurv/bintrees/__init__.py,sha256=z0GwaTPCzww2H2aXF28ubppw0Oc4umNTAlFAKu1VBJc,742
|
|
13
|
+
sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd,sha256=MxoXOQkQdRrPLmM5jRTZk-OWEeQayd2NQT74Zcwd82A,66560
|
|
14
|
+
sksurv/datasets/__init__.py,sha256=ZIU1vPRhO2CYpTf-Lbr64S73JkR1CAP7aVBO_wGHHAY,313
|
|
15
|
+
sksurv/datasets/base.py,sha256=l4xiUUNkYbY-JhpU2En6Kr7AZ29cb-KVUxS6SL-RxAk,15126
|
|
16
|
+
sksurv/datasets/data/GBSG2.arff,sha256=oX_UM7Qy841xBOArXBkUPLzIxNTvdtIJqpxXsqGGw9Q,26904
|
|
17
|
+
sksurv/datasets/data/actg320.arff,sha256=BwIq5q_i_75G2rPFQ6TjO0bsiR8MwA6wPouG-SX7TUo,46615
|
|
18
|
+
sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff,sha256=1dNSJczfgZOjp4Ya0SKREreyGorzWd-rm1ryihwIAck,265026
|
|
19
|
+
sksurv/datasets/data/flchain.arff,sha256=4LVUyEe-45ozaWPy0VkN-1js_MNsKw1gs2E-JRyjU4o,350945
|
|
20
|
+
sksurv/datasets/data/veteran.arff,sha256=LxZtbmq4I82rcB24JeJTYRtlgwPc3vM2OX5hg-q7xTw,5408
|
|
21
|
+
sksurv/datasets/data/whas500.arff,sha256=dvqRzx-nwgSVJZxNVE2zelnt7l3xgzFtMucB7Wux574,28292
|
|
22
|
+
sksurv/ensemble/__init__.py,sha256=aBjRTFm8UE5sTew292-qcplLUCc6owAfY6osWlj-VSM,193
|
|
23
|
+
sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd,sha256=ZEE6Cfc4Ubd-psYEVQflNGIaDngL6W2QJCnW7qKfiUE,153600
|
|
24
|
+
sksurv/ensemble/boosting.py,sha256=T2DWIPGMK9S_BX9dIuhnptH3v2P5jfCv_d-wXR0l3zE,63690
|
|
25
|
+
sksurv/ensemble/forest.py,sha256=su6MFuBhD2mMM6SAAq8_ixEv4i7kYm40gcBdNeXmw14,36330
|
|
26
|
+
sksurv/ensemble/survival_loss.py,sha256=v3tSou5t1YY6lBydAZYZ66DLqAirvRhErqW1dZYrTWE,6093
|
|
27
|
+
sksurv/io/__init__.py,sha256=dalzZGTrvekCM8wwsB636rg1dwDkQtDWaBOw7TpHr5U,94
|
|
28
|
+
sksurv/io/arffread.py,sha256=47rFxaZuavV0cbFUrZ_NjSmSByWXqkwZ4MkpFK4jw_w,1897
|
|
29
|
+
sksurv/io/arffwrite.py,sha256=Dkvnr_zz4i23xGPDTwxTAOOggbL1GxtgSi1em6kVp0Q,4609
|
|
30
|
+
sksurv/kernels/__init__.py,sha256=R1if2sVd_0_f6LniIGUR0tipIfzRKpzgGYnvrVZZvHM,78
|
|
31
|
+
sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd,sha256=cVv6Nb9vPfQqiSfLAFY7XhHjgpmO0beiN77ZmKKUQY4,161280
|
|
32
|
+
sksurv/kernels/clinical.py,sha256=kVAMx95rgrerfvGlmuNgW4hc3HojTruKoNnhuaeLkGw,11046
|
|
33
|
+
sksurv/linear_model/__init__.py,sha256=dO6Mr3wXk6Q-KQEuhpdgMeY3ji8ZVdgC-SeSRnrJdmw,155
|
|
34
|
+
sksurv/linear_model/_coxnet.cp313-win_amd64.pyd,sha256=PTpLr_NU8shli8Kl9--DaYTr-TePEZWhppi3opFNjXQ,92160
|
|
35
|
+
sksurv/linear_model/aft.py,sha256=oWtLF-WCCrdmjJtb9e91SHA-be_IQPTf7Kf9VWoXQB8,7616
|
|
36
|
+
sksurv/linear_model/coxnet.py,sha256=ZJAsJImSoGCL6HJ9uXQdph8sDLAZakcIvMuzRy28F4M,20727
|
|
37
|
+
sksurv/linear_model/coxph.py,sha256=PeYw7I_EHSgV3kriUnEhJg6CIdISkFk9ewlzijHC8Nw,21475
|
|
38
|
+
sksurv/meta/__init__.py,sha256=vw8vn6gR50a8O_7VhUw4aRgbSP2k3erVBgYS7ifwOcI,220
|
|
39
|
+
sksurv/meta/base.py,sha256=AdhIkZi9PvucZ3B2lhhFQpQbwp8EUCDUVOiaev_UsX8,1472
|
|
40
|
+
sksurv/meta/ensemble_selection.py,sha256=NTXzq6xvwL5IwdxyjNVkPOgzU_9lzdsyFKCq0YbdOXo,24673
|
|
41
|
+
sksurv/meta/stacking.py,sha256=md3ZKYYlS23YKbbLjWZqtPRSJYAHfLzN9RlC5RgzYEQ,12629
|
|
42
|
+
sksurv/svm/__init__.py,sha256=CSceYEcBPGKRcJZ4R0u7DzwictGln_weLIsbt2i5xeU,339
|
|
43
|
+
sksurv/svm/_minlip.cp313-win_amd64.pyd,sha256=G4o8dbyW7Az9bPcgGmMh2yQaOOwIl9SaCz9pg501lgc,157696
|
|
44
|
+
sksurv/svm/_prsvm.cp313-win_amd64.pyd,sha256=yJqDmx7ybvPy5Vh0MLLgDuloZJInYtqWY4OI-ESr39w,155136
|
|
45
|
+
sksurv/svm/minlip.py,sha256=Elej96P4s6jMNCbGIbNm4ehWI3ew9o2IyF5tfZuTEqg,22416
|
|
46
|
+
sksurv/svm/naive_survival_svm.py,sha256=D6l8ctlpyUPBhcMlF_6KJq9d_tq1Xmtb9w--n8pGcfM,8209
|
|
47
|
+
sksurv/svm/survival_svm.py,sha256=oGac8mWlDOlZPRoK1MmbcigKUq2bWPw1Uc8vYSsSlwI,44730
|
|
48
|
+
sksurv/tree/__init__.py,sha256=ozb0fhURX-lpiSiHZd0DMnjkqhC0XOC2CTZq0hEZLPw,65
|
|
49
|
+
sksurv/tree/_criterion.cp313-win_amd64.pyd,sha256=l0JJbug1cS8lM46DwH_rGrEz5G5GGnfqVDYP4c3vreQ,176128
|
|
50
|
+
sksurv/tree/tree.py,sha256=PPkqq8jDeAHCd9L283-SLCk7lTBfH-HObTOuOXOvqHs,27248
|
|
51
|
+
scikit_survival-0.23.1.dist-info/COPYING,sha256=Czg9WmPaZE9ijZnDOXbqZIftiaqlnwsyV5kt6sEXHms,35821
|
|
52
|
+
scikit_survival-0.23.1.dist-info/METADATA,sha256=alYA45glhQaH9X_MzelejmubVqXGWVaPiiMPqRa6Yns,49841
|
|
53
|
+
scikit_survival-0.23.1.dist-info/WHEEL,sha256=UJbDlEYuWWwgv9Hu0As4Rgv2Qpdka2YFe6UlEKs4AoE,101
|
|
54
|
+
scikit_survival-0.23.1.dist-info/top_level.txt,sha256=fPkcFA-XQGbwnD_ZXOvaOWmSd34Qezr26Mn99nYPvAg,7
|
|
55
|
+
scikit_survival-0.23.1.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
sksurv
|
sksurv/__init__.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
2
|
+
import platform
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from sklearn.pipeline import Pipeline, _final_estimator_has
|
|
6
|
+
from sklearn.utils.metaestimators import available_if
|
|
7
|
+
|
|
8
|
+
from .util import property_available_if
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _get_version(name):
|
|
12
|
+
try:
|
|
13
|
+
pkg_version = version(name)
|
|
14
|
+
except ImportError:
|
|
15
|
+
pkg_version = None
|
|
16
|
+
return pkg_version
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def show_versions():
|
|
20
|
+
sys_info = {
|
|
21
|
+
"Platform": platform.platform(),
|
|
22
|
+
"Python version": f"{platform.python_implementation()} {platform}",
|
|
23
|
+
"Python interpreter": sys.executable,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
deps = [
|
|
27
|
+
"scikit-survival",
|
|
28
|
+
"scikit-learn",
|
|
29
|
+
"numpy",
|
|
30
|
+
"scipy",
|
|
31
|
+
"pandas",
|
|
32
|
+
"numexpr",
|
|
33
|
+
"ecos",
|
|
34
|
+
"osqp",
|
|
35
|
+
"joblib",
|
|
36
|
+
"matplotlib",
|
|
37
|
+
"pytest",
|
|
38
|
+
"sphinx",
|
|
39
|
+
"Cython",
|
|
40
|
+
"pip",
|
|
41
|
+
"setuptools",
|
|
42
|
+
]
|
|
43
|
+
minwidth = max(
|
|
44
|
+
max(map(len, deps)),
|
|
45
|
+
max(map(len, sys_info.keys())),
|
|
46
|
+
)
|
|
47
|
+
fmt = "{0:<%ds}: {1}" % minwidth
|
|
48
|
+
|
|
49
|
+
print("SYSTEM")
|
|
50
|
+
print("------")
|
|
51
|
+
for name, version_string in sys_info.items():
|
|
52
|
+
print(fmt.format(name, version_string))
|
|
53
|
+
|
|
54
|
+
print("\nDEPENDENCIES")
|
|
55
|
+
print("------------")
|
|
56
|
+
for dep in deps:
|
|
57
|
+
version_string = _get_version(dep)
|
|
58
|
+
print(fmt.format(dep, version_string))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@available_if(_final_estimator_has("predict_cumulative_hazard_function"))
|
|
62
|
+
def predict_cumulative_hazard_function(self, X, **kwargs):
|
|
63
|
+
"""Predict cumulative hazard function.
|
|
64
|
+
|
|
65
|
+
The cumulative hazard function for an individual
|
|
66
|
+
with feature vector :math:`x` is defined as
|
|
67
|
+
|
|
68
|
+
.. math::
|
|
69
|
+
|
|
70
|
+
H(t \\mid x) = \\exp(x^\\top \\beta) H_0(t) ,
|
|
71
|
+
|
|
72
|
+
where :math:`H_0(t)` is the baseline hazard function,
|
|
73
|
+
estimated by Breslow's estimator.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
X : array-like, shape = (n_samples, n_features)
|
|
78
|
+
Data matrix.
|
|
79
|
+
|
|
80
|
+
Returns
|
|
81
|
+
-------
|
|
82
|
+
cum_hazard : ndarray, shape = (n_samples,)
|
|
83
|
+
Predicted cumulative hazard functions.
|
|
84
|
+
"""
|
|
85
|
+
Xt = X
|
|
86
|
+
for _, _, transform in self._iter(with_final=False):
|
|
87
|
+
Xt = transform.transform(Xt)
|
|
88
|
+
return self.steps[-1][-1].predict_cumulative_hazard_function(Xt, **kwargs)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@available_if(_final_estimator_has("predict_survival_function"))
|
|
92
|
+
def predict_survival_function(self, X, **kwargs):
|
|
93
|
+
"""Predict survival function.
|
|
94
|
+
|
|
95
|
+
The survival function for an individual
|
|
96
|
+
with feature vector :math:`x` is defined as
|
|
97
|
+
|
|
98
|
+
.. math::
|
|
99
|
+
|
|
100
|
+
S(t \\mid x) = S_0(t)^{\\exp(x^\\top \\beta)} ,
|
|
101
|
+
|
|
102
|
+
where :math:`S_0(t)` is the baseline survival function,
|
|
103
|
+
estimated by Breslow's estimator.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
X : array-like, shape = (n_samples, n_features)
|
|
108
|
+
Data matrix.
|
|
109
|
+
|
|
110
|
+
Returns
|
|
111
|
+
-------
|
|
112
|
+
survival : ndarray, shape = (n_samples,)
|
|
113
|
+
Predicted survival functions.
|
|
114
|
+
"""
|
|
115
|
+
Xt = X
|
|
116
|
+
for _, _, transform in self._iter(with_final=False):
|
|
117
|
+
Xt = transform.transform(Xt)
|
|
118
|
+
return self.steps[-1][-1].predict_survival_function(Xt, **kwargs)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@property_available_if(_final_estimator_has("_predict_risk_score"))
|
|
122
|
+
def _predict_risk_score(self):
|
|
123
|
+
return self.steps[-1][-1]._predict_risk_score
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def patch_pipeline():
|
|
127
|
+
Pipeline.predict_survival_function = predict_survival_function
|
|
128
|
+
Pipeline.predict_cumulative_hazard_function = predict_cumulative_hazard_function
|
|
129
|
+
Pipeline._predict_risk_score = _predict_risk_score
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
__version__ = version("scikit-survival")
|
|
134
|
+
except PackageNotFoundError: # pragma: no cover
|
|
135
|
+
# package is not installed
|
|
136
|
+
__version__ = "unknown"
|
|
137
|
+
|
|
138
|
+
patch_pipeline()
|
sksurv/base.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SurvivalAnalysisMixin:
|
|
17
|
+
def _predict_function(self, func_name, baseline_model, prediction, return_array):
|
|
18
|
+
fns = getattr(baseline_model, func_name)(prediction)
|
|
19
|
+
|
|
20
|
+
if not return_array:
|
|
21
|
+
return fns
|
|
22
|
+
|
|
23
|
+
times = baseline_model.unique_times_
|
|
24
|
+
arr = np.empty((prediction.shape[0], times.shape[0]), dtype=float)
|
|
25
|
+
for i, fn in enumerate(fns):
|
|
26
|
+
arr[i, :] = fn(times)
|
|
27
|
+
return arr
|
|
28
|
+
|
|
29
|
+
def _predict_survival_function(self, baseline_model, prediction, return_array):
|
|
30
|
+
"""Return survival functions.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
baseline_model : sksurv.linear_model.coxph.BreslowEstimator
|
|
35
|
+
Estimator of baseline survival function.
|
|
36
|
+
|
|
37
|
+
prediction : array-like, shape=(n_samples,)
|
|
38
|
+
Predicted risk scores.
|
|
39
|
+
|
|
40
|
+
return_array : bool
|
|
41
|
+
If True, return a float array of the survival function
|
|
42
|
+
evaluated at the unique event times, otherwise return
|
|
43
|
+
an array of :class:`sksurv.functions.StepFunction` instances.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
survival : ndarray
|
|
48
|
+
"""
|
|
49
|
+
return self._predict_function("get_survival_function", baseline_model, prediction, return_array)
|
|
50
|
+
|
|
51
|
+
def _predict_cumulative_hazard_function(self, baseline_model, prediction, return_array):
|
|
52
|
+
"""Return cumulative hazard functions.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
baseline_model : sksurv.linear_model.coxph.BreslowEstimator
|
|
57
|
+
Estimator of baseline cumulative hazard function.
|
|
58
|
+
|
|
59
|
+
prediction : array-like, shape=(n_samples,)
|
|
60
|
+
Predicted risk scores.
|
|
61
|
+
|
|
62
|
+
return_array : bool
|
|
63
|
+
If True, return a float array of the cumulative hazard function
|
|
64
|
+
evaluated at the unique event times, otherwise return
|
|
65
|
+
an array of :class:`sksurv.functions.StepFunction` instances.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
cum_hazard : ndarray
|
|
70
|
+
"""
|
|
71
|
+
return self._predict_function("get_cumulative_hazard_function", baseline_model, prediction, return_array)
|
|
72
|
+
|
|
73
|
+
def score(self, X, y):
|
|
74
|
+
"""Returns the concordance index of the prediction.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
X : array-like, shape = (n_samples, n_features)
|
|
79
|
+
Test samples.
|
|
80
|
+
|
|
81
|
+
y : structured array, shape = (n_samples,)
|
|
82
|
+
A structured array containing the binary event indicator
|
|
83
|
+
as first field, and time of event or time of censoring as
|
|
84
|
+
second field.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
cindex : float
|
|
89
|
+
Estimated concordance index.
|
|
90
|
+
"""
|
|
91
|
+
from .metrics import concordance_index_censored
|
|
92
|
+
|
|
93
|
+
name_event, name_time = y.dtype.names
|
|
94
|
+
|
|
95
|
+
risk_score = self.predict(X)
|
|
96
|
+
if not getattr(self, "_predict_risk_score", True):
|
|
97
|
+
risk_score *= -1 # convert prediction on time scale to risk scale
|
|
98
|
+
|
|
99
|
+
result = concordance_index_censored(y[name_event], y[name_time], risk_score)
|
|
100
|
+
return result[0]
|
|
101
|
+
|
|
102
|
+
def _more_tags(self):
|
|
103
|
+
return {"requires_y": True}
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
from ._binarytrees import AATree, AVLTree, RBTree
|
|
14
|
+
|
|
15
|
+
__all__ = ["RBTree", "AVLTree", "AATree"]
|
|
Binary file
|
sksurv/column.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import logging
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
from pandas.api.types import CategoricalDtype, is_object_dtype
|
|
18
|
+
|
|
19
|
+
__all__ = ["categorical_to_numeric", "encode_categorical", "standardize"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _apply_along_column(array, func1d, **kwargs):
|
|
23
|
+
if isinstance(array, pd.DataFrame):
|
|
24
|
+
return array.apply(func1d, **kwargs)
|
|
25
|
+
return np.apply_along_axis(func1d, 0, array, **kwargs)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def standardize_column(series_or_array, with_std=True):
|
|
29
|
+
d = series_or_array.dtype
|
|
30
|
+
if issubclass(d.type, np.number):
|
|
31
|
+
output = series_or_array.astype(float)
|
|
32
|
+
m = series_or_array.mean()
|
|
33
|
+
output -= m
|
|
34
|
+
|
|
35
|
+
if with_std:
|
|
36
|
+
s = series_or_array.std(ddof=1)
|
|
37
|
+
output /= s
|
|
38
|
+
|
|
39
|
+
return output
|
|
40
|
+
|
|
41
|
+
return series_or_array
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def standardize(table, with_std=True):
|
|
45
|
+
"""
|
|
46
|
+
Perform Z-Normalization on each numeric column of the given table.
|
|
47
|
+
|
|
48
|
+
If `table` is a pandas.DataFrame, only numeric columns are modified,
|
|
49
|
+
all other columns remain unchanged. If `table` is a numpy.ndarray,
|
|
50
|
+
it is only modified if it has numeric dtype, in which case the returned
|
|
51
|
+
array will have floating point dtype.
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
table : pandas.DataFrame or numpy.ndarray
|
|
56
|
+
Data to standardize.
|
|
57
|
+
|
|
58
|
+
with_std : bool, optional, default: True
|
|
59
|
+
If ``False`` data is only centered and not converted to unit variance.
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
normalized : pandas.DataFrame
|
|
64
|
+
Table with numeric columns normalized.
|
|
65
|
+
Categorical columns in the input table remain unchanged.
|
|
66
|
+
"""
|
|
67
|
+
new_frame = _apply_along_column(table, standardize_column, with_std=with_std)
|
|
68
|
+
|
|
69
|
+
return new_frame
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _encode_categorical_series(series, allow_drop=True):
|
|
73
|
+
values = _get_dummies_1d(series, allow_drop=allow_drop)
|
|
74
|
+
if values is None:
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
enc, levels = values
|
|
78
|
+
if enc is None:
|
|
79
|
+
return pd.Series(index=series.index, name=series.name, dtype=series.dtype)
|
|
80
|
+
|
|
81
|
+
if not allow_drop and enc.shape[1] == 1:
|
|
82
|
+
return series
|
|
83
|
+
|
|
84
|
+
names = []
|
|
85
|
+
for key in range(1, enc.shape[1]):
|
|
86
|
+
names.append(f"{series.name}={levels[key]}")
|
|
87
|
+
series = pd.DataFrame(enc[:, 1:], columns=names, index=series.index)
|
|
88
|
+
|
|
89
|
+
return series
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def encode_categorical(table, columns=None, **kwargs):
|
|
93
|
+
"""
|
|
94
|
+
Encode categorical columns with `M` categories into `M-1` columns according
|
|
95
|
+
to the one-hot scheme.
|
|
96
|
+
|
|
97
|
+
Parameters
|
|
98
|
+
----------
|
|
99
|
+
table : pandas.DataFrame
|
|
100
|
+
Table with categorical columns to encode.
|
|
101
|
+
|
|
102
|
+
columns : list-like, optional, default: None
|
|
103
|
+
Column names in the DataFrame to be encoded.
|
|
104
|
+
If `columns` is None then all the columns with
|
|
105
|
+
`object` or `category` dtype will be converted.
|
|
106
|
+
|
|
107
|
+
allow_drop : boolean, optional, default: True
|
|
108
|
+
Whether to allow dropping categorical columns that only consist
|
|
109
|
+
of a single category.
|
|
110
|
+
|
|
111
|
+
Returns
|
|
112
|
+
-------
|
|
113
|
+
encoded : pandas.DataFrame
|
|
114
|
+
Table with categorical columns encoded as numeric.
|
|
115
|
+
Numeric columns in the input table remain unchanged.
|
|
116
|
+
"""
|
|
117
|
+
if isinstance(table, pd.Series):
|
|
118
|
+
if not isinstance(table.dtype, CategoricalDtype) and not is_object_dtype(table.dtype):
|
|
119
|
+
raise TypeError(f"series must be of categorical dtype, but was {table.dtype}")
|
|
120
|
+
return _encode_categorical_series(table, **kwargs)
|
|
121
|
+
|
|
122
|
+
def _is_categorical_or_object(series):
|
|
123
|
+
return isinstance(series.dtype, CategoricalDtype) or is_object_dtype(series.dtype)
|
|
124
|
+
|
|
125
|
+
if columns is None:
|
|
126
|
+
# for columns containing categories
|
|
127
|
+
columns_to_encode = {nam for nam, s in table.items() if _is_categorical_or_object(s)}
|
|
128
|
+
else:
|
|
129
|
+
columns_to_encode = set(columns)
|
|
130
|
+
|
|
131
|
+
items = []
|
|
132
|
+
for name, series in table.items():
|
|
133
|
+
if name in columns_to_encode:
|
|
134
|
+
series = _encode_categorical_series(series, **kwargs)
|
|
135
|
+
if series is None:
|
|
136
|
+
continue
|
|
137
|
+
items.append(series)
|
|
138
|
+
|
|
139
|
+
# concat columns of tables
|
|
140
|
+
new_table = pd.concat(items, axis=1, copy=False)
|
|
141
|
+
return new_table
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _get_dummies_1d(data, allow_drop=True):
|
|
145
|
+
# Series avoids inconsistent NaN handling
|
|
146
|
+
cat = pd.Categorical(data)
|
|
147
|
+
levels = cat.categories
|
|
148
|
+
number_of_cols = len(levels)
|
|
149
|
+
|
|
150
|
+
# if all NaN or only one level
|
|
151
|
+
if allow_drop and number_of_cols < 2:
|
|
152
|
+
logging.getLogger(__package__).warning(
|
|
153
|
+
f"dropped categorical variable {data.name!r}, because it has only {number_of_cols} values"
|
|
154
|
+
)
|
|
155
|
+
return
|
|
156
|
+
if number_of_cols == 0:
|
|
157
|
+
return None, levels
|
|
158
|
+
|
|
159
|
+
dummy_mat = np.eye(number_of_cols).take(cat.codes, axis=0)
|
|
160
|
+
|
|
161
|
+
# reset NaN GH4446
|
|
162
|
+
dummy_mat[cat.codes == -1] = np.nan
|
|
163
|
+
|
|
164
|
+
return dummy_mat, levels
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def categorical_to_numeric(table):
|
|
168
|
+
"""Encode categorical columns to numeric by converting each category to
|
|
169
|
+
an integer value.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
table : pandas.DataFrame
|
|
174
|
+
Table with categorical columns to encode.
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
encoded : pandas.DataFrame
|
|
179
|
+
Table with categorical columns encoded as numeric.
|
|
180
|
+
Numeric columns in the input table remain unchanged.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
def transform(column):
|
|
184
|
+
if isinstance(column.dtype, CategoricalDtype):
|
|
185
|
+
return column.cat.codes
|
|
186
|
+
if is_object_dtype(column.dtype):
|
|
187
|
+
try:
|
|
188
|
+
nc = column.astype(np.int64)
|
|
189
|
+
except ValueError:
|
|
190
|
+
classes = column.dropna().unique()
|
|
191
|
+
classes.sort(kind="mergesort")
|
|
192
|
+
nc = column.map(dict(zip(classes, range(classes.shape[0]))))
|
|
193
|
+
return nc
|
|
194
|
+
if column.dtype == bool:
|
|
195
|
+
return column.astype(np.int64)
|
|
196
|
+
|
|
197
|
+
return column
|
|
198
|
+
|
|
199
|
+
if isinstance(table, pd.Series):
|
|
200
|
+
return pd.Series(transform(table), name=table.name, index=table.index)
|
|
201
|
+
return table.apply(transform, axis=0, result_type="expand")
|
sksurv/compare.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from scipy import stats
|
|
6
|
+
from sklearn.utils.validation import check_array
|
|
7
|
+
|
|
8
|
+
from .util import check_array_survival
|
|
9
|
+
|
|
10
|
+
__all__ = ["compare_survival"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def compare_survival(y, group_indicator, return_stats=False):
|
|
14
|
+
"""K-sample log-rank hypothesis test of identical survival functions.
|
|
15
|
+
|
|
16
|
+
Compares the pooled hazard rate with each group-specific
|
|
17
|
+
hazard rate. The alternative hypothesis is that the hazard
|
|
18
|
+
rate of at least one group differs from the others at some time.
|
|
19
|
+
|
|
20
|
+
See [1]_ for more details.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
y : structured array, shape = (n_samples,)
|
|
25
|
+
A structured array containing the binary event indicator
|
|
26
|
+
as first field, and time of event or time of censoring as
|
|
27
|
+
second field.
|
|
28
|
+
|
|
29
|
+
group_indicator : array-like, shape = (n_samples,)
|
|
30
|
+
Group membership of each sample.
|
|
31
|
+
|
|
32
|
+
return_stats : bool, optional, default: False
|
|
33
|
+
Whether to return a data frame with statistics for each group
|
|
34
|
+
and the covariance matrix of the test statistic.
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
chisq : float
|
|
39
|
+
Test statistic.
|
|
40
|
+
pvalue : float
|
|
41
|
+
Two-sided p-value with respect to the null hypothesis
|
|
42
|
+
that the hazard rates across all groups are equal.
|
|
43
|
+
stats : pandas.DataFrame
|
|
44
|
+
Summary statistics for each group: number of samples,
|
|
45
|
+
observed number of events, expected number of events,
|
|
46
|
+
and test statistic.
|
|
47
|
+
Only provided if `return_stats` is True.
|
|
48
|
+
covariance : array, shape=(n_groups, n_groups)
|
|
49
|
+
Covariance matrix of the test statistic.
|
|
50
|
+
Only provided if `return_stats` is True.
|
|
51
|
+
|
|
52
|
+
References
|
|
53
|
+
----------
|
|
54
|
+
.. [1] Fleming, T. R. and Harrington, D. P.
|
|
55
|
+
A Class of Hypothesis Tests for One and Two Samples of Censored Survival Data.
|
|
56
|
+
Communications In Statistics 10 (1981): 763-794.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
event, time = check_array_survival(group_indicator, y)
|
|
60
|
+
group_indicator = check_array(
|
|
61
|
+
group_indicator,
|
|
62
|
+
dtype="O",
|
|
63
|
+
ensure_2d=False,
|
|
64
|
+
estimator="compare_survival",
|
|
65
|
+
input_name="group_indicator",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
n_samples = time.shape[0]
|
|
69
|
+
groups, group_counts = np.unique(group_indicator, return_counts=True)
|
|
70
|
+
n_groups = groups.shape[0]
|
|
71
|
+
if n_groups == 1:
|
|
72
|
+
raise ValueError("At least two groups must be specified, but only one was provided.")
|
|
73
|
+
|
|
74
|
+
# sort descending
|
|
75
|
+
o = np.argsort(-time, kind="mergesort")
|
|
76
|
+
x = group_indicator[o]
|
|
77
|
+
event = event[o]
|
|
78
|
+
time = time[o]
|
|
79
|
+
|
|
80
|
+
at_risk = np.zeros(n_groups, dtype=int)
|
|
81
|
+
observed = np.zeros(n_groups, dtype=int)
|
|
82
|
+
expected = np.zeros(n_groups, dtype=float)
|
|
83
|
+
covar = np.zeros((n_groups, n_groups), dtype=float)
|
|
84
|
+
|
|
85
|
+
covar_indices = np.diag_indices(n_groups)
|
|
86
|
+
|
|
87
|
+
k = 0
|
|
88
|
+
while k < n_samples:
|
|
89
|
+
ti = time[k]
|
|
90
|
+
total_events = 0
|
|
91
|
+
while k < n_samples and ti == time[k]:
|
|
92
|
+
idx = np.searchsorted(groups, x[k])
|
|
93
|
+
if event[k]:
|
|
94
|
+
observed[idx] += 1
|
|
95
|
+
total_events += 1
|
|
96
|
+
at_risk[idx] += 1
|
|
97
|
+
k += 1
|
|
98
|
+
|
|
99
|
+
if total_events != 0:
|
|
100
|
+
total_at_risk = k
|
|
101
|
+
expected += at_risk * (total_events / total_at_risk)
|
|
102
|
+
if total_at_risk > 1:
|
|
103
|
+
multiplier = total_events * (total_at_risk - total_events) / (total_at_risk * (total_at_risk - 1))
|
|
104
|
+
temp = at_risk * multiplier
|
|
105
|
+
covar[covar_indices] += temp
|
|
106
|
+
covar -= np.outer(temp, at_risk) / total_at_risk
|
|
107
|
+
|
|
108
|
+
df = n_groups - 1
|
|
109
|
+
zz = observed[:df] - expected[:df]
|
|
110
|
+
chisq = np.linalg.solve(covar[:df, :df], zz).dot(zz)
|
|
111
|
+
pval = stats.chi2.sf(chisq, df)
|
|
112
|
+
|
|
113
|
+
if return_stats:
|
|
114
|
+
table = OrderedDict()
|
|
115
|
+
table["counts"] = group_counts
|
|
116
|
+
table["observed"] = observed
|
|
117
|
+
table["expected"] = expected
|
|
118
|
+
table["statistic"] = observed - expected
|
|
119
|
+
table = pd.DataFrame.from_dict(table)
|
|
120
|
+
table.index = pd.Index(groups, name="group", dtype=groups.dtype)
|
|
121
|
+
return chisq, pval, table, covar
|
|
122
|
+
|
|
123
|
+
return chisq, pval
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from .base import (
|
|
2
|
+
get_x_y, # noqa: F401
|
|
3
|
+
load_aids, # noqa: F401
|
|
4
|
+
load_arff_files_standardized, # noqa: F401
|
|
5
|
+
load_breast_cancer, # noqa: F401
|
|
6
|
+
load_flchain, # noqa: F401
|
|
7
|
+
load_gbsg2, # noqa: F401
|
|
8
|
+
load_veterans_lung_cancer, # noqa: F401
|
|
9
|
+
load_whas500, # noqa: F401
|
|
10
|
+
)
|