pySEQTarget 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pySEQTarget/SEQopts.py +197 -0
  2. pySEQTarget/SEQoutput.py +163 -0
  3. pySEQTarget/SEQuential.py +375 -0
  4. pySEQTarget/__init__.py +5 -0
  5. pySEQTarget/analysis/__init__.py +8 -0
  6. pySEQTarget/analysis/_hazard.py +211 -0
  7. pySEQTarget/analysis/_outcome_fit.py +75 -0
  8. pySEQTarget/analysis/_risk_estimates.py +136 -0
  9. pySEQTarget/analysis/_subgroup_fit.py +30 -0
  10. pySEQTarget/analysis/_survival_pred.py +372 -0
  11. pySEQTarget/data/__init__.py +19 -0
  12. pySEQTarget/error/__init__.py +2 -0
  13. pySEQTarget/error/_datachecker.py +38 -0
  14. pySEQTarget/error/_param_checker.py +50 -0
  15. pySEQTarget/expansion/__init__.py +5 -0
  16. pySEQTarget/expansion/_binder.py +98 -0
  17. pySEQTarget/expansion/_diagnostics.py +53 -0
  18. pySEQTarget/expansion/_dynamic.py +73 -0
  19. pySEQTarget/expansion/_mapper.py +44 -0
  20. pySEQTarget/expansion/_selection.py +31 -0
  21. pySEQTarget/helpers/__init__.py +8 -0
  22. pySEQTarget/helpers/_bootstrap.py +111 -0
  23. pySEQTarget/helpers/_col_string.py +6 -0
  24. pySEQTarget/helpers/_format_time.py +6 -0
  25. pySEQTarget/helpers/_output_files.py +167 -0
  26. pySEQTarget/helpers/_pad.py +7 -0
  27. pySEQTarget/helpers/_predict_model.py +9 -0
  28. pySEQTarget/helpers/_prepare_data.py +19 -0
  29. pySEQTarget/initialization/__init__.py +5 -0
  30. pySEQTarget/initialization/_censoring.py +53 -0
  31. pySEQTarget/initialization/_denominator.py +39 -0
  32. pySEQTarget/initialization/_numerator.py +37 -0
  33. pySEQTarget/initialization/_outcome.py +56 -0
  34. pySEQTarget/plot/__init__.py +1 -0
  35. pySEQTarget/plot/_survival_plot.py +104 -0
  36. pySEQTarget/weighting/__init__.py +8 -0
  37. pySEQTarget/weighting/_weight_bind.py +86 -0
  38. pySEQTarget/weighting/_weight_data.py +47 -0
  39. pySEQTarget/weighting/_weight_fit.py +99 -0
  40. pySEQTarget/weighting/_weight_pred.py +192 -0
  41. pySEQTarget/weighting/_weight_stats.py +23 -0
  42. pyseqtarget-0.10.0.dist-info/METADATA +98 -0
  43. pyseqtarget-0.10.0.dist-info/RECORD +46 -0
  44. pyseqtarget-0.10.0.dist-info/WHEEL +5 -0
  45. pyseqtarget-0.10.0.dist-info/licenses/LICENSE +21 -0
  46. pyseqtarget-0.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,53 @@
1
+ def _cense_numerator(self) -> str:
2
+ trial = (
3
+ "+".join(["trial", f"trial{self.indicator_squared}"])
4
+ if self.trial_include
5
+ else None
6
+ )
7
+ followup = (
8
+ "+".join(["followup", f"followup{self.indicator_squared}"])
9
+ if self.followup_include
10
+ else None
11
+ )
12
+ time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"])
13
+ tv_bas = (
14
+ "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
15
+ if self.time_varying_cols
16
+ else None
17
+ )
18
+ fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
19
+
20
+ if self.weight_preexpansion:
21
+ out = "+".join(filter(None, ["tx_lag", time, fixed]))
22
+ else:
23
+ out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv_bas]))
24
+
25
+ return out
26
+
27
+
28
+ def _cense_denominator(self) -> str:
29
+ trial = (
30
+ "+".join(["trial", f"trial{self.indicator_squared}"])
31
+ if self.trial_include
32
+ else None
33
+ )
34
+ followup = (
35
+ "+".join(["followup", f"followup{self.indicator_squared}"])
36
+ if self.followup_include
37
+ else None
38
+ )
39
+ time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"])
40
+ tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None
41
+ tv_bas = (
42
+ "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
43
+ if self.time_varying_cols
44
+ else None
45
+ )
46
+ fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
47
+
48
+ if self.weight_preexpansion:
49
+ out = "+".join(filter(None, ["tx_lag", time, fixed, tv]))
50
+ else:
51
+ out = "+".join(filter(None, ["tx_lag", trial, followup, fixed, tv, tv_bas]))
52
+
53
+ return out
@@ -0,0 +1,39 @@
1
+ def _denominator(self) -> str:
2
+ if self.method == "ITT":
3
+ return
4
+ trial = (
5
+ "+".join(["trial", f"trial{self.indicator_squared}"])
6
+ if self.trial_include
7
+ else None
8
+ )
9
+ followup = (
10
+ "+".join(["followup", f"followup{self.indicator_squared}"])
11
+ if self.followup_include
12
+ else None
13
+ )
14
+ time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"])
15
+
16
+ tv = "+".join(self.time_varying_cols) if self.time_varying_cols else None
17
+ tv_bas = (
18
+ "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
19
+ if self.time_varying_cols
20
+ else None
21
+ )
22
+ fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
23
+
24
+ if self.weight_preexpansion:
25
+ if self.method == "dose-response":
26
+ out = "+".join(filter(None, [fixed, tv, time]))
27
+ elif self.method == "censoring" and not self.excused:
28
+ out = "+".join(filter(None, [fixed, tv, time]))
29
+ elif self.method == "censoring" and self.excused:
30
+ out = "+".join(filter(None, [fixed, tv, time]))
31
+ else:
32
+ if self.method == "dose-response":
33
+ out = "+".join(filter(None, [fixed, tv, tv_bas, followup, trial]))
34
+ elif self.method == "censoring" and not self.excused:
35
+ out = "+".join(filter(None, [fixed, tv, tv_bas, followup, trial]))
36
+ elif self.method == "censoring" and self.excused:
37
+ out = "+".join(filter(None, [fixed, tv, tv_bas, followup, trial]))
38
+
39
+ return out
@@ -0,0 +1,37 @@
1
+ def _numerator(self) -> str:
2
+ if self.method == "ITT":
3
+ return
4
+ trial = (
5
+ "+".join(["trial", f"trial{self.indicator_squared}"])
6
+ if self.trial_include
7
+ else None
8
+ )
9
+ followup = (
10
+ "+".join(["followup", f"followup{self.indicator_squared}"])
11
+ if self.followup_include
12
+ else None
13
+ )
14
+ time = "+".join([self.time_col, f"{self.time_col}{self.indicator_squared}"])
15
+
16
+ tv_bas = (
17
+ "+".join([f"{v}{self.indicator_baseline}" for v in self.time_varying_cols])
18
+ if self.time_varying_cols
19
+ else None
20
+ )
21
+ fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
22
+
23
+ if self.weight_preexpansion:
24
+ if self.method == "dose-response":
25
+ out = "+".join(filter(None, [fixed, time]))
26
+ elif self.method == "censoring" and not self.excused:
27
+ out = "+".join(filter(None, [fixed, time]))
28
+ elif self.method == "censoring" and self.excused:
29
+ out = None
30
+ else:
31
+ if self.method == "dose-response":
32
+ out = "+".join(filter(None, [fixed, tv_bas, followup, trial]))
33
+ elif self.method == "censoring" and not self.excused:
34
+ out = "+".join(filter(None, [fixed, tv_bas, followup, trial]))
35
+ elif self.method == "censoring" and self.excused:
36
+ out = "+".join(filter(None, [fixed, tv_bas, followup, trial]))
37
+ return out
@@ -0,0 +1,56 @@
1
+ def _outcome(self) -> str:
2
+ tx_bas = f"{self.treatment_col}{self.indicator_baseline}"
3
+ dose = "+".join(["dose", f"dose{self.indicator_squared}"])
4
+ interaction = f"{tx_bas}*followup"
5
+ interaction_dose = "+".join(
6
+ ["followup*dose", f"followup*dose{self.indicator_squared}"]
7
+ )
8
+
9
+ if self.hazard or not self.km_curves:
10
+ interaction = interaction_dose = None
11
+
12
+ tv_bas = (
13
+ "+".join([f"{v}_bas" for v in self.time_varying_cols])
14
+ if self.time_varying_cols
15
+ else None
16
+ )
17
+ fixed = "+".join(self.fixed_cols) if self.fixed_cols else None
18
+ trial = (
19
+ "+".join(["trial", f"trial{self.indicator_squared}"])
20
+ if self.trial_include
21
+ else None
22
+ )
23
+
24
+ if self.followup_include:
25
+ followup = "+".join(["followup", f"followup{self.indicator_squared}"])
26
+ elif (self.followup_spline or self.followup_class) and not self.followup_include:
27
+ followup = "followup"
28
+ else:
29
+ followup = None
30
+
31
+ if self.method == "ITT":
32
+ parts = [tx_bas, followup, trial, fixed, tv_bas, interaction]
33
+ return "+".join(filter(None, parts))
34
+
35
+ if self.weighted:
36
+ if self.weight_preexpansion:
37
+ if self.method == "dose-response":
38
+ parts = [dose, followup, trial, fixed, interaction_dose]
39
+ elif self.method == "censoring":
40
+ if self.excused:
41
+ parts = [tx_bas, followup, trial, interaction]
42
+ else:
43
+ parts = [tx_bas, followup, trial, fixed, interaction]
44
+ else:
45
+ if self.method == "dose-response":
46
+ parts = [dose, followup, trial, fixed, tv_bas, interaction_dose]
47
+ elif self.method == "censoring":
48
+ parts = [tx_bas, followup, trial, fixed, tv_bas, interaction]
49
+ return "+".join(filter(None, parts))
50
+
51
+ if self.method == "dose-response":
52
+ parts = [dose, followup, trial, fixed, tv_bas, interaction_dose]
53
+ elif self.method == "censoring":
54
+ parts = [tx_bas, followup, trial, fixed, tv_bas, interaction]
55
+
56
+ return "+".join(filter(None, parts))
@@ -0,0 +1 @@
1
+ from ._survival_plot import _survival_plot as _survival_plot
@@ -0,0 +1,104 @@
1
+ import itertools
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import polars as pl
6
+
7
+
8
+ def _survival_plot(self):
9
+ if self.plot_type == "risk":
10
+ plot_data = self.km_data.filter(pl.col("estimate") == "risk")
11
+ elif self.plot_type == "survival":
12
+ plot_data = self.km_data.filter(pl.col("estimate") == "survival")
13
+ else:
14
+ plot_data = self.km_data.filter(pl.col("estimate") == "incidence")
15
+
16
+ if self.subgroup_colname is None:
17
+ fig = _plot_single(self, plot_data)
18
+ else:
19
+ fig = _plot_subgroups(self, plot_data)
20
+
21
+ return fig
22
+
23
+
24
+ def _plot_single(self, plot_data):
25
+ fig, ax = plt.subplots(figsize=(10, 6))
26
+ _plot_data(self, plot_data, ax)
27
+
28
+ if self.plot_title is None:
29
+ self.plot_title = f"Cumulative {self.plot_type.title()}"
30
+
31
+ ax.set_xlabel("Followup")
32
+ ax.set_ylabel(self.plot_type.title())
33
+ ax.set_title(self.plot_title)
34
+ ax.legend()
35
+ ax.grid()
36
+
37
+ return fig
38
+
39
+
40
+ def _plot_subgroups(self, plot_data):
41
+ subgroups = sorted(plot_data[self.subgroup_colname].unique().to_list())
42
+ n_subgroups = len(subgroups)
43
+ n_cols = min(3, n_subgroups)
44
+ n_rows = (n_subgroups + n_cols - 1) // n_cols
45
+
46
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(7 * n_cols, 6 * n_rows))
47
+ axes = np.atleast_1d(axes).flatten()
48
+
49
+ for idx, subgroup_val in enumerate(subgroups):
50
+ ax = axes[idx]
51
+ subgroup_data = plot_data.filter(pl.col(self.subgroup_colname) == subgroup_val)
52
+ _plot_data(self, subgroup_data, ax)
53
+ subgroup_label = (
54
+ str(subgroup_val).title() if isinstance(subgroup_val, str) else subgroup_val
55
+ )
56
+ ax.set_xlabel("Followup")
57
+ ax.set_ylabel(self.plot_type.title())
58
+ ax.set_title(
59
+ f"{self.subgroup_colname.title()}: {subgroup_label}",
60
+ fontsize=10,
61
+ style="italic",
62
+ )
63
+ ax.legend()
64
+ ax.grid()
65
+
66
+ for idx in range(n_subgroups, len(axes)):
67
+ axes[idx].set_visible(False)
68
+
69
+ if self.plot_title:
70
+ fig.suptitle(self.plot_title, fontsize=14)
71
+ else:
72
+ fig.suptitle(f"Cumulative {self.plot_type.title()}", fontsize=14)
73
+
74
+ plt.tight_layout()
75
+ return fig
76
+
77
+
78
+ def _plot_data(self, plot_data, ax):
79
+ color_cycle = itertools.cycle(self.plot_colors) if self.plot_colors else None
80
+
81
+ for idx, i in enumerate(self.treatment_level):
82
+ subset = plot_data.filter(pl.col(self.treatment_col) == i)
83
+ if subset.is_empty():
84
+ continue
85
+
86
+ label = f"treatment = {i}"
87
+ if self.plot_labels and idx < len(self.plot_labels):
88
+ label = self.plot_labels[idx]
89
+
90
+ color = next(color_cycle) if color_cycle else None
91
+
92
+ (line,) = ax.plot(
93
+ subset["followup"], subset["pred"], "-", label=label, color=color
94
+ )
95
+
96
+ if "LCI" in subset.columns and "UCI" in subset.columns:
97
+ ax.fill_between(
98
+ subset["followup"],
99
+ subset["LCI"],
100
+ subset["UCI"],
101
+ color=line.get_color(),
102
+ alpha=0.2,
103
+ label="_nolegend_",
104
+ )
@@ -0,0 +1,8 @@
1
+ from ._weight_bind import _weight_bind as _weight_bind
2
+ from ._weight_data import _weight_setup as _weight_setup
3
+ from ._weight_fit import _fit_denominator as _fit_denominator
4
+ from ._weight_fit import _fit_LTFU as _fit_LTFU
5
+ from ._weight_fit import _fit_numerator as _fit_numerator
6
+ from ._weight_fit import _fit_visit as _fit_visit
7
+ from ._weight_pred import _weight_predict as _weight_predict
8
+ from ._weight_stats import _weight_stats as _weight_stats
@@ -0,0 +1,86 @@
1
+ import polars as pl
2
+
3
+
4
+ def _weight_bind(self, WDT):
5
+ if self.weight_preexpansion:
6
+ join = "inner"
7
+ on = [self.id_col, "period"]
8
+ WDT = WDT.rename({self.time_col: "period"})
9
+ else:
10
+ join = "left"
11
+ on = [self.id_col, "trial", "followup"]
12
+
13
+ WDT = self.DT.join(WDT, on=on, how=join)
14
+
15
+ if self.visit_colname is not None:
16
+ visit = pl.col(self.visit_colname) == 0
17
+ else:
18
+ visit = pl.lit(False)
19
+
20
+ if self.weight_preexpansion and self.excused:
21
+ trial = (pl.col("trial") == 0) & (pl.col("period") == 0)
22
+ excused = (
23
+ pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"])
24
+ > 0
25
+ )
26
+ override = (
27
+ trial
28
+ | excused
29
+ | visit
30
+ | pl.col(self.outcome_col).is_null()
31
+ | (pl.col("denominator") < 1e-7)
32
+ )
33
+ elif not self.weight_preexpansion and self.excused:
34
+ trial = pl.col("followup") == 0
35
+ excused = (
36
+ pl.col("isExcused").fill_null(False).cum_sum().over([self.id_col, "trial"])
37
+ > 0
38
+ )
39
+ override = (
40
+ trial
41
+ | excused
42
+ | visit
43
+ | pl.col(self.outcome_col).is_null()
44
+ | (pl.col("denominator") < 1e-7)
45
+ | (pl.col("numerator") < 1e-7)
46
+ )
47
+ else:
48
+ trial = (pl.col("trial") == pl.col("trial").min().over(self.id_col)) & (
49
+ pl.col("followup") == 0
50
+ )
51
+ excused = pl.lit(False)
52
+ override = (
53
+ trial
54
+ | excused
55
+ | visit
56
+ | pl.col(self.outcome_col).is_null()
57
+ | (pl.col("denominator") < 1e-15)
58
+ | pl.col("numerator").is_null()
59
+ )
60
+
61
+ self.DT = (
62
+ (
63
+ WDT.with_columns(
64
+ pl.when(override)
65
+ .then(pl.lit(1.0))
66
+ .otherwise(pl.col("numerator") / pl.col("denominator"))
67
+ .alias("wt")
68
+ )
69
+ .sort([self.id_col, "trial", "followup"])
70
+ .with_columns(
71
+ pl.col("wt")
72
+ .fill_null(1.0)
73
+ .cum_prod()
74
+ .over([self.id_col, "trial"])
75
+ .alias("weight")
76
+ )
77
+ )
78
+ .with_columns(
79
+ (
80
+ pl.col("weight")
81
+ * pl.col("_cense").fill_null(1.0)
82
+ * pl.col("_visit").fill_null(1.0)
83
+ ).alias("weight")
84
+ )
85
+ .drop(["_cense", "_visit"])
86
+ )
@@ -0,0 +1,47 @@
1
+ import polars as pl
2
+
3
+
4
+ def _weight_setup(self):
5
+ DT = self.DT
6
+ data = self.data
7
+ if not self.weight_preexpansion:
8
+ baseline_lag = (
9
+ data.select([self.treatment_col, self.id_col, self.time_col])
10
+ .sort([self.id_col, self.time_col])
11
+ .with_columns(
12
+ pl.col(self.treatment_col)
13
+ .shift(fill_value=self.treatment_level[0])
14
+ .over(self.id_col)
15
+ .alias("tx_lag")
16
+ )
17
+ .drop(self.treatment_col)
18
+ .rename({self.time_col: "period"})
19
+ )
20
+
21
+ fup0 = DT.filter(pl.col("followup") == 0).join(
22
+ baseline_lag, on=[self.id_col, "period"], how="inner"
23
+ )
24
+
25
+ fup = (
26
+ DT.sort([self.id_col, "trial", "followup"])
27
+ .with_columns(
28
+ pl.col(self.treatment_col)
29
+ .shift(fill_value=self.treatment_level[0])
30
+ .over([self.id_col, "trial"])
31
+ .alias("tx_lag")
32
+ )
33
+ .filter(pl.col("followup") != 0)
34
+ )
35
+
36
+ WDT = pl.concat([fup0, fup]).sort([self.id_col, "trial", "followup"])
37
+ else:
38
+ WDT = data.with_columns(
39
+ pl.col(self.treatment_col)
40
+ .shift(fill_value=self.treatment_level[0])
41
+ .over(self.id_col)
42
+ .alias("tx_lag"),
43
+ (pl.col(self.time_col) ** 2).alias(
44
+ f"{self.time_col}{self.indicator_squared}"
45
+ ),
46
+ )
47
+ return WDT
@@ -0,0 +1,99 @@
1
+ import statsmodels.api as sm
2
+ import statsmodels.formula.api as smf
3
+
4
+
5
+ def _fit_pair(
6
+ self, WDT, outcome_attr, formula_attr, output_attrs, eligible_colname_attr=None
7
+ ):
8
+ outcome = getattr(self, outcome_attr)
9
+
10
+ if eligible_colname_attr is not None:
11
+ _eligible_col = getattr(self, eligible_colname_attr)
12
+ if _eligible_col is not None:
13
+ WDT = WDT[WDT[_eligible_col] == 1]
14
+
15
+ for rhs, out in zip(formula_attr, output_attrs):
16
+ formula = f"{outcome}~{rhs}"
17
+ model = smf.glm(formula, WDT, family=sm.families.Binomial())
18
+ setattr(self, out, model.fit(disp=0))
19
+
20
+
21
+ def _fit_LTFU(self, WDT):
22
+ if self.cense_colname is None:
23
+ return
24
+ _fit_pair(
25
+ self,
26
+ WDT,
27
+ "cense_colname",
28
+ [self.cense_numerator, self.cense_denominator],
29
+ ["cense_numerator", "cense_denominator"],
30
+ "cense_eligible_colname",
31
+ )
32
+
33
+
34
+ def _fit_visit(self, WDT):
35
+ if self.visit_colname is None:
36
+ return
37
+ _fit_pair(
38
+ self,
39
+ WDT,
40
+ "visit_colname",
41
+ [self.cense_numerator, self.cense_denominator],
42
+ ["visit_numerator", "visit_denominator"],
43
+ )
44
+
45
+
46
+ def _fit_numerator(self, WDT):
47
+ if self.weight_preexpansion and self.excused:
48
+ return
49
+ if self.method == "ITT":
50
+ return
51
+ predictor = "switch" if self.excused else self.treatment_col
52
+ formula = f"{predictor}~{self.numerator}"
53
+ tx_bas = (
54
+ f"{self.treatment_col}{self.indicator_baseline}" if self.excused else "tx_lag"
55
+ )
56
+ fits = []
57
+ for i, level in enumerate(self.treatment_level):
58
+ if self.excused and self.excused_colnames[i] is not None:
59
+ DT_subset = WDT[WDT[self.excused_colnames[i]] == 0]
60
+ else:
61
+ DT_subset = WDT
62
+ if self.weight_lag_condition:
63
+ DT_subset = DT_subset[DT_subset[tx_bas] == level]
64
+ if self.weight_eligible_colnames[i] is not None:
65
+ DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1]
66
+ model = smf.mnlogit(formula, DT_subset)
67
+ model_fit = model.fit(disp=0)
68
+ fits.append(model_fit)
69
+
70
+ self.numerator_model = fits
71
+
72
+
73
+ def _fit_denominator(self, WDT):
74
+ if self.method == "ITT":
75
+ return
76
+ predictor = (
77
+ "switch"
78
+ if self.excused and not self.weight_preexpansion
79
+ else self.treatment_col
80
+ )
81
+ formula = f"{predictor}~{self.denominator}"
82
+ fits = []
83
+ for i, level in enumerate(self.treatment_level):
84
+ if self.excused and self.excused_colnames[i] is not None:
85
+ DT_subset = WDT[WDT[self.excused_colnames[i]] == 0]
86
+ else:
87
+ DT_subset = WDT
88
+ if self.weight_lag_condition:
89
+ DT_subset = DT_subset[DT_subset["tx_lag"] == level]
90
+ if not self.weight_preexpansion:
91
+ DT_subset = DT_subset[DT_subset["followup"] != 0]
92
+ if self.weight_eligible_colnames[i] is not None:
93
+ DT_subset = DT_subset[DT_subset[self.weight_eligible_colnames[i]] == 1]
94
+
95
+ model = smf.mnlogit(formula, DT_subset)
96
+ model_fit = model.fit(disp=0)
97
+ fits.append(model_fit)
98
+
99
+ self.denominator_model = fits