pySEQTarget 0.10.0__tar.gz → 0.12.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/PKG-INFO +11 -5
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/README.md +5 -4
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/SEQopts.py +17 -2
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/SEQuential.py +23 -10
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/__init__.py +1 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_hazard.py +9 -4
- pyseqtarget-0.12.0/pySEQTarget/analysis/_risk_estimates.py +138 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_survival_pred.py +21 -12
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/error/__init__.py +1 -1
- pyseqtarget-0.10.0/pySEQTarget/error/_datachecker.py → pyseqtarget-0.12.0/pySEQTarget/error/_data_checker.py +1 -1
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_mapper.py +4 -11
- pyseqtarget-0.12.0/pySEQTarget/expansion/_selection.py +44 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/__init__.py +1 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_bootstrap.py +41 -5
- pyseqtarget-0.12.0/pySEQTarget/helpers/_fix_categories.py +21 -0
- pyseqtarget-0.12.0/pySEQTarget/helpers/_offloader.py +82 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_output_files.py +2 -4
- pyseqtarget-0.12.0/pySEQTarget/helpers/_predict_model.py +57 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/__init__.py +1 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_bind.py +3 -0
- pyseqtarget-0.12.0/pySEQTarget/weighting/_weight_fit.py +137 -0
- pyseqtarget-0.12.0/pySEQTarget/weighting/_weight_offload.py +19 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_pred.py +74 -52
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/PKG-INFO +11 -5
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/SOURCES.txt +5 -1
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/requires.txt +6 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pyproject.toml +10 -8
- pyseqtarget-0.12.0/tests/test_offload.py +41 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_survival.py +13 -0
- pyseqtarget-0.10.0/pySEQTarget/analysis/_risk_estimates.py +0 -136
- pyseqtarget-0.10.0/pySEQTarget/expansion/_selection.py +0 -31
- pyseqtarget-0.10.0/pySEQTarget/helpers/_predict_model.py +0 -9
- pyseqtarget-0.10.0/pySEQTarget/weighting/_weight_fit.py +0 -99
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/LICENSE +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/SEQoutput.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/__init__.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_outcome_fit.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/analysis/_subgroup_fit.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/data/__init__.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/error/_param_checker.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/__init__.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_binder.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_diagnostics.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/expansion/_dynamic.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_col_string.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_format_time.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_pad.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/helpers/_prepare_data.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/__init__.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_censoring.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_denominator.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_numerator.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/initialization/_outcome.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/plot/__init__.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/plot/_survival_plot.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_data.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget/weighting/_weight_stats.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/dependency_links.txt +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/pySEQTarget.egg-info/top_level.txt +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/setup.cfg +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_accessor.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_coefficients.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_covariates.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_followup_options.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_hazard.py +0 -0
- {pyseqtarget-0.10.0 → pyseqtarget-0.12.0}/tests/test_parallel.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: pySEQTarget
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.12.0
|
|
4
4
|
Summary: Sequentially Nested Target Trial Emulation
|
|
5
5
|
Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernan <mhernan@hsph.harvard.edu>
|
|
6
6
|
Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>
|
|
@@ -33,6 +33,11 @@ Requires-Dist: statsmodels
|
|
|
33
33
|
Requires-Dist: matplotlib
|
|
34
34
|
Requires-Dist: pyarrow
|
|
35
35
|
Requires-Dist: lifelines
|
|
36
|
+
Requires-Dist: joblib
|
|
37
|
+
Provides-Extra: output
|
|
38
|
+
Requires-Dist: markdown; extra == "output"
|
|
39
|
+
Requires-Dist: weasyprint; extra == "output"
|
|
40
|
+
Requires-Dist: tabulate; extra == "output"
|
|
36
41
|
Dynamic: license-file
|
|
37
42
|
|
|
38
43
|
# pySEQTarget - Sequentially Nested Target Trial Emulation
|
|
@@ -68,8 +73,9 @@ From the user side, this amounts to creating a dataclass, `SEQopts`, and then fe
|
|
|
68
73
|
```python
|
|
69
74
|
import polars as pl
|
|
70
75
|
from pySEQTarget import SEQuential, SEQopts
|
|
76
|
+
from pySEQTarget.data import load_data
|
|
71
77
|
|
|
72
|
-
data =
|
|
78
|
+
data = load_data("SEQdata")
|
|
73
79
|
options = SEQopts(km_curves = True)
|
|
74
80
|
|
|
75
81
|
# Initiate the class
|
|
@@ -77,17 +83,18 @@ model = SEQuential(data,
|
|
|
77
83
|
id_col = "ID",
|
|
78
84
|
time_col = "time",
|
|
79
85
|
eligible_col = "eligible",
|
|
86
|
+
treatment_col = "tx_init",
|
|
87
|
+
outcome_col = "outcome",
|
|
80
88
|
time_varying_cols = ["N", "L", "P"],
|
|
81
89
|
fixed_cols = ["sex"],
|
|
82
90
|
method = "ITT",
|
|
83
|
-
|
|
91
|
+
parameters = options)
|
|
84
92
|
model.expand() # Construct the nested structure
|
|
85
93
|
model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
|
|
86
94
|
model.fit() # Fit the model
|
|
87
95
|
model.survival() # Create survival curves
|
|
88
96
|
model.plot() # Create and show a plot of the survival curves
|
|
89
97
|
model.collect() # Collection of important information
|
|
90
|
-
|
|
91
98
|
```
|
|
92
99
|
|
|
93
100
|
## Assumptions
|
|
@@ -95,4 +102,3 @@ There are several key assumptions in this package -
|
|
|
95
102
|
1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
|
|
96
103
|
1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
|
|
97
104
|
2. `eligible_col` and elements of `excused_colnames` are once 1, only 1 (with respect to `time_col`) flag variables.
|
|
98
|
-
|
|
@@ -31,8 +31,9 @@ From the user side, this amounts to creating a dataclass, `SEQopts`, and then fe
|
|
|
31
31
|
```python
|
|
32
32
|
import polars as pl
|
|
33
33
|
from pySEQTarget import SEQuential, SEQopts
|
|
34
|
+
from pySEQTarget.data import load_data
|
|
34
35
|
|
|
35
|
-
data =
|
|
36
|
+
data = load_data("SEQdata")
|
|
36
37
|
options = SEQopts(km_curves = True)
|
|
37
38
|
|
|
38
39
|
# Initiate the class
|
|
@@ -40,17 +41,18 @@ model = SEQuential(data,
|
|
|
40
41
|
id_col = "ID",
|
|
41
42
|
time_col = "time",
|
|
42
43
|
eligible_col = "eligible",
|
|
44
|
+
treatment_col = "tx_init",
|
|
45
|
+
outcome_col = "outcome",
|
|
43
46
|
time_varying_cols = ["N", "L", "P"],
|
|
44
47
|
fixed_cols = ["sex"],
|
|
45
48
|
method = "ITT",
|
|
46
|
-
|
|
49
|
+
parameters = options)
|
|
47
50
|
model.expand() # Construct the nested structure
|
|
48
51
|
model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
|
|
49
52
|
model.fit() # Fit the model
|
|
50
53
|
model.survival() # Create survival curves
|
|
51
54
|
model.plot() # Create and show a plot of the survival curves
|
|
52
55
|
model.collect() # Collection of important information
|
|
53
|
-
|
|
54
56
|
```
|
|
55
57
|
|
|
56
58
|
## Assumptions
|
|
@@ -58,4 +60,3 @@ There are several key assumptions in this package -
|
|
|
58
60
|
1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
|
|
59
61
|
1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
|
|
60
62
|
2. `eligible_col` and elements of `excused_colnames` are once 1, only 1 (with respect to `time_col`) flag variables.
|
|
61
|
-
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import multiprocessing
|
|
2
|
+
import os
|
|
2
3
|
from dataclasses import dataclass, field
|
|
3
4
|
from typing import List, Literal, Optional
|
|
4
5
|
|
|
@@ -18,7 +19,7 @@ class SEQopts:
|
|
|
18
19
|
:type bootstrap_CI_method: str
|
|
19
20
|
:param cense_colname: Column name for censoring effect (LTFU, etc.)
|
|
20
21
|
:type cense_colname: str
|
|
21
|
-
:param cense_denominator: Override to specify denominator patsy formula for censoring models
|
|
22
|
+
:param cense_denominator: Override to specify denominator patsy formula for censoring models; "1" or "" indicate intercept only model
|
|
22
23
|
:type cense_denominator: Optional[str] or None
|
|
23
24
|
:param cense_numerator: Override to specify numerator patsy formula for censoring models
|
|
24
25
|
:type cense_numerator: Optional[str] or None
|
|
@@ -54,8 +55,12 @@ class SEQopts:
|
|
|
54
55
|
:type km_curves: bool
|
|
55
56
|
:param ncores: Number of cores to use if running in parallel
|
|
56
57
|
:type ncores: int
|
|
57
|
-
:param numerator: Override to specify the outcome patsy formula for numerator models
|
|
58
|
+
:param numerator: Override to specify the outcome patsy formula for numerator models; "1" or "" indicate intercept only model
|
|
58
59
|
:type numerator: str
|
|
60
|
+
:param offload: Boolean to offload intermediate model data to disk
|
|
61
|
+
:type offload: bool
|
|
62
|
+
:param offload_dir: Directory to offload intermediate model data
|
|
63
|
+
:type offload_dir: str
|
|
59
64
|
:param parallel: Boolean to run model fitting in parallel
|
|
60
65
|
:type parallel: bool
|
|
61
66
|
:param plot_colors: List of colors for KM plots, if applicable
|
|
@@ -80,8 +85,12 @@ class SEQopts:
|
|
|
80
85
|
:type treatment_level: List[int]
|
|
81
86
|
:param trial_include: Boolean to force trial values into model covariates
|
|
82
87
|
:type trial_include: bool
|
|
88
|
+
:param visit_colname: Column name specifying visit number
|
|
89
|
+
:type visit_colname: str
|
|
83
90
|
:param weight_eligible_colnames: List of column names of length treatment_level to identify which rows are eligible for weight fitting
|
|
84
91
|
:type weight_eligible_colnames: List[str]
|
|
92
|
+
:param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton"
|
|
93
|
+
:type weight_fit_method: str
|
|
85
94
|
:param weight_min: Minimum weight
|
|
86
95
|
:type weight_min: float
|
|
87
96
|
:param weight_max: Maximum weight
|
|
@@ -120,6 +129,8 @@ class SEQopts:
|
|
|
120
129
|
km_curves: bool = False
|
|
121
130
|
ncores: int = multiprocessing.cpu_count()
|
|
122
131
|
numerator: Optional[str] = None
|
|
132
|
+
offload: bool = False
|
|
133
|
+
offload_dir: str = "_seq_models"
|
|
123
134
|
parallel: bool = False
|
|
124
135
|
plot_colors: List[str] = field(
|
|
125
136
|
default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]
|
|
@@ -136,6 +147,7 @@ class SEQopts:
|
|
|
136
147
|
trial_include: bool = True
|
|
137
148
|
visit_colname: str = None
|
|
138
149
|
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
|
|
150
|
+
weight_fit_method: Literal["newton", "bfgs", "lbfgs", "nm"] = "newton"
|
|
139
151
|
weight_min: float = 0.0
|
|
140
152
|
weight_max: float = None
|
|
141
153
|
weight_lag_condition: bool = True
|
|
@@ -195,3 +207,6 @@ class SEQopts:
|
|
|
195
207
|
attr = getattr(self, i)
|
|
196
208
|
if attr is not None and not isinstance(attr, list):
|
|
197
209
|
setattr(self, i, "".join(attr.split()))
|
|
210
|
+
|
|
211
|
+
if self.offload:
|
|
212
|
+
os.makedirs(self.offload_dir, exist_ok=True)
|
|
@@ -7,19 +7,20 @@ from typing import List, Literal, Optional
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import polars as pl
|
|
9
9
|
|
|
10
|
-
from .analysis import (_calculate_hazard, _calculate_survival,
|
|
11
|
-
_pred_risk, _risk_estimates,
|
|
12
|
-
|
|
10
|
+
from .analysis import (_calculate_hazard, _calculate_survival, _clamp,
|
|
11
|
+
_outcome_fit, _pred_risk, _risk_estimates,
|
|
12
|
+
_subgroup_fit)
|
|
13
|
+
from .error import _data_checker, _param_checker
|
|
13
14
|
from .expansion import _binder, _diagnostics, _dynamic, _random_selection
|
|
14
|
-
from .helpers import _col_string, _format_time, bootstrap_loop
|
|
15
|
+
from .helpers import Offloader, _col_string, _format_time, bootstrap_loop
|
|
15
16
|
from .initialization import (_cense_denominator, _cense_numerator,
|
|
16
17
|
_denominator, _numerator, _outcome)
|
|
17
18
|
from .plot import _survival_plot
|
|
18
19
|
from .SEQopts import SEQopts
|
|
19
20
|
from .SEQoutput import SEQoutput
|
|
20
21
|
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
|
|
21
|
-
_fit_visit,
|
|
22
|
-
_weight_setup, _weight_stats)
|
|
22
|
+
_fit_visit, _offload_weights, _weight_bind,
|
|
23
|
+
_weight_predict, _weight_setup, _weight_stats)
|
|
23
24
|
|
|
24
25
|
|
|
25
26
|
class SEQuential:
|
|
@@ -83,6 +84,8 @@ class SEQuential:
|
|
|
83
84
|
np.random.RandomState(self.seed) if self.seed is not None else np.random
|
|
84
85
|
)
|
|
85
86
|
|
|
87
|
+
self._offloader = Offloader(enabled=self.offload, dir=self.offload_dir)
|
|
88
|
+
|
|
86
89
|
if self.covariates is None:
|
|
87
90
|
self.covariates = _outcome(self)
|
|
88
91
|
|
|
@@ -101,7 +104,7 @@ class SEQuential:
|
|
|
101
104
|
self.cense_denominator = _cense_denominator(self)
|
|
102
105
|
|
|
103
106
|
_param_checker(self)
|
|
104
|
-
|
|
107
|
+
_data_checker(self)
|
|
105
108
|
|
|
106
109
|
def expand(self) -> None:
|
|
107
110
|
"""
|
|
@@ -190,7 +193,6 @@ class SEQuential:
|
|
|
190
193
|
)
|
|
191
194
|
id_counts = Counter(sampled_IDs)
|
|
192
195
|
self._boot_samples.append(id_counts)
|
|
193
|
-
return self
|
|
194
196
|
|
|
195
197
|
@bootstrap_loop
|
|
196
198
|
def fit(self) -> None:
|
|
@@ -201,6 +203,9 @@ class SEQuential:
|
|
|
201
203
|
raise ValueError(
|
|
202
204
|
"Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
|
|
203
205
|
)
|
|
206
|
+
boot_idx = None
|
|
207
|
+
if hasattr(self, "_current_boot_idx"):
|
|
208
|
+
boot_idx = self._current_boot_idx
|
|
204
209
|
|
|
205
210
|
if self.weighted:
|
|
206
211
|
WDT = _weight_setup(self)
|
|
@@ -217,6 +222,9 @@ class SEQuential:
|
|
|
217
222
|
_fit_numerator(self, WDT)
|
|
218
223
|
_fit_denominator(self, WDT)
|
|
219
224
|
|
|
225
|
+
if self.offload:
|
|
226
|
+
_offload_weights(self, boot_idx)
|
|
227
|
+
|
|
220
228
|
WDT = pl.from_pandas(WDT)
|
|
221
229
|
WDT = _weight_predict(self, WDT)
|
|
222
230
|
_weight_bind(self, WDT)
|
|
@@ -244,6 +252,11 @@ class SEQuential:
|
|
|
244
252
|
self.weighted,
|
|
245
253
|
"weight",
|
|
246
254
|
)
|
|
255
|
+
if self.offload:
|
|
256
|
+
offloaded_models = {}
|
|
257
|
+
for key, model in models.items():
|
|
258
|
+
offloaded_models[key] = self._offloader.save_model(model, key, boot_idx)
|
|
259
|
+
return offloaded_models
|
|
247
260
|
return models
|
|
248
261
|
|
|
249
262
|
def survival(self, **kwargs) -> None:
|
|
@@ -266,7 +279,7 @@ class SEQuential:
|
|
|
266
279
|
|
|
267
280
|
risk_data = _pred_risk(self)
|
|
268
281
|
surv_data = _calculate_survival(self, risk_data)
|
|
269
|
-
self.km_data = pl.concat([risk_data, surv_data])
|
|
282
|
+
self.km_data = _clamp(pl.concat([risk_data, surv_data]))
|
|
270
283
|
self.risk_estimates = _risk_estimates(self)
|
|
271
284
|
|
|
272
285
|
end = time.perf_counter()
|
|
@@ -342,7 +355,7 @@ class SEQuential:
|
|
|
342
355
|
}
|
|
343
356
|
|
|
344
357
|
if self.compevent_colname is not None:
|
|
345
|
-
compevent_models = [model["compevent"] for model in self.
|
|
358
|
+
compevent_models = [model["compevent"] for model in self.outcome_model]
|
|
346
359
|
else:
|
|
347
360
|
compevent_models = None
|
|
348
361
|
|
|
@@ -3,6 +3,7 @@ from ._outcome_fit import _outcome_fit as _outcome_fit
|
|
|
3
3
|
from ._risk_estimates import _risk_estimates as _risk_estimates
|
|
4
4
|
from ._subgroup_fit import _subgroup_fit as _subgroup_fit
|
|
5
5
|
from ._survival_pred import _calculate_survival as _calculate_survival
|
|
6
|
+
from ._survival_pred import _clamp as _clamp
|
|
6
7
|
from ._survival_pred import \
|
|
7
8
|
_get_outcome_predictions as _get_outcome_predictions
|
|
8
9
|
from ._survival_pred import _pred_risk as _pred_risk
|
|
@@ -4,6 +4,8 @@ import numpy as np
|
|
|
4
4
|
import polars as pl
|
|
5
5
|
from lifelines import CoxPHFitter
|
|
6
6
|
|
|
7
|
+
from ..helpers._predict_model import _safe_predict
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
def _calculate_hazard(self):
|
|
9
11
|
if self.subgroup_colname is None:
|
|
@@ -93,8 +95,10 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
|
|
|
93
95
|
else:
|
|
94
96
|
model_dict = self.outcome_model[boot_idx]
|
|
95
97
|
|
|
96
|
-
outcome_model = model_dict["outcome"]
|
|
97
|
-
ce_model =
|
|
98
|
+
outcome_model = self._offloader.load_model(model_dict["outcome"])
|
|
99
|
+
ce_model = None
|
|
100
|
+
if self.compevent_colname and "compevent" in model_dict:
|
|
101
|
+
ce_model = self._offloader.load_model(model_dict["compevent"])
|
|
98
102
|
|
|
99
103
|
all_treatments = []
|
|
100
104
|
for val in self.treatment_level:
|
|
@@ -103,13 +107,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
|
|
|
103
107
|
)
|
|
104
108
|
|
|
105
109
|
tmp_pd = tmp.to_pandas()
|
|
106
|
-
outcome_prob = outcome_model
|
|
110
|
+
outcome_prob = _safe_predict(outcome_model, tmp_pd)
|
|
107
111
|
outcome_sim = rng.binomial(1, outcome_prob)
|
|
108
112
|
|
|
109
113
|
tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])
|
|
110
114
|
|
|
111
115
|
if ce_model is not None:
|
|
112
|
-
|
|
116
|
+
ce_tmp_pd = tmp.to_pandas()
|
|
117
|
+
ce_prob = _safe_predict(ce_model, ce_tmp_pd)
|
|
113
118
|
ce_sim = rng.binomial(1, ce_prob)
|
|
114
119
|
tmp = tmp.with_columns([pl.Series("ce", ce_sim)])
|
|
115
120
|
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
from scipy import stats
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None):
|
|
6
|
+
"""
|
|
7
|
+
Compute Risk Difference and Risk Ratio from a comparison dataframe.
|
|
8
|
+
Consolidates the repeated calculation logic.
|
|
9
|
+
"""
|
|
10
|
+
if group_cols is None:
|
|
11
|
+
group_cols = []
|
|
12
|
+
|
|
13
|
+
if has_bootstrap:
|
|
14
|
+
rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
|
|
15
|
+
rd_comp = comp.with_columns(
|
|
16
|
+
[
|
|
17
|
+
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
|
|
18
|
+
(pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias("RD 95% LCI"),
|
|
19
|
+
(pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias("RD 95% UCI"),
|
|
20
|
+
]
|
|
21
|
+
)
|
|
22
|
+
rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
|
|
23
|
+
col_order = group_cols + [
|
|
24
|
+
"A_x",
|
|
25
|
+
"A_y",
|
|
26
|
+
"Risk Difference",
|
|
27
|
+
"RD 95% LCI",
|
|
28
|
+
"RD 95% UCI",
|
|
29
|
+
]
|
|
30
|
+
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
|
|
31
|
+
|
|
32
|
+
rr_log_se = (
|
|
33
|
+
(pl.col("se_x") / pl.col("risk_x")).pow(2)
|
|
34
|
+
+ (pl.col("se_y") / pl.col("risk_y")).pow(2)
|
|
35
|
+
).sqrt()
|
|
36
|
+
rr_comp = comp.with_columns(
|
|
37
|
+
[
|
|
38
|
+
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
|
|
39
|
+
((pl.col("risk_x") / pl.col("risk_y")) * (-z * rr_log_se).exp()).alias(
|
|
40
|
+
"RR 95% LCI"
|
|
41
|
+
),
|
|
42
|
+
((pl.col("risk_x") / pl.col("risk_y")) * (z * rr_log_se).exp()).alias(
|
|
43
|
+
"RR 95% UCI"
|
|
44
|
+
),
|
|
45
|
+
]
|
|
46
|
+
)
|
|
47
|
+
rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
|
|
48
|
+
col_order = group_cols + [
|
|
49
|
+
"A_x",
|
|
50
|
+
"A_y",
|
|
51
|
+
"Risk Ratio",
|
|
52
|
+
"RR 95% LCI",
|
|
53
|
+
"RR 95% UCI",
|
|
54
|
+
]
|
|
55
|
+
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
|
|
56
|
+
else:
|
|
57
|
+
rd_comp = comp.with_columns(
|
|
58
|
+
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
|
|
59
|
+
)
|
|
60
|
+
rd_comp = rd_comp.drop(["risk_x", "risk_y"])
|
|
61
|
+
col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
|
|
62
|
+
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
|
|
63
|
+
|
|
64
|
+
rr_comp = comp.with_columns(
|
|
65
|
+
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
|
|
66
|
+
)
|
|
67
|
+
rr_comp = rr_comp.drop(["risk_x", "risk_y"])
|
|
68
|
+
col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
|
|
69
|
+
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
|
|
70
|
+
|
|
71
|
+
return rd_comp, rr_comp
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _risk_estimates(self):
|
|
75
|
+
last_followup = self.km_data["followup"].max()
|
|
76
|
+
risk = self.km_data.filter(
|
|
77
|
+
(pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
group_cols = [self.subgroup_colname] if self.subgroup_colname else []
|
|
81
|
+
has_bootstrap = self.bootstrap_nboot > 0
|
|
82
|
+
|
|
83
|
+
if has_bootstrap:
|
|
84
|
+
alpha = 1 - self.bootstrap_CI
|
|
85
|
+
z = stats.norm.ppf(1 - alpha / 2)
|
|
86
|
+
else:
|
|
87
|
+
z = None
|
|
88
|
+
|
|
89
|
+
# Pre-extract data for each treatment level once (avoid repeated filtering)
|
|
90
|
+
risk_by_level = {}
|
|
91
|
+
for tx in self.treatment_level:
|
|
92
|
+
level_data = risk.filter(pl.col(self.treatment_col) == tx)
|
|
93
|
+
risk_by_level[tx] = {
|
|
94
|
+
"pred": level_data.select(group_cols + ["pred"]),
|
|
95
|
+
}
|
|
96
|
+
if has_bootstrap:
|
|
97
|
+
risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])
|
|
98
|
+
|
|
99
|
+
rd_comparisons = []
|
|
100
|
+
rr_comparisons = []
|
|
101
|
+
|
|
102
|
+
for tx_x in self.treatment_level:
|
|
103
|
+
for tx_y in self.treatment_level:
|
|
104
|
+
if tx_x == tx_y:
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
# Use pre-extracted data instead of filtering again
|
|
108
|
+
risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
|
|
109
|
+
risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})
|
|
110
|
+
|
|
111
|
+
if group_cols:
|
|
112
|
+
comp = risk_x.join(risk_y, on=group_cols, how="left")
|
|
113
|
+
else:
|
|
114
|
+
comp = risk_x.join(risk_y, how="cross")
|
|
115
|
+
|
|
116
|
+
comp = comp.with_columns(
|
|
117
|
+
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if has_bootstrap:
|
|
121
|
+
se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
|
|
122
|
+
se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
|
|
123
|
+
|
|
124
|
+
if group_cols:
|
|
125
|
+
comp = comp.join(se_x, on=group_cols, how="left")
|
|
126
|
+
comp = comp.join(se_y, on=group_cols, how="left")
|
|
127
|
+
else:
|
|
128
|
+
comp = comp.join(se_x, how="cross")
|
|
129
|
+
comp = comp.join(se_y, how="cross")
|
|
130
|
+
|
|
131
|
+
rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)
|
|
132
|
+
rd_comparisons.append(rd_comp)
|
|
133
|
+
rr_comparisons.append(rr_comp)
|
|
134
|
+
|
|
135
|
+
risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
|
|
136
|
+
risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()
|
|
137
|
+
|
|
138
|
+
return {"risk_difference": risk_difference, "risk_ratio": risk_ratio}
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
import polars as pl
|
|
2
2
|
|
|
3
|
+
from ..helpers._predict_model import _safe_predict
|
|
4
|
+
|
|
3
5
|
|
|
4
6
|
def _get_outcome_predictions(self, TxDT, idx=None):
|
|
5
7
|
data = TxDT.to_pandas()
|
|
@@ -9,9 +11,12 @@ def _get_outcome_predictions(self, TxDT, idx=None):
|
|
|
9
11
|
|
|
10
12
|
for boot_model in self.outcome_model:
|
|
11
13
|
model_dict = boot_model[idx] if idx is not None else boot_model
|
|
12
|
-
|
|
14
|
+
outcome_model = self._offloader.load_model(model_dict["outcome"])
|
|
15
|
+
predictions["outcome"].append(_safe_predict(outcome_model, data.copy()))
|
|
16
|
+
|
|
13
17
|
if self.compevent_colname is not None:
|
|
14
|
-
|
|
18
|
+
compevent_model = self._offloader.load_model(model_dict["compevent"])
|
|
19
|
+
predictions["compevent"].append(_safe_predict(compevent_model, data.copy()))
|
|
15
20
|
|
|
16
21
|
return predictions
|
|
17
22
|
|
|
@@ -41,24 +46,20 @@ def _calculate_risk(self, data, idx=None, val=None):
|
|
|
41
46
|
lci = a / 2
|
|
42
47
|
uci = 1 - lci
|
|
43
48
|
|
|
49
|
+
# Pre-compute the followup range once (starts at 1, not 0)
|
|
50
|
+
followup_range = list(range(1, self.followup_max + 1))
|
|
51
|
+
|
|
44
52
|
SDT = (
|
|
45
53
|
data.with_columns(
|
|
46
|
-
[
|
|
47
|
-
(
|
|
48
|
-
pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)
|
|
49
|
-
).alias("TID")
|
|
50
|
-
]
|
|
54
|
+
[pl.concat_str([pl.col(self.id_col), pl.col("trial")]).alias("TID")]
|
|
51
55
|
)
|
|
52
56
|
.group_by("TID")
|
|
53
57
|
.first()
|
|
54
58
|
.drop(["followup", f"followup{self.indicator_squared}"])
|
|
55
|
-
.with_columns([pl.lit(
|
|
59
|
+
.with_columns([pl.lit(followup_range).alias("followup")])
|
|
56
60
|
.explode("followup")
|
|
57
61
|
.with_columns(
|
|
58
|
-
[
|
|
59
|
-
(pl.col("followup") + 1).alias("followup"),
|
|
60
|
-
(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
|
|
61
|
-
]
|
|
62
|
+
[(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}")]
|
|
62
63
|
)
|
|
63
64
|
).sort([self.id_col, "trial", "followup"])
|
|
64
65
|
|
|
@@ -370,3 +371,11 @@ def _calculate_survival(self, risk_data):
|
|
|
370
371
|
[(1 - pl.col("pred")).alias("pred"), pl.lit("survival").alias("estimate")]
|
|
371
372
|
)
|
|
372
373
|
return surv
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def _clamp(data):
|
|
377
|
+
"""Clamp prediction and CI columns to [0, 1] bounds."""
|
|
378
|
+
cols = ["pred", "LCI", "UCI"]
|
|
379
|
+
exists = [c for c in cols if c in data.columns]
|
|
380
|
+
|
|
381
|
+
return data.with_columns([pl.col(col).clip(0.0, 1.0) for col in exists])
|
|
@@ -1,2 +1,2 @@
|
|
|
1
|
-
from .
|
|
1
|
+
from ._data_checker import _data_checker as _data_checker
|
|
2
2
|
from ._param_checker import _param_checker as _param_checker
|
|
@@ -13,17 +13,10 @@ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.in
|
|
|
13
13
|
.with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")])
|
|
14
14
|
.with_columns(
|
|
15
15
|
[
|
|
16
|
-
pl.
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
]
|
|
21
|
-
)
|
|
22
|
-
.map_elements(
|
|
23
|
-
lambda x: list(range(x[time_col], x["max_time"] + 1)),
|
|
24
|
-
return_dtype=pl.List(pl.Int64),
|
|
25
|
-
)
|
|
26
|
-
.alias("period")
|
|
16
|
+
pl.int_ranges(
|
|
17
|
+
pl.col(time_col),
|
|
18
|
+
pl.col(time_col).max().over(id_col) + 1,
|
|
19
|
+
).alias("period")
|
|
27
20
|
]
|
|
28
21
|
)
|
|
29
22
|
.explode("period")
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _random_selection(self):
|
|
5
|
+
"""
|
|
6
|
+
Handles the case where random selection is applied for data from
|
|
7
|
+
the __mapper -> __binder -> optionally __dynamic pipeline
|
|
8
|
+
"""
|
|
9
|
+
UIDs = (
|
|
10
|
+
self.DT.select(
|
|
11
|
+
[self.id_col, "trial", f"{self.treatment_col}{self.indicator_baseline}"]
|
|
12
|
+
)
|
|
13
|
+
.with_columns(
|
|
14
|
+
(
|
|
15
|
+
pl.col(self.id_col).cast(pl.Utf8) + "_" + pl.col("trial").cast(pl.Utf8)
|
|
16
|
+
).alias("trialID")
|
|
17
|
+
)
|
|
18
|
+
.filter(
|
|
19
|
+
pl.col(f"{self.treatment_col}{self.indicator_baseline}")
|
|
20
|
+
== self.treatment_level[0]
|
|
21
|
+
)
|
|
22
|
+
.unique("trialID")
|
|
23
|
+
.get_column("trialID")
|
|
24
|
+
.to_list()
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
NIDs = len(UIDs)
|
|
28
|
+
sample = self._rng.choice(
|
|
29
|
+
UIDs, size=int(self.selection_sample * NIDs), replace=False
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
self.DT = (
|
|
33
|
+
self.DT.with_columns(
|
|
34
|
+
(
|
|
35
|
+
pl.col(self.id_col).cast(pl.Utf8) + "_" + pl.col("trial").cast(pl.Utf8)
|
|
36
|
+
).alias("trialID")
|
|
37
|
+
)
|
|
38
|
+
.filter(
|
|
39
|
+
pl.col("trialID").is_in(sample)
|
|
40
|
+
| pl.col(f"{self.treatment_col}{self.indicator_baseline}")
|
|
41
|
+
!= self.treatment_level[0]
|
|
42
|
+
)
|
|
43
|
+
.drop("trialID")
|
|
44
|
+
)
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from ._bootstrap import bootstrap_loop as bootstrap_loop
|
|
2
2
|
from ._col_string import _col_string as _col_string
|
|
3
3
|
from ._format_time import _format_time as _format_time
|
|
4
|
+
from ._offloader import Offloader as Offloader
|
|
4
5
|
from ._output_files import _build_md as _build_md
|
|
5
6
|
from ._output_files import _build_pdf as _build_pdf
|
|
6
7
|
from ._pad import _pad as _pad
|