scikit-survival 0.24.1__cp312-cp312-win_amd64.whl → 0.26.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.26.0.dist-info/METADATA +185 -0
- scikit_survival-0.26.0.dist-info/RECORD +58 -0
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/WHEEL +1 -1
- sksurv/__init__.py +51 -6
- sksurv/base.py +12 -2
- sksurv/bintrees/_binarytrees.cp312-win_amd64.pyd +0 -0
- sksurv/column.py +38 -35
- sksurv/compare.py +23 -23
- sksurv/datasets/base.py +52 -27
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/_coxph_loss.cp312-win_amd64.pyd +0 -0
- sksurv/ensemble/boosting.py +116 -168
- sksurv/ensemble/forest.py +94 -151
- sksurv/functions.py +29 -29
- sksurv/io/arffread.py +37 -4
- sksurv/io/arffwrite.py +41 -5
- sksurv/kernels/_clinical_kernel.cp312-win_amd64.pyd +0 -0
- sksurv/kernels/clinical.py +36 -16
- sksurv/linear_model/_coxnet.cp312-win_amd64.pyd +0 -0
- sksurv/linear_model/aft.py +14 -11
- sksurv/linear_model/coxnet.py +138 -89
- sksurv/linear_model/coxph.py +102 -83
- sksurv/meta/ensemble_selection.py +91 -9
- sksurv/meta/stacking.py +47 -26
- sksurv/metrics.py +257 -224
- sksurv/nonparametric.py +150 -81
- sksurv/preprocessing.py +74 -34
- sksurv/svm/_minlip.cp312-win_amd64.pyd +0 -0
- sksurv/svm/_prsvm.cp312-win_amd64.pyd +0 -0
- sksurv/svm/minlip.py +171 -85
- sksurv/svm/naive_survival_svm.py +63 -34
- sksurv/svm/survival_svm.py +103 -103
- sksurv/testing.py +47 -0
- sksurv/tree/_criterion.cp312-win_amd64.pyd +0 -0
- sksurv/tree/tree.py +170 -84
- sksurv/util.py +85 -30
- scikit_survival-0.24.1.dist-info/METADATA +0 -889
- scikit_survival-0.24.1.dist-info/RECORD +0 -57
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/licenses/COPYING +0 -0
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/top_level.txt +0 -0
sksurv/ensemble/forest.py
CHANGED
|
@@ -18,6 +18,7 @@ from sklearn.utils._tags import get_tags
|
|
|
18
18
|
from sklearn.utils.validation import check_is_fitted, check_random_state, validate_data
|
|
19
19
|
|
|
20
20
|
from ..base import SurvivalAnalysisMixin
|
|
21
|
+
from ..docstrings import append_cumulative_hazard_example, append_survival_function_example
|
|
21
22
|
from ..metrics import concordance_index_censored
|
|
22
23
|
from ..tree import ExtraSurvivalTree, SurvivalTree
|
|
23
24
|
from ..tree._criterion import get_unique_times
|
|
@@ -96,9 +97,9 @@ class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
|
|
|
96
97
|
Data matrix
|
|
97
98
|
|
|
98
99
|
y : structured array, shape = (n_samples,)
|
|
99
|
-
A structured array
|
|
100
|
-
|
|
101
|
-
second field.
|
|
100
|
+
A structured array with two fields. The first field is a boolean
|
|
101
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
102
|
+
The second field is a float with the time of event or time of censoring.
|
|
102
103
|
|
|
103
104
|
Returns
|
|
104
105
|
-------
|
|
@@ -266,15 +267,15 @@ class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
|
|
|
266
267
|
return y_hat
|
|
267
268
|
|
|
268
269
|
def predict(self, X):
|
|
269
|
-
"""Predict risk score.
|
|
270
|
+
r"""Predict risk score.
|
|
270
271
|
|
|
271
272
|
The ensemble risk score is the total number of events,
|
|
272
273
|
which can be estimated by the sum of the estimated
|
|
273
|
-
ensemble cumulative hazard function :math
|
|
274
|
+
ensemble cumulative hazard function :math:`\hat{H}_e`.
|
|
274
275
|
|
|
275
276
|
.. math::
|
|
276
277
|
|
|
277
|
-
|
|
278
|
+
\sum_{j=1}^{n} \hat{H}_e(T_{j} \mid x) ,
|
|
278
279
|
|
|
279
280
|
where :math:`n` denotes the total number of distinct
|
|
280
281
|
event times in the training data.
|
|
@@ -322,7 +323,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
322
323
|
|
|
323
324
|
Parameters
|
|
324
325
|
----------
|
|
325
|
-
n_estimators :
|
|
326
|
+
n_estimators : int, optional, default: 100
|
|
326
327
|
The number of trees in the forest.
|
|
327
328
|
|
|
328
329
|
max_depth : int or None, optional, default: None
|
|
@@ -355,7 +356,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
355
356
|
the input samples) required to be at a leaf node. Samples have
|
|
356
357
|
equal weight when sample_weight is not provided.
|
|
357
358
|
|
|
358
|
-
max_features : int, float,
|
|
359
|
+
max_features : int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt'
|
|
359
360
|
The number of features to consider when looking for the best split:
|
|
360
361
|
|
|
361
362
|
- If int, then consider `max_features` features at each split.
|
|
@@ -375,11 +376,11 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
375
376
|
Best nodes are defined as relative reduction in impurity.
|
|
376
377
|
If None then unlimited number of leaf nodes.
|
|
377
378
|
|
|
378
|
-
bootstrap :
|
|
379
|
+
bootstrap : bool, optional, default: True
|
|
379
380
|
Whether bootstrap samples are used when building trees. If False, the
|
|
380
381
|
whole dataset is used to build each tree.
|
|
381
382
|
|
|
382
|
-
oob_score : bool, default: False
|
|
383
|
+
oob_score : bool, optional, default: False
|
|
383
384
|
Whether to use out-of-bag samples to estimate
|
|
384
385
|
the generalization accuracy.
|
|
385
386
|
|
|
@@ -412,22 +413,22 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
412
413
|
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
|
|
413
414
|
`max_samples` should be in the interval `(0.0, 1.0]`.
|
|
414
415
|
|
|
415
|
-
low_memory :
|
|
416
|
-
If set,
|
|
417
|
-
and
|
|
416
|
+
low_memory : bool, optional, default: False
|
|
417
|
+
If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
|
|
418
|
+
and :meth:`predict_survival_function` are not implemented.
|
|
418
419
|
|
|
419
420
|
Attributes
|
|
420
421
|
----------
|
|
421
422
|
estimators_ : list of SurvivalTree instances
|
|
422
423
|
The collection of fitted sub-estimators.
|
|
423
424
|
|
|
424
|
-
unique_times_ :
|
|
425
|
+
unique_times_ : ndarray, shape = (n_unique_times,)
|
|
425
426
|
Unique time points.
|
|
426
427
|
|
|
427
428
|
n_features_in_ : int
|
|
428
429
|
Number of features seen during ``fit``.
|
|
429
430
|
|
|
430
|
-
feature_names_in_ : ndarray
|
|
431
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,)
|
|
431
432
|
Names of features seen during ``fit``. Defined only when `X`
|
|
432
433
|
has feature names that are all strings.
|
|
433
434
|
|
|
@@ -527,6 +528,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
527
528
|
self.max_leaf_nodes = max_leaf_nodes
|
|
528
529
|
self.low_memory = low_memory
|
|
529
530
|
|
|
531
|
+
@append_cumulative_hazard_example(estimator_mod="ensemble", estimator_class="RandomSurvivalForest")
|
|
530
532
|
def predict_cumulative_hazard_function(self, X, return_array=False):
|
|
531
533
|
"""Predict cumulative hazard function.
|
|
532
534
|
|
|
@@ -544,47 +546,33 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
544
546
|
X : array-like, shape = (n_samples, n_features)
|
|
545
547
|
Data matrix.
|
|
546
548
|
|
|
547
|
-
return_array :
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
549
|
+
return_array : bool, default: False
|
|
550
|
+
Whether to return a single array of cumulative hazard values
|
|
551
|
+
or a list of step functions.
|
|
552
|
+
|
|
553
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
554
|
+
objects is returned.
|
|
555
|
+
|
|
556
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
557
|
+
returned, where `n_unique_times` is the number of unique
|
|
558
|
+
event times in the training data. Each row represents the cumulative
|
|
559
|
+
hazard function of an individual evaluated at `unique_times_`.
|
|
551
560
|
|
|
552
561
|
Returns
|
|
553
562
|
-------
|
|
554
563
|
cum_hazard : ndarray
|
|
555
|
-
If `return_array` is
|
|
556
|
-
|
|
557
|
-
|
|
564
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
565
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
566
|
+
|
|
567
|
+
If `return_array` is `True`, a numeric array of shape
|
|
568
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
558
569
|
|
|
559
570
|
Examples
|
|
560
571
|
--------
|
|
561
|
-
>>> import matplotlib.pyplot as plt
|
|
562
|
-
>>> from sksurv.datasets import load_whas500
|
|
563
|
-
>>> from sksurv.ensemble import RandomSurvivalForest
|
|
564
|
-
|
|
565
|
-
Load and prepare the data.
|
|
566
|
-
|
|
567
|
-
>>> X, y = load_whas500()
|
|
568
|
-
>>> X = X.astype(float)
|
|
569
|
-
|
|
570
|
-
Fit the model.
|
|
571
|
-
|
|
572
|
-
>>> estimator = RandomSurvivalForest().fit(X, y)
|
|
573
|
-
|
|
574
|
-
Estimate the cumulative hazard function for the first 5 samples.
|
|
575
|
-
|
|
576
|
-
>>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
|
|
577
|
-
|
|
578
|
-
Plot the estimated cumulative hazard functions.
|
|
579
|
-
|
|
580
|
-
>>> for fn in chf_funcs:
|
|
581
|
-
... plt.step(fn.x, fn(fn.x), where="post")
|
|
582
|
-
...
|
|
583
|
-
>>> plt.ylim(0, 1)
|
|
584
|
-
>>> plt.show()
|
|
585
572
|
"""
|
|
586
573
|
return super().predict_cumulative_hazard_function(X, return_array)
|
|
587
574
|
|
|
575
|
+
@append_survival_function_example(estimator_mod="ensemble", estimator_class="RandomSurvivalForest")
|
|
588
576
|
def predict_survival_function(self, X, return_array=False):
|
|
589
577
|
"""Predict survival function.
|
|
590
578
|
|
|
@@ -602,45 +590,29 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
602
590
|
X : array-like, shape = (n_samples, n_features)
|
|
603
591
|
Data matrix.
|
|
604
592
|
|
|
605
|
-
return_array :
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
593
|
+
return_array : bool, default: False
|
|
594
|
+
Whether to return a single array of survival probabilities
|
|
595
|
+
or a list of step functions.
|
|
596
|
+
|
|
597
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
598
|
+
objects is returned.
|
|
599
|
+
|
|
600
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
601
|
+
returned, where `n_unique_times` is the number of unique
|
|
602
|
+
event times in the training data. Each row represents the survival
|
|
603
|
+
function of an individual evaluated at `unique_times_`.
|
|
609
604
|
|
|
610
605
|
Returns
|
|
611
606
|
-------
|
|
612
607
|
survival : ndarray
|
|
613
|
-
If `return_array` is
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
608
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
609
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
610
|
+
|
|
611
|
+
If `return_array` is `True`, a numeric array of shape
|
|
612
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
617
613
|
|
|
618
614
|
Examples
|
|
619
615
|
--------
|
|
620
|
-
>>> import matplotlib.pyplot as plt
|
|
621
|
-
>>> from sksurv.datasets import load_whas500
|
|
622
|
-
>>> from sksurv.ensemble import RandomSurvivalForest
|
|
623
|
-
|
|
624
|
-
Load and prepare the data.
|
|
625
|
-
|
|
626
|
-
>>> X, y = load_whas500()
|
|
627
|
-
>>> X = X.astype(float)
|
|
628
|
-
|
|
629
|
-
Fit the model.
|
|
630
|
-
|
|
631
|
-
>>> estimator = RandomSurvivalForest().fit(X, y)
|
|
632
|
-
|
|
633
|
-
Estimate the survival function for the first 5 samples.
|
|
634
|
-
|
|
635
|
-
>>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
|
|
636
|
-
|
|
637
|
-
Plot the estimated survival functions.
|
|
638
|
-
|
|
639
|
-
>>> for fn in surv_funcs:
|
|
640
|
-
... plt.step(fn.x, fn(fn.x), where="post")
|
|
641
|
-
...
|
|
642
|
-
>>> plt.ylim(0, 1)
|
|
643
|
-
>>> plt.show()
|
|
644
616
|
"""
|
|
645
617
|
return super().predict_survival_function(X, return_array)
|
|
646
618
|
|
|
@@ -667,7 +639,7 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
667
639
|
|
|
668
640
|
Parameters
|
|
669
641
|
----------
|
|
670
|
-
n_estimators :
|
|
642
|
+
n_estimators : int, optional, default: 100
|
|
671
643
|
The number of trees in the forest.
|
|
672
644
|
|
|
673
645
|
max_depth : int or None, optional, default: None
|
|
@@ -700,7 +672,7 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
700
672
|
the input samples) required to be at a leaf node. Samples have
|
|
701
673
|
equal weight when sample_weight is not provided.
|
|
702
674
|
|
|
703
|
-
max_features : int, float,
|
|
675
|
+
max_features : int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt'
|
|
704
676
|
The number of features to consider when looking for the best split:
|
|
705
677
|
|
|
706
678
|
- If int, then consider `max_features` features at each split.
|
|
@@ -720,11 +692,11 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
720
692
|
Best nodes are defined as relative reduction in impurity.
|
|
721
693
|
If None then unlimited number of leaf nodes.
|
|
722
694
|
|
|
723
|
-
bootstrap :
|
|
695
|
+
bootstrap : bool, optional, default: True
|
|
724
696
|
Whether bootstrap samples are used when building trees. If False, the
|
|
725
697
|
whole dataset is used to build each tree.
|
|
726
698
|
|
|
727
|
-
oob_score : bool, default: False
|
|
699
|
+
oob_score : bool, optional, default: False
|
|
728
700
|
Whether to use out-of-bag samples to estimate
|
|
729
701
|
the generalization accuracy.
|
|
730
702
|
|
|
@@ -757,22 +729,22 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
757
729
|
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
|
|
758
730
|
`max_samples` should be in the interval `(0.0, 1.0]`.
|
|
759
731
|
|
|
760
|
-
low_memory :
|
|
761
|
-
If set,
|
|
762
|
-
and
|
|
732
|
+
low_memory : bool, optional, default: False
|
|
733
|
+
If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
|
|
734
|
+
and :meth:`predict_survival_function` are not implemented.
|
|
763
735
|
|
|
764
736
|
Attributes
|
|
765
737
|
----------
|
|
766
738
|
estimators_ : list of SurvivalTree instances
|
|
767
739
|
The collection of fitted sub-estimators.
|
|
768
740
|
|
|
769
|
-
unique_times_ :
|
|
741
|
+
unique_times_ : ndarray, shape = (n_unique_times,)
|
|
770
742
|
Unique time points.
|
|
771
743
|
|
|
772
744
|
n_features_in_ : int
|
|
773
|
-
|
|
745
|
+
Number of features seen during ``fit``.
|
|
774
746
|
|
|
775
|
-
feature_names_in_ : ndarray
|
|
747
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,)
|
|
776
748
|
Names of features seen during ``fit``. Defined only when `X`
|
|
777
749
|
has feature names that are all strings.
|
|
778
750
|
|
|
@@ -841,6 +813,7 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
841
813
|
self.max_leaf_nodes = max_leaf_nodes
|
|
842
814
|
self.low_memory = low_memory
|
|
843
815
|
|
|
816
|
+
@append_cumulative_hazard_example(estimator_mod="ensemble", estimator_class="ExtraSurvivalTrees")
|
|
844
817
|
def predict_cumulative_hazard_function(self, X, return_array=False):
|
|
845
818
|
"""Predict cumulative hazard function.
|
|
846
819
|
|
|
@@ -858,47 +831,33 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
858
831
|
X : array-like, shape = (n_samples, n_features)
|
|
859
832
|
Data matrix.
|
|
860
833
|
|
|
861
|
-
return_array :
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
834
|
+
return_array : bool, default: False
|
|
835
|
+
Whether to return a single array of cumulative hazard values
|
|
836
|
+
or a list of step functions.
|
|
837
|
+
|
|
838
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
839
|
+
objects is returned.
|
|
840
|
+
|
|
841
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
842
|
+
returned, where `n_unique_times` is the number of unique
|
|
843
|
+
event times in the training data. Each row represents the cumulative
|
|
844
|
+
hazard function of an individual evaluated at `unique_times_`.
|
|
865
845
|
|
|
866
846
|
Returns
|
|
867
847
|
-------
|
|
868
848
|
cum_hazard : ndarray
|
|
869
|
-
If `return_array` is
|
|
870
|
-
|
|
871
|
-
|
|
849
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
850
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
851
|
+
|
|
852
|
+
If `return_array` is `True`, a numeric array of shape
|
|
853
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
872
854
|
|
|
873
855
|
Examples
|
|
874
856
|
--------
|
|
875
|
-
>>> import matplotlib.pyplot as plt
|
|
876
|
-
>>> from sksurv.datasets import load_whas500
|
|
877
|
-
>>> from sksurv.ensemble import ExtraSurvivalTrees
|
|
878
|
-
|
|
879
|
-
Load and prepare the data.
|
|
880
|
-
|
|
881
|
-
>>> X, y = load_whas500()
|
|
882
|
-
>>> X = X.astype(float)
|
|
883
|
-
|
|
884
|
-
Fit the model.
|
|
885
|
-
|
|
886
|
-
>>> estimator = ExtraSurvivalTrees().fit(X, y)
|
|
887
|
-
|
|
888
|
-
Estimate the cumulative hazard function for the first 5 samples.
|
|
889
|
-
|
|
890
|
-
>>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
|
|
891
|
-
|
|
892
|
-
Plot the estimated cumulative hazard functions.
|
|
893
|
-
|
|
894
|
-
>>> for fn in chf_funcs:
|
|
895
|
-
... plt.step(fn.x, fn(fn.x), where="post")
|
|
896
|
-
...
|
|
897
|
-
>>> plt.ylim(0, 1)
|
|
898
|
-
>>> plt.show()
|
|
899
857
|
"""
|
|
900
858
|
return super().predict_cumulative_hazard_function(X, return_array)
|
|
901
859
|
|
|
860
|
+
@append_survival_function_example(estimator_mod="ensemble", estimator_class="ExtraSurvivalTrees")
|
|
902
861
|
def predict_survival_function(self, X, return_array=False):
|
|
903
862
|
"""Predict survival function.
|
|
904
863
|
|
|
@@ -916,44 +875,28 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
|
|
|
916
875
|
X : array-like, shape = (n_samples, n_features)
|
|
917
876
|
Data matrix.
|
|
918
877
|
|
|
919
|
-
return_array :
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
878
|
+
return_array : bool, default: False
|
|
879
|
+
Whether to return a single array of survival probabilities
|
|
880
|
+
or a list of step functions.
|
|
881
|
+
|
|
882
|
+
If `False`, a list of :class:`sksurv.functions.StepFunction`
|
|
883
|
+
objects is returned.
|
|
884
|
+
|
|
885
|
+
If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
|
|
886
|
+
returned, where `n_unique_times` is the number of unique
|
|
887
|
+
event times in the training data. Each row represents the survival
|
|
888
|
+
function of an individual evaluated at `unique_times_`.
|
|
923
889
|
|
|
924
890
|
Returns
|
|
925
891
|
-------
|
|
926
892
|
survival : ndarray
|
|
927
|
-
If `return_array` is
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
893
|
+
If `return_array` is `False`, an array of `n_samples`
|
|
894
|
+
:class:`sksurv.functions.StepFunction` instances is returned.
|
|
895
|
+
|
|
896
|
+
If `return_array` is `True`, a numeric array of shape
|
|
897
|
+
`(n_samples, n_unique_times_)` is returned.
|
|
931
898
|
|
|
932
899
|
Examples
|
|
933
900
|
--------
|
|
934
|
-
>>> import matplotlib.pyplot as plt
|
|
935
|
-
>>> from sksurv.datasets import load_whas500
|
|
936
|
-
>>> from sksurv.ensemble import ExtraSurvivalTrees
|
|
937
|
-
|
|
938
|
-
Load and prepare the data.
|
|
939
|
-
|
|
940
|
-
>>> X, y = load_whas500()
|
|
941
|
-
>>> X = X.astype(float)
|
|
942
|
-
|
|
943
|
-
Fit the model.
|
|
944
|
-
|
|
945
|
-
>>> estimator = ExtraSurvivalTrees().fit(X, y)
|
|
946
|
-
|
|
947
|
-
Estimate the survival function for the first 5 samples.
|
|
948
|
-
|
|
949
|
-
>>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
|
|
950
|
-
|
|
951
|
-
Plot the estimated survival functions.
|
|
952
|
-
|
|
953
|
-
>>> for fn in surv_funcs:
|
|
954
|
-
... plt.step(fn.x, fn(fn.x), where="post")
|
|
955
|
-
...
|
|
956
|
-
>>> plt.ylim(0, 1)
|
|
957
|
-
>>> plt.show()
|
|
958
901
|
"""
|
|
959
902
|
return super().predict_survival_function(X, return_array)
|
sksurv/functions.py
CHANGED
|
@@ -18,31 +18,29 @@ __all__ = ["StepFunction"]
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class StepFunction:
|
|
21
|
-
"""
|
|
21
|
+
r"""A callable step function.
|
|
22
|
+
|
|
23
|
+
The function is defined by a set of points :math:`(x_i, y_i)` and is
|
|
24
|
+
evaluated as:
|
|
22
25
|
|
|
23
26
|
.. math::
|
|
24
27
|
|
|
25
|
-
f(z) = a
|
|
26
|
-
x_i \\leq z < x_{i + 1}
|
|
28
|
+
f(z) = a \cdot y_i + b \quad \text{if} \quad x_i \leq z < x_{i + 1}
|
|
27
29
|
|
|
28
30
|
Parameters
|
|
29
31
|
----------
|
|
30
32
|
x : ndarray, shape = (n_points,)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
+
The values on the x-axis, must be in ascending order.
|
|
33
34
|
y : ndarray, shape = (n_points,)
|
|
34
|
-
|
|
35
|
-
|
|
35
|
+
The corresponding values on the y-axis.
|
|
36
36
|
a : float, optional, default: 1.0
|
|
37
|
-
|
|
38
|
-
|
|
37
|
+
A constant factor to scale ``y`` by.
|
|
39
38
|
b : float, optional, default: 0.0
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
If entry is `None`, use the first/last value of `x` as limit.
|
|
39
|
+
A constant offset term.
|
|
40
|
+
domain : tuple, optional, default: (0, None)
|
|
41
|
+
A tuple ``(lower, upper)`` that defines the domain of the step function.
|
|
42
|
+
If ``lower`` or ``upper`` is ``None``, the first or last value of ``x`` is
|
|
43
|
+
used as the limit, respectively.
|
|
46
44
|
"""
|
|
47
45
|
|
|
48
46
|
def __init__(self, x, y, *, a=1.0, b=0.0, domain=(0, None)):
|
|
@@ -57,36 +55,38 @@ class StepFunction:
|
|
|
57
55
|
|
|
58
56
|
@property
|
|
59
57
|
def domain(self):
|
|
60
|
-
"""
|
|
61
|
-
|
|
58
|
+
"""The domain of the function.
|
|
59
|
+
|
|
60
|
+
The domain is the range of values that the function accepts.
|
|
62
61
|
|
|
63
62
|
Returns
|
|
64
63
|
-------
|
|
65
64
|
lower_limit : float
|
|
66
|
-
Lower limit of
|
|
65
|
+
Lower limit of the omain.
|
|
67
66
|
|
|
68
67
|
upper_limit : float
|
|
69
|
-
Upper limit of domain.
|
|
68
|
+
Upper limit of the domain.
|
|
70
69
|
"""
|
|
71
70
|
return self._domain
|
|
72
71
|
|
|
73
72
|
def __call__(self, x):
|
|
74
|
-
"""Evaluate step function.
|
|
75
|
-
|
|
76
|
-
Values outside the interval specified by `self.domain`
|
|
77
|
-
will raise an exception.
|
|
78
|
-
Values in `x` that are in the interval `[self.domain[0]; self.x[0]]`
|
|
79
|
-
get mapped to `self.y[0]`.
|
|
73
|
+
"""Evaluate the step function at given values.
|
|
80
74
|
|
|
81
75
|
Parameters
|
|
82
76
|
----------
|
|
83
|
-
x : float
|
|
84
|
-
|
|
77
|
+
x : float or array-like, shape=(n_values,)
|
|
78
|
+
The values at which to evaluate the step function.
|
|
79
|
+
Values must be within the function's ``domain``.
|
|
85
80
|
|
|
86
81
|
Returns
|
|
87
82
|
-------
|
|
88
|
-
y : float
|
|
89
|
-
|
|
83
|
+
y : float or array-like, shape=(n_values,)
|
|
84
|
+
The value of the step function at ``x``.
|
|
85
|
+
|
|
86
|
+
Raises
|
|
87
|
+
------
|
|
88
|
+
ValueError
|
|
89
|
+
If ``x`` contains values outside the function's ``domain``.
|
|
90
90
|
"""
|
|
91
91
|
x = np.atleast_1d(x)
|
|
92
92
|
if not np.isfinite(x).all():
|
sksurv/io/arffread.py
CHANGED
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
13
13
|
import numpy as np
|
|
14
14
|
import pandas as pd
|
|
15
|
+
from pandas.api.types import is_string_dtype
|
|
15
16
|
from scipy.io.arff import loadarff as scipy_loadarff
|
|
16
17
|
|
|
17
18
|
__all__ = ["loadarff"]
|
|
@@ -34,7 +35,8 @@ def _to_pandas(data, meta):
|
|
|
34
35
|
data_dict[name] = pd.Categorical(raw, categories=attr_format, ordered=False)
|
|
35
36
|
else:
|
|
36
37
|
arr = data[name]
|
|
37
|
-
|
|
38
|
+
dtype = "str" if is_string_dtype(arr.dtype) else arr.dtype
|
|
39
|
+
p = pd.Series(arr, dtype=dtype)
|
|
38
40
|
data_dict[name] = p
|
|
39
41
|
|
|
40
42
|
# currently, this step converts all pandas.Categorial columns back to pandas.Series
|
|
@@ -42,17 +44,48 @@ def _to_pandas(data, meta):
|
|
|
42
44
|
|
|
43
45
|
|
|
44
46
|
def loadarff(filename):
|
|
45
|
-
"""Load ARFF file
|
|
47
|
+
"""Load ARFF file.
|
|
46
48
|
|
|
47
49
|
Parameters
|
|
48
50
|
----------
|
|
49
|
-
filename :
|
|
50
|
-
Path to ARFF file
|
|
51
|
+
filename : str or file-like
|
|
52
|
+
Path to ARFF file, or file-like object to read from.
|
|
51
53
|
|
|
52
54
|
Returns
|
|
53
55
|
-------
|
|
54
56
|
data_frame : :class:`pandas.DataFrame`
|
|
55
57
|
DataFrame containing data of ARFF file
|
|
58
|
+
|
|
59
|
+
See Also
|
|
60
|
+
--------
|
|
61
|
+
scipy.io.arff.loadarff : The underlying function that reads the ARFF file.
|
|
62
|
+
|
|
63
|
+
Examples
|
|
64
|
+
--------
|
|
65
|
+
>>> from io import StringIO
|
|
66
|
+
>>> from sksurv.io import loadarff
|
|
67
|
+
>>>
|
|
68
|
+
>>> # Create a dummy ARFF file
|
|
69
|
+
>>> arff_content = '''
|
|
70
|
+
... @relation test_data
|
|
71
|
+
... @attribute feature1 numeric
|
|
72
|
+
... @attribute feature2 numeric
|
|
73
|
+
... @attribute class {A,B,C}
|
|
74
|
+
... @data
|
|
75
|
+
... 1.0,2.0,A
|
|
76
|
+
... 3.0,4.0,B
|
|
77
|
+
... 5.0,6.0,C
|
|
78
|
+
... '''
|
|
79
|
+
>>>
|
|
80
|
+
>>> # Load the ARFF file
|
|
81
|
+
>>> with StringIO(arff_content) as f:
|
|
82
|
+
... data = loadarff(f)
|
|
83
|
+
>>>
|
|
84
|
+
>>> print(data)
|
|
85
|
+
class feature1 feature2
|
|
86
|
+
0 A 1.0 2.0
|
|
87
|
+
1 B 3.0 4.0
|
|
88
|
+
2 C 5.0 6.0
|
|
56
89
|
"""
|
|
57
90
|
data, meta = scipy_loadarff(filename)
|
|
58
91
|
return _to_pandas(data, meta)
|
sksurv/io/arffwrite.py
CHANGED
|
@@ -15,7 +15,7 @@ import re
|
|
|
15
15
|
|
|
16
16
|
import numpy as np
|
|
17
17
|
import pandas as pd
|
|
18
|
-
from pandas.api.types import CategoricalDtype,
|
|
18
|
+
from pandas.api.types import CategoricalDtype, is_string_dtype
|
|
19
19
|
|
|
20
20
|
_ILLEGAL_CHARACTER_PAT = re.compile(r"[^-_=\w\d\(\)<>\.]")
|
|
21
21
|
|
|
@@ -28,15 +28,51 @@ def writearff(data, filename, relation_name=None, index=True):
|
|
|
28
28
|
data : :class:`pandas.DataFrame`
|
|
29
29
|
DataFrame containing data
|
|
30
30
|
|
|
31
|
-
filename :
|
|
31
|
+
filename : str or file-like object
|
|
32
32
|
Path to ARFF file or file-like object. In the latter case,
|
|
33
33
|
the handle is closed by calling this function.
|
|
34
34
|
|
|
35
|
-
relation_name :
|
|
35
|
+
relation_name : str, optional, default: 'pandas'
|
|
36
36
|
Name of relation in ARFF file.
|
|
37
37
|
|
|
38
38
|
index : boolean, optional, default: True
|
|
39
39
|
Write row names (index)
|
|
40
|
+
|
|
41
|
+
See Also
|
|
42
|
+
--------
|
|
43
|
+
loadarff : Function to read ARFF files.
|
|
44
|
+
|
|
45
|
+
Examples
|
|
46
|
+
--------
|
|
47
|
+
>>> import numpy as np
|
|
48
|
+
>>> import pandas as pd
|
|
49
|
+
>>> from sksurv.io import writearff
|
|
50
|
+
>>>
|
|
51
|
+
>>> # Create a dummy DataFrame
|
|
52
|
+
>>> data = pd.DataFrame({
|
|
53
|
+
... 'feature1': [1.0, 3.0, 5.0],
|
|
54
|
+
... 'feature2': [2.0, np.nan, 6.0],
|
|
55
|
+
... 'class': ['A', 'B', 'C']
|
|
56
|
+
... }, index=['One', 'Two', 'Three'])
|
|
57
|
+
>>>
|
|
58
|
+
>>> # Write to ARFF file
|
|
59
|
+
>>> writearff(data, 'test_output.arff', relation_name='test_data')
|
|
60
|
+
>>>
|
|
61
|
+
>>> # Read contents of ARFF file
|
|
62
|
+
>>> with open('test_output.arff') as f:
|
|
63
|
+
... arff_contents = "".join(f.readlines())
|
|
64
|
+
>>> print(arff_contents)
|
|
65
|
+
@relation test_data
|
|
66
|
+
<BLANKLINE>
|
|
67
|
+
@attribute index {One,Three,Two}
|
|
68
|
+
@attribute feature1 real
|
|
69
|
+
@attribute feature2 real
|
|
70
|
+
@attribute class {A,B,C}
|
|
71
|
+
<BLANKLINE>
|
|
72
|
+
@data
|
|
73
|
+
One,1.0,2.0,A
|
|
74
|
+
Two,3.0,?,B
|
|
75
|
+
Three,5.0,6.0,C
|
|
40
76
|
"""
|
|
41
77
|
if isinstance(filename, str):
|
|
42
78
|
fp = open(filename, "w")
|
|
@@ -70,7 +106,7 @@ def _write_header(data, fp, relation_name, index):
|
|
|
70
106
|
name = attribute_names[column]
|
|
71
107
|
fp.write(f"@attribute {name}\t")
|
|
72
108
|
|
|
73
|
-
if isinstance(series.dtype, CategoricalDtype) or
|
|
109
|
+
if isinstance(series.dtype, CategoricalDtype) or is_string_dtype(series.dtype):
|
|
74
110
|
_write_attribute_categorical(series, fp)
|
|
75
111
|
elif np.issubdtype(series.dtype, np.floating):
|
|
76
112
|
fp.write("real")
|
|
@@ -132,7 +168,7 @@ def _write_data(data, fp):
|
|
|
132
168
|
fp.write("@data\n")
|
|
133
169
|
|
|
134
170
|
def to_str(x):
|
|
135
|
-
if pd.
|
|
171
|
+
if pd.isna(x):
|
|
136
172
|
return "?"
|
|
137
173
|
return str(x)
|
|
138
174
|
|
|
Binary file
|