lifelines 0.27.7__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lifelines/plotting.py CHANGED
@@ -46,26 +46,20 @@ def create_scipy_stats_model_from_lifelines_model(model):
46
46
  raise TypeError(
47
47
  "Cannot use qq-plot with this model. See notes here: https://lifelines.readthedocs.io/en/latest/Examples.html?highlight=qq_plot#selecting-a-parametric-model-using-qq-plots"
48
48
  )
49
-
50
49
  if dist == "weibull":
51
50
  scipy_dist = "weibull_min"
52
51
  sparams = (model.rho_, 0, model.lambda_)
53
-
54
52
  elif dist == "lognormal":
55
53
  scipy_dist = "lognorm"
56
54
  sparams = (model.sigma_, 0, np.exp(model.mu_))
57
-
58
55
  elif dist == "loglogistic":
59
56
  scipy_dist = "fisk"
60
57
  sparams = (model.beta_, 0, model.alpha_)
61
-
62
58
  elif dist == "exponential":
63
59
  scipy_dist = "expon"
64
60
  sparams = (0, model.lambda_)
65
-
66
61
  else:
67
62
  raise NotImplementedError("Distribution not implemented in SciPy")
68
-
69
63
  return getattr(stats, scipy_dist)(*sparams)
70
64
 
71
65
 
@@ -85,30 +79,44 @@ def cdf_plot(model, timeline=None, ax=None, **plot_kwargs):
85
79
 
86
80
  if ax is None:
87
81
  ax = plt.gca()
88
-
89
82
  if timeline is None:
90
83
  timeline = model.timeline
91
-
92
84
  COL_EMP = "empirical CDF"
93
85
 
94
86
  if CensoringType.is_left_censoring(model):
95
87
  empirical_kmf = KaplanMeierFitter().fit_left_censoring(
96
- model.durations, model.event_observed, label=COL_EMP, timeline=timeline, weights=model.weights, entry=model.entry
88
+ model.durations,
89
+ model.event_observed,
90
+ label=COL_EMP,
91
+ timeline=timeline,
92
+ weights=model.weights,
93
+ entry=model.entry,
97
94
  )
98
95
  elif CensoringType.is_right_censoring(model):
99
96
  empirical_kmf = KaplanMeierFitter().fit_right_censoring(
100
- model.durations, model.event_observed, label=COL_EMP, timeline=timeline, weights=model.weights, entry=model.entry
97
+ model.durations,
98
+ model.event_observed,
99
+ label=COL_EMP,
100
+ timeline=timeline,
101
+ weights=model.weights,
102
+ entry=model.entry,
101
103
  )
102
104
  elif CensoringType.is_interval_censoring(model):
103
105
  empirical_kmf = KaplanMeierFitter().fit_interval_censoring(
104
- model.lower_bound, model.upper_bound, label=COL_EMP, timeline=timeline, weights=model.weights, entry=model.entry
106
+ model.lower_bound,
107
+ model.upper_bound,
108
+ label=COL_EMP,
109
+ timeline=timeline,
110
+ weights=model.weights,
111
+ entry=model.entry,
105
112
  )
106
-
107
113
  empirical_kmf.plot_cumulative_density(ax=ax, **plot_kwargs)
108
114
 
109
115
  dist = get_distribution_name_of_lifelines_model(model)
110
116
  dist_object = create_scipy_stats_model_from_lifelines_model(model)
111
- ax.plot(timeline, dist_object.cdf(timeline), label="fitted %s" % dist, **plot_kwargs)
117
+ ax.plot(
118
+ timeline, dist_object.cdf(timeline), label="fitted %s" % dist, **plot_kwargs
119
+ )
112
120
  ax.legend()
113
121
  return ax
114
122
 
@@ -166,14 +174,12 @@ def rmst_plot(model, model2=None, t=np.inf, ax=None, text_position=None, **plot_
166
174
 
167
175
  if ax is None:
168
176
  ax = plt.gca()
169
-
170
177
  rmst = restricted_mean_survival_time(model, t=t)
171
178
  c = ax._get_lines.get_next_color()
172
179
  model.plot_survival_function(ax=ax, color=c, ci_show=False, **plot_kwargs)
173
180
 
174
181
  if text_position is None:
175
182
  text_position = (np.percentile(model.timeline, 10), 0.15)
176
-
177
183
  if model2 is not None:
178
184
  c2 = ax._get_lines.get_next_color()
179
185
  rmst2 = restricted_mean_survival_time(model2, t=t)
@@ -206,14 +212,26 @@ def rmst_plot(model, model2=None, t=np.inf, ax=None, text_position=None, **plot_
206
212
  )
207
213
 
208
214
  ax.text(
209
- text_position[0], text_position[1], "RMST(%s) -\n RMST(%s)=%.3f" % (model._label, model2._label, rmst - rmst2)
215
+ text_position[0],
216
+ text_position[1],
217
+ "RMST(%s) -\n RMST(%s)=%.3f"
218
+ % (model._label, model2._label, rmst - rmst2),
210
219
  ) # dynamically pick this.
211
220
  else:
212
221
  rmst = restricted_mean_survival_time(model, t=t)
213
- sf_exp_at_limit = model.predict(np.append(model.timeline, t)).sort_index().loc[:t]
214
- ax.fill_between(sf_exp_at_limit.index, sf_exp_at_limit.values, step="post", color=c, alpha=0.25)
215
- ax.text(text_position[0], text_position[1], "RMST=%.3f" % rmst) # dynamically pick this.
216
-
222
+ sf_exp_at_limit = (
223
+ model.predict(np.append(model.timeline, t)).sort_index().loc[:t]
224
+ )
225
+ ax.fill_between(
226
+ sf_exp_at_limit.index,
227
+ sf_exp_at_limit.values,
228
+ step="post",
229
+ color=c,
230
+ alpha=0.25,
231
+ )
232
+ ax.text(
233
+ text_position[0], text_position[1], "RMST=%.3f" % rmst
234
+ ) # dynamically pick this.
217
235
  ax.axvline(t, ls="--", color="k")
218
236
  ax.set_ylim(0, 1)
219
237
  return ax
@@ -259,7 +277,6 @@ def qq_plot(model, ax=None, scatter_color="k", **plot_kwargs):
259
277
 
260
278
  if ax is None:
261
279
  ax = plt.gca()
262
-
263
280
  dist = get_distribution_name_of_lifelines_model(model)
264
281
  dist_object = create_scipy_stats_model_from_lifelines_model(model)
265
282
 
@@ -268,21 +285,34 @@ def qq_plot(model, ax=None, scatter_color="k", **plot_kwargs):
268
285
 
269
286
  if CensoringType.is_left_censoring(model):
270
287
  kmf = KaplanMeierFitter().fit_left_censoring(
271
- model.durations, model.event_observed, label=COL_EMP, weights=model.weights, entry=model.entry
288
+ model.durations,
289
+ model.event_observed,
290
+ label=COL_EMP,
291
+ weights=model.weights,
292
+ entry=model.entry,
272
293
  )
273
294
  sf, cdf = kmf.survival_function_[COL_EMP], kmf.cumulative_density_[COL_EMP]
274
295
  elif CensoringType.is_right_censoring(model):
275
296
  kmf = KaplanMeierFitter().fit_right_censoring(
276
- model.durations, model.event_observed, label=COL_EMP, weights=model.weights, entry=model.entry
297
+ model.durations,
298
+ model.event_observed,
299
+ label=COL_EMP,
300
+ weights=model.weights,
301
+ entry=model.entry,
277
302
  )
278
303
  sf, cdf = kmf.survival_function_[COL_EMP], kmf.cumulative_density_[COL_EMP]
279
-
280
304
  elif CensoringType.is_interval_censoring(model):
281
305
  kmf = KaplanMeierFitter().fit_interval_censoring(
282
- model.lower_bound, model.upper_bound, label=COL_EMP, weights=model.weights, entry=model.entry
306
+ model.lower_bound,
307
+ model.upper_bound,
308
+ label=COL_EMP,
309
+ weights=model.weights,
310
+ entry=model.entry,
311
+ )
312
+ sf, cdf = (
313
+ kmf.survival_function_.mean(1),
314
+ kmf.cumulative_density_[COL_EMP + "_lower"],
283
315
  )
284
- sf, cdf = kmf.survival_function_.mean(1), kmf.cumulative_density_[COL_EMP + "_lower"]
285
-
286
316
  q = np.unique(cdf.values)
287
317
 
288
318
  quantiles = qth_survival_times(1 - q, sf)
@@ -291,7 +321,9 @@ def qq_plot(model, ax=None, scatter_color="k", **plot_kwargs):
291
321
 
292
322
  max_, min_ = quantiles[COL_EMP].max(), quantiles[COL_EMP].min()
293
323
 
294
- quantiles.plot.scatter(COL_THEO, COL_EMP, c="none", edgecolor=scatter_color, lw=0.5, ax=ax)
324
+ quantiles.plot.scatter(
325
+ COL_THEO, COL_EMP, c="none", edgecolor=scatter_color, lw=0.5, ax=ax
326
+ )
295
327
  ax.plot([min_, max_], [min_, max_], c="k", ls=":", lw=1.0)
296
328
  ax.set_ylim(min_, max_)
297
329
  ax.set_xlim(min_, max_)
@@ -446,30 +478,28 @@ def add_at_risk_counts(
446
478
 
447
479
  if ax is None:
448
480
  ax = plt.gca()
449
-
450
481
  fig = kwargs.pop("fig", None)
451
482
  if fig is None:
452
483
  fig = plt.gcf()
453
-
454
484
  if labels is None:
455
485
  labels = [f._label for f in fitters]
456
486
  elif labels is False:
457
487
  labels = [None] * len(fitters)
458
-
459
488
  if rows_to_show is None:
460
489
  rows_to_show = ["At risk", "Censored", "Events"]
461
490
  else:
462
491
  assert all(
463
492
  row in ["At risk", "Censored", "Events"] for row in rows_to_show
464
493
  ), 'must be one of ["At risk", "Censored", "Events"]'
465
-
466
494
  n_rows = len(rows_to_show)
467
495
 
468
496
  # Create another axes where we can put size ticks
469
497
  ax2 = plt.twiny(ax=ax)
470
498
  # Move the ticks below existing axes
471
499
  # Appropriate length scaled for 6 inches. Adjust for figure size.
472
- ax_height = (ax.get_position().y1 - ax.get_position().y0) * fig.get_figheight() # axis height
500
+ ax_height = (
501
+ ax.get_position().y1 - ax.get_position().y0
502
+ ) * fig.get_figheight() # axis height
473
503
  ax2_ypos = ypos / ax_height
474
504
 
475
505
  move_spines(ax2, ["bottom"], [ax2_ypos])
@@ -502,34 +532,56 @@ def add_at_risk_counts(
502
532
  if at_risk_count_from_start_of_period:
503
533
  event_table_slice = f.event_table.assign(at_risk=lambda x: x.at_risk)
504
534
  else:
505
- event_table_slice = f.event_table.assign(at_risk=lambda x: x.at_risk - x.removed)
506
-
535
+ event_table_slice = f.event_table.assign(
536
+ at_risk=lambda x: x.at_risk - x.removed
537
+ )
507
538
  if not event_table_slice.loc[:tick].empty:
508
539
  event_table_slice = (
509
540
  event_table_slice.loc[:tick, ["at_risk", "censored", "observed"]]
510
- .agg({"at_risk": lambda x: x.tail(1).values, "censored": "sum", "observed": "sum"}) # see #1385
511
- .rename({"at_risk": "At risk", "censored": "Censored", "observed": "Events"})
541
+ .agg(
542
+ {
543
+ "at_risk": lambda x: x.tail(1).values,
544
+ "censored": "sum",
545
+ "observed": "sum",
546
+ }
547
+ ) # see #1385
548
+ .rename(
549
+ {
550
+ "at_risk": "At risk",
551
+ "censored": "Censored",
552
+ "observed": "Events",
553
+ }
554
+ )
512
555
  .fillna(0)
513
556
  )
514
557
  counts.extend([int(c) for c in event_table_slice.loc[rows_to_show]])
515
558
  else:
516
559
  counts.extend([0 for _ in range(n_rows)])
517
-
518
560
  if n_rows > 1:
519
561
  if tick == ax2.get_xticks()[0]:
520
562
  max_length = len(str(max(counts)))
521
563
  for i, c in enumerate(counts):
522
564
  if i % n_rows == 0:
523
565
  if is_latex_enabled():
524
- lbl += ("\n" if i > 0 else "") + r"\textbf{%s}" % labels[int(i / n_rows)] + "\n"
566
+ lbl += (
567
+ ("\n" if i > 0 else "")
568
+ + r"\textbf{%s}" % labels[int(i / n_rows)]
569
+ + "\n"
570
+ )
525
571
  else:
526
- lbl += ("\n" if i > 0 else "") + r"%s" % labels[int(i / n_rows)] + "\n"
527
-
572
+ lbl += (
573
+ ("\n" if i > 0 else "")
574
+ + r"%s" % labels[int(i / n_rows)]
575
+ + "\n"
576
+ )
528
577
  l = rows_to_show[i % n_rows]
529
- s = "{}".format(l.rjust(10, " ")) + (" " * (max_length - len(str(c)) + 3)) + "{{:>{}d}}\n".format(max_length)
578
+ s = (
579
+ "{}".format(l.rjust(10, " "))
580
+ + (" " * (max_length - len(str(c)) + 3))
581
+ + "{{:>{}d}}\n".format(max_length)
582
+ )
530
583
 
531
584
  lbl += s.format(c)
532
-
533
585
  else:
534
586
  # Create tick label
535
587
  lbl += ""
@@ -538,7 +590,6 @@ def add_at_risk_counts(
538
590
  lbl += "\n\n"
539
591
  s = "\n{}"
540
592
  lbl += s.format(c)
541
-
542
593
  else:
543
594
  # if only one row to show, show in "condensed" version
544
595
  if tick == ax2.get_xticks()[0]:
@@ -553,16 +604,13 @@ def add_at_risk_counts(
553
604
  + "{{:>{}d}}\n".format(max_length)
554
605
  )
555
606
  lbl += s.format(c)
556
-
557
607
  else:
558
608
  # Create tick label
559
609
  lbl += ""
560
610
  for i, c in enumerate(counts):
561
611
  s = "\n{}"
562
612
  lbl += s.format(c)
563
-
564
613
  ticklabels.append(lbl)
565
-
566
614
  # Align labels to the right so numbers can be compared easily
567
615
  ax2.set_xticklabels(ticklabels, ha="right", **kwargs)
568
616
 
@@ -620,14 +668,16 @@ def plot_interval_censored_lifetimes(
620
668
 
621
669
  if ax is None:
622
670
  ax = plt.gca()
623
-
624
671
  # If lower_bounds is pd.Series with non-default index, then use index values as y-axis labels.
625
- label_plot_bars = type(lower_bound) is pd.Series and type(lower_bound.index) is not pd.RangeIndex
672
+ label_plot_bars = (
673
+ type(lower_bound) is pd.Series and type(lower_bound.index) is not pd.RangeIndex
674
+ )
626
675
 
627
676
  N = lower_bound.shape[0]
628
677
  if N > 25:
629
- warnings.warn("For less visual clutter, you may want to subsample to less than 25 individuals.")
630
-
678
+ warnings.warn(
679
+ "For less visual clutter, you may want to subsample to less than 25 individuals."
680
+ )
631
681
  assert upper_bound.shape[0] == N
632
682
 
633
683
  if sort_by_lower_bound:
@@ -636,10 +686,8 @@ def plot_interval_censored_lifetimes(
636
686
  lower_bound = _iloc(lower_bound, ix)
637
687
  if entry is not None:
638
688
  entry = _iloc(entry, ix)
639
-
640
689
  if entry is None:
641
690
  entry = np.zeros(N)
642
-
643
691
  for i in range(N):
644
692
  if np.isposinf(_iloc(upper_bound, i)):
645
693
  c = event_right_censored_color
@@ -652,10 +700,8 @@ def plot_interval_censored_lifetimes(
652
700
  else:
653
701
  ax.scatter(_iloc(lower_bound, i), i, color=c, marker=">", s=13)
654
702
  ax.scatter(_iloc(upper_bound, i), i, color=c, marker="<", s=13)
655
-
656
703
  if left_truncated:
657
704
  ax.hlines(i, 0, _iloc(entry, i), color=c, lw=1.0, linestyle="--")
658
-
659
705
  if label_plot_bars:
660
706
  ax.set_yticks(range(0, N))
661
707
  ax.set_yticklabels(lower_bound.index)
@@ -663,7 +709,6 @@ def plot_interval_censored_lifetimes(
663
709
  from matplotlib.ticker import MaxNLocator
664
710
 
665
711
  ax.yaxis.set_major_locator(MaxNLocator(integer=True))
666
-
667
712
  ax.set_xlim(0)
668
713
  ax.set_ylim(-0.5, N)
669
714
  return ax
@@ -718,20 +763,20 @@ def plot_lifetimes(
718
763
 
719
764
  if ax is None:
720
765
  ax = plt.gca()
721
-
722
766
  # If durations is pd.Series with non-default index, then use index values as y-axis labels.
723
- label_plot_bars = type(durations) is pd.Series and type(durations.index) is not pd.RangeIndex
767
+ label_plot_bars = (
768
+ type(durations) is pd.Series and type(durations.index) is not pd.RangeIndex
769
+ )
724
770
 
725
771
  N = durations.shape[0]
726
772
  if N > 25:
727
- warnings.warn("For less visual clutter, you may want to subsample to less than 25 individuals.")
728
-
773
+ warnings.warn(
774
+ "For less visual clutter, you may want to subsample to less than 25 individuals."
775
+ )
729
776
  if event_observed is None:
730
777
  event_observed = np.ones(N, dtype=bool)
731
-
732
778
  if entry is None:
733
779
  entry = np.zeros(N)
734
-
735
780
  assert durations.shape[0] == N
736
781
  assert event_observed.shape[0] == N
737
782
 
@@ -741,7 +786,6 @@ def plot_lifetimes(
741
786
  durations = _iloc(durations, ix)
742
787
  event_observed = _iloc(event_observed, ix)
743
788
  entry = _iloc(entry, ix)
744
-
745
789
  for i in range(N):
746
790
  c = event_observed_color if _iloc(event_observed, i) else event_censored_color
747
791
  ax.hlines(i, _iloc(entry, i), _iloc(durations, i), color=c, lw=1.5)
@@ -749,7 +793,6 @@ def plot_lifetimes(
749
793
  ax.hlines(i, 0, _iloc(entry, i), color=c, lw=1.0, linestyle="--")
750
794
  m = "" if not _iloc(event_observed, i) else "o"
751
795
  ax.scatter(_iloc(durations, i), i, color=c, marker=m, s=13)
752
-
753
796
  if label_plot_bars:
754
797
  ax.set_yticks(range(0, N))
755
798
  ax.set_yticklabels(durations.index)
@@ -757,14 +800,17 @@ def plot_lifetimes(
757
800
  from matplotlib.ticker import MaxNLocator
758
801
 
759
802
  ax.yaxis.set_major_locator(MaxNLocator(integer=True))
760
-
761
803
  ax.set_xlim(0)
762
804
  ax.set_ylim(-0.5, N)
763
805
  return ax
764
806
 
765
807
 
766
808
  def set_kwargs_color(kwargs):
767
- kwargs["color"] = coalesce(kwargs.pop("c", None), kwargs.pop("color", None), kwargs["ax"]._get_lines.get_next_color())
809
+ kwargs["color"] = coalesce(
810
+ kwargs.pop("c", None),
811
+ kwargs.pop("color", None),
812
+ kwargs["ax"]._get_lines.get_next_color(),
813
+ )
768
814
 
769
815
 
770
816
  def set_kwargs_drawstyle(kwargs, default="steps-post"):
@@ -778,15 +824,20 @@ def set_kwargs_label(kwargs, cls):
778
824
  def create_dataframe_slicer(iloc, loc, timeline):
779
825
  if (loc is not None) and (iloc is not None):
780
826
  raise ValueError("Cannot set both loc and iloc in call to .plot().")
781
-
782
827
  user_did_not_specify_certain_indexes = (iloc is None) and (loc is None)
783
- user_submitted_slice = slice(timeline.min(), timeline.max()) if user_did_not_specify_certain_indexes else coalesce(loc, iloc)
828
+ user_submitted_slice = (
829
+ slice(timeline.min(), timeline.max())
830
+ if user_did_not_specify_certain_indexes
831
+ else coalesce(loc, iloc)
832
+ )
784
833
 
785
834
  get_method = "iloc" if iloc is not None else "loc"
786
835
  return lambda df: getattr(df, get_method)[user_submitted_slice]
787
836
 
788
837
 
789
- def loglogs_plot(cls, loc=None, iloc=None, show_censors=False, censor_styles=None, ax=None, **kwargs):
838
+ def loglogs_plot(
839
+ cls, loc=None, iloc=None, show_censors=False, censor_styles=None, ax=None, **kwargs
840
+ ):
790
841
  """
791
842
  Specifies a plot of the log(-log(SV)) versus log(time) where SV is the estimated survival function.
792
843
  """
@@ -797,13 +848,10 @@ def loglogs_plot(cls, loc=None, iloc=None, show_censors=False, censor_styles=Non
797
848
 
798
849
  if (loc is not None) and (iloc is not None):
799
850
  raise ValueError("Cannot set both loc and iloc in call to .plot().")
800
-
801
851
  if censor_styles is None:
802
852
  censor_styles = {}
803
-
804
853
  if ax is None:
805
854
  ax = plt.gca()
806
-
807
855
  kwargs["ax"] = ax
808
856
  set_kwargs_color(kwargs)
809
857
  set_kwargs_drawstyle(kwargs)
@@ -816,10 +864,11 @@ def loglogs_plot(cls, loc=None, iloc=None, show_censors=False, censor_styles=Non
816
864
  if show_censors and cls.event_table["censored"].sum() > 0:
817
865
  cs = {"marker": "|", "ms": 12, "mew": 1}
818
866
  cs.update(censor_styles)
819
- times = dataframe_slicer(cls.event_table.loc[(cls.event_table["censored"] > 0)]).index.values.astype(float)
867
+ times = dataframe_slicer(
868
+ cls.event_table.loc[(cls.event_table["censored"] > 0)]
869
+ ).index.values.astype(float)
820
870
  v = cls.predict(times)
821
871
  ax.plot(np.log(times), loglog(v), linestyle="None", color=colour, **cs)
822
-
823
872
  # plot estimate
824
873
  sliced_estimates = dataframe_slicer(loglog(cls.survival_function_))
825
874
  sliced_estimates["log(timeline)"] = np.log(sliced_estimates.index)
@@ -848,7 +897,6 @@ def _plot_estimate(
848
897
  ax=None,
849
898
  **kwargs
850
899
  ):
851
-
852
900
  """
853
901
  Plots a pretty figure of estimates
854
902
 
@@ -904,21 +952,28 @@ def _plot_estimate(
904
952
  DeprecationWarning,
905
953
  )
906
954
  ci_only_lines = ci_force_lines
907
-
908
- plot_estimate_config = PlotEstimateConfig(cls, estimate, loc, iloc, show_censors, censor_styles, logx, ax, **kwargs)
955
+ if "point_in_time" in kwargs: # marker for given point
956
+ point_in_time = kwargs["point_in_time"]
957
+ kwargs.pop("point_in_time")
958
+ plot_estimate_config = PlotEstimateConfig(
959
+ cls, estimate, loc, iloc, show_censors, censor_styles, logx, ax, **kwargs
960
+ )
909
961
 
910
962
  dataframe_slicer = create_dataframe_slicer(iloc, loc, cls.timeline)
911
963
 
912
964
  if show_censors and cls.event_table["censored"].sum() > 0:
913
965
  cs = {"marker": "+", "ms": 12, "mew": 1}
914
966
  cs.update(plot_estimate_config.censor_styles)
915
- censored_times = dataframe_slicer(cls.event_table.loc[(cls.event_table["censored"] > 0)]).index.values.astype(float)
967
+ censored_times = dataframe_slicer(
968
+ cls.event_table.loc[(cls.event_table["censored"] > 0)]
969
+ ).index.values.astype(float)
916
970
  v = plot_estimate_config.predict_at_times(censored_times).values
917
- plot_estimate_config.ax.plot(censored_times, v, linestyle="None", color=plot_estimate_config.colour, **cs)
918
-
919
- dataframe_slicer(plot_estimate_config.estimate_).rename(columns=lambda _: plot_estimate_config.kwargs.pop("label")).plot(
920
- logx=plot_estimate_config.logx, **plot_estimate_config.kwargs
921
- )
971
+ plot_estimate_config.ax.plot(
972
+ censored_times, v, linestyle="None", color=plot_estimate_config.colour, **cs
973
+ )
974
+ dataframe_slicer(plot_estimate_config.estimate_).rename(
975
+ columns=lambda _: plot_estimate_config.kwargs.pop("label")
976
+ ).plot(logx=plot_estimate_config.logx, **plot_estimate_config.kwargs)
922
977
 
923
978
  # plot confidence intervals
924
979
  if ci_show:
@@ -940,15 +995,20 @@ def _plot_estimate(
940
995
  )
941
996
  )
942
997
  else:
943
- x = dataframe_slicer(plot_estimate_config.confidence_interval_).index.values.astype(float)
944
- lower = dataframe_slicer(plot_estimate_config.confidence_interval_.iloc[:, [0]]).values[:, 0]
945
- upper = dataframe_slicer(plot_estimate_config.confidence_interval_.iloc[:, [1]]).values[:, 0]
998
+ x = dataframe_slicer(
999
+ plot_estimate_config.confidence_interval_
1000
+ ).index.values.astype(float)
1001
+ lower = dataframe_slicer(
1002
+ plot_estimate_config.confidence_interval_.iloc[:, [0]]
1003
+ ).values[:, 0]
1004
+ upper = dataframe_slicer(
1005
+ plot_estimate_config.confidence_interval_.iloc[:, [1]]
1006
+ ).values[:, 0]
946
1007
 
947
1008
  if plot_estimate_config.kwargs["drawstyle"] == "default":
948
1009
  step = None
949
1010
  elif plot_estimate_config.kwargs["drawstyle"].startswith("step"):
950
1011
  step = plot_estimate_config.kwargs["drawstyle"].replace("steps-", "")
951
-
952
1012
  plot_estimate_config.ax.fill_between(
953
1013
  x,
954
1014
  lower,
@@ -958,23 +1018,35 @@ def _plot_estimate(
958
1018
  linewidth=0.0 if ci_no_lines else 1.0,
959
1019
  step=step,
960
1020
  )
961
-
962
1021
  if at_risk_counts:
963
1022
  add_at_risk_counts(cls, ax=plot_estimate_config.ax)
964
1023
  plt.tight_layout()
965
-
1024
+ if "point_in_time" in locals():
1025
+ plot_estimate_config.ax.scatter(
1026
+ point_in_time, cls.survival_function_at_times(point_in_time)
1027
+ )
966
1028
  return plot_estimate_config.ax
967
1029
 
968
1030
 
969
1031
  class PlotEstimateConfig:
970
- def __init__(self, cls, estimate: Union[str, pd.DataFrame], loc, iloc, show_censors, censor_styles, logx, ax, **kwargs):
1032
+ def __init__(
1033
+ self,
1034
+ cls,
1035
+ estimate: Union[str, pd.DataFrame],
1036
+ loc,
1037
+ iloc,
1038
+ show_censors,
1039
+ censor_styles,
1040
+ logx,
1041
+ ax,
1042
+ **kwargs
1043
+ ):
971
1044
  from matplotlib import pyplot as plt
972
1045
 
973
1046
  self.censor_styles = coalesce(censor_styles, {})
974
1047
 
975
1048
  if ax is None:
976
1049
  ax = plt.gca()
977
-
978
1050
  kwargs["ax"] = ax
979
1051
  set_kwargs_color(kwargs)
980
1052
  set_kwargs_drawstyle(kwargs)
@@ -311,7 +311,7 @@ def _expected_value_of_survival_squared_up_to_t(
311
311
 
312
312
  if isinstance(model_or_survival_function, pd.DataFrame):
313
313
  sf = model_or_survival_function.loc[:t]
314
- sf = sf.append(pd.DataFrame([1], index=[0], columns=sf.columns)).sort_index()
314
+ sf = pd.concat((sf, pd.DataFrame([1], index=[0], columns=sf.columns))).sort_index()
315
315
  sf_tau = sf * sf.index.values[:, None]
316
316
  return 2 * trapz(y=sf_tau.values[:, 0], x=sf_tau.index)
317
317
  elif isinstance(model_or_survival_function, lifelines.fitters.UnivariateFitter):
@@ -561,7 +561,7 @@ def _group_event_table_by_intervals(event_table, intervals) -> pd.DataFrame:
561
561
  )
562
562
  # convert columns from multiindex
563
563
  event_table.columns = event_table.columns.droplevel(1)
564
- return event_table.bfill()
564
+ return event_table.bfill().fillna(0)
565
565
 
566
566
 
567
567
  def survival_events_from_table(survival_table, observed_deaths_col="observed", censored_col="censored"):
@@ -744,9 +744,6 @@ def k_fold_cross_validation(
744
744
  results: list
745
745
  (k,1) list of scores for each fold. The scores can be anything.
746
746
 
747
- See Also
748
- ---------
749
- lifelines.utils.sklearn_adapter.sklearn_adapter
750
747
 
751
748
  """
752
749
  # Make sure fitters is a list
@@ -884,6 +881,7 @@ def _additive_estimate(events, timeline, _additive_f, _additive_var, reverse):
884
881
  population = events["at_risk"] - entrances
885
882
 
886
883
  estimate_ = np.cumsum(_additive_f(population, deaths))
884
+
887
885
  var_ = np.cumsum(_additive_var(population, deaths))
888
886
 
889
887
  timeline = sorted(timeline)
@@ -1908,7 +1906,7 @@ class CovariateParameterMappings:
1908
1906
 
1909
1907
  Xs = {}
1910
1908
  for param_name, transform in self.mappings.items():
1911
- if isinstance(transform, formulaic.formula.Formula):
1909
+ if isinstance(transform, formulaic.ModelSpec):
1912
1910
  X = transform.get_model_matrix(df)
1913
1911
  elif isinstance(transform, list):
1914
1912
  if self.force_intercept:
@@ -1954,6 +1952,6 @@ class CovariateParameterMappings:
1954
1952
  if self.force_intercept:
1955
1953
  formula += "+ 1"
1956
1954
 
1957
- design_info = formulaic.Formula(formula)
1955
+ design_info = formulaic.ModelSpec.from_spec(formulaic.Formula(formula).get_model_matrix(df))
1958
1956
 
1959
1957
  return design_info
lifelines/version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  from __future__ import unicode_literals
3
3
 
4
- __version__ = "0.27.7"
4
+ __version__ = "0.28.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lifelines
3
- Version: 0.27.7
3
+ Version: 0.28.0
4
4
  Summary: Survival analysis in Python, including Kaplan Meier, Nelson Aalen and regression
5
5
  Home-page: https://github.com/CamDavidsonPilon/lifelines
6
6
  Author: Cameron Davidson-Pilon
@@ -9,22 +9,20 @@ License: MIT
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: License :: OSI Approved :: MIT License
11
11
  Classifier: Programming Language :: Python
12
- Classifier: Programming Language :: Python :: 3.7
13
- Classifier: Programming Language :: Python :: 3.8
14
12
  Classifier: Programming Language :: Python :: 3.9
15
13
  Classifier: Programming Language :: Python :: 3.10
16
14
  Classifier: Programming Language :: Python :: 3.11
17
15
  Classifier: Topic :: Scientific/Engineering
18
- Requires-Python: >=3.7
16
+ Requires-Python: >=3.9
19
17
  Description-Content-Type: text/markdown
20
18
  License-File: LICENSE
21
- Requires-Dist: numpy (>=1.14.0)
22
- Requires-Dist: scipy (>=1.2.0)
23
- Requires-Dist: pandas (>=1.0.0)
24
- Requires-Dist: matplotlib (>=3.0)
25
- Requires-Dist: autograd (>=1.5)
26
- Requires-Dist: autograd-gamma (>=0.3)
27
- Requires-Dist: formulaic (>=0.2.2)
19
+ Requires-Dist: numpy <2.0,>=1.14.0
20
+ Requires-Dist: scipy >=1.2.0
21
+ Requires-Dist: pandas >=1.2.0
22
+ Requires-Dist: matplotlib >=3.0
23
+ Requires-Dist: autograd >=1.5
24
+ Requires-Dist: autograd-gamma >=0.3
25
+ Requires-Dist: formulaic >=0.2.2
28
26
 
29
27
  ![](http://i.imgur.com/EOowdSD.png)
30
28