mcp-stata 1.2.2__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-stata might be problematic. Click here for more details.

@@ -0,0 +1,385 @@
1
+ """
2
+ Graph creation detection for streaming Stata output.
3
+
4
+ This module provides functionality to detect when graphs are created
5
+ during Stata command execution and automatically cache them.
6
+ """
7
+
8
+ import asyncio
9
+ import re
10
+ import threading
11
+ import time
12
+ from typing import List, Set, Callable, Dict, Any
13
+ import logging
14
+
15
+
16
+ # SFI is always available
17
+ SFI_AVAILABLE = True
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class GraphCreationDetector:
23
+ """Detects graph creation using SFI-only detection with pystata integration."""
24
+
25
+ def __init__(self, stata_client=None):
26
+ self._lock = threading.Lock()
27
+ self._detected_graphs: Set[str] = set()
28
+ self._removed_graphs: Set[str] = set()
29
+ self._unnamed_graph_counter = 0 # Track unnamed graphs for identification
30
+ self._stata_client = stata_client
31
+ self._last_graph_state: Dict[str, Any] = {} # Track graph state changes
32
+
33
+ def _describe_graph_signature(self, graph_name: str) -> str:
34
+ """Return a stable signature for a graph.
35
+
36
+ We intentionally avoid using timestamps as the signature, since that makes
37
+ every poll look like a modification.
38
+ """
39
+ if not self._stata_client or not hasattr(self._stata_client, "stata"):
40
+ return ""
41
+ try:
42
+ # Capture output so we can hash it deterministically.
43
+ resp = self._stata_client.run_command_structured(f"graph describe {graph_name}", echo=False)
44
+ if resp.success and resp.stdout:
45
+ return resp.stdout
46
+ if resp.error and resp.error.snippet:
47
+ return resp.error.snippet
48
+ except Exception:
49
+ return ""
50
+ return ""
51
+
52
+ def _detect_graphs_via_pystata(self) -> List[str]:
53
+ """Detect newly created graphs using direct pystata state access."""
54
+ if not self._stata_client:
55
+ return []
56
+
57
+ try:
58
+ # Get current graph state using pystata's sfi interface
59
+ current_graphs = self._get_current_graphs_from_pystata()
60
+ current_state = self._get_graph_state_from_pystata()
61
+
62
+ # Compare with last known state to detect new graphs
63
+ new_graphs = []
64
+
65
+ # Check for new graph names
66
+ for graph_name in current_graphs:
67
+ if graph_name not in self._last_graph_state and graph_name not in self._removed_graphs:
68
+ new_graphs.append(graph_name)
69
+
70
+ # Check for state changes in existing graphs (modifications)
71
+ for graph_name, state in current_state.items():
72
+ if graph_name in self._last_graph_state:
73
+ last_state = self._last_graph_state[graph_name]
74
+ # Compare stable signature only.
75
+ if state.get("signature") != last_state.get("signature"):
76
+ if graph_name not in self._removed_graphs:
77
+ new_graphs.append(graph_name)
78
+
79
+ # Update cached state
80
+ self._last_graph_state = current_state.copy()
81
+
82
+ return new_graphs
83
+
84
+ except (ImportError, RuntimeError, ValueError, AttributeError) as e:
85
+ # These are expected exceptions when SFI is not available or Stata state is inaccessible
86
+ logger.debug(f"Failed to detect graphs via pystata (expected): {e}")
87
+ return []
88
+ except Exception as e:
89
+ # Unexpected errors should be logged as errors
90
+ logger.error(f"Unexpected error in pystata graph detection: {e}")
91
+ return []
92
+
93
+ def _get_current_graphs_from_pystata(self) -> List[str]:
94
+ """Get current list of graphs using pystata's sfi interface."""
95
+ try:
96
+ # Use pystata to get graph list directly
97
+ if self._stata_client and hasattr(self._stata_client, 'list_graphs'):
98
+ return self._stata_client.list_graphs()
99
+ else:
100
+ # Fallback to sfi Macro interface - only if stata is available
101
+ if self._stata_client and hasattr(self._stata_client, 'stata'):
102
+ try:
103
+ from sfi import Macro
104
+ self._stata_client.stata.run("quietly graph dir, memory")
105
+ self._stata_client.stata.run("global mcp_graph_list `r(list)'")
106
+ graph_list_str = Macro.getGlobal("mcp_graph_list")
107
+ return graph_list_str.split() if graph_list_str else []
108
+ except ImportError:
109
+ logger.warning("sfi.Macro not available for fallback graph detection")
110
+ return []
111
+ else:
112
+ return []
113
+ except Exception as e:
114
+ logger.warning(f"Failed to get current graphs from pystata: {e}")
115
+ return []
116
+
117
+ def _get_graph_state_from_pystata(self) -> Dict[str, Any]:
118
+ """Get detailed graph state information using pystata's sfi interface."""
119
+ graph_state = {}
120
+
121
+ try:
122
+ current_graphs = self._get_current_graphs_from_pystata()
123
+
124
+ for graph_name in current_graphs:
125
+ try:
126
+ signature = self._describe_graph_signature(graph_name)
127
+ state_info = {
128
+ "name": graph_name,
129
+ "exists": True,
130
+ "valid": bool(signature),
131
+ "signature": signature,
132
+ }
133
+
134
+ # Only update timestamps when the signature changes.
135
+ prev = self._last_graph_state.get(graph_name)
136
+ if prev is None or prev.get("signature") != signature:
137
+ state_info["timestamp"] = time.time()
138
+ else:
139
+ state_info["timestamp"] = prev.get("timestamp", time.time())
140
+
141
+ graph_state[graph_name] = state_info
142
+
143
+ except Exception as e:
144
+ logger.warning(f"Failed to get state for graph {graph_name}: {e}")
145
+ graph_state[graph_name] = {"name": graph_name, "timestamp": time.time(), "exists": False, "signature": ""}
146
+
147
+ except Exception as e:
148
+ logger.warning(f"Failed to get graph state from pystata: {e}")
149
+
150
+ return graph_state
151
+
152
+
153
+
154
+ def detect_graph_modifications(self, text: str = None) -> dict:
155
+ """Detect graph modification/removal using SFI state comparison."""
156
+ modifications = {"dropped": [], "renamed": [], "cleared": False}
157
+
158
+ if not self._stata_client:
159
+ return modifications
160
+
161
+ try:
162
+ # Get current graph state via SFI
163
+ current_graphs = set(self._get_current_graphs_from_pystata())
164
+
165
+ # Compare with last known state to detect modifications
166
+ if self._last_graph_state:
167
+ last_graphs = set(self._last_graph_state.keys())
168
+
169
+ # Detect dropped graphs (in last state but not current)
170
+ dropped_graphs = last_graphs - current_graphs
171
+ modifications["dropped"].extend(dropped_graphs)
172
+
173
+ # Detect clear all (no graphs remain when there were some before)
174
+ if last_graphs and not current_graphs:
175
+ modifications["cleared"] = True
176
+
177
+ # Update last known state for next comparison (stable signatures)
178
+ new_state: Dict[str, Any] = {}
179
+ for graph in current_graphs:
180
+ sig = self._describe_graph_signature(graph)
181
+ new_state[graph] = {
182
+ "name": graph,
183
+ "exists": True,
184
+ "valid": bool(sig),
185
+ "signature": sig,
186
+ "timestamp": time.time(),
187
+ }
188
+ self._last_graph_state = new_state
189
+
190
+ except Exception as e:
191
+ logger.debug(f"SFI modification detection failed: {e}")
192
+
193
+ return modifications
194
+
195
+
196
+ def should_cache_graph(self, graph_name: str) -> bool:
197
+ """Determine if a graph should be cached."""
198
+ with self._lock:
199
+ # Don't cache if already detected or removed
200
+ if graph_name in self._detected_graphs or graph_name in self._removed_graphs:
201
+ return False
202
+
203
+ # Mark as detected
204
+ self._detected_graphs.add(graph_name)
205
+ return True
206
+
207
+ def mark_graph_removed(self, graph_name: str) -> None:
208
+ """Mark a graph as removed."""
209
+ with self._lock:
210
+ self._removed_graphs.add(graph_name)
211
+ self._detected_graphs.discard(graph_name)
212
+
213
+ def mark_all_cleared(self) -> None:
214
+ """Mark all graphs as cleared."""
215
+ with self._lock:
216
+ self._detected_graphs.clear()
217
+ self._removed_graphs.clear()
218
+
219
+ def clear_detection_state(self) -> None:
220
+ """Clear all detection state."""
221
+ with self._lock:
222
+ self._detected_graphs.clear()
223
+ self._removed_graphs.clear()
224
+ self._unnamed_graph_counter = 0
225
+
226
+ def process_modifications(self, modifications: dict) -> None:
227
+ """Process detected modifications."""
228
+ with self._lock:
229
+ # Handle dropped graphs
230
+ for graph_name in modifications.get("dropped", []):
231
+ self.mark_graph_removed(graph_name)
232
+
233
+ # Handle renamed graphs
234
+ for old_name, new_name in modifications.get("renamed", []):
235
+ self.mark_graph_removed(old_name)
236
+ self._detected_graphs.discard(new_name) # Allow re-detection with new name
237
+
238
+ # Handle clear all
239
+ if modifications.get("cleared", False):
240
+ self.mark_all_cleared()
241
+
242
+
243
+ class StreamingGraphCache:
244
+ """Integrates graph detection with caching during streaming."""
245
+
246
+ def __init__(self, stata_client, auto_cache: bool = False):
247
+ self.stata_client = stata_client
248
+ self.auto_cache = auto_cache
249
+ self.detector = GraphCreationDetector(stata_client)
250
+ self._lock = threading.Lock()
251
+ self._cache_callbacks: List[Callable[[str, bool], None]] = []
252
+ self._graphs_to_cache: List[str] = []
253
+ self._cached_graphs: Set[str] = set()
254
+ self._removed_graphs = set() # Track removed graphs directly
255
+ self._initial_graphs: Set[str] = set() # Captured before execution starts
256
+
257
+ def add_cache_callback(self, callback: Callable[[str, bool], None]) -> None:
258
+ """Add callback for graph cache events."""
259
+ with self._lock:
260
+ self._cache_callbacks.append(callback)
261
+
262
+
263
+ async def cache_detected_graphs_with_pystata(self) -> List[str]:
264
+ """Enhanced caching method that uses pystata for real-time graph detection."""
265
+ if not self.auto_cache:
266
+ return []
267
+
268
+ cached_names = []
269
+
270
+ # First, try to get any newly detected graphs via pystata state
271
+ if self.stata_client:
272
+ try:
273
+ # Get current state and check for new graphs
274
+
275
+ pystata_detected = self.detector._detect_graphs_via_pystata()
276
+
277
+ # Add any newly detected graphs to cache queue
278
+ for graph_name in pystata_detected:
279
+ if graph_name not in self._cached_graphs and graph_name not in self._removed_graphs:
280
+ self._graphs_to_cache.append(graph_name)
281
+
282
+ except Exception as e:
283
+ logger.warning(f"Failed to get pystata graph updates: {e}")
284
+
285
+ # Process the cache queue
286
+ with self._lock:
287
+ graphs_to_process = self._graphs_to_cache.copy()
288
+ self._graphs_to_cache.clear()
289
+
290
+ # Get current graph list for verification
291
+ try:
292
+ current_graphs = self.stata_client.list_graphs()
293
+ except Exception as e:
294
+ logger.warning(f"Failed to get current graph list: {e}")
295
+ return cached_names
296
+
297
+ for graph_name in graphs_to_process:
298
+ if graph_name in current_graphs and graph_name not in self._cached_graphs:
299
+ try:
300
+ success = await asyncio.to_thread(self.stata_client.cache_graph_on_creation, graph_name)
301
+ if success:
302
+ cached_names.append(graph_name)
303
+ with self._lock:
304
+ self._cached_graphs.add(graph_name)
305
+
306
+ # Notify callbacks
307
+ for callback in self._cache_callbacks:
308
+ try:
309
+ callback(graph_name, success)
310
+ except Exception as e:
311
+ logger.warning(f"Cache callback failed for {graph_name}: {e}")
312
+
313
+ except Exception as e:
314
+ logger.warning(f"Failed to cache graph {graph_name}: {e}")
315
+ # Still notify callbacks of failure
316
+ for callback in self._cache_callbacks:
317
+ try:
318
+ callback(graph_name, False)
319
+ except Exception:
320
+ pass
321
+
322
+ return cached_names
323
+
324
+ async def cache_detected_graphs(self) -> List[str]:
325
+ """Cache all detected graphs."""
326
+ if not self.auto_cache:
327
+ return []
328
+
329
+ cached_names = []
330
+
331
+ with self._lock:
332
+ graphs_to_process = self._graphs_to_cache.copy()
333
+ self._graphs_to_cache.clear()
334
+
335
+ # Get current graph list for verification
336
+ try:
337
+ current_graphs = self.stata_client.list_graphs()
338
+ except Exception as e:
339
+ logger.warning(f"Failed to get current graph list: {e}")
340
+ return cached_names
341
+
342
+ for graph_name in graphs_to_process:
343
+ if graph_name in current_graphs and graph_name not in self._cached_graphs:
344
+ try:
345
+ success = await asyncio.to_thread(self.stata_client.cache_graph_on_creation, graph_name)
346
+ if success:
347
+ cached_names.append(graph_name)
348
+ with self._lock:
349
+ self._cached_graphs.add(graph_name)
350
+
351
+ # Notify callbacks
352
+ for callback in self._cache_callbacks:
353
+ try:
354
+ callback(graph_name, success)
355
+ except Exception as e:
356
+ logger.warning(f"Cache callback failed for {graph_name}: {e}")
357
+
358
+ except Exception as e:
359
+ logger.warning(f"Failed to cache graph {graph_name}: {e}")
360
+ # Still notify callbacks of failure
361
+ for callback in self._cache_callbacks:
362
+ try:
363
+ callback(graph_name, False)
364
+ except Exception:
365
+ pass
366
+
367
+ return cached_names
368
+
369
+ def get_cache_stats(self) -> dict:
370
+ """Get caching statistics."""
371
+ with self._lock:
372
+ return {
373
+ "auto_cache_enabled": self.auto_cache,
374
+ "pending_cache_count": len(self._graphs_to_cache),
375
+ "cached_graphs_count": len(self._cached_graphs),
376
+ "detected_graphs_count": len(self.detector._detected_graphs),
377
+ "removed_graphs_count": len(self.detector._removed_graphs),
378
+ }
379
+
380
+ def reset(self) -> None:
381
+ """Reset the cache state."""
382
+ with self._lock:
383
+ self._graphs_to_cache.clear()
384
+ self._cached_graphs.clear()
385
+ self.detector.clear_detection_state()
mcp_stata/models.py CHANGED
@@ -7,6 +7,7 @@ class ErrorEnvelope(BaseModel):
7
7
  rc: Optional[int] = None
8
8
  line: Optional[int] = None
9
9
  command: Optional[str] = None
10
+ log_path: Optional[str] = None
10
11
  stdout: Optional[str] = None
11
12
  stderr: Optional[str] = None
12
13
  snippet: Optional[str] = None
@@ -18,6 +19,7 @@ class CommandResponse(BaseModel):
18
19
  rc: int
19
20
  stdout: str
20
21
  stderr: Optional[str] = None
22
+ log_path: Optional[str] = None
21
23
  success: bool
22
24
  error: Optional[ErrorEnvelope] = None
23
25
 
@@ -49,7 +51,8 @@ class GraphListResponse(BaseModel):
49
51
 
50
52
  class GraphExport(BaseModel):
51
53
  name: str
52
- image_base64: str
54
+ file_path: Optional[str] = None
55
+ image_base64: Optional[str] = None
53
56
 
54
57
 
55
58
  class GraphExportResponse(BaseModel):