mcp-stata 1.22.1__cp311-abi3-macosx_11_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_stata/__init__.py +3 -0
- mcp_stata/__main__.py +4 -0
- mcp_stata/_native_ops.abi3.so +0 -0
- mcp_stata/config.py +20 -0
- mcp_stata/discovery.py +548 -0
- mcp_stata/graph_detector.py +601 -0
- mcp_stata/models.py +74 -0
- mcp_stata/native_ops.py +87 -0
- mcp_stata/server.py +1333 -0
- mcp_stata/sessions.py +264 -0
- mcp_stata/smcl/smcl2html.py +88 -0
- mcp_stata/stata_client.py +4710 -0
- mcp_stata/streaming_io.py +264 -0
- mcp_stata/test_stata.py +56 -0
- mcp_stata/ui_http.py +1034 -0
- mcp_stata/utils.py +159 -0
- mcp_stata/worker.py +167 -0
- mcp_stata-1.22.1.dist-info/METADATA +488 -0
- mcp_stata-1.22.1.dist-info/RECORD +22 -0
- mcp_stata-1.22.1.dist-info/WHEEL +4 -0
- mcp_stata-1.22.1.dist-info/entry_points.txt +2 -0
- mcp_stata-1.22.1.dist-info/licenses/LICENSE +661 -0
mcp_stata/server.py
ADDED
|
@@ -0,0 +1,1333 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import anyio
|
|
3
|
+
import asyncio
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
7
|
+
from mcp.server.fastmcp import Context, FastMCP
|
|
8
|
+
from mcp.server.fastmcp.utilities import logging as fastmcp_logging
|
|
9
|
+
import mcp.types as types
|
|
10
|
+
from .stata_client import StataClient
|
|
11
|
+
from .models import (
|
|
12
|
+
ErrorEnvelope,
|
|
13
|
+
CommandResponse,
|
|
14
|
+
DataResponse,
|
|
15
|
+
GraphListResponse,
|
|
16
|
+
VariableInfo,
|
|
17
|
+
VariablesResponse,
|
|
18
|
+
GraphInfo,
|
|
19
|
+
GraphExport,
|
|
20
|
+
GraphExportResponse,
|
|
21
|
+
SessionInfo,
|
|
22
|
+
SessionListResponse,
|
|
23
|
+
)
|
|
24
|
+
from .sessions import SessionManager
|
|
25
|
+
import logging
|
|
26
|
+
import sys
|
|
27
|
+
import json
|
|
28
|
+
import os
|
|
29
|
+
import multiprocessing
|
|
30
|
+
import re
|
|
31
|
+
import traceback
|
|
32
|
+
import uuid
|
|
33
|
+
from functools import wraps
|
|
34
|
+
from typing import Optional, Dict
|
|
35
|
+
import threading
|
|
36
|
+
|
|
37
|
+
from .ui_http import UIChannelManager
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Configure logging
|
|
41
|
+
logger = logging.getLogger("mcp_stata")
|
|
42
|
+
payload_logger = logging.getLogger("mcp_stata.payloads")
|
|
43
|
+
_LOGGING_CONFIGURED = False
|
|
44
|
+
|
|
45
|
+
def get_server_version() -> str:
|
|
46
|
+
"""Determine the server version from package metadata or fallback."""
|
|
47
|
+
try:
|
|
48
|
+
return version("mcp-stata")
|
|
49
|
+
except PackageNotFoundError:
|
|
50
|
+
# If not installed, try to find version in pyproject.toml near this file
|
|
51
|
+
try:
|
|
52
|
+
# We are in src/mcp_stata/server.py, pyproject.toml is at ../../pyproject.toml
|
|
53
|
+
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
54
|
+
pyproject_path = os.path.join(base_dir, "pyproject.toml")
|
|
55
|
+
if os.path.exists(pyproject_path):
|
|
56
|
+
with open(pyproject_path, "r") as f:
|
|
57
|
+
import re
|
|
58
|
+
content = f.read()
|
|
59
|
+
match = re.search(r'^version\s*=\s*["\']([^"\']+)["\']', content, re.MULTILINE)
|
|
60
|
+
if match:
|
|
61
|
+
return match.group(1)
|
|
62
|
+
except Exception:
|
|
63
|
+
pass
|
|
64
|
+
return "unknown"
|
|
65
|
+
|
|
66
|
+
SERVER_VERSION = get_server_version()
|
|
67
|
+
|
|
68
|
+
def setup_logging():
|
|
69
|
+
global _LOGGING_CONFIGURED
|
|
70
|
+
if _LOGGING_CONFIGURED:
|
|
71
|
+
return
|
|
72
|
+
_LOGGING_CONFIGURED = True
|
|
73
|
+
log_level = os.getenv("MCP_STATA_LOGLEVEL", "DEBUG").upper()
|
|
74
|
+
app_handler = logging.StreamHandler(sys.stderr)
|
|
75
|
+
app_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
|
|
76
|
+
app_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
|
|
77
|
+
|
|
78
|
+
mcp_handler = logging.StreamHandler(sys.stderr)
|
|
79
|
+
mcp_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
|
|
80
|
+
mcp_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
|
|
81
|
+
|
|
82
|
+
payload_handler = logging.StreamHandler(sys.stderr)
|
|
83
|
+
payload_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
|
|
84
|
+
payload_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
|
|
85
|
+
|
|
86
|
+
root_logger = logging.getLogger()
|
|
87
|
+
root_logger.handlers = []
|
|
88
|
+
root_logger.setLevel(logging.WARNING)
|
|
89
|
+
|
|
90
|
+
for name, item in logging.root.manager.loggerDict.items():
|
|
91
|
+
if not isinstance(item, logging.Logger):
|
|
92
|
+
continue
|
|
93
|
+
item.handlers = []
|
|
94
|
+
item.propagate = False
|
|
95
|
+
if item.level == logging.NOTSET:
|
|
96
|
+
item.setLevel(getattr(logging, log_level, logging.DEBUG))
|
|
97
|
+
|
|
98
|
+
logger.handlers = [app_handler]
|
|
99
|
+
logger.propagate = False
|
|
100
|
+
|
|
101
|
+
payload_logger.handlers = [payload_handler]
|
|
102
|
+
payload_logger.propagate = False
|
|
103
|
+
|
|
104
|
+
mcp_logger = logging.getLogger("mcp.server")
|
|
105
|
+
mcp_logger.handlers = [mcp_handler]
|
|
106
|
+
mcp_logger.propagate = False
|
|
107
|
+
mcp_logger.setLevel(getattr(logging, log_level, logging.DEBUG))
|
|
108
|
+
|
|
109
|
+
mcp_lowlevel = logging.getLogger("mcp.server.lowlevel.server")
|
|
110
|
+
mcp_lowlevel.handlers = [mcp_handler]
|
|
111
|
+
mcp_lowlevel.propagate = False
|
|
112
|
+
mcp_lowlevel.setLevel(getattr(logging, log_level, logging.DEBUG))
|
|
113
|
+
|
|
114
|
+
mcp_root = logging.getLogger("mcp")
|
|
115
|
+
mcp_root.handlers = [mcp_handler]
|
|
116
|
+
mcp_root.propagate = False
|
|
117
|
+
mcp_root.setLevel(getattr(logging, log_level, logging.DEBUG))
|
|
118
|
+
if logger.level == logging.NOTSET:
|
|
119
|
+
logger.setLevel(getattr(logging, log_level, logging.DEBUG))
|
|
120
|
+
|
|
121
|
+
logger.info("=== mcp-stata server starting ===")
|
|
122
|
+
logger.info("mcp-stata version: %s", SERVER_VERSION)
|
|
123
|
+
logger.info("STATA_PATH env at startup: %s", os.getenv("STATA_PATH", "<not set>"))
|
|
124
|
+
logger.info("LOG_LEVEL: %s", log_level)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# Initialize FastMCP
|
|
129
|
+
mcp = FastMCP("mcp_stata")
|
|
130
|
+
# Set version on the underlying server to expose it in InitializeResult
|
|
131
|
+
mcp._mcp_server.version = SERVER_VERSION
|
|
132
|
+
|
|
133
|
+
session_manager = SessionManager()
|
|
134
|
+
|
|
135
|
+
class StataClientProxy:
|
|
136
|
+
"""Proxy for StataClient that routes calls to a StataSession (via worker process)."""
|
|
137
|
+
def __init__(self, session_id: str = "default"):
|
|
138
|
+
self.session_id = session_id
|
|
139
|
+
|
|
140
|
+
def _call_sync(self, method: str, args: dict[str, Any]) -> Any:
|
|
141
|
+
try:
|
|
142
|
+
loop = asyncio.get_running_loop()
|
|
143
|
+
except RuntimeError:
|
|
144
|
+
loop = None
|
|
145
|
+
|
|
146
|
+
async def _run():
|
|
147
|
+
session = await session_manager.get_or_create_session(self.session_id)
|
|
148
|
+
return await session.call(method, args)
|
|
149
|
+
|
|
150
|
+
if loop and loop.is_running():
|
|
151
|
+
# If we're in a thread different from the loop's thread
|
|
152
|
+
# (which is true for UI HTTP handler threads)
|
|
153
|
+
import threading
|
|
154
|
+
if threading.current_thread() != threading.main_thread(): # Simplified check
|
|
155
|
+
future = asyncio.run_coroutine_threadsafe(_run(), loop)
|
|
156
|
+
return future.result()
|
|
157
|
+
else:
|
|
158
|
+
# If we're on the main thread but in a loop, we can't block.
|
|
159
|
+
# This case shouldn't happen for UIChannelManager but might for tests.
|
|
160
|
+
# For tests, we'll try anyio.from_thread.run if available or just run it.
|
|
161
|
+
return anyio.from_thread.run(_run)
|
|
162
|
+
else:
|
|
163
|
+
return asyncio.run(_run())
|
|
164
|
+
|
|
165
|
+
def get_dataset_state(self) -> dict[str, Any]:
|
|
166
|
+
return self._call_sync("get_dataset_state", {})
|
|
167
|
+
|
|
168
|
+
def get_arrow_stream(self, **kwargs) -> bytes:
|
|
169
|
+
return self._call_sync("get_arrow_stream", kwargs)
|
|
170
|
+
|
|
171
|
+
def list_variables_rich(self) -> list[dict[str, Any]]:
|
|
172
|
+
return self._call_sync("list_variables_rich", {})
|
|
173
|
+
|
|
174
|
+
def compute_view_indices(self, filter_expr: str) -> list[int]:
|
|
175
|
+
return self._call_sync("compute_view_indices", {"filter_expr": filter_expr})
|
|
176
|
+
|
|
177
|
+
def validate_filter_expr(self, filter_expr: str):
|
|
178
|
+
return self._call_sync("validate_filter_expr", {"filter_expr": filter_expr})
|
|
179
|
+
|
|
180
|
+
def get_page(self, **kwargs):
|
|
181
|
+
return self._call_sync("get_page", kwargs)
|
|
182
|
+
|
|
183
|
+
client = StataClientProxy()
|
|
184
|
+
ui_channel = None
|
|
185
|
+
|
|
186
|
+
def _ensure_ui_channel():
|
|
187
|
+
global ui_channel
|
|
188
|
+
if ui_channel is None:
|
|
189
|
+
try:
|
|
190
|
+
from .ui_http import UIChannelManager
|
|
191
|
+
ui_channel = UIChannelManager(client)
|
|
192
|
+
except Exception:
|
|
193
|
+
logger.exception("Failed to initialize UI channel")
|
|
194
|
+
|
|
195
|
+
@mcp.tool()
|
|
196
|
+
async def create_session(session_id: str) -> str:
|
|
197
|
+
"""Create a new Stata session.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
session_id: A unique identifier for the new session.
|
|
201
|
+
"""
|
|
202
|
+
await session_manager.get_or_create_session(session_id)
|
|
203
|
+
return json.dumps({"status": "created", "session_id": session_id})
|
|
204
|
+
|
|
205
|
+
@mcp.tool()
|
|
206
|
+
async def stop_session(session_id: str) -> str:
|
|
207
|
+
"""Stop and terminate a Stata session.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
session_id: The identifier of the session to stop.
|
|
211
|
+
"""
|
|
212
|
+
await session_manager.stop_session(session_id)
|
|
213
|
+
return json.dumps({"status": "stopped", "session_id": session_id})
|
|
214
|
+
|
|
215
|
+
@mcp.tool()
|
|
216
|
+
def list_sessions() -> str:
|
|
217
|
+
"""List all active Stata sessions and their status."""
|
|
218
|
+
sessions = session_manager.list_sessions()
|
|
219
|
+
return SessionListResponse(sessions=sessions).model_dump_json()
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
async def _noop_log(_text: str) -> None:
|
|
223
|
+
return
|
|
224
|
+
|
|
225
|
+
@dataclass
|
|
226
|
+
class BackgroundTask:
|
|
227
|
+
task_id: str
|
|
228
|
+
kind: str
|
|
229
|
+
task: asyncio.Task
|
|
230
|
+
created_at: datetime
|
|
231
|
+
log_path: Optional[str] = None
|
|
232
|
+
result: Optional[str] = None
|
|
233
|
+
error: Optional[str] = None
|
|
234
|
+
done: bool = False
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
_background_tasks: Dict[str, BackgroundTask] = {}
|
|
238
|
+
_request_log_paths: Dict[str, str] = {}
|
|
239
|
+
_read_log_paths: set[str] = set()
|
|
240
|
+
_read_log_offsets: Dict[str, int] = {}
|
|
241
|
+
_STDOUT_FILTER_INSTALLED = False
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def _install_stdout_filter() -> None:
|
|
245
|
+
"""
|
|
246
|
+
Redirect process stdout to a pipe and forward only JSON-RPC lines to the
|
|
247
|
+
original stdout. Any non-JSON output (e.g., Stata noise) is sent to stderr.
|
|
248
|
+
"""
|
|
249
|
+
global _STDOUT_FILTER_INSTALLED
|
|
250
|
+
if _STDOUT_FILTER_INSTALLED:
|
|
251
|
+
return
|
|
252
|
+
_STDOUT_FILTER_INSTALLED = True
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
# Flush any pending output before redirecting.
|
|
256
|
+
try:
|
|
257
|
+
sys.stdout.flush()
|
|
258
|
+
except Exception:
|
|
259
|
+
pass
|
|
260
|
+
|
|
261
|
+
original_stdout_fd = os.dup(1)
|
|
262
|
+
read_fd, write_fd = os.pipe()
|
|
263
|
+
os.dup2(write_fd, 1)
|
|
264
|
+
os.close(write_fd)
|
|
265
|
+
|
|
266
|
+
def _forward_stdout() -> None:
|
|
267
|
+
buffer = b""
|
|
268
|
+
while True:
|
|
269
|
+
try:
|
|
270
|
+
chunk = os.read(read_fd, 4096)
|
|
271
|
+
except Exception:
|
|
272
|
+
break
|
|
273
|
+
if not chunk:
|
|
274
|
+
break
|
|
275
|
+
buffer += chunk
|
|
276
|
+
while b"\n" in buffer:
|
|
277
|
+
line, buffer = buffer.split(b"\n", 1)
|
|
278
|
+
line_with_nl = line + b"\n"
|
|
279
|
+
stripped = line.lstrip()
|
|
280
|
+
if stripped:
|
|
281
|
+
try:
|
|
282
|
+
payload = json.loads(stripped)
|
|
283
|
+
if isinstance(payload, dict) and payload.get("jsonrpc"):
|
|
284
|
+
os.write(original_stdout_fd, line_with_nl)
|
|
285
|
+
elif isinstance(payload, list) and any(
|
|
286
|
+
isinstance(item, dict) and item.get("jsonrpc") for item in payload
|
|
287
|
+
):
|
|
288
|
+
os.write(original_stdout_fd, line_with_nl)
|
|
289
|
+
else:
|
|
290
|
+
os.write(2, line_with_nl)
|
|
291
|
+
except Exception:
|
|
292
|
+
os.write(2, line_with_nl)
|
|
293
|
+
if buffer:
|
|
294
|
+
stripped = buffer.lstrip()
|
|
295
|
+
if stripped:
|
|
296
|
+
try:
|
|
297
|
+
payload = json.loads(stripped)
|
|
298
|
+
if isinstance(payload, dict) and payload.get("jsonrpc"):
|
|
299
|
+
os.write(original_stdout_fd, buffer)
|
|
300
|
+
elif isinstance(payload, list) and any(
|
|
301
|
+
isinstance(item, dict) and item.get("jsonrpc") for item in payload
|
|
302
|
+
):
|
|
303
|
+
os.write(original_stdout_fd, buffer)
|
|
304
|
+
else:
|
|
305
|
+
os.write(2, buffer)
|
|
306
|
+
except Exception:
|
|
307
|
+
os.write(2, buffer)
|
|
308
|
+
|
|
309
|
+
try:
|
|
310
|
+
os.close(read_fd)
|
|
311
|
+
except Exception:
|
|
312
|
+
pass
|
|
313
|
+
|
|
314
|
+
t = threading.Thread(target=_forward_stdout, name="mcp-stdout-filter", daemon=True)
|
|
315
|
+
t.start()
|
|
316
|
+
except Exception:
|
|
317
|
+
_STDOUT_FILTER_INSTALLED = False
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _register_task(task_info: BackgroundTask, max_tasks: int = 100) -> None:
|
|
321
|
+
_background_tasks[task_info.task_id] = task_info
|
|
322
|
+
if len(_background_tasks) <= max_tasks:
|
|
323
|
+
return
|
|
324
|
+
completed = [task for task in _background_tasks.values() if task.done]
|
|
325
|
+
completed.sort(key=lambda item: item.created_at)
|
|
326
|
+
for task in completed[: max(0, len(_background_tasks) - max_tasks)]:
|
|
327
|
+
_background_tasks.pop(task.task_id, None)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _format_command_result(result, raw: bool, as_json: bool) -> str:
|
|
331
|
+
if raw:
|
|
332
|
+
if result.success:
|
|
333
|
+
return result.log_path or ""
|
|
334
|
+
if result.error:
|
|
335
|
+
msg = result.error.message
|
|
336
|
+
if result.error.rc is not None:
|
|
337
|
+
msg = f"{msg}\nrc={result.error.rc}"
|
|
338
|
+
return msg
|
|
339
|
+
return result.log_path or ""
|
|
340
|
+
|
|
341
|
+
# Note: we used to clear result.stdout here for token efficiency,
|
|
342
|
+
# but that conflicts with requirements and breaks E2E tests that
|
|
343
|
+
# expect results in the return value.
|
|
344
|
+
|
|
345
|
+
if as_json:
|
|
346
|
+
return result.model_dump_json()
|
|
347
|
+
return result.model_dump_json()
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
async def _wait_for_log_path(task_info: BackgroundTask) -> None:
|
|
351
|
+
while task_info.log_path is None and not task_info.done:
|
|
352
|
+
await anyio.sleep(0.01)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
async def _notify_task_done(session: object | None, task_info: BackgroundTask, request_id: object | None) -> None:
|
|
356
|
+
if session is None:
|
|
357
|
+
return
|
|
358
|
+
payload = {
|
|
359
|
+
"event": "task_done",
|
|
360
|
+
"task_id": task_info.task_id,
|
|
361
|
+
"status": "done" if task_info.done else "unknown",
|
|
362
|
+
"log_path": task_info.log_path,
|
|
363
|
+
"error": task_info.error,
|
|
364
|
+
}
|
|
365
|
+
try:
|
|
366
|
+
await session.send_log_message(level="info", data=json.dumps(payload), related_request_id=request_id)
|
|
367
|
+
except Exception:
|
|
368
|
+
return
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _debug_notification(kind: str, payload: object, request_id: object | None = None) -> None:
|
|
372
|
+
try:
|
|
373
|
+
serialized = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False)
|
|
374
|
+
except Exception:
|
|
375
|
+
serialized = str(payload)
|
|
376
|
+
payload_logger.info("MCP notify %s request_id=%s payload=%s", kind, request_id, serialized)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
async def _notify_tool_error(ctx: Context | None, tool_name: str, exc: Exception) -> None:
|
|
380
|
+
if ctx is None:
|
|
381
|
+
return
|
|
382
|
+
session = ctx.request_context.session
|
|
383
|
+
if session is None:
|
|
384
|
+
return
|
|
385
|
+
task_id = None
|
|
386
|
+
meta = ctx.request_context.meta
|
|
387
|
+
if meta is not None:
|
|
388
|
+
task_id = getattr(meta, "task_id", None) or getattr(meta, "taskId", None)
|
|
389
|
+
payload = {
|
|
390
|
+
"event": "tool_error",
|
|
391
|
+
"tool": tool_name,
|
|
392
|
+
"error": str(exc),
|
|
393
|
+
"traceback": traceback.format_exc(),
|
|
394
|
+
}
|
|
395
|
+
if task_id is not None:
|
|
396
|
+
payload["task_id"] = task_id
|
|
397
|
+
try:
|
|
398
|
+
await session.send_log_message(
|
|
399
|
+
level="error",
|
|
400
|
+
data=json.dumps(payload),
|
|
401
|
+
related_request_id=ctx.request_id,
|
|
402
|
+
)
|
|
403
|
+
except Exception:
|
|
404
|
+
logger.exception("Failed to emit tool_error notification for %s", tool_name)
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _log_tool_call(tool_name: str, ctx: Context | None = None) -> None:
|
|
408
|
+
request_id = None
|
|
409
|
+
if ctx is not None:
|
|
410
|
+
request_id = getattr(ctx, "request_id", None)
|
|
411
|
+
logger.info("MCP tool call: %s request_id=%s", tool_name, request_id)
|
|
412
|
+
|
|
413
|
+
def _should_stream_smcl_chunk(text: str, request_id: object | None) -> bool:
|
|
414
|
+
if request_id is None:
|
|
415
|
+
return True
|
|
416
|
+
try:
|
|
417
|
+
payload = json.loads(text)
|
|
418
|
+
if isinstance(payload, dict) and payload.get("event"):
|
|
419
|
+
return True
|
|
420
|
+
except Exception:
|
|
421
|
+
pass
|
|
422
|
+
log_path = _request_log_paths.get(str(request_id))
|
|
423
|
+
if log_path and log_path in _read_log_paths:
|
|
424
|
+
return False
|
|
425
|
+
return True
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _attach_task_id(ctx: Context | None, task_id: str) -> None:
|
|
429
|
+
if ctx is None:
|
|
430
|
+
return
|
|
431
|
+
meta = ctx.request_context.meta
|
|
432
|
+
if meta is None:
|
|
433
|
+
meta = types.RequestParams.Meta()
|
|
434
|
+
ctx.request_context.meta = meta
|
|
435
|
+
try:
|
|
436
|
+
setattr(meta, "task_id", task_id)
|
|
437
|
+
except Exception:
|
|
438
|
+
logger.debug("Unable to attach task_id to request meta", exc_info=True)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def _extract_ctx(args: tuple[object, ...], kwargs: dict[str, object]) -> Context | None:
|
|
442
|
+
ctx = kwargs.get("ctx")
|
|
443
|
+
if isinstance(ctx, Context):
|
|
444
|
+
return ctx
|
|
445
|
+
for arg in args:
|
|
446
|
+
if isinstance(arg, Context):
|
|
447
|
+
return arg
|
|
448
|
+
return None
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def log_call(func):
|
|
452
|
+
"""Decorator to log tool and resource calls."""
|
|
453
|
+
if asyncio.iscoroutinefunction(func):
|
|
454
|
+
@wraps(func)
|
|
455
|
+
async def async_inner(*args, **kwargs):
|
|
456
|
+
ctx = _extract_ctx(args, kwargs)
|
|
457
|
+
_log_tool_call(func.__name__, ctx)
|
|
458
|
+
return await func(*args, **kwargs)
|
|
459
|
+
return async_inner
|
|
460
|
+
else:
|
|
461
|
+
@wraps(func)
|
|
462
|
+
def sync_inner(*args, **kwargs):
|
|
463
|
+
ctx = _extract_ctx(args, kwargs)
|
|
464
|
+
_log_tool_call(func.__name__, ctx)
|
|
465
|
+
return func(*args, **kwargs)
|
|
466
|
+
return sync_inner
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
@mcp.tool()
|
|
470
|
+
@log_call
|
|
471
|
+
async def run_do_file_background(
|
|
472
|
+
path: str,
|
|
473
|
+
ctx: Context | None = None,
|
|
474
|
+
echo: bool = True,
|
|
475
|
+
as_json: bool = True,
|
|
476
|
+
trace: bool = False,
|
|
477
|
+
raw: bool = False,
|
|
478
|
+
max_output_lines: int = None,
|
|
479
|
+
cwd: str | None = None,
|
|
480
|
+
session_id: str = "default",
|
|
481
|
+
) -> str:
|
|
482
|
+
"""Run a Stata do-file in the background and return a task id.
|
|
483
|
+
|
|
484
|
+
Notifications:
|
|
485
|
+
- logMessage: {"event":"log_path","path":"..."}
|
|
486
|
+
- logMessage: {"event":"task_done","task_id":"...","status":"done","log_path":"...","error":null}
|
|
487
|
+
"""
|
|
488
|
+
session = getattr(getattr(ctx, "request_context", None), "session", None) if ctx is not None else None
|
|
489
|
+
request_id = ctx.request_id if ctx is not None else None
|
|
490
|
+
task_id = uuid.uuid4().hex
|
|
491
|
+
_attach_task_id(ctx, task_id)
|
|
492
|
+
task_info = BackgroundTask(
|
|
493
|
+
task_id=task_id,
|
|
494
|
+
kind="do_file",
|
|
495
|
+
task=None,
|
|
496
|
+
created_at=datetime.now(timezone.utc),
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
async def notify_log(text: str) -> None:
|
|
500
|
+
if session is not None:
|
|
501
|
+
if not _should_stream_smcl_chunk(text, ctx.request_id):
|
|
502
|
+
return
|
|
503
|
+
_debug_notification("logMessage", text, ctx.request_id)
|
|
504
|
+
try:
|
|
505
|
+
await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
|
|
506
|
+
except Exception as e:
|
|
507
|
+
logger.warning("Failed to send logMessage notification: %s", e)
|
|
508
|
+
sys.stderr.write(f"[mcp_stata] ERROR: logMessage send failed: {e!r}\n")
|
|
509
|
+
sys.stderr.flush()
|
|
510
|
+
try:
|
|
511
|
+
payload = json.loads(text)
|
|
512
|
+
if isinstance(payload, dict) and payload.get("event") == "log_path":
|
|
513
|
+
task_info.log_path = payload.get("path")
|
|
514
|
+
if ctx.request_id is not None and task_info.log_path:
|
|
515
|
+
_request_log_paths[str(ctx.request_id)] = task_info.log_path
|
|
516
|
+
except Exception:
|
|
517
|
+
return
|
|
518
|
+
|
|
519
|
+
progress_token = None
|
|
520
|
+
if ctx is not None and getattr(ctx, "request_context", None) is not None and getattr(ctx.request_context, "meta", None) is not None:
|
|
521
|
+
progress_token = getattr(ctx.request_context.meta, "progressToken", None)
|
|
522
|
+
|
|
523
|
+
async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
|
|
524
|
+
if session is None or progress_token is None:
|
|
525
|
+
return
|
|
526
|
+
_debug_notification(
|
|
527
|
+
"progress",
|
|
528
|
+
{"progress": progress, "total": total, "message": message},
|
|
529
|
+
ctx.request_id,
|
|
530
|
+
)
|
|
531
|
+
try:
|
|
532
|
+
await session.send_progress_notification(
|
|
533
|
+
progress_token=progress_token,
|
|
534
|
+
progress=progress,
|
|
535
|
+
total=total,
|
|
536
|
+
message=message,
|
|
537
|
+
related_request_id=ctx.request_id,
|
|
538
|
+
)
|
|
539
|
+
except Exception as exc:
|
|
540
|
+
logger.debug("Progress notification failed: %s", exc)
|
|
541
|
+
|
|
542
|
+
async def _run() -> None:
|
|
543
|
+
try:
|
|
544
|
+
stata_session = await session_manager.get_or_create_session(session_id)
|
|
545
|
+
result_dict = await stata_session.call(
|
|
546
|
+
"run_do_file",
|
|
547
|
+
{
|
|
548
|
+
"path": path,
|
|
549
|
+
"options": {
|
|
550
|
+
"echo": echo,
|
|
551
|
+
"trace": trace,
|
|
552
|
+
"max_output_lines": max_output_lines,
|
|
553
|
+
"cwd": cwd,
|
|
554
|
+
"emit_graph_ready": True,
|
|
555
|
+
"graph_ready_task_id": task_id,
|
|
556
|
+
"graph_ready_format": "svg",
|
|
557
|
+
}
|
|
558
|
+
},
|
|
559
|
+
notify_log=notify_log,
|
|
560
|
+
notify_progress=notify_progress if progress_token is not None else None,
|
|
561
|
+
)
|
|
562
|
+
result = CommandResponse.model_validate(result_dict)
|
|
563
|
+
if not task_info.log_path and result.log_path:
|
|
564
|
+
task_info.log_path = result.log_path
|
|
565
|
+
if result.error:
|
|
566
|
+
task_info.error = result.error.message
|
|
567
|
+
task_info.result = _format_command_result(result, raw=raw, as_json=as_json)
|
|
568
|
+
task_info.done = True
|
|
569
|
+
await _notify_task_done(session, task_info, request_id)
|
|
570
|
+
|
|
571
|
+
_ensure_ui_channel()
|
|
572
|
+
if ui_channel:
|
|
573
|
+
ui_channel.notify_potential_dataset_change(session_id)
|
|
574
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
575
|
+
task_info.done = True
|
|
576
|
+
task_info.error = str(exc)
|
|
577
|
+
await _notify_task_done(session, task_info, request_id)
|
|
578
|
+
|
|
579
|
+
if session is None:
|
|
580
|
+
await _run()
|
|
581
|
+
task_info.task = None
|
|
582
|
+
else:
|
|
583
|
+
task_info.task = asyncio.create_task(_run())
|
|
584
|
+
_register_task(task_info)
|
|
585
|
+
await _wait_for_log_path(task_info)
|
|
586
|
+
return json.dumps({"task_id": task_id, "status": "started", "log_path": task_info.log_path})
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
@mcp.tool()
|
|
590
|
+
@log_call
|
|
591
|
+
def get_task_status(task_id: str, allow_polling: bool = False) -> str:
|
|
592
|
+
"""Return task status for background executions.
|
|
593
|
+
|
|
594
|
+
Polling is disabled by default; set allow_polling=True for legacy callers.
|
|
595
|
+
"""
|
|
596
|
+
notice = "Prefer task_done logMessage notifications over polling get_task_status."
|
|
597
|
+
if not allow_polling:
|
|
598
|
+
logger.warning(
|
|
599
|
+
"get_task_status called without allow_polling; clients must use task_done logMessage notifications"
|
|
600
|
+
)
|
|
601
|
+
return json.dumps({
|
|
602
|
+
"task_id": task_id,
|
|
603
|
+
"status": "polling_not_allowed",
|
|
604
|
+
"error": "Polling is disabled; use task_done logMessage notifications.",
|
|
605
|
+
"notice": notice,
|
|
606
|
+
})
|
|
607
|
+
logger.warning("get_task_status called; clients should use task_done logMessage notifications instead of polling")
|
|
608
|
+
task_info = _background_tasks.get(task_id)
|
|
609
|
+
if task_info is None:
|
|
610
|
+
return json.dumps({"task_id": task_id, "status": "not_found", "notice": notice})
|
|
611
|
+
return json.dumps({
|
|
612
|
+
"task_id": task_id,
|
|
613
|
+
"status": "done" if task_info.done else "running",
|
|
614
|
+
"kind": task_info.kind,
|
|
615
|
+
"created_at": task_info.created_at.isoformat(),
|
|
616
|
+
"log_path": task_info.log_path,
|
|
617
|
+
"error": task_info.error,
|
|
618
|
+
"notice": notice,
|
|
619
|
+
})
|
|
620
|
+
|
|
621
|
+
|
|
622
|
+
@mcp.tool()
|
|
623
|
+
@log_call
|
|
624
|
+
def get_task_result(task_id: str, allow_polling: bool = False) -> str:
|
|
625
|
+
"""Return task result for background executions.
|
|
626
|
+
|
|
627
|
+
Polling is disabled by default; set allow_polling=True for legacy callers.
|
|
628
|
+
"""
|
|
629
|
+
notice = "Prefer task_done logMessage notifications over polling get_task_result."
|
|
630
|
+
if not allow_polling:
|
|
631
|
+
logger.warning(
|
|
632
|
+
"get_task_result called without allow_polling; clients must use task_done logMessage notifications"
|
|
633
|
+
)
|
|
634
|
+
return json.dumps({
|
|
635
|
+
"task_id": task_id,
|
|
636
|
+
"status": "polling_not_allowed",
|
|
637
|
+
"error": "Polling is disabled; use task_done logMessage notifications.",
|
|
638
|
+
"notice": notice,
|
|
639
|
+
})
|
|
640
|
+
logger.warning("get_task_result called; clients should use task_done logMessage notifications instead of polling")
|
|
641
|
+
task_info = _background_tasks.get(task_id)
|
|
642
|
+
if task_info is None:
|
|
643
|
+
return json.dumps({"task_id": task_id, "status": "not_found", "notice": notice})
|
|
644
|
+
if not task_info.done:
|
|
645
|
+
return json.dumps({
|
|
646
|
+
"task_id": task_id,
|
|
647
|
+
"status": "running",
|
|
648
|
+
"log_path": task_info.log_path,
|
|
649
|
+
"notice": notice,
|
|
650
|
+
})
|
|
651
|
+
return json.dumps({
|
|
652
|
+
"task_id": task_id,
|
|
653
|
+
"status": "done",
|
|
654
|
+
"log_path": task_info.log_path,
|
|
655
|
+
"error": task_info.error,
|
|
656
|
+
"notice": notice,
|
|
657
|
+
"result": task_info.result,
|
|
658
|
+
})
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
@mcp.tool()
|
|
662
|
+
@log_call
|
|
663
|
+
def cancel_task(task_id: str) -> str:
|
|
664
|
+
"""Request cancellation of a background task."""
|
|
665
|
+
task_info = _background_tasks.get(task_id)
|
|
666
|
+
if task_info is None:
|
|
667
|
+
return json.dumps({"task_id": task_id, "status": "not_found"})
|
|
668
|
+
if task_info.task and not task_info.task.done():
|
|
669
|
+
task_info.task.cancel()
|
|
670
|
+
return json.dumps({"task_id": task_id, "status": "cancelling"})
|
|
671
|
+
return json.dumps({"task_id": task_id, "status": "done", "log_path": task_info.log_path})
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
@mcp.tool()
|
|
675
|
+
@log_call
|
|
676
|
+
async def run_command_background(
|
|
677
|
+
code: str,
|
|
678
|
+
ctx: Context | None = None,
|
|
679
|
+
echo: bool = True,
|
|
680
|
+
as_json: bool = True,
|
|
681
|
+
trace: bool = False,
|
|
682
|
+
raw: bool = False,
|
|
683
|
+
max_output_lines: int = None,
|
|
684
|
+
cwd: str | None = None,
|
|
685
|
+
session_id: str = "default",
|
|
686
|
+
) -> str:
|
|
687
|
+
"""Run a Stata command in the background and return a task id.
|
|
688
|
+
|
|
689
|
+
Notifications:
|
|
690
|
+
- logMessage: {"event":"log_path","path":"..."}
|
|
691
|
+
- logMessage: {"event":"task_done","task_id":"...","status":"done","log_path":"...","error":null}
|
|
692
|
+
"""
|
|
693
|
+
session = getattr(getattr(ctx, "request_context", None), "session", None) if ctx is not None else None
|
|
694
|
+
request_id = ctx.request_id if ctx is not None else None
|
|
695
|
+
task_id = uuid.uuid4().hex
|
|
696
|
+
_attach_task_id(ctx, task_id)
|
|
697
|
+
task_info = BackgroundTask(
|
|
698
|
+
task_id=task_id,
|
|
699
|
+
kind="command",
|
|
700
|
+
task=None,
|
|
701
|
+
created_at=datetime.now(timezone.utc),
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
async def notify_log(text: str) -> None:
|
|
705
|
+
if session is not None:
|
|
706
|
+
if not _should_stream_smcl_chunk(text, ctx.request_id):
|
|
707
|
+
return
|
|
708
|
+
_debug_notification("logMessage", text, ctx.request_id)
|
|
709
|
+
await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
|
|
710
|
+
try:
|
|
711
|
+
payload = json.loads(text)
|
|
712
|
+
if isinstance(payload, dict) and payload.get("event") == "log_path":
|
|
713
|
+
task_info.log_path = payload.get("path")
|
|
714
|
+
if ctx.request_id is not None and task_info.log_path:
|
|
715
|
+
_request_log_paths[str(ctx.request_id)] = task_info.log_path
|
|
716
|
+
except Exception:
|
|
717
|
+
return
|
|
718
|
+
|
|
719
|
+
progress_token = None
|
|
720
|
+
if ctx is not None and getattr(ctx, "request_context", None) is not None and getattr(ctx.request_context, "meta", None) is not None:
|
|
721
|
+
progress_token = getattr(ctx.request_context.meta, "progressToken", None)
|
|
722
|
+
|
|
723
|
+
async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
|
|
724
|
+
if session is None or progress_token is None:
|
|
725
|
+
return
|
|
726
|
+
await session.send_progress_notification(
|
|
727
|
+
progress_token=progress_token,
|
|
728
|
+
progress=progress,
|
|
729
|
+
total=total,
|
|
730
|
+
message=message,
|
|
731
|
+
related_request_id=ctx.request_id,
|
|
732
|
+
)
|
|
733
|
+
|
|
734
|
+
async def _run() -> None:
|
|
735
|
+
try:
|
|
736
|
+
stata_session = await session_manager.get_or_create_session(session_id)
|
|
737
|
+
result_dict = await stata_session.call(
|
|
738
|
+
"run_command",
|
|
739
|
+
{
|
|
740
|
+
"code": code,
|
|
741
|
+
"options": {
|
|
742
|
+
"echo": echo,
|
|
743
|
+
"trace": trace,
|
|
744
|
+
"max_output_lines": max_output_lines,
|
|
745
|
+
"cwd": cwd,
|
|
746
|
+
"emit_graph_ready": True,
|
|
747
|
+
"graph_ready_task_id": task_id,
|
|
748
|
+
"graph_ready_format": "svg",
|
|
749
|
+
}
|
|
750
|
+
},
|
|
751
|
+
notify_log=notify_log,
|
|
752
|
+
notify_progress=notify_progress if progress_token is not None else None,
|
|
753
|
+
)
|
|
754
|
+
result = CommandResponse.model_validate(result_dict)
|
|
755
|
+
if not task_info.log_path and result.log_path:
|
|
756
|
+
task_info.log_path = result.log_path
|
|
757
|
+
if result.error:
|
|
758
|
+
task_info.error = result.error.message
|
|
759
|
+
task_info.result = _format_command_result(result, raw=raw, as_json=as_json)
|
|
760
|
+
task_info.done = True
|
|
761
|
+
await _notify_task_done(session, task_info, request_id)
|
|
762
|
+
|
|
763
|
+
_ensure_ui_channel()
|
|
764
|
+
if ui_channel:
|
|
765
|
+
ui_channel.notify_potential_dataset_change(session_id)
|
|
766
|
+
except Exception as exc: # pragma: no cover - defensive
|
|
767
|
+
task_info.done = True
|
|
768
|
+
task_info.error = str(exc)
|
|
769
|
+
await _notify_task_done(session, task_info, request_id)
|
|
770
|
+
|
|
771
|
+
if session is None:
|
|
772
|
+
await _run()
|
|
773
|
+
task_info.task = None
|
|
774
|
+
else:
|
|
775
|
+
task_info.task = asyncio.create_task(_run())
|
|
776
|
+
_register_task(task_info)
|
|
777
|
+
await _wait_for_log_path(task_info)
|
|
778
|
+
return json.dumps({"task_id": task_id, "status": "started", "log_path": task_info.log_path})
|
|
779
|
+
|
|
780
|
+
@mcp.tool()
|
|
781
|
+
@log_call
|
|
782
|
+
async def run_command(
|
|
783
|
+
code: str,
|
|
784
|
+
ctx: Context | None = None,
|
|
785
|
+
echo: bool = True,
|
|
786
|
+
as_json: bool = True,
|
|
787
|
+
trace: bool = False,
|
|
788
|
+
raw: bool = False,
|
|
789
|
+
max_output_lines: int = None,
|
|
790
|
+
cwd: str | None = None,
|
|
791
|
+
session_id: str = "default",
|
|
792
|
+
) -> str:
|
|
793
|
+
"""
|
|
794
|
+
Executes Stata code.
|
|
795
|
+
|
|
796
|
+
This is the primary tool for interacting with Stata.
|
|
797
|
+
|
|
798
|
+
Stata output is written to a temporary log file on disk.
|
|
799
|
+
The server emits a single `notifications/logMessage` event containing the log file path
|
|
800
|
+
(JSON payload: {"event":"log_path","path":"..."}) so the client can tail it locally.
|
|
801
|
+
If the client supplies a progress callback/token, progress updates may also be emitted
|
|
802
|
+
via `notifications/progress`.
|
|
803
|
+
|
|
804
|
+
Args:
|
|
805
|
+
code: The Stata command(s) to execute (e.g., "sysuse auto", "regress price mpg", "summarize").
|
|
806
|
+
ctx: FastMCP-injected request context (used to send MCP notifications). Optional for direct Python calls.
|
|
807
|
+
echo: If True, the command itself is included in the output. Default is True.
|
|
808
|
+
as_json: If True, returns a JSON envelope with rc/stdout/stderr/error.
|
|
809
|
+
trace: If True, enables `set trace on` for deeper error diagnostics (automatically disabled after).
|
|
810
|
+
raw: If True, return raw output/error message rather than a JSON envelope.
|
|
811
|
+
max_output_lines: If set, truncates stdout to this many lines for token efficiency.
|
|
812
|
+
Useful for verbose commands (regress, codebook, etc.).
|
|
813
|
+
Note: This tool always uses log-file streaming semantics; there is no non-streaming mode.
|
|
814
|
+
"""
|
|
815
|
+
session = getattr(getattr(ctx, "request_context", None), "session", None) if ctx is not None else None
|
|
816
|
+
|
|
817
|
+
async def notify_log(text: str) -> None:
|
|
818
|
+
if session is None:
|
|
819
|
+
return
|
|
820
|
+
if not _should_stream_smcl_chunk(text, ctx.request_id):
|
|
821
|
+
return
|
|
822
|
+
_debug_notification("logMessage", text, ctx.request_id)
|
|
823
|
+
await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
|
|
824
|
+
try:
|
|
825
|
+
payload = json.loads(text)
|
|
826
|
+
if isinstance(payload, dict) and payload.get("event") == "log_path":
|
|
827
|
+
if ctx.request_id is not None:
|
|
828
|
+
_request_log_paths[str(ctx.request_id)] = payload.get("path")
|
|
829
|
+
except Exception:
|
|
830
|
+
return
|
|
831
|
+
|
|
832
|
+
progress_token = None
|
|
833
|
+
if ctx is not None and getattr(ctx, "request_context", None) is not None and getattr(ctx.request_context, "meta", None) is not None:
|
|
834
|
+
progress_token = getattr(ctx.request_context.meta, "progressToken", None)
|
|
835
|
+
|
|
836
|
+
async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
|
|
837
|
+
if session is None or progress_token is None:
|
|
838
|
+
return
|
|
839
|
+
await session.send_progress_notification(
|
|
840
|
+
progress_token=progress_token,
|
|
841
|
+
progress=progress,
|
|
842
|
+
total=total,
|
|
843
|
+
message=message,
|
|
844
|
+
related_request_id=ctx.request_id,
|
|
845
|
+
)
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
stata_session = await session_manager.get_or_create_session(session_id)
|
|
849
|
+
result_dict = await stata_session.call(
|
|
850
|
+
"run_command",
|
|
851
|
+
{
|
|
852
|
+
"code": code,
|
|
853
|
+
"options": {
|
|
854
|
+
"echo": echo,
|
|
855
|
+
"trace": trace,
|
|
856
|
+
"max_output_lines": max_output_lines,
|
|
857
|
+
"cwd": cwd,
|
|
858
|
+
"emit_graph_ready": True,
|
|
859
|
+
"graph_ready_task_id": ctx.request_id if ctx else None,
|
|
860
|
+
"graph_ready_format": "svg",
|
|
861
|
+
}
|
|
862
|
+
},
|
|
863
|
+
notify_log=notify_log if session is not None else _noop_log,
|
|
864
|
+
notify_progress=notify_progress if progress_token is not None else None,
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
result = CommandResponse.model_validate(result_dict)
|
|
868
|
+
_ensure_ui_channel()
|
|
869
|
+
if ui_channel:
|
|
870
|
+
ui_channel.notify_potential_dataset_change(session_id)
|
|
871
|
+
return _format_command_result(result, raw=raw, as_json=as_json)
|
|
872
|
+
|
|
873
|
+
@mcp.tool()
|
|
874
|
+
@log_call
|
|
875
|
+
def read_log(path: str, offset: int = 0, max_bytes: int = 65536) -> str:
|
|
876
|
+
"""Read a slice of a log file.
|
|
877
|
+
|
|
878
|
+
Intended for clients that want to display a terminal-like view without pushing MBs of
|
|
879
|
+
output through MCP log notifications.
|
|
880
|
+
|
|
881
|
+
Args:
|
|
882
|
+
path: Absolute path to the log file previously provided by the server.
|
|
883
|
+
offset: Byte offset to start reading from.
|
|
884
|
+
max_bytes: Maximum bytes to read.
|
|
885
|
+
|
|
886
|
+
Returns a compact JSON string: {"path":..., "offset":..., "next_offset":..., "data":...}
|
|
887
|
+
"""
|
|
888
|
+
try:
|
|
889
|
+
if path:
|
|
890
|
+
_read_log_paths.add(path)
|
|
891
|
+
if offset < 0:
|
|
892
|
+
offset = 0
|
|
893
|
+
if path:
|
|
894
|
+
last_offset = _read_log_offsets.get(path, 0)
|
|
895
|
+
if offset < last_offset:
|
|
896
|
+
offset = last_offset
|
|
897
|
+
with open(path, "rb") as f:
|
|
898
|
+
f.seek(offset)
|
|
899
|
+
data = f.read(max_bytes)
|
|
900
|
+
next_offset = f.tell()
|
|
901
|
+
if path:
|
|
902
|
+
_read_log_offsets[path] = next_offset
|
|
903
|
+
text = data.decode("utf-8", errors="replace")
|
|
904
|
+
return json.dumps({"path": path, "offset": offset, "next_offset": next_offset, "data": text})
|
|
905
|
+
except FileNotFoundError:
|
|
906
|
+
return json.dumps({"path": path, "offset": offset, "next_offset": offset, "data": ""})
|
|
907
|
+
except Exception as e:
|
|
908
|
+
return json.dumps({"path": path, "offset": offset, "next_offset": offset, "data": f"ERROR: {e}"})
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
@mcp.tool()
|
|
912
|
+
@log_call
|
|
913
|
+
def find_in_log(
|
|
914
|
+
path: str,
|
|
915
|
+
query: str,
|
|
916
|
+
start_offset: int = 0,
|
|
917
|
+
max_bytes: int = 5_000_000,
|
|
918
|
+
before: int = 2,
|
|
919
|
+
after: int = 2,
|
|
920
|
+
case_sensitive: bool = False,
|
|
921
|
+
regex: bool = False,
|
|
922
|
+
max_matches: int = 50,
|
|
923
|
+
) -> str:
|
|
924
|
+
"""Find text within a log file and return context windows.
|
|
925
|
+
|
|
926
|
+
Args:
|
|
927
|
+
path: Absolute path to the log file previously provided by the server.
|
|
928
|
+
query: Text or regex pattern to search for.
|
|
929
|
+
start_offset: Byte offset to start searching from.
|
|
930
|
+
max_bytes: Maximum bytes to read from the log.
|
|
931
|
+
before: Number of context lines to include before each match.
|
|
932
|
+
after: Number of context lines to include after each match.
|
|
933
|
+
case_sensitive: If True, match case-sensitively.
|
|
934
|
+
regex: If True, treat query as a regular expression.
|
|
935
|
+
max_matches: Maximum number of matches to return.
|
|
936
|
+
|
|
937
|
+
Returns a JSON string with matches and offsets:
|
|
938
|
+
{"path":..., "query":..., "start_offset":..., "next_offset":..., "truncated":..., "matches":[...]}.
|
|
939
|
+
"""
|
|
940
|
+
try:
|
|
941
|
+
if start_offset < 0:
|
|
942
|
+
start_offset = 0
|
|
943
|
+
if max_bytes <= 0:
|
|
944
|
+
return json.dumps({
|
|
945
|
+
"path": path,
|
|
946
|
+
"query": query,
|
|
947
|
+
"start_offset": start_offset,
|
|
948
|
+
"next_offset": start_offset,
|
|
949
|
+
"truncated": False,
|
|
950
|
+
"matches": [],
|
|
951
|
+
})
|
|
952
|
+
with open(path, "rb") as f:
|
|
953
|
+
f.seek(start_offset)
|
|
954
|
+
data = f.read(max_bytes)
|
|
955
|
+
next_offset = f.tell()
|
|
956
|
+
|
|
957
|
+
text = data.decode("utf-8", errors="replace")
|
|
958
|
+
lines = text.splitlines()
|
|
959
|
+
|
|
960
|
+
if regex:
|
|
961
|
+
flags = 0 if case_sensitive else re.IGNORECASE
|
|
962
|
+
pattern = re.compile(query, flags=flags)
|
|
963
|
+
def is_match(line: str) -> bool:
|
|
964
|
+
return pattern.search(line) is not None
|
|
965
|
+
else:
|
|
966
|
+
needle = query if case_sensitive else query.lower()
|
|
967
|
+
def is_match(line: str) -> bool:
|
|
968
|
+
haystack = line if case_sensitive else line.lower()
|
|
969
|
+
return needle in haystack
|
|
970
|
+
|
|
971
|
+
matches = []
|
|
972
|
+
for idx, line in enumerate(lines):
|
|
973
|
+
if not is_match(line):
|
|
974
|
+
continue
|
|
975
|
+
start_idx = max(0, idx - max(0, before))
|
|
976
|
+
end_idx = min(len(lines), idx + max(0, after) + 1)
|
|
977
|
+
context = lines[start_idx:end_idx]
|
|
978
|
+
matches.append({
|
|
979
|
+
"line_index": idx,
|
|
980
|
+
"context_start": start_idx,
|
|
981
|
+
"context_end": end_idx,
|
|
982
|
+
"context": context,
|
|
983
|
+
})
|
|
984
|
+
if len(matches) >= max_matches:
|
|
985
|
+
break
|
|
986
|
+
|
|
987
|
+
truncated = len(matches) >= max_matches
|
|
988
|
+
return json.dumps({
|
|
989
|
+
"path": path,
|
|
990
|
+
"query": query,
|
|
991
|
+
"start_offset": start_offset,
|
|
992
|
+
"next_offset": next_offset,
|
|
993
|
+
"truncated": truncated,
|
|
994
|
+
"matches": matches,
|
|
995
|
+
})
|
|
996
|
+
except FileNotFoundError:
|
|
997
|
+
return json.dumps({
|
|
998
|
+
"path": path,
|
|
999
|
+
"query": query,
|
|
1000
|
+
"start_offset": start_offset,
|
|
1001
|
+
"next_offset": start_offset,
|
|
1002
|
+
"truncated": False,
|
|
1003
|
+
"matches": [],
|
|
1004
|
+
})
|
|
1005
|
+
except Exception as e:
|
|
1006
|
+
return json.dumps({
|
|
1007
|
+
"path": path,
|
|
1008
|
+
"query": query,
|
|
1009
|
+
"start_offset": start_offset,
|
|
1010
|
+
"next_offset": start_offset,
|
|
1011
|
+
"truncated": False,
|
|
1012
|
+
"matches": [],
|
|
1013
|
+
"error": f"ERROR: {e}",
|
|
1014
|
+
})
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
@mcp.tool()
|
|
1018
|
+
@log_call
|
|
1019
|
+
async def get_data(start: int = 0, count: int = 50, session_id: str = "default") -> str:
|
|
1020
|
+
"""
|
|
1021
|
+
Returns a slice of the active dataset as a JSON-formatted list of dictionaries.
|
|
1022
|
+
|
|
1023
|
+
Use this to inspect the actual data values in memory. Useful for checking data quality or content.
|
|
1024
|
+
|
|
1025
|
+
Args:
|
|
1026
|
+
start: The zero-based index of the first observation to retrieve.
|
|
1027
|
+
count: The number of observations to retrieve. Defaults to 50.
|
|
1028
|
+
session_id: The ID of the Stata session.
|
|
1029
|
+
"""
|
|
1030
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1031
|
+
data = await session.call("get_data", {"start": start, "count": count})
|
|
1032
|
+
resp = DataResponse(start=start, count=count, data=data)
|
|
1033
|
+
return resp.model_dump_json()
|
|
1034
|
+
|
|
1035
|
+
def _ensure_ui_channel():
|
|
1036
|
+
global ui_channel
|
|
1037
|
+
if ui_channel is None:
|
|
1038
|
+
try:
|
|
1039
|
+
from .ui_http import UIChannelManager
|
|
1040
|
+
# Pass the default client proxy. UIChannelManager will create
|
|
1041
|
+
# session-specific proxies as needed.
|
|
1042
|
+
ui_channel = UIChannelManager(client)
|
|
1043
|
+
except Exception:
|
|
1044
|
+
logger.exception("Failed to initialize UI channel")
|
|
1045
|
+
|
|
1046
|
+
@mcp.tool()
|
|
1047
|
+
@log_call
|
|
1048
|
+
def get_ui_channel(session_id: str = "default") -> str:
|
|
1049
|
+
"""Return localhost HTTP endpoint + bearer token for the extension UI data plane.
|
|
1050
|
+
|
|
1051
|
+
Args:
|
|
1052
|
+
session_id: Stata session ID to connect the UI to (default is "default").
|
|
1053
|
+
"""
|
|
1054
|
+
_ensure_ui_channel()
|
|
1055
|
+
if ui_channel is None:
|
|
1056
|
+
return json.dumps({"error": "UI channel not initialized"})
|
|
1057
|
+
info = ui_channel.get_channel()
|
|
1058
|
+
payload = {
|
|
1059
|
+
"baseUrl": info.base_url,
|
|
1060
|
+
"token": info.token,
|
|
1061
|
+
"expiresAt": info.expires_at,
|
|
1062
|
+
"capabilities": ui_channel.capabilities(),
|
|
1063
|
+
"sessionId": session_id,
|
|
1064
|
+
}
|
|
1065
|
+
return json.dumps(payload)
|
|
1066
|
+
|
|
1067
|
+
@mcp.tool()
|
|
1068
|
+
@log_call
|
|
1069
|
+
async def describe(session_id: str = "default") -> str:
|
|
1070
|
+
"""Returns the descriptive metadata of the dataset."""
|
|
1071
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1072
|
+
result_dict = await session.call("run_command_structured", {"code": "describe", "options": {"echo": True}})
|
|
1073
|
+
|
|
1074
|
+
result = CommandResponse.model_validate(result_dict)
|
|
1075
|
+
if result.success:
|
|
1076
|
+
return result.stdout
|
|
1077
|
+
if result.error:
|
|
1078
|
+
return result.error.message
|
|
1079
|
+
return ""
|
|
1080
|
+
|
|
1081
|
+
@mcp.tool()
|
|
1082
|
+
@log_call
|
|
1083
|
+
async def list_graphs(session_id: str = "default") -> str:
|
|
1084
|
+
"""Lists graphs in memory."""
|
|
1085
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1086
|
+
graphs_dict = await session.call("list_graphs", {})
|
|
1087
|
+
|
|
1088
|
+
graphs = GraphListResponse.model_validate(graphs_dict)
|
|
1089
|
+
return graphs.model_dump_json()
|
|
1090
|
+
|
|
1091
|
+
@mcp.tool()
|
|
1092
|
+
@log_call
|
|
1093
|
+
async def export_graph(graph_name: str = None, format: str = "pdf", session_id: str = "default") -> str:
|
|
1094
|
+
"""Exports a graph to a file."""
|
|
1095
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1096
|
+
try:
|
|
1097
|
+
return await session.call("export_graph", {"graph_name": graph_name, "format": format})
|
|
1098
|
+
except Exception as e:
|
|
1099
|
+
raise RuntimeError(f"Failed to export graph: {e}")
|
|
1100
|
+
|
|
1101
|
+
@mcp.tool()
|
|
1102
|
+
@log_call
|
|
1103
|
+
async def get_help(topic: str, plain_text: bool = False, session_id: str = "default") -> str:
|
|
1104
|
+
"""Returns help for a Stata command."""
|
|
1105
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1106
|
+
return await session.call("get_help", {"topic": topic, "plain_text": plain_text})
|
|
1107
|
+
|
|
1108
|
+
@mcp.tool()
|
|
1109
|
+
async def get_stored_results(session_id: str = "default") -> str:
|
|
1110
|
+
"""Returns stored r() and e() results."""
|
|
1111
|
+
import json
|
|
1112
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1113
|
+
results = await session.call("get_stored_results", {})
|
|
1114
|
+
return json.dumps(results)
|
|
1115
|
+
|
|
1116
|
+
@mcp.tool()
|
|
1117
|
+
async def load_data(source: str, clear: bool = True, as_json: bool = True, raw: bool = False, max_output_lines: int | None = None, session_id: str = "default") -> str:
|
|
1118
|
+
"""Loads a dataset."""
|
|
1119
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1120
|
+
result_dict = await session.call("load_data", {"source": source, "options": {"clear": clear, "max_output_lines": max_output_lines}})
|
|
1121
|
+
|
|
1122
|
+
result = CommandResponse.model_validate(result_dict)
|
|
1123
|
+
# ui_channel.notify_potential_dataset_change()
|
|
1124
|
+
if raw:
|
|
1125
|
+
return result.stdout if result.success else (result.error.message if result.error else result.stdout)
|
|
1126
|
+
return result.model_dump_json()
|
|
1127
|
+
|
|
1128
|
+
@mcp.tool()
|
|
1129
|
+
async def codebook(variable: str, as_json: bool = True, trace: bool = False, raw: bool = False, max_output_lines: int | None = None, session_id: str = "default") -> str:
|
|
1130
|
+
"""Returns codebook for a variable."""
|
|
1131
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1132
|
+
result_dict = await session.call("codebook", {"variable": variable, "options": {"trace": trace, "max_output_lines": max_output_lines}})
|
|
1133
|
+
|
|
1134
|
+
result = CommandResponse.model_validate(result_dict)
|
|
1135
|
+
if raw:
|
|
1136
|
+
return result.stdout if result.success else (result.error.message if result.error else result.stdout)
|
|
1137
|
+
return result.model_dump_json()
|
|
1138
|
+
|
|
1139
|
+
@mcp.tool()
|
|
1140
|
+
@log_call
|
|
1141
|
+
async def run_do_file(
|
|
1142
|
+
path: str,
|
|
1143
|
+
ctx: Context | None = None,
|
|
1144
|
+
echo: bool = True,
|
|
1145
|
+
as_json: bool = True,
|
|
1146
|
+
trace: bool = False,
|
|
1147
|
+
raw: bool = False,
|
|
1148
|
+
max_output_lines: int = None,
|
|
1149
|
+
cwd: str | None = None,
|
|
1150
|
+
session_id: str = "default",
|
|
1151
|
+
) -> str:
|
|
1152
|
+
"""
|
|
1153
|
+
Executes a .do file.
|
|
1154
|
+
|
|
1155
|
+
Stata output is written to a temporary log file on disk.
|
|
1156
|
+
The server emits a single `notifications/logMessage` event containing the log file path
|
|
1157
|
+
(JSON payload: {"event":"log_path","path":"..."}) so the client can tail it locally.
|
|
1158
|
+
If the client supplies a progress callback/token, progress updates are emitted via
|
|
1159
|
+
`notifications/progress`.
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
path: Path to the .do file to execute.
|
|
1163
|
+
ctx: FastMCP-injected request context (used to send MCP notifications). Optional for direct Python calls.
|
|
1164
|
+
echo: If True, includes command in output.
|
|
1165
|
+
as_json: If True, returns JSON envelope.
|
|
1166
|
+
trace: If True, enables trace mode.
|
|
1167
|
+
raw: If True, returns raw output only.
|
|
1168
|
+
max_output_lines: If set, truncates stdout to this many lines for token efficiency.
|
|
1169
|
+
Note: This tool always uses log-file streaming semantics; there is no non-streaming mode.
|
|
1170
|
+
"""
|
|
1171
|
+
session = getattr(getattr(ctx, "request_context", None), "session", None) if ctx is not None else None
|
|
1172
|
+
|
|
1173
|
+
async def notify_log(text: str) -> None:
|
|
1174
|
+
if session is None:
|
|
1175
|
+
return
|
|
1176
|
+
if not _should_stream_smcl_chunk(text, ctx.request_id):
|
|
1177
|
+
return
|
|
1178
|
+
await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
|
|
1179
|
+
try:
|
|
1180
|
+
payload = json.loads(text)
|
|
1181
|
+
if isinstance(payload, dict) and payload.get("event") == "log_path":
|
|
1182
|
+
if ctx.request_id is not None:
|
|
1183
|
+
_request_log_paths[str(ctx.request_id)] = payload.get("path")
|
|
1184
|
+
except Exception:
|
|
1185
|
+
return
|
|
1186
|
+
|
|
1187
|
+
progress_token = None
|
|
1188
|
+
if ctx is not None and getattr(ctx, "request_context", None) is not None and getattr(ctx.request_context, "meta", None) is not None:
|
|
1189
|
+
progress_token = getattr(ctx.request_context.meta, "progressToken", None)
|
|
1190
|
+
|
|
1191
|
+
async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
|
|
1192
|
+
if session is None or progress_token is None:
|
|
1193
|
+
return
|
|
1194
|
+
await session.send_progress_notification(
|
|
1195
|
+
progress_token=progress_token,
|
|
1196
|
+
progress=progress,
|
|
1197
|
+
total=total,
|
|
1198
|
+
message=message,
|
|
1199
|
+
related_request_id=ctx.request_id,
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
stata_session = await session_manager.get_or_create_session(session_id)
|
|
1203
|
+
result_dict = await stata_session.call(
|
|
1204
|
+
"run_do_file",
|
|
1205
|
+
{
|
|
1206
|
+
"path": path,
|
|
1207
|
+
"options": {
|
|
1208
|
+
"echo": echo,
|
|
1209
|
+
"trace": trace,
|
|
1210
|
+
"max_output_lines": max_output_lines,
|
|
1211
|
+
"cwd": cwd,
|
|
1212
|
+
"emit_graph_ready": True,
|
|
1213
|
+
"graph_ready_task_id": ctx.request_id if ctx else None,
|
|
1214
|
+
"graph_ready_format": "svg",
|
|
1215
|
+
}
|
|
1216
|
+
},
|
|
1217
|
+
notify_log=notify_log if session is not None else _noop_log,
|
|
1218
|
+
notify_progress=notify_progress if progress_token is not None else None,
|
|
1219
|
+
)
|
|
1220
|
+
|
|
1221
|
+
result = CommandResponse.model_validate(result_dict)
|
|
1222
|
+
|
|
1223
|
+
# ui_channel.notify_potential_dataset_change()
|
|
1224
|
+
|
|
1225
|
+
return _format_command_result(result, raw=raw, as_json=as_json)
|
|
1226
|
+
|
|
1227
|
+
@mcp.resource("stata://data/summary")
|
|
1228
|
+
async def get_summary() -> str:
|
|
1229
|
+
"""Returns output of summarize."""
|
|
1230
|
+
session = await session_manager.get_or_create_session("default")
|
|
1231
|
+
result_dict = await session.call("run_command_structured", {"code": "summarize", "options": {"echo": True}})
|
|
1232
|
+
|
|
1233
|
+
result = CommandResponse.model_validate(result_dict)
|
|
1234
|
+
if result.success:
|
|
1235
|
+
return result.stdout
|
|
1236
|
+
if result.error:
|
|
1237
|
+
return result.error.message
|
|
1238
|
+
return ""
|
|
1239
|
+
|
|
1240
|
+
@mcp.resource("stata://data/metadata")
|
|
1241
|
+
async def get_metadata() -> str:
|
|
1242
|
+
"""Returns output of describe."""
|
|
1243
|
+
session = await session_manager.get_or_create_session("default")
|
|
1244
|
+
result_dict = await session.call("run_command_structured", {"code": "describe", "options": {"echo": True}})
|
|
1245
|
+
|
|
1246
|
+
result = CommandResponse.model_validate(result_dict)
|
|
1247
|
+
if result.success:
|
|
1248
|
+
return result.stdout
|
|
1249
|
+
if result.error:
|
|
1250
|
+
return result.error.message
|
|
1251
|
+
return ""
|
|
1252
|
+
|
|
1253
|
+
@mcp.resource("stata://graphs/list")
|
|
1254
|
+
@log_call
|
|
1255
|
+
async def list_graphs_resource() -> str:
|
|
1256
|
+
"""Resource wrapper for the graph list (uses tool list_graphs)."""
|
|
1257
|
+
return await list_graphs("default")
|
|
1258
|
+
|
|
1259
|
+
@mcp.tool()
|
|
1260
|
+
async def get_variable_list(session_id: str = "default") -> str:
|
|
1261
|
+
"""Returns JSON list of all variables."""
|
|
1262
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1263
|
+
variables_dict = await session.call("list_variables_structured", {})
|
|
1264
|
+
|
|
1265
|
+
variables = VariablesResponse.model_validate(variables_dict)
|
|
1266
|
+
return variables.model_dump_json()
|
|
1267
|
+
|
|
1268
|
+
@mcp.resource("stata://variables/list")
|
|
1269
|
+
async def get_variable_list_resource() -> str:
|
|
1270
|
+
"""Resource wrapper for the variable list."""
|
|
1271
|
+
return await get_variable_list("default")
|
|
1272
|
+
|
|
1273
|
+
@mcp.resource("stata://results/stored")
|
|
1274
|
+
async def get_stored_results_resource() -> str:
|
|
1275
|
+
"""Returns stored r() and e() results."""
|
|
1276
|
+
session = await session_manager.get_or_create_session("default")
|
|
1277
|
+
results = await session.call("get_stored_results", {})
|
|
1278
|
+
return json.dumps(results)
|
|
1279
|
+
|
|
1280
|
+
@mcp.tool()
|
|
1281
|
+
async def export_graphs_all(session_id: str = "default") -> str:
|
|
1282
|
+
"""
|
|
1283
|
+
Exports all graphs in memory to file paths.
|
|
1284
|
+
|
|
1285
|
+
Returns a JSON envelope listing graph names and file paths.
|
|
1286
|
+
The agent can open SVG files directly to verify visuals (titles/labels/colors/legends).
|
|
1287
|
+
"""
|
|
1288
|
+
session = await session_manager.get_or_create_session(session_id)
|
|
1289
|
+
exports_dict = await session.call("export_graphs_all", {})
|
|
1290
|
+
|
|
1291
|
+
exports = GraphExportResponse.model_validate(exports_dict)
|
|
1292
|
+
return exports.model_dump_json(exclude_none=False)
|
|
1293
|
+
|
|
1294
|
+
def main():
|
|
1295
|
+
if "--version" in sys.argv:
|
|
1296
|
+
print(SERVER_VERSION)
|
|
1297
|
+
return
|
|
1298
|
+
|
|
1299
|
+
# Fix for macOS environments where sys.executable might be a shim that calls 'realpath'.
|
|
1300
|
+
# On some macOS versions (pre-Monterey) or minimal environments, 'realpath' is missing,
|
|
1301
|
+
# causing shims (like those from uv or pyenv) to fail.
|
|
1302
|
+
if sys.platform == "darwin":
|
|
1303
|
+
try:
|
|
1304
|
+
real_py = os.path.realpath(sys.executable)
|
|
1305
|
+
if real_py != sys.executable:
|
|
1306
|
+
multiprocessing.set_executable(real_py)
|
|
1307
|
+
except Exception:
|
|
1308
|
+
pass
|
|
1309
|
+
|
|
1310
|
+
# Filter non-JSON output off stdout to keep stdio transport clean.
|
|
1311
|
+
_install_stdout_filter()
|
|
1312
|
+
|
|
1313
|
+
setup_logging()
|
|
1314
|
+
|
|
1315
|
+
# Initialize UI channel with default session proxy logic if needed
|
|
1316
|
+
# (Simplified for now, UI might only show default session)
|
|
1317
|
+
global ui_channel
|
|
1318
|
+
|
|
1319
|
+
async def init_sessions():
|
|
1320
|
+
await session_manager.start()
|
|
1321
|
+
# We need a client-like object for UIChannelManager.
|
|
1322
|
+
# This is a bit tricky since it's now multi-session.
|
|
1323
|
+
# For now, we'll try to find a way to make UIChannelManager work or disable it.
|
|
1324
|
+
# Let's use the default session's worker proxy if it was a real client.
|
|
1325
|
+
# But for now, we'll skip UIChannelManager integration or keep it limited.
|
|
1326
|
+
pass
|
|
1327
|
+
|
|
1328
|
+
asyncio.run(init_sessions())
|
|
1329
|
+
|
|
1330
|
+
mcp.run()
|
|
1331
|
+
|
|
1332
|
+
if __name__ == "__main__":
|
|
1333
|
+
main()
|