scikit-survival 0.23.1__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.23.1.dist-info/COPYING +674 -0
- scikit_survival-0.23.1.dist-info/METADATA +888 -0
- scikit_survival-0.23.1.dist-info/RECORD +55 -0
- scikit_survival-0.23.1.dist-info/WHEEL +5 -0
- scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
- sksurv/__init__.py +138 -0
- sksurv/base.py +103 -0
- sksurv/bintrees/__init__.py +15 -0
- sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
- sksurv/column.py +201 -0
- sksurv/compare.py +123 -0
- sksurv/datasets/__init__.py +10 -0
- sksurv/datasets/base.py +436 -0
- sksurv/datasets/data/GBSG2.arff +700 -0
- sksurv/datasets/data/actg320.arff +1169 -0
- sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
- sksurv/datasets/data/flchain.arff +7887 -0
- sksurv/datasets/data/veteran.arff +148 -0
- sksurv/datasets/data/whas500.arff +520 -0
- sksurv/ensemble/__init__.py +2 -0
- sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
- sksurv/ensemble/boosting.py +1610 -0
- sksurv/ensemble/forest.py +947 -0
- sksurv/ensemble/survival_loss.py +151 -0
- sksurv/exceptions.py +18 -0
- sksurv/functions.py +114 -0
- sksurv/io/__init__.py +2 -0
- sksurv/io/arffread.py +58 -0
- sksurv/io/arffwrite.py +145 -0
- sksurv/kernels/__init__.py +1 -0
- sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +328 -0
- sksurv/linear_model/__init__.py +3 -0
- sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
- sksurv/linear_model/aft.py +205 -0
- sksurv/linear_model/coxnet.py +543 -0
- sksurv/linear_model/coxph.py +618 -0
- sksurv/meta/__init__.py +4 -0
- sksurv/meta/base.py +35 -0
- sksurv/meta/ensemble_selection.py +642 -0
- sksurv/meta/stacking.py +349 -0
- sksurv/metrics.py +996 -0
- sksurv/nonparametric.py +588 -0
- sksurv/preprocessing.py +155 -0
- sksurv/svm/__init__.py +11 -0
- sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +606 -0
- sksurv/svm/naive_survival_svm.py +221 -0
- sksurv/svm/survival_svm.py +1228 -0
- sksurv/testing.py +108 -0
- sksurv/tree/__init__.py +1 -0
- sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
- sksurv/tree/tree.py +703 -0
- sksurv/util.py +333 -0
sksurv/nonparametric.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
1
|
+
# This program is free software: you can redistribute it and/or modify
|
|
2
|
+
# it under the terms of the GNU General Public License as published by
|
|
3
|
+
# the Free Software Foundation, either version 3 of the License, or
|
|
4
|
+
# (at your option) any later version.
|
|
5
|
+
#
|
|
6
|
+
# This program is distributed in the hope that it will be useful,
|
|
7
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
8
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
9
|
+
# GNU General Public License for more details.
|
|
10
|
+
#
|
|
11
|
+
# You should have received a copy of the GNU General Public License
|
|
12
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import numbers
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
from scipy import stats
|
|
17
|
+
from sklearn.base import BaseEstimator
|
|
18
|
+
from sklearn.utils._param_validation import Interval, StrOptions
|
|
19
|
+
from sklearn.utils.validation import check_array, check_consistent_length, check_is_fitted
|
|
20
|
+
|
|
21
|
+
from .util import check_y_survival
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"CensoringDistributionEstimator",
|
|
25
|
+
"kaplan_meier_estimator",
|
|
26
|
+
"nelson_aalen_estimator",
|
|
27
|
+
"ipc_weights",
|
|
28
|
+
"SurvivalFunctionEstimator",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _compute_counts(event, time, order=None):
|
|
33
|
+
"""Count right censored and uncensored samples at each unique time point.
|
|
34
|
+
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
event : array
|
|
38
|
+
Boolean event indicator.
|
|
39
|
+
|
|
40
|
+
time : array
|
|
41
|
+
Survival time or time of censoring.
|
|
42
|
+
|
|
43
|
+
order : array or None
|
|
44
|
+
Indices to order time in ascending order.
|
|
45
|
+
If None, order will be computed.
|
|
46
|
+
|
|
47
|
+
Returns
|
|
48
|
+
-------
|
|
49
|
+
times : array
|
|
50
|
+
Unique time points.
|
|
51
|
+
|
|
52
|
+
n_events : array
|
|
53
|
+
Number of events at each time point.
|
|
54
|
+
|
|
55
|
+
n_at_risk : array
|
|
56
|
+
Number of samples that have not been censored or have not had an event at each time point.
|
|
57
|
+
|
|
58
|
+
n_censored : array
|
|
59
|
+
Number of censored samples at each time point.
|
|
60
|
+
"""
|
|
61
|
+
n_samples = event.shape[0]
|
|
62
|
+
|
|
63
|
+
if order is None:
|
|
64
|
+
order = np.argsort(time, kind="mergesort")
|
|
65
|
+
|
|
66
|
+
uniq_times = np.empty(n_samples, dtype=time.dtype)
|
|
67
|
+
uniq_events = np.empty(n_samples, dtype=int)
|
|
68
|
+
uniq_counts = np.empty(n_samples, dtype=int)
|
|
69
|
+
|
|
70
|
+
i = 0
|
|
71
|
+
prev_val = time[order[0]]
|
|
72
|
+
j = 0
|
|
73
|
+
while True:
|
|
74
|
+
count_event = 0
|
|
75
|
+
count = 0
|
|
76
|
+
while i < n_samples and prev_val == time[order[i]]:
|
|
77
|
+
if event[order[i]]:
|
|
78
|
+
count_event += 1
|
|
79
|
+
|
|
80
|
+
count += 1
|
|
81
|
+
i += 1
|
|
82
|
+
|
|
83
|
+
uniq_times[j] = prev_val
|
|
84
|
+
uniq_events[j] = count_event
|
|
85
|
+
uniq_counts[j] = count
|
|
86
|
+
j += 1
|
|
87
|
+
|
|
88
|
+
if i == n_samples:
|
|
89
|
+
break
|
|
90
|
+
|
|
91
|
+
prev_val = time[order[i]]
|
|
92
|
+
|
|
93
|
+
times = np.resize(uniq_times, j)
|
|
94
|
+
n_events = np.resize(uniq_events, j)
|
|
95
|
+
total_count = np.resize(uniq_counts, j)
|
|
96
|
+
n_censored = total_count - n_events
|
|
97
|
+
|
|
98
|
+
# offset cumulative sum by one
|
|
99
|
+
total_count = np.r_[0, total_count]
|
|
100
|
+
n_at_risk = n_samples - np.cumsum(total_count)
|
|
101
|
+
|
|
102
|
+
return times, n_events, n_at_risk[:-1], n_censored
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _compute_counts_truncated(event, time_enter, time_exit):
|
|
106
|
+
"""Compute counts for left truncated and right censored survival data.
|
|
107
|
+
|
|
108
|
+
Parameters
|
|
109
|
+
----------
|
|
110
|
+
event : array
|
|
111
|
+
Boolean event indicator.
|
|
112
|
+
|
|
113
|
+
time_start : array
|
|
114
|
+
Time when a subject entered the study.
|
|
115
|
+
|
|
116
|
+
time_exit : array
|
|
117
|
+
Time when a subject left the study due to an
|
|
118
|
+
event or censoring.
|
|
119
|
+
|
|
120
|
+
Returns
|
|
121
|
+
-------
|
|
122
|
+
times : array
|
|
123
|
+
Unique time points.
|
|
124
|
+
|
|
125
|
+
n_events : array
|
|
126
|
+
Number of events at each time point.
|
|
127
|
+
|
|
128
|
+
n_at_risk : array
|
|
129
|
+
Number of samples that are censored or have an event at each time point.
|
|
130
|
+
"""
|
|
131
|
+
if (time_enter > time_exit).any():
|
|
132
|
+
raise ValueError("exit time must be larger start time for all samples")
|
|
133
|
+
|
|
134
|
+
n_samples = event.shape[0]
|
|
135
|
+
|
|
136
|
+
uniq_times = np.sort(np.unique(np.r_[time_enter, time_exit]), kind="mergesort")
|
|
137
|
+
total_counts = np.empty(len(uniq_times), dtype=int)
|
|
138
|
+
event_counts = np.empty(len(uniq_times), dtype=int)
|
|
139
|
+
|
|
140
|
+
order_enter = np.argsort(time_enter, kind="mergesort")
|
|
141
|
+
order_exit = np.argsort(time_exit, kind="mergesort")
|
|
142
|
+
s_time_enter = time_enter[order_enter]
|
|
143
|
+
s_time_exit = time_exit[order_exit]
|
|
144
|
+
|
|
145
|
+
t0 = uniq_times[0]
|
|
146
|
+
# everything larger is included
|
|
147
|
+
idx_enter = np.searchsorted(s_time_enter, t0, side="right")
|
|
148
|
+
# everything smaller is excluded
|
|
149
|
+
idx_exit = np.searchsorted(s_time_exit, t0, side="left")
|
|
150
|
+
|
|
151
|
+
total_counts[0] = idx_enter
|
|
152
|
+
# except people die on the day they enter
|
|
153
|
+
event_counts[0] = 0
|
|
154
|
+
|
|
155
|
+
for i in range(1, len(uniq_times)):
|
|
156
|
+
ti = uniq_times[i]
|
|
157
|
+
|
|
158
|
+
while idx_enter < n_samples and s_time_enter[idx_enter] < ti:
|
|
159
|
+
idx_enter += 1
|
|
160
|
+
|
|
161
|
+
while idx_exit < n_samples and s_time_exit[idx_exit] < ti:
|
|
162
|
+
idx_exit += 1
|
|
163
|
+
|
|
164
|
+
risk_set = np.setdiff1d(order_enter[:idx_enter], order_exit[:idx_exit], assume_unique=True)
|
|
165
|
+
total_counts[i] = len(risk_set)
|
|
166
|
+
|
|
167
|
+
count_event = 0
|
|
168
|
+
k = idx_exit
|
|
169
|
+
while k < n_samples and s_time_exit[k] == ti:
|
|
170
|
+
if event[order_exit[k]]:
|
|
171
|
+
count_event += 1
|
|
172
|
+
k += 1
|
|
173
|
+
event_counts[i] = count_event
|
|
174
|
+
|
|
175
|
+
return uniq_times, event_counts, total_counts
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _ci_logmlog(prob_survival, sigma_t, z):
|
|
179
|
+
"""Compute the pointwise log-minus-log transformed confidence intervals"""
|
|
180
|
+
eps = np.finfo(prob_survival.dtype).eps
|
|
181
|
+
log_p = np.zeros_like(prob_survival)
|
|
182
|
+
np.log(prob_survival, where=prob_survival > eps, out=log_p)
|
|
183
|
+
theta = np.zeros_like(prob_survival)
|
|
184
|
+
np.true_divide(sigma_t, log_p, where=log_p < -eps, out=theta)
|
|
185
|
+
theta = np.array([[-1], [1]]) * theta * z
|
|
186
|
+
ci = np.exp(np.exp(theta) * log_p)
|
|
187
|
+
ci[:, prob_survival <= eps] = 0.0
|
|
188
|
+
ci[:, 1.0 - prob_survival <= eps] = 1.0
|
|
189
|
+
return ci
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _km_ci_estimator(prob_survival, ratio_var, conf_level, conf_type):
|
|
193
|
+
if conf_type not in {"log-log"}:
|
|
194
|
+
raise ValueError(f"conf_type must be None or a str among {{'log-log'}}, but was {conf_type!r}")
|
|
195
|
+
|
|
196
|
+
if not isinstance(conf_level, numbers.Real) or not np.isfinite(conf_level) or conf_level <= 0 or conf_level >= 1.0:
|
|
197
|
+
raise ValueError(f"conf_level must be a float in the range (0.0, 1.0), but was {conf_level!r}")
|
|
198
|
+
|
|
199
|
+
z = stats.norm.isf((1.0 - conf_level) / 2.0)
|
|
200
|
+
sigma = np.sqrt(np.cumsum(ratio_var))
|
|
201
|
+
ci = _ci_logmlog(prob_survival, sigma, z)
|
|
202
|
+
return ci
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def kaplan_meier_estimator(
|
|
206
|
+
event,
|
|
207
|
+
time_exit,
|
|
208
|
+
time_enter=None,
|
|
209
|
+
time_min=None,
|
|
210
|
+
reverse=False,
|
|
211
|
+
conf_level=0.95,
|
|
212
|
+
conf_type=None,
|
|
213
|
+
):
|
|
214
|
+
"""Kaplan-Meier estimator of survival function.
|
|
215
|
+
|
|
216
|
+
See [1]_ for further description.
|
|
217
|
+
|
|
218
|
+
Parameters
|
|
219
|
+
----------
|
|
220
|
+
event : array-like, shape = (n_samples,)
|
|
221
|
+
Contains binary event indicators.
|
|
222
|
+
|
|
223
|
+
time_exit : array-like, shape = (n_samples,)
|
|
224
|
+
Contains event/censoring times.
|
|
225
|
+
|
|
226
|
+
time_enter : array-like, shape = (n_samples,), optional
|
|
227
|
+
Contains time when each individual entered the study for
|
|
228
|
+
left truncated survival data.
|
|
229
|
+
|
|
230
|
+
time_min : float, optional
|
|
231
|
+
Compute estimator conditional on survival at least up to
|
|
232
|
+
the specified time.
|
|
233
|
+
|
|
234
|
+
reverse : bool, optional, default: False
|
|
235
|
+
Whether to estimate the censoring distribution.
|
|
236
|
+
When there are ties between times at which events are observed,
|
|
237
|
+
then events come first and are subtracted from the denominator.
|
|
238
|
+
Only available for right-censored data, i.e. `time_enter` must
|
|
239
|
+
be None.
|
|
240
|
+
|
|
241
|
+
conf_level : float, optional, default: 0.95
|
|
242
|
+
The level for a two-sided confidence interval on the survival curves.
|
|
243
|
+
|
|
244
|
+
conf_type : None or {'log-log'}, optional, default: None.
|
|
245
|
+
The type of confidence intervals to estimate.
|
|
246
|
+
If `None`, no confidence intervals are estimated.
|
|
247
|
+
If "log-log", estimate confidence intervals using
|
|
248
|
+
the log hazard or :math:`log(-log(S(t)))` as described in [2]_.
|
|
249
|
+
|
|
250
|
+
Returns
|
|
251
|
+
-------
|
|
252
|
+
time : array, shape = (n_times,)
|
|
253
|
+
Unique times.
|
|
254
|
+
|
|
255
|
+
prob_survival : array, shape = (n_times,)
|
|
256
|
+
Survival probability at each unique time point.
|
|
257
|
+
If `time_enter` is provided, estimates are conditional probabilities.
|
|
258
|
+
|
|
259
|
+
conf_int : array, shape = (2, n_times)
|
|
260
|
+
Pointwise confidence interval of the Kaplan-Meier estimator
|
|
261
|
+
at each unique time point.
|
|
262
|
+
Only provided if `conf_type` is not None.
|
|
263
|
+
|
|
264
|
+
Examples
|
|
265
|
+
--------
|
|
266
|
+
Creating a Kaplan-Meier curve:
|
|
267
|
+
|
|
268
|
+
>>> x, y, conf_int = kaplan_meier_estimator(event, time, conf_type="log-log")
|
|
269
|
+
>>> plt.step(x, y, where="post")
|
|
270
|
+
>>> plt.fill_between(x, conf_int[0], conf_int[1], alpha=0.25, step="post")
|
|
271
|
+
>>> plt.ylim(0, 1)
|
|
272
|
+
>>> plt.show()
|
|
273
|
+
|
|
274
|
+
See also
|
|
275
|
+
--------
|
|
276
|
+
sksurv.nonparametric.SurvivalFunctionEstimator
|
|
277
|
+
Estimator API of the Kaplan-Meier estimator.
|
|
278
|
+
|
|
279
|
+
References
|
|
280
|
+
----------
|
|
281
|
+
.. [1] Kaplan, E. L. and Meier, P., "Nonparametric estimation from incomplete observations",
|
|
282
|
+
Journal of The American Statistical Association, vol. 53, pp. 457-481, 1958.
|
|
283
|
+
.. [2] Borgan Ø. and Liestøl K., "A Note on Confidence Intervals and Bands for the
|
|
284
|
+
Survival Function Based on Transformations", Scandinavian Journal of
|
|
285
|
+
Statistics. 1990;17(1):35–41.
|
|
286
|
+
"""
|
|
287
|
+
event, time_enter, time_exit = check_y_survival(event, time_enter, time_exit, allow_all_censored=True)
|
|
288
|
+
check_consistent_length(event, time_enter, time_exit)
|
|
289
|
+
|
|
290
|
+
if conf_type is not None and reverse:
|
|
291
|
+
raise NotImplementedError("Confidence intervals of the censoring distribution is not implemented.")
|
|
292
|
+
|
|
293
|
+
if time_enter is None:
|
|
294
|
+
uniq_times, n_events, n_at_risk, n_censored = _compute_counts(event, time_exit)
|
|
295
|
+
|
|
296
|
+
if reverse:
|
|
297
|
+
n_at_risk -= n_events
|
|
298
|
+
n_events = n_censored
|
|
299
|
+
else:
|
|
300
|
+
if reverse:
|
|
301
|
+
raise ValueError("The censoring distribution cannot be estimated from left truncated data")
|
|
302
|
+
|
|
303
|
+
uniq_times, n_events, n_at_risk = _compute_counts_truncated(event, time_enter, time_exit)
|
|
304
|
+
|
|
305
|
+
# account for 0/0 = nan
|
|
306
|
+
ratio = np.divide(
|
|
307
|
+
n_events,
|
|
308
|
+
n_at_risk,
|
|
309
|
+
out=np.zeros(uniq_times.shape[0], dtype=float),
|
|
310
|
+
where=n_events != 0,
|
|
311
|
+
)
|
|
312
|
+
values = 1.0 - ratio
|
|
313
|
+
|
|
314
|
+
if conf_type is not None:
|
|
315
|
+
ratio_var = np.divide(
|
|
316
|
+
n_events,
|
|
317
|
+
n_at_risk * (n_at_risk - n_events),
|
|
318
|
+
out=np.zeros(uniq_times.shape[0], dtype=float),
|
|
319
|
+
where=(n_events != 0) & (n_at_risk != n_events),
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
if time_min is not None:
|
|
323
|
+
mask = uniq_times >= time_min
|
|
324
|
+
uniq_times = np.compress(mask, uniq_times)
|
|
325
|
+
values = np.compress(mask, values)
|
|
326
|
+
|
|
327
|
+
prob_survival = np.cumprod(values)
|
|
328
|
+
|
|
329
|
+
if conf_type is None:
|
|
330
|
+
return uniq_times, prob_survival
|
|
331
|
+
|
|
332
|
+
if time_min is not None:
|
|
333
|
+
ratio_var = np.compress(mask, ratio_var)
|
|
334
|
+
|
|
335
|
+
ci = _km_ci_estimator(prob_survival, ratio_var, conf_level, conf_type)
|
|
336
|
+
|
|
337
|
+
return uniq_times, prob_survival, ci
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def nelson_aalen_estimator(event, time):
|
|
341
|
+
"""Nelson-Aalen estimator of cumulative hazard function.
|
|
342
|
+
|
|
343
|
+
See [1]_, [2]_ for further description.
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
event : array-like, shape = (n_samples,)
|
|
348
|
+
Contains binary event indicators.
|
|
349
|
+
|
|
350
|
+
time : array-like, shape = (n_samples,)
|
|
351
|
+
Contains event/censoring times.
|
|
352
|
+
|
|
353
|
+
Returns
|
|
354
|
+
-------
|
|
355
|
+
time : array, shape = (n_times,)
|
|
356
|
+
Unique times.
|
|
357
|
+
|
|
358
|
+
cum_hazard : array, shape = (n_times,)
|
|
359
|
+
Cumulative hazard at each unique time point.
|
|
360
|
+
|
|
361
|
+
References
|
|
362
|
+
----------
|
|
363
|
+
.. [1] Nelson, W., "Theory and applications of hazard plotting for censored failure data",
|
|
364
|
+
Technometrics, vol. 14, pp. 945-965, 1972.
|
|
365
|
+
|
|
366
|
+
.. [2] Aalen, O. O., "Nonparametric inference for a family of counting processes",
|
|
367
|
+
Annals of Statistics, vol. 6, pp. 701–726, 1978.
|
|
368
|
+
"""
|
|
369
|
+
event, time = check_y_survival(event, time)
|
|
370
|
+
check_consistent_length(event, time)
|
|
371
|
+
uniq_times, n_events, n_at_risk, _ = _compute_counts(event, time)
|
|
372
|
+
|
|
373
|
+
y = np.cumsum(n_events / n_at_risk)
|
|
374
|
+
|
|
375
|
+
return uniq_times, y
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def ipc_weights(event, time):
|
|
379
|
+
"""Compute inverse probability of censoring weights
|
|
380
|
+
|
|
381
|
+
Parameters
|
|
382
|
+
----------
|
|
383
|
+
event : array, shape = (n_samples,)
|
|
384
|
+
Boolean event indicator.
|
|
385
|
+
|
|
386
|
+
time : array, shape = (n_samples,)
|
|
387
|
+
Time when a subject experienced an event or was censored.
|
|
388
|
+
|
|
389
|
+
Returns
|
|
390
|
+
-------
|
|
391
|
+
weights : array, shape = (n_samples,)
|
|
392
|
+
inverse probability of censoring weights
|
|
393
|
+
|
|
394
|
+
See also
|
|
395
|
+
--------
|
|
396
|
+
CensoringDistributionEstimator
|
|
397
|
+
An estimator interface for estimating inverse probability
|
|
398
|
+
of censoring weights for unseen time points.
|
|
399
|
+
"""
|
|
400
|
+
if event.all():
|
|
401
|
+
return np.ones(time.shape[0])
|
|
402
|
+
|
|
403
|
+
unique_time, p = kaplan_meier_estimator(event, time, reverse=True)
|
|
404
|
+
|
|
405
|
+
idx = np.searchsorted(unique_time, time[event])
|
|
406
|
+
Ghat = p[idx]
|
|
407
|
+
|
|
408
|
+
assert (Ghat > 0).all()
|
|
409
|
+
|
|
410
|
+
weights = np.zeros(time.shape[0])
|
|
411
|
+
weights[event] = 1.0 / Ghat
|
|
412
|
+
|
|
413
|
+
return weights
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
class SurvivalFunctionEstimator(BaseEstimator):
|
|
417
|
+
"""Kaplan–Meier estimate of the survival function.
|
|
418
|
+
|
|
419
|
+
Parameters
|
|
420
|
+
----------
|
|
421
|
+
conf_level : float, optional, default: 0.95
|
|
422
|
+
The level for a two-sided confidence interval on the survival curves.
|
|
423
|
+
|
|
424
|
+
conf_type : None or {'log-log'}, optional, default: None.
|
|
425
|
+
The type of confidence intervals to estimate.
|
|
426
|
+
If `None`, no confidence intervals are estimated.
|
|
427
|
+
If "log-log", estimate confidence intervals using
|
|
428
|
+
the log hazard or :math:`log(-log(S(t)))`.
|
|
429
|
+
|
|
430
|
+
See also
|
|
431
|
+
--------
|
|
432
|
+
sksurv.nonparametric.kaplan_meier_estimator
|
|
433
|
+
Functional API of the Kaplan-Meier estimator.
|
|
434
|
+
"""
|
|
435
|
+
|
|
436
|
+
_parameter_constraints = {
|
|
437
|
+
"conf_level": [Interval(numbers.Real, 0.0, 1.0, closed="neither")],
|
|
438
|
+
"conf_type": [None, StrOptions({"log-log"})],
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
def __init__(self, conf_level=0.95, conf_type=None):
|
|
442
|
+
self.conf_level = conf_level
|
|
443
|
+
self.conf_type = conf_type
|
|
444
|
+
|
|
445
|
+
def fit(self, y):
|
|
446
|
+
"""Estimate survival distribution from training data.
|
|
447
|
+
|
|
448
|
+
Parameters
|
|
449
|
+
----------
|
|
450
|
+
y : structured array, shape = (n_samples,)
|
|
451
|
+
A structured array containing the binary event indicator
|
|
452
|
+
as first field, and time of event or time of censoring as
|
|
453
|
+
second field.
|
|
454
|
+
|
|
455
|
+
Returns
|
|
456
|
+
-------
|
|
457
|
+
self
|
|
458
|
+
"""
|
|
459
|
+
self._validate_params()
|
|
460
|
+
event, time = check_y_survival(y, allow_all_censored=True)
|
|
461
|
+
|
|
462
|
+
values = kaplan_meier_estimator(event, time, conf_level=self.conf_level, conf_type=self.conf_type)
|
|
463
|
+
if self.conf_type is None:
|
|
464
|
+
unique_time, prob = values
|
|
465
|
+
else:
|
|
466
|
+
unique_time, prob, conf_int = values
|
|
467
|
+
self.conf_int_ = np.column_stack((np.ones((2, 1)), conf_int))
|
|
468
|
+
|
|
469
|
+
self.unique_time_ = np.r_[-np.inf, unique_time]
|
|
470
|
+
self.prob_ = np.r_[1.0, prob]
|
|
471
|
+
|
|
472
|
+
return self
|
|
473
|
+
|
|
474
|
+
def predict_proba(self, time, return_conf_int=False):
|
|
475
|
+
"""Return probability of an event after given time point.
|
|
476
|
+
|
|
477
|
+
:math:`\\hat{S}(t) = P(T > t)`
|
|
478
|
+
|
|
479
|
+
Parameters
|
|
480
|
+
----------
|
|
481
|
+
time : array, shape = (n_samples,)
|
|
482
|
+
Time to estimate probability at.
|
|
483
|
+
|
|
484
|
+
return_conf_int : bool, optional, default: False
|
|
485
|
+
Whether to return the pointwise confidence interval
|
|
486
|
+
of the survival function.
|
|
487
|
+
Only available if :meth:`fit()` has been called
|
|
488
|
+
with the `conf_type` parameter set.
|
|
489
|
+
|
|
490
|
+
Returns
|
|
491
|
+
-------
|
|
492
|
+
prob : array, shape = (n_samples,)
|
|
493
|
+
Probability of an event at the passed time points.
|
|
494
|
+
|
|
495
|
+
conf_int : array, shape = (2, n_samples)
|
|
496
|
+
Pointwise confidence interval at the passed time points.
|
|
497
|
+
Only provided if `return_conf_int` is True.
|
|
498
|
+
"""
|
|
499
|
+
check_is_fitted(self, "unique_time_")
|
|
500
|
+
if return_conf_int and not hasattr(self, "conf_int_"):
|
|
501
|
+
raise ValueError(
|
|
502
|
+
"If return_conf_int is True, SurvivalFunctionEstimator must be fitted with conf_int != None"
|
|
503
|
+
)
|
|
504
|
+
|
|
505
|
+
time = check_array(time, ensure_2d=False, estimator=self, input_name="time")
|
|
506
|
+
|
|
507
|
+
# K-M is undefined if estimate at last time point is non-zero
|
|
508
|
+
extends = time > self.unique_time_[-1]
|
|
509
|
+
if self.prob_[-1] > 0 and extends.any():
|
|
510
|
+
raise ValueError(f"time must be smaller than largest observed time point: {self.unique_time_[-1]}")
|
|
511
|
+
|
|
512
|
+
# beyond last time point is zero probability
|
|
513
|
+
Shat = np.empty(time.shape, dtype=float)
|
|
514
|
+
Shat[extends] = 0.0
|
|
515
|
+
|
|
516
|
+
valid = ~extends
|
|
517
|
+
time = time[valid]
|
|
518
|
+
idx = np.searchsorted(self.unique_time_, time)
|
|
519
|
+
# for non-exact matches, we need to shift the index to left
|
|
520
|
+
eps = np.finfo(self.unique_time_.dtype).eps
|
|
521
|
+
exact = np.absolute(self.unique_time_[idx] - time) < eps
|
|
522
|
+
idx[~exact] -= 1
|
|
523
|
+
Shat[valid] = self.prob_[idx]
|
|
524
|
+
|
|
525
|
+
if not return_conf_int:
|
|
526
|
+
return Shat
|
|
527
|
+
|
|
528
|
+
ci = np.empty((2, time.shape[0]), dtype=float)
|
|
529
|
+
ci[:, extends] = np.nan
|
|
530
|
+
ci[:, valid] = self.conf_int_[:, idx]
|
|
531
|
+
return Shat, ci
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class CensoringDistributionEstimator(SurvivalFunctionEstimator):
|
|
535
|
+
"""Kaplan–Meier estimator for the censoring distribution."""
|
|
536
|
+
|
|
537
|
+
def fit(self, y):
|
|
538
|
+
"""Estimate censoring distribution from training data.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
y : structured array, shape = (n_samples,)
|
|
543
|
+
A structured array containing the binary event indicator
|
|
544
|
+
as first field, and time of event or time of censoring as
|
|
545
|
+
second field.
|
|
546
|
+
|
|
547
|
+
Returns
|
|
548
|
+
-------
|
|
549
|
+
self
|
|
550
|
+
"""
|
|
551
|
+
event, time = check_y_survival(y)
|
|
552
|
+
if event.all():
|
|
553
|
+
self.unique_time_ = np.unique(time)
|
|
554
|
+
self.prob_ = np.ones(self.unique_time_.shape[0])
|
|
555
|
+
else:
|
|
556
|
+
unique_time, prob = kaplan_meier_estimator(event, time, reverse=True)
|
|
557
|
+
self.unique_time_ = np.r_[-np.inf, unique_time]
|
|
558
|
+
self.prob_ = np.r_[1.0, prob]
|
|
559
|
+
|
|
560
|
+
return self
|
|
561
|
+
|
|
562
|
+
def predict_ipcw(self, y):
|
|
563
|
+
"""Return inverse probability of censoring weights at given time points.
|
|
564
|
+
|
|
565
|
+
:math:`\\omega_i = \\delta_i / \\hat{G}(y_i)`
|
|
566
|
+
|
|
567
|
+
Parameters
|
|
568
|
+
----------
|
|
569
|
+
y : structured array, shape = (n_samples,)
|
|
570
|
+
A structured array containing the binary event indicator
|
|
571
|
+
as first field, and time of event or time of censoring as
|
|
572
|
+
second field.
|
|
573
|
+
|
|
574
|
+
Returns
|
|
575
|
+
-------
|
|
576
|
+
ipcw : array, shape = (n_samples,)
|
|
577
|
+
Inverse probability of censoring weights.
|
|
578
|
+
"""
|
|
579
|
+
event, time = check_y_survival(y)
|
|
580
|
+
Ghat = self.predict_proba(time[event])
|
|
581
|
+
|
|
582
|
+
if (Ghat == 0.0).any():
|
|
583
|
+
raise ValueError("censoring survival function is zero at one or more time points")
|
|
584
|
+
|
|
585
|
+
weights = np.zeros(time.shape[0])
|
|
586
|
+
weights[event] = 1.0 / Ghat
|
|
587
|
+
|
|
588
|
+
return weights
|