scikit-survival 0.26.0__cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. scikit_survival-0.26.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.26.0.dist-info/RECORD +58 -0
  3. scikit_survival-0.26.0.dist-info/WHEEL +6 -0
  4. scikit_survival-0.26.0.dist-info/licenses/COPYING +674 -0
  5. scikit_survival-0.26.0.dist-info/top_level.txt +1 -0
  6. sksurv/__init__.py +183 -0
  7. sksurv/base.py +115 -0
  8. sksurv/bintrees/__init__.py +15 -0
  9. sksurv/bintrees/_binarytrees.cpython-312-x86_64-linux-gnu.so +0 -0
  10. sksurv/column.py +204 -0
  11. sksurv/compare.py +123 -0
  12. sksurv/datasets/__init__.py +12 -0
  13. sksurv/datasets/base.py +614 -0
  14. sksurv/datasets/data/GBSG2.arff +700 -0
  15. sksurv/datasets/data/actg320.arff +1169 -0
  16. sksurv/datasets/data/bmt.arff +46 -0
  17. sksurv/datasets/data/breast_cancer_GSE7390-metastasis.arff +283 -0
  18. sksurv/datasets/data/cgvhd.arff +118 -0
  19. sksurv/datasets/data/flchain.arff +7887 -0
  20. sksurv/datasets/data/veteran.arff +148 -0
  21. sksurv/datasets/data/whas500.arff +520 -0
  22. sksurv/docstrings.py +99 -0
  23. sksurv/ensemble/__init__.py +2 -0
  24. sksurv/ensemble/_coxph_loss.cpython-312-x86_64-linux-gnu.so +0 -0
  25. sksurv/ensemble/boosting.py +1564 -0
  26. sksurv/ensemble/forest.py +902 -0
  27. sksurv/ensemble/survival_loss.py +151 -0
  28. sksurv/exceptions.py +18 -0
  29. sksurv/functions.py +114 -0
  30. sksurv/io/__init__.py +2 -0
  31. sksurv/io/arffread.py +91 -0
  32. sksurv/io/arffwrite.py +181 -0
  33. sksurv/kernels/__init__.py +1 -0
  34. sksurv/kernels/_clinical_kernel.cpython-312-x86_64-linux-gnu.so +0 -0
  35. sksurv/kernels/clinical.py +348 -0
  36. sksurv/linear_model/__init__.py +3 -0
  37. sksurv/linear_model/_coxnet.cpython-312-x86_64-linux-gnu.so +0 -0
  38. sksurv/linear_model/aft.py +208 -0
  39. sksurv/linear_model/coxnet.py +592 -0
  40. sksurv/linear_model/coxph.py +637 -0
  41. sksurv/meta/__init__.py +4 -0
  42. sksurv/meta/base.py +35 -0
  43. sksurv/meta/ensemble_selection.py +724 -0
  44. sksurv/meta/stacking.py +370 -0
  45. sksurv/metrics.py +1028 -0
  46. sksurv/nonparametric.py +911 -0
  47. sksurv/preprocessing.py +195 -0
  48. sksurv/svm/__init__.py +11 -0
  49. sksurv/svm/_minlip.cpython-312-x86_64-linux-gnu.so +0 -0
  50. sksurv/svm/_prsvm.cpython-312-x86_64-linux-gnu.so +0 -0
  51. sksurv/svm/minlip.py +695 -0
  52. sksurv/svm/naive_survival_svm.py +249 -0
  53. sksurv/svm/survival_svm.py +1236 -0
  54. sksurv/testing.py +155 -0
  55. sksurv/tree/__init__.py +1 -0
  56. sksurv/tree/_criterion.cpython-312-x86_64-linux-gnu.so +0 -0
  57. sksurv/tree/tree.py +790 -0
  58. sksurv/util.py +416 -0
sksurv/__init__.py ADDED
@@ -0,0 +1,183 @@
1
+ from importlib.metadata import PackageNotFoundError, version
2
+ import platform
3
+ import sys
4
+
5
+ from sklearn.pipeline import Pipeline, _final_estimator_has
6
+ from sklearn.utils.metaestimators import available_if
7
+
8
+ from .util import property_available_if
9
+
10
+
11
+ def _get_version(name):
12
+ try:
13
+ pkg_version = version(name)
14
+ except ImportError:
15
+ pkg_version = None
16
+ return pkg_version
17
+
18
+
19
+ def show_versions():
20
+ """Print debugging information."""
21
+ sys_info = {
22
+ "Platform": platform.platform(),
23
+ "Python version": f"{platform.python_implementation()} {platform.python_version()}",
24
+ "Python interpreter": sys.executable,
25
+ }
26
+
27
+ deps = [
28
+ "scikit-survival",
29
+ "scikit-learn",
30
+ "numpy",
31
+ "scipy",
32
+ "pandas",
33
+ "numexpr",
34
+ "ecos",
35
+ "osqp",
36
+ "joblib",
37
+ "matplotlib",
38
+ "pytest",
39
+ "sphinx",
40
+ "Cython",
41
+ "pip",
42
+ "setuptools",
43
+ ]
44
+ minwidth = max(
45
+ max(map(len, deps)),
46
+ max(map(len, sys_info.keys())),
47
+ )
48
+ fmt = f"{{0:<{minwidth}s}}: {{1}}"
49
+
50
+ print("SYSTEM")
51
+ print("------")
52
+ for name, version_string in sys_info.items():
53
+ print(fmt.format(name, version_string))
54
+
55
+ print("\nDEPENDENCIES")
56
+ print("------------")
57
+ for dep in deps:
58
+ version_string = _get_version(dep)
59
+ print(fmt.format(dep, version_string))
60
+
61
+
62
+ @available_if(_final_estimator_has("predict_cumulative_hazard_function"))
63
+ def predict_cumulative_hazard_function(self, X, **kwargs):
64
+ r"""Predict cumulative hazard function for a pipeline.
65
+
66
+ The cumulative hazard function for an individual
67
+ with feature vector :math:`x` is defined as
68
+
69
+ .. math::
70
+
71
+ H(t \mid x) = \exp(x^\top \beta) H_0(t) ,
72
+
73
+ where :math:`H_0(t)` is the baseline hazard function,
74
+ estimated by Breslow's estimator.
75
+
76
+ Parameters
77
+ ----------
78
+ X : array-like, shape = (n_samples, n_features)
79
+ Data matrix.
80
+
81
+ Returns
82
+ -------
83
+ cum_hazard : ndarray, shape = (n_samples,)
84
+ Predicted cumulative hazard functions. Each element is an instance
85
+ of :class:`sksurv.functions.StepFunction`.
86
+
87
+ See Also
88
+ --------
89
+ predict_survival_function : Predict survival function for a pipeline.
90
+
91
+ Examples
92
+ --------
93
+ >>> from sksurv.datasets import load_whas500
94
+ >>> from sksurv.linear_model import CoxPHSurvivalAnalysis
95
+ >>> from sksurv.preprocessing import OneHotEncoder
96
+ >>> from sklearn.pipeline import Pipeline
97
+ >>>
98
+ >>> X, y = load_whas500()
99
+ >>> pipe = Pipeline([('encode', OneHotEncoder()),
100
+ ... ('cox', CoxPHSurvivalAnalysis())])
101
+ >>> pipe.fit(X, y)
102
+ Pipeline(...)
103
+ >>> chf = pipe.predict_cumulative_hazard_function(X.iloc[:5])
104
+ >>> for fn in chf:
105
+ ... print(fn.x, fn.y)
106
+ [...]
107
+ """
108
+ Xt = X
109
+ for _, _, transform in self._iter(with_final=False):
110
+ Xt = transform.transform(Xt)
111
+ return self.steps[-1][-1].predict_cumulative_hazard_function(Xt, **kwargs)
112
+
113
+
114
+ @available_if(_final_estimator_has("predict_survival_function"))
115
+ def predict_survival_function(self, X, **kwargs):
116
+ r"""Predict survival function for a pipeline.
117
+
118
+ The survival function for an individual
119
+ with feature vector :math:`x` is defined as
120
+
121
+ .. math::
122
+
123
+ S(t \mid x) = S_0(t)^{\exp(x^\top \beta)} ,
124
+
125
+ where :math:`S_0(t)` is the baseline survival function,
126
+ estimated by Breslow's estimator.
127
+
128
+ Parameters
129
+ ----------
130
+ X : array-like, shape = (n_samples, n_features)
131
+ Data matrix.
132
+
133
+ Returns
134
+ -------
135
+ survival : ndarray, shape = (n_samples,)
136
+ Predicted survival functions. Each element is an instance
137
+ of :class:`sksurv.functions.StepFunction`.
138
+
139
+ See Also
140
+ --------
141
+ predict_cumulative_hazard_function : Predict cumulative hazard function for a pipeline.
142
+
143
+ Examples
144
+ --------
145
+ >>> from sksurv.datasets import load_whas500
146
+ >>> from sksurv.linear_model import CoxPHSurvivalAnalysis
147
+ >>> from sksurv.preprocessing import OneHotEncoder
148
+ >>> from sklearn.pipeline import Pipeline
149
+ >>>
150
+ >>> X, y = load_whas500()
151
+ >>> pipe = Pipeline([('encode', OneHotEncoder()),
152
+ ... ('cox', CoxPHSurvivalAnalysis())])
153
+ >>> pipe.fit(X, y)
154
+ Pipeline(...)
155
+ >>> surv_fn = pipe.predict_survival_function(X.iloc[:5])
156
+ >>> for fn in surv_fn:
157
+ ... print(fn.x, fn.y)
158
+ [...]
159
+ """
160
+ Xt = X
161
+ for _, _, transform in self._iter(with_final=False):
162
+ Xt = transform.transform(Xt)
163
+ return self.steps[-1][-1].predict_survival_function(Xt, **kwargs)
164
+
165
+
166
+ @property_available_if(_final_estimator_has("_predict_risk_score"))
167
+ def _predict_risk_score(self):
168
+ return self.steps[-1][-1]._predict_risk_score
169
+
170
+
171
+ def patch_pipeline():
172
+ Pipeline.predict_survival_function = predict_survival_function
173
+ Pipeline.predict_cumulative_hazard_function = predict_cumulative_hazard_function
174
+ Pipeline._predict_risk_score = _predict_risk_score
175
+
176
+
177
+ try:
178
+ __version__ = version("scikit-survival")
179
+ except PackageNotFoundError: # pragma: no cover
180
+ # package is not installed
181
+ __version__ = "unknown"
182
+
183
+ patch_pipeline()
sksurv/base.py ADDED
@@ -0,0 +1,115 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import numpy as np
14
+
15
+
16
+ class SurvivalAnalysisMixin:
17
+ def _predict_function(self, func_name, baseline_model, prediction, return_array):
18
+ fns = getattr(baseline_model, func_name)(prediction)
19
+
20
+ if not return_array:
21
+ return fns
22
+
23
+ times = baseline_model.unique_times_
24
+ arr = np.empty((prediction.shape[0], times.shape[0]), dtype=float)
25
+ for i, fn in enumerate(fns):
26
+ arr[i, :] = fn(times)
27
+ return arr
28
+
29
+ def _predict_survival_function(self, baseline_model, prediction, return_array):
30
+ """Return survival functions.
31
+
32
+ Parameters
33
+ ----------
34
+ baseline_model : sksurv.linear_model.coxph.BreslowEstimator
35
+ Estimator of baseline survival function.
36
+
37
+ prediction : array-like, shape=(n_samples,)
38
+ Predicted risk scores.
39
+
40
+ return_array : bool
41
+ If True, return a float array of the survival function
42
+ evaluated at the unique event times, otherwise return
43
+ an array of :class:`sksurv.functions.StepFunction` instances.
44
+
45
+ Returns
46
+ -------
47
+ survival : ndarray of StepFunction
48
+ If `return_array` is True, an array of shape (n_samples, n_unique_times)
49
+ containing the survival function values. Otherwise, a list of
50
+ :class:`sksurv.functions.StepFunction` instances.
51
+ """
52
+ return self._predict_function("get_survival_function", baseline_model, prediction, return_array)
53
+
54
+ def _predict_cumulative_hazard_function(self, baseline_model, prediction, return_array):
55
+ """Return cumulative hazard functions.
56
+
57
+ Parameters
58
+ ----------
59
+ baseline_model : sksurv.linear_model.coxph.BreslowEstimator
60
+ Estimator of baseline cumulative hazard function.
61
+
62
+ prediction : array-like, shape=(n_samples,)
63
+ Predicted risk scores.
64
+
65
+ return_array : bool
66
+ If True, return a float array of the cumulative hazard function
67
+ evaluated at the unique event times, otherwise return
68
+ an array of :class:`sksurv.functions.StepFunction` instances.
69
+
70
+ Returns
71
+ -------
72
+ cum_hazard : ndarray of StepFunction
73
+ If `return_array` is True, an array of shape (n_samples, n_unique_times)
74
+ containing the cumulative hazard function values. Otherwise, a list of
75
+ :class:`sksurv.functions.StepFunction` instances.
76
+ """
77
+ return self._predict_function("get_cumulative_hazard_function", baseline_model, prediction, return_array)
78
+
79
+ def score(self, X, y):
80
+ """Returns the concordance index of the prediction.
81
+
82
+ Parameters
83
+ ----------
84
+ X : array-like, shape = (n_samples, n_features)
85
+ Test samples.
86
+
87
+ y : structured array, shape = (n_samples,)
88
+ A structured array containing the binary event indicator
89
+ as first field, and time of event or time of censoring as
90
+ second field.
91
+
92
+ Returns
93
+ -------
94
+ cindex : float
95
+ Estimated concordance index.
96
+
97
+ See also
98
+ --------
99
+ sksurv.metrics.concordance_index_censored : Computes the concordance index.
100
+ """
101
+ from .metrics import concordance_index_censored
102
+
103
+ name_event, name_time = y.dtype.names
104
+
105
+ risk_score = self.predict(X)
106
+ if not getattr(self, "_predict_risk_score", True):
107
+ risk_score *= -1 # convert prediction on time scale to risk scale
108
+
109
+ result = concordance_index_censored(y[name_event], y[name_time], risk_score)
110
+ return result[0]
111
+
112
+ def __sklearn_tags__(self):
113
+ tags = super().__sklearn_tags__()
114
+ tags.target_tags.required = True
115
+ return tags
@@ -0,0 +1,15 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ from ._binarytrees import AATree, AVLTree, RBTree
14
+
15
+ __all__ = ["RBTree", "AVLTree", "AATree"]
sksurv/column.py ADDED
@@ -0,0 +1,204 @@
1
+ # This program is free software: you can redistribute it and/or modify
2
+ # it under the terms of the GNU General Public License as published by
3
+ # the Free Software Foundation, either version 3 of the License, or
4
+ # (at your option) any later version.
5
+ #
6
+ # This program is distributed in the hope that it will be useful,
7
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
8
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9
+ # GNU General Public License for more details.
10
+ #
11
+ # You should have received a copy of the GNU General Public License
12
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
+ import logging
14
+
15
+ import numpy as np
16
+ import pandas as pd
17
+ from pandas.api.types import CategoricalDtype, is_string_dtype
18
+
19
+ __all__ = ["categorical_to_numeric", "encode_categorical", "standardize"]
20
+
21
+
22
+ def _apply_along_column(array, func1d, **kwargs):
23
+ if isinstance(array, pd.DataFrame):
24
+ return array.apply(func1d, **kwargs)
25
+ return np.apply_along_axis(func1d, 0, array, **kwargs)
26
+
27
+
28
+ def standardize_column(series_or_array, with_std=True):
29
+ d = series_or_array.dtype
30
+ if issubclass(d.type, np.number):
31
+ output = series_or_array.astype(float)
32
+ m = series_or_array.mean()
33
+ output -= m
34
+
35
+ if with_std:
36
+ s = series_or_array.std(ddof=1)
37
+ output /= s
38
+
39
+ return output
40
+
41
+ return series_or_array
42
+
43
+
44
+ def standardize(table, with_std=True):
45
+ """Standardize numeric features by removing the mean and scaling to unit variance.
46
+
47
+ This function performs Z-Normalization on each numeric column of the given
48
+ table.
49
+
50
+ If `table` is a :class:`pandas.DataFrame`, only numeric columns are modified;
51
+ all other columns remain unchanged. If `table` is a :class:`numpy.ndarray`,
52
+ it is only modified if it has a numeric dtype, in which case the returned
53
+ array will have a floating-point dtype.
54
+
55
+ Parameters
56
+ ----------
57
+ table : pandas.DataFrame or numpy.ndarray
58
+ Data to standardize.
59
+ with_std : bool, optional, default: True
60
+ If ``False``, data is only centered (mean removed) and not scaled to
61
+ unit variance.
62
+
63
+ Returns
64
+ -------
65
+ normalized : pandas.DataFrame or numpy.ndarray
66
+ The standardized data. The output type will be the same as the input type.
67
+ """
68
+ new_frame = _apply_along_column(table, standardize_column, with_std=with_std)
69
+
70
+ return new_frame
71
+
72
+
73
+ def _encode_categorical_series(series, allow_drop=True):
74
+ values = _get_dummies_1d(series, allow_drop=allow_drop)
75
+ if values is None:
76
+ return
77
+
78
+ enc, levels = values
79
+ if enc is None:
80
+ return pd.Series(index=series.index, name=series.name, dtype=series.dtype)
81
+
82
+ if not allow_drop and enc.shape[1] == 1:
83
+ return series
84
+
85
+ names = []
86
+ for key in range(1, enc.shape[1]):
87
+ names.append(f"{series.name}={levels[key]}")
88
+ series = pd.DataFrame(enc[:, 1:], columns=names, index=series.index)
89
+
90
+ return series
91
+
92
+
93
+ def encode_categorical(table, columns=None, **kwargs):
94
+ """One-hot encode categorical features.
95
+
96
+ This function creates a binary column for each category and, by default,
97
+ drops one of the categories per feature: a column with `M` categories
98
+ is encoded as `M-1` integer columns according to the one-hot
99
+ scheme.
100
+
101
+ Parameters
102
+ ----------
103
+ table : pandas.DataFrame or pandas.Series
104
+ Data with categorical columns to encode.
105
+ columns : list-like, optional, default: None
106
+ Column names in the DataFrame to be encoded.
107
+ If `columns` is `None`, all columns with `object` or `category`
108
+ dtype will be converted. This parameter is ignored if `table` is a
109
+ pandas.Series.
110
+ allow_drop : bool, optional, default: True
111
+ Whether to allow dropping categorical columns that only consist
112
+ of a single category.
113
+
114
+ Returns
115
+ -------
116
+ encoded : pandas.DataFrame
117
+ The transformed data with categorical columns encoded as numeric.
118
+ Numeric columns in the input table remain unchanged.
119
+ """
120
+ if isinstance(table, pd.Series):
121
+ if not isinstance(table.dtype, CategoricalDtype) and not is_string_dtype(table.dtype):
122
+ raise TypeError(f"series must be of categorical dtype, but was {table.dtype}")
123
+ return _encode_categorical_series(table, **kwargs)
124
+
125
+ def _is_categorical_or_object(series):
126
+ return isinstance(series.dtype, CategoricalDtype) or is_string_dtype(series.dtype)
127
+
128
+ if columns is None:
129
+ # for columns containing categories
130
+ columns_to_encode = {nam for nam, s in table.items() if _is_categorical_or_object(s)}
131
+ else:
132
+ columns_to_encode = set(columns)
133
+
134
+ items = []
135
+ for name, series in table.items():
136
+ if name in columns_to_encode:
137
+ series = _encode_categorical_series(series, **kwargs)
138
+ if series is None:
139
+ continue
140
+ items.append(series)
141
+
142
+ # concat columns of tables
143
+ new_table = pd.concat(items, axis=1, copy=False)
144
+ return new_table
145
+
146
+
147
+ def _get_dummies_1d(data, allow_drop=True):
148
+ # Series avoids inconsistent NaN handling
149
+ cat = pd.Categorical(data)
150
+ levels = cat.categories
151
+ number_of_cols = len(levels)
152
+
153
+ # if all NaN or only one level
154
+ if allow_drop and number_of_cols < 2:
155
+ logging.getLogger(__package__).warning(
156
+ f"dropped categorical variable {data.name!r}, because it has only {number_of_cols} values"
157
+ )
158
+ return
159
+ if number_of_cols == 0:
160
+ return None, levels
161
+
162
+ dummy_mat = np.eye(number_of_cols).take(cat.codes, axis=0)
163
+
164
+ # reset NaN GH4446
165
+ dummy_mat[cat.codes == -1] = np.nan
166
+
167
+ return dummy_mat, levels
168
+
169
+
170
+ def categorical_to_numeric(table):
171
+ """Encode categorical features as integers.
172
+
173
+ This function converts each category to a unique integer value.
174
+
175
+ Parameters
176
+ ----------
177
+ table : pandas.DataFrame or pandas.Series
178
+ Data with categorical columns to encode.
179
+
180
+ Returns
181
+ -------
182
+ encoded : pandas.DataFrame or pandas.Series
183
+ The transformed data with categorical columns encoded as integers.
184
+ The output type will be the same as the input type.
185
+ """
186
+
187
+ def transform(column):
188
+ if isinstance(column.dtype, CategoricalDtype):
189
+ return column.cat.codes
190
+ if is_string_dtype(column.dtype):
191
+ try:
192
+ nc = column.astype(np.int64)
193
+ except ValueError:
194
+ classes = column.dropna().unique()
195
+ nc = column.map(dict(zip(sorted(classes), range(classes.shape[0]))))
196
+ return nc
197
+ if column.dtype == bool:
198
+ return column.astype(np.int64)
199
+
200
+ return column
201
+
202
+ if isinstance(table, pd.Series):
203
+ return pd.Series(transform(table), name=table.name, index=table.index)
204
+ return table.apply(transform, axis=0, result_type="expand")
sksurv/compare.py ADDED
@@ -0,0 +1,123 @@
1
+ from collections import OrderedDict
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from scipy import stats
6
+ from sklearn.utils.validation import check_array
7
+
8
+ from .util import check_array_survival
9
+
10
+ __all__ = ["compare_survival"]
11
+
12
+
13
+ def compare_survival(y, group_indicator, return_stats=False):
14
+ """Compare survival functions of two or more groups using the log-rank test.
15
+
16
+ The log-rank test is a non-parametric hypothesis test for comparing the
17
+ survival functions of two or more independent groups. The null hypothesis is
18
+ that the survival functions of the groups are identical. The alternative
19
+ hypothesis is that at least one survival function differs from the others.
20
+
21
+ The test statistic is approximately chi-squared distributed with :math:`K-1`
22
+ degrees of freedom, where :math:`K` is the number of groups.
23
+
24
+ See [1]_ for more details.
25
+
26
+ Parameters
27
+ ----------
28
+ y : structured array, shape = (n_samples,)
29
+ A structured array with two fields. The first field is a boolean
30
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
31
+ The second field is a float with the time of event or time of censoring.
32
+ group_indicator : array-like, shape = (n_samples,)
33
+ Group membership of each sample.
34
+ return_stats : bool, optional, default: False
35
+ Whether to return a data frame with statistics for each group and the
36
+ covariance matrix of the test statistic.
37
+
38
+ Returns
39
+ -------
40
+ chisq : float
41
+ The test statistic.
42
+ pvalue : float
43
+ The two-sided p-value for the test.
44
+ stats : pandas.DataFrame, optional
45
+ A DataFrame with summary statistics for each group. This includes the
46
+ number of samples, observed number of events, expected number of events,
47
+ and the test statistic. Only returned if ``return_stats`` is ``True``.
48
+ covariance : ndarray, shape=(n_groups, n_groups), optional
49
+ The covariance matrix of the test statistic. Only returned if
50
+ ``return_stats`` is ``True``.
51
+
52
+ References
53
+ ----------
54
+ .. [1] Fleming, T. R. and Harrington, D. P.
55
+ A Class of Hypothesis Tests for One and Two Samples of Censored Survival Data.
56
+ Communications In Statistics 10 (1981): 763-794.
57
+ """
58
+
59
+ event, time = check_array_survival(group_indicator, y)
60
+ group_indicator = check_array(
61
+ group_indicator,
62
+ dtype="O",
63
+ ensure_2d=False,
64
+ estimator="compare_survival",
65
+ input_name="group_indicator",
66
+ )
67
+
68
+ n_samples = time.shape[0]
69
+ groups, group_counts = np.unique(group_indicator, return_counts=True)
70
+ n_groups = groups.shape[0]
71
+ if n_groups == 1:
72
+ raise ValueError("At least two groups must be specified, but only one was provided.")
73
+
74
+ # sort descending
75
+ o = np.argsort(-time, kind="mergesort")
76
+ x = group_indicator[o]
77
+ event = event[o]
78
+ time = time[o]
79
+
80
+ at_risk = np.zeros(n_groups, dtype=int)
81
+ observed = np.zeros(n_groups, dtype=int)
82
+ expected = np.zeros(n_groups, dtype=float)
83
+ covar = np.zeros((n_groups, n_groups), dtype=float)
84
+
85
+ covar_indices = np.diag_indices(n_groups)
86
+
87
+ k = 0
88
+ while k < n_samples:
89
+ ti = time[k]
90
+ total_events = 0
91
+ while k < n_samples and ti == time[k]:
92
+ idx = np.searchsorted(groups, x[k])
93
+ if event[k]:
94
+ observed[idx] += 1
95
+ total_events += 1
96
+ at_risk[idx] += 1
97
+ k += 1
98
+
99
+ if total_events != 0:
100
+ total_at_risk = k
101
+ expected += at_risk * (total_events / total_at_risk)
102
+ if total_at_risk > 1:
103
+ multiplier = total_events * (total_at_risk - total_events) / (total_at_risk * (total_at_risk - 1))
104
+ temp = at_risk * multiplier
105
+ covar[covar_indices] += temp
106
+ covar -= np.outer(temp, at_risk) / total_at_risk
107
+
108
+ df = n_groups - 1
109
+ zz = observed[:df] - expected[:df]
110
+ chisq = np.linalg.solve(covar[:df, :df], zz).dot(zz)
111
+ pval = stats.chi2.sf(chisq, df)
112
+
113
+ if return_stats:
114
+ table = OrderedDict()
115
+ table["counts"] = group_counts
116
+ table["observed"] = observed
117
+ table["expected"] = expected
118
+ table["statistic"] = observed - expected
119
+ table = pd.DataFrame.from_dict(table)
120
+ table.index = pd.Index(groups, name="group")
121
+ return chisq, pval, table, covar
122
+
123
+ return chisq, pval
@@ -0,0 +1,12 @@
1
+ from .base import (
2
+ get_x_y, # noqa: F401
3
+ load_aids, # noqa: F401
4
+ load_arff_files_standardized, # noqa: F401
5
+ load_bmt, # noqa: F401
6
+ load_breast_cancer, # noqa: F401
7
+ load_cgvhd, # noqa: F401
8
+ load_flchain, # noqa: F401
9
+ load_gbsg2, # noqa: F401
10
+ load_veterans_lung_cancer, # noqa: F401
11
+ load_whas500, # noqa: F401
12
+ )