mcp-stata 1.7.6__py3-none-any.whl → 1.16.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-stata might be problematic. Click here for more details.

mcp_stata/server.py CHANGED
@@ -1,6 +1,10 @@
1
1
  import anyio
2
+ import asyncio
3
+ from dataclasses import dataclass
4
+ from datetime import datetime
2
5
  from importlib.metadata import PackageNotFoundError, version
3
6
  from mcp.server.fastmcp import Context, FastMCP
7
+ from mcp.server.fastmcp.utilities import logging as fastmcp_logging
4
8
  import mcp.types as types
5
9
  from .stata_client import StataClient
6
10
  from .models import (
@@ -10,26 +14,575 @@ from .models import (
10
14
  GraphExportResponse,
11
15
  )
12
16
  import logging
17
+ import sys
13
18
  import json
14
19
  import os
20
+ import re
21
+ import traceback
22
+ import uuid
23
+ from functools import wraps
24
+ from typing import Optional, Dict
15
25
 
16
26
  from .ui_http import UIChannelManager
17
27
 
18
28
 
19
- LOG_LEVEL = os.getenv("MCP_STATA_LOGLEVEL", "INFO").upper()
20
- logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(levelname)s %(name)s - %(message)s")
21
- try:
22
- _mcp_stata_version = version("mcp-stata")
23
- except PackageNotFoundError:
24
- _mcp_stata_version = "unknown"
25
- logging.info("mcp-stata version: %s", _mcp_stata_version)
26
- logging.info("STATA_PATH env at startup: %s", os.getenv("STATA_PATH", "<not set>"))
29
+ # Configure logging
30
+ logger = logging.getLogger("mcp_stata")
31
+ payload_logger = logging.getLogger("mcp_stata.payloads")
32
+ _LOGGING_CONFIGURED = False
33
+
34
+ def setup_logging():
35
+ global _LOGGING_CONFIGURED
36
+ if _LOGGING_CONFIGURED:
37
+ return
38
+ _LOGGING_CONFIGURED = True
39
+ log_level = os.getenv("MCP_STATA_LOGLEVEL", "DEBUG").upper()
40
+ app_handler = logging.StreamHandler(sys.stderr)
41
+ app_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
42
+ app_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
43
+
44
+ mcp_handler = logging.StreamHandler(sys.stderr)
45
+ mcp_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
46
+ mcp_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
47
+
48
+ payload_handler = logging.StreamHandler(sys.stderr)
49
+ payload_handler.setLevel(getattr(logging, log_level, logging.DEBUG))
50
+ payload_handler.setFormatter(logging.Formatter("[%(name)s] %(levelname)s: %(message)s"))
51
+
52
+ root_logger = logging.getLogger()
53
+ root_logger.handlers = []
54
+ root_logger.setLevel(logging.WARNING)
55
+
56
+ for name, item in logging.root.manager.loggerDict.items():
57
+ if not isinstance(item, logging.Logger):
58
+ continue
59
+ item.handlers = []
60
+ item.propagate = False
61
+ if item.level == logging.NOTSET:
62
+ item.setLevel(getattr(logging, log_level, logging.DEBUG))
63
+
64
+ logger.handlers = [app_handler]
65
+ logger.propagate = False
66
+
67
+ payload_logger.handlers = [payload_handler]
68
+ payload_logger.propagate = False
69
+
70
+ mcp_logger = logging.getLogger("mcp.server")
71
+ mcp_logger.handlers = [mcp_handler]
72
+ mcp_logger.propagate = False
73
+ mcp_logger.setLevel(getattr(logging, log_level, logging.DEBUG))
74
+
75
+ mcp_lowlevel = logging.getLogger("mcp.server.lowlevel.server")
76
+ mcp_lowlevel.handlers = [mcp_handler]
77
+ mcp_lowlevel.propagate = False
78
+ mcp_lowlevel.setLevel(getattr(logging, log_level, logging.DEBUG))
79
+
80
+ mcp_root = logging.getLogger("mcp")
81
+ mcp_root.handlers = [mcp_handler]
82
+ mcp_root.propagate = False
83
+ mcp_root.setLevel(getattr(logging, log_level, logging.DEBUG))
84
+ if logger.level == logging.NOTSET:
85
+ logger.setLevel(getattr(logging, log_level, logging.DEBUG))
86
+
87
+ try:
88
+ _mcp_stata_version = version("mcp-stata")
89
+ except PackageNotFoundError:
90
+ _mcp_stata_version = "unknown"
91
+
92
+ logger.info("=== mcp-stata server starting ===")
93
+ logger.info("mcp-stata version: %s", _mcp_stata_version)
94
+ logger.info("STATA_PATH env at startup: %s", os.getenv("STATA_PATH", "<not set>"))
95
+ logger.info("LOG_LEVEL: %s", log_level)
96
+
97
+
27
98
 
28
99
  # Initialize FastMCP
29
100
  mcp = FastMCP("mcp_stata")
30
101
  client = StataClient()
31
102
  ui_channel = UIChannelManager(client)
32
103
 
104
+
105
+ @dataclass
106
+ class BackgroundTask:
107
+ task_id: str
108
+ kind: str
109
+ task: asyncio.Task
110
+ created_at: datetime
111
+ log_path: Optional[str] = None
112
+ result: Optional[str] = None
113
+ error: Optional[str] = None
114
+ done: bool = False
115
+
116
+
117
+ _background_tasks: Dict[str, BackgroundTask] = {}
118
+ _request_log_paths: Dict[str, str] = {}
119
+ _read_log_paths: set[str] = set()
120
+ _read_log_offsets: Dict[str, int] = {}
121
+
122
+
123
+ def _register_task(task_info: BackgroundTask, max_tasks: int = 100) -> None:
124
+ _background_tasks[task_info.task_id] = task_info
125
+ if len(_background_tasks) <= max_tasks:
126
+ return
127
+ completed = [task for task in _background_tasks.values() if task.done]
128
+ completed.sort(key=lambda item: item.created_at)
129
+ for task in completed[: max(0, len(_background_tasks) - max_tasks)]:
130
+ _background_tasks.pop(task.task_id, None)
131
+
132
+
133
+ def _format_command_result(result, raw: bool, as_json: bool) -> str:
134
+ if raw:
135
+ if result.success:
136
+ return result.log_path or ""
137
+ if result.error:
138
+ msg = result.error.message
139
+ if result.error.rc is not None:
140
+ msg = f"{msg}\nrc={result.error.rc}"
141
+ return msg
142
+ return result.log_path or ""
143
+ if as_json:
144
+ return result.model_dump_json()
145
+ return result.model_dump_json()
146
+
147
+
148
+ async def _wait_for_log_path(task_info: BackgroundTask) -> None:
149
+ while task_info.log_path is None and not task_info.done:
150
+ await anyio.sleep(0.01)
151
+
152
+
153
+ async def _notify_task_done(session: object | None, task_info: BackgroundTask, request_id: object | None) -> None:
154
+ if session is None:
155
+ return
156
+ payload = {
157
+ "event": "task_done",
158
+ "task_id": task_info.task_id,
159
+ "status": "done" if task_info.done else "unknown",
160
+ "log_path": task_info.log_path,
161
+ "error": task_info.error,
162
+ }
163
+ try:
164
+ await session.send_log_message(level="info", data=json.dumps(payload), related_request_id=request_id)
165
+ except Exception:
166
+ return
167
+
168
+
169
+ def _debug_notification(kind: str, payload: object, request_id: object | None = None) -> None:
170
+ try:
171
+ serialized = payload if isinstance(payload, str) else json.dumps(payload, ensure_ascii=False)
172
+ except Exception:
173
+ serialized = str(payload)
174
+ payload_logger.info("MCP notify %s request_id=%s payload=%s", kind, request_id, serialized)
175
+
176
+
177
+ async def _notify_tool_error(ctx: Context | None, tool_name: str, exc: Exception) -> None:
178
+ if ctx is None:
179
+ return
180
+ session = ctx.request_context.session
181
+ if session is None:
182
+ return
183
+ task_id = None
184
+ meta = ctx.request_context.meta
185
+ if meta is not None:
186
+ task_id = getattr(meta, "task_id", None) or getattr(meta, "taskId", None)
187
+ payload = {
188
+ "event": "tool_error",
189
+ "tool": tool_name,
190
+ "error": str(exc),
191
+ "traceback": traceback.format_exc(),
192
+ }
193
+ if task_id is not None:
194
+ payload["task_id"] = task_id
195
+ try:
196
+ await session.send_log_message(
197
+ level="error",
198
+ data=json.dumps(payload),
199
+ related_request_id=ctx.request_id,
200
+ )
201
+ except Exception:
202
+ logger.exception("Failed to emit tool_error notification for %s", tool_name)
203
+
204
+
205
+ def _log_tool_call(tool_name: str, ctx: Context | None = None) -> None:
206
+ request_id = None
207
+ if ctx is not None:
208
+ request_id = getattr(ctx, "request_id", None)
209
+ logger.info("MCP tool call: %s request_id=%s", tool_name, request_id)
210
+
211
+ def _should_stream_smcl_chunk(text: str, request_id: object | None) -> bool:
212
+ if request_id is None:
213
+ return True
214
+ try:
215
+ payload = json.loads(text)
216
+ if isinstance(payload, dict) and payload.get("event"):
217
+ return True
218
+ except Exception:
219
+ pass
220
+ log_path = _request_log_paths.get(str(request_id))
221
+ if log_path and log_path in _read_log_paths:
222
+ return False
223
+ return True
224
+
225
+
226
+ def _attach_task_id(ctx: Context | None, task_id: str) -> None:
227
+ if ctx is None:
228
+ return
229
+ meta = ctx.request_context.meta
230
+ if meta is None:
231
+ meta = types.RequestParams.Meta()
232
+ ctx.request_context.meta = meta
233
+ try:
234
+ setattr(meta, "task_id", task_id)
235
+ except Exception:
236
+ logger.debug("Unable to attach task_id to request meta", exc_info=True)
237
+
238
+
239
+ def _extract_ctx(args: tuple[object, ...], kwargs: dict[str, object]) -> Context | None:
240
+ ctx = kwargs.get("ctx")
241
+ if isinstance(ctx, Context):
242
+ return ctx
243
+ for arg in args:
244
+ if isinstance(arg, Context):
245
+ return arg
246
+ return None
247
+
248
+
249
+ _mcp_tool = mcp.tool
250
+ _mcp_resource = mcp.resource
251
+
252
+
253
+ def tool(*tool_args, **tool_kwargs):
254
+ decorator = _mcp_tool(*tool_args, **tool_kwargs)
255
+
256
+ def outer(func):
257
+ if asyncio.iscoroutinefunction(func):
258
+ @wraps(func)
259
+ async def async_inner(*args, **kwargs):
260
+ ctx = _extract_ctx(args, kwargs)
261
+ _log_tool_call(func.__name__, ctx)
262
+ try:
263
+ return await func(*args, **kwargs)
264
+ except Exception as exc:
265
+ await _notify_tool_error(ctx, func.__name__, exc)
266
+ raise
267
+
268
+ return decorator(async_inner)
269
+
270
+ @wraps(func)
271
+ def sync_inner(*args, **kwargs):
272
+ ctx = _extract_ctx(args, kwargs)
273
+ _log_tool_call(func.__name__, ctx)
274
+ try:
275
+ return func(*args, **kwargs)
276
+ except Exception:
277
+ logger.exception("Tool %s failed", func.__name__)
278
+ raise
279
+
280
+ return decorator(sync_inner)
281
+
282
+ return outer
283
+
284
+
285
+ mcp.tool = tool
286
+
287
+
288
+ def resource(*resource_args, **resource_kwargs):
289
+ decorator = _mcp_resource(*resource_args, **resource_kwargs)
290
+
291
+ def outer(func):
292
+ if asyncio.iscoroutinefunction(func):
293
+ @wraps(func)
294
+ async def async_inner(*args, **kwargs):
295
+ _log_tool_call(func.__name__, _extract_ctx(args, kwargs))
296
+ return await func(*args, **kwargs)
297
+
298
+ return decorator(async_inner)
299
+
300
+ @wraps(func)
301
+ def sync_inner(*args, **kwargs):
302
+ _log_tool_call(func.__name__, _extract_ctx(args, kwargs))
303
+ return func(*args, **kwargs)
304
+
305
+ return decorator(sync_inner)
306
+
307
+ return outer
308
+
309
+
310
+ mcp.resource = resource
311
+
312
+
313
+ @mcp.tool()
314
+ async def run_do_file_background(
315
+ path: str,
316
+ ctx: Context | None = None,
317
+ echo: bool = True,
318
+ as_json: bool = True,
319
+ trace: bool = False,
320
+ raw: bool = False,
321
+ max_output_lines: int = None,
322
+ cwd: str | None = None,
323
+ ) -> str:
324
+ """Run a Stata do-file in the background and return a task id.
325
+
326
+ Notifications:
327
+ - logMessage: {"event":"log_path","path":"..."}
328
+ - logMessage: {"event":"task_done","task_id":"...","status":"done","log_path":"...","error":null}
329
+ """
330
+ session = ctx.request_context.session if ctx is not None else None
331
+ request_id = ctx.request_id if ctx is not None else None
332
+ task_id = uuid.uuid4().hex
333
+ _attach_task_id(ctx, task_id)
334
+ task_info = BackgroundTask(
335
+ task_id=task_id,
336
+ kind="do_file",
337
+ task=None,
338
+ created_at=datetime.utcnow(),
339
+ )
340
+
341
+ async def notify_log(text: str) -> None:
342
+ if session is not None:
343
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
344
+ return
345
+ _debug_notification("logMessage", text, ctx.request_id)
346
+ try:
347
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
348
+ except Exception as e:
349
+ logger.warning("Failed to send logMessage notification: %s", e)
350
+ sys.stderr.write(f"[mcp_stata] ERROR: logMessage send failed: {e!r}\n")
351
+ sys.stderr.flush()
352
+ try:
353
+ payload = json.loads(text)
354
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
355
+ task_info.log_path = payload.get("path")
356
+ if ctx.request_id is not None and task_info.log_path:
357
+ _request_log_paths[str(ctx.request_id)] = task_info.log_path
358
+ except Exception:
359
+ return
360
+
361
+ progress_token = None
362
+ if ctx is not None and ctx.request_context.meta is not None:
363
+ progress_token = ctx.request_context.meta.progressToken
364
+
365
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
366
+ if session is None or progress_token is None:
367
+ return
368
+ _debug_notification(
369
+ "progress",
370
+ {"progress": progress, "total": total, "message": message},
371
+ ctx.request_id,
372
+ )
373
+ await session.send_progress_notification(
374
+ progress_token=progress_token,
375
+ progress=progress,
376
+ total=total,
377
+ message=message,
378
+ related_request_id=ctx.request_id,
379
+ )
380
+
381
+ async def _run() -> None:
382
+ try:
383
+ result = await client.run_do_file_streaming(
384
+ path,
385
+ notify_log=notify_log,
386
+ notify_progress=notify_progress if progress_token is not None else None,
387
+ echo=echo,
388
+ trace=trace,
389
+ max_output_lines=max_output_lines,
390
+ cwd=cwd,
391
+ emit_graph_ready=True,
392
+ graph_ready_task_id=task_id,
393
+ graph_ready_format="svg",
394
+ )
395
+ # Notify task completion as soon as the core operation is finished
396
+ task_info.done = True
397
+ if result.error:
398
+ task_info.error = result.error.message
399
+ await _notify_task_done(session, task_info, request_id)
400
+
401
+ ui_channel.notify_potential_dataset_change()
402
+ task_info.result = _format_command_result(result, raw=raw, as_json=as_json)
403
+ except Exception as exc: # pragma: no cover - defensive
404
+ task_info.done = True
405
+ task_info.error = str(exc)
406
+ await _notify_task_done(session, task_info, request_id)
407
+
408
+ task_info.task = asyncio.create_task(_run())
409
+ _register_task(task_info)
410
+ await _wait_for_log_path(task_info)
411
+ return json.dumps({"task_id": task_id, "status": "started", "log_path": task_info.log_path})
412
+
413
+
414
+ @mcp.tool()
415
+ def get_task_status(task_id: str, allow_polling: bool = False) -> str:
416
+ """Return task status for background executions.
417
+
418
+ Polling is disabled by default; set allow_polling=True for legacy callers.
419
+ """
420
+ notice = "Prefer task_done logMessage notifications over polling get_task_status."
421
+ if not allow_polling:
422
+ logger.warning(
423
+ "get_task_status called without allow_polling; clients must use task_done logMessage notifications"
424
+ )
425
+ return json.dumps({
426
+ "task_id": task_id,
427
+ "status": "polling_not_allowed",
428
+ "error": "Polling is disabled; use task_done logMessage notifications.",
429
+ "notice": notice,
430
+ })
431
+ logger.warning("get_task_status called; clients should use task_done logMessage notifications instead of polling")
432
+ task_info = _background_tasks.get(task_id)
433
+ if task_info is None:
434
+ return json.dumps({"task_id": task_id, "status": "not_found", "notice": notice})
435
+ return json.dumps({
436
+ "task_id": task_id,
437
+ "status": "done" if task_info.done else "running",
438
+ "kind": task_info.kind,
439
+ "created_at": task_info.created_at.isoformat(),
440
+ "log_path": task_info.log_path,
441
+ "error": task_info.error,
442
+ "notice": notice,
443
+ })
444
+
445
+
446
+ @mcp.tool()
447
+ def get_task_result(task_id: str, allow_polling: bool = False) -> str:
448
+ """Return task result for background executions.
449
+
450
+ Polling is disabled by default; set allow_polling=True for legacy callers.
451
+ """
452
+ notice = "Prefer task_done logMessage notifications over polling get_task_result."
453
+ if not allow_polling:
454
+ logger.warning(
455
+ "get_task_result called without allow_polling; clients must use task_done logMessage notifications"
456
+ )
457
+ return json.dumps({
458
+ "task_id": task_id,
459
+ "status": "polling_not_allowed",
460
+ "error": "Polling is disabled; use task_done logMessage notifications.",
461
+ "notice": notice,
462
+ })
463
+ logger.warning("get_task_result called; clients should use task_done logMessage notifications instead of polling")
464
+ task_info = _background_tasks.get(task_id)
465
+ if task_info is None:
466
+ return json.dumps({"task_id": task_id, "status": "not_found", "notice": notice})
467
+ if not task_info.done:
468
+ return json.dumps({
469
+ "task_id": task_id,
470
+ "status": "running",
471
+ "log_path": task_info.log_path,
472
+ "notice": notice,
473
+ })
474
+ return json.dumps({
475
+ "task_id": task_id,
476
+ "status": "done",
477
+ "log_path": task_info.log_path,
478
+ "error": task_info.error,
479
+ "notice": notice,
480
+ "result": task_info.result,
481
+ })
482
+
483
+
484
+ @mcp.tool()
485
+ def cancel_task(task_id: str) -> str:
486
+ """Request cancellation of a background task."""
487
+ task_info = _background_tasks.get(task_id)
488
+ if task_info is None:
489
+ return json.dumps({"task_id": task_id, "status": "not_found"})
490
+ if task_info.task and not task_info.task.done():
491
+ task_info.task.cancel()
492
+ return json.dumps({"task_id": task_id, "status": "cancelling"})
493
+ return json.dumps({"task_id": task_id, "status": "done", "log_path": task_info.log_path})
494
+
495
+
496
+ @mcp.tool()
497
+ async def run_command_background(
498
+ code: str,
499
+ ctx: Context | None = None,
500
+ echo: bool = True,
501
+ as_json: bool = True,
502
+ trace: bool = False,
503
+ raw: bool = False,
504
+ max_output_lines: int = None,
505
+ cwd: str | None = None,
506
+ ) -> str:
507
+ """Run a Stata command in the background and return a task id.
508
+
509
+ Notifications:
510
+ - logMessage: {"event":"log_path","path":"..."}
511
+ - logMessage: {"event":"task_done","task_id":"...","status":"done","log_path":"...","error":null}
512
+ """
513
+ session = ctx.request_context.session if ctx is not None else None
514
+ request_id = ctx.request_id if ctx is not None else None
515
+ task_id = uuid.uuid4().hex
516
+ _attach_task_id(ctx, task_id)
517
+ task_info = BackgroundTask(
518
+ task_id=task_id,
519
+ kind="command",
520
+ task=None,
521
+ created_at=datetime.utcnow(),
522
+ )
523
+
524
+ async def notify_log(text: str) -> None:
525
+ if session is not None:
526
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
527
+ return
528
+ _debug_notification("logMessage", text, ctx.request_id)
529
+ await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
530
+ try:
531
+ payload = json.loads(text)
532
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
533
+ task_info.log_path = payload.get("path")
534
+ if ctx.request_id is not None and task_info.log_path:
535
+ _request_log_paths[str(ctx.request_id)] = task_info.log_path
536
+ except Exception:
537
+ return
538
+
539
+ progress_token = None
540
+ if ctx is not None and ctx.request_context.meta is not None:
541
+ progress_token = ctx.request_context.meta.progressToken
542
+
543
+ async def notify_progress(progress: float, total: float | None, message: str | None) -> None:
544
+ if session is None or progress_token is None:
545
+ return
546
+ await session.send_progress_notification(
547
+ progress_token=progress_token,
548
+ progress=progress,
549
+ total=total,
550
+ message=message,
551
+ related_request_id=ctx.request_id,
552
+ )
553
+
554
+ async def _run() -> None:
555
+ try:
556
+ result = await client.run_command_streaming(
557
+ code,
558
+ notify_log=notify_log,
559
+ notify_progress=notify_progress if progress_token is not None else None,
560
+ echo=echo,
561
+ trace=trace,
562
+ max_output_lines=max_output_lines,
563
+ cwd=cwd,
564
+ emit_graph_ready=True,
565
+ graph_ready_task_id=task_id,
566
+ graph_ready_format="svg",
567
+ )
568
+ # Notify task completion as soon as the core operation is finished
569
+ task_info.done = True
570
+ if result.error:
571
+ task_info.error = result.error.message
572
+ await _notify_task_done(session, task_info, request_id)
573
+
574
+ ui_channel.notify_potential_dataset_change()
575
+ task_info.result = _format_command_result(result, raw=raw, as_json=as_json)
576
+ except Exception as exc: # pragma: no cover - defensive
577
+ task_info.done = True
578
+ task_info.error = str(exc)
579
+ await _notify_task_done(session, task_info, request_id)
580
+
581
+ task_info.task = asyncio.create_task(_run())
582
+ _register_task(task_info)
583
+ await _wait_for_log_path(task_info)
584
+ return json.dumps({"task_id": task_id, "status": "started", "log_path": task_info.log_path})
585
+
33
586
  @mcp.tool()
34
587
  async def run_command(
35
588
  code: str,
@@ -68,7 +621,17 @@ async def run_command(
68
621
  async def notify_log(text: str) -> None:
69
622
  if session is None:
70
623
  return
624
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
625
+ return
626
+ _debug_notification("logMessage", text, ctx.request_id)
71
627
  await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
628
+ try:
629
+ payload = json.loads(text)
630
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
631
+ if ctx.request_id is not None:
632
+ _request_log_paths[str(ctx.request_id)] = payload.get("path")
633
+ except Exception:
634
+ return
72
635
 
73
636
  progress_token = None
74
637
  if ctx is not None and ctx.request_context.meta is not None:
@@ -96,6 +659,9 @@ async def run_command(
96
659
  trace=trace,
97
660
  max_output_lines=max_output_lines,
98
661
  cwd=cwd,
662
+ emit_graph_ready=True,
663
+ graph_ready_task_id=ctx.request_id if ctx else None,
664
+ graph_ready_format="svg",
99
665
  )
100
666
 
101
667
  # Conservative invalidation: arbitrary Stata commands may change data.
@@ -128,12 +694,20 @@ def read_log(path: str, offset: int = 0, max_bytes: int = 65536) -> str:
128
694
  Returns a compact JSON string: {"path":..., "offset":..., "next_offset":..., "data":...}
129
695
  """
130
696
  try:
697
+ if path:
698
+ _read_log_paths.add(path)
131
699
  if offset < 0:
132
700
  offset = 0
701
+ if path:
702
+ last_offset = _read_log_offsets.get(path, 0)
703
+ if offset < last_offset:
704
+ offset = last_offset
133
705
  with open(path, "rb") as f:
134
706
  f.seek(offset)
135
707
  data = f.read(max_bytes)
136
708
  next_offset = f.tell()
709
+ if path:
710
+ _read_log_offsets[path] = next_offset
137
711
  text = data.decode("utf-8", errors="replace")
138
712
  return json.dumps({"path": path, "offset": offset, "next_offset": next_offset, "data": text})
139
713
  except FileNotFoundError:
@@ -142,6 +716,111 @@ def read_log(path: str, offset: int = 0, max_bytes: int = 65536) -> str:
142
716
  return json.dumps({"path": path, "offset": offset, "next_offset": offset, "data": f"ERROR: {e}"})
143
717
 
144
718
 
719
+ @mcp.tool()
720
+ def find_in_log(
721
+ path: str,
722
+ query: str,
723
+ start_offset: int = 0,
724
+ max_bytes: int = 5_000_000,
725
+ before: int = 2,
726
+ after: int = 2,
727
+ case_sensitive: bool = False,
728
+ regex: bool = False,
729
+ max_matches: int = 50,
730
+ ) -> str:
731
+ """Find text within a log file and return context windows.
732
+
733
+ Args:
734
+ path: Absolute path to the log file previously provided by the server.
735
+ query: Text or regex pattern to search for.
736
+ start_offset: Byte offset to start searching from.
737
+ max_bytes: Maximum bytes to read from the log.
738
+ before: Number of context lines to include before each match.
739
+ after: Number of context lines to include after each match.
740
+ case_sensitive: If True, match case-sensitively.
741
+ regex: If True, treat query as a regular expression.
742
+ max_matches: Maximum number of matches to return.
743
+
744
+ Returns a JSON string with matches and offsets:
745
+ {"path":..., "query":..., "start_offset":..., "next_offset":..., "truncated":..., "matches":[...]}.
746
+ """
747
+ try:
748
+ if start_offset < 0:
749
+ start_offset = 0
750
+ if max_bytes <= 0:
751
+ return json.dumps({
752
+ "path": path,
753
+ "query": query,
754
+ "start_offset": start_offset,
755
+ "next_offset": start_offset,
756
+ "truncated": False,
757
+ "matches": [],
758
+ })
759
+ with open(path, "rb") as f:
760
+ f.seek(start_offset)
761
+ data = f.read(max_bytes)
762
+ next_offset = f.tell()
763
+
764
+ text = data.decode("utf-8", errors="replace")
765
+ lines = text.splitlines()
766
+
767
+ if regex:
768
+ flags = 0 if case_sensitive else re.IGNORECASE
769
+ pattern = re.compile(query, flags=flags)
770
+ def is_match(line: str) -> bool:
771
+ return pattern.search(line) is not None
772
+ else:
773
+ needle = query if case_sensitive else query.lower()
774
+ def is_match(line: str) -> bool:
775
+ haystack = line if case_sensitive else line.lower()
776
+ return needle in haystack
777
+
778
+ matches = []
779
+ for idx, line in enumerate(lines):
780
+ if not is_match(line):
781
+ continue
782
+ start_idx = max(0, idx - max(0, before))
783
+ end_idx = min(len(lines), idx + max(0, after) + 1)
784
+ context = lines[start_idx:end_idx]
785
+ matches.append({
786
+ "line_index": idx,
787
+ "context_start": start_idx,
788
+ "context_end": end_idx,
789
+ "context": context,
790
+ })
791
+ if len(matches) >= max_matches:
792
+ break
793
+
794
+ truncated = len(matches) >= max_matches
795
+ return json.dumps({
796
+ "path": path,
797
+ "query": query,
798
+ "start_offset": start_offset,
799
+ "next_offset": next_offset,
800
+ "truncated": truncated,
801
+ "matches": matches,
802
+ })
803
+ except FileNotFoundError:
804
+ return json.dumps({
805
+ "path": path,
806
+ "query": query,
807
+ "start_offset": start_offset,
808
+ "next_offset": start_offset,
809
+ "truncated": False,
810
+ "matches": [],
811
+ })
812
+ except Exception as e:
813
+ return json.dumps({
814
+ "path": path,
815
+ "query": query,
816
+ "start_offset": start_offset,
817
+ "next_offset": start_offset,
818
+ "truncated": False,
819
+ "matches": [],
820
+ "error": f"ERROR: {e}",
821
+ })
822
+
823
+
145
824
  @mcp.tool()
146
825
  def get_data(start: int = 0, count: int = 50) -> str:
147
826
  """
@@ -306,7 +985,16 @@ async def run_do_file(
306
985
  async def notify_log(text: str) -> None:
307
986
  if session is None:
308
987
  return
988
+ if not _should_stream_smcl_chunk(text, ctx.request_id):
989
+ return
309
990
  await session.send_log_message(level="info", data=text, related_request_id=ctx.request_id)
991
+ try:
992
+ payload = json.loads(text)
993
+ if isinstance(payload, dict) and payload.get("event") == "log_path":
994
+ if ctx.request_id is not None:
995
+ _request_log_paths[str(ctx.request_id)] = payload.get("path")
996
+ except Exception:
997
+ return
310
998
 
311
999
  progress_token = None
312
1000
  if ctx is not None and ctx.request_context.meta is not None:
@@ -334,6 +1022,9 @@ async def run_do_file(
334
1022
  trace=trace,
335
1023
  max_output_lines=max_output_lines,
336
1024
  cwd=cwd,
1025
+ emit_graph_ready=True,
1026
+ graph_ready_task_id=ctx.request_id if ctx else None,
1027
+ graph_ready_format="svg",
337
1028
  )
338
1029
 
339
1030
  ui_channel.notify_potential_dataset_change()
@@ -395,35 +1086,45 @@ def get_stored_results_resource() -> str:
395
1086
  return json.dumps(client.get_stored_results())
396
1087
 
397
1088
  @mcp.tool()
398
- def export_graphs_all(use_base64: bool = False) -> str:
1089
+ def export_graphs_all() -> str:
399
1090
  """
400
- Exports all graphs in memory to file paths (default) or base64-encoded SVGs.
1091
+ Exports all graphs in memory to file paths.
401
1092
 
402
- Args:
403
- use_base64: If True, returns base64-encoded images (token-intensive).
404
- If False (default), returns file paths to SVG files (token-efficient).
405
- Use file paths unless you need to embed images directly.
406
-
407
- Returns a JSON envelope listing graph names and either file paths or base64 images.
1093
+ Returns a JSON envelope listing graph names and file paths.
408
1094
  The agent can open SVG files directly to verify visuals (titles/labels/colors/legends).
409
1095
  """
410
- exports = client.export_graphs_all(use_base64=use_base64)
1096
+ exports = client.export_graphs_all()
411
1097
  return exports.model_dump_json(exclude_none=False)
412
1098
 
413
1099
  def main():
414
- # On Windows, Stata automation relies on COM, which is sensitive to threading models.
415
- # The FastMCP server executes tool calls in a thread pool. If Stata is initialized
416
- # lazily inside a worker thread, it may fail or hang due to COM/UI limitations.
417
- # We explicitly initialize Stata here on the main thread to ensure the COM server
418
- # is properly registered and accessible.
419
- if os.name == "nt":
1100
+ if "--version" in sys.argv:
420
1101
  try:
421
- client.init()
422
- except Exception as e:
423
- # Log error but let the server start; specific tools will fail gracefully later
424
- logging.error(f"Stata initialization failed: {e}")
1102
+ from importlib.metadata import version
1103
+ print(version("mcp-stata"))
1104
+ except Exception:
1105
+ print("unknown")
1106
+ return
1107
+
1108
+ setup_logging()
1109
+
1110
+ # Initialize Stata here on the main thread to ensure any issues are logged early.
1111
+ # On Windows, this is critical for COM registration. On other platforms, it helps
1112
+ # catch license or installation errors before the first tool call.
1113
+ try:
1114
+ client.init()
1115
+ except BaseException as e:
1116
+ # Use sys.stderr.write and flush to ensure visibility before exit
1117
+ msg = f"\n{'='*60}\n[mcp_stata] FATAL: STATA INITIALIZATION FAILED\n{'='*60}\nError: {repr(e)}\n"
1118
+ sys.stderr.write(msg)
1119
+ if isinstance(e, SystemExit):
1120
+ sys.stderr.write(f"Stata triggered a SystemExit (code: {e.code}). This is usually a license error.\n")
1121
+ sys.stderr.write(f"{'='*60}\n\n")
1122
+ sys.stderr.flush()
1123
+
1124
+ # We exit here because the user wants a clear failure when Stata cannot be loaded.
1125
+ sys.exit(1)
425
1126
 
426
1127
  mcp.run()
427
1128
 
428
1129
  if __name__ == "__main__":
429
- main()
1130
+ main()