scikit-survival 0.23.1__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.23.1.dist-info/COPYING +674 -0
- scikit_survival-0.23.1.dist-info/METADATA +888 -0
- scikit_survival-0.23.1.dist-info/RECORD +55 -0
- scikit_survival-0.23.1.dist-info/WHEEL +5 -0
- scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +138 -0
- sksurv/base.py +103 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
- sksurv/column.py +201 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +10 -0
- sksurv/datasets/base.py +436 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
- sksurv/ensemble/boosting.py +1610 -0
- sksurv/ensemble/forest.py +947 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +58 -0
- sksurv/io/arffwrite.py +145 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +328 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
- sksurv/linear_model/aft.py +205 -0
- sksurv/linear_model/coxnet.py +543 -0
- sksurv/linear_model/coxph.py +618 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +642 -0
- sksurv/meta/stacking.py +349 -0
- sksurv/metrics.py +996 -0
- sksurv/nonparametric.py +588 -0
- sksurv/preprocessing.py +155 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +606 -0
- sksurv/svm/naive_survival_svm.py +221 -0
- sksurv/svm/survival_svm.py +1228 -0
- sksurv/testing.py +108 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
- sksurv/tree/tree.py +703 -0
- sksurv/util.py +333 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
from abc import ABCMeta
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from sklearn._loss.link import IdentityLink
|
|
17
|
+
from sklearn._loss.loss import BaseLoss
|
|
18
|
+
from sklearn.utils.extmath import squared_norm
|
|
19
|
+
|
|
20
|
+
from ..nonparametric import ipc_weights
|
|
21
|
+
from ._coxph_loss import coxph_loss, coxph_negative_gradient
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SurvivalLossFunction(BaseLoss, metaclass=ABCMeta): # noqa: B024
|
|
25
|
+
"""Base class for survival loss functions."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, sample_weight=None):
|
|
28
|
+
super().__init__(closs=None, link=IdentityLink())
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CoxPH(SurvivalLossFunction):
|
|
32
|
+
"""Cox Partial Likelihood"""
|
|
33
|
+
|
|
34
|
+
# pylint: disable=no-self-use
|
|
35
|
+
|
|
36
|
+
def __call__(self, y_true, raw_prediction, sample_weight=None): # pylint: disable=unused-argument
|
|
37
|
+
"""Compute the partial likelihood of prediction ``y_pred`` and ``y``."""
|
|
38
|
+
# TODO add support for sample weights
|
|
39
|
+
return coxph_loss(y_true["event"].astype(np.uint8), y_true["time"], raw_prediction.ravel())
|
|
40
|
+
|
|
41
|
+
def gradient(self, y_true, raw_prediction, sample_weight=None, **kwargs): # pylint: disable=unused-argument
|
|
42
|
+
"""Negative gradient of partial likelihood
|
|
43
|
+
|
|
44
|
+
Parameters
|
|
45
|
+
---------
|
|
46
|
+
y : tuple, len = 2
|
|
47
|
+
First element is boolean event indicator and second element survival/censoring time.
|
|
48
|
+
y_pred : np.ndarray, shape=(n,):
|
|
49
|
+
The predictions.
|
|
50
|
+
"""
|
|
51
|
+
ret = coxph_negative_gradient(y_true["event"].astype(np.uint8), y_true["time"], raw_prediction.ravel())
|
|
52
|
+
if sample_weight is not None:
|
|
53
|
+
ret *= sample_weight
|
|
54
|
+
return ret
|
|
55
|
+
|
|
56
|
+
def update_terminal_regions(
|
|
57
|
+
self, tree, X, y, residual, raw_predictions, sample_weight, sample_mask, learning_rate=0.1, k=0
|
|
58
|
+
):
|
|
59
|
+
"""Least squares does not need to update terminal regions.
|
|
60
|
+
|
|
61
|
+
But it has to update the predictions.
|
|
62
|
+
"""
|
|
63
|
+
# update predictions
|
|
64
|
+
raw_predictions[:, k] += learning_rate * tree.predict(X).ravel()
|
|
65
|
+
|
|
66
|
+
def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, residual, raw_predictions, sample_weight):
|
|
67
|
+
"""Least squares does not need to update terminal regions"""
|
|
68
|
+
|
|
69
|
+
def _scale_raw_prediction(self, raw_predictions):
|
|
70
|
+
return raw_predictions
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class CensoredSquaredLoss(SurvivalLossFunction):
|
|
74
|
+
"""Censoring-aware squared loss.
|
|
75
|
+
|
|
76
|
+
Censoring is taken into account by only considering the residuals
|
|
77
|
+
of samples that are not censored, or the predicted survival time
|
|
78
|
+
is before the time of censoring.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
# pylint: disable=no-self-use
|
|
82
|
+
def __call__(self, y_true, raw_prediction, sample_weight=None):
|
|
83
|
+
"""Compute the partial likelihood of prediction ``y_pred`` and ``y``."""
|
|
84
|
+
pred_time = y_true["time"] - raw_prediction.ravel()
|
|
85
|
+
mask = (pred_time > 0) | y_true["event"]
|
|
86
|
+
return 0.5 * squared_norm(pred_time.compress(mask, axis=0))
|
|
87
|
+
|
|
88
|
+
def gradient(self, y_true, raw_prediction, **kwargs): # pylint: disable=unused-argument
|
|
89
|
+
"""Negative gradient of partial likelihood
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
---------
|
|
93
|
+
y : tuple, len = 2
|
|
94
|
+
First element is boolean event indicator and second element survival/censoring time.
|
|
95
|
+
y_pred : np.ndarray, shape=(n,):
|
|
96
|
+
The predictions.
|
|
97
|
+
"""
|
|
98
|
+
pred_time = y_true["time"] - raw_prediction.ravel()
|
|
99
|
+
mask = (pred_time > 0) | y_true["event"]
|
|
100
|
+
ret = np.zeros(y_true["event"].shape[0])
|
|
101
|
+
ret[mask] = pred_time.compress(mask, axis=0)
|
|
102
|
+
return ret
|
|
103
|
+
|
|
104
|
+
def update_terminal_regions(
|
|
105
|
+
self, tree, X, y, residual, raw_predictions, sample_weight, sample_mask, learning_rate=0.1, k=0
|
|
106
|
+
):
|
|
107
|
+
"""Least squares does not need to update terminal regions.
|
|
108
|
+
|
|
109
|
+
But it has to update the predictions.
|
|
110
|
+
"""
|
|
111
|
+
# update predictions
|
|
112
|
+
raw_predictions[:, k] += learning_rate * tree.predict(X).ravel()
|
|
113
|
+
|
|
114
|
+
def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, residual, raw_predictions, sample_weight):
|
|
115
|
+
"""Least squares does not need to update terminal regions"""
|
|
116
|
+
|
|
117
|
+
def _scale_raw_prediction(self, raw_predictions):
|
|
118
|
+
np.exp(raw_predictions, out=raw_predictions)
|
|
119
|
+
return raw_predictions
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class IPCWLeastSquaresError(SurvivalLossFunction):
|
|
123
|
+
"""Inverse probability of censoring weighted least squares error"""
|
|
124
|
+
|
|
125
|
+
# pylint: disable=no-self-use
|
|
126
|
+
|
|
127
|
+
def __call__(self, y_true, raw_prediction, sample_weight=None):
|
|
128
|
+
sample_weight = ipc_weights(y_true["event"], y_true["time"])
|
|
129
|
+
return 1.0 / sample_weight.sum() * np.sum(sample_weight * ((y_true["time"] - raw_prediction.ravel()) ** 2.0))
|
|
130
|
+
|
|
131
|
+
def gradient(self, y_true, raw_prediction, **kwargs): # pylint: disable=unused-argument
|
|
132
|
+
return y_true["time"] - raw_prediction.ravel()
|
|
133
|
+
|
|
134
|
+
def update_terminal_regions(self, tree, X, y, residual, y_pred, sample_weight, sample_mask, learning_rate=0.1, k=0):
|
|
135
|
+
y_pred[:, k] += learning_rate * tree.predict(X).ravel()
|
|
136
|
+
|
|
137
|
+
def _update_terminal_region(
|
|
138
|
+
self, tree, terminal_regions, leaf, X, y, residual, pred, sample_weight
|
|
139
|
+
): # pragma: no cover
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
def _scale_raw_prediction(self, raw_predictions):
|
|
143
|
+
np.exp(raw_predictions, out=raw_predictions)
|
|
144
|
+
return raw_predictions
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
LOSS_FUNCTIONS = {
|
|
148
|
+
"coxph": CoxPH,
|
|
149
|
+
"squared": CensoredSquaredLoss,
|
|
150
|
+
"ipcwls": IPCWLeastSquaresError,
|
|
151
|
+
}
|
sksurv/exceptions.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class NoComparablePairException(ValueError):
|
|
16
|
+
"""An error indicating that data of censored event times
|
|
17
|
+
does not contain one or more comparable pairs.
|
|
18
|
+
"""
|
sksurv/functions.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
from sklearn.utils import check_consistent_length
|
|
16
|
+
|
|
17
|
+
__all__ = ["StepFunction"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class StepFunction:
|
|
21
|
+
"""Callable step function.
|
|
22
|
+
|
|
23
|
+
.. math::
|
|
24
|
+
|
|
25
|
+
f(z) = a * y_i + b,
|
|
26
|
+
x_i \\leq z < x_{i + 1}
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
x : ndarray, shape = (n_points,)
|
|
31
|
+
Values on the x axis in ascending order.
|
|
32
|
+
|
|
33
|
+
y : ndarray, shape = (n_points,)
|
|
34
|
+
Corresponding values on the y axis.
|
|
35
|
+
|
|
36
|
+
a : float, optional, default: 1.0
|
|
37
|
+
Constant to multiply by.
|
|
38
|
+
|
|
39
|
+
b : float, optional, default: 0.0
|
|
40
|
+
Constant offset term.
|
|
41
|
+
|
|
42
|
+
domain : tuple, optional
|
|
43
|
+
A tuple with two entries that sets the limits of the
|
|
44
|
+
domain of the step function.
|
|
45
|
+
If entry is `None`, use the first/last value of `x` as limit.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, x, y, *, a=1.0, b=0.0, domain=(0, None)):
|
|
49
|
+
check_consistent_length(x, y)
|
|
50
|
+
self.x = x
|
|
51
|
+
self.y = y
|
|
52
|
+
self.a = a
|
|
53
|
+
self.b = b
|
|
54
|
+
domain_lower = self.x[0] if domain[0] is None else domain[0]
|
|
55
|
+
domain_upper = self.x[-1] if domain[1] is None else domain[1]
|
|
56
|
+
self._domain = (float(domain_lower), float(domain_upper))
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def domain(self):
|
|
60
|
+
"""Returns the domain of the function, that means
|
|
61
|
+
the range of values that the function accepts.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
lower_limit : float
|
|
66
|
+
Lower limit of domain.
|
|
67
|
+
|
|
68
|
+
upper_limit : float
|
|
69
|
+
Upper limit of domain.
|
|
70
|
+
"""
|
|
71
|
+
return self._domain
|
|
72
|
+
|
|
73
|
+
def __call__(self, x):
|
|
74
|
+
"""Evaluate step function.
|
|
75
|
+
|
|
76
|
+
Values outside the interval specified by `self.domain`
|
|
77
|
+
will raise an exception.
|
|
78
|
+
Values in `x` that are in the interval `[self.domain[0]; self.x[0]]`
|
|
79
|
+
get mapped to `self.y[0]`.
|
|
80
|
+
|
|
81
|
+
Parameters
|
|
82
|
+
----------
|
|
83
|
+
x : float|array-like, shape=(n_values,)
|
|
84
|
+
Values to evaluate step function at.
|
|
85
|
+
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
y : float|array-like, shape=(n_values,)
|
|
89
|
+
Values of step function at `x`.
|
|
90
|
+
"""
|
|
91
|
+
x = np.atleast_1d(x)
|
|
92
|
+
if not np.isfinite(x).all():
|
|
93
|
+
raise ValueError("x must be finite")
|
|
94
|
+
if np.min(x) < self._domain[0] or np.max(x) > self.domain[1]:
|
|
95
|
+
raise ValueError(f"x must be within [{self.domain[0]:f}; {self.domain[1]:f}]")
|
|
96
|
+
|
|
97
|
+
# x is within the domain, but we need to account for self.domain[0] <= x < self.x[0]
|
|
98
|
+
x = np.clip(x, a_min=self.x[0], a_max=None)
|
|
99
|
+
|
|
100
|
+
i = np.searchsorted(self.x, x, side="left")
|
|
101
|
+
not_exact = self.x[i] != x
|
|
102
|
+
i[not_exact] -= 1
|
|
103
|
+
value = self.a * self.y[i] + self.b
|
|
104
|
+
if value.shape[0] == 1:
|
|
105
|
+
return value[0]
|
|
106
|
+
return value
|
|
107
|
+
|
|
108
|
+
def __repr__(self):
|
|
109
|
+
return f"StepFunction(x={self.x!r}, y={self.y!r}, a={self.a!r}, b={self.b!r})"
|
|
110
|
+
|
|
111
|
+
def __eq__(self, other):
|
|
112
|
+
if isinstance(other, type(self)):
|
|
113
|
+
return all(self.x == other.x) and all(self.y == other.y) and self.a == other.a and self.b == other.b
|
|
114
|
+
return False
|
sksurv/io/__init__.py
ADDED
sksurv/io/arffread.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
from scipy.io.arff import loadarff as scipy_loadarff
|
|
16
|
+
|
|
17
|
+
__all__ = ["loadarff"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _to_pandas(data, meta):
|
|
21
|
+
data_dict = {}
|
|
22
|
+
attrnames = sorted(meta.names())
|
|
23
|
+
for name in attrnames:
|
|
24
|
+
tp, attr_format = meta[name]
|
|
25
|
+
if tp == "nominal":
|
|
26
|
+
raw = []
|
|
27
|
+
for b in data[name]:
|
|
28
|
+
# replace missing values with NaN
|
|
29
|
+
if b == b"?":
|
|
30
|
+
raw.append(np.nan)
|
|
31
|
+
else:
|
|
32
|
+
raw.append(b.decode())
|
|
33
|
+
|
|
34
|
+
data_dict[name] = pd.Categorical(raw, categories=attr_format, ordered=False)
|
|
35
|
+
else:
|
|
36
|
+
arr = data[name]
|
|
37
|
+
p = pd.Series(arr, dtype=arr.dtype)
|
|
38
|
+
data_dict[name] = p
|
|
39
|
+
|
|
40
|
+
# currently, this step converts all pandas.Categorial columns back to pandas.Series
|
|
41
|
+
return pd.DataFrame.from_dict(data_dict)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def loadarff(filename):
|
|
45
|
+
"""Load ARFF file
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
filename : string
|
|
50
|
+
Path to ARFF file
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
data_frame : :class:`pandas.DataFrame`
|
|
55
|
+
DataFrame containing data of ARFF file
|
|
56
|
+
"""
|
|
57
|
+
data, meta = scipy_loadarff(filename)
|
|
58
|
+
return _to_pandas(data, meta)
|
sksurv/io/arffwrite.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import os.path
|
|
14
|
+
import re
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from pandas.api.types import CategoricalDtype, is_object_dtype
|
|
19
|
+
|
|
20
|
+
_ILLEGAL_CHARACTER_PAT = re.compile(r"[^-_=\w\d\(\)<>\.]")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def writearff(data, filename, relation_name=None, index=True):
|
|
24
|
+
"""Write ARFF file
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
data : :class:`pandas.DataFrame`
|
|
29
|
+
DataFrame containing data
|
|
30
|
+
|
|
31
|
+
filename : string or file-like object
|
|
32
|
+
Path to ARFF file or file-like object. In the latter case,
|
|
33
|
+
the handle is closed by calling this function.
|
|
34
|
+
|
|
35
|
+
relation_name : string, optional, default: "pandas"
|
|
36
|
+
Name of relation in ARFF file.
|
|
37
|
+
|
|
38
|
+
index : boolean, optional, default: True
|
|
39
|
+
Write row names (index)
|
|
40
|
+
"""
|
|
41
|
+
if isinstance(filename, str):
|
|
42
|
+
fp = open(filename, "w")
|
|
43
|
+
|
|
44
|
+
if relation_name is None:
|
|
45
|
+
relation_name = os.path.basename(filename)
|
|
46
|
+
else:
|
|
47
|
+
fp = filename
|
|
48
|
+
|
|
49
|
+
if relation_name is None:
|
|
50
|
+
relation_name = "pandas"
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
data = _write_header(data, fp, relation_name, index)
|
|
54
|
+
fp.write("\n")
|
|
55
|
+
_write_data(data, fp)
|
|
56
|
+
finally:
|
|
57
|
+
fp.close()
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _write_header(data, fp, relation_name, index):
|
|
61
|
+
"""Write header containing attribute names and types"""
|
|
62
|
+
fp.write(f"@relation {relation_name}\n\n")
|
|
63
|
+
|
|
64
|
+
if index:
|
|
65
|
+
data = data.reset_index()
|
|
66
|
+
|
|
67
|
+
attribute_names = _sanitize_column_names(data)
|
|
68
|
+
|
|
69
|
+
for column, series in data.items():
|
|
70
|
+
name = attribute_names[column]
|
|
71
|
+
fp.write(f"@attribute {name}\t")
|
|
72
|
+
|
|
73
|
+
if isinstance(series.dtype, CategoricalDtype) or is_object_dtype(series):
|
|
74
|
+
_write_attribute_categorical(series, fp)
|
|
75
|
+
elif np.issubdtype(series.dtype, np.floating):
|
|
76
|
+
fp.write("real")
|
|
77
|
+
elif np.issubdtype(series.dtype, np.integer):
|
|
78
|
+
fp.write("integer")
|
|
79
|
+
elif np.issubdtype(series.dtype, np.datetime64):
|
|
80
|
+
fp.write("date 'yyyy-MM-dd HH:mm:ss'")
|
|
81
|
+
else:
|
|
82
|
+
raise TypeError(f"unsupported type {series.dtype}")
|
|
83
|
+
|
|
84
|
+
fp.write("\n")
|
|
85
|
+
return data
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _sanitize_column_names(data):
|
|
89
|
+
"""Replace illegal characters with underscore"""
|
|
90
|
+
new_names = {}
|
|
91
|
+
for name in data.columns:
|
|
92
|
+
new_names[name] = _ILLEGAL_CHARACTER_PAT.sub("_", name)
|
|
93
|
+
return new_names
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _check_str_value(x):
|
|
97
|
+
"""If string has a space, wrap it in double quotes and remove/escape illegal characters"""
|
|
98
|
+
if isinstance(x, str):
|
|
99
|
+
# remove commas, and single quotation marks since loadarff cannot deal with it
|
|
100
|
+
x = x.replace(",", ".").replace(chr(0x2018), "'").replace(chr(0x2019), "'")
|
|
101
|
+
|
|
102
|
+
# put string in double quotes
|
|
103
|
+
if " " in x:
|
|
104
|
+
if x[0] in ('"', "'"):
|
|
105
|
+
x = x[1:]
|
|
106
|
+
if x[-1] in ('"', "'"):
|
|
107
|
+
x = x[: len(x) - 1]
|
|
108
|
+
x = '"' + x.replace('"', '\\"') + '"'
|
|
109
|
+
return str(x)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
_check_str_array = np.frompyfunc(_check_str_value, 1, 1)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _write_attribute_categorical(series, fp):
|
|
116
|
+
"""Write categories of a categorical/nominal attribute"""
|
|
117
|
+
if isinstance(series.dtype, CategoricalDtype):
|
|
118
|
+
categories = series.cat.categories
|
|
119
|
+
string_values = _check_str_array(categories)
|
|
120
|
+
else:
|
|
121
|
+
categories = series.dropna().unique()
|
|
122
|
+
string_values = sorted(_check_str_array(categories), key=lambda x: x.strip('"'))
|
|
123
|
+
|
|
124
|
+
values = ",".join(string_values)
|
|
125
|
+
fp.write("{")
|
|
126
|
+
fp.write(values)
|
|
127
|
+
fp.write("}")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _write_data(data, fp):
|
|
131
|
+
"""Write the data section"""
|
|
132
|
+
fp.write("@data\n")
|
|
133
|
+
|
|
134
|
+
def to_str(x):
|
|
135
|
+
if pd.isnull(x):
|
|
136
|
+
return "?"
|
|
137
|
+
return str(x)
|
|
138
|
+
|
|
139
|
+
data = data.applymap(to_str)
|
|
140
|
+
n_rows = data.shape[0]
|
|
141
|
+
for i in range(n_rows):
|
|
142
|
+
str_values = list(data.iloc[i, :].apply(_check_str_array))
|
|
143
|
+
line = ",".join(str_values)
|
|
144
|
+
fp.write(line)
|
|
145
|
+
fp.write("\n")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .clinical import ClinicalKernelTransform, clinical_kernel # noqa: F401
|
|
Binary file
|