scikit-survival 0.24.1__cp311-cp311-win_amd64.whl → 0.25.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.25.0.dist-info/METADATA +185 -0
- scikit_survival-0.25.0.dist-info/RECORD +58 -0
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/WHEEL +1 -1
- sksurv/__init__.py +51 -6
- sksurv/base.py +12 -2
- sksurv/bintrees/_binarytrees.cp311-win_amd64.pyd +0 -0
- sksurv/column.py +33 -29
- sksurv/compare.py +22 -22
- sksurv/datasets/base.py +45 -20
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/_coxph_loss.cp311-win_amd64.pyd +0 -0
- sksurv/ensemble/boosting.py +116 -168
- sksurv/ensemble/forest.py +94 -151
- sksurv/functions.py +29 -29
- sksurv/io/arffread.py +34 -3
- sksurv/io/arffwrite.py +38 -2
- sksurv/kernels/_clinical_kernel.cp311-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +33 -13
- sksurv/linear_model/_coxnet.cp311-win_amd64.pyd +0 -0
- sksurv/linear_model/aft.py +14 -11
- sksurv/linear_model/coxnet.py +138 -89
- sksurv/linear_model/coxph.py +102 -83
- sksurv/meta/ensemble_selection.py +91 -9
- sksurv/meta/stacking.py +47 -26
- sksurv/metrics.py +257 -224
- sksurv/nonparametric.py +150 -81
- sksurv/preprocessing.py +55 -27
- sksurv/svm/_minlip.cp311-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp311-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +160 -79
- sksurv/svm/naive_survival_svm.py +63 -34
- sksurv/svm/survival_svm.py +103 -103
- sksurv/tree/_criterion.cp311-win_amd64.pyd +0 -0
- sksurv/tree/tree.py +170 -84
- sksurv/util.py +80 -26
- scikit_survival-0.24.1.dist-info/METADATA +0 -889
- scikit_survival-0.24.1.dist-info/RECORD +0 -57
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/licenses/COPYING +0 -0
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/top_level.txt +0 -0
sksurv/linear_model/coxph.py
CHANGED
|
@@ -21,6 +21,7 @@ from sklearn.utils._param_validation import Interval, StrOptions
|
|
|
21
21
|
from sklearn.utils.validation import check_array, check_is_fitted, validate_data
|
|
22
22
|
|
|
23
23
|
from ..base import SurvivalAnalysisMixin
|
|
24
|
+
from ..docstrings import append_cumulative_hazard_example, append_survival_function_example
|
|
24
25
|
from ..functions import StepFunction
|
|
25
26
|
from ..nonparametric import _compute_counts
|
|
26
27
|
from ..util import check_array_survival
|
|
@@ -29,17 +30,21 @@ __all__ = ["CoxPHSurvivalAnalysis"]
|
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
class BreslowEstimator:
|
|
32
|
-
"""Breslow's estimator
|
|
33
|
+
"""Breslow's non-parametric estimator for the cumulative baseline hazard.
|
|
34
|
+
|
|
35
|
+
This class is used by :class:`CoxPHSurvivalAnalysis` to estimate the
|
|
36
|
+
cumulative baseline hazard and baseline survival function after the
|
|
37
|
+
coefficients of the Cox model have been fitted.
|
|
33
38
|
|
|
34
39
|
Attributes
|
|
35
40
|
----------
|
|
36
41
|
cum_baseline_hazard_ : :class:`sksurv.functions.StepFunction`
|
|
37
|
-
|
|
42
|
+
Estimated cumulative baseline hazard function.
|
|
38
43
|
|
|
39
44
|
baseline_survival_ : :class:`sksurv.functions.StepFunction`
|
|
40
|
-
|
|
45
|
+
Estimated baseline survival function.
|
|
41
46
|
|
|
42
|
-
unique_times_ : ndarray
|
|
47
|
+
unique_times_ : ndarray, shape=(n_unique_times,)
|
|
43
48
|
Unique event times.
|
|
44
49
|
"""
|
|
45
50
|
|
|
@@ -126,7 +131,29 @@ class BreslowEstimator:
|
|
|
126
131
|
|
|
127
132
|
|
|
128
133
|
class CoxPHOptimizer:
|
|
129
|
-
"""
|
|
134
|
+
"""Helper class for fitting the Cox proportional hazards model.
|
|
135
|
+
|
|
136
|
+
This class computes the negative log-likelihood, its gradient, and the
|
|
137
|
+
Hessian matrix for the Cox model. It is used internally by
|
|
138
|
+
:class:`CoxPHSurvivalAnalysis`.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
X : ndarray, shape=(n_samples, n_features)
|
|
143
|
+
The feature matrix.
|
|
144
|
+
|
|
145
|
+
event : ndarray, shape=(n_samples,)
|
|
146
|
+
The event indicator.
|
|
147
|
+
|
|
148
|
+
time : ndarray, shape=(n_samples,)
|
|
149
|
+
The event/censoring times.
|
|
150
|
+
|
|
151
|
+
alpha : ndarray, shape=(n_features,)
|
|
152
|
+
The regularization parameters.
|
|
153
|
+
|
|
154
|
+
ties : {'breslow', 'efron'}
|
|
155
|
+
The method to handle tied event times.
|
|
156
|
+
"""
|
|
130
157
|
|
|
131
158
|
def __init__(self, X, event, time, alpha, ties):
|
|
132
159
|
# sort descending
|
|
@@ -270,6 +297,17 @@ class CoxPHOptimizer:
|
|
|
270
297
|
|
|
271
298
|
|
|
272
299
|
class VerboseReporter:
|
|
300
|
+
"""Helper class to report optimization progress.
|
|
301
|
+
|
|
302
|
+
This class is used by :class:`CoxPHSurvivalAnalysis` to print
|
|
303
|
+
optimization progress depending on the verbosity level.
|
|
304
|
+
|
|
305
|
+
Parameters
|
|
306
|
+
----------
|
|
307
|
+
verbose : int
|
|
308
|
+
The verbosity level.
|
|
309
|
+
"""
|
|
310
|
+
|
|
273
311
|
def __init__(self, verbose):
|
|
274
312
|
self.verbose = verbose
|
|
275
313
|
|
|
@@ -293,20 +331,25 @@ class VerboseReporter:
|
|
|
293
331
|
|
|
294
332
|
|
|
295
333
|
class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
296
|
-
"""Cox proportional hazards model.
|
|
334
|
+
"""The Cox proportional hazards model, also known as Cox regression.
|
|
335
|
+
|
|
336
|
+
This model is a semi-parametric model that can be used to model the
|
|
337
|
+
relationship between a set of features and the time to an event.
|
|
338
|
+
The model is fitted by maximizing the partial likelihood
|
|
339
|
+
using Newton-Raphson optimization.
|
|
297
340
|
|
|
298
341
|
There are two possible choices for handling tied event times.
|
|
299
342
|
The default is Breslow's method, which considers each of the
|
|
300
343
|
events at a given time as distinct. Efron's method is more
|
|
301
344
|
accurate if there are a large number of ties. When the number
|
|
302
345
|
of ties is small, the estimated coefficients by Breslow's and
|
|
303
|
-
Efron's method are quite close.
|
|
346
|
+
Efron's method are quite close.
|
|
304
347
|
|
|
305
348
|
See [1]_, [2]_, [3]_ for further description.
|
|
306
349
|
|
|
307
350
|
Parameters
|
|
308
351
|
----------
|
|
309
|
-
alpha : float
|
|
352
|
+
alpha : float or ndarray, shape = (n_features,), optional, default: 0
|
|
310
353
|
Regularization parameter for ridge regression penalty.
|
|
311
354
|
If a single float, the same penalty is used for all features.
|
|
312
355
|
If an array, there must be one penalty for each feature.
|
|
@@ -318,7 +361,7 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
318
361
|
no tied event times all the methods are equivalent.
|
|
319
362
|
|
|
320
363
|
n_iter : int, optional, default: 100
|
|
321
|
-
|
|
364
|
+
The maximum number of iterations taken for the solver to converge.
|
|
322
365
|
|
|
323
366
|
tol : float, optional, default: 1e-9
|
|
324
367
|
Convergence criteria. Convergence is based on the negative log-likelihood::
|
|
@@ -332,7 +375,7 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
332
375
|
Attributes
|
|
333
376
|
----------
|
|
334
377
|
coef_ : ndarray, shape = (n_features,)
|
|
335
|
-
Coefficients of the model
|
|
378
|
+
Coefficients of the model.
|
|
336
379
|
|
|
337
380
|
cum_baseline_hazard_ : :class:`sksurv.functions.StepFunction`
|
|
338
381
|
Estimated baseline cumulative hazard function.
|
|
@@ -343,11 +386,11 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
343
386
|
n_features_in_ : int
|
|
344
387
|
Number of features seen during ``fit``.
|
|
345
388
|
|
|
346
|
-
feature_names_in_ : ndarray
|
|
389
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,)
|
|
347
390
|
Names of features seen during ``fit``. Defined only when `X`
|
|
348
391
|
has feature names that are all strings.
|
|
349
392
|
|
|
350
|
-
unique_times_ :
|
|
393
|
+
unique_times_ : ndarray, shape = (n_unique_times,)
|
|
351
394
|
Unique time points.
|
|
352
395
|
|
|
353
396
|
See also
|
|
@@ -395,7 +438,7 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
395
438
|
return self._baseline_model.unique_times_
|
|
396
439
|
|
|
397
440
|
def fit(self, X, y):
|
|
398
|
-
"""
|
|
441
|
+
"""Fit the model to the given data.
|
|
399
442
|
|
|
400
443
|
Parameters
|
|
401
444
|
----------
|
|
@@ -403,9 +446,9 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
403
446
|
Data matrix
|
|
404
447
|
|
|
405
448
|
y : structured array, shape = (n_samples,)
|
|
406
|
-
A structured array
|
|
407
|
-
|
|
408
|
-
second field.
|
|
449
|
+
A structured array with two fields. The first field is a boolean
|
|
450
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
451
|
+
The second field is a float with the time of event or time of censoring.
|
|
409
452
|
|
|
410
453
|
Returns
|
|
411
454
|
-------
|
|
@@ -482,6 +525,11 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
482
525
|
def predict(self, X):
|
|
483
526
|
"""Predict risk scores.
|
|
484
527
|
|
|
528
|
+
The risk score is the linear predictor of the model,
|
|
529
|
+
computed as the dot product of the input features `X` and the
|
|
530
|
+
estimated coefficients `coef_`. A higher score indicates a
|
|
531
|
+
higher risk of experiencing the event.
|
|
532
|
+
|
|
485
533
|
Parameters
|
|
486
534
|
----------
|
|
487
535
|
X : array-like, shape = (n_samples, n_features)
|
|
@@ -498,15 +546,16 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
498
546
|
|
|
499
547
|
return np.dot(X, self.coef_)
|
|
500
548
|
|
|
549
|
+
@append_cumulative_hazard_example(estimator_mod="linear_model", estimator_class="CoxPHSurvivalAnalysis")
|
|
501
550
|
def predict_cumulative_hazard_function(self, X, return_array=False):
|
|
502
|
-
"""Predict cumulative hazard function.
|
|
551
|
+
r"""Predict cumulative hazard function.
|
|
503
552
|
|
|
504
553
|
The cumulative hazard function for an individual
|
|
505
554
|
with feature vector :math:`x` is defined as
|
|
506
555
|
|
|
507
556
|
.. math::
|
|
508
557
|
|
|
509
|
-
H(t
|
|
558
|
+
H(t \mid x) = \exp(x^\top \beta) H_0(t) ,
|
|
510
559
|
|
|
511
560
|
where :math:`H_0(t)` is the baseline hazard function,
|
|
512
561
|
estimated by Breslow's estimator.
|
|
@@ -516,56 +565,42 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
516
565
|
X : array-like, shape = (n_samples, n_features)
|
|
517
566
|
Data matrix.
|
|
518
567
|
|
|
519
|
-
return_array :
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
568
|
+
return_array : bool, default: False
|
|
569
|
+
Whether to return a single array of cumulative hazard values
|
|
570
|
+
or a list of step functions.
|
|
571
|
+
|
|
572
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
573
|
+
objects is returned.
|
|
574
|
+
|
|
575
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
576
|
+
returned, where `n_unique_times` is the number of unique
|
|
577
|
+
event times in the training data. Each row represents the cumulative
|
|
578
|
+
hazard function of an individual evaluated at `unique_times_`.
|
|
523
579
|
|
|
524
580
|
Returns
|
|
525
581
|
-------
|
|
526
582
|
cum_hazard : ndarray
|
|
527
|
-
If `return_array` is
|
|
528
|
-
|
|
529
|
-
|
|
583
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
584
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
585
|
+
|
|
586
|
+
If `return_array` is `True`, a numeric array of shape
|
|
587
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
530
588
|
|
|
531
589
|
Examples
|
|
532
590
|
--------
|
|
533
|
-
>>> import matplotlib.pyplot as plt
|
|
534
|
-
>>> from sksurv.datasets import load_whas500
|
|
535
|
-
>>> from sksurv.linear_model import CoxPHSurvivalAnalysis
|
|
536
|
-
|
|
537
|
-
Load the data.
|
|
538
|
-
|
|
539
|
-
>>> X, y = load_whas500()
|
|
540
|
-
>>> X = X.astype(float)
|
|
541
|
-
|
|
542
|
-
Fit the model.
|
|
543
|
-
|
|
544
|
-
>>> estimator = CoxPHSurvivalAnalysis().fit(X, y)
|
|
545
|
-
|
|
546
|
-
Estimate the cumulative hazard function for the first 10 samples.
|
|
547
|
-
|
|
548
|
-
>>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:10])
|
|
549
|
-
|
|
550
|
-
Plot the estimated cumulative hazard functions.
|
|
551
|
-
|
|
552
|
-
>>> for fn in chf_funcs:
|
|
553
|
-
... plt.step(fn.x, fn(fn.x), where="post")
|
|
554
|
-
...
|
|
555
|
-
>>> plt.ylim(0, 1)
|
|
556
|
-
>>> plt.show()
|
|
557
591
|
"""
|
|
558
592
|
return self._predict_cumulative_hazard_function(self._baseline_model, self.predict(X), return_array)
|
|
559
593
|
|
|
594
|
+
@append_survival_function_example(estimator_mod="linear_model", estimator_class="CoxPHSurvivalAnalysis")
|
|
560
595
|
def predict_survival_function(self, X, return_array=False):
|
|
561
|
-
"""Predict survival function.
|
|
596
|
+
r"""Predict survival function.
|
|
562
597
|
|
|
563
598
|
The survival function for an individual
|
|
564
599
|
with feature vector :math:`x` is defined as
|
|
565
600
|
|
|
566
601
|
.. math::
|
|
567
602
|
|
|
568
|
-
S(t
|
|
603
|
+
S(t \mid x) = S_0(t)^{\exp(x^\top \beta)} ,
|
|
569
604
|
|
|
570
605
|
where :math:`S_0(t)` is the baseline survival function,
|
|
571
606
|
estimated by Breslow's estimator.
|
|
@@ -575,44 +610,28 @@ class CoxPHSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
|
|
|
575
610
|
X : array-like, shape = (n_samples, n_features)
|
|
576
611
|
Data matrix.
|
|
577
612
|
|
|
578
|
-
return_array :
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
613
|
+
return_array : bool, default: False
|
|
614
|
+
Whether to return a single array of survival probabilities
|
|
615
|
+
or a list of step functions.
|
|
616
|
+
|
|
617
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
618
|
+
objects is returned.
|
|
619
|
+
|
|
620
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
621
|
+
returned, where `n_unique_times` is the number of unique
|
|
622
|
+
event times in the training data. Each row represents the survival
|
|
623
|
+
function of an individual evaluated at `unique_times_`.
|
|
582
624
|
|
|
583
625
|
Returns
|
|
584
626
|
-------
|
|
585
627
|
survival : ndarray
|
|
586
|
-
If `return_array` is
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
628
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
629
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
630
|
+
|
|
631
|
+
If `return_array` is `True`, a numeric array of shape
|
|
632
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
590
633
|
|
|
591
634
|
Examples
|
|
592
635
|
--------
|
|
593
|
-
>>> import matplotlib.pyplot as plt
|
|
594
|
-
>>> from sksurv.datasets import load_whas500
|
|
595
|
-
>>> from sksurv.linear_model import CoxPHSurvivalAnalysis
|
|
596
|
-
|
|
597
|
-
Load the data.
|
|
598
|
-
|
|
599
|
-
>>> X, y = load_whas500()
|
|
600
|
-
>>> X = X.astype(float)
|
|
601
|
-
|
|
602
|
-
Fit the model.
|
|
603
|
-
|
|
604
|
-
>>> estimator = CoxPHSurvivalAnalysis().fit(X, y)
|
|
605
|
-
|
|
606
|
-
Estimate the survival function for the first 10 samples.
|
|
607
|
-
|
|
608
|
-
>>> surv_funcs = estimator.predict_survival_function(X.iloc[:10])
|
|
609
|
-
|
|
610
|
-
Plot the estimated survival functions.
|
|
611
|
-
|
|
612
|
-
>>> for fn in surv_funcs:
|
|
613
|
-
... plt.step(fn.x, fn(fn.x), where="post")
|
|
614
|
-
...
|
|
615
|
-
>>> plt.ylim(0, 1)
|
|
616
|
-
>>> plt.show()
|
|
617
636
|
"""
|
|
618
637
|
return self._predict_survival_function(self._baseline_model, self.predict(X), return_array)
|
|
@@ -37,18 +37,55 @@ def _corr_kendalltau(X):
|
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
class EnsembleAverage(BaseEstimator):
|
|
40
|
+
"""A meta-estimator that averages the predictions of base estimators.
|
|
41
|
+
|
|
42
|
+
This estimator is for internal use by :class:`BaseEnsembleSelection`.
|
|
43
|
+
It takes a list of estimators that have already been fitted and
|
|
44
|
+
averages their predictions.
|
|
45
|
+
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
base_estimators : list of estimators
|
|
49
|
+
The base estimators to average. The estimators must be fitted.
|
|
50
|
+
|
|
51
|
+
name : str, optional, default: None
|
|
52
|
+
The name of the ensemble.
|
|
53
|
+
"""
|
|
54
|
+
|
|
40
55
|
def __init__(self, base_estimators, name=None):
|
|
41
56
|
self.base_estimators = base_estimators
|
|
42
57
|
self.name = name
|
|
43
58
|
assert not hasattr(self.base_estimators[0], "classes_"), "base estimator cannot be a classifier"
|
|
44
59
|
|
|
45
60
|
def get_base_params(self):
|
|
61
|
+
"""Get parameters for this estimator's first base estimator.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
params : dict
|
|
66
|
+
Parameter names mapped to their values.
|
|
67
|
+
"""
|
|
46
68
|
return self.base_estimators[0].get_params()
|
|
47
69
|
|
|
48
70
|
def fit(self, X, y=None, **kwargs): # pragma: no cover; # pylint: disable=unused-argument
|
|
49
71
|
return self
|
|
50
72
|
|
|
51
73
|
def predict(self, X):
|
|
74
|
+
"""Predict using the ensemble of estimators.
|
|
75
|
+
|
|
76
|
+
The prediction is the average of the predictions of all base
|
|
77
|
+
estimators.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
X : array-like, shape = (n_samples, n_features)
|
|
82
|
+
Data to predict on.
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
y_pred : ndarray, shape = (n_samples,)
|
|
87
|
+
The predicted values.
|
|
88
|
+
"""
|
|
52
89
|
prediction = np.zeros(X.shape[0])
|
|
53
90
|
for est in self.base_estimators:
|
|
54
91
|
prediction += est.predict(X)
|
|
@@ -57,18 +94,59 @@ class EnsembleAverage(BaseEstimator):
|
|
|
57
94
|
|
|
58
95
|
|
|
59
96
|
class MeanEstimator(BaseEstimator):
|
|
97
|
+
"""A meta-estimator that averages predictions.
|
|
98
|
+
|
|
99
|
+
This estimator computes the mean of an array along its last axis.
|
|
100
|
+
It is intended to be used as a ``meta_estimator`` in an ensemble model,
|
|
101
|
+
where it averages the predictions of the base estimators.
|
|
102
|
+
"""
|
|
103
|
+
|
|
60
104
|
def fit(self, X, y=None, **kwargs): # pragma: no cover; # pylint: disable=unused-argument
|
|
61
105
|
return self
|
|
62
106
|
|
|
63
107
|
def predict(self, X): # pylint: disable=no-self-use
|
|
108
|
+
"""Return the mean of an array along its last axis.
|
|
109
|
+
|
|
110
|
+
Parameters
|
|
111
|
+
----------
|
|
112
|
+
X : array-like, shape = (n_samples, n_estimators)
|
|
113
|
+
The predictions of base estimators.
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
y_pred : ndarray, shape = (n_samples,)
|
|
118
|
+
The averaged predictions.
|
|
119
|
+
"""
|
|
64
120
|
return X.mean(axis=X.ndim - 1)
|
|
65
121
|
|
|
66
122
|
|
|
67
123
|
class MeanRankEstimator(BaseEstimator):
|
|
124
|
+
"""A meta-estimator that averages the ranks of predictions of base estimators.
|
|
125
|
+
|
|
126
|
+
This estimator first converts the predictions of each base estimator
|
|
127
|
+
into ranks and then averages the ranks. It is intended to be used as
|
|
128
|
+
a ``meta_estimator`` in an ensemble model.
|
|
129
|
+
"""
|
|
130
|
+
|
|
68
131
|
def fit(self, X, y=None, **kwargs): # pragma: no cover; # pylint: disable=unused-argument
|
|
69
132
|
return self
|
|
70
133
|
|
|
71
134
|
def predict(self, X): # pylint: disable=no-self-use
|
|
135
|
+
"""Return the mean of ranks.
|
|
136
|
+
|
|
137
|
+
The predictions of each base estimator are first converted into
|
|
138
|
+
ranks and then averaged.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
X : array-like, shape = (n_samples, n_estimators)
|
|
143
|
+
The predictions of base estimators.
|
|
144
|
+
|
|
145
|
+
Returns
|
|
146
|
+
-------
|
|
147
|
+
y_pred : ndarray, shape = (n_samples,)
|
|
148
|
+
The averaged ranks.
|
|
149
|
+
"""
|
|
72
150
|
# convert predictions of individual models into ranks
|
|
73
151
|
ranks = np.apply_along_axis(rankdata, 0, X)
|
|
74
152
|
# average predicted ranks
|
|
@@ -134,6 +212,7 @@ class BaseEnsembleSelection(Stacking):
|
|
|
134
212
|
self._extra_params.extend(["scorer", "n_estimators", "min_score", "min_correlation", "cv", "n_jobs", "verbose"])
|
|
135
213
|
|
|
136
214
|
def __len__(self):
|
|
215
|
+
"""Return the number of fitted models."""
|
|
137
216
|
if hasattr(self, "fitted_models_"):
|
|
138
217
|
return len(self.fitted_models_)
|
|
139
218
|
return 0
|
|
@@ -300,16 +379,19 @@ class BaseEnsembleSelection(Stacking):
|
|
|
300
379
|
raise NotImplementedError()
|
|
301
380
|
|
|
302
381
|
def fit(self, X, y=None, **fit_params):
|
|
303
|
-
"""Fit ensemble of models
|
|
382
|
+
"""Fit ensemble of models.
|
|
304
383
|
|
|
305
384
|
Parameters
|
|
306
385
|
----------
|
|
307
386
|
X : array-like, shape = (n_samples, n_features)
|
|
308
387
|
Training data.
|
|
309
388
|
|
|
310
|
-
y : array-like, optional
|
|
389
|
+
y : array-like, shape = (n_samples,), optional
|
|
311
390
|
Target data if base estimators are supervised.
|
|
312
391
|
|
|
392
|
+
**fit_params : dict
|
|
393
|
+
Parameters passed to the ``fit`` method of each base estimator.
|
|
394
|
+
|
|
313
395
|
Returns
|
|
314
396
|
-------
|
|
315
397
|
self
|
|
@@ -347,7 +429,7 @@ class EnsembleSelection(BaseEnsembleSelection):
|
|
|
347
429
|
If a float, the percentage of estimators in the ensemble to retain, if an int the
|
|
348
430
|
absolute number of estimators to retain.
|
|
349
431
|
|
|
350
|
-
min_score : float, optional, default: 0.
|
|
432
|
+
min_score : float, optional, default: 0.2
|
|
351
433
|
Threshold for pruning estimators based on scoring metric. After `fit`, only estimators
|
|
352
434
|
with a score above `min_score` are retained.
|
|
353
435
|
|
|
@@ -379,7 +461,7 @@ class EnsembleSelection(BaseEnsembleSelection):
|
|
|
379
461
|
n_features_in_ : int
|
|
380
462
|
Number of features seen during ``fit``.
|
|
381
463
|
|
|
382
|
-
feature_names_in_ : ndarray
|
|
464
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,)
|
|
383
465
|
Names of features seen during ``fit``. Defined only when `X`
|
|
384
466
|
has feature names that are all strings.
|
|
385
467
|
|
|
@@ -473,14 +555,14 @@ class EnsembleSelection(BaseEnsembleSelection):
|
|
|
473
555
|
|
|
474
556
|
|
|
475
557
|
class EnsembleSelectionRegressor(BaseEnsembleSelection):
|
|
476
|
-
"""Ensemble selection for regression that accounts for the accuracy and correlation of errors.
|
|
558
|
+
r"""Ensemble selection for regression that accounts for the accuracy and correlation of errors.
|
|
477
559
|
|
|
478
560
|
The ensemble is pruned during training according to estimators' accuracy and the correlation
|
|
479
561
|
between prediction errors per sample. The accuracy of the *i*-th estimator defined as
|
|
480
|
-
:math
|
|
562
|
+
:math:`\frac{ \min_{i=1,\ldots, n}(error_i) }{ error_i }`.
|
|
481
563
|
In addition to the accuracy, models are selected based on the correlation between residuals
|
|
482
564
|
of different models (diversity). The diversity of the *i*-th estimator is defined as
|
|
483
|
-
:math
|
|
565
|
+
:math:`\frac{n-count}{n}`, where *count* is the number of estimators for whom the correlation
|
|
484
566
|
of residuals exceeds `min_correlation`.
|
|
485
567
|
|
|
486
568
|
The hillclimbing is based on cross-validation to avoid having to create a separate validation set.
|
|
@@ -504,7 +586,7 @@ class EnsembleSelectionRegressor(BaseEnsembleSelection):
|
|
|
504
586
|
|
|
505
587
|
min_score : float, optional, default: 0.66
|
|
506
588
|
Threshold for pruning estimators based on scoring metric. After `fit`, only estimators
|
|
507
|
-
with
|
|
589
|
+
with an accuracy above `min_score` are retained.
|
|
508
590
|
|
|
509
591
|
min_correlation : float, optional, default: 0.6
|
|
510
592
|
Threshold for Pearson's correlation coefficient that determines when residuals of
|
|
@@ -534,7 +616,7 @@ class EnsembleSelectionRegressor(BaseEnsembleSelection):
|
|
|
534
616
|
n_features_in_ : int
|
|
535
617
|
Number of features seen during ``fit``.
|
|
536
618
|
|
|
537
|
-
feature_names_in_ : ndarray
|
|
619
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,)
|
|
538
620
|
Names of features seen during ``fit``. Defined only when `X`
|
|
539
621
|
has feature names that are all strings.
|
|
540
622
|
|