pySEQTarget 0.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyseqtarget-0.9.0/LICENSE +21 -0
- pyseqtarget-0.9.0/PKG-INFO +91 -0
- pyseqtarget-0.9.0/README.md +56 -0
- pyseqtarget-0.9.0/pySEQTarget/SEQopts.py +105 -0
- pyseqtarget-0.9.0/pySEQTarget/SEQoutput.py +86 -0
- pyseqtarget-0.9.0/pySEQTarget/SEQuential.py +315 -0
- pyseqtarget-0.9.0/pySEQTarget/__init__.py +5 -0
- pyseqtarget-0.9.0/pySEQTarget/data/__init__.py +19 -0
- pyseqtarget-0.9.0/pySEQTarget.egg-info/PKG-INFO +91 -0
- pyseqtarget-0.9.0/pySEQTarget.egg-info/SOURCES.txt +20 -0
- pyseqtarget-0.9.0/pySEQTarget.egg-info/dependency_links.txt +1 -0
- pyseqtarget-0.9.0/pySEQTarget.egg-info/requires.txt +7 -0
- pyseqtarget-0.9.0/pySEQTarget.egg-info/top_level.txt +1 -0
- pyseqtarget-0.9.0/pyproject.toml +63 -0
- pyseqtarget-0.9.0/setup.cfg +4 -0
- pyseqtarget-0.9.0/tests/test_accessor.py +27 -0
- pyseqtarget-0.9.0/tests/test_coefficients.py +369 -0
- pyseqtarget-0.9.0/tests/test_covariates.py +180 -0
- pyseqtarget-0.9.0/tests/test_followup_options.py +103 -0
- pyseqtarget-0.9.0/tests/test_hazard.py +64 -0
- pyseqtarget-0.9.0/tests/test_parallel.py +42 -0
- pyseqtarget-0.9.0/tests/test_survival.py +165 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 CAUSALab
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pySEQTarget
|
|
3
|
+
Version: 0.9.0
|
|
4
|
+
Summary: Sequentially Nested Target Trial Emulation
|
|
5
|
+
Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernan <mhernan@hsph.harvard.edu>
|
|
6
|
+
Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>
|
|
7
|
+
License: MIT
|
|
8
|
+
Project-URL: Homepage, https://github.com/CausalInference/pySEQTarget
|
|
9
|
+
Project-URL: Repository, https://github.com/CausalInference/pySEQTarget
|
|
10
|
+
Project-URL: Bug Tracker, https://github.com/CausalInference/pySEQTarget/issues
|
|
11
|
+
Project-URL: Ryan O'Dea (ORCID), https://orcid.org/0009-0000-0103-9546
|
|
12
|
+
Project-URL: Alejandro Szmulewicz (ORCID), https://orcid.org/0000-0002-2664-802X
|
|
13
|
+
Project-URL: Tom Palmer (ORCID), https://orcid.org/0000-0003-4655-4511
|
|
14
|
+
Project-URL: Miguel Hernan (ORCID), https://orcid.org/0000-0003-1619-8456
|
|
15
|
+
Project-URL: University of Bristol (ROR), https://ror.org/0524sp257
|
|
16
|
+
Project-URL: Harvard University (ROR), https://ror.org/03vek6s52
|
|
17
|
+
Keywords: causal inference,sequential trial emulation,target trial,observational studies
|
|
18
|
+
Classifier: Development Status :: 4 - Beta
|
|
19
|
+
Classifier: Intended Audience :: Science/Research
|
|
20
|
+
Classifier: Programming Language :: Python :: 3
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
22
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
23
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
24
|
+
Requires-Python: >=3.10
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
License-File: LICENSE
|
|
27
|
+
Requires-Dist: numpy
|
|
28
|
+
Requires-Dist: polars
|
|
29
|
+
Requires-Dist: tqdm
|
|
30
|
+
Requires-Dist: statsmodels
|
|
31
|
+
Requires-Dist: matplotlib
|
|
32
|
+
Requires-Dist: pyarrow
|
|
33
|
+
Requires-Dist: lifelines
|
|
34
|
+
Dynamic: license-file
|
|
35
|
+
|
|
36
|
+
# pySEQTarget - Sequentially Nested Target Trial Emulation
|
|
37
|
+
|
|
38
|
+
Implementation of sequential trial emulation for the analysis of
|
|
39
|
+
observational databases. The ‘SEQTaRget’ software accommodates
|
|
40
|
+
time-varying treatments and confounders, as well as binary and failure
|
|
41
|
+
time outcomes. ‘SEQTaRget’ allows to compare both static and dynamic
|
|
42
|
+
strategies, can be used to estimate observational analogs of
|
|
43
|
+
intention-to-treat and per-protocol effects, and can adjust for
|
|
44
|
+
potential selection bias.
|
|
45
|
+
|
|
46
|
+
## Installation
|
|
47
|
+
You can install the development version of pySEQTarget from github with:
|
|
48
|
+
```shell
|
|
49
|
+
pip install git+https://github.com/CausalInference/pySEQTarget
|
|
50
|
+
```
|
|
51
|
+
Or from pypi iwth
|
|
52
|
+
```shell
|
|
53
|
+
pip install pySEQTarget
|
|
54
|
+
```
|
|
55
|
+
|
|
56
|
+
## Setting up your Analysis
|
|
57
|
+
The primary API, `SEQuential` uses a dataclass system to handle function input. You can then recover elements as they are built by interacting with the `SEQuential` object you create.
|
|
58
|
+
|
|
59
|
+
From the user side, this amounts to creating a dataclass, `SEQopts`, and then feeding this into `SEQuential`. If you forgot to add something at class instantiation, you can, in some cases, add them when you call their respective class method.
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
import polars as pl
|
|
63
|
+
from pySEQTarget import SEQuential, SEQopts
|
|
64
|
+
|
|
65
|
+
data = pl.from_pandas(SEQdata)
|
|
66
|
+
options = SEQopts(km_curves = True)
|
|
67
|
+
|
|
68
|
+
# Initiate the class
|
|
69
|
+
model = SEQuential(data,
|
|
70
|
+
id_col = "ID",
|
|
71
|
+
time_col = "time",
|
|
72
|
+
eligible_col = "eligible",
|
|
73
|
+
time_varying_cols = ["N", "L", "P"],
|
|
74
|
+
fixed_cols = ["sex"],
|
|
75
|
+
method = "ITT",
|
|
76
|
+
options = options)
|
|
77
|
+
model.expand() # Construct the nested structure
|
|
78
|
+
model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
|
|
79
|
+
model.fit() # Fit the model
|
|
80
|
+
model.survival() # Create survival curves
|
|
81
|
+
model.plot() # Create and show a plot of the survival curves
|
|
82
|
+
model.collect() # Collection of important information
|
|
83
|
+
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
## Assumptions
|
|
87
|
+
There are several key assumptions in this package -
|
|
88
|
+
1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
|
|
89
|
+
1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
|
|
90
|
+
2. `eligible_col`, `excused_column_names` and [TODO] are once 1, only 1 (with respect to `time_col`) flag variables.
|
|
91
|
+
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# pySEQTarget - Sequentially Nested Target Trial Emulation
|
|
2
|
+
|
|
3
|
+
Implementation of sequential trial emulation for the analysis of
|
|
4
|
+
observational databases. The ‘SEQTaRget’ software accommodates
|
|
5
|
+
time-varying treatments and confounders, as well as binary and failure
|
|
6
|
+
time outcomes. ‘SEQTaRget’ allows to compare both static and dynamic
|
|
7
|
+
strategies, can be used to estimate observational analogs of
|
|
8
|
+
intention-to-treat and per-protocol effects, and can adjust for
|
|
9
|
+
potential selection bias.
|
|
10
|
+
|
|
11
|
+
## Installation
|
|
12
|
+
You can install the development version of pySEQTarget from github with:
|
|
13
|
+
```shell
|
|
14
|
+
pip install git+https://github.com/CausalInference/pySEQTarget
|
|
15
|
+
```
|
|
16
|
+
Or from pypi iwth
|
|
17
|
+
```shell
|
|
18
|
+
pip install pySEQTarget
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
## Setting up your Analysis
|
|
22
|
+
The primary API, `SEQuential` uses a dataclass system to handle function input. You can then recover elements as they are built by interacting with the `SEQuential` object you create.
|
|
23
|
+
|
|
24
|
+
From the user side, this amounts to creating a dataclass, `SEQopts`, and then feeding this into `SEQuential`. If you forgot to add something at class instantiation, you can, in some cases, add them when you call their respective class method.
|
|
25
|
+
|
|
26
|
+
```python
|
|
27
|
+
import polars as pl
|
|
28
|
+
from pySEQTarget import SEQuential, SEQopts
|
|
29
|
+
|
|
30
|
+
data = pl.from_pandas(SEQdata)
|
|
31
|
+
options = SEQopts(km_curves = True)
|
|
32
|
+
|
|
33
|
+
# Initiate the class
|
|
34
|
+
model = SEQuential(data,
|
|
35
|
+
id_col = "ID",
|
|
36
|
+
time_col = "time",
|
|
37
|
+
eligible_col = "eligible",
|
|
38
|
+
time_varying_cols = ["N", "L", "P"],
|
|
39
|
+
fixed_cols = ["sex"],
|
|
40
|
+
method = "ITT",
|
|
41
|
+
options = options)
|
|
42
|
+
model.expand() # Construct the nested structure
|
|
43
|
+
model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
|
|
44
|
+
model.fit() # Fit the model
|
|
45
|
+
model.survival() # Create survival curves
|
|
46
|
+
model.plot() # Create and show a plot of the survival curves
|
|
47
|
+
model.collect() # Collection of important information
|
|
48
|
+
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
## Assumptions
|
|
52
|
+
There are several key assumptions in this package -
|
|
53
|
+
1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
|
|
54
|
+
1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
|
|
55
|
+
2. `eligible_col`, `excused_column_names` and [TODO] are once 1, only 1 (with respect to `time_col`) flag variables.
|
|
56
|
+
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import multiprocessing
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import List, Literal, Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class SEQopts:
|
|
8
|
+
bootstrap_nboot: int = 0
|
|
9
|
+
bootstrap_sample: float = 0.8
|
|
10
|
+
bootstrap_CI: float = 0.95
|
|
11
|
+
bootstrap_CI_method: Literal["se", "percentile"] = "se"
|
|
12
|
+
cense_colname: Optional[str] = None
|
|
13
|
+
cense_denominator: Optional[str] = None
|
|
14
|
+
cense_numerator: Optional[str] = None
|
|
15
|
+
cense_eligible_colname: Optional[str] = None
|
|
16
|
+
compevent_colname: Optional[str] = None
|
|
17
|
+
covariates: Optional[str] = None
|
|
18
|
+
denominator: Optional[str] = None
|
|
19
|
+
excused: bool = False
|
|
20
|
+
excused_colnames: List[str] = field(default_factory=lambda: [])
|
|
21
|
+
followup_class: bool = False
|
|
22
|
+
followup_include: bool = True
|
|
23
|
+
followup_max: int = None
|
|
24
|
+
followup_min: int = 0
|
|
25
|
+
followup_spline: bool = False
|
|
26
|
+
hazard_estimate: bool = False
|
|
27
|
+
indicator_baseline: str = "_bas"
|
|
28
|
+
indicator_squared: str = "_sq"
|
|
29
|
+
km_curves: bool = False
|
|
30
|
+
ncores: int = multiprocessing.cpu_count()
|
|
31
|
+
numerator: Optional[str] = None
|
|
32
|
+
parallel: bool = False
|
|
33
|
+
plot_colors: List[str] = field(
|
|
34
|
+
default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]
|
|
35
|
+
)
|
|
36
|
+
plot_labels: List[str] = field(default_factory=lambda: [])
|
|
37
|
+
plot_title: str = None
|
|
38
|
+
plot_type: Literal["risk", "survival", "incidence"] = "risk"
|
|
39
|
+
seed: Optional[int] = None
|
|
40
|
+
selection_first_trial: bool = False
|
|
41
|
+
selection_probability: float = 0.8
|
|
42
|
+
selection_random: bool = False
|
|
43
|
+
subgroup_colname: str = None
|
|
44
|
+
treatment_level: List[int] = field(default_factory=lambda: [0, 1])
|
|
45
|
+
trial_include: bool = True
|
|
46
|
+
weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
|
|
47
|
+
weight_min: float = 0.0
|
|
48
|
+
weight_max: float = None
|
|
49
|
+
weight_lag_condition: bool = True
|
|
50
|
+
weight_p99: bool = False
|
|
51
|
+
weight_preexpansion: bool = False
|
|
52
|
+
weighted: bool = False
|
|
53
|
+
|
|
54
|
+
def __post_init__(self):
|
|
55
|
+
bools = [
|
|
56
|
+
"excused",
|
|
57
|
+
"followup_class",
|
|
58
|
+
"followup_include",
|
|
59
|
+
"followup_spline",
|
|
60
|
+
"hazard_estimate",
|
|
61
|
+
"km_curves",
|
|
62
|
+
"parallel",
|
|
63
|
+
"selection_first_trial",
|
|
64
|
+
"selection_random",
|
|
65
|
+
"trial_include",
|
|
66
|
+
"weight_lag_condition",
|
|
67
|
+
"weight_p99",
|
|
68
|
+
"weight_preexpansion",
|
|
69
|
+
"weighted",
|
|
70
|
+
]
|
|
71
|
+
for i in bools:
|
|
72
|
+
if not isinstance(getattr(self, i), bool):
|
|
73
|
+
raise TypeError(f"{i} must be a boolean value.")
|
|
74
|
+
|
|
75
|
+
if not isinstance(self.bootstrap_nboot, int) or self.bootstrap_nboot < 0:
|
|
76
|
+
raise ValueError("bootstrap_nboot must be a positive integer.")
|
|
77
|
+
|
|
78
|
+
if self.ncores < 1 or not isinstance(self.ncores, int):
|
|
79
|
+
raise ValueError("ncores must be a positive integer.")
|
|
80
|
+
|
|
81
|
+
if not (0.0 <= self.bootstrap_sample <= 1.0):
|
|
82
|
+
raise ValueError("bootstrap_sample must be between 0 and 1.")
|
|
83
|
+
if not (0.0 < self.bootstrap_CI < 1.0):
|
|
84
|
+
raise ValueError("bootstrap_CI must be between 0 and 1.")
|
|
85
|
+
if not (0.0 <= self.selection_probability <= 1.0):
|
|
86
|
+
raise ValueError("selection_probability must be between 0 and 1.")
|
|
87
|
+
|
|
88
|
+
if self.plot_type not in ["risk", "survival", "incidence"]:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
"plot_type must be either 'risk', 'survival', or 'incidence'."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
if self.bootstrap_CI_method not in ["se", "percentile"]:
|
|
94
|
+
raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'")
|
|
95
|
+
|
|
96
|
+
for i in (
|
|
97
|
+
"covariates",
|
|
98
|
+
"numerator",
|
|
99
|
+
"denominator",
|
|
100
|
+
"cense_numerator",
|
|
101
|
+
"cense_denominator",
|
|
102
|
+
):
|
|
103
|
+
attr = getattr(self, i)
|
|
104
|
+
if attr is not None and not isinstance(attr, list):
|
|
105
|
+
setattr(self, i, "".join(attr.split()))
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import List, Literal, Optional
|
|
3
|
+
|
|
4
|
+
import matplotlib.figure
|
|
5
|
+
import polars as pl
|
|
6
|
+
from statsmodels.base.wrapper import ResultsWrapper
|
|
7
|
+
|
|
8
|
+
from .SEQopts import SEQopts
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class SEQoutput:
|
|
13
|
+
options: SEQopts = None
|
|
14
|
+
method: str = None
|
|
15
|
+
numerator_models: List[ResultsWrapper] = None
|
|
16
|
+
denominator_models: List[ResultsWrapper] = None
|
|
17
|
+
outcome_models: List[List[ResultsWrapper]] = None
|
|
18
|
+
compevent_models: List[List[ResultsWrapper]] = None
|
|
19
|
+
weight_statistics: dict = None
|
|
20
|
+
hazard: pl.DataFrame = None
|
|
21
|
+
km_data: pl.DataFrame = None
|
|
22
|
+
km_graph: matplotlib.figure.Figure = None
|
|
23
|
+
risk_ratio: pl.DataFrame = None
|
|
24
|
+
risk_difference: pl.DataFrame = None
|
|
25
|
+
time: dict = None
|
|
26
|
+
diagnostic_tables: dict = None
|
|
27
|
+
|
|
28
|
+
def plot(self):
|
|
29
|
+
print(self.km_graph)
|
|
30
|
+
|
|
31
|
+
def summary(
|
|
32
|
+
self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]]
|
|
33
|
+
):
|
|
34
|
+
match type:
|
|
35
|
+
case "numerator":
|
|
36
|
+
models = self.numerator_models
|
|
37
|
+
case "denominator":
|
|
38
|
+
models = self.denominator_models
|
|
39
|
+
case "compevent":
|
|
40
|
+
models = self.compevent_models
|
|
41
|
+
case _:
|
|
42
|
+
models = self.outcome_models
|
|
43
|
+
|
|
44
|
+
return [model.summary() for model in models]
|
|
45
|
+
|
|
46
|
+
def retrieve_data(
|
|
47
|
+
self,
|
|
48
|
+
type=Optional[
|
|
49
|
+
Literal[
|
|
50
|
+
"km_data",
|
|
51
|
+
"hazard",
|
|
52
|
+
"risk_ratio",
|
|
53
|
+
"risk_difference",
|
|
54
|
+
"unique_outcomes",
|
|
55
|
+
"nonunique_outcomes",
|
|
56
|
+
"unique_switches",
|
|
57
|
+
"nonunique_switches",
|
|
58
|
+
]
|
|
59
|
+
],
|
|
60
|
+
):
|
|
61
|
+
match type:
|
|
62
|
+
case "hazard":
|
|
63
|
+
data = self.hazard
|
|
64
|
+
case "risk_ratio":
|
|
65
|
+
data = self.risk_ratio
|
|
66
|
+
case "risk_difference":
|
|
67
|
+
data = self.risk_difference
|
|
68
|
+
case "unique_outcomes":
|
|
69
|
+
data = self.diagnostic_tables["unique_outcomes"]
|
|
70
|
+
case "nonunique_outcomes":
|
|
71
|
+
data = self.diagnostic_tables["nonunique_outcomes"]
|
|
72
|
+
case "unique_switches":
|
|
73
|
+
if self.diagnostic_tables.has_key("unique_switches"):
|
|
74
|
+
data = self.diagnostic_tables["unique_switches"]
|
|
75
|
+
else:
|
|
76
|
+
data = None
|
|
77
|
+
case "nonunique_switches":
|
|
78
|
+
if self.diagnostic_tables.has_key("nonunique_switches"):
|
|
79
|
+
data = self.diagnostic_tables["nonunique_switches"]
|
|
80
|
+
else:
|
|
81
|
+
data = None
|
|
82
|
+
case _:
|
|
83
|
+
data = self.km_data
|
|
84
|
+
if data is None:
|
|
85
|
+
raise ValueError("Data {type} was not created in the SEQuential process")
|
|
86
|
+
return data
|
|
@@ -0,0 +1,315 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import time
|
|
3
|
+
from collections import Counter
|
|
4
|
+
from dataclasses import asdict
|
|
5
|
+
from typing import List, Literal, Optional
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import polars as pl
|
|
9
|
+
|
|
10
|
+
from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit,
|
|
11
|
+
_pred_risk, _risk_estimates, _subgroup_fit)
|
|
12
|
+
from .error import _datachecker, _param_checker
|
|
13
|
+
from .expansion import _binder, _diagnostics, _dynamic, _random_selection
|
|
14
|
+
from .helpers import _col_string, _format_time, bootstrap_loop
|
|
15
|
+
from .initialization import (_cense_denominator, _cense_numerator,
|
|
16
|
+
_denominator, _numerator, _outcome)
|
|
17
|
+
from .plot import _survival_plot
|
|
18
|
+
from .SEQopts import SEQopts
|
|
19
|
+
from .SEQoutput import SEQoutput
|
|
20
|
+
from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
|
|
21
|
+
_weight_bind, _weight_predict, _weight_setup,
|
|
22
|
+
_weight_stats)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SEQuential:
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
data: pl.DataFrame,
|
|
29
|
+
id_col: str,
|
|
30
|
+
time_col: str,
|
|
31
|
+
eligible_col: str,
|
|
32
|
+
treatment_col: str,
|
|
33
|
+
outcome_col: str,
|
|
34
|
+
time_varying_cols: Optional[List[str]] = None,
|
|
35
|
+
fixed_cols: Optional[List[str]] = None,
|
|
36
|
+
method: Literal["ITT", "dose-response", "censoring"] = "ITT",
|
|
37
|
+
parameters: Optional[SEQopts] = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
self.data = data
|
|
40
|
+
self.id_col = id_col
|
|
41
|
+
self.time_col = time_col
|
|
42
|
+
self.eligible_col = eligible_col
|
|
43
|
+
self.treatment_col = treatment_col
|
|
44
|
+
self.outcome_col = outcome_col
|
|
45
|
+
self.time_varying_cols = time_varying_cols
|
|
46
|
+
self.fixed_cols = fixed_cols
|
|
47
|
+
self.method = method
|
|
48
|
+
|
|
49
|
+
self._time_initialized = datetime.datetime.now()
|
|
50
|
+
|
|
51
|
+
if parameters is None:
|
|
52
|
+
parameters = SEQopts()
|
|
53
|
+
|
|
54
|
+
for name, value in asdict(parameters).items():
|
|
55
|
+
setattr(self, name, value)
|
|
56
|
+
|
|
57
|
+
self._rng = (
|
|
58
|
+
np.random.RandomState(self.seed) if self.seed is not None else np.random
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if self.covariates is None:
|
|
62
|
+
self.covariates = _outcome(self)
|
|
63
|
+
|
|
64
|
+
if self.weighted:
|
|
65
|
+
if self.numerator is None:
|
|
66
|
+
self.numerator = _numerator(self)
|
|
67
|
+
|
|
68
|
+
if self.denominator is None:
|
|
69
|
+
self.denominator = _denominator(self)
|
|
70
|
+
|
|
71
|
+
if self.cense_colname is not None:
|
|
72
|
+
if self.cense_numerator is None:
|
|
73
|
+
self.cense_numerator = _cense_numerator(self)
|
|
74
|
+
|
|
75
|
+
if self.cense_denominator is None:
|
|
76
|
+
self.cense_denominator = _cense_denominator(self)
|
|
77
|
+
|
|
78
|
+
_param_checker(self)
|
|
79
|
+
_datachecker(self)
|
|
80
|
+
|
|
81
|
+
def expand(self):
|
|
82
|
+
start = time.perf_counter()
|
|
83
|
+
kept = [
|
|
84
|
+
self.cense_colname,
|
|
85
|
+
self.cense_eligible_colname,
|
|
86
|
+
self.compevent_colname,
|
|
87
|
+
*self.weight_eligible_colnames,
|
|
88
|
+
*self.excused_colnames,
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
self.data = self.data.with_columns(
|
|
92
|
+
[
|
|
93
|
+
pl.when(pl.col(self.treatment_col).is_in(self.treatment_level))
|
|
94
|
+
.then(self.eligible_col)
|
|
95
|
+
.otherwise(0)
|
|
96
|
+
.alias(self.eligible_col),
|
|
97
|
+
pl.col(self.treatment_col).shift(1).over([self.id_col]).alias("tx_lag"),
|
|
98
|
+
pl.lit(False).alias("switch"),
|
|
99
|
+
]
|
|
100
|
+
).with_columns(
|
|
101
|
+
[
|
|
102
|
+
pl.when(pl.col(self.time_col) == 0)
|
|
103
|
+
.then(pl.lit(False))
|
|
104
|
+
.otherwise(
|
|
105
|
+
(pl.col("tx_lag").is_not_null())
|
|
106
|
+
& (pl.col("tx_lag") != pl.col(self.treatment_col))
|
|
107
|
+
)
|
|
108
|
+
.cast(pl.Int8)
|
|
109
|
+
.alias("switch")
|
|
110
|
+
]
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
self.DT = _binder(
|
|
114
|
+
self,
|
|
115
|
+
kept_cols=_col_string(
|
|
116
|
+
[
|
|
117
|
+
self.covariates,
|
|
118
|
+
self.numerator,
|
|
119
|
+
self.denominator,
|
|
120
|
+
self.cense_numerator,
|
|
121
|
+
self.cense_denominator,
|
|
122
|
+
]
|
|
123
|
+
).union(kept),
|
|
124
|
+
).with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))
|
|
125
|
+
|
|
126
|
+
self.data = self.data.with_columns(
|
|
127
|
+
pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if self.method != "ITT":
|
|
131
|
+
_dynamic(self)
|
|
132
|
+
if self.selection_random:
|
|
133
|
+
_random_selection(self)
|
|
134
|
+
_diagnostics(self)
|
|
135
|
+
|
|
136
|
+
end = time.perf_counter()
|
|
137
|
+
self._expansion_time = _format_time(start, end)
|
|
138
|
+
|
|
139
|
+
def bootstrap(self, **kwargs):
|
|
140
|
+
allowed = {
|
|
141
|
+
"bootstrap_nboot",
|
|
142
|
+
"bootstrap_sample",
|
|
143
|
+
"bootstrap_CI",
|
|
144
|
+
"bootstrap_method",
|
|
145
|
+
}
|
|
146
|
+
for key, value in kwargs.items():
|
|
147
|
+
if key in allowed:
|
|
148
|
+
setattr(self, key, value)
|
|
149
|
+
else:
|
|
150
|
+
raise ValueError(f"Unknown argument: {key}")
|
|
151
|
+
|
|
152
|
+
UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list()
|
|
153
|
+
NIDs = len(UIDs)
|
|
154
|
+
|
|
155
|
+
self._boot_samples = []
|
|
156
|
+
for _ in range(self.bootstrap_nboot):
|
|
157
|
+
sampled_IDs = self._rng.choice(
|
|
158
|
+
UIDs, size=int(self.bootstrap_sample * NIDs), replace=True
|
|
159
|
+
)
|
|
160
|
+
id_counts = Counter(sampled_IDs)
|
|
161
|
+
self._boot_samples.append(id_counts)
|
|
162
|
+
return self
|
|
163
|
+
|
|
164
|
+
@bootstrap_loop
|
|
165
|
+
def fit(self):
|
|
166
|
+
if self.bootstrap_nboot > 0 and not hasattr(self, "_boot_samples"):
|
|
167
|
+
raise ValueError(
|
|
168
|
+
"Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
if self.weighted:
|
|
172
|
+
WDT = _weight_setup(self)
|
|
173
|
+
if not self.weight_preexpansion and not self.excused:
|
|
174
|
+
WDT = WDT.filter(pl.col("followup") > 0)
|
|
175
|
+
|
|
176
|
+
WDT = WDT.to_pandas()
|
|
177
|
+
for col in self.fixed_cols:
|
|
178
|
+
if col in WDT.columns:
|
|
179
|
+
WDT[col] = WDT[col].astype("category")
|
|
180
|
+
|
|
181
|
+
_fit_LTFU(self, WDT)
|
|
182
|
+
_fit_numerator(self, WDT)
|
|
183
|
+
_fit_denominator(self, WDT)
|
|
184
|
+
|
|
185
|
+
WDT = pl.from_pandas(WDT)
|
|
186
|
+
WDT = _weight_predict(self, WDT)
|
|
187
|
+
_weight_bind(self, WDT)
|
|
188
|
+
self.weight_stats = _weight_stats(self)
|
|
189
|
+
|
|
190
|
+
if self.subgroup_colname is not None:
|
|
191
|
+
return _subgroup_fit(self)
|
|
192
|
+
|
|
193
|
+
models = {
|
|
194
|
+
"outcome": _outcome_fit(
|
|
195
|
+
self,
|
|
196
|
+
self.DT,
|
|
197
|
+
self.outcome_col,
|
|
198
|
+
self.covariates,
|
|
199
|
+
self.weighted,
|
|
200
|
+
"weight",
|
|
201
|
+
)
|
|
202
|
+
}
|
|
203
|
+
if self.compevent_colname is not None:
|
|
204
|
+
models["compevent"] = _outcome_fit(
|
|
205
|
+
self,
|
|
206
|
+
self.DT,
|
|
207
|
+
self.compevent_colname,
|
|
208
|
+
self.covariates,
|
|
209
|
+
self.weighted,
|
|
210
|
+
"weight",
|
|
211
|
+
)
|
|
212
|
+
return models
|
|
213
|
+
|
|
214
|
+
def survival(self):
|
|
215
|
+
if not hasattr(self, "outcome_model") or not self.outcome_model:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
"Outcome model not found. Please run the 'fit' method before calculating survival."
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
start = time.perf_counter()
|
|
221
|
+
|
|
222
|
+
risk_data = _pred_risk(self)
|
|
223
|
+
surv_data = _calculate_survival(self, risk_data)
|
|
224
|
+
self.km_data = pl.concat([risk_data, surv_data])
|
|
225
|
+
self.risk_estimates = _risk_estimates(self)
|
|
226
|
+
|
|
227
|
+
end = time.perf_counter()
|
|
228
|
+
self._survival_time = _format_time(start, end)
|
|
229
|
+
|
|
230
|
+
def hazard(self):
|
|
231
|
+
start = time.perf_counter()
|
|
232
|
+
|
|
233
|
+
if not hasattr(self, "outcome_model") or not self.outcome_model:
|
|
234
|
+
raise ValueError(
|
|
235
|
+
"Outcome model not found. Please run the 'fit' method before calculating hazard ratio."
|
|
236
|
+
)
|
|
237
|
+
self.hazard_ratio = _calculate_hazard(self)
|
|
238
|
+
|
|
239
|
+
end = time.perf_counter()
|
|
240
|
+
self._hazard_time = _format_time(start, end)
|
|
241
|
+
|
|
242
|
+
def plot(self):
|
|
243
|
+
self.km_graph = _survival_plot(self)
|
|
244
|
+
|
|
245
|
+
def collect(self):
|
|
246
|
+
self._time_collected = datetime.datetime.now()
|
|
247
|
+
|
|
248
|
+
generated = [
|
|
249
|
+
"numerator_model",
|
|
250
|
+
"denominator_model",
|
|
251
|
+
"outcome_model",
|
|
252
|
+
"hazard_ratio",
|
|
253
|
+
"risk_estimates",
|
|
254
|
+
"km_data",
|
|
255
|
+
"km_graph",
|
|
256
|
+
"diagnostics",
|
|
257
|
+
"_survival_time",
|
|
258
|
+
"_hazard_time",
|
|
259
|
+
"_model_time",
|
|
260
|
+
"_expansion_time",
|
|
261
|
+
"weight_stats",
|
|
262
|
+
]
|
|
263
|
+
for attr in generated:
|
|
264
|
+
if not hasattr(self, attr):
|
|
265
|
+
setattr(self, attr, None)
|
|
266
|
+
|
|
267
|
+
# Options ==========================
|
|
268
|
+
base = SEQopts()
|
|
269
|
+
|
|
270
|
+
for name, value in vars(self).items():
|
|
271
|
+
if name in asdict(base).keys():
|
|
272
|
+
setattr(base, name, value)
|
|
273
|
+
|
|
274
|
+
# Timing =========================
|
|
275
|
+
time = {
|
|
276
|
+
"start_time": self._time_initialized,
|
|
277
|
+
"expansion_time": self._expansion_time,
|
|
278
|
+
"model_time": self._model_time,
|
|
279
|
+
"survival_time": self._survival_time,
|
|
280
|
+
"hazard_time": self._hazard_time,
|
|
281
|
+
"collection_time": self._time_collected,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
if self.compevent_colname is not None:
|
|
285
|
+
compevent_models = [model["compevent"] for model in self.outcome_models]
|
|
286
|
+
else:
|
|
287
|
+
compevent_models = None
|
|
288
|
+
|
|
289
|
+
if self.outcome_model is not None:
|
|
290
|
+
outcome_models = [model["outcome"] for model in self.outcome_model]
|
|
291
|
+
|
|
292
|
+
if self.risk_estimates is None:
|
|
293
|
+
risk_ratio = risk_difference = None
|
|
294
|
+
else:
|
|
295
|
+
risk_ratio = self.risk_estimates["risk_ratio"]
|
|
296
|
+
risk_difference = self.risk_estimates["risk_difference"]
|
|
297
|
+
|
|
298
|
+
output = SEQoutput(
|
|
299
|
+
options=base,
|
|
300
|
+
method=self.method,
|
|
301
|
+
numerator_models=self.numerator_model,
|
|
302
|
+
denominator_models=self.denominator_model,
|
|
303
|
+
outcome_models=outcome_models,
|
|
304
|
+
compevent_models=compevent_models,
|
|
305
|
+
weight_statistics=self.weight_stats,
|
|
306
|
+
hazard=self.hazard_ratio,
|
|
307
|
+
km_data=self.km_data,
|
|
308
|
+
km_graph=self.km_graph,
|
|
309
|
+
risk_ratio=risk_ratio,
|
|
310
|
+
risk_difference=risk_difference,
|
|
311
|
+
time=time,
|
|
312
|
+
diagnostic_tables=self.diagnostics,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
return output
|