csrlite 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,399 @@
1
+ # pyre-strict
2
+ """
3
+ Adverse Event (AE) Analysis Functions
4
+
5
+ This module provides a three-step pipeline for AE summary analysis:
6
+ - ae_summary_ard: Generate Analysis Results Data (ARD) in long format
7
+ - ae_summary_df: Transform ARD to wide display format
8
+ - ae_summary_rtf: Generate formatted RTF output
9
+ - ae_summary: Complete pipeline wrapper
10
+ - study_plan_to_ae_summary: Batch generation from StudyPlan
11
+
12
+ Uses Polars native SQL capabilities for data manipulation, count.py utilities for subject counting,
13
+ and parse.py utilities for StudyPlan parsing.
14
+ """
15
+
16
+ from pathlib import Path
17
+
18
+ import polars as pl
19
+ from rtflite import RTFDocument
20
+
21
+ from ..common.count import count_subject, count_subject_with_observation
22
+ from ..common.parse import StudyPlanParser
23
+ from ..common.plan import StudyPlan
24
+ from ..common.utils import apply_common_filters
25
+ from .ae_utils import create_ae_rtf_table
26
+
27
+
28
+ def study_plan_to_ae_summary(
29
+ study_plan: StudyPlan,
30
+ ) -> list[str]:
31
+ """
32
+ Generate AE summary RTF outputs for all analyses defined in StudyPlan.
33
+
34
+ This function reads the expanded plan from StudyPlan and generates
35
+ an RTF table for each analysis specification automatically.
36
+
37
+ Args:
38
+ study_plan: StudyPlan object with loaded datasets and analysis specifications
39
+
40
+ Returns:
41
+ list[str]: List of paths to generated RTF files
42
+ """
43
+
44
+ # Meta data
45
+ analysis = "ae_summary"
46
+ analysis_label = "Analysis of Adverse Event Summary"
47
+ output_dir = study_plan.output_dir
48
+ footnote = ["Every participant is counted a single time for each applicable row and column."]
49
+ source = None
50
+
51
+ population_df_name = "adsl"
52
+ observation_df_name = "adae"
53
+
54
+ id = ("USUBJID", "Subject ID")
55
+ total = True
56
+ missing_group = "error"
57
+
58
+ # Create output directory if it doesn't exist
59
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
60
+
61
+ # Initialize parser
62
+ parser = StudyPlanParser(study_plan)
63
+
64
+ # Get expanded plan DataFrame
65
+ plan_df = study_plan.get_plan_df()
66
+
67
+ # Filter for AE summary analyses
68
+ ae_plans = plan_df.filter(pl.col("analysis") == analysis)
69
+
70
+ rtf_files = []
71
+
72
+ # Generate RTF for each analysis
73
+ for row in ae_plans.iter_rows(named=True):
74
+ population = row["population"]
75
+ observation = row.get("observation")
76
+ parameter = row["parameter"]
77
+ group = row.get("group")
78
+
79
+ # Validate group is specified
80
+ if group is None:
81
+ raise ValueError(
82
+ f"Group not specified in YAML "
83
+ f"population={population}, observation={observation}, parameter={parameter}. "
84
+ "Please add group to your YAML plan."
85
+ )
86
+
87
+ # Get datasets using parser
88
+ population_df, observation_df = parser.get_datasets(population_df_name, observation_df_name)
89
+
90
+ # Get filters and configuration using parser
91
+ population_filter = parser.get_population_filter(population)
92
+ param_names, param_filters, param_labels, _ = parser.get_parameter_info(
93
+ parameter
94
+ ) # Ignore indent for AE
95
+ obs_filter = parser.get_observation_filter(observation)
96
+ group_var_name, group_labels = parser.get_group_info(group)
97
+
98
+ # Build variables as list of tuples [(filter, label)]
99
+ variables_list = list(zip(param_filters, param_labels))
100
+
101
+ # Build group tuple (variable_name, label)
102
+ group_var_label = group_labels[0] if group_labels else group_var_name
103
+ group_tuple = (group_var_name, group_var_label)
104
+
105
+ # Build title with population and observation context
106
+ title_parts = [analysis_label]
107
+ if observation:
108
+ obs_kw = study_plan.keywords.observations.get(observation)
109
+ if obs_kw and obs_kw.label:
110
+ title_parts.append(obs_kw.label)
111
+
112
+ pop_kw = study_plan.keywords.populations.get(population)
113
+ if pop_kw and pop_kw.label:
114
+ title_parts.append(pop_kw.label)
115
+
116
+ # Build output filename
117
+ filename = f"{analysis}_{population}"
118
+ if observation:
119
+ filename += f"_{observation}"
120
+ filename += f"_{parameter.replace(';', '_')}.rtf"
121
+ output_file = str(Path(output_dir) / filename)
122
+
123
+ # Generate RTF using the new ae_summary signature
124
+ rtf_path = ae_summary(
125
+ population=population_df,
126
+ observation=observation_df,
127
+ population_filter=population_filter,
128
+ observation_filter=obs_filter,
129
+ id=id,
130
+ group=group_tuple,
131
+ variables=variables_list,
132
+ title=title_parts,
133
+ footnote=footnote,
134
+ source=source,
135
+ output_file=output_file,
136
+ total=total,
137
+ missing_group=missing_group,
138
+ )
139
+
140
+ rtf_files.append(rtf_path)
141
+
142
+ return rtf_files
143
+
144
+
145
+ def ae_summary(
146
+ population: pl.DataFrame,
147
+ observation: pl.DataFrame,
148
+ population_filter: str | None,
149
+ observation_filter: str | None,
150
+ id: tuple[str, str],
151
+ group: tuple[str, str],
152
+ variables: list[tuple[str, str]],
153
+ title: list[str],
154
+ footnote: list[str] | None,
155
+ source: list[str] | None,
156
+ output_file: str,
157
+ total: bool = True,
158
+ col_rel_width: list[float] | None = None,
159
+ missing_group: str = "error",
160
+ ) -> str:
161
+ """
162
+ Complete AE summary pipeline wrapper.
163
+
164
+ This function orchestrates the three-step pipeline:
165
+ 1. ae_summary_ard: Generate Analysis Results Data
166
+ 2. ae_summary_df: Transform to display format
167
+ 3. ae_summary_rtf: Generate RTF output and write to file
168
+
169
+ Args:
170
+ population: Population DataFrame (subject-level data, e.g., ADSL)
171
+ observation: Observation DataFrame (event data, e.g., ADAE)
172
+ population_filter: SQL WHERE clause for population (can be None)
173
+ observation_filter: SQL WHERE clause for observation (can be None)
174
+ id: Tuple (variable_name, label) for ID column
175
+ group: Tuple (variable_name, label) for grouping variable
176
+ variables: List of tuples [(filter, label)] for analysis variables
177
+ title: Title for RTF output as list of strings
178
+ footnote: Optional footnote for RTF output as list of strings
179
+ source: Optional source for RTF output as list of strings
180
+ output_file: File path to write RTF output
181
+ total: Whether to include total column (default: True)
182
+ col_rel_width: Optional column widths for RTF output
183
+ missing_group: How to handle missing group values (default: "error")
184
+
185
+ Returns:
186
+ str: Path to the generated RTF file
187
+ """
188
+ # Step 1: Generate ARD
189
+ ard = ae_summary_ard(
190
+ population=population,
191
+ observation=observation,
192
+ population_filter=population_filter,
193
+ observation_filter=observation_filter,
194
+ id=id,
195
+ group=group,
196
+ variables=variables,
197
+ total=total,
198
+ missing_group=missing_group,
199
+ )
200
+
201
+ # Step 2: Transform to display format
202
+ df = ae_summary_df(ard)
203
+
204
+ # Step 3: Generate RTF and write to file
205
+ rtf_doc = ae_summary_rtf(
206
+ df=df,
207
+ title=title,
208
+ footnote=footnote,
209
+ source=source,
210
+ col_rel_width=col_rel_width,
211
+ )
212
+ rtf_doc.write_rtf(output_file)
213
+
214
+ return output_file
215
+
216
+
217
+ def ae_summary_ard(
218
+ population: pl.DataFrame,
219
+ observation: pl.DataFrame,
220
+ population_filter: str | None,
221
+ observation_filter: str | None,
222
+ id: tuple[str, str],
223
+ group: tuple[str, str],
224
+ variables: list[tuple[str, str]],
225
+ total: bool,
226
+ missing_group: str,
227
+ ) -> pl.DataFrame:
228
+ """
229
+ Generate Analysis Results Data (ARD) for AE summary analysis.
230
+
231
+ Creates a long-format DataFrame with standardized structure (__index__, __group__, __value__)
232
+ containing population counts and observation statistics for each analysis variable.
233
+
234
+ Args:
235
+ population: Population DataFrame (subject-level data, e.g., ADSL)
236
+ observation: Observation DataFrame (event data, e.g., ADAE)
237
+ population_filter: SQL WHERE clause for population (can be None)
238
+ observation_filter: SQL WHERE clause for observation (can be None)
239
+ id: Tuple (variable_name, label) for ID column
240
+ group: Tuple (variable_name, label) for grouping variable
241
+ variables: List of tuples [(filter, label)] for analysis variables
242
+ total: Whether to include total column in counts
243
+ missing_group: How to handle missing group values: "error", "ignore", or "fill"
244
+
245
+ Returns:
246
+ pl.DataFrame: Long-format ARD with columns __index__, __group__, __value__
247
+ """
248
+ # Extract group variable name (label is in tuple but not needed separately)
249
+ pop_var_name = "Participants in population"
250
+ id_var_name, id_var_label = id
251
+ group_var_name, group_var_label = group
252
+
253
+ # Apply common filters (parameter_filter is handled inside the loop, so None here)
254
+ population_filtered, observation_to_filter = apply_common_filters(
255
+ population=population,
256
+ observation=observation,
257
+ population_filter=population_filter,
258
+ observation_filter=observation_filter,
259
+ )
260
+
261
+ # Filter observation data to include only subjects in the filtered population
262
+ # Process all variables in the list
263
+ observation_filtered_list = []
264
+ for variable_filter, variable_label in variables:
265
+ obs_filtered = (
266
+ observation_to_filter.filter(
267
+ pl.col(id_var_name).is_in(population_filtered[id_var_name].to_list())
268
+ )
269
+ .filter(pl.sql_expr(variable_filter))
270
+ .with_columns(pl.lit(variable_label).alias("__index__"))
271
+ )
272
+
273
+ observation_filtered_list.append(obs_filtered)
274
+
275
+ # Concatenate all filtered observations
276
+ observation_filtered = pl.concat(observation_filtered_list)
277
+
278
+ # Population
279
+ n_pop = count_subject(
280
+ population=population_filtered,
281
+ id=id_var_name,
282
+ group=group_var_name,
283
+ total=total,
284
+ missing_group=missing_group,
285
+ )
286
+
287
+ n_pop = n_pop.select(
288
+ pl.lit(pop_var_name).alias("__index__"),
289
+ pl.col(group_var_name).alias("__group__"),
290
+ pl.col("n_subj_pop").cast(pl.String).alias("__value__"),
291
+ )
292
+
293
+ # Empty row with same structure as n_pop but with empty strings
294
+ n_empty = n_pop.select(
295
+ pl.lit("").alias("__index__"), pl.col("__group__"), pl.lit("").alias("__value__")
296
+ )
297
+
298
+ # Observation
299
+ n_obs = count_subject_with_observation(
300
+ population=population_filtered,
301
+ observation=observation_filtered,
302
+ id=id_var_name,
303
+ group=group_var_name,
304
+ total=total,
305
+ variable="__index__",
306
+ missing_group=missing_group,
307
+ )
308
+
309
+ n_obs = n_obs.select(
310
+ pl.col("__index__"),
311
+ pl.col(group_var_name).alias("__group__"),
312
+ pl.col("n_pct_subj_fmt").alias("__value__"),
313
+ )
314
+
315
+ res = pl.concat([n_pop, n_empty, n_obs])
316
+
317
+ # Convert __index__ to ordered Enum based on appearance
318
+ # Build the ordered categories list: population name, empty string, then variable labels
319
+ variable_labels = [label for _, label in variables]
320
+ ordered_categories = [pop_var_name, ""] + variable_labels
321
+
322
+ res = res.with_columns(pl.col("__index__").cast(pl.Enum(ordered_categories))).sort(
323
+ "__index__", "__group__"
324
+ )
325
+
326
+ return res
327
+
328
+
329
+ def ae_summary_df(ard: pl.DataFrame) -> pl.DataFrame:
330
+ """
331
+ Transform AE summary ARD (Analysis Results Data) into display-ready DataFrame.
332
+
333
+ Converts the long-format ARD with __index__, __group__, and __value__ columns
334
+ into a wide-format display table where groups become columns.
335
+
336
+ Args:
337
+ ard: Analysis Results Data DataFrame with __index__, __group__, __value__ columns
338
+
339
+ Returns:
340
+ pl.DataFrame: Wide-format display table with groups as columns
341
+ """
342
+ # Pivot from long to wide format: __group__ values become columns
343
+ df_wide = ard.pivot(index="__index__", on="__group__", values="__value__")
344
+
345
+ return df_wide
346
+
347
+
348
+ def ae_summary_rtf(
349
+ df: pl.DataFrame,
350
+ title: list[str],
351
+ footnote: list[str] | None,
352
+ source: list[str] | None,
353
+ col_rel_width: list[float] | None = None,
354
+ ) -> RTFDocument:
355
+ """
356
+ Generate RTF table from AE summary display DataFrame.
357
+
358
+ Creates a formatted RTF table with two-level column headers showing
359
+ treatment groups with "n (%)" values.
360
+
361
+ Args:
362
+ df: Display DataFrame from ae_summary_df (wide format with __index__ column)
363
+ title: Title(s) for the table as list of strings
364
+ footnote: Optional footnote(s) as list of strings
365
+ source: Optional source note(s) as list of strings
366
+ col_rel_width: Optional list of relative column widths. If None, auto-calculated
367
+ as [n_cols-1, 1, 1, 1, ...] where n_cols is total column count
368
+
369
+ Returns:
370
+ RTFDocument: RTF document object that can be written to file
371
+ """
372
+
373
+ # Rename __index__ to empty string for display
374
+ df_rtf = df.rename({"__index__": ""})
375
+
376
+ # Calculate number of columns
377
+ n_cols = len(df_rtf.columns)
378
+
379
+ # Build first-level column headers (use actual column names)
380
+ col_header_1 = list(df_rtf.columns)
381
+
382
+ # Build second-level column headers (empty for first, "n (%)" for groups)
383
+ col_header_2 = [""] + ["n (%)"] * (n_cols - 1)
384
+
385
+ # Calculate column widths - auto-calculate if not provided
386
+ if col_rel_width is None:
387
+ col_widths = [float(n_cols - 1)] + [1.0] * (n_cols - 1)
388
+ else:
389
+ col_widths = col_rel_width
390
+
391
+ return create_ae_rtf_table(
392
+ df=df_rtf,
393
+ col_header_1=col_header_1,
394
+ col_header_2=col_header_2,
395
+ col_widths=col_widths,
396
+ title=title,
397
+ footnote=footnote,
398
+ source=source,
399
+ )
csrlite/ae/ae_utils.py ADDED
@@ -0,0 +1,132 @@
1
+ # pyre-strict
2
+ from typing import Any
3
+
4
+ import polars as pl
5
+ from rtflite import RTFBody, RTFColumnHeader, RTFDocument, RTFFootnote, RTFPage, RTFSource, RTFTitle
6
+
7
+
8
+ def get_ae_parameter_title(param: Any, prefix: str = "Participants With") -> str:
9
+ """
10
+ Extract title from parameter for ae_* title generation.
11
+
12
+ Args:
13
+ param: Parameter object with terms attribute
14
+ prefix: Prefix for the title (e.g. "Participants With", "Listing of Participants With")
15
+
16
+ Returns:
17
+ Title string for the analysis
18
+ """
19
+ default_suffix = "Adverse Events"
20
+
21
+ if not param:
22
+ return f"{prefix} {default_suffix}"
23
+
24
+ # Check for terms attribute
25
+ if hasattr(param, "terms") and param.terms and isinstance(param.terms, dict):
26
+ terms = param.terms
27
+
28
+ # Preprocess to empty strings (avoiding None)
29
+ before = terms.get("before", "").title()
30
+ after = terms.get("after", "").title()
31
+
32
+ # Build title and clean up extra spaces
33
+ title = f"{prefix} {before} {default_suffix} {after}"
34
+ return " ".join(title.split()) # Remove extra spaces
35
+
36
+ # Fallback to default
37
+ return f"{prefix} {default_suffix}"
38
+
39
+
40
+ def get_ae_parameter_row_labels(param: Any) -> tuple[str, str]:
41
+ """
42
+ Generate n_with and n_without row labels based on parameter terms.
43
+
44
+ Returns:
45
+ Tuple of (n_with_label, n_without_label)
46
+ """
47
+ # Default labels
48
+ default_with = " with one or more adverse events"
49
+ default_without = " with no adverse events"
50
+
51
+ if not param or not hasattr(param, "terms") or not param.terms:
52
+ return (default_with, default_without)
53
+
54
+ terms = param.terms
55
+ before = terms.get("before", "").lower()
56
+ after = terms.get("after", "").lower()
57
+
58
+ # Build dynamic labels with leading indentation
59
+ with_label = f"with one or more {before} adverse events {after}"
60
+ without_label = f"with no {before} adverse events {after}"
61
+
62
+ # Clean up extra spaces and add back the 4-space indentation
63
+ with_label = " " + " ".join(with_label.split())
64
+ without_label = " " + " ".join(without_label.split())
65
+
66
+ return (with_label, without_label)
67
+
68
+
69
+ def create_ae_rtf_table(
70
+ df: pl.DataFrame,
71
+ col_header_1: list[str],
72
+ col_header_2: list[str] | None,
73
+ col_widths: list[float] | None,
74
+ title: list[str] | str,
75
+ footnote: list[str] | str | None,
76
+ source: list[str] | str | None,
77
+ borders_2: bool = True,
78
+ orientation: str = "landscape",
79
+ ) -> RTFDocument:
80
+ """
81
+ Create a standardized RTF table document with 1 or 2 header rows.
82
+ """
83
+ n_cols = len(df.columns)
84
+
85
+ # Calculate column widths if None - simple default
86
+ if col_widths is None:
87
+ col_widths = [1] * n_cols
88
+
89
+ # Normalize metadata
90
+ title_list = [title] if isinstance(title, str) else title
91
+ footnote_list = [footnote] if isinstance(footnote, str) else (footnote or [])
92
+ source_list = [source] if isinstance(source, str) else (source or [])
93
+
94
+ headers = [
95
+ RTFColumnHeader(
96
+ text=col_header_1,
97
+ col_rel_width=col_widths,
98
+ text_justification=["l"] + ["c"] * (n_cols - 1),
99
+ )
100
+ ]
101
+
102
+ if col_header_2:
103
+ h2_kwargs = {
104
+ "text": col_header_2,
105
+ "col_rel_width": col_widths,
106
+ "text_justification": ["l"] + ["c"] * (n_cols - 1),
107
+ }
108
+ if borders_2:
109
+ h2_kwargs["border_left"] = ["single"]
110
+ h2_kwargs["border_top"] = [""]
111
+
112
+ headers.append(RTFColumnHeader(**h2_kwargs))
113
+
114
+ rtf_components: dict[str, Any] = {
115
+ "df": df,
116
+ "rtf_page": RTFPage(orientation=orientation),
117
+ "rtf_title": RTFTitle(text=title_list),
118
+ "rtf_column_header": headers,
119
+ "rtf_body": RTFBody(
120
+ col_rel_width=col_widths,
121
+ text_justification=["l"] + ["c"] * (n_cols - 1),
122
+ border_left=["single"] * n_cols,
123
+ ),
124
+ }
125
+
126
+ if footnote_list:
127
+ rtf_components["rtf_footnote"] = RTFFootnote(text=footnote_list)
128
+
129
+ if source_list:
130
+ rtf_components["rtf_source"] = RTFSource(text=source_list)
131
+
132
+ return RTFDocument(**rtf_components)