scikit-survival 0.24.1__cp313-cp313-macosx_11_0_arm64.whl → 0.26.0__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. scikit_survival-0.26.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.26.0.dist-info/RECORD +58 -0
  3. {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/WHEEL +1 -1
  4. sksurv/__init__.py +51 -6
  5. sksurv/base.py +12 -2
  6. sksurv/bintrees/_binarytrees.cpython-313-darwin.so +0 -0
  7. sksurv/column.py +38 -35
  8. sksurv/compare.py +23 -23
  9. sksurv/datasets/base.py +52 -27
  10. sksurv/docstrings.py +99 -0
  11. sksurv/ensemble/_coxph_loss.cpython-313-darwin.so +0 -0
  12. sksurv/ensemble/boosting.py +116 -168
  13. sksurv/ensemble/forest.py +94 -151
  14. sksurv/functions.py +29 -29
  15. sksurv/io/arffread.py +37 -4
  16. sksurv/io/arffwrite.py +41 -5
  17. sksurv/kernels/_clinical_kernel.cpython-313-darwin.so +0 -0
  18. sksurv/kernels/clinical.py +36 -16
  19. sksurv/linear_model/_coxnet.cpython-313-darwin.so +0 -0
  20. sksurv/linear_model/aft.py +14 -11
  21. sksurv/linear_model/coxnet.py +138 -89
  22. sksurv/linear_model/coxph.py +102 -83
  23. sksurv/meta/ensemble_selection.py +91 -9
  24. sksurv/meta/stacking.py +47 -26
  25. sksurv/metrics.py +257 -224
  26. sksurv/nonparametric.py +150 -81
  27. sksurv/preprocessing.py +74 -34
  28. sksurv/svm/_minlip.cpython-313-darwin.so +0 -0
  29. sksurv/svm/_prsvm.cpython-313-darwin.so +0 -0
  30. sksurv/svm/minlip.py +171 -85
  31. sksurv/svm/naive_survival_svm.py +63 -34
  32. sksurv/svm/survival_svm.py +103 -103
  33. sksurv/testing.py +47 -0
  34. sksurv/tree/_criterion.cpython-313-darwin.so +0 -0
  35. sksurv/tree/tree.py +170 -84
  36. sksurv/util.py +85 -30
  37. scikit_survival-0.24.1.dist-info/METADATA +0 -889
  38. scikit_survival-0.24.1.dist-info/RECORD +0 -57
  39. {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/licenses/COPYING +0 -0
  40. {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/top_level.txt +0 -0
sksurv/datasets/base.py CHANGED
@@ -36,10 +36,10 @@ def _get_x_y_survival(dataset, col_event, col_time, val_outcome, competing_risks
36
36
  event_type = np.int64 if competing_risks else bool
37
37
  y = np.empty(dtype=[(col_event, event_type), (col_time, np.float64)], shape=dataset.shape[0])
38
38
  if competing_risks:
39
- y[col_event] = dataset[col_event].values
39
+ y[col_event] = dataset[col_event].to_numpy()
40
40
  else:
41
- y[col_event] = (dataset[col_event] == val_outcome).values
42
- y[col_time] = dataset[col_time].values
41
+ y[col_event] = (dataset[col_event] == val_outcome).to_numpy()
42
+ y[col_time] = dataset[col_time].to_numpy()
43
43
 
44
44
  x_frame = dataset.drop([col_event, col_time], axis=1)
45
45
 
@@ -82,18 +82,23 @@ def get_x_y(data_frame, attr_labels, pos_label=None, survival=True, competing_ri
82
82
  Whether to return `y` that can be used for survival analysis.
83
83
 
84
84
  competing_risks : bool, optional, default: False
85
- Whether `y` refers to competing risks situation. Only used if `survival` is True
85
+ Whether `y` refers to competing risks situation. Only used if `survival` is `True`.
86
86
 
87
87
  Returns
88
88
  -------
89
89
  X : pandas.DataFrame, shape = (n_samples, n_columns - len(attr_labels))
90
90
  Data frame containing features.
91
91
 
92
- y : None or pandas.DataFrame, shape = (n_samples, len(attr_labels))
93
- Data frame containing columns with supervised information.
94
- If `survival` was `True`, then the column denoting the event
95
- indicator will be boolean and survival times will be float.
96
- If `attr_labels` contains `None`, y is set to `None`.
92
+ y : structured array, shape = (n_samples,), or pandas.DataFrame, shape = (n_samples, len(attr_labels)), or None
93
+ If `survival` is `True`, a structured array with two fields.
94
+ The first field is a boolean where ``True`` indicates an event and ``False``
95
+ indicates right-censoring. The second field is a float with the time of
96
+ event or time of censoring.
97
+
98
+ If `survival` is `False` and `attr_labels` not `None`, a :class:`pandas.DataFrame`
99
+ with columns specified by `attr_labels`.
100
+
101
+ If `survival` is `False` and `attr_labels` is `None`, `y` is set to `None`.
97
102
  """
98
103
  if survival:
99
104
  if len(attr_labels) != 2:
@@ -111,7 +116,7 @@ def _loadarff_with_index(filename):
111
116
  if isinstance(dataset["index"].dtype, CategoricalDtype):
112
117
  # concatenating categorical index may raise TypeError
113
118
  # see https://github.com/pandas-dev/pandas/issues/14586
114
- dataset["index"] = dataset["index"].astype(object)
119
+ dataset = dataset.astype({"index": "str"})
115
120
  dataset.set_index("index", inplace=True)
116
121
  return dataset
117
122
 
@@ -154,7 +159,7 @@ def load_arff_files_standardized(
154
159
  Whether to standardize data to zero mean and unit variance.
155
160
  See :func:`sksurv.column.standardize`.
156
161
 
157
- to_numeric : boo, optional, default: True
162
+ to_numeric : bool, optional, default: True
158
163
  Whether to convert categorical variables to numeric values.
159
164
  See :func:`sksurv.column.categorical_to_numeric`.
160
165
 
@@ -163,14 +168,34 @@ def load_arff_files_standardized(
163
168
  x_train : pandas.DataFrame, shape = (n_train, n_features)
164
169
  Training data.
165
170
 
166
- y_train : pandas.DataFrame, shape = (n_train, n_labels)
171
+ y_train : structured array, shape = (n_train,), or pandas.DataFrame, shape = (n_train, len(attr_labels))
167
172
  Dependent variables of training data.
168
173
 
169
- x_test : None or pandas.DataFrame, shape = (n_train, n_features)
174
+ If `survival` is `True`, a structured array with two fields.
175
+ The first field is a boolean where ``True`` indicates an event and ``False``
176
+ indicates right-censoring. The second field is a float with the time of
177
+ event or time of censoring.
178
+
179
+ If `survival` is `False` and `attr_labels` not `None`, a :class:`pandas.DataFrame`
180
+ with columns specified by `attr_labels`.
181
+
182
+ If `survival` is `False` and `attr_labels` is `None`, `y_train` is set to `None`.
183
+
184
+ x_test : None or pandas.DataFrame, shape = (n_test, n_features)
170
185
  Testing data if `path_testing` was provided.
171
186
 
172
- y_test : None or pandas.DataFrame, shape = (n_train, n_labels)
187
+ y_test : None or structured array, shape = (n_test,)
173
188
  Dependent variables of testing data if `path_testing` was provided.
189
+
190
+ If `survival` is `True`, a structured array with two fields.
191
+ The first field is a boolean where ``True`` indicates an event and ``False``
192
+ indicates right-censoring. The second field is a float with the time of
193
+ event or time of censoring.
194
+
195
+ If `survival` is `False` and `attr_labels` not `None`, a :class:`pandas.DataFrame`
196
+ with columns specified by `attr_labels`.
197
+
198
+ If `survival` is `False` and `attr_labels` is `None`, `y_test` is set to `None`.
174
199
  """
175
200
  dataset = _loadarff_with_index(path_training)
176
201
 
@@ -237,7 +262,7 @@ def load_whas500():
237
262
 
238
263
  y : structured array with 2 fields
239
264
  *fstat*: boolean indicating whether the endpoint has been reached
240
- or the event time is right censored.
265
+ or the event time is right-censored.
241
266
 
242
267
  *lenfol*: total length of follow-up (days from hospital admission date
243
268
  to date of last follow-up)
@@ -269,7 +294,7 @@ def load_gbsg2():
269
294
 
270
295
  y : structured array with 2 fields
271
296
  *cens*: boolean indicating whether the endpoint has been reached
272
- or the event time is right censored.
297
+ or the event time is right-censored.
273
298
 
274
299
  *time*: total length of follow-up
275
300
 
@@ -302,7 +327,7 @@ def load_veterans_lung_cancer():
302
327
 
303
328
  y : structured array with 2 fields
304
329
  *Status*: boolean indicating whether the endpoint has been reached
305
- or the event time is right censored.
330
+ or the event time is right-censored.
306
331
 
307
332
  *Survival_in_days*: total length of follow-up
308
333
 
@@ -328,8 +353,8 @@ def load_aids(endpoint="aids"):
328
353
 
329
354
  Parameters
330
355
  ----------
331
- endpoint : aids|death
332
- The endpoint
356
+ endpoint : {'aids', 'death'}, default: 'aids'
357
+ The endpoint.
333
358
 
334
359
  Returns
335
360
  -------
@@ -338,7 +363,7 @@ def load_aids(endpoint="aids"):
338
363
 
339
364
  y : structured array with 2 fields
340
365
  *censor*: boolean indicating whether the endpoint has been reached
341
- or the event time is right censored.
366
+ or the event time is right-censored.
342
367
 
343
368
  *time*: total length of follow-up
344
369
 
@@ -384,7 +409,7 @@ def load_breast_cancer():
384
409
 
385
410
  y : structured array with 2 fields
386
411
  *e.tdm*: boolean indicating whether the endpoint has been reached
387
- or the event time is right censored.
412
+ or the event time is right-censored.
388
413
 
389
414
  *t.tdm*: time to distant metastasis (days)
390
415
 
@@ -428,7 +453,7 @@ def load_flchain():
428
453
 
429
454
  y : structured array with 2 fields
430
455
  *death*: boolean indicating whether the subject died
431
- or the event time is right censored.
456
+ or the event time is right-censored.
432
457
 
433
458
  *futime*: total length of follow-up or time of death.
434
459
 
@@ -473,7 +498,7 @@ def load_bmt():
473
498
  The measurements for each patient.
474
499
 
475
500
  y : structured array with 2 fields
476
- *status*: Integer indicating the endpoint: 0-(survival i.e. right censored data), 1-(TRM), 2-(relapse)
501
+ *status*: Integer indicating the endpoint: 0-(survival i.e. right-censored data), 1-(TRM), 2-(relapse)
477
502
 
478
503
  *ftime*: total length of follow-up or time of event.
479
504
 
@@ -487,7 +512,7 @@ def load_bmt():
487
512
  """
488
513
  full_path = _get_data_path("bmt.arff")
489
514
  data = loadarff(full_path)
490
- data["ftime"] = data["ftime"].astype(int)
515
+ data = data.astype({"ftime": int})
491
516
  return get_x_y(data, attr_labels=["status", "ftime"], competing_risks=True)
492
517
 
493
518
 
@@ -566,7 +591,7 @@ def load_cgvhd():
566
591
  The measurements for each patient.
567
592
 
568
593
  y : structured array with 2 fields
569
- *status*: Integer indicating the endpoint: 0: right censored data; 1: CGVHD; 2: relapse; 3: death.
594
+ *status*: Integer indicating the endpoint: 0: right-censored data; 1: CGVHD; 2: relapse; 3: death.
570
595
 
571
596
  *ftime*: total length of follow-up or time of event.
572
597
 
@@ -578,8 +603,8 @@ def load_cgvhd():
578
603
  """
579
604
  full_path = _get_data_path("cgvhd.arff")
580
605
  data = loadarff(full_path)
581
- data["ftime"] = data[["survtime", "reltime", "cgvhtime"]].min(axis=1)
582
- data["status"] = (
606
+ data.loc[:, "ftime"] = data[["survtime", "reltime", "cgvhtime"]].min(axis=1)
607
+ data.loc[:, "status"] = (
583
608
  ((data["ftime"] == data["cgvhtime"]) & (data["cgvh"] == "1")).astype(int)
584
609
  + 2 * ((data["ftime"] == data["reltime"]) & (data["rcens"] == "1")).astype(int)
585
610
  + 3 * ((data["ftime"] == data["survtime"]) & (data["stat"] == "1")).astype(int)
sksurv/docstrings.py ADDED
@@ -0,0 +1,99 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ _PRED_SURV_FN_EXAMPLE_DOC = """
14
+ .. plot::
15
+
16
+ >>> import matplotlib.pyplot as plt
17
+ >>> from sksurv.datasets import load_veterans_lung_cancer
18
+ >>> from sksurv.preprocessing import OneHotEncoder
19
+ >>> from sksurv.{estimator_mod} import {estimator_class}
20
+
21
+ Load the data and encode categorical features.
22
+
23
+ >>> X, y = load_veterans_lung_cancer()
24
+ >>> Xt = OneHotEncoder().fit_transform(X)
25
+
26
+ Fit the model.
27
+
28
+ >>> estimator = {estimator_class}().fit(Xt, y)
29
+
30
+ Estimate the survival function for the first 10 samples.
31
+
32
+ >>> surv_funcs = estimator.predict_survival_function(Xt.iloc[:10])
33
+
34
+ Plot the estimated survival functions.
35
+
36
+ >>> for fn in surv_funcs:
37
+ ... plt.step(fn.x, fn(fn.x), where="post")
38
+ ...
39
+ [...]
40
+ >>> plt.ylim(0, 1)
41
+ (0.0, 1.0)
42
+ >>> plt.show() # doctest: +SKIP
43
+ """
44
+
45
+ _PRED_CUMHAZ_FN_EXAMPLE_DOC = """
46
+ .. plot::
47
+
48
+ >>> import matplotlib.pyplot as plt
49
+ >>> from sksurv.datasets import load_veterans_lung_cancer
50
+ >>> from sksurv.preprocessing import OneHotEncoder
51
+ >>> from sksurv.{estimator_mod} import {estimator_class}
52
+
53
+ Load the data and encode categorical features.
54
+
55
+ >>> X, y = load_veterans_lung_cancer()
56
+ >>> Xt = OneHotEncoder().fit_transform(X)
57
+
58
+ Fit the model.
59
+
60
+ >>> estimator = {estimator_class}().fit(Xt, y)
61
+
62
+ Estimate the cumulative hazard function for the first 10 samples.
63
+
64
+ >>> chf_funcs = estimator.predict_cumulative_hazard_function(Xt.iloc[:10])
65
+
66
+ Plot the estimated cumulative hazard functions.
67
+
68
+ >>> for fn in chf_funcs:
69
+ ... plt.step(fn.x, fn(fn.x), where="post")
70
+ ...
71
+ [...]
72
+ >>> plt.show() # doctest: +SKIP
73
+ """
74
+
75
+
76
+ def append_survival_function_example(*, estimator_mod, estimator_class):
77
+ """Append example of using predict_survival_function to API doc"""
78
+
79
+ def func(f):
80
+ f.__doc__ += _PRED_SURV_FN_EXAMPLE_DOC.format(
81
+ estimator_mod=estimator_mod,
82
+ estimator_class=estimator_class,
83
+ )
84
+ return f
85
+
86
+ return func
87
+
88
+
89
+ def append_cumulative_hazard_example(*, estimator_mod, estimator_class):
90
+ """Append example of using predict_cumulative_hazard_function to API doc"""
91
+
92
+ def func(f):
93
+ f.__doc__ += _PRED_CUMHAZ_FN_EXAMPLE_DOC.format(
94
+ estimator_mod=estimator_mod,
95
+ estimator_class=estimator_class,
96
+ )
97
+ return f
98
+
99
+ return func