scikit-survival 0.25.0__cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.25.0.dist-info/METADATA +185 -0
- scikit_survival-0.25.0.dist-info/RECORD +58 -0
- scikit_survival-0.25.0.dist-info/WHEEL +6 -0
- scikit_survival-0.25.0.dist-info/licenses/COPYING +674 -0
- scikit_survival-0.25.0.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +183 -0
- sksurv/base.py +115 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cpython-312-x86_64-linux-gnu.so +0 -0
- sksurv/column.py +205 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +12 -0
- sksurv/datasets/base.py +614 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/bmt.arff +46 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/cgvhd.arff +118 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cpython-312-x86_64-linux-gnu.so +0 -0
- sksurv/ensemble/boosting.py +1564 -0
- sksurv/ensemble/forest.py +902 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +89 -0
- sksurv/io/arffwrite.py +181 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cpython-312-x86_64-linux-gnu.so +0 -0
- sksurv/kernels/clinical.py +348 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cpython-312-x86_64-linux-gnu.so +0 -0
- sksurv/linear_model/aft.py +208 -0
- sksurv/linear_model/coxnet.py +592 -0
- sksurv/linear_model/coxph.py +637 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +724 -0
- sksurv/meta/stacking.py +370 -0
- sksurv/metrics.py +1028 -0
- sksurv/nonparametric.py +911 -0
- sksurv/preprocessing.py +183 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cpython-312-x86_64-linux-gnu.so +0 -0
- sksurv/svm/_prsvm.cpython-312-x86_64-linux-gnu.so +0 -0
- sksurv/svm/minlip.py +690 -0
- sksurv/svm/naive_survival_svm.py +249 -0
- sksurv/svm/survival_svm.py +1236 -0
- sksurv/testing.py +108 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cpython-312-x86_64-linux-gnu.so +0 -0
- sksurv/tree/tree.py +790 -0
- sksurv/util.py +415 -0
sksurv/metrics.py
ADDED
|
@@ -0,0 +1,1028 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import numpy as np
|
|
14
|
+
from sklearn.base import BaseEstimator
|
|
15
|
+
from sklearn.utils.metaestimators import available_if
|
|
16
|
+
from sklearn.utils.validation import check_array, check_consistent_length, check_is_fitted
|
|
17
|
+
|
|
18
|
+
from .exceptions import NoComparablePairException
|
|
19
|
+
from .nonparametric import CensoringDistributionEstimator, SurvivalFunctionEstimator
|
|
20
|
+
from .util import check_y_survival
|
|
21
|
+
|
|
22
|
+
__all__ = [
|
|
23
|
+
"as_concordance_index_ipcw_scorer",
|
|
24
|
+
"as_cumulative_dynamic_auc_scorer",
|
|
25
|
+
"as_integrated_brier_score_scorer",
|
|
26
|
+
"brier_score",
|
|
27
|
+
"concordance_index_censored",
|
|
28
|
+
"concordance_index_ipcw",
|
|
29
|
+
"cumulative_dynamic_auc",
|
|
30
|
+
"integrated_brier_score",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _check_estimate_1d(estimate, test_time):
|
|
35
|
+
estimate = check_array(estimate, ensure_2d=False, input_name="estimate")
|
|
36
|
+
if estimate.ndim != 1:
|
|
37
|
+
raise ValueError(f"Expected 1D array, got {estimate.ndim}D array instead:\narray={estimate}.\n")
|
|
38
|
+
check_consistent_length(test_time, estimate)
|
|
39
|
+
return estimate
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _check_inputs(event_indicator, event_time, estimate):
|
|
43
|
+
check_consistent_length(event_indicator, event_time, estimate)
|
|
44
|
+
event_indicator = check_array(event_indicator, ensure_2d=False, input_name="event_indicator")
|
|
45
|
+
event_time = check_array(event_time, ensure_2d=False, input_name="event_time")
|
|
46
|
+
estimate = _check_estimate_1d(estimate, event_time)
|
|
47
|
+
|
|
48
|
+
if not np.issubdtype(event_indicator.dtype, np.bool_):
|
|
49
|
+
raise ValueError(
|
|
50
|
+
f"only boolean arrays are supported as class labels for survival analysis, got {event_indicator.dtype}"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if len(event_time) < 2:
|
|
54
|
+
raise ValueError("Need a minimum of two samples")
|
|
55
|
+
|
|
56
|
+
if not event_indicator.any():
|
|
57
|
+
raise ValueError("All samples are censored")
|
|
58
|
+
|
|
59
|
+
return event_indicator, event_time, estimate
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _check_times(test_time, times):
|
|
63
|
+
times = check_array(np.atleast_1d(times), ensure_2d=False, input_name="times")
|
|
64
|
+
times = np.unique(times)
|
|
65
|
+
|
|
66
|
+
if times.max() >= test_time.max() or times.min() < test_time.min():
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"all times must be within follow-up time of test data: [{test_time.min()}; {test_time.max()}["
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return times
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _check_estimate_2d(estimate, test_time, time_points, estimator):
|
|
75
|
+
estimate = check_array(estimate, ensure_2d=False, allow_nd=False, input_name="estimate", estimator=estimator)
|
|
76
|
+
time_points = _check_times(test_time, time_points)
|
|
77
|
+
check_consistent_length(test_time, estimate)
|
|
78
|
+
|
|
79
|
+
if estimate.ndim == 2 and estimate.shape[1] != time_points.shape[0]:
|
|
80
|
+
raise ValueError(f"expected estimate with {time_points.shape[0]} columns, but got {estimate.shape[1]}")
|
|
81
|
+
|
|
82
|
+
return estimate, time_points
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _iter_comparable(event_indicator, event_time, order):
|
|
86
|
+
n_samples = len(event_time)
|
|
87
|
+
tied_time = 0
|
|
88
|
+
i = 0
|
|
89
|
+
while i < n_samples - 1:
|
|
90
|
+
time_i = event_time[order[i]]
|
|
91
|
+
end = i + 1
|
|
92
|
+
while end < n_samples and event_time[order[end]] == time_i:
|
|
93
|
+
end += 1
|
|
94
|
+
|
|
95
|
+
# check for tied event times
|
|
96
|
+
event_at_same_time = event_indicator[order[i:end]]
|
|
97
|
+
censored_at_same_time = ~event_at_same_time
|
|
98
|
+
|
|
99
|
+
mask = np.zeros(n_samples, dtype=bool)
|
|
100
|
+
mask[end:] = True
|
|
101
|
+
# an event is comparable to censored samples at same time point
|
|
102
|
+
mask[i:end] = censored_at_same_time
|
|
103
|
+
|
|
104
|
+
for j in range(i, end):
|
|
105
|
+
if event_indicator[order[j]]:
|
|
106
|
+
tied_time += censored_at_same_time.sum()
|
|
107
|
+
yield (j, mask, tied_time)
|
|
108
|
+
i = end
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _estimate_concordance_index(event_indicator, event_time, estimate, weights, tied_tol=1e-8):
|
|
112
|
+
order = np.argsort(event_time)
|
|
113
|
+
|
|
114
|
+
tied_time = None
|
|
115
|
+
|
|
116
|
+
concordant = 0
|
|
117
|
+
discordant = 0
|
|
118
|
+
tied_risk = 0
|
|
119
|
+
numerator = 0.0
|
|
120
|
+
denominator = 0.0
|
|
121
|
+
for ind, mask, tied_time in _iter_comparable(event_indicator, event_time, order):
|
|
122
|
+
est_i = estimate[order[ind]]
|
|
123
|
+
event_i = event_indicator[order[ind]]
|
|
124
|
+
w_i = weights[order[ind]]
|
|
125
|
+
|
|
126
|
+
est = estimate[order[mask]]
|
|
127
|
+
|
|
128
|
+
assert event_i, f"got censored sample at index {order[ind]}, but expected uncensored"
|
|
129
|
+
|
|
130
|
+
ties = np.absolute(est - est_i) <= tied_tol
|
|
131
|
+
n_ties = ties.sum()
|
|
132
|
+
# an event should have a higher score
|
|
133
|
+
con = est < est_i
|
|
134
|
+
n_con = con[~ties].sum()
|
|
135
|
+
|
|
136
|
+
numerator += w_i * n_con + 0.5 * w_i * n_ties
|
|
137
|
+
denominator += w_i * mask.sum()
|
|
138
|
+
|
|
139
|
+
tied_risk += n_ties
|
|
140
|
+
concordant += n_con
|
|
141
|
+
discordant += est.size - n_con - n_ties
|
|
142
|
+
|
|
143
|
+
if tied_time is None:
|
|
144
|
+
raise NoComparablePairException("Data has no comparable pairs, cannot estimate concordance index.")
|
|
145
|
+
|
|
146
|
+
cindex = numerator / denominator
|
|
147
|
+
return cindex, concordant, discordant, tied_risk, tied_time
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def concordance_index_censored(event_indicator, event_time, estimate, tied_tol=1e-8):
|
|
151
|
+
"""Measures the agreement between a predicted risk score and the actual time-to-event.
|
|
152
|
+
|
|
153
|
+
The concordance index is a measure of rank correlation between predicted risk
|
|
154
|
+
scores and observed time points. It is defined as the proportion of all comparable
|
|
155
|
+
pairs in which the predictions and outcomes are concordant.
|
|
156
|
+
A pair of samples is concordant if the sample with a higher risk score has a
|
|
157
|
+
shorter time-to-event. A higher concordance index indicates better model performance.
|
|
158
|
+
|
|
159
|
+
A pair of samples is considered comparable if the sample with a shorter
|
|
160
|
+
survival time experienced an event. This means we can confidently say that
|
|
161
|
+
the individual with the shorter time had a worse outcome. If both samples
|
|
162
|
+
are censored, or if they experienced an event at the same time, they are
|
|
163
|
+
not comparable.
|
|
164
|
+
|
|
165
|
+
When predicted risks are identical for a pair, 0.5 rather than 1 is added to the count
|
|
166
|
+
of concordant pairs.
|
|
167
|
+
|
|
168
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb>`
|
|
169
|
+
and [1]_ for further description.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
event_indicator : array-like, shape = (n_samples,)
|
|
174
|
+
A boolean array where ``True`` indicates an event and ``False`` indicates
|
|
175
|
+
censoring.
|
|
176
|
+
event_time : array-like, shape = (n_samples,)
|
|
177
|
+
Array containing the time of an event or time of censoring.
|
|
178
|
+
estimate : array-like, shape = (n_samples,)
|
|
179
|
+
The predicted risk score for each sample (e.g., from ``estimator.predict(X)``).
|
|
180
|
+
A higher value indicates a higher risk of experiencing an event.
|
|
181
|
+
tied_tol : float, optional, default: 1e-8
|
|
182
|
+
The tolerance value for considering ties in risk scores. If the
|
|
183
|
+
absolute difference between two risk scores is smaller than or equal to
|
|
184
|
+
``tied_tol``, they are considered tied.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
cindex : float
|
|
189
|
+
The concordance index.
|
|
190
|
+
concordant : int
|
|
191
|
+
The number of concordant pairs.
|
|
192
|
+
discordant : int
|
|
193
|
+
The number of discordant pairs.
|
|
194
|
+
tied_risk : int
|
|
195
|
+
The number of pairs with tied risk scores.
|
|
196
|
+
tied_time : int
|
|
197
|
+
The number of comparable pairs with tied survival times.
|
|
198
|
+
|
|
199
|
+
Notes
|
|
200
|
+
-----
|
|
201
|
+
This metric expects risk scores, which are typically returned by ``estimator.predict(X)``.
|
|
202
|
+
It *does not accept* survival probabilities.
|
|
203
|
+
|
|
204
|
+
See also
|
|
205
|
+
--------
|
|
206
|
+
concordance_index_ipcw
|
|
207
|
+
A less biased estimator of the concordance index.
|
|
208
|
+
|
|
209
|
+
References
|
|
210
|
+
----------
|
|
211
|
+
.. [1] Harrell, F.E., Califf, R.M., Pryor, D.B., Lee, K.L., Rosati, R.A,
|
|
212
|
+
"Multivariable prognostic models: issues in developing models,
|
|
213
|
+
evaluating assumptions and adequacy, and measuring and reducing errors",
|
|
214
|
+
Statistics in Medicine, 15(4), 361-87, 1996.
|
|
215
|
+
"""
|
|
216
|
+
event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
|
|
217
|
+
|
|
218
|
+
w = np.ones_like(estimate)
|
|
219
|
+
|
|
220
|
+
return _estimate_concordance_index(event_indicator, event_time, estimate, w, tied_tol)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def concordance_index_ipcw(survival_train, survival_test, estimate, tau=None, tied_tol=1e-8):
|
|
224
|
+
r"""Concordance index for right-censored data based on inverse probability of censoring weights.
|
|
225
|
+
|
|
226
|
+
This is an alternative to the estimator in :func:`concordance_index_censored`
|
|
227
|
+
that does not depend on the distribution of censoring times in the test data.
|
|
228
|
+
By using inverse probability of censoring weights (IPCW), it provides an unbiased
|
|
229
|
+
and consistent estimate of the population concordance measure.
|
|
230
|
+
|
|
231
|
+
This estimator requires access to survival times from the training data to
|
|
232
|
+
estimate the censoring distribution.
|
|
233
|
+
Note that survival times in `survival_test` must lie within the range of
|
|
234
|
+
survival times in `survival_train`. This can be achieved by specifying the
|
|
235
|
+
truncation time `tau`.
|
|
236
|
+
The resulting `cindex` tells how well the given prediction model works in
|
|
237
|
+
predicting events that occur in the time range from 0 to `tau`.
|
|
238
|
+
|
|
239
|
+
For time points in `survival_test` that lie outside of the range specified by
|
|
240
|
+
values in `survival_train`, the probability of censoring is unknown and an
|
|
241
|
+
exception will be raised::
|
|
242
|
+
|
|
243
|
+
ValueError: time must be smaller than largest observed time point
|
|
244
|
+
|
|
245
|
+
The censoring distribution is estimated using the Kaplan-Meier estimator, which
|
|
246
|
+
assumes that censoring is random and independent of the features.
|
|
247
|
+
|
|
248
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb>`
|
|
249
|
+
and [1]_ for further description.
|
|
250
|
+
|
|
251
|
+
Parameters
|
|
252
|
+
----------
|
|
253
|
+
survival_train : structured array, shape = (n_train_samples,)
|
|
254
|
+
Survival times for the training data, used to estimate the censoring
|
|
255
|
+
distribution.
|
|
256
|
+
A structured array with two fields. The first field is a boolean
|
|
257
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
258
|
+
The second field is a float with the time of event or time of censoring.
|
|
259
|
+
survival_test : structured array, shape = (n_samples,)
|
|
260
|
+
Survival times for the test data.
|
|
261
|
+
A structured array with two fields. The first field is a boolean
|
|
262
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
263
|
+
The second field is a float with the time of event or time of censoring.
|
|
264
|
+
estimate : array-like, shape = (n_samples,)
|
|
265
|
+
Predicted risk scores for the test data (e.g., from ``estimator.predict(X)``).
|
|
266
|
+
A higher value indicates a higher risk of experiencing an event.
|
|
267
|
+
tau : float, optional
|
|
268
|
+
Truncation time. The survival function for the underlying
|
|
269
|
+
censoring time distribution :math:`D` needs to be positive
|
|
270
|
+
at `tau`, i.e., `tau` should be chosen such that the
|
|
271
|
+
probability of being censored after time `tau` is non-zero:
|
|
272
|
+
:math:`P(D > \tau) > 0`. If `None`, no truncation is performed.
|
|
273
|
+
tied_tol : float, optional, default: 1e-8
|
|
274
|
+
The tolerance value for considering ties in risk scores.
|
|
275
|
+
If the absolute difference between two risk scores is smaller than
|
|
276
|
+
or equal to ``tied_tol``, they are considered tied.
|
|
277
|
+
|
|
278
|
+
Returns
|
|
279
|
+
-------
|
|
280
|
+
cindex : float
|
|
281
|
+
The concordance index.
|
|
282
|
+
concordant : int
|
|
283
|
+
The number of concordant pairs.
|
|
284
|
+
discordant : int
|
|
285
|
+
The number of discordant pairs.
|
|
286
|
+
tied_risk : int
|
|
287
|
+
The number of pairs with tied risk scores.
|
|
288
|
+
tied_time : int
|
|
289
|
+
The number of comparable pairs with tied survival times.
|
|
290
|
+
|
|
291
|
+
Notes
|
|
292
|
+
-----
|
|
293
|
+
This metric expects risk scores, which are typically returned by ``estimator.predict(X)``.
|
|
294
|
+
It *does not accept* survival probabilities.
|
|
295
|
+
|
|
296
|
+
See also
|
|
297
|
+
--------
|
|
298
|
+
concordance_index_censored
|
|
299
|
+
A simpler, but potentially biased, estimator of the concordance index.
|
|
300
|
+
as_concordance_index_ipcw_scorer
|
|
301
|
+
A wrapper class that uses :func:`concordance_index_ipcw`
|
|
302
|
+
in its ``score`` method instead of the default
|
|
303
|
+
:func:`concordance_index_censored`.
|
|
304
|
+
|
|
305
|
+
References
|
|
306
|
+
----------
|
|
307
|
+
.. [1] Uno, H., Cai, T., Pencina, M. J., D’Agostino, R. B., & Wei, L. J. (2011).
|
|
308
|
+
"On the C-statistics for evaluating overall adequacy of risk prediction
|
|
309
|
+
procedures with censored survival data".
|
|
310
|
+
Statistics in Medicine, 30(10), 1105–1117.
|
|
311
|
+
"""
|
|
312
|
+
test_event, test_time = check_y_survival(survival_test)
|
|
313
|
+
|
|
314
|
+
if tau is not None:
|
|
315
|
+
mask = test_time < tau
|
|
316
|
+
survival_test = survival_test[mask]
|
|
317
|
+
|
|
318
|
+
estimate = _check_estimate_1d(estimate, test_time)
|
|
319
|
+
|
|
320
|
+
cens = CensoringDistributionEstimator()
|
|
321
|
+
cens.fit(survival_train)
|
|
322
|
+
ipcw_test = cens.predict_ipcw(survival_test)
|
|
323
|
+
if tau is None:
|
|
324
|
+
ipcw = ipcw_test
|
|
325
|
+
else:
|
|
326
|
+
ipcw = np.empty(estimate.shape[0], dtype=ipcw_test.dtype)
|
|
327
|
+
ipcw[mask] = ipcw_test
|
|
328
|
+
ipcw[~mask] = 0
|
|
329
|
+
|
|
330
|
+
w = np.square(ipcw)
|
|
331
|
+
|
|
332
|
+
return _estimate_concordance_index(test_event, test_time, estimate, w, tied_tol)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def cumulative_dynamic_auc(survival_train, survival_test, estimate, times, tied_tol=1e-8):
|
|
336
|
+
r"""Computes the cumulative/dynamic area under the ROC curve (AUC) for right-censored data.
|
|
337
|
+
|
|
338
|
+
This metric evaluates a model's performance at specific time points.
|
|
339
|
+
The cumulative/dynamic AUC at time :math:`t` quantifies how well a model can
|
|
340
|
+
distinguish subjects who experience an event by time :math:`t` (cases) from
|
|
341
|
+
those who do not (controls). A higher AUC indicates better model performance.
|
|
342
|
+
|
|
343
|
+
This function can also evaluate models with time-dependent predictions, such as
|
|
344
|
+
:class:`sksurv.ensemble.RandomSurvivalForest`
|
|
345
|
+
(see :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Time-dependent-Risk-Scores>`).
|
|
346
|
+
In this case, ``estimate`` must be a 2D array where ``estimate[i, j]`` is the
|
|
347
|
+
predicted risk score for the :math:`i`-th instance at time point ``times[j]``.
|
|
348
|
+
|
|
349
|
+
The receiver operating characteristic (ROC) curve and the area under the
|
|
350
|
+
ROC curve (AUC) are metrics to evaluate a binary classifier. Each point on
|
|
351
|
+
the ROC denotes the performance of a binary classifier at a specific
|
|
352
|
+
threshold with respect to the sensitivity (true positive rate) on the
|
|
353
|
+
y-axis and the specificity (true negative rate) on the x-axis.
|
|
354
|
+
|
|
355
|
+
ROC and AUC can be extended to survival analysis by defining cases and
|
|
356
|
+
controls based on a time point :math:`t`. *Cumulative cases* are all
|
|
357
|
+
individuals that experienced an event prior to or at time
|
|
358
|
+
:math:`t` (:math:`t_i \leq t`), whereas *dynamic controls* are those
|
|
359
|
+
with :math:`t_i > t`. Given an estimator of the :math:`i`-th individual's
|
|
360
|
+
risk score :math:`\hat{f}(\mathbf{x}_i)`, the cumulative/dynamic AUC at
|
|
361
|
+
time :math:`t` is defined as
|
|
362
|
+
|
|
363
|
+
.. math::
|
|
364
|
+
|
|
365
|
+
\widehat{\mathrm{AUC}}(t) =
|
|
366
|
+
\frac{\sum_{i=1}^n \sum_{j=1}^n I(y_j > t) I(y_i \leq t) \omega_i
|
|
367
|
+
I(\hat{f}(\mathbf{x}_j) \leq \hat{f}(\mathbf{x}_i))}
|
|
368
|
+
{(\sum_{i=1}^n I(y_i > t)) (\sum_{i=1}^n I(y_i \leq t) \omega_i)}
|
|
369
|
+
|
|
370
|
+
where :math:`\omega_i` are inverse probability of censoring weights (IPCW).
|
|
371
|
+
|
|
372
|
+
To account for censoring, this metric uses inverse probability of censoring
|
|
373
|
+
weights (IPCW), which requires access to survival times from the training
|
|
374
|
+
data to estimate the censoring distribution. Note that survival times in
|
|
375
|
+
``survival_test`` must lie within the range of survival times in ``survival_train``.
|
|
376
|
+
This can be achieved by specifying ``times`` accordingly, e.g. by setting
|
|
377
|
+
``times[-1]`` slightly below the maximum expected follow-up time.
|
|
378
|
+
|
|
379
|
+
For time points in ``survival_test`` that lie outside of the range specified by
|
|
380
|
+
values in ``survival_train``, the probability of censoring is unknown and an
|
|
381
|
+
exception will be raised::
|
|
382
|
+
|
|
383
|
+
ValueError: time must be smaller than largest observed time point
|
|
384
|
+
|
|
385
|
+
The censoring distribution is estimated using the Kaplan-Meier estimator, which
|
|
386
|
+
assumes that censoring is random and independent of the features.
|
|
387
|
+
|
|
388
|
+
The function also returns a summary measure, which is the mean of the
|
|
389
|
+
:math:`\mathrm{AUC}(t)` over the specified time range, weighted by the
|
|
390
|
+
estimated survival function:
|
|
391
|
+
|
|
392
|
+
.. math::
|
|
393
|
+
|
|
394
|
+
\overline{\mathrm{AUC}}(\tau_1, \tau_2) =
|
|
395
|
+
\frac{1}{\hat{S}(\tau_1) - \hat{S}(\tau_2)}
|
|
396
|
+
\int_{\tau_1}^{\tau_2} \widehat{\mathrm{AUC}}(t)\,d \hat{S}(t)
|
|
397
|
+
|
|
398
|
+
where :math:`\hat{S}(t)` is the Kaplan–Meier estimator of the survival function.
|
|
399
|
+
|
|
400
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Area-under-the-ROC>`,
|
|
401
|
+
[1]_, [2]_, [3]_ for further description.
|
|
402
|
+
|
|
403
|
+
Parameters
|
|
404
|
+
----------
|
|
405
|
+
survival_train : structured array, shape = (n_train_samples,)
|
|
406
|
+
Survival times for the training data, used to estimate the censoring
|
|
407
|
+
distribution.
|
|
408
|
+
A structured array with two fields. The first field is a boolean
|
|
409
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
410
|
+
The second field is a float with the time of event or time of censoring.
|
|
411
|
+
survival_test : structured array, shape = (n_samples,)
|
|
412
|
+
Survival times for the test data.
|
|
413
|
+
A structured array with two fields. The first field is a boolean
|
|
414
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
415
|
+
The second field is a float with the time of event or time of censoring.
|
|
416
|
+
estimate : array-like, shape = (n_samples,) or (n_samples, n_times)
|
|
417
|
+
Predicted risk scores for the test data (e.g., from ``estimator.predict(X)``.
|
|
418
|
+
A higher value indicates a higher risk of experiencing an event.
|
|
419
|
+
If a 1D array is provided, the same risk score is used for all time points.
|
|
420
|
+
If a 2D array is provided, ``estimate[:, j]`` is used for the :math:`j`-th
|
|
421
|
+
time point.
|
|
422
|
+
times : array-like, shape = (n_times,)
|
|
423
|
+
The time points at which to compute the AUC. Values must be within the
|
|
424
|
+
range of follow-up times in ``survival_test``.
|
|
425
|
+
tied_tol : float, optional, default: 1e-8
|
|
426
|
+
The tolerance value for considering ties in risk scores. If the
|
|
427
|
+
absolute difference between two risk scores is smaller than or equal to
|
|
428
|
+
``tied_tol``, they are considered tied.
|
|
429
|
+
|
|
430
|
+
Returns
|
|
431
|
+
-------
|
|
432
|
+
auc : ndarray, shape = (n_times,)
|
|
433
|
+
The cumulative/dynamic AUC estimates at each time point in ``times``.
|
|
434
|
+
mean_auc : float
|
|
435
|
+
The mean cumulative/dynamic AUC over the specified time range ``(times[0], times[-1])``.
|
|
436
|
+
|
|
437
|
+
Notes
|
|
438
|
+
-----
|
|
439
|
+
This metric expects risk scores, which are typically returned by ``estimator.predict(X)``
|
|
440
|
+
(for time-independent risks), or ``estimator.predict_cumulative_hazard_function(X)``
|
|
441
|
+
(for time-dependent risks). It *does not accept* survival probabilities.
|
|
442
|
+
|
|
443
|
+
See also
|
|
444
|
+
--------
|
|
445
|
+
as_cumulative_dynamic_auc_scorer
|
|
446
|
+
A wrapper class that uses :func:`cumulative_dynamic_auc`
|
|
447
|
+
in its ``score`` method instead of the default
|
|
448
|
+
:func:`concordance_index_censored`.
|
|
449
|
+
|
|
450
|
+
References
|
|
451
|
+
----------
|
|
452
|
+
.. [1] H. Uno, T. Cai, L. Tian, and L. J. Wei,
|
|
453
|
+
"Evaluating prediction rules for t-year survivors with censored regression models,"
|
|
454
|
+
Journal of the American Statistical Association, vol. 102, pp. 527–537, 2007.
|
|
455
|
+
.. [2] H. Hung and C. T. Chiang,
|
|
456
|
+
"Estimation methods for time-dependent AUC models with survival data,"
|
|
457
|
+
Canadian Journal of Statistics, vol. 38, no. 1, pp. 8–26, 2010.
|
|
458
|
+
.. [3] J. Lambert and S. Chevret,
|
|
459
|
+
"Summary measure of discrimination in survival models based on cumulative/dynamic time-dependent ROC curves,"
|
|
460
|
+
Statistical Methods in Medical Research, 2014.
|
|
461
|
+
"""
|
|
462
|
+
test_event, test_time = check_y_survival(survival_test)
|
|
463
|
+
estimate, times = _check_estimate_2d(estimate, test_time, times, estimator="cumulative_dynamic_auc")
|
|
464
|
+
|
|
465
|
+
n_samples = estimate.shape[0]
|
|
466
|
+
n_times = times.shape[0]
|
|
467
|
+
if estimate.ndim == 1:
|
|
468
|
+
estimate = np.broadcast_to(estimate[:, np.newaxis], (n_samples, n_times))
|
|
469
|
+
|
|
470
|
+
# fit and transform IPCW
|
|
471
|
+
cens = CensoringDistributionEstimator()
|
|
472
|
+
cens.fit(survival_train)
|
|
473
|
+
ipcw = cens.predict_ipcw(survival_test)
|
|
474
|
+
|
|
475
|
+
# expand arrays to (n_samples, n_times) shape
|
|
476
|
+
test_time = np.broadcast_to(test_time[:, np.newaxis], (n_samples, n_times))
|
|
477
|
+
test_event = np.broadcast_to(test_event[:, np.newaxis], (n_samples, n_times))
|
|
478
|
+
times_2d = np.broadcast_to(times, (n_samples, n_times))
|
|
479
|
+
ipcw = np.broadcast_to(ipcw[:, np.newaxis], (n_samples, n_times))
|
|
480
|
+
|
|
481
|
+
# sort each time point (columns) by risk score (descending)
|
|
482
|
+
o = np.argsort(-estimate, axis=0)
|
|
483
|
+
test_time = np.take_along_axis(test_time, o, axis=0)
|
|
484
|
+
test_event = np.take_along_axis(test_event, o, axis=0)
|
|
485
|
+
estimate = np.take_along_axis(estimate, o, axis=0)
|
|
486
|
+
ipcw = np.take_along_axis(ipcw, o, axis=0)
|
|
487
|
+
|
|
488
|
+
is_case = (test_time <= times_2d) & test_event
|
|
489
|
+
is_control = test_time > times_2d
|
|
490
|
+
n_controls = is_control.sum(axis=0)
|
|
491
|
+
|
|
492
|
+
# prepend row of infinity values
|
|
493
|
+
estimate_diff = np.concatenate((np.broadcast_to(np.inf, (1, n_times)), estimate))
|
|
494
|
+
is_tied = np.absolute(np.diff(estimate_diff, axis=0)) <= tied_tol
|
|
495
|
+
|
|
496
|
+
cumsum_tp = np.cumsum(is_case * ipcw, axis=0)
|
|
497
|
+
cumsum_fp = np.cumsum(is_control, axis=0)
|
|
498
|
+
true_pos = cumsum_tp / cumsum_tp[-1]
|
|
499
|
+
false_pos = cumsum_fp / n_controls
|
|
500
|
+
|
|
501
|
+
scores = np.empty(n_times, dtype=float)
|
|
502
|
+
it = np.nditer((true_pos, false_pos, is_tied), order="F", flags=["external_loop"])
|
|
503
|
+
with it:
|
|
504
|
+
for i, (tp, fp, mask) in enumerate(it):
|
|
505
|
+
idx = np.flatnonzero(mask) - 1
|
|
506
|
+
# only keep the last estimate for tied risk scores
|
|
507
|
+
tp_no_ties = np.delete(tp, idx)
|
|
508
|
+
fp_no_ties = np.delete(fp, idx)
|
|
509
|
+
# Add an extra threshold position
|
|
510
|
+
# to make sure that the curve starts at (0, 0)
|
|
511
|
+
tp_no_ties = np.r_[0, tp_no_ties]
|
|
512
|
+
fp_no_ties = np.r_[0, fp_no_ties]
|
|
513
|
+
scores[i] = np.trapz(tp_no_ties, fp_no_ties)
|
|
514
|
+
|
|
515
|
+
if n_times == 1:
|
|
516
|
+
mean_auc = scores[0]
|
|
517
|
+
else:
|
|
518
|
+
surv = SurvivalFunctionEstimator()
|
|
519
|
+
surv.fit(survival_test)
|
|
520
|
+
s_times = surv.predict_proba(times)
|
|
521
|
+
# compute integral of AUC over survival function
|
|
522
|
+
d = -np.diff(np.r_[1.0, s_times])
|
|
523
|
+
integral = (scores * d).sum()
|
|
524
|
+
mean_auc = integral / (1.0 - s_times[-1])
|
|
525
|
+
|
|
526
|
+
return scores, mean_auc
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def brier_score(survival_train, survival_test, estimate, times):
|
|
530
|
+
r"""The time-dependent Brier score for right-censored data.
|
|
531
|
+
|
|
532
|
+
The time-dependent Brier score measures the inaccuracy of
|
|
533
|
+
predicted survival probabilities at a given time point.
|
|
534
|
+
It is the mean squared error between the true survival status
|
|
535
|
+
and the predicted survival probability at time point :math:`t`.
|
|
536
|
+
A lower Brier score indicates better model performance.
|
|
537
|
+
|
|
538
|
+
To account for censoring, this metric uses inverse probability of censoring
|
|
539
|
+
weights (IPCW), which requires access to survival times from the training
|
|
540
|
+
data to estimate the censoring distribution. Note that survival times in
|
|
541
|
+
``survival_test`` must lie within the range of survival times in ``survival_train``.
|
|
542
|
+
This can be achieved by specifying ``times`` accordingly, e.g. by setting
|
|
543
|
+
``times[-1]`` slightly below the maximum expected follow-up time.
|
|
544
|
+
|
|
545
|
+
For time points in ``survival_test`` that lie outside of the range specified by
|
|
546
|
+
values in ``survival_train``, the probability of censoring is unknown and an
|
|
547
|
+
exception will be raised::
|
|
548
|
+
|
|
549
|
+
ValueError: time must be smaller than largest observed time point
|
|
550
|
+
|
|
551
|
+
The censoring distribution is estimated using the Kaplan-Meier estimator, which
|
|
552
|
+
assumes that censoring is random and independent of the features.
|
|
553
|
+
|
|
554
|
+
The time-dependent Brier score at time :math:`t` is defined as
|
|
555
|
+
|
|
556
|
+
.. math::
|
|
557
|
+
|
|
558
|
+
\mathrm{BS}^c(t) = \frac{1}{n} \sum_{i=1}^n I(y_i \leq t \land \delta_i = 1)
|
|
559
|
+
\frac{(0 - \hat{\pi}(t | \mathbf{x}_i))^2}{\hat{G}(y_i)} + I(y_i > t)
|
|
560
|
+
\frac{(1 - \hat{\pi}(t | \mathbf{x}_i))^2}{\hat{G}(t)} ,
|
|
561
|
+
|
|
562
|
+
where :math:`\hat{\pi}(t | \mathbf{x})` is the predicted survival probability
|
|
563
|
+
up to the time point :math:`t` for a feature vector :math:`\mathbf{x}`,
|
|
564
|
+
and :math:`1/\hat{G}(t)` is a inverse probability of censoring weight.
|
|
565
|
+
|
|
566
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Brier-Score>`
|
|
567
|
+
and [1]_ for details.
|
|
568
|
+
|
|
569
|
+
Parameters
|
|
570
|
+
----------
|
|
571
|
+
survival_train : structured array, shape = (n_train_samples,)
|
|
572
|
+
Survival times for the training data, used to estimate the censoring
|
|
573
|
+
distribution.
|
|
574
|
+
A structured array with two fields. The first field is a boolean
|
|
575
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
576
|
+
The second field is a float with the time of event or time of censoring.
|
|
577
|
+
survival_test : structured array, shape = (n_samples,)
|
|
578
|
+
Survival times for the test data.
|
|
579
|
+
A structured array with two fields. The first field is a boolean
|
|
580
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
581
|
+
The second field is a float with the time of event or time of censoring.
|
|
582
|
+
estimate : array-like, shape = (n_samples, n_times)
|
|
583
|
+
Predicted survival probabilities for the test data at the time points
|
|
584
|
+
specified by ``times``, typically obtained from
|
|
585
|
+
``estimator.predict_survival_function(X)``. The value of ``estimate[:, i]``
|
|
586
|
+
must correspond to the estimated survival probability up to
|
|
587
|
+
the time point ``times[i]``.
|
|
588
|
+
times : array-like, shape = (n_times,)
|
|
589
|
+
The time points at which to compute the Brier score. Values must be
|
|
590
|
+
within the range of follow-up times in ``survival_test``.
|
|
591
|
+
|
|
592
|
+
Returns
|
|
593
|
+
-------
|
|
594
|
+
times : ndarray, shape = (n_times,)
|
|
595
|
+
The unique time points at which the Brier score was estimated.
|
|
596
|
+
brier_scores : ndarray, shape = (n_times,)
|
|
597
|
+
The Brier score at each time point in ``times``.
|
|
598
|
+
|
|
599
|
+
Notes
|
|
600
|
+
-----
|
|
601
|
+
This metric expects survival probabilities, which are typically returned by
|
|
602
|
+
``estimator.predict_survival_function(X)``.
|
|
603
|
+
It *does not accept* risk scores.
|
|
604
|
+
|
|
605
|
+
Examples
|
|
606
|
+
--------
|
|
607
|
+
>>> from sksurv.datasets import load_gbsg2
|
|
608
|
+
>>> from sksurv.linear_model import CoxPHSurvivalAnalysis
|
|
609
|
+
>>> from sksurv.metrics import brier_score
|
|
610
|
+
>>> from sksurv.preprocessing import OneHotEncoder
|
|
611
|
+
|
|
612
|
+
Load and prepare data.
|
|
613
|
+
|
|
614
|
+
>>> X, y = load_gbsg2()
|
|
615
|
+
>>> X["tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
|
|
616
|
+
>>> Xt = OneHotEncoder().fit_transform(X)
|
|
617
|
+
|
|
618
|
+
Fit a Cox model.
|
|
619
|
+
|
|
620
|
+
>>> est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)
|
|
621
|
+
|
|
622
|
+
Retrieve individual survival functions and get probability
|
|
623
|
+
of remaining event free up to 5 years (=1825 days).
|
|
624
|
+
|
|
625
|
+
>>> survs = est.predict_survival_function(Xt)
|
|
626
|
+
>>> preds = [fn(1825) for fn in survs]
|
|
627
|
+
|
|
628
|
+
Compute the Brier score at 5 years.
|
|
629
|
+
|
|
630
|
+
>>> times, score = brier_score(y, y, preds, 1825)
|
|
631
|
+
>>> print(score)
|
|
632
|
+
[0.20881843]
|
|
633
|
+
|
|
634
|
+
See also
|
|
635
|
+
--------
|
|
636
|
+
integrated_brier_score
|
|
637
|
+
Computes the average Brier score over all time points.
|
|
638
|
+
|
|
639
|
+
References
|
|
640
|
+
----------
|
|
641
|
+
.. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher,
|
|
642
|
+
"Assessment and comparison of prognostic classification schemes for survival data,"
|
|
643
|
+
Statistics in Medicine, vol. 18, no. 17-18, pp. 2529–2545, 1999.
|
|
644
|
+
"""
|
|
645
|
+
test_event, test_time = check_y_survival(survival_test)
|
|
646
|
+
estimate, times = _check_estimate_2d(estimate, test_time, times, estimator="brier_score")
|
|
647
|
+
if estimate.ndim == 1 and times.shape[0] == 1:
|
|
648
|
+
estimate = estimate.reshape(-1, 1)
|
|
649
|
+
|
|
650
|
+
# fit IPCW estimator
|
|
651
|
+
cens = CensoringDistributionEstimator().fit(survival_train)
|
|
652
|
+
# calculate inverse probability of censoring weight at current time point t.
|
|
653
|
+
prob_cens_t = cens.predict_proba(times)
|
|
654
|
+
prob_cens_t[prob_cens_t == 0] = np.inf
|
|
655
|
+
# calculate inverse probability of censoring weights at observed time point
|
|
656
|
+
prob_cens_y = cens.predict_proba(test_time)
|
|
657
|
+
prob_cens_y[prob_cens_y == 0] = np.inf
|
|
658
|
+
|
|
659
|
+
# Calculating the brier scores at each time point
|
|
660
|
+
brier_scores = np.empty(times.shape[0], dtype=float)
|
|
661
|
+
for i, t in enumerate(times):
|
|
662
|
+
est = estimate[:, i]
|
|
663
|
+
is_case = (test_time <= t) & test_event
|
|
664
|
+
is_control = test_time > t
|
|
665
|
+
|
|
666
|
+
brier_scores[i] = np.mean(
|
|
667
|
+
np.square(est) * is_case.astype(int) / prob_cens_y
|
|
668
|
+
+ np.square(1.0 - est) * is_control.astype(int) / prob_cens_t[i]
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
return times, brier_scores
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def integrated_brier_score(survival_train, survival_test, estimate, times):
|
|
675
|
+
r"""Computes the integrated Brier score (IBS).
|
|
676
|
+
|
|
677
|
+
The IBS is an overall measure of the model's performance across all
|
|
678
|
+
available time points :math:`t_1 \leq t \leq t_\text{max}`.
|
|
679
|
+
It is the average Brier score, integrated over time.
|
|
680
|
+
A lower IBS indicates better model performance.
|
|
681
|
+
|
|
682
|
+
The integrated time-dependent Brier score over the interval
|
|
683
|
+
:math:`[t_1; t_\text{max}]` is defined as
|
|
684
|
+
|
|
685
|
+
.. math::
|
|
686
|
+
|
|
687
|
+
\mathrm{IBS} = \int_{t_1}^{t_\text{max}} \mathrm{BS}^c(t) d w(t)
|
|
688
|
+
|
|
689
|
+
where the weighting function is :math:`w(t) = t / t_\text{max}`.
|
|
690
|
+
The integral is estimated via the trapezoidal rule.
|
|
691
|
+
|
|
692
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Brier-Score>`
|
|
693
|
+
and [1]_ for further details.
|
|
694
|
+
|
|
695
|
+
Parameters
|
|
696
|
+
----------
|
|
697
|
+
survival_train : structured array, shape = (n_train_samples,)
|
|
698
|
+
Survival times for the training data, used to estimate the censoring
|
|
699
|
+
distribution.
|
|
700
|
+
A structured array with two fields. The first field is a boolean
|
|
701
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
702
|
+
The second field is a float with the time of event or time of censoring.
|
|
703
|
+
survival_test : structured array, shape = (n_samples,)
|
|
704
|
+
Survival times for the test data.
|
|
705
|
+
A structured array with two fields. The first field is a boolean
|
|
706
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
707
|
+
The second field is a float with the time of event or time of censoring.
|
|
708
|
+
estimate : array-like, shape = (n_samples, n_times)
|
|
709
|
+
Predicted survival probabilities for the test data at the time points
|
|
710
|
+
specified by ``times``, typically obtained from
|
|
711
|
+
``estimator.predict_survival_function(X)``. The value of ``estimate[:, i]``
|
|
712
|
+
must correspond to the estimated survival probability up to
|
|
713
|
+
the time point ``times[i]``.
|
|
714
|
+
times : array-like, shape = (n_times,)
|
|
715
|
+
The time points at which to compute the Brier score. Values must be
|
|
716
|
+
within the range of follow-up times in ``survival_test``.
|
|
717
|
+
|
|
718
|
+
Returns
|
|
719
|
+
-------
|
|
720
|
+
ibs : float
|
|
721
|
+
The integrated Brier score.
|
|
722
|
+
|
|
723
|
+
Notes
|
|
724
|
+
-----
|
|
725
|
+
This metric expects survival probabilities, which are typically returned by
|
|
726
|
+
``estimator.predict_survival_function(X)``.
|
|
727
|
+
It *does not accept* risk scores.
|
|
728
|
+
|
|
729
|
+
Examples
|
|
730
|
+
--------
|
|
731
|
+
>>> import numpy as np
|
|
732
|
+
>>> from sksurv.datasets import load_gbsg2
|
|
733
|
+
>>> from sksurv.linear_model import CoxPHSurvivalAnalysis
|
|
734
|
+
>>> from sksurv.metrics import integrated_brier_score
|
|
735
|
+
>>> from sksurv.preprocessing import OneHotEncoder
|
|
736
|
+
|
|
737
|
+
Load and prepare data.
|
|
738
|
+
|
|
739
|
+
>>> X, y = load_gbsg2()
|
|
740
|
+
>>> X["tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
|
|
741
|
+
>>> Xt = OneHotEncoder().fit_transform(X)
|
|
742
|
+
|
|
743
|
+
Fit a Cox model.
|
|
744
|
+
|
|
745
|
+
>>> est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)
|
|
746
|
+
|
|
747
|
+
Retrieve individual survival functions and get probability
|
|
748
|
+
of remaining event free from 1 year to 5 years (=1825 days).
|
|
749
|
+
|
|
750
|
+
>>> survs = est.predict_survival_function(Xt)
|
|
751
|
+
>>> times = np.arange(365, 1826)
|
|
752
|
+
>>> preds = np.asarray([[fn(t) for t in times] for fn in survs])
|
|
753
|
+
|
|
754
|
+
Compute the integrated Brier score from 1 to 5 years.
|
|
755
|
+
|
|
756
|
+
>>> score = integrated_brier_score(y, y, preds, times)
|
|
757
|
+
>>> print(round(score, 4))
|
|
758
|
+
0.1816
|
|
759
|
+
|
|
760
|
+
See also
|
|
761
|
+
--------
|
|
762
|
+
brier_score
|
|
763
|
+
Computes the Brier score at specified time points.
|
|
764
|
+
|
|
765
|
+
as_integrated_brier_score_scorer
|
|
766
|
+
Wrapper class that uses :func:`integrated_brier_score`
|
|
767
|
+
in its ``score`` method instead of the default
|
|
768
|
+
:func:`concordance_index_censored`.
|
|
769
|
+
|
|
770
|
+
References
|
|
771
|
+
----------
|
|
772
|
+
.. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher,
|
|
773
|
+
"Assessment and comparison of prognostic classification schemes for survival data,"
|
|
774
|
+
Statistics in Medicine, vol. 18, no. 17-18, pp. 2529–2545, 1999.
|
|
775
|
+
"""
|
|
776
|
+
# Computing the brier scores
|
|
777
|
+
times, brier_scores = brier_score(survival_train, survival_test, estimate, times)
|
|
778
|
+
|
|
779
|
+
if times.shape[0] < 2:
|
|
780
|
+
raise ValueError("At least two time points must be given")
|
|
781
|
+
|
|
782
|
+
# Computing the IBS
|
|
783
|
+
ibs_value = np.trapz(brier_scores, times) / (times[-1] - times[0])
|
|
784
|
+
|
|
785
|
+
return ibs_value
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
def _estimator_has(attr):
|
|
789
|
+
"""Check that meta_estimator has `attr`.
|
|
790
|
+
|
|
791
|
+
Used together with `available_if`."""
|
|
792
|
+
|
|
793
|
+
def check(self):
|
|
794
|
+
# raise original `AttributeError` if `attr` does not exist
|
|
795
|
+
getattr(self.estimator_, attr)
|
|
796
|
+
return True
|
|
797
|
+
|
|
798
|
+
return check
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
class _ScoreOverrideMixin:
|
|
802
|
+
def __init__(self, estimator, predict_func, score_func, score_index, greater_is_better):
|
|
803
|
+
if not hasattr(estimator, predict_func):
|
|
804
|
+
raise AttributeError(f"{estimator!r} object has no attribute {predict_func!r}")
|
|
805
|
+
|
|
806
|
+
self.estimator = estimator
|
|
807
|
+
self._predict_func = predict_func
|
|
808
|
+
self._score_func = score_func
|
|
809
|
+
self._score_index = score_index
|
|
810
|
+
self._sign = 1 if greater_is_better else -1
|
|
811
|
+
|
|
812
|
+
def _get_score_params(self):
|
|
813
|
+
"""Return dict of parameters passed to ``score_func``."""
|
|
814
|
+
params = self.get_params(deep=False)
|
|
815
|
+
del params["estimator"]
|
|
816
|
+
return params
|
|
817
|
+
|
|
818
|
+
def fit(self, X, y, **fit_params):
|
|
819
|
+
self._train_y = np.array(y, copy=True)
|
|
820
|
+
self.estimator_ = self.estimator.fit(X, y, **fit_params)
|
|
821
|
+
return self
|
|
822
|
+
|
|
823
|
+
def _do_predict(self, X):
|
|
824
|
+
predict_func = getattr(self.estimator_, self._predict_func)
|
|
825
|
+
return predict_func(X)
|
|
826
|
+
|
|
827
|
+
def score(self, X, y):
|
|
828
|
+
"""Returns the score on the given data.
|
|
829
|
+
|
|
830
|
+
Parameters
|
|
831
|
+
----------
|
|
832
|
+
X : array-like, shape = (n_samples, n_features)
|
|
833
|
+
Input data, where n_samples is the number of samples and
|
|
834
|
+
n_features is the number of features.
|
|
835
|
+
|
|
836
|
+
y : array-like, shape = (n_samples,)
|
|
837
|
+
Target relative to X for classification or regression;
|
|
838
|
+
None for unsupervised learning.
|
|
839
|
+
|
|
840
|
+
Returns
|
|
841
|
+
-------
|
|
842
|
+
score : float
|
|
843
|
+
"""
|
|
844
|
+
estimate = self._do_predict(X)
|
|
845
|
+
score = self._score_func(
|
|
846
|
+
survival_train=self._train_y,
|
|
847
|
+
survival_test=y,
|
|
848
|
+
estimate=estimate,
|
|
849
|
+
**self._get_score_params(),
|
|
850
|
+
)
|
|
851
|
+
if self._score_index is not None:
|
|
852
|
+
score = score[self._score_index]
|
|
853
|
+
return self._sign * score
|
|
854
|
+
|
|
855
|
+
@available_if(_estimator_has("predict"))
|
|
856
|
+
def predict(self, X):
|
|
857
|
+
"""Call predict on the estimator.
|
|
858
|
+
|
|
859
|
+
Only available if estimator supports ``predict``.
|
|
860
|
+
|
|
861
|
+
Parameters
|
|
862
|
+
----------
|
|
863
|
+
X : indexable, length n_samples
|
|
864
|
+
Must fulfill the input assumptions of the
|
|
865
|
+
underlying estimator.
|
|
866
|
+
"""
|
|
867
|
+
check_is_fitted(self, "estimator_")
|
|
868
|
+
return self.estimator_.predict(X)
|
|
869
|
+
|
|
870
|
+
@available_if(_estimator_has("predict_cumulative_hazard_function"))
|
|
871
|
+
def predict_cumulative_hazard_function(self, X):
|
|
872
|
+
"""Call predict_cumulative_hazard_function on the estimator.
|
|
873
|
+
|
|
874
|
+
Only available if estimator supports ``predict_cumulative_hazard_function``.
|
|
875
|
+
|
|
876
|
+
Parameters
|
|
877
|
+
----------
|
|
878
|
+
X : indexable, length n_samples
|
|
879
|
+
Must fulfill the input assumptions of the
|
|
880
|
+
underlying estimator.
|
|
881
|
+
"""
|
|
882
|
+
check_is_fitted(self, "estimator_")
|
|
883
|
+
return self.estimator_.predict_cumulative_hazard_function(X)
|
|
884
|
+
|
|
885
|
+
@available_if(_estimator_has("predict_survival_function"))
|
|
886
|
+
def predict_survival_function(self, X):
|
|
887
|
+
"""Call predict_survival_function on the estimator.
|
|
888
|
+
|
|
889
|
+
Only available if estimator supports ``predict_survival_function``.
|
|
890
|
+
|
|
891
|
+
Parameters
|
|
892
|
+
----------
|
|
893
|
+
X : indexable, length n_samples
|
|
894
|
+
Must fulfill the input assumptions of the
|
|
895
|
+
underlying estimator.
|
|
896
|
+
"""
|
|
897
|
+
check_is_fitted(self, "estimator_")
|
|
898
|
+
return self.estimator_.predict_survival_function(X)
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
class as_cumulative_dynamic_auc_scorer(_ScoreOverrideMixin, BaseEstimator):
|
|
902
|
+
"""Wraps an estimator to use :func:`cumulative_dynamic_auc` as ``score`` function.
|
|
903
|
+
|
|
904
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
|
|
905
|
+
for using it for hyper-parameter optimization.
|
|
906
|
+
|
|
907
|
+
Parameters
|
|
908
|
+
----------
|
|
909
|
+
estimator : object
|
|
910
|
+
Instance of an estimator.
|
|
911
|
+
times : array-like, shape = (n_times,)
|
|
912
|
+
The time points at which to compute the AUC. Values must be within the
|
|
913
|
+
range of follow-up times of the test data.
|
|
914
|
+
tied_tol : float, optional, default: 1e-8
|
|
915
|
+
The tolerance value for considering ties in risk scores. If the
|
|
916
|
+
absolute difference between two risk scores is smaller than or equal to
|
|
917
|
+
``tied_tol``, they are considered tied.
|
|
918
|
+
|
|
919
|
+
Attributes
|
|
920
|
+
----------
|
|
921
|
+
estimator_ : estimator
|
|
922
|
+
Estimator that was fit.
|
|
923
|
+
|
|
924
|
+
See also
|
|
925
|
+
--------
|
|
926
|
+
cumulative_dynamic_auc
|
|
927
|
+
"""
|
|
928
|
+
|
|
929
|
+
def __init__(self, estimator, times, tied_tol=1e-8):
|
|
930
|
+
super().__init__(
|
|
931
|
+
estimator=estimator,
|
|
932
|
+
predict_func="predict",
|
|
933
|
+
score_func=cumulative_dynamic_auc,
|
|
934
|
+
score_index=1,
|
|
935
|
+
greater_is_better=True,
|
|
936
|
+
)
|
|
937
|
+
self.times = times
|
|
938
|
+
self.tied_tol = tied_tol
|
|
939
|
+
|
|
940
|
+
|
|
941
|
+
class as_concordance_index_ipcw_scorer(_ScoreOverrideMixin, BaseEstimator):
|
|
942
|
+
r"""Wraps an estimator to use :func:`concordance_index_ipcw` as ``score`` function.
|
|
943
|
+
|
|
944
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
|
|
945
|
+
for using it for hyper-parameter optimization.
|
|
946
|
+
|
|
947
|
+
Parameters
|
|
948
|
+
----------
|
|
949
|
+
estimator : object
|
|
950
|
+
Instance of an estimator.
|
|
951
|
+
tau : float, optional
|
|
952
|
+
Truncation time. The survival function for the underlying
|
|
953
|
+
censoring time distribution :math:`D` needs to be positive
|
|
954
|
+
at `tau`, i.e., `tau` should be chosen such that the
|
|
955
|
+
probability of being censored after time `tau` is non-zero:
|
|
956
|
+
:math:`P(D > \tau) > 0`. If `None`, no truncation is performed.
|
|
957
|
+
tied_tol : float, optional, default: 1e-8
|
|
958
|
+
The tolerance value for considering ties in risk scores.
|
|
959
|
+
If the absolute difference between two risk scores is smaller than
|
|
960
|
+
or equal to ``tied_tol``, they are considered tied.
|
|
961
|
+
|
|
962
|
+
Attributes
|
|
963
|
+
----------
|
|
964
|
+
estimator_ : estimator
|
|
965
|
+
Estimator that was fit.
|
|
966
|
+
|
|
967
|
+
See also
|
|
968
|
+
--------
|
|
969
|
+
concordance_index_ipcw
|
|
970
|
+
"""
|
|
971
|
+
|
|
972
|
+
def __init__(self, estimator, tau=None, tied_tol=1e-8):
|
|
973
|
+
super().__init__(
|
|
974
|
+
estimator=estimator,
|
|
975
|
+
predict_func="predict",
|
|
976
|
+
score_func=concordance_index_ipcw,
|
|
977
|
+
score_index=0,
|
|
978
|
+
greater_is_better=True,
|
|
979
|
+
)
|
|
980
|
+
self.tau = tau
|
|
981
|
+
self.tied_tol = tied_tol
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
class as_integrated_brier_score_scorer(_ScoreOverrideMixin, BaseEstimator):
|
|
985
|
+
"""Wraps an estimator to use the negative of :func:`integrated_brier_score` as ``score`` function.
|
|
986
|
+
|
|
987
|
+
The estimator needs to be able to estimate survival functions via
|
|
988
|
+
a ``predict_survival_function`` method.
|
|
989
|
+
|
|
990
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
|
|
991
|
+
for using it for hyper-parameter optimization.
|
|
992
|
+
|
|
993
|
+
Parameters
|
|
994
|
+
----------
|
|
995
|
+
estimator : object
|
|
996
|
+
Instance of an estimator that provides ``predict_survival_function``.
|
|
997
|
+
times : array-like, shape = (n_times,)
|
|
998
|
+
The time points at which to compute the Brier score. Values must be
|
|
999
|
+
within the range of follow-up times of the test data.
|
|
1000
|
+
|
|
1001
|
+
Attributes
|
|
1002
|
+
----------
|
|
1003
|
+
estimator_ : estimator
|
|
1004
|
+
Estimator that was fit.
|
|
1005
|
+
|
|
1006
|
+
See also
|
|
1007
|
+
--------
|
|
1008
|
+
integrated_brier_score
|
|
1009
|
+
"""
|
|
1010
|
+
|
|
1011
|
+
def __init__(self, estimator, times):
|
|
1012
|
+
super().__init__(
|
|
1013
|
+
estimator=estimator,
|
|
1014
|
+
predict_func="predict_survival_function",
|
|
1015
|
+
score_func=integrated_brier_score,
|
|
1016
|
+
score_index=None,
|
|
1017
|
+
greater_is_better=False,
|
|
1018
|
+
)
|
|
1019
|
+
self.times = times
|
|
1020
|
+
|
|
1021
|
+
def _do_predict(self, X):
|
|
1022
|
+
predict_func = getattr(self.estimator_, self._predict_func)
|
|
1023
|
+
surv_fns = predict_func(X)
|
|
1024
|
+
times = self.times
|
|
1025
|
+
estimates = np.empty((len(surv_fns), len(times)))
|
|
1026
|
+
for i, fn in enumerate(surv_fns):
|
|
1027
|
+
estimates[i, :] = fn(times)
|
|
1028
|
+
return estimates
|