mcp-stata 1.22.1__cp311-abi3-macosx_11_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_stata/sessions.py ADDED
@@ -0,0 +1,264 @@
1
+ from __future__ import annotations
2
+ import os
3
+ import uuid
4
+ import logging
5
+ import asyncio
6
+ import atexit
7
+ import multiprocessing
8
+ from typing import Any, Dict, List, Optional, Callable, Awaitable
9
+ from multiprocessing.connection import Connection
10
+ from datetime import datetime, timezone
11
+
12
+ from mcp_stata.models import SessionInfo, CommandResponse
13
+
14
+ logger = logging.getLogger("mcp_stata.sessions")
15
+
16
+ # Use 'spawn' for process creation to ensure thread-safety and Stata/Rust compatibility.
17
+ # Re-exposed at module level so tests can patch these references.
18
+ _ctx = multiprocessing.get_context("spawn")
19
+ Process = _ctx.Process
20
+ Pipe = _ctx.Pipe
21
+
22
+ class StataSession:
23
+ def __init__(self, session_id: str):
24
+ self.id = session_id
25
+ self.status = "starting"
26
+ self.created_at = datetime.now(timezone.utc).isoformat()
27
+ self.pid: Optional[int] = None
28
+
29
+ self._parent_conn, self._child_conn = Pipe()
30
+ self._process = Process(target=self._run_worker, args=(self._child_conn,))
31
+ self._process.daemon = True
32
+ self._process.start()
33
+
34
+ self._pending_requests: Dict[str, asyncio.Future] = {}
35
+ self._log_listeners: Dict[str, List[Callable[[str], Awaitable[None]]]] = {}
36
+ self._progress_listeners: Dict[str, List[Callable[[float, Optional[float], Optional[str]], Awaitable[None]]]] = {}
37
+
38
+ self._listener_running = True
39
+ self._listener_task = asyncio.create_task(self._listen_to_worker())
40
+
41
+ def _run_worker(self, conn: Connection):
42
+ from mcp_stata.worker import main
43
+ main(conn)
44
+
45
+ async def _listen_to_worker(self):
46
+ loop = asyncio.get_running_loop()
47
+ try:
48
+ while self._listener_running:
49
+ # Use poll with timeout to allow checking self._listener_running and asyncio cancellation
50
+ if await loop.run_in_executor(None, self._parent_conn.poll, 0.2):
51
+ try:
52
+ msg = await loop.run_in_executor(None, self._parent_conn.recv)
53
+ await self._handle_worker_msg(msg)
54
+ except (EOFError, ConnectionResetError, BrokenPipeError):
55
+ logger.info(f"Session {self.id} worker connection closed.")
56
+ break
57
+ else:
58
+ # Give the event loop a chance to process other tasks and check cancellation
59
+ await asyncio.sleep(0.01)
60
+ except asyncio.CancelledError:
61
+ self._listener_running = False
62
+ raise
63
+ except Exception as e:
64
+ logger.error(f"Error in session {self.id} listener: {e}")
65
+ self.status = "error"
66
+ finally:
67
+ self._listener_running = False
68
+ if self.status != "error":
69
+ self.status = "stopped"
70
+
71
+ async def _handle_worker_msg(self, msg: Dict[str, Any]):
72
+ event = msg.get("event")
73
+ msg_id = msg.get("id")
74
+
75
+ if event == "ready":
76
+ self.pid = msg.get("pid")
77
+ self.status = "running"
78
+ logger.info(f"Session {self.id} ready (PID: {self.pid})")
79
+
80
+ elif event == "log":
81
+ if msg_id in self._log_listeners:
82
+ for cb in self._log_listeners[msg_id]:
83
+ await cb(msg.get("text"))
84
+
85
+ elif event == "progress":
86
+ if msg_id in self._progress_listeners:
87
+ for cb in self._progress_listeners[msg_id]:
88
+ await cb(msg.get("progress"), msg.get("total"), msg.get("message"))
89
+
90
+ elif event == "result":
91
+ if msg_id in self._pending_requests:
92
+ if not self._pending_requests[msg_id].done():
93
+ self._pending_requests[msg_id].set_result(msg.get("result"))
94
+ self._cleanup_listeners(msg_id)
95
+
96
+ elif event == "error":
97
+ if msg_id in self._pending_requests:
98
+ if not self._pending_requests[msg_id].done():
99
+ self._pending_requests[msg_id].set_exception(RuntimeError(msg.get("message")))
100
+ self._cleanup_listeners(msg_id)
101
+ else:
102
+ logger.error(f"Global worker error in session {self.id}: {msg.get('message')}")
103
+ # Don't update status if already stopped or error
104
+ if self.status not in ("stopped", "error"):
105
+ self.status = "error"
106
+
107
+ def _cleanup_listeners(self, msg_id: str):
108
+ self._log_listeners.pop(msg_id, None)
109
+ self._progress_listeners.pop(msg_id, None)
110
+ self._pending_requests.pop(msg_id, None)
111
+
112
+ async def _ensure_listener(self):
113
+ current_loop = asyncio.get_running_loop()
114
+ if self._listener_task is None or self._listener_task.done() or (hasattr(self._listener_task, "get_loop") and self._listener_task.get_loop() != current_loop):
115
+ if self._listener_task and not self._listener_task.done():
116
+ self._listener_task.cancel()
117
+ self._listener_running = True
118
+ self._listener_task = current_loop.create_task(self._listen_to_worker())
119
+
120
+ async def call(self, method: str, args: Dict[str, Any],
121
+ notify_log: Optional[Callable[[str], Awaitable[None]]] = None,
122
+ notify_progress: Optional[Callable[[float, Optional[float], Optional[str]], Awaitable[None]]] = None) -> Any:
123
+
124
+ await self._ensure_listener()
125
+ msg_id = uuid.uuid4().hex
126
+ future = asyncio.get_running_loop().create_future()
127
+ self._pending_requests[msg_id] = future
128
+
129
+ if notify_log:
130
+ self._log_listeners.setdefault(msg_id, []).append(notify_log)
131
+ if notify_progress:
132
+ self._progress_listeners.setdefault(msg_id, []).append(notify_progress)
133
+
134
+ try:
135
+ self._parent_conn.send({
136
+ "type": method,
137
+ "id": msg_id,
138
+ "args": args
139
+ })
140
+ except (AttributeError, BrokenPipeError, ConnectionResetError) as e:
141
+ self._cleanup_listeners(msg_id)
142
+ raise RuntimeError(f"Failed to send command to worker: {e}")
143
+
144
+ return await future
145
+
146
+ async def stop(self, timeout: float = 5.0):
147
+ self._listener_running = False
148
+ if self.status != "stopped":
149
+ try:
150
+ self._parent_conn.send({"type": "stop"})
151
+ except Exception:
152
+ pass
153
+
154
+ if self._process and self._process.is_alive():
155
+ self._process.terminate()
156
+
157
+ # Use executor to join with timeout without blocking the event loop
158
+ loop = asyncio.get_running_loop()
159
+ try:
160
+ await loop.run_in_executor(None, self._process.join, timeout)
161
+ except Exception:
162
+ pass
163
+
164
+ if self._process.is_alive():
165
+ logger.warning(f"Session {self.id} worker (PID {self._process.pid}) did not exit after {timeout}s; killing.")
166
+ try:
167
+ self._process.kill()
168
+ await loop.run_in_executor(None, self._process.join)
169
+ except Exception as e:
170
+ logger.error(f"Failed to kill session {self.id} worker: {e}")
171
+
172
+ self.status = "stopped"
173
+ if self._listener_task:
174
+ self._listener_task.cancel()
175
+ try:
176
+ await self._listener_task
177
+ except asyncio.CancelledError:
178
+ pass
179
+ except Exception:
180
+ pass
181
+
182
+ def get_info(self) -> SessionInfo:
183
+ return SessionInfo(
184
+ id=self.id,
185
+ status=self.status,
186
+ created_at=self.created_at,
187
+ pid=self.pid
188
+ )
189
+
190
+ # Use class-level list to track all managers for cleanup
191
+ _all_managers: List[SessionManager] = []
192
+ _atexit_registered = False
193
+
194
+ def _global_shutdown():
195
+ """Final emergency cleanup for all SessionManagers."""
196
+ for manager in _all_managers:
197
+ manager._shutdown()
198
+
199
+ class SessionManager:
200
+ def __init__(self):
201
+ self._sessions: Dict[str, StataSession] = {}
202
+ self._default_session_id = "default"
203
+ _all_managers.append(self)
204
+ global _atexit_registered
205
+ if not _atexit_registered:
206
+ atexit.register(_global_shutdown)
207
+ _atexit_registered = True
208
+
209
+ def _shutdown(self) -> None:
210
+ """Emergency cleanup for atexit."""
211
+ for session in list(self._sessions.values()):
212
+ try:
213
+ if session._process and session._process.is_alive():
214
+ # Be very aggressive in atexit, we don't have much time
215
+ session._process.kill()
216
+ session._process.join(timeout=0.1)
217
+ except Exception:
218
+ pass
219
+ try:
220
+ session._parent_conn.close()
221
+ except Exception:
222
+ pass
223
+ self._sessions.clear()
224
+
225
+ async def start(self):
226
+ # Start default session
227
+ await self.get_or_create_session(self._default_session_id)
228
+
229
+ async def get_or_create_session(self, session_id: str) -> StataSession:
230
+ if session_id not in self._sessions:
231
+ logger.info(f"Creating new Stata session: {session_id}")
232
+ session = StataSession(session_id)
233
+ self._sessions[session_id] = session
234
+ # Give it more time to start up on CI (especially Stata's first init)
235
+ # but don't wait if the process dies or status changes.
236
+ timeout = 30.0
237
+ start_time = asyncio.get_running_loop().time()
238
+ while session.status == "starting" and asyncio.get_running_loop().time() - start_time < timeout:
239
+ if not session._process.is_alive():
240
+ # Process died before reaching ready
241
+ session.status = "error"
242
+ break
243
+ await asyncio.sleep(0.1)
244
+
245
+ return self._sessions[session_id]
246
+
247
+ def get_session(self, session_id: str) -> StataSession:
248
+ if session_id not in self._sessions:
249
+ raise ValueError(f"Session {session_id} not found.")
250
+ return self._sessions[session_id]
251
+
252
+ def list_sessions(self) -> List[SessionInfo]:
253
+ return [s.get_info() for s in self._sessions.values()]
254
+
255
+ async def stop_session(self, session_id: str):
256
+ if session_id in self._sessions:
257
+ await self._sessions[session_id].stop()
258
+ del self._sessions[session_id]
259
+
260
+ async def stop_all(self):
261
+ tasks = [s.stop() for s in self._sessions.values()]
262
+ if tasks:
263
+ await asyncio.gather(*tasks)
264
+ self._sessions.clear()
@@ -0,0 +1,88 @@
1
+ """Convert a SMCL file into Markdown.
2
+
3
+ Adapted from https://github.com/sergiocorreia/parse-smcl (MIT). Simplified into
4
+ a single module geared toward MCP, emitting Markdown by default.
5
+ """
6
+
7
+ import os
8
+ import re
9
+ from mcp_stata.native_ops import smcl_to_markdown as rust_smcl_to_markdown
10
+
11
+
12
+ def expand_includes(lines, adopath):
13
+ """Expand INCLUDE directives if ado path is available."""
14
+ if not adopath:
15
+ return lines
16
+ includes = [(i, line[13:].strip()) for (i, line) in enumerate(lines) if line.startswith("INCLUDE help ")]
17
+ if os.path.exists(adopath):
18
+ for i, cmd in reversed(includes):
19
+ fn = os.path.join(adopath, cmd[0], cmd if cmd.endswith(".ihlp") else cmd + ".ihlp")
20
+ try:
21
+ with open(fn, "r", encoding="utf-8") as f:
22
+ content = f.readlines()
23
+ except FileNotFoundError:
24
+ continue
25
+ if content and content[0].startswith("{* *! version"):
26
+ content.pop(0)
27
+ lines[i:i+1] = content
28
+ return lines
29
+
30
+
31
+ def _inline_to_markdown(text: str) -> str:
32
+ """Convert common inline SMCL directives to Markdown."""
33
+
34
+ def repl(match: re.Match) -> str:
35
+ tag = match.group(1).lower()
36
+ content = match.group(2) or ""
37
+ if tag in ("bf", "strong"):
38
+ return f"**{content}**"
39
+ if tag in ("it", "em"):
40
+ return f"*{content}*"
41
+ if tag in ("cmd", "cmdab", "code", "inp", "input", "res", "err", "txt"):
42
+ return f"`{content}`"
43
+ return content
44
+
45
+ text = re.sub(r"\{([a-zA-Z0-9_]+):([^}]*)\}", repl, text)
46
+ text = re.sub(r"\{[^}]*\}", "", text)
47
+ return text
48
+
49
+
50
+ def smcl_to_markdown(smcl_text: str, adopath: str = None, current_file: str = "help") -> str:
51
+ """Convert SMCL text to lightweight Markdown suitable for LLM consumption."""
52
+ if not smcl_text:
53
+ return ""
54
+
55
+ # Try Rust optimization first if no complicated includes are needed
56
+ if not adopath or "INCLUDE help" not in smcl_text:
57
+ res = rust_smcl_to_markdown(smcl_text)
58
+ if res:
59
+ # Add header to match existing Python behavior
60
+ return f"# Help for {current_file}\n" + res
61
+
62
+ lines = smcl_text.splitlines()
63
+ if lines and lines[0].strip() == "{smcl}":
64
+ lines = lines[1:]
65
+
66
+ lines = expand_includes(lines, adopath)
67
+
68
+ title = None
69
+ body_parts = []
70
+
71
+ for raw in lines:
72
+ line = raw.strip()
73
+ if not line:
74
+ continue
75
+ if line.startswith("{title:"):
76
+ title = line[len("{title:") :].rstrip("}")
77
+ continue
78
+ # Paragraph markers
79
+ line = line.replace("{p_end}", "")
80
+ line = re.sub(r"\{p[^}]*\}", "", line)
81
+ body_parts.append(_inline_to_markdown(line))
82
+
83
+ md_parts = [f"# Help for {current_file}"]
84
+ if title:
85
+ md_parts.append(f"\n## {title}\n")
86
+ md_parts.append("\n".join(part for part in body_parts if part).strip())
87
+
88
+ return "\n\n".join(part for part in md_parts if part).strip() + "\n"