pySEQTarget 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pySEQTarget/SEQopts.py +197 -0
- pySEQTarget/SEQoutput.py +163 -0
- pySEQTarget/SEQuential.py +375 -0
- pySEQTarget/__init__.py +5 -0
- pySEQTarget/analysis/__init__.py +8 -0
- pySEQTarget/analysis/_hazard.py +211 -0
- pySEQTarget/analysis/_outcome_fit.py +75 -0
- pySEQTarget/analysis/_risk_estimates.py +136 -0
- pySEQTarget/analysis/_subgroup_fit.py +30 -0
- pySEQTarget/analysis/_survival_pred.py +372 -0
- pySEQTarget/data/__init__.py +19 -0
- pySEQTarget/error/__init__.py +2 -0
- pySEQTarget/error/_datachecker.py +38 -0
- pySEQTarget/error/_param_checker.py +50 -0
- pySEQTarget/expansion/__init__.py +5 -0
- pySEQTarget/expansion/_binder.py +98 -0
- pySEQTarget/expansion/_diagnostics.py +53 -0
- pySEQTarget/expansion/_dynamic.py +73 -0
- pySEQTarget/expansion/_mapper.py +44 -0
- pySEQTarget/expansion/_selection.py +31 -0
- pySEQTarget/helpers/__init__.py +8 -0
- pySEQTarget/helpers/_bootstrap.py +111 -0
- pySEQTarget/helpers/_col_string.py +6 -0
- pySEQTarget/helpers/_format_time.py +6 -0
- pySEQTarget/helpers/_output_files.py +167 -0
- pySEQTarget/helpers/_pad.py +7 -0
- pySEQTarget/helpers/_predict_model.py +9 -0
- pySEQTarget/helpers/_prepare_data.py +19 -0
- pySEQTarget/initialization/__init__.py +5 -0
- pySEQTarget/initialization/_censoring.py +53 -0
- pySEQTarget/initialization/_denominator.py +39 -0
- pySEQTarget/initialization/_numerator.py +37 -0
- pySEQTarget/initialization/_outcome.py +56 -0
- pySEQTarget/plot/__init__.py +1 -0
- pySEQTarget/plot/_survival_plot.py +104 -0
- pySEQTarget/weighting/__init__.py +8 -0
- pySEQTarget/weighting/_weight_bind.py +86 -0
- pySEQTarget/weighting/_weight_data.py +47 -0
- pySEQTarget/weighting/_weight_fit.py +99 -0
- pySEQTarget/weighting/_weight_pred.py +192 -0
- pySEQTarget/weighting/_weight_stats.py +23 -0
- pyseqtarget-0.10.0.dist-info/METADATA +98 -0
- pyseqtarget-0.10.0.dist-info/RECORD +46 -0
- pyseqtarget-0.10.0.dist-info/WHEEL +5 -0
- pyseqtarget-0.10.0.dist-info/licenses/LICENSE +21 -0
- pyseqtarget-0.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,375 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import time
|
|
3
|
+
from collections import Counter
|
|
4
|
+
from dataclasses import asdict
|
|
5
|
+
from typing import List, Literal, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import polars as pl
|
|
9
|
+
|
|
10
|
+
from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit,
|
|
11
|
+
_pred_risk, _risk_estimates, _subgroup_fit)
|
|
12
|
+
from .error import _datachecker, _param_checker
|
|
13
|
+
from .expansion import _binder, _diagnostics, _dynamic, _random_selection
|
|
14
|
+
from .helpers import _col_string, _format_time, bootstrap_loop
|
|
15
|
+
from .initialization import (_cense_denominator, _cense_numerator,
|
|
16
|
+
_denominator, _numerator, _outcome)
|
|
17
|
+
from .plot import _survival_plot
|
|
18
|
+
from .SEQopts import SEQopts
|
|
19
|
+
from .SEQoutput import SEQoutput
|
|
20
|
+
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
|
|
21
|
+
_fit_visit, _weight_bind, _weight_predict,
|
|
22
|
+
_weight_setup, _weight_stats)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SEQuential:
|
|
26
|
+
"""
|
|
27
|
+
Primary class initializer for SEQuentially nested target trial emulation
|
|
28
|
+
|
|
29
|
+
:param data: Data for analysis
|
|
30
|
+
:type data: pl.DataFrame
|
|
31
|
+
:param id_col: Column name for unique patient IDs
|
|
32
|
+
:type id_col: str
|
|
33
|
+
:param time_col: Column name for observational time points
|
|
34
|
+
:type time_col: str
|
|
35
|
+
:param eligible_col: Column name for analytical eligibility
|
|
36
|
+
:type eligible_col: str
|
|
37
|
+
:param treatment_col: Column name specifying treatment per time_col
|
|
38
|
+
:type treatment_col: str
|
|
39
|
+
:param outcome_col: Column name specifying outcome per time_col
|
|
40
|
+
:type outcome_col: str
|
|
41
|
+
:param time_varying_cols: Time-varying column names as covariates (BMI, Age, etc.)
|
|
42
|
+
:type time_varying_cols: Optional[List[str]] or None
|
|
43
|
+
:param fixed_cols: Fixed column names as covariates (Sex, YOB, etc.)
|
|
44
|
+
:type fixed_cols: Optional[List[str]] or None
|
|
45
|
+
:param method: Method for analysis ['ITT', 'dose-response', or 'censoring']
|
|
46
|
+
:type method: str
|
|
47
|
+
:param parameters: Parameters to augment analysis, specified with ``pySEQTarget.SEQopts``
|
|
48
|
+
:type parameters: Optional[SEQopts] or None
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
data: pl.DataFrame,
|
|
54
|
+
id_col: str,
|
|
55
|
+
time_col: str,
|
|
56
|
+
eligible_col: str,
|
|
57
|
+
treatment_col: str,
|
|
58
|
+
outcome_col: str,
|
|
59
|
+
time_varying_cols: Optional[List[str]] = None,
|
|
60
|
+
fixed_cols: Optional[List[str]] = None,
|
|
61
|
+
method: Literal["ITT", "dose-response", "censoring"] = "ITT",
|
|
62
|
+
parameters: Optional[SEQopts] = None,
|
|
63
|
+
) -> None:
|
|
64
|
+
self.data = data
|
|
65
|
+
self.id_col = id_col
|
|
66
|
+
self.time_col = time_col
|
|
67
|
+
self.eligible_col = eligible_col
|
|
68
|
+
self.treatment_col = treatment_col
|
|
69
|
+
self.outcome_col = outcome_col
|
|
70
|
+
self.time_varying_cols = time_varying_cols
|
|
71
|
+
self.fixed_cols = fixed_cols
|
|
72
|
+
self.method = method
|
|
73
|
+
|
|
74
|
+
self._time_initialized = datetime.datetime.now()
|
|
75
|
+
|
|
76
|
+
if parameters is None:
|
|
77
|
+
parameters = SEQopts()
|
|
78
|
+
|
|
79
|
+
for name, value in asdict(parameters).items():
|
|
80
|
+
setattr(self, name, value)
|
|
81
|
+
|
|
82
|
+
self._rng = (
|
|
83
|
+
np.random.RandomState(self.seed) if self.seed is not None else np.random
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if self.covariates is None:
|
|
87
|
+
self.covariates = _outcome(self)
|
|
88
|
+
|
|
89
|
+
if self.weighted:
|
|
90
|
+
if self.numerator is None:
|
|
91
|
+
self.numerator = _numerator(self)
|
|
92
|
+
|
|
93
|
+
if self.denominator is None:
|
|
94
|
+
self.denominator = _denominator(self)
|
|
95
|
+
|
|
96
|
+
if self.cense_colname is not None or self.visit_colname is not None:
|
|
97
|
+
if self.cense_numerator is None:
|
|
98
|
+
self.cense_numerator = _cense_numerator(self)
|
|
99
|
+
|
|
100
|
+
if self.cense_denominator is None:
|
|
101
|
+
self.cense_denominator = _cense_denominator(self)
|
|
102
|
+
|
|
103
|
+
_param_checker(self)
|
|
104
|
+
_datachecker(self)
|
|
105
|
+
|
|
106
|
+
def expand(self) -> None:
|
|
107
|
+
"""
|
|
108
|
+
Creates the sequentially nested, emulated target trial structure
|
|
109
|
+
"""
|
|
110
|
+
start = time.perf_counter()
|
|
111
|
+
kept = [
|
|
112
|
+
self.cense_colname,
|
|
113
|
+
self.cense_eligible_colname,
|
|
114
|
+
self.compevent_colname,
|
|
115
|
+
self.visit_colname,
|
|
116
|
+
*self.weight_eligible_colnames,
|
|
117
|
+
*self.excused_colnames,
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
self.data = self.data.with_columns(
|
|
121
|
+
[
|
|
122
|
+
pl.when(pl.col(self.treatment_col).is_in(self.treatment_level))
|
|
123
|
+
.then(self.eligible_col)
|
|
124
|
+
.otherwise(0)
|
|
125
|
+
.alias(self.eligible_col),
|
|
126
|
+
pl.col(self.treatment_col).shift(1).over([self.id_col]).alias("tx_lag"),
|
|
127
|
+
pl.lit(False).alias("switch"),
|
|
128
|
+
]
|
|
129
|
+
).with_columns(
|
|
130
|
+
[
|
|
131
|
+
pl.when(pl.col(self.time_col) == 0)
|
|
132
|
+
.then(pl.lit(False))
|
|
133
|
+
.otherwise(
|
|
134
|
+
(pl.col("tx_lag").is_not_null())
|
|
135
|
+
& (pl.col("tx_lag") != pl.col(self.treatment_col))
|
|
136
|
+
)
|
|
137
|
+
.cast(pl.Int8)
|
|
138
|
+
.alias("switch")
|
|
139
|
+
]
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
self.DT = _binder(
|
|
143
|
+
self,
|
|
144
|
+
kept_cols=_col_string(
|
|
145
|
+
[
|
|
146
|
+
self.covariates,
|
|
147
|
+
self.numerator,
|
|
148
|
+
self.denominator,
|
|
149
|
+
self.cense_numerator,
|
|
150
|
+
self.cense_denominator,
|
|
151
|
+
]
|
|
152
|
+
).union(kept),
|
|
153
|
+
).with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))
|
|
154
|
+
|
|
155
|
+
self.data = self.data.with_columns(
|
|
156
|
+
pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if self.method != "ITT":
|
|
160
|
+
_dynamic(self)
|
|
161
|
+
if self.selection_random:
|
|
162
|
+
_random_selection(self)
|
|
163
|
+
_diagnostics(self)
|
|
164
|
+
|
|
165
|
+
end = time.perf_counter()
|
|
166
|
+
self._expansion_time = _format_time(start, end)
|
|
167
|
+
|
|
168
|
+
def bootstrap(self, **kwargs) -> None:
|
|
169
|
+
"""
|
|
170
|
+
Internally sets up bootstrapping - creating a list of IDs to use per iteration
|
|
171
|
+
"""
|
|
172
|
+
allowed = {
|
|
173
|
+
"bootstrap_nboot",
|
|
174
|
+
"bootstrap_sample",
|
|
175
|
+
"bootstrap_CI",
|
|
176
|
+
"bootstrap_method",
|
|
177
|
+
}
|
|
178
|
+
for key, value in kwargs.items():
|
|
179
|
+
if key in allowed:
|
|
180
|
+
setattr(self, key, value)
|
|
181
|
+
else:
|
|
182
|
+
raise ValueError(f"Unknown argument: {key}")
|
|
183
|
+
UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list()
|
|
184
|
+
NIDs = len(UIDs)
|
|
185
|
+
|
|
186
|
+
self._boot_samples = []
|
|
187
|
+
for _ in range(self.bootstrap_nboot):
|
|
188
|
+
sampled_IDs = self._rng.choice(
|
|
189
|
+
UIDs, size=int(self.bootstrap_sample * NIDs), replace=True
|
|
190
|
+
)
|
|
191
|
+
id_counts = Counter(sampled_IDs)
|
|
192
|
+
self._boot_samples.append(id_counts)
|
|
193
|
+
return self
|
|
194
|
+
|
|
195
|
+
@bootstrap_loop
|
|
196
|
+
def fit(self) -> None:
|
|
197
|
+
"""
|
|
198
|
+
Fits weight models (numerator, denominator, censoring) and outcome models (outcome, competing event)
|
|
199
|
+
"""
|
|
200
|
+
if self.bootstrap_nboot > 0 and not hasattr(self, "_boot_samples"):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
"Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
if self.weighted:
|
|
206
|
+
WDT = _weight_setup(self)
|
|
207
|
+
if not self.weight_preexpansion and not self.excused:
|
|
208
|
+
WDT = WDT.filter(pl.col("followup") > 0)
|
|
209
|
+
|
|
210
|
+
WDT = WDT.to_pandas()
|
|
211
|
+
for col in self.fixed_cols:
|
|
212
|
+
if col in WDT.columns:
|
|
213
|
+
WDT[col] = WDT[col].astype("category")
|
|
214
|
+
|
|
215
|
+
_fit_LTFU(self, WDT)
|
|
216
|
+
_fit_visit(self, WDT)
|
|
217
|
+
_fit_numerator(self, WDT)
|
|
218
|
+
_fit_denominator(self, WDT)
|
|
219
|
+
|
|
220
|
+
WDT = pl.from_pandas(WDT)
|
|
221
|
+
WDT = _weight_predict(self, WDT)
|
|
222
|
+
_weight_bind(self, WDT)
|
|
223
|
+
self.weight_stats = _weight_stats(self)
|
|
224
|
+
|
|
225
|
+
if self.subgroup_colname is not None:
|
|
226
|
+
return _subgroup_fit(self)
|
|
227
|
+
|
|
228
|
+
models = {
|
|
229
|
+
"outcome": _outcome_fit(
|
|
230
|
+
self,
|
|
231
|
+
self.DT,
|
|
232
|
+
self.outcome_col,
|
|
233
|
+
self.covariates,
|
|
234
|
+
self.weighted,
|
|
235
|
+
"weight",
|
|
236
|
+
)
|
|
237
|
+
}
|
|
238
|
+
if self.compevent_colname is not None:
|
|
239
|
+
models["compevent"] = _outcome_fit(
|
|
240
|
+
self,
|
|
241
|
+
self.DT,
|
|
242
|
+
self.compevent_colname,
|
|
243
|
+
self.covariates,
|
|
244
|
+
self.weighted,
|
|
245
|
+
"weight",
|
|
246
|
+
)
|
|
247
|
+
return models
|
|
248
|
+
|
|
249
|
+
def survival(self, **kwargs) -> None:
|
|
250
|
+
"""
|
|
251
|
+
Uses fit outcome models (outcome, competing event) to estimate risk, survival, and incidence curves
|
|
252
|
+
"""
|
|
253
|
+
allowed = {"bootstrap_CI", "bootstrap_CI_method"}
|
|
254
|
+
for key, val in kwargs.items():
|
|
255
|
+
if key in allowed:
|
|
256
|
+
setattr(self, key, val)
|
|
257
|
+
else:
|
|
258
|
+
raise ValueError(f"Unknown or misplaced arugment: {key}")
|
|
259
|
+
|
|
260
|
+
if not hasattr(self, "outcome_model") or not self.outcome_model:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
"Outcome model not found. Please run the 'fit' method before calculating survival."
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
start = time.perf_counter()
|
|
266
|
+
|
|
267
|
+
risk_data = _pred_risk(self)
|
|
268
|
+
surv_data = _calculate_survival(self, risk_data)
|
|
269
|
+
self.km_data = pl.concat([risk_data, surv_data])
|
|
270
|
+
self.risk_estimates = _risk_estimates(self)
|
|
271
|
+
|
|
272
|
+
end = time.perf_counter()
|
|
273
|
+
self._survival_time = _format_time(start, end)
|
|
274
|
+
|
|
275
|
+
def hazard(self) -> None:
|
|
276
|
+
"""
|
|
277
|
+
Uses fit outcome models (outcome, competing event) to estimate hazard ratios
|
|
278
|
+
"""
|
|
279
|
+
start = time.perf_counter()
|
|
280
|
+
|
|
281
|
+
if not hasattr(self, "outcome_model") or not self.outcome_model:
|
|
282
|
+
raise ValueError(
|
|
283
|
+
"Outcome model not found. Please run the 'fit' method before calculating hazard ratio."
|
|
284
|
+
)
|
|
285
|
+
self.hazard_ratio = _calculate_hazard(self)
|
|
286
|
+
|
|
287
|
+
end = time.perf_counter()
|
|
288
|
+
self._hazard_time = _format_time(start, end)
|
|
289
|
+
|
|
290
|
+
def plot(self, **kwargs) -> None:
|
|
291
|
+
"""
|
|
292
|
+
Shows a plot specific to plot_type
|
|
293
|
+
"""
|
|
294
|
+
allowed = {"plot_type", "plot_colors", "plot_title", "plot_labels"}
|
|
295
|
+
for key, val in kwargs.items():
|
|
296
|
+
if key in allowed:
|
|
297
|
+
setattr(self, key, val)
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Unknown or misplaced arugment: {key}")
|
|
300
|
+
self.km_graph = _survival_plot(self)
|
|
301
|
+
|
|
302
|
+
def collect(self) -> SEQoutput:
|
|
303
|
+
"""
|
|
304
|
+
Collects all results current created into ``SEQoutput`` class
|
|
305
|
+
"""
|
|
306
|
+
self._time_collected = datetime.datetime.now()
|
|
307
|
+
|
|
308
|
+
generated = [
|
|
309
|
+
"numerator_model",
|
|
310
|
+
"denominator_model",
|
|
311
|
+
"outcome_model",
|
|
312
|
+
"hazard_ratio",
|
|
313
|
+
"risk_estimates",
|
|
314
|
+
"km_data",
|
|
315
|
+
"km_graph",
|
|
316
|
+
"diagnostics",
|
|
317
|
+
"_survival_time",
|
|
318
|
+
"_hazard_time",
|
|
319
|
+
"_model_time",
|
|
320
|
+
"_expansion_time",
|
|
321
|
+
"weight_stats",
|
|
322
|
+
]
|
|
323
|
+
for attr in generated:
|
|
324
|
+
if not hasattr(self, attr):
|
|
325
|
+
setattr(self, attr, None)
|
|
326
|
+
|
|
327
|
+
# Options ==========================
|
|
328
|
+
base = SEQopts()
|
|
329
|
+
|
|
330
|
+
for name, value in vars(self).items():
|
|
331
|
+
if name in asdict(base).keys():
|
|
332
|
+
setattr(base, name, value)
|
|
333
|
+
|
|
334
|
+
# Timing =========================
|
|
335
|
+
time = {
|
|
336
|
+
"start_time": self._time_initialized,
|
|
337
|
+
"expansion_time": self._expansion_time,
|
|
338
|
+
"model_time": self._model_time,
|
|
339
|
+
"survival_time": self._survival_time,
|
|
340
|
+
"hazard_time": self._hazard_time,
|
|
341
|
+
"collection_time": self._time_collected,
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
if self.compevent_colname is not None:
|
|
345
|
+
compevent_models = [model["compevent"] for model in self.outcome_models]
|
|
346
|
+
else:
|
|
347
|
+
compevent_models = None
|
|
348
|
+
|
|
349
|
+
if self.outcome_model is not None:
|
|
350
|
+
outcome_models = [model["outcome"] for model in self.outcome_model]
|
|
351
|
+
|
|
352
|
+
if self.risk_estimates is None:
|
|
353
|
+
risk_ratio = risk_difference = None
|
|
354
|
+
else:
|
|
355
|
+
risk_ratio = self.risk_estimates["risk_ratio"]
|
|
356
|
+
risk_difference = self.risk_estimates["risk_difference"]
|
|
357
|
+
|
|
358
|
+
output = SEQoutput(
|
|
359
|
+
options=base,
|
|
360
|
+
method=self.method,
|
|
361
|
+
numerator_models=self.numerator_model,
|
|
362
|
+
denominator_models=self.denominator_model,
|
|
363
|
+
outcome_models=outcome_models,
|
|
364
|
+
compevent_models=compevent_models,
|
|
365
|
+
weight_statistics=self.weight_stats,
|
|
366
|
+
hazard=self.hazard_ratio,
|
|
367
|
+
km_data=self.km_data,
|
|
368
|
+
km_graph=self.km_graph,
|
|
369
|
+
risk_ratio=risk_ratio,
|
|
370
|
+
risk_difference=risk_difference,
|
|
371
|
+
time=time,
|
|
372
|
+
diagnostic_tables=self.diagnostics,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
return output
|
pySEQTarget/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from ._hazard import _calculate_hazard as _calculate_hazard
|
|
2
|
+
from ._outcome_fit import _outcome_fit as _outcome_fit
|
|
3
|
+
from ._risk_estimates import _risk_estimates as _risk_estimates
|
|
4
|
+
from ._subgroup_fit import _subgroup_fit as _subgroup_fit
|
|
5
|
+
from ._survival_pred import _calculate_survival as _calculate_survival
|
|
6
|
+
from ._survival_pred import \
|
|
7
|
+
_get_outcome_predictions as _get_outcome_predictions
|
|
8
|
+
from ._survival_pred import _pred_risk as _pred_risk
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import polars as pl
|
|
5
|
+
from lifelines import CoxPHFitter
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _calculate_hazard(self):
|
|
9
|
+
if self.subgroup_colname is None:
|
|
10
|
+
return _calculate_hazard_single(self, self.DT, idx=None, val=None)
|
|
11
|
+
|
|
12
|
+
all_hazards = []
|
|
13
|
+
original_DT = self.DT
|
|
14
|
+
|
|
15
|
+
for i, val in enumerate(self._unique_subgroups):
|
|
16
|
+
subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val)
|
|
17
|
+
hazard = _calculate_hazard_single(self, subgroup_DT, i, val)
|
|
18
|
+
all_hazards.append(hazard)
|
|
19
|
+
|
|
20
|
+
self.DT = original_DT
|
|
21
|
+
return pl.concat(all_hazards)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _calculate_hazard_single(self, data, idx=None, val=None):
|
|
25
|
+
full_hr = _hazard_handler(self, data, idx, 0, self._rng)
|
|
26
|
+
|
|
27
|
+
if full_hr is None or np.isnan(full_hr):
|
|
28
|
+
return _create_hazard_output(None, None, None, val, self)
|
|
29
|
+
|
|
30
|
+
if self.bootstrap_nboot > 0:
|
|
31
|
+
boot_hrs = []
|
|
32
|
+
|
|
33
|
+
for boot_idx in range(len(self._boot_samples)):
|
|
34
|
+
id_counts = self._boot_samples[boot_idx]
|
|
35
|
+
|
|
36
|
+
boot_data_list = []
|
|
37
|
+
for id_val, count in id_counts.items():
|
|
38
|
+
id_data = data.filter(pl.col(self.id_col) == id_val)
|
|
39
|
+
for _ in range(count):
|
|
40
|
+
boot_data_list.append(id_data)
|
|
41
|
+
|
|
42
|
+
boot_data = pl.concat(boot_data_list)
|
|
43
|
+
|
|
44
|
+
boot_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng)
|
|
45
|
+
if boot_hr is not None and not np.isnan(boot_hr):
|
|
46
|
+
boot_hrs.append(boot_hr)
|
|
47
|
+
|
|
48
|
+
if len(boot_hrs) == 0:
|
|
49
|
+
return _create_hazard_output(full_hr, None, None, val, self)
|
|
50
|
+
|
|
51
|
+
if self.bootstrap_CI_method == "se":
|
|
52
|
+
from scipy.stats import norm
|
|
53
|
+
|
|
54
|
+
z = norm.ppf(1 - (1 - self.bootstrap_CI) / 2)
|
|
55
|
+
se = np.std(boot_hrs)
|
|
56
|
+
lci = full_hr - z * se
|
|
57
|
+
uci = full_hr + z * se
|
|
58
|
+
else:
|
|
59
|
+
lci = np.quantile(boot_hrs, (1 - self.bootstrap_CI) / 2)
|
|
60
|
+
uci = np.quantile(boot_hrs, 1 - (1 - self.bootstrap_CI) / 2)
|
|
61
|
+
else:
|
|
62
|
+
lci, uci = None, None
|
|
63
|
+
|
|
64
|
+
return _create_hazard_output(full_hr, lci, uci, val, self)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _hazard_handler(self, data, idx, boot_idx, rng):
|
|
68
|
+
exclude_cols = [
|
|
69
|
+
"followup",
|
|
70
|
+
f"followup{self.indicator_squared}",
|
|
71
|
+
self.treatment_col,
|
|
72
|
+
f"{self.treatment_col}{self.indicator_baseline}",
|
|
73
|
+
"period",
|
|
74
|
+
self.outcome_col,
|
|
75
|
+
]
|
|
76
|
+
if self.compevent_colname:
|
|
77
|
+
exclude_cols.append(self.compevent_colname)
|
|
78
|
+
keep_cols = [col for col in data.columns if col not in exclude_cols]
|
|
79
|
+
|
|
80
|
+
trials = (
|
|
81
|
+
data.select(keep_cols)
|
|
82
|
+
.group_by([self.id_col, "trial"])
|
|
83
|
+
.first()
|
|
84
|
+
.with_columns([pl.lit(list(range(self.followup_max + 1))).alias("followup")])
|
|
85
|
+
.explode("followup")
|
|
86
|
+
.with_columns(
|
|
87
|
+
[(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
|
|
88
|
+
)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
if idx is not None:
|
|
92
|
+
model_dict = self.outcome_model[boot_idx][idx]
|
|
93
|
+
else:
|
|
94
|
+
model_dict = self.outcome_model[boot_idx]
|
|
95
|
+
|
|
96
|
+
outcome_model = model_dict["outcome"]
|
|
97
|
+
ce_model = model_dict.get("compevent", None) if self.compevent_colname else None
|
|
98
|
+
|
|
99
|
+
all_treatments = []
|
|
100
|
+
for val in self.treatment_level:
|
|
101
|
+
tmp = trials.with_columns(
|
|
102
|
+
[pl.lit(val).alias(f"{self.treatment_col}{self.indicator_baseline}")]
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
tmp_pd = tmp.to_pandas()
|
|
106
|
+
outcome_prob = outcome_model.predict(tmp_pd)
|
|
107
|
+
outcome_sim = rng.binomial(1, outcome_prob)
|
|
108
|
+
|
|
109
|
+
tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])
|
|
110
|
+
|
|
111
|
+
if ce_model is not None:
|
|
112
|
+
ce_prob = ce_model.predict(tmp_pd)
|
|
113
|
+
ce_sim = rng.binomial(1, ce_prob)
|
|
114
|
+
tmp = tmp.with_columns([pl.Series("ce", ce_sim)])
|
|
115
|
+
|
|
116
|
+
tmp = (
|
|
117
|
+
tmp.with_columns(
|
|
118
|
+
[
|
|
119
|
+
pl.when((pl.col("outcome") == 1) | (pl.col("ce") == 1))
|
|
120
|
+
.then(1)
|
|
121
|
+
.otherwise(0)
|
|
122
|
+
.alias("any_event")
|
|
123
|
+
]
|
|
124
|
+
)
|
|
125
|
+
.with_columns(
|
|
126
|
+
[
|
|
127
|
+
pl.col("any_event")
|
|
128
|
+
.cum_sum()
|
|
129
|
+
.over([self.id_col, "trial"])
|
|
130
|
+
.alias("event_cumsum")
|
|
131
|
+
]
|
|
132
|
+
)
|
|
133
|
+
.filter(pl.col("event_cumsum") <= 1)
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
tmp = tmp.with_columns(
|
|
137
|
+
[
|
|
138
|
+
pl.col("outcome")
|
|
139
|
+
.cum_sum()
|
|
140
|
+
.over([self.id_col, "trial"])
|
|
141
|
+
.alias("event_cumsum")
|
|
142
|
+
]
|
|
143
|
+
).filter(pl.col("event_cumsum") <= 1)
|
|
144
|
+
|
|
145
|
+
tmp = tmp.group_by([self.id_col, "trial"]).last()
|
|
146
|
+
all_treatments.append(tmp)
|
|
147
|
+
|
|
148
|
+
sim_data = pl.concat(all_treatments)
|
|
149
|
+
|
|
150
|
+
if ce_model is not None:
|
|
151
|
+
sim_data = sim_data.with_columns(
|
|
152
|
+
[
|
|
153
|
+
pl.when(pl.col("outcome") == 1)
|
|
154
|
+
.then(pl.lit(1))
|
|
155
|
+
.when(pl.col("ce") == 1)
|
|
156
|
+
.then(pl.lit(2))
|
|
157
|
+
.otherwise(pl.lit(0))
|
|
158
|
+
.alias("event")
|
|
159
|
+
]
|
|
160
|
+
)
|
|
161
|
+
else:
|
|
162
|
+
sim_data = sim_data.with_columns([pl.col("outcome").alias("event")])
|
|
163
|
+
|
|
164
|
+
sim_data_pd = sim_data.to_pandas()
|
|
165
|
+
|
|
166
|
+
try:
|
|
167
|
+
# COXPHFITER CURRENTLY HAS DEPRECATED datetime.datetime.utcnow()
|
|
168
|
+
warnings.filterwarnings("ignore", message=".*datetime.datetime.utcnow.*")
|
|
169
|
+
if ce_model is not None:
|
|
170
|
+
cox_data = sim_data_pd[sim_data_pd["event"].isin([0, 1])].copy()
|
|
171
|
+
cox_data["event_binary"] = (cox_data["event"] == 1).astype(int)
|
|
172
|
+
|
|
173
|
+
cph = CoxPHFitter()
|
|
174
|
+
cph.fit(
|
|
175
|
+
cox_data,
|
|
176
|
+
duration_col="followup",
|
|
177
|
+
event_col="event_binary",
|
|
178
|
+
formula=f"`{self.treatment_col}{self.indicator_baseline}`",
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
cph = CoxPHFitter()
|
|
182
|
+
cph.fit(
|
|
183
|
+
sim_data_pd,
|
|
184
|
+
duration_col="followup",
|
|
185
|
+
event_col="event",
|
|
186
|
+
formula=f"`{self.treatment_col}{self.indicator_baseline}`",
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
hr = np.exp(cph.params_.values[0])
|
|
190
|
+
return hr
|
|
191
|
+
except Exception as e:
|
|
192
|
+
print(f"Cox model fitting failed: {e}")
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def _create_hazard_output(hr, lci, uci, val, self):
|
|
197
|
+
if lci is not None and uci is not None:
|
|
198
|
+
output = pl.DataFrame(
|
|
199
|
+
{
|
|
200
|
+
"Hazard": [hr if hr is not None else float("nan")],
|
|
201
|
+
"LCI": [lci],
|
|
202
|
+
"UCI": [uci],
|
|
203
|
+
}
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
output = pl.DataFrame({"Hazard": [hr if hr is not None else float("nan")]})
|
|
207
|
+
|
|
208
|
+
if val is not None:
|
|
209
|
+
output = output.with_columns(pl.lit(val).alias(self.subgroup_colname))
|
|
210
|
+
|
|
211
|
+
return output
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
import polars as pl
|
|
4
|
+
import statsmodels.api as sm
|
|
5
|
+
import statsmodels.formula.api as smf
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _outcome_fit(
|
|
9
|
+
self,
|
|
10
|
+
df: pl.DataFrame,
|
|
11
|
+
outcome: str,
|
|
12
|
+
formula: str,
|
|
13
|
+
weighted: bool = False,
|
|
14
|
+
weight_col: str = "weight",
|
|
15
|
+
):
|
|
16
|
+
if weighted:
|
|
17
|
+
df = df.with_columns(
|
|
18
|
+
pl.col(weight_col).clip(
|
|
19
|
+
lower_bound=self.weight_min, upper_bound=self.weight_max
|
|
20
|
+
)
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
if self.method == "censoring":
|
|
24
|
+
df = df.filter(pl.col("switch") != 1)
|
|
25
|
+
|
|
26
|
+
df_pd = df.to_pandas()
|
|
27
|
+
|
|
28
|
+
df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category")
|
|
29
|
+
tx_bas = f"{self.treatment_col}{self.indicator_baseline}"
|
|
30
|
+
df_pd[tx_bas] = df_pd[tx_bas].astype("category")
|
|
31
|
+
|
|
32
|
+
if self.followup_class and not self.followup_spline:
|
|
33
|
+
df_pd["followup"] = df_pd["followup"].astype("category")
|
|
34
|
+
squared_col = f"followup{self.indicator_squared}"
|
|
35
|
+
if squared_col in df_pd.columns:
|
|
36
|
+
df_pd[squared_col] = df_pd[squared_col].astype("category")
|
|
37
|
+
|
|
38
|
+
if self.followup_spline:
|
|
39
|
+
spline = "cr(followup, df=3)"
|
|
40
|
+
|
|
41
|
+
formula = re.sub(r"(\w+)\s*\*\s*followup\b", rf"\1*{spline}", formula)
|
|
42
|
+
formula = re.sub(r"\bfollowup\s*\*\s*(\w+)", rf"{spline}*\1", formula)
|
|
43
|
+
formula = re.sub(
|
|
44
|
+
rf"\bfollowup{re.escape(self.indicator_squared)}\b", "", formula
|
|
45
|
+
)
|
|
46
|
+
formula = re.sub(r"\bfollowup\b", "", formula)
|
|
47
|
+
|
|
48
|
+
formula = re.sub(r"\s+", " ", formula)
|
|
49
|
+
formula = re.sub(r"\+\s*\+", "+", formula)
|
|
50
|
+
formula = re.sub(r"^\s*\+\s*|\s*\+\s*$", "", formula).strip()
|
|
51
|
+
|
|
52
|
+
if formula:
|
|
53
|
+
formula = f"{formula} + I({spline}**2)"
|
|
54
|
+
else:
|
|
55
|
+
formula = f"I({spline}**2)"
|
|
56
|
+
|
|
57
|
+
if self.fixed_cols:
|
|
58
|
+
for col in self.fixed_cols:
|
|
59
|
+
if col in df_pd.columns:
|
|
60
|
+
df_pd[col] = df_pd[col].astype("category")
|
|
61
|
+
|
|
62
|
+
full_formula = f"{outcome} ~ {formula}"
|
|
63
|
+
|
|
64
|
+
glm_kwargs = {
|
|
65
|
+
"formula": full_formula,
|
|
66
|
+
"data": df_pd,
|
|
67
|
+
"family": sm.families.Binomial(),
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
if weighted:
|
|
71
|
+
glm_kwargs["var_weights"] = df_pd[weight_col]
|
|
72
|
+
|
|
73
|
+
model = smf.glm(**glm_kwargs)
|
|
74
|
+
model_fit = model.fit()
|
|
75
|
+
return model_fit
|