mcp-stata 1.6.2__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-stata might be problematic. Click here for more details.

mcp_stata/stata_client.py CHANGED
@@ -6,13 +6,15 @@ import re
6
6
  import subprocess
7
7
  import sys
8
8
  import threading
9
+ from importlib.metadata import PackageNotFoundError, version
9
10
  import tempfile
10
11
  import time
11
12
  from contextlib import contextmanager
12
13
  from io import StringIO
13
- from typing import Any, Awaitable, Callable, Dict, List, Optional
14
+ from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
14
15
 
15
16
  import anyio
17
+ from anyio import get_cancelled_exc_class
16
18
 
17
19
  from .discovery import find_stata_path
18
20
  from .models import (
@@ -32,6 +34,74 @@ from .graph_detector import StreamingGraphCache
32
34
  logger = logging.getLogger("mcp_stata")
33
35
 
34
36
 
37
+ # ============================================================================
38
+ # MODULE-LEVEL DISCOVERY CACHE
39
+ # ============================================================================
40
+ # This cache ensures Stata discovery runs exactly once per process lifetime
41
+ _discovery_lock = threading.Lock()
42
+ _discovery_result: Optional[Tuple[str, str]] = None # (path, edition)
43
+ _discovery_attempted = False
44
+ _discovery_error: Optional[Exception] = None
45
+
46
+
47
+ def _get_discovered_stata() -> Tuple[str, str]:
48
+ """
49
+ Get the discovered Stata path and edition, running discovery only once.
50
+
51
+ Returns:
52
+ Tuple of (stata_executable_path, edition)
53
+
54
+ Raises:
55
+ RuntimeError: If Stata discovery fails
56
+ """
57
+ global _discovery_result, _discovery_attempted, _discovery_error
58
+
59
+ with _discovery_lock:
60
+ # If we've already successfully discovered Stata, return cached result
61
+ if _discovery_result is not None:
62
+ return _discovery_result
63
+
64
+ # If we've already attempted and failed, re-raise the cached error
65
+ if _discovery_attempted and _discovery_error is not None:
66
+ raise RuntimeError(f"Stata binary not found: {_discovery_error}") from _discovery_error
67
+
68
+ # This is the first attempt - run discovery
69
+ _discovery_attempted = True
70
+
71
+ try:
72
+ # Log environment state once at first discovery
73
+ env_path = os.getenv("STATA_PATH")
74
+ if env_path:
75
+ logger.info("STATA_PATH env provided (raw): %s", env_path)
76
+ else:
77
+ logger.info("STATA_PATH env not set; attempting auto-discovery")
78
+
79
+ try:
80
+ pkg_version = version("mcp-stata")
81
+ except PackageNotFoundError:
82
+ pkg_version = "unknown"
83
+ logger.info("mcp-stata version: %s", pkg_version)
84
+
85
+ # Run discovery
86
+ stata_exec_path, edition = find_stata_path()
87
+
88
+ # Cache the successful result
89
+ _discovery_result = (stata_exec_path, edition)
90
+ logger.info("Discovery found Stata at: %s (%s)", stata_exec_path, edition)
91
+
92
+ return _discovery_result
93
+
94
+ except FileNotFoundError as e:
95
+ _discovery_error = e
96
+ raise RuntimeError(f"Stata binary not found: {e}") from e
97
+ except PermissionError as e:
98
+ _discovery_error = e
99
+ raise RuntimeError(
100
+ f"Stata binary is not executable: {e}. "
101
+ "Point STATA_PATH directly to the Stata binary (e.g., .../Contents/MacOS/stata-mp)."
102
+ ) from e
103
+
104
+
35
105
  class StataClient:
36
106
  _initialized = False
37
107
  _exec_lock: threading.Lock
@@ -100,6 +170,62 @@ class StataClient:
100
170
  logger.error(f"Failed to notify about graph cache: {e}")
101
171
 
102
172
  return graph_cache_callback
173
+ def _request_break_in(self) -> None:
174
+ """
175
+ Attempt to interrupt a running Stata command when cancellation is requested.
176
+
177
+ Uses the Stata sfi.breakIn hook when available; errors are swallowed because
178
+ cancellation should never crash the host process.
179
+ """
180
+ try:
181
+ import sfi # type: ignore[import-not-found]
182
+
183
+ break_fn = getattr(sfi, "breakIn", None) or getattr(sfi, "break_in", None)
184
+ if callable(break_fn):
185
+ try:
186
+ break_fn()
187
+ logger.info("Sent breakIn() to Stata for cancellation")
188
+ except Exception as e: # pragma: no cover - best-effort
189
+ logger.warning(f"Failed to send breakIn() to Stata: {e}")
190
+ else: # pragma: no cover - environment without Stata runtime
191
+ logger.debug("sfi.breakIn not available; cannot interrupt Stata")
192
+ except Exception as e: # pragma: no cover - import failure or other
193
+ logger.debug(f"Unable to import sfi for cancellation: {e}")
194
+
195
+ async def _wait_for_stata_stop(self, timeout: float = 2.0) -> bool:
196
+ """
197
+ After requesting a break, poll the Stata interface so it can surface BreakError
198
+ and return control. This is best-effort and time-bounded.
199
+ """
200
+ deadline = time.monotonic() + timeout
201
+ try:
202
+ import sfi # type: ignore[import-not-found]
203
+
204
+ toolkit = getattr(sfi, "SFIToolkit", None)
205
+ poll = getattr(toolkit, "pollnow", None) or getattr(toolkit, "pollstd", None)
206
+ BreakError = getattr(sfi, "BreakError", None)
207
+ except Exception: # pragma: no cover
208
+ return False
209
+
210
+ if not callable(poll):
211
+ return False
212
+
213
+ last_exc: Optional[Exception] = None
214
+ while time.monotonic() < deadline:
215
+ try:
216
+ poll()
217
+ except Exception as e: # pragma: no cover - depends on Stata runtime
218
+ last_exc = e
219
+ if BreakError is not None and isinstance(e, BreakError):
220
+ logger.info("Stata BreakError detected; cancellation acknowledged by Stata")
221
+ return True
222
+ # If Stata already stopped, break on any other exception.
223
+ break
224
+ await anyio.sleep(0.05)
225
+
226
+ if last_exc:
227
+ logger.debug(f"Cancellation poll exited with {last_exc}")
228
+ return False
103
229
 
104
230
  @contextmanager
105
231
  def _temp_cwd(self, cwd: Optional[str]):
@@ -114,24 +240,15 @@ class StataClient:
114
240
  os.chdir(prev)
115
241
 
116
242
  def init(self):
117
- """Initializes usage of pystata."""
243
+ """Initializes usage of pystata using cached discovery results."""
118
244
  if self._initialized:
119
245
  return
120
246
 
121
247
  try:
122
248
  import stata_setup
123
-
124
- try:
125
- stata_exec_path, edition = find_stata_path()
126
- except FileNotFoundError as e:
127
- raise RuntimeError(f"Stata binary not found: {e}") from e
128
- except PermissionError as e:
129
- raise RuntimeError(
130
- f"Stata binary is not executable: {e}. "
131
- "Point STATA_PATH directly to the Stata binary (e.g., .../Contents/MacOS/stata-mp)."
132
- ) from e
133
249
 
134
- logger.info(f"Discovery found Stata at: {stata_exec_path} ({edition})")
250
+ # Get discovered Stata path (cached from first call)
251
+ stata_exec_path, edition = _get_discovered_stata()
135
252
 
136
253
  candidates = []
137
254
 
@@ -171,6 +288,7 @@ class StataClient:
171
288
  try:
172
289
  stata_setup.config(path, edition)
173
290
  success = True
291
+ logger.debug("stata_setup.config succeeded with path: %s", path)
174
292
  break
175
293
  except Exception:
176
294
  continue
@@ -187,14 +305,6 @@ class StataClient:
187
305
  from pystata import stata # type: ignore[import-not-found]
188
306
  self.stata = stata
189
307
  self._initialized = True
190
-
191
- # Ensure a clean graph state for a fresh client. PyStata's backend is
192
- # effectively global, so graph memory can otherwise leak across tests
193
- # and separate StataClient instances.
194
- try:
195
- self.stata.run("capture graph drop _all", quietly=True)
196
- except Exception:
197
- pass
198
308
 
199
309
  # Initialize list_graphs TTL cache
200
310
  self._list_graphs_cache = None
@@ -205,11 +315,14 @@ class StataClient:
205
315
  # internal Stata graph names.
206
316
  self._graph_name_aliases: Dict[str, str] = {}
207
317
  self._graph_name_reverse: Dict[str, str] = {}
318
+
319
+ logger.info("StataClient initialized successfully with %s (%s)", stata_exec_path, edition)
208
320
 
209
- except ImportError:
210
- # Fallback for when stata_setup isn't in PYTHONPATH yet?
211
- # Usually users must have it installed. We rely on discovery logic.
212
- raise RuntimeError("Could not import `stata_setup`. Ensure pystata is installed.")
321
+ except ImportError as e:
322
+ raise RuntimeError(
323
+ f"Failed to import stata_setup or pystata: {e}. "
324
+ "Ensure they are installed (pip install pystata stata-setup)."
325
+ ) from e
213
326
 
214
327
  def _make_valid_stata_name(self, name: str) -> str:
215
328
  """Create a valid Stata name (<=32 chars, [A-Za-z_][A-Za-z0-9_]*)."""
@@ -295,6 +408,73 @@ class StataClient:
295
408
  return None
296
409
  return None
297
410
 
411
+ def _read_log_tail(self, path: str, max_chars: int) -> str:
412
+ try:
413
+ with open(path, "rb") as f:
414
+ f.seek(0, os.SEEK_END)
415
+ size = f.tell()
416
+ if size <= 0:
417
+ return ""
418
+ read_size = min(size, max_chars)
419
+ f.seek(-read_size, os.SEEK_END)
420
+ data = f.read(read_size)
421
+ return data.decode("utf-8", errors="replace")
422
+ except Exception:
423
+ return ""
424
+
425
+ def _select_stata_error_message(self, text: str, fallback: str) -> str:
426
+ if not text:
427
+ return fallback
428
+ ignore_patterns = (
429
+ r"^r\(\d+\);?$",
430
+ r"^end of do-file$",
431
+ r"^execution terminated$",
432
+ r"^[-=*]{3,}.*$",
433
+ )
434
+ rc_pattern = r"^r\(\d+\);?$"
435
+ error_patterns = (
436
+ r"\btype mismatch\b",
437
+ r"\bnot found\b",
438
+ r"\bnot allowed\b",
439
+ r"\bno observations\b",
440
+ r"\bconformability error\b",
441
+ r"\binvalid\b",
442
+ r"\bsyntax error\b",
443
+ r"\berror\b",
444
+ )
445
+ lines = text.splitlines()
446
+ for raw in reversed(lines):
447
+ line = raw.strip()
448
+ if not line:
449
+ continue
450
+ if any(re.search(pat, line, re.IGNORECASE) for pat in error_patterns):
451
+ return line
452
+ for i in range(len(lines) - 1, -1, -1):
453
+ line = lines[i].strip()
454
+ if not line:
455
+ continue
456
+ if re.match(rc_pattern, line, re.IGNORECASE):
457
+ for j in range(i - 1, -1, -1):
458
+ prev_line = lines[j].strip()
459
+ if not prev_line:
460
+ continue
461
+ if prev_line.startswith((".", ">", "-", "=")):
462
+ continue
463
+ if any(re.match(pat, prev_line, re.IGNORECASE) for pat in ignore_patterns):
464
+ continue
465
+ return prev_line
466
+ return line
467
+ for raw in reversed(lines):
468
+ line = raw.strip()
469
+ if not line:
470
+ continue
471
+ if line.startswith((".", ">", "-", "=")):
472
+ continue
473
+ if any(re.match(pat, line, re.IGNORECASE) for pat in ignore_patterns):
474
+ continue
475
+ return line
476
+ return fallback
477
+
298
478
  def _smcl_to_text(self, smcl: str) -> str:
299
479
  """Convert simple SMCL markup into plain text for LLM-friendly help."""
300
480
  # First, keep inline directive content if present (e.g., {bf:word} -> word)
@@ -320,7 +500,10 @@ class StataClient:
320
500
  rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
321
501
  line_no = self._parse_line_from_text(combined) if combined else None
322
502
  snippet = combined[-800:] if combined else None
323
- message = (stderr or (str(exc) if exc else "") or stdout or "Stata error").strip()
503
+ fallback = (stderr or (str(exc) if exc else "") or stdout or "Stata error").strip()
504
+ if fallback == "Stata error" and rc_final is not None:
505
+ fallback = f"Stata error r({rc_final})"
506
+ message = self._select_stata_error_message(combined, fallback)
324
507
  return ErrorEnvelope(
325
508
  message=message,
326
509
  rc=rc_final,
@@ -527,7 +710,7 @@ class StataClient:
527
710
  buffering=1,
528
711
  )
529
712
  log_path = log_file.name
530
- tail = TailBuffer(max_chars=8000)
713
+ tail = TailBuffer(max_chars=200000 if trace else 20000)
531
714
  tee = FileTeeIO(log_file, tail)
532
715
 
533
716
  # Inform the MCP client immediately where to read/tail the output.
@@ -538,33 +721,42 @@ class StataClient:
538
721
  def _run_blocking() -> None:
539
722
  nonlocal rc, exc
540
723
  with self._exec_lock:
541
- with self._temp_cwd(cwd):
542
- with self._redirect_io_streaming(tee, tee):
543
- try:
544
- if trace:
545
- self.stata.run("set trace on")
546
- ret = self.stata.run(code, echo=echo)
547
- # Some PyStata builds return output as a string rather than printing.
548
- if isinstance(ret, str) and ret:
549
- try:
550
- tee.write(ret)
551
- except Exception:
552
- pass
553
- except Exception as e:
554
- exc = e
555
- finally:
556
- rc = self._read_return_code()
557
- if trace:
558
- try:
559
- self.stata.run("set trace off")
560
- except Exception:
561
- pass
724
+ self._is_executing = True
725
+ try:
726
+ with self._temp_cwd(cwd):
727
+ with self._redirect_io_streaming(tee, tee):
728
+ try:
729
+ if trace:
730
+ self.stata.run("set trace on")
731
+ ret = self.stata.run(code, echo=echo)
732
+ # Some PyStata builds return output as a string rather than printing.
733
+ if isinstance(ret, str) and ret:
734
+ try:
735
+ tee.write(ret)
736
+ except Exception:
737
+ pass
738
+ except Exception as e:
739
+ exc = e
740
+ finally:
741
+ rc = self._read_return_code()
742
+ if trace:
743
+ try:
744
+ self.stata.run("set trace off")
745
+ except Exception:
746
+ pass
747
+ finally:
748
+ self._is_executing = False
562
749
 
563
750
  try:
564
751
  if notify_progress is not None:
565
752
  await notify_progress(0, None, "Running Stata command")
566
753
 
567
- await anyio.to_thread.run_sync(_run_blocking)
754
+ await anyio.to_thread.run_sync(_run_blocking, abandon_on_cancel=True)
755
+ except get_cancelled_exc_class():
756
+ # Best-effort cancellation: signal Stata to break, wait briefly, then propagate.
757
+ self._request_break_in()
758
+ await self._wait_for_stata_stop()
759
+ raise
568
760
  finally:
569
761
  tee.close()
570
762
 
@@ -583,6 +775,9 @@ class StataClient:
583
775
  logger.warning(f"Failed to cache detected graphs: {e}")
584
776
 
585
777
  tail_text = tail.get_value()
778
+ log_tail = self._read_log_tail(log_path, 200000 if trace else 20000)
779
+ if log_tail and len(log_tail) > len(tail_text):
780
+ tail_text = log_tail
586
781
  combined = (tail_text or "") + (f"\n{exc}" if exc else "")
587
782
  rc_hint = self._parse_rc_from_text(combined) if combined else None
588
783
  if exc is None and rc_hint is not None and rc_hint != 0:
@@ -596,14 +791,10 @@ class StataClient:
596
791
  rc_hint = self._parse_rc_from_text(combined) if combined else None
597
792
  rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
598
793
  line_no = self._parse_line_from_text(combined) if combined else None
599
- message = "Stata error"
600
- if tail_text and tail_text.strip():
601
- for line in reversed(tail_text.splitlines()):
602
- if line.strip():
603
- message = line.strip()
604
- break
605
- elif exc is not None:
606
- message = str(exc).strip() or message
794
+ fallback = (str(exc).strip() if exc is not None else "") or "Stata error"
795
+ if fallback == "Stata error" and rc_final is not None:
796
+ fallback = f"Stata error r({rc_final})"
797
+ message = self._select_stata_error_message(combined, fallback)
607
798
 
608
799
  error = ErrorEnvelope(
609
800
  message=message,
@@ -754,7 +945,7 @@ class StataClient:
754
945
  buffering=1,
755
946
  )
756
947
  log_path = log_file.name
757
- tail = TailBuffer(max_chars=8000)
948
+ tail = TailBuffer(max_chars=200000 if trace else 20000)
758
949
  tee = FileTeeIO(log_file, tail)
759
950
 
760
951
  # Inform the MCP client immediately where to read/tail the output.
@@ -838,7 +1029,11 @@ class StataClient:
838
1029
  await notify_progress(0, None, "Running do-file")
839
1030
 
840
1031
  try:
841
- await anyio.to_thread.run_sync(_run_blocking)
1032
+ await anyio.to_thread.run_sync(_run_blocking, abandon_on_cancel=True)
1033
+ except get_cancelled_exc_class():
1034
+ self._request_break_in()
1035
+ await self._wait_for_stata_stop()
1036
+ raise
842
1037
  finally:
843
1038
  done.set()
844
1039
  tee.close()
@@ -916,6 +1111,9 @@ class StataClient:
916
1111
  logger.error(f"Post-execution graph detection failed: {e}")
917
1112
 
918
1113
  tail_text = tail.get_value()
1114
+ log_tail = self._read_log_tail(log_path, 200000 if trace else 20000)
1115
+ if log_tail and len(log_tail) > len(tail_text):
1116
+ tail_text = log_tail
919
1117
  combined = (tail_text or "") + (f"\n{exc}" if exc else "")
920
1118
  rc_hint = self._parse_rc_from_text(combined) if combined else None
921
1119
  if exc is None and rc_hint is not None and rc_hint != 0:
@@ -929,14 +1127,10 @@ class StataClient:
929
1127
  rc_hint = self._parse_rc_from_text(combined) if combined else None
930
1128
  rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
931
1129
  line_no = self._parse_line_from_text(combined) if combined else None
932
- message = "Stata error"
933
- if tail_text and tail_text.strip():
934
- for line in reversed(tail_text.splitlines()):
935
- if line.strip():
936
- message = line.strip()
937
- break
938
- elif exc is not None:
939
- message = str(exc).strip() or message
1130
+ fallback = (str(exc).strip() if exc is not None else "") or "Stata error"
1131
+ if fallback == "Stata error" and rc_final is not None:
1132
+ fallback = f"Stata error r({rc_final})"
1133
+ message = self._select_stata_error_message(combined, fallback)
940
1134
 
941
1135
  error = ErrorEnvelope(
942
1136
  message=message,
@@ -1299,6 +1493,65 @@ class StataClient:
1299
1493
 
1300
1494
  return indices
1301
1495
 
1496
+ def apply_sort(self, sort_spec: List[str]) -> None:
1497
+ """
1498
+ Apply sorting to the dataset using gsort.
1499
+
1500
+ Args:
1501
+ sort_spec: List of variables to sort by, with optional +/- prefix.
1502
+ e.g., ["-price", "+mpg"] sorts by price descending, then mpg ascending.
1503
+ No prefix is treated as ascending (+).
1504
+
1505
+ Raises:
1506
+ ValueError: If sort_spec is invalid or contains invalid variables
1507
+ RuntimeError: If no data in memory or sort command fails
1508
+ """
1509
+ if not self._initialized:
1510
+ self.init()
1511
+
1512
+ state = self.get_dataset_state()
1513
+ if int(state.get("k", 0) or 0) == 0 and int(state.get("n", 0) or 0) == 0:
1514
+ raise RuntimeError("No data in memory")
1515
+
1516
+ if not sort_spec or not isinstance(sort_spec, list):
1517
+ raise ValueError("sort_spec must be a non-empty list")
1518
+
1519
+ # Validate all variables exist
1520
+ var_map = self._get_var_index_map()
1521
+ for spec in sort_spec:
1522
+ if not isinstance(spec, str) or not spec:
1523
+ raise ValueError(f"Invalid sort specification: {spec!r}")
1524
+
1525
+ # Extract variable name (remove +/- prefix if present)
1526
+ varname = spec.lstrip("+-")
1527
+ if not varname:
1528
+ raise ValueError(f"Invalid sort specification: {spec!r}")
1529
+
1530
+ if varname not in var_map:
1531
+ raise ValueError(f"Variable not found: {varname}")
1532
+
1533
+ # Build gsort command
1534
+ # gsort uses - for descending, + or nothing for ascending
1535
+ gsort_args = []
1536
+ for spec in sort_spec:
1537
+ if spec.startswith("-") or spec.startswith("+"):
1538
+ gsort_args.append(spec)
1539
+ else:
1540
+ # No prefix means ascending, add + explicitly for clarity
1541
+ gsort_args.append(f"+{spec}")
1542
+
1543
+ cmd = f"gsort {' '.join(gsort_args)}"
1544
+
1545
+ try:
1546
+ result = self.run_command_structured(cmd, echo=False)
1547
+ if not result.success:
1548
+ error_msg = result.error.message if result.error else "Sort failed"
1549
+ raise RuntimeError(f"Failed to sort dataset: {error_msg}")
1550
+ except Exception as e:
1551
+ if isinstance(e, RuntimeError):
1552
+ raise
1553
+ raise RuntimeError(f"Failed to sort dataset: {e}")
1554
+
1302
1555
  def get_variable_details(self, varname: str) -> str:
1303
1556
  """Returns codebook/summary for a specific variable."""
1304
1557
  resp = self.run_command_structured(f"codebook {varname}", echo=True)
@@ -2121,7 +2374,7 @@ class StataClient:
2121
2374
  buffering=1,
2122
2375
  )
2123
2376
  log_path = log_file.name
2124
- tail = TailBuffer(max_chars=8000)
2377
+ tail = TailBuffer(max_chars=200000 if trace else 20000)
2125
2378
  tee = FileTeeIO(log_file, tail)
2126
2379
 
2127
2380
  rc = -1
@@ -2152,6 +2405,9 @@ class StataClient:
2152
2405
  tee.close()
2153
2406
 
2154
2407
  tail_text = tail.get_value()
2408
+ log_tail = self._read_log_tail(log_path, 200000 if trace else 20000)
2409
+ if log_tail and len(log_tail) > len(tail_text):
2410
+ tail_text = log_tail
2155
2411
  combined = (tail_text or "") + (f"\n{exc}" if exc else "")
2156
2412
  rc_hint = self._parse_rc_from_text(combined) if combined else None
2157
2413
  if exc is None and rc_hint is not None and rc_hint != 0:
@@ -2166,14 +2422,10 @@ class StataClient:
2166
2422
  rc_hint = self._parse_rc_from_text(combined) if combined else None
2167
2423
  rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
2168
2424
  line_no = self._parse_line_from_text(combined) if combined else None
2169
- message = "Stata error"
2170
- if tail_text and tail_text.strip():
2171
- for line in reversed(tail_text.splitlines()):
2172
- if line.strip():
2173
- message = line.strip()
2174
- break
2175
- elif exc is not None:
2176
- message = str(exc).strip() or message
2425
+ fallback = (str(exc).strip() if exc is not None else "") or "Stata error"
2426
+ if fallback == "Stata error" and rc_final is not None:
2427
+ fallback = f"Stata error r({rc_final})"
2428
+ message = self._select_stata_error_message(combined, fallback)
2177
2429
 
2178
2430
  error = ErrorEnvelope(
2179
2431
  message=message,
@@ -2258,4 +2510,3 @@ class StataClient:
2258
2510
  )
2259
2511
 
2260
2512
  return result
2261
-