stata-cli 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stata_cli/engine.py ADDED
@@ -0,0 +1,461 @@
1
+ """Stata execution engine wrapping PyStata."""
2
+
3
+ import io
4
+ import json
5
+ import os
6
+ import sys
7
+ import re
8
+ import time
9
+ import tempfile
10
+ import platform
11
+ from dataclasses import dataclass, field, asdict
12
+ from typing import Any, Dict, List, Optional
13
+
14
+ from .utils import join_line_continuations
15
+ from .graph_artifacts import (
16
+ build_graph_record,
17
+ cleanup_graph_batches,
18
+ create_batch_context,
19
+ ensure_graphs_root,
20
+ get_graphs_root,
21
+ write_batch_manifest,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class Result:
27
+ """Outcome of a Stata command execution."""
28
+ success: bool
29
+ output: str
30
+ error: str
31
+ execution_time: float
32
+ return_code: int = 0
33
+ extra: dict = field(default_factory=dict)
34
+
35
+ def to_json(self) -> str:
36
+ return json.dumps(asdict(self), ensure_ascii=False, indent=2)
37
+
38
+
39
+ # ── Graph command regex for auto-naming ──────────────────────────────────
40
+
41
+ _GRAPH_CMD_RE = re.compile(
42
+ r"^(\s*)(scatter|histogram|twoway|kdensity|graph\s+"
43
+ r"(?:bar|box|dot|pie|matrix|hbar|hbox|combine))\s+(.*)$",
44
+ re.IGNORECASE,
45
+ )
46
+ _GRAPH_NAME_RE = re.compile(r"\bname\s*\(", re.IGNORECASE)
47
+ _EXISTING_GRAPHN_RE = re.compile(r"\bname\s*\(\s*graph(\d+)", re.IGNORECASE)
48
+
49
+
50
+ class StataEngine:
51
+ """Thin wrapper around PyStata for single-process command execution."""
52
+
53
+ def __init__(self, stata_path: str, edition: str = "mp", graphs_dir: Optional[str] = None):
54
+ self.stata_path = stata_path
55
+ self.edition = edition.lower()
56
+ self.graphs_dir = graphs_dir or get_graphs_root()
57
+ self._stata = None
58
+ self._stlib = None
59
+ self._initialized = False
60
+ self._stop_sent = False
61
+
62
+ def _ensure_initialized(self) -> None:
63
+ if self._initialized:
64
+ return
65
+
66
+ os.environ["SYSDIR_STATA"] = self.stata_path
67
+
68
+ from .utils import get_pystata_path
69
+
70
+ pystata_dir = get_pystata_path(self.stata_path)
71
+ if pystata_dir and pystata_dir not in sys.path:
72
+ sys.path.insert(0, pystata_dir)
73
+ utilities_parent = os.path.join(self.stata_path, "utilities")
74
+ if os.path.isdir(utilities_parent) and utilities_parent not in sys.path:
75
+ sys.path.insert(0, utilities_parent)
76
+
77
+ if platform.system() == "Darwin":
78
+ os.environ["_JAVA_OPTIONS"] = "-Djava.awt.headless=true"
79
+
80
+ old_stdout = sys.stdout
81
+ sys.stdout = io.StringIO()
82
+ try:
83
+ from pystata import config # type: ignore[import-untyped]
84
+ config.init(self.edition)
85
+ finally:
86
+ sys.stdout = old_stdout
87
+
88
+ from pystata import stata as stata_module # type: ignore[import-untyped]
89
+
90
+ self._stata = stata_module
91
+
92
+ try:
93
+ from pystata.config import stlib as stlib_module # type: ignore[import-untyped]
94
+ self._stlib = stlib_module
95
+ except Exception:
96
+ self._stlib = None
97
+
98
+ self._initialized = True
99
+
100
+ # ── public API ────────────────────────────────────────────────────────
101
+
102
+ def run(self, code: str, timeout: float = 600.0) -> Result:
103
+ """Execute a Stata code string and return captured output."""
104
+ self._ensure_initialized()
105
+ code = join_line_continuations(code)
106
+ self._stop_sent = False
107
+
108
+ self._reset_graph_tracking()
109
+
110
+ log_file = os.path.join(
111
+ tempfile.gettempdir(),
112
+ f"stata_cli_{os.getpid()}_{int(time.time() * 1000)}.log",
113
+ )
114
+ log_file_stata = log_file.replace("\\", "/")
115
+
116
+ wrapped = (
117
+ f'capture log close _all\n'
118
+ f'log using "{log_file_stata}", replace text\n'
119
+ f'{code}\n'
120
+ f'capture log close _all\n'
121
+ )
122
+
123
+ start = time.time()
124
+ try:
125
+ old_stdout = sys.stdout
126
+ sys.stdout = io.StringIO()
127
+ try:
128
+ self._stata.run(wrapped, echo=True, inline=False)
129
+ finally:
130
+ captured_stdout = sys.stdout.getvalue()
131
+ sys.stdout = old_stdout
132
+
133
+ output = self._read_log(log_file) or captured_stdout
134
+ elapsed = time.time() - start
135
+
136
+ output = _deduplicate_breaks(output)
137
+
138
+ graphs = self._detect_and_export_graphs()
139
+ extra: dict = {}
140
+ if graphs:
141
+ extra["graphs"] = graphs
142
+
143
+ if "--Break--" in output:
144
+ return Result(False, output, "Execution cancelled", elapsed, return_code=1, extra=extra)
145
+
146
+ rc = _extract_return_code(output)
147
+ if rc:
148
+ return Result(False, output, "", elapsed, return_code=rc, extra=extra)
149
+
150
+ return Result(True, output, "", elapsed, return_code=0, extra=extra)
151
+
152
+ except Exception as exc:
153
+ elapsed = time.time() - start
154
+ error_str = str(exc)
155
+ if "--Break--" in error_str:
156
+ return Result(False, "", "Execution cancelled", elapsed, return_code=1)
157
+ rc = _extract_return_code(error_str)
158
+ return Result(False, error_str, "", elapsed, return_code=rc or 1)
159
+ finally:
160
+ self._remove(log_file)
161
+
162
+ def run_file(self, path: str, timeout: float = 600.0) -> Result:
163
+ """Execute a .do file with graph auto-naming and ``///`` preprocessing."""
164
+ path = os.path.abspath(path)
165
+ if not os.path.isfile(path):
166
+ return Result(False, "", f"File not found: {path}", 0.0)
167
+
168
+ preprocessed = self._preprocess_do_file(path)
169
+ file_dir = os.path.dirname(path)
170
+ stata_path = preprocessed.replace("\\", "/")
171
+ code = f'cd "{file_dir}"\ndo "{stata_path}"'
172
+ try:
173
+ return self.run(code, timeout=timeout)
174
+ finally:
175
+ if preprocessed != path:
176
+ self._remove(preprocessed)
177
+
178
+ def get_data(self, if_condition: Optional[str] = None, max_rows: int = 10000) -> Dict[str, Any]:
179
+ """Return the current dataset as a dict (columns, data, dtypes, row counts)."""
180
+ self._ensure_initialized()
181
+ try:
182
+ import sfi # type: ignore[import-untyped]
183
+ import numpy as np # type: ignore[import-untyped]
184
+ except ImportError as exc:
185
+ return {"status": "error", "error": f"Missing dependency: {exc}"}
186
+
187
+ total_obs = sfi.Data.getObsTotal()
188
+ if total_obs == 0:
189
+ return {
190
+ "status": "success", "data": [], "columns": [], "dtypes": {},
191
+ "rows": 0, "total_rows": 0, "displayed_rows": 0, "max_rows": max_rows,
192
+ }
193
+
194
+ max_rows = max(100, max_rows)
195
+ frame_name = f"_stata_cli_flt_{os.getpid()}"
196
+
197
+ try:
198
+ if if_condition:
199
+ self._stata.run(f"capture frame drop {frame_name}", inline=False, echo=False)
200
+ self._stata.run(f"frame copy `c(frame)' {frame_name}", inline=False, echo=False)
201
+ self._stata.run(f"frame {frame_name}: quietly gen long _orig_obs = _n - 1", inline=False, echo=False)
202
+ self._stata.run(f"frame {frame_name}: quietly keep if {if_condition}", inline=False, echo=False)
203
+
204
+ df = self._stata.pdataframe_from_frame(frame_name)
205
+ filtered_obs = len(df) if df is not None else 0
206
+ if filtered_obs > max_rows:
207
+ df = df.head(max_rows)
208
+
209
+ orig_index = df["_orig_obs"].tolist() if df is not None and not df.empty else []
210
+ if df is not None and "_orig_obs" in df.columns:
211
+ df = df.drop(columns=["_orig_obs"])
212
+
213
+ total_matching = filtered_obs
214
+ displayed = min(filtered_obs, max_rows)
215
+ else:
216
+ total_matching = total_obs
217
+ displayed = min(total_obs, max_rows)
218
+ if total_obs > max_rows:
219
+ df = self._stata.pdataframe_from_data(obs=range(max_rows))
220
+ else:
221
+ df = self._stata.pdataframe_from_data()
222
+ orig_index = list(range(len(df))) if df is not None else []
223
+
224
+ if df is None or df.empty:
225
+ return {
226
+ "status": "success", "data": [], "columns": [], "dtypes": {},
227
+ "rows": 0, "total_rows": total_matching, "displayed_rows": 0, "max_rows": max_rows,
228
+ }
229
+
230
+ df_clean = df.replace({np.nan: None})
231
+ return {
232
+ "status": "success",
233
+ "data": df_clean.values.tolist(),
234
+ "columns": df_clean.columns.tolist(),
235
+ "dtypes": {col: str(df[col].dtype) for col in df.columns},
236
+ "rows": len(df),
237
+ "index": orig_index,
238
+ "total_rows": total_matching,
239
+ "displayed_rows": displayed,
240
+ "max_rows": max_rows,
241
+ }
242
+ except Exception as exc:
243
+ return {"status": "error", "error": str(exc)}
244
+ finally:
245
+ if if_condition:
246
+ try:
247
+ self._stata.run(f"capture frame drop {frame_name}", inline=False, echo=False)
248
+ except Exception:
249
+ pass
250
+
251
+ def help(self, topic: str) -> Result:
252
+ """Return Stata help text for *topic*."""
253
+ self._ensure_initialized()
254
+ topic = topic.strip().lstrip("#").replace(" ", "_").split(",")[0].strip()
255
+ if not topic or not re.match(r"^[a-zA-Z0-9_.\-]+$", topic):
256
+ return Result(False, "", "Invalid topic name", 0.0)
257
+
258
+ first_letter = topic[0].lower()
259
+ sysdirs = ["base", "plus", "site", "personal", "stata", "oldplace"]
260
+ fallback_blocks = ""
261
+ for sd in sysdirs:
262
+ for ext in ["sthlp", "hlp"]:
263
+ fallback_blocks += (
264
+ f'if "`_helpfn\'" == "" {{\n'
265
+ f' capture confirm file "`c(sysdir_{sd})\'{first_letter}/{topic}.{ext}"\n'
266
+ f' if _rc == 0 local _helpfn "`c(sysdir_{sd})\'{first_letter}/{topic}.{ext}"\n'
267
+ f'}}\n'
268
+ )
269
+
270
+ stata_code = (
271
+ f'quietly set more off\n'
272
+ f'local _stata_help_old_linesize = c(linesize)\n'
273
+ f'quietly set linesize 255\n'
274
+ f'local _helpfn ""\n'
275
+ f'capture findfile {topic}.sthlp\n'
276
+ f'if _rc == 0 local _helpfn "`r(fn)\'"\n'
277
+ f'if "`_helpfn\'" == "" {{\n'
278
+ f' capture findfile {topic}.hlp\n'
279
+ f' if _rc == 0 local _helpfn "`r(fn)\'"\n'
280
+ f'}}\n'
281
+ f'{fallback_blocks}'
282
+ f'if "`_helpfn\'" != "" {{\n'
283
+ f' type "`_helpfn\'", starbang\n'
284
+ f'}}\n'
285
+ f'else {{\n'
286
+ f' display as error "help file not found for: {topic}"\n'
287
+ f'}}\n'
288
+ f'quietly set linesize `_stata_help_old_linesize\'\n'
289
+ )
290
+
291
+ result = self.run(stata_code)
292
+
293
+ if result.output:
294
+ from .output_filter import clean_log_wrapper, apply_compact_filter
295
+ from .smcl_parser import smcl_to_text
296
+ output = clean_log_wrapper(result.output)
297
+ output = apply_compact_filter(output, filter_command_echo=True)
298
+ output = smcl_to_text(output)
299
+ result.output = output
300
+
301
+ return result
302
+
303
+ def stop(self) -> bool:
304
+ """Interrupt a running Stata command. Returns True if signal sent."""
305
+ if self._stop_sent or self._stlib is None:
306
+ return False
307
+ self._stop_sent = True
308
+ try:
309
+ self._stlib.StataSO_SetBreak()
310
+ return True
311
+ except Exception:
312
+ return False
313
+
314
+ def close(self) -> None:
315
+ """Cleanup placeholder."""
316
+
317
+ # ── graph detection ──────────────────────────────────────────────────
318
+
319
+ def _reset_graph_tracking(self) -> None:
320
+ if self._stlib is None:
321
+ return
322
+ try:
323
+ from pystata.config import get_encode_str # type: ignore[import-untyped]
324
+ self._stlib.StataSO_Execute(get_encode_str("qui _gr_list off"), False)
325
+ self._stlib.StataSO_Execute(get_encode_str("qui _gr_list on"), False)
326
+ except Exception:
327
+ pass
328
+
329
+ def _detect_and_export_graphs(self) -> List[Dict[str, Any]]:
330
+ if self._stlib is None:
331
+ return []
332
+ try:
333
+ import sfi # type: ignore[import-untyped]
334
+ from pystata.config import get_encode_str # type: ignore[import-untyped]
335
+
336
+ self._stlib.StataSO_Execute(get_encode_str("qui _gr_list list"), False)
337
+ gnamelist = sfi.Macro.getGlobal("r(_grlist)")
338
+ if not gnamelist or not gnamelist.strip():
339
+ return []
340
+
341
+ graph_names = gnamelist.strip().split()
342
+ graphs_root = ensure_graphs_root(self.graphs_dir)
343
+ batch = create_batch_context(graphs_root)
344
+ graphs: List[Dict[str, Any]] = []
345
+
346
+ for order, gname in enumerate(graph_names):
347
+ try:
348
+ self._stlib.StataSO_Execute(
349
+ get_encode_str(f"quietly graph display {gname}"), False
350
+ )
351
+ graph_file = os.path.join(batch["batch_dir"], f"{gname}.png")
352
+ graph_file_stata = graph_file.replace("\\", "/")
353
+ export_cmd = (
354
+ f'quietly graph export "{graph_file_stata}", '
355
+ f"name({gname}) replace width(800) height(600)"
356
+ )
357
+ rc = self._stlib.StataSO_Execute(get_encode_str(export_cmd), False)
358
+ if rc != 0:
359
+ continue
360
+ if os.path.isfile(graph_file) and os.path.getsize(graph_file) > 0:
361
+ graphs.append(build_graph_record(batch, gname, graph_file, order))
362
+ except Exception:
363
+ continue
364
+
365
+ if graphs:
366
+ write_batch_manifest(batch, graphs)
367
+ cleanup_graph_batches(graphs_root, keep_ids=[batch["batch_id"]])
368
+
369
+ return graphs
370
+ except Exception:
371
+ return []
372
+
373
+ # ── do-file preprocessing ────────────────────────────────────────────
374
+
375
+ @staticmethod
376
+ def _preprocess_do_file(path: str) -> str:
377
+ """Join ``///`` continuations and auto-name unnamed graph commands.
378
+
379
+ Returns the path to a temp file (or the original if no changes needed).
380
+ """
381
+ try:
382
+ with open(path, "r", encoding="utf-8", errors="replace") as fh:
383
+ content = fh.read()
384
+
385
+ joined_lines = join_line_continuations(content).splitlines()
386
+
387
+ existing_nums: set[int] = set()
388
+ for line in joined_lines:
389
+ for m in _EXISTING_GRAPHN_RE.findall(line):
390
+ try:
391
+ existing_nums.add(int(m))
392
+ except ValueError:
393
+ pass
394
+
395
+ counter = max(existing_nums) if existing_nums else 0
396
+ modified = False
397
+ out_lines: list[str] = []
398
+
399
+ for line in joined_lines:
400
+ gm = _GRAPH_CMD_RE.match(line)
401
+ if gm and not _GRAPH_NAME_RE.search(gm.group(3)):
402
+ indent, cmd, rest = gm.group(1), gm.group(2), gm.group(3)
403
+ counter += 1
404
+ name_opt = f"name(graph{counter}, replace)"
405
+ if "," in rest:
406
+ rest = rest.replace(",", f", {name_opt}", 1)
407
+ else:
408
+ rest = rest.rstrip() + f", {name_opt}"
409
+ out_lines.append(f"{indent}{cmd} {rest}")
410
+ modified = True
411
+ else:
412
+ out_lines.append(line)
413
+
414
+ if not modified:
415
+ return path
416
+
417
+ tmp = tempfile.NamedTemporaryFile(suffix=".do", delete=False, mode="w", encoding="utf-8")
418
+ tmp.write("\n".join(out_lines) + "\n")
419
+ tmp.close()
420
+ return tmp.name
421
+ except Exception:
422
+ return path
423
+
424
+ # ── internals ─────────────────────────────────────────────────────────
425
+
426
+ @staticmethod
427
+ def _read_log(path: str) -> str:
428
+ if not os.path.isfile(path):
429
+ return ""
430
+ try:
431
+ with open(path, "r", encoding="utf-8", errors="replace") as fh:
432
+ return fh.read()
433
+ except OSError:
434
+ return ""
435
+
436
+ @staticmethod
437
+ def _remove(path: str) -> None:
438
+ try:
439
+ if os.path.isfile(path):
440
+ os.unlink(path)
441
+ except OSError:
442
+ pass
443
+
444
+
445
+ def _deduplicate_breaks(output: str) -> str:
446
+ if not output or "--Break--" not in output:
447
+ return output
448
+ return re.sub(
449
+ r"(--Break--\s*\n\s*r\(1\);\s*\n?)+",
450
+ "--Break--\nr(1);\n",
451
+ output,
452
+ )
453
+
454
+
455
+ _STATA_ERROR_RE = re.compile(r"^r\((\d+)\);\s*$", re.MULTILINE)
456
+
457
+
458
+ def _extract_return_code(output: str) -> int:
459
+ """Extract the last Stata return code from output. 0 if none found."""
460
+ matches = _STATA_ERROR_RE.findall(output)
461
+ return int(matches[-1]) if matches else 0
@@ -0,0 +1,95 @@
1
+ """Graph artifact storage and batch management."""
2
+
3
+ import json
4
+ import os
5
+ import shutil
6
+ import tempfile
7
+ import time
8
+ import uuid
9
+ from typing import Any, Dict, Iterable, List, Optional
10
+
11
+
12
+ DEFAULT_KEEP_BATCHES = 2
13
+ MANIFEST_FILENAME = "manifest.json"
14
+
15
+
16
+ def get_graphs_root(configured: Optional[str] = None) -> str:
17
+ root = configured or os.environ.get("STATA_CLI_GRAPHS_DIR")
18
+ if root:
19
+ return os.path.abspath(root)
20
+ return os.path.join(os.path.expanduser("~"), ".stata-cli", "graphs")
21
+
22
+
23
+ def ensure_graphs_root(graphs_root: str) -> str:
24
+ os.makedirs(graphs_root, exist_ok=True)
25
+ return os.path.abspath(graphs_root)
26
+
27
+
28
+ def create_batch_context(graphs_root: str, execution_id: Optional[str] = None) -> Dict[str, Any]:
29
+ graphs_root = ensure_graphs_root(graphs_root)
30
+ execution_id = execution_id or f"exec-{int(time.time() * 1000)}-{uuid.uuid4().hex[:8]}"
31
+ batch_dir = os.path.join(graphs_root, execution_id)
32
+ os.makedirs(batch_dir, exist_ok=True)
33
+ return {
34
+ "execution_id": execution_id,
35
+ "batch_id": execution_id,
36
+ "batch_dir": batch_dir,
37
+ "graphs_root": graphs_root,
38
+ "created_at": int(time.time() * 1000),
39
+ }
40
+
41
+
42
+ def build_graph_record(
43
+ batch_context: Dict[str, Any],
44
+ name: str,
45
+ file_path: str,
46
+ order: int,
47
+ fmt: str = "png",
48
+ ) -> Dict[str, Any]:
49
+ return {
50
+ "name": name,
51
+ "path": file_path.replace("\\", "/"),
52
+ "filename": os.path.basename(file_path),
53
+ "format": fmt,
54
+ "order": order,
55
+ "batchId": batch_context["batch_id"],
56
+ }
57
+
58
+
59
+ def write_batch_manifest(batch_context: Dict[str, Any], graphs: List[Dict[str, Any]]) -> str:
60
+ manifest = {
61
+ "executionId": batch_context["execution_id"],
62
+ "batchId": batch_context["batch_id"],
63
+ "createdAt": batch_context["created_at"],
64
+ "graphs": graphs,
65
+ }
66
+ path = os.path.join(batch_context["batch_dir"], MANIFEST_FILENAME)
67
+ with open(path, "w", encoding="utf-8") as fh:
68
+ json.dump(manifest, fh, ensure_ascii=False, indent=2)
69
+ return path
70
+
71
+
72
+ def cleanup_graph_batches(
73
+ graphs_root: str,
74
+ keep_ids: Optional[Iterable[str]] = None,
75
+ keep_latest: int = DEFAULT_KEEP_BATCHES,
76
+ ) -> List[str]:
77
+ if not os.path.isdir(graphs_root):
78
+ return []
79
+ protected = set(keep_ids or [])
80
+ batch_dirs = sorted(
81
+ (e.path for e in os.scandir(graphs_root) if e.is_dir()),
82
+ key=lambda p: os.path.getmtime(p),
83
+ reverse=True,
84
+ )
85
+ removed: list[str] = []
86
+ for idx, bdir in enumerate(batch_dirs):
87
+ bid = os.path.basename(bdir)
88
+ if bid in protected or idx < keep_latest:
89
+ continue
90
+ try:
91
+ shutil.rmtree(bdir, ignore_errors=True)
92
+ removed.append(bid)
93
+ except OSError:
94
+ pass
95
+ return removed