csrlite 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. {csrlite-0.1.0 → csrlite-0.2.0}/MANIFEST.in +1 -0
  2. {csrlite-0.1.0/src/csrlite.egg-info → csrlite-0.2.0}/PKG-INFO +7 -7
  3. {csrlite-0.1.0 → csrlite-0.2.0}/pyproject.toml +7 -7
  4. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/__init__.py +16 -8
  5. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/ae/ae_listing.py +2 -0
  6. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/ae/ae_specific.py +10 -5
  7. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/ae/ae_summary.py +4 -2
  8. csrlite-0.2.0/src/csrlite/ae/ae_utils.py +62 -0
  9. csrlite-0.2.0/src/csrlite/common/config.py +34 -0
  10. csrlite-0.2.0/src/csrlite/common/count.py +293 -0
  11. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/common/plan.py +79 -67
  12. csrlite-0.1.0/src/csrlite/ae/ae_utils.py → csrlite-0.2.0/src/csrlite/common/rtf.py +16 -63
  13. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/common/utils.py +4 -4
  14. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/disposition/disposition.py +126 -95
  15. {csrlite-0.1.0 → csrlite-0.2.0/src/csrlite.egg-info}/PKG-INFO +7 -7
  16. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite.egg-info/SOURCES.txt +2 -0
  17. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite.egg-info/requires.txt +6 -6
  18. csrlite-0.1.0/src/csrlite/common/count.py +0 -199
  19. {csrlite-0.1.0 → csrlite-0.2.0}/README.md +0 -0
  20. {csrlite-0.1.0 → csrlite-0.2.0}/setup.cfg +0 -0
  21. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/ae/__init__.py +0 -0
  22. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/common/parse.py +0 -0
  23. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/common/yaml_loader.py +0 -0
  24. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite/disposition/__init__.py +0 -0
  25. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite.egg-info/dependency_links.txt +0 -0
  26. {csrlite-0.1.0 → csrlite-0.2.0}/src/csrlite.egg-info/top_level.txt +0 -0
@@ -18,6 +18,7 @@ prune docs
18
18
  prune data
19
19
  prune studies
20
20
  prune .github
21
+ prune scripts
21
22
  exclude .gitignore
22
23
  exclude .pyre_configuration
23
24
  exclude _quarto.yml
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: csrlite
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: A hierarchical YAML-based framework for generating Tables, Listings, and Figures in clinical trials
5
5
  Author-email: Clinical Biostatistics Team <biostat@example.com>
6
6
  License: MIT
@@ -28,17 +28,17 @@ Provides-Extra: plotting
28
28
  Requires-Dist: matplotlib>=3.5.0; extra == "plotting"
29
29
  Requires-Dist: plotly>=5.0.0; extra == "plotting"
30
30
  Provides-Extra: dev
31
- Requires-Dist: pytest>=7.0.0; extra == "dev"
32
31
  Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
33
- Requires-Dist: black>=22.0.0; extra == "dev"
34
- Requires-Dist: isort>=5.0.0; extra == "dev"
35
- Requires-Dist: mypy>=1.0.0; extra == "dev"
36
32
  Requires-Dist: pytest>=9.0.1; extra == "dev"
33
+ Requires-Dist: black>=22.0.0; extra == "dev"
34
+ Requires-Dist: isort>=7.0.0; extra == "dev"
35
+ Requires-Dist: ruff>=0.14.8; extra == "dev"
36
+ Requires-Dist: mypy>=1.19.0; extra == "dev"
37
+ Requires-Dist: quarto>=0.1.0; extra == "dev"
38
+ Requires-Dist: pyre-check>=0.9.18; extra == "dev"
37
39
  Requires-Dist: jupyter>=1.1.1; extra == "dev"
38
40
  Requires-Dist: jupyter-cache>=1.0.1; extra == "dev"
39
41
  Requires-Dist: nbformat>=5.10.4; extra == "dev"
40
- Requires-Dist: ruff>=0.1.0; extra == "dev"
41
- Requires-Dist: pyre-check>=0.9.18; extra == "dev"
42
42
  Provides-Extra: all
43
43
  Requires-Dist: rtflite; extra == "all"
44
44
  Requires-Dist: matplotlib>=3.5.0; extra == "all"
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "csrlite"
7
- version = "0.1.0"
7
+ version = "0.2.0"
8
8
  description = "A hierarchical YAML-based framework for generating Tables, Listings, and Figures in clinical trials"
9
9
  authors = [{name = "Clinical Biostatistics Team", email = "biostat@example.com"}]
10
10
  license = {text = "MIT"}
@@ -32,17 +32,17 @@ dependencies = [
32
32
  rtf = ["rtflite"]
33
33
  plotting = ["matplotlib>=3.5.0", "plotly>=5.0.0"]
34
34
  dev = [
35
- "pytest>=7.0.0",
36
35
  "pytest-cov>=4.0.0",
37
- "black>=22.0.0",
38
- "isort>=5.0.0",
39
- "mypy>=1.0.0",
40
36
  "pytest>=9.0.1",
37
+ "black>=22.0.0",
38
+ "isort>=7.0.0",
39
+ "ruff>=0.14.8",
40
+ "mypy>=1.19.0",
41
+ "quarto>=0.1.0",
42
+ "pyre-check>=0.9.18",
41
43
  "jupyter>=1.1.1",
42
44
  "jupyter-cache>=1.0.1",
43
45
  "nbformat>=5.10.4",
44
- "ruff>=0.1.0",
45
- "pyre-check>=0.9.18",
46
46
  ]
47
47
  all = ["rtflite", "matplotlib>=3.5.0", "plotly>=5.0.0"]
48
48
 
@@ -1,18 +1,19 @@
1
- from .ae.ae_listing import (
2
- # AE listing functions
1
+ import logging
2
+ import sys
3
+
4
+ from .ae.ae_listing import ( # AE listing functions
3
5
  ae_listing,
4
6
  study_plan_to_ae_listing,
5
7
  )
6
- from .ae.ae_specific import (
7
- # AE specific functions
8
+ from .ae.ae_specific import ( # AE specific functions
8
9
  ae_specific,
9
10
  study_plan_to_ae_specific,
10
11
  )
11
- from .ae.ae_summary import (
12
- # AE summary functions
12
+ from .ae.ae_summary import ( # AE summary functions
13
13
  ae_summary,
14
14
  study_plan_to_ae_summary,
15
15
  )
16
+ from .common.config import config
16
17
  from .common.count import (
17
18
  count_subject,
18
19
  count_subject_with_observation,
@@ -21,12 +22,19 @@ from .common.parse import (
21
22
  StudyPlanParser,
22
23
  parse_filter_to_sql,
23
24
  )
24
- from .common.plan import (
25
- # Core classes
25
+ from .common.plan import ( # Core classes
26
26
  load_plan,
27
27
  )
28
28
  from .disposition.disposition import study_plan_to_disposition_summary
29
29
 
30
+ # Configure logging
31
+ logging.basicConfig(
32
+ level=config.logging_level,
33
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
34
+ stream=sys.stdout,
35
+ )
36
+ logger = logging.getLogger("csrlite")
37
+
30
38
  # Main exports for common usage
31
39
  __all__ = [
32
40
  # Primary user interface
@@ -71,6 +71,8 @@ def ae_listing_ard(
71
71
  parameter_filter=parameter_filter,
72
72
  )
73
73
 
74
+ assert observation_to_filter is not None
75
+
74
76
  # Filter observation to include only subjects in filtered population
75
77
  observation_filtered = observation_to_filter.filter(
76
78
  pl.col(id_var_name).is_in(population_filtered[id_var_name].to_list())
@@ -24,8 +24,9 @@ from rtflite import RTFDocument
24
24
  from ..common.count import count_subject, count_subject_with_observation
25
25
  from ..common.parse import StudyPlanParser
26
26
  from ..common.plan import StudyPlan
27
+ from ..common.rtf import create_rtf_table_n_pct
27
28
  from ..common.utils import apply_common_filters
28
- from .ae_utils import create_ae_rtf_table, get_ae_parameter_row_labels, get_ae_parameter_title
29
+ from .ae_utils import get_ae_parameter_row_labels, get_ae_parameter_title
29
30
 
30
31
 
31
32
  def ae_specific_ard(
@@ -80,6 +81,8 @@ def ae_specific_ard(
80
81
  parameter_filter=parameter_filter,
81
82
  )
82
83
 
84
+ assert observation_to_filter is not None
85
+
83
86
  # Filter observation to include only subjects in filtered population
84
87
  observation_filtered = observation_to_filter.filter(
85
88
  pl.col(id_var_name).is_in(population_filtered[id_var_name].to_list())
@@ -114,7 +117,9 @@ def ae_specific_ard(
114
117
 
115
118
  # Get population with event indicator
116
119
  pop_with_indicator = population_filtered.with_columns(
117
- pl.col(id_var_name).is_in(subjects_with_events[id_var_name]).alias("__has_event__")
120
+ pl.col(id_var_name)
121
+ .is_in(subjects_with_events[id_var_name].to_list())
122
+ .alias("__has_event__")
118
123
  )
119
124
 
120
125
  # Count subjects with and without events using count_subject_with_observation
@@ -129,7 +134,7 @@ def ae_specific_ard(
129
134
  )
130
135
 
131
136
  # Extract 'with' counts
132
- n_with = event_counts.filter(pl.col("__has_event__")).select(
137
+ n_with = event_counts.filter(pl.col("__has_event__") == "true").select(
133
138
  [
134
139
  pl.lit(n_with_label).alias("__index__"),
135
140
  pl.col(group_var_name).cast(pl.String).alias("__group__"),
@@ -138,7 +143,7 @@ def ae_specific_ard(
138
143
  )
139
144
 
140
145
  # Extract 'without' counts
141
- n_without = event_counts.filter(~pl.col("__has_event__")).select(
146
+ n_without = event_counts.filter(pl.col("__has_event__") == "false").select(
142
147
  [
143
148
  pl.lit(n_without_label).alias("__index__"),
144
149
  pl.col(group_var_name).cast(pl.String).alias("__group__"),
@@ -254,7 +259,7 @@ def ae_specific_rtf(
254
259
  else:
255
260
  col_widths = col_rel_width
256
261
 
257
- return create_ae_rtf_table(
262
+ return create_rtf_table_n_pct(
258
263
  df=df_rtf,
259
264
  col_header_1=col_header_1,
260
265
  col_header_2=col_header_2,
@@ -21,8 +21,8 @@ from rtflite import RTFDocument
21
21
  from ..common.count import count_subject, count_subject_with_observation
22
22
  from ..common.parse import StudyPlanParser
23
23
  from ..common.plan import StudyPlan
24
+ from ..common.rtf import create_rtf_table_n_pct
24
25
  from ..common.utils import apply_common_filters
25
- from .ae_utils import create_ae_rtf_table
26
26
 
27
27
 
28
28
  def study_plan_to_ae_summary(
@@ -258,6 +258,8 @@ def ae_summary_ard(
258
258
  observation_filter=observation_filter,
259
259
  )
260
260
 
261
+ assert observation_to_filter is not None
262
+
261
263
  # Filter observation data to include only subjects in the filtered population
262
264
  # Process all variables in the list
263
265
  observation_filtered_list = []
@@ -388,7 +390,7 @@ def ae_summary_rtf(
388
390
  else:
389
391
  col_widths = col_rel_width
390
392
 
391
- return create_ae_rtf_table(
393
+ return create_rtf_table_n_pct(
392
394
  df=df_rtf,
393
395
  col_header_1=col_header_1,
394
396
  col_header_2=col_header_2,
@@ -0,0 +1,62 @@
1
+ from typing import Any
2
+
3
+
4
+ def get_ae_parameter_title(param: Any, prefix: str = "Participants With") -> str:
5
+ """
6
+ Extract title from parameter for ae_* title generation.
7
+
8
+ Args:
9
+ param: Parameter object with terms attribute
10
+ prefix: Prefix for the title (e.g. "Participants With", "Listing of Participants With")
11
+
12
+ Returns:
13
+ Title string for the analysis
14
+ """
15
+ default_suffix = "Adverse Events"
16
+
17
+ if not param:
18
+ return f"{prefix} {default_suffix}"
19
+
20
+ # Check for terms attribute
21
+ if hasattr(param, "terms") and param.terms and isinstance(param.terms, dict):
22
+ terms = param.terms
23
+
24
+ # Preprocess to empty strings (avoiding None)
25
+ before = terms.get("before", "").title()
26
+ after = terms.get("after", "").title()
27
+
28
+ # Build title and clean up extra spaces
29
+ title = f"{prefix} {before} {default_suffix} {after}"
30
+ return " ".join(title.split()) # Remove extra spaces
31
+
32
+ # Fallback to default
33
+ return f"{prefix} {default_suffix}"
34
+
35
+
36
+ def get_ae_parameter_row_labels(param: Any) -> tuple[str, str]:
37
+ """
38
+ Generate n_with and n_without row labels based on parameter terms.
39
+
40
+ Returns:
41
+ Tuple of (n_with_label, n_without_label)
42
+ """
43
+ # Default labels
44
+ default_with = " with one or more adverse events"
45
+ default_without = " with no adverse events"
46
+
47
+ if not param or not hasattr(param, "terms") or not param.terms:
48
+ return (default_with, default_without)
49
+
50
+ terms = param.terms
51
+ before = terms.get("before", "").lower()
52
+ after = terms.get("after", "").lower()
53
+
54
+ # Build dynamic labels with leading indentation
55
+ with_label = f"with one or more {before} adverse events {after}"
56
+ without_label = f"with no {before} adverse events {after}"
57
+
58
+ # Clean up extra spaces and add back the 4-space indentation
59
+ with_label = " " + " ".join(with_label.split())
60
+ without_label = " " + " ".join(without_label.split())
61
+
62
+ return (with_label, without_label)
@@ -0,0 +1,34 @@
1
+ # pyre-strict
2
+ """
3
+ Central configuration for csrlite.
4
+ """
5
+
6
+ from typing import Literal, Optional
7
+
8
+ from pydantic import BaseModel, ConfigDict, Field
9
+
10
+
11
+ class CsrLiteConfig(BaseModel):
12
+ """
13
+ Global configuration for csrlite library.
14
+ """
15
+
16
+ # Column Name Defaults
17
+ id_col: str = Field(default="USUBJID", description="Subject Identifier Column")
18
+ group_col: Optional[str] = Field(default=None, description="Treatment Group Column")
19
+
20
+ # Missing Value Handling
21
+ missing_str: str = Field(
22
+ default="__missing__", description="String to represent missing string values"
23
+ )
24
+
25
+ # Logging
26
+ logging_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(
27
+ default="INFO", description="Default logging level"
28
+ )
29
+
30
+ model_config = ConfigDict(validate_assignment=True)
31
+
32
+
33
+ # Global configuration instance
34
+ config = CsrLiteConfig()
@@ -0,0 +1,293 @@
1
+ # pyre-strict
2
+ import polars as pl
3
+
4
+ from .config import config
5
+
6
+
7
+ def _to_pop(
8
+ population: pl.DataFrame,
9
+ id: str,
10
+ group: str,
11
+ total: bool = True,
12
+ missing_group: str = "error",
13
+ ) -> pl.DataFrame:
14
+ # prepare data
15
+ pop = population.select(id, group)
16
+
17
+ # validate data
18
+ if pop[id].is_duplicated().any():
19
+ raise ValueError(f"The '{id}' column in the population DataFrame is not unique.")
20
+
21
+ if missing_group == "error" and pop[group].is_null().any():
22
+ raise ValueError(
23
+ f"Missing values found in the '{group}' column of the population DataFrame, "
24
+ "and 'missing_group' is set to 'error'."
25
+ )
26
+
27
+ # Convert group to Enum for consistent categorical ordering
28
+ u_pop = pop[group].unique().sort().to_list()
29
+
30
+ # handle total column
31
+ if total:
32
+ pop_total = pop.with_columns(pl.lit("Total").alias(group))
33
+ pop = pl.concat([pop, pop_total]).with_columns(
34
+ pl.col(group).cast(pl.Enum(u_pop + ["Total"]))
35
+ )
36
+ else:
37
+ pop = pop.with_columns(pl.col(group).cast(pl.Enum(u_pop)))
38
+
39
+ return pop
40
+
41
+
42
+ def count_subject(
43
+ population: pl.DataFrame,
44
+ id: str,
45
+ group: str,
46
+ total: bool = True,
47
+ missing_group: str = "error",
48
+ ) -> pl.DataFrame:
49
+ """
50
+ Counts subjects by group and optionally includes a 'Total' column.
51
+
52
+ Args:
53
+ population (pl.DataFrame): DataFrame containing subject population data.
54
+ id (str): The name of the subject ID column.
55
+ group (str): The name of the treatment group column.
56
+ total (bool, optional): If True, adds a 'Total' group. Defaults to True.
57
+ missing_group (str, optional): How to handle missing values ("error", "ignore").
58
+
59
+ Returns:
60
+ pl.DataFrame: A DataFrame with subject counts ('n_subj_pop') for each group.
61
+ """
62
+
63
+ pop = _to_pop(
64
+ population=population,
65
+ id=id,
66
+ group=group,
67
+ total=total,
68
+ missing_group=missing_group,
69
+ )
70
+
71
+ return pop.group_by(group).agg(pl.len().alias("n_subj_pop")).sort(group)
72
+
73
+
74
+ def count_summary_data(
75
+ population: pl.DataFrame,
76
+ observation: pl.DataFrame,
77
+ id: str,
78
+ group: str,
79
+ variable: str | list[str],
80
+ total: bool = True,
81
+ missing_group: str = "error",
82
+ ) -> pl.DataFrame:
83
+ """
84
+ Generates numeric summary data (counts and percentages) for observations.
85
+ Does NOT perform string formatting.
86
+
87
+ Returns:
88
+ pl.DataFrame: DataFrame with columns:
89
+ - [group]: Group column
90
+ - [variable]: Variable columns
91
+ - n_obs: Count of observations
92
+ - n_subj: Count of unique subjects with observation
93
+ - n_subj_pop: Total subjects in group
94
+ - pct_subj: Percentage of subjects (0-100)
95
+ """
96
+ # Normalize variable to list
97
+ if isinstance(variable, str):
98
+ variables = [variable]
99
+ else:
100
+ variables = variable
101
+
102
+ # prepare data
103
+ pop = _to_pop(
104
+ population=population,
105
+ id=id,
106
+ group=group,
107
+ total=total,
108
+ missing_group=missing_group,
109
+ )
110
+
111
+ # Select all required columns (id + all variables)
112
+ obs = observation.select(id, *variables).join(pop, on=id, how="left")
113
+
114
+ for var in variables:
115
+ obs = obs.with_columns(pl.col(var).cast(pl.String).fill_null(config.missing_str))
116
+
117
+ # Check for IDs in observation that are not in population
118
+ if not obs[id].is_in(pop[id].to_list()).all():
119
+ missing_ids = (
120
+ obs.filter(~pl.col(id).is_in(pop[id].to_list()))
121
+ .select(id)
122
+ .unique()
123
+ .to_series()
124
+ .to_list()
125
+ )
126
+ raise ValueError(
127
+ f"Some '{id}' values in the observation DataFrame are not present in the population: "
128
+ f"{missing_ids}"
129
+ )
130
+
131
+ df_pop = count_subject(
132
+ population=population,
133
+ id=id,
134
+ group=group,
135
+ total=total,
136
+ missing_group=missing_group,
137
+ )
138
+
139
+ all_levels_df = []
140
+
141
+ # Iterate through hierarchies
142
+ for i in range(1, len(variables) + 1):
143
+ current_vars = variables[:i]
144
+
145
+ # Aggregation
146
+ df_obs_counts = obs.group_by(group, *current_vars).agg(
147
+ pl.len().alias("n_obs"), pl.n_unique(id).alias("n_subj")
148
+ )
149
+
150
+ # Cross join for all combinations
151
+ unique_groups = df_pop.select(group)
152
+ unique_variables = obs.select(current_vars).unique()
153
+ all_combinations = unique_groups.join(unique_variables, how="cross")
154
+
155
+ # Join back
156
+ df_level = (
157
+ all_combinations.join(df_obs_counts, on=[group, *current_vars], how="left")
158
+ .join(df_pop, on=group, how="left")
159
+ .with_columns([pl.col("n_obs").fill_null(0), pl.col("n_subj").fill_null(0)])
160
+ )
161
+
162
+ df_level = df_level.with_columns([pl.col(c).cast(pl.String) for c in current_vars])
163
+
164
+ # Add missing columns with "__all__"
165
+ for var in variables:
166
+ if var not in df_level.columns:
167
+ df_level = df_level.with_columns(pl.lit("__all__").cast(pl.String).alias(var))
168
+
169
+ all_levels_df.append(df_level)
170
+
171
+ # Stack
172
+ df_obs = pl.concat(all_levels_df, how="diagonal")
173
+
174
+ # Calculate percentage
175
+ df_obs = df_obs.with_columns(pct_subj=(pl.col("n_subj") / pl.col("n_subj_pop") * 100))
176
+
177
+ return df_obs
178
+
179
+
180
+ def format_summary_table(
181
+ df: pl.DataFrame,
182
+ group: str,
183
+ variable: str | list[str],
184
+ pct_digit: int = 1,
185
+ max_n_width: int | None = None,
186
+ ) -> pl.DataFrame:
187
+ """
188
+ Formats numeric summary data into display strings (e.g., "n ( pct)").
189
+ Adds indentation and sorting.
190
+ """
191
+ if isinstance(variable, str):
192
+ variables = [variable]
193
+ else:
194
+ variables = variable
195
+
196
+ df_fmt = df.with_columns(
197
+ pct_subj_fmt=(
198
+ pl.when(pl.col("pct_subj").is_null() | pl.col("pct_subj").is_nan())
199
+ .then(0.0)
200
+ .otherwise(pl.col("pct_subj"))
201
+ .round(pct_digit, mode="half_away_from_zero")
202
+ .cast(pl.String)
203
+ )
204
+ )
205
+
206
+ if max_n_width is None:
207
+ max_n_width = df_fmt.select(pl.col("n_subj").cast(pl.String).str.len_chars().max()).item()
208
+
209
+ max_pct_width = 3 if pct_digit == 0 else 4 + pct_digit
210
+
211
+ df_fmt = df_fmt.with_columns(
212
+ [
213
+ pl.col("pct_subj_fmt").str.pad_start(max_pct_width, " "),
214
+ pl.col("n_subj").cast(pl.String).str.pad_start(max_n_width, " ").alias("n_subj_fmt"),
215
+ ]
216
+ ).with_columns(
217
+ n_pct_subj_fmt=pl.concat_str(
218
+ [pl.col("n_subj_fmt"), pl.lit(" ("), pl.col("pct_subj_fmt"), pl.lit(")")]
219
+ )
220
+ )
221
+
222
+ # Sorting Logic
223
+ sort_exprs = [pl.col(group)]
224
+ for var in variables:
225
+ # 0 for __all__, 1 for values, 2 for config.missing_str
226
+ sort_key_col = f"__sort_key_{var}__"
227
+ df_fmt = df_fmt.with_columns(
228
+ pl.when(pl.col(var) == "__all__")
229
+ .then(0)
230
+ .when(pl.col(var) == config.missing_str)
231
+ .then(2)
232
+ .otherwise(1)
233
+ .alias(sort_key_col)
234
+ )
235
+ sort_exprs.append(pl.col(sort_key_col))
236
+ sort_exprs.append(pl.col(var))
237
+
238
+ df_fmt = df_fmt.sort(sort_exprs).select(pl.exclude(r"^__sort_key_.*$"))
239
+
240
+ # Indentation logic
241
+ if len(variables) > 0:
242
+ var_expr = (
243
+ pl.when(pl.col(variables[0]) == config.missing_str)
244
+ .then(pl.lit("Missing"))
245
+ .otherwise(pl.col(variables[0]))
246
+ )
247
+
248
+ for i in range(1, len(variables)):
249
+ var_expr = (
250
+ pl.when(pl.col(variables[i]) == "__all__")
251
+ .then(var_expr)
252
+ .when(pl.col(variables[i]) == config.missing_str)
253
+ .then(pl.lit(" " * 4 * i) + pl.lit("Missing"))
254
+ .otherwise(pl.lit(" " * 4 * i) + pl.col(variables[i]))
255
+ )
256
+ df_fmt = df_fmt.with_columns(var_expr.alias("__variable__"))
257
+
258
+ df_fmt = df_fmt.with_row_index(name="__id__", offset=1)
259
+ return df_fmt
260
+
261
+
262
+ def count_subject_with_observation(
263
+ population: pl.DataFrame,
264
+ observation: pl.DataFrame,
265
+ id: str,
266
+ group: str,
267
+ variable: str | list[str],
268
+ total: bool = True,
269
+ missing_group: str = "error",
270
+ pct_digit: int = 1,
271
+ max_n_width: int | None = None,
272
+ ) -> pl.DataFrame:
273
+ """
274
+ Legacy wrapper for backward compatibility (mostly for tests that rely on the old signature),
275
+ but now strictly composing the new functions.
276
+ """
277
+ df_raw = count_summary_data(
278
+ population=population,
279
+ observation=observation,
280
+ id=id,
281
+ group=group,
282
+ variable=variable,
283
+ total=total,
284
+ missing_group=missing_group,
285
+ )
286
+
287
+ return format_summary_table(
288
+ df=df_raw,
289
+ group=group,
290
+ variable=variable,
291
+ pct_digit=pct_digit,
292
+ max_n_width=max_n_width,
293
+ )