pysofra 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pysofra/__init__.py +82 -0
- pysofra/core/__init__.py +14 -0
- pysofra/core/compose.py +167 -0
- pysofra/core/format.py +155 -0
- pysofra/core/frames.py +69 -0
- pysofra/core/schema.py +128 -0
- pysofra/core/table.py +924 -0
- pysofra/io/__init__.py +1 -0
- pysofra/models/__init__.py +6 -0
- pysofra/models/extract.py +249 -0
- pysofra/models/pool.py +119 -0
- pysofra/models/regression.py +507 -0
- pysofra/models/survival.py +395 -0
- pysofra/models/uvregression.py +438 -0
- pysofra/notebook/__init__.py +6 -0
- pysofra/plot/__init__.py +23 -0
- pysofra/plot/_backend.py +32 -0
- pysofra/plot/forest.py +159 -0
- pysofra/plot/inline.py +171 -0
- pysofra/plot/km.py +249 -0
- pysofra/render/__init__.py +28 -0
- pysofra/render/_zip_determinism.py +57 -0
- pysofra/render/base.py +22 -0
- pysofra/render/docx.py +286 -0
- pysofra/render/html.py +442 -0
- pysofra/render/image.py +130 -0
- pysofra/render/latex.py +253 -0
- pysofra/render/markdown.py +128 -0
- pysofra/render/pptx.py +340 -0
- pysofra/render/xlsx.py +226 -0
- pysofra/summary/__init__.py +6 -0
- pysofra/summary/calibrate.py +214 -0
- pysofra/summary/design.py +246 -0
- pysofra/summary/effect_size.py +187 -0
- pysofra/summary/extras.py +745 -0
- pysofra/summary/smd.py +133 -0
- pysofra/summary/stats.py +135 -0
- pysofra/summary/tbl_cross.py +339 -0
- pysofra/summary/tbl_one.py +1220 -0
- pysofra/summary/tbl_summary.py +51 -0
- pysofra/summary/tests.py +370 -0
- pysofra/summary/typing.py +129 -0
- pysofra/summary/weights.py +161 -0
- pysofra/themes/__init__.py +5 -0
- pysofra/themes/registry.py +272 -0
- pysofra-0.1.0a1.dist-info/METADATA +301 -0
- pysofra-0.1.0a1.dist-info/RECORD +50 -0
- pysofra-0.1.0a1.dist-info/WHEEL +4 -0
- pysofra-0.1.0a1.dist-info/licenses/LICENSE +674 -0
- pysofra-0.1.0a1.dist-info/licenses/NOTICE +18 -0
pysofra/summary/smd.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
"""Standardized mean differences (SMDs) across groups.
|
|
2
|
+
|
|
3
|
+
For two groups, the SMD is computed as |mean_1 - mean_2| / sd_pool for
|
|
4
|
+
continuous variables and as the multivariate categorical SMD for factors.
|
|
5
|
+
|
|
6
|
+
For three or more groups we report the *maximum pairwise* SMD, which is
|
|
7
|
+
the convention used by ``tableone``. Users who want a different summary
|
|
8
|
+
(mean pairwise, single-reference) can post-process the metadata.
|
|
9
|
+
|
|
10
|
+
References
|
|
11
|
+
----------
|
|
12
|
+
Yang, D., & Dalton, J. E. (2012). A unified approach to measuring the
|
|
13
|
+
effect size between two groups using SAS. SAS Global Forum.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from itertools import combinations
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pandas as pd
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def continuous_smd_pair(a: np.ndarray, b: np.ndarray) -> float | None:
|
|
25
|
+
"""SMD between two continuous samples using pooled SD."""
|
|
26
|
+
a = a[~np.isnan(a)]
|
|
27
|
+
b = b[~np.isnan(b)]
|
|
28
|
+
na, nb = a.size, b.size
|
|
29
|
+
if na < 2 or nb < 2:
|
|
30
|
+
return None
|
|
31
|
+
# Same ``inf``-safety wrap as ``continuous_stats``: numpy's mean / var
|
|
32
|
+
# emit ``RuntimeWarning`` on inf-bearing inputs which escalates to
|
|
33
|
+
# an exception under the project's ``filterwarnings = error`` gate.
|
|
34
|
+
# The resulting ``mean = inf`` / ``var = nan`` are handled by the
|
|
35
|
+
# explicit ``sd_pool == 0.0`` and ``ma == mb`` checks below.
|
|
36
|
+
import warnings as _w
|
|
37
|
+
with np.errstate(invalid="ignore", over="ignore"), _w.catch_warnings():
|
|
38
|
+
_w.simplefilter("ignore", RuntimeWarning)
|
|
39
|
+
ma, mb = float(np.mean(a)), float(np.mean(b))
|
|
40
|
+
va, vb = float(np.var(a, ddof=1)), float(np.var(b, ddof=1))
|
|
41
|
+
sd_pool = float(np.sqrt((va + vb) / 2.0))
|
|
42
|
+
if sd_pool == 0.0:
|
|
43
|
+
return 0.0 if ma == mb else float("inf")
|
|
44
|
+
return abs(ma - mb) / sd_pool
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def categorical_smd_pair(p1: np.ndarray, p2: np.ndarray) -> float | None:
|
|
48
|
+
"""Multivariate categorical SMD between two proportion vectors.
|
|
49
|
+
|
|
50
|
+
``p1`` and ``p2`` are proportion vectors over the same K levels.
|
|
51
|
+
Uses the Yang & Dalton (2012) formulation with K-1 dimensions.
|
|
52
|
+
|
|
53
|
+
Edge cases. When the (K-1) average covariance ``S`` is the zero
|
|
54
|
+
matrix the multivariate quadratic form is undefined; this happens
|
|
55
|
+
when both groups have all mass on a single (possibly different)
|
|
56
|
+
category, including the *complete-separation* case
|
|
57
|
+
(e.g. group 1 = "A", group 2 = "B" only). Returning the
|
|
58
|
+
Mahalanobis distance via the pseudo-inverse silently yields zero
|
|
59
|
+
in that case, which would report perfect balance — the opposite
|
|
60
|
+
of the truth. We therefore return ``inf`` when the contrast is
|
|
61
|
+
nonzero and the covariance is degenerate, and ``0`` when both
|
|
62
|
+
are zero (groups truly identical).
|
|
63
|
+
"""
|
|
64
|
+
if p1.size != p2.size or p1.size < 2:
|
|
65
|
+
return None
|
|
66
|
+
# Use K-1 categories to avoid singular covariance.
|
|
67
|
+
p1 = p1[:-1]
|
|
68
|
+
p2 = p2[:-1]
|
|
69
|
+
diff = p1 - p2
|
|
70
|
+
# Mean covariance matrix S = (S1 + S2) / 2
|
|
71
|
+
s1 = np.diag(p1) - np.outer(p1, p1)
|
|
72
|
+
s2 = np.diag(p2) - np.outer(p2, p2)
|
|
73
|
+
s = (s1 + s2) / 2.0
|
|
74
|
+
# Degenerate covariance: either no variability within either group
|
|
75
|
+
# (each group concentrates on one category) or the K-1 contrast
|
|
76
|
+
# space is empty. ``pinv`` returns a near-zero matrix here, which
|
|
77
|
+
# would falsely report a zero SMD even under complete separation.
|
|
78
|
+
if np.allclose(s, 0.0):
|
|
79
|
+
return 0.0 if np.allclose(diff, 0.0) else float("inf")
|
|
80
|
+
try:
|
|
81
|
+
s_inv = np.linalg.pinv(s)
|
|
82
|
+
except np.linalg.LinAlgError:
|
|
83
|
+
return None
|
|
84
|
+
val = float(diff @ s_inv @ diff)
|
|
85
|
+
if val < 0: # numerical
|
|
86
|
+
val = 0.0
|
|
87
|
+
return float(np.sqrt(val))
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def continuous_smd(values: pd.Series, groups: pd.Series) -> float | None:
|
|
91
|
+
"""Maximum pairwise continuous SMD across groups."""
|
|
92
|
+
df = pd.DataFrame({"v": pd.to_numeric(values, errors="coerce"), "g": groups}).dropna()
|
|
93
|
+
if df.empty:
|
|
94
|
+
return None
|
|
95
|
+
by_group = {g: x["v"].to_numpy() for g, x in df.groupby("g", observed=True)}
|
|
96
|
+
keys = list(by_group)
|
|
97
|
+
if len(keys) < 2:
|
|
98
|
+
return None
|
|
99
|
+
if len(keys) == 2:
|
|
100
|
+
return continuous_smd_pair(by_group[keys[0]], by_group[keys[1]])
|
|
101
|
+
pair_results = [
|
|
102
|
+
continuous_smd_pair(by_group[a], by_group[b])
|
|
103
|
+
for a, b in combinations(keys, 2)
|
|
104
|
+
]
|
|
105
|
+
pairs_cont: list[float] = [p for p in pair_results if p is not None]
|
|
106
|
+
return max(pairs_cont) if pairs_cont else None
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def categorical_smd(
|
|
110
|
+
values: pd.Series,
|
|
111
|
+
groups: pd.Series,
|
|
112
|
+
levels: list[object] | tuple[object, ...] | None = None,
|
|
113
|
+
) -> float | None:
|
|
114
|
+
"""Maximum pairwise categorical SMD across groups."""
|
|
115
|
+
df = pd.DataFrame({"v": values, "g": groups}).dropna()
|
|
116
|
+
if df.empty:
|
|
117
|
+
return None
|
|
118
|
+
ctab = pd.crosstab(df["v"], df["g"])
|
|
119
|
+
if levels is not None:
|
|
120
|
+
ctab = ctab.reindex(index=list(levels), fill_value=0)
|
|
121
|
+
if ctab.shape[0] < 2 or ctab.shape[1] < 2:
|
|
122
|
+
return None
|
|
123
|
+
col_totals = ctab.sum(axis=0).replace(0, np.nan)
|
|
124
|
+
props = (ctab / col_totals).fillna(0.0)
|
|
125
|
+
keys = list(props.columns)
|
|
126
|
+
if len(keys) == 2:
|
|
127
|
+
return categorical_smd_pair(props[keys[0]].to_numpy(), props[keys[1]].to_numpy())
|
|
128
|
+
cat_pair_results = [
|
|
129
|
+
categorical_smd_pair(props[a].to_numpy(), props[b].to_numpy())
|
|
130
|
+
for a, b in combinations(keys, 2)
|
|
131
|
+
]
|
|
132
|
+
pairs_cat: list[float] = [p for p in cat_pair_results if p is not None]
|
|
133
|
+
return max(pairs_cat) if pairs_cat else None
|
pysofra/summary/stats.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Summary statistic computations for continuous and categorical variables.
|
|
2
|
+
|
|
3
|
+
Pure functions that take a pandas Series (or groupby slice) and return a
|
|
4
|
+
small dataclass of statistics. Format-free — formatting belongs to
|
|
5
|
+
:mod:`pysofra.core.format`.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import warnings
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
import pandas as pd
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True)
|
|
18
|
+
class ContinuousStats:
|
|
19
|
+
n: int
|
|
20
|
+
n_missing: int
|
|
21
|
+
mean: float
|
|
22
|
+
sd: float
|
|
23
|
+
median: float
|
|
24
|
+
q1: float
|
|
25
|
+
q3: float
|
|
26
|
+
min: float
|
|
27
|
+
max: float
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def continuous_stats(series: pd.Series) -> ContinuousStats:
|
|
31
|
+
"""Compute continuous summary statistics. Tolerates all-NaN slices."""
|
|
32
|
+
s = pd.to_numeric(series, errors="coerce")
|
|
33
|
+
n_total = len(s)
|
|
34
|
+
valid = s.dropna()
|
|
35
|
+
n = int(valid.size)
|
|
36
|
+
n_missing = int(n_total - n)
|
|
37
|
+
|
|
38
|
+
if n == 0:
|
|
39
|
+
nan = float("nan")
|
|
40
|
+
return ContinuousStats(0, n_missing, nan, nan, nan, nan, nan, nan, nan)
|
|
41
|
+
|
|
42
|
+
arr = valid.to_numpy(dtype=float)
|
|
43
|
+
# When the array contains ``inf``/``-inf``, numpy's ``mean`` / ``std``
|
|
44
|
+
# / ``quantile`` emit ``RuntimeWarning: invalid value encountered in
|
|
45
|
+
# subtract`` (and similar). Under ``filterwarnings = error`` — which
|
|
46
|
+
# this project's own pyproject.toml sets and which is a common
|
|
47
|
+
# user-side ``-W error`` posture — those warnings escalate to
|
|
48
|
+
# exceptions and crash ``tbl_one`` on perfectly legal data. The R6
|
|
49
|
+
# audit fixed the ``int(np.inf) → OverflowError`` path in
|
|
50
|
+
# ``infer_kind`` but didn't reach this downstream stats site.
|
|
51
|
+
# Wrap arithmetic in ``np.errstate`` + ``catch_warnings`` so the
|
|
52
|
+
# stats compute cleanly to ``nan`` / ``inf`` (which the formatters
|
|
53
|
+
# then render as em-dash).
|
|
54
|
+
with np.errstate(invalid="ignore", over="ignore"), warnings.catch_warnings():
|
|
55
|
+
warnings.simplefilter("ignore", RuntimeWarning)
|
|
56
|
+
mean = float(np.mean(arr))
|
|
57
|
+
# Sample SD (ddof=1) is undefined for n=1; we report NaN so renderers
|
|
58
|
+
# can show ``—`` rather than a misleading "0.00".
|
|
59
|
+
sd = float(np.std(arr, ddof=1)) if n > 1 else float("nan")
|
|
60
|
+
median = float(np.median(arr))
|
|
61
|
+
q1, q3 = (float(x) for x in np.quantile(arr, [0.25, 0.75]))
|
|
62
|
+
arr_min = float(np.min(arr))
|
|
63
|
+
arr_max = float(np.max(arr))
|
|
64
|
+
return ContinuousStats(
|
|
65
|
+
n=n,
|
|
66
|
+
n_missing=n_missing,
|
|
67
|
+
mean=mean,
|
|
68
|
+
sd=sd,
|
|
69
|
+
median=median,
|
|
70
|
+
q1=q1,
|
|
71
|
+
q3=q3,
|
|
72
|
+
min=arr_min,
|
|
73
|
+
max=arr_max,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass(frozen=True)
|
|
78
|
+
class CategoricalStats:
|
|
79
|
+
n: int
|
|
80
|
+
n_missing: int
|
|
81
|
+
counts: dict[object, int] # ordered by level
|
|
82
|
+
levels: tuple[object, ...]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def categorical_stats(
|
|
86
|
+
series: pd.Series,
|
|
87
|
+
levels: list[object] | tuple[object, ...] | None = None,
|
|
88
|
+
) -> CategoricalStats:
|
|
89
|
+
"""Compute counts per level.
|
|
90
|
+
|
|
91
|
+
If ``levels`` is provided, levels missing from the series are included
|
|
92
|
+
with count 0 (so that grouped tables align across strata).
|
|
93
|
+
"""
|
|
94
|
+
s = series
|
|
95
|
+
n_missing = int(s.isna().sum())
|
|
96
|
+
valid = s.dropna()
|
|
97
|
+
|
|
98
|
+
if levels is None:
|
|
99
|
+
if isinstance(s.dtype, pd.CategoricalDtype):
|
|
100
|
+
level_list: list[object] = list(s.cat.categories)
|
|
101
|
+
else:
|
|
102
|
+
level_list = sorted(valid.unique(), key=_safe_sort_key)
|
|
103
|
+
else:
|
|
104
|
+
level_list = list(levels)
|
|
105
|
+
|
|
106
|
+
counts: dict[object, int] = {lvl: 0 for lvl in level_list}
|
|
107
|
+
vc = valid.value_counts(dropna=False)
|
|
108
|
+
for lvl, c in vc.items():
|
|
109
|
+
if lvl in counts:
|
|
110
|
+
counts[lvl] = int(c)
|
|
111
|
+
else:
|
|
112
|
+
# Out-of-spec level (only when caller passed explicit levels).
|
|
113
|
+
counts[lvl] = int(c)
|
|
114
|
+
level_list.append(lvl)
|
|
115
|
+
|
|
116
|
+
return CategoricalStats(
|
|
117
|
+
n=int(valid.size),
|
|
118
|
+
n_missing=n_missing,
|
|
119
|
+
counts=counts,
|
|
120
|
+
levels=tuple(level_list),
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _safe_sort_key(x: object) -> tuple[int, object]:
|
|
125
|
+
"""Sort key that puts numerics first, then strings, then everything else.
|
|
126
|
+
|
|
127
|
+
Avoids ``TypeError`` from mixed-type uniques (e.g. ``[1, "a"]``).
|
|
128
|
+
"""
|
|
129
|
+
if isinstance(x, bool):
|
|
130
|
+
return (0, int(x))
|
|
131
|
+
if isinstance(x, (int, float, np.integer, np.floating)):
|
|
132
|
+
return (0, float(x))
|
|
133
|
+
if isinstance(x, str):
|
|
134
|
+
return (1, x)
|
|
135
|
+
return (2, repr(x))
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""Cross-tabulation tables — equivalent to ``gtsummary::tbl_cross``.
|
|
2
|
+
|
|
3
|
+
``tbl_cross`` builds a two-way contingency table with selectable cell
|
|
4
|
+
content:
|
|
5
|
+
|
|
6
|
+
* ``n`` — raw count (default)
|
|
7
|
+
* ``row_pct`` — row-percent
|
|
8
|
+
* ``col_pct`` — column-percent
|
|
9
|
+
* ``total_pct`` — overall percent
|
|
10
|
+
* ``n_row_pct`` — n with row-% in parens (the "n (row %)" style)
|
|
11
|
+
* ``n_col_pct`` — n with col-% in parens
|
|
12
|
+
* ``n_total_pct`` — n with total-% in parens
|
|
13
|
+
|
|
14
|
+
Margins (row totals / column totals / grand total) are added by default
|
|
15
|
+
and can be turned off with ``margins=False``.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import pandas as pd
|
|
23
|
+
|
|
24
|
+
from ..core.format import fmt_n_pct, fmt_p_value
|
|
25
|
+
from ..core.frames import to_pandas
|
|
26
|
+
from ..core.schema import HeaderCell, HeaderRow, Row, make_cell
|
|
27
|
+
from ..core.table import SofraTable, TableSpec
|
|
28
|
+
|
|
29
|
+
_CELL_STYLES = (
|
|
30
|
+
"n", "row_pct", "col_pct", "total_pct",
|
|
31
|
+
"n_row_pct", "n_col_pct", "n_total_pct",
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def tbl_cross(
|
|
36
|
+
data: Any,
|
|
37
|
+
*,
|
|
38
|
+
row: str,
|
|
39
|
+
column: str,
|
|
40
|
+
cell: str = "n_col_pct",
|
|
41
|
+
margins: bool = True,
|
|
42
|
+
digits: int = 1,
|
|
43
|
+
labels: dict[str, str] | None = None,
|
|
44
|
+
) -> SofraTable:
|
|
45
|
+
"""Cross-tabulate ``row`` against ``column``.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
data
|
|
50
|
+
Source dataframe.
|
|
51
|
+
row
|
|
52
|
+
Variable name for the rows.
|
|
53
|
+
column
|
|
54
|
+
Variable name for the columns.
|
|
55
|
+
cell
|
|
56
|
+
How to display each interior cell. See module docstring.
|
|
57
|
+
margins
|
|
58
|
+
Include row / column / grand totals.
|
|
59
|
+
digits
|
|
60
|
+
Decimal places for the percent.
|
|
61
|
+
labels
|
|
62
|
+
Optional mapping of level → display label, applied to both row
|
|
63
|
+
and column labels.
|
|
64
|
+
|
|
65
|
+
Notes
|
|
66
|
+
-----
|
|
67
|
+
The returned :class:`SofraTable` carries a rebuild closure over the
|
|
68
|
+
source ``data`` so the statistical modifiers ``.add_p()`` and
|
|
69
|
+
``.add_overall()`` work directly:
|
|
70
|
+
|
|
71
|
+
* ``.add_p()`` re-runs the cross-tab and appends a *p*-value
|
|
72
|
+
footnote based on the auto-selected categorical test (Fisher's
|
|
73
|
+
exact for 2x2, Pearson χ² otherwise).
|
|
74
|
+
* ``.add_overall()`` toggles ``margins=True`` so the row, column,
|
|
75
|
+
and grand totals are rendered (no-op when margins are already on,
|
|
76
|
+
which is the default).
|
|
77
|
+
* ``.add_smd()`` raises :class:`NotImplementedError` — SMD is a
|
|
78
|
+
between-group effect-size on a single distribution and is
|
|
79
|
+
undefined on a contingency table. Use :func:`tbl_one` for SMD
|
|
80
|
+
between two arms.
|
|
81
|
+
"""
|
|
82
|
+
if cell not in _CELL_STYLES:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"cell must be one of {_CELL_STYLES}; got {cell!r}."
|
|
85
|
+
)
|
|
86
|
+
df = to_pandas(data)
|
|
87
|
+
if row not in df.columns:
|
|
88
|
+
raise KeyError(f"row column {row!r} not in data")
|
|
89
|
+
if column not in df.columns:
|
|
90
|
+
raise KeyError(f"column column {column!r} not in data")
|
|
91
|
+
|
|
92
|
+
spec = TableSpec(
|
|
93
|
+
builder="tbl_cross",
|
|
94
|
+
options={
|
|
95
|
+
"row": row,
|
|
96
|
+
"column": column,
|
|
97
|
+
"cell": cell,
|
|
98
|
+
"margins": margins,
|
|
99
|
+
"digits": digits,
|
|
100
|
+
"labels": dict(labels or {}),
|
|
101
|
+
# Modifier flags — toggled by .add_p() / .add_overall() /
|
|
102
|
+
# .add_smd() via SofraTable._with_option().
|
|
103
|
+
"p_value": False,
|
|
104
|
+
"overall": False,
|
|
105
|
+
"smd": False,
|
|
106
|
+
},
|
|
107
|
+
)
|
|
108
|
+
return _build_cross(df, spec)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _build_cross(df: pd.DataFrame, spec: TableSpec) -> SofraTable:
|
|
112
|
+
"""Build a tbl_cross SofraTable from a frozen source df and spec."""
|
|
113
|
+
row: str = spec.options["row"]
|
|
114
|
+
column: str = spec.options["column"]
|
|
115
|
+
cell: str = spec.options["cell"]
|
|
116
|
+
# add_overall() forces margins on; the user-passed margins= remains
|
|
117
|
+
# otherwise.
|
|
118
|
+
margins: bool = bool(spec.options["margins"] or spec.options.get("overall"))
|
|
119
|
+
digits: int = int(spec.options["digits"])
|
|
120
|
+
labels: dict[str, str] = dict(spec.options.get("labels") or {})
|
|
121
|
+
|
|
122
|
+
# add_smd() is meaningless on a cross-tab. Surface that explicitly
|
|
123
|
+
# rather than emitting a spurious "SMD" column or silently dropping
|
|
124
|
+
# the flag.
|
|
125
|
+
if spec.options.get("smd"):
|
|
126
|
+
raise NotImplementedError(
|
|
127
|
+
"add_smd() is not defined for tbl_cross — SMD measures the "
|
|
128
|
+
"standardised difference between two distributions of a "
|
|
129
|
+
"single variable. For SMD between two arms on the same "
|
|
130
|
+
"variable, use tbl_one(df, by=...).add_smd().",
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Drop rows missing either dimension.
|
|
134
|
+
sub = df[[row, column]].dropna()
|
|
135
|
+
if sub.empty:
|
|
136
|
+
# Emit a minimal placeholder table. Carry the spec + rebuild so
|
|
137
|
+
# the table can still respond to modifiers (no-ops in practice,
|
|
138
|
+
# but the contract is "rebuild always works").
|
|
139
|
+
def _empty_rebuild(new_spec: TableSpec) -> SofraTable: # pragma: no cover — modifier on empty cross-tab
|
|
140
|
+
return _build_cross(df, new_spec)
|
|
141
|
+
return SofraTable(
|
|
142
|
+
rows=(Row(cells=(make_cell(row, align="left", bold=True),
|
|
143
|
+
make_cell("—", value=None,
|
|
144
|
+
kind="numeric", align="right"))),),
|
|
145
|
+
headers=(HeaderRow(cells=(
|
|
146
|
+
HeaderCell(text=labels.get(row, row), align="left"),
|
|
147
|
+
HeaderCell(text=labels.get(column, column)),
|
|
148
|
+
)),),
|
|
149
|
+
footnotes=("No non-missing rows for the requested cross-tabulation.",),
|
|
150
|
+
metadata={"builder": "tbl_cross"},
|
|
151
|
+
_spec=spec,
|
|
152
|
+
_rebuild=_empty_rebuild,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# Preserve categorical ordering where present.
|
|
156
|
+
if isinstance(df[row].dtype, pd.CategoricalDtype):
|
|
157
|
+
row_levels = [lvl for lvl in df[row].cat.categories if lvl in set(sub[row])]
|
|
158
|
+
else:
|
|
159
|
+
row_levels = sorted(sub[row].unique(), key=_sort_key)
|
|
160
|
+
if isinstance(df[column].dtype, pd.CategoricalDtype):
|
|
161
|
+
col_levels = [lvl for lvl in df[column].cat.categories if lvl in set(sub[column])]
|
|
162
|
+
else:
|
|
163
|
+
col_levels = sorted(sub[column].unique(), key=_sort_key)
|
|
164
|
+
|
|
165
|
+
ctab = pd.crosstab(sub[row], sub[column])
|
|
166
|
+
ctab = ctab.reindex(index=row_levels, columns=col_levels, fill_value=0)
|
|
167
|
+
|
|
168
|
+
row_totals = ctab.sum(axis=1)
|
|
169
|
+
col_totals = ctab.sum(axis=0)
|
|
170
|
+
grand_total = float(ctab.values.sum())
|
|
171
|
+
|
|
172
|
+
# ------------------------------------------------------------------
|
|
173
|
+
# Headers
|
|
174
|
+
# ------------------------------------------------------------------
|
|
175
|
+
header_cells = [HeaderCell(text=labels.get(row, row), align="left")]
|
|
176
|
+
for lvl in col_levels:
|
|
177
|
+
header_cells.append(HeaderCell(text=labels.get(lvl, str(lvl))))
|
|
178
|
+
if margins:
|
|
179
|
+
header_cells.append(HeaderCell(text="Total"))
|
|
180
|
+
headers = (HeaderRow(cells=tuple(header_cells)),)
|
|
181
|
+
|
|
182
|
+
# spanning header naming the column variable
|
|
183
|
+
from ..core.schema import SpanningHeader
|
|
184
|
+
spanning: tuple[SpanningHeader, ...]
|
|
185
|
+
if len(col_levels) > 0:
|
|
186
|
+
spanning = (SpanningHeader(
|
|
187
|
+
label=labels.get(column, column),
|
|
188
|
+
start=1,
|
|
189
|
+
end=len(col_levels) + (1 if margins else 0),
|
|
190
|
+
),)
|
|
191
|
+
else: # pragma: no cover — guarded by the empty-sub short-circuit above
|
|
192
|
+
spanning = ()
|
|
193
|
+
|
|
194
|
+
# ------------------------------------------------------------------
|
|
195
|
+
# Body rows
|
|
196
|
+
# ------------------------------------------------------------------
|
|
197
|
+
rows: list[Row] = []
|
|
198
|
+
for r_lvl in row_levels:
|
|
199
|
+
body = [make_cell(labels.get(r_lvl, str(r_lvl)), align="left")]
|
|
200
|
+
for c_lvl in col_levels:
|
|
201
|
+
n = int(ctab.loc[r_lvl, c_lvl])
|
|
202
|
+
body.append(_fmt_cross_cell(
|
|
203
|
+
n=n,
|
|
204
|
+
row_total=int(row_totals.loc[r_lvl]),
|
|
205
|
+
col_total=int(col_totals.loc[c_lvl]),
|
|
206
|
+
grand_total=grand_total,
|
|
207
|
+
style=cell,
|
|
208
|
+
digits=digits,
|
|
209
|
+
))
|
|
210
|
+
if margins:
|
|
211
|
+
rt = int(row_totals.loc[r_lvl])
|
|
212
|
+
body.append(make_cell(
|
|
213
|
+
_fmt_margin(rt, grand_total, style=cell, digits=digits),
|
|
214
|
+
value=rt, kind="numeric", align="right",
|
|
215
|
+
))
|
|
216
|
+
rows.append(Row(cells=tuple(body)))
|
|
217
|
+
|
|
218
|
+
# Margin row
|
|
219
|
+
if margins:
|
|
220
|
+
body = [make_cell("Total", align="left", bold=True)]
|
|
221
|
+
for c_lvl in col_levels:
|
|
222
|
+
ct = int(col_totals.loc[c_lvl])
|
|
223
|
+
body.append(make_cell(
|
|
224
|
+
_fmt_margin(ct, grand_total, style=cell, digits=digits),
|
|
225
|
+
value=ct, kind="numeric", align="right",
|
|
226
|
+
))
|
|
227
|
+
body.append(make_cell(
|
|
228
|
+
f"{int(grand_total):,}",
|
|
229
|
+
value=int(grand_total),
|
|
230
|
+
kind="numeric", align="right", bold=True,
|
|
231
|
+
))
|
|
232
|
+
rows.append(Row(cells=tuple(body), is_group_header=True))
|
|
233
|
+
|
|
234
|
+
footnotes = [_footnote_for(cell)]
|
|
235
|
+
|
|
236
|
+
# ------------------------------------------------------------------
|
|
237
|
+
# add_p() — auto-selected categorical test on the full contingency.
|
|
238
|
+
# Reported as a footnote so the table body stays a clean grid.
|
|
239
|
+
# ------------------------------------------------------------------
|
|
240
|
+
metadata: dict[str, Any] = {"builder": "tbl_cross"}
|
|
241
|
+
if spec.options.get("p_value"):
|
|
242
|
+
from .tests import categorical_test
|
|
243
|
+
res = categorical_test(sub[row], sub[column])
|
|
244
|
+
if res.p_value is not None:
|
|
245
|
+
footnotes.append(
|
|
246
|
+
f"{res.test}: p = {fmt_p_value(res.p_value)}",
|
|
247
|
+
)
|
|
248
|
+
# Surface the raw p-value + test name in metadata for
|
|
249
|
+
# programmatic consumers (e.g. golden tests, downstream
|
|
250
|
+
# reports that want the numeric value).
|
|
251
|
+
metadata["p_value"] = float(res.p_value)
|
|
252
|
+
metadata["p_test"] = res.test
|
|
253
|
+
|
|
254
|
+
def _rebuild(new_spec: TableSpec) -> SofraTable:
|
|
255
|
+
return _build_cross(df, new_spec)
|
|
256
|
+
|
|
257
|
+
return SofraTable(
|
|
258
|
+
rows=tuple(rows),
|
|
259
|
+
headers=headers,
|
|
260
|
+
spanning_headers=spanning,
|
|
261
|
+
footnotes=tuple(footnotes),
|
|
262
|
+
metadata=metadata,
|
|
263
|
+
_spec=spec,
|
|
264
|
+
_rebuild=_rebuild,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# ----------------------------------------------------------------------
|
|
269
|
+
# Helpers
|
|
270
|
+
# ----------------------------------------------------------------------
|
|
271
|
+
|
|
272
|
+
def _fmt_cross_cell(
|
|
273
|
+
*, n: int, row_total: int, col_total: int, grand_total: float,
|
|
274
|
+
style: str, digits: int,
|
|
275
|
+
) -> Any:
|
|
276
|
+
"""Format one body cell of the cross-tab according to ``style``."""
|
|
277
|
+
if style == "n":
|
|
278
|
+
return make_cell(f"{n:,}", value=n, kind="numeric", align="right")
|
|
279
|
+
if style == "row_pct":
|
|
280
|
+
pct = 100.0 * n / row_total if row_total else float("nan")
|
|
281
|
+
return make_cell(_pct(pct, digits), value=pct,
|
|
282
|
+
kind="numeric", align="right")
|
|
283
|
+
if style == "col_pct":
|
|
284
|
+
pct = 100.0 * n / col_total if col_total else float("nan")
|
|
285
|
+
return make_cell(_pct(pct, digits), value=pct,
|
|
286
|
+
kind="numeric", align="right")
|
|
287
|
+
if style == "total_pct":
|
|
288
|
+
pct = 100.0 * n / grand_total if grand_total else float("nan")
|
|
289
|
+
return make_cell(_pct(pct, digits), value=pct,
|
|
290
|
+
kind="numeric", align="right")
|
|
291
|
+
if style == "n_row_pct":
|
|
292
|
+
return make_cell(fmt_n_pct(n, row_total, digits=digits),
|
|
293
|
+
value=n, kind="numeric", align="right")
|
|
294
|
+
if style == "n_col_pct":
|
|
295
|
+
return make_cell(fmt_n_pct(n, col_total, digits=digits),
|
|
296
|
+
value=n, kind="numeric", align="right")
|
|
297
|
+
if style == "n_total_pct":
|
|
298
|
+
return make_cell(fmt_n_pct(n, int(grand_total), digits=digits),
|
|
299
|
+
value=n, kind="numeric", align="right")
|
|
300
|
+
raise ValueError(f"unknown cell style {style!r}") # pragma: no cover — guarded by top-level cell-style validation
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _fmt_margin(n: int, grand_total: float, *, style: str, digits: int) -> str:
|
|
304
|
+
"""Margin cell formatting — always 'n (overall %)' for n-style cells."""
|
|
305
|
+
if style.startswith("n_"):
|
|
306
|
+
return fmt_n_pct(n, int(grand_total), digits=digits)
|
|
307
|
+
if style in ("row_pct", "col_pct", "total_pct"):
|
|
308
|
+
return _pct(100.0 * n / grand_total if grand_total else float("nan"),
|
|
309
|
+
digits)
|
|
310
|
+
return f"{n:,}"
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _pct(p: float, digits: int) -> str:
|
|
314
|
+
import math
|
|
315
|
+
if p is None or (isinstance(p, float) and math.isnan(p)):
|
|
316
|
+
return "—"
|
|
317
|
+
return f"{p:.{digits}f}%"
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _footnote_for(style: str) -> str:
|
|
321
|
+
return {
|
|
322
|
+
"n": "Cells: raw counts.",
|
|
323
|
+
"row_pct": "Cells: row-percent.",
|
|
324
|
+
"col_pct": "Cells: column-percent.",
|
|
325
|
+
"total_pct": "Cells: overall percent.",
|
|
326
|
+
"n_row_pct": "Cells: n (row-%).",
|
|
327
|
+
"n_col_pct": "Cells: n (column-%).",
|
|
328
|
+
"n_total_pct": "Cells: n (overall-%).",
|
|
329
|
+
}[style]
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _sort_key(x: Any) -> tuple[int, Any]:
|
|
333
|
+
if isinstance(x, bool):
|
|
334
|
+
return (0, int(x))
|
|
335
|
+
if isinstance(x, (int, float)):
|
|
336
|
+
return (0, float(x))
|
|
337
|
+
if isinstance(x, str):
|
|
338
|
+
return (1, x)
|
|
339
|
+
return (2, repr(x))
|