mcp-stata 1.20.0__cp311-abi3-macosx_11_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-stata might be problematic. Click here for more details.

@@ -0,0 +1,468 @@
1
+ from __future__ import annotations
2
+ """
3
+ Graph creation detection for streaming Stata output.
4
+
5
+ This module provides functionality to detect when graphs are created
6
+ during Stata command execution and automatically cache them.
7
+ """
8
+
9
+ import asyncio
10
+ import contextlib
11
+ import inspect
12
+ import re
13
+ import threading
14
+ import time
15
+ from typing import List, Set, Callable, Dict, Any
16
+ import logging
17
+
18
+
19
+ # SFI is always available
20
+ SFI_AVAILABLE = True
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class GraphCreationDetector:
26
+ """Detects graph creation using SFI-only detection with pystata integration."""
27
+
28
+ def __init__(self, stata_client=None):
29
+ self._lock = threading.Lock()
30
+ self._detected_graphs: Set[str] = set()
31
+ self._removed_graphs: Set[str] = set()
32
+ self._unnamed_graph_counter = 0 # Track unnamed graphs for identification
33
+ self._stata_client = stata_client
34
+ self._last_graph_state: Dict[str, Any] = {} # Track graph state changes
35
+
36
+ def _describe_graph_signature(self, graph_name: str) -> str:
37
+ """Return a stable signature for a graph.
38
+
39
+ We avoid using Stata calls like 'graph describe' here because they are slow
40
+ (each call takes ~35ms) and would be called for every graph on every poll,
41
+ bottlenecking the streaming output.
42
+
43
+ Instead, we use name-based tracking tied to the Stata command execution
44
+ context. The signature is stable within a single command execution but
45
+ changes when a new command starts, allowing us to detect modifications
46
+ between commands without any Stata overhead.
47
+ """
48
+ if not self._stata_client:
49
+ return ""
50
+
51
+ # Access command_idx from stata_client if available
52
+ cmd_idx = getattr(self._stata_client, "_command_idx", 0)
53
+ # Always include command_idx to ensure modifications to named graphs are
54
+ # detected when a new command starts. Within a single command, the
55
+ # detected_graphs set in the cache will prevent duplicates.
56
+ return f"{graph_name}_{cmd_idx}"
57
+
58
+ def _get_graph_timestamp(self, graph_name: str) -> str:
59
+ """Get the creation/modification timestamp of a graph using graph describe.
60
+
61
+ The result is cached per command to optimize performance during streaming.
62
+ """
63
+ results = self._get_graph_timestamps([graph_name])
64
+ return results.get(graph_name, "")
65
+
66
+ def _get_graph_timestamps(self, graph_names: List[str]) -> Dict[str, str]:
67
+ """Get timestamps for multiple graphs in a single Stata call to minimize overhead."""
68
+ if not graph_names or not self._stata_client or not hasattr(self._stata_client, "stata"):
69
+ return {}
70
+
71
+ try:
72
+ # Use the lock from client to prevent concurrency issues with pystata
73
+ exec_lock = getattr(self._stata_client, "_exec_lock", None)
74
+ ctx = exec_lock if exec_lock else contextlib.nullcontext()
75
+
76
+ with ctx:
77
+ hold_name = f"_mcp_detector_thold_{int(time.time() * 1000 % 1000000)}"
78
+ self._stata_client.stata.run(f"capture _return hold {hold_name}", echo=False)
79
+ try:
80
+ # Build a single Stata command to fetch all timestamps
81
+ stata_cmd = ""
82
+ for i, name in enumerate(graph_names):
83
+ resolved = self._stata_client._resolve_graph_name_for_stata(name)
84
+ stata_cmd += f"quietly graph describe {resolved}\n"
85
+ stata_cmd += f"macro define mcp_ts_{i} \"`r(command_date)'_`r(command_time)'\"\n"
86
+
87
+ self._stata_client.stata.run(stata_cmd, echo=False)
88
+
89
+ from sfi import Macro
90
+ results = {}
91
+ for i, name in enumerate(graph_names):
92
+ ts = Macro.getGlobal(f"mcp_ts_{i}")
93
+ if ts:
94
+ results[name] = ts
95
+ return results
96
+ finally:
97
+ self._stata_client.stata.run(f"capture _return restore {hold_name}", echo=False)
98
+ except Exception as e:
99
+ logger.debug(f"Failed to get timestamps: {e}")
100
+ return {}
101
+
102
+ def _detect_graphs_via_pystata(self) -> List[str]:
103
+ """Detect newly created graphs using direct pystata state access."""
104
+ if not self._stata_client:
105
+ return []
106
+
107
+ with self._lock:
108
+ try:
109
+ # Get current graph state using pystata's sfi interface
110
+ current_graphs = self._get_current_graphs_from_pystata()
111
+ current_state = self._get_graph_state_from_pystata()
112
+
113
+ # Compare with last known state to detect new graphs
114
+ new_graphs = []
115
+
116
+ # Check for new graph names
117
+ for graph_name in current_graphs:
118
+ if graph_name not in self._last_graph_state and graph_name not in self._removed_graphs:
119
+ new_graphs.append(graph_name)
120
+
121
+ # Check for state changes in existing graphs (modifications)
122
+ for graph_name, state in current_state.items():
123
+ if graph_name in self._last_graph_state:
124
+ last_state = self._last_graph_state[graph_name]
125
+ # Compare stable signature.
126
+ if state.get("signature") != last_state.get("signature"):
127
+ if graph_name not in self._removed_graphs:
128
+ new_graphs.append(graph_name)
129
+
130
+ # Update cached state
131
+ self._last_graph_state = current_state.copy()
132
+
133
+ return new_graphs
134
+
135
+ except (ImportError, RuntimeError, ValueError, AttributeError) as e:
136
+ # These are expected exceptions when SFI is not available or Stata state is inaccessible
137
+ logger.debug(f"Failed to detect graphs via pystata (expected): {e}")
138
+ return []
139
+ except Exception as e:
140
+ # Unexpected errors should be logged as errors
141
+ logger.error(f"Unexpected error in pystata graph detection: {e}")
142
+ return []
143
+
144
+ def _get_current_graphs_from_pystata(self) -> List[str]:
145
+ """Get current list of graphs using pystata's sfi interface."""
146
+ try:
147
+ # Use pystata to get graph list directly
148
+ if self._stata_client and hasattr(self._stata_client, 'list_graphs'):
149
+ return self._stata_client.list_graphs(force_refresh=True)
150
+ else:
151
+ # Fallback to sfi Macro interface - only if stata is available
152
+ if self._stata_client and hasattr(self._stata_client, 'stata'):
153
+ # Access the lock from client to prevent concurrency issues with pystata
154
+ exec_lock = getattr(self._stata_client, "_exec_lock", None)
155
+ ctx = exec_lock if exec_lock else contextlib.nullcontext()
156
+
157
+ with ctx:
158
+ try:
159
+ from sfi import Macro
160
+ hold_name = f"_mcp_detector_hold_{int(time.time() * 1000 % 1000000)}"
161
+ self._stata_client.stata.run(f"capture _return hold {hold_name}", echo=False)
162
+ try:
163
+ self._stata_client.stata.run("macro define mcp_graph_list \"\"", echo=False)
164
+ self._stata_client.stata.run("quietly graph dir, memory", echo=False)
165
+ self._stata_client.stata.run("macro define mcp_graph_list `r(list)'", echo=False)
166
+ graph_list_str = Macro.getGlobal("mcp_graph_list")
167
+ finally:
168
+ self._stata_client.stata.run(f"capture _return restore {hold_name}", echo=False)
169
+ return graph_list_str.split() if graph_list_str else []
170
+ except ImportError:
171
+ logger.warning("sfi.Macro not available for fallback graph detection")
172
+ return []
173
+ else:
174
+ return []
175
+ except Exception as e:
176
+ logger.warning(f"Failed to get current graphs from pystata: {e}")
177
+ return []
178
+
179
+ def _get_graph_state_from_pystata(self) -> Dict[str, Any]:
180
+ """Get detailed graph state information using pystata's sfi interface."""
181
+ graph_state = {}
182
+
183
+ try:
184
+ current_graphs = self._get_current_graphs_from_pystata()
185
+ cmd_idx = getattr(self._stata_client, "_command_idx", 0)
186
+
187
+ # PRE-FETCH: Get timestamps for all current graphs in a single batch
188
+ # to minimize Stata-Python boundary crossings.
189
+ timestamps = self._get_graph_timestamps(current_graphs)
190
+
191
+ for graph_name in current_graphs:
192
+ try:
193
+ # Signature logic:
194
+ # 1. Start with name+cmd_idx (fast)
195
+ fast_sig = self._describe_graph_signature(graph_name)
196
+
197
+ # 2. If it's a new command for this graph, verify with timestamp
198
+ prev = self._last_graph_state.get(graph_name)
199
+ timestamp = timestamps.get(graph_name)
200
+ sig = fast_sig
201
+
202
+ if prev and prev.get("cmd_idx") == cmd_idx:
203
+ # Already processed in this command. Keep the signature we decided on.
204
+ sig = prev.get("signature")
205
+ elif prev and prev.get("cmd_idx") != cmd_idx:
206
+ # Command jumped. We need to know if it's a REAL modification.
207
+ # We use the timestamp from graph describe.
208
+ if timestamp:
209
+ prev_ts = prev.get("timestamp_val")
210
+ if prev_ts and prev_ts == timestamp:
211
+ # Timestamp match! Reuse the OLD signature to avoid
212
+ # reporting a modification just because cmd_idx jumped.
213
+ sig = prev.get("signature")
214
+ else:
215
+ # Timestamp changed or was missing. Use new fast_sig.
216
+ sig = fast_sig
217
+ else:
218
+ # Failed to get timestamp, fall back to fast_sig (safe default)
219
+ sig = fast_sig
220
+
221
+ state_info = {
222
+ "name": graph_name,
223
+ "exists": True,
224
+ "valid": bool(sig),
225
+ "signature": sig,
226
+ "cmd_idx": cmd_idx,
227
+ "timestamp_val": timestamp,
228
+ }
229
+
230
+ # Only update visual timestamps when the signature changes.
231
+ if prev is None or prev.get("signature") != sig:
232
+ state_info["timestamp"] = time.time()
233
+ else:
234
+ state_info["timestamp"] = prev.get("timestamp", time.time())
235
+
236
+ graph_state[graph_name] = state_info
237
+
238
+ except Exception as e:
239
+ logger.warning(f"Failed to get state for graph {graph_name}: {e}")
240
+ graph_state[graph_name] = {"name": graph_name, "timestamp": time.time(), "exists": False, "signature": "", "cmd_idx": cmd_idx}
241
+
242
+ except Exception as e:
243
+ logger.warning(f"Failed to get graph state from pystata: {e}")
244
+
245
+ return graph_state
246
+
247
+
248
+
249
+ def detect_graph_modifications(self, text: str = None) -> dict:
250
+ """Detect graph modification/removal using SFI state comparison."""
251
+ modifications = {"dropped": [], "renamed": [], "cleared": False}
252
+
253
+ if not self._stata_client:
254
+ return modifications
255
+
256
+ try:
257
+ # Use the more sophisticated state retrieval that handles timestamp verification
258
+ new_state = self._get_graph_state_from_pystata()
259
+ current_graphs = set(new_state.keys())
260
+
261
+ # Compare with last known state to detect modifications
262
+ if self._last_graph_state:
263
+ last_graphs = set(self._last_graph_state.keys())
264
+
265
+ # Detect dropped graphs (in last state but not current)
266
+ dropped_graphs = last_graphs - current_graphs
267
+ modifications["dropped"].extend(dropped_graphs)
268
+
269
+ # Detect clear all (no graphs remain when there were some before)
270
+ if last_graphs and not current_graphs:
271
+ modifications["cleared"] = True
272
+
273
+ # Update last known state
274
+ self._last_graph_state = new_state
275
+
276
+ except Exception as e:
277
+ logger.debug(f"SFI modification detection failed: {e}")
278
+
279
+ return modifications
280
+
281
+
282
+ def should_cache_graph(self, graph_name: str) -> bool:
283
+ """Determine if a graph should be cached."""
284
+ with self._lock:
285
+ # Don't cache if already detected or removed
286
+ if graph_name in self._detected_graphs or graph_name in self._removed_graphs:
287
+ return False
288
+
289
+ # Mark as detected
290
+ self._detected_graphs.add(graph_name)
291
+ return True
292
+
293
+ def mark_graph_removed(self, graph_name: str) -> None:
294
+ """Mark a graph as removed."""
295
+ with self._lock:
296
+ self._removed_graphs.add(graph_name)
297
+ self._detected_graphs.discard(graph_name)
298
+
299
+ def mark_all_cleared(self) -> None:
300
+ """Mark all graphs as cleared."""
301
+ with self._lock:
302
+ self._detected_graphs.clear()
303
+ self._removed_graphs.clear()
304
+
305
+ def clear_detection_state(self) -> None:
306
+ """Clear all detection state."""
307
+ with self._lock:
308
+ self._detected_graphs.clear()
309
+ self._removed_graphs.clear()
310
+ self._unnamed_graph_counter = 0
311
+
312
+ def process_modifications(self, modifications: dict) -> None:
313
+ """Process detected modifications."""
314
+ with self._lock:
315
+ # Handle dropped graphs
316
+ for graph_name in modifications.get("dropped", []):
317
+ self.mark_graph_removed(graph_name)
318
+
319
+ # Handle renamed graphs
320
+ for old_name, new_name in modifications.get("renamed", []):
321
+ self.mark_graph_removed(old_name)
322
+ self._detected_graphs.discard(new_name) # Allow re-detection with new name
323
+
324
+ # Handle clear all
325
+ if modifications.get("cleared", False):
326
+ self.mark_all_cleared()
327
+
328
+
329
+ class StreamingGraphCache:
330
+ """Integrates graph detection with caching during streaming."""
331
+
332
+ def __init__(self, stata_client, auto_cache: bool = False):
333
+ self.stata_client = stata_client
334
+ self.auto_cache = auto_cache
335
+ # Use persistent detector from client if available, else create local one
336
+ if hasattr(stata_client, "_graph_detector"):
337
+ self.detector = stata_client._graph_detector
338
+ else:
339
+ self.detector = GraphCreationDetector(stata_client)
340
+ self._lock = threading.Lock()
341
+ self._cache_callbacks: List[Callable[[str, bool], None]] = []
342
+ self._graphs_to_cache: List[str] = []
343
+ self._cached_graphs: Set[str] = set()
344
+ self._removed_graphs = set() # Track removed graphs directly
345
+ self._initial_graphs: Set[str] = set() # Captured before execution starts
346
+
347
+ def add_cache_callback(self, callback: Callable[[str, bool], None]) -> None:
348
+ """Add callback for graph cache events."""
349
+ with self._lock:
350
+ self._cache_callbacks.append(callback)
351
+
352
+ async def _notify_cache_callbacks(self, graph_name: str, success: bool) -> None:
353
+ for callback in self._cache_callbacks:
354
+ try:
355
+ result = callback(graph_name, success)
356
+ if inspect.isawaitable(result):
357
+ await result
358
+ except Exception as e:
359
+ logger.warning(f"Cache callback failed for {graph_name}: {e}")
360
+
361
+
362
+ async def cache_detected_graphs_with_pystata(self) -> List[str]:
363
+ """Enhanced caching method that uses pystata for real-time graph detection."""
364
+ if not self.auto_cache:
365
+ return []
366
+
367
+ cached_names = []
368
+
369
+ # First, try to get any newly detected graphs via pystata state
370
+ if self.stata_client:
371
+ try:
372
+ # Get current state and check for new graphs
373
+
374
+ pystata_detected = self.detector._detect_graphs_via_pystata()
375
+
376
+ # Add any newly detected graphs to cache queue
377
+ for graph_name in pystata_detected:
378
+ if graph_name not in self._cached_graphs and graph_name not in self._removed_graphs:
379
+ self._graphs_to_cache.append(graph_name)
380
+
381
+ except Exception as e:
382
+ logger.warning(f"Failed to get pystata graph updates: {e}")
383
+
384
+ # Process the cache queue
385
+ with self._lock:
386
+ graphs_to_process = self._graphs_to_cache.copy()
387
+ self._graphs_to_cache.clear()
388
+
389
+ # Get current graph list for verification
390
+ try:
391
+ current_graphs = self.stata_client.list_graphs()
392
+ except Exception as e:
393
+ logger.warning(f"Failed to get current graph list: {e}")
394
+ return cached_names
395
+
396
+ for graph_name in graphs_to_process:
397
+ if graph_name in current_graphs and graph_name not in self._cached_graphs:
398
+ try:
399
+ success = await asyncio.to_thread(self.stata_client.cache_graph_on_creation, graph_name)
400
+ if success:
401
+ cached_names.append(graph_name)
402
+ with self._lock:
403
+ self._cached_graphs.add(graph_name)
404
+
405
+ # Notify callbacks
406
+ await self._notify_cache_callbacks(graph_name, success)
407
+
408
+ except Exception as e:
409
+ logger.warning(f"Failed to cache graph {graph_name}: {e}")
410
+ # Still notify callbacks of failure
411
+ await self._notify_cache_callbacks(graph_name, False)
412
+
413
+ return cached_names
414
+
415
+ async def cache_detected_graphs(self) -> List[str]:
416
+ """Cache all detected graphs."""
417
+ if not self.auto_cache:
418
+ return []
419
+
420
+ cached_names = []
421
+
422
+ with self._lock:
423
+ graphs_to_process = self._graphs_to_cache.copy()
424
+ self._graphs_to_cache.clear()
425
+
426
+ # Get current graph list for verification
427
+ try:
428
+ current_graphs = self.stata_client.list_graphs()
429
+ except Exception as e:
430
+ logger.warning(f"Failed to get current graph list: {e}")
431
+ return cached_names
432
+
433
+ for graph_name in graphs_to_process:
434
+ if graph_name in current_graphs and graph_name not in self._cached_graphs:
435
+ try:
436
+ success = await asyncio.to_thread(self.stata_client.cache_graph_on_creation, graph_name)
437
+ if success:
438
+ cached_names.append(graph_name)
439
+ with self._lock:
440
+ self._cached_graphs.add(graph_name)
441
+
442
+ # Notify callbacks
443
+ await self._notify_cache_callbacks(graph_name, success)
444
+
445
+ except Exception as e:
446
+ logger.warning(f"Failed to cache graph {graph_name}: {e}")
447
+ # Still notify callbacks of failure
448
+ await self._notify_cache_callbacks(graph_name, False)
449
+
450
+ return cached_names
451
+
452
+ def get_cache_stats(self) -> dict:
453
+ """Get caching statistics."""
454
+ with self._lock:
455
+ return {
456
+ "auto_cache_enabled": self.auto_cache,
457
+ "pending_cache_count": len(self._graphs_to_cache),
458
+ "cached_graphs_count": len(self._cached_graphs),
459
+ "detected_graphs_count": len(self.detector._detected_graphs),
460
+ "removed_graphs_count": len(self.detector._removed_graphs),
461
+ }
462
+
463
+ def reset(self) -> None:
464
+ """Reset the cache state."""
465
+ with self._lock:
466
+ self._graphs_to_cache.clear()
467
+ self._cached_graphs.clear()
468
+ self.detector.clear_detection_state()
mcp_stata/models.py ADDED
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+ from typing import List, Optional, Dict, Any
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ErrorEnvelope(BaseModel):
7
+ message: str
8
+ rc: Optional[int] = None
9
+ line: Optional[int] = None
10
+ command: Optional[str] = None
11
+ log_path: Optional[str] = None
12
+ context: Optional[str] = None
13
+ stdout: Optional[str] = None
14
+ stderr: Optional[str] = None
15
+ snippet: Optional[str] = None
16
+ trace: Optional[bool] = None
17
+ smcl_output: Optional[str] = None
18
+
19
+
20
+ class CommandResponse(BaseModel):
21
+ command: str
22
+ rc: int
23
+ stdout: str
24
+ stderr: Optional[str] = None
25
+ log_path: Optional[str] = None
26
+ success: bool
27
+ error: Optional[ErrorEnvelope] = None
28
+ smcl_output: Optional[str] = None
29
+
30
+
31
+ class DataResponse(BaseModel):
32
+ start: int
33
+ count: int
34
+ data: List[Dict[str, Any]]
35
+
36
+
37
+ class VariableInfo(BaseModel):
38
+ name: str
39
+ label: Optional[str] = None
40
+ type: Optional[str] = None
41
+
42
+
43
+ class VariablesResponse(BaseModel):
44
+ variables: List[VariableInfo]
45
+
46
+
47
+ class GraphInfo(BaseModel):
48
+ name: str
49
+ active: bool = False
50
+
51
+
52
+ class GraphListResponse(BaseModel):
53
+ graphs: List[GraphInfo]
54
+
55
+
56
+ class GraphExport(BaseModel):
57
+ name: str
58
+ file_path: Optional[str] = None
59
+
60
+
61
+ class GraphExportResponse(BaseModel):
62
+ graphs: List[GraphExport]
63
+
@@ -0,0 +1,87 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Iterable, Any, Tuple
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ try:
9
+ from mcp_stata import _native_ops as _native
10
+ except Exception: # pragma: no cover - optional module
11
+ _native = None
12
+
13
+
14
+ def argsort_numeric(
15
+ columns: Iterable["numpy.ndarray"],
16
+ descending: list[bool],
17
+ nulls_last: list[bool],
18
+ ) -> list[int] | None:
19
+ if _native is None:
20
+ return None
21
+ cols = list(columns)
22
+ if not cols:
23
+ return []
24
+ try:
25
+ return _native.argsort_numeric(cols, descending, nulls_last)
26
+ except Exception as e:
27
+ logger.warning(f"Native numeric sort failed: {e}")
28
+ return None
29
+
30
+
31
+ def argsort_mixed(
32
+ columns: Iterable[object],
33
+ is_string: list[bool],
34
+ descending: list[bool],
35
+ nulls_last: list[bool],
36
+ ) -> list[int] | None:
37
+ if _native is None:
38
+ return None
39
+ cols = list(columns)
40
+ if not cols:
41
+ return []
42
+ try:
43
+ return _native.argsort_mixed(cols, is_string, descending, nulls_last)
44
+ except Exception as e:
45
+ logger.warning(f"Native mixed sort failed: {e}")
46
+ return None
47
+
48
+
49
+ def smcl_to_markdown(smcl_text: str) -> str | None:
50
+ if _native is None:
51
+ return None
52
+ try:
53
+ return _native.smcl_to_markdown(smcl_text)
54
+ except Exception as e:
55
+ logger.warning(f"Native SMCL conversion failed: {e}")
56
+ return None
57
+
58
+
59
+ def fast_scan_log(smcl_content: str, rc_default: int) -> Tuple[str, str, int | None] | None:
60
+ if _native is None:
61
+ return None
62
+ try:
63
+ return _native.fast_scan_log(smcl_content, rc_default)
64
+ except Exception as e:
65
+ logger.warning(f"Native log scanning failed: {e}")
66
+ return None
67
+
68
+
69
+ def compute_filter_indices(
70
+ data_numeric: dict[str, "numpy.ndarray"],
71
+ data_string: dict[str, list[str]],
72
+ filter_json: str,
73
+ row_count: int,
74
+ ) -> list[int] | None:
75
+ if _native is None:
76
+ return None
77
+ try:
78
+ return _native.compute_filter_indices(
79
+ data_numeric,
80
+ data_string,
81
+ filter_json,
82
+ row_count
83
+ )
84
+ except Exception as e:
85
+ logger.warning(f"Native filtering failed: {e}")
86
+ return None
87
+