pySEQTarget 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. pySEQTarget/SEQopts.py +197 -0
  2. pySEQTarget/SEQoutput.py +163 -0
  3. pySEQTarget/SEQuential.py +375 -0
  4. pySEQTarget/__init__.py +5 -0
  5. pySEQTarget/analysis/__init__.py +8 -0
  6. pySEQTarget/analysis/_hazard.py +211 -0
  7. pySEQTarget/analysis/_outcome_fit.py +75 -0
  8. pySEQTarget/analysis/_risk_estimates.py +136 -0
  9. pySEQTarget/analysis/_subgroup_fit.py +30 -0
  10. pySEQTarget/analysis/_survival_pred.py +372 -0
  11. pySEQTarget/data/__init__.py +19 -0
  12. pySEQTarget/error/__init__.py +2 -0
  13. pySEQTarget/error/_datachecker.py +38 -0
  14. pySEQTarget/error/_param_checker.py +50 -0
  15. pySEQTarget/expansion/__init__.py +5 -0
  16. pySEQTarget/expansion/_binder.py +98 -0
  17. pySEQTarget/expansion/_diagnostics.py +53 -0
  18. pySEQTarget/expansion/_dynamic.py +73 -0
  19. pySEQTarget/expansion/_mapper.py +44 -0
  20. pySEQTarget/expansion/_selection.py +31 -0
  21. pySEQTarget/helpers/__init__.py +8 -0
  22. pySEQTarget/helpers/_bootstrap.py +111 -0
  23. pySEQTarget/helpers/_col_string.py +6 -0
  24. pySEQTarget/helpers/_format_time.py +6 -0
  25. pySEQTarget/helpers/_output_files.py +167 -0
  26. pySEQTarget/helpers/_pad.py +7 -0
  27. pySEQTarget/helpers/_predict_model.py +9 -0
  28. pySEQTarget/helpers/_prepare_data.py +19 -0
  29. pySEQTarget/initialization/__init__.py +5 -0
  30. pySEQTarget/initialization/_censoring.py +53 -0
  31. pySEQTarget/initialization/_denominator.py +39 -0
  32. pySEQTarget/initialization/_numerator.py +37 -0
  33. pySEQTarget/initialization/_outcome.py +56 -0
  34. pySEQTarget/plot/__init__.py +1 -0
  35. pySEQTarget/plot/_survival_plot.py +104 -0
  36. pySEQTarget/weighting/__init__.py +8 -0
  37. pySEQTarget/weighting/_weight_bind.py +86 -0
  38. pySEQTarget/weighting/_weight_data.py +47 -0
  39. pySEQTarget/weighting/_weight_fit.py +99 -0
  40. pySEQTarget/weighting/_weight_pred.py +192 -0
  41. pySEQTarget/weighting/_weight_stats.py +23 -0
  42. pyseqtarget-0.10.0.dist-info/METADATA +98 -0
  43. pyseqtarget-0.10.0.dist-info/RECORD +46 -0
  44. pyseqtarget-0.10.0.dist-info/WHEEL +5 -0
  45. pyseqtarget-0.10.0.dist-info/licenses/LICENSE +21 -0
  46. pyseqtarget-0.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,98 @@
1
+ import polars as pl
2
+
3
+ from ._mapper import _mapper
4
+
5
+
6
+ def _binder(self, kept_cols):
7
+ """
8
+ Internal function to bind data to the map created by __mapper
9
+ """
10
+ excluded = {
11
+ "dose",
12
+ f"dose{self.indicator_squared}",
13
+ "followup",
14
+ f"followup{self.indicator_squared}",
15
+ "tx_lag",
16
+ "trial",
17
+ f"trial{self.indicator_squared}",
18
+ self.time_col,
19
+ f"{self.time_col}{self.indicator_squared}",
20
+ }
21
+
22
+ cols = kept_cols.union({self.eligible_col, self.outcome_col, self.treatment_col})
23
+ cols = {col for col in cols if col is not None}
24
+
25
+ regular = {
26
+ col
27
+ for col in cols
28
+ if not (self.indicator_baseline in col or self.indicator_squared in col)
29
+ and col not in excluded
30
+ }
31
+
32
+ baseline = {
33
+ col for col in cols if self.indicator_baseline in col and col not in excluded
34
+ }
35
+ bas_kept = {col.replace(self.indicator_baseline, "") for col in baseline}
36
+
37
+ squared = {
38
+ col for col in cols if self.indicator_squared in col and col not in excluded
39
+ }
40
+ sq_kept = {col.replace(self.indicator_squared, "") for col in squared}
41
+
42
+ kept = list(regular.union(bas_kept).union(sq_kept))
43
+
44
+ if self.selection_first_trial:
45
+ DT = (
46
+ self.data.sort([self.id_col, self.time_col])
47
+ .with_columns(
48
+ [
49
+ pl.col(self.time_col).alias("period"),
50
+ pl.col(self.time_col).alias("followup"),
51
+ pl.lit(0).alias("trial"),
52
+ ]
53
+ )
54
+ .drop(self.time_col)
55
+ )
56
+ else:
57
+ DT = _mapper(
58
+ self.data, self.id_col, self.time_col, self.followup_min, self.followup_max
59
+ )
60
+ DT = DT.join(
61
+ self.data.select([self.id_col, self.time_col] + kept),
62
+ left_on=[self.id_col, "period"],
63
+ right_on=[self.id_col, self.time_col],
64
+ how="left",
65
+ )
66
+ DT = DT.sort([self.id_col, "trial", "followup"]).with_columns(
67
+ [
68
+ (pl.col("trial") ** 2).alias(f"trial{self.indicator_squared}"),
69
+ (pl.col("followup") ** 2).alias(f"followup{self.indicator_squared}"),
70
+ ]
71
+ )
72
+
73
+ if squared:
74
+ squares = []
75
+ for sq in squared:
76
+ col = sq.replace(self.indicator_squared, "")
77
+ squares.append((pl.col(col) ** 2).alias(f"{col}{self.indicator_squared}"))
78
+ DT = DT.with_columns(squares)
79
+
80
+ baseline_cols = {bas.replace(self.indicator_baseline, "") for bas in baseline}
81
+ needed = {self.eligible_col, self.treatment_col}
82
+ baseline_cols.update({c for c in needed})
83
+
84
+ bas = [
85
+ pl.col(c)
86
+ .first()
87
+ .over([self.id_col, "trial"])
88
+ .alias(f"{c}{self.indicator_baseline}")
89
+ for c in baseline_cols
90
+ ]
91
+
92
+ DT = (
93
+ DT.with_columns(bas)
94
+ .filter(pl.col(f"{self.eligible_col}{self.indicator_baseline}") == 1)
95
+ .drop([f"{self.eligible_col}{self.indicator_baseline}", self.eligible_col])
96
+ )
97
+
98
+ return DT
@@ -0,0 +1,53 @@
1
+ import polars as pl
2
+
3
+
4
+ def _diagnostics(self):
5
+ unique_out = _outcome_diag(self, unique=True)
6
+ nonunique_out = _outcome_diag(self, unique=False)
7
+ out = {"unique_outcomes": unique_out, "nonunique_outcomes": nonunique_out}
8
+
9
+ if self.method == "censoring":
10
+ unique_switch = _switch_diag(self, unique=True)
11
+ nonunique_switch = _switch_diag(self, unique=False)
12
+ out.update(
13
+ {"unique_switches": unique_switch, "nonunique_switches": nonunique_switch}
14
+ )
15
+
16
+ self.diagnostics = out
17
+
18
+
19
+ def _outcome_diag(self, unique):
20
+ if unique:
21
+ data = (
22
+ self.DT.select([self.id_col, self.treatment_col, self.outcome_col])
23
+ .group_by(self.id_col)
24
+ .last()
25
+ )
26
+ else:
27
+ data = self.DT
28
+ out = data.group_by([self.treatment_col, self.outcome_col]).len()
29
+
30
+ return out
31
+
32
+
33
+ def _switch_diag(self, unique):
34
+ if not self.excused:
35
+ data = self.DT.with_columns(pl.lit(False).alias("isExcused"))
36
+ else:
37
+ data = self.DT
38
+
39
+ if unique:
40
+ data = (
41
+ data.select([self.id_col, self.treatment_col, "switch", "isExcused"])
42
+ .with_columns(
43
+ pl.when((pl.col("switch") == 0) & (pl.col("isExcused")))
44
+ .then(1)
45
+ .otherwise(pl.col("switch"))
46
+ .alias("switch")
47
+ )
48
+ .group_by(self.id_col)
49
+ .last()
50
+ )
51
+
52
+ out = data.group_by([self.treatment_col, "isExcused", "switch"]).len()
53
+ return out
@@ -0,0 +1,73 @@
1
+ import polars as pl
2
+
3
+
4
+ def _dynamic(self):
5
+ """
6
+ Handles special cases for the data from the __mapper -> __binder pipeline
7
+ """
8
+ if self.method == "dose-response":
9
+ DT = self.DT.with_columns(
10
+ pl.col(self.treatment_col)
11
+ .cum_sum()
12
+ .over([self.id_col, "trial"])
13
+ .alias("dose")
14
+ ).with_columns([(pl.col("dose") ** 2).alias(f"dose{self.indicator_squared}")])
15
+ self.DT = DT
16
+
17
+ elif self.method == "censoring":
18
+ DT = self.DT.sort([self.id_col, "trial", "followup"]).with_columns(
19
+ pl.col(self.treatment_col)
20
+ .shift(1)
21
+ .over([self.id_col, "trial"])
22
+ .alias("tx_lag")
23
+ )
24
+
25
+ switch = (
26
+ pl.when(pl.col("followup") == 0)
27
+ .then(pl.lit(False))
28
+ .otherwise(pl.col("tx_lag") != pl.col(self.treatment_col))
29
+ )
30
+ is_excused = pl.lit(False)
31
+ if self.excused:
32
+ conditions = []
33
+ for i, val in enumerate(self.treatment_level):
34
+ colname = self.excused_colnames[i]
35
+ if colname is not None:
36
+ conditions.append(
37
+ (pl.col(colname) == 1) & (pl.col(self.treatment_col) == val)
38
+ )
39
+
40
+ if conditions:
41
+ excused = pl.any_horizontal(conditions)
42
+ is_excused = switch & excused
43
+
44
+ DT = DT.with_columns(
45
+ [switch.alias("switch"), is_excused.alias("isExcused")]
46
+ ).sort([self.id_col, "trial", "followup"])
47
+
48
+ if self.excused:
49
+ DT = (
50
+ DT.with_columns(
51
+ pl.col("isExcused")
52
+ .cast(pl.Int8)
53
+ .cum_sum()
54
+ .over([self.id_col, "trial"])
55
+ .alias("_excused_tmp")
56
+ )
57
+ .with_columns(
58
+ pl.when(pl.col("_excused_tmp") > 0)
59
+ .then(pl.lit(False))
60
+ .otherwise(pl.col("switch"))
61
+ .alias("switch")
62
+ )
63
+ .drop("_excused_tmp")
64
+ )
65
+
66
+ DT = DT.filter(
67
+ (pl.col("switch").cum_max().shift(1, fill_value=False)).over(
68
+ [self.id_col, "trial"]
69
+ )
70
+ == 0
71
+ ).with_columns(pl.col("switch").cast(pl.Int8).alias("switch"))
72
+
73
+ self.DT = DT.drop(["tx_lag"])
@@ -0,0 +1,44 @@
1
+ import math
2
+
3
+ import polars as pl
4
+
5
+
6
+ def _mapper(data, id_col, time_col, min_followup=-math.inf, max_followup=math.inf):
7
+ """
8
+ Internal function to create the expanded map to bind data to.
9
+ """
10
+
11
+ DT = (
12
+ data.select([pl.col(id_col), pl.col(time_col)])
13
+ .with_columns([pl.col(id_col).cum_count().over(id_col).sub(1).alias("trial")])
14
+ .with_columns(
15
+ [
16
+ pl.struct(
17
+ [
18
+ pl.col(time_col),
19
+ pl.col(time_col).max().over(id_col).alias("max_time"),
20
+ ]
21
+ )
22
+ .map_elements(
23
+ lambda x: list(range(x[time_col], x["max_time"] + 1)),
24
+ return_dtype=pl.List(pl.Int64),
25
+ )
26
+ .alias("period")
27
+ ]
28
+ )
29
+ .explode("period")
30
+ .drop(pl.col(time_col))
31
+ .with_columns(
32
+ [
33
+ pl.col(id_col)
34
+ .cum_count()
35
+ .over([id_col, "trial"])
36
+ .sub(1)
37
+ .alias("followup")
38
+ ]
39
+ )
40
+ .filter(
41
+ (pl.col("followup") >= min_followup) & (pl.col("followup") <= max_followup)
42
+ )
43
+ )
44
+ return DT
@@ -0,0 +1,31 @@
1
+ import polars as pl
2
+
3
+
4
+ def _random_selection(self):
5
+ """
6
+ Handles the case where random selection is applied for data from
7
+ the __mapper -> __binder -> optionally __dynamic pipeline
8
+ """
9
+ UIDs = (
10
+ self.DT.select(
11
+ [self.id_col, "trial", f"{self.treatment_col}{self.indicator_baseline}"]
12
+ )
13
+ .with_columns((pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID"))
14
+ .filter(pl.col(f"{self.treatment_col}{self.indicator_baseline}") == 0)
15
+ .unique("trialID")
16
+ .to_series()
17
+ .to_list()
18
+ )
19
+
20
+ NIDs = len(UIDs)
21
+ sample = self._rng.choice(
22
+ UIDs, size=int(self.selection_sample * NIDs), replace=False
23
+ )
24
+
25
+ self.DT = (
26
+ self.DT.with_columns(
27
+ (pl.col(self.id_col) + "_" + pl.col("trial")).alias("trialID")
28
+ )
29
+ .filter(pl.col("trialID").is_in(sample))
30
+ .drop("trialID")
31
+ )
@@ -0,0 +1,8 @@
1
+ from ._bootstrap import bootstrap_loop as bootstrap_loop
2
+ from ._col_string import _col_string as _col_string
3
+ from ._format_time import _format_time as _format_time
4
+ from ._output_files import _build_md as _build_md
5
+ from ._output_files import _build_pdf as _build_pdf
6
+ from ._pad import _pad as _pad
7
+ from ._predict_model import _predict_model as _predict_model
8
+ from ._prepare_data import _prepare_data as _prepare_data
@@ -0,0 +1,111 @@
1
+ import copy
2
+ import time
3
+ from concurrent.futures import ProcessPoolExecutor, as_completed
4
+ from functools import wraps
5
+
6
+ import numpy as np
7
+ import polars as pl
8
+ from tqdm import tqdm
9
+
10
+ from ._format_time import _format_time
11
+
12
+
13
+ def _prepare_boot_data(self, data, boot_id):
14
+ id_counts = self._boot_samples[boot_id]
15
+
16
+ counts = pl.DataFrame(
17
+ {self.id_col: list(id_counts.keys()), "count": list(id_counts.values())}
18
+ )
19
+
20
+ bootstrapped = data.join(counts, on=self.id_col, how="inner")
21
+ bootstrapped = (
22
+ bootstrapped.with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate"))
23
+ .explode("replicate")
24
+ .with_columns(
25
+ (
26
+ pl.col(self.id_col).cast(pl.Utf8)
27
+ + "_"
28
+ + pl.col("replicate").cast(pl.Utf8)
29
+ ).alias(self.id_col)
30
+ )
31
+ .drop("count", "replicate")
32
+ )
33
+
34
+ return bootstrapped
35
+
36
+
37
+ def _bootstrap_worker(obj, method_name, original_DT, i, seed, args, kwargs):
38
+ obj = copy.deepcopy(obj)
39
+ obj._rng = (
40
+ np.random.RandomState(seed + i) if seed is not None else np.random.RandomState()
41
+ )
42
+ obj.DT = _prepare_boot_data(obj, original_DT, i)
43
+
44
+ # Disable bootstrapping to prevent recursion
45
+ obj.bootstrap_nboot = 0
46
+
47
+ method = getattr(obj, method_name)
48
+ result = method(*args, **kwargs)
49
+ obj._rng = None
50
+ return result
51
+
52
+
53
+ def bootstrap_loop(method):
54
+ @wraps(method)
55
+ def wrapper(self, *args, **kwargs):
56
+ if not hasattr(self, "outcome_model"):
57
+ self.outcome_model = []
58
+ start = time.perf_counter()
59
+
60
+ results = []
61
+ full = method(self, *args, **kwargs)
62
+ results.append(full)
63
+
64
+ if getattr(self, "bootstrap_nboot") > 0 and getattr(
65
+ self, "_boot_samples", None
66
+ ):
67
+ original_DT = self.DT
68
+ nboot = self.bootstrap_nboot
69
+ ncores = self.ncores
70
+ seed = getattr(self, "seed", None)
71
+ method_name = method.__name__
72
+
73
+ if getattr(self, "parallel", False):
74
+ original_rng = getattr(self, "_rng", None)
75
+ self._rng = None
76
+
77
+ with ProcessPoolExecutor(max_workers=ncores) as executor:
78
+ futures = [
79
+ executor.submit(
80
+ _bootstrap_worker,
81
+ self,
82
+ method_name,
83
+ original_DT,
84
+ i,
85
+ seed,
86
+ args,
87
+ kwargs,
88
+ )
89
+ for i in range(nboot)
90
+ ]
91
+ for j in tqdm(
92
+ as_completed(futures), total=nboot, desc="Bootstrapping..."
93
+ ):
94
+ results.append(j.result())
95
+
96
+ self._rng = original_rng
97
+ else:
98
+ for i in tqdm(range(nboot), desc="Bootstrapping..."):
99
+ self.DT = _prepare_boot_data(self, original_DT, i)
100
+ boot_fit = method(self, *args, **kwargs)
101
+ results.append(boot_fit)
102
+
103
+ self.DT = original_DT
104
+
105
+ end = time.perf_counter()
106
+ self._model_time = _format_time(start, end)
107
+
108
+ self.outcome_model = results
109
+ return results
110
+
111
+ return wrapper
@@ -0,0 +1,6 @@
1
+ def _col_string(expressions):
2
+ cols = set()
3
+ for expression in expressions:
4
+ if expression is not None:
5
+ cols.update(expression.replace("+", " ").replace("*", " ").split())
6
+ return cols
@@ -0,0 +1,6 @@
1
+ def _format_time(start, end):
2
+ elapsed = end - start
3
+ days, rem = divmod(elapsed, 86400)
4
+ hours, rem = divmod(rem, 3600)
5
+ minutes, seconds = divmod(rem, 60)
6
+ return f"{int(days)}-{int(hours):02d}:{int(minutes):02d}:{seconds:05.2f}"
@@ -0,0 +1,167 @@
1
+ import datetime
2
+
3
+
4
+ def _build_md(self, img_path: str = None) -> str:
5
+ """
6
+ Builds markdown content for SEQuential analysis results.
7
+
8
+ :param self: SEQoutput instance
9
+ :param img_path: Path to saved KM graph image (if any)
10
+ :return: Markdown string
11
+ """
12
+
13
+ lines = []
14
+
15
+ lines.append(f"# SEQuential Analysis: {datetime.date.today()}: {self.method}")
16
+ lines.append("")
17
+
18
+ if self.options.weighted:
19
+ lines.append("## Weighting")
20
+ lines.append("")
21
+
22
+ lines.append("### Numerator Model")
23
+ lines.append("")
24
+ lines.append("```")
25
+ lines.append(str(self.numerator_models[0].summary()))
26
+ lines.append("```")
27
+ lines.append("")
28
+
29
+ lines.append("### Denominator Model")
30
+ lines.append("")
31
+ lines.append("```")
32
+ lines.append(str(self.denominator_models[0].summary()))
33
+ lines.append("```")
34
+ lines.append("")
35
+
36
+ if self.options.compevent_colname is not None and self.compevent_models:
37
+ lines.append("### Competing Event Model")
38
+ lines.append("")
39
+ lines.append("```")
40
+ lines.append(str(self.compevent_models[0].summary()))
41
+ lines.append("```")
42
+ lines.append("")
43
+
44
+ lines.append("### Weighting Statistics")
45
+ lines.append("")
46
+ lines.append(self.weight_statistics.to_pandas().to_markdown(index=False))
47
+ lines.append("")
48
+
49
+ lines.append("## Outcome")
50
+ lines.append("")
51
+
52
+ lines.append("### Outcome Model")
53
+ lines.append("")
54
+ lines.append("```")
55
+ lines.append(str(self.outcome_models[0].summary()))
56
+ lines.append("```")
57
+ lines.append("")
58
+
59
+ if self.options.hazard_estimate and self.hazard is not None:
60
+ lines.append("### Hazard")
61
+ lines.append("")
62
+ lines.append(self.hazard.to_pandas().to_markdown(index=False))
63
+ lines.append("")
64
+
65
+ if self.options.km_curves:
66
+ lines.append("### Survival")
67
+ lines.append("")
68
+
69
+ if self.risk_difference is not None:
70
+ lines.append("#### Risk Differences")
71
+ lines.append("")
72
+ lines.append(self.risk_difference.to_pandas().to_markdown(index=False))
73
+ lines.append("")
74
+
75
+ if self.risk_ratio is not None:
76
+ lines.append("#### Risk Ratios")
77
+ lines.append("")
78
+ lines.append(self.risk_ratio.to_pandas().to_markdown(index=False))
79
+ lines.append("")
80
+
81
+ if self.km_graph is not None and img_path is not None:
82
+ lines.append("#### Survival Curves")
83
+ lines.append("")
84
+ lines.append(f"![Kaplan-Meier Survival Curves]({img_path})")
85
+ lines.append("")
86
+
87
+ if self.diagnostic_tables:
88
+ lines.append("## Diagnostic Tables")
89
+ lines.append("")
90
+ for name, table in self.diagnostic_tables.items():
91
+ lines.append(f"### {name.replace('_', ' ').title()}")
92
+ lines.append("")
93
+ lines.append(table.to_pandas().to_markdown(index=False))
94
+ lines.append("")
95
+
96
+ return "\n".join(lines)
97
+
98
+
99
+ def _build_pdf(md_content: str, filename: str, img_path: str = None) -> None:
100
+ """
101
+ Converts markdown content to PDF.
102
+
103
+ :param md_content: Markdown string
104
+ :param filename: Output PDF path
105
+ :param img_path: Absolute path to image file (if any)
106
+ """
107
+ try:
108
+ import markdown
109
+ from weasyprint import CSS, HTML
110
+ except ImportError:
111
+ raise ImportError(
112
+ "PDF generation requires 'markdown' and 'weasyprint'. "
113
+ "Install with: pip install markdown weasyprint"
114
+ )
115
+
116
+ html_content = markdown.markdown(md_content, extensions=["tables", "fenced_code"])
117
+
118
+ if img_path:
119
+ img_name = img_path.split("/")[-1]
120
+ html_content = html_content.replace(
121
+ f'src="{img_name}"', f'src="file://{img_path}"'
122
+ )
123
+
124
+ css = CSS(
125
+ string="""
126
+ body {
127
+ font-family: Arial, sans-serif;
128
+ font-size: 11pt;
129
+ line-height: 1.4;
130
+ margin: 2cm;
131
+ }
132
+ h1 { color: #2c3e50; border-bottom: 2px solid #2c3e50; padding-bottom: 0.3em; }
133
+ h2 { color: #34495e; border-bottom: 1px solid #bdc3c7; padding-bottom: 0.2em; }
134
+ h3 { color: #7f8c8d; }
135
+ table {
136
+ border-collapse: collapse;
137
+ width: 100%;
138
+ margin: 1em 0;
139
+ }
140
+ th, td {
141
+ border: 1px solid #bdc3c7;
142
+ padding: 8px;
143
+ text-align: left;
144
+ }
145
+ th { background-color: #ecf0f1; }
146
+ tr:nth-child(even) { background-color: #f9f9f9; }
147
+ pre {
148
+ background-color: #f4f4f4;
149
+ padding: 1em;
150
+ border-radius: 4px;
151
+ overflow-x: auto;
152
+ font-size: 9pt;
153
+ }
154
+ code { font-family: 'Courier New', monospace; }
155
+ img { max-width: 100%; height: auto; }
156
+ """
157
+ )
158
+
159
+ full_html = f"""
160
+ <!DOCTYPE html>
161
+ <html>
162
+ <head><meta charset="utf-8"></head>
163
+ <body>{html_content}</body>
164
+ </html>
165
+ """
166
+
167
+ HTML(string=full_html).write_pdf(filename, stylesheets=[css])
@@ -0,0 +1,7 @@
1
+ def _pad(a, b):
2
+ len_a, len_b = len(a), len(b)
3
+ if len_a < len_b:
4
+ a = a + [None] * (len_b - len_a)
5
+ elif len_b < len_a:
6
+ b = b + [None] * (len_a - len_b)
7
+ return a, b
@@ -0,0 +1,9 @@
1
+ import numpy as np
2
+
3
+
4
+ def _predict_model(self, model, newdata):
5
+ newdata = newdata.to_pandas()
6
+ for col in self.fixed_cols:
7
+ if col in newdata.columns:
8
+ newdata[col] = newdata[col].astype("category")
9
+ return np.array(model.predict(newdata))
@@ -0,0 +1,19 @@
1
+ import polars as pl
2
+
3
+
4
+ def _prepare_data(self, DT):
5
+ binaries = [
6
+ self.eligible_col,
7
+ self.outcome_col,
8
+ self.cense_colname,
9
+ ] # self.excused_colnames + self.weight_eligible_colnames
10
+ binary_colnames = [col for col in binaries if col in DT.columns and not None]
11
+
12
+ DT = DT.with_columns(
13
+ [
14
+ *[pl.col(col).cast(pl.Categorical) for col in self.fixed_cols],
15
+ *[pl.col(col).cast(pl.Int8) for col in binary_colnames],
16
+ pl.col(self.id_col).cast(pl.Utf8),
17
+ ]
18
+ )
19
+ return DT
@@ -0,0 +1,5 @@
1
+ from ._censoring import _cense_denominator as _cense_denominator
2
+ from ._censoring import _cense_numerator as _cense_numerator
3
+ from ._denominator import _denominator as _denominator
4
+ from ._numerator import _numerator as _numerator
5
+ from ._outcome import _outcome as _outcome