openstat-cli 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstat/__init__.py +3 -0
- openstat/__main__.py +4 -0
- openstat/backends/__init__.py +16 -0
- openstat/backends/duckdb_backend.py +70 -0
- openstat/backends/polars_backend.py +52 -0
- openstat/cli.py +92 -0
- openstat/commands/__init__.py +82 -0
- openstat/commands/adv_stat_cmds.py +1255 -0
- openstat/commands/advanced_ml_cmds.py +576 -0
- openstat/commands/advreg_cmds.py +207 -0
- openstat/commands/alias_cmds.py +135 -0
- openstat/commands/arch_cmds.py +82 -0
- openstat/commands/arules_cmds.py +111 -0
- openstat/commands/automodel_cmds.py +212 -0
- openstat/commands/backend_cmds.py +82 -0
- openstat/commands/base.py +170 -0
- openstat/commands/bayes_cmds.py +71 -0
- openstat/commands/causal_cmds.py +269 -0
- openstat/commands/cluster_cmds.py +152 -0
- openstat/commands/data_cmds.py +996 -0
- openstat/commands/datamanip_cmds.py +672 -0
- openstat/commands/dataquality_cmds.py +174 -0
- openstat/commands/datetime_cmds.py +176 -0
- openstat/commands/dimreduce_cmds.py +184 -0
- openstat/commands/discrete_cmds.py +149 -0
- openstat/commands/dsl_cmds.py +143 -0
- openstat/commands/epi_cmds.py +93 -0
- openstat/commands/equiv_tobit_cmds.py +94 -0
- openstat/commands/esttab_cmds.py +196 -0
- openstat/commands/export_beamer_cmds.py +142 -0
- openstat/commands/export_cmds.py +201 -0
- openstat/commands/export_extra_cmds.py +240 -0
- openstat/commands/factor_cmds.py +180 -0
- openstat/commands/groupby_cmds.py +155 -0
- openstat/commands/help_cmds.py +237 -0
- openstat/commands/i18n_cmds.py +43 -0
- openstat/commands/import_extra_cmds.py +561 -0
- openstat/commands/influence_cmds.py +134 -0
- openstat/commands/iv_cmds.py +106 -0
- openstat/commands/manova_cmds.py +105 -0
- openstat/commands/mediate_cmds.py +233 -0
- openstat/commands/meta_cmds.py +284 -0
- openstat/commands/mi_cmds.py +228 -0
- openstat/commands/mixed_cmds.py +79 -0
- openstat/commands/mixture_changepoint_cmds.py +166 -0
- openstat/commands/ml_adv_cmds.py +147 -0
- openstat/commands/ml_cmds.py +178 -0
- openstat/commands/model_eval_cmds.py +142 -0
- openstat/commands/network_cmds.py +288 -0
- openstat/commands/nlquery_cmds.py +161 -0
- openstat/commands/nonparam_cmds.py +149 -0
- openstat/commands/outreg_cmds.py +247 -0
- openstat/commands/panel_cmds.py +141 -0
- openstat/commands/pdf_cmds.py +226 -0
- openstat/commands/pipeline_cmds.py +319 -0
- openstat/commands/plot_cmds.py +189 -0
- openstat/commands/plugin_cmds.py +79 -0
- openstat/commands/posthoc_cmds.py +153 -0
- openstat/commands/power_cmds.py +172 -0
- openstat/commands/profile_cmds.py +246 -0
- openstat/commands/rbridge_cmds.py +81 -0
- openstat/commands/regex_cmds.py +104 -0
- openstat/commands/report_cmds.py +48 -0
- openstat/commands/repro_cmds.py +129 -0
- openstat/commands/resampling_cmds.py +109 -0
- openstat/commands/reshape_cmds.py +223 -0
- openstat/commands/sem_cmds.py +177 -0
- openstat/commands/stat_cmds.py +1040 -0
- openstat/commands/stata_import_cmds.py +215 -0
- openstat/commands/string_cmds.py +124 -0
- openstat/commands/surv_cmds.py +145 -0
- openstat/commands/survey_cmds.py +153 -0
- openstat/commands/textanalysis_cmds.py +192 -0
- openstat/commands/ts_adv_cmds.py +136 -0
- openstat/commands/ts_cmds.py +195 -0
- openstat/commands/tui_cmds.py +111 -0
- openstat/commands/ux_cmds.py +191 -0
- openstat/commands/validate_cmds.py +270 -0
- openstat/commands/viz_adv_cmds.py +312 -0
- openstat/commands/viz_extra_cmds.py +251 -0
- openstat/commands/watch_cmds.py +69 -0
- openstat/config.py +106 -0
- openstat/dsl/__init__.py +0 -0
- openstat/dsl/parser.py +332 -0
- openstat/dsl/tokenizer.py +105 -0
- openstat/i18n.py +120 -0
- openstat/io/__init__.py +0 -0
- openstat/io/loader.py +187 -0
- openstat/jupyter/__init__.py +18 -0
- openstat/jupyter/display.py +18 -0
- openstat/jupyter/magic.py +60 -0
- openstat/logging_config.py +59 -0
- openstat/plots/__init__.py +0 -0
- openstat/plots/plotter.py +437 -0
- openstat/plots/surv_plots.py +32 -0
- openstat/plots/ts_plots.py +59 -0
- openstat/plugins/__init__.py +5 -0
- openstat/plugins/manager.py +69 -0
- openstat/repl.py +457 -0
- openstat/reporting/__init__.py +0 -0
- openstat/reporting/eda.py +208 -0
- openstat/reporting/report.py +67 -0
- openstat/script_runner.py +319 -0
- openstat/session.py +133 -0
- openstat/stats/__init__.py +0 -0
- openstat/stats/advanced_regression.py +269 -0
- openstat/stats/arch_garch.py +84 -0
- openstat/stats/bayesian.py +103 -0
- openstat/stats/causal.py +258 -0
- openstat/stats/clustering.py +206 -0
- openstat/stats/discrete.py +311 -0
- openstat/stats/epidemiology.py +119 -0
- openstat/stats/equiv_tobit.py +163 -0
- openstat/stats/factor.py +174 -0
- openstat/stats/imputation.py +282 -0
- openstat/stats/influence.py +78 -0
- openstat/stats/iv.py +131 -0
- openstat/stats/manova.py +124 -0
- openstat/stats/mixed.py +128 -0
- openstat/stats/ml.py +275 -0
- openstat/stats/ml_advanced.py +117 -0
- openstat/stats/model_eval.py +183 -0
- openstat/stats/models.py +1342 -0
- openstat/stats/nonparametric.py +130 -0
- openstat/stats/panel.py +179 -0
- openstat/stats/power.py +295 -0
- openstat/stats/resampling.py +203 -0
- openstat/stats/survey.py +213 -0
- openstat/stats/survival.py +196 -0
- openstat/stats/timeseries.py +142 -0
- openstat/stats/ts_advanced.py +114 -0
- openstat/types.py +11 -0
- openstat/web/__init__.py +1 -0
- openstat/web/app.py +117 -0
- openstat/web/session_manager.py +73 -0
- openstat/web/static/app.js +117 -0
- openstat/web/static/index.html +38 -0
- openstat/web/static/style.css +103 -0
- openstat_cli-1.0.0.dist-info/METADATA +748 -0
- openstat_cli-1.0.0.dist-info/RECORD +143 -0
- openstat_cli-1.0.0.dist-info/WHEEL +4 -0
- openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
- openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1040 @@
|
|
|
1
|
+
"""Statistics commands: summarize, tabulate, groupby, corr, ols, logit, ttest, chi2, anova."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
import polars as pl
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
from rich.table import Table
|
|
10
|
+
|
|
11
|
+
from openstat.session import Session, ModelResult
|
|
12
|
+
from openstat.config import get_config
|
|
13
|
+
from openstat.dsl.parser import parse_formula, ParseError
|
|
14
|
+
from openstat.stats.models import (
|
|
15
|
+
fit_ols, fit_logit, fit_probit, fit_poisson, fit_negbin, fit_quantreg,
|
|
16
|
+
compute_margins, bootstrap_model,
|
|
17
|
+
run_ttest, run_chi2, run_anova,
|
|
18
|
+
compute_vif, stepwise_ols, compute_residuals,
|
|
19
|
+
)
|
|
20
|
+
from openstat.commands.base import command, CommandArgs, rich_to_str, friendly_error
|
|
21
|
+
from openstat.types import NUMERIC_DTYPES
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _store_model(
|
|
25
|
+
session: Session, result, raw_model, dep: str, indeps: list[str],
|
|
26
|
+
fit_kwargs: dict | None = None,
|
|
27
|
+
) -> str:
|
|
28
|
+
"""Store model in session state, return summary output."""
|
|
29
|
+
session._last_model = raw_model
|
|
30
|
+
session._last_model_vars = (dep, indeps)
|
|
31
|
+
session._last_fit_result = result
|
|
32
|
+
session._last_fit_kwargs = fit_kwargs or {}
|
|
33
|
+
md = result.to_markdown()
|
|
34
|
+
details: dict = {
|
|
35
|
+
"n_obs": result.n_obs,
|
|
36
|
+
"params": dict(result.params),
|
|
37
|
+
"std_errors": dict(result.std_errors),
|
|
38
|
+
"aic": result.aic,
|
|
39
|
+
"bic": result.bic,
|
|
40
|
+
}
|
|
41
|
+
if result.r_squared is not None:
|
|
42
|
+
details["r_squared"] = result.r_squared
|
|
43
|
+
if result.adj_r_squared is not None:
|
|
44
|
+
details["adj_r_squared"] = result.adj_r_squared
|
|
45
|
+
if result.pseudo_r2 is not None:
|
|
46
|
+
details["pseudo_r2"] = result.pseudo_r2
|
|
47
|
+
if result.log_likelihood is not None:
|
|
48
|
+
details["log_likelihood"] = result.log_likelihood
|
|
49
|
+
if result.dispersion is not None:
|
|
50
|
+
details["dispersion"] = result.dispersion
|
|
51
|
+
session.results.append(ModelResult(
|
|
52
|
+
name=result.model_type, formula=result.formula,
|
|
53
|
+
table=md, details=details,
|
|
54
|
+
))
|
|
55
|
+
output = result.summary_table()
|
|
56
|
+
if result.warnings:
|
|
57
|
+
output += "\n" + "\n".join(result.warnings)
|
|
58
|
+
return output
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _parse_agg(token: str) -> tuple[str, str | None]:
|
|
62
|
+
m = re.match(r"(\w+)\((\w*)\)", token)
|
|
63
|
+
if not m:
|
|
64
|
+
raise ValueError(f"Invalid aggregation: {token}. Use e.g. mean(col), count()")
|
|
65
|
+
func, col = m.group(1), m.group(2) or None
|
|
66
|
+
return func, col
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@command("summarize", usage="summarize [col1 col2 ...]")
|
|
70
|
+
def cmd_summarize(session: Session, args: str) -> str:
|
|
71
|
+
"""Compute summary statistics for numeric columns (SD = sample, ddof=1)."""
|
|
72
|
+
df = session.require_data()
|
|
73
|
+
cols = args.split() if args.strip() else None
|
|
74
|
+
|
|
75
|
+
if cols:
|
|
76
|
+
missing = [c for c in cols if c not in df.columns]
|
|
77
|
+
if missing:
|
|
78
|
+
return f"Columns not found: {', '.join(missing)}"
|
|
79
|
+
num_cols = [c for c in cols if df[c].dtype in NUMERIC_DTYPES]
|
|
80
|
+
else:
|
|
81
|
+
num_cols = [c for c in df.columns if df[c].dtype in NUMERIC_DTYPES]
|
|
82
|
+
|
|
83
|
+
if not num_cols:
|
|
84
|
+
return "No numeric columns to summarize."
|
|
85
|
+
|
|
86
|
+
def render(console: Console) -> None:
|
|
87
|
+
table = Table(title="Summary Statistics")
|
|
88
|
+
table.add_column("Variable", style="cyan")
|
|
89
|
+
for stat in ["N", "Mean", "SD (sample)", "Min", "P25", "P50", "P75", "Max"]:
|
|
90
|
+
table.add_column(stat, justify="right")
|
|
91
|
+
|
|
92
|
+
for c in num_cols:
|
|
93
|
+
col = df[c].drop_nulls()
|
|
94
|
+
n = col.len()
|
|
95
|
+
if n == 0:
|
|
96
|
+
table.add_row(c, "0", *["—"] * 7)
|
|
97
|
+
continue
|
|
98
|
+
sd_val = col.std() if n > 1 else 0.0
|
|
99
|
+
table.add_row(
|
|
100
|
+
c, str(n),
|
|
101
|
+
f"{col.mean():.4f}", f"{sd_val:.4f}", f"{col.min():.4f}",
|
|
102
|
+
f"{col.quantile(0.25):.4f}", f"{col.quantile(0.50):.4f}",
|
|
103
|
+
f"{col.quantile(0.75):.4f}", f"{col.max():.4f}",
|
|
104
|
+
)
|
|
105
|
+
console.print(table)
|
|
106
|
+
|
|
107
|
+
return rich_to_str(render)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@command("tabulate", usage="tabulate <column>")
|
|
111
|
+
def cmd_tabulate(session: Session, args: str) -> str:
|
|
112
|
+
"""Show frequency table for a column (top 50 values by default)."""
|
|
113
|
+
df = session.require_data()
|
|
114
|
+
col = args.strip()
|
|
115
|
+
if not col:
|
|
116
|
+
return "Usage: tabulate <column>"
|
|
117
|
+
if col not in df.columns:
|
|
118
|
+
return f"Column not found: {col}"
|
|
119
|
+
|
|
120
|
+
tab_limit = get_config().tabulate_limit
|
|
121
|
+
counts = (
|
|
122
|
+
df.group_by(col).len()
|
|
123
|
+
.sort("len", descending=True)
|
|
124
|
+
.rename({"len": "count"})
|
|
125
|
+
)
|
|
126
|
+
total = counts["count"].sum()
|
|
127
|
+
total_unique = counts.height
|
|
128
|
+
truncated = total_unique > tab_limit
|
|
129
|
+
|
|
130
|
+
if truncated:
|
|
131
|
+
counts = counts.head(tab_limit)
|
|
132
|
+
|
|
133
|
+
counts = counts.with_columns(
|
|
134
|
+
(pl.col("count") / total * 100).round(1).alias("percent")
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def render(console: Console) -> None:
|
|
138
|
+
title = f"Frequency: {col}"
|
|
139
|
+
if truncated:
|
|
140
|
+
title += f" (top {tab_limit} of {total_unique} unique values)"
|
|
141
|
+
table = Table(title=title)
|
|
142
|
+
table.add_column(col, style="cyan")
|
|
143
|
+
table.add_column("Count", justify="right")
|
|
144
|
+
table.add_column("Percent", justify="right")
|
|
145
|
+
|
|
146
|
+
for row in counts.iter_rows(named=True):
|
|
147
|
+
table.add_row(str(row[col]), str(row["count"]), f"{row['percent']:.1f}%")
|
|
148
|
+
table.add_row("Total", str(total), "100.0%", style="bold")
|
|
149
|
+
console.print(table)
|
|
150
|
+
|
|
151
|
+
return rich_to_str(render)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@command("corr", usage="corr [col1 col2 ...]")
|
|
155
|
+
def cmd_corr(session: Session, args: str) -> str:
|
|
156
|
+
"""Show correlation matrix for numeric columns."""
|
|
157
|
+
df = session.require_data()
|
|
158
|
+
cols = args.split() if args.strip() else None
|
|
159
|
+
|
|
160
|
+
if cols:
|
|
161
|
+
missing = [c for c in cols if c not in df.columns]
|
|
162
|
+
if missing:
|
|
163
|
+
return f"Columns not found: {', '.join(missing)}"
|
|
164
|
+
num_cols = [c for c in cols if df[c].dtype in NUMERIC_DTYPES]
|
|
165
|
+
else:
|
|
166
|
+
num_cols = [c for c in df.columns if df[c].dtype in NUMERIC_DTYPES]
|
|
167
|
+
|
|
168
|
+
if len(num_cols) < 2:
|
|
169
|
+
return "Need at least 2 numeric columns for correlation."
|
|
170
|
+
|
|
171
|
+
sub = df.select(num_cols).drop_nulls()
|
|
172
|
+
# Compute pairwise correlation
|
|
173
|
+
corr_data: dict[str, list[float]] = {}
|
|
174
|
+
for c1 in num_cols:
|
|
175
|
+
row = []
|
|
176
|
+
for c2 in num_cols:
|
|
177
|
+
r = sub.select(pl.corr(c1, c2)).item()
|
|
178
|
+
row.append(r if r is not None else 0.0)
|
|
179
|
+
corr_data[c1] = row
|
|
180
|
+
|
|
181
|
+
def render(console: Console) -> None:
|
|
182
|
+
table = Table(title="Correlation Matrix (Pearson)")
|
|
183
|
+
table.add_column("", style="cyan")
|
|
184
|
+
for c in num_cols:
|
|
185
|
+
table.add_column(c, justify="right")
|
|
186
|
+
for i, c1 in enumerate(num_cols):
|
|
187
|
+
vals = []
|
|
188
|
+
for j, c2 in enumerate(num_cols):
|
|
189
|
+
r = corr_data[c1][j]
|
|
190
|
+
# Highlight strong correlations
|
|
191
|
+
if i == j:
|
|
192
|
+
vals.append("1.0000")
|
|
193
|
+
elif abs(r) > 0.7:
|
|
194
|
+
vals.append(f"[bold]{r:.4f}[/bold]")
|
|
195
|
+
else:
|
|
196
|
+
vals.append(f"{r:.4f}")
|
|
197
|
+
table.add_row(c1, *vals)
|
|
198
|
+
console.print(table)
|
|
199
|
+
|
|
200
|
+
return rich_to_str(render)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@command("groupby", usage="groupby <cols> summarize <agg(col)> ...")
|
|
204
|
+
def cmd_groupby(session: Session, args: str) -> str:
|
|
205
|
+
"""Group-by and summarize."""
|
|
206
|
+
df = session.require_data()
|
|
207
|
+
|
|
208
|
+
ca = CommandArgs(args)
|
|
209
|
+
summarize_rest = ca.rest_after("summarize")
|
|
210
|
+
if summarize_rest is None:
|
|
211
|
+
return "Usage: groupby <col1> <col2> summarize mean(x) sd(x) count()"
|
|
212
|
+
|
|
213
|
+
# Group cols are positional tokens before "summarize"
|
|
214
|
+
group_cols = []
|
|
215
|
+
for p in ca.positional:
|
|
216
|
+
if p.lower() == "summarize":
|
|
217
|
+
break
|
|
218
|
+
group_cols.append(p)
|
|
219
|
+
agg_tokens = summarize_rest.split()
|
|
220
|
+
|
|
221
|
+
if not group_cols:
|
|
222
|
+
return "No grouping columns specified."
|
|
223
|
+
if not agg_tokens:
|
|
224
|
+
return "No aggregation functions specified."
|
|
225
|
+
|
|
226
|
+
missing = [c for c in group_cols if c not in df.columns]
|
|
227
|
+
if missing:
|
|
228
|
+
return f"Columns not found: {', '.join(missing)}"
|
|
229
|
+
|
|
230
|
+
AGG_MAP = {
|
|
231
|
+
"mean": lambda c: pl.col(c).mean().alias(f"mean_{c}"),
|
|
232
|
+
"sd": lambda c: pl.col(c).std().alias(f"sd_{c}"),
|
|
233
|
+
"sum": lambda c: pl.col(c).sum().alias(f"sum_{c}"),
|
|
234
|
+
"min": lambda c: pl.col(c).min().alias(f"min_{c}"),
|
|
235
|
+
"max": lambda c: pl.col(c).max().alias(f"max_{c}"),
|
|
236
|
+
"median": lambda c: pl.col(c).median().alias(f"median_{c}"),
|
|
237
|
+
"count": lambda _: pl.len().alias("count"),
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
agg_exprs = []
|
|
241
|
+
for tok in agg_tokens:
|
|
242
|
+
func_name, col_name = _parse_agg(tok)
|
|
243
|
+
if func_name not in AGG_MAP:
|
|
244
|
+
return f"Unknown aggregation: {func_name}. Available: {', '.join(AGG_MAP)}"
|
|
245
|
+
if func_name != "count" and col_name is None:
|
|
246
|
+
return f"{func_name}() requires a column name, e.g. {func_name}(col)"
|
|
247
|
+
if col_name and col_name not in df.columns:
|
|
248
|
+
return f"Column not found: {col_name}"
|
|
249
|
+
agg_exprs.append(AGG_MAP[func_name](col_name))
|
|
250
|
+
|
|
251
|
+
result = df.group_by(group_cols).agg(agg_exprs).sort(group_cols)
|
|
252
|
+
|
|
253
|
+
def render(console: Console) -> None:
|
|
254
|
+
table = Table(title="Group Summary")
|
|
255
|
+
for col_name in result.columns:
|
|
256
|
+
table.add_column(col_name, justify="right" if col_name not in group_cols else "left")
|
|
257
|
+
for row in result.iter_rows():
|
|
258
|
+
table.add_row(*[f"{v:.4f}" if isinstance(v, float) else str(v) for v in row])
|
|
259
|
+
console.print(table)
|
|
260
|
+
|
|
261
|
+
return rich_to_str(render)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@command("ols", usage="ols y ~ x1 + x2 [--robust] [--cluster=col]")
|
|
265
|
+
def cmd_ols(session: Session, args: str) -> str:
|
|
266
|
+
"""Fit OLS regression."""
|
|
267
|
+
df = session.require_data()
|
|
268
|
+
ca = CommandArgs(args)
|
|
269
|
+
robust = ca.has_flag("--robust")
|
|
270
|
+
cluster_col = ca.get_option("cluster")
|
|
271
|
+
formula_str = ca.strip_flags_and_options()
|
|
272
|
+
if not formula_str:
|
|
273
|
+
return "Usage: ols y ~ x1 + x2 [--robust] [--cluster=col]"
|
|
274
|
+
try:
|
|
275
|
+
dep, indeps = parse_formula(formula_str)
|
|
276
|
+
result, raw_model = fit_ols(df, dep, indeps, robust=robust, cluster_col=cluster_col)
|
|
277
|
+
return _store_model(session, result, raw_model, dep, indeps)
|
|
278
|
+
except ParseError as e:
|
|
279
|
+
return f"Formula error: {e}"
|
|
280
|
+
except Exception as e:
|
|
281
|
+
return friendly_error(e, "OLS error")
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
@command("logit", usage="logit y ~ x1 + x2 [--robust] [--cluster=col]")
|
|
285
|
+
def cmd_logit(session: Session, args: str) -> str:
|
|
286
|
+
"""Fit logistic regression (binary dependent variable)."""
|
|
287
|
+
df = session.require_data()
|
|
288
|
+
ca = CommandArgs(args)
|
|
289
|
+
robust = ca.has_flag("--robust")
|
|
290
|
+
cluster_col = ca.get_option("cluster")
|
|
291
|
+
formula_str = ca.strip_flags_and_options()
|
|
292
|
+
if not formula_str:
|
|
293
|
+
return "Usage: logit y ~ x1 + x2 [--robust] [--cluster=col]"
|
|
294
|
+
try:
|
|
295
|
+
dep, indeps = parse_formula(formula_str)
|
|
296
|
+
result, raw_model = fit_logit(df, dep, indeps, robust=robust, cluster_col=cluster_col)
|
|
297
|
+
return _store_model(session, result, raw_model, dep, indeps)
|
|
298
|
+
except ParseError as e:
|
|
299
|
+
return f"Formula error: {e}"
|
|
300
|
+
except Exception as e:
|
|
301
|
+
return friendly_error(e, "Logit error")
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@command("ttest", usage="ttest <col> [by <group>] [mu=<value>] [paired <col2>]")
|
|
305
|
+
def cmd_ttest(session: Session, args: str) -> str:
|
|
306
|
+
"""T-test: one-sample, two-sample (Welch), or paired."""
|
|
307
|
+
df = session.require_data()
|
|
308
|
+
ca = CommandArgs(args)
|
|
309
|
+
if not ca.positional:
|
|
310
|
+
return (
|
|
311
|
+
"Usage:\n"
|
|
312
|
+
" ttest <col> One-sample (H0: mean=0)\n"
|
|
313
|
+
" ttest <col> mu=5 One-sample (H0: mean=5)\n"
|
|
314
|
+
" ttest <col> by <group> Two-sample (Welch)\n"
|
|
315
|
+
" ttest <col> paired <col2> Paired t-test"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
col = ca.positional[0]
|
|
319
|
+
mu = ca.get_option_float("mu", 0.0)
|
|
320
|
+
|
|
321
|
+
by = None
|
|
322
|
+
by_rest = ca.rest_after("by")
|
|
323
|
+
if by_rest:
|
|
324
|
+
by = by_rest.split()[0]
|
|
325
|
+
|
|
326
|
+
paired_col = None
|
|
327
|
+
paired_rest = ca.rest_after("paired")
|
|
328
|
+
if paired_rest:
|
|
329
|
+
paired_col = paired_rest.split()[0]
|
|
330
|
+
|
|
331
|
+
try:
|
|
332
|
+
result = run_ttest(df, col, by=by, mu=mu, paired_col=paired_col)
|
|
333
|
+
return result.summary_table()
|
|
334
|
+
except Exception as e:
|
|
335
|
+
return friendly_error(e, "T-test error")
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
@command("chi2", usage="chi2 <col1> <col2>")
|
|
339
|
+
def cmd_chi2(session: Session, args: str) -> str:
|
|
340
|
+
"""Chi-square test of independence between two categorical columns."""
|
|
341
|
+
df = session.require_data()
|
|
342
|
+
parts = args.split()
|
|
343
|
+
if len(parts) < 2:
|
|
344
|
+
return "Usage: chi2 <col1> <col2>"
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
result = run_chi2(df, parts[0], parts[1])
|
|
348
|
+
return result.summary_table()
|
|
349
|
+
except Exception as e:
|
|
350
|
+
return friendly_error(e, "Chi-square error")
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@command("anova", usage="anova <col> by <group>")
|
|
354
|
+
def cmd_anova(session: Session, args: str) -> str:
|
|
355
|
+
"""One-way ANOVA: test if group means differ."""
|
|
356
|
+
df = session.require_data()
|
|
357
|
+
ca = CommandArgs(args)
|
|
358
|
+
by_str = ca.rest_after("by")
|
|
359
|
+
if not by_str:
|
|
360
|
+
return "Usage: anova <col> by <group_col>"
|
|
361
|
+
col = ca.positional[0] if ca.positional else ""
|
|
362
|
+
by_col = by_str.split()[0]
|
|
363
|
+
if not col or not by_col:
|
|
364
|
+
return "Usage: anova <col> by <group_col>"
|
|
365
|
+
try:
|
|
366
|
+
result = run_anova(df, col, by_col)
|
|
367
|
+
return result.summary_table()
|
|
368
|
+
except Exception as e:
|
|
369
|
+
return friendly_error(e, "ANOVA error")
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
@command("crosstab", usage="crosstab <row_col> <col_col>")
|
|
373
|
+
def cmd_crosstab(session: Session, args: str) -> str:
|
|
374
|
+
"""Two-way frequency table (contingency table) with row percentages."""
|
|
375
|
+
df = session.require_data()
|
|
376
|
+
parts = args.split()
|
|
377
|
+
if len(parts) < 2:
|
|
378
|
+
return "Usage: crosstab <row_col> <col_col>"
|
|
379
|
+
|
|
380
|
+
row_col, col_col = parts[0], parts[1]
|
|
381
|
+
for c in (row_col, col_col):
|
|
382
|
+
if c not in df.columns:
|
|
383
|
+
return f"Column not found: {c}"
|
|
384
|
+
|
|
385
|
+
sub = df.select([row_col, col_col]).drop_nulls()
|
|
386
|
+
ct = sub.group_by([row_col, col_col]).len().rename({"len": "count"})
|
|
387
|
+
|
|
388
|
+
rows = sorted(sub[row_col].unique().to_list(), key=str)
|
|
389
|
+
cols = sorted(sub[col_col].unique().to_list(), key=str)
|
|
390
|
+
|
|
391
|
+
# Build count matrix
|
|
392
|
+
count_map: dict[tuple, int] = {}
|
|
393
|
+
for r in ct.iter_rows(named=True):
|
|
394
|
+
count_map[(r[row_col], r[col_col])] = r["count"]
|
|
395
|
+
|
|
396
|
+
def render(console: Console) -> None:
|
|
397
|
+
table = Table(title=f"Cross-tabulation: {row_col} x {col_col}")
|
|
398
|
+
table.add_column(row_col, style="cyan")
|
|
399
|
+
for c in cols:
|
|
400
|
+
table.add_column(str(c), justify="right")
|
|
401
|
+
table.add_column("Total", justify="right", style="bold")
|
|
402
|
+
|
|
403
|
+
for row_val in rows:
|
|
404
|
+
row_total = sum(count_map.get((row_val, c), 0) for c in cols)
|
|
405
|
+
cells = []
|
|
406
|
+
for c in cols:
|
|
407
|
+
cnt = count_map.get((row_val, c), 0)
|
|
408
|
+
pct = cnt / row_total * 100 if row_total > 0 else 0
|
|
409
|
+
cells.append(f"{cnt} ({pct:.0f}%)")
|
|
410
|
+
table.add_row(str(row_val), *cells, str(row_total))
|
|
411
|
+
|
|
412
|
+
# Total row
|
|
413
|
+
col_totals = [sum(count_map.get((r, c), 0) for r in rows) for c in cols]
|
|
414
|
+
grand_total = sum(col_totals)
|
|
415
|
+
table.add_row(
|
|
416
|
+
"Total",
|
|
417
|
+
*[str(t) for t in col_totals],
|
|
418
|
+
str(grand_total),
|
|
419
|
+
style="bold",
|
|
420
|
+
)
|
|
421
|
+
console.print(table)
|
|
422
|
+
|
|
423
|
+
return rich_to_str(render)
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
@command("probit", usage="probit y ~ x1 + x2 [--robust] [--cluster=col]")
|
|
427
|
+
def cmd_probit(session: Session, args: str) -> str:
|
|
428
|
+
"""Fit probit regression (binary dependent variable)."""
|
|
429
|
+
df = session.require_data()
|
|
430
|
+
ca = CommandArgs(args)
|
|
431
|
+
robust = ca.has_flag("--robust")
|
|
432
|
+
cluster_col = ca.get_option("cluster")
|
|
433
|
+
formula_str = ca.strip_flags_and_options()
|
|
434
|
+
if not formula_str:
|
|
435
|
+
return "Usage: probit y ~ x1 + x2 [--robust] [--cluster=col]"
|
|
436
|
+
try:
|
|
437
|
+
dep, indeps = parse_formula(formula_str)
|
|
438
|
+
result, raw_model = fit_probit(df, dep, indeps, robust=robust, cluster_col=cluster_col)
|
|
439
|
+
return _store_model(session, result, raw_model, dep, indeps)
|
|
440
|
+
except ParseError as e:
|
|
441
|
+
return f"Formula error: {e}"
|
|
442
|
+
except Exception as e:
|
|
443
|
+
return friendly_error(e, "Probit error")
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
@command("predict", usage="predict [<col_name>]")
|
|
447
|
+
def cmd_predict(session: Session, args: str) -> str:
|
|
448
|
+
"""Generate predictions from the last fitted model, add as a new column."""
|
|
449
|
+
import statsmodels.api as sm
|
|
450
|
+
from openstat.stats.models import _build_X_from_indeps
|
|
451
|
+
|
|
452
|
+
df = session.require_data()
|
|
453
|
+
if session._last_model is None or session._last_model_vars is None:
|
|
454
|
+
return "No model fitted yet. Run ols, logit, or probit first."
|
|
455
|
+
|
|
456
|
+
col_name = args.strip() or "yhat"
|
|
457
|
+
dep, indeps = session._last_model_vars
|
|
458
|
+
|
|
459
|
+
# Collect all base columns needed (including interaction components)
|
|
460
|
+
all_base: set[str] = set()
|
|
461
|
+
for v in indeps:
|
|
462
|
+
if ":" in v:
|
|
463
|
+
all_base.update(v.split(":"))
|
|
464
|
+
else:
|
|
465
|
+
all_base.add(v)
|
|
466
|
+
missing = [c for c in all_base if c not in df.columns]
|
|
467
|
+
if missing:
|
|
468
|
+
return f"Predictor columns not found in current data: {', '.join(missing)}"
|
|
469
|
+
|
|
470
|
+
model = session._last_model
|
|
471
|
+
X = _build_X_from_indeps(df, indeps)
|
|
472
|
+
X = sm.add_constant(X)
|
|
473
|
+
preds = model.predict(X)
|
|
474
|
+
|
|
475
|
+
session.snapshot()
|
|
476
|
+
session.df = df.with_columns(pl.Series(col_name, preds.tolist()).cast(pl.Float64))
|
|
477
|
+
return f"Predictions added as '{col_name}'. {session.shape_str}. Use 'undo' to revert."
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
@command("vif", usage="vif")
|
|
481
|
+
def cmd_vif(session: Session, args: str) -> str:
|
|
482
|
+
"""Show Variance Inflation Factor for the last fitted model's predictors."""
|
|
483
|
+
df = session.require_data()
|
|
484
|
+
if session._last_model_vars is None:
|
|
485
|
+
return "No model fitted yet. Run ols first, then vif."
|
|
486
|
+
|
|
487
|
+
dep, indeps = session._last_model_vars
|
|
488
|
+
if len(indeps) < 2:
|
|
489
|
+
return "VIF requires at least 2 predictors."
|
|
490
|
+
|
|
491
|
+
try:
|
|
492
|
+
vifs = compute_vif(df, indeps)
|
|
493
|
+
|
|
494
|
+
def render(console: Console) -> None:
|
|
495
|
+
table = Table(title="Variance Inflation Factor")
|
|
496
|
+
table.add_column("Variable", style="cyan")
|
|
497
|
+
table.add_column("VIF", justify="right")
|
|
498
|
+
table.add_column("Status", style="green")
|
|
499
|
+
|
|
500
|
+
for var, vif_val in vifs:
|
|
501
|
+
if vif_val > 10:
|
|
502
|
+
status = "[red]HIGH[/red]"
|
|
503
|
+
elif vif_val > 5:
|
|
504
|
+
status = "[yellow]moderate[/yellow]"
|
|
505
|
+
else:
|
|
506
|
+
status = "ok"
|
|
507
|
+
table.add_row(var, f"{vif_val:.2f}", status)
|
|
508
|
+
console.print(table)
|
|
509
|
+
console.print("Rule of thumb: VIF > 10 indicates serious multicollinearity")
|
|
510
|
+
|
|
511
|
+
return rich_to_str(render)
|
|
512
|
+
except Exception as e:
|
|
513
|
+
return friendly_error(e, "VIF error")
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
@command("stepwise", usage="stepwise y ~ x1 + x2 + x3 [--backward] [--p_enter=0.05] [--p_remove=0.10]")
|
|
517
|
+
def cmd_stepwise(session: Session, args: str) -> str:
|
|
518
|
+
"""Run stepwise OLS regression for variable selection."""
|
|
519
|
+
df = session.require_data()
|
|
520
|
+
if not args.strip():
|
|
521
|
+
return (
|
|
522
|
+
"Usage: stepwise y ~ x1 + x2 + x3 [--backward]\n"
|
|
523
|
+
"Options: --p_enter=0.05 --p_remove=0.10"
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
ca = CommandArgs(args)
|
|
527
|
+
direction = "backward" if ca.has_flag("--backward") else "forward"
|
|
528
|
+
p_enter = ca.get_option_float("p_enter", 0.05)
|
|
529
|
+
p_remove = ca.get_option_float("p_remove", 0.10)
|
|
530
|
+
formula_str = ca.strip_flags_and_options()
|
|
531
|
+
try:
|
|
532
|
+
dep, indeps = parse_formula(formula_str)
|
|
533
|
+
result = stepwise_ols(
|
|
534
|
+
df, dep, indeps, direction=direction,
|
|
535
|
+
p_enter=p_enter, p_remove=p_remove,
|
|
536
|
+
)
|
|
537
|
+
session._last_model = None # stepwise doesn't store a single model
|
|
538
|
+
session._last_model_vars = (dep, result.selected)
|
|
539
|
+
return result.summary()
|
|
540
|
+
except ParseError as e:
|
|
541
|
+
return f"Formula error: {e}"
|
|
542
|
+
except Exception as e:
|
|
543
|
+
return friendly_error(e, "Stepwise error")
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
@command("residuals", usage="residuals [<col_name>]")
|
|
547
|
+
def cmd_residuals(session: Session, args: str) -> str:
|
|
548
|
+
"""Add residuals from the last model as a new column. Generates diagnostic plots."""
|
|
549
|
+
df = session.require_data()
|
|
550
|
+
if session._last_model is None or session._last_model_vars is None:
|
|
551
|
+
return "No model fitted yet. Run ols first, then residuals."
|
|
552
|
+
|
|
553
|
+
col_name = args.strip() or "residuals"
|
|
554
|
+
dep, indeps = session._last_model_vars
|
|
555
|
+
|
|
556
|
+
try:
|
|
557
|
+
diag = compute_residuals(session._last_model, df, dep, indeps)
|
|
558
|
+
except Exception as e:
|
|
559
|
+
return friendly_error(e, "Residuals error")
|
|
560
|
+
|
|
561
|
+
session.snapshot()
|
|
562
|
+
session.df = df.with_columns(
|
|
563
|
+
pl.Series(col_name, diag["residuals"].tolist()).cast(pl.Float64)
|
|
564
|
+
)
|
|
565
|
+
|
|
566
|
+
# Generate diagnostic plots
|
|
567
|
+
from openstat.plots.plotter import plot_residuals_vs_fitted, plot_qq, plot_scale_location
|
|
568
|
+
paths = []
|
|
569
|
+
try:
|
|
570
|
+
paths.append(plot_residuals_vs_fitted(diag["fitted"], diag["residuals"], session.output_dir))
|
|
571
|
+
paths.append(plot_qq(diag["std_residuals"], session.output_dir))
|
|
572
|
+
paths.append(plot_scale_location(diag["fitted"], diag["std_residuals"], session.output_dir))
|
|
573
|
+
session.plot_paths.extend(str(p) for p in paths)
|
|
574
|
+
except Exception:
|
|
575
|
+
pass # plots are optional
|
|
576
|
+
|
|
577
|
+
lines = [f"Residuals added as '{col_name}'. {session.shape_str}."]
|
|
578
|
+
if paths:
|
|
579
|
+
lines.append("Diagnostic plots saved:")
|
|
580
|
+
for p in paths:
|
|
581
|
+
lines.append(f" {p}")
|
|
582
|
+
lines.append("Use 'undo' to revert.")
|
|
583
|
+
return "\n".join(lines)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
@command("latex", usage="latex [<path.tex>]")
|
|
587
|
+
def cmd_latex(session: Session, args: str) -> str:
|
|
588
|
+
"""Export the last model result as a LaTeX table."""
|
|
589
|
+
if session._last_fit_result is None:
|
|
590
|
+
return "No model results to export. Run ols, logit, or probit first."
|
|
591
|
+
|
|
592
|
+
from openstat.stats.models import FitResult
|
|
593
|
+
result: FitResult = session._last_fit_result # type: ignore[assignment]
|
|
594
|
+
latex_str = result.to_latex()
|
|
595
|
+
|
|
596
|
+
path = args.strip()
|
|
597
|
+
if path:
|
|
598
|
+
from pathlib import Path as _Path
|
|
599
|
+
p = _Path(path)
|
|
600
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
601
|
+
p.write_text(latex_str, encoding="utf-8")
|
|
602
|
+
return f"LaTeX table saved to {p}"
|
|
603
|
+
return latex_str
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
@command("poisson", usage="poisson y ~ x1 + x2 [--robust] [--cluster=col] [--exposure=col]")
|
|
607
|
+
def cmd_poisson(session: Session, args: str) -> str:
|
|
608
|
+
"""Fit Poisson regression for count data."""
|
|
609
|
+
df = session.require_data()
|
|
610
|
+
ca = CommandArgs(args)
|
|
611
|
+
robust = ca.has_flag("--robust")
|
|
612
|
+
cluster_col = ca.get_option("cluster")
|
|
613
|
+
exposure_col = ca.get_option("exposure")
|
|
614
|
+
formula_str = ca.strip_flags_and_options()
|
|
615
|
+
if not formula_str:
|
|
616
|
+
return "Usage: poisson y ~ x1 + x2 [--robust] [--exposure=col]"
|
|
617
|
+
try:
|
|
618
|
+
dep, indeps = parse_formula(formula_str)
|
|
619
|
+
result, raw_model = fit_poisson(
|
|
620
|
+
df, dep, indeps, robust=robust,
|
|
621
|
+
cluster_col=cluster_col, exposure_col=exposure_col,
|
|
622
|
+
)
|
|
623
|
+
kw = {"exposure_col": exposure_col} if exposure_col else {}
|
|
624
|
+
return _store_model(session, result, raw_model, dep, indeps, fit_kwargs=kw)
|
|
625
|
+
except ParseError as e:
|
|
626
|
+
return f"Formula error: {e}"
|
|
627
|
+
except Exception as e:
|
|
628
|
+
return friendly_error(e, "Poisson error")
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
@command("negbin", usage="negbin y ~ x1 + x2 [--robust] [--cluster=col]")
|
|
632
|
+
def cmd_negbin(session: Session, args: str) -> str:
|
|
633
|
+
"""Fit Negative Binomial regression for overdispersed count data."""
|
|
634
|
+
df = session.require_data()
|
|
635
|
+
ca = CommandArgs(args)
|
|
636
|
+
robust = ca.has_flag("--robust")
|
|
637
|
+
cluster_col = ca.get_option("cluster")
|
|
638
|
+
formula_str = ca.strip_flags_and_options()
|
|
639
|
+
if not formula_str:
|
|
640
|
+
return "Usage: negbin y ~ x1 + x2 [--robust]"
|
|
641
|
+
try:
|
|
642
|
+
dep, indeps = parse_formula(formula_str)
|
|
643
|
+
result, raw_model = fit_negbin(
|
|
644
|
+
df, dep, indeps, robust=robust, cluster_col=cluster_col,
|
|
645
|
+
)
|
|
646
|
+
return _store_model(session, result, raw_model, dep, indeps)
|
|
647
|
+
except ParseError as e:
|
|
648
|
+
return f"Formula error: {e}"
|
|
649
|
+
except Exception as e:
|
|
650
|
+
return friendly_error(e, "NegBin error")
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
@command("quantreg", usage="quantreg y ~ x1 + x2 [tau=0.5]")
|
|
654
|
+
def cmd_quantreg(session: Session, args: str) -> str:
|
|
655
|
+
"""Fit quantile regression (default: median, tau=0.5)."""
|
|
656
|
+
df = session.require_data()
|
|
657
|
+
ca = CommandArgs(args)
|
|
658
|
+
tau = ca.get_option_float("tau", 0.5)
|
|
659
|
+
formula_str = ca.strip_flags_and_options()
|
|
660
|
+
if not formula_str:
|
|
661
|
+
return "Usage: quantreg y ~ x1 + x2 [tau=0.5]"
|
|
662
|
+
try:
|
|
663
|
+
dep, indeps = parse_formula(formula_str)
|
|
664
|
+
result, raw_model = fit_quantreg(df, dep, indeps, tau=tau)
|
|
665
|
+
return _store_model(session, result, raw_model, dep, indeps, fit_kwargs={"tau": tau})
|
|
666
|
+
except ParseError as e:
|
|
667
|
+
return f"Formula error: {e}"
|
|
668
|
+
except Exception as e:
|
|
669
|
+
return friendly_error(e, "QuantReg error")
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
@command("margins", usage="margins [--at=means|average]")
|
|
673
|
+
def cmd_margins(session: Session, args: str) -> str:
|
|
674
|
+
"""Compute marginal effects after logit or probit."""
|
|
675
|
+
if session._last_model is None:
|
|
676
|
+
return "No model fitted. Run logit or probit first."
|
|
677
|
+
if not hasattr(session._last_model, "get_margeff"):
|
|
678
|
+
return "Marginal effects only available for logit/probit models."
|
|
679
|
+
|
|
680
|
+
ca = CommandArgs(args)
|
|
681
|
+
method = ca.get_option("at", "average") or "average"
|
|
682
|
+
|
|
683
|
+
try:
|
|
684
|
+
# Build var_names from last fit result
|
|
685
|
+
fit_result = session._last_fit_result
|
|
686
|
+
var_names = list(fit_result.params.keys()) if fit_result else [] # type: ignore[union-attr]
|
|
687
|
+
result = compute_margins(session._last_model, var_names, method)
|
|
688
|
+
session._last_margins = result
|
|
689
|
+
return result.summary_table()
|
|
690
|
+
except Exception as e:
|
|
691
|
+
return friendly_error(e, "Margins error")
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
@command("bootstrap", usage="bootstrap [n=1000] [ci=95]")
|
|
695
|
+
def cmd_bootstrap(session: Session, args: str) -> str:
|
|
696
|
+
"""Bootstrap confidence intervals for the last fitted model."""
|
|
697
|
+
if session._last_fit_result is None:
|
|
698
|
+
return "No model fitted. Run a model command first."
|
|
699
|
+
if session._last_model_vars is None:
|
|
700
|
+
return "No model fitted. Run a model command first."
|
|
701
|
+
|
|
702
|
+
ca = CommandArgs(args)
|
|
703
|
+
n_boot = int(ca.get_option_float("n", float(get_config().bootstrap_iterations)))
|
|
704
|
+
ci = ca.get_option_float("ci", 95.0)
|
|
705
|
+
|
|
706
|
+
dep, indeps = session._last_model_vars
|
|
707
|
+
|
|
708
|
+
# Determine which fit function to use from last model type
|
|
709
|
+
fit_fn_map = {
|
|
710
|
+
"OLS": fit_ols,
|
|
711
|
+
"Logit": fit_logit,
|
|
712
|
+
"Probit": fit_probit,
|
|
713
|
+
"Poisson": fit_poisson,
|
|
714
|
+
"NegBin": fit_negbin,
|
|
715
|
+
}
|
|
716
|
+
model_type = session._last_fit_result.model_type.split()[0] # type: ignore[union-attr]
|
|
717
|
+
# Handle QuantReg(tau=0.5) format
|
|
718
|
+
if model_type.startswith("QuantReg"):
|
|
719
|
+
fit_fn = fit_quantreg
|
|
720
|
+
else:
|
|
721
|
+
fit_fn = fit_fn_map.get(model_type) # type: ignore[assignment]
|
|
722
|
+
|
|
723
|
+
if fit_fn is None:
|
|
724
|
+
return f"Bootstrap not supported for model type: {model_type}"
|
|
725
|
+
|
|
726
|
+
try:
|
|
727
|
+
result = bootstrap_model(
|
|
728
|
+
session.require_data(), dep, indeps, fit_fn, n_boot, ci,
|
|
729
|
+
**session._last_fit_kwargs,
|
|
730
|
+
)
|
|
731
|
+
return result.summary_table()
|
|
732
|
+
except Exception as e:
|
|
733
|
+
return friendly_error(e, "Bootstrap error")
|
|
734
|
+
|
|
735
|
+
|
|
736
|
+
# ---------------------------------------------------------------------------
|
|
737
|
+
# estat — post-estimation diagnostics
|
|
738
|
+
# ---------------------------------------------------------------------------
|
|
739
|
+
|
|
740
|
+
@command("estat", usage="estat <subcommand> (hettest | ovtest | linktest | ic | icc | firststage | overid | endogtest | phtest | deff | all)")
|
|
741
|
+
def cmd_estat(session: Session, args: str) -> str:
|
|
742
|
+
"""Post-estimation diagnostics (Stata-style).
|
|
743
|
+
|
|
744
|
+
Subcommands:
|
|
745
|
+
hettest — Breusch-Pagan / Cook-Weisberg heteroscedasticity test
|
|
746
|
+
ovtest — Ramsey RESET specification test
|
|
747
|
+
linktest — link test for model specification
|
|
748
|
+
ic — Information criteria (AIC, BIC, log-likelihood)
|
|
749
|
+
icc — Intraclass Correlation Coefficient (after mixed)
|
|
750
|
+
firststage — First-stage diagnostics (after ivregress)
|
|
751
|
+
overid — Overidentification test (after ivregress)
|
|
752
|
+
endogtest — Endogeneity test (after ivregress)
|
|
753
|
+
phtest — Proportional hazards test (after stcox)
|
|
754
|
+
deff — Design effect (after svy:)
|
|
755
|
+
all — Run all OLS diagnostics
|
|
756
|
+
"""
|
|
757
|
+
import statsmodels.api as sm
|
|
758
|
+
from openstat.stats.models import _build_X_from_indeps
|
|
759
|
+
|
|
760
|
+
if session._last_model is None or session._last_model_vars is None:
|
|
761
|
+
return "No model fitted. Run a model command first."
|
|
762
|
+
|
|
763
|
+
sub_cmd = args.strip().lower()
|
|
764
|
+
if not sub_cmd:
|
|
765
|
+
return (
|
|
766
|
+
"Usage: estat <subcommand>\n"
|
|
767
|
+
" hettest Breusch-Pagan heteroscedasticity test\n"
|
|
768
|
+
" ovtest Ramsey RESET specification test\n"
|
|
769
|
+
" linktest Link test for model specification\n"
|
|
770
|
+
" ic Information criteria (AIC, BIC)\n"
|
|
771
|
+
" icc Intraclass Correlation Coefficient (after mixed)\n"
|
|
772
|
+
" firststage First-stage diagnostics (after ivregress)\n"
|
|
773
|
+
" overid Overidentification test (after ivregress)\n"
|
|
774
|
+
" endogtest Endogeneity test (after ivregress)\n"
|
|
775
|
+
" phtest Proportional hazards test (after stcox)\n"
|
|
776
|
+
" deff Design effect (after svy:)\n"
|
|
777
|
+
" all Run all OLS diagnostics"
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
# --- New feature-specific subcommands ---
|
|
781
|
+
if sub_cmd == "icc":
|
|
782
|
+
try:
|
|
783
|
+
from openstat.stats.mixed import compute_icc
|
|
784
|
+
icc_val = compute_icc(session._last_model)
|
|
785
|
+
return f"Intraclass Correlation Coefficient (ICC): {icc_val:.4f}\n ICC = var(random effect) / (var(random effect) + var(residual))"
|
|
786
|
+
except Exception as e:
|
|
787
|
+
return friendly_error(e, "estat icc (requires a mixed model)")
|
|
788
|
+
|
|
789
|
+
if sub_cmd == "firststage":
|
|
790
|
+
try:
|
|
791
|
+
from openstat.stats.iv import first_stage_diagnostics
|
|
792
|
+
return first_stage_diagnostics(session._last_model)
|
|
793
|
+
except Exception as e:
|
|
794
|
+
return friendly_error(e, "estat firststage (requires an IV model)")
|
|
795
|
+
|
|
796
|
+
if sub_cmd == "overid":
|
|
797
|
+
try:
|
|
798
|
+
from openstat.stats.iv import overidentification_test
|
|
799
|
+
return overidentification_test(session._last_model)
|
|
800
|
+
except Exception as e:
|
|
801
|
+
return friendly_error(e, "estat overid (requires an IV model)")
|
|
802
|
+
|
|
803
|
+
if sub_cmd == "endogtest":
|
|
804
|
+
try:
|
|
805
|
+
from openstat.stats.iv import endogeneity_test
|
|
806
|
+
return endogeneity_test(session._last_model)
|
|
807
|
+
except Exception as e:
|
|
808
|
+
return friendly_error(e, "estat endogtest (requires an IV model)")
|
|
809
|
+
|
|
810
|
+
if sub_cmd == "phtest":
|
|
811
|
+
try:
|
|
812
|
+
from openstat.stats.survival import schoenfeld_test
|
|
813
|
+
return schoenfeld_test(session._last_model)
|
|
814
|
+
except Exception as e:
|
|
815
|
+
return friendly_error(e, "estat phtest (requires a Cox PH model)")
|
|
816
|
+
|
|
817
|
+
if sub_cmd == "deff":
|
|
818
|
+
df = session.require_data()
|
|
819
|
+
if session._svy_weight_var is None:
|
|
820
|
+
return "Survey design not set. Use svyset first."
|
|
821
|
+
dep, _ = session._last_model_vars
|
|
822
|
+
try:
|
|
823
|
+
from openstat.stats.survey import compute_deff
|
|
824
|
+
deff = compute_deff(df, dep, session._svy_weight_var,
|
|
825
|
+
session._svy_psu_var, session._svy_strata_var)
|
|
826
|
+
return f"Design Effect (DEFF) for {dep}: {deff:.4f}\n DEFF > 1 indicates clustering increases variance relative to SRS"
|
|
827
|
+
except Exception as e:
|
|
828
|
+
return friendly_error(e, "estat deff")
|
|
829
|
+
|
|
830
|
+
model = session._last_model
|
|
831
|
+
dep, indeps = session._last_model_vars
|
|
832
|
+
df = session.require_data()
|
|
833
|
+
|
|
834
|
+
# Build data aligned with model
|
|
835
|
+
all_base: set[str] = set()
|
|
836
|
+
for v in indeps:
|
|
837
|
+
if ":" in v:
|
|
838
|
+
all_base.update(v.split(":"))
|
|
839
|
+
else:
|
|
840
|
+
all_base.add(v)
|
|
841
|
+
cols_needed = [dep] + sorted(all_base)
|
|
842
|
+
sub_df = df.select(cols_needed).drop_nulls()
|
|
843
|
+
y = sub_df[dep].to_numpy().astype(float)
|
|
844
|
+
X = _build_X_from_indeps(sub_df, indeps)
|
|
845
|
+
X = sm.add_constant(X)
|
|
846
|
+
|
|
847
|
+
results: list[str] = []
|
|
848
|
+
|
|
849
|
+
# ── hettest ────────────────────────────────────────────────────────
|
|
850
|
+
if sub_cmd in ("hettest", "all"):
|
|
851
|
+
try:
|
|
852
|
+
from statsmodels.stats.diagnostic import het_breuschpagan
|
|
853
|
+
resid = y - model.predict(X)
|
|
854
|
+
bp_stat, bp_pval, f_stat, f_pval = het_breuschpagan(resid, X)
|
|
855
|
+
|
|
856
|
+
def render_het(console: Console) -> None:
|
|
857
|
+
t = Table(title="Breusch-Pagan / Cook-Weisberg Test for Heteroscedasticity")
|
|
858
|
+
t.add_column("Metric", style="cyan")
|
|
859
|
+
t.add_column("Value", justify="right")
|
|
860
|
+
t.add_row("LM statistic", f"{bp_stat:.4f}")
|
|
861
|
+
t.add_row("LM p-value", f"{bp_pval:.6f}")
|
|
862
|
+
t.add_row("F statistic", f"{f_stat:.4f}")
|
|
863
|
+
t.add_row("F p-value", f"{f_pval:.6f}")
|
|
864
|
+
console.print(t)
|
|
865
|
+
sig = "Heteroscedasticity detected" if bp_pval < 0.05 else "No significant heteroscedasticity"
|
|
866
|
+
console.print(f"Result: {sig} at alpha = 0.05")
|
|
867
|
+
|
|
868
|
+
results.append(rich_to_str(render_het))
|
|
869
|
+
except Exception as e:
|
|
870
|
+
results.append(f"hettest failed: {e}")
|
|
871
|
+
|
|
872
|
+
# ── ovtest (Ramsey RESET) ──────────────────────────────────────────
|
|
873
|
+
if sub_cmd in ("ovtest", "all"):
|
|
874
|
+
try:
|
|
875
|
+
from statsmodels.stats.diagnostic import linear_reset
|
|
876
|
+
fitted_model = sm.OLS(y, X).fit()
|
|
877
|
+
reset_test = linear_reset(fitted_model, power=3, use_f=True)
|
|
878
|
+
|
|
879
|
+
def render_ov(console: Console) -> None:
|
|
880
|
+
t = Table(title="Ramsey RESET Test (powers 2-3)")
|
|
881
|
+
t.add_column("Metric", style="cyan")
|
|
882
|
+
t.add_column("Value", justify="right")
|
|
883
|
+
t.add_row("F statistic", f"{reset_test.fvalue:.4f}")
|
|
884
|
+
t.add_row("p-value", f"{reset_test.pvalue:.6f}")
|
|
885
|
+
t.add_row("df", f"({int(reset_test.df_num)}, {int(reset_test.df_denom)})")
|
|
886
|
+
console.print(t)
|
|
887
|
+
sig = "Specification error detected" if reset_test.pvalue < 0.05 else "No specification error"
|
|
888
|
+
console.print(f"Result: {sig} at alpha = 0.05")
|
|
889
|
+
|
|
890
|
+
results.append(rich_to_str(render_ov))
|
|
891
|
+
except Exception as e:
|
|
892
|
+
results.append(f"ovtest failed: {e}")
|
|
893
|
+
|
|
894
|
+
# ── linktest ───────────────────────────────────────────────────────
|
|
895
|
+
if sub_cmd in ("linktest", "all"):
|
|
896
|
+
try:
|
|
897
|
+
yhat = model.predict(X)
|
|
898
|
+
yhat_sq = yhat ** 2
|
|
899
|
+
import numpy as _np
|
|
900
|
+
X_link = sm.add_constant(_np.column_stack([yhat, yhat_sq]))
|
|
901
|
+
link_model = sm.OLS(y, X_link).fit()
|
|
902
|
+
p_hatsq = float(link_model.pvalues[2])
|
|
903
|
+
|
|
904
|
+
def render_link(console: Console) -> None:
|
|
905
|
+
t = Table(title="Link Test for Model Specification")
|
|
906
|
+
t.add_column("Variable", style="cyan")
|
|
907
|
+
t.add_column("Coef", justify="right")
|
|
908
|
+
t.add_column("Std.Err", justify="right")
|
|
909
|
+
t.add_column("t", justify="right")
|
|
910
|
+
t.add_column("P>|t|", justify="right")
|
|
911
|
+
names = ["_cons", "_hat", "_hatsq"]
|
|
912
|
+
for i, name in enumerate(names):
|
|
913
|
+
t.add_row(
|
|
914
|
+
name,
|
|
915
|
+
f"{link_model.params[i]:.4f}",
|
|
916
|
+
f"{link_model.bse[i]:.4f}",
|
|
917
|
+
f"{link_model.tvalues[i]:.3f}",
|
|
918
|
+
f"{link_model.pvalues[i]:.4f}",
|
|
919
|
+
)
|
|
920
|
+
console.print(t)
|
|
921
|
+
if p_hatsq < 0.05:
|
|
922
|
+
console.print("Note: _hatsq is significant — possible specification error.")
|
|
923
|
+
else:
|
|
924
|
+
console.print("Note: _hatsq is not significant — model appears well-specified.")
|
|
925
|
+
|
|
926
|
+
results.append(rich_to_str(render_link))
|
|
927
|
+
except Exception as e:
|
|
928
|
+
results.append(f"linktest failed: {e}")
|
|
929
|
+
|
|
930
|
+
# ── ic (information criteria) ──────────────────────────────────────
|
|
931
|
+
if sub_cmd in ("ic", "all"):
|
|
932
|
+
try:
|
|
933
|
+
def render_ic(console: Console) -> None:
|
|
934
|
+
t = Table(title="Information Criteria")
|
|
935
|
+
t.add_column("Criterion", style="cyan")
|
|
936
|
+
t.add_column("Value", justify="right")
|
|
937
|
+
if hasattr(model, "aic"):
|
|
938
|
+
t.add_row("AIC", f"{model.aic:.2f}")
|
|
939
|
+
if hasattr(model, "bic"):
|
|
940
|
+
t.add_row("BIC", f"{model.bic:.2f}")
|
|
941
|
+
if hasattr(model, "llf"):
|
|
942
|
+
t.add_row("Log-Likelihood", f"{model.llf:.2f}")
|
|
943
|
+
if hasattr(model, "nobs"):
|
|
944
|
+
t.add_row("N", str(int(model.nobs)))
|
|
945
|
+
if hasattr(model, "df_model"):
|
|
946
|
+
t.add_row("df (model)", str(int(model.df_model)))
|
|
947
|
+
console.print(t)
|
|
948
|
+
|
|
949
|
+
results.append(rich_to_str(render_ic))
|
|
950
|
+
except Exception as e:
|
|
951
|
+
results.append(f"ic failed: {e}")
|
|
952
|
+
|
|
953
|
+
if not results:
|
|
954
|
+
return f"Unknown estat subcommand: {sub_cmd}. Use: hettest, ovtest, linktest, ic, icc, firststage, overid, endogtest, phtest, deff, all"
|
|
955
|
+
|
|
956
|
+
return "\n\n".join(results)
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
# ---------------------------------------------------------------------------
|
|
960
|
+
# estimates table — model comparison
|
|
961
|
+
# ---------------------------------------------------------------------------
|
|
962
|
+
|
|
963
|
+
@command("estimates", usage="estimates table")
|
|
964
|
+
def cmd_estimates(session: Session, args: str) -> str:
|
|
965
|
+
"""Compare stored model results side-by-side."""
|
|
966
|
+
sub = args.strip().lower()
|
|
967
|
+
if sub != "table":
|
|
968
|
+
return "Usage: estimates table"
|
|
969
|
+
|
|
970
|
+
model_results = [r for r in session.results if hasattr(r, "formula")]
|
|
971
|
+
if len(model_results) < 2:
|
|
972
|
+
return "Need at least 2 stored model results. Run multiple models first."
|
|
973
|
+
|
|
974
|
+
# Collect all variable names across all models (preserving order)
|
|
975
|
+
all_vars: list[str] = []
|
|
976
|
+
seen_vars: set[str] = set()
|
|
977
|
+
for mr in model_results:
|
|
978
|
+
params = mr.details.get("params", {})
|
|
979
|
+
for var in params:
|
|
980
|
+
if var not in seen_vars:
|
|
981
|
+
seen_vars.add(var)
|
|
982
|
+
all_vars.append(var)
|
|
983
|
+
|
|
984
|
+
def render(console: Console) -> None:
|
|
985
|
+
table = Table(title="Model Comparison")
|
|
986
|
+
table.add_column("", style="cyan")
|
|
987
|
+
for i, mr in enumerate(model_results):
|
|
988
|
+
label = f"({i + 1}) {mr.name}"
|
|
989
|
+
table.add_column(label, justify="right")
|
|
990
|
+
|
|
991
|
+
# Coefficient rows with SE in parentheses
|
|
992
|
+
for var in all_vars:
|
|
993
|
+
vals = []
|
|
994
|
+
for mr in model_results:
|
|
995
|
+
params = mr.details.get("params", {})
|
|
996
|
+
se = mr.details.get("std_errors", {})
|
|
997
|
+
if var in params:
|
|
998
|
+
coef = params[var]
|
|
999
|
+
se_val = se.get(var)
|
|
1000
|
+
cell = f"{coef:.4f}"
|
|
1001
|
+
if se_val is not None:
|
|
1002
|
+
cell += f"\n({se_val:.4f})"
|
|
1003
|
+
vals.append(cell)
|
|
1004
|
+
else:
|
|
1005
|
+
vals.append("—")
|
|
1006
|
+
table.add_row(var, *vals)
|
|
1007
|
+
|
|
1008
|
+
# Separator
|
|
1009
|
+
table.add_section()
|
|
1010
|
+
|
|
1011
|
+
# Model statistics
|
|
1012
|
+
stat_keys = [
|
|
1013
|
+
("n_obs", "N"), ("r_squared", "R²"), ("adj_r_squared", "Adj. R²"),
|
|
1014
|
+
("pseudo_r2", "Pseudo R²"), ("log_likelihood", "Log-Lik."),
|
|
1015
|
+
("aic", "AIC"), ("bic", "BIC"), ("dispersion", "Dispersion (α)"),
|
|
1016
|
+
]
|
|
1017
|
+
|
|
1018
|
+
for key, label in stat_keys:
|
|
1019
|
+
vals = []
|
|
1020
|
+
any_present = False
|
|
1021
|
+
for mr in model_results:
|
|
1022
|
+
v = mr.details.get(key)
|
|
1023
|
+
if v is not None:
|
|
1024
|
+
any_present = True
|
|
1025
|
+
if isinstance(v, float):
|
|
1026
|
+
if key == "n_obs":
|
|
1027
|
+
vals.append(str(int(v)))
|
|
1028
|
+
else:
|
|
1029
|
+
vals.append(f"{v:.4f}")
|
|
1030
|
+
else:
|
|
1031
|
+
vals.append(str(v))
|
|
1032
|
+
else:
|
|
1033
|
+
vals.append("—")
|
|
1034
|
+
if any_present:
|
|
1035
|
+
table.add_row(label, *vals)
|
|
1036
|
+
|
|
1037
|
+
console.print(table)
|
|
1038
|
+
console.print("Standard errors in parentheses")
|
|
1039
|
+
|
|
1040
|
+
return rich_to_str(render)
|