lifelines 0.27.8__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lifelines/exceptions.py CHANGED
@@ -5,6 +5,10 @@ class StatError(Exception):
5
5
  pass
6
6
 
7
7
 
8
+ class ProportionalHazardAssumptionError(Exception):
9
+ pass
10
+
11
+
8
12
  class ConvergenceError(ValueError):
9
13
  # inherits from ValueError for backwards compatibility reasons
10
14
  def __init__(self, msg, original_exception=""):
@@ -550,7 +550,7 @@ class ParametricUnivariateFitter(UnivariateFitter):
550
550
  minimizing_results, previous_results, minimizing_ll = None, None, np.inf
551
551
  for method, option in zip(
552
552
  ["Nelder-Mead", self._scipy_fit_method],
553
- [{"maxiter": 100}, {**{"disp": show_progress}, **self._scipy_fit_options, **fit_options}],
553
+ [{"maxiter": 400}, {**{"disp": show_progress}, **self._scipy_fit_options, **fit_options}],
554
554
  ):
555
555
 
556
556
  initial_value = self._initial_values if previous_results is None else utils._to_1d_array(previous_results.x)
@@ -1409,7 +1409,7 @@ class ParametricRegressionFitter(RegressionFitter):
1409
1409
  def _survival_function(self, params, T, Xs):
1410
1410
  return anp.clip(anp.exp(-self._cumulative_hazard(params, T, Xs)), 1e-12, 1 - 1e-12)
1411
1411
 
1412
- def _log_likelihood_right_censoring(self, params, Ts, E, W, entries, Xs) -> float:
1412
+ def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs) -> float:
1413
1413
 
1414
1414
  T = Ts[0]
1415
1415
  non_zero_entries = entries > 0
@@ -3365,6 +3365,8 @@ class ParametericAFTRegressionFitter(ParametricRegressionFitter):
3365
3365
  also display the baseline survival, defined as the survival at the mean of the original dataset.
3366
3366
  times: iterable
3367
3367
  pass in a times to plot
3368
+ y: str
3369
+ one of "survival_function", "hazard", "cumulative_hazard". Default "survival_function"
3368
3370
  kwargs:
3369
3371
  pass in additional plotting commands
3370
3372
 
@@ -72,7 +72,14 @@ class BreslowFlemingHarringtonFitter(NonParametricUnivariateFitter):
72
72
  alpha = coalesce(alpha, self.alpha)
73
73
 
74
74
  naf = NelsonAalenFitter(alpha=alpha)
75
- naf.fit(durations, event_observed=event_observed, timeline=timeline, label=self._label, entry=entry, ci_labels=ci_labels)
75
+ naf.fit(
76
+ durations,
77
+ event_observed=event_observed,
78
+ timeline=timeline,
79
+ label=self._label,
80
+ entry=entry,
81
+ ci_labels=ci_labels,
82
+ )
76
83
  self.durations, self.event_observed, self.timeline, self.entry, self.event_table, self.weights = (
77
84
  naf.durations,
78
85
  naf.event_observed,
@@ -87,6 +94,7 @@ class BreslowFlemingHarringtonFitter(NonParametricUnivariateFitter):
87
94
  self.confidence_interval_ = np.exp(-naf.confidence_interval_)
88
95
  self.confidence_interval_survival_function_ = self.confidence_interval_
89
96
  self.confidence_interval_cumulative_density = 1 - self.confidence_interval_
97
+ self.confidence_interval_cumulative_density[:] = np.fliplr(self.confidence_interval_cumulative_density.values)
90
98
 
91
99
  # estimation methods
92
100
  self._estimation_method = "survival_function_"
@@ -80,7 +80,7 @@ class CoxPHFitter(RegressionFitter, ProportionalHazardMixin):
80
80
  When ``baseline_estimation_method="spline"``, this allows customizing the points in the time axis for the baseline hazard curve.
81
81
  To use evenly-spaced points in time, the ``n_baseline_knots`` parameter can be employed instead.
82
82
 
83
- breakpoints: int
83
+ breakpoints: list, optional
84
84
  Used when ``baseline_estimation_method="piecewise"``. Set the positions of the baseline hazard breakpoints.
85
85
 
86
86
  Examples
@@ -105,6 +105,7 @@ class GeneralizedGammaFitter(KnownModelParametricUnivariateFitter):
105
105
  """
106
106
 
107
107
  _scipy_fit_method = "SLSQP"
108
+ _scipy_fit_options = {"maxiter": 10_000, "maxfev": 10_000}
108
109
  _fitted_parameter_names = ["mu_", "ln_sigma_", "lambda_"]
109
110
  _bounds = [(None, None), (None, None), (None, None)]
110
111
  _compare_to_values = np.array([0.0, 0.0, 1.0])
@@ -117,14 +118,14 @@ class GeneralizedGammaFitter(KnownModelParametricUnivariateFitter):
117
118
  elif CensoringType.is_interval_censoring(self):
118
119
  # this fails if Ts[1] == Ts[0], so we add a some fudge factors.
119
120
  log_data = log(Ts[1] - Ts[0] + 0.1)
120
- return np.array([log_data.mean(), log(log_data.std() + 0.01), 0.1])
121
+ return np.array([log_data.mean() * 1.5, log(log_data.std() + 0.1), 1.0])
121
122
 
122
123
  def _cumulative_hazard(self, params, times):
123
124
  mu_, ln_sigma_, lambda_ = params
124
125
 
125
126
  sigma_ = safe_exp(ln_sigma_)
126
127
  Z = (log(times) - mu_) / sigma_
127
- ilambda_2 = 1 / lambda_ ** 2
128
+ ilambda_2 = 1 / lambda_**2
128
129
  clipped_exp = np.clip(safe_exp(lambda_ * Z) * ilambda_2, 1e-300, 1e20)
129
130
 
130
131
  if lambda_ > 0:
@@ -137,7 +138,7 @@ class GeneralizedGammaFitter(KnownModelParametricUnivariateFitter):
137
138
 
138
139
  def _log_hazard(self, params, times):
139
140
  mu_, ln_sigma_, lambda_ = params
140
- ilambda_2 = 1 / lambda_ ** 2
141
+ ilambda_2 = 1 / lambda_**2
141
142
  Z = (log(times) - mu_) / safe_exp(ln_sigma_)
142
143
  clipped_exp = np.clip(safe_exp(lambda_ * Z) * ilambda_2, 1e-300, 1e20)
143
144
  if lambda_ > 0:
@@ -171,5 +172,5 @@ class GeneralizedGammaFitter(KnownModelParametricUnivariateFitter):
171
172
  sigma_ = exp(self.ln_sigma_)
172
173
 
173
174
  if lambda_ > 0:
174
- return exp(sigma_ * log(gammainccinv(1 / lambda_ ** 2, p) * lambda_ ** 2) / lambda_) * exp(self.mu_)
175
- return exp(sigma_ * log(gammaincinv(1 / lambda_ ** 2, p) * lambda_ ** 2) / lambda_) * exp(self.mu_)
175
+ return exp(sigma_ * log(gammainccinv(1 / lambda_**2, p) * lambda_**2) / lambda_) * exp(self.mu_)
176
+ return exp(sigma_ * log(gammaincinv(1 / lambda_**2, p) * lambda_**2) / lambda_) * exp(self.mu_)
@@ -351,9 +351,14 @@ class KaplanMeierFitter(NonParametricUnivariateFitter):
351
351
  primary_estimate_name = "survival_function_"
352
352
  secondary_estimate_name = "cumulative_density_"
353
353
 
354
- (self.durations, self.event_observed, self.timeline, self.entry, self.event_table, self.weights) = _preprocess_inputs(
355
- durations, event_observed, timeline, entry, weights
356
- )
354
+ (
355
+ self.durations,
356
+ self.event_observed,
357
+ self.timeline,
358
+ self.entry,
359
+ self.event_table,
360
+ self.weights,
361
+ ) = _preprocess_inputs(durations, event_observed, timeline, entry, weights)
357
362
 
358
363
  alpha = alpha if alpha else self.alpha
359
364
  log_estimate, cumulative_sq_ = _additive_estimate(
@@ -386,6 +391,7 @@ class KaplanMeierFitter(NonParametricUnivariateFitter):
386
391
 
387
392
  self.confidence_interval_survival_function_ = self.confidence_interval_
388
393
  self.confidence_interval_cumulative_density_ = 1 - self.confidence_interval_
394
+ self.confidence_interval_cumulative_density_[:] = np.fliplr(self.confidence_interval_cumulative_density_.values)
389
395
  self._median = median_survival_times(self.survival_function_)
390
396
  self._cumulative_sq_ = cumulative_sq_
391
397
 
@@ -4,6 +4,7 @@ from textwrap import dedent, fill
4
4
  from autograd import numpy as anp
5
5
  import numpy as np
6
6
  from pandas import DataFrame, Series
7
+ from lifelines.exceptions import ProportionalHazardAssumptionError
7
8
  from lifelines.statistics import proportional_hazard_test, TimeTransformers
8
9
  from lifelines.utils import format_p_value
9
10
  from lifelines.utils.lowess import lowess
@@ -28,6 +29,7 @@ class ProportionalHazardMixin:
28
29
  p_value_threshold: float = 0.01,
29
30
  plot_n_bootstraps: int = 15,
30
31
  columns: Optional[List[str]] = None,
32
+ raise_on_fail: bool = False,
31
33
  ) -> None:
32
34
  """
33
35
  Use this function to test the proportional hazards assumption. See usage example at
@@ -51,6 +53,8 @@ class ProportionalHazardMixin:
51
53
  the function significantly.
52
54
  columns: list, optional
53
55
  specify a subset of columns to test.
56
+ raise_on_fail: bool, optional
57
+ throw a ``ProportionalHazardAssumptionError`` if the test fails. Default: False.
54
58
 
55
59
  Returns
56
60
  --------
@@ -107,7 +111,7 @@ class ProportionalHazardMixin:
107
111
 
108
112
  for variable in self.params_.index.intersection(columns or self.params_.index):
109
113
  minumum_observed_p_value = test_results.summary.loc[variable, "p"].min()
110
-
114
+
111
115
  # plot is done (regardless of test result) whenever `show_plots = True`
112
116
  if show_plots:
113
117
  axes.append([])
@@ -224,9 +228,8 @@ class ProportionalHazardMixin:
224
228
  ),
225
229
  end="\n\n",
226
230
  )
227
- #################
231
+ #################
228
232
 
229
-
230
233
  if advice and counter > 0:
231
234
  print(
232
235
  dedent(
@@ -243,6 +246,8 @@ class ProportionalHazardMixin:
243
246
 
244
247
  if counter == 0:
245
248
  print("Proportional hazard assumption looks okay.")
249
+ elif raise_on_fail:
250
+ raise ProportionalHazardAssumptionError()
246
251
  return axes
247
252
 
248
253
  @property
@@ -183,7 +183,7 @@ class NelsonAalenFitter(UnivariateFitter):
183
183
  )
184
184
 
185
185
  def _variance_f_discrete(self, population, deaths):
186
- return (population - deaths) * deaths / population ** 3
186
+ return (1 - deaths / population) * (deaths / population) * (1.0 / population)
187
187
 
188
188
  def _additive_f_smooth(self, population, deaths):
189
189
  cum_ = np.cumsum(1.0 / np.arange(1, np.max(population) + 1))
@@ -239,7 +239,7 @@ class NelsonAalenFitter(UnivariateFitter):
239
239
  C = var_hazard_.values != 0.0 # only consider the points with jumps
240
240
  std_hazard_ = np.sqrt(
241
241
  1.0
242
- / (bandwidth ** 2)
242
+ / (bandwidth**2)
243
243
  * np.dot(epanechnikov_kernel(timeline[:, None], timeline[C][None, :], bandwidth) ** 2, var_hazard_.values[C])
244
244
  )
245
245
  values = {
@@ -311,7 +311,7 @@ def _expected_value_of_survival_squared_up_to_t(
311
311
 
312
312
  if isinstance(model_or_survival_function, pd.DataFrame):
313
313
  sf = model_or_survival_function.loc[:t]
314
- sf = sf.append(pd.DataFrame([1], index=[0], columns=sf.columns)).sort_index()
314
+ sf = pd.concat((sf, pd.DataFrame([1], index=[0], columns=sf.columns))).sort_index()
315
315
  sf_tau = sf * sf.index.values[:, None]
316
316
  return 2 * trapz(y=sf_tau.values[:, 0], x=sf_tau.index)
317
317
  elif isinstance(model_or_survival_function, lifelines.fitters.UnivariateFitter):
@@ -561,7 +561,7 @@ def _group_event_table_by_intervals(event_table, intervals) -> pd.DataFrame:
561
561
  )
562
562
  # convert columns from multiindex
563
563
  event_table.columns = event_table.columns.droplevel(1)
564
- return event_table.bfill()
564
+ return event_table.bfill().fillna(0)
565
565
 
566
566
 
567
567
  def survival_events_from_table(survival_table, observed_deaths_col="observed", censored_col="censored"):
@@ -744,9 +744,6 @@ def k_fold_cross_validation(
744
744
  results: list
745
745
  (k,1) list of scores for each fold. The scores can be anything.
746
746
 
747
- See Also
748
- ---------
749
- lifelines.utils.sklearn_adapter.sklearn_adapter
750
747
 
751
748
  """
752
749
  # Make sure fitters is a list
@@ -884,6 +881,7 @@ def _additive_estimate(events, timeline, _additive_f, _additive_var, reverse):
884
881
  population = events["at_risk"] - entrances
885
882
 
886
883
  estimate_ = np.cumsum(_additive_f(population, deaths))
884
+
887
885
  var_ = np.cumsum(_additive_var(population, deaths))
888
886
 
889
887
  timeline = sorted(timeline)
lifelines/version.py CHANGED
@@ -1,4 +1,4 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  from __future__ import unicode_literals
3
3
 
4
- __version__ = "0.27.8"
4
+ __version__ = "0.28.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lifelines
3
- Version: 0.27.8
3
+ Version: 0.28.0
4
4
  Summary: Survival analysis in Python, including Kaplan Meier, Nelson Aalen and regression
5
5
  Home-page: https://github.com/CamDavidsonPilon/lifelines
6
6
  Author: Cameron Davidson-Pilon
@@ -9,18 +9,16 @@ License: MIT
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: License :: OSI Approved :: MIT License
11
11
  Classifier: Programming Language :: Python
12
- Classifier: Programming Language :: Python :: 3.7
13
- Classifier: Programming Language :: Python :: 3.8
14
12
  Classifier: Programming Language :: Python :: 3.9
15
13
  Classifier: Programming Language :: Python :: 3.10
16
14
  Classifier: Programming Language :: Python :: 3.11
17
15
  Classifier: Topic :: Scientific/Engineering
18
- Requires-Python: >=3.7
16
+ Requires-Python: >=3.9
19
17
  Description-Content-Type: text/markdown
20
18
  License-File: LICENSE
21
19
  Requires-Dist: numpy <2.0,>=1.14.0
22
20
  Requires-Dist: scipy >=1.2.0
23
- Requires-Dist: pandas >=1.0.0
21
+ Requires-Dist: pandas >=1.2.0
24
22
  Requires-Dist: matplotlib >=3.0
25
23
  Requires-Dist: autograd >=1.5
26
24
  Requires-Dist: autograd-gamma >=0.3
@@ -1,10 +1,10 @@
1
1
  lifelines/__init__.py,sha256=F_sKrawq6L4GwTPgOu4FjoGUKQ2gfelAVIQOW1Ee8Ao,2241
2
2
  lifelines/calibration.py,sha256=Luii7bkJ2YB0jpuOYYhI22aUyEc1gLsS10Pno6Sqo98,4113
3
- lifelines/exceptions.py,sha256=Kf6GN2vB-SHde2mbPomj2PhpnCvCBOSTUZLY1jwOw-U,514
3
+ lifelines/exceptions.py,sha256=8T1vQuI6Fnf_5OfiJahksn5Soe-SmU9Y2IA7HYen460,577
4
4
  lifelines/generate_datasets.py,sha256=WsvvrZt0jEYQ7-Fp53vrCq7MzmAM2pPUSoCaiQRwN5g,10155
5
5
  lifelines/plotting.py,sha256=sQmwpSziHzVQoWoe_gll4LInrjg-E4FpeWMp07wurNo,35069
6
6
  lifelines/statistics.py,sha256=cOszUYz87elnbTAt6V3fTrHwPjB9HFI1hxjLvKypS6k,35129
7
- lifelines/version.py,sha256=i_7hf_ze0p4EVBems7NVh3EH_hYHAeUYD_J_LE6mR50,88
7
+ lifelines/version.py,sha256=a1Hb9Vpjox_LGg1wAeN82zL0efHoU5UicQ_OpuCT9_E,88
8
8
  lifelines/datasets/CuZn-LeftCensoredDataset.csv,sha256=PxTdZcJPPbhtaadpHjhMFVcUxmSn84BuDarujZIJpm4,1996
9
9
  lifelines/datasets/__init__.py,sha256=dhFp0uvLVBoAPBNSziknPpNc-ML9Ega6X2yL2UJHQ1M,19976
10
10
  lifelines/datasets/anderson.csv,sha256=nTAtTK8mf0ymU88nKvO2Fj0WL9SE9o4S0GVujmX8Cl4,580
@@ -35,39 +35,38 @@ lifelines/datasets/rossi.csv,sha256=AhRAAXDgfzAVooXtyiAUysDa6KrBJfy6rWQkkOBfiSw,
35
35
  lifelines/datasets/stanford_heart.csv,sha256=HWS9SqJjQ6gDmvxxKCJLR1cOIJ8XKuwTNu4bW8tKWVM,8859
36
36
  lifelines/datasets/static_test.csv,sha256=w2PtSkXknCZfciwqcOZGlA8znBO7jTcq_AJ5e6NStAk,101
37
37
  lifelines/datasets/waltons_dataset.csv,sha256=Fd4UX6tGYxgGhXtH3T-S81wIGIbVohv5yom4aw0kXL8,2449
38
- lifelines/fitters/__init__.py,sha256=_bW0VgluvRFHfd_wn4NX4nTqSL2F0O7V8YeK12rhpos,151518
38
+ lifelines/fitters/__init__.py,sha256=ZqJoIOtP_-esET_V_SEnlHhOWyVw-JKnYgj7CiF10pA,151639
39
39
  lifelines/fitters/aalen_additive_fitter.py,sha256=vRQb38weMcknyxC9bJwiALwCzxmJ5DsEZwHkz2zV93k,21518
40
40
  lifelines/fitters/aalen_johansen_fitter.py,sha256=w_2MV7Bbtr0swJ0VdySqirhlGsjbYyqduRx9iLKd6XA,14172
41
- lifelines/fitters/breslow_fleming_harrington_fitter.py,sha256=Te1Y73lIIKhTC6yMADe35RVHI4XOLF17ub-N8oudS4I,4091
41
+ lifelines/fitters/breslow_fleming_harrington_fitter.py,sha256=_86qU3wMHEyuCKLjhHLERP_ymNnlSvi7chWgi8Kygxg,4293
42
42
  lifelines/fitters/cox_time_varying_fitter.py,sha256=i8_mmJZm0VjHnX7wZYeLwMgpJryr1hfd69iRwaBn33Q,34656
43
- lifelines/fitters/coxph_fitter.py,sha256=uv_e6wR0o3gyZcaNOYbeeZhgfaPXQhtcM-nc3-9kxAg,136853
43
+ lifelines/fitters/coxph_fitter.py,sha256=U8k0mEHn0xsZ-akhUHeBhLsOZvKubAsyK6_JDyO5heE,136864
44
44
  lifelines/fitters/crc_spline_fitter.py,sha256=FUaiz4O-Hdke7T5dV8RCl-27oWxrMJLBSXxnRN4QkGQ,3126
45
45
  lifelines/fitters/exponential_fitter.py,sha256=Fbb1rtBOrHb_YxFYidzqXcFw7aWsqet_2vqi7s8WJ4U,2857
46
- lifelines/fitters/generalized_gamma_fitter.py,sha256=FUGff4DBhTqZ4woAhXpcH-YMF0L_CbdzGkAoT6TXcGI,6426
46
+ lifelines/fitters/generalized_gamma_fitter.py,sha256=OiXO9onvYtI2gNvUoxF4mjEjbj7IRZl5R4UZ_RzrSjo,6482
47
47
  lifelines/fitters/generalized_gamma_regression_fitter.py,sha256=UzG3dVau0UNdQtM6yW63wabDf7j--rxrdE9AlaVB8Vk,7955
48
- lifelines/fitters/kaplan_meier_fitter.py,sha256=079URtbAAwae1SkzgTi-nCOhfJklfogMiOFiyUdUaYg,24027
48
+ lifelines/fitters/kaplan_meier_fitter.py,sha256=UYPJi4BYcn54F26fc_lkkYzcZV-yUomsBB59ufdLRF8,24209
49
49
  lifelines/fitters/log_logistic_aft_fitter.py,sha256=cw179z0_IqvuWgOORHSZ1lBiidHcYkiO4hDi4YDEqRo,7074
50
50
  lifelines/fitters/log_logistic_fitter.py,sha256=iTH97i9TrLp5IVBIZHC8nx5rvSn2-KM-wfv1wR_YSPU,4004
51
51
  lifelines/fitters/log_normal_aft_fitter.py,sha256=aOcdMR8T4vhy2BKGebrpEJD_lTZIQQ5VsrnuuKkU0RA,7890
52
52
  lifelines/fitters/log_normal_fitter.py,sha256=NLn1DCxJ9WJrVaairJPcOu_lShko_-vwoXw6goRR42w,3557
53
- lifelines/fitters/mixins.py,sha256=6k5-g8cit8ODbU7PbVD9tfYsY0jpde0HID3wJQ5kz1k,12527
53
+ lifelines/fitters/mixins.py,sha256=mLfRxHv_Mgyyp_Lw6HiQNitI9gJyTRIFSG2OjvvcFnk,12827
54
54
  lifelines/fitters/mixture_cure_fitter.py,sha256=UetFlv9EfFYMDt95M2iR354lna5RKeWtO_lkoaMmoZE,5416
55
- lifelines/fitters/nelson_aalen_fitter.py,sha256=UNlEX5wR6xsUmEmJ2n2MEqblz-KvGmvlh8eGHfuQf6Y,10666
55
+ lifelines/fitters/nelson_aalen_fitter.py,sha256=QSE6E0ia6-TeHMIoMyo6nTmq8MHM21CgoUOdH7d1QFE,10686
56
56
  lifelines/fitters/npmle.py,sha256=HV3yeu1byVv5oBSdv5TuLUg2X5NUxydxj8-h_xYScB0,10143
57
57
  lifelines/fitters/piecewise_exponential_fitter.py,sha256=j48sXaEODClFmfFP3THb0qJ3_Q7ctJz19j50Uo1QJME,3357
58
58
  lifelines/fitters/piecewise_exponential_regression_fitter.py,sha256=JuGm93cKQBu6KBTHEOoheLJfMqP0h1ckeQjMIpC8aQo,4978
59
59
  lifelines/fitters/spline_fitter.py,sha256=TnkXPBabgZVqtI90T1-gm6C8k73WhQMrhbEAZw1OX0c,4214
60
60
  lifelines/fitters/weibull_aft_fitter.py,sha256=6wtU499AvXxZAE9PdnNQnbzh_NpPcdAEL6zd3xRV8hU,7772
61
61
  lifelines/fitters/weibull_fitter.py,sha256=CcII_V5ns00jP5sqv0dn8Yo0T3kdyc4Rkpb2bBuTvjU,3771
62
- lifelines/utils/__init__.py,sha256=e_hkwdPsn3SWoDyJeXoRh7oVU2TZwg2iSUUJZjoLKyM,70490
62
+ lifelines/utils/__init__.py,sha256=qEAVyYZCAvNInTSp4qvFxGFDhz_aKPp2NcRAqvr_1xA,70428
63
63
  lifelines/utils/btree.py,sha256=yevaIsGw_tQsGauXmwBHTMgCBjuuMZQgdHa-nCB-q2I,4369
64
64
  lifelines/utils/concordance.py,sha256=hWXrmg1BiK2Hqu9CRzlvkPlnlmZqZcAxH7L1PjaqdC8,12245
65
65
  lifelines/utils/lowess.py,sha256=MMydVcnbxqIgsiNcIgVUFtlFycD7v3ezwEGpituvBHs,2541
66
66
  lifelines/utils/printer.py,sha256=-nXxu02gs0kaKfoQQ65sH-I45tGmgoFeOOIUSEc53iE,5861
67
67
  lifelines/utils/safe_exp.py,sha256=HCCAkwQTx6G2qRC03v9Q_GWqVj8at1Eac1JVrMgS9hg,4350
68
- lifelines/utils/sklearn_adapter.py,sha256=S5qotbZ1hf1fhFBsx39Yd_NpA31jB9HhRiLjE8TWlhw,4202
69
- lifelines-0.27.8.dist-info/LICENSE,sha256=AasDeD139SnTdfXbKgN4BMyMgBlRy9YFs60tNrB4wf0,1079
70
- lifelines-0.27.8.dist-info/METADATA,sha256=_F3epTvxvgQlotdOsNhcL05k_jcDz3WofMBJE-xrXf0,3288
71
- lifelines-0.27.8.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
72
- lifelines-0.27.8.dist-info/top_level.txt,sha256=3i57Z4mtpc6jWrsW0n-_o9Y7CpzytMTeLMPJBHYAo0o,10
73
- lifelines-0.27.8.dist-info/RECORD,,
68
+ lifelines-0.28.0.dist-info/LICENSE,sha256=AasDeD139SnTdfXbKgN4BMyMgBlRy9YFs60tNrB4wf0,1079
69
+ lifelines-0.28.0.dist-info/METADATA,sha256=6zMJUnOh6TWoVtj87HGTmlPCsIopdr9jOeu_2wKmoP4,3188
70
+ lifelines-0.28.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
71
+ lifelines-0.28.0.dist-info/top_level.txt,sha256=3i57Z4mtpc6jWrsW0n-_o9Y7CpzytMTeLMPJBHYAo0o,10
72
+ lifelines-0.28.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.41.2)
2
+ Generator: bdist_wheel (0.42.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,135 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import inspect
3
- import pandas as pd
4
-
5
- try:
6
- from sklearn.base import BaseEstimator, RegressorMixin, MetaEstimatorMixin
7
- except ImportError:
8
- raise ImportError("scikit-learn must be installed on the local system to use this utility class.")
9
- from . import concordance_index
10
-
11
- __all__ = ["sklearn_adapter"]
12
-
13
-
14
- def filter_kwargs(f, kwargs):
15
- s = inspect.signature(f)
16
- res = {k: kwargs[k] for k in s.parameters if k in kwargs}
17
- return res
18
-
19
-
20
- class _SklearnModel(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
21
- def __init__(self, **kwargs):
22
- self._params = kwargs
23
- self.lifelines_model = self.lifelines_model(**filter_kwargs(self.lifelines_model.__init__, self._params))
24
- self._params["duration_col"] = "duration_col"
25
- self._params["event_col"] = self._event_col
26
-
27
- @property
28
- def _yColumn(self):
29
- return self._params["duration_col"]
30
-
31
- @property
32
- def _eventColumn(self):
33
- return self._params["event_col"]
34
-
35
- def fit(self, X, y=None):
36
- """
37
-
38
- Parameters
39
- -----------
40
-
41
- X: DataFrame
42
- must be a pandas DataFrame (with event_col included, if applicable)
43
-
44
- """
45
- if not isinstance(X, pd.DataFrame):
46
- raise ValueError("X must be a DataFrame. Got type: {}".format(type(X)))
47
-
48
- X = X.copy()
49
-
50
- if y is not None:
51
- X.insert(len(X.columns), self._yColumn, y, allow_duplicates=False)
52
-
53
- fit = getattr(self.lifelines_model, self._fit_method)
54
- self.lifelines_model = fit(df=X, **filter_kwargs(fit, self._params))
55
- return self
56
-
57
- def set_params(self, **params):
58
- for key, value in params.items():
59
- setattr(self.lifelines_model, key, value)
60
- return self
61
-
62
- def get_params(self, deep=True):
63
- out = {}
64
- for name, p in inspect.signature(self.lifelines_model.__init__).parameters.items():
65
- if p.kind < 4: # ignore kwargs
66
- out[name] = getattr(self.lifelines_model, name)
67
- return out
68
-
69
- def predict(self, X, **kwargs):
70
- """
71
- Parameters
72
- ------------
73
- X: DataFrame or numpy array
74
-
75
- """
76
- predictions = getattr(self.lifelines_model, self._predict_method)(X, **kwargs).squeeze().values
77
- return predictions
78
-
79
- def score(self, X, y, **kwargs):
80
- """
81
-
82
- Parameters
83
- -----------
84
-
85
- X: DataFrame
86
- must be a pandas DataFrame (with event_col included, if applicable)
87
-
88
- """
89
- rest_columns = list(set(X.columns) - {self._yColumn, self._eventColumn})
90
-
91
- x = X.loc[:, rest_columns]
92
- e = X.loc[:, self._eventColumn] if self._eventColumn else None
93
-
94
- if y is None:
95
- y = X.loc[:, self._yColumn]
96
-
97
- if callable(self._scoring_method):
98
- res = self._scoring_method(y, self.predict(x, **kwargs), event_observed=e)
99
- else:
100
- raise ValueError()
101
- return res
102
-
103
-
104
- def sklearn_adapter(fitter, event_col=None, predict_method="predict_expectation", scoring_method=concordance_index):
105
- """
106
- This function wraps lifelines models into a scikit-learn compatible API. The function returns a
107
- class that can be instantiated with parameters (similar to a scikit-learn class).
108
-
109
- Parameters
110
- ----------
111
-
112
- fitter: class
113
- The class (not an instance) to be wrapper. Example: ``CoxPHFitter`` or ``WeibullAFTFitter``
114
- event_col: string
115
- The column in your DataFrame that represents (if applicable) the event column
116
- predict_method: string
117
- Can be the string ``"predict_median", "predict_expectation"``
118
- scoring_method: function
119
- Provide a way to produce a ``score`` on the scikit-learn model. Signature should look like (durations, predictions, event_observed=None)
120
-
121
- """
122
- name = "SkLearn" + fitter.__name__
123
- klass = type(
124
- name,
125
- (_SklearnModel,),
126
- {
127
- "lifelines_model": fitter,
128
- "_event_col": event_col,
129
- "_predict_method": predict_method,
130
- "_fit_method": "fit",
131
- "_scoring_method": staticmethod(scoring_method),
132
- },
133
- )
134
- globals()[klass.__name__] = klass
135
- return klass