pysofra 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pysofra/__init__.py +82 -0
- pysofra/core/__init__.py +14 -0
- pysofra/core/compose.py +167 -0
- pysofra/core/format.py +155 -0
- pysofra/core/frames.py +69 -0
- pysofra/core/schema.py +128 -0
- pysofra/core/table.py +924 -0
- pysofra/io/__init__.py +1 -0
- pysofra/models/__init__.py +6 -0
- pysofra/models/extract.py +249 -0
- pysofra/models/pool.py +119 -0
- pysofra/models/regression.py +507 -0
- pysofra/models/survival.py +395 -0
- pysofra/models/uvregression.py +438 -0
- pysofra/notebook/__init__.py +6 -0
- pysofra/plot/__init__.py +23 -0
- pysofra/plot/_backend.py +32 -0
- pysofra/plot/forest.py +159 -0
- pysofra/plot/inline.py +171 -0
- pysofra/plot/km.py +249 -0
- pysofra/render/__init__.py +28 -0
- pysofra/render/_zip_determinism.py +57 -0
- pysofra/render/base.py +22 -0
- pysofra/render/docx.py +286 -0
- pysofra/render/html.py +442 -0
- pysofra/render/image.py +130 -0
- pysofra/render/latex.py +253 -0
- pysofra/render/markdown.py +128 -0
- pysofra/render/pptx.py +340 -0
- pysofra/render/xlsx.py +226 -0
- pysofra/summary/__init__.py +6 -0
- pysofra/summary/calibrate.py +214 -0
- pysofra/summary/design.py +246 -0
- pysofra/summary/effect_size.py +187 -0
- pysofra/summary/extras.py +745 -0
- pysofra/summary/smd.py +133 -0
- pysofra/summary/stats.py +135 -0
- pysofra/summary/tbl_cross.py +339 -0
- pysofra/summary/tbl_one.py +1220 -0
- pysofra/summary/tbl_summary.py +51 -0
- pysofra/summary/tests.py +370 -0
- pysofra/summary/typing.py +129 -0
- pysofra/summary/weights.py +161 -0
- pysofra/themes/__init__.py +5 -0
- pysofra/themes/registry.py +272 -0
- pysofra-0.1.0a1.dist-info/METADATA +301 -0
- pysofra-0.1.0a1.dist-info/RECORD +50 -0
- pysofra-0.1.0a1.dist-info/WHEEL +4 -0
- pysofra-0.1.0a1.dist-info/licenses/LICENSE +674 -0
- pysofra-0.1.0a1.dist-info/licenses/NOTICE +18 -0
|
@@ -0,0 +1,507 @@
|
|
|
1
|
+
"""Regression tables — equivalent to ``gtsummary::tbl_regression``.
|
|
2
|
+
|
|
3
|
+
Supports any of:
|
|
4
|
+
|
|
5
|
+
* **statsmodels** — ``OLS``, ``GLM``, ``Logit``, ``Probit``, ``Poisson``,
|
|
6
|
+
``NegativeBinomial``, etc. (anything that exposes
|
|
7
|
+
``.params`` / ``.pvalues`` / ``.conf_int()``).
|
|
8
|
+
* **lifelines** — ``CoxPHFitter``, ``WeibullAFTFitter``,
|
|
9
|
+
``LogNormalAFTFitter``, and similar regression fitters with ``.summary``.
|
|
10
|
+
* **sklearn** — ``LinearRegression``, ``LogisticRegression`` (binary),
|
|
11
|
+
``Lasso``, ``Ridge`` etc. Point estimates only; sklearn does not expose
|
|
12
|
+
confidence intervals.
|
|
13
|
+
|
|
14
|
+
Pass a single model for a one-model table, or a list for a side-by-side
|
|
15
|
+
multi-model comparison.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
|
|
24
|
+
from ..core.format import fmt_number, fmt_p_value
|
|
25
|
+
from ..core.schema import HeaderCell, HeaderRow, Row, make_cell
|
|
26
|
+
from ..core.table import SofraTable, TableSpec
|
|
27
|
+
from .extract import ModelSummary, extract
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def tbl_regression(
|
|
31
|
+
model: Any | list[Any],
|
|
32
|
+
*,
|
|
33
|
+
exponentiate: bool | None = None,
|
|
34
|
+
conf_level: float = 0.95,
|
|
35
|
+
digits: int = 2,
|
|
36
|
+
labels: dict[str, str] | None = None,
|
|
37
|
+
intercept: bool = False,
|
|
38
|
+
estimate_label: str | None = None,
|
|
39
|
+
model_labels: list[str] | None = None,
|
|
40
|
+
design: Any = None,
|
|
41
|
+
data: Any = None,
|
|
42
|
+
) -> SofraTable:
|
|
43
|
+
"""Build a regression results table.
|
|
44
|
+
|
|
45
|
+
Parameters
|
|
46
|
+
----------
|
|
47
|
+
model
|
|
48
|
+
A fitted model, or a list of fitted models for a multi-model
|
|
49
|
+
side-by-side table.
|
|
50
|
+
exponentiate
|
|
51
|
+
If ``True``, exponentiate point estimates and CI bounds (ORs / HRs
|
|
52
|
+
/ IRRs). ``None`` (default) auto-selects: ``True`` for log-link
|
|
53
|
+
models (Logit / Poisson / Cox / Weibull AFT), ``False`` otherwise.
|
|
54
|
+
conf_level
|
|
55
|
+
Confidence level for the CI column (default 95%).
|
|
56
|
+
digits
|
|
57
|
+
Decimal places for estimates and CI bounds.
|
|
58
|
+
labels
|
|
59
|
+
Mapping from coefficient name → display label. Shared across all
|
|
60
|
+
models in a multi-model table.
|
|
61
|
+
intercept
|
|
62
|
+
Whether to include the intercept row.
|
|
63
|
+
estimate_label
|
|
64
|
+
Custom header label for the estimate column. Defaults to ``OR`` /
|
|
65
|
+
``HR`` / ``IRR`` / ``β`` / ``Estimate`` based on the detected
|
|
66
|
+
model family.
|
|
67
|
+
model_labels
|
|
68
|
+
For multi-model tables, the spanning-header label for each model
|
|
69
|
+
(defaults to ``Model 1``, ``Model 2``, ...).
|
|
70
|
+
design
|
|
71
|
+
Optional :class:`~pysofra.SurveyDesign`. When provided, the fit
|
|
72
|
+
is re-summarised with cluster-robust standard errors (Taylor
|
|
73
|
+
linearisation matching ``survey::svyglm`` to first order). The
|
|
74
|
+
``data`` argument is required for statsmodels models when a
|
|
75
|
+
design with cluster columns is given.
|
|
76
|
+
data
|
|
77
|
+
Source dataframe — needed only when ``design=`` references
|
|
78
|
+
columns that the fitted model didn't already see.
|
|
79
|
+
"""
|
|
80
|
+
models = list(model) if isinstance(model, (list, tuple)) else [model]
|
|
81
|
+
if not models:
|
|
82
|
+
raise ValueError("tbl_regression requires at least one model.")
|
|
83
|
+
|
|
84
|
+
if design is not None:
|
|
85
|
+
# ``data`` may be a single DataFrame (shared by every model) or a
|
|
86
|
+
# list of one DataFrame per model when each fit was on a different
|
|
87
|
+
# slice.
|
|
88
|
+
if isinstance(data, (list, tuple)):
|
|
89
|
+
if len(data) != len(models):
|
|
90
|
+
raise ValueError(
|
|
91
|
+
"When data= is a list it must have one DataFrame per model "
|
|
92
|
+
f"(got {len(data)} for {len(models)} models)."
|
|
93
|
+
)
|
|
94
|
+
datas = list(data)
|
|
95
|
+
else:
|
|
96
|
+
datas = [data] * len(models)
|
|
97
|
+
models = [
|
|
98
|
+
_refit_with_design(m, design, d)
|
|
99
|
+
for m, d in zip(models, datas, strict=True)
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
summaries = [extract(m, conf_level=conf_level) for m in models]
|
|
103
|
+
labels = dict(labels or {})
|
|
104
|
+
|
|
105
|
+
if len(summaries) == 1:
|
|
106
|
+
tbl = _build_single(
|
|
107
|
+
summaries[0],
|
|
108
|
+
exponentiate=exponentiate,
|
|
109
|
+
conf_level=conf_level,
|
|
110
|
+
digits=digits,
|
|
111
|
+
labels=labels,
|
|
112
|
+
intercept=intercept,
|
|
113
|
+
estimate_label=estimate_label,
|
|
114
|
+
)
|
|
115
|
+
# Attach the fitted model so add_global_p() can run Wald F-tests.
|
|
116
|
+
from dataclasses import replace as _replace
|
|
117
|
+
new_md = dict(tbl.metadata)
|
|
118
|
+
new_md["model"] = models[0]
|
|
119
|
+
new_md["design"] = design
|
|
120
|
+
return _replace(tbl, metadata=new_md)
|
|
121
|
+
|
|
122
|
+
return _build_multi(
|
|
123
|
+
summaries,
|
|
124
|
+
exponentiate=exponentiate,
|
|
125
|
+
conf_level=conf_level,
|
|
126
|
+
digits=digits,
|
|
127
|
+
labels=labels,
|
|
128
|
+
intercept=intercept,
|
|
129
|
+
estimate_label=estimate_label,
|
|
130
|
+
model_labels=model_labels,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
# ----------------------------------------------------------------------
|
|
135
|
+
# Single-model
|
|
136
|
+
# ----------------------------------------------------------------------
|
|
137
|
+
|
|
138
|
+
def _build_single(
|
|
139
|
+
summary: ModelSummary,
|
|
140
|
+
*,
|
|
141
|
+
exponentiate: bool | None,
|
|
142
|
+
conf_level: float,
|
|
143
|
+
digits: int,
|
|
144
|
+
labels: dict[str, str],
|
|
145
|
+
intercept: bool,
|
|
146
|
+
estimate_label: str | None,
|
|
147
|
+
) -> SofraTable:
|
|
148
|
+
exp = summary.natural_exponentiate if exponentiate is None else bool(exponentiate)
|
|
149
|
+
label = estimate_label or _default_estimate_label(summary.family, exp)
|
|
150
|
+
|
|
151
|
+
keep = [n for n in summary.estimates.index if intercept or not _is_intercept(n)]
|
|
152
|
+
|
|
153
|
+
header_cells = (
|
|
154
|
+
HeaderCell(text="Variable", align="left"),
|
|
155
|
+
HeaderCell(text=label),
|
|
156
|
+
HeaderCell(text=f"{int(round(conf_level * 100))}% CI"),
|
|
157
|
+
HeaderCell(text="p-value"),
|
|
158
|
+
)
|
|
159
|
+
headers = (HeaderRow(cells=header_cells),)
|
|
160
|
+
|
|
161
|
+
rows: list[Row] = []
|
|
162
|
+
for name in keep:
|
|
163
|
+
rows.append(_render_coef_row(
|
|
164
|
+
summary, name, exp=exp, digits=digits, labels=labels,
|
|
165
|
+
))
|
|
166
|
+
|
|
167
|
+
footnotes = _footnotes(summary.family, exp, conf_level, label, has_ci=True)
|
|
168
|
+
spec = TableSpec(
|
|
169
|
+
builder="tbl_regression",
|
|
170
|
+
options={
|
|
171
|
+
"exponentiate": exp,
|
|
172
|
+
"conf_level": conf_level,
|
|
173
|
+
"digits": digits,
|
|
174
|
+
"intercept": intercept,
|
|
175
|
+
},
|
|
176
|
+
)
|
|
177
|
+
return SofraTable(
|
|
178
|
+
rows=tuple(rows),
|
|
179
|
+
headers=headers,
|
|
180
|
+
footnotes=tuple(footnotes),
|
|
181
|
+
metadata={"builder": "tbl_regression", "family": summary.family},
|
|
182
|
+
_spec=spec,
|
|
183
|
+
_rebuild=None,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
# ----------------------------------------------------------------------
|
|
188
|
+
# Multi-model
|
|
189
|
+
# ----------------------------------------------------------------------
|
|
190
|
+
|
|
191
|
+
def _build_multi(
|
|
192
|
+
summaries: list[ModelSummary],
|
|
193
|
+
*,
|
|
194
|
+
exponentiate: bool | None,
|
|
195
|
+
conf_level: float,
|
|
196
|
+
digits: int,
|
|
197
|
+
labels: dict[str, str],
|
|
198
|
+
intercept: bool,
|
|
199
|
+
estimate_label: str | None,
|
|
200
|
+
model_labels: list[str] | None,
|
|
201
|
+
) -> SofraTable:
|
|
202
|
+
if model_labels is not None and len(model_labels) != len(summaries):
|
|
203
|
+
raise ValueError(
|
|
204
|
+
f"model_labels has {len(model_labels)} entries but {len(summaries)} models."
|
|
205
|
+
)
|
|
206
|
+
model_labels = model_labels or [f"Model {i + 1}" for i in range(len(summaries))]
|
|
207
|
+
|
|
208
|
+
# Union of coefficient names, ordered by first appearance across models.
|
|
209
|
+
coef_order: list[str] = []
|
|
210
|
+
seen: set[str] = set()
|
|
211
|
+
for s in summaries:
|
|
212
|
+
for n in s.estimates.index:
|
|
213
|
+
if not intercept and _is_intercept(n):
|
|
214
|
+
continue
|
|
215
|
+
if n not in seen:
|
|
216
|
+
seen.add(n)
|
|
217
|
+
coef_order.append(n)
|
|
218
|
+
|
|
219
|
+
# Per-model exponentiate decision (each model may have a different link).
|
|
220
|
+
exp_per = [
|
|
221
|
+
s.natural_exponentiate if exponentiate is None else bool(exponentiate)
|
|
222
|
+
for s in summaries
|
|
223
|
+
]
|
|
224
|
+
labels_per = [
|
|
225
|
+
estimate_label or _default_estimate_label(s.family, e)
|
|
226
|
+
for s, e in zip(summaries, exp_per, strict=True)
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
# Header: Variable, then for each model: {label}, CI, p
|
|
230
|
+
header_cells = [HeaderCell(text="Variable", align="left")]
|
|
231
|
+
spanning = []
|
|
232
|
+
col = 1
|
|
233
|
+
for label, ml in zip(labels_per, model_labels, strict=True):
|
|
234
|
+
spanning.append((ml, col, col + 2))
|
|
235
|
+
header_cells.append(HeaderCell(text=label))
|
|
236
|
+
header_cells.append(HeaderCell(text=f"{int(round(conf_level * 100))}% CI"))
|
|
237
|
+
header_cells.append(HeaderCell(text="p"))
|
|
238
|
+
col += 3
|
|
239
|
+
headers = (HeaderRow(cells=tuple(header_cells)),)
|
|
240
|
+
|
|
241
|
+
from ..core.schema import SpanningHeader
|
|
242
|
+
spanning_headers = tuple(
|
|
243
|
+
SpanningHeader(label=ml, start=s, end=e) for ml, s, e in spanning
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
rows: list[Row] = []
|
|
247
|
+
for name in coef_order:
|
|
248
|
+
cells = [make_cell(labels.get(name, name), align="left")]
|
|
249
|
+
for s, e in zip(summaries, exp_per, strict=True):
|
|
250
|
+
if name in s.estimates.index:
|
|
251
|
+
est = float(s.estimates[name])
|
|
252
|
+
lo = float(s.ci_lo[name]) if name in s.ci_lo.index else float("nan")
|
|
253
|
+
hi = float(s.ci_hi[name]) if name in s.ci_hi.index else float("nan")
|
|
254
|
+
p = float(s.pvalues[name]) if name in s.pvalues.index else float("nan")
|
|
255
|
+
if e:
|
|
256
|
+
with np.errstate(over="ignore"):
|
|
257
|
+
est, lo, hi = np.exp(est), np.exp(lo), np.exp(hi)
|
|
258
|
+
cells.append(make_cell(fmt_number(est, digits), value=est,
|
|
259
|
+
kind="numeric", align="right"))
|
|
260
|
+
cells.append(make_cell(
|
|
261
|
+
f"{fmt_number(lo, digits)}, {fmt_number(hi, digits)}",
|
|
262
|
+
value=(lo, hi), kind="ci", align="right",
|
|
263
|
+
))
|
|
264
|
+
cells.append(make_cell(fmt_p_value(p), value=p, kind="p_value",
|
|
265
|
+
align="right"))
|
|
266
|
+
else:
|
|
267
|
+
cells.append(make_cell("—", value=None, align="right"))
|
|
268
|
+
cells.append(make_cell("—", value=None, align="right"))
|
|
269
|
+
cells.append(make_cell("—", value=None, align="right"))
|
|
270
|
+
rows.append(Row(cells=tuple(cells)))
|
|
271
|
+
|
|
272
|
+
footnotes: list[str] = []
|
|
273
|
+
for s, e, ml in zip(summaries, exp_per, model_labels, strict=True):
|
|
274
|
+
footnotes.append(f"{ml}: {s.family}{' (exponentiated)' if e else ''}.")
|
|
275
|
+
footnotes.append(f"CI = {int(round(conf_level * 100))}% confidence interval.")
|
|
276
|
+
|
|
277
|
+
return SofraTable(
|
|
278
|
+
rows=tuple(rows),
|
|
279
|
+
headers=headers,
|
|
280
|
+
spanning_headers=spanning_headers,
|
|
281
|
+
footnotes=tuple(footnotes),
|
|
282
|
+
metadata={"builder": "tbl_regression", "n_models": len(summaries)},
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
# ----------------------------------------------------------------------
|
|
287
|
+
# Helpers
|
|
288
|
+
# ----------------------------------------------------------------------
|
|
289
|
+
|
|
290
|
+
def _render_coef_row(
|
|
291
|
+
summary: ModelSummary,
|
|
292
|
+
name: str,
|
|
293
|
+
*,
|
|
294
|
+
exp: bool,
|
|
295
|
+
digits: int,
|
|
296
|
+
labels: dict[str, str],
|
|
297
|
+
) -> Row:
|
|
298
|
+
est = float(summary.estimates[name])
|
|
299
|
+
lo = float(summary.ci_lo[name]) if name in summary.ci_lo.index else float("nan")
|
|
300
|
+
hi = float(summary.ci_hi[name]) if name in summary.ci_hi.index else float("nan")
|
|
301
|
+
p = float(summary.pvalues[name]) if name in summary.pvalues.index else float("nan")
|
|
302
|
+
|
|
303
|
+
if exp:
|
|
304
|
+
# Suppress the standard "overflow encountered in exp" warning from
|
|
305
|
+
# pathological estimates (e.g. perfect-separation logits). The
|
|
306
|
+
# formatter already renders inf as ``—``.
|
|
307
|
+
with np.errstate(over="ignore"):
|
|
308
|
+
est_disp, lo_disp, hi_disp = np.exp(est), np.exp(lo), np.exp(hi)
|
|
309
|
+
else:
|
|
310
|
+
est_disp, lo_disp, hi_disp = est, lo, hi
|
|
311
|
+
|
|
312
|
+
label = labels.get(name, name)
|
|
313
|
+
return Row(cells=(
|
|
314
|
+
make_cell(label, align="left"),
|
|
315
|
+
make_cell(fmt_number(est_disp, digits), value=est_disp,
|
|
316
|
+
kind="numeric", align="right"),
|
|
317
|
+
make_cell(f"{fmt_number(lo_disp, digits)}, {fmt_number(hi_disp, digits)}",
|
|
318
|
+
value=(lo_disp, hi_disp), kind="ci", align="right"),
|
|
319
|
+
make_cell(fmt_p_value(p), value=p, kind="p_value", align="right"),
|
|
320
|
+
))
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def _footnotes(family: str, exp: bool, conf_level: float, label: str,
|
|
324
|
+
has_ci: bool) -> list[str]:
|
|
325
|
+
out: list[str] = []
|
|
326
|
+
if exp:
|
|
327
|
+
out.append(
|
|
328
|
+
f"{label} = exponentiated coefficient; "
|
|
329
|
+
f"CI = {int(round(conf_level * 100))}% confidence interval."
|
|
330
|
+
)
|
|
331
|
+
else:
|
|
332
|
+
out.append(f"CI = {int(round(conf_level * 100))}% confidence interval.")
|
|
333
|
+
if family:
|
|
334
|
+
out.append(f"Model: {family}.")
|
|
335
|
+
if not has_ci: # pragma: no cover — every caller currently passes has_ci=True
|
|
336
|
+
out.append("Note: CIs not available for this model type.")
|
|
337
|
+
return out
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _is_intercept(name: str) -> bool:
|
|
341
|
+
return str(name).lower() in {"intercept", "const", "(intercept)"}
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _default_estimate_label(family_label: str, exponentiate: bool) -> str:
|
|
345
|
+
fl = family_label.lower()
|
|
346
|
+
if exponentiate:
|
|
347
|
+
if "cox" in fl or "phreg" in fl:
|
|
348
|
+
return "HR"
|
|
349
|
+
if "weibull" in fl or "lognormal" in fl or "loglogistic" in fl:
|
|
350
|
+
return "HR" # AFT models report exp(coef) as a time ratio; HR is colloquial
|
|
351
|
+
if "logit" in fl or "binomial" in fl or "probit" in fl or "logistic" in fl:
|
|
352
|
+
return "OR"
|
|
353
|
+
if "poisson" in fl or "negativebinomial" in fl:
|
|
354
|
+
return "IRR"
|
|
355
|
+
return "exp(β)"
|
|
356
|
+
if "ols" in fl or "linear" in fl or "gls" in fl or "wls" in fl:
|
|
357
|
+
return "β"
|
|
358
|
+
return "Estimate"
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
# ----------------------------------------------------------------------
|
|
362
|
+
# Design-aware refit (svyglm parity, first-order)
|
|
363
|
+
# ----------------------------------------------------------------------
|
|
364
|
+
|
|
365
|
+
def _refit_with_design(model: Any, design: Any, data: Any) -> Any:
|
|
366
|
+
"""Re-fit a statsmodels model using design-aware point estimates + SEs.
|
|
367
|
+
|
|
368
|
+
Reproduces R's ``survey::svyglm`` to first order:
|
|
369
|
+
|
|
370
|
+
* If ``design.weights`` is set, the model is re-fit with the
|
|
371
|
+
sampling weights folded in — ``WLS(weights=)`` for OLS,
|
|
372
|
+
``GLM(freq_weights=)`` for binomial / Poisson families. Point
|
|
373
|
+
estimates therefore match a weighted analysis (not the original
|
|
374
|
+
unweighted fit).
|
|
375
|
+
* If ``design.cluster`` is set, the variance estimator switches to
|
|
376
|
+
cluster-robust (``cov_type='cluster'``) keyed by the first-stage
|
|
377
|
+
PSU. Otherwise HC1 is used.
|
|
378
|
+
|
|
379
|
+
``design.strata`` and ``design.fpc`` are not yet exhibited by the
|
|
380
|
+
refit — full strata-aware Taylor linearisation for GLM-family
|
|
381
|
+
regression is outside the scope of statsmodels' built-in variance
|
|
382
|
+
estimators. We log this limitation but do not refuse to run when
|
|
383
|
+
only strata is set; the cluster / HC1 SE remains a valid (if
|
|
384
|
+
conservative) approximation. To get the exact stratified estimator
|
|
385
|
+
use R's ``survey::svyglm`` directly.
|
|
386
|
+
"""
|
|
387
|
+
import numpy as np
|
|
388
|
+
import pandas as pd
|
|
389
|
+
|
|
390
|
+
inner = getattr(model, "model", None)
|
|
391
|
+
if inner is None or not hasattr(inner, "fit"):
|
|
392
|
+
raise ValueError(
|
|
393
|
+
"design= currently supports statsmodels-style results only."
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
# Use the original DataFrames where available so coefficient names
|
|
397
|
+
# survive the refit (statsmodels uses the column index as the
|
|
398
|
+
# params index).
|
|
399
|
+
endog = getattr(inner.data, "orig_endog", None)
|
|
400
|
+
if endog is None: # pragma: no cover — modern statsmodels always attaches orig_endog
|
|
401
|
+
endog = inner.endog
|
|
402
|
+
exog = getattr(inner.data, "orig_exog", None)
|
|
403
|
+
if exog is None: # pragma: no cover — modern statsmodels always attaches orig_exog
|
|
404
|
+
exog = inner.exog
|
|
405
|
+
|
|
406
|
+
# ------------------------------------------------------------------
|
|
407
|
+
# Build the weight + cluster vectors from `data`. They must be in
|
|
408
|
+
# the same row order as endog/exog — statsmodels keeps them in the
|
|
409
|
+
# frame's natural order after `dropna`-style cleaning, so we trust
|
|
410
|
+
# the user to pass the same `data` that the model was fit on.
|
|
411
|
+
# ------------------------------------------------------------------
|
|
412
|
+
weights_arr = None
|
|
413
|
+
if design.weights is not None:
|
|
414
|
+
if data is None:
|
|
415
|
+
raise ValueError(
|
|
416
|
+
"Pass data= to tbl_regression when design has weights."
|
|
417
|
+
)
|
|
418
|
+
w = pd.to_numeric(data[design.weights], errors="coerce").to_numpy(dtype=float)
|
|
419
|
+
if len(w) != len(endog):
|
|
420
|
+
raise ValueError(
|
|
421
|
+
f"design.weights length {len(w)} does not match the model's "
|
|
422
|
+
f"endog length {len(endog)}; pass the same DataFrame the "
|
|
423
|
+
f"model was fit on."
|
|
424
|
+
)
|
|
425
|
+
# Negative or NaN weights — drop into the model's own warning system.
|
|
426
|
+
w = np.where(np.isfinite(w) & (w > 0), w, 0.0)
|
|
427
|
+
weights_arr = w
|
|
428
|
+
|
|
429
|
+
cluster_arr = None
|
|
430
|
+
if design.cluster is not None:
|
|
431
|
+
# SurveyDesign requires `weights` (no default), so the earlier
|
|
432
|
+
# `if design.weights is not None` branch will already have
|
|
433
|
+
# captured data=None and the length mismatch. Both raises here
|
|
434
|
+
# are kept as defence-in-depth in case SurveyDesign ever gains
|
|
435
|
+
# an optional-weights mode.
|
|
436
|
+
if data is None: # pragma: no cover — guarded by required design.weights
|
|
437
|
+
raise ValueError(
|
|
438
|
+
"Pass data= to tbl_regression when design has cluster columns."
|
|
439
|
+
)
|
|
440
|
+
clust = data[design.primary_cluster].to_numpy()
|
|
441
|
+
if len(clust) != len(endog): # pragma: no cover — guarded by upstream length check
|
|
442
|
+
raise ValueError(
|
|
443
|
+
f"design.cluster length {len(clust)} does not match the model's "
|
|
444
|
+
f"endog length {len(endog)}."
|
|
445
|
+
)
|
|
446
|
+
cluster_arr = clust
|
|
447
|
+
|
|
448
|
+
# ------------------------------------------------------------------
|
|
449
|
+
# Re-fit point estimates with weights when applicable. We dispatch
|
|
450
|
+
# by the original model class so the user's choice of OLS / Logit /
|
|
451
|
+
# Poisson / GLM is preserved.
|
|
452
|
+
# ------------------------------------------------------------------
|
|
453
|
+
cov_kwds: dict[str, Any] = {}
|
|
454
|
+
if cluster_arr is not None:
|
|
455
|
+
cov_type = "cluster"
|
|
456
|
+
cov_kwds["groups"] = cluster_arr
|
|
457
|
+
else:
|
|
458
|
+
cov_type = "HC1"
|
|
459
|
+
|
|
460
|
+
if weights_arr is None:
|
|
461
|
+
# Weight-free path: keep the original model class, just swap
|
|
462
|
+
# in the design-based variance estimator.
|
|
463
|
+
return inner.fit(cov_type=cov_type, cov_kwds=cov_kwds)
|
|
464
|
+
|
|
465
|
+
# Weighted path. Pick the correct refit recipe by model family.
|
|
466
|
+
inner_name = type(inner).__name__
|
|
467
|
+
try:
|
|
468
|
+
import statsmodels.api as sm
|
|
469
|
+
except ImportError as e: # pragma: no cover — guarded by upstream import
|
|
470
|
+
raise ImportError("design= requires statsmodels.") from e
|
|
471
|
+
|
|
472
|
+
# statsmodels emits SpecificationWarning when combining freq_weights
|
|
473
|
+
# with cov_type='cluster' in GLM. The combination is what every
|
|
474
|
+
# design-based regression library uses (R's survey::svyglm,
|
|
475
|
+
# Stata's svyset/regress), so we suppress the warning locally.
|
|
476
|
+
import warnings as _w
|
|
477
|
+
try:
|
|
478
|
+
from statsmodels.tools.sm_exceptions import SpecificationWarning
|
|
479
|
+
except ImportError: # pragma: no cover
|
|
480
|
+
SpecificationWarning = Warning
|
|
481
|
+
|
|
482
|
+
def _fit(refit: Any) -> Any:
|
|
483
|
+
with _w.catch_warnings():
|
|
484
|
+
_w.simplefilter("ignore", SpecificationWarning)
|
|
485
|
+
return refit.fit(cov_type=cov_type, cov_kwds=cov_kwds)
|
|
486
|
+
|
|
487
|
+
if inner_name == "OLS":
|
|
488
|
+
return _fit(sm.WLS(endog, exog, weights=weights_arr))
|
|
489
|
+
|
|
490
|
+
if inner_name == "GLM":
|
|
491
|
+
return _fit(sm.GLM(endog, exog, family=inner.family,
|
|
492
|
+
freq_weights=weights_arr))
|
|
493
|
+
|
|
494
|
+
if inner_name == "Logit":
|
|
495
|
+
return _fit(sm.GLM(endog, exog, family=sm.families.Binomial(),
|
|
496
|
+
freq_weights=weights_arr))
|
|
497
|
+
|
|
498
|
+
if inner_name == "Poisson":
|
|
499
|
+
return _fit(sm.GLM(endog, exog, family=sm.families.Poisson(),
|
|
500
|
+
freq_weights=weights_arr))
|
|
501
|
+
|
|
502
|
+
raise NotImplementedError( # pragma: no cover — exotic statsmodels model
|
|
503
|
+
f"design= with weights does not yet support {inner_name!r}. "
|
|
504
|
+
f"Supported model classes are OLS, GLM, Logit, Poisson. "
|
|
505
|
+
f"For other models, either drop the weights from the design or "
|
|
506
|
+
f"open an issue describing your model."
|
|
507
|
+
)
|