pySEQTarget 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pySEQTarget/SEQopts.py +197 -0
  2. pySEQTarget/SEQoutput.py +163 -0
  3. pySEQTarget/SEQuential.py +375 -0
  4. pySEQTarget/__init__.py +5 -0
  5. pySEQTarget/analysis/__init__.py +8 -0
  6. pySEQTarget/analysis/_hazard.py +211 -0
  7. pySEQTarget/analysis/_outcome_fit.py +75 -0
  8. pySEQTarget/analysis/_risk_estimates.py +136 -0
  9. pySEQTarget/analysis/_subgroup_fit.py +30 -0
  10. pySEQTarget/analysis/_survival_pred.py +372 -0
  11. pySEQTarget/data/__init__.py +19 -0
  12. pySEQTarget/error/__init__.py +2 -0
  13. pySEQTarget/error/_datachecker.py +38 -0
  14. pySEQTarget/error/_param_checker.py +50 -0
  15. pySEQTarget/expansion/__init__.py +5 -0
  16. pySEQTarget/expansion/_binder.py +98 -0
  17. pySEQTarget/expansion/_diagnostics.py +53 -0
  18. pySEQTarget/expansion/_dynamic.py +73 -0
  19. pySEQTarget/expansion/_mapper.py +44 -0
  20. pySEQTarget/expansion/_selection.py +31 -0
  21. pySEQTarget/helpers/__init__.py +8 -0
  22. pySEQTarget/helpers/_bootstrap.py +111 -0
  23. pySEQTarget/helpers/_col_string.py +6 -0
  24. pySEQTarget/helpers/_format_time.py +6 -0
  25. pySEQTarget/helpers/_output_files.py +167 -0
  26. pySEQTarget/helpers/_pad.py +7 -0
  27. pySEQTarget/helpers/_predict_model.py +9 -0
  28. pySEQTarget/helpers/_prepare_data.py +19 -0
  29. pySEQTarget/initialization/__init__.py +5 -0
  30. pySEQTarget/initialization/_censoring.py +53 -0
  31. pySEQTarget/initialization/_denominator.py +39 -0
  32. pySEQTarget/initialization/_numerator.py +37 -0
  33. pySEQTarget/initialization/_outcome.py +56 -0
  34. pySEQTarget/plot/__init__.py +1 -0
  35. pySEQTarget/plot/_survival_plot.py +104 -0
  36. pySEQTarget/weighting/__init__.py +8 -0
  37. pySEQTarget/weighting/_weight_bind.py +86 -0
  38. pySEQTarget/weighting/_weight_data.py +47 -0
  39. pySEQTarget/weighting/_weight_fit.py +99 -0
  40. pySEQTarget/weighting/_weight_pred.py +192 -0
  41. pySEQTarget/weighting/_weight_stats.py +23 -0
  42. pyseqtarget-0.10.0.dist-info/METADATA +98 -0
  43. pyseqtarget-0.10.0.dist-info/RECORD +46 -0
  44. pyseqtarget-0.10.0.dist-info/WHEEL +5 -0
  45. pyseqtarget-0.10.0.dist-info/licenses/LICENSE +21 -0
  46. pyseqtarget-0.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,136 @@
1
+ import polars as pl
2
+ from scipy import stats
3
+
4
+
5
+ def _risk_estimates(self):
6
+ last_followup = self.km_data["followup"].max()
7
+ risk = self.km_data.filter(
8
+ (pl.col("followup") == last_followup) & (pl.col("estimate") == "risk")
9
+ )
10
+
11
+ group_cols = [self.subgroup_colname] if self.subgroup_colname else []
12
+ rd_comparisons = []
13
+ rr_comparisons = []
14
+
15
+ if self.bootstrap_nboot > 0:
16
+ alpha = 1 - self.bootstrap_CI
17
+ z = stats.norm.ppf(1 - alpha / 2)
18
+
19
+ for tx_x in self.treatment_level:
20
+ for tx_y in self.treatment_level:
21
+ if tx_x == tx_y:
22
+ continue
23
+
24
+ risk_x = (
25
+ risk.filter(pl.col("tx_init") == tx_x)
26
+ .select(group_cols + ["pred"])
27
+ .rename({"pred": "risk_x"})
28
+ )
29
+
30
+ risk_y = (
31
+ risk.filter(pl.col("tx_init") == tx_y)
32
+ .select(group_cols + ["pred"])
33
+ .rename({"pred": "risk_y"})
34
+ )
35
+
36
+ if group_cols:
37
+ comp = risk_x.join(risk_y, on=group_cols, how="left")
38
+ else:
39
+ comp = risk_x.join(risk_y, how="cross")
40
+
41
+ comp = comp.with_columns(
42
+ [pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
43
+ )
44
+
45
+ if self.bootstrap_nboot > 0:
46
+ se_x = (
47
+ risk.filter(pl.col("tx_init") == tx_x)
48
+ .select(group_cols + ["SE"])
49
+ .rename({"SE": "se_x"})
50
+ )
51
+
52
+ se_y = (
53
+ risk.filter(pl.col("tx_init") == tx_y)
54
+ .select(group_cols + ["SE"])
55
+ .rename({"SE": "se_y"})
56
+ )
57
+
58
+ if group_cols:
59
+ comp = comp.join(se_x, on=group_cols, how="left")
60
+ comp = comp.join(se_y, on=group_cols, how="left")
61
+ else:
62
+ comp = comp.join(se_x, how="cross")
63
+ comp = comp.join(se_y, how="cross")
64
+
65
+ rd_se = (pl.col("se_x").pow(2) + pl.col("se_y").pow(2)).sqrt()
66
+ rd_comp = comp.with_columns(
67
+ [
68
+ (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference"),
69
+ (pl.col("risk_x") - pl.col("risk_y") - z * rd_se).alias(
70
+ "RD 95% LCI"
71
+ ),
72
+ (pl.col("risk_x") - pl.col("risk_y") + z * rd_se).alias(
73
+ "RD 95% UCI"
74
+ ),
75
+ ]
76
+ )
77
+ rd_comp = rd_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
78
+ col_order = group_cols + [
79
+ "A_x",
80
+ "A_y",
81
+ "Risk Difference",
82
+ "RD 95% LCI",
83
+ "RD 95% UCI",
84
+ ]
85
+ rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
86
+ rd_comparisons.append(rd_comp)
87
+
88
+ rr_log_se = (
89
+ (pl.col("se_x") / pl.col("risk_x")).pow(2)
90
+ + (pl.col("se_y") / pl.col("risk_y")).pow(2)
91
+ ).sqrt()
92
+ rr_comp = comp.with_columns(
93
+ [
94
+ (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio"),
95
+ (
96
+ (pl.col("risk_x") / pl.col("risk_y"))
97
+ * (-z * rr_log_se).exp()
98
+ ).alias("RR 95% LCI"),
99
+ (
100
+ (pl.col("risk_x") / pl.col("risk_y"))
101
+ * (z * rr_log_se).exp()
102
+ ).alias("RR 95% UCI"),
103
+ ]
104
+ )
105
+ rr_comp = rr_comp.drop(["risk_x", "risk_y", "se_x", "se_y"])
106
+ col_order = group_cols + [
107
+ "A_x",
108
+ "A_y",
109
+ "Risk Ratio",
110
+ "RR 95% LCI",
111
+ "RR 95% UCI",
112
+ ]
113
+ rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
114
+ rr_comparisons.append(rr_comp)
115
+
116
+ else:
117
+ rd_comp = comp.with_columns(
118
+ (pl.col("risk_x") - pl.col("risk_y")).alias("Risk Difference")
119
+ )
120
+ rd_comp = rd_comp.drop(["risk_x", "risk_y"])
121
+ col_order = group_cols + ["A_x", "A_y", "Risk Difference"]
122
+ rd_comp = rd_comp.select([c for c in col_order if c in rd_comp.columns])
123
+ rd_comparisons.append(rd_comp)
124
+
125
+ rr_comp = comp.with_columns(
126
+ (pl.col("risk_x") / pl.col("risk_y")).alias("Risk Ratio")
127
+ )
128
+ rr_comp = rr_comp.drop(["risk_x", "risk_y"])
129
+ col_order = group_cols + ["A_x", "A_y", "Risk Ratio"]
130
+ rr_comp = rr_comp.select([c for c in col_order if c in rr_comp.columns])
131
+ rr_comparisons.append(rr_comp)
132
+
133
+ risk_difference = pl.concat(rd_comparisons) if rd_comparisons else pl.DataFrame()
134
+ risk_ratio = pl.concat(rr_comparisons) if rr_comparisons else pl.DataFrame()
135
+
136
+ return {"risk_difference": risk_difference, "risk_ratio": risk_ratio}
@@ -0,0 +1,30 @@
1
+ import polars as pl
2
+
3
+ from ._outcome_fit import _outcome_fit
4
+
5
+
6
+ def _subgroup_fit(self):
7
+ subgroups = sorted(self.DT[self.subgroup_colname].unique().to_list())
8
+ self._unique_subgroups = subgroups
9
+
10
+ models_list = []
11
+ for val in subgroups:
12
+ subDT = self.DT.filter(pl.col(self.subgroup_colname) == val)
13
+
14
+ models = {
15
+ "outcome": _outcome_fit(
16
+ self, subDT, self.outcome_col, self.covariates, self.weighted, "weight"
17
+ )
18
+ }
19
+
20
+ if self.compevent_colname is not None:
21
+ models["compevent"] = _outcome_fit(
22
+ self,
23
+ subDT,
24
+ self.compevent_colname,
25
+ self.covariates,
26
+ self.weighted,
27
+ "weight",
28
+ )
29
+ models_list.append(models)
30
+ return models_list
@@ -0,0 +1,372 @@
1
+ import polars as pl
2
+
3
+
4
+ def _get_outcome_predictions(self, TxDT, idx=None):
5
+ data = TxDT.to_pandas()
6
+ predictions = {"outcome": []}
7
+ if self.compevent_colname is not None:
8
+ predictions["compevent"] = []
9
+
10
+ for boot_model in self.outcome_model:
11
+ model_dict = boot_model[idx] if idx is not None else boot_model
12
+ predictions["outcome"].append(model_dict["outcome"].predict(data))
13
+ if self.compevent_colname is not None:
14
+ predictions["compevent"].append(model_dict["compevent"].predict(data))
15
+
16
+ return predictions
17
+
18
+
19
+ def _pred_risk(self):
20
+ has_subgroups = (
21
+ isinstance(self.outcome_model[0], list) if self.outcome_model else False
22
+ )
23
+
24
+ if not has_subgroups:
25
+ return _calculate_risk(self, self.DT, idx=None, val=None)
26
+
27
+ all_risks = []
28
+ original_DT = self.DT
29
+
30
+ for i, val in enumerate(self._unique_subgroups):
31
+ subgroup_DT = original_DT.filter(pl.col(self.subgroup_colname) == val)
32
+ risk = _calculate_risk(self, subgroup_DT, i, val)
33
+ all_risks.append(risk)
34
+
35
+ self.DT = original_DT
36
+ return pl.concat(all_risks)
37
+
38
+
39
+ def _calculate_risk(self, data, idx=None, val=None):
40
+ a = 1 - self.bootstrap_CI
41
+ lci = a / 2
42
+ uci = 1 - lci
43
+
44
+ SDT = (
45
+ data.with_columns(
46
+ [
47
+ (
48
+ pl.col(self.id_col).cast(pl.Utf8) + pl.col("trial").cast(pl.Utf8)
49
+ ).alias("TID")
50
+ ]
51
+ )
52
+ .group_by("TID")
53
+ .first()
54
+ .drop(["followup", f"followup{self.indicator_squared}"])
55
+ .with_columns([pl.lit(list(range(self.followup_max))).alias("followup")])
56
+ .explode("followup")
57
+ .with_columns(
58
+ [
59
+ (pl.col("followup") + 1).alias("followup"),
60
+ (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
61
+ ]
62
+ )
63
+ ).sort([self.id_col, "trial", "followup"])
64
+
65
+ risks = []
66
+ for treatment_val in self.treatment_level:
67
+ TxDT = SDT.with_columns(
68
+ [
69
+ pl.lit(treatment_val).alias(
70
+ f"{self.treatment_col}{self.indicator_baseline}"
71
+ )
72
+ ]
73
+ )
74
+
75
+ if self.method == "dose-response":
76
+ if treatment_val == self.treatment_level[0]:
77
+ TxDT = TxDT.with_columns(
78
+ [pl.lit(0.0).alias("dose"), pl.lit(0.0).alias("dose_sq")]
79
+ )
80
+ else:
81
+ TxDT = TxDT.with_columns(
82
+ [
83
+ pl.col("followup").alias("dose"),
84
+ pl.col(f"followup{self.indicator_squared}").alias("dose_sq"),
85
+ ]
86
+ )
87
+
88
+ preds = _get_outcome_predictions(self, TxDT, idx=idx)
89
+ pred_series = [pl.Series("pred_outcome", preds["outcome"][0])]
90
+
91
+ if self.bootstrap_nboot > 0:
92
+ for boot_idx, pred in enumerate(preds["outcome"][1:], start=1):
93
+ pred_series.append(pl.Series(f"pred_outcome_{boot_idx}", pred))
94
+
95
+ if self.compevent_colname is not None:
96
+ pred_series.append(pl.Series("pred_ce", preds["compevent"][0]))
97
+ if self.bootstrap_nboot > 0:
98
+ for boot_idx, pred in enumerate(preds["compevent"][1:], start=1):
99
+ pred_series.append(pl.Series(f"pred_ce_{boot_idx}", pred))
100
+
101
+ outcome_names = [s.name for s in pred_series if "outcome" in s.name]
102
+ ce_names = [s.name for s in pred_series if "ce" in s.name]
103
+
104
+ TxDT = TxDT.with_columns(pred_series)
105
+
106
+ if self.compevent_colname is not None:
107
+ for out_col, ce_col in zip(outcome_names, ce_names):
108
+ surv_col = out_col.replace("pred_outcome", "surv")
109
+ cce_col = out_col.replace("pred_outcome", "cce")
110
+ inc_col = out_col.replace("pred_outcome", "inc")
111
+
112
+ TxDT = TxDT.with_columns(
113
+ [
114
+ (1 - pl.col(out_col)).cum_prod().over("TID").alias(surv_col),
115
+ ((1 - pl.col(out_col)) * (1 - pl.col(ce_col)))
116
+ .cum_prod()
117
+ .over("TID")
118
+ .alias(cce_col),
119
+ ]
120
+ ).with_columns(
121
+ [
122
+ (pl.col(out_col) * (1 - pl.col(ce_col)) * pl.col(cce_col))
123
+ .cum_sum()
124
+ .over("TID")
125
+ .alias(inc_col)
126
+ ]
127
+ )
128
+
129
+ surv_names = [n.replace("pred_outcome", "surv") for n in outcome_names]
130
+ inc_names = [n.replace("pred_outcome", "inc") for n in outcome_names]
131
+ TxDT = (
132
+ TxDT.group_by("followup")
133
+ .agg([pl.col(col).mean() for col in surv_names + inc_names])
134
+ .sort("followup")
135
+ )
136
+ main_col = "surv"
137
+ boot_cols = [col for col in surv_names if col != "surv"]
138
+ else:
139
+ TxDT = (
140
+ TxDT.with_columns(
141
+ [
142
+ (1 - pl.col(col)).cum_prod().over("TID").alias(col)
143
+ for col in outcome_names
144
+ ]
145
+ )
146
+ .group_by("followup")
147
+ .agg([pl.col(col).mean() for col in outcome_names])
148
+ .sort("followup")
149
+ .with_columns([(1 - pl.col(col)).alias(col) for col in outcome_names])
150
+ )
151
+ main_col = "pred_outcome"
152
+ boot_cols = [col for col in outcome_names if col != "pred_outcome"]
153
+
154
+ if boot_cols:
155
+ risk = (
156
+ TxDT.select(["followup"] + boot_cols)
157
+ .unpivot(
158
+ index="followup",
159
+ on=boot_cols,
160
+ variable_name="bootID",
161
+ value_name="risk",
162
+ )
163
+ .group_by("followup")
164
+ .agg(
165
+ [
166
+ pl.col("risk").std().cast(pl.Float64).alias("SE"),
167
+ pl.col("risk").quantile(lci).cast(pl.Float64).alias("LCI"),
168
+ pl.col("risk").quantile(uci).cast(pl.Float64).alias("UCI"),
169
+ ]
170
+ )
171
+ .join(TxDT.select(["followup", main_col]), on="followup")
172
+ )
173
+
174
+ if self.bootstrap_CI_method == "se":
175
+ from scipy.stats import norm
176
+
177
+ z = norm.ppf(1 - a / 2)
178
+ risk = risk.with_columns(
179
+ [
180
+ (pl.col(main_col) - z * pl.col("SE")).alias("LCI"),
181
+ (pl.col(main_col) + z * pl.col("SE")).alias("UCI"),
182
+ ]
183
+ )
184
+
185
+ fup0_val = 1.0 if self.compevent_colname else 0.0
186
+
187
+ if self.compevent_colname is not None:
188
+ inc_boot_cols = [col for col in inc_names if col != "inc"]
189
+ if inc_boot_cols:
190
+ inc_risk = (
191
+ TxDT.select(["followup"] + inc_boot_cols)
192
+ .unpivot(
193
+ index="followup",
194
+ on=inc_boot_cols,
195
+ variable_name="bootID",
196
+ value_name="inc_val",
197
+ )
198
+ .group_by("followup")
199
+ .agg(
200
+ [
201
+ pl.col("inc_val")
202
+ .std()
203
+ .cast(pl.Float64)
204
+ .alias("inc_SE"),
205
+ pl.col("inc_val")
206
+ .quantile(lci)
207
+ .cast(pl.Float64)
208
+ .alias("inc_LCI"),
209
+ pl.col("inc_val")
210
+ .quantile(uci)
211
+ .cast(pl.Float64)
212
+ .alias("inc_UCI"),
213
+ ]
214
+ )
215
+ .join(TxDT.select(["followup", "inc"]), on="followup")
216
+ )
217
+ risk = risk.join(inc_risk, on="followup")
218
+ final_cols = [
219
+ "followup",
220
+ main_col,
221
+ "SE",
222
+ "LCI",
223
+ "UCI",
224
+ "inc",
225
+ "inc_SE",
226
+ "inc_LCI",
227
+ "inc_UCI",
228
+ ]
229
+ risk = risk.select(final_cols).with_columns(
230
+ pl.lit(treatment_val).alias(self.treatment_col)
231
+ )
232
+
233
+ fup0 = pl.DataFrame(
234
+ {
235
+ "followup": [0],
236
+ main_col: [fup0_val],
237
+ "SE": [0.0],
238
+ "LCI": [fup0_val],
239
+ "UCI": [fup0_val],
240
+ "inc": [0.0],
241
+ "inc_SE": [0.0],
242
+ "inc_LCI": [0.0],
243
+ "inc_UCI": [0.0],
244
+ self.treatment_col: [treatment_val],
245
+ }
246
+ ).with_columns(
247
+ [
248
+ pl.col("followup").cast(pl.Int64),
249
+ pl.col(self.treatment_col).cast(pl.Int32),
250
+ ]
251
+ )
252
+ else:
253
+ risk = risk.select(
254
+ ["followup", main_col, "SE", "LCI", "UCI"]
255
+ ).with_columns(pl.lit(treatment_val).alias(self.treatment_col))
256
+ fup0 = pl.DataFrame(
257
+ {
258
+ "followup": [0],
259
+ main_col: [fup0_val],
260
+ "SE": [0.0],
261
+ "LCI": [fup0_val],
262
+ "UCI": [fup0_val],
263
+ self.treatment_col: [treatment_val],
264
+ }
265
+ ).with_columns(
266
+ [
267
+ pl.col("followup").cast(pl.Int64),
268
+ pl.col(self.treatment_col).cast(pl.Int32),
269
+ ]
270
+ )
271
+ else:
272
+ risk = risk.select(
273
+ ["followup", main_col, "SE", "LCI", "UCI"]
274
+ ).with_columns(pl.lit(treatment_val).alias(self.treatment_col))
275
+ fup0 = pl.DataFrame(
276
+ {
277
+ "followup": [0],
278
+ main_col: [fup0_val],
279
+ "SE": [0.0],
280
+ "LCI": [fup0_val],
281
+ "UCI": [fup0_val],
282
+ self.treatment_col: [treatment_val],
283
+ }
284
+ ).with_columns(
285
+ [
286
+ pl.col("followup").cast(pl.Int64),
287
+ pl.col(self.treatment_col).cast(pl.Int32),
288
+ ]
289
+ )
290
+ else:
291
+ fup0_val = 1.0 if self.compevent_colname else 0.0
292
+ risk = TxDT.select(["followup", main_col]).with_columns(
293
+ pl.lit(treatment_val).alias(self.treatment_col)
294
+ )
295
+ fup0 = pl.DataFrame(
296
+ {
297
+ "followup": [0],
298
+ main_col: [fup0_val],
299
+ self.treatment_col: [treatment_val],
300
+ }
301
+ ).with_columns(
302
+ [
303
+ pl.col("followup").cast(pl.Int64),
304
+ pl.col(self.treatment_col).cast(pl.Int32),
305
+ ]
306
+ )
307
+
308
+ if self.compevent_colname is not None:
309
+ risk = risk.join(TxDT.select(["followup", "inc"]), on="followup")
310
+ fup0 = fup0.with_columns([pl.lit(0.0).alias("inc")])
311
+
312
+ risks.append(pl.concat([fup0, risk]))
313
+ out = pl.concat(risks)
314
+
315
+ if self.compevent_colname is not None:
316
+ has_ci = "SE" in out.columns
317
+
318
+ surv_cols = ["followup", self.treatment_col, "surv"]
319
+ if has_ci:
320
+ surv_cols.extend(["SE", "LCI", "UCI"])
321
+ surv_out = (
322
+ out.select(surv_cols)
323
+ .rename({"surv": "pred"})
324
+ .with_columns(pl.lit("survival").alias("estimate"))
325
+ )
326
+
327
+ risk_cols = ["followup", self.treatment_col, (1 - pl.col("surv")).alias("pred")]
328
+ if has_ci:
329
+ risk_cols.extend(
330
+ [
331
+ pl.col("SE"),
332
+ (1 - pl.col("UCI")).alias("LCI"),
333
+ (1 - pl.col("LCI")).alias("UCI"),
334
+ ]
335
+ )
336
+ risk_out = out.select(risk_cols).with_columns(pl.lit("risk").alias("estimate"))
337
+
338
+ inc_cols = ["followup", self.treatment_col, pl.col("inc").alias("pred")]
339
+ if has_ci:
340
+ inc_cols.extend(
341
+ [
342
+ pl.col("inc_SE").alias("SE"),
343
+ pl.col("inc_LCI").alias("LCI"),
344
+ pl.col("inc_UCI").alias("UCI"),
345
+ ]
346
+ )
347
+ inc_out = out.select(inc_cols).with_columns(
348
+ pl.lit("incidence").alias("estimate")
349
+ )
350
+
351
+ out = pl.concat([surv_out, risk_out, inc_out])
352
+ else:
353
+ out = out.rename({"pred_outcome": "pred"}).with_columns(
354
+ pl.lit("risk").alias("estimate")
355
+ )
356
+
357
+ if val is not None:
358
+ out = out.with_columns(pl.lit(val).alias(self.subgroup_colname))
359
+
360
+ return out
361
+
362
+
363
+ def _calculate_survival(self, risk_data):
364
+ if self.bootstrap_nboot > 0:
365
+ surv = risk_data.with_columns(
366
+ [(1 - pl.col(col)).alias(col) for col in ["pred", "LCI", "UCI"]]
367
+ ).with_columns(pl.lit("survival").alias("estimate"))
368
+ else:
369
+ surv = risk_data.with_columns(
370
+ [(1 - pl.col("pred")).alias("pred"), pl.lit("survival").alias("estimate")]
371
+ )
372
+ return surv
@@ -0,0 +1,19 @@
1
+ from importlib.resources import files
2
+
3
+ import polars as pl
4
+
5
+
6
+ def load_data(name: str = "SEQdata") -> pl.DataFrame:
7
+ loc = files("pySEQTarget.data")
8
+ if name in ["SEQdata", "SEQdata_multitreatment", "SEQdata_LTFU"]:
9
+ if name == "SEQdata":
10
+ data_path = loc.joinpath("SEQdata.csv")
11
+ elif name == "SEQdata_multitreatment":
12
+ data_path = loc.joinpath("SEQdata_multitreatment.csv")
13
+ else:
14
+ data_path = loc.joinpath("SEQdata_LTFU.csv")
15
+ return pl.read_csv(data_path)
16
+ else:
17
+ raise ValueError(
18
+ f"Dataset '{name}' not available. Options: ['SEQdata', 'SEQdata_multitreatment', 'SEQdata_LTFU']"
19
+ )
@@ -0,0 +1,2 @@
1
+ from ._datachecker import _datachecker as _datachecker
2
+ from ._param_checker import _param_checker as _param_checker
@@ -0,0 +1,38 @@
1
+ import polars as pl
2
+
3
+
4
+ def _datachecker(self):
5
+ check = self.data.group_by(self.id_col).agg(
6
+ [pl.len().alias("row_count"), pl.col(self.time_col).max().alias("max_time")]
7
+ )
8
+
9
+ invalid = check.filter(pl.col("row_count") != pl.col("max_time") + 1)
10
+ if len(invalid) > 0:
11
+ raise ValueError(
12
+ f"Data validation failed: {len(invalid)} ID(s) have mismatched "
13
+ f"This suggests invalid times"
14
+ f"Invalid IDs:\n{invalid}"
15
+ )
16
+
17
+ for col in self.excused_colnames:
18
+ violations = (
19
+ self.data.sort([self.id_col, self.time_col])
20
+ .group_by(self.id_col)
21
+ .agg(
22
+ [
23
+ (
24
+ (pl.col(col).cum_sum().shift(1, fill_value=0) > 0)
25
+ & (pl.col(col) == 0)
26
+ )
27
+ .any()
28
+ .alias("has_violation")
29
+ ]
30
+ )
31
+ .filter(pl.col("has_violation"))
32
+ )
33
+
34
+ if len(violations) > 0:
35
+ raise ValueError(
36
+ f"Column '{col}' violates 'once one, always one' rule for excusing treatment "
37
+ f"{len(violations)} ID(s) have zeros after ones."
38
+ )
@@ -0,0 +1,50 @@
1
+ from ..helpers import _pad
2
+
3
+
4
+ def _param_checker(self):
5
+ if (
6
+ self.subgroup_colname is not None
7
+ and self.subgroup_colname not in self.fixed_cols
8
+ ):
9
+ raise ValueError("subgroup_colname must be included in fixed_cols.")
10
+
11
+ if self.followup_max is None:
12
+ self.followup_max = self.data.select(self.time_col).to_series().max()
13
+
14
+ if len(self.excused_colnames) == 0 and self.excused:
15
+ self.excused = False
16
+ raise Warning(
17
+ "Excused column names not provided but excused is set to True. Automatically set excused to False"
18
+ )
19
+
20
+ if len(self.excused_colnames) > 0 and not self.excused:
21
+ self.excused = True
22
+ raise Warning(
23
+ "Excused column names provided but excused is set to False. Automatically set excused to True"
24
+ )
25
+
26
+ if self.km_curves and self.hazard_estimate:
27
+ raise ValueError("km_curves and hazard cannot both be set to True.")
28
+
29
+ if sum([self.followup_class, self.followup_include, self.followup_spline]) > 1:
30
+ raise ValueError(
31
+ "Only one of followup_class or followup_include can be set to True."
32
+ )
33
+
34
+ if (
35
+ self.weighted
36
+ and self.method == "ITT"
37
+ and self.cense_colname is None
38
+ and self.visit_colname is None
39
+ ):
40
+ raise ValueError(
41
+ "For weighted ITT analyses, cense_colname or visit_colname must be provided."
42
+ )
43
+
44
+ if self.excused:
45
+ _, self.excused_colnames = _pad(self.treatment_level, self.excused_colnames)
46
+ _, self.weight_eligible_colnames = _pad(
47
+ self.treatment_level, self.weight_eligible_colnames
48
+ )
49
+
50
+ return
@@ -0,0 +1,5 @@
1
+ from ._binder import _binder as _binder
2
+ from ._diagnostics import _diagnostics as _diagnostics
3
+ from ._dynamic import _dynamic as _dynamic
4
+ from ._mapper import _mapper as _mapper
5
+ from ._selection import _random_selection as _random_selection