scikit-survival 0.24.1__cp313-cp313-win_amd64.whl → 0.26.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. scikit_survival-0.26.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.26.0.dist-info/RECORD +58 -0
  3. {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/WHEEL +1 -1
  4. sksurv/__init__.py +51 -6
  5. sksurv/base.py +12 -2
  6. sksurv/bintrees/_binarytrees.cp313-win_amd64.pyd +0 -0
  7. sksurv/column.py +38 -35
  8. sksurv/compare.py +23 -23
  9. sksurv/datasets/base.py +52 -27
  10. sksurv/docstrings.py +99 -0
  11. sksurv/ensemble/_coxph_loss.cp313-win_amd64.pyd +0 -0
  12. sksurv/ensemble/boosting.py +116 -168
  13. sksurv/ensemble/forest.py +94 -151
  14. sksurv/functions.py +29 -29
  15. sksurv/io/arffread.py +37 -4
  16. sksurv/io/arffwrite.py +41 -5
  17. sksurv/kernels/_clinical_kernel.cp313-win_amd64.pyd +0 -0
  18. sksurv/kernels/clinical.py +36 -16
  19. sksurv/linear_model/_coxnet.cp313-win_amd64.pyd +0 -0
  20. sksurv/linear_model/aft.py +14 -11
  21. sksurv/linear_model/coxnet.py +138 -89
  22. sksurv/linear_model/coxph.py +102 -83
  23. sksurv/meta/ensemble_selection.py +91 -9
  24. sksurv/meta/stacking.py +47 -26
  25. sksurv/metrics.py +257 -224
  26. sksurv/nonparametric.py +150 -81
  27. sksurv/preprocessing.py +74 -34
  28. sksurv/svm/_minlip.cp313-win_amd64.pyd +0 -0
  29. sksurv/svm/_prsvm.cp313-win_amd64.pyd +0 -0
  30. sksurv/svm/minlip.py +171 -85
  31. sksurv/svm/naive_survival_svm.py +63 -34
  32. sksurv/svm/survival_svm.py +103 -103
  33. sksurv/testing.py +47 -0
  34. sksurv/tree/_criterion.cp313-win_amd64.pyd +0 -0
  35. sksurv/tree/tree.py +170 -84
  36. sksurv/util.py +85 -30
  37. scikit_survival-0.24.1.dist-info/METADATA +0 -889
  38. scikit_survival-0.24.1.dist-info/RECORD +0 -57
  39. {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/licenses/COPYING +0 -0
  40. {scikit_survival-0.24.1.dist-info → scikit_survival-0.26.0.dist-info}/top_level.txt +0 -0
sksurv/ensemble/forest.py CHANGED
@@ -18,6 +18,7 @@ from sklearn.utils._tags import get_tags
18
18
  from sklearn.utils.validation import check_is_fitted, check_random_state, validate_data
19
19
 
20
20
  from ..base import SurvivalAnalysisMixin
21
+ from ..docstrings import append_cumulative_hazard_example, append_survival_function_example
21
22
  from ..metrics import concordance_index_censored
22
23
  from ..tree import ExtraSurvivalTree, SurvivalTree
23
24
  from ..tree._criterion import get_unique_times
@@ -96,9 +97,9 @@ class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
96
97
  Data matrix
97
98
 
98
99
  y : structured array, shape = (n_samples,)
99
- A structured array containing the binary event indicator
100
- as first field, and time of event or time of censoring as
101
- second field.
100
+ A structured array with two fields. The first field is a boolean
101
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
102
+ The second field is a float with the time of event or time of censoring.
102
103
 
103
104
  Returns
104
105
  -------
@@ -266,15 +267,15 @@ class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
266
267
  return y_hat
267
268
 
268
269
  def predict(self, X):
269
- """Predict risk score.
270
+ r"""Predict risk score.
270
271
 
271
272
  The ensemble risk score is the total number of events,
272
273
  which can be estimated by the sum of the estimated
273
- ensemble cumulative hazard function :math:`\\hat{H}_e`.
274
+ ensemble cumulative hazard function :math:`\hat{H}_e`.
274
275
 
275
276
  .. math::
276
277
 
277
- \\sum_{j=1}^{n} \\hat{H}_e(T_{j} \\mid x) ,
278
+ \sum_{j=1}^{n} \hat{H}_e(T_{j} \mid x) ,
278
279
 
279
280
  where :math:`n` denotes the total number of distinct
280
281
  event times in the training data.
@@ -322,7 +323,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
322
323
 
323
324
  Parameters
324
325
  ----------
325
- n_estimators : integer, optional, default: 100
326
+ n_estimators : int, optional, default: 100
326
327
  The number of trees in the forest.
327
328
 
328
329
  max_depth : int or None, optional, default: None
@@ -355,7 +356,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
355
356
  the input samples) required to be at a leaf node. Samples have
356
357
  equal weight when sample_weight is not provided.
357
358
 
358
- max_features : int, float, string or None, optional, default: None
359
+ max_features : int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt'
359
360
  The number of features to consider when looking for the best split:
360
361
 
361
362
  - If int, then consider `max_features` features at each split.
@@ -375,11 +376,11 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
375
376
  Best nodes are defined as relative reduction in impurity.
376
377
  If None then unlimited number of leaf nodes.
377
378
 
378
- bootstrap : boolean, optional, default: True
379
+ bootstrap : bool, optional, default: True
379
380
  Whether bootstrap samples are used when building trees. If False, the
380
381
  whole dataset is used to build each tree.
381
382
 
382
- oob_score : bool, default: False
383
+ oob_score : bool, optional, default: False
383
384
  Whether to use out-of-bag samples to estimate
384
385
  the generalization accuracy.
385
386
 
@@ -412,22 +413,22 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
412
413
  - If float, then draw `max_samples * X.shape[0]` samples. Thus,
413
414
  `max_samples` should be in the interval `(0.0, 1.0]`.
414
415
 
415
- low_memory : boolean, default: False
416
- If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function``
417
- and ``predict_survival_function`` are not implemented.
416
+ low_memory : bool, optional, default: False
417
+ If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
418
+ and :meth:`predict_survival_function` are not implemented.
418
419
 
419
420
  Attributes
420
421
  ----------
421
422
  estimators_ : list of SurvivalTree instances
422
423
  The collection of fitted sub-estimators.
423
424
 
424
- unique_times_ : array of shape = (n_unique_times,)
425
+ unique_times_ : ndarray, shape = (n_unique_times,)
425
426
  Unique time points.
426
427
 
427
428
  n_features_in_ : int
428
429
  Number of features seen during ``fit``.
429
430
 
430
- feature_names_in_ : ndarray of shape (`n_features_in_`,)
431
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
431
432
  Names of features seen during ``fit``. Defined only when `X`
432
433
  has feature names that are all strings.
433
434
 
@@ -527,6 +528,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
527
528
  self.max_leaf_nodes = max_leaf_nodes
528
529
  self.low_memory = low_memory
529
530
 
531
+ @append_cumulative_hazard_example(estimator_mod="ensemble", estimator_class="RandomSurvivalForest")
530
532
  def predict_cumulative_hazard_function(self, X, return_array=False):
531
533
  """Predict cumulative hazard function.
532
534
 
@@ -544,47 +546,33 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
544
546
  X : array-like, shape = (n_samples, n_features)
545
547
  Data matrix.
546
548
 
547
- return_array : boolean, default: False
548
- If set, return an array with the cumulative hazard rate
549
- for each `self.unique_times_`, otherwise an array of
550
- :class:`sksurv.functions.StepFunction`.
549
+ return_array : bool, default: False
550
+ Whether to return a single array of cumulative hazard values
551
+ or a list of step functions.
552
+
553
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
554
+ objects is returned.
555
+
556
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
557
+ returned, where `n_unique_times` is the number of unique
558
+ event times in the training data. Each row represents the cumulative
559
+ hazard function of an individual evaluated at `unique_times_`.
551
560
 
552
561
  Returns
553
562
  -------
554
563
  cum_hazard : ndarray
555
- If `return_array` is set, an array with the cumulative hazard rate
556
- for each `self.unique_times_`, otherwise an array of length `n_samples`
557
- of :class:`sksurv.functions.StepFunction` instances will be returned.
564
+ If `return_array` is `False`, an array of `n_samples`
565
+ :class:`sksurv.functions.StepFunction` instances is returned.
566
+
567
+ If `return_array` is `True`, a numeric array of shape
568
+ `(n_samples, n_unique_times_)` is returned.
558
569
 
559
570
  Examples
560
571
  --------
561
- >>> import matplotlib.pyplot as plt
562
- >>> from sksurv.datasets import load_whas500
563
- >>> from sksurv.ensemble import RandomSurvivalForest
564
-
565
- Load and prepare the data.
566
-
567
- >>> X, y = load_whas500()
568
- >>> X = X.astype(float)
569
-
570
- Fit the model.
571
-
572
- >>> estimator = RandomSurvivalForest().fit(X, y)
573
-
574
- Estimate the cumulative hazard function for the first 5 samples.
575
-
576
- >>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
577
-
578
- Plot the estimated cumulative hazard functions.
579
-
580
- >>> for fn in chf_funcs:
581
- ... plt.step(fn.x, fn(fn.x), where="post")
582
- ...
583
- >>> plt.ylim(0, 1)
584
- >>> plt.show()
585
572
  """
586
573
  return super().predict_cumulative_hazard_function(X, return_array)
587
574
 
575
+ @append_survival_function_example(estimator_mod="ensemble", estimator_class="RandomSurvivalForest")
588
576
  def predict_survival_function(self, X, return_array=False):
589
577
  """Predict survival function.
590
578
 
@@ -602,45 +590,29 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
602
590
  X : array-like, shape = (n_samples, n_features)
603
591
  Data matrix.
604
592
 
605
- return_array : boolean
606
- If set, return an array with the probability
607
- of survival for each `self.unique_times_`,
608
- otherwise an array of :class:`sksurv.functions.StepFunction`.
593
+ return_array : bool, default: False
594
+ Whether to return a single array of survival probabilities
595
+ or a list of step functions.
596
+
597
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
598
+ objects is returned.
599
+
600
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
601
+ returned, where `n_unique_times` is the number of unique
602
+ event times in the training data. Each row represents the survival
603
+ function of an individual evaluated at `unique_times_`.
609
604
 
610
605
  Returns
611
606
  -------
612
607
  survival : ndarray
613
- If `return_array` is set, an array with the probability
614
- of survival for each `self.unique_times_`,
615
- otherwise an array of :class:`sksurv.functions.StepFunction`
616
- will be returned.
608
+ If `return_array` is `False`, an array of `n_samples`
609
+ :class:`sksurv.functions.StepFunction` instances is returned.
610
+
611
+ If `return_array` is `True`, a numeric array of shape
612
+ `(n_samples, n_unique_times_)` is returned.
617
613
 
618
614
  Examples
619
615
  --------
620
- >>> import matplotlib.pyplot as plt
621
- >>> from sksurv.datasets import load_whas500
622
- >>> from sksurv.ensemble import RandomSurvivalForest
623
-
624
- Load and prepare the data.
625
-
626
- >>> X, y = load_whas500()
627
- >>> X = X.astype(float)
628
-
629
- Fit the model.
630
-
631
- >>> estimator = RandomSurvivalForest().fit(X, y)
632
-
633
- Estimate the survival function for the first 5 samples.
634
-
635
- >>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
636
-
637
- Plot the estimated survival functions.
638
-
639
- >>> for fn in surv_funcs:
640
- ... plt.step(fn.x, fn(fn.x), where="post")
641
- ...
642
- >>> plt.ylim(0, 1)
643
- >>> plt.show()
644
616
  """
645
617
  return super().predict_survival_function(X, return_array)
646
618
 
@@ -667,7 +639,7 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
667
639
 
668
640
  Parameters
669
641
  ----------
670
- n_estimators : integer, optional, default: 100
642
+ n_estimators : int, optional, default: 100
671
643
  The number of trees in the forest.
672
644
 
673
645
  max_depth : int or None, optional, default: None
@@ -700,7 +672,7 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
700
672
  the input samples) required to be at a leaf node. Samples have
701
673
  equal weight when sample_weight is not provided.
702
674
 
703
- max_features : int, float, string or None, optional, default: None
675
+ max_features : int, float, {'sqrt', 'log2'} or None, optional, default: 'sqrt'
704
676
  The number of features to consider when looking for the best split:
705
677
 
706
678
  - If int, then consider `max_features` features at each split.
@@ -720,11 +692,11 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
720
692
  Best nodes are defined as relative reduction in impurity.
721
693
  If None then unlimited number of leaf nodes.
722
694
 
723
- bootstrap : boolean, optional, default: True
695
+ bootstrap : bool, optional, default: True
724
696
  Whether bootstrap samples are used when building trees. If False, the
725
697
  whole dataset is used to build each tree.
726
698
 
727
- oob_score : bool, default: False
699
+ oob_score : bool, optional, default: False
728
700
  Whether to use out-of-bag samples to estimate
729
701
  the generalization accuracy.
730
702
 
@@ -757,22 +729,22 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
757
729
  - If float, then draw `max_samples * X.shape[0]` samples. Thus,
758
730
  `max_samples` should be in the interval `(0.0, 1.0]`.
759
731
 
760
- low_memory : boolean, default: False
761
- If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function``
762
- and ``predict_survival_function`` are not implemented.
732
+ low_memory : bool, optional, default: False
733
+ If set, :meth:`predict` computations use reduced memory but :meth:`predict_cumulative_hazard_function`
734
+ and :meth:`predict_survival_function` are not implemented.
763
735
 
764
736
  Attributes
765
737
  ----------
766
738
  estimators_ : list of SurvivalTree instances
767
739
  The collection of fitted sub-estimators.
768
740
 
769
- unique_times_ : array of shape = (n_unique_times,)
741
+ unique_times_ : ndarray, shape = (n_unique_times,)
770
742
  Unique time points.
771
743
 
772
744
  n_features_in_ : int
773
- The number of features when ``fit`` is performed.
745
+ Number of features seen during ``fit``.
774
746
 
775
- feature_names_in_ : ndarray of shape (`n_features_in_`,)
747
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
776
748
  Names of features seen during ``fit``. Defined only when `X`
777
749
  has feature names that are all strings.
778
750
 
@@ -841,6 +813,7 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
841
813
  self.max_leaf_nodes = max_leaf_nodes
842
814
  self.low_memory = low_memory
843
815
 
816
+ @append_cumulative_hazard_example(estimator_mod="ensemble", estimator_class="ExtraSurvivalTrees")
844
817
  def predict_cumulative_hazard_function(self, X, return_array=False):
845
818
  """Predict cumulative hazard function.
846
819
 
@@ -858,47 +831,33 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
858
831
  X : array-like, shape = (n_samples, n_features)
859
832
  Data matrix.
860
833
 
861
- return_array : boolean, default: False
862
- If set, return an array with the cumulative hazard rate
863
- for each `self.unique_times_`, otherwise an array of
864
- :class:`sksurv.functions.StepFunction`.
834
+ return_array : bool, default: False
835
+ Whether to return a single array of cumulative hazard values
836
+ or a list of step functions.
837
+
838
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
839
+ objects is returned.
840
+
841
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
842
+ returned, where `n_unique_times` is the number of unique
843
+ event times in the training data. Each row represents the cumulative
844
+ hazard function of an individual evaluated at `unique_times_`.
865
845
 
866
846
  Returns
867
847
  -------
868
848
  cum_hazard : ndarray
869
- If `return_array` is set, an array with the cumulative hazard rate
870
- for each `self.unique_times_`, otherwise an array of length `n_samples`
871
- of :class:`sksurv.functions.StepFunction` instances will be returned.
849
+ If `return_array` is `False`, an array of `n_samples`
850
+ :class:`sksurv.functions.StepFunction` instances is returned.
851
+
852
+ If `return_array` is `True`, a numeric array of shape
853
+ `(n_samples, n_unique_times_)` is returned.
872
854
 
873
855
  Examples
874
856
  --------
875
- >>> import matplotlib.pyplot as plt
876
- >>> from sksurv.datasets import load_whas500
877
- >>> from sksurv.ensemble import ExtraSurvivalTrees
878
-
879
- Load and prepare the data.
880
-
881
- >>> X, y = load_whas500()
882
- >>> X = X.astype(float)
883
-
884
- Fit the model.
885
-
886
- >>> estimator = ExtraSurvivalTrees().fit(X, y)
887
-
888
- Estimate the cumulative hazard function for the first 5 samples.
889
-
890
- >>> chf_funcs = estimator.predict_cumulative_hazard_function(X.iloc[:5])
891
-
892
- Plot the estimated cumulative hazard functions.
893
-
894
- >>> for fn in chf_funcs:
895
- ... plt.step(fn.x, fn(fn.x), where="post")
896
- ...
897
- >>> plt.ylim(0, 1)
898
- >>> plt.show()
899
857
  """
900
858
  return super().predict_cumulative_hazard_function(X, return_array)
901
859
 
860
+ @append_survival_function_example(estimator_mod="ensemble", estimator_class="ExtraSurvivalTrees")
902
861
  def predict_survival_function(self, X, return_array=False):
903
862
  """Predict survival function.
904
863
 
@@ -916,44 +875,28 @@ class ExtraSurvivalTrees(SurvivalAnalysisMixin, _BaseSurvivalForest):
916
875
  X : array-like, shape = (n_samples, n_features)
917
876
  Data matrix.
918
877
 
919
- return_array : boolean, default: False
920
- If set, return an array with the probability
921
- of survival for each `self.unique_times_`,
922
- otherwise an array of :class:`sksurv.functions.StepFunction`.
878
+ return_array : bool, default: False
879
+ Whether to return a single array of survival probabilities
880
+ or a list of step functions.
881
+
882
+ If `False`, a list of :class:`sksurv.functions.StepFunction`
883
+ objects is returned.
884
+
885
+ If `True`, a 2d-array of shape `(n_samples, n_unique_times)` is
886
+ returned, where `n_unique_times` is the number of unique
887
+ event times in the training data. Each row represents the survival
888
+ function of an individual evaluated at `unique_times_`.
923
889
 
924
890
  Returns
925
891
  -------
926
892
  survival : ndarray
927
- If `return_array` is set, an array with the probability of
928
- survival for each `self.unique_times_`, otherwise an array of
929
- length `n_samples` of :class:`sksurv.functions.StepFunction`
930
- instances will be returned.
893
+ If `return_array` is `False`, an array of `n_samples`
894
+ :class:`sksurv.functions.StepFunction` instances is returned.
895
+
896
+ If `return_array` is `True`, a numeric array of shape
897
+ `(n_samples, n_unique_times_)` is returned.
931
898
 
932
899
  Examples
933
900
  --------
934
- >>> import matplotlib.pyplot as plt
935
- >>> from sksurv.datasets import load_whas500
936
- >>> from sksurv.ensemble import ExtraSurvivalTrees
937
-
938
- Load and prepare the data.
939
-
940
- >>> X, y = load_whas500()
941
- >>> X = X.astype(float)
942
-
943
- Fit the model.
944
-
945
- >>> estimator = ExtraSurvivalTrees().fit(X, y)
946
-
947
- Estimate the survival function for the first 5 samples.
948
-
949
- >>> surv_funcs = estimator.predict_survival_function(X.iloc[:5])
950
-
951
- Plot the estimated survival functions.
952
-
953
- >>> for fn in surv_funcs:
954
- ... plt.step(fn.x, fn(fn.x), where="post")
955
- ...
956
- >>> plt.ylim(0, 1)
957
- >>> plt.show()
958
901
  """
959
902
  return super().predict_survival_function(X, return_array)
sksurv/functions.py CHANGED
@@ -18,31 +18,29 @@ __all__ = ["StepFunction"]
18
18
 
19
19
 
20
20
  class StepFunction:
21
- """Callable step function.
21
+ r"""A callable step function.
22
+
23
+ The function is defined by a set of points :math:`(x_i, y_i)` and is
24
+ evaluated as:
22
25
 
23
26
  .. math::
24
27
 
25
- f(z) = a * y_i + b,
26
- x_i \\leq z < x_{i + 1}
28
+ f(z) = a \cdot y_i + b \quad \text{if} \quad x_i \leq z < x_{i + 1}
27
29
 
28
30
  Parameters
29
31
  ----------
30
32
  x : ndarray, shape = (n_points,)
31
- Values on the x axis in ascending order.
32
-
33
+ The values on the x-axis, must be in ascending order.
33
34
  y : ndarray, shape = (n_points,)
34
- Corresponding values on the y axis.
35
-
35
+ The corresponding values on the y-axis.
36
36
  a : float, optional, default: 1.0
37
- Constant to multiply by.
38
-
37
+ A constant factor to scale ``y`` by.
39
38
  b : float, optional, default: 0.0
40
- Constant offset term.
41
-
42
- domain : tuple, optional
43
- A tuple with two entries that sets the limits of the
44
- domain of the step function.
45
- If entry is `None`, use the first/last value of `x` as limit.
39
+ A constant offset term.
40
+ domain : tuple, optional, default: (0, None)
41
+ A tuple ``(lower, upper)`` that defines the domain of the step function.
42
+ If ``lower`` or ``upper`` is ``None``, the first or last value of ``x`` is
43
+ used as the limit, respectively.
46
44
  """
47
45
 
48
46
  def __init__(self, x, y, *, a=1.0, b=0.0, domain=(0, None)):
@@ -57,36 +55,38 @@ class StepFunction:
57
55
 
58
56
  @property
59
57
  def domain(self):
60
- """Returns the domain of the function, that means
61
- the range of values that the function accepts.
58
+ """The domain of the function.
59
+
60
+ The domain is the range of values that the function accepts.
62
61
 
63
62
  Returns
64
63
  -------
65
64
  lower_limit : float
66
- Lower limit of domain.
65
+ Lower limit of the omain.
67
66
 
68
67
  upper_limit : float
69
- Upper limit of domain.
68
+ Upper limit of the domain.
70
69
  """
71
70
  return self._domain
72
71
 
73
72
  def __call__(self, x):
74
- """Evaluate step function.
75
-
76
- Values outside the interval specified by `self.domain`
77
- will raise an exception.
78
- Values in `x` that are in the interval `[self.domain[0]; self.x[0]]`
79
- get mapped to `self.y[0]`.
73
+ """Evaluate the step function at given values.
80
74
 
81
75
  Parameters
82
76
  ----------
83
- x : float|array-like, shape=(n_values,)
84
- Values to evaluate step function at.
77
+ x : float or array-like, shape=(n_values,)
78
+ The values at which to evaluate the step function.
79
+ Values must be within the function's ``domain``.
85
80
 
86
81
  Returns
87
82
  -------
88
- y : float|array-like, shape=(n_values,)
89
- Values of step function at `x`.
83
+ y : float or array-like, shape=(n_values,)
84
+ The value of the step function at ``x``.
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If ``x`` contains values outside the function's ``domain``.
90
90
  """
91
91
  x = np.atleast_1d(x)
92
92
  if not np.isfinite(x).all():
sksurv/io/arffread.py CHANGED
@@ -12,6 +12,7 @@
12
12
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
13
13
  import numpy as np
14
14
  import pandas as pd
15
+ from pandas.api.types import is_string_dtype
15
16
  from scipy.io.arff import loadarff as scipy_loadarff
16
17
 
17
18
  __all__ = ["loadarff"]
@@ -34,7 +35,8 @@ def _to_pandas(data, meta):
34
35
  data_dict[name] = pd.Categorical(raw, categories=attr_format, ordered=False)
35
36
  else:
36
37
  arr = data[name]
37
- p = pd.Series(arr, dtype=arr.dtype)
38
+ dtype = "str" if is_string_dtype(arr.dtype) else arr.dtype
39
+ p = pd.Series(arr, dtype=dtype)
38
40
  data_dict[name] = p
39
41
 
40
42
  # currently, this step converts all pandas.Categorial columns back to pandas.Series
@@ -42,17 +44,48 @@ def _to_pandas(data, meta):
42
44
 
43
45
 
44
46
  def loadarff(filename):
45
- """Load ARFF file
47
+ """Load ARFF file.
46
48
 
47
49
  Parameters
48
50
  ----------
49
- filename : string
50
- Path to ARFF file
51
+ filename : str or file-like
52
+ Path to ARFF file, or file-like object to read from.
51
53
 
52
54
  Returns
53
55
  -------
54
56
  data_frame : :class:`pandas.DataFrame`
55
57
  DataFrame containing data of ARFF file
58
+
59
+ See Also
60
+ --------
61
+ scipy.io.arff.loadarff : The underlying function that reads the ARFF file.
62
+
63
+ Examples
64
+ --------
65
+ >>> from io import StringIO
66
+ >>> from sksurv.io import loadarff
67
+ >>>
68
+ >>> # Create a dummy ARFF file
69
+ >>> arff_content = '''
70
+ ... @relation test_data
71
+ ... @attribute feature1 numeric
72
+ ... @attribute feature2 numeric
73
+ ... @attribute class {A,B,C}
74
+ ... @data
75
+ ... 1.0,2.0,A
76
+ ... 3.0,4.0,B
77
+ ... 5.0,6.0,C
78
+ ... '''
79
+ >>>
80
+ >>> # Load the ARFF file
81
+ >>> with StringIO(arff_content) as f:
82
+ ... data = loadarff(f)
83
+ >>>
84
+ >>> print(data)
85
+ class feature1 feature2
86
+ 0 A 1.0 2.0
87
+ 1 B 3.0 4.0
88
+ 2 C 5.0 6.0
56
89
  """
57
90
  data, meta = scipy_loadarff(filename)
58
91
  return _to_pandas(data, meta)
sksurv/io/arffwrite.py CHANGED
@@ -15,7 +15,7 @@ import re
15
15
 
16
16
  import numpy as np
17
17
  import pandas as pd
18
- from pandas.api.types import CategoricalDtype, is_object_dtype
18
+ from pandas.api.types import CategoricalDtype, is_string_dtype
19
19
 
20
20
  _ILLEGAL_CHARACTER_PAT = re.compile(r"[^-_=\w\d\(\)<>\.]")
21
21
 
@@ -28,15 +28,51 @@ def writearff(data, filename, relation_name=None, index=True):
28
28
  data : :class:`pandas.DataFrame`
29
29
  DataFrame containing data
30
30
 
31
- filename : string or file-like object
31
+ filename : str or file-like object
32
32
  Path to ARFF file or file-like object. In the latter case,
33
33
  the handle is closed by calling this function.
34
34
 
35
- relation_name : string, optional, default: "pandas"
35
+ relation_name : str, optional, default: 'pandas'
36
36
  Name of relation in ARFF file.
37
37
 
38
38
  index : boolean, optional, default: True
39
39
  Write row names (index)
40
+
41
+ See Also
42
+ --------
43
+ loadarff : Function to read ARFF files.
44
+
45
+ Examples
46
+ --------
47
+ >>> import numpy as np
48
+ >>> import pandas as pd
49
+ >>> from sksurv.io import writearff
50
+ >>>
51
+ >>> # Create a dummy DataFrame
52
+ >>> data = pd.DataFrame({
53
+ ... 'feature1': [1.0, 3.0, 5.0],
54
+ ... 'feature2': [2.0, np.nan, 6.0],
55
+ ... 'class': ['A', 'B', 'C']
56
+ ... }, index=['One', 'Two', 'Three'])
57
+ >>>
58
+ >>> # Write to ARFF file
59
+ >>> writearff(data, 'test_output.arff', relation_name='test_data')
60
+ >>>
61
+ >>> # Read contents of ARFF file
62
+ >>> with open('test_output.arff') as f:
63
+ ... arff_contents = "".join(f.readlines())
64
+ >>> print(arff_contents)
65
+ @relation test_data
66
+ <BLANKLINE>
67
+ @attribute index {One,Three,Two}
68
+ @attribute feature1 real
69
+ @attribute feature2 real
70
+ @attribute class {A,B,C}
71
+ <BLANKLINE>
72
+ @data
73
+ One,1.0,2.0,A
74
+ Two,3.0,?,B
75
+ Three,5.0,6.0,C
40
76
  """
41
77
  if isinstance(filename, str):
42
78
  fp = open(filename, "w")
@@ -70,7 +106,7 @@ def _write_header(data, fp, relation_name, index):
70
106
  name = attribute_names[column]
71
107
  fp.write(f"@attribute {name}\t")
72
108
 
73
- if isinstance(series.dtype, CategoricalDtype) or is_object_dtype(series):
109
+ if isinstance(series.dtype, CategoricalDtype) or is_string_dtype(series.dtype):
74
110
  _write_attribute_categorical(series, fp)
75
111
  elif np.issubdtype(series.dtype, np.floating):
76
112
  fp.write("real")
@@ -132,7 +168,7 @@ def _write_data(data, fp):
132
168
  fp.write("@data\n")
133
169
 
134
170
  def to_str(x):
135
- if pd.isnull(x):
171
+ if pd.isna(x):
136
172
  return "?"
137
173
  return str(x)
138
174