scikit-survival 0.23.1__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. scikit_survival-0.23.1.dist-info/COPYING +674 -0
  2. scikit_survival-0.23.1.dist-info/METADATA +888 -0
  3. scikit_survival-0.23.1.dist-info/RECORD +55 -0
  4. scikit_survival-0.23.1.dist-info/WHEEL +5 -0
  5. scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +138 -0
  7. sksurv/base.py +103 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cpython-313-darwin.so +0 -0
  10. sksurv/column.py +201 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +10 -0
  13. sksurv/datasets/base.py +436 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  17. sksurv/datasets/data/flchain.arff +7887 -0
  18. sksurv/datasets/data/veteran.arff +148 -0
  19. sksurv/datasets/data/whas500.arff +520 -0
  20. sksurv/ensemble/__init__.py +2 -0
  21. sksurv/ensemble/_coxph_loss.cpython-313-darwin.so +0 -0
  22. sksurv/ensemble/boosting.py +1610 -0
  23. sksurv/ensemble/forest.py +947 -0
  24. sksurv/ensemble/survival_loss.py +151 -0
  25. sksurv/exceptions.py +18 -0
  26. sksurv/functions.py +114 -0
  27. sksurv/io/__init__.py +2 -0
  28. sksurv/io/arffread.py +58 -0
  29. sksurv/io/arffwrite.py +145 -0
  30. sksurv/kernels/__init__.py +1 -0
  31. sksurv/kernels/_clinical_kernel.cpython-313-darwin.so +0 -0
  32. sksurv/kernels/clinical.py +328 -0
  33. sksurv/linear_model/__init__.py +3 -0
  34. sksurv/linear_model/_coxnet.cpython-313-darwin.so +0 -0
  35. sksurv/linear_model/aft.py +205 -0
  36. sksurv/linear_model/coxnet.py +543 -0
  37. sksurv/linear_model/coxph.py +618 -0
  38. sksurv/meta/__init__.py +4 -0
  39. sksurv/meta/base.py +35 -0
  40. sksurv/meta/ensemble_selection.py +642 -0
  41. sksurv/meta/stacking.py +349 -0
  42. sksurv/metrics.py +996 -0
  43. sksurv/nonparametric.py +588 -0
  44. sksurv/preprocessing.py +155 -0
  45. sksurv/svm/__init__.py +11 -0
  46. sksurv/svm/_minlip.cpython-313-darwin.so +0 -0
  47. sksurv/svm/_prsvm.cpython-313-darwin.so +0 -0
  48. sksurv/svm/minlip.py +606 -0
  49. sksurv/svm/naive_survival_svm.py +221 -0
  50. sksurv/svm/survival_svm.py +1228 -0
  51. sksurv/testing.py +108 -0
  52. sksurv/tree/__init__.py +1 -0
  53. sksurv/tree/_criterion.cpython-313-darwin.so +0 -0
  54. sksurv/tree/tree.py +703 -0
  55. sksurv/util.py +333 -0
sksurv/metrics.py ADDED
@@ -0,0 +1,996 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import numpy as np
14
+ from sklearn.base import BaseEstimator
15
+ from sklearn.utils import check_array, check_consistent_length
16
+ from sklearn.utils.metaestimators import available_if
17
+ from sklearn.utils.validation import check_is_fitted
18
+
19
+ from .exceptions import NoComparablePairException
20
+ from .nonparametric import CensoringDistributionEstimator, SurvivalFunctionEstimator
21
+ from .util import check_y_survival
22
+
23
+ __all__ = [
24
+ "as_concordance_index_ipcw_scorer",
25
+ "as_cumulative_dynamic_auc_scorer",
26
+ "as_integrated_brier_score_scorer",
27
+ "brier_score",
28
+ "concordance_index_censored",
29
+ "concordance_index_ipcw",
30
+ "cumulative_dynamic_auc",
31
+ "integrated_brier_score",
32
+ ]
33
+
34
+
35
+ def _check_estimate_1d(estimate, test_time):
36
+ estimate = check_array(estimate, ensure_2d=False, input_name="estimate")
37
+ if estimate.ndim != 1:
38
+ raise ValueError(f"Expected 1D array, got {estimate.ndim}D array instead:\narray={estimate}.\n")
39
+ check_consistent_length(test_time, estimate)
40
+ return estimate
41
+
42
+
43
+ def _check_inputs(event_indicator, event_time, estimate):
44
+ check_consistent_length(event_indicator, event_time, estimate)
45
+ event_indicator = check_array(event_indicator, ensure_2d=False, input_name="event_indicator")
46
+ event_time = check_array(event_time, ensure_2d=False, input_name="event_time")
47
+ estimate = _check_estimate_1d(estimate, event_time)
48
+
49
+ if not np.issubdtype(event_indicator.dtype, np.bool_):
50
+ raise ValueError(
51
+ f"only boolean arrays are supported as class labels for survival analysis, got {event_indicator.dtype}"
52
+ )
53
+
54
+ if len(event_time) < 2:
55
+ raise ValueError("Need a minimum of two samples")
56
+
57
+ if not event_indicator.any():
58
+ raise ValueError("All samples are censored")
59
+
60
+ return event_indicator, event_time, estimate
61
+
62
+
63
+ def _check_times(test_time, times):
64
+ times = check_array(np.atleast_1d(times), ensure_2d=False, input_name="times")
65
+ times = np.unique(times)
66
+
67
+ if times.max() >= test_time.max() or times.min() < test_time.min():
68
+ raise ValueError(
69
+ f"all times must be within follow-up time of test data: [{test_time.min()}; {test_time.max()}["
70
+ )
71
+
72
+ return times
73
+
74
+
75
+ def _check_estimate_2d(estimate, test_time, time_points, estimator):
76
+ estimate = check_array(estimate, ensure_2d=False, allow_nd=False, input_name="estimate", estimator=estimator)
77
+ time_points = _check_times(test_time, time_points)
78
+ check_consistent_length(test_time, estimate)
79
+
80
+ if estimate.ndim == 2 and estimate.shape[1] != time_points.shape[0]:
81
+ raise ValueError(f"expected estimate with {time_points.shape[0]} columns, but got {estimate.shape[1]}")
82
+
83
+ return estimate, time_points
84
+
85
+
86
+ def _iter_comparable(event_indicator, event_time, order):
87
+ n_samples = len(event_time)
88
+ tied_time = 0
89
+ i = 0
90
+ while i < n_samples - 1:
91
+ time_i = event_time[order[i]]
92
+ end = i + 1
93
+ while end < n_samples and event_time[order[end]] == time_i:
94
+ end += 1
95
+
96
+ # check for tied event times
97
+ event_at_same_time = event_indicator[order[i:end]]
98
+ censored_at_same_time = ~event_at_same_time
99
+
100
+ mask = np.zeros(n_samples, dtype=bool)
101
+ mask[end:] = True
102
+ # an event is comparable to censored samples at same time point
103
+ mask[i:end] = censored_at_same_time
104
+
105
+ for j in range(i, end):
106
+ if event_indicator[order[j]]:
107
+ tied_time += censored_at_same_time.sum()
108
+ yield (j, mask, tied_time)
109
+ i = end
110
+
111
+
112
+ def _estimate_concordance_index(event_indicator, event_time, estimate, weights, tied_tol=1e-8):
113
+ order = np.argsort(event_time)
114
+
115
+ tied_time = None
116
+
117
+ concordant = 0
118
+ discordant = 0
119
+ tied_risk = 0
120
+ numerator = 0.0
121
+ denominator = 0.0
122
+ for ind, mask, tied_time in _iter_comparable(event_indicator, event_time, order):
123
+ est_i = estimate[order[ind]]
124
+ event_i = event_indicator[order[ind]]
125
+ w_i = weights[order[ind]]
126
+
127
+ est = estimate[order[mask]]
128
+
129
+ assert event_i, f"got censored sample at index {order[ind]}, but expected uncensored"
130
+
131
+ ties = np.absolute(est - est_i) <= tied_tol
132
+ n_ties = ties.sum()
133
+ # an event should have a higher score
134
+ con = est < est_i
135
+ n_con = con[~ties].sum()
136
+
137
+ numerator += w_i * n_con + 0.5 * w_i * n_ties
138
+ denominator += w_i * mask.sum()
139
+
140
+ tied_risk += n_ties
141
+ concordant += n_con
142
+ discordant += est.size - n_con - n_ties
143
+
144
+ if tied_time is None:
145
+ raise NoComparablePairException("Data has no comparable pairs, cannot estimate concordance index.")
146
+
147
+ cindex = numerator / denominator
148
+ return cindex, concordant, discordant, tied_risk, tied_time
149
+
150
+
151
+ def concordance_index_censored(event_indicator, event_time, estimate, tied_tol=1e-8):
152
+ """Concordance index for right-censored data
153
+
154
+ The concordance index is defined as the proportion of all comparable pairs
155
+ in which the predictions and outcomes are concordant.
156
+
157
+ Two samples are comparable if (i) both of them experienced an event (at different times),
158
+ or (ii) the one with a shorter observed survival time experienced an event, in which case
159
+ the event-free subject "outlived" the other. A pair is not comparable if they experienced
160
+ events at the same time.
161
+
162
+ Concordance intuitively means that two samples were ordered correctly by the model.
163
+ More specifically, two samples are concordant, if the one with a higher estimated
164
+ risk score has a shorter actual survival time.
165
+ When predicted risks are identical for a pair, 0.5 rather than 1 is added to the count
166
+ of concordant pairs.
167
+
168
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb>`
169
+ and [1]_ for further description.
170
+
171
+ Parameters
172
+ ----------
173
+ event_indicator : array-like, shape = (n_samples,)
174
+ Boolean array denotes whether an event occurred
175
+
176
+ event_time : array-like, shape = (n_samples,)
177
+ Array containing the time of an event or time of censoring
178
+
179
+ estimate : array-like, shape = (n_samples,)
180
+ Estimated risk of experiencing an event
181
+
182
+ tied_tol : float, optional, default: 1e-8
183
+ The tolerance value for considering ties.
184
+ If the absolute difference between risk scores is smaller
185
+ or equal than `tied_tol`, risk scores are considered tied.
186
+
187
+ Returns
188
+ -------
189
+ cindex : float
190
+ Concordance index
191
+
192
+ concordant : int
193
+ Number of concordant pairs
194
+
195
+ discordant : int
196
+ Number of discordant pairs
197
+
198
+ tied_risk : int
199
+ Number of pairs having tied estimated risks
200
+
201
+ tied_time : int
202
+ Number of comparable pairs sharing the same time
203
+
204
+ See also
205
+ --------
206
+ concordance_index_ipcw
207
+ Alternative estimator of the concordance index with less bias.
208
+
209
+ References
210
+ ----------
211
+ .. [1] Harrell, F.E., Califf, R.M., Pryor, D.B., Lee, K.L., Rosati, R.A,
212
+ "Multivariable prognostic models: issues in developing models,
213
+ evaluating assumptions and adequacy, and measuring and reducing errors",
214
+ Statistics in Medicine, 15(4), 361-87, 1996.
215
+ """
216
+ event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
217
+
218
+ w = np.ones_like(estimate)
219
+
220
+ return _estimate_concordance_index(event_indicator, event_time, estimate, w, tied_tol)
221
+
222
+
223
+ def concordance_index_ipcw(survival_train, survival_test, estimate, tau=None, tied_tol=1e-8):
224
+ """Concordance index for right-censored data based on inverse probability of censoring weights.
225
+
226
+ This is an alternative to the estimator in :func:`concordance_index_censored`
227
+ that does not depend on the distribution of censoring times in the test data.
228
+ Therefore, the estimate is unbiased and consistent for a population concordance
229
+ measure that is free of censoring.
230
+
231
+ It is based on inverse probability of censoring weights, thus requires
232
+ access to survival times from the training data to estimate the censoring
233
+ distribution. Note that this requires that survival times `survival_test`
234
+ lie within the range of survival times `survival_train`. This can be
235
+ achieved by specifying the truncation time `tau`.
236
+ The resulting `cindex` tells how well the given prediction model works in
237
+ predicting events that occur in the time range from 0 to `tau`.
238
+
239
+ The estimator uses the Kaplan-Meier estimator to estimate the
240
+ censoring survivor function. Therefore, it is restricted to
241
+ situations where the random censoring assumption holds and
242
+ censoring is independent of the features.
243
+
244
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb>`
245
+ and [1]_ for further description.
246
+
247
+ Parameters
248
+ ----------
249
+ survival_train : structured array, shape = (n_train_samples,)
250
+ Survival times for training data to estimate the censoring
251
+ distribution from.
252
+ A structured array containing the binary event indicator
253
+ as first field, and time of event or time of censoring as
254
+ second field.
255
+
256
+ survival_test : structured array, shape = (n_samples,)
257
+ Survival times of test data.
258
+ A structured array containing the binary event indicator
259
+ as first field, and time of event or time of censoring as
260
+ second field.
261
+
262
+ estimate : array-like, shape = (n_samples,)
263
+ Estimated risk of experiencing an event of test data.
264
+
265
+ tau : float, optional
266
+ Truncation time. The survival function for the underlying
267
+ censoring time distribution :math:`D` needs to be positive
268
+ at `tau`, i.e., `tau` should be chosen such that the
269
+ probability of being censored after time `tau` is non-zero:
270
+ :math:`P(D > \\tau) > 0`. If `None`, no truncation is performed.
271
+
272
+ tied_tol : float, optional, default: 1e-8
273
+ The tolerance value for considering ties.
274
+ If the absolute difference between risk scores is smaller
275
+ or equal than `tied_tol`, risk scores are considered tied.
276
+
277
+ Returns
278
+ -------
279
+ cindex : float
280
+ Concordance index
281
+
282
+ concordant : int
283
+ Number of concordant pairs
284
+
285
+ discordant : int
286
+ Number of discordant pairs
287
+
288
+ tied_risk : int
289
+ Number of pairs having tied estimated risks
290
+
291
+ tied_time : int
292
+ Number of comparable pairs sharing the same time
293
+
294
+ See also
295
+ --------
296
+ concordance_index_censored
297
+ Simpler estimator of the concordance index.
298
+
299
+ as_concordance_index_ipcw_scorer
300
+ Wrapper class that uses :func:`concordance_index_ipcw`
301
+ in its ``score`` method instead of the default
302
+ :func:`concordance_index_censored`.
303
+
304
+ References
305
+ ----------
306
+ .. [1] Uno, H., Cai, T., Pencina, M. J., D’Agostino, R. B., & Wei, L. J. (2011).
307
+ "On the C-statistics for evaluating overall adequacy of risk prediction
308
+ procedures with censored survival data".
309
+ Statistics in Medicine, 30(10), 1105–1117.
310
+ """
311
+ test_event, test_time = check_y_survival(survival_test)
312
+
313
+ if tau is not None:
314
+ mask = test_time < tau
315
+ survival_test = survival_test[mask]
316
+
317
+ estimate = _check_estimate_1d(estimate, test_time)
318
+
319
+ cens = CensoringDistributionEstimator()
320
+ cens.fit(survival_train)
321
+ ipcw_test = cens.predict_ipcw(survival_test)
322
+ if tau is None:
323
+ ipcw = ipcw_test
324
+ else:
325
+ ipcw = np.empty(estimate.shape[0], dtype=ipcw_test.dtype)
326
+ ipcw[mask] = ipcw_test
327
+ ipcw[~mask] = 0
328
+
329
+ w = np.square(ipcw)
330
+
331
+ return _estimate_concordance_index(test_event, test_time, estimate, w, tied_tol)
332
+
333
+
334
+ def cumulative_dynamic_auc(survival_train, survival_test, estimate, times, tied_tol=1e-8):
335
+ """Estimator of cumulative/dynamic AUC for right-censored time-to-event data.
336
+
337
+ The receiver operating characteristic (ROC) curve and the area under the
338
+ ROC curve (AUC) can be extended to survival data by defining
339
+ sensitivity (true positive rate) and specificity (true negative rate)
340
+ as time-dependent measures. *Cumulative cases* are all individuals that
341
+ experienced an event prior to or at time :math:`t` (:math:`t_i \\leq t`),
342
+ whereas *dynamic controls* are those with :math:`t_i > t`.
343
+ The associated cumulative/dynamic AUC quantifies how well a model can
344
+ distinguish subjects who fail by a given time (:math:`t_i \\leq t`) from
345
+ subjects who fail after this time (:math:`t_i > t`).
346
+
347
+ Given an estimator of the :math:`i`-th individual's risk score
348
+ :math:`\\hat{f}(\\mathbf{x}_i)`, the cumulative/dynamic AUC at time
349
+ :math:`t` is defined as
350
+
351
+ .. math::
352
+
353
+ \\widehat{\\mathrm{AUC}}(t) =
354
+ \\frac{\\sum_{i=1}^n \\sum_{j=1}^n I(y_j > t) I(y_i \\leq t) \\omega_i
355
+ I(\\hat{f}(\\mathbf{x}_j) \\leq \\hat{f}(\\mathbf{x}_i))}
356
+ {(\\sum_{i=1}^n I(y_i > t)) (\\sum_{i=1}^n I(y_i \\leq t) \\omega_i)}
357
+
358
+ where :math:`\\omega_i` are inverse probability of censoring weights (IPCW).
359
+
360
+ To estimate IPCW, access to survival times from the training data is required
361
+ to estimate the censoring distribution. Note that this requires that survival
362
+ times `survival_test` lie within the range of survival times `survival_train`.
363
+ This can be achieved by specifying `times` accordingly, e.g. by setting
364
+ `times[-1]` slightly below the maximum expected follow-up time.
365
+ IPCW are computed using the Kaplan-Meier estimator, which is
366
+ restricted to situations where the random censoring assumption holds and
367
+ censoring is independent of the features.
368
+
369
+ This function can also be used to evaluate models with time-dependent predictions
370
+ :math:`\\hat{f}(\\mathbf{x}_i, t)`, such as :class:`sksurv.ensemble.RandomSurvivalForest`
371
+ (see :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Time-dependent-Risk-Scores>`).
372
+ In this case, `estimate` must be a 2-d array where ``estimate[i, j]`` is the
373
+ predicted risk score for the i-th instance at time point ``times[j]``.
374
+
375
+ Finally, the function also provides a single summary measure that refers to the mean
376
+ of the :math:`\\mathrm{AUC}(t)` over the time range :math:`(\\tau_1, \\tau_2)`.
377
+
378
+ .. math::
379
+
380
+ \\overline{\\mathrm{AUC}}(\\tau_1, \\tau_2) =
381
+ \\frac{1}{\\hat{S}(\\tau_1) - \\hat{S}(\\tau_2)}
382
+ \\int_{\\tau_1}^{\\tau_2} \\widehat{\\mathrm{AUC}}(t)\\,d \\hat{S}(t)
383
+
384
+ where :math:`\\hat{S}(t)` is the Kaplan–Meier estimator of the survival function.
385
+
386
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Area-under-the-ROC>`,
387
+ [1]_, [2]_, [3]_ for further description.
388
+
389
+ Parameters
390
+ ----------
391
+ survival_train : structured array, shape = (n_train_samples,)
392
+ Survival times for training data to estimate the censoring
393
+ distribution from.
394
+ A structured array containing the binary event indicator
395
+ as first field, and time of event or time of censoring as
396
+ second field.
397
+
398
+ survival_test : structured array, shape = (n_samples,)
399
+ Survival times of test data.
400
+ A structured array containing the binary event indicator
401
+ as first field, and time of event or time of censoring as
402
+ second field.
403
+
404
+ estimate : array-like, shape = (n_samples,) or (n_samples, n_times)
405
+ Estimated risk of experiencing an event of test data.
406
+ If `estimate` is a 1-d array, the same risk score across all time
407
+ points is used. If `estimate` is a 2-d array, the risk scores in the
408
+ j-th column are used to evaluate the j-th time point.
409
+
410
+ times : array-like, shape = (n_times,)
411
+ The time points for which the area under the
412
+ time-dependent ROC curve is computed. Values must be
413
+ within the range of follow-up times of the test data
414
+ `survival_test`.
415
+
416
+ tied_tol : float, optional, default: 1e-8
417
+ The tolerance value for considering ties.
418
+ If the absolute difference between risk scores is smaller
419
+ or equal than `tied_tol`, risk scores are considered tied.
420
+
421
+ Returns
422
+ -------
423
+ auc : array, shape = (n_times,)
424
+ The cumulative/dynamic AUC estimates (evaluated at `times`).
425
+ mean_auc : float
426
+ Summary measure referring to the mean cumulative/dynamic AUC
427
+ over the specified time range `(times[0], times[-1])`.
428
+
429
+ See also
430
+ --------
431
+ as_cumulative_dynamic_auc_scorer
432
+ Wrapper class that uses :func:`cumulative_dynamic_auc`
433
+ in its ``score`` method instead of the default
434
+ :func:`concordance_index_censored`.
435
+
436
+ References
437
+ ----------
438
+ .. [1] H. Uno, T. Cai, L. Tian, and L. J. Wei,
439
+ "Evaluating prediction rules for t-year survivors with censored regression models,"
440
+ Journal of the American Statistical Association, vol. 102, pp. 527–537, 2007.
441
+ .. [2] H. Hung and C. T. Chiang,
442
+ "Estimation methods for time-dependent AUC models with survival data,"
443
+ Canadian Journal of Statistics, vol. 38, no. 1, pp. 8–26, 2010.
444
+ .. [3] J. Lambert and S. Chevret,
445
+ "Summary measure of discrimination in survival models based on cumulative/dynamic time-dependent ROC curves,"
446
+ Statistical Methods in Medical Research, 2014.
447
+ """
448
+ test_event, test_time = check_y_survival(survival_test)
449
+ estimate, times = _check_estimate_2d(estimate, test_time, times, estimator="cumulative_dynamic_auc")
450
+
451
+ n_samples = estimate.shape[0]
452
+ n_times = times.shape[0]
453
+ if estimate.ndim == 1:
454
+ estimate = np.broadcast_to(estimate[:, np.newaxis], (n_samples, n_times))
455
+
456
+ # fit and transform IPCW
457
+ cens = CensoringDistributionEstimator()
458
+ cens.fit(survival_train)
459
+ ipcw = cens.predict_ipcw(survival_test)
460
+
461
+ # expand arrays to (n_samples, n_times) shape
462
+ test_time = np.broadcast_to(test_time[:, np.newaxis], (n_samples, n_times))
463
+ test_event = np.broadcast_to(test_event[:, np.newaxis], (n_samples, n_times))
464
+ times_2d = np.broadcast_to(times, (n_samples, n_times))
465
+ ipcw = np.broadcast_to(ipcw[:, np.newaxis], (n_samples, n_times))
466
+
467
+ # sort each time point (columns) by risk score (descending)
468
+ o = np.argsort(-estimate, axis=0)
469
+ test_time = np.take_along_axis(test_time, o, axis=0)
470
+ test_event = np.take_along_axis(test_event, o, axis=0)
471
+ estimate = np.take_along_axis(estimate, o, axis=0)
472
+ ipcw = np.take_along_axis(ipcw, o, axis=0)
473
+
474
+ is_case = (test_time <= times_2d) & test_event
475
+ is_control = test_time > times_2d
476
+ n_controls = is_control.sum(axis=0)
477
+
478
+ # prepend row of infinity values
479
+ estimate_diff = np.concatenate((np.broadcast_to(np.inf, (1, n_times)), estimate))
480
+ is_tied = np.absolute(np.diff(estimate_diff, axis=0)) <= tied_tol
481
+
482
+ cumsum_tp = np.cumsum(is_case * ipcw, axis=0)
483
+ cumsum_fp = np.cumsum(is_control, axis=0)
484
+ true_pos = cumsum_tp / cumsum_tp[-1]
485
+ false_pos = cumsum_fp / n_controls
486
+
487
+ scores = np.empty(n_times, dtype=float)
488
+ it = np.nditer((true_pos, false_pos, is_tied), order="F", flags=["external_loop"])
489
+ with it:
490
+ for i, (tp, fp, mask) in enumerate(it):
491
+ idx = np.flatnonzero(mask) - 1
492
+ # only keep the last estimate for tied risk scores
493
+ tp_no_ties = np.delete(tp, idx)
494
+ fp_no_ties = np.delete(fp, idx)
495
+ # Add an extra threshold position
496
+ # to make sure that the curve starts at (0, 0)
497
+ tp_no_ties = np.r_[0, tp_no_ties]
498
+ fp_no_ties = np.r_[0, fp_no_ties]
499
+ scores[i] = np.trapz(tp_no_ties, fp_no_ties)
500
+
501
+ if n_times == 1:
502
+ mean_auc = scores[0]
503
+ else:
504
+ surv = SurvivalFunctionEstimator()
505
+ surv.fit(survival_test)
506
+ s_times = surv.predict_proba(times)
507
+ # compute integral of AUC over survival function
508
+ d = -np.diff(np.r_[1.0, s_times])
509
+ integral = (scores * d).sum()
510
+ mean_auc = integral / (1.0 - s_times[-1])
511
+
512
+ return scores, mean_auc
513
+
514
+
515
+ def brier_score(survival_train, survival_test, estimate, times):
516
+ """Estimate the time-dependent Brier score for right censored data.
517
+
518
+ The time-dependent Brier score is the mean squared error at time point :math:`t`:
519
+
520
+ .. math::
521
+
522
+ \\mathrm{BS}^c(t) = \\frac{1}{n} \\sum_{i=1}^n I(y_i \\leq t \\land \\delta_i = 1)
523
+ \\frac{(0 - \\hat{\\pi}(t | \\mathbf{x}_i))^2}{\\hat{G}(y_i)} + I(y_i > t)
524
+ \\frac{(1 - \\hat{\\pi}(t | \\mathbf{x}_i))^2}{\\hat{G}(t)} ,
525
+
526
+ where :math:`\\hat{\\pi}(t | \\mathbf{x})` is the predicted probability of
527
+ remaining event-free up to time point :math:`t` for a feature vector :math:`\\mathbf{x}`,
528
+ and :math:`1/\\hat{G}(t)` is a inverse probability of censoring weight, estimated by
529
+ the Kaplan-Meier estimator.
530
+
531
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Brier-Score>`
532
+ and [1]_ for details.
533
+
534
+ Parameters
535
+ ----------
536
+ survival_train : structured array, shape = (n_train_samples,)
537
+ Survival times for training data to estimate the censoring
538
+ distribution from.
539
+ A structured array containing the binary event indicator
540
+ as first field, and time of event or time of censoring as
541
+ second field.
542
+
543
+ survival_test : structured array, shape = (n_samples,)
544
+ Survival times of test data.
545
+ A structured array containing the binary event indicator
546
+ as first field, and time of event or time of censoring as
547
+ second field.
548
+
549
+ estimate : array-like, shape = (n_samples, n_times)
550
+ Estimated probability of remaining event-free at time points
551
+ specified by `times`. The value of ``estimate[i]`` must correspond to
552
+ the estimated probability of remaining event-free up to the time point
553
+ ``times[i]``. Typically, estimated probabilities are obtained via the
554
+ survival function returned by an estimator's
555
+ ``predict_survival_function`` method.
556
+
557
+ times : array-like, shape = (n_times,)
558
+ The time points for which to estimate the Brier score.
559
+ Values must be within the range of follow-up times of
560
+ the test data `survival_test`.
561
+
562
+ Returns
563
+ -------
564
+ times : array, shape = (n_times,)
565
+ Unique time points at which the brier scores was estimated.
566
+
567
+ brier_scores : array , shape = (n_times,)
568
+ Values of the brier score.
569
+
570
+ Examples
571
+ --------
572
+ >>> from sksurv.datasets import load_gbsg2
573
+ >>> from sksurv.linear_model import CoxPHSurvivalAnalysis
574
+ >>> from sksurv.metrics import brier_score
575
+ >>> from sksurv.preprocessing import OneHotEncoder
576
+
577
+ Load and prepare data.
578
+
579
+ >>> X, y = load_gbsg2()
580
+ >>> X.loc[:, "tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
581
+ >>> Xt = OneHotEncoder().fit_transform(X)
582
+
583
+ Fit a Cox model.
584
+
585
+ >>> est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)
586
+
587
+ Retrieve individual survival functions and get probability
588
+ of remaining event free up to 5 years (=1825 days).
589
+
590
+ >>> survs = est.predict_survival_function(Xt)
591
+ >>> preds = [fn(1825) for fn in survs]
592
+
593
+ Compute the Brier score at 5 years.
594
+
595
+ >>> times, score = brier_score(y, y, preds, 1825)
596
+ >>> print(score)
597
+ [0.20881843]
598
+
599
+ See also
600
+ --------
601
+ integrated_brier_score
602
+ Computes the average Brier score over all time points.
603
+
604
+ References
605
+ ----------
606
+ .. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher,
607
+ "Assessment and comparison of prognostic classification schemes for survival data,"
608
+ Statistics in Medicine, vol. 18, no. 17-18, pp. 2529–2545, 1999.
609
+ """
610
+ test_event, test_time = check_y_survival(survival_test)
611
+ estimate, times = _check_estimate_2d(estimate, test_time, times, estimator="brier_score")
612
+ if estimate.ndim == 1 and times.shape[0] == 1:
613
+ estimate = estimate.reshape(-1, 1)
614
+
615
+ # fit IPCW estimator
616
+ cens = CensoringDistributionEstimator().fit(survival_train)
617
+ # calculate inverse probability of censoring weight at current time point t.
618
+ prob_cens_t = cens.predict_proba(times)
619
+ prob_cens_t[prob_cens_t == 0] = np.inf
620
+ # calculate inverse probability of censoring weights at observed time point
621
+ prob_cens_y = cens.predict_proba(test_time)
622
+ prob_cens_y[prob_cens_y == 0] = np.inf
623
+
624
+ # Calculating the brier scores at each time point
625
+ brier_scores = np.empty(times.shape[0], dtype=float)
626
+ for i, t in enumerate(times):
627
+ est = estimate[:, i]
628
+ is_case = (test_time <= t) & test_event
629
+ is_control = test_time > t
630
+
631
+ brier_scores[i] = np.mean(
632
+ np.square(est) * is_case.astype(int) / prob_cens_y
633
+ + np.square(1.0 - est) * is_control.astype(int) / prob_cens_t[i]
634
+ )
635
+
636
+ return times, brier_scores
637
+
638
+
639
+ def integrated_brier_score(survival_train, survival_test, estimate, times):
640
+ """The Integrated Brier Score (IBS) provides an overall calculation of
641
+ the model performance at all available times :math:`t_1 \\leq t \\leq t_\\text{max}`.
642
+
643
+ The integrated time-dependent Brier score over the interval
644
+ :math:`[t_1; t_\\text{max}]` is defined as
645
+
646
+ .. math::
647
+
648
+ \\mathrm{IBS} = \\int_{t_1}^{t_\\text{max}} \\mathrm{BS}^c(t) d w(t)
649
+
650
+ where the weighting function is :math:`w(t) = t / t_\\text{max}`.
651
+ The integral is estimated via the trapezoidal rule.
652
+
653
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Brier-Score>`
654
+ and [1]_ for further details.
655
+
656
+ Parameters
657
+ ----------
658
+ survival_train : structured array, shape = (n_train_samples,)
659
+ Survival times for training data to estimate the censoring
660
+ distribution from.
661
+ A structured array containing the binary event indicator
662
+ as first field, and time of event or time of censoring as
663
+ second field.
664
+
665
+ survival_test : structured array, shape = (n_samples,)
666
+ Survival times of test data.
667
+ A structured array containing the binary event indicator
668
+ as first field, and time of event or time of censoring as
669
+ second field.
670
+
671
+ estimate : array-like, shape = (n_samples, n_times)
672
+ Estimated probability of remaining event-free at time points
673
+ specified by `times`. The value of ``estimate[i]`` must correspond to
674
+ the estimated probability of remaining event-free up to the time point
675
+ ``times[i]``. Typically, estimated probabilities are obtained via the
676
+ survival function returned by an estimator's
677
+ ``predict_survival_function`` method.
678
+
679
+ times : array-like, shape = (n_times,)
680
+ The time points for which to estimate the Brier score.
681
+ Values must be within the range of follow-up times of
682
+ the test data `survival_test`.
683
+
684
+ Returns
685
+ -------
686
+ ibs : float
687
+ The integrated Brier score.
688
+
689
+ Examples
690
+ --------
691
+ >>> import numpy as np
692
+ >>> from sksurv.datasets import load_gbsg2
693
+ >>> from sksurv.linear_model import CoxPHSurvivalAnalysis
694
+ >>> from sksurv.metrics import integrated_brier_score
695
+ >>> from sksurv.preprocessing import OneHotEncoder
696
+
697
+ Load and prepare data.
698
+
699
+ >>> X, y = load_gbsg2()
700
+ >>> X.loc[:, "tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
701
+ >>> Xt = OneHotEncoder().fit_transform(X)
702
+
703
+ Fit a Cox model.
704
+
705
+ >>> est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)
706
+
707
+ Retrieve individual survival functions and get probability
708
+ of remaining event free from 1 year to 5 years (=1825 days).
709
+
710
+ >>> survs = est.predict_survival_function(Xt)
711
+ >>> times = np.arange(365, 1826)
712
+ >>> preds = np.asarray([[fn(t) for t in times] for fn in survs])
713
+
714
+ Compute the integrated Brier score from 1 to 5 years.
715
+
716
+ >>> score = integrated_brier_score(y, y, preds, times)
717
+ >>> print(score)
718
+ 0.1815853064627424
719
+
720
+ See also
721
+ --------
722
+ brier_score
723
+ Computes the Brier score at specified time points.
724
+
725
+ as_integrated_brier_score_scorer
726
+ Wrapper class that uses :func:`integrated_brier_score`
727
+ in its ``score`` method instead of the default
728
+ :func:`concordance_index_censored`.
729
+
730
+ References
731
+ ----------
732
+ .. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher,
733
+ "Assessment and comparison of prognostic classification schemes for survival data,"
734
+ Statistics in Medicine, vol. 18, no. 17-18, pp. 2529–2545, 1999.
735
+ """
736
+ # Computing the brier scores
737
+ times, brier_scores = brier_score(survival_train, survival_test, estimate, times)
738
+
739
+ if times.shape[0] < 2:
740
+ raise ValueError("At least two time points must be given")
741
+
742
+ # Computing the IBS
743
+ ibs_value = np.trapz(brier_scores, times) / (times[-1] - times[0])
744
+
745
+ return ibs_value
746
+
747
+
748
+ def _estimator_has(attr):
749
+ """Check that meta_estimator has `attr`.
750
+
751
+ Used together with `available_if`."""
752
+
753
+ def check(self):
754
+ # raise original `AttributeError` if `attr` does not exist
755
+ getattr(self.estimator_, attr)
756
+ return True
757
+
758
+ return check
759
+
760
+
761
+ class _ScoreOverrideMixin:
762
+ def __init__(self, estimator, predict_func, score_func, score_index, greater_is_better):
763
+ if not hasattr(estimator, predict_func):
764
+ raise AttributeError(f"{estimator!r} object has no attribute {predict_func!r}")
765
+
766
+ self.estimator = estimator
767
+ self._predict_func = predict_func
768
+ self._score_func = score_func
769
+ self._score_index = score_index
770
+ self._sign = 1 if greater_is_better else -1
771
+
772
+ def _get_score_params(self):
773
+ """Return dict of parameters passed to ``score_func``."""
774
+ params = self.get_params(deep=False)
775
+ del params["estimator"]
776
+ return params
777
+
778
+ def fit(self, X, y, **fit_params):
779
+ self._train_y = np.array(y, copy=True)
780
+ self.estimator_ = self.estimator.fit(X, y, **fit_params)
781
+ return self
782
+
783
+ def _do_predict(self, X):
784
+ predict_func = getattr(self.estimator_, self._predict_func)
785
+ return predict_func(X)
786
+
787
+ def score(self, X, y):
788
+ """Returns the score on the given data.
789
+
790
+ Parameters
791
+ ----------
792
+ X : array-like of shape (n_samples, n_features)
793
+ Input data, where n_samples is the number of samples and
794
+ n_features is the number of features.
795
+
796
+ y : array-like of shape (n_samples,)
797
+ Target relative to X for classification or regression;
798
+ None for unsupervised learning.
799
+
800
+ Returns
801
+ -------
802
+ score : float
803
+ """
804
+ estimate = self._do_predict(X)
805
+ score = self._score_func(
806
+ survival_train=self._train_y,
807
+ survival_test=y,
808
+ estimate=estimate,
809
+ **self._get_score_params(),
810
+ )
811
+ if self._score_index is not None:
812
+ score = score[self._score_index]
813
+ return self._sign * score
814
+
815
+ @available_if(_estimator_has("predict"))
816
+ def predict(self, X):
817
+ """Call predict on the estimator.
818
+
819
+ Only available if estimator supports ``predict``.
820
+
821
+ Parameters
822
+ ----------
823
+ X : indexable, length n_samples
824
+ Must fulfill the input assumptions of the
825
+ underlying estimator.
826
+ """
827
+ check_is_fitted(self, "estimator_")
828
+ return self.estimator_.predict(X)
829
+
830
+ @available_if(_estimator_has("predict_cumulative_hazard_function"))
831
+ def predict_cumulative_hazard_function(self, X):
832
+ """Call predict_cumulative_hazard_function on the estimator.
833
+
834
+ Only available if estimator supports ``predict_cumulative_hazard_function``.
835
+
836
+ Parameters
837
+ ----------
838
+ X : indexable, length n_samples
839
+ Must fulfill the input assumptions of the
840
+ underlying estimator.
841
+ """
842
+ check_is_fitted(self, "estimator_")
843
+ return self.estimator_.predict_cumulative_hazard_function(X)
844
+
845
+ @available_if(_estimator_has("predict_survival_function"))
846
+ def predict_survival_function(self, X):
847
+ """Call predict_survival_function on the estimator.
848
+
849
+ Only available if estimator supports ``predict_survival_function``.
850
+
851
+ Parameters
852
+ ----------
853
+ X : indexable, length n_samples
854
+ Must fulfill the input assumptions of the
855
+ underlying estimator.
856
+ """
857
+ check_is_fitted(self, "estimator_")
858
+ return self.estimator_.predict_survival_function(X)
859
+
860
+
861
+ class as_cumulative_dynamic_auc_scorer(_ScoreOverrideMixin, BaseEstimator):
862
+ """Wraps an estimator to use :func:`cumulative_dynamic_auc` as ``score`` function.
863
+
864
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
865
+ for using it for hyper-parameter optimization.
866
+
867
+ Parameters
868
+ ----------
869
+ estimator : object
870
+ Instance of an estimator.
871
+
872
+ times : array-like, shape = (n_times,)
873
+ The time points for which the area under the
874
+ time-dependent ROC curve is computed. Values must be
875
+ within the range of follow-up times of the test data
876
+ `survival_test`.
877
+
878
+ tied_tol : float, optional, default: 1e-8
879
+ The tolerance value for considering ties.
880
+ If the absolute difference between risk scores is smaller
881
+ or equal than `tied_tol`, risk scores are considered tied.
882
+
883
+ Attributes
884
+ ----------
885
+ estimator_ : estimator
886
+ Estimator that was fit.
887
+
888
+ See also
889
+ --------
890
+ cumulative_dynamic_auc
891
+ """
892
+
893
+ def __init__(self, estimator, times, tied_tol=1e-8):
894
+ super().__init__(
895
+ estimator=estimator,
896
+ predict_func="predict",
897
+ score_func=cumulative_dynamic_auc,
898
+ score_index=1,
899
+ greater_is_better=True,
900
+ )
901
+ self.times = times
902
+ self.tied_tol = tied_tol
903
+
904
+
905
+ class as_concordance_index_ipcw_scorer(_ScoreOverrideMixin, BaseEstimator):
906
+ """Wraps an estimator to use :func:`concordance_index_ipcw` as ``score`` function.
907
+
908
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
909
+ for using it for hyper-parameter optimization.
910
+
911
+ Parameters
912
+ ----------
913
+ estimator : object
914
+ Instance of an estimator.
915
+
916
+ tau : float, optional
917
+ Truncation time. The survival function for the underlying
918
+ censoring time distribution :math:`D` needs to be positive
919
+ at `tau`, i.e., `tau` should be chosen such that the
920
+ probability of being censored after time `tau` is non-zero:
921
+ :math:`P(D > \\tau) > 0`. If `None`, no truncation is performed.
922
+
923
+ tied_tol : float, optional, default: 1e-8
924
+ The tolerance value for considering ties.
925
+ If the absolute difference between risk scores is smaller
926
+ or equal than `tied_tol`, risk scores are considered tied.
927
+
928
+ Attributes
929
+ ----------
930
+ estimator_ : estimator
931
+ Estimator that was fit.
932
+
933
+ See also
934
+ --------
935
+ concordance_index_ipcw
936
+ """
937
+
938
+ def __init__(self, estimator, tau=None, tied_tol=1e-8):
939
+ super().__init__(
940
+ estimator=estimator,
941
+ predict_func="predict",
942
+ score_func=concordance_index_ipcw,
943
+ score_index=0,
944
+ greater_is_better=True,
945
+ )
946
+ self.tau = tau
947
+ self.tied_tol = tied_tol
948
+
949
+
950
+ class as_integrated_brier_score_scorer(_ScoreOverrideMixin, BaseEstimator):
951
+ """Wraps an estimator to use the negative of :func:`integrated_brier_score` as ``score`` function.
952
+
953
+ The estimator needs to be able to estimate survival functions via
954
+ a ``predict_survival_function`` method.
955
+
956
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
957
+ for using it for hyper-parameter optimization.
958
+
959
+ Parameters
960
+ ----------
961
+ estimator : object
962
+ Instance of an estimator that provides ``predict_survival_function``.
963
+
964
+ times : array-like, shape = (n_times,)
965
+ The time points for which to estimate the Brier score.
966
+ Values must be within the range of follow-up times of
967
+ the test data `survival_test`.
968
+
969
+ Attributes
970
+ ----------
971
+ estimator_ : estimator
972
+ Estimator that was fit.
973
+
974
+ See also
975
+ --------
976
+ integrated_brier_score
977
+ """
978
+
979
+ def __init__(self, estimator, times):
980
+ super().__init__(
981
+ estimator=estimator,
982
+ predict_func="predict_survival_function",
983
+ score_func=integrated_brier_score,
984
+ score_index=None,
985
+ greater_is_better=False,
986
+ )
987
+ self.times = times
988
+
989
+ def _do_predict(self, X):
990
+ predict_func = getattr(self.estimator_, self._predict_func)
991
+ surv_fns = predict_func(X)
992
+ times = self.times
993
+ estimates = np.empty((len(surv_fns), len(times)))
994
+ for i, fn in enumerate(surv_fns):
995
+ estimates[i, :] = fn(times)
996
+ return estimates