stata-code 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stata_code/__init__.py +100 -0
- stata_code/core/__init__.py +73 -0
- stata_code/core/_pool.py +808 -0
- stata_code/core/_refs.py +97 -0
- stata_code/core/_runtime.py +179 -0
- stata_code/core/errors.py +447 -0
- stata_code/core/runner.py +1092 -0
- stata_code/core/schema.py +317 -0
- stata_code/kernel/__init__.py +5 -0
- stata_code/kernel/__main__.py +6 -0
- stata_code/kernel/kernel.py +331 -0
- stata_code/mcp/__init__.py +3 -0
- stata_code/mcp/__main__.py +6 -0
- stata_code/mcp/server.py +360 -0
- stata_code-0.3.0.dist-info/METADATA +389 -0
- stata_code-0.3.0.dist-info/RECORD +20 -0
- stata_code-0.3.0.dist-info/WHEEL +4 -0
- stata_code-0.3.0.dist-info/entry_points.txt +3 -0
- stata_code-0.3.0.dist-info/licenses/LICENSE +21 -0
- stata_code-0.3.0.dist-info/licenses/LICENSE-POLICY.md +125 -0
stata_code/core/_pool.py
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
1
|
+
"""Subprocess pool for hard timeout enforcement.
|
|
2
|
+
|
|
3
|
+
`runner.execute()` runs pystata in-process, which is fast but offers no
|
|
4
|
+
clean way to cancel a long-running Stata command (`bayes`, `bootstrap,
|
|
5
|
+
reps(10000)`, infinite loop, ...). pystata holds the GIL inside C land
|
|
6
|
+
and ignores Python signals until it returns. v0.2 documented this:
|
|
7
|
+
`timeout_ms` was accepted but not enforced.
|
|
8
|
+
|
|
9
|
+
This module fills that gap. `pool_execute()` is a drop-in for
|
|
10
|
+
`runner.execute()` that routes the call through a per-session
|
|
11
|
+
subprocess. The parent enforces `timeout_ms` by reading from the
|
|
12
|
+
worker's stdout with a deadline; on overrun, the worker is SIGTERMed
|
|
13
|
+
(grace period) then SIGKILLed, and the parent returns a synthetic
|
|
14
|
+
`RunResult(rc=-2, error.kind="timeout")`. The dead worker is dropped
|
|
15
|
+
from the pool and respawned on the next call to that `session_id`.
|
|
16
|
+
|
|
17
|
+
Design choices:
|
|
18
|
+
|
|
19
|
+
- **One worker per session_id.** Stata is single-threaded; serialize per
|
|
20
|
+
session at the worker level. Different sessions get different
|
|
21
|
+
workers and run truly in parallel.
|
|
22
|
+
- **Workers are warm.** First call to a new session pays the pystata
|
|
23
|
+
init cost (~3-10s); subsequent calls are pipe-roundtrip + JSON only
|
|
24
|
+
(typically <50ms overhead).
|
|
25
|
+
- **Refs are ferried.** The worker's `_refs` store is local to that
|
|
26
|
+
process. After each `execute()`, the parent harvests any newly-
|
|
27
|
+
created refs and re-puts them in its OWN `_refs` so that
|
|
28
|
+
`get_log/get_graph/get_matrix` calls served by the parent (the MCP
|
|
29
|
+
server, typically) hit the parent's store directly without IPC.
|
|
30
|
+
- **Wire protocol.** Newline-delimited JSON. Request: one line in
|
|
31
|
+
`{id, code, options}`. Response: one line in
|
|
32
|
+
`{id, ok, result, ref_blobs}` (or `{id, ok=false, error}`).
|
|
33
|
+
|
|
34
|
+
Not exposed in the public API. Frontends import `pool_execute` from
|
|
35
|
+
this module if they want the timeout guarantee.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
from __future__ import annotations
|
|
39
|
+
|
|
40
|
+
import base64
|
|
41
|
+
import json
|
|
42
|
+
import os
|
|
43
|
+
import subprocess
|
|
44
|
+
import sys
|
|
45
|
+
import threading
|
|
46
|
+
import time
|
|
47
|
+
import uuid
|
|
48
|
+
from datetime import datetime, timezone
|
|
49
|
+
from typing import Any
|
|
50
|
+
|
|
51
|
+
from stata_code.core import _refs
|
|
52
|
+
from stata_code.core.schema import (
|
|
53
|
+
Backend,
|
|
54
|
+
DatasetInfo,
|
|
55
|
+
ErrorContext,
|
|
56
|
+
ErrorInfo,
|
|
57
|
+
ErrorKind,
|
|
58
|
+
LogInfo,
|
|
59
|
+
ResultsInfo,
|
|
60
|
+
RunResult,
|
|
61
|
+
StataInfo,
|
|
62
|
+
StataReturns,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
# Grace period after SIGTERM before SIGKILL.
|
|
66
|
+
_KILL_GRACE_S = 2.0
|
|
67
|
+
|
|
68
|
+
# Read-poll interval while waiting for worker stdout.
|
|
69
|
+
_POLL_INTERVAL_S = 0.05
|
|
70
|
+
|
|
71
|
+
# Default pool capacity: how many distinct sessions to keep warm.
|
|
72
|
+
_DEFAULT_POOL_CAPACITY = 8
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
76
|
+
# Worker side: runs in the SUBPROCESS, not the parent.
|
|
77
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
_BYTES_MARKER = "__bytes_b64__"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _to_jsonable(obj: Any) -> Any:
|
|
84
|
+
"""Recursively rewrite bytes/bytearray as `{__bytes_b64__: <base64>}`.
|
|
85
|
+
|
|
86
|
+
Other primitives (str, int, float, bool, None) pass through. Nested
|
|
87
|
+
dicts and lists are walked. This lets us ferry the runner's graph
|
|
88
|
+
payload — which is a dict that *contains* bytes (`{"format": "png",
|
|
89
|
+
"bytes": <data>, "width": ..., "height": ...}`) — over a JSON pipe
|
|
90
|
+
without losing structure.
|
|
91
|
+
"""
|
|
92
|
+
if isinstance(obj, (bytes, bytearray)):
|
|
93
|
+
return {_BYTES_MARKER: base64.b64encode(bytes(obj)).decode("ascii")}
|
|
94
|
+
if isinstance(obj, dict):
|
|
95
|
+
return {k: _to_jsonable(v) for k, v in obj.items()}
|
|
96
|
+
if isinstance(obj, list):
|
|
97
|
+
return [_to_jsonable(v) for v in obj]
|
|
98
|
+
if isinstance(obj, tuple):
|
|
99
|
+
return [_to_jsonable(v) for v in obj]
|
|
100
|
+
return obj
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _from_jsonable(obj: Any) -> Any:
|
|
104
|
+
"""Inverse of `_to_jsonable`. Restores bytes from the marker form."""
|
|
105
|
+
if isinstance(obj, dict):
|
|
106
|
+
if _BYTES_MARKER in obj and len(obj) == 1:
|
|
107
|
+
data = obj[_BYTES_MARKER]
|
|
108
|
+
return base64.b64decode(data) if isinstance(data, str) else b""
|
|
109
|
+
return {k: _from_jsonable(v) for k, v in obj.items()}
|
|
110
|
+
if isinstance(obj, list):
|
|
111
|
+
return [_from_jsonable(v) for v in obj]
|
|
112
|
+
return obj
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _payload_to_wire(payload: Any) -> dict[str, Any]:
|
|
116
|
+
"""Serialize a `_refs` payload to a JSON-safe envelope.
|
|
117
|
+
|
|
118
|
+
`_refs` stores three payload shapes today:
|
|
119
|
+
- bytes: graph data (legacy / direct path)
|
|
120
|
+
- str: full log text (legacy / direct path)
|
|
121
|
+
- dict: log info `{text, lines_total, bytes_total}`,
|
|
122
|
+
matrix info `{rows, cols, values}`,
|
|
123
|
+
graph info `{format, bytes, width, height}` — note the
|
|
124
|
+
nested bytes inside this last shape, which is why dict
|
|
125
|
+
payloads go through `_to_jsonable`.
|
|
126
|
+
"""
|
|
127
|
+
if isinstance(payload, (bytes, bytearray)):
|
|
128
|
+
return {"kind": "bytes", "data": base64.b64encode(bytes(payload)).decode("ascii")}
|
|
129
|
+
if isinstance(payload, str):
|
|
130
|
+
return {"kind": "text", "data": payload}
|
|
131
|
+
if isinstance(payload, dict):
|
|
132
|
+
return {"kind": "json", "data": _to_jsonable(payload)}
|
|
133
|
+
# Conservative fallback — wrap repr; the parent will ignore unknown kinds.
|
|
134
|
+
return {"kind": "unknown", "data": repr(payload)}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _payload_from_wire(envelope: dict[str, Any]) -> Any:
|
|
138
|
+
kind = envelope.get("kind")
|
|
139
|
+
data = envelope.get("data")
|
|
140
|
+
if kind == "bytes":
|
|
141
|
+
return base64.b64decode(data) if isinstance(data, str) else b""
|
|
142
|
+
if kind == "text":
|
|
143
|
+
return data if isinstance(data, str) else ""
|
|
144
|
+
if kind == "json":
|
|
145
|
+
return _from_jsonable(data)
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _worker_main() -> int:
|
|
150
|
+
"""Worker entry. Reads one JSON request per line from stdin, writes one
|
|
151
|
+
JSON response per line to stdout. Exits 0 on EOF.
|
|
152
|
+
"""
|
|
153
|
+
# CRITICAL: dup FDs 0 and 1 BEFORE pystata gets imported / initialized.
|
|
154
|
+
# `pystata.config.init()` reaches into the runtime's FDs (closes /
|
|
155
|
+
# redirects stdin, redirects stdout to its own buffer). If we read /
|
|
156
|
+
# write through `sys.stdin` / `sys.stdout` the protocol breaks the
|
|
157
|
+
# moment Stata initializes — subsequent reads see EOF and the worker
|
|
158
|
+
# exits returncode=0 without responding to the second request.
|
|
159
|
+
#
|
|
160
|
+
# Duping the file descriptors gives us a reader/writer rooted at a
|
|
161
|
+
# *separate* FD that pystata cannot reach via the sys.* indirection.
|
|
162
|
+
saved_stdin_fd = os.dup(0)
|
|
163
|
+
saved_stdout_fd = os.dup(1)
|
|
164
|
+
proto_in = os.fdopen(saved_stdin_fd, "r", buffering=1, encoding="utf-8")
|
|
165
|
+
proto_out = os.fdopen(saved_stdout_fd, "w", buffering=1, encoding="utf-8")
|
|
166
|
+
|
|
167
|
+
# Imported here so `python -m stata_code.core._pool` can fail loudly
|
|
168
|
+
# only if a request actually arrives — listing the worker as a tool
|
|
169
|
+
# candidate shouldn't cost a Stata init.
|
|
170
|
+
from stata_code.core.runner import execute
|
|
171
|
+
|
|
172
|
+
# Explicit readline() loop instead of `for line in proto_in` — the latter
|
|
173
|
+
# uses the io module's buffered iterator, which read-aheads more bytes
|
|
174
|
+
# than are available on a pipe and breaks the request/response cadence
|
|
175
|
+
# after pystata init.
|
|
176
|
+
while True:
|
|
177
|
+
line = proto_in.readline()
|
|
178
|
+
if not line:
|
|
179
|
+
break # EOF — parent closed the pipe
|
|
180
|
+
line = line.strip()
|
|
181
|
+
if not line:
|
|
182
|
+
continue
|
|
183
|
+
req_id: str | None = None
|
|
184
|
+
try:
|
|
185
|
+
req = json.loads(line)
|
|
186
|
+
req_id = req.get("id")
|
|
187
|
+
op = req.get("op", "execute")
|
|
188
|
+
if op == "ping":
|
|
189
|
+
response = {"id": req_id, "ok": True, "pong": True}
|
|
190
|
+
elif op == "list_sessions":
|
|
191
|
+
# Imported lazily — calling list_sessions() on a worker that
|
|
192
|
+
# hasn't yet had any execute() request still triggers pystata
|
|
193
|
+
# init, which is the price of an honest answer.
|
|
194
|
+
from stata_code.core.runner import list_sessions as _ls
|
|
195
|
+
|
|
196
|
+
response = {"id": req_id, "ok": True, "sessions": _ls()}
|
|
197
|
+
elif op == "execute":
|
|
198
|
+
code = req["code"]
|
|
199
|
+
options = req.get("options", {})
|
|
200
|
+
# Snapshot ref keys before so we can ferry only the *new* ones.
|
|
201
|
+
# _refs._store is private but we own this codebase.
|
|
202
|
+
with _refs._lock: # noqa: SLF001
|
|
203
|
+
keys_before = set(_refs._store.keys()) # noqa: SLF001
|
|
204
|
+
result = execute(code, **options)
|
|
205
|
+
with _refs._lock: # noqa: SLF001
|
|
206
|
+
keys_after = set(_refs._store.keys()) # noqa: SLF001
|
|
207
|
+
new_keys = keys_after - keys_before
|
|
208
|
+
ref_blobs: dict[str, dict[str, Any]] = {}
|
|
209
|
+
for k in new_keys:
|
|
210
|
+
payload = _refs.get(k)
|
|
211
|
+
if payload is None:
|
|
212
|
+
continue
|
|
213
|
+
ref_blobs[k] = _payload_to_wire(payload)
|
|
214
|
+
response = {
|
|
215
|
+
"id": req_id,
|
|
216
|
+
"ok": True,
|
|
217
|
+
"result": json.loads(result.model_dump_json()),
|
|
218
|
+
"ref_blobs": ref_blobs,
|
|
219
|
+
}
|
|
220
|
+
else:
|
|
221
|
+
response = {"id": req_id, "ok": False, "error": f"unknown op: {op}"}
|
|
222
|
+
except Exception as exc: # noqa: BLE001
|
|
223
|
+
response = {
|
|
224
|
+
"id": req_id,
|
|
225
|
+
"ok": False,
|
|
226
|
+
"error": f"{type(exc).__name__}: {exc}",
|
|
227
|
+
}
|
|
228
|
+
proto_out.write(json.dumps(response) + "\n")
|
|
229
|
+
proto_out.flush()
|
|
230
|
+
return 0
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
234
|
+
# Parent side: WorkerProcess + SessionPool + pool_execute().
|
|
235
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _default_worker_cmd() -> list[str]:
|
|
239
|
+
return [sys.executable, "-u", "-m", "stata_code.core._pool"]
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class _WorkerError(RuntimeError):
|
|
243
|
+
"""Raised on worker-side execution failure (non-Stata, e.g., crash)."""
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class _WorkerTimeout(TimeoutError):
|
|
247
|
+
"""Raised when a request exceeds its deadline. Parent kills the worker."""
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class WorkerProcess:
|
|
251
|
+
"""Parent-side handle for one subprocess worker.
|
|
252
|
+
|
|
253
|
+
Construct via `WorkerProcess(session_id, ...)` — the subprocess is
|
|
254
|
+
spawned lazily on first `execute()`. Use `kill()` to terminate.
|
|
255
|
+
Workers are not thread-safe internally; the pool serializes calls.
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
def __init__(
|
|
259
|
+
self,
|
|
260
|
+
session_id: str,
|
|
261
|
+
*,
|
|
262
|
+
worker_cmd: list[str] | None = None,
|
|
263
|
+
) -> None:
|
|
264
|
+
self.session_id = session_id
|
|
265
|
+
self._cmd = list(worker_cmd) if worker_cmd is not None else _default_worker_cmd()
|
|
266
|
+
self._proc: subprocess.Popen[str] | None = None
|
|
267
|
+
self._lock = threading.Lock()
|
|
268
|
+
self.last_used: float = time.monotonic()
|
|
269
|
+
|
|
270
|
+
def _spawn(self) -> subprocess.Popen[str]:
|
|
271
|
+
env = os.environ.copy()
|
|
272
|
+
# Force unbuffered I/O even if PYTHONUNBUFFERED isn't already set.
|
|
273
|
+
env.setdefault("PYTHONUNBUFFERED", "1")
|
|
274
|
+
return subprocess.Popen(
|
|
275
|
+
self._cmd,
|
|
276
|
+
stdin=subprocess.PIPE,
|
|
277
|
+
stdout=subprocess.PIPE,
|
|
278
|
+
stderr=subprocess.PIPE,
|
|
279
|
+
text=True,
|
|
280
|
+
bufsize=1, # line-buffered
|
|
281
|
+
env=env,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def _ensure_alive(self) -> subprocess.Popen[str]:
|
|
285
|
+
if self._proc is None or self._proc.poll() is not None:
|
|
286
|
+
self._proc = self._spawn()
|
|
287
|
+
return self._proc
|
|
288
|
+
|
|
289
|
+
def is_alive(self) -> bool:
|
|
290
|
+
return self._proc is not None and self._proc.poll() is None
|
|
291
|
+
|
|
292
|
+
def execute(
|
|
293
|
+
self,
|
|
294
|
+
code: str,
|
|
295
|
+
options: dict[str, Any],
|
|
296
|
+
*,
|
|
297
|
+
timeout_ms: int | None,
|
|
298
|
+
) -> tuple[dict[str, Any], dict[str, dict[str, Any]]]:
|
|
299
|
+
"""Send one execute request and return (result_dict, ref_blobs).
|
|
300
|
+
|
|
301
|
+
Raises `_WorkerTimeout` on timeout (caller is responsible for
|
|
302
|
+
killing the worker), `_WorkerError` on protocol or worker-side
|
|
303
|
+
crash.
|
|
304
|
+
"""
|
|
305
|
+
with self._lock:
|
|
306
|
+
proc = self._ensure_alive()
|
|
307
|
+
assert proc.stdin is not None and proc.stdout is not None # for mypy
|
|
308
|
+
req_id = uuid.uuid4().hex
|
|
309
|
+
request = {"id": req_id, "op": "execute", "code": code, "options": options}
|
|
310
|
+
try:
|
|
311
|
+
proc.stdin.write(json.dumps(request) + "\n")
|
|
312
|
+
proc.stdin.flush()
|
|
313
|
+
except BrokenPipeError as exc:
|
|
314
|
+
raise _WorkerError(f"worker pipe broken on write: {exc}") from exc
|
|
315
|
+
|
|
316
|
+
deadline: float | None
|
|
317
|
+
deadline = None if timeout_ms is None else time.monotonic() + timeout_ms / 1000.0
|
|
318
|
+
|
|
319
|
+
line = self._readline_with_deadline(proc, deadline)
|
|
320
|
+
self.last_used = time.monotonic()
|
|
321
|
+
|
|
322
|
+
try:
|
|
323
|
+
response = json.loads(line)
|
|
324
|
+
except json.JSONDecodeError as exc:
|
|
325
|
+
raise _WorkerError(f"worker emitted non-JSON: {line!r}") from exc
|
|
326
|
+
|
|
327
|
+
if response.get("id") != req_id:
|
|
328
|
+
raise _WorkerError(
|
|
329
|
+
f"worker response id mismatch: expected {req_id}, got {response.get('id')}"
|
|
330
|
+
)
|
|
331
|
+
if not response.get("ok"):
|
|
332
|
+
raise _WorkerError(
|
|
333
|
+
f"worker reported failure: {response.get('error', '<no error>')}"
|
|
334
|
+
)
|
|
335
|
+
return response["result"], response.get("ref_blobs", {})
|
|
336
|
+
|
|
337
|
+
@staticmethod
|
|
338
|
+
def _readline_with_deadline(
|
|
339
|
+
proc: subprocess.Popen[str],
|
|
340
|
+
deadline: float | None,
|
|
341
|
+
) -> str:
|
|
342
|
+
"""Read one line from `proc.stdout` honoring an optional wall-clock
|
|
343
|
+
deadline. Raises `_WorkerTimeout` on overrun.
|
|
344
|
+
|
|
345
|
+
Implementation note: we poll `proc.poll()` plus a short readline
|
|
346
|
+
in a thread, joining with the remaining budget. This is portable
|
|
347
|
+
(no select on Windows pipes) and robust for the line-oriented
|
|
348
|
+
protocol.
|
|
349
|
+
"""
|
|
350
|
+
assert proc.stdout is not None
|
|
351
|
+
result: dict[str, str | BaseException] = {}
|
|
352
|
+
|
|
353
|
+
def _reader() -> None:
|
|
354
|
+
try:
|
|
355
|
+
line = proc.stdout.readline() # type: ignore[union-attr]
|
|
356
|
+
result["line"] = line
|
|
357
|
+
except BaseException as exc: # noqa: BLE001
|
|
358
|
+
result["err"] = exc
|
|
359
|
+
|
|
360
|
+
thread = threading.Thread(target=_reader, daemon=True)
|
|
361
|
+
thread.start()
|
|
362
|
+
|
|
363
|
+
while True:
|
|
364
|
+
if deadline is None:
|
|
365
|
+
remaining = None
|
|
366
|
+
else:
|
|
367
|
+
remaining = deadline - time.monotonic()
|
|
368
|
+
if remaining <= 0:
|
|
369
|
+
raise _WorkerTimeout("deadline exceeded waiting for worker response")
|
|
370
|
+
thread.join(timeout=min(_POLL_INTERVAL_S, remaining) if remaining is not None else _POLL_INTERVAL_S)
|
|
371
|
+
if not thread.is_alive():
|
|
372
|
+
if "err" in result:
|
|
373
|
+
raise _WorkerError(f"reader thread error: {result['err']!r}")
|
|
374
|
+
line = result.get("line", "")
|
|
375
|
+
assert isinstance(line, str)
|
|
376
|
+
if not line:
|
|
377
|
+
# EOF — worker exited or pipe closed unexpectedly.
|
|
378
|
+
rc = proc.poll()
|
|
379
|
+
raise _WorkerError(f"worker exited (returncode={rc}) before responding")
|
|
380
|
+
return line
|
|
381
|
+
# Worker still running but no line yet. If the process died, surface that.
|
|
382
|
+
if proc.poll() is not None:
|
|
383
|
+
# Wait briefly for any final bytes the reader thread might catch.
|
|
384
|
+
thread.join(timeout=0.1)
|
|
385
|
+
if "line" in result and isinstance(result["line"], str) and result["line"]:
|
|
386
|
+
return result["line"]
|
|
387
|
+
rc = proc.returncode
|
|
388
|
+
raise _WorkerError(f"worker exited (returncode={rc}) before responding")
|
|
389
|
+
|
|
390
|
+
def send_simple_op(
|
|
391
|
+
self,
|
|
392
|
+
op: str,
|
|
393
|
+
*,
|
|
394
|
+
timeout_ms: int | None,
|
|
395
|
+
) -> dict[str, Any]:
|
|
396
|
+
"""Send a no-payload op (e.g., ``ping``, ``list_sessions``) and return
|
|
397
|
+
the full response dict.
|
|
398
|
+
|
|
399
|
+
Unlike :meth:`execute`, this does **not** respawn a dead worker —
|
|
400
|
+
if the subprocess isn't running, raises :class:`_WorkerError`.
|
|
401
|
+
Caller should treat that as "this worker has nothing to report"
|
|
402
|
+
rather than block on a fresh pystata init for a status query.
|
|
403
|
+
"""
|
|
404
|
+
with self._lock:
|
|
405
|
+
if self._proc is None or self._proc.poll() is not None:
|
|
406
|
+
raise _WorkerError(f"worker for {self.session_id!r} not running")
|
|
407
|
+
proc = self._proc
|
|
408
|
+
assert proc.stdin is not None and proc.stdout is not None # for mypy
|
|
409
|
+
req_id = uuid.uuid4().hex
|
|
410
|
+
request = {"id": req_id, "op": op}
|
|
411
|
+
try:
|
|
412
|
+
proc.stdin.write(json.dumps(request) + "\n")
|
|
413
|
+
proc.stdin.flush()
|
|
414
|
+
except BrokenPipeError as exc:
|
|
415
|
+
raise _WorkerError(f"worker pipe broken on write: {exc}") from exc
|
|
416
|
+
|
|
417
|
+
deadline: float | None
|
|
418
|
+
deadline = None if timeout_ms is None else time.monotonic() + timeout_ms / 1000.0
|
|
419
|
+
|
|
420
|
+
line = self._readline_with_deadline(proc, deadline)
|
|
421
|
+
self.last_used = time.monotonic()
|
|
422
|
+
|
|
423
|
+
try:
|
|
424
|
+
response = json.loads(line)
|
|
425
|
+
except json.JSONDecodeError as exc:
|
|
426
|
+
raise _WorkerError(f"worker emitted non-JSON: {line!r}") from exc
|
|
427
|
+
|
|
428
|
+
if response.get("id") != req_id:
|
|
429
|
+
raise _WorkerError(
|
|
430
|
+
f"worker response id mismatch: expected {req_id}, got {response.get('id')}"
|
|
431
|
+
)
|
|
432
|
+
if not response.get("ok"):
|
|
433
|
+
raise _WorkerError(
|
|
434
|
+
f"worker reported failure: {response.get('error', '<no error>')}"
|
|
435
|
+
)
|
|
436
|
+
return response
|
|
437
|
+
|
|
438
|
+
def kill(self) -> None:
|
|
439
|
+
"""Terminate the worker. SIGTERM with grace, then SIGKILL."""
|
|
440
|
+
with self._lock:
|
|
441
|
+
self._kill_locked()
|
|
442
|
+
|
|
443
|
+
def _kill_locked(self) -> None:
|
|
444
|
+
if self._proc is None:
|
|
445
|
+
return
|
|
446
|
+
if self._proc.poll() is not None:
|
|
447
|
+
self._proc = None
|
|
448
|
+
return
|
|
449
|
+
try:
|
|
450
|
+
self._proc.terminate()
|
|
451
|
+
try:
|
|
452
|
+
self._proc.wait(timeout=_KILL_GRACE_S)
|
|
453
|
+
except subprocess.TimeoutExpired:
|
|
454
|
+
self._proc.kill()
|
|
455
|
+
try:
|
|
456
|
+
self._proc.wait(timeout=_KILL_GRACE_S)
|
|
457
|
+
except subprocess.TimeoutExpired:
|
|
458
|
+
pass
|
|
459
|
+
except ProcessLookupError:
|
|
460
|
+
pass
|
|
461
|
+
# Drain any buffered stderr so we don't leak a fd.
|
|
462
|
+
try:
|
|
463
|
+
if self._proc.stderr is not None:
|
|
464
|
+
self._proc.stderr.read()
|
|
465
|
+
except Exception: # noqa: BLE001
|
|
466
|
+
pass
|
|
467
|
+
self._proc = None
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
class SessionPool:
|
|
471
|
+
"""LRU pool of subprocess workers, keyed by `session_id`."""
|
|
472
|
+
|
|
473
|
+
def __init__(
|
|
474
|
+
self,
|
|
475
|
+
*,
|
|
476
|
+
capacity: int = _DEFAULT_POOL_CAPACITY,
|
|
477
|
+
worker_cmd: list[str] | None = None,
|
|
478
|
+
) -> None:
|
|
479
|
+
if capacity < 1:
|
|
480
|
+
raise ValueError("capacity must be ≥ 1")
|
|
481
|
+
self._capacity = capacity
|
|
482
|
+
self._worker_cmd = worker_cmd
|
|
483
|
+
self._workers: dict[str, WorkerProcess] = {}
|
|
484
|
+
self._lock = threading.Lock()
|
|
485
|
+
|
|
486
|
+
@property
|
|
487
|
+
def capacity(self) -> int:
|
|
488
|
+
return self._capacity
|
|
489
|
+
|
|
490
|
+
def _get_or_spawn(self, session_id: str) -> WorkerProcess:
|
|
491
|
+
with self._lock:
|
|
492
|
+
w = self._workers.get(session_id)
|
|
493
|
+
if w is None or not w.is_alive():
|
|
494
|
+
if w is not None:
|
|
495
|
+
# Existing-but-dead — clean up before respawn.
|
|
496
|
+
w.kill()
|
|
497
|
+
w = WorkerProcess(session_id, worker_cmd=self._worker_cmd)
|
|
498
|
+
self._workers[session_id] = w
|
|
499
|
+
self._evict_to_capacity_locked(keep=session_id)
|
|
500
|
+
return w
|
|
501
|
+
|
|
502
|
+
def _evict_to_capacity_locked(self, *, keep: str) -> None:
|
|
503
|
+
if len(self._workers) <= self._capacity:
|
|
504
|
+
return
|
|
505
|
+
# LRU by `last_used`. Never evict the just-added worker.
|
|
506
|
+
candidates = sorted(
|
|
507
|
+
((sid, w) for sid, w in self._workers.items() if sid != keep),
|
|
508
|
+
key=lambda kv: kv[1].last_used,
|
|
509
|
+
)
|
|
510
|
+
while len(self._workers) > self._capacity and candidates:
|
|
511
|
+
sid, w = candidates.pop(0)
|
|
512
|
+
w.kill()
|
|
513
|
+
self._workers.pop(sid, None)
|
|
514
|
+
|
|
515
|
+
def execute(
|
|
516
|
+
self,
|
|
517
|
+
code: str,
|
|
518
|
+
*,
|
|
519
|
+
session_id: str = "main",
|
|
520
|
+
timeout_ms: int | None = 600_000,
|
|
521
|
+
**options: Any,
|
|
522
|
+
) -> RunResult:
|
|
523
|
+
"""Execute `code` in the session's worker, enforcing `timeout_ms`.
|
|
524
|
+
|
|
525
|
+
On timeout: SIGTERM/SIGKILL the worker and return a synthetic
|
|
526
|
+
`RunResult(rc=-2, error.kind="timeout")`. Subsequent calls to the
|
|
527
|
+
same `session_id` will respawn a fresh worker.
|
|
528
|
+
"""
|
|
529
|
+
# Normalize: pass session_id through to the worker so it routes
|
|
530
|
+
# to the right Stata frame. timeout_ms is enforced HERE — the
|
|
531
|
+
# worker doesn't see it.
|
|
532
|
+
worker_options = {**options, "session_id": session_id}
|
|
533
|
+
# Forward timeout_ms verbatim so the worker stores it on the result
|
|
534
|
+
# for observability, even though the real enforcement is parent-side.
|
|
535
|
+
worker_options.setdefault("timeout_ms", timeout_ms)
|
|
536
|
+
worker = self._get_or_spawn(session_id)
|
|
537
|
+
started = time.monotonic()
|
|
538
|
+
try:
|
|
539
|
+
result_dict, ref_blobs = worker.execute(
|
|
540
|
+
code, worker_options, timeout_ms=timeout_ms
|
|
541
|
+
)
|
|
542
|
+
except _WorkerTimeout:
|
|
543
|
+
worker.kill()
|
|
544
|
+
with self._lock:
|
|
545
|
+
self._workers.pop(session_id, None)
|
|
546
|
+
elapsed_ms = int((time.monotonic() - started) * 1000)
|
|
547
|
+
return _build_timeout_result(
|
|
548
|
+
session_id=session_id,
|
|
549
|
+
elapsed_ms=elapsed_ms,
|
|
550
|
+
timeout_ms=timeout_ms or 0,
|
|
551
|
+
)
|
|
552
|
+
except _WorkerError as exc:
|
|
553
|
+
worker.kill()
|
|
554
|
+
with self._lock:
|
|
555
|
+
self._workers.pop(session_id, None)
|
|
556
|
+
elapsed_ms = int((time.monotonic() - started) * 1000)
|
|
557
|
+
return _build_adapter_crash_result(
|
|
558
|
+
session_id=session_id,
|
|
559
|
+
elapsed_ms=elapsed_ms,
|
|
560
|
+
message=str(exc),
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
# Ferry refs into the parent's _refs store.
|
|
564
|
+
for ref_id, envelope in ref_blobs.items():
|
|
565
|
+
payload = _payload_from_wire(envelope)
|
|
566
|
+
if payload is not None:
|
|
567
|
+
_refs.put(ref_id, payload)
|
|
568
|
+
|
|
569
|
+
return RunResult.model_validate(result_dict)
|
|
570
|
+
|
|
571
|
+
def kill_session(self, session_id: str) -> bool:
|
|
572
|
+
"""Terminate a session's worker. Returns True if a worker existed."""
|
|
573
|
+
with self._lock:
|
|
574
|
+
w = self._workers.pop(session_id, None)
|
|
575
|
+
if w is None:
|
|
576
|
+
return False
|
|
577
|
+
w.kill()
|
|
578
|
+
return True
|
|
579
|
+
|
|
580
|
+
def shutdown(self) -> None:
|
|
581
|
+
"""Kill all workers."""
|
|
582
|
+
with self._lock:
|
|
583
|
+
workers = list(self._workers.values())
|
|
584
|
+
self._workers.clear()
|
|
585
|
+
for w in workers:
|
|
586
|
+
w.kill()
|
|
587
|
+
|
|
588
|
+
def session_ids(self) -> list[str]:
|
|
589
|
+
with self._lock:
|
|
590
|
+
return list(self._workers)
|
|
591
|
+
|
|
592
|
+
def list_session_info(
|
|
593
|
+
self,
|
|
594
|
+
*,
|
|
595
|
+
per_worker_timeout_ms: int = 5000,
|
|
596
|
+
) -> list[dict[str, Any]]:
|
|
597
|
+
"""Aggregate live-session info across all workers.
|
|
598
|
+
|
|
599
|
+
For each worker that's alive, sends ``op=list_sessions`` and pulls
|
|
600
|
+
back ``[{session_id, frame, n_obs}, ...]`` from that worker's
|
|
601
|
+
pystata. The pool dedupes by ``session_id`` (first writer wins) and
|
|
602
|
+
returns the union.
|
|
603
|
+
|
|
604
|
+
Workers that are dead, that fail to respond within
|
|
605
|
+
``per_worker_timeout_ms``, or that raise a protocol error are
|
|
606
|
+
silently skipped — partial information is better than failing the
|
|
607
|
+
whole list call. Workers that haven't yet served an ``execute``
|
|
608
|
+
will pay the pystata-init cost on the next ``stata_run``, not here:
|
|
609
|
+
:meth:`WorkerProcess.send_simple_op` deliberately does **not**
|
|
610
|
+
respawn dead workers.
|
|
611
|
+
"""
|
|
612
|
+
with self._lock:
|
|
613
|
+
workers = list(self._workers.items())
|
|
614
|
+
sessions: list[dict[str, Any]] = []
|
|
615
|
+
seen: set[str] = set()
|
|
616
|
+
for _sid, worker in workers:
|
|
617
|
+
if not worker.is_alive():
|
|
618
|
+
continue
|
|
619
|
+
try:
|
|
620
|
+
response = worker.send_simple_op(
|
|
621
|
+
"list_sessions", timeout_ms=per_worker_timeout_ms
|
|
622
|
+
)
|
|
623
|
+
except (_WorkerError, _WorkerTimeout):
|
|
624
|
+
continue
|
|
625
|
+
for entry in response.get("sessions") or []:
|
|
626
|
+
sid = entry.get("session_id")
|
|
627
|
+
if sid is None or sid in seen:
|
|
628
|
+
continue
|
|
629
|
+
seen.add(sid)
|
|
630
|
+
sessions.append(entry)
|
|
631
|
+
return sessions
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
635
|
+
# Synthetic-result builders for timeout / adapter crash.
|
|
636
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
637
|
+
|
|
638
|
+
|
|
639
|
+
def _utc_iso_ms() -> str:
|
|
640
|
+
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.") + (
|
|
641
|
+
f"{datetime.now(timezone.utc).microsecond // 1000:03d}Z"
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
|
|
645
|
+
def _empty_returns() -> StataReturns:
|
|
646
|
+
return StataReturns(scalars={}, macros={}, matrices={})
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def _empty_dataset() -> DatasetInfo:
|
|
650
|
+
return DatasetInfo(
|
|
651
|
+
frame="default",
|
|
652
|
+
n_obs=0,
|
|
653
|
+
n_vars=0,
|
|
654
|
+
changed=False,
|
|
655
|
+
filename=None,
|
|
656
|
+
variables=None,
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
def _build_timeout_result(
|
|
661
|
+
*,
|
|
662
|
+
session_id: str,
|
|
663
|
+
elapsed_ms: int,
|
|
664
|
+
timeout_ms: int,
|
|
665
|
+
) -> RunResult:
|
|
666
|
+
err = ErrorInfo(
|
|
667
|
+
kind=ErrorKind.TIMEOUT,
|
|
668
|
+
rc=-2,
|
|
669
|
+
rc_label="timeout",
|
|
670
|
+
message=(
|
|
671
|
+
f"Execution exceeded the configured timeout of {timeout_ms} ms. "
|
|
672
|
+
f"The worker process for session_id={session_id!r} was terminated."
|
|
673
|
+
),
|
|
674
|
+
command=None,
|
|
675
|
+
line=None,
|
|
676
|
+
context=ErrorContext(before=[], failing="", after=[]),
|
|
677
|
+
commands_executed=None,
|
|
678
|
+
path=None,
|
|
679
|
+
varname=None,
|
|
680
|
+
name=None,
|
|
681
|
+
suggestions=[],
|
|
682
|
+
)
|
|
683
|
+
return RunResult(
|
|
684
|
+
ok=False,
|
|
685
|
+
rc=-2,
|
|
686
|
+
session_id=session_id,
|
|
687
|
+
request_id=uuid.uuid4().hex,
|
|
688
|
+
started_at=_utc_iso_ms(),
|
|
689
|
+
elapsed_ms=elapsed_ms,
|
|
690
|
+
stata_elapsed_ms=elapsed_ms,
|
|
691
|
+
stata=StataInfo(version="unknown", edition="unknown", backend=Backend.PYSTATA),
|
|
692
|
+
log=LogInfo(
|
|
693
|
+
head="",
|
|
694
|
+
tail="",
|
|
695
|
+
lines_total=0,
|
|
696
|
+
bytes_total=0,
|
|
697
|
+
truncated=False,
|
|
698
|
+
complete=False,
|
|
699
|
+
error_window=None,
|
|
700
|
+
ref=None,
|
|
701
|
+
),
|
|
702
|
+
results=ResultsInfo(r=_empty_returns(), e=_empty_returns(), last_estimation_cmd=None),
|
|
703
|
+
dataset=_empty_dataset(),
|
|
704
|
+
graphs=[],
|
|
705
|
+
warnings=[],
|
|
706
|
+
error=err,
|
|
707
|
+
schema_version="1.0",
|
|
708
|
+
capabilities=["pystata", "subprocess_timeout"],
|
|
709
|
+
)
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
def _build_adapter_crash_result(
|
|
713
|
+
*,
|
|
714
|
+
session_id: str,
|
|
715
|
+
elapsed_ms: int,
|
|
716
|
+
message: str,
|
|
717
|
+
) -> RunResult:
|
|
718
|
+
err = ErrorInfo(
|
|
719
|
+
kind=ErrorKind.ADAPTER_CRASH,
|
|
720
|
+
rc=-1,
|
|
721
|
+
rc_label="adapter_crash",
|
|
722
|
+
message=f"Subprocess worker crashed: {message}",
|
|
723
|
+
command=None,
|
|
724
|
+
line=None,
|
|
725
|
+
context=ErrorContext(before=[], failing="", after=[]),
|
|
726
|
+
commands_executed=None,
|
|
727
|
+
path=None,
|
|
728
|
+
varname=None,
|
|
729
|
+
name=None,
|
|
730
|
+
suggestions=[],
|
|
731
|
+
)
|
|
732
|
+
return RunResult(
|
|
733
|
+
ok=False,
|
|
734
|
+
rc=-1,
|
|
735
|
+
session_id=session_id,
|
|
736
|
+
request_id=uuid.uuid4().hex,
|
|
737
|
+
started_at=_utc_iso_ms(),
|
|
738
|
+
elapsed_ms=elapsed_ms,
|
|
739
|
+
stata_elapsed_ms=elapsed_ms,
|
|
740
|
+
stata=StataInfo(version="unknown", edition="unknown", backend=Backend.PYSTATA),
|
|
741
|
+
log=LogInfo(
|
|
742
|
+
head="",
|
|
743
|
+
tail="",
|
|
744
|
+
lines_total=0,
|
|
745
|
+
bytes_total=0,
|
|
746
|
+
truncated=False,
|
|
747
|
+
complete=False,
|
|
748
|
+
error_window=None,
|
|
749
|
+
ref=None,
|
|
750
|
+
),
|
|
751
|
+
results=ResultsInfo(r=_empty_returns(), e=_empty_returns(), last_estimation_cmd=None),
|
|
752
|
+
dataset=_empty_dataset(),
|
|
753
|
+
graphs=[],
|
|
754
|
+
warnings=[],
|
|
755
|
+
error=err,
|
|
756
|
+
schema_version="1.0",
|
|
757
|
+
capabilities=["pystata", "subprocess_timeout"],
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
762
|
+
# Module-level convenience: lazy default pool + pool_execute().
|
|
763
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
_default_pool: SessionPool | None = None
|
|
767
|
+
_default_pool_lock = threading.Lock()
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def get_default_pool() -> SessionPool:
|
|
771
|
+
global _default_pool
|
|
772
|
+
if _default_pool is not None:
|
|
773
|
+
return _default_pool
|
|
774
|
+
with _default_pool_lock:
|
|
775
|
+
if _default_pool is None:
|
|
776
|
+
_default_pool = SessionPool()
|
|
777
|
+
return _default_pool
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
def pool_execute(
|
|
781
|
+
code: str,
|
|
782
|
+
*,
|
|
783
|
+
session_id: str = "main",
|
|
784
|
+
timeout_ms: int | None = 600_000,
|
|
785
|
+
**options: Any,
|
|
786
|
+
) -> RunResult:
|
|
787
|
+
"""Drop-in replacement for `runner.execute()` that enforces `timeout_ms`.
|
|
788
|
+
|
|
789
|
+
Routes through the module's default `SessionPool`. See `SessionPool.execute`
|
|
790
|
+
for behavior.
|
|
791
|
+
"""
|
|
792
|
+
return get_default_pool().execute(
|
|
793
|
+
code, session_id=session_id, timeout_ms=timeout_ms, **options
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def shutdown_default_pool() -> None:
|
|
798
|
+
"""Kill all default-pool workers. Useful for clean shutdown / tests."""
|
|
799
|
+
global _default_pool
|
|
800
|
+
with _default_pool_lock:
|
|
801
|
+
if _default_pool is not None:
|
|
802
|
+
_default_pool.shutdown()
|
|
803
|
+
_default_pool = None
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
# Worker entry-point. `python -m stata_code.core._pool` lands here.
|
|
807
|
+
if __name__ == "__main__":
|
|
808
|
+
sys.exit(_worker_main())
|