mcp-stata 1.22.1__cp311-abi3-macosx_11_0_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mcp_stata/utils.py ADDED
@@ -0,0 +1,159 @@
1
+ from __future__ import annotations
2
+ import os
3
+ import tempfile
4
+ import pathlib
5
+ import uuid
6
+ import logging
7
+ import threading
8
+ import shutil
9
+ import atexit
10
+ import signal
11
+ import sys
12
+ from typing import Optional, List
13
+
14
+ logger = logging.getLogger("mcp_stata")
15
+
16
+ _temp_dir_cache: Optional[str] = None
17
+ _temp_dir_lock = threading.Lock()
18
+ _files_to_cleanup: set[pathlib.Path] = set()
19
+ _dirs_to_cleanup: set[pathlib.Path] = set()
20
+
21
+ def register_temp_file(path: str | pathlib.Path) -> None:
22
+ """
23
+ Register a file to be deleted on process exit.
24
+ Using this instead of NamedTemporaryFile(delete=True) because on Windows,
25
+ delete=True prevents Stata from opening the file simultaneously.
26
+ """
27
+ with _temp_dir_lock:
28
+ p = pathlib.Path(path).absolute()
29
+ _files_to_cleanup.add(p)
30
+
31
+ def register_temp_dir(path: str | pathlib.Path) -> None:
32
+ """Register a directory to be recursively deleted on process exit."""
33
+ with _temp_dir_lock:
34
+ p = pathlib.Path(path).absolute()
35
+ _dirs_to_cleanup.add(p)
36
+
37
+ def is_windows() -> bool:
38
+ """Returns True if the current operating system is Windows."""
39
+ return os.name == "nt"
40
+
41
+ def _cleanup_temp_resources():
42
+ """Cleanup registered temporary files and directories."""
43
+ with _temp_dir_lock:
44
+ # Sort and copy to avoid modification during iteration
45
+ files = sorted(list(_files_to_cleanup), reverse=True)
46
+ for p in files:
47
+ try:
48
+ # missing_ok=True is Python 3.8+
49
+ p.unlink(missing_ok=True)
50
+ _files_to_cleanup.discard(p)
51
+ except Exception:
52
+ pass
53
+
54
+ dirs = sorted(list(_dirs_to_cleanup), reverse=True)
55
+ for p in dirs:
56
+ try:
57
+ if p.exists() and p.is_dir():
58
+ shutil.rmtree(p, ignore_errors=True)
59
+ _dirs_to_cleanup.discard(p)
60
+ except Exception:
61
+ pass
62
+
63
+ atexit.register(_cleanup_temp_resources)
64
+
65
+ def _signal_handler(signum, frame):
66
+ """Handle signals by cleaning up and exiting."""
67
+ _cleanup_temp_resources()
68
+ sys.exit(0)
69
+
70
+ # Register signal handlers for graceful cleanup on termination
71
+ try:
72
+ # Avoid hijacking signals if we are running in a test environment or not in main thread
73
+ is_pytest = "pytest" in sys.modules or "PYTEST_CURRENT_TEST" in os.environ
74
+ if threading.current_thread() is threading.main_thread() and not is_pytest:
75
+ signal.signal(signal.SIGTERM, _signal_handler)
76
+ signal.signal(signal.SIGINT, _signal_handler)
77
+ except (ValueError, RuntimeError):
78
+ # Not in main thread or other signal handling restriction
79
+ pass
80
+
81
+ def get_writable_temp_dir() -> str:
82
+ """
83
+ Finds a writable temporary directory by trying multiple fallback locations.
84
+ Priority:
85
+ 1. MCP_STATA_TEMP environment variable
86
+ 2. System Temp (tempfile.gettempdir())
87
+ 3. User Home subdirectory (~/.mcp-stata/temp)
88
+ 4. Current Working Directory subdirectory (.tmp)
89
+
90
+ Results are cached after the first successful identification.
91
+ """
92
+ global _temp_dir_cache
93
+
94
+ with _temp_dir_lock:
95
+ if _temp_dir_cache is not None:
96
+ return _temp_dir_cache
97
+
98
+ candidates = []
99
+
100
+ # 1. Environment variable
101
+ env_temp = os.getenv("MCP_STATA_TEMP")
102
+ if env_temp:
103
+ candidates.append((pathlib.Path(env_temp), "MCP_STATA_TEMP environment variable"))
104
+
105
+ # 2. System Temp
106
+ candidates.append((pathlib.Path(tempfile.gettempdir()), "System temp directory"))
107
+
108
+ # 3. User Home
109
+ try:
110
+ home_temp = pathlib.Path.home() / ".mcp-stata" / "temp"
111
+ candidates.append((home_temp, "User home directory"))
112
+ except Exception:
113
+ pass
114
+
115
+ # 4. Current working directory subdirectory (.tmp)
116
+ candidates.append((pathlib.Path.cwd() / ".tmp", "Working directory (.tmp)"))
117
+
118
+ tested_paths = []
119
+ for path, description in candidates:
120
+ try:
121
+ # Ensure directory exists
122
+ path.mkdir(parents=True, exist_ok=True)
123
+
124
+ # Test writability using standard tempfile logic
125
+ try:
126
+ fd, temp_path = tempfile.mkstemp(
127
+ prefix=".mcp_write_test_",
128
+ suffix=".tmp",
129
+ dir=str(path)
130
+ )
131
+ os.close(fd)
132
+ os.unlink(temp_path)
133
+
134
+ # Success
135
+ validated_path = str(path.absolute())
136
+
137
+ # Log if we fell back from the first preferred (non-env) candidate
138
+ # (System temp is second, index 1 if env_temp is set, else index 0)
139
+ first_preferred_idx = 1 if env_temp else 0
140
+ if candidates.index((path, description)) > first_preferred_idx:
141
+ logger.warning(f"Falling back to temporary directory: {validated_path} ({description})")
142
+ else:
143
+ logger.debug(f"Using temporary directory: {validated_path} ({description})")
144
+
145
+ _temp_dir_cache = validated_path
146
+ # Globally set tempfile.tempdir so other parts of the app and libraries
147
+ # use our validated writable path by default.
148
+ tempfile.tempdir = validated_path
149
+ return validated_path
150
+ except (OSError, PermissionError) as e:
151
+ tested_paths.append(f"{path} ({description}): {e}")
152
+ continue
153
+ except (OSError, PermissionError) as e:
154
+ tested_paths.append(f"{path} ({description}): {e}")
155
+ continue
156
+
157
+ error_msg = "Failed to find any writable temporary directory. Errors:\n" + "\n".join(tested_paths)
158
+ logger.error(error_msg)
159
+ raise RuntimeError(error_msg)
mcp_stata/worker.py ADDED
@@ -0,0 +1,167 @@
1
+ from __future__ import annotations
2
+ import os
3
+ import sys
4
+ import threading
5
+ import logging
6
+ import json
7
+ import traceback
8
+ from typing import Any, Dict, Optional
9
+ from multiprocessing.connection import Connection
10
+ import asyncio
11
+
12
+ # Ensure the parent directory is in sys.path so we can import modules
13
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
+
15
+ from mcp_stata.stata_client import StataClient
16
+
17
+ logger = logging.getLogger("mcp_stata.worker")
18
+
19
+ class StataWorker:
20
+ def __init__(self, conn: Connection):
21
+ self.conn = conn
22
+ self.client: Optional[StataClient] = None
23
+ self.loop = asyncio.new_event_loop()
24
+ asyncio.set_event_loop(self.loop)
25
+
26
+ def run(self):
27
+ """Main loop for the worker process."""
28
+ try:
29
+ # Initialize Stata in this process
30
+ self.client = StataClient()
31
+ # StataClient.init() will be called on first command if not already done,
32
+ # but we can do it here explicitly.
33
+ self.client.init()
34
+
35
+ logger.info("StataWorker initialized and ready.")
36
+ self.conn.send({"event": "ready", "pid": os.getpid()})
37
+
38
+ while True:
39
+ if self.conn.poll(0.1):
40
+ msg = self.conn.recv()
41
+ if msg.get("type") == "stop":
42
+ break
43
+
44
+ # Handle command
45
+ self.loop.run_until_complete(self.handle_message(msg))
46
+ except Exception as e:
47
+ logger.error(f"Worker process failed: {e}")
48
+ self.conn.send({"event": "error", "message": str(e), "traceback": traceback.format_exc()})
49
+ finally:
50
+ logger.info("Worker process exiting.")
51
+ self.conn.close()
52
+
53
+ async def handle_message(self, msg: Dict[str, Any]):
54
+ msg_type = msg.get("type")
55
+ msg_id = msg.get("id")
56
+ args = msg.get("args", {})
57
+
58
+ async def notify_log(text: str):
59
+ self.conn.send({"event": "log", "id": msg_id, "text": text})
60
+
61
+ async def notify_progress(progress: float, total: Optional[float], message: Optional[str]):
62
+ self.conn.send({"event": "progress", "id": msg_id, "progress": progress, "total": total, "message": message})
63
+
64
+ try:
65
+ if msg_type == "run_command":
66
+ result = await self.client.run_command_streaming(
67
+ args["code"],
68
+ notify_log=notify_log,
69
+ notify_progress=notify_progress,
70
+ **args.get("options", {})
71
+ )
72
+ self.conn.send({"event": "result", "id": msg_id, "result": result.model_dump()})
73
+
74
+ elif msg_type == "run_do_file":
75
+ result = await self.client.run_do_file_streaming(
76
+ args["path"],
77
+ notify_log=notify_log,
78
+ notify_progress=notify_progress,
79
+ **args.get("options", {})
80
+ )
81
+ self.conn.send({"event": "result", "id": msg_id, "result": result.model_dump()})
82
+
83
+ elif msg_type == "get_data":
84
+ data = self.client.get_data(args.get("start", 0), args.get("count", 50))
85
+ self.conn.send({"event": "result", "id": msg_id, "result": data})
86
+
87
+ elif msg_type == "list_graphs":
88
+ graphs = self.client.list_graphs_structured()
89
+ self.conn.send({"event": "result", "id": msg_id, "result": graphs.model_dump()})
90
+
91
+ elif msg_type == "export_graph":
92
+ path = self.client.export_graph(args.get("graph_name"), format=args.get("format", "pdf"))
93
+ self.conn.send({"event": "result", "id": msg_id, "result": path})
94
+
95
+ elif msg_type == "get_help":
96
+ help_text = self.client.get_help(args["topic"], plain_text=args.get("plain_text", False))
97
+ self.conn.send({"event": "result", "id": msg_id, "result": help_text})
98
+
99
+ elif msg_type == "run_command_structured":
100
+ result = self.client.run_command_structured(args["code"], **args.get("options", {}))
101
+ self.conn.send({"event": "result", "id": msg_id, "result": result.model_dump()})
102
+
103
+ elif msg_type == "load_data":
104
+ result = self.client.load_data(args["source"], **args.get("options", {}))
105
+ self.conn.send({"event": "result", "id": msg_id, "result": result.model_dump()})
106
+
107
+ elif msg_type == "codebook":
108
+ result = self.client.codebook(args["variable"], **args.get("options", {}))
109
+ self.conn.send({"event": "result", "id": msg_id, "result": result.model_dump()})
110
+
111
+ elif msg_type == "get_dataset_state":
112
+ state = self.client.get_dataset_state()
113
+ self.conn.send({"event": "result", "id": msg_id, "result": state})
114
+
115
+ elif msg_type == "get_arrow_stream":
116
+ # StataClient.get_arrow_stream supports offset, limit, vars, etc.
117
+ arrow_bytes = self.client.get_arrow_stream(**args)
118
+ self.conn.send({"event": "result", "id": msg_id, "result": arrow_bytes})
119
+
120
+ elif msg_type == "list_variables_rich":
121
+ variables = self.client.list_variables_rich()
122
+ self.conn.send({"event": "result", "id": msg_id, "result": variables})
123
+
124
+ elif msg_type == "compute_view_indices":
125
+ indices = self.client.compute_view_indices(args["filter_expr"])
126
+ self.conn.send({"event": "result", "id": msg_id, "result": indices})
127
+
128
+ elif msg_type == "validate_filter_expr":
129
+ self.client.validate_filter_expr(args["filter_expr"])
130
+ self.conn.send({"event": "result", "id": msg_id, "result": None})
131
+
132
+ elif msg_type == "get_page":
133
+ page = self.client.get_page(**args)
134
+ self.conn.send({"event": "result", "id": msg_id, "result": page})
135
+
136
+ elif msg_type == "list_variables_structured":
137
+ variables = self.client.list_variables_structured()
138
+ self.conn.send({"event": "result", "id": msg_id, "result": variables.model_dump()})
139
+
140
+ elif msg_type == "export_graphs_all":
141
+ exports = self.client.export_graphs_all()
142
+ self.conn.send({"event": "result", "id": msg_id, "result": exports.model_dump()})
143
+
144
+ elif msg_type == "get_stored_results":
145
+ results = self.client.get_stored_results()
146
+ self.conn.send({"event": "result", "id": msg_id, "result": results})
147
+
148
+ else:
149
+ self.conn.send({"event": "error", "id": msg_id, "message": f"Unknown message type: {msg_type}"})
150
+
151
+ except Exception as e:
152
+ logger.error(f"Error handling message {msg_type}: {e}")
153
+ self.conn.send({
154
+ "event": "error",
155
+ "id": msg_id,
156
+ "message": str(e),
157
+ "traceback": traceback.format_exc()
158
+ })
159
+
160
+ def main(conn):
161
+ worker = StataWorker(conn)
162
+ worker.run()
163
+
164
+ if __name__ == "__main__":
165
+ # This entry point is used when the process is started via multiprocessing
166
+ # But usually we'll pass the connection object from the parent.
167
+ pass