openstat-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. openstat/__init__.py +3 -0
  2. openstat/__main__.py +4 -0
  3. openstat/backends/__init__.py +16 -0
  4. openstat/backends/duckdb_backend.py +70 -0
  5. openstat/backends/polars_backend.py +52 -0
  6. openstat/cli.py +92 -0
  7. openstat/commands/__init__.py +82 -0
  8. openstat/commands/adv_stat_cmds.py +1255 -0
  9. openstat/commands/advanced_ml_cmds.py +576 -0
  10. openstat/commands/advreg_cmds.py +207 -0
  11. openstat/commands/alias_cmds.py +135 -0
  12. openstat/commands/arch_cmds.py +82 -0
  13. openstat/commands/arules_cmds.py +111 -0
  14. openstat/commands/automodel_cmds.py +212 -0
  15. openstat/commands/backend_cmds.py +82 -0
  16. openstat/commands/base.py +170 -0
  17. openstat/commands/bayes_cmds.py +71 -0
  18. openstat/commands/causal_cmds.py +269 -0
  19. openstat/commands/cluster_cmds.py +152 -0
  20. openstat/commands/data_cmds.py +996 -0
  21. openstat/commands/datamanip_cmds.py +672 -0
  22. openstat/commands/dataquality_cmds.py +174 -0
  23. openstat/commands/datetime_cmds.py +176 -0
  24. openstat/commands/dimreduce_cmds.py +184 -0
  25. openstat/commands/discrete_cmds.py +149 -0
  26. openstat/commands/dsl_cmds.py +143 -0
  27. openstat/commands/epi_cmds.py +93 -0
  28. openstat/commands/equiv_tobit_cmds.py +94 -0
  29. openstat/commands/esttab_cmds.py +196 -0
  30. openstat/commands/export_beamer_cmds.py +142 -0
  31. openstat/commands/export_cmds.py +201 -0
  32. openstat/commands/export_extra_cmds.py +240 -0
  33. openstat/commands/factor_cmds.py +180 -0
  34. openstat/commands/groupby_cmds.py +155 -0
  35. openstat/commands/help_cmds.py +237 -0
  36. openstat/commands/i18n_cmds.py +43 -0
  37. openstat/commands/import_extra_cmds.py +561 -0
  38. openstat/commands/influence_cmds.py +134 -0
  39. openstat/commands/iv_cmds.py +106 -0
  40. openstat/commands/manova_cmds.py +105 -0
  41. openstat/commands/mediate_cmds.py +233 -0
  42. openstat/commands/meta_cmds.py +284 -0
  43. openstat/commands/mi_cmds.py +228 -0
  44. openstat/commands/mixed_cmds.py +79 -0
  45. openstat/commands/mixture_changepoint_cmds.py +166 -0
  46. openstat/commands/ml_adv_cmds.py +147 -0
  47. openstat/commands/ml_cmds.py +178 -0
  48. openstat/commands/model_eval_cmds.py +142 -0
  49. openstat/commands/network_cmds.py +288 -0
  50. openstat/commands/nlquery_cmds.py +161 -0
  51. openstat/commands/nonparam_cmds.py +149 -0
  52. openstat/commands/outreg_cmds.py +247 -0
  53. openstat/commands/panel_cmds.py +141 -0
  54. openstat/commands/pdf_cmds.py +226 -0
  55. openstat/commands/pipeline_cmds.py +319 -0
  56. openstat/commands/plot_cmds.py +189 -0
  57. openstat/commands/plugin_cmds.py +79 -0
  58. openstat/commands/posthoc_cmds.py +153 -0
  59. openstat/commands/power_cmds.py +172 -0
  60. openstat/commands/profile_cmds.py +246 -0
  61. openstat/commands/rbridge_cmds.py +81 -0
  62. openstat/commands/regex_cmds.py +104 -0
  63. openstat/commands/report_cmds.py +48 -0
  64. openstat/commands/repro_cmds.py +129 -0
  65. openstat/commands/resampling_cmds.py +109 -0
  66. openstat/commands/reshape_cmds.py +223 -0
  67. openstat/commands/sem_cmds.py +177 -0
  68. openstat/commands/stat_cmds.py +1040 -0
  69. openstat/commands/stata_import_cmds.py +215 -0
  70. openstat/commands/string_cmds.py +124 -0
  71. openstat/commands/surv_cmds.py +145 -0
  72. openstat/commands/survey_cmds.py +153 -0
  73. openstat/commands/textanalysis_cmds.py +192 -0
  74. openstat/commands/ts_adv_cmds.py +136 -0
  75. openstat/commands/ts_cmds.py +195 -0
  76. openstat/commands/tui_cmds.py +111 -0
  77. openstat/commands/ux_cmds.py +191 -0
  78. openstat/commands/validate_cmds.py +270 -0
  79. openstat/commands/viz_adv_cmds.py +312 -0
  80. openstat/commands/viz_extra_cmds.py +251 -0
  81. openstat/commands/watch_cmds.py +69 -0
  82. openstat/config.py +106 -0
  83. openstat/dsl/__init__.py +0 -0
  84. openstat/dsl/parser.py +332 -0
  85. openstat/dsl/tokenizer.py +105 -0
  86. openstat/i18n.py +120 -0
  87. openstat/io/__init__.py +0 -0
  88. openstat/io/loader.py +187 -0
  89. openstat/jupyter/__init__.py +18 -0
  90. openstat/jupyter/display.py +18 -0
  91. openstat/jupyter/magic.py +60 -0
  92. openstat/logging_config.py +59 -0
  93. openstat/plots/__init__.py +0 -0
  94. openstat/plots/plotter.py +437 -0
  95. openstat/plots/surv_plots.py +32 -0
  96. openstat/plots/ts_plots.py +59 -0
  97. openstat/plugins/__init__.py +5 -0
  98. openstat/plugins/manager.py +69 -0
  99. openstat/repl.py +457 -0
  100. openstat/reporting/__init__.py +0 -0
  101. openstat/reporting/eda.py +208 -0
  102. openstat/reporting/report.py +67 -0
  103. openstat/script_runner.py +319 -0
  104. openstat/session.py +133 -0
  105. openstat/stats/__init__.py +0 -0
  106. openstat/stats/advanced_regression.py +269 -0
  107. openstat/stats/arch_garch.py +84 -0
  108. openstat/stats/bayesian.py +103 -0
  109. openstat/stats/causal.py +258 -0
  110. openstat/stats/clustering.py +206 -0
  111. openstat/stats/discrete.py +311 -0
  112. openstat/stats/epidemiology.py +119 -0
  113. openstat/stats/equiv_tobit.py +163 -0
  114. openstat/stats/factor.py +174 -0
  115. openstat/stats/imputation.py +282 -0
  116. openstat/stats/influence.py +78 -0
  117. openstat/stats/iv.py +131 -0
  118. openstat/stats/manova.py +124 -0
  119. openstat/stats/mixed.py +128 -0
  120. openstat/stats/ml.py +275 -0
  121. openstat/stats/ml_advanced.py +117 -0
  122. openstat/stats/model_eval.py +183 -0
  123. openstat/stats/models.py +1342 -0
  124. openstat/stats/nonparametric.py +130 -0
  125. openstat/stats/panel.py +179 -0
  126. openstat/stats/power.py +295 -0
  127. openstat/stats/resampling.py +203 -0
  128. openstat/stats/survey.py +213 -0
  129. openstat/stats/survival.py +196 -0
  130. openstat/stats/timeseries.py +142 -0
  131. openstat/stats/ts_advanced.py +114 -0
  132. openstat/types.py +11 -0
  133. openstat/web/__init__.py +1 -0
  134. openstat/web/app.py +117 -0
  135. openstat/web/session_manager.py +73 -0
  136. openstat/web/static/app.js +117 -0
  137. openstat/web/static/index.html +38 -0
  138. openstat/web/static/style.css +103 -0
  139. openstat_cli-1.0.0.dist-info/METADATA +748 -0
  140. openstat_cli-1.0.0.dist-info/RECORD +143 -0
  141. openstat_cli-1.0.0.dist-info/WHEEL +4 -0
  142. openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
  143. openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,258 @@
1
+ """Causal inference models: Difference-in-Differences, Propensity Score Matching."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import polars as pl
7
+ import statsmodels.api as sm
8
+ from scipy import stats as sp_stats
9
+ from scipy.spatial import KDTree
10
+
11
+ from openstat.stats.models import FitResult, _model_type_suffix
12
+
13
+
14
+ # ── Difference-in-Differences ────────────────────────────────────────
15
+
16
+ def fit_did(
17
+ df: pl.DataFrame,
18
+ dep: str,
19
+ treatment_col: str,
20
+ time_col: str,
21
+ *,
22
+ robust: bool = False,
23
+ cluster_col: str | None = None,
24
+ ) -> tuple[FitResult, object]:
25
+ """Fit a Difference-in-Differences model.
26
+
27
+ Model: y = b0 + b1*treatment + b2*post + b3*(treatment*post) + e
28
+ The DiD estimate is b3.
29
+ """
30
+ cols_needed = [dep, treatment_col, time_col]
31
+ if cluster_col:
32
+ cols_needed.append(cluster_col)
33
+ cols_needed = list(dict.fromkeys(cols_needed))
34
+
35
+ missing = [c for c in cols_needed if c not in df.columns]
36
+ if missing:
37
+ raise ValueError(f"Columns not found: {', '.join(missing)}")
38
+
39
+ sub = df.select(cols_needed).drop_nulls()
40
+ if sub.height == 0:
41
+ raise ValueError("No observations after dropping missing values")
42
+
43
+ warnings_list: list[str] = []
44
+ n_dropped = df.height - sub.height
45
+ if n_dropped > 0:
46
+ warnings_list.append(f"Note: {n_dropped} observation(s) dropped due to missing values.")
47
+
48
+ y = sub[dep].to_numpy().astype(float)
49
+ treat = sub[treatment_col].to_numpy().astype(float)
50
+ post = sub[time_col].to_numpy().astype(float)
51
+ interact = treat * post
52
+
53
+ X = np.column_stack([np.ones(len(y)), treat, post, interact])
54
+ var_names = ["_cons", treatment_col, time_col, f"{treatment_col}:{time_col}"]
55
+
56
+ # Cluster SE
57
+ groups = None
58
+ if cluster_col:
59
+ groups = sub[cluster_col].to_numpy()
60
+
61
+ if robust:
62
+ cov_type = "HC1"
63
+ cov_kwds: dict = {}
64
+ elif groups is not None:
65
+ cov_type = "cluster"
66
+ cov_kwds = {"groups": groups}
67
+ else:
68
+ cov_type = "nonrobust"
69
+ cov_kwds = {}
70
+
71
+ model = sm.OLS(y, X).fit(cov_type=cov_type, cov_kwds=cov_kwds)
72
+ ci = model.conf_int()
73
+
74
+ # Compute group means for diagnostics
75
+ treat_mask = treat == 1
76
+ control_mask = treat == 0
77
+ pre_mask = post == 0
78
+ post_mask = post == 1
79
+
80
+ means = {}
81
+ for label, t_mask, p_mask in [
82
+ ("Control, Pre", control_mask, pre_mask),
83
+ ("Control, Post", control_mask, post_mask),
84
+ ("Treatment, Pre", treat_mask, pre_mask),
85
+ ("Treatment, Post", treat_mask, post_mask),
86
+ ]:
87
+ mask = t_mask & p_mask
88
+ if mask.sum() > 0:
89
+ means[label] = float(np.mean(y[mask]))
90
+
91
+ if means:
92
+ warnings_list.append("Group means:")
93
+ for label, mean in means.items():
94
+ warnings_list.append(f" {label}: {mean:.4f}")
95
+
96
+ did_coef = float(model.params[3])
97
+ did_se = float(model.bse[3])
98
+ did_p = float(model.pvalues[3])
99
+ warnings_list.append(f"DiD estimate: {did_coef:.4f} (SE={did_se:.4f}, p={did_p:.4f})")
100
+
101
+ suffix = _model_type_suffix(robust, groups is not None)
102
+
103
+ fit = FitResult(
104
+ model_type="DiD" + suffix,
105
+ formula=f"{dep} ~ {treatment_col} + {time_col} + {treatment_col}:{time_col}",
106
+ dep_var=dep,
107
+ indep_vars=[treatment_col, time_col],
108
+ n_obs=int(model.nobs),
109
+ params=dict(zip(var_names, model.params)),
110
+ std_errors=dict(zip(var_names, model.bse)),
111
+ t_values=dict(zip(var_names, model.tvalues)),
112
+ p_values=dict(zip(var_names, model.pvalues)),
113
+ conf_int_low=dict(zip(var_names, ci[:, 0])),
114
+ conf_int_high=dict(zip(var_names, ci[:, 1])),
115
+ r_squared=float(model.rsquared),
116
+ adj_r_squared=float(model.rsquared_adj),
117
+ f_statistic=float(model.fvalue) if model.fvalue is not None else None,
118
+ f_pvalue=float(model.f_pvalue) if model.f_pvalue is not None else None,
119
+ warnings=warnings_list,
120
+ )
121
+ return fit, model
122
+
123
+
124
+ # ── Propensity Score Matching ────────────────────────────────────────
125
+
126
+ def fit_psm(
127
+ df: pl.DataFrame,
128
+ outcome: str,
129
+ covariates: list[str],
130
+ treatment_col: str,
131
+ *,
132
+ n_neighbors: int = 1,
133
+ caliper: float | None = None,
134
+ ) -> str:
135
+ """Propensity Score Matching: estimate Average Treatment Effect on Treated (ATT).
136
+
137
+ Steps:
138
+ 1. Logit model for propensity score P(T=1 | X)
139
+ 2. KDTree nearest-neighbor matching
140
+ 3. ATT = mean(Y_treated - Y_matched_control)
141
+ 4. Bootstrap SE
142
+ """
143
+ all_cols = list(dict.fromkeys([outcome, treatment_col] + covariates))
144
+ missing = [c for c in all_cols if c not in df.columns]
145
+ if missing:
146
+ raise ValueError(f"Columns not found: {', '.join(missing)}")
147
+
148
+ sub = df.select(all_cols).drop_nulls()
149
+ if sub.height < 20:
150
+ raise ValueError(f"Too few observations ({sub.height}) for propensity score matching.")
151
+
152
+ y = sub[outcome].to_numpy().astype(float)
153
+ treat = sub[treatment_col].to_numpy().astype(float)
154
+
155
+ unique_t = set(treat)
156
+ if not unique_t.issubset({0.0, 1.0}):
157
+ raise ValueError(
158
+ f"Treatment variable must be binary (0/1). Found: {sorted(unique_t)[:10]}"
159
+ )
160
+
161
+ X = sub.select(covariates).to_numpy().astype(float)
162
+ X_with_const = sm.add_constant(X)
163
+
164
+ # Step 1: Propensity score via logit
165
+ logit_model = sm.Logit(treat, X_with_const).fit(disp=0)
166
+ pscore = logit_model.predict(X_with_const)
167
+
168
+ # Default caliper
169
+ if caliper is None:
170
+ caliper = 0.2 * np.std(pscore)
171
+
172
+ # Step 2: KDTree matching
173
+ treated_idx = np.where(treat == 1)[0]
174
+ control_idx = np.where(treat == 0)[0]
175
+
176
+ if len(treated_idx) == 0 or len(control_idx) == 0:
177
+ raise ValueError("Need both treated and control observations.")
178
+
179
+ control_ps = pscore[control_idx].reshape(-1, 1)
180
+ tree = KDTree(control_ps)
181
+
182
+ matched_treated = []
183
+ matched_control_outcomes = []
184
+ unmatched = 0
185
+
186
+ for t_i in treated_idx:
187
+ ps_t = pscore[t_i]
188
+ dists, idxs = tree.query([[ps_t]], k=n_neighbors)
189
+ dists = dists.flatten()
190
+ idxs = idxs.flatten()
191
+
192
+ # Apply caliper
193
+ valid = dists <= caliper
194
+ if not valid.any():
195
+ unmatched += 1
196
+ continue
197
+
198
+ matched_treated.append(y[t_i])
199
+ control_outcomes = [y[control_idx[idx]] for idx, v in zip(idxs, valid) if v]
200
+ matched_control_outcomes.append(np.mean(control_outcomes))
201
+
202
+ if len(matched_treated) < 5:
203
+ raise ValueError(
204
+ f"Only {len(matched_treated)} treated units matched. "
205
+ f"Try increasing caliper or reducing n_neighbors."
206
+ )
207
+
208
+ matched_treated = np.array(matched_treated)
209
+ matched_control_outcomes = np.array(matched_control_outcomes)
210
+
211
+ # Step 3: ATT
212
+ att = float(np.mean(matched_treated - matched_control_outcomes))
213
+
214
+ # Step 4: Bootstrap SE
215
+ n_boot = 50
216
+ rng = np.random.RandomState(42)
217
+ boot_atts = []
218
+ for _ in range(n_boot):
219
+ boot_idx = rng.choice(len(matched_treated), size=len(matched_treated), replace=True)
220
+ boot_att = float(np.mean(matched_treated[boot_idx] - matched_control_outcomes[boot_idx]))
221
+ boot_atts.append(boot_att)
222
+
223
+ se_att = float(np.std(boot_atts, ddof=1))
224
+ t_stat = att / se_att if se_att > 0 else np.nan
225
+ p_value = float(2 * (1 - sp_stats.norm.cdf(np.abs(t_stat))))
226
+
227
+ # Balance table: mean difference before/after matching
228
+ balance_lines = []
229
+ for i, cov in enumerate(covariates):
230
+ mean_t = float(np.mean(X[treated_idx, i]))
231
+ mean_c_all = float(np.mean(X[control_idx, i]))
232
+ # Matched controls (approximate via pscore-matched indices)
233
+ balance_lines.append(
234
+ f" {cov:20s} Treated: {mean_t:8.4f} Control: {mean_c_all:8.4f} "
235
+ f"Diff: {mean_t - mean_c_all:8.4f}"
236
+ )
237
+
238
+ lines = [
239
+ "Propensity Score Matching",
240
+ f" Treatment variable: {treatment_col}",
241
+ f" Outcome variable: {outcome}",
242
+ f" Covariates: {', '.join(covariates)}",
243
+ f" Neighbors: {n_neighbors}, Caliper: {caliper:.4f}",
244
+ "",
245
+ f" N treated: {len(treated_idx)}",
246
+ f" N control: {len(control_idx)}",
247
+ f" Matched: {len(matched_treated)}",
248
+ f" Unmatched: {unmatched}",
249
+ "",
250
+ f" ATT: {att:.4f}",
251
+ f" SE: {se_att:.4f}",
252
+ f" t-stat: {t_stat:.4f}",
253
+ f" p-value: {p_value:.4f}",
254
+ "",
255
+ "Covariate Balance (before matching):",
256
+ ] + balance_lines
257
+
258
+ return "\n".join(lines)
@@ -0,0 +1,206 @@
1
+ """Clustering, MDS, and discriminant analysis."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import polars as pl
7
+
8
+ try:
9
+ from sklearn.cluster import KMeans, AgglomerativeClustering # type: ignore[import]
10
+ from sklearn.manifold import MDS # type: ignore[import]
11
+ from sklearn.discriminant_analysis import ( # type: ignore[import]
12
+ LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis,
13
+ )
14
+ from sklearn.preprocessing import StandardScaler, LabelEncoder # type: ignore[import]
15
+ from sklearn.metrics import ( # type: ignore[import]
16
+ silhouette_score, calinski_harabasz_score, accuracy_score,
17
+ )
18
+ _HAS_SKLEARN = True
19
+ except ImportError:
20
+ _HAS_SKLEARN = False
21
+
22
+
23
+ def _require_sklearn():
24
+ if not _HAS_SKLEARN:
25
+ raise ImportError(
26
+ "scikit-learn is required for clustering commands.\n"
27
+ "Install: pip install scikit-learn"
28
+ )
29
+
30
+
31
+ def _std(df: pl.DataFrame, cols: list[str]) -> np.ndarray:
32
+ X = df.select(cols).drop_nulls().to_numpy().astype(float)
33
+ from sklearn.preprocessing import StandardScaler
34
+ return StandardScaler().fit_transform(X), X
35
+
36
+
37
+ # ── K-Means ────────────────────────────────────────────────────────────────
38
+
39
+ def fit_kmeans(
40
+ df: pl.DataFrame,
41
+ cols: list[str],
42
+ *,
43
+ k: int = 3,
44
+ n_init: int = 10,
45
+ max_iter: int = 300,
46
+ random_state: int = 42,
47
+ ) -> dict:
48
+ """K-means clustering."""
49
+ _require_sklearn()
50
+ X_s, X_raw = _std(df, cols)
51
+ n = len(X_s)
52
+
53
+ model = KMeans(n_clusters=k, n_init=n_init, max_iter=max_iter, random_state=random_state)
54
+ labels = model.fit_predict(X_s)
55
+
56
+ sil = float(silhouette_score(X_s, labels)) if k > 1 else float("nan")
57
+ ch = float(calinski_harabasz_score(X_s, labels)) if k > 1 else float("nan")
58
+ inertia = float(model.inertia_)
59
+
60
+ cluster_sizes = {int(i): int((labels == i).sum()) for i in range(k)}
61
+
62
+ # Cluster centroids (in original scale)
63
+ centroids_std = model.cluster_centers_
64
+ # Back-transform using per-column stats
65
+ means = X_raw.mean(axis=0)
66
+ stds = X_raw.std(axis=0) + 1e-15
67
+ centroids_orig = centroids_std * stds + means
68
+
69
+ return {
70
+ "method": "K-Means",
71
+ "cols": cols,
72
+ "k": k,
73
+ "n_obs": n,
74
+ "inertia": inertia,
75
+ "silhouette_score": sil,
76
+ "calinski_harabasz": ch,
77
+ "cluster_sizes": cluster_sizes,
78
+ "centroids": centroids_orig.tolist(),
79
+ "labels": labels.tolist(),
80
+ "_model": model,
81
+ }
82
+
83
+
84
+ # ── Hierarchical (Agglomerative) ───────────────────────────────────────────
85
+
86
+ def fit_hierarchical(
87
+ df: pl.DataFrame,
88
+ cols: list[str],
89
+ *,
90
+ k: int = 3,
91
+ linkage: str = "ward",
92
+ metric: str = "euclidean",
93
+ ) -> dict:
94
+ """Agglomerative hierarchical clustering."""
95
+ _require_sklearn()
96
+ X_s, _ = _std(df, cols)
97
+ n = len(X_s)
98
+
99
+ link = linkage if linkage != "ward" or metric == "euclidean" else "average"
100
+ model = AgglomerativeClustering(n_clusters=k, linkage=link)
101
+ labels = model.fit_predict(X_s)
102
+
103
+ sil = float(silhouette_score(X_s, labels)) if k > 1 else float("nan")
104
+ ch = float(calinski_harabasz_score(X_s, labels)) if k > 1 else float("nan")
105
+ cluster_sizes = {int(i): int((labels == i).sum()) for i in range(k)}
106
+
107
+ return {
108
+ "method": "Hierarchical",
109
+ "cols": cols,
110
+ "k": k,
111
+ "linkage": linkage,
112
+ "n_obs": n,
113
+ "silhouette_score": sil,
114
+ "calinski_harabasz": ch,
115
+ "cluster_sizes": cluster_sizes,
116
+ "labels": labels.tolist(),
117
+ "_model": model,
118
+ }
119
+
120
+
121
+ # ── MDS ────────────────────────────────────────────────────────────────────
122
+
123
+ def fit_mds(
124
+ df: pl.DataFrame,
125
+ cols: list[str],
126
+ *,
127
+ n_components: int = 2,
128
+ metric: bool = True,
129
+ random_state: int = 42,
130
+ ) -> dict:
131
+ """Multidimensional Scaling."""
132
+ _require_sklearn()
133
+ X_s, _ = _std(df, cols)
134
+
135
+ model = MDS(
136
+ n_components=n_components,
137
+ metric=metric,
138
+ random_state=random_state,
139
+ normalized_stress="auto",
140
+ )
141
+ coords = model.fit_transform(X_s)
142
+ stress = float(model.stress_)
143
+
144
+ return {
145
+ "method": "MDS",
146
+ "cols": cols,
147
+ "n_components": n_components,
148
+ "metric": metric,
149
+ "stress": stress,
150
+ "n_obs": len(X_s),
151
+ "coordinates": coords.tolist(),
152
+ "_model": model,
153
+ }
154
+
155
+
156
+ # ── Discriminant Analysis ──────────────────────────────────────────────────
157
+
158
+ def fit_discriminant(
159
+ df: pl.DataFrame,
160
+ dep: str,
161
+ indeps: list[str],
162
+ *,
163
+ method: str = "lda",
164
+ ) -> dict:
165
+ """Linear or Quadratic Discriminant Analysis."""
166
+ _require_sklearn()
167
+ sub = df.select([dep] + indeps).drop_nulls()
168
+ y_raw = sub[dep].to_numpy()
169
+ X = sub.select(indeps).to_numpy().astype(float)
170
+
171
+ le = LabelEncoder()
172
+ y = le.fit_transform(y_raw.astype(str))
173
+
174
+ if method.lower() == "qda":
175
+ model = QuadraticDiscriminantAnalysis()
176
+ else:
177
+ model = LinearDiscriminantAnalysis()
178
+
179
+ model.fit(X, y)
180
+ y_pred = model.predict(X)
181
+ acc = float(accuracy_score(y, y_pred))
182
+
183
+ classes = le.classes_.tolist()
184
+ prior = model.priors_.tolist() if hasattr(model, "priors_") else []
185
+
186
+ result = {
187
+ "method": method.upper(),
188
+ "dep": dep,
189
+ "indeps": indeps,
190
+ "classes": classes,
191
+ "n_classes": len(classes),
192
+ "priors": prior,
193
+ "accuracy": acc,
194
+ "n_obs": len(y),
195
+ "_model": model,
196
+ "_le": le,
197
+ }
198
+
199
+ # LDA-specific: discriminant function coefficients
200
+ if method.lower() == "lda" and hasattr(model, "coef_"):
201
+ result["coefficients"] = {
202
+ cls: dict(zip(indeps, model.coef_[i].tolist()))
203
+ for i, cls in enumerate(classes[1:]) # k-1 functions
204
+ }
205
+
206
+ return result