pySEQTarget 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pySEQTarget/SEQopts.py +197 -0
- pySEQTarget/SEQoutput.py +163 -0
- pySEQTarget/SEQuential.py +375 -0
- pySEQTarget/__init__.py +5 -0
- pySEQTarget/analysis/__init__.py +8 -0
- pySEQTarget/analysis/_hazard.py +211 -0
- pySEQTarget/analysis/_outcome_fit.py +75 -0
- pySEQTarget/analysis/_risk_estimates.py +136 -0
- pySEQTarget/analysis/_subgroup_fit.py +30 -0
- pySEQTarget/analysis/_survival_pred.py +372 -0
- pySEQTarget/data/__init__.py +19 -0
- pySEQTarget/error/__init__.py +2 -0
- pySEQTarget/error/_datachecker.py +38 -0
- pySEQTarget/error/_param_checker.py +50 -0
- pySEQTarget/expansion/__init__.py +5 -0
- pySEQTarget/expansion/_binder.py +98 -0
- pySEQTarget/expansion/_diagnostics.py +53 -0
- pySEQTarget/expansion/_dynamic.py +73 -0
- pySEQTarget/expansion/_mapper.py +44 -0
- pySEQTarget/expansion/_selection.py +31 -0
- pySEQTarget/helpers/__init__.py +8 -0
- pySEQTarget/helpers/_bootstrap.py +111 -0
- pySEQTarget/helpers/_col_string.py +6 -0
- pySEQTarget/helpers/_format_time.py +6 -0
- pySEQTarget/helpers/_output_files.py +167 -0
- pySEQTarget/helpers/_pad.py +7 -0
- pySEQTarget/helpers/_predict_model.py +9 -0
- pySEQTarget/helpers/_prepare_data.py +19 -0
- pySEQTarget/initialization/__init__.py +5 -0
- pySEQTarget/initialization/_censoring.py +53 -0
- pySEQTarget/initialization/_denominator.py +39 -0
- pySEQTarget/initialization/_numerator.py +37 -0
- pySEQTarget/initialization/_outcome.py +56 -0
- pySEQTarget/plot/__init__.py +1 -0
- pySEQTarget/plot/_survival_plot.py +104 -0
- pySEQTarget/weighting/__init__.py +8 -0
- pySEQTarget/weighting/_weight_bind.py +86 -0
- pySEQTarget/weighting/_weight_data.py +47 -0
- pySEQTarget/weighting/_weight_fit.py +99 -0
- pySEQTarget/weighting/_weight_pred.py +192 -0
- pySEQTarget/weighting/_weight_stats.py +23 -0
- pyseqtarget-0.10.0.dist-info/METADATA +98 -0
- pyseqtarget-0.10.0.dist-info/RECORD +46 -0
- pyseqtarget-0.10.0.dist-info/WHEEL +5 -0
- pyseqtarget-0.10.0.dist-info/licenses/LICENSE +21 -0
- pyseqtarget-0.10.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
from scipy import stats
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def _risk_estimates(self):
|
|
6
|
+
last_followup = self.km_data["followup"].max()
|
|
7
|
+
risk = self.km_data.filter(
|
|
8
|
+
(pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
group_cols = [self.subgroup_colname] if self.subgroup_colname else []
|
|
12
|
+
rd_comparisons = []
|
|
13
|
+
rr_comparisons = []
|
|
14
|
+
|
|
15
|
+
if self.bootstrap_nboot > 0:
|
|
16
|
+
alpha = 1 - self.bootstrap_CI
|
|
17
|
+
z = stats.norm.ppf(1 - alpha / 2)
|
|
18
|
+
|
|
19
|
+
for tx_x in self.treatment_level:
|
|
20
|
+
for tx_y in self.treatment_level:
|
|
21
|
+
if tx_x == tx_y:
|
|
22
|
+
continue
|
|
23
|
+
|
|
24
|
+
risk_x = (
|
|
25
|
+
risk.filter(pl.col("tx_init") == tx_x)
|
|
26
|
+
.select(group_cols + ["pred"])
|
|
27
|
+
.rename({"pred": "risk_x"})
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
risk_y = (
|
|
31
|
+
risk.filter(pl.col("tx_init") == tx_y)
|
|
32
|
+
.select(group_cols + ["pred"])
|
|
33
|
+
.rename({"pred": "risk_y"})
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if group_cols:
|
|
37
|
+
comp = risk_x.join(risk_y, on=group_cols, how="left")
|
|
38
|
+
else:
|
|
39
|
+
comp = risk_x.join(risk_y, how="cross")
|
|
40
|
+
|
|
41
|
+
comp = comp.with_columns(
|
|
42
|
+
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if self.bootstrap_nboot > 0:
|
|
46
|
+
se_x = (
|
|
47
|
+
risk.filter(pl.col("tx_init") == tx_x)
|
|
48
|
+
.select(group_cols + ["SE"])
|
|
49
|
+
.rename({"SE": "se_x"})
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
se_y = (
|
|
53
|
+
risk.filter(pl.col("tx_init") == tx_y)
|
|
54
|
+
.select(group_cols + ["SE"])
|
|
55
|
+
.rename({"SE": "se_y"})
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
if group_cols:
|
|
59
|
+
comp = comp.join(se_x, on=group_cols, how="left")
|
|
60
|
+
comp = comp.join(se_y, on=group_cols, how="left")
|
|
61
|
+
else:
|
|
62
|
+
comp = comp.join(se_x, how="cross")
|
|
63
|
+
comp = comp.join(se_y, how="cross")
|
|
64
|
+
|
|
65
|
+
rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
|
|
66
|
+
rd_comp = comp.with_columns(
|
|
67
|
+
[
|
|
68
|
+
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
|
|
69
|
+
(pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias(
|
|
70
|
+
"RD 95% LCI"
|
|
71
|
+
),
|
|
72
|
+
(pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias(
|
|
73
|
+
"RD 95% UCI"
|
|
74
|
+
),
|
|
75
|
+
]
|
|
76
|
+
)
|
|
77
|
+
rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
|
|
78
|
+
col_order = group_cols + [
|
|
79
|
+
"A_x",
|
|
80
|
+
"A_y",
|
|
81
|
+
"Risk Difference",
|
|
82
|
+
"RD 95% LCI",
|
|
83
|
+
"RD 95% UCI",
|
|
84
|
+
]
|
|
85
|
+
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
|
|
86
|
+
rd_comparisons.append(rd_comp)
|
|
87
|
+
|
|
88
|
+
rr_log_se = (
|
|
89
|
+
(pl.col("se_x") / pl.col("risk_x")).pow(2)
|
|
90
|
+
+ (pl.col("se_y") / pl.col("risk_y")).pow(2)
|
|
91
|
+
).sqrt()
|
|
92
|
+
rr_comp = comp.with_columns(
|
|
93
|
+
[
|
|
94
|
+
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
|
|
95
|
+
(
|
|
96
|
+
(pl.col("risk_x") / pl.col("risk_y"))
|
|
97
|
+
* (-z * rr_log_se).exp()
|
|
98
|
+
).alias("RR 95% LCI"),
|
|
99
|
+
(
|
|
100
|
+
(pl.col("risk_x") / pl.col("risk_y"))
|
|
101
|
+
* (z * rr_log_se).exp()
|
|
102
|
+
).alias("RR 95% UCI"),
|
|
103
|
+
]
|
|
104
|
+
)
|
|
105
|
+
rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
|
|
106
|
+
col_order = group_cols + [
|
|
107
|
+
"A_x",
|
|
108
|
+
"A_y",
|
|
109
|
+
"Risk Ratio",
|
|
110
|
+
"RR 95% LCI",
|
|
111
|
+
"RR 95% UCI",
|
|
112
|
+
]
|
|
113
|
+
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
|
|
114
|
+
rr_comparisons.append(rr_comp)
|
|
115
|
+
|
|
116
|
+
else:
|
|
117
|
+
rd_comp = comp.with_columns(
|
|
118
|
+
(pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
|
|
119
|
+
)
|
|
120
|
+
rd_comp = rd_comp.drop(["risk_x", "risk_y"])
|
|
121
|
+
col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
|
|
122
|
+
rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
|
|
123
|
+
rd_comparisons.append(rd_comp)
|
|
124
|
+
|
|
125
|
+
rr_comp = comp.with_columns(
|
|
126
|
+
(pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
|
|
127
|
+
)
|
|
128
|
+
rr_comp = rr_comp.drop(["risk_x", "risk_y"])
|
|
129
|
+
col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
|
|
130
|
+
rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
|
|
131
|
+
rr_comparisons.append(rr_comp)
|
|
132
|
+
|
|
133
|
+
risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
|
|
134
|
+
risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()
|
|
135
|
+
|
|
136
|
+
return {"risk_difference": risk_difference, "risk_ratio": risk_ratio}
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
|
|
3
|
+
from ._outcome_fit import _outcome_fit
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _subgroup_fit(self):
|
|
7
|
+
subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list())
|
|
8
|
+
self._unique_subgroups = subgroups
|
|
9
|
+
|
|
10
|
+
models_list = []
|
|
11
|
+
for val in subgroups:
|
|
12
|
+
subDT = self.DT.filter(pl.col(self.subgroup_colname) == val)
|
|
13
|
+
|
|
14
|
+
models = {
|
|
15
|
+
"outcome": _outcome_fit(
|
|
16
|
+
self, subDT, self.outcome_col, self.covariates, self.weighted, "weight"
|
|
17
|
+
)
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
if self.compevent_colname is not None:
|
|
21
|
+
models["compevent"] = _outcome_fit(
|
|
22
|
+
self,
|
|
23
|
+
subDT,
|
|
24
|
+
self.compevent_colname,
|
|
25
|
+
self.covariates,
|
|
26
|
+
self.weighted,
|
|
27
|
+
"weight",
|
|
28
|
+
)
|
|
29
|
+
models_list.append(models)
|
|
30
|
+
return models_list
|
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _get_outcome_predictions(self, TxDT, idx=None):
|
|
5
|
+
data = TxDT.to_pandas()
|
|
6
|
+
predictions = {"outcome": []}
|
|
7
|
+
if self.compevent_colname is not None:
|
|
8
|
+
predictions["compevent"] = []
|
|
9
|
+
|
|
10
|
+
for boot_model in self.outcome_model:
|
|
11
|
+
model_dict = boot_model[idx] if idx is not None else boot_model
|
|
12
|
+
predictions["outcome"].append(model_dict["outcome"].predict(data))
|
|
13
|
+
if self.compevent_colname is not None:
|
|
14
|
+
predictions["compevent"].append(model_dict["compevent"].predict(data))
|
|
15
|
+
|
|
16
|
+
return predictions
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _pred_risk(self):
|
|
20
|
+
has_subgroups = (
|
|
21
|
+
isinstance(self.outcome_model[0], list) if self.outcome_model else False
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
if not has_subgroups:
|
|
25
|
+
return _calculate_risk(self, self.DT, idx=None, val=None)
|
|
26
|
+
|
|
27
|
+
all_risks = []
|
|
28
|
+
original_DT = self.DT
|
|
29
|
+
|
|
30
|
+
for i, val in enumerate(self._unique_subgroups):
|
|
31
|
+
subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val)
|
|
32
|
+
risk = _calculate_risk(self, subgroup_DT, i, val)
|
|
33
|
+
all_risks.append(risk)
|
|
34
|
+
|
|
35
|
+
self.DT = original_DT
|
|
36
|
+
return pl.concat(all_risks)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _calculate_risk(self, data, idx=None, val=None):
|
|
40
|
+
a = 1 - self.bootstrap_CI
|
|
41
|
+
lci = a / 2
|
|
42
|
+
uci = 1 - lci
|
|
43
|
+
|
|
44
|
+
SDT = (
|
|
45
|
+
data.with_columns(
|
|
46
|
+
[
|
|
47
|
+
(
|
|
48
|
+
pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)
|
|
49
|
+
).alias("TID")
|
|
50
|
+
]
|
|
51
|
+
)
|
|
52
|
+
.group_by("TID")
|
|
53
|
+
.first()
|
|
54
|
+
.drop(["followup", f"followup{self.indicator_squared}"])
|
|
55
|
+
.with_columns([pl.lit(list(range(self.followup_max))).alias("followup")])
|
|
56
|
+
.explode("followup")
|
|
57
|
+
.with_columns(
|
|
58
|
+
[
|
|
59
|
+
(pl.col("followup") + 1).alias("followup"),
|
|
60
|
+
(pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
|
|
61
|
+
]
|
|
62
|
+
)
|
|
63
|
+
).sort([self.id_col, "trial", "followup"])
|
|
64
|
+
|
|
65
|
+
risks = []
|
|
66
|
+
for treatment_val in self.treatment_level:
|
|
67
|
+
TxDT = SDT.with_columns(
|
|
68
|
+
[
|
|
69
|
+
pl.lit(treatment_val).alias(
|
|
70
|
+
f"{self.treatment_col}{self.indicator_baseline}"
|
|
71
|
+
)
|
|
72
|
+
]
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if self.method == "dose-response":
|
|
76
|
+
if treatment_val == self.treatment_level[0]:
|
|
77
|
+
TxDT = TxDT.with_columns(
|
|
78
|
+
[pl.lit(0.0).alias("dose"), pl.lit(0.0).alias("dose_sq")]
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
TxDT = TxDT.with_columns(
|
|
82
|
+
[
|
|
83
|
+
pl.col("followup").alias("dose"),
|
|
84
|
+
pl.col(f"followup{self.indicator_squared}").alias("dose_sq"),
|
|
85
|
+
]
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
preds = _get_outcome_predictions(self, TxDT, idx=idx)
|
|
89
|
+
pred_series = [pl.Series("pred_outcome", preds["outcome"][0])]
|
|
90
|
+
|
|
91
|
+
if self.bootstrap_nboot > 0:
|
|
92
|
+
for boot_idx, pred in enumerate(preds["outcome"][1:], start=1):
|
|
93
|
+
pred_series.append(pl.Series(f"pred_outcome_{boot_idx}", pred))
|
|
94
|
+
|
|
95
|
+
if self.compevent_colname is not None:
|
|
96
|
+
pred_series.append(pl.Series("pred_ce", preds["compevent"][0]))
|
|
97
|
+
if self.bootstrap_nboot > 0:
|
|
98
|
+
for boot_idx, pred in enumerate(preds["compevent"][1:], start=1):
|
|
99
|
+
pred_series.append(pl.Series(f"pred_ce_{boot_idx}", pred))
|
|
100
|
+
|
|
101
|
+
outcome_names = [s.name for s in pred_series if "outcome" in s.name]
|
|
102
|
+
ce_names = [s.name for s in pred_series if "ce" in s.name]
|
|
103
|
+
|
|
104
|
+
TxDT = TxDT.with_columns(pred_series)
|
|
105
|
+
|
|
106
|
+
if self.compevent_colname is not None:
|
|
107
|
+
for out_col, ce_col in zip(outcome_names, ce_names):
|
|
108
|
+
surv_col = out_col.replace("pred_outcome", "surv")
|
|
109
|
+
cce_col = out_col.replace("pred_outcome", "cce")
|
|
110
|
+
inc_col = out_col.replace("pred_outcome", "inc")
|
|
111
|
+
|
|
112
|
+
TxDT = TxDT.with_columns(
|
|
113
|
+
[
|
|
114
|
+
(1 - pl.col(out_col)).cum_prod().over("TID").alias(surv_col),
|
|
115
|
+
((1 - pl.col(out_col)) * (1 - pl.col(ce_col)))
|
|
116
|
+
.cum_prod()
|
|
117
|
+
.over("TID")
|
|
118
|
+
.alias(cce_col),
|
|
119
|
+
]
|
|
120
|
+
).with_columns(
|
|
121
|
+
[
|
|
122
|
+
(pl.col(out_col) * (1 - pl.col(ce_col)) * pl.col(cce_col))
|
|
123
|
+
.cum_sum()
|
|
124
|
+
.over("TID")
|
|
125
|
+
.alias(inc_col)
|
|
126
|
+
]
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
surv_names = [n.replace("pred_outcome", "surv") for n in outcome_names]
|
|
130
|
+
inc_names = [n.replace("pred_outcome", "inc") for n in outcome_names]
|
|
131
|
+
TxDT = (
|
|
132
|
+
TxDT.group_by("followup")
|
|
133
|
+
.agg([pl.col(col).mean() for col in surv_names + inc_names])
|
|
134
|
+
.sort("followup")
|
|
135
|
+
)
|
|
136
|
+
main_col = "surv"
|
|
137
|
+
boot_cols = [col for col in surv_names if col != "surv"]
|
|
138
|
+
else:
|
|
139
|
+
TxDT = (
|
|
140
|
+
TxDT.with_columns(
|
|
141
|
+
[
|
|
142
|
+
(1 - pl.col(col)).cum_prod().over("TID").alias(col)
|
|
143
|
+
for col in outcome_names
|
|
144
|
+
]
|
|
145
|
+
)
|
|
146
|
+
.group_by("followup")
|
|
147
|
+
.agg([pl.col(col).mean() for col in outcome_names])
|
|
148
|
+
.sort("followup")
|
|
149
|
+
.with_columns([(1 - pl.col(col)).alias(col) for col in outcome_names])
|
|
150
|
+
)
|
|
151
|
+
main_col = "pred_outcome"
|
|
152
|
+
boot_cols = [col for col in outcome_names if col != "pred_outcome"]
|
|
153
|
+
|
|
154
|
+
if boot_cols:
|
|
155
|
+
risk = (
|
|
156
|
+
TxDT.select(["followup"] + boot_cols)
|
|
157
|
+
.unpivot(
|
|
158
|
+
index="followup",
|
|
159
|
+
on=boot_cols,
|
|
160
|
+
variable_name="bootID",
|
|
161
|
+
value_name="risk",
|
|
162
|
+
)
|
|
163
|
+
.group_by("followup")
|
|
164
|
+
.agg(
|
|
165
|
+
[
|
|
166
|
+
pl.col("risk").std().cast(pl.Float64).alias("SE"),
|
|
167
|
+
pl.col("risk").quantile(lci).cast(pl.Float64).alias("LCI"),
|
|
168
|
+
pl.col("risk").quantile(uci).cast(pl.Float64).alias("UCI"),
|
|
169
|
+
]
|
|
170
|
+
)
|
|
171
|
+
.join(TxDT.select(["followup", main_col]), on="followup")
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
if self.bootstrap_CI_method == "se":
|
|
175
|
+
from scipy.stats import norm
|
|
176
|
+
|
|
177
|
+
z = norm.ppf(1 - a / 2)
|
|
178
|
+
risk = risk.with_columns(
|
|
179
|
+
[
|
|
180
|
+
(pl.col(main_col) - z * pl.col("SE")).alias("LCI"),
|
|
181
|
+
(pl.col(main_col) + z * pl.col("SE")).alias("UCI"),
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
fup0_val = 1.0 if self.compevent_colname else 0.0
|
|
186
|
+
|
|
187
|
+
if self.compevent_colname is not None:
|
|
188
|
+
inc_boot_cols = [col for col in inc_names if col != "inc"]
|
|
189
|
+
if inc_boot_cols:
|
|
190
|
+
inc_risk = (
|
|
191
|
+
TxDT.select(["followup"] + inc_boot_cols)
|
|
192
|
+
.unpivot(
|
|
193
|
+
index="followup",
|
|
194
|
+
on=inc_boot_cols,
|
|
195
|
+
variable_name="bootID",
|
|
196
|
+
value_name="inc_val",
|
|
197
|
+
)
|
|
198
|
+
.group_by("followup")
|
|
199
|
+
.agg(
|
|
200
|
+
[
|
|
201
|
+
pl.col("inc_val")
|
|
202
|
+
.std()
|
|
203
|
+
.cast(pl.Float64)
|
|
204
|
+
.alias("inc_SE"),
|
|
205
|
+
pl.col("inc_val")
|
|
206
|
+
.quantile(lci)
|
|
207
|
+
.cast(pl.Float64)
|
|
208
|
+
.alias("inc_LCI"),
|
|
209
|
+
pl.col("inc_val")
|
|
210
|
+
.quantile(uci)
|
|
211
|
+
.cast(pl.Float64)
|
|
212
|
+
.alias("inc_UCI"),
|
|
213
|
+
]
|
|
214
|
+
)
|
|
215
|
+
.join(TxDT.select(["followup", "inc"]), on="followup")
|
|
216
|
+
)
|
|
217
|
+
risk = risk.join(inc_risk, on="followup")
|
|
218
|
+
final_cols = [
|
|
219
|
+
"followup",
|
|
220
|
+
main_col,
|
|
221
|
+
"SE",
|
|
222
|
+
"LCI",
|
|
223
|
+
"UCI",
|
|
224
|
+
"inc",
|
|
225
|
+
"inc_SE",
|
|
226
|
+
"inc_LCI",
|
|
227
|
+
"inc_UCI",
|
|
228
|
+
]
|
|
229
|
+
risk = risk.select(final_cols).with_columns(
|
|
230
|
+
pl.lit(treatment_val).alias(self.treatment_col)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
fup0 = pl.DataFrame(
|
|
234
|
+
{
|
|
235
|
+
"followup": [0],
|
|
236
|
+
main_col: [fup0_val],
|
|
237
|
+
"SE": [0.0],
|
|
238
|
+
"LCI": [fup0_val],
|
|
239
|
+
"UCI": [fup0_val],
|
|
240
|
+
"inc": [0.0],
|
|
241
|
+
"inc_SE": [0.0],
|
|
242
|
+
"inc_LCI": [0.0],
|
|
243
|
+
"inc_UCI": [0.0],
|
|
244
|
+
self.treatment_col: [treatment_val],
|
|
245
|
+
}
|
|
246
|
+
).with_columns(
|
|
247
|
+
[
|
|
248
|
+
pl.col("followup").cast(pl.Int64),
|
|
249
|
+
pl.col(self.treatment_col).cast(pl.Int32),
|
|
250
|
+
]
|
|
251
|
+
)
|
|
252
|
+
else:
|
|
253
|
+
risk = risk.select(
|
|
254
|
+
["followup", main_col, "SE", "LCI", "UCI"]
|
|
255
|
+
).with_columns(pl.lit(treatment_val).alias(self.treatment_col))
|
|
256
|
+
fup0 = pl.DataFrame(
|
|
257
|
+
{
|
|
258
|
+
"followup": [0],
|
|
259
|
+
main_col: [fup0_val],
|
|
260
|
+
"SE": [0.0],
|
|
261
|
+
"LCI": [fup0_val],
|
|
262
|
+
"UCI": [fup0_val],
|
|
263
|
+
self.treatment_col: [treatment_val],
|
|
264
|
+
}
|
|
265
|
+
).with_columns(
|
|
266
|
+
[
|
|
267
|
+
pl.col("followup").cast(pl.Int64),
|
|
268
|
+
pl.col(self.treatment_col).cast(pl.Int32),
|
|
269
|
+
]
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
risk = risk.select(
|
|
273
|
+
["followup", main_col, "SE", "LCI", "UCI"]
|
|
274
|
+
).with_columns(pl.lit(treatment_val).alias(self.treatment_col))
|
|
275
|
+
fup0 = pl.DataFrame(
|
|
276
|
+
{
|
|
277
|
+
"followup": [0],
|
|
278
|
+
main_col: [fup0_val],
|
|
279
|
+
"SE": [0.0],
|
|
280
|
+
"LCI": [fup0_val],
|
|
281
|
+
"UCI": [fup0_val],
|
|
282
|
+
self.treatment_col: [treatment_val],
|
|
283
|
+
}
|
|
284
|
+
).with_columns(
|
|
285
|
+
[
|
|
286
|
+
pl.col("followup").cast(pl.Int64),
|
|
287
|
+
pl.col(self.treatment_col).cast(pl.Int32),
|
|
288
|
+
]
|
|
289
|
+
)
|
|
290
|
+
else:
|
|
291
|
+
fup0_val = 1.0 if self.compevent_colname else 0.0
|
|
292
|
+
risk = TxDT.select(["followup", main_col]).with_columns(
|
|
293
|
+
pl.lit(treatment_val).alias(self.treatment_col)
|
|
294
|
+
)
|
|
295
|
+
fup0 = pl.DataFrame(
|
|
296
|
+
{
|
|
297
|
+
"followup": [0],
|
|
298
|
+
main_col: [fup0_val],
|
|
299
|
+
self.treatment_col: [treatment_val],
|
|
300
|
+
}
|
|
301
|
+
).with_columns(
|
|
302
|
+
[
|
|
303
|
+
pl.col("followup").cast(pl.Int64),
|
|
304
|
+
pl.col(self.treatment_col).cast(pl.Int32),
|
|
305
|
+
]
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
if self.compevent_colname is not None:
|
|
309
|
+
risk = risk.join(TxDT.select(["followup", "inc"]), on="followup")
|
|
310
|
+
fup0 = fup0.with_columns([pl.lit(0.0).alias("inc")])
|
|
311
|
+
|
|
312
|
+
risks.append(pl.concat([fup0, risk]))
|
|
313
|
+
out = pl.concat(risks)
|
|
314
|
+
|
|
315
|
+
if self.compevent_colname is not None:
|
|
316
|
+
has_ci = "SE" in out.columns
|
|
317
|
+
|
|
318
|
+
surv_cols = ["followup", self.treatment_col, "surv"]
|
|
319
|
+
if has_ci:
|
|
320
|
+
surv_cols.extend(["SE", "LCI", "UCI"])
|
|
321
|
+
surv_out = (
|
|
322
|
+
out.select(surv_cols)
|
|
323
|
+
.rename({"surv": "pred"})
|
|
324
|
+
.with_columns(pl.lit("survival").alias("estimate"))
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
risk_cols = ["followup", self.treatment_col, (1 - pl.col("surv")).alias("pred")]
|
|
328
|
+
if has_ci:
|
|
329
|
+
risk_cols.extend(
|
|
330
|
+
[
|
|
331
|
+
pl.col("SE"),
|
|
332
|
+
(1 - pl.col("UCI")).alias("LCI"),
|
|
333
|
+
(1 - pl.col("LCI")).alias("UCI"),
|
|
334
|
+
]
|
|
335
|
+
)
|
|
336
|
+
risk_out = out.select(risk_cols).with_columns(pl.lit("risk").alias("estimate"))
|
|
337
|
+
|
|
338
|
+
inc_cols = ["followup", self.treatment_col, pl.col("inc").alias("pred")]
|
|
339
|
+
if has_ci:
|
|
340
|
+
inc_cols.extend(
|
|
341
|
+
[
|
|
342
|
+
pl.col("inc_SE").alias("SE"),
|
|
343
|
+
pl.col("inc_LCI").alias("LCI"),
|
|
344
|
+
pl.col("inc_UCI").alias("UCI"),
|
|
345
|
+
]
|
|
346
|
+
)
|
|
347
|
+
inc_out = out.select(inc_cols).with_columns(
|
|
348
|
+
pl.lit("incidence").alias("estimate")
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
out = pl.concat([surv_out, risk_out, inc_out])
|
|
352
|
+
else:
|
|
353
|
+
out = out.rename({"pred_outcome": "pred"}).with_columns(
|
|
354
|
+
pl.lit("risk").alias("estimate")
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
if val is not None:
|
|
358
|
+
out = out.with_columns(pl.lit(val).alias(self.subgroup_colname))
|
|
359
|
+
|
|
360
|
+
return out
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def _calculate_survival(self, risk_data):
|
|
364
|
+
if self.bootstrap_nboot > 0:
|
|
365
|
+
surv = risk_data.with_columns(
|
|
366
|
+
[(1 - pl.col(col)).alias(col) for col in ["pred", "LCI", "UCI"]]
|
|
367
|
+
).with_columns(pl.lit("survival").alias("estimate"))
|
|
368
|
+
else:
|
|
369
|
+
surv = risk_data.with_columns(
|
|
370
|
+
[(1 - pl.col("pred")).alias("pred"), pl.lit("survival").alias("estimate")]
|
|
371
|
+
)
|
|
372
|
+
return surv
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from importlib.resources import files
|
|
2
|
+
|
|
3
|
+
import polars as pl
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def load_data(name: str = "SEQdata") -> pl.DataFrame:
|
|
7
|
+
loc = files("pySEQTarget.data")
|
|
8
|
+
if name in ["SEQdata", "SEQdata_multitreatment", "SEQdata_LTFU"]:
|
|
9
|
+
if name == "SEQdata":
|
|
10
|
+
data_path = loc.joinpath("SEQdata.csv")
|
|
11
|
+
elif name == "SEQdata_multitreatment":
|
|
12
|
+
data_path = loc.joinpath("SEQdata_multitreatment.csv")
|
|
13
|
+
else:
|
|
14
|
+
data_path = loc.joinpath("SEQdata_LTFU.csv")
|
|
15
|
+
return pl.read_csv(data_path)
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"Dataset '{name}' not available. Options: ['SEQdata', 'SEQdata_multitreatment', 'SEQdata_LTFU']"
|
|
19
|
+
)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import polars as pl
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _datachecker(self):
|
|
5
|
+
check = self.data.group_by(self.id_col).agg(
|
|
6
|
+
[pl.len().alias("row_count"), pl.col(self.time_col).max().alias("max_time")]
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1)
|
|
10
|
+
if len(invalid) > 0:
|
|
11
|
+
raise ValueError(
|
|
12
|
+
f"Data validation failed: {len(invalid)} ID(s) have mismatched "
|
|
13
|
+
f"This suggests invalid times"
|
|
14
|
+
f"Invalid IDs:\n{invalid}"
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
for col in self.excused_colnames:
|
|
18
|
+
violations = (
|
|
19
|
+
self.data.sort([self.id_col, self.time_col])
|
|
20
|
+
.group_by(self.id_col)
|
|
21
|
+
.agg(
|
|
22
|
+
[
|
|
23
|
+
(
|
|
24
|
+
(pl.col(col).cum_sum().shift(1, fill_value=0) > 0)
|
|
25
|
+
& (pl.col(col) == 0)
|
|
26
|
+
)
|
|
27
|
+
.any()
|
|
28
|
+
.alias("has_violation")
|
|
29
|
+
]
|
|
30
|
+
)
|
|
31
|
+
.filter(pl.col("has_violation"))
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if len(violations) > 0:
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Column '{col}' violates 'once one, always one' rule for excusing treatment "
|
|
37
|
+
f"{len(violations)} ID(s) have zeros after ones."
|
|
38
|
+
)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from ..helpers import _pad
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _param_checker(self):
|
|
5
|
+
if (
|
|
6
|
+
self.subgroup_colname is not None
|
|
7
|
+
and self.subgroup_colname not in self.fixed_cols
|
|
8
|
+
):
|
|
9
|
+
raise ValueError("subgroup_colname must be included in fixed_cols.")
|
|
10
|
+
|
|
11
|
+
if self.followup_max is None:
|
|
12
|
+
self.followup_max = self.data.select(self.time_col).to_series().max()
|
|
13
|
+
|
|
14
|
+
if len(self.excused_colnames) == 0 and self.excused:
|
|
15
|
+
self.excused = False
|
|
16
|
+
raise Warning(
|
|
17
|
+
"Excused column names not provided but excused is set to True. Automatically set excused to False"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if len(self.excused_colnames) > 0 and not self.excused:
|
|
21
|
+
self.excused = True
|
|
22
|
+
raise Warning(
|
|
23
|
+
"Excused column names provided but excused is set to False. Automatically set excused to True"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
if self.km_curves and self.hazard_estimate:
|
|
27
|
+
raise ValueError("km_curves and hazard cannot both be set to True.")
|
|
28
|
+
|
|
29
|
+
if sum([self.followup_class, self.followup_include, self.followup_spline]) > 1:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
"Only one of followup_class or followup_include can be set to True."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if (
|
|
35
|
+
self.weighted
|
|
36
|
+
and self.method == "ITT"
|
|
37
|
+
and self.cense_colname is None
|
|
38
|
+
and self.visit_colname is None
|
|
39
|
+
):
|
|
40
|
+
raise ValueError(
|
|
41
|
+
"For weighted ITT analyses, cense_colname or visit_colname must be provided."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if self.excused:
|
|
45
|
+
_, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames)
|
|
46
|
+
_, self.weight_eligible_colnames = _pad(
|
|
47
|
+
self.treatment_level, self.weight_eligible_colnames
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return
|