scikit-survival 0.23.1__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.23.1.dist-info/COPYING +674 -0
- scikit_survival-0.23.1.dist-info/METADATA +888 -0
- scikit_survival-0.23.1.dist-info/RECORD +55 -0
- scikit_survival-0.23.1.dist-info/WHEEL +5 -0
- scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +138 -0
- sksurv/base.py +103 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
- sksurv/column.py +201 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +10 -0
- sksurv/datasets/base.py +436 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
- sksurv/ensemble/boosting.py +1610 -0
- sksurv/ensemble/forest.py +947 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +58 -0
- sksurv/io/arffwrite.py +145 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +328 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
- sksurv/linear_model/aft.py +205 -0
- sksurv/linear_model/coxnet.py +543 -0
- sksurv/linear_model/coxph.py +618 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +642 -0
- sksurv/meta/stacking.py +349 -0
- sksurv/metrics.py +996 -0
- sksurv/nonparametric.py +588 -0
- sksurv/preprocessing.py +155 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +606 -0
- sksurv/svm/naive_survival_svm.py +221 -0
- sksurv/svm/survival_svm.py +1228 -0
- sksurv/testing.py +108 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
- sksurv/tree/tree.py +703 -0
- sksurv/util.py +333 -0
sksurv/metrics.py
ADDED
|
@@ -0,0 +1,996 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import numpy as np
|
|
14
|
+
from sklearn.base import BaseEstimator
|
|
15
|
+
from sklearn.utils import check_array, check_consistent_length
|
|
16
|
+
from sklearn.utils.metaestimators import available_if
|
|
17
|
+
from sklearn.utils.validation import check_is_fitted
|
|
18
|
+
|
|
19
|
+
from .exceptions import NoComparablePairException
|
|
20
|
+
from .nonparametric import CensoringDistributionEstimator, SurvivalFunctionEstimator
|
|
21
|
+
from .util import check_y_survival
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"as_concordance_index_ipcw_scorer",
|
|
25
|
+
"as_cumulative_dynamic_auc_scorer",
|
|
26
|
+
"as_integrated_brier_score_scorer",
|
|
27
|
+
"brier_score",
|
|
28
|
+
"concordance_index_censored",
|
|
29
|
+
"concordance_index_ipcw",
|
|
30
|
+
"cumulative_dynamic_auc",
|
|
31
|
+
"integrated_brier_score",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _check_estimate_1d(estimate, test_time):
|
|
36
|
+
estimate = check_array(estimate, ensure_2d=False, input_name="estimate")
|
|
37
|
+
if estimate.ndim != 1:
|
|
38
|
+
raise ValueError(f"Expected 1D array, got {estimate.ndim}D array instead:\narray={estimate}.\n")
|
|
39
|
+
check_consistent_length(test_time, estimate)
|
|
40
|
+
return estimate
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _check_inputs(event_indicator, event_time, estimate):
|
|
44
|
+
check_consistent_length(event_indicator, event_time, estimate)
|
|
45
|
+
event_indicator = check_array(event_indicator, ensure_2d=False, input_name="event_indicator")
|
|
46
|
+
event_time = check_array(event_time, ensure_2d=False, input_name="event_time")
|
|
47
|
+
estimate = _check_estimate_1d(estimate, event_time)
|
|
48
|
+
|
|
49
|
+
if not np.issubdtype(event_indicator.dtype, np.bool_):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"only boolean arrays are supported as class labels for survival analysis, got {event_indicator.dtype}"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
if len(event_time) < 2:
|
|
55
|
+
raise ValueError("Need a minimum of two samples")
|
|
56
|
+
|
|
57
|
+
if not event_indicator.any():
|
|
58
|
+
raise ValueError("All samples are censored")
|
|
59
|
+
|
|
60
|
+
return event_indicator, event_time, estimate
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _check_times(test_time, times):
|
|
64
|
+
times = check_array(np.atleast_1d(times), ensure_2d=False, input_name="times")
|
|
65
|
+
times = np.unique(times)
|
|
66
|
+
|
|
67
|
+
if times.max() >= test_time.max() or times.min() < test_time.min():
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"all times must be within follow-up time of test data: [{test_time.min()}; {test_time.max()}["
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return times
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _check_estimate_2d(estimate, test_time, time_points, estimator):
|
|
76
|
+
estimate = check_array(estimate, ensure_2d=False, allow_nd=False, input_name="estimate", estimator=estimator)
|
|
77
|
+
time_points = _check_times(test_time, time_points)
|
|
78
|
+
check_consistent_length(test_time, estimate)
|
|
79
|
+
|
|
80
|
+
if estimate.ndim == 2 and estimate.shape[1] != time_points.shape[0]:
|
|
81
|
+
raise ValueError(f"expected estimate with {time_points.shape[0]} columns, but got {estimate.shape[1]}")
|
|
82
|
+
|
|
83
|
+
return estimate, time_points
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _iter_comparable(event_indicator, event_time, order):
|
|
87
|
+
n_samples = len(event_time)
|
|
88
|
+
tied_time = 0
|
|
89
|
+
i = 0
|
|
90
|
+
while i < n_samples - 1:
|
|
91
|
+
time_i = event_time[order[i]]
|
|
92
|
+
end = i + 1
|
|
93
|
+
while end < n_samples and event_time[order[end]] == time_i:
|
|
94
|
+
end += 1
|
|
95
|
+
|
|
96
|
+
# check for tied event times
|
|
97
|
+
event_at_same_time = event_indicator[order[i:end]]
|
|
98
|
+
censored_at_same_time = ~event_at_same_time
|
|
99
|
+
|
|
100
|
+
mask = np.zeros(n_samples, dtype=bool)
|
|
101
|
+
mask[end:] = True
|
|
102
|
+
# an event is comparable to censored samples at same time point
|
|
103
|
+
mask[i:end] = censored_at_same_time
|
|
104
|
+
|
|
105
|
+
for j in range(i, end):
|
|
106
|
+
if event_indicator[order[j]]:
|
|
107
|
+
tied_time += censored_at_same_time.sum()
|
|
108
|
+
yield (j, mask, tied_time)
|
|
109
|
+
i = end
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _estimate_concordance_index(event_indicator, event_time, estimate, weights, tied_tol=1e-8):
|
|
113
|
+
order = np.argsort(event_time)
|
|
114
|
+
|
|
115
|
+
tied_time = None
|
|
116
|
+
|
|
117
|
+
concordant = 0
|
|
118
|
+
discordant = 0
|
|
119
|
+
tied_risk = 0
|
|
120
|
+
numerator = 0.0
|
|
121
|
+
denominator = 0.0
|
|
122
|
+
for ind, mask, tied_time in _iter_comparable(event_indicator, event_time, order):
|
|
123
|
+
est_i = estimate[order[ind]]
|
|
124
|
+
event_i = event_indicator[order[ind]]
|
|
125
|
+
w_i = weights[order[ind]]
|
|
126
|
+
|
|
127
|
+
est = estimate[order[mask]]
|
|
128
|
+
|
|
129
|
+
assert event_i, f"got censored sample at index {order[ind]}, but expected uncensored"
|
|
130
|
+
|
|
131
|
+
ties = np.absolute(est - est_i) <= tied_tol
|
|
132
|
+
n_ties = ties.sum()
|
|
133
|
+
# an event should have a higher score
|
|
134
|
+
con = est < est_i
|
|
135
|
+
n_con = con[~ties].sum()
|
|
136
|
+
|
|
137
|
+
numerator += w_i * n_con + 0.5 * w_i * n_ties
|
|
138
|
+
denominator += w_i * mask.sum()
|
|
139
|
+
|
|
140
|
+
tied_risk += n_ties
|
|
141
|
+
concordant += n_con
|
|
142
|
+
discordant += est.size - n_con - n_ties
|
|
143
|
+
|
|
144
|
+
if tied_time is None:
|
|
145
|
+
raise NoComparablePairException("Data has no comparable pairs, cannot estimate concordance index.")
|
|
146
|
+
|
|
147
|
+
cindex = numerator / denominator
|
|
148
|
+
return cindex, concordant, discordant, tied_risk, tied_time
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def concordance_index_censored(event_indicator, event_time, estimate, tied_tol=1e-8):
|
|
152
|
+
"""Concordance index for right-censored data
|
|
153
|
+
|
|
154
|
+
The concordance index is defined as the proportion of all comparable pairs
|
|
155
|
+
in which the predictions and outcomes are concordant.
|
|
156
|
+
|
|
157
|
+
Two samples are comparable if (i) both of them experienced an event (at different times),
|
|
158
|
+
or (ii) the one with a shorter observed survival time experienced an event, in which case
|
|
159
|
+
the event-free subject "outlived" the other. A pair is not comparable if they experienced
|
|
160
|
+
events at the same time.
|
|
161
|
+
|
|
162
|
+
Concordance intuitively means that two samples were ordered correctly by the model.
|
|
163
|
+
More specifically, two samples are concordant, if the one with a higher estimated
|
|
164
|
+
risk score has a shorter actual survival time.
|
|
165
|
+
When predicted risks are identical for a pair, 0.5 rather than 1 is added to the count
|
|
166
|
+
of concordant pairs.
|
|
167
|
+
|
|
168
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb>`
|
|
169
|
+
and [1]_ for further description.
|
|
170
|
+
|
|
171
|
+
Parameters
|
|
172
|
+
----------
|
|
173
|
+
event_indicator : array-like, shape = (n_samples,)
|
|
174
|
+
Boolean array denotes whether an event occurred
|
|
175
|
+
|
|
176
|
+
event_time : array-like, shape = (n_samples,)
|
|
177
|
+
Array containing the time of an event or time of censoring
|
|
178
|
+
|
|
179
|
+
estimate : array-like, shape = (n_samples,)
|
|
180
|
+
Estimated risk of experiencing an event
|
|
181
|
+
|
|
182
|
+
tied_tol : float, optional, default: 1e-8
|
|
183
|
+
The tolerance value for considering ties.
|
|
184
|
+
If the absolute difference between risk scores is smaller
|
|
185
|
+
or equal than `tied_tol`, risk scores are considered tied.
|
|
186
|
+
|
|
187
|
+
Returns
|
|
188
|
+
-------
|
|
189
|
+
cindex : float
|
|
190
|
+
Concordance index
|
|
191
|
+
|
|
192
|
+
concordant : int
|
|
193
|
+
Number of concordant pairs
|
|
194
|
+
|
|
195
|
+
discordant : int
|
|
196
|
+
Number of discordant pairs
|
|
197
|
+
|
|
198
|
+
tied_risk : int
|
|
199
|
+
Number of pairs having tied estimated risks
|
|
200
|
+
|
|
201
|
+
tied_time : int
|
|
202
|
+
Number of comparable pairs sharing the same time
|
|
203
|
+
|
|
204
|
+
See also
|
|
205
|
+
--------
|
|
206
|
+
concordance_index_ipcw
|
|
207
|
+
Alternative estimator of the concordance index with less bias.
|
|
208
|
+
|
|
209
|
+
References
|
|
210
|
+
----------
|
|
211
|
+
.. [1] Harrell, F.E., Califf, R.M., Pryor, D.B., Lee, K.L., Rosati, R.A,
|
|
212
|
+
"Multivariable prognostic models: issues in developing models,
|
|
213
|
+
evaluating assumptions and adequacy, and measuring and reducing errors",
|
|
214
|
+
Statistics in Medicine, 15(4), 361-87, 1996.
|
|
215
|
+
"""
|
|
216
|
+
event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)
|
|
217
|
+
|
|
218
|
+
w = np.ones_like(estimate)
|
|
219
|
+
|
|
220
|
+
return _estimate_concordance_index(event_indicator, event_time, estimate, w, tied_tol)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def concordance_index_ipcw(survival_train, survival_test, estimate, tau=None, tied_tol=1e-8):
|
|
224
|
+
"""Concordance index for right-censored data based on inverse probability of censoring weights.
|
|
225
|
+
|
|
226
|
+
This is an alternative to the estimator in :func:`concordance_index_censored`
|
|
227
|
+
that does not depend on the distribution of censoring times in the test data.
|
|
228
|
+
Therefore, the estimate is unbiased and consistent for a population concordance
|
|
229
|
+
measure that is free of censoring.
|
|
230
|
+
|
|
231
|
+
It is based on inverse probability of censoring weights, thus requires
|
|
232
|
+
access to survival times from the training data to estimate the censoring
|
|
233
|
+
distribution. Note that this requires that survival times `survival_test`
|
|
234
|
+
lie within the range of survival times `survival_train`. This can be
|
|
235
|
+
achieved by specifying the truncation time `tau`.
|
|
236
|
+
The resulting `cindex` tells how well the given prediction model works in
|
|
237
|
+
predicting events that occur in the time range from 0 to `tau`.
|
|
238
|
+
|
|
239
|
+
The estimator uses the Kaplan-Meier estimator to estimate the
|
|
240
|
+
censoring survivor function. Therefore, it is restricted to
|
|
241
|
+
situations where the random censoring assumption holds and
|
|
242
|
+
censoring is independent of the features.
|
|
243
|
+
|
|
244
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb>`
|
|
245
|
+
and [1]_ for further description.
|
|
246
|
+
|
|
247
|
+
Parameters
|
|
248
|
+
----------
|
|
249
|
+
survival_train : structured array, shape = (n_train_samples,)
|
|
250
|
+
Survival times for training data to estimate the censoring
|
|
251
|
+
distribution from.
|
|
252
|
+
A structured array containing the binary event indicator
|
|
253
|
+
as first field, and time of event or time of censoring as
|
|
254
|
+
second field.
|
|
255
|
+
|
|
256
|
+
survival_test : structured array, shape = (n_samples,)
|
|
257
|
+
Survival times of test data.
|
|
258
|
+
A structured array containing the binary event indicator
|
|
259
|
+
as first field, and time of event or time of censoring as
|
|
260
|
+
second field.
|
|
261
|
+
|
|
262
|
+
estimate : array-like, shape = (n_samples,)
|
|
263
|
+
Estimated risk of experiencing an event of test data.
|
|
264
|
+
|
|
265
|
+
tau : float, optional
|
|
266
|
+
Truncation time. The survival function for the underlying
|
|
267
|
+
censoring time distribution :math:`D` needs to be positive
|
|
268
|
+
at `tau`, i.e., `tau` should be chosen such that the
|
|
269
|
+
probability of being censored after time `tau` is non-zero:
|
|
270
|
+
:math:`P(D > \\tau) > 0`. If `None`, no truncation is performed.
|
|
271
|
+
|
|
272
|
+
tied_tol : float, optional, default: 1e-8
|
|
273
|
+
The tolerance value for considering ties.
|
|
274
|
+
If the absolute difference between risk scores is smaller
|
|
275
|
+
or equal than `tied_tol`, risk scores are considered tied.
|
|
276
|
+
|
|
277
|
+
Returns
|
|
278
|
+
-------
|
|
279
|
+
cindex : float
|
|
280
|
+
Concordance index
|
|
281
|
+
|
|
282
|
+
concordant : int
|
|
283
|
+
Number of concordant pairs
|
|
284
|
+
|
|
285
|
+
discordant : int
|
|
286
|
+
Number of discordant pairs
|
|
287
|
+
|
|
288
|
+
tied_risk : int
|
|
289
|
+
Number of pairs having tied estimated risks
|
|
290
|
+
|
|
291
|
+
tied_time : int
|
|
292
|
+
Number of comparable pairs sharing the same time
|
|
293
|
+
|
|
294
|
+
See also
|
|
295
|
+
--------
|
|
296
|
+
concordance_index_censored
|
|
297
|
+
Simpler estimator of the concordance index.
|
|
298
|
+
|
|
299
|
+
as_concordance_index_ipcw_scorer
|
|
300
|
+
Wrapper class that uses :func:`concordance_index_ipcw`
|
|
301
|
+
in its ``score`` method instead of the default
|
|
302
|
+
:func:`concordance_index_censored`.
|
|
303
|
+
|
|
304
|
+
References
|
|
305
|
+
----------
|
|
306
|
+
.. [1] Uno, H., Cai, T., Pencina, M. J., D’Agostino, R. B., & Wei, L. J. (2011).
|
|
307
|
+
"On the C-statistics for evaluating overall adequacy of risk prediction
|
|
308
|
+
procedures with censored survival data".
|
|
309
|
+
Statistics in Medicine, 30(10), 1105–1117.
|
|
310
|
+
"""
|
|
311
|
+
test_event, test_time = check_y_survival(survival_test)
|
|
312
|
+
|
|
313
|
+
if tau is not None:
|
|
314
|
+
mask = test_time < tau
|
|
315
|
+
survival_test = survival_test[mask]
|
|
316
|
+
|
|
317
|
+
estimate = _check_estimate_1d(estimate, test_time)
|
|
318
|
+
|
|
319
|
+
cens = CensoringDistributionEstimator()
|
|
320
|
+
cens.fit(survival_train)
|
|
321
|
+
ipcw_test = cens.predict_ipcw(survival_test)
|
|
322
|
+
if tau is None:
|
|
323
|
+
ipcw = ipcw_test
|
|
324
|
+
else:
|
|
325
|
+
ipcw = np.empty(estimate.shape[0], dtype=ipcw_test.dtype)
|
|
326
|
+
ipcw[mask] = ipcw_test
|
|
327
|
+
ipcw[~mask] = 0
|
|
328
|
+
|
|
329
|
+
w = np.square(ipcw)
|
|
330
|
+
|
|
331
|
+
return _estimate_concordance_index(test_event, test_time, estimate, w, tied_tol)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def cumulative_dynamic_auc(survival_train, survival_test, estimate, times, tied_tol=1e-8):
|
|
335
|
+
"""Estimator of cumulative/dynamic AUC for right-censored time-to-event data.
|
|
336
|
+
|
|
337
|
+
The receiver operating characteristic (ROC) curve and the area under the
|
|
338
|
+
ROC curve (AUC) can be extended to survival data by defining
|
|
339
|
+
sensitivity (true positive rate) and specificity (true negative rate)
|
|
340
|
+
as time-dependent measures. *Cumulative cases* are all individuals that
|
|
341
|
+
experienced an event prior to or at time :math:`t` (:math:`t_i \\leq t`),
|
|
342
|
+
whereas *dynamic controls* are those with :math:`t_i > t`.
|
|
343
|
+
The associated cumulative/dynamic AUC quantifies how well a model can
|
|
344
|
+
distinguish subjects who fail by a given time (:math:`t_i \\leq t`) from
|
|
345
|
+
subjects who fail after this time (:math:`t_i > t`).
|
|
346
|
+
|
|
347
|
+
Given an estimator of the :math:`i`-th individual's risk score
|
|
348
|
+
:math:`\\hat{f}(\\mathbf{x}_i)`, the cumulative/dynamic AUC at time
|
|
349
|
+
:math:`t` is defined as
|
|
350
|
+
|
|
351
|
+
.. math::
|
|
352
|
+
|
|
353
|
+
\\widehat{\\mathrm{AUC}}(t) =
|
|
354
|
+
\\frac{\\sum_{i=1}^n \\sum_{j=1}^n I(y_j > t) I(y_i \\leq t) \\omega_i
|
|
355
|
+
I(\\hat{f}(\\mathbf{x}_j) \\leq \\hat{f}(\\mathbf{x}_i))}
|
|
356
|
+
{(\\sum_{i=1}^n I(y_i > t)) (\\sum_{i=1}^n I(y_i \\leq t) \\omega_i)}
|
|
357
|
+
|
|
358
|
+
where :math:`\\omega_i` are inverse probability of censoring weights (IPCW).
|
|
359
|
+
|
|
360
|
+
To estimate IPCW, access to survival times from the training data is required
|
|
361
|
+
to estimate the censoring distribution. Note that this requires that survival
|
|
362
|
+
times `survival_test` lie within the range of survival times `survival_train`.
|
|
363
|
+
This can be achieved by specifying `times` accordingly, e.g. by setting
|
|
364
|
+
`times[-1]` slightly below the maximum expected follow-up time.
|
|
365
|
+
IPCW are computed using the Kaplan-Meier estimator, which is
|
|
366
|
+
restricted to situations where the random censoring assumption holds and
|
|
367
|
+
censoring is independent of the features.
|
|
368
|
+
|
|
369
|
+
This function can also be used to evaluate models with time-dependent predictions
|
|
370
|
+
:math:`\\hat{f}(\\mathbf{x}_i, t)`, such as :class:`sksurv.ensemble.RandomSurvivalForest`
|
|
371
|
+
(see :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Time-dependent-Risk-Scores>`).
|
|
372
|
+
In this case, `estimate` must be a 2-d array where ``estimate[i, j]`` is the
|
|
373
|
+
predicted risk score for the i-th instance at time point ``times[j]``.
|
|
374
|
+
|
|
375
|
+
Finally, the function also provides a single summary measure that refers to the mean
|
|
376
|
+
of the :math:`\\mathrm{AUC}(t)` over the time range :math:`(\\tau_1, \\tau_2)`.
|
|
377
|
+
|
|
378
|
+
.. math::
|
|
379
|
+
|
|
380
|
+
\\overline{\\mathrm{AUC}}(\\tau_1, \\tau_2) =
|
|
381
|
+
\\frac{1}{\\hat{S}(\\tau_1) - \\hat{S}(\\tau_2)}
|
|
382
|
+
\\int_{\\tau_1}^{\\tau_2} \\widehat{\\mathrm{AUC}}(t)\\,d \\hat{S}(t)
|
|
383
|
+
|
|
384
|
+
where :math:`\\hat{S}(t)` is the Kaplan–Meier estimator of the survival function.
|
|
385
|
+
|
|
386
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Area-under-the-ROC>`,
|
|
387
|
+
[1]_, [2]_, [3]_ for further description.
|
|
388
|
+
|
|
389
|
+
Parameters
|
|
390
|
+
----------
|
|
391
|
+
survival_train : structured array, shape = (n_train_samples,)
|
|
392
|
+
Survival times for training data to estimate the censoring
|
|
393
|
+
distribution from.
|
|
394
|
+
A structured array containing the binary event indicator
|
|
395
|
+
as first field, and time of event or time of censoring as
|
|
396
|
+
second field.
|
|
397
|
+
|
|
398
|
+
survival_test : structured array, shape = (n_samples,)
|
|
399
|
+
Survival times of test data.
|
|
400
|
+
A structured array containing the binary event indicator
|
|
401
|
+
as first field, and time of event or time of censoring as
|
|
402
|
+
second field.
|
|
403
|
+
|
|
404
|
+
estimate : array-like, shape = (n_samples,) or (n_samples, n_times)
|
|
405
|
+
Estimated risk of experiencing an event of test data.
|
|
406
|
+
If `estimate` is a 1-d array, the same risk score across all time
|
|
407
|
+
points is used. If `estimate` is a 2-d array, the risk scores in the
|
|
408
|
+
j-th column are used to evaluate the j-th time point.
|
|
409
|
+
|
|
410
|
+
times : array-like, shape = (n_times,)
|
|
411
|
+
The time points for which the area under the
|
|
412
|
+
time-dependent ROC curve is computed. Values must be
|
|
413
|
+
within the range of follow-up times of the test data
|
|
414
|
+
`survival_test`.
|
|
415
|
+
|
|
416
|
+
tied_tol : float, optional, default: 1e-8
|
|
417
|
+
The tolerance value for considering ties.
|
|
418
|
+
If the absolute difference between risk scores is smaller
|
|
419
|
+
or equal than `tied_tol`, risk scores are considered tied.
|
|
420
|
+
|
|
421
|
+
Returns
|
|
422
|
+
-------
|
|
423
|
+
auc : array, shape = (n_times,)
|
|
424
|
+
The cumulative/dynamic AUC estimates (evaluated at `times`).
|
|
425
|
+
mean_auc : float
|
|
426
|
+
Summary measure referring to the mean cumulative/dynamic AUC
|
|
427
|
+
over the specified time range `(times[0], times[-1])`.
|
|
428
|
+
|
|
429
|
+
See also
|
|
430
|
+
--------
|
|
431
|
+
as_cumulative_dynamic_auc_scorer
|
|
432
|
+
Wrapper class that uses :func:`cumulative_dynamic_auc`
|
|
433
|
+
in its ``score`` method instead of the default
|
|
434
|
+
:func:`concordance_index_censored`.
|
|
435
|
+
|
|
436
|
+
References
|
|
437
|
+
----------
|
|
438
|
+
.. [1] H. Uno, T. Cai, L. Tian, and L. J. Wei,
|
|
439
|
+
"Evaluating prediction rules for t-year survivors with censored regression models,"
|
|
440
|
+
Journal of the American Statistical Association, vol. 102, pp. 527–537, 2007.
|
|
441
|
+
.. [2] H. Hung and C. T. Chiang,
|
|
442
|
+
"Estimation methods for time-dependent AUC models with survival data,"
|
|
443
|
+
Canadian Journal of Statistics, vol. 38, no. 1, pp. 8–26, 2010.
|
|
444
|
+
.. [3] J. Lambert and S. Chevret,
|
|
445
|
+
"Summary measure of discrimination in survival models based on cumulative/dynamic time-dependent ROC curves,"
|
|
446
|
+
Statistical Methods in Medical Research, 2014.
|
|
447
|
+
"""
|
|
448
|
+
test_event, test_time = check_y_survival(survival_test)
|
|
449
|
+
estimate, times = _check_estimate_2d(estimate, test_time, times, estimator="cumulative_dynamic_auc")
|
|
450
|
+
|
|
451
|
+
n_samples = estimate.shape[0]
|
|
452
|
+
n_times = times.shape[0]
|
|
453
|
+
if estimate.ndim == 1:
|
|
454
|
+
estimate = np.broadcast_to(estimate[:, np.newaxis], (n_samples, n_times))
|
|
455
|
+
|
|
456
|
+
# fit and transform IPCW
|
|
457
|
+
cens = CensoringDistributionEstimator()
|
|
458
|
+
cens.fit(survival_train)
|
|
459
|
+
ipcw = cens.predict_ipcw(survival_test)
|
|
460
|
+
|
|
461
|
+
# expand arrays to (n_samples, n_times) shape
|
|
462
|
+
test_time = np.broadcast_to(test_time[:, np.newaxis], (n_samples, n_times))
|
|
463
|
+
test_event = np.broadcast_to(test_event[:, np.newaxis], (n_samples, n_times))
|
|
464
|
+
times_2d = np.broadcast_to(times, (n_samples, n_times))
|
|
465
|
+
ipcw = np.broadcast_to(ipcw[:, np.newaxis], (n_samples, n_times))
|
|
466
|
+
|
|
467
|
+
# sort each time point (columns) by risk score (descending)
|
|
468
|
+
o = np.argsort(-estimate, axis=0)
|
|
469
|
+
test_time = np.take_along_axis(test_time, o, axis=0)
|
|
470
|
+
test_event = np.take_along_axis(test_event, o, axis=0)
|
|
471
|
+
estimate = np.take_along_axis(estimate, o, axis=0)
|
|
472
|
+
ipcw = np.take_along_axis(ipcw, o, axis=0)
|
|
473
|
+
|
|
474
|
+
is_case = (test_time <= times_2d) & test_event
|
|
475
|
+
is_control = test_time > times_2d
|
|
476
|
+
n_controls = is_control.sum(axis=0)
|
|
477
|
+
|
|
478
|
+
# prepend row of infinity values
|
|
479
|
+
estimate_diff = np.concatenate((np.broadcast_to(np.inf, (1, n_times)), estimate))
|
|
480
|
+
is_tied = np.absolute(np.diff(estimate_diff, axis=0)) <= tied_tol
|
|
481
|
+
|
|
482
|
+
cumsum_tp = np.cumsum(is_case * ipcw, axis=0)
|
|
483
|
+
cumsum_fp = np.cumsum(is_control, axis=0)
|
|
484
|
+
true_pos = cumsum_tp / cumsum_tp[-1]
|
|
485
|
+
false_pos = cumsum_fp / n_controls
|
|
486
|
+
|
|
487
|
+
scores = np.empty(n_times, dtype=float)
|
|
488
|
+
it = np.nditer((true_pos, false_pos, is_tied), order="F", flags=["external_loop"])
|
|
489
|
+
with it:
|
|
490
|
+
for i, (tp, fp, mask) in enumerate(it):
|
|
491
|
+
idx = np.flatnonzero(mask) - 1
|
|
492
|
+
# only keep the last estimate for tied risk scores
|
|
493
|
+
tp_no_ties = np.delete(tp, idx)
|
|
494
|
+
fp_no_ties = np.delete(fp, idx)
|
|
495
|
+
# Add an extra threshold position
|
|
496
|
+
# to make sure that the curve starts at (0, 0)
|
|
497
|
+
tp_no_ties = np.r_[0, tp_no_ties]
|
|
498
|
+
fp_no_ties = np.r_[0, fp_no_ties]
|
|
499
|
+
scores[i] = np.trapz(tp_no_ties, fp_no_ties)
|
|
500
|
+
|
|
501
|
+
if n_times == 1:
|
|
502
|
+
mean_auc = scores[0]
|
|
503
|
+
else:
|
|
504
|
+
surv = SurvivalFunctionEstimator()
|
|
505
|
+
surv.fit(survival_test)
|
|
506
|
+
s_times = surv.predict_proba(times)
|
|
507
|
+
# compute integral of AUC over survival function
|
|
508
|
+
d = -np.diff(np.r_[1.0, s_times])
|
|
509
|
+
integral = (scores * d).sum()
|
|
510
|
+
mean_auc = integral / (1.0 - s_times[-1])
|
|
511
|
+
|
|
512
|
+
return scores, mean_auc
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def brier_score(survival_train, survival_test, estimate, times):
|
|
516
|
+
"""Estimate the time-dependent Brier score for right censored data.
|
|
517
|
+
|
|
518
|
+
The time-dependent Brier score is the mean squared error at time point :math:`t`:
|
|
519
|
+
|
|
520
|
+
.. math::
|
|
521
|
+
|
|
522
|
+
\\mathrm{BS}^c(t) = \\frac{1}{n} \\sum_{i=1}^n I(y_i \\leq t \\land \\delta_i = 1)
|
|
523
|
+
\\frac{(0 - \\hat{\\pi}(t | \\mathbf{x}_i))^2}{\\hat{G}(y_i)} + I(y_i > t)
|
|
524
|
+
\\frac{(1 - \\hat{\\pi}(t | \\mathbf{x}_i))^2}{\\hat{G}(t)} ,
|
|
525
|
+
|
|
526
|
+
where :math:`\\hat{\\pi}(t | \\mathbf{x})` is the predicted probability of
|
|
527
|
+
remaining event-free up to time point :math:`t` for a feature vector :math:`\\mathbf{x}`,
|
|
528
|
+
and :math:`1/\\hat{G}(t)` is a inverse probability of censoring weight, estimated by
|
|
529
|
+
the Kaplan-Meier estimator.
|
|
530
|
+
|
|
531
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Brier-Score>`
|
|
532
|
+
and [1]_ for details.
|
|
533
|
+
|
|
534
|
+
Parameters
|
|
535
|
+
----------
|
|
536
|
+
survival_train : structured array, shape = (n_train_samples,)
|
|
537
|
+
Survival times for training data to estimate the censoring
|
|
538
|
+
distribution from.
|
|
539
|
+
A structured array containing the binary event indicator
|
|
540
|
+
as first field, and time of event or time of censoring as
|
|
541
|
+
second field.
|
|
542
|
+
|
|
543
|
+
survival_test : structured array, shape = (n_samples,)
|
|
544
|
+
Survival times of test data.
|
|
545
|
+
A structured array containing the binary event indicator
|
|
546
|
+
as first field, and time of event or time of censoring as
|
|
547
|
+
second field.
|
|
548
|
+
|
|
549
|
+
estimate : array-like, shape = (n_samples, n_times)
|
|
550
|
+
Estimated probability of remaining event-free at time points
|
|
551
|
+
specified by `times`. The value of ``estimate[i]`` must correspond to
|
|
552
|
+
the estimated probability of remaining event-free up to the time point
|
|
553
|
+
``times[i]``. Typically, estimated probabilities are obtained via the
|
|
554
|
+
survival function returned by an estimator's
|
|
555
|
+
``predict_survival_function`` method.
|
|
556
|
+
|
|
557
|
+
times : array-like, shape = (n_times,)
|
|
558
|
+
The time points for which to estimate the Brier score.
|
|
559
|
+
Values must be within the range of follow-up times of
|
|
560
|
+
the test data `survival_test`.
|
|
561
|
+
|
|
562
|
+
Returns
|
|
563
|
+
-------
|
|
564
|
+
times : array, shape = (n_times,)
|
|
565
|
+
Unique time points at which the brier scores was estimated.
|
|
566
|
+
|
|
567
|
+
brier_scores : array , shape = (n_times,)
|
|
568
|
+
Values of the brier score.
|
|
569
|
+
|
|
570
|
+
Examples
|
|
571
|
+
--------
|
|
572
|
+
>>> from sksurv.datasets import load_gbsg2
|
|
573
|
+
>>> from sksurv.linear_model import CoxPHSurvivalAnalysis
|
|
574
|
+
>>> from sksurv.metrics import brier_score
|
|
575
|
+
>>> from sksurv.preprocessing import OneHotEncoder
|
|
576
|
+
|
|
577
|
+
Load and prepare data.
|
|
578
|
+
|
|
579
|
+
>>> X, y = load_gbsg2()
|
|
580
|
+
>>> X.loc[:, "tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
|
|
581
|
+
>>> Xt = OneHotEncoder().fit_transform(X)
|
|
582
|
+
|
|
583
|
+
Fit a Cox model.
|
|
584
|
+
|
|
585
|
+
>>> est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)
|
|
586
|
+
|
|
587
|
+
Retrieve individual survival functions and get probability
|
|
588
|
+
of remaining event free up to 5 years (=1825 days).
|
|
589
|
+
|
|
590
|
+
>>> survs = est.predict_survival_function(Xt)
|
|
591
|
+
>>> preds = [fn(1825) for fn in survs]
|
|
592
|
+
|
|
593
|
+
Compute the Brier score at 5 years.
|
|
594
|
+
|
|
595
|
+
>>> times, score = brier_score(y, y, preds, 1825)
|
|
596
|
+
>>> print(score)
|
|
597
|
+
[0.20881843]
|
|
598
|
+
|
|
599
|
+
See also
|
|
600
|
+
--------
|
|
601
|
+
integrated_brier_score
|
|
602
|
+
Computes the average Brier score over all time points.
|
|
603
|
+
|
|
604
|
+
References
|
|
605
|
+
----------
|
|
606
|
+
.. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher,
|
|
607
|
+
"Assessment and comparison of prognostic classification schemes for survival data,"
|
|
608
|
+
Statistics in Medicine, vol. 18, no. 17-18, pp. 2529–2545, 1999.
|
|
609
|
+
"""
|
|
610
|
+
test_event, test_time = check_y_survival(survival_test)
|
|
611
|
+
estimate, times = _check_estimate_2d(estimate, test_time, times, estimator="brier_score")
|
|
612
|
+
if estimate.ndim == 1 and times.shape[0] == 1:
|
|
613
|
+
estimate = estimate.reshape(-1, 1)
|
|
614
|
+
|
|
615
|
+
# fit IPCW estimator
|
|
616
|
+
cens = CensoringDistributionEstimator().fit(survival_train)
|
|
617
|
+
# calculate inverse probability of censoring weight at current time point t.
|
|
618
|
+
prob_cens_t = cens.predict_proba(times)
|
|
619
|
+
prob_cens_t[prob_cens_t == 0] = np.inf
|
|
620
|
+
# calculate inverse probability of censoring weights at observed time point
|
|
621
|
+
prob_cens_y = cens.predict_proba(test_time)
|
|
622
|
+
prob_cens_y[prob_cens_y == 0] = np.inf
|
|
623
|
+
|
|
624
|
+
# Calculating the brier scores at each time point
|
|
625
|
+
brier_scores = np.empty(times.shape[0], dtype=float)
|
|
626
|
+
for i, t in enumerate(times):
|
|
627
|
+
est = estimate[:, i]
|
|
628
|
+
is_case = (test_time <= t) & test_event
|
|
629
|
+
is_control = test_time > t
|
|
630
|
+
|
|
631
|
+
brier_scores[i] = np.mean(
|
|
632
|
+
np.square(est) * is_case.astype(int) / prob_cens_y
|
|
633
|
+
+ np.square(1.0 - est) * is_control.astype(int) / prob_cens_t[i]
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
return times, brier_scores
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def integrated_brier_score(survival_train, survival_test, estimate, times):
|
|
640
|
+
"""The Integrated Brier Score (IBS) provides an overall calculation of
|
|
641
|
+
the model performance at all available times :math:`t_1 \\leq t \\leq t_\\text{max}`.
|
|
642
|
+
|
|
643
|
+
The integrated time-dependent Brier score over the interval
|
|
644
|
+
:math:`[t_1; t_\\text{max}]` is defined as
|
|
645
|
+
|
|
646
|
+
.. math::
|
|
647
|
+
|
|
648
|
+
\\mathrm{IBS} = \\int_{t_1}^{t_\\text{max}} \\mathrm{BS}^c(t) d w(t)
|
|
649
|
+
|
|
650
|
+
where the weighting function is :math:`w(t) = t / t_\\text{max}`.
|
|
651
|
+
The integral is estimated via the trapezoidal rule.
|
|
652
|
+
|
|
653
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Time-dependent-Brier-Score>`
|
|
654
|
+
and [1]_ for further details.
|
|
655
|
+
|
|
656
|
+
Parameters
|
|
657
|
+
----------
|
|
658
|
+
survival_train : structured array, shape = (n_train_samples,)
|
|
659
|
+
Survival times for training data to estimate the censoring
|
|
660
|
+
distribution from.
|
|
661
|
+
A structured array containing the binary event indicator
|
|
662
|
+
as first field, and time of event or time of censoring as
|
|
663
|
+
second field.
|
|
664
|
+
|
|
665
|
+
survival_test : structured array, shape = (n_samples,)
|
|
666
|
+
Survival times of test data.
|
|
667
|
+
A structured array containing the binary event indicator
|
|
668
|
+
as first field, and time of event or time of censoring as
|
|
669
|
+
second field.
|
|
670
|
+
|
|
671
|
+
estimate : array-like, shape = (n_samples, n_times)
|
|
672
|
+
Estimated probability of remaining event-free at time points
|
|
673
|
+
specified by `times`. The value of ``estimate[i]`` must correspond to
|
|
674
|
+
the estimated probability of remaining event-free up to the time point
|
|
675
|
+
``times[i]``. Typically, estimated probabilities are obtained via the
|
|
676
|
+
survival function returned by an estimator's
|
|
677
|
+
``predict_survival_function`` method.
|
|
678
|
+
|
|
679
|
+
times : array-like, shape = (n_times,)
|
|
680
|
+
The time points for which to estimate the Brier score.
|
|
681
|
+
Values must be within the range of follow-up times of
|
|
682
|
+
the test data `survival_test`.
|
|
683
|
+
|
|
684
|
+
Returns
|
|
685
|
+
-------
|
|
686
|
+
ibs : float
|
|
687
|
+
The integrated Brier score.
|
|
688
|
+
|
|
689
|
+
Examples
|
|
690
|
+
--------
|
|
691
|
+
>>> import numpy as np
|
|
692
|
+
>>> from sksurv.datasets import load_gbsg2
|
|
693
|
+
>>> from sksurv.linear_model import CoxPHSurvivalAnalysis
|
|
694
|
+
>>> from sksurv.metrics import integrated_brier_score
|
|
695
|
+
>>> from sksurv.preprocessing import OneHotEncoder
|
|
696
|
+
|
|
697
|
+
Load and prepare data.
|
|
698
|
+
|
|
699
|
+
>>> X, y = load_gbsg2()
|
|
700
|
+
>>> X.loc[:, "tgrade"] = X.loc[:, "tgrade"].map(len).astype(int)
|
|
701
|
+
>>> Xt = OneHotEncoder().fit_transform(X)
|
|
702
|
+
|
|
703
|
+
Fit a Cox model.
|
|
704
|
+
|
|
705
|
+
>>> est = CoxPHSurvivalAnalysis(ties="efron").fit(Xt, y)
|
|
706
|
+
|
|
707
|
+
Retrieve individual survival functions and get probability
|
|
708
|
+
of remaining event free from 1 year to 5 years (=1825 days).
|
|
709
|
+
|
|
710
|
+
>>> survs = est.predict_survival_function(Xt)
|
|
711
|
+
>>> times = np.arange(365, 1826)
|
|
712
|
+
>>> preds = np.asarray([[fn(t) for t in times] for fn in survs])
|
|
713
|
+
|
|
714
|
+
Compute the integrated Brier score from 1 to 5 years.
|
|
715
|
+
|
|
716
|
+
>>> score = integrated_brier_score(y, y, preds, times)
|
|
717
|
+
>>> print(score)
|
|
718
|
+
0.1815853064627424
|
|
719
|
+
|
|
720
|
+
See also
|
|
721
|
+
--------
|
|
722
|
+
brier_score
|
|
723
|
+
Computes the Brier score at specified time points.
|
|
724
|
+
|
|
725
|
+
as_integrated_brier_score_scorer
|
|
726
|
+
Wrapper class that uses :func:`integrated_brier_score`
|
|
727
|
+
in its ``score`` method instead of the default
|
|
728
|
+
:func:`concordance_index_censored`.
|
|
729
|
+
|
|
730
|
+
References
|
|
731
|
+
----------
|
|
732
|
+
.. [1] E. Graf, C. Schmoor, W. Sauerbrei, and M. Schumacher,
|
|
733
|
+
"Assessment and comparison of prognostic classification schemes for survival data,"
|
|
734
|
+
Statistics in Medicine, vol. 18, no. 17-18, pp. 2529–2545, 1999.
|
|
735
|
+
"""
|
|
736
|
+
# Computing the brier scores
|
|
737
|
+
times, brier_scores = brier_score(survival_train, survival_test, estimate, times)
|
|
738
|
+
|
|
739
|
+
if times.shape[0] < 2:
|
|
740
|
+
raise ValueError("At least two time points must be given")
|
|
741
|
+
|
|
742
|
+
# Computing the IBS
|
|
743
|
+
ibs_value = np.trapz(brier_scores, times) / (times[-1] - times[0])
|
|
744
|
+
|
|
745
|
+
return ibs_value
|
|
746
|
+
|
|
747
|
+
|
|
748
|
+
def _estimator_has(attr):
|
|
749
|
+
"""Check that meta_estimator has `attr`.
|
|
750
|
+
|
|
751
|
+
Used together with `available_if`."""
|
|
752
|
+
|
|
753
|
+
def check(self):
|
|
754
|
+
# raise original `AttributeError` if `attr` does not exist
|
|
755
|
+
getattr(self.estimator_, attr)
|
|
756
|
+
return True
|
|
757
|
+
|
|
758
|
+
return check
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
class _ScoreOverrideMixin:
|
|
762
|
+
def __init__(self, estimator, predict_func, score_func, score_index, greater_is_better):
|
|
763
|
+
if not hasattr(estimator, predict_func):
|
|
764
|
+
raise AttributeError(f"{estimator!r} object has no attribute {predict_func!r}")
|
|
765
|
+
|
|
766
|
+
self.estimator = estimator
|
|
767
|
+
self._predict_func = predict_func
|
|
768
|
+
self._score_func = score_func
|
|
769
|
+
self._score_index = score_index
|
|
770
|
+
self._sign = 1 if greater_is_better else -1
|
|
771
|
+
|
|
772
|
+
def _get_score_params(self):
|
|
773
|
+
"""Return dict of parameters passed to ``score_func``."""
|
|
774
|
+
params = self.get_params(deep=False)
|
|
775
|
+
del params["estimator"]
|
|
776
|
+
return params
|
|
777
|
+
|
|
778
|
+
def fit(self, X, y, **fit_params):
|
|
779
|
+
self._train_y = np.array(y, copy=True)
|
|
780
|
+
self.estimator_ = self.estimator.fit(X, y, **fit_params)
|
|
781
|
+
return self
|
|
782
|
+
|
|
783
|
+
def _do_predict(self, X):
|
|
784
|
+
predict_func = getattr(self.estimator_, self._predict_func)
|
|
785
|
+
return predict_func(X)
|
|
786
|
+
|
|
787
|
+
def score(self, X, y):
|
|
788
|
+
"""Returns the score on the given data.
|
|
789
|
+
|
|
790
|
+
Parameters
|
|
791
|
+
----------
|
|
792
|
+
X : array-like of shape (n_samples, n_features)
|
|
793
|
+
Input data, where n_samples is the number of samples and
|
|
794
|
+
n_features is the number of features.
|
|
795
|
+
|
|
796
|
+
y : array-like of shape (n_samples,)
|
|
797
|
+
Target relative to X for classification or regression;
|
|
798
|
+
None for unsupervised learning.
|
|
799
|
+
|
|
800
|
+
Returns
|
|
801
|
+
-------
|
|
802
|
+
score : float
|
|
803
|
+
"""
|
|
804
|
+
estimate = self._do_predict(X)
|
|
805
|
+
score = self._score_func(
|
|
806
|
+
survival_train=self._train_y,
|
|
807
|
+
survival_test=y,
|
|
808
|
+
estimate=estimate,
|
|
809
|
+
**self._get_score_params(),
|
|
810
|
+
)
|
|
811
|
+
if self._score_index is not None:
|
|
812
|
+
score = score[self._score_index]
|
|
813
|
+
return self._sign * score
|
|
814
|
+
|
|
815
|
+
@available_if(_estimator_has("predict"))
|
|
816
|
+
def predict(self, X):
|
|
817
|
+
"""Call predict on the estimator.
|
|
818
|
+
|
|
819
|
+
Only available if estimator supports ``predict``.
|
|
820
|
+
|
|
821
|
+
Parameters
|
|
822
|
+
----------
|
|
823
|
+
X : indexable, length n_samples
|
|
824
|
+
Must fulfill the input assumptions of the
|
|
825
|
+
underlying estimator.
|
|
826
|
+
"""
|
|
827
|
+
check_is_fitted(self, "estimator_")
|
|
828
|
+
return self.estimator_.predict(X)
|
|
829
|
+
|
|
830
|
+
@available_if(_estimator_has("predict_cumulative_hazard_function"))
|
|
831
|
+
def predict_cumulative_hazard_function(self, X):
|
|
832
|
+
"""Call predict_cumulative_hazard_function on the estimator.
|
|
833
|
+
|
|
834
|
+
Only available if estimator supports ``predict_cumulative_hazard_function``.
|
|
835
|
+
|
|
836
|
+
Parameters
|
|
837
|
+
----------
|
|
838
|
+
X : indexable, length n_samples
|
|
839
|
+
Must fulfill the input assumptions of the
|
|
840
|
+
underlying estimator.
|
|
841
|
+
"""
|
|
842
|
+
check_is_fitted(self, "estimator_")
|
|
843
|
+
return self.estimator_.predict_cumulative_hazard_function(X)
|
|
844
|
+
|
|
845
|
+
@available_if(_estimator_has("predict_survival_function"))
|
|
846
|
+
def predict_survival_function(self, X):
|
|
847
|
+
"""Call predict_survival_function on the estimator.
|
|
848
|
+
|
|
849
|
+
Only available if estimator supports ``predict_survival_function``.
|
|
850
|
+
|
|
851
|
+
Parameters
|
|
852
|
+
----------
|
|
853
|
+
X : indexable, length n_samples
|
|
854
|
+
Must fulfill the input assumptions of the
|
|
855
|
+
underlying estimator.
|
|
856
|
+
"""
|
|
857
|
+
check_is_fitted(self, "estimator_")
|
|
858
|
+
return self.estimator_.predict_survival_function(X)
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
class as_cumulative_dynamic_auc_scorer(_ScoreOverrideMixin, BaseEstimator):
|
|
862
|
+
"""Wraps an estimator to use :func:`cumulative_dynamic_auc` as ``score`` function.
|
|
863
|
+
|
|
864
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
|
|
865
|
+
for using it for hyper-parameter optimization.
|
|
866
|
+
|
|
867
|
+
Parameters
|
|
868
|
+
----------
|
|
869
|
+
estimator : object
|
|
870
|
+
Instance of an estimator.
|
|
871
|
+
|
|
872
|
+
times : array-like, shape = (n_times,)
|
|
873
|
+
The time points for which the area under the
|
|
874
|
+
time-dependent ROC curve is computed. Values must be
|
|
875
|
+
within the range of follow-up times of the test data
|
|
876
|
+
`survival_test`.
|
|
877
|
+
|
|
878
|
+
tied_tol : float, optional, default: 1e-8
|
|
879
|
+
The tolerance value for considering ties.
|
|
880
|
+
If the absolute difference between risk scores is smaller
|
|
881
|
+
or equal than `tied_tol`, risk scores are considered tied.
|
|
882
|
+
|
|
883
|
+
Attributes
|
|
884
|
+
----------
|
|
885
|
+
estimator_ : estimator
|
|
886
|
+
Estimator that was fit.
|
|
887
|
+
|
|
888
|
+
See also
|
|
889
|
+
--------
|
|
890
|
+
cumulative_dynamic_auc
|
|
891
|
+
"""
|
|
892
|
+
|
|
893
|
+
def __init__(self, estimator, times, tied_tol=1e-8):
|
|
894
|
+
super().__init__(
|
|
895
|
+
estimator=estimator,
|
|
896
|
+
predict_func="predict",
|
|
897
|
+
score_func=cumulative_dynamic_auc,
|
|
898
|
+
score_index=1,
|
|
899
|
+
greater_is_better=True,
|
|
900
|
+
)
|
|
901
|
+
self.times = times
|
|
902
|
+
self.tied_tol = tied_tol
|
|
903
|
+
|
|
904
|
+
|
|
905
|
+
class as_concordance_index_ipcw_scorer(_ScoreOverrideMixin, BaseEstimator):
|
|
906
|
+
"""Wraps an estimator to use :func:`concordance_index_ipcw` as ``score`` function.
|
|
907
|
+
|
|
908
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
|
|
909
|
+
for using it for hyper-parameter optimization.
|
|
910
|
+
|
|
911
|
+
Parameters
|
|
912
|
+
----------
|
|
913
|
+
estimator : object
|
|
914
|
+
Instance of an estimator.
|
|
915
|
+
|
|
916
|
+
tau : float, optional
|
|
917
|
+
Truncation time. The survival function for the underlying
|
|
918
|
+
censoring time distribution :math:`D` needs to be positive
|
|
919
|
+
at `tau`, i.e., `tau` should be chosen such that the
|
|
920
|
+
probability of being censored after time `tau` is non-zero:
|
|
921
|
+
:math:`P(D > \\tau) > 0`. If `None`, no truncation is performed.
|
|
922
|
+
|
|
923
|
+
tied_tol : float, optional, default: 1e-8
|
|
924
|
+
The tolerance value for considering ties.
|
|
925
|
+
If the absolute difference between risk scores is smaller
|
|
926
|
+
or equal than `tied_tol`, risk scores are considered tied.
|
|
927
|
+
|
|
928
|
+
Attributes
|
|
929
|
+
----------
|
|
930
|
+
estimator_ : estimator
|
|
931
|
+
Estimator that was fit.
|
|
932
|
+
|
|
933
|
+
See also
|
|
934
|
+
--------
|
|
935
|
+
concordance_index_ipcw
|
|
936
|
+
"""
|
|
937
|
+
|
|
938
|
+
def __init__(self, estimator, tau=None, tied_tol=1e-8):
|
|
939
|
+
super().__init__(
|
|
940
|
+
estimator=estimator,
|
|
941
|
+
predict_func="predict",
|
|
942
|
+
score_func=concordance_index_ipcw,
|
|
943
|
+
score_index=0,
|
|
944
|
+
greater_is_better=True,
|
|
945
|
+
)
|
|
946
|
+
self.tau = tau
|
|
947
|
+
self.tied_tol = tied_tol
|
|
948
|
+
|
|
949
|
+
|
|
950
|
+
class as_integrated_brier_score_scorer(_ScoreOverrideMixin, BaseEstimator):
|
|
951
|
+
"""Wraps an estimator to use the negative of :func:`integrated_brier_score` as ``score`` function.
|
|
952
|
+
|
|
953
|
+
The estimator needs to be able to estimate survival functions via
|
|
954
|
+
a ``predict_survival_function`` method.
|
|
955
|
+
|
|
956
|
+
See the :ref:`User Guide </user_guide/evaluating-survival-models.ipynb#Using-Metrics-in-Hyper-parameter-Search>`
|
|
957
|
+
for using it for hyper-parameter optimization.
|
|
958
|
+
|
|
959
|
+
Parameters
|
|
960
|
+
----------
|
|
961
|
+
estimator : object
|
|
962
|
+
Instance of an estimator that provides ``predict_survival_function``.
|
|
963
|
+
|
|
964
|
+
times : array-like, shape = (n_times,)
|
|
965
|
+
The time points for which to estimate the Brier score.
|
|
966
|
+
Values must be within the range of follow-up times of
|
|
967
|
+
the test data `survival_test`.
|
|
968
|
+
|
|
969
|
+
Attributes
|
|
970
|
+
----------
|
|
971
|
+
estimator_ : estimator
|
|
972
|
+
Estimator that was fit.
|
|
973
|
+
|
|
974
|
+
See also
|
|
975
|
+
--------
|
|
976
|
+
integrated_brier_score
|
|
977
|
+
"""
|
|
978
|
+
|
|
979
|
+
def __init__(self, estimator, times):
|
|
980
|
+
super().__init__(
|
|
981
|
+
estimator=estimator,
|
|
982
|
+
predict_func="predict_survival_function",
|
|
983
|
+
score_func=integrated_brier_score,
|
|
984
|
+
score_index=None,
|
|
985
|
+
greater_is_better=False,
|
|
986
|
+
)
|
|
987
|
+
self.times = times
|
|
988
|
+
|
|
989
|
+
def _do_predict(self, X):
|
|
990
|
+
predict_func = getattr(self.estimator_, self._predict_func)
|
|
991
|
+
surv_fns = predict_func(X)
|
|
992
|
+
times = self.times
|
|
993
|
+
estimates = np.empty((len(surv_fns), len(times)))
|
|
994
|
+
for i, fn in enumerate(surv_fns):
|
|
995
|
+
estimates[i, :] = fn(times)
|
|
996
|
+
return estimates
|