mcp-stata 1.22.1__cp311-abi3-macosx_11_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_stata/server.py ADDED
@@ -0,0 +1,1333 @@
1
+ from __future__ import annotations
2
+ import anyio
3
+ import asyncio
4
+ from dataclasses import dataclass
5
+ from datetime import datetime, timezone
6
+ from importlib.metadata import PackageNotFoundError, version
7
+ from mcp.server.fastmcp import Context, FastMCP
8
+ from mcp.server.fastmcp.utilities import logging as fastmcp_logging
9
+ import mcp.types as types
10
+ from .stata_client import StataClient
11
+ from .models import (
12
+ ErrorEnvelope,
13
+ CommandResponse,
14
+ DataResponse,
15
+ GraphListResponse,
16
+ VariableInfo,
17
+ VariablesResponse,
18
+ GraphInfo,
19
+ GraphExport,
20
+ GraphExportResponse,
21
+ SessionInfo,
22
+ SessionListResponse,
23
+ )
24
+ from .sessions import SessionManager
25
+ import logging
26
+ import sys
27
+ import json
28
+ import os
29
+ import multiprocessing
30
+ import re
31
+ import traceback
32
+ import uuid
33
+ from functools import wraps
34
+ from typing import Optional, Dict
35
+ import threading
36
+
37
+ from .ui_http import UIChannelManager
38
+
39
+
40
+ # Configure logging
41
+ logger = logging.getLogger("mcp_stata")
42
+ payload_logger = logging.getLogger("mcp_stata.payloads")
43
+ _LOGGING_CONFIGURED = False
44
+
45
+ def get_server_version() -> str:
46
+ """Determine the server version from package metadata or fallback."""
47
+ try:
48
+ return version("mcp-stata")
49
+ except PackageNotFoundError:
50
+ # If not installed, try to find version in pyproject.toml near this file
51
+ try:
52
+ # We are in src/mcp_stata/server.py, pyproject.toml is at ../../pyproject.toml
53
+ base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
54
+ pyproject_path = os.path.join(base_dir, "pyproject.toml")
55
+ if os.path.exists(pyproject_path):
56
+ with open(pyproject_path, "r") as f:
57
+ import re
58
+ content = f.read()
59
+ match = re.search(r'^version\s*=\s*["\']([^"\']+)["\']', content, re.MULTILINE)
60
+ if match:
61
+ return match.group(1)
62
+ except Exception:
63
+ pass
64
+ return "unknown"
65
+
66
+ SERVER_VERSION = get_server_version()
67
+
68
+ def setup_logging():
69
+ global _LOGGING_CONFIGURED
70
+ if _LOGGING_CONFIGURED:
71
+ return
72
+ _LOGGING_CONFIGURED = True
73
+ log_level = os.getenv("MCP_STATA_LOGLEVEL", "DEBUG").upper()
74
+ app_handler = logging.StreamHandler(sys.stderr)
75
+ app_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
76
+ app_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
77
+
78
+ mcp_handler = logging.StreamHandler(sys.stderr)
79
+ mcp_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
80
+ mcp_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
81
+
82
+ payload_handler = logging.StreamHandler(sys.stderr)
83
+ payload_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
84
+ payload_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
85
+
86
+ root_logger = logging.getLogger()
87
+ root_logger.handlers = []
88
+ root_logger.setLevel(logging.WARNING)
89
+
90
+ for name, item in logging.root.manager.loggerDict.items():
91
+ if not isinstance(item, logging.Logger):
92
+ continue
93
+ item.handlers = []
94
+ item.propagate = False
95
+ if item.level == logging.NOTSET:
96
+ item.setLevel(getattr(logging, log_level, logging.DEBUG))
97
+
98
+ logger.handlers = [app_handler]
99
+ logger.propagate = False
100
+
101
+ payload_logger.handlers = [payload_handler]
102
+ payload_logger.propagate = False
103
+
104
+ mcp_logger = logging.getLogger("mcp.server")
105
+ mcp_logger.handlers = [mcp_handler]
106
+ mcp_logger.propagate = False
107
+ mcp_logger.setLevel(getattr(logging, log_level, logging.DEBUG))
108
+
109
+ mcp_lowlevel = logging.getLogger("mcp.server.lowlevel.server")
110
+ mcp_lowlevel.handlers = [mcp_handler]
111
+ mcp_lowlevel.propagate = False
112
+ mcp_lowlevel.setLevel(getattr(logging, log_level, logging.DEBUG))
113
+
114
+ mcp_root = logging.getLogger("mcp")
115
+ mcp_root.handlers = [mcp_handler]
116
+ mcp_root.propagate = False
117
+ mcp_root.setLevel(getattr(logging, log_level, logging.DEBUG))
118
+ if logger.level == logging.NOTSET:
119
+ logger.setLevel(getattr(logging, log_level, logging.DEBUG))
120
+
121
+ logger.info("=== mcp-stata server starting ===")
122
+ logger.info("mcp-stata version: %s", SERVER_VERSION)
123
+ logger.info("STATA_PATH env at startup: %s", os.getenv("STATA_PATH", "<not set>"))
124
+ logger.info("LOG_LEVEL: %s", log_level)
125
+
126
+
127
+
128
+ # Initialize FastMCP
129
+ mcp = FastMCP("mcp_stata")
130
+ # Set version on the underlying server to expose it in InitializeResult
131
+ mcp._mcp_server.version = SERVER_VERSION
132
+
133
+ session_manager = SessionManager()
134
+
135
+ class StataClientProxy:
136
+ """Proxy for StataClient that routes calls to a StataSession (via worker process)."""
137
+ def __init__(self, session_id: str = "default"):
138
+ self.session_id = session_id
139
+
140
+ def _call_sync(self, method: str, args: dict[str, Any]) -> Any:
141
+ try:
142
+ loop = asyncio.get_running_loop()
143
+ except RuntimeError:
144
+ loop = None
145
+
146
+ async def _run():
147
+ session = await session_manager.get_or_create_session(self.session_id)
148
+ return await session.call(method, args)
149
+
150
+ if loop and loop.is_running():
151
+ # If we're in a thread different from the loop's thread
152
+ # (which is true for UI HTTP handler threads)
153
+ import threading
154
+ if threading.current_thread() != threading.main_thread(): # Simplified check
155
+ future = asyncio.run_coroutine_threadsafe(_run(), loop)
156
+ return future.result()
157
+ else:
158
+ # If we're on the main thread but in a loop, we can't block.
159
+ # This case shouldn't happen for UIChannelManager but might for tests.
160
+ # For tests, we'll try anyio.from_thread.run if available or just run it.
161
+ return anyio.from_thread.run(_run)
162
+ else:
163
+ return asyncio.run(_run())
164
+
165
+ def get_dataset_state(self) -> dict[str, Any]:
166
+ return self._call_sync("get_dataset_state", {})
167
+
168
+ def get_arrow_stream(self, **kwargs) -> bytes:
169
+ return self._call_sync("get_arrow_stream", kwargs)
170
+
171
+ def list_variables_rich(self) -> list[dict[str, Any]]:
172
+ return self._call_sync("list_variables_rich", {})
173
+
174
+ def compute_view_indices(self, filter_expr: str) -> list[int]:
175
+ return self._call_sync("compute_view_indices", {"filter_expr": filter_expr})
176
+
177
+ def validate_filter_expr(self, filter_expr: str):
178
+ return self._call_sync("validate_filter_expr", {"filter_expr": filter_expr})
179
+
180
+ def get_page(self, **kwargs):
181
+ return self._call_sync("get_page", kwargs)
182
+
183
+ client = StataClientProxy()
184
+ ui_channel = None
185
+
186
+ def _ensure_ui_channel():
187
+ global ui_channel
188
+ if ui_channel is None:
189
+ try:
190
+ from .ui_http import UIChannelManager
191
+ ui_channel = UIChannelManager(client)
192
+ except Exception:
193
+ logger.exception("Failed to initialize UI channel")
194
+
195
+ @mcp.tool()
196
+ async def create_session(session_id: str) -> str:
197
+ """Create a new Stata session.
198
+
199
+ Args:
200
+ session_id: A unique identifier for the new session.
201
+ """
202
+ await session_manager.get_or_create_session(session_id)
203
+ return json.dumps({"status": "created", "session_id": session_id})
204
+
205
+ @mcp.tool()
206
+ async def stop_session(session_id: str) -> str:
207
+ """Stop and terminate a Stata session.
208
+
209
+ Args:
210
+ session_id: The identifier of the session to stop.
211
+ """
212
+ await session_manager.stop_session(session_id)
213
+ return json.dumps({"status": "stopped", "session_id": session_id})
214
+
215
+ @mcp.tool()
216
+ def list_sessions() -> str:
217
+ """List all active Stata sessions and their status."""
218
+ sessions = session_manager.list_sessions()
219
+ return SessionListResponse(sessions=sessions).model_dump_json()
220
+
221
+
222
+ async def _noop_log(_text: str) -> None:
223
+ return
224
+
225
+ @dataclass
226
+ class BackgroundTask:
227
+ task_id: str
228
+ kind: str
229
+ task: asyncio.Task
230
+ created_at: datetime
231
+ log_path: Optional[str] = None
232
+ result: Optional[str] = None
233
+ error: Optional[str] = None
234
+ done: bool = False
235
+
236
+
237
+ _background_tasks: Dict[str, BackgroundTask] = {}
238
+ _request_log_paths: Dict[str, str] = {}
239
+ _read_log_paths: set[str] = set()
240
+ _read_log_offsets: Dict[str, int] = {}
241
+ _STDOUT_FILTER_INSTALLED = False
242
+
243
+
244
+ def _install_stdout_filter() -> None:
245
+ """
246
+ Redirect process stdout to a pipe and forward only JSON-RPC lines to the
247
+ original stdout. Any non-JSON output (e.g., Stata noise) is sent to stderr.
248
+ """
249
+ global _STDOUT_FILTER_INSTALLED
250
+ if _STDOUT_FILTER_INSTALLED:
251
+ return
252
+ _STDOUT_FILTER_INSTALLED = True
253
+
254
+ try:
255
+ # Flush any pending output before redirecting.
256
+ try:
257
+ sys.stdout.flush()
258
+ except Exception:
259
+ pass
260
+
261
+ original_stdout_fd = os.dup(1)
262
+ read_fd, write_fd = os.pipe()
263
+ os.dup2(write_fd, 1)
264
+ os.close(write_fd)
265
+
266
+ def _forward_stdout() -> None:
267
+ buffer = b""
268
+ while True:
269
+ try:
270
+ chunk = os.read(read_fd, 4096)
271
+ except Exception:
272
+ break
273
+ if not chunk:
274
+ break
275
+ buffer += chunk
276
+ while b"\n" in buffer:
277
+ line, buffer = buffer.split(b"\n", 1)
278
+ line_with_nl = line + b"\n"
279
+ stripped = line.lstrip()
280
+ if stripped:
281
+ try:
282
+ payload = json.loads(stripped)
283
+ if isinstance(payload, dict) and payload.get("jsonrpc"):
284
+ os.write(original_stdout_fd, line_with_nl)
285
+ elif isinstance(payload, list) and any(
286
+ isinstance(item, dict) and item.get("jsonrpc") for item in payload
287
+ ):
288
+ os.write(original_stdout_fd, line_with_nl)
289
+ else:
290
+ os.write(2, line_with_nl)
291
+ except Exception:
292
+ os.write(2, line_with_nl)
293
+ if buffer:
294
+ stripped = buffer.lstrip()
295
+ if stripped:
296
+ try:
297
+ payload = json.loads(stripped)
298
+ if isinstance(payload, dict) and payload.get("jsonrpc"):
299
+ os.write(original_stdout_fd, buffer)
300
+ elif isinstance(payload, list) and any(
301
+ isinstance(item, dict) and item.get("jsonrpc") for item in payload
302
+ ):
303
+ os.write(original_stdout_fd, buffer)
304
+ else:
305
+ os.write(2, buffer)
306
+ except Exception:
307
+ os.write(2, buffer)
308
+
309
+ try:
310
+ os.close(read_fd)
311
+ except Exception:
312
+ pass
313
+
314
+ t = threading.Thread(target=_forward_stdout, name="mcp-stdout-filter", daemon=True)
315
+ t.start()
316
+ except Exception:
317
+ _STDOUT_FILTER_INSTALLED = False
318
+
319
+
320
+ def _register_task(task_info: BackgroundTask, max_tasks: int = 100) -> None:
321
+ _background_tasks[task_info.task_id] = task_info
322
+ if len(_background_tasks) <= max_tasks:
323
+ return
324
+ completed = [task for task in _background_tasks.values() if task.done]
325
+ completed.sort(key=lambda item: item.created_at)
326
+ for task in completed[: max(0, len(_background_tasks) - max_tasks)]:
327
+ _background_tasks.pop(task.task_id, None)
328
+
329
+
330
+ def _format_command_result(result, raw: bool, as_json: bool) -> str:
331
+ if raw:
332
+ if result.success:
333
+ return result.log_path or ""
334
+ if result.error:
335
+ msg = result.error.message
336
+ if result.error.rc is not None:
337
+ msg = f"{msg}\nrc={result.error.rc}"
338
+ return msg
339
+ return result.log_path or ""
340
+
341
+ # Note: we used to clear result.stdout here for token efficiency,
342
+ # but that conflicts with requirements and breaks E2E tests that
343
+ # expect results in the return value.
344
+
345
+ if as_json:
346
+ return result.model_dump_json()
347
+ return result.model_dump_json()
348
+
349
+
350
+ async def _wait_for_log_path(task_info: BackgroundTask) -> None:
351
+ while task_info.log_path is None and not task_info.done:
352
+ await anyio.sleep(0.01)
353
+
354
+
355
+ async def _notify_task_done(session: object | None, task_info: BackgroundTask, request_id: object | None) -> None:
356
+ if session is None:
357
+ return
358
+ payload = {
359
+ "event": "task_done",
360
+ "task_id": task_info.task_id,
361
+ "status": "done" if task_info.done else "unknown",
362
+ "log_path": task_info.log_path,
363
+ "error": task_info.error,
364
+ }
365
+ try:
366
+ await session.send_log_message(level="info", data=json.dumps(payload), related_request_id=request_id)
367
+ except Exception:
368
+ return
369
+
370
+
371
+ def _debug_notification(kind: str, payload: object, request_id: object | None = None) -> None:
372
+ try:
373
+ serialized = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False)
374
+ except Exception:
375
+ serialized = str(payload)
376
+ payload_logger.info("MCP notify %s request_id=%s payload=%s", kind, request_id, serialized)
377
+
378
+
379
+ async def _notify_tool_error(ctx: Context | None, tool_name: str, exc: Exception) -> None:
380
+ if ctx is None:
381
+ return
382
+ session = ctx.request_context.session
383
+ if session is None:
384
+ return
385
+ task_id = None
386
+ meta = ctx.request_context.meta
387
+ if meta is not None:
388
+ task_id = getattr(meta, "task_id", None) or getattr(meta, "taskId", None)
389
+ payload = {
390
+ "event": "tool_error",
391
+ "tool": tool_name,
392
+ "error": str(exc),
393
+ "traceback": traceback.format_exc(),
394
+ }
395
+ if task_id is not None:
396
+ payload["task_id"] = task_id
397
+ try:
398
+ await session.send_log_message(
399
+ level="error",
400
+ data=json.dumps(payload),
401
+ related_request_id=ctx.request_id,
402
+ )
403
+ except Exception:
404
+ logger.exception("Failed to emit tool_error notification for %s", tool_name)
405
+
406
+
407
+ def _log_tool_call(tool_name: str, ctx: Context | None = None) -> None:
408
+ request_id = None
409
+ if ctx is not None:
410
+ request_id = getattr(ctx, "request_id", None)
411
+ logger.info("MCP tool call: %s request_id=%s", tool_name, request_id)
412
+
413
+ def _should_stream_smcl_chunk(text: str, request_id: object | None) -> bool:
414
+ if request_id is None:
415
+ return True
416
+ try:
417
+ payload = json.loads(text)
418
+ if isinstance(payload, dict) and payload.get("event"):
419
+ return True
420
+ except Exception:
421
+ pass
422
+ log_path = _request_log_paths.get(str(request_id))
423
+ if log_path and log_path in _read_log_paths:
424
+ return False
425
+ return True
426
+
427
+
428
+ def _attach_task_id(ctx: Context | None, task_id: str) -> None:
429
+ if ctx is None:
430
+ return
431
+ meta = ctx.request_context.meta
432
+ if meta is None:
433
+ meta = types.RequestParams.Meta()
434
+ ctx.request_context.meta = meta
435
+ try:
436
+ setattr(meta, "task_id", task_id)
437
+ except Exception:
438
+ logger.debug("Unable to attach task_id to request meta", exc_info=True)
439
+
440
+
441
+ def _extract_ctx(args: tuple[object, ...], kwargs: dict[str, object]) -> Context | None:
442
+ ctx = kwargs.get("ctx")
443
+ if isinstance(ctx, Context):
444
+ return ctx
445
+ for arg in args:
446
+ if isinstance(arg, Context):
447
+ return arg
448
+ return None
449
+
450
+
451
+ def log_call(func):
452
+ """Decorator to log tool and resource calls."""
453
+ if asyncio.iscoroutinefunction(func):
454
+ @wraps(func)
455
+ async def async_inner(*args, **kwargs):
456
+ ctx = _extract_ctx(args, kwargs)
457
+ _log_tool_call(func.__name__, ctx)
458
+ return await func(*args, **kwargs)
459
+ return async_inner
460
+ else:
461
+ @wraps(func)
462
+ def sync_inner(*args, **kwargs):
463
+ ctx = _extract_ctx(args, kwargs)
464
+ _log_tool_call(func.__name__, ctx)
465
+ return func(*args, **kwargs)
466
+ return sync_inner
467
+
468
+
469
+ @mcp.tool()
470
+ @log_call
471
+ async def run_do_file_background(
472
+ path: str,
473
+ ctx: Context | None = None,
474
+ echo: bool = True,
475
+ as_json: bool = True,
476
+ trace: bool = False,
477
+ raw: bool = False,
478
+ max_output_lines: int = None,
479
+ cwd: str | None = None,
480
+ session_id: str = "default",
481
+ ) -> str:
482
+ """Run a Stata do-file in the background and return a task id.
483
+
484
+ Notifications:
485
+ - logMessage: {"event":"log_path","path":"..."}
486
+ - logMessage: {"event":"task_done","task_id":"...","status":"done","log_path":"...","error":null}
487
+ """
488
+ session = getattr(getattr(ctx, "request_context", None), "session", None) if ctx is not None else None
489
+ request_id = ctx.request_id if ctx is not None else None
490
+ task_id = uuid.uuid4().hex
491
+ _attach_task_id(ctx, task_id)
492
+ task_info = BackgroundTask(
493
+ task_id=task_id,
494
+ kind="do_file",
495
+ task=None,
496
+ created_at=datetime.now(timezone.utc),
497
+ )
498
+
499
+ async def notify_log(text: str) -> None:
500
+ if session is not None:
501
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
502
+ return
503
+ _debug_notification("logMessage", text, ctx.request_id)
504
+ try:
505
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
506
+ except Exception as e:
507
+ logger.warning("Failed to send logMessage notification: %s", e)
508
+ sys.stderr.write(f"[mcp_stata] ERROR: logMessage send failed: {e!r}\n")
509
+ sys.stderr.flush()
510
+ try:
511
+ payload = json.loads(text)
512
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
513
+ task_info.log_path = payload.get("path")
514
+ if ctx.request_id is not None and task_info.log_path:
515
+ _request_log_paths[str(ctx.request_id)] = task_info.log_path
516
+ except Exception:
517
+ return
518
+
519
+ progress_token = None
520
+ if ctx is not None and getattr(ctx, "request_context", None) is not None and getattr(ctx.request_context, "meta", None) is not None:
521
+ progress_token = getattr(ctx.request_context.meta, "progressToken", None)
522
+
523
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
524
+ if session is None or progress_token is None:
525
+ return
526
+ _debug_notification(
527
+ "progress",
528
+ {"progress": progress, "total": total, "message": message},
529
+ ctx.request_id,
530
+ )
531
+ try:
532
+ await session.send_progress_notification(
533
+ progress_token=progress_token,
534
+ progress=progress,
535
+ total=total,
536
+ message=message,
537
+ related_request_id=ctx.request_id,
538
+ )
539
+ except Exception as exc:
540
+ logger.debug("Progress notification failed: %s", exc)
541
+
542
+ async def _run() -> None:
543
+ try:
544
+ stata_session = await session_manager.get_or_create_session(session_id)
545
+ result_dict = await stata_session.call(
546
+ "run_do_file",
547
+ {
548
+ "path": path,
549
+ "options": {
550
+ "echo": echo,
551
+ "trace": trace,
552
+ "max_output_lines": max_output_lines,
553
+ "cwd": cwd,
554
+ "emit_graph_ready": True,
555
+ "graph_ready_task_id": task_id,
556
+ "graph_ready_format": "svg",
557
+ }
558
+ },
559
+ notify_log=notify_log,
560
+ notify_progress=notify_progress if progress_token is not None else None,
561
+ )
562
+ result = CommandResponse.model_validate(result_dict)
563
+ if not task_info.log_path and result.log_path:
564
+ task_info.log_path = result.log_path
565
+ if result.error:
566
+ task_info.error = result.error.message
567
+ task_info.result = _format_command_result(result, raw=raw, as_json=as_json)
568
+ task_info.done = True
569
+ await _notify_task_done(session, task_info, request_id)
570
+
571
+ _ensure_ui_channel()
572
+ if ui_channel:
573
+ ui_channel.notify_potential_dataset_change(session_id)
574
+ except Exception as exc: # pragma: no cover - defensive
575
+ task_info.done = True
576
+ task_info.error = str(exc)
577
+ await _notify_task_done(session, task_info, request_id)
578
+
579
+ if session is None:
580
+ await _run()
581
+ task_info.task = None
582
+ else:
583
+ task_info.task = asyncio.create_task(_run())
584
+ _register_task(task_info)
585
+ await _wait_for_log_path(task_info)
586
+ return json.dumps({"task_id": task_id, "status": "started", "log_path": task_info.log_path})
587
+
588
+
589
+ @mcp.tool()
590
+ @log_call
591
+ def get_task_status(task_id: str, allow_polling: bool = False) -> str:
592
+ """Return task status for background executions.
593
+
594
+ Polling is disabled by default; set allow_polling=True for legacy callers.
595
+ """
596
+ notice = "Prefer task_done logMessage notifications over polling get_task_status."
597
+ if not allow_polling:
598
+ logger.warning(
599
+ "get_task_status called without allow_polling; clients must use task_done logMessage notifications"
600
+ )
601
+ return json.dumps({
602
+ "task_id": task_id,
603
+ "status": "polling_not_allowed",
604
+ "error": "Polling is disabled; use task_done logMessage notifications.",
605
+ "notice": notice,
606
+ })
607
+ logger.warning("get_task_status called; clients should use task_done logMessage notifications instead of polling")
608
+ task_info = _background_tasks.get(task_id)
609
+ if task_info is None:
610
+ return json.dumps({"task_id": task_id, "status": "not_found", "notice": notice})
611
+ return json.dumps({
612
+ "task_id": task_id,
613
+ "status": "done" if task_info.done else "running",
614
+ "kind": task_info.kind,
615
+ "created_at": task_info.created_at.isoformat(),
616
+ "log_path": task_info.log_path,
617
+ "error": task_info.error,
618
+ "notice": notice,
619
+ })
620
+
621
+
622
+ @mcp.tool()
623
+ @log_call
624
+ def get_task_result(task_id: str, allow_polling: bool = False) -> str:
625
+ """Return task result for background executions.
626
+
627
+ Polling is disabled by default; set allow_polling=True for legacy callers.
628
+ """
629
+ notice = "Prefer task_done logMessage notifications over polling get_task_result."
630
+ if not allow_polling:
631
+ logger.warning(
632
+ "get_task_result called without allow_polling; clients must use task_done logMessage notifications"
633
+ )
634
+ return json.dumps({
635
+ "task_id": task_id,
636
+ "status": "polling_not_allowed",
637
+ "error": "Polling is disabled; use task_done logMessage notifications.",
638
+ "notice": notice,
639
+ })
640
+ logger.warning("get_task_result called; clients should use task_done logMessage notifications instead of polling")
641
+ task_info = _background_tasks.get(task_id)
642
+ if task_info is None:
643
+ return json.dumps({"task_id": task_id, "status": "not_found", "notice": notice})
644
+ if not task_info.done:
645
+ return json.dumps({
646
+ "task_id": task_id,
647
+ "status": "running",
648
+ "log_path": task_info.log_path,
649
+ "notice": notice,
650
+ })
651
+ return json.dumps({
652
+ "task_id": task_id,
653
+ "status": "done",
654
+ "log_path": task_info.log_path,
655
+ "error": task_info.error,
656
+ "notice": notice,
657
+ "result": task_info.result,
658
+ })
659
+
660
+
661
+ @mcp.tool()
662
+ @log_call
663
+ def cancel_task(task_id: str) -> str:
664
+ """Request cancellation of a background task."""
665
+ task_info = _background_tasks.get(task_id)
666
+ if task_info is None:
667
+ return json.dumps({"task_id": task_id, "status": "not_found"})
668
+ if task_info.task and not task_info.task.done():
669
+ task_info.task.cancel()
670
+ return json.dumps({"task_id": task_id, "status": "cancelling"})
671
+ return json.dumps({"task_id": task_id, "status": "done", "log_path": task_info.log_path})
672
+
673
+
674
+ @mcp.tool()
675
+ @log_call
676
+ async def run_command_background(
677
+ code: str,
678
+ ctx: Context | None = None,
679
+ echo: bool = True,
680
+ as_json: bool = True,
681
+ trace: bool = False,
682
+ raw: bool = False,
683
+ max_output_lines: int = None,
684
+ cwd: str | None = None,
685
+ session_id: str = "default",
686
+ ) -> str:
687
+ """Run a Stata command in the background and return a task id.
688
+
689
+ Notifications:
690
+ - logMessage: {"event":"log_path","path":"..."}
691
+ - logMessage: {"event":"task_done","task_id":"...","status":"done","log_path":"...","error":null}
692
+ """
693
+ session = getattr(getattr(ctx, "request_context", None), "session", None) if ctx is not None else None
694
+ request_id = ctx.request_id if ctx is not None else None
695
+ task_id = uuid.uuid4().hex
696
+ _attach_task_id(ctx, task_id)
697
+ task_info = BackgroundTask(
698
+ task_id=task_id,
699
+ kind="command",
700
+ task=None,
701
+ created_at=datetime.now(timezone.utc),
702
+ )
703
+
704
+ async def notify_log(text: str) -> None:
705
+ if session is not None:
706
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
707
+ return
708
+ _debug_notification("logMessage", text, ctx.request_id)
709
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
710
+ try:
711
+ payload = json.loads(text)
712
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
713
+ task_info.log_path = payload.get("path")
714
+ if ctx.request_id is not None and task_info.log_path:
715
+ _request_log_paths[str(ctx.request_id)] = task_info.log_path
716
+ except Exception:
717
+ return
718
+
719
+ progress_token = None
720
+ if ctx is not None and getattr(ctx, "request_context", None) is not None and getattr(ctx.request_context, "meta", None) is not None:
721
+ progress_token = getattr(ctx.request_context.meta, "progressToken", None)
722
+
723
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
724
+ if session is None or progress_token is None:
725
+ return
726
+ await session.send_progress_notification(
727
+ progress_token=progress_token,
728
+ progress=progress,
729
+ total=total,
730
+ message=message,
731
+ related_request_id=ctx.request_id,
732
+ )
733
+
734
+ async def _run() -> None:
735
+ try:
736
+ stata_session = await session_manager.get_or_create_session(session_id)
737
+ result_dict = await stata_session.call(
738
+ "run_command",
739
+ {
740
+ "code": code,
741
+ "options": {
742
+ "echo": echo,
743
+ "trace": trace,
744
+ "max_output_lines": max_output_lines,
745
+ "cwd": cwd,
746
+ "emit_graph_ready": True,
747
+ "graph_ready_task_id": task_id,
748
+ "graph_ready_format": "svg",
749
+ }
750
+ },
751
+ notify_log=notify_log,
752
+ notify_progress=notify_progress if progress_token is not None else None,
753
+ )
754
+ result = CommandResponse.model_validate(result_dict)
755
+ if not task_info.log_path and result.log_path:
756
+ task_info.log_path = result.log_path
757
+ if result.error:
758
+ task_info.error = result.error.message
759
+ task_info.result = _format_command_result(result, raw=raw, as_json=as_json)
760
+ task_info.done = True
761
+ await _notify_task_done(session, task_info, request_id)
762
+
763
+ _ensure_ui_channel()
764
+ if ui_channel:
765
+ ui_channel.notify_potential_dataset_change(session_id)
766
+ except Exception as exc: # pragma: no cover - defensive
767
+ task_info.done = True
768
+ task_info.error = str(exc)
769
+ await _notify_task_done(session, task_info, request_id)
770
+
771
+ if session is None:
772
+ await _run()
773
+ task_info.task = None
774
+ else:
775
+ task_info.task = asyncio.create_task(_run())
776
+ _register_task(task_info)
777
+ await _wait_for_log_path(task_info)
778
+ return json.dumps({"task_id": task_id, "status": "started", "log_path": task_info.log_path})
779
+
780
+ @mcp.tool()
781
+ @log_call
782
+ async def run_command(
783
+ code: str,
784
+ ctx: Context | None = None,
785
+ echo: bool = True,
786
+ as_json: bool = True,
787
+ trace: bool = False,
788
+ raw: bool = False,
789
+ max_output_lines: int = None,
790
+ cwd: str | None = None,
791
+ session_id: str = "default",
792
+ ) -> str:
793
+ """
794
+ Executes Stata code.
795
+
796
+ This is the primary tool for interacting with Stata.
797
+
798
+ Stata output is written to a temporary log file on disk.
799
+ The server emits a single `notifications/logMessage` event containing the log file path
800
+ (JSON payload: {"event":"log_path","path":"..."}) so the client can tail it locally.
801
+ If the client supplies a progress callback/token, progress updates may also be emitted
802
+ via `notifications/progress`.
803
+
804
+ Args:
805
+ code: The Stata command(s) to execute (e.g., "sysuse auto", "regress price mpg", "summarize").
806
+ ctx: FastMCP-injected request context (used to send MCP notifications). Optional for direct Python calls.
807
+ echo: If True, the command itself is included in the output. Default is True.
808
+ as_json: If True, returns a JSON envelope with rc/stdout/stderr/error.
809
+ trace: If True, enables `set trace on` for deeper error diagnostics (automatically disabled after).
810
+ raw: If True, return raw output/error message rather than a JSON envelope.
811
+ max_output_lines: If set, truncates stdout to this many lines for token efficiency.
812
+ Useful for verbose commands (regress, codebook, etc.).
813
+ Note: This tool always uses log-file streaming semantics; there is no non-streaming mode.
814
+ """
815
+ session = getattr(getattr(ctx, "request_context", None), "session", None) if ctx is not None else None
816
+
817
+ async def notify_log(text: str) -> None:
818
+ if session is None:
819
+ return
820
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
821
+ return
822
+ _debug_notification("logMessage", text, ctx.request_id)
823
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
824
+ try:
825
+ payload = json.loads(text)
826
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
827
+ if ctx.request_id is not None:
828
+ _request_log_paths[str(ctx.request_id)] = payload.get("path")
829
+ except Exception:
830
+ return
831
+
832
+ progress_token = None
833
+ if ctx is not None and getattr(ctx, "request_context", None) is not None and getattr(ctx.request_context, "meta", None) is not None:
834
+ progress_token = getattr(ctx.request_context.meta, "progressToken", None)
835
+
836
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
837
+ if session is None or progress_token is None:
838
+ return
839
+ await session.send_progress_notification(
840
+ progress_token=progress_token,
841
+ progress=progress,
842
+ total=total,
843
+ message=message,
844
+ related_request_id=ctx.request_id,
845
+ )
846
+
847
+
848
+ stata_session = await session_manager.get_or_create_session(session_id)
849
+ result_dict = await stata_session.call(
850
+ "run_command",
851
+ {
852
+ "code": code,
853
+ "options": {
854
+ "echo": echo,
855
+ "trace": trace,
856
+ "max_output_lines": max_output_lines,
857
+ "cwd": cwd,
858
+ "emit_graph_ready": True,
859
+ "graph_ready_task_id": ctx.request_id if ctx else None,
860
+ "graph_ready_format": "svg",
861
+ }
862
+ },
863
+ notify_log=notify_log if session is not None else _noop_log,
864
+ notify_progress=notify_progress if progress_token is not None else None,
865
+ )
866
+
867
+ result = CommandResponse.model_validate(result_dict)
868
+ _ensure_ui_channel()
869
+ if ui_channel:
870
+ ui_channel.notify_potential_dataset_change(session_id)
871
+ return _format_command_result(result, raw=raw, as_json=as_json)
872
+
873
+ @mcp.tool()
874
+ @log_call
875
+ def read_log(path: str, offset: int = 0, max_bytes: int = 65536) -> str:
876
+ """Read a slice of a log file.
877
+
878
+ Intended for clients that want to display a terminal-like view without pushing MBs of
879
+ output through MCP log notifications.
880
+
881
+ Args:
882
+ path: Absolute path to the log file previously provided by the server.
883
+ offset: Byte offset to start reading from.
884
+ max_bytes: Maximum bytes to read.
885
+
886
+ Returns a compact JSON string: {"path":..., "offset":..., "next_offset":..., "data":...}
887
+ """
888
+ try:
889
+ if path:
890
+ _read_log_paths.add(path)
891
+ if offset < 0:
892
+ offset = 0
893
+ if path:
894
+ last_offset = _read_log_offsets.get(path, 0)
895
+ if offset < last_offset:
896
+ offset = last_offset
897
+ with open(path, "rb") as f:
898
+ f.seek(offset)
899
+ data = f.read(max_bytes)
900
+ next_offset = f.tell()
901
+ if path:
902
+ _read_log_offsets[path] = next_offset
903
+ text = data.decode("utf-8", errors="replace")
904
+ return json.dumps({"path": path, "offset": offset, "next_offset": next_offset, "data": text})
905
+ except FileNotFoundError:
906
+ return json.dumps({"path": path, "offset": offset, "next_offset": offset, "data": ""})
907
+ except Exception as e:
908
+ return json.dumps({"path": path, "offset": offset, "next_offset": offset, "data": f"ERROR: {e}"})
909
+
910
+
911
+ @mcp.tool()
912
+ @log_call
913
+ def find_in_log(
914
+ path: str,
915
+ query: str,
916
+ start_offset: int = 0,
917
+ max_bytes: int = 5_000_000,
918
+ before: int = 2,
919
+ after: int = 2,
920
+ case_sensitive: bool = False,
921
+ regex: bool = False,
922
+ max_matches: int = 50,
923
+ ) -> str:
924
+ """Find text within a log file and return context windows.
925
+
926
+ Args:
927
+ path: Absolute path to the log file previously provided by the server.
928
+ query: Text or regex pattern to search for.
929
+ start_offset: Byte offset to start searching from.
930
+ max_bytes: Maximum bytes to read from the log.
931
+ before: Number of context lines to include before each match.
932
+ after: Number of context lines to include after each match.
933
+ case_sensitive: If True, match case-sensitively.
934
+ regex: If True, treat query as a regular expression.
935
+ max_matches: Maximum number of matches to return.
936
+
937
+ Returns a JSON string with matches and offsets:
938
+ {"path":..., "query":..., "start_offset":..., "next_offset":..., "truncated":..., "matches":[...]}.
939
+ """
940
+ try:
941
+ if start_offset < 0:
942
+ start_offset = 0
943
+ if max_bytes <= 0:
944
+ return json.dumps({
945
+ "path": path,
946
+ "query": query,
947
+ "start_offset": start_offset,
948
+ "next_offset": start_offset,
949
+ "truncated": False,
950
+ "matches": [],
951
+ })
952
+ with open(path, "rb") as f:
953
+ f.seek(start_offset)
954
+ data = f.read(max_bytes)
955
+ next_offset = f.tell()
956
+
957
+ text = data.decode("utf-8", errors="replace")
958
+ lines = text.splitlines()
959
+
960
+ if regex:
961
+ flags = 0 if case_sensitive else re.IGNORECASE
962
+ pattern = re.compile(query, flags=flags)
963
+ def is_match(line: str) -> bool:
964
+ return pattern.search(line) is not None
965
+ else:
966
+ needle = query if case_sensitive else query.lower()
967
+ def is_match(line: str) -> bool:
968
+ haystack = line if case_sensitive else line.lower()
969
+ return needle in haystack
970
+
971
+ matches = []
972
+ for idx, line in enumerate(lines):
973
+ if not is_match(line):
974
+ continue
975
+ start_idx = max(0, idx - max(0, before))
976
+ end_idx = min(len(lines), idx + max(0, after) + 1)
977
+ context = lines[start_idx:end_idx]
978
+ matches.append({
979
+ "line_index": idx,
980
+ "context_start": start_idx,
981
+ "context_end": end_idx,
982
+ "context": context,
983
+ })
984
+ if len(matches) >= max_matches:
985
+ break
986
+
987
+ truncated = len(matches) >= max_matches
988
+ return json.dumps({
989
+ "path": path,
990
+ "query": query,
991
+ "start_offset": start_offset,
992
+ "next_offset": next_offset,
993
+ "truncated": truncated,
994
+ "matches": matches,
995
+ })
996
+ except FileNotFoundError:
997
+ return json.dumps({
998
+ "path": path,
999
+ "query": query,
1000
+ "start_offset": start_offset,
1001
+ "next_offset": start_offset,
1002
+ "truncated": False,
1003
+ "matches": [],
1004
+ })
1005
+ except Exception as e:
1006
+ return json.dumps({
1007
+ "path": path,
1008
+ "query": query,
1009
+ "start_offset": start_offset,
1010
+ "next_offset": start_offset,
1011
+ "truncated": False,
1012
+ "matches": [],
1013
+ "error": f"ERROR: {e}",
1014
+ })
1015
+
1016
+
1017
+ @mcp.tool()
1018
+ @log_call
1019
+ async def get_data(start: int = 0, count: int = 50, session_id: str = "default") -> str:
1020
+ """
1021
+ Returns a slice of the active dataset as a JSON-formatted list of dictionaries.
1022
+
1023
+ Use this to inspect the actual data values in memory. Useful for checking data quality or content.
1024
+
1025
+ Args:
1026
+ start: The zero-based index of the first observation to retrieve.
1027
+ count: The number of observations to retrieve. Defaults to 50.
1028
+ session_id: The ID of the Stata session.
1029
+ """
1030
+ session = await session_manager.get_or_create_session(session_id)
1031
+ data = await session.call("get_data", {"start": start, "count": count})
1032
+ resp = DataResponse(start=start, count=count, data=data)
1033
+ return resp.model_dump_json()
1034
+
1035
+ def _ensure_ui_channel():
1036
+ global ui_channel
1037
+ if ui_channel is None:
1038
+ try:
1039
+ from .ui_http import UIChannelManager
1040
+ # Pass the default client proxy. UIChannelManager will create
1041
+ # session-specific proxies as needed.
1042
+ ui_channel = UIChannelManager(client)
1043
+ except Exception:
1044
+ logger.exception("Failed to initialize UI channel")
1045
+
1046
+ @mcp.tool()
1047
+ @log_call
1048
+ def get_ui_channel(session_id: str = "default") -> str:
1049
+ """Return localhost HTTP endpoint + bearer token for the extension UI data plane.
1050
+
1051
+ Args:
1052
+ session_id: Stata session ID to connect the UI to (default is "default").
1053
+ """
1054
+ _ensure_ui_channel()
1055
+ if ui_channel is None:
1056
+ return json.dumps({"error": "UI channel not initialized"})
1057
+ info = ui_channel.get_channel()
1058
+ payload = {
1059
+ "baseUrl": info.base_url,
1060
+ "token": info.token,
1061
+ "expiresAt": info.expires_at,
1062
+ "capabilities": ui_channel.capabilities(),
1063
+ "sessionId": session_id,
1064
+ }
1065
+ return json.dumps(payload)
1066
+
1067
+ @mcp.tool()
1068
+ @log_call
1069
+ async def describe(session_id: str = "default") -> str:
1070
+ """Returns the descriptive metadata of the dataset."""
1071
+ session = await session_manager.get_or_create_session(session_id)
1072
+ result_dict = await session.call("run_command_structured", {"code": "describe", "options": {"echo": True}})
1073
+
1074
+ result = CommandResponse.model_validate(result_dict)
1075
+ if result.success:
1076
+ return result.stdout
1077
+ if result.error:
1078
+ return result.error.message
1079
+ return ""
1080
+
1081
+ @mcp.tool()
1082
+ @log_call
1083
+ async def list_graphs(session_id: str = "default") -> str:
1084
+ """Lists graphs in memory."""
1085
+ session = await session_manager.get_or_create_session(session_id)
1086
+ graphs_dict = await session.call("list_graphs", {})
1087
+
1088
+ graphs = GraphListResponse.model_validate(graphs_dict)
1089
+ return graphs.model_dump_json()
1090
+
1091
+ @mcp.tool()
1092
+ @log_call
1093
+ async def export_graph(graph_name: str = None, format: str = "pdf", session_id: str = "default") -> str:
1094
+ """Exports a graph to a file."""
1095
+ session = await session_manager.get_or_create_session(session_id)
1096
+ try:
1097
+ return await session.call("export_graph", {"graph_name": graph_name, "format": format})
1098
+ except Exception as e:
1099
+ raise RuntimeError(f"Failed to export graph: {e}")
1100
+
1101
+ @mcp.tool()
1102
+ @log_call
1103
+ async def get_help(topic: str, plain_text: bool = False, session_id: str = "default") -> str:
1104
+ """Returns help for a Stata command."""
1105
+ session = await session_manager.get_or_create_session(session_id)
1106
+ return await session.call("get_help", {"topic": topic, "plain_text": plain_text})
1107
+
1108
+ @mcp.tool()
1109
+ async def get_stored_results(session_id: str = "default") -> str:
1110
+ """Returns stored r() and e() results."""
1111
+ import json
1112
+ session = await session_manager.get_or_create_session(session_id)
1113
+ results = await session.call("get_stored_results", {})
1114
+ return json.dumps(results)
1115
+
1116
+ @mcp.tool()
1117
+ async def load_data(source: str, clear: bool = True, as_json: bool = True, raw: bool = False, max_output_lines: int | None = None, session_id: str = "default") -> str:
1118
+ """Loads a dataset."""
1119
+ session = await session_manager.get_or_create_session(session_id)
1120
+ result_dict = await session.call("load_data", {"source": source, "options": {"clear": clear, "max_output_lines": max_output_lines}})
1121
+
1122
+ result = CommandResponse.model_validate(result_dict)
1123
+ # ui_channel.notify_potential_dataset_change()
1124
+ if raw:
1125
+ return result.stdout if result.success else (result.error.message if result.error else result.stdout)
1126
+ return result.model_dump_json()
1127
+
1128
+ @mcp.tool()
1129
+ async def codebook(variable: str, as_json: bool = True, trace: bool = False, raw: bool = False, max_output_lines: int | None = None, session_id: str = "default") -> str:
1130
+ """Returns codebook for a variable."""
1131
+ session = await session_manager.get_or_create_session(session_id)
1132
+ result_dict = await session.call("codebook", {"variable": variable, "options": {"trace": trace, "max_output_lines": max_output_lines}})
1133
+
1134
+ result = CommandResponse.model_validate(result_dict)
1135
+ if raw:
1136
+ return result.stdout if result.success else (result.error.message if result.error else result.stdout)
1137
+ return result.model_dump_json()
1138
+
1139
+ @mcp.tool()
1140
+ @log_call
1141
+ async def run_do_file(
1142
+ path: str,
1143
+ ctx: Context | None = None,
1144
+ echo: bool = True,
1145
+ as_json: bool = True,
1146
+ trace: bool = False,
1147
+ raw: bool = False,
1148
+ max_output_lines: int = None,
1149
+ cwd: str | None = None,
1150
+ session_id: str = "default",
1151
+ ) -> str:
1152
+ """
1153
+ Executes a .do file.
1154
+
1155
+ Stata output is written to a temporary log file on disk.
1156
+ The server emits a single `notifications/logMessage` event containing the log file path
1157
+ (JSON payload: {"event":"log_path","path":"..."}) so the client can tail it locally.
1158
+ If the client supplies a progress callback/token, progress updates are emitted via
1159
+ `notifications/progress`.
1160
+
1161
+ Args:
1162
+ path: Path to the .do file to execute.
1163
+ ctx: FastMCP-injected request context (used to send MCP notifications). Optional for direct Python calls.
1164
+ echo: If True, includes command in output.
1165
+ as_json: If True, returns JSON envelope.
1166
+ trace: If True, enables trace mode.
1167
+ raw: If True, returns raw output only.
1168
+ max_output_lines: If set, truncates stdout to this many lines for token efficiency.
1169
+ Note: This tool always uses log-file streaming semantics; there is no non-streaming mode.
1170
+ """
1171
+ session = getattr(getattr(ctx, "request_context", None), "session", None) if ctx is not None else None
1172
+
1173
+ async def notify_log(text: str) -> None:
1174
+ if session is None:
1175
+ return
1176
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
1177
+ return
1178
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
1179
+ try:
1180
+ payload = json.loads(text)
1181
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
1182
+ if ctx.request_id is not None:
1183
+ _request_log_paths[str(ctx.request_id)] = payload.get("path")
1184
+ except Exception:
1185
+ return
1186
+
1187
+ progress_token = None
1188
+ if ctx is not None and getattr(ctx, "request_context", None) is not None and getattr(ctx.request_context, "meta", None) is not None:
1189
+ progress_token = getattr(ctx.request_context.meta, "progressToken", None)
1190
+
1191
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
1192
+ if session is None or progress_token is None:
1193
+ return
1194
+ await session.send_progress_notification(
1195
+ progress_token=progress_token,
1196
+ progress=progress,
1197
+ total=total,
1198
+ message=message,
1199
+ related_request_id=ctx.request_id,
1200
+ )
1201
+
1202
+ stata_session = await session_manager.get_or_create_session(session_id)
1203
+ result_dict = await stata_session.call(
1204
+ "run_do_file",
1205
+ {
1206
+ "path": path,
1207
+ "options": {
1208
+ "echo": echo,
1209
+ "trace": trace,
1210
+ "max_output_lines": max_output_lines,
1211
+ "cwd": cwd,
1212
+ "emit_graph_ready": True,
1213
+ "graph_ready_task_id": ctx.request_id if ctx else None,
1214
+ "graph_ready_format": "svg",
1215
+ }
1216
+ },
1217
+ notify_log=notify_log if session is not None else _noop_log,
1218
+ notify_progress=notify_progress if progress_token is not None else None,
1219
+ )
1220
+
1221
+ result = CommandResponse.model_validate(result_dict)
1222
+
1223
+ # ui_channel.notify_potential_dataset_change()
1224
+
1225
+ return _format_command_result(result, raw=raw, as_json=as_json)
1226
+
1227
+ @mcp.resource("stata://data/summary")
1228
+ async def get_summary() -> str:
1229
+ """Returns output of summarize."""
1230
+ session = await session_manager.get_or_create_session("default")
1231
+ result_dict = await session.call("run_command_structured", {"code": "summarize", "options": {"echo": True}})
1232
+
1233
+ result = CommandResponse.model_validate(result_dict)
1234
+ if result.success:
1235
+ return result.stdout
1236
+ if result.error:
1237
+ return result.error.message
1238
+ return ""
1239
+
1240
+ @mcp.resource("stata://data/metadata")
1241
+ async def get_metadata() -> str:
1242
+ """Returns output of describe."""
1243
+ session = await session_manager.get_or_create_session("default")
1244
+ result_dict = await session.call("run_command_structured", {"code": "describe", "options": {"echo": True}})
1245
+
1246
+ result = CommandResponse.model_validate(result_dict)
1247
+ if result.success:
1248
+ return result.stdout
1249
+ if result.error:
1250
+ return result.error.message
1251
+ return ""
1252
+
1253
+ @mcp.resource("stata://graphs/list")
1254
+ @log_call
1255
+ async def list_graphs_resource() -> str:
1256
+ """Resource wrapper for the graph list (uses tool list_graphs)."""
1257
+ return await list_graphs("default")
1258
+
1259
+ @mcp.tool()
1260
+ async def get_variable_list(session_id: str = "default") -> str:
1261
+ """Returns JSON list of all variables."""
1262
+ session = await session_manager.get_or_create_session(session_id)
1263
+ variables_dict = await session.call("list_variables_structured", {})
1264
+
1265
+ variables = VariablesResponse.model_validate(variables_dict)
1266
+ return variables.model_dump_json()
1267
+
1268
+ @mcp.resource("stata://variables/list")
1269
+ async def get_variable_list_resource() -> str:
1270
+ """Resource wrapper for the variable list."""
1271
+ return await get_variable_list("default")
1272
+
1273
+ @mcp.resource("stata://results/stored")
1274
+ async def get_stored_results_resource() -> str:
1275
+ """Returns stored r() and e() results."""
1276
+ session = await session_manager.get_or_create_session("default")
1277
+ results = await session.call("get_stored_results", {})
1278
+ return json.dumps(results)
1279
+
1280
+ @mcp.tool()
1281
+ async def export_graphs_all(session_id: str = "default") -> str:
1282
+ """
1283
+ Exports all graphs in memory to file paths.
1284
+
1285
+ Returns a JSON envelope listing graph names and file paths.
1286
+ The agent can open SVG files directly to verify visuals (titles/labels/colors/legends).
1287
+ """
1288
+ session = await session_manager.get_or_create_session(session_id)
1289
+ exports_dict = await session.call("export_graphs_all", {})
1290
+
1291
+ exports = GraphExportResponse.model_validate(exports_dict)
1292
+ return exports.model_dump_json(exclude_none=False)
1293
+
1294
+ def main():
1295
+ if "--version" in sys.argv:
1296
+ print(SERVER_VERSION)
1297
+ return
1298
+
1299
+ # Fix for macOS environments where sys.executable might be a shim that calls 'realpath'.
1300
+ # On some macOS versions (pre-Monterey) or minimal environments, 'realpath' is missing,
1301
+ # causing shims (like those from uv or pyenv) to fail.
1302
+ if sys.platform == "darwin":
1303
+ try:
1304
+ real_py = os.path.realpath(sys.executable)
1305
+ if real_py != sys.executable:
1306
+ multiprocessing.set_executable(real_py)
1307
+ except Exception:
1308
+ pass
1309
+
1310
+ # Filter non-JSON output off stdout to keep stdio transport clean.
1311
+ _install_stdout_filter()
1312
+
1313
+ setup_logging()
1314
+
1315
+ # Initialize UI channel with default session proxy logic if needed
1316
+ # (Simplified for now, UI might only show default session)
1317
+ global ui_channel
1318
+
1319
+ async def init_sessions():
1320
+ await session_manager.start()
1321
+ # We need a client-like object for UIChannelManager.
1322
+ # This is a bit tricky since it's now multi-session.
1323
+ # For now, we'll try to find a way to make UIChannelManager work or disable it.
1324
+ # Let's use the default session's worker proxy if it was a real client.
1325
+ # But for now, we'll skip UIChannelManager integration or keep it limited.
1326
+ pass
1327
+
1328
+ asyncio.run(init_sessions())
1329
+
1330
+ mcp.run()
1331
+
1332
+ if __name__ == "__main__":
1333
+ main()