mcp-stata 1.18.0__cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-stata might be problematic. Click here for more details.
- mcp_stata/__init__.py +4 -0
- mcp_stata/_native_ops.cpython-312-aarch64-linux-gnu.so +0 -0
- mcp_stata/config.py +20 -0
- mcp_stata/discovery.py +550 -0
- mcp_stata/graph_detector.py +401 -0
- mcp_stata/models.py +62 -0
- mcp_stata/native_ops.py +87 -0
- mcp_stata/server.py +1130 -0
- mcp_stata/smcl/smcl2html.py +88 -0
- mcp_stata/stata_client.py +3692 -0
- mcp_stata/streaming_io.py +263 -0
- mcp_stata/test_stata.py +54 -0
- mcp_stata/ui_http.py +998 -0
- mcp_stata-1.18.0.dist-info/METADATA +471 -0
- mcp_stata-1.18.0.dist-info/RECORD +18 -0
- mcp_stata-1.18.0.dist-info/WHEEL +5 -0
- mcp_stata-1.18.0.dist-info/entry_points.txt +2 -0
- mcp_stata-1.18.0.dist-info/licenses/LICENSE +661 -0
|
@@ -0,0 +1,3692 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import io
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import platform
|
|
8
|
+
import re
|
|
9
|
+
import subprocess
|
|
10
|
+
import sys
|
|
11
|
+
import tempfile
|
|
12
|
+
import threading
|
|
13
|
+
import time
|
|
14
|
+
import uuid
|
|
15
|
+
from contextlib import contextmanager, redirect_stdout, redirect_stderr
|
|
16
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
17
|
+
from io import StringIO
|
|
18
|
+
from typing import Any, Awaitable, Callable, Dict, Generator, List, Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import anyio
|
|
21
|
+
from anyio import get_cancelled_exc_class
|
|
22
|
+
|
|
23
|
+
from .discovery import find_stata_candidates
|
|
24
|
+
from .config import MAX_LIMIT
|
|
25
|
+
from .models import (
|
|
26
|
+
CommandResponse,
|
|
27
|
+
ErrorEnvelope,
|
|
28
|
+
GraphExport,
|
|
29
|
+
GraphExportResponse,
|
|
30
|
+
GraphInfo,
|
|
31
|
+
GraphListResponse,
|
|
32
|
+
VariableInfo,
|
|
33
|
+
VariablesResponse,
|
|
34
|
+
)
|
|
35
|
+
from .smcl.smcl2html import smcl_to_markdown
|
|
36
|
+
from .streaming_io import FileTeeIO, TailBuffer
|
|
37
|
+
from .graph_detector import StreamingGraphCache
|
|
38
|
+
from .native_ops import fast_scan_log, compute_filter_indices
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger("mcp_stata")
|
|
41
|
+
|
|
42
|
+
_POLARS_AVAILABLE: Optional[bool] = None
|
|
43
|
+
|
|
44
|
+
def _check_polars_available() -> bool:
|
|
45
|
+
"""
|
|
46
|
+
Check if Polars can be safely imported.
|
|
47
|
+
Must detect problematic platforms BEFORE attempting import,
|
|
48
|
+
since the crash is a fatal signal, not a catchable exception.
|
|
49
|
+
"""
|
|
50
|
+
if sys.platform == "win32" and platform.machine().lower() in ("arm64", "aarch64"):
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
import polars # noqa: F401
|
|
55
|
+
return True
|
|
56
|
+
except ImportError:
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _get_polars_available() -> bool:
|
|
61
|
+
global _POLARS_AVAILABLE
|
|
62
|
+
if _POLARS_AVAILABLE is None:
|
|
63
|
+
_POLARS_AVAILABLE = _check_polars_available()
|
|
64
|
+
return _POLARS_AVAILABLE
|
|
65
|
+
|
|
66
|
+
# ============================================================================
|
|
67
|
+
# MODULE-LEVEL DISCOVERY CACHE
|
|
68
|
+
# ============================================================================
|
|
69
|
+
# This cache ensures Stata discovery runs exactly once per process lifetime
|
|
70
|
+
_discovery_lock = threading.Lock()
|
|
71
|
+
_discovery_result: Optional[Tuple[str, str]] = None # (path, edition)
|
|
72
|
+
_discovery_candidates: Optional[List[Tuple[str, str]]] = None
|
|
73
|
+
_discovery_attempted = False
|
|
74
|
+
_discovery_error: Optional[Exception] = None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _get_discovery_candidates() -> List[Tuple[str, str]]:
|
|
78
|
+
"""
|
|
79
|
+
Get ordered discovery candidates, running discovery only once.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
List of (stata_executable_path, edition) ordered by preference.
|
|
83
|
+
|
|
84
|
+
Raises:
|
|
85
|
+
RuntimeError: If Stata discovery fails
|
|
86
|
+
"""
|
|
87
|
+
global _discovery_result, _discovery_candidates, _discovery_attempted, _discovery_error
|
|
88
|
+
|
|
89
|
+
with _discovery_lock:
|
|
90
|
+
# If we've already successfully discovered Stata, return cached result
|
|
91
|
+
if _discovery_result is not None:
|
|
92
|
+
return _discovery_candidates or [_discovery_result]
|
|
93
|
+
|
|
94
|
+
if _discovery_candidates is not None:
|
|
95
|
+
return _discovery_candidates
|
|
96
|
+
|
|
97
|
+
# If we've already attempted and failed, re-raise the cached error
|
|
98
|
+
if _discovery_attempted and _discovery_error is not None:
|
|
99
|
+
raise RuntimeError(f"Stata binary not found: {_discovery_error}") from _discovery_error
|
|
100
|
+
|
|
101
|
+
# This is the first attempt - run discovery
|
|
102
|
+
_discovery_attempted = True
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
# Log environment state once at first discovery
|
|
106
|
+
env_path = os.getenv("STATA_PATH")
|
|
107
|
+
if env_path:
|
|
108
|
+
logger.info("STATA_PATH env provided (raw): %s", env_path)
|
|
109
|
+
else:
|
|
110
|
+
logger.info("STATA_PATH env not set; attempting auto-discovery")
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
pkg_version = version("mcp-stata")
|
|
114
|
+
except PackageNotFoundError:
|
|
115
|
+
pkg_version = "unknown"
|
|
116
|
+
logger.info("mcp-stata version: %s", pkg_version)
|
|
117
|
+
|
|
118
|
+
# Run discovery
|
|
119
|
+
candidates = find_stata_candidates()
|
|
120
|
+
|
|
121
|
+
# Cache the successful result
|
|
122
|
+
_discovery_candidates = candidates
|
|
123
|
+
if candidates:
|
|
124
|
+
_discovery_result = candidates[0]
|
|
125
|
+
logger.info("Discovery found Stata at: %s (%s)", _discovery_result[0], _discovery_result[1])
|
|
126
|
+
else:
|
|
127
|
+
raise FileNotFoundError("No Stata candidates discovered")
|
|
128
|
+
|
|
129
|
+
return candidates
|
|
130
|
+
|
|
131
|
+
except FileNotFoundError as e:
|
|
132
|
+
_discovery_error = e
|
|
133
|
+
raise RuntimeError(f"Stata binary not found: {e}") from e
|
|
134
|
+
except PermissionError as e:
|
|
135
|
+
_discovery_error = e
|
|
136
|
+
raise RuntimeError(
|
|
137
|
+
f"Stata binary is not executable: {e}. "
|
|
138
|
+
"Point STATA_PATH directly to the Stata binary (e.g., .../Contents/MacOS/stata-mp)."
|
|
139
|
+
) from e
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _get_discovered_stata() -> Tuple[str, str]:
|
|
143
|
+
"""
|
|
144
|
+
Preserve existing API: return the highest-priority discovered Stata candidate.
|
|
145
|
+
"""
|
|
146
|
+
candidates = _get_discovery_candidates()
|
|
147
|
+
if not candidates:
|
|
148
|
+
raise RuntimeError("Stata binary not found: no candidates discovered")
|
|
149
|
+
return candidates[0]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class StataClient:
|
|
153
|
+
_initialized = False
|
|
154
|
+
_exec_lock: threading.Lock
|
|
155
|
+
_cache_init_lock = threading.Lock() # Class-level lock for cache initialization
|
|
156
|
+
_is_executing = False # Flag to prevent recursive Stata calls
|
|
157
|
+
MAX_DATA_ROWS = MAX_LIMIT
|
|
158
|
+
MAX_GRAPH_BYTES = 50 * 1024 * 1024 # Maximum graph exports (~50MB)
|
|
159
|
+
MAX_CACHE_SIZE = 100 # Maximum number of graphs to cache
|
|
160
|
+
MAX_CACHE_BYTES = 500 * 1024 * 1024 # Maximum cache size in bytes (~500MB)
|
|
161
|
+
LIST_GRAPHS_TTL = 0.075 # TTL for list_graphs cache (75ms)
|
|
162
|
+
|
|
163
|
+
def __init__(self):
|
|
164
|
+
self._exec_lock = threading.RLock()
|
|
165
|
+
self._is_executing = False
|
|
166
|
+
self._command_idx = 0 # Counter for user-initiated commands
|
|
167
|
+
self._initialized = False
|
|
168
|
+
from .graph_detector import GraphCreationDetector
|
|
169
|
+
self._graph_detector = GraphCreationDetector(self)
|
|
170
|
+
|
|
171
|
+
def __new__(cls):
|
|
172
|
+
inst = super(StataClient, cls).__new__(cls)
|
|
173
|
+
inst._exec_lock = threading.RLock()
|
|
174
|
+
inst._is_executing = False
|
|
175
|
+
inst._command_idx = 0
|
|
176
|
+
from .graph_detector import GraphCreationDetector
|
|
177
|
+
inst._graph_detector = GraphCreationDetector(inst)
|
|
178
|
+
return inst
|
|
179
|
+
|
|
180
|
+
def _increment_command_idx(self) -> int:
|
|
181
|
+
"""Increment and return the command counter."""
|
|
182
|
+
self._command_idx += 1
|
|
183
|
+
return self._command_idx
|
|
184
|
+
|
|
185
|
+
@contextmanager
|
|
186
|
+
def _redirect_io(self, out_buf, err_buf):
|
|
187
|
+
"""Safely redirect stdout/stderr for the duration of a Stata call."""
|
|
188
|
+
backup_stdout, backup_stderr = sys.stdout, sys.stderr
|
|
189
|
+
sys.stdout, sys.stderr = out_buf, err_buf
|
|
190
|
+
try:
|
|
191
|
+
yield
|
|
192
|
+
finally:
|
|
193
|
+
sys.stdout, sys.stderr = backup_stdout, backup_stderr
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def _stata_quote(value: str) -> str:
|
|
198
|
+
"""Return a Stata double-quoted string literal for value."""
|
|
199
|
+
# Stata uses doubled quotes to represent a quote character inside a string.
|
|
200
|
+
v = (value or "")
|
|
201
|
+
v = v.replace('"', '""')
|
|
202
|
+
# Use compound double quotes to avoid tokenization issues with spaces and
|
|
203
|
+
# punctuation in contexts like graph names.
|
|
204
|
+
return f'`"{v}"\''
|
|
205
|
+
|
|
206
|
+
@contextmanager
|
|
207
|
+
def _redirect_io_streaming(self, out_stream, err_stream):
|
|
208
|
+
backup_stdout, backup_stderr = sys.stdout, sys.stderr
|
|
209
|
+
sys.stdout, sys.stderr = out_stream, err_stream
|
|
210
|
+
try:
|
|
211
|
+
yield
|
|
212
|
+
finally:
|
|
213
|
+
sys.stdout, sys.stderr = backup_stdout, backup_stderr
|
|
214
|
+
|
|
215
|
+
@staticmethod
|
|
216
|
+
def _safe_unlink(path: str) -> None:
|
|
217
|
+
if not path:
|
|
218
|
+
return
|
|
219
|
+
try:
|
|
220
|
+
if os.path.exists(path):
|
|
221
|
+
os.unlink(path)
|
|
222
|
+
except Exception:
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
def _create_smcl_log_path(
|
|
226
|
+
self,
|
|
227
|
+
*,
|
|
228
|
+
prefix: str = "mcp_smcl_",
|
|
229
|
+
max_hex: Optional[int] = None,
|
|
230
|
+
base_dir: Optional[str] = None,
|
|
231
|
+
) -> str:
|
|
232
|
+
hex_id = uuid.uuid4().hex if max_hex is None else uuid.uuid4().hex[:max_hex]
|
|
233
|
+
base = os.path.realpath(tempfile.gettempdir())
|
|
234
|
+
smcl_path = os.path.join(base, f"{prefix}{hex_id}.smcl")
|
|
235
|
+
self._safe_unlink(smcl_path)
|
|
236
|
+
return smcl_path
|
|
237
|
+
|
|
238
|
+
@staticmethod
|
|
239
|
+
def _make_smcl_log_name() -> str:
|
|
240
|
+
return f"_mcp_smcl_{uuid.uuid4().hex[:8]}"
|
|
241
|
+
|
|
242
|
+
def _open_smcl_log(self, smcl_path: str, log_name: str, *, quiet: bool = False) -> bool:
|
|
243
|
+
path_for_stata = smcl_path.replace("\\", "/")
|
|
244
|
+
base_cmd = f"log using \"{path_for_stata}\", replace smcl name({log_name})"
|
|
245
|
+
unnamed_cmd = f"log using \"{path_for_stata}\", replace smcl"
|
|
246
|
+
for attempt in range(4):
|
|
247
|
+
try:
|
|
248
|
+
logger.debug(
|
|
249
|
+
"_open_smcl_log attempt=%s log_name=%s path=%s",
|
|
250
|
+
attempt + 1,
|
|
251
|
+
log_name,
|
|
252
|
+
smcl_path,
|
|
253
|
+
)
|
|
254
|
+
logger.warning(
|
|
255
|
+
"SMCL open attempt %s cwd=%s path=%s",
|
|
256
|
+
attempt + 1,
|
|
257
|
+
os.getcwd(),
|
|
258
|
+
smcl_path,
|
|
259
|
+
)
|
|
260
|
+
logger.debug(
|
|
261
|
+
"SMCL open attempt=%s cwd=%s path=%s cmd=%s",
|
|
262
|
+
attempt + 1,
|
|
263
|
+
os.getcwd(),
|
|
264
|
+
smcl_path,
|
|
265
|
+
base_cmd,
|
|
266
|
+
)
|
|
267
|
+
try:
|
|
268
|
+
close_ret = self.stata.run("capture log close _all", echo=False)
|
|
269
|
+
if close_ret:
|
|
270
|
+
logger.warning("SMCL close_all output: %s", close_ret)
|
|
271
|
+
except Exception:
|
|
272
|
+
pass
|
|
273
|
+
cmd = f"{'quietly ' if quiet else ''}{base_cmd}"
|
|
274
|
+
try:
|
|
275
|
+
output_buf = StringIO()
|
|
276
|
+
with redirect_stdout(output_buf), redirect_stderr(output_buf):
|
|
277
|
+
self.stata.run(cmd, echo=False)
|
|
278
|
+
ret = output_buf.getvalue().strip()
|
|
279
|
+
if ret:
|
|
280
|
+
logger.warning("SMCL log open output: %s", ret)
|
|
281
|
+
except Exception as e:
|
|
282
|
+
logger.warning("SMCL log open failed (attempt %s): %s", attempt + 1, e)
|
|
283
|
+
logger.warning("SMCL log open failed: %r", e)
|
|
284
|
+
try:
|
|
285
|
+
retry_buf = StringIO()
|
|
286
|
+
with redirect_stdout(retry_buf), redirect_stderr(retry_buf):
|
|
287
|
+
self.stata.run(base_cmd, echo=False)
|
|
288
|
+
ret = retry_buf.getvalue().strip()
|
|
289
|
+
if ret:
|
|
290
|
+
logger.warning("SMCL log open output (no quiet): %s", ret)
|
|
291
|
+
except Exception as inner:
|
|
292
|
+
logger.warning("SMCL log open retry failed: %s", inner)
|
|
293
|
+
query_buf = StringIO()
|
|
294
|
+
try:
|
|
295
|
+
with redirect_stdout(query_buf), redirect_stderr(query_buf):
|
|
296
|
+
self.stata.run("log query", echo=False)
|
|
297
|
+
except Exception as query_err:
|
|
298
|
+
query_buf.write(f"log query failed: {query_err!r}")
|
|
299
|
+
query_ret = query_buf.getvalue().strip()
|
|
300
|
+
logger.warning("SMCL log query output: %s", query_ret)
|
|
301
|
+
|
|
302
|
+
if query_ret:
|
|
303
|
+
query_lower = query_ret.lower()
|
|
304
|
+
log_confirmed = "log:" in query_lower and "smcl" in query_lower and " on" in query_lower
|
|
305
|
+
if log_confirmed:
|
|
306
|
+
self._last_smcl_log_named = True
|
|
307
|
+
logger.info("SMCL log confirmed: %s", path_for_stata)
|
|
308
|
+
return True
|
|
309
|
+
logger.warning("SMCL log not confirmed after open; query_ret=%s", query_ret)
|
|
310
|
+
try:
|
|
311
|
+
unnamed_output = StringIO()
|
|
312
|
+
with redirect_stdout(unnamed_output), redirect_stderr(unnamed_output):
|
|
313
|
+
self.stata.run(unnamed_cmd, echo=False)
|
|
314
|
+
unnamed_ret = unnamed_output.getvalue().strip()
|
|
315
|
+
if unnamed_ret:
|
|
316
|
+
logger.warning("SMCL log open output (unnamed): %s", unnamed_ret)
|
|
317
|
+
except Exception as e:
|
|
318
|
+
logger.warning("SMCL log open failed (unnamed, attempt %s): %s", attempt + 1, e)
|
|
319
|
+
unnamed_query_buf = StringIO()
|
|
320
|
+
try:
|
|
321
|
+
with redirect_stdout(unnamed_query_buf), redirect_stderr(unnamed_query_buf):
|
|
322
|
+
self.stata.run("log query", echo=False)
|
|
323
|
+
except Exception as query_err:
|
|
324
|
+
unnamed_query_buf.write(f"log query failed: {query_err!r}")
|
|
325
|
+
unnamed_query = unnamed_query_buf.getvalue().strip()
|
|
326
|
+
if unnamed_query:
|
|
327
|
+
unnamed_lower = unnamed_query.lower()
|
|
328
|
+
unnamed_confirmed = "log:" in unnamed_lower and "smcl" in unnamed_lower and " on" in unnamed_lower
|
|
329
|
+
if unnamed_confirmed:
|
|
330
|
+
self._last_smcl_log_named = False
|
|
331
|
+
logger.info("SMCL log confirmed (unnamed): %s", path_for_stata)
|
|
332
|
+
return True
|
|
333
|
+
except Exception as e:
|
|
334
|
+
logger.warning("Failed to open SMCL log (attempt %s): %s", attempt + 1, e)
|
|
335
|
+
if attempt < 3:
|
|
336
|
+
time.sleep(0.1)
|
|
337
|
+
logger.warning("Failed to open SMCL log with cmd: %s", cmd)
|
|
338
|
+
return False
|
|
339
|
+
|
|
340
|
+
def _close_smcl_log(self, log_name: str) -> None:
|
|
341
|
+
try:
|
|
342
|
+
use_named = getattr(self, "_last_smcl_log_named", None)
|
|
343
|
+
if use_named is False:
|
|
344
|
+
self.stata.run("capture log close", echo=False)
|
|
345
|
+
else:
|
|
346
|
+
self.stata.run(f"capture log close {log_name}", echo=False)
|
|
347
|
+
except Exception:
|
|
348
|
+
pass
|
|
349
|
+
|
|
350
|
+
def _restore_results_from_hold(self, hold_attr: str) -> None:
|
|
351
|
+
if not hasattr(self, hold_attr):
|
|
352
|
+
return
|
|
353
|
+
hold_name = getattr(self, hold_attr)
|
|
354
|
+
try:
|
|
355
|
+
self.stata.run(f"capture _return restore {hold_name}", echo=False)
|
|
356
|
+
self._last_results = self.get_stored_results(force_fresh=True)
|
|
357
|
+
except Exception:
|
|
358
|
+
pass
|
|
359
|
+
finally:
|
|
360
|
+
try:
|
|
361
|
+
delattr(self, hold_attr)
|
|
362
|
+
except Exception:
|
|
363
|
+
pass
|
|
364
|
+
|
|
365
|
+
def _create_streaming_log(self, *, trace: bool) -> tuple[tempfile.NamedTemporaryFile, str, TailBuffer, FileTeeIO]:
|
|
366
|
+
log_file = tempfile.NamedTemporaryFile(
|
|
367
|
+
prefix="mcp_stata_",
|
|
368
|
+
suffix=".log",
|
|
369
|
+
delete=False,
|
|
370
|
+
mode="w",
|
|
371
|
+
encoding="utf-8",
|
|
372
|
+
errors="replace",
|
|
373
|
+
buffering=1,
|
|
374
|
+
)
|
|
375
|
+
log_path = log_file.name
|
|
376
|
+
tail = TailBuffer(max_chars=200000 if trace else 20000)
|
|
377
|
+
tee = FileTeeIO(log_file, tail)
|
|
378
|
+
return log_file, log_path, tail, tee
|
|
379
|
+
|
|
380
|
+
def _init_streaming_graph_cache(
|
|
381
|
+
self,
|
|
382
|
+
auto_cache_graphs: bool,
|
|
383
|
+
on_graph_cached: Optional[Callable[[str, bool], Awaitable[None]]],
|
|
384
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
385
|
+
) -> Optional[StreamingGraphCache]:
|
|
386
|
+
if not auto_cache_graphs:
|
|
387
|
+
return None
|
|
388
|
+
graph_cache = StreamingGraphCache(self, auto_cache=True)
|
|
389
|
+
graph_cache_callback = self._create_graph_cache_callback(on_graph_cached, notify_log)
|
|
390
|
+
graph_cache.add_cache_callback(graph_cache_callback)
|
|
391
|
+
return graph_cache
|
|
392
|
+
|
|
393
|
+
def _capture_graph_state(
|
|
394
|
+
self,
|
|
395
|
+
graph_cache: Optional[StreamingGraphCache],
|
|
396
|
+
emit_graph_ready: bool,
|
|
397
|
+
) -> Optional[dict[str, str]]:
|
|
398
|
+
# Capture initial graph state BEFORE execution starts
|
|
399
|
+
if graph_cache:
|
|
400
|
+
# Clear detection state for the new command (detected/removed sets)
|
|
401
|
+
# but preserve _last_graph_state signatures for modification detection.
|
|
402
|
+
graph_cache.detector.clear_detection_state()
|
|
403
|
+
try:
|
|
404
|
+
graph_cache._initial_graphs = set(self.list_graphs(force_refresh=True))
|
|
405
|
+
logger.debug(f"Initial graph state captured: {graph_cache._initial_graphs}")
|
|
406
|
+
except Exception as e:
|
|
407
|
+
logger.debug(f"Failed to capture initial graph state: {e}")
|
|
408
|
+
graph_cache._initial_graphs = set()
|
|
409
|
+
|
|
410
|
+
graph_ready_initial = None
|
|
411
|
+
if emit_graph_ready:
|
|
412
|
+
try:
|
|
413
|
+
graph_ready_initial = {}
|
|
414
|
+
for graph_name in self.list_graphs(force_refresh=True):
|
|
415
|
+
graph_ready_initial[graph_name] = self._get_graph_signature(graph_name)
|
|
416
|
+
logger.debug("Graph-ready initial state captured: %s", set(graph_ready_initial))
|
|
417
|
+
except Exception as e:
|
|
418
|
+
logger.debug("Failed to capture graph-ready state: %s", e)
|
|
419
|
+
graph_ready_initial = {}
|
|
420
|
+
return graph_ready_initial
|
|
421
|
+
|
|
422
|
+
async def _cache_new_graphs(
|
|
423
|
+
self,
|
|
424
|
+
graph_cache: Optional[StreamingGraphCache],
|
|
425
|
+
*,
|
|
426
|
+
notify_progress: Optional[Callable[[float, Optional[float], Optional[str]], Awaitable[None]]],
|
|
427
|
+
total_lines: int,
|
|
428
|
+
completed_label: str,
|
|
429
|
+
) -> None:
|
|
430
|
+
if not graph_cache or not graph_cache.auto_cache:
|
|
431
|
+
return
|
|
432
|
+
try:
|
|
433
|
+
cached_graphs = []
|
|
434
|
+
# Use detector to find new OR modified graphs
|
|
435
|
+
pystata_detected = await anyio.to_thread.run_sync(graph_cache.detector._detect_graphs_via_pystata)
|
|
436
|
+
|
|
437
|
+
# Combine with any pending graphs in queue
|
|
438
|
+
with graph_cache._lock:
|
|
439
|
+
to_process = set(pystata_detected) | set(graph_cache._graphs_to_cache)
|
|
440
|
+
graph_cache._graphs_to_cache.clear()
|
|
441
|
+
|
|
442
|
+
if to_process:
|
|
443
|
+
logger.info(f"Detected {len(to_process)} new or modified graph(s): {sorted(to_process)}")
|
|
444
|
+
|
|
445
|
+
for graph_name in to_process:
|
|
446
|
+
if graph_name in graph_cache._cached_graphs:
|
|
447
|
+
continue
|
|
448
|
+
|
|
449
|
+
try:
|
|
450
|
+
cache_result = await anyio.to_thread.run_sync(
|
|
451
|
+
self.cache_graph_on_creation,
|
|
452
|
+
graph_name,
|
|
453
|
+
)
|
|
454
|
+
if cache_result:
|
|
455
|
+
cached_graphs.append(graph_name)
|
|
456
|
+
graph_cache._cached_graphs.add(graph_name)
|
|
457
|
+
|
|
458
|
+
for callback in graph_cache._cache_callbacks:
|
|
459
|
+
try:
|
|
460
|
+
result = callback(graph_name, cache_result)
|
|
461
|
+
if inspect.isawaitable(result):
|
|
462
|
+
await result
|
|
463
|
+
except Exception:
|
|
464
|
+
pass
|
|
465
|
+
except Exception as e:
|
|
466
|
+
logger.error(f"Error caching graph {graph_name}: {e}")
|
|
467
|
+
|
|
468
|
+
if cached_graphs and notify_progress:
|
|
469
|
+
await notify_progress(
|
|
470
|
+
float(total_lines) if total_lines > 0 else 1,
|
|
471
|
+
float(total_lines) if total_lines > 0 else 1,
|
|
472
|
+
f"{completed_label} completed. Cached {len(cached_graphs)} graph(s): {', '.join(cached_graphs)}",
|
|
473
|
+
)
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logger.error(f"Post-execution graph detection failed: {e}")
|
|
476
|
+
|
|
477
|
+
def _emit_graph_ready_task(
|
|
478
|
+
self,
|
|
479
|
+
*,
|
|
480
|
+
emit_graph_ready: bool,
|
|
481
|
+
graph_ready_initial: Optional[dict[str, str]],
|
|
482
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
483
|
+
graph_ready_task_id: Optional[str],
|
|
484
|
+
graph_ready_format: str,
|
|
485
|
+
) -> None:
|
|
486
|
+
if emit_graph_ready and graph_ready_initial is not None:
|
|
487
|
+
try:
|
|
488
|
+
asyncio.create_task(
|
|
489
|
+
self._emit_graph_ready_events(
|
|
490
|
+
graph_ready_initial,
|
|
491
|
+
notify_log,
|
|
492
|
+
graph_ready_task_id,
|
|
493
|
+
graph_ready_format,
|
|
494
|
+
)
|
|
495
|
+
)
|
|
496
|
+
except Exception as e:
|
|
497
|
+
logger.warning("graph_ready emission failed to start: %s", e)
|
|
498
|
+
|
|
499
|
+
async def _stream_smcl_log(
|
|
500
|
+
self,
|
|
501
|
+
*,
|
|
502
|
+
smcl_path: str,
|
|
503
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
504
|
+
done: anyio.Event,
|
|
505
|
+
on_chunk: Optional[Callable[[str], Awaitable[None]]] = None,
|
|
506
|
+
) -> None:
|
|
507
|
+
last_pos = 0
|
|
508
|
+
emitted_debug_chunks = 0
|
|
509
|
+
# Wait for Stata to create the SMCL file
|
|
510
|
+
while not done.is_set() and not os.path.exists(smcl_path):
|
|
511
|
+
await anyio.sleep(0.05)
|
|
512
|
+
|
|
513
|
+
try:
|
|
514
|
+
def _read_content() -> str:
|
|
515
|
+
try:
|
|
516
|
+
with open(smcl_path, "r", encoding="utf-8", errors="replace") as f:
|
|
517
|
+
f.seek(last_pos)
|
|
518
|
+
return f.read()
|
|
519
|
+
except PermissionError:
|
|
520
|
+
if os.name == "nt":
|
|
521
|
+
try:
|
|
522
|
+
res = subprocess.run(f'type "{smcl_path}"', shell=True, capture_output=True)
|
|
523
|
+
full_content = res.stdout.decode("utf-8", errors="replace")
|
|
524
|
+
if len(full_content) > last_pos:
|
|
525
|
+
return full_content[last_pos:]
|
|
526
|
+
return ""
|
|
527
|
+
except Exception:
|
|
528
|
+
return ""
|
|
529
|
+
return ""
|
|
530
|
+
except FileNotFoundError:
|
|
531
|
+
return ""
|
|
532
|
+
|
|
533
|
+
while not done.is_set():
|
|
534
|
+
chunk = await anyio.to_thread.run_sync(_read_content)
|
|
535
|
+
if chunk:
|
|
536
|
+
last_pos += len(chunk)
|
|
537
|
+
try:
|
|
538
|
+
await notify_log(chunk)
|
|
539
|
+
except Exception as exc:
|
|
540
|
+
logger.debug("notify_log failed: %s", exc)
|
|
541
|
+
if on_chunk is not None:
|
|
542
|
+
try:
|
|
543
|
+
await on_chunk(chunk)
|
|
544
|
+
except Exception as exc:
|
|
545
|
+
logger.debug("on_chunk callback failed: %s", exc)
|
|
546
|
+
await anyio.sleep(0.05)
|
|
547
|
+
|
|
548
|
+
chunk = await anyio.to_thread.run_sync(_read_content)
|
|
549
|
+
if on_chunk is not None:
|
|
550
|
+
# Final check even if last chunk is empty, to ensure
|
|
551
|
+
# graphs created at the very end are detected.
|
|
552
|
+
try:
|
|
553
|
+
await on_chunk(chunk or "")
|
|
554
|
+
except Exception as exc:
|
|
555
|
+
logger.debug("final on_chunk check failed: %s", exc)
|
|
556
|
+
|
|
557
|
+
if chunk:
|
|
558
|
+
last_pos += len(chunk)
|
|
559
|
+
try:
|
|
560
|
+
await notify_log(chunk)
|
|
561
|
+
except Exception as exc:
|
|
562
|
+
logger.debug("notify_log failed: %s", exc)
|
|
563
|
+
|
|
564
|
+
except Exception as e:
|
|
565
|
+
logger.warning(f"Log streaming failed: {e}")
|
|
566
|
+
|
|
567
|
+
def _run_streaming_blocking(
|
|
568
|
+
self,
|
|
569
|
+
*,
|
|
570
|
+
command: str,
|
|
571
|
+
tee: FileTeeIO,
|
|
572
|
+
cwd: Optional[str],
|
|
573
|
+
trace: bool,
|
|
574
|
+
echo: bool,
|
|
575
|
+
smcl_path: str,
|
|
576
|
+
smcl_log_name: str,
|
|
577
|
+
hold_attr: str,
|
|
578
|
+
require_smcl_log: bool = False,
|
|
579
|
+
) -> tuple[int, Optional[Exception]]:
|
|
580
|
+
rc = -1
|
|
581
|
+
exc: Optional[Exception] = None
|
|
582
|
+
with self._exec_lock:
|
|
583
|
+
self._is_executing = True
|
|
584
|
+
try:
|
|
585
|
+
from sfi import Scalar, SFIToolkit # Import SFI tools
|
|
586
|
+
with self._temp_cwd(cwd):
|
|
587
|
+
logger.debug(
|
|
588
|
+
"opening SMCL log name=%s path=%s cwd=%s",
|
|
589
|
+
smcl_log_name,
|
|
590
|
+
smcl_path,
|
|
591
|
+
os.getcwd(),
|
|
592
|
+
)
|
|
593
|
+
try:
|
|
594
|
+
log_opened = self._open_smcl_log(smcl_path, smcl_log_name, quiet=True)
|
|
595
|
+
except Exception as e:
|
|
596
|
+
log_opened = False
|
|
597
|
+
logger.warning("_open_smcl_log raised: %r", e)
|
|
598
|
+
logger.info("SMCL log_opened=%s path=%s", log_opened, smcl_path)
|
|
599
|
+
if require_smcl_log and not log_opened:
|
|
600
|
+
exc = RuntimeError("Failed to open SMCL log")
|
|
601
|
+
logger.error("SMCL log open failed for %s", smcl_path)
|
|
602
|
+
rc = 1
|
|
603
|
+
if exc is None:
|
|
604
|
+
try:
|
|
605
|
+
with self._redirect_io_streaming(tee, tee):
|
|
606
|
+
try:
|
|
607
|
+
if trace:
|
|
608
|
+
self.stata.run("set trace on")
|
|
609
|
+
logger.debug("running Stata command echo=%s: %s", echo, command)
|
|
610
|
+
ret = self.stata.run(command, echo=echo)
|
|
611
|
+
if ret:
|
|
612
|
+
logger.debug("stata.run output: %s", ret)
|
|
613
|
+
|
|
614
|
+
setattr(self, hold_attr, f"mcp_hold_{uuid.uuid4().hex[:8]}")
|
|
615
|
+
self.stata.run(
|
|
616
|
+
f"capture _return hold {getattr(self, hold_attr)}",
|
|
617
|
+
echo=False,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
if isinstance(ret, str) and ret:
|
|
621
|
+
try:
|
|
622
|
+
tee.write(ret)
|
|
623
|
+
except Exception:
|
|
624
|
+
pass
|
|
625
|
+
try:
|
|
626
|
+
rc = self._get_rc_from_scalar(Scalar)
|
|
627
|
+
except Exception:
|
|
628
|
+
pass
|
|
629
|
+
except Exception as e:
|
|
630
|
+
exc = e
|
|
631
|
+
logger.error("stata.run failed: %r", e)
|
|
632
|
+
if rc in (-1, 0):
|
|
633
|
+
rc = 1
|
|
634
|
+
finally:
|
|
635
|
+
if trace:
|
|
636
|
+
try:
|
|
637
|
+
self.stata.run("set trace off")
|
|
638
|
+
except Exception:
|
|
639
|
+
pass
|
|
640
|
+
finally:
|
|
641
|
+
self._close_smcl_log(smcl_log_name)
|
|
642
|
+
self._restore_results_from_hold(hold_attr)
|
|
643
|
+
return rc, exc
|
|
644
|
+
# If we get here, SMCL log failed and we're required to stop.
|
|
645
|
+
return rc, exc
|
|
646
|
+
finally:
|
|
647
|
+
self._is_executing = False
|
|
648
|
+
return rc, exc
|
|
649
|
+
|
|
650
|
+
def _resolve_do_file_path(
|
|
651
|
+
self,
|
|
652
|
+
path: str,
|
|
653
|
+
cwd: Optional[str],
|
|
654
|
+
) -> tuple[Optional[str], Optional[str], Optional[CommandResponse]]:
|
|
655
|
+
if cwd is not None and not os.path.isdir(cwd):
|
|
656
|
+
return None, None, CommandResponse(
|
|
657
|
+
command=f'do "{path}"',
|
|
658
|
+
rc=601,
|
|
659
|
+
stdout="",
|
|
660
|
+
stderr=None,
|
|
661
|
+
success=False,
|
|
662
|
+
error=ErrorEnvelope(
|
|
663
|
+
message=f"cwd not found: {cwd}",
|
|
664
|
+
rc=601,
|
|
665
|
+
command=path,
|
|
666
|
+
),
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
effective_path = path
|
|
670
|
+
if cwd is not None and not os.path.isabs(path):
|
|
671
|
+
effective_path = os.path.abspath(os.path.join(cwd, path))
|
|
672
|
+
|
|
673
|
+
if not os.path.exists(effective_path):
|
|
674
|
+
return None, None, CommandResponse(
|
|
675
|
+
command=f'do "{effective_path}"',
|
|
676
|
+
rc=601,
|
|
677
|
+
stdout="",
|
|
678
|
+
stderr=None,
|
|
679
|
+
success=False,
|
|
680
|
+
error=ErrorEnvelope(
|
|
681
|
+
message=f"Do-file not found: {effective_path}",
|
|
682
|
+
rc=601,
|
|
683
|
+
command=effective_path,
|
|
684
|
+
),
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
path_for_stata = effective_path.replace("\\", "/")
|
|
688
|
+
command = f'do "{path_for_stata}"'
|
|
689
|
+
return effective_path, command, None
|
|
690
|
+
|
|
691
|
+
@contextmanager
|
|
692
|
+
def _smcl_log_capture(self) -> "Generator[Tuple[str, str], None, None]":
|
|
693
|
+
"""
|
|
694
|
+
Context manager that wraps command execution in a named SMCL log.
|
|
695
|
+
|
|
696
|
+
This runs alongside any user logs (named logs can coexist).
|
|
697
|
+
Yields (log_name, log_path) tuple for use within the context.
|
|
698
|
+
The SMCL file is NOT deleted automatically - caller should clean up.
|
|
699
|
+
|
|
700
|
+
Usage:
|
|
701
|
+
with self._smcl_log_capture() as (log_name, smcl_path):
|
|
702
|
+
self.stata.run(cmd)
|
|
703
|
+
# After context, read smcl_path for raw SMCL output
|
|
704
|
+
"""
|
|
705
|
+
# Use a unique name but DO NOT join start with mkstemp to avoid existing file locks.
|
|
706
|
+
# Stata will create the file.
|
|
707
|
+
smcl_path = self._create_smcl_log_path()
|
|
708
|
+
# Unique log name to avoid collisions with user logs
|
|
709
|
+
log_name = self._make_smcl_log_name()
|
|
710
|
+
|
|
711
|
+
try:
|
|
712
|
+
# Open named SMCL log (quietly to avoid polluting output)
|
|
713
|
+
log_opened = self._open_smcl_log(smcl_path, log_name, quiet=True)
|
|
714
|
+
if not log_opened:
|
|
715
|
+
# Still yield, consumer might see empty file or handle error,
|
|
716
|
+
# but we can't do much if Stata refuses to log.
|
|
717
|
+
pass
|
|
718
|
+
|
|
719
|
+
yield log_name, smcl_path
|
|
720
|
+
finally:
|
|
721
|
+
# Always close our named log
|
|
722
|
+
self._close_smcl_log(log_name)
|
|
723
|
+
|
|
724
|
+
def _read_smcl_file(self, path: str) -> str:
|
|
725
|
+
"""Read SMCL file contents, handling encoding issues and Windows file locks."""
|
|
726
|
+
try:
|
|
727
|
+
with open(path, 'r', encoding='utf-8', errors='replace') as f:
|
|
728
|
+
return f.read()
|
|
729
|
+
except PermissionError:
|
|
730
|
+
if os.name == "nt":
|
|
731
|
+
# Windows Fallback: Try to use 'type' command to bypass exclusive lock
|
|
732
|
+
try:
|
|
733
|
+
res = subprocess.run(f'type "{path}"', shell=True, capture_output=True)
|
|
734
|
+
if res.returncode == 0:
|
|
735
|
+
return res.stdout.decode('utf-8', errors='replace')
|
|
736
|
+
except Exception as e:
|
|
737
|
+
logger.debug(f"Combined fallback read failed: {e}")
|
|
738
|
+
logger.warning(f"Failed to read SMCL file {path} due to lock")
|
|
739
|
+
return ""
|
|
740
|
+
except Exception as e:
|
|
741
|
+
logger.warning(f"Failed to read SMCL file {path}: {e}")
|
|
742
|
+
return ""
|
|
743
|
+
|
|
744
|
+
def _extract_error_from_smcl(self, smcl_content: str, rc: int) -> Tuple[str, str]:
|
|
745
|
+
"""
|
|
746
|
+
Extract error message and context from raw SMCL output.
|
|
747
|
+
|
|
748
|
+
Uses {err} tags as the authoritative source for error detection.
|
|
749
|
+
|
|
750
|
+
Returns:
|
|
751
|
+
Tuple of (error_message, context_string)
|
|
752
|
+
"""
|
|
753
|
+
if not smcl_content:
|
|
754
|
+
return f"Stata error r({rc})", ""
|
|
755
|
+
|
|
756
|
+
# Try Rust optimization
|
|
757
|
+
native_res = fast_scan_log(smcl_content, rc)
|
|
758
|
+
if native_res:
|
|
759
|
+
error_msg, context, _ = native_res
|
|
760
|
+
return error_msg, context
|
|
761
|
+
|
|
762
|
+
lines = smcl_content.splitlines()
|
|
763
|
+
|
|
764
|
+
# Search backwards for {err} tags - they indicate error lines
|
|
765
|
+
error_lines = []
|
|
766
|
+
error_start_idx = -1
|
|
767
|
+
|
|
768
|
+
for i in range(len(lines) - 1, -1, -1):
|
|
769
|
+
line = lines[i]
|
|
770
|
+
if '{err}' in line:
|
|
771
|
+
if error_start_idx == -1:
|
|
772
|
+
error_start_idx = i
|
|
773
|
+
# Walk backwards to find consecutive {err} lines
|
|
774
|
+
j = i
|
|
775
|
+
while j >= 0 and '{err}' in lines[j]:
|
|
776
|
+
error_lines.insert(0, lines[j])
|
|
777
|
+
j -= 1
|
|
778
|
+
break
|
|
779
|
+
|
|
780
|
+
if error_lines:
|
|
781
|
+
# Clean SMCL tags from error message
|
|
782
|
+
clean_lines = []
|
|
783
|
+
for line in error_lines:
|
|
784
|
+
# Remove SMCL tags but keep the text content
|
|
785
|
+
cleaned = re.sub(r'\{[^}]*\}', '', line).strip()
|
|
786
|
+
if cleaned:
|
|
787
|
+
clean_lines.append(cleaned)
|
|
788
|
+
|
|
789
|
+
error_msg = " ".join(clean_lines) or f"Stata error r({rc})"
|
|
790
|
+
|
|
791
|
+
# Context is everything from error start to end
|
|
792
|
+
context_start = max(0, error_start_idx - 5) # Include 5 lines before error
|
|
793
|
+
context = "\n".join(lines[context_start:])
|
|
794
|
+
|
|
795
|
+
return error_msg, context
|
|
796
|
+
|
|
797
|
+
# Fallback: no {err} found, return last 30 lines as context
|
|
798
|
+
context_start = max(0, len(lines) - 30)
|
|
799
|
+
context = "\n".join(lines[context_start:])
|
|
800
|
+
|
|
801
|
+
return f"Stata error r({rc})", context
|
|
802
|
+
|
|
803
|
+
def _parse_rc_from_smcl(self, smcl_content: str) -> Optional[int]:
|
|
804
|
+
"""Parse return code from SMCL content using specific structural patterns."""
|
|
805
|
+
if not smcl_content:
|
|
806
|
+
return None
|
|
807
|
+
|
|
808
|
+
# Try Rust optimization
|
|
809
|
+
native_res = fast_scan_log(smcl_content, 0)
|
|
810
|
+
if native_res:
|
|
811
|
+
_, _, rc = native_res
|
|
812
|
+
if rc is not None:
|
|
813
|
+
return rc
|
|
814
|
+
|
|
815
|
+
# 1. Primary check: SMCL search tag {search r(N), ...}
|
|
816
|
+
# This is the most authoritative interactive indicator
|
|
817
|
+
matches = list(re.finditer(r'\{search r\((\d+)\)', smcl_content))
|
|
818
|
+
if matches:
|
|
819
|
+
try:
|
|
820
|
+
return int(matches[-1].group(1))
|
|
821
|
+
except Exception:
|
|
822
|
+
pass
|
|
823
|
+
|
|
824
|
+
# 2. Secondary check: Standalone r(N); pattern
|
|
825
|
+
# This appears at the end of command blocks
|
|
826
|
+
matches = list(re.finditer(r'(?<!\w)r\((\d+)\);?', smcl_content))
|
|
827
|
+
if matches:
|
|
828
|
+
try:
|
|
829
|
+
return int(matches[-1].group(1))
|
|
830
|
+
except Exception:
|
|
831
|
+
pass
|
|
832
|
+
|
|
833
|
+
return None
|
|
834
|
+
|
|
835
|
+
@staticmethod
|
|
836
|
+
def _create_graph_cache_callback(on_graph_cached, notify_log):
|
|
837
|
+
"""Create a standardized graph cache callback with proper error handling."""
|
|
838
|
+
async def graph_cache_callback(graph_name: str, success: bool) -> None:
|
|
839
|
+
try:
|
|
840
|
+
if on_graph_cached:
|
|
841
|
+
await on_graph_cached(graph_name, success)
|
|
842
|
+
except Exception as e:
|
|
843
|
+
logger.error(f"Graph cache callback failed: {e}")
|
|
844
|
+
|
|
845
|
+
try:
|
|
846
|
+
# Also notify via log channel
|
|
847
|
+
await notify_log(json.dumps({
|
|
848
|
+
"event": "graph_cached",
|
|
849
|
+
"graph": graph_name,
|
|
850
|
+
"success": success
|
|
851
|
+
}))
|
|
852
|
+
except Exception as e:
|
|
853
|
+
logger.error(f"Failed to notify about graph cache: {e}")
|
|
854
|
+
|
|
855
|
+
return graph_cache_callback
|
|
856
|
+
|
|
857
|
+
def _get_cached_graph_path(self, graph_name: str) -> Optional[str]:
|
|
858
|
+
if not hasattr(self, "_cache_lock") or not hasattr(self, "_preemptive_cache"):
|
|
859
|
+
return None
|
|
860
|
+
try:
|
|
861
|
+
with self._cache_lock:
|
|
862
|
+
cache_path = self._preemptive_cache.get(graph_name)
|
|
863
|
+
if not cache_path:
|
|
864
|
+
return None
|
|
865
|
+
|
|
866
|
+
# Double-check validity (e.g. signature match for current command)
|
|
867
|
+
if not self._is_cache_valid(graph_name, cache_path):
|
|
868
|
+
return None
|
|
869
|
+
|
|
870
|
+
return cache_path
|
|
871
|
+
except Exception:
|
|
872
|
+
return None
|
|
873
|
+
|
|
874
|
+
async def _emit_graph_ready_for_graphs(
|
|
875
|
+
self,
|
|
876
|
+
graph_names: List[str],
|
|
877
|
+
*,
|
|
878
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
879
|
+
task_id: Optional[str],
|
|
880
|
+
export_format: str,
|
|
881
|
+
graph_ready_initial: Optional[dict[str, str]],
|
|
882
|
+
) -> None:
|
|
883
|
+
if not graph_names:
|
|
884
|
+
return
|
|
885
|
+
fmt = (export_format or "svg").strip().lower()
|
|
886
|
+
for graph_name in graph_names:
|
|
887
|
+
signature = self._get_graph_signature(graph_name)
|
|
888
|
+
if graph_ready_initial is not None:
|
|
889
|
+
previous = graph_ready_initial.get(graph_name)
|
|
890
|
+
if previous is not None and previous == signature:
|
|
891
|
+
continue
|
|
892
|
+
try:
|
|
893
|
+
export_path = None
|
|
894
|
+
if fmt == "svg":
|
|
895
|
+
export_path = self._get_cached_graph_path(graph_name)
|
|
896
|
+
if not export_path:
|
|
897
|
+
export_path = await anyio.to_thread.run_sync(
|
|
898
|
+
lambda: self.export_graph(graph_name, format=fmt)
|
|
899
|
+
)
|
|
900
|
+
payload = {
|
|
901
|
+
"event": "graph_ready",
|
|
902
|
+
"task_id": task_id,
|
|
903
|
+
"graph": {
|
|
904
|
+
"name": graph_name,
|
|
905
|
+
"path": export_path,
|
|
906
|
+
"label": graph_name,
|
|
907
|
+
},
|
|
908
|
+
}
|
|
909
|
+
await notify_log(json.dumps(payload))
|
|
910
|
+
if graph_ready_initial is not None:
|
|
911
|
+
graph_ready_initial[graph_name] = signature
|
|
912
|
+
except Exception as e:
|
|
913
|
+
logger.warning("graph_ready export failed for %s: %s", graph_name, e)
|
|
914
|
+
|
|
915
|
+
async def _maybe_cache_graphs_on_chunk(
|
|
916
|
+
self,
|
|
917
|
+
*,
|
|
918
|
+
graph_cache: Optional[StreamingGraphCache],
|
|
919
|
+
emit_graph_ready: bool,
|
|
920
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
921
|
+
graph_ready_task_id: Optional[str],
|
|
922
|
+
graph_ready_format: str,
|
|
923
|
+
graph_ready_initial: Optional[dict[str, str]],
|
|
924
|
+
last_check: List[float],
|
|
925
|
+
force: bool = False,
|
|
926
|
+
) -> None:
|
|
927
|
+
if not graph_cache or not graph_cache.auto_cache:
|
|
928
|
+
return
|
|
929
|
+
if self._is_executing and not force:
|
|
930
|
+
# Skip polling if Stata is busy; it will block on _exec_lock anyway.
|
|
931
|
+
# During final check (force=True), we know it's safe because _run_streaming_blocking has finished.
|
|
932
|
+
return
|
|
933
|
+
now = time.monotonic()
|
|
934
|
+
if not force and last_check and now - last_check[0] < 0.25:
|
|
935
|
+
return
|
|
936
|
+
if last_check:
|
|
937
|
+
last_check[0] = now
|
|
938
|
+
try:
|
|
939
|
+
cached_names = await graph_cache.cache_detected_graphs_with_pystata()
|
|
940
|
+
except Exception as e:
|
|
941
|
+
logger.debug("graph_ready polling failed: %s", e)
|
|
942
|
+
return
|
|
943
|
+
if emit_graph_ready and cached_names:
|
|
944
|
+
await self._emit_graph_ready_for_graphs(
|
|
945
|
+
cached_names,
|
|
946
|
+
notify_log=notify_log,
|
|
947
|
+
task_id=graph_ready_task_id,
|
|
948
|
+
export_format=graph_ready_format,
|
|
949
|
+
graph_ready_initial=graph_ready_initial,
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
async def _emit_graph_ready_events(
|
|
953
|
+
self,
|
|
954
|
+
initial_graphs: dict[str, str],
|
|
955
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
956
|
+
task_id: Optional[str],
|
|
957
|
+
export_format: str,
|
|
958
|
+
) -> None:
|
|
959
|
+
try:
|
|
960
|
+
current_graphs = list(self.list_graphs(force_refresh=True))
|
|
961
|
+
except Exception as e:
|
|
962
|
+
logger.warning("graph_ready: list_graphs failed: %s", e)
|
|
963
|
+
return
|
|
964
|
+
|
|
965
|
+
if not current_graphs:
|
|
966
|
+
return
|
|
967
|
+
|
|
968
|
+
for graph_name in current_graphs:
|
|
969
|
+
signature = self._get_graph_signature(graph_name)
|
|
970
|
+
previous = initial_graphs.get(graph_name)
|
|
971
|
+
if previous is not None and previous == signature:
|
|
972
|
+
continue
|
|
973
|
+
try:
|
|
974
|
+
export_path = None
|
|
975
|
+
if export_format == "svg":
|
|
976
|
+
export_path = self._get_cached_graph_path(graph_name)
|
|
977
|
+
|
|
978
|
+
if not export_path:
|
|
979
|
+
export_path = await anyio.to_thread.run_sync(
|
|
980
|
+
lambda: self.export_graph(graph_name, format=export_format)
|
|
981
|
+
)
|
|
982
|
+
payload = {
|
|
983
|
+
"event": "graph_ready",
|
|
984
|
+
"task_id": task_id,
|
|
985
|
+
"graph": {
|
|
986
|
+
"name": graph_name,
|
|
987
|
+
"path": export_path,
|
|
988
|
+
"label": graph_name,
|
|
989
|
+
},
|
|
990
|
+
}
|
|
991
|
+
await notify_log(json.dumps(payload))
|
|
992
|
+
initial_graphs[graph_name] = signature
|
|
993
|
+
except Exception as e:
|
|
994
|
+
logger.warning("graph_ready export failed for %s: %s", graph_name, e)
|
|
995
|
+
|
|
996
|
+
def _get_graph_signature(self, graph_name: str) -> str:
|
|
997
|
+
"""
|
|
998
|
+
Get a stable signature for a graph without calling Stata.
|
|
999
|
+
Consistent with GraphCreationDetector implementation.
|
|
1000
|
+
"""
|
|
1001
|
+
if not graph_name:
|
|
1002
|
+
return ""
|
|
1003
|
+
cmd_idx = getattr(self, "_command_idx", 0)
|
|
1004
|
+
# Only include command index for default 'Graph' to detect modifications.
|
|
1005
|
+
# For named graphs, we only want to detect them when they are new or renamed.
|
|
1006
|
+
if graph_name.lower() == "graph":
|
|
1007
|
+
return f"{graph_name}_{cmd_idx}"
|
|
1008
|
+
return graph_name
|
|
1009
|
+
|
|
1010
|
+
def _request_break_in(self) -> None:
|
|
1011
|
+
"""
|
|
1012
|
+
Attempt to interrupt a running Stata command when cancellation is requested.
|
|
1013
|
+
|
|
1014
|
+
Uses the Stata sfi.breakIn hook when available; errors are swallowed because
|
|
1015
|
+
cancellation should never crash the host process.
|
|
1016
|
+
"""
|
|
1017
|
+
try:
|
|
1018
|
+
import sfi # type: ignore[import-not-found]
|
|
1019
|
+
|
|
1020
|
+
break_fn = getattr(sfi, "breakIn", None) or getattr(sfi, "break_in", None)
|
|
1021
|
+
if callable(break_fn):
|
|
1022
|
+
try:
|
|
1023
|
+
break_fn()
|
|
1024
|
+
logger.info("Sent breakIn() to Stata for cancellation")
|
|
1025
|
+
except Exception as e: # pragma: no cover - best-effort
|
|
1026
|
+
logger.warning(f"Failed to send breakIn() to Stata: {e}")
|
|
1027
|
+
else: # pragma: no cover - environment without Stata runtime
|
|
1028
|
+
logger.debug("sfi.breakIn not available; cannot interrupt Stata")
|
|
1029
|
+
except Exception as e: # pragma: no cover - import failure or other
|
|
1030
|
+
logger.debug(f"Unable to import sfi for cancellation: {e}")
|
|
1031
|
+
|
|
1032
|
+
async def _wait_for_stata_stop(self, timeout: float = 2.0) -> bool:
|
|
1033
|
+
"""
|
|
1034
|
+
After requesting a break, poll the Stata interface so it can surface BreakError
|
|
1035
|
+
and return control. This is best-effort and time-bounded.
|
|
1036
|
+
"""
|
|
1037
|
+
deadline = time.monotonic() + timeout
|
|
1038
|
+
try:
|
|
1039
|
+
import sfi # type: ignore[import-not-found]
|
|
1040
|
+
|
|
1041
|
+
toolkit = getattr(sfi, "SFIToolkit", None)
|
|
1042
|
+
poll = getattr(toolkit, "pollnow", None) or getattr(toolkit, "pollstd", None)
|
|
1043
|
+
BreakError = getattr(sfi, "BreakError", None)
|
|
1044
|
+
except Exception: # pragma: no cover
|
|
1045
|
+
return False
|
|
1046
|
+
|
|
1047
|
+
if not callable(poll):
|
|
1048
|
+
return False
|
|
1049
|
+
|
|
1050
|
+
last_exc: Optional[Exception] = None
|
|
1051
|
+
while time.monotonic() < deadline:
|
|
1052
|
+
try:
|
|
1053
|
+
poll()
|
|
1054
|
+
except Exception as e: # pragma: no cover - depends on Stata runtime
|
|
1055
|
+
last_exc = e
|
|
1056
|
+
if BreakError is not None and isinstance(e, BreakError):
|
|
1057
|
+
logger.info("Stata BreakError detected; cancellation acknowledged by Stata")
|
|
1058
|
+
return True
|
|
1059
|
+
# If Stata already stopped, break on any other exception.
|
|
1060
|
+
break
|
|
1061
|
+
await anyio.sleep(0.05)
|
|
1062
|
+
|
|
1063
|
+
if last_exc:
|
|
1064
|
+
logger.debug(f"Cancellation poll exited with {last_exc}")
|
|
1065
|
+
return False
|
|
1066
|
+
|
|
1067
|
+
@contextmanager
|
|
1068
|
+
def _temp_cwd(self, cwd: Optional[str]):
|
|
1069
|
+
if cwd is None:
|
|
1070
|
+
yield
|
|
1071
|
+
return
|
|
1072
|
+
prev = os.getcwd()
|
|
1073
|
+
os.chdir(cwd)
|
|
1074
|
+
try:
|
|
1075
|
+
yield
|
|
1076
|
+
finally:
|
|
1077
|
+
os.chdir(prev)
|
|
1078
|
+
|
|
1079
|
+
@contextmanager
|
|
1080
|
+
def _safe_redirect_fds(self):
|
|
1081
|
+
"""Redirects fd 1 (stdout) to fd 2 (stderr) at the OS level."""
|
|
1082
|
+
# Save original stdout fd
|
|
1083
|
+
try:
|
|
1084
|
+
stdout_fd = os.dup(1)
|
|
1085
|
+
except Exception:
|
|
1086
|
+
# Fallback if we can't dup (e.g. strange environment)
|
|
1087
|
+
yield
|
|
1088
|
+
return
|
|
1089
|
+
|
|
1090
|
+
try:
|
|
1091
|
+
# Redirect OS-level stdout to stderr
|
|
1092
|
+
os.dup2(2, 1)
|
|
1093
|
+
yield
|
|
1094
|
+
finally:
|
|
1095
|
+
# Restore stdout
|
|
1096
|
+
try:
|
|
1097
|
+
os.dup2(stdout_fd, 1)
|
|
1098
|
+
os.close(stdout_fd)
|
|
1099
|
+
except Exception:
|
|
1100
|
+
pass
|
|
1101
|
+
|
|
1102
|
+
def init(self):
|
|
1103
|
+
"""Initializes usage of pystata using cached discovery results."""
|
|
1104
|
+
if self._initialized:
|
|
1105
|
+
return
|
|
1106
|
+
|
|
1107
|
+
# Suppress any non-UTF8 banner output from PyStata on stdout, which breaks MCP stdio transport
|
|
1108
|
+
from contextlib import redirect_stdout, redirect_stderr
|
|
1109
|
+
|
|
1110
|
+
try:
|
|
1111
|
+
import stata_setup
|
|
1112
|
+
|
|
1113
|
+
# Get discovered Stata paths (cached from first call)
|
|
1114
|
+
discovery_candidates = _get_discovery_candidates()
|
|
1115
|
+
if not discovery_candidates:
|
|
1116
|
+
raise RuntimeError("No Stata candidates found during discovery")
|
|
1117
|
+
|
|
1118
|
+
logger.info("Initializing Stata engine (attempting up to %d candidate binaries)...", len(discovery_candidates))
|
|
1119
|
+
|
|
1120
|
+
# Diagnostic: force faulthandler to output to stderr for C crashes
|
|
1121
|
+
import faulthandler
|
|
1122
|
+
faulthandler.enable(file=sys.stderr)
|
|
1123
|
+
import subprocess
|
|
1124
|
+
|
|
1125
|
+
success = False
|
|
1126
|
+
last_error = None
|
|
1127
|
+
chosen_exec: Optional[Tuple[str, str]] = None
|
|
1128
|
+
|
|
1129
|
+
for stata_exec_path, edition in discovery_candidates:
|
|
1130
|
+
candidates = []
|
|
1131
|
+
# Prefer the binary directory first (documented input for stata_setup)
|
|
1132
|
+
bin_dir = os.path.dirname(stata_exec_path)
|
|
1133
|
+
|
|
1134
|
+
# 2. App Bundle: .../StataMP.app (macOS only)
|
|
1135
|
+
curr = bin_dir
|
|
1136
|
+
app_bundle = None
|
|
1137
|
+
while len(curr) > 1:
|
|
1138
|
+
if curr.endswith(".app"):
|
|
1139
|
+
app_bundle = curr
|
|
1140
|
+
break
|
|
1141
|
+
parent = os.path.dirname(curr)
|
|
1142
|
+
if parent == curr:
|
|
1143
|
+
break
|
|
1144
|
+
curr = parent
|
|
1145
|
+
|
|
1146
|
+
ordered_candidates = []
|
|
1147
|
+
if app_bundle:
|
|
1148
|
+
# On macOS, the parent of the .app is often the correct install path
|
|
1149
|
+
# (e.g., /Applications/StataNow containing StataMP.app)
|
|
1150
|
+
parent_dir = os.path.dirname(app_bundle)
|
|
1151
|
+
if parent_dir and parent_dir != "/":
|
|
1152
|
+
ordered_candidates.append(parent_dir)
|
|
1153
|
+
ordered_candidates.append(app_bundle)
|
|
1154
|
+
|
|
1155
|
+
if bin_dir:
|
|
1156
|
+
ordered_candidates.append(bin_dir)
|
|
1157
|
+
|
|
1158
|
+
# Deduplicate preserving order
|
|
1159
|
+
seen = set()
|
|
1160
|
+
candidates = []
|
|
1161
|
+
for c in ordered_candidates:
|
|
1162
|
+
if c not in seen:
|
|
1163
|
+
seen.add(c)
|
|
1164
|
+
candidates.append(c)
|
|
1165
|
+
|
|
1166
|
+
for path in candidates:
|
|
1167
|
+
try:
|
|
1168
|
+
# 1. Pre-flight check in a subprocess to capture hard exits/crashes
|
|
1169
|
+
sys.stderr.write(f"[mcp_stata] DEBUG: Pre-flight check for path '{path}'\n")
|
|
1170
|
+
sys.stderr.flush()
|
|
1171
|
+
|
|
1172
|
+
preflight_code = f"""
|
|
1173
|
+
import sys
|
|
1174
|
+
import stata_setup
|
|
1175
|
+
from contextlib import redirect_stdout, redirect_stderr
|
|
1176
|
+
with redirect_stdout(sys.stderr), redirect_stderr(sys.stderr):
|
|
1177
|
+
try:
|
|
1178
|
+
stata_setup.config({repr(path)}, {repr(edition)})
|
|
1179
|
+
from pystata import stata
|
|
1180
|
+
# Minimal verification of engine health
|
|
1181
|
+
stata.run('display 1', echo=False)
|
|
1182
|
+
print('PREFLIGHT_OK')
|
|
1183
|
+
except Exception as e:
|
|
1184
|
+
print(f'PREFLIGHT_FAIL: {{e}}', file=sys.stderr)
|
|
1185
|
+
sys.exit(1)
|
|
1186
|
+
"""
|
|
1187
|
+
|
|
1188
|
+
try:
|
|
1189
|
+
# Use shorter timeout for pre-flight if feasible,
|
|
1190
|
+
# but keep it safe for slow environments. 15s is usually enough for a ping.
|
|
1191
|
+
res = subprocess.run(
|
|
1192
|
+
[sys.executable, "-c", preflight_code],
|
|
1193
|
+
capture_output=True, text=True, timeout=20
|
|
1194
|
+
)
|
|
1195
|
+
if res.returncode != 0:
|
|
1196
|
+
sys.stderr.write(f"[mcp_stata] Pre-flight failed (rc={res.returncode}) for '{path}'\n")
|
|
1197
|
+
if res.stdout.strip():
|
|
1198
|
+
sys.stderr.write(f"--- Pre-flight stdout ---\n{res.stdout.strip()}\n")
|
|
1199
|
+
if res.stderr.strip():
|
|
1200
|
+
sys.stderr.write(f"--- Pre-flight stderr ---\n{res.stderr.strip()}\n")
|
|
1201
|
+
sys.stderr.flush()
|
|
1202
|
+
last_error = f"Pre-flight failed: {res.stdout.strip()} {res.stderr.strip()}"
|
|
1203
|
+
continue
|
|
1204
|
+
else:
|
|
1205
|
+
sys.stderr.write(f"[mcp_stata] Pre-flight succeeded for '{path}'. Proceeding to in-process init.\n")
|
|
1206
|
+
sys.stderr.flush()
|
|
1207
|
+
except Exception as pre_e:
|
|
1208
|
+
sys.stderr.write(f"[mcp_stata] Pre-flight execution error for '{path}': {repr(pre_e)}\n")
|
|
1209
|
+
sys.stderr.flush()
|
|
1210
|
+
last_error = pre_e
|
|
1211
|
+
continue
|
|
1212
|
+
|
|
1213
|
+
msg = f"[mcp_stata] DEBUG: In-process stata_setup.config('{path}', '{edition}')\n"
|
|
1214
|
+
sys.stderr.write(msg)
|
|
1215
|
+
sys.stderr.flush()
|
|
1216
|
+
# Redirect both sys.stdout/err AND the raw fds to our stderr pipe.
|
|
1217
|
+
with redirect_stdout(sys.stderr), redirect_stderr(sys.stderr), self._safe_redirect_fds():
|
|
1218
|
+
stata_setup.config(path, edition)
|
|
1219
|
+
|
|
1220
|
+
sys.stderr.write(f"[mcp_stata] DEBUG: stata_setup.config succeeded for path: {path}\n")
|
|
1221
|
+
sys.stderr.flush()
|
|
1222
|
+
success = True
|
|
1223
|
+
chosen_exec = (stata_exec_path, edition)
|
|
1224
|
+
logger.info("stata_setup.config succeeded with path: %s", path)
|
|
1225
|
+
break
|
|
1226
|
+
except BaseException as e:
|
|
1227
|
+
last_error = e
|
|
1228
|
+
sys.stderr.write(f"[mcp_stata] WARNING: In-process stata_setup.config caught: {repr(e)}\n")
|
|
1229
|
+
sys.stderr.flush()
|
|
1230
|
+
logger.warning("stata_setup.config failed for path '%s': %s", path, e)
|
|
1231
|
+
if isinstance(e, SystemExit):
|
|
1232
|
+
break
|
|
1233
|
+
continue
|
|
1234
|
+
|
|
1235
|
+
if success:
|
|
1236
|
+
# Cache winning candidate for subsequent lookups
|
|
1237
|
+
global _discovery_result
|
|
1238
|
+
if chosen_exec:
|
|
1239
|
+
_discovery_result = chosen_exec
|
|
1240
|
+
break
|
|
1241
|
+
|
|
1242
|
+
if not success:
|
|
1243
|
+
error_msg = (
|
|
1244
|
+
f"stata_setup.config failed to initialize Stata. "
|
|
1245
|
+
f"Tried candidates: {discovery_candidates}. "
|
|
1246
|
+
f"Last error: {repr(last_error)}"
|
|
1247
|
+
)
|
|
1248
|
+
sys.stderr.write(f"[mcp_stata] ERROR: {error_msg}\n")
|
|
1249
|
+
sys.stderr.flush()
|
|
1250
|
+
logger.error(error_msg)
|
|
1251
|
+
raise RuntimeError(error_msg)
|
|
1252
|
+
|
|
1253
|
+
# Cache the binary path for later use (e.g., PNG export on Windows)
|
|
1254
|
+
self._stata_exec_path = os.path.abspath(stata_exec_path)
|
|
1255
|
+
|
|
1256
|
+
try:
|
|
1257
|
+
sys.stderr.write("[mcp_stata] DEBUG: Importing pystata and warming up...\n")
|
|
1258
|
+
sys.stderr.flush()
|
|
1259
|
+
with redirect_stdout(sys.stderr), redirect_stderr(sys.stderr), self._safe_redirect_fds():
|
|
1260
|
+
from pystata import stata # type: ignore[import-not-found]
|
|
1261
|
+
# Warm up the engine and swallow any late splash screen output
|
|
1262
|
+
stata.run("display 1", echo=False)
|
|
1263
|
+
self.stata = stata
|
|
1264
|
+
self._initialized = True
|
|
1265
|
+
sys.stderr.write("[mcp_stata] DEBUG: pystata warmed up successfully\n")
|
|
1266
|
+
sys.stderr.flush()
|
|
1267
|
+
except BaseException as e:
|
|
1268
|
+
sys.stderr.write(f"[mcp_stata] ERROR: Failed to load pystata or run initial command: {repr(e)}\n")
|
|
1269
|
+
sys.stderr.flush()
|
|
1270
|
+
logger.error("Failed to load pystata or run initial command: %s", e)
|
|
1271
|
+
raise
|
|
1272
|
+
|
|
1273
|
+
# Initialize list_graphs TTL cache
|
|
1274
|
+
self._list_graphs_cache = None
|
|
1275
|
+
self._list_graphs_cache_time = 0
|
|
1276
|
+
self._list_graphs_cache_lock = threading.Lock()
|
|
1277
|
+
|
|
1278
|
+
# Map user-facing graph names (may include spaces/punctuation) to valid
|
|
1279
|
+
# internal Stata graph names.
|
|
1280
|
+
self._graph_name_aliases: Dict[str, str] = {}
|
|
1281
|
+
self._graph_name_reverse: Dict[str, str] = {}
|
|
1282
|
+
|
|
1283
|
+
logger.info("StataClient initialized successfully with %s (%s)", stata_exec_path, edition)
|
|
1284
|
+
|
|
1285
|
+
except ImportError as e:
|
|
1286
|
+
raise RuntimeError(
|
|
1287
|
+
f"Failed to import stata_setup or pystata: {e}. "
|
|
1288
|
+
"Ensure they are installed (pip install pystata stata-setup)."
|
|
1289
|
+
) from e
|
|
1290
|
+
|
|
1291
|
+
def _make_valid_stata_name(self, name: str) -> str:
|
|
1292
|
+
"""Create a valid Stata name (<=32 chars, [A-Za-z_][A-Za-z0-9_]*)."""
|
|
1293
|
+
base = re.sub(r"[^A-Za-z0-9_]", "_", name or "")
|
|
1294
|
+
if not base:
|
|
1295
|
+
base = "Graph"
|
|
1296
|
+
if not re.match(r"^[A-Za-z_]", base):
|
|
1297
|
+
base = f"G_{base}"
|
|
1298
|
+
base = base[:32]
|
|
1299
|
+
|
|
1300
|
+
# Avoid collisions.
|
|
1301
|
+
candidate = base
|
|
1302
|
+
i = 1
|
|
1303
|
+
while candidate in getattr(self, "_graph_name_reverse", {}):
|
|
1304
|
+
suffix = f"_{i}"
|
|
1305
|
+
candidate = (base[: max(0, 32 - len(suffix))] + suffix)[:32]
|
|
1306
|
+
i += 1
|
|
1307
|
+
return candidate
|
|
1308
|
+
|
|
1309
|
+
def _resolve_graph_name_for_stata(self, name: str) -> str:
|
|
1310
|
+
"""Return internal Stata graph name for a user-facing name."""
|
|
1311
|
+
if not name:
|
|
1312
|
+
return name
|
|
1313
|
+
aliases = getattr(self, "_graph_name_aliases", None)
|
|
1314
|
+
if aliases and name in aliases:
|
|
1315
|
+
return aliases[name]
|
|
1316
|
+
return name
|
|
1317
|
+
|
|
1318
|
+
def _maybe_rewrite_graph_name_in_command(self, code: str) -> str:
|
|
1319
|
+
"""Rewrite name("...") to a valid Stata name and store alias mapping."""
|
|
1320
|
+
if not code:
|
|
1321
|
+
return code
|
|
1322
|
+
if not hasattr(self, "_graph_name_aliases"):
|
|
1323
|
+
self._graph_name_aliases = {}
|
|
1324
|
+
self._graph_name_reverse = {}
|
|
1325
|
+
|
|
1326
|
+
# Handle common patterns: name("..." ...) or name(`"..."' ...)
|
|
1327
|
+
pat = re.compile(r"name\(\s*(?:`\"(?P<cq>[^\"]*)\"'|\"(?P<dq>[^\"]*)\")\s*(?P<rest>[^)]*)\)")
|
|
1328
|
+
|
|
1329
|
+
def repl(m: re.Match) -> str:
|
|
1330
|
+
original = m.group("cq") if m.group("cq") is not None else m.group("dq")
|
|
1331
|
+
original = original or ""
|
|
1332
|
+
internal = self._graph_name_aliases.get(original)
|
|
1333
|
+
if not internal:
|
|
1334
|
+
internal = self._make_valid_stata_name(original)
|
|
1335
|
+
self._graph_name_aliases[original] = internal
|
|
1336
|
+
self._graph_name_reverse[internal] = original
|
|
1337
|
+
rest = m.group("rest") or ""
|
|
1338
|
+
return f"name({internal}{rest})"
|
|
1339
|
+
|
|
1340
|
+
return pat.sub(repl, code)
|
|
1341
|
+
|
|
1342
|
+
def _get_rc_from_scalar(self, Scalar) -> int:
|
|
1343
|
+
"""Safely get return code, handling None values."""
|
|
1344
|
+
try:
|
|
1345
|
+
from sfi import Macro
|
|
1346
|
+
rc_val = Macro.getGlobal("_rc")
|
|
1347
|
+
if rc_val is None:
|
|
1348
|
+
return -1
|
|
1349
|
+
return int(float(rc_val))
|
|
1350
|
+
except Exception:
|
|
1351
|
+
return -1
|
|
1352
|
+
|
|
1353
|
+
def _parse_rc_from_text(self, text: str) -> Optional[int]:
|
|
1354
|
+
"""Parse return code from plain text using structural patterns."""
|
|
1355
|
+
if not text:
|
|
1356
|
+
return None
|
|
1357
|
+
|
|
1358
|
+
# 1. Primary check: 'search r(N)' pattern (SMCL tag potentially stripped)
|
|
1359
|
+
matches = list(re.finditer(r'search r\((\d+)\)', text))
|
|
1360
|
+
if matches:
|
|
1361
|
+
try:
|
|
1362
|
+
return int(matches[-1].group(1))
|
|
1363
|
+
except Exception:
|
|
1364
|
+
pass
|
|
1365
|
+
|
|
1366
|
+
# 2. Secondary check: Standalone r(N); pattern
|
|
1367
|
+
# This appears at the end of command blocks
|
|
1368
|
+
matches = list(re.finditer(r'(?<!\w)r\((\d+)\);?', text))
|
|
1369
|
+
if matches:
|
|
1370
|
+
try:
|
|
1371
|
+
return int(matches[-1].group(1))
|
|
1372
|
+
except Exception:
|
|
1373
|
+
pass
|
|
1374
|
+
|
|
1375
|
+
return None
|
|
1376
|
+
|
|
1377
|
+
def _parse_line_from_text(self, text: str) -> Optional[int]:
|
|
1378
|
+
match = re.search(r"line\s+(\d+)", text, re.IGNORECASE)
|
|
1379
|
+
if match:
|
|
1380
|
+
try:
|
|
1381
|
+
return int(match.group(1))
|
|
1382
|
+
except Exception:
|
|
1383
|
+
return None
|
|
1384
|
+
return None
|
|
1385
|
+
|
|
1386
|
+
def _read_log_backwards_until_error(self, path: str, max_bytes: int = 5_000_000) -> str:
|
|
1387
|
+
"""
|
|
1388
|
+
Read log file backwards in chunks, stopping when we find {err} tags or reach the start.
|
|
1389
|
+
|
|
1390
|
+
This is more efficient and robust than reading huge fixed tails, as we only read
|
|
1391
|
+
what we need to find the error.
|
|
1392
|
+
|
|
1393
|
+
Args:
|
|
1394
|
+
path: Path to the log file
|
|
1395
|
+
max_bytes: Maximum total bytes to read (safety limit, default 5MB)
|
|
1396
|
+
|
|
1397
|
+
Returns:
|
|
1398
|
+
The relevant portion of the log containing the error and context
|
|
1399
|
+
"""
|
|
1400
|
+
try:
|
|
1401
|
+
chunk_size = 50_000 # Read 50KB chunks at a time
|
|
1402
|
+
total_read = 0
|
|
1403
|
+
chunks = []
|
|
1404
|
+
|
|
1405
|
+
with open(path, 'rb') as f:
|
|
1406
|
+
# Get file size
|
|
1407
|
+
f.seek(0, os.SEEK_END)
|
|
1408
|
+
file_size = f.tell()
|
|
1409
|
+
|
|
1410
|
+
if file_size == 0:
|
|
1411
|
+
return ""
|
|
1412
|
+
|
|
1413
|
+
# Start from the end
|
|
1414
|
+
position = file_size
|
|
1415
|
+
|
|
1416
|
+
while position > 0 and total_read < max_bytes:
|
|
1417
|
+
# Calculate how much to read in this chunk
|
|
1418
|
+
read_size = min(chunk_size, position, max_bytes - total_read)
|
|
1419
|
+
position -= read_size
|
|
1420
|
+
|
|
1421
|
+
# Seek and read
|
|
1422
|
+
f.seek(position)
|
|
1423
|
+
chunk = f.read(read_size)
|
|
1424
|
+
chunks.insert(0, chunk)
|
|
1425
|
+
total_read += read_size
|
|
1426
|
+
|
|
1427
|
+
# Decode and check for error tags
|
|
1428
|
+
try:
|
|
1429
|
+
accumulated = b''.join(chunks).decode('utf-8', errors='replace')
|
|
1430
|
+
|
|
1431
|
+
# Check if we've found an error tag
|
|
1432
|
+
if '{err}' in accumulated:
|
|
1433
|
+
# Found it! Read one more chunk for context before the error
|
|
1434
|
+
if position > 0 and total_read < max_bytes:
|
|
1435
|
+
extra_read = min(chunk_size, position, max_bytes - total_read)
|
|
1436
|
+
position -= extra_read
|
|
1437
|
+
f.seek(position)
|
|
1438
|
+
extra_chunk = f.read(extra_read)
|
|
1439
|
+
chunks.insert(0, extra_chunk)
|
|
1440
|
+
|
|
1441
|
+
return b''.join(chunks).decode('utf-8', errors='replace')
|
|
1442
|
+
|
|
1443
|
+
except UnicodeDecodeError:
|
|
1444
|
+
# Continue reading if we hit a decode error (might be mid-character)
|
|
1445
|
+
continue
|
|
1446
|
+
|
|
1447
|
+
# Read everything we've accumulated
|
|
1448
|
+
return b''.join(chunks).decode('utf-8', errors='replace')
|
|
1449
|
+
|
|
1450
|
+
except Exception as e:
|
|
1451
|
+
logger.warning(f"Error reading log backwards: {e}")
|
|
1452
|
+
# Fallback to regular tail read
|
|
1453
|
+
return self._read_log_tail(path, 200_000)
|
|
1454
|
+
|
|
1455
|
+
def _read_log_tail_smart(self, path: str, rc: int, trace: bool = False) -> str:
|
|
1456
|
+
"""
|
|
1457
|
+
Smart log tail reader that adapts based on whether an error occurred.
|
|
1458
|
+
|
|
1459
|
+
- If rc == 0: Read normal tail (20KB without trace, 200KB with trace)
|
|
1460
|
+
- If rc != 0: Search backwards dynamically to find the error
|
|
1461
|
+
|
|
1462
|
+
Args:
|
|
1463
|
+
path: Path to the log file
|
|
1464
|
+
rc: Return code from Stata
|
|
1465
|
+
trace: Whether trace mode was enabled
|
|
1466
|
+
|
|
1467
|
+
Returns:
|
|
1468
|
+
Relevant log content
|
|
1469
|
+
"""
|
|
1470
|
+
if rc != 0:
|
|
1471
|
+
# Error occurred - search backwards for {err} tags
|
|
1472
|
+
return self._read_log_backwards_until_error(path)
|
|
1473
|
+
else:
|
|
1474
|
+
# Success - just read normal tail
|
|
1475
|
+
tail_size = 200_000 if trace else 20_000
|
|
1476
|
+
return self._read_log_tail(path, tail_size)
|
|
1477
|
+
|
|
1478
|
+
def _read_log_tail(self, path: str, max_chars: int) -> str:
|
|
1479
|
+
try:
|
|
1480
|
+
with open(path, "rb") as f:
|
|
1481
|
+
f.seek(0, os.SEEK_END)
|
|
1482
|
+
size = f.tell()
|
|
1483
|
+
|
|
1484
|
+
if size <= 0:
|
|
1485
|
+
return ""
|
|
1486
|
+
read_size = min(size, max_chars)
|
|
1487
|
+
f.seek(-read_size, os.SEEK_END)
|
|
1488
|
+
data = f.read(read_size)
|
|
1489
|
+
return data.decode("utf-8", errors="replace")
|
|
1490
|
+
except Exception:
|
|
1491
|
+
return ""
|
|
1492
|
+
|
|
1493
|
+
def _build_combined_log(
|
|
1494
|
+
self,
|
|
1495
|
+
tail: TailBuffer,
|
|
1496
|
+
path: str,
|
|
1497
|
+
rc: int,
|
|
1498
|
+
trace: bool,
|
|
1499
|
+
exc: Optional[Exception],
|
|
1500
|
+
) -> str:
|
|
1501
|
+
tail_text = tail.get_value()
|
|
1502
|
+
log_tail = self._read_log_tail_smart(path, rc, trace)
|
|
1503
|
+
if log_tail and len(log_tail) > len(tail_text):
|
|
1504
|
+
tail_text = log_tail
|
|
1505
|
+
return (tail_text or "") + (f"\n{exc}" if exc else "")
|
|
1506
|
+
|
|
1507
|
+
def _truncate_command_output(
|
|
1508
|
+
self,
|
|
1509
|
+
result: CommandResponse,
|
|
1510
|
+
max_output_lines: Optional[int],
|
|
1511
|
+
) -> CommandResponse:
|
|
1512
|
+
if max_output_lines is None or not result.stdout:
|
|
1513
|
+
return result
|
|
1514
|
+
lines = result.stdout.splitlines()
|
|
1515
|
+
if len(lines) <= max_output_lines:
|
|
1516
|
+
return result
|
|
1517
|
+
truncated_lines = lines[:max_output_lines]
|
|
1518
|
+
truncated_lines.append(
|
|
1519
|
+
f"\n... (output truncated: showing {max_output_lines} of {len(lines)} lines)"
|
|
1520
|
+
)
|
|
1521
|
+
truncated_stdout = "\n".join(truncated_lines)
|
|
1522
|
+
if hasattr(result, "model_copy"):
|
|
1523
|
+
return result.model_copy(update={"stdout": truncated_stdout})
|
|
1524
|
+
return result.copy(update={"stdout": truncated_stdout})
|
|
1525
|
+
|
|
1526
|
+
def _run_plain_capture(self, code: str) -> str:
|
|
1527
|
+
"""
|
|
1528
|
+
Run a Stata command while capturing output using a named SMCL log.
|
|
1529
|
+
This is the most reliable way to capture output (like return list)
|
|
1530
|
+
without interfering with user logs or being affected by stdout redirection issues.
|
|
1531
|
+
"""
|
|
1532
|
+
if not self._initialized:
|
|
1533
|
+
self.init()
|
|
1534
|
+
|
|
1535
|
+
with self._exec_lock:
|
|
1536
|
+
hold_name = f"mcp_hold_{uuid.uuid4().hex[:8]}"
|
|
1537
|
+
# Hold results BEFORE opening the capture log
|
|
1538
|
+
self.stata.run(f"capture _return hold {hold_name}", echo=False)
|
|
1539
|
+
|
|
1540
|
+
try:
|
|
1541
|
+
with self._smcl_log_capture() as (log_name, smcl_path):
|
|
1542
|
+
# Restore results INSIDE the capture log so return list can see them
|
|
1543
|
+
self.stata.run(f"capture _return restore {hold_name}", echo=False)
|
|
1544
|
+
try:
|
|
1545
|
+
self.stata.run(code, echo=True)
|
|
1546
|
+
except Exception:
|
|
1547
|
+
pass
|
|
1548
|
+
except Exception:
|
|
1549
|
+
# Cleanup hold if log capture failed to open
|
|
1550
|
+
self.stata.run(f"capture _return drop {hold_name}", echo=False)
|
|
1551
|
+
content = ""
|
|
1552
|
+
smcl_path = None
|
|
1553
|
+
else:
|
|
1554
|
+
# Read SMCL content and convert to text
|
|
1555
|
+
content = self._read_smcl_file(smcl_path)
|
|
1556
|
+
# Remove the temp file
|
|
1557
|
+
self._safe_unlink(smcl_path)
|
|
1558
|
+
|
|
1559
|
+
return self._smcl_to_text(content)
|
|
1560
|
+
|
|
1561
|
+
def _count_do_file_lines(self, path: str) -> int:
|
|
1562
|
+
"""
|
|
1563
|
+
Count the number of executable lines in a .do file for progress inference.
|
|
1564
|
+
|
|
1565
|
+
Blank lines and comment-only lines (starting with * or //) are ignored.
|
|
1566
|
+
"""
|
|
1567
|
+
try:
|
|
1568
|
+
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
|
1569
|
+
lines = f.read().splitlines()
|
|
1570
|
+
except Exception:
|
|
1571
|
+
return 0
|
|
1572
|
+
|
|
1573
|
+
total = 0
|
|
1574
|
+
for line in lines:
|
|
1575
|
+
s = line.strip()
|
|
1576
|
+
if not s:
|
|
1577
|
+
continue
|
|
1578
|
+
if s.startswith("*"):
|
|
1579
|
+
continue
|
|
1580
|
+
if s.startswith("//"):
|
|
1581
|
+
continue
|
|
1582
|
+
total += 1
|
|
1583
|
+
return total
|
|
1584
|
+
|
|
1585
|
+
def _smcl_to_text(self, smcl: str) -> str:
|
|
1586
|
+
"""Convert simple SMCL markup into plain text for LLM-friendly help."""
|
|
1587
|
+
# First, keep inline directive content if present (e.g., {bf:word} -> word)
|
|
1588
|
+
cleaned = re.sub(r"\{[^}:]+:([^}]*)\}", r"\1", smcl)
|
|
1589
|
+
# Remove remaining SMCL brace commands like {smcl}, {vieweralsosee ...}, {txt}, {p}
|
|
1590
|
+
cleaned = re.sub(r"\{[^}]*\}", "", cleaned)
|
|
1591
|
+
# Normalize whitespace
|
|
1592
|
+
cleaned = cleaned.replace("\r", "")
|
|
1593
|
+
lines = [line.rstrip() for line in cleaned.splitlines()]
|
|
1594
|
+
return "\n".join(lines).strip()
|
|
1595
|
+
|
|
1596
|
+
def _extract_error_and_context(self, log_content: str, rc: int) -> Tuple[str, str]:
|
|
1597
|
+
"""
|
|
1598
|
+
Extracts the error message and trace context using {err} SMCL tags.
|
|
1599
|
+
"""
|
|
1600
|
+
if not log_content:
|
|
1601
|
+
return f"Stata error r({rc})", ""
|
|
1602
|
+
|
|
1603
|
+
lines = log_content.splitlines()
|
|
1604
|
+
|
|
1605
|
+
# Search backwards for the {err} tag
|
|
1606
|
+
for i in range(len(lines) - 1, -1, -1):
|
|
1607
|
+
line = lines[i]
|
|
1608
|
+
if '{err}' in line:
|
|
1609
|
+
# Found the (last) error line.
|
|
1610
|
+
# Walk backwards to find the start of the error block (consecutive {err} lines)
|
|
1611
|
+
start_idx = i
|
|
1612
|
+
while start_idx > 0 and '{err}' in lines[start_idx-1]:
|
|
1613
|
+
start_idx -= 1
|
|
1614
|
+
|
|
1615
|
+
# The full error message is the concatenation of all {err} lines in this block
|
|
1616
|
+
error_lines = []
|
|
1617
|
+
for j in range(start_idx, i + 1):
|
|
1618
|
+
error_lines.append(lines[j].strip())
|
|
1619
|
+
|
|
1620
|
+
clean_msg = " ".join(filter(None, error_lines)) or f"Stata error r({rc})"
|
|
1621
|
+
|
|
1622
|
+
# Capture everything from the start of the error block to the end
|
|
1623
|
+
context_str = "\n".join(lines[start_idx:])
|
|
1624
|
+
return clean_msg, context_str
|
|
1625
|
+
|
|
1626
|
+
# Fallback: grab the last 30 lines
|
|
1627
|
+
context_start = max(0, len(lines) - 30)
|
|
1628
|
+
context_str = "\n".join(lines[context_start:])
|
|
1629
|
+
|
|
1630
|
+
return f"Stata error r({rc})", context_str
|
|
1631
|
+
|
|
1632
|
+
def _exec_with_capture(self, code: str, echo: bool = True, trace: bool = False, cwd: Optional[str] = None) -> CommandResponse:
|
|
1633
|
+
if not self._initialized:
|
|
1634
|
+
self.init()
|
|
1635
|
+
|
|
1636
|
+
self._increment_command_idx()
|
|
1637
|
+
# Rewrite graph names with special characters to internal aliases
|
|
1638
|
+
code = self._maybe_rewrite_graph_name_in_command(code)
|
|
1639
|
+
|
|
1640
|
+
output_buffer = StringIO()
|
|
1641
|
+
error_buffer = StringIO()
|
|
1642
|
+
rc = 0
|
|
1643
|
+
sys_error = None
|
|
1644
|
+
error_envelope = None
|
|
1645
|
+
smcl_content = ""
|
|
1646
|
+
smcl_path = None
|
|
1647
|
+
|
|
1648
|
+
with self._exec_lock:
|
|
1649
|
+
try:
|
|
1650
|
+
from sfi import Scalar, SFIToolkit
|
|
1651
|
+
with self._temp_cwd(cwd):
|
|
1652
|
+
# Create SMCL log for authoritative output capture
|
|
1653
|
+
# Use shorter unique path to avoid Windows path issues
|
|
1654
|
+
smcl_path = self._create_smcl_log_path(prefix="mcp_", max_hex=16, base_dir=cwd)
|
|
1655
|
+
log_name = self._make_smcl_log_name()
|
|
1656
|
+
self._open_smcl_log(smcl_path, log_name)
|
|
1657
|
+
|
|
1658
|
+
try:
|
|
1659
|
+
with self._redirect_io(output_buffer, error_buffer):
|
|
1660
|
+
try:
|
|
1661
|
+
if trace:
|
|
1662
|
+
self.stata.run("set trace on")
|
|
1663
|
+
|
|
1664
|
+
# Run the user code
|
|
1665
|
+
self.stata.run(code, echo=echo)
|
|
1666
|
+
|
|
1667
|
+
# Hold results IMMEDIATELY to prevent clobbering by cleanup
|
|
1668
|
+
self._hold_name = f"mcp_hold_{uuid.uuid4().hex[:8]}"
|
|
1669
|
+
self.stata.run(f"capture _return hold {self._hold_name}", echo=False)
|
|
1670
|
+
|
|
1671
|
+
finally:
|
|
1672
|
+
if trace:
|
|
1673
|
+
try:
|
|
1674
|
+
self.stata.run("set trace off")
|
|
1675
|
+
except Exception:
|
|
1676
|
+
pass
|
|
1677
|
+
finally:
|
|
1678
|
+
# Close SMCL log AFTER output redirection
|
|
1679
|
+
self._close_smcl_log(log_name)
|
|
1680
|
+
# Restore and capture results while still inside the lock
|
|
1681
|
+
self._restore_results_from_hold("_hold_name")
|
|
1682
|
+
|
|
1683
|
+
except Exception as e:
|
|
1684
|
+
sys_error = str(e)
|
|
1685
|
+
# Try to parse RC from exception message
|
|
1686
|
+
parsed_rc = self._parse_rc_from_text(sys_error)
|
|
1687
|
+
rc = parsed_rc if parsed_rc is not None else 1
|
|
1688
|
+
|
|
1689
|
+
# Read SMCL content as the authoritative source
|
|
1690
|
+
if smcl_path:
|
|
1691
|
+
smcl_content = self._read_smcl_file(smcl_path)
|
|
1692
|
+
# Clean up SMCL file
|
|
1693
|
+
self._safe_unlink(smcl_path)
|
|
1694
|
+
|
|
1695
|
+
stdout_content = output_buffer.getvalue()
|
|
1696
|
+
stderr_content = error_buffer.getvalue()
|
|
1697
|
+
|
|
1698
|
+
# If RC wasn't captured or is generic, try to parse from SMCL
|
|
1699
|
+
if rc in (0, 1, -1) and smcl_content:
|
|
1700
|
+
parsed_rc = self._parse_rc_from_smcl(smcl_content)
|
|
1701
|
+
if parsed_rc is not None and parsed_rc != 0:
|
|
1702
|
+
rc = parsed_rc
|
|
1703
|
+
elif rc == -1:
|
|
1704
|
+
rc = 0
|
|
1705
|
+
|
|
1706
|
+
# If stdout is empty but SMCL has content AND command succeeded, use SMCL as stdout
|
|
1707
|
+
# This handles cases where Stata writes to log but not to redirected stdout
|
|
1708
|
+
# For errors, we keep stdout empty and error info goes to ErrorEnvelope
|
|
1709
|
+
if rc == 0 and not stdout_content and smcl_content:
|
|
1710
|
+
# Convert SMCL to plain text for stdout
|
|
1711
|
+
stdout_content = self._smcl_to_text(smcl_content)
|
|
1712
|
+
|
|
1713
|
+
if rc != 0:
|
|
1714
|
+
if sys_error:
|
|
1715
|
+
msg = sys_error
|
|
1716
|
+
context = sys_error
|
|
1717
|
+
else:
|
|
1718
|
+
# Extract error from SMCL (authoritative source)
|
|
1719
|
+
msg, context = self._extract_error_from_smcl(smcl_content, rc)
|
|
1720
|
+
|
|
1721
|
+
error_envelope = ErrorEnvelope(
|
|
1722
|
+
message=msg,
|
|
1723
|
+
rc=rc,
|
|
1724
|
+
context=context,
|
|
1725
|
+
snippet=smcl_content[-800:] if smcl_content else (stdout_content + stderr_content)[-800:],
|
|
1726
|
+
smcl_output=smcl_content # Include raw SMCL for debugging
|
|
1727
|
+
)
|
|
1728
|
+
stderr_content = context
|
|
1729
|
+
|
|
1730
|
+
resp = CommandResponse(
|
|
1731
|
+
command=code,
|
|
1732
|
+
rc=rc,
|
|
1733
|
+
stdout=stdout_content,
|
|
1734
|
+
stderr=stderr_content,
|
|
1735
|
+
success=(rc == 0),
|
|
1736
|
+
error=error_envelope,
|
|
1737
|
+
log_path=smcl_path if smcl_path else None,
|
|
1738
|
+
smcl_output=smcl_content,
|
|
1739
|
+
)
|
|
1740
|
+
|
|
1741
|
+
# Capture results immediately after execution, INSIDE the lock
|
|
1742
|
+
try:
|
|
1743
|
+
self._last_results = self.get_stored_results(force_fresh=True)
|
|
1744
|
+
except Exception:
|
|
1745
|
+
self._last_results = None
|
|
1746
|
+
|
|
1747
|
+
return resp
|
|
1748
|
+
|
|
1749
|
+
def _exec_no_capture(self, code: str, echo: bool = False, trace: bool = False) -> CommandResponse:
|
|
1750
|
+
"""Execute Stata code while leaving stdout/stderr alone."""
|
|
1751
|
+
if not self._initialized:
|
|
1752
|
+
self.init()
|
|
1753
|
+
|
|
1754
|
+
exc: Optional[Exception] = None
|
|
1755
|
+
ret_text: Optional[str] = None
|
|
1756
|
+
rc = 0
|
|
1757
|
+
|
|
1758
|
+
with self._exec_lock:
|
|
1759
|
+
try:
|
|
1760
|
+
from sfi import Scalar # Import SFI tools
|
|
1761
|
+
if trace:
|
|
1762
|
+
self.stata.run("set trace on")
|
|
1763
|
+
ret = self.stata.run(code, echo=echo)
|
|
1764
|
+
if isinstance(ret, str) and ret:
|
|
1765
|
+
ret_text = ret
|
|
1766
|
+
|
|
1767
|
+
|
|
1768
|
+
except Exception as e:
|
|
1769
|
+
exc = e
|
|
1770
|
+
rc = 1
|
|
1771
|
+
finally:
|
|
1772
|
+
if trace:
|
|
1773
|
+
try:
|
|
1774
|
+
self.stata.run("set trace off")
|
|
1775
|
+
except Exception as e:
|
|
1776
|
+
logger.warning("Failed to turn off Stata trace mode: %s", e)
|
|
1777
|
+
|
|
1778
|
+
stdout = ""
|
|
1779
|
+
stderr = ""
|
|
1780
|
+
success = rc == 0 and exc is None
|
|
1781
|
+
error = None
|
|
1782
|
+
if not success:
|
|
1783
|
+
msg = str(exc) if exc else f"Stata error r({rc})"
|
|
1784
|
+
error = ErrorEnvelope(
|
|
1785
|
+
message=msg,
|
|
1786
|
+
rc=rc,
|
|
1787
|
+
command=code,
|
|
1788
|
+
stdout=ret_text,
|
|
1789
|
+
)
|
|
1790
|
+
|
|
1791
|
+
return CommandResponse(
|
|
1792
|
+
command=code,
|
|
1793
|
+
rc=rc,
|
|
1794
|
+
stdout=stdout,
|
|
1795
|
+
stderr=None,
|
|
1796
|
+
success=success,
|
|
1797
|
+
error=error,
|
|
1798
|
+
)
|
|
1799
|
+
|
|
1800
|
+
def _exec_no_capture_silent(self, code: str, echo: bool = False, trace: bool = False) -> CommandResponse:
|
|
1801
|
+
"""Execute Stata code while suppressing stdout/stderr output."""
|
|
1802
|
+
if not self._initialized:
|
|
1803
|
+
self.init()
|
|
1804
|
+
|
|
1805
|
+
exc: Optional[Exception] = None
|
|
1806
|
+
ret_text: Optional[str] = None
|
|
1807
|
+
rc = 0
|
|
1808
|
+
|
|
1809
|
+
with self._exec_lock:
|
|
1810
|
+
try:
|
|
1811
|
+
from sfi import Scalar # Import SFI tools
|
|
1812
|
+
if trace:
|
|
1813
|
+
self.stata.run("set trace on")
|
|
1814
|
+
output_buf = StringIO()
|
|
1815
|
+
with redirect_stdout(output_buf), redirect_stderr(output_buf):
|
|
1816
|
+
ret = self.stata.run(code, echo=echo)
|
|
1817
|
+
if isinstance(ret, str) and ret:
|
|
1818
|
+
ret_text = ret
|
|
1819
|
+
except Exception as e:
|
|
1820
|
+
exc = e
|
|
1821
|
+
rc = 1
|
|
1822
|
+
finally:
|
|
1823
|
+
if trace:
|
|
1824
|
+
try:
|
|
1825
|
+
self.stata.run("set trace off")
|
|
1826
|
+
except Exception as e:
|
|
1827
|
+
logger.warning("Failed to turn off Stata trace mode: %s", e)
|
|
1828
|
+
|
|
1829
|
+
stdout = ""
|
|
1830
|
+
stderr = ""
|
|
1831
|
+
success = rc == 0 and exc is None
|
|
1832
|
+
error = None
|
|
1833
|
+
if not success:
|
|
1834
|
+
msg = str(exc) if exc else f"Stata error r({rc})"
|
|
1835
|
+
error = ErrorEnvelope(
|
|
1836
|
+
message=msg,
|
|
1837
|
+
rc=rc,
|
|
1838
|
+
command=code,
|
|
1839
|
+
stdout=ret_text,
|
|
1840
|
+
)
|
|
1841
|
+
|
|
1842
|
+
return CommandResponse(
|
|
1843
|
+
command=code,
|
|
1844
|
+
rc=rc,
|
|
1845
|
+
stdout=stdout,
|
|
1846
|
+
stderr=None,
|
|
1847
|
+
success=success,
|
|
1848
|
+
error=error,
|
|
1849
|
+
)
|
|
1850
|
+
|
|
1851
|
+
def exec_lightweight(self, code: str) -> CommandResponse:
|
|
1852
|
+
"""
|
|
1853
|
+
Executes a command using simple stdout redirection (no SMCL logs).
|
|
1854
|
+
Much faster on Windows as it avoids FS operations.
|
|
1855
|
+
LIMITED: Does not support error envelopes or complex return code parsing.
|
|
1856
|
+
"""
|
|
1857
|
+
if not self._initialized:
|
|
1858
|
+
self.init()
|
|
1859
|
+
|
|
1860
|
+
code = self._maybe_rewrite_graph_name_in_command(code)
|
|
1861
|
+
|
|
1862
|
+
output_buffer = StringIO()
|
|
1863
|
+
error_buffer = StringIO()
|
|
1864
|
+
rc = 0
|
|
1865
|
+
exc = None
|
|
1866
|
+
|
|
1867
|
+
with self._exec_lock:
|
|
1868
|
+
with self._redirect_io(output_buffer, error_buffer):
|
|
1869
|
+
try:
|
|
1870
|
+
self.stata.run(code, echo=False)
|
|
1871
|
+
except Exception as e:
|
|
1872
|
+
exc = e
|
|
1873
|
+
rc = 1
|
|
1874
|
+
|
|
1875
|
+
stdout = output_buffer.getvalue()
|
|
1876
|
+
stderr = error_buffer.getvalue()
|
|
1877
|
+
|
|
1878
|
+
return CommandResponse(
|
|
1879
|
+
command=code,
|
|
1880
|
+
rc=rc,
|
|
1881
|
+
stdout=stdout,
|
|
1882
|
+
stderr=stderr if not exc else str(exc),
|
|
1883
|
+
success=(rc == 0),
|
|
1884
|
+
error=None
|
|
1885
|
+
)
|
|
1886
|
+
|
|
1887
|
+
async def run_command_streaming(
|
|
1888
|
+
self,
|
|
1889
|
+
code: str,
|
|
1890
|
+
*,
|
|
1891
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
1892
|
+
notify_progress: Optional[Callable[[float, Optional[float], Optional[str]], Awaitable[None]]] = None,
|
|
1893
|
+
echo: bool = True,
|
|
1894
|
+
trace: bool = False,
|
|
1895
|
+
max_output_lines: Optional[int] = None,
|
|
1896
|
+
cwd: Optional[str] = None,
|
|
1897
|
+
auto_cache_graphs: bool = False,
|
|
1898
|
+
on_graph_cached: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
|
1899
|
+
emit_graph_ready: bool = False,
|
|
1900
|
+
graph_ready_task_id: Optional[str] = None,
|
|
1901
|
+
graph_ready_format: str = "svg",
|
|
1902
|
+
) -> CommandResponse:
|
|
1903
|
+
if not self._initialized:
|
|
1904
|
+
self.init()
|
|
1905
|
+
|
|
1906
|
+
code = self._maybe_rewrite_graph_name_in_command(code)
|
|
1907
|
+
auto_cache_graphs = auto_cache_graphs or emit_graph_ready
|
|
1908
|
+
total_lines = 0 # Commands (not do-files) do not have line-based progress
|
|
1909
|
+
|
|
1910
|
+
if cwd is not None and not os.path.isdir(cwd):
|
|
1911
|
+
return CommandResponse(
|
|
1912
|
+
command=code,
|
|
1913
|
+
rc=601,
|
|
1914
|
+
stdout="",
|
|
1915
|
+
stderr=None,
|
|
1916
|
+
success=False,
|
|
1917
|
+
error=ErrorEnvelope(
|
|
1918
|
+
message=f"cwd not found: {cwd}",
|
|
1919
|
+
rc=601,
|
|
1920
|
+
command=code,
|
|
1921
|
+
),
|
|
1922
|
+
)
|
|
1923
|
+
|
|
1924
|
+
start_time = time.time()
|
|
1925
|
+
exc: Optional[Exception] = None
|
|
1926
|
+
smcl_content = ""
|
|
1927
|
+
smcl_path = None
|
|
1928
|
+
|
|
1929
|
+
# Setup streaming graph cache if enabled
|
|
1930
|
+
graph_cache = self._init_streaming_graph_cache(auto_cache_graphs, on_graph_cached, notify_log)
|
|
1931
|
+
|
|
1932
|
+
_log_file, log_path, tail, tee = self._create_streaming_log(trace=trace)
|
|
1933
|
+
|
|
1934
|
+
# Create SMCL log path for authoritative output capture
|
|
1935
|
+
smcl_path = self._create_smcl_log_path(base_dir=cwd)
|
|
1936
|
+
smcl_log_name = self._make_smcl_log_name()
|
|
1937
|
+
|
|
1938
|
+
# Inform the MCP client immediately where to read/tail the output.
|
|
1939
|
+
await notify_log(json.dumps({"event": "log_path", "path": smcl_path}))
|
|
1940
|
+
|
|
1941
|
+
rc = -1
|
|
1942
|
+
path_for_stata = code.replace("\\", "/")
|
|
1943
|
+
command = f'{path_for_stata}'
|
|
1944
|
+
|
|
1945
|
+
graph_ready_initial = self._capture_graph_state(graph_cache, emit_graph_ready)
|
|
1946
|
+
|
|
1947
|
+
# Increment AFTER capture so detected modifications are based on state BEFORE this command
|
|
1948
|
+
self._increment_command_idx()
|
|
1949
|
+
|
|
1950
|
+
graph_poll_state = [0.0]
|
|
1951
|
+
|
|
1952
|
+
async def on_chunk_for_graphs(_chunk: str) -> None:
|
|
1953
|
+
# Background the graph check so we don't block SMCL streaming or task completion
|
|
1954
|
+
asyncio.create_task(
|
|
1955
|
+
self._maybe_cache_graphs_on_chunk(
|
|
1956
|
+
graph_cache=graph_cache,
|
|
1957
|
+
emit_graph_ready=emit_graph_ready,
|
|
1958
|
+
notify_log=notify_log,
|
|
1959
|
+
graph_ready_task_id=graph_ready_task_id,
|
|
1960
|
+
graph_ready_format=graph_ready_format,
|
|
1961
|
+
graph_ready_initial=graph_ready_initial,
|
|
1962
|
+
last_check=graph_poll_state,
|
|
1963
|
+
)
|
|
1964
|
+
)
|
|
1965
|
+
|
|
1966
|
+
done = anyio.Event()
|
|
1967
|
+
|
|
1968
|
+
try:
|
|
1969
|
+
async with anyio.create_task_group() as tg:
|
|
1970
|
+
async def stream_smcl() -> None:
|
|
1971
|
+
try:
|
|
1972
|
+
await self._stream_smcl_log(
|
|
1973
|
+
smcl_path=smcl_path,
|
|
1974
|
+
notify_log=notify_log,
|
|
1975
|
+
done=done,
|
|
1976
|
+
on_chunk=on_chunk_for_graphs if graph_cache else None,
|
|
1977
|
+
)
|
|
1978
|
+
except Exception as exc:
|
|
1979
|
+
logger.debug("SMCL streaming failed: %s", exc)
|
|
1980
|
+
|
|
1981
|
+
tg.start_soon(stream_smcl)
|
|
1982
|
+
|
|
1983
|
+
if notify_progress is not None:
|
|
1984
|
+
if total_lines > 0:
|
|
1985
|
+
await notify_progress(0, float(total_lines), f"Executing command: 0/{total_lines}")
|
|
1986
|
+
else:
|
|
1987
|
+
await notify_progress(0, None, "Running command")
|
|
1988
|
+
|
|
1989
|
+
try:
|
|
1990
|
+
run_blocking = lambda: self._run_streaming_blocking(
|
|
1991
|
+
command=command,
|
|
1992
|
+
tee=tee,
|
|
1993
|
+
cwd=cwd,
|
|
1994
|
+
trace=trace,
|
|
1995
|
+
echo=echo,
|
|
1996
|
+
smcl_path=smcl_path,
|
|
1997
|
+
smcl_log_name=smcl_log_name,
|
|
1998
|
+
hold_attr="_hold_name_stream",
|
|
1999
|
+
require_smcl_log=True,
|
|
2000
|
+
)
|
|
2001
|
+
try:
|
|
2002
|
+
rc, exc = await anyio.to_thread.run_sync(
|
|
2003
|
+
run_blocking,
|
|
2004
|
+
abandon_on_cancel=True,
|
|
2005
|
+
)
|
|
2006
|
+
except TypeError:
|
|
2007
|
+
rc, exc = await anyio.to_thread.run_sync(run_blocking)
|
|
2008
|
+
except Exception as e:
|
|
2009
|
+
exc = e
|
|
2010
|
+
if rc in (-1, 0):
|
|
2011
|
+
rc = 1
|
|
2012
|
+
except get_cancelled_exc_class():
|
|
2013
|
+
self._request_break_in()
|
|
2014
|
+
await self._wait_for_stata_stop()
|
|
2015
|
+
raise
|
|
2016
|
+
finally:
|
|
2017
|
+
done.set()
|
|
2018
|
+
tee.close()
|
|
2019
|
+
except* Exception as exc_group:
|
|
2020
|
+
logger.debug("SMCL streaming task group failed: %s", exc_group)
|
|
2021
|
+
|
|
2022
|
+
# Read SMCL content as the authoritative source
|
|
2023
|
+
smcl_content = self._read_smcl_file(smcl_path)
|
|
2024
|
+
|
|
2025
|
+
if graph_cache:
|
|
2026
|
+
asyncio.create_task(
|
|
2027
|
+
self._cache_new_graphs(
|
|
2028
|
+
graph_cache,
|
|
2029
|
+
notify_progress=notify_progress,
|
|
2030
|
+
total_lines=total_lines,
|
|
2031
|
+
completed_label="Command",
|
|
2032
|
+
)
|
|
2033
|
+
)
|
|
2034
|
+
|
|
2035
|
+
combined = self._build_combined_log(tail, smcl_path, rc, trace, exc)
|
|
2036
|
+
|
|
2037
|
+
# Use SMCL content as primary source for RC detection
|
|
2038
|
+
if not exc or rc in (1, -1):
|
|
2039
|
+
parsed_rc = self._parse_rc_from_smcl(smcl_content)
|
|
2040
|
+
if parsed_rc is not None and parsed_rc != 0:
|
|
2041
|
+
rc = parsed_rc
|
|
2042
|
+
elif rc in (-1, 0, 1): # Also check text if rc is generic 1 or unset
|
|
2043
|
+
parsed_rc_text = self._parse_rc_from_text(combined)
|
|
2044
|
+
if parsed_rc_text is not None:
|
|
2045
|
+
rc = parsed_rc_text
|
|
2046
|
+
elif rc == -1:
|
|
2047
|
+
rc = 0 # Default to success if no error trace found
|
|
2048
|
+
|
|
2049
|
+
success = (rc == 0 and exc is None)
|
|
2050
|
+
stderr_final = None
|
|
2051
|
+
error = None
|
|
2052
|
+
|
|
2053
|
+
if not success:
|
|
2054
|
+
# Use SMCL as authoritative source for error extraction
|
|
2055
|
+
if smcl_content:
|
|
2056
|
+
msg, context = self._extract_error_from_smcl(smcl_content, rc)
|
|
2057
|
+
else:
|
|
2058
|
+
# Fallback to combined log
|
|
2059
|
+
msg, context = self._extract_error_and_context(combined, rc)
|
|
2060
|
+
|
|
2061
|
+
error = ErrorEnvelope(
|
|
2062
|
+
message=msg,
|
|
2063
|
+
context=context,
|
|
2064
|
+
rc=rc,
|
|
2065
|
+
command=command,
|
|
2066
|
+
log_path=log_path,
|
|
2067
|
+
snippet=smcl_content[-800:] if smcl_content else combined[-800:],
|
|
2068
|
+
smcl_output=smcl_content,
|
|
2069
|
+
)
|
|
2070
|
+
stderr_final = context
|
|
2071
|
+
|
|
2072
|
+
duration = time.time() - start_time
|
|
2073
|
+
logger.info(
|
|
2074
|
+
"stata.run(stream) rc=%s success=%s trace=%s duration_ms=%.2f code_preview=%s",
|
|
2075
|
+
rc,
|
|
2076
|
+
success,
|
|
2077
|
+
trace,
|
|
2078
|
+
duration * 1000,
|
|
2079
|
+
code.replace("\n", "\\n")[:120],
|
|
2080
|
+
)
|
|
2081
|
+
|
|
2082
|
+
result = CommandResponse(
|
|
2083
|
+
command=code,
|
|
2084
|
+
rc=rc,
|
|
2085
|
+
stdout="",
|
|
2086
|
+
stderr=stderr_final,
|
|
2087
|
+
log_path=log_path,
|
|
2088
|
+
success=success,
|
|
2089
|
+
error=error,
|
|
2090
|
+
smcl_output=smcl_content,
|
|
2091
|
+
)
|
|
2092
|
+
|
|
2093
|
+
if notify_progress is not None:
|
|
2094
|
+
await notify_progress(1, 1, "Finished")
|
|
2095
|
+
|
|
2096
|
+
return result
|
|
2097
|
+
|
|
2098
|
+
async def run_do_file_streaming(
|
|
2099
|
+
self,
|
|
2100
|
+
path: str,
|
|
2101
|
+
*,
|
|
2102
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
2103
|
+
notify_progress: Optional[Callable[[float, Optional[float], Optional[str]], Awaitable[None]]] = None,
|
|
2104
|
+
echo: bool = True,
|
|
2105
|
+
trace: bool = False,
|
|
2106
|
+
max_output_lines: Optional[int] = None,
|
|
2107
|
+
cwd: Optional[str] = None,
|
|
2108
|
+
auto_cache_graphs: bool = False,
|
|
2109
|
+
on_graph_cached: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
|
2110
|
+
emit_graph_ready: bool = False,
|
|
2111
|
+
graph_ready_task_id: Optional[str] = None,
|
|
2112
|
+
graph_ready_format: str = "svg",
|
|
2113
|
+
) -> CommandResponse:
|
|
2114
|
+
effective_path, command, error_response = self._resolve_do_file_path(path, cwd)
|
|
2115
|
+
if error_response is not None:
|
|
2116
|
+
return error_response
|
|
2117
|
+
|
|
2118
|
+
total_lines = self._count_do_file_lines(effective_path)
|
|
2119
|
+
executed_lines = 0
|
|
2120
|
+
last_progress_time = 0.0
|
|
2121
|
+
dot_prompt = re.compile(r"^\.\s+\S")
|
|
2122
|
+
|
|
2123
|
+
async def on_chunk_for_progress(chunk: str) -> None:
|
|
2124
|
+
nonlocal executed_lines, last_progress_time
|
|
2125
|
+
if total_lines <= 0 or notify_progress is None:
|
|
2126
|
+
return
|
|
2127
|
+
for line in chunk.splitlines():
|
|
2128
|
+
if dot_prompt.match(line):
|
|
2129
|
+
executed_lines += 1
|
|
2130
|
+
if executed_lines > total_lines:
|
|
2131
|
+
executed_lines = total_lines
|
|
2132
|
+
|
|
2133
|
+
now = time.monotonic()
|
|
2134
|
+
if executed_lines > 0 and (now - last_progress_time) >= 0.25:
|
|
2135
|
+
last_progress_time = now
|
|
2136
|
+
await notify_progress(
|
|
2137
|
+
float(executed_lines),
|
|
2138
|
+
float(total_lines),
|
|
2139
|
+
f"Executing do-file: {executed_lines}/{total_lines}",
|
|
2140
|
+
)
|
|
2141
|
+
|
|
2142
|
+
if not self._initialized:
|
|
2143
|
+
self.init()
|
|
2144
|
+
|
|
2145
|
+
auto_cache_graphs = auto_cache_graphs or emit_graph_ready
|
|
2146
|
+
|
|
2147
|
+
start_time = time.time()
|
|
2148
|
+
exc: Optional[Exception] = None
|
|
2149
|
+
smcl_content = ""
|
|
2150
|
+
smcl_path = None
|
|
2151
|
+
|
|
2152
|
+
graph_cache = self._init_streaming_graph_cache(auto_cache_graphs, on_graph_cached, notify_log)
|
|
2153
|
+
_log_file, log_path, tail, tee = self._create_streaming_log(trace=trace)
|
|
2154
|
+
|
|
2155
|
+
base_dir = cwd or os.path.dirname(effective_path)
|
|
2156
|
+
smcl_path = self._create_smcl_log_path(base_dir=base_dir)
|
|
2157
|
+
smcl_log_name = self._make_smcl_log_name()
|
|
2158
|
+
|
|
2159
|
+
# Inform the MCP client immediately where to read/tail the output.
|
|
2160
|
+
await notify_log(json.dumps({"event": "log_path", "path": smcl_path}))
|
|
2161
|
+
|
|
2162
|
+
rc = -1
|
|
2163
|
+
graph_ready_initial = self._capture_graph_state(graph_cache, emit_graph_ready)
|
|
2164
|
+
|
|
2165
|
+
# Increment AFTER capture
|
|
2166
|
+
self._increment_command_idx()
|
|
2167
|
+
|
|
2168
|
+
graph_poll_state = [0.0]
|
|
2169
|
+
|
|
2170
|
+
async def on_chunk_for_graphs(_chunk: str) -> None:
|
|
2171
|
+
# Background the graph check so we don't block SMCL streaming or task completion
|
|
2172
|
+
asyncio.create_task(
|
|
2173
|
+
self._maybe_cache_graphs_on_chunk(
|
|
2174
|
+
graph_cache=graph_cache,
|
|
2175
|
+
emit_graph_ready=emit_graph_ready,
|
|
2176
|
+
notify_log=notify_log,
|
|
2177
|
+
graph_ready_task_id=graph_ready_task_id,
|
|
2178
|
+
graph_ready_format=graph_ready_format,
|
|
2179
|
+
graph_ready_initial=graph_ready_initial,
|
|
2180
|
+
last_check=graph_poll_state,
|
|
2181
|
+
)
|
|
2182
|
+
)
|
|
2183
|
+
|
|
2184
|
+
on_chunk_callback = on_chunk_for_progress
|
|
2185
|
+
if graph_cache:
|
|
2186
|
+
async def on_chunk_callback(chunk: str) -> None:
|
|
2187
|
+
await on_chunk_for_progress(chunk)
|
|
2188
|
+
await on_chunk_for_graphs(chunk)
|
|
2189
|
+
|
|
2190
|
+
done = anyio.Event()
|
|
2191
|
+
|
|
2192
|
+
try:
|
|
2193
|
+
async with anyio.create_task_group() as tg:
|
|
2194
|
+
async def stream_smcl() -> None:
|
|
2195
|
+
try:
|
|
2196
|
+
await self._stream_smcl_log(
|
|
2197
|
+
smcl_path=smcl_path,
|
|
2198
|
+
notify_log=notify_log,
|
|
2199
|
+
done=done,
|
|
2200
|
+
on_chunk=on_chunk_callback,
|
|
2201
|
+
)
|
|
2202
|
+
except Exception as exc:
|
|
2203
|
+
logger.debug("SMCL streaming failed: %s", exc)
|
|
2204
|
+
|
|
2205
|
+
tg.start_soon(stream_smcl)
|
|
2206
|
+
|
|
2207
|
+
if notify_progress is not None:
|
|
2208
|
+
if total_lines > 0:
|
|
2209
|
+
await notify_progress(0, float(total_lines), f"Executing do-file: 0/{total_lines}")
|
|
2210
|
+
else:
|
|
2211
|
+
await notify_progress(0, None, "Running do-file")
|
|
2212
|
+
|
|
2213
|
+
try:
|
|
2214
|
+
run_blocking = lambda: self._run_streaming_blocking(
|
|
2215
|
+
command=command,
|
|
2216
|
+
tee=tee,
|
|
2217
|
+
cwd=cwd,
|
|
2218
|
+
trace=trace,
|
|
2219
|
+
echo=echo,
|
|
2220
|
+
smcl_path=smcl_path,
|
|
2221
|
+
smcl_log_name=smcl_log_name,
|
|
2222
|
+
hold_attr="_hold_name_do",
|
|
2223
|
+
require_smcl_log=True,
|
|
2224
|
+
)
|
|
2225
|
+
try:
|
|
2226
|
+
rc, exc = await anyio.to_thread.run_sync(
|
|
2227
|
+
run_blocking,
|
|
2228
|
+
abandon_on_cancel=True,
|
|
2229
|
+
)
|
|
2230
|
+
except TypeError:
|
|
2231
|
+
rc, exc = await anyio.to_thread.run_sync(run_blocking)
|
|
2232
|
+
except Exception as e:
|
|
2233
|
+
exc = e
|
|
2234
|
+
if rc in (-1, 0):
|
|
2235
|
+
rc = 1
|
|
2236
|
+
except get_cancelled_exc_class():
|
|
2237
|
+
self._request_break_in()
|
|
2238
|
+
await self._wait_for_stata_stop()
|
|
2239
|
+
raise
|
|
2240
|
+
finally:
|
|
2241
|
+
done.set()
|
|
2242
|
+
tee.close()
|
|
2243
|
+
except* Exception as exc_group:
|
|
2244
|
+
logger.debug("SMCL streaming task group failed: %s", exc_group)
|
|
2245
|
+
|
|
2246
|
+
# Read SMCL content as the authoritative source
|
|
2247
|
+
smcl_content = self._read_smcl_file(smcl_path)
|
|
2248
|
+
|
|
2249
|
+
if graph_cache:
|
|
2250
|
+
asyncio.create_task(
|
|
2251
|
+
self._cache_new_graphs(
|
|
2252
|
+
graph_cache,
|
|
2253
|
+
notify_progress=notify_progress,
|
|
2254
|
+
total_lines=total_lines,
|
|
2255
|
+
completed_label="Do-file",
|
|
2256
|
+
)
|
|
2257
|
+
)
|
|
2258
|
+
|
|
2259
|
+
combined = self._build_combined_log(tail, log_path, rc, trace, exc)
|
|
2260
|
+
|
|
2261
|
+
# Use SMCL content as primary source for RC detection
|
|
2262
|
+
if not exc or rc in (1, -1):
|
|
2263
|
+
parsed_rc = self._parse_rc_from_smcl(smcl_content)
|
|
2264
|
+
if parsed_rc is not None and parsed_rc != 0:
|
|
2265
|
+
rc = parsed_rc
|
|
2266
|
+
elif rc in (-1, 0, 1):
|
|
2267
|
+
parsed_rc_text = self._parse_rc_from_text(combined)
|
|
2268
|
+
if parsed_rc_text is not None:
|
|
2269
|
+
rc = parsed_rc_text
|
|
2270
|
+
elif rc == -1:
|
|
2271
|
+
rc = 0 # Default to success if no error found
|
|
2272
|
+
|
|
2273
|
+
success = (rc == 0 and exc is None)
|
|
2274
|
+
stderr_final = None
|
|
2275
|
+
error = None
|
|
2276
|
+
|
|
2277
|
+
if not success:
|
|
2278
|
+
# Use SMCL as authoritative source for error extraction
|
|
2279
|
+
if smcl_content:
|
|
2280
|
+
msg, context = self._extract_error_from_smcl(smcl_content, rc)
|
|
2281
|
+
else:
|
|
2282
|
+
# Fallback to combined log
|
|
2283
|
+
msg, context = self._extract_error_and_context(combined, rc)
|
|
2284
|
+
|
|
2285
|
+
error = ErrorEnvelope(
|
|
2286
|
+
message=msg,
|
|
2287
|
+
context=context,
|
|
2288
|
+
rc=rc,
|
|
2289
|
+
command=command,
|
|
2290
|
+
log_path=log_path,
|
|
2291
|
+
snippet=smcl_content[-800:] if smcl_content else combined[-800:],
|
|
2292
|
+
smcl_output=smcl_content,
|
|
2293
|
+
)
|
|
2294
|
+
stderr_final = context
|
|
2295
|
+
|
|
2296
|
+
duration = time.time() - start_time
|
|
2297
|
+
logger.info(
|
|
2298
|
+
"stata.run(do stream) rc=%s success=%s trace=%s duration_ms=%.2f path=%s",
|
|
2299
|
+
rc,
|
|
2300
|
+
success,
|
|
2301
|
+
trace,
|
|
2302
|
+
duration * 1000,
|
|
2303
|
+
effective_path,
|
|
2304
|
+
)
|
|
2305
|
+
|
|
2306
|
+
result = CommandResponse(
|
|
2307
|
+
command=command,
|
|
2308
|
+
rc=rc,
|
|
2309
|
+
stdout="",
|
|
2310
|
+
stderr=stderr_final,
|
|
2311
|
+
log_path=log_path,
|
|
2312
|
+
success=success,
|
|
2313
|
+
error=error,
|
|
2314
|
+
smcl_output=smcl_content,
|
|
2315
|
+
)
|
|
2316
|
+
|
|
2317
|
+
if notify_progress is not None:
|
|
2318
|
+
if total_lines > 0:
|
|
2319
|
+
await notify_progress(float(total_lines), float(total_lines), f"Executing do-file: {total_lines}/{total_lines}")
|
|
2320
|
+
else:
|
|
2321
|
+
await notify_progress(1, 1, "Finished")
|
|
2322
|
+
|
|
2323
|
+
return result
|
|
2324
|
+
|
|
2325
|
+
def run_command_structured(self, code: str, echo: bool = True, trace: bool = False, max_output_lines: Optional[int] = None, cwd: Optional[str] = None) -> CommandResponse:
|
|
2326
|
+
"""Runs a Stata command and returns a structured envelope.
|
|
2327
|
+
|
|
2328
|
+
Args:
|
|
2329
|
+
code: The Stata command to execute.
|
|
2330
|
+
echo: If True, the command itself is included in the output.
|
|
2331
|
+
trace: If True, enables trace mode for debugging.
|
|
2332
|
+
max_output_lines: If set, truncates stdout to this many lines (token efficiency).
|
|
2333
|
+
"""
|
|
2334
|
+
result = self._exec_with_capture(code, echo=echo, trace=trace, cwd=cwd)
|
|
2335
|
+
|
|
2336
|
+
return self._truncate_command_output(result, max_output_lines)
|
|
2337
|
+
|
|
2338
|
+
def get_data(self, start: int = 0, count: int = 50) -> List[Dict[str, Any]]:
|
|
2339
|
+
"""Returns valid JSON-serializable data."""
|
|
2340
|
+
if not self._initialized:
|
|
2341
|
+
self.init()
|
|
2342
|
+
|
|
2343
|
+
if count > self.MAX_DATA_ROWS:
|
|
2344
|
+
count = self.MAX_DATA_ROWS
|
|
2345
|
+
|
|
2346
|
+
with self._exec_lock:
|
|
2347
|
+
try:
|
|
2348
|
+
# Use pystata integration to retrieve data
|
|
2349
|
+
df = self.stata.pdataframe_from_data()
|
|
2350
|
+
|
|
2351
|
+
# Slice
|
|
2352
|
+
sliced = df.iloc[start : start + count]
|
|
2353
|
+
|
|
2354
|
+
# Convert to dict
|
|
2355
|
+
return sliced.to_dict(orient="records")
|
|
2356
|
+
except Exception as e:
|
|
2357
|
+
return [{"error": f"Failed to retrieve data: {e}"}]
|
|
2358
|
+
|
|
2359
|
+
def list_variables(self) -> List[Dict[str, str]]:
|
|
2360
|
+
"""Returns list of variables with labels."""
|
|
2361
|
+
if not self._initialized:
|
|
2362
|
+
self.init()
|
|
2363
|
+
|
|
2364
|
+
# We can use sfi to be efficient
|
|
2365
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
2366
|
+
vars_info = []
|
|
2367
|
+
with self._exec_lock:
|
|
2368
|
+
for i in range(Data.getVarCount()):
|
|
2369
|
+
var_index = i # 0-based
|
|
2370
|
+
name = Data.getVarName(var_index)
|
|
2371
|
+
label = Data.getVarLabel(var_index)
|
|
2372
|
+
type_str = Data.getVarType(var_index) # Returns int
|
|
2373
|
+
|
|
2374
|
+
vars_info.append({
|
|
2375
|
+
"name": name,
|
|
2376
|
+
"label": label,
|
|
2377
|
+
"type": str(type_str),
|
|
2378
|
+
})
|
|
2379
|
+
return vars_info
|
|
2380
|
+
|
|
2381
|
+
def get_dataset_state(self) -> Dict[str, Any]:
|
|
2382
|
+
"""Return basic dataset state without mutating the dataset."""
|
|
2383
|
+
if not self._initialized:
|
|
2384
|
+
self.init()
|
|
2385
|
+
|
|
2386
|
+
from sfi import Data, Macro # type: ignore[import-not-found]
|
|
2387
|
+
|
|
2388
|
+
with self._exec_lock:
|
|
2389
|
+
n = int(Data.getObsTotal())
|
|
2390
|
+
k = int(Data.getVarCount())
|
|
2391
|
+
|
|
2392
|
+
frame = "default"
|
|
2393
|
+
sortlist = ""
|
|
2394
|
+
changed = False
|
|
2395
|
+
try:
|
|
2396
|
+
frame = str(Macro.getGlobal("frame") or "default")
|
|
2397
|
+
except Exception:
|
|
2398
|
+
logger.debug("Failed to get 'frame' macro", exc_info=True)
|
|
2399
|
+
frame = "default"
|
|
2400
|
+
try:
|
|
2401
|
+
sortlist = str(Macro.getGlobal("sortlist") or "")
|
|
2402
|
+
except Exception:
|
|
2403
|
+
logger.debug("Failed to get 'sortlist' macro", exc_info=True)
|
|
2404
|
+
sortlist = ""
|
|
2405
|
+
try:
|
|
2406
|
+
changed = bool(int(float(Macro.getGlobal("changed") or "0")))
|
|
2407
|
+
except Exception:
|
|
2408
|
+
logger.debug("Failed to get 'changed' macro", exc_info=True)
|
|
2409
|
+
changed = False
|
|
2410
|
+
|
|
2411
|
+
return {"frame": frame, "n": n, "k": k, "sortlist": sortlist, "changed": changed}
|
|
2412
|
+
|
|
2413
|
+
def _require_data_in_memory(self) -> None:
|
|
2414
|
+
state = self.get_dataset_state()
|
|
2415
|
+
if int(state.get("k", 0) or 0) == 0 and int(state.get("n", 0) or 0) == 0:
|
|
2416
|
+
# Stata empty dataset could still have k>0 n==0; treat that as ok.
|
|
2417
|
+
raise RuntimeError("No data in memory")
|
|
2418
|
+
|
|
2419
|
+
def _get_var_index_map(self) -> Dict[str, int]:
|
|
2420
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
2421
|
+
|
|
2422
|
+
out: Dict[str, int] = {}
|
|
2423
|
+
with self._exec_lock:
|
|
2424
|
+
for i in range(int(Data.getVarCount())):
|
|
2425
|
+
try:
|
|
2426
|
+
out[str(Data.getVarName(i))] = i
|
|
2427
|
+
except Exception:
|
|
2428
|
+
continue
|
|
2429
|
+
return out
|
|
2430
|
+
|
|
2431
|
+
def list_variables_rich(self) -> List[Dict[str, Any]]:
|
|
2432
|
+
"""Return variable metadata (name/type/label/format/valueLabel) without modifying the dataset."""
|
|
2433
|
+
if not self._initialized:
|
|
2434
|
+
self.init()
|
|
2435
|
+
|
|
2436
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
2437
|
+
|
|
2438
|
+
vars_info: List[Dict[str, Any]] = []
|
|
2439
|
+
for i in range(int(Data.getVarCount())):
|
|
2440
|
+
name = str(Data.getVarName(i))
|
|
2441
|
+
label = None
|
|
2442
|
+
fmt = None
|
|
2443
|
+
vtype = None
|
|
2444
|
+
value_label = None
|
|
2445
|
+
try:
|
|
2446
|
+
label = Data.getVarLabel(i)
|
|
2447
|
+
except Exception:
|
|
2448
|
+
label = None
|
|
2449
|
+
try:
|
|
2450
|
+
fmt = Data.getVarFormat(i)
|
|
2451
|
+
except Exception:
|
|
2452
|
+
fmt = None
|
|
2453
|
+
try:
|
|
2454
|
+
vtype = Data.getVarType(i)
|
|
2455
|
+
except Exception:
|
|
2456
|
+
vtype = None
|
|
2457
|
+
|
|
2458
|
+
vars_info.append(
|
|
2459
|
+
{
|
|
2460
|
+
"name": name,
|
|
2461
|
+
"type": str(vtype) if vtype is not None else None,
|
|
2462
|
+
"label": label if label else None,
|
|
2463
|
+
"format": fmt if fmt else None,
|
|
2464
|
+
"valueLabel": value_label,
|
|
2465
|
+
}
|
|
2466
|
+
)
|
|
2467
|
+
return vars_info
|
|
2468
|
+
|
|
2469
|
+
@staticmethod
|
|
2470
|
+
def _is_stata_missing(value: Any) -> bool:
|
|
2471
|
+
if value is None:
|
|
2472
|
+
return True
|
|
2473
|
+
if isinstance(value, float):
|
|
2474
|
+
# Stata missing values typically show up as very large floats via sfi.Data.get
|
|
2475
|
+
return value > 8.0e307
|
|
2476
|
+
return False
|
|
2477
|
+
|
|
2478
|
+
def _normalize_cell(self, value: Any, *, max_chars: int) -> tuple[Any, bool]:
|
|
2479
|
+
if self._is_stata_missing(value):
|
|
2480
|
+
return ".", False
|
|
2481
|
+
if isinstance(value, str):
|
|
2482
|
+
if len(value) > max_chars:
|
|
2483
|
+
return value[:max_chars], True
|
|
2484
|
+
return value, False
|
|
2485
|
+
return value, False
|
|
2486
|
+
|
|
2487
|
+
def get_page(
|
|
2488
|
+
self,
|
|
2489
|
+
*,
|
|
2490
|
+
offset: int,
|
|
2491
|
+
limit: int,
|
|
2492
|
+
vars: List[str],
|
|
2493
|
+
include_obs_no: bool,
|
|
2494
|
+
max_chars: int,
|
|
2495
|
+
obs_indices: Optional[List[int]] = None,
|
|
2496
|
+
) -> Dict[str, Any]:
|
|
2497
|
+
if not self._initialized:
|
|
2498
|
+
self.init()
|
|
2499
|
+
|
|
2500
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
2501
|
+
|
|
2502
|
+
state = self.get_dataset_state()
|
|
2503
|
+
n = int(state.get("n", 0) or 0)
|
|
2504
|
+
k = int(state.get("k", 0) or 0)
|
|
2505
|
+
if k == 0 and n == 0:
|
|
2506
|
+
raise RuntimeError("No data in memory")
|
|
2507
|
+
|
|
2508
|
+
var_map = self._get_var_index_map()
|
|
2509
|
+
for v in vars:
|
|
2510
|
+
if v not in var_map:
|
|
2511
|
+
raise ValueError(f"Invalid variable: {v}")
|
|
2512
|
+
|
|
2513
|
+
if obs_indices is None:
|
|
2514
|
+
start = offset
|
|
2515
|
+
end = min(offset + limit, n)
|
|
2516
|
+
if start >= n:
|
|
2517
|
+
rows: list[list[Any]] = []
|
|
2518
|
+
returned = 0
|
|
2519
|
+
obs_list: list[int] = []
|
|
2520
|
+
else:
|
|
2521
|
+
obs_list = list(range(start, end))
|
|
2522
|
+
raw_rows = Data.get(var=vars, obs=obs_list)
|
|
2523
|
+
rows = raw_rows
|
|
2524
|
+
returned = len(rows)
|
|
2525
|
+
else:
|
|
2526
|
+
start = offset
|
|
2527
|
+
end = min(offset + limit, len(obs_indices))
|
|
2528
|
+
obs_list = obs_indices[start:end]
|
|
2529
|
+
raw_rows = Data.get(var=vars, obs=obs_list) if obs_list else []
|
|
2530
|
+
rows = raw_rows
|
|
2531
|
+
returned = len(rows)
|
|
2532
|
+
|
|
2533
|
+
out_vars = list(vars)
|
|
2534
|
+
out_rows: list[list[Any]] = []
|
|
2535
|
+
truncated_cells = 0
|
|
2536
|
+
|
|
2537
|
+
if include_obs_no:
|
|
2538
|
+
out_vars = ["_n"] + out_vars
|
|
2539
|
+
|
|
2540
|
+
for idx, raw in enumerate(rows):
|
|
2541
|
+
norm_row: list[Any] = []
|
|
2542
|
+
if include_obs_no:
|
|
2543
|
+
norm_row.append(int(obs_list[idx]) + 1)
|
|
2544
|
+
for cell in raw:
|
|
2545
|
+
norm, truncated = self._normalize_cell(cell, max_chars=max_chars)
|
|
2546
|
+
if truncated:
|
|
2547
|
+
truncated_cells += 1
|
|
2548
|
+
norm_row.append(norm)
|
|
2549
|
+
out_rows.append(norm_row)
|
|
2550
|
+
|
|
2551
|
+
return {
|
|
2552
|
+
"vars": out_vars,
|
|
2553
|
+
"rows": out_rows,
|
|
2554
|
+
"returned": returned,
|
|
2555
|
+
"truncated_cells": truncated_cells,
|
|
2556
|
+
}
|
|
2557
|
+
|
|
2558
|
+
def get_arrow_stream(
|
|
2559
|
+
self,
|
|
2560
|
+
*,
|
|
2561
|
+
offset: int,
|
|
2562
|
+
limit: int,
|
|
2563
|
+
vars: List[str],
|
|
2564
|
+
include_obs_no: bool,
|
|
2565
|
+
obs_indices: Optional[List[int]] = None,
|
|
2566
|
+
) -> bytes:
|
|
2567
|
+
"""
|
|
2568
|
+
Returns an Apache Arrow IPC stream (as bytes) for the requested data page.
|
|
2569
|
+
Uses Polars if available (faster), falls back to Pandas.
|
|
2570
|
+
"""
|
|
2571
|
+
if not self._initialized:
|
|
2572
|
+
self.init()
|
|
2573
|
+
|
|
2574
|
+
import pyarrow as pa
|
|
2575
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
2576
|
+
|
|
2577
|
+
use_polars = _get_polars_available()
|
|
2578
|
+
if use_polars:
|
|
2579
|
+
import polars as pl
|
|
2580
|
+
else:
|
|
2581
|
+
import pandas as pd
|
|
2582
|
+
|
|
2583
|
+
state = self.get_dataset_state()
|
|
2584
|
+
n = int(state.get("n", 0) or 0)
|
|
2585
|
+
k = int(state.get("k", 0) or 0)
|
|
2586
|
+
if k == 0 and n == 0:
|
|
2587
|
+
raise RuntimeError("No data in memory")
|
|
2588
|
+
|
|
2589
|
+
var_map = self._get_var_index_map()
|
|
2590
|
+
for v in vars:
|
|
2591
|
+
if v not in var_map:
|
|
2592
|
+
raise ValueError(f"Invalid variable: {v}")
|
|
2593
|
+
|
|
2594
|
+
# Determine observations to fetch
|
|
2595
|
+
if obs_indices is None:
|
|
2596
|
+
start = offset
|
|
2597
|
+
end = min(offset + limit, n)
|
|
2598
|
+
obs_list = list(range(start, end)) if start < n else []
|
|
2599
|
+
else:
|
|
2600
|
+
start = offset
|
|
2601
|
+
end = min(offset + limit, len(obs_indices))
|
|
2602
|
+
obs_list = obs_indices[start:end]
|
|
2603
|
+
|
|
2604
|
+
try:
|
|
2605
|
+
if not obs_list:
|
|
2606
|
+
# Empty schema-only table
|
|
2607
|
+
if use_polars:
|
|
2608
|
+
schema_cols = {}
|
|
2609
|
+
if include_obs_no:
|
|
2610
|
+
schema_cols["_n"] = pl.Int64
|
|
2611
|
+
for v in vars:
|
|
2612
|
+
schema_cols[v] = pl.Utf8
|
|
2613
|
+
table = pl.DataFrame(schema=schema_cols).to_arrow()
|
|
2614
|
+
else:
|
|
2615
|
+
columns = {}
|
|
2616
|
+
if include_obs_no:
|
|
2617
|
+
columns["_n"] = pa.array([], type=pa.int64())
|
|
2618
|
+
for v in vars:
|
|
2619
|
+
columns[v] = pa.array([], type=pa.string())
|
|
2620
|
+
table = pa.table(columns)
|
|
2621
|
+
else:
|
|
2622
|
+
# Fetch all data in one C-call
|
|
2623
|
+
raw_data = Data.get(var=vars, obs=obs_list, valuelabel=False)
|
|
2624
|
+
|
|
2625
|
+
if use_polars:
|
|
2626
|
+
df = pl.DataFrame(raw_data, schema=vars, orient="row")
|
|
2627
|
+
if include_obs_no:
|
|
2628
|
+
obs_nums = [i + 1 for i in obs_list]
|
|
2629
|
+
df = df.with_columns(pl.Series("_n", obs_nums, dtype=pl.Int64))
|
|
2630
|
+
df = df.select(["_n"] + vars)
|
|
2631
|
+
table = df.to_arrow()
|
|
2632
|
+
else:
|
|
2633
|
+
df = pd.DataFrame(raw_data, columns=vars)
|
|
2634
|
+
if include_obs_no:
|
|
2635
|
+
df.insert(0, "_n", [i + 1 for i in obs_list])
|
|
2636
|
+
table = pa.Table.from_pandas(df, preserve_index=False)
|
|
2637
|
+
|
|
2638
|
+
# Serialize to IPC Stream
|
|
2639
|
+
sink = pa.BufferOutputStream()
|
|
2640
|
+
with pa.RecordBatchStreamWriter(sink, table.schema) as writer:
|
|
2641
|
+
writer.write_table(table)
|
|
2642
|
+
|
|
2643
|
+
return sink.getvalue().to_pybytes()
|
|
2644
|
+
|
|
2645
|
+
except Exception as e:
|
|
2646
|
+
raise RuntimeError(f"Failed to generate Arrow stream: {e}")
|
|
2647
|
+
|
|
2648
|
+
_FILTER_IDENT = re.compile(r"\b[A-Za-z_][A-Za-z0-9_]*\b")
|
|
2649
|
+
|
|
2650
|
+
def _extract_filter_vars(self, filter_expr: str) -> List[str]:
|
|
2651
|
+
tokens = set(self._FILTER_IDENT.findall(filter_expr or ""))
|
|
2652
|
+
# Exclude python keywords we might inject.
|
|
2653
|
+
exclude = {"and", "or", "not", "True", "False", "None"}
|
|
2654
|
+
var_map = self._get_var_index_map()
|
|
2655
|
+
vars_used = [t for t in tokens if t not in exclude and t in var_map]
|
|
2656
|
+
return sorted(vars_used)
|
|
2657
|
+
|
|
2658
|
+
def _compile_filter_expr(self, filter_expr: str) -> Any:
|
|
2659
|
+
expr = (filter_expr or "").strip()
|
|
2660
|
+
if not expr:
|
|
2661
|
+
raise ValueError("Empty filter")
|
|
2662
|
+
|
|
2663
|
+
# Stata boolean operators.
|
|
2664
|
+
expr = expr.replace("&", " and ").replace("|", " or ")
|
|
2665
|
+
|
|
2666
|
+
# Replace missing literal '.' (but not numeric decimals like 0.5).
|
|
2667
|
+
expr = re.sub(r"(?<![0-9])\.(?![0-9A-Za-z_])", "None", expr)
|
|
2668
|
+
|
|
2669
|
+
try:
|
|
2670
|
+
return compile(expr, "<filterExpr>", "eval")
|
|
2671
|
+
except Exception as e:
|
|
2672
|
+
raise ValueError(f"Invalid filter expression: {e}")
|
|
2673
|
+
|
|
2674
|
+
def validate_filter_expr(self, filter_expr: str) -> None:
|
|
2675
|
+
if not self._initialized:
|
|
2676
|
+
self.init()
|
|
2677
|
+
state = self.get_dataset_state()
|
|
2678
|
+
if int(state.get("k", 0) or 0) == 0 and int(state.get("n", 0) or 0) == 0:
|
|
2679
|
+
raise RuntimeError("No data in memory")
|
|
2680
|
+
|
|
2681
|
+
vars_used = self._extract_filter_vars(filter_expr)
|
|
2682
|
+
if not vars_used:
|
|
2683
|
+
# still allow constant expressions like "1" or "True"
|
|
2684
|
+
self._compile_filter_expr(filter_expr)
|
|
2685
|
+
return
|
|
2686
|
+
self._compile_filter_expr(filter_expr)
|
|
2687
|
+
|
|
2688
|
+
def compute_view_indices(self, filter_expr: str, *, chunk_size: int = 5000) -> List[int]:
|
|
2689
|
+
if not self._initialized:
|
|
2690
|
+
self.init()
|
|
2691
|
+
|
|
2692
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
2693
|
+
|
|
2694
|
+
state = self.get_dataset_state()
|
|
2695
|
+
n = int(state.get("n", 0) or 0)
|
|
2696
|
+
k = int(state.get("k", 0) or 0)
|
|
2697
|
+
if k == 0 and n == 0:
|
|
2698
|
+
raise RuntimeError("No data in memory")
|
|
2699
|
+
|
|
2700
|
+
vars_used = self._extract_filter_vars(filter_expr)
|
|
2701
|
+
code = self._compile_filter_expr(filter_expr)
|
|
2702
|
+
_ = self._get_var_index_map()
|
|
2703
|
+
|
|
2704
|
+
is_string_vars = []
|
|
2705
|
+
if vars_used:
|
|
2706
|
+
from sfi import Variable # type: ignore
|
|
2707
|
+
is_string_vars = [Variable.isString(v) for v in vars_used]
|
|
2708
|
+
|
|
2709
|
+
indices: List[int] = []
|
|
2710
|
+
for start in range(0, n, chunk_size):
|
|
2711
|
+
end = min(start + chunk_size, n)
|
|
2712
|
+
obs_list = list(range(start, end))
|
|
2713
|
+
raw_rows = Data.get(var=vars_used, obs=obs_list) if vars_used else [[None] for _ in obs_list]
|
|
2714
|
+
|
|
2715
|
+
# Try Rust optimization for the chunk
|
|
2716
|
+
if vars_used and raw_rows:
|
|
2717
|
+
# Transpose rows to columns for Rust
|
|
2718
|
+
cols = []
|
|
2719
|
+
# Extract columns
|
|
2720
|
+
for j in range(len(vars_used)):
|
|
2721
|
+
col_data_list = [row[j] for row in raw_rows]
|
|
2722
|
+
if not is_string_vars[j]:
|
|
2723
|
+
import numpy as np
|
|
2724
|
+
col_data = np.array(col_data_list, dtype=np.float64)
|
|
2725
|
+
else:
|
|
2726
|
+
col_data = col_data_list
|
|
2727
|
+
cols.append(col_data)
|
|
2728
|
+
|
|
2729
|
+
rust_indices = compute_filter_indices(filter_expr, vars_used, cols, is_string_vars)
|
|
2730
|
+
if rust_indices is not None:
|
|
2731
|
+
indices.extend([int(obs_list[i]) for i in rust_indices])
|
|
2732
|
+
continue
|
|
2733
|
+
|
|
2734
|
+
for row_i, obs in enumerate(obs_list):
|
|
2735
|
+
env: Dict[str, Any] = {}
|
|
2736
|
+
if vars_used:
|
|
2737
|
+
for j, v in enumerate(vars_used):
|
|
2738
|
+
val = raw_rows[row_i][j]
|
|
2739
|
+
env[v] = None if self._is_stata_missing(val) else val
|
|
2740
|
+
|
|
2741
|
+
ok = False
|
|
2742
|
+
try:
|
|
2743
|
+
ok = bool(eval(code, {"__builtins__": {}}, env))
|
|
2744
|
+
except NameError as e:
|
|
2745
|
+
raise ValueError(f"Invalid filter: {e}")
|
|
2746
|
+
except Exception as e:
|
|
2747
|
+
raise ValueError(f"Invalid filter: {e}")
|
|
2748
|
+
|
|
2749
|
+
if ok:
|
|
2750
|
+
indices.append(int(obs))
|
|
2751
|
+
|
|
2752
|
+
return indices
|
|
2753
|
+
|
|
2754
|
+
def apply_sort(self, sort_spec: List[str]) -> None:
|
|
2755
|
+
"""
|
|
2756
|
+
Apply sorting to the dataset using gsort.
|
|
2757
|
+
|
|
2758
|
+
Args:
|
|
2759
|
+
sort_spec: List of variables to sort by, with optional +/- prefix.
|
|
2760
|
+
e.g., ["-price", "+mpg"] sorts by price descending, then mpg ascending.
|
|
2761
|
+
No prefix is treated as ascending (+).
|
|
2762
|
+
|
|
2763
|
+
Raises:
|
|
2764
|
+
ValueError: If sort_spec is invalid or contains invalid variables
|
|
2765
|
+
RuntimeError: If no data in memory or sort command fails
|
|
2766
|
+
"""
|
|
2767
|
+
if not self._initialized:
|
|
2768
|
+
self.init()
|
|
2769
|
+
|
|
2770
|
+
state = self.get_dataset_state()
|
|
2771
|
+
if int(state.get("k", 0) or 0) == 0 and int(state.get("n", 0) or 0) == 0:
|
|
2772
|
+
raise RuntimeError("No data in memory")
|
|
2773
|
+
|
|
2774
|
+
if not sort_spec or not isinstance(sort_spec, list):
|
|
2775
|
+
raise ValueError("sort_spec must be a non-empty list")
|
|
2776
|
+
|
|
2777
|
+
# Validate all variables exist
|
|
2778
|
+
var_map = self._get_var_index_map()
|
|
2779
|
+
for spec in sort_spec:
|
|
2780
|
+
if not isinstance(spec, str) or not spec:
|
|
2781
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
2782
|
+
# Extract variable name (remove +/- prefix if present)
|
|
2783
|
+
varname = spec.lstrip("+-")
|
|
2784
|
+
if not varname:
|
|
2785
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
2786
|
+
|
|
2787
|
+
if varname not in var_map:
|
|
2788
|
+
raise ValueError(f"Variable not found: {varname}")
|
|
2789
|
+
|
|
2790
|
+
# Build gsort command
|
|
2791
|
+
# gsort uses - for descending, + or nothing for ascending
|
|
2792
|
+
gsort_args = []
|
|
2793
|
+
for spec in sort_spec:
|
|
2794
|
+
if spec.startswith("-") or spec.startswith("+"):
|
|
2795
|
+
gsort_args.append(spec)
|
|
2796
|
+
else:
|
|
2797
|
+
# No prefix means ascending, add + explicitly for clarity
|
|
2798
|
+
gsort_args.append(f"+{spec}")
|
|
2799
|
+
|
|
2800
|
+
cmd = f"gsort {' '.join(gsort_args)}"
|
|
2801
|
+
|
|
2802
|
+
try:
|
|
2803
|
+
# Sorting is hot-path for UI paging; use lightweight execution.
|
|
2804
|
+
result = self.exec_lightweight(cmd)
|
|
2805
|
+
if not result.success:
|
|
2806
|
+
error_msg = result.stderr or "Sort failed"
|
|
2807
|
+
raise RuntimeError(f"Failed to sort dataset: {error_msg}")
|
|
2808
|
+
except Exception as e:
|
|
2809
|
+
if isinstance(e, RuntimeError):
|
|
2810
|
+
raise
|
|
2811
|
+
raise RuntimeError(f"Failed to sort dataset: {e}")
|
|
2812
|
+
|
|
2813
|
+
def get_variable_details(self, varname: str) -> str:
|
|
2814
|
+
"""Returns codebook/summary for a specific variable."""
|
|
2815
|
+
resp = self.run_command_structured(f"codebook {varname}", echo=True)
|
|
2816
|
+
if resp.success:
|
|
2817
|
+
return resp.stdout
|
|
2818
|
+
if resp.error:
|
|
2819
|
+
return resp.error.message
|
|
2820
|
+
return ""
|
|
2821
|
+
|
|
2822
|
+
def list_variables_structured(self) -> VariablesResponse:
|
|
2823
|
+
vars_info: List[VariableInfo] = []
|
|
2824
|
+
for item in self.list_variables():
|
|
2825
|
+
vars_info.append(
|
|
2826
|
+
VariableInfo(
|
|
2827
|
+
name=item.get("name", ""),
|
|
2828
|
+
label=item.get("label"),
|
|
2829
|
+
type=item.get("type"),
|
|
2830
|
+
)
|
|
2831
|
+
)
|
|
2832
|
+
return VariablesResponse(variables=vars_info)
|
|
2833
|
+
|
|
2834
|
+
def list_graphs(self, *, force_refresh: bool = False) -> List[str]:
|
|
2835
|
+
"""Returns list of graphs in memory with TTL caching."""
|
|
2836
|
+
if not self._initialized:
|
|
2837
|
+
self.init()
|
|
2838
|
+
|
|
2839
|
+
import time
|
|
2840
|
+
|
|
2841
|
+
# Prevent recursive Stata calls - if we're already executing, return cached or empty
|
|
2842
|
+
if self._is_executing:
|
|
2843
|
+
with self._list_graphs_cache_lock:
|
|
2844
|
+
if self._list_graphs_cache is not None:
|
|
2845
|
+
logger.debug("Recursive list_graphs call prevented, returning cached value")
|
|
2846
|
+
return self._list_graphs_cache
|
|
2847
|
+
else:
|
|
2848
|
+
logger.debug("Recursive list_graphs call prevented, returning empty list")
|
|
2849
|
+
return []
|
|
2850
|
+
|
|
2851
|
+
# Check if cache is valid
|
|
2852
|
+
current_time = time.time()
|
|
2853
|
+
with self._list_graphs_cache_lock:
|
|
2854
|
+
if (not force_refresh and self._list_graphs_cache is not None and
|
|
2855
|
+
current_time - self._list_graphs_cache_time < self.LIST_GRAPHS_TTL):
|
|
2856
|
+
return self._list_graphs_cache
|
|
2857
|
+
|
|
2858
|
+
# Cache miss or expired, fetch fresh data
|
|
2859
|
+
with self._exec_lock:
|
|
2860
|
+
try:
|
|
2861
|
+
# Preservation of r() results is critical because this can be called
|
|
2862
|
+
# automatically after every user command (e.g., during streaming).
|
|
2863
|
+
import time
|
|
2864
|
+
hold_name = f"_mcp_ghold_{int(time.time() * 1000 % 1000000)}"
|
|
2865
|
+
self.stata.run(f"capture _return hold {hold_name}", echo=False)
|
|
2866
|
+
|
|
2867
|
+
try:
|
|
2868
|
+
self.stata.run("macro define mcp_graph_list \"\"", echo=False)
|
|
2869
|
+
self.stata.run("quietly graph dir, memory", echo=False)
|
|
2870
|
+
from sfi import Macro # type: ignore[import-not-found]
|
|
2871
|
+
self.stata.run("macro define mcp_graph_list `r(list)'", echo=False)
|
|
2872
|
+
graph_list_str = Macro.getGlobal("mcp_graph_list")
|
|
2873
|
+
finally:
|
|
2874
|
+
self.stata.run(f"capture _return restore {hold_name}", echo=False)
|
|
2875
|
+
|
|
2876
|
+
raw_list = graph_list_str.split() if graph_list_str else []
|
|
2877
|
+
|
|
2878
|
+
# Map internal Stata names back to user-facing names when we have an alias.
|
|
2879
|
+
reverse = getattr(self, "_graph_name_reverse", {})
|
|
2880
|
+
graph_list = [reverse.get(n, n) for n in raw_list]
|
|
2881
|
+
|
|
2882
|
+
result = graph_list
|
|
2883
|
+
|
|
2884
|
+
# Update cache
|
|
2885
|
+
with self._list_graphs_cache_lock:
|
|
2886
|
+
self._list_graphs_cache = result
|
|
2887
|
+
self._list_graphs_cache_time = time.time()
|
|
2888
|
+
|
|
2889
|
+
return result
|
|
2890
|
+
|
|
2891
|
+
except Exception as e:
|
|
2892
|
+
# On error, return cached result if available, otherwise empty list
|
|
2893
|
+
with self._list_graphs_cache_lock:
|
|
2894
|
+
if self._list_graphs_cache is not None:
|
|
2895
|
+
logger.warning(f"list_graphs failed, returning cached result: {e}")
|
|
2896
|
+
return self._list_graphs_cache
|
|
2897
|
+
logger.warning(f"list_graphs failed, no cache available: {e}")
|
|
2898
|
+
return []
|
|
2899
|
+
|
|
2900
|
+
def list_graphs_structured(self) -> GraphListResponse:
|
|
2901
|
+
names = self.list_graphs()
|
|
2902
|
+
active_name = names[-1] if names else None
|
|
2903
|
+
graphs = [GraphInfo(name=n, active=(n == active_name)) for n in names]
|
|
2904
|
+
return GraphListResponse(graphs=graphs)
|
|
2905
|
+
|
|
2906
|
+
def invalidate_list_graphs_cache(self) -> None:
|
|
2907
|
+
"""Invalidate the list_graphs cache to force fresh data on next call."""
|
|
2908
|
+
with self._list_graphs_cache_lock:
|
|
2909
|
+
self._list_graphs_cache = None
|
|
2910
|
+
self._list_graphs_cache_time = 0
|
|
2911
|
+
|
|
2912
|
+
def export_graph(self, graph_name: str = None, filename: str = None, format: str = "pdf") -> str:
|
|
2913
|
+
"""Exports graph to a temp file (pdf or png) and returns the path.
|
|
2914
|
+
|
|
2915
|
+
On Windows, PyStata can crash when exporting PNGs directly. For PNG on
|
|
2916
|
+
Windows, we save the graph to .gph and invoke the Stata executable in
|
|
2917
|
+
batch mode to export the PNG out-of-process.
|
|
2918
|
+
"""
|
|
2919
|
+
import tempfile
|
|
2920
|
+
|
|
2921
|
+
fmt = (format or "pdf").strip().lower()
|
|
2922
|
+
if fmt not in {"pdf", "png", "svg"}:
|
|
2923
|
+
raise ValueError(f"Unsupported graph export format: {format}. Allowed: pdf, png, svg.")
|
|
2924
|
+
|
|
2925
|
+
|
|
2926
|
+
if not filename:
|
|
2927
|
+
suffix = f".{fmt}"
|
|
2928
|
+
with tempfile.NamedTemporaryFile(prefix="mcp_stata_", suffix=suffix, delete=False) as tmp:
|
|
2929
|
+
filename = tmp.name
|
|
2930
|
+
else:
|
|
2931
|
+
# Ensure fresh start
|
|
2932
|
+
if os.path.exists(filename):
|
|
2933
|
+
try:
|
|
2934
|
+
os.remove(filename)
|
|
2935
|
+
except Exception:
|
|
2936
|
+
pass
|
|
2937
|
+
|
|
2938
|
+
# Keep the user-facing path as a normal absolute Windows path
|
|
2939
|
+
user_filename = os.path.abspath(filename)
|
|
2940
|
+
|
|
2941
|
+
if fmt == "png" and os.name == "nt":
|
|
2942
|
+
# 1) Save graph to a .gph file from the embedded session
|
|
2943
|
+
with tempfile.NamedTemporaryFile(prefix="mcp_stata_graph_", suffix=".gph", delete=False) as gph_tmp:
|
|
2944
|
+
gph_path = gph_tmp.name
|
|
2945
|
+
gph_path_for_stata = gph_path.replace("\\", "/")
|
|
2946
|
+
# Make the target graph current, then save without name() (which isn't accepted there)
|
|
2947
|
+
if graph_name:
|
|
2948
|
+
self._exec_no_capture_silent(f'quietly graph display "{graph_name}"', echo=False)
|
|
2949
|
+
save_cmd = f'quietly graph save "{gph_path_for_stata}", replace'
|
|
2950
|
+
save_resp = self._exec_no_capture_silent(save_cmd, echo=False)
|
|
2951
|
+
if not save_resp.success:
|
|
2952
|
+
msg = save_resp.error.message if save_resp.error else f"graph save failed (rc={save_resp.rc})"
|
|
2953
|
+
raise RuntimeError(msg)
|
|
2954
|
+
|
|
2955
|
+
# 2) Prepare a do-file to export PNG externally
|
|
2956
|
+
user_filename_fwd = user_filename.replace("\\", "/")
|
|
2957
|
+
do_lines = [
|
|
2958
|
+
f'quietly graph use "{gph_path_for_stata}"',
|
|
2959
|
+
f'quietly graph export "{user_filename_fwd}", replace as(png)',
|
|
2960
|
+
"exit",
|
|
2961
|
+
]
|
|
2962
|
+
with tempfile.NamedTemporaryFile(prefix="mcp_stata_export_", suffix=".do", delete=False, mode="w", encoding="ascii") as do_tmp:
|
|
2963
|
+
do_tmp.write("\n".join(do_lines))
|
|
2964
|
+
do_path = do_tmp.name
|
|
2965
|
+
|
|
2966
|
+
stata_exe = getattr(self, "_stata_exec_path", None)
|
|
2967
|
+
if not stata_exe or not os.path.exists(stata_exe):
|
|
2968
|
+
raise RuntimeError("Stata executable path unavailable for PNG export")
|
|
2969
|
+
|
|
2970
|
+
workdir = os.path.dirname(do_path) or None
|
|
2971
|
+
log_path = os.path.splitext(do_path)[0] + ".log"
|
|
2972
|
+
|
|
2973
|
+
cmd = [stata_exe, "/e", "do", do_path]
|
|
2974
|
+
try:
|
|
2975
|
+
completed = subprocess.run(
|
|
2976
|
+
cmd,
|
|
2977
|
+
capture_output=True,
|
|
2978
|
+
text=True,
|
|
2979
|
+
timeout=30,
|
|
2980
|
+
cwd=workdir,
|
|
2981
|
+
)
|
|
2982
|
+
except subprocess.TimeoutExpired:
|
|
2983
|
+
raise RuntimeError("External Stata export timed out")
|
|
2984
|
+
finally:
|
|
2985
|
+
try:
|
|
2986
|
+
os.remove(do_path)
|
|
2987
|
+
except Exception:
|
|
2988
|
+
# Ignore errors during temporary do-file cleanup (file may not exist or be locked)
|
|
2989
|
+
logger.warning("Failed to remove temporary do-file: %s", do_path, exc_info=True)
|
|
2990
|
+
|
|
2991
|
+
try:
|
|
2992
|
+
os.remove(gph_path)
|
|
2993
|
+
except Exception:
|
|
2994
|
+
logger.warning("Failed to remove temporary graph file: %s", gph_path, exc_info=True)
|
|
2995
|
+
|
|
2996
|
+
try:
|
|
2997
|
+
if os.path.exists(log_path):
|
|
2998
|
+
os.remove(log_path)
|
|
2999
|
+
except Exception:
|
|
3000
|
+
logger.warning("Failed to remove temporary log file: %s", log_path, exc_info=True)
|
|
3001
|
+
|
|
3002
|
+
if completed.returncode != 0:
|
|
3003
|
+
err = completed.stderr.strip() or completed.stdout.strip() or str(completed.returncode)
|
|
3004
|
+
raise RuntimeError(f"External Stata export failed: {err}")
|
|
3005
|
+
|
|
3006
|
+
else:
|
|
3007
|
+
# Stata prefers forward slashes in its command parser on Windows
|
|
3008
|
+
filename_for_stata = user_filename.replace("\\", "/")
|
|
3009
|
+
|
|
3010
|
+
if graph_name:
|
|
3011
|
+
resolved = self._resolve_graph_name_for_stata(graph_name)
|
|
3012
|
+
# Use display + export without name() for maximum compatibility.
|
|
3013
|
+
# name(NAME) often fails in PyStata for non-active graphs (r(693)).
|
|
3014
|
+
self._exec_no_capture_silent(f'quietly graph display "{resolved}"', echo=False)
|
|
3015
|
+
|
|
3016
|
+
cmd = f'quietly graph export "{filename_for_stata}", replace as({fmt})'
|
|
3017
|
+
|
|
3018
|
+
# Avoid stdout/stderr redirection for graph export because PyStata's
|
|
3019
|
+
# output thread can crash on Windows when we swap stdio handles.
|
|
3020
|
+
resp = self._exec_no_capture_silent(cmd, echo=False)
|
|
3021
|
+
if not resp.success:
|
|
3022
|
+
# Retry once after a short pause in case Stata had a transient file handle issue
|
|
3023
|
+
time.sleep(0.2)
|
|
3024
|
+
resp_retry = self._exec_no_capture_silent(cmd, echo=False)
|
|
3025
|
+
if not resp_retry.success:
|
|
3026
|
+
msg = resp_retry.error.message if resp_retry.error else f"graph export failed (rc={resp_retry.rc})"
|
|
3027
|
+
raise RuntimeError(msg)
|
|
3028
|
+
resp = resp_retry
|
|
3029
|
+
|
|
3030
|
+
if os.path.exists(user_filename):
|
|
3031
|
+
try:
|
|
3032
|
+
size = os.path.getsize(user_filename)
|
|
3033
|
+
if size == 0:
|
|
3034
|
+
raise RuntimeError(f"Graph export failed: produced empty file {user_filename}")
|
|
3035
|
+
if size > self.MAX_GRAPH_BYTES:
|
|
3036
|
+
raise RuntimeError(
|
|
3037
|
+
f"Graph export failed: file too large (> {self.MAX_GRAPH_BYTES} bytes): {user_filename}"
|
|
3038
|
+
)
|
|
3039
|
+
except Exception as size_err:
|
|
3040
|
+
# Clean up oversized or unreadable files
|
|
3041
|
+
try:
|
|
3042
|
+
os.remove(user_filename)
|
|
3043
|
+
except Exception:
|
|
3044
|
+
pass
|
|
3045
|
+
raise size_err
|
|
3046
|
+
return user_filename
|
|
3047
|
+
|
|
3048
|
+
# If file missing, it failed. Check output for details.
|
|
3049
|
+
msg = resp.error.message if resp.error else "graph export failed: file missing"
|
|
3050
|
+
raise RuntimeError(msg)
|
|
3051
|
+
|
|
3052
|
+
def get_help(self, topic: str, plain_text: bool = False) -> str:
|
|
3053
|
+
"""Returns help text as Markdown (default) or plain text."""
|
|
3054
|
+
if not self._initialized:
|
|
3055
|
+
self.init()
|
|
3056
|
+
|
|
3057
|
+
with self._exec_lock:
|
|
3058
|
+
# Try to locate the .sthlp help file
|
|
3059
|
+
# We use 'capture' to avoid crashing if not found
|
|
3060
|
+
self.stata.run(f"capture findfile {topic}.sthlp")
|
|
3061
|
+
|
|
3062
|
+
# Retrieve the found path from r(fn)
|
|
3063
|
+
from sfi import Macro # type: ignore[import-not-found]
|
|
3064
|
+
self.stata.run("global mcp_help_file `r(fn)'")
|
|
3065
|
+
fn = Macro.getGlobal("mcp_help_file")
|
|
3066
|
+
|
|
3067
|
+
if fn and os.path.exists(fn):
|
|
3068
|
+
try:
|
|
3069
|
+
with open(fn, 'r', encoding='utf-8', errors='replace') as f:
|
|
3070
|
+
smcl = f.read()
|
|
3071
|
+
if plain_text:
|
|
3072
|
+
return self._smcl_to_text(smcl)
|
|
3073
|
+
try:
|
|
3074
|
+
return smcl_to_markdown(smcl, adopath=os.path.dirname(fn), current_file=os.path.splitext(os.path.basename(fn))[0])
|
|
3075
|
+
except Exception as parse_err:
|
|
3076
|
+
logger.warning("SMCL to Markdown failed, falling back to plain text: %s", parse_err)
|
|
3077
|
+
return self._smcl_to_text(smcl)
|
|
3078
|
+
except Exception as e:
|
|
3079
|
+
logger.warning("Help file read failed for %s: %s", topic, e)
|
|
3080
|
+
|
|
3081
|
+
# If no help file found, return a fallback message
|
|
3082
|
+
return f"Help file for '{topic}' not found."
|
|
3083
|
+
|
|
3084
|
+
def get_stored_results(self, force_fresh: bool = False) -> Dict[str, Any]:
|
|
3085
|
+
"""Returns e() and r() results using SFI for maximum reliability."""
|
|
3086
|
+
if not force_fresh and self._last_results is not None:
|
|
3087
|
+
return self._last_results
|
|
3088
|
+
|
|
3089
|
+
if not self._initialized:
|
|
3090
|
+
self.init()
|
|
3091
|
+
|
|
3092
|
+
with self._exec_lock:
|
|
3093
|
+
# We must be extremely careful not to clobber r()/e() while fetching their names.
|
|
3094
|
+
# We use a hold to peek at the results.
|
|
3095
|
+
hold_name = f"mcp_peek_{uuid.uuid4().hex[:8]}"
|
|
3096
|
+
self.stata.run(f"capture _return hold {hold_name}", echo=False)
|
|
3097
|
+
|
|
3098
|
+
try:
|
|
3099
|
+
from sfi import Scalar, Macro
|
|
3100
|
+
results = {"r": {}, "e": {}}
|
|
3101
|
+
|
|
3102
|
+
for rclass in ["r", "e"]:
|
|
3103
|
+
# Restore with 'hold' to peek at results without losing them from the hold
|
|
3104
|
+
# Note: Stata 18+ supports 'restore ..., hold' which is ideal.
|
|
3105
|
+
self.stata.run(f"capture _return restore {hold_name}, hold", echo=False)
|
|
3106
|
+
|
|
3107
|
+
# Fetch names using backtick expansion (which we verified works better than colon)
|
|
3108
|
+
# and avoid leading underscores which were causing syntax errors with 'global'
|
|
3109
|
+
self.stata.run(f"macro define mcp_scnames `: {rclass}(scalars)'", echo=False)
|
|
3110
|
+
self.stata.run(f"macro define mcp_macnames `: {rclass}(macros)'", echo=False)
|
|
3111
|
+
|
|
3112
|
+
# 1. Capture Scalars
|
|
3113
|
+
names_str = Macro.getGlobal("mcp_scnames")
|
|
3114
|
+
if names_str:
|
|
3115
|
+
for name in names_str.split():
|
|
3116
|
+
try:
|
|
3117
|
+
val = Scalar.getValue(f"{rclass}({name})")
|
|
3118
|
+
results[rclass][name] = val
|
|
3119
|
+
except Exception:
|
|
3120
|
+
pass
|
|
3121
|
+
|
|
3122
|
+
# 2. Capture Macros (strings)
|
|
3123
|
+
macros_str = Macro.getGlobal("mcp_macnames")
|
|
3124
|
+
if macros_str:
|
|
3125
|
+
for name in macros_str.split():
|
|
3126
|
+
try:
|
|
3127
|
+
# Restore/Hold again to be safe before fetching each macro
|
|
3128
|
+
self.stata.run(f"capture _return restore {hold_name}, hold", echo=False)
|
|
3129
|
+
# Capture the string value into a macro
|
|
3130
|
+
self.stata.run(f"macro define mcp_mval `{rclass}({name})'", echo=False)
|
|
3131
|
+
val = Macro.getGlobal("mcp_mval")
|
|
3132
|
+
results[rclass][name] = val
|
|
3133
|
+
except Exception:
|
|
3134
|
+
pass
|
|
3135
|
+
|
|
3136
|
+
# Cleanup
|
|
3137
|
+
self.stata.run("macro drop mcp_scnames mcp_macnames mcp_mval", echo=False)
|
|
3138
|
+
self.stata.run(f"capture _return restore {hold_name}", echo=False) # Restore one last time to leave Stata in correct state
|
|
3139
|
+
|
|
3140
|
+
self._last_results = results
|
|
3141
|
+
return results
|
|
3142
|
+
except Exception as e:
|
|
3143
|
+
logger.error(f"SFI-based get_stored_results failed: {e}")
|
|
3144
|
+
# Try to clean up hold if we failed
|
|
3145
|
+
try:
|
|
3146
|
+
self.stata.run(f"capture _return drop {hold_name}", echo=False)
|
|
3147
|
+
except Exception:
|
|
3148
|
+
pass
|
|
3149
|
+
return {"r": {}, "e": {}}
|
|
3150
|
+
|
|
3151
|
+
def invalidate_graph_cache(self, graph_name: str = None) -> None:
|
|
3152
|
+
"""Invalidate cache for specific graph or all graphs.
|
|
3153
|
+
|
|
3154
|
+
Args:
|
|
3155
|
+
graph_name: Specific graph name to invalidate. If None, clears all cache.
|
|
3156
|
+
"""
|
|
3157
|
+
self._initialize_cache()
|
|
3158
|
+
|
|
3159
|
+
with self._cache_lock:
|
|
3160
|
+
if graph_name is None:
|
|
3161
|
+
# Clear all cache
|
|
3162
|
+
self._preemptive_cache.clear()
|
|
3163
|
+
else:
|
|
3164
|
+
# Clear specific graph cache
|
|
3165
|
+
if graph_name in self._preemptive_cache:
|
|
3166
|
+
del self._preemptive_cache[graph_name]
|
|
3167
|
+
# Also clear hash if present
|
|
3168
|
+
hash_key = f"{graph_name}_hash"
|
|
3169
|
+
if hash_key in self._preemptive_cache:
|
|
3170
|
+
del self._preemptive_cache[hash_key]
|
|
3171
|
+
|
|
3172
|
+
def _initialize_cache(self) -> None:
|
|
3173
|
+
"""Initialize cache in a thread-safe manner."""
|
|
3174
|
+
import tempfile
|
|
3175
|
+
import threading
|
|
3176
|
+
import os
|
|
3177
|
+
import uuid
|
|
3178
|
+
|
|
3179
|
+
with StataClient._cache_init_lock: # Use class-level lock
|
|
3180
|
+
if not hasattr(self, '_cache_initialized'):
|
|
3181
|
+
self._preemptive_cache = {}
|
|
3182
|
+
self._cache_access_times = {} # Track access times for LRU
|
|
3183
|
+
self._cache_sizes = {} # Track individual cache item sizes
|
|
3184
|
+
self._total_cache_size = 0 # Track total cache size in bytes
|
|
3185
|
+
# Use unique identifier to avoid conflicts
|
|
3186
|
+
unique_id = f"preemptive_cache_{uuid.uuid4().hex[:8]}_{os.getpid()}"
|
|
3187
|
+
self._preemptive_cache_dir = tempfile.mkdtemp(prefix=unique_id)
|
|
3188
|
+
self._cache_lock = threading.Lock()
|
|
3189
|
+
self._cache_initialized = True
|
|
3190
|
+
|
|
3191
|
+
# Register cleanup function
|
|
3192
|
+
import atexit
|
|
3193
|
+
atexit.register(self._cleanup_cache)
|
|
3194
|
+
else:
|
|
3195
|
+
# Cache already initialized, but directory might have been removed.
|
|
3196
|
+
if (not hasattr(self, '_preemptive_cache_dir') or
|
|
3197
|
+
not self._preemptive_cache_dir or
|
|
3198
|
+
not os.path.isdir(self._preemptive_cache_dir)):
|
|
3199
|
+
unique_id = f"preemptive_cache_{uuid.uuid4().hex[:8]}_{os.getpid()}"
|
|
3200
|
+
self._preemptive_cache_dir = tempfile.mkdtemp(prefix=unique_id)
|
|
3201
|
+
|
|
3202
|
+
def _cleanup_cache(self) -> None:
|
|
3203
|
+
"""Clean up cache directory and files."""
|
|
3204
|
+
import os
|
|
3205
|
+
import shutil
|
|
3206
|
+
|
|
3207
|
+
if hasattr(self, '_preemptive_cache_dir') and self._preemptive_cache_dir:
|
|
3208
|
+
try:
|
|
3209
|
+
shutil.rmtree(self._preemptive_cache_dir, ignore_errors=True)
|
|
3210
|
+
except Exception:
|
|
3211
|
+
pass # Best effort cleanup
|
|
3212
|
+
|
|
3213
|
+
if hasattr(self, '_preemptive_cache'):
|
|
3214
|
+
self._preemptive_cache.clear()
|
|
3215
|
+
|
|
3216
|
+
def _evict_cache_if_needed(self, new_item_size: int = 0) -> None:
|
|
3217
|
+
"""
|
|
3218
|
+
Evict least recently used cache items if cache size limits are exceeded.
|
|
3219
|
+
|
|
3220
|
+
NOTE: The caller is responsible for holding ``self._cache_lock`` while
|
|
3221
|
+
invoking this method, so that eviction and subsequent cache insertion
|
|
3222
|
+
(if any) occur within a single critical section.
|
|
3223
|
+
"""
|
|
3224
|
+
import time
|
|
3225
|
+
|
|
3226
|
+
# Check if we need to evict based on count or size
|
|
3227
|
+
needs_eviction = (
|
|
3228
|
+
len(self._preemptive_cache) > StataClient.MAX_CACHE_SIZE or
|
|
3229
|
+
self._total_cache_size + new_item_size > StataClient.MAX_CACHE_BYTES
|
|
3230
|
+
)
|
|
3231
|
+
|
|
3232
|
+
if not needs_eviction:
|
|
3233
|
+
return
|
|
3234
|
+
|
|
3235
|
+
# Sort by access time (oldest first)
|
|
3236
|
+
items_by_access = sorted(
|
|
3237
|
+
self._cache_access_times.items(),
|
|
3238
|
+
key=lambda x: x[1]
|
|
3239
|
+
)
|
|
3240
|
+
|
|
3241
|
+
evicted_count = 0
|
|
3242
|
+
for graph_name, access_time in items_by_access:
|
|
3243
|
+
if (len(self._preemptive_cache) < StataClient.MAX_CACHE_SIZE and
|
|
3244
|
+
self._total_cache_size + new_item_size <= StataClient.MAX_CACHE_BYTES):
|
|
3245
|
+
break
|
|
3246
|
+
|
|
3247
|
+
# Remove from cache
|
|
3248
|
+
if graph_name in self._preemptive_cache:
|
|
3249
|
+
cache_path = self._preemptive_cache[graph_name]
|
|
3250
|
+
|
|
3251
|
+
# Remove file
|
|
3252
|
+
try:
|
|
3253
|
+
if os.path.exists(cache_path):
|
|
3254
|
+
os.remove(cache_path)
|
|
3255
|
+
except Exception:
|
|
3256
|
+
pass
|
|
3257
|
+
|
|
3258
|
+
# Update tracking
|
|
3259
|
+
item_size = self._cache_sizes.get(graph_name, 0)
|
|
3260
|
+
del self._preemptive_cache[graph_name]
|
|
3261
|
+
del self._cache_access_times[graph_name]
|
|
3262
|
+
if graph_name in self._cache_sizes:
|
|
3263
|
+
del self._cache_sizes[graph_name]
|
|
3264
|
+
self._total_cache_size -= item_size
|
|
3265
|
+
evicted_count += 1
|
|
3266
|
+
|
|
3267
|
+
# Remove hash entry if exists
|
|
3268
|
+
hash_key = f"{graph_name}_hash"
|
|
3269
|
+
if hash_key in self._preemptive_cache:
|
|
3270
|
+
del self._preemptive_cache[hash_key]
|
|
3271
|
+
|
|
3272
|
+
if evicted_count > 0:
|
|
3273
|
+
logger.debug(f"Evicted {evicted_count} items from graph cache due to size limits")
|
|
3274
|
+
|
|
3275
|
+
def _get_content_hash(self, data: bytes) -> str:
|
|
3276
|
+
"""Generate content hash for cache validation."""
|
|
3277
|
+
import hashlib
|
|
3278
|
+
return hashlib.md5(data).hexdigest()
|
|
3279
|
+
|
|
3280
|
+
def _sanitize_filename(self, name: str) -> str:
|
|
3281
|
+
"""Sanitize graph name for safe file system usage."""
|
|
3282
|
+
import re
|
|
3283
|
+
# Remove or replace problematic characters
|
|
3284
|
+
safe_name = re.sub(r'[<>:"/\\|?*]', '_', name)
|
|
3285
|
+
safe_name = re.sub(r'[^\w\-_.]', '_', safe_name)
|
|
3286
|
+
# Limit length
|
|
3287
|
+
return safe_name[:100] if len(safe_name) > 100 else safe_name
|
|
3288
|
+
|
|
3289
|
+
def _validate_graph_exists(self, graph_name: str) -> bool:
|
|
3290
|
+
"""Validate that graph still exists in Stata."""
|
|
3291
|
+
try:
|
|
3292
|
+
# First try to get graph list to verify existence
|
|
3293
|
+
graph_list = self.list_graphs(force_refresh=True)
|
|
3294
|
+
if graph_name not in graph_list:
|
|
3295
|
+
return False
|
|
3296
|
+
|
|
3297
|
+
# Additional validation by attempting to display the graph
|
|
3298
|
+
resolved = self._resolve_graph_name_for_stata(graph_name)
|
|
3299
|
+
cmd = f'quietly graph display {resolved}'
|
|
3300
|
+
resp = self._exec_no_capture_silent(cmd, echo=False)
|
|
3301
|
+
return resp.success
|
|
3302
|
+
except Exception:
|
|
3303
|
+
return False
|
|
3304
|
+
|
|
3305
|
+
def _is_cache_valid(self, graph_name: str, cache_path: str) -> bool:
|
|
3306
|
+
"""Check if cached content is still valid using internal signatures."""
|
|
3307
|
+
try:
|
|
3308
|
+
if not os.path.exists(cache_path) or os.path.getsize(cache_path) == 0:
|
|
3309
|
+
return False
|
|
3310
|
+
|
|
3311
|
+
current_sig = self._get_graph_signature(graph_name)
|
|
3312
|
+
cached_sig = self._preemptive_cache.get(f"{graph_name}_sig")
|
|
3313
|
+
|
|
3314
|
+
# If we have a signature match, it's valid for the current command session
|
|
3315
|
+
if cached_sig and cached_sig == current_sig:
|
|
3316
|
+
return True
|
|
3317
|
+
|
|
3318
|
+
# Otherwise it's invalid (needs refresh for new command)
|
|
3319
|
+
return False
|
|
3320
|
+
except Exception:
|
|
3321
|
+
return False
|
|
3322
|
+
|
|
3323
|
+
def export_graphs_all(self) -> GraphExportResponse:
|
|
3324
|
+
"""Exports all graphs to file paths."""
|
|
3325
|
+
exports: List[GraphExport] = []
|
|
3326
|
+
graph_names = self.list_graphs(force_refresh=True)
|
|
3327
|
+
|
|
3328
|
+
if not graph_names:
|
|
3329
|
+
return GraphExportResponse(graphs=exports)
|
|
3330
|
+
|
|
3331
|
+
import tempfile
|
|
3332
|
+
import os
|
|
3333
|
+
import threading
|
|
3334
|
+
import uuid
|
|
3335
|
+
import time
|
|
3336
|
+
import logging
|
|
3337
|
+
|
|
3338
|
+
# Initialize cache in thread-safe manner
|
|
3339
|
+
self._initialize_cache()
|
|
3340
|
+
|
|
3341
|
+
def _cache_keyed_svg_path(name: str) -> str:
|
|
3342
|
+
import hashlib
|
|
3343
|
+
safe_name = self._sanitize_filename(name)
|
|
3344
|
+
suffix = hashlib.md5((name or "").encode("utf-8")).hexdigest()[:8]
|
|
3345
|
+
return os.path.join(self._preemptive_cache_dir, f"{safe_name}_{suffix}.svg")
|
|
3346
|
+
|
|
3347
|
+
def _export_svg_bytes(name: str) -> bytes:
|
|
3348
|
+
resolved = self._resolve_graph_name_for_stata(name)
|
|
3349
|
+
|
|
3350
|
+
temp_dir = tempfile.gettempdir()
|
|
3351
|
+
safe_temp_name = self._sanitize_filename(name)
|
|
3352
|
+
unique_filename = f"{safe_temp_name}_{uuid.uuid4().hex[:8]}_{os.getpid()}_{int(time.time())}.svg"
|
|
3353
|
+
svg_path = os.path.join(temp_dir, unique_filename)
|
|
3354
|
+
svg_path_for_stata = svg_path.replace("\\", "/")
|
|
3355
|
+
|
|
3356
|
+
try:
|
|
3357
|
+
export_cmd = f'quietly graph export "{svg_path_for_stata}", name({resolved}) replace as(svg)'
|
|
3358
|
+
export_resp = self._exec_no_capture_silent(export_cmd, echo=False)
|
|
3359
|
+
|
|
3360
|
+
if not export_resp.success:
|
|
3361
|
+
display_cmd = f'quietly graph display {resolved}'
|
|
3362
|
+
display_resp = self._exec_no_capture_silent(display_cmd, echo=False)
|
|
3363
|
+
if display_resp.success:
|
|
3364
|
+
export_cmd2 = f'quietly graph export "{svg_path_for_stata}", replace as(svg)'
|
|
3365
|
+
export_resp = self._exec_no_capture_silent(export_cmd2, echo=False)
|
|
3366
|
+
else:
|
|
3367
|
+
export_resp = display_resp
|
|
3368
|
+
|
|
3369
|
+
if export_resp.success and os.path.exists(svg_path) and os.path.getsize(svg_path) > 0:
|
|
3370
|
+
with open(svg_path, "rb") as f:
|
|
3371
|
+
return f.read()
|
|
3372
|
+
error_msg = getattr(export_resp, 'error', 'Unknown error')
|
|
3373
|
+
raise RuntimeError(f"Failed to export graph {name}: {error_msg}")
|
|
3374
|
+
finally:
|
|
3375
|
+
if os.path.exists(svg_path):
|
|
3376
|
+
try:
|
|
3377
|
+
os.remove(svg_path)
|
|
3378
|
+
except OSError as e:
|
|
3379
|
+
logger.warning(f"Failed to cleanup temp file {svg_path}: {e}")
|
|
3380
|
+
|
|
3381
|
+
cached_graphs = {}
|
|
3382
|
+
uncached_graphs = []
|
|
3383
|
+
cache_errors = []
|
|
3384
|
+
|
|
3385
|
+
with self._cache_lock:
|
|
3386
|
+
for name in graph_names:
|
|
3387
|
+
if name in self._preemptive_cache:
|
|
3388
|
+
cached_path = self._preemptive_cache[name]
|
|
3389
|
+
if os.path.exists(cached_path) and os.path.getsize(cached_path) > 0:
|
|
3390
|
+
# Additional validation: check if graph content has changed
|
|
3391
|
+
if self._is_cache_valid(name, cached_path):
|
|
3392
|
+
cached_graphs[name] = cached_path
|
|
3393
|
+
else:
|
|
3394
|
+
uncached_graphs.append(name)
|
|
3395
|
+
# Remove stale cache entry
|
|
3396
|
+
del self._preemptive_cache[name]
|
|
3397
|
+
else:
|
|
3398
|
+
uncached_graphs.append(name)
|
|
3399
|
+
# Remove invalid cache entry
|
|
3400
|
+
if name in self._preemptive_cache:
|
|
3401
|
+
del self._preemptive_cache[name]
|
|
3402
|
+
else:
|
|
3403
|
+
uncached_graphs.append(name)
|
|
3404
|
+
|
|
3405
|
+
for name, cached_path in cached_graphs.items():
|
|
3406
|
+
try:
|
|
3407
|
+
exports.append(GraphExport(name=name, file_path=cached_path))
|
|
3408
|
+
except Exception as e:
|
|
3409
|
+
cache_errors.append(f"Failed to read cached graph {name}: {e}")
|
|
3410
|
+
# Fall back to uncached processing
|
|
3411
|
+
uncached_graphs.append(name)
|
|
3412
|
+
|
|
3413
|
+
if uncached_graphs:
|
|
3414
|
+
successful_graphs = []
|
|
3415
|
+
failed_graphs = []
|
|
3416
|
+
memory_results = {}
|
|
3417
|
+
|
|
3418
|
+
for name in uncached_graphs:
|
|
3419
|
+
try:
|
|
3420
|
+
svg_data = _export_svg_bytes(name)
|
|
3421
|
+
memory_results[name] = svg_data
|
|
3422
|
+
successful_graphs.append(name)
|
|
3423
|
+
except Exception as e:
|
|
3424
|
+
failed_graphs.append(name)
|
|
3425
|
+
cache_errors.append(f"Failed to cache graph {name}: {e}")
|
|
3426
|
+
|
|
3427
|
+
for name in successful_graphs:
|
|
3428
|
+
result = memory_results[name]
|
|
3429
|
+
|
|
3430
|
+
cache_path = _cache_keyed_svg_path(name)
|
|
3431
|
+
|
|
3432
|
+
try:
|
|
3433
|
+
with open(cache_path, 'wb') as f:
|
|
3434
|
+
f.write(result)
|
|
3435
|
+
|
|
3436
|
+
# Update cache with size tracking and eviction
|
|
3437
|
+
import time
|
|
3438
|
+
item_size = len(result)
|
|
3439
|
+
self._evict_cache_if_needed(item_size)
|
|
3440
|
+
|
|
3441
|
+
with self._cache_lock:
|
|
3442
|
+
self._preemptive_cache[name] = cache_path
|
|
3443
|
+
# Store content hash for validation
|
|
3444
|
+
self._preemptive_cache[f"{name}_hash"] = self._get_content_hash(result)
|
|
3445
|
+
# Update tracking
|
|
3446
|
+
self._cache_access_times[name] = time.time()
|
|
3447
|
+
self._cache_sizes[name] = item_size
|
|
3448
|
+
self._total_cache_size += item_size
|
|
3449
|
+
|
|
3450
|
+
exports.append(GraphExport(name=name, file_path=cache_path))
|
|
3451
|
+
except Exception as e:
|
|
3452
|
+
cache_errors.append(f"Failed to cache graph {name}: {e}")
|
|
3453
|
+
# Still return the result even if caching fails
|
|
3454
|
+
# Create temp file for immediate use
|
|
3455
|
+
safe_name = self._sanitize_filename(name)
|
|
3456
|
+
temp_path = os.path.join(tempfile.gettempdir(), f"{safe_name}_{uuid.uuid4().hex[:8]}.svg")
|
|
3457
|
+
with open(temp_path, 'wb') as f:
|
|
3458
|
+
f.write(result)
|
|
3459
|
+
exports.append(GraphExport(name=name, file_path=temp_path))
|
|
3460
|
+
|
|
3461
|
+
# Log errors if any occurred
|
|
3462
|
+
if cache_errors:
|
|
3463
|
+
logger = logging.getLogger(__name__)
|
|
3464
|
+
for error in cache_errors:
|
|
3465
|
+
logger.warning(error)
|
|
3466
|
+
|
|
3467
|
+
return GraphExportResponse(graphs=exports)
|
|
3468
|
+
|
|
3469
|
+
def cache_graph_on_creation(self, graph_name: str) -> bool:
|
|
3470
|
+
"""Revolutionary method to cache a graph immediately after creation.
|
|
3471
|
+
|
|
3472
|
+
Call this method right after creating a graph to pre-emptively cache it.
|
|
3473
|
+
This eliminates all export wait time for future access.
|
|
3474
|
+
|
|
3475
|
+
Args:
|
|
3476
|
+
graph_name: Name of the graph to cache
|
|
3477
|
+
|
|
3478
|
+
Returns:
|
|
3479
|
+
True if caching succeeded, False otherwise
|
|
3480
|
+
"""
|
|
3481
|
+
import os
|
|
3482
|
+
import logging
|
|
3483
|
+
|
|
3484
|
+
# Initialize cache in thread-safe manner
|
|
3485
|
+
self._initialize_cache()
|
|
3486
|
+
|
|
3487
|
+
# Invalidate list_graphs cache since a new graph was created
|
|
3488
|
+
self.invalidate_list_graphs_cache()
|
|
3489
|
+
|
|
3490
|
+
# Check if already cached and valid
|
|
3491
|
+
with self._cache_lock:
|
|
3492
|
+
if graph_name in self._preemptive_cache:
|
|
3493
|
+
cache_path = self._preemptive_cache[graph_name]
|
|
3494
|
+
if os.path.exists(cache_path) and os.path.getsize(cache_path) > 0:
|
|
3495
|
+
if self._is_cache_valid(graph_name, cache_path):
|
|
3496
|
+
# Update access time for LRU
|
|
3497
|
+
import time
|
|
3498
|
+
self._cache_access_times[graph_name] = time.time()
|
|
3499
|
+
return True
|
|
3500
|
+
else:
|
|
3501
|
+
# Remove stale cache entry
|
|
3502
|
+
del self._preemptive_cache[graph_name]
|
|
3503
|
+
if graph_name in self._cache_access_times:
|
|
3504
|
+
del self._cache_access_times[graph_name]
|
|
3505
|
+
if graph_name in self._cache_sizes:
|
|
3506
|
+
self._total_cache_size -= self._cache_sizes[graph_name]
|
|
3507
|
+
del self._cache_sizes[graph_name]
|
|
3508
|
+
# Remove hash entry if exists
|
|
3509
|
+
hash_key = f"{graph_name}_hash"
|
|
3510
|
+
if hash_key in self._preemptive_cache:
|
|
3511
|
+
del self._preemptive_cache[hash_key]
|
|
3512
|
+
|
|
3513
|
+
try:
|
|
3514
|
+
# Include signature in filename to force client-side refresh
|
|
3515
|
+
import hashlib
|
|
3516
|
+
sig = self._get_graph_signature(graph_name)
|
|
3517
|
+
safe_name = self._sanitize_filename(sig)
|
|
3518
|
+
suffix = hashlib.md5((sig or "").encode("utf-8")).hexdigest()[:8]
|
|
3519
|
+
cache_path = os.path.join(self._preemptive_cache_dir, f"{safe_name}_{suffix}.svg")
|
|
3520
|
+
cache_path_for_stata = cache_path.replace("\\", "/")
|
|
3521
|
+
|
|
3522
|
+
resolved_graph_name = self._resolve_graph_name_for_stata(graph_name)
|
|
3523
|
+
# Use display + export without name() for maximum compatibility.
|
|
3524
|
+
# name(NAME) often fails in PyStata for non-active graphs (r(693)).
|
|
3525
|
+
# Quoting the name helps with spaces/special characters.
|
|
3526
|
+
display_cmd = f'quietly graph display "{resolved_graph_name}"'
|
|
3527
|
+
self._exec_no_capture_silent(display_cmd, echo=False)
|
|
3528
|
+
|
|
3529
|
+
export_cmd = f'quietly graph export "{cache_path_for_stata}", replace as(svg)'
|
|
3530
|
+
resp = self._exec_no_capture_silent(export_cmd, echo=False)
|
|
3531
|
+
|
|
3532
|
+
if resp.success and os.path.exists(cache_path) and os.path.getsize(cache_path) > 0:
|
|
3533
|
+
# Read the data to compute hash
|
|
3534
|
+
with open(cache_path, 'rb') as f:
|
|
3535
|
+
data = f.read()
|
|
3536
|
+
|
|
3537
|
+
# Update cache with size tracking and eviction
|
|
3538
|
+
import time
|
|
3539
|
+
item_size = len(data)
|
|
3540
|
+
self._evict_cache_if_needed(item_size)
|
|
3541
|
+
|
|
3542
|
+
with self._cache_lock:
|
|
3543
|
+
# Clear any old versions of this graph from the path cache
|
|
3544
|
+
# (Optional but keeps it clean)
|
|
3545
|
+
old_path = self._preemptive_cache.get(graph_name)
|
|
3546
|
+
if old_path and old_path != cache_path:
|
|
3547
|
+
try:
|
|
3548
|
+
os.remove(old_path)
|
|
3549
|
+
except Exception:
|
|
3550
|
+
pass
|
|
3551
|
+
|
|
3552
|
+
self._preemptive_cache[graph_name] = cache_path
|
|
3553
|
+
# Store content hash for validation
|
|
3554
|
+
self._preemptive_cache[f"{graph_name}_hash"] = self._get_content_hash(data)
|
|
3555
|
+
# Store signature for fast validation
|
|
3556
|
+
self._preemptive_cache[f"{graph_name}_sig"] = self._get_graph_signature(graph_name)
|
|
3557
|
+
# Update tracking
|
|
3558
|
+
self._cache_access_times[graph_name] = time.time()
|
|
3559
|
+
self._cache_sizes[graph_name] = item_size
|
|
3560
|
+
self._total_cache_size += item_size
|
|
3561
|
+
|
|
3562
|
+
return True
|
|
3563
|
+
else:
|
|
3564
|
+
error_msg = getattr(resp, 'error', 'Unknown error')
|
|
3565
|
+
logger = logging.getLogger(__name__)
|
|
3566
|
+
logger.warning(f"Failed to cache graph {graph_name}: {error_msg}")
|
|
3567
|
+
|
|
3568
|
+
except Exception as e:
|
|
3569
|
+
logger = logging.getLogger(__name__)
|
|
3570
|
+
logger.warning(f"Exception caching graph {graph_name}: {e}")
|
|
3571
|
+
|
|
3572
|
+
return False
|
|
3573
|
+
|
|
3574
|
+
def run_do_file(self, path: str, echo: bool = True, trace: bool = False, max_output_lines: Optional[int] = None, cwd: Optional[str] = None) -> CommandResponse:
|
|
3575
|
+
effective_path, command, error_response = self._resolve_do_file_path(path, cwd)
|
|
3576
|
+
if error_response is not None:
|
|
3577
|
+
return error_response
|
|
3578
|
+
|
|
3579
|
+
if not self._initialized:
|
|
3580
|
+
self.init()
|
|
3581
|
+
|
|
3582
|
+
start_time = time.time()
|
|
3583
|
+
exc: Optional[Exception] = None
|
|
3584
|
+
smcl_content = ""
|
|
3585
|
+
smcl_path = None
|
|
3586
|
+
|
|
3587
|
+
_log_file, log_path, tail, tee = self._create_streaming_log(trace=trace)
|
|
3588
|
+
base_dir = cwd or os.path.dirname(effective_path)
|
|
3589
|
+
smcl_path = self._create_smcl_log_path(base_dir=base_dir)
|
|
3590
|
+
smcl_log_name = self._make_smcl_log_name()
|
|
3591
|
+
|
|
3592
|
+
rc = -1
|
|
3593
|
+
try:
|
|
3594
|
+
rc, exc = self._run_streaming_blocking(
|
|
3595
|
+
command=command,
|
|
3596
|
+
tee=tee,
|
|
3597
|
+
cwd=cwd,
|
|
3598
|
+
trace=trace,
|
|
3599
|
+
echo=echo,
|
|
3600
|
+
smcl_path=smcl_path,
|
|
3601
|
+
smcl_log_name=smcl_log_name,
|
|
3602
|
+
hold_attr="_hold_name_do_sync",
|
|
3603
|
+
require_smcl_log=True,
|
|
3604
|
+
)
|
|
3605
|
+
except Exception as e:
|
|
3606
|
+
exc = e
|
|
3607
|
+
rc = 1
|
|
3608
|
+
finally:
|
|
3609
|
+
tee.close()
|
|
3610
|
+
|
|
3611
|
+
# Read SMCL content as the authoritative source
|
|
3612
|
+
smcl_content = self._read_smcl_file(smcl_path)
|
|
3613
|
+
|
|
3614
|
+
combined = self._build_combined_log(tail, log_path, rc, trace, exc)
|
|
3615
|
+
|
|
3616
|
+
# Use SMCL content as primary source for RC detection if not already captured
|
|
3617
|
+
if rc == -1 and not exc:
|
|
3618
|
+
parsed_rc = self._parse_rc_from_smcl(smcl_content)
|
|
3619
|
+
if parsed_rc is not None:
|
|
3620
|
+
rc = parsed_rc
|
|
3621
|
+
else:
|
|
3622
|
+
# Fallback to text parsing
|
|
3623
|
+
parsed_rc = self._parse_rc_from_text(combined)
|
|
3624
|
+
rc = parsed_rc if parsed_rc is not None else 0
|
|
3625
|
+
elif exc and rc == 1:
|
|
3626
|
+
# Try to parse more specific RC from exception message
|
|
3627
|
+
parsed_rc = self._parse_rc_from_text(str(exc))
|
|
3628
|
+
if parsed_rc is not None:
|
|
3629
|
+
rc = parsed_rc
|
|
3630
|
+
|
|
3631
|
+
success = (rc == 0 and exc is None)
|
|
3632
|
+
error = None
|
|
3633
|
+
|
|
3634
|
+
if not success:
|
|
3635
|
+
# Use SMCL as authoritative source for error extraction
|
|
3636
|
+
if smcl_content:
|
|
3637
|
+
msg, context = self._extract_error_from_smcl(smcl_content, rc)
|
|
3638
|
+
else:
|
|
3639
|
+
# Fallback to combined log
|
|
3640
|
+
msg, context = self._extract_error_and_context(combined, rc)
|
|
3641
|
+
|
|
3642
|
+
error = ErrorEnvelope(
|
|
3643
|
+
message=msg,
|
|
3644
|
+
rc=rc,
|
|
3645
|
+
snippet=context,
|
|
3646
|
+
command=command,
|
|
3647
|
+
log_path=log_path,
|
|
3648
|
+
smcl_output=smcl_content,
|
|
3649
|
+
)
|
|
3650
|
+
|
|
3651
|
+
duration = time.time() - start_time
|
|
3652
|
+
logger.info(
|
|
3653
|
+
"stata.run(do) rc=%s success=%s trace=%s duration_ms=%.2f path=%s",
|
|
3654
|
+
rc,
|
|
3655
|
+
success,
|
|
3656
|
+
trace,
|
|
3657
|
+
duration * 1000,
|
|
3658
|
+
effective_path,
|
|
3659
|
+
)
|
|
3660
|
+
|
|
3661
|
+
return CommandResponse(
|
|
3662
|
+
command=command,
|
|
3663
|
+
rc=rc,
|
|
3664
|
+
stdout="",
|
|
3665
|
+
stderr=None,
|
|
3666
|
+
log_path=log_path,
|
|
3667
|
+
success=success,
|
|
3668
|
+
error=error,
|
|
3669
|
+
smcl_output=smcl_content,
|
|
3670
|
+
)
|
|
3671
|
+
|
|
3672
|
+
def load_data(self, source: str, clear: bool = True, max_output_lines: Optional[int] = None) -> CommandResponse:
|
|
3673
|
+
src = source.strip()
|
|
3674
|
+
clear_suffix = ", clear" if clear else ""
|
|
3675
|
+
|
|
3676
|
+
if src.startswith("sysuse "):
|
|
3677
|
+
cmd = f"{src}{clear_suffix}"
|
|
3678
|
+
elif src.startswith("webuse "):
|
|
3679
|
+
cmd = f"{src}{clear_suffix}"
|
|
3680
|
+
elif src.startswith("use "):
|
|
3681
|
+
cmd = f"{src}{clear_suffix}"
|
|
3682
|
+
elif "://" in src or src.endswith(".dta") or os.path.sep in src:
|
|
3683
|
+
cmd = f'use "{src}"{clear_suffix}'
|
|
3684
|
+
else:
|
|
3685
|
+
cmd = f"sysuse {src}{clear_suffix}"
|
|
3686
|
+
|
|
3687
|
+
result = self._exec_with_capture(cmd, echo=True, trace=False)
|
|
3688
|
+
return self._truncate_command_output(result, max_output_lines)
|
|
3689
|
+
|
|
3690
|
+
def codebook(self, varname: str, trace: bool = False, max_output_lines: Optional[int] = None) -> CommandResponse:
|
|
3691
|
+
result = self._exec_with_capture(f"codebook {varname}", trace=trace)
|
|
3692
|
+
return self._truncate_command_output(result, max_output_lines)
|