pySEQTarget 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 CAUSALab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,91 @@
1
+ Metadata-Version: 2.4
2
+ Name: pySEQTarget
3
+ Version: 0.9.0
4
+ Summary: Sequentially Nested Target Trial Emulation
5
+ Author-email: Ryan O'Dea <ryan.odea@psi.ch>, Alejandro Szmulewicz <aszmulewicz@hsph.harvard.edu>, Tom Palmer <tom.palmer@bristol.ac.uk>, Miguel Hernan <mhernan@hsph.harvard.edu>
6
+ Maintainer-email: Ryan O'Dea <ryan.odea@psi.ch>
7
+ License: MIT
8
+ Project-URL: Homepage, https://github.com/CausalInference/pySEQTarget
9
+ Project-URL: Repository, https://github.com/CausalInference/pySEQTarget
10
+ Project-URL: Bug Tracker, https://github.com/CausalInference/pySEQTarget/issues
11
+ Project-URL: Ryan O'Dea (ORCID), https://orcid.org/0009-0000-0103-9546
12
+ Project-URL: Alejandro Szmulewicz (ORCID), https://orcid.org/0000-0002-2664-802X
13
+ Project-URL: Tom Palmer (ORCID), https://orcid.org/0000-0003-4655-4511
14
+ Project-URL: Miguel Hernan (ORCID), https://orcid.org/0000-0003-1619-8456
15
+ Project-URL: University of Bristol (ROR), https://ror.org/0524sp257
16
+ Project-URL: Harvard University (ROR), https://ror.org/03vek6s52
17
+ Keywords: causal inference,sequential trial emulation,target trial,observational studies
18
+ Classifier: Development Status :: 4 - Beta
19
+ Classifier: Intended Audience :: Science/Research
20
+ Classifier: Programming Language :: Python :: 3
21
+ Classifier: Programming Language :: Python :: 3.10
22
+ Classifier: Programming Language :: Python :: 3.11
23
+ Classifier: Programming Language :: Python :: 3.12
24
+ Requires-Python: >=3.10
25
+ Description-Content-Type: text/markdown
26
+ License-File: LICENSE
27
+ Requires-Dist: numpy
28
+ Requires-Dist: polars
29
+ Requires-Dist: tqdm
30
+ Requires-Dist: statsmodels
31
+ Requires-Dist: matplotlib
32
+ Requires-Dist: pyarrow
33
+ Requires-Dist: lifelines
34
+ Dynamic: license-file
35
+
36
+ # pySEQTarget - Sequentially Nested Target Trial Emulation
37
+
38
+ Implementation of sequential trial emulation for the analysis of
39
+ observational databases. The ‘SEQTaRget’ software accommodates
40
+ time-varying treatments and confounders, as well as binary and failure
41
+ time outcomes. ‘SEQTaRget’ allows to compare both static and dynamic
42
+ strategies, can be used to estimate observational analogs of
43
+ intention-to-treat and per-protocol effects, and can adjust for
44
+ potential selection bias.
45
+
46
+ ## Installation
47
+ You can install the development version of pySEQTarget from github with:
48
+ ```shell
49
+ pip install git+https://github.com/CausalInference/pySEQTarget
50
+ ```
51
+ Or from pypi iwth
52
+ ```shell
53
+ pip install pySEQTarget
54
+ ```
55
+
56
+ ## Setting up your Analysis
57
+ The primary API, `SEQuential` uses a dataclass system to handle function input. You can then recover elements as they are built by interacting with the `SEQuential` object you create.
58
+
59
+ From the user side, this amounts to creating a dataclass, `SEQopts`, and then feeding this into `SEQuential`. If you forgot to add something at class instantiation, you can, in some cases, add them when you call their respective class method.
60
+
61
+ ```python
62
+ import polars as pl
63
+ from pySEQTarget import SEQuential, SEQopts
64
+
65
+ data = pl.from_pandas(SEQdata)
66
+ options = SEQopts(km_curves = True)
67
+
68
+ # Initiate the class
69
+ model = SEQuential(data,
70
+ id_col = "ID",
71
+ time_col = "time",
72
+ eligible_col = "eligible",
73
+ time_varying_cols = ["N", "L", "P"],
74
+ fixed_cols = ["sex"],
75
+ method = "ITT",
76
+ options = options)
77
+ model.expand() # Construct the nested structure
78
+ model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
79
+ model.fit() # Fit the model
80
+ model.survival() # Create survival curves
81
+ model.plot() # Create and show a plot of the survival curves
82
+ model.collect() # Collection of important information
83
+
84
+ ```
85
+
86
+ ## Assumptions
87
+ There are several key assumptions in this package -
88
+ 1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
89
+ 1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
90
+ 2. `eligible_col`, `excused_column_names` and [TODO] are once 1, only 1 (with respect to `time_col`) flag variables.
91
+
@@ -0,0 +1,56 @@
1
+ # pySEQTarget - Sequentially Nested Target Trial Emulation
2
+
3
+ Implementation of sequential trial emulation for the analysis of
4
+ observational databases. The ‘SEQTaRget’ software accommodates
5
+ time-varying treatments and confounders, as well as binary and failure
6
+ time outcomes. ‘SEQTaRget’ allows to compare both static and dynamic
7
+ strategies, can be used to estimate observational analogs of
8
+ intention-to-treat and per-protocol effects, and can adjust for
9
+ potential selection bias.
10
+
11
+ ## Installation
12
+ You can install the development version of pySEQTarget from github with:
13
+ ```shell
14
+ pip install git+https://github.com/CausalInference/pySEQTarget
15
+ ```
16
+ Or from pypi iwth
17
+ ```shell
18
+ pip install pySEQTarget
19
+ ```
20
+
21
+ ## Setting up your Analysis
22
+ The primary API, `SEQuential` uses a dataclass system to handle function input. You can then recover elements as they are built by interacting with the `SEQuential` object you create.
23
+
24
+ From the user side, this amounts to creating a dataclass, `SEQopts`, and then feeding this into `SEQuential`. If you forgot to add something at class instantiation, you can, in some cases, add them when you call their respective class method.
25
+
26
+ ```python
27
+ import polars as pl
28
+ from pySEQTarget import SEQuential, SEQopts
29
+
30
+ data = pl.from_pandas(SEQdata)
31
+ options = SEQopts(km_curves = True)
32
+
33
+ # Initiate the class
34
+ model = SEQuential(data,
35
+ id_col = "ID",
36
+ time_col = "time",
37
+ eligible_col = "eligible",
38
+ time_varying_cols = ["N", "L", "P"],
39
+ fixed_cols = ["sex"],
40
+ method = "ITT",
41
+ options = options)
42
+ model.expand() # Construct the nested structure
43
+ model.bootstrap(bootstrap_nboot = 20) # Run 20 bootstrap samples
44
+ model.fit() # Fit the model
45
+ model.survival() # Create survival curves
46
+ model.plot() # Create and show a plot of the survival curves
47
+ model.collect() # Collection of important information
48
+
49
+ ```
50
+
51
+ ## Assumptions
52
+ There are several key assumptions in this package -
53
+ 1. User provided `time_col` begins at 0 per unique `id_col`, we also assume this column contains only integers and continues by 1 for every time step, e.g. (0, 1, 2, 3, 4, ...) is allowed and (0, 1, 2, 2.5, ...) or (0, 1, 4, 5) are not
54
+ 1. Provided `time_col` entries may be out of order at intake as a sort is enforced at expansion.
55
+ 2. `eligible_col`, `excused_column_names` and [TODO] are once 1, only 1 (with respect to `time_col`) flag variables.
56
+
@@ -0,0 +1,105 @@
1
+ import multiprocessing
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Literal, Optional
4
+
5
+
6
+ @dataclass
7
+ class SEQopts:
8
+ bootstrap_nboot: int = 0
9
+ bootstrap_sample: float = 0.8
10
+ bootstrap_CI: float = 0.95
11
+ bootstrap_CI_method: Literal["se", "percentile"] = "se"
12
+ cense_colname: Optional[str] = None
13
+ cense_denominator: Optional[str] = None
14
+ cense_numerator: Optional[str] = None
15
+ cense_eligible_colname: Optional[str] = None
16
+ compevent_colname: Optional[str] = None
17
+ covariates: Optional[str] = None
18
+ denominator: Optional[str] = None
19
+ excused: bool = False
20
+ excused_colnames: List[str] = field(default_factory=lambda: [])
21
+ followup_class: bool = False
22
+ followup_include: bool = True
23
+ followup_max: int = None
24
+ followup_min: int = 0
25
+ followup_spline: bool = False
26
+ hazard_estimate: bool = False
27
+ indicator_baseline: str = "_bas"
28
+ indicator_squared: str = "_sq"
29
+ km_curves: bool = False
30
+ ncores: int = multiprocessing.cpu_count()
31
+ numerator: Optional[str] = None
32
+ parallel: bool = False
33
+ plot_colors: List[str] = field(
34
+ default_factory=lambda: ["#F8766D", "#00BFC4", "#555555"]
35
+ )
36
+ plot_labels: List[str] = field(default_factory=lambda: [])
37
+ plot_title: str = None
38
+ plot_type: Literal["risk", "survival", "incidence"] = "risk"
39
+ seed: Optional[int] = None
40
+ selection_first_trial: bool = False
41
+ selection_probability: float = 0.8
42
+ selection_random: bool = False
43
+ subgroup_colname: str = None
44
+ treatment_level: List[int] = field(default_factory=lambda: [0, 1])
45
+ trial_include: bool = True
46
+ weight_eligible_colnames: List[str] = field(default_factory=lambda: [])
47
+ weight_min: float = 0.0
48
+ weight_max: float = None
49
+ weight_lag_condition: bool = True
50
+ weight_p99: bool = False
51
+ weight_preexpansion: bool = False
52
+ weighted: bool = False
53
+
54
+ def __post_init__(self):
55
+ bools = [
56
+ "excused",
57
+ "followup_class",
58
+ "followup_include",
59
+ "followup_spline",
60
+ "hazard_estimate",
61
+ "km_curves",
62
+ "parallel",
63
+ "selection_first_trial",
64
+ "selection_random",
65
+ "trial_include",
66
+ "weight_lag_condition",
67
+ "weight_p99",
68
+ "weight_preexpansion",
69
+ "weighted",
70
+ ]
71
+ for i in bools:
72
+ if not isinstance(getattr(self, i), bool):
73
+ raise TypeError(f"{i} must be a boolean value.")
74
+
75
+ if not isinstance(self.bootstrap_nboot, int) or self.bootstrap_nboot < 0:
76
+ raise ValueError("bootstrap_nboot must be a positive integer.")
77
+
78
+ if self.ncores < 1 or not isinstance(self.ncores, int):
79
+ raise ValueError("ncores must be a positive integer.")
80
+
81
+ if not (0.0 <= self.bootstrap_sample <= 1.0):
82
+ raise ValueError("bootstrap_sample must be between 0 and 1.")
83
+ if not (0.0 < self.bootstrap_CI < 1.0):
84
+ raise ValueError("bootstrap_CI must be between 0 and 1.")
85
+ if not (0.0 <= self.selection_probability <= 1.0):
86
+ raise ValueError("selection_probability must be between 0 and 1.")
87
+
88
+ if self.plot_type not in ["risk", "survival", "incidence"]:
89
+ raise ValueError(
90
+ "plot_type must be either 'risk', 'survival', or 'incidence'."
91
+ )
92
+
93
+ if self.bootstrap_CI_method not in ["se", "percentile"]:
94
+ raise ValueError("bootstrap_CI_method must be one of 'se' or 'percentile'")
95
+
96
+ for i in (
97
+ "covariates",
98
+ "numerator",
99
+ "denominator",
100
+ "cense_numerator",
101
+ "cense_denominator",
102
+ ):
103
+ attr = getattr(self, i)
104
+ if attr is not None and not isinstance(attr, list):
105
+ setattr(self, i, "".join(attr.split()))
@@ -0,0 +1,86 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Literal, Optional
3
+
4
+ import matplotlib.figure
5
+ import polars as pl
6
+ from statsmodels.base.wrapper import ResultsWrapper
7
+
8
+ from .SEQopts import SEQopts
9
+
10
+
11
+ @dataclass
12
+ class SEQoutput:
13
+ options: SEQopts = None
14
+ method: str = None
15
+ numerator_models: List[ResultsWrapper] = None
16
+ denominator_models: List[ResultsWrapper] = None
17
+ outcome_models: List[List[ResultsWrapper]] = None
18
+ compevent_models: List[List[ResultsWrapper]] = None
19
+ weight_statistics: dict = None
20
+ hazard: pl.DataFrame = None
21
+ km_data: pl.DataFrame = None
22
+ km_graph: matplotlib.figure.Figure = None
23
+ risk_ratio: pl.DataFrame = None
24
+ risk_difference: pl.DataFrame = None
25
+ time: dict = None
26
+ diagnostic_tables: dict = None
27
+
28
+ def plot(self):
29
+ print(self.km_graph)
30
+
31
+ def summary(
32
+ self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]]
33
+ ):
34
+ match type:
35
+ case "numerator":
36
+ models = self.numerator_models
37
+ case "denominator":
38
+ models = self.denominator_models
39
+ case "compevent":
40
+ models = self.compevent_models
41
+ case _:
42
+ models = self.outcome_models
43
+
44
+ return [model.summary() for model in models]
45
+
46
+ def retrieve_data(
47
+ self,
48
+ type=Optional[
49
+ Literal[
50
+ "km_data",
51
+ "hazard",
52
+ "risk_ratio",
53
+ "risk_difference",
54
+ "unique_outcomes",
55
+ "nonunique_outcomes",
56
+ "unique_switches",
57
+ "nonunique_switches",
58
+ ]
59
+ ],
60
+ ):
61
+ match type:
62
+ case "hazard":
63
+ data = self.hazard
64
+ case "risk_ratio":
65
+ data = self.risk_ratio
66
+ case "risk_difference":
67
+ data = self.risk_difference
68
+ case "unique_outcomes":
69
+ data = self.diagnostic_tables["unique_outcomes"]
70
+ case "nonunique_outcomes":
71
+ data = self.diagnostic_tables["nonunique_outcomes"]
72
+ case "unique_switches":
73
+ if self.diagnostic_tables.has_key("unique_switches"):
74
+ data = self.diagnostic_tables["unique_switches"]
75
+ else:
76
+ data = None
77
+ case "nonunique_switches":
78
+ if self.diagnostic_tables.has_key("nonunique_switches"):
79
+ data = self.diagnostic_tables["nonunique_switches"]
80
+ else:
81
+ data = None
82
+ case _:
83
+ data = self.km_data
84
+ if data is None:
85
+ raise ValueError("Data {type} was not created in the SEQuential process")
86
+ return data
@@ -0,0 +1,315 @@
1
+ import datetime
2
+ import time
3
+ from collections import Counter
4
+ from dataclasses import asdict
5
+ from typing import List, Literal, Optional
6
+
7
+ import numpy as np
8
+ import polars as pl
9
+
10
+ from .analysis import (_calculate_hazard, _calculate_survival, _outcome_fit,
11
+ _pred_risk, _risk_estimates, _subgroup_fit)
12
+ from .error import _datachecker, _param_checker
13
+ from .expansion import _binder, _diagnostics, _dynamic, _random_selection
14
+ from .helpers import _col_string, _format_time, bootstrap_loop
15
+ from .initialization import (_cense_denominator, _cense_numerator,
16
+ _denominator, _numerator, _outcome)
17
+ from .plot import _survival_plot
18
+ from .SEQopts import SEQopts
19
+ from .SEQoutput import SEQoutput
20
+ from .weighting import (_fit_denominator, _fit_LTFU, _fit_numerator,
21
+ _weight_bind, _weight_predict, _weight_setup,
22
+ _weight_stats)
23
+
24
+
25
+ class SEQuential:
26
+ def __init__(
27
+ self,
28
+ data: pl.DataFrame,
29
+ id_col: str,
30
+ time_col: str,
31
+ eligible_col: str,
32
+ treatment_col: str,
33
+ outcome_col: str,
34
+ time_varying_cols: Optional[List[str]] = None,
35
+ fixed_cols: Optional[List[str]] = None,
36
+ method: Literal["ITT", "dose-response", "censoring"] = "ITT",
37
+ parameters: Optional[SEQopts] = None,
38
+ ) -> None:
39
+ self.data = data
40
+ self.id_col = id_col
41
+ self.time_col = time_col
42
+ self.eligible_col = eligible_col
43
+ self.treatment_col = treatment_col
44
+ self.outcome_col = outcome_col
45
+ self.time_varying_cols = time_varying_cols
46
+ self.fixed_cols = fixed_cols
47
+ self.method = method
48
+
49
+ self._time_initialized = datetime.datetime.now()
50
+
51
+ if parameters is None:
52
+ parameters = SEQopts()
53
+
54
+ for name, value in asdict(parameters).items():
55
+ setattr(self, name, value)
56
+
57
+ self._rng = (
58
+ np.random.RandomState(self.seed) if self.seed is not None else np.random
59
+ )
60
+
61
+ if self.covariates is None:
62
+ self.covariates = _outcome(self)
63
+
64
+ if self.weighted:
65
+ if self.numerator is None:
66
+ self.numerator = _numerator(self)
67
+
68
+ if self.denominator is None:
69
+ self.denominator = _denominator(self)
70
+
71
+ if self.cense_colname is not None:
72
+ if self.cense_numerator is None:
73
+ self.cense_numerator = _cense_numerator(self)
74
+
75
+ if self.cense_denominator is None:
76
+ self.cense_denominator = _cense_denominator(self)
77
+
78
+ _param_checker(self)
79
+ _datachecker(self)
80
+
81
+ def expand(self):
82
+ start = time.perf_counter()
83
+ kept = [
84
+ self.cense_colname,
85
+ self.cense_eligible_colname,
86
+ self.compevent_colname,
87
+ *self.weight_eligible_colnames,
88
+ *self.excused_colnames,
89
+ ]
90
+
91
+ self.data = self.data.with_columns(
92
+ [
93
+ pl.when(pl.col(self.treatment_col).is_in(self.treatment_level))
94
+ .then(self.eligible_col)
95
+ .otherwise(0)
96
+ .alias(self.eligible_col),
97
+ pl.col(self.treatment_col).shift(1).over([self.id_col]).alias("tx_lag"),
98
+ pl.lit(False).alias("switch"),
99
+ ]
100
+ ).with_columns(
101
+ [
102
+ pl.when(pl.col(self.time_col) == 0)
103
+ .then(pl.lit(False))
104
+ .otherwise(
105
+ (pl.col("tx_lag").is_not_null())
106
+ & (pl.col("tx_lag") != pl.col(self.treatment_col))
107
+ )
108
+ .cast(pl.Int8)
109
+ .alias("switch")
110
+ ]
111
+ )
112
+
113
+ self.DT = _binder(
114
+ self,
115
+ kept_cols=_col_string(
116
+ [
117
+ self.covariates,
118
+ self.numerator,
119
+ self.denominator,
120
+ self.cense_numerator,
121
+ self.cense_denominator,
122
+ ]
123
+ ).union(kept),
124
+ ).with_columns(pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col))
125
+
126
+ self.data = self.data.with_columns(
127
+ pl.col(self.id_col).cast(pl.Utf8).alias(self.id_col)
128
+ )
129
+
130
+ if self.method != "ITT":
131
+ _dynamic(self)
132
+ if self.selection_random:
133
+ _random_selection(self)
134
+ _diagnostics(self)
135
+
136
+ end = time.perf_counter()
137
+ self._expansion_time = _format_time(start, end)
138
+
139
+ def bootstrap(self, **kwargs):
140
+ allowed = {
141
+ "bootstrap_nboot",
142
+ "bootstrap_sample",
143
+ "bootstrap_CI",
144
+ "bootstrap_method",
145
+ }
146
+ for key, value in kwargs.items():
147
+ if key in allowed:
148
+ setattr(self, key, value)
149
+ else:
150
+ raise ValueError(f"Unknown argument: {key}")
151
+
152
+ UIDs = self.DT.select(pl.col(self.id_col)).unique().to_series().to_list()
153
+ NIDs = len(UIDs)
154
+
155
+ self._boot_samples = []
156
+ for _ in range(self.bootstrap_nboot):
157
+ sampled_IDs = self._rng.choice(
158
+ UIDs, size=int(self.bootstrap_sample * NIDs), replace=True
159
+ )
160
+ id_counts = Counter(sampled_IDs)
161
+ self._boot_samples.append(id_counts)
162
+ return self
163
+
164
+ @bootstrap_loop
165
+ def fit(self):
166
+ if self.bootstrap_nboot > 0 and not hasattr(self, "_boot_samples"):
167
+ raise ValueError(
168
+ "Bootstrap sampling not found. Please run the 'bootstrap' method before fitting with bootstrapping."
169
+ )
170
+
171
+ if self.weighted:
172
+ WDT = _weight_setup(self)
173
+ if not self.weight_preexpansion and not self.excused:
174
+ WDT = WDT.filter(pl.col("followup") > 0)
175
+
176
+ WDT = WDT.to_pandas()
177
+ for col in self.fixed_cols:
178
+ if col in WDT.columns:
179
+ WDT[col] = WDT[col].astype("category")
180
+
181
+ _fit_LTFU(self, WDT)
182
+ _fit_numerator(self, WDT)
183
+ _fit_denominator(self, WDT)
184
+
185
+ WDT = pl.from_pandas(WDT)
186
+ WDT = _weight_predict(self, WDT)
187
+ _weight_bind(self, WDT)
188
+ self.weight_stats = _weight_stats(self)
189
+
190
+ if self.subgroup_colname is not None:
191
+ return _subgroup_fit(self)
192
+
193
+ models = {
194
+ "outcome": _outcome_fit(
195
+ self,
196
+ self.DT,
197
+ self.outcome_col,
198
+ self.covariates,
199
+ self.weighted,
200
+ "weight",
201
+ )
202
+ }
203
+ if self.compevent_colname is not None:
204
+ models["compevent"] = _outcome_fit(
205
+ self,
206
+ self.DT,
207
+ self.compevent_colname,
208
+ self.covariates,
209
+ self.weighted,
210
+ "weight",
211
+ )
212
+ return models
213
+
214
+ def survival(self):
215
+ if not hasattr(self, "outcome_model") or not self.outcome_model:
216
+ raise ValueError(
217
+ "Outcome model not found. Please run the 'fit' method before calculating survival."
218
+ )
219
+
220
+ start = time.perf_counter()
221
+
222
+ risk_data = _pred_risk(self)
223
+ surv_data = _calculate_survival(self, risk_data)
224
+ self.km_data = pl.concat([risk_data, surv_data])
225
+ self.risk_estimates = _risk_estimates(self)
226
+
227
+ end = time.perf_counter()
228
+ self._survival_time = _format_time(start, end)
229
+
230
+ def hazard(self):
231
+ start = time.perf_counter()
232
+
233
+ if not hasattr(self, "outcome_model") or not self.outcome_model:
234
+ raise ValueError(
235
+ "Outcome model not found. Please run the 'fit' method before calculating hazard ratio."
236
+ )
237
+ self.hazard_ratio = _calculate_hazard(self)
238
+
239
+ end = time.perf_counter()
240
+ self._hazard_time = _format_time(start, end)
241
+
242
+ def plot(self):
243
+ self.km_graph = _survival_plot(self)
244
+
245
+ def collect(self):
246
+ self._time_collected = datetime.datetime.now()
247
+
248
+ generated = [
249
+ "numerator_model",
250
+ "denominator_model",
251
+ "outcome_model",
252
+ "hazard_ratio",
253
+ "risk_estimates",
254
+ "km_data",
255
+ "km_graph",
256
+ "diagnostics",
257
+ "_survival_time",
258
+ "_hazard_time",
259
+ "_model_time",
260
+ "_expansion_time",
261
+ "weight_stats",
262
+ ]
263
+ for attr in generated:
264
+ if not hasattr(self, attr):
265
+ setattr(self, attr, None)
266
+
267
+ # Options ==========================
268
+ base = SEQopts()
269
+
270
+ for name, value in vars(self).items():
271
+ if name in asdict(base).keys():
272
+ setattr(base, name, value)
273
+
274
+ # Timing =========================
275
+ time = {
276
+ "start_time": self._time_initialized,
277
+ "expansion_time": self._expansion_time,
278
+ "model_time": self._model_time,
279
+ "survival_time": self._survival_time,
280
+ "hazard_time": self._hazard_time,
281
+ "collection_time": self._time_collected,
282
+ }
283
+
284
+ if self.compevent_colname is not None:
285
+ compevent_models = [model["compevent"] for model in self.outcome_models]
286
+ else:
287
+ compevent_models = None
288
+
289
+ if self.outcome_model is not None:
290
+ outcome_models = [model["outcome"] for model in self.outcome_model]
291
+
292
+ if self.risk_estimates is None:
293
+ risk_ratio = risk_difference = None
294
+ else:
295
+ risk_ratio = self.risk_estimates["risk_ratio"]
296
+ risk_difference = self.risk_estimates["risk_difference"]
297
+
298
+ output = SEQoutput(
299
+ options=base,
300
+ method=self.method,
301
+ numerator_models=self.numerator_model,
302
+ denominator_models=self.denominator_model,
303
+ outcome_models=outcome_models,
304
+ compevent_models=compevent_models,
305
+ weight_statistics=self.weight_stats,
306
+ hazard=self.hazard_ratio,
307
+ km_data=self.km_data,
308
+ km_graph=self.km_graph,
309
+ risk_ratio=risk_ratio,
310
+ risk_difference=risk_difference,
311
+ time=time,
312
+ diagnostic_tables=self.diagnostics,
313
+ )
314
+
315
+ return output
@@ -0,0 +1,5 @@
1
+ from .SEQopts import SEQopts
2
+ from .SEQoutput import SEQoutput
3
+ from .SEQuential import SEQuential
4
+
5
+ __all__ = ["SEQuential", "SEQopts", "SEQoutput"]