scikit-survival 0.26.0__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.26.0.dist-info/METADATA +185 -0
- scikit_survival-0.26.0.dist-info/RECORD +58 -0
- scikit_survival-0.26.0.dist-info/WHEEL +6 -0
- scikit_survival-0.26.0.dist-info/licenses/COPYING +674 -0
- scikit_survival-0.26.0.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +183 -0
- sksurv/base.py +115 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cpython-311-x86_64-linux-gnu.so +0 -0
- sksurv/column.py +204 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +12 -0
- sksurv/datasets/base.py +614 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/bmt.arff +46 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/cgvhd.arff +118 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cpython-311-x86_64-linux-gnu.so +0 -0
- sksurv/ensemble/boosting.py +1564 -0
- sksurv/ensemble/forest.py +902 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +91 -0
- sksurv/io/arffwrite.py +181 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cpython-311-x86_64-linux-gnu.so +0 -0
- sksurv/kernels/clinical.py +348 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cpython-311-x86_64-linux-gnu.so +0 -0
- sksurv/linear_model/aft.py +208 -0
- sksurv/linear_model/coxnet.py +592 -0
- sksurv/linear_model/coxph.py +637 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +724 -0
- sksurv/meta/stacking.py +370 -0
- sksurv/metrics.py +1028 -0
- sksurv/nonparametric.py +911 -0
- sksurv/preprocessing.py +195 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cpython-311-x86_64-linux-gnu.so +0 -0
- sksurv/svm/_prsvm.cpython-311-x86_64-linux-gnu.so +0 -0
- sksurv/svm/minlip.py +695 -0
- sksurv/svm/naive_survival_svm.py +249 -0
- sksurv/svm/survival_svm.py +1236 -0
- sksurv/testing.py +155 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cpython-311-x86_64-linux-gnu.so +0 -0
- sksurv/tree/tree.py +790 -0
- sksurv/util.py +416 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
from abc import ABCMeta
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from sklearn._loss.link import IdentityLink
|
|
17
|
+
from sklearn._loss.loss import BaseLoss
|
|
18
|
+
from sklearn.utils.extmath import squared_norm
|
|
19
|
+
|
|
20
|
+
from ..nonparametric import ipc_weights
|
|
21
|
+
from ._coxph_loss import coxph_loss, coxph_negative_gradient
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SurvivalLossFunction(BaseLoss, metaclass=ABCMeta): # noqa: B024
|
|
25
|
+
"""Base class for survival loss functions."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, sample_weight=None):
|
|
28
|
+
super().__init__(closs=None, link=IdentityLink())
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CoxPH(SurvivalLossFunction):
|
|
32
|
+
"""Cox Partial Likelihood"""
|
|
33
|
+
|
|
34
|
+
# pylint: disable=no-self-use
|
|
35
|
+
|
|
36
|
+
def __call__(self, y_true, raw_prediction, sample_weight=None): # pylint: disable=unused-argument
|
|
37
|
+
"""Compute the partial likelihood of prediction ``y_pred`` and ``y``."""
|
|
38
|
+
# TODO add support for sample weights
|
|
39
|
+
return coxph_loss(y_true["event"].astype(np.uint8), y_true["time"], raw_prediction.ravel())
|
|
40
|
+
|
|
41
|
+
def gradient(self, y_true, raw_prediction, sample_weight=None, **kwargs): # pylint: disable=unused-argument
|
|
42
|
+
"""Negative gradient of partial likelihood
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
---------
|
|
46
|
+
y : tuple, len = 2
|
|
47
|
+
First element is boolean event indicator and second element survival/censoring time.
|
|
48
|
+
y_pred : np.ndarray, shape=(n,):
|
|
49
|
+
The predictions.
|
|
50
|
+
"""
|
|
51
|
+
ret = coxph_negative_gradient(y_true["event"].astype(np.uint8), y_true["time"], raw_prediction.ravel())
|
|
52
|
+
if sample_weight is not None:
|
|
53
|
+
ret *= sample_weight
|
|
54
|
+
return ret
|
|
55
|
+
|
|
56
|
+
def update_terminal_regions(
|
|
57
|
+
self, tree, X, y, residual, raw_predictions, sample_weight, sample_mask, learning_rate=0.1, k=0
|
|
58
|
+
):
|
|
59
|
+
"""Least squares does not need to update terminal regions.
|
|
60
|
+
|
|
61
|
+
But it has to update the predictions.
|
|
62
|
+
"""
|
|
63
|
+
# update predictions
|
|
64
|
+
raw_predictions[:, k] += learning_rate * tree.predict(X).ravel()
|
|
65
|
+
|
|
66
|
+
def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, residual, raw_predictions, sample_weight):
|
|
67
|
+
"""Least squares does not need to update terminal regions"""
|
|
68
|
+
|
|
69
|
+
def _scale_raw_prediction(self, raw_predictions):
|
|
70
|
+
return raw_predictions
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class CensoredSquaredLoss(SurvivalLossFunction):
|
|
74
|
+
"""Censoring-aware squared loss.
|
|
75
|
+
|
|
76
|
+
Censoring is taken into account by only considering the residuals
|
|
77
|
+
of samples that are not censored, or the predicted survival time
|
|
78
|
+
is before the time of censoring.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# pylint: disable=no-self-use
|
|
82
|
+
def __call__(self, y_true, raw_prediction, sample_weight=None):
|
|
83
|
+
"""Compute the partial likelihood of prediction ``y_pred`` and ``y``."""
|
|
84
|
+
pred_time = y_true["time"] - raw_prediction.ravel()
|
|
85
|
+
mask = (pred_time > 0) | y_true["event"]
|
|
86
|
+
return 0.5 * squared_norm(pred_time.compress(mask, axis=0))
|
|
87
|
+
|
|
88
|
+
def gradient(self, y_true, raw_prediction, **kwargs): # pylint: disable=unused-argument
|
|
89
|
+
"""Negative gradient of partial likelihood
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
---------
|
|
93
|
+
y : tuple, len = 2
|
|
94
|
+
First element is boolean event indicator and second element survival/censoring time.
|
|
95
|
+
y_pred : np.ndarray, shape=(n,):
|
|
96
|
+
The predictions.
|
|
97
|
+
"""
|
|
98
|
+
pred_time = y_true["time"] - raw_prediction.ravel()
|
|
99
|
+
mask = (pred_time > 0) | y_true["event"]
|
|
100
|
+
ret = np.zeros(y_true["event"].shape[0])
|
|
101
|
+
ret[mask] = pred_time.compress(mask, axis=0)
|
|
102
|
+
return ret
|
|
103
|
+
|
|
104
|
+
def update_terminal_regions(
|
|
105
|
+
self, tree, X, y, residual, raw_predictions, sample_weight, sample_mask, learning_rate=0.1, k=0
|
|
106
|
+
):
|
|
107
|
+
"""Least squares does not need to update terminal regions.
|
|
108
|
+
|
|
109
|
+
But it has to update the predictions.
|
|
110
|
+
"""
|
|
111
|
+
# update predictions
|
|
112
|
+
raw_predictions[:, k] += learning_rate * tree.predict(X).ravel()
|
|
113
|
+
|
|
114
|
+
def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, residual, raw_predictions, sample_weight):
|
|
115
|
+
"""Least squares does not need to update terminal regions"""
|
|
116
|
+
|
|
117
|
+
def _scale_raw_prediction(self, raw_predictions):
|
|
118
|
+
np.exp(raw_predictions, out=raw_predictions)
|
|
119
|
+
return raw_predictions
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class IPCWLeastSquaresError(SurvivalLossFunction):
|
|
123
|
+
"""Inverse probability of censoring weighted least squares error"""
|
|
124
|
+
|
|
125
|
+
# pylint: disable=no-self-use
|
|
126
|
+
|
|
127
|
+
def __call__(self, y_true, raw_prediction, sample_weight=None):
|
|
128
|
+
sample_weight = ipc_weights(y_true["event"], y_true["time"])
|
|
129
|
+
return 1.0 / sample_weight.sum() * np.sum(sample_weight * ((y_true["time"] - raw_prediction.ravel()) ** 2.0))
|
|
130
|
+
|
|
131
|
+
def gradient(self, y_true, raw_prediction, **kwargs): # pylint: disable=unused-argument
|
|
132
|
+
return y_true["time"] - raw_prediction.ravel()
|
|
133
|
+
|
|
134
|
+
def update_terminal_regions(self, tree, X, y, residual, y_pred, sample_weight, sample_mask, learning_rate=0.1, k=0):
|
|
135
|
+
y_pred[:, k] += learning_rate * tree.predict(X).ravel()
|
|
136
|
+
|
|
137
|
+
def _update_terminal_region(
|
|
138
|
+
self, tree, terminal_regions, leaf, X, y, residual, pred, sample_weight
|
|
139
|
+
): # pragma: no cover
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
def _scale_raw_prediction(self, raw_predictions):
|
|
143
|
+
np.exp(raw_predictions, out=raw_predictions)
|
|
144
|
+
return raw_predictions
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
LOSS_FUNCTIONS = {
|
|
148
|
+
"coxph": CoxPH,
|
|
149
|
+
"squared": CensoredSquaredLoss,
|
|
150
|
+
"ipcwls": IPCWLeastSquaresError,
|
|
151
|
+
}
|
sksurv/exceptions.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NoComparablePairException(ValueError):
|
|
16
|
+
"""An error indicating that data of censored event times
|
|
17
|
+
does not contain one or more comparable pairs.
|
|
18
|
+
"""
|
sksurv/functions.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
from sklearn.utils.validation import check_consistent_length
|
|
16
|
+
|
|
17
|
+
__all__ = ["StepFunction"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StepFunction:
|
|
21
|
+
r"""A callable step function.
|
|
22
|
+
|
|
23
|
+
The function is defined by a set of points :math:`(x_i, y_i)` and is
|
|
24
|
+
evaluated as:
|
|
25
|
+
|
|
26
|
+
.. math::
|
|
27
|
+
|
|
28
|
+
f(z) = a \cdot y_i + b \quad \text{if} \quad x_i \leq z < x_{i + 1}
|
|
29
|
+
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
x : ndarray, shape = (n_points,)
|
|
33
|
+
The values on the x-axis, must be in ascending order.
|
|
34
|
+
y : ndarray, shape = (n_points,)
|
|
35
|
+
The corresponding values on the y-axis.
|
|
36
|
+
a : float, optional, default: 1.0
|
|
37
|
+
A constant factor to scale ``y`` by.
|
|
38
|
+
b : float, optional, default: 0.0
|
|
39
|
+
A constant offset term.
|
|
40
|
+
domain : tuple, optional, default: (0, None)
|
|
41
|
+
A tuple ``(lower, upper)`` that defines the domain of the step function.
|
|
42
|
+
If ``lower`` or ``upper`` is ``None``, the first or last value of ``x`` is
|
|
43
|
+
used as the limit, respectively.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, x, y, *, a=1.0, b=0.0, domain=(0, None)):
|
|
47
|
+
check_consistent_length(x, y)
|
|
48
|
+
self.x = x
|
|
49
|
+
self.y = y
|
|
50
|
+
self.a = a
|
|
51
|
+
self.b = b
|
|
52
|
+
domain_lower = self.x[0] if domain[0] is None else domain[0]
|
|
53
|
+
domain_upper = self.x[-1] if domain[1] is None else domain[1]
|
|
54
|
+
self._domain = (float(domain_lower), float(domain_upper))
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def domain(self):
|
|
58
|
+
"""The domain of the function.
|
|
59
|
+
|
|
60
|
+
The domain is the range of values that the function accepts.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
lower_limit : float
|
|
65
|
+
Lower limit of the omain.
|
|
66
|
+
|
|
67
|
+
upper_limit : float
|
|
68
|
+
Upper limit of the domain.
|
|
69
|
+
"""
|
|
70
|
+
return self._domain
|
|
71
|
+
|
|
72
|
+
def __call__(self, x):
|
|
73
|
+
"""Evaluate the step function at given values.
|
|
74
|
+
|
|
75
|
+
Parameters
|
|
76
|
+
----------
|
|
77
|
+
x : float or array-like, shape=(n_values,)
|
|
78
|
+
The values at which to evaluate the step function.
|
|
79
|
+
Values must be within the function's ``domain``.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
y : float or array-like, shape=(n_values,)
|
|
84
|
+
The value of the step function at ``x``.
|
|
85
|
+
|
|
86
|
+
Raises
|
|
87
|
+
------
|
|
88
|
+
ValueError
|
|
89
|
+
If ``x`` contains values outside the function's ``domain``.
|
|
90
|
+
"""
|
|
91
|
+
x = np.atleast_1d(x)
|
|
92
|
+
if not np.isfinite(x).all():
|
|
93
|
+
raise ValueError("x must be finite")
|
|
94
|
+
if np.min(x) < self._domain[0] or np.max(x) > self.domain[1]:
|
|
95
|
+
raise ValueError(f"x must be within [{self.domain[0]:f}; {self.domain[1]:f}]")
|
|
96
|
+
|
|
97
|
+
# x is within the domain, but we need to account for self.domain[0] <= x < self.x[0]
|
|
98
|
+
x = np.clip(x, a_min=self.x[0], a_max=None)
|
|
99
|
+
|
|
100
|
+
i = np.searchsorted(self.x, x, side="left")
|
|
101
|
+
not_exact = self.x[i] != x
|
|
102
|
+
i[not_exact] -= 1
|
|
103
|
+
value = self.a * self.y[i] + self.b
|
|
104
|
+
if value.shape[0] == 1:
|
|
105
|
+
return value[0]
|
|
106
|
+
return value
|
|
107
|
+
|
|
108
|
+
def __repr__(self):
|
|
109
|
+
return f"StepFunction(x={self.x!r}, y={self.y!r}, a={self.a!r}, b={self.b!r})"
|
|
110
|
+
|
|
111
|
+
def __eq__(self, other):
|
|
112
|
+
if isinstance(other, type(self)):
|
|
113
|
+
return all(self.x == other.x) and all(self.y == other.y) and self.a == other.a and self.b == other.b
|
|
114
|
+
return False
|
sksurv/io/__init__.py
ADDED
sksurv/io/arffread.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
from pandas.api.types import is_string_dtype
|
|
16
|
+
from scipy.io.arff import loadarff as scipy_loadarff
|
|
17
|
+
|
|
18
|
+
__all__ = ["loadarff"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _to_pandas(data, meta):
|
|
22
|
+
data_dict = {}
|
|
23
|
+
attrnames = sorted(meta.names())
|
|
24
|
+
for name in attrnames:
|
|
25
|
+
tp, attr_format = meta[name]
|
|
26
|
+
if tp == "nominal":
|
|
27
|
+
raw = []
|
|
28
|
+
for b in data[name]:
|
|
29
|
+
# replace missing values with NaN
|
|
30
|
+
if b == b"?":
|
|
31
|
+
raw.append(np.nan)
|
|
32
|
+
else:
|
|
33
|
+
raw.append(b.decode())
|
|
34
|
+
|
|
35
|
+
data_dict[name] = pd.Categorical(raw, categories=attr_format, ordered=False)
|
|
36
|
+
else:
|
|
37
|
+
arr = data[name]
|
|
38
|
+
dtype = "str" if is_string_dtype(arr.dtype) else arr.dtype
|
|
39
|
+
p = pd.Series(arr, dtype=dtype)
|
|
40
|
+
data_dict[name] = p
|
|
41
|
+
|
|
42
|
+
# currently, this step converts all pandas.Categorial columns back to pandas.Series
|
|
43
|
+
return pd.DataFrame.from_dict(data_dict)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def loadarff(filename):
|
|
47
|
+
"""Load ARFF file.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
filename : str or file-like
|
|
52
|
+
Path to ARFF file, or file-like object to read from.
|
|
53
|
+
|
|
54
|
+
Returns
|
|
55
|
+
-------
|
|
56
|
+
data_frame : :class:`pandas.DataFrame`
|
|
57
|
+
DataFrame containing data of ARFF file
|
|
58
|
+
|
|
59
|
+
See Also
|
|
60
|
+
--------
|
|
61
|
+
scipy.io.arff.loadarff : The underlying function that reads the ARFF file.
|
|
62
|
+
|
|
63
|
+
Examples
|
|
64
|
+
--------
|
|
65
|
+
>>> from io import StringIO
|
|
66
|
+
>>> from sksurv.io import loadarff
|
|
67
|
+
>>>
|
|
68
|
+
>>> # Create a dummy ARFF file
|
|
69
|
+
>>> arff_content = '''
|
|
70
|
+
... @relation test_data
|
|
71
|
+
... @attribute feature1 numeric
|
|
72
|
+
... @attribute feature2 numeric
|
|
73
|
+
... @attribute class {A,B,C}
|
|
74
|
+
... @data
|
|
75
|
+
... 1.0,2.0,A
|
|
76
|
+
... 3.0,4.0,B
|
|
77
|
+
... 5.0,6.0,C
|
|
78
|
+
... '''
|
|
79
|
+
>>>
|
|
80
|
+
>>> # Load the ARFF file
|
|
81
|
+
>>> with StringIO(arff_content) as f:
|
|
82
|
+
... data = loadarff(f)
|
|
83
|
+
>>>
|
|
84
|
+
>>> print(data)
|
|
85
|
+
class feature1 feature2
|
|
86
|
+
0 A 1.0 2.0
|
|
87
|
+
1 B 3.0 4.0
|
|
88
|
+
2 C 5.0 6.0
|
|
89
|
+
"""
|
|
90
|
+
data, meta = scipy_loadarff(filename)
|
|
91
|
+
return _to_pandas(data, meta)
|
sksurv/io/arffwrite.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import os.path
|
|
14
|
+
import re
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from pandas.api.types import CategoricalDtype, is_string_dtype
|
|
19
|
+
|
|
20
|
+
_ILLEGAL_CHARACTER_PAT = re.compile(r"[^-_=\w\d\(\)<>\.]")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def writearff(data, filename, relation_name=None, index=True):
|
|
24
|
+
"""Write ARFF file
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
data : :class:`pandas.DataFrame`
|
|
29
|
+
DataFrame containing data
|
|
30
|
+
|
|
31
|
+
filename : str or file-like object
|
|
32
|
+
Path to ARFF file or file-like object. In the latter case,
|
|
33
|
+
the handle is closed by calling this function.
|
|
34
|
+
|
|
35
|
+
relation_name : str, optional, default: 'pandas'
|
|
36
|
+
Name of relation in ARFF file.
|
|
37
|
+
|
|
38
|
+
index : boolean, optional, default: True
|
|
39
|
+
Write row names (index)
|
|
40
|
+
|
|
41
|
+
See Also
|
|
42
|
+
--------
|
|
43
|
+
loadarff : Function to read ARFF files.
|
|
44
|
+
|
|
45
|
+
Examples
|
|
46
|
+
--------
|
|
47
|
+
>>> import numpy as np
|
|
48
|
+
>>> import pandas as pd
|
|
49
|
+
>>> from sksurv.io import writearff
|
|
50
|
+
>>>
|
|
51
|
+
>>> # Create a dummy DataFrame
|
|
52
|
+
>>> data = pd.DataFrame({
|
|
53
|
+
... 'feature1': [1.0, 3.0, 5.0],
|
|
54
|
+
... 'feature2': [2.0, np.nan, 6.0],
|
|
55
|
+
... 'class': ['A', 'B', 'C']
|
|
56
|
+
... }, index=['One', 'Two', 'Three'])
|
|
57
|
+
>>>
|
|
58
|
+
>>> # Write to ARFF file
|
|
59
|
+
>>> writearff(data, 'test_output.arff', relation_name='test_data')
|
|
60
|
+
>>>
|
|
61
|
+
>>> # Read contents of ARFF file
|
|
62
|
+
>>> with open('test_output.arff') as f:
|
|
63
|
+
... arff_contents = "".join(f.readlines())
|
|
64
|
+
>>> print(arff_contents)
|
|
65
|
+
@relation test_data
|
|
66
|
+
<BLANKLINE>
|
|
67
|
+
@attribute index {One,Three,Two}
|
|
68
|
+
@attribute feature1 real
|
|
69
|
+
@attribute feature2 real
|
|
70
|
+
@attribute class {A,B,C}
|
|
71
|
+
<BLANKLINE>
|
|
72
|
+
@data
|
|
73
|
+
One,1.0,2.0,A
|
|
74
|
+
Two,3.0,?,B
|
|
75
|
+
Three,5.0,6.0,C
|
|
76
|
+
"""
|
|
77
|
+
if isinstance(filename, str):
|
|
78
|
+
fp = open(filename, "w")
|
|
79
|
+
|
|
80
|
+
if relation_name is None:
|
|
81
|
+
relation_name = os.path.basename(filename)
|
|
82
|
+
else:
|
|
83
|
+
fp = filename
|
|
84
|
+
|
|
85
|
+
if relation_name is None:
|
|
86
|
+
relation_name = "pandas"
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
data = _write_header(data, fp, relation_name, index)
|
|
90
|
+
fp.write("\n")
|
|
91
|
+
_write_data(data, fp)
|
|
92
|
+
finally:
|
|
93
|
+
fp.close()
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _write_header(data, fp, relation_name, index):
|
|
97
|
+
"""Write header containing attribute names and types"""
|
|
98
|
+
fp.write(f"@relation {relation_name}\n\n")
|
|
99
|
+
|
|
100
|
+
if index:
|
|
101
|
+
data = data.reset_index()
|
|
102
|
+
|
|
103
|
+
attribute_names = _sanitize_column_names(data)
|
|
104
|
+
|
|
105
|
+
for column, series in data.items():
|
|
106
|
+
name = attribute_names[column]
|
|
107
|
+
fp.write(f"@attribute {name}\t")
|
|
108
|
+
|
|
109
|
+
if isinstance(series.dtype, CategoricalDtype) or is_string_dtype(series.dtype):
|
|
110
|
+
_write_attribute_categorical(series, fp)
|
|
111
|
+
elif np.issubdtype(series.dtype, np.floating):
|
|
112
|
+
fp.write("real")
|
|
113
|
+
elif np.issubdtype(series.dtype, np.integer):
|
|
114
|
+
fp.write("integer")
|
|
115
|
+
elif np.issubdtype(series.dtype, np.datetime64):
|
|
116
|
+
fp.write("date 'yyyy-MM-dd HH:mm:ss'")
|
|
117
|
+
else:
|
|
118
|
+
raise TypeError(f"unsupported type {series.dtype}")
|
|
119
|
+
|
|
120
|
+
fp.write("\n")
|
|
121
|
+
return data
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _sanitize_column_names(data):
|
|
125
|
+
"""Replace illegal characters with underscore"""
|
|
126
|
+
new_names = {}
|
|
127
|
+
for name in data.columns:
|
|
128
|
+
new_names[name] = _ILLEGAL_CHARACTER_PAT.sub("_", name)
|
|
129
|
+
return new_names
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _check_str_value(x):
|
|
133
|
+
"""If string has a space, wrap it in double quotes and remove/escape illegal characters"""
|
|
134
|
+
if isinstance(x, str):
|
|
135
|
+
# remove commas, and single quotation marks since loadarff cannot deal with it
|
|
136
|
+
x = x.replace(",", ".").replace(chr(0x2018), "'").replace(chr(0x2019), "'")
|
|
137
|
+
|
|
138
|
+
# put string in double quotes
|
|
139
|
+
if " " in x:
|
|
140
|
+
if x[0] in ('"', "'"):
|
|
141
|
+
x = x[1:]
|
|
142
|
+
if x[-1] in ('"', "'"):
|
|
143
|
+
x = x[: len(x) - 1]
|
|
144
|
+
x = '"' + x.replace('"', '\\"') + '"'
|
|
145
|
+
return str(x)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
_check_str_array = np.frompyfunc(_check_str_value, 1, 1)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def _write_attribute_categorical(series, fp):
|
|
152
|
+
"""Write categories of a categorical/nominal attribute"""
|
|
153
|
+
if isinstance(series.dtype, CategoricalDtype):
|
|
154
|
+
categories = series.cat.categories
|
|
155
|
+
string_values = _check_str_array(categories)
|
|
156
|
+
else:
|
|
157
|
+
categories = series.dropna().unique()
|
|
158
|
+
string_values = sorted(_check_str_array(categories), key=lambda x: x.strip('"'))
|
|
159
|
+
|
|
160
|
+
values = ",".join(string_values)
|
|
161
|
+
fp.write("{")
|
|
162
|
+
fp.write(values)
|
|
163
|
+
fp.write("}")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _write_data(data, fp):
|
|
167
|
+
"""Write the data section"""
|
|
168
|
+
fp.write("@data\n")
|
|
169
|
+
|
|
170
|
+
def to_str(x):
|
|
171
|
+
if pd.isna(x):
|
|
172
|
+
return "?"
|
|
173
|
+
return str(x)
|
|
174
|
+
|
|
175
|
+
data = data.applymap(to_str)
|
|
176
|
+
n_rows = data.shape[0]
|
|
177
|
+
for i in range(n_rows):
|
|
178
|
+
str_values = list(data.iloc[i, :].apply(_check_str_array))
|
|
179
|
+
line = ",".join(str_values)
|
|
180
|
+
fp.write(line)
|
|
181
|
+
fp.write("\n")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .clinical import ClinicalKernelTransform, clinical_kernel # noqa: F401
|
|
Binary file
|