pySEQTarget 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pySEQTarget/SEQopts.py +197 -0
- pySEQTarget/SEQoutput.py +163 -0
- pySEQTarget/SEQuential.py +375 -0
- pySEQTarget/__init__.py +5 -0
- pySEQTarget/analysis/__init__.py +8 -0
- pySEQTarget/analysis/_hazard.py +211 -0
- pySEQTarget/analysis/_outcome_fit.py +75 -0
- pySEQTarget/analysis/_risk_estimates.py +136 -0
- pySEQTarget/analysis/_subgroup_fit.py +30 -0
- pySEQTarget/analysis/_survival_pred.py +372 -0
- pySEQTarget/data/__init__.py +19 -0
- pySEQTarget/error/__init__.py +2 -0
- pySEQTarget/error/_datachecker.py +38 -0
- pySEQTarget/error/_param_checker.py +50 -0
- pySEQTarget/expansion/__init__.py +5 -0
- pySEQTarget/expansion/_binder.py +98 -0
- pySEQTarget/expansion/_diagnostics.py +53 -0
- pySEQTarget/expansion/_dynamic.py +73 -0
- pySEQTarget/expansion/_mapper.py +44 -0
- pySEQTarget/expansion/_selection.py +31 -0
- pySEQTarget/helpers/__init__.py +8 -0
- pySEQTarget/helpers/_bootstrap.py +111 -0
- pySEQTarget/helpers/_col_string.py +6 -0
- pySEQTarget/helpers/_format_time.py +6 -0
- pySEQTarget/helpers/_output_files.py +167 -0
- pySEQTarget/helpers/_pad.py +7 -0
- pySEQTarget/helpers/_predict_model.py +9 -0
- pySEQTarget/helpers/_prepare_data.py +19 -0
- pySEQTarget/initialization/__init__.py +5 -0
- pySEQTarget/initialization/_censoring.py +53 -0
- pySEQTarget/initialization/_denominator.py +39 -0
- pySEQTarget/initialization/_numerator.py +37 -0
- pySEQTarget/initialization/_outcome.py +56 -0
- pySEQTarget/plot/__init__.py +1 -0
- pySEQTarget/plot/_survival_plot.py +104 -0
- pySEQTarget/weighting/__init__.py +8 -0
- pySEQTarget/weighting/_weight_bind.py +86 -0
- pySEQTarget/weighting/_weight_data.py +47 -0
- pySEQTarget/weighting/_weight_fit.py +99 -0
- pySEQTarget/weighting/_weight_pred.py +192 -0
- pySEQTarget/weighting/_weight_stats.py +23 -0
- pyseqtarget-0.10.0.dist-info/METADATA +98 -0
- pyseqtarget-0.10.0.dist-info/RECORD +46 -0
- pyseqtarget-0.10.0.dist-info/WHEEL +5 -0
- pyseqtarget-0.10.0.dist-info/licenses/LICENSE +21 -0
- pyseqtarget-0.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
def _cense_numerator(self) -> str:
|
|
2
|
+
trial = (
|
|
3
|
+
"+".join(["trial", f"trial{self.indicator_squared}"])
|
|
4
|
+
if self.trial_include
|
|
5
|
+
else None
|
|
6
|
+
)
|
|
7
|
+
followup = (
|
|
8
|
+
"+".join(["followup", f"followup{self.indicator_squared}"])
|
|
9
|
+
if self.followup_include
|
|
10
|
+
else None
|
|
11
|
+
)
|
|
12
|
+
time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"])
|
|
13
|
+
tv_bas = (
|
|
14
|
+
"+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
|
|
15
|
+
if self.time_varying_cols
|
|
16
|
+
else None
|
|
17
|
+
)
|
|
18
|
+
fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
|
|
19
|
+
|
|
20
|
+
if self.weight_preexpansion:
|
|
21
|
+
out = "+".join(filter(None, ["tx_lag", time, fixed]))
|
|
22
|
+
else:
|
|
23
|
+
out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv_bas]))
|
|
24
|
+
|
|
25
|
+
return out
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _cense_denominator(self) -> str:
|
|
29
|
+
trial = (
|
|
30
|
+
"+".join(["trial", f"trial{self.indicator_squared}"])
|
|
31
|
+
if self.trial_include
|
|
32
|
+
else None
|
|
33
|
+
)
|
|
34
|
+
followup = (
|
|
35
|
+
"+".join(["followup", f"followup{self.indicator_squared}"])
|
|
36
|
+
if self.followup_include
|
|
37
|
+
else None
|
|
38
|
+
)
|
|
39
|
+
time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"])
|
|
40
|
+
tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None
|
|
41
|
+
tv_bas = (
|
|
42
|
+
"+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
|
|
43
|
+
if self.time_varying_cols
|
|
44
|
+
else None
|
|
45
|
+
)
|
|
46
|
+
fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
|
|
47
|
+
|
|
48
|
+
if self.weight_preexpansion:
|
|
49
|
+
out = "+".join(filter(None, ["tx_lag", time, fixed, tv]))
|
|
50
|
+
else:
|
|
51
|
+
out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv, tv_bas]))
|
|
52
|
+
|
|
53
|
+
return out
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
def _denominator(self) -> str:
|
|
2
|
+
if self.method == "ITT":
|
|
3
|
+
return
|
|
4
|
+
trial = (
|
|
5
|
+
"+".join(["trial", f"trial{self.indicator_squared}"])
|
|
6
|
+
if self.trial_include
|
|
7
|
+
else None
|
|
8
|
+
)
|
|
9
|
+
followup = (
|
|
10
|
+
"+".join(["followup", f"followup{self.indicator_squared}"])
|
|
11
|
+
if self.followup_include
|
|
12
|
+
else None
|
|
13
|
+
)
|
|
14
|
+
time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"])
|
|
15
|
+
|
|
16
|
+
tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None
|
|
17
|
+
tv_bas = (
|
|
18
|
+
"+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
|
|
19
|
+
if self.time_varying_cols
|
|
20
|
+
else None
|
|
21
|
+
)
|
|
22
|
+
fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
|
|
23
|
+
|
|
24
|
+
if self.weight_preexpansion:
|
|
25
|
+
if self.method == "dose-response":
|
|
26
|
+
out = "+".join(filter(None, [fixed, tv, time]))
|
|
27
|
+
elif self.method == "censoring" and not self.excused:
|
|
28
|
+
out = "+".join(filter(None, [fixed, tv, time]))
|
|
29
|
+
elif self.method == "censoring" and self.excused:
|
|
30
|
+
out = "+".join(filter(None, [fixed, tv, time]))
|
|
31
|
+
else:
|
|
32
|
+
if self.method == "dose-response":
|
|
33
|
+
out = "+".join(filter(None, [fixed, tv, tv_bas, followup, trial]))
|
|
34
|
+
elif self.method == "censoring" and not self.excused:
|
|
35
|
+
out = "+".join(filter(None, [fixed, tv, tv_bas, followup, trial]))
|
|
36
|
+
elif self.method == "censoring" and self.excused:
|
|
37
|
+
out = "+".join(filter(None, [fixed, tv, tv_bas, followup, trial]))
|
|
38
|
+
|
|
39
|
+
return out
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
def _numerator(self) -> str:
|
|
2
|
+
if self.method == "ITT":
|
|
3
|
+
return
|
|
4
|
+
trial = (
|
|
5
|
+
"+".join(["trial", f"trial{self.indicator_squared}"])
|
|
6
|
+
if self.trial_include
|
|
7
|
+
else None
|
|
8
|
+
)
|
|
9
|
+
followup = (
|
|
10
|
+
"+".join(["followup", f"followup{self.indicator_squared}"])
|
|
11
|
+
if self.followup_include
|
|
12
|
+
else None
|
|
13
|
+
)
|
|
14
|
+
time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"])
|
|
15
|
+
|
|
16
|
+
tv_bas = (
|
|
17
|
+
"+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
|
|
18
|
+
if self.time_varying_cols
|
|
19
|
+
else None
|
|
20
|
+
)
|
|
21
|
+
fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
|
|
22
|
+
|
|
23
|
+
if self.weight_preexpansion:
|
|
24
|
+
if self.method == "dose-response":
|
|
25
|
+
out = "+".join(filter(None, [fixed, time]))
|
|
26
|
+
elif self.method == "censoring" and not self.excused:
|
|
27
|
+
out = "+".join(filter(None, [fixed, time]))
|
|
28
|
+
elif self.method == "censoring" and self.excused:
|
|
29
|
+
out = None
|
|
30
|
+
else:
|
|
31
|
+
if self.method == "dose-response":
|
|
32
|
+
out = "+".join(filter(None, [fixed, tv_bas, followup, trial]))
|
|
33
|
+
elif self.method == "censoring" and not self.excused:
|
|
34
|
+
out = "+".join(filter(None, [fixed, tv_bas, followup, trial]))
|
|
35
|
+
elif self.method == "censoring" and self.excused:
|
|
36
|
+
out = "+".join(filter(None, [fixed, tv_bas, followup, trial]))
|
|
37
|
+
return out
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
def _outcome(self) -> str:
|
|
2
|
+
tx_bas = f"{self.treatment_col}{self.indicator_baseline}"
|
|
3
|
+
dose = "+".join(["dose", f"dose{self.indicator_squared}"])
|
|
4
|
+
interaction = f"{tx_bas}*followup"
|
|
5
|
+
interaction_dose = "+".join(
|
|
6
|
+
["followup*dose", f"followup*dose{self.indicator_squared}"]
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
if self.hazard or not self.km_curves:
|
|
10
|
+
interaction = interaction_dose = None
|
|
11
|
+
|
|
12
|
+
tv_bas = (
|
|
13
|
+
"+".join([f"{v}_bas" for v in self.time_varying_cols])
|
|
14
|
+
if self.time_varying_cols
|
|
15
|
+
else None
|
|
16
|
+
)
|
|
17
|
+
fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
|
|
18
|
+
trial = (
|
|
19
|
+
"+".join(["trial", f"trial{self.indicator_squared}"])
|
|
20
|
+
if self.trial_include
|
|
21
|
+
else None
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
if self.followup_include:
|
|
25
|
+
followup = "+".join(["followup", f"followup{self.indicator_squared}"])
|
|
26
|
+
elif (self.followup_spline or self.followup_class) and not self.followup_include:
|
|
27
|
+
followup = "followup"
|
|
28
|
+
else:
|
|
29
|
+
followup = None
|
|
30
|
+
|
|
31
|
+
if self.method == "ITT":
|
|
32
|
+
parts = [tx_bas, followup, trial, fixed, tv_bas, interaction]
|
|
33
|
+
return "+".join(filter(None, parts))
|
|
34
|
+
|
|
35
|
+
if self.weighted:
|
|
36
|
+
if self.weight_preexpansion:
|
|
37
|
+
if self.method == "dose-response":
|
|
38
|
+
parts = [dose, followup, trial, fixed, interaction_dose]
|
|
39
|
+
elif self.method == "censoring":
|
|
40
|
+
if self.excused:
|
|
41
|
+
parts = [tx_bas, followup, trial, interaction]
|
|
42
|
+
else:
|
|
43
|
+
parts = [tx_bas, followup, trial, fixed, interaction]
|
|
44
|
+
else:
|
|
45
|
+
if self.method == "dose-response":
|
|
46
|
+
parts = [dose, followup, trial, fixed, tv_bas, interaction_dose]
|
|
47
|
+
elif self.method == "censoring":
|
|
48
|
+
parts = [tx_bas, followup, trial, fixed, tv_bas, interaction]
|
|
49
|
+
return "+".join(filter(None, parts))
|
|
50
|
+
|
|
51
|
+
if self.method == "dose-response":
|
|
52
|
+
parts = [dose, followup, trial, fixed, tv_bas, interaction_dose]
|
|
53
|
+
elif self.method == "censoring":
|
|
54
|
+
parts = [tx_bas, followup, trial, fixed, tv_bas, interaction]
|
|
55
|
+
|
|
56
|
+
return "+".join(filter(None, parts))
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from ._survival_plot import _survival_plot as _survival_plot
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
5
|
+
import polars as pl
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _survival_plot(self):
|
|
9
|
+
if self.plot_type == "risk":
|
|
10
|
+
plot_data = self.km_data.filter(pl.col("estimate") == "risk")
|
|
11
|
+
elif self.plot_type == "survival":
|
|
12
|
+
plot_data = self.km_data.filter(pl.col("estimate") == "survival")
|
|
13
|
+
else:
|
|
14
|
+
plot_data = self.km_data.filter(pl.col("estimate") == "incidence")
|
|
15
|
+
|
|
16
|
+
if self.subgroup_colname is None:
|
|
17
|
+
fig = _plot_single(self, plot_data)
|
|
18
|
+
else:
|
|
19
|
+
fig = _plot_subgroups(self, plot_data)
|
|
20
|
+
|
|
21
|
+
return fig
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _plot_single(self, plot_data):
|
|
25
|
+
fig, ax = plt.subplots(figsize=(10, 6))
|
|
26
|
+
_plot_data(self, plot_data, ax)
|
|
27
|
+
|
|
28
|
+
if self.plot_title is None:
|
|
29
|
+
self.plot_title = f"Cumulative {self.plot_type.title()}"
|
|
30
|
+
|
|
31
|
+
ax.set_xlabel("Followup")
|
|
32
|
+
ax.set_ylabel(self.plot_type.title())
|
|
33
|
+
ax.set_title(self.plot_title)
|
|
34
|
+
ax.legend()
|
|
35
|
+
ax.grid()
|
|
36
|
+
|
|
37
|
+
return fig
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _plot_subgroups(self, plot_data):
|
|
41
|
+
subgroups = sorted(plot_data[self.subgroup_colname].unique().to_list())
|
|
42
|
+
n_subgroups = len(subgroups)
|
|
43
|
+
n_cols = min(3, n_subgroups)
|
|
44
|
+
n_rows = (n_subgroups + n_cols - 1) // n_cols
|
|
45
|
+
|
|
46
|
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 6 * n_rows))
|
|
47
|
+
axes = np.atleast_1d(axes).flatten()
|
|
48
|
+
|
|
49
|
+
for idx, subgroup_val in enumerate(subgroups):
|
|
50
|
+
ax = axes[idx]
|
|
51
|
+
subgroup_data = plot_data.filter(pl.col(self.subgroup_colname) == subgroup_val)
|
|
52
|
+
_plot_data(self, subgroup_data, ax)
|
|
53
|
+
subgroup_label = (
|
|
54
|
+
str(subgroup_val).title() if isinstance(subgroup_val, str) else subgroup_val
|
|
55
|
+
)
|
|
56
|
+
ax.set_xlabel("Followup")
|
|
57
|
+
ax.set_ylabel(self.plot_type.title())
|
|
58
|
+
ax.set_title(
|
|
59
|
+
f"{self.subgroup_colname.title()}: {subgroup_label}",
|
|
60
|
+
fontsize=10,
|
|
61
|
+
style="italic",
|
|
62
|
+
)
|
|
63
|
+
ax.legend()
|
|
64
|
+
ax.grid()
|
|
65
|
+
|
|
66
|
+
for idx in range(n_subgroups, len(axes)):
|
|
67
|
+
axes[idx].set_visible(False)
|
|
68
|
+
|
|
69
|
+
if self.plot_title:
|
|
70
|
+
fig.suptitle(self.plot_title, fontsize=14)
|
|
71
|
+
else:
|
|
72
|
+
fig.suptitle(f"Cumulative {self.plot_type.title()}", fontsize=14)
|
|
73
|
+
|
|
74
|
+
plt.tight_layout()
|
|
75
|
+
return fig
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _plot_data(self, plot_data, ax):
|
|
79
|
+
color_cycle = itertools.cycle(self.plot_colors) if self.plot_colors else None
|
|
80
|
+
|
|
81
|
+
for idx, i in enumerate(self.treatment_level):
|
|
82
|
+
subset = plot_data.filter(pl.col(self.treatment_col) == i)
|
|
83
|
+
if subset.is_empty():
|
|
84
|
+
continue
|
|
85
|
+
|
|
86
|
+
label = f"treatment = {i}"
|
|
87
|
+
if self.plot_labels and idx < len(self.plot_labels):
|
|
88
|
+
label = self.plot_labels[idx]
|
|
89
|
+
|
|
90
|
+
color = next(color_cycle) if color_cycle else None
|
|
91
|
+
|
|
92
|
+
(line,) = ax.plot(
|
|
93
|
+
subset["followup"], subset["pred"], "-", label=label, color=color
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if "LCI" in subset.columns and "UCI" in subset.columns:
|
|
97
|
+
ax.fill_between(
|
|
98
|
+
subset["followup"],
|
|
99
|
+
subset["LCI"],
|
|
100
|
+
subset["UCI"],
|
|
101
|
+
color=line.get_color(),
|
|
102
|
+
alpha=0.2,
|
|
103
|
+
label="_nolegend_",
|
|
104
|
+
)
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from ._weight_bind import _weight_bind as _weight_bind
|
|
2
|
+
from ._weight_data import _weight_setup as _weight_setup
|
|
3
|
+
from ._weight_fit import _fit_denominator as _fit_denominator
|
|
4
|
+
from ._weight_fit import _fit_LTFU as _fit_LTFU
|
|
5
|
+
from ._weight_fit import _fit_numerator as _fit_numerator
|
|
6
|
+
from ._weight_fit import _fit_visit as _fit_visit
|
|
7
|
+
from ._weight_pred import _weight_predict as _weight_predict
|
|
8
|
+
from ._weight_stats import _weight_stats as _weight_stats
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _weight_bind(self, WDT):
|
|
5
|
+
if self.weight_preexpansion:
|
|
6
|
+
join = "inner"
|
|
7
|
+
on = [self.id_col, "period"]
|
|
8
|
+
WDT = WDT.rename({self.time_col: "period"})
|
|
9
|
+
else:
|
|
10
|
+
join = "left"
|
|
11
|
+
on = [self.id_col, "trial", "followup"]
|
|
12
|
+
|
|
13
|
+
WDT = self.DT.join(WDT, on=on, how=join)
|
|
14
|
+
|
|
15
|
+
if self.visit_colname is not None:
|
|
16
|
+
visit = pl.col(self.visit_colname) == 0
|
|
17
|
+
else:
|
|
18
|
+
visit = pl.lit(False)
|
|
19
|
+
|
|
20
|
+
if self.weight_preexpansion and self.excused:
|
|
21
|
+
trial = (pl.col("trial") == 0) & (pl.col("period") == 0)
|
|
22
|
+
excused = (
|
|
23
|
+
pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"])
|
|
24
|
+
> 0
|
|
25
|
+
)
|
|
26
|
+
override = (
|
|
27
|
+
trial
|
|
28
|
+
| excused
|
|
29
|
+
| visit
|
|
30
|
+
| pl.col(self.outcome_col).is_null()
|
|
31
|
+
| (pl.col("denominator") < 1e-7)
|
|
32
|
+
)
|
|
33
|
+
elif not self.weight_preexpansion and self.excused:
|
|
34
|
+
trial = pl.col("followup") == 0
|
|
35
|
+
excused = (
|
|
36
|
+
pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"])
|
|
37
|
+
> 0
|
|
38
|
+
)
|
|
39
|
+
override = (
|
|
40
|
+
trial
|
|
41
|
+
| excused
|
|
42
|
+
| visit
|
|
43
|
+
| pl.col(self.outcome_col).is_null()
|
|
44
|
+
| (pl.col("denominator") < 1e-7)
|
|
45
|
+
| (pl.col("numerator") < 1e-7)
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col)) & (
|
|
49
|
+
pl.col("followup") == 0
|
|
50
|
+
)
|
|
51
|
+
excused = pl.lit(False)
|
|
52
|
+
override = (
|
|
53
|
+
trial
|
|
54
|
+
| excused
|
|
55
|
+
| visit
|
|
56
|
+
| pl.col(self.outcome_col).is_null()
|
|
57
|
+
| (pl.col("denominator") < 1e-15)
|
|
58
|
+
| pl.col("numerator").is_null()
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
self.DT = (
|
|
62
|
+
(
|
|
63
|
+
WDT.with_columns(
|
|
64
|
+
pl.when(override)
|
|
65
|
+
.then(pl.lit(1.0))
|
|
66
|
+
.otherwise(pl.col("numerator") / pl.col("denominator"))
|
|
67
|
+
.alias("wt")
|
|
68
|
+
)
|
|
69
|
+
.sort([self.id_col, "trial", "followup"])
|
|
70
|
+
.with_columns(
|
|
71
|
+
pl.col("wt")
|
|
72
|
+
.fill_null(1.0)
|
|
73
|
+
.cum_prod()
|
|
74
|
+
.over([self.id_col, "trial"])
|
|
75
|
+
.alias("weight")
|
|
76
|
+
)
|
|
77
|
+
)
|
|
78
|
+
.with_columns(
|
|
79
|
+
(
|
|
80
|
+
pl.col("weight")
|
|
81
|
+
* pl.col("_cense").fill_null(1.0)
|
|
82
|
+
* pl.col("_visit").fill_null(1.0)
|
|
83
|
+
).alias("weight")
|
|
84
|
+
)
|
|
85
|
+
.drop(["_cense", "_visit"])
|
|
86
|
+
)
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _weight_setup(self):
|
|
5
|
+
DT = self.DT
|
|
6
|
+
data = self.data
|
|
7
|
+
if not self.weight_preexpansion:
|
|
8
|
+
baseline_lag = (
|
|
9
|
+
data.select([self.treatment_col, self.id_col, self.time_col])
|
|
10
|
+
.sort([self.id_col, self.time_col])
|
|
11
|
+
.with_columns(
|
|
12
|
+
pl.col(self.treatment_col)
|
|
13
|
+
.shift(fill_value=self.treatment_level[0])
|
|
14
|
+
.over(self.id_col)
|
|
15
|
+
.alias("tx_lag")
|
|
16
|
+
)
|
|
17
|
+
.drop(self.treatment_col)
|
|
18
|
+
.rename({self.time_col: "period"})
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
fup0 = DT.filter(pl.col("followup") == 0).join(
|
|
22
|
+
baseline_lag, on=[self.id_col, "period"], how="inner"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
fup = (
|
|
26
|
+
DT.sort([self.id_col, "trial", "followup"])
|
|
27
|
+
.with_columns(
|
|
28
|
+
pl.col(self.treatment_col)
|
|
29
|
+
.shift(fill_value=self.treatment_level[0])
|
|
30
|
+
.over([self.id_col, "trial"])
|
|
31
|
+
.alias("tx_lag")
|
|
32
|
+
)
|
|
33
|
+
.filter(pl.col("followup") != 0)
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
WDT = pl.concat([fup0, fup]).sort([self.id_col, "trial", "followup"])
|
|
37
|
+
else:
|
|
38
|
+
WDT = data.with_columns(
|
|
39
|
+
pl.col(self.treatment_col)
|
|
40
|
+
.shift(fill_value=self.treatment_level[0])
|
|
41
|
+
.over(self.id_col)
|
|
42
|
+
.alias("tx_lag"),
|
|
43
|
+
(pl.col(self.time_col) ** 2).alias(
|
|
44
|
+
f"{self.time_col}{self.indicator_squared}"
|
|
45
|
+
),
|
|
46
|
+
)
|
|
47
|
+
return WDT
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import statsmodels.api as sm
|
|
2
|
+
import statsmodels.formula.api as smf
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _fit_pair(
|
|
6
|
+
self, WDT, outcome_attr, formula_attr, output_attrs, eligible_colname_attr=None
|
|
7
|
+
):
|
|
8
|
+
outcome = getattr(self, outcome_attr)
|
|
9
|
+
|
|
10
|
+
if eligible_colname_attr is not None:
|
|
11
|
+
_eligible_col = getattr(self, eligible_colname_attr)
|
|
12
|
+
if _eligible_col is not None:
|
|
13
|
+
WDT = WDT[WDT[_eligible_col] == 1]
|
|
14
|
+
|
|
15
|
+
for rhs, out in zip(formula_attr, output_attrs):
|
|
16
|
+
formula = f"{outcome}~{rhs}"
|
|
17
|
+
model = smf.glm(formula, WDT, family=sm.families.Binomial())
|
|
18
|
+
setattr(self, out, model.fit(disp=0))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _fit_LTFU(self, WDT):
|
|
22
|
+
if self.cense_colname is None:
|
|
23
|
+
return
|
|
24
|
+
_fit_pair(
|
|
25
|
+
self,
|
|
26
|
+
WDT,
|
|
27
|
+
"cense_colname",
|
|
28
|
+
[self.cense_numerator, self.cense_denominator],
|
|
29
|
+
["cense_numerator", "cense_denominator"],
|
|
30
|
+
"cense_eligible_colname",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _fit_visit(self, WDT):
|
|
35
|
+
if self.visit_colname is None:
|
|
36
|
+
return
|
|
37
|
+
_fit_pair(
|
|
38
|
+
self,
|
|
39
|
+
WDT,
|
|
40
|
+
"visit_colname",
|
|
41
|
+
[self.cense_numerator, self.cense_denominator],
|
|
42
|
+
["visit_numerator", "visit_denominator"],
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _fit_numerator(self, WDT):
|
|
47
|
+
if self.weight_preexpansion and self.excused:
|
|
48
|
+
return
|
|
49
|
+
if self.method == "ITT":
|
|
50
|
+
return
|
|
51
|
+
predictor = "switch" if self.excused else self.treatment_col
|
|
52
|
+
formula = f"{predictor}~{self.numerator}"
|
|
53
|
+
tx_bas = (
|
|
54
|
+
f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag"
|
|
55
|
+
)
|
|
56
|
+
fits = []
|
|
57
|
+
for i, level in enumerate(self.treatment_level):
|
|
58
|
+
if self.excused and self.excused_colnames[i] is not None:
|
|
59
|
+
DT_subset = WDT[WDT[self.excused_colnames[i]] == 0]
|
|
60
|
+
else:
|
|
61
|
+
DT_subset = WDT
|
|
62
|
+
if self.weight_lag_condition:
|
|
63
|
+
DT_subset = DT_subset[DT_subset[tx_bas] == level]
|
|
64
|
+
if self.weight_eligible_colnames[i] is not None:
|
|
65
|
+
DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1]
|
|
66
|
+
model = smf.mnlogit(formula, DT_subset)
|
|
67
|
+
model_fit = model.fit(disp=0)
|
|
68
|
+
fits.append(model_fit)
|
|
69
|
+
|
|
70
|
+
self.numerator_model = fits
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _fit_denominator(self, WDT):
|
|
74
|
+
if self.method == "ITT":
|
|
75
|
+
return
|
|
76
|
+
predictor = (
|
|
77
|
+
"switch"
|
|
78
|
+
if self.excused and not self.weight_preexpansion
|
|
79
|
+
else self.treatment_col
|
|
80
|
+
)
|
|
81
|
+
formula = f"{predictor}~{self.denominator}"
|
|
82
|
+
fits = []
|
|
83
|
+
for i, level in enumerate(self.treatment_level):
|
|
84
|
+
if self.excused and self.excused_colnames[i] is not None:
|
|
85
|
+
DT_subset = WDT[WDT[self.excused_colnames[i]] == 0]
|
|
86
|
+
else:
|
|
87
|
+
DT_subset = WDT
|
|
88
|
+
if self.weight_lag_condition:
|
|
89
|
+
DT_subset = DT_subset[DT_subset["tx_lag"] == level]
|
|
90
|
+
if not self.weight_preexpansion:
|
|
91
|
+
DT_subset = DT_subset[DT_subset["followup"] != 0]
|
|
92
|
+
if self.weight_eligible_colnames[i] is not None:
|
|
93
|
+
DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1]
|
|
94
|
+
|
|
95
|
+
model = smf.mnlogit(formula, DT_subset)
|
|
96
|
+
model_fit = model.fit(disp=0)
|
|
97
|
+
fits.append(model_fit)
|
|
98
|
+
|
|
99
|
+
self.denominator_model = fits
|