scikit-survival 0.26.0__cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. scikit_survival-0.26.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.26.0.dist-info/RECORD +58 -0
  3. scikit_survival-0.26.0.dist-info/WHEEL +6 -0
  4. scikit_survival-0.26.0.dist-info/licenses/COPYING +674 -0
  5. scikit_survival-0.26.0.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +183 -0
  7. sksurv/base.py +115 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cpython-311-x86_64-linux-gnu.so +0 -0
  10. sksurv/column.py +204 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +12 -0
  13. sksurv/datasets/base.py +614 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/bmt.arff +46 -0
  17. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  18. sksurv/datasets/data/cgvhd.arff +118 -0
  19. sksurv/datasets/data/flchain.arff +7887 -0
  20. sksurv/datasets/data/veteran.arff +148 -0
  21. sksurv/datasets/data/whas500.arff +520 -0
  22. sksurv/docstrings.py +99 -0
  23. sksurv/ensemble/__init__.py +2 -0
  24. sksurv/ensemble/_coxph_loss.cpython-311-x86_64-linux-gnu.so +0 -0
  25. sksurv/ensemble/boosting.py +1564 -0
  26. sksurv/ensemble/forest.py +902 -0
  27. sksurv/ensemble/survival_loss.py +151 -0
  28. sksurv/exceptions.py +18 -0
  29. sksurv/functions.py +114 -0
  30. sksurv/io/__init__.py +2 -0
  31. sksurv/io/arffread.py +91 -0
  32. sksurv/io/arffwrite.py +181 -0
  33. sksurv/kernels/__init__.py +1 -0
  34. sksurv/kernels/_clinical_kernel.cpython-311-x86_64-linux-gnu.so +0 -0
  35. sksurv/kernels/clinical.py +348 -0
  36. sksurv/linear_model/__init__.py +3 -0
  37. sksurv/linear_model/_coxnet.cpython-311-x86_64-linux-gnu.so +0 -0
  38. sksurv/linear_model/aft.py +208 -0
  39. sksurv/linear_model/coxnet.py +592 -0
  40. sksurv/linear_model/coxph.py +637 -0
  41. sksurv/meta/__init__.py +4 -0
  42. sksurv/meta/base.py +35 -0
  43. sksurv/meta/ensemble_selection.py +724 -0
  44. sksurv/meta/stacking.py +370 -0
  45. sksurv/metrics.py +1028 -0
  46. sksurv/nonparametric.py +911 -0
  47. sksurv/preprocessing.py +195 -0
  48. sksurv/svm/__init__.py +11 -0
  49. sksurv/svm/_minlip.cpython-311-x86_64-linux-gnu.so +0 -0
  50. sksurv/svm/_prsvm.cpython-311-x86_64-linux-gnu.so +0 -0
  51. sksurv/svm/minlip.py +695 -0
  52. sksurv/svm/naive_survival_svm.py +249 -0
  53. sksurv/svm/survival_svm.py +1236 -0
  54. sksurv/testing.py +155 -0
  55. sksurv/tree/__init__.py +1 -0
  56. sksurv/tree/_criterion.cpython-311-x86_64-linux-gnu.so +0 -0
  57. sksurv/tree/tree.py +790 -0
  58. sksurv/util.py +416 -0
sksurv/metrics.py ADDED
@@ -0,0 +1,1028 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import numpy as np
14
+ from sklearn.base import BaseEstimator
15
+ from sklearn.utils.metaestimators import available_if
16
+ from sklearn.utils.validation import check_array, check_consistent_length, check_is_fitted
17
+
18
+ from .exceptions import NoComparablePairException
19
+ from .nonparametric import CensoringDistributionEstimator, SurvivalFunctionEstimator
20
+ from .util import check_y_survival
21
+
22
+ __all__ = [
23
+ "as_concordance_index_ipcw_scorer",
24
+ "as_cumulative_dynamic_auc_scorer",
25
+ "as_integrated_brier_score_scorer",
26
+ "brier_score",
27
+ "concordance_index_censored",
28
+ "concordance_index_ipcw",
29
+ "cumulative_dynamic_auc",
30
+ "integrated_brier_score",
31
+ ]
32
+
33
+
34
+ def _check_estimate_1d(estimate, test_time):
35
+ estimate = check_array(estimate, ensure_2d=False, input_name="estimate")
36
+ if estimate.ndim != 1:
37
+ raise ValueError(f"Expected 1D array, got {estimate.ndim}D array instead:\narray={estimate}.\n")
38
+ check_consistent_length(test_time, estimate)
39
+ return estimate
40
+
41
+
42
+ def _check_inputs(event_indicator, event_time, estimate):
43
+ check_consistent_length(event_indicator, event_time, estimate)
44
+ event_indicator = check_array(event_indicator, ensure_2d=False, input_name="event_indicator")
45
+ event_time = check_array(event_time, ensure_2d=False, input_name="event_time")
46
+ estimate = _check_estimate_1d(estimate, event_time)
47
+
48
+ if not np.issubdtype(event_indicator.dtype, np.bool_):
49
+ raise ValueError(
50
+ f"only boolean arrays are supported as class labels for survival analysis, got {event_indicator.dtype}"
51
+ )
52
+
53
+ if len(event_time) < 2:
54
+ raise ValueError("Need a minimum of two samples")
55
+
56
+ if not event_indicator.any():
57
+ raise ValueError("All samples are censored")
58
+
59
+ return event_indicator, event_time, estimate
60
+
61
+
62
+ def _check_times(test_time, times):
63
+ times = check_array(np.atleast_1d(times), ensure_2d=False, input_name="times")
64
+ times = np.unique(times)
65
+
66
+ if times.max() >= test_time.max() or times.min() < test_time.min():
67
+ raise ValueError(
68
+ f"all times must be within follow-up time of test data: [{test_time.min()}; {test_time.max()}["
69
+ )
70
+
71
+ return times
72
+
73
+
74
+ def _check_estimate_2d(estimate, test_time, time_points, estimator):
75
+ estimate = check_array(estimate, ensure_2d=False, allow_nd=False, input_name="estimate", estimator=estimator)
76
+ time_points = _check_times(test_time, time_points)
77
+ check_consistent_length(test_time, estimate)
78
+
79
+ if estimate.ndim == 2 and estimate.shape[1] != time_points.shape[0]:
80
+ raise ValueError(f"expected estimate with {time_points.shape[0]} columns, but got {estimate.shape[1]}")
81
+
82
+ return estimate, time_points
83
+
84
+
85
+ def _iter_comparable(event_indicator, event_time, order):
86
+ n_samples = len(event_time)
87
+ tied_time = 0
88
+ i = 0
89
+ while i < n_samples - 1:
90
+ time_i = event_time[order[i]]
91
+ end = i + 1
92
+ while end < n_samples and event_time[order[end]] == time_i:
93
+ end += 1
94
+
95
+ # check for tied event times
96
+ event_at_same_time = event_indicator[order[i:end]]
97
+ censored_at_same_time = ~event_at_same_time
98
+
99
+ mask = np.zeros(n_samples, dtype=bool)
100
+ mask[end:] = True
101
+ # an event is comparable to censored samples at same time point
102
+ mask[i:end] = censored_at_same_time
103
+
104
+ for j in range(i, end):
105
+ if event_indicator[order[j]]:
106
+ tied_time += censored_at_same_time.sum()
107
+ yield (j, mask, tied_time)
108
+ i = end
109
+
110
+
111
+ def _estimate_concordance_index(event_indicator, event_time, estimate, weights, tied_tol=1e-8):
112
+ order = np.argsort(event_time)
113
+
114
+ tied_time = None
115
+
116
+ concordant = 0
117
+ discordant = 0
118
+ tied_risk = 0
119
+ numerator = 0.0
120
+ denominator = 0.0
121
+ for ind, mask, tied_time in _iter_comparable(event_indicator, event_time, order):
122
+ est_i = estimate[order[ind]]
123
+ event_i = event_indicator[order[ind]]
124
+ w_i = weights[order[ind]]
125
+
126
+ est = estimate[order[mask]]
127
+
128
+ assert event_i, f"got censored sample at index {order[ind]}, but expected uncensored"
129
+
130
+ ties = np.absolute(est - est_i) <= tied_tol
131
+ n_ties = ties.sum()
132
+ # an event should have a higher score
133
+ con = est < est_i
134
+ n_con = con[~ties].sum()
135
+
136
+ numerator += w_i * n_con + 0.5 * w_i * n_ties
137
+ denominator += w_i * mask.sum()
138
+
139
+ tied_risk += n_ties
140
+ concordant += n_con
141
+ discordant += est.size - n_con - n_ties
142
+
143
+ if tied_time is None:
144
+ raise NoComparablePairException("Data has no comparable pairs, cannot estimate concordance index.")
145
+
146
+ cindex = numerator / denominator
147
+ return cindex, concordant, discordant, tied_risk, tied_time
148
+
149
+
150
+ def concordance_index_censored(event_indicator, event_time, estimate, tied_tol=1e-8):
151
+ """Measures the agreement between a predicted risk score and the actual time-to-event.
152
+
153
+ The concordance index is a measure of rank correlation between predicted risk
154
+ scores and observed time points. It is defined as the proportion of all comparable
155
+ pairs in which the predictions and outcomes are concordant.
156
+ A pair of samples is concordant if the sample with a higher risk score has a
157
+ shorter time-to-event. A higher concordance index indicates better model performance.
158
+
159
+ A pair of samples is considered comparable if the sample with a shorter
160
+ survival time experienced an event. This means we can confidently say that
161
+ the individual with the shorter time had a worse outcome. If both samples
162
+ are censored, or if they experienced an event at the same time, they are
163
+ not comparable.
164
+
165
+ When predicted risks are identical for a pair, 0.5 rather than 1 is added to the count
166
+ of concordant pairs.
167
+
168
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb>`
169
+ and [1]_ for further description.
170
+
171
+ Parameters
172
+ ----------
173
+ event_indicator : array-like, shape = (n_samples,)
174
+ A boolean array where ``True`` indicates an event and ``False`` indicates
175
+ censoring.
176
+ event_time : array-like, shape = (n_samples,)
177
+ Array containing the time of an event or time of censoring.
178
+ estimate : array-like, shape = (n_samples,)
179
+ The predicted risk score for each sample (e.g., from ``estimator.predict(X)``).
180
+ A higher value indicates a higher risk of experiencing an event.
181
+ tied_tol : float, optional, default: 1e-8
182
+ The tolerance value for considering ties in risk scores. If the
183
+ absolute difference between two risk scores is smaller than or equal to
184
+ ``tied_tol``, they are considered tied.
185
+
186
+ Returns
187
+ -------
188
+ cindex : float
189
+ The concordance index.
190
+ concordant : int
191
+ The number of concordant pairs.
192
+ discordant : int
193
+ The number of discordant pairs.
194
+ tied_risk : int
195
+ The number of pairs with tied risk scores.
196
+ tied_time : int
197
+ The number of comparable pairs with tied survival times.
198
+
199
+ Notes
200
+ -----
201
+ This metric expects risk scores, which are typically returned by ``estimator.predict(X)``.
202
+ It *does not accept* survival probabilities.
203
+
204
+ See also
205
+ --------
206
+ concordance_index_ipcw
207
+ A less biased estimator of the concordance index.
208
+
209
+ References
210
+ ----------
211
+ .. [1] Harrell, F.E., Califf, R.M., Pryor, D.B., Lee, K.L., Rosati, R.A,
212
+ "Multivariable prognostic models: issues in developing models,
213
+ evaluating assumptions and adequacy, and measuring and reducing errors",
214
+ Statistics in Medicine, 15(4), 361-87, 1996.
215
+ """
216
+ event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
217
+
218
+ w = np.ones_like(estimate)
219
+
220
+ return _estimate_concordance_index(event_indicator, event_time, estimate, w, tied_tol)
221
+
222
+
223
+ def concordance_index_ipcw(survival_train, survival_test, estimate, tau=None, tied_tol=1e-8):
224
+ r"""Concordance index for right-censored data based on inverse probability of censoring weights.
225
+
226
+ This is an alternative to the estimator in :func:`concordance_index_censored`
227
+ that does not depend on the distribution of censoring times in the test data.
228
+ By using inverse probability of censoring weights (IPCW), it provides an unbiased
229
+ and consistent estimate of the population concordance measure.
230
+
231
+ This estimator requires access to survival times from the training data to
232
+ estimate the censoring distribution.
233
+ Note that survival times in `survival_test` must lie within the range of
234
+ survival times in `survival_train`. This can be achieved by specifying the
235
+ truncation time `tau`.
236
+ The resulting `cindex` tells how well the given prediction model works in
237
+ predicting events that occur in the time range from 0 to `tau`.
238
+
239
+ For time points in `survival_test` that lie outside of the range specified by
240
+ values in `survival_train`, the probability of censoring is unknown and an
241
+ exception will be raised::
242
+
243
+ ValueError: time must be smaller than largest observed time point
244
+
245
+ The censoring distribution is estimated using the Kaplan-Meier estimator, which
246
+ assumes that censoring is random and independent of the features.
247
+
248
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb>`
249
+ and [1]_ for further description.
250
+
251
+ Parameters
252
+ ----------
253
+ survival_train : structured array, shape = (n_train_samples,)
254
+ Survival times for the training data, used to estimate the censoring
255
+ distribution.
256
+ A structured array with two fields. The first field is a boolean
257
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
258
+ The second field is a float with the time of event or time of censoring.
259
+ survival_test : structured array, shape = (n_samples,)
260
+ Survival times for the test data.
261
+ A structured array with two fields. The first field is a boolean
262
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
263
+ The second field is a float with the time of event or time of censoring.
264
+ estimate : array-like, shape = (n_samples,)
265
+ Predicted risk scores for the test data (e.g., from ``estimator.predict(X)``).
266
+ A higher value indicates a higher risk of experiencing an event.
267
+ tau : float, optional
268
+ Truncation time. The survival function for the underlying
269
+ censoring time distribution :math:`D` needs to be positive
270
+ at `tau`, i.e., `tau` should be chosen such that the
271
+ probability of being censored after time `tau` is non-zero:
272
+ :math:`P(D > \tau) > 0`. If `None`, no truncation is performed.
273
+ tied_tol : float, optional, default: 1e-8
274
+ The tolerance value for considering ties in risk scores.
275
+ If the absolute difference between two risk scores is smaller than
276
+ or equal to ``tied_tol``, they are considered tied.
277
+
278
+ Returns
279
+ -------
280
+ cindex : float
281
+ The concordance index.
282
+ concordant : int
283
+ The number of concordant pairs.
284
+ discordant : int
285
+ The number of discordant pairs.
286
+ tied_risk : int
287
+ The number of pairs with tied risk scores.
288
+ tied_time : int
289
+ The number of comparable pairs with tied survival times.
290
+
291
+ Notes
292
+ -----
293
+ This metric expects risk scores, which are typically returned by ``estimator.predict(X)``.
294
+ It *does not accept* survival probabilities.
295
+
296
+ See also
297
+ --------
298
+ concordance_index_censored
299
+ A simpler, but potentially biased, estimator of the concordance index.
300
+ as_concordance_index_ipcw_scorer
301
+ A wrapper class that uses :func:`concordance_index_ipcw`
302
+ in its ``score`` method instead of the default
303
+ :func:`concordance_index_censored`.
304
+
305
+ References
306
+ ----------
307
+ .. [1] Uno, H., Cai, T., Pencina, M. J., D’Agostino, R. B., & Wei, L. J. (2011).
308
+ "On the C-statistics for evaluating overall adequacy of risk prediction
309
+ procedures with censored survival data".
310
+ Statistics in Medicine, 30(10), 1105–1117.
311
+ """
312
+ test_event, test_time = check_y_survival(survival_test)
313
+
314
+ if tau is not None:
315
+ mask = test_time < tau
316
+ survival_test = survival_test[mask]
317
+
318
+ estimate = _check_estimate_1d(estimate, test_time)
319
+
320
+ cens = CensoringDistributionEstimator()
321
+ cens.fit(survival_train)
322
+ ipcw_test = cens.predict_ipcw(survival_test)
323
+ if tau is None:
324
+ ipcw = ipcw_test
325
+ else:
326
+ ipcw = np.empty(estimate.shape[0], dtype=ipcw_test.dtype)
327
+ ipcw[mask] = ipcw_test
328
+ ipcw[~mask] = 0
329
+
330
+ w = np.square(ipcw)
331
+
332
+ return _estimate_concordance_index(test_event, test_time, estimate, w, tied_tol)
333
+
334
+
335
+ def cumulative_dynamic_auc(survival_train, survival_test, estimate, times, tied_tol=1e-8):
336
+ r"""Computes the cumulative/dynamic area under the ROC curve (AUC) for right-censored data.
337
+
338
+ This metric evaluates a model's performance at specific time points.
339
+ The cumulative/dynamic AUC at time :math:`t` quantifies how well a model can
340
+ distinguish subjects who experience an event by time :math:`t` (cases) from
341
+ those who do not (controls). A higher AUC indicates better model performance.
342
+
343
+ This function can also evaluate models with time-dependent predictions, such as
344
+ :class:`sksurv.ensemble.RandomSurvivalForest`
345
+ (see :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Time-dependent-Risk-Scores>`).
346
+ In this case, ``estimate`` must be a 2D array where ``estimate[i, j]`` is the
347
+ predicted risk score for the :math:`i`-th instance at time point ``times[j]``.
348
+
349
+ The receiver operating characteristic (ROC) curve and the area under the
350
+ ROC curve (AUC) are metrics to evaluate a binary classifier. Each point on
351
+ the ROC denotes the performance of a binary classifier at a specific
352
+ threshold with respect to the sensitivity (true positive rate) on the
353
+ y-axis and the specificity (true negative rate) on the x-axis.
354
+
355
+ ROC and AUC can be extended to survival analysis by defining cases and
356
+ controls based on a time point :math:`t`. *Cumulative cases* are all
357
+ individuals that experienced an event prior to or at time
358
+ :math:`t` (:math:`t_i \leq t`), whereas *dynamic controls* are those
359
+ with :math:`t_i > t`. Given an estimator of the :math:`i`-th individual's
360
+ risk score :math:`\hat{f}(\mathbf{x}_i)`, the cumulative/dynamic AUC at
361
+ time :math:`t` is defined as
362
+
363
+ .. math::
364
+
365
+ \widehat{\mathrm{AUC}}(t) =
366
+ \frac{\sum_{i=1}^n \sum_{j=1}^n I(y_j > t) I(y_i \leq t) \omega_i
367
+ I(\hat{f}(\mathbf{x}_j) \leq \hat{f}(\mathbf{x}_i))}
368
+ {(\sum_{i=1}^n I(y_i > t)) (\sum_{i=1}^n I(y_i \leq t) \omega_i)}
369
+
370
+ where :math:`\omega_i` are inverse probability of censoring weights (IPCW).
371
+
372
+ To account for censoring, this metric uses inverse probability of censoring
373
+ weights (IPCW), which requires access to survival times from the training
374
+ data to estimate the censoring distribution. Note that survival times in
375
+ ``survival_test`` must lie within the range of survival times in ``survival_train``.
376
+ This can be achieved by specifying ``times`` accordingly, e.g. by setting
377
+ ``times[-1]`` slightly below the maximum expected follow-up time.
378
+
379
+ For time points in ``survival_test`` that lie outside of the range specified by
380
+ values in ``survival_train``, the probability of censoring is unknown and an
381
+ exception will be raised::
382
+
383
+ ValueError: time must be smaller than largest observed time point
384
+
385
+ The censoring distribution is estimated using the Kaplan-Meier estimator, which
386
+ assumes that censoring is random and independent of the features.
387
+
388
+ The function also returns a summary measure, which is the mean of the
389
+ :math:`\mathrm{AUC}(t)` over the specified time range, weighted by the
390
+ estimated survival function:
391
+
392
+ .. math::
393
+
394
+ \overline{\mathrm{AUC}}(\tau_1, \tau_2) =
395
+ \frac{1}{\hat{S}(\tau_1) - \hat{S}(\tau_2)}
396
+ \int_{\tau_1}^{\tau_2} \widehat{\mathrm{AUC}}(t)\,d \hat{S}(t)
397
+
398
+ where :math:`\hat{S}(t)` is the Kaplan–Meier estimator of the survival function.
399
+
400
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Area-under-the-ROC>`,
401
+ [1]_, [2]_, [3]_ for further description.
402
+
403
+ Parameters
404
+ ----------
405
+ survival_train : structured array, shape = (n_train_samples,)
406
+ Survival times for the training data, used to estimate the censoring
407
+ distribution.
408
+ A structured array with two fields. The first field is a boolean
409
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
410
+ The second field is a float with the time of event or time of censoring.
411
+ survival_test : structured array, shape = (n_samples,)
412
+ Survival times for the test data.
413
+ A structured array with two fields. The first field is a boolean
414
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
415
+ The second field is a float with the time of event or time of censoring.
416
+ estimate : array-like, shape = (n_samples,) or (n_samples, n_times)
417
+ Predicted risk scores for the test data (e.g., from ``estimator.predict(X)``.
418
+ A higher value indicates a higher risk of experiencing an event.
419
+ If a 1D array is provided, the same risk score is used for all time points.
420
+ If a 2D array is provided, ``estimate[:, j]`` is used for the :math:`j`-th
421
+ time point.
422
+ times : array-like, shape = (n_times,)
423
+ The time points at which to compute the AUC. Values must be within the
424
+ range of follow-up times in ``survival_test``.
425
+ tied_tol : float, optional, default: 1e-8
426
+ The tolerance value for considering ties in risk scores. If the
427
+ absolute difference between two risk scores is smaller than or equal to
428
+ ``tied_tol``, they are considered tied.
429
+
430
+ Returns
431
+ -------
432
+ auc : ndarray, shape = (n_times,)
433
+ The cumulative/dynamic AUC estimates at each time point in ``times``.
434
+ mean_auc : float
435
+ The mean cumulative/dynamic AUC over the specified time range ``(times[0], times[-1])``.
436
+
437
+ Notes
438
+ -----
439
+ This metric expects risk scores, which are typically returned by ``estimator.predict(X)``
440
+ (for time-independent risks), or ``estimator.predict_cumulative_hazard_function(X)``
441
+ (for time-dependent risks). It *does not accept* survival probabilities.
442
+
443
+ See also
444
+ --------
445
+ as_cumulative_dynamic_auc_scorer
446
+ A wrapper class that uses :func:`cumulative_dynamic_auc`
447
+ in its ``score`` method instead of the default
448
+ :func:`concordance_index_censored`.
449
+
450
+ References
451
+ ----------
452
+ .. [1] H. Uno, T. Cai, L. Tian, and L. J. Wei,
453
+ "Evaluating prediction rules for t-year survivors with censored regression models,"
454
+ Journal of the American Statistical Association, vol. 102, pp. 527–537, 2007.
455
+ .. [2] H. Hung and C. T. Chiang,
456
+ "Estimation methods for time-dependent AUC models with survival data,"
457
+ Canadian Journal of Statistics, vol. 38, no. 1, pp. 8–26, 2010.
458
+ .. [3] J. Lambert and S. Chevret,
459
+ "Summary measure of discrimination in survival models based on cumulative/dynamic time-dependent ROC curves,"
460
+ Statistical Methods in Medical Research, 2014.
461
+ """
462
+ test_event, test_time = check_y_survival(survival_test)
463
+ estimate, times = _check_estimate_2d(estimate, test_time, times, estimator="cumulative_dynamic_auc")
464
+
465
+ n_samples = estimate.shape[0]
466
+ n_times = times.shape[0]
467
+ if estimate.ndim == 1:
468
+ estimate = np.broadcast_to(estimate[:, np.newaxis], (n_samples, n_times))
469
+
470
+ # fit and transform IPCW
471
+ cens = CensoringDistributionEstimator()
472
+ cens.fit(survival_train)
473
+ ipcw = cens.predict_ipcw(survival_test)
474
+
475
+ # expand arrays to (n_samples, n_times) shape
476
+ test_time = np.broadcast_to(test_time[:, np.newaxis], (n_samples, n_times))
477
+ test_event = np.broadcast_to(test_event[:, np.newaxis], (n_samples, n_times))
478
+ times_2d = np.broadcast_to(times, (n_samples, n_times))
479
+ ipcw = np.broadcast_to(ipcw[:, np.newaxis], (n_samples, n_times))
480
+
481
+ # sort each time point (columns) by risk score (descending)
482
+ o = np.argsort(-estimate, axis=0)
483
+ test_time = np.take_along_axis(test_time, o, axis=0)
484
+ test_event = np.take_along_axis(test_event, o, axis=0)
485
+ estimate = np.take_along_axis(estimate, o, axis=0)
486
+ ipcw = np.take_along_axis(ipcw, o, axis=0)
487
+
488
+ is_case = (test_time <= times_2d) & test_event
489
+ is_control = test_time > times_2d
490
+ n_controls = is_control.sum(axis=0)
491
+
492
+ # prepend row of infinity values
493
+ estimate_diff = np.concatenate((np.broadcast_to(np.inf, (1, n_times)), estimate))
494
+ is_tied = np.absolute(np.diff(estimate_diff, axis=0)) <= tied_tol
495
+
496
+ cumsum_tp = np.cumsum(is_case * ipcw, axis=0)
497
+ cumsum_fp = np.cumsum(is_control, axis=0)
498
+ true_pos = cumsum_tp / cumsum_tp[-1]
499
+ false_pos = cumsum_fp / n_controls
500
+
501
+ scores = np.empty(n_times, dtype=float)
502
+ it = np.nditer((true_pos, false_pos, is_tied), order="F", flags=["external_loop"])
503
+ with it:
504
+ for i, (tp, fp, mask) in enumerate(it):
505
+ idx = np.flatnonzero(mask) - 1
506
+ # only keep the last estimate for tied risk scores
507
+ tp_no_ties = np.delete(tp, idx)
508
+ fp_no_ties = np.delete(fp, idx)
509
+ # Add an extra threshold position
510
+ # to make sure that the curve starts at (0, 0)
511
+ tp_no_ties = np.r_[0, tp_no_ties]
512
+ fp_no_ties = np.r_[0, fp_no_ties]
513
+ scores[i] = np.trapz(tp_no_ties, fp_no_ties)
514
+
515
+ if n_times == 1:
516
+ mean_auc = scores[0]
517
+ else:
518
+ surv = SurvivalFunctionEstimator()
519
+ surv.fit(survival_test)
520
+ s_times = surv.predict_proba(times)
521
+ # compute integral of AUC over survival function
522
+ d = -np.diff(np.r_[1.0, s_times])
523
+ integral = (scores * d).sum()
524
+ mean_auc = integral / (1.0 - s_times[-1])
525
+
526
+ return scores, mean_auc
527
+
528
+
529
+ def brier_score(survival_train, survival_test, estimate, times):
530
+ r"""The time-dependent Brier score for right-censored data.
531
+
532
+ The time-dependent Brier score measures the inaccuracy of
533
+ predicted survival probabilities at a given time point.
534
+ It is the mean squared error between the true survival status
535
+ and the predicted survival probability at time point :math:`t`.
536
+ A lower Brier score indicates better model performance.
537
+
538
+ To account for censoring, this metric uses inverse probability of censoring
539
+ weights (IPCW), which requires access to survival times from the training
540
+ data to estimate the censoring distribution. Note that survival times in
541
+ ``survival_test`` must lie within the range of survival times in ``survival_train``.
542
+ This can be achieved by specifying ``times`` accordingly, e.g. by setting
543
+ ``times[-1]`` slightly below the maximum expected follow-up time.
544
+
545
+ For time points in ``survival_test`` that lie outside of the range specified by
546
+ values in ``survival_train``, the probability of censoring is unknown and an
547
+ exception will be raised::
548
+
549
+ ValueError: time must be smaller than largest observed time point
550
+
551
+ The censoring distribution is estimated using the Kaplan-Meier estimator, which
552
+ assumes that censoring is random and independent of the features.
553
+
554
+ The time-dependent Brier score at time :math:`t` is defined as
555
+
556
+ .. math::
557
+
558
+ \mathrm{BS}^c(t) = \frac{1}{n} \sum_{i=1}^n I(y_i \leq t \land \delta_i = 1)
559
+ \frac{(0 - \hat{\pi}(t | \mathbf{x}_i))^2}{\hat{G}(y_i)} + I(y_i > t)
560
+ \frac{(1 - \hat{\pi}(t | \mathbf{x}_i))^2}{\hat{G}(t)} ,
561
+
562
+ where :math:`\hat{\pi}(t | \mathbf{x})` is the predicted survival probability
563
+ up to the time point :math:`t` for a feature vector :math:`\mathbf{x}`,
564
+ and :math:`1/\hat{G}(t)` is a inverse probability of censoring weight.
565
+
566
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Brier-Score>`
567
+ and [1]_ for details.
568
+
569
+ Parameters
570
+ ----------
571
+ survival_train : structured array, shape = (n_train_samples,)
572
+ Survival times for the training data, used to estimate the censoring
573
+ distribution.
574
+ A structured array with two fields. The first field is a boolean
575
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
576
+ The second field is a float with the time of event or time of censoring.
577
+ survival_test : structured array, shape = (n_samples,)
578
+ Survival times for the test data.
579
+ A structured array with two fields. The first field is a boolean
580
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
581
+ The second field is a float with the time of event or time of censoring.
582
+ estimate : array-like, shape = (n_samples, n_times)
583
+ Predicted survival probabilities for the test data at the time points
584
+ specified by ``times``, typically obtained from
585
+ ``estimator.predict_survival_function(X)``. The value of ``estimate[:, i]``
586
+ must correspond to the estimated survival probability up to
587
+ the time point ``times[i]``.
588
+ times : array-like, shape = (n_times,)
589
+ The time points at which to compute the Brier score. Values must be
590
+ within the range of follow-up times in ``survival_test``.
591
+
592
+ Returns
593
+ -------
594
+ times : ndarray, shape = (n_times,)
595
+ The unique time points at which the Brier score was estimated.
596
+ brier_scores : ndarray, shape = (n_times,)
597
+ The Brier score at each time point in ``times``.
598
+
599
+ Notes
600
+ -----
601
+ This metric expects survival probabilities, which are typically returned by
602
+ ``estimator.predict_survival_function(X)``.
603
+ It *does not accept* risk scores.
604
+
605
+ Examples
606
+ --------
607
+ >>> from sksurv.datasets import load_gbsg2
608
+ >>> from sksurv.linear_model import CoxPHSurvivalAnalysis
609
+ >>> from sksurv.metrics import brier_score
610
+ >>> from sksurv.preprocessing import OneHotEncoder
611
+
612
+ Load and prepare data.
613
+
614
+ >>> X, y = load_gbsg2()
615
+ >>> X["tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
616
+ >>> Xt = OneHotEncoder().fit_transform(X)
617
+
618
+ Fit a Cox model.
619
+
620
+ >>> est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)
621
+
622
+ Retrieve individual survival functions and get probability
623
+ of remaining event free up to 5 years (=1825 days).
624
+
625
+ >>> survs = est.predict_survival_function(Xt)
626
+ >>> preds = [fn(1825) for fn in survs]
627
+
628
+ Compute the Brier score at 5 years.
629
+
630
+ >>> times, score = brier_score(y, y, preds, 1825)
631
+ >>> print(score)
632
+ [0.20881843]
633
+
634
+ See also
635
+ --------
636
+ integrated_brier_score
637
+ Computes the average Brier score over all time points.
638
+
639
+ References
640
+ ----------
641
+ .. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher,
642
+ "Assessment and comparison of prognostic classification schemes for survival data,"
643
+ Statistics in Medicine, vol. 18, no. 17-18, pp. 2529–2545, 1999.
644
+ """
645
+ test_event, test_time = check_y_survival(survival_test)
646
+ estimate, times = _check_estimate_2d(estimate, test_time, times, estimator="brier_score")
647
+ if estimate.ndim == 1 and times.shape[0] == 1:
648
+ estimate = estimate.reshape(-1, 1)
649
+
650
+ # fit IPCW estimator
651
+ cens = CensoringDistributionEstimator().fit(survival_train)
652
+ # calculate inverse probability of censoring weight at current time point t.
653
+ prob_cens_t = cens.predict_proba(times)
654
+ prob_cens_t[prob_cens_t == 0] = np.inf
655
+ # calculate inverse probability of censoring weights at observed time point
656
+ prob_cens_y = cens.predict_proba(test_time)
657
+ prob_cens_y[prob_cens_y == 0] = np.inf
658
+
659
+ # Calculating the brier scores at each time point
660
+ brier_scores = np.empty(times.shape[0], dtype=float)
661
+ for i, t in enumerate(times):
662
+ est = estimate[:, i]
663
+ is_case = (test_time <= t) & test_event
664
+ is_control = test_time > t
665
+
666
+ brier_scores[i] = np.mean(
667
+ np.square(est) * is_case.astype(int) / prob_cens_y
668
+ + np.square(1.0 - est) * is_control.astype(int) / prob_cens_t[i]
669
+ )
670
+
671
+ return times, brier_scores
672
+
673
+
674
+ def integrated_brier_score(survival_train, survival_test, estimate, times):
675
+ r"""Computes the integrated Brier score (IBS).
676
+
677
+ The IBS is an overall measure of the model's performance across all
678
+ available time points :math:`t_1 \leq t \leq t_\text{max}`.
679
+ It is the average Brier score, integrated over time.
680
+ A lower IBS indicates better model performance.
681
+
682
+ The integrated time-dependent Brier score over the interval
683
+ :math:`[t_1; t_\text{max}]` is defined as
684
+
685
+ .. math::
686
+
687
+ \mathrm{IBS} = \int_{t_1}^{t_\text{max}} \mathrm{BS}^c(t) d w(t)
688
+
689
+ where the weighting function is :math:`w(t) = t / t_\text{max}`.
690
+ The integral is estimated via the trapezoidal rule.
691
+
692
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Brier-Score>`
693
+ and [1]_ for further details.
694
+
695
+ Parameters
696
+ ----------
697
+ survival_train : structured array, shape = (n_train_samples,)
698
+ Survival times for the training data, used to estimate the censoring
699
+ distribution.
700
+ A structured array with two fields. The first field is a boolean
701
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
702
+ The second field is a float with the time of event or time of censoring.
703
+ survival_test : structured array, shape = (n_samples,)
704
+ Survival times for the test data.
705
+ A structured array with two fields. The first field is a boolean
706
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
707
+ The second field is a float with the time of event or time of censoring.
708
+ estimate : array-like, shape = (n_samples, n_times)
709
+ Predicted survival probabilities for the test data at the time points
710
+ specified by ``times``, typically obtained from
711
+ ``estimator.predict_survival_function(X)``. The value of ``estimate[:, i]``
712
+ must correspond to the estimated survival probability up to
713
+ the time point ``times[i]``.
714
+ times : array-like, shape = (n_times,)
715
+ The time points at which to compute the Brier score. Values must be
716
+ within the range of follow-up times in ``survival_test``.
717
+
718
+ Returns
719
+ -------
720
+ ibs : float
721
+ The integrated Brier score.
722
+
723
+ Notes
724
+ -----
725
+ This metric expects survival probabilities, which are typically returned by
726
+ ``estimator.predict_survival_function(X)``.
727
+ It *does not accept* risk scores.
728
+
729
+ Examples
730
+ --------
731
+ >>> import numpy as np
732
+ >>> from sksurv.datasets import load_gbsg2
733
+ >>> from sksurv.linear_model import CoxPHSurvivalAnalysis
734
+ >>> from sksurv.metrics import integrated_brier_score
735
+ >>> from sksurv.preprocessing import OneHotEncoder
736
+
737
+ Load and prepare data.
738
+
739
+ >>> X, y = load_gbsg2()
740
+ >>> X["tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
741
+ >>> Xt = OneHotEncoder().fit_transform(X)
742
+
743
+ Fit a Cox model.
744
+
745
+ >>> est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)
746
+
747
+ Retrieve individual survival functions and get probability
748
+ of remaining event free from 1 year to 5 years (=1825 days).
749
+
750
+ >>> survs = est.predict_survival_function(Xt)
751
+ >>> times = np.arange(365, 1826)
752
+ >>> preds = np.asarray([[fn(t) for t in times] for fn in survs])
753
+
754
+ Compute the integrated Brier score from 1 to 5 years.
755
+
756
+ >>> score = integrated_brier_score(y, y, preds, times)
757
+ >>> print(round(score, 4))
758
+ 0.1816
759
+
760
+ See also
761
+ --------
762
+ brier_score
763
+ Computes the Brier score at specified time points.
764
+
765
+ as_integrated_brier_score_scorer
766
+ Wrapper class that uses :func:`integrated_brier_score`
767
+ in its ``score`` method instead of the default
768
+ :func:`concordance_index_censored`.
769
+
770
+ References
771
+ ----------
772
+ .. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher,
773
+ "Assessment and comparison of prognostic classification schemes for survival data,"
774
+ Statistics in Medicine, vol. 18, no. 17-18, pp. 2529–2545, 1999.
775
+ """
776
+ # Computing the brier scores
777
+ times, brier_scores = brier_score(survival_train, survival_test, estimate, times)
778
+
779
+ if times.shape[0] < 2:
780
+ raise ValueError("At least two time points must be given")
781
+
782
+ # Computing the IBS
783
+ ibs_value = np.trapz(brier_scores, times) / (times[-1] - times[0])
784
+
785
+ return ibs_value
786
+
787
+
788
+ def _estimator_has(attr):
789
+ """Check that meta_estimator has `attr`.
790
+
791
+ Used together with `available_if`."""
792
+
793
+ def check(self):
794
+ # raise original `AttributeError` if `attr` does not exist
795
+ getattr(self.estimator_, attr)
796
+ return True
797
+
798
+ return check
799
+
800
+
801
+ class _ScoreOverrideMixin:
802
+ def __init__(self, estimator, predict_func, score_func, score_index, greater_is_better):
803
+ if not hasattr(estimator, predict_func):
804
+ raise AttributeError(f"{estimator!r} object has no attribute {predict_func!r}")
805
+
806
+ self.estimator = estimator
807
+ self._predict_func = predict_func
808
+ self._score_func = score_func
809
+ self._score_index = score_index
810
+ self._sign = 1 if greater_is_better else -1
811
+
812
+ def _get_score_params(self):
813
+ """Return dict of parameters passed to ``score_func``."""
814
+ params = self.get_params(deep=False)
815
+ del params["estimator"]
816
+ return params
817
+
818
+ def fit(self, X, y, **fit_params):
819
+ self._train_y = np.array(y, copy=True)
820
+ self.estimator_ = self.estimator.fit(X, y, **fit_params)
821
+ return self
822
+
823
+ def _do_predict(self, X):
824
+ predict_func = getattr(self.estimator_, self._predict_func)
825
+ return predict_func(X)
826
+
827
+ def score(self, X, y):
828
+ """Returns the score on the given data.
829
+
830
+ Parameters
831
+ ----------
832
+ X : array-like, shape = (n_samples, n_features)
833
+ Input data, where n_samples is the number of samples and
834
+ n_features is the number of features.
835
+
836
+ y : array-like, shape = (n_samples,)
837
+ Target relative to X for classification or regression;
838
+ None for unsupervised learning.
839
+
840
+ Returns
841
+ -------
842
+ score : float
843
+ """
844
+ estimate = self._do_predict(X)
845
+ score = self._score_func(
846
+ survival_train=self._train_y,
847
+ survival_test=y,
848
+ estimate=estimate,
849
+ **self._get_score_params(),
850
+ )
851
+ if self._score_index is not None:
852
+ score = score[self._score_index]
853
+ return self._sign * score
854
+
855
+ @available_if(_estimator_has("predict"))
856
+ def predict(self, X):
857
+ """Call predict on the estimator.
858
+
859
+ Only available if estimator supports ``predict``.
860
+
861
+ Parameters
862
+ ----------
863
+ X : indexable, length n_samples
864
+ Must fulfill the input assumptions of the
865
+ underlying estimator.
866
+ """
867
+ check_is_fitted(self, "estimator_")
868
+ return self.estimator_.predict(X)
869
+
870
+ @available_if(_estimator_has("predict_cumulative_hazard_function"))
871
+ def predict_cumulative_hazard_function(self, X):
872
+ """Call predict_cumulative_hazard_function on the estimator.
873
+
874
+ Only available if estimator supports ``predict_cumulative_hazard_function``.
875
+
876
+ Parameters
877
+ ----------
878
+ X : indexable, length n_samples
879
+ Must fulfill the input assumptions of the
880
+ underlying estimator.
881
+ """
882
+ check_is_fitted(self, "estimator_")
883
+ return self.estimator_.predict_cumulative_hazard_function(X)
884
+
885
+ @available_if(_estimator_has("predict_survival_function"))
886
+ def predict_survival_function(self, X):
887
+ """Call predict_survival_function on the estimator.
888
+
889
+ Only available if estimator supports ``predict_survival_function``.
890
+
891
+ Parameters
892
+ ----------
893
+ X : indexable, length n_samples
894
+ Must fulfill the input assumptions of the
895
+ underlying estimator.
896
+ """
897
+ check_is_fitted(self, "estimator_")
898
+ return self.estimator_.predict_survival_function(X)
899
+
900
+
901
+ class as_cumulative_dynamic_auc_scorer(_ScoreOverrideMixin, BaseEstimator):
902
+ """Wraps an estimator to use :func:`cumulative_dynamic_auc` as ``score`` function.
903
+
904
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
905
+ for using it for hyper-parameter optimization.
906
+
907
+ Parameters
908
+ ----------
909
+ estimator : object
910
+ Instance of an estimator.
911
+ times : array-like, shape = (n_times,)
912
+ The time points at which to compute the AUC. Values must be within the
913
+ range of follow-up times of the test data.
914
+ tied_tol : float, optional, default: 1e-8
915
+ The tolerance value for considering ties in risk scores. If the
916
+ absolute difference between two risk scores is smaller than or equal to
917
+ ``tied_tol``, they are considered tied.
918
+
919
+ Attributes
920
+ ----------
921
+ estimator_ : estimator
922
+ Estimator that was fit.
923
+
924
+ See also
925
+ --------
926
+ cumulative_dynamic_auc
927
+ """
928
+
929
+ def __init__(self, estimator, times, tied_tol=1e-8):
930
+ super().__init__(
931
+ estimator=estimator,
932
+ predict_func="predict",
933
+ score_func=cumulative_dynamic_auc,
934
+ score_index=1,
935
+ greater_is_better=True,
936
+ )
937
+ self.times = times
938
+ self.tied_tol = tied_tol
939
+
940
+
941
+ class as_concordance_index_ipcw_scorer(_ScoreOverrideMixin, BaseEstimator):
942
+ r"""Wraps an estimator to use :func:`concordance_index_ipcw` as ``score`` function.
943
+
944
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
945
+ for using it for hyper-parameter optimization.
946
+
947
+ Parameters
948
+ ----------
949
+ estimator : object
950
+ Instance of an estimator.
951
+ tau : float, optional
952
+ Truncation time. The survival function for the underlying
953
+ censoring time distribution :math:`D` needs to be positive
954
+ at `tau`, i.e., `tau` should be chosen such that the
955
+ probability of being censored after time `tau` is non-zero:
956
+ :math:`P(D > \tau) > 0`. If `None`, no truncation is performed.
957
+ tied_tol : float, optional, default: 1e-8
958
+ The tolerance value for considering ties in risk scores.
959
+ If the absolute difference between two risk scores is smaller than
960
+ or equal to ``tied_tol``, they are considered tied.
961
+
962
+ Attributes
963
+ ----------
964
+ estimator_ : estimator
965
+ Estimator that was fit.
966
+
967
+ See also
968
+ --------
969
+ concordance_index_ipcw
970
+ """
971
+
972
+ def __init__(self, estimator, tau=None, tied_tol=1e-8):
973
+ super().__init__(
974
+ estimator=estimator,
975
+ predict_func="predict",
976
+ score_func=concordance_index_ipcw,
977
+ score_index=0,
978
+ greater_is_better=True,
979
+ )
980
+ self.tau = tau
981
+ self.tied_tol = tied_tol
982
+
983
+
984
+ class as_integrated_brier_score_scorer(_ScoreOverrideMixin, BaseEstimator):
985
+ """Wraps an estimator to use the negative of :func:`integrated_brier_score` as ``score`` function.
986
+
987
+ The estimator needs to be able to estimate survival functions via
988
+ a ``predict_survival_function`` method.
989
+
990
+ See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
991
+ for using it for hyper-parameter optimization.
992
+
993
+ Parameters
994
+ ----------
995
+ estimator : object
996
+ Instance of an estimator that provides ``predict_survival_function``.
997
+ times : array-like, shape = (n_times,)
998
+ The time points at which to compute the Brier score. Values must be
999
+ within the range of follow-up times of the test data.
1000
+
1001
+ Attributes
1002
+ ----------
1003
+ estimator_ : estimator
1004
+ Estimator that was fit.
1005
+
1006
+ See also
1007
+ --------
1008
+ integrated_brier_score
1009
+ """
1010
+
1011
+ def __init__(self, estimator, times):
1012
+ super().__init__(
1013
+ estimator=estimator,
1014
+ predict_func="predict_survival_function",
1015
+ score_func=integrated_brier_score,
1016
+ score_index=None,
1017
+ greater_is_better=False,
1018
+ )
1019
+ self.times = times
1020
+
1021
+ def _do_predict(self, X):
1022
+ predict_func = getattr(self.estimator_, self._predict_func)
1023
+ surv_fns = predict_func(X)
1024
+ times = self.times
1025
+ estimates = np.empty((len(surv_fns), len(times)))
1026
+ for i, fn in enumerate(surv_fns):
1027
+ estimates[i, :] = fn(times)
1028
+ return estimates