pySEQTarget 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pySEQTarget/SEQopts.py +197 -0
  2. pySEQTarget/SEQoutput.py +163 -0
  3. pySEQTarget/SEQuential.py +375 -0
  4. pySEQTarget/__init__.py +5 -0
  5. pySEQTarget/analysis/__init__.py +8 -0
  6. pySEQTarget/analysis/_hazard.py +211 -0
  7. pySEQTarget/analysis/_outcome_fit.py +75 -0
  8. pySEQTarget/analysis/_risk_estimates.py +136 -0
  9. pySEQTarget/analysis/_subgroup_fit.py +30 -0
  10. pySEQTarget/analysis/_survival_pred.py +372 -0
  11. pySEQTarget/data/__init__.py +19 -0
  12. pySEQTarget/error/__init__.py +2 -0
  13. pySEQTarget/error/_datachecker.py +38 -0
  14. pySEQTarget/error/_param_checker.py +50 -0
  15. pySEQTarget/expansion/__init__.py +5 -0
  16. pySEQTarget/expansion/_binder.py +98 -0
  17. pySEQTarget/expansion/_diagnostics.py +53 -0
  18. pySEQTarget/expansion/_dynamic.py +73 -0
  19. pySEQTarget/expansion/_mapper.py +44 -0
  20. pySEQTarget/expansion/_selection.py +31 -0
  21. pySEQTarget/helpers/__init__.py +8 -0
  22. pySEQTarget/helpers/_bootstrap.py +111 -0
  23. pySEQTarget/helpers/_col_string.py +6 -0
  24. pySEQTarget/helpers/_format_time.py +6 -0
  25. pySEQTarget/helpers/_output_files.py +167 -0
  26. pySEQTarget/helpers/_pad.py +7 -0
  27. pySEQTarget/helpers/_predict_model.py +9 -0
  28. pySEQTarget/helpers/_prepare_data.py +19 -0
  29. pySEQTarget/initialization/__init__.py +5 -0
  30. pySEQTarget/initialization/_censoring.py +53 -0
  31. pySEQTarget/initialization/_denominator.py +39 -0
  32. pySEQTarget/initialization/_numerator.py +37 -0
  33. pySEQTarget/initialization/_outcome.py +56 -0
  34. pySEQTarget/plot/__init__.py +1 -0
  35. pySEQTarget/plot/_survival_plot.py +104 -0
  36. pySEQTarget/weighting/__init__.py +8 -0
  37. pySEQTarget/weighting/_weight_bind.py +86 -0
  38. pySEQTarget/weighting/_weight_data.py +47 -0
  39. pySEQTarget/weighting/_weight_fit.py +99 -0
  40. pySEQTarget/weighting/_weight_pred.py +192 -0
  41. pySEQTarget/weighting/_weight_stats.py +23 -0
  42. pyseqtarget-0.10.0.dist-info/METADATA +98 -0
  43. pyseqtarget-0.10.0.dist-info/RECORD +46 -0
  44. pyseqtarget-0.10.0.dist-info/WHEEL +5 -0
  45. pyseqtarget-0.10.0.dist-info/licenses/LICENSE +21 -0
  46. pyseqtarget-0.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,375 @@
1
+ import datetime
2
+ import time
3
+ from collections import Counter
4
+ from dataclasses import asdict
5
+ from typing import List, Literal, Optional
6
+
7
+ import numpy as np
8
+ import polars as pl
9
+
10
+ from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit,
11
+ _pred_risk, _risk_estimates, _subgroup_fit)
12
+ from .error import _datachecker, _param_checker
13
+ from .expansion import _binder, _diagnostics, _dynamic, _random_selection
14
+ from .helpers import _col_string, _format_time, bootstrap_loop
15
+ from .initialization import (_cense_denominator, _cense_numerator,
16
+ _denominator, _numerator, _outcome)
17
+ from .plot import _survival_plot
18
+ from .SEQopts import SEQopts
19
+ from .SEQoutput import SEQoutput
20
+ from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
21
+ _fit_visit, _weight_bind, _weight_predict,
22
+ _weight_setup, _weight_stats)
23
+
24
+
25
+ class SEQuential:
26
+ """
27
+ Primary class initializer for SEQuentially nested target trial emulation
28
+
29
+ :param data: Data for analysis
30
+ :type data: pl.DataFrame
31
+ :param id_col: Column name for unique patient IDs
32
+ :type id_col: str
33
+ :param time_col: Column name for observational time points
34
+ :type time_col: str
35
+ :param eligible_col: Column name for analytical eligibility
36
+ :type eligible_col: str
37
+ :param treatment_col: Column name specifying treatment per time_col
38
+ :type treatment_col: str
39
+ :param outcome_col: Column name specifying outcome per time_col
40
+ :type outcome_col: str
41
+ :param time_varying_cols: Time-varying column names as covariates (BMI, Age, etc.)
42
+ :type time_varying_cols: Optional[List[str]] or None
43
+ :param fixed_cols: Fixed column names as covariates (Sex, YOB, etc.)
44
+ :type fixed_cols: Optional[List[str]] or None
45
+ :param method: Method for analysis ['ITT', 'dose-response', or 'censoring']
46
+ :type method: str
47
+ :param parameters: Parameters to augment analysis, specified with ``pySEQTarget.SEQopts``
48
+ :type parameters: Optional[SEQopts] or None
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ data: pl.DataFrame,
54
+ id_col: str,
55
+ time_col: str,
56
+ eligible_col: str,
57
+ treatment_col: str,
58
+ outcome_col: str,
59
+ time_varying_cols: Optional[List[str]] = None,
60
+ fixed_cols: Optional[List[str]] = None,
61
+ method: Literal["ITT", "dose-response", "censoring"] = "ITT",
62
+ parameters: Optional[SEQopts] = None,
63
+ ) -> None:
64
+ self.data = data
65
+ self.id_col = id_col
66
+ self.time_col = time_col
67
+ self.eligible_col = eligible_col
68
+ self.treatment_col = treatment_col
69
+ self.outcome_col = outcome_col
70
+ self.time_varying_cols = time_varying_cols
71
+ self.fixed_cols = fixed_cols
72
+ self.method = method
73
+
74
+ self._time_initialized = datetime.datetime.now()
75
+
76
+ if parameters is None:
77
+ parameters = SEQopts()
78
+
79
+ for name, value in asdict(parameters).items():
80
+ setattr(self, name, value)
81
+
82
+ self._rng = (
83
+ np.random.RandomState(self.seed) if self.seed is not None else np.random
84
+ )
85
+
86
+ if self.covariates is None:
87
+ self.covariates = _outcome(self)
88
+
89
+ if self.weighted:
90
+ if self.numerator is None:
91
+ self.numerator = _numerator(self)
92
+
93
+ if self.denominator is None:
94
+ self.denominator = _denominator(self)
95
+
96
+ if self.cense_colname is not None or self.visit_colname is not None:
97
+ if self.cense_numerator is None:
98
+ self.cense_numerator = _cense_numerator(self)
99
+
100
+ if self.cense_denominator is None:
101
+ self.cense_denominator = _cense_denominator(self)
102
+
103
+ _param_checker(self)
104
+ _datachecker(self)
105
+
106
+ def expand(self) -> None:
107
+ """
108
+ Creates the sequentially nested, emulated target trial structure
109
+ """
110
+ start = time.perf_counter()
111
+ kept = [
112
+ self.cense_colname,
113
+ self.cense_eligible_colname,
114
+ self.compevent_colname,
115
+ self.visit_colname,
116
+ *self.weight_eligible_colnames,
117
+ *self.excused_colnames,
118
+ ]
119
+
120
+ self.data = self.data.with_columns(
121
+ [
122
+ pl.when(pl.col(self.treatment_col).is_in(self.treatment_level))
123
+ .then(self.eligible_col)
124
+ .otherwise(0)
125
+ .alias(self.eligible_col),
126
+ pl.col(self.treatment_col).shift(1).over([self.id_col]).alias("tx_lag"),
127
+ pl.lit(False).alias("switch"),
128
+ ]
129
+ ).with_columns(
130
+ [
131
+ pl.when(pl.col(self.time_col) == 0)
132
+ .then(pl.lit(False))
133
+ .otherwise(
134
+ (pl.col("tx_lag").is_not_null())
135
+ & (pl.col("tx_lag") != pl.col(self.treatment_col))
136
+ )
137
+ .cast(pl.Int8)
138
+ .alias("switch")
139
+ ]
140
+ )
141
+
142
+ self.DT = _binder(
143
+ self,
144
+ kept_cols=_col_string(
145
+ [
146
+ self.covariates,
147
+ self.numerator,
148
+ self.denominator,
149
+ self.cense_numerator,
150
+ self.cense_denominator,
151
+ ]
152
+ ).union(kept),
153
+ ).with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))
154
+
155
+ self.data = self.data.with_columns(
156
+ pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
157
+ )
158
+
159
+ if self.method != "ITT":
160
+ _dynamic(self)
161
+ if self.selection_random:
162
+ _random_selection(self)
163
+ _diagnostics(self)
164
+
165
+ end = time.perf_counter()
166
+ self._expansion_time = _format_time(start, end)
167
+
168
+ def bootstrap(self, **kwargs) -> None:
169
+ """
170
+ Internally sets up bootstrapping - creating a list of IDs to use per iteration
171
+ """
172
+ allowed = {
173
+ "bootstrap_nboot",
174
+ "bootstrap_sample",
175
+ "bootstrap_CI",
176
+ "bootstrap_method",
177
+ }
178
+ for key, value in kwargs.items():
179
+ if key in allowed:
180
+ setattr(self, key, value)
181
+ else:
182
+ raise ValueError(f"Unknown argument: {key}")
183
+ UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list()
184
+ NIDs = len(UIDs)
185
+
186
+ self._boot_samples = []
187
+ for _ in range(self.bootstrap_nboot):
188
+ sampled_IDs = self._rng.choice(
189
+ UIDs, size=int(self.bootstrap_sample * NIDs), replace=True
190
+ )
191
+ id_counts = Counter(sampled_IDs)
192
+ self._boot_samples.append(id_counts)
193
+ return self
194
+
195
+ @bootstrap_loop
196
+ def fit(self) -> None:
197
+ """
198
+ Fits weight models (numerator, denominator, censoring) and outcome models (outcome, competing event)
199
+ """
200
+ if self.bootstrap_nboot > 0 and not hasattr(self, "_boot_samples"):
201
+ raise ValueError(
202
+ "Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
203
+ )
204
+
205
+ if self.weighted:
206
+ WDT = _weight_setup(self)
207
+ if not self.weight_preexpansion and not self.excused:
208
+ WDT = WDT.filter(pl.col("followup") > 0)
209
+
210
+ WDT = WDT.to_pandas()
211
+ for col in self.fixed_cols:
212
+ if col in WDT.columns:
213
+ WDT[col] = WDT[col].astype("category")
214
+
215
+ _fit_LTFU(self, WDT)
216
+ _fit_visit(self, WDT)
217
+ _fit_numerator(self, WDT)
218
+ _fit_denominator(self, WDT)
219
+
220
+ WDT = pl.from_pandas(WDT)
221
+ WDT = _weight_predict(self, WDT)
222
+ _weight_bind(self, WDT)
223
+ self.weight_stats = _weight_stats(self)
224
+
225
+ if self.subgroup_colname is not None:
226
+ return _subgroup_fit(self)
227
+
228
+ models = {
229
+ "outcome": _outcome_fit(
230
+ self,
231
+ self.DT,
232
+ self.outcome_col,
233
+ self.covariates,
234
+ self.weighted,
235
+ "weight",
236
+ )
237
+ }
238
+ if self.compevent_colname is not None:
239
+ models["compevent"] = _outcome_fit(
240
+ self,
241
+ self.DT,
242
+ self.compevent_colname,
243
+ self.covariates,
244
+ self.weighted,
245
+ "weight",
246
+ )
247
+ return models
248
+
249
+ def survival(self, **kwargs) -> None:
250
+ """
251
+ Uses fit outcome models (outcome, competing event) to estimate risk, survival, and incidence curves
252
+ """
253
+ allowed = {"bootstrap_CI", "bootstrap_CI_method"}
254
+ for key, val in kwargs.items():
255
+ if key in allowed:
256
+ setattr(self, key, val)
257
+ else:
258
+ raise ValueError(f"Unknown or misplaced arugment: {key}")
259
+
260
+ if not hasattr(self, "outcome_model") or not self.outcome_model:
261
+ raise ValueError(
262
+ "Outcome model not found. Please run the 'fit' method before calculating survival."
263
+ )
264
+
265
+ start = time.perf_counter()
266
+
267
+ risk_data = _pred_risk(self)
268
+ surv_data = _calculate_survival(self, risk_data)
269
+ self.km_data = pl.concat([risk_data, surv_data])
270
+ self.risk_estimates = _risk_estimates(self)
271
+
272
+ end = time.perf_counter()
273
+ self._survival_time = _format_time(start, end)
274
+
275
+ def hazard(self) -> None:
276
+ """
277
+ Uses fit outcome models (outcome, competing event) to estimate hazard ratios
278
+ """
279
+ start = time.perf_counter()
280
+
281
+ if not hasattr(self, "outcome_model") or not self.outcome_model:
282
+ raise ValueError(
283
+ "Outcome model not found. Please run the 'fit' method before calculating hazard ratio."
284
+ )
285
+ self.hazard_ratio = _calculate_hazard(self)
286
+
287
+ end = time.perf_counter()
288
+ self._hazard_time = _format_time(start, end)
289
+
290
+ def plot(self, **kwargs) -> None:
291
+ """
292
+ Shows a plot specific to plot_type
293
+ """
294
+ allowed = {"plot_type", "plot_colors", "plot_title", "plot_labels"}
295
+ for key, val in kwargs.items():
296
+ if key in allowed:
297
+ setattr(self, key, val)
298
+ else:
299
+ raise ValueError(f"Unknown or misplaced arugment: {key}")
300
+ self.km_graph = _survival_plot(self)
301
+
302
+ def collect(self) -> SEQoutput:
303
+ """
304
+ Collects all results current created into ``SEQoutput`` class
305
+ """
306
+ self._time_collected = datetime.datetime.now()
307
+
308
+ generated = [
309
+ "numerator_model",
310
+ "denominator_model",
311
+ "outcome_model",
312
+ "hazard_ratio",
313
+ "risk_estimates",
314
+ "km_data",
315
+ "km_graph",
316
+ "diagnostics",
317
+ "_survival_time",
318
+ "_hazard_time",
319
+ "_model_time",
320
+ "_expansion_time",
321
+ "weight_stats",
322
+ ]
323
+ for attr in generated:
324
+ if not hasattr(self, attr):
325
+ setattr(self, attr, None)
326
+
327
+ # Options ==========================
328
+ base = SEQopts()
329
+
330
+ for name, value in vars(self).items():
331
+ if name in asdict(base).keys():
332
+ setattr(base, name, value)
333
+
334
+ # Timing =========================
335
+ time = {
336
+ "start_time": self._time_initialized,
337
+ "expansion_time": self._expansion_time,
338
+ "model_time": self._model_time,
339
+ "survival_time": self._survival_time,
340
+ "hazard_time": self._hazard_time,
341
+ "collection_time": self._time_collected,
342
+ }
343
+
344
+ if self.compevent_colname is not None:
345
+ compevent_models = [model["compevent"] for model in self.outcome_models]
346
+ else:
347
+ compevent_models = None
348
+
349
+ if self.outcome_model is not None:
350
+ outcome_models = [model["outcome"] for model in self.outcome_model]
351
+
352
+ if self.risk_estimates is None:
353
+ risk_ratio = risk_difference = None
354
+ else:
355
+ risk_ratio = self.risk_estimates["risk_ratio"]
356
+ risk_difference = self.risk_estimates["risk_difference"]
357
+
358
+ output = SEQoutput(
359
+ options=base,
360
+ method=self.method,
361
+ numerator_models=self.numerator_model,
362
+ denominator_models=self.denominator_model,
363
+ outcome_models=outcome_models,
364
+ compevent_models=compevent_models,
365
+ weight_statistics=self.weight_stats,
366
+ hazard=self.hazard_ratio,
367
+ km_data=self.km_data,
368
+ km_graph=self.km_graph,
369
+ risk_ratio=risk_ratio,
370
+ risk_difference=risk_difference,
371
+ time=time,
372
+ diagnostic_tables=self.diagnostics,
373
+ )
374
+
375
+ return output
@@ -0,0 +1,5 @@
1
+ from .SEQopts import SEQopts
2
+ from .SEQoutput import SEQoutput
3
+ from .SEQuential import SEQuential
4
+
5
+ __all__ = ["SEQuential", "SEQopts", "SEQoutput"]
@@ -0,0 +1,8 @@
1
+ from ._hazard import _calculate_hazard as _calculate_hazard
2
+ from ._outcome_fit import _outcome_fit as _outcome_fit
3
+ from ._risk_estimates import _risk_estimates as _risk_estimates
4
+ from ._subgroup_fit import _subgroup_fit as _subgroup_fit
5
+ from ._survival_pred import _calculate_survival as _calculate_survival
6
+ from ._survival_pred import \
7
+ _get_outcome_predictions as _get_outcome_predictions
8
+ from ._survival_pred import _pred_risk as _pred_risk
@@ -0,0 +1,211 @@
1
+ import warnings
2
+
3
+ import numpy as np
4
+ import polars as pl
5
+ from lifelines import CoxPHFitter
6
+
7
+
8
+ def _calculate_hazard(self):
9
+ if self.subgroup_colname is None:
10
+ return _calculate_hazard_single(self, self.DT, idx=None, val=None)
11
+
12
+ all_hazards = []
13
+ original_DT = self.DT
14
+
15
+ for i, val in enumerate(self._unique_subgroups):
16
+ subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val)
17
+ hazard = _calculate_hazard_single(self, subgroup_DT, i, val)
18
+ all_hazards.append(hazard)
19
+
20
+ self.DT = original_DT
21
+ return pl.concat(all_hazards)
22
+
23
+
24
+ def _calculate_hazard_single(self, data, idx=None, val=None):
25
+ full_hr = _hazard_handler(self, data, idx, 0, self._rng)
26
+
27
+ if full_hr is None or np.isnan(full_hr):
28
+ return _create_hazard_output(None, None, None, val, self)
29
+
30
+ if self.bootstrap_nboot > 0:
31
+ boot_hrs = []
32
+
33
+ for boot_idx in range(len(self._boot_samples)):
34
+ id_counts = self._boot_samples[boot_idx]
35
+
36
+ boot_data_list = []
37
+ for id_val, count in id_counts.items():
38
+ id_data = data.filter(pl.col(self.id_col) == id_val)
39
+ for _ in range(count):
40
+ boot_data_list.append(id_data)
41
+
42
+ boot_data = pl.concat(boot_data_list)
43
+
44
+ boot_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng)
45
+ if boot_hr is not None and not np.isnan(boot_hr):
46
+ boot_hrs.append(boot_hr)
47
+
48
+ if len(boot_hrs) == 0:
49
+ return _create_hazard_output(full_hr, None, None, val, self)
50
+
51
+ if self.bootstrap_CI_method == "se":
52
+ from scipy.stats import norm
53
+
54
+ z = norm.ppf(1 - (1 - self.bootstrap_CI) / 2)
55
+ se = np.std(boot_hrs)
56
+ lci = full_hr - z * se
57
+ uci = full_hr + z * se
58
+ else:
59
+ lci = np.quantile(boot_hrs, (1 - self.bootstrap_CI) / 2)
60
+ uci = np.quantile(boot_hrs, 1 - (1 - self.bootstrap_CI) / 2)
61
+ else:
62
+ lci, uci = None, None
63
+
64
+ return _create_hazard_output(full_hr, lci, uci, val, self)
65
+
66
+
67
+ def _hazard_handler(self, data, idx, boot_idx, rng):
68
+ exclude_cols = [
69
+ "followup",
70
+ f"followup{self.indicator_squared}",
71
+ self.treatment_col,
72
+ f"{self.treatment_col}{self.indicator_baseline}",
73
+ "period",
74
+ self.outcome_col,
75
+ ]
76
+ if self.compevent_colname:
77
+ exclude_cols.append(self.compevent_colname)
78
+ keep_cols = [col for col in data.columns if col not in exclude_cols]
79
+
80
+ trials = (
81
+ data.select(keep_cols)
82
+ .group_by([self.id_col, "trial"])
83
+ .first()
84
+ .with_columns([pl.lit(list(range(self.followup_max + 1))).alias("followup")])
85
+ .explode("followup")
86
+ .with_columns(
87
+ [(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
88
+ )
89
+ )
90
+
91
+ if idx is not None:
92
+ model_dict = self.outcome_model[boot_idx][idx]
93
+ else:
94
+ model_dict = self.outcome_model[boot_idx]
95
+
96
+ outcome_model = model_dict["outcome"]
97
+ ce_model = model_dict.get("compevent", None) if self.compevent_colname else None
98
+
99
+ all_treatments = []
100
+ for val in self.treatment_level:
101
+ tmp = trials.with_columns(
102
+ [pl.lit(val).alias(f"{self.treatment_col}{self.indicator_baseline}")]
103
+ )
104
+
105
+ tmp_pd = tmp.to_pandas()
106
+ outcome_prob = outcome_model.predict(tmp_pd)
107
+ outcome_sim = rng.binomial(1, outcome_prob)
108
+
109
+ tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])
110
+
111
+ if ce_model is not None:
112
+ ce_prob = ce_model.predict(tmp_pd)
113
+ ce_sim = rng.binomial(1, ce_prob)
114
+ tmp = tmp.with_columns([pl.Series("ce", ce_sim)])
115
+
116
+ tmp = (
117
+ tmp.with_columns(
118
+ [
119
+ pl.when((pl.col("outcome") == 1) | (pl.col("ce") == 1))
120
+ .then(1)
121
+ .otherwise(0)
122
+ .alias("any_event")
123
+ ]
124
+ )
125
+ .with_columns(
126
+ [
127
+ pl.col("any_event")
128
+ .cum_sum()
129
+ .over([self.id_col, "trial"])
130
+ .alias("event_cumsum")
131
+ ]
132
+ )
133
+ .filter(pl.col("event_cumsum") <= 1)
134
+ )
135
+ else:
136
+ tmp = tmp.with_columns(
137
+ [
138
+ pl.col("outcome")
139
+ .cum_sum()
140
+ .over([self.id_col, "trial"])
141
+ .alias("event_cumsum")
142
+ ]
143
+ ).filter(pl.col("event_cumsum") <= 1)
144
+
145
+ tmp = tmp.group_by([self.id_col, "trial"]).last()
146
+ all_treatments.append(tmp)
147
+
148
+ sim_data = pl.concat(all_treatments)
149
+
150
+ if ce_model is not None:
151
+ sim_data = sim_data.with_columns(
152
+ [
153
+ pl.when(pl.col("outcome") == 1)
154
+ .then(pl.lit(1))
155
+ .when(pl.col("ce") == 1)
156
+ .then(pl.lit(2))
157
+ .otherwise(pl.lit(0))
158
+ .alias("event")
159
+ ]
160
+ )
161
+ else:
162
+ sim_data = sim_data.with_columns([pl.col("outcome").alias("event")])
163
+
164
+ sim_data_pd = sim_data.to_pandas()
165
+
166
+ try:
167
+ # COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow()
168
+ warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*")
169
+ if ce_model is not None:
170
+ cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy()
171
+ cox_data["event_binary"] = (cox_data["event"] == 1).astype(int)
172
+
173
+ cph = CoxPHFitter()
174
+ cph.fit(
175
+ cox_data,
176
+ duration_col="followup",
177
+ event_col="event_binary",
178
+ formula=f"`{self.treatment_col}{self.indicator_baseline}`",
179
+ )
180
+ else:
181
+ cph = CoxPHFitter()
182
+ cph.fit(
183
+ sim_data_pd,
184
+ duration_col="followup",
185
+ event_col="event",
186
+ formula=f"`{self.treatment_col}{self.indicator_baseline}`",
187
+ )
188
+
189
+ hr = np.exp(cph.params_.values[0])
190
+ return hr
191
+ except Exception as e:
192
+ print(f"Cox model fitting failed: {e}")
193
+ return None
194
+
195
+
196
+ def _create_hazard_output(hr, lci, uci, val, self):
197
+ if lci is not None and uci is not None:
198
+ output = pl.DataFrame(
199
+ {
200
+ "Hazard": [hr if hr is not None else float("nan")],
201
+ "LCI": [lci],
202
+ "UCI": [uci],
203
+ }
204
+ )
205
+ else:
206
+ output = pl.DataFrame({"Hazard": [hr if hr is not None else float("nan")]})
207
+
208
+ if val is not None:
209
+ output = output.with_columns(pl.lit(val).alias(self.subgroup_colname))
210
+
211
+ return output
@@ -0,0 +1,75 @@
1
+ import re
2
+
3
+ import polars as pl
4
+ import statsmodels.api as sm
5
+ import statsmodels.formula.api as smf
6
+
7
+
8
+ def _outcome_fit(
9
+ self,
10
+ df: pl.DataFrame,
11
+ outcome: str,
12
+ formula: str,
13
+ weighted: bool = False,
14
+ weight_col: str = "weight",
15
+ ):
16
+ if weighted:
17
+ df = df.with_columns(
18
+ pl.col(weight_col).clip(
19
+ lower_bound=self.weight_min, upper_bound=self.weight_max
20
+ )
21
+ )
22
+
23
+ if self.method == "censoring":
24
+ df = df.filter(pl.col("switch") != 1)
25
+
26
+ df_pd = df.to_pandas()
27
+
28
+ df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category")
29
+ tx_bas = f"{self.treatment_col}{self.indicator_baseline}"
30
+ df_pd[tx_bas] = df_pd[tx_bas].astype("category")
31
+
32
+ if self.followup_class and not self.followup_spline:
33
+ df_pd["followup"] = df_pd["followup"].astype("category")
34
+ squared_col = f"followup{self.indicator_squared}"
35
+ if squared_col in df_pd.columns:
36
+ df_pd[squared_col] = df_pd[squared_col].astype("category")
37
+
38
+ if self.followup_spline:
39
+ spline = "cr(followup, df=3)"
40
+
41
+ formula = re.sub(r"(\w+)\s*\*\s*followup\b", rf"\1*{spline}", formula)
42
+ formula = re.sub(r"\bfollowup\s*\*\s*(\w+)", rf"{spline}*\1", formula)
43
+ formula = re.sub(
44
+ rf"\bfollowup{re.escape(self.indicator_squared)}\b", "", formula
45
+ )
46
+ formula = re.sub(r"\bfollowup\b", "", formula)
47
+
48
+ formula = re.sub(r"\s+", " ", formula)
49
+ formula = re.sub(r"\+\s*\+", "+", formula)
50
+ formula = re.sub(r"^\s*\+\s*|\s*\+\s*$", "", formula).strip()
51
+
52
+ if formula:
53
+ formula = f"{formula} + I({spline}**2)"
54
+ else:
55
+ formula = f"I({spline}**2)"
56
+
57
+ if self.fixed_cols:
58
+ for col in self.fixed_cols:
59
+ if col in df_pd.columns:
60
+ df_pd[col] = df_pd[col].astype("category")
61
+
62
+ full_formula = f"{outcome} ~ {formula}"
63
+
64
+ glm_kwargs = {
65
+ "formula": full_formula,
66
+ "data": df_pd,
67
+ "family": sm.families.Binomial(),
68
+ }
69
+
70
+ if weighted:
71
+ glm_kwargs["var_weights"] = df_pd[weight_col]
72
+
73
+ model = smf.glm(**glm_kwargs)
74
+ model_fit = model.fit()
75
+ return model_fit