scikit-survival 0.23.1__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. scikit_survival-0.23.1.dist-info/COPYING +674 -0
  2. scikit_survival-0.23.1.dist-info/METADATA +888 -0
  3. scikit_survival-0.23.1.dist-info/RECORD +55 -0
  4. scikit_survival-0.23.1.dist-info/WHEEL +5 -0
  5. scikit_survival-0.23.1.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +138 -0
  7. sksurv/base.py +103 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
  10. sksurv/column.py +201 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +10 -0
  13. sksurv/datasets/base.py +436 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  17. sksurv/datasets/data/flchain.arff +7887 -0
  18. sksurv/datasets/data/veteran.arff +148 -0
  19. sksurv/datasets/data/whas500.arff +520 -0
  20. sksurv/ensemble/__init__.py +2 -0
  21. sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
  22. sksurv/ensemble/boosting.py +1610 -0
  23. sksurv/ensemble/forest.py +947 -0
  24. sksurv/ensemble/survival_loss.py +151 -0
  25. sksurv/exceptions.py +18 -0
  26. sksurv/functions.py +114 -0
  27. sksurv/io/__init__.py +2 -0
  28. sksurv/io/arffread.py +58 -0
  29. sksurv/io/arffwrite.py +145 -0
  30. sksurv/kernels/__init__.py +1 -0
  31. sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
  32. sksurv/kernels/clinical.py +328 -0
  33. sksurv/linear_model/__init__.py +3 -0
  34. sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
  35. sksurv/linear_model/aft.py +205 -0
  36. sksurv/linear_model/coxnet.py +543 -0
  37. sksurv/linear_model/coxph.py +618 -0
  38. sksurv/meta/__init__.py +4 -0
  39. sksurv/meta/base.py +35 -0
  40. sksurv/meta/ensemble_selection.py +642 -0
  41. sksurv/meta/stacking.py +349 -0
  42. sksurv/metrics.py +996 -0
  43. sksurv/nonparametric.py +588 -0
  44. sksurv/preprocessing.py +155 -0
  45. sksurv/svm/__init__.py +11 -0
  46. sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
  47. sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
  48. sksurv/svm/minlip.py +606 -0
  49. sksurv/svm/naive_survival_svm.py +221 -0
  50. sksurv/svm/survival_svm.py +1228 -0
  51. sksurv/testing.py +108 -0
  52. sksurv/tree/__init__.py +1 -0
  53. sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
  54. sksurv/tree/tree.py +703 -0
  55. sksurv/util.py +333 -0
@@ -0,0 +1,436 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from pandas.api.types import CategoricalDtype
6
+
7
+ from ..column import categorical_to_numeric, standardize
8
+ from ..io import loadarff
9
+ from ..util import safe_concat
10
+
11
+ __all__ = [
12
+ "get_x_y",
13
+ "load_arff_files_standardized",
14
+ "load_aids",
15
+ "load_breast_cancer",
16
+ "load_flchain",
17
+ "load_gbsg2",
18
+ "load_whas500",
19
+ "load_veterans_lung_cancer",
20
+ ]
21
+
22
+
23
+ def _get_data_path(name):
24
+ from importlib.resources import files
25
+
26
+ return files(__package__) / "data" / name
27
+
28
+
29
+ def _get_x_y_survival(dataset, col_event, col_time, val_outcome):
30
+ if col_event is None or col_time is None:
31
+ y = None
32
+ x_frame = dataset
33
+ else:
34
+ y = np.empty(dtype=[(col_event, bool), (col_time, np.float64)], shape=dataset.shape[0])
35
+ y[col_event] = (dataset[col_event] == val_outcome).values
36
+ y[col_time] = dataset[col_time].values
37
+
38
+ x_frame = dataset.drop([col_event, col_time], axis=1)
39
+
40
+ return x_frame, y
41
+
42
+
43
+ def _get_x_y_other(dataset, col_label):
44
+ if col_label is None:
45
+ y = None
46
+ x_frame = dataset
47
+ else:
48
+ y = dataset.loc[:, col_label]
49
+ x_frame = dataset.drop(col_label, axis=1)
50
+
51
+ return x_frame, y
52
+
53
+
54
+ def get_x_y(data_frame, attr_labels, pos_label=None, survival=True):
55
+ """Split data frame into features and labels.
56
+
57
+ Parameters
58
+ ----------
59
+ data_frame : pandas.DataFrame, shape = (n_samples, n_columns)
60
+ A data frame.
61
+
62
+ attr_labels : sequence of str or None
63
+ A list of one or more columns that are considered the label.
64
+ If `survival` is `True`, then attr_labels has two elements:
65
+ 1) the name of the column denoting the event indicator, and
66
+ 2) the name of the column denoting the survival time.
67
+ If the sequence contains `None`, then labels are not retrieved
68
+ and only a data frame with features is returned.
69
+
70
+ pos_label : any, optional
71
+ Which value of the event indicator column denotes that a
72
+ patient experienced an event. This value is ignored if
73
+ `survival` is `False`.
74
+
75
+ survival : bool, optional, default: True
76
+ Whether to return `y` that can be used for survival analysis.
77
+
78
+ Returns
79
+ -------
80
+ X : pandas.DataFrame, shape = (n_samples, n_columns - len(attr_labels))
81
+ Data frame containing features.
82
+
83
+ y : None or pandas.DataFrame, shape = (n_samples, len(attr_labels))
84
+ Data frame containing columns with supervised information.
85
+ If `survival` was `True`, then the column denoting the event
86
+ indicator will be boolean and survival times will be float.
87
+ If `attr_labels` contains `None`, y is set to `None`.
88
+ """
89
+ if survival:
90
+ if len(attr_labels) != 2:
91
+ raise ValueError(f"expected sequence of length two for attr_labels, but got {len(attr_labels)}")
92
+ if pos_label is None:
93
+ raise ValueError("pos_label needs to be specified if survival=True")
94
+ return _get_x_y_survival(data_frame, attr_labels[0], attr_labels[1], pos_label)
95
+
96
+ return _get_x_y_other(data_frame, attr_labels)
97
+
98
+
99
+ def _loadarff_with_index(filename):
100
+ dataset = loadarff(filename)
101
+ if "index" in dataset.columns:
102
+ if isinstance(dataset["index"].dtype, CategoricalDtype):
103
+ # concatenating categorical index may raise TypeError
104
+ # see https://github.com/pandas-dev/pandas/issues/14586
105
+ dataset["index"] = dataset["index"].astype(object)
106
+ dataset.set_index("index", inplace=True)
107
+ return dataset
108
+
109
+
110
+ def load_arff_files_standardized(
111
+ path_training,
112
+ attr_labels,
113
+ pos_label=None,
114
+ path_testing=None,
115
+ survival=True,
116
+ standardize_numeric=True,
117
+ to_numeric=True,
118
+ ):
119
+ """Load dataset in ARFF format.
120
+
121
+ Parameters
122
+ ----------
123
+ path_training : str
124
+ Path to ARFF file containing data.
125
+
126
+ attr_labels : sequence of str
127
+ Names of attributes denoting dependent variables.
128
+ If ``survival`` is set, it must be a sequence with two items:
129
+ the name of the event indicator and the name of the survival/censoring time.
130
+
131
+ pos_label : any type, optional
132
+ Value corresponding to an event in survival analysis.
133
+ Only considered if ``survival`` is ``True``.
134
+
135
+ path_testing : str, optional
136
+ Path to ARFF file containing hold-out data. Only columns that are available in both
137
+ training and testing are considered (excluding dependent variables).
138
+ If ``standardize_numeric`` is set, data is normalized by considering both training
139
+ and testing data.
140
+
141
+ survival : bool, optional, default: True
142
+ Whether the dependent variables denote event indicator and survival/censoring time.
143
+
144
+ standardize_numeric : bool, optional, default: True
145
+ Whether to standardize data to zero mean and unit variance.
146
+ See :func:`sksurv.column.standardize`.
147
+
148
+ to_numeric : boo, optional, default: True
149
+ Whether to convert categorical variables to numeric values.
150
+ See :func:`sksurv.column.categorical_to_numeric`.
151
+
152
+ Returns
153
+ -------
154
+ x_train : pandas.DataFrame, shape = (n_train, n_features)
155
+ Training data.
156
+
157
+ y_train : pandas.DataFrame, shape = (n_train, n_labels)
158
+ Dependent variables of training data.
159
+
160
+ x_test : None or pandas.DataFrame, shape = (n_train, n_features)
161
+ Testing data if `path_testing` was provided.
162
+
163
+ y_test : None or pandas.DataFrame, shape = (n_train, n_labels)
164
+ Dependent variables of testing data if `path_testing` was provided.
165
+ """
166
+ dataset = _loadarff_with_index(path_training)
167
+
168
+ x_train, y_train = get_x_y(dataset, attr_labels, pos_label, survival)
169
+
170
+ if path_testing is not None:
171
+ x_test, y_test = _load_arff_testing(path_testing, attr_labels, pos_label, survival)
172
+
173
+ if len(x_train.columns.symmetric_difference(x_test.columns)) > 0:
174
+ warnings.warn("Restricting columns to intersection between training and testing data", stacklevel=2)
175
+
176
+ cols = x_train.columns.intersection(x_test.columns)
177
+ if len(cols) == 0:
178
+ raise ValueError("columns of training and test data do not intersect")
179
+
180
+ x_train = x_train.loc[:, cols]
181
+ x_test = x_test.loc[:, cols]
182
+
183
+ x = safe_concat((x_train, x_test), axis=0)
184
+ if standardize_numeric:
185
+ x = standardize(x)
186
+ if to_numeric:
187
+ x = categorical_to_numeric(x)
188
+
189
+ n_train = x_train.shape[0]
190
+ x_train = x.iloc[:n_train, :]
191
+ x_test = x.iloc[n_train:, :]
192
+ else:
193
+ if standardize_numeric:
194
+ x_train = standardize(x_train)
195
+ if to_numeric:
196
+ x_train = categorical_to_numeric(x_train)
197
+
198
+ x_test = None
199
+ y_test = None
200
+
201
+ return x_train, y_train, x_test, y_test
202
+
203
+
204
+ def _load_arff_testing(path_testing, attr_labels, pos_label, survival):
205
+ test_dataset = _loadarff_with_index(path_testing)
206
+
207
+ has_labels = pd.Index(attr_labels).isin(test_dataset.columns).all()
208
+ if not has_labels:
209
+ if survival:
210
+ attr_labels = [None, None]
211
+ else:
212
+ attr_labels = None
213
+ return get_x_y(test_dataset, attr_labels, pos_label, survival)
214
+
215
+
216
+ def load_whas500():
217
+ """Load and return the Worcester Heart Attack Study dataset
218
+
219
+ The dataset has 500 samples and 14 features.
220
+ The endpoint is death, which occurred for 215 patients (43.0%).
221
+
222
+ See [1]_, [2]_ for further description.
223
+
224
+ Returns
225
+ -------
226
+ x : pandas.DataFrame
227
+ The measurements for each patient.
228
+
229
+ y : structured array with 2 fields
230
+ *fstat*: boolean indicating whether the endpoint has been reached
231
+ or the event time is right censored.
232
+
233
+ *lenfol*: total length of follow-up (days from hospital admission date
234
+ to date of last follow-up)
235
+
236
+ References
237
+ ----------
238
+ .. [1] https://web.archive.org/web/20170114043458/http://www.umass.edu/statdata/statdata/data/
239
+
240
+ .. [2] Hosmer, D., Lemeshow, S., May, S.:
241
+ "Applied Survival Analysis: Regression Modeling of Time to Event Data."
242
+ John Wiley & Sons, Inc. (2008)
243
+ """
244
+ fn = _get_data_path("whas500.arff")
245
+ return get_x_y(loadarff(fn), attr_labels=["fstat", "lenfol"], pos_label="1")
246
+
247
+
248
+ def load_gbsg2():
249
+ """Load and return the German Breast Cancer Study Group 2 dataset
250
+
251
+ The dataset has 686 samples and 8 features.
252
+ The endpoint is recurrence free survival, which occurred for 299 patients (43.6%).
253
+
254
+ See [1]_, [2]_ for further description.
255
+
256
+ Returns
257
+ -------
258
+ x : pandas.DataFrame
259
+ The measurements for each patient.
260
+
261
+ y : structured array with 2 fields
262
+ *cens*: boolean indicating whether the endpoint has been reached
263
+ or the event time is right censored.
264
+
265
+ *time*: total length of follow-up
266
+
267
+ References
268
+ ----------
269
+ .. [1] http://ascopubs.org/doi/abs/10.1200/jco.1994.12.10.2086
270
+
271
+ .. [2] Schumacher, M., Basert, G., Bojar, H., et al.
272
+ "Randomized 2 × 2 trial evaluating hormonal treatment and the duration of chemotherapy
273
+ in node-positive breast cancer patients."
274
+ Journal of Clinical Oncology 12, 2086–2093. (1994)
275
+ """
276
+ fn = _get_data_path("GBSG2.arff")
277
+ return get_x_y(loadarff(fn), attr_labels=["cens", "time"], pos_label="1")
278
+
279
+
280
+ def load_veterans_lung_cancer():
281
+ """Load and return data from the Veterans' Administration
282
+ Lung Cancer Trial
283
+
284
+ The dataset has 137 samples and 6 features.
285
+ The endpoint is death, which occurred for 128 patients (93.4%).
286
+
287
+ See [1]_ for further description.
288
+
289
+ Returns
290
+ -------
291
+ x : pandas.DataFrame
292
+ The measurements for each patient.
293
+
294
+ y : structured array with 2 fields
295
+ *Status*: boolean indicating whether the endpoint has been reached
296
+ or the event time is right censored.
297
+
298
+ *Survival_in_days*: total length of follow-up
299
+
300
+ References
301
+ ----------
302
+ .. [1] Kalbfleisch, J.D., Prentice, R.L.:
303
+ "The Statistical Analysis of Failure Time Data." John Wiley & Sons, Inc. (2002)
304
+ """
305
+ fn = _get_data_path("veteran.arff")
306
+ return get_x_y(loadarff(fn), attr_labels=["Status", "Survival_in_days"], pos_label="dead")
307
+
308
+
309
+ def load_aids(endpoint="aids"):
310
+ """Load and return the AIDS Clinical Trial dataset
311
+
312
+ The dataset has 1,151 samples and 11 features.
313
+ The dataset has 2 endpoints:
314
+
315
+ 1. AIDS defining event, which occurred for 96 patients (8.3%)
316
+ 2. Death, which occurred for 26 patients (2.3%)
317
+
318
+ See [1]_, [2]_ for further description.
319
+
320
+ Parameters
321
+ ----------
322
+ endpoint : aids|death
323
+ The endpoint
324
+
325
+ Returns
326
+ -------
327
+ x : pandas.DataFrame
328
+ The measurements for each patient.
329
+
330
+ y : structured array with 2 fields
331
+ *censor*: boolean indicating whether the endpoint has been reached
332
+ or the event time is right censored.
333
+
334
+ *time*: total length of follow-up
335
+
336
+ If ``endpoint`` is death, the fields are named *censor_d* and *time_d*.
337
+
338
+ References
339
+ ----------
340
+ .. [1] https://web.archive.org/web/20170114043458/http://www.umass.edu/statdata/statdata/data/
341
+
342
+ .. [2] Hosmer, D., Lemeshow, S., May, S.:
343
+ "Applied Survival Analysis: Regression Modeling of Time to Event Data."
344
+ John Wiley & Sons, Inc. (2008)
345
+ """
346
+ labels_aids = ["censor", "time"]
347
+ labels_death = ["censor_d", "time_d"]
348
+ if endpoint == "aids":
349
+ attr_labels = labels_aids
350
+ drop_columns = labels_death
351
+ elif endpoint == "death":
352
+ attr_labels = labels_death
353
+ drop_columns = labels_aids
354
+ else:
355
+ raise ValueError("endpoint must be 'aids' or 'death'")
356
+
357
+ fn = _get_data_path("actg320.arff")
358
+ x, y = get_x_y(loadarff(fn), attr_labels=attr_labels, pos_label="1")
359
+ x.drop(drop_columns, axis=1, inplace=True)
360
+ return x, y
361
+
362
+
363
+ def load_breast_cancer():
364
+ """Load and return the breast cancer dataset
365
+
366
+ The dataset has 198 samples and 80 features.
367
+ The endpoint is the presence of distance metastases, which occurred for 51 patients (25.8%).
368
+
369
+ See [1]_, [2]_ for further description.
370
+
371
+ Returns
372
+ -------
373
+ x : pandas.DataFrame
374
+ The measurements for each patient.
375
+
376
+ y : structured array with 2 fields
377
+ *e.tdm*: boolean indicating whether the endpoint has been reached
378
+ or the event time is right censored.
379
+
380
+ *t.tdm*: time to distant metastasis (days)
381
+
382
+ References
383
+ ----------
384
+ .. [1] https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE7390
385
+
386
+ .. [2] Desmedt, C., Piette, F., Loi et al.:
387
+ "Strong Time Dependence of the 76-Gene Prognostic Signature for Node-Negative Breast Cancer
388
+ Patients in the TRANSBIG Multicenter Independent Validation Series."
389
+ Clin. Cancer Res. 13(11), 3207–14 (2007)
390
+ """
391
+ fn = _get_data_path("breast_cancer_GSE7390-metastasis.arff")
392
+ return get_x_y(loadarff(fn), attr_labels=["e.tdm", "t.tdm"], pos_label="1")
393
+
394
+
395
+ def load_flchain():
396
+ """Load and return assay of serum free light chain for 7874 subjects.
397
+
398
+ The dataset has 7874 samples and 9 features:
399
+
400
+ 1. age: age in years
401
+ 2. sex: F=female, M=male
402
+ 3. sample.yr: the calendar year in which a blood sample was obtained
403
+ 4. kappa: serum free light chain, kappa portion
404
+ 5. lambda: serum free light chain, lambda portion
405
+ 6. flc.grp: the serum free light chain group for the subject, as used in the original analysis
406
+ 7. creatinine: serum creatinine
407
+ 8. mgus: whether the subject had been diagnosed with monoclonal gammapothy (MGUS)
408
+ 9. chapter: for those who died, a grouping of their primary cause of death by chapter headings
409
+ of the International Code of Diseases ICD-9
410
+
411
+ The endpoint is death, which occurred for 2169 subjects (27.5%).
412
+
413
+ See [1]_, [2]_ for further description.
414
+
415
+ Returns
416
+ -------
417
+ x : pandas.DataFrame
418
+ The measurements for each patient.
419
+
420
+ y : structured array with 2 fields
421
+ *death*: boolean indicating whether the subject died
422
+ or the event time is right censored.
423
+
424
+ *futime*: total length of follow-up or time of death.
425
+
426
+ References
427
+ ----------
428
+ .. [1] https://doi.org/10.1016/j.mayocp.2012.03.009
429
+
430
+ .. [2] Dispenzieri, A., Katzmann, J., Kyle, R., Larson, D., Therneau, T., Colby, C., Clark, R.,
431
+ Mead, G., Kumar, S., Melton III, LJ. and Rajkumar, SV.
432
+ Use of nonclonal serum immunoglobulin free light chains to predict overall survival in
433
+ the general population, Mayo Clinic Proceedings 87:512-523. (2012)
434
+ """
435
+ fn = _get_data_path("flchain.arff")
436
+ return get_x_y(loadarff(fn), attr_labels=["death", "futime"], pos_label="dead")