scikit-survival 0.26.0__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.26.0.dist-info/METADATA +185 -0
- scikit_survival-0.26.0.dist-info/RECORD +58 -0
- scikit_survival-0.26.0.dist-info/WHEEL +5 -0
- scikit_survival-0.26.0.dist-info/licenses/COPYING +674 -0
- scikit_survival-0.26.0.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +183 -0
- sksurv/base.py +115 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cp314-win_amd64.pyd +0 -0
- sksurv/column.py +204 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +12 -0
- sksurv/datasets/base.py +614 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/bmt.arff +46 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/cgvhd.arff +118 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cp314-win_amd64.pyd +0 -0
- sksurv/ensemble/boosting.py +1564 -0
- sksurv/ensemble/forest.py +902 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +91 -0
- sksurv/io/arffwrite.py +181 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cp314-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +348 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cp314-win_amd64.pyd +0 -0
- sksurv/linear_model/aft.py +208 -0
- sksurv/linear_model/coxnet.py +592 -0
- sksurv/linear_model/coxph.py +637 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +724 -0
- sksurv/meta/stacking.py +370 -0
- sksurv/metrics.py +1028 -0
- sksurv/nonparametric.py +911 -0
- sksurv/preprocessing.py +195 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cp314-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp314-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +695 -0
- sksurv/svm/naive_survival_svm.py +249 -0
- sksurv/svm/survival_svm.py +1236 -0
- sksurv/testing.py +155 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cp314-win_amd64.pyd +0 -0
- sksurv/tree/tree.py +790 -0
- sksurv/util.py +416 -0
sksurv/__init__.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
2
|
+
import platform
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from sklearn.pipeline import Pipeline, _final_estimator_has
|
|
6
|
+
from sklearn.utils.metaestimators import available_if
|
|
7
|
+
|
|
8
|
+
from .util import property_available_if
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _get_version(name):
|
|
12
|
+
try:
|
|
13
|
+
pkg_version = version(name)
|
|
14
|
+
except ImportError:
|
|
15
|
+
pkg_version = None
|
|
16
|
+
return pkg_version
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def show_versions():
|
|
20
|
+
"""Print debugging information."""
|
|
21
|
+
sys_info = {
|
|
22
|
+
"Platform": platform.platform(),
|
|
23
|
+
"Python version": f"{platform.python_implementation()} {platform.python_version()}",
|
|
24
|
+
"Python interpreter": sys.executable,
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
deps = [
|
|
28
|
+
"scikit-survival",
|
|
29
|
+
"scikit-learn",
|
|
30
|
+
"numpy",
|
|
31
|
+
"scipy",
|
|
32
|
+
"pandas",
|
|
33
|
+
"numexpr",
|
|
34
|
+
"ecos",
|
|
35
|
+
"osqp",
|
|
36
|
+
"joblib",
|
|
37
|
+
"matplotlib",
|
|
38
|
+
"pytest",
|
|
39
|
+
"sphinx",
|
|
40
|
+
"Cython",
|
|
41
|
+
"pip",
|
|
42
|
+
"setuptools",
|
|
43
|
+
]
|
|
44
|
+
minwidth = max(
|
|
45
|
+
max(map(len, deps)),
|
|
46
|
+
max(map(len, sys_info.keys())),
|
|
47
|
+
)
|
|
48
|
+
fmt = f"{{0:<{minwidth}s}}: {{1}}"
|
|
49
|
+
|
|
50
|
+
print("SYSTEM")
|
|
51
|
+
print("------")
|
|
52
|
+
for name, version_string in sys_info.items():
|
|
53
|
+
print(fmt.format(name, version_string))
|
|
54
|
+
|
|
55
|
+
print("\nDEPENDENCIES")
|
|
56
|
+
print("------------")
|
|
57
|
+
for dep in deps:
|
|
58
|
+
version_string = _get_version(dep)
|
|
59
|
+
print(fmt.format(dep, version_string))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@available_if(_final_estimator_has("predict_cumulative_hazard_function"))
|
|
63
|
+
def predict_cumulative_hazard_function(self, X, **kwargs):
|
|
64
|
+
r"""Predict cumulative hazard function for a pipeline.
|
|
65
|
+
|
|
66
|
+
The cumulative hazard function for an individual
|
|
67
|
+
with feature vector :math:`x` is defined as
|
|
68
|
+
|
|
69
|
+
.. math::
|
|
70
|
+
|
|
71
|
+
H(t \mid x) = \exp(x^\top \beta) H_0(t) ,
|
|
72
|
+
|
|
73
|
+
where :math:`H_0(t)` is the baseline hazard function,
|
|
74
|
+
estimated by Breslow's estimator.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
X : array-like, shape = (n_samples, n_features)
|
|
79
|
+
Data matrix.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
cum_hazard : ndarray, shape = (n_samples,)
|
|
84
|
+
Predicted cumulative hazard functions. Each element is an instance
|
|
85
|
+
of :class:`sksurv.functions.StepFunction`.
|
|
86
|
+
|
|
87
|
+
See Also
|
|
88
|
+
--------
|
|
89
|
+
predict_survival_function : Predict survival function for a pipeline.
|
|
90
|
+
|
|
91
|
+
Examples
|
|
92
|
+
--------
|
|
93
|
+
>>> from sksurv.datasets import load_whas500
|
|
94
|
+
>>> from sksurv.linear_model import CoxPHSurvivalAnalysis
|
|
95
|
+
>>> from sksurv.preprocessing import OneHotEncoder
|
|
96
|
+
>>> from sklearn.pipeline import Pipeline
|
|
97
|
+
>>>
|
|
98
|
+
>>> X, y = load_whas500()
|
|
99
|
+
>>> pipe = Pipeline([('encode', OneHotEncoder()),
|
|
100
|
+
... ('cox', CoxPHSurvivalAnalysis())])
|
|
101
|
+
>>> pipe.fit(X, y)
|
|
102
|
+
Pipeline(...)
|
|
103
|
+
>>> chf = pipe.predict_cumulative_hazard_function(X.iloc[:5])
|
|
104
|
+
>>> for fn in chf:
|
|
105
|
+
... print(fn.x, fn.y)
|
|
106
|
+
[...]
|
|
107
|
+
"""
|
|
108
|
+
Xt = X
|
|
109
|
+
for _, _, transform in self._iter(with_final=False):
|
|
110
|
+
Xt = transform.transform(Xt)
|
|
111
|
+
return self.steps[-1][-1].predict_cumulative_hazard_function(Xt, **kwargs)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@available_if(_final_estimator_has("predict_survival_function"))
|
|
115
|
+
def predict_survival_function(self, X, **kwargs):
|
|
116
|
+
r"""Predict survival function for a pipeline.
|
|
117
|
+
|
|
118
|
+
The survival function for an individual
|
|
119
|
+
with feature vector :math:`x` is defined as
|
|
120
|
+
|
|
121
|
+
.. math::
|
|
122
|
+
|
|
123
|
+
S(t \mid x) = S_0(t)^{\exp(x^\top \beta)} ,
|
|
124
|
+
|
|
125
|
+
where :math:`S_0(t)` is the baseline survival function,
|
|
126
|
+
estimated by Breslow's estimator.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
X : array-like, shape = (n_samples, n_features)
|
|
131
|
+
Data matrix.
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
survival : ndarray, shape = (n_samples,)
|
|
136
|
+
Predicted survival functions. Each element is an instance
|
|
137
|
+
of :class:`sksurv.functions.StepFunction`.
|
|
138
|
+
|
|
139
|
+
See Also
|
|
140
|
+
--------
|
|
141
|
+
predict_cumulative_hazard_function : Predict cumulative hazard function for a pipeline.
|
|
142
|
+
|
|
143
|
+
Examples
|
|
144
|
+
--------
|
|
145
|
+
>>> from sksurv.datasets import load_whas500
|
|
146
|
+
>>> from sksurv.linear_model import CoxPHSurvivalAnalysis
|
|
147
|
+
>>> from sksurv.preprocessing import OneHotEncoder
|
|
148
|
+
>>> from sklearn.pipeline import Pipeline
|
|
149
|
+
>>>
|
|
150
|
+
>>> X, y = load_whas500()
|
|
151
|
+
>>> pipe = Pipeline([('encode', OneHotEncoder()),
|
|
152
|
+
... ('cox', CoxPHSurvivalAnalysis())])
|
|
153
|
+
>>> pipe.fit(X, y)
|
|
154
|
+
Pipeline(...)
|
|
155
|
+
>>> surv_fn = pipe.predict_survival_function(X.iloc[:5])
|
|
156
|
+
>>> for fn in surv_fn:
|
|
157
|
+
... print(fn.x, fn.y)
|
|
158
|
+
[...]
|
|
159
|
+
"""
|
|
160
|
+
Xt = X
|
|
161
|
+
for _, _, transform in self._iter(with_final=False):
|
|
162
|
+
Xt = transform.transform(Xt)
|
|
163
|
+
return self.steps[-1][-1].predict_survival_function(Xt, **kwargs)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@property_available_if(_final_estimator_has("_predict_risk_score"))
|
|
167
|
+
def _predict_risk_score(self):
|
|
168
|
+
return self.steps[-1][-1]._predict_risk_score
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def patch_pipeline():
|
|
172
|
+
Pipeline.predict_survival_function = predict_survival_function
|
|
173
|
+
Pipeline.predict_cumulative_hazard_function = predict_cumulative_hazard_function
|
|
174
|
+
Pipeline._predict_risk_score = _predict_risk_score
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
__version__ = version("scikit-survival")
|
|
179
|
+
except PackageNotFoundError: # pragma: no cover
|
|
180
|
+
# package is not installed
|
|
181
|
+
__version__ = "unknown"
|
|
182
|
+
|
|
183
|
+
patch_pipeline()
|
sksurv/base.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SurvivalAnalysisMixin:
|
|
17
|
+
def _predict_function(self, func_name, baseline_model, prediction, return_array):
|
|
18
|
+
fns = getattr(baseline_model, func_name)(prediction)
|
|
19
|
+
|
|
20
|
+
if not return_array:
|
|
21
|
+
return fns
|
|
22
|
+
|
|
23
|
+
times = baseline_model.unique_times_
|
|
24
|
+
arr = np.empty((prediction.shape[0], times.shape[0]), dtype=float)
|
|
25
|
+
for i, fn in enumerate(fns):
|
|
26
|
+
arr[i, :] = fn(times)
|
|
27
|
+
return arr
|
|
28
|
+
|
|
29
|
+
def _predict_survival_function(self, baseline_model, prediction, return_array):
|
|
30
|
+
"""Return survival functions.
|
|
31
|
+
|
|
32
|
+
Parameters
|
|
33
|
+
----------
|
|
34
|
+
baseline_model : sksurv.linear_model.coxph.BreslowEstimator
|
|
35
|
+
Estimator of baseline survival function.
|
|
36
|
+
|
|
37
|
+
prediction : array-like, shape=(n_samples,)
|
|
38
|
+
Predicted risk scores.
|
|
39
|
+
|
|
40
|
+
return_array : bool
|
|
41
|
+
If True, return a float array of the survival function
|
|
42
|
+
evaluated at the unique event times, otherwise return
|
|
43
|
+
an array of :class:`sksurv.functions.StepFunction` instances.
|
|
44
|
+
|
|
45
|
+
Returns
|
|
46
|
+
-------
|
|
47
|
+
survival : ndarray of StepFunction
|
|
48
|
+
If `return_array` is True, an array of shape (n_samples, n_unique_times)
|
|
49
|
+
containing the survival function values. Otherwise, a list of
|
|
50
|
+
:class:`sksurv.functions.StepFunction` instances.
|
|
51
|
+
"""
|
|
52
|
+
return self._predict_function("get_survival_function", baseline_model, prediction, return_array)
|
|
53
|
+
|
|
54
|
+
def _predict_cumulative_hazard_function(self, baseline_model, prediction, return_array):
|
|
55
|
+
"""Return cumulative hazard functions.
|
|
56
|
+
|
|
57
|
+
Parameters
|
|
58
|
+
----------
|
|
59
|
+
baseline_model : sksurv.linear_model.coxph.BreslowEstimator
|
|
60
|
+
Estimator of baseline cumulative hazard function.
|
|
61
|
+
|
|
62
|
+
prediction : array-like, shape=(n_samples,)
|
|
63
|
+
Predicted risk scores.
|
|
64
|
+
|
|
65
|
+
return_array : bool
|
|
66
|
+
If True, return a float array of the cumulative hazard function
|
|
67
|
+
evaluated at the unique event times, otherwise return
|
|
68
|
+
an array of :class:`sksurv.functions.StepFunction` instances.
|
|
69
|
+
|
|
70
|
+
Returns
|
|
71
|
+
-------
|
|
72
|
+
cum_hazard : ndarray of StepFunction
|
|
73
|
+
If `return_array` is True, an array of shape (n_samples, n_unique_times)
|
|
74
|
+
containing the cumulative hazard function values. Otherwise, a list of
|
|
75
|
+
:class:`sksurv.functions.StepFunction` instances.
|
|
76
|
+
"""
|
|
77
|
+
return self._predict_function("get_cumulative_hazard_function", baseline_model, prediction, return_array)
|
|
78
|
+
|
|
79
|
+
def score(self, X, y):
|
|
80
|
+
"""Returns the concordance index of the prediction.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
X : array-like, shape = (n_samples, n_features)
|
|
85
|
+
Test samples.
|
|
86
|
+
|
|
87
|
+
y : structured array, shape = (n_samples,)
|
|
88
|
+
A structured array containing the binary event indicator
|
|
89
|
+
as first field, and time of event or time of censoring as
|
|
90
|
+
second field.
|
|
91
|
+
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
cindex : float
|
|
95
|
+
Estimated concordance index.
|
|
96
|
+
|
|
97
|
+
See also
|
|
98
|
+
--------
|
|
99
|
+
sksurv.metrics.concordance_index_censored : Computes the concordance index.
|
|
100
|
+
"""
|
|
101
|
+
from .metrics import concordance_index_censored
|
|
102
|
+
|
|
103
|
+
name_event, name_time = y.dtype.names
|
|
104
|
+
|
|
105
|
+
risk_score = self.predict(X)
|
|
106
|
+
if not getattr(self, "_predict_risk_score", True):
|
|
107
|
+
risk_score *= -1 # convert prediction on time scale to risk scale
|
|
108
|
+
|
|
109
|
+
result = concordance_index_censored(y[name_event], y[name_time], risk_score)
|
|
110
|
+
return result[0]
|
|
111
|
+
|
|
112
|
+
def __sklearn_tags__(self):
|
|
113
|
+
tags = super().__sklearn_tags__()
|
|
114
|
+
tags.target_tags.required = True
|
|
115
|
+
return tags
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
from ._binarytrees import AATree, AVLTree, RBTree
|
|
14
|
+
|
|
15
|
+
__all__ = ["RBTree", "AVLTree", "AATree"]
|
|
Binary file
|
sksurv/column.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import logging
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pandas as pd
|
|
17
|
+
from pandas.api.types import CategoricalDtype, is_string_dtype
|
|
18
|
+
|
|
19
|
+
__all__ = ["categorical_to_numeric", "encode_categorical", "standardize"]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _apply_along_column(array, func1d, **kwargs):
|
|
23
|
+
if isinstance(array, pd.DataFrame):
|
|
24
|
+
return array.apply(func1d, **kwargs)
|
|
25
|
+
return np.apply_along_axis(func1d, 0, array, **kwargs)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def standardize_column(series_or_array, with_std=True):
|
|
29
|
+
d = series_or_array.dtype
|
|
30
|
+
if issubclass(d.type, np.number):
|
|
31
|
+
output = series_or_array.astype(float)
|
|
32
|
+
m = series_or_array.mean()
|
|
33
|
+
output -= m
|
|
34
|
+
|
|
35
|
+
if with_std:
|
|
36
|
+
s = series_or_array.std(ddof=1)
|
|
37
|
+
output /= s
|
|
38
|
+
|
|
39
|
+
return output
|
|
40
|
+
|
|
41
|
+
return series_or_array
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def standardize(table, with_std=True):
|
|
45
|
+
"""Standardize numeric features by removing the mean and scaling to unit variance.
|
|
46
|
+
|
|
47
|
+
This function performs Z-Normalization on each numeric column of the given
|
|
48
|
+
table.
|
|
49
|
+
|
|
50
|
+
If `table` is a :class:`pandas.DataFrame`, only numeric columns are modified;
|
|
51
|
+
all other columns remain unchanged. If `table` is a :class:`numpy.ndarray`,
|
|
52
|
+
it is only modified if it has a numeric dtype, in which case the returned
|
|
53
|
+
array will have a floating-point dtype.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
table : pandas.DataFrame or numpy.ndarray
|
|
58
|
+
Data to standardize.
|
|
59
|
+
with_std : bool, optional, default: True
|
|
60
|
+
If ``False``, data is only centered (mean removed) and not scaled to
|
|
61
|
+
unit variance.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
normalized : pandas.DataFrame or numpy.ndarray
|
|
66
|
+
The standardized data. The output type will be the same as the input type.
|
|
67
|
+
"""
|
|
68
|
+
new_frame = _apply_along_column(table, standardize_column, with_std=with_std)
|
|
69
|
+
|
|
70
|
+
return new_frame
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _encode_categorical_series(series, allow_drop=True):
|
|
74
|
+
values = _get_dummies_1d(series, allow_drop=allow_drop)
|
|
75
|
+
if values is None:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
enc, levels = values
|
|
79
|
+
if enc is None:
|
|
80
|
+
return pd.Series(index=series.index, name=series.name, dtype=series.dtype)
|
|
81
|
+
|
|
82
|
+
if not allow_drop and enc.shape[1] == 1:
|
|
83
|
+
return series
|
|
84
|
+
|
|
85
|
+
names = []
|
|
86
|
+
for key in range(1, enc.shape[1]):
|
|
87
|
+
names.append(f"{series.name}={levels[key]}")
|
|
88
|
+
series = pd.DataFrame(enc[:, 1:], columns=names, index=series.index)
|
|
89
|
+
|
|
90
|
+
return series
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def encode_categorical(table, columns=None, **kwargs):
|
|
94
|
+
"""One-hot encode categorical features.
|
|
95
|
+
|
|
96
|
+
This function creates a binary column for each category and, by default,
|
|
97
|
+
drops one of the categories per feature: a column with `M` categories
|
|
98
|
+
is encoded as `M-1` integer columns according to the one-hot
|
|
99
|
+
scheme.
|
|
100
|
+
|
|
101
|
+
Parameters
|
|
102
|
+
----------
|
|
103
|
+
table : pandas.DataFrame or pandas.Series
|
|
104
|
+
Data with categorical columns to encode.
|
|
105
|
+
columns : list-like, optional, default: None
|
|
106
|
+
Column names in the DataFrame to be encoded.
|
|
107
|
+
If `columns` is `None`, all columns with `object` or `category`
|
|
108
|
+
dtype will be converted. This parameter is ignored if `table` is a
|
|
109
|
+
pandas.Series.
|
|
110
|
+
allow_drop : bool, optional, default: True
|
|
111
|
+
Whether to allow dropping categorical columns that only consist
|
|
112
|
+
of a single category.
|
|
113
|
+
|
|
114
|
+
Returns
|
|
115
|
+
-------
|
|
116
|
+
encoded : pandas.DataFrame
|
|
117
|
+
The transformed data with categorical columns encoded as numeric.
|
|
118
|
+
Numeric columns in the input table remain unchanged.
|
|
119
|
+
"""
|
|
120
|
+
if isinstance(table, pd.Series):
|
|
121
|
+
if not isinstance(table.dtype, CategoricalDtype) and not is_string_dtype(table.dtype):
|
|
122
|
+
raise TypeError(f"series must be of categorical dtype, but was {table.dtype}")
|
|
123
|
+
return _encode_categorical_series(table, **kwargs)
|
|
124
|
+
|
|
125
|
+
def _is_categorical_or_object(series):
|
|
126
|
+
return isinstance(series.dtype, CategoricalDtype) or is_string_dtype(series.dtype)
|
|
127
|
+
|
|
128
|
+
if columns is None:
|
|
129
|
+
# for columns containing categories
|
|
130
|
+
columns_to_encode = {nam for nam, s in table.items() if _is_categorical_or_object(s)}
|
|
131
|
+
else:
|
|
132
|
+
columns_to_encode = set(columns)
|
|
133
|
+
|
|
134
|
+
items = []
|
|
135
|
+
for name, series in table.items():
|
|
136
|
+
if name in columns_to_encode:
|
|
137
|
+
series = _encode_categorical_series(series, **kwargs)
|
|
138
|
+
if series is None:
|
|
139
|
+
continue
|
|
140
|
+
items.append(series)
|
|
141
|
+
|
|
142
|
+
# concat columns of tables
|
|
143
|
+
new_table = pd.concat(items, axis=1, copy=False)
|
|
144
|
+
return new_table
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _get_dummies_1d(data, allow_drop=True):
|
|
148
|
+
# Series avoids inconsistent NaN handling
|
|
149
|
+
cat = pd.Categorical(data)
|
|
150
|
+
levels = cat.categories
|
|
151
|
+
number_of_cols = len(levels)
|
|
152
|
+
|
|
153
|
+
# if all NaN or only one level
|
|
154
|
+
if allow_drop and number_of_cols < 2:
|
|
155
|
+
logging.getLogger(__package__).warning(
|
|
156
|
+
f"dropped categorical variable {data.name!r}, because it has only {number_of_cols} values"
|
|
157
|
+
)
|
|
158
|
+
return
|
|
159
|
+
if number_of_cols == 0:
|
|
160
|
+
return None, levels
|
|
161
|
+
|
|
162
|
+
dummy_mat = np.eye(number_of_cols).take(cat.codes, axis=0)
|
|
163
|
+
|
|
164
|
+
# reset NaN GH4446
|
|
165
|
+
dummy_mat[cat.codes == -1] = np.nan
|
|
166
|
+
|
|
167
|
+
return dummy_mat, levels
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def categorical_to_numeric(table):
|
|
171
|
+
"""Encode categorical features as integers.
|
|
172
|
+
|
|
173
|
+
This function converts each category to a unique integer value.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
table : pandas.DataFrame or pandas.Series
|
|
178
|
+
Data with categorical columns to encode.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
encoded : pandas.DataFrame or pandas.Series
|
|
183
|
+
The transformed data with categorical columns encoded as integers.
|
|
184
|
+
The output type will be the same as the input type.
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def transform(column):
|
|
188
|
+
if isinstance(column.dtype, CategoricalDtype):
|
|
189
|
+
return column.cat.codes
|
|
190
|
+
if is_string_dtype(column.dtype):
|
|
191
|
+
try:
|
|
192
|
+
nc = column.astype(np.int64)
|
|
193
|
+
except ValueError:
|
|
194
|
+
classes = column.dropna().unique()
|
|
195
|
+
nc = column.map(dict(zip(sorted(classes), range(classes.shape[0]))))
|
|
196
|
+
return nc
|
|
197
|
+
if column.dtype == bool:
|
|
198
|
+
return column.astype(np.int64)
|
|
199
|
+
|
|
200
|
+
return column
|
|
201
|
+
|
|
202
|
+
if isinstance(table, pd.Series):
|
|
203
|
+
return pd.Series(transform(table), name=table.name, index=table.index)
|
|
204
|
+
return table.apply(transform, axis=0, result_type="expand")
|
sksurv/compare.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import pandas as pd
|
|
5
|
+
from scipy import stats
|
|
6
|
+
from sklearn.utils.validation import check_array
|
|
7
|
+
|
|
8
|
+
from .util import check_array_survival
|
|
9
|
+
|
|
10
|
+
__all__ = ["compare_survival"]
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def compare_survival(y, group_indicator, return_stats=False):
|
|
14
|
+
"""Compare survival functions of two or more groups using the log-rank test.
|
|
15
|
+
|
|
16
|
+
The log-rank test is a non-parametric hypothesis test for comparing the
|
|
17
|
+
survival functions of two or more independent groups. The null hypothesis is
|
|
18
|
+
that the survival functions of the groups are identical. The alternative
|
|
19
|
+
hypothesis is that at least one survival function differs from the others.
|
|
20
|
+
|
|
21
|
+
The test statistic is approximately chi-squared distributed with :math:`K-1`
|
|
22
|
+
degrees of freedom, where :math:`K` is the number of groups.
|
|
23
|
+
|
|
24
|
+
See [1]_ for more details.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
y : structured array, shape = (n_samples,)
|
|
29
|
+
A structured array with two fields. The first field is a boolean
|
|
30
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
31
|
+
The second field is a float with the time of event or time of censoring.
|
|
32
|
+
group_indicator : array-like, shape = (n_samples,)
|
|
33
|
+
Group membership of each sample.
|
|
34
|
+
return_stats : bool, optional, default: False
|
|
35
|
+
Whether to return a data frame with statistics for each group and the
|
|
36
|
+
covariance matrix of the test statistic.
|
|
37
|
+
|
|
38
|
+
Returns
|
|
39
|
+
-------
|
|
40
|
+
chisq : float
|
|
41
|
+
The test statistic.
|
|
42
|
+
pvalue : float
|
|
43
|
+
The two-sided p-value for the test.
|
|
44
|
+
stats : pandas.DataFrame, optional
|
|
45
|
+
A DataFrame with summary statistics for each group. This includes the
|
|
46
|
+
number of samples, observed number of events, expected number of events,
|
|
47
|
+
and the test statistic. Only returned if ``return_stats`` is ``True``.
|
|
48
|
+
covariance : ndarray, shape=(n_groups, n_groups), optional
|
|
49
|
+
The covariance matrix of the test statistic. Only returned if
|
|
50
|
+
``return_stats`` is ``True``.
|
|
51
|
+
|
|
52
|
+
References
|
|
53
|
+
----------
|
|
54
|
+
.. [1] Fleming, T. R. and Harrington, D. P.
|
|
55
|
+
A Class of Hypothesis Tests for One and Two Samples of Censored Survival Data.
|
|
56
|
+
Communications In Statistics 10 (1981): 763-794.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
event, time = check_array_survival(group_indicator, y)
|
|
60
|
+
group_indicator = check_array(
|
|
61
|
+
group_indicator,
|
|
62
|
+
dtype="O",
|
|
63
|
+
ensure_2d=False,
|
|
64
|
+
estimator="compare_survival",
|
|
65
|
+
input_name="group_indicator",
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
n_samples = time.shape[0]
|
|
69
|
+
groups, group_counts = np.unique(group_indicator, return_counts=True)
|
|
70
|
+
n_groups = groups.shape[0]
|
|
71
|
+
if n_groups == 1:
|
|
72
|
+
raise ValueError("At least two groups must be specified, but only one was provided.")
|
|
73
|
+
|
|
74
|
+
# sort descending
|
|
75
|
+
o = np.argsort(-time, kind="mergesort")
|
|
76
|
+
x = group_indicator[o]
|
|
77
|
+
event = event[o]
|
|
78
|
+
time = time[o]
|
|
79
|
+
|
|
80
|
+
at_risk = np.zeros(n_groups, dtype=int)
|
|
81
|
+
observed = np.zeros(n_groups, dtype=int)
|
|
82
|
+
expected = np.zeros(n_groups, dtype=float)
|
|
83
|
+
covar = np.zeros((n_groups, n_groups), dtype=float)
|
|
84
|
+
|
|
85
|
+
covar_indices = np.diag_indices(n_groups)
|
|
86
|
+
|
|
87
|
+
k = 0
|
|
88
|
+
while k < n_samples:
|
|
89
|
+
ti = time[k]
|
|
90
|
+
total_events = 0
|
|
91
|
+
while k < n_samples and ti == time[k]:
|
|
92
|
+
idx = np.searchsorted(groups, x[k])
|
|
93
|
+
if event[k]:
|
|
94
|
+
observed[idx] += 1
|
|
95
|
+
total_events += 1
|
|
96
|
+
at_risk[idx] += 1
|
|
97
|
+
k += 1
|
|
98
|
+
|
|
99
|
+
if total_events != 0:
|
|
100
|
+
total_at_risk = k
|
|
101
|
+
expected += at_risk * (total_events / total_at_risk)
|
|
102
|
+
if total_at_risk > 1:
|
|
103
|
+
multiplier = total_events * (total_at_risk - total_events) / (total_at_risk * (total_at_risk - 1))
|
|
104
|
+
temp = at_risk * multiplier
|
|
105
|
+
covar[covar_indices] += temp
|
|
106
|
+
covar -= np.outer(temp, at_risk) / total_at_risk
|
|
107
|
+
|
|
108
|
+
df = n_groups - 1
|
|
109
|
+
zz = observed[:df] - expected[:df]
|
|
110
|
+
chisq = np.linalg.solve(covar[:df, :df], zz).dot(zz)
|
|
111
|
+
pval = stats.chi2.sf(chisq, df)
|
|
112
|
+
|
|
113
|
+
if return_stats:
|
|
114
|
+
table = OrderedDict()
|
|
115
|
+
table["counts"] = group_counts
|
|
116
|
+
table["observed"] = observed
|
|
117
|
+
table["expected"] = expected
|
|
118
|
+
table["statistic"] = observed - expected
|
|
119
|
+
table = pd.DataFrame.from_dict(table)
|
|
120
|
+
table.index = pd.Index(groups, name="group")
|
|
121
|
+
return chisq, pval, table, covar
|
|
122
|
+
|
|
123
|
+
return chisq, pval
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from .base import (
|
|
2
|
+
get_x_y, # noqa: F401
|
|
3
|
+
load_aids, # noqa: F401
|
|
4
|
+
load_arff_files_standardized, # noqa: F401
|
|
5
|
+
load_bmt, # noqa: F401
|
|
6
|
+
load_breast_cancer, # noqa: F401
|
|
7
|
+
load_cgvhd, # noqa: F401
|
|
8
|
+
load_flchain, # noqa: F401
|
|
9
|
+
load_gbsg2, # noqa: F401
|
|
10
|
+
load_veterans_lung_cancer, # noqa: F401
|
|
11
|
+
load_whas500, # noqa: F401
|
|
12
|
+
)
|