openstat-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. openstat/__init__.py +3 -0
  2. openstat/__main__.py +4 -0
  3. openstat/backends/__init__.py +16 -0
  4. openstat/backends/duckdb_backend.py +70 -0
  5. openstat/backends/polars_backend.py +52 -0
  6. openstat/cli.py +92 -0
  7. openstat/commands/__init__.py +82 -0
  8. openstat/commands/adv_stat_cmds.py +1255 -0
  9. openstat/commands/advanced_ml_cmds.py +576 -0
  10. openstat/commands/advreg_cmds.py +207 -0
  11. openstat/commands/alias_cmds.py +135 -0
  12. openstat/commands/arch_cmds.py +82 -0
  13. openstat/commands/arules_cmds.py +111 -0
  14. openstat/commands/automodel_cmds.py +212 -0
  15. openstat/commands/backend_cmds.py +82 -0
  16. openstat/commands/base.py +170 -0
  17. openstat/commands/bayes_cmds.py +71 -0
  18. openstat/commands/causal_cmds.py +269 -0
  19. openstat/commands/cluster_cmds.py +152 -0
  20. openstat/commands/data_cmds.py +996 -0
  21. openstat/commands/datamanip_cmds.py +672 -0
  22. openstat/commands/dataquality_cmds.py +174 -0
  23. openstat/commands/datetime_cmds.py +176 -0
  24. openstat/commands/dimreduce_cmds.py +184 -0
  25. openstat/commands/discrete_cmds.py +149 -0
  26. openstat/commands/dsl_cmds.py +143 -0
  27. openstat/commands/epi_cmds.py +93 -0
  28. openstat/commands/equiv_tobit_cmds.py +94 -0
  29. openstat/commands/esttab_cmds.py +196 -0
  30. openstat/commands/export_beamer_cmds.py +142 -0
  31. openstat/commands/export_cmds.py +201 -0
  32. openstat/commands/export_extra_cmds.py +240 -0
  33. openstat/commands/factor_cmds.py +180 -0
  34. openstat/commands/groupby_cmds.py +155 -0
  35. openstat/commands/help_cmds.py +237 -0
  36. openstat/commands/i18n_cmds.py +43 -0
  37. openstat/commands/import_extra_cmds.py +561 -0
  38. openstat/commands/influence_cmds.py +134 -0
  39. openstat/commands/iv_cmds.py +106 -0
  40. openstat/commands/manova_cmds.py +105 -0
  41. openstat/commands/mediate_cmds.py +233 -0
  42. openstat/commands/meta_cmds.py +284 -0
  43. openstat/commands/mi_cmds.py +228 -0
  44. openstat/commands/mixed_cmds.py +79 -0
  45. openstat/commands/mixture_changepoint_cmds.py +166 -0
  46. openstat/commands/ml_adv_cmds.py +147 -0
  47. openstat/commands/ml_cmds.py +178 -0
  48. openstat/commands/model_eval_cmds.py +142 -0
  49. openstat/commands/network_cmds.py +288 -0
  50. openstat/commands/nlquery_cmds.py +161 -0
  51. openstat/commands/nonparam_cmds.py +149 -0
  52. openstat/commands/outreg_cmds.py +247 -0
  53. openstat/commands/panel_cmds.py +141 -0
  54. openstat/commands/pdf_cmds.py +226 -0
  55. openstat/commands/pipeline_cmds.py +319 -0
  56. openstat/commands/plot_cmds.py +189 -0
  57. openstat/commands/plugin_cmds.py +79 -0
  58. openstat/commands/posthoc_cmds.py +153 -0
  59. openstat/commands/power_cmds.py +172 -0
  60. openstat/commands/profile_cmds.py +246 -0
  61. openstat/commands/rbridge_cmds.py +81 -0
  62. openstat/commands/regex_cmds.py +104 -0
  63. openstat/commands/report_cmds.py +48 -0
  64. openstat/commands/repro_cmds.py +129 -0
  65. openstat/commands/resampling_cmds.py +109 -0
  66. openstat/commands/reshape_cmds.py +223 -0
  67. openstat/commands/sem_cmds.py +177 -0
  68. openstat/commands/stat_cmds.py +1040 -0
  69. openstat/commands/stata_import_cmds.py +215 -0
  70. openstat/commands/string_cmds.py +124 -0
  71. openstat/commands/surv_cmds.py +145 -0
  72. openstat/commands/survey_cmds.py +153 -0
  73. openstat/commands/textanalysis_cmds.py +192 -0
  74. openstat/commands/ts_adv_cmds.py +136 -0
  75. openstat/commands/ts_cmds.py +195 -0
  76. openstat/commands/tui_cmds.py +111 -0
  77. openstat/commands/ux_cmds.py +191 -0
  78. openstat/commands/validate_cmds.py +270 -0
  79. openstat/commands/viz_adv_cmds.py +312 -0
  80. openstat/commands/viz_extra_cmds.py +251 -0
  81. openstat/commands/watch_cmds.py +69 -0
  82. openstat/config.py +106 -0
  83. openstat/dsl/__init__.py +0 -0
  84. openstat/dsl/parser.py +332 -0
  85. openstat/dsl/tokenizer.py +105 -0
  86. openstat/i18n.py +120 -0
  87. openstat/io/__init__.py +0 -0
  88. openstat/io/loader.py +187 -0
  89. openstat/jupyter/__init__.py +18 -0
  90. openstat/jupyter/display.py +18 -0
  91. openstat/jupyter/magic.py +60 -0
  92. openstat/logging_config.py +59 -0
  93. openstat/plots/__init__.py +0 -0
  94. openstat/plots/plotter.py +437 -0
  95. openstat/plots/surv_plots.py +32 -0
  96. openstat/plots/ts_plots.py +59 -0
  97. openstat/plugins/__init__.py +5 -0
  98. openstat/plugins/manager.py +69 -0
  99. openstat/repl.py +457 -0
  100. openstat/reporting/__init__.py +0 -0
  101. openstat/reporting/eda.py +208 -0
  102. openstat/reporting/report.py +67 -0
  103. openstat/script_runner.py +319 -0
  104. openstat/session.py +133 -0
  105. openstat/stats/__init__.py +0 -0
  106. openstat/stats/advanced_regression.py +269 -0
  107. openstat/stats/arch_garch.py +84 -0
  108. openstat/stats/bayesian.py +103 -0
  109. openstat/stats/causal.py +258 -0
  110. openstat/stats/clustering.py +206 -0
  111. openstat/stats/discrete.py +311 -0
  112. openstat/stats/epidemiology.py +119 -0
  113. openstat/stats/equiv_tobit.py +163 -0
  114. openstat/stats/factor.py +174 -0
  115. openstat/stats/imputation.py +282 -0
  116. openstat/stats/influence.py +78 -0
  117. openstat/stats/iv.py +131 -0
  118. openstat/stats/manova.py +124 -0
  119. openstat/stats/mixed.py +128 -0
  120. openstat/stats/ml.py +275 -0
  121. openstat/stats/ml_advanced.py +117 -0
  122. openstat/stats/model_eval.py +183 -0
  123. openstat/stats/models.py +1342 -0
  124. openstat/stats/nonparametric.py +130 -0
  125. openstat/stats/panel.py +179 -0
  126. openstat/stats/power.py +295 -0
  127. openstat/stats/resampling.py +203 -0
  128. openstat/stats/survey.py +213 -0
  129. openstat/stats/survival.py +196 -0
  130. openstat/stats/timeseries.py +142 -0
  131. openstat/stats/ts_advanced.py +114 -0
  132. openstat/types.py +11 -0
  133. openstat/web/__init__.py +1 -0
  134. openstat/web/app.py +117 -0
  135. openstat/web/session_manager.py +73 -0
  136. openstat/web/static/app.js +117 -0
  137. openstat/web/static/index.html +38 -0
  138. openstat/web/static/style.css +103 -0
  139. openstat_cli-1.0.0.dist-info/METADATA +748 -0
  140. openstat_cli-1.0.0.dist-info/RECORD +143 -0
  141. openstat_cli-1.0.0.dist-info/WHEEL +4 -0
  142. openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
  143. openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,269 @@
1
+ """Causal inference commands: did, psmatch."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ from openstat.session import Session, ModelResult
8
+ from openstat.dsl.parser import ParseError
9
+ from openstat.stats.causal import fit_did, fit_psm
10
+ from openstat.commands.base import command, CommandArgs, friendly_error
11
+
12
+
13
+ def _store_model(session, result, raw_model, dep, indeps):
14
+ """Store model in session state, return summary output."""
15
+ session._last_model = raw_model
16
+ session._last_model_vars = (dep, indeps)
17
+ session._last_fit_result = result
18
+ session._last_fit_kwargs = {}
19
+ md = result.to_markdown()
20
+ details: dict = {
21
+ "n_obs": result.n_obs,
22
+ "params": dict(result.params),
23
+ "std_errors": dict(result.std_errors),
24
+ }
25
+ if result.r_squared is not None:
26
+ details["r_squared"] = result.r_squared
27
+ session.results.append(ModelResult(
28
+ name=result.model_type, formula=result.formula,
29
+ table=md, details=details,
30
+ ))
31
+ output = result.summary_table()
32
+ if result.warnings:
33
+ output += "\n" + "\n".join(result.warnings)
34
+ return output
35
+
36
+
37
+ @command("did", usage="did y ~ treatment_var time_var [--robust] [--cluster=col]")
38
+ def cmd_did(session: Session, args: str) -> str:
39
+ """Difference-in-Differences estimation."""
40
+ df = session.require_data()
41
+ ca = CommandArgs(args)
42
+ robust = ca.has_flag("--robust")
43
+ cluster_col = ca.get_option("cluster")
44
+ formula_str = ca.strip_flags_and_options()
45
+
46
+ if not formula_str or "~" not in formula_str:
47
+ return "Usage: did y ~ treatment_var time_var [--robust] [--cluster=col]"
48
+
49
+ try:
50
+ # Parse: y ~ treatment_var time_var
51
+ left, right = formula_str.split("~", 1)
52
+ dep = left.strip()
53
+ if not dep:
54
+ return "Usage: did y ~ treatment_var time_var"
55
+
56
+ rhs_vars = right.strip().split()
57
+ if len(rhs_vars) < 2:
58
+ return "Usage: did y ~ treatment_var time_var (need both treatment and time variables)"
59
+
60
+ treatment_col_name = rhs_vars[0]
61
+ time_col_name = rhs_vars[1]
62
+
63
+ result, raw_model = fit_did(
64
+ df, dep, treatment_col_name, time_col_name,
65
+ robust=robust, cluster_col=cluster_col,
66
+ )
67
+ return _store_model(session, result, raw_model, dep, [treatment_col_name, time_col_name])
68
+ except Exception as e:
69
+ return friendly_error(e, "DiD error")
70
+
71
+
72
+ @command("psmatch", usage="psmatch outcome ~ covars, treatment(tvar) [caliper(0.1)] [nn(3)]")
73
+ def cmd_psmatch(session: Session, args: str) -> str:
74
+ """Propensity Score Matching."""
75
+ df = session.require_data()
76
+
77
+ # Parse treatment(var)
78
+ treat_match = re.search(r'treatment\((\w+)\)', args)
79
+ if not treat_match:
80
+ return "Usage: psmatch outcome ~ x1 x2, treatment(tvar) [caliper(0.1)] [nn(3)]"
81
+ treatment_col = treat_match.group(1)
82
+
83
+ # Parse optional caliper(value)
84
+ caliper = None
85
+ caliper_match = re.search(r'caliper\(([^)]+)\)', args)
86
+ if caliper_match:
87
+ try:
88
+ caliper = float(caliper_match.group(1))
89
+ except ValueError:
90
+ return f"Invalid caliper value: {caliper_match.group(1)}"
91
+
92
+ # Parse optional nn(value)
93
+ n_neighbors = 1
94
+ nn_match = re.search(r'nn\((\d+)\)', args)
95
+ if nn_match:
96
+ n_neighbors = int(nn_match.group(1))
97
+
98
+ # Strip options to get formula part
99
+ formula_part = args
100
+ for pattern in [r',?\s*treatment\([^)]+\)', r',?\s*caliper\([^)]+\)', r',?\s*nn\(\d+\)']:
101
+ formula_part = re.sub(pattern, '', formula_part)
102
+ formula_part = formula_part.strip().rstrip(',').strip()
103
+
104
+ if not formula_part or "~" not in formula_part:
105
+ return "Usage: psmatch outcome ~ x1 x2, treatment(tvar) [caliper(0.1)] [nn(3)]"
106
+
107
+ try:
108
+ left, right = formula_part.split("~", 1)
109
+ outcome = left.strip()
110
+ covariates = right.strip().split()
111
+
112
+ if not outcome or not covariates:
113
+ return "Usage: psmatch outcome ~ x1 x2, treatment(tvar)"
114
+
115
+ result_str = fit_psm(
116
+ df, outcome, covariates, treatment_col,
117
+ n_neighbors=n_neighbors, caliper=caliper,
118
+ )
119
+
120
+ # Store a simple record
121
+ session.results.append(ModelResult(
122
+ name="PSM", formula=f"{outcome} ~ {' + '.join(covariates)}",
123
+ table=result_str, details={"treatment": treatment_col},
124
+ ))
125
+ return result_str
126
+ except Exception as e:
127
+ return friendly_error(e, "PSM error")
128
+
129
+
130
+ @command("iptw", usage="iptw <outcome> ~ <covars>, treatment(<tvar>) [--ate|--att] [--trim=0.01]")
131
+ def cmd_iptw(session: Session, args: str) -> str:
132
+ """Inverse Probability Treatment Weighting (IPTW) for causal inference.
133
+
134
+ Estimates propensity scores via logistic regression, then uses IPT weights
135
+ to estimate the Average Treatment Effect (ATE) or Average Treatment Effect
136
+ on the Treated (ATT) via weighted OLS.
137
+
138
+ Examples:
139
+ iptw score ~ age + income, treatment(employed)
140
+ iptw score ~ age + income, treatment(employed) --att
141
+ iptw score ~ age + income, treatment(employed) --ate --trim=0.05
142
+ """
143
+ import re
144
+ import numpy as np
145
+ import polars as pl
146
+ import statsmodels.api as sm
147
+ from sklearn.linear_model import LogisticRegression
148
+
149
+ df = session.require_data()
150
+
151
+ # Parse treatment(var)
152
+ treat_m = re.search(r'treatment\((\w+)\)', args)
153
+ if not treat_m:
154
+ return "Usage: iptw outcome ~ x1 + x2, treatment(tvar) [--ate|--att] [--trim=0.01]"
155
+ treatment_col = treat_m.group(1)
156
+
157
+ # Parse estimand
158
+ att = "--att" in args
159
+ estimand = "ATT" if att else "ATE"
160
+
161
+ # Parse trim
162
+ trim = 0.01
163
+ trim_m = re.search(r'--trim[= ]([\d.]+)', args)
164
+ if trim_m:
165
+ trim = float(trim_m.group(1))
166
+
167
+ # Clean formula
168
+ formula_part = re.sub(r',?\s*treatment\([^)]+\)', '', args)
169
+ formula_part = re.sub(r'--\w+(?:[= ][\d.]+)?', '', formula_part).strip()
170
+
171
+ if "~" not in formula_part:
172
+ return "Usage: iptw outcome ~ x1 + x2, treatment(tvar)"
173
+
174
+ lhs, rhs = formula_part.split("~", 1)
175
+ outcome = lhs.strip()
176
+ covars = [c.strip() for c in rhs.replace("+", " ").split() if c.strip()]
177
+
178
+ needed = [outcome, treatment_col] + covars
179
+ missing = [c for c in needed if c not in df.columns]
180
+ if missing:
181
+ return f"Columns not found: {', '.join(missing)}"
182
+
183
+ sub = df.select(needed).drop_nulls()
184
+ y = sub[outcome].to_numpy().astype(float)
185
+ treat = sub[treatment_col].to_numpy().astype(float)
186
+ X = sub.select(covars).to_numpy().astype(float)
187
+
188
+ # Step 1: Propensity score model
189
+ try:
190
+ ps_model = LogisticRegression(max_iter=1000, C=1e6)
191
+ ps_model.fit(X, treat)
192
+ ps = ps_model.predict_proba(X)[:, 1]
193
+ except Exception as exc:
194
+ return f"Propensity score estimation failed: {exc}"
195
+
196
+ # Trim extreme propensity scores
197
+ ps_clipped = np.clip(ps, trim, 1 - trim)
198
+
199
+ # Step 2: Compute IPTW weights
200
+ if estimand == "ATE":
201
+ weights = treat / ps_clipped + (1 - treat) / (1 - ps_clipped)
202
+ else: # ATT
203
+ weights = treat + (1 - treat) * ps_clipped / (1 - ps_clipped)
204
+
205
+ # Step 3: Weighted OLS of outcome on treatment
206
+ X_ols = sm.add_constant(treat)
207
+ try:
208
+ wls_res = sm.WLS(y, X_ols, weights=weights).fit()
209
+ except Exception as exc:
210
+ return f"Weighted regression failed: {exc}"
211
+
212
+ ate_est = wls_res.params[1]
213
+ ate_se = wls_res.bse[1]
214
+ ate_t = wls_res.tvalues[1]
215
+ ate_p = wls_res.pvalues[1]
216
+ ci_low, ci_high = wls_res.conf_int()[1]
217
+
218
+ def _sig(p):
219
+ if p < 0.001: return "***"
220
+ if p < 0.01: return "**"
221
+ if p < 0.05: return "*"
222
+ return ""
223
+
224
+ # Balance assessment (standardized mean differences before/after)
225
+ n_treat = int(treat.sum())
226
+ n_ctrl = int((1 - treat).sum())
227
+
228
+ lines = [
229
+ f"Outcome: {outcome} Treatment: {treatment_col} Estimand: {estimand}",
230
+ f"N = {len(y)} (treated={n_treat}, control={n_ctrl})",
231
+ f"Propensity score trim: [{trim:.3f}, {1-trim:.3f}]",
232
+ "",
233
+ f"IPTW {estimand} Estimate:",
234
+ f" {'Coef':>10} {'SE':>8} {'t':>7} {'p-value':>9} {'95% CI':>20}",
235
+ " " + "-" * 58,
236
+ f" {ate_est:>10.4f} {ate_se:>8.4f} {ate_t:>7.3f} "
237
+ f"{ate_p:>9.4f}{_sig(ate_p)} [{ci_low:.4f}, {ci_high:.4f}]",
238
+ "",
239
+ "Weight Summary:",
240
+ f" Min={weights.min():.3f} Mean={weights.mean():.3f} "
241
+ f"Max={weights.max():.3f} SD={weights.std():.3f}",
242
+ "",
243
+ "Propensity Score Summary:",
244
+ f" Treated — mean={ps[treat==1].mean():.3f} "
245
+ f"min={ps[treat==1].min():.3f} max={ps[treat==1].max():.3f}",
246
+ f" Control — mean={ps[treat==0].mean():.3f} "
247
+ f"min={ps[treat==0].min():.3f} max={ps[treat==0].max():.3f}",
248
+ ]
249
+
250
+ # Standardized mean differences (covariate balance)
251
+ lines += ["", "Covariate Balance (Standardized Mean Differences):"]
252
+ lines.append(f" {'Variable':<20} {'Before SMD':>12} {'After SMD':>11} {'Balanced?':>10}")
253
+ lines.append(" " + "-" * 56)
254
+ for j, cname in enumerate(covars):
255
+ xj = X[:, j]
256
+ mu_t = xj[treat == 1].mean()
257
+ mu_c = xj[treat == 0].mean()
258
+ sd_pool = np.sqrt((xj[treat == 1].var() + xj[treat == 0].var()) / 2 + 1e-10)
259
+ smd_before = abs(mu_t - mu_c) / sd_pool
260
+ # After weighting
261
+ mu_t_w = np.average(xj[treat == 1], weights=weights[treat == 1])
262
+ mu_c_w = np.average(xj[treat == 0], weights=weights[treat == 0])
263
+ smd_after = abs(mu_t_w - mu_c_w) / sd_pool
264
+ balanced = "✓" if smd_after < 0.1 else "✗"
265
+ lines.append(
266
+ f" {cname:<20} {smd_before:>12.4f} {smd_after:>11.4f} {balanced:>10}"
267
+ )
268
+
269
+ return "\n" + "=" * 60 + "\nIPTW Causal Estimate\n" + "=" * 60 + "\n" + "\n".join(lines) + "\n" + "=" * 60
@@ -0,0 +1,152 @@
1
+ """Clustering, MDS, and discriminant analysis commands."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ from openstat.commands.base import command
8
+ from openstat.session import Session
9
+
10
+
11
+ def _stata_opts(raw: str) -> tuple[list[str], dict[str, str]]:
12
+ opts: dict[str, str] = {}
13
+ for m in re.finditer(r'(\w+)\(([^)]*)\)', raw):
14
+ opts[m.group(1).lower()] = m.group(2)
15
+ rest = re.sub(r'\w+\([^)]*\)', '', raw)
16
+ positional = [t.strip(',') for t in rest.split() if t.strip(',')]
17
+ return positional, opts
18
+
19
+
20
+ @command("cluster", usage="cluster kmeans|hierarchical varlist [, k(3) linkage(ward)]")
21
+ def cmd_cluster(session: Session, args: str) -> str:
22
+ """K-means or hierarchical clustering."""
23
+ df = session.require_data()
24
+ positional, opts = _stata_opts(args)
25
+ if len(positional) < 2:
26
+ return (
27
+ "Usage: cluster kmeans varlist [, k(3)]\n"
28
+ " cluster hierarchical varlist [, k(3) linkage(ward)]"
29
+ )
30
+ sub = positional[0].lower()
31
+ cols = [c for c in positional[1:] if c in df.columns]
32
+ if not cols:
33
+ return "No valid columns found."
34
+
35
+ k = int(opts.get("k", 3))
36
+
37
+ try:
38
+ if sub in ("kmeans", "k-means"):
39
+ from openstat.stats.clustering import fit_kmeans
40
+ result = fit_kmeans(df, cols, k=k)
41
+ lines = [f"\nK-Means Clustering (k={k})", "=" * 50]
42
+ lines.append(f" {'N observations':<25} {result['n_obs']}")
43
+ lines.append(f" {'Inertia':<25} {result['inertia']:.4f}")
44
+ lines.append(f" {'Silhouette score':<25} {result['silhouette_score']:.4f}")
45
+ lines.append(f" {'Calinski-Harabasz':<25} {result['calinski_harabasz']:.4f}")
46
+ lines.append("\nCluster sizes:")
47
+ for cl, n in result["cluster_sizes"].items():
48
+ lines.append(f" Cluster {cl + 1}: {n} obs ({n/result['n_obs']*100:.1f}%)")
49
+ lines.append("\nCentroids (original scale):")
50
+ lines.append(" " + f"{'Cluster':<10}" + "".join(f" {c[:8]:>8}" for c in cols))
51
+ for i, cent in enumerate(result["centroids"]):
52
+ row = f" {'Cl.' + str(i+1):<10}"
53
+ for v in cent:
54
+ row += f" {v:>8.3f}"
55
+ lines.append(row)
56
+ lines.append("=" * 50)
57
+ session._last_model = result
58
+ return "\n".join(lines)
59
+
60
+ elif sub in ("hierarchical", "hier", "agglomerative"):
61
+ linkage = opts.get("linkage", "ward")
62
+ from openstat.stats.clustering import fit_hierarchical
63
+ result = fit_hierarchical(df, cols, k=k, linkage=linkage)
64
+ lines = [f"\nHierarchical Clustering (k={k}, linkage={linkage})", "=" * 50]
65
+ lines.append(f" {'N observations':<25} {result['n_obs']}")
66
+ lines.append(f" {'Silhouette score':<25} {result['silhouette_score']:.4f}")
67
+ lines.append(f" {'Calinski-Harabasz':<25} {result['calinski_harabasz']:.4f}")
68
+ lines.append("\nCluster sizes:")
69
+ for cl, n in result["cluster_sizes"].items():
70
+ lines.append(f" Cluster {cl + 1}: {n} obs")
71
+ lines.append("=" * 50)
72
+ session._last_model = result
73
+ return "\n".join(lines)
74
+
75
+ else:
76
+ return f"Unknown cluster method: {sub}. Use 'kmeans' or 'hierarchical'."
77
+
78
+ except ImportError as e:
79
+ return str(e)
80
+ except Exception as exc:
81
+ return f"cluster error: {exc}"
82
+
83
+
84
+ @command("mds", usage="mds varlist [, n(2) metric]")
85
+ def cmd_mds(session: Session, args: str) -> str:
86
+ """Multidimensional scaling."""
87
+ df = session.require_data()
88
+ positional, opts = _stata_opts(args)
89
+ cols = [c for c in positional if c in df.columns]
90
+ if len(cols) < 2:
91
+ return "mds requires at least 2 numeric variables."
92
+ n_comp = int(opts.get("n", 2))
93
+ metric = "nonmetric" not in positional
94
+
95
+ try:
96
+ from openstat.stats.clustering import fit_mds
97
+ result = fit_mds(df, cols, n_components=n_comp, metric=metric)
98
+ lines = [f"\nMDS ({'metric' if metric else 'non-metric'})", "=" * 50]
99
+ lines.append(f" {'N observations':<25} {result['n_obs']}")
100
+ lines.append(f" {'N components':<25} {n_comp}")
101
+ lines.append(f" {'Stress':<25} {result['stress']:.6f}")
102
+ lines.append("\nFirst 5 coordinates (Dim 1, Dim 2):")
103
+ for i, coord in enumerate(result["coordinates"][:5]):
104
+ lines.append(" " + " ".join(f"{v:>8.4f}" for v in coord))
105
+ if len(result["coordinates"]) > 5:
106
+ lines.append(f" ... ({len(result['coordinates'])} total rows)")
107
+ lines.append("=" * 50)
108
+ session._last_model = result
109
+ return "\n".join(lines)
110
+ except ImportError as e:
111
+ return str(e)
112
+ except Exception as exc:
113
+ return f"mds error: {exc}"
114
+
115
+
116
+ @command("discriminant", usage="discriminant groupvar indepvars [, method(lda|qda)]")
117
+ def cmd_discriminant(session: Session, args: str) -> str:
118
+ """Linear (LDA) or Quadratic (QDA) Discriminant Analysis."""
119
+ df = session.require_data()
120
+ positional, opts = _stata_opts(args)
121
+ if len(positional) < 2:
122
+ return "Usage: discriminant groupvar indepvar1 indepvar2 ... [, method(lda)]"
123
+
124
+ dep = positional[0]
125
+ indeps = [c for c in positional[1:] if c in df.columns]
126
+ method = opts.get("method", "lda").lower()
127
+
128
+ try:
129
+ from openstat.stats.clustering import fit_discriminant
130
+ result = fit_discriminant(df, dep, indeps, method=method)
131
+ lines = [f"\n{result['method']}: {dep}", "=" * 55]
132
+ lines.append(f" {'N observations':<25} {result['n_obs']}")
133
+ lines.append(f" {'N classes':<25} {result['n_classes']}")
134
+ lines.append(f" {'Classes':<25} {', '.join(str(c) for c in result['classes'])}")
135
+ lines.append(f" {'Accuracy (train)':<25} {result['accuracy']:>12.4f}")
136
+ if result.get("priors"):
137
+ lines.append("\nClass priors:")
138
+ for cls, p in zip(result["classes"], result["priors"]):
139
+ lines.append(f" {str(cls):<20} {p:.4f}")
140
+ if "coefficients" in result:
141
+ lines.append("\nDiscriminant function coefficients:")
142
+ for func, coefs in result["coefficients"].items():
143
+ lines.append(f" Function ({func}):")
144
+ for var, val in coefs.items():
145
+ lines.append(f" {var:<22} {val:>10.4f}")
146
+ lines.append("=" * 55)
147
+ session._last_model = result
148
+ return "\n".join(lines)
149
+ except ImportError as e:
150
+ return str(e)
151
+ except Exception as exc:
152
+ return f"discriminant error: {exc}"