openstat-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. openstat/__init__.py +3 -0
  2. openstat/__main__.py +4 -0
  3. openstat/backends/__init__.py +16 -0
  4. openstat/backends/duckdb_backend.py +70 -0
  5. openstat/backends/polars_backend.py +52 -0
  6. openstat/cli.py +92 -0
  7. openstat/commands/__init__.py +82 -0
  8. openstat/commands/adv_stat_cmds.py +1255 -0
  9. openstat/commands/advanced_ml_cmds.py +576 -0
  10. openstat/commands/advreg_cmds.py +207 -0
  11. openstat/commands/alias_cmds.py +135 -0
  12. openstat/commands/arch_cmds.py +82 -0
  13. openstat/commands/arules_cmds.py +111 -0
  14. openstat/commands/automodel_cmds.py +212 -0
  15. openstat/commands/backend_cmds.py +82 -0
  16. openstat/commands/base.py +170 -0
  17. openstat/commands/bayes_cmds.py +71 -0
  18. openstat/commands/causal_cmds.py +269 -0
  19. openstat/commands/cluster_cmds.py +152 -0
  20. openstat/commands/data_cmds.py +996 -0
  21. openstat/commands/datamanip_cmds.py +672 -0
  22. openstat/commands/dataquality_cmds.py +174 -0
  23. openstat/commands/datetime_cmds.py +176 -0
  24. openstat/commands/dimreduce_cmds.py +184 -0
  25. openstat/commands/discrete_cmds.py +149 -0
  26. openstat/commands/dsl_cmds.py +143 -0
  27. openstat/commands/epi_cmds.py +93 -0
  28. openstat/commands/equiv_tobit_cmds.py +94 -0
  29. openstat/commands/esttab_cmds.py +196 -0
  30. openstat/commands/export_beamer_cmds.py +142 -0
  31. openstat/commands/export_cmds.py +201 -0
  32. openstat/commands/export_extra_cmds.py +240 -0
  33. openstat/commands/factor_cmds.py +180 -0
  34. openstat/commands/groupby_cmds.py +155 -0
  35. openstat/commands/help_cmds.py +237 -0
  36. openstat/commands/i18n_cmds.py +43 -0
  37. openstat/commands/import_extra_cmds.py +561 -0
  38. openstat/commands/influence_cmds.py +134 -0
  39. openstat/commands/iv_cmds.py +106 -0
  40. openstat/commands/manova_cmds.py +105 -0
  41. openstat/commands/mediate_cmds.py +233 -0
  42. openstat/commands/meta_cmds.py +284 -0
  43. openstat/commands/mi_cmds.py +228 -0
  44. openstat/commands/mixed_cmds.py +79 -0
  45. openstat/commands/mixture_changepoint_cmds.py +166 -0
  46. openstat/commands/ml_adv_cmds.py +147 -0
  47. openstat/commands/ml_cmds.py +178 -0
  48. openstat/commands/model_eval_cmds.py +142 -0
  49. openstat/commands/network_cmds.py +288 -0
  50. openstat/commands/nlquery_cmds.py +161 -0
  51. openstat/commands/nonparam_cmds.py +149 -0
  52. openstat/commands/outreg_cmds.py +247 -0
  53. openstat/commands/panel_cmds.py +141 -0
  54. openstat/commands/pdf_cmds.py +226 -0
  55. openstat/commands/pipeline_cmds.py +319 -0
  56. openstat/commands/plot_cmds.py +189 -0
  57. openstat/commands/plugin_cmds.py +79 -0
  58. openstat/commands/posthoc_cmds.py +153 -0
  59. openstat/commands/power_cmds.py +172 -0
  60. openstat/commands/profile_cmds.py +246 -0
  61. openstat/commands/rbridge_cmds.py +81 -0
  62. openstat/commands/regex_cmds.py +104 -0
  63. openstat/commands/report_cmds.py +48 -0
  64. openstat/commands/repro_cmds.py +129 -0
  65. openstat/commands/resampling_cmds.py +109 -0
  66. openstat/commands/reshape_cmds.py +223 -0
  67. openstat/commands/sem_cmds.py +177 -0
  68. openstat/commands/stat_cmds.py +1040 -0
  69. openstat/commands/stata_import_cmds.py +215 -0
  70. openstat/commands/string_cmds.py +124 -0
  71. openstat/commands/surv_cmds.py +145 -0
  72. openstat/commands/survey_cmds.py +153 -0
  73. openstat/commands/textanalysis_cmds.py +192 -0
  74. openstat/commands/ts_adv_cmds.py +136 -0
  75. openstat/commands/ts_cmds.py +195 -0
  76. openstat/commands/tui_cmds.py +111 -0
  77. openstat/commands/ux_cmds.py +191 -0
  78. openstat/commands/validate_cmds.py +270 -0
  79. openstat/commands/viz_adv_cmds.py +312 -0
  80. openstat/commands/viz_extra_cmds.py +251 -0
  81. openstat/commands/watch_cmds.py +69 -0
  82. openstat/config.py +106 -0
  83. openstat/dsl/__init__.py +0 -0
  84. openstat/dsl/parser.py +332 -0
  85. openstat/dsl/tokenizer.py +105 -0
  86. openstat/i18n.py +120 -0
  87. openstat/io/__init__.py +0 -0
  88. openstat/io/loader.py +187 -0
  89. openstat/jupyter/__init__.py +18 -0
  90. openstat/jupyter/display.py +18 -0
  91. openstat/jupyter/magic.py +60 -0
  92. openstat/logging_config.py +59 -0
  93. openstat/plots/__init__.py +0 -0
  94. openstat/plots/plotter.py +437 -0
  95. openstat/plots/surv_plots.py +32 -0
  96. openstat/plots/ts_plots.py +59 -0
  97. openstat/plugins/__init__.py +5 -0
  98. openstat/plugins/manager.py +69 -0
  99. openstat/repl.py +457 -0
  100. openstat/reporting/__init__.py +0 -0
  101. openstat/reporting/eda.py +208 -0
  102. openstat/reporting/report.py +67 -0
  103. openstat/script_runner.py +319 -0
  104. openstat/session.py +133 -0
  105. openstat/stats/__init__.py +0 -0
  106. openstat/stats/advanced_regression.py +269 -0
  107. openstat/stats/arch_garch.py +84 -0
  108. openstat/stats/bayesian.py +103 -0
  109. openstat/stats/causal.py +258 -0
  110. openstat/stats/clustering.py +206 -0
  111. openstat/stats/discrete.py +311 -0
  112. openstat/stats/epidemiology.py +119 -0
  113. openstat/stats/equiv_tobit.py +163 -0
  114. openstat/stats/factor.py +174 -0
  115. openstat/stats/imputation.py +282 -0
  116. openstat/stats/influence.py +78 -0
  117. openstat/stats/iv.py +131 -0
  118. openstat/stats/manova.py +124 -0
  119. openstat/stats/mixed.py +128 -0
  120. openstat/stats/ml.py +275 -0
  121. openstat/stats/ml_advanced.py +117 -0
  122. openstat/stats/model_eval.py +183 -0
  123. openstat/stats/models.py +1342 -0
  124. openstat/stats/nonparametric.py +130 -0
  125. openstat/stats/panel.py +179 -0
  126. openstat/stats/power.py +295 -0
  127. openstat/stats/resampling.py +203 -0
  128. openstat/stats/survey.py +213 -0
  129. openstat/stats/survival.py +196 -0
  130. openstat/stats/timeseries.py +142 -0
  131. openstat/stats/ts_advanced.py +114 -0
  132. openstat/types.py +11 -0
  133. openstat/web/__init__.py +1 -0
  134. openstat/web/app.py +117 -0
  135. openstat/web/session_manager.py +73 -0
  136. openstat/web/static/app.js +117 -0
  137. openstat/web/static/index.html +38 -0
  138. openstat/web/static/style.css +103 -0
  139. openstat_cli-1.0.0.dist-info/METADATA +748 -0
  140. openstat_cli-1.0.0.dist-info/RECORD +143 -0
  141. openstat_cli-1.0.0.dist-info/WHEEL +4 -0
  142. openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
  143. openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,174 @@
1
+ """Data quality commands: duplicates, winsor, standardize, normalize, mdpattern."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ import polars as pl
8
+
9
+ from openstat.commands.base import command
10
+ from openstat.session import Session
11
+
12
+
13
+ def _stata_opts(raw: str) -> tuple[list[str], dict[str, str]]:
14
+ opts: dict[str, str] = {}
15
+ for m in re.finditer(r'(\w+)\(([^)]*)\)', raw):
16
+ opts[m.group(1).lower()] = m.group(2)
17
+ rest = re.sub(r'\w+\([^)]*\)', '', raw)
18
+ positional = [t.strip(',') for t in rest.split() if t.strip(',')]
19
+ return positional, opts
20
+
21
+
22
+ @command("duplicates", usage="duplicates [report|drop] [varlist]")
23
+ def cmd_duplicates(session: Session, args: str) -> str:
24
+ """Report or drop duplicate observations."""
25
+ df = session.require_data()
26
+ positional, opts = _stata_opts(args)
27
+ action = positional[0].lower() if positional else "report"
28
+ subset = [c for c in positional[1:] if c in df.columns] or None
29
+
30
+ if action == "report":
31
+ if subset:
32
+ dup_mask = df.select(subset).is_duplicated()
33
+ else:
34
+ dup_mask = df.is_duplicated()
35
+ n_dup = int(dup_mask.sum())
36
+ n_unique_dup = int(df.filter(dup_mask).height) - int(df.filter(dup_mask).unique(subset=subset).height) if n_dup > 0 else 0
37
+ return (
38
+ f"Duplicates report:\n"
39
+ f" Total observations: {df.height}\n"
40
+ f" Duplicate rows: {n_dup}\n"
41
+ f" Unique duplicated obs: {n_unique_dup}\n"
42
+ f" Subset: {subset or 'all columns'}"
43
+ )
44
+ elif action in ("drop", "list"):
45
+ session.snapshot()
46
+ if subset:
47
+ clean_df = df.unique(subset=subset, keep="first")
48
+ else:
49
+ clean_df = df.unique(keep="first")
50
+ n_dropped = df.height - clean_df.height
51
+ if action == "drop":
52
+ session.df = clean_df
53
+ return f"Dropped {n_dropped} duplicate rows. {clean_df.height} rows remain."
54
+ else:
55
+ return f"Found {n_dropped} duplicate rows (use 'duplicates drop' to remove)."
56
+ else:
57
+ return "Usage: duplicates [report|drop|list] [varlist]"
58
+
59
+
60
+ @command("winsor", usage="winsor varname [p(0.05) gen(newvar)]")
61
+ def cmd_winsor(session: Session, args: str) -> str:
62
+ """Winsorize a variable at specified percentile (both tails)."""
63
+ df = session.require_data()
64
+ positional, opts = _stata_opts(args)
65
+ if not positional:
66
+ return "Usage: winsor varname [p(0.05) gen(newvar)]"
67
+ var = positional[0]
68
+ if var not in df.columns:
69
+ return f"Column '{var}' not found."
70
+ p = float(opts.get("p", 0.05))
71
+ new_var = opts.get("gen", f"{var}_w")
72
+ session.snapshot()
73
+ try:
74
+ series = df[var].cast(pl.Float64)
75
+ lo = float(series.quantile(p))
76
+ hi = float(series.quantile(1 - p))
77
+ winsorized = series.clip(lo, hi)
78
+ session.df = df.with_columns(winsorized.alias(new_var))
79
+ n_lo = int((series < lo).sum())
80
+ n_hi = int((series > hi).sum())
81
+ return (
82
+ f"Winsorized '{var}' → '{new_var}'\n"
83
+ f" Lower cutoff ({p*100:.1f}%): {lo:.4f} ({n_lo} obs clipped)\n"
84
+ f" Upper cutoff ({(1-p)*100:.1f}%): {hi:.4f} ({n_hi} obs clipped)"
85
+ )
86
+ except Exception as exc:
87
+ return f"winsor error: {exc}"
88
+
89
+
90
+ @command("standardize", usage="standardize var1 [var2 ...] [gen(prefix_)]")
91
+ def cmd_standardize(session: Session, args: str) -> str:
92
+ """Z-score standardize variables: (x - mean) / std."""
93
+ df = session.require_data()
94
+ positional, opts = _stata_opts(args)
95
+ cols = [c for c in positional if c in df.columns]
96
+ if not cols:
97
+ return "No valid numeric variables found."
98
+ prefix = opts.get("gen", "")
99
+ session.snapshot()
100
+ try:
101
+ new_df = df
102
+ new_cols = []
103
+ for col in cols:
104
+ s = df[col].cast(pl.Float64)
105
+ m = float(s.mean())
106
+ sd = float(s.std())
107
+ new_name = f"{prefix}{col}_z" if not prefix else f"{prefix}{col}"
108
+ new_df = new_df.with_columns(((s - m) / max(sd, 1e-10)).alias(new_name))
109
+ new_cols.append(new_name)
110
+ session.df = new_df
111
+ return f"Standardized {len(cols)} variable(s): {new_cols}"
112
+ except Exception as exc:
113
+ return f"standardize error: {exc}"
114
+
115
+
116
+ @command("normalize", usage="normalize var1 [var2 ...] [gen(prefix_)]")
117
+ def cmd_normalize(session: Session, args: str) -> str:
118
+ """Min-max normalize variables to [0, 1]."""
119
+ df = session.require_data()
120
+ positional, opts = _stata_opts(args)
121
+ cols = [c for c in positional if c in df.columns]
122
+ if not cols:
123
+ return "No valid numeric variables found."
124
+ prefix = opts.get("gen", "")
125
+ session.snapshot()
126
+ try:
127
+ new_df = df
128
+ new_cols = []
129
+ for col in cols:
130
+ s = df[col].cast(pl.Float64)
131
+ lo = float(s.min())
132
+ hi = float(s.max())
133
+ new_name = f"{prefix}{col}_norm" if not prefix else f"{prefix}{col}"
134
+ new_df = new_df.with_columns(((s - lo) / max(hi - lo, 1e-10)).alias(new_name))
135
+ new_cols.append(new_name)
136
+ session.df = new_df
137
+ return f"Normalized {len(cols)} variable(s) to [0,1]: {new_cols}"
138
+ except Exception as exc:
139
+ return f"normalize error: {exc}"
140
+
141
+
142
+ @command("mdpattern", usage="mdpattern [var1 var2 ...]")
143
+ def cmd_mdpattern(session: Session, args: str) -> str:
144
+ """Display missing data pattern for all (or specified) variables."""
145
+ df = session.require_data()
146
+ positional, opts = _stata_opts(args)
147
+ cols = [c for c in positional if c in df.columns] or df.columns
148
+
149
+ lines = ["\nMissing Data Pattern", "=" * 60]
150
+ lines.append(f" N = {df.height} observations, {len(cols)} variables\n")
151
+
152
+ col_w = max(len(c) for c in cols) + 2
153
+ header = f" {'Variable':<{col_w}} {'Missing':>8} {'%Missing':>10} {'Complete':>10}"
154
+ lines.append(header)
155
+ lines.append(" " + "-" * (col_w + 32))
156
+
157
+ total_missing = 0
158
+ for col in cols:
159
+ n_miss = int(df[col].is_null().sum())
160
+ pct = 100.0 * n_miss / df.height if df.height > 0 else 0.0
161
+ n_complete = df.height - n_miss
162
+ total_missing += n_miss
163
+ bar = "░" * int(pct / 5) # bar in 5% increments
164
+ lines.append(f" {col:<{col_w}} {n_miss:>8} {pct:>9.1f}% {n_complete:>10} {bar}")
165
+
166
+ lines.append(" " + "-" * (col_w + 32))
167
+ overall_pct = 100.0 * total_missing / (df.height * len(cols)) if df.height > 0 else 0.0
168
+ lines.append(f" {'Total missing cells':<{col_w}} {total_missing:>8} {overall_pct:>9.1f}%")
169
+
170
+ # Complete cases
171
+ n_complete_rows = int(df.select(cols).drop_nulls().height)
172
+ lines.append(f"\n Complete rows (no missing in any selected var): {n_complete_rows} ({100*n_complete_rows/df.height:.1f}%)")
173
+
174
+ return "\n".join(lines)
@@ -0,0 +1,176 @@
1
+ """Datetime operations: extract, arithmetic, format."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from openstat.commands.base import command, CommandArgs, friendly_error
6
+ from openstat.session import Session
7
+
8
+
9
+ @command("datetime", usage="datetime extract|diff|format|parse <col> [options]")
10
+ def cmd_datetime(session: Session, args: str) -> str:
11
+ """Datetime column operations.
12
+
13
+ Sub-commands:
14
+ datetime extract <col> [into(<prefix>)]
15
+ — extract year, month, day, hour, minute, weekday, quarter
16
+ datetime diff <col1> <col2> [unit=days|hours|minutes] [into(<newcol>)]
17
+ — compute difference between two date columns
18
+ datetime format <col> <fmt> [into(<newcol>)]
19
+ — reformat datetime as string (strftime format)
20
+ datetime parse <col> [fmt=<format>] [into(<newcol>)]
21
+ — parse a string column as datetime
22
+ datetime shift <col> <N> <unit> [into(<newcol>)]
23
+ — add/subtract time (e.g. shift date 7 days)
24
+
25
+ Examples:
26
+ datetime extract created_at into(dt)
27
+ → creates dt_year, dt_month, dt_day, dt_weekday, dt_quarter
28
+ datetime diff end_date start_date unit=days into(duration)
29
+ datetime format created_at "%Y-%m" into(year_month)
30
+ datetime parse date_str fmt="%d/%m/%Y" into(date)
31
+ datetime shift order_date 30 days into(delivery_date)
32
+ """
33
+ import polars as pl
34
+
35
+ ca = CommandArgs(args)
36
+ if not ca.positional:
37
+ return "Usage: datetime extract|diff|format|parse|shift <col> ..."
38
+
39
+ subcmd = ca.positional[0].lower()
40
+
41
+ try:
42
+ df = session.require_data()
43
+
44
+ if subcmd == "extract":
45
+ if len(ca.positional) < 2:
46
+ return "Usage: datetime extract <col> [into(<prefix>)]"
47
+ col = ca.positional[1]
48
+ if col not in df.columns:
49
+ return f"Column not found: {col}"
50
+
51
+ into_raw = ca.rest_after("into")
52
+ prefix = into_raw.strip().strip("()") if into_raw else col
53
+
54
+ # Cast to datetime if string
55
+ series = df[col]
56
+ if series.dtype == pl.Utf8:
57
+ series = series.str.to_datetime(strict=False)
58
+
59
+ dt = series.dt
60
+ new_cols = {
61
+ f"{prefix}_year": dt.year(),
62
+ f"{prefix}_month": dt.month(),
63
+ f"{prefix}_day": dt.day(),
64
+ f"{prefix}_hour": dt.hour(),
65
+ f"{prefix}_minute": dt.minute(),
66
+ f"{prefix}_weekday": dt.weekday(),
67
+ f"{prefix}_quarter": dt.quarter(),
68
+ }
69
+ session.snapshot()
70
+ session.df = df.with_columns([
71
+ pl.Series(name, vals) for name, vals in new_cols.items()
72
+ ])
73
+ parts = ", ".join(new_cols.keys())
74
+ return f"Extracted datetime components: {parts}"
75
+
76
+ elif subcmd == "diff":
77
+ if len(ca.positional) < 3:
78
+ return "Usage: datetime diff <col1> <col2> [unit=days] [into(<newcol>)]"
79
+ c1, c2 = ca.positional[1], ca.positional[2]
80
+ unit = ca.options.get("unit", "days")
81
+ into_raw = ca.rest_after("into")
82
+ newcol = into_raw.strip().strip("()") if into_raw else f"{c1}_minus_{c2}"
83
+
84
+ for c in [c1, c2]:
85
+ if c not in df.columns:
86
+ return f"Column not found: {c}"
87
+
88
+ def _to_dt(s):
89
+ if s.dtype == pl.Utf8:
90
+ return s.str.to_datetime(strict=False)
91
+ return s.cast(pl.Datetime)
92
+
93
+ s1 = _to_dt(df[c1])
94
+ s2 = _to_dt(df[c2])
95
+ diff_dur = s1 - s2
96
+
97
+ unit_map = {
98
+ "days": 86_400_000_000,
99
+ "hours": 3_600_000_000,
100
+ "minutes": 60_000_000,
101
+ "seconds": 1_000_000,
102
+ }
103
+ divisor = unit_map.get(unit, 86_400_000_000)
104
+ diff_num = (diff_dur.dt.total_microseconds() / divisor).cast(pl.Float64)
105
+
106
+ session.snapshot()
107
+ session.df = df.with_columns(diff_num.alias(newcol))
108
+ return f"Date difference stored in '{newcol}' ({unit}). Mean: {diff_num.mean():.2f}"
109
+
110
+ elif subcmd == "format":
111
+ if len(ca.positional) < 3:
112
+ return "Usage: datetime format <col> <fmt> [into(<newcol>)]"
113
+ col, fmt = ca.positional[1], ca.positional[2]
114
+ if col not in df.columns:
115
+ return f"Column not found: {col}"
116
+ into_raw = ca.rest_after("into")
117
+ newcol = into_raw.strip().strip("()") if into_raw else f"{col}_fmt"
118
+
119
+ series = df[col]
120
+ if series.dtype == pl.Utf8:
121
+ series = series.str.to_datetime(strict=False)
122
+ formatted = series.dt.strftime(fmt)
123
+
124
+ session.snapshot()
125
+ session.df = df.with_columns(formatted.alias(newcol))
126
+ return f"Formatted '{col}' → '{newcol}' using format '{fmt}'"
127
+
128
+ elif subcmd == "parse":
129
+ if len(ca.positional) < 2:
130
+ return "Usage: datetime parse <col> [fmt=<format>] [into(<newcol>)]"
131
+ col = ca.positional[1]
132
+ if col not in df.columns:
133
+ return f"Column not found: {col}"
134
+ fmt = ca.options.get("fmt")
135
+ into_raw = ca.rest_after("into")
136
+ newcol = into_raw.strip().strip("()") if into_raw else f"{col}_dt"
137
+
138
+ if fmt:
139
+ parsed = df[col].str.to_datetime(format=fmt, strict=False)
140
+ else:
141
+ parsed = df[col].str.to_datetime(strict=False)
142
+
143
+ session.snapshot()
144
+ session.df = df.with_columns(parsed.alias(newcol))
145
+ n_ok = parsed.drop_nulls().len()
146
+ return f"Parsed '{col}' → '{newcol}': {n_ok}/{df.height} rows parsed."
147
+
148
+ elif subcmd == "shift":
149
+ if len(ca.positional) < 4:
150
+ return "Usage: datetime shift <col> <N> <unit> [into(<newcol>)]"
151
+ col = ca.positional[1]
152
+ n = int(ca.positional[2])
153
+ unit = ca.positional[3].lower().rstrip("s") # days→day, hours→hour
154
+ if col not in df.columns:
155
+ return f"Column not found: {col}"
156
+ into_raw = ca.rest_after("into")
157
+ newcol = into_raw.strip().strip("()") if into_raw else f"{col}_shifted"
158
+
159
+ series = df[col]
160
+ if series.dtype == pl.Utf8:
161
+ series = series.str.to_datetime(strict=False)
162
+
163
+ from datetime import timedelta
164
+ unit_map2 = {"day": "days", "hour": "hours", "minute": "minutes", "second": "seconds", "week": "weeks"}
165
+ td_key = unit_map2.get(unit, "days")
166
+ shifted = series + pl.duration(**{td_key: n})
167
+
168
+ session.snapshot()
169
+ session.df = df.with_columns(shifted.alias(newcol))
170
+ return f"Shifted '{col}' by {n} {unit}(s) → '{newcol}'"
171
+
172
+ else:
173
+ return f"Unknown sub-command: {subcmd}. Use extract, diff, format, parse, or shift."
174
+
175
+ except Exception as e:
176
+ return friendly_error(e, "datetime")
@@ -0,0 +1,184 @@
1
+ """Dimensionality reduction: t-SNE, UMAP, PCA plot."""
2
+
3
+ from __future__ import annotations
4
+ from openstat.commands.base import command, CommandArgs, friendly_error
5
+ from openstat.session import Session
6
+
7
+
8
+ @command("tsne", usage="tsne [cols...] [--n=2] [--perplexity=30] [--out=tsne.png] [--color=col]")
9
+ def cmd_tsne(session: Session, args: str) -> str:
10
+ """t-SNE dimensionality reduction and visualization.
11
+
12
+ Options:
13
+ --n=<dim> output dimensions (2 or 3, default: 2)
14
+ --perplexity=<p> perplexity (5–50, default: 30)
15
+ --iter=<n> iterations (default: 1000)
16
+ --color=<col> column to colour points by
17
+ --out=<path> output image path
18
+
19
+ Examples:
20
+ tsne x1 x2 x3 x4 --color=label
21
+ tsne --perplexity=20 --iter=2000 --out=tsne_result.png
22
+ """
23
+ try:
24
+ from sklearn.manifold import TSNE
25
+ except ImportError:
26
+ return "scikit-learn required. Install: pip install scikit-learn"
27
+
28
+ import matplotlib
29
+ matplotlib.use("Agg")
30
+ import matplotlib.pyplot as plt
31
+ import numpy as np
32
+ import polars as pl
33
+
34
+ ca = CommandArgs(args)
35
+ n_dim = int(ca.options.get("n", 2))
36
+ perplexity = float(ca.options.get("perplexity", 30))
37
+ n_iter = int(ca.options.get("iter", 1000))
38
+ color_col = ca.options.get("color")
39
+ out_path = ca.options.get("out", str(session.output_dir / "tsne.png"))
40
+
41
+ try:
42
+ df = session.require_data()
43
+ NUMERIC = (pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64,
44
+ pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64)
45
+ if ca.positional:
46
+ cols = [c for c in ca.positional if c in df.columns]
47
+ else:
48
+ cols = [c for c in df.columns if df[c].dtype in NUMERIC]
49
+
50
+ if len(cols) < 2:
51
+ return "Need at least 2 numeric columns for t-SNE."
52
+
53
+ sub_cols = cols[:]
54
+ if color_col and color_col in df.columns and color_col not in sub_cols:
55
+ sub_cols.append(color_col)
56
+
57
+ sub = df.select(sub_cols).drop_nulls()
58
+ X = sub.select(cols).to_numpy().astype(float)
59
+
60
+ if len(X) < 5:
61
+ return "Need at least 5 rows for t-SNE."
62
+
63
+ perplexity = min(perplexity, len(X) - 1)
64
+ tsne = TSNE(n_components=n_dim, perplexity=perplexity, max_iter=n_iter,
65
+ random_state=42)
66
+ embedding = tsne.fit_transform(X)
67
+
68
+ fig, ax = plt.subplots(figsize=(8, 6))
69
+ if color_col and color_col in sub.columns:
70
+ cats = sub[color_col].cast(pl.Utf8).to_list()
71
+ unique_cats = sorted(set(cats))
72
+ cmap = plt.colormaps.get_cmap("tab10")
73
+ for i, cat in enumerate(unique_cats):
74
+ mask = [c == cat for c in cats]
75
+ ax.scatter(embedding[mask, 0], embedding[mask, 1],
76
+ label=str(cat), alpha=0.7, s=20,
77
+ color=cmap(i / max(len(unique_cats), 1)))
78
+ ax.legend(title=color_col, markerscale=2, fontsize=8)
79
+ else:
80
+ ax.scatter(embedding[:, 0], embedding[:, 1], alpha=0.6, s=20, color="#4C72B0")
81
+
82
+ ax.set_xlabel("t-SNE 1")
83
+ ax.set_ylabel("t-SNE 2")
84
+ ax.set_title(f"t-SNE ({len(cols)} features, perplexity={perplexity:.0f})")
85
+ fig.tight_layout()
86
+
87
+ session.output_dir.mkdir(parents=True, exist_ok=True)
88
+ from pathlib import Path
89
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
90
+ fig.savefig(out_path, dpi=150)
91
+ plt.close(fig)
92
+ session.plot_paths.append(out_path)
93
+ return f"t-SNE plot saved: {out_path} (n={len(X)}, features={len(cols)})"
94
+ except Exception as e:
95
+ return friendly_error(e, "tsne")
96
+
97
+
98
+ @command("umap", usage="umap [cols...] [--n=2] [--neighbors=15] [--out=umap.png] [--color=col]")
99
+ def cmd_umap(session: Session, args: str) -> str:
100
+ """UMAP dimensionality reduction and visualization.
101
+
102
+ Options:
103
+ --n=<dim> output dimensions (2 or 3, default: 2)
104
+ --neighbors=<k> number of neighbors (default: 15)
105
+ --mindist=<d> minimum distance (default: 0.1)
106
+ --color=<col> column to colour points by
107
+ --out=<path> output image path
108
+
109
+ Examples:
110
+ umap x1 x2 x3 x4 --color=label
111
+ umap --neighbors=20 --mindist=0.05
112
+ """
113
+ try:
114
+ import umap as umap_lib
115
+ except ImportError:
116
+ return "umap-learn required. Install: pip install umap-learn"
117
+
118
+ import matplotlib
119
+ matplotlib.use("Agg")
120
+ import matplotlib.pyplot as plt
121
+ import polars as pl
122
+
123
+ ca = CommandArgs(args)
124
+ n_dim = int(ca.options.get("n", 2))
125
+ n_neighbors = int(ca.options.get("neighbors", 15))
126
+ min_dist = float(ca.options.get("mindist", 0.1))
127
+ color_col = ca.options.get("color")
128
+ out_path = ca.options.get("out", str(session.output_dir / "umap.png"))
129
+
130
+ try:
131
+ df = session.require_data()
132
+ NUMERIC = (pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64,
133
+ pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64)
134
+ if ca.positional:
135
+ cols = [c for c in ca.positional if c in df.columns]
136
+ else:
137
+ cols = [c for c in df.columns if df[c].dtype in NUMERIC]
138
+
139
+ if len(cols) < 2:
140
+ return "Need at least 2 numeric columns for UMAP."
141
+
142
+ sub_cols = cols[:]
143
+ if color_col and color_col in df.columns and color_col not in sub_cols:
144
+ sub_cols.append(color_col)
145
+
146
+ sub = df.select(sub_cols).drop_nulls()
147
+ X = sub.select(cols).to_numpy().astype(float)
148
+
149
+ if len(X) < 4:
150
+ return "Need at least 4 rows for UMAP."
151
+
152
+ n_neighbors = min(n_neighbors, len(X) - 1)
153
+ reducer = umap_lib.UMAP(n_components=n_dim, n_neighbors=n_neighbors,
154
+ min_dist=min_dist, random_state=42)
155
+ embedding = reducer.fit_transform(X)
156
+
157
+ fig, ax = plt.subplots(figsize=(8, 6))
158
+ if color_col and color_col in sub.columns:
159
+ cats = sub[color_col].cast(pl.Utf8).to_list()
160
+ unique_cats = sorted(set(cats))
161
+ cmap = plt.colormaps.get_cmap("tab10")
162
+ for i, cat in enumerate(unique_cats):
163
+ mask = [c == cat for c in cats]
164
+ ax.scatter(embedding[mask, 0], embedding[mask, 1],
165
+ label=str(cat), alpha=0.7, s=20,
166
+ color=cmap(i / max(len(unique_cats), 1)))
167
+ ax.legend(title=color_col, markerscale=2, fontsize=8)
168
+ else:
169
+ ax.scatter(embedding[:, 0], embedding[:, 1], alpha=0.6, s=20, color="#4C72B0")
170
+
171
+ ax.set_xlabel("UMAP 1")
172
+ ax.set_ylabel("UMAP 2")
173
+ ax.set_title(f"UMAP ({len(cols)} features, neighbors={n_neighbors})")
174
+ fig.tight_layout()
175
+
176
+ session.output_dir.mkdir(parents=True, exist_ok=True)
177
+ from pathlib import Path
178
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
179
+ fig.savefig(out_path, dpi=150)
180
+ plt.close(fig)
181
+ session.plot_paths.append(out_path)
182
+ return f"UMAP plot saved: {out_path} (n={len(X)}, features={len(cols)})"
183
+ except Exception as e:
184
+ return friendly_error(e, "umap")
@@ -0,0 +1,149 @@
1
+ """Discrete / censored model commands: tobit, mlogit, ologit, oprobit."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ from openstat.session import Session, ModelResult
8
+ from openstat.dsl.parser import parse_formula, ParseError
9
+ from openstat.stats.discrete import fit_tobit, fit_mlogit, fit_ordered
10
+ from openstat.commands.base import command, CommandArgs, friendly_error
11
+
12
+
13
+ def _store_model(session, result, raw_model, dep, indeps):
14
+ """Store model in session state, return summary output."""
15
+ session._last_model = raw_model
16
+ session._last_model_vars = (dep, indeps)
17
+ session._last_fit_result = result
18
+ session._last_fit_kwargs = {}
19
+ md = result.to_markdown()
20
+ details: dict = {
21
+ "n_obs": result.n_obs,
22
+ "params": dict(result.params),
23
+ "std_errors": dict(result.std_errors),
24
+ }
25
+ if result.aic is not None:
26
+ details["aic"] = result.aic
27
+ if result.bic is not None:
28
+ details["bic"] = result.bic
29
+ if result.pseudo_r2 is not None:
30
+ details["pseudo_r2"] = result.pseudo_r2
31
+ if result.log_likelihood is not None:
32
+ details["log_likelihood"] = result.log_likelihood
33
+ session.results.append(ModelResult(
34
+ name=result.model_type, formula=result.formula,
35
+ table=md, details=details,
36
+ ))
37
+ output = result.summary_table()
38
+ if result.warnings:
39
+ output += "\n" + "\n".join(result.warnings)
40
+ return output
41
+
42
+
43
+ @command("tobit", usage="tobit y ~ x1 + x2 [, ll(0)] [, ul(100)] [--robust]")
44
+ def cmd_tobit(session: Session, args: str) -> str:
45
+ """Fit a Tobit (censored) regression model."""
46
+ df = session.require_data()
47
+ ca = CommandArgs(args)
48
+ robust = ca.has_flag("--robust")
49
+ cluster_col = ca.get_option("cluster")
50
+
51
+ # Parse limits from args: ll(value) and ul(value)
52
+ lower_limit = None
53
+ upper_limit = None
54
+ ll_match = re.search(r'(?:,\s*)?ll\(([^)]+)\)', args)
55
+ ul_match = re.search(r'(?:,\s*)?ul\(([^)]+)\)', args)
56
+ if ll_match:
57
+ try:
58
+ lower_limit = float(ll_match.group(1))
59
+ except ValueError:
60
+ return f"Invalid lower limit: {ll_match.group(1)}"
61
+ if ul_match:
62
+ try:
63
+ upper_limit = float(ul_match.group(1))
64
+ except ValueError:
65
+ return f"Invalid upper limit: {ul_match.group(1)}"
66
+
67
+ # Strip limit specs from formula
68
+ formula_str = args
69
+ for pattern in [r',?\s*ll\([^)]+\)', r',?\s*ul\([^)]+\)']:
70
+ formula_str = re.sub(pattern, '', formula_str)
71
+ formula_str = CommandArgs(formula_str).strip_flags_and_options()
72
+
73
+ if not formula_str or "~" not in formula_str:
74
+ return "Usage: tobit y ~ x1 + x2 [, ll(0)] [, ul(100)] [--robust]"
75
+
76
+ try:
77
+ dep, indeps = parse_formula(formula_str)
78
+ result, raw_model = fit_tobit(
79
+ df, dep, indeps,
80
+ lower_limit=lower_limit,
81
+ upper_limit=upper_limit,
82
+ robust=robust,
83
+ cluster_col=cluster_col,
84
+ )
85
+ return _store_model(session, result, raw_model, dep, indeps)
86
+ except ParseError as e:
87
+ return f"Formula error: {e}"
88
+ except Exception as e:
89
+ return friendly_error(e, "Tobit error")
90
+
91
+
92
+ @command("mlogit", usage="mlogit y ~ x1 + x2 [--robust] [--cluster=col]")
93
+ def cmd_mlogit(session: Session, args: str) -> str:
94
+ """Fit a Multinomial Logit model."""
95
+ df = session.require_data()
96
+ ca = CommandArgs(args)
97
+ robust = ca.has_flag("--robust")
98
+ cluster_col = ca.get_option("cluster")
99
+ formula_str = ca.strip_flags_and_options()
100
+ if not formula_str:
101
+ return "Usage: mlogit y ~ x1 + x2 [--robust] [--cluster=col]"
102
+ try:
103
+ dep, indeps = parse_formula(formula_str)
104
+ result, raw_model = fit_mlogit(df, dep, indeps, robust=robust, cluster_col=cluster_col)
105
+ return _store_model(session, result, raw_model, dep, indeps)
106
+ except ParseError as e:
107
+ return f"Formula error: {e}"
108
+ except Exception as e:
109
+ return friendly_error(e, "MNLogit error")
110
+
111
+
112
+ @command("ologit", usage="ologit y ~ x1 + x2 [--robust]")
113
+ def cmd_ologit(session: Session, args: str) -> str:
114
+ """Fit an Ordered Logit model."""
115
+ df = session.require_data()
116
+ ca = CommandArgs(args)
117
+ robust = ca.has_flag("--robust")
118
+ cluster_col = ca.get_option("cluster")
119
+ formula_str = ca.strip_flags_and_options()
120
+ if not formula_str:
121
+ return "Usage: ologit y ~ x1 + x2 [--robust]"
122
+ try:
123
+ dep, indeps = parse_formula(formula_str)
124
+ result, raw_model = fit_ordered(df, dep, indeps, link="logit", robust=robust, cluster_col=cluster_col)
125
+ return _store_model(session, result, raw_model, dep, indeps)
126
+ except ParseError as e:
127
+ return f"Formula error: {e}"
128
+ except Exception as e:
129
+ return friendly_error(e, "Ordered Logit error")
130
+
131
+
132
+ @command("oprobit", usage="oprobit y ~ x1 + x2 [--robust]")
133
+ def cmd_oprobit(session: Session, args: str) -> str:
134
+ """Fit an Ordered Probit model."""
135
+ df = session.require_data()
136
+ ca = CommandArgs(args)
137
+ robust = ca.has_flag("--robust")
138
+ cluster_col = ca.get_option("cluster")
139
+ formula_str = ca.strip_flags_and_options()
140
+ if not formula_str:
141
+ return "Usage: oprobit y ~ x1 + x2 [--robust]"
142
+ try:
143
+ dep, indeps = parse_formula(formula_str)
144
+ result, raw_model = fit_ordered(df, dep, indeps, link="probit", robust=robust, cluster_col=cluster_col)
145
+ return _store_model(session, result, raw_model, dep, indeps)
146
+ except ParseError as e:
147
+ return f"Formula error: {e}"
148
+ except Exception as e:
149
+ return friendly_error(e, "Ordered Probit error")