scikit-survival 0.24.1__cp310-cp310-macosx_11_0_arm64.whl → 0.25.0__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. scikit_survival-0.25.0.dist-info/METADATA +185 -0
  2. scikit_survival-0.25.0.dist-info/RECORD +58 -0
  3. {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/WHEEL +1 -1
  4. sksurv/__init__.py +51 -6
  5. sksurv/base.py +12 -2
  6. sksurv/bintrees/_binarytrees.cpython-310-darwin.so +0 -0
  7. sksurv/column.py +33 -29
  8. sksurv/compare.py +22 -22
  9. sksurv/datasets/base.py +45 -20
  10. sksurv/docstrings.py +99 -0
  11. sksurv/ensemble/_coxph_loss.cpython-310-darwin.so +0 -0
  12. sksurv/ensemble/boosting.py +116 -168
  13. sksurv/ensemble/forest.py +94 -151
  14. sksurv/functions.py +29 -29
  15. sksurv/io/arffread.py +34 -3
  16. sksurv/io/arffwrite.py +38 -2
  17. sksurv/kernels/_clinical_kernel.cpython-310-darwin.so +0 -0
  18. sksurv/kernels/clinical.py +33 -13
  19. sksurv/linear_model/_coxnet.cpython-310-darwin.so +0 -0
  20. sksurv/linear_model/aft.py +14 -11
  21. sksurv/linear_model/coxnet.py +138 -89
  22. sksurv/linear_model/coxph.py +102 -83
  23. sksurv/meta/ensemble_selection.py +91 -9
  24. sksurv/meta/stacking.py +47 -26
  25. sksurv/metrics.py +257 -224
  26. sksurv/nonparametric.py +150 -81
  27. sksurv/preprocessing.py +55 -27
  28. sksurv/svm/_minlip.cpython-310-darwin.so +0 -0
  29. sksurv/svm/_prsvm.cpython-310-darwin.so +0 -0
  30. sksurv/svm/minlip.py +160 -79
  31. sksurv/svm/naive_survival_svm.py +63 -34
  32. sksurv/svm/survival_svm.py +103 -103
  33. sksurv/tree/_criterion.cpython-310-darwin.so +0 -0
  34. sksurv/tree/tree.py +170 -84
  35. sksurv/util.py +80 -26
  36. scikit_survival-0.24.1.dist-info/METADATA +0 -889
  37. scikit_survival-0.24.1.dist-info/RECORD +0 -57
  38. {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/licenses/COPYING +0 -0
  39. {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/top_level.txt +0 -0
sksurv/nonparametric.py CHANGED
@@ -31,36 +31,36 @@ __all__ = [
31
31
 
32
32
 
33
33
  def _compute_counts(event, time, order=None):
34
- """Count right censored and uncensored samples at each unique time point.
34
+ """Count right-censored and uncensored samples at each unique time point.
35
35
 
36
36
  Parameters
37
37
  ----------
38
- event : array
38
+ event : ndarray
39
39
  Boolean event indicator.
40
40
  Integer in the case of multiple risks.
41
41
  Zero means right-censored event.
42
42
  Positive values for each of the possible risk events.
43
43
 
44
- time : array
44
+ time : ndarray
45
45
  Survival time or time of censoring.
46
46
 
47
- order : array or None
47
+ order : ndarray or None
48
48
  Indices to order time in ascending order.
49
49
  If None, order will be computed.
50
50
 
51
51
  Returns
52
52
  -------
53
- times : array
53
+ times : ndarray
54
54
  Unique time points.
55
55
 
56
- n_events : array
56
+ n_events : ndarray
57
57
  Number of events at each time point.
58
58
  2D array with shape `(n_unique_time_points, n_risks + 1)` in the case of competing risks.
59
59
 
60
- n_at_risk : array
60
+ n_at_risk : ndarray
61
61
  Number of samples that have not been censored or have not had an event at each time point.
62
62
 
63
- n_censored : array
63
+ n_censored : ndarray
64
64
  Number of censored samples at each time point.
65
65
  """
66
66
  n_samples = event.shape[0]
@@ -116,29 +116,29 @@ def _compute_counts(event, time, order=None):
116
116
 
117
117
 
118
118
  def _compute_counts_truncated(event, time_enter, time_exit):
119
- """Compute counts for left truncated and right censored survival data.
119
+ """Compute counts for left truncated and right-censored survival data.
120
120
 
121
121
  Parameters
122
122
  ----------
123
- event : array
123
+ event : ndarray
124
124
  Boolean event indicator.
125
125
 
126
- time_start : array
126
+ time_enter : ndarray
127
127
  Time when a subject entered the study.
128
128
 
129
- time_exit : array
129
+ time_exit : ndarray
130
130
  Time when a subject left the study due to an
131
131
  event or censoring.
132
132
 
133
133
  Returns
134
134
  -------
135
- times : array
135
+ times : ndarray
136
136
  Unique time points.
137
137
 
138
- n_events : array
138
+ n_events : ndarray
139
139
  Number of events at each time point.
140
140
 
141
- n_at_risk : array
141
+ n_at_risk : ndarray
142
142
  Number of samples that are censored or have an event at each time point.
143
143
  """
144
144
  if (time_enter > time_exit).any():
@@ -212,6 +212,27 @@ def _ci_logmlog(s, sigma_t, conf_level):
212
212
 
213
213
 
214
214
  def _km_ci_estimator(prob_survival, ratio_var, conf_level, conf_type):
215
+ """Helper to compute confidence intervals for the Kaplan-Meier estimate.
216
+
217
+ Parameters
218
+ ----------
219
+ prob_survival : ndarray, shape = (n_times,)
220
+ Survival probability at each unique time point.
221
+
222
+ ratio_var : ndarray, shape = (n_times,)
223
+ The variance ratio term for each unique time point.
224
+
225
+ conf_level : float
226
+ The level for a two-sided confidence interval.
227
+
228
+ conf_type : {'log-log'}
229
+ The type of confidence intervals to estimate.
230
+
231
+ Returns
232
+ -------
233
+ ci : ndarray, shape = (2, n_times)
234
+ Pointwise confidence interval.
235
+ """
215
236
  if conf_type not in {"log-log"}:
216
237
  raise ValueError(f"conf_type must be None or a str among {{'log-log'}}, but was {conf_type!r}")
217
238
 
@@ -232,17 +253,18 @@ def kaplan_meier_estimator(
232
253
  conf_level=0.95,
233
254
  conf_type=None,
234
255
  ):
235
- """Kaplan-Meier estimator of survival function.
256
+ """Computes the Kaplan-Meier estimate of the survival function.
236
257
 
237
258
  See [1]_ for further description.
238
259
 
239
260
  Parameters
240
261
  ----------
241
262
  event : array-like, shape = (n_samples,)
242
- Contains binary event indicators.
263
+ A boolean array where ``True`` indicates an event and ``False`` indicates
264
+ right-censoring.
243
265
 
244
266
  time_exit : array-like, shape = (n_samples,)
245
- Contains event/censoring times.
267
+ Time of event or censoring.
246
268
 
247
269
  time_enter : array-like, shape = (n_samples,), optional
248
270
  Contains time when each individual entered the study for
@@ -270,14 +292,14 @@ def kaplan_meier_estimator(
270
292
 
271
293
  Returns
272
294
  -------
273
- time : array, shape = (n_times,)
295
+ time : ndarray, shape = (n_times,)
274
296
  Unique times.
275
297
 
276
- prob_survival : array, shape = (n_times,)
298
+ prob_survival : ndarray, shape = (n_times,)
277
299
  Survival probability at each unique time point.
278
300
  If `time_enter` is provided, estimates are conditional probabilities.
279
301
 
280
- conf_int : array, shape = (2, n_times)
302
+ conf_int : ndarray, shape = (2, n_times)
281
303
  Pointwise confidence interval of the Kaplan-Meier estimator
282
304
  at each unique time point.
283
305
  Only provided if `conf_type` is not None.
@@ -286,11 +308,23 @@ def kaplan_meier_estimator(
286
308
  --------
287
309
  Creating a Kaplan-Meier curve:
288
310
 
289
- >>> x, y, conf_int = kaplan_meier_estimator(event, time, conf_type="log-log")
290
- >>> plt.step(x, y, where="post")
291
- >>> plt.fill_between(x, conf_int[0], conf_int[1], alpha=0.25, step="post")
292
- >>> plt.ylim(0, 1)
293
- >>> plt.show()
311
+ .. plot::
312
+
313
+ >>> import matplotlib.pyplot as plt
314
+ >>> from sksurv.datasets import load_veterans_lung_cancer
315
+ >>> from sksurv.nonparametric import kaplan_meier_estimator
316
+ >>>
317
+ >>> _, y = load_veterans_lung_cancer()
318
+ >>> time, prob_surv, conf_int = kaplan_meier_estimator(
319
+ ... y["Status"], y["Survival_in_days"], conf_type="log-log"
320
+ ... )
321
+ >>> plt.step(time, prob_surv, where="post")
322
+ [...]
323
+ >>> plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
324
+ <matplotlib.collections.PolyCollection object at 0x...>
325
+ >>> plt.ylim(0, 1)
326
+ (0.0, 1.0)
327
+ >>> plt.show() # doctest: +SKIP
294
328
 
295
329
  See also
296
330
  --------
@@ -359,26 +393,44 @@ def kaplan_meier_estimator(
359
393
 
360
394
 
361
395
  def nelson_aalen_estimator(event, time):
362
- """Nelson-Aalen estimator of cumulative hazard function.
396
+ """Computes the Nelson-Aalen estimate of the cumulative hazard function.
363
397
 
364
398
  See [1]_, [2]_ for further description.
365
399
 
366
400
  Parameters
367
401
  ----------
368
402
  event : array-like, shape = (n_samples,)
369
- Contains binary event indicators.
403
+ A boolean array where ``True`` indicates an event and ``False`` indicates
404
+ right-censoring.
370
405
 
371
406
  time : array-like, shape = (n_samples,)
372
- Contains event/censoring times.
407
+ Time of event or censoring.
373
408
 
374
409
  Returns
375
410
  -------
376
- time : array, shape = (n_times,)
411
+ time : ndarray, shape = (n_times,)
377
412
  Unique times.
378
413
 
379
- cum_hazard : array, shape = (n_times,)
414
+ cum_hazard : ndarray, shape = (n_times,)
380
415
  Cumulative hazard at each unique time point.
381
416
 
417
+ Examples
418
+ --------
419
+ Creating a cumulative hazard curve:
420
+
421
+ .. plot::
422
+
423
+ >>> import matplotlib.pyplot as plt
424
+ >>> from sksurv.datasets import load_aids
425
+ >>> from sksurv.nonparametric import nelson_aalen_estimator
426
+ >>>
427
+ >>> _, y = load_aids(endpoint="death")
428
+ >>> time, cum_hazard = nelson_aalen_estimator(y["censor_d"], y["time_d"])
429
+ >>>
430
+ >>> plt.step(time, cum_hazard, where="post")
431
+ [...]
432
+ >>> plt.show() # doctest: +SKIP
433
+
382
434
  References
383
435
  ----------
384
436
  .. [1] Nelson, W., "Theory and applications of hazard plotting for censored failure data",
@@ -401,15 +453,16 @@ def ipc_weights(event, time):
401
453
 
402
454
  Parameters
403
455
  ----------
404
- event : array, shape = (n_samples,)
405
- Boolean event indicator.
456
+ event : array-like, shape = (n_samples,)
457
+ A boolean array where ``True`` indicates an event and ``False`` indicates
458
+ right-censoring.
406
459
 
407
- time : array, shape = (n_samples,)
460
+ time : array-like, shape = (n_samples,)
408
461
  Time when a subject experienced an event or was censored.
409
462
 
410
463
  Returns
411
464
  -------
412
- weights : array, shape = (n_samples,)
465
+ weights : ndarray, shape = (n_samples,)
413
466
  inverse probability of censoring weights
414
467
 
415
468
  See also
@@ -469,9 +522,9 @@ class SurvivalFunctionEstimator(BaseEstimator):
469
522
  Parameters
470
523
  ----------
471
524
  y : structured array, shape = (n_samples,)
472
- A structured array containing the binary event indicator
473
- as first field, and time of event or time of censoring as
474
- second field.
525
+ A structured array with two fields. The first field is a boolean
526
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
527
+ The second field is a float with the time of event or time of censoring.
475
528
 
476
529
  Returns
477
530
  -------
@@ -493,13 +546,13 @@ class SurvivalFunctionEstimator(BaseEstimator):
493
546
  return self
494
547
 
495
548
  def predict_proba(self, time, return_conf_int=False):
496
- """Return probability of an event after given time point.
549
+ r"""Return probability of remaining event-free at given time points.
497
550
 
498
- :math:`\\hat{S}(t) = P(T > t)`
551
+ :math:`\hat{S}(t) = P(T > t)`
499
552
 
500
553
  Parameters
501
554
  ----------
502
- time : array, shape = (n_samples,)
555
+ time : array-like, shape = (n_samples,)
503
556
  Time to estimate probability at.
504
557
 
505
558
  return_conf_int : bool, optional, default: False
@@ -510,10 +563,10 @@ class SurvivalFunctionEstimator(BaseEstimator):
510
563
 
511
564
  Returns
512
565
  -------
513
- prob : array, shape = (n_samples,)
514
- Probability of an event at the passed time points.
566
+ prob : ndarray, shape = (n_samples,)
567
+ Probability of remaining event-free at the given time points.
515
568
 
516
- conf_int : array, shape = (2, n_samples)
569
+ conf_int : ndarray, shape = (2, n_samples)
517
570
  Pointwise confidence interval at the passed time points.
518
571
  Only provided if `return_conf_int` is True.
519
572
  """
@@ -561,9 +614,9 @@ class CensoringDistributionEstimator(SurvivalFunctionEstimator):
561
614
  Parameters
562
615
  ----------
563
616
  y : structured array, shape = (n_samples,)
564
- A structured array containing the binary event indicator
565
- as first field, and time of event or time of censoring as
566
- second field.
617
+ A structured array with two fields. The first field is a boolean
618
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
619
+ The second field is a float with the time of event or time of censoring.
567
620
 
568
621
  Returns
569
622
  -------
@@ -581,20 +634,20 @@ class CensoringDistributionEstimator(SurvivalFunctionEstimator):
581
634
  return self
582
635
 
583
636
  def predict_ipcw(self, y):
584
- """Return inverse probability of censoring weights at given time points.
637
+ r"""Return inverse probability of censoring weights at given time points.
585
638
 
586
- :math:`\\omega_i = \\delta_i / \\hat{G}(y_i)`
639
+ :math:`\omega_i = \delta_i / \hat{G}(y_i)`
587
640
 
588
641
  Parameters
589
642
  ----------
590
643
  y : structured array, shape = (n_samples,)
591
- A structured array containing the binary event indicator
592
- as first field, and time of event or time of censoring as
593
- second field.
644
+ A structured array with two fields. The first field is a boolean
645
+ where ``True`` indicates an event and ``False`` indicates right-censoring.
646
+ The second field is a float with the time of event or time of censoring.
594
647
 
595
648
  Returns
596
649
  -------
597
- ipcw : array, shape = (n_samples,)
650
+ ipcw : ndarray, shape = (n_samples,)
598
651
  Inverse probability of censoring weights.
599
652
  """
600
653
  event, time = check_y_survival(y)
@@ -638,14 +691,14 @@ def cumulative_incidence_competing_risks(
638
691
 
639
692
  Parameters
640
693
  ----------
641
- event : array-like, shape = (n_samples,)
642
- Contains event indicators.
694
+ event : array-like, shape = (n_samples,), dtype = int
695
+ Contains event indicators. A value of 0 indicates right-censoring,
696
+ while a positive integer from 1 to `n_risks` corresponds to a specific risk.
697
+ `n_risks` is the total number of different risks.
698
+ It assumes there are events for all possible risks.
643
699
 
644
700
  time_exit : array-like, shape = (n_samples,)
645
- Contains event/censoring times. '0' indicates right-censoring.
646
- Positive integers (between 1 and n_risks, n_risks being the total number of different risks)
647
- indicate the possible different risks.
648
- It assumes there are events for all possible risks.
701
+ Contains event or censoring times.
649
702
 
650
703
  time_min : float, optional, default: None
651
704
  Compute estimator conditional on survival at least up to
@@ -660,23 +713,24 @@ def cumulative_incidence_competing_risks(
660
713
  If "log-log", estimate confidence intervals using
661
714
  the log hazard or :math:`log(-log(S(t)))`.
662
715
 
663
- var_type : None or one of {'Aalen', 'Dinse', 'Dinse_Approx'}, optional, default: 'Aalen'
716
+ var_type : {'Aalen', 'Dinse', 'Dinse_Approx'}, optional, default: 'Aalen'
664
717
  The method for estimating the variance of the estimator.
665
718
  See [2]_, [3]_ and [4]_ for each of the methods.
666
719
  Only used if `conf_type` is not None.
667
720
 
668
721
  Returns
669
722
  -------
670
- time : array, shape = (n_times,)
723
+ time : ndarray, shape = (n_times,)
671
724
  Unique times.
672
725
 
673
- cum_incidence : array, shape = (n_risks + 1, n_times)
674
- Cumulative incidence at each unique time point.
675
- The first dimension indicates total risk (``cum_incidence[0]``),
676
- the dimension `i=1,...,n_risks` the incidence for each competing risk.
726
+ cum_incidence : ndarray, shape = (n_risks + 1, n_times)
727
+ Cumulative incidence for each risk. The first row (``cum_incidence[0]``)
728
+ is the cumulative incidence of any risk (total risk). The remaining
729
+ rows (``cum_incidence[1:]``) are the cumulative incidences for each
730
+ competing risk.
677
731
 
678
- conf_int : array, shape = (n_risks + 1, 2, n_times)
679
- Pointwise confidence interval (second axis) of the Kaplan-Meier estimator
732
+ conf_int : ndarray, shape = (n_risks + 1, 2, n_times)
733
+ Pointwise confidence interval (second axis) of the cumulative incidence function
680
734
  at each unique time point (last axis)
681
735
  for all possible risks (first axis), including overall risk (``conf_int[0]``).
682
736
  Only provided if `conf_type` is not None.
@@ -685,20 +739,35 @@ def cumulative_incidence_competing_risks(
685
739
  --------
686
740
  Creating cumulative incidence curves:
687
741
 
688
- >>> from sksurv.datasets import load_bmt
689
- >>> dis, bmt_df = load_bmt()
690
- >>> event = bmt_df["status"]
691
- >>> time = bmt_df["ftime"]
692
- >>> n_risks = event.max()
693
- >>> x, y, conf_int = cumulative_incidence_competing_risks(event, time, conf_type="log-log")
694
- >>> plt.step(x, y[0], where="post", label="Total risk")
695
- >>> plt.fill_between(x, conf_int[0, 0], conf_int[0, 1], alpha=0.25, step="post")
696
- >>> for i in range(1, n_risks + 1):
697
- >>> plt.step(x, y[i], where="post", label=f"{i}-risk")
698
- >>> plt.fill_between(x, conf_int[i, 0], conf_int[i, 1], alpha=0.25, step="post")
699
- >>> plt.ylim(0, 1)
700
- >>> plt.legend()
701
- >>> plt.show()
742
+ .. plot::
743
+
744
+ >>> import matplotlib.pyplot as plt
745
+ >>> from sksurv.datasets import load_bmt
746
+ >>> from sksurv.nonparametric import cumulative_incidence_competing_risks
747
+ >>>
748
+ >>> dis, bmt_df = load_bmt()
749
+ >>> event = bmt_df["status"]
750
+ >>> time = bmt_df["ftime"]
751
+ >>> n_risks = event.max()
752
+ >>>
753
+ >>> x, y, conf_int = cumulative_incidence_competing_risks(
754
+ ... event, time, conf_type="log-log"
755
+ ... )
756
+ >>>
757
+ >>> plt.step(x, y[0], where="post", label="Total risk")
758
+ [...]
759
+ >>> plt.fill_between(x, conf_int[0, 0], conf_int[0, 1], alpha=0.25, step="post")
760
+ <matplotlib.collections.PolyCollection object at 0x...>
761
+ >>> for i in range(1, n_risks + 1):
762
+ ... plt.step(x, y[i], where="post", label=f"{i}-risk")
763
+ ... plt.fill_between(x, conf_int[i, 0], conf_int[i, 1], alpha=0.25, step="post")
764
+ [...]
765
+ <matplotlib.collections.PolyCollection object at 0x...>
766
+ >>> plt.ylim(0, 1)
767
+ (0.0, 1.0)
768
+ >>> plt.legend()
769
+ <matplotlib.legend.Legend object at 0x...>
770
+ >>> plt.show() # doctest: +SKIP
702
771
 
703
772
  References
704
773
  ----------
sksurv/preprocessing.py CHANGED
@@ -19,40 +19,60 @@ __all__ = ["OneHotEncoder"]
19
19
 
20
20
 
21
21
  def check_columns_exist(actual, expected):
22
+ """Check if all expected columns are present in a dataframe.
23
+
24
+ Parameters
25
+ ----------
26
+ actual : pandas.Index
27
+ The actual columns of a dataframe.
28
+ expected : pandas.Index
29
+ The expected columns.
30
+
31
+ Raises
32
+ ------
33
+ ValueError
34
+ If any of the expected columns are missing from the actual columns.
35
+ """
22
36
  missing_features = expected.difference(actual)
23
37
  if len(missing_features) != 0:
24
38
  raise ValueError(f"{len(missing_features)} features are missing from data: {missing_features.tolist()}")
25
39
 
26
40
 
27
41
  class OneHotEncoder(BaseEstimator, TransformerMixin):
28
- """Encode categorical columns with `M` categories into `M-1` columns according
29
- to the one-hot scheme.
42
+ """Encode categorical features using a one-hot scheme.
30
43
 
31
- The order of non-categorical columns is preserved, encoded columns are inserted
32
- inplace of the original column.
44
+ This transformer only works on pandas DataFrames. It identifies columns
45
+ with `category` or `object` data type as categorical features.
46
+ The features are encoded using a one-hot (or dummy) encoding scheme, which
47
+ creates a binary column for each category. By default, one category per feature
48
+ is dropped. a column with `M` categories is encoded as `M-1` integer columns
49
+ according to the one-hot scheme.
50
+
51
+ The order of non-categorical columns is preserved. Encoded columns are inserted
52
+ in place of the original column.
33
53
 
34
54
  Parameters
35
55
  ----------
36
- allow_drop : boolean, optional, default: True
56
+ allow_drop : bool, optional, default: True
37
57
  Whether to allow dropping categorical columns that only consist
38
58
  of a single category.
39
59
 
40
60
  Attributes
41
61
  ----------
42
62
  feature_names_ : pandas.Index
43
- List of encoded columns.
63
+ Names of categorical features that were encoded.
44
64
 
45
65
  categories_ : dict
46
- Categories of encoded columns.
66
+ A dictionary mapping each categorical feature name to a list of its
67
+ categories.
47
68
 
48
- encoded_columns_ : list
49
- Name of columns after encoding.
50
- Includes names of non-categorical columns.
69
+ encoded_columns_ : pandas.Index
70
+ The full list of feature names in the transformed output.
51
71
 
52
72
  n_features_in_ : int
53
73
  Number of features seen during ``fit``.
54
74
 
55
- feature_names_in_ : ndarray of shape (`n_features_in_`,)
75
+ feature_names_in_ : ndarray, shape = (`n_features_in_`,)
56
76
  Names of features seen during ``fit``. Defined only when `X`
57
77
  has feature names that are all strings.
58
78
  """
@@ -61,18 +81,20 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
61
81
  self.allow_drop = allow_drop
62
82
 
63
83
  def fit(self, X, y=None): # pylint: disable=unused-argument
64
- """Retrieve categorical columns.
84
+ """Determine which features are categorical and should be one-hot encoded.
65
85
 
66
86
  Parameters
67
87
  ----------
68
88
  X : pandas.DataFrame
69
- Data to encode.
70
- y :
71
- Ignored. For compatibility with Pipeline.
89
+ The data to determine categorical features from.
90
+ y : None
91
+ Ignored. This parameter exists only for compatibility with
92
+ :class:`sklearn.pipeline.Pipeline`.
93
+
72
94
  Returns
73
95
  -------
74
96
  self : object
75
- Returns self
97
+ Returns the instance itself.
76
98
  """
77
99
  self.fit_transform(X)
78
100
  return self
@@ -81,21 +103,27 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
81
103
  return encode_categorical(X, columns=columns_to_encode, allow_drop=self.allow_drop)
82
104
 
83
105
  def fit_transform(self, X, y=None, **fit_params): # pylint: disable=unused-argument
84
- """Convert categorical columns to numeric values.
106
+ """Fit to data, then transform it.
107
+
108
+ Fits the transformer to ``X`` by identifying categorical features and
109
+ then returns a transformed version of ``X`` with categorical features
110
+ one-hot encoded.
85
111
 
86
112
  Parameters
87
113
  ----------
88
114
  X : pandas.DataFrame
89
- Data to encode.
90
- y :
91
- Ignored. For compatibility with TransformerMixin.
92
- fit_params :
93
- Ignored. For compatibility with TransformerMixin.
115
+ The data to fit and transform.
116
+ y : None, optional
117
+ Ignored. This parameter exists only for compatibility with
118
+ :class:`sklearn.pipeline.Pipeline`.
119
+ fit_params : dict, optional
120
+ Ignored. This parameter exists only for compatibility with
121
+ :class:`sklearn.pipeline.Pipeline`.
94
122
 
95
123
  Returns
96
124
  -------
97
125
  Xt : pandas.DataFrame
98
- Encoded data.
126
+ The transformed data.
99
127
  """
100
128
  _check_feature_names(self, X, reset=True)
101
129
  _check_n_features(self, X, reset=True)
@@ -108,17 +136,17 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
108
136
  return x_dummy
109
137
 
110
138
  def transform(self, X):
111
- """Convert categorical columns to numeric values.
139
+ """Transform ``X`` by one-hot encoding categorical features.
112
140
 
113
141
  Parameters
114
142
  ----------
115
143
  X : pandas.DataFrame
116
- Data to encode.
144
+ The data to transform.
117
145
 
118
146
  Returns
119
147
  -------
120
148
  Xt : pandas.DataFrame
121
- Encoded data.
149
+ The transformed data.
122
150
  """
123
151
  check_is_fitted(self, "encoded_columns_")
124
152
  _check_n_features(self, X, reset=False)
@@ -136,7 +164,7 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
136
164
 
137
165
  Parameters
138
166
  ----------
139
- input_features : array-like of str or None, default=None
167
+ input_features : array-like of str or None, default: None
140
168
  Input features.
141
169
 
142
170
  - If `input_features` is `None`, then `feature_names_in_` is
Binary file
Binary file