openstat-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. openstat/__init__.py +3 -0
  2. openstat/__main__.py +4 -0
  3. openstat/backends/__init__.py +16 -0
  4. openstat/backends/duckdb_backend.py +70 -0
  5. openstat/backends/polars_backend.py +52 -0
  6. openstat/cli.py +92 -0
  7. openstat/commands/__init__.py +82 -0
  8. openstat/commands/adv_stat_cmds.py +1255 -0
  9. openstat/commands/advanced_ml_cmds.py +576 -0
  10. openstat/commands/advreg_cmds.py +207 -0
  11. openstat/commands/alias_cmds.py +135 -0
  12. openstat/commands/arch_cmds.py +82 -0
  13. openstat/commands/arules_cmds.py +111 -0
  14. openstat/commands/automodel_cmds.py +212 -0
  15. openstat/commands/backend_cmds.py +82 -0
  16. openstat/commands/base.py +170 -0
  17. openstat/commands/bayes_cmds.py +71 -0
  18. openstat/commands/causal_cmds.py +269 -0
  19. openstat/commands/cluster_cmds.py +152 -0
  20. openstat/commands/data_cmds.py +996 -0
  21. openstat/commands/datamanip_cmds.py +672 -0
  22. openstat/commands/dataquality_cmds.py +174 -0
  23. openstat/commands/datetime_cmds.py +176 -0
  24. openstat/commands/dimreduce_cmds.py +184 -0
  25. openstat/commands/discrete_cmds.py +149 -0
  26. openstat/commands/dsl_cmds.py +143 -0
  27. openstat/commands/epi_cmds.py +93 -0
  28. openstat/commands/equiv_tobit_cmds.py +94 -0
  29. openstat/commands/esttab_cmds.py +196 -0
  30. openstat/commands/export_beamer_cmds.py +142 -0
  31. openstat/commands/export_cmds.py +201 -0
  32. openstat/commands/export_extra_cmds.py +240 -0
  33. openstat/commands/factor_cmds.py +180 -0
  34. openstat/commands/groupby_cmds.py +155 -0
  35. openstat/commands/help_cmds.py +237 -0
  36. openstat/commands/i18n_cmds.py +43 -0
  37. openstat/commands/import_extra_cmds.py +561 -0
  38. openstat/commands/influence_cmds.py +134 -0
  39. openstat/commands/iv_cmds.py +106 -0
  40. openstat/commands/manova_cmds.py +105 -0
  41. openstat/commands/mediate_cmds.py +233 -0
  42. openstat/commands/meta_cmds.py +284 -0
  43. openstat/commands/mi_cmds.py +228 -0
  44. openstat/commands/mixed_cmds.py +79 -0
  45. openstat/commands/mixture_changepoint_cmds.py +166 -0
  46. openstat/commands/ml_adv_cmds.py +147 -0
  47. openstat/commands/ml_cmds.py +178 -0
  48. openstat/commands/model_eval_cmds.py +142 -0
  49. openstat/commands/network_cmds.py +288 -0
  50. openstat/commands/nlquery_cmds.py +161 -0
  51. openstat/commands/nonparam_cmds.py +149 -0
  52. openstat/commands/outreg_cmds.py +247 -0
  53. openstat/commands/panel_cmds.py +141 -0
  54. openstat/commands/pdf_cmds.py +226 -0
  55. openstat/commands/pipeline_cmds.py +319 -0
  56. openstat/commands/plot_cmds.py +189 -0
  57. openstat/commands/plugin_cmds.py +79 -0
  58. openstat/commands/posthoc_cmds.py +153 -0
  59. openstat/commands/power_cmds.py +172 -0
  60. openstat/commands/profile_cmds.py +246 -0
  61. openstat/commands/rbridge_cmds.py +81 -0
  62. openstat/commands/regex_cmds.py +104 -0
  63. openstat/commands/report_cmds.py +48 -0
  64. openstat/commands/repro_cmds.py +129 -0
  65. openstat/commands/resampling_cmds.py +109 -0
  66. openstat/commands/reshape_cmds.py +223 -0
  67. openstat/commands/sem_cmds.py +177 -0
  68. openstat/commands/stat_cmds.py +1040 -0
  69. openstat/commands/stata_import_cmds.py +215 -0
  70. openstat/commands/string_cmds.py +124 -0
  71. openstat/commands/surv_cmds.py +145 -0
  72. openstat/commands/survey_cmds.py +153 -0
  73. openstat/commands/textanalysis_cmds.py +192 -0
  74. openstat/commands/ts_adv_cmds.py +136 -0
  75. openstat/commands/ts_cmds.py +195 -0
  76. openstat/commands/tui_cmds.py +111 -0
  77. openstat/commands/ux_cmds.py +191 -0
  78. openstat/commands/validate_cmds.py +270 -0
  79. openstat/commands/viz_adv_cmds.py +312 -0
  80. openstat/commands/viz_extra_cmds.py +251 -0
  81. openstat/commands/watch_cmds.py +69 -0
  82. openstat/config.py +106 -0
  83. openstat/dsl/__init__.py +0 -0
  84. openstat/dsl/parser.py +332 -0
  85. openstat/dsl/tokenizer.py +105 -0
  86. openstat/i18n.py +120 -0
  87. openstat/io/__init__.py +0 -0
  88. openstat/io/loader.py +187 -0
  89. openstat/jupyter/__init__.py +18 -0
  90. openstat/jupyter/display.py +18 -0
  91. openstat/jupyter/magic.py +60 -0
  92. openstat/logging_config.py +59 -0
  93. openstat/plots/__init__.py +0 -0
  94. openstat/plots/plotter.py +437 -0
  95. openstat/plots/surv_plots.py +32 -0
  96. openstat/plots/ts_plots.py +59 -0
  97. openstat/plugins/__init__.py +5 -0
  98. openstat/plugins/manager.py +69 -0
  99. openstat/repl.py +457 -0
  100. openstat/reporting/__init__.py +0 -0
  101. openstat/reporting/eda.py +208 -0
  102. openstat/reporting/report.py +67 -0
  103. openstat/script_runner.py +319 -0
  104. openstat/session.py +133 -0
  105. openstat/stats/__init__.py +0 -0
  106. openstat/stats/advanced_regression.py +269 -0
  107. openstat/stats/arch_garch.py +84 -0
  108. openstat/stats/bayesian.py +103 -0
  109. openstat/stats/causal.py +258 -0
  110. openstat/stats/clustering.py +206 -0
  111. openstat/stats/discrete.py +311 -0
  112. openstat/stats/epidemiology.py +119 -0
  113. openstat/stats/equiv_tobit.py +163 -0
  114. openstat/stats/factor.py +174 -0
  115. openstat/stats/imputation.py +282 -0
  116. openstat/stats/influence.py +78 -0
  117. openstat/stats/iv.py +131 -0
  118. openstat/stats/manova.py +124 -0
  119. openstat/stats/mixed.py +128 -0
  120. openstat/stats/ml.py +275 -0
  121. openstat/stats/ml_advanced.py +117 -0
  122. openstat/stats/model_eval.py +183 -0
  123. openstat/stats/models.py +1342 -0
  124. openstat/stats/nonparametric.py +130 -0
  125. openstat/stats/panel.py +179 -0
  126. openstat/stats/power.py +295 -0
  127. openstat/stats/resampling.py +203 -0
  128. openstat/stats/survey.py +213 -0
  129. openstat/stats/survival.py +196 -0
  130. openstat/stats/timeseries.py +142 -0
  131. openstat/stats/ts_advanced.py +114 -0
  132. openstat/types.py +11 -0
  133. openstat/web/__init__.py +1 -0
  134. openstat/web/app.py +117 -0
  135. openstat/web/session_manager.py +73 -0
  136. openstat/web/static/app.js +117 -0
  137. openstat/web/static/index.html +38 -0
  138. openstat/web/static/style.css +103 -0
  139. openstat_cli-1.0.0.dist-info/METADATA +748 -0
  140. openstat_cli-1.0.0.dist-info/RECORD +143 -0
  141. openstat_cli-1.0.0.dist-info/WHEEL +4 -0
  142. openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
  143. openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,104 @@
1
+ """Regex column commands: regex extract/replace/match/split."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re as _re
6
+
7
+ from openstat.commands.base import command, CommandArgs, friendly_error
8
+ from openstat.session import Session
9
+
10
+
11
+ @command("regex", usage="regex extract|replace|match|split <col> <pattern> [options]")
12
+ def cmd_regex(session: Session, args: str) -> str:
13
+ """Apply regex operations to a string column.
14
+
15
+ Sub-commands:
16
+ regex extract <col> <pattern> [into(<newcol>)]
17
+ — extract first capture group; stores in new column
18
+ regex replace <col> <pattern> <replacement> [into(<newcol>)]
19
+ — replace all matches with replacement string
20
+ regex match <col> <pattern> [into(<newcol>)]
21
+ — add boolean column: 1 if row matches, 0 otherwise
22
+ regex split <col> <pattern> [into(<newcol>)]
23
+ — split on pattern, store list as string repr
24
+
25
+ Examples:
26
+ regex extract email "([^@]+)@" into(username)
27
+ regex replace phone "[^0-9]" "" into(phone_clean)
28
+ regex match address "\\bStreet\\b" into(is_street)
29
+ regex split tags "," into(tag_list)
30
+ """
31
+ import polars as pl
32
+
33
+ ca = CommandArgs(args)
34
+ if len(ca.positional) < 3:
35
+ return "Usage: regex extract|replace|match|split <col> <pattern> ..."
36
+
37
+ subcmd = ca.positional[0].lower()
38
+ col_name = ca.positional[1]
39
+ pattern = ca.positional[2]
40
+
41
+ try:
42
+ df = session.require_data()
43
+ if col_name not in df.columns:
44
+ return f"Column not found: {col_name}"
45
+
46
+ # Validate regex
47
+ try:
48
+ compiled = _re.compile(pattern)
49
+ except _re.error as exc:
50
+ return f"Invalid regex: {exc}"
51
+
52
+ into_raw = ca.rest_after("into")
53
+ new_col = into_raw.strip().strip("()") if into_raw else None
54
+
55
+ if subcmd == "extract":
56
+ new_col = new_col or f"{col_name}_extracted"
57
+ vals = df[col_name].cast(pl.Utf8).to_list()
58
+ results = []
59
+ for v in vals:
60
+ if v is None:
61
+ results.append(None)
62
+ continue
63
+ m = compiled.search(v)
64
+ if m:
65
+ results.append(m.group(1) if m.lastindex else m.group(0))
66
+ else:
67
+ results.append(None)
68
+ session.snapshot()
69
+ session.df = df.with_columns(pl.Series(new_col, results))
70
+ n_matched = sum(1 for r in results if r is not None)
71
+ return f"Extracted to '{new_col}': {n_matched}/{df.height} rows matched."
72
+
73
+ elif subcmd == "replace":
74
+ repl = ca.positional[3] if len(ca.positional) > 3 else ""
75
+ new_col = new_col or col_name
76
+ vals = df[col_name].cast(pl.Utf8).to_list()
77
+ results = [compiled.sub(repl, v) if v is not None else None for v in vals]
78
+ session.snapshot()
79
+ session.df = df.with_columns(pl.Series(new_col, results))
80
+ n_changed = sum(1 for orig, new in zip(vals, results) if orig != new)
81
+ return f"Replaced in '{new_col}': {n_changed}/{df.height} rows changed."
82
+
83
+ elif subcmd == "match":
84
+ new_col = new_col or f"{col_name}_match"
85
+ vals = df[col_name].cast(pl.Utf8).to_list()
86
+ results = [1 if (v is not None and compiled.search(v)) else 0 for v in vals]
87
+ session.snapshot()
88
+ session.df = df.with_columns(pl.Series(new_col, results, dtype=pl.Int8))
89
+ n_match = sum(results)
90
+ return f"Match column '{new_col}': {n_match}/{df.height} rows matched."
91
+
92
+ elif subcmd == "split":
93
+ new_col = new_col or f"{col_name}_split"
94
+ vals = df[col_name].cast(pl.Utf8).to_list()
95
+ results = [str(compiled.split(v)) if v is not None else None for v in vals]
96
+ session.snapshot()
97
+ session.df = df.with_columns(pl.Series(new_col, results))
98
+ return f"Split column '{new_col}' created."
99
+
100
+ else:
101
+ return f"Unknown sub-command: {subcmd}. Use extract, replace, match, or split."
102
+
103
+ except Exception as e:
104
+ return friendly_error(e, "regex")
@@ -0,0 +1,48 @@
1
+ """Report and help commands."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from openstat.session import Session
6
+ from openstat.reporting.report import generate_report
7
+ from openstat.commands.base import command, get_registry
8
+
9
+
10
+ @command("report", usage="report [eda [path.html] | path.md]")
11
+ def cmd_report(session: Session, args: str) -> str:
12
+ """Generate a Markdown report or automated EDA HTML report."""
13
+ stripped = args.strip()
14
+ if stripped.startswith("eda"):
15
+ path = stripped[3:].strip() or "outputs/eda_report.html"
16
+ try:
17
+ from openstat.reporting.eda import generate_eda_report
18
+ out = generate_eda_report(session, path)
19
+ return f"EDA report saved: {out}"
20
+ except Exception as e:
21
+ return f"EDA report error: {e}"
22
+
23
+ path = stripped or "outputs/report.md"
24
+ try:
25
+ out = generate_report(session, path)
26
+ return f"Report saved: {out}"
27
+ except Exception as e:
28
+ return f"Report error: {e}"
29
+
30
+
31
+ @command("help", usage="help [command]")
32
+ def cmd_help(session: Session, args: str) -> str:
33
+ """Show available commands or help for a specific command."""
34
+ registry = get_registry()
35
+ if args.strip() and args.strip() in registry:
36
+ handler = registry[args.strip()]
37
+ from openstat.commands.base import get_usage
38
+ usage = get_usage(args.strip())
39
+ doc = handler.__doc__ or "No description."
40
+ return f"{args.strip()}: {doc}\nUsage: {usage}"
41
+
42
+ lines = ["Available commands:", ""]
43
+ for name, handler in sorted(registry.items()):
44
+ doc = (handler.__doc__ or "").split("\n")[0]
45
+ lines.append(f" {name:<25} {doc}")
46
+ lines.append("")
47
+ lines.append("Type 'help <command>' for details. Type 'quit' to exit.")
48
+ return "\n".join(lines)
@@ -0,0 +1,129 @@
1
+ """Reproducibility commands: set seed, session save/replay/info."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import datetime
6
+ import re
7
+ from pathlib import Path
8
+
9
+ from openstat.commands.base import command, get_registry
10
+ from openstat.session import Session
11
+ from openstat import __version__
12
+
13
+
14
+ # Module-level seed tracking
15
+ _current_seed: int | None = None
16
+
17
+
18
+ def get_current_seed() -> int | None:
19
+ return _current_seed
20
+
21
+
22
+
23
+ @command("session", usage="session info | session save <path> | session replay <path>")
24
+ def cmd_session(session: Session, args: str) -> str:
25
+ """Session management: view info, save commands to script, replay a script.
26
+
27
+ Examples:
28
+ session info — show session details
29
+ session save analysis.ost — save all commands to a script file
30
+ session replay script.ost — run an .ost script in current session
31
+ """
32
+ tokens = args.strip().split(None, 1)
33
+ subcmd = tokens[0].lower() if tokens else "info"
34
+
35
+ if subcmd == "info":
36
+ seed = getattr(session, "_repro_seed", _current_seed)
37
+ lines = [
38
+ "Session Information",
39
+ "=" * 50,
40
+ f" OpenStat version : {__version__}",
41
+ f" Dataset : {session.dataset_name or '(none)'}",
42
+ f" Shape : {session.shape_str}",
43
+ f" Random seed : {seed if seed is not None else '(not set)'}",
44
+ f" Commands run : {len(session.history)}",
45
+ f" Models fitted : {len(session.results)}",
46
+ f" Plots generated : {len(session.plot_paths)}",
47
+ f" Output dir : {session.output_dir}",
48
+ f" Log file : {session._log_path or '(none)'}",
49
+ f" Timestamp : {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
50
+ ]
51
+ return "\n".join(lines)
52
+
53
+ elif subcmd == "save":
54
+ path = tokens[1].strip() if len(tokens) > 1 else "session_script.ost"
55
+ path_obj = Path(path)
56
+ path_obj.parent.mkdir(parents=True, exist_ok=True)
57
+ seed = getattr(session, "_repro_seed", _current_seed)
58
+
59
+ with open(path_obj, "w", encoding="utf-8") as f:
60
+ f.write(f"# OpenStat script — saved {datetime.datetime.now().isoformat()}\n")
61
+ f.write(f"# OpenStat version: {__version__}\n")
62
+ if seed is not None:
63
+ f.write(f"# Random seed: {seed}\n")
64
+ f.write(f"set seed {seed}\n")
65
+ f.write(f"# Dataset: {session.dataset_name or '(none)'}\n")
66
+ f.write("\n")
67
+ for cmd_line in session.history:
68
+ # Skip the current 'session save' command
69
+ if cmd_line.strip().startswith("session save"):
70
+ continue
71
+ f.write(f"{cmd_line}\n")
72
+
73
+ return f"Session saved to: {path_obj.absolute()} ({len(session.history)} commands)"
74
+
75
+ elif subcmd == "replay":
76
+ path = tokens[1].strip() if len(tokens) > 1 else None
77
+ if not path:
78
+ return "Usage: session replay <path.ost>"
79
+ if not Path(path).exists():
80
+ return f"File not found: {path}"
81
+ # Use run_script
82
+ from openstat.repl import run_script
83
+ try:
84
+ run_script(path, session)
85
+ return f"Replayed: {path}"
86
+ except SystemExit:
87
+ return f"Replay stopped due to error in: {path}"
88
+ except Exception as exc:
89
+ return f"Replay error: {exc}"
90
+
91
+ else:
92
+ return (
93
+ "Usage:\n"
94
+ " session info — view session details\n"
95
+ " session save <path.ost> — save commands to script\n"
96
+ " session replay <path.ost> — run a script file"
97
+ )
98
+
99
+
100
+ @command("version", usage="version")
101
+ def cmd_version(session: Session, args: str) -> str:
102
+ """Show OpenStat version and environment information."""
103
+ import sys
104
+ import platform
105
+
106
+ lines = [
107
+ f"OpenStat {__version__}",
108
+ f"Python {sys.version.split()[0]}",
109
+ f"Platform {platform.system()} {platform.machine()}",
110
+ ]
111
+ deps = [
112
+ ("polars", "polars"),
113
+ ("numpy", "numpy"),
114
+ ("statsmodels", "statsmodels"),
115
+ ("scipy", "scipy"),
116
+ ("matplotlib", "matplotlib"),
117
+ ]
118
+ for name, mod in deps:
119
+ try:
120
+ m = __import__(mod)
121
+ lines.append(f" {name:<15} {m.__version__}")
122
+ except ImportError:
123
+ lines.append(f" {name:<15} (not installed)")
124
+
125
+ seed = getattr(session, "_repro_seed", _current_seed)
126
+ if seed is not None:
127
+ lines.append(f"Random seed: {seed}")
128
+
129
+ return "\n".join(lines)
@@ -0,0 +1,109 @@
1
+ """Bootstrap and permutation test commands."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ from openstat.commands.base import command
8
+ from openstat.session import Session
9
+
10
+
11
+ def _stata_opts(raw: str) -> tuple[list[str], dict[str, str]]:
12
+ opts: dict[str, str] = {}
13
+ for m in re.finditer(r'(\w+)\(([^)]*)\)', raw):
14
+ opts[m.group(1).lower()] = m.group(2)
15
+ rest = re.sub(r'\w+\([^)]*\)', '', raw)
16
+ positional = [t.strip(',') for t in rest.split() if t.strip(',')]
17
+ return positional, opts
18
+
19
+
20
+ def _fmt(r: dict) -> str:
21
+ lines = [f"\n{r.get('test', 'Result')}", "=" * 55]
22
+ skip = {"test", "groups", "_model"}
23
+ for k, v in r.items():
24
+ if k in skip:
25
+ continue
26
+ if isinstance(v, float):
27
+ lines.append(f" {k:<35} {v:.6f}")
28
+ elif isinstance(v, list):
29
+ lines.append(f" {k:<35} {v}")
30
+ else:
31
+ lines.append(f" {k:<35} {v}")
32
+ lines.append("=" * 55)
33
+ return "\n".join(lines)
34
+
35
+
36
+ @command("bootstrap", usage="bootstrap var [by(groupvar)] [stat(mean)] [n(2000)] [ci(0.95)]")
37
+ def cmd_bootstrap(session: Session, args: str) -> str:
38
+ """Bootstrap confidence interval. With by(): tests difference between groups."""
39
+ from openstat.stats.resampling import bootstrap_ci, bootstrap_diff
40
+ df = session.require_data()
41
+ positional, opts = _stata_opts(args)
42
+ if not positional:
43
+ return "Usage: bootstrap var [by(group)] [stat(mean)] [n(2000)] [ci(0.95)]"
44
+ col = positional[0]
45
+ if col not in df.columns:
46
+ return f"Column '{col}' not found."
47
+ by = opts.get("by")
48
+ stat = opts.get("stat", "mean")
49
+ n_boot = int(opts.get("n", 2000))
50
+ ci = float(opts.get("ci", 0.95))
51
+ try:
52
+ if by:
53
+ if by not in df.columns:
54
+ return f"Group column '{by}' not found."
55
+ r = bootstrap_diff(df, col, by, stat=stat, n_boot=n_boot, ci=ci)
56
+ else:
57
+ r = bootstrap_ci(df, col, stat=stat, n_boot=n_boot, ci=ci)
58
+ return _fmt(r)
59
+ except Exception as exc:
60
+ return f"bootstrap error: {exc}"
61
+
62
+
63
+ @command("permtest", usage="permtest var by(groupvar) [stat(mean)] [n(2000)] [--greater|--less]")
64
+ def cmd_permtest(session: Session, args: str) -> str:
65
+ """Permutation test for difference between two groups."""
66
+ from openstat.stats.resampling import permutation_test
67
+ df = session.require_data()
68
+ positional, opts = _stata_opts(args)
69
+ if not positional:
70
+ return "Usage: permtest var by(groupvar) [stat(mean)] [n(2000)]"
71
+ col = positional[0]
72
+ if col not in df.columns:
73
+ return f"Column '{col}' not found."
74
+ by = opts.get("by")
75
+ if not by:
76
+ return "Specify group variable: permtest var by(groupvar)"
77
+ if by not in df.columns:
78
+ return f"Group column '{by}' not found."
79
+ stat = opts.get("stat", "mean")
80
+ n_perm = int(opts.get("n", 2000))
81
+ alt = "two-sided"
82
+ if "--greater" in args:
83
+ alt = "greater"
84
+ elif "--less" in args:
85
+ alt = "less"
86
+ try:
87
+ r = permutation_test(df, col, by, stat=stat, n_perm=n_perm, alternative=alt)
88
+ return _fmt(r)
89
+ except Exception as exc:
90
+ return f"permtest error: {exc}"
91
+
92
+
93
+ @command("jackknife", usage="jackknife var [stat(mean)]")
94
+ def cmd_jackknife(session: Session, args: str) -> str:
95
+ """Jackknife bias and standard error estimation."""
96
+ from openstat.stats.resampling import jackknife_ci
97
+ df = session.require_data()
98
+ positional, opts = _stata_opts(args)
99
+ if not positional:
100
+ return "Usage: jackknife var [stat(mean)]"
101
+ col = positional[0]
102
+ if col not in df.columns:
103
+ return f"Column '{col}' not found."
104
+ stat = opts.get("stat", "mean")
105
+ try:
106
+ r = jackknife_ci(df, col, stat=stat)
107
+ return _fmt(r)
108
+ except Exception as exc:
109
+ return f"jackknife error: {exc}"
@@ -0,0 +1,223 @@
1
+ """reshape, collapse, encode, decode commands."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ import polars as pl
8
+
9
+ from openstat.commands.base import command
10
+ from openstat.session import Session
11
+
12
+
13
+ def _stata_opts(raw: str) -> tuple[list[str], dict[str, str]]:
14
+ opts: dict[str, str] = {}
15
+ for m in re.finditer(r'(\w+)\(([^)]*)\)', raw):
16
+ opts[m.group(1).lower()] = m.group(2)
17
+ rest = re.sub(r'\w+\([^)]*\)', '', raw)
18
+ positional = [t.strip(',') for t in rest.split() if t.strip(',')]
19
+ return positional, opts
20
+
21
+
22
+ # ── reshape ────────────────────────────────────────────────────────────────
23
+
24
+ @command("reshape", usage="reshape wide|long varlist, i(id) j(timevar) [stub(prefix)]")
25
+ def cmd_reshape(session: Session, args: str) -> str:
26
+ """Reshape data between wide and long format (Stata-style).
27
+
28
+ Wide→Long: reshape long prefix, i(id) j(time)
29
+ Long→Wide: reshape wide varlist, i(id) j(timevar)
30
+ """
31
+ df = session.require_data()
32
+ positional, opts = _stata_opts(args)
33
+
34
+ if len(positional) < 2:
35
+ return (
36
+ "Usage:\n"
37
+ " reshape long stubname, i(idvar) j(timevar)\n"
38
+ " reshape wide varlist, i(idvar) j(timevar)"
39
+ )
40
+
41
+ direction = positional[0].lower()
42
+ id_var = opts.get("i")
43
+ j_var = opts.get("j")
44
+ if not id_var or not j_var:
45
+ return "Specify: i(idvar) j(timevar)"
46
+
47
+ session.snapshot()
48
+
49
+ if direction == "long":
50
+ # wide → long: column prefix → stub
51
+ stub = positional[1] if len(positional) > 1 else ""
52
+ value_vars = [c for c in df.columns if c.startswith(stub) and c != id_var]
53
+ if not value_vars:
54
+ return f"No columns found with prefix '{stub}'"
55
+ try:
56
+ long_df = df.unpivot(
57
+ on=value_vars,
58
+ index=[id_var],
59
+ variable_name=j_var,
60
+ value_name=stub or "value",
61
+ )
62
+ # Extract numeric suffix from variable name
63
+ long_df = long_df.with_columns(
64
+ pl.col(j_var).str.replace(stub, "", literal=True).alias(j_var)
65
+ )
66
+ session.df = long_df
67
+ return f"Reshaped wide→long: {df.shape} → {long_df.shape}. {j_var} = {long_df[j_var].unique().to_list()}"
68
+ except Exception as exc:
69
+ return f"reshape long error: {exc}"
70
+
71
+ elif direction == "wide":
72
+ # long → wide
73
+ var_list = [c for c in positional[1:] if c in df.columns]
74
+ if not var_list:
75
+ return "No valid value variables found."
76
+ try:
77
+ wide_df = df.pivot(
78
+ on=j_var,
79
+ index=id_var,
80
+ values=var_list[0] if len(var_list) == 1 else var_list,
81
+ aggregate_function="first",
82
+ )
83
+ session.df = wide_df
84
+ return f"Reshaped long→wide: {df.shape} → {wide_df.shape}"
85
+ except Exception as exc:
86
+ return f"reshape wide error: {exc}"
87
+
88
+ else:
89
+ return f"Unknown reshape direction: {direction}. Use 'wide' or 'long'."
90
+
91
+
92
+ # ── collapse ───────────────────────────────────────────────────────────────
93
+
94
+ _COLLAPSE_FUNS = {
95
+ "mean": lambda c: pl.col(c).mean(),
96
+ "sum": lambda c: pl.col(c).sum(),
97
+ "count": lambda c: pl.col(c).count(),
98
+ "min": lambda c: pl.col(c).min(),
99
+ "max": lambda c: pl.col(c).max(),
100
+ "median": lambda c: pl.col(c).median(),
101
+ "std": lambda c: pl.col(c).std(),
102
+ "var": lambda c: pl.col(c).var(),
103
+ "first": lambda c: pl.col(c).first(),
104
+ "last": lambda c: pl.col(c).last(),
105
+ }
106
+
107
+
108
+ @command("collapse", usage="collapse (stat) varlist [, by(groupvars)]")
109
+ def cmd_collapse(session: Session, args: str) -> str:
110
+ """Collapse dataset to group-level aggregates (replaces df).
111
+
112
+ Examples:
113
+ collapse (mean) income age, by(region)
114
+ collapse (sum) sales, by(year region)
115
+ """
116
+ df = session.require_data()
117
+ session.snapshot()
118
+
119
+ # Parse (stat) from args
120
+ stat_m = re.search(r'\((\w+)\)', args)
121
+ stat = stat_m.group(1).lower() if stat_m else "mean"
122
+ if stat not in _COLLAPSE_FUNS:
123
+ return f"Unknown statistic: {stat}. Use: {', '.join(_COLLAPSE_FUNS)}"
124
+
125
+ # _stata_opts extracts key(value) pairs (including by(...)) from full args
126
+ positional_raw, opts = _stata_opts(args)
127
+ # positional_raw may include "(mean)" token — filter it out
128
+ value_vars = [c for c in positional_raw if c in df.columns]
129
+ by_raw = opts.get("by", "")
130
+ by_vars = [c.strip() for c in by_raw.split() if c.strip() in df.columns]
131
+
132
+ if not value_vars:
133
+ return "No valid value variables found."
134
+
135
+ agg_fn = _COLLAPSE_FUNS[stat]
136
+ agg_exprs = [agg_fn(c).alias(c) for c in value_vars]
137
+
138
+ try:
139
+ if by_vars:
140
+ result = df.group_by(by_vars).agg(agg_exprs).sort(by_vars)
141
+ else:
142
+ result = df.select(agg_exprs)
143
+
144
+ session.df = result
145
+ return (
146
+ f"Collapsed to {result.shape[0]} rows × {result.shape[1]} cols "
147
+ f"using {stat}({', '.join(value_vars)})"
148
+ + (f" by {', '.join(by_vars)}" if by_vars else "")
149
+ )
150
+ except Exception as exc:
151
+ return f"collapse error: {exc}"
152
+
153
+
154
+ # ── encode ─────────────────────────────────────────────────────────────────
155
+
156
+ @command("encode", usage="encode varname [, gen(newvar)]")
157
+ def cmd_encode(session: Session, args: str) -> str:
158
+ """Encode a string/categorical column to integer codes (0-based)."""
159
+ df = session.require_data()
160
+ positional, opts = _stata_opts(args)
161
+ if not positional:
162
+ return "Usage: encode varname [, gen(newvar)]"
163
+
164
+ var = positional[0]
165
+ if var not in df.columns:
166
+ return f"Column '{var}' not found."
167
+
168
+ new_var = opts.get("gen", var + "_encoded")
169
+ session.snapshot()
170
+
171
+ try:
172
+ # Map unique sorted values to integers
173
+ unique_vals = sorted(df[var].drop_nulls().unique().to_list(), key=str)
174
+ val_map = {v: i for i, v in enumerate(unique_vals)}
175
+ encoded = df[var].map_elements(
176
+ lambda x: val_map.get(x), return_dtype=pl.Int64
177
+ )
178
+ session.df = df.with_columns(encoded.alias(new_var))
179
+ mapping_str = "\n".join(f" {i} = {v}" for i, v in enumerate(unique_vals[:20]))
180
+ if len(unique_vals) > 20:
181
+ mapping_str += f"\n ... ({len(unique_vals)} total)"
182
+ return f"Encoded '{var}' → '{new_var}' ({len(unique_vals)} categories)\n{mapping_str}"
183
+ except Exception as exc:
184
+ return f"encode error: {exc}"
185
+
186
+
187
+ # ── decode ─────────────────────────────────────────────────────────────────
188
+
189
+ @command("decode", usage="decode encodedvar origvar [, gen(newvar)]")
190
+ def cmd_decode(session: Session, args: str) -> str:
191
+ """Decode integer codes back to string labels.
192
+
193
+ Requires a reference (original) string column with the same row order.
194
+ decode encoded_col orig_col [, gen(newvar)]
195
+ """
196
+ df = session.require_data()
197
+ positional, opts = _stata_opts(args)
198
+ if len(positional) < 2:
199
+ return "Usage: decode encoded_col orig_col [, gen(newvar)]"
200
+
201
+ enc_var = positional[0]
202
+ orig_var = positional[1]
203
+ new_var = opts.get("gen", enc_var + "_decoded")
204
+
205
+ for v in (enc_var, orig_var):
206
+ if v not in df.columns:
207
+ return f"Column '{v}' not found."
208
+
209
+ session.snapshot()
210
+ try:
211
+ # Build map from int codes → original labels using orig_var
212
+ code_col = df[enc_var].cast(pl.Int64)
213
+ label_col = df[orig_var].cast(pl.Utf8)
214
+ pairs = list(zip(code_col.to_list(), label_col.to_list()))
215
+ code_map = {c: l for c, l in pairs if c is not None and l is not None}
216
+
217
+ decoded = df[enc_var].map_elements(
218
+ lambda x: code_map.get(x, None), return_dtype=pl.Utf8
219
+ )
220
+ session.df = df.with_columns(decoded.alias(new_var))
221
+ return f"Decoded '{enc_var}' → '{new_var}' ({len(code_map)} unique codes)"
222
+ except Exception as exc:
223
+ return f"decode error: {exc}"