openstat-cli 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. openstat/__init__.py +3 -0
  2. openstat/__main__.py +4 -0
  3. openstat/backends/__init__.py +16 -0
  4. openstat/backends/duckdb_backend.py +70 -0
  5. openstat/backends/polars_backend.py +52 -0
  6. openstat/cli.py +92 -0
  7. openstat/commands/__init__.py +82 -0
  8. openstat/commands/adv_stat_cmds.py +1255 -0
  9. openstat/commands/advanced_ml_cmds.py +576 -0
  10. openstat/commands/advreg_cmds.py +207 -0
  11. openstat/commands/alias_cmds.py +135 -0
  12. openstat/commands/arch_cmds.py +82 -0
  13. openstat/commands/arules_cmds.py +111 -0
  14. openstat/commands/automodel_cmds.py +212 -0
  15. openstat/commands/backend_cmds.py +82 -0
  16. openstat/commands/base.py +170 -0
  17. openstat/commands/bayes_cmds.py +71 -0
  18. openstat/commands/causal_cmds.py +269 -0
  19. openstat/commands/cluster_cmds.py +152 -0
  20. openstat/commands/data_cmds.py +996 -0
  21. openstat/commands/datamanip_cmds.py +672 -0
  22. openstat/commands/dataquality_cmds.py +174 -0
  23. openstat/commands/datetime_cmds.py +176 -0
  24. openstat/commands/dimreduce_cmds.py +184 -0
  25. openstat/commands/discrete_cmds.py +149 -0
  26. openstat/commands/dsl_cmds.py +143 -0
  27. openstat/commands/epi_cmds.py +93 -0
  28. openstat/commands/equiv_tobit_cmds.py +94 -0
  29. openstat/commands/esttab_cmds.py +196 -0
  30. openstat/commands/export_beamer_cmds.py +142 -0
  31. openstat/commands/export_cmds.py +201 -0
  32. openstat/commands/export_extra_cmds.py +240 -0
  33. openstat/commands/factor_cmds.py +180 -0
  34. openstat/commands/groupby_cmds.py +155 -0
  35. openstat/commands/help_cmds.py +237 -0
  36. openstat/commands/i18n_cmds.py +43 -0
  37. openstat/commands/import_extra_cmds.py +561 -0
  38. openstat/commands/influence_cmds.py +134 -0
  39. openstat/commands/iv_cmds.py +106 -0
  40. openstat/commands/manova_cmds.py +105 -0
  41. openstat/commands/mediate_cmds.py +233 -0
  42. openstat/commands/meta_cmds.py +284 -0
  43. openstat/commands/mi_cmds.py +228 -0
  44. openstat/commands/mixed_cmds.py +79 -0
  45. openstat/commands/mixture_changepoint_cmds.py +166 -0
  46. openstat/commands/ml_adv_cmds.py +147 -0
  47. openstat/commands/ml_cmds.py +178 -0
  48. openstat/commands/model_eval_cmds.py +142 -0
  49. openstat/commands/network_cmds.py +288 -0
  50. openstat/commands/nlquery_cmds.py +161 -0
  51. openstat/commands/nonparam_cmds.py +149 -0
  52. openstat/commands/outreg_cmds.py +247 -0
  53. openstat/commands/panel_cmds.py +141 -0
  54. openstat/commands/pdf_cmds.py +226 -0
  55. openstat/commands/pipeline_cmds.py +319 -0
  56. openstat/commands/plot_cmds.py +189 -0
  57. openstat/commands/plugin_cmds.py +79 -0
  58. openstat/commands/posthoc_cmds.py +153 -0
  59. openstat/commands/power_cmds.py +172 -0
  60. openstat/commands/profile_cmds.py +246 -0
  61. openstat/commands/rbridge_cmds.py +81 -0
  62. openstat/commands/regex_cmds.py +104 -0
  63. openstat/commands/report_cmds.py +48 -0
  64. openstat/commands/repro_cmds.py +129 -0
  65. openstat/commands/resampling_cmds.py +109 -0
  66. openstat/commands/reshape_cmds.py +223 -0
  67. openstat/commands/sem_cmds.py +177 -0
  68. openstat/commands/stat_cmds.py +1040 -0
  69. openstat/commands/stata_import_cmds.py +215 -0
  70. openstat/commands/string_cmds.py +124 -0
  71. openstat/commands/surv_cmds.py +145 -0
  72. openstat/commands/survey_cmds.py +153 -0
  73. openstat/commands/textanalysis_cmds.py +192 -0
  74. openstat/commands/ts_adv_cmds.py +136 -0
  75. openstat/commands/ts_cmds.py +195 -0
  76. openstat/commands/tui_cmds.py +111 -0
  77. openstat/commands/ux_cmds.py +191 -0
  78. openstat/commands/validate_cmds.py +270 -0
  79. openstat/commands/viz_adv_cmds.py +312 -0
  80. openstat/commands/viz_extra_cmds.py +251 -0
  81. openstat/commands/watch_cmds.py +69 -0
  82. openstat/config.py +106 -0
  83. openstat/dsl/__init__.py +0 -0
  84. openstat/dsl/parser.py +332 -0
  85. openstat/dsl/tokenizer.py +105 -0
  86. openstat/i18n.py +120 -0
  87. openstat/io/__init__.py +0 -0
  88. openstat/io/loader.py +187 -0
  89. openstat/jupyter/__init__.py +18 -0
  90. openstat/jupyter/display.py +18 -0
  91. openstat/jupyter/magic.py +60 -0
  92. openstat/logging_config.py +59 -0
  93. openstat/plots/__init__.py +0 -0
  94. openstat/plots/plotter.py +437 -0
  95. openstat/plots/surv_plots.py +32 -0
  96. openstat/plots/ts_plots.py +59 -0
  97. openstat/plugins/__init__.py +5 -0
  98. openstat/plugins/manager.py +69 -0
  99. openstat/repl.py +457 -0
  100. openstat/reporting/__init__.py +0 -0
  101. openstat/reporting/eda.py +208 -0
  102. openstat/reporting/report.py +67 -0
  103. openstat/script_runner.py +319 -0
  104. openstat/session.py +133 -0
  105. openstat/stats/__init__.py +0 -0
  106. openstat/stats/advanced_regression.py +269 -0
  107. openstat/stats/arch_garch.py +84 -0
  108. openstat/stats/bayesian.py +103 -0
  109. openstat/stats/causal.py +258 -0
  110. openstat/stats/clustering.py +206 -0
  111. openstat/stats/discrete.py +311 -0
  112. openstat/stats/epidemiology.py +119 -0
  113. openstat/stats/equiv_tobit.py +163 -0
  114. openstat/stats/factor.py +174 -0
  115. openstat/stats/imputation.py +282 -0
  116. openstat/stats/influence.py +78 -0
  117. openstat/stats/iv.py +131 -0
  118. openstat/stats/manova.py +124 -0
  119. openstat/stats/mixed.py +128 -0
  120. openstat/stats/ml.py +275 -0
  121. openstat/stats/ml_advanced.py +117 -0
  122. openstat/stats/model_eval.py +183 -0
  123. openstat/stats/models.py +1342 -0
  124. openstat/stats/nonparametric.py +130 -0
  125. openstat/stats/panel.py +179 -0
  126. openstat/stats/power.py +295 -0
  127. openstat/stats/resampling.py +203 -0
  128. openstat/stats/survey.py +213 -0
  129. openstat/stats/survival.py +196 -0
  130. openstat/stats/timeseries.py +142 -0
  131. openstat/stats/ts_advanced.py +114 -0
  132. openstat/types.py +11 -0
  133. openstat/web/__init__.py +1 -0
  134. openstat/web/app.py +117 -0
  135. openstat/web/session_manager.py +73 -0
  136. openstat/web/static/app.js +117 -0
  137. openstat/web/static/index.html +38 -0
  138. openstat/web/static/style.css +103 -0
  139. openstat_cli-1.0.0.dist-info/METADATA +748 -0
  140. openstat_cli-1.0.0.dist-info/RECORD +143 -0
  141. openstat_cli-1.0.0.dist-info/WHEEL +4 -0
  142. openstat_cli-1.0.0.dist-info/entry_points.txt +2 -0
  143. openstat_cli-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,672 @@
1
+ """Extended data manipulation commands:
2
+ cast (multi-col), lag, lead, cumulative, bin, antijoin, semijoin,
3
+ sample stratified, anonymize.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import re
9
+
10
+ import polars as pl
11
+
12
+ from openstat.commands.base import command, CommandArgs, friendly_error
13
+ from openstat.session import Session
14
+
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Helpers
18
+ # ---------------------------------------------------------------------------
19
+
20
+ _TYPE_MAP: dict[str, pl.DataType] = {
21
+ "int": pl.Int64,
22
+ "float": pl.Float64,
23
+ "str": pl.Utf8,
24
+ "bool": pl.Boolean,
25
+ "date": pl.Date,
26
+ "datetime": pl.Datetime,
27
+ }
28
+
29
+ _VALID_TYPES = ", ".join(_TYPE_MAP)
30
+
31
+
32
+ def _parse_into(args: str) -> tuple[str, str | None]:
33
+ """Strip trailing into(<newcol>) from an arg string.
34
+
35
+ Returns (remaining_args, new_col_or_None).
36
+ """
37
+ m = re.search(r'\binto\(\s*(\w+)\s*\)', args, re.IGNORECASE)
38
+ if m:
39
+ new_col = m.group(1)
40
+ remaining = (args[: m.start()] + args[m.end():]).strip()
41
+ return remaining, new_col
42
+ return args, None
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # cast — multi-column version
47
+ # ---------------------------------------------------------------------------
48
+
49
+ @command("cast", usage="cast <col> <type> [col2 type2 ...]")
50
+ def cmd_cast(session: Session, args: str) -> str:
51
+ """Cast one or more columns to a new type.
52
+
53
+ Types: int, float, str, bool, date, datetime
54
+ Example: cast age int income float gender str
55
+ """
56
+ df = session.require_data()
57
+ tokens = args.split()
58
+ if len(tokens) < 2:
59
+ return (
60
+ "Usage: cast <col> <type> [col2 type2 ...]\n"
61
+ f"Types: {_VALID_TYPES}"
62
+ )
63
+
64
+ # Pair up tokens: (col, type), (col, type), ...
65
+ if len(tokens) % 2 != 0:
66
+ return (
67
+ "Provide pairs of <col> <type>.\n"
68
+ f"Types: {_VALID_TYPES}"
69
+ )
70
+
71
+ pairs: list[tuple[str, str]] = [
72
+ (tokens[i], tokens[i + 1].lower()) for i in range(0, len(tokens), 2)
73
+ ]
74
+
75
+ # Validate all before touching the data
76
+ for col, tname in pairs:
77
+ if col not in df.columns:
78
+ return f"Column not found: '{col}'. Use 'describe' to list columns."
79
+ if tname not in _TYPE_MAP:
80
+ return f"Unknown type '{tname}'. Valid types: {_VALID_TYPES}"
81
+
82
+ session.snapshot()
83
+ lines: list[str] = []
84
+ try:
85
+ for col, tname in pairs:
86
+ old_dtype = str(df[col].dtype)
87
+ pl_type = _TYPE_MAP[tname]
88
+ session.df = session.df.with_columns(pl.col(col).cast(pl_type))
89
+ lines.append(f" '{col}': {old_dtype} -> {tname}")
90
+ lines.insert(0, f"Cast {len(pairs)} column(s):")
91
+ lines.append("Use 'undo' to revert.")
92
+ return "\n".join(lines)
93
+ except Exception as exc:
94
+ session.undo()
95
+ return friendly_error(exc, "cast")
96
+
97
+
98
+ # ---------------------------------------------------------------------------
99
+ # lag
100
+ # ---------------------------------------------------------------------------
101
+
102
+ @command("lag", usage="lag <col> [n=1] [into(<newcol>)]")
103
+ def cmd_lag(session: Session, args: str) -> str:
104
+ """Create a lagged column (shift values forward by n rows, default n=1).
105
+
106
+ Example: lag price 2 into(price_lag2)
107
+ """
108
+ df = session.require_data()
109
+
110
+ # Strip into() first
111
+ rest, new_col = _parse_into(args)
112
+ tokens = rest.split()
113
+
114
+ if not tokens:
115
+ return "Usage: lag <col> [n=1] [into(<newcol>)]"
116
+
117
+ col = tokens[0]
118
+ if col not in df.columns:
119
+ return friendly_error(KeyError(col), "lag")
120
+
121
+ n = 1
122
+ if len(tokens) >= 2:
123
+ try:
124
+ n = int(tokens[1])
125
+ except ValueError:
126
+ return f"n must be an integer, got '{tokens[1]}'"
127
+
128
+ if new_col is None:
129
+ new_col = f"{col}_lag{n}"
130
+
131
+ session.snapshot()
132
+ try:
133
+ session.df = df.with_columns(
134
+ pl.col(col).shift(n).alias(new_col)
135
+ )
136
+ return (
137
+ f"Created '{new_col}' as lag({col}, {n}). "
138
+ "Use 'undo' to revert."
139
+ )
140
+ except Exception as exc:
141
+ session.undo()
142
+ return friendly_error(exc, "lag")
143
+
144
+
145
+ # ---------------------------------------------------------------------------
146
+ # lead
147
+ # ---------------------------------------------------------------------------
148
+
149
+ @command("lead", usage="lead <col> [n=1] [into(<newcol>)]")
150
+ def cmd_lead(session: Session, args: str) -> str:
151
+ """Create a lead column (shift values backward by n rows, default n=1).
152
+
153
+ Example: lead gdp 3 into(gdp_lead3)
154
+ """
155
+ df = session.require_data()
156
+
157
+ rest, new_col = _parse_into(args)
158
+ tokens = rest.split()
159
+
160
+ if not tokens:
161
+ return "Usage: lead <col> [n=1] [into(<newcol>)]"
162
+
163
+ col = tokens[0]
164
+ if col not in df.columns:
165
+ return friendly_error(KeyError(col), "lead")
166
+
167
+ n = 1
168
+ if len(tokens) >= 2:
169
+ try:
170
+ n = int(tokens[1])
171
+ except ValueError:
172
+ return f"n must be an integer, got '{tokens[1]}'"
173
+
174
+ if new_col is None:
175
+ new_col = f"{col}_lead{n}"
176
+
177
+ session.snapshot()
178
+ try:
179
+ session.df = df.with_columns(
180
+ pl.col(col).shift(-n).alias(new_col)
181
+ )
182
+ return (
183
+ f"Created '{new_col}' as lead({col}, {n}). "
184
+ "Use 'undo' to revert."
185
+ )
186
+ except Exception as exc:
187
+ session.undo()
188
+ return friendly_error(exc, "lead")
189
+
190
+
191
+ # ---------------------------------------------------------------------------
192
+ # cumulative
193
+ # ---------------------------------------------------------------------------
194
+
195
+ _CUM_FUNCS: dict[str, str] = {
196
+ "sum": "cum_sum",
197
+ "prod": "cum_prod",
198
+ "max": "cum_max",
199
+ "min": "cum_min",
200
+ "count": "cum_count",
201
+ }
202
+
203
+
204
+ @command("cumulative", usage="cumulative <col> <func> [into(<newcol>)]")
205
+ def cmd_cumulative(session: Session, args: str) -> str:
206
+ """Compute a cumulative statistic along a column.
207
+
208
+ func: sum, prod, max, min, count
209
+ Example: cumulative sales sum into(sales_cumsum)
210
+ """
211
+ df = session.require_data()
212
+
213
+ rest, new_col = _parse_into(args)
214
+ tokens = rest.split()
215
+
216
+ if len(tokens) < 2:
217
+ valid = ", ".join(_CUM_FUNCS)
218
+ return f"Usage: cumulative <col> <func> [into(<newcol>)] (func: {valid})"
219
+
220
+ col, func_name = tokens[0], tokens[1].lower()
221
+ if col not in df.columns:
222
+ return friendly_error(KeyError(col), "cumulative")
223
+
224
+ if func_name not in _CUM_FUNCS:
225
+ valid = ", ".join(_CUM_FUNCS)
226
+ return f"Unknown function '{func_name}'. Valid: {valid}"
227
+
228
+ if new_col is None:
229
+ new_col = f"{col}_cum{func_name}"
230
+
231
+ pl_method = _CUM_FUNCS[func_name]
232
+
233
+ session.snapshot()
234
+ try:
235
+ expr = getattr(pl.col(col), pl_method)()
236
+ session.df = df.with_columns(expr.alias(new_col))
237
+ return (
238
+ f"Created '{new_col}' as cumulative {func_name}({col}). "
239
+ "Use 'undo' to revert."
240
+ )
241
+ except Exception as exc:
242
+ session.undo()
243
+ return friendly_error(exc, "cumulative")
244
+
245
+
246
+ # ---------------------------------------------------------------------------
247
+ # bin
248
+ # ---------------------------------------------------------------------------
249
+
250
+ @command("bin", usage="bin <col> <n_bins> [into(<newcol>)] [--labels=a,b,c] [--equal-freq]")
251
+ def cmd_bin(session: Session, args: str) -> str:
252
+ """Discretize a continuous column into n_bins bins.
253
+
254
+ --equal-freq use quantile-based (equal-frequency) bins; default is equal-width.
255
+ --labels=a,b comma-separated bin labels (must match n_bins).
256
+ Example: bin income 5 into(income_cat)
257
+ Example: bin age 3 --labels=young,middle,senior --equal-freq
258
+ """
259
+ df = session.require_data()
260
+ ca = CommandArgs(args)
261
+
262
+ # Strip into() from the raw string before positional parsing
263
+ raw_no_into, new_col = _parse_into(ca.strip_flags_and_options())
264
+ pos_tokens = raw_no_into.split()
265
+
266
+ if len(pos_tokens) < 2:
267
+ return "Usage: bin <col> <n_bins> [into(<newcol>)] [--labels=a,b,c] [--equal-freq]"
268
+
269
+ col = pos_tokens[0]
270
+ try:
271
+ n_bins = int(pos_tokens[1])
272
+ except ValueError:
273
+ return f"n_bins must be an integer, got '{pos_tokens[1]}'"
274
+
275
+ if n_bins < 2:
276
+ return "n_bins must be >= 2"
277
+
278
+ if col not in df.columns:
279
+ return friendly_error(KeyError(col), "bin")
280
+
281
+ equal_freq = ca.has_flag("--equal-freq")
282
+ labels_raw = ca.options.get("labels")
283
+ labels: list[str] | None = None
284
+ if labels_raw:
285
+ labels = [lb.strip() for lb in labels_raw.split(",")]
286
+ if len(labels) != n_bins:
287
+ return (
288
+ f"Number of labels ({len(labels)}) must match n_bins ({n_bins})."
289
+ )
290
+
291
+ if new_col is None:
292
+ new_col = f"{col}_bin"
293
+
294
+ series = df[col].cast(pl.Float64)
295
+ non_null = series.drop_nulls()
296
+
297
+ if non_null.len() == 0:
298
+ return f"Column '{col}' has no non-null values."
299
+
300
+ try:
301
+ if equal_freq:
302
+ # Quantile breakpoints
303
+ quantiles = [i / n_bins for i in range(n_bins + 1)]
304
+ breaks = [float(non_null.quantile(q, interpolation="linear")) for q in quantiles]
305
+ breaks[0] -= 1e-10 # include minimum
306
+ else:
307
+ lo = float(non_null.min())
308
+ hi = float(non_null.max())
309
+ step = (hi - lo) / n_bins
310
+ breaks = [lo + i * step for i in range(n_bins + 1)]
311
+ breaks[0] -= 1e-10
312
+
313
+ def _bin_value(v: float | None) -> str | None:
314
+ if v is None:
315
+ return None
316
+ for i in range(n_bins):
317
+ if breaks[i] < v <= breaks[i + 1]:
318
+ if labels:
319
+ return labels[i]
320
+ lo_s = f"{breaks[i]:.4g}"
321
+ hi_s = f"{breaks[i + 1]:.4g}"
322
+ return f"({lo_s}, {hi_s}]"
323
+ # Edge: value equals the minimum
324
+ if labels:
325
+ return labels[0]
326
+ lo_s = f"{breaks[0]:.4g}"
327
+ hi_s = f"{breaks[1]:.4g}"
328
+ return f"({lo_s}, {hi_s}]"
329
+
330
+ bin_col = pl.Series(
331
+ name=new_col,
332
+ values=[_bin_value(v) for v in series.to_list()],
333
+ dtype=pl.Utf8,
334
+ )
335
+ except Exception as exc:
336
+ return friendly_error(exc, "bin")
337
+
338
+ session.snapshot()
339
+ session.df = df.with_columns(bin_col)
340
+ method = "equal-frequency" if equal_freq else "equal-width"
341
+ return (
342
+ f"Binned '{col}' into {n_bins} {method} bins -> '{new_col}'. "
343
+ "Use 'undo' to revert."
344
+ )
345
+
346
+
347
+ # ---------------------------------------------------------------------------
348
+ # antijoin
349
+ # ---------------------------------------------------------------------------
350
+
351
+ @command("antijoin", usage="antijoin <file> on(<col>) [how=left|right]")
352
+ def cmd_antijoin(session: Session, args: str) -> str:
353
+ """Keep rows from current dataset that are NOT found in another file.
354
+
355
+ Example: antijoin excluded_ids.csv on(id)
356
+ """
357
+ from openstat.io.loader import load_file
358
+
359
+ df = session.require_data()
360
+ ca = CommandArgs(args)
361
+
362
+ # Parse on(<col>)
363
+ on_m = re.search(r'\bon\(\s*(\w+)\s*\)', args, re.IGNORECASE)
364
+ if not on_m:
365
+ return "Usage: antijoin <file> on(<col>)"
366
+ key_col = on_m.group(1)
367
+
368
+ # File is everything before the on(...) token
369
+ file_part = args[: on_m.start()].strip()
370
+ # Remove any --options or flags from the file part
371
+ file_path = re.sub(r'--\S+', '', file_part).strip()
372
+
373
+ if not file_path:
374
+ return "Usage: antijoin <file> on(<col>)"
375
+
376
+ if key_col not in df.columns:
377
+ return f"Key column '{key_col}' not found in current dataset."
378
+
379
+ try:
380
+ other = load_file(file_path)
381
+ except Exception as exc:
382
+ return f"Cannot load file: {exc}"
383
+
384
+ if key_col not in other.columns:
385
+ return f"Key column '{key_col}' not found in '{file_path}'."
386
+
387
+ session.snapshot()
388
+ try:
389
+ # Anti-join: left join then filter for nulls from right side
390
+ right_keys = other.select(pl.col(key_col).alias("__right_key__")).unique()
391
+ merged = df.join(
392
+ right_keys,
393
+ left_on=key_col,
394
+ right_on="__right_key__",
395
+ how="left",
396
+ )
397
+ mask = merged["__right_key__"].is_null()
398
+ session.df = df.filter(mask)
399
+ kept = session.df.height
400
+ total = df.height
401
+ return (
402
+ f"Anti-join on '{key_col}': kept {kept:,} of {total:,} rows "
403
+ f"(removed {total - kept:,} matching rows). Use 'undo' to revert."
404
+ )
405
+ except Exception as exc:
406
+ session.undo()
407
+ return friendly_error(exc, "antijoin")
408
+
409
+
410
+ # ---------------------------------------------------------------------------
411
+ # semijoin
412
+ # ---------------------------------------------------------------------------
413
+
414
+ @command("semijoin", usage="semijoin <file> on(<col>)")
415
+ def cmd_semijoin(session: Session, args: str) -> str:
416
+ """Keep rows from current dataset that ARE found in another file.
417
+
418
+ Example: semijoin valid_ids.csv on(id)
419
+ """
420
+ from openstat.io.loader import load_file
421
+
422
+ df = session.require_data()
423
+
424
+ on_m = re.search(r'\bon\(\s*(\w+)\s*\)', args, re.IGNORECASE)
425
+ if not on_m:
426
+ return "Usage: semijoin <file> on(<col>)"
427
+ key_col = on_m.group(1)
428
+
429
+ file_part = re.sub(r'--\S+', '', args[: on_m.start()]).strip()
430
+ if not file_part:
431
+ return "Usage: semijoin <file> on(<col>)"
432
+
433
+ if key_col not in df.columns:
434
+ return f"Key column '{key_col}' not found in current dataset."
435
+
436
+ try:
437
+ other = load_file(file_part)
438
+ except Exception as exc:
439
+ return f"Cannot load file: {exc}"
440
+
441
+ if key_col not in other.columns:
442
+ return f"Key column '{key_col}' not found in '{file_part}'."
443
+
444
+ session.snapshot()
445
+ try:
446
+ right_keys = other.select(pl.col(key_col).alias("__right_key__")).unique()
447
+ merged = df.join(
448
+ right_keys,
449
+ left_on=key_col,
450
+ right_on="__right_key__",
451
+ how="left",
452
+ )
453
+ mask = merged["__right_key__"].is_not_null()
454
+ session.df = df.filter(mask)
455
+ kept = session.df.height
456
+ total = df.height
457
+ return (
458
+ f"Semi-join on '{key_col}': kept {kept:,} of {total:,} rows "
459
+ f"(removed {total - kept:,} non-matching rows). Use 'undo' to revert."
460
+ )
461
+ except Exception as exc:
462
+ session.undo()
463
+ return friendly_error(exc, "semijoin")
464
+
465
+
466
+ # ---------------------------------------------------------------------------
467
+ # sample stratified
468
+ # ---------------------------------------------------------------------------
469
+
470
+ @command("sample stratified", usage="sample stratified <n_or_frac> by(<stratum_col>) [--seed=N]")
471
+ def cmd_sample_stratified(session: Session, args: str) -> str:
472
+ """Stratified random sample: draw proportionally from each stratum.
473
+
474
+ n_or_frac: integer (absolute per stratum) or float < 1 (fraction of each stratum).
475
+ Example: sample stratified 100 by(region)
476
+ Example: sample stratified 0.2 by(gender) --seed=42
477
+ """
478
+ df = session.require_data()
479
+ ca = CommandArgs(args)
480
+
481
+ # Parse by(<col>)
482
+ by_m = re.search(r'\bby\(\s*(\w+)\s*\)', args, re.IGNORECASE)
483
+ if not by_m:
484
+ return "Usage: sample stratified <n_or_frac> by(<stratum_col>) [--seed=N]"
485
+ stratum_col = by_m.group(1)
486
+
487
+ if stratum_col not in df.columns:
488
+ return f"Stratum column '{stratum_col}' not found."
489
+
490
+ # First positional token is n_or_frac
491
+ pos_tokens = ca.positional
492
+ if not pos_tokens:
493
+ return "Provide a sample size or fraction as the first argument."
494
+
495
+ try:
496
+ n_or_frac_raw = pos_tokens[0]
497
+ if "." in n_or_frac_raw:
498
+ n_or_frac: float | int = float(n_or_frac_raw)
499
+ use_frac = True
500
+ else:
501
+ n_or_frac = int(n_or_frac_raw)
502
+ use_frac = False
503
+ except ValueError:
504
+ return f"n_or_frac must be a number, got '{pos_tokens[0]}'"
505
+
506
+ if use_frac and not (0 < n_or_frac < 1):
507
+ return "Fraction must be between 0 and 1 (exclusive)."
508
+ if not use_frac and n_or_frac <= 0:
509
+ return "Sample size must be a positive integer."
510
+
511
+ seed_str = ca.options.get("seed")
512
+ seed: int | None = None
513
+ if seed_str is not None:
514
+ try:
515
+ seed = int(seed_str)
516
+ except ValueError:
517
+ return f"--seed must be an integer, got '{seed_str}'"
518
+
519
+ session.snapshot()
520
+ try:
521
+ strata = df[stratum_col].unique().to_list()
522
+ parts: list[pl.DataFrame] = []
523
+ for stratum_val in strata:
524
+ stratum_df = df.filter(pl.col(stratum_col) == stratum_val)
525
+ stratum_n = stratum_df.height
526
+ if stratum_n == 0:
527
+ continue
528
+ if use_frac:
529
+ take = max(1, int(stratum_n * float(n_or_frac)))
530
+ else:
531
+ take = min(int(n_or_frac), stratum_n)
532
+ parts.append(stratum_df.sample(n=take, shuffle=True, seed=seed))
533
+
534
+ if not parts:
535
+ session.undo()
536
+ return "No strata found — dataset may be empty."
537
+
538
+ session.df = pl.concat(parts).sample(fraction=1.0, shuffle=True, seed=seed)
539
+ total_sampled = session.df.height
540
+ mode_desc = f"{n_or_frac} per stratum" if not use_frac else f"{int(float(n_or_frac)*100)}% per stratum"
541
+ return (
542
+ f"Stratified sample ({mode_desc}, {len(strata)} strata): "
543
+ f"{total_sampled:,} rows drawn from {df.height:,}. "
544
+ "Use 'undo' to revert."
545
+ )
546
+ except Exception as exc:
547
+ session.undo()
548
+ return friendly_error(exc, "sample stratified")
549
+
550
+
551
+ # ---------------------------------------------------------------------------
552
+ # anonymize
553
+ # ---------------------------------------------------------------------------
554
+
555
+ @command("anonymize", usage="anonymize <col> [col2 ...] [--method=mask|hash|noise|drop]")
556
+ def cmd_anonymize(session: Session, args: str) -> str:
557
+ """Anonymize columns using a chosen method.
558
+
559
+ Methods:
560
+ mask Replace string values with asterisks, keeping first and last char.
561
+ hash Replace values with their SHA-256 hex digest.
562
+ noise Add Gaussian noise to numeric columns (--noise_std=0.1 controls scale).
563
+ drop Drop the column(s) entirely.
564
+
565
+ Example: anonymize name email --method=mask
566
+ Example: anonymize ssn --method=hash
567
+ Example: anonymize income --method=noise --noise_std=0.05
568
+ """
569
+ import hashlib
570
+
571
+ df = session.require_data()
572
+ ca = CommandArgs(args)
573
+
574
+ cols = ca.positional
575
+ if not cols:
576
+ return "Usage: anonymize <col> [col2 ...] [--method=mask|hash|noise|drop]"
577
+
578
+ method = ca.options.get("method", "mask").lower()
579
+ valid_methods = {"mask", "hash", "noise", "drop"}
580
+ if method not in valid_methods:
581
+ return f"Unknown method '{method}'. Valid: {', '.join(sorted(valid_methods))}"
582
+
583
+ missing = [c for c in cols if c not in df.columns]
584
+ if missing:
585
+ return f"Columns not found: {', '.join(missing)}"
586
+
587
+ noise_std = 0.1
588
+ if "noise_std" in ca.options:
589
+ try:
590
+ noise_std = float(ca.options["noise_std"])
591
+ except ValueError:
592
+ return f"--noise_std must be a float, got '{ca.options['noise_std']}'"
593
+
594
+ session.snapshot()
595
+ try:
596
+ if method == "drop":
597
+ session.df = df.drop(cols)
598
+ return (
599
+ f"Dropped {len(cols)} column(s): {', '.join(cols)}. "
600
+ "Use 'undo' to revert."
601
+ )
602
+
603
+ work_df = df
604
+ report_lines: list[str] = []
605
+
606
+ for col in cols:
607
+ dtype = df[col].dtype
608
+ is_numeric = dtype in (
609
+ pl.Float32, pl.Float64,
610
+ pl.Int8, pl.Int16, pl.Int32, pl.Int64,
611
+ pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64,
612
+ )
613
+
614
+ if method == "mask":
615
+ def _mask(v: object) -> str | None:
616
+ if v is None:
617
+ return None
618
+ s = str(v)
619
+ if len(s) <= 2:
620
+ return "*" * len(s)
621
+ return s[0] + "*" * (len(s) - 2) + s[-1]
622
+
623
+ masked = pl.Series(
624
+ name=col,
625
+ values=[_mask(v) for v in work_df[col].to_list()],
626
+ dtype=pl.Utf8,
627
+ )
628
+ work_df = work_df.with_columns(masked)
629
+ report_lines.append(f" '{col}': masked (first+last char kept)")
630
+
631
+ elif method == "hash":
632
+ def _hash(v: object) -> str | None:
633
+ if v is None:
634
+ return None
635
+ raw = str(v).encode("utf-8")
636
+ return hashlib.sha256(raw).hexdigest()
637
+
638
+ hashed = pl.Series(
639
+ name=col,
640
+ values=[_hash(v) for v in work_df[col].to_list()],
641
+ dtype=pl.Utf8,
642
+ )
643
+ work_df = work_df.with_columns(hashed)
644
+ report_lines.append(f" '{col}': SHA-256 hashed")
645
+
646
+ elif method == "noise":
647
+ if not is_numeric:
648
+ report_lines.append(
649
+ f" '{col}': skipped (noise requires numeric column, got {dtype})"
650
+ )
651
+ continue
652
+ import random as _random
653
+ vals = work_df[col].cast(pl.Float64).to_list()
654
+ noisy = [
655
+ v + _random.gauss(0, noise_std) if v is not None else None
656
+ for v in vals
657
+ ]
658
+ work_df = work_df.with_columns(
659
+ pl.Series(name=col, values=noisy, dtype=pl.Float64)
660
+ )
661
+ report_lines.append(
662
+ f" '{col}': Gaussian noise added (std={noise_std})"
663
+ )
664
+
665
+ session.df = work_df
666
+ header = f"Anonymized {len(cols)} column(s) using method='{method}':"
667
+ report_lines.append("Use 'undo' to revert.")
668
+ return "\n".join([header] + report_lines)
669
+
670
+ except Exception as exc:
671
+ session.undo()
672
+ return friendly_error(exc, "anonymize")