scikit-survival 0.25.0__cp311-cp311-win_amd64.whl → 0.27.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scikit_survival-0.25.0.dist-info → scikit_survival-0.27.0.dist-info}/METADATA +13 -17
- {scikit_survival-0.25.0.dist-info → scikit_survival-0.27.0.dist-info}/RECORD +24 -24
- {scikit_survival-0.25.0.dist-info → scikit_survival-0.27.0.dist-info}/WHEEL +1 -1
- sksurv/bintrees/_binarytrees.cp311-win_amd64.pyd +0 -0
- sksurv/column.py +5 -6
- sksurv/compare.py +1 -1
- sksurv/datasets/base.py +7 -7
- sksurv/ensemble/_coxph_loss.cp311-win_amd64.pyd +0 -0
- sksurv/io/arffread.py +3 -1
- sksurv/io/arffwrite.py +4 -4
- sksurv/kernels/_clinical_kernel.cp311-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +3 -3
- sksurv/linear_model/_coxnet.cp311-win_amd64.pyd +0 -0
- sksurv/metrics.py +2 -2
- sksurv/nonparametric.py +3 -3
- sksurv/preprocessing.py +19 -7
- sksurv/svm/_minlip.cp311-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp311-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +11 -6
- sksurv/testing.py +52 -0
- sksurv/tree/_criterion.cp311-win_amd64.pyd +0 -0
- sksurv/util.py +5 -4
- {scikit_survival-0.25.0.dist-info → scikit_survival-0.27.0.dist-info}/licenses/COPYING +0 -0
- {scikit_survival-0.25.0.dist-info → scikit_survival-0.27.0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: scikit-survival
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.27.0
|
|
4
4
|
Summary: Survival analysis built on top of scikit-learn
|
|
5
5
|
Author-email: Sebastian Pölsterl <sebp@k-d-w.org>
|
|
6
6
|
License-Expression: GPL-3.0-or-later
|
|
@@ -19,28 +19,28 @@ Classifier: Programming Language :: C++
|
|
|
19
19
|
Classifier: Programming Language :: Cython
|
|
20
20
|
Classifier: Programming Language :: Python
|
|
21
21
|
Classifier: Programming Language :: Python :: 3
|
|
22
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
23
22
|
Classifier: Programming Language :: Python :: 3.11
|
|
24
23
|
Classifier: Programming Language :: Python :: 3.12
|
|
25
24
|
Classifier: Programming Language :: Python :: 3.13
|
|
25
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
26
26
|
Classifier: Topic :: Software Development
|
|
27
27
|
Classifier: Topic :: Scientific/Engineering
|
|
28
|
-
Requires-Python: >=3.
|
|
28
|
+
Requires-Python: >=3.11
|
|
29
29
|
Description-Content-Type: text/x-rst
|
|
30
30
|
License-File: COPYING
|
|
31
31
|
Requires-Dist: ecos
|
|
32
32
|
Requires-Dist: joblib
|
|
33
33
|
Requires-Dist: numexpr
|
|
34
|
-
Requires-Dist: numpy
|
|
35
|
-
Requires-Dist: osqp
|
|
36
|
-
Requires-Dist: pandas>=
|
|
37
|
-
Requires-Dist: scipy>=1.
|
|
38
|
-
Requires-Dist: scikit-learn<1.
|
|
34
|
+
Requires-Dist: numpy>=2.0.0
|
|
35
|
+
Requires-Dist: osqp>=1.0.2
|
|
36
|
+
Requires-Dist: pandas>=2.2.0
|
|
37
|
+
Requires-Dist: scipy>=1.13.0
|
|
38
|
+
Requires-Dist: scikit-learn<1.9,>=1.8.0
|
|
39
39
|
Dynamic: license-file
|
|
40
40
|
|
|
41
41
|
|License| |Docs| |DOI|
|
|
42
42
|
|
|
43
|
-
|build-tests| |
|
|
43
|
+
|build-tests| |Codecov| |Codacy|
|
|
44
44
|
|
|
45
45
|
***************
|
|
46
46
|
scikit-survival
|
|
@@ -72,14 +72,14 @@ this unique characteristic of such a dataset into account.
|
|
|
72
72
|
Requirements
|
|
73
73
|
============
|
|
74
74
|
|
|
75
|
-
- Python 3.
|
|
75
|
+
- Python 3.11 or later
|
|
76
76
|
- ecos
|
|
77
77
|
- joblib
|
|
78
78
|
- numexpr
|
|
79
|
-
- numpy
|
|
79
|
+
- numpy 2.0.0 or later
|
|
80
80
|
- osqp
|
|
81
|
-
- pandas
|
|
82
|
-
- scikit-learn 1.
|
|
81
|
+
- pandas 2.2.0 or later
|
|
82
|
+
- scikit-learn 1.8
|
|
83
83
|
- scipy
|
|
84
84
|
- C/C++ compiler
|
|
85
85
|
|
|
@@ -178,8 +178,4 @@ Please cite the following paper if you are using **scikit-survival**.
|
|
|
178
178
|
:target: https://github.com/sebp/scikit-survival/actions?query=workflow%3Atests+branch%3Amaster
|
|
179
179
|
:alt: GitHub Actions Tests Status
|
|
180
180
|
|
|
181
|
-
.. |build-windows| image:: https://ci.appveyor.com/api/projects/status/github/sebp/scikit-survival?branch=master&svg=true
|
|
182
|
-
:target: https://ci.appveyor.com/project/sebp/scikit-survival
|
|
183
|
-
:alt: Windows Build Status on AppVeyor
|
|
184
|
-
|
|
185
181
|
.. _survival analysis: https://en.wikipedia.org/wiki/Survival_analysis
|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
scikit_survival-0.
|
|
1
|
+
scikit_survival-0.27.0.dist-info/licenses/COPYING,sha256=Czg9WmPaZE9ijZnDOXbqZIftiaqlnwsyV5kt6sEXHms,35821
|
|
2
2
|
sksurv/__init__.py,sha256=y_H9kC05lnKq8z8qPnZJZVt79ILJhMH4bkCsohCbvV8,5336
|
|
3
3
|
sksurv/base.py,sha256=FqiHfSFH6fdRzi49Eu3HT08hYxd9yrbdvX0-wSse9qc,4488
|
|
4
|
-
sksurv/column.py,sha256=
|
|
5
|
-
sksurv/compare.py,sha256=
|
|
4
|
+
sksurv/column.py,sha256=Kuh1P2mQhHsZgohl2V1iLEGyjIZ3809QqIZa6XDZU7M,7073
|
|
5
|
+
sksurv/compare.py,sha256=XmVNu-cnXk8M9rwm-qgWnxO_pgzDIH4lcZWkd7sVHJM,4543
|
|
6
6
|
sksurv/docstrings.py,sha256=FGFChjoUjOMWxBz1_mZ2ffgond4e-cGMq8NfOM8S7oc,3400
|
|
7
7
|
sksurv/exceptions.py,sha256=WopRsYNia5MQTvvVXL3U_nf418t29pplfBz6HFtmt1M,819
|
|
8
8
|
sksurv/functions.py,sha256=lOquoM6P6B8WCSBWZFyT0uuKBRXl7zSRkzWwQuHsxTU,4012
|
|
9
|
-
sksurv/metrics.py,sha256=
|
|
10
|
-
sksurv/nonparametric.py,sha256=
|
|
11
|
-
sksurv/preprocessing.py,sha256=
|
|
12
|
-
sksurv/testing.py,sha256=
|
|
13
|
-
sksurv/util.py,sha256=
|
|
9
|
+
sksurv/metrics.py,sha256=RZBRQaMBSYxSNdcfPrrxWnlQgBssjEj-3lKT7ul4gOU,42277
|
|
10
|
+
sksurv/nonparametric.py,sha256=bxig1gErSkrhvY1HUP99C5wzPS6JapVYBvOc5b7yUFw,32659
|
|
11
|
+
sksurv/preprocessing.py,sha256=zjvxrpxZxBgA8QOP5yBc_dAE5QWLRSMerbKJBE5So2o,7169
|
|
12
|
+
sksurv/testing.py,sha256=wxHys1bjqzZKrZ5-rp2sXKjlcTNh-fHmSks-fT_sPjs,6257
|
|
13
|
+
sksurv/util.py,sha256=Ml7t7u7qONmohFTEAPMQI6IClv3IpJa6MIs_bL5FoME,16225
|
|
14
14
|
sksurv/bintrees/__init__.py,sha256=z0GwaTPCzww2H2aXF28ubppw0Oc4umNTAlFAKu1VBJc,742
|
|
15
|
-
sksurv/bintrees/_binarytrees.cp311-win_amd64.pyd,sha256=
|
|
15
|
+
sksurv/bintrees/_binarytrees.cp311-win_amd64.pyd,sha256=gWmNLCajX18I5YuNdori_qWhL0RAyOBy163Ql3Xr_A0,64000
|
|
16
16
|
sksurv/datasets/__init__.py,sha256=lDHdxi0FcMHWKyWTovynHnV5B3B5Y4qh8pNRk9to9nE,373
|
|
17
|
-
sksurv/datasets/base.py,sha256=
|
|
17
|
+
sksurv/datasets/base.py,sha256=TKMgBrhD0bB0pldgzVgaW58eEa1WSfIq4vId3o2XlQI,26236
|
|
18
18
|
sksurv/datasets/data/GBSG2.arff,sha256=oX_UM7Qy841xBOArXBkUPLzIxNTvdtIJqpxXsqGGw9Q,26904
|
|
19
19
|
sksurv/datasets/data/actg320.arff,sha256=BwIq5q_i_75G2rPFQ6TjO0bsiR8MwA6wPouG-SX7TUo,46615
|
|
20
20
|
sksurv/datasets/data/bmt.arff,sha256=3cF6Vrjkc5891_0fsCJ4d_4aMXJlZ6VmFygfMmOOmYM,555
|
|
@@ -24,18 +24,18 @@ sksurv/datasets/data/flchain.arff,sha256=4LVUyEe-45ozaWPy0VkN-1js_MNsKw1gs2E-JRy
|
|
|
24
24
|
sksurv/datasets/data/veteran.arff,sha256=LxZtbmq4I82rcB24JeJTYRtlgwPc3vM2OX5hg-q7xTw,5408
|
|
25
25
|
sksurv/datasets/data/whas500.arff,sha256=dvqRzx-nwgSVJZxNVE2zelnt7l3xgzFtMucB7Wux574,28292
|
|
26
26
|
sksurv/ensemble/__init__.py,sha256=aBjRTFm8UE5sTew292-qcplLUCc6owAfY6osWlj-VSM,193
|
|
27
|
-
sksurv/ensemble/_coxph_loss.cp311-win_amd64.pyd,sha256=
|
|
27
|
+
sksurv/ensemble/_coxph_loss.cp311-win_amd64.pyd,sha256=OIwWQh6SbrAfwcSieB71mqGsDJfP8dLBqouYn64VRtk,137728
|
|
28
28
|
sksurv/ensemble/boosting.py,sha256=PierryBGYz_RBHuhBpwImQpxiLCsuZm73FrSPTTBgUg,63136
|
|
29
29
|
sksurv/ensemble/forest.py,sha256=esiN6cUajfzdw-eMiXmwm9Eljs59Z0y1hmf83s2xx5Q,36170
|
|
30
30
|
sksurv/ensemble/survival_loss.py,sha256=v3tSou5t1YY6lBydAZYZ66DLqAirvRhErqW1dZYrTWE,6093
|
|
31
31
|
sksurv/io/__init__.py,sha256=dalzZGTrvekCM8wwsB636rg1dwDkQtDWaBOw7TpHr5U,94
|
|
32
|
-
sksurv/io/arffread.py,sha256
|
|
33
|
-
sksurv/io/arffwrite.py,sha256=
|
|
32
|
+
sksurv/io/arffread.py,sha256=qg76GNSN0OwMTsjsjoIxAVORyn3QKx6St-h4kinQOGM,2841
|
|
33
|
+
sksurv/io/arffwrite.py,sha256=1s2XoXRaTkVj7XyJ8FDHz2uCg2oqmC9V0PMW3_AkJ6A,5611
|
|
34
34
|
sksurv/kernels/__init__.py,sha256=R1if2sVd_0_f6LniIGUR0tipIfzRKpzgGYnvrVZZvHM,78
|
|
35
|
-
sksurv/kernels/_clinical_kernel.cp311-win_amd64.pyd,sha256=
|
|
36
|
-
sksurv/kernels/clinical.py,sha256=
|
|
35
|
+
sksurv/kernels/_clinical_kernel.cp311-win_amd64.pyd,sha256=aA6LrZMPHL92ELbcYbCHcs_ulU6Z4q5zIFGVzIVJwWE,144896
|
|
36
|
+
sksurv/kernels/clinical.py,sha256=8FzyMVLdj68YXdEjj_e1IA9_fR2-9yOabfEbPv4VdJk,11800
|
|
37
37
|
sksurv/linear_model/__init__.py,sha256=dO6Mr3wXk6Q-KQEuhpdgMeY3ji8ZVdgC-SeSRnrJdmw,155
|
|
38
|
-
sksurv/linear_model/_coxnet.cp311-win_amd64.pyd,sha256=
|
|
38
|
+
sksurv/linear_model/_coxnet.cp311-win_amd64.pyd,sha256=r6goQA8ZCiZUn3uwi5Pgqm3gmS2jou_-iBbMBUu-hdw,93184
|
|
39
39
|
sksurv/linear_model/aft.py,sha256=z3GhH6YztcF25v3A2CJAqIK19XEfqUPV-nI0bKgb6h0,7766
|
|
40
40
|
sksurv/linear_model/coxnet.py,sha256=M5IMDA73LHjqxM8jzA4cya7dIKUsOa6FlonLlnFKdDU,23148
|
|
41
41
|
sksurv/linear_model/coxph.py,sha256=ckZSYTl6e_NIBzGi0WvbF9AyfZyvZ30x1VX-HpHdzbU,22826
|
|
@@ -44,15 +44,15 @@ sksurv/meta/base.py,sha256=AdhIkZi9PvucZ3B2lhhFQpQbwp8EUCDUVOiaev_UsX8,1472
|
|
|
44
44
|
sksurv/meta/ensemble_selection.py,sha256=MBFTcvO9FKJMJkr5Ys6MHp-BQPSq0T9RcrteZjZRHWA,27292
|
|
45
45
|
sksurv/meta/stacking.py,sha256=t5tDtSDKEHURJqdLpo1-L7cQZv6d5N9fbMi87xMj_Dw,13534
|
|
46
46
|
sksurv/svm/__init__.py,sha256=CSceYEcBPGKRcJZ4R0u7DzwictGln_weLIsbt2i5xeU,339
|
|
47
|
-
sksurv/svm/_minlip.cp311-win_amd64.pyd,sha256=
|
|
48
|
-
sksurv/svm/_prsvm.cp311-win_amd64.pyd,sha256=
|
|
49
|
-
sksurv/svm/minlip.py,sha256=
|
|
47
|
+
sksurv/svm/_minlip.cp311-win_amd64.pyd,sha256=sqN9QAIyWj1mcg0OH857FU7XVpAgEeuSlmKqaNm1XFs,143360
|
|
48
|
+
sksurv/svm/_prsvm.cp311-win_amd64.pyd,sha256=QsAxnNLL1aY4LlejE76abe4sziImh5WrzeVh8d1bjRo,141312
|
|
49
|
+
sksurv/svm/minlip.py,sha256=ZgUP_Zx5oduh1VhC-OhYjlJ28J1VhPUHReFirHnEcJw,25671
|
|
50
50
|
sksurv/svm/naive_survival_svm.py,sha256=B7NHFE_BNpyzQKJjUmLZBqmw6sbqlHrGS_6iQLKu2yA,9568
|
|
51
51
|
sksurv/svm/survival_svm.py,sha256=fgF1lSbt_jTy9Ykigg4fEWrsmqae_NyO401KTNG2Kpo,46222
|
|
52
52
|
sksurv/tree/__init__.py,sha256=ozb0fhURX-lpiSiHZd0DMnjkqhC0XOC2CTZq0hEZLPw,65
|
|
53
|
-
sksurv/tree/_criterion.cp311-win_amd64.pyd,sha256=
|
|
53
|
+
sksurv/tree/_criterion.cp311-win_amd64.pyd,sha256=6IHsFXblpIEy1AohNoFgUeZybWVHmYKqW0-zyuyUgTY,162816
|
|
54
54
|
sksurv/tree/tree.py,sha256=4FMKf862hGaZnObIytz_XiLL2fPD11Z0w5zUO2zp7cM,32782
|
|
55
|
-
scikit_survival-0.
|
|
56
|
-
scikit_survival-0.
|
|
57
|
-
scikit_survival-0.
|
|
58
|
-
scikit_survival-0.
|
|
55
|
+
scikit_survival-0.27.0.dist-info/METADATA,sha256=xaC_67dBglseZ5obMIN9AOyYtlMJSwf7dMHNUrUBQS8,7131
|
|
56
|
+
scikit_survival-0.27.0.dist-info/WHEEL,sha256=bVDps4tOb0nHmbMU4uK7h9NZsZJmyTqcRbrvK3TZLY4,102
|
|
57
|
+
scikit_survival-0.27.0.dist-info/top_level.txt,sha256=fPkcFA-XQGbwnD_ZXOvaOWmSd34Qezr26Mn99nYPvAg,7
|
|
58
|
+
scikit_survival-0.27.0.dist-info/RECORD,,
|
|
Binary file
|
sksurv/column.py
CHANGED
|
@@ -14,7 +14,7 @@ import logging
|
|
|
14
14
|
|
|
15
15
|
import numpy as np
|
|
16
16
|
import pandas as pd
|
|
17
|
-
from pandas.api.types import CategoricalDtype,
|
|
17
|
+
from pandas.api.types import CategoricalDtype, is_string_dtype
|
|
18
18
|
|
|
19
19
|
__all__ = ["categorical_to_numeric", "encode_categorical", "standardize"]
|
|
20
20
|
|
|
@@ -118,12 +118,12 @@ def encode_categorical(table, columns=None, **kwargs):
|
|
|
118
118
|
Numeric columns in the input table remain unchanged.
|
|
119
119
|
"""
|
|
120
120
|
if isinstance(table, pd.Series):
|
|
121
|
-
if not isinstance(table.dtype, CategoricalDtype) and not
|
|
121
|
+
if not isinstance(table.dtype, CategoricalDtype) and not is_string_dtype(table.dtype):
|
|
122
122
|
raise TypeError(f"series must be of categorical dtype, but was {table.dtype}")
|
|
123
123
|
return _encode_categorical_series(table, **kwargs)
|
|
124
124
|
|
|
125
125
|
def _is_categorical_or_object(series):
|
|
126
|
-
return isinstance(series.dtype, CategoricalDtype) or
|
|
126
|
+
return isinstance(series.dtype, CategoricalDtype) or is_string_dtype(series.dtype)
|
|
127
127
|
|
|
128
128
|
if columns is None:
|
|
129
129
|
# for columns containing categories
|
|
@@ -187,13 +187,12 @@ def categorical_to_numeric(table):
|
|
|
187
187
|
def transform(column):
|
|
188
188
|
if isinstance(column.dtype, CategoricalDtype):
|
|
189
189
|
return column.cat.codes
|
|
190
|
-
if
|
|
190
|
+
if is_string_dtype(column.dtype):
|
|
191
191
|
try:
|
|
192
192
|
nc = column.astype(np.int64)
|
|
193
193
|
except ValueError:
|
|
194
194
|
classes = column.dropna().unique()
|
|
195
|
-
|
|
196
|
-
nc = column.map(dict(zip(classes, range(classes.shape[0]))))
|
|
195
|
+
nc = column.map(dict(zip(sorted(classes), range(classes.shape[0]))))
|
|
197
196
|
return nc
|
|
198
197
|
if column.dtype == bool:
|
|
199
198
|
return column.astype(np.int64)
|
sksurv/compare.py
CHANGED
|
@@ -117,7 +117,7 @@ def compare_survival(y, group_indicator, return_stats=False):
|
|
|
117
117
|
table["expected"] = expected
|
|
118
118
|
table["statistic"] = observed - expected
|
|
119
119
|
table = pd.DataFrame.from_dict(table)
|
|
120
|
-
table.index = pd.Index(groups, name="group"
|
|
120
|
+
table.index = pd.Index(groups, name="group")
|
|
121
121
|
return chisq, pval, table, covar
|
|
122
122
|
|
|
123
123
|
return chisq, pval
|
sksurv/datasets/base.py
CHANGED
|
@@ -36,10 +36,10 @@ def _get_x_y_survival(dataset, col_event, col_time, val_outcome, competing_risks
|
|
|
36
36
|
event_type = np.int64 if competing_risks else bool
|
|
37
37
|
y = np.empty(dtype=[(col_event, event_type), (col_time, np.float64)], shape=dataset.shape[0])
|
|
38
38
|
if competing_risks:
|
|
39
|
-
y[col_event] = dataset[col_event].
|
|
39
|
+
y[col_event] = dataset[col_event].to_numpy()
|
|
40
40
|
else:
|
|
41
|
-
y[col_event] = (dataset[col_event] == val_outcome).
|
|
42
|
-
y[col_time] = dataset[col_time].
|
|
41
|
+
y[col_event] = (dataset[col_event] == val_outcome).to_numpy()
|
|
42
|
+
y[col_time] = dataset[col_time].to_numpy()
|
|
43
43
|
|
|
44
44
|
x_frame = dataset.drop([col_event, col_time], axis=1)
|
|
45
45
|
|
|
@@ -116,7 +116,7 @@ def _loadarff_with_index(filename):
|
|
|
116
116
|
if isinstance(dataset["index"].dtype, CategoricalDtype):
|
|
117
117
|
# concatenating categorical index may raise TypeError
|
|
118
118
|
# see https://github.com/pandas-dev/pandas/issues/14586
|
|
119
|
-
dataset
|
|
119
|
+
dataset = dataset.astype({"index": "str"})
|
|
120
120
|
dataset.set_index("index", inplace=True)
|
|
121
121
|
return dataset
|
|
122
122
|
|
|
@@ -512,7 +512,7 @@ def load_bmt():
|
|
|
512
512
|
"""
|
|
513
513
|
full_path = _get_data_path("bmt.arff")
|
|
514
514
|
data = loadarff(full_path)
|
|
515
|
-
data
|
|
515
|
+
data = data.astype({"ftime": int})
|
|
516
516
|
return get_x_y(data, attr_labels=["status", "ftime"], competing_risks=True)
|
|
517
517
|
|
|
518
518
|
|
|
@@ -603,8 +603,8 @@ def load_cgvhd():
|
|
|
603
603
|
"""
|
|
604
604
|
full_path = _get_data_path("cgvhd.arff")
|
|
605
605
|
data = loadarff(full_path)
|
|
606
|
-
data["ftime"] = data[["survtime", "reltime", "cgvhtime"]].min(axis=1)
|
|
607
|
-
data["status"] = (
|
|
606
|
+
data.loc[:, "ftime"] = data[["survtime", "reltime", "cgvhtime"]].min(axis=1)
|
|
607
|
+
data.loc[:, "status"] = (
|
|
608
608
|
((data["ftime"] == data["cgvhtime"]) & (data["cgvh"] == "1")).astype(int)
|
|
609
609
|
+ 2 * ((data["ftime"] == data["reltime"]) & (data["rcens"] == "1")).astype(int)
|
|
610
610
|
+ 3 * ((data["ftime"] == data["survtime"]) & (data["stat"] == "1")).astype(int)
|
|
Binary file
|
sksurv/io/arffread.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
15
|
+
from pandas.api.types import is_string_dtype
|
|
15
16
|
from scipy.io.arff import loadarff as scipy_loadarff
|
|
16
17
|
|
|
17
18
|
__all__ = ["loadarff"]
|
|
@@ -34,7 +35,8 @@ def _to_pandas(data, meta):
|
|
|
34
35
|
data_dict[name] = pd.Categorical(raw, categories=attr_format, ordered=False)
|
|
35
36
|
else:
|
|
36
37
|
arr = data[name]
|
|
37
|
-
|
|
38
|
+
dtype = "str" if is_string_dtype(arr.dtype) else arr.dtype
|
|
39
|
+
p = pd.Series(arr, dtype=dtype)
|
|
38
40
|
data_dict[name] = p
|
|
39
41
|
|
|
40
42
|
# currently, this step converts all pandas.Categorial columns back to pandas.Series
|
sksurv/io/arffwrite.py
CHANGED
|
@@ -15,7 +15,7 @@ import re
|
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
import pandas as pd
|
|
18
|
-
from pandas.api.types import CategoricalDtype,
|
|
18
|
+
from pandas.api.types import CategoricalDtype, is_string_dtype
|
|
19
19
|
|
|
20
20
|
_ILLEGAL_CHARACTER_PAT = re.compile(r"[^-_=\w\d\(\)<>\.]")
|
|
21
21
|
|
|
@@ -106,7 +106,7 @@ def _write_header(data, fp, relation_name, index):
|
|
|
106
106
|
name = attribute_names[column]
|
|
107
107
|
fp.write(f"@attribute {name}\t")
|
|
108
108
|
|
|
109
|
-
if isinstance(series.dtype, CategoricalDtype) or
|
|
109
|
+
if isinstance(series.dtype, CategoricalDtype) or is_string_dtype(series.dtype):
|
|
110
110
|
_write_attribute_categorical(series, fp)
|
|
111
111
|
elif np.issubdtype(series.dtype, np.floating):
|
|
112
112
|
fp.write("real")
|
|
@@ -168,11 +168,11 @@ def _write_data(data, fp):
|
|
|
168
168
|
fp.write("@data\n")
|
|
169
169
|
|
|
170
170
|
def to_str(x):
|
|
171
|
-
if pd.
|
|
171
|
+
if pd.isna(x):
|
|
172
172
|
return "?"
|
|
173
173
|
return str(x)
|
|
174
174
|
|
|
175
|
-
data = data.
|
|
175
|
+
data = data.map(to_str)
|
|
176
176
|
n_rows = data.shape[0]
|
|
177
177
|
for i in range(n_rows):
|
|
178
178
|
str_values = list(data.iloc[i, :].apply(_check_str_array))
|
|
Binary file
|
sksurv/kernels/clinical.py
CHANGED
|
@@ -41,7 +41,7 @@ def _get_continuous_and_ordinal_array(x):
|
|
|
41
41
|
ordinal_columns = pd.Index([v for v in nominal_columns if x[v].cat.ordered])
|
|
42
42
|
continuous_columns = x.select_dtypes(include=[np.number]).columns
|
|
43
43
|
|
|
44
|
-
x_num = x.loc[:, continuous_columns].
|
|
44
|
+
x_num = x.loc[:, continuous_columns].to_numpy(dtype=np.float64)
|
|
45
45
|
if len(ordinal_columns) > 0:
|
|
46
46
|
x = _ordinal_as_numeric(x, ordinal_columns)
|
|
47
47
|
|
|
@@ -123,7 +123,7 @@ def clinical_kernel(x, y=None):
|
|
|
123
123
|
y_numeric = x_numeric
|
|
124
124
|
|
|
125
125
|
continuous_ordinal_kernel(x_numeric, y_numeric, mat)
|
|
126
|
-
_nominal_kernel(x.loc[:, nominal_columns].
|
|
126
|
+
_nominal_kernel(x.loc[:, nominal_columns].to_numpy(), y.loc[:, nominal_columns].to_numpy(), mat)
|
|
127
127
|
mat /= x.shape[1]
|
|
128
128
|
return mat
|
|
129
129
|
|
|
@@ -210,7 +210,7 @@ class ClinicalKernelTransform(BaseEstimator, TransformerMixin):
|
|
|
210
210
|
else:
|
|
211
211
|
raise TypeError(f"unsupported dtype: {dt!r}")
|
|
212
212
|
|
|
213
|
-
fit_data[:, i] = col.
|
|
213
|
+
fit_data[:, i] = col.to_numpy()
|
|
214
214
|
|
|
215
215
|
self._numeric_columns = np.asarray(numeric_columns)
|
|
216
216
|
self._nominal_columns = np.asarray(nominal_columns)
|
|
Binary file
|
sksurv/metrics.py
CHANGED
|
@@ -510,7 +510,7 @@ def cumulative_dynamic_auc(survival_train, survival_test, estimate, times, tied_
|
|
|
510
510
|
# to make sure that the curve starts at (0, 0)
|
|
511
511
|
tp_no_ties = np.r_[0, tp_no_ties]
|
|
512
512
|
fp_no_ties = np.r_[0, fp_no_ties]
|
|
513
|
-
scores[i] = np.
|
|
513
|
+
scores[i] = np.trapezoid(tp_no_ties, fp_no_ties)
|
|
514
514
|
|
|
515
515
|
if n_times == 1:
|
|
516
516
|
mean_auc = scores[0]
|
|
@@ -780,7 +780,7 @@ def integrated_brier_score(survival_train, survival_test, estimate, times):
|
|
|
780
780
|
raise ValueError("At least two time points must be given")
|
|
781
781
|
|
|
782
782
|
# Computing the IBS
|
|
783
|
-
ibs_value = np.
|
|
783
|
+
ibs_value = np.trapezoid(brier_scores, times) / (times[-1] - times[0])
|
|
784
784
|
|
|
785
785
|
return ibs_value
|
|
786
786
|
|
sksurv/nonparametric.py
CHANGED
|
@@ -321,7 +321,7 @@ def kaplan_meier_estimator(
|
|
|
321
321
|
>>> plt.step(time, prob_surv, where="post")
|
|
322
322
|
[...]
|
|
323
323
|
>>> plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
|
|
324
|
-
<matplotlib.collections.
|
|
324
|
+
<matplotlib.collections.FillBetweenPolyCollection object at 0x...>
|
|
325
325
|
>>> plt.ylim(0, 1)
|
|
326
326
|
(0.0, 1.0)
|
|
327
327
|
>>> plt.show() # doctest: +SKIP
|
|
@@ -757,12 +757,12 @@ def cumulative_incidence_competing_risks(
|
|
|
757
757
|
>>> plt.step(x, y[0], where="post", label="Total risk")
|
|
758
758
|
[...]
|
|
759
759
|
>>> plt.fill_between(x, conf_int[0, 0], conf_int[0, 1], alpha=0.25, step="post")
|
|
760
|
-
<matplotlib.collections.
|
|
760
|
+
<matplotlib.collections.FillBetweenPolyCollection object at 0x...>
|
|
761
761
|
>>> for i in range(1, n_risks + 1):
|
|
762
762
|
... plt.step(x, y[i], where="post", label=f"{i}-risk")
|
|
763
763
|
... plt.fill_between(x, conf_int[i, 0], conf_int[i, 1], alpha=0.25, step="post")
|
|
764
764
|
[...]
|
|
765
|
-
<matplotlib.collections.
|
|
765
|
+
<matplotlib.collections.FillBetweenPolyCollection object at 0x...>
|
|
766
766
|
>>> plt.ylim(0, 1)
|
|
767
767
|
(0.0, 1.0)
|
|
768
768
|
>>> plt.legend()
|
sksurv/preprocessing.py
CHANGED
|
@@ -10,6 +10,8 @@
|
|
|
10
10
|
#
|
|
11
11
|
# You should have received a copy of the GNU General Public License
|
|
12
12
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from pandas.api.types import CategoricalDtype, is_string_dtype
|
|
13
15
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
14
16
|
from sklearn.utils.validation import _check_feature_names, _check_feature_names_in, _check_n_features, check_is_fitted
|
|
15
17
|
|
|
@@ -127,12 +129,24 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
|
|
|
127
129
|
"""
|
|
128
130
|
_check_feature_names(self, X, reset=True)
|
|
129
131
|
_check_n_features(self, X, reset=True)
|
|
130
|
-
|
|
132
|
+
|
|
133
|
+
def is_string_or_categorical_dtype(dtype):
|
|
134
|
+
return is_string_dtype(dtype) or isinstance(dtype, CategoricalDtype)
|
|
135
|
+
|
|
136
|
+
columns_to_encode = pd.Index(
|
|
137
|
+
[name for name, dtype in X.dtypes.items() if is_string_or_categorical_dtype(dtype)]
|
|
138
|
+
)
|
|
131
139
|
x_dummy = self._encode(X, columns_to_encode)
|
|
132
140
|
|
|
133
141
|
self.feature_names_ = columns_to_encode
|
|
134
|
-
|
|
135
|
-
|
|
142
|
+
cat_cols = {}
|
|
143
|
+
for col_name in columns_to_encode:
|
|
144
|
+
col = X[col_name]
|
|
145
|
+
if not isinstance(col.dtype, CategoricalDtype):
|
|
146
|
+
col = col.astype("category")
|
|
147
|
+
cat_cols[col_name] = col.cat.categories
|
|
148
|
+
self.categories_ = cat_cols
|
|
149
|
+
self.encoded_columns_ = x_dummy.columns.copy()
|
|
136
150
|
return x_dummy
|
|
137
151
|
|
|
138
152
|
def transform(self, X):
|
|
@@ -152,9 +166,7 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
|
|
|
152
166
|
_check_n_features(self, X, reset=False)
|
|
153
167
|
check_columns_exist(X.columns, self.feature_names_)
|
|
154
168
|
|
|
155
|
-
Xt = X.
|
|
156
|
-
for col, cat in self.categories_.items():
|
|
157
|
-
Xt[col] = Xt[col].cat.set_categories(cat)
|
|
169
|
+
Xt = X.astype({col: CategoricalDtype(cat) for col, cat in self.categories_.items()})
|
|
158
170
|
|
|
159
171
|
new_data = self._encode(Xt, self.feature_names_)
|
|
160
172
|
return new_data.loc[:, self.encoded_columns_]
|
|
@@ -180,4 +192,4 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
|
|
|
180
192
|
check_is_fitted(self, "encoded_columns_")
|
|
181
193
|
input_features = _check_feature_names_in(self, input_features)
|
|
182
194
|
|
|
183
|
-
return self.encoded_columns_.
|
|
195
|
+
return self.encoded_columns_.to_numpy(copy=True)
|
|
Binary file
|
|
Binary file
|
sksurv/svm/minlip.py
CHANGED
|
@@ -81,17 +81,22 @@ class OsqpSolver(QPSolver):
|
|
|
81
81
|
|
|
82
82
|
solver_opts = self._get_options()
|
|
83
83
|
m = osqp.OSQP()
|
|
84
|
-
m.setup(P=sparse.csc_matrix(P), q=q, A=G, u=h, **solver_opts) # noqa: E741
|
|
85
|
-
results = m.solve()
|
|
84
|
+
m.setup(P=sparse.csc_matrix(P), q=q, A=G, l=None, u=h, **solver_opts) # noqa: E741
|
|
85
|
+
results = m.solve(raise_error=False)
|
|
86
86
|
|
|
87
|
-
|
|
87
|
+
solved_codes = (
|
|
88
|
+
osqp.SolverStatus.OSQP_SOLVED,
|
|
89
|
+
osqp.SolverStatus.OSQP_SOLVED_INACCURATE,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if results.info.status_val == osqp.SolverStatus.OSQP_MAX_ITER_REACHED: # max iter reached
|
|
88
93
|
warnings.warn(
|
|
89
94
|
(f"OSQP solver did not converge: {results.info.status}"),
|
|
90
95
|
category=ConvergenceWarning,
|
|
91
96
|
stacklevel=2,
|
|
92
97
|
)
|
|
93
|
-
elif results.info.status_val not in
|
|
94
|
-
#
|
|
98
|
+
elif results.info.status_val not in solved_codes: # pragma: no cover
|
|
99
|
+
# none of SOLVED, SOLVED_INACCURATE
|
|
95
100
|
raise RuntimeError(f"OSQP solver failed: {results.info.status}")
|
|
96
101
|
|
|
97
102
|
n_iter = results.info.iter
|
|
@@ -103,7 +108,7 @@ class OsqpSolver(QPSolver):
|
|
|
103
108
|
"eps_abs": 1e-5,
|
|
104
109
|
"eps_rel": 1e-5,
|
|
105
110
|
"max_iter": self.max_iter or 4000,
|
|
106
|
-
"
|
|
111
|
+
"polishing": True,
|
|
107
112
|
"verbose": self.verbose,
|
|
108
113
|
}
|
|
109
114
|
return solver_opts
|
sksurv/testing.py
CHANGED
|
@@ -10,13 +10,17 @@
|
|
|
10
10
|
#
|
|
11
11
|
# You should have received a copy of the GNU General Public License
|
|
12
12
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
|
+
from contextlib import nullcontext
|
|
13
14
|
from importlib import import_module
|
|
15
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
14
16
|
import inspect
|
|
15
17
|
from pathlib import Path
|
|
16
18
|
import pkgutil
|
|
17
19
|
|
|
18
20
|
import numpy as np
|
|
19
21
|
from numpy.testing import assert_almost_equal, assert_array_equal
|
|
22
|
+
from packaging.version import parse
|
|
23
|
+
import pandas as pd
|
|
20
24
|
import pytest
|
|
21
25
|
from sklearn.base import BaseEstimator, TransformerMixin
|
|
22
26
|
|
|
@@ -106,3 +110,51 @@ class FixtureParameterFactory:
|
|
|
106
110
|
values = func()
|
|
107
111
|
cases.append(pytest.param(*values, id=name))
|
|
108
112
|
return cases
|
|
113
|
+
|
|
114
|
+
def get_cases_func(self):
|
|
115
|
+
cases = []
|
|
116
|
+
for name, func in inspect.getmembers(self):
|
|
117
|
+
if name.startswith("data_"):
|
|
118
|
+
cases.append(pytest.param(func, id=name))
|
|
119
|
+
return cases
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def check_module_minimum_version(module, min_version_str, max_version_str=None):
|
|
123
|
+
"""
|
|
124
|
+
Check whether a module of a specified minimum version is available.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
module : str
|
|
129
|
+
Name of the module.
|
|
130
|
+
min_version_str : str
|
|
131
|
+
Minimum version of the module.
|
|
132
|
+
max_version_str : str, optional
|
|
133
|
+
Maximum version of the module (excluding).
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
-------
|
|
137
|
+
available : bool
|
|
138
|
+
True if the module is available and its version is >= `version_str`.
|
|
139
|
+
"""
|
|
140
|
+
try:
|
|
141
|
+
module_version = parse(version(module))
|
|
142
|
+
required_min_version = parse(min_version_str)
|
|
143
|
+
if max_version_str is None:
|
|
144
|
+
return module_version >= required_min_version
|
|
145
|
+
required_max_version = parse(max_version_str)
|
|
146
|
+
return required_min_version <= module_version < required_max_version
|
|
147
|
+
except PackageNotFoundError: # pragma: no cover
|
|
148
|
+
return False
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_pandas_infer_string_context():
|
|
152
|
+
if check_module_minimum_version("pandas", "2.3.0", "3.0.0"):
|
|
153
|
+
return (
|
|
154
|
+
pytest.param(pd.option_context("future.infer_string", False), id="infer_string=False"),
|
|
155
|
+
pytest.param(pd.option_context("future.infer_string", True), id="infer_string=True"),
|
|
156
|
+
)
|
|
157
|
+
return (
|
|
158
|
+
pytest.param(nullcontext(), id="pandas default options"),
|
|
159
|
+
pytest.param(nullcontext(), marks=pytest.mark.skip("no pandas 2.3.0")),
|
|
160
|
+
)
|
|
Binary file
|
sksurv/util.py
CHANGED
|
@@ -142,7 +142,7 @@ class Surv:
|
|
|
142
142
|
raise TypeError(f"expected pandas.DataFrame, but got {type(data)!r}")
|
|
143
143
|
|
|
144
144
|
return Surv.from_arrays(
|
|
145
|
-
data.loc[:, event].
|
|
145
|
+
data.loc[:, event].to_numpy(), data.loc[:, time].to_numpy(), name_event=str(event), name_time=str(time)
|
|
146
146
|
)
|
|
147
147
|
|
|
148
148
|
|
|
@@ -337,6 +337,7 @@ def safe_concat(objs, *args, **kwargs):
|
|
|
337
337
|
categories[df.name] = {"categories": df.cat.categories, "ordered": df.cat.ordered}
|
|
338
338
|
else:
|
|
339
339
|
dfc = df.select_dtypes(include=["category"])
|
|
340
|
+
new_dtypes = {}
|
|
340
341
|
for name, s in dfc.items():
|
|
341
342
|
if name in categories:
|
|
342
343
|
if axis == 1:
|
|
@@ -345,12 +346,12 @@ def safe_concat(objs, *args, **kwargs):
|
|
|
345
346
|
raise ValueError(f"categories for column {name} do not match")
|
|
346
347
|
else:
|
|
347
348
|
categories[name] = {"categories": s.cat.categories, "ordered": s.cat.ordered}
|
|
348
|
-
|
|
349
|
+
new_dtypes[name] = "str"
|
|
350
|
+
df = df.astype(new_dtypes)
|
|
349
351
|
|
|
350
352
|
concatenated = pd.concat(objs, *args, axis=axis, **kwargs)
|
|
351
353
|
|
|
352
|
-
for name, params in categories.items()
|
|
353
|
-
concatenated[name] = pd.Categorical(concatenated[name], **params)
|
|
354
|
+
concatenated = concatenated.astype({name: pd.CategoricalDtype(**params) for name, params in categories.items()})
|
|
354
355
|
|
|
355
356
|
return concatenated
|
|
356
357
|
|
|
File without changes
|
|
File without changes
|