openstat-cli 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openstat/__init__.py +3 -0
- openstat/__main__.py +4 -0
- openstat/backends/__init__.py +16 -0
- openstat/backends/duckdb_backend.py +70 -0
- openstat/backends/polars_backend.py +52 -0
- openstat/cli.py +92 -0
- openstat/commands/__init__.py +82 -0
- openstat/commands/adv_stat_cmds.py +1255 -0
- openstat/commands/advanced_ml_cmds.py +576 -0
- openstat/commands/advreg_cmds.py +207 -0
- openstat/commands/alias_cmds.py +135 -0
- openstat/commands/arch_cmds.py +82 -0
- openstat/commands/arules_cmds.py +111 -0
- openstat/commands/automodel_cmds.py +212 -0
- openstat/commands/backend_cmds.py +82 -0
- openstat/commands/base.py +170 -0
- openstat/commands/bayes_cmds.py +71 -0
- openstat/commands/causal_cmds.py +269 -0
- openstat/commands/cluster_cmds.py +152 -0
- openstat/commands/data_cmds.py +996 -0
- openstat/commands/datamanip_cmds.py +672 -0
- openstat/commands/dataquality_cmds.py +174 -0
- openstat/commands/datetime_cmds.py +176 -0
- openstat/commands/dimreduce_cmds.py +184 -0
- openstat/commands/discrete_cmds.py +149 -0
- openstat/commands/dsl_cmds.py +143 -0
- openstat/commands/epi_cmds.py +93 -0
- openstat/commands/equiv_tobit_cmds.py +94 -0
- openstat/commands/esttab_cmds.py +196 -0
- openstat/commands/export_beamer_cmds.py +142 -0
- openstat/commands/export_cmds.py +201 -0
- openstat/commands/export_extra_cmds.py +240 -0
- openstat/commands/factor_cmds.py +180 -0
- openstat/commands/groupby_cmds.py +155 -0
- openstat/commands/help_cmds.py +237 -0
- openstat/commands/i18n_cmds.py +43 -0
- openstat/commands/import_extra_cmds.py +561 -0
- openstat/commands/influence_cmds.py +134 -0
- openstat/commands/iv_cmds.py +106 -0
- openstat/commands/manova_cmds.py +105 -0
- openstat/commands/mediate_cmds.py +233 -0
- openstat/commands/meta_cmds.py +284 -0
- openstat/commands/mi_cmds.py +228 -0
- openstat/commands/mixed_cmds.py +79 -0
- openstat/commands/mixture_changepoint_cmds.py +166 -0
- openstat/commands/ml_adv_cmds.py +147 -0
- openstat/commands/ml_cmds.py +178 -0
- openstat/commands/model_eval_cmds.py +142 -0
- openstat/commands/network_cmds.py +288 -0
- openstat/commands/nlquery_cmds.py +161 -0
- openstat/commands/nonparam_cmds.py +149 -0
- openstat/commands/outreg_cmds.py +247 -0
- openstat/commands/panel_cmds.py +141 -0
- openstat/commands/pdf_cmds.py +226 -0
- openstat/commands/pipeline_cmds.py +319 -0
- openstat/commands/plot_cmds.py +189 -0
- openstat/commands/plugin_cmds.py +79 -0
- openstat/commands/posthoc_cmds.py +153 -0
- openstat/commands/power_cmds.py +172 -0
- openstat/commands/profile_cmds.py +246 -0
- openstat/commands/rbridge_cmds.py +81 -0
- openstat/commands/regex_cmds.py +104 -0
- openstat/commands/report_cmds.py +48 -0
- openstat/commands/repro_cmds.py +129 -0
- openstat/commands/resampling_cmds.py +109 -0
- openstat/commands/reshape_cmds.py +223 -0
- openstat/commands/sem_cmds.py +177 -0
- openstat/commands/stat_cmds.py +1040 -0
- openstat/commands/stata_import_cmds.py +215 -0
- openstat/commands/string_cmds.py +124 -0
- openstat/commands/surv_cmds.py +145 -0
- openstat/commands/survey_cmds.py +153 -0
- openstat/commands/textanalysis_cmds.py +192 -0
- openstat/commands/ts_adv_cmds.py +136 -0
- openstat/commands/ts_cmds.py +195 -0
- openstat/commands/tui_cmds.py +111 -0
- openstat/commands/ux_cmds.py +191 -0
- openstat/commands/validate_cmds.py +270 -0
- openstat/commands/viz_adv_cmds.py +312 -0
- openstat/commands/viz_extra_cmds.py +251 -0
- openstat/commands/watch_cmds.py +69 -0
- openstat/config.py +106 -0
- openstat/dsl/__init__.py +0 -0
- openstat/dsl/parser.py +332 -0
- openstat/dsl/tokenizer.py +105 -0
- openstat/i18n.py +120 -0
- openstat/io/__init__.py +0 -0
- openstat/io/loader.py +187 -0
- openstat/jupyter/__init__.py +18 -0
- openstat/jupyter/display.py +18 -0
- openstat/jupyter/magic.py +60 -0
- openstat/logging_config.py +59 -0
- openstat/plots/__init__.py +0 -0
- openstat/plots/plotter.py +437 -0
- openstat/plots/surv_plots.py +32 -0
- openstat/plots/ts_plots.py +59 -0
- openstat/plugins/__init__.py +5 -0
- openstat/plugins/manager.py +69 -0
- openstat/repl.py +457 -0
- openstat/reporting/__init__.py +0 -0
- openstat/reporting/eda.py +208 -0
- openstat/reporting/report.py +67 -0
- openstat/script_runner.py +319 -0
- openstat/session.py +133 -0
- openstat/stats/__init__.py +0 -0
- openstat/stats/advanced_regression.py +269 -0
- openstat/stats/arch_garch.py +84 -0
- openstat/stats/bayesian.py +103 -0
- openstat/stats/causal.py +258 -0
- openstat/stats/clustering.py +206 -0
- openstat/stats/discrete.py +311 -0
- openstat/stats/epidemiology.py +119 -0
- openstat/stats/equiv_tobit.py +163 -0
- openstat/stats/factor.py +174 -0
- openstat/stats/imputation.py +282 -0
- openstat/stats/influence.py +78 -0
- openstat/stats/iv.py +131 -0
- openstat/stats/manova.py +124 -0
- openstat/stats/mixed.py +128 -0
- openstat/stats/ml.py +275 -0
- openstat/stats/ml_advanced.py +117 -0
- openstat/stats/model_eval.py +183 -0
- openstat/stats/models.py +1342 -0
- openstat/stats/nonparametric.py +130 -0
- openstat/stats/panel.py +179 -0
- openstat/stats/power.py +295 -0
- openstat/stats/resampling.py +203 -0
- openstat/stats/survey.py +213 -0
- openstat/stats/survival.py +196 -0
- openstat/stats/timeseries.py +142 -0
- openstat/stats/ts_advanced.py +114 -0
- openstat/types.py +11 -0
- openstat/web/__init__.py +1 -0
- openstat/web/app.py +117 -0
- openstat/web/session_manager.py +73 -0
- openstat/web/static/app.js +117 -0
- openstat/web/static/index.html +38 -0
- openstat/web/static/style.css +103 -0
- openstat_cli-1.0.0.dist-info/METADATA +748 -0
- openstat_cli-1.0.0.dist-info/RECORD +143 -0
- openstat_cli-1.0.0.dist-info/WHEEL +4 -0
- openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
- openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
"""Causal inference commands: did, psmatch."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
from openstat.session import Session, ModelResult
|
|
8
|
+
from openstat.dsl.parser import ParseError
|
|
9
|
+
from openstat.stats.causal import fit_did, fit_psm
|
|
10
|
+
from openstat.commands.base import command, CommandArgs, friendly_error
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _store_model(session, result, raw_model, dep, indeps):
|
|
14
|
+
"""Store model in session state, return summary output."""
|
|
15
|
+
session._last_model = raw_model
|
|
16
|
+
session._last_model_vars = (dep, indeps)
|
|
17
|
+
session._last_fit_result = result
|
|
18
|
+
session._last_fit_kwargs = {}
|
|
19
|
+
md = result.to_markdown()
|
|
20
|
+
details: dict = {
|
|
21
|
+
"n_obs": result.n_obs,
|
|
22
|
+
"params": dict(result.params),
|
|
23
|
+
"std_errors": dict(result.std_errors),
|
|
24
|
+
}
|
|
25
|
+
if result.r_squared is not None:
|
|
26
|
+
details["r_squared"] = result.r_squared
|
|
27
|
+
session.results.append(ModelResult(
|
|
28
|
+
name=result.model_type, formula=result.formula,
|
|
29
|
+
table=md, details=details,
|
|
30
|
+
))
|
|
31
|
+
output = result.summary_table()
|
|
32
|
+
if result.warnings:
|
|
33
|
+
output += "\n" + "\n".join(result.warnings)
|
|
34
|
+
return output
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@command("did", usage="did y ~ treatment_var time_var [--robust] [--cluster=col]")
|
|
38
|
+
def cmd_did(session: Session, args: str) -> str:
|
|
39
|
+
"""Difference-in-Differences estimation."""
|
|
40
|
+
df = session.require_data()
|
|
41
|
+
ca = CommandArgs(args)
|
|
42
|
+
robust = ca.has_flag("--robust")
|
|
43
|
+
cluster_col = ca.get_option("cluster")
|
|
44
|
+
formula_str = ca.strip_flags_and_options()
|
|
45
|
+
|
|
46
|
+
if not formula_str or "~" not in formula_str:
|
|
47
|
+
return "Usage: did y ~ treatment_var time_var [--robust] [--cluster=col]"
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
# Parse: y ~ treatment_var time_var
|
|
51
|
+
left, right = formula_str.split("~", 1)
|
|
52
|
+
dep = left.strip()
|
|
53
|
+
if not dep:
|
|
54
|
+
return "Usage: did y ~ treatment_var time_var"
|
|
55
|
+
|
|
56
|
+
rhs_vars = right.strip().split()
|
|
57
|
+
if len(rhs_vars) < 2:
|
|
58
|
+
return "Usage: did y ~ treatment_var time_var (need both treatment and time variables)"
|
|
59
|
+
|
|
60
|
+
treatment_col_name = rhs_vars[0]
|
|
61
|
+
time_col_name = rhs_vars[1]
|
|
62
|
+
|
|
63
|
+
result, raw_model = fit_did(
|
|
64
|
+
df, dep, treatment_col_name, time_col_name,
|
|
65
|
+
robust=robust, cluster_col=cluster_col,
|
|
66
|
+
)
|
|
67
|
+
return _store_model(session, result, raw_model, dep, [treatment_col_name, time_col_name])
|
|
68
|
+
except Exception as e:
|
|
69
|
+
return friendly_error(e, "DiD error")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
@command("psmatch", usage="psmatch outcome ~ covars, treatment(tvar) [caliper(0.1)] [nn(3)]")
|
|
73
|
+
def cmd_psmatch(session: Session, args: str) -> str:
|
|
74
|
+
"""Propensity Score Matching."""
|
|
75
|
+
df = session.require_data()
|
|
76
|
+
|
|
77
|
+
# Parse treatment(var)
|
|
78
|
+
treat_match = re.search(r'treatment\((\w+)\)', args)
|
|
79
|
+
if not treat_match:
|
|
80
|
+
return "Usage: psmatch outcome ~ x1 x2, treatment(tvar) [caliper(0.1)] [nn(3)]"
|
|
81
|
+
treatment_col = treat_match.group(1)
|
|
82
|
+
|
|
83
|
+
# Parse optional caliper(value)
|
|
84
|
+
caliper = None
|
|
85
|
+
caliper_match = re.search(r'caliper\(([^)]+)\)', args)
|
|
86
|
+
if caliper_match:
|
|
87
|
+
try:
|
|
88
|
+
caliper = float(caliper_match.group(1))
|
|
89
|
+
except ValueError:
|
|
90
|
+
return f"Invalid caliper value: {caliper_match.group(1)}"
|
|
91
|
+
|
|
92
|
+
# Parse optional nn(value)
|
|
93
|
+
n_neighbors = 1
|
|
94
|
+
nn_match = re.search(r'nn\((\d+)\)', args)
|
|
95
|
+
if nn_match:
|
|
96
|
+
n_neighbors = int(nn_match.group(1))
|
|
97
|
+
|
|
98
|
+
# Strip options to get formula part
|
|
99
|
+
formula_part = args
|
|
100
|
+
for pattern in [r',?\s*treatment\([^)]+\)', r',?\s*caliper\([^)]+\)', r',?\s*nn\(\d+\)']:
|
|
101
|
+
formula_part = re.sub(pattern, '', formula_part)
|
|
102
|
+
formula_part = formula_part.strip().rstrip(',').strip()
|
|
103
|
+
|
|
104
|
+
if not formula_part or "~" not in formula_part:
|
|
105
|
+
return "Usage: psmatch outcome ~ x1 x2, treatment(tvar) [caliper(0.1)] [nn(3)]"
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
left, right = formula_part.split("~", 1)
|
|
109
|
+
outcome = left.strip()
|
|
110
|
+
covariates = right.strip().split()
|
|
111
|
+
|
|
112
|
+
if not outcome or not covariates:
|
|
113
|
+
return "Usage: psmatch outcome ~ x1 x2, treatment(tvar)"
|
|
114
|
+
|
|
115
|
+
result_str = fit_psm(
|
|
116
|
+
df, outcome, covariates, treatment_col,
|
|
117
|
+
n_neighbors=n_neighbors, caliper=caliper,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Store a simple record
|
|
121
|
+
session.results.append(ModelResult(
|
|
122
|
+
name="PSM", formula=f"{outcome} ~ {' + '.join(covariates)}",
|
|
123
|
+
table=result_str, details={"treatment": treatment_col},
|
|
124
|
+
))
|
|
125
|
+
return result_str
|
|
126
|
+
except Exception as e:
|
|
127
|
+
return friendly_error(e, "PSM error")
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@command("iptw", usage="iptw <outcome> ~ <covars>, treatment(<tvar>) [--ate|--att] [--trim=0.01]")
|
|
131
|
+
def cmd_iptw(session: Session, args: str) -> str:
|
|
132
|
+
"""Inverse Probability Treatment Weighting (IPTW) for causal inference.
|
|
133
|
+
|
|
134
|
+
Estimates propensity scores via logistic regression, then uses IPT weights
|
|
135
|
+
to estimate the Average Treatment Effect (ATE) or Average Treatment Effect
|
|
136
|
+
on the Treated (ATT) via weighted OLS.
|
|
137
|
+
|
|
138
|
+
Examples:
|
|
139
|
+
iptw score ~ age + income, treatment(employed)
|
|
140
|
+
iptw score ~ age + income, treatment(employed) --att
|
|
141
|
+
iptw score ~ age + income, treatment(employed) --ate --trim=0.05
|
|
142
|
+
"""
|
|
143
|
+
import re
|
|
144
|
+
import numpy as np
|
|
145
|
+
import polars as pl
|
|
146
|
+
import statsmodels.api as sm
|
|
147
|
+
from sklearn.linear_model import LogisticRegression
|
|
148
|
+
|
|
149
|
+
df = session.require_data()
|
|
150
|
+
|
|
151
|
+
# Parse treatment(var)
|
|
152
|
+
treat_m = re.search(r'treatment\((\w+)\)', args)
|
|
153
|
+
if not treat_m:
|
|
154
|
+
return "Usage: iptw outcome ~ x1 + x2, treatment(tvar) [--ate|--att] [--trim=0.01]"
|
|
155
|
+
treatment_col = treat_m.group(1)
|
|
156
|
+
|
|
157
|
+
# Parse estimand
|
|
158
|
+
att = "--att" in args
|
|
159
|
+
estimand = "ATT" if att else "ATE"
|
|
160
|
+
|
|
161
|
+
# Parse trim
|
|
162
|
+
trim = 0.01
|
|
163
|
+
trim_m = re.search(r'--trim[= ]([\d.]+)', args)
|
|
164
|
+
if trim_m:
|
|
165
|
+
trim = float(trim_m.group(1))
|
|
166
|
+
|
|
167
|
+
# Clean formula
|
|
168
|
+
formula_part = re.sub(r',?\s*treatment\([^)]+\)', '', args)
|
|
169
|
+
formula_part = re.sub(r'--\w+(?:[= ][\d.]+)?', '', formula_part).strip()
|
|
170
|
+
|
|
171
|
+
if "~" not in formula_part:
|
|
172
|
+
return "Usage: iptw outcome ~ x1 + x2, treatment(tvar)"
|
|
173
|
+
|
|
174
|
+
lhs, rhs = formula_part.split("~", 1)
|
|
175
|
+
outcome = lhs.strip()
|
|
176
|
+
covars = [c.strip() for c in rhs.replace("+", " ").split() if c.strip()]
|
|
177
|
+
|
|
178
|
+
needed = [outcome, treatment_col] + covars
|
|
179
|
+
missing = [c for c in needed if c not in df.columns]
|
|
180
|
+
if missing:
|
|
181
|
+
return f"Columns not found: {', '.join(missing)}"
|
|
182
|
+
|
|
183
|
+
sub = df.select(needed).drop_nulls()
|
|
184
|
+
y = sub[outcome].to_numpy().astype(float)
|
|
185
|
+
treat = sub[treatment_col].to_numpy().astype(float)
|
|
186
|
+
X = sub.select(covars).to_numpy().astype(float)
|
|
187
|
+
|
|
188
|
+
# Step 1: Propensity score model
|
|
189
|
+
try:
|
|
190
|
+
ps_model = LogisticRegression(max_iter=1000, C=1e6)
|
|
191
|
+
ps_model.fit(X, treat)
|
|
192
|
+
ps = ps_model.predict_proba(X)[:, 1]
|
|
193
|
+
except Exception as exc:
|
|
194
|
+
return f"Propensity score estimation failed: {exc}"
|
|
195
|
+
|
|
196
|
+
# Trim extreme propensity scores
|
|
197
|
+
ps_clipped = np.clip(ps, trim, 1 - trim)
|
|
198
|
+
|
|
199
|
+
# Step 2: Compute IPTW weights
|
|
200
|
+
if estimand == "ATE":
|
|
201
|
+
weights = treat / ps_clipped + (1 - treat) / (1 - ps_clipped)
|
|
202
|
+
else: # ATT
|
|
203
|
+
weights = treat + (1 - treat) * ps_clipped / (1 - ps_clipped)
|
|
204
|
+
|
|
205
|
+
# Step 3: Weighted OLS of outcome on treatment
|
|
206
|
+
X_ols = sm.add_constant(treat)
|
|
207
|
+
try:
|
|
208
|
+
wls_res = sm.WLS(y, X_ols, weights=weights).fit()
|
|
209
|
+
except Exception as exc:
|
|
210
|
+
return f"Weighted regression failed: {exc}"
|
|
211
|
+
|
|
212
|
+
ate_est = wls_res.params[1]
|
|
213
|
+
ate_se = wls_res.bse[1]
|
|
214
|
+
ate_t = wls_res.tvalues[1]
|
|
215
|
+
ate_p = wls_res.pvalues[1]
|
|
216
|
+
ci_low, ci_high = wls_res.conf_int()[1]
|
|
217
|
+
|
|
218
|
+
def _sig(p):
|
|
219
|
+
if p < 0.001: return "***"
|
|
220
|
+
if p < 0.01: return "**"
|
|
221
|
+
if p < 0.05: return "*"
|
|
222
|
+
return ""
|
|
223
|
+
|
|
224
|
+
# Balance assessment (standardized mean differences before/after)
|
|
225
|
+
n_treat = int(treat.sum())
|
|
226
|
+
n_ctrl = int((1 - treat).sum())
|
|
227
|
+
|
|
228
|
+
lines = [
|
|
229
|
+
f"Outcome: {outcome} Treatment: {treatment_col} Estimand: {estimand}",
|
|
230
|
+
f"N = {len(y)} (treated={n_treat}, control={n_ctrl})",
|
|
231
|
+
f"Propensity score trim: [{trim:.3f}, {1-trim:.3f}]",
|
|
232
|
+
"",
|
|
233
|
+
f"IPTW {estimand} Estimate:",
|
|
234
|
+
f" {'Coef':>10} {'SE':>8} {'t':>7} {'p-value':>9} {'95% CI':>20}",
|
|
235
|
+
" " + "-" * 58,
|
|
236
|
+
f" {ate_est:>10.4f} {ate_se:>8.4f} {ate_t:>7.3f} "
|
|
237
|
+
f"{ate_p:>9.4f}{_sig(ate_p)} [{ci_low:.4f}, {ci_high:.4f}]",
|
|
238
|
+
"",
|
|
239
|
+
"Weight Summary:",
|
|
240
|
+
f" Min={weights.min():.3f} Mean={weights.mean():.3f} "
|
|
241
|
+
f"Max={weights.max():.3f} SD={weights.std():.3f}",
|
|
242
|
+
"",
|
|
243
|
+
"Propensity Score Summary:",
|
|
244
|
+
f" Treated — mean={ps[treat==1].mean():.3f} "
|
|
245
|
+
f"min={ps[treat==1].min():.3f} max={ps[treat==1].max():.3f}",
|
|
246
|
+
f" Control — mean={ps[treat==0].mean():.3f} "
|
|
247
|
+
f"min={ps[treat==0].min():.3f} max={ps[treat==0].max():.3f}",
|
|
248
|
+
]
|
|
249
|
+
|
|
250
|
+
# Standardized mean differences (covariate balance)
|
|
251
|
+
lines += ["", "Covariate Balance (Standardized Mean Differences):"]
|
|
252
|
+
lines.append(f" {'Variable':<20} {'Before SMD':>12} {'After SMD':>11} {'Balanced?':>10}")
|
|
253
|
+
lines.append(" " + "-" * 56)
|
|
254
|
+
for j, cname in enumerate(covars):
|
|
255
|
+
xj = X[:, j]
|
|
256
|
+
mu_t = xj[treat == 1].mean()
|
|
257
|
+
mu_c = xj[treat == 0].mean()
|
|
258
|
+
sd_pool = np.sqrt((xj[treat == 1].var() + xj[treat == 0].var()) / 2 + 1e-10)
|
|
259
|
+
smd_before = abs(mu_t - mu_c) / sd_pool
|
|
260
|
+
# After weighting
|
|
261
|
+
mu_t_w = np.average(xj[treat == 1], weights=weights[treat == 1])
|
|
262
|
+
mu_c_w = np.average(xj[treat == 0], weights=weights[treat == 0])
|
|
263
|
+
smd_after = abs(mu_t_w - mu_c_w) / sd_pool
|
|
264
|
+
balanced = "✓" if smd_after < 0.1 else "✗"
|
|
265
|
+
lines.append(
|
|
266
|
+
f" {cname:<20} {smd_before:>12.4f} {smd_after:>11.4f} {balanced:>10}"
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
return "\n" + "=" * 60 + "\nIPTW Causal Estimate\n" + "=" * 60 + "\n" + "\n".join(lines) + "\n" + "=" * 60
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
"""Clustering, MDS, and discriminant analysis commands."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
|
|
7
|
+
from openstat.commands.base import command
|
|
8
|
+
from openstat.session import Session
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _stata_opts(raw: str) -> tuple[list[str], dict[str, str]]:
|
|
12
|
+
opts: dict[str, str] = {}
|
|
13
|
+
for m in re.finditer(r'(\w+)\(([^)]*)\)', raw):
|
|
14
|
+
opts[m.group(1).lower()] = m.group(2)
|
|
15
|
+
rest = re.sub(r'\w+\([^)]*\)', '', raw)
|
|
16
|
+
positional = [t.strip(',') for t in rest.split() if t.strip(',')]
|
|
17
|
+
return positional, opts
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@command("cluster", usage="cluster kmeans|hierarchical varlist [, k(3) linkage(ward)]")
|
|
21
|
+
def cmd_cluster(session: Session, args: str) -> str:
|
|
22
|
+
"""K-means or hierarchical clustering."""
|
|
23
|
+
df = session.require_data()
|
|
24
|
+
positional, opts = _stata_opts(args)
|
|
25
|
+
if len(positional) < 2:
|
|
26
|
+
return (
|
|
27
|
+
"Usage: cluster kmeans varlist [, k(3)]\n"
|
|
28
|
+
" cluster hierarchical varlist [, k(3) linkage(ward)]"
|
|
29
|
+
)
|
|
30
|
+
sub = positional[0].lower()
|
|
31
|
+
cols = [c for c in positional[1:] if c in df.columns]
|
|
32
|
+
if not cols:
|
|
33
|
+
return "No valid columns found."
|
|
34
|
+
|
|
35
|
+
k = int(opts.get("k", 3))
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
if sub in ("kmeans", "k-means"):
|
|
39
|
+
from openstat.stats.clustering import fit_kmeans
|
|
40
|
+
result = fit_kmeans(df, cols, k=k)
|
|
41
|
+
lines = [f"\nK-Means Clustering (k={k})", "=" * 50]
|
|
42
|
+
lines.append(f" {'N observations':<25} {result['n_obs']}")
|
|
43
|
+
lines.append(f" {'Inertia':<25} {result['inertia']:.4f}")
|
|
44
|
+
lines.append(f" {'Silhouette score':<25} {result['silhouette_score']:.4f}")
|
|
45
|
+
lines.append(f" {'Calinski-Harabasz':<25} {result['calinski_harabasz']:.4f}")
|
|
46
|
+
lines.append("\nCluster sizes:")
|
|
47
|
+
for cl, n in result["cluster_sizes"].items():
|
|
48
|
+
lines.append(f" Cluster {cl + 1}: {n} obs ({n/result['n_obs']*100:.1f}%)")
|
|
49
|
+
lines.append("\nCentroids (original scale):")
|
|
50
|
+
lines.append(" " + f"{'Cluster':<10}" + "".join(f" {c[:8]:>8}" for c in cols))
|
|
51
|
+
for i, cent in enumerate(result["centroids"]):
|
|
52
|
+
row = f" {'Cl.' + str(i+1):<10}"
|
|
53
|
+
for v in cent:
|
|
54
|
+
row += f" {v:>8.3f}"
|
|
55
|
+
lines.append(row)
|
|
56
|
+
lines.append("=" * 50)
|
|
57
|
+
session._last_model = result
|
|
58
|
+
return "\n".join(lines)
|
|
59
|
+
|
|
60
|
+
elif sub in ("hierarchical", "hier", "agglomerative"):
|
|
61
|
+
linkage = opts.get("linkage", "ward")
|
|
62
|
+
from openstat.stats.clustering import fit_hierarchical
|
|
63
|
+
result = fit_hierarchical(df, cols, k=k, linkage=linkage)
|
|
64
|
+
lines = [f"\nHierarchical Clustering (k={k}, linkage={linkage})", "=" * 50]
|
|
65
|
+
lines.append(f" {'N observations':<25} {result['n_obs']}")
|
|
66
|
+
lines.append(f" {'Silhouette score':<25} {result['silhouette_score']:.4f}")
|
|
67
|
+
lines.append(f" {'Calinski-Harabasz':<25} {result['calinski_harabasz']:.4f}")
|
|
68
|
+
lines.append("\nCluster sizes:")
|
|
69
|
+
for cl, n in result["cluster_sizes"].items():
|
|
70
|
+
lines.append(f" Cluster {cl + 1}: {n} obs")
|
|
71
|
+
lines.append("=" * 50)
|
|
72
|
+
session._last_model = result
|
|
73
|
+
return "\n".join(lines)
|
|
74
|
+
|
|
75
|
+
else:
|
|
76
|
+
return f"Unknown cluster method: {sub}. Use 'kmeans' or 'hierarchical'."
|
|
77
|
+
|
|
78
|
+
except ImportError as e:
|
|
79
|
+
return str(e)
|
|
80
|
+
except Exception as exc:
|
|
81
|
+
return f"cluster error: {exc}"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@command("mds", usage="mds varlist [, n(2) metric]")
|
|
85
|
+
def cmd_mds(session: Session, args: str) -> str:
|
|
86
|
+
"""Multidimensional scaling."""
|
|
87
|
+
df = session.require_data()
|
|
88
|
+
positional, opts = _stata_opts(args)
|
|
89
|
+
cols = [c for c in positional if c in df.columns]
|
|
90
|
+
if len(cols) < 2:
|
|
91
|
+
return "mds requires at least 2 numeric variables."
|
|
92
|
+
n_comp = int(opts.get("n", 2))
|
|
93
|
+
metric = "nonmetric" not in positional
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
from openstat.stats.clustering import fit_mds
|
|
97
|
+
result = fit_mds(df, cols, n_components=n_comp, metric=metric)
|
|
98
|
+
lines = [f"\nMDS ({'metric' if metric else 'non-metric'})", "=" * 50]
|
|
99
|
+
lines.append(f" {'N observations':<25} {result['n_obs']}")
|
|
100
|
+
lines.append(f" {'N components':<25} {n_comp}")
|
|
101
|
+
lines.append(f" {'Stress':<25} {result['stress']:.6f}")
|
|
102
|
+
lines.append("\nFirst 5 coordinates (Dim 1, Dim 2):")
|
|
103
|
+
for i, coord in enumerate(result["coordinates"][:5]):
|
|
104
|
+
lines.append(" " + " ".join(f"{v:>8.4f}" for v in coord))
|
|
105
|
+
if len(result["coordinates"]) > 5:
|
|
106
|
+
lines.append(f" ... ({len(result['coordinates'])} total rows)")
|
|
107
|
+
lines.append("=" * 50)
|
|
108
|
+
session._last_model = result
|
|
109
|
+
return "\n".join(lines)
|
|
110
|
+
except ImportError as e:
|
|
111
|
+
return str(e)
|
|
112
|
+
except Exception as exc:
|
|
113
|
+
return f"mds error: {exc}"
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@command("discriminant", usage="discriminant groupvar indepvars [, method(lda|qda)]")
|
|
117
|
+
def cmd_discriminant(session: Session, args: str) -> str:
|
|
118
|
+
"""Linear (LDA) or Quadratic (QDA) Discriminant Analysis."""
|
|
119
|
+
df = session.require_data()
|
|
120
|
+
positional, opts = _stata_opts(args)
|
|
121
|
+
if len(positional) < 2:
|
|
122
|
+
return "Usage: discriminant groupvar indepvar1 indepvar2 ... [, method(lda)]"
|
|
123
|
+
|
|
124
|
+
dep = positional[0]
|
|
125
|
+
indeps = [c for c in positional[1:] if c in df.columns]
|
|
126
|
+
method = opts.get("method", "lda").lower()
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
from openstat.stats.clustering import fit_discriminant
|
|
130
|
+
result = fit_discriminant(df, dep, indeps, method=method)
|
|
131
|
+
lines = [f"\n{result['method']}: {dep}", "=" * 55]
|
|
132
|
+
lines.append(f" {'N observations':<25} {result['n_obs']}")
|
|
133
|
+
lines.append(f" {'N classes':<25} {result['n_classes']}")
|
|
134
|
+
lines.append(f" {'Classes':<25} {', '.join(str(c) for c in result['classes'])}")
|
|
135
|
+
lines.append(f" {'Accuracy (train)':<25} {result['accuracy']:>12.4f}")
|
|
136
|
+
if result.get("priors"):
|
|
137
|
+
lines.append("\nClass priors:")
|
|
138
|
+
for cls, p in zip(result["classes"], result["priors"]):
|
|
139
|
+
lines.append(f" {str(cls):<20} {p:.4f}")
|
|
140
|
+
if "coefficients" in result:
|
|
141
|
+
lines.append("\nDiscriminant function coefficients:")
|
|
142
|
+
for func, coefs in result["coefficients"].items():
|
|
143
|
+
lines.append(f" Function ({func}):")
|
|
144
|
+
for var, val in coefs.items():
|
|
145
|
+
lines.append(f" {var:<22} {val:>10.4f}")
|
|
146
|
+
lines.append("=" * 55)
|
|
147
|
+
session._last_model = result
|
|
148
|
+
return "\n".join(lines)
|
|
149
|
+
except ImportError as e:
|
|
150
|
+
return str(e)
|
|
151
|
+
except Exception as exc:
|
|
152
|
+
return f"discriminant error: {exc}"
|