scikit-survival 0.24.1__cp310-cp310-macosx_11_0_arm64.whl → 0.25.0__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scikit_survival-0.25.0.dist-info/METADATA +185 -0
- scikit_survival-0.25.0.dist-info/RECORD +58 -0
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/WHEEL +1 -1
- sksurv/__init__.py +51 -6
- sksurv/base.py +12 -2
- sksurv/bintrees/_binarytrees.cpython-310-darwin.so +0 -0
- sksurv/column.py +33 -29
- sksurv/compare.py +22 -22
- sksurv/datasets/base.py +45 -20
- sksurv/docstrings.py +99 -0
- sksurv/ensemble/_coxph_loss.cpython-310-darwin.so +0 -0
- sksurv/ensemble/boosting.py +116 -168
- sksurv/ensemble/forest.py +94 -151
- sksurv/functions.py +29 -29
- sksurv/io/arffread.py +34 -3
- sksurv/io/arffwrite.py +38 -2
- sksurv/kernels/_clinical_kernel.cpython-310-darwin.so +0 -0
- sksurv/kernels/clinical.py +33 -13
- sksurv/linear_model/_coxnet.cpython-310-darwin.so +0 -0
- sksurv/linear_model/aft.py +14 -11
- sksurv/linear_model/coxnet.py +138 -89
- sksurv/linear_model/coxph.py +102 -83
- sksurv/meta/ensemble_selection.py +91 -9
- sksurv/meta/stacking.py +47 -26
- sksurv/metrics.py +257 -224
- sksurv/nonparametric.py +150 -81
- sksurv/preprocessing.py +55 -27
- sksurv/svm/_minlip.cpython-310-darwin.so +0 -0
- sksurv/svm/_prsvm.cpython-310-darwin.so +0 -0
- sksurv/svm/minlip.py +160 -79
- sksurv/svm/naive_survival_svm.py +63 -34
- sksurv/svm/survival_svm.py +103 -103
- sksurv/tree/_criterion.cpython-310-darwin.so +0 -0
- sksurv/tree/tree.py +170 -84
- sksurv/util.py +80 -26
- scikit_survival-0.24.1.dist-info/METADATA +0 -889
- scikit_survival-0.24.1.dist-info/RECORD +0 -57
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/licenses/COPYING +0 -0
- {scikit_survival-0.24.1.dist-info → scikit_survival-0.25.0.dist-info}/top_level.txt +0 -0
sksurv/nonparametric.py
CHANGED
|
@@ -31,36 +31,36 @@ __all__ = [
|
|
|
31
31
|
|
|
32
32
|
|
|
33
33
|
def _compute_counts(event, time, order=None):
|
|
34
|
-
"""Count right
|
|
34
|
+
"""Count right-censored and uncensored samples at each unique time point.
|
|
35
35
|
|
|
36
36
|
Parameters
|
|
37
37
|
----------
|
|
38
|
-
event :
|
|
38
|
+
event : ndarray
|
|
39
39
|
Boolean event indicator.
|
|
40
40
|
Integer in the case of multiple risks.
|
|
41
41
|
Zero means right-censored event.
|
|
42
42
|
Positive values for each of the possible risk events.
|
|
43
43
|
|
|
44
|
-
time :
|
|
44
|
+
time : ndarray
|
|
45
45
|
Survival time or time of censoring.
|
|
46
46
|
|
|
47
|
-
order :
|
|
47
|
+
order : ndarray or None
|
|
48
48
|
Indices to order time in ascending order.
|
|
49
49
|
If None, order will be computed.
|
|
50
50
|
|
|
51
51
|
Returns
|
|
52
52
|
-------
|
|
53
|
-
times :
|
|
53
|
+
times : ndarray
|
|
54
54
|
Unique time points.
|
|
55
55
|
|
|
56
|
-
n_events :
|
|
56
|
+
n_events : ndarray
|
|
57
57
|
Number of events at each time point.
|
|
58
58
|
2D array with shape `(n_unique_time_points, n_risks + 1)` in the case of competing risks.
|
|
59
59
|
|
|
60
|
-
n_at_risk :
|
|
60
|
+
n_at_risk : ndarray
|
|
61
61
|
Number of samples that have not been censored or have not had an event at each time point.
|
|
62
62
|
|
|
63
|
-
n_censored :
|
|
63
|
+
n_censored : ndarray
|
|
64
64
|
Number of censored samples at each time point.
|
|
65
65
|
"""
|
|
66
66
|
n_samples = event.shape[0]
|
|
@@ -116,29 +116,29 @@ def _compute_counts(event, time, order=None):
|
|
|
116
116
|
|
|
117
117
|
|
|
118
118
|
def _compute_counts_truncated(event, time_enter, time_exit):
|
|
119
|
-
"""Compute counts for left truncated and right
|
|
119
|
+
"""Compute counts for left truncated and right-censored survival data.
|
|
120
120
|
|
|
121
121
|
Parameters
|
|
122
122
|
----------
|
|
123
|
-
event :
|
|
123
|
+
event : ndarray
|
|
124
124
|
Boolean event indicator.
|
|
125
125
|
|
|
126
|
-
|
|
126
|
+
time_enter : ndarray
|
|
127
127
|
Time when a subject entered the study.
|
|
128
128
|
|
|
129
|
-
time_exit :
|
|
129
|
+
time_exit : ndarray
|
|
130
130
|
Time when a subject left the study due to an
|
|
131
131
|
event or censoring.
|
|
132
132
|
|
|
133
133
|
Returns
|
|
134
134
|
-------
|
|
135
|
-
times :
|
|
135
|
+
times : ndarray
|
|
136
136
|
Unique time points.
|
|
137
137
|
|
|
138
|
-
n_events :
|
|
138
|
+
n_events : ndarray
|
|
139
139
|
Number of events at each time point.
|
|
140
140
|
|
|
141
|
-
n_at_risk :
|
|
141
|
+
n_at_risk : ndarray
|
|
142
142
|
Number of samples that are censored or have an event at each time point.
|
|
143
143
|
"""
|
|
144
144
|
if (time_enter > time_exit).any():
|
|
@@ -212,6 +212,27 @@ def _ci_logmlog(s, sigma_t, conf_level):
|
|
|
212
212
|
|
|
213
213
|
|
|
214
214
|
def _km_ci_estimator(prob_survival, ratio_var, conf_level, conf_type):
|
|
215
|
+
"""Helper to compute confidence intervals for the Kaplan-Meier estimate.
|
|
216
|
+
|
|
217
|
+
Parameters
|
|
218
|
+
----------
|
|
219
|
+
prob_survival : ndarray, shape = (n_times,)
|
|
220
|
+
Survival probability at each unique time point.
|
|
221
|
+
|
|
222
|
+
ratio_var : ndarray, shape = (n_times,)
|
|
223
|
+
The variance ratio term for each unique time point.
|
|
224
|
+
|
|
225
|
+
conf_level : float
|
|
226
|
+
The level for a two-sided confidence interval.
|
|
227
|
+
|
|
228
|
+
conf_type : {'log-log'}
|
|
229
|
+
The type of confidence intervals to estimate.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
ci : ndarray, shape = (2, n_times)
|
|
234
|
+
Pointwise confidence interval.
|
|
235
|
+
"""
|
|
215
236
|
if conf_type not in {"log-log"}:
|
|
216
237
|
raise ValueError(f"conf_type must be None or a str among {{'log-log'}}, but was {conf_type!r}")
|
|
217
238
|
|
|
@@ -232,17 +253,18 @@ def kaplan_meier_estimator(
|
|
|
232
253
|
conf_level=0.95,
|
|
233
254
|
conf_type=None,
|
|
234
255
|
):
|
|
235
|
-
"""Kaplan-Meier
|
|
256
|
+
"""Computes the Kaplan-Meier estimate of the survival function.
|
|
236
257
|
|
|
237
258
|
See [1]_ for further description.
|
|
238
259
|
|
|
239
260
|
Parameters
|
|
240
261
|
----------
|
|
241
262
|
event : array-like, shape = (n_samples,)
|
|
242
|
-
|
|
263
|
+
A boolean array where ``True`` indicates an event and ``False`` indicates
|
|
264
|
+
right-censoring.
|
|
243
265
|
|
|
244
266
|
time_exit : array-like, shape = (n_samples,)
|
|
245
|
-
|
|
267
|
+
Time of event or censoring.
|
|
246
268
|
|
|
247
269
|
time_enter : array-like, shape = (n_samples,), optional
|
|
248
270
|
Contains time when each individual entered the study for
|
|
@@ -270,14 +292,14 @@ def kaplan_meier_estimator(
|
|
|
270
292
|
|
|
271
293
|
Returns
|
|
272
294
|
-------
|
|
273
|
-
time :
|
|
295
|
+
time : ndarray, shape = (n_times,)
|
|
274
296
|
Unique times.
|
|
275
297
|
|
|
276
|
-
prob_survival :
|
|
298
|
+
prob_survival : ndarray, shape = (n_times,)
|
|
277
299
|
Survival probability at each unique time point.
|
|
278
300
|
If `time_enter` is provided, estimates are conditional probabilities.
|
|
279
301
|
|
|
280
|
-
conf_int :
|
|
302
|
+
conf_int : ndarray, shape = (2, n_times)
|
|
281
303
|
Pointwise confidence interval of the Kaplan-Meier estimator
|
|
282
304
|
at each unique time point.
|
|
283
305
|
Only provided if `conf_type` is not None.
|
|
@@ -286,11 +308,23 @@ def kaplan_meier_estimator(
|
|
|
286
308
|
--------
|
|
287
309
|
Creating a Kaplan-Meier curve:
|
|
288
310
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
311
|
+
.. plot::
|
|
312
|
+
|
|
313
|
+
>>> import matplotlib.pyplot as plt
|
|
314
|
+
>>> from sksurv.datasets import load_veterans_lung_cancer
|
|
315
|
+
>>> from sksurv.nonparametric import kaplan_meier_estimator
|
|
316
|
+
>>>
|
|
317
|
+
>>> _, y = load_veterans_lung_cancer()
|
|
318
|
+
>>> time, prob_surv, conf_int = kaplan_meier_estimator(
|
|
319
|
+
... y["Status"], y["Survival_in_days"], conf_type="log-log"
|
|
320
|
+
... )
|
|
321
|
+
>>> plt.step(time, prob_surv, where="post")
|
|
322
|
+
[...]
|
|
323
|
+
>>> plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step="post")
|
|
324
|
+
<matplotlib.collections.PolyCollection object at 0x...>
|
|
325
|
+
>>> plt.ylim(0, 1)
|
|
326
|
+
(0.0, 1.0)
|
|
327
|
+
>>> plt.show() # doctest: +SKIP
|
|
294
328
|
|
|
295
329
|
See also
|
|
296
330
|
--------
|
|
@@ -359,26 +393,44 @@ def kaplan_meier_estimator(
|
|
|
359
393
|
|
|
360
394
|
|
|
361
395
|
def nelson_aalen_estimator(event, time):
|
|
362
|
-
"""Nelson-Aalen
|
|
396
|
+
"""Computes the Nelson-Aalen estimate of the cumulative hazard function.
|
|
363
397
|
|
|
364
398
|
See [1]_, [2]_ for further description.
|
|
365
399
|
|
|
366
400
|
Parameters
|
|
367
401
|
----------
|
|
368
402
|
event : array-like, shape = (n_samples,)
|
|
369
|
-
|
|
403
|
+
A boolean array where ``True`` indicates an event and ``False`` indicates
|
|
404
|
+
right-censoring.
|
|
370
405
|
|
|
371
406
|
time : array-like, shape = (n_samples,)
|
|
372
|
-
|
|
407
|
+
Time of event or censoring.
|
|
373
408
|
|
|
374
409
|
Returns
|
|
375
410
|
-------
|
|
376
|
-
time :
|
|
411
|
+
time : ndarray, shape = (n_times,)
|
|
377
412
|
Unique times.
|
|
378
413
|
|
|
379
|
-
cum_hazard :
|
|
414
|
+
cum_hazard : ndarray, shape = (n_times,)
|
|
380
415
|
Cumulative hazard at each unique time point.
|
|
381
416
|
|
|
417
|
+
Examples
|
|
418
|
+
--------
|
|
419
|
+
Creating a cumulative hazard curve:
|
|
420
|
+
|
|
421
|
+
.. plot::
|
|
422
|
+
|
|
423
|
+
>>> import matplotlib.pyplot as plt
|
|
424
|
+
>>> from sksurv.datasets import load_aids
|
|
425
|
+
>>> from sksurv.nonparametric import nelson_aalen_estimator
|
|
426
|
+
>>>
|
|
427
|
+
>>> _, y = load_aids(endpoint="death")
|
|
428
|
+
>>> time, cum_hazard = nelson_aalen_estimator(y["censor_d"], y["time_d"])
|
|
429
|
+
>>>
|
|
430
|
+
>>> plt.step(time, cum_hazard, where="post")
|
|
431
|
+
[...]
|
|
432
|
+
>>> plt.show() # doctest: +SKIP
|
|
433
|
+
|
|
382
434
|
References
|
|
383
435
|
----------
|
|
384
436
|
.. [1] Nelson, W., "Theory and applications of hazard plotting for censored failure data",
|
|
@@ -401,15 +453,16 @@ def ipc_weights(event, time):
|
|
|
401
453
|
|
|
402
454
|
Parameters
|
|
403
455
|
----------
|
|
404
|
-
event : array, shape = (n_samples,)
|
|
405
|
-
|
|
456
|
+
event : array-like, shape = (n_samples,)
|
|
457
|
+
A boolean array where ``True`` indicates an event and ``False`` indicates
|
|
458
|
+
right-censoring.
|
|
406
459
|
|
|
407
|
-
time : array, shape = (n_samples,)
|
|
460
|
+
time : array-like, shape = (n_samples,)
|
|
408
461
|
Time when a subject experienced an event or was censored.
|
|
409
462
|
|
|
410
463
|
Returns
|
|
411
464
|
-------
|
|
412
|
-
weights :
|
|
465
|
+
weights : ndarray, shape = (n_samples,)
|
|
413
466
|
inverse probability of censoring weights
|
|
414
467
|
|
|
415
468
|
See also
|
|
@@ -469,9 +522,9 @@ class SurvivalFunctionEstimator(BaseEstimator):
|
|
|
469
522
|
Parameters
|
|
470
523
|
----------
|
|
471
524
|
y : structured array, shape = (n_samples,)
|
|
472
|
-
A structured array
|
|
473
|
-
|
|
474
|
-
second field.
|
|
525
|
+
A structured array with two fields. The first field is a boolean
|
|
526
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
527
|
+
The second field is a float with the time of event or time of censoring.
|
|
475
528
|
|
|
476
529
|
Returns
|
|
477
530
|
-------
|
|
@@ -493,13 +546,13 @@ class SurvivalFunctionEstimator(BaseEstimator):
|
|
|
493
546
|
return self
|
|
494
547
|
|
|
495
548
|
def predict_proba(self, time, return_conf_int=False):
|
|
496
|
-
"""Return probability of
|
|
549
|
+
r"""Return probability of remaining event-free at given time points.
|
|
497
550
|
|
|
498
|
-
:math
|
|
551
|
+
:math:`\hat{S}(t) = P(T > t)`
|
|
499
552
|
|
|
500
553
|
Parameters
|
|
501
554
|
----------
|
|
502
|
-
time : array, shape = (n_samples,)
|
|
555
|
+
time : array-like, shape = (n_samples,)
|
|
503
556
|
Time to estimate probability at.
|
|
504
557
|
|
|
505
558
|
return_conf_int : bool, optional, default: False
|
|
@@ -510,10 +563,10 @@ class SurvivalFunctionEstimator(BaseEstimator):
|
|
|
510
563
|
|
|
511
564
|
Returns
|
|
512
565
|
-------
|
|
513
|
-
prob :
|
|
514
|
-
Probability of
|
|
566
|
+
prob : ndarray, shape = (n_samples,)
|
|
567
|
+
Probability of remaining event-free at the given time points.
|
|
515
568
|
|
|
516
|
-
conf_int :
|
|
569
|
+
conf_int : ndarray, shape = (2, n_samples)
|
|
517
570
|
Pointwise confidence interval at the passed time points.
|
|
518
571
|
Only provided if `return_conf_int` is True.
|
|
519
572
|
"""
|
|
@@ -561,9 +614,9 @@ class CensoringDistributionEstimator(SurvivalFunctionEstimator):
|
|
|
561
614
|
Parameters
|
|
562
615
|
----------
|
|
563
616
|
y : structured array, shape = (n_samples,)
|
|
564
|
-
A structured array
|
|
565
|
-
|
|
566
|
-
second field.
|
|
617
|
+
A structured array with two fields. The first field is a boolean
|
|
618
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
619
|
+
The second field is a float with the time of event or time of censoring.
|
|
567
620
|
|
|
568
621
|
Returns
|
|
569
622
|
-------
|
|
@@ -581,20 +634,20 @@ class CensoringDistributionEstimator(SurvivalFunctionEstimator):
|
|
|
581
634
|
return self
|
|
582
635
|
|
|
583
636
|
def predict_ipcw(self, y):
|
|
584
|
-
"""Return inverse probability of censoring weights at given time points.
|
|
637
|
+
r"""Return inverse probability of censoring weights at given time points.
|
|
585
638
|
|
|
586
|
-
:math
|
|
639
|
+
:math:`\omega_i = \delta_i / \hat{G}(y_i)`
|
|
587
640
|
|
|
588
641
|
Parameters
|
|
589
642
|
----------
|
|
590
643
|
y : structured array, shape = (n_samples,)
|
|
591
|
-
A structured array
|
|
592
|
-
|
|
593
|
-
second field.
|
|
644
|
+
A structured array with two fields. The first field is a boolean
|
|
645
|
+
where ``True`` indicates an event and ``False`` indicates right-censoring.
|
|
646
|
+
The second field is a float with the time of event or time of censoring.
|
|
594
647
|
|
|
595
648
|
Returns
|
|
596
649
|
-------
|
|
597
|
-
ipcw :
|
|
650
|
+
ipcw : ndarray, shape = (n_samples,)
|
|
598
651
|
Inverse probability of censoring weights.
|
|
599
652
|
"""
|
|
600
653
|
event, time = check_y_survival(y)
|
|
@@ -638,14 +691,14 @@ def cumulative_incidence_competing_risks(
|
|
|
638
691
|
|
|
639
692
|
Parameters
|
|
640
693
|
----------
|
|
641
|
-
event : array-like, shape = (n_samples,)
|
|
642
|
-
Contains event indicators.
|
|
694
|
+
event : array-like, shape = (n_samples,), dtype = int
|
|
695
|
+
Contains event indicators. A value of 0 indicates right-censoring,
|
|
696
|
+
while a positive integer from 1 to `n_risks` corresponds to a specific risk.
|
|
697
|
+
`n_risks` is the total number of different risks.
|
|
698
|
+
It assumes there are events for all possible risks.
|
|
643
699
|
|
|
644
700
|
time_exit : array-like, shape = (n_samples,)
|
|
645
|
-
Contains event
|
|
646
|
-
Positive integers (between 1 and n_risks, n_risks being the total number of different risks)
|
|
647
|
-
indicate the possible different risks.
|
|
648
|
-
It assumes there are events for all possible risks.
|
|
701
|
+
Contains event or censoring times.
|
|
649
702
|
|
|
650
703
|
time_min : float, optional, default: None
|
|
651
704
|
Compute estimator conditional on survival at least up to
|
|
@@ -660,23 +713,24 @@ def cumulative_incidence_competing_risks(
|
|
|
660
713
|
If "log-log", estimate confidence intervals using
|
|
661
714
|
the log hazard or :math:`log(-log(S(t)))`.
|
|
662
715
|
|
|
663
|
-
var_type :
|
|
716
|
+
var_type : {'Aalen', 'Dinse', 'Dinse_Approx'}, optional, default: 'Aalen'
|
|
664
717
|
The method for estimating the variance of the estimator.
|
|
665
718
|
See [2]_, [3]_ and [4]_ for each of the methods.
|
|
666
719
|
Only used if `conf_type` is not None.
|
|
667
720
|
|
|
668
721
|
Returns
|
|
669
722
|
-------
|
|
670
|
-
time :
|
|
723
|
+
time : ndarray, shape = (n_times,)
|
|
671
724
|
Unique times.
|
|
672
725
|
|
|
673
|
-
cum_incidence :
|
|
674
|
-
Cumulative incidence
|
|
675
|
-
|
|
676
|
-
|
|
726
|
+
cum_incidence : ndarray, shape = (n_risks + 1, n_times)
|
|
727
|
+
Cumulative incidence for each risk. The first row (``cum_incidence[0]``)
|
|
728
|
+
is the cumulative incidence of any risk (total risk). The remaining
|
|
729
|
+
rows (``cum_incidence[1:]``) are the cumulative incidences for each
|
|
730
|
+
competing risk.
|
|
677
731
|
|
|
678
|
-
conf_int :
|
|
679
|
-
Pointwise confidence interval (second axis) of the
|
|
732
|
+
conf_int : ndarray, shape = (n_risks + 1, 2, n_times)
|
|
733
|
+
Pointwise confidence interval (second axis) of the cumulative incidence function
|
|
680
734
|
at each unique time point (last axis)
|
|
681
735
|
for all possible risks (first axis), including overall risk (``conf_int[0]``).
|
|
682
736
|
Only provided if `conf_type` is not None.
|
|
@@ -685,20 +739,35 @@ def cumulative_incidence_competing_risks(
|
|
|
685
739
|
--------
|
|
686
740
|
Creating cumulative incidence curves:
|
|
687
741
|
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
742
|
+
.. plot::
|
|
743
|
+
|
|
744
|
+
>>> import matplotlib.pyplot as plt
|
|
745
|
+
>>> from sksurv.datasets import load_bmt
|
|
746
|
+
>>> from sksurv.nonparametric import cumulative_incidence_competing_risks
|
|
747
|
+
>>>
|
|
748
|
+
>>> dis, bmt_df = load_bmt()
|
|
749
|
+
>>> event = bmt_df["status"]
|
|
750
|
+
>>> time = bmt_df["ftime"]
|
|
751
|
+
>>> n_risks = event.max()
|
|
752
|
+
>>>
|
|
753
|
+
>>> x, y, conf_int = cumulative_incidence_competing_risks(
|
|
754
|
+
... event, time, conf_type="log-log"
|
|
755
|
+
... )
|
|
756
|
+
>>>
|
|
757
|
+
>>> plt.step(x, y[0], where="post", label="Total risk")
|
|
758
|
+
[...]
|
|
759
|
+
>>> plt.fill_between(x, conf_int[0, 0], conf_int[0, 1], alpha=0.25, step="post")
|
|
760
|
+
<matplotlib.collections.PolyCollection object at 0x...>
|
|
761
|
+
>>> for i in range(1, n_risks + 1):
|
|
762
|
+
... plt.step(x, y[i], where="post", label=f"{i}-risk")
|
|
763
|
+
... plt.fill_between(x, conf_int[i, 0], conf_int[i, 1], alpha=0.25, step="post")
|
|
764
|
+
[...]
|
|
765
|
+
<matplotlib.collections.PolyCollection object at 0x...>
|
|
766
|
+
>>> plt.ylim(0, 1)
|
|
767
|
+
(0.0, 1.0)
|
|
768
|
+
>>> plt.legend()
|
|
769
|
+
<matplotlib.legend.Legend object at 0x...>
|
|
770
|
+
>>> plt.show() # doctest: +SKIP
|
|
702
771
|
|
|
703
772
|
References
|
|
704
773
|
----------
|
sksurv/preprocessing.py
CHANGED
|
@@ -19,40 +19,60 @@ __all__ = ["OneHotEncoder"]
|
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
def check_columns_exist(actual, expected):
|
|
22
|
+
"""Check if all expected columns are present in a dataframe.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
actual : pandas.Index
|
|
27
|
+
The actual columns of a dataframe.
|
|
28
|
+
expected : pandas.Index
|
|
29
|
+
The expected columns.
|
|
30
|
+
|
|
31
|
+
Raises
|
|
32
|
+
------
|
|
33
|
+
ValueError
|
|
34
|
+
If any of the expected columns are missing from the actual columns.
|
|
35
|
+
"""
|
|
22
36
|
missing_features = expected.difference(actual)
|
|
23
37
|
if len(missing_features) != 0:
|
|
24
38
|
raise ValueError(f"{len(missing_features)} features are missing from data: {missing_features.tolist()}")
|
|
25
39
|
|
|
26
40
|
|
|
27
41
|
class OneHotEncoder(BaseEstimator, TransformerMixin):
|
|
28
|
-
"""Encode categorical
|
|
29
|
-
to the one-hot scheme.
|
|
42
|
+
"""Encode categorical features using a one-hot scheme.
|
|
30
43
|
|
|
31
|
-
|
|
32
|
-
|
|
44
|
+
This transformer only works on pandas DataFrames. It identifies columns
|
|
45
|
+
with `category` or `object` data type as categorical features.
|
|
46
|
+
The features are encoded using a one-hot (or dummy) encoding scheme, which
|
|
47
|
+
creates a binary column for each category. By default, one category per feature
|
|
48
|
+
is dropped. a column with `M` categories is encoded as `M-1` integer columns
|
|
49
|
+
according to the one-hot scheme.
|
|
50
|
+
|
|
51
|
+
The order of non-categorical columns is preserved. Encoded columns are inserted
|
|
52
|
+
in place of the original column.
|
|
33
53
|
|
|
34
54
|
Parameters
|
|
35
55
|
----------
|
|
36
|
-
allow_drop :
|
|
56
|
+
allow_drop : bool, optional, default: True
|
|
37
57
|
Whether to allow dropping categorical columns that only consist
|
|
38
58
|
of a single category.
|
|
39
59
|
|
|
40
60
|
Attributes
|
|
41
61
|
----------
|
|
42
62
|
feature_names_ : pandas.Index
|
|
43
|
-
|
|
63
|
+
Names of categorical features that were encoded.
|
|
44
64
|
|
|
45
65
|
categories_ : dict
|
|
46
|
-
|
|
66
|
+
A dictionary mapping each categorical feature name to a list of its
|
|
67
|
+
categories.
|
|
47
68
|
|
|
48
|
-
encoded_columns_ :
|
|
49
|
-
|
|
50
|
-
Includes names of non-categorical columns.
|
|
69
|
+
encoded_columns_ : pandas.Index
|
|
70
|
+
The full list of feature names in the transformed output.
|
|
51
71
|
|
|
52
72
|
n_features_in_ : int
|
|
53
73
|
Number of features seen during ``fit``.
|
|
54
74
|
|
|
55
|
-
feature_names_in_ : ndarray
|
|
75
|
+
feature_names_in_ : ndarray, shape = (`n_features_in_`,)
|
|
56
76
|
Names of features seen during ``fit``. Defined only when `X`
|
|
57
77
|
has feature names that are all strings.
|
|
58
78
|
"""
|
|
@@ -61,18 +81,20 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
|
|
|
61
81
|
self.allow_drop = allow_drop
|
|
62
82
|
|
|
63
83
|
def fit(self, X, y=None): # pylint: disable=unused-argument
|
|
64
|
-
"""
|
|
84
|
+
"""Determine which features are categorical and should be one-hot encoded.
|
|
65
85
|
|
|
66
86
|
Parameters
|
|
67
87
|
----------
|
|
68
88
|
X : pandas.DataFrame
|
|
69
|
-
|
|
70
|
-
y :
|
|
71
|
-
Ignored.
|
|
89
|
+
The data to determine categorical features from.
|
|
90
|
+
y : None
|
|
91
|
+
Ignored. This parameter exists only for compatibility with
|
|
92
|
+
:class:`sklearn.pipeline.Pipeline`.
|
|
93
|
+
|
|
72
94
|
Returns
|
|
73
95
|
-------
|
|
74
96
|
self : object
|
|
75
|
-
Returns
|
|
97
|
+
Returns the instance itself.
|
|
76
98
|
"""
|
|
77
99
|
self.fit_transform(X)
|
|
78
100
|
return self
|
|
@@ -81,21 +103,27 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
|
|
|
81
103
|
return encode_categorical(X, columns=columns_to_encode, allow_drop=self.allow_drop)
|
|
82
104
|
|
|
83
105
|
def fit_transform(self, X, y=None, **fit_params): # pylint: disable=unused-argument
|
|
84
|
-
"""
|
|
106
|
+
"""Fit to data, then transform it.
|
|
107
|
+
|
|
108
|
+
Fits the transformer to ``X`` by identifying categorical features and
|
|
109
|
+
then returns a transformed version of ``X`` with categorical features
|
|
110
|
+
one-hot encoded.
|
|
85
111
|
|
|
86
112
|
Parameters
|
|
87
113
|
----------
|
|
88
114
|
X : pandas.DataFrame
|
|
89
|
-
|
|
90
|
-
y :
|
|
91
|
-
Ignored.
|
|
92
|
-
|
|
93
|
-
|
|
115
|
+
The data to fit and transform.
|
|
116
|
+
y : None, optional
|
|
117
|
+
Ignored. This parameter exists only for compatibility with
|
|
118
|
+
:class:`sklearn.pipeline.Pipeline`.
|
|
119
|
+
fit_params : dict, optional
|
|
120
|
+
Ignored. This parameter exists only for compatibility with
|
|
121
|
+
:class:`sklearn.pipeline.Pipeline`.
|
|
94
122
|
|
|
95
123
|
Returns
|
|
96
124
|
-------
|
|
97
125
|
Xt : pandas.DataFrame
|
|
98
|
-
|
|
126
|
+
The transformed data.
|
|
99
127
|
"""
|
|
100
128
|
_check_feature_names(self, X, reset=True)
|
|
101
129
|
_check_n_features(self, X, reset=True)
|
|
@@ -108,17 +136,17 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
|
|
|
108
136
|
return x_dummy
|
|
109
137
|
|
|
110
138
|
def transform(self, X):
|
|
111
|
-
"""
|
|
139
|
+
"""Transform ``X`` by one-hot encoding categorical features.
|
|
112
140
|
|
|
113
141
|
Parameters
|
|
114
142
|
----------
|
|
115
143
|
X : pandas.DataFrame
|
|
116
|
-
|
|
144
|
+
The data to transform.
|
|
117
145
|
|
|
118
146
|
Returns
|
|
119
147
|
-------
|
|
120
148
|
Xt : pandas.DataFrame
|
|
121
|
-
|
|
149
|
+
The transformed data.
|
|
122
150
|
"""
|
|
123
151
|
check_is_fitted(self, "encoded_columns_")
|
|
124
152
|
_check_n_features(self, X, reset=False)
|
|
@@ -136,7 +164,7 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
|
|
|
136
164
|
|
|
137
165
|
Parameters
|
|
138
166
|
----------
|
|
139
|
-
input_features : array-like of str or None, default
|
|
167
|
+
input_features : array-like of str or None, default: None
|
|
140
168
|
Input features.
|
|
141
169
|
|
|
142
170
|
- If `input_features` is `None`, then `feature_names_in_` is
|
|
Binary file
|
|
Binary file
|