scikit-survival 0.23.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. scikit_survival-0.23.1.dist-info/COPYING +674 -0
  2. scikit_survival-0.23.1.dist-info/METADATA +888 -0
  3. scikit_survival-0.23.1.dist-info/RECORD +55 -0
  4. scikit_survival-0.23.1.dist-info/WHEEL +5 -0
  5. scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +138 -0
  7. sksurv/base.py +103 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
  10. sksurv/column.py +201 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +10 -0
  13. sksurv/datasets/base.py +436 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  17. sksurv/datasets/data/flchain.arff +7887 -0
  18. sksurv/datasets/data/veteran.arff +148 -0
  19. sksurv/datasets/data/whas500.arff +520 -0
  20. sksurv/ensemble/__init__.py +2 -0
  21. sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
  22. sksurv/ensemble/boosting.py +1610 -0
  23. sksurv/ensemble/forest.py +947 -0
  24. sksurv/ensemble/survival_loss.py +151 -0
  25. sksurv/exceptions.py +18 -0
  26. sksurv/functions.py +114 -0
  27. sksurv/io/__init__.py +2 -0
  28. sksurv/io/arffread.py +58 -0
  29. sksurv/io/arffwrite.py +145 -0
  30. sksurv/kernels/__init__.py +1 -0
  31. sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
  32. sksurv/kernels/clinical.py +328 -0
  33. sksurv/linear_model/__init__.py +3 -0
  34. sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
  35. sksurv/linear_model/aft.py +205 -0
  36. sksurv/linear_model/coxnet.py +543 -0
  37. sksurv/linear_model/coxph.py +618 -0
  38. sksurv/meta/__init__.py +4 -0
  39. sksurv/meta/base.py +35 -0
  40. sksurv/meta/ensemble_selection.py +642 -0
  41. sksurv/meta/stacking.py +349 -0
  42. sksurv/metrics.py +996 -0
  43. sksurv/nonparametric.py +588 -0
  44. sksurv/preprocessing.py +155 -0
  45. sksurv/svm/__init__.py +11 -0
  46. sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
  47. sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
  48. sksurv/svm/minlip.py +606 -0
  49. sksurv/svm/naive_survival_svm.py +221 -0
  50. sksurv/svm/survival_svm.py +1228 -0
  51. sksurv/testing.py +108 -0
  52. sksurv/tree/__init__.py +1 -0
  53. sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
  54. sksurv/tree/tree.py +703 -0
  55. sksurv/util.py +333 -0
@@ -0,0 +1,1228 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ from abc import ABCMeta, abstractmethod
14
+ from numbers import Integral, Real
15
+ import warnings
16
+
17
+ import numexpr
18
+ import numpy as np
19
+ from scipy.optimize import minimize
20
+ from sklearn.base import BaseEstimator
21
+ from sklearn.exceptions import ConvergenceWarning
22
+ from sklearn.metrics.pairwise import PAIRWISE_KERNEL_FUNCTIONS, pairwise_kernels
23
+ from sklearn.utils import check_array, check_consistent_length, check_random_state, check_X_y
24
+ from sklearn.utils._param_validation import Interval, StrOptions
25
+ from sklearn.utils.extmath import safe_sparse_dot, squared_norm
26
+ from sklearn.utils.validation import check_is_fitted
27
+
28
+ from ..base import SurvivalAnalysisMixin
29
+ from ..bintrees import AVLTree, RBTree
30
+ from ..exceptions import NoComparablePairException
31
+ from ..util import check_array_survival
32
+ from ._prsvm import survival_constraints_simple, survival_constraints_with_support_vectors
33
+
34
+
35
+ class Counter(metaclass=ABCMeta):
36
+ @abstractmethod
37
+ def __init__(self, x, y, status, time=None):
38
+ self.x, self.y = check_X_y(x, y)
39
+
40
+ assert np.issubdtype(y.dtype, np.integer), f"y vector must have integer type, but was {y.dtype}"
41
+ assert y.min() == 0, "minimum element of y vector must be 0"
42
+
43
+ if time is None:
44
+ self.status = check_array(status, dtype=bool, ensure_2d=False)
45
+ check_consistent_length(self.x, self.status)
46
+ else:
47
+ self.status = check_array(status, dtype=bool, ensure_2d=False)
48
+ self.time = check_array(time, ensure_2d=False)
49
+ check_consistent_length(self.x, self.status, self.time)
50
+
51
+ self.eps = np.finfo(self.x.dtype).eps
52
+
53
+ def update_sort_order(self, w):
54
+ xw = np.dot(self.x, w)
55
+ order = xw.argsort(kind="mergesort")
56
+ self.xw = xw[order]
57
+ self.order = order
58
+ return xw
59
+
60
+ @abstractmethod
61
+ def calculate(self, v):
62
+ """Return l_plus, xv_plus, l_minus, xv_minus"""
63
+
64
+
65
+ class OrderStatisticTreeSurvivalCounter(Counter):
66
+ """Counting method used by :class:`LargeScaleOptimizer` for survival analysis.
67
+
68
+ Parameters
69
+ ----------
70
+ x : array, shape = (n_samples, n_features)
71
+ Feature matrix
72
+
73
+ y : array of int, shape = (n_samples,)
74
+ Unique ranks of samples, starting with 0.
75
+
76
+ status : array of bool, shape = (n_samples,)
77
+ Event indicator of samples.
78
+
79
+ tree_class : type
80
+ Which class to use as order statistic tree
81
+
82
+ time : array, shape = (n_samples,)
83
+ Survival times.
84
+ """
85
+
86
+ def __init__(self, x, y, status, tree_class, time=None):
87
+ super().__init__(x, y, status, time)
88
+ self._tree_class = tree_class
89
+
90
+ def calculate(self, v):
91
+ # only self.xw is sorted, for everything else use self.order
92
+ # the order of return values is with respect to original order of samples, NOT self.order
93
+ xv = np.dot(self.x, v)
94
+
95
+ od = self.order
96
+
97
+ n_samples = self.x.shape[0]
98
+ l_plus = np.zeros(n_samples, dtype=int)
99
+ l_minus = np.zeros(n_samples, dtype=int)
100
+ xv_plus = np.zeros(n_samples, dtype=float)
101
+ xv_minus = np.zeros(n_samples, dtype=float)
102
+
103
+ j = 0
104
+ tree = self._tree_class(n_samples)
105
+ for i in range(n_samples):
106
+ while j < n_samples and 1 - self.xw[j] + self.xw[i] > 0:
107
+ tree.insert(self.y[od[j]], xv[od[j]])
108
+ j += 1
109
+
110
+ # larger (root of t, y[od[i]])
111
+ count, vec_sum = tree.count_larger_with_event(self.y[od[i]], self.status[od[i]])
112
+ l_plus[od[i]] = count
113
+ xv_plus[od[i]] = vec_sum
114
+
115
+ tree = self._tree_class(n_samples)
116
+ j = n_samples - 1
117
+ for i in range(j, -1, -1):
118
+ while j >= 0 and 1 - self.xw[i] + self.xw[j] > 0:
119
+ if self.status[od[j]]:
120
+ tree.insert(self.y[od[j]], xv[od[j]])
121
+ j -= 1
122
+
123
+ # smaller (root of T, y[od[i]])
124
+ count, vec_sum = tree.count_smaller(self.y[od[i]])
125
+ l_minus[od[i]] = count
126
+ xv_minus[od[i]] = vec_sum
127
+
128
+ return l_plus, xv_plus, l_minus, xv_minus
129
+
130
+
131
+ class SurvivalCounter(Counter):
132
+ def __init__(self, x, y, status, n_relevance_levels, time=None):
133
+ super().__init__(x, y, status, time)
134
+ self.n_relevance_levels = n_relevance_levels
135
+
136
+ def _count_values(self):
137
+ """Return dict mapping relevance level to sample index"""
138
+ indices = {yi: [i] for i, yi in enumerate(self.y) if self.status[i]}
139
+
140
+ return indices
141
+
142
+ def calculate(self, v):
143
+ n_samples = self.x.shape[0]
144
+ l_plus = np.zeros(n_samples, dtype=int)
145
+ l_minus = np.zeros(n_samples, dtype=int)
146
+ xv_plus = np.zeros(n_samples, dtype=float)
147
+ xv_minus = np.zeros(n_samples, dtype=float)
148
+ indices = self._count_values()
149
+
150
+ od = self.order
151
+
152
+ for relevance in range(self.n_relevance_levels):
153
+ j = 0
154
+ count_plus = 0
155
+ # relevance levels are unique, therefore count can only be 1 or 0
156
+ count_minus = 1 if relevance in indices else 0
157
+ xv_count_plus = 0
158
+ xv_count_minus = np.dot(self.x.take(indices.get(relevance, []), axis=0), v).sum()
159
+
160
+ for i in range(n_samples):
161
+ if self.y[od[i]] != relevance or not self.status[od[i]]:
162
+ continue
163
+
164
+ while j < n_samples and 1 - self.xw[j] + self.xw[i] > 0:
165
+ if self.y[od[j]] > relevance:
166
+ count_plus += 1
167
+ xv_count_plus += np.dot(self.x[od[j], :], v)
168
+ l_minus[od[j]] += count_minus
169
+ xv_minus[od[j]] += xv_count_minus
170
+
171
+ j += 1
172
+
173
+ l_plus[od[i]] = count_plus
174
+ xv_plus[od[i]] += xv_count_plus
175
+ count_minus -= 1
176
+ xv_count_minus -= np.dot(self.x.take(od[i], axis=0), v)
177
+
178
+ return l_plus, xv_plus, l_minus, xv_minus
179
+
180
+
181
+ class RankSVMOptimizer(metaclass=ABCMeta):
182
+ """Abstract base class for all optimizers"""
183
+
184
+ def __init__(self, alpha, rank_ratio, timeit=False):
185
+ self.alpha = alpha
186
+ self.rank_ratio = rank_ratio
187
+ self.timeit = timeit
188
+
189
+ self._last_w = None
190
+ # cache gradient computations
191
+ self._last_gradient_w = None
192
+ self._last_gradient = None
193
+
194
+ @abstractmethod
195
+ def _objective_func(self, w):
196
+ """Evaluate objective function at w"""
197
+
198
+ @abstractmethod
199
+ def _update_constraints(self, w):
200
+ """Update constraints"""
201
+
202
+ @abstractmethod
203
+ def _gradient_func(self, w):
204
+ """Evaluate gradient at w"""
205
+
206
+ @abstractmethod
207
+ def _hessian_func(self, w, s):
208
+ """Evaluate Hessian at w"""
209
+
210
+ @property
211
+ @abstractmethod
212
+ def n_coefficients(self):
213
+ """Return number of coefficients (includes intercept)"""
214
+
215
+ def _update_constraints_if_necessary(self, w):
216
+ needs_update = (w != self._last_w).any()
217
+ if needs_update:
218
+ self._update_constraints(w)
219
+ self._last_w = w.copy()
220
+ return needs_update
221
+
222
+ def _do_objective_func(self, w):
223
+ self._update_constraints_if_necessary(w)
224
+ return self._objective_func(w)
225
+
226
+ def _do_gradient_func(self, w):
227
+ if self._last_gradient_w is not None and (w == self._last_gradient_w).all():
228
+ return self._last_gradient
229
+
230
+ self._update_constraints_if_necessary(w)
231
+ self._last_gradient_w = w.copy()
232
+ self._last_gradient = self._gradient_func(w)
233
+ return self._last_gradient
234
+
235
+ def _init_coefficients(self):
236
+ w = np.zeros(self.n_coefficients)
237
+ self._update_constraints(w)
238
+ self._last_w = w.copy()
239
+ return w
240
+
241
+ def run(self, **kwargs):
242
+ w = self._init_coefficients()
243
+
244
+ timings = None
245
+ if self.timeit:
246
+ import timeit
247
+
248
+ def _inner():
249
+ return minimize(
250
+ self._do_objective_func,
251
+ w,
252
+ method="newton-cg",
253
+ jac=self._do_gradient_func,
254
+ hessp=self._hessian_func,
255
+ **kwargs,
256
+ )
257
+
258
+ timer = timeit.Timer(_inner)
259
+ timings = timer.repeat(self.timeit, number=1)
260
+
261
+ opt_result = minimize(
262
+ self._do_objective_func,
263
+ w,
264
+ method="newton-cg",
265
+ jac=self._do_gradient_func,
266
+ hessp=self._hessian_func,
267
+ **kwargs,
268
+ )
269
+ opt_result["timings"] = timings
270
+
271
+ return opt_result
272
+
273
+
274
+ class SimpleOptimizer(RankSVMOptimizer):
275
+ """Simple optimizer, which explicitly constructs matrix of all pairs of samples"""
276
+
277
+ def __init__(self, x, y, alpha, rank_ratio, timeit=False):
278
+ super().__init__(alpha, rank_ratio, timeit)
279
+ self.data_x = x
280
+ self.constraints = survival_constraints_simple(np.asarray(y, dtype=np.uint8))
281
+
282
+ if self.constraints.shape[0] == 0:
283
+ raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
284
+
285
+ self.L = np.ones(self.constraints.shape[0])
286
+
287
+ @property
288
+ def n_coefficients(self):
289
+ return self.data_x.shape[1]
290
+
291
+ def _objective_func(self, w):
292
+ val = 0.5 * squared_norm(w) + 0.5 * self.alpha * squared_norm(self.L)
293
+ return val
294
+
295
+ def _update_constraints(self, w):
296
+ self.xw = np.dot(self.data_x, w)
297
+ self.L = 1 - self.constraints.dot(self.xw)
298
+ np.maximum(0, self.L, out=self.L)
299
+ support_vectors = np.nonzero(self.L > 0)[0]
300
+ self.Asv = self.constraints[support_vectors, :]
301
+
302
+ def _gradient_func(self, w):
303
+ # sum over columns without running into overflow problems
304
+ col_sum = self.Asv.sum(axis=0, dtype=int)
305
+ v = col_sum.A.squeeze()
306
+
307
+ z = np.dot(self.data_x.T, (self.Asv.T.dot(self.Asv.dot(self.xw)) - v))
308
+ return w + self.alpha * z
309
+
310
+ def _hessian_func(self, w, s):
311
+ z = self.alpha * self.Asv.dot(np.dot(self.data_x, s))
312
+ return s + np.dot(safe_sparse_dot(z.T, self.Asv), self.data_x).T
313
+
314
+
315
+ class PRSVMOptimizer(RankSVMOptimizer):
316
+ """PRSVM optimizer that after each iteration of Newton's method
317
+ constructs matrix of support vector pairs"""
318
+
319
+ def __init__(self, x, y, alpha, rank_ratio, timeit=False):
320
+ super().__init__(alpha, rank_ratio, timeit)
321
+ self.data_x = x
322
+ self.data_y = np.asarray(y, dtype=np.uint8)
323
+ self._constraints = lambda w: survival_constraints_with_support_vectors(self.data_y, w)
324
+
325
+ Aw = self._constraints(np.zeros(x.shape[1]))
326
+ if Aw.shape[0] == 0:
327
+ raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
328
+
329
+ @property
330
+ def n_coefficients(self):
331
+ return self.data_x.shape[1]
332
+
333
+ def _objective_func(self, w):
334
+ z = self.Aw.shape[0] + squared_norm(self.AXw) - 2.0 * self.AXw.sum()
335
+ val = 0.5 * squared_norm(w) + 0.5 * self.alpha * z
336
+ return val
337
+
338
+ def _update_constraints(self, w):
339
+ xw = np.dot(self.data_x, w)
340
+ self.Aw = self._constraints(xw)
341
+ self.AXw = self.Aw.dot(xw)
342
+
343
+ def _gradient_func(self, w):
344
+ # sum over columns without running into overflow problems
345
+ col_sum = self.Aw.sum(axis=0, dtype=int)
346
+ v = col_sum.A.squeeze()
347
+ z = np.dot(self.data_x.T, self.Aw.T.dot(self.AXw) - v)
348
+ return w + self.alpha * z
349
+
350
+ def _hessian_func(self, w, s):
351
+ v = self.Aw.dot(np.dot(self.data_x, s))
352
+ z = self.alpha * np.dot(self.data_x.T, self.Aw.T.dot(v))
353
+ return s + z
354
+
355
+
356
+ class LargeScaleOptimizer(RankSVMOptimizer):
357
+ """Optimizer that does not explicitly create matrix of constraints
358
+
359
+ Parameters
360
+ ----------
361
+ alpha : float
362
+ Regularization parameter.
363
+
364
+ rank_ratio : float
365
+ Trade-off between regression and ranking objectives.
366
+
367
+ fit_intercept : bool
368
+ Whether to fit an intercept. Only used if regression objective
369
+ is optimized (rank_ratio < 1.0).
370
+
371
+ counter : object
372
+ Instance of :class:`Counter` subclass.
373
+
374
+ References
375
+ ----------
376
+ Lee, C.-P., & Lin, C.-J. (2014). Supplement Materials for "Large-scale linear RankSVM". Neural Computation, 26(4),
377
+ 781–817. doi:10.1162/NECO_a_00571
378
+ """
379
+
380
+ def __init__(self, alpha, rank_ratio, fit_intercept, counter, timeit=False):
381
+ super().__init__(alpha, rank_ratio, timeit)
382
+
383
+ self._counter = counter
384
+ self._regr_penalty = (1.0 - rank_ratio) * alpha
385
+ self._rank_penalty = rank_ratio * alpha
386
+ self._has_time = hasattr(self._counter, "time") and self._regr_penalty > 0
387
+ self._fit_intercept = fit_intercept if self._has_time else False
388
+
389
+ @property
390
+ def n_coefficients(self):
391
+ n = self._counter.x.shape[1]
392
+ if self._fit_intercept:
393
+ n += 1
394
+ return n
395
+
396
+ def _init_coefficients(self):
397
+ w = super()._init_coefficients()
398
+ n = w.shape[0]
399
+ if self._fit_intercept:
400
+ w[0] = self._counter.time.mean()
401
+ n -= 1
402
+
403
+ l_plus, _, l_minus, _ = self._counter.calculate(np.zeros(n))
404
+ if np.all(l_plus == 0) and np.all(l_minus == 0):
405
+ raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
406
+
407
+ return w
408
+
409
+ def _split_coefficents(self, w):
410
+ """Split into intercept/bias and feature-specific coefficients"""
411
+ if self._fit_intercept:
412
+ bias = w[0]
413
+ wf = w[1:]
414
+ else:
415
+ bias = 0.0
416
+ wf = w
417
+ return bias, wf
418
+
419
+ def _objective_func(self, w):
420
+ bias, wf = self._split_coefficents(w)
421
+
422
+ l_plus, xv_plus, l_minus, xv_minus = self._counter.calculate(wf) # pylint: disable=unused-variable
423
+
424
+ xw = self._xw
425
+ val = 0.5 * squared_norm(wf)
426
+ if self._has_time:
427
+ val += (
428
+ 0.5 * self._regr_penalty * squared_norm(self.y_compressed - bias - xw.compress(self.regr_mask, axis=0))
429
+ )
430
+
431
+ val += (
432
+ 0.5
433
+ * self._rank_penalty
434
+ * numexpr.evaluate(
435
+ "sum(xw * ((l_plus + l_minus) * xw - xv_plus - xv_minus - 2 * (l_minus - l_plus)) + l_minus)"
436
+ )
437
+ )
438
+
439
+ return val
440
+
441
+ def _update_constraints(self, w):
442
+ bias, wf = self._split_coefficents(w)
443
+
444
+ self._xw = self._counter.update_sort_order(wf)
445
+
446
+ if self._has_time:
447
+ pred_time = self._counter.time - self._xw - bias
448
+ self.regr_mask = (pred_time > 0) | self._counter.status
449
+ self.y_compressed = self._counter.time.compress(self.regr_mask, axis=0)
450
+
451
+ def _gradient_func(self, w):
452
+ bias, wf = self._split_coefficents(w)
453
+
454
+ l_plus, xv_plus, l_minus, xv_minus = self._counter.calculate(wf) # pylint: disable=unused-variable
455
+ x = self._counter.x
456
+
457
+ xw = self._xw # noqa: F841; # pylint: disable=unused-variable
458
+ z = numexpr.evaluate("(l_plus + l_minus) * xw - xv_plus - xv_minus - l_minus + l_plus")
459
+
460
+ grad = wf + self._rank_penalty * np.dot(x.T, z)
461
+ if self._has_time:
462
+ xc = x.compress(self.regr_mask, axis=0)
463
+ xcs = np.dot(xc, wf)
464
+ grad += self._regr_penalty * (np.dot(xc.T, xcs) + xc.sum(axis=0) * bias - np.dot(xc.T, self.y_compressed))
465
+
466
+ # intercept
467
+ if self._fit_intercept:
468
+ grad_intercept = self._regr_penalty * (xcs.sum() + xc.shape[0] * bias - self.y_compressed.sum())
469
+ grad = np.r_[grad_intercept, grad]
470
+
471
+ return grad
472
+
473
+ def _hessian_func(self, w, s):
474
+ s_bias, s_feat = self._split_coefficents(s)
475
+
476
+ l_plus, xv_plus, l_minus, xv_minus = self._counter.calculate(s_feat) # pylint: disable=unused-variable
477
+ x = self._counter.x
478
+
479
+ xs = np.dot(x, s_feat) # pylint: disable=unused-variable
480
+ xs = numexpr.evaluate("(l_plus + l_minus) * xs - xv_plus - xv_minus")
481
+
482
+ hessp = s_feat + self._rank_penalty * np.dot(x.T, xs)
483
+ if self._has_time:
484
+ xc = x.compress(self.regr_mask, axis=0)
485
+ hessp += self._regr_penalty * np.dot(xc.T, np.dot(xc, s_feat))
486
+
487
+ # intercept
488
+ if self._fit_intercept:
489
+ xsum = xc.sum(axis=0)
490
+ hessp += self._regr_penalty * xsum * s_bias
491
+ hessp_intercept = self._regr_penalty * xc.shape[0] * s_bias + self._regr_penalty * np.dot(xsum, s_feat)
492
+ hessp = np.r_[hessp_intercept, hessp]
493
+
494
+ return hessp
495
+
496
+
497
+ class NonlinearLargeScaleOptimizer(RankSVMOptimizer):
498
+ """Optimizer that does not explicitly create matrix of constraints
499
+
500
+ Parameters
501
+ ----------
502
+ alpha : float
503
+ Regularization parameter.
504
+
505
+ rank_ratio : float
506
+ Trade-off between regression and ranking objectives.
507
+
508
+ counter : object
509
+ Instance of :class:`Counter` subclass.
510
+
511
+ References
512
+ ----------
513
+ Lee, C.-P., & Lin, C.-J. (2014). Supplement Materials for "Large-scale linear RankSVM". Neural Computation, 26(4),
514
+ 781–817. doi:10.1162/NECO_a_00571
515
+ """
516
+
517
+ def __init__(self, alpha, rank_ratio, fit_intercept, counter, timeit=False):
518
+ super().__init__(alpha, rank_ratio, timeit)
519
+
520
+ self._counter = counter
521
+ self._fit_intercept = fit_intercept
522
+ self._rank_penalty = rank_ratio * alpha
523
+ self._regr_penalty = (1.0 - rank_ratio) * alpha
524
+ self._has_time = hasattr(self._counter, "time") and self._regr_penalty > 0
525
+ self._fit_intercept = fit_intercept if self._has_time else False
526
+
527
+ @property
528
+ def n_coefficients(self):
529
+ n = self._counter.x.shape[0]
530
+ if self._fit_intercept:
531
+ n += 1
532
+ return n
533
+
534
+ def _init_coefficients(self):
535
+ w = super()._init_coefficients()
536
+ n = w.shape[0]
537
+ if self._fit_intercept:
538
+ w[0] = self._counter.time.mean()
539
+ n -= 1
540
+
541
+ l_plus, _, l_minus, _ = self._counter.calculate(np.zeros(n))
542
+ if np.all(l_plus == 0) and np.all(l_minus == 0):
543
+ raise NoComparablePairException("Data has no comparable pairs, cannot fit model.")
544
+
545
+ return w
546
+
547
+ def _split_coefficents(self, w):
548
+ """Split into intercept/bias and feature-specific coefficients"""
549
+ if self._fit_intercept:
550
+ bias = w[0]
551
+ wf = w[1:]
552
+ else:
553
+ bias = 0.0
554
+ wf = w
555
+ return bias, wf
556
+
557
+ def _update_constraints(self, beta_bias):
558
+ bias, beta = self._split_coefficents(beta_bias)
559
+
560
+ self._Kw = self._counter.update_sort_order(beta)
561
+
562
+ if self._has_time:
563
+ pred_time = self._counter.time - self._Kw - bias
564
+ self.regr_mask = (pred_time > 0) | self._counter.status
565
+ self.y_compressed = self._counter.time.compress(self.regr_mask, axis=0)
566
+
567
+ def _objective_func(self, beta_bias):
568
+ bias, beta = self._split_coefficents(beta_bias)
569
+
570
+ Kw = self._Kw
571
+
572
+ val = 0.5 * np.dot(beta, Kw)
573
+ if self._has_time:
574
+ val += (
575
+ 0.5 * self._regr_penalty * squared_norm(self.y_compressed - bias - Kw.compress(self.regr_mask, axis=0))
576
+ )
577
+
578
+ l_plus, xv_plus, l_minus, xv_minus = self._counter.calculate(beta) # pylint: disable=unused-variable
579
+ val += (
580
+ 0.5
581
+ * self._rank_penalty
582
+ * numexpr.evaluate(
583
+ "sum(Kw * ((l_plus + l_minus) * Kw - xv_plus - xv_minus - 2 * (l_minus - l_plus)) + l_minus)"
584
+ )
585
+ )
586
+
587
+ return val
588
+
589
+ def _gradient_func(self, beta_bias):
590
+ bias, beta = self._split_coefficents(beta_bias)
591
+
592
+ K = self._counter.x
593
+ Kw = self._Kw
594
+
595
+ l_plus, xv_plus, l_minus, xv_minus = self._counter.calculate(beta) # pylint: disable=unused-variable
596
+ z = numexpr.evaluate("(l_plus + l_minus) * Kw - xv_plus - xv_minus - l_minus + l_plus")
597
+
598
+ gradient = Kw + self._rank_penalty * np.dot(K, z)
599
+ if self._has_time:
600
+ K_comp = K.compress(self.regr_mask, axis=0)
601
+ K_comp_beta = np.dot(K_comp, beta)
602
+ gradient += self._regr_penalty * (
603
+ np.dot(K_comp.T, K_comp_beta) + K_comp.sum(axis=0) * bias - np.dot(K_comp.T, self.y_compressed)
604
+ )
605
+
606
+ # intercept
607
+ if self._fit_intercept:
608
+ grad_intercept = self._regr_penalty * (
609
+ K_comp_beta.sum() + K_comp.shape[0] * bias - self.y_compressed.sum()
610
+ )
611
+ gradient = np.r_[grad_intercept, gradient]
612
+
613
+ return gradient
614
+
615
+ def _hessian_func(self, _beta, s):
616
+ s_bias, s_feat = self._split_coefficents(s)
617
+
618
+ K = self._counter.x
619
+ Ks = np.dot(K, s_feat)
620
+
621
+ l_plus, xv_plus, l_minus, xv_minus = self._counter.calculate(s_feat) # pylint: disable=unused-variable
622
+ xs = numexpr.evaluate("(l_plus + l_minus) * Ks - xv_plus - xv_minus")
623
+
624
+ hessian = Ks + self._rank_penalty * np.dot(K, xs)
625
+ if self._has_time:
626
+ K_comp = K.compress(self.regr_mask, axis=0)
627
+ hessian += self._regr_penalty * np.dot(K_comp.T, np.dot(K_comp, s_feat))
628
+
629
+ # intercept
630
+ if self._fit_intercept:
631
+ xsum = K_comp.sum(axis=0)
632
+ hessian += self._regr_penalty * xsum * s_bias
633
+ hessian_intercept = self._regr_penalty * K_comp.shape[0] * s_bias + self._regr_penalty * np.dot(
634
+ xsum, s_feat
635
+ )
636
+ hessian = np.r_[hessian_intercept, hessian]
637
+
638
+ return hessian
639
+
640
+
641
+ class BaseSurvivalSVM(BaseEstimator, metaclass=ABCMeta):
642
+ _parameter_constraints = {
643
+ "alpha": [Interval(Real, 0.0, None, closed="neither")],
644
+ "rank_ratio": [Interval(Real, 0.0, 1.0, closed="both")],
645
+ "fit_intercept": ["boolean"],
646
+ "max_iter": [Interval(Integral, 1, None, closed="left")],
647
+ "verbose": ["verbose"],
648
+ "tol": [Interval(Real, 0.0, None, closed="neither"), None],
649
+ "random_state": ["random_state"],
650
+ "timeit": [Interval(Integral, 1, None, closed="left"), "boolean"],
651
+ }
652
+
653
+ @abstractmethod
654
+ def __init__(
655
+ self,
656
+ alpha=1,
657
+ rank_ratio=1.0,
658
+ fit_intercept=False,
659
+ max_iter=20,
660
+ verbose=False,
661
+ tol=None,
662
+ optimizer=None,
663
+ random_state=None,
664
+ timeit=False,
665
+ ):
666
+ self.alpha = alpha
667
+ self.rank_ratio = rank_ratio
668
+ self.fit_intercept = fit_intercept
669
+ self.max_iter = max_iter
670
+ self.verbose = verbose
671
+ self.tol = tol
672
+ self.optimizer = optimizer
673
+ self.random_state = random_state
674
+ self.timeit = timeit
675
+
676
+ self.coef_ = None
677
+ self.optimizer_result_ = None
678
+
679
+ def _create_optimizer(self, X, y, status):
680
+ """Samples are ordered by relevance"""
681
+ if self.optimizer is None:
682
+ self.optimizer = "avltree"
683
+
684
+ times, ranks = y
685
+
686
+ if self.optimizer == "simple":
687
+ optimizer = SimpleOptimizer(X, status, self.alpha, self.rank_ratio, timeit=self.timeit)
688
+ elif self.optimizer == "PRSVM":
689
+ optimizer = PRSVMOptimizer(X, status, self.alpha, self.rank_ratio, timeit=self.timeit)
690
+ elif self.optimizer == "direct-count":
691
+ optimizer = LargeScaleOptimizer(
692
+ self.alpha,
693
+ self.rank_ratio,
694
+ self.fit_intercept,
695
+ SurvivalCounter(X, ranks, status, len(ranks), times),
696
+ timeit=self.timeit,
697
+ )
698
+ elif self.optimizer == "rbtree":
699
+ optimizer = LargeScaleOptimizer(
700
+ self.alpha,
701
+ self.rank_ratio,
702
+ self.fit_intercept,
703
+ OrderStatisticTreeSurvivalCounter(X, ranks, status, RBTree, times),
704
+ timeit=self.timeit,
705
+ )
706
+ elif self.optimizer == "avltree":
707
+ optimizer = LargeScaleOptimizer(
708
+ self.alpha,
709
+ self.rank_ratio,
710
+ self.fit_intercept,
711
+ OrderStatisticTreeSurvivalCounter(X, ranks, status, AVLTree, times),
712
+ timeit=self.timeit,
713
+ )
714
+
715
+ return optimizer
716
+
717
+ @property
718
+ def _predict_risk_score(self):
719
+ return self.rank_ratio == 1
720
+
721
+ @abstractmethod
722
+ def _fit(self, X, time, event, samples_order):
723
+ """Create and run optimizer"""
724
+
725
+ @abstractmethod
726
+ def predict(self, X):
727
+ """Predict risk score"""
728
+
729
+ def _validate_for_fit(self, X):
730
+ return self._validate_data(X, ensure_min_samples=2)
731
+
732
+ def fit(self, X, y):
733
+ """Build a survival support vector machine model from training data.
734
+
735
+ Parameters
736
+ ----------
737
+ X : array-like, shape = (n_samples, n_features)
738
+ Data matrix.
739
+
740
+ y : structured array, shape = (n_samples,)
741
+ A structured array containing the binary event indicator
742
+ as first field, and time of event or time of censoring as
743
+ second field.
744
+
745
+ Returns
746
+ -------
747
+ self
748
+ """
749
+ X = self._validate_for_fit(X)
750
+ event, time = check_array_survival(X, y, allow_time_zero=False)
751
+
752
+ self._validate_params()
753
+
754
+ if self.fit_intercept and self.rank_ratio == 1.0:
755
+ raise ValueError("fit_intercept=True is only meaningful if rank_ratio < 1.0")
756
+
757
+ if self.rank_ratio < 1.0:
758
+ if self.optimizer in {"simple", "PRSVM"}:
759
+ raise ValueError(f"optimizer {self.optimizer!r} does not implement regression objective")
760
+
761
+ # log-transform time
762
+ time = np.log(time)
763
+ assert np.isfinite(time).all()
764
+
765
+ random_state = check_random_state(self.random_state)
766
+ samples_order = BaseSurvivalSVM._argsort_and_resolve_ties(time, random_state)
767
+
768
+ opt_result = self._fit(X, time, event, samples_order)
769
+ coef = opt_result.x
770
+ if self.fit_intercept:
771
+ self.coef_ = coef[1:]
772
+ self.intercept_ = coef[0]
773
+ else:
774
+ self.coef_ = coef
775
+
776
+ if not opt_result.success:
777
+ warnings.warn(
778
+ ("Optimization did not converge: " + opt_result.message), category=ConvergenceWarning, stacklevel=2
779
+ )
780
+ self.optimizer_result_ = opt_result
781
+
782
+ return self
783
+
784
+ @property
785
+ def n_iter_(self):
786
+ return self.optimizer_result_.nit
787
+
788
+ @staticmethod
789
+ def _argsort_and_resolve_ties(time, random_state):
790
+ """Like np.argsort, but resolves ties uniformly at random"""
791
+ n_samples = len(time)
792
+ order = np.argsort(time, kind="mergesort")
793
+
794
+ i = 0
795
+ while i < n_samples - 1:
796
+ inext = i + 1
797
+ while inext < n_samples and time[order[i]] == time[order[inext]]:
798
+ inext += 1
799
+
800
+ if i + 1 != inext:
801
+ # resolve ties randomly
802
+ random_state.shuffle(order[i:inext])
803
+ i = inext
804
+ return order
805
+
806
+
807
+ class FastSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin):
808
+ """Efficient Training of linear Survival Support Vector Machine
809
+
810
+ Training data consists of *n* triplets :math:`(\\mathbf{x}_i, y_i, \\delta_i)`,
811
+ where :math:`\\mathbf{x}_i` is a *d*-dimensional feature vector, :math:`y_i > 0`
812
+ the survival time or time of censoring, and :math:`\\delta_i \\in \\{0,1\\}`
813
+ the binary event indicator. Using the training data, the objective is to
814
+ minimize the following function:
815
+
816
+ .. math::
817
+
818
+ \\arg \\min_{\\mathbf{w}, b} \\frac{1}{2} \\mathbf{w}^\\top \\mathbf{w}
819
+ + \\frac{\\alpha}{2} \\left[ r \\sum_{i,j \\in \\mathcal{P}}
820
+ \\max(0, 1 - (\\mathbf{w}^\\top \\mathbf{x}_i - \\mathbf{w}^\\top \\mathbf{x}_j))^2
821
+ + (1 - r) \\sum_{i=0}^n \\left( \\zeta_{\\mathbf{w}, b} (y_i, x_i, \\delta_i)
822
+ \\right)^2 \\right]
823
+
824
+ \\zeta_{\\mathbf{w},b} (y_i, \\mathbf{x}_i, \\delta_i) =
825
+ \\begin{cases}
826
+ \\max(0, y_i - \\mathbf{w}^\\top \\mathbf{x}_i - b) \\quad \\text{if $\\delta_i = 0$,} \\\\
827
+ y_i - \\mathbf{w}^\\top \\mathbf{x}_i - b \\quad \\text{if $\\delta_i = 1$,} \\\\
828
+ \\end{cases}
829
+
830
+ \\mathcal{P} = \\{ (i, j) \\mid y_i > y_j \\land \\delta_j = 1 \\}_{i,j=1,\\dots,n}
831
+
832
+ The hyper-parameter :math:`\\alpha > 0` determines the amount of regularization
833
+ to apply: a smaller value increases the amount of regularization and a
834
+ higher value reduces the amount of regularization. The hyper-parameter
835
+ :math:`r \\in [0; 1]` determines the trade-off between the ranking objective
836
+ and the regression objective. If :math:`r = 1` it reduces to the ranking
837
+ objective, and if :math:`r = 0` to the regression objective. If the regression
838
+ objective is used, survival/censoring times are log-transform and thus cannot be
839
+ zero or negative.
840
+
841
+ See the :ref:`User Guide </user_guide/survival-svm.ipynb>` and [1]_ for further description.
842
+
843
+ Parameters
844
+ ----------
845
+ alpha : float, positive, default: 1
846
+ Weight of penalizing the squared hinge loss in the objective function
847
+
848
+ rank_ratio : float, optional, default: 1.0
849
+ Mixing parameter between regression and ranking objective with ``0 <= rank_ratio <= 1``.
850
+ If ``rank_ratio = 1``, only ranking is performed, if ``rank_ratio = 0``, only regression
851
+ is performed. A non-zero value is only allowed if optimizer is one of 'avltree', 'rbtree',
852
+ or 'direct-count'.
853
+
854
+ fit_intercept : boolean, optional, default: False
855
+ Whether to calculate an intercept for the regression model. If set to ``False``, no intercept
856
+ will be calculated. Has no effect if ``rank_ratio = 1``, i.e., only ranking is performed.
857
+
858
+ max_iter : int, optional, default: 20
859
+ Maximum number of iterations to perform in Newton optimization
860
+
861
+ verbose : bool, optional, default: False
862
+ Whether to print messages during optimization
863
+
864
+ tol : float or None, optional, default: None
865
+ Tolerance for termination. For detailed control, use solver-specific
866
+ options.
867
+
868
+ optimizer : {'avltree', 'direct-count', 'PRSVM', 'rbtree', 'simple'}, optional, default: 'avltree'
869
+ Which optimizer to use.
870
+
871
+ random_state : int or :class:`numpy.random.RandomState` instance, optional
872
+ Random number generator (used to resolve ties in survival times).
873
+
874
+ timeit : False, int or None, default: None
875
+ If non-zero value is provided the time it takes for optimization is measured.
876
+ The given number of repetitions are performed. Results can be accessed from the
877
+ ``optimizer_result_`` attribute.
878
+
879
+ Attributes
880
+ ----------
881
+ coef_ : ndarray, shape = (n_features,)
882
+ Coefficients of the features in the decision function.
883
+
884
+ optimizer_result_ : :class:`scipy.optimize.OptimizeResult`
885
+ Stats returned by the optimizer. See :class:`scipy.optimize.OptimizeResult`.
886
+
887
+ n_features_in_ : int
888
+ Number of features seen during ``fit``.
889
+
890
+ feature_names_in_ : ndarray of shape (`n_features_in_`,)
891
+ Names of features seen during ``fit``. Defined only when `X`
892
+ has feature names that are all strings.
893
+
894
+ n_iter_ : int
895
+ Number of iterations run by the optimization routine to fit the model.
896
+
897
+ See also
898
+ --------
899
+ FastKernelSurvivalSVM
900
+ Fast implementation for arbitrary kernel functions.
901
+
902
+ References
903
+ ----------
904
+ .. [1] Pölsterl, S., Navab, N., and Katouzian, A.,
905
+ "Fast Training of Support Vector Machines for Survival Analysis",
906
+ Machine Learning and Knowledge Discovery in Databases: European Conference,
907
+ ECML PKDD 2015, Porto, Portugal,
908
+ Lecture Notes in Computer Science, vol. 9285, pp. 243-259 (2015)
909
+ """
910
+
911
+ _parameter_constraints = {
912
+ **BaseSurvivalSVM._parameter_constraints,
913
+ "optimizer": [StrOptions({"simple", "PRSVM", "direct-count", "rbtree", "avltree"}), None],
914
+ }
915
+
916
+ def __init__(
917
+ self,
918
+ alpha=1,
919
+ *,
920
+ rank_ratio=1.0,
921
+ fit_intercept=False,
922
+ max_iter=20,
923
+ verbose=False,
924
+ tol=None,
925
+ optimizer=None,
926
+ random_state=None,
927
+ timeit=False,
928
+ ):
929
+ super().__init__(
930
+ alpha=alpha,
931
+ rank_ratio=rank_ratio,
932
+ fit_intercept=fit_intercept,
933
+ max_iter=max_iter,
934
+ verbose=verbose,
935
+ tol=tol,
936
+ optimizer=optimizer,
937
+ random_state=random_state,
938
+ timeit=timeit,
939
+ )
940
+
941
+ def _fit(self, X, time, event, samples_order):
942
+ data_y = (time[samples_order], np.arange(len(samples_order)))
943
+ status = event[samples_order]
944
+
945
+ optimizer = self._create_optimizer(X[samples_order], data_y, status)
946
+ opt_result = optimizer.run(tol=self.tol, options={"maxiter": self.max_iter, "disp": self.verbose})
947
+ return opt_result
948
+
949
+ def predict(self, X):
950
+ """Rank samples according to survival times
951
+
952
+ Lower ranks indicate shorter survival, higher ranks longer survival.
953
+
954
+ Parameters
955
+ ----------
956
+ X : array-like, shape = (n_samples, n_features)
957
+ The input samples.
958
+
959
+ Returns
960
+ -------
961
+ y : ndarray, shape = (n_samples,)
962
+ Predicted ranks.
963
+ """
964
+ check_is_fitted(self, "coef_")
965
+ X = self._validate_data(X, reset=False)
966
+
967
+ val = np.dot(X, self.coef_)
968
+ if hasattr(self, "intercept_"):
969
+ val += self.intercept_
970
+
971
+ # Order by increasing survival time if objective is pure ranking
972
+ if self.rank_ratio == 1:
973
+ val *= -1
974
+ else:
975
+ # model was fitted on log(time), transform to original scale
976
+ val = np.exp(val)
977
+
978
+ return val
979
+
980
+
981
+ class FastKernelSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin):
982
+ """Efficient Training of kernel Survival Support Vector Machine.
983
+
984
+ See the :ref:`User Guide </user_guide/survival-svm.ipynb>` and [1]_ for further description.
985
+
986
+ Parameters
987
+ ----------
988
+ alpha : float, positive, default: 1
989
+ Weight of penalizing the squared hinge loss in the objective function
990
+
991
+ rank_ratio : float, optional, default: 1.0
992
+ Mixing parameter between regression and ranking objective with ``0 <= rank_ratio <= 1``.
993
+ If ``rank_ratio = 1``, only ranking is performed, if ``rank_ratio = 0``, only regression
994
+ is performed. A non-zero value is only allowed if optimizer is one of 'avltree', 'PRSVM',
995
+ or 'rbtree'.
996
+
997
+ fit_intercept : boolean, optional, default: False
998
+ Whether to calculate an intercept for the regression model. If set to ``False``, no intercept
999
+ will be calculated. Has no effect if ``rank_ratio = 1``, i.e., only ranking is performed.
1000
+
1001
+ kernel : str or callable, default: 'linear'.
1002
+ Kernel mapping used internally. This parameter is directly passed to
1003
+ :func:`sklearn.metrics.pairwise.pairwise_kernels`.
1004
+ If `kernel` is a string, it must be one of the metrics
1005
+ in `pairwise.PAIRWISE_KERNEL_FUNCTIONS` or "precomputed".
1006
+ If `kernel` is "precomputed", X is assumed to be a kernel matrix.
1007
+ Alternatively, if `kernel` is a callable function, it is called on
1008
+ each pair of instances (rows) and the resulting value recorded. The
1009
+ callable should take two rows from X as input and return the
1010
+ corresponding kernel value as a single number. This means that
1011
+ callables from :mod:`sklearn.metrics.pairwise` are not allowed, as
1012
+ they operate on matrices, not single samples. Use the string
1013
+ identifying the kernel instead.
1014
+
1015
+ gamma : float, optional, default: None
1016
+ Gamma parameter for the RBF, laplacian, polynomial, exponential chi2
1017
+ and sigmoid kernels. Interpretation of the default value is left to
1018
+ the kernel; see the documentation for :mod:`sklearn.metrics.pairwise`.
1019
+ Ignored by other kernels.
1020
+
1021
+ degree : int, default: 3
1022
+ Degree of the polynomial kernel. Ignored by other kernels.
1023
+
1024
+ coef0 : float, optional
1025
+ Zero coefficient for polynomial and sigmoid kernels.
1026
+ Ignored by other kernels.
1027
+
1028
+ kernel_params : mapping of string to any, optional
1029
+ Additional parameters (keyword arguments) for kernel function passed
1030
+ as callable object.
1031
+
1032
+ max_iter : int, optional, default: 20
1033
+ Maximum number of iterations to perform in Newton optimization
1034
+
1035
+ verbose : bool, optional, default: False
1036
+ Whether to print messages during optimization
1037
+
1038
+ tol : float or None, optional, default: None
1039
+ Tolerance for termination. For detailed control, use solver-specific
1040
+ options.
1041
+
1042
+ optimizer : {'avltree', 'rbtree'}, optional, default: 'rbtree'
1043
+ Which optimizer to use.
1044
+
1045
+ random_state : int or :class:`numpy.random.RandomState` instance, optional
1046
+ Random number generator (used to resolve ties in survival times).
1047
+
1048
+ timeit : False, int or None, default: None
1049
+ If non-zero value is provided the time it takes for optimization is measured.
1050
+ The given number of repetitions are performed. Results can be accessed from the
1051
+ ``optimizer_result_`` attribute.
1052
+
1053
+ Attributes
1054
+ ----------
1055
+ coef_ : ndarray, shape = (n_samples,)
1056
+ Weights assigned to the samples in training data to represent
1057
+ the decision function in kernel space.
1058
+
1059
+ fit_X_ : ndarray
1060
+ Training data.
1061
+
1062
+ optimizer_result_ : :class:`scipy.optimize.OptimizeResult`
1063
+ Stats returned by the optimizer. See :class:`scipy.optimize.OptimizeResult`.
1064
+
1065
+ n_features_in_ : int
1066
+ Number of features seen during ``fit``.
1067
+
1068
+ feature_names_in_ : ndarray of shape (`n_features_in_`,)
1069
+ Names of features seen during ``fit``. Defined only when `X`
1070
+ has feature names that are all strings.
1071
+
1072
+ n_iter_ : int
1073
+ Number of iterations run by the optimization routine to fit the model.
1074
+
1075
+ See also
1076
+ --------
1077
+ FastSurvivalSVM
1078
+ Fast implementation for linear kernel.
1079
+
1080
+ References
1081
+ ----------
1082
+ .. [1] Pölsterl, S., Navab, N., and Katouzian, A.,
1083
+ *An Efficient Training Algorithm for Kernel Survival Support Vector Machines*
1084
+ 4th Workshop on Machine Learning in Life Sciences,
1085
+ 23 September 2016, Riva del Garda, Italy. arXiv:1611.07054
1086
+ """
1087
+
1088
+ _parameter_constraints = {
1089
+ **FastSurvivalSVM._parameter_constraints,
1090
+ "kernel": [
1091
+ StrOptions(set(PAIRWISE_KERNEL_FUNCTIONS.keys()) | {"precomputed"}),
1092
+ callable,
1093
+ ],
1094
+ "gamma": [Interval(Real, 0.0, None, closed="left"), None],
1095
+ "degree": [Interval(Integral, 0, None, closed="left")],
1096
+ "coef0": [Interval(Real, None, None, closed="neither")],
1097
+ "kernel_params": [dict, None],
1098
+ "optimizer": [StrOptions({"rbtree", "avltree"}), None],
1099
+ }
1100
+
1101
+ def __init__(
1102
+ self,
1103
+ alpha=1,
1104
+ *,
1105
+ rank_ratio=1.0,
1106
+ fit_intercept=False,
1107
+ kernel="rbf",
1108
+ gamma=None,
1109
+ degree=3,
1110
+ coef0=1,
1111
+ kernel_params=None,
1112
+ max_iter=20,
1113
+ verbose=False,
1114
+ tol=None,
1115
+ optimizer=None,
1116
+ random_state=None,
1117
+ timeit=False,
1118
+ ):
1119
+ super().__init__(
1120
+ alpha=alpha,
1121
+ rank_ratio=rank_ratio,
1122
+ fit_intercept=fit_intercept,
1123
+ max_iter=max_iter,
1124
+ verbose=verbose,
1125
+ tol=tol,
1126
+ optimizer=optimizer,
1127
+ random_state=random_state,
1128
+ timeit=timeit,
1129
+ )
1130
+ self.kernel = kernel
1131
+ self.gamma = gamma
1132
+ self.degree = degree
1133
+ self.coef0 = coef0
1134
+ self.kernel_params = kernel_params
1135
+
1136
+ def _more_tags(self):
1137
+ # tell sklearn.utils.metaestimators._safe_split function that we expect kernel matrix
1138
+ return {"pairwise": self.kernel == "precomputed"}
1139
+
1140
+ def _get_kernel(self, X, Y=None):
1141
+ if callable(self.kernel):
1142
+ params = self.kernel_params or {}
1143
+ else:
1144
+ params = {"gamma": self.gamma, "degree": self.degree, "coef0": self.coef0}
1145
+ return pairwise_kernels(X, Y, metric=self.kernel, filter_params=True, **params)
1146
+
1147
+ def _create_optimizer(self, kernel_mat, y, status):
1148
+ if self.optimizer is None:
1149
+ self.optimizer = "rbtree"
1150
+
1151
+ times, ranks = y
1152
+
1153
+ if self.optimizer == "rbtree":
1154
+ optimizer = NonlinearLargeScaleOptimizer(
1155
+ self.alpha,
1156
+ self.rank_ratio,
1157
+ self.fit_intercept,
1158
+ OrderStatisticTreeSurvivalCounter(kernel_mat, ranks, status, RBTree, times),
1159
+ timeit=self.timeit,
1160
+ )
1161
+ elif self.optimizer == "avltree":
1162
+ optimizer = NonlinearLargeScaleOptimizer(
1163
+ self.alpha,
1164
+ self.rank_ratio,
1165
+ self.fit_intercept,
1166
+ OrderStatisticTreeSurvivalCounter(kernel_mat, ranks, status, AVLTree, times),
1167
+ timeit=self.timeit,
1168
+ )
1169
+
1170
+ return optimizer
1171
+
1172
+ def _validate_for_fit(self, X):
1173
+ if self.kernel != "precomputed":
1174
+ return super()._validate_for_fit(X)
1175
+ return X
1176
+
1177
+ def _fit(self, X, time, event, samples_order):
1178
+ # don't reorder X here, because it might be a precomputed kernel matrix
1179
+ kernel_mat = self._get_kernel(X)
1180
+ if (np.abs(kernel_mat.T - kernel_mat) > 1e-12).any():
1181
+ raise ValueError("kernel matrix is not symmetric")
1182
+
1183
+ data_y = (time[samples_order], np.arange(len(samples_order)))
1184
+ status = event[samples_order]
1185
+
1186
+ optimizer = self._create_optimizer(kernel_mat[np.ix_(samples_order, samples_order)], data_y, status)
1187
+ opt_result = optimizer.run(tol=self.tol, options={"maxiter": self.max_iter, "disp": self.verbose})
1188
+
1189
+ # reorder coefficients according to order in original training data,
1190
+ # i.e., reverse ordering according to samples_order
1191
+ self.fit_X_ = X
1192
+ if self.fit_intercept:
1193
+ opt_result.x[samples_order + 1] = opt_result.x[1:].copy()
1194
+ else:
1195
+ opt_result.x[samples_order] = opt_result.x.copy()
1196
+
1197
+ return opt_result
1198
+
1199
+ def predict(self, X):
1200
+ """Rank samples according to survival times
1201
+
1202
+ Lower ranks indicate shorter survival, higher ranks longer survival.
1203
+
1204
+ Parameters
1205
+ ----------
1206
+ X : array-like, shape = (n_samples, n_features)
1207
+ The input samples.
1208
+
1209
+ Returns
1210
+ -------
1211
+ y : ndarray, shape = (n_samples,)
1212
+ Predicted ranks.
1213
+ """
1214
+ X = self._validate_data(X, reset=False)
1215
+ kernel_mat = self._get_kernel(X, self.fit_X_)
1216
+
1217
+ val = np.dot(kernel_mat, self.coef_)
1218
+ if hasattr(self, "intercept_"):
1219
+ val += self.intercept_
1220
+
1221
+ # Order by increasing survival time if objective is pure ranking
1222
+ if self.rank_ratio == 1:
1223
+ val *= -1
1224
+ else:
1225
+ # model was fitted on log(time), transform to original scale
1226
+ val = np.exp(val)
1227
+
1228
+ return val