mcp-stata 1.21.0__cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_stata/server.py ADDED
@@ -0,0 +1,1233 @@
1
+ from __future__ import annotations
2
+ import anyio
3
+ import asyncio
4
+ from dataclasses import dataclass
5
+ from datetime import datetime
6
+ from importlib.metadata import PackageNotFoundError, version
7
+ from mcp.server.fastmcp import Context, FastMCP
8
+ from mcp.server.fastmcp.utilities import logging as fastmcp_logging
9
+ import mcp.types as types
10
+ from .stata_client import StataClient
11
+ from .models import (
12
+ DataResponse,
13
+ GraphListResponse,
14
+ VariablesResponse,
15
+ GraphExportResponse,
16
+ )
17
+ import logging
18
+ import sys
19
+ import json
20
+ import os
21
+ import re
22
+ import traceback
23
+ import uuid
24
+ from functools import wraps
25
+ from typing import Optional, Dict
26
+ import threading
27
+
28
+ from .ui_http import UIChannelManager
29
+
30
+
31
+ # Configure logging
32
+ logger = logging.getLogger("mcp_stata")
33
+ payload_logger = logging.getLogger("mcp_stata.payloads")
34
+ _LOGGING_CONFIGURED = False
35
+
36
+ def get_server_version() -> str:
37
+ """Determine the server version from package metadata or fallback."""
38
+ try:
39
+ return version("mcp-stata")
40
+ except PackageNotFoundError:
41
+ # If not installed, try to find version in pyproject.toml near this file
42
+ try:
43
+ # We are in src/mcp_stata/server.py, pyproject.toml is at ../../pyproject.toml
44
+ base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
45
+ pyproject_path = os.path.join(base_dir, "pyproject.toml")
46
+ if os.path.exists(pyproject_path):
47
+ with open(pyproject_path, "r") as f:
48
+ import re
49
+ content = f.read()
50
+ match = re.search(r'^version\s*=\s*["\']([^"\']+)["\']', content, re.MULTILINE)
51
+ if match:
52
+ return match.group(1)
53
+ except Exception:
54
+ pass
55
+ return "unknown"
56
+
57
+ SERVER_VERSION = get_server_version()
58
+
59
+ def setup_logging():
60
+ global _LOGGING_CONFIGURED
61
+ if _LOGGING_CONFIGURED:
62
+ return
63
+ _LOGGING_CONFIGURED = True
64
+ log_level = os.getenv("MCP_STATA_LOGLEVEL", "DEBUG").upper()
65
+ app_handler = logging.StreamHandler(sys.stderr)
66
+ app_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
67
+ app_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
68
+
69
+ mcp_handler = logging.StreamHandler(sys.stderr)
70
+ mcp_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
71
+ mcp_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
72
+
73
+ payload_handler = logging.StreamHandler(sys.stderr)
74
+ payload_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
75
+ payload_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
76
+
77
+ root_logger = logging.getLogger()
78
+ root_logger.handlers = []
79
+ root_logger.setLevel(logging.WARNING)
80
+
81
+ for name, item in logging.root.manager.loggerDict.items():
82
+ if not isinstance(item, logging.Logger):
83
+ continue
84
+ item.handlers = []
85
+ item.propagate = False
86
+ if item.level == logging.NOTSET:
87
+ item.setLevel(getattr(logging, log_level, logging.DEBUG))
88
+
89
+ logger.handlers = [app_handler]
90
+ logger.propagate = False
91
+
92
+ payload_logger.handlers = [payload_handler]
93
+ payload_logger.propagate = False
94
+
95
+ mcp_logger = logging.getLogger("mcp.server")
96
+ mcp_logger.handlers = [mcp_handler]
97
+ mcp_logger.propagate = False
98
+ mcp_logger.setLevel(getattr(logging, log_level, logging.DEBUG))
99
+
100
+ mcp_lowlevel = logging.getLogger("mcp.server.lowlevel.server")
101
+ mcp_lowlevel.handlers = [mcp_handler]
102
+ mcp_lowlevel.propagate = False
103
+ mcp_lowlevel.setLevel(getattr(logging, log_level, logging.DEBUG))
104
+
105
+ mcp_root = logging.getLogger("mcp")
106
+ mcp_root.handlers = [mcp_handler]
107
+ mcp_root.propagate = False
108
+ mcp_root.setLevel(getattr(logging, log_level, logging.DEBUG))
109
+ if logger.level == logging.NOTSET:
110
+ logger.setLevel(getattr(logging, log_level, logging.DEBUG))
111
+
112
+ logger.info("=== mcp-stata server starting ===")
113
+ logger.info("mcp-stata version: %s", SERVER_VERSION)
114
+ logger.info("STATA_PATH env at startup: %s", os.getenv("STATA_PATH", "<not set>"))
115
+ logger.info("LOG_LEVEL: %s", log_level)
116
+
117
+
118
+
119
+ # Initialize FastMCP
120
+ mcp = FastMCP("mcp_stata")
121
+ # Set version on the underlying server to expose it in InitializeResult
122
+ mcp._mcp_server.version = SERVER_VERSION
123
+
124
+ client = StataClient()
125
+ ui_channel = UIChannelManager(client)
126
+
127
+
128
+ @dataclass
129
+ class BackgroundTask:
130
+ task_id: str
131
+ kind: str
132
+ task: asyncio.Task
133
+ created_at: datetime
134
+ log_path: Optional[str] = None
135
+ result: Optional[str] = None
136
+ error: Optional[str] = None
137
+ done: bool = False
138
+
139
+
140
+ _background_tasks: Dict[str, BackgroundTask] = {}
141
+ _request_log_paths: Dict[str, str] = {}
142
+ _read_log_paths: set[str] = set()
143
+ _read_log_offsets: Dict[str, int] = {}
144
+ _STDOUT_FILTER_INSTALLED = False
145
+
146
+
147
+ def _install_stdout_filter() -> None:
148
+ """
149
+ Redirect process stdout to a pipe and forward only JSON-RPC lines to the
150
+ original stdout. Any non-JSON output (e.g., Stata noise) is sent to stderr.
151
+ """
152
+ global _STDOUT_FILTER_INSTALLED
153
+ if _STDOUT_FILTER_INSTALLED:
154
+ return
155
+ _STDOUT_FILTER_INSTALLED = True
156
+
157
+ try:
158
+ # Flush any pending output before redirecting.
159
+ try:
160
+ sys.stdout.flush()
161
+ except Exception:
162
+ pass
163
+
164
+ original_stdout_fd = os.dup(1)
165
+ read_fd, write_fd = os.pipe()
166
+ os.dup2(write_fd, 1)
167
+ os.close(write_fd)
168
+
169
+ def _forward_stdout() -> None:
170
+ buffer = b""
171
+ while True:
172
+ try:
173
+ chunk = os.read(read_fd, 4096)
174
+ except Exception:
175
+ break
176
+ if not chunk:
177
+ break
178
+ buffer += chunk
179
+ while b"\n" in buffer:
180
+ line, buffer = buffer.split(b"\n", 1)
181
+ line_with_nl = line + b"\n"
182
+ stripped = line.lstrip()
183
+ if stripped:
184
+ try:
185
+ payload = json.loads(stripped)
186
+ if isinstance(payload, dict) and payload.get("jsonrpc"):
187
+ os.write(original_stdout_fd, line_with_nl)
188
+ elif isinstance(payload, list) and any(
189
+ isinstance(item, dict) and item.get("jsonrpc") for item in payload
190
+ ):
191
+ os.write(original_stdout_fd, line_with_nl)
192
+ else:
193
+ os.write(2, line_with_nl)
194
+ except Exception:
195
+ os.write(2, line_with_nl)
196
+ if buffer:
197
+ stripped = buffer.lstrip()
198
+ if stripped:
199
+ try:
200
+ payload = json.loads(stripped)
201
+ if isinstance(payload, dict) and payload.get("jsonrpc"):
202
+ os.write(original_stdout_fd, buffer)
203
+ elif isinstance(payload, list) and any(
204
+ isinstance(item, dict) and item.get("jsonrpc") for item in payload
205
+ ):
206
+ os.write(original_stdout_fd, buffer)
207
+ else:
208
+ os.write(2, buffer)
209
+ except Exception:
210
+ os.write(2, buffer)
211
+
212
+ try:
213
+ os.close(read_fd)
214
+ except Exception:
215
+ pass
216
+
217
+ t = threading.Thread(target=_forward_stdout, name="mcp-stdout-filter", daemon=True)
218
+ t.start()
219
+ except Exception:
220
+ _STDOUT_FILTER_INSTALLED = False
221
+
222
+
223
+ def _register_task(task_info: BackgroundTask, max_tasks: int = 100) -> None:
224
+ _background_tasks[task_info.task_id] = task_info
225
+ if len(_background_tasks) <= max_tasks:
226
+ return
227
+ completed = [task for task in _background_tasks.values() if task.done]
228
+ completed.sort(key=lambda item: item.created_at)
229
+ for task in completed[: max(0, len(_background_tasks) - max_tasks)]:
230
+ _background_tasks.pop(task.task_id, None)
231
+
232
+
233
+ def _format_command_result(result, raw: bool, as_json: bool) -> str:
234
+ if raw:
235
+ if result.success:
236
+ return result.log_path or ""
237
+ if result.error:
238
+ msg = result.error.message
239
+ if result.error.rc is not None:
240
+ msg = f"{msg}\nrc={result.error.rc}"
241
+ return msg
242
+ return result.log_path or ""
243
+
244
+ # Note: we used to clear result.stdout here for token efficiency,
245
+ # but that conflicts with requirements and breaks E2E tests that
246
+ # expect results in the return value.
247
+
248
+ if as_json:
249
+ return result.model_dump_json()
250
+ return result.model_dump_json()
251
+
252
+
253
+ async def _wait_for_log_path(task_info: BackgroundTask) -> None:
254
+ while task_info.log_path is None and not task_info.done:
255
+ await anyio.sleep(0.01)
256
+
257
+
258
+ async def _notify_task_done(session: object | None, task_info: BackgroundTask, request_id: object | None) -> None:
259
+ if session is None:
260
+ return
261
+ payload = {
262
+ "event": "task_done",
263
+ "task_id": task_info.task_id,
264
+ "status": "done" if task_info.done else "unknown",
265
+ "log_path": task_info.log_path,
266
+ "error": task_info.error,
267
+ }
268
+ try:
269
+ await session.send_log_message(level="info", data=json.dumps(payload), related_request_id=request_id)
270
+ except Exception:
271
+ return
272
+
273
+
274
+ def _debug_notification(kind: str, payload: object, request_id: object | None = None) -> None:
275
+ try:
276
+ serialized = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False)
277
+ except Exception:
278
+ serialized = str(payload)
279
+ payload_logger.info("MCP notify %s request_id=%s payload=%s", kind, request_id, serialized)
280
+
281
+
282
+ async def _notify_tool_error(ctx: Context | None, tool_name: str, exc: Exception) -> None:
283
+ if ctx is None:
284
+ return
285
+ session = ctx.request_context.session
286
+ if session is None:
287
+ return
288
+ task_id = None
289
+ meta = ctx.request_context.meta
290
+ if meta is not None:
291
+ task_id = getattr(meta, "task_id", None) or getattr(meta, "taskId", None)
292
+ payload = {
293
+ "event": "tool_error",
294
+ "tool": tool_name,
295
+ "error": str(exc),
296
+ "traceback": traceback.format_exc(),
297
+ }
298
+ if task_id is not None:
299
+ payload["task_id"] = task_id
300
+ try:
301
+ await session.send_log_message(
302
+ level="error",
303
+ data=json.dumps(payload),
304
+ related_request_id=ctx.request_id,
305
+ )
306
+ except Exception:
307
+ logger.exception("Failed to emit tool_error notification for %s", tool_name)
308
+
309
+
310
+ def _log_tool_call(tool_name: str, ctx: Context | None = None) -> None:
311
+ request_id = None
312
+ if ctx is not None:
313
+ request_id = getattr(ctx, "request_id", None)
314
+ logger.info("MCP tool call: %s request_id=%s", tool_name, request_id)
315
+
316
+ def _should_stream_smcl_chunk(text: str, request_id: object | None) -> bool:
317
+ if request_id is None:
318
+ return True
319
+ try:
320
+ payload = json.loads(text)
321
+ if isinstance(payload, dict) and payload.get("event"):
322
+ return True
323
+ except Exception:
324
+ pass
325
+ log_path = _request_log_paths.get(str(request_id))
326
+ if log_path and log_path in _read_log_paths:
327
+ return False
328
+ return True
329
+
330
+
331
+ def _attach_task_id(ctx: Context | None, task_id: str) -> None:
332
+ if ctx is None:
333
+ return
334
+ meta = ctx.request_context.meta
335
+ if meta is None:
336
+ meta = types.RequestParams.Meta()
337
+ ctx.request_context.meta = meta
338
+ try:
339
+ setattr(meta, "task_id", task_id)
340
+ except Exception:
341
+ logger.debug("Unable to attach task_id to request meta", exc_info=True)
342
+
343
+
344
+ def _extract_ctx(args: tuple[object, ...], kwargs: dict[str, object]) -> Context | None:
345
+ ctx = kwargs.get("ctx")
346
+ if isinstance(ctx, Context):
347
+ return ctx
348
+ for arg in args:
349
+ if isinstance(arg, Context):
350
+ return arg
351
+ return None
352
+
353
+
354
+ _mcp_tool = mcp.tool
355
+ _mcp_resource = mcp.resource
356
+
357
+
358
+ def tool(*tool_args, **tool_kwargs):
359
+ decorator = _mcp_tool(*tool_args, **tool_kwargs)
360
+
361
+ def outer(func):
362
+ if asyncio.iscoroutinefunction(func):
363
+ @wraps(func)
364
+ async def async_inner(*args, **kwargs):
365
+ ctx = _extract_ctx(args, kwargs)
366
+ _log_tool_call(func.__name__, ctx)
367
+ try:
368
+ return await func(*args, **kwargs)
369
+ except Exception as exc:
370
+ await _notify_tool_error(ctx, func.__name__, exc)
371
+ raise
372
+
373
+ return decorator(async_inner)
374
+
375
+ @wraps(func)
376
+ def sync_inner(*args, **kwargs):
377
+ ctx = _extract_ctx(args, kwargs)
378
+ _log_tool_call(func.__name__, ctx)
379
+ try:
380
+ return func(*args, **kwargs)
381
+ except Exception:
382
+ logger.exception("Tool %s failed", func.__name__)
383
+ raise
384
+
385
+ return decorator(sync_inner)
386
+
387
+ return outer
388
+
389
+
390
+ mcp.tool = tool
391
+
392
+
393
+ def resource(*resource_args, **resource_kwargs):
394
+ decorator = _mcp_resource(*resource_args, **resource_kwargs)
395
+
396
+ def outer(func):
397
+ if asyncio.iscoroutinefunction(func):
398
+ @wraps(func)
399
+ async def async_inner(*args, **kwargs):
400
+ _log_tool_call(func.__name__, _extract_ctx(args, kwargs))
401
+ return await func(*args, **kwargs)
402
+
403
+ return decorator(async_inner)
404
+
405
+ @wraps(func)
406
+ def sync_inner(*args, **kwargs):
407
+ _log_tool_call(func.__name__, _extract_ctx(args, kwargs))
408
+ return func(*args, **kwargs)
409
+
410
+ return decorator(sync_inner)
411
+
412
+ return outer
413
+
414
+
415
+ mcp.resource = resource
416
+
417
+
418
+ @mcp.tool()
419
+ async def run_do_file_background(
420
+ path: str,
421
+ ctx: Context | None = None,
422
+ echo: bool = True,
423
+ as_json: bool = True,
424
+ trace: bool = False,
425
+ raw: bool = False,
426
+ max_output_lines: int = None,
427
+ cwd: str | None = None,
428
+ ) -> str:
429
+ """Run a Stata do-file in the background and return a task id.
430
+
431
+ Notifications:
432
+ - logMessage: {"event":"log_path","path":"..."}
433
+ - logMessage: {"event":"task_done","task_id":"...","status":"done","log_path":"...","error":null}
434
+ """
435
+ session = ctx.request_context.session if ctx is not None else None
436
+ request_id = ctx.request_id if ctx is not None else None
437
+ task_id = uuid.uuid4().hex
438
+ _attach_task_id(ctx, task_id)
439
+ task_info = BackgroundTask(
440
+ task_id=task_id,
441
+ kind="do_file",
442
+ task=None,
443
+ created_at=datetime.utcnow(),
444
+ )
445
+
446
+ async def notify_log(text: str) -> None:
447
+ if session is not None:
448
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
449
+ return
450
+ _debug_notification("logMessage", text, ctx.request_id)
451
+ try:
452
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
453
+ except Exception as e:
454
+ logger.warning("Failed to send logMessage notification: %s", e)
455
+ sys.stderr.write(f"[mcp_stata] ERROR: logMessage send failed: {e!r}\n")
456
+ sys.stderr.flush()
457
+ try:
458
+ payload = json.loads(text)
459
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
460
+ task_info.log_path = payload.get("path")
461
+ if ctx.request_id is not None and task_info.log_path:
462
+ _request_log_paths[str(ctx.request_id)] = task_info.log_path
463
+ except Exception:
464
+ return
465
+
466
+ progress_token = None
467
+ if ctx is not None and ctx.request_context.meta is not None:
468
+ progress_token = ctx.request_context.meta.progressToken
469
+
470
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
471
+ if session is None or progress_token is None:
472
+ return
473
+ _debug_notification(
474
+ "progress",
475
+ {"progress": progress, "total": total, "message": message},
476
+ ctx.request_id,
477
+ )
478
+ try:
479
+ await session.send_progress_notification(
480
+ progress_token=progress_token,
481
+ progress=progress,
482
+ total=total,
483
+ message=message,
484
+ related_request_id=ctx.request_id,
485
+ )
486
+ except Exception as exc:
487
+ logger.debug("Progress notification failed: %s", exc)
488
+
489
+ async def _run() -> None:
490
+ try:
491
+ result = await client.run_do_file_streaming(
492
+ path,
493
+ notify_log=notify_log,
494
+ notify_progress=notify_progress if progress_token is not None else None,
495
+ echo=echo,
496
+ trace=trace,
497
+ max_output_lines=max_output_lines,
498
+ cwd=cwd,
499
+ emit_graph_ready=True,
500
+ graph_ready_task_id=task_id,
501
+ graph_ready_format="svg",
502
+ )
503
+ task_info.result = _format_command_result(result, raw=raw, as_json=as_json)
504
+ if not task_info.log_path and result.log_path:
505
+ task_info.log_path = result.log_path
506
+ if result.error:
507
+ task_info.error = result.error.message
508
+ # Notify task completion after result is available
509
+ task_info.done = True
510
+ await _notify_task_done(session, task_info, request_id)
511
+
512
+ ui_channel.notify_potential_dataset_change()
513
+ except Exception as exc: # pragma: no cover - defensive
514
+ task_info.done = True
515
+ task_info.error = str(exc)
516
+ await _notify_task_done(session, task_info, request_id)
517
+
518
+ if session is None:
519
+ await _run()
520
+ task_info.task = None
521
+ else:
522
+ task_info.task = asyncio.create_task(_run())
523
+ _register_task(task_info)
524
+ await _wait_for_log_path(task_info)
525
+ return json.dumps({"task_id": task_id, "status": "started", "log_path": task_info.log_path})
526
+
527
+
528
+ @mcp.tool()
529
+ def get_task_status(task_id: str, allow_polling: bool = False) -> str:
530
+ """Return task status for background executions.
531
+
532
+ Polling is disabled by default; set allow_polling=True for legacy callers.
533
+ """
534
+ notice = "Prefer task_done logMessage notifications over polling get_task_status."
535
+ if not allow_polling:
536
+ logger.warning(
537
+ "get_task_status called without allow_polling; clients must use task_done logMessage notifications"
538
+ )
539
+ return json.dumps({
540
+ "task_id": task_id,
541
+ "status": "polling_not_allowed",
542
+ "error": "Polling is disabled; use task_done logMessage notifications.",
543
+ "notice": notice,
544
+ })
545
+ logger.warning("get_task_status called; clients should use task_done logMessage notifications instead of polling")
546
+ task_info = _background_tasks.get(task_id)
547
+ if task_info is None:
548
+ return json.dumps({"task_id": task_id, "status": "not_found", "notice": notice})
549
+ return json.dumps({
550
+ "task_id": task_id,
551
+ "status": "done" if task_info.done else "running",
552
+ "kind": task_info.kind,
553
+ "created_at": task_info.created_at.isoformat(),
554
+ "log_path": task_info.log_path,
555
+ "error": task_info.error,
556
+ "notice": notice,
557
+ })
558
+
559
+
560
+ @mcp.tool()
561
+ def get_task_result(task_id: str, allow_polling: bool = False) -> str:
562
+ """Return task result for background executions.
563
+
564
+ Polling is disabled by default; set allow_polling=True for legacy callers.
565
+ """
566
+ notice = "Prefer task_done logMessage notifications over polling get_task_result."
567
+ if not allow_polling:
568
+ logger.warning(
569
+ "get_task_result called without allow_polling; clients must use task_done logMessage notifications"
570
+ )
571
+ return json.dumps({
572
+ "task_id": task_id,
573
+ "status": "polling_not_allowed",
574
+ "error": "Polling is disabled; use task_done logMessage notifications.",
575
+ "notice": notice,
576
+ })
577
+ logger.warning("get_task_result called; clients should use task_done logMessage notifications instead of polling")
578
+ task_info = _background_tasks.get(task_id)
579
+ if task_info is None:
580
+ return json.dumps({"task_id": task_id, "status": "not_found", "notice": notice})
581
+ if not task_info.done:
582
+ return json.dumps({
583
+ "task_id": task_id,
584
+ "status": "running",
585
+ "log_path": task_info.log_path,
586
+ "notice": notice,
587
+ })
588
+ return json.dumps({
589
+ "task_id": task_id,
590
+ "status": "done",
591
+ "log_path": task_info.log_path,
592
+ "error": task_info.error,
593
+ "notice": notice,
594
+ "result": task_info.result,
595
+ })
596
+
597
+
598
+ @mcp.tool()
599
+ def cancel_task(task_id: str) -> str:
600
+ """Request cancellation of a background task."""
601
+ task_info = _background_tasks.get(task_id)
602
+ if task_info is None:
603
+ return json.dumps({"task_id": task_id, "status": "not_found"})
604
+ if task_info.task and not task_info.task.done():
605
+ task_info.task.cancel()
606
+ return json.dumps({"task_id": task_id, "status": "cancelling"})
607
+ return json.dumps({"task_id": task_id, "status": "done", "log_path": task_info.log_path})
608
+
609
+
610
+ @mcp.tool()
611
+ async def run_command_background(
612
+ code: str,
613
+ ctx: Context | None = None,
614
+ echo: bool = True,
615
+ as_json: bool = True,
616
+ trace: bool = False,
617
+ raw: bool = False,
618
+ max_output_lines: int = None,
619
+ cwd: str | None = None,
620
+ ) -> str:
621
+ """Run a Stata command in the background and return a task id.
622
+
623
+ Notifications:
624
+ - logMessage: {"event":"log_path","path":"..."}
625
+ - logMessage: {"event":"task_done","task_id":"...","status":"done","log_path":"...","error":null}
626
+ """
627
+ session = ctx.request_context.session if ctx is not None else None
628
+ request_id = ctx.request_id if ctx is not None else None
629
+ task_id = uuid.uuid4().hex
630
+ _attach_task_id(ctx, task_id)
631
+ task_info = BackgroundTask(
632
+ task_id=task_id,
633
+ kind="command",
634
+ task=None,
635
+ created_at=datetime.utcnow(),
636
+ )
637
+
638
+ async def notify_log(text: str) -> None:
639
+ if session is not None:
640
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
641
+ return
642
+ _debug_notification("logMessage", text, ctx.request_id)
643
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
644
+ try:
645
+ payload = json.loads(text)
646
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
647
+ task_info.log_path = payload.get("path")
648
+ if ctx.request_id is not None and task_info.log_path:
649
+ _request_log_paths[str(ctx.request_id)] = task_info.log_path
650
+ except Exception:
651
+ return
652
+
653
+ progress_token = None
654
+ if ctx is not None and ctx.request_context.meta is not None:
655
+ progress_token = ctx.request_context.meta.progressToken
656
+
657
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
658
+ if session is None or progress_token is None:
659
+ return
660
+ await session.send_progress_notification(
661
+ progress_token=progress_token,
662
+ progress=progress,
663
+ total=total,
664
+ message=message,
665
+ related_request_id=ctx.request_id,
666
+ )
667
+
668
+ async def _run() -> None:
669
+ try:
670
+ result = await client.run_command_streaming(
671
+ code,
672
+ notify_log=notify_log,
673
+ notify_progress=notify_progress if progress_token is not None else None,
674
+ echo=echo,
675
+ trace=trace,
676
+ max_output_lines=max_output_lines,
677
+ cwd=cwd,
678
+ emit_graph_ready=True,
679
+ graph_ready_task_id=task_id,
680
+ graph_ready_format="svg",
681
+ )
682
+ task_info.result = _format_command_result(result, raw=raw, as_json=as_json)
683
+ if not task_info.log_path and result.log_path:
684
+ task_info.log_path = result.log_path
685
+ if result.error:
686
+ task_info.error = result.error.message
687
+ # Notify task completion after result is available
688
+ task_info.done = True
689
+ await _notify_task_done(session, task_info, request_id)
690
+
691
+ ui_channel.notify_potential_dataset_change()
692
+ except Exception as exc: # pragma: no cover - defensive
693
+ task_info.done = True
694
+ task_info.error = str(exc)
695
+ await _notify_task_done(session, task_info, request_id)
696
+
697
+ if session is None:
698
+ await _run()
699
+ task_info.task = None
700
+ else:
701
+ task_info.task = asyncio.create_task(_run())
702
+ _register_task(task_info)
703
+ await _wait_for_log_path(task_info)
704
+ return json.dumps({"task_id": task_id, "status": "started", "log_path": task_info.log_path})
705
+
706
+ @mcp.tool()
707
+ async def run_command(
708
+ code: str,
709
+ ctx: Context | None = None,
710
+ echo: bool = True,
711
+ as_json: bool = True,
712
+ trace: bool = False,
713
+ raw: bool = False,
714
+ max_output_lines: int = None,
715
+ cwd: str | None = None,
716
+ ) -> str:
717
+ """
718
+ Executes Stata code.
719
+
720
+ This is the primary tool for interacting with Stata.
721
+
722
+ Stata output is written to a temporary log file on disk.
723
+ The server emits a single `notifications/logMessage` event containing the log file path
724
+ (JSON payload: {"event":"log_path","path":"..."}) so the client can tail it locally.
725
+ If the client supplies a progress callback/token, progress updates may also be emitted
726
+ via `notifications/progress`.
727
+
728
+ Args:
729
+ code: The Stata command(s) to execute (e.g., "sysuse auto", "regress price mpg", "summarize").
730
+ ctx: FastMCP-injected request context (used to send MCP notifications). Optional for direct Python calls.
731
+ echo: If True, the command itself is included in the output. Default is True.
732
+ as_json: If True, returns a JSON envelope with rc/stdout/stderr/error.
733
+ trace: If True, enables `set trace on` for deeper error diagnostics (automatically disabled after).
734
+ raw: If True, return raw output/error message rather than a JSON envelope.
735
+ max_output_lines: If set, truncates stdout to this many lines for token efficiency.
736
+ Useful for verbose commands (regress, codebook, etc.).
737
+ Note: This tool always uses log-file streaming semantics; there is no non-streaming mode.
738
+ """
739
+ session = ctx.request_context.session if ctx is not None else None
740
+
741
+ async def notify_log(text: str) -> None:
742
+ if session is None:
743
+ return
744
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
745
+ return
746
+ _debug_notification("logMessage", text, ctx.request_id)
747
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
748
+ try:
749
+ payload = json.loads(text)
750
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
751
+ if ctx.request_id is not None:
752
+ _request_log_paths[str(ctx.request_id)] = payload.get("path")
753
+ except Exception:
754
+ return
755
+
756
+ progress_token = None
757
+ if ctx is not None and ctx.request_context.meta is not None:
758
+ progress_token = ctx.request_context.meta.progressToken
759
+
760
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
761
+ if session is None or progress_token is None:
762
+ return
763
+ await session.send_progress_notification(
764
+ progress_token=progress_token,
765
+ progress=progress,
766
+ total=total,
767
+ message=message,
768
+ related_request_id=ctx.request_id,
769
+ )
770
+
771
+ async def _noop_log(_text: str) -> None:
772
+ return
773
+
774
+ result = await client.run_command_streaming(
775
+ code,
776
+ notify_log=notify_log if session is not None else _noop_log,
777
+ notify_progress=notify_progress if progress_token is not None else None,
778
+ echo=echo,
779
+ trace=trace,
780
+ max_output_lines=max_output_lines,
781
+ cwd=cwd,
782
+ emit_graph_ready=True,
783
+ graph_ready_task_id=ctx.request_id if ctx else None,
784
+ graph_ready_format="svg",
785
+ )
786
+
787
+ # Conservative invalidation: arbitrary Stata commands may change data.
788
+ ui_channel.notify_potential_dataset_change()
789
+
790
+ return _format_command_result(result, raw=raw, as_json=as_json)
791
+
792
+ @mcp.tool()
793
+ def read_log(path: str, offset: int = 0, max_bytes: int = 65536) -> str:
794
+ """Read a slice of a log file.
795
+
796
+ Intended for clients that want to display a terminal-like view without pushing MBs of
797
+ output through MCP log notifications.
798
+
799
+ Args:
800
+ path: Absolute path to the log file previously provided by the server.
801
+ offset: Byte offset to start reading from.
802
+ max_bytes: Maximum bytes to read.
803
+
804
+ Returns a compact JSON string: {"path":..., "offset":..., "next_offset":..., "data":...}
805
+ """
806
+ try:
807
+ if path:
808
+ _read_log_paths.add(path)
809
+ if offset < 0:
810
+ offset = 0
811
+ if path:
812
+ last_offset = _read_log_offsets.get(path, 0)
813
+ if offset < last_offset:
814
+ offset = last_offset
815
+ with open(path, "rb") as f:
816
+ f.seek(offset)
817
+ data = f.read(max_bytes)
818
+ next_offset = f.tell()
819
+ if path:
820
+ _read_log_offsets[path] = next_offset
821
+ text = data.decode("utf-8", errors="replace")
822
+ return json.dumps({"path": path, "offset": offset, "next_offset": next_offset, "data": text})
823
+ except FileNotFoundError:
824
+ return json.dumps({"path": path, "offset": offset, "next_offset": offset, "data": ""})
825
+ except Exception as e:
826
+ return json.dumps({"path": path, "offset": offset, "next_offset": offset, "data": f"ERROR: {e}"})
827
+
828
+
829
+ @mcp.tool()
830
+ def find_in_log(
831
+ path: str,
832
+ query: str,
833
+ start_offset: int = 0,
834
+ max_bytes: int = 5_000_000,
835
+ before: int = 2,
836
+ after: int = 2,
837
+ case_sensitive: bool = False,
838
+ regex: bool = False,
839
+ max_matches: int = 50,
840
+ ) -> str:
841
+ """Find text within a log file and return context windows.
842
+
843
+ Args:
844
+ path: Absolute path to the log file previously provided by the server.
845
+ query: Text or regex pattern to search for.
846
+ start_offset: Byte offset to start searching from.
847
+ max_bytes: Maximum bytes to read from the log.
848
+ before: Number of context lines to include before each match.
849
+ after: Number of context lines to include after each match.
850
+ case_sensitive: If True, match case-sensitively.
851
+ regex: If True, treat query as a regular expression.
852
+ max_matches: Maximum number of matches to return.
853
+
854
+ Returns a JSON string with matches and offsets:
855
+ {"path":..., "query":..., "start_offset":..., "next_offset":..., "truncated":..., "matches":[...]}.
856
+ """
857
+ try:
858
+ if start_offset < 0:
859
+ start_offset = 0
860
+ if max_bytes <= 0:
861
+ return json.dumps({
862
+ "path": path,
863
+ "query": query,
864
+ "start_offset": start_offset,
865
+ "next_offset": start_offset,
866
+ "truncated": False,
867
+ "matches": [],
868
+ })
869
+ with open(path, "rb") as f:
870
+ f.seek(start_offset)
871
+ data = f.read(max_bytes)
872
+ next_offset = f.tell()
873
+
874
+ text = data.decode("utf-8", errors="replace")
875
+ lines = text.splitlines()
876
+
877
+ if regex:
878
+ flags = 0 if case_sensitive else re.IGNORECASE
879
+ pattern = re.compile(query, flags=flags)
880
+ def is_match(line: str) -> bool:
881
+ return pattern.search(line) is not None
882
+ else:
883
+ needle = query if case_sensitive else query.lower()
884
+ def is_match(line: str) -> bool:
885
+ haystack = line if case_sensitive else line.lower()
886
+ return needle in haystack
887
+
888
+ matches = []
889
+ for idx, line in enumerate(lines):
890
+ if not is_match(line):
891
+ continue
892
+ start_idx = max(0, idx - max(0, before))
893
+ end_idx = min(len(lines), idx + max(0, after) + 1)
894
+ context = lines[start_idx:end_idx]
895
+ matches.append({
896
+ "line_index": idx,
897
+ "context_start": start_idx,
898
+ "context_end": end_idx,
899
+ "context": context,
900
+ })
901
+ if len(matches) >= max_matches:
902
+ break
903
+
904
+ truncated = len(matches) >= max_matches
905
+ return json.dumps({
906
+ "path": path,
907
+ "query": query,
908
+ "start_offset": start_offset,
909
+ "next_offset": next_offset,
910
+ "truncated": truncated,
911
+ "matches": matches,
912
+ })
913
+ except FileNotFoundError:
914
+ return json.dumps({
915
+ "path": path,
916
+ "query": query,
917
+ "start_offset": start_offset,
918
+ "next_offset": start_offset,
919
+ "truncated": False,
920
+ "matches": [],
921
+ })
922
+ except Exception as e:
923
+ return json.dumps({
924
+ "path": path,
925
+ "query": query,
926
+ "start_offset": start_offset,
927
+ "next_offset": start_offset,
928
+ "truncated": False,
929
+ "matches": [],
930
+ "error": f"ERROR: {e}",
931
+ })
932
+
933
+
934
+ @mcp.tool()
935
+ def get_data(start: int = 0, count: int = 50) -> str:
936
+ """
937
+ Returns a slice of the active dataset as a JSON-formatted list of dictionaries.
938
+
939
+ Use this to inspect the actual data values in memory. Useful for checking data quality or content.
940
+
941
+ Args:
942
+ start: The zero-based index of the first observation to retrieve.
943
+ count: The number of observations to retrieve. Defaults to 50.
944
+ """
945
+ data = client.get_data(start, count)
946
+ resp = DataResponse(start=start, count=count, data=data)
947
+ return resp.model_dump_json()
948
+
949
+
950
+ @mcp.tool()
951
+ def get_ui_channel() -> str:
952
+ """Return localhost HTTP endpoint + bearer token for the extension UI data plane."""
953
+ info = ui_channel.get_channel()
954
+ payload = {
955
+ "baseUrl": info.base_url,
956
+ "token": info.token,
957
+ "expiresAt": info.expires_at,
958
+ "capabilities": ui_channel.capabilities(),
959
+ }
960
+ return json.dumps(payload)
961
+
962
+ @mcp.tool()
963
+ def describe() -> str:
964
+ """
965
+ Returns variable descriptions, storage types, and labels (equivalent to Stata's `describe` command).
966
+
967
+ Use this to understand the structure of the dataset, variable names, and their formats before running analysis.
968
+ """
969
+ result = client.run_command_structured("describe", echo=True)
970
+ if result.success:
971
+ return result.stdout
972
+ if result.error:
973
+ return result.error.message
974
+ return ""
975
+
976
+ @mcp.tool()
977
+ def list_graphs() -> str:
978
+ """
979
+ Lists the names of all graphs currently stored in Stata's memory.
980
+
981
+ Use this to see which graphs are available for export via `export_graph`. The
982
+ response marks the active graph so the agent knows which one will export by
983
+ default.
984
+ """
985
+ graphs = client.list_graphs_structured()
986
+ return graphs.model_dump_json()
987
+
988
+ @mcp.tool()
989
+ def export_graph(graph_name: str = None, format: str = "pdf") -> str:
990
+ """
991
+ Exports a stored Stata graph to a file and returns its path.
992
+
993
+ Args:
994
+ graph_name: The name of the graph to export (as seen in `list_graphs`).
995
+ If None, exports the currently active graph.
996
+ format: Output format, defaults to "pdf". Supported: "pdf", "png". Use
997
+ "png" to view the plot directly so the agent can visually check
998
+ titles, labels, legends, colors, and other user requirements.
999
+ """
1000
+ try:
1001
+ return client.export_graph(graph_name, format=format)
1002
+ except Exception as e:
1003
+ raise RuntimeError(f"Failed to export graph: {e}")
1004
+
1005
+ @mcp.tool()
1006
+ def get_help(topic: str, plain_text: bool = False) -> str:
1007
+ """
1008
+ Returns the official Stata help text for a given command or topic.
1009
+
1010
+ Args:
1011
+ topic: The command name or help topic (e.g., "regress", "graph", "options").
1012
+ Returns Markdown by default, or plain text when plain_text=True.
1013
+ """
1014
+ return client.get_help(topic, plain_text=plain_text)
1015
+
1016
+ @mcp.tool()
1017
+ def get_stored_results() -> str:
1018
+ """
1019
+ Returns the current stored results (r-class and e-class scalars/macros) as a JSON-formatted string.
1020
+
1021
+ Use this after running a command (like `summarize` or `regress`) to programmatically retrieve
1022
+ specific values (e.g., means, coefficients, sample sizes) for validation or further calculation.
1023
+ """
1024
+ import json
1025
+ return json.dumps(client.get_stored_results())
1026
+
1027
+ @mcp.tool()
1028
+ def load_data(source: str, clear: bool = True, as_json: bool = True, raw: bool = False, max_output_lines: int = None) -> str:
1029
+ """
1030
+ Loads data using sysuse/webuse/use heuristics based on the source string.
1031
+ Automatically appends , clear unless clear=False.
1032
+
1033
+ Args:
1034
+ source: Dataset source (e.g., "auto", "auto.dta", "/path/to/file.dta").
1035
+ clear: If True, clears data in memory before loading.
1036
+ as_json: If True, returns JSON envelope.
1037
+ raw: If True, returns raw output only.
1038
+ max_output_lines: If set, truncates stdout to this many lines for token efficiency.
1039
+ """
1040
+ result = client.load_data(source, clear=clear, max_output_lines=max_output_lines)
1041
+ ui_channel.notify_potential_dataset_change()
1042
+ if raw:
1043
+ return result.stdout if result.success else (result.error.message if result.error else result.stdout)
1044
+ return result.model_dump_json()
1045
+
1046
+ @mcp.tool()
1047
+ def codebook(variable: str, as_json: bool = True, trace: bool = False, raw: bool = False, max_output_lines: int = None) -> str:
1048
+ """
1049
+ Returns codebook/summary for a specific variable.
1050
+
1051
+ Args:
1052
+ variable: The variable name to analyze.
1053
+ as_json: If True, returns JSON envelope.
1054
+ trace: If True, enables trace mode.
1055
+ raw: If True, returns raw output only.
1056
+ max_output_lines: If set, truncates stdout to this many lines for token efficiency.
1057
+ """
1058
+ result = client.codebook(variable, trace=trace, max_output_lines=max_output_lines)
1059
+ if raw:
1060
+ return result.stdout if result.success else (result.error.message if result.error else result.stdout)
1061
+ return result.model_dump_json()
1062
+
1063
+ @mcp.tool()
1064
+ async def run_do_file(
1065
+ path: str,
1066
+ ctx: Context | None = None,
1067
+ echo: bool = True,
1068
+ as_json: bool = True,
1069
+ trace: bool = False,
1070
+ raw: bool = False,
1071
+ max_output_lines: int = None,
1072
+ cwd: str | None = None,
1073
+ ) -> str:
1074
+ """
1075
+ Executes a .do file.
1076
+
1077
+ Stata output is written to a temporary log file on disk.
1078
+ The server emits a single `notifications/logMessage` event containing the log file path
1079
+ (JSON payload: {"event":"log_path","path":"..."}) so the client can tail it locally.
1080
+ If the client supplies a progress callback/token, progress updates are emitted via
1081
+ `notifications/progress`.
1082
+
1083
+ Args:
1084
+ path: Path to the .do file to execute.
1085
+ ctx: FastMCP-injected request context (used to send MCP notifications). Optional for direct Python calls.
1086
+ echo: If True, includes command in output.
1087
+ as_json: If True, returns JSON envelope.
1088
+ trace: If True, enables trace mode.
1089
+ raw: If True, returns raw output only.
1090
+ max_output_lines: If set, truncates stdout to this many lines for token efficiency.
1091
+ Note: This tool always uses log-file streaming semantics; there is no non-streaming mode.
1092
+ """
1093
+ session = ctx.request_context.session if ctx is not None else None
1094
+
1095
+ async def notify_log(text: str) -> None:
1096
+ if session is None:
1097
+ return
1098
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
1099
+ return
1100
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
1101
+ try:
1102
+ payload = json.loads(text)
1103
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
1104
+ if ctx.request_id is not None:
1105
+ _request_log_paths[str(ctx.request_id)] = payload.get("path")
1106
+ except Exception:
1107
+ return
1108
+
1109
+ progress_token = None
1110
+ if ctx is not None and ctx.request_context.meta is not None:
1111
+ progress_token = ctx.request_context.meta.progressToken
1112
+
1113
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
1114
+ if session is None or progress_token is None:
1115
+ return
1116
+ await session.send_progress_notification(
1117
+ progress_token=progress_token,
1118
+ progress=progress,
1119
+ total=total,
1120
+ message=message,
1121
+ related_request_id=ctx.request_id,
1122
+ )
1123
+
1124
+ async def _noop_log(_text: str) -> None:
1125
+ return
1126
+
1127
+ result = await client.run_do_file_streaming(
1128
+ path,
1129
+ notify_log=notify_log if session is not None else _noop_log,
1130
+ notify_progress=notify_progress if progress_token is not None else None,
1131
+ echo=echo,
1132
+ trace=trace,
1133
+ max_output_lines=max_output_lines,
1134
+ cwd=cwd,
1135
+ emit_graph_ready=True,
1136
+ graph_ready_task_id=ctx.request_id if ctx else None,
1137
+ graph_ready_format="svg",
1138
+ )
1139
+
1140
+ ui_channel.notify_potential_dataset_change()
1141
+
1142
+ return _format_command_result(result, raw=raw, as_json=as_json)
1143
+
1144
+ @mcp.resource("stata://data/summary")
1145
+ def get_summary() -> str:
1146
+ """
1147
+ Returns the output of the `summarize` command for the dataset in memory.
1148
+ Provides descriptive statistics (obs, mean, std. dev, min, max) for all variables.
1149
+ """
1150
+ result = client.run_command_structured("summarize", echo=True)
1151
+ if result.success:
1152
+ return result.stdout
1153
+ if result.error:
1154
+ return result.error.message
1155
+ return ""
1156
+
1157
+ @mcp.resource("stata://data/metadata")
1158
+ def get_metadata() -> str:
1159
+ """
1160
+ Returns the output of the `describe` command.
1161
+ Provides metadata about the dataset, including variable names, storage types, display formats, and labels.
1162
+ """
1163
+ result = client.run_command_structured("describe", echo=True)
1164
+ if result.success:
1165
+ return result.stdout
1166
+ if result.error:
1167
+ return result.error.message
1168
+ return ""
1169
+
1170
+ @mcp.resource("stata://graphs/list")
1171
+ def list_graphs_resource() -> str:
1172
+ """Resource wrapper for the graph list (uses tool list_graphs)."""
1173
+ return list_graphs()
1174
+
1175
+ @mcp.tool()
1176
+ def get_variable_list() -> str:
1177
+ """Returns JSON list of all variables."""
1178
+ variables = client.list_variables_structured()
1179
+ return variables.model_dump_json()
1180
+
1181
+ @mcp.resource("stata://variables/list")
1182
+ def get_variable_list_resource() -> str:
1183
+ """Resource wrapper for the variable list."""
1184
+ return get_variable_list()
1185
+
1186
+ @mcp.resource("stata://results/stored")
1187
+ def get_stored_results_resource() -> str:
1188
+ """Returns stored r() and e() results."""
1189
+ import json
1190
+ return json.dumps(client.get_stored_results())
1191
+
1192
+ @mcp.tool()
1193
+ def export_graphs_all() -> str:
1194
+ """
1195
+ Exports all graphs in memory to file paths.
1196
+
1197
+ Returns a JSON envelope listing graph names and file paths.
1198
+ The agent can open SVG files directly to verify visuals (titles/labels/colors/legends).
1199
+ """
1200
+ exports = client.export_graphs_all()
1201
+ return exports.model_dump_json(exclude_none=False)
1202
+
1203
+ def main():
1204
+ if "--version" in sys.argv:
1205
+ print(SERVER_VERSION)
1206
+ return
1207
+
1208
+ # Filter non-JSON output off stdout to keep stdio transport clean.
1209
+ _install_stdout_filter()
1210
+
1211
+ setup_logging()
1212
+
1213
+ # Initialize Stata here on the main thread to ensure any issues are logged early.
1214
+ # On Windows, this is critical for COM registration. On other platforms, it helps
1215
+ # catch license or installation errors before the first tool call.
1216
+ try:
1217
+ client.init()
1218
+ except BaseException as e:
1219
+ # Use sys.stderr.write and flush to ensure visibility before exit
1220
+ msg = f"\n{'='*60}\n[mcp_stata] FATAL: STATA INITIALIZATION FAILED\n{'='*60}\nError: {repr(e)}\n"
1221
+ sys.stderr.write(msg)
1222
+ if isinstance(e, SystemExit):
1223
+ sys.stderr.write(f"Stata triggered a SystemExit (code: {e.code}). This is usually a license error.\n")
1224
+ sys.stderr.write(f"{'='*60}\n\n")
1225
+ sys.stderr.flush()
1226
+
1227
+ # We exit here because the user wants a clear failure when Stata cannot be loaded.
1228
+ sys.exit(1)
1229
+
1230
+ mcp.run()
1231
+
1232
+ if __name__ == "__main__":
1233
+ main()