pySEQTarget 0.10.1__tar.gz → 0.12.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/PKG-INFO +6 -1
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/SEQopts.py +17 -2
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/SEQuential.py +17 -4
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_hazard.py +9 -4
- pyseqtarget-0.12.0/pySEQTarget/analysis/_risk_estimates.py +138 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_survival_pred.py +13 -12
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_mapper.py +4 -11
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/__init__.py +1 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_bootstrap.py +39 -4
- pyseqtarget-0.12.0/pySEQTarget/helpers/_fix_categories.py +21 -0
- pyseqtarget-0.12.0/pySEQTarget/helpers/_offloader.py +82 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_output_files.py +2 -4
- pyseqtarget-0.12.0/pySEQTarget/helpers/_predict_model.py +57 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/__init__.py +1 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_bind.py +3 -0
- pyseqtarget-0.12.0/pySEQTarget/weighting/_weight_fit.py +137 -0
- pyseqtarget-0.12.0/pySEQTarget/weighting/_weight_offload.py +19 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_pred.py +74 -52
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/PKG-INFO +6 -1
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/SOURCES.txt +4 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/requires.txt +6 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pyproject.toml +10 -8
- pyseqtarget-0.12.0/tests/test_offload.py +41 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_survival.py +13 -0
- pyseqtarget-0.10.1/pySEQTarget/analysis/_risk_estimates.py +0 -136
- pyseqtarget-0.10.1/pySEQTarget/helpers/_predict_model.py +0 -9
- pyseqtarget-0.10.1/pySEQTarget/weighting/_weight_fit.py +0 -99
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/LICENSE +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/README.md +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/SEQoutput.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/__init__.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/__init__.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_outcome_fit.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_subgroup_fit.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/data/__init__.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/error/__init__.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/error/_data_checker.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/error/_param_checker.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/__init__.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_binder.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_diagnostics.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_dynamic.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_selection.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_col_string.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_format_time.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_pad.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_prepare_data.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/__init__.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_censoring.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_denominator.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_numerator.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_outcome.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/plot/__init__.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/plot/_survival_plot.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_data.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_stats.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/dependency_links.txt +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/top_level.txt +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/setup.cfg +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_accessor.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_coefficients.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_covariates.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_followup_options.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_hazard.py +0 -0
- {pyseqtarget-0.10.1 → pyseqtarget-0.12.0}/tests/test_parallel.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pySEQTarget
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.12.0
|
|
4
4
|
Summary: Sequentially Nested Target Trial Emulation
|
|
5
5
|
Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernan <mhernan@hsph.harvard.edu>
|
|
6
6
|
Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>
|
|
@@ -33,6 +33,11 @@ Requires-Dist: statsmodels
|
|
|
33
33
|
Requires-Dist: matplotlib
|
|
34
34
|
Requires-Dist: pyarrow
|
|
35
35
|
Requires-Dist: lifelines
|
|
36
|
+
Requires-Dist: joblib
|
|
37
|
+
Provides-Extra: output
|
|
38
|
+
Requires-Dist: markdown; extra == "output"
|
|
39
|
+
Requires-Dist: weasyprint; extra == "output"
|
|
40
|
+
Requires-Dist: tabulate; extra == "output"
|
|
36
41
|
Dynamic: license-file
|
|
37
42
|
|
|
38
43
|
# pySEQTarget - Sequentially Nested Target Trial Emulation
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import multiprocessing
|
|
2
|
+
import os
|
|
2
3
|
from dataclasses import dataclass, field
|
|
3
4
|
from typing import List, Literal, Optional
|
|
4
5
|
|
|
@@ -18,7 +19,7 @@ class SEQopts:
|
|
|
18
19
|
:type bootstrap_CI_method: str
|
|
19
20
|
:param cense_colname: Column name for censoring effect (LTFU, etc.)
|
|
20
21
|
:type cense_colname: str
|
|
21
|
-
:param cense_denominator: Override to specify denominator patsy formula for censoring models
|
|
22
|
+
:param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model
|
|
22
23
|
:type cense_denominator: Optional[str] or None
|
|
23
24
|
:param cense_numerator: Override to specify numerator patsy formula for censoring models
|
|
24
25
|
:type cense_numerator: Optional[str] or None
|
|
@@ -54,8 +55,12 @@ class SEQopts:
|
|
|
54
55
|
:type km_curves: bool
|
|
55
56
|
:param ncores: Number of cores to use if running in parallel
|
|
56
57
|
:type ncores: int
|
|
57
|
-
:param numerator: Override to specify the outcome patsy formula for numerator models
|
|
58
|
+
:param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model
|
|
58
59
|
:type numerator: str
|
|
60
|
+
:param offload: Boolean to offload intermediate model data to disk
|
|
61
|
+
:type offload: bool
|
|
62
|
+
:param offload_dir: Directory to offload intermediate model data
|
|
63
|
+
:type offload_dir: str
|
|
59
64
|
:param parallel: Boolean to run model fitting in parallel
|
|
60
65
|
:type parallel: bool
|
|
61
66
|
:param plot_colors: List of colors for KM plots, if applicable
|
|
@@ -80,8 +85,12 @@ class SEQopts:
|
|
|
80
85
|
:type treatment_level: List[int]
|
|
81
86
|
:param trial_include: Boolean to force trial values into model covariates
|
|
82
87
|
:type trial_include: bool
|
|
88
|
+
:param visit_colname: Column name specifying visit number
|
|
89
|
+
:type visit_colname: str
|
|
83
90
|
:param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting
|
|
84
91
|
:type weight_eligible_colnames: List[str]
|
|
92
|
+
:param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton"
|
|
93
|
+
:type weight_fit_method: str
|
|
85
94
|
:param weight_min: Minimum weight
|
|
86
95
|
:type weight_min: float
|
|
87
96
|
:param weight_max: Maximum weight
|
|
@@ -120,6 +129,8 @@ class SEQopts:
|
|
|
120
129
|
km_curves: bool = False
|
|
121
130
|
ncores: int = multiprocessing.cpu_count()
|
|
122
131
|
numerator: Optional[str] = None
|
|
132
|
+
offload: bool = False
|
|
133
|
+
offload_dir: str = "_seq_models"
|
|
123
134
|
parallel: bool = False
|
|
124
135
|
plot_colors: List[str] = field(
|
|
125
136
|
default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]
|
|
@@ -136,6 +147,7 @@ class SEQopts:
|
|
|
136
147
|
trial_include: bool = True
|
|
137
148
|
visit_colname: str = None
|
|
138
149
|
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
|
|
150
|
+
weight_fit_method: Literal["newton", "bfgs", "lbfgs", "nm"] = "newton"
|
|
139
151
|
weight_min: float = 0.0
|
|
140
152
|
weight_max: float = None
|
|
141
153
|
weight_lag_condition: bool = True
|
|
@@ -195,3 +207,6 @@ class SEQopts:
|
|
|
195
207
|
attr = getattr(self, i)
|
|
196
208
|
if attr is not None and not isinstance(attr, list):
|
|
197
209
|
setattr(self, i, "".join(attr.split()))
|
|
210
|
+
|
|
211
|
+
if self.offload:
|
|
212
|
+
os.makedirs(self.offload_dir, exist_ok=True)
|
|
@@ -12,15 +12,15 @@ from .analysis import (_calculate_hazard, _calculate_survival, _clamp,
|
|
|
12
12
|
_subgroup_fit)
|
|
13
13
|
from .error import _data_checker, _param_checker
|
|
14
14
|
from .expansion import _binder, _diagnostics, _dynamic, _random_selection
|
|
15
|
-
from .helpers import _col_string, _format_time, bootstrap_loop
|
|
15
|
+
from .helpers import Offloader, _col_string, _format_time, bootstrap_loop
|
|
16
16
|
from .initialization import (_cense_denominator, _cense_numerator,
|
|
17
17
|
_denominator, _numerator, _outcome)
|
|
18
18
|
from .plot import _survival_plot
|
|
19
19
|
from .SEQopts import SEQopts
|
|
20
20
|
from .SEQoutput import SEQoutput
|
|
21
21
|
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
|
|
22
|
-
_fit_visit,
|
|
23
|
-
_weight_setup, _weight_stats)
|
|
22
|
+
_fit_visit, _offload_weights, _weight_bind,
|
|
23
|
+
_weight_predict, _weight_setup, _weight_stats)
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class SEQuential:
|
|
@@ -84,6 +84,8 @@ class SEQuential:
|
|
|
84
84
|
np.random.RandomState(self.seed) if self.seed is not None else np.random
|
|
85
85
|
)
|
|
86
86
|
|
|
87
|
+
self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir)
|
|
88
|
+
|
|
87
89
|
if self.covariates is None:
|
|
88
90
|
self.covariates = _outcome(self)
|
|
89
91
|
|
|
@@ -201,6 +203,9 @@ class SEQuential:
|
|
|
201
203
|
raise ValueError(
|
|
202
204
|
"Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
|
|
203
205
|
)
|
|
206
|
+
boot_idx = None
|
|
207
|
+
if hasattr(self, "_current_boot_idx"):
|
|
208
|
+
boot_idx = self._current_boot_idx
|
|
204
209
|
|
|
205
210
|
if self.weighted:
|
|
206
211
|
WDT = _weight_setup(self)
|
|
@@ -217,6 +222,9 @@ class SEQuential:
|
|
|
217
222
|
_fit_numerator(self, WDT)
|
|
218
223
|
_fit_denominator(self, WDT)
|
|
219
224
|
|
|
225
|
+
if self.offload:
|
|
226
|
+
_offload_weights(self, boot_idx)
|
|
227
|
+
|
|
220
228
|
WDT = pl.from_pandas(WDT)
|
|
221
229
|
WDT = _weight_predict(self, WDT)
|
|
222
230
|
_weight_bind(self, WDT)
|
|
@@ -244,6 +252,11 @@ class SEQuential:
|
|
|
244
252
|
self.weighted,
|
|
245
253
|
"weight",
|
|
246
254
|
)
|
|
255
|
+
if self.offload:
|
|
256
|
+
offloaded_models = {}
|
|
257
|
+
for key, model in models.items():
|
|
258
|
+
offloaded_models[key] = self._offloader.save_model(model, key, boot_idx)
|
|
259
|
+
return offloaded_models
|
|
247
260
|
return models
|
|
248
261
|
|
|
249
262
|
def survival(self, **kwargs) -> None:
|
|
@@ -342,7 +355,7 @@ class SEQuential:
|
|
|
342
355
|
}
|
|
343
356
|
|
|
344
357
|
if self.compevent_colname is not None:
|
|
345
|
-
compevent_models = [model["compevent"] for model in self.
|
|
358
|
+
compevent_models = [model["compevent"] for model in self.outcome_model]
|
|
346
359
|
else:
|
|
347
360
|
compevent_models = None
|
|
348
361
|
|
|
@@ -4,6 +4,8 @@ import numpy as np
|
|
|
4
4
|
import polars as pl
|
|
5
5
|
from lifelines import CoxPHFitter
|
|
6
6
|
|
|
7
|
+
from ..helpers._predict_model import _safe_predict
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
def _calculate_hazard(self):
|
|
9
11
|
if self.subgroup_colname is None:
|
|
@@ -93,8 +95,10 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
|
|
|
93
95
|
else:
|
|
94
96
|
model_dict = self.outcome_model[boot_idx]
|
|
95
97
|
|
|
96
|
-
outcome_model = model_dict["outcome"]
|
|
97
|
-
ce_model =
|
|
98
|
+
outcome_model = self._offloader.load_model(model_dict["outcome"])
|
|
99
|
+
ce_model = None
|
|
100
|
+
if self.compevent_colname and "compevent" in model_dict:
|
|
101
|
+
ce_model = self._offloader.load_model(model_dict["compevent"])
|
|
98
102
|
|
|
99
103
|
all_treatments = []
|
|
100
104
|
for val in self.treatment_level:
|
|
@@ -103,13 +107,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
|
|
|
103
107
|
)
|
|
104
108
|
|
|
105
109
|
tmp_pd = tmp.to_pandas()
|
|
106
|
-
outcome_prob = outcome_model
|
|
110
|
+
outcome_prob = _safe_predict(outcome_model, tmp_pd)
|
|
107
111
|
outcome_sim = rng.binomial(1, outcome_prob)
|
|
108
112
|
|
|
109
113
|
tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])
|
|
110
114
|
|
|
111
115
|
if ce_model is not None:
|
|
112
|
-
|
|
116
|
+
ce_tmp_pd = tmp.to_pandas()
|
|
117
|
+
ce_prob = _safe_predict(ce_model, ce_tmp_pd)
|
|
113
118
|
ce_sim = rng.binomial(1, ce_prob)
|
|
114
119
|
tmp = tmp.with_columns([pl.Series("ce", ce_sim)])
|
|
115
120
|
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
from scipy import stats
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None):
|
|
6
|
+
"""
|
|
7
|
+
Compute Risk Difference and Risk Ratio from a comparison dataframe.
|
|
8
|
+
Consolidates the repeated calculation logic.
|
|
9
|
+
"""
|
|
10
|
+
if group_cols is None:
|
|
11
|
+
group_cols = []
|
|
12
|
+
|
|
13
|
+
if has_bootstrap:
|
|
14
|
+
rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
|
|
15
|
+
rd_comp = comp.with_columns(
|
|
16
|
+
[
|
|
17
|
+
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
|
|
18
|
+
(pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias("RD 95% LCI"),
|
|
19
|
+
(pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias("RD 95% UCI"),
|
|
20
|
+
]
|
|
21
|
+
)
|
|
22
|
+
rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
|
|
23
|
+
col_order = group_cols + [
|
|
24
|
+
"A_x",
|
|
25
|
+
"A_y",
|
|
26
|
+
"Risk Difference",
|
|
27
|
+
"RD 95% LCI",
|
|
28
|
+
"RD 95% UCI",
|
|
29
|
+
]
|
|
30
|
+
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
|
|
31
|
+
|
|
32
|
+
rr_log_se = (
|
|
33
|
+
(pl.col("se_x") / pl.col("risk_x")).pow(2)
|
|
34
|
+
+ (pl.col("se_y") / pl.col("risk_y")).pow(2)
|
|
35
|
+
).sqrt()
|
|
36
|
+
rr_comp = comp.with_columns(
|
|
37
|
+
[
|
|
38
|
+
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
|
|
39
|
+
((pl.col("risk_x") / pl.col("risk_y")) * (-z * rr_log_se).exp()).alias(
|
|
40
|
+
"RR 95% LCI"
|
|
41
|
+
),
|
|
42
|
+
((pl.col("risk_x") / pl.col("risk_y")) * (z * rr_log_se).exp()).alias(
|
|
43
|
+
"RR 95% UCI"
|
|
44
|
+
),
|
|
45
|
+
]
|
|
46
|
+
)
|
|
47
|
+
rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
|
|
48
|
+
col_order = group_cols + [
|
|
49
|
+
"A_x",
|
|
50
|
+
"A_y",
|
|
51
|
+
"Risk Ratio",
|
|
52
|
+
"RR 95% LCI",
|
|
53
|
+
"RR 95% UCI",
|
|
54
|
+
]
|
|
55
|
+
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
|
|
56
|
+
else:
|
|
57
|
+
rd_comp = comp.with_columns(
|
|
58
|
+
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
|
|
59
|
+
)
|
|
60
|
+
rd_comp = rd_comp.drop(["risk_x", "risk_y"])
|
|
61
|
+
col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
|
|
62
|
+
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
|
|
63
|
+
|
|
64
|
+
rr_comp = comp.with_columns(
|
|
65
|
+
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
|
|
66
|
+
)
|
|
67
|
+
rr_comp = rr_comp.drop(["risk_x", "risk_y"])
|
|
68
|
+
col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
|
|
69
|
+
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
|
|
70
|
+
|
|
71
|
+
return rd_comp, rr_comp
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _risk_estimates(self):
|
|
75
|
+
last_followup = self.km_data["followup"].max()
|
|
76
|
+
risk = self.km_data.filter(
|
|
77
|
+
(pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
group_cols = [self.subgroup_colname] if self.subgroup_colname else []
|
|
81
|
+
has_bootstrap = self.bootstrap_nboot > 0
|
|
82
|
+
|
|
83
|
+
if has_bootstrap:
|
|
84
|
+
alpha = 1 - self.bootstrap_CI
|
|
85
|
+
z = stats.norm.ppf(1 - alpha / 2)
|
|
86
|
+
else:
|
|
87
|
+
z = None
|
|
88
|
+
|
|
89
|
+
# Pre-extract data for each treatment level once (avoid repeated filtering)
|
|
90
|
+
risk_by_level = {}
|
|
91
|
+
for tx in self.treatment_level:
|
|
92
|
+
level_data = risk.filter(pl.col(self.treatment_col) == tx)
|
|
93
|
+
risk_by_level[tx] = {
|
|
94
|
+
"pred": level_data.select(group_cols + ["pred"]),
|
|
95
|
+
}
|
|
96
|
+
if has_bootstrap:
|
|
97
|
+
risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])
|
|
98
|
+
|
|
99
|
+
rd_comparisons = []
|
|
100
|
+
rr_comparisons = []
|
|
101
|
+
|
|
102
|
+
for tx_x in self.treatment_level:
|
|
103
|
+
for tx_y in self.treatment_level:
|
|
104
|
+
if tx_x == tx_y:
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
# Use pre-extracted data instead of filtering again
|
|
108
|
+
risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
|
|
109
|
+
risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})
|
|
110
|
+
|
|
111
|
+
if group_cols:
|
|
112
|
+
comp = risk_x.join(risk_y, on=group_cols, how="left")
|
|
113
|
+
else:
|
|
114
|
+
comp = risk_x.join(risk_y, how="cross")
|
|
115
|
+
|
|
116
|
+
comp = comp.with_columns(
|
|
117
|
+
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if has_bootstrap:
|
|
121
|
+
se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
|
|
122
|
+
se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
|
|
123
|
+
|
|
124
|
+
if group_cols:
|
|
125
|
+
comp = comp.join(se_x, on=group_cols, how="left")
|
|
126
|
+
comp = comp.join(se_y, on=group_cols, how="left")
|
|
127
|
+
else:
|
|
128
|
+
comp = comp.join(se_x, how="cross")
|
|
129
|
+
comp = comp.join(se_y, how="cross")
|
|
130
|
+
|
|
131
|
+
rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)
|
|
132
|
+
rd_comparisons.append(rd_comp)
|
|
133
|
+
rr_comparisons.append(rr_comp)
|
|
134
|
+
|
|
135
|
+
risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
|
|
136
|
+
risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()
|
|
137
|
+
|
|
138
|
+
return {"risk_difference": risk_difference, "risk_ratio": risk_ratio}
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import polars as pl
|
|
2
2
|
|
|
3
|
+
from ..helpers._predict_model import _safe_predict
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
def _get_outcome_predictions(self, TxDT, idx=None):
|
|
5
7
|
data = TxDT.to_pandas()
|
|
@@ -9,9 +11,12 @@ def _get_outcome_predictions(self, TxDT, idx=None):
|
|
|
9
11
|
|
|
10
12
|
for boot_model in self.outcome_model:
|
|
11
13
|
model_dict = boot_model[idx] if idx is not None else boot_model
|
|
12
|
-
|
|
14
|
+
outcome_model = self._offloader.load_model(model_dict["outcome"])
|
|
15
|
+
predictions["outcome"].append(_safe_predict(outcome_model, data.copy()))
|
|
16
|
+
|
|
13
17
|
if self.compevent_colname is not None:
|
|
14
|
-
|
|
18
|
+
compevent_model = self._offloader.load_model(model_dict["compevent"])
|
|
19
|
+
predictions["compevent"].append(_safe_predict(compevent_model, data.copy()))
|
|
15
20
|
|
|
16
21
|
return predictions
|
|
17
22
|
|
|
@@ -41,24 +46,20 @@ def _calculate_risk(self, data, idx=None, val=None):
|
|
|
41
46
|
lci = a / 2
|
|
42
47
|
uci = 1 - lci
|
|
43
48
|
|
|
49
|
+
# Pre-compute the followup range once (starts at 1, not 0)
|
|
50
|
+
followup_range = list(range(1, self.followup_max + 1))
|
|
51
|
+
|
|
44
52
|
SDT = (
|
|
45
53
|
data.with_columns(
|
|
46
|
-
[
|
|
47
|
-
(
|
|
48
|
-
pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)
|
|
49
|
-
).alias("TID")
|
|
50
|
-
]
|
|
54
|
+
[pl.concat_str([pl.col(self.id_col), pl.col("trial")]).alias("TID")]
|
|
51
55
|
)
|
|
52
56
|
.group_by("TID")
|
|
53
57
|
.first()
|
|
54
58
|
.drop(["followup", f"followup{self.indicator_squared}"])
|
|
55
|
-
.with_columns([pl.lit(
|
|
59
|
+
.with_columns([pl.lit(followup_range).alias("followup")])
|
|
56
60
|
.explode("followup")
|
|
57
61
|
.with_columns(
|
|
58
|
-
[
|
|
59
|
-
(pl.col("followup") + 1).alias("followup"),
|
|
60
|
-
(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
|
|
61
|
-
]
|
|
62
|
+
[(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
|
|
62
63
|
)
|
|
63
64
|
).sort([self.id_col, "trial", "followup"])
|
|
64
65
|
|
|
@@ -13,17 +13,10 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in
|
|
|
13
13
|
.with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")])
|
|
14
14
|
.with_columns(
|
|
15
15
|
[
|
|
16
|
-
pl.
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
]
|
|
21
|
-
)
|
|
22
|
-
.map_elements(
|
|
23
|
-
lambda x: list(range(x[time_col], x["max_time"] + 1)),
|
|
24
|
-
return_dtype=pl.List(pl.Int64),
|
|
25
|
-
)
|
|
26
|
-
.alias("period")
|
|
16
|
+
pl.int_ranges(
|
|
17
|
+
pl.col(time_col),
|
|
18
|
+
pl.col(time_col).max().over(id_col) + 1,
|
|
19
|
+
).alias("period")
|
|
27
20
|
]
|
|
28
21
|
)
|
|
29
22
|
.explode("period")
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from ._bootstrap import bootstrap_loop as bootstrap_loop
|
|
2
2
|
from ._col_string import _col_string as _col_string
|
|
3
3
|
from ._format_time import _format_time as _format_time
|
|
4
|
+
from ._offloader import Offloader as Offloader
|
|
4
5
|
from ._output_files import _build_md as _build_md
|
|
5
6
|
from ._output_files import _build_pdf as _build_pdf
|
|
6
7
|
from ._pad import _pad as _pad
|
|
@@ -35,11 +35,28 @@ def _prepare_boot_data(self, data, boot_id):
|
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs):
|
|
38
|
-
|
|
38
|
+
# Shallow copy the object and only deep copy mutable state that changes per-bootstrap
|
|
39
|
+
obj = copy.copy(obj)
|
|
40
|
+
# Deep copy only the mutable attributes that get modified during fitting
|
|
41
|
+
obj.outcome_model = []
|
|
42
|
+
obj.numerator_model = (
|
|
43
|
+
copy.copy(obj.numerator_model)
|
|
44
|
+
if hasattr(obj, "numerator_model") and obj.numerator_model
|
|
45
|
+
else []
|
|
46
|
+
)
|
|
47
|
+
obj.denominator_model = (
|
|
48
|
+
copy.copy(obj.denominator_model)
|
|
49
|
+
if hasattr(obj, "denominator_model") and obj.denominator_model
|
|
50
|
+
else []
|
|
51
|
+
)
|
|
52
|
+
|
|
39
53
|
obj._rng = (
|
|
40
54
|
np.random.RandomState(seed + i) if seed is not None else np.random.RandomState()
|
|
41
55
|
)
|
|
56
|
+
original_DT = obj._offloader.load_dataframe(original_DT)
|
|
42
57
|
obj.DT = _prepare_boot_data(obj, original_DT, i)
|
|
58
|
+
del original_DT
|
|
59
|
+
obj._current_boot_idx = i + 1
|
|
43
60
|
|
|
44
61
|
# Disable bootstrapping to prevent recursion
|
|
45
62
|
obj.bootstrap_nboot = 0
|
|
@@ -60,6 +77,7 @@ def bootstrap_loop(method):
|
|
|
60
77
|
results = []
|
|
61
78
|
original_DT = self.DT
|
|
62
79
|
|
|
80
|
+
self._current_boot_idx = None
|
|
63
81
|
full = method(self, *args, **kwargs)
|
|
64
82
|
results.append(full)
|
|
65
83
|
|
|
@@ -71,9 +89,12 @@ def bootstrap_loop(method):
|
|
|
71
89
|
seed = getattr(self, "seed", None)
|
|
72
90
|
method_name = method.__name__
|
|
73
91
|
|
|
92
|
+
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
|
|
93
|
+
|
|
74
94
|
if getattr(self, "parallel", False):
|
|
75
95
|
original_rng = getattr(self, "_rng", None)
|
|
76
96
|
self._rng = None
|
|
97
|
+
self.DT = None
|
|
77
98
|
|
|
78
99
|
with ProcessPoolExecutor(max_workers=ncores) as executor:
|
|
79
100
|
futures = [
|
|
@@ -81,7 +102,7 @@ def bootstrap_loop(method):
|
|
|
81
102
|
_bootstrap_worker,
|
|
82
103
|
self,
|
|
83
104
|
method_name,
|
|
84
|
-
|
|
105
|
+
original_DT_ref,
|
|
85
106
|
i,
|
|
86
107
|
seed,
|
|
87
108
|
args,
|
|
@@ -95,13 +116,27 @@ def bootstrap_loop(method):
|
|
|
95
116
|
results.append(j.result())
|
|
96
117
|
|
|
97
118
|
self._rng = original_rng
|
|
119
|
+
self.DT = self._offloader.load_dataframe(original_DT_ref)
|
|
98
120
|
else:
|
|
121
|
+
# Keep original data in memory if offloading is disabled to avoid unnecessary I/O
|
|
122
|
+
if self._offloader.enabled:
|
|
123
|
+
original_DT_ref = self._offloader.save_dataframe(original_DT, "_DT")
|
|
124
|
+
del original_DT
|
|
125
|
+
else:
|
|
126
|
+
original_DT_ref = original_DT
|
|
127
|
+
|
|
99
128
|
for i in tqdm(range(nboot), desc="Bootstrapping..."):
|
|
100
|
-
self.
|
|
129
|
+
self._current_boot_idx = i + 1
|
|
130
|
+
tmp = self._offloader.load_dataframe(original_DT_ref)
|
|
131
|
+
self.DT = _prepare_boot_data(self, tmp, i)
|
|
132
|
+
if self._offloader.enabled:
|
|
133
|
+
del tmp
|
|
134
|
+
self.bootstrap_nboot = 0
|
|
101
135
|
boot_fit = method(self, *args, **kwargs)
|
|
102
136
|
results.append(boot_fit)
|
|
103
137
|
|
|
104
|
-
|
|
138
|
+
self.bootstrap_nboot = nboot
|
|
139
|
+
self.DT = self._offloader.load_dataframe(original_DT_ref)
|
|
105
140
|
|
|
106
141
|
end = time.perf_counter()
|
|
107
142
|
self._model_time = _format_time(start, end)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
def _fix_categories_for_predict(model, newdata):
|
|
2
|
+
"""
|
|
3
|
+
Fix categorical column ordering in newdata to match what the model expects.
|
|
4
|
+
"""
|
|
5
|
+
if (
|
|
6
|
+
hasattr(model, "model")
|
|
7
|
+
and hasattr(model.model, "data")
|
|
8
|
+
and hasattr(model.model.data, "design_info")
|
|
9
|
+
):
|
|
10
|
+
design_info = model.model.data.design_info
|
|
11
|
+
for factor, factor_info in design_info.factor_infos.items():
|
|
12
|
+
if factor_info.type == "categorical":
|
|
13
|
+
col_name = factor.name()
|
|
14
|
+
if col_name in newdata.columns:
|
|
15
|
+
expected_categories = list(factor_info.categories)
|
|
16
|
+
newdata[col_name] = newdata[col_name].astype(str)
|
|
17
|
+
newdata[col_name] = newdata[col_name].astype("category")
|
|
18
|
+
newdata[col_name] = newdata[col_name].cat.set_categories(
|
|
19
|
+
expected_categories
|
|
20
|
+
)
|
|
21
|
+
return newdata
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Optional, Union
|
|
4
|
+
|
|
5
|
+
import joblib
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Offloader:
|
|
10
|
+
"""Manages disk-based storage for models and intermediate data"""
|
|
11
|
+
|
|
12
|
+
def __init__(self, enabled: bool, dir: str, compression: int = 3):
|
|
13
|
+
self.enabled = enabled
|
|
14
|
+
self.dir = Path(dir)
|
|
15
|
+
self.compression = compression
|
|
16
|
+
# Create a cached loader bound to this instance
|
|
17
|
+
self._init_cache()
|
|
18
|
+
|
|
19
|
+
def _init_cache(self):
|
|
20
|
+
"""Initialize the LRU cache for model loading."""
|
|
21
|
+
self._cached_load = lru_cache(maxsize=32)(self._load_from_disk)
|
|
22
|
+
|
|
23
|
+
def __getstate__(self):
|
|
24
|
+
"""Prepare state for pickling - exclude the unpicklable cache."""
|
|
25
|
+
state = self.__dict__.copy()
|
|
26
|
+
# Remove the cache wrapper which can't be pickled
|
|
27
|
+
del state["_cached_load"]
|
|
28
|
+
return state
|
|
29
|
+
|
|
30
|
+
def __setstate__(self, state):
|
|
31
|
+
"""Restore state after unpickling - recreate the cache."""
|
|
32
|
+
self.__dict__.update(state)
|
|
33
|
+
# Recreate the cache after unpickling
|
|
34
|
+
self._init_cache()
|
|
35
|
+
|
|
36
|
+
def save_model(
|
|
37
|
+
self, model: Any, name: str, boot_idx: Optional[int] = None
|
|
38
|
+
) -> Union[Any, str]:
|
|
39
|
+
"""Save a fitted model to disk and return a reference"""
|
|
40
|
+
if not self.enabled:
|
|
41
|
+
return model
|
|
42
|
+
|
|
43
|
+
filename = (
|
|
44
|
+
f"{name}_boot{boot_idx}.pkl" if boot_idx is not None else f"{name}.pkl"
|
|
45
|
+
)
|
|
46
|
+
filepath = self.dir / filename
|
|
47
|
+
|
|
48
|
+
joblib.dump(model, filepath, compress=self.compression)
|
|
49
|
+
|
|
50
|
+
return str(filepath)
|
|
51
|
+
|
|
52
|
+
def _load_from_disk(self, filepath: str) -> Any:
|
|
53
|
+
"""Internal method to load a model from disk (cached)."""
|
|
54
|
+
return joblib.load(filepath)
|
|
55
|
+
|
|
56
|
+
def load_model(self, ref: Union[Any, str]) -> Any:
|
|
57
|
+
"""Load a model, using cache for repeated loads of the same file."""
|
|
58
|
+
if not self.enabled or not isinstance(ref, str):
|
|
59
|
+
return ref
|
|
60
|
+
|
|
61
|
+
return self._cached_load(ref)
|
|
62
|
+
|
|
63
|
+
def clear_cache(self) -> None:
|
|
64
|
+
"""Clear the model loading cache. Call between bootstrap iterations if needed."""
|
|
65
|
+
self._cached_load.cache_clear()
|
|
66
|
+
|
|
67
|
+
def save_dataframe(self, df: pl.DataFrame, name: str) -> Union[pl.DataFrame, str]:
|
|
68
|
+
if not self.enabled:
|
|
69
|
+
return df
|
|
70
|
+
|
|
71
|
+
filename = f"{name}.parquet"
|
|
72
|
+
filepath = self.dir / filename
|
|
73
|
+
|
|
74
|
+
df.write_parquet(filepath, compression="zstd")
|
|
75
|
+
|
|
76
|
+
return str(filepath)
|
|
77
|
+
|
|
78
|
+
def load_dataframe(self, ref: Union[pl.DataFrame, str]) -> pl.DataFrame:
|
|
79
|
+
if not self.enabled or not isinstance(ref, str):
|
|
80
|
+
return ref
|
|
81
|
+
|
|
82
|
+
return pl.read_parquet(ref)
|
|
@@ -121,8 +121,7 @@ def _build_pdf(md_content: str, filename: str, img_path: str = None) -> None:
|
|
|
121
121
|
f'src="{img_name}"', f'src="file://{img_path}"'
|
|
122
122
|
)
|
|
123
123
|
|
|
124
|
-
css = CSS(
|
|
125
|
-
string="""
|
|
124
|
+
css = CSS(string="""
|
|
126
125
|
body {
|
|
127
126
|
font-family: Arial, sans-serif;
|
|
128
127
|
font-size: 11pt;
|
|
@@ -153,8 +152,7 @@ def _build_pdf(md_content: str, filename: str, img_path: str = None) -> None:
|
|
|
153
152
|
}
|
|
154
153
|
code { font-family: 'Courier New', monospace; }
|
|
155
154
|
img { max-width: 100%; height: auto; }
|
|
156
|
-
"""
|
|
157
|
-
)
|
|
155
|
+
""")
|
|
158
156
|
|
|
159
157
|
full_html = f"""
|
|
160
158
|
<!DOCTYPE html>
|