mcp-stata 1.2.2__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mcp-stata might be problematic. Click here for more details.
- mcp_stata/discovery.py +96 -25
- mcp_stata/graph_detector.py +385 -0
- mcp_stata/models.py +4 -1
- mcp_stata/server.py +258 -44
- mcp_stata/stata_client.py +1990 -265
- mcp_stata/streaming_io.py +261 -0
- mcp_stata/ui_http.py +540 -0
- mcp_stata-1.6.2.dist-info/METADATA +380 -0
- mcp_stata-1.6.2.dist-info/RECORD +14 -0
- mcp_stata-1.2.2.dist-info/METADATA +0 -240
- mcp_stata-1.2.2.dist-info/RECORD +0 -11
- {mcp_stata-1.2.2.dist-info → mcp_stata-1.6.2.dist-info}/WHEEL +0 -0
- {mcp_stata-1.2.2.dist-info → mcp_stata-1.6.2.dist-info}/entry_points.txt +0 -0
- {mcp_stata-1.2.2.dist-info → mcp_stata-1.6.2.dist-info}/licenses/LICENSE +0 -0
mcp_stata/discovery.py
CHANGED
|
@@ -4,11 +4,30 @@ import glob
|
|
|
4
4
|
import logging
|
|
5
5
|
import shutil
|
|
6
6
|
|
|
7
|
-
from typing import Tuple,
|
|
7
|
+
from typing import Tuple, List
|
|
8
8
|
|
|
9
9
|
logger = logging.getLogger("mcp_stata.discovery")
|
|
10
10
|
|
|
11
11
|
|
|
12
|
+
def _normalize_env_path(raw: str) -> str:
|
|
13
|
+
"""Strip quotes/whitespace and expand variables for STATA_PATH."""
|
|
14
|
+
cleaned = raw.strip()
|
|
15
|
+
if (cleaned.startswith("\"") and cleaned.endswith("\"")) or (
|
|
16
|
+
cleaned.startswith("'") and cleaned.endswith("'")
|
|
17
|
+
):
|
|
18
|
+
cleaned = cleaned[1:-1].strip()
|
|
19
|
+
return os.path.expandvars(os.path.expanduser(cleaned))
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _is_executable(path: str, system: str) -> bool:
|
|
23
|
+
if not os.path.exists(path):
|
|
24
|
+
return False
|
|
25
|
+
if system == "Windows":
|
|
26
|
+
# On Windows, check if it's a file and has .exe extension
|
|
27
|
+
return os.path.isfile(path) and path.lower().endswith('.exe')
|
|
28
|
+
return os.access(path, os.X_OK)
|
|
29
|
+
|
|
30
|
+
|
|
12
31
|
def _dedupe_preserve(items: List[tuple]) -> List[tuple]:
|
|
13
32
|
seen = set()
|
|
14
33
|
unique = []
|
|
@@ -27,9 +46,60 @@ def find_stata_path() -> Tuple[str, str]:
|
|
|
27
46
|
"""
|
|
28
47
|
system = platform.system()
|
|
29
48
|
|
|
30
|
-
|
|
49
|
+
windows_binaries = [
|
|
50
|
+
("StataMP-64.exe", "mp"),
|
|
51
|
+
("StataMP.exe", "mp"),
|
|
52
|
+
("StataSE-64.exe", "se"),
|
|
53
|
+
("StataSE.exe", "se"),
|
|
54
|
+
("Stata-64.exe", "be"),
|
|
55
|
+
("Stata.exe", "be"),
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
linux_binaries = [
|
|
59
|
+
("stata-mp", "mp"),
|
|
60
|
+
("stata-se", "se"),
|
|
61
|
+
("stata-ic", "be"),
|
|
62
|
+
("stata", "be"),
|
|
63
|
+
("xstata-mp", "mp"),
|
|
64
|
+
("xstata-se", "se"),
|
|
65
|
+
("xstata-ic", "be"),
|
|
66
|
+
("xstata", "be"),
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
# 1. Check Environment Variable (supports quoted values and directory targets)
|
|
31
70
|
if os.environ.get("STATA_PATH"):
|
|
32
|
-
|
|
71
|
+
raw_path = os.environ["STATA_PATH"]
|
|
72
|
+
path = _normalize_env_path(raw_path)
|
|
73
|
+
logger.info("Using STATA_PATH override (normalized): %s", path)
|
|
74
|
+
|
|
75
|
+
# If a directory is provided, try standard binaries for the platform
|
|
76
|
+
if os.path.isdir(path):
|
|
77
|
+
search_set = []
|
|
78
|
+
if system == "Windows":
|
|
79
|
+
search_set = windows_binaries
|
|
80
|
+
elif system == "Linux":
|
|
81
|
+
search_set = linux_binaries
|
|
82
|
+
elif system == "Darwin":
|
|
83
|
+
search_set = [
|
|
84
|
+
("Contents/MacOS/stata-mp", "mp"),
|
|
85
|
+
("Contents/MacOS/stata-se", "se"),
|
|
86
|
+
("Contents/MacOS/stata", "be"),
|
|
87
|
+
("stata-mp", "mp"),
|
|
88
|
+
("stata-se", "se"),
|
|
89
|
+
("stata", "be"),
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
for binary, edition in search_set:
|
|
93
|
+
candidate = os.path.join(path, binary)
|
|
94
|
+
if _is_executable(candidate, system):
|
|
95
|
+
logger.info("Found Stata via STATA_PATH directory: %s (%s)", candidate, edition)
|
|
96
|
+
return candidate, edition
|
|
97
|
+
|
|
98
|
+
raise FileNotFoundError(
|
|
99
|
+
f"STATA_PATH points to directory '{path}', but no Stata executable was found within. "
|
|
100
|
+
"Point STATA_PATH directly to the Stata binary (e.g., C:\\Program Files\\Stata18\\StataMP-64.exe)."
|
|
101
|
+
)
|
|
102
|
+
|
|
33
103
|
edition = "be"
|
|
34
104
|
lower_path = path.lower()
|
|
35
105
|
if "mp" in lower_path:
|
|
@@ -44,7 +114,7 @@ def find_stata_path() -> Tuple[str, str]:
|
|
|
44
114
|
"Update STATA_PATH to your Stata binary (e.g., "
|
|
45
115
|
"/Applications/StataNow/StataMP.app/Contents/MacOS/stata-mp or /usr/local/stata18/stata-mp)."
|
|
46
116
|
)
|
|
47
|
-
if not
|
|
117
|
+
if not _is_executable(path, system):
|
|
48
118
|
raise PermissionError(
|
|
49
119
|
f"STATA_PATH points to '{path}', but it is not executable. "
|
|
50
120
|
"Ensure this is the Stata binary, not the .app directory."
|
|
@@ -84,29 +154,13 @@ def find_stata_path() -> Tuple[str, str]:
|
|
|
84
154
|
|
|
85
155
|
for base_dir in base_dirs:
|
|
86
156
|
for stata_dir in glob.glob(os.path.join(base_dir, "Stata*")):
|
|
87
|
-
for exe, edition in
|
|
88
|
-
("StataMP-64.exe", "mp"),
|
|
89
|
-
("StataMP.exe", "mp"),
|
|
90
|
-
("StataSE-64.exe", "se"),
|
|
91
|
-
("StataSE.exe", "se"),
|
|
92
|
-
("Stata-64.exe", "be"),
|
|
93
|
-
("Stata.exe", "be"),
|
|
94
|
-
]:
|
|
157
|
+
for exe, edition in windows_binaries:
|
|
95
158
|
full_path = os.path.join(stata_dir, exe)
|
|
96
159
|
if os.path.exists(full_path):
|
|
97
160
|
candidates.append((full_path, edition))
|
|
98
161
|
|
|
99
162
|
elif system == "Linux":
|
|
100
|
-
|
|
101
|
-
("stata-mp", "mp"),
|
|
102
|
-
("stata-se", "se"),
|
|
103
|
-
("stata-ic", "be"),
|
|
104
|
-
("stata", "be"),
|
|
105
|
-
("xstata-mp", "mp"),
|
|
106
|
-
("xstata-se", "se"),
|
|
107
|
-
("xstata-ic", "be"),
|
|
108
|
-
("xstata", "be"),
|
|
109
|
-
]
|
|
163
|
+
home_base = os.environ.get("HOME") or os.path.expanduser("~")
|
|
110
164
|
|
|
111
165
|
# 2a. Try binaries available on PATH first
|
|
112
166
|
for binary, edition in linux_binaries:
|
|
@@ -118,8 +172,8 @@ def find_stata_path() -> Tuple[str, str]:
|
|
|
118
172
|
linux_roots = [
|
|
119
173
|
"/usr/local",
|
|
120
174
|
"/opt",
|
|
121
|
-
os.path.
|
|
122
|
-
os.path.
|
|
175
|
+
os.path.join(home_base, "stata"),
|
|
176
|
+
os.path.join(home_base, "Stata"),
|
|
123
177
|
]
|
|
124
178
|
|
|
125
179
|
for root in linux_roots:
|
|
@@ -143,13 +197,14 @@ def find_stata_path() -> Tuple[str, str]:
|
|
|
143
197
|
if os.path.exists(full_path):
|
|
144
198
|
candidates.append((full_path, edition))
|
|
145
199
|
|
|
200
|
+
|
|
146
201
|
candidates = _dedupe_preserve(candidates)
|
|
147
202
|
|
|
148
203
|
for path, edition in candidates:
|
|
149
204
|
if not os.path.exists(path):
|
|
150
205
|
logger.warning("Discovered candidate missing on disk: %s", path)
|
|
151
206
|
continue
|
|
152
|
-
if not
|
|
207
|
+
if not _is_executable(path, system):
|
|
153
208
|
logger.warning("Discovered candidate is not executable: %s", path)
|
|
154
209
|
continue
|
|
155
210
|
logger.info("Auto-discovered Stata at %s (%s)", path, edition)
|
|
@@ -160,3 +215,19 @@ def find_stata_path() -> Tuple[str, str]:
|
|
|
160
215
|
"Set STATA_PATH to your Stata executable (e.g., "
|
|
161
216
|
"/Applications/StataNow/StataMP.app/Contents/MacOS/stata-mp, /usr/local/stata18/stata-mp, or C:\\Program Files\\Stata18\\StataMP-64.exe)."
|
|
162
217
|
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def main() -> int:
|
|
221
|
+
"""CLI helper to print discovered Stata binary and edition."""
|
|
222
|
+
try:
|
|
223
|
+
path, edition = find_stata_path()
|
|
224
|
+
# Print so CLI users and tests see the output on stdout.
|
|
225
|
+
print(f"Stata executable: {path}\nEdition: {edition}")
|
|
226
|
+
return 0
|
|
227
|
+
except Exception as exc: # pragma: no cover - exercised via tests with env
|
|
228
|
+
print(f"Discovery failed: {exc}")
|
|
229
|
+
return 1
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
if __name__ == "__main__": # pragma: no cover - manual utility
|
|
233
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Graph creation detection for streaming Stata output.
|
|
3
|
+
|
|
4
|
+
This module provides functionality to detect when graphs are created
|
|
5
|
+
during Stata command execution and automatically cache them.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import re
|
|
10
|
+
import threading
|
|
11
|
+
import time
|
|
12
|
+
from typing import List, Set, Callable, Dict, Any
|
|
13
|
+
import logging
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# SFI is always available
|
|
17
|
+
SFI_AVAILABLE = True
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GraphCreationDetector:
|
|
23
|
+
"""Detects graph creation using SFI-only detection with pystata integration."""
|
|
24
|
+
|
|
25
|
+
def __init__(self, stata_client=None):
|
|
26
|
+
self._lock = threading.Lock()
|
|
27
|
+
self._detected_graphs: Set[str] = set()
|
|
28
|
+
self._removed_graphs: Set[str] = set()
|
|
29
|
+
self._unnamed_graph_counter = 0 # Track unnamed graphs for identification
|
|
30
|
+
self._stata_client = stata_client
|
|
31
|
+
self._last_graph_state: Dict[str, Any] = {} # Track graph state changes
|
|
32
|
+
|
|
33
|
+
def _describe_graph_signature(self, graph_name: str) -> str:
|
|
34
|
+
"""Return a stable signature for a graph.
|
|
35
|
+
|
|
36
|
+
We intentionally avoid using timestamps as the signature, since that makes
|
|
37
|
+
every poll look like a modification.
|
|
38
|
+
"""
|
|
39
|
+
if not self._stata_client or not hasattr(self._stata_client, "stata"):
|
|
40
|
+
return ""
|
|
41
|
+
try:
|
|
42
|
+
# Capture output so we can hash it deterministically.
|
|
43
|
+
resp = self._stata_client.run_command_structured(f"graph describe {graph_name}", echo=False)
|
|
44
|
+
if resp.success and resp.stdout:
|
|
45
|
+
return resp.stdout
|
|
46
|
+
if resp.error and resp.error.snippet:
|
|
47
|
+
return resp.error.snippet
|
|
48
|
+
except Exception:
|
|
49
|
+
return ""
|
|
50
|
+
return ""
|
|
51
|
+
|
|
52
|
+
def _detect_graphs_via_pystata(self) -> List[str]:
|
|
53
|
+
"""Detect newly created graphs using direct pystata state access."""
|
|
54
|
+
if not self._stata_client:
|
|
55
|
+
return []
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
# Get current graph state using pystata's sfi interface
|
|
59
|
+
current_graphs = self._get_current_graphs_from_pystata()
|
|
60
|
+
current_state = self._get_graph_state_from_pystata()
|
|
61
|
+
|
|
62
|
+
# Compare with last known state to detect new graphs
|
|
63
|
+
new_graphs = []
|
|
64
|
+
|
|
65
|
+
# Check for new graph names
|
|
66
|
+
for graph_name in current_graphs:
|
|
67
|
+
if graph_name not in self._last_graph_state and graph_name not in self._removed_graphs:
|
|
68
|
+
new_graphs.append(graph_name)
|
|
69
|
+
|
|
70
|
+
# Check for state changes in existing graphs (modifications)
|
|
71
|
+
for graph_name, state in current_state.items():
|
|
72
|
+
if graph_name in self._last_graph_state:
|
|
73
|
+
last_state = self._last_graph_state[graph_name]
|
|
74
|
+
# Compare stable signature only.
|
|
75
|
+
if state.get("signature") != last_state.get("signature"):
|
|
76
|
+
if graph_name not in self._removed_graphs:
|
|
77
|
+
new_graphs.append(graph_name)
|
|
78
|
+
|
|
79
|
+
# Update cached state
|
|
80
|
+
self._last_graph_state = current_state.copy()
|
|
81
|
+
|
|
82
|
+
return new_graphs
|
|
83
|
+
|
|
84
|
+
except (ImportError, RuntimeError, ValueError, AttributeError) as e:
|
|
85
|
+
# These are expected exceptions when SFI is not available or Stata state is inaccessible
|
|
86
|
+
logger.debug(f"Failed to detect graphs via pystata (expected): {e}")
|
|
87
|
+
return []
|
|
88
|
+
except Exception as e:
|
|
89
|
+
# Unexpected errors should be logged as errors
|
|
90
|
+
logger.error(f"Unexpected error in pystata graph detection: {e}")
|
|
91
|
+
return []
|
|
92
|
+
|
|
93
|
+
def _get_current_graphs_from_pystata(self) -> List[str]:
|
|
94
|
+
"""Get current list of graphs using pystata's sfi interface."""
|
|
95
|
+
try:
|
|
96
|
+
# Use pystata to get graph list directly
|
|
97
|
+
if self._stata_client and hasattr(self._stata_client, 'list_graphs'):
|
|
98
|
+
return self._stata_client.list_graphs()
|
|
99
|
+
else:
|
|
100
|
+
# Fallback to sfi Macro interface - only if stata is available
|
|
101
|
+
if self._stata_client and hasattr(self._stata_client, 'stata'):
|
|
102
|
+
try:
|
|
103
|
+
from sfi import Macro
|
|
104
|
+
self._stata_client.stata.run("quietly graph dir, memory")
|
|
105
|
+
self._stata_client.stata.run("global mcp_graph_list `r(list)'")
|
|
106
|
+
graph_list_str = Macro.getGlobal("mcp_graph_list")
|
|
107
|
+
return graph_list_str.split() if graph_list_str else []
|
|
108
|
+
except ImportError:
|
|
109
|
+
logger.warning("sfi.Macro not available for fallback graph detection")
|
|
110
|
+
return []
|
|
111
|
+
else:
|
|
112
|
+
return []
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.warning(f"Failed to get current graphs from pystata: {e}")
|
|
115
|
+
return []
|
|
116
|
+
|
|
117
|
+
def _get_graph_state_from_pystata(self) -> Dict[str, Any]:
|
|
118
|
+
"""Get detailed graph state information using pystata's sfi interface."""
|
|
119
|
+
graph_state = {}
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
current_graphs = self._get_current_graphs_from_pystata()
|
|
123
|
+
|
|
124
|
+
for graph_name in current_graphs:
|
|
125
|
+
try:
|
|
126
|
+
signature = self._describe_graph_signature(graph_name)
|
|
127
|
+
state_info = {
|
|
128
|
+
"name": graph_name,
|
|
129
|
+
"exists": True,
|
|
130
|
+
"valid": bool(signature),
|
|
131
|
+
"signature": signature,
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
# Only update timestamps when the signature changes.
|
|
135
|
+
prev = self._last_graph_state.get(graph_name)
|
|
136
|
+
if prev is None or prev.get("signature") != signature:
|
|
137
|
+
state_info["timestamp"] = time.time()
|
|
138
|
+
else:
|
|
139
|
+
state_info["timestamp"] = prev.get("timestamp", time.time())
|
|
140
|
+
|
|
141
|
+
graph_state[graph_name] = state_info
|
|
142
|
+
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.warning(f"Failed to get state for graph {graph_name}: {e}")
|
|
145
|
+
graph_state[graph_name] = {"name": graph_name, "timestamp": time.time(), "exists": False, "signature": ""}
|
|
146
|
+
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.warning(f"Failed to get graph state from pystata: {e}")
|
|
149
|
+
|
|
150
|
+
return graph_state
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def detect_graph_modifications(self, text: str = None) -> dict:
|
|
155
|
+
"""Detect graph modification/removal using SFI state comparison."""
|
|
156
|
+
modifications = {"dropped": [], "renamed": [], "cleared": False}
|
|
157
|
+
|
|
158
|
+
if not self._stata_client:
|
|
159
|
+
return modifications
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
# Get current graph state via SFI
|
|
163
|
+
current_graphs = set(self._get_current_graphs_from_pystata())
|
|
164
|
+
|
|
165
|
+
# Compare with last known state to detect modifications
|
|
166
|
+
if self._last_graph_state:
|
|
167
|
+
last_graphs = set(self._last_graph_state.keys())
|
|
168
|
+
|
|
169
|
+
# Detect dropped graphs (in last state but not current)
|
|
170
|
+
dropped_graphs = last_graphs - current_graphs
|
|
171
|
+
modifications["dropped"].extend(dropped_graphs)
|
|
172
|
+
|
|
173
|
+
# Detect clear all (no graphs remain when there were some before)
|
|
174
|
+
if last_graphs and not current_graphs:
|
|
175
|
+
modifications["cleared"] = True
|
|
176
|
+
|
|
177
|
+
# Update last known state for next comparison (stable signatures)
|
|
178
|
+
new_state: Dict[str, Any] = {}
|
|
179
|
+
for graph in current_graphs:
|
|
180
|
+
sig = self._describe_graph_signature(graph)
|
|
181
|
+
new_state[graph] = {
|
|
182
|
+
"name": graph,
|
|
183
|
+
"exists": True,
|
|
184
|
+
"valid": bool(sig),
|
|
185
|
+
"signature": sig,
|
|
186
|
+
"timestamp": time.time(),
|
|
187
|
+
}
|
|
188
|
+
self._last_graph_state = new_state
|
|
189
|
+
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.debug(f"SFI modification detection failed: {e}")
|
|
192
|
+
|
|
193
|
+
return modifications
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def should_cache_graph(self, graph_name: str) -> bool:
|
|
197
|
+
"""Determine if a graph should be cached."""
|
|
198
|
+
with self._lock:
|
|
199
|
+
# Don't cache if already detected or removed
|
|
200
|
+
if graph_name in self._detected_graphs or graph_name in self._removed_graphs:
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
# Mark as detected
|
|
204
|
+
self._detected_graphs.add(graph_name)
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
def mark_graph_removed(self, graph_name: str) -> None:
|
|
208
|
+
"""Mark a graph as removed."""
|
|
209
|
+
with self._lock:
|
|
210
|
+
self._removed_graphs.add(graph_name)
|
|
211
|
+
self._detected_graphs.discard(graph_name)
|
|
212
|
+
|
|
213
|
+
def mark_all_cleared(self) -> None:
|
|
214
|
+
"""Mark all graphs as cleared."""
|
|
215
|
+
with self._lock:
|
|
216
|
+
self._detected_graphs.clear()
|
|
217
|
+
self._removed_graphs.clear()
|
|
218
|
+
|
|
219
|
+
def clear_detection_state(self) -> None:
|
|
220
|
+
"""Clear all detection state."""
|
|
221
|
+
with self._lock:
|
|
222
|
+
self._detected_graphs.clear()
|
|
223
|
+
self._removed_graphs.clear()
|
|
224
|
+
self._unnamed_graph_counter = 0
|
|
225
|
+
|
|
226
|
+
def process_modifications(self, modifications: dict) -> None:
|
|
227
|
+
"""Process detected modifications."""
|
|
228
|
+
with self._lock:
|
|
229
|
+
# Handle dropped graphs
|
|
230
|
+
for graph_name in modifications.get("dropped", []):
|
|
231
|
+
self.mark_graph_removed(graph_name)
|
|
232
|
+
|
|
233
|
+
# Handle renamed graphs
|
|
234
|
+
for old_name, new_name in modifications.get("renamed", []):
|
|
235
|
+
self.mark_graph_removed(old_name)
|
|
236
|
+
self._detected_graphs.discard(new_name) # Allow re-detection with new name
|
|
237
|
+
|
|
238
|
+
# Handle clear all
|
|
239
|
+
if modifications.get("cleared", False):
|
|
240
|
+
self.mark_all_cleared()
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class StreamingGraphCache:
|
|
244
|
+
"""Integrates graph detection with caching during streaming."""
|
|
245
|
+
|
|
246
|
+
def __init__(self, stata_client, auto_cache: bool = False):
|
|
247
|
+
self.stata_client = stata_client
|
|
248
|
+
self.auto_cache = auto_cache
|
|
249
|
+
self.detector = GraphCreationDetector(stata_client)
|
|
250
|
+
self._lock = threading.Lock()
|
|
251
|
+
self._cache_callbacks: List[Callable[[str, bool], None]] = []
|
|
252
|
+
self._graphs_to_cache: List[str] = []
|
|
253
|
+
self._cached_graphs: Set[str] = set()
|
|
254
|
+
self._removed_graphs = set() # Track removed graphs directly
|
|
255
|
+
self._initial_graphs: Set[str] = set() # Captured before execution starts
|
|
256
|
+
|
|
257
|
+
def add_cache_callback(self, callback: Callable[[str, bool], None]) -> None:
|
|
258
|
+
"""Add callback for graph cache events."""
|
|
259
|
+
with self._lock:
|
|
260
|
+
self._cache_callbacks.append(callback)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
async def cache_detected_graphs_with_pystata(self) -> List[str]:
|
|
264
|
+
"""Enhanced caching method that uses pystata for real-time graph detection."""
|
|
265
|
+
if not self.auto_cache:
|
|
266
|
+
return []
|
|
267
|
+
|
|
268
|
+
cached_names = []
|
|
269
|
+
|
|
270
|
+
# First, try to get any newly detected graphs via pystata state
|
|
271
|
+
if self.stata_client:
|
|
272
|
+
try:
|
|
273
|
+
# Get current state and check for new graphs
|
|
274
|
+
|
|
275
|
+
pystata_detected = self.detector._detect_graphs_via_pystata()
|
|
276
|
+
|
|
277
|
+
# Add any newly detected graphs to cache queue
|
|
278
|
+
for graph_name in pystata_detected:
|
|
279
|
+
if graph_name not in self._cached_graphs and graph_name not in self._removed_graphs:
|
|
280
|
+
self._graphs_to_cache.append(graph_name)
|
|
281
|
+
|
|
282
|
+
except Exception as e:
|
|
283
|
+
logger.warning(f"Failed to get pystata graph updates: {e}")
|
|
284
|
+
|
|
285
|
+
# Process the cache queue
|
|
286
|
+
with self._lock:
|
|
287
|
+
graphs_to_process = self._graphs_to_cache.copy()
|
|
288
|
+
self._graphs_to_cache.clear()
|
|
289
|
+
|
|
290
|
+
# Get current graph list for verification
|
|
291
|
+
try:
|
|
292
|
+
current_graphs = self.stata_client.list_graphs()
|
|
293
|
+
except Exception as e:
|
|
294
|
+
logger.warning(f"Failed to get current graph list: {e}")
|
|
295
|
+
return cached_names
|
|
296
|
+
|
|
297
|
+
for graph_name in graphs_to_process:
|
|
298
|
+
if graph_name in current_graphs and graph_name not in self._cached_graphs:
|
|
299
|
+
try:
|
|
300
|
+
success = await asyncio.to_thread(self.stata_client.cache_graph_on_creation, graph_name)
|
|
301
|
+
if success:
|
|
302
|
+
cached_names.append(graph_name)
|
|
303
|
+
with self._lock:
|
|
304
|
+
self._cached_graphs.add(graph_name)
|
|
305
|
+
|
|
306
|
+
# Notify callbacks
|
|
307
|
+
for callback in self._cache_callbacks:
|
|
308
|
+
try:
|
|
309
|
+
callback(graph_name, success)
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.warning(f"Cache callback failed for {graph_name}: {e}")
|
|
312
|
+
|
|
313
|
+
except Exception as e:
|
|
314
|
+
logger.warning(f"Failed to cache graph {graph_name}: {e}")
|
|
315
|
+
# Still notify callbacks of failure
|
|
316
|
+
for callback in self._cache_callbacks:
|
|
317
|
+
try:
|
|
318
|
+
callback(graph_name, False)
|
|
319
|
+
except Exception:
|
|
320
|
+
pass
|
|
321
|
+
|
|
322
|
+
return cached_names
|
|
323
|
+
|
|
324
|
+
async def cache_detected_graphs(self) -> List[str]:
|
|
325
|
+
"""Cache all detected graphs."""
|
|
326
|
+
if not self.auto_cache:
|
|
327
|
+
return []
|
|
328
|
+
|
|
329
|
+
cached_names = []
|
|
330
|
+
|
|
331
|
+
with self._lock:
|
|
332
|
+
graphs_to_process = self._graphs_to_cache.copy()
|
|
333
|
+
self._graphs_to_cache.clear()
|
|
334
|
+
|
|
335
|
+
# Get current graph list for verification
|
|
336
|
+
try:
|
|
337
|
+
current_graphs = self.stata_client.list_graphs()
|
|
338
|
+
except Exception as e:
|
|
339
|
+
logger.warning(f"Failed to get current graph list: {e}")
|
|
340
|
+
return cached_names
|
|
341
|
+
|
|
342
|
+
for graph_name in graphs_to_process:
|
|
343
|
+
if graph_name in current_graphs and graph_name not in self._cached_graphs:
|
|
344
|
+
try:
|
|
345
|
+
success = await asyncio.to_thread(self.stata_client.cache_graph_on_creation, graph_name)
|
|
346
|
+
if success:
|
|
347
|
+
cached_names.append(graph_name)
|
|
348
|
+
with self._lock:
|
|
349
|
+
self._cached_graphs.add(graph_name)
|
|
350
|
+
|
|
351
|
+
# Notify callbacks
|
|
352
|
+
for callback in self._cache_callbacks:
|
|
353
|
+
try:
|
|
354
|
+
callback(graph_name, success)
|
|
355
|
+
except Exception as e:
|
|
356
|
+
logger.warning(f"Cache callback failed for {graph_name}: {e}")
|
|
357
|
+
|
|
358
|
+
except Exception as e:
|
|
359
|
+
logger.warning(f"Failed to cache graph {graph_name}: {e}")
|
|
360
|
+
# Still notify callbacks of failure
|
|
361
|
+
for callback in self._cache_callbacks:
|
|
362
|
+
try:
|
|
363
|
+
callback(graph_name, False)
|
|
364
|
+
except Exception:
|
|
365
|
+
pass
|
|
366
|
+
|
|
367
|
+
return cached_names
|
|
368
|
+
|
|
369
|
+
def get_cache_stats(self) -> dict:
|
|
370
|
+
"""Get caching statistics."""
|
|
371
|
+
with self._lock:
|
|
372
|
+
return {
|
|
373
|
+
"auto_cache_enabled": self.auto_cache,
|
|
374
|
+
"pending_cache_count": len(self._graphs_to_cache),
|
|
375
|
+
"cached_graphs_count": len(self._cached_graphs),
|
|
376
|
+
"detected_graphs_count": len(self.detector._detected_graphs),
|
|
377
|
+
"removed_graphs_count": len(self.detector._removed_graphs),
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
def reset(self) -> None:
|
|
381
|
+
"""Reset the cache state."""
|
|
382
|
+
with self._lock:
|
|
383
|
+
self._graphs_to_cache.clear()
|
|
384
|
+
self._cached_graphs.clear()
|
|
385
|
+
self.detector.clear_detection_state()
|
mcp_stata/models.py
CHANGED
|
@@ -7,6 +7,7 @@ class ErrorEnvelope(BaseModel):
|
|
|
7
7
|
rc: Optional[int] = None
|
|
8
8
|
line: Optional[int] = None
|
|
9
9
|
command: Optional[str] = None
|
|
10
|
+
log_path: Optional[str] = None
|
|
10
11
|
stdout: Optional[str] = None
|
|
11
12
|
stderr: Optional[str] = None
|
|
12
13
|
snippet: Optional[str] = None
|
|
@@ -18,6 +19,7 @@ class CommandResponse(BaseModel):
|
|
|
18
19
|
rc: int
|
|
19
20
|
stdout: str
|
|
20
21
|
stderr: Optional[str] = None
|
|
22
|
+
log_path: Optional[str] = None
|
|
21
23
|
success: bool
|
|
22
24
|
error: Optional[ErrorEnvelope] = None
|
|
23
25
|
|
|
@@ -49,7 +51,8 @@ class GraphListResponse(BaseModel):
|
|
|
49
51
|
|
|
50
52
|
class GraphExport(BaseModel):
|
|
51
53
|
name: str
|
|
52
|
-
|
|
54
|
+
file_path: Optional[str] = None
|
|
55
|
+
image_base64: Optional[str] = None
|
|
53
56
|
|
|
54
57
|
|
|
55
58
|
class GraphExportResponse(BaseModel):
|