scikit-survival 0.23.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. scikit_survival-0.23.1.dist-info/COPYING +674 -0
  2. scikit_survival-0.23.1.dist-info/METADATA +888 -0
  3. scikit_survival-0.23.1.dist-info/RECORD +55 -0
  4. scikit_survival-0.23.1.dist-info/WHEEL +5 -0
  5. scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +138 -0
  7. sksurv/base.py +103 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
  10. sksurv/column.py +201 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +10 -0
  13. sksurv/datasets/base.py +436 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  17. sksurv/datasets/data/flchain.arff +7887 -0
  18. sksurv/datasets/data/veteran.arff +148 -0
  19. sksurv/datasets/data/whas500.arff +520 -0
  20. sksurv/ensemble/__init__.py +2 -0
  21. sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
  22. sksurv/ensemble/boosting.py +1610 -0
  23. sksurv/ensemble/forest.py +947 -0
  24. sksurv/ensemble/survival_loss.py +151 -0
  25. sksurv/exceptions.py +18 -0
  26. sksurv/functions.py +114 -0
  27. sksurv/io/__init__.py +2 -0
  28. sksurv/io/arffread.py +58 -0
  29. sksurv/io/arffwrite.py +145 -0
  30. sksurv/kernels/__init__.py +1 -0
  31. sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
  32. sksurv/kernels/clinical.py +328 -0
  33. sksurv/linear_model/__init__.py +3 -0
  34. sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
  35. sksurv/linear_model/aft.py +205 -0
  36. sksurv/linear_model/coxnet.py +543 -0
  37. sksurv/linear_model/coxph.py +618 -0
  38. sksurv/meta/__init__.py +4 -0
  39. sksurv/meta/base.py +35 -0
  40. sksurv/meta/ensemble_selection.py +642 -0
  41. sksurv/meta/stacking.py +349 -0
  42. sksurv/metrics.py +996 -0
  43. sksurv/nonparametric.py +588 -0
  44. sksurv/preprocessing.py +155 -0
  45. sksurv/svm/__init__.py +11 -0
  46. sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
  47. sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
  48. sksurv/svm/minlip.py +606 -0
  49. sksurv/svm/naive_survival_svm.py +221 -0
  50. sksurv/svm/survival_svm.py +1228 -0
  51. sksurv/testing.py +108 -0
  52. sksurv/tree/__init__.py +1 -0
  53. sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
  54. sksurv/tree/tree.py +703 -0
  55. sksurv/util.py +333 -0
@@ -0,0 +1,55 @@
1
+ sksurv/__init__.py,sha256=ylL6ChkAkHWnuIFWG5lky_d_ZNFfukUUt5tU0HVJfgo,3710
2
+ sksurv/base.py,sha256=0HXvIivawJMsby7UxnLUFOKbKvurrJEM0KrMIYmqyBU,3801
3
+ sksurv/column.py,sha256=hojN9vVDw2ka5LRueyc3f1g12AjwuzSIiPqGP5js2Nk,6670
4
+ sksurv/compare.py,sha256=WbJb2NFXMncgkcrNKq3Xgo5YqW_klgkDkKJc219aJiw,4212
5
+ sksurv/exceptions.py,sha256=WopRsYNia5MQTvvVXL3U_nf418t29pplfBz6HFtmt1M,819
6
+ sksurv/functions.py,sha256=x63NGGVQVqtFiEG-wXW569q1U_kKJFwbOb9-8BcfpaM,3780
7
+ sksurv/metrics.py,sha256=rAA_UmN-ddNm52OE3pcixZcfg_XKFig7CeQ2eQw7ArI,38379
8
+ sksurv/nonparametric.py,sha256=aZo5tXkv8a6Fup7NKZzSqW9GEtpCfwk82B2RCWvD0QI,19339
9
+ sksurv/preprocessing.py,sha256=WGrL3ngfskFuMqdMxr7Dkw5Bj8NWij_xZuLRqDdz5Eo,5315
10
+ sksurv/testing.py,sha256=9qoUAeRfp3tb4Zg7oCHkBS0DJXkIuxdflRLoV7NEfek,4403
11
+ sksurv/util.py,sha256=S8GnCvsaOUqjbzFXQylXCUnQSfoy057H8FQ4XidYqgI,12276
12
+ sksurv/bintrees/__init__.py,sha256=z0GwaTPCzww2H2aXF28ubppw0Oc4umNTAlFAKu1VBJc,742
13
+ sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd,sha256=MxoXOQkQdRrPLmM5jRTZk-OWEeQayd2NQT74Zcwd82A,66560
14
+ sksurv/datasets/__init__.py,sha256=ZIU1vPRhO2CYpTf-Lbr64S73JkR1CAP7aVBO_wGHHAY,313
15
+ sksurv/datasets/base.py,sha256=l4xiUUNkYbY-JhpU2En6Kr7AZ29cb-KVUxS6SL-RxAk,15126
16
+ sksurv/datasets/data/GBSG2.arff,sha256=oX_UM7Qy841xBOArXBkUPLzIxNTvdtIJqpxXsqGGw9Q,26904
17
+ sksurv/datasets/data/actg320.arff,sha256=BwIq5q_i_75G2rPFQ6TjO0bsiR8MwA6wPouG-SX7TUo,46615
18
+ sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff,sha256=1dNSJczfgZOjp4Ya0SKREreyGorzWd-rm1ryihwIAck,265026
19
+ sksurv/datasets/data/flchain.arff,sha256=4LVUyEe-45ozaWPy0VkN-1js_MNsKw1gs2E-JRyjU4o,350945
20
+ sksurv/datasets/data/veteran.arff,sha256=LxZtbmq4I82rcB24JeJTYRtlgwPc3vM2OX5hg-q7xTw,5408
21
+ sksurv/datasets/data/whas500.arff,sha256=dvqRzx-nwgSVJZxNVE2zelnt7l3xgzFtMucB7Wux574,28292
22
+ sksurv/ensemble/__init__.py,sha256=aBjRTFm8UE5sTew292-qcplLUCc6owAfY6osWlj-VSM,193
23
+ sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd,sha256=ZEE6Cfc4Ubd-psYEVQflNGIaDngL6W2QJCnW7qKfiUE,153600
24
+ sksurv/ensemble/boosting.py,sha256=T2DWIPGMK9S_BX9dIuhnptH3v2P5jfCv_d-wXR0l3zE,63690
25
+ sksurv/ensemble/forest.py,sha256=su6MFuBhD2mMM6SAAq8_ixEv4i7kYm40gcBdNeXmw14,36330
26
+ sksurv/ensemble/survival_loss.py,sha256=v3tSou5t1YY6lBydAZYZ66DLqAirvRhErqW1dZYrTWE,6093
27
+ sksurv/io/__init__.py,sha256=dalzZGTrvekCM8wwsB636rg1dwDkQtDWaBOw7TpHr5U,94
28
+ sksurv/io/arffread.py,sha256=47rFxaZuavV0cbFUrZ_NjSmSByWXqkwZ4MkpFK4jw_w,1897
29
+ sksurv/io/arffwrite.py,sha256=Dkvnr_zz4i23xGPDTwxTAOOggbL1GxtgSi1em6kVp0Q,4609
30
+ sksurv/kernels/__init__.py,sha256=R1if2sVd_0_f6LniIGUR0tipIfzRKpzgGYnvrVZZvHM,78
31
+ sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd,sha256=cVv6Nb9vPfQqiSfLAFY7XhHjgpmO0beiN77ZmKKUQY4,161280
32
+ sksurv/kernels/clinical.py,sha256=kVAMx95rgrerfvGlmuNgW4hc3HojTruKoNnhuaeLkGw,11046
33
+ sksurv/linear_model/__init__.py,sha256=dO6Mr3wXk6Q-KQEuhpdgMeY3ji8ZVdgC-SeSRnrJdmw,155
34
+ sksurv/linear_model/_coxnet.cp313-win_amd64.pyd,sha256=PTpLr_NU8shli8Kl9--DaYTr-TePEZWhppi3opFNjXQ,92160
35
+ sksurv/linear_model/aft.py,sha256=oWtLF-WCCrdmjJtb9e91SHA-be_IQPTf7Kf9VWoXQB8,7616
36
+ sksurv/linear_model/coxnet.py,sha256=ZJAsJImSoGCL6HJ9uXQdph8sDLAZakcIvMuzRy28F4M,20727
37
+ sksurv/linear_model/coxph.py,sha256=PeYw7I_EHSgV3kriUnEhJg6CIdISkFk9ewlzijHC8Nw,21475
38
+ sksurv/meta/__init__.py,sha256=vw8vn6gR50a8O_7VhUw4aRgbSP2k3erVBgYS7ifwOcI,220
39
+ sksurv/meta/base.py,sha256=AdhIkZi9PvucZ3B2lhhFQpQbwp8EUCDUVOiaev_UsX8,1472
40
+ sksurv/meta/ensemble_selection.py,sha256=NTXzq6xvwL5IwdxyjNVkPOgzU_9lzdsyFKCq0YbdOXo,24673
41
+ sksurv/meta/stacking.py,sha256=md3ZKYYlS23YKbbLjWZqtPRSJYAHfLzN9RlC5RgzYEQ,12629
42
+ sksurv/svm/__init__.py,sha256=CSceYEcBPGKRcJZ4R0u7DzwictGln_weLIsbt2i5xeU,339
43
+ sksurv/svm/_minlip.cp313-win_amd64.pyd,sha256=G4o8dbyW7Az9bPcgGmMh2yQaOOwIl9SaCz9pg501lgc,157696
44
+ sksurv/svm/_prsvm.cp313-win_amd64.pyd,sha256=yJqDmx7ybvPy5Vh0MLLgDuloZJInYtqWY4OI-ESr39w,155136
45
+ sksurv/svm/minlip.py,sha256=Elej96P4s6jMNCbGIbNm4ehWI3ew9o2IyF5tfZuTEqg,22416
46
+ sksurv/svm/naive_survival_svm.py,sha256=D6l8ctlpyUPBhcMlF_6KJq9d_tq1Xmtb9w--n8pGcfM,8209
47
+ sksurv/svm/survival_svm.py,sha256=oGac8mWlDOlZPRoK1MmbcigKUq2bWPw1Uc8vYSsSlwI,44730
48
+ sksurv/tree/__init__.py,sha256=ozb0fhURX-lpiSiHZd0DMnjkqhC0XOC2CTZq0hEZLPw,65
49
+ sksurv/tree/_criterion.cp313-win_amd64.pyd,sha256=l0JJbug1cS8lM46DwH_rGrEz5G5GGnfqVDYP4c3vreQ,176128
50
+ sksurv/tree/tree.py,sha256=PPkqq8jDeAHCd9L283-SLCk7lTBfH-HObTOuOXOvqHs,27248
51
+ scikit_survival-0.23.1.dist-info/COPYING,sha256=Czg9WmPaZE9ijZnDOXbqZIftiaqlnwsyV5kt6sEXHms,35821
52
+ scikit_survival-0.23.1.dist-info/METADATA,sha256=alYA45glhQaH9X_MzelejmubVqXGWVaPiiMPqRa6Yns,49841
53
+ scikit_survival-0.23.1.dist-info/WHEEL,sha256=UJbDlEYuWWwgv9Hu0As4Rgv2Qpdka2YFe6UlEKs4AoE,101
54
+ scikit_survival-0.23.1.dist-info/top_level.txt,sha256=fPkcFA-XQGbwnD_ZXOvaOWmSd34Qezr26Mn99nYPvAg,7
55
+ scikit_survival-0.23.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (75.3.0)
3
+ Root-Is-Purelib: false
4
+ Tag: cp313-cp313-win_amd64
5
+
@@ -0,0 +1 @@
1
+ sksurv
sksurv/__init__.py ADDED
@@ -0,0 +1,138 @@
1
+ from importlib.metadata import PackageNotFoundError, version
2
+ import platform
3
+ import sys
4
+
5
+ from sklearn.pipeline import Pipeline, _final_estimator_has
6
+ from sklearn.utils.metaestimators import available_if
7
+
8
+ from .util import property_available_if
9
+
10
+
11
+ def _get_version(name):
12
+ try:
13
+ pkg_version = version(name)
14
+ except ImportError:
15
+ pkg_version = None
16
+ return pkg_version
17
+
18
+
19
+ def show_versions():
20
+ sys_info = {
21
+ "Platform": platform.platform(),
22
+ "Python version": f"{platform.python_implementation()} {platform}",
23
+ "Python interpreter": sys.executable,
24
+ }
25
+
26
+ deps = [
27
+ "scikit-survival",
28
+ "scikit-learn",
29
+ "numpy",
30
+ "scipy",
31
+ "pandas",
32
+ "numexpr",
33
+ "ecos",
34
+ "osqp",
35
+ "joblib",
36
+ "matplotlib",
37
+ "pytest",
38
+ "sphinx",
39
+ "Cython",
40
+ "pip",
41
+ "setuptools",
42
+ ]
43
+ minwidth = max(
44
+ max(map(len, deps)),
45
+ max(map(len, sys_info.keys())),
46
+ )
47
+ fmt = "{0:<%ds}: {1}" % minwidth
48
+
49
+ print("SYSTEM")
50
+ print("------")
51
+ for name, version_string in sys_info.items():
52
+ print(fmt.format(name, version_string))
53
+
54
+ print("\nDEPENDENCIES")
55
+ print("------------")
56
+ for dep in deps:
57
+ version_string = _get_version(dep)
58
+ print(fmt.format(dep, version_string))
59
+
60
+
61
+ @available_if(_final_estimator_has("predict_cumulative_hazard_function"))
62
+ def predict_cumulative_hazard_function(self, X, **kwargs):
63
+ """Predict cumulative hazard function.
64
+
65
+ The cumulative hazard function for an individual
66
+ with feature vector :math:`x` is defined as
67
+
68
+ .. math::
69
+
70
+ H(t \\mid x) = \\exp(x^\\top \\beta) H_0(t) ,
71
+
72
+ where :math:`H_0(t)` is the baseline hazard function,
73
+ estimated by Breslow's estimator.
74
+
75
+ Parameters
76
+ ----------
77
+ X : array-like, shape = (n_samples, n_features)
78
+ Data matrix.
79
+
80
+ Returns
81
+ -------
82
+ cum_hazard : ndarray, shape = (n_samples,)
83
+ Predicted cumulative hazard functions.
84
+ """
85
+ Xt = X
86
+ for _, _, transform in self._iter(with_final=False):
87
+ Xt = transform.transform(Xt)
88
+ return self.steps[-1][-1].predict_cumulative_hazard_function(Xt, **kwargs)
89
+
90
+
91
+ @available_if(_final_estimator_has("predict_survival_function"))
92
+ def predict_survival_function(self, X, **kwargs):
93
+ """Predict survival function.
94
+
95
+ The survival function for an individual
96
+ with feature vector :math:`x` is defined as
97
+
98
+ .. math::
99
+
100
+ S(t \\mid x) = S_0(t)^{\\exp(x^\\top \\beta)} ,
101
+
102
+ where :math:`S_0(t)` is the baseline survival function,
103
+ estimated by Breslow's estimator.
104
+
105
+ Parameters
106
+ ----------
107
+ X : array-like, shape = (n_samples, n_features)
108
+ Data matrix.
109
+
110
+ Returns
111
+ -------
112
+ survival : ndarray, shape = (n_samples,)
113
+ Predicted survival functions.
114
+ """
115
+ Xt = X
116
+ for _, _, transform in self._iter(with_final=False):
117
+ Xt = transform.transform(Xt)
118
+ return self.steps[-1][-1].predict_survival_function(Xt, **kwargs)
119
+
120
+
121
+ @property_available_if(_final_estimator_has("_predict_risk_score"))
122
+ def _predict_risk_score(self):
123
+ return self.steps[-1][-1]._predict_risk_score
124
+
125
+
126
+ def patch_pipeline():
127
+ Pipeline.predict_survival_function = predict_survival_function
128
+ Pipeline.predict_cumulative_hazard_function = predict_cumulative_hazard_function
129
+ Pipeline._predict_risk_score = _predict_risk_score
130
+
131
+
132
+ try:
133
+ __version__ = version("scikit-survival")
134
+ except PackageNotFoundError: # pragma: no cover
135
+ # package is not installed
136
+ __version__ = "unknown"
137
+
138
+ patch_pipeline()
sksurv/base.py ADDED
@@ -0,0 +1,103 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import numpy as np
14
+
15
+
16
+ class SurvivalAnalysisMixin:
17
+ def _predict_function(self, func_name, baseline_model, prediction, return_array):
18
+ fns = getattr(baseline_model, func_name)(prediction)
19
+
20
+ if not return_array:
21
+ return fns
22
+
23
+ times = baseline_model.unique_times_
24
+ arr = np.empty((prediction.shape[0], times.shape[0]), dtype=float)
25
+ for i, fn in enumerate(fns):
26
+ arr[i, :] = fn(times)
27
+ return arr
28
+
29
+ def _predict_survival_function(self, baseline_model, prediction, return_array):
30
+ """Return survival functions.
31
+
32
+ Parameters
33
+ ----------
34
+ baseline_model : sksurv.linear_model.coxph.BreslowEstimator
35
+ Estimator of baseline survival function.
36
+
37
+ prediction : array-like, shape=(n_samples,)
38
+ Predicted risk scores.
39
+
40
+ return_array : bool
41
+ If True, return a float array of the survival function
42
+ evaluated at the unique event times, otherwise return
43
+ an array of :class:`sksurv.functions.StepFunction` instances.
44
+
45
+ Returns
46
+ -------
47
+ survival : ndarray
48
+ """
49
+ return self._predict_function("get_survival_function", baseline_model, prediction, return_array)
50
+
51
+ def _predict_cumulative_hazard_function(self, baseline_model, prediction, return_array):
52
+ """Return cumulative hazard functions.
53
+
54
+ Parameters
55
+ ----------
56
+ baseline_model : sksurv.linear_model.coxph.BreslowEstimator
57
+ Estimator of baseline cumulative hazard function.
58
+
59
+ prediction : array-like, shape=(n_samples,)
60
+ Predicted risk scores.
61
+
62
+ return_array : bool
63
+ If True, return a float array of the cumulative hazard function
64
+ evaluated at the unique event times, otherwise return
65
+ an array of :class:`sksurv.functions.StepFunction` instances.
66
+
67
+ Returns
68
+ -------
69
+ cum_hazard : ndarray
70
+ """
71
+ return self._predict_function("get_cumulative_hazard_function", baseline_model, prediction, return_array)
72
+
73
+ def score(self, X, y):
74
+ """Returns the concordance index of the prediction.
75
+
76
+ Parameters
77
+ ----------
78
+ X : array-like, shape = (n_samples, n_features)
79
+ Test samples.
80
+
81
+ y : structured array, shape = (n_samples,)
82
+ A structured array containing the binary event indicator
83
+ as first field, and time of event or time of censoring as
84
+ second field.
85
+
86
+ Returns
87
+ -------
88
+ cindex : float
89
+ Estimated concordance index.
90
+ """
91
+ from .metrics import concordance_index_censored
92
+
93
+ name_event, name_time = y.dtype.names
94
+
95
+ risk_score = self.predict(X)
96
+ if not getattr(self, "_predict_risk_score", True):
97
+ risk_score *= -1 # convert prediction on time scale to risk scale
98
+
99
+ result = concordance_index_censored(y[name_event], y[name_time], risk_score)
100
+ return result[0]
101
+
102
+ def _more_tags(self):
103
+ return {"requires_y": True}
@@ -0,0 +1,15 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ from ._binarytrees import AATree, AVLTree, RBTree
14
+
15
+ __all__ = ["RBTree", "AVLTree", "AATree"]
sksurv/column.py ADDED
@@ -0,0 +1,201 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import logging
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ from pandas.api.types import CategoricalDtype, is_object_dtype
18
+
19
+ __all__ = ["categorical_to_numeric", "encode_categorical", "standardize"]
20
+
21
+
22
+ def _apply_along_column(array, func1d, **kwargs):
23
+ if isinstance(array, pd.DataFrame):
24
+ return array.apply(func1d, **kwargs)
25
+ return np.apply_along_axis(func1d, 0, array, **kwargs)
26
+
27
+
28
+ def standardize_column(series_or_array, with_std=True):
29
+ d = series_or_array.dtype
30
+ if issubclass(d.type, np.number):
31
+ output = series_or_array.astype(float)
32
+ m = series_or_array.mean()
33
+ output -= m
34
+
35
+ if with_std:
36
+ s = series_or_array.std(ddof=1)
37
+ output /= s
38
+
39
+ return output
40
+
41
+ return series_or_array
42
+
43
+
44
+ def standardize(table, with_std=True):
45
+ """
46
+ Perform Z-Normalization on each numeric column of the given table.
47
+
48
+ If `table` is a pandas.DataFrame, only numeric columns are modified,
49
+ all other columns remain unchanged. If `table` is a numpy.ndarray,
50
+ it is only modified if it has numeric dtype, in which case the returned
51
+ array will have floating point dtype.
52
+
53
+ Parameters
54
+ ----------
55
+ table : pandas.DataFrame or numpy.ndarray
56
+ Data to standardize.
57
+
58
+ with_std : bool, optional, default: True
59
+ If ``False`` data is only centered and not converted to unit variance.
60
+
61
+ Returns
62
+ -------
63
+ normalized : pandas.DataFrame
64
+ Table with numeric columns normalized.
65
+ Categorical columns in the input table remain unchanged.
66
+ """
67
+ new_frame = _apply_along_column(table, standardize_column, with_std=with_std)
68
+
69
+ return new_frame
70
+
71
+
72
+ def _encode_categorical_series(series, allow_drop=True):
73
+ values = _get_dummies_1d(series, allow_drop=allow_drop)
74
+ if values is None:
75
+ return
76
+
77
+ enc, levels = values
78
+ if enc is None:
79
+ return pd.Series(index=series.index, name=series.name, dtype=series.dtype)
80
+
81
+ if not allow_drop and enc.shape[1] == 1:
82
+ return series
83
+
84
+ names = []
85
+ for key in range(1, enc.shape[1]):
86
+ names.append(f"{series.name}={levels[key]}")
87
+ series = pd.DataFrame(enc[:, 1:], columns=names, index=series.index)
88
+
89
+ return series
90
+
91
+
92
+ def encode_categorical(table, columns=None, **kwargs):
93
+ """
94
+ Encode categorical columns with `M` categories into `M-1` columns according
95
+ to the one-hot scheme.
96
+
97
+ Parameters
98
+ ----------
99
+ table : pandas.DataFrame
100
+ Table with categorical columns to encode.
101
+
102
+ columns : list-like, optional, default: None
103
+ Column names in the DataFrame to be encoded.
104
+ If `columns` is None then all the columns with
105
+ `object` or `category` dtype will be converted.
106
+
107
+ allow_drop : boolean, optional, default: True
108
+ Whether to allow dropping categorical columns that only consist
109
+ of a single category.
110
+
111
+ Returns
112
+ -------
113
+ encoded : pandas.DataFrame
114
+ Table with categorical columns encoded as numeric.
115
+ Numeric columns in the input table remain unchanged.
116
+ """
117
+ if isinstance(table, pd.Series):
118
+ if not isinstance(table.dtype, CategoricalDtype) and not is_object_dtype(table.dtype):
119
+ raise TypeError(f"series must be of categorical dtype, but was {table.dtype}")
120
+ return _encode_categorical_series(table, **kwargs)
121
+
122
+ def _is_categorical_or_object(series):
123
+ return isinstance(series.dtype, CategoricalDtype) or is_object_dtype(series.dtype)
124
+
125
+ if columns is None:
126
+ # for columns containing categories
127
+ columns_to_encode = {nam for nam, s in table.items() if _is_categorical_or_object(s)}
128
+ else:
129
+ columns_to_encode = set(columns)
130
+
131
+ items = []
132
+ for name, series in table.items():
133
+ if name in columns_to_encode:
134
+ series = _encode_categorical_series(series, **kwargs)
135
+ if series is None:
136
+ continue
137
+ items.append(series)
138
+
139
+ # concat columns of tables
140
+ new_table = pd.concat(items, axis=1, copy=False)
141
+ return new_table
142
+
143
+
144
+ def _get_dummies_1d(data, allow_drop=True):
145
+ # Series avoids inconsistent NaN handling
146
+ cat = pd.Categorical(data)
147
+ levels = cat.categories
148
+ number_of_cols = len(levels)
149
+
150
+ # if all NaN or only one level
151
+ if allow_drop and number_of_cols < 2:
152
+ logging.getLogger(__package__).warning(
153
+ f"dropped categorical variable {data.name!r}, because it has only {number_of_cols} values"
154
+ )
155
+ return
156
+ if number_of_cols == 0:
157
+ return None, levels
158
+
159
+ dummy_mat = np.eye(number_of_cols).take(cat.codes, axis=0)
160
+
161
+ # reset NaN GH4446
162
+ dummy_mat[cat.codes == -1] = np.nan
163
+
164
+ return dummy_mat, levels
165
+
166
+
167
+ def categorical_to_numeric(table):
168
+ """Encode categorical columns to numeric by converting each category to
169
+ an integer value.
170
+
171
+ Parameters
172
+ ----------
173
+ table : pandas.DataFrame
174
+ Table with categorical columns to encode.
175
+
176
+ Returns
177
+ -------
178
+ encoded : pandas.DataFrame
179
+ Table with categorical columns encoded as numeric.
180
+ Numeric columns in the input table remain unchanged.
181
+ """
182
+
183
+ def transform(column):
184
+ if isinstance(column.dtype, CategoricalDtype):
185
+ return column.cat.codes
186
+ if is_object_dtype(column.dtype):
187
+ try:
188
+ nc = column.astype(np.int64)
189
+ except ValueError:
190
+ classes = column.dropna().unique()
191
+ classes.sort(kind="mergesort")
192
+ nc = column.map(dict(zip(classes, range(classes.shape[0]))))
193
+ return nc
194
+ if column.dtype == bool:
195
+ return column.astype(np.int64)
196
+
197
+ return column
198
+
199
+ if isinstance(table, pd.Series):
200
+ return pd.Series(transform(table), name=table.name, index=table.index)
201
+ return table.apply(transform, axis=0, result_type="expand")
sksurv/compare.py ADDED
@@ -0,0 +1,123 @@
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from scipy import stats
6
+ from sklearn.utils.validation import check_array
7
+
8
+ from .util import check_array_survival
9
+
10
+ __all__ = ["compare_survival"]
11
+
12
+
13
+ def compare_survival(y, group_indicator, return_stats=False):
14
+ """K-sample log-rank hypothesis test of identical survival functions.
15
+
16
+ Compares the pooled hazard rate with each group-specific
17
+ hazard rate. The alternative hypothesis is that the hazard
18
+ rate of at least one group differs from the others at some time.
19
+
20
+ See [1]_ for more details.
21
+
22
+ Parameters
23
+ ----------
24
+ y : structured array, shape = (n_samples,)
25
+ A structured array containing the binary event indicator
26
+ as first field, and time of event or time of censoring as
27
+ second field.
28
+
29
+ group_indicator : array-like, shape = (n_samples,)
30
+ Group membership of each sample.
31
+
32
+ return_stats : bool, optional, default: False
33
+ Whether to return a data frame with statistics for each group
34
+ and the covariance matrix of the test statistic.
35
+
36
+ Returns
37
+ -------
38
+ chisq : float
39
+ Test statistic.
40
+ pvalue : float
41
+ Two-sided p-value with respect to the null hypothesis
42
+ that the hazard rates across all groups are equal.
43
+ stats : pandas.DataFrame
44
+ Summary statistics for each group: number of samples,
45
+ observed number of events, expected number of events,
46
+ and test statistic.
47
+ Only provided if `return_stats` is True.
48
+ covariance : array, shape=(n_groups, n_groups)
49
+ Covariance matrix of the test statistic.
50
+ Only provided if `return_stats` is True.
51
+
52
+ References
53
+ ----------
54
+ .. [1] Fleming, T. R. and Harrington, D. P.
55
+ A Class of Hypothesis Tests for One and Two Samples of Censored Survival Data.
56
+ Communications In Statistics 10 (1981): 763-794.
57
+ """
58
+
59
+ event, time = check_array_survival(group_indicator, y)
60
+ group_indicator = check_array(
61
+ group_indicator,
62
+ dtype="O",
63
+ ensure_2d=False,
64
+ estimator="compare_survival",
65
+ input_name="group_indicator",
66
+ )
67
+
68
+ n_samples = time.shape[0]
69
+ groups, group_counts = np.unique(group_indicator, return_counts=True)
70
+ n_groups = groups.shape[0]
71
+ if n_groups == 1:
72
+ raise ValueError("At least two groups must be specified, but only one was provided.")
73
+
74
+ # sort descending
75
+ o = np.argsort(-time, kind="mergesort")
76
+ x = group_indicator[o]
77
+ event = event[o]
78
+ time = time[o]
79
+
80
+ at_risk = np.zeros(n_groups, dtype=int)
81
+ observed = np.zeros(n_groups, dtype=int)
82
+ expected = np.zeros(n_groups, dtype=float)
83
+ covar = np.zeros((n_groups, n_groups), dtype=float)
84
+
85
+ covar_indices = np.diag_indices(n_groups)
86
+
87
+ k = 0
88
+ while k < n_samples:
89
+ ti = time[k]
90
+ total_events = 0
91
+ while k < n_samples and ti == time[k]:
92
+ idx = np.searchsorted(groups, x[k])
93
+ if event[k]:
94
+ observed[idx] += 1
95
+ total_events += 1
96
+ at_risk[idx] += 1
97
+ k += 1
98
+
99
+ if total_events != 0:
100
+ total_at_risk = k
101
+ expected += at_risk * (total_events / total_at_risk)
102
+ if total_at_risk > 1:
103
+ multiplier = total_events * (total_at_risk - total_events) / (total_at_risk * (total_at_risk - 1))
104
+ temp = at_risk * multiplier
105
+ covar[covar_indices] += temp
106
+ covar -= np.outer(temp, at_risk) / total_at_risk
107
+
108
+ df = n_groups - 1
109
+ zz = observed[:df] - expected[:df]
110
+ chisq = np.linalg.solve(covar[:df, :df], zz).dot(zz)
111
+ pval = stats.chi2.sf(chisq, df)
112
+
113
+ if return_stats:
114
+ table = OrderedDict()
115
+ table["counts"] = group_counts
116
+ table["observed"] = observed
117
+ table["expected"] = expected
118
+ table["statistic"] = observed - expected
119
+ table = pd.DataFrame.from_dict(table)
120
+ table.index = pd.Index(groups, name="group", dtype=groups.dtype)
121
+ return chisq, pval, table, covar
122
+
123
+ return chisq, pval
@@ -0,0 +1,10 @@
1
+ from .base import (
2
+ get_x_y, # noqa: F401
3
+ load_aids, # noqa: F401
4
+ load_arff_files_standardized, # noqa: F401
5
+ load_breast_cancer, # noqa: F401
6
+ load_flchain, # noqa: F401
7
+ load_gbsg2, # noqa: F401
8
+ load_veterans_lung_cancer, # noqa: F401
9
+ load_whas500, # noqa: F401
10
+ )