lifelines 0.27.7__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lifelines/datasets/__init__.py +2 -2
- lifelines/exceptions.py +4 -0
- lifelines/fitters/__init__.py +33 -20
- lifelines/fitters/aalen_johansen_fitter.py +44 -0
- lifelines/fitters/breslow_fleming_harrington_fitter.py +9 -1
- lifelines/fitters/cox_time_varying_fitter.py +15 -10
- lifelines/fitters/coxph_fitter.py +17 -13
- lifelines/fitters/generalized_gamma_fitter.py +6 -5
- lifelines/fitters/kaplan_meier_fitter.py +9 -3
- lifelines/fitters/mixins.py +8 -3
- lifelines/fitters/nelson_aalen_fitter.py +2 -2
- lifelines/plotting.py +163 -91
- lifelines/utils/__init__.py +5 -7
- lifelines/version.py +1 -1
- {lifelines-0.27.7.dist-info → lifelines-0.28.0.dist-info}/METADATA +9 -11
- {lifelines-0.27.7.dist-info → lifelines-0.28.0.dist-info}/RECORD +19 -22
- {lifelines-0.27.7.dist-info → lifelines-0.28.0.dist-info}/WHEEL +1 -1
- lifelines/datasets/ACTG175.csv +0 -2140
- lifelines/metrics.py +0 -60
- lifelines/utils/sklearn_adapter.py +0 -135
- {lifelines-0.27.7.dist-info → lifelines-0.28.0.dist-info}/LICENSE +0 -0
- {lifelines-0.27.7.dist-info → lifelines-0.28.0.dist-info}/top_level.txt +0 -0
lifelines/plotting.py
CHANGED
|
@@ -46,26 +46,20 @@ def create_scipy_stats_model_from_lifelines_model(model):
|
|
|
46
46
|
raise TypeError(
|
|
47
47
|
"Cannot use qq-plot with this model. See notes here: https://lifelines.readthedocs.io/en/latest/Examples.html?highlight=qq_plot#selecting-a-parametric-model-using-qq-plots"
|
|
48
48
|
)
|
|
49
|
-
|
|
50
49
|
if dist == "weibull":
|
|
51
50
|
scipy_dist = "weibull_min"
|
|
52
51
|
sparams = (model.rho_, 0, model.lambda_)
|
|
53
|
-
|
|
54
52
|
elif dist == "lognormal":
|
|
55
53
|
scipy_dist = "lognorm"
|
|
56
54
|
sparams = (model.sigma_, 0, np.exp(model.mu_))
|
|
57
|
-
|
|
58
55
|
elif dist == "loglogistic":
|
|
59
56
|
scipy_dist = "fisk"
|
|
60
57
|
sparams = (model.beta_, 0, model.alpha_)
|
|
61
|
-
|
|
62
58
|
elif dist == "exponential":
|
|
63
59
|
scipy_dist = "expon"
|
|
64
60
|
sparams = (0, model.lambda_)
|
|
65
|
-
|
|
66
61
|
else:
|
|
67
62
|
raise NotImplementedError("Distribution not implemented in SciPy")
|
|
68
|
-
|
|
69
63
|
return getattr(stats, scipy_dist)(*sparams)
|
|
70
64
|
|
|
71
65
|
|
|
@@ -85,30 +79,44 @@ def cdf_plot(model, timeline=None, ax=None, **plot_kwargs):
|
|
|
85
79
|
|
|
86
80
|
if ax is None:
|
|
87
81
|
ax = plt.gca()
|
|
88
|
-
|
|
89
82
|
if timeline is None:
|
|
90
83
|
timeline = model.timeline
|
|
91
|
-
|
|
92
84
|
COL_EMP = "empirical CDF"
|
|
93
85
|
|
|
94
86
|
if CensoringType.is_left_censoring(model):
|
|
95
87
|
empirical_kmf = KaplanMeierFitter().fit_left_censoring(
|
|
96
|
-
model.durations,
|
|
88
|
+
model.durations,
|
|
89
|
+
model.event_observed,
|
|
90
|
+
label=COL_EMP,
|
|
91
|
+
timeline=timeline,
|
|
92
|
+
weights=model.weights,
|
|
93
|
+
entry=model.entry,
|
|
97
94
|
)
|
|
98
95
|
elif CensoringType.is_right_censoring(model):
|
|
99
96
|
empirical_kmf = KaplanMeierFitter().fit_right_censoring(
|
|
100
|
-
model.durations,
|
|
97
|
+
model.durations,
|
|
98
|
+
model.event_observed,
|
|
99
|
+
label=COL_EMP,
|
|
100
|
+
timeline=timeline,
|
|
101
|
+
weights=model.weights,
|
|
102
|
+
entry=model.entry,
|
|
101
103
|
)
|
|
102
104
|
elif CensoringType.is_interval_censoring(model):
|
|
103
105
|
empirical_kmf = KaplanMeierFitter().fit_interval_censoring(
|
|
104
|
-
model.lower_bound,
|
|
106
|
+
model.lower_bound,
|
|
107
|
+
model.upper_bound,
|
|
108
|
+
label=COL_EMP,
|
|
109
|
+
timeline=timeline,
|
|
110
|
+
weights=model.weights,
|
|
111
|
+
entry=model.entry,
|
|
105
112
|
)
|
|
106
|
-
|
|
107
113
|
empirical_kmf.plot_cumulative_density(ax=ax, **plot_kwargs)
|
|
108
114
|
|
|
109
115
|
dist = get_distribution_name_of_lifelines_model(model)
|
|
110
116
|
dist_object = create_scipy_stats_model_from_lifelines_model(model)
|
|
111
|
-
ax.plot(
|
|
117
|
+
ax.plot(
|
|
118
|
+
timeline, dist_object.cdf(timeline), label="fitted %s" % dist, **plot_kwargs
|
|
119
|
+
)
|
|
112
120
|
ax.legend()
|
|
113
121
|
return ax
|
|
114
122
|
|
|
@@ -166,14 +174,12 @@ def rmst_plot(model, model2=None, t=np.inf, ax=None, text_position=None, **plot_
|
|
|
166
174
|
|
|
167
175
|
if ax is None:
|
|
168
176
|
ax = plt.gca()
|
|
169
|
-
|
|
170
177
|
rmst = restricted_mean_survival_time(model, t=t)
|
|
171
178
|
c = ax._get_lines.get_next_color()
|
|
172
179
|
model.plot_survival_function(ax=ax, color=c, ci_show=False, **plot_kwargs)
|
|
173
180
|
|
|
174
181
|
if text_position is None:
|
|
175
182
|
text_position = (np.percentile(model.timeline, 10), 0.15)
|
|
176
|
-
|
|
177
183
|
if model2 is not None:
|
|
178
184
|
c2 = ax._get_lines.get_next_color()
|
|
179
185
|
rmst2 = restricted_mean_survival_time(model2, t=t)
|
|
@@ -206,14 +212,26 @@ def rmst_plot(model, model2=None, t=np.inf, ax=None, text_position=None, **plot_
|
|
|
206
212
|
)
|
|
207
213
|
|
|
208
214
|
ax.text(
|
|
209
|
-
text_position[0],
|
|
215
|
+
text_position[0],
|
|
216
|
+
text_position[1],
|
|
217
|
+
"RMST(%s) -\n RMST(%s)=%.3f"
|
|
218
|
+
% (model._label, model2._label, rmst - rmst2),
|
|
210
219
|
) # dynamically pick this.
|
|
211
220
|
else:
|
|
212
221
|
rmst = restricted_mean_survival_time(model, t=t)
|
|
213
|
-
sf_exp_at_limit =
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
222
|
+
sf_exp_at_limit = (
|
|
223
|
+
model.predict(np.append(model.timeline, t)).sort_index().loc[:t]
|
|
224
|
+
)
|
|
225
|
+
ax.fill_between(
|
|
226
|
+
sf_exp_at_limit.index,
|
|
227
|
+
sf_exp_at_limit.values,
|
|
228
|
+
step="post",
|
|
229
|
+
color=c,
|
|
230
|
+
alpha=0.25,
|
|
231
|
+
)
|
|
232
|
+
ax.text(
|
|
233
|
+
text_position[0], text_position[1], "RMST=%.3f" % rmst
|
|
234
|
+
) # dynamically pick this.
|
|
217
235
|
ax.axvline(t, ls="--", color="k")
|
|
218
236
|
ax.set_ylim(0, 1)
|
|
219
237
|
return ax
|
|
@@ -259,7 +277,6 @@ def qq_plot(model, ax=None, scatter_color="k", **plot_kwargs):
|
|
|
259
277
|
|
|
260
278
|
if ax is None:
|
|
261
279
|
ax = plt.gca()
|
|
262
|
-
|
|
263
280
|
dist = get_distribution_name_of_lifelines_model(model)
|
|
264
281
|
dist_object = create_scipy_stats_model_from_lifelines_model(model)
|
|
265
282
|
|
|
@@ -268,21 +285,34 @@ def qq_plot(model, ax=None, scatter_color="k", **plot_kwargs):
|
|
|
268
285
|
|
|
269
286
|
if CensoringType.is_left_censoring(model):
|
|
270
287
|
kmf = KaplanMeierFitter().fit_left_censoring(
|
|
271
|
-
model.durations,
|
|
288
|
+
model.durations,
|
|
289
|
+
model.event_observed,
|
|
290
|
+
label=COL_EMP,
|
|
291
|
+
weights=model.weights,
|
|
292
|
+
entry=model.entry,
|
|
272
293
|
)
|
|
273
294
|
sf, cdf = kmf.survival_function_[COL_EMP], kmf.cumulative_density_[COL_EMP]
|
|
274
295
|
elif CensoringType.is_right_censoring(model):
|
|
275
296
|
kmf = KaplanMeierFitter().fit_right_censoring(
|
|
276
|
-
model.durations,
|
|
297
|
+
model.durations,
|
|
298
|
+
model.event_observed,
|
|
299
|
+
label=COL_EMP,
|
|
300
|
+
weights=model.weights,
|
|
301
|
+
entry=model.entry,
|
|
277
302
|
)
|
|
278
303
|
sf, cdf = kmf.survival_function_[COL_EMP], kmf.cumulative_density_[COL_EMP]
|
|
279
|
-
|
|
280
304
|
elif CensoringType.is_interval_censoring(model):
|
|
281
305
|
kmf = KaplanMeierFitter().fit_interval_censoring(
|
|
282
|
-
model.lower_bound,
|
|
306
|
+
model.lower_bound,
|
|
307
|
+
model.upper_bound,
|
|
308
|
+
label=COL_EMP,
|
|
309
|
+
weights=model.weights,
|
|
310
|
+
entry=model.entry,
|
|
311
|
+
)
|
|
312
|
+
sf, cdf = (
|
|
313
|
+
kmf.survival_function_.mean(1),
|
|
314
|
+
kmf.cumulative_density_[COL_EMP + "_lower"],
|
|
283
315
|
)
|
|
284
|
-
sf, cdf = kmf.survival_function_.mean(1), kmf.cumulative_density_[COL_EMP + "_lower"]
|
|
285
|
-
|
|
286
316
|
q = np.unique(cdf.values)
|
|
287
317
|
|
|
288
318
|
quantiles = qth_survival_times(1 - q, sf)
|
|
@@ -291,7 +321,9 @@ def qq_plot(model, ax=None, scatter_color="k", **plot_kwargs):
|
|
|
291
321
|
|
|
292
322
|
max_, min_ = quantiles[COL_EMP].max(), quantiles[COL_EMP].min()
|
|
293
323
|
|
|
294
|
-
quantiles.plot.scatter(
|
|
324
|
+
quantiles.plot.scatter(
|
|
325
|
+
COL_THEO, COL_EMP, c="none", edgecolor=scatter_color, lw=0.5, ax=ax
|
|
326
|
+
)
|
|
295
327
|
ax.plot([min_, max_], [min_, max_], c="k", ls=":", lw=1.0)
|
|
296
328
|
ax.set_ylim(min_, max_)
|
|
297
329
|
ax.set_xlim(min_, max_)
|
|
@@ -446,30 +478,28 @@ def add_at_risk_counts(
|
|
|
446
478
|
|
|
447
479
|
if ax is None:
|
|
448
480
|
ax = plt.gca()
|
|
449
|
-
|
|
450
481
|
fig = kwargs.pop("fig", None)
|
|
451
482
|
if fig is None:
|
|
452
483
|
fig = plt.gcf()
|
|
453
|
-
|
|
454
484
|
if labels is None:
|
|
455
485
|
labels = [f._label for f in fitters]
|
|
456
486
|
elif labels is False:
|
|
457
487
|
labels = [None] * len(fitters)
|
|
458
|
-
|
|
459
488
|
if rows_to_show is None:
|
|
460
489
|
rows_to_show = ["At risk", "Censored", "Events"]
|
|
461
490
|
else:
|
|
462
491
|
assert all(
|
|
463
492
|
row in ["At risk", "Censored", "Events"] for row in rows_to_show
|
|
464
493
|
), 'must be one of ["At risk", "Censored", "Events"]'
|
|
465
|
-
|
|
466
494
|
n_rows = len(rows_to_show)
|
|
467
495
|
|
|
468
496
|
# Create another axes where we can put size ticks
|
|
469
497
|
ax2 = plt.twiny(ax=ax)
|
|
470
498
|
# Move the ticks below existing axes
|
|
471
499
|
# Appropriate length scaled for 6 inches. Adjust for figure size.
|
|
472
|
-
ax_height = (
|
|
500
|
+
ax_height = (
|
|
501
|
+
ax.get_position().y1 - ax.get_position().y0
|
|
502
|
+
) * fig.get_figheight() # axis height
|
|
473
503
|
ax2_ypos = ypos / ax_height
|
|
474
504
|
|
|
475
505
|
move_spines(ax2, ["bottom"], [ax2_ypos])
|
|
@@ -502,34 +532,56 @@ def add_at_risk_counts(
|
|
|
502
532
|
if at_risk_count_from_start_of_period:
|
|
503
533
|
event_table_slice = f.event_table.assign(at_risk=lambda x: x.at_risk)
|
|
504
534
|
else:
|
|
505
|
-
event_table_slice = f.event_table.assign(
|
|
506
|
-
|
|
535
|
+
event_table_slice = f.event_table.assign(
|
|
536
|
+
at_risk=lambda x: x.at_risk - x.removed
|
|
537
|
+
)
|
|
507
538
|
if not event_table_slice.loc[:tick].empty:
|
|
508
539
|
event_table_slice = (
|
|
509
540
|
event_table_slice.loc[:tick, ["at_risk", "censored", "observed"]]
|
|
510
|
-
.agg(
|
|
511
|
-
|
|
541
|
+
.agg(
|
|
542
|
+
{
|
|
543
|
+
"at_risk": lambda x: x.tail(1).values,
|
|
544
|
+
"censored": "sum",
|
|
545
|
+
"observed": "sum",
|
|
546
|
+
}
|
|
547
|
+
) # see #1385
|
|
548
|
+
.rename(
|
|
549
|
+
{
|
|
550
|
+
"at_risk": "At risk",
|
|
551
|
+
"censored": "Censored",
|
|
552
|
+
"observed": "Events",
|
|
553
|
+
}
|
|
554
|
+
)
|
|
512
555
|
.fillna(0)
|
|
513
556
|
)
|
|
514
557
|
counts.extend([int(c) for c in event_table_slice.loc[rows_to_show]])
|
|
515
558
|
else:
|
|
516
559
|
counts.extend([0 for _ in range(n_rows)])
|
|
517
|
-
|
|
518
560
|
if n_rows > 1:
|
|
519
561
|
if tick == ax2.get_xticks()[0]:
|
|
520
562
|
max_length = len(str(max(counts)))
|
|
521
563
|
for i, c in enumerate(counts):
|
|
522
564
|
if i % n_rows == 0:
|
|
523
565
|
if is_latex_enabled():
|
|
524
|
-
lbl += (
|
|
566
|
+
lbl += (
|
|
567
|
+
("\n" if i > 0 else "")
|
|
568
|
+
+ r"\textbf{%s}" % labels[int(i / n_rows)]
|
|
569
|
+
+ "\n"
|
|
570
|
+
)
|
|
525
571
|
else:
|
|
526
|
-
lbl += (
|
|
527
|
-
|
|
572
|
+
lbl += (
|
|
573
|
+
("\n" if i > 0 else "")
|
|
574
|
+
+ r"%s" % labels[int(i / n_rows)]
|
|
575
|
+
+ "\n"
|
|
576
|
+
)
|
|
528
577
|
l = rows_to_show[i % n_rows]
|
|
529
|
-
s =
|
|
578
|
+
s = (
|
|
579
|
+
"{}".format(l.rjust(10, " "))
|
|
580
|
+
+ (" " * (max_length - len(str(c)) + 3))
|
|
581
|
+
+ "{{:>{}d}}\n".format(max_length)
|
|
582
|
+
)
|
|
530
583
|
|
|
531
584
|
lbl += s.format(c)
|
|
532
|
-
|
|
533
585
|
else:
|
|
534
586
|
# Create tick label
|
|
535
587
|
lbl += ""
|
|
@@ -538,7 +590,6 @@ def add_at_risk_counts(
|
|
|
538
590
|
lbl += "\n\n"
|
|
539
591
|
s = "\n{}"
|
|
540
592
|
lbl += s.format(c)
|
|
541
|
-
|
|
542
593
|
else:
|
|
543
594
|
# if only one row to show, show in "condensed" version
|
|
544
595
|
if tick == ax2.get_xticks()[0]:
|
|
@@ -553,16 +604,13 @@ def add_at_risk_counts(
|
|
|
553
604
|
+ "{{:>{}d}}\n".format(max_length)
|
|
554
605
|
)
|
|
555
606
|
lbl += s.format(c)
|
|
556
|
-
|
|
557
607
|
else:
|
|
558
608
|
# Create tick label
|
|
559
609
|
lbl += ""
|
|
560
610
|
for i, c in enumerate(counts):
|
|
561
611
|
s = "\n{}"
|
|
562
612
|
lbl += s.format(c)
|
|
563
|
-
|
|
564
613
|
ticklabels.append(lbl)
|
|
565
|
-
|
|
566
614
|
# Align labels to the right so numbers can be compared easily
|
|
567
615
|
ax2.set_xticklabels(ticklabels, ha="right", **kwargs)
|
|
568
616
|
|
|
@@ -620,14 +668,16 @@ def plot_interval_censored_lifetimes(
|
|
|
620
668
|
|
|
621
669
|
if ax is None:
|
|
622
670
|
ax = plt.gca()
|
|
623
|
-
|
|
624
671
|
# If lower_bounds is pd.Series with non-default index, then use index values as y-axis labels.
|
|
625
|
-
label_plot_bars =
|
|
672
|
+
label_plot_bars = (
|
|
673
|
+
type(lower_bound) is pd.Series and type(lower_bound.index) is not pd.RangeIndex
|
|
674
|
+
)
|
|
626
675
|
|
|
627
676
|
N = lower_bound.shape[0]
|
|
628
677
|
if N > 25:
|
|
629
|
-
warnings.warn(
|
|
630
|
-
|
|
678
|
+
warnings.warn(
|
|
679
|
+
"For less visual clutter, you may want to subsample to less than 25 individuals."
|
|
680
|
+
)
|
|
631
681
|
assert upper_bound.shape[0] == N
|
|
632
682
|
|
|
633
683
|
if sort_by_lower_bound:
|
|
@@ -636,10 +686,8 @@ def plot_interval_censored_lifetimes(
|
|
|
636
686
|
lower_bound = _iloc(lower_bound, ix)
|
|
637
687
|
if entry is not None:
|
|
638
688
|
entry = _iloc(entry, ix)
|
|
639
|
-
|
|
640
689
|
if entry is None:
|
|
641
690
|
entry = np.zeros(N)
|
|
642
|
-
|
|
643
691
|
for i in range(N):
|
|
644
692
|
if np.isposinf(_iloc(upper_bound, i)):
|
|
645
693
|
c = event_right_censored_color
|
|
@@ -652,10 +700,8 @@ def plot_interval_censored_lifetimes(
|
|
|
652
700
|
else:
|
|
653
701
|
ax.scatter(_iloc(lower_bound, i), i, color=c, marker=">", s=13)
|
|
654
702
|
ax.scatter(_iloc(upper_bound, i), i, color=c, marker="<", s=13)
|
|
655
|
-
|
|
656
703
|
if left_truncated:
|
|
657
704
|
ax.hlines(i, 0, _iloc(entry, i), color=c, lw=1.0, linestyle="--")
|
|
658
|
-
|
|
659
705
|
if label_plot_bars:
|
|
660
706
|
ax.set_yticks(range(0, N))
|
|
661
707
|
ax.set_yticklabels(lower_bound.index)
|
|
@@ -663,7 +709,6 @@ def plot_interval_censored_lifetimes(
|
|
|
663
709
|
from matplotlib.ticker import MaxNLocator
|
|
664
710
|
|
|
665
711
|
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
|
666
|
-
|
|
667
712
|
ax.set_xlim(0)
|
|
668
713
|
ax.set_ylim(-0.5, N)
|
|
669
714
|
return ax
|
|
@@ -718,20 +763,20 @@ def plot_lifetimes(
|
|
|
718
763
|
|
|
719
764
|
if ax is None:
|
|
720
765
|
ax = plt.gca()
|
|
721
|
-
|
|
722
766
|
# If durations is pd.Series with non-default index, then use index values as y-axis labels.
|
|
723
|
-
label_plot_bars =
|
|
767
|
+
label_plot_bars = (
|
|
768
|
+
type(durations) is pd.Series and type(durations.index) is not pd.RangeIndex
|
|
769
|
+
)
|
|
724
770
|
|
|
725
771
|
N = durations.shape[0]
|
|
726
772
|
if N > 25:
|
|
727
|
-
warnings.warn(
|
|
728
|
-
|
|
773
|
+
warnings.warn(
|
|
774
|
+
"For less visual clutter, you may want to subsample to less than 25 individuals."
|
|
775
|
+
)
|
|
729
776
|
if event_observed is None:
|
|
730
777
|
event_observed = np.ones(N, dtype=bool)
|
|
731
|
-
|
|
732
778
|
if entry is None:
|
|
733
779
|
entry = np.zeros(N)
|
|
734
|
-
|
|
735
780
|
assert durations.shape[0] == N
|
|
736
781
|
assert event_observed.shape[0] == N
|
|
737
782
|
|
|
@@ -741,7 +786,6 @@ def plot_lifetimes(
|
|
|
741
786
|
durations = _iloc(durations, ix)
|
|
742
787
|
event_observed = _iloc(event_observed, ix)
|
|
743
788
|
entry = _iloc(entry, ix)
|
|
744
|
-
|
|
745
789
|
for i in range(N):
|
|
746
790
|
c = event_observed_color if _iloc(event_observed, i) else event_censored_color
|
|
747
791
|
ax.hlines(i, _iloc(entry, i), _iloc(durations, i), color=c, lw=1.5)
|
|
@@ -749,7 +793,6 @@ def plot_lifetimes(
|
|
|
749
793
|
ax.hlines(i, 0, _iloc(entry, i), color=c, lw=1.0, linestyle="--")
|
|
750
794
|
m = "" if not _iloc(event_observed, i) else "o"
|
|
751
795
|
ax.scatter(_iloc(durations, i), i, color=c, marker=m, s=13)
|
|
752
|
-
|
|
753
796
|
if label_plot_bars:
|
|
754
797
|
ax.set_yticks(range(0, N))
|
|
755
798
|
ax.set_yticklabels(durations.index)
|
|
@@ -757,14 +800,17 @@ def plot_lifetimes(
|
|
|
757
800
|
from matplotlib.ticker import MaxNLocator
|
|
758
801
|
|
|
759
802
|
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
|
760
|
-
|
|
761
803
|
ax.set_xlim(0)
|
|
762
804
|
ax.set_ylim(-0.5, N)
|
|
763
805
|
return ax
|
|
764
806
|
|
|
765
807
|
|
|
766
808
|
def set_kwargs_color(kwargs):
|
|
767
|
-
kwargs["color"] = coalesce(
|
|
809
|
+
kwargs["color"] = coalesce(
|
|
810
|
+
kwargs.pop("c", None),
|
|
811
|
+
kwargs.pop("color", None),
|
|
812
|
+
kwargs["ax"]._get_lines.get_next_color(),
|
|
813
|
+
)
|
|
768
814
|
|
|
769
815
|
|
|
770
816
|
def set_kwargs_drawstyle(kwargs, default="steps-post"):
|
|
@@ -778,15 +824,20 @@ def set_kwargs_label(kwargs, cls):
|
|
|
778
824
|
def create_dataframe_slicer(iloc, loc, timeline):
|
|
779
825
|
if (loc is not None) and (iloc is not None):
|
|
780
826
|
raise ValueError("Cannot set both loc and iloc in call to .plot().")
|
|
781
|
-
|
|
782
827
|
user_did_not_specify_certain_indexes = (iloc is None) and (loc is None)
|
|
783
|
-
user_submitted_slice =
|
|
828
|
+
user_submitted_slice = (
|
|
829
|
+
slice(timeline.min(), timeline.max())
|
|
830
|
+
if user_did_not_specify_certain_indexes
|
|
831
|
+
else coalesce(loc, iloc)
|
|
832
|
+
)
|
|
784
833
|
|
|
785
834
|
get_method = "iloc" if iloc is not None else "loc"
|
|
786
835
|
return lambda df: getattr(df, get_method)[user_submitted_slice]
|
|
787
836
|
|
|
788
837
|
|
|
789
|
-
def loglogs_plot(
|
|
838
|
+
def loglogs_plot(
|
|
839
|
+
cls, loc=None, iloc=None, show_censors=False, censor_styles=None, ax=None, **kwargs
|
|
840
|
+
):
|
|
790
841
|
"""
|
|
791
842
|
Specifies a plot of the log(-log(SV)) versus log(time) where SV is the estimated survival function.
|
|
792
843
|
"""
|
|
@@ -797,13 +848,10 @@ def loglogs_plot(cls, loc=None, iloc=None, show_censors=False, censor_styles=Non
|
|
|
797
848
|
|
|
798
849
|
if (loc is not None) and (iloc is not None):
|
|
799
850
|
raise ValueError("Cannot set both loc and iloc in call to .plot().")
|
|
800
|
-
|
|
801
851
|
if censor_styles is None:
|
|
802
852
|
censor_styles = {}
|
|
803
|
-
|
|
804
853
|
if ax is None:
|
|
805
854
|
ax = plt.gca()
|
|
806
|
-
|
|
807
855
|
kwargs["ax"] = ax
|
|
808
856
|
set_kwargs_color(kwargs)
|
|
809
857
|
set_kwargs_drawstyle(kwargs)
|
|
@@ -816,10 +864,11 @@ def loglogs_plot(cls, loc=None, iloc=None, show_censors=False, censor_styles=Non
|
|
|
816
864
|
if show_censors and cls.event_table["censored"].sum() > 0:
|
|
817
865
|
cs = {"marker": "|", "ms": 12, "mew": 1}
|
|
818
866
|
cs.update(censor_styles)
|
|
819
|
-
times = dataframe_slicer(
|
|
867
|
+
times = dataframe_slicer(
|
|
868
|
+
cls.event_table.loc[(cls.event_table["censored"] > 0)]
|
|
869
|
+
).index.values.astype(float)
|
|
820
870
|
v = cls.predict(times)
|
|
821
871
|
ax.plot(np.log(times), loglog(v), linestyle="None", color=colour, **cs)
|
|
822
|
-
|
|
823
872
|
# plot estimate
|
|
824
873
|
sliced_estimates = dataframe_slicer(loglog(cls.survival_function_))
|
|
825
874
|
sliced_estimates["log(timeline)"] = np.log(sliced_estimates.index)
|
|
@@ -848,7 +897,6 @@ def _plot_estimate(
|
|
|
848
897
|
ax=None,
|
|
849
898
|
**kwargs
|
|
850
899
|
):
|
|
851
|
-
|
|
852
900
|
"""
|
|
853
901
|
Plots a pretty figure of estimates
|
|
854
902
|
|
|
@@ -904,21 +952,28 @@ def _plot_estimate(
|
|
|
904
952
|
DeprecationWarning,
|
|
905
953
|
)
|
|
906
954
|
ci_only_lines = ci_force_lines
|
|
907
|
-
|
|
908
|
-
|
|
955
|
+
if "point_in_time" in kwargs: # marker for given point
|
|
956
|
+
point_in_time = kwargs["point_in_time"]
|
|
957
|
+
kwargs.pop("point_in_time")
|
|
958
|
+
plot_estimate_config = PlotEstimateConfig(
|
|
959
|
+
cls, estimate, loc, iloc, show_censors, censor_styles, logx, ax, **kwargs
|
|
960
|
+
)
|
|
909
961
|
|
|
910
962
|
dataframe_slicer = create_dataframe_slicer(iloc, loc, cls.timeline)
|
|
911
963
|
|
|
912
964
|
if show_censors and cls.event_table["censored"].sum() > 0:
|
|
913
965
|
cs = {"marker": "+", "ms": 12, "mew": 1}
|
|
914
966
|
cs.update(plot_estimate_config.censor_styles)
|
|
915
|
-
censored_times = dataframe_slicer(
|
|
967
|
+
censored_times = dataframe_slicer(
|
|
968
|
+
cls.event_table.loc[(cls.event_table["censored"] > 0)]
|
|
969
|
+
).index.values.astype(float)
|
|
916
970
|
v = plot_estimate_config.predict_at_times(censored_times).values
|
|
917
|
-
plot_estimate_config.ax.plot(
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
971
|
+
plot_estimate_config.ax.plot(
|
|
972
|
+
censored_times, v, linestyle="None", color=plot_estimate_config.colour, **cs
|
|
973
|
+
)
|
|
974
|
+
dataframe_slicer(plot_estimate_config.estimate_).rename(
|
|
975
|
+
columns=lambda _: plot_estimate_config.kwargs.pop("label")
|
|
976
|
+
).plot(logx=plot_estimate_config.logx, **plot_estimate_config.kwargs)
|
|
922
977
|
|
|
923
978
|
# plot confidence intervals
|
|
924
979
|
if ci_show:
|
|
@@ -940,15 +995,20 @@ def _plot_estimate(
|
|
|
940
995
|
)
|
|
941
996
|
)
|
|
942
997
|
else:
|
|
943
|
-
x = dataframe_slicer(
|
|
944
|
-
|
|
945
|
-
|
|
998
|
+
x = dataframe_slicer(
|
|
999
|
+
plot_estimate_config.confidence_interval_
|
|
1000
|
+
).index.values.astype(float)
|
|
1001
|
+
lower = dataframe_slicer(
|
|
1002
|
+
plot_estimate_config.confidence_interval_.iloc[:, [0]]
|
|
1003
|
+
).values[:, 0]
|
|
1004
|
+
upper = dataframe_slicer(
|
|
1005
|
+
plot_estimate_config.confidence_interval_.iloc[:, [1]]
|
|
1006
|
+
).values[:, 0]
|
|
946
1007
|
|
|
947
1008
|
if plot_estimate_config.kwargs["drawstyle"] == "default":
|
|
948
1009
|
step = None
|
|
949
1010
|
elif plot_estimate_config.kwargs["drawstyle"].startswith("step"):
|
|
950
1011
|
step = plot_estimate_config.kwargs["drawstyle"].replace("steps-", "")
|
|
951
|
-
|
|
952
1012
|
plot_estimate_config.ax.fill_between(
|
|
953
1013
|
x,
|
|
954
1014
|
lower,
|
|
@@ -958,23 +1018,35 @@ def _plot_estimate(
|
|
|
958
1018
|
linewidth=0.0 if ci_no_lines else 1.0,
|
|
959
1019
|
step=step,
|
|
960
1020
|
)
|
|
961
|
-
|
|
962
1021
|
if at_risk_counts:
|
|
963
1022
|
add_at_risk_counts(cls, ax=plot_estimate_config.ax)
|
|
964
1023
|
plt.tight_layout()
|
|
965
|
-
|
|
1024
|
+
if "point_in_time" in locals():
|
|
1025
|
+
plot_estimate_config.ax.scatter(
|
|
1026
|
+
point_in_time, cls.survival_function_at_times(point_in_time)
|
|
1027
|
+
)
|
|
966
1028
|
return plot_estimate_config.ax
|
|
967
1029
|
|
|
968
1030
|
|
|
969
1031
|
class PlotEstimateConfig:
|
|
970
|
-
def __init__(
|
|
1032
|
+
def __init__(
|
|
1033
|
+
self,
|
|
1034
|
+
cls,
|
|
1035
|
+
estimate: Union[str, pd.DataFrame],
|
|
1036
|
+
loc,
|
|
1037
|
+
iloc,
|
|
1038
|
+
show_censors,
|
|
1039
|
+
censor_styles,
|
|
1040
|
+
logx,
|
|
1041
|
+
ax,
|
|
1042
|
+
**kwargs
|
|
1043
|
+
):
|
|
971
1044
|
from matplotlib import pyplot as plt
|
|
972
1045
|
|
|
973
1046
|
self.censor_styles = coalesce(censor_styles, {})
|
|
974
1047
|
|
|
975
1048
|
if ax is None:
|
|
976
1049
|
ax = plt.gca()
|
|
977
|
-
|
|
978
1050
|
kwargs["ax"] = ax
|
|
979
1051
|
set_kwargs_color(kwargs)
|
|
980
1052
|
set_kwargs_drawstyle(kwargs)
|
lifelines/utils/__init__.py
CHANGED
|
@@ -311,7 +311,7 @@ def _expected_value_of_survival_squared_up_to_t(
|
|
|
311
311
|
|
|
312
312
|
if isinstance(model_or_survival_function, pd.DataFrame):
|
|
313
313
|
sf = model_or_survival_function.loc[:t]
|
|
314
|
-
sf =
|
|
314
|
+
sf = pd.concat((sf, pd.DataFrame([1], index=[0], columns=sf.columns))).sort_index()
|
|
315
315
|
sf_tau = sf * sf.index.values[:, None]
|
|
316
316
|
return 2 * trapz(y=sf_tau.values[:, 0], x=sf_tau.index)
|
|
317
317
|
elif isinstance(model_or_survival_function, lifelines.fitters.UnivariateFitter):
|
|
@@ -561,7 +561,7 @@ def _group_event_table_by_intervals(event_table, intervals) -> pd.DataFrame:
|
|
|
561
561
|
)
|
|
562
562
|
# convert columns from multiindex
|
|
563
563
|
event_table.columns = event_table.columns.droplevel(1)
|
|
564
|
-
return event_table.bfill()
|
|
564
|
+
return event_table.bfill().fillna(0)
|
|
565
565
|
|
|
566
566
|
|
|
567
567
|
def survival_events_from_table(survival_table, observed_deaths_col="observed", censored_col="censored"):
|
|
@@ -744,9 +744,6 @@ def k_fold_cross_validation(
|
|
|
744
744
|
results: list
|
|
745
745
|
(k,1) list of scores for each fold. The scores can be anything.
|
|
746
746
|
|
|
747
|
-
See Also
|
|
748
|
-
---------
|
|
749
|
-
lifelines.utils.sklearn_adapter.sklearn_adapter
|
|
750
747
|
|
|
751
748
|
"""
|
|
752
749
|
# Make sure fitters is a list
|
|
@@ -884,6 +881,7 @@ def _additive_estimate(events, timeline, _additive_f, _additive_var, reverse):
|
|
|
884
881
|
population = events["at_risk"] - entrances
|
|
885
882
|
|
|
886
883
|
estimate_ = np.cumsum(_additive_f(population, deaths))
|
|
884
|
+
|
|
887
885
|
var_ = np.cumsum(_additive_var(population, deaths))
|
|
888
886
|
|
|
889
887
|
timeline = sorted(timeline)
|
|
@@ -1908,7 +1906,7 @@ class CovariateParameterMappings:
|
|
|
1908
1906
|
|
|
1909
1907
|
Xs = {}
|
|
1910
1908
|
for param_name, transform in self.mappings.items():
|
|
1911
|
-
if isinstance(transform, formulaic.
|
|
1909
|
+
if isinstance(transform, formulaic.ModelSpec):
|
|
1912
1910
|
X = transform.get_model_matrix(df)
|
|
1913
1911
|
elif isinstance(transform, list):
|
|
1914
1912
|
if self.force_intercept:
|
|
@@ -1954,6 +1952,6 @@ class CovariateParameterMappings:
|
|
|
1954
1952
|
if self.force_intercept:
|
|
1955
1953
|
formula += "+ 1"
|
|
1956
1954
|
|
|
1957
|
-
design_info = formulaic.Formula(formula)
|
|
1955
|
+
design_info = formulaic.ModelSpec.from_spec(formulaic.Formula(formula).get_model_matrix(df))
|
|
1958
1956
|
|
|
1959
1957
|
return design_info
|
lifelines/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: lifelines
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.28.0
|
|
4
4
|
Summary: Survival analysis in Python, including Kaplan Meier, Nelson Aalen and regression
|
|
5
5
|
Home-page: https://github.com/CamDavidsonPilon/lifelines
|
|
6
6
|
Author: Cameron Davidson-Pilon
|
|
@@ -9,22 +9,20 @@ License: MIT
|
|
|
9
9
|
Classifier: Development Status :: 4 - Beta
|
|
10
10
|
Classifier: License :: OSI Approved :: MIT License
|
|
11
11
|
Classifier: Programming Language :: Python
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.7
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
14
12
|
Classifier: Programming Language :: Python :: 3.9
|
|
15
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
16
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
17
15
|
Classifier: Topic :: Scientific/Engineering
|
|
18
|
-
Requires-Python: >=3.
|
|
16
|
+
Requires-Python: >=3.9
|
|
19
17
|
Description-Content-Type: text/markdown
|
|
20
18
|
License-File: LICENSE
|
|
21
|
-
Requires-Dist: numpy
|
|
22
|
-
Requires-Dist: scipy
|
|
23
|
-
Requires-Dist: pandas
|
|
24
|
-
Requires-Dist: matplotlib
|
|
25
|
-
Requires-Dist: autograd
|
|
26
|
-
Requires-Dist: autograd-gamma
|
|
27
|
-
Requires-Dist: formulaic
|
|
19
|
+
Requires-Dist: numpy <2.0,>=1.14.0
|
|
20
|
+
Requires-Dist: scipy >=1.2.0
|
|
21
|
+
Requires-Dist: pandas >=1.2.0
|
|
22
|
+
Requires-Dist: matplotlib >=3.0
|
|
23
|
+
Requires-Dist: autograd >=1.5
|
|
24
|
+
Requires-Dist: autograd-gamma >=0.3
|
|
25
|
+
Requires-Dist: formulaic >=0.2.2
|
|
28
26
|
|
|
29
27
|

|
|
30
28
|
|