openstat-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. openstat/__init__.py +3 -0
  2. openstat/__main__.py +4 -0
  3. openstat/backends/__init__.py +16 -0
  4. openstat/backends/duckdb_backend.py +70 -0
  5. openstat/backends/polars_backend.py +52 -0
  6. openstat/cli.py +92 -0
  7. openstat/commands/__init__.py +82 -0
  8. openstat/commands/adv_stat_cmds.py +1255 -0
  9. openstat/commands/advanced_ml_cmds.py +576 -0
  10. openstat/commands/advreg_cmds.py +207 -0
  11. openstat/commands/alias_cmds.py +135 -0
  12. openstat/commands/arch_cmds.py +82 -0
  13. openstat/commands/arules_cmds.py +111 -0
  14. openstat/commands/automodel_cmds.py +212 -0
  15. openstat/commands/backend_cmds.py +82 -0
  16. openstat/commands/base.py +170 -0
  17. openstat/commands/bayes_cmds.py +71 -0
  18. openstat/commands/causal_cmds.py +269 -0
  19. openstat/commands/cluster_cmds.py +152 -0
  20. openstat/commands/data_cmds.py +996 -0
  21. openstat/commands/datamanip_cmds.py +672 -0
  22. openstat/commands/dataquality_cmds.py +174 -0
  23. openstat/commands/datetime_cmds.py +176 -0
  24. openstat/commands/dimreduce_cmds.py +184 -0
  25. openstat/commands/discrete_cmds.py +149 -0
  26. openstat/commands/dsl_cmds.py +143 -0
  27. openstat/commands/epi_cmds.py +93 -0
  28. openstat/commands/equiv_tobit_cmds.py +94 -0
  29. openstat/commands/esttab_cmds.py +196 -0
  30. openstat/commands/export_beamer_cmds.py +142 -0
  31. openstat/commands/export_cmds.py +201 -0
  32. openstat/commands/export_extra_cmds.py +240 -0
  33. openstat/commands/factor_cmds.py +180 -0
  34. openstat/commands/groupby_cmds.py +155 -0
  35. openstat/commands/help_cmds.py +237 -0
  36. openstat/commands/i18n_cmds.py +43 -0
  37. openstat/commands/import_extra_cmds.py +561 -0
  38. openstat/commands/influence_cmds.py +134 -0
  39. openstat/commands/iv_cmds.py +106 -0
  40. openstat/commands/manova_cmds.py +105 -0
  41. openstat/commands/mediate_cmds.py +233 -0
  42. openstat/commands/meta_cmds.py +284 -0
  43. openstat/commands/mi_cmds.py +228 -0
  44. openstat/commands/mixed_cmds.py +79 -0
  45. openstat/commands/mixture_changepoint_cmds.py +166 -0
  46. openstat/commands/ml_adv_cmds.py +147 -0
  47. openstat/commands/ml_cmds.py +178 -0
  48. openstat/commands/model_eval_cmds.py +142 -0
  49. openstat/commands/network_cmds.py +288 -0
  50. openstat/commands/nlquery_cmds.py +161 -0
  51. openstat/commands/nonparam_cmds.py +149 -0
  52. openstat/commands/outreg_cmds.py +247 -0
  53. openstat/commands/panel_cmds.py +141 -0
  54. openstat/commands/pdf_cmds.py +226 -0
  55. openstat/commands/pipeline_cmds.py +319 -0
  56. openstat/commands/plot_cmds.py +189 -0
  57. openstat/commands/plugin_cmds.py +79 -0
  58. openstat/commands/posthoc_cmds.py +153 -0
  59. openstat/commands/power_cmds.py +172 -0
  60. openstat/commands/profile_cmds.py +246 -0
  61. openstat/commands/rbridge_cmds.py +81 -0
  62. openstat/commands/regex_cmds.py +104 -0
  63. openstat/commands/report_cmds.py +48 -0
  64. openstat/commands/repro_cmds.py +129 -0
  65. openstat/commands/resampling_cmds.py +109 -0
  66. openstat/commands/reshape_cmds.py +223 -0
  67. openstat/commands/sem_cmds.py +177 -0
  68. openstat/commands/stat_cmds.py +1040 -0
  69. openstat/commands/stata_import_cmds.py +215 -0
  70. openstat/commands/string_cmds.py +124 -0
  71. openstat/commands/surv_cmds.py +145 -0
  72. openstat/commands/survey_cmds.py +153 -0
  73. openstat/commands/textanalysis_cmds.py +192 -0
  74. openstat/commands/ts_adv_cmds.py +136 -0
  75. openstat/commands/ts_cmds.py +195 -0
  76. openstat/commands/tui_cmds.py +111 -0
  77. openstat/commands/ux_cmds.py +191 -0
  78. openstat/commands/validate_cmds.py +270 -0
  79. openstat/commands/viz_adv_cmds.py +312 -0
  80. openstat/commands/viz_extra_cmds.py +251 -0
  81. openstat/commands/watch_cmds.py +69 -0
  82. openstat/config.py +106 -0
  83. openstat/dsl/__init__.py +0 -0
  84. openstat/dsl/parser.py +332 -0
  85. openstat/dsl/tokenizer.py +105 -0
  86. openstat/i18n.py +120 -0
  87. openstat/io/__init__.py +0 -0
  88. openstat/io/loader.py +187 -0
  89. openstat/jupyter/__init__.py +18 -0
  90. openstat/jupyter/display.py +18 -0
  91. openstat/jupyter/magic.py +60 -0
  92. openstat/logging_config.py +59 -0
  93. openstat/plots/__init__.py +0 -0
  94. openstat/plots/plotter.py +437 -0
  95. openstat/plots/surv_plots.py +32 -0
  96. openstat/plots/ts_plots.py +59 -0
  97. openstat/plugins/__init__.py +5 -0
  98. openstat/plugins/manager.py +69 -0
  99. openstat/repl.py +457 -0
  100. openstat/reporting/__init__.py +0 -0
  101. openstat/reporting/eda.py +208 -0
  102. openstat/reporting/report.py +67 -0
  103. openstat/script_runner.py +319 -0
  104. openstat/session.py +133 -0
  105. openstat/stats/__init__.py +0 -0
  106. openstat/stats/advanced_regression.py +269 -0
  107. openstat/stats/arch_garch.py +84 -0
  108. openstat/stats/bayesian.py +103 -0
  109. openstat/stats/causal.py +258 -0
  110. openstat/stats/clustering.py +206 -0
  111. openstat/stats/discrete.py +311 -0
  112. openstat/stats/epidemiology.py +119 -0
  113. openstat/stats/equiv_tobit.py +163 -0
  114. openstat/stats/factor.py +174 -0
  115. openstat/stats/imputation.py +282 -0
  116. openstat/stats/influence.py +78 -0
  117. openstat/stats/iv.py +131 -0
  118. openstat/stats/manova.py +124 -0
  119. openstat/stats/mixed.py +128 -0
  120. openstat/stats/ml.py +275 -0
  121. openstat/stats/ml_advanced.py +117 -0
  122. openstat/stats/model_eval.py +183 -0
  123. openstat/stats/models.py +1342 -0
  124. openstat/stats/nonparametric.py +130 -0
  125. openstat/stats/panel.py +179 -0
  126. openstat/stats/power.py +295 -0
  127. openstat/stats/resampling.py +203 -0
  128. openstat/stats/survey.py +213 -0
  129. openstat/stats/survival.py +196 -0
  130. openstat/stats/timeseries.py +142 -0
  131. openstat/stats/ts_advanced.py +114 -0
  132. openstat/types.py +11 -0
  133. openstat/web/__init__.py +1 -0
  134. openstat/web/app.py +117 -0
  135. openstat/web/session_manager.py +73 -0
  136. openstat/web/static/app.js +117 -0
  137. openstat/web/static/index.html +38 -0
  138. openstat/web/static/style.css +103 -0
  139. openstat_cli-1.0.0.dist-info/METADATA +748 -0
  140. openstat_cli-1.0.0.dist-info/RECORD +143 -0
  141. openstat_cli-1.0.0.dist-info/WHEEL +4 -0
  142. openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
  143. openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1342 @@
1
+ """Statistical models: OLS, Logistic regression, hypothesis tests.
2
+
3
+ Includes diagnostics:
4
+ - Multicollinearity detection (condition number)
5
+ - Convergence check (logit)
6
+ - Minimum observations check
7
+ - Stepwise variable selection (forward/backward)
8
+ - Residual diagnostics
9
+
10
+ Hypothesis tests:
11
+ - t-test (one-sample, two-sample, paired)
12
+ - Chi-square test of independence
13
+ - One-way ANOVA
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import io
19
+ from dataclasses import dataclass, field
20
+
21
+ import numpy as np
22
+ import polars as pl
23
+ import statsmodels.api as sm
24
+ from scipy import stats as sp_stats
25
+ from rich.table import Table
26
+ from rich.console import Console
27
+
28
+ from openstat.config import get_config
29
+
30
+
31
+ @dataclass
32
+ class FitResult:
33
+ """Holds a fitted model result for reporting."""
34
+
35
+ model_type: str # "OLS" or "Logit"
36
+ formula: str
37
+ dep_var: str
38
+ indep_vars: list[str]
39
+ n_obs: int
40
+ params: dict[str, float]
41
+ std_errors: dict[str, float]
42
+ t_values: dict[str, float]
43
+ p_values: dict[str, float]
44
+ conf_int_low: dict[str, float]
45
+ conf_int_high: dict[str, float]
46
+ r_squared: float | None = None # OLS only
47
+ adj_r_squared: float | None = None # OLS only
48
+ f_statistic: float | None = None # OLS only
49
+ f_pvalue: float | None = None # OLS only
50
+ aic: float | None = None
51
+ bic: float | None = None
52
+ pseudo_r2: float | None = None # Logit/Probit only
53
+ log_likelihood: float | None = None
54
+ dispersion: float | None = None # Negative Binomial alpha
55
+ warnings: list[str] = field(default_factory=list)
56
+
57
+ def summary_table(self) -> str:
58
+ """Return a Rich-formatted summary table as string."""
59
+ console = Console(file=io.StringIO(), width=100, record=True)
60
+ table = Table(title=f"{self.model_type}: {self.formula}")
61
+ table.add_column("Variable", style="cyan")
62
+ table.add_column("Coef", justify="right")
63
+ table.add_column("Std.Err", justify="right")
64
+ table.add_column("t/z", justify="right")
65
+ table.add_column("P>|t|", justify="right")
66
+ table.add_column("[95% CI Low]", justify="right")
67
+ table.add_column("[95% CI High]", justify="right")
68
+
69
+ for var in self.params:
70
+ sig = ""
71
+ pv = self.p_values[var]
72
+ if pv < 0.001:
73
+ sig = " ***"
74
+ elif pv < 0.01:
75
+ sig = " **"
76
+ elif pv < 0.05:
77
+ sig = " *"
78
+
79
+ table.add_row(
80
+ var,
81
+ f"{self.params[var]:.4f}",
82
+ f"{self.std_errors[var]:.4f}",
83
+ f"{self.t_values[var]:.3f}",
84
+ f"{pv:.4f}{sig}",
85
+ f"{self.conf_int_low[var]:.4f}",
86
+ f"{self.conf_int_high[var]:.4f}",
87
+ )
88
+
89
+ console.print(table)
90
+ header = f"N = {self.n_obs}"
91
+ if self.r_squared is not None:
92
+ header += f" | R² = {self.r_squared:.4f}"
93
+ if self.adj_r_squared is not None:
94
+ header += f" | Adj.R² = {self.adj_r_squared:.4f}"
95
+ if self.f_statistic is not None and self.f_pvalue is not None:
96
+ k = len(self.indep_vars)
97
+ header += f" | F({k}, {self.n_obs - k - 1}) = {self.f_statistic:.2f} (p={self.f_pvalue:.4f})"
98
+ if self.aic is not None:
99
+ header += f" | AIC = {self.aic:.1f}"
100
+ if self.bic is not None:
101
+ header += f" | BIC = {self.bic:.1f}"
102
+ if self.pseudo_r2 is not None:
103
+ header += f" | Pseudo R² = {self.pseudo_r2:.4f}"
104
+ if self.log_likelihood is not None:
105
+ header += f" | LL = {self.log_likelihood:.1f}"
106
+ if self.dispersion is not None:
107
+ header += f" | alpha = {self.dispersion:.4f}"
108
+ console.print(header)
109
+ console.print("Significance: * p<0.05 ** p<0.01 *** p<0.001")
110
+ return console.export_text()
111
+
112
+ def to_markdown(self) -> str:
113
+ """Return a Markdown-formatted summary."""
114
+ lines = [
115
+ f"### {self.model_type}: {self.formula}",
116
+ "",
117
+ f"N = {self.n_obs}",
118
+ ]
119
+ if self.r_squared is not None:
120
+ lines.append(f"R² = {self.r_squared:.4f}")
121
+ if self.adj_r_squared is not None:
122
+ lines.append(f"Adj. R² = {self.adj_r_squared:.4f}")
123
+ if self.f_statistic is not None and self.f_pvalue is not None:
124
+ lines.append(f"F-statistic = {self.f_statistic:.2f} (p = {self.f_pvalue:.4f})")
125
+ if self.aic is not None:
126
+ lines.append(f"AIC = {self.aic:.1f}")
127
+ if self.bic is not None:
128
+ lines.append(f"BIC = {self.bic:.1f}")
129
+ if self.pseudo_r2 is not None:
130
+ lines.append(f"Pseudo R² = {self.pseudo_r2:.4f}")
131
+ if self.log_likelihood is not None:
132
+ lines.append(f"Log-Likelihood = {self.log_likelihood:.1f}")
133
+ if self.dispersion is not None:
134
+ lines.append(f"Dispersion (alpha) = {self.dispersion:.4f}")
135
+ if self.warnings:
136
+ lines.append("")
137
+ for w in self.warnings:
138
+ lines.append(f"> {w}")
139
+ lines.append("")
140
+ lines.append("| Variable | Coef | Std.Err | t/z | P>\\|t\\| | 95% CI |")
141
+ lines.append("|----------|------|---------|-----|---------|--------|")
142
+ for var in self.params:
143
+ ci = f"[{self.conf_int_low[var]:.4f}, {self.conf_int_high[var]:.4f}]"
144
+ lines.append(
145
+ f"| {var} | {self.params[var]:.4f} | {self.std_errors[var]:.4f} "
146
+ f"| {self.t_values[var]:.3f} | {self.p_values[var]:.4f} | {ci} |"
147
+ )
148
+ lines.append("")
149
+ return "\n".join(lines)
150
+
151
+ def to_latex(self) -> str:
152
+ """Return a LaTeX-formatted regression table."""
153
+ lines = [
154
+ r"\begin{table}[htbp]",
155
+ r"\centering",
156
+ f"\\caption{{{self.model_type}: {self.formula}}}",
157
+ r"\begin{tabular}{lcccccc}",
158
+ r"\hline",
159
+ r"Variable & Coef & Std.Err & t/z & P$>|$t$|$ & [95\% CI Low] & [95\% CI High] \\",
160
+ r"\hline",
161
+ ]
162
+ for var in self.params:
163
+ sig = ""
164
+ pv = self.p_values[var]
165
+ if pv < 0.001:
166
+ sig = "$^{***}$"
167
+ elif pv < 0.01:
168
+ sig = "$^{**}$"
169
+ elif pv < 0.05:
170
+ sig = "$^{*}$"
171
+ # Escape underscores in variable names for LaTeX
172
+ var_tex = var.replace("_", r"\_")
173
+ lines.append(
174
+ f"{var_tex} & {self.params[var]:.4f}{sig} & "
175
+ f"{self.std_errors[var]:.4f} & {self.t_values[var]:.3f} & "
176
+ f"{self.p_values[var]:.4f} & {self.conf_int_low[var]:.4f} & "
177
+ f"{self.conf_int_high[var]:.4f} \\\\"
178
+ )
179
+ lines.append(r"\hline")
180
+ footer = f"N = {self.n_obs}"
181
+ if self.r_squared is not None:
182
+ footer += f", $R^2$ = {self.r_squared:.4f}"
183
+ if self.adj_r_squared is not None:
184
+ footer += f", Adj.$R^2$ = {self.adj_r_squared:.4f}"
185
+ if self.f_statistic is not None and self.f_pvalue is not None:
186
+ footer += f", F = {self.f_statistic:.2f} (p = {self.f_pvalue:.4f})"
187
+ if self.aic is not None:
188
+ footer += f", AIC = {self.aic:.1f}"
189
+ if self.bic is not None:
190
+ footer += f", BIC = {self.bic:.1f}"
191
+ if self.pseudo_r2 is not None:
192
+ footer += f", Pseudo $R^2$ = {self.pseudo_r2:.4f}"
193
+ if self.log_likelihood is not None:
194
+ footer += f", LL = {self.log_likelihood:.1f}"
195
+ if self.dispersion is not None:
196
+ footer += f", $\\alpha$ = {self.dispersion:.4f}"
197
+ lines.append(f"\\multicolumn{{7}}{{l}}{{{footer}}} \\\\")
198
+ lines.append(r"\hline")
199
+ lines.append(r"\multicolumn{7}{l}{\footnotesize $^{*}$p$<$0.05, $^{**}$p$<$0.01, $^{***}$p$<$0.001} \\")
200
+ lines.append(r"\end{tabular}")
201
+ lines.append(r"\end{table}")
202
+ return "\n".join(lines)
203
+
204
+ def to_html(self) -> str:
205
+ """Return HTML-formatted summary for Jupyter display."""
206
+ console = Console(file=io.StringIO(), width=120, record=True)
207
+ table = Table(title=f"{self.model_type}: {self.formula}")
208
+ table.add_column("Variable", style="cyan")
209
+ table.add_column("Coef", justify="right")
210
+ table.add_column("Std.Err", justify="right")
211
+ table.add_column("t/z", justify="right")
212
+ table.add_column("P>|t|", justify="right")
213
+ table.add_column("[95% CI Low]", justify="right")
214
+ table.add_column("[95% CI High]", justify="right")
215
+ for var in self.params:
216
+ sig = ""
217
+ pv = self.p_values[var]
218
+ if pv < 0.001:
219
+ sig = " ***"
220
+ elif pv < 0.01:
221
+ sig = " **"
222
+ elif pv < 0.05:
223
+ sig = " *"
224
+ table.add_row(
225
+ var,
226
+ f"{self.params[var]:.4f}",
227
+ f"{self.std_errors[var]:.4f}",
228
+ f"{self.t_values[var]:.3f}",
229
+ f"{pv:.4f}{sig}",
230
+ f"{self.conf_int_low[var]:.4f}",
231
+ f"{self.conf_int_high[var]:.4f}",
232
+ )
233
+ console.print(table)
234
+ return console.export_html(inline_styles=True)
235
+
236
+
237
+ def _prepare_data(
238
+ df: pl.DataFrame,
239
+ dep: str,
240
+ indeps: list[str],
241
+ *,
242
+ cluster_col: str | None = None,
243
+ ) -> tuple[np.ndarray, np.ndarray, list[str], list[str], np.ndarray | None]:
244
+ """Extract y and X (with constant) as numpy arrays.
245
+
246
+ Handles interaction terms (e.g. ``"x1:x2"`` in *indeps*) by creating
247
+ product columns on the fly.
248
+
249
+ Returns ``(y, X, warnings, var_names, cluster_groups)``.
250
+ *var_names* is ``["_cons"] + expanded_indeps``.
251
+ *cluster_groups* is ``None`` when *cluster_col* is not given.
252
+ """
253
+ # Separate base variables from interaction terms
254
+ base_vars: list[str] = []
255
+ interactions: list[str] = []
256
+ for v in indeps:
257
+ if ":" in v:
258
+ interactions.append(v)
259
+ else:
260
+ base_vars.append(v)
261
+
262
+ # Collect all raw columns needed
263
+ all_base: set[str] = set(base_vars)
264
+ for inter in interactions:
265
+ all_base.update(inter.split(":"))
266
+
267
+ cols_needed = [dep] + sorted(all_base)
268
+ if cluster_col:
269
+ if cluster_col not in df.columns:
270
+ raise ValueError(f"Cluster column not found: {cluster_col}")
271
+ cols_needed = list(dict.fromkeys(cols_needed + [cluster_col]))
272
+
273
+ missing = [c for c in cols_needed if c not in df.columns]
274
+ if missing:
275
+ raise ValueError(f"Columns not found: {', '.join(missing)}")
276
+
277
+ sub = df.select(cols_needed).drop_nulls()
278
+ if sub.height == 0:
279
+ raise ValueError("No observations after dropping missing values")
280
+
281
+ n_dropped = df.height - sub.height
282
+ warnings: list[str] = []
283
+ if n_dropped > 0:
284
+ warnings.append(
285
+ f"Note: {n_dropped} observation(s) dropped due to missing values."
286
+ )
287
+
288
+ y = sub[dep].to_numpy().astype(float)
289
+
290
+ # Build X: base vars first, then interaction columns
291
+ X_parts = sub.select(base_vars).to_numpy().astype(float) if base_vars else np.empty((sub.height, 0))
292
+ var_names_x: list[str] = list(base_vars)
293
+
294
+ for inter in interactions:
295
+ parts = inter.split(":")
296
+ col = np.ones(sub.height)
297
+ for p in parts:
298
+ col = col * sub[p].to_numpy().astype(float)
299
+ X_parts = np.column_stack([X_parts, col]) if X_parts.size > 0 else col.reshape(-1, 1)
300
+ var_names_x.append(inter)
301
+
302
+ # Check minimum observations
303
+ n_params = len(var_names_x) + 1 # +1 for constant
304
+ if sub.height < n_params + 2:
305
+ raise ValueError(
306
+ f"Too few observations ({sub.height}) for {n_params} parameters. "
307
+ f"Need at least {n_params + 2}."
308
+ )
309
+
310
+ cfg = get_config()
311
+ if sub.height < n_params * cfg.min_obs_per_predictor:
312
+ warnings.append(
313
+ f"Warning: Low observations-to-predictors ratio "
314
+ f"({sub.height} obs / {n_params} params = {sub.height / n_params:.1f}). "
315
+ f"Results may be unreliable."
316
+ )
317
+
318
+ X = sm.add_constant(X_parts)
319
+
320
+ # Check multicollinearity via condition number
321
+ try:
322
+ cond = np.linalg.cond(X)
323
+ if cond > cfg.condition_threshold:
324
+ warnings.append(
325
+ f"Warning: Possible multicollinearity detected "
326
+ f"(condition number = {cond:.0f}, threshold = {cfg.condition_threshold}). "
327
+ f"Consider removing correlated predictors."
328
+ )
329
+ except np.linalg.LinAlgError:
330
+ warnings.append("Warning: Could not compute condition number.")
331
+
332
+ # Cluster groups
333
+ groups: np.ndarray | None = None
334
+ if cluster_col:
335
+ groups = sub[cluster_col].to_numpy()
336
+
337
+ var_names = ["_cons"] + var_names_x
338
+ return y, X, warnings, var_names, groups
339
+
340
+
341
+ def _cov_args(
342
+ robust: bool, groups: np.ndarray | None
343
+ ) -> tuple[str, dict]:
344
+ """Return (cov_type, cov_kwds) for statsmodels fit()."""
345
+ if groups is not None:
346
+ return "cluster", {"groups": groups}
347
+ if robust:
348
+ return "HC1", {}
349
+ return "nonrobust", {}
350
+
351
+
352
+ def _model_type_suffix(robust: bool, cluster: bool) -> str:
353
+ if cluster:
354
+ return " (cluster-robust)"
355
+ if robust:
356
+ return " (robust)"
357
+ return ""
358
+
359
+
360
+ def fit_ols(
361
+ df: pl.DataFrame,
362
+ dep: str,
363
+ indeps: list[str],
364
+ *,
365
+ robust: bool = False,
366
+ cluster_col: str | None = None,
367
+ ) -> tuple[FitResult, object]:
368
+ """Fit an OLS regression. Returns (FitResult, raw_model)."""
369
+ y, X, warnings, var_names, groups = _prepare_data(
370
+ df, dep, indeps, cluster_col=cluster_col,
371
+ )
372
+ cov_type, cov_kwds = _cov_args(robust, groups)
373
+ model = sm.OLS(y, X).fit(cov_type=cov_type, cov_kwds=cov_kwds)
374
+ ci = model.conf_int()
375
+
376
+ # Heteroscedasticity check (Breusch-Pagan)
377
+ if not robust and groups is None:
378
+ try:
379
+ from statsmodels.stats.diagnostic import het_breuschpagan
380
+ bp_stat, bp_pval, _, _ = het_breuschpagan(model.resid, model.model.exog)
381
+ if bp_pval < 0.05:
382
+ warnings.append(
383
+ f"Warning: Heteroscedasticity detected (Breusch-Pagan p={bp_pval:.4f}). "
384
+ f"Consider using --robust for heteroscedasticity-robust standard errors."
385
+ )
386
+ except Exception:
387
+ pass # diagnostic failure should not block results
388
+
389
+ # Autocorrelation check (Durbin-Watson)
390
+ try:
391
+ from statsmodels.stats.stattools import durbin_watson
392
+ dw = durbin_watson(model.resid)
393
+ if dw < 1.5 or dw > 2.5:
394
+ warnings.append(
395
+ f"Warning: Possible autocorrelation (Durbin-Watson = {dw:.3f}). "
396
+ f"Values far from 2.0 suggest serial correlation in residuals."
397
+ )
398
+ except Exception:
399
+ pass
400
+
401
+ suffix = _model_type_suffix(robust, groups is not None)
402
+ result = FitResult(
403
+ model_type="OLS" + suffix,
404
+ formula=f"{dep} ~ {' + '.join(indeps)}",
405
+ dep_var=dep,
406
+ indep_vars=indeps,
407
+ n_obs=int(model.nobs),
408
+ params=dict(zip(var_names, model.params)),
409
+ std_errors=dict(zip(var_names, model.bse)),
410
+ t_values=dict(zip(var_names, model.tvalues)),
411
+ p_values=dict(zip(var_names, model.pvalues)),
412
+ conf_int_low=dict(zip(var_names, ci[:, 0])),
413
+ conf_int_high=dict(zip(var_names, ci[:, 1])),
414
+ r_squared=float(model.rsquared),
415
+ adj_r_squared=float(model.rsquared_adj),
416
+ f_statistic=float(model.fvalue) if hasattr(model, "fvalue") and model.fvalue is not None else None,
417
+ f_pvalue=float(model.f_pvalue) if hasattr(model, "f_pvalue") and model.f_pvalue is not None else None,
418
+ aic=float(model.aic),
419
+ bic=float(model.bic),
420
+ warnings=warnings,
421
+ )
422
+ return result, model
423
+
424
+
425
+ def fit_logit(
426
+ df: pl.DataFrame,
427
+ dep: str,
428
+ indeps: list[str],
429
+ *,
430
+ robust: bool = False,
431
+ cluster_col: str | None = None,
432
+ ) -> tuple[FitResult, object]:
433
+ """Fit a logistic regression (binary). Returns (FitResult, raw_model)."""
434
+ y, X, warnings, var_names, groups = _prepare_data(
435
+ df, dep, indeps, cluster_col=cluster_col,
436
+ )
437
+
438
+ unique_y = set(y)
439
+ if not unique_y.issubset({0.0, 1.0}):
440
+ raise ValueError(
441
+ f"Logit requires binary (0/1) dependent variable. "
442
+ f"Found values: {sorted(unique_y)[:10]}"
443
+ )
444
+
445
+ cov_type, cov_kwds = _cov_args(robust, groups)
446
+ model = sm.Logit(y, X).fit(disp=0, cov_type=cov_type, cov_kwds=cov_kwds)
447
+
448
+ # Check convergence
449
+ if hasattr(model, "mle_retvals") and not model.mle_retvals.get("converged", True):
450
+ warnings.append(
451
+ "WARNING: Model did NOT converge. Results may be unreliable. "
452
+ "Consider rescaling variables or reducing predictors."
453
+ )
454
+
455
+ ci = model.conf_int()
456
+ suffix = _model_type_suffix(robust, groups is not None)
457
+
458
+ result = FitResult(
459
+ model_type="Logit" + suffix,
460
+ formula=f"{dep} ~ {' + '.join(indeps)}",
461
+ dep_var=dep,
462
+ indep_vars=indeps,
463
+ n_obs=int(model.nobs),
464
+ params=dict(zip(var_names, model.params)),
465
+ std_errors=dict(zip(var_names, model.bse)),
466
+ t_values=dict(zip(var_names, model.tvalues)),
467
+ p_values=dict(zip(var_names, model.pvalues)),
468
+ conf_int_low=dict(zip(var_names, ci[:, 0])),
469
+ conf_int_high=dict(zip(var_names, ci[:, 1])),
470
+ pseudo_r2=float(model.prsquared),
471
+ aic=float(model.aic),
472
+ bic=float(model.bic),
473
+ warnings=warnings,
474
+ )
475
+ return result, model
476
+
477
+
478
+ def fit_probit(
479
+ df: pl.DataFrame,
480
+ dep: str,
481
+ indeps: list[str],
482
+ *,
483
+ robust: bool = False,
484
+ cluster_col: str | None = None,
485
+ ) -> tuple[FitResult, object]:
486
+ """Fit a probit regression (binary). Returns (FitResult, raw_model)."""
487
+ y, X, warnings, var_names, groups = _prepare_data(
488
+ df, dep, indeps, cluster_col=cluster_col,
489
+ )
490
+
491
+ unique_y = set(y)
492
+ if not unique_y.issubset({0.0, 1.0}):
493
+ raise ValueError(
494
+ f"Probit requires binary (0/1) dependent variable. "
495
+ f"Found values: {sorted(unique_y)[:10]}"
496
+ )
497
+
498
+ cov_type, cov_kwds = _cov_args(robust, groups)
499
+ model = sm.Probit(y, X).fit(disp=0, cov_type=cov_type, cov_kwds=cov_kwds)
500
+
501
+ if hasattr(model, "mle_retvals") and not model.mle_retvals.get("converged", True):
502
+ warnings.append(
503
+ "WARNING: Model did NOT converge. Results may be unreliable."
504
+ )
505
+
506
+ ci = model.conf_int()
507
+ suffix = _model_type_suffix(robust, groups is not None)
508
+
509
+ result = FitResult(
510
+ model_type="Probit" + suffix,
511
+ formula=f"{dep} ~ {' + '.join(indeps)}",
512
+ dep_var=dep,
513
+ indep_vars=indeps,
514
+ n_obs=int(model.nobs),
515
+ params=dict(zip(var_names, model.params)),
516
+ std_errors=dict(zip(var_names, model.bse)),
517
+ t_values=dict(zip(var_names, model.tvalues)),
518
+ p_values=dict(zip(var_names, model.pvalues)),
519
+ conf_int_low=dict(zip(var_names, ci[:, 0])),
520
+ conf_int_high=dict(zip(var_names, ci[:, 1])),
521
+ pseudo_r2=float(model.prsquared),
522
+ aic=float(model.aic),
523
+ bic=float(model.bic),
524
+ warnings=warnings,
525
+ )
526
+ return result, model
527
+
528
+
529
+ def fit_poisson(
530
+ df: pl.DataFrame,
531
+ dep: str,
532
+ indeps: list[str],
533
+ *,
534
+ robust: bool = False,
535
+ cluster_col: str | None = None,
536
+ exposure_col: str | None = None,
537
+ ) -> tuple[FitResult, object]:
538
+ """Fit a Poisson regression. Returns (FitResult, raw_model)."""
539
+ # Pre-filter: drop rows where exposure is null so _prepare_data and
540
+ # offset computation use the same subset of rows.
541
+ if exposure_col:
542
+ if exposure_col not in df.columns:
543
+ raise ValueError(f"Exposure column not found: {exposure_col}")
544
+ df = df.filter(pl.col(exposure_col).is_not_null())
545
+
546
+ y, X, warnings, var_names, groups = _prepare_data(
547
+ df, dep, indeps, cluster_col=cluster_col,
548
+ )
549
+
550
+ # Handle exposure (offset = log(exposure))
551
+ offset = None
552
+ if exposure_col:
553
+ # Reconstruct the same subset as _prepare_data (exposure nulls already
554
+ # removed above, so drop_nulls here matches _prepare_data exactly).
555
+ base_vars = [v for v in indeps if ":" not in v]
556
+ all_base: set[str] = set(base_vars)
557
+ for v in indeps:
558
+ if ":" in v:
559
+ all_base.update(v.split(":"))
560
+ cols_needed = [dep] + sorted(all_base) + [exposure_col]
561
+ if cluster_col:
562
+ cols_needed = list(dict.fromkeys(cols_needed + [cluster_col]))
563
+ sub = df.select(list(dict.fromkeys(cols_needed))).drop_nulls()
564
+ offset = np.log(sub[exposure_col].to_numpy().astype(float))
565
+
566
+ cov_type, cov_kwds = _cov_args(robust, groups)
567
+ model = sm.Poisson(y, X, offset=offset).fit(
568
+ disp=0, cov_type=cov_type, cov_kwds=cov_kwds,
569
+ )
570
+
571
+ if hasattr(model, "mle_retvals") and not model.mle_retvals.get("converged", True):
572
+ warnings.append("WARNING: Model did NOT converge. Results may be unreliable.")
573
+
574
+ ci = model.conf_int()
575
+ suffix = _model_type_suffix(robust, groups is not None)
576
+
577
+ result = FitResult(
578
+ model_type="Poisson" + suffix,
579
+ formula=f"{dep} ~ {' + '.join(indeps)}",
580
+ dep_var=dep,
581
+ indep_vars=indeps,
582
+ n_obs=int(model.nobs),
583
+ params=dict(zip(var_names, model.params)),
584
+ std_errors=dict(zip(var_names, model.bse)),
585
+ t_values=dict(zip(var_names, model.tvalues)),
586
+ p_values=dict(zip(var_names, model.pvalues)),
587
+ conf_int_low=dict(zip(var_names, ci[:, 0])),
588
+ conf_int_high=dict(zip(var_names, ci[:, 1])),
589
+ pseudo_r2=float(model.prsquared) if hasattr(model, "prsquared") else None,
590
+ log_likelihood=float(model.llf),
591
+ aic=float(model.aic),
592
+ bic=float(model.bic),
593
+ warnings=warnings,
594
+ )
595
+ return result, model
596
+
597
+
598
+ def fit_negbin(
599
+ df: pl.DataFrame,
600
+ dep: str,
601
+ indeps: list[str],
602
+ *,
603
+ robust: bool = False,
604
+ cluster_col: str | None = None,
605
+ ) -> tuple[FitResult, object]:
606
+ """Fit a Negative Binomial regression. Returns (FitResult, raw_model)."""
607
+ y, X, warnings, var_names, groups = _prepare_data(
608
+ df, dep, indeps, cluster_col=cluster_col,
609
+ )
610
+
611
+ cov_type, cov_kwds = _cov_args(robust, groups)
612
+ model = sm.NegativeBinomial(y, X).fit(
613
+ disp=0, cov_type=cov_type, cov_kwds=cov_kwds,
614
+ )
615
+
616
+ if hasattr(model, "mle_retvals") and not model.mle_retvals.get("converged", True):
617
+ warnings.append("WARNING: Model did NOT converge. Results may be unreliable.")
618
+
619
+ ci = model.conf_int()
620
+ suffix = _model_type_suffix(robust, groups is not None)
621
+
622
+ # NegativeBinomial includes alpha as the last parameter; exclude it from coef table
623
+ n_coefs = len(var_names)
624
+ params_arr = model.params[:n_coefs]
625
+ bse_arr = model.bse[:n_coefs]
626
+ tvals_arr = model.tvalues[:n_coefs]
627
+ pvals_arr = model.pvalues[:n_coefs]
628
+ ci_arr = ci[:n_coefs]
629
+
630
+ result = FitResult(
631
+ model_type="NegBin" + suffix,
632
+ formula=f"{dep} ~ {' + '.join(indeps)}",
633
+ dep_var=dep,
634
+ indep_vars=indeps,
635
+ n_obs=int(model.nobs),
636
+ params=dict(zip(var_names, params_arr)),
637
+ std_errors=dict(zip(var_names, bse_arr)),
638
+ t_values=dict(zip(var_names, tvals_arr)),
639
+ p_values=dict(zip(var_names, pvals_arr)),
640
+ conf_int_low=dict(zip(var_names, ci_arr[:, 0])),
641
+ conf_int_high=dict(zip(var_names, ci_arr[:, 1])),
642
+ pseudo_r2=float(model.prsquared) if hasattr(model, "prsquared") else None,
643
+ log_likelihood=float(model.llf),
644
+ dispersion=float(model.params[-1]), # alpha
645
+ aic=float(model.aic),
646
+ bic=float(model.bic),
647
+ warnings=warnings,
648
+ )
649
+ return result, model
650
+
651
+
652
+ def fit_quantreg(
653
+ df: pl.DataFrame,
654
+ dep: str,
655
+ indeps: list[str],
656
+ *,
657
+ tau: float = 0.5,
658
+ ) -> tuple[FitResult, object]:
659
+ """Fit a quantile regression. Returns (FitResult, raw_model)."""
660
+ if not (0 < tau < 1):
661
+ raise ValueError(f"tau must be between 0 and 1 (exclusive), got {tau}")
662
+ y, X, warnings, var_names, _ = _prepare_data(df, dep, indeps)
663
+
664
+ model = sm.QuantReg(y, X).fit(q=tau)
665
+ ci = model.conf_int()
666
+
667
+ result = FitResult(
668
+ model_type=f"QuantReg(tau={tau})",
669
+ formula=f"{dep} ~ {' + '.join(indeps)}",
670
+ dep_var=dep,
671
+ indep_vars=indeps,
672
+ n_obs=int(model.nobs),
673
+ params=dict(zip(var_names, model.params)),
674
+ std_errors=dict(zip(var_names, model.bse)),
675
+ t_values=dict(zip(var_names, model.tvalues)),
676
+ p_values=dict(zip(var_names, model.pvalues)),
677
+ conf_int_low=dict(zip(var_names, ci[:, 0])),
678
+ conf_int_high=dict(zip(var_names, ci[:, 1])),
679
+ pseudo_r2=float(model.prsquared) if hasattr(model, "prsquared") else None,
680
+ warnings=warnings,
681
+ )
682
+ return result, model
683
+
684
+
685
+ # ---------------------------------------------------------------------------
686
+ # Marginal effects
687
+ # ---------------------------------------------------------------------------
688
+
689
+ @dataclass
690
+ class MarginalEffectsResult:
691
+ """Holds marginal effects for a discrete-choice model."""
692
+
693
+ method: str # "at_means" or "average"
694
+ effects: dict[str, float]
695
+ std_errors: dict[str, float]
696
+ z_values: dict[str, float]
697
+ p_values: dict[str, float]
698
+ conf_int_low: dict[str, float]
699
+ conf_int_high: dict[str, float]
700
+
701
+ def summary_table(self) -> str:
702
+ console = Console(file=io.StringIO(), width=100, record=True)
703
+ table = Table(title=f"Marginal Effects ({self.method})")
704
+ table.add_column("Variable", style="cyan")
705
+ table.add_column("dy/dx", justify="right")
706
+ table.add_column("Std.Err", justify="right")
707
+ table.add_column("z", justify="right")
708
+ table.add_column("P>|z|", justify="right")
709
+ table.add_column("[95% CI Low]", justify="right")
710
+ table.add_column("[95% CI High]", justify="right")
711
+
712
+ for var in self.effects:
713
+ pv = self.p_values[var]
714
+ sig = ""
715
+ if pv < 0.001:
716
+ sig = " ***"
717
+ elif pv < 0.01:
718
+ sig = " **"
719
+ elif pv < 0.05:
720
+ sig = " *"
721
+ table.add_row(
722
+ var,
723
+ f"{self.effects[var]:.6f}",
724
+ f"{self.std_errors[var]:.6f}",
725
+ f"{self.z_values[var]:.3f}",
726
+ f"{pv:.4f}{sig}",
727
+ f"{self.conf_int_low[var]:.6f}",
728
+ f"{self.conf_int_high[var]:.6f}",
729
+ )
730
+
731
+ console.print(table)
732
+ return console.export_text()
733
+
734
+
735
+ def compute_margins(
736
+ raw_model: object, var_names: list[str], method: str = "average"
737
+ ) -> MarginalEffectsResult:
738
+ """Compute marginal effects from a logit/probit model.
739
+
740
+ *var_names* should be the list of predictor names (excluding ``_cons``).
741
+ *method*: ``"average"`` for average marginal effects, ``"means"`` for at-means.
742
+ """
743
+ at_map = {"average": "overall", "means": "mean"}
744
+ at = at_map.get(method, "overall")
745
+
746
+ mfx = raw_model.get_margeff(at=at) # type: ignore[attr-defined]
747
+ margeff = mfx.margeff
748
+ margeff_se = mfx.margeff_se
749
+ z_vals = margeff / margeff_se
750
+ p_vals = mfx.pvalues
751
+ ci = mfx.conf_int()
752
+
753
+ # var_names without _cons
754
+ names = [v for v in var_names if v != "_cons"]
755
+
756
+ return MarginalEffectsResult(
757
+ method=method,
758
+ effects=dict(zip(names, margeff)),
759
+ std_errors=dict(zip(names, margeff_se)),
760
+ z_values=dict(zip(names, z_vals)),
761
+ p_values=dict(zip(names, p_vals)),
762
+ conf_int_low=dict(zip(names, ci[:, 0])),
763
+ conf_int_high=dict(zip(names, ci[:, 1])),
764
+ )
765
+
766
+
767
+ # ---------------------------------------------------------------------------
768
+ # Bootstrap confidence intervals
769
+ # ---------------------------------------------------------------------------
770
+
771
+ @dataclass
772
+ class BootstrapResult:
773
+ """Holds bootstrap confidence interval results."""
774
+
775
+ original_params: dict[str, float]
776
+ boot_means: dict[str, float]
777
+ boot_std: dict[str, float]
778
+ ci_low: dict[str, float]
779
+ ci_high: dict[str, float]
780
+ n_boot: int
781
+ ci_level: float
782
+ n_failed: int = 0
783
+
784
+ def summary_table(self) -> str:
785
+ console = Console(file=io.StringIO(), width=100, record=True)
786
+ table = Table(title=f"Bootstrap Confidence Intervals ({self.n_boot} replications, {self.ci_level}%)")
787
+ table.add_column("Variable", style="cyan")
788
+ table.add_column("Original", justify="right")
789
+ table.add_column("Boot.Mean", justify="right")
790
+ table.add_column("Boot.SE", justify="right")
791
+ table.add_column(f"[{self.ci_level}% CI Low]", justify="right")
792
+ table.add_column(f"[{self.ci_level}% CI High]", justify="right")
793
+
794
+ for var in self.original_params:
795
+ table.add_row(
796
+ var,
797
+ f"{self.original_params[var]:.4f}",
798
+ f"{self.boot_means[var]:.4f}",
799
+ f"{self.boot_std[var]:.4f}",
800
+ f"{self.ci_low[var]:.4f}",
801
+ f"{self.ci_high[var]:.4f}",
802
+ )
803
+
804
+ console.print(table)
805
+ if self.n_failed > 0:
806
+ console.print(f"Note: {self.n_failed} bootstrap iteration(s) failed and were skipped.")
807
+ return console.export_text()
808
+
809
+
810
+ def _boot_one_iter(args):
811
+ """Run a single bootstrap iteration (for parallel execution)."""
812
+ df, dep, indeps, fit_fn, fit_kwargs, seed = args
813
+ sample = df.sample(n=df.height, with_replacement=True, seed=seed)
814
+ try:
815
+ r, _ = fit_fn(sample, dep, indeps, **fit_kwargs)
816
+ return dict(r.params)
817
+ except Exception:
818
+ return None
819
+
820
+
821
+ def bootstrap_model(
822
+ df: pl.DataFrame,
823
+ dep: str,
824
+ indeps: list[str],
825
+ fit_fn,
826
+ n_boot: int = 1000,
827
+ ci: float = 95.0,
828
+ **fit_kwargs,
829
+ ) -> BootstrapResult:
830
+ """Generic bootstrap for any model fit function.
831
+
832
+ *fit_fn* must have signature ``(df, dep, indeps, **kwargs) -> (FitResult, raw_model)``.
833
+ Uses thread-pool parallelism for speed.
834
+ """
835
+ from concurrent.futures import ThreadPoolExecutor, as_completed
836
+ import os
837
+
838
+ original_result, _ = fit_fn(df, dep, indeps, **fit_kwargs)
839
+
840
+ boot_params: dict[str, list[float]] = {var: [] for var in original_result.params}
841
+ n_failed = 0
842
+
843
+ max_workers = min(os.cpu_count() or 4, 8)
844
+
845
+ # For small n_boot, serial is faster due to thread overhead
846
+ if n_boot <= 100:
847
+ for i in range(n_boot):
848
+ result = _boot_one_iter((df, dep, indeps, fit_fn, fit_kwargs, i))
849
+ if result is None:
850
+ n_failed += 1
851
+ else:
852
+ for var in boot_params:
853
+ if var in result:
854
+ boot_params[var].append(result[var])
855
+ else:
856
+ tasks = [(df, dep, indeps, fit_fn, fit_kwargs, i) for i in range(n_boot)]
857
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
858
+ futures = [executor.submit(_boot_one_iter, t) for t in tasks]
859
+ for future in as_completed(futures):
860
+ result = future.result()
861
+ if result is None:
862
+ n_failed += 1
863
+ else:
864
+ for var in boot_params:
865
+ if var in result:
866
+ boot_params[var].append(result[var])
867
+
868
+ alpha = (100 - ci) / 2
869
+ boot_means: dict[str, float] = {}
870
+ boot_std: dict[str, float] = {}
871
+ ci_low: dict[str, float] = {}
872
+ ci_high: dict[str, float] = {}
873
+
874
+ for var in original_result.params:
875
+ arr = np.array(boot_params[var])
876
+ if len(arr) > 0:
877
+ boot_means[var] = float(np.mean(arr))
878
+ boot_std[var] = float(np.std(arr, ddof=1))
879
+ ci_low[var] = float(np.percentile(arr, alpha))
880
+ ci_high[var] = float(np.percentile(arr, 100 - alpha))
881
+ else:
882
+ boot_means[var] = float("nan")
883
+ boot_std[var] = float("nan")
884
+ ci_low[var] = float("nan")
885
+ ci_high[var] = float("nan")
886
+
887
+ return BootstrapResult(
888
+ original_params=dict(original_result.params),
889
+ boot_means=boot_means,
890
+ boot_std=boot_std,
891
+ ci_low=ci_low,
892
+ ci_high=ci_high,
893
+ n_boot=n_boot,
894
+ ci_level=ci,
895
+ n_failed=n_failed,
896
+ )
897
+
898
+
899
+ def _build_X_from_indeps(df: pl.DataFrame, indeps: list[str]) -> np.ndarray:
900
+ """Build an X matrix from *indeps*, handling interaction terms (``":"``).
901
+
902
+ Does NOT add a constant. Drops nulls from the relevant columns first.
903
+ Returns the numpy array for the rows present in *df*.
904
+ """
905
+ base_vars = [v for v in indeps if ":" not in v]
906
+ interactions = [v for v in indeps if ":" in v]
907
+
908
+ X_parts = df.select(base_vars).to_numpy().astype(float) if base_vars else np.empty((df.height, 0))
909
+ for inter in interactions:
910
+ parts = inter.split(":")
911
+ col = np.ones(df.height)
912
+ for p in parts:
913
+ col = col * df[p].to_numpy().astype(float)
914
+ X_parts = np.column_stack([X_parts, col]) if X_parts.size > 0 else col.reshape(-1, 1)
915
+ return X_parts
916
+
917
+
918
+ def compute_vif(df: pl.DataFrame, indeps: list[str]) -> list[tuple[str, float]]:
919
+ """Compute Variance Inflation Factor for each predictor.
920
+
921
+ Handles interaction terms (e.g. ``"x1:x2"`` in *indeps*).
922
+ """
923
+ # Collect all base columns needed
924
+ all_base: set[str] = set()
925
+ for v in indeps:
926
+ if ":" in v:
927
+ all_base.update(v.split(":"))
928
+ else:
929
+ all_base.add(v)
930
+
931
+ missing = [c for c in all_base if c not in df.columns]
932
+ if missing:
933
+ raise ValueError(f"Columns not found: {', '.join(missing)}")
934
+
935
+ cols_needed = sorted(all_base)
936
+ sub = df.select(cols_needed).drop_nulls()
937
+
938
+ # Build full X matrix (handles interactions)
939
+ X_full = _build_X_from_indeps(sub, indeps)
940
+ if X_full.shape[0] < X_full.shape[1] + 1:
941
+ raise ValueError("Too few observations for VIF calculation")
942
+
943
+ vifs = []
944
+ for i, var in enumerate(indeps):
945
+ y_i = X_full[:, i]
946
+ X_i = np.delete(X_full, i, axis=1)
947
+ X_i = sm.add_constant(X_i)
948
+ r2 = sm.OLS(y_i, X_i).fit().rsquared
949
+ vif_val = 1.0 / (1.0 - r2) if r2 < 1.0 else float("inf")
950
+ vifs.append((var, vif_val))
951
+ return vifs
952
+
953
+
954
+ # ---------------------------------------------------------------------------
955
+ # Hypothesis tests
956
+ # ---------------------------------------------------------------------------
957
+
958
+ @dataclass
959
+ class TestResult:
960
+ """Holds a hypothesis test result."""
961
+
962
+ test_name: str
963
+ statistic: float
964
+ p_value: float
965
+ df: float | int | None = None
966
+ details: dict[str, object] = field(default_factory=dict)
967
+ interpretation: str = ""
968
+
969
+ def summary_table(self) -> str:
970
+ console = Console(file=io.StringIO(), width=100, record=True)
971
+ table = Table(title=self.test_name)
972
+ table.add_column("Metric", style="cyan")
973
+ table.add_column("Value", justify="right")
974
+
975
+ table.add_row("Test statistic", f"{self.statistic:.4f}")
976
+ if self.df is not None:
977
+ table.add_row("Degrees of freedom", str(self.df))
978
+ table.add_row("p-value", f"{self.p_value:.6f}")
979
+
980
+ for k, v in self.details.items():
981
+ if isinstance(v, float):
982
+ table.add_row(k, f"{v:.4f}")
983
+ else:
984
+ table.add_row(k, str(v))
985
+
986
+ console.print(table)
987
+
988
+ sig = "significant" if self.p_value < 0.05 else "not significant"
989
+ console.print(f"Result: {sig} at alpha = 0.05")
990
+ if self.interpretation:
991
+ console.print(self.interpretation)
992
+ return console.export_text().rstrip()
993
+
994
+
995
+ def run_ttest(
996
+ df: pl.DataFrame,
997
+ col: str,
998
+ *,
999
+ by: str | None = None,
1000
+ mu: float = 0.0,
1001
+ paired_col: str | None = None,
1002
+ ) -> TestResult:
1003
+ """Run a t-test.
1004
+
1005
+ - One-sample: test col mean against mu
1006
+ - Two-sample: split col by a binary grouping variable `by`
1007
+ - Paired: test difference between col and paired_col
1008
+ """
1009
+ if col not in df.columns:
1010
+ raise ValueError(f"Column not found: {col}")
1011
+
1012
+ if paired_col is not None:
1013
+ # Paired t-test
1014
+ if paired_col not in df.columns:
1015
+ raise ValueError(f"Column not found: {paired_col}")
1016
+ sub = df.select([col, paired_col]).drop_nulls()
1017
+ a = sub[col].to_numpy().astype(float)
1018
+ b = sub[paired_col].to_numpy().astype(float)
1019
+ t_stat, p_val = sp_stats.ttest_rel(a, b)
1020
+ return TestResult(
1021
+ test_name=f"Paired t-test: {col} vs {paired_col}",
1022
+ statistic=float(t_stat),
1023
+ p_value=float(p_val),
1024
+ df=len(a) - 1,
1025
+ details={
1026
+ "N (pairs)": len(a),
1027
+ "Mean difference": float(np.mean(a - b)),
1028
+ f"Mean({col})": float(np.mean(a)),
1029
+ f"Mean({paired_col})": float(np.mean(b)),
1030
+ },
1031
+ )
1032
+
1033
+ if by is not None:
1034
+ # Two-sample t-test
1035
+ if by not in df.columns:
1036
+ raise ValueError(f"Column not found: {by}")
1037
+ sub = df.select([col, by]).drop_nulls()
1038
+ groups = sub[by].unique().sort().to_list()
1039
+ if len(groups) != 2:
1040
+ raise ValueError(
1041
+ f"Two-sample t-test requires exactly 2 groups in '{by}', "
1042
+ f"found {len(groups)}: {groups[:5]}"
1043
+ )
1044
+ g1 = sub.filter(pl.col(by) == groups[0])[col].to_numpy().astype(float)
1045
+ g2 = sub.filter(pl.col(by) == groups[1])[col].to_numpy().astype(float)
1046
+ t_stat, p_val = sp_stats.ttest_ind(g1, g2, equal_var=False)
1047
+ # Welch-Satterthwaite degrees of freedom
1048
+ n1, n2 = len(g1), len(g2)
1049
+ v1, v2 = np.var(g1, ddof=1), np.var(g2, ddof=1)
1050
+ numerator = (v1 / n1 + v2 / n2) ** 2
1051
+ denominator = (v1 / n1) ** 2 / (n1 - 1) + (v2 / n2) ** 2 / (n2 - 1)
1052
+ welch_df = numerator / denominator if denominator > 0 else n1 + n2 - 2
1053
+ return TestResult(
1054
+ test_name=f"Two-sample t-test: {col} by {by} (Welch)",
1055
+ statistic=float(t_stat),
1056
+ p_value=float(p_val),
1057
+ df=round(float(welch_df), 2),
1058
+ details={
1059
+ f"N({groups[0]})": len(g1),
1060
+ f"N({groups[1]})": len(g2),
1061
+ f"Mean({groups[0]})": float(np.mean(g1)),
1062
+ f"Mean({groups[1]})": float(np.mean(g2)),
1063
+ },
1064
+ )
1065
+
1066
+ # One-sample t-test
1067
+ sub = df[col].drop_nulls().to_numpy().astype(float)
1068
+ t_stat, p_val = sp_stats.ttest_1samp(sub, mu)
1069
+ return TestResult(
1070
+ test_name=f"One-sample t-test: {col} (H0: mu = {mu})",
1071
+ statistic=float(t_stat),
1072
+ p_value=float(p_val),
1073
+ df=len(sub) - 1,
1074
+ details={
1075
+ "N": len(sub),
1076
+ "Sample mean": float(np.mean(sub)),
1077
+ "Sample SD": float(np.std(sub, ddof=1)),
1078
+ "H0 mean": mu,
1079
+ },
1080
+ )
1081
+
1082
+
1083
+ def run_chi2(df: pl.DataFrame, col1: str, col2: str) -> TestResult:
1084
+ """Run a chi-square test of independence (cross-tabulation)."""
1085
+ for c in (col1, col2):
1086
+ if c not in df.columns:
1087
+ raise ValueError(f"Column not found: {c}")
1088
+
1089
+ sub = df.select([col1, col2]).drop_nulls()
1090
+
1091
+ # Build contingency table
1092
+ ct = sub.group_by([col1, col2]).len().rename({"len": "count"})
1093
+ rows = sorted(sub[col1].unique().to_list(), key=str)
1094
+ cols = sorted(sub[col2].unique().to_list(), key=str)
1095
+
1096
+ table = np.zeros((len(rows), len(cols)), dtype=int)
1097
+ row_idx = {v: i for i, v in enumerate(rows)}
1098
+ col_idx = {v: i for i, v in enumerate(cols)}
1099
+ for r in ct.iter_rows(named=True):
1100
+ table[row_idx[r[col1]], col_idx[r[col2]]] = r["count"]
1101
+
1102
+ chi2, p_val, dof, expected = sp_stats.chi2_contingency(table)
1103
+
1104
+ return TestResult(
1105
+ test_name=f"Chi-square test: {col1} x {col2}",
1106
+ statistic=float(chi2),
1107
+ p_value=float(p_val),
1108
+ df=int(dof),
1109
+ details={
1110
+ "N": int(sub.height),
1111
+ f"Unique({col1})": len(rows),
1112
+ f"Unique({col2})": len(cols),
1113
+ "Cramér's V": float(np.sqrt(chi2 / (sub.height * (min(len(rows), len(cols)) - 1))))
1114
+ if min(len(rows), len(cols)) > 1 else 0.0,
1115
+ },
1116
+ )
1117
+
1118
+
1119
+ def run_anova(df: pl.DataFrame, col: str, by: str) -> TestResult:
1120
+ """Run one-way ANOVA (F-test)."""
1121
+ for c in (col, by):
1122
+ if c not in df.columns:
1123
+ raise ValueError(f"Column not found: {c}")
1124
+
1125
+ sub = df.select([col, by]).drop_nulls()
1126
+ groups = sub[by].unique().sort().to_list()
1127
+
1128
+ if len(groups) < 2:
1129
+ raise ValueError(f"ANOVA requires at least 2 groups, found {len(groups)}")
1130
+
1131
+ samples = []
1132
+ group_stats: list[tuple[str, int, float, float]] = []
1133
+ for g in groups:
1134
+ vals = sub.filter(pl.col(by) == g)[col].to_numpy().astype(float)
1135
+ samples.append(vals)
1136
+ group_stats.append((str(g), len(vals), float(np.mean(vals)), float(np.std(vals, ddof=1))))
1137
+
1138
+ f_stat, p_val = sp_stats.f_oneway(*samples)
1139
+ k = len(groups)
1140
+ n = sub.height
1141
+
1142
+ details: dict[str, object] = {
1143
+ "N (total)": n,
1144
+ "Groups": k,
1145
+ "df (between)": k - 1,
1146
+ "df (within)": n - k,
1147
+ }
1148
+ for name, cnt, mean, sd in group_stats:
1149
+ details[f" {name}: N={cnt}"] = f"mean={mean:.4f}, sd={sd:.4f}"
1150
+
1151
+ return TestResult(
1152
+ test_name=f"One-way ANOVA: {col} by {by}",
1153
+ statistic=float(f_stat),
1154
+ p_value=float(p_val),
1155
+ df=k - 1,
1156
+ details=details,
1157
+ )
1158
+
1159
+
1160
+ # ---------------------------------------------------------------------------
1161
+ # Stepwise regression
1162
+ # ---------------------------------------------------------------------------
1163
+
1164
+ @dataclass
1165
+ class StepwiseResult:
1166
+ """Holds stepwise selection result."""
1167
+
1168
+ direction: str # "forward" or "backward"
1169
+ selected: list[str]
1170
+ dropped: list[str]
1171
+ steps: list[dict[str, object]]
1172
+ final_fit: FitResult
1173
+
1174
+ def summary(self) -> str:
1175
+ """Human-readable summary of variable selection."""
1176
+ lines = [f"Stepwise ({self.direction}) variable selection"]
1177
+ lines.append(f"Final model: {self.final_fit.formula}")
1178
+ lines.append(f"Selected {len(self.selected)} variable(s): {', '.join(self.selected)}")
1179
+ if self.dropped:
1180
+ lines.append(f"Dropped: {', '.join(self.dropped)}")
1181
+ lines.append("")
1182
+ for step in self.steps:
1183
+ lines.append(f" Step {step['step']}: {step['action']} '{step['variable']}' "
1184
+ f"(AIC={step['aic']:.2f})")
1185
+ lines.append("")
1186
+ lines.append(self.final_fit.summary_table())
1187
+ return "\n".join(lines)
1188
+
1189
+
1190
+ def stepwise_ols(
1191
+ df: pl.DataFrame,
1192
+ dep: str,
1193
+ candidates: list[str],
1194
+ *,
1195
+ direction: str = "forward",
1196
+ p_enter: float = 0.05,
1197
+ p_remove: float = 0.10,
1198
+ ) -> StepwiseResult:
1199
+ """Run stepwise OLS regression.
1200
+
1201
+ direction: "forward" or "backward"
1202
+ p_enter: p-value threshold to add a variable (forward)
1203
+ p_remove: p-value threshold to remove a variable (backward)
1204
+ """
1205
+ # Collect all base columns needed (interaction components resolved)
1206
+ all_base: set[str] = {dep}
1207
+ for v in candidates:
1208
+ if ":" in v:
1209
+ all_base.update(v.split(":"))
1210
+ else:
1211
+ all_base.add(v)
1212
+ missing = [c for c in all_base if c not in df.columns]
1213
+ if missing:
1214
+ raise ValueError(f"Columns not found: {', '.join(missing)}")
1215
+
1216
+ steps: list[dict[str, object]] = []
1217
+ step_num = 0
1218
+
1219
+ if direction == "forward":
1220
+ selected: list[str] = []
1221
+ remaining = list(candidates)
1222
+
1223
+ while remaining:
1224
+ best_var = None
1225
+ best_pval = 1.0
1226
+ best_aic = float("inf")
1227
+
1228
+ for var in remaining:
1229
+ try:
1230
+ trial = selected + [var]
1231
+ result, model = fit_ols(df, dep, trial)
1232
+ var_idx = trial.index(var) + 1 # +1 for constant
1233
+ pval = float(model.pvalues[var_idx])
1234
+ aic = float(model.aic)
1235
+ if pval < best_pval:
1236
+ best_var = var
1237
+ best_pval = pval
1238
+ best_aic = aic
1239
+ except Exception:
1240
+ continue
1241
+
1242
+ if best_var is None or best_pval > p_enter:
1243
+ break
1244
+
1245
+ selected.append(best_var)
1246
+ remaining.remove(best_var)
1247
+ step_num += 1
1248
+ steps.append({
1249
+ "step": step_num, "action": "add", "variable": best_var,
1250
+ "p_value": best_pval, "aic": best_aic,
1251
+ })
1252
+
1253
+ dropped = [c for c in candidates if c not in selected]
1254
+
1255
+ else: # backward
1256
+ selected = list(candidates)
1257
+
1258
+ while len(selected) > 1:
1259
+ result, model = fit_ols(df, dep, selected)
1260
+ # Find variable with highest p-value (excluding constant at index 0)
1261
+ pvals = model.pvalues[1:] # skip constant
1262
+ worst_idx = int(np.argmax(pvals))
1263
+ worst_pval = float(pvals[worst_idx])
1264
+ worst_var = selected[worst_idx]
1265
+
1266
+ if worst_pval <= p_remove:
1267
+ break
1268
+
1269
+ selected.remove(worst_var)
1270
+ step_num += 1
1271
+ steps.append({
1272
+ "step": step_num, "action": "remove", "variable": worst_var,
1273
+ "p_value": worst_pval, "aic": float(model.aic),
1274
+ })
1275
+
1276
+ dropped = [c for c in candidates if c not in selected]
1277
+
1278
+ if not selected:
1279
+ raise ValueError("No variables selected. Consider relaxing p_enter threshold.")
1280
+
1281
+ final_result, _ = fit_ols(df, dep, selected)
1282
+ return StepwiseResult(
1283
+ direction=direction,
1284
+ selected=selected,
1285
+ dropped=dropped,
1286
+ steps=steps,
1287
+ final_fit=final_result,
1288
+ )
1289
+
1290
+
1291
+ # ---------------------------------------------------------------------------
1292
+ # Residual diagnostics
1293
+ # ---------------------------------------------------------------------------
1294
+
1295
+ def compute_residuals(
1296
+ model: object, df: pl.DataFrame, dep: str, indeps: list[str]
1297
+ ) -> dict[str, np.ndarray]:
1298
+ """Compute residual diagnostics from a fitted OLS model.
1299
+
1300
+ Returns dict with: residuals, fitted, std_residuals (internally studentized),
1301
+ leverage (hat matrix diagonal).
1302
+ """
1303
+ # Collect all raw columns needed (including interaction components)
1304
+ all_base: set[str] = set()
1305
+ for v in indeps:
1306
+ if ":" in v:
1307
+ all_base.update(v.split(":"))
1308
+ else:
1309
+ all_base.add(v)
1310
+ cols_needed = [dep] + sorted(all_base)
1311
+ sub = df.select(cols_needed).drop_nulls()
1312
+ y = sub[dep].to_numpy().astype(float)
1313
+ X = _build_X_from_indeps(sub, indeps)
1314
+ X = sm.add_constant(X)
1315
+
1316
+ fitted = model.predict(X)
1317
+ resid = y - fitted
1318
+
1319
+ # Internally studentized residuals using leverage (hat matrix)
1320
+ leverage = np.zeros_like(resid)
1321
+ try:
1322
+ H = X @ np.linalg.inv(X.T @ X) @ X.T
1323
+ leverage = np.diag(H)
1324
+ n = len(resid)
1325
+ p = X.shape[1]
1326
+ # MSE (mean squared error of residuals)
1327
+ mse = np.sum(resid ** 2) / (n - p)
1328
+ s = np.sqrt(mse)
1329
+ denom = s * np.sqrt(np.maximum(1 - leverage, 1e-10))
1330
+ resid_std = resid / denom
1331
+ except np.linalg.LinAlgError:
1332
+ # Fallback to simple standardization if hat matrix fails
1333
+ s = np.std(resid, ddof=1)
1334
+ resid_std = resid / s if s > 0 else resid
1335
+
1336
+ return {
1337
+ "residuals": resid,
1338
+ "fitted": fitted,
1339
+ "std_residuals": resid_std,
1340
+ "leverage": leverage,
1341
+ "y": y,
1342
+ }