csrlite 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- csrlite/__init__.py +110 -71
- csrlite/ae/__init__.py +1 -1
- csrlite/ae/ae_listing.py +494 -494
- csrlite/ae/ae_specific.py +483 -483
- csrlite/ae/ae_summary.py +401 -401
- csrlite/ae/ae_utils.py +62 -62
- csrlite/cm/cm_listing.py +497 -0
- csrlite/cm/cm_summary.py +327 -0
- csrlite/common/config.py +34 -34
- csrlite/common/count.py +293 -293
- csrlite/common/parse.py +308 -308
- csrlite/common/plan.py +365 -365
- csrlite/common/rtf.py +137 -137
- csrlite/common/utils.py +33 -33
- csrlite/common/yaml_loader.py +71 -71
- csrlite/disposition/__init__.py +2 -2
- csrlite/disposition/disposition.py +332 -332
- csrlite/ie/ie_listing.py +109 -0
- csrlite/ie/{ie.py → ie_summary.py} +292 -405
- csrlite/mh/mh_listing.py +209 -0
- csrlite/mh/mh_summary.py +333 -0
- csrlite/pd/pd_listing.py +461 -0
- {csrlite-0.2.1.dist-info → csrlite-0.3.0.dist-info}/METADATA +68 -68
- csrlite-0.3.0.dist-info/RECORD +26 -0
- csrlite-0.2.1.dist-info/RECORD +0 -20
- {csrlite-0.2.1.dist-info → csrlite-0.3.0.dist-info}/WHEEL +0 -0
- {csrlite-0.2.1.dist-info → csrlite-0.3.0.dist-info}/top_level.txt +0 -0
csrlite/common/count.py
CHANGED
|
@@ -1,293 +1,293 @@
|
|
|
1
|
-
# pyre-strict
|
|
2
|
-
import polars as pl
|
|
3
|
-
|
|
4
|
-
from .config import config
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def _to_pop(
|
|
8
|
-
population: pl.DataFrame,
|
|
9
|
-
id: str,
|
|
10
|
-
group: str,
|
|
11
|
-
total: bool = True,
|
|
12
|
-
missing_group: str = "error",
|
|
13
|
-
) -> pl.DataFrame:
|
|
14
|
-
# prepare data
|
|
15
|
-
pop = population.select(id, group)
|
|
16
|
-
|
|
17
|
-
# validate data
|
|
18
|
-
if pop[id].is_duplicated().any():
|
|
19
|
-
raise ValueError(f"The '{id}' column in the population DataFrame is not unique.")
|
|
20
|
-
|
|
21
|
-
if missing_group == "error" and pop[group].is_null().any():
|
|
22
|
-
raise ValueError(
|
|
23
|
-
f"Missing values found in the '{group}' column of the population DataFrame, "
|
|
24
|
-
"and 'missing_group' is set to 'error'."
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
# Convert group to Enum for consistent categorical ordering
|
|
28
|
-
u_pop = pop[group].unique().sort().to_list()
|
|
29
|
-
|
|
30
|
-
# handle total column
|
|
31
|
-
if total:
|
|
32
|
-
pop_total = pop.with_columns(pl.lit("Total").alias(group))
|
|
33
|
-
pop = pl.concat([pop, pop_total]).with_columns(
|
|
34
|
-
pl.col(group).cast(pl.Enum(u_pop + ["Total"]))
|
|
35
|
-
)
|
|
36
|
-
else:
|
|
37
|
-
pop = pop.with_columns(pl.col(group).cast(pl.Enum(u_pop)))
|
|
38
|
-
|
|
39
|
-
return pop
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def count_subject(
|
|
43
|
-
population: pl.DataFrame,
|
|
44
|
-
id: str,
|
|
45
|
-
group: str,
|
|
46
|
-
total: bool = True,
|
|
47
|
-
missing_group: str = "error",
|
|
48
|
-
) -> pl.DataFrame:
|
|
49
|
-
"""
|
|
50
|
-
Counts subjects by group and optionally includes a 'Total' column.
|
|
51
|
-
|
|
52
|
-
Args:
|
|
53
|
-
population (pl.DataFrame): DataFrame containing subject population data.
|
|
54
|
-
id (str): The name of the subject ID column.
|
|
55
|
-
group (str): The name of the treatment group column.
|
|
56
|
-
total (bool, optional): If True, adds a 'Total' group. Defaults to True.
|
|
57
|
-
missing_group (str, optional): How to handle missing values ("error", "ignore").
|
|
58
|
-
|
|
59
|
-
Returns:
|
|
60
|
-
pl.DataFrame: A DataFrame with subject counts ('n_subj_pop') for each group.
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
pop = _to_pop(
|
|
64
|
-
population=population,
|
|
65
|
-
id=id,
|
|
66
|
-
group=group,
|
|
67
|
-
total=total,
|
|
68
|
-
missing_group=missing_group,
|
|
69
|
-
)
|
|
70
|
-
|
|
71
|
-
return pop.group_by(group).agg(pl.len().alias("n_subj_pop")).sort(group)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
def count_summary_data(
|
|
75
|
-
population: pl.DataFrame,
|
|
76
|
-
observation: pl.DataFrame,
|
|
77
|
-
id: str,
|
|
78
|
-
group: str,
|
|
79
|
-
variable: str | list[str],
|
|
80
|
-
total: bool = True,
|
|
81
|
-
missing_group: str = "error",
|
|
82
|
-
) -> pl.DataFrame:
|
|
83
|
-
"""
|
|
84
|
-
Generates numeric summary data (counts and percentages) for observations.
|
|
85
|
-
Does NOT perform string formatting.
|
|
86
|
-
|
|
87
|
-
Returns:
|
|
88
|
-
pl.DataFrame: DataFrame with columns:
|
|
89
|
-
- [group]: Group column
|
|
90
|
-
- [variable]: Variable columns
|
|
91
|
-
- n_obs: Count of observations
|
|
92
|
-
- n_subj: Count of unique subjects with observation
|
|
93
|
-
- n_subj_pop: Total subjects in group
|
|
94
|
-
- pct_subj: Percentage of subjects (0-100)
|
|
95
|
-
"""
|
|
96
|
-
# Normalize variable to list
|
|
97
|
-
if isinstance(variable, str):
|
|
98
|
-
variables = [variable]
|
|
99
|
-
else:
|
|
100
|
-
variables = variable
|
|
101
|
-
|
|
102
|
-
# prepare data
|
|
103
|
-
pop = _to_pop(
|
|
104
|
-
population=population,
|
|
105
|
-
id=id,
|
|
106
|
-
group=group,
|
|
107
|
-
total=total,
|
|
108
|
-
missing_group=missing_group,
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
# Select all required columns (id + all variables)
|
|
112
|
-
obs = observation.select(id, *variables).join(pop, on=id, how="left")
|
|
113
|
-
|
|
114
|
-
for var in variables:
|
|
115
|
-
obs = obs.with_columns(pl.col(var).cast(pl.String).fill_null(config.missing_str))
|
|
116
|
-
|
|
117
|
-
# Check for IDs in observation that are not in population
|
|
118
|
-
if not obs[id].is_in(pop[id].to_list()).all():
|
|
119
|
-
missing_ids = (
|
|
120
|
-
obs.filter(~pl.col(id).is_in(pop[id].to_list()))
|
|
121
|
-
.select(id)
|
|
122
|
-
.unique()
|
|
123
|
-
.to_series()
|
|
124
|
-
.to_list()
|
|
125
|
-
)
|
|
126
|
-
raise ValueError(
|
|
127
|
-
f"Some '{id}' values in the observation DataFrame are not present in the population: "
|
|
128
|
-
f"{missing_ids}"
|
|
129
|
-
)
|
|
130
|
-
|
|
131
|
-
df_pop = count_subject(
|
|
132
|
-
population=population,
|
|
133
|
-
id=id,
|
|
134
|
-
group=group,
|
|
135
|
-
total=total,
|
|
136
|
-
missing_group=missing_group,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
all_levels_df = []
|
|
140
|
-
|
|
141
|
-
# Iterate through hierarchies
|
|
142
|
-
for i in range(1, len(variables) + 1):
|
|
143
|
-
current_vars = variables[:i]
|
|
144
|
-
|
|
145
|
-
# Aggregation
|
|
146
|
-
df_obs_counts = obs.group_by(group, *current_vars).agg(
|
|
147
|
-
pl.len().alias("n_obs"), pl.n_unique(id).alias("n_subj")
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
# Cross join for all combinations
|
|
151
|
-
unique_groups = df_pop.select(group)
|
|
152
|
-
unique_variables = obs.select(current_vars).unique()
|
|
153
|
-
all_combinations = unique_groups.join(unique_variables, how="cross")
|
|
154
|
-
|
|
155
|
-
# Join back
|
|
156
|
-
df_level = (
|
|
157
|
-
all_combinations.join(df_obs_counts, on=[group, *current_vars], how="left")
|
|
158
|
-
.join(df_pop, on=group, how="left")
|
|
159
|
-
.with_columns([pl.col("n_obs").fill_null(0), pl.col("n_subj").fill_null(0)])
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
df_level = df_level.with_columns([pl.col(c).cast(pl.String) for c in current_vars])
|
|
163
|
-
|
|
164
|
-
# Add missing columns with "__all__"
|
|
165
|
-
for var in variables:
|
|
166
|
-
if var not in df_level.columns:
|
|
167
|
-
df_level = df_level.with_columns(pl.lit("__all__").cast(pl.String).alias(var))
|
|
168
|
-
|
|
169
|
-
all_levels_df.append(df_level)
|
|
170
|
-
|
|
171
|
-
# Stack
|
|
172
|
-
df_obs = pl.concat(all_levels_df, how="diagonal")
|
|
173
|
-
|
|
174
|
-
# Calculate percentage
|
|
175
|
-
df_obs = df_obs.with_columns(pct_subj=(pl.col("n_subj") / pl.col("n_subj_pop") * 100))
|
|
176
|
-
|
|
177
|
-
return df_obs
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def format_summary_table(
|
|
181
|
-
df: pl.DataFrame,
|
|
182
|
-
group: str,
|
|
183
|
-
variable: str | list[str],
|
|
184
|
-
pct_digit: int = 1,
|
|
185
|
-
max_n_width: int | None = None,
|
|
186
|
-
) -> pl.DataFrame:
|
|
187
|
-
"""
|
|
188
|
-
Formats numeric summary data into display strings (e.g., "n ( pct)").
|
|
189
|
-
Adds indentation and sorting.
|
|
190
|
-
"""
|
|
191
|
-
if isinstance(variable, str):
|
|
192
|
-
variables = [variable]
|
|
193
|
-
else:
|
|
194
|
-
variables = variable
|
|
195
|
-
|
|
196
|
-
df_fmt = df.with_columns(
|
|
197
|
-
pct_subj_fmt=(
|
|
198
|
-
pl.when(pl.col("pct_subj").is_null() | pl.col("pct_subj").is_nan())
|
|
199
|
-
.then(0.0)
|
|
200
|
-
.otherwise(pl.col("pct_subj"))
|
|
201
|
-
.round(pct_digit, mode="half_away_from_zero")
|
|
202
|
-
.cast(pl.String)
|
|
203
|
-
)
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
if max_n_width is None:
|
|
207
|
-
max_n_width = df_fmt.select(pl.col("n_subj").cast(pl.String).str.len_chars().max()).item()
|
|
208
|
-
|
|
209
|
-
max_pct_width = 3 if pct_digit == 0 else 4 + pct_digit
|
|
210
|
-
|
|
211
|
-
df_fmt = df_fmt.with_columns(
|
|
212
|
-
[
|
|
213
|
-
pl.col("pct_subj_fmt").str.pad_start(max_pct_width, " "),
|
|
214
|
-
pl.col("n_subj").cast(pl.String).str.pad_start(max_n_width, " ").alias("n_subj_fmt"),
|
|
215
|
-
]
|
|
216
|
-
).with_columns(
|
|
217
|
-
n_pct_subj_fmt=pl.concat_str(
|
|
218
|
-
[pl.col("n_subj_fmt"), pl.lit(" ("), pl.col("pct_subj_fmt"), pl.lit(")")]
|
|
219
|
-
)
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
# Sorting Logic
|
|
223
|
-
sort_exprs = [pl.col(group)]
|
|
224
|
-
for var in variables:
|
|
225
|
-
# 0 for __all__, 1 for values, 2 for config.missing_str
|
|
226
|
-
sort_key_col = f"__sort_key_{var}__"
|
|
227
|
-
df_fmt = df_fmt.with_columns(
|
|
228
|
-
pl.when(pl.col(var) == "__all__")
|
|
229
|
-
.then(0)
|
|
230
|
-
.when(pl.col(var) == config.missing_str)
|
|
231
|
-
.then(2)
|
|
232
|
-
.otherwise(1)
|
|
233
|
-
.alias(sort_key_col)
|
|
234
|
-
)
|
|
235
|
-
sort_exprs.append(pl.col(sort_key_col))
|
|
236
|
-
sort_exprs.append(pl.col(var))
|
|
237
|
-
|
|
238
|
-
df_fmt = df_fmt.sort(sort_exprs).select(pl.exclude(r"^__sort_key_.*$"))
|
|
239
|
-
|
|
240
|
-
# Indentation logic
|
|
241
|
-
if len(variables) > 0:
|
|
242
|
-
var_expr = (
|
|
243
|
-
pl.when(pl.col(variables[0]) == config.missing_str)
|
|
244
|
-
.then(pl.lit("Missing"))
|
|
245
|
-
.otherwise(pl.col(variables[0]))
|
|
246
|
-
)
|
|
247
|
-
|
|
248
|
-
for i in range(1, len(variables)):
|
|
249
|
-
var_expr = (
|
|
250
|
-
pl.when(pl.col(variables[i]) == "__all__")
|
|
251
|
-
.then(var_expr)
|
|
252
|
-
.when(pl.col(variables[i]) == config.missing_str)
|
|
253
|
-
.then(pl.lit(" " * 4 * i) + pl.lit("Missing"))
|
|
254
|
-
.otherwise(pl.lit(" " * 4 * i) + pl.col(variables[i]))
|
|
255
|
-
)
|
|
256
|
-
df_fmt = df_fmt.with_columns(var_expr.alias("__variable__"))
|
|
257
|
-
|
|
258
|
-
df_fmt = df_fmt.with_row_index(name="__id__", offset=1)
|
|
259
|
-
return df_fmt
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
def count_subject_with_observation(
|
|
263
|
-
population: pl.DataFrame,
|
|
264
|
-
observation: pl.DataFrame,
|
|
265
|
-
id: str,
|
|
266
|
-
group: str,
|
|
267
|
-
variable: str | list[str],
|
|
268
|
-
total: bool = True,
|
|
269
|
-
missing_group: str = "error",
|
|
270
|
-
pct_digit: int = 1,
|
|
271
|
-
max_n_width: int | None = None,
|
|
272
|
-
) -> pl.DataFrame:
|
|
273
|
-
"""
|
|
274
|
-
Legacy wrapper for backward compatibility (mostly for tests that rely on the old signature),
|
|
275
|
-
but now strictly composing the new functions.
|
|
276
|
-
"""
|
|
277
|
-
df_raw = count_summary_data(
|
|
278
|
-
population=population,
|
|
279
|
-
observation=observation,
|
|
280
|
-
id=id,
|
|
281
|
-
group=group,
|
|
282
|
-
variable=variable,
|
|
283
|
-
total=total,
|
|
284
|
-
missing_group=missing_group,
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
return format_summary_table(
|
|
288
|
-
df=df_raw,
|
|
289
|
-
group=group,
|
|
290
|
-
variable=variable,
|
|
291
|
-
pct_digit=pct_digit,
|
|
292
|
-
max_n_width=max_n_width,
|
|
293
|
-
)
|
|
1
|
+
# pyre-strict
|
|
2
|
+
import polars as pl
|
|
3
|
+
|
|
4
|
+
from .config import config
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _to_pop(
|
|
8
|
+
population: pl.DataFrame,
|
|
9
|
+
id: str,
|
|
10
|
+
group: str,
|
|
11
|
+
total: bool = True,
|
|
12
|
+
missing_group: str = "error",
|
|
13
|
+
) -> pl.DataFrame:
|
|
14
|
+
# prepare data
|
|
15
|
+
pop = population.select(id, group)
|
|
16
|
+
|
|
17
|
+
# validate data
|
|
18
|
+
if pop[id].is_duplicated().any():
|
|
19
|
+
raise ValueError(f"The '{id}' column in the population DataFrame is not unique.")
|
|
20
|
+
|
|
21
|
+
if missing_group == "error" and pop[group].is_null().any():
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Missing values found in the '{group}' column of the population DataFrame, "
|
|
24
|
+
"and 'missing_group' is set to 'error'."
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Convert group to Enum for consistent categorical ordering
|
|
28
|
+
u_pop = pop[group].unique().sort().to_list()
|
|
29
|
+
|
|
30
|
+
# handle total column
|
|
31
|
+
if total:
|
|
32
|
+
pop_total = pop.with_columns(pl.lit("Total").alias(group))
|
|
33
|
+
pop = pl.concat([pop, pop_total]).with_columns(
|
|
34
|
+
pl.col(group).cast(pl.Enum(u_pop + ["Total"]))
|
|
35
|
+
)
|
|
36
|
+
else:
|
|
37
|
+
pop = pop.with_columns(pl.col(group).cast(pl.Enum(u_pop)))
|
|
38
|
+
|
|
39
|
+
return pop
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def count_subject(
|
|
43
|
+
population: pl.DataFrame,
|
|
44
|
+
id: str,
|
|
45
|
+
group: str,
|
|
46
|
+
total: bool = True,
|
|
47
|
+
missing_group: str = "error",
|
|
48
|
+
) -> pl.DataFrame:
|
|
49
|
+
"""
|
|
50
|
+
Counts subjects by group and optionally includes a 'Total' column.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
population (pl.DataFrame): DataFrame containing subject population data.
|
|
54
|
+
id (str): The name of the subject ID column.
|
|
55
|
+
group (str): The name of the treatment group column.
|
|
56
|
+
total (bool, optional): If True, adds a 'Total' group. Defaults to True.
|
|
57
|
+
missing_group (str, optional): How to handle missing values ("error", "ignore").
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
pl.DataFrame: A DataFrame with subject counts ('n_subj_pop') for each group.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
pop = _to_pop(
|
|
64
|
+
population=population,
|
|
65
|
+
id=id,
|
|
66
|
+
group=group,
|
|
67
|
+
total=total,
|
|
68
|
+
missing_group=missing_group,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
return pop.group_by(group).agg(pl.len().alias("n_subj_pop")).sort(group)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def count_summary_data(
|
|
75
|
+
population: pl.DataFrame,
|
|
76
|
+
observation: pl.DataFrame,
|
|
77
|
+
id: str,
|
|
78
|
+
group: str,
|
|
79
|
+
variable: str | list[str],
|
|
80
|
+
total: bool = True,
|
|
81
|
+
missing_group: str = "error",
|
|
82
|
+
) -> pl.DataFrame:
|
|
83
|
+
"""
|
|
84
|
+
Generates numeric summary data (counts and percentages) for observations.
|
|
85
|
+
Does NOT perform string formatting.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
pl.DataFrame: DataFrame with columns:
|
|
89
|
+
- [group]: Group column
|
|
90
|
+
- [variable]: Variable columns
|
|
91
|
+
- n_obs: Count of observations
|
|
92
|
+
- n_subj: Count of unique subjects with observation
|
|
93
|
+
- n_subj_pop: Total subjects in group
|
|
94
|
+
- pct_subj: Percentage of subjects (0-100)
|
|
95
|
+
"""
|
|
96
|
+
# Normalize variable to list
|
|
97
|
+
if isinstance(variable, str):
|
|
98
|
+
variables = [variable]
|
|
99
|
+
else:
|
|
100
|
+
variables = variable
|
|
101
|
+
|
|
102
|
+
# prepare data
|
|
103
|
+
pop = _to_pop(
|
|
104
|
+
population=population,
|
|
105
|
+
id=id,
|
|
106
|
+
group=group,
|
|
107
|
+
total=total,
|
|
108
|
+
missing_group=missing_group,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Select all required columns (id + all variables)
|
|
112
|
+
obs = observation.select(id, *variables).join(pop, on=id, how="left")
|
|
113
|
+
|
|
114
|
+
for var in variables:
|
|
115
|
+
obs = obs.with_columns(pl.col(var).cast(pl.String).fill_null(config.missing_str))
|
|
116
|
+
|
|
117
|
+
# Check for IDs in observation that are not in population
|
|
118
|
+
if not obs[id].is_in(pop[id].to_list()).all():
|
|
119
|
+
missing_ids = (
|
|
120
|
+
obs.filter(~pl.col(id).is_in(pop[id].to_list()))
|
|
121
|
+
.select(id)
|
|
122
|
+
.unique()
|
|
123
|
+
.to_series()
|
|
124
|
+
.to_list()
|
|
125
|
+
)
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Some '{id}' values in the observation DataFrame are not present in the population: "
|
|
128
|
+
f"{missing_ids}"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
df_pop = count_subject(
|
|
132
|
+
population=population,
|
|
133
|
+
id=id,
|
|
134
|
+
group=group,
|
|
135
|
+
total=total,
|
|
136
|
+
missing_group=missing_group,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
all_levels_df = []
|
|
140
|
+
|
|
141
|
+
# Iterate through hierarchies
|
|
142
|
+
for i in range(1, len(variables) + 1):
|
|
143
|
+
current_vars = variables[:i]
|
|
144
|
+
|
|
145
|
+
# Aggregation
|
|
146
|
+
df_obs_counts = obs.group_by(group, *current_vars).agg(
|
|
147
|
+
pl.len().alias("n_obs"), pl.n_unique(id).alias("n_subj")
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Cross join for all combinations
|
|
151
|
+
unique_groups = df_pop.select(group)
|
|
152
|
+
unique_variables = obs.select(current_vars).unique()
|
|
153
|
+
all_combinations = unique_groups.join(unique_variables, how="cross")
|
|
154
|
+
|
|
155
|
+
# Join back
|
|
156
|
+
df_level = (
|
|
157
|
+
all_combinations.join(df_obs_counts, on=[group, *current_vars], how="left")
|
|
158
|
+
.join(df_pop, on=group, how="left")
|
|
159
|
+
.with_columns([pl.col("n_obs").fill_null(0), pl.col("n_subj").fill_null(0)])
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
df_level = df_level.with_columns([pl.col(c).cast(pl.String) for c in current_vars])
|
|
163
|
+
|
|
164
|
+
# Add missing columns with "__all__"
|
|
165
|
+
for var in variables:
|
|
166
|
+
if var not in df_level.columns:
|
|
167
|
+
df_level = df_level.with_columns(pl.lit("__all__").cast(pl.String).alias(var))
|
|
168
|
+
|
|
169
|
+
all_levels_df.append(df_level)
|
|
170
|
+
|
|
171
|
+
# Stack
|
|
172
|
+
df_obs = pl.concat(all_levels_df, how="diagonal")
|
|
173
|
+
|
|
174
|
+
# Calculate percentage
|
|
175
|
+
df_obs = df_obs.with_columns(pct_subj=(pl.col("n_subj") / pl.col("n_subj_pop") * 100))
|
|
176
|
+
|
|
177
|
+
return df_obs
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def format_summary_table(
|
|
181
|
+
df: pl.DataFrame,
|
|
182
|
+
group: str,
|
|
183
|
+
variable: str | list[str],
|
|
184
|
+
pct_digit: int = 1,
|
|
185
|
+
max_n_width: int | None = None,
|
|
186
|
+
) -> pl.DataFrame:
|
|
187
|
+
"""
|
|
188
|
+
Formats numeric summary data into display strings (e.g., "n ( pct)").
|
|
189
|
+
Adds indentation and sorting.
|
|
190
|
+
"""
|
|
191
|
+
if isinstance(variable, str):
|
|
192
|
+
variables = [variable]
|
|
193
|
+
else:
|
|
194
|
+
variables = variable
|
|
195
|
+
|
|
196
|
+
df_fmt = df.with_columns(
|
|
197
|
+
pct_subj_fmt=(
|
|
198
|
+
pl.when(pl.col("pct_subj").is_null() | pl.col("pct_subj").is_nan())
|
|
199
|
+
.then(0.0)
|
|
200
|
+
.otherwise(pl.col("pct_subj"))
|
|
201
|
+
.round(pct_digit, mode="half_away_from_zero")
|
|
202
|
+
.cast(pl.String)
|
|
203
|
+
)
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if max_n_width is None:
|
|
207
|
+
max_n_width = df_fmt.select(pl.col("n_subj").cast(pl.String).str.len_chars().max()).item()
|
|
208
|
+
|
|
209
|
+
max_pct_width = 3 if pct_digit == 0 else 4 + pct_digit
|
|
210
|
+
|
|
211
|
+
df_fmt = df_fmt.with_columns(
|
|
212
|
+
[
|
|
213
|
+
pl.col("pct_subj_fmt").str.pad_start(max_pct_width, " "),
|
|
214
|
+
pl.col("n_subj").cast(pl.String).str.pad_start(max_n_width, " ").alias("n_subj_fmt"),
|
|
215
|
+
]
|
|
216
|
+
).with_columns(
|
|
217
|
+
n_pct_subj_fmt=pl.concat_str(
|
|
218
|
+
[pl.col("n_subj_fmt"), pl.lit(" ("), pl.col("pct_subj_fmt"), pl.lit(")")]
|
|
219
|
+
)
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Sorting Logic
|
|
223
|
+
sort_exprs = [pl.col(group)]
|
|
224
|
+
for var in variables:
|
|
225
|
+
# 0 for __all__, 1 for values, 2 for config.missing_str
|
|
226
|
+
sort_key_col = f"__sort_key_{var}__"
|
|
227
|
+
df_fmt = df_fmt.with_columns(
|
|
228
|
+
pl.when(pl.col(var) == "__all__")
|
|
229
|
+
.then(0)
|
|
230
|
+
.when(pl.col(var) == config.missing_str)
|
|
231
|
+
.then(2)
|
|
232
|
+
.otherwise(1)
|
|
233
|
+
.alias(sort_key_col)
|
|
234
|
+
)
|
|
235
|
+
sort_exprs.append(pl.col(sort_key_col))
|
|
236
|
+
sort_exprs.append(pl.col(var))
|
|
237
|
+
|
|
238
|
+
df_fmt = df_fmt.sort(sort_exprs).select(pl.exclude(r"^__sort_key_.*$"))
|
|
239
|
+
|
|
240
|
+
# Indentation logic
|
|
241
|
+
if len(variables) > 0:
|
|
242
|
+
var_expr = (
|
|
243
|
+
pl.when(pl.col(variables[0]) == config.missing_str)
|
|
244
|
+
.then(pl.lit("Missing"))
|
|
245
|
+
.otherwise(pl.col(variables[0]))
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
for i in range(1, len(variables)):
|
|
249
|
+
var_expr = (
|
|
250
|
+
pl.when(pl.col(variables[i]) == "__all__")
|
|
251
|
+
.then(var_expr)
|
|
252
|
+
.when(pl.col(variables[i]) == config.missing_str)
|
|
253
|
+
.then(pl.lit(" " * 4 * i) + pl.lit("Missing"))
|
|
254
|
+
.otherwise(pl.lit(" " * 4 * i) + pl.col(variables[i]))
|
|
255
|
+
)
|
|
256
|
+
df_fmt = df_fmt.with_columns(var_expr.alias("__variable__"))
|
|
257
|
+
|
|
258
|
+
df_fmt = df_fmt.with_row_index(name="__id__", offset=1)
|
|
259
|
+
return df_fmt
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def count_subject_with_observation(
|
|
263
|
+
population: pl.DataFrame,
|
|
264
|
+
observation: pl.DataFrame,
|
|
265
|
+
id: str,
|
|
266
|
+
group: str,
|
|
267
|
+
variable: str | list[str],
|
|
268
|
+
total: bool = True,
|
|
269
|
+
missing_group: str = "error",
|
|
270
|
+
pct_digit: int = 1,
|
|
271
|
+
max_n_width: int | None = None,
|
|
272
|
+
) -> pl.DataFrame:
|
|
273
|
+
"""
|
|
274
|
+
Legacy wrapper for backward compatibility (mostly for tests that rely on the old signature),
|
|
275
|
+
but now strictly composing the new functions.
|
|
276
|
+
"""
|
|
277
|
+
df_raw = count_summary_data(
|
|
278
|
+
population=population,
|
|
279
|
+
observation=observation,
|
|
280
|
+
id=id,
|
|
281
|
+
group=group,
|
|
282
|
+
variable=variable,
|
|
283
|
+
total=total,
|
|
284
|
+
missing_group=missing_group,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
return format_summary_table(
|
|
288
|
+
df=df_raw,
|
|
289
|
+
group=group,
|
|
290
|
+
variable=variable,
|
|
291
|
+
pct_digit=pct_digit,
|
|
292
|
+
max_n_width=max_n_width,
|
|
293
|
+
)
|