lifelines 0.27.8__tar.gz → 0.28.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. {lifelines-0.27.8/lifelines.egg-info → lifelines-0.28.0}/PKG-INFO +9 -4
  2. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/exceptions.py +4 -0
  3. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/__init__.py +4 -2
  4. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/breslow_fleming_harrington_fitter.py +9 -1
  5. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/coxph_fitter.py +1 -1
  6. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/generalized_gamma_fitter.py +6 -5
  7. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/kaplan_meier_fitter.py +9 -3
  8. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/mixins.py +8 -3
  9. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/nelson_aalen_fitter.py +2 -2
  10. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/test_estimation.py +77 -10
  11. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/utils/test_utils.py +15 -118
  12. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/utils/__init__.py +3 -5
  13. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/version.py +1 -1
  14. {lifelines-0.27.8 → lifelines-0.28.0/lifelines.egg-info}/PKG-INFO +9 -4
  15. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines.egg-info/SOURCES.txt +1 -3
  16. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines.egg-info/requires.txt +1 -1
  17. {lifelines-0.27.8 → lifelines-0.28.0}/reqs/base-requirements.txt +1 -1
  18. lifelines-0.28.0/reqs/docs-requirements.txt +7 -0
  19. {lifelines-0.27.8 → lifelines-0.28.0}/setup.py +1 -3
  20. lifelines-0.27.8/lifelines/utils/sklearn_adapter.py +0 -135
  21. lifelines-0.27.8/reqs/docs-requirements.txt +0 -7
  22. lifelines-0.27.8/reqs/travis-requirements.txt +0 -5
  23. {lifelines-0.27.8 → lifelines-0.28.0}/LICENSE +0 -0
  24. {lifelines-0.27.8 → lifelines-0.28.0}/MANIFEST.in +0 -0
  25. {lifelines-0.27.8 → lifelines-0.28.0}/README.md +0 -0
  26. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/__init__.py +0 -0
  27. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/calibration.py +0 -0
  28. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/CuZn-LeftCensoredDataset.csv +0 -0
  29. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/__init__.py +0 -0
  30. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/anderson.csv +0 -0
  31. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/c_botulinum_lag_phase.csv +0 -0
  32. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/canadian_senators.csv +0 -0
  33. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/dd.csv +0 -0
  34. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/dfcv_dataset.py +0 -0
  35. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/divorce.dat +0 -0
  36. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/g3.csv +0 -0
  37. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/gbsg2.csv +0 -0
  38. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/gehan.dat +0 -0
  39. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/holly_molly_polly.tsv +0 -0
  40. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/interval_diabetes.csv +0 -0
  41. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/kidney_transplant.csv +0 -0
  42. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/larynx.csv +0 -0
  43. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/lung.csv +0 -0
  44. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/lymph_node.csv +0 -0
  45. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/lymphoma.csv +0 -0
  46. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/merrell1955.csv +0 -0
  47. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/mice.csv +0 -0
  48. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/multicenter_aids_cohort.tsv +0 -0
  49. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/nh4.csv +0 -0
  50. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/panel_test.csv +0 -0
  51. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/psychiatric_patients.csv +0 -0
  52. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/recur.csv +0 -0
  53. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/regression.csv +0 -0
  54. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/rossi.csv +0 -0
  55. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/stanford_heart.csv +0 -0
  56. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/static_test.csv +0 -0
  57. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/datasets/waltons_dataset.csv +0 -0
  58. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/aalen_additive_fitter.py +0 -0
  59. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/aalen_johansen_fitter.py +0 -0
  60. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/cox_time_varying_fitter.py +0 -0
  61. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/crc_spline_fitter.py +0 -0
  62. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/exponential_fitter.py +0 -0
  63. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/generalized_gamma_regression_fitter.py +0 -0
  64. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/log_logistic_aft_fitter.py +0 -0
  65. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/log_logistic_fitter.py +0 -0
  66. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/log_normal_aft_fitter.py +0 -0
  67. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/log_normal_fitter.py +0 -0
  68. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/mixture_cure_fitter.py +0 -0
  69. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/npmle.py +0 -0
  70. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/piecewise_exponential_fitter.py +0 -0
  71. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/piecewise_exponential_regression_fitter.py +0 -0
  72. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/spline_fitter.py +0 -0
  73. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/weibull_aft_fitter.py +0 -0
  74. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/fitters/weibull_fitter.py +0 -0
  75. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/generate_datasets.py +0 -0
  76. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/plotting.py +0 -0
  77. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/statistics.py +0 -0
  78. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/__init__.py +0 -0
  79. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/test_generate_datasets.py +0 -0
  80. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/test_npmle.py +0 -0
  81. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/test_plotting.py +0 -0
  82. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/test_statistics.py +0 -0
  83. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/utils/test_btree.py +0 -0
  84. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/tests/utils/test_concordance.py +0 -0
  85. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/utils/btree.py +0 -0
  86. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/utils/concordance.py +0 -0
  87. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/utils/lowess.py +0 -0
  88. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/utils/printer.py +0 -0
  89. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines/utils/safe_exp.py +0 -0
  90. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines.egg-info/dependency_links.txt +0 -0
  91. {lifelines-0.27.8 → lifelines-0.28.0}/lifelines.egg-info/top_level.txt +0 -0
  92. {lifelines-0.27.8 → lifelines-0.28.0}/reqs/dev-requirements.txt +0 -0
  93. {lifelines-0.27.8 → lifelines-0.28.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lifelines
3
- Version: 0.27.8
3
+ Version: 0.28.0
4
4
  Summary: Survival analysis in Python, including Kaplan Meier, Nelson Aalen and regression
5
5
  Home-page: https://github.com/CamDavidsonPilon/lifelines
6
6
  Author: Cameron Davidson-Pilon
@@ -9,15 +9,20 @@ License: MIT
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: License :: OSI Approved :: MIT License
11
11
  Classifier: Programming Language :: Python
12
- Classifier: Programming Language :: Python :: 3.7
13
- Classifier: Programming Language :: Python :: 3.8
14
12
  Classifier: Programming Language :: Python :: 3.9
15
13
  Classifier: Programming Language :: Python :: 3.10
16
14
  Classifier: Programming Language :: Python :: 3.11
17
15
  Classifier: Topic :: Scientific/Engineering
18
- Requires-Python: >=3.7
16
+ Requires-Python: >=3.9
19
17
  Description-Content-Type: text/markdown
20
18
  License-File: LICENSE
19
+ Requires-Dist: numpy<2.0,>=1.14.0
20
+ Requires-Dist: scipy>=1.2.0
21
+ Requires-Dist: pandas>=1.2.0
22
+ Requires-Dist: matplotlib>=3.0
23
+ Requires-Dist: autograd>=1.5
24
+ Requires-Dist: autograd-gamma>=0.3
25
+ Requires-Dist: formulaic>=0.2.2
21
26
 
22
27
  ![](http://i.imgur.com/EOowdSD.png)
23
28
 
@@ -5,6 +5,10 @@ class StatError(Exception):
5
5
  pass
6
6
 
7
7
 
8
+ class ProportionalHazardAssumptionError(Exception):
9
+ pass
10
+
11
+
8
12
  class ConvergenceError(ValueError):
9
13
  # inherits from ValueError for backwards compatibility reasons
10
14
  def __init__(self, msg, original_exception=""):
@@ -550,7 +550,7 @@ class ParametricUnivariateFitter(UnivariateFitter):
550
550
  minimizing_results, previous_results, minimizing_ll = None, None, np.inf
551
551
  for method, option in zip(
552
552
  ["Nelder-Mead", self._scipy_fit_method],
553
- [{"maxiter": 100}, {**{"disp": show_progress}, **self._scipy_fit_options, **fit_options}],
553
+ [{"maxiter": 400}, {**{"disp": show_progress}, **self._scipy_fit_options, **fit_options}],
554
554
  ):
555
555
 
556
556
  initial_value = self._initial_values if previous_results is None else utils._to_1d_array(previous_results.x)
@@ -1409,7 +1409,7 @@ class ParametricRegressionFitter(RegressionFitter):
1409
1409
  def _survival_function(self, params, T, Xs):
1410
1410
  return anp.clip(anp.exp(-self._cumulative_hazard(params, T, Xs)), 1e-12, 1 - 1e-12)
1411
1411
 
1412
- def _log_likelihood_right_censoring(self, params, Ts, E, W, entries, Xs) -> float:
1412
+ def _log_likelihood_right_censoring(self, params, Ts: tuple, E, W, entries, Xs) -> float:
1413
1413
 
1414
1414
  T = Ts[0]
1415
1415
  non_zero_entries = entries > 0
@@ -3365,6 +3365,8 @@ class ParametericAFTRegressionFitter(ParametricRegressionFitter):
3365
3365
  also display the baseline survival, defined as the survival at the mean of the original dataset.
3366
3366
  times: iterable
3367
3367
  pass in a times to plot
3368
+ y: str
3369
+ one of "survival_function", "hazard", "cumulative_hazard". Default "survival_function"
3368
3370
  kwargs:
3369
3371
  pass in additional plotting commands
3370
3372
 
@@ -72,7 +72,14 @@ class BreslowFlemingHarringtonFitter(NonParametricUnivariateFitter):
72
72
  alpha = coalesce(alpha, self.alpha)
73
73
 
74
74
  naf = NelsonAalenFitter(alpha=alpha)
75
- naf.fit(durations, event_observed=event_observed, timeline=timeline, label=self._label, entry=entry, ci_labels=ci_labels)
75
+ naf.fit(
76
+ durations,
77
+ event_observed=event_observed,
78
+ timeline=timeline,
79
+ label=self._label,
80
+ entry=entry,
81
+ ci_labels=ci_labels,
82
+ )
76
83
  self.durations, self.event_observed, self.timeline, self.entry, self.event_table, self.weights = (
77
84
  naf.durations,
78
85
  naf.event_observed,
@@ -87,6 +94,7 @@ class BreslowFlemingHarringtonFitter(NonParametricUnivariateFitter):
87
94
  self.confidence_interval_ = np.exp(-naf.confidence_interval_)
88
95
  self.confidence_interval_survival_function_ = self.confidence_interval_
89
96
  self.confidence_interval_cumulative_density = 1 - self.confidence_interval_
97
+ self.confidence_interval_cumulative_density[:] = np.fliplr(self.confidence_interval_cumulative_density.values)
90
98
 
91
99
  # estimation methods
92
100
  self._estimation_method = "survival_function_"
@@ -80,7 +80,7 @@ class CoxPHFitter(RegressionFitter, ProportionalHazardMixin):
80
80
  When ``baseline_estimation_method="spline"``, this allows customizing the points in the time axis for the baseline hazard curve.
81
81
  To use evenly-spaced points in time, the ``n_baseline_knots`` parameter can be employed instead.
82
82
 
83
- breakpoints: int
83
+ breakpoints: list, optional
84
84
  Used when ``baseline_estimation_method="piecewise"``. Set the positions of the baseline hazard breakpoints.
85
85
 
86
86
  Examples
@@ -105,6 +105,7 @@ class GeneralizedGammaFitter(KnownModelParametricUnivariateFitter):
105
105
  """
106
106
 
107
107
  _scipy_fit_method = "SLSQP"
108
+ _scipy_fit_options = {"maxiter": 10_000, "maxfev": 10_000}
108
109
  _fitted_parameter_names = ["mu_", "ln_sigma_", "lambda_"]
109
110
  _bounds = [(None, None), (None, None), (None, None)]
110
111
  _compare_to_values = np.array([0.0, 0.0, 1.0])
@@ -117,14 +118,14 @@ class GeneralizedGammaFitter(KnownModelParametricUnivariateFitter):
117
118
  elif CensoringType.is_interval_censoring(self):
118
119
  # this fails if Ts[1] == Ts[0], so we add a some fudge factors.
119
120
  log_data = log(Ts[1] - Ts[0] + 0.1)
120
- return np.array([log_data.mean(), log(log_data.std() + 0.01), 0.1])
121
+ return np.array([log_data.mean() * 1.5, log(log_data.std() + 0.1), 1.0])
121
122
 
122
123
  def _cumulative_hazard(self, params, times):
123
124
  mu_, ln_sigma_, lambda_ = params
124
125
 
125
126
  sigma_ = safe_exp(ln_sigma_)
126
127
  Z = (log(times) - mu_) / sigma_
127
- ilambda_2 = 1 / lambda_ ** 2
128
+ ilambda_2 = 1 / lambda_**2
128
129
  clipped_exp = np.clip(safe_exp(lambda_ * Z) * ilambda_2, 1e-300, 1e20)
129
130
 
130
131
  if lambda_ > 0:
@@ -137,7 +138,7 @@ class GeneralizedGammaFitter(KnownModelParametricUnivariateFitter):
137
138
 
138
139
  def _log_hazard(self, params, times):
139
140
  mu_, ln_sigma_, lambda_ = params
140
- ilambda_2 = 1 / lambda_ ** 2
141
+ ilambda_2 = 1 / lambda_**2
141
142
  Z = (log(times) - mu_) / safe_exp(ln_sigma_)
142
143
  clipped_exp = np.clip(safe_exp(lambda_ * Z) * ilambda_2, 1e-300, 1e20)
143
144
  if lambda_ > 0:
@@ -171,5 +172,5 @@ class GeneralizedGammaFitter(KnownModelParametricUnivariateFitter):
171
172
  sigma_ = exp(self.ln_sigma_)
172
173
 
173
174
  if lambda_ > 0:
174
- return exp(sigma_ * log(gammainccinv(1 / lambda_ ** 2, p) * lambda_ ** 2) / lambda_) * exp(self.mu_)
175
- return exp(sigma_ * log(gammaincinv(1 / lambda_ ** 2, p) * lambda_ ** 2) / lambda_) * exp(self.mu_)
175
+ return exp(sigma_ * log(gammainccinv(1 / lambda_**2, p) * lambda_**2) / lambda_) * exp(self.mu_)
176
+ return exp(sigma_ * log(gammaincinv(1 / lambda_**2, p) * lambda_**2) / lambda_) * exp(self.mu_)
@@ -351,9 +351,14 @@ class KaplanMeierFitter(NonParametricUnivariateFitter):
351
351
  primary_estimate_name = "survival_function_"
352
352
  secondary_estimate_name = "cumulative_density_"
353
353
 
354
- (self.durations, self.event_observed, self.timeline, self.entry, self.event_table, self.weights) = _preprocess_inputs(
355
- durations, event_observed, timeline, entry, weights
356
- )
354
+ (
355
+ self.durations,
356
+ self.event_observed,
357
+ self.timeline,
358
+ self.entry,
359
+ self.event_table,
360
+ self.weights,
361
+ ) = _preprocess_inputs(durations, event_observed, timeline, entry, weights)
357
362
 
358
363
  alpha = alpha if alpha else self.alpha
359
364
  log_estimate, cumulative_sq_ = _additive_estimate(
@@ -386,6 +391,7 @@ class KaplanMeierFitter(NonParametricUnivariateFitter):
386
391
 
387
392
  self.confidence_interval_survival_function_ = self.confidence_interval_
388
393
  self.confidence_interval_cumulative_density_ = 1 - self.confidence_interval_
394
+ self.confidence_interval_cumulative_density_[:] = np.fliplr(self.confidence_interval_cumulative_density_.values)
389
395
  self._median = median_survival_times(self.survival_function_)
390
396
  self._cumulative_sq_ = cumulative_sq_
391
397
 
@@ -4,6 +4,7 @@ from textwrap import dedent, fill
4
4
  from autograd import numpy as anp
5
5
  import numpy as np
6
6
  from pandas import DataFrame, Series
7
+ from lifelines.exceptions import ProportionalHazardAssumptionError
7
8
  from lifelines.statistics import proportional_hazard_test, TimeTransformers
8
9
  from lifelines.utils import format_p_value
9
10
  from lifelines.utils.lowess import lowess
@@ -28,6 +29,7 @@ class ProportionalHazardMixin:
28
29
  p_value_threshold: float = 0.01,
29
30
  plot_n_bootstraps: int = 15,
30
31
  columns: Optional[List[str]] = None,
32
+ raise_on_fail: bool = False,
31
33
  ) -> None:
32
34
  """
33
35
  Use this function to test the proportional hazards assumption. See usage example at
@@ -51,6 +53,8 @@ class ProportionalHazardMixin:
51
53
  the function significantly.
52
54
  columns: list, optional
53
55
  specify a subset of columns to test.
56
+ raise_on_fail: bool, optional
57
+ throw a ``ProportionalHazardAssumptionError`` if the test fails. Default: False.
54
58
 
55
59
  Returns
56
60
  --------
@@ -107,7 +111,7 @@ class ProportionalHazardMixin:
107
111
 
108
112
  for variable in self.params_.index.intersection(columns or self.params_.index):
109
113
  minumum_observed_p_value = test_results.summary.loc[variable, "p"].min()
110
-
114
+
111
115
  # plot is done (regardless of test result) whenever `show_plots = True`
112
116
  if show_plots:
113
117
  axes.append([])
@@ -224,9 +228,8 @@ class ProportionalHazardMixin:
224
228
  ),
225
229
  end="\n\n",
226
230
  )
227
- #################
231
+ #################
228
232
 
229
-
230
233
  if advice and counter > 0:
231
234
  print(
232
235
  dedent(
@@ -243,6 +246,8 @@ class ProportionalHazardMixin:
243
246
 
244
247
  if counter == 0:
245
248
  print("Proportional hazard assumption looks okay.")
249
+ elif raise_on_fail:
250
+ raise ProportionalHazardAssumptionError()
246
251
  return axes
247
252
 
248
253
  @property
@@ -183,7 +183,7 @@ class NelsonAalenFitter(UnivariateFitter):
183
183
  )
184
184
 
185
185
  def _variance_f_discrete(self, population, deaths):
186
- return (population - deaths) * deaths / population ** 3
186
+ return (1 - deaths / population) * (deaths / population) * (1.0 / population)
187
187
 
188
188
  def _additive_f_smooth(self, population, deaths):
189
189
  cum_ = np.cumsum(1.0 / np.arange(1, np.max(population) + 1))
@@ -239,7 +239,7 @@ class NelsonAalenFitter(UnivariateFitter):
239
239
  C = var_hazard_.values != 0.0 # only consider the points with jumps
240
240
  std_hazard_ = np.sqrt(
241
241
  1.0
242
- / (bandwidth ** 2)
242
+ / (bandwidth**2)
243
243
  * np.dot(epanechnikov_kernel(timeline[:, None], timeline[C][None, :], bandwidth) ** 2, var_hazard_.values[C])
244
244
  )
245
245
  values = {
@@ -34,7 +34,14 @@ from lifelines.utils import (
34
34
  qth_survival_time,
35
35
  )
36
36
 
37
- from lifelines.exceptions import StatisticalWarning, ApproximationWarning, StatError, ConvergenceWarning, ConvergenceError
37
+ from lifelines.exceptions import (
38
+ ProportionalHazardAssumptionError,
39
+ StatisticalWarning,
40
+ ApproximationWarning,
41
+ StatError,
42
+ ConvergenceWarning,
43
+ ConvergenceError,
44
+ )
38
45
  from lifelines.fitters import BaseFitter, ParametricUnivariateFitter, ParametricRegressionFitter, RegressionFitter
39
46
  from lifelines.fitters.coxph_fitter import SemiParametricPHFitter
40
47
 
@@ -475,6 +482,19 @@ class TestUnivariateFitters:
475
482
  assert not isinstance(fitter.predict(1), Iterable)
476
483
  assert isinstance(fitter.predict([1, 2]), Iterable)
477
484
 
485
+ def test_cumulative_density_ci_is_ordered_correctly(self, positive_sample_lifetimes, univariate_fitters):
486
+ T = positive_sample_lifetimes[0]
487
+ for f in univariate_fitters:
488
+ fitter = f()
489
+ fitter.fit(T)
490
+ if not hasattr(fitter, "confidence_interval_cumulative_density_"):
491
+ continue
492
+ lower, upper = f"{fitter.label}_lower_0.95", f"{fitter.label}_upper_0.95"
493
+ assert np.all(
494
+ (fitter.confidence_interval_cumulative_density_[upper] - fitter.confidence_interval_cumulative_density_[lower])
495
+ >= 0
496
+ )
497
+
478
498
  def test_predict_method_returns_exact_value_if_given_an_observed_time(self):
479
499
  T = [1, 2, 3]
480
500
  kmf = KaplanMeierFitter()
@@ -574,9 +594,9 @@ class TestUnivariateFitters:
574
594
  assert_frame_equal(with_list, with_array)
575
595
  assert_frame_equal(with_tuple, with_array)
576
596
 
577
- with_array = fitter.fit_left_censoring(T, C).survival_function_
578
- with_list = fitter.fit_left_censoring(list(T), list(C)).survival_function_
579
- with_tuple = fitter.fit_left_censoring(tuple(T), tuple(C)).survival_function_
597
+ with_array = fitter.fit_left_censoring(T).survival_function_
598
+ with_list = fitter.fit_left_censoring(list(T)).survival_function_
599
+ with_tuple = fitter.fit_left_censoring(tuple(T)).survival_function_
580
600
  assert_frame_equal(with_list, with_array)
581
601
  assert_frame_equal(with_tuple, with_array)
582
602
 
@@ -1607,6 +1627,13 @@ class TestNelsonAalenFitter:
1607
1627
 
1608
1628
  assert_frame_equal(naf_w_weights.cumulative_hazard_, naf_no_weights.cumulative_hazard_)
1609
1629
 
1630
+ def test_variance_calculation_does_not_overflow(self):
1631
+
1632
+ y = np.random.randint(1, 1000, 100000000)
1633
+ naf = NelsonAalenFitter(nelson_aalen_smoothing=False)
1634
+ naf.fit(y, event_observed=None, timeline=range(0, int(y.max())))
1635
+ assert (naf._cumulative_sq >= 0).all()
1636
+
1610
1637
 
1611
1638
  class TestBreslowFlemingHarringtonFitter:
1612
1639
  def test_BHF_fit_when_KMF_throws_an_error(self):
@@ -2893,6 +2920,38 @@ class TestCoxPHFitter_SemiParametric:
2893
2920
 
2894
2921
  assert np.abs(newton(X, T, E, W, entries)[0] - -0.0335) < 0.0001
2895
2922
 
2923
+ def test_baseline_prediction_with_extreme_means(self, rossi):
2924
+ cph = CoxPHFitter()
2925
+ cph.fit(rossi, "week", "arrest")
2926
+
2927
+ rossi_shifted = rossi.copy()
2928
+ rossi_shifted["prio"] += 100
2929
+ cph_shifted = CoxPHFitter()
2930
+ cph_shifted.fit(rossi_shifted, "week", "arrest")
2931
+
2932
+ # make sure summary stats are equal
2933
+ assert_frame_equal(cph_shifted.summary, cph.summary)
2934
+
2935
+ # confirm hazards are equal
2936
+ assert_frame_equal(cph.baseline_hazard_, cph_shifted.baseline_hazard_)
2937
+ assert_frame_equal(cph.baseline_cumulative_hazard_, cph_shifted.baseline_cumulative_hazard_)
2938
+
2939
+ def test_baseline_prediction_with_extreme_scaling(self, rossi):
2940
+ cph = CoxPHFitter()
2941
+ cph.fit(rossi, "week", "arrest")
2942
+
2943
+ rossi_scaled = rossi.copy()
2944
+ rossi_scaled["prio"] *= 100
2945
+ cph_scaled = CoxPHFitter()
2946
+ cph_scaled.fit(rossi_scaled, "week", "arrest")
2947
+
2948
+ # make sure summary stats are equal - note that CI and coefs are unequal since we scaled params.
2949
+ assert_frame_equal(cph_scaled.summary[["z", "p"]], cph.summary[["z", "p"]])
2950
+
2951
+ # confirm hazards are equal
2952
+ assert_frame_equal(cph.baseline_hazard_, cph_scaled.baseline_hazard_)
2953
+ assert_frame_equal(cph.baseline_cumulative_hazard_, cph_scaled.baseline_cumulative_hazard_)
2954
+
2896
2955
 
2897
2956
  class TestCoxPHFitterPeices:
2898
2957
  @pytest.fixture
@@ -3027,7 +3086,7 @@ class TestCoxPHFitter:
3027
3086
  def test_formula_can_accept_numpy_functions(self, cph, rossi):
3028
3087
  cph.fit(rossi, "week", "arrest", formula="fin + log10(prio+1) + np.sqrt(age)")
3029
3088
  assert "fin" in cph.summary.index
3030
- assert "log10(prio+1)" in cph.summary.index
3089
+ assert "log10(prio + 1)" in cph.summary.index
3031
3090
  assert "np.sqrt(age)" in cph.summary.index
3032
3091
 
3033
3092
  @pytest.mark.xfail
@@ -3119,9 +3178,14 @@ class TestCoxPHFitter:
3119
3178
 
3120
3179
  def test_formulas_handles_categories_at_inference(self, cph):
3121
3180
  # Create a dummy dataset with some one continuous and one categorical features
3122
- df = pd.DataFrame({
3123
- 'time': [1, 2, 3, 1, 2, 3], 'event': [0, 1, 1, 1, 0, 0],
3124
- 'cov_cont':[0.1, 0.2, 0.3, 0.1, 0.2, 0.3], 'cov_categ': ['A', 'A', 'B', 'B', 'C', 'C']})
3181
+ df = pd.DataFrame(
3182
+ {
3183
+ "time": [1, 2, 3, 1, 2, 3],
3184
+ "event": [0, 1, 1, 1, 0, 0],
3185
+ "cov_cont": [0.1, 0.2, 0.3, 0.1, 0.2, 0.3],
3186
+ "cov_categ": ["A", "A", "B", "B", "C", "C"],
3187
+ }
3188
+ )
3125
3189
  cph.fit(df, "time", "event", formula="cov_cont + C(cov_categ)")
3126
3190
  cph.predict_survival_function(df.iloc[:4])
3127
3191
 
@@ -3402,6 +3466,11 @@ class TestCoxPHFitter:
3402
3466
  cph.fit(rossi, "week", "arrest")
3403
3467
  cph.check_assumptions(rossi)
3404
3468
 
3469
+ def test_check_assumptions_thows_if_raise_on_fail_enalbed(self, cph, rossi):
3470
+ cph.fit(rossi, "week", "arrest")
3471
+ with pytest.raises(ProportionalHazardAssumptionError):
3472
+ cph.check_assumptions(rossi, p_value_threshold=0.05, raise_on_fail=True)
3473
+
3405
3474
  def test_check_assumptions_for_subset_of_columns(self, cph, rossi):
3406
3475
  cph.fit(rossi, "week", "arrest")
3407
3476
  cph.check_assumptions(rossi, columns=["age"])
@@ -5688,6 +5757,4 @@ class TestMixtureCureFitter:
5688
5757
  T, E = load_kidney_transplant()["time"], load_kidney_transplant()["death"]
5689
5758
  wmc.fit(T, E)
5690
5759
  mcfitter.fit(T, E)
5691
- print(wmc.summary)
5692
- print(mcfitter.summary)
5693
5760
  assert_frame_equal(wmc.summary.reset_index(drop=True), mcfitter.summary.reset_index(drop=True), rtol=0.25)
@@ -15,7 +15,6 @@ from lifelines import CoxPHFitter, WeibullAFTFitter, KaplanMeierFitter, Exponent
15
15
  from lifelines.datasets import load_regression_dataset, load_larynx, load_waltons, load_rossi
16
16
  from lifelines import utils
17
17
  from lifelines import exceptions
18
- from lifelines.utils.sklearn_adapter import sklearn_adapter
19
18
  from lifelines.utils.safe_exp import safe_exp
20
19
 
21
20
 
@@ -303,6 +302,13 @@ def test_survival_table_from_events_binned_with_empty_bin():
303
302
  assert not pd.isnull(event_table).any().any()
304
303
 
305
304
 
305
+ def test_survival_table_from_events_with_future_bins():
306
+ df = load_waltons()
307
+ event_table = utils.survival_table_from_events(df["T"], df["E"], collapse=True, intervals=np.arange(10, 100).tolist())
308
+ assert not pd.isnull(event_table).any().any()
309
+ assert event_table.iloc[-1].sum() == 0
310
+
311
+
306
312
  def test_survival_table_from_events_at_risk_column():
307
313
  df = load_waltons()
308
314
  # from R
@@ -885,122 +891,6 @@ class TestStepSizer:
885
891
  assert ss.next() < start
886
892
 
887
893
 
888
- class TestSklearnAdapter:
889
- @pytest.fixture
890
- def X(self):
891
- return load_regression_dataset().drop("T", axis=1)
892
-
893
- @pytest.fixture
894
- def Y(self):
895
- return load_regression_dataset().pop("T")
896
-
897
- def test_model_has_correct_api(self, X, Y):
898
- base_model = sklearn_adapter(CoxPHFitter, event_col="E")
899
- cph = base_model()
900
- assert hasattr(cph, "fit")
901
- cph.fit(X, Y)
902
- assert hasattr(cph, "predict")
903
- cph.predict(X)
904
- assert hasattr(cph, "score")
905
- cph.score(X, Y)
906
-
907
- def test_sklearn_cross_val_score_accept_model(self, X, Y):
908
- from sklearn.model_selection import cross_val_score
909
- from sklearn.model_selection import GridSearchCV
910
-
911
- base_model = sklearn_adapter(WeibullAFTFitter, event_col="E")
912
- wf = base_model(penalizer=1.0)
913
- assert len(cross_val_score(wf, X, Y, cv=3)) == 3
914
-
915
- def test_sklearn_GridSearchCV_accept_model(self, X, Y):
916
- from sklearn.model_selection import cross_val_score
917
- from sklearn.model_selection import GridSearchCV
918
-
919
- base_model = sklearn_adapter(WeibullAFTFitter, event_col="E")
920
-
921
- grid_params = {"penalizer": 10.0 ** np.arange(-2, 3), "model_ancillary": [True, False]}
922
- clf = GridSearchCV(base_model(), grid_params, cv=4)
923
- clf.fit(X, Y)
924
-
925
- assert clf.best_params_ == {"model_ancillary": True, "penalizer": 100.0}
926
- assert clf.predict(X).shape[0] == X.shape[0]
927
-
928
- def test_model_can_accept_things_like_strata(self, X, Y):
929
- X["strata"] = np.random.randint(0, 2, size=X.shape[0])
930
- base_model = sklearn_adapter(CoxPHFitter, event_col="E")
931
- cph = base_model(strata="strata")
932
- cph.fit(X, Y)
933
-
934
- def test_we_can_user_other_prediction_methods(self, X, Y):
935
-
936
- base_model = sklearn_adapter(WeibullAFTFitter, event_col="E", predict_method="predict_median")
937
- wf = base_model(strata="strata")
938
- wf.fit(X, Y)
939
- assert wf.predict(X).shape[0] == X.shape[0]
940
-
941
- def test_dill(self, X, Y):
942
- import dill
943
-
944
- base_model = sklearn_adapter(CoxPHFitter, event_col="E")
945
- cph = base_model()
946
- cph.fit(X, Y)
947
-
948
- s = dill.dumps(cph)
949
- s = dill.loads(s)
950
- assert cph.predict(X).shape[0] == X.shape[0]
951
-
952
- def test_pickle(self, X, Y):
953
- import pickle
954
-
955
- base_model = sklearn_adapter(CoxPHFitter, event_col="E")
956
- cph = base_model()
957
- cph.fit(X, Y)
958
-
959
- s = pickle.dumps(cph, protocol=-1)
960
- s = pickle.loads(s)
961
- assert cph.predict(X).shape[0] == X.shape[0]
962
-
963
- def test_isinstance(self):
964
- from sklearn.base import BaseEstimator, RegressorMixin, MetaEstimatorMixin, MultiOutputMixin
965
-
966
- base_model = sklearn_adapter(CoxPHFitter, event_col="E")
967
- assert isinstance(base_model(), BaseEstimator)
968
- assert isinstance(base_model(), RegressorMixin)
969
- assert isinstance(base_model(), MetaEstimatorMixin)
970
-
971
- @pytest.mark.xfail
972
- def test_sklearn_GridSearchCV_accept_model_with_parallelization(self, X, Y):
973
- from sklearn.model_selection import cross_val_score
974
- from sklearn.model_selection import GridSearchCV
975
-
976
- base_model = sklearn_adapter(WeibullAFTFitter, event_col="E")
977
-
978
- grid_params = {"penalizer": 10.0 ** np.arange(-2, 3), "l1_ratio": [0.05, 0.5, 0.95], "model_ancillary": [True, False]}
979
- # note the n_jobs
980
- clf = GridSearchCV(base_model(), grid_params, cv=4, n_jobs=-1)
981
- clf.fit(X, Y)
982
-
983
- assert clf.best_params_ == {"l1_ratio": 0.5, "model_ancillary": False, "penalizer": 0.01}
984
- assert clf.predict(X).shape[0] == X.shape[0]
985
-
986
- def test_joblib(self, X, Y):
987
- from joblib import dump, load
988
-
989
- base_model = sklearn_adapter(WeibullAFTFitter, event_col="E")
990
-
991
- clf = base_model()
992
- clf.fit(X, Y)
993
- dump(clf, "filename.joblib")
994
- clf = load("filename.joblib")
995
-
996
- @pytest.mark.xfail
997
- def test_sklearn_check(self):
998
- from sklearn.utils.estimator_checks import check_estimator
999
-
1000
- base_model = sklearn_adapter(WeibullAFTFitter, event_col="E")
1001
- check_estimator(base_model())
1002
-
1003
-
1004
894
  def test_rmst_works_at_kaplan_meier_edge_case():
1005
895
 
1006
896
  T = [1, 2, 3, 4, 10]
@@ -1025,7 +915,14 @@ def test_rmst_works_at_kaplan_meier_with_left_censoring():
1025
915
  assert abs(results[1] - 0) < 0.0001
1026
916
 
1027
917
 
1028
- def test_rmst_exactely_with_known_solution():
918
+ def test_rmst_works_with_return_variance():
919
+ # issue 1578
920
+ T = [1, 2, 3, 4, 10]
921
+ kmf = KaplanMeierFitter().fit(T)
922
+ result = utils.restricted_mean_survival_time(kmf.survival_function_, t=10, return_variance=True)
923
+
924
+
925
+ def test_rmst_exactly_with_known_solution():
1029
926
  T = np.random.exponential(2, 100)
1030
927
  exp = ExponentialFitter().fit(T)
1031
928
  lambda_ = exp.lambda_
@@ -311,7 +311,7 @@ def _expected_value_of_survival_squared_up_to_t(
311
311
 
312
312
  if isinstance(model_or_survival_function, pd.DataFrame):
313
313
  sf = model_or_survival_function.loc[:t]
314
- sf = sf.append(pd.DataFrame([1], index=[0], columns=sf.columns)).sort_index()
314
+ sf = pd.concat((sf, pd.DataFrame([1], index=[0], columns=sf.columns))).sort_index()
315
315
  sf_tau = sf * sf.index.values[:, None]
316
316
  return 2 * trapz(y=sf_tau.values[:, 0], x=sf_tau.index)
317
317
  elif isinstance(model_or_survival_function, lifelines.fitters.UnivariateFitter):
@@ -561,7 +561,7 @@ def _group_event_table_by_intervals(event_table, intervals) -> pd.DataFrame:
561
561
  )
562
562
  # convert columns from multiindex
563
563
  event_table.columns = event_table.columns.droplevel(1)
564
- return event_table.bfill()
564
+ return event_table.bfill().fillna(0)
565
565
 
566
566
 
567
567
  def survival_events_from_table(survival_table, observed_deaths_col="observed", censored_col="censored"):
@@ -744,9 +744,6 @@ def k_fold_cross_validation(
744
744
  results: list
745
745
  (k,1) list of scores for each fold. The scores can be anything.
746
746
 
747
- See Also
748
- ---------
749
- lifelines.utils.sklearn_adapter.sklearn_adapter
750
747
 
751
748
  """
752
749
  # Make sure fitters is a list
@@ -884,6 +881,7 @@ def _additive_estimate(events, timeline, _additive_f, _additive_var, reverse):
884
881
  population = events["at_risk"] - entrances
885
882
 
886
883
  estimate_ = np.cumsum(_additive_f(population, deaths))
884
+
887
885
  var_ = np.cumsum(_additive_var(population, deaths))
888
886
 
889
887
  timeline = sorted(timeline)
@@ -1,4 +1,4 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  from __future__ import unicode_literals
3
3
 
4
- __version__ = "0.27.8"
4
+ __version__ = "0.28.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: lifelines
3
- Version: 0.27.8
3
+ Version: 0.28.0
4
4
  Summary: Survival analysis in Python, including Kaplan Meier, Nelson Aalen and regression
5
5
  Home-page: https://github.com/CamDavidsonPilon/lifelines
6
6
  Author: Cameron Davidson-Pilon
@@ -9,15 +9,20 @@ License: MIT
9
9
  Classifier: Development Status :: 4 - Beta
10
10
  Classifier: License :: OSI Approved :: MIT License
11
11
  Classifier: Programming Language :: Python
12
- Classifier: Programming Language :: Python :: 3.7
13
- Classifier: Programming Language :: Python :: 3.8
14
12
  Classifier: Programming Language :: Python :: 3.9
15
13
  Classifier: Programming Language :: Python :: 3.10
16
14
  Classifier: Programming Language :: Python :: 3.11
17
15
  Classifier: Topic :: Scientific/Engineering
18
- Requires-Python: >=3.7
16
+ Requires-Python: >=3.9
19
17
  Description-Content-Type: text/markdown
20
18
  License-File: LICENSE
19
+ Requires-Dist: numpy<2.0,>=1.14.0
20
+ Requires-Dist: scipy>=1.2.0
21
+ Requires-Dist: pandas>=1.2.0
22
+ Requires-Dist: matplotlib>=3.0
23
+ Requires-Dist: autograd>=1.5
24
+ Requires-Dist: autograd-gamma>=0.3
25
+ Requires-Dist: formulaic>=0.2.2
21
26
 
22
27
  ![](http://i.imgur.com/EOowdSD.png)
23
28
 
@@ -83,8 +83,6 @@ lifelines/utils/concordance.py
83
83
  lifelines/utils/lowess.py
84
84
  lifelines/utils/printer.py
85
85
  lifelines/utils/safe_exp.py
86
- lifelines/utils/sklearn_adapter.py
87
86
  reqs/base-requirements.txt
88
87
  reqs/dev-requirements.txt
89
- reqs/docs-requirements.txt
90
- reqs/travis-requirements.txt
88
+ reqs/docs-requirements.txt
@@ -1,6 +1,6 @@
1
1
  numpy<2.0,>=1.14.0
2
2
  scipy>=1.2.0
3
- pandas>=1.0.0
3
+ pandas>=1.2.0
4
4
  matplotlib>=3.0
5
5
  autograd>=1.5
6
6
  autograd-gamma>=0.3
@@ -1,6 +1,6 @@
1
1
  numpy>=1.14.0,<2.0
2
2
  scipy>=1.2.0
3
- pandas>=1.0.0
3
+ pandas>=1.2.0
4
4
  matplotlib>=3.0
5
5
  autograd>=1.5
6
6
  autograd-gamma>=0.3
@@ -0,0 +1,7 @@
1
+ -r dev-requirements.txt
2
+ sphinx==7.2.6
3
+ sphinx_rtd_theme==2.0.0
4
+ nbsphinx==0.9.3
5
+ jupyter_client==8.6.0
6
+ nbconvert>=6.5.1
7
+ ipykernel==6.28.0
@@ -17,8 +17,6 @@ CLASSIFIERS = [
17
17
  "Development Status :: 4 - Beta",
18
18
  "License :: OSI Approved :: MIT License",
19
19
  "Programming Language :: Python",
20
- "Programming Language :: Python :: 3.7",
21
- "Programming Language :: Python :: 3.8",
22
20
  "Programming Language :: Python :: 3.9",
23
21
  "Programming Language :: Python :: 3.10",
24
22
  "Programming Language :: Python :: 3.11",
@@ -28,7 +26,7 @@ LICENSE = "MIT"
28
26
  PACKAGE_DATA = {"lifelines": ["datasets/*"]}
29
27
  DESCRIPTION = "Survival analysis in Python, including Kaplan Meier, Nelson Aalen and regression"
30
28
  URL = "https://github.com/CamDavidsonPilon/lifelines"
31
- PYTHON_REQ = ">=3.7"
29
+ PYTHON_REQ = ">=3.9"
32
30
 
33
31
  setup(
34
32
  name=NAME,
@@ -1,135 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- import inspect
3
- import pandas as pd
4
-
5
- try:
6
- from sklearn.base import BaseEstimator, RegressorMixin, MetaEstimatorMixin
7
- except ImportError:
8
- raise ImportError("scikit-learn must be installed on the local system to use this utility class.")
9
- from . import concordance_index
10
-
11
- __all__ = ["sklearn_adapter"]
12
-
13
-
14
- def filter_kwargs(f, kwargs):
15
- s = inspect.signature(f)
16
- res = {k: kwargs[k] for k in s.parameters if k in kwargs}
17
- return res
18
-
19
-
20
- class _SklearnModel(BaseEstimator, MetaEstimatorMixin, RegressorMixin):
21
- def __init__(self, **kwargs):
22
- self._params = kwargs
23
- self.lifelines_model = self.lifelines_model(**filter_kwargs(self.lifelines_model.__init__, self._params))
24
- self._params["duration_col"] = "duration_col"
25
- self._params["event_col"] = self._event_col
26
-
27
- @property
28
- def _yColumn(self):
29
- return self._params["duration_col"]
30
-
31
- @property
32
- def _eventColumn(self):
33
- return self._params["event_col"]
34
-
35
- def fit(self, X, y=None):
36
- """
37
-
38
- Parameters
39
- -----------
40
-
41
- X: DataFrame
42
- must be a pandas DataFrame (with event_col included, if applicable)
43
-
44
- """
45
- if not isinstance(X, pd.DataFrame):
46
- raise ValueError("X must be a DataFrame. Got type: {}".format(type(X)))
47
-
48
- X = X.copy()
49
-
50
- if y is not None:
51
- X.insert(len(X.columns), self._yColumn, y, allow_duplicates=False)
52
-
53
- fit = getattr(self.lifelines_model, self._fit_method)
54
- self.lifelines_model = fit(df=X, **filter_kwargs(fit, self._params))
55
- return self
56
-
57
- def set_params(self, **params):
58
- for key, value in params.items():
59
- setattr(self.lifelines_model, key, value)
60
- return self
61
-
62
- def get_params(self, deep=True):
63
- out = {}
64
- for name, p in inspect.signature(self.lifelines_model.__init__).parameters.items():
65
- if p.kind < 4: # ignore kwargs
66
- out[name] = getattr(self.lifelines_model, name)
67
- return out
68
-
69
- def predict(self, X, **kwargs):
70
- """
71
- Parameters
72
- ------------
73
- X: DataFrame or numpy array
74
-
75
- """
76
- predictions = getattr(self.lifelines_model, self._predict_method)(X, **kwargs).squeeze().values
77
- return predictions
78
-
79
- def score(self, X, y, **kwargs):
80
- """
81
-
82
- Parameters
83
- -----------
84
-
85
- X: DataFrame
86
- must be a pandas DataFrame (with event_col included, if applicable)
87
-
88
- """
89
- rest_columns = list(set(X.columns) - {self._yColumn, self._eventColumn})
90
-
91
- x = X.loc[:, rest_columns]
92
- e = X.loc[:, self._eventColumn] if self._eventColumn else None
93
-
94
- if y is None:
95
- y = X.loc[:, self._yColumn]
96
-
97
- if callable(self._scoring_method):
98
- res = self._scoring_method(y, self.predict(x, **kwargs), event_observed=e)
99
- else:
100
- raise ValueError()
101
- return res
102
-
103
-
104
- def sklearn_adapter(fitter, event_col=None, predict_method="predict_expectation", scoring_method=concordance_index):
105
- """
106
- This function wraps lifelines models into a scikit-learn compatible API. The function returns a
107
- class that can be instantiated with parameters (similar to a scikit-learn class).
108
-
109
- Parameters
110
- ----------
111
-
112
- fitter: class
113
- The class (not an instance) to be wrapper. Example: ``CoxPHFitter`` or ``WeibullAFTFitter``
114
- event_col: string
115
- The column in your DataFrame that represents (if applicable) the event column
116
- predict_method: string
117
- Can be the string ``"predict_median", "predict_expectation"``
118
- scoring_method: function
119
- Provide a way to produce a ``score`` on the scikit-learn model. Signature should look like (durations, predictions, event_observed=None)
120
-
121
- """
122
- name = "SkLearn" + fitter.__name__
123
- klass = type(
124
- name,
125
- (_SklearnModel,),
126
- {
127
- "lifelines_model": fitter,
128
- "_event_col": event_col,
129
- "_predict_method": predict_method,
130
- "_fit_method": "fit",
131
- "_scoring_method": staticmethod(scoring_method),
132
- },
133
- )
134
- globals()[klass.__name__] = klass
135
- return klass
@@ -1,7 +0,0 @@
1
- -r dev-requirements.txt
2
- sphinx
3
- sphinx_rtd_theme
4
- nbsphinx
5
- jupyter_client
6
- nbconvert>=6.5.1
7
- ipykernel
@@ -1,5 +0,0 @@
1
- python-coveralls
2
- seaborn
3
- pytest-travis-fold
4
- dill
5
- -r dev-requirements.txt
File without changes
File without changes
File without changes
File without changes