stata-code 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,808 @@
1
+ """Subprocess pool for hard timeout enforcement.
2
+
3
+ `runner.execute()` runs pystata in-process, which is fast but offers no
4
+ clean way to cancel a long-running Stata command (`bayes`, `bootstrap,
5
+ reps(10000)`, infinite loop, ...). pystata holds the GIL inside C land
6
+ and ignores Python signals until it returns. v0.2 documented this:
7
+ `timeout_ms` was accepted but not enforced.
8
+
9
+ This module fills that gap. `pool_execute()` is a drop-in for
10
+ `runner.execute()` that routes the call through a per-session
11
+ subprocess. The parent enforces `timeout_ms` by reading from the
12
+ worker's stdout with a deadline; on overrun, the worker is SIGTERMed
13
+ (grace period) then SIGKILLed, and the parent returns a synthetic
14
+ `RunResult(rc=-2, error.kind="timeout")`. The dead worker is dropped
15
+ from the pool and respawned on the next call to that `session_id`.
16
+
17
+ Design choices:
18
+
19
+ - **One worker per session_id.** Stata is single-threaded; serialize per
20
+ session at the worker level. Different sessions get different
21
+ workers and run truly in parallel.
22
+ - **Workers are warm.** First call to a new session pays the pystata
23
+ init cost (~3-10s); subsequent calls are pipe-roundtrip + JSON only
24
+ (typically <50ms overhead).
25
+ - **Refs are ferried.** The worker's `_refs` store is local to that
26
+ process. After each `execute()`, the parent harvests any newly-
27
+ created refs and re-puts them in its OWN `_refs` so that
28
+ `get_log/get_graph/get_matrix` calls served by the parent (the MCP
29
+ server, typically) hit the parent's store directly without IPC.
30
+ - **Wire protocol.** Newline-delimited JSON. Request: one line in
31
+ `{id, code, options}`. Response: one line in
32
+ `{id, ok, result, ref_blobs}` (or `{id, ok=false, error}`).
33
+
34
+ Not exposed in the public API. Frontends import `pool_execute` from
35
+ this module if they want the timeout guarantee.
36
+ """
37
+
38
+ from __future__ import annotations
39
+
40
+ import base64
41
+ import json
42
+ import os
43
+ import subprocess
44
+ import sys
45
+ import threading
46
+ import time
47
+ import uuid
48
+ from datetime import datetime, timezone
49
+ from typing import Any
50
+
51
+ from stata_code.core import _refs
52
+ from stata_code.core.schema import (
53
+ Backend,
54
+ DatasetInfo,
55
+ ErrorContext,
56
+ ErrorInfo,
57
+ ErrorKind,
58
+ LogInfo,
59
+ ResultsInfo,
60
+ RunResult,
61
+ StataInfo,
62
+ StataReturns,
63
+ )
64
+
65
+ # Grace period after SIGTERM before SIGKILL.
66
+ _KILL_GRACE_S = 2.0
67
+
68
+ # Read-poll interval while waiting for worker stdout.
69
+ _POLL_INTERVAL_S = 0.05
70
+
71
+ # Default pool capacity: how many distinct sessions to keep warm.
72
+ _DEFAULT_POOL_CAPACITY = 8
73
+
74
+
75
+ # ─────────────────────────────────────────────────────────────────────────────
76
+ # Worker side: runs in the SUBPROCESS, not the parent.
77
+ # ─────────────────────────────────────────────────────────────────────────────
78
+
79
+
80
+ _BYTES_MARKER = "__bytes_b64__"
81
+
82
+
83
+ def _to_jsonable(obj: Any) -> Any:
84
+ """Recursively rewrite bytes/bytearray as `{__bytes_b64__: <base64>}`.
85
+
86
+ Other primitives (str, int, float, bool, None) pass through. Nested
87
+ dicts and lists are walked. This lets us ferry the runner's graph
88
+ payload — which is a dict that *contains* bytes (`{"format": "png",
89
+ "bytes": <data>, "width": ..., "height": ...}`) — over a JSON pipe
90
+ without losing structure.
91
+ """
92
+ if isinstance(obj, (bytes, bytearray)):
93
+ return {_BYTES_MARKER: base64.b64encode(bytes(obj)).decode("ascii")}
94
+ if isinstance(obj, dict):
95
+ return {k: _to_jsonable(v) for k, v in obj.items()}
96
+ if isinstance(obj, list):
97
+ return [_to_jsonable(v) for v in obj]
98
+ if isinstance(obj, tuple):
99
+ return [_to_jsonable(v) for v in obj]
100
+ return obj
101
+
102
+
103
+ def _from_jsonable(obj: Any) -> Any:
104
+ """Inverse of `_to_jsonable`. Restores bytes from the marker form."""
105
+ if isinstance(obj, dict):
106
+ if _BYTES_MARKER in obj and len(obj) == 1:
107
+ data = obj[_BYTES_MARKER]
108
+ return base64.b64decode(data) if isinstance(data, str) else b""
109
+ return {k: _from_jsonable(v) for k, v in obj.items()}
110
+ if isinstance(obj, list):
111
+ return [_from_jsonable(v) for v in obj]
112
+ return obj
113
+
114
+
115
+ def _payload_to_wire(payload: Any) -> dict[str, Any]:
116
+ """Serialize a `_refs` payload to a JSON-safe envelope.
117
+
118
+ `_refs` stores three payload shapes today:
119
+ - bytes: graph data (legacy / direct path)
120
+ - str: full log text (legacy / direct path)
121
+ - dict: log info `{text, lines_total, bytes_total}`,
122
+ matrix info `{rows, cols, values}`,
123
+ graph info `{format, bytes, width, height}` — note the
124
+ nested bytes inside this last shape, which is why dict
125
+ payloads go through `_to_jsonable`.
126
+ """
127
+ if isinstance(payload, (bytes, bytearray)):
128
+ return {"kind": "bytes", "data": base64.b64encode(bytes(payload)).decode("ascii")}
129
+ if isinstance(payload, str):
130
+ return {"kind": "text", "data": payload}
131
+ if isinstance(payload, dict):
132
+ return {"kind": "json", "data": _to_jsonable(payload)}
133
+ # Conservative fallback — wrap repr; the parent will ignore unknown kinds.
134
+ return {"kind": "unknown", "data": repr(payload)}
135
+
136
+
137
+ def _payload_from_wire(envelope: dict[str, Any]) -> Any:
138
+ kind = envelope.get("kind")
139
+ data = envelope.get("data")
140
+ if kind == "bytes":
141
+ return base64.b64decode(data) if isinstance(data, str) else b""
142
+ if kind == "text":
143
+ return data if isinstance(data, str) else ""
144
+ if kind == "json":
145
+ return _from_jsonable(data)
146
+ return None
147
+
148
+
149
+ def _worker_main() -> int:
150
+ """Worker entry. Reads one JSON request per line from stdin, writes one
151
+ JSON response per line to stdout. Exits 0 on EOF.
152
+ """
153
+ # CRITICAL: dup FDs 0 and 1 BEFORE pystata gets imported / initialized.
154
+ # `pystata.config.init()` reaches into the runtime's FDs (closes /
155
+ # redirects stdin, redirects stdout to its own buffer). If we read /
156
+ # write through `sys.stdin` / `sys.stdout` the protocol breaks the
157
+ # moment Stata initializes — subsequent reads see EOF and the worker
158
+ # exits returncode=0 without responding to the second request.
159
+ #
160
+ # Duping the file descriptors gives us a reader/writer rooted at a
161
+ # *separate* FD that pystata cannot reach via the sys.* indirection.
162
+ saved_stdin_fd = os.dup(0)
163
+ saved_stdout_fd = os.dup(1)
164
+ proto_in = os.fdopen(saved_stdin_fd, "r", buffering=1, encoding="utf-8")
165
+ proto_out = os.fdopen(saved_stdout_fd, "w", buffering=1, encoding="utf-8")
166
+
167
+ # Imported here so `python -m stata_code.core._pool` can fail loudly
168
+ # only if a request actually arrives — listing the worker as a tool
169
+ # candidate shouldn't cost a Stata init.
170
+ from stata_code.core.runner import execute
171
+
172
+ # Explicit readline() loop instead of `for line in proto_in` — the latter
173
+ # uses the io module's buffered iterator, which read-aheads more bytes
174
+ # than are available on a pipe and breaks the request/response cadence
175
+ # after pystata init.
176
+ while True:
177
+ line = proto_in.readline()
178
+ if not line:
179
+ break # EOF — parent closed the pipe
180
+ line = line.strip()
181
+ if not line:
182
+ continue
183
+ req_id: str | None = None
184
+ try:
185
+ req = json.loads(line)
186
+ req_id = req.get("id")
187
+ op = req.get("op", "execute")
188
+ if op == "ping":
189
+ response = {"id": req_id, "ok": True, "pong": True}
190
+ elif op == "list_sessions":
191
+ # Imported lazily — calling list_sessions() on a worker that
192
+ # hasn't yet had any execute() request still triggers pystata
193
+ # init, which is the price of an honest answer.
194
+ from stata_code.core.runner import list_sessions as _ls
195
+
196
+ response = {"id": req_id, "ok": True, "sessions": _ls()}
197
+ elif op == "execute":
198
+ code = req["code"]
199
+ options = req.get("options", {})
200
+ # Snapshot ref keys before so we can ferry only the *new* ones.
201
+ # _refs._store is private but we own this codebase.
202
+ with _refs._lock: # noqa: SLF001
203
+ keys_before = set(_refs._store.keys()) # noqa: SLF001
204
+ result = execute(code, **options)
205
+ with _refs._lock: # noqa: SLF001
206
+ keys_after = set(_refs._store.keys()) # noqa: SLF001
207
+ new_keys = keys_after - keys_before
208
+ ref_blobs: dict[str, dict[str, Any]] = {}
209
+ for k in new_keys:
210
+ payload = _refs.get(k)
211
+ if payload is None:
212
+ continue
213
+ ref_blobs[k] = _payload_to_wire(payload)
214
+ response = {
215
+ "id": req_id,
216
+ "ok": True,
217
+ "result": json.loads(result.model_dump_json()),
218
+ "ref_blobs": ref_blobs,
219
+ }
220
+ else:
221
+ response = {"id": req_id, "ok": False, "error": f"unknown op: {op}"}
222
+ except Exception as exc: # noqa: BLE001
223
+ response = {
224
+ "id": req_id,
225
+ "ok": False,
226
+ "error": f"{type(exc).__name__}: {exc}",
227
+ }
228
+ proto_out.write(json.dumps(response) + "\n")
229
+ proto_out.flush()
230
+ return 0
231
+
232
+
233
+ # ─────────────────────────────────────────────────────────────────────────────
234
+ # Parent side: WorkerProcess + SessionPool + pool_execute().
235
+ # ─────────────────────────────────────────────────────────────────────────────
236
+
237
+
238
+ def _default_worker_cmd() -> list[str]:
239
+ return [sys.executable, "-u", "-m", "stata_code.core._pool"]
240
+
241
+
242
+ class _WorkerError(RuntimeError):
243
+ """Raised on worker-side execution failure (non-Stata, e.g., crash)."""
244
+
245
+
246
+ class _WorkerTimeout(TimeoutError):
247
+ """Raised when a request exceeds its deadline. Parent kills the worker."""
248
+
249
+
250
+ class WorkerProcess:
251
+ """Parent-side handle for one subprocess worker.
252
+
253
+ Construct via `WorkerProcess(session_id, ...)` — the subprocess is
254
+ spawned lazily on first `execute()`. Use `kill()` to terminate.
255
+ Workers are not thread-safe internally; the pool serializes calls.
256
+ """
257
+
258
+ def __init__(
259
+ self,
260
+ session_id: str,
261
+ *,
262
+ worker_cmd: list[str] | None = None,
263
+ ) -> None:
264
+ self.session_id = session_id
265
+ self._cmd = list(worker_cmd) if worker_cmd is not None else _default_worker_cmd()
266
+ self._proc: subprocess.Popen[str] | None = None
267
+ self._lock = threading.Lock()
268
+ self.last_used: float = time.monotonic()
269
+
270
+ def _spawn(self) -> subprocess.Popen[str]:
271
+ env = os.environ.copy()
272
+ # Force unbuffered I/O even if PYTHONUNBUFFERED isn't already set.
273
+ env.setdefault("PYTHONUNBUFFERED", "1")
274
+ return subprocess.Popen(
275
+ self._cmd,
276
+ stdin=subprocess.PIPE,
277
+ stdout=subprocess.PIPE,
278
+ stderr=subprocess.PIPE,
279
+ text=True,
280
+ bufsize=1, # line-buffered
281
+ env=env,
282
+ )
283
+
284
+ def _ensure_alive(self) -> subprocess.Popen[str]:
285
+ if self._proc is None or self._proc.poll() is not None:
286
+ self._proc = self._spawn()
287
+ return self._proc
288
+
289
+ def is_alive(self) -> bool:
290
+ return self._proc is not None and self._proc.poll() is None
291
+
292
+ def execute(
293
+ self,
294
+ code: str,
295
+ options: dict[str, Any],
296
+ *,
297
+ timeout_ms: int | None,
298
+ ) -> tuple[dict[str, Any], dict[str, dict[str, Any]]]:
299
+ """Send one execute request and return (result_dict, ref_blobs).
300
+
301
+ Raises `_WorkerTimeout` on timeout (caller is responsible for
302
+ killing the worker), `_WorkerError` on protocol or worker-side
303
+ crash.
304
+ """
305
+ with self._lock:
306
+ proc = self._ensure_alive()
307
+ assert proc.stdin is not None and proc.stdout is not None # for mypy
308
+ req_id = uuid.uuid4().hex
309
+ request = {"id": req_id, "op": "execute", "code": code, "options": options}
310
+ try:
311
+ proc.stdin.write(json.dumps(request) + "\n")
312
+ proc.stdin.flush()
313
+ except BrokenPipeError as exc:
314
+ raise _WorkerError(f"worker pipe broken on write: {exc}") from exc
315
+
316
+ deadline: float | None
317
+ deadline = None if timeout_ms is None else time.monotonic() + timeout_ms / 1000.0
318
+
319
+ line = self._readline_with_deadline(proc, deadline)
320
+ self.last_used = time.monotonic()
321
+
322
+ try:
323
+ response = json.loads(line)
324
+ except json.JSONDecodeError as exc:
325
+ raise _WorkerError(f"worker emitted non-JSON: {line!r}") from exc
326
+
327
+ if response.get("id") != req_id:
328
+ raise _WorkerError(
329
+ f"worker response id mismatch: expected {req_id}, got {response.get('id')}"
330
+ )
331
+ if not response.get("ok"):
332
+ raise _WorkerError(
333
+ f"worker reported failure: {response.get('error', '<no error>')}"
334
+ )
335
+ return response["result"], response.get("ref_blobs", {})
336
+
337
+ @staticmethod
338
+ def _readline_with_deadline(
339
+ proc: subprocess.Popen[str],
340
+ deadline: float | None,
341
+ ) -> str:
342
+ """Read one line from `proc.stdout` honoring an optional wall-clock
343
+ deadline. Raises `_WorkerTimeout` on overrun.
344
+
345
+ Implementation note: we poll `proc.poll()` plus a short readline
346
+ in a thread, joining with the remaining budget. This is portable
347
+ (no select on Windows pipes) and robust for the line-oriented
348
+ protocol.
349
+ """
350
+ assert proc.stdout is not None
351
+ result: dict[str, str | BaseException] = {}
352
+
353
+ def _reader() -> None:
354
+ try:
355
+ line = proc.stdout.readline() # type: ignore[union-attr]
356
+ result["line"] = line
357
+ except BaseException as exc: # noqa: BLE001
358
+ result["err"] = exc
359
+
360
+ thread = threading.Thread(target=_reader, daemon=True)
361
+ thread.start()
362
+
363
+ while True:
364
+ if deadline is None:
365
+ remaining = None
366
+ else:
367
+ remaining = deadline - time.monotonic()
368
+ if remaining <= 0:
369
+ raise _WorkerTimeout("deadline exceeded waiting for worker response")
370
+ thread.join(timeout=min(_POLL_INTERVAL_S, remaining) if remaining is not None else _POLL_INTERVAL_S)
371
+ if not thread.is_alive():
372
+ if "err" in result:
373
+ raise _WorkerError(f"reader thread error: {result['err']!r}")
374
+ line = result.get("line", "")
375
+ assert isinstance(line, str)
376
+ if not line:
377
+ # EOF — worker exited or pipe closed unexpectedly.
378
+ rc = proc.poll()
379
+ raise _WorkerError(f"worker exited (returncode={rc}) before responding")
380
+ return line
381
+ # Worker still running but no line yet. If the process died, surface that.
382
+ if proc.poll() is not None:
383
+ # Wait briefly for any final bytes the reader thread might catch.
384
+ thread.join(timeout=0.1)
385
+ if "line" in result and isinstance(result["line"], str) and result["line"]:
386
+ return result["line"]
387
+ rc = proc.returncode
388
+ raise _WorkerError(f"worker exited (returncode={rc}) before responding")
389
+
390
+ def send_simple_op(
391
+ self,
392
+ op: str,
393
+ *,
394
+ timeout_ms: int | None,
395
+ ) -> dict[str, Any]:
396
+ """Send a no-payload op (e.g., ``ping``, ``list_sessions``) and return
397
+ the full response dict.
398
+
399
+ Unlike :meth:`execute`, this does **not** respawn a dead worker —
400
+ if the subprocess isn't running, raises :class:`_WorkerError`.
401
+ Caller should treat that as "this worker has nothing to report"
402
+ rather than block on a fresh pystata init for a status query.
403
+ """
404
+ with self._lock:
405
+ if self._proc is None or self._proc.poll() is not None:
406
+ raise _WorkerError(f"worker for {self.session_id!r} not running")
407
+ proc = self._proc
408
+ assert proc.stdin is not None and proc.stdout is not None # for mypy
409
+ req_id = uuid.uuid4().hex
410
+ request = {"id": req_id, "op": op}
411
+ try:
412
+ proc.stdin.write(json.dumps(request) + "\n")
413
+ proc.stdin.flush()
414
+ except BrokenPipeError as exc:
415
+ raise _WorkerError(f"worker pipe broken on write: {exc}") from exc
416
+
417
+ deadline: float | None
418
+ deadline = None if timeout_ms is None else time.monotonic() + timeout_ms / 1000.0
419
+
420
+ line = self._readline_with_deadline(proc, deadline)
421
+ self.last_used = time.monotonic()
422
+
423
+ try:
424
+ response = json.loads(line)
425
+ except json.JSONDecodeError as exc:
426
+ raise _WorkerError(f"worker emitted non-JSON: {line!r}") from exc
427
+
428
+ if response.get("id") != req_id:
429
+ raise _WorkerError(
430
+ f"worker response id mismatch: expected {req_id}, got {response.get('id')}"
431
+ )
432
+ if not response.get("ok"):
433
+ raise _WorkerError(
434
+ f"worker reported failure: {response.get('error', '<no error>')}"
435
+ )
436
+ return response
437
+
438
+ def kill(self) -> None:
439
+ """Terminate the worker. SIGTERM with grace, then SIGKILL."""
440
+ with self._lock:
441
+ self._kill_locked()
442
+
443
+ def _kill_locked(self) -> None:
444
+ if self._proc is None:
445
+ return
446
+ if self._proc.poll() is not None:
447
+ self._proc = None
448
+ return
449
+ try:
450
+ self._proc.terminate()
451
+ try:
452
+ self._proc.wait(timeout=_KILL_GRACE_S)
453
+ except subprocess.TimeoutExpired:
454
+ self._proc.kill()
455
+ try:
456
+ self._proc.wait(timeout=_KILL_GRACE_S)
457
+ except subprocess.TimeoutExpired:
458
+ pass
459
+ except ProcessLookupError:
460
+ pass
461
+ # Drain any buffered stderr so we don't leak a fd.
462
+ try:
463
+ if self._proc.stderr is not None:
464
+ self._proc.stderr.read()
465
+ except Exception: # noqa: BLE001
466
+ pass
467
+ self._proc = None
468
+
469
+
470
+ class SessionPool:
471
+ """LRU pool of subprocess workers, keyed by `session_id`."""
472
+
473
+ def __init__(
474
+ self,
475
+ *,
476
+ capacity: int = _DEFAULT_POOL_CAPACITY,
477
+ worker_cmd: list[str] | None = None,
478
+ ) -> None:
479
+ if capacity < 1:
480
+ raise ValueError("capacity must be ≥ 1")
481
+ self._capacity = capacity
482
+ self._worker_cmd = worker_cmd
483
+ self._workers: dict[str, WorkerProcess] = {}
484
+ self._lock = threading.Lock()
485
+
486
+ @property
487
+ def capacity(self) -> int:
488
+ return self._capacity
489
+
490
+ def _get_or_spawn(self, session_id: str) -> WorkerProcess:
491
+ with self._lock:
492
+ w = self._workers.get(session_id)
493
+ if w is None or not w.is_alive():
494
+ if w is not None:
495
+ # Existing-but-dead — clean up before respawn.
496
+ w.kill()
497
+ w = WorkerProcess(session_id, worker_cmd=self._worker_cmd)
498
+ self._workers[session_id] = w
499
+ self._evict_to_capacity_locked(keep=session_id)
500
+ return w
501
+
502
+ def _evict_to_capacity_locked(self, *, keep: str) -> None:
503
+ if len(self._workers) <= self._capacity:
504
+ return
505
+ # LRU by `last_used`. Never evict the just-added worker.
506
+ candidates = sorted(
507
+ ((sid, w) for sid, w in self._workers.items() if sid != keep),
508
+ key=lambda kv: kv[1].last_used,
509
+ )
510
+ while len(self._workers) > self._capacity and candidates:
511
+ sid, w = candidates.pop(0)
512
+ w.kill()
513
+ self._workers.pop(sid, None)
514
+
515
+ def execute(
516
+ self,
517
+ code: str,
518
+ *,
519
+ session_id: str = "main",
520
+ timeout_ms: int | None = 600_000,
521
+ **options: Any,
522
+ ) -> RunResult:
523
+ """Execute `code` in the session's worker, enforcing `timeout_ms`.
524
+
525
+ On timeout: SIGTERM/SIGKILL the worker and return a synthetic
526
+ `RunResult(rc=-2, error.kind="timeout")`. Subsequent calls to the
527
+ same `session_id` will respawn a fresh worker.
528
+ """
529
+ # Normalize: pass session_id through to the worker so it routes
530
+ # to the right Stata frame. timeout_ms is enforced HERE — the
531
+ # worker doesn't see it.
532
+ worker_options = {**options, "session_id": session_id}
533
+ # Forward timeout_ms verbatim so the worker stores it on the result
534
+ # for observability, even though the real enforcement is parent-side.
535
+ worker_options.setdefault("timeout_ms", timeout_ms)
536
+ worker = self._get_or_spawn(session_id)
537
+ started = time.monotonic()
538
+ try:
539
+ result_dict, ref_blobs = worker.execute(
540
+ code, worker_options, timeout_ms=timeout_ms
541
+ )
542
+ except _WorkerTimeout:
543
+ worker.kill()
544
+ with self._lock:
545
+ self._workers.pop(session_id, None)
546
+ elapsed_ms = int((time.monotonic() - started) * 1000)
547
+ return _build_timeout_result(
548
+ session_id=session_id,
549
+ elapsed_ms=elapsed_ms,
550
+ timeout_ms=timeout_ms or 0,
551
+ )
552
+ except _WorkerError as exc:
553
+ worker.kill()
554
+ with self._lock:
555
+ self._workers.pop(session_id, None)
556
+ elapsed_ms = int((time.monotonic() - started) * 1000)
557
+ return _build_adapter_crash_result(
558
+ session_id=session_id,
559
+ elapsed_ms=elapsed_ms,
560
+ message=str(exc),
561
+ )
562
+
563
+ # Ferry refs into the parent's _refs store.
564
+ for ref_id, envelope in ref_blobs.items():
565
+ payload = _payload_from_wire(envelope)
566
+ if payload is not None:
567
+ _refs.put(ref_id, payload)
568
+
569
+ return RunResult.model_validate(result_dict)
570
+
571
+ def kill_session(self, session_id: str) -> bool:
572
+ """Terminate a session's worker. Returns True if a worker existed."""
573
+ with self._lock:
574
+ w = self._workers.pop(session_id, None)
575
+ if w is None:
576
+ return False
577
+ w.kill()
578
+ return True
579
+
580
+ def shutdown(self) -> None:
581
+ """Kill all workers."""
582
+ with self._lock:
583
+ workers = list(self._workers.values())
584
+ self._workers.clear()
585
+ for w in workers:
586
+ w.kill()
587
+
588
+ def session_ids(self) -> list[str]:
589
+ with self._lock:
590
+ return list(self._workers)
591
+
592
+ def list_session_info(
593
+ self,
594
+ *,
595
+ per_worker_timeout_ms: int = 5000,
596
+ ) -> list[dict[str, Any]]:
597
+ """Aggregate live-session info across all workers.
598
+
599
+ For each worker that's alive, sends ``op=list_sessions`` and pulls
600
+ back ``[{session_id, frame, n_obs}, ...]`` from that worker's
601
+ pystata. The pool dedupes by ``session_id`` (first writer wins) and
602
+ returns the union.
603
+
604
+ Workers that are dead, that fail to respond within
605
+ ``per_worker_timeout_ms``, or that raise a protocol error are
606
+ silently skipped — partial information is better than failing the
607
+ whole list call. Workers that haven't yet served an ``execute``
608
+ will pay the pystata-init cost on the next ``stata_run``, not here:
609
+ :meth:`WorkerProcess.send_simple_op` deliberately does **not**
610
+ respawn dead workers.
611
+ """
612
+ with self._lock:
613
+ workers = list(self._workers.items())
614
+ sessions: list[dict[str, Any]] = []
615
+ seen: set[str] = set()
616
+ for _sid, worker in workers:
617
+ if not worker.is_alive():
618
+ continue
619
+ try:
620
+ response = worker.send_simple_op(
621
+ "list_sessions", timeout_ms=per_worker_timeout_ms
622
+ )
623
+ except (_WorkerError, _WorkerTimeout):
624
+ continue
625
+ for entry in response.get("sessions") or []:
626
+ sid = entry.get("session_id")
627
+ if sid is None or sid in seen:
628
+ continue
629
+ seen.add(sid)
630
+ sessions.append(entry)
631
+ return sessions
632
+
633
+
634
+ # ─────────────────────────────────────────────────────────────────────────────
635
+ # Synthetic-result builders for timeout / adapter crash.
636
+ # ─────────────────────────────────────────────────────────────────────────────
637
+
638
+
639
+ def _utc_iso_ms() -> str:
640
+ return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.") + (
641
+ f"{datetime.now(timezone.utc).microsecond // 1000:03d}Z"
642
+ )
643
+
644
+
645
+ def _empty_returns() -> StataReturns:
646
+ return StataReturns(scalars={}, macros={}, matrices={})
647
+
648
+
649
+ def _empty_dataset() -> DatasetInfo:
650
+ return DatasetInfo(
651
+ frame="default",
652
+ n_obs=0,
653
+ n_vars=0,
654
+ changed=False,
655
+ filename=None,
656
+ variables=None,
657
+ )
658
+
659
+
660
+ def _build_timeout_result(
661
+ *,
662
+ session_id: str,
663
+ elapsed_ms: int,
664
+ timeout_ms: int,
665
+ ) -> RunResult:
666
+ err = ErrorInfo(
667
+ kind=ErrorKind.TIMEOUT,
668
+ rc=-2,
669
+ rc_label="timeout",
670
+ message=(
671
+ f"Execution exceeded the configured timeout of {timeout_ms} ms. "
672
+ f"The worker process for session_id={session_id!r} was terminated."
673
+ ),
674
+ command=None,
675
+ line=None,
676
+ context=ErrorContext(before=[], failing="", after=[]),
677
+ commands_executed=None,
678
+ path=None,
679
+ varname=None,
680
+ name=None,
681
+ suggestions=[],
682
+ )
683
+ return RunResult(
684
+ ok=False,
685
+ rc=-2,
686
+ session_id=session_id,
687
+ request_id=uuid.uuid4().hex,
688
+ started_at=_utc_iso_ms(),
689
+ elapsed_ms=elapsed_ms,
690
+ stata_elapsed_ms=elapsed_ms,
691
+ stata=StataInfo(version="unknown", edition="unknown", backend=Backend.PYSTATA),
692
+ log=LogInfo(
693
+ head="",
694
+ tail="",
695
+ lines_total=0,
696
+ bytes_total=0,
697
+ truncated=False,
698
+ complete=False,
699
+ error_window=None,
700
+ ref=None,
701
+ ),
702
+ results=ResultsInfo(r=_empty_returns(), e=_empty_returns(), last_estimation_cmd=None),
703
+ dataset=_empty_dataset(),
704
+ graphs=[],
705
+ warnings=[],
706
+ error=err,
707
+ schema_version="1.0",
708
+ capabilities=["pystata", "subprocess_timeout"],
709
+ )
710
+
711
+
712
+ def _build_adapter_crash_result(
713
+ *,
714
+ session_id: str,
715
+ elapsed_ms: int,
716
+ message: str,
717
+ ) -> RunResult:
718
+ err = ErrorInfo(
719
+ kind=ErrorKind.ADAPTER_CRASH,
720
+ rc=-1,
721
+ rc_label="adapter_crash",
722
+ message=f"Subprocess worker crashed: {message}",
723
+ command=None,
724
+ line=None,
725
+ context=ErrorContext(before=[], failing="", after=[]),
726
+ commands_executed=None,
727
+ path=None,
728
+ varname=None,
729
+ name=None,
730
+ suggestions=[],
731
+ )
732
+ return RunResult(
733
+ ok=False,
734
+ rc=-1,
735
+ session_id=session_id,
736
+ request_id=uuid.uuid4().hex,
737
+ started_at=_utc_iso_ms(),
738
+ elapsed_ms=elapsed_ms,
739
+ stata_elapsed_ms=elapsed_ms,
740
+ stata=StataInfo(version="unknown", edition="unknown", backend=Backend.PYSTATA),
741
+ log=LogInfo(
742
+ head="",
743
+ tail="",
744
+ lines_total=0,
745
+ bytes_total=0,
746
+ truncated=False,
747
+ complete=False,
748
+ error_window=None,
749
+ ref=None,
750
+ ),
751
+ results=ResultsInfo(r=_empty_returns(), e=_empty_returns(), last_estimation_cmd=None),
752
+ dataset=_empty_dataset(),
753
+ graphs=[],
754
+ warnings=[],
755
+ error=err,
756
+ schema_version="1.0",
757
+ capabilities=["pystata", "subprocess_timeout"],
758
+ )
759
+
760
+
761
+ # ─────────────────────────────────────────────────────────────────────────────
762
+ # Module-level convenience: lazy default pool + pool_execute().
763
+ # ─────────────────────────────────────────────────────────────────────────────
764
+
765
+
766
+ _default_pool: SessionPool | None = None
767
+ _default_pool_lock = threading.Lock()
768
+
769
+
770
+ def get_default_pool() -> SessionPool:
771
+ global _default_pool
772
+ if _default_pool is not None:
773
+ return _default_pool
774
+ with _default_pool_lock:
775
+ if _default_pool is None:
776
+ _default_pool = SessionPool()
777
+ return _default_pool
778
+
779
+
780
+ def pool_execute(
781
+ code: str,
782
+ *,
783
+ session_id: str = "main",
784
+ timeout_ms: int | None = 600_000,
785
+ **options: Any,
786
+ ) -> RunResult:
787
+ """Drop-in replacement for `runner.execute()` that enforces `timeout_ms`.
788
+
789
+ Routes through the module's default `SessionPool`. See `SessionPool.execute`
790
+ for behavior.
791
+ """
792
+ return get_default_pool().execute(
793
+ code, session_id=session_id, timeout_ms=timeout_ms, **options
794
+ )
795
+
796
+
797
+ def shutdown_default_pool() -> None:
798
+ """Kill all default-pool workers. Useful for clean shutdown / tests."""
799
+ global _default_pool
800
+ with _default_pool_lock:
801
+ if _default_pool is not None:
802
+ _default_pool.shutdown()
803
+ _default_pool = None
804
+
805
+
806
+ # Worker entry-point. `python -m stata_code.core._pool` lands here.
807
+ if __name__ == "__main__":
808
+ sys.exit(_worker_main())