csrlite 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,333 @@
1
+ # pyre-strict
2
+ """
3
+ Medical History (MH) Summary Analysis Functions
4
+ """
5
+
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import polars as pl
10
+
11
+ from ..common.parse import StudyPlanParser
12
+ from ..common.plan import StudyPlan
13
+ from ..common.rtf import create_rtf_table_n_pct
14
+ from ..common.utils import apply_common_filters
15
+
16
+
17
+ def mh_summary(
18
+ population: pl.DataFrame,
19
+ observation: pl.DataFrame,
20
+ population_filter: str | None = "SAFFL = 'Y'",
21
+ observation_filter: str | None = "MHOCCUR = 'Y'",
22
+ id: tuple[str, str] = ("USUBJID", "Subject ID"),
23
+ group: tuple[str, str] = ("TRT01A", "Treatment"),
24
+ variables: list[tuple[str, str]] | None = None,
25
+ title: list[str] | None = None,
26
+ footnote: list[str] | None = None,
27
+ source: list[str] | None = None,
28
+ output_file: str = "mh_summary.rtf",
29
+ ) -> str:
30
+ """
31
+ Generate Medical History Summary Table.
32
+ """
33
+ if title is None:
34
+ title = ["Summary of Medical History by Body System and Preferred Term"]
35
+
36
+ if variables is None:
37
+ # Default hierarchy: Body System -> Preferred Term
38
+ variables = [("MHBODSYS", "System Organ Class"), ("MHDECOD", "Preferred Term")]
39
+
40
+ # Generate ARD
41
+ ard = mh_summary_ard(
42
+ population=population,
43
+ observation=observation,
44
+ population_filter=population_filter,
45
+ observation_filter=observation_filter,
46
+ group_col=group[0],
47
+ id_col=id[0],
48
+ variables=variables,
49
+ )
50
+
51
+ # Transform to Display DF
52
+ df = mh_summary_df(ard)
53
+
54
+ # Generate RTF
55
+ mh_summary_rtf(df=df, output_path=output_file, title=title, footnote=footnote, source=source)
56
+
57
+ return output_file
58
+
59
+
60
+ def mh_summary_ard(
61
+ population: pl.DataFrame,
62
+ observation: pl.DataFrame,
63
+ population_filter: str | None,
64
+ observation_filter: str | None,
65
+ group_col: str,
66
+ id_col: str,
67
+ variables: list[tuple[str, str]],
68
+ ) -> pl.DataFrame:
69
+ """
70
+ Generate ARD for MH Summary.
71
+ Hierarchy is often Body System -> Preferred Term.
72
+ """
73
+
74
+ # Apply filters
75
+ adsl, adq = apply_common_filters(
76
+ population=population,
77
+ observation=observation,
78
+ population_filter=population_filter,
79
+ observation_filter=observation_filter,
80
+ )
81
+
82
+ if adq is None:
83
+ # Should not happen as we passed observation df
84
+ raise ValueError("Observation data is missing")
85
+
86
+ # This summary usually nests MHDECOD under MHBODSYS
87
+ # Structure:
88
+ # Any Medical History (1=1)
89
+ # Body System 1
90
+ # Term A
91
+ # Term B
92
+
93
+ # We can reuse count_subject_with_observation but it handles list of flexible conditions.
94
+ # For nested structure, we might need manual construction or nested calls.
95
+
96
+ # Let's assume standard 2-level nesting: MHBODSYS -> MHDECOD
97
+ # Check if variables match this pattern
98
+
99
+ # Identify the hierarchy columns
100
+ # If standard usage: variables=[("MHBODSYS", "SOC"), ("MHDECOD", "PT")]
101
+
102
+ # We will build a list of (filter_expr, label, indent_level, is_header)
103
+
104
+ specs: list[dict[str, Any]] = []
105
+
106
+ # 1. Overall "Any Medical History"
107
+ specs.append(
108
+ {"filter": pl.lit(True), "label": "Any Medical History", "indent": 0, "is_header": False}
109
+ )
110
+
111
+ # Get distinct Body Systems
112
+ bodsys_list: list[str | None] = (
113
+ adq.select("MHBODSYS").unique().sort("MHBODSYS").to_series().to_list()
114
+ )
115
+
116
+ for sys in bodsys_list:
117
+ if sys is None:
118
+ continue
119
+
120
+ # Add Body System Row
121
+ specs.append(
122
+ {
123
+ "filter": pl.col("MHBODSYS") == sys,
124
+ "label": sys,
125
+ "indent": 1,
126
+ "is_header": False, # It has counts
127
+ }
128
+ )
129
+
130
+ # Get distinct Terms within this System
131
+ terms: list[str | None] = (
132
+ adq.filter(pl.col("MHBODSYS") == sys)
133
+ .select("MHDECOD")
134
+ .unique()
135
+ .sort("MHDECOD")
136
+ .to_series()
137
+ .to_list()
138
+ )
139
+
140
+ for term in terms:
141
+ if term is None:
142
+ continue
143
+ specs.append(
144
+ {
145
+ "filter": (pl.col("MHBODSYS") == sys) & (pl.col("MHDECOD") == term),
146
+ "label": term,
147
+ "indent": 2,
148
+ "is_header": False,
149
+ }
150
+ )
151
+
152
+ # Now calculate counts for each spec
153
+ results: list[dict[str, Any]] = []
154
+
155
+ # Get total population counts by group
156
+ pop_counts = adsl.group_by(group_col).count().sort(group_col)
157
+ groups: list[Any] = pop_counts.select(group_col).to_series().to_list()
158
+ # Pre-calculate totals map
159
+ pop_totals: dict[Any, int] = {
160
+ row[group_col]: row["count"] for row in pop_counts.iter_rows(named=True)
161
+ }
162
+
163
+ # Helper to calculate row
164
+ def calc_row(
165
+ spec: dict[str, Any], obs_data: pl.DataFrame, pop_data: pl.DataFrame
166
+ ) -> dict[str, Any]:
167
+ row_res = {"label": spec["label"], "indent": spec["indent"], "is_header": spec["is_header"]}
168
+
169
+ # Filter observation data based on spec string/expr
170
+ # Note: count_subject_with_observation logic handles join.
171
+ # We can simulate logic here.
172
+
173
+ # 1. Filter ADQ based on criteria
174
+ filtered_obs = obs_data.filter(spec["filter"])
175
+
176
+ # 2. Join with ADSL to get groups (inner join to count only subjects in population)
177
+ # But we already filtered ADSL (population).
178
+
179
+ subset = filtered_obs.join(pop_data.select([id_col, group_col]), on=id_col, how="inner")
180
+
181
+ # 3. Group by Group Col
182
+ counts = subset.select(id_col, group_col).unique().group_by(group_col).count()
183
+ counts_map = {row[group_col]: row["count"] for row in counts.iter_rows(named=True)}
184
+
185
+ for g in groups:
186
+ n = counts_map.get(g, 0)
187
+ denom = pop_totals.get(g, 0)
188
+ pct = (n / denom * 100.0) if denom > 0 else 0.0
189
+ row_res[f"count_{g}"] = n
190
+ row_res[f"pct_{g}"] = pct
191
+
192
+ return row_res
193
+
194
+ for spec in specs:
195
+ results.append(calc_row(spec, adq, adsl))
196
+
197
+ return pl.DataFrame(results)
198
+
199
+
200
+ def mh_summary_df(ard: pl.DataFrame) -> pl.DataFrame:
201
+ """
202
+ Transform ARD to Display DataFrame.
203
+ """
204
+ if ard.is_empty():
205
+ return pl.DataFrame()
206
+
207
+ # Identify group columns
208
+ cols = ard.columns
209
+ group_cols = [c for c in cols if c.startswith("count_")]
210
+ groups = [c.replace("count_", "") for c in group_cols]
211
+
212
+ select_exprs = [
213
+ (pl.lit(" ").repeat_by(pl.col("indent")).list.join("") + pl.col("label")).alias(
214
+ "Medical History"
215
+ )
216
+ ]
217
+
218
+ for g in groups:
219
+ col_n = pl.col(f"count_{g}")
220
+ col_pct = pl.col(f"pct_{g}")
221
+
222
+ fmt = (
223
+ col_n.cast(pl.Utf8)
224
+ + " ("
225
+ + col_pct.map_elements(lambda x: f"{x:.1f}", return_dtype=pl.Utf8)
226
+ + ")"
227
+ ).alias(g)
228
+
229
+ select_exprs.append(fmt)
230
+
231
+ return ard.select(select_exprs)
232
+
233
+
234
+ def mh_summary_rtf(
235
+ df: pl.DataFrame,
236
+ output_path: str,
237
+ title: list[str] | str,
238
+ footnote: list[str] | None,
239
+ source: list[str] | None,
240
+ ) -> None:
241
+ """
242
+ Generate RTF document.
243
+ """
244
+ if df.is_empty():
245
+ # Handle empty case?
246
+ return
247
+
248
+ n_cols = len(df.columns)
249
+ col_width_first = 2.5
250
+ remaining_width = 7.0 # Approx page width
251
+ col_width_others = remaining_width / (n_cols - 1)
252
+ col_widths = [col_width_first] + [col_width_others] * (n_cols - 1)
253
+
254
+ col_header_1 = list(df.columns)
255
+ col_header_2 = [""] + ["n (%)"] * (n_cols - 1)
256
+
257
+ rtf_doc = create_rtf_table_n_pct(
258
+ df=df,
259
+ col_header_1=col_header_1,
260
+ col_header_2=col_header_2,
261
+ col_widths=col_widths,
262
+ title=title,
263
+ footnote=footnote,
264
+ source=source,
265
+ )
266
+
267
+ rtf_doc.write_rtf(output_path)
268
+
269
+
270
+ def study_plan_to_mh_summary(study_plan: StudyPlan) -> list[str]:
271
+ """
272
+ Batch generate MH summaries from study plan.
273
+ """
274
+ analysis_type = "mh_summary"
275
+ output_dir = study_plan.output_dir
276
+
277
+ # Initialize parser
278
+ parser = StudyPlanParser(study_plan)
279
+
280
+ # Get plans
281
+ plans = study_plan.study_data.get("plans", [])
282
+ all_specs = []
283
+ for plan_data in plans:
284
+ expanded = study_plan.expander.expand_plan(plan_data)
285
+ for p in expanded:
286
+ all_specs.append(study_plan.expander.create_analysis_spec(p))
287
+
288
+ plan_df = pl.DataFrame(all_specs)
289
+
290
+ if "analysis" in plan_df.columns:
291
+ mh_plans = plan_df.filter(pl.col("analysis") == analysis_type)
292
+ else:
293
+ mh_plans = pl.DataFrame()
294
+
295
+ generated_files = []
296
+
297
+ for analysis in mh_plans.iter_rows(named=True):
298
+ pop_name = analysis.get("population", "enrolled")
299
+ group_kw = analysis.get("group", "trt01a") # specific key?
300
+
301
+ try:
302
+ # Load Population
303
+ adsl, group_col = parser.get_population_data(pop_name, group_kw)
304
+
305
+ # Load MH Data
306
+ # Note: Assuming 'admh' is the dataset name
307
+ (admh,) = parser.get_datasets("admh")
308
+
309
+ filename = f"{analysis_type}_{pop_name}_{group_kw}.rtf".lower()
310
+ output_path = f"{output_dir}/{filename}"
311
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
312
+
313
+ mh_summary(
314
+ population=adsl,
315
+ observation=admh,
316
+ population_filter=None, # Already filtered by parser
317
+ observation_filter="MHOCCUR = 'Y'",
318
+ group=(group_col, group_col), # Use actual col name
319
+ output_file=output_path,
320
+ title=[
321
+ "Summary of Medical History by System Organ Class and Preferred Term",
322
+ f"({pop_name} Population)",
323
+ ],
324
+ source=["Source: ADSL, ADMH"],
325
+ )
326
+
327
+ generated_files.append(output_path)
328
+
329
+ except Exception as e:
330
+ print(f"Error generating MH summary: {e}")
331
+ continue
332
+
333
+ return generated_files