scikit-survival 0.25.0__cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. scikit_survival-0.25.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.25.0.dist-info/RECORD +58 -0
  3. scikit_survival-0.25.0.dist-info/WHEEL +6 -0
  4. scikit_survival-0.25.0.dist-info/licenses/COPYING +674 -0
  5. scikit_survival-0.25.0.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +183 -0
  7. sksurv/base.py +115 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cpython-313-x86_64-linux-gnu.so +0 -0
  10. sksurv/column.py +205 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +12 -0
  13. sksurv/datasets/base.py +614 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/bmt.arff +46 -0
  17. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  18. sksurv/datasets/data/cgvhd.arff +118 -0
  19. sksurv/datasets/data/flchain.arff +7887 -0
  20. sksurv/datasets/data/veteran.arff +148 -0
  21. sksurv/datasets/data/whas500.arff +520 -0
  22. sksurv/docstrings.py +99 -0
  23. sksurv/ensemble/__init__.py +2 -0
  24. sksurv/ensemble/_coxph_loss.cpython-313-x86_64-linux-gnu.so +0 -0
  25. sksurv/ensemble/boosting.py +1564 -0
  26. sksurv/ensemble/forest.py +902 -0
  27. sksurv/ensemble/survival_loss.py +151 -0
  28. sksurv/exceptions.py +18 -0
  29. sksurv/functions.py +114 -0
  30. sksurv/io/__init__.py +2 -0
  31. sksurv/io/arffread.py +89 -0
  32. sksurv/io/arffwrite.py +181 -0
  33. sksurv/kernels/__init__.py +1 -0
  34. sksurv/kernels/_clinical_kernel.cpython-313-x86_64-linux-gnu.so +0 -0
  35. sksurv/kernels/clinical.py +348 -0
  36. sksurv/linear_model/__init__.py +3 -0
  37. sksurv/linear_model/_coxnet.cpython-313-x86_64-linux-gnu.so +0 -0
  38. sksurv/linear_model/aft.py +208 -0
  39. sksurv/linear_model/coxnet.py +592 -0
  40. sksurv/linear_model/coxph.py +637 -0
  41. sksurv/meta/__init__.py +4 -0
  42. sksurv/meta/base.py +35 -0
  43. sksurv/meta/ensemble_selection.py +724 -0
  44. sksurv/meta/stacking.py +370 -0
  45. sksurv/metrics.py +1028 -0
  46. sksurv/nonparametric.py +911 -0
  47. sksurv/preprocessing.py +183 -0
  48. sksurv/svm/__init__.py +11 -0
  49. sksurv/svm/_minlip.cpython-313-x86_64-linux-gnu.so +0 -0
  50. sksurv/svm/_prsvm.cpython-313-x86_64-linux-gnu.so +0 -0
  51. sksurv/svm/minlip.py +690 -0
  52. sksurv/svm/naive_survival_svm.py +249 -0
  53. sksurv/svm/survival_svm.py +1236 -0
  54. sksurv/testing.py +108 -0
  55. sksurv/tree/__init__.py +1 -0
  56. sksurv/tree/_criterion.cpython-313-x86_64-linux-gnu.so +0 -0
  57. sksurv/tree/tree.py +790 -0
  58. sksurv/util.py +415 -0
@@ -0,0 +1,249 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import itertools
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ from scipy.special import comb
18
+ from sklearn.svm import LinearSVC
19
+ from sklearn.utils.validation import _get_feature_names, check_random_state, validate_data
20
+
21
+ from ..base import SurvivalAnalysisMixin
22
+ from ..exceptions import NoComparablePairException
23
+ from ..util import check_array_survival
24
+
25
+
26
+ class NaiveSurvivalSVM(SurvivalAnalysisMixin, LinearSVC):
27
+ r"""Naive implementation of linear Survival Support Vector Machine.
28
+
29
+ This class uses a regular linear support vector classifier (liblinear)
30
+ to implement a survival SVM. It constructs a new dataset by computing
31
+ the difference between feature vectors of comparable pairs from the
32
+ original data. This approach results in a space complexity of
33
+ :math:`O(\text{n_samples}^2)`.
34
+
35
+ The optimization problem is formulated as:
36
+
37
+ .. math::
38
+
39
+ \min_{\mathbf{w}}\quad
40
+ \frac{1}{2} \lVert \mathbf{w} \rVert_2^2
41
+ + \gamma \sum_{i = 1}^n \xi_i \\
42
+ \text{subject to}\quad
43
+ \mathbf{w}^\top \mathbf{x}_i - \mathbf{w}^\top \mathbf{x}_j \geq 1 - \xi_{ij},\quad
44
+ \forall (i, j) \in \mathcal{P}, \\
45
+ \xi_i \geq 0,\quad \forall (i, j) \in \mathcal{P}.
46
+
47
+ \mathcal{P} = \{ (i, j) \mid y_i > y_j \land \delta_j = 1 \}_{i,j=1,\dots,n}.
48
+
49
+ See [1]_, [2]_ for further description.
50
+
51
+ Parameters
52
+ ----------
53
+ alpha : float, optional, default: 1.0
54
+ Weight of penalizing the squared hinge loss in the objective function. Must be greater than 0.
55
+
56
+ loss : {'hinge', 'squared_hinge'}, optional,default: 'squared_hinge'
57
+ Specifies the loss function. 'hinge' is the standard SVM loss
58
+ (used e.g. by the SVC class) while 'squared_hinge' is the
59
+ square of the hinge loss.
60
+
61
+ penalty : {'l1', 'l2'}, optional,default: 'l2'
62
+ Specifies the norm used in the penalization. The 'l2'
63
+ penalty is the standard used in SVC. The 'l1' leads to `coef_`
64
+ vectors that are sparse.
65
+
66
+ dual : bool, optional,default: True
67
+ Select the algorithm to either solve the dual or primal
68
+ optimization problem. Prefer dual=False when n_samples > n_features.
69
+
70
+ tol : float, optional, default: 1e-4
71
+ Tolerance for stopping criteria.
72
+
73
+ verbose : int, optional, default: 0
74
+ If ``True``, enable verbose output. Note that this setting takes advantage of a
75
+ per-process runtime setting in liblinear that, if enabled, may not work
76
+ properly in a multithreaded context.
77
+
78
+ random_state : int, :class:`numpy.random.RandomState` instance, or None, optional, default: None
79
+ Used to resolve ties in survival times. Pass an int for reproducible output across
80
+ multiple :meth:`fit` calls.
81
+
82
+ max_iter : int, optional, default: 1000
83
+ The maximum number of iterations taken for the solver to converge.
84
+
85
+ Attributes
86
+ ----------
87
+ n_iter_ : int
88
+ Number of iterations run by the optimization routine to fit the model.
89
+
90
+ See also
91
+ --------
92
+ sksurv.svm.FastSurvivalSVM : Alternative implementation with reduced time complexity for training.
93
+ sksurv.svm.HingeLossSurvivalSVM : Non-linear version of the naive survival SVM based on kernel functions.
94
+
95
+ References
96
+ ----------
97
+ .. [1] Van Belle, V., Pelckmans, K., Suykens, J. A., & Van Huffel, S.
98
+ Support Vector Machines for Survival Analysis. In Proc. of the 3rd Int. Conf.
99
+ on Computational Intelligence in Medicine and Healthcare (CIMED). 1-8. 2007
100
+
101
+ .. [2] Evers, L., Messow, C.M.,
102
+ "Sparse kernel methods for high-dimensional survival data",
103
+ Bioinformatics 24(14), 1632-8, 2008.
104
+
105
+ """
106
+
107
+ _parameter_constraints = {
108
+ "penalty": LinearSVC._parameter_constraints["penalty"],
109
+ "loss": LinearSVC._parameter_constraints["loss"],
110
+ "dual": LinearSVC._parameter_constraints["dual"],
111
+ "tol": LinearSVC._parameter_constraints["tol"],
112
+ "alpha": LinearSVC._parameter_constraints["C"],
113
+ "verbose": LinearSVC._parameter_constraints["verbose"],
114
+ "random_state": LinearSVC._parameter_constraints["random_state"],
115
+ "max_iter": LinearSVC._parameter_constraints["max_iter"],
116
+ }
117
+
118
+ def __init__(
119
+ self,
120
+ penalty="l2",
121
+ loss="squared_hinge",
122
+ *,
123
+ dual=False,
124
+ tol=1e-4,
125
+ alpha=1.0,
126
+ verbose=0,
127
+ random_state=None,
128
+ max_iter=1000,
129
+ ):
130
+ super().__init__(
131
+ penalty=penalty,
132
+ loss=loss,
133
+ dual=dual,
134
+ tol=tol,
135
+ verbose=verbose,
136
+ random_state=random_state,
137
+ max_iter=max_iter,
138
+ fit_intercept=False,
139
+ )
140
+ self.alpha = alpha
141
+
142
+ def _get_survival_pairs(self, X, y, random_state): # pylint: disable=no-self-use
143
+ """Generates comparable pairs from survival data.
144
+
145
+ Parameters
146
+ ----------
147
+ X : array-like, shape = (n_samples, n_features)
148
+ Data matrix.
149
+ y : structured array, shape = (n_samples,)
150
+ A structured array containing the binary event indicator
151
+ and time of event or time of censoring.
152
+ random_state : RandomState instance
153
+ Random number generator used for shuffling.
154
+
155
+ Returns
156
+ -------
157
+ x_pairs : ndarray, shape = (n_pairs, n_features)
158
+ Feature differences for comparable pairs.
159
+ y_pairs : ndarray, shape = (n_pairs,)
160
+ Labels for comparable pairs (1 or -1).
161
+
162
+ Raises
163
+ ------
164
+ NoComparablePairException
165
+ If no comparable pairs can be formed from the input data.
166
+ """
167
+ feature_names = _get_feature_names(X)
168
+
169
+ X = validate_data(self, X, ensure_min_samples=2)
170
+ event, time = check_array_survival(X, y)
171
+
172
+ idx = np.arange(X.shape[0], dtype=int)
173
+ random_state.shuffle(idx)
174
+
175
+ n_pairs = int(comb(X.shape[0], 2))
176
+ x_pairs = np.empty((n_pairs, X.shape[1]), dtype=float)
177
+ y_pairs = np.empty(n_pairs, dtype=np.int8)
178
+ k = 0
179
+ for xi, xj in itertools.combinations(idx, 2):
180
+ if time[xi] > time[xj] and event[xj]:
181
+ np.subtract(X[xi, :], X[xj, :], out=x_pairs[k, :])
182
+ y_pairs[k] = 1
183
+ k += 1
184
+ elif time[xi] < time[xj] and event[xi]:
185
+ np.subtract(X[xi, :], X[xj, :], out=x_pairs[k, :])
186
+ y_pairs[k] = -1
187
+ k += 1
188
+ elif time[xi] == time[xj] and (event[xi] or event[xj]):
189
+ np.subtract(X[xi, :], X[xj, :], out=x_pairs[k, :])
190
+ y_pairs[k] = 1 if event[xj] else -1
191
+ k += 1
192
+
193
+ x_pairs.resize((k, X.shape[1]), refcheck=False)
194
+ y_pairs.resize(k, refcheck=False)
195
+
196
+ if feature_names is not None:
197
+ x_pairs = pd.DataFrame(x_pairs, columns=feature_names)
198
+ return x_pairs, y_pairs
199
+
200
+ def fit(self, X, y, sample_weight=None):
201
+ """Build a survival support vector machine model from training data.
202
+
203
+ Parameters
204
+ ----------
205
+ X : array-like, shape = (n_samples, n_features)
206
+ Data matrix.
207
+
208
+ y : structured array, shape = (n_samples,)
209
+ A structured array with two fields. The first field is a boolean
210
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
211
+ The second field is a float with the time of event or time of censoring.
212
+
213
+ sample_weight : array-like, shape = (n_samples,), optional
214
+ Array of weights that are assigned to individual
215
+ samples. If not provided,
216
+ then each sample is given unit weight.
217
+
218
+ Returns
219
+ -------
220
+ self
221
+ """
222
+ random_state = check_random_state(self.random_state)
223
+
224
+ x_pairs, y_pairs = self._get_survival_pairs(X, y, random_state)
225
+ if x_pairs.shape[0] == 0:
226
+ raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
227
+
228
+ self.C = self.alpha
229
+ return super().fit(x_pairs, y_pairs, sample_weight=sample_weight)
230
+
231
+ def predict(self, X):
232
+ """Predict risk scores.
233
+
234
+ Predictions are risk scores (i.e. higher values indicate an
235
+ increased risk of experiencing an event). The scores have no
236
+ unit and are only meaningful to rank samples by their risk
237
+ of experiencing an event.
238
+
239
+ Parameters
240
+ ----------
241
+ X : array-like, shape = (n_samples, n_features,)
242
+ The input samples.
243
+
244
+ Returns
245
+ -------
246
+ y : ndarray, shape = (n_samples,), dtype = float
247
+ Predicted risk scores.
248
+ """
249
+ return -self.decision_function(X)