mcp-stata 1.6.8__py3-none-any.whl → 1.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-stata might be problematic. Click here for more details.
- mcp_stata/discovery.py +151 -31
- mcp_stata/models.py +1 -0
- mcp_stata/stata_client.py +323 -264
- mcp_stata/ui_http.py +37 -1
- {mcp_stata-1.6.8.dist-info → mcp_stata-1.7.6.dist-info}/METADATA +60 -2
- mcp_stata-1.7.6.dist-info/RECORD +14 -0
- mcp_stata-1.6.8.dist-info/RECORD +0 -14
- {mcp_stata-1.6.8.dist-info → mcp_stata-1.7.6.dist-info}/WHEEL +0 -0
- {mcp_stata-1.6.8.dist-info → mcp_stata-1.7.6.dist-info}/entry_points.txt +0 -0
- {mcp_stata-1.6.8.dist-info → mcp_stata-1.7.6.dist-info}/licenses/LICENSE +0 -0
mcp_stata/stata_client.py
CHANGED
|
@@ -120,16 +120,48 @@ class StataClient:
|
|
|
120
120
|
return inst
|
|
121
121
|
|
|
122
122
|
@contextmanager
|
|
123
|
-
def _redirect_io(self):
|
|
123
|
+
def _redirect_io(self, out_buf, err_buf):
|
|
124
124
|
"""Safely redirect stdout/stderr for the duration of a Stata call."""
|
|
125
|
-
out_buf, err_buf = StringIO(), StringIO()
|
|
126
125
|
backup_stdout, backup_stderr = sys.stdout, sys.stderr
|
|
127
126
|
sys.stdout, sys.stderr = out_buf, err_buf
|
|
128
127
|
try:
|
|
129
|
-
yield
|
|
128
|
+
yield
|
|
130
129
|
finally:
|
|
131
130
|
sys.stdout, sys.stderr = backup_stdout, backup_stderr
|
|
132
131
|
|
|
132
|
+
def _select_stata_error_message(self, text: str, fallback: str) -> str:
|
|
133
|
+
"""
|
|
134
|
+
Helper for tests and legacy callers to extract the clean error message.
|
|
135
|
+
"""
|
|
136
|
+
if not text:
|
|
137
|
+
return fallback
|
|
138
|
+
|
|
139
|
+
lines = text.splitlines()
|
|
140
|
+
trace_pattern = re.compile(r'^\s*[-=.]')
|
|
141
|
+
noise_pattern = re.compile(r'^(?:\}|\{txt\}|\{com\}|end of do-file)')
|
|
142
|
+
|
|
143
|
+
for line in reversed(lines):
|
|
144
|
+
stripped = line.strip()
|
|
145
|
+
if not stripped:
|
|
146
|
+
continue
|
|
147
|
+
if trace_pattern.match(line):
|
|
148
|
+
continue
|
|
149
|
+
if noise_pattern.match(stripped):
|
|
150
|
+
continue
|
|
151
|
+
if stripped.startswith("r(") and stripped.endswith(");"):
|
|
152
|
+
# If we hit r(123); we might want the line ABOVE it if it's not noise
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
# Preserve SMCL tags
|
|
156
|
+
return stripped
|
|
157
|
+
|
|
158
|
+
# If we couldn't find a better message, try to find r(N);
|
|
159
|
+
match = re.search(r"r\(\d+\);", text)
|
|
160
|
+
if match:
|
|
161
|
+
return match.group(0)
|
|
162
|
+
|
|
163
|
+
return fallback
|
|
164
|
+
|
|
133
165
|
@staticmethod
|
|
134
166
|
def _stata_quote(value: str) -> str:
|
|
135
167
|
"""Return a Stata double-quoted string literal for value."""
|
|
@@ -170,6 +202,7 @@ class StataClient:
|
|
|
170
202
|
logger.error(f"Failed to notify about graph cache: {e}")
|
|
171
203
|
|
|
172
204
|
return graph_cache_callback
|
|
205
|
+
|
|
173
206
|
def _request_break_in(self) -> None:
|
|
174
207
|
"""
|
|
175
208
|
Attempt to interrupt a running Stata command when cancellation is requested.
|
|
@@ -380,15 +413,32 @@ class StataClient:
|
|
|
380
413
|
try:
|
|
381
414
|
from sfi import Macro # type: ignore[import-not-found]
|
|
382
415
|
rc_val = Macro.getCValue("rc") # type: ignore[attr-defined]
|
|
416
|
+
if rc_val is not None:
|
|
417
|
+
return int(float(rc_val))
|
|
418
|
+
# If getCValue returns None, fall through to the alternative approach
|
|
419
|
+
except Exception:
|
|
420
|
+
pass
|
|
421
|
+
|
|
422
|
+
# Alternative approach: use a global macro
|
|
423
|
+
# CRITICAL: This must be done carefully to avoid mutating c(rc)
|
|
424
|
+
try:
|
|
425
|
+
self.stata.run("global MCP_RC = c(rc)")
|
|
426
|
+
from sfi import Macro as Macro2 # type: ignore[import-not-found]
|
|
427
|
+
rc_val = Macro2.getGlobal("MCP_RC")
|
|
383
428
|
return int(float(rc_val))
|
|
384
429
|
except Exception:
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
430
|
+
return -1
|
|
431
|
+
|
|
432
|
+
def _get_rc_from_scalar(self, Scalar) -> int:
|
|
433
|
+
"""Safely get return code, handling None values."""
|
|
434
|
+
try:
|
|
435
|
+
from sfi import Macro
|
|
436
|
+
rc_val = Macro.getCValue("rc")
|
|
437
|
+
if rc_val is None:
|
|
391
438
|
return -1
|
|
439
|
+
return int(float(rc_val))
|
|
440
|
+
except Exception:
|
|
441
|
+
return -1
|
|
392
442
|
|
|
393
443
|
def _parse_rc_from_text(self, text: str) -> Optional[int]:
|
|
394
444
|
match = re.search(r"r\((\d+)\)", text)
|
|
@@ -408,6 +458,20 @@ class StataClient:
|
|
|
408
458
|
return None
|
|
409
459
|
return None
|
|
410
460
|
|
|
461
|
+
def _read_log_tail(self, path: str, max_chars: int) -> str:
|
|
462
|
+
try:
|
|
463
|
+
with open(path, "rb") as f:
|
|
464
|
+
f.seek(0, os.SEEK_END)
|
|
465
|
+
size = f.tell()
|
|
466
|
+
if size <= 0:
|
|
467
|
+
return ""
|
|
468
|
+
read_size = min(size, max_chars)
|
|
469
|
+
f.seek(-read_size, os.SEEK_END)
|
|
470
|
+
data = f.read(read_size)
|
|
471
|
+
return data.decode("utf-8", errors="replace")
|
|
472
|
+
except Exception:
|
|
473
|
+
return ""
|
|
474
|
+
|
|
411
475
|
def _smcl_to_text(self, smcl: str) -> str:
|
|
412
476
|
"""Convert simple SMCL markup into plain text for LLM-friendly help."""
|
|
413
477
|
# First, keep inline directive content if present (e.g., {bf:word} -> word)
|
|
@@ -419,150 +483,126 @@ class StataClient:
|
|
|
419
483
|
lines = [line.rstrip() for line in cleaned.splitlines()]
|
|
420
484
|
return "\n".join(lines).strip()
|
|
421
485
|
|
|
422
|
-
def
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
486
|
+
def _extract_error_and_context(self, log_content: str, rc: int) -> Tuple[str, str]:
|
|
487
|
+
"""
|
|
488
|
+
Extracts the error message and trace context using {err} SMCL tags.
|
|
489
|
+
"""
|
|
490
|
+
if not log_content:
|
|
491
|
+
return f"Stata error r({rc})", ""
|
|
492
|
+
|
|
493
|
+
lines = log_content.splitlines()
|
|
494
|
+
|
|
495
|
+
# Search backwards for the {err} tag
|
|
496
|
+
for i in range(len(lines) - 1, -1, -1):
|
|
497
|
+
line = lines[i]
|
|
498
|
+
if '{err}' in line:
|
|
499
|
+
# Found the (last) error line.
|
|
500
|
+
# Walk backwards to find the start of the error block (consecutive {err} lines)
|
|
501
|
+
start_idx = i
|
|
502
|
+
while start_idx > 0 and '{err}' in lines[start_idx-1]:
|
|
503
|
+
start_idx -= 1
|
|
504
|
+
|
|
505
|
+
# The full error message is the concatenation of all {err} lines in this block
|
|
506
|
+
error_lines = []
|
|
507
|
+
for j in range(start_idx, i + 1):
|
|
508
|
+
error_lines.append(lines[j].strip())
|
|
509
|
+
|
|
510
|
+
clean_msg = " ".join(filter(None, error_lines)) or f"Stata error r({rc})"
|
|
511
|
+
|
|
512
|
+
# Capture everything from the start of the error block to the end
|
|
513
|
+
context_str = "\n".join(lines[start_idx:])
|
|
514
|
+
return clean_msg, context_str
|
|
515
|
+
|
|
516
|
+
# Fallback: grab the last 30 lines
|
|
517
|
+
context_start = max(0, len(lines) - 30)
|
|
518
|
+
context_str = "\n".join(lines[context_start:])
|
|
519
|
+
|
|
520
|
+
return f"Stata error r({rc})", context_str
|
|
447
521
|
|
|
448
522
|
def _exec_with_capture(self, code: str, echo: bool = True, trace: bool = False, cwd: Optional[str] = None) -> CommandResponse:
|
|
449
|
-
"""Execute Stata code with stdout/stderr capture and rc detection."""
|
|
450
523
|
if not self._initialized:
|
|
451
524
|
self.init()
|
|
452
525
|
|
|
526
|
+
# Rewrite graph names with special characters to internal aliases
|
|
453
527
|
code = self._maybe_rewrite_graph_name_in_command(code)
|
|
454
528
|
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
stdout="",
|
|
460
|
-
stderr=None,
|
|
461
|
-
success=False,
|
|
462
|
-
error=ErrorEnvelope(
|
|
463
|
-
message=f"cwd not found: {cwd}",
|
|
464
|
-
rc=601,
|
|
465
|
-
command=code,
|
|
466
|
-
),
|
|
467
|
-
)
|
|
529
|
+
output_buffer = StringIO()
|
|
530
|
+
error_buffer = StringIO()
|
|
531
|
+
rc = 0
|
|
532
|
+
sys_error = None
|
|
468
533
|
|
|
469
|
-
start_time = time.time()
|
|
470
|
-
exc: Optional[Exception] = None
|
|
471
|
-
ret_text: Optional[str] = None
|
|
472
534
|
with self._exec_lock:
|
|
473
|
-
# Set execution flag to prevent recursive Stata calls
|
|
474
|
-
self._is_executing = True
|
|
475
535
|
try:
|
|
536
|
+
from sfi import Scalar, SFIToolkit # Import SFI tools inside execution block
|
|
476
537
|
with self._temp_cwd(cwd):
|
|
477
|
-
with self._redirect_io(
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
success = rc == 0 and exc is None
|
|
512
|
-
error = None
|
|
513
|
-
if not success:
|
|
514
|
-
error = self._build_error_envelope(code, rc, stdout, stderr, exc, trace)
|
|
515
|
-
duration = time.time() - start_time
|
|
516
|
-
code_preview = code.replace("\n", "\\n")
|
|
517
|
-
logger.info(
|
|
518
|
-
"stata.run rc=%s success=%s trace=%s duration_ms=%.2f code_preview=%s",
|
|
519
|
-
rc,
|
|
520
|
-
success,
|
|
521
|
-
trace,
|
|
522
|
-
duration * 1000,
|
|
523
|
-
code_preview[:120],
|
|
524
|
-
)
|
|
525
|
-
# Mutually exclusive - when error, output is in ErrorEnvelope only
|
|
538
|
+
with self._redirect_io(output_buffer, error_buffer):
|
|
539
|
+
if trace:
|
|
540
|
+
self.stata.run("set trace on")
|
|
541
|
+
|
|
542
|
+
# 1. Run the user code
|
|
543
|
+
self.stata.run(code, echo=echo)
|
|
544
|
+
|
|
545
|
+
except Exception as e:
|
|
546
|
+
sys_error = str(e)
|
|
547
|
+
# Try to parse RC from exception message
|
|
548
|
+
parsed_rc = self._parse_rc_from_text(sys_error)
|
|
549
|
+
rc = parsed_rc if parsed_rc is not None else 1
|
|
550
|
+
|
|
551
|
+
stdout_content = output_buffer.getvalue()
|
|
552
|
+
stderr_content = error_buffer.getvalue()
|
|
553
|
+
full_log = stdout_content + "\n" + stderr_content
|
|
554
|
+
|
|
555
|
+
# 2. Extract RC from log tail (primary error detection method)
|
|
556
|
+
if rc == 1 and not sys_error: # No exception but might have error in log
|
|
557
|
+
parsed_rc = self._parse_rc_from_text(full_log)
|
|
558
|
+
if parsed_rc is not None:
|
|
559
|
+
rc = parsed_rc
|
|
560
|
+
|
|
561
|
+
error_envelope = None
|
|
562
|
+
if rc != 0:
|
|
563
|
+
if sys_error:
|
|
564
|
+
msg = sys_error
|
|
565
|
+
snippet = sys_error # Include the exception message as snippet
|
|
566
|
+
else:
|
|
567
|
+
# Extract error message from log tail
|
|
568
|
+
msg, context = self._extract_error_and_context(full_log, rc)
|
|
569
|
+
|
|
570
|
+
error_envelope = ErrorEnvelope(message=msg, rc=rc, context=context, snippet=full_log[-800:])
|
|
571
|
+
|
|
526
572
|
return CommandResponse(
|
|
527
573
|
command=code,
|
|
528
574
|
rc=rc,
|
|
529
|
-
stdout=
|
|
530
|
-
stderr=
|
|
531
|
-
success=
|
|
532
|
-
error=
|
|
575
|
+
stdout=stdout_content,
|
|
576
|
+
stderr=stderr_content,
|
|
577
|
+
success=(rc == 0),
|
|
578
|
+
error=error_envelope,
|
|
533
579
|
)
|
|
534
580
|
|
|
535
581
|
def _exec_no_capture(self, code: str, echo: bool = False, trace: bool = False) -> CommandResponse:
|
|
536
|
-
"""Execute Stata code while leaving stdout/stderr alone.
|
|
537
|
-
|
|
538
|
-
PyStata's output bridge uses its own thread and can misbehave on Windows
|
|
539
|
-
when we redirect stdio (e.g., graph export). This path keeps the normal
|
|
540
|
-
handlers and just reads rc afterward.
|
|
541
|
-
"""
|
|
582
|
+
"""Execute Stata code while leaving stdout/stderr alone."""
|
|
542
583
|
if not self._initialized:
|
|
543
584
|
self.init()
|
|
544
585
|
|
|
545
586
|
exc: Optional[Exception] = None
|
|
546
587
|
ret_text: Optional[str] = None
|
|
588
|
+
rc = 0
|
|
589
|
+
|
|
547
590
|
with self._exec_lock:
|
|
548
591
|
try:
|
|
592
|
+
from sfi import Scalar # Import SFI tools
|
|
549
593
|
if trace:
|
|
550
594
|
self.stata.run("set trace on")
|
|
551
595
|
ret = self.stata.run(code, echo=echo)
|
|
552
596
|
if isinstance(ret, str) and ret:
|
|
553
597
|
ret_text = ret
|
|
598
|
+
|
|
599
|
+
# Robust RC check even for no-capture
|
|
600
|
+
rc = self._read_return_code()
|
|
601
|
+
|
|
554
602
|
except Exception as e:
|
|
555
603
|
exc = e
|
|
604
|
+
rc = 1
|
|
556
605
|
finally:
|
|
557
|
-
rc = self._read_return_code()
|
|
558
|
-
# If Stata returned an r(#) in text, prefer it.
|
|
559
|
-
combined = "\n".join(filter(None, [ret_text or "", str(exc) if exc else ""])).strip()
|
|
560
|
-
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
561
|
-
if exc is None and rc_hint is not None and rc_hint != 0:
|
|
562
|
-
rc = rc_hint
|
|
563
|
-
if exc is None and (rc is None or rc == -1) and rc_hint is None:
|
|
564
|
-
# Normalize spurious rc reads only when missing/invalid
|
|
565
|
-
rc = 0
|
|
566
606
|
if trace:
|
|
567
607
|
try:
|
|
568
608
|
self.stata.run("set trace off")
|
|
@@ -574,8 +614,13 @@ class StataClient:
|
|
|
574
614
|
success = rc == 0 and exc is None
|
|
575
615
|
error = None
|
|
576
616
|
if not success:
|
|
577
|
-
|
|
578
|
-
error =
|
|
617
|
+
msg = str(exc) if exc else f"Stata error r({rc})"
|
|
618
|
+
error = ErrorEnvelope(
|
|
619
|
+
message=msg,
|
|
620
|
+
rc=rc,
|
|
621
|
+
command=code,
|
|
622
|
+
stdout=ret_text,
|
|
623
|
+
)
|
|
579
624
|
|
|
580
625
|
return CommandResponse(
|
|
581
626
|
command=code,
|
|
@@ -640,7 +685,7 @@ class StataClient:
|
|
|
640
685
|
buffering=1,
|
|
641
686
|
)
|
|
642
687
|
log_path = log_file.name
|
|
643
|
-
tail = TailBuffer(max_chars=
|
|
688
|
+
tail = TailBuffer(max_chars=200000 if trace else 20000)
|
|
644
689
|
tee = FileTeeIO(log_file, tail)
|
|
645
690
|
|
|
646
691
|
# Inform the MCP client immediately where to read/tail the output.
|
|
@@ -653,6 +698,7 @@ class StataClient:
|
|
|
653
698
|
with self._exec_lock:
|
|
654
699
|
self._is_executing = True
|
|
655
700
|
try:
|
|
701
|
+
from sfi import Scalar, SFIToolkit # Import SFI tools
|
|
656
702
|
with self._temp_cwd(cwd):
|
|
657
703
|
with self._redirect_io_streaming(tee, tee):
|
|
658
704
|
try:
|
|
@@ -665,10 +711,14 @@ class StataClient:
|
|
|
665
711
|
tee.write(ret)
|
|
666
712
|
except Exception:
|
|
667
713
|
pass
|
|
714
|
+
|
|
715
|
+
# ROBUST DETECTION & OUTPUT
|
|
716
|
+
rc = self._read_return_code()
|
|
717
|
+
|
|
668
718
|
except Exception as e:
|
|
669
719
|
exc = e
|
|
720
|
+
if rc == 0: rc = 1
|
|
670
721
|
finally:
|
|
671
|
-
rc = self._read_return_code()
|
|
672
722
|
if trace:
|
|
673
723
|
try:
|
|
674
724
|
self.stata.run("set trace off")
|
|
@@ -705,36 +755,25 @@ class StataClient:
|
|
|
705
755
|
logger.warning(f"Failed to cache detected graphs: {e}")
|
|
706
756
|
|
|
707
757
|
tail_text = tail.get_value()
|
|
758
|
+
log_tail = self._read_log_tail(log_path, 200000 if trace else 20000)
|
|
759
|
+
if log_tail and len(log_tail) > len(tail_text):
|
|
760
|
+
tail_text = log_tail
|
|
708
761
|
combined = (tail_text or "") + (f"\n{exc}" if exc else "")
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
rc = rc_hint
|
|
712
|
-
if exc is None and rc_hint is None:
|
|
713
|
-
rc = 0 if rc is None or rc != 0 else rc
|
|
714
|
-
success = rc == 0 and exc is None
|
|
762
|
+
|
|
763
|
+
success = (rc == 0 and exc is None)
|
|
715
764
|
error = None
|
|
765
|
+
|
|
716
766
|
if not success:
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
line_no = self._parse_line_from_text(combined) if combined else None
|
|
721
|
-
message = "Stata error"
|
|
722
|
-
if tail_text and tail_text.strip():
|
|
723
|
-
for line in reversed(tail_text.splitlines()):
|
|
724
|
-
if line.strip():
|
|
725
|
-
message = line.strip()
|
|
726
|
-
break
|
|
727
|
-
elif exc is not None:
|
|
728
|
-
message = str(exc).strip() or message
|
|
729
|
-
|
|
767
|
+
# Use robust extractor
|
|
768
|
+
msg, context = self._extract_error_and_context(combined, rc)
|
|
769
|
+
|
|
730
770
|
error = ErrorEnvelope(
|
|
731
|
-
message=
|
|
732
|
-
|
|
733
|
-
|
|
771
|
+
message=msg,
|
|
772
|
+
context=context,
|
|
773
|
+
rc=rc,
|
|
734
774
|
command=code,
|
|
735
775
|
log_path=log_path,
|
|
736
|
-
snippet=snippet
|
|
737
|
-
trace=trace or None,
|
|
776
|
+
snippet=combined[-800:] # Keep snippet for backward compat
|
|
738
777
|
)
|
|
739
778
|
|
|
740
779
|
duration = time.time() - start_time
|
|
@@ -876,7 +915,7 @@ class StataClient:
|
|
|
876
915
|
buffering=1,
|
|
877
916
|
)
|
|
878
917
|
log_path = log_file.name
|
|
879
|
-
tail = TailBuffer(max_chars=
|
|
918
|
+
tail = TailBuffer(max_chars=200000 if trace else 20000)
|
|
880
919
|
tee = FileTeeIO(log_file, tail)
|
|
881
920
|
|
|
882
921
|
# Inform the MCP client immediately where to read/tail the output.
|
|
@@ -887,7 +926,6 @@ class StataClient:
|
|
|
887
926
|
command = f'do "{path_for_stata}"'
|
|
888
927
|
|
|
889
928
|
# Capture initial graph state BEFORE execution starts
|
|
890
|
-
# This allows post-execution detection to identify new graphs
|
|
891
929
|
if graph_cache:
|
|
892
930
|
try:
|
|
893
931
|
graph_cache._initial_graphs = set(self.list_graphs())
|
|
@@ -902,6 +940,7 @@ class StataClient:
|
|
|
902
940
|
# Set execution flag to prevent recursive Stata calls
|
|
903
941
|
self._is_executing = True
|
|
904
942
|
try:
|
|
943
|
+
from sfi import Scalar, SFIToolkit # Import SFI tools
|
|
905
944
|
with self._temp_cwd(cwd):
|
|
906
945
|
with self._redirect_io_streaming(tee, tee):
|
|
907
946
|
try:
|
|
@@ -914,15 +953,17 @@ class StataClient:
|
|
|
914
953
|
tee.write(ret)
|
|
915
954
|
except Exception:
|
|
916
955
|
pass
|
|
956
|
+
|
|
957
|
+
# ROBUST DETECTION & OUTPUT
|
|
958
|
+
rc = self._read_return_code()
|
|
959
|
+
|
|
917
960
|
except Exception as e:
|
|
918
961
|
exc = e
|
|
962
|
+
if rc == 0: rc = 1
|
|
919
963
|
finally:
|
|
920
|
-
rc = self._read_return_code()
|
|
921
964
|
if trace:
|
|
922
|
-
try:
|
|
923
|
-
|
|
924
|
-
except Exception:
|
|
925
|
-
pass
|
|
965
|
+
try: self.stata.run("set trace off")
|
|
966
|
+
except: pass
|
|
926
967
|
finally:
|
|
927
968
|
# Clear execution flag
|
|
928
969
|
self._is_executing = False
|
|
@@ -970,65 +1011,33 @@ class StataClient:
|
|
|
970
1011
|
tee.close()
|
|
971
1012
|
|
|
972
1013
|
# Robust post-execution graph detection and caching
|
|
973
|
-
# This is the ONLY place where graphs are detected and cached
|
|
974
|
-
# Runs after execution completes, when it's safe to call list_graphs()
|
|
975
1014
|
if graph_cache and graph_cache.auto_cache:
|
|
976
|
-
cached_graphs = []
|
|
977
1015
|
try:
|
|
978
|
-
#
|
|
1016
|
+
# [Existing graph cache logic kept identical]
|
|
1017
|
+
cached_graphs = []
|
|
979
1018
|
initial_graphs = getattr(graph_cache, '_initial_graphs', set())
|
|
980
|
-
|
|
981
|
-
# Get current state (after execution)
|
|
982
|
-
logger.debug("Post-execution: Querying graph state via list_graphs()")
|
|
983
1019
|
current_graphs = set(self.list_graphs())
|
|
984
|
-
|
|
985
|
-
# Detect new graphs (created during execution)
|
|
986
1020
|
new_graphs = current_graphs - initial_graphs - graph_cache._cached_graphs
|
|
987
1021
|
|
|
988
1022
|
if new_graphs:
|
|
989
1023
|
logger.info(f"Detected {len(new_graphs)} new graph(s): {sorted(new_graphs)}")
|
|
990
1024
|
|
|
991
|
-
# Cache each detected graph
|
|
992
1025
|
for graph_name in new_graphs:
|
|
993
1026
|
try:
|
|
994
|
-
logger.debug(f"Caching graph: {graph_name}")
|
|
995
1027
|
cache_result = await anyio.to_thread.run_sync(
|
|
996
1028
|
self.cache_graph_on_creation,
|
|
997
1029
|
graph_name
|
|
998
1030
|
)
|
|
999
|
-
|
|
1000
1031
|
if cache_result:
|
|
1001
1032
|
cached_graphs.append(graph_name)
|
|
1002
1033
|
graph_cache._cached_graphs.add(graph_name)
|
|
1003
|
-
|
|
1004
|
-
else:
|
|
1005
|
-
logger.warning(f"Failed to cache graph: {graph_name}")
|
|
1006
|
-
|
|
1007
|
-
# Trigger callbacks
|
|
1034
|
+
|
|
1008
1035
|
for callback in graph_cache._cache_callbacks:
|
|
1009
1036
|
try:
|
|
1010
1037
|
await anyio.to_thread.run_sync(callback, graph_name, cache_result)
|
|
1011
|
-
except Exception
|
|
1012
|
-
logger.debug(f"Callback failed for {graph_name}: {e}")
|
|
1013
|
-
|
|
1038
|
+
except Exception: pass
|
|
1014
1039
|
except Exception as e:
|
|
1015
1040
|
logger.error(f"Error caching graph {graph_name}: {e}")
|
|
1016
|
-
# Trigger callbacks with failure
|
|
1017
|
-
for callback in graph_cache._cache_callbacks:
|
|
1018
|
-
try:
|
|
1019
|
-
await anyio.to_thread.run_sync(callback, graph_name, False)
|
|
1020
|
-
except Exception:
|
|
1021
|
-
pass
|
|
1022
|
-
|
|
1023
|
-
# Check for dropped graphs (for completeness)
|
|
1024
|
-
dropped_graphs = initial_graphs - current_graphs
|
|
1025
|
-
if dropped_graphs:
|
|
1026
|
-
logger.debug(f"Graphs dropped during execution: {sorted(dropped_graphs)}")
|
|
1027
|
-
for graph_name in dropped_graphs:
|
|
1028
|
-
try:
|
|
1029
|
-
self.invalidate_graph_cache(graph_name)
|
|
1030
|
-
except Exception:
|
|
1031
|
-
pass
|
|
1032
1041
|
|
|
1033
1042
|
# Notify progress if graphs were cached
|
|
1034
1043
|
if cached_graphs and notify_progress:
|
|
@@ -1037,41 +1046,29 @@ class StataClient:
|
|
|
1037
1046
|
float(total_lines) if total_lines > 0 else 1,
|
|
1038
1047
|
f"Do-file completed. Cached {len(cached_graphs)} graph(s): {', '.join(cached_graphs)}"
|
|
1039
1048
|
)
|
|
1040
|
-
|
|
1041
1049
|
except Exception as e:
|
|
1042
1050
|
logger.error(f"Post-execution graph detection failed: {e}")
|
|
1043
1051
|
|
|
1044
1052
|
tail_text = tail.get_value()
|
|
1053
|
+
log_tail = self._read_log_tail(log_path, 200000 if trace else 20000)
|
|
1054
|
+
if log_tail and len(log_tail) > len(tail_text):
|
|
1055
|
+
tail_text = log_tail
|
|
1045
1056
|
combined = (tail_text or "") + (f"\n{exc}" if exc else "")
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
rc = rc_hint
|
|
1049
|
-
if exc is None and rc_hint is None:
|
|
1050
|
-
rc = 0 if rc is None or rc != 0 else rc
|
|
1051
|
-
success = rc == 0 and exc is None
|
|
1057
|
+
|
|
1058
|
+
success = (rc == 0 and exc is None)
|
|
1052
1059
|
error = None
|
|
1060
|
+
|
|
1053
1061
|
if not success:
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
|
|
1057
|
-
line_no = self._parse_line_from_text(combined) if combined else None
|
|
1058
|
-
message = "Stata error"
|
|
1059
|
-
if tail_text and tail_text.strip():
|
|
1060
|
-
for line in reversed(tail_text.splitlines()):
|
|
1061
|
-
if line.strip():
|
|
1062
|
-
message = line.strip()
|
|
1063
|
-
break
|
|
1064
|
-
elif exc is not None:
|
|
1065
|
-
message = str(exc).strip() or message
|
|
1062
|
+
# Robust extraction
|
|
1063
|
+
msg, context = self._extract_error_and_context(combined, rc)
|
|
1066
1064
|
|
|
1067
1065
|
error = ErrorEnvelope(
|
|
1068
|
-
message=
|
|
1069
|
-
|
|
1070
|
-
|
|
1066
|
+
message=msg,
|
|
1067
|
+
context=context,
|
|
1068
|
+
rc=rc,
|
|
1071
1069
|
command=command,
|
|
1072
1070
|
log_path=log_path,
|
|
1073
|
-
snippet=
|
|
1074
|
-
trace=trace or None,
|
|
1071
|
+
snippet=combined[-800:]
|
|
1075
1072
|
)
|
|
1076
1073
|
|
|
1077
1074
|
duration = time.time() - start_time
|
|
@@ -1425,6 +1422,65 @@ class StataClient:
|
|
|
1425
1422
|
|
|
1426
1423
|
return indices
|
|
1427
1424
|
|
|
1425
|
+
def apply_sort(self, sort_spec: List[str]) -> None:
|
|
1426
|
+
"""
|
|
1427
|
+
Apply sorting to the dataset using gsort.
|
|
1428
|
+
|
|
1429
|
+
Args:
|
|
1430
|
+
sort_spec: List of variables to sort by, with optional +/- prefix.
|
|
1431
|
+
e.g., ["-price", "+mpg"] sorts by price descending, then mpg ascending.
|
|
1432
|
+
No prefix is treated as ascending (+).
|
|
1433
|
+
|
|
1434
|
+
Raises:
|
|
1435
|
+
ValueError: If sort_spec is invalid or contains invalid variables
|
|
1436
|
+
RuntimeError: If no data in memory or sort command fails
|
|
1437
|
+
"""
|
|
1438
|
+
if not self._initialized:
|
|
1439
|
+
self.init()
|
|
1440
|
+
|
|
1441
|
+
state = self.get_dataset_state()
|
|
1442
|
+
if int(state.get("k", 0) or 0) == 0 and int(state.get("n", 0) or 0) == 0:
|
|
1443
|
+
raise RuntimeError("No data in memory")
|
|
1444
|
+
|
|
1445
|
+
if not sort_spec or not isinstance(sort_spec, list):
|
|
1446
|
+
raise ValueError("sort_spec must be a non-empty list")
|
|
1447
|
+
|
|
1448
|
+
# Validate all variables exist
|
|
1449
|
+
var_map = self._get_var_index_map()
|
|
1450
|
+
for spec in sort_spec:
|
|
1451
|
+
if not isinstance(spec, str) or not spec:
|
|
1452
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
1453
|
+
|
|
1454
|
+
# Extract variable name (remove +/- prefix if present)
|
|
1455
|
+
varname = spec.lstrip("+-")
|
|
1456
|
+
if not varname:
|
|
1457
|
+
raise ValueError(f"Invalid sort specification: {spec!r}")
|
|
1458
|
+
|
|
1459
|
+
if varname not in var_map:
|
|
1460
|
+
raise ValueError(f"Variable not found: {varname}")
|
|
1461
|
+
|
|
1462
|
+
# Build gsort command
|
|
1463
|
+
# gsort uses - for descending, + or nothing for ascending
|
|
1464
|
+
gsort_args = []
|
|
1465
|
+
for spec in sort_spec:
|
|
1466
|
+
if spec.startswith("-") or spec.startswith("+"):
|
|
1467
|
+
gsort_args.append(spec)
|
|
1468
|
+
else:
|
|
1469
|
+
# No prefix means ascending, add + explicitly for clarity
|
|
1470
|
+
gsort_args.append(f"+{spec}")
|
|
1471
|
+
|
|
1472
|
+
cmd = f"gsort {' '.join(gsort_args)}"
|
|
1473
|
+
|
|
1474
|
+
try:
|
|
1475
|
+
result = self.run_command_structured(cmd, echo=False)
|
|
1476
|
+
if not result.success:
|
|
1477
|
+
error_msg = result.error.message if result.error else "Sort failed"
|
|
1478
|
+
raise RuntimeError(f"Failed to sort dataset: {error_msg}")
|
|
1479
|
+
except Exception as e:
|
|
1480
|
+
if isinstance(e, RuntimeError):
|
|
1481
|
+
raise
|
|
1482
|
+
raise RuntimeError(f"Failed to sort dataset: {e}")
|
|
1483
|
+
|
|
1428
1484
|
def get_variable_details(self, varname: str) -> str:
|
|
1429
1485
|
"""Returns codebook/summary for a specific variable."""
|
|
1430
1486
|
resp = self.run_command_structured(f"codebook {varname}", echo=True)
|
|
@@ -2247,68 +2303,72 @@ class StataClient:
|
|
|
2247
2303
|
buffering=1,
|
|
2248
2304
|
)
|
|
2249
2305
|
log_path = log_file.name
|
|
2250
|
-
tail = TailBuffer(max_chars=
|
|
2306
|
+
tail = TailBuffer(max_chars=200000 if trace else 20000)
|
|
2251
2307
|
tee = FileTeeIO(log_file, tail)
|
|
2252
2308
|
|
|
2253
2309
|
rc = -1
|
|
2254
2310
|
|
|
2255
2311
|
with self._exec_lock:
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2275
|
-
|
|
2276
|
-
|
|
2312
|
+
try:
|
|
2313
|
+
from sfi import Scalar, SFIToolkit # Import SFI tools
|
|
2314
|
+
with self._temp_cwd(cwd):
|
|
2315
|
+
with self._redirect_io_streaming(tee, tee):
|
|
2316
|
+
try:
|
|
2317
|
+
if trace:
|
|
2318
|
+
self.stata.run("set trace on")
|
|
2319
|
+
ret = self.stata.run(command, echo=echo)
|
|
2320
|
+
# Some PyStata builds return output as a string rather than printing.
|
|
2321
|
+
if isinstance(ret, str) and ret:
|
|
2322
|
+
try:
|
|
2323
|
+
tee.write(ret)
|
|
2324
|
+
except Exception:
|
|
2325
|
+
pass
|
|
2326
|
+
|
|
2327
|
+
except Exception as e:
|
|
2328
|
+
exc = e
|
|
2329
|
+
rc = 1
|
|
2330
|
+
finally:
|
|
2331
|
+
if trace:
|
|
2332
|
+
try:
|
|
2333
|
+
self.stata.run("set trace off")
|
|
2334
|
+
except Exception:
|
|
2335
|
+
pass
|
|
2336
|
+
except Exception as e:
|
|
2337
|
+
# Outer catch in case imports or locks fail
|
|
2338
|
+
exc = e
|
|
2339
|
+
rc = 1
|
|
2277
2340
|
|
|
2278
2341
|
tee.close()
|
|
2279
2342
|
|
|
2280
2343
|
tail_text = tail.get_value()
|
|
2344
|
+
log_tail = self._read_log_tail(log_path, 200000 if trace else 20000)
|
|
2345
|
+
if log_tail and len(log_tail) > len(tail_text):
|
|
2346
|
+
tail_text = log_tail
|
|
2281
2347
|
combined = (tail_text or "") + (f"\n{exc}" if exc else "")
|
|
2282
|
-
rc_hint = self._parse_rc_from_text(combined) if combined else None
|
|
2283
|
-
if exc is None and rc_hint is not None and rc_hint != 0:
|
|
2284
|
-
rc = rc_hint
|
|
2285
|
-
if exc is None and rc_hint is None:
|
|
2286
|
-
rc = 0 if rc is None or rc != 0 else rc
|
|
2287
|
-
success = rc == 0 and exc is None
|
|
2288
2348
|
|
|
2349
|
+
# Parse RC from log tail if no exception occurred
|
|
2350
|
+
if rc == -1 and not exc:
|
|
2351
|
+
parsed_rc = self._parse_rc_from_text(combined)
|
|
2352
|
+
rc = parsed_rc if parsed_rc is not None else 0
|
|
2353
|
+
elif exc:
|
|
2354
|
+
# Try to parse RC from exception message
|
|
2355
|
+
parsed_rc = self._parse_rc_from_text(str(exc))
|
|
2356
|
+
if parsed_rc is not None:
|
|
2357
|
+
rc = parsed_rc
|
|
2358
|
+
|
|
2359
|
+
success = (rc == 0 and exc is None)
|
|
2289
2360
|
error = None
|
|
2361
|
+
|
|
2290
2362
|
if not success:
|
|
2291
|
-
|
|
2292
|
-
|
|
2293
|
-
rc_final = rc_hint if (rc_hint is not None and rc_hint != 0) else (rc if rc not in (-1, None) else rc_hint)
|
|
2294
|
-
line_no = self._parse_line_from_text(combined) if combined else None
|
|
2295
|
-
message = "Stata error"
|
|
2296
|
-
if tail_text and tail_text.strip():
|
|
2297
|
-
for line in reversed(tail_text.splitlines()):
|
|
2298
|
-
if line.strip():
|
|
2299
|
-
message = line.strip()
|
|
2300
|
-
break
|
|
2301
|
-
elif exc is not None:
|
|
2302
|
-
message = str(exc).strip() or message
|
|
2363
|
+
# Robust extraction
|
|
2364
|
+
msg, context = self._extract_error_and_context(combined, rc)
|
|
2303
2365
|
|
|
2304
2366
|
error = ErrorEnvelope(
|
|
2305
|
-
message=
|
|
2306
|
-
rc=
|
|
2307
|
-
|
|
2367
|
+
message=msg,
|
|
2368
|
+
rc=rc,
|
|
2369
|
+
snippet=context,
|
|
2308
2370
|
command=command,
|
|
2309
|
-
log_path=log_path
|
|
2310
|
-
snippet=snippet,
|
|
2311
|
-
trace=trace or None,
|
|
2371
|
+
log_path=log_path
|
|
2312
2372
|
)
|
|
2313
2373
|
|
|
2314
2374
|
duration = time.time() - start_time
|
|
@@ -2383,5 +2443,4 @@ class StataClient:
|
|
|
2383
2443
|
error=result.error,
|
|
2384
2444
|
)
|
|
2385
2445
|
|
|
2386
|
-
return result
|
|
2387
|
-
|
|
2446
|
+
return result
|