mcp-stata 1.2.2__py3-none-any.whl → 1.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-stata might be problematic. Click here for more details.
- mcp_stata/discovery.py +243 -54
- mcp_stata/graph_detector.py +385 -0
- mcp_stata/models.py +4 -1
- mcp_stata/server.py +265 -44
- mcp_stata/stata_client.py +2114 -263
- mcp_stata/streaming_io.py +261 -0
- mcp_stata/ui_http.py +559 -0
- mcp_stata-1.6.8.dist-info/METADATA +388 -0
- mcp_stata-1.6.8.dist-info/RECORD +14 -0
- mcp_stata-1.2.2.dist-info/METADATA +0 -240
- mcp_stata-1.2.2.dist-info/RECORD +0 -11
- {mcp_stata-1.2.2.dist-info → mcp_stata-1.6.8.dist-info}/WHEEL +0 -0
- {mcp_stata-1.2.2.dist-info → mcp_stata-1.6.8.dist-info}/entry_points.txt +0 -0
- {mcp_stata-1.2.2.dist-info → mcp_stata-1.6.8.dist-info}/licenses/LICENSE +0 -0
mcp_stata/stata_client.py
CHANGED
|
@@ -1,17 +1,22 @@
|
|
|
1
|
-
import sys
|
|
2
|
-
import os
|
|
3
|
-
import json
|
|
4
|
-
import re
|
|
5
1
|
import base64
|
|
2
|
+
import json
|
|
6
3
|
import logging
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import subprocess
|
|
7
|
+
import sys
|
|
7
8
|
import threading
|
|
9
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
10
|
+
import tempfile
|
|
8
11
|
import time
|
|
9
|
-
from io import StringIO
|
|
10
12
|
from contextlib import contextmanager
|
|
11
|
-
from
|
|
12
|
-
import
|
|
13
|
+
from io import StringIO
|
|
14
|
+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
|
15
|
+
|
|
16
|
+
import anyio
|
|
17
|
+
from anyio import get_cancelled_exc_class
|
|
18
|
+
|
|
13
19
|
from .discovery import find_stata_path
|
|
14
|
-
from .smcl.smcl2html import smcl_to_markdown
|
|
15
20
|
from .models import (
|
|
16
21
|
CommandResponse,
|
|
17
22
|
ErrorEnvelope,
|
|
@@ -22,21 +27,97 @@ from .models import (
|
|
|
22
27
|
VariableInfo,
|
|
23
28
|
VariablesResponse,
|
|
24
29
|
)
|
|
30
|
+
from .smcl.smcl2html import smcl_to_markdown
|
|
31
|
+
from .streaming_io import FileTeeIO, TailBuffer
|
|
32
|
+
from .graph_detector import StreamingGraphCache
|
|
25
33
|
|
|
26
34
|
logger = logging.getLogger("mcp_stata")
|
|
27
35
|
|
|
36
|
+
|
|
37
|
+
# ============================================================================
|
|
38
|
+
# MODULE-LEVEL DISCOVERY CACHE
|
|
39
|
+
# ============================================================================
|
|
40
|
+
# This cache ensures Stata discovery runs exactly once per process lifetime
|
|
41
|
+
_discovery_lock = threading.Lock()
|
|
42
|
+
_discovery_result: Optional[Tuple[str, str]] = None # (path, edition)
|
|
43
|
+
_discovery_attempted = False
|
|
44
|
+
_discovery_error: Optional[Exception] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_discovered_stata() -> Tuple[str, str]:
|
|
48
|
+
"""
|
|
49
|
+
Get the discovered Stata path and edition, running discovery only once.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Tuple of (stata_executable_path, edition)
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
RuntimeError: If Stata discovery fails
|
|
56
|
+
"""
|
|
57
|
+
global _discovery_result, _discovery_attempted, _discovery_error
|
|
58
|
+
|
|
59
|
+
with _discovery_lock:
|
|
60
|
+
# If we've already successfully discovered Stata, return cached result
|
|
61
|
+
if _discovery_result is not None:
|
|
62
|
+
return _discovery_result
|
|
63
|
+
|
|
64
|
+
# If we've already attempted and failed, re-raise the cached error
|
|
65
|
+
if _discovery_attempted and _discovery_error is not None:
|
|
66
|
+
raise RuntimeError(f"Stata binary not found: {_discovery_error}") from _discovery_error
|
|
67
|
+
|
|
68
|
+
# This is the first attempt - run discovery
|
|
69
|
+
_discovery_attempted = True
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
# Log environment state once at first discovery
|
|
73
|
+
env_path = os.getenv("STATA_PATH")
|
|
74
|
+
if env_path:
|
|
75
|
+
logger.info("STATA_PATH env provided (raw): %s", env_path)
|
|
76
|
+
else:
|
|
77
|
+
logger.info("STATA_PATH env not set; attempting auto-discovery")
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
pkg_version = version("mcp-stata")
|
|
81
|
+
except PackageNotFoundError:
|
|
82
|
+
pkg_version = "unknown"
|
|
83
|
+
logger.info("mcp-stata version: %s", pkg_version)
|
|
84
|
+
|
|
85
|
+
# Run discovery
|
|
86
|
+
stata_exec_path, edition = find_stata_path()
|
|
87
|
+
|
|
88
|
+
# Cache the successful result
|
|
89
|
+
_discovery_result = (stata_exec_path, edition)
|
|
90
|
+
logger.info("Discovery found Stata at: %s (%s)", stata_exec_path, edition)
|
|
91
|
+
|
|
92
|
+
return _discovery_result
|
|
93
|
+
|
|
94
|
+
except FileNotFoundError as e:
|
|
95
|
+
_discovery_error = e
|
|
96
|
+
raise RuntimeError(f"Stata binary not found: {e}") from e
|
|
97
|
+
except PermissionError as e:
|
|
98
|
+
_discovery_error = e
|
|
99
|
+
raise RuntimeError(
|
|
100
|
+
f"Stata binary is not executable: {e}. "
|
|
101
|
+
"Point STATA_PATH directly to the Stata binary (e.g., .../Contents/MacOS/stata-mp)."
|
|
102
|
+
) from e
|
|
103
|
+
|
|
104
|
+
|
|
28
105
|
class StataClient:
|
|
29
|
-
_instance = None
|
|
30
106
|
_initialized = False
|
|
31
107
|
_exec_lock: threading.Lock
|
|
108
|
+
_cache_init_lock = threading.Lock() # Class-level lock for cache initialization
|
|
109
|
+
_is_executing = False # Flag to prevent recursive Stata calls
|
|
32
110
|
MAX_DATA_ROWS = 500
|
|
33
|
-
MAX_GRAPH_BYTES = 50 * 1024 * 1024 #
|
|
111
|
+
MAX_GRAPH_BYTES = 50 * 1024 * 1024 # Maximum graph exports (~50MB)
|
|
112
|
+
MAX_CACHE_SIZE = 100 # Maximum number of graphs to cache
|
|
113
|
+
MAX_CACHE_BYTES = 500 * 1024 * 1024 # Maximum cache size in bytes (~500MB)
|
|
114
|
+
LIST_GRAPHS_TTL = 0.075 # TTL for list_graphs cache (75ms)
|
|
34
115
|
|
|
35
116
|
def __new__(cls):
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
return
|
|
117
|
+
inst = super(StataClient, cls).__new__(cls)
|
|
118
|
+
inst._exec_lock = threading.Lock()
|
|
119
|
+
inst._is_executing = False
|
|
120
|
+
return inst
|
|
40
121
|
|
|
41
122
|
@contextmanager
|
|
42
123
|
def _redirect_io(self):
|
|
@@ -49,95 +130,261 @@ class StataClient:
|
|
|
49
130
|
finally:
|
|
50
131
|
sys.stdout, sys.stderr = backup_stdout, backup_stderr
|
|
51
132
|
|
|
133
|
+
@staticmethod
|
|
134
|
+
def _stata_quote(value: str) -> str:
|
|
135
|
+
"""Return a Stata double-quoted string literal for value."""
|
|
136
|
+
# Stata uses doubled quotes to represent a quote character inside a string.
|
|
137
|
+
v = (value or "")
|
|
138
|
+
v = v.replace('"', '""')
|
|
139
|
+
# Use compound double quotes to avoid tokenization issues with spaces and
|
|
140
|
+
# punctuation in contexts like graph names.
|
|
141
|
+
return f'`"{v}"\''
|
|
142
|
+
|
|
143
|
+
@contextmanager
|
|
144
|
+
def _redirect_io_streaming(self, out_stream, err_stream):
|
|
145
|
+
backup_stdout, backup_stderr = sys.stdout, sys.stderr
|
|
146
|
+
sys.stdout, sys.stderr = out_stream, err_stream
|
|
147
|
+
try:
|
|
148
|
+
yield
|
|
149
|
+
finally:
|
|
150
|
+
sys.stdout, sys.stderr = backup_stdout, backup_stderr
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def _create_graph_cache_callback(on_graph_cached, notify_log):
|
|
154
|
+
"""Create a standardized graph cache callback with proper error handling."""
|
|
155
|
+
async def graph_cache_callback(graph_name: str, success: bool) -> None:
|
|
156
|
+
try:
|
|
157
|
+
if on_graph_cached:
|
|
158
|
+
await on_graph_cached(graph_name, success)
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.error(f"Graph cache callback failed: {e}")
|
|
161
|
+
|
|
162
|
+
try:
|
|
163
|
+
# Also notify via log channel
|
|
164
|
+
await notify_log(json.dumps({
|
|
165
|
+
"event": "graph_cached",
|
|
166
|
+
"graph": graph_name,
|
|
167
|
+
"success": success
|
|
168
|
+
}))
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.error(f"Failed to notify about graph cache: {e}")
|
|
171
|
+
|
|
172
|
+
return graph_cache_callback
|
|
173
|
+
def _request_break_in(self) -> None:
|
|
174
|
+
"""
|
|
175
|
+
Attempt to interrupt a running Stata command when cancellation is requested.
|
|
176
|
+
|
|
177
|
+
Uses the Stata sfi.breakIn hook when available; errors are swallowed because
|
|
178
|
+
cancellation should never crash the host process.
|
|
179
|
+
"""
|
|
180
|
+
try:
|
|
181
|
+
import sfi # type: ignore[import-not-found]
|
|
182
|
+
|
|
183
|
+
break_fn = getattr(sfi, "breakIn", None) or getattr(sfi, "break_in", None)
|
|
184
|
+
if callable(break_fn):
|
|
185
|
+
try:
|
|
186
|
+
break_fn()
|
|
187
|
+
logger.info("Sent breakIn() to Stata for cancellation")
|
|
188
|
+
except Exception as e: # pragma: no cover - best-effort
|
|
189
|
+
logger.warning(f"Failed to send breakIn() to Stata: {e}")
|
|
190
|
+
else: # pragma: no cover - environment without Stata runtime
|
|
191
|
+
logger.debug("sfi.breakIn not available; cannot interrupt Stata")
|
|
192
|
+
except Exception as e: # pragma: no cover - import failure or other
|
|
193
|
+
logger.debug(f"Unable to import sfi for cancellation: {e}")
|
|
194
|
+
|
|
195
|
+
async def _wait_for_stata_stop(self, timeout: float = 2.0) -> bool:
|
|
196
|
+
"""
|
|
197
|
+
After requesting a break, poll the Stata interface so it can surface BreakError
|
|
198
|
+
and return control. This is best-effort and time-bounded.
|
|
199
|
+
"""
|
|
200
|
+
deadline = time.monotonic() + timeout
|
|
201
|
+
try:
|
|
202
|
+
import sfi # type: ignore[import-not-found]
|
|
203
|
+
|
|
204
|
+
toolkit = getattr(sfi, "SFIToolkit", None)
|
|
205
|
+
poll = getattr(toolkit, "pollnow", None) or getattr(toolkit, "pollstd", None)
|
|
206
|
+
BreakError = getattr(sfi, "BreakError", None)
|
|
207
|
+
except Exception: # pragma: no cover
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
if not callable(poll):
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
last_exc: Optional[Exception] = None
|
|
214
|
+
while time.monotonic() < deadline:
|
|
215
|
+
try:
|
|
216
|
+
poll()
|
|
217
|
+
except Exception as e: # pragma: no cover - depends on Stata runtime
|
|
218
|
+
last_exc = e
|
|
219
|
+
if BreakError is not None and isinstance(e, BreakError):
|
|
220
|
+
logger.info("Stata BreakError detected; cancellation acknowledged by Stata")
|
|
221
|
+
return True
|
|
222
|
+
# If Stata already stopped, break on any other exception.
|
|
223
|
+
break
|
|
224
|
+
await anyio.sleep(0.05)
|
|
225
|
+
|
|
226
|
+
if last_exc:
|
|
227
|
+
logger.debug(f"Cancellation poll exited with {last_exc}")
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
@contextmanager
|
|
231
|
+
def _temp_cwd(self, cwd: Optional[str]):
|
|
232
|
+
if cwd is None:
|
|
233
|
+
yield
|
|
234
|
+
return
|
|
235
|
+
prev = os.getcwd()
|
|
236
|
+
os.chdir(cwd)
|
|
237
|
+
try:
|
|
238
|
+
yield
|
|
239
|
+
finally:
|
|
240
|
+
os.chdir(prev)
|
|
241
|
+
|
|
52
242
|
def init(self):
|
|
53
|
-
"""Initializes usage of pystata."""
|
|
243
|
+
"""Initializes usage of pystata using cached discovery results."""
|
|
54
244
|
if self._initialized:
|
|
55
245
|
return
|
|
56
246
|
|
|
57
247
|
try:
|
|
58
|
-
# 1. Setup config
|
|
59
|
-
# 1. Setup config
|
|
60
248
|
import stata_setup
|
|
61
|
-
try:
|
|
62
|
-
stata_exec_path, edition = find_stata_path()
|
|
63
|
-
except FileNotFoundError as e:
|
|
64
|
-
raise RuntimeError(f"Stata binary not found: {e}") from e
|
|
65
|
-
except PermissionError as e:
|
|
66
|
-
raise RuntimeError(
|
|
67
|
-
f"Stata binary is not executable: {e}. "
|
|
68
|
-
"Point STATA_PATH directly to the Stata binary (e.g., .../Contents/MacOS/stata-mp)."
|
|
69
|
-
) from e
|
|
70
|
-
logger.info(f"Discovery found Stata at: {stata_exec_path} ({edition})")
|
|
71
249
|
|
|
72
|
-
#
|
|
73
|
-
|
|
74
|
-
try:
|
|
75
|
-
logger.info(f"Attempting stata_setup.config with: {path_to_try}")
|
|
76
|
-
stata_setup.config(path_to_try, edition)
|
|
77
|
-
return True
|
|
78
|
-
except Exception as e:
|
|
79
|
-
logger.warning(f"Init failed with {path_to_try}: {e}")
|
|
80
|
-
return False
|
|
250
|
+
# Get discovered Stata path (cached from first call)
|
|
251
|
+
stata_exec_path, edition = _get_discovered_stata()
|
|
81
252
|
|
|
82
|
-
success = False
|
|
83
253
|
candidates = []
|
|
84
|
-
|
|
85
|
-
#
|
|
254
|
+
|
|
255
|
+
# Prefer the binary directory first (documented input for stata_setup)
|
|
86
256
|
bin_dir = os.path.dirname(stata_exec_path)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
257
|
+
if bin_dir:
|
|
258
|
+
candidates.append(bin_dir)
|
|
259
|
+
|
|
260
|
+
# 2. App Bundle: .../StataMP.app (macOS only)
|
|
90
261
|
curr = bin_dir
|
|
91
262
|
app_bundle = None
|
|
92
263
|
while len(curr) > 1:
|
|
93
264
|
if curr.endswith(".app"):
|
|
94
265
|
app_bundle = curr
|
|
95
266
|
break
|
|
96
|
-
|
|
97
|
-
|
|
267
|
+
parent = os.path.dirname(curr)
|
|
268
|
+
if parent == curr: # Reached root directory, prevent infinite loop on Windows
|
|
269
|
+
break
|
|
270
|
+
curr = parent
|
|
271
|
+
|
|
98
272
|
if app_bundle:
|
|
99
|
-
|
|
100
|
-
candidates.
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
273
|
+
candidates.insert(0, os.path.dirname(app_bundle))
|
|
274
|
+
candidates.insert(1, app_bundle)
|
|
275
|
+
|
|
276
|
+
# Deduplicate preserving order
|
|
277
|
+
seen = set()
|
|
278
|
+
deduped = []
|
|
279
|
+
for c in candidates:
|
|
280
|
+
if c in seen:
|
|
281
|
+
continue
|
|
282
|
+
seen.add(c)
|
|
283
|
+
deduped.append(c)
|
|
284
|
+
candidates = deduped
|
|
285
|
+
|
|
286
|
+
success = False
|
|
108
287
|
for path in candidates:
|
|
109
|
-
|
|
288
|
+
try:
|
|
289
|
+
stata_setup.config(path, edition)
|
|
110
290
|
success = True
|
|
291
|
+
logger.debug("stata_setup.config succeeded with path: %s", path)
|
|
111
292
|
break
|
|
112
|
-
|
|
293
|
+
except Exception:
|
|
294
|
+
continue
|
|
295
|
+
|
|
113
296
|
if not success:
|
|
114
297
|
raise RuntimeError(
|
|
115
298
|
f"stata_setup.config failed. Tried: {candidates}. "
|
|
116
299
|
f"Derived from binary: {stata_exec_path}"
|
|
117
300
|
)
|
|
118
|
-
|
|
119
|
-
#
|
|
120
|
-
|
|
301
|
+
|
|
302
|
+
# Cache the binary path for later use (e.g., PNG export on Windows)
|
|
303
|
+
self._stata_exec_path = os.path.abspath(stata_exec_path)
|
|
304
|
+
|
|
305
|
+
from pystata import stata # type: ignore[import-not-found]
|
|
121
306
|
self.stata = stata
|
|
122
307
|
self._initialized = True
|
|
123
308
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
309
|
+
# Initialize list_graphs TTL cache
|
|
310
|
+
self._list_graphs_cache = None
|
|
311
|
+
self._list_graphs_cache_time = 0
|
|
312
|
+
self._list_graphs_cache_lock = threading.Lock()
|
|
313
|
+
|
|
314
|
+
# Map user-facing graph names (may include spaces/punctuation) to valid
|
|
315
|
+
# internal Stata graph names.
|
|
316
|
+
self._graph_name_aliases: Dict[str, str] = {}
|
|
317
|
+
self._graph_name_reverse: Dict[str, str] = {}
|
|
318
|
+
|
|
319
|
+
logger.info("StataClient initialized successfully with %s (%s)", stata_exec_path, edition)
|
|
320
|
+
|
|
321
|
+
except ImportError as e:
|
|
322
|
+
raise RuntimeError(
|
|
323
|
+
f"Failed to import stata_setup or pystata: {e}. "
|
|
324
|
+
"Ensure they are installed (pip install pystata stata-setup)."
|
|
325
|
+
) from e
|
|
326
|
+
|
|
327
|
+
def _make_valid_stata_name(self, name: str) -> str:
|
|
328
|
+
"""Create a valid Stata name (<=32 chars, [A-Za-z_][A-Za-z0-9_]*)."""
|
|
329
|
+
base = re.sub(r"[^A-Za-z0-9_]", "_", name or "")
|
|
330
|
+
if not base:
|
|
331
|
+
base = "Graph"
|
|
332
|
+
if not re.match(r"^[A-Za-z_]", base):
|
|
333
|
+
base = f"G_{base}"
|
|
334
|
+
base = base[:32]
|
|
335
|
+
|
|
336
|
+
# Avoid collisions.
|
|
337
|
+
candidate = base
|
|
338
|
+
i = 1
|
|
339
|
+
while candidate in getattr(self, "_graph_name_reverse", {}):
|
|
340
|
+
suffix = f"_{i}"
|
|
341
|
+
candidate = (base[: max(0, 32 - len(suffix))] + suffix)[:32]
|
|
342
|
+
i += 1
|
|
343
|
+
return candidate
|
|
344
|
+
|
|
345
|
+
def _resolve_graph_name_for_stata(self, name: str) -> str:
|
|
346
|
+
"""Return internal Stata graph name for a user-facing name."""
|
|
347
|
+
if not name:
|
|
348
|
+
return name
|
|
349
|
+
aliases = getattr(self, "_graph_name_aliases", None)
|
|
350
|
+
if aliases and name in aliases:
|
|
351
|
+
return aliases[name]
|
|
352
|
+
return name
|
|
353
|
+
|
|
354
|
+
def _maybe_rewrite_graph_name_in_command(self, code: str) -> str:
|
|
355
|
+
"""Rewrite name("...") to a valid Stata name and store alias mapping."""
|
|
356
|
+
if not code:
|
|
357
|
+
return code
|
|
358
|
+
if not hasattr(self, "_graph_name_aliases"):
|
|
359
|
+
self._graph_name_aliases = {}
|
|
360
|
+
self._graph_name_reverse = {}
|
|
361
|
+
|
|
362
|
+
# Handle common patterns: name("..." ...) or name(`"..."' ...)
|
|
363
|
+
pat = re.compile(r"name\(\s*(?:`\"(?P<cq>[^\"]*)\"'|\"(?P<dq>[^\"]*)\")\s*(?P<rest>[^)]*)\)")
|
|
364
|
+
|
|
365
|
+
def repl(m: re.Match) -> str:
|
|
366
|
+
original = m.group("cq") if m.group("cq") is not None else m.group("dq")
|
|
367
|
+
original = original or ""
|
|
368
|
+
internal = self._graph_name_aliases.get(original)
|
|
369
|
+
if not internal:
|
|
370
|
+
internal = self._make_valid_stata_name(original)
|
|
371
|
+
self._graph_name_aliases[original] = internal
|
|
372
|
+
self._graph_name_reverse[internal] = original
|
|
373
|
+
rest = m.group("rest") or ""
|
|
374
|
+
return f"name({internal}{rest})"
|
|
375
|
+
|
|
376
|
+
return pat.sub(repl, code)
|
|
130
377
|
|
|
131
378
|
def _read_return_code(self) -> int:
|
|
132
379
|
"""Read the last Stata return code without mutating rc."""
|
|
133
380
|
try:
|
|
134
|
-
from sfi import Macro
|
|
381
|
+
from sfi import Macro # type: ignore[import-not-found]
|
|
135
382
|
rc_val = Macro.getCValue("rc") # type: ignore[attr-defined]
|
|
136
383
|
return int(float(rc_val))
|
|
137
384
|
except Exception:
|
|
138
385
|
try:
|
|
139
386
|
self.stata.run("global MCP_RC = c(rc)")
|
|
140
|
-
from sfi import Macro as Macro2
|
|
387
|
+
from sfi import Macro as Macro2 # type: ignore[import-not-found]
|
|
141
388
|
rc_val = Macro2.getGlobal("MCP_RC")
|
|
142
389
|
return int(float(rc_val))
|
|
143
390
|
except Exception:
|
|
@@ -183,7 +430,7 @@ class StataClient:
|
|
|
183
430
|
) -> ErrorEnvelope:
|
|
184
431
|
combined = "\n".join(filter(None, [stdout, stderr, str(exc) if exc else ""])).strip()
|
|
185
432
|
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
186
|
-
rc_final = rc if rc not in (-1, None) else rc_hint
|
|
433
|
+
rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
|
|
187
434
|
line_no = self._parse_line_from_text(combined) if combined else None
|
|
188
435
|
snippet = combined[-800:] if combined else None
|
|
189
436
|
message = (stderr or (str(exc) if exc else "") or stdout or "Stata error").strip()
|
|
@@ -198,33 +445,68 @@ class StataClient:
|
|
|
198
445
|
trace=trace or None,
|
|
199
446
|
)
|
|
200
447
|
|
|
201
|
-
def _exec_with_capture(self, code: str, echo: bool = True, trace: bool = False) -> CommandResponse:
|
|
448
|
+
def _exec_with_capture(self, code: str, echo: bool = True, trace: bool = False, cwd: Optional[str] = None) -> CommandResponse:
|
|
202
449
|
"""Execute Stata code with stdout/stderr capture and rc detection."""
|
|
203
450
|
if not self._initialized:
|
|
204
451
|
self.init()
|
|
205
452
|
|
|
453
|
+
code = self._maybe_rewrite_graph_name_in_command(code)
|
|
454
|
+
|
|
455
|
+
if cwd is not None and not os.path.isdir(cwd):
|
|
456
|
+
return CommandResponse(
|
|
457
|
+
command=code,
|
|
458
|
+
rc=601,
|
|
459
|
+
stdout="",
|
|
460
|
+
stderr=None,
|
|
461
|
+
success=False,
|
|
462
|
+
error=ErrorEnvelope(
|
|
463
|
+
message=f"cwd not found: {cwd}",
|
|
464
|
+
rc=601,
|
|
465
|
+
command=code,
|
|
466
|
+
),
|
|
467
|
+
)
|
|
468
|
+
|
|
206
469
|
start_time = time.time()
|
|
207
470
|
exc: Optional[Exception] = None
|
|
471
|
+
ret_text: Optional[str] = None
|
|
208
472
|
with self._exec_lock:
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
self.
|
|
214
|
-
except Exception as e:
|
|
215
|
-
exc = e
|
|
216
|
-
finally:
|
|
217
|
-
rc = self._read_return_code()
|
|
218
|
-
if trace:
|
|
473
|
+
# Set execution flag to prevent recursive Stata calls
|
|
474
|
+
self._is_executing = True
|
|
475
|
+
try:
|
|
476
|
+
with self._temp_cwd(cwd):
|
|
477
|
+
with self._redirect_io() as (out_buf, err_buf):
|
|
219
478
|
try:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
479
|
+
if trace:
|
|
480
|
+
self.stata.run("set trace on")
|
|
481
|
+
ret = self.stata.run(code, echo=echo)
|
|
482
|
+
if isinstance(ret, str) and ret:
|
|
483
|
+
ret_text = ret
|
|
484
|
+
except Exception as e:
|
|
485
|
+
exc = e
|
|
486
|
+
finally:
|
|
487
|
+
rc = self._read_return_code()
|
|
488
|
+
if trace:
|
|
489
|
+
try:
|
|
490
|
+
self.stata.run("set trace off")
|
|
491
|
+
except Exception:
|
|
492
|
+
pass
|
|
493
|
+
finally:
|
|
494
|
+
# Clear execution flag
|
|
495
|
+
self._is_executing = False
|
|
223
496
|
|
|
224
497
|
stdout = out_buf.getvalue()
|
|
498
|
+
# Some PyStata builds return output as a string rather than printing.
|
|
499
|
+
if (not stdout or not stdout.strip()) and ret_text:
|
|
500
|
+
stdout = ret_text
|
|
225
501
|
stderr = err_buf.getvalue()
|
|
226
|
-
|
|
227
|
-
|
|
502
|
+
combined = "\n".join(filter(None, [stdout, stderr, str(exc) if exc else ""])).strip()
|
|
503
|
+
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
504
|
+
if exc is None and rc_hint is not None and rc_hint != 0:
|
|
505
|
+
# Prefer r(#) parsed from the current command output when present.
|
|
506
|
+
rc = rc_hint
|
|
507
|
+
# If no exception and stderr is empty and no r(#) is present, treat rc anomalies as success
|
|
508
|
+
# (e.g., stale/spurious c(rc) reads).
|
|
509
|
+
if exc is None and (not stderr or not stderr.strip()) and rc_hint is None:
|
|
228
510
|
rc = 0 if rc is None or rc != 0 else rc
|
|
229
511
|
success = rc == 0 and exc is None
|
|
230
512
|
error = None
|
|
@@ -240,176 +522,1156 @@ class StataClient:
|
|
|
240
522
|
duration * 1000,
|
|
241
523
|
code_preview[:120],
|
|
242
524
|
)
|
|
525
|
+
# Mutually exclusive - when error, output is in ErrorEnvelope only
|
|
243
526
|
return CommandResponse(
|
|
244
527
|
command=code,
|
|
245
528
|
rc=rc,
|
|
246
|
-
stdout=stdout,
|
|
247
|
-
stderr=
|
|
529
|
+
stdout="" if not success else stdout,
|
|
530
|
+
stderr=None,
|
|
248
531
|
success=success,
|
|
249
532
|
error=error,
|
|
250
533
|
)
|
|
251
534
|
|
|
252
|
-
def
|
|
253
|
-
"""
|
|
254
|
-
result = self._exec_with_capture(code, echo=echo)
|
|
255
|
-
if result.success:
|
|
256
|
-
return result.stdout
|
|
257
|
-
if result.error:
|
|
258
|
-
return f"Error executing Stata code (r({result.error.rc})):\n{result.error.message}"
|
|
259
|
-
return result.stdout or "Unknown Stata error"
|
|
535
|
+
def _exec_no_capture(self, code: str, echo: bool = False, trace: bool = False) -> CommandResponse:
|
|
536
|
+
"""Execute Stata code while leaving stdout/stderr alone.
|
|
260
537
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
def get_data(self, start: int = 0, count: int = 50) -> List[Dict[str, Any]]:
|
|
266
|
-
"""Returns valid JSON-serializable data."""
|
|
538
|
+
PyStata's output bridge uses its own thread and can misbehave on Windows
|
|
539
|
+
when we redirect stdio (e.g., graph export). This path keeps the normal
|
|
540
|
+
handlers and just reads rc afterward.
|
|
541
|
+
"""
|
|
267
542
|
if not self._initialized:
|
|
268
543
|
self.init()
|
|
269
544
|
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
545
|
+
exc: Optional[Exception] = None
|
|
546
|
+
ret_text: Optional[str] = None
|
|
547
|
+
with self._exec_lock:
|
|
548
|
+
try:
|
|
549
|
+
if trace:
|
|
550
|
+
self.stata.run("set trace on")
|
|
551
|
+
ret = self.stata.run(code, echo=echo)
|
|
552
|
+
if isinstance(ret, str) and ret:
|
|
553
|
+
ret_text = ret
|
|
554
|
+
except Exception as e:
|
|
555
|
+
exc = e
|
|
556
|
+
finally:
|
|
557
|
+
rc = self._read_return_code()
|
|
558
|
+
# If Stata returned an r(#) in text, prefer it.
|
|
559
|
+
combined = "\n".join(filter(None, [ret_text or "", str(exc) if exc else ""])).strip()
|
|
560
|
+
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
561
|
+
if exc is None and rc_hint is not None and rc_hint != 0:
|
|
562
|
+
rc = rc_hint
|
|
563
|
+
if exc is None and (rc is None or rc == -1) and rc_hint is None:
|
|
564
|
+
# Normalize spurious rc reads only when missing/invalid
|
|
565
|
+
rc = 0
|
|
566
|
+
if trace:
|
|
567
|
+
try:
|
|
568
|
+
self.stata.run("set trace off")
|
|
569
|
+
except Exception as e:
|
|
570
|
+
logger.warning("Failed to turn off Stata trace mode: %s", e)
|
|
276
571
|
|
|
277
|
-
|
|
278
|
-
|
|
572
|
+
stdout = ""
|
|
573
|
+
stderr = ""
|
|
574
|
+
success = rc == 0 and exc is None
|
|
575
|
+
error = None
|
|
576
|
+
if not success:
|
|
577
|
+
# Pass ret_text as stdout for snippet parsing.
|
|
578
|
+
error = self._build_error_envelope(code, rc, ret_text or "", stderr, exc, trace)
|
|
279
579
|
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
580
|
+
return CommandResponse(
|
|
581
|
+
command=code,
|
|
582
|
+
rc=rc,
|
|
583
|
+
stdout=stdout,
|
|
584
|
+
stderr=None,
|
|
585
|
+
success=success,
|
|
586
|
+
error=error,
|
|
587
|
+
)
|
|
284
588
|
|
|
285
|
-
def
|
|
286
|
-
|
|
589
|
+
async def run_command_streaming(
|
|
590
|
+
self,
|
|
591
|
+
code: str,
|
|
592
|
+
*,
|
|
593
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
594
|
+
notify_progress: Optional[Callable[[float, Optional[float], Optional[str]], Awaitable[None]]] = None,
|
|
595
|
+
echo: bool = True,
|
|
596
|
+
trace: bool = False,
|
|
597
|
+
max_output_lines: Optional[int] = None,
|
|
598
|
+
cwd: Optional[str] = None,
|
|
599
|
+
auto_cache_graphs: bool = False,
|
|
600
|
+
on_graph_cached: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
|
601
|
+
) -> CommandResponse:
|
|
287
602
|
if not self._initialized:
|
|
288
603
|
self.init()
|
|
289
|
-
|
|
290
|
-
# We can use sfi to be efficient
|
|
291
|
-
from sfi import Data
|
|
292
|
-
vars_info = []
|
|
293
|
-
for i in range(Data.getVarCount()):
|
|
294
|
-
var_index = i # 0-based
|
|
295
|
-
name = Data.getVarName(var_index)
|
|
296
|
-
label = Data.getVarLabel(var_index)
|
|
297
|
-
type_str = Data.getVarType(var_index) # Returns int
|
|
298
|
-
|
|
299
|
-
vars_info.append({
|
|
300
|
-
"name": name,
|
|
301
|
-
"label": label,
|
|
302
|
-
"type": str(type_str),
|
|
303
|
-
})
|
|
304
|
-
return vars_info
|
|
305
604
|
|
|
306
|
-
|
|
307
|
-
"""Returns codebook/summary for a specific variable."""
|
|
308
|
-
return self.run_command(f"codebook {varname}")
|
|
605
|
+
code = self._maybe_rewrite_graph_name_in_command(code)
|
|
309
606
|
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
607
|
+
if cwd is not None and not os.path.isdir(cwd):
|
|
608
|
+
return CommandResponse(
|
|
609
|
+
command=code,
|
|
610
|
+
rc=601,
|
|
611
|
+
stdout="",
|
|
612
|
+
stderr=None,
|
|
613
|
+
success=False,
|
|
614
|
+
error=ErrorEnvelope(
|
|
615
|
+
message=f"cwd not found: {cwd}",
|
|
616
|
+
rc=601,
|
|
617
|
+
command=code,
|
|
618
|
+
),
|
|
319
619
|
)
|
|
320
|
-
return VariablesResponse(variables=vars_info)
|
|
321
620
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
if not self._initialized:
|
|
325
|
-
self.init()
|
|
326
|
-
|
|
327
|
-
# 'graph dir' returns list in r(list)
|
|
328
|
-
# We need to ensure we run it quietly so we don't spam.
|
|
329
|
-
self.stata.run("quietly graph dir, memory")
|
|
330
|
-
|
|
331
|
-
# Accessing r-class results in Python can be tricky via pystata's run command.
|
|
332
|
-
# We stash the result in a global macro that python sfi can easily read.
|
|
333
|
-
from sfi import Macro
|
|
334
|
-
self.stata.run("global mcp_graph_list `r(list)'")
|
|
335
|
-
graph_list_str = Macro.getGlobal("mcp_graph_list")
|
|
336
|
-
if not graph_list_str:
|
|
337
|
-
return []
|
|
338
|
-
|
|
339
|
-
return graph_list_str.split()
|
|
621
|
+
start_time = time.time()
|
|
622
|
+
exc: Optional[Exception] = None
|
|
340
623
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
624
|
+
# Setup streaming graph cache if enabled
|
|
625
|
+
graph_cache = None
|
|
626
|
+
if auto_cache_graphs:
|
|
627
|
+
graph_cache = StreamingGraphCache(self, auto_cache=True)
|
|
628
|
+
|
|
629
|
+
graph_cache_callback = self._create_graph_cache_callback(on_graph_cached, notify_log)
|
|
630
|
+
|
|
631
|
+
graph_cache.add_cache_callback(graph_cache_callback)
|
|
346
632
|
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
633
|
+
log_file = tempfile.NamedTemporaryFile(
|
|
634
|
+
prefix="mcp_stata_",
|
|
635
|
+
suffix=".log",
|
|
636
|
+
delete=False,
|
|
637
|
+
mode="w",
|
|
638
|
+
encoding="utf-8",
|
|
639
|
+
errors="replace",
|
|
640
|
+
buffering=1,
|
|
641
|
+
)
|
|
642
|
+
log_path = log_file.name
|
|
643
|
+
tail = TailBuffer(max_chars=8000)
|
|
644
|
+
tee = FileTeeIO(log_file, tail)
|
|
350
645
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
raise ValueError(f"Unsupported graph export format: {format}. Allowed: pdf, png.")
|
|
646
|
+
# Inform the MCP client immediately where to read/tail the output.
|
|
647
|
+
await notify_log(json.dumps({"event": "log_path", "path": log_path}))
|
|
354
648
|
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
if os.path.exists(filename):
|
|
362
|
-
try:
|
|
363
|
-
os.remove(filename)
|
|
364
|
-
except Exception:
|
|
365
|
-
pass
|
|
366
|
-
|
|
367
|
-
cmd = "graph export"
|
|
368
|
-
if graph_name:
|
|
369
|
-
cmd += f' "{filename}", name("{graph_name}") replace as({fmt})'
|
|
370
|
-
else:
|
|
371
|
-
cmd += f' "{filename}", replace as({fmt})'
|
|
372
|
-
|
|
373
|
-
output = self.run_command(cmd)
|
|
374
|
-
|
|
375
|
-
if os.path.exists(filename):
|
|
376
|
-
try:
|
|
377
|
-
size = os.path.getsize(filename)
|
|
378
|
-
if size == 0:
|
|
379
|
-
raise RuntimeError(f"Graph export failed: produced empty file {filename}")
|
|
380
|
-
if size > self.MAX_GRAPH_BYTES:
|
|
381
|
-
raise RuntimeError(
|
|
382
|
-
f"Graph export failed: file too large (> {self.MAX_GRAPH_BYTES} bytes): {filename}"
|
|
383
|
-
)
|
|
384
|
-
except Exception as size_err:
|
|
385
|
-
# Clean up oversized or unreadable files
|
|
649
|
+
rc = -1
|
|
650
|
+
|
|
651
|
+
def _run_blocking() -> None:
|
|
652
|
+
nonlocal rc, exc
|
|
653
|
+
with self._exec_lock:
|
|
654
|
+
self._is_executing = True
|
|
386
655
|
try:
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
656
|
+
with self._temp_cwd(cwd):
|
|
657
|
+
with self._redirect_io_streaming(tee, tee):
|
|
658
|
+
try:
|
|
659
|
+
if trace:
|
|
660
|
+
self.stata.run("set trace on")
|
|
661
|
+
ret = self.stata.run(code, echo=echo)
|
|
662
|
+
# Some PyStata builds return output as a string rather than printing.
|
|
663
|
+
if isinstance(ret, str) and ret:
|
|
664
|
+
try:
|
|
665
|
+
tee.write(ret)
|
|
666
|
+
except Exception:
|
|
667
|
+
pass
|
|
668
|
+
except Exception as e:
|
|
669
|
+
exc = e
|
|
670
|
+
finally:
|
|
671
|
+
rc = self._read_return_code()
|
|
672
|
+
if trace:
|
|
673
|
+
try:
|
|
674
|
+
self.stata.run("set trace off")
|
|
675
|
+
except Exception:
|
|
676
|
+
pass
|
|
677
|
+
finally:
|
|
678
|
+
self._is_executing = False
|
|
395
679
|
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
if
|
|
680
|
+
try:
|
|
681
|
+
if notify_progress is not None:
|
|
682
|
+
await notify_progress(0, None, "Running Stata command")
|
|
683
|
+
|
|
684
|
+
await anyio.to_thread.run_sync(_run_blocking, abandon_on_cancel=True)
|
|
685
|
+
except get_cancelled_exc_class():
|
|
686
|
+
# Best-effort cancellation: signal Stata to break, wait briefly, then propagate.
|
|
687
|
+
self._request_break_in()
|
|
688
|
+
await self._wait_for_stata_stop()
|
|
689
|
+
raise
|
|
690
|
+
finally:
|
|
691
|
+
tee.close()
|
|
692
|
+
|
|
693
|
+
# Cache detected graphs after command completes
|
|
694
|
+
if graph_cache:
|
|
411
695
|
try:
|
|
412
|
-
|
|
696
|
+
# Use the enhanced pystata-integrated caching method
|
|
697
|
+
if hasattr(graph_cache, 'cache_detected_graphs_with_pystata'):
|
|
698
|
+
cached_graphs = await graph_cache.cache_detected_graphs_with_pystata()
|
|
699
|
+
else:
|
|
700
|
+
cached_graphs = await graph_cache.cache_detected_graphs()
|
|
701
|
+
|
|
702
|
+
if cached_graphs and notify_progress:
|
|
703
|
+
await notify_progress(1, 1, f"Command completed. Cached {len(cached_graphs)} graphs: {', '.join(cached_graphs)}")
|
|
704
|
+
except Exception as e:
|
|
705
|
+
logger.warning(f"Failed to cache detected graphs: {e}")
|
|
706
|
+
|
|
707
|
+
tail_text = tail.get_value()
|
|
708
|
+
combined = (tail_text or "") + (f"\n{exc}" if exc else "")
|
|
709
|
+
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
710
|
+
if exc is None and rc_hint is not None and rc_hint != 0:
|
|
711
|
+
rc = rc_hint
|
|
712
|
+
if exc is None and rc_hint is None:
|
|
713
|
+
rc = 0 if rc is None or rc != 0 else rc
|
|
714
|
+
success = rc == 0 and exc is None
|
|
715
|
+
error = None
|
|
716
|
+
if not success:
|
|
717
|
+
snippet = (tail_text[-800:] if tail_text else None) or (str(exc) if exc else None)
|
|
718
|
+
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
719
|
+
rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
|
|
720
|
+
line_no = self._parse_line_from_text(combined) if combined else None
|
|
721
|
+
message = "Stata error"
|
|
722
|
+
if tail_text and tail_text.strip():
|
|
723
|
+
for line in reversed(tail_text.splitlines()):
|
|
724
|
+
if line.strip():
|
|
725
|
+
message = line.strip()
|
|
726
|
+
break
|
|
727
|
+
elif exc is not None:
|
|
728
|
+
message = str(exc).strip() or message
|
|
729
|
+
|
|
730
|
+
error = ErrorEnvelope(
|
|
731
|
+
message=message,
|
|
732
|
+
rc=rc_final,
|
|
733
|
+
line=line_no,
|
|
734
|
+
command=code,
|
|
735
|
+
log_path=log_path,
|
|
736
|
+
snippet=snippet,
|
|
737
|
+
trace=trace or None,
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
duration = time.time() - start_time
|
|
741
|
+
code_preview = code.replace("\n", "\\n")
|
|
742
|
+
logger.info(
|
|
743
|
+
"stata.run(stream) rc=%s success=%s trace=%s duration_ms=%.2f code_preview=%s",
|
|
744
|
+
rc,
|
|
745
|
+
success,
|
|
746
|
+
trace,
|
|
747
|
+
duration * 1000,
|
|
748
|
+
code_preview[:120],
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
result = CommandResponse(
|
|
752
|
+
command=code,
|
|
753
|
+
rc=rc,
|
|
754
|
+
stdout="",
|
|
755
|
+
stderr=None,
|
|
756
|
+
log_path=log_path,
|
|
757
|
+
success=success,
|
|
758
|
+
error=error,
|
|
759
|
+
)
|
|
760
|
+
|
|
761
|
+
if notify_progress is not None:
|
|
762
|
+
await notify_progress(1, 1, "Finished")
|
|
763
|
+
|
|
764
|
+
return result
|
|
765
|
+
|
|
766
|
+
def _count_do_file_lines(self, path: str) -> int:
|
|
767
|
+
try:
|
|
768
|
+
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
|
769
|
+
lines = f.read().splitlines()
|
|
770
|
+
except Exception:
|
|
771
|
+
return 0
|
|
772
|
+
|
|
773
|
+
total = 0
|
|
774
|
+
for line in lines:
|
|
775
|
+
s = line.strip()
|
|
776
|
+
if not s:
|
|
777
|
+
continue
|
|
778
|
+
if s.startswith("*"):
|
|
779
|
+
continue
|
|
780
|
+
if s.startswith("//"):
|
|
781
|
+
continue
|
|
782
|
+
total += 1
|
|
783
|
+
return total
|
|
784
|
+
|
|
785
|
+
async def run_do_file_streaming(
|
|
786
|
+
self,
|
|
787
|
+
path: str,
|
|
788
|
+
*,
|
|
789
|
+
notify_log: Callable[[str], Awaitable[None]],
|
|
790
|
+
notify_progress: Optional[Callable[[float, Optional[float], Optional[str]], Awaitable[None]]] = None,
|
|
791
|
+
echo: bool = True,
|
|
792
|
+
trace: bool = False,
|
|
793
|
+
max_output_lines: Optional[int] = None,
|
|
794
|
+
cwd: Optional[str] = None,
|
|
795
|
+
auto_cache_graphs: bool = False,
|
|
796
|
+
on_graph_cached: Optional[Callable[[str, bool], Awaitable[None]]] = None,
|
|
797
|
+
) -> CommandResponse:
|
|
798
|
+
if cwd is not None and not os.path.isdir(cwd):
|
|
799
|
+
return CommandResponse(
|
|
800
|
+
command=f'do "{path}"',
|
|
801
|
+
rc=601,
|
|
802
|
+
stdout="",
|
|
803
|
+
stderr=None,
|
|
804
|
+
success=False,
|
|
805
|
+
error=ErrorEnvelope(
|
|
806
|
+
message=f"cwd not found: {cwd}",
|
|
807
|
+
rc=601,
|
|
808
|
+
command=path,
|
|
809
|
+
),
|
|
810
|
+
)
|
|
811
|
+
|
|
812
|
+
effective_path = path
|
|
813
|
+
if cwd is not None and not os.path.isabs(path):
|
|
814
|
+
effective_path = os.path.abspath(os.path.join(cwd, path))
|
|
815
|
+
|
|
816
|
+
if not os.path.exists(effective_path):
|
|
817
|
+
return CommandResponse(
|
|
818
|
+
command=f'do "{effective_path}"',
|
|
819
|
+
rc=601,
|
|
820
|
+
stdout="",
|
|
821
|
+
stderr=None,
|
|
822
|
+
success=False,
|
|
823
|
+
error=ErrorEnvelope(
|
|
824
|
+
message=f"Do-file not found: {effective_path}",
|
|
825
|
+
rc=601,
|
|
826
|
+
command=effective_path,
|
|
827
|
+
),
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
total_lines = self._count_do_file_lines(effective_path)
|
|
831
|
+
executed_lines = 0
|
|
832
|
+
last_progress_time = 0.0
|
|
833
|
+
dot_prompt = re.compile(r"^\.\s+\S")
|
|
834
|
+
|
|
835
|
+
async def on_chunk_for_progress(chunk: str) -> None:
|
|
836
|
+
nonlocal executed_lines, last_progress_time
|
|
837
|
+
if total_lines <= 0 or notify_progress is None:
|
|
838
|
+
return
|
|
839
|
+
for line in chunk.splitlines():
|
|
840
|
+
if dot_prompt.match(line):
|
|
841
|
+
executed_lines += 1
|
|
842
|
+
if executed_lines > total_lines:
|
|
843
|
+
executed_lines = total_lines
|
|
844
|
+
|
|
845
|
+
now = time.monotonic()
|
|
846
|
+
if executed_lines > 0 and (now - last_progress_time) >= 0.25:
|
|
847
|
+
last_progress_time = now
|
|
848
|
+
await notify_progress(
|
|
849
|
+
float(executed_lines),
|
|
850
|
+
float(total_lines),
|
|
851
|
+
f"Executing do-file: {executed_lines}/{total_lines}",
|
|
852
|
+
)
|
|
853
|
+
|
|
854
|
+
if not self._initialized:
|
|
855
|
+
self.init()
|
|
856
|
+
|
|
857
|
+
start_time = time.time()
|
|
858
|
+
exc: Optional[Exception] = None
|
|
859
|
+
|
|
860
|
+
# Setup streaming graph cache if enabled
|
|
861
|
+
graph_cache = None
|
|
862
|
+
if auto_cache_graphs:
|
|
863
|
+
graph_cache = StreamingGraphCache(self, auto_cache=True)
|
|
864
|
+
|
|
865
|
+
graph_cache_callback = self._create_graph_cache_callback(on_graph_cached, notify_log)
|
|
866
|
+
|
|
867
|
+
graph_cache.add_cache_callback(graph_cache_callback)
|
|
868
|
+
|
|
869
|
+
log_file = tempfile.NamedTemporaryFile(
|
|
870
|
+
prefix="mcp_stata_",
|
|
871
|
+
suffix=".log",
|
|
872
|
+
delete=False,
|
|
873
|
+
mode="w",
|
|
874
|
+
encoding="utf-8",
|
|
875
|
+
errors="replace",
|
|
876
|
+
buffering=1,
|
|
877
|
+
)
|
|
878
|
+
log_path = log_file.name
|
|
879
|
+
tail = TailBuffer(max_chars=8000)
|
|
880
|
+
tee = FileTeeIO(log_file, tail)
|
|
881
|
+
|
|
882
|
+
# Inform the MCP client immediately where to read/tail the output.
|
|
883
|
+
await notify_log(json.dumps({"event": "log_path", "path": log_path}))
|
|
884
|
+
|
|
885
|
+
rc = -1
|
|
886
|
+
path_for_stata = effective_path.replace("\\", "/")
|
|
887
|
+
command = f'do "{path_for_stata}"'
|
|
888
|
+
|
|
889
|
+
# Capture initial graph state BEFORE execution starts
|
|
890
|
+
# This allows post-execution detection to identify new graphs
|
|
891
|
+
if graph_cache:
|
|
892
|
+
try:
|
|
893
|
+
graph_cache._initial_graphs = set(self.list_graphs())
|
|
894
|
+
logger.debug(f"Initial graph state captured: {graph_cache._initial_graphs}")
|
|
895
|
+
except Exception as e:
|
|
896
|
+
logger.debug(f"Failed to capture initial graph state: {e}")
|
|
897
|
+
graph_cache._initial_graphs = set()
|
|
898
|
+
|
|
899
|
+
def _run_blocking() -> None:
|
|
900
|
+
nonlocal rc, exc
|
|
901
|
+
with self._exec_lock:
|
|
902
|
+
# Set execution flag to prevent recursive Stata calls
|
|
903
|
+
self._is_executing = True
|
|
904
|
+
try:
|
|
905
|
+
with self._temp_cwd(cwd):
|
|
906
|
+
with self._redirect_io_streaming(tee, tee):
|
|
907
|
+
try:
|
|
908
|
+
if trace:
|
|
909
|
+
self.stata.run("set trace on")
|
|
910
|
+
ret = self.stata.run(command, echo=echo)
|
|
911
|
+
# Some PyStata builds return output as a string rather than printing.
|
|
912
|
+
if isinstance(ret, str) and ret:
|
|
913
|
+
try:
|
|
914
|
+
tee.write(ret)
|
|
915
|
+
except Exception:
|
|
916
|
+
pass
|
|
917
|
+
except Exception as e:
|
|
918
|
+
exc = e
|
|
919
|
+
finally:
|
|
920
|
+
rc = self._read_return_code()
|
|
921
|
+
if trace:
|
|
922
|
+
try:
|
|
923
|
+
self.stata.run("set trace off")
|
|
924
|
+
except Exception:
|
|
925
|
+
pass
|
|
926
|
+
finally:
|
|
927
|
+
# Clear execution flag
|
|
928
|
+
self._is_executing = False
|
|
929
|
+
|
|
930
|
+
done = anyio.Event()
|
|
931
|
+
|
|
932
|
+
async def _monitor_progress_from_log() -> None:
|
|
933
|
+
if notify_progress is None or total_lines <= 0:
|
|
934
|
+
return
|
|
935
|
+
last_pos = 0
|
|
936
|
+
try:
|
|
937
|
+
with open(log_path, "r", encoding="utf-8", errors="replace") as f:
|
|
938
|
+
while not done.is_set():
|
|
939
|
+
f.seek(last_pos)
|
|
940
|
+
chunk = f.read()
|
|
941
|
+
if chunk:
|
|
942
|
+
last_pos = f.tell()
|
|
943
|
+
await on_chunk_for_progress(chunk)
|
|
944
|
+
await anyio.sleep(0.05)
|
|
945
|
+
|
|
946
|
+
f.seek(last_pos)
|
|
947
|
+
chunk = f.read()
|
|
948
|
+
if chunk:
|
|
949
|
+
await on_chunk_for_progress(chunk)
|
|
950
|
+
except Exception:
|
|
951
|
+
return
|
|
952
|
+
|
|
953
|
+
async with anyio.create_task_group() as tg:
|
|
954
|
+
tg.start_soon(_monitor_progress_from_log)
|
|
955
|
+
|
|
956
|
+
if notify_progress is not None:
|
|
957
|
+
if total_lines > 0:
|
|
958
|
+
await notify_progress(0, float(total_lines), f"Executing do-file: 0/{total_lines}")
|
|
959
|
+
else:
|
|
960
|
+
await notify_progress(0, None, "Running do-file")
|
|
961
|
+
|
|
962
|
+
try:
|
|
963
|
+
await anyio.to_thread.run_sync(_run_blocking, abandon_on_cancel=True)
|
|
964
|
+
except get_cancelled_exc_class():
|
|
965
|
+
self._request_break_in()
|
|
966
|
+
await self._wait_for_stata_stop()
|
|
967
|
+
raise
|
|
968
|
+
finally:
|
|
969
|
+
done.set()
|
|
970
|
+
tee.close()
|
|
971
|
+
|
|
972
|
+
# Robust post-execution graph detection and caching
|
|
973
|
+
# This is the ONLY place where graphs are detected and cached
|
|
974
|
+
# Runs after execution completes, when it's safe to call list_graphs()
|
|
975
|
+
if graph_cache and graph_cache.auto_cache:
|
|
976
|
+
cached_graphs = []
|
|
977
|
+
try:
|
|
978
|
+
# Get initial state (before execution)
|
|
979
|
+
initial_graphs = getattr(graph_cache, '_initial_graphs', set())
|
|
980
|
+
|
|
981
|
+
# Get current state (after execution)
|
|
982
|
+
logger.debug("Post-execution: Querying graph state via list_graphs()")
|
|
983
|
+
current_graphs = set(self.list_graphs())
|
|
984
|
+
|
|
985
|
+
# Detect new graphs (created during execution)
|
|
986
|
+
new_graphs = current_graphs - initial_graphs - graph_cache._cached_graphs
|
|
987
|
+
|
|
988
|
+
if new_graphs:
|
|
989
|
+
logger.info(f"Detected {len(new_graphs)} new graph(s): {sorted(new_graphs)}")
|
|
990
|
+
|
|
991
|
+
# Cache each detected graph
|
|
992
|
+
for graph_name in new_graphs:
|
|
993
|
+
try:
|
|
994
|
+
logger.debug(f"Caching graph: {graph_name}")
|
|
995
|
+
cache_result = await anyio.to_thread.run_sync(
|
|
996
|
+
self.cache_graph_on_creation,
|
|
997
|
+
graph_name
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
if cache_result:
|
|
1001
|
+
cached_graphs.append(graph_name)
|
|
1002
|
+
graph_cache._cached_graphs.add(graph_name)
|
|
1003
|
+
logger.debug(f"Successfully cached graph: {graph_name}")
|
|
1004
|
+
else:
|
|
1005
|
+
logger.warning(f"Failed to cache graph: {graph_name}")
|
|
1006
|
+
|
|
1007
|
+
# Trigger callbacks
|
|
1008
|
+
for callback in graph_cache._cache_callbacks:
|
|
1009
|
+
try:
|
|
1010
|
+
await anyio.to_thread.run_sync(callback, graph_name, cache_result)
|
|
1011
|
+
except Exception as e:
|
|
1012
|
+
logger.debug(f"Callback failed for {graph_name}: {e}")
|
|
1013
|
+
|
|
1014
|
+
except Exception as e:
|
|
1015
|
+
logger.error(f"Error caching graph {graph_name}: {e}")
|
|
1016
|
+
# Trigger callbacks with failure
|
|
1017
|
+
for callback in graph_cache._cache_callbacks:
|
|
1018
|
+
try:
|
|
1019
|
+
await anyio.to_thread.run_sync(callback, graph_name, False)
|
|
1020
|
+
except Exception:
|
|
1021
|
+
pass
|
|
1022
|
+
|
|
1023
|
+
# Check for dropped graphs (for completeness)
|
|
1024
|
+
dropped_graphs = initial_graphs - current_graphs
|
|
1025
|
+
if dropped_graphs:
|
|
1026
|
+
logger.debug(f"Graphs dropped during execution: {sorted(dropped_graphs)}")
|
|
1027
|
+
for graph_name in dropped_graphs:
|
|
1028
|
+
try:
|
|
1029
|
+
self.invalidate_graph_cache(graph_name)
|
|
1030
|
+
except Exception:
|
|
1031
|
+
pass
|
|
1032
|
+
|
|
1033
|
+
# Notify progress if graphs were cached
|
|
1034
|
+
if cached_graphs and notify_progress:
|
|
1035
|
+
await notify_progress(
|
|
1036
|
+
float(total_lines) if total_lines > 0 else 1,
|
|
1037
|
+
float(total_lines) if total_lines > 0 else 1,
|
|
1038
|
+
f"Do-file completed. Cached {len(cached_graphs)} graph(s): {', '.join(cached_graphs)}"
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
except Exception as e:
|
|
1042
|
+
logger.error(f"Post-execution graph detection failed: {e}")
|
|
1043
|
+
|
|
1044
|
+
tail_text = tail.get_value()
|
|
1045
|
+
combined = (tail_text or "") + (f"\n{exc}" if exc else "")
|
|
1046
|
+
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
1047
|
+
if exc is None and rc_hint is not None and rc_hint != 0:
|
|
1048
|
+
rc = rc_hint
|
|
1049
|
+
if exc is None and rc_hint is None:
|
|
1050
|
+
rc = 0 if rc is None or rc != 0 else rc
|
|
1051
|
+
success = rc == 0 and exc is None
|
|
1052
|
+
error = None
|
|
1053
|
+
if not success:
|
|
1054
|
+
snippet = (tail_text[-800:] if tail_text else None) or (str(exc) if exc else None)
|
|
1055
|
+
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
1056
|
+
rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
|
|
1057
|
+
line_no = self._parse_line_from_text(combined) if combined else None
|
|
1058
|
+
message = "Stata error"
|
|
1059
|
+
if tail_text and tail_text.strip():
|
|
1060
|
+
for line in reversed(tail_text.splitlines()):
|
|
1061
|
+
if line.strip():
|
|
1062
|
+
message = line.strip()
|
|
1063
|
+
break
|
|
1064
|
+
elif exc is not None:
|
|
1065
|
+
message = str(exc).strip() or message
|
|
1066
|
+
|
|
1067
|
+
error = ErrorEnvelope(
|
|
1068
|
+
message=message,
|
|
1069
|
+
rc=rc_final,
|
|
1070
|
+
line=line_no,
|
|
1071
|
+
command=command,
|
|
1072
|
+
log_path=log_path,
|
|
1073
|
+
snippet=snippet,
|
|
1074
|
+
trace=trace or None,
|
|
1075
|
+
)
|
|
1076
|
+
|
|
1077
|
+
duration = time.time() - start_time
|
|
1078
|
+
logger.info(
|
|
1079
|
+
"stata.run(do stream) rc=%s success=%s trace=%s duration_ms=%.2f path=%s",
|
|
1080
|
+
rc,
|
|
1081
|
+
success,
|
|
1082
|
+
trace,
|
|
1083
|
+
duration * 1000,
|
|
1084
|
+
effective_path,
|
|
1085
|
+
)
|
|
1086
|
+
|
|
1087
|
+
result = CommandResponse(
|
|
1088
|
+
command=command,
|
|
1089
|
+
rc=rc,
|
|
1090
|
+
stdout="",
|
|
1091
|
+
stderr=None,
|
|
1092
|
+
log_path=log_path,
|
|
1093
|
+
success=success,
|
|
1094
|
+
error=error,
|
|
1095
|
+
)
|
|
1096
|
+
|
|
1097
|
+
if notify_progress is not None:
|
|
1098
|
+
if total_lines > 0:
|
|
1099
|
+
await notify_progress(float(total_lines), float(total_lines), f"Executing do-file: {total_lines}/{total_lines}")
|
|
1100
|
+
else:
|
|
1101
|
+
await notify_progress(1, 1, "Finished")
|
|
1102
|
+
|
|
1103
|
+
return result
|
|
1104
|
+
|
|
1105
|
+
def run_command_structured(self, code: str, echo: bool = True, trace: bool = False, max_output_lines: Optional[int] = None, cwd: Optional[str] = None) -> CommandResponse:
|
|
1106
|
+
"""Runs a Stata command and returns a structured envelope.
|
|
1107
|
+
|
|
1108
|
+
Args:
|
|
1109
|
+
code: The Stata command to execute.
|
|
1110
|
+
echo: If True, the command itself is included in the output.
|
|
1111
|
+
trace: If True, enables trace mode for debugging.
|
|
1112
|
+
max_output_lines: If set, truncates stdout to this many lines (token efficiency).
|
|
1113
|
+
"""
|
|
1114
|
+
result = self._exec_with_capture(code, echo=echo, trace=trace, cwd=cwd)
|
|
1115
|
+
|
|
1116
|
+
# Truncate stdout if requested
|
|
1117
|
+
if max_output_lines is not None and result.stdout:
|
|
1118
|
+
lines = result.stdout.splitlines()
|
|
1119
|
+
if len(lines) > max_output_lines:
|
|
1120
|
+
truncated_lines = lines[:max_output_lines]
|
|
1121
|
+
truncated_lines.append(f"\n... (output truncated: showing {max_output_lines} of {len(lines)} lines)")
|
|
1122
|
+
result = CommandResponse(
|
|
1123
|
+
command=result.command,
|
|
1124
|
+
rc=result.rc,
|
|
1125
|
+
stdout="\n".join(truncated_lines),
|
|
1126
|
+
stderr=result.stderr,
|
|
1127
|
+
success=result.success,
|
|
1128
|
+
error=result.error,
|
|
1129
|
+
)
|
|
1130
|
+
|
|
1131
|
+
return result
|
|
1132
|
+
|
|
1133
|
+
def get_data(self, start: int = 0, count: int = 50) -> List[Dict[str, Any]]:
|
|
1134
|
+
"""Returns valid JSON-serializable data."""
|
|
1135
|
+
if not self._initialized:
|
|
1136
|
+
self.init()
|
|
1137
|
+
|
|
1138
|
+
if count > self.MAX_DATA_ROWS:
|
|
1139
|
+
count = self.MAX_DATA_ROWS
|
|
1140
|
+
|
|
1141
|
+
try:
|
|
1142
|
+
# Use pystata integration to retrieve data
|
|
1143
|
+
df = self.stata.pdataframe_from_data()
|
|
1144
|
+
|
|
1145
|
+
# Slice
|
|
1146
|
+
sliced = df.iloc[start : start + count]
|
|
1147
|
+
|
|
1148
|
+
# Convert to dict
|
|
1149
|
+
return sliced.to_dict(orient="records")
|
|
1150
|
+
except Exception as e:
|
|
1151
|
+
return [{"error": f"Failed to retrieve data: {e}"}]
|
|
1152
|
+
|
|
1153
|
+
def list_variables(self) -> List[Dict[str, str]]:
|
|
1154
|
+
"""Returns list of variables with labels."""
|
|
1155
|
+
if not self._initialized:
|
|
1156
|
+
self.init()
|
|
1157
|
+
|
|
1158
|
+
# We can use sfi to be efficient
|
|
1159
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
1160
|
+
vars_info = []
|
|
1161
|
+
for i in range(Data.getVarCount()):
|
|
1162
|
+
var_index = i # 0-based
|
|
1163
|
+
name = Data.getVarName(var_index)
|
|
1164
|
+
label = Data.getVarLabel(var_index)
|
|
1165
|
+
type_str = Data.getVarType(var_index) # Returns int
|
|
1166
|
+
|
|
1167
|
+
vars_info.append({
|
|
1168
|
+
"name": name,
|
|
1169
|
+
"label": label,
|
|
1170
|
+
"type": str(type_str),
|
|
1171
|
+
})
|
|
1172
|
+
return vars_info
|
|
1173
|
+
|
|
1174
|
+
def get_dataset_state(self) -> Dict[str, Any]:
|
|
1175
|
+
"""Return basic dataset state without mutating the dataset."""
|
|
1176
|
+
if not self._initialized:
|
|
1177
|
+
self.init()
|
|
1178
|
+
|
|
1179
|
+
from sfi import Data, Macro # type: ignore[import-not-found]
|
|
1180
|
+
|
|
1181
|
+
n = int(Data.getObsTotal())
|
|
1182
|
+
k = int(Data.getVarCount())
|
|
1183
|
+
|
|
1184
|
+
frame = "default"
|
|
1185
|
+
sortlist = ""
|
|
1186
|
+
changed = False
|
|
1187
|
+
try:
|
|
1188
|
+
frame = str(Macro.getCValue("frame") or "default")
|
|
1189
|
+
except Exception:
|
|
1190
|
+
frame = "default"
|
|
1191
|
+
try:
|
|
1192
|
+
sortlist = str(Macro.getCValue("sortlist") or "")
|
|
1193
|
+
except Exception:
|
|
1194
|
+
sortlist = ""
|
|
1195
|
+
try:
|
|
1196
|
+
changed = bool(int(float(Macro.getCValue("changed") or "0")))
|
|
1197
|
+
except Exception:
|
|
1198
|
+
changed = False
|
|
1199
|
+
|
|
1200
|
+
return {"frame": frame, "n": n, "k": k, "sortlist": sortlist, "changed": changed}
|
|
1201
|
+
|
|
1202
|
+
def _require_data_in_memory(self) -> None:
|
|
1203
|
+
state = self.get_dataset_state()
|
|
1204
|
+
if int(state.get("k", 0) or 0) == 0 and int(state.get("n", 0) or 0) == 0:
|
|
1205
|
+
# Stata empty dataset could still have k>0 n==0; treat that as ok.
|
|
1206
|
+
raise RuntimeError("No data in memory")
|
|
1207
|
+
|
|
1208
|
+
def _get_var_index_map(self) -> Dict[str, int]:
|
|
1209
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
1210
|
+
|
|
1211
|
+
out: Dict[str, int] = {}
|
|
1212
|
+
for i in range(int(Data.getVarCount())):
|
|
1213
|
+
try:
|
|
1214
|
+
out[str(Data.getVarName(i))] = i
|
|
1215
|
+
except Exception:
|
|
1216
|
+
continue
|
|
1217
|
+
return out
|
|
1218
|
+
|
|
1219
|
+
def list_variables_rich(self) -> List[Dict[str, Any]]:
|
|
1220
|
+
"""Return variable metadata (name/type/label/format/valueLabel) without modifying the dataset."""
|
|
1221
|
+
if not self._initialized:
|
|
1222
|
+
self.init()
|
|
1223
|
+
|
|
1224
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
1225
|
+
|
|
1226
|
+
vars_info: List[Dict[str, Any]] = []
|
|
1227
|
+
for i in range(int(Data.getVarCount())):
|
|
1228
|
+
name = str(Data.getVarName(i))
|
|
1229
|
+
label = None
|
|
1230
|
+
fmt = None
|
|
1231
|
+
vtype = None
|
|
1232
|
+
value_label = None
|
|
1233
|
+
try:
|
|
1234
|
+
label = Data.getVarLabel(i)
|
|
1235
|
+
except Exception:
|
|
1236
|
+
label = None
|
|
1237
|
+
try:
|
|
1238
|
+
fmt = Data.getVarFormat(i)
|
|
1239
|
+
except Exception:
|
|
1240
|
+
fmt = None
|
|
1241
|
+
try:
|
|
1242
|
+
vtype = Data.getVarType(i)
|
|
1243
|
+
except Exception:
|
|
1244
|
+
vtype = None
|
|
1245
|
+
|
|
1246
|
+
vars_info.append(
|
|
1247
|
+
{
|
|
1248
|
+
"name": name,
|
|
1249
|
+
"type": str(vtype) if vtype is not None else None,
|
|
1250
|
+
"label": label if label else None,
|
|
1251
|
+
"format": fmt if fmt else None,
|
|
1252
|
+
"valueLabel": value_label,
|
|
1253
|
+
}
|
|
1254
|
+
)
|
|
1255
|
+
return vars_info
|
|
1256
|
+
|
|
1257
|
+
@staticmethod
|
|
1258
|
+
def _is_stata_missing(value: Any) -> bool:
|
|
1259
|
+
if value is None:
|
|
1260
|
+
return True
|
|
1261
|
+
if isinstance(value, float):
|
|
1262
|
+
# Stata missing values typically show up as very large floats via sfi.Data.get
|
|
1263
|
+
return value > 8.0e307
|
|
1264
|
+
return False
|
|
1265
|
+
|
|
1266
|
+
def _normalize_cell(self, value: Any, *, max_chars: int) -> tuple[Any, bool]:
|
|
1267
|
+
if self._is_stata_missing(value):
|
|
1268
|
+
return ".", False
|
|
1269
|
+
if isinstance(value, str):
|
|
1270
|
+
if len(value) > max_chars:
|
|
1271
|
+
return value[:max_chars], True
|
|
1272
|
+
return value, False
|
|
1273
|
+
return value, False
|
|
1274
|
+
|
|
1275
|
+
def get_page(
|
|
1276
|
+
self,
|
|
1277
|
+
*,
|
|
1278
|
+
offset: int,
|
|
1279
|
+
limit: int,
|
|
1280
|
+
vars: List[str],
|
|
1281
|
+
include_obs_no: bool,
|
|
1282
|
+
max_chars: int,
|
|
1283
|
+
obs_indices: Optional[List[int]] = None,
|
|
1284
|
+
) -> Dict[str, Any]:
|
|
1285
|
+
if not self._initialized:
|
|
1286
|
+
self.init()
|
|
1287
|
+
|
|
1288
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
1289
|
+
|
|
1290
|
+
state = self.get_dataset_state()
|
|
1291
|
+
n = int(state.get("n", 0) or 0)
|
|
1292
|
+
k = int(state.get("k", 0) or 0)
|
|
1293
|
+
if k == 0 and n == 0:
|
|
1294
|
+
raise RuntimeError("No data in memory")
|
|
1295
|
+
|
|
1296
|
+
var_map = self._get_var_index_map()
|
|
1297
|
+
for v in vars:
|
|
1298
|
+
if v not in var_map:
|
|
1299
|
+
raise ValueError(f"Invalid variable: {v}")
|
|
1300
|
+
|
|
1301
|
+
if obs_indices is None:
|
|
1302
|
+
start = offset
|
|
1303
|
+
end = min(offset + limit, n)
|
|
1304
|
+
if start >= n:
|
|
1305
|
+
rows: list[list[Any]] = []
|
|
1306
|
+
returned = 0
|
|
1307
|
+
obs_list: list[int] = []
|
|
1308
|
+
else:
|
|
1309
|
+
obs_list = list(range(start, end))
|
|
1310
|
+
raw_rows = Data.get(var=vars, obs=obs_list)
|
|
1311
|
+
rows = raw_rows
|
|
1312
|
+
returned = len(rows)
|
|
1313
|
+
else:
|
|
1314
|
+
start = offset
|
|
1315
|
+
end = min(offset + limit, len(obs_indices))
|
|
1316
|
+
obs_list = obs_indices[start:end]
|
|
1317
|
+
raw_rows = Data.get(var=vars, obs=obs_list) if obs_list else []
|
|
1318
|
+
rows = raw_rows
|
|
1319
|
+
returned = len(rows)
|
|
1320
|
+
|
|
1321
|
+
out_vars = list(vars)
|
|
1322
|
+
out_rows: list[list[Any]] = []
|
|
1323
|
+
truncated_cells = 0
|
|
1324
|
+
|
|
1325
|
+
if include_obs_no:
|
|
1326
|
+
out_vars = ["_n"] + out_vars
|
|
1327
|
+
|
|
1328
|
+
for idx, raw in enumerate(rows):
|
|
1329
|
+
norm_row: list[Any] = []
|
|
1330
|
+
if include_obs_no:
|
|
1331
|
+
norm_row.append(int(obs_list[idx]) + 1)
|
|
1332
|
+
for cell in raw:
|
|
1333
|
+
norm, truncated = self._normalize_cell(cell, max_chars=max_chars)
|
|
1334
|
+
if truncated:
|
|
1335
|
+
truncated_cells += 1
|
|
1336
|
+
norm_row.append(norm)
|
|
1337
|
+
out_rows.append(norm_row)
|
|
1338
|
+
|
|
1339
|
+
return {
|
|
1340
|
+
"vars": out_vars,
|
|
1341
|
+
"rows": out_rows,
|
|
1342
|
+
"returned": returned,
|
|
1343
|
+
"truncated_cells": truncated_cells,
|
|
1344
|
+
}
|
|
1345
|
+
|
|
1346
|
+
_FILTER_IDENT = re.compile(r"\b[A-Za-z_][A-Za-z0-9_]*\b")
|
|
1347
|
+
|
|
1348
|
+
def _extract_filter_vars(self, filter_expr: str) -> List[str]:
|
|
1349
|
+
tokens = set(self._FILTER_IDENT.findall(filter_expr or ""))
|
|
1350
|
+
# Exclude python keywords we might inject.
|
|
1351
|
+
exclude = {"and", "or", "not", "True", "False", "None"}
|
|
1352
|
+
var_map = self._get_var_index_map()
|
|
1353
|
+
vars_used = [t for t in tokens if t not in exclude and t in var_map]
|
|
1354
|
+
return sorted(vars_used)
|
|
1355
|
+
|
|
1356
|
+
def _compile_filter_expr(self, filter_expr: str) -> Any:
|
|
1357
|
+
expr = (filter_expr or "").strip()
|
|
1358
|
+
if not expr:
|
|
1359
|
+
raise ValueError("Empty filter")
|
|
1360
|
+
|
|
1361
|
+
# Stata boolean operators.
|
|
1362
|
+
expr = expr.replace("&", " and ").replace("|", " or ")
|
|
1363
|
+
|
|
1364
|
+
# Replace missing literal '.' (but not numeric decimals like 0.5).
|
|
1365
|
+
expr = re.sub(r"(?<![0-9])\.(?![0-9A-Za-z_])", "None", expr)
|
|
1366
|
+
|
|
1367
|
+
try:
|
|
1368
|
+
return compile(expr, "<filterExpr>", "eval")
|
|
1369
|
+
except Exception as e:
|
|
1370
|
+
raise ValueError(f"Invalid filter expression: {e}")
|
|
1371
|
+
|
|
1372
|
+
def validate_filter_expr(self, filter_expr: str) -> None:
|
|
1373
|
+
if not self._initialized:
|
|
1374
|
+
self.init()
|
|
1375
|
+
state = self.get_dataset_state()
|
|
1376
|
+
if int(state.get("k", 0) or 0) == 0 and int(state.get("n", 0) or 0) == 0:
|
|
1377
|
+
raise RuntimeError("No data in memory")
|
|
1378
|
+
|
|
1379
|
+
vars_used = self._extract_filter_vars(filter_expr)
|
|
1380
|
+
if not vars_used:
|
|
1381
|
+
# still allow constant expressions like "1" or "True"
|
|
1382
|
+
self._compile_filter_expr(filter_expr)
|
|
1383
|
+
return
|
|
1384
|
+
self._compile_filter_expr(filter_expr)
|
|
1385
|
+
|
|
1386
|
+
def compute_view_indices(self, filter_expr: str, *, chunk_size: int = 5000) -> List[int]:
|
|
1387
|
+
if not self._initialized:
|
|
1388
|
+
self.init()
|
|
1389
|
+
|
|
1390
|
+
from sfi import Data # type: ignore[import-not-found]
|
|
1391
|
+
|
|
1392
|
+
state = self.get_dataset_state()
|
|
1393
|
+
n = int(state.get("n", 0) or 0)
|
|
1394
|
+
k = int(state.get("k", 0) or 0)
|
|
1395
|
+
if k == 0 and n == 0:
|
|
1396
|
+
raise RuntimeError("No data in memory")
|
|
1397
|
+
|
|
1398
|
+
vars_used = self._extract_filter_vars(filter_expr)
|
|
1399
|
+
code = self._compile_filter_expr(filter_expr)
|
|
1400
|
+
_ = self._get_var_index_map()
|
|
1401
|
+
|
|
1402
|
+
indices: List[int] = []
|
|
1403
|
+
for start in range(0, n, chunk_size):
|
|
1404
|
+
end = min(start + chunk_size, n)
|
|
1405
|
+
obs_list = list(range(start, end))
|
|
1406
|
+
raw_rows = Data.get(var=vars_used, obs=obs_list) if vars_used else [[None] for _ in obs_list]
|
|
1407
|
+
|
|
1408
|
+
for row_i, obs in enumerate(obs_list):
|
|
1409
|
+
env: Dict[str, Any] = {}
|
|
1410
|
+
if vars_used:
|
|
1411
|
+
for j, v in enumerate(vars_used):
|
|
1412
|
+
val = raw_rows[row_i][j]
|
|
1413
|
+
env[v] = None if self._is_stata_missing(val) else val
|
|
1414
|
+
|
|
1415
|
+
ok = False
|
|
1416
|
+
try:
|
|
1417
|
+
ok = bool(eval(code, {"__builtins__": {}}, env))
|
|
1418
|
+
except NameError as e:
|
|
1419
|
+
raise ValueError(f"Invalid filter: {e}")
|
|
1420
|
+
except Exception as e:
|
|
1421
|
+
raise ValueError(f"Invalid filter: {e}")
|
|
1422
|
+
|
|
1423
|
+
if ok:
|
|
1424
|
+
indices.append(int(obs))
|
|
1425
|
+
|
|
1426
|
+
return indices
|
|
1427
|
+
|
|
1428
|
+
def get_variable_details(self, varname: str) -> str:
|
|
1429
|
+
"""Returns codebook/summary for a specific variable."""
|
|
1430
|
+
resp = self.run_command_structured(f"codebook {varname}", echo=True)
|
|
1431
|
+
if resp.success:
|
|
1432
|
+
return resp.stdout
|
|
1433
|
+
if resp.error:
|
|
1434
|
+
return resp.error.message
|
|
1435
|
+
return ""
|
|
1436
|
+
|
|
1437
|
+
def list_variables_structured(self) -> VariablesResponse:
|
|
1438
|
+
vars_info: List[VariableInfo] = []
|
|
1439
|
+
for item in self.list_variables():
|
|
1440
|
+
vars_info.append(
|
|
1441
|
+
VariableInfo(
|
|
1442
|
+
name=item.get("name", ""),
|
|
1443
|
+
label=item.get("label"),
|
|
1444
|
+
type=item.get("type"),
|
|
1445
|
+
)
|
|
1446
|
+
)
|
|
1447
|
+
return VariablesResponse(variables=vars_info)
|
|
1448
|
+
|
|
1449
|
+
def list_graphs(self, *, force_refresh: bool = False) -> List[str]:
|
|
1450
|
+
"""Returns list of graphs in memory with TTL caching."""
|
|
1451
|
+
if not self._initialized:
|
|
1452
|
+
self.init()
|
|
1453
|
+
|
|
1454
|
+
import time
|
|
1455
|
+
|
|
1456
|
+
# Prevent recursive Stata calls - if we're already executing, return cached or empty
|
|
1457
|
+
if self._is_executing:
|
|
1458
|
+
with self._list_graphs_cache_lock:
|
|
1459
|
+
if self._list_graphs_cache is not None:
|
|
1460
|
+
logger.debug("Recursive list_graphs call prevented, returning cached value")
|
|
1461
|
+
return self._list_graphs_cache
|
|
1462
|
+
else:
|
|
1463
|
+
logger.debug("Recursive list_graphs call prevented, returning empty list")
|
|
1464
|
+
return []
|
|
1465
|
+
|
|
1466
|
+
# Check if cache is valid
|
|
1467
|
+
current_time = time.time()
|
|
1468
|
+
with self._list_graphs_cache_lock:
|
|
1469
|
+
if (not force_refresh and self._list_graphs_cache is not None and
|
|
1470
|
+
current_time - self._list_graphs_cache_time < self.LIST_GRAPHS_TTL):
|
|
1471
|
+
return self._list_graphs_cache
|
|
1472
|
+
|
|
1473
|
+
# Cache miss or expired, fetch fresh data
|
|
1474
|
+
try:
|
|
1475
|
+
# 'graph dir' returns list in r(list)
|
|
1476
|
+
# We need to ensure we run it quietly so we don't spam.
|
|
1477
|
+
self.stata.run("quietly graph dir, memory")
|
|
1478
|
+
|
|
1479
|
+
# Accessing r-class results in Python can be tricky via pystata's run command.
|
|
1480
|
+
# We stash the result in a global macro that python sfi can easily read.
|
|
1481
|
+
from sfi import Macro # type: ignore[import-not-found]
|
|
1482
|
+
self.stata.run("global mcp_graph_list `r(list)'")
|
|
1483
|
+
graph_list_str = Macro.getGlobal("mcp_graph_list")
|
|
1484
|
+
raw_list = graph_list_str.split() if graph_list_str else []
|
|
1485
|
+
|
|
1486
|
+
# Map internal Stata names back to user-facing names when we have an alias.
|
|
1487
|
+
reverse = getattr(self, "_graph_name_reverse", {})
|
|
1488
|
+
graph_list = [reverse.get(n, n) for n in raw_list]
|
|
1489
|
+
|
|
1490
|
+
result = graph_list
|
|
1491
|
+
|
|
1492
|
+
# Update cache
|
|
1493
|
+
with self._list_graphs_cache_lock:
|
|
1494
|
+
self._list_graphs_cache = result
|
|
1495
|
+
self._list_graphs_cache_time = current_time
|
|
1496
|
+
|
|
1497
|
+
return result
|
|
1498
|
+
|
|
1499
|
+
except Exception as e:
|
|
1500
|
+
# On error, return cached result if available, otherwise empty list
|
|
1501
|
+
with self._list_graphs_cache_lock:
|
|
1502
|
+
if self._list_graphs_cache is not None:
|
|
1503
|
+
logger.warning(f"list_graphs failed, returning cached result: {e}")
|
|
1504
|
+
return self._list_graphs_cache
|
|
1505
|
+
logger.warning(f"list_graphs failed, no cache available: {e}")
|
|
1506
|
+
return []
|
|
1507
|
+
|
|
1508
|
+
def list_graphs_structured(self) -> GraphListResponse:
|
|
1509
|
+
names = self.list_graphs()
|
|
1510
|
+
active_name = names[-1] if names else None
|
|
1511
|
+
graphs = [GraphInfo(name=n, active=(n == active_name)) for n in names]
|
|
1512
|
+
return GraphListResponse(graphs=graphs)
|
|
1513
|
+
|
|
1514
|
+
def invalidate_list_graphs_cache(self) -> None:
|
|
1515
|
+
"""Invalidate the list_graphs cache to force fresh data on next call."""
|
|
1516
|
+
with self._list_graphs_cache_lock:
|
|
1517
|
+
self._list_graphs_cache = None
|
|
1518
|
+
self._list_graphs_cache_time = 0
|
|
1519
|
+
|
|
1520
|
+
def export_graph(self, graph_name: str = None, filename: str = None, format: str = "pdf") -> str:
|
|
1521
|
+
"""Exports graph to a temp file (pdf or png) and returns the path.
|
|
1522
|
+
|
|
1523
|
+
On Windows, PyStata can crash when exporting PNGs directly. For PNG on
|
|
1524
|
+
Windows, we save the graph to .gph and invoke the Stata executable in
|
|
1525
|
+
batch mode to export the PNG out-of-process.
|
|
1526
|
+
"""
|
|
1527
|
+
import tempfile
|
|
1528
|
+
|
|
1529
|
+
fmt = (format or "pdf").strip().lower()
|
|
1530
|
+
if fmt not in {"pdf", "png"}:
|
|
1531
|
+
raise ValueError(f"Unsupported graph export format: {format}. Allowed: pdf, png.")
|
|
1532
|
+
|
|
1533
|
+
if not filename:
|
|
1534
|
+
suffix = f".{fmt}"
|
|
1535
|
+
with tempfile.NamedTemporaryFile(prefix="mcp_stata_", suffix=suffix, delete=False) as tmp:
|
|
1536
|
+
filename = tmp.name
|
|
1537
|
+
else:
|
|
1538
|
+
# Ensure fresh start
|
|
1539
|
+
if os.path.exists(filename):
|
|
1540
|
+
try:
|
|
1541
|
+
os.remove(filename)
|
|
1542
|
+
except Exception:
|
|
1543
|
+
pass
|
|
1544
|
+
|
|
1545
|
+
# Keep the user-facing path as a normal absolute Windows path
|
|
1546
|
+
user_filename = os.path.abspath(filename)
|
|
1547
|
+
|
|
1548
|
+
if fmt == "png" and os.name == "nt":
|
|
1549
|
+
# 1) Save graph to a .gph file from the embedded session
|
|
1550
|
+
with tempfile.NamedTemporaryFile(prefix="mcp_stata_graph_", suffix=".gph", delete=False) as gph_tmp:
|
|
1551
|
+
gph_path = gph_tmp.name
|
|
1552
|
+
gph_path_for_stata = gph_path.replace("\\", "/")
|
|
1553
|
+
# Make the target graph current, then save without name() (which isn't accepted there)
|
|
1554
|
+
if graph_name:
|
|
1555
|
+
self._exec_no_capture(f'graph display "{graph_name}"', echo=False)
|
|
1556
|
+
save_cmd = f'graph save "{gph_path_for_stata}", replace'
|
|
1557
|
+
save_resp = self._exec_no_capture(save_cmd, echo=False)
|
|
1558
|
+
if not save_resp.success:
|
|
1559
|
+
msg = save_resp.error.message if save_resp.error else f"graph save failed (rc={save_resp.rc})"
|
|
1560
|
+
raise RuntimeError(msg)
|
|
1561
|
+
|
|
1562
|
+
# 2) Prepare a do-file to export PNG externally
|
|
1563
|
+
user_filename_fwd = user_filename.replace("\\", "/")
|
|
1564
|
+
do_lines = [
|
|
1565
|
+
f'graph use "{gph_path_for_stata}"',
|
|
1566
|
+
f'graph export "{user_filename_fwd}", replace as(png)',
|
|
1567
|
+
"exit",
|
|
1568
|
+
]
|
|
1569
|
+
with tempfile.NamedTemporaryFile(prefix="mcp_stata_export_", suffix=".do", delete=False, mode="w", encoding="ascii") as do_tmp:
|
|
1570
|
+
do_tmp.write("\n".join(do_lines))
|
|
1571
|
+
do_path = do_tmp.name
|
|
1572
|
+
|
|
1573
|
+
stata_exe = getattr(self, "_stata_exec_path", None)
|
|
1574
|
+
if not stata_exe or not os.path.exists(stata_exe):
|
|
1575
|
+
raise RuntimeError("Stata executable path unavailable for PNG export")
|
|
1576
|
+
|
|
1577
|
+
workdir = os.path.dirname(do_path) or None
|
|
1578
|
+
log_path = os.path.splitext(do_path)[0] + ".log"
|
|
1579
|
+
|
|
1580
|
+
cmd = [stata_exe, "/e", "do", do_path]
|
|
1581
|
+
try:
|
|
1582
|
+
completed = subprocess.run(
|
|
1583
|
+
cmd,
|
|
1584
|
+
capture_output=True,
|
|
1585
|
+
text=True,
|
|
1586
|
+
timeout=30,
|
|
1587
|
+
cwd=workdir,
|
|
1588
|
+
)
|
|
1589
|
+
except subprocess.TimeoutExpired:
|
|
1590
|
+
raise RuntimeError("External Stata export timed out")
|
|
1591
|
+
finally:
|
|
1592
|
+
try:
|
|
1593
|
+
os.remove(do_path)
|
|
1594
|
+
except Exception:
|
|
1595
|
+
# Ignore errors during temporary do-file cleanup (file may not exist or be locked)
|
|
1596
|
+
logger.warning("Failed to remove temporary do-file: %s", do_path, exc_info=True)
|
|
1597
|
+
|
|
1598
|
+
try:
|
|
1599
|
+
os.remove(gph_path)
|
|
1600
|
+
except Exception:
|
|
1601
|
+
logger.warning("Failed to remove temporary graph file: %s", gph_path, exc_info=True)
|
|
1602
|
+
|
|
1603
|
+
try:
|
|
1604
|
+
if os.path.exists(log_path):
|
|
1605
|
+
os.remove(log_path)
|
|
1606
|
+
except Exception:
|
|
1607
|
+
logger.warning("Failed to remove temporary log file: %s", log_path, exc_info=True)
|
|
1608
|
+
|
|
1609
|
+
if completed.returncode != 0:
|
|
1610
|
+
err = completed.stderr.strip() or completed.stdout.strip() or str(completed.returncode)
|
|
1611
|
+
raise RuntimeError(f"External Stata export failed: {err}")
|
|
1612
|
+
|
|
1613
|
+
else:
|
|
1614
|
+
# Stata prefers forward slashes in its command parser on Windows
|
|
1615
|
+
filename_for_stata = user_filename.replace("\\", "/")
|
|
1616
|
+
|
|
1617
|
+
cmd = "graph export"
|
|
1618
|
+
if graph_name:
|
|
1619
|
+
resolved = self._resolve_graph_name_for_stata(graph_name)
|
|
1620
|
+
cmd += f' "{filename_for_stata}", name("{resolved}") replace as({fmt})'
|
|
1621
|
+
else:
|
|
1622
|
+
cmd += f' "{filename_for_stata}", replace as({fmt})'
|
|
1623
|
+
|
|
1624
|
+
# Avoid stdout/stderr redirection for graph export because PyStata's
|
|
1625
|
+
# output thread can crash on Windows when we swap stdio handles.
|
|
1626
|
+
resp = self._exec_no_capture(cmd, echo=False)
|
|
1627
|
+
if not resp.success:
|
|
1628
|
+
# Retry once after a short pause in case Stata had a transient file handle issue
|
|
1629
|
+
time.sleep(0.2)
|
|
1630
|
+
resp_retry = self._exec_no_capture(cmd, echo=False)
|
|
1631
|
+
if not resp_retry.success:
|
|
1632
|
+
msg = resp_retry.error.message if resp_retry.error else f"graph export failed (rc={resp_retry.rc})"
|
|
1633
|
+
raise RuntimeError(msg)
|
|
1634
|
+
resp = resp_retry
|
|
1635
|
+
|
|
1636
|
+
if os.path.exists(user_filename):
|
|
1637
|
+
try:
|
|
1638
|
+
size = os.path.getsize(user_filename)
|
|
1639
|
+
if size == 0:
|
|
1640
|
+
raise RuntimeError(f"Graph export failed: produced empty file {user_filename}")
|
|
1641
|
+
if size > self.MAX_GRAPH_BYTES:
|
|
1642
|
+
raise RuntimeError(
|
|
1643
|
+
f"Graph export failed: file too large (> {self.MAX_GRAPH_BYTES} bytes): {user_filename}"
|
|
1644
|
+
)
|
|
1645
|
+
except Exception as size_err:
|
|
1646
|
+
# Clean up oversized or unreadable files
|
|
1647
|
+
try:
|
|
1648
|
+
os.remove(user_filename)
|
|
1649
|
+
except Exception:
|
|
1650
|
+
pass
|
|
1651
|
+
raise size_err
|
|
1652
|
+
return user_filename
|
|
1653
|
+
|
|
1654
|
+
# If file missing, it failed. Check output for details.
|
|
1655
|
+
msg = resp.error.message if resp.error else "graph export failed: file missing"
|
|
1656
|
+
raise RuntimeError(msg)
|
|
1657
|
+
|
|
1658
|
+
def get_help(self, topic: str, plain_text: bool = False) -> str:
|
|
1659
|
+
"""Returns help text as Markdown (default) or plain text."""
|
|
1660
|
+
if not self._initialized:
|
|
1661
|
+
self.init()
|
|
1662
|
+
|
|
1663
|
+
# Try to locate the .sthlp help file
|
|
1664
|
+
# We use 'capture' to avoid crashing if not found
|
|
1665
|
+
self.stata.run(f"capture findfile {topic}.sthlp")
|
|
1666
|
+
|
|
1667
|
+
# Retrieve the found path from r(fn)
|
|
1668
|
+
from sfi import Macro # type: ignore[import-not-found]
|
|
1669
|
+
self.stata.run("global mcp_help_file `r(fn)'")
|
|
1670
|
+
fn = Macro.getGlobal("mcp_help_file")
|
|
1671
|
+
|
|
1672
|
+
if fn and os.path.exists(fn):
|
|
1673
|
+
try:
|
|
1674
|
+
with open(fn, 'r', encoding='utf-8', errors='replace') as f:
|
|
413
1675
|
smcl = f.read()
|
|
414
1676
|
if plain_text:
|
|
415
1677
|
return self._smcl_to_text(smcl)
|
|
@@ -428,78 +1690,514 @@ class StataClient:
|
|
|
428
1690
|
"""Returns e() and r() results."""
|
|
429
1691
|
if not self._initialized:
|
|
430
1692
|
self.init()
|
|
431
|
-
|
|
432
|
-
from sfi import Scalar, Macro
|
|
433
|
-
|
|
1693
|
+
|
|
434
1694
|
results = {"r": {}, "e": {}}
|
|
435
|
-
|
|
1695
|
+
|
|
436
1696
|
# We parse 'return list' output as there is no direct bulk export of stored results
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
1697
|
+
raw_r_resp = self.run_command_structured("return list", echo=True)
|
|
1698
|
+
raw_e_resp = self.run_command_structured("ereturn list", echo=True)
|
|
1699
|
+
raw_r = raw_r_resp.stdout if raw_r_resp.success else (raw_r_resp.error.snippet if raw_r_resp.error else "")
|
|
1700
|
+
raw_e = raw_e_resp.stdout if raw_e_resp.success else (raw_e_resp.error.snippet if raw_e_resp.error else "")
|
|
1701
|
+
|
|
440
1702
|
# Simple parser
|
|
441
1703
|
def parse_list(text):
|
|
442
1704
|
data = {}
|
|
443
1705
|
# We don't strictly need to track sections if we check patterns
|
|
444
1706
|
for line in text.splitlines():
|
|
445
1707
|
line = line.strip()
|
|
446
|
-
if not line:
|
|
447
|
-
|
|
1708
|
+
if not line:
|
|
1709
|
+
continue
|
|
1710
|
+
|
|
448
1711
|
# scalars: r(name) = value
|
|
449
1712
|
if "=" in line and ("r(" in line or "e(" in line):
|
|
450
1713
|
try:
|
|
451
1714
|
name_part, val_part = line.split("=", 1)
|
|
452
1715
|
name_part = name_part.strip() # "r(mean)"
|
|
453
1716
|
val_part = val_part.strip() # "6165.2..."
|
|
454
|
-
|
|
455
|
-
# Extract just the name inside r(...) if desired,
|
|
456
|
-
# or keep full key "r(mean)".
|
|
1717
|
+
|
|
1718
|
+
# Extract just the name inside r(...) if desired,
|
|
1719
|
+
# or keep full key "r(mean)".
|
|
457
1720
|
# User likely wants "mean" inside "r" dict.
|
|
458
|
-
|
|
1721
|
+
|
|
459
1722
|
if "(" in name_part and name_part.endswith(")"):
|
|
460
1723
|
# r(mean) -> mean
|
|
461
1724
|
start = name_part.find("(") + 1
|
|
462
1725
|
end = name_part.find(")")
|
|
463
1726
|
key = name_part[start:end]
|
|
464
1727
|
data[key] = val_part
|
|
465
|
-
except:
|
|
466
|
-
|
|
1728
|
+
except Exception:
|
|
1729
|
+
pass
|
|
1730
|
+
|
|
467
1731
|
# macros: r(name) : "value"
|
|
468
1732
|
elif ":" in line and ("r(" in line or "e(" in line):
|
|
469
|
-
|
|
1733
|
+
try:
|
|
470
1734
|
name_part, val_part = line.split(":", 1)
|
|
471
1735
|
name_part = name_part.strip()
|
|
472
1736
|
val_part = val_part.strip().strip('"')
|
|
473
|
-
|
|
1737
|
+
|
|
474
1738
|
if "(" in name_part and name_part.endswith(")"):
|
|
475
1739
|
start = name_part.find("(") + 1
|
|
476
1740
|
end = name_part.find(")")
|
|
477
1741
|
key = name_part[start:end]
|
|
478
1742
|
data[key] = val_part
|
|
479
|
-
|
|
1743
|
+
except Exception:
|
|
1744
|
+
pass
|
|
480
1745
|
return data
|
|
481
|
-
|
|
1746
|
+
|
|
482
1747
|
results["r"] = parse_list(raw_r)
|
|
483
1748
|
results["e"] = parse_list(raw_e)
|
|
484
|
-
|
|
1749
|
+
|
|
485
1750
|
return results
|
|
486
1751
|
|
|
487
|
-
def
|
|
488
|
-
"""
|
|
1752
|
+
def invalidate_graph_cache(self, graph_name: str = None) -> None:
|
|
1753
|
+
"""Invalidate cache for specific graph or all graphs.
|
|
1754
|
+
|
|
1755
|
+
Args:
|
|
1756
|
+
graph_name: Specific graph name to invalidate. If None, clears all cache.
|
|
1757
|
+
"""
|
|
1758
|
+
self._initialize_cache()
|
|
1759
|
+
|
|
1760
|
+
with self._cache_lock:
|
|
1761
|
+
if graph_name is None:
|
|
1762
|
+
# Clear all cache
|
|
1763
|
+
self._preemptive_cache.clear()
|
|
1764
|
+
else:
|
|
1765
|
+
# Clear specific graph cache
|
|
1766
|
+
if graph_name in self._preemptive_cache:
|
|
1767
|
+
del self._preemptive_cache[graph_name]
|
|
1768
|
+
# Also clear hash if present
|
|
1769
|
+
hash_key = f"{graph_name}_hash"
|
|
1770
|
+
if hash_key in self._preemptive_cache:
|
|
1771
|
+
del self._preemptive_cache[hash_key]
|
|
1772
|
+
|
|
1773
|
+
def _initialize_cache(self) -> None:
|
|
1774
|
+
"""Initialize cache in a thread-safe manner."""
|
|
1775
|
+
import tempfile
|
|
1776
|
+
import threading
|
|
1777
|
+
import os
|
|
1778
|
+
import uuid
|
|
1779
|
+
|
|
1780
|
+
with StataClient._cache_init_lock: # Use class-level lock
|
|
1781
|
+
if not hasattr(self, '_cache_initialized'):
|
|
1782
|
+
self._preemptive_cache = {}
|
|
1783
|
+
self._cache_access_times = {} # Track access times for LRU
|
|
1784
|
+
self._cache_sizes = {} # Track individual cache item sizes
|
|
1785
|
+
self._total_cache_size = 0 # Track total cache size in bytes
|
|
1786
|
+
# Use unique identifier to avoid conflicts
|
|
1787
|
+
unique_id = f"preemptive_cache_{uuid.uuid4().hex[:8]}_{os.getpid()}"
|
|
1788
|
+
self._preemptive_cache_dir = tempfile.mkdtemp(prefix=unique_id)
|
|
1789
|
+
self._cache_lock = threading.Lock()
|
|
1790
|
+
self._cache_initialized = True
|
|
1791
|
+
|
|
1792
|
+
# Register cleanup function
|
|
1793
|
+
import atexit
|
|
1794
|
+
atexit.register(self._cleanup_cache)
|
|
1795
|
+
else:
|
|
1796
|
+
# Cache already initialized, but directory might have been removed.
|
|
1797
|
+
if (not hasattr(self, '_preemptive_cache_dir') or
|
|
1798
|
+
not self._preemptive_cache_dir or
|
|
1799
|
+
not os.path.isdir(self._preemptive_cache_dir)):
|
|
1800
|
+
unique_id = f"preemptive_cache_{uuid.uuid4().hex[:8]}_{os.getpid()}"
|
|
1801
|
+
self._preemptive_cache_dir = tempfile.mkdtemp(prefix=unique_id)
|
|
1802
|
+
|
|
1803
|
+
def _cleanup_cache(self) -> None:
|
|
1804
|
+
"""Clean up cache directory and files."""
|
|
1805
|
+
import os
|
|
1806
|
+
import shutil
|
|
1807
|
+
|
|
1808
|
+
if hasattr(self, '_preemptive_cache_dir') and self._preemptive_cache_dir:
|
|
1809
|
+
try:
|
|
1810
|
+
shutil.rmtree(self._preemptive_cache_dir, ignore_errors=True)
|
|
1811
|
+
except Exception:
|
|
1812
|
+
pass # Best effort cleanup
|
|
1813
|
+
|
|
1814
|
+
if hasattr(self, '_preemptive_cache'):
|
|
1815
|
+
self._preemptive_cache.clear()
|
|
1816
|
+
|
|
1817
|
+
def _evict_cache_if_needed(self, new_item_size: int = 0) -> None:
|
|
1818
|
+
"""
|
|
1819
|
+
Evict least recently used cache items if cache size limits are exceeded.
|
|
1820
|
+
|
|
1821
|
+
NOTE: The caller is responsible for holding ``self._cache_lock`` while
|
|
1822
|
+
invoking this method, so that eviction and subsequent cache insertion
|
|
1823
|
+
(if any) occur within a single critical section.
|
|
1824
|
+
"""
|
|
1825
|
+
import time
|
|
1826
|
+
|
|
1827
|
+
# Check if we need to evict based on count or size
|
|
1828
|
+
needs_eviction = (
|
|
1829
|
+
len(self._preemptive_cache) > StataClient.MAX_CACHE_SIZE or
|
|
1830
|
+
self._total_cache_size + new_item_size > StataClient.MAX_CACHE_BYTES
|
|
1831
|
+
)
|
|
1832
|
+
|
|
1833
|
+
if not needs_eviction:
|
|
1834
|
+
return
|
|
1835
|
+
|
|
1836
|
+
# Sort by access time (oldest first)
|
|
1837
|
+
items_by_access = sorted(
|
|
1838
|
+
self._cache_access_times.items(),
|
|
1839
|
+
key=lambda x: x[1]
|
|
1840
|
+
)
|
|
1841
|
+
|
|
1842
|
+
evicted_count = 0
|
|
1843
|
+
for graph_name, access_time in items_by_access:
|
|
1844
|
+
if (len(self._preemptive_cache) < StataClient.MAX_CACHE_SIZE and
|
|
1845
|
+
self._total_cache_size + new_item_size <= StataClient.MAX_CACHE_BYTES):
|
|
1846
|
+
break
|
|
1847
|
+
|
|
1848
|
+
# Remove from cache
|
|
1849
|
+
if graph_name in self._preemptive_cache:
|
|
1850
|
+
cache_path = self._preemptive_cache[graph_name]
|
|
1851
|
+
|
|
1852
|
+
# Remove file
|
|
1853
|
+
try:
|
|
1854
|
+
if os.path.exists(cache_path):
|
|
1855
|
+
os.remove(cache_path)
|
|
1856
|
+
except Exception:
|
|
1857
|
+
pass
|
|
1858
|
+
|
|
1859
|
+
# Update tracking
|
|
1860
|
+
item_size = self._cache_sizes.get(graph_name, 0)
|
|
1861
|
+
del self._preemptive_cache[graph_name]
|
|
1862
|
+
del self._cache_access_times[graph_name]
|
|
1863
|
+
if graph_name in self._cache_sizes:
|
|
1864
|
+
del self._cache_sizes[graph_name]
|
|
1865
|
+
self._total_cache_size -= item_size
|
|
1866
|
+
evicted_count += 1
|
|
1867
|
+
|
|
1868
|
+
# Remove hash entry if exists
|
|
1869
|
+
hash_key = f"{graph_name}_hash"
|
|
1870
|
+
if hash_key in self._preemptive_cache:
|
|
1871
|
+
del self._preemptive_cache[hash_key]
|
|
1872
|
+
|
|
1873
|
+
if evicted_count > 0:
|
|
1874
|
+
logger.debug(f"Evicted {evicted_count} items from graph cache due to size limits")
|
|
1875
|
+
|
|
1876
|
+
def _get_content_hash(self, data: bytes) -> str:
|
|
1877
|
+
"""Generate content hash for cache validation."""
|
|
1878
|
+
import hashlib
|
|
1879
|
+
return hashlib.md5(data).hexdigest()
|
|
1880
|
+
|
|
1881
|
+
def _sanitize_filename(self, name: str) -> str:
|
|
1882
|
+
"""Sanitize graph name for safe file system usage."""
|
|
1883
|
+
import re
|
|
1884
|
+
# Remove or replace problematic characters
|
|
1885
|
+
safe_name = re.sub(r'[<>:"/\\|?*]', '_', name)
|
|
1886
|
+
safe_name = re.sub(r'[^\w\-_.]', '_', safe_name)
|
|
1887
|
+
# Limit length
|
|
1888
|
+
return safe_name[:100] if len(safe_name) > 100 else safe_name
|
|
1889
|
+
|
|
1890
|
+
def _validate_graph_exists(self, graph_name: str) -> bool:
|
|
1891
|
+
"""Validate that graph still exists in Stata."""
|
|
1892
|
+
try:
|
|
1893
|
+
# First try to get graph list to verify existence
|
|
1894
|
+
graph_list = self.list_graphs(force_refresh=True)
|
|
1895
|
+
if graph_name not in graph_list:
|
|
1896
|
+
return False
|
|
1897
|
+
|
|
1898
|
+
# Additional validation by attempting to display the graph
|
|
1899
|
+
resolved = self._resolve_graph_name_for_stata(graph_name)
|
|
1900
|
+
cmd = f'graph display {resolved}'
|
|
1901
|
+
resp = self._exec_no_capture(cmd, echo=False)
|
|
1902
|
+
return resp.success
|
|
1903
|
+
except Exception:
|
|
1904
|
+
return False
|
|
1905
|
+
|
|
1906
|
+
def _is_cache_valid(self, graph_name: str, cache_path: str) -> bool:
|
|
1907
|
+
"""Check if cached content is still valid."""
|
|
1908
|
+
try:
|
|
1909
|
+
# Get current graph content hash
|
|
1910
|
+
import tempfile
|
|
1911
|
+
import os
|
|
1912
|
+
|
|
1913
|
+
temp_dir = tempfile.gettempdir()
|
|
1914
|
+
temp_file = os.path.join(temp_dir, f"temp_{graph_name}_{os.getpid()}.svg")
|
|
1915
|
+
|
|
1916
|
+
resolved = self._resolve_graph_name_for_stata(graph_name)
|
|
1917
|
+
export_cmd = f'graph export "{temp_file.replace("\\\\", "/")}", name({resolved}) replace as(svg)'
|
|
1918
|
+
resp = self._exec_no_capture(export_cmd, echo=False)
|
|
1919
|
+
|
|
1920
|
+
if resp.success and os.path.exists(temp_file):
|
|
1921
|
+
with open(temp_file, 'rb') as f:
|
|
1922
|
+
current_data = f.read()
|
|
1923
|
+
os.remove(temp_file)
|
|
1924
|
+
|
|
1925
|
+
current_hash = self._get_content_hash(current_data)
|
|
1926
|
+
cached_hash = self._preemptive_cache.get(f"{graph_name}_hash")
|
|
1927
|
+
|
|
1928
|
+
return cached_hash == current_hash
|
|
1929
|
+
except Exception:
|
|
1930
|
+
pass
|
|
1931
|
+
|
|
1932
|
+
return False # Assume invalid if we can't verify
|
|
1933
|
+
|
|
1934
|
+
def export_graphs_all(self, use_base64: bool = False) -> GraphExportResponse:
|
|
1935
|
+
"""Exports all graphs to file paths (default) or base64-encoded strings.
|
|
1936
|
+
|
|
1937
|
+
Args:
|
|
1938
|
+
use_base64: If True, returns base64-encoded images. If False (default),
|
|
1939
|
+
returns file paths to exported SVG files.
|
|
1940
|
+
"""
|
|
489
1941
|
exports: List[GraphExport] = []
|
|
490
|
-
|
|
1942
|
+
graph_names = self.list_graphs(force_refresh=True)
|
|
1943
|
+
|
|
1944
|
+
if not graph_names:
|
|
1945
|
+
return GraphExportResponse(graphs=exports)
|
|
1946
|
+
|
|
1947
|
+
import tempfile
|
|
1948
|
+
import os
|
|
1949
|
+
import threading
|
|
1950
|
+
import base64
|
|
1951
|
+
import uuid
|
|
1952
|
+
import time
|
|
1953
|
+
import logging
|
|
1954
|
+
|
|
1955
|
+
# Initialize cache in thread-safe manner
|
|
1956
|
+
self._initialize_cache()
|
|
1957
|
+
|
|
1958
|
+
def _cache_keyed_svg_path(name: str) -> str:
|
|
1959
|
+
import hashlib
|
|
1960
|
+
safe_name = self._sanitize_filename(name)
|
|
1961
|
+
suffix = hashlib.md5((name or "").encode("utf-8")).hexdigest()[:8]
|
|
1962
|
+
return os.path.join(self._preemptive_cache_dir, f"{safe_name}_{suffix}.svg")
|
|
1963
|
+
|
|
1964
|
+
def _export_svg_bytes(name: str) -> bytes:
|
|
1965
|
+
resolved = self._resolve_graph_name_for_stata(name)
|
|
1966
|
+
|
|
1967
|
+
temp_dir = tempfile.gettempdir()
|
|
1968
|
+
safe_temp_name = self._sanitize_filename(name)
|
|
1969
|
+
unique_filename = f"{safe_temp_name}_{uuid.uuid4().hex[:8]}_{os.getpid()}_{int(time.time())}.svg"
|
|
1970
|
+
svg_path = os.path.join(temp_dir, unique_filename)
|
|
1971
|
+
svg_path_for_stata = svg_path.replace("\\", "/")
|
|
1972
|
+
|
|
1973
|
+
try:
|
|
1974
|
+
export_cmd = f'graph export "{svg_path_for_stata}", name({resolved}) replace as(svg)'
|
|
1975
|
+
export_resp = self._exec_no_capture(export_cmd, echo=False)
|
|
1976
|
+
|
|
1977
|
+
if not export_resp.success:
|
|
1978
|
+
display_cmd = f'graph display {resolved}'
|
|
1979
|
+
display_resp = self._exec_no_capture(display_cmd, echo=False)
|
|
1980
|
+
if display_resp.success:
|
|
1981
|
+
export_cmd2 = f'graph export "{svg_path_for_stata}", replace as(svg)'
|
|
1982
|
+
export_resp = self._exec_no_capture(export_cmd2, echo=False)
|
|
1983
|
+
else:
|
|
1984
|
+
export_resp = display_resp
|
|
1985
|
+
|
|
1986
|
+
if export_resp.success and os.path.exists(svg_path) and os.path.getsize(svg_path) > 0:
|
|
1987
|
+
with open(svg_path, "rb") as f:
|
|
1988
|
+
return f.read()
|
|
1989
|
+
error_msg = getattr(export_resp, 'error', 'Unknown error')
|
|
1990
|
+
raise RuntimeError(f"Failed to export graph {name}: {error_msg}")
|
|
1991
|
+
finally:
|
|
1992
|
+
if os.path.exists(svg_path):
|
|
1993
|
+
try:
|
|
1994
|
+
os.remove(svg_path)
|
|
1995
|
+
except OSError as e:
|
|
1996
|
+
logger.warning(f"Failed to cleanup temp file {svg_path}: {e}")
|
|
1997
|
+
|
|
1998
|
+
cached_graphs = {}
|
|
1999
|
+
uncached_graphs = []
|
|
2000
|
+
cache_errors = []
|
|
2001
|
+
|
|
2002
|
+
with self._cache_lock:
|
|
2003
|
+
for name in graph_names:
|
|
2004
|
+
if name in self._preemptive_cache:
|
|
2005
|
+
cached_path = self._preemptive_cache[name]
|
|
2006
|
+
if os.path.exists(cached_path) and os.path.getsize(cached_path) > 0:
|
|
2007
|
+
# Additional validation: check if graph content has changed
|
|
2008
|
+
if self._is_cache_valid(name, cached_path):
|
|
2009
|
+
cached_graphs[name] = cached_path
|
|
2010
|
+
else:
|
|
2011
|
+
uncached_graphs.append(name)
|
|
2012
|
+
# Remove stale cache entry
|
|
2013
|
+
del self._preemptive_cache[name]
|
|
2014
|
+
else:
|
|
2015
|
+
uncached_graphs.append(name)
|
|
2016
|
+
# Remove invalid cache entry
|
|
2017
|
+
if name in self._preemptive_cache:
|
|
2018
|
+
del self._preemptive_cache[name]
|
|
2019
|
+
else:
|
|
2020
|
+
uncached_graphs.append(name)
|
|
2021
|
+
|
|
2022
|
+
for name, cached_path in cached_graphs.items():
|
|
491
2023
|
try:
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
2024
|
+
if use_base64:
|
|
2025
|
+
with open(cached_path, "rb") as f:
|
|
2026
|
+
svg_b64 = base64.b64encode(f.read()).decode("ascii")
|
|
2027
|
+
exports.append(GraphExport(name=name, image_base64=svg_b64))
|
|
2028
|
+
else:
|
|
2029
|
+
exports.append(GraphExport(name=name, file_path=cached_path))
|
|
496
2030
|
except Exception as e:
|
|
497
|
-
|
|
498
|
-
|
|
2031
|
+
cache_errors.append(f"Failed to read cached graph {name}: {e}")
|
|
2032
|
+
# Fall back to uncached processing
|
|
2033
|
+
uncached_graphs.append(name)
|
|
2034
|
+
|
|
2035
|
+
if uncached_graphs:
|
|
2036
|
+
successful_graphs = []
|
|
2037
|
+
failed_graphs = []
|
|
2038
|
+
memory_results = {}
|
|
2039
|
+
|
|
2040
|
+
for name in uncached_graphs:
|
|
2041
|
+
try:
|
|
2042
|
+
svg_data = _export_svg_bytes(name)
|
|
2043
|
+
memory_results[name] = svg_data
|
|
2044
|
+
successful_graphs.append(name)
|
|
2045
|
+
except Exception as e:
|
|
2046
|
+
failed_graphs.append(name)
|
|
2047
|
+
cache_errors.append(f"Failed to cache graph {name}: {e}")
|
|
2048
|
+
|
|
2049
|
+
for name in successful_graphs:
|
|
2050
|
+
result = memory_results[name]
|
|
2051
|
+
|
|
2052
|
+
cache_path = _cache_keyed_svg_path(name)
|
|
2053
|
+
|
|
2054
|
+
try:
|
|
2055
|
+
with open(cache_path, 'wb') as f:
|
|
2056
|
+
f.write(result)
|
|
2057
|
+
|
|
2058
|
+
# Update cache with size tracking and eviction
|
|
2059
|
+
import time
|
|
2060
|
+
item_size = len(result)
|
|
2061
|
+
self._evict_cache_if_needed(item_size)
|
|
2062
|
+
|
|
2063
|
+
with self._cache_lock:
|
|
2064
|
+
self._preemptive_cache[name] = cache_path
|
|
2065
|
+
# Store content hash for validation
|
|
2066
|
+
self._preemptive_cache[f"{name}_hash"] = self._get_content_hash(result)
|
|
2067
|
+
# Update tracking
|
|
2068
|
+
self._cache_access_times[name] = time.time()
|
|
2069
|
+
self._cache_sizes[name] = item_size
|
|
2070
|
+
self._total_cache_size += item_size
|
|
2071
|
+
|
|
2072
|
+
if use_base64:
|
|
2073
|
+
svg_b64 = base64.b64encode(result).decode("ascii")
|
|
2074
|
+
exports.append(GraphExport(name=name, image_base64=svg_b64))
|
|
2075
|
+
else:
|
|
2076
|
+
exports.append(GraphExport(name=name, file_path=cache_path))
|
|
2077
|
+
except Exception as e:
|
|
2078
|
+
cache_errors.append(f"Failed to cache graph {name}: {e}")
|
|
2079
|
+
# Still return the result even if caching fails
|
|
2080
|
+
if use_base64:
|
|
2081
|
+
svg_b64 = base64.b64encode(result).decode("ascii")
|
|
2082
|
+
exports.append(GraphExport(name=name, image_base64=svg_b64))
|
|
2083
|
+
else:
|
|
2084
|
+
# Create temp file for immediate use
|
|
2085
|
+
safe_name = self._sanitize_filename(name)
|
|
2086
|
+
temp_path = os.path.join(tempfile.gettempdir(), f"{safe_name}_{uuid.uuid4().hex[:8]}.svg")
|
|
2087
|
+
with open(temp_path, 'wb') as f:
|
|
2088
|
+
f.write(result)
|
|
2089
|
+
exports.append(GraphExport(name=name, file_path=temp_path))
|
|
2090
|
+
|
|
2091
|
+
# Log errors if any occurred
|
|
2092
|
+
if cache_errors:
|
|
2093
|
+
logger = logging.getLogger(__name__)
|
|
2094
|
+
for error in cache_errors:
|
|
2095
|
+
logger.warning(error)
|
|
2096
|
+
|
|
499
2097
|
return GraphExportResponse(graphs=exports)
|
|
500
2098
|
|
|
501
|
-
def
|
|
502
|
-
|
|
2099
|
+
def cache_graph_on_creation(self, graph_name: str) -> bool:
|
|
2100
|
+
"""Revolutionary method to cache a graph immediately after creation.
|
|
2101
|
+
|
|
2102
|
+
Call this method right after creating a graph to pre-emptively cache it.
|
|
2103
|
+
This eliminates all export wait time for future access.
|
|
2104
|
+
|
|
2105
|
+
Args:
|
|
2106
|
+
graph_name: Name of the graph to cache
|
|
2107
|
+
|
|
2108
|
+
Returns:
|
|
2109
|
+
True if caching succeeded, False otherwise
|
|
2110
|
+
"""
|
|
2111
|
+
import os
|
|
2112
|
+
import logging
|
|
2113
|
+
|
|
2114
|
+
# Initialize cache in thread-safe manner
|
|
2115
|
+
self._initialize_cache()
|
|
2116
|
+
|
|
2117
|
+
# Invalidate list_graphs cache since a new graph was created
|
|
2118
|
+
self.invalidate_list_graphs_cache()
|
|
2119
|
+
|
|
2120
|
+
# Check if already cached and valid
|
|
2121
|
+
with self._cache_lock:
|
|
2122
|
+
if graph_name in self._preemptive_cache:
|
|
2123
|
+
cache_path = self._preemptive_cache[graph_name]
|
|
2124
|
+
if os.path.exists(cache_path) and os.path.getsize(cache_path) > 0:
|
|
2125
|
+
if self._is_cache_valid(graph_name, cache_path):
|
|
2126
|
+
# Update access time for LRU
|
|
2127
|
+
import time
|
|
2128
|
+
self._cache_access_times[graph_name] = time.time()
|
|
2129
|
+
return True
|
|
2130
|
+
else:
|
|
2131
|
+
# Remove stale cache entry
|
|
2132
|
+
del self._preemptive_cache[graph_name]
|
|
2133
|
+
if graph_name in self._cache_access_times:
|
|
2134
|
+
del self._cache_access_times[graph_name]
|
|
2135
|
+
if graph_name in self._cache_sizes:
|
|
2136
|
+
self._total_cache_size -= self._cache_sizes[graph_name]
|
|
2137
|
+
del self._cache_sizes[graph_name]
|
|
2138
|
+
# Remove hash entry if exists
|
|
2139
|
+
hash_key = f"{graph_name}_hash"
|
|
2140
|
+
if hash_key in self._preemptive_cache:
|
|
2141
|
+
del self._preemptive_cache[hash_key]
|
|
2142
|
+
|
|
2143
|
+
try:
|
|
2144
|
+
# Sanitize graph name for file system
|
|
2145
|
+
safe_name = self._sanitize_filename(graph_name)
|
|
2146
|
+
cache_path = os.path.join(self._preemptive_cache_dir, f"{safe_name}.svg")
|
|
2147
|
+
cache_path_for_stata = cache_path.replace("\\", "/")
|
|
2148
|
+
|
|
2149
|
+
resolved_graph_name = self._resolve_graph_name_for_stata(graph_name)
|
|
2150
|
+
graph_name_q = self._stata_quote(resolved_graph_name)
|
|
2151
|
+
|
|
2152
|
+
export_cmd = f'graph export "{cache_path_for_stata}", name({graph_name_q}) replace as(svg)'
|
|
2153
|
+
resp = self._exec_no_capture(export_cmd, echo=False)
|
|
2154
|
+
|
|
2155
|
+
# Fallback: some graph names (spaces, slashes, backslashes) can confuse
|
|
2156
|
+
# Stata's parser in name() even when the graph exists. In that case,
|
|
2157
|
+
# make the graph current, then export without name().
|
|
2158
|
+
if not resp.success:
|
|
2159
|
+
try:
|
|
2160
|
+
display_cmd = f'graph display {graph_name_q}'
|
|
2161
|
+
display_resp = self._exec_no_capture(display_cmd, echo=False)
|
|
2162
|
+
if display_resp.success:
|
|
2163
|
+
export_cmd2 = f'graph export "{cache_path_for_stata}", replace as(svg)'
|
|
2164
|
+
resp = self._exec_no_capture(export_cmd2, echo=False)
|
|
2165
|
+
except Exception:
|
|
2166
|
+
pass
|
|
2167
|
+
|
|
2168
|
+
if resp.success and os.path.exists(cache_path) and os.path.getsize(cache_path) > 0:
|
|
2169
|
+
# Read the data to compute hash
|
|
2170
|
+
with open(cache_path, 'rb') as f:
|
|
2171
|
+
data = f.read()
|
|
2172
|
+
|
|
2173
|
+
# Update cache with size tracking and eviction
|
|
2174
|
+
import time
|
|
2175
|
+
item_size = len(data)
|
|
2176
|
+
self._evict_cache_if_needed(item_size)
|
|
2177
|
+
|
|
2178
|
+
with self._cache_lock:
|
|
2179
|
+
self._preemptive_cache[graph_name] = cache_path
|
|
2180
|
+
# Store content hash for validation
|
|
2181
|
+
self._preemptive_cache[f"{graph_name}_hash"] = self._get_content_hash(data)
|
|
2182
|
+
# Update tracking
|
|
2183
|
+
self._cache_access_times[graph_name] = time.time()
|
|
2184
|
+
self._cache_sizes[graph_name] = item_size
|
|
2185
|
+
self._total_cache_size += item_size
|
|
2186
|
+
|
|
2187
|
+
return True
|
|
2188
|
+
else:
|
|
2189
|
+
error_msg = getattr(resp, 'error', 'Unknown error')
|
|
2190
|
+
logger = logging.getLogger(__name__)
|
|
2191
|
+
logger.warning(f"Failed to cache graph {graph_name}: {error_msg}")
|
|
2192
|
+
|
|
2193
|
+
except Exception as e:
|
|
2194
|
+
logger = logging.getLogger(__name__)
|
|
2195
|
+
logger.warning(f"Exception caching graph {graph_name}: {e}")
|
|
2196
|
+
|
|
2197
|
+
return False
|
|
2198
|
+
|
|
2199
|
+
def run_do_file(self, path: str, echo: bool = True, trace: bool = False, max_output_lines: Optional[int] = None, cwd: Optional[str] = None) -> CommandResponse:
|
|
2200
|
+
if cwd is not None and not os.path.isdir(cwd):
|
|
503
2201
|
return CommandResponse(
|
|
504
2202
|
command=f'do "{path}"',
|
|
505
2203
|
rc=601,
|
|
@@ -507,14 +2205,133 @@ class StataClient:
|
|
|
507
2205
|
stderr=None,
|
|
508
2206
|
success=False,
|
|
509
2207
|
error=ErrorEnvelope(
|
|
510
|
-
message=f"
|
|
2208
|
+
message=f"cwd not found: {cwd}",
|
|
511
2209
|
rc=601,
|
|
512
2210
|
command=path,
|
|
513
2211
|
),
|
|
514
2212
|
)
|
|
515
|
-
return self._exec_with_capture(f'do "{path}"', echo=echo, trace=trace)
|
|
516
2213
|
|
|
517
|
-
|
|
2214
|
+
effective_path = path
|
|
2215
|
+
if cwd is not None and not os.path.isabs(path):
|
|
2216
|
+
effective_path = os.path.abspath(os.path.join(cwd, path))
|
|
2217
|
+
|
|
2218
|
+
if not os.path.exists(effective_path):
|
|
2219
|
+
return CommandResponse(
|
|
2220
|
+
command=f'do "{effective_path}"',
|
|
2221
|
+
rc=601,
|
|
2222
|
+
stdout="",
|
|
2223
|
+
stderr=None,
|
|
2224
|
+
success=False,
|
|
2225
|
+
error=ErrorEnvelope(
|
|
2226
|
+
message=f"Do-file not found: {effective_path}",
|
|
2227
|
+
rc=601,
|
|
2228
|
+
command=effective_path,
|
|
2229
|
+
),
|
|
2230
|
+
)
|
|
2231
|
+
|
|
2232
|
+
if not self._initialized:
|
|
2233
|
+
self.init()
|
|
2234
|
+
|
|
2235
|
+
start_time = time.time()
|
|
2236
|
+
exc: Optional[Exception] = None
|
|
2237
|
+
path_for_stata = effective_path.replace("\\", "/")
|
|
2238
|
+
command = f'do "{path_for_stata}"'
|
|
2239
|
+
|
|
2240
|
+
log_file = tempfile.NamedTemporaryFile(
|
|
2241
|
+
prefix="mcp_stata_",
|
|
2242
|
+
suffix=".log",
|
|
2243
|
+
delete=False,
|
|
2244
|
+
mode="w",
|
|
2245
|
+
encoding="utf-8",
|
|
2246
|
+
errors="replace",
|
|
2247
|
+
buffering=1,
|
|
2248
|
+
)
|
|
2249
|
+
log_path = log_file.name
|
|
2250
|
+
tail = TailBuffer(max_chars=8000)
|
|
2251
|
+
tee = FileTeeIO(log_file, tail)
|
|
2252
|
+
|
|
2253
|
+
rc = -1
|
|
2254
|
+
|
|
2255
|
+
with self._exec_lock:
|
|
2256
|
+
with self._temp_cwd(cwd):
|
|
2257
|
+
with self._redirect_io_streaming(tee, tee):
|
|
2258
|
+
try:
|
|
2259
|
+
if trace:
|
|
2260
|
+
self.stata.run("set trace on")
|
|
2261
|
+
ret = self.stata.run(command, echo=echo)
|
|
2262
|
+
# Some PyStata builds return output as a string rather than printing.
|
|
2263
|
+
if isinstance(ret, str) and ret:
|
|
2264
|
+
try:
|
|
2265
|
+
tee.write(ret)
|
|
2266
|
+
except Exception:
|
|
2267
|
+
pass
|
|
2268
|
+
except Exception as e:
|
|
2269
|
+
exc = e
|
|
2270
|
+
finally:
|
|
2271
|
+
rc = self._read_return_code()
|
|
2272
|
+
if trace:
|
|
2273
|
+
try:
|
|
2274
|
+
self.stata.run("set trace off")
|
|
2275
|
+
except Exception:
|
|
2276
|
+
pass
|
|
2277
|
+
|
|
2278
|
+
tee.close()
|
|
2279
|
+
|
|
2280
|
+
tail_text = tail.get_value()
|
|
2281
|
+
combined = (tail_text or "") + (f"\n{exc}" if exc else "")
|
|
2282
|
+
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
2283
|
+
if exc is None and rc_hint is not None and rc_hint != 0:
|
|
2284
|
+
rc = rc_hint
|
|
2285
|
+
if exc is None and rc_hint is None:
|
|
2286
|
+
rc = 0 if rc is None or rc != 0 else rc
|
|
2287
|
+
success = rc == 0 and exc is None
|
|
2288
|
+
|
|
2289
|
+
error = None
|
|
2290
|
+
if not success:
|
|
2291
|
+
snippet = (tail_text[-800:] if tail_text else None) or (str(exc) if exc else None)
|
|
2292
|
+
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
2293
|
+
rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
|
|
2294
|
+
line_no = self._parse_line_from_text(combined) if combined else None
|
|
2295
|
+
message = "Stata error"
|
|
2296
|
+
if tail_text and tail_text.strip():
|
|
2297
|
+
for line in reversed(tail_text.splitlines()):
|
|
2298
|
+
if line.strip():
|
|
2299
|
+
message = line.strip()
|
|
2300
|
+
break
|
|
2301
|
+
elif exc is not None:
|
|
2302
|
+
message = str(exc).strip() or message
|
|
2303
|
+
|
|
2304
|
+
error = ErrorEnvelope(
|
|
2305
|
+
message=message,
|
|
2306
|
+
rc=rc_final,
|
|
2307
|
+
line=line_no,
|
|
2308
|
+
command=command,
|
|
2309
|
+
log_path=log_path,
|
|
2310
|
+
snippet=snippet,
|
|
2311
|
+
trace=trace or None,
|
|
2312
|
+
)
|
|
2313
|
+
|
|
2314
|
+
duration = time.time() - start_time
|
|
2315
|
+
logger.info(
|
|
2316
|
+
"stata.run(do) rc=%s success=%s trace=%s duration_ms=%.2f path=%s",
|
|
2317
|
+
rc,
|
|
2318
|
+
success,
|
|
2319
|
+
trace,
|
|
2320
|
+
duration * 1000,
|
|
2321
|
+
effective_path,
|
|
2322
|
+
)
|
|
2323
|
+
|
|
2324
|
+
return CommandResponse(
|
|
2325
|
+
command=command,
|
|
2326
|
+
rc=rc,
|
|
2327
|
+
stdout="",
|
|
2328
|
+
stderr=None,
|
|
2329
|
+
log_path=log_path,
|
|
2330
|
+
success=success,
|
|
2331
|
+
error=error,
|
|
2332
|
+
)
|
|
2333
|
+
|
|
2334
|
+
def load_data(self, source: str, clear: bool = True, max_output_lines: Optional[int] = None) -> CommandResponse:
|
|
518
2335
|
src = source.strip()
|
|
519
2336
|
clear_suffix = ", clear" if clear else ""
|
|
520
2337
|
|
|
@@ -529,8 +2346,42 @@ class StataClient:
|
|
|
529
2346
|
else:
|
|
530
2347
|
cmd = f"sysuse {src}{clear_suffix}"
|
|
531
2348
|
|
|
532
|
-
|
|
2349
|
+
result = self._exec_with_capture(cmd, echo=True, trace=False)
|
|
2350
|
+
|
|
2351
|
+
# Truncate stdout if requested
|
|
2352
|
+
if max_output_lines is not None and result.stdout:
|
|
2353
|
+
lines = result.stdout.splitlines()
|
|
2354
|
+
if len(lines) > max_output_lines:
|
|
2355
|
+
truncated_lines = lines[:max_output_lines]
|
|
2356
|
+
truncated_lines.append(f"\n... (output truncated: showing {max_output_lines} of {len(lines)} lines)")
|
|
2357
|
+
result = CommandResponse(
|
|
2358
|
+
command=result.command,
|
|
2359
|
+
rc=result.rc,
|
|
2360
|
+
stdout="\n".join(truncated_lines),
|
|
2361
|
+
stderr=result.stderr,
|
|
2362
|
+
success=result.success,
|
|
2363
|
+
error=result.error,
|
|
2364
|
+
)
|
|
2365
|
+
|
|
2366
|
+
return result
|
|
2367
|
+
|
|
2368
|
+
def codebook(self, varname: str, trace: bool = False, max_output_lines: Optional[int] = None) -> CommandResponse:
|
|
2369
|
+
result = self._exec_with_capture(f"codebook {varname}", trace=trace)
|
|
2370
|
+
|
|
2371
|
+
# Truncate stdout if requested
|
|
2372
|
+
if max_output_lines is not None and result.stdout:
|
|
2373
|
+
lines = result.stdout.splitlines()
|
|
2374
|
+
if len(lines) > max_output_lines:
|
|
2375
|
+
truncated_lines = lines[:max_output_lines]
|
|
2376
|
+
truncated_lines.append(f"\n... (output truncated: showing {max_output_lines} of {len(lines)} lines)")
|
|
2377
|
+
result = CommandResponse(
|
|
2378
|
+
command=result.command,
|
|
2379
|
+
rc=result.rc,
|
|
2380
|
+
stdout="\n".join(truncated_lines),
|
|
2381
|
+
stderr=result.stderr,
|
|
2382
|
+
success=result.success,
|
|
2383
|
+
error=result.error,
|
|
2384
|
+
)
|
|
533
2385
|
|
|
534
|
-
|
|
535
|
-
return self._exec_with_capture(f"codebook {varname}", trace=trace)
|
|
2386
|
+
return result
|
|
536
2387
|
|