mcp-stata 1.21.0__cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_stata/__init__.py +3 -0
- mcp_stata/__main__.py +4 -0
- mcp_stata/_native_ops.abi3.so +0 -0
- mcp_stata/config.py +20 -0
- mcp_stata/discovery.py +548 -0
- mcp_stata/graph_detector.py +601 -0
- mcp_stata/models.py +64 -0
- mcp_stata/native_ops.py +87 -0
- mcp_stata/server.py +1233 -0
- mcp_stata/smcl/smcl2html.py +88 -0
- mcp_stata/stata_client.py +4638 -0
- mcp_stata/streaming_io.py +264 -0
- mcp_stata/test_stata.py +56 -0
- mcp_stata/ui_http.py +999 -0
- mcp_stata/utils.py +159 -0
- mcp_stata-1.21.0.dist-info/METADATA +486 -0
- mcp_stata-1.21.0.dist-info/RECORD +20 -0
- mcp_stata-1.21.0.dist-info/WHEEL +5 -0
- mcp_stata-1.21.0.dist-info/entry_points.txt +2 -0
- mcp_stata-1.21.0.dist-info/licenses/LICENSE +661 -0
|
@@ -0,0 +1,601 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
"""
|
|
3
|
+
Graph creation detection for streaming Stata output.
|
|
4
|
+
|
|
5
|
+
This module provides functionality to detect when graphs are created
|
|
6
|
+
during Stata command execution and automatically cache them.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import contextlib
|
|
11
|
+
import inspect
|
|
12
|
+
import re
|
|
13
|
+
import threading
|
|
14
|
+
import time
|
|
15
|
+
from typing import List, Set, Callable, Dict, Any
|
|
16
|
+
import logging
|
|
17
|
+
import uuid
|
|
18
|
+
import shlex
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# SFI is always available
|
|
22
|
+
SFI_AVAILABLE = True
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class GraphCreationDetector:
|
|
28
|
+
"""Detects graph creation using SFI-only detection with pystata integration."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, stata_client=None):
|
|
31
|
+
self._lock = threading.Lock()
|
|
32
|
+
self._detected_graphs: Set[str] = set()
|
|
33
|
+
self._removed_graphs: Set[str] = set()
|
|
34
|
+
self._unnamed_graph_counter = 0 # Track unnamed graphs for identification
|
|
35
|
+
self._stata_client = stata_client
|
|
36
|
+
self._last_graph_state: Dict[str, Any] = {} # Track graph state changes
|
|
37
|
+
self._inventory_cache: Dict[str, Any] = {
|
|
38
|
+
"timestamp": 0.0,
|
|
39
|
+
"graphs": [],
|
|
40
|
+
"timestamps": {},
|
|
41
|
+
}
|
|
42
|
+
self._inventory_cache_ttl = 0.5
|
|
43
|
+
self._inventory_cache_enabled = False
|
|
44
|
+
|
|
45
|
+
def _describe_graph_signature(self, graph_name: str) -> str:
|
|
46
|
+
"""Return a stable signature for a graph.
|
|
47
|
+
|
|
48
|
+
We use name-based tracking tied to the Stata command execution
|
|
49
|
+
context, enriched with timestamps where available.
|
|
50
|
+
"""
|
|
51
|
+
if not self._stata_client:
|
|
52
|
+
return ""
|
|
53
|
+
|
|
54
|
+
# Try to find timestamp in client's cache first for cross-command stability
|
|
55
|
+
cache = getattr(self._stata_client, "_list_graphs_cache", None)
|
|
56
|
+
if cache:
|
|
57
|
+
for g in cache:
|
|
58
|
+
# Cache might contain GraphInfo objects
|
|
59
|
+
name = getattr(g, "name", g if isinstance(g, str) else None)
|
|
60
|
+
created = getattr(g, "created", None)
|
|
61
|
+
if name == graph_name and created:
|
|
62
|
+
return f"{graph_name}_{created}"
|
|
63
|
+
|
|
64
|
+
# Fallback to command_idx
|
|
65
|
+
cmd_idx = getattr(self._stata_client, "_command_idx", 0)
|
|
66
|
+
return f"{graph_name}_{cmd_idx}"
|
|
67
|
+
|
|
68
|
+
def _get_graph_inventory(self, *, need_timestamps: bool = True) -> tuple[List[str], Dict[str, str]]:
|
|
69
|
+
"""Get both the list of graphs and their timestamps in a single Stata call."""
|
|
70
|
+
if not self._stata_client or not hasattr(self._stata_client, "stata"):
|
|
71
|
+
return [], {}
|
|
72
|
+
if self._inventory_cache_enabled:
|
|
73
|
+
now = time.monotonic()
|
|
74
|
+
cached = self._inventory_cache
|
|
75
|
+
if cached["graphs"] and (now - cached["timestamp"]) < self._inventory_cache_ttl:
|
|
76
|
+
if need_timestamps:
|
|
77
|
+
if cached["timestamps"]:
|
|
78
|
+
return list(cached["graphs"]), dict(cached["timestamps"])
|
|
79
|
+
else:
|
|
80
|
+
return list(cached["graphs"]), {}
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
# Use the lock from client to prevent concurrency issues with pystata
|
|
84
|
+
exec_lock = getattr(self._stata_client, "_exec_lock", None)
|
|
85
|
+
ctx = exec_lock if exec_lock else contextlib.nullcontext()
|
|
86
|
+
|
|
87
|
+
with ctx:
|
|
88
|
+
hold_name = f"_mcp_detector_inv_{int(time.time() * 1000 % 1000000)}"
|
|
89
|
+
from sfi import Macro
|
|
90
|
+
|
|
91
|
+
# Bundle to get everything in one round trip
|
|
92
|
+
# 1. Hold results
|
|
93
|
+
# 2. Get list of graphs in memory
|
|
94
|
+
# 3. Store list in global
|
|
95
|
+
# 4. Loop over list to get timestamps
|
|
96
|
+
# 5. Restore results
|
|
97
|
+
bundle = [
|
|
98
|
+
f"capture _return hold {hold_name}",
|
|
99
|
+
"quietly graph dir, memory",
|
|
100
|
+
"local list `r(list)'",
|
|
101
|
+
"macro define mcpinvlist \"`r(list)'\"",
|
|
102
|
+
"local i = 0",
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
if need_timestamps:
|
|
106
|
+
bundle.extend([
|
|
107
|
+
"foreach g of local list {",
|
|
108
|
+
" capture quietly graph describe `g'",
|
|
109
|
+
" macro define mcpinvts`i' \"`r(command_date)'_`r(command_time)'\"",
|
|
110
|
+
" local i = `i' + 1",
|
|
111
|
+
"}",
|
|
112
|
+
])
|
|
113
|
+
|
|
114
|
+
bundle.extend([
|
|
115
|
+
"macro define mcpinvcount \"`i'\"",
|
|
116
|
+
f"capture _return restore {hold_name}",
|
|
117
|
+
])
|
|
118
|
+
|
|
119
|
+
self._stata_client.stata.run("\n".join(bundle), echo=False)
|
|
120
|
+
|
|
121
|
+
# Fetch result list
|
|
122
|
+
raw_list_str = Macro.getGlobal("mcpinvlist")
|
|
123
|
+
count_str = Macro.getGlobal("mcpinvcount")
|
|
124
|
+
|
|
125
|
+
if not raw_list_str:
|
|
126
|
+
return [], {}
|
|
127
|
+
|
|
128
|
+
# Handle quoted names if any (spaces in names)
|
|
129
|
+
try:
|
|
130
|
+
graph_names = shlex.split(raw_list_str)
|
|
131
|
+
except Exception:
|
|
132
|
+
graph_names = raw_list_str.split()
|
|
133
|
+
|
|
134
|
+
# Map internal names back to user-facing names if aliases exist
|
|
135
|
+
reverse = getattr(self._stata_client, "_graph_name_reverse", {})
|
|
136
|
+
user_names = [reverse.get(n, n) for n in graph_names]
|
|
137
|
+
|
|
138
|
+
# Fetch timestamps and map them to user names
|
|
139
|
+
count = int(float(count_str)) if count_str else 0
|
|
140
|
+
|
|
141
|
+
timestamps = {}
|
|
142
|
+
if need_timestamps:
|
|
143
|
+
for i in range(count):
|
|
144
|
+
ts = Macro.getGlobal(f"mcpinvts{i}")
|
|
145
|
+
if ts and i < len(user_names):
|
|
146
|
+
# Use user_names to match what the rest of the system expects
|
|
147
|
+
timestamps[user_names[i]] = ts
|
|
148
|
+
|
|
149
|
+
self._inventory_cache = {
|
|
150
|
+
"timestamp": time.monotonic(),
|
|
151
|
+
"graphs": list(user_names),
|
|
152
|
+
"timestamps": dict(timestamps),
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
return user_names, timestamps
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logger.debug(f"Inventory fetch failed: {e}")
|
|
158
|
+
return [], {}
|
|
159
|
+
except Exception as e:
|
|
160
|
+
logger.debug(f"Inventory fetch failed: {e}")
|
|
161
|
+
return [], {}
|
|
162
|
+
|
|
163
|
+
def _get_graph_timestamp(self, graph_name: str) -> str:
|
|
164
|
+
"""Get the creation/modification timestamp of a graph using graph describe.
|
|
165
|
+
|
|
166
|
+
The result is cached per command to optimize performance during streaming.
|
|
167
|
+
"""
|
|
168
|
+
results = self._get_graph_timestamps([graph_name])
|
|
169
|
+
return results.get(graph_name, "")
|
|
170
|
+
|
|
171
|
+
def _get_graph_timestamps(self, graph_names: List[str]) -> Dict[str, str]:
|
|
172
|
+
"""Get timestamps for multiple graphs in a single Stata call to minimize overhead."""
|
|
173
|
+
if not graph_names or not self._stata_client or not hasattr(self._stata_client, "stata"):
|
|
174
|
+
return {}
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
# Use the lock from client to prevent concurrency issues with pystata
|
|
178
|
+
exec_lock = getattr(self._stata_client, "_exec_lock", None)
|
|
179
|
+
ctx = exec_lock if exec_lock else contextlib.nullcontext()
|
|
180
|
+
|
|
181
|
+
with ctx:
|
|
182
|
+
hold_name = f"_mcp_detector_thold_{int(time.time() * 1000 % 1000000)}"
|
|
183
|
+
self._stata_client.stata.run(f"capture _return hold {hold_name}", echo=False)
|
|
184
|
+
try:
|
|
185
|
+
# Build a single Stata command to fetch all timestamps
|
|
186
|
+
stata_cmd = ""
|
|
187
|
+
for i, name in enumerate(graph_names):
|
|
188
|
+
resolved = self._stata_client._resolve_graph_name_for_stata(name)
|
|
189
|
+
stata_cmd += f"quietly graph describe {resolved}\n"
|
|
190
|
+
stata_cmd += f"macro define mcp_ts_{i} \"`r(command_date)'_`r(command_time)'\"\n"
|
|
191
|
+
|
|
192
|
+
self._stata_client.stata.run(stata_cmd, echo=False)
|
|
193
|
+
|
|
194
|
+
from sfi import Macro
|
|
195
|
+
results = {}
|
|
196
|
+
for i, name in enumerate(graph_names):
|
|
197
|
+
ts = Macro.getGlobal(f"mcp_ts_{i}")
|
|
198
|
+
if ts:
|
|
199
|
+
results[name] = ts
|
|
200
|
+
return results
|
|
201
|
+
finally:
|
|
202
|
+
self._stata_client.stata.run(f"capture _return restore {hold_name}", echo=False)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.debug(f"Failed to get timestamps: {e}")
|
|
205
|
+
return {}
|
|
206
|
+
|
|
207
|
+
def _detect_graphs_via_pystata(self) -> List[str]:
|
|
208
|
+
"""Detect newly created graphs using direct pystata state access."""
|
|
209
|
+
if not self._stata_client:
|
|
210
|
+
return []
|
|
211
|
+
|
|
212
|
+
with self._lock:
|
|
213
|
+
try:
|
|
214
|
+
# Get current graph state - this now uses a single bundle (1 round trip)
|
|
215
|
+
current_state = self._get_graph_state_from_pystata()
|
|
216
|
+
current_graphs = list(current_state.keys())
|
|
217
|
+
|
|
218
|
+
# Compare with last known state to detect new graphs
|
|
219
|
+
new_graphs = []
|
|
220
|
+
|
|
221
|
+
# Check for new graph names
|
|
222
|
+
for graph_name in current_graphs:
|
|
223
|
+
if graph_name not in self._last_graph_state and graph_name not in self._removed_graphs:
|
|
224
|
+
new_graphs.append(graph_name)
|
|
225
|
+
|
|
226
|
+
# Check for state changes in existing graphs (modifications)
|
|
227
|
+
for graph_name, state in current_state.items():
|
|
228
|
+
if graph_name in self._last_graph_state:
|
|
229
|
+
last_state = self._last_graph_state[graph_name]
|
|
230
|
+
# Compare stable signature.
|
|
231
|
+
if state.get("signature") != last_state.get("signature"):
|
|
232
|
+
if graph_name not in self._removed_graphs:
|
|
233
|
+
new_graphs.append(graph_name)
|
|
234
|
+
|
|
235
|
+
# Update cached state
|
|
236
|
+
self._last_graph_state = current_state.copy()
|
|
237
|
+
|
|
238
|
+
return new_graphs
|
|
239
|
+
|
|
240
|
+
except (ImportError, RuntimeError, ValueError, AttributeError) as e:
|
|
241
|
+
# These are expected exceptions when SFI is not available or Stata state is inaccessible
|
|
242
|
+
logger.debug(f"Failed to detect graphs via pystata (expected): {e}")
|
|
243
|
+
return []
|
|
244
|
+
except Exception as e:
|
|
245
|
+
# Unexpected errors should be logged as errors
|
|
246
|
+
logger.error(f"Unexpected error in pystata graph detection: {e}")
|
|
247
|
+
return []
|
|
248
|
+
|
|
249
|
+
def _get_current_graphs_from_pystata(self) -> List[str]:
|
|
250
|
+
"""Get current list of graphs using pystata's sfi interface."""
|
|
251
|
+
try:
|
|
252
|
+
# Use pystata to get graph list directly
|
|
253
|
+
if self._stata_client and hasattr(self._stata_client, 'list_graphs'):
|
|
254
|
+
graphs = self._stata_client.list_graphs(force_refresh=True)
|
|
255
|
+
if graphs:
|
|
256
|
+
return graphs
|
|
257
|
+
# Fallback to inventory if list_graphs is empty
|
|
258
|
+
try:
|
|
259
|
+
inventory, _timestamps = self._get_graph_inventory(need_timestamps=False)
|
|
260
|
+
if inventory:
|
|
261
|
+
return inventory
|
|
262
|
+
except Exception:
|
|
263
|
+
return []
|
|
264
|
+
# Brief retry to allow graph registration to settle
|
|
265
|
+
time.sleep(0.05)
|
|
266
|
+
graphs = self._stata_client.list_graphs(force_refresh=True)
|
|
267
|
+
if graphs:
|
|
268
|
+
return graphs
|
|
269
|
+
try:
|
|
270
|
+
inventory, _timestamps = self._get_graph_inventory(need_timestamps=False)
|
|
271
|
+
return inventory
|
|
272
|
+
except Exception:
|
|
273
|
+
return []
|
|
274
|
+
else:
|
|
275
|
+
# Fallback to sfi Macro interface - only if stata is available
|
|
276
|
+
if self._stata_client and hasattr(self._stata_client, 'stata'):
|
|
277
|
+
# Access the lock from client to prevent concurrency issues with pystata
|
|
278
|
+
exec_lock = getattr(self._stata_client, "_exec_lock", None)
|
|
279
|
+
ctx = exec_lock if exec_lock else contextlib.nullcontext()
|
|
280
|
+
|
|
281
|
+
with ctx:
|
|
282
|
+
try:
|
|
283
|
+
from sfi import Macro
|
|
284
|
+
hold_name = f"_mcp_det_{int(time.time() * 1000 % 1000000)}"
|
|
285
|
+
self._stata_client.stata.run(f"capture _return hold {hold_name}", echo=False)
|
|
286
|
+
try:
|
|
287
|
+
# Run graph dir quietly
|
|
288
|
+
self._stata_client.stata.run("quietly graph dir, memory", echo=False)
|
|
289
|
+
# Get r(list) DIRECTLY via SFI Macro interface to avoid parsing issues
|
|
290
|
+
# and syntax errors with empty results.
|
|
291
|
+
self._stata_client.stata.run("macro define mcp_detector_list `r(list)'", echo=False)
|
|
292
|
+
graph_list_str = Macro.getGlobal("mcp_detector_list")
|
|
293
|
+
finally:
|
|
294
|
+
self._stata_client.stata.run(f"capture _return restore {hold_name}", echo=False)
|
|
295
|
+
|
|
296
|
+
if not graph_list_str:
|
|
297
|
+
return []
|
|
298
|
+
|
|
299
|
+
# Handle quoted names from r(list) - Stata quotes names with spaces
|
|
300
|
+
import shlex
|
|
301
|
+
try:
|
|
302
|
+
return shlex.split(graph_list_str)
|
|
303
|
+
except Exception:
|
|
304
|
+
return graph_list_str.split()
|
|
305
|
+
except ImportError:
|
|
306
|
+
logger.warning("sfi.Macro not available for fallback graph detection")
|
|
307
|
+
return []
|
|
308
|
+
else:
|
|
309
|
+
return []
|
|
310
|
+
except Exception as e:
|
|
311
|
+
logger.warning(f"Failed to get current graphs from pystata: {e}")
|
|
312
|
+
return []
|
|
313
|
+
|
|
314
|
+
def _get_graph_state_from_pystata(self) -> Dict[str, Any]:
|
|
315
|
+
"""Get detailed graph state information using pystata's sfi interface."""
|
|
316
|
+
graph_state = {}
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
# Combined fetch for both list and timestamps (1 round trip)
|
|
320
|
+
current_graphs, timestamps = self._get_graph_inventory()
|
|
321
|
+
cmd_idx = getattr(self._stata_client, "_command_idx", 0)
|
|
322
|
+
|
|
323
|
+
for graph_name in current_graphs:
|
|
324
|
+
try:
|
|
325
|
+
# Signature logic:
|
|
326
|
+
# Prefer stable timestamps across commands to avoid duplicate notifications.
|
|
327
|
+
fast_sig = self._describe_graph_signature(graph_name)
|
|
328
|
+
|
|
329
|
+
prev = self._last_graph_state.get(graph_name)
|
|
330
|
+
timestamp = timestamps.get(graph_name)
|
|
331
|
+
|
|
332
|
+
if prev and prev.get("cmd_idx") == cmd_idx:
|
|
333
|
+
# Already processed in this command context.
|
|
334
|
+
sig = prev.get("signature")
|
|
335
|
+
elif timestamp:
|
|
336
|
+
# Use timestamp-stable signature across commands when available.
|
|
337
|
+
sig = f"{graph_name}_{timestamp}"
|
|
338
|
+
else:
|
|
339
|
+
# Fallback to command-index-based signature.
|
|
340
|
+
sig = fast_sig
|
|
341
|
+
|
|
342
|
+
state_info = {
|
|
343
|
+
"name": graph_name,
|
|
344
|
+
"exists": True,
|
|
345
|
+
"valid": bool(sig),
|
|
346
|
+
"signature": sig,
|
|
347
|
+
"cmd_idx": cmd_idx,
|
|
348
|
+
"timestamp_val": timestamp,
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
# Only update visual timestamps when the signature changes.
|
|
352
|
+
if prev is None or prev.get("signature") != sig:
|
|
353
|
+
state_info["timestamp"] = time.time()
|
|
354
|
+
else:
|
|
355
|
+
state_info["timestamp"] = prev.get("timestamp", time.time())
|
|
356
|
+
|
|
357
|
+
graph_state[graph_name] = state_info
|
|
358
|
+
|
|
359
|
+
except Exception as e:
|
|
360
|
+
logger.warning(f"Failed to get state for graph {graph_name}: {e}")
|
|
361
|
+
graph_state[graph_name] = {"name": graph_name, "timestamp": time.time(), "exists": False, "signature": "", "cmd_idx": cmd_idx}
|
|
362
|
+
|
|
363
|
+
except Exception as e:
|
|
364
|
+
logger.warning(f"Failed to get graph state from pystata: {e}")
|
|
365
|
+
|
|
366
|
+
return graph_state
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def detect_graph_modifications(self, text: str = None) -> dict:
|
|
371
|
+
"""Detect graph modification/removal using SFI state comparison."""
|
|
372
|
+
modifications = {"dropped": [], "renamed": [], "cleared": False}
|
|
373
|
+
|
|
374
|
+
if not self._stata_client:
|
|
375
|
+
return modifications
|
|
376
|
+
|
|
377
|
+
try:
|
|
378
|
+
# Use the more sophisticated state retrieval that handles timestamp verification
|
|
379
|
+
new_state = self._get_graph_state_from_pystata()
|
|
380
|
+
current_graphs = set(new_state.keys())
|
|
381
|
+
|
|
382
|
+
# Compare with last known state to detect modifications
|
|
383
|
+
if self._last_graph_state:
|
|
384
|
+
last_graphs = set(self._last_graph_state.keys())
|
|
385
|
+
|
|
386
|
+
# Detect dropped graphs (in last state but not current)
|
|
387
|
+
dropped_graphs = last_graphs - current_graphs
|
|
388
|
+
modifications["dropped"].extend(dropped_graphs)
|
|
389
|
+
|
|
390
|
+
# Detect clear all (no graphs remain when there were some before)
|
|
391
|
+
if last_graphs and not current_graphs:
|
|
392
|
+
modifications["cleared"] = True
|
|
393
|
+
|
|
394
|
+
# Update last known state
|
|
395
|
+
self._last_graph_state = new_state
|
|
396
|
+
|
|
397
|
+
except Exception as e:
|
|
398
|
+
logger.debug(f"SFI modification detection failed: {e}")
|
|
399
|
+
|
|
400
|
+
return modifications
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def should_cache_graph(self, graph_name: str) -> bool:
|
|
404
|
+
"""Determine if a graph should be cached."""
|
|
405
|
+
with self._lock:
|
|
406
|
+
# Don't cache if already detected or removed
|
|
407
|
+
if graph_name in self._detected_graphs or graph_name in self._removed_graphs:
|
|
408
|
+
return False
|
|
409
|
+
|
|
410
|
+
# Mark as detected
|
|
411
|
+
self._detected_graphs.add(graph_name)
|
|
412
|
+
return True
|
|
413
|
+
|
|
414
|
+
def mark_graph_removed(self, graph_name: str) -> None:
|
|
415
|
+
"""Mark a graph as removed."""
|
|
416
|
+
with self._lock:
|
|
417
|
+
self._removed_graphs.add(graph_name)
|
|
418
|
+
self._detected_graphs.discard(graph_name)
|
|
419
|
+
|
|
420
|
+
def mark_all_cleared(self) -> None:
|
|
421
|
+
"""Mark all graphs as cleared."""
|
|
422
|
+
with self._lock:
|
|
423
|
+
self._detected_graphs.clear()
|
|
424
|
+
self._removed_graphs.clear()
|
|
425
|
+
|
|
426
|
+
def clear_detection_state(self) -> None:
|
|
427
|
+
"""Clear all detection state."""
|
|
428
|
+
with self._lock:
|
|
429
|
+
self._detected_graphs.clear()
|
|
430
|
+
self._removed_graphs.clear()
|
|
431
|
+
self._unnamed_graph_counter = 0
|
|
432
|
+
|
|
433
|
+
def process_modifications(self, modifications: dict) -> None:
|
|
434
|
+
"""Process detected modifications."""
|
|
435
|
+
with self._lock:
|
|
436
|
+
# Handle dropped graphs
|
|
437
|
+
for graph_name in modifications.get("dropped", []):
|
|
438
|
+
self.mark_graph_removed(graph_name)
|
|
439
|
+
|
|
440
|
+
# Handle renamed graphs
|
|
441
|
+
for old_name, new_name in modifications.get("renamed", []):
|
|
442
|
+
self.mark_graph_removed(old_name)
|
|
443
|
+
self._detected_graphs.discard(new_name) # Allow re-detection with new name
|
|
444
|
+
|
|
445
|
+
# Handle clear all
|
|
446
|
+
if modifications.get("cleared", False):
|
|
447
|
+
self.mark_all_cleared()
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
class StreamingGraphCache:
|
|
451
|
+
"""Integrates graph detection with caching during streaming."""
|
|
452
|
+
|
|
453
|
+
def __init__(self, stata_client, auto_cache: bool = False):
|
|
454
|
+
self.stata_client = stata_client
|
|
455
|
+
self.auto_cache = auto_cache
|
|
456
|
+
# Use persistent detector from client if available, else create local one
|
|
457
|
+
if hasattr(stata_client, "_graph_detector"):
|
|
458
|
+
self.detector = stata_client._graph_detector
|
|
459
|
+
else:
|
|
460
|
+
self.detector = GraphCreationDetector(stata_client)
|
|
461
|
+
self._lock = threading.Lock()
|
|
462
|
+
self._cache_callbacks: List[Callable[[str, bool], None]] = []
|
|
463
|
+
self._graphs_to_cache: List[str] = []
|
|
464
|
+
self._cached_graphs: Set[str] = set()
|
|
465
|
+
self._removed_graphs = set() # Track removed graphs directly
|
|
466
|
+
self._initial_graphs: Set[str] = set() # Captured before execution starts
|
|
467
|
+
|
|
468
|
+
def add_cache_callback(self, callback: Callable[[str, bool], None]) -> None:
|
|
469
|
+
"""Add callback for graph cache events."""
|
|
470
|
+
with self._lock:
|
|
471
|
+
self._cache_callbacks.append(callback)
|
|
472
|
+
|
|
473
|
+
async def _notify_cache_callbacks(self, graph_name: str, success: bool) -> None:
|
|
474
|
+
for callback in self._cache_callbacks:
|
|
475
|
+
try:
|
|
476
|
+
result = callback(graph_name, success)
|
|
477
|
+
if inspect.isawaitable(result):
|
|
478
|
+
await result
|
|
479
|
+
except Exception as e:
|
|
480
|
+
logger.warning(f"Cache callback failed for {graph_name}: {e}")
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
async def cache_detected_graphs_with_pystata(self) -> List[str]:
|
|
484
|
+
"""Enhanced caching method that uses pystata for real-time graph detection."""
|
|
485
|
+
if not self.auto_cache:
|
|
486
|
+
return []
|
|
487
|
+
|
|
488
|
+
cached_names = []
|
|
489
|
+
|
|
490
|
+
# First, try to get any newly detected graphs via pystata state
|
|
491
|
+
if self.stata_client:
|
|
492
|
+
try:
|
|
493
|
+
# Get current state and check for new graphs
|
|
494
|
+
# _detect_graphs_via_pystata is sync and uses _exec_lock, must run in thread
|
|
495
|
+
import anyio
|
|
496
|
+
self.detector._inventory_cache_enabled = True
|
|
497
|
+
try:
|
|
498
|
+
pystata_detected = await anyio.to_thread.run_sync(self.detector._detect_graphs_via_pystata)
|
|
499
|
+
finally:
|
|
500
|
+
self.detector._inventory_cache_enabled = False
|
|
501
|
+
|
|
502
|
+
# Add any newly detected graphs to cache queue
|
|
503
|
+
for graph_name in pystata_detected:
|
|
504
|
+
if graph_name not in self._cached_graphs and graph_name not in self._removed_graphs:
|
|
505
|
+
self._graphs_to_cache.append(graph_name)
|
|
506
|
+
|
|
507
|
+
except Exception as e:
|
|
508
|
+
logger.warning(f"Failed to get pystata graph updates: {e}")
|
|
509
|
+
|
|
510
|
+
# Process the cache queue
|
|
511
|
+
with self._lock:
|
|
512
|
+
graphs_to_process = self._graphs_to_cache.copy()
|
|
513
|
+
self._graphs_to_cache.clear()
|
|
514
|
+
|
|
515
|
+
if not graphs_to_process:
|
|
516
|
+
return cached_names
|
|
517
|
+
|
|
518
|
+
# Get current graph list for verification
|
|
519
|
+
try:
|
|
520
|
+
# list_graphs is sync and uses _exec_lock, must run in thread
|
|
521
|
+
import anyio
|
|
522
|
+
current_graphs = await anyio.to_thread.run_sync(self.stata_client.list_graphs)
|
|
523
|
+
except Exception as e:
|
|
524
|
+
logger.warning(f"Failed to get current graph list: {e}")
|
|
525
|
+
return cached_names
|
|
526
|
+
|
|
527
|
+
for graph_name in graphs_to_process:
|
|
528
|
+
if graph_name in current_graphs and graph_name not in self._cached_graphs:
|
|
529
|
+
try:
|
|
530
|
+
success = await asyncio.to_thread(self.stata_client.cache_graph_on_creation, graph_name)
|
|
531
|
+
if success:
|
|
532
|
+
cached_names.append(graph_name)
|
|
533
|
+
with self._lock:
|
|
534
|
+
self._cached_graphs.add(graph_name)
|
|
535
|
+
|
|
536
|
+
# Notify callbacks
|
|
537
|
+
await self._notify_cache_callbacks(graph_name, success)
|
|
538
|
+
|
|
539
|
+
except Exception as e:
|
|
540
|
+
logger.warning(f"Failed to cache graph {graph_name}: {e}")
|
|
541
|
+
# Still notify callbacks of failure
|
|
542
|
+
await self._notify_cache_callbacks(graph_name, False)
|
|
543
|
+
|
|
544
|
+
return cached_names
|
|
545
|
+
|
|
546
|
+
async def cache_detected_graphs(self) -> List[str]:
|
|
547
|
+
"""Cache all detected graphs."""
|
|
548
|
+
if not self.auto_cache:
|
|
549
|
+
return []
|
|
550
|
+
|
|
551
|
+
cached_names = []
|
|
552
|
+
|
|
553
|
+
with self._lock:
|
|
554
|
+
graphs_to_process = self._graphs_to_cache.copy()
|
|
555
|
+
self._graphs_to_cache.clear()
|
|
556
|
+
|
|
557
|
+
# Get current graph list for verification
|
|
558
|
+
try:
|
|
559
|
+
# list_graphs is sync and uses _exec_lock, must run in thread
|
|
560
|
+
import anyio
|
|
561
|
+
current_graphs = await anyio.to_thread.run_sync(self.stata_client.list_graphs)
|
|
562
|
+
except Exception as e:
|
|
563
|
+
logger.warning(f"Failed to get current graph list: {e}")
|
|
564
|
+
return cached_names
|
|
565
|
+
|
|
566
|
+
for graph_name in graphs_to_process:
|
|
567
|
+
if graph_name in current_graphs and graph_name not in self._cached_graphs:
|
|
568
|
+
try:
|
|
569
|
+
success = await asyncio.to_thread(self.stata_client.cache_graph_on_creation, graph_name)
|
|
570
|
+
if success:
|
|
571
|
+
cached_names.append(graph_name)
|
|
572
|
+
with self._lock:
|
|
573
|
+
self._cached_graphs.add(graph_name)
|
|
574
|
+
|
|
575
|
+
# Notify callbacks
|
|
576
|
+
await self._notify_cache_callbacks(graph_name, success)
|
|
577
|
+
|
|
578
|
+
except Exception as e:
|
|
579
|
+
logger.warning(f"Failed to cache graph {graph_name}: {e}")
|
|
580
|
+
# Still notify callbacks of failure
|
|
581
|
+
await self._notify_cache_callbacks(graph_name, False)
|
|
582
|
+
|
|
583
|
+
return cached_names
|
|
584
|
+
|
|
585
|
+
def get_cache_stats(self) -> dict:
|
|
586
|
+
"""Get caching statistics."""
|
|
587
|
+
with self._lock:
|
|
588
|
+
return {
|
|
589
|
+
"auto_cache_enabled": self.auto_cache,
|
|
590
|
+
"pending_cache_count": len(self._graphs_to_cache),
|
|
591
|
+
"cached_graphs_count": len(self._cached_graphs),
|
|
592
|
+
"detected_graphs_count": len(self.detector._detected_graphs),
|
|
593
|
+
"removed_graphs_count": len(self.detector._removed_graphs),
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
def reset(self) -> None:
|
|
597
|
+
"""Reset the cache state."""
|
|
598
|
+
with self._lock:
|
|
599
|
+
self._graphs_to_cache.clear()
|
|
600
|
+
self._cached_graphs.clear()
|
|
601
|
+
self.detector.clear_detection_state()
|
mcp_stata/models.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import List, Optional, Dict, Any
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ErrorEnvelope(BaseModel):
|
|
7
|
+
message: str
|
|
8
|
+
rc: Optional[int] = None
|
|
9
|
+
line: Optional[int] = None
|
|
10
|
+
command: Optional[str] = None
|
|
11
|
+
log_path: Optional[str] = None
|
|
12
|
+
context: Optional[str] = None
|
|
13
|
+
stdout: Optional[str] = None
|
|
14
|
+
stderr: Optional[str] = None
|
|
15
|
+
snippet: Optional[str] = None
|
|
16
|
+
trace: Optional[bool] = None
|
|
17
|
+
smcl_output: Optional[str] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CommandResponse(BaseModel):
|
|
21
|
+
command: str
|
|
22
|
+
rc: int
|
|
23
|
+
stdout: str
|
|
24
|
+
stderr: Optional[str] = None
|
|
25
|
+
log_path: Optional[str] = None
|
|
26
|
+
success: bool
|
|
27
|
+
error: Optional[ErrorEnvelope] = None
|
|
28
|
+
smcl_output: Optional[str] = None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DataResponse(BaseModel):
|
|
32
|
+
start: int
|
|
33
|
+
count: int
|
|
34
|
+
data: List[Dict[str, Any]]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class VariableInfo(BaseModel):
|
|
38
|
+
name: str
|
|
39
|
+
label: Optional[str] = None
|
|
40
|
+
type: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class VariablesResponse(BaseModel):
|
|
44
|
+
variables: List[VariableInfo]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class GraphInfo(BaseModel):
|
|
48
|
+
name: str
|
|
49
|
+
active: bool = False
|
|
50
|
+
created: Optional[str] = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class GraphListResponse(BaseModel):
|
|
54
|
+
graphs: List[GraphInfo]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class GraphExport(BaseModel):
|
|
58
|
+
name: str
|
|
59
|
+
file_path: Optional[str] = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GraphExportResponse(BaseModel):
|
|
63
|
+
graphs: List[GraphExport]
|
|
64
|
+
|