scikit-survival 0.26.0__cp314-cp314-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. scikit_survival-0.26.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.26.0.dist-info/RECORD +58 -0
  3. scikit_survival-0.26.0.dist-info/WHEEL +6 -0
  4. scikit_survival-0.26.0.dist-info/licenses/COPYING +674 -0
  5. scikit_survival-0.26.0.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +183 -0
  7. sksurv/base.py +115 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cpython-314-darwin.so +0 -0
  10. sksurv/column.py +204 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +12 -0
  13. sksurv/datasets/base.py +614 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/bmt.arff +46 -0
  17. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  18. sksurv/datasets/data/cgvhd.arff +118 -0
  19. sksurv/datasets/data/flchain.arff +7887 -0
  20. sksurv/datasets/data/veteran.arff +148 -0
  21. sksurv/datasets/data/whas500.arff +520 -0
  22. sksurv/docstrings.py +99 -0
  23. sksurv/ensemble/__init__.py +2 -0
  24. sksurv/ensemble/_coxph_loss.cpython-314-darwin.so +0 -0
  25. sksurv/ensemble/boosting.py +1564 -0
  26. sksurv/ensemble/forest.py +902 -0
  27. sksurv/ensemble/survival_loss.py +151 -0
  28. sksurv/exceptions.py +18 -0
  29. sksurv/functions.py +114 -0
  30. sksurv/io/__init__.py +2 -0
  31. sksurv/io/arffread.py +91 -0
  32. sksurv/io/arffwrite.py +181 -0
  33. sksurv/kernels/__init__.py +1 -0
  34. sksurv/kernels/_clinical_kernel.cpython-314-darwin.so +0 -0
  35. sksurv/kernels/clinical.py +348 -0
  36. sksurv/linear_model/__init__.py +3 -0
  37. sksurv/linear_model/_coxnet.cpython-314-darwin.so +0 -0
  38. sksurv/linear_model/aft.py +208 -0
  39. sksurv/linear_model/coxnet.py +592 -0
  40. sksurv/linear_model/coxph.py +637 -0
  41. sksurv/meta/__init__.py +4 -0
  42. sksurv/meta/base.py +35 -0
  43. sksurv/meta/ensemble_selection.py +724 -0
  44. sksurv/meta/stacking.py +370 -0
  45. sksurv/metrics.py +1028 -0
  46. sksurv/nonparametric.py +911 -0
  47. sksurv/preprocessing.py +195 -0
  48. sksurv/svm/__init__.py +11 -0
  49. sksurv/svm/_minlip.cpython-314-darwin.so +0 -0
  50. sksurv/svm/_prsvm.cpython-314-darwin.so +0 -0
  51. sksurv/svm/minlip.py +695 -0
  52. sksurv/svm/naive_survival_svm.py +249 -0
  53. sksurv/svm/survival_svm.py +1236 -0
  54. sksurv/testing.py +155 -0
  55. sksurv/tree/__init__.py +1 -0
  56. sksurv/tree/_criterion.cpython-314-darwin.so +0 -0
  57. sksurv/tree/tree.py +790 -0
  58. sksurv/util.py +416 -0
@@ -0,0 +1,151 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ from abc import ABCMeta
14
+
15
+ import numpy as np
16
+ from sklearn._loss.link import IdentityLink
17
+ from sklearn._loss.loss import BaseLoss
18
+ from sklearn.utils.extmath import squared_norm
19
+
20
+ from ..nonparametric import ipc_weights
21
+ from ._coxph_loss import coxph_loss, coxph_negative_gradient
22
+
23
+
24
+ class SurvivalLossFunction(BaseLoss, metaclass=ABCMeta): # noqa: B024
25
+ """Base class for survival loss functions."""
26
+
27
+ def __init__(self, sample_weight=None):
28
+ super().__init__(closs=None, link=IdentityLink())
29
+
30
+
31
+ class CoxPH(SurvivalLossFunction):
32
+ """Cox Partial Likelihood"""
33
+
34
+ # pylint: disable=no-self-use
35
+
36
+ def __call__(self, y_true, raw_prediction, sample_weight=None): # pylint: disable=unused-argument
37
+ """Compute the partial likelihood of prediction ``y_pred`` and ``y``."""
38
+ # TODO add support for sample weights
39
+ return coxph_loss(y_true["event"].astype(np.uint8), y_true["time"], raw_prediction.ravel())
40
+
41
+ def gradient(self, y_true, raw_prediction, sample_weight=None, **kwargs): # pylint: disable=unused-argument
42
+ """Negative gradient of partial likelihood
43
+
44
+ Parameters
45
+ ---------
46
+ y : tuple, len = 2
47
+ First element is boolean event indicator and second element survival/censoring time.
48
+ y_pred : np.ndarray, shape=(n,):
49
+ The predictions.
50
+ """
51
+ ret = coxph_negative_gradient(y_true["event"].astype(np.uint8), y_true["time"], raw_prediction.ravel())
52
+ if sample_weight is not None:
53
+ ret *= sample_weight
54
+ return ret
55
+
56
+ def update_terminal_regions(
57
+ self, tree, X, y, residual, raw_predictions, sample_weight, sample_mask, learning_rate=0.1, k=0
58
+ ):
59
+ """Least squares does not need to update terminal regions.
60
+
61
+ But it has to update the predictions.
62
+ """
63
+ # update predictions
64
+ raw_predictions[:, k] += learning_rate * tree.predict(X).ravel()
65
+
66
+ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, residual, raw_predictions, sample_weight):
67
+ """Least squares does not need to update terminal regions"""
68
+
69
+ def _scale_raw_prediction(self, raw_predictions):
70
+ return raw_predictions
71
+
72
+
73
+ class CensoredSquaredLoss(SurvivalLossFunction):
74
+ """Censoring-aware squared loss.
75
+
76
+ Censoring is taken into account by only considering the residuals
77
+ of samples that are not censored, or the predicted survival time
78
+ is before the time of censoring.
79
+ """
80
+
81
+ # pylint: disable=no-self-use
82
+ def __call__(self, y_true, raw_prediction, sample_weight=None):
83
+ """Compute the partial likelihood of prediction ``y_pred`` and ``y``."""
84
+ pred_time = y_true["time"] - raw_prediction.ravel()
85
+ mask = (pred_time > 0) | y_true["event"]
86
+ return 0.5 * squared_norm(pred_time.compress(mask, axis=0))
87
+
88
+ def gradient(self, y_true, raw_prediction, **kwargs): # pylint: disable=unused-argument
89
+ """Negative gradient of partial likelihood
90
+
91
+ Parameters
92
+ ---------
93
+ y : tuple, len = 2
94
+ First element is boolean event indicator and second element survival/censoring time.
95
+ y_pred : np.ndarray, shape=(n,):
96
+ The predictions.
97
+ """
98
+ pred_time = y_true["time"] - raw_prediction.ravel()
99
+ mask = (pred_time > 0) | y_true["event"]
100
+ ret = np.zeros(y_true["event"].shape[0])
101
+ ret[mask] = pred_time.compress(mask, axis=0)
102
+ return ret
103
+
104
+ def update_terminal_regions(
105
+ self, tree, X, y, residual, raw_predictions, sample_weight, sample_mask, learning_rate=0.1, k=0
106
+ ):
107
+ """Least squares does not need to update terminal regions.
108
+
109
+ But it has to update the predictions.
110
+ """
111
+ # update predictions
112
+ raw_predictions[:, k] += learning_rate * tree.predict(X).ravel()
113
+
114
+ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y, residual, raw_predictions, sample_weight):
115
+ """Least squares does not need to update terminal regions"""
116
+
117
+ def _scale_raw_prediction(self, raw_predictions):
118
+ np.exp(raw_predictions, out=raw_predictions)
119
+ return raw_predictions
120
+
121
+
122
+ class IPCWLeastSquaresError(SurvivalLossFunction):
123
+ """Inverse probability of censoring weighted least squares error"""
124
+
125
+ # pylint: disable=no-self-use
126
+
127
+ def __call__(self, y_true, raw_prediction, sample_weight=None):
128
+ sample_weight = ipc_weights(y_true["event"], y_true["time"])
129
+ return 1.0 / sample_weight.sum() * np.sum(sample_weight * ((y_true["time"] - raw_prediction.ravel()) ** 2.0))
130
+
131
+ def gradient(self, y_true, raw_prediction, **kwargs): # pylint: disable=unused-argument
132
+ return y_true["time"] - raw_prediction.ravel()
133
+
134
+ def update_terminal_regions(self, tree, X, y, residual, y_pred, sample_weight, sample_mask, learning_rate=0.1, k=0):
135
+ y_pred[:, k] += learning_rate * tree.predict(X).ravel()
136
+
137
+ def _update_terminal_region(
138
+ self, tree, terminal_regions, leaf, X, y, residual, pred, sample_weight
139
+ ): # pragma: no cover
140
+ pass
141
+
142
+ def _scale_raw_prediction(self, raw_predictions):
143
+ np.exp(raw_predictions, out=raw_predictions)
144
+ return raw_predictions
145
+
146
+
147
+ LOSS_FUNCTIONS = {
148
+ "coxph": CoxPH,
149
+ "squared": CensoredSquaredLoss,
150
+ "ipcwls": IPCWLeastSquaresError,
151
+ }
sksurv/exceptions.py ADDED
@@ -0,0 +1,18 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+
14
+
15
+ class NoComparablePairException(ValueError):
16
+ """An error indicating that data of censored event times
17
+ does not contain one or more comparable pairs.
18
+ """
sksurv/functions.py ADDED
@@ -0,0 +1,114 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+
14
+ import numpy as np
15
+ from sklearn.utils.validation import check_consistent_length
16
+
17
+ __all__ = ["StepFunction"]
18
+
19
+
20
+ class StepFunction:
21
+ r"""A callable step function.
22
+
23
+ The function is defined by a set of points :math:`(x_i, y_i)` and is
24
+ evaluated as:
25
+
26
+ .. math::
27
+
28
+ f(z) = a \cdot y_i + b \quad \text{if} \quad x_i \leq z < x_{i + 1}
29
+
30
+ Parameters
31
+ ----------
32
+ x : ndarray, shape = (n_points,)
33
+ The values on the x-axis, must be in ascending order.
34
+ y : ndarray, shape = (n_points,)
35
+ The corresponding values on the y-axis.
36
+ a : float, optional, default: 1.0
37
+ A constant factor to scale ``y`` by.
38
+ b : float, optional, default: 0.0
39
+ A constant offset term.
40
+ domain : tuple, optional, default: (0, None)
41
+ A tuple ``(lower, upper)`` that defines the domain of the step function.
42
+ If ``lower`` or ``upper`` is ``None``, the first or last value of ``x`` is
43
+ used as the limit, respectively.
44
+ """
45
+
46
+ def __init__(self, x, y, *, a=1.0, b=0.0, domain=(0, None)):
47
+ check_consistent_length(x, y)
48
+ self.x = x
49
+ self.y = y
50
+ self.a = a
51
+ self.b = b
52
+ domain_lower = self.x[0] if domain[0] is None else domain[0]
53
+ domain_upper = self.x[-1] if domain[1] is None else domain[1]
54
+ self._domain = (float(domain_lower), float(domain_upper))
55
+
56
+ @property
57
+ def domain(self):
58
+ """The domain of the function.
59
+
60
+ The domain is the range of values that the function accepts.
61
+
62
+ Returns
63
+ -------
64
+ lower_limit : float
65
+ Lower limit of the omain.
66
+
67
+ upper_limit : float
68
+ Upper limit of the domain.
69
+ """
70
+ return self._domain
71
+
72
+ def __call__(self, x):
73
+ """Evaluate the step function at given values.
74
+
75
+ Parameters
76
+ ----------
77
+ x : float or array-like, shape=(n_values,)
78
+ The values at which to evaluate the step function.
79
+ Values must be within the function's ``domain``.
80
+
81
+ Returns
82
+ -------
83
+ y : float or array-like, shape=(n_values,)
84
+ The value of the step function at ``x``.
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If ``x`` contains values outside the function's ``domain``.
90
+ """
91
+ x = np.atleast_1d(x)
92
+ if not np.isfinite(x).all():
93
+ raise ValueError("x must be finite")
94
+ if np.min(x) < self._domain[0] or np.max(x) > self.domain[1]:
95
+ raise ValueError(f"x must be within [{self.domain[0]:f}; {self.domain[1]:f}]")
96
+
97
+ # x is within the domain, but we need to account for self.domain[0] <= x < self.x[0]
98
+ x = np.clip(x, a_min=self.x[0], a_max=None)
99
+
100
+ i = np.searchsorted(self.x, x, side="left")
101
+ not_exact = self.x[i] != x
102
+ i[not_exact] -= 1
103
+ value = self.a * self.y[i] + self.b
104
+ if value.shape[0] == 1:
105
+ return value[0]
106
+ return value
107
+
108
+ def __repr__(self):
109
+ return f"StepFunction(x={self.x!r}, y={self.y!r}, a={self.a!r}, b={self.b!r})"
110
+
111
+ def __eq__(self, other):
112
+ if isinstance(other, type(self)):
113
+ return all(self.x == other.x) and all(self.y == other.y) and self.a == other.a and self.b == other.b
114
+ return False
sksurv/io/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from .arffread import loadarff # noqa: F401
2
+ from .arffwrite import writearff # noqa: F401
sksurv/io/arffread.py ADDED
@@ -0,0 +1,91 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import numpy as np
14
+ import pandas as pd
15
+ from pandas.api.types import is_string_dtype
16
+ from scipy.io.arff import loadarff as scipy_loadarff
17
+
18
+ __all__ = ["loadarff"]
19
+
20
+
21
+ def _to_pandas(data, meta):
22
+ data_dict = {}
23
+ attrnames = sorted(meta.names())
24
+ for name in attrnames:
25
+ tp, attr_format = meta[name]
26
+ if tp == "nominal":
27
+ raw = []
28
+ for b in data[name]:
29
+ # replace missing values with NaN
30
+ if b == b"?":
31
+ raw.append(np.nan)
32
+ else:
33
+ raw.append(b.decode())
34
+
35
+ data_dict[name] = pd.Categorical(raw, categories=attr_format, ordered=False)
36
+ else:
37
+ arr = data[name]
38
+ dtype = "str" if is_string_dtype(arr.dtype) else arr.dtype
39
+ p = pd.Series(arr, dtype=dtype)
40
+ data_dict[name] = p
41
+
42
+ # currently, this step converts all pandas.Categorial columns back to pandas.Series
43
+ return pd.DataFrame.from_dict(data_dict)
44
+
45
+
46
+ def loadarff(filename):
47
+ """Load ARFF file.
48
+
49
+ Parameters
50
+ ----------
51
+ filename : str or file-like
52
+ Path to ARFF file, or file-like object to read from.
53
+
54
+ Returns
55
+ -------
56
+ data_frame : :class:`pandas.DataFrame`
57
+ DataFrame containing data of ARFF file
58
+
59
+ See Also
60
+ --------
61
+ scipy.io.arff.loadarff : The underlying function that reads the ARFF file.
62
+
63
+ Examples
64
+ --------
65
+ >>> from io import StringIO
66
+ >>> from sksurv.io import loadarff
67
+ >>>
68
+ >>> # Create a dummy ARFF file
69
+ >>> arff_content = '''
70
+ ... @relation test_data
71
+ ... @attribute feature1 numeric
72
+ ... @attribute feature2 numeric
73
+ ... @attribute class {A,B,C}
74
+ ... @data
75
+ ... 1.0,2.0,A
76
+ ... 3.0,4.0,B
77
+ ... 5.0,6.0,C
78
+ ... '''
79
+ >>>
80
+ >>> # Load the ARFF file
81
+ >>> with StringIO(arff_content) as f:
82
+ ... data = loadarff(f)
83
+ >>>
84
+ >>> print(data)
85
+ class feature1 feature2
86
+ 0 A 1.0 2.0
87
+ 1 B 3.0 4.0
88
+ 2 C 5.0 6.0
89
+ """
90
+ data, meta = scipy_loadarff(filename)
91
+ return _to_pandas(data, meta)
sksurv/io/arffwrite.py ADDED
@@ -0,0 +1,181 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import os.path
14
+ import re
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ from pandas.api.types import CategoricalDtype, is_string_dtype
19
+
20
+ _ILLEGAL_CHARACTER_PAT = re.compile(r"[^-_=\w\d\(\)<>\.]")
21
+
22
+
23
+ def writearff(data, filename, relation_name=None, index=True):
24
+ """Write ARFF file
25
+
26
+ Parameters
27
+ ----------
28
+ data : :class:`pandas.DataFrame`
29
+ DataFrame containing data
30
+
31
+ filename : str or file-like object
32
+ Path to ARFF file or file-like object. In the latter case,
33
+ the handle is closed by calling this function.
34
+
35
+ relation_name : str, optional, default: 'pandas'
36
+ Name of relation in ARFF file.
37
+
38
+ index : boolean, optional, default: True
39
+ Write row names (index)
40
+
41
+ See Also
42
+ --------
43
+ loadarff : Function to read ARFF files.
44
+
45
+ Examples
46
+ --------
47
+ >>> import numpy as np
48
+ >>> import pandas as pd
49
+ >>> from sksurv.io import writearff
50
+ >>>
51
+ >>> # Create a dummy DataFrame
52
+ >>> data = pd.DataFrame({
53
+ ... 'feature1': [1.0, 3.0, 5.0],
54
+ ... 'feature2': [2.0, np.nan, 6.0],
55
+ ... 'class': ['A', 'B', 'C']
56
+ ... }, index=['One', 'Two', 'Three'])
57
+ >>>
58
+ >>> # Write to ARFF file
59
+ >>> writearff(data, 'test_output.arff', relation_name='test_data')
60
+ >>>
61
+ >>> # Read contents of ARFF file
62
+ >>> with open('test_output.arff') as f:
63
+ ... arff_contents = "".join(f.readlines())
64
+ >>> print(arff_contents)
65
+ @relation test_data
66
+ <BLANKLINE>
67
+ @attribute index {One,Three,Two}
68
+ @attribute feature1 real
69
+ @attribute feature2 real
70
+ @attribute class {A,B,C}
71
+ <BLANKLINE>
72
+ @data
73
+ One,1.0,2.0,A
74
+ Two,3.0,?,B
75
+ Three,5.0,6.0,C
76
+ """
77
+ if isinstance(filename, str):
78
+ fp = open(filename, "w")
79
+
80
+ if relation_name is None:
81
+ relation_name = os.path.basename(filename)
82
+ else:
83
+ fp = filename
84
+
85
+ if relation_name is None:
86
+ relation_name = "pandas"
87
+
88
+ try:
89
+ data = _write_header(data, fp, relation_name, index)
90
+ fp.write("\n")
91
+ _write_data(data, fp)
92
+ finally:
93
+ fp.close()
94
+
95
+
96
+ def _write_header(data, fp, relation_name, index):
97
+ """Write header containing attribute names and types"""
98
+ fp.write(f"@relation {relation_name}\n\n")
99
+
100
+ if index:
101
+ data = data.reset_index()
102
+
103
+ attribute_names = _sanitize_column_names(data)
104
+
105
+ for column, series in data.items():
106
+ name = attribute_names[column]
107
+ fp.write(f"@attribute {name}\t")
108
+
109
+ if isinstance(series.dtype, CategoricalDtype) or is_string_dtype(series.dtype):
110
+ _write_attribute_categorical(series, fp)
111
+ elif np.issubdtype(series.dtype, np.floating):
112
+ fp.write("real")
113
+ elif np.issubdtype(series.dtype, np.integer):
114
+ fp.write("integer")
115
+ elif np.issubdtype(series.dtype, np.datetime64):
116
+ fp.write("date 'yyyy-MM-dd HH:mm:ss'")
117
+ else:
118
+ raise TypeError(f"unsupported type {series.dtype}")
119
+
120
+ fp.write("\n")
121
+ return data
122
+
123
+
124
+ def _sanitize_column_names(data):
125
+ """Replace illegal characters with underscore"""
126
+ new_names = {}
127
+ for name in data.columns:
128
+ new_names[name] = _ILLEGAL_CHARACTER_PAT.sub("_", name)
129
+ return new_names
130
+
131
+
132
+ def _check_str_value(x):
133
+ """If string has a space, wrap it in double quotes and remove/escape illegal characters"""
134
+ if isinstance(x, str):
135
+ # remove commas, and single quotation marks since loadarff cannot deal with it
136
+ x = x.replace(",", ".").replace(chr(0x2018), "'").replace(chr(0x2019), "'")
137
+
138
+ # put string in double quotes
139
+ if " " in x:
140
+ if x[0] in ('"', "'"):
141
+ x = x[1:]
142
+ if x[-1] in ('"', "'"):
143
+ x = x[: len(x) - 1]
144
+ x = '"' + x.replace('"', '\\"') + '"'
145
+ return str(x)
146
+
147
+
148
+ _check_str_array = np.frompyfunc(_check_str_value, 1, 1)
149
+
150
+
151
+ def _write_attribute_categorical(series, fp):
152
+ """Write categories of a categorical/nominal attribute"""
153
+ if isinstance(series.dtype, CategoricalDtype):
154
+ categories = series.cat.categories
155
+ string_values = _check_str_array(categories)
156
+ else:
157
+ categories = series.dropna().unique()
158
+ string_values = sorted(_check_str_array(categories), key=lambda x: x.strip('"'))
159
+
160
+ values = ",".join(string_values)
161
+ fp.write("{")
162
+ fp.write(values)
163
+ fp.write("}")
164
+
165
+
166
+ def _write_data(data, fp):
167
+ """Write the data section"""
168
+ fp.write("@data\n")
169
+
170
+ def to_str(x):
171
+ if pd.isna(x):
172
+ return "?"
173
+ return str(x)
174
+
175
+ data = data.applymap(to_str)
176
+ n_rows = data.shape[0]
177
+ for i in range(n_rows):
178
+ str_values = list(data.iloc[i, :].apply(_check_str_array))
179
+ line = ",".join(str_values)
180
+ fp.write(line)
181
+ fp.write("\n")
@@ -0,0 +1 @@
1
+ from .clinical import ClinicalKernelTransform, clinical_kernel # noqa: F401