pysofra 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. pysofra/__init__.py +82 -0
  2. pysofra/core/__init__.py +14 -0
  3. pysofra/core/compose.py +167 -0
  4. pysofra/core/format.py +155 -0
  5. pysofra/core/frames.py +69 -0
  6. pysofra/core/schema.py +128 -0
  7. pysofra/core/table.py +924 -0
  8. pysofra/io/__init__.py +1 -0
  9. pysofra/models/__init__.py +6 -0
  10. pysofra/models/extract.py +249 -0
  11. pysofra/models/pool.py +119 -0
  12. pysofra/models/regression.py +507 -0
  13. pysofra/models/survival.py +395 -0
  14. pysofra/models/uvregression.py +438 -0
  15. pysofra/notebook/__init__.py +6 -0
  16. pysofra/plot/__init__.py +23 -0
  17. pysofra/plot/_backend.py +32 -0
  18. pysofra/plot/forest.py +159 -0
  19. pysofra/plot/inline.py +171 -0
  20. pysofra/plot/km.py +249 -0
  21. pysofra/render/__init__.py +28 -0
  22. pysofra/render/_zip_determinism.py +57 -0
  23. pysofra/render/base.py +22 -0
  24. pysofra/render/docx.py +286 -0
  25. pysofra/render/html.py +442 -0
  26. pysofra/render/image.py +130 -0
  27. pysofra/render/latex.py +253 -0
  28. pysofra/render/markdown.py +128 -0
  29. pysofra/render/pptx.py +340 -0
  30. pysofra/render/xlsx.py +226 -0
  31. pysofra/summary/__init__.py +6 -0
  32. pysofra/summary/calibrate.py +214 -0
  33. pysofra/summary/design.py +246 -0
  34. pysofra/summary/effect_size.py +187 -0
  35. pysofra/summary/extras.py +745 -0
  36. pysofra/summary/smd.py +133 -0
  37. pysofra/summary/stats.py +135 -0
  38. pysofra/summary/tbl_cross.py +339 -0
  39. pysofra/summary/tbl_one.py +1220 -0
  40. pysofra/summary/tbl_summary.py +51 -0
  41. pysofra/summary/tests.py +370 -0
  42. pysofra/summary/typing.py +129 -0
  43. pysofra/summary/weights.py +161 -0
  44. pysofra/themes/__init__.py +5 -0
  45. pysofra/themes/registry.py +272 -0
  46. pysofra-0.1.0a1.dist-info/METADATA +301 -0
  47. pysofra-0.1.0a1.dist-info/RECORD +50 -0
  48. pysofra-0.1.0a1.dist-info/WHEEL +4 -0
  49. pysofra-0.1.0a1.dist-info/licenses/LICENSE +674 -0
  50. pysofra-0.1.0a1.dist-info/licenses/NOTICE +18 -0
pysofra/summary/smd.py ADDED
@@ -0,0 +1,133 @@
1
+ """Standardized mean differences (SMDs) across groups.
2
+
3
+ For two groups, the SMD is computed as |mean_1 - mean_2| / sd_pool for
4
+ continuous variables and as the multivariate categorical SMD for factors.
5
+
6
+ For three or more groups we report the *maximum pairwise* SMD, which is
7
+ the convention used by ``tableone``. Users who want a different summary
8
+ (mean pairwise, single-reference) can post-process the metadata.
9
+
10
+ References
11
+ ----------
12
+ Yang, D., & Dalton, J. E. (2012). A unified approach to measuring the
13
+ effect size between two groups using SAS. SAS Global Forum.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from itertools import combinations
19
+
20
+ import numpy as np
21
+ import pandas as pd
22
+
23
+
24
+ def continuous_smd_pair(a: np.ndarray, b: np.ndarray) -> float | None:
25
+ """SMD between two continuous samples using pooled SD."""
26
+ a = a[~np.isnan(a)]
27
+ b = b[~np.isnan(b)]
28
+ na, nb = a.size, b.size
29
+ if na < 2 or nb < 2:
30
+ return None
31
+ # Same ``inf``-safety wrap as ``continuous_stats``: numpy's mean / var
32
+ # emit ``RuntimeWarning`` on inf-bearing inputs which escalates to
33
+ # an exception under the project's ``filterwarnings = error`` gate.
34
+ # The resulting ``mean = inf`` / ``var = nan`` are handled by the
35
+ # explicit ``sd_pool == 0.0`` and ``ma == mb`` checks below.
36
+ import warnings as _w
37
+ with np.errstate(invalid="ignore", over="ignore"), _w.catch_warnings():
38
+ _w.simplefilter("ignore", RuntimeWarning)
39
+ ma, mb = float(np.mean(a)), float(np.mean(b))
40
+ va, vb = float(np.var(a, ddof=1)), float(np.var(b, ddof=1))
41
+ sd_pool = float(np.sqrt((va + vb) / 2.0))
42
+ if sd_pool == 0.0:
43
+ return 0.0 if ma == mb else float("inf")
44
+ return abs(ma - mb) / sd_pool
45
+
46
+
47
+ def categorical_smd_pair(p1: np.ndarray, p2: np.ndarray) -> float | None:
48
+ """Multivariate categorical SMD between two proportion vectors.
49
+
50
+ ``p1`` and ``p2`` are proportion vectors over the same K levels.
51
+ Uses the Yang & Dalton (2012) formulation with K-1 dimensions.
52
+
53
+ Edge cases. When the (K-1) average covariance ``S`` is the zero
54
+ matrix the multivariate quadratic form is undefined; this happens
55
+ when both groups have all mass on a single (possibly different)
56
+ category, including the *complete-separation* case
57
+ (e.g. group 1 = "A", group 2 = "B" only). Returning the
58
+ Mahalanobis distance via the pseudo-inverse silently yields zero
59
+ in that case, which would report perfect balance — the opposite
60
+ of the truth. We therefore return ``inf`` when the contrast is
61
+ nonzero and the covariance is degenerate, and ``0`` when both
62
+ are zero (groups truly identical).
63
+ """
64
+ if p1.size != p2.size or p1.size < 2:
65
+ return None
66
+ # Use K-1 categories to avoid singular covariance.
67
+ p1 = p1[:-1]
68
+ p2 = p2[:-1]
69
+ diff = p1 - p2
70
+ # Mean covariance matrix S = (S1 + S2) / 2
71
+ s1 = np.diag(p1) - np.outer(p1, p1)
72
+ s2 = np.diag(p2) - np.outer(p2, p2)
73
+ s = (s1 + s2) / 2.0
74
+ # Degenerate covariance: either no variability within either group
75
+ # (each group concentrates on one category) or the K-1 contrast
76
+ # space is empty. ``pinv`` returns a near-zero matrix here, which
77
+ # would falsely report a zero SMD even under complete separation.
78
+ if np.allclose(s, 0.0):
79
+ return 0.0 if np.allclose(diff, 0.0) else float("inf")
80
+ try:
81
+ s_inv = np.linalg.pinv(s)
82
+ except np.linalg.LinAlgError:
83
+ return None
84
+ val = float(diff @ s_inv @ diff)
85
+ if val < 0: # numerical
86
+ val = 0.0
87
+ return float(np.sqrt(val))
88
+
89
+
90
+ def continuous_smd(values: pd.Series, groups: pd.Series) -> float | None:
91
+ """Maximum pairwise continuous SMD across groups."""
92
+ df = pd.DataFrame({"v": pd.to_numeric(values, errors="coerce"), "g": groups}).dropna()
93
+ if df.empty:
94
+ return None
95
+ by_group = {g: x["v"].to_numpy() for g, x in df.groupby("g", observed=True)}
96
+ keys = list(by_group)
97
+ if len(keys) < 2:
98
+ return None
99
+ if len(keys) == 2:
100
+ return continuous_smd_pair(by_group[keys[0]], by_group[keys[1]])
101
+ pair_results = [
102
+ continuous_smd_pair(by_group[a], by_group[b])
103
+ for a, b in combinations(keys, 2)
104
+ ]
105
+ pairs_cont: list[float] = [p for p in pair_results if p is not None]
106
+ return max(pairs_cont) if pairs_cont else None
107
+
108
+
109
+ def categorical_smd(
110
+ values: pd.Series,
111
+ groups: pd.Series,
112
+ levels: list[object] | tuple[object, ...] | None = None,
113
+ ) -> float | None:
114
+ """Maximum pairwise categorical SMD across groups."""
115
+ df = pd.DataFrame({"v": values, "g": groups}).dropna()
116
+ if df.empty:
117
+ return None
118
+ ctab = pd.crosstab(df["v"], df["g"])
119
+ if levels is not None:
120
+ ctab = ctab.reindex(index=list(levels), fill_value=0)
121
+ if ctab.shape[0] < 2 or ctab.shape[1] < 2:
122
+ return None
123
+ col_totals = ctab.sum(axis=0).replace(0, np.nan)
124
+ props = (ctab / col_totals).fillna(0.0)
125
+ keys = list(props.columns)
126
+ if len(keys) == 2:
127
+ return categorical_smd_pair(props[keys[0]].to_numpy(), props[keys[1]].to_numpy())
128
+ cat_pair_results = [
129
+ categorical_smd_pair(props[a].to_numpy(), props[b].to_numpy())
130
+ for a, b in combinations(keys, 2)
131
+ ]
132
+ pairs_cat: list[float] = [p for p in cat_pair_results if p is not None]
133
+ return max(pairs_cat) if pairs_cat else None
@@ -0,0 +1,135 @@
1
+ """Summary statistic computations for continuous and categorical variables.
2
+
3
+ Pure functions that take a pandas Series (or groupby slice) and return a
4
+ small dataclass of statistics. Format-free — formatting belongs to
5
+ :mod:`pysofra.core.format`.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import warnings
11
+ from dataclasses import dataclass
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+
17
+ @dataclass(frozen=True)
18
+ class ContinuousStats:
19
+ n: int
20
+ n_missing: int
21
+ mean: float
22
+ sd: float
23
+ median: float
24
+ q1: float
25
+ q3: float
26
+ min: float
27
+ max: float
28
+
29
+
30
+ def continuous_stats(series: pd.Series) -> ContinuousStats:
31
+ """Compute continuous summary statistics. Tolerates all-NaN slices."""
32
+ s = pd.to_numeric(series, errors="coerce")
33
+ n_total = len(s)
34
+ valid = s.dropna()
35
+ n = int(valid.size)
36
+ n_missing = int(n_total - n)
37
+
38
+ if n == 0:
39
+ nan = float("nan")
40
+ return ContinuousStats(0, n_missing, nan, nan, nan, nan, nan, nan, nan)
41
+
42
+ arr = valid.to_numpy(dtype=float)
43
+ # When the array contains ``inf``/``-inf``, numpy's ``mean`` / ``std``
44
+ # / ``quantile`` emit ``RuntimeWarning: invalid value encountered in
45
+ # subtract`` (and similar). Under ``filterwarnings = error`` — which
46
+ # this project's own pyproject.toml sets and which is a common
47
+ # user-side ``-W error`` posture — those warnings escalate to
48
+ # exceptions and crash ``tbl_one`` on perfectly legal data. The R6
49
+ # audit fixed the ``int(np.inf) → OverflowError`` path in
50
+ # ``infer_kind`` but didn't reach this downstream stats site.
51
+ # Wrap arithmetic in ``np.errstate`` + ``catch_warnings`` so the
52
+ # stats compute cleanly to ``nan`` / ``inf`` (which the formatters
53
+ # then render as em-dash).
54
+ with np.errstate(invalid="ignore", over="ignore"), warnings.catch_warnings():
55
+ warnings.simplefilter("ignore", RuntimeWarning)
56
+ mean = float(np.mean(arr))
57
+ # Sample SD (ddof=1) is undefined for n=1; we report NaN so renderers
58
+ # can show ``—`` rather than a misleading "0.00".
59
+ sd = float(np.std(arr, ddof=1)) if n > 1 else float("nan")
60
+ median = float(np.median(arr))
61
+ q1, q3 = (float(x) for x in np.quantile(arr, [0.25, 0.75]))
62
+ arr_min = float(np.min(arr))
63
+ arr_max = float(np.max(arr))
64
+ return ContinuousStats(
65
+ n=n,
66
+ n_missing=n_missing,
67
+ mean=mean,
68
+ sd=sd,
69
+ median=median,
70
+ q1=q1,
71
+ q3=q3,
72
+ min=arr_min,
73
+ max=arr_max,
74
+ )
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class CategoricalStats:
79
+ n: int
80
+ n_missing: int
81
+ counts: dict[object, int] # ordered by level
82
+ levels: tuple[object, ...]
83
+
84
+
85
+ def categorical_stats(
86
+ series: pd.Series,
87
+ levels: list[object] | tuple[object, ...] | None = None,
88
+ ) -> CategoricalStats:
89
+ """Compute counts per level.
90
+
91
+ If ``levels`` is provided, levels missing from the series are included
92
+ with count 0 (so that grouped tables align across strata).
93
+ """
94
+ s = series
95
+ n_missing = int(s.isna().sum())
96
+ valid = s.dropna()
97
+
98
+ if levels is None:
99
+ if isinstance(s.dtype, pd.CategoricalDtype):
100
+ level_list: list[object] = list(s.cat.categories)
101
+ else:
102
+ level_list = sorted(valid.unique(), key=_safe_sort_key)
103
+ else:
104
+ level_list = list(levels)
105
+
106
+ counts: dict[object, int] = {lvl: 0 for lvl in level_list}
107
+ vc = valid.value_counts(dropna=False)
108
+ for lvl, c in vc.items():
109
+ if lvl in counts:
110
+ counts[lvl] = int(c)
111
+ else:
112
+ # Out-of-spec level (only when caller passed explicit levels).
113
+ counts[lvl] = int(c)
114
+ level_list.append(lvl)
115
+
116
+ return CategoricalStats(
117
+ n=int(valid.size),
118
+ n_missing=n_missing,
119
+ counts=counts,
120
+ levels=tuple(level_list),
121
+ )
122
+
123
+
124
+ def _safe_sort_key(x: object) -> tuple[int, object]:
125
+ """Sort key that puts numerics first, then strings, then everything else.
126
+
127
+ Avoids ``TypeError`` from mixed-type uniques (e.g. ``[1, "a"]``).
128
+ """
129
+ if isinstance(x, bool):
130
+ return (0, int(x))
131
+ if isinstance(x, (int, float, np.integer, np.floating)):
132
+ return (0, float(x))
133
+ if isinstance(x, str):
134
+ return (1, x)
135
+ return (2, repr(x))
@@ -0,0 +1,339 @@
1
+ """Cross-tabulation tables — equivalent to ``gtsummary::tbl_cross``.
2
+
3
+ ``tbl_cross`` builds a two-way contingency table with selectable cell
4
+ content:
5
+
6
+ * ``n`` — raw count (default)
7
+ * ``row_pct`` — row-percent
8
+ * ``col_pct`` — column-percent
9
+ * ``total_pct`` — overall percent
10
+ * ``n_row_pct`` — n with row-% in parens (the "n (row %)" style)
11
+ * ``n_col_pct`` — n with col-% in parens
12
+ * ``n_total_pct`` — n with total-% in parens
13
+
14
+ Margins (row totals / column totals / grand total) are added by default
15
+ and can be turned off with ``margins=False``.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Any
21
+
22
+ import pandas as pd
23
+
24
+ from ..core.format import fmt_n_pct, fmt_p_value
25
+ from ..core.frames import to_pandas
26
+ from ..core.schema import HeaderCell, HeaderRow, Row, make_cell
27
+ from ..core.table import SofraTable, TableSpec
28
+
29
+ _CELL_STYLES = (
30
+ "n", "row_pct", "col_pct", "total_pct",
31
+ "n_row_pct", "n_col_pct", "n_total_pct",
32
+ )
33
+
34
+
35
+ def tbl_cross(
36
+ data: Any,
37
+ *,
38
+ row: str,
39
+ column: str,
40
+ cell: str = "n_col_pct",
41
+ margins: bool = True,
42
+ digits: int = 1,
43
+ labels: dict[str, str] | None = None,
44
+ ) -> SofraTable:
45
+ """Cross-tabulate ``row`` against ``column``.
46
+
47
+ Parameters
48
+ ----------
49
+ data
50
+ Source dataframe.
51
+ row
52
+ Variable name for the rows.
53
+ column
54
+ Variable name for the columns.
55
+ cell
56
+ How to display each interior cell. See module docstring.
57
+ margins
58
+ Include row / column / grand totals.
59
+ digits
60
+ Decimal places for the percent.
61
+ labels
62
+ Optional mapping of level → display label, applied to both row
63
+ and column labels.
64
+
65
+ Notes
66
+ -----
67
+ The returned :class:`SofraTable` carries a rebuild closure over the
68
+ source ``data`` so the statistical modifiers ``.add_p()`` and
69
+ ``.add_overall()`` work directly:
70
+
71
+ * ``.add_p()`` re-runs the cross-tab and appends a *p*-value
72
+ footnote based on the auto-selected categorical test (Fisher's
73
+ exact for 2x2, Pearson χ² otherwise).
74
+ * ``.add_overall()`` toggles ``margins=True`` so the row, column,
75
+ and grand totals are rendered (no-op when margins are already on,
76
+ which is the default).
77
+ * ``.add_smd()`` raises :class:`NotImplementedError` — SMD is a
78
+ between-group effect-size on a single distribution and is
79
+ undefined on a contingency table. Use :func:`tbl_one` for SMD
80
+ between two arms.
81
+ """
82
+ if cell not in _CELL_STYLES:
83
+ raise ValueError(
84
+ f"cell must be one of {_CELL_STYLES}; got {cell!r}."
85
+ )
86
+ df = to_pandas(data)
87
+ if row not in df.columns:
88
+ raise KeyError(f"row column {row!r} not in data")
89
+ if column not in df.columns:
90
+ raise KeyError(f"column column {column!r} not in data")
91
+
92
+ spec = TableSpec(
93
+ builder="tbl_cross",
94
+ options={
95
+ "row": row,
96
+ "column": column,
97
+ "cell": cell,
98
+ "margins": margins,
99
+ "digits": digits,
100
+ "labels": dict(labels or {}),
101
+ # Modifier flags — toggled by .add_p() / .add_overall() /
102
+ # .add_smd() via SofraTable._with_option().
103
+ "p_value": False,
104
+ "overall": False,
105
+ "smd": False,
106
+ },
107
+ )
108
+ return _build_cross(df, spec)
109
+
110
+
111
+ def _build_cross(df: pd.DataFrame, spec: TableSpec) -> SofraTable:
112
+ """Build a tbl_cross SofraTable from a frozen source df and spec."""
113
+ row: str = spec.options["row"]
114
+ column: str = spec.options["column"]
115
+ cell: str = spec.options["cell"]
116
+ # add_overall() forces margins on; the user-passed margins= remains
117
+ # otherwise.
118
+ margins: bool = bool(spec.options["margins"] or spec.options.get("overall"))
119
+ digits: int = int(spec.options["digits"])
120
+ labels: dict[str, str] = dict(spec.options.get("labels") or {})
121
+
122
+ # add_smd() is meaningless on a cross-tab. Surface that explicitly
123
+ # rather than emitting a spurious "SMD" column or silently dropping
124
+ # the flag.
125
+ if spec.options.get("smd"):
126
+ raise NotImplementedError(
127
+ "add_smd() is not defined for tbl_cross — SMD measures the "
128
+ "standardised difference between two distributions of a "
129
+ "single variable. For SMD between two arms on the same "
130
+ "variable, use tbl_one(df, by=...).add_smd().",
131
+ )
132
+
133
+ # Drop rows missing either dimension.
134
+ sub = df[[row, column]].dropna()
135
+ if sub.empty:
136
+ # Emit a minimal placeholder table. Carry the spec + rebuild so
137
+ # the table can still respond to modifiers (no-ops in practice,
138
+ # but the contract is "rebuild always works").
139
+ def _empty_rebuild(new_spec: TableSpec) -> SofraTable: # pragma: no cover — modifier on empty cross-tab
140
+ return _build_cross(df, new_spec)
141
+ return SofraTable(
142
+ rows=(Row(cells=(make_cell(row, align="left", bold=True),
143
+ make_cell("—", value=None,
144
+ kind="numeric", align="right"))),),
145
+ headers=(HeaderRow(cells=(
146
+ HeaderCell(text=labels.get(row, row), align="left"),
147
+ HeaderCell(text=labels.get(column, column)),
148
+ )),),
149
+ footnotes=("No non-missing rows for the requested cross-tabulation.",),
150
+ metadata={"builder": "tbl_cross"},
151
+ _spec=spec,
152
+ _rebuild=_empty_rebuild,
153
+ )
154
+
155
+ # Preserve categorical ordering where present.
156
+ if isinstance(df[row].dtype, pd.CategoricalDtype):
157
+ row_levels = [lvl for lvl in df[row].cat.categories if lvl in set(sub[row])]
158
+ else:
159
+ row_levels = sorted(sub[row].unique(), key=_sort_key)
160
+ if isinstance(df[column].dtype, pd.CategoricalDtype):
161
+ col_levels = [lvl for lvl in df[column].cat.categories if lvl in set(sub[column])]
162
+ else:
163
+ col_levels = sorted(sub[column].unique(), key=_sort_key)
164
+
165
+ ctab = pd.crosstab(sub[row], sub[column])
166
+ ctab = ctab.reindex(index=row_levels, columns=col_levels, fill_value=0)
167
+
168
+ row_totals = ctab.sum(axis=1)
169
+ col_totals = ctab.sum(axis=0)
170
+ grand_total = float(ctab.values.sum())
171
+
172
+ # ------------------------------------------------------------------
173
+ # Headers
174
+ # ------------------------------------------------------------------
175
+ header_cells = [HeaderCell(text=labels.get(row, row), align="left")]
176
+ for lvl in col_levels:
177
+ header_cells.append(HeaderCell(text=labels.get(lvl, str(lvl))))
178
+ if margins:
179
+ header_cells.append(HeaderCell(text="Total"))
180
+ headers = (HeaderRow(cells=tuple(header_cells)),)
181
+
182
+ # spanning header naming the column variable
183
+ from ..core.schema import SpanningHeader
184
+ spanning: tuple[SpanningHeader, ...]
185
+ if len(col_levels) > 0:
186
+ spanning = (SpanningHeader(
187
+ label=labels.get(column, column),
188
+ start=1,
189
+ end=len(col_levels) + (1 if margins else 0),
190
+ ),)
191
+ else: # pragma: no cover — guarded by the empty-sub short-circuit above
192
+ spanning = ()
193
+
194
+ # ------------------------------------------------------------------
195
+ # Body rows
196
+ # ------------------------------------------------------------------
197
+ rows: list[Row] = []
198
+ for r_lvl in row_levels:
199
+ body = [make_cell(labels.get(r_lvl, str(r_lvl)), align="left")]
200
+ for c_lvl in col_levels:
201
+ n = int(ctab.loc[r_lvl, c_lvl])
202
+ body.append(_fmt_cross_cell(
203
+ n=n,
204
+ row_total=int(row_totals.loc[r_lvl]),
205
+ col_total=int(col_totals.loc[c_lvl]),
206
+ grand_total=grand_total,
207
+ style=cell,
208
+ digits=digits,
209
+ ))
210
+ if margins:
211
+ rt = int(row_totals.loc[r_lvl])
212
+ body.append(make_cell(
213
+ _fmt_margin(rt, grand_total, style=cell, digits=digits),
214
+ value=rt, kind="numeric", align="right",
215
+ ))
216
+ rows.append(Row(cells=tuple(body)))
217
+
218
+ # Margin row
219
+ if margins:
220
+ body = [make_cell("Total", align="left", bold=True)]
221
+ for c_lvl in col_levels:
222
+ ct = int(col_totals.loc[c_lvl])
223
+ body.append(make_cell(
224
+ _fmt_margin(ct, grand_total, style=cell, digits=digits),
225
+ value=ct, kind="numeric", align="right",
226
+ ))
227
+ body.append(make_cell(
228
+ f"{int(grand_total):,}",
229
+ value=int(grand_total),
230
+ kind="numeric", align="right", bold=True,
231
+ ))
232
+ rows.append(Row(cells=tuple(body), is_group_header=True))
233
+
234
+ footnotes = [_footnote_for(cell)]
235
+
236
+ # ------------------------------------------------------------------
237
+ # add_p() — auto-selected categorical test on the full contingency.
238
+ # Reported as a footnote so the table body stays a clean grid.
239
+ # ------------------------------------------------------------------
240
+ metadata: dict[str, Any] = {"builder": "tbl_cross"}
241
+ if spec.options.get("p_value"):
242
+ from .tests import categorical_test
243
+ res = categorical_test(sub[row], sub[column])
244
+ if res.p_value is not None:
245
+ footnotes.append(
246
+ f"{res.test}: p = {fmt_p_value(res.p_value)}",
247
+ )
248
+ # Surface the raw p-value + test name in metadata for
249
+ # programmatic consumers (e.g. golden tests, downstream
250
+ # reports that want the numeric value).
251
+ metadata["p_value"] = float(res.p_value)
252
+ metadata["p_test"] = res.test
253
+
254
+ def _rebuild(new_spec: TableSpec) -> SofraTable:
255
+ return _build_cross(df, new_spec)
256
+
257
+ return SofraTable(
258
+ rows=tuple(rows),
259
+ headers=headers,
260
+ spanning_headers=spanning,
261
+ footnotes=tuple(footnotes),
262
+ metadata=metadata,
263
+ _spec=spec,
264
+ _rebuild=_rebuild,
265
+ )
266
+
267
+
268
+ # ----------------------------------------------------------------------
269
+ # Helpers
270
+ # ----------------------------------------------------------------------
271
+
272
+ def _fmt_cross_cell(
273
+ *, n: int, row_total: int, col_total: int, grand_total: float,
274
+ style: str, digits: int,
275
+ ) -> Any:
276
+ """Format one body cell of the cross-tab according to ``style``."""
277
+ if style == "n":
278
+ return make_cell(f"{n:,}", value=n, kind="numeric", align="right")
279
+ if style == "row_pct":
280
+ pct = 100.0 * n / row_total if row_total else float("nan")
281
+ return make_cell(_pct(pct, digits), value=pct,
282
+ kind="numeric", align="right")
283
+ if style == "col_pct":
284
+ pct = 100.0 * n / col_total if col_total else float("nan")
285
+ return make_cell(_pct(pct, digits), value=pct,
286
+ kind="numeric", align="right")
287
+ if style == "total_pct":
288
+ pct = 100.0 * n / grand_total if grand_total else float("nan")
289
+ return make_cell(_pct(pct, digits), value=pct,
290
+ kind="numeric", align="right")
291
+ if style == "n_row_pct":
292
+ return make_cell(fmt_n_pct(n, row_total, digits=digits),
293
+ value=n, kind="numeric", align="right")
294
+ if style == "n_col_pct":
295
+ return make_cell(fmt_n_pct(n, col_total, digits=digits),
296
+ value=n, kind="numeric", align="right")
297
+ if style == "n_total_pct":
298
+ return make_cell(fmt_n_pct(n, int(grand_total), digits=digits),
299
+ value=n, kind="numeric", align="right")
300
+ raise ValueError(f"unknown cell style {style!r}") # pragma: no cover — guarded by top-level cell-style validation
301
+
302
+
303
+ def _fmt_margin(n: int, grand_total: float, *, style: str, digits: int) -> str:
304
+ """Margin cell formatting — always 'n (overall %)' for n-style cells."""
305
+ if style.startswith("n_"):
306
+ return fmt_n_pct(n, int(grand_total), digits=digits)
307
+ if style in ("row_pct", "col_pct", "total_pct"):
308
+ return _pct(100.0 * n / grand_total if grand_total else float("nan"),
309
+ digits)
310
+ return f"{n:,}"
311
+
312
+
313
+ def _pct(p: float, digits: int) -> str:
314
+ import math
315
+ if p is None or (isinstance(p, float) and math.isnan(p)):
316
+ return "—"
317
+ return f"{p:.{digits}f}%"
318
+
319
+
320
+ def _footnote_for(style: str) -> str:
321
+ return {
322
+ "n": "Cells: raw counts.",
323
+ "row_pct": "Cells: row-percent.",
324
+ "col_pct": "Cells: column-percent.",
325
+ "total_pct": "Cells: overall percent.",
326
+ "n_row_pct": "Cells: n (row-%).",
327
+ "n_col_pct": "Cells: n (column-%).",
328
+ "n_total_pct": "Cells: n (overall-%).",
329
+ }[style]
330
+
331
+
332
+ def _sort_key(x: Any) -> tuple[int, Any]:
333
+ if isinstance(x, bool):
334
+ return (0, int(x))
335
+ if isinstance(x, (int, float)):
336
+ return (0, float(x))
337
+ if isinstance(x, str):
338
+ return (1, x)
339
+ return (2, repr(x))