csrlite 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,327 @@
1
+ # pyre-strict
2
+ """
3
+ Concomitant Medications (CM) Summary Functions
4
+
5
+ This module provides a three-step pipeline for CM summary analysis:
6
+ - cm_summary_ard: Generate Analysis Results Data (ARD) in long format
7
+ - cm_summary_df: Transform ARD to wide display format
8
+ - cm_summary_rtf: Generate formatted RTF output
9
+ - cm_summary: Complete pipeline wrapper
10
+ - study_plan_to_cm_summary: Batch generation from StudyPlan
11
+
12
+ Applications:
13
+ - Summary of Concomitant Medications
14
+ - Summary of Prior Medications
15
+ """
16
+
17
+ from pathlib import Path
18
+
19
+ import polars as pl
20
+ from rtflite import RTFDocument
21
+
22
+ from ..common.count import count_subject, count_subject_with_observation
23
+ from ..common.parse import StudyPlanParser
24
+ from ..common.plan import StudyPlan
25
+ from ..common.rtf import create_rtf_table_n_pct
26
+ from ..common.utils import apply_common_filters
27
+
28
+
29
+ def study_plan_to_cm_summary(
30
+ study_plan: StudyPlan,
31
+ ) -> list[str]:
32
+ """
33
+ Generate CM summary RTF outputs for all analyses defined in StudyPlan.
34
+
35
+ Args:
36
+ study_plan: StudyPlan object with loaded datasets and analysis specifications
37
+
38
+ Returns:
39
+ list[str]: List of paths to generated RTF files
40
+ """
41
+
42
+ # Meta data
43
+ analysis = "cm_summary"
44
+ analysis_label = "Summary of Concomitant Medications"
45
+ output_dir = study_plan.output_dir
46
+ footnote = ["Every participant is counted a single time for each applicable row and column."]
47
+ source = None
48
+
49
+ population_df_name = "adsl"
50
+ observation_df_name = "adcm"
51
+
52
+ id = ("USUBJID", "Subject ID")
53
+ total = True
54
+ missing_group = "error"
55
+
56
+ # Create output directory if it doesn't exist
57
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
58
+
59
+ # Initialize parser
60
+ parser = StudyPlanParser(study_plan)
61
+
62
+ # Get expanded plan DataFrame
63
+ plan_df = study_plan.get_plan_df()
64
+
65
+ # Filter for CM summary analyses
66
+ cm_plans = plan_df.filter(pl.col("analysis") == analysis)
67
+
68
+ rtf_files = []
69
+
70
+ # Generate RTF for each analysis
71
+ for row in cm_plans.iter_rows(named=True):
72
+ population = row["population"]
73
+ observation = row.get("observation")
74
+ parameter = row.get("parameter")
75
+ group = row.get("group")
76
+
77
+ # Validate group is specified
78
+ if group is None:
79
+ raise ValueError(
80
+ f"Group not specified in YAML "
81
+ f"population={population}, observation={observation}, parameter={parameter}. "
82
+ "Please add group to your YAML plan."
83
+ )
84
+
85
+ # Get datasets using parser
86
+ population_df, observation_df = parser.get_datasets(population_df_name, observation_df_name)
87
+
88
+ # Get filters and configuration using parser
89
+ population_filter = parser.get_population_filter(population)
90
+
91
+ # Handle parameters (variables to summarize)
92
+ if parameter:
93
+ param_names, param_filters, param_labels, _ = parser.get_parameter_info(parameter)
94
+ else:
95
+ # Default to summarizing "Any Medication" if no parameter specified
96
+ # But usually cm_summary needs parameters defining what to count
97
+ # Use a default generic filter if none provided
98
+ param_filters = ["1=1"]
99
+ param_labels = ["Any Medication"]
100
+
101
+ obs_filter = parser.get_observation_filter(observation)
102
+ group_var_name, group_labels = parser.get_group_info(group)
103
+
104
+ # Build variables as list of tuples [(filter, label)]
105
+ variables_list = list(zip(param_filters, param_labels))
106
+
107
+ # Build group tuple (variable_name, label)
108
+ group_var_label = group_labels[0] if group_labels else group_var_name
109
+ group_tuple = (group_var_name, group_var_label)
110
+
111
+ # Build title
112
+ title_parts = [analysis_label]
113
+ if observation:
114
+ obs_kw = study_plan.keywords.observations.get(observation)
115
+ if obs_kw and obs_kw.label:
116
+ title_parts.append(obs_kw.label)
117
+
118
+ pop_kw = study_plan.keywords.populations.get(population)
119
+ if pop_kw and pop_kw.label:
120
+ title_parts.append(pop_kw.label)
121
+
122
+ # Build output filename
123
+ filename = f"{analysis}_{population}"
124
+ if observation:
125
+ filename += f"_{observation}"
126
+ if parameter:
127
+ filename += f"_{parameter.replace(';', '_')}"
128
+ filename += ".rtf"
129
+ output_file = str(Path(output_dir) / filename)
130
+
131
+ # Generate RTF
132
+ rtf_path = cm_summary(
133
+ population=population_df,
134
+ observation=observation_df,
135
+ population_filter=population_filter,
136
+ observation_filter=obs_filter,
137
+ id=id,
138
+ group=group_tuple,
139
+ variables=variables_list,
140
+ title=title_parts,
141
+ footnote=footnote,
142
+ source=source,
143
+ output_file=output_file,
144
+ total=total,
145
+ missing_group=missing_group,
146
+ )
147
+
148
+ rtf_files.append(rtf_path)
149
+
150
+ return rtf_files
151
+
152
+
153
+ def cm_summary_ard(
154
+ population: pl.DataFrame,
155
+ observation: pl.DataFrame,
156
+ population_filter: str | None,
157
+ observation_filter: str | None,
158
+ id: tuple[str, str],
159
+ group: tuple[str, str],
160
+ variables: list[tuple[str, str]],
161
+ total: bool,
162
+ missing_group: str,
163
+ ) -> pl.DataFrame:
164
+ """
165
+ Generate Analysis Results Data (ARD) for CM summary analysis.
166
+ """
167
+ # Reuse the same logic logic as ae_summary_ard since it's generic counting
168
+ # But checking if we should duplicate code or import?
169
+ # For now, duplication allows independence (e.g. if CM specific logic is needed later)
170
+
171
+ pop_var_name = "Participants in population"
172
+ id_var_name, id_var_label = id
173
+ group_var_name, group_var_label = group
174
+
175
+ population_filtered, observation_to_filter = apply_common_filters(
176
+ population=population,
177
+ observation=observation,
178
+ population_filter=population_filter,
179
+ observation_filter=observation_filter,
180
+ )
181
+
182
+ assert observation_to_filter is not None
183
+
184
+ observation_filtered_list = []
185
+ for variable_filter, variable_label in variables:
186
+ obs_filtered = (
187
+ observation_to_filter.filter(
188
+ pl.col(id_var_name).is_in(population_filtered[id_var_name].to_list())
189
+ )
190
+ .filter(pl.sql_expr(variable_filter))
191
+ .with_columns(pl.lit(variable_label).alias("__index__"))
192
+ )
193
+ observation_filtered_list.append(obs_filtered)
194
+
195
+ if observation_filtered_list:
196
+ observation_filtered = pl.concat(observation_filtered_list)
197
+ else:
198
+ # Handle case with no variables (empty df with correct schema)
199
+ observation_filtered = observation_to_filter.clear().with_columns(
200
+ pl.lit("").alias("__index__")
201
+ )
202
+
203
+ # Population counts
204
+ n_pop = count_subject(
205
+ population=population_filtered,
206
+ id=id_var_name,
207
+ group=group_var_name,
208
+ total=total,
209
+ missing_group=missing_group,
210
+ )
211
+
212
+ n_pop = n_pop.select(
213
+ pl.lit(pop_var_name).alias("__index__"),
214
+ pl.col(group_var_name).alias("__group__"),
215
+ pl.col("n_subj_pop").cast(pl.String).alias("__value__"),
216
+ )
217
+
218
+ n_empty = n_pop.select(
219
+ pl.lit("").alias("__index__"), pl.col("__group__"), pl.lit("").alias("__value__")
220
+ )
221
+
222
+ # Observation counts
223
+ n_obs = count_subject_with_observation(
224
+ population=population_filtered,
225
+ observation=observation_filtered,
226
+ id=id_var_name,
227
+ group=group_var_name,
228
+ total=total,
229
+ variable="__index__",
230
+ missing_group=missing_group,
231
+ )
232
+
233
+ n_obs = n_obs.select(
234
+ pl.col("__index__"),
235
+ pl.col(group_var_name).alias("__group__"),
236
+ pl.col("n_pct_subj_fmt").alias("__value__"),
237
+ )
238
+
239
+ res = pl.concat([n_pop, n_empty, n_obs])
240
+
241
+ variable_labels = [label for _, label in variables]
242
+ ordered_categories = [pop_var_name, ""] + variable_labels
243
+
244
+ # Ensure all categories are present in Enum
245
+ res = res.with_columns(pl.col("__index__").cast(pl.Enum(ordered_categories))).sort(
246
+ "__index__", "__group__"
247
+ )
248
+
249
+ return res
250
+
251
+
252
+ def cm_summary_df(ard: pl.DataFrame) -> pl.DataFrame:
253
+ """Transform CM summary ARD into display-ready DataFrame."""
254
+ df_wide = ard.pivot(index="__index__", on="__group__", values="__value__")
255
+ return df_wide
256
+
257
+
258
+ def cm_summary_rtf(
259
+ df: pl.DataFrame,
260
+ title: list[str],
261
+ footnote: list[str] | None,
262
+ source: list[str] | None,
263
+ col_rel_width: list[float] | None = None,
264
+ ) -> RTFDocument:
265
+ """Generate RTF table from CM summary display DataFrame."""
266
+ df_rtf = df.rename({"__index__": ""})
267
+ n_cols = len(df_rtf.columns)
268
+ col_header_1 = list(df_rtf.columns)
269
+ col_header_2 = [""] + ["n (%)"] * (n_cols - 1)
270
+
271
+ if col_rel_width is None:
272
+ col_widths = [float(n_cols - 1)] + [1.0] * (n_cols - 1)
273
+ else:
274
+ col_widths = col_rel_width
275
+
276
+ return create_rtf_table_n_pct(
277
+ df=df_rtf,
278
+ col_header_1=col_header_1,
279
+ col_header_2=col_header_2,
280
+ col_widths=col_widths,
281
+ title=title,
282
+ footnote=footnote,
283
+ source=source,
284
+ )
285
+
286
+
287
+ def cm_summary(
288
+ population: pl.DataFrame,
289
+ observation: pl.DataFrame,
290
+ population_filter: str | None,
291
+ observation_filter: str | None,
292
+ id: tuple[str, str],
293
+ group: tuple[str, str],
294
+ variables: list[tuple[str, str]],
295
+ title: list[str],
296
+ footnote: list[str] | None,
297
+ source: list[str] | None,
298
+ output_file: str,
299
+ total: bool = True,
300
+ col_rel_width: list[float] | None = None,
301
+ missing_group: str = "error",
302
+ ) -> str:
303
+ """Complete CM summary pipeline wrapper."""
304
+ ard = cm_summary_ard(
305
+ population=population,
306
+ observation=observation,
307
+ population_filter=population_filter,
308
+ observation_filter=observation_filter,
309
+ id=id,
310
+ group=group,
311
+ variables=variables,
312
+ total=total,
313
+ missing_group=missing_group,
314
+ )
315
+
316
+ df = cm_summary_df(ard)
317
+
318
+ rtf_doc = cm_summary_rtf(
319
+ df=df,
320
+ title=title,
321
+ footnote=footnote,
322
+ source=source,
323
+ col_rel_width=col_rel_width,
324
+ )
325
+ rtf_doc.write_rtf(output_file)
326
+
327
+ return output_file
csrlite/common/rtf.py CHANGED
@@ -83,3 +83,55 @@ def create_rtf_table_n_pct(
83
83
  rtf_components["rtf_source"] = RTFSource(text=source_list)
84
84
 
85
85
  return RTFDocument(**rtf_components)
86
+
87
+
88
+ def create_rtf_listing(
89
+ df: pl.DataFrame,
90
+ col_header: list[str],
91
+ col_widths: list[float] | None,
92
+ title: list[str] | str,
93
+ footnote: list[str] | str | None,
94
+ source: list[str] | str | None,
95
+ orientation: str = "landscape",
96
+ ) -> RTFDocument:
97
+ """
98
+ Create a standardized RTF listing document.
99
+ """
100
+ n_cols = len(df.columns)
101
+
102
+ # Calculate column widths if None
103
+ if col_widths is None:
104
+ col_widths = [1.0] * n_cols
105
+
106
+ # Normalize metadata
107
+ title_list = [title] if isinstance(title, str) else title
108
+ footnote_list = [footnote] if isinstance(footnote, str) else (footnote or [])
109
+ source_list = [source] if isinstance(source, str) else (source or [])
110
+
111
+ headers = [
112
+ RTFColumnHeader(
113
+ text=col_header,
114
+ col_rel_width=col_widths,
115
+ text_justification=["l"] * n_cols,
116
+ )
117
+ ]
118
+
119
+ rtf_components: dict[str, Any] = {
120
+ "df": df,
121
+ "rtf_page": RTFPage(orientation=orientation),
122
+ "rtf_title": RTFTitle(text=title_list),
123
+ "rtf_column_header": headers,
124
+ "rtf_body": RTFBody(
125
+ col_rel_width=col_widths,
126
+ text_justification=["l"] * n_cols,
127
+ border_left=["single"] * n_cols,
128
+ ),
129
+ }
130
+
131
+ if footnote_list:
132
+ rtf_components["rtf_footnote"] = RTFFootnote(text=footnote_list)
133
+
134
+ if source_list:
135
+ rtf_components["rtf_source"] = RTFSource(text=source_list)
136
+
137
+ return RTFDocument(**rtf_components)
@@ -39,7 +39,7 @@ def study_plan_to_disposition_summary(
39
39
 
40
40
  id = ("USUBJID", "Subject ID")
41
41
  ds_term = ("EOSSTT", "Disposition Status")
42
- dist_reason_term = ("DCREASCD", "Discontinued Reason")
42
+ dist_reason_term = ("DCSREAS", "Discontinued Reason")
43
43
 
44
44
  total = True
45
45
  missing_group = "error"
@@ -0,0 +1,109 @@
1
+ # pyre-strict
2
+ """
3
+ Inclusion/Exclusion (IE) Listing Analysis Functions
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ import polars as pl
9
+
10
+ from ..common.parse import StudyPlanParser
11
+ from ..common.plan import StudyPlan
12
+ from ..common.rtf import create_rtf_listing
13
+ from ..common.utils import apply_common_filters
14
+
15
+
16
+ def study_plan_to_ie_listing(
17
+ study_plan: StudyPlan,
18
+ ) -> list[str]:
19
+ """
20
+ Generate IE Listing outputs.
21
+ """
22
+ # Meta data
23
+ analysis_type = "ie_listing"
24
+ output_dir = study_plan.output_dir
25
+ title = "Listing of Protocol Deviations"
26
+
27
+ # Ensure output directory exists
28
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
29
+
30
+ # Initialize parser
31
+ parser = StudyPlanParser(study_plan)
32
+
33
+ # Get expanded plan (Manually expansion to avoid AttributeError)
34
+ plans = study_plan.study_data.get("plans", [])
35
+ all_specs = []
36
+ for plan_data in plans:
37
+ expanded = study_plan.expander.expand_plan(plan_data)
38
+ for p in expanded:
39
+ all_specs.append(study_plan.expander.create_analysis_spec(p))
40
+
41
+ plan_df = pl.DataFrame(all_specs)
42
+
43
+ if "analysis" in plan_df.columns:
44
+ listing_plans = plan_df.filter(pl.col("analysis") == analysis_type)
45
+ else:
46
+ listing_plans = pl.DataFrame()
47
+
48
+ generated_files = []
49
+
50
+ # If listing_plans is empty, create a dummy row to force generation
51
+ if listing_plans.height == 0:
52
+ listing_plans = pl.DataFrame([{"population": "enrolled", "analysis": analysis_type}])
53
+
54
+ for analysis in listing_plans.iter_rows(named=True):
55
+ # Load ADSL
56
+ pop_name = analysis.get("population", "enrolled")
57
+
58
+ try:
59
+ (adsl_raw,) = parser.get_datasets("adsl")
60
+ pop_filter = parser.get_population_filter(pop_name)
61
+
62
+ adsl, _ = apply_common_filters(
63
+ population=adsl_raw,
64
+ observation=None,
65
+ population_filter=pop_filter,
66
+ observation_filter=None,
67
+ )
68
+
69
+ except ValueError as e:
70
+ print(f"Error loading population: {e}")
71
+ continue
72
+
73
+ # Output filename
74
+ filename = f"{analysis_type}_{pop_name}.rtf".lower()
75
+ output_path = f"{output_dir}/{filename}"
76
+
77
+ # Generate DF
78
+ df = ie_listing_df(adsl)
79
+
80
+ # Generate RTF
81
+ ie_listing_rtf(df, output_path, title=title)
82
+
83
+ generated_files.append(output_path)
84
+
85
+ return generated_files
86
+
87
+
88
+ def ie_listing_df(adsl: pl.DataFrame) -> pl.DataFrame:
89
+ """Select columns for Listing."""
90
+ # Check if DCSREAS exists
91
+ cols = ["USUBJID", "DCSREAS"]
92
+ available = [c for c in cols if c in adsl.columns]
93
+ return adsl.select(available)
94
+
95
+
96
+ def ie_listing_rtf(df: pl.DataFrame, output_path: str, title: str | list[str] = "") -> None:
97
+ """Generate RTF Listing."""
98
+ col_widths = [1.5, 3.5] # Approximate ratio
99
+
100
+ rtf_doc = create_rtf_listing(
101
+ df=df,
102
+ col_header=list(df.columns),
103
+ col_widths=col_widths,
104
+ title=title,
105
+ footnote=[],
106
+ source=[],
107
+ )
108
+
109
+ rtf_doc.write_rtf(output_path)