stata-code 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stata_code/__init__.py +100 -0
- stata_code/core/__init__.py +73 -0
- stata_code/core/_pool.py +808 -0
- stata_code/core/_refs.py +97 -0
- stata_code/core/_runtime.py +179 -0
- stata_code/core/errors.py +447 -0
- stata_code/core/runner.py +1092 -0
- stata_code/core/schema.py +317 -0
- stata_code/kernel/__init__.py +5 -0
- stata_code/kernel/__main__.py +6 -0
- stata_code/kernel/kernel.py +331 -0
- stata_code/mcp/__init__.py +3 -0
- stata_code/mcp/__main__.py +6 -0
- stata_code/mcp/server.py +360 -0
- stata_code-0.3.0.dist-info/METADATA +389 -0
- stata_code-0.3.0.dist-info/RECORD +20 -0
- stata_code-0.3.0.dist-info/WHEEL +4 -0
- stata_code-0.3.0.dist-info/entry_points.txt +3 -0
- stata_code-0.3.0.dist-info/licenses/LICENSE +21 -0
- stata_code-0.3.0.dist-info/licenses/LICENSE-POLICY.md +125 -0
|
@@ -0,0 +1,1092 @@
|
|
|
1
|
+
"""High-level execute() — runs Stata code and returns a v1.0 RunResult.
|
|
2
|
+
|
|
3
|
+
This is the only place that touches Stata. The MCP server and Jupyter
|
|
4
|
+
kernel both import from here and only translate transports.
|
|
5
|
+
|
|
6
|
+
Implements the v1.0 envelope from SCHEMA.md: ok / rc / error / log /
|
|
7
|
+
results / dataset / graphs / warnings / capabilities. r() and e() are
|
|
8
|
+
collected via sfi (native types). Multi-session is implemented through
|
|
9
|
+
Stata frames (session_id="main" ↔ default frame). Per-line error
|
|
10
|
+
attribution comes from parsing pystata's transcript.
|
|
11
|
+
|
|
12
|
+
For deferred items (hard timeout, cooperative cancellation, get_matrix
|
|
13
|
+
ref mode, console fallback for Stata 11–16, streaming logs), see
|
|
14
|
+
SCHEMA.md §8.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import io
|
|
20
|
+
import re
|
|
21
|
+
import tempfile
|
|
22
|
+
import threading
|
|
23
|
+
import time
|
|
24
|
+
import uuid
|
|
25
|
+
from contextlib import redirect_stdout
|
|
26
|
+
from datetime import datetime, timezone
|
|
27
|
+
from pathlib import Path
|
|
28
|
+
from typing import Any
|
|
29
|
+
|
|
30
|
+
from stata_code.core import _refs
|
|
31
|
+
from stata_code.core._runtime import PystataNotAvailable, get_runtime
|
|
32
|
+
from stata_code.core.errors import classify_rc, suggestions_for
|
|
33
|
+
from stata_code.core.schema import (
|
|
34
|
+
Backend,
|
|
35
|
+
DatasetInfo,
|
|
36
|
+
ErrorContext,
|
|
37
|
+
ErrorInfo,
|
|
38
|
+
ErrorKind,
|
|
39
|
+
GraphFormat,
|
|
40
|
+
GraphInfo,
|
|
41
|
+
LogInfo,
|
|
42
|
+
Matrix,
|
|
43
|
+
ResultsInfo,
|
|
44
|
+
RunResult,
|
|
45
|
+
StataEdition,
|
|
46
|
+
StataInfo,
|
|
47
|
+
StataReturns,
|
|
48
|
+
VariableInfo,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
52
|
+
# Helpers
|
|
53
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
_EDITION_MAP: dict[str, StataEdition] = {
|
|
57
|
+
"mp": StataEdition.MP,
|
|
58
|
+
"se": StataEdition.SE,
|
|
59
|
+
"ic": StataEdition.IC,
|
|
60
|
+
"be": StataEdition.BE,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
_ERETURN_NAME_RE = re.compile(r"^\s*(?:e|r)\(([A-Za-z_][A-Za-z0-9_]*)\)\s*[=:]")
|
|
64
|
+
_VARNAME_RE = re.compile(r"variable (\w+) (?:not found|already defined)")
|
|
65
|
+
_FILE_PATH_RE = re.compile(
|
|
66
|
+
r"file\s+(\S+?)\s+(?:not\s+found|already\s+exists|could\s+not)"
|
|
67
|
+
)
|
|
68
|
+
_NAME_CONFLICT_RE = re.compile(r"(\w+)\s+already\s+(?:defined|exists)")
|
|
69
|
+
_UNRECOGNIZED_CMD_RE = re.compile(r"(\S+)\s+(?:is\s+)?unrecognized\s+command")
|
|
70
|
+
|
|
71
|
+
# Cooperative cancellation: a per-session "cancel-pending" flag, settable
|
|
72
|
+
# from any thread via `cancel(session_id)`. The flag is consumed by the
|
|
73
|
+
# next `execute()` call for that session, which short-circuits and returns
|
|
74
|
+
# a RunResult with `error.kind="cancelled"` instead of forwarding the code
|
|
75
|
+
# to Stata. Cooperative semantics — does NOT interrupt code that is
|
|
76
|
+
# already mid-`stata.run()`. Hard interruption requires the subprocess-
|
|
77
|
+
# based runtime planned for v0.3+ (see SCHEMA.md §8).
|
|
78
|
+
_cancel_lock = threading.Lock()
|
|
79
|
+
_cancel_pending: set[str] = set()
|
|
80
|
+
|
|
81
|
+
# Cap on `dataset.variables` to avoid pathological return sizes (per SCHEMA §3.5).
|
|
82
|
+
_DATASET_VAR_CAP = 200
|
|
83
|
+
|
|
84
|
+
# Cap on inlined matrix cells (rows × cols). Above this, `values` is omitted
|
|
85
|
+
# from the envelope and a `matrix://...` ref is stored instead, retrievable
|
|
86
|
+
# via `get_matrix(ref)`. Per SCHEMA.md §3.4: "Producers SHOULD do this when
|
|
87
|
+
# a matrix would inline more than ~10,000 cells."
|
|
88
|
+
MATRIX_INLINE_CELL_CAP = 10_000
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _utc_iso_ms() -> str:
|
|
92
|
+
now = datetime.now(timezone.utc)
|
|
93
|
+
return now.strftime("%Y-%m-%dT%H:%M:%S.") + f"{now.microsecond // 1000:03d}Z"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _new_request_id() -> str:
|
|
97
|
+
# uuid4 hex is unique enough; ULID would be sortable but adds a dep.
|
|
98
|
+
return uuid.uuid4().hex
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _split_log(
|
|
102
|
+
log: str,
|
|
103
|
+
head_lines: int,
|
|
104
|
+
tail_lines: int,
|
|
105
|
+
include_full: bool,
|
|
106
|
+
request_id: str,
|
|
107
|
+
) -> LogInfo:
|
|
108
|
+
"""Build a LogInfo per SCHEMA §3.3.
|
|
109
|
+
|
|
110
|
+
Stores the full log under `log://<request_id>` when truncating, so that
|
|
111
|
+
`get_log(ref)` can retrieve it later within the producer's lifetime.
|
|
112
|
+
"""
|
|
113
|
+
norm = log.replace("\r\n", "\n").replace("\r", "\n")
|
|
114
|
+
lines = norm.split("\n")
|
|
115
|
+
if lines and lines[-1] == "":
|
|
116
|
+
lines = lines[:-1]
|
|
117
|
+
lines_total = len(lines)
|
|
118
|
+
# `bytes_total` reflects the byte count of what `get_log(ref)` would
|
|
119
|
+
# return — i.e., the normalized text without trailing newline.
|
|
120
|
+
full_text = "\n".join(lines)
|
|
121
|
+
bytes_total = len(full_text.encode("utf-8"))
|
|
122
|
+
|
|
123
|
+
if include_full or lines_total <= head_lines + tail_lines:
|
|
124
|
+
return LogInfo(
|
|
125
|
+
head=full_text,
|
|
126
|
+
tail="",
|
|
127
|
+
lines_total=lines_total,
|
|
128
|
+
bytes_total=bytes_total,
|
|
129
|
+
truncated=False,
|
|
130
|
+
complete=True,
|
|
131
|
+
ref=None,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
head = "\n".join(lines[:head_lines])
|
|
135
|
+
tail = "\n".join(lines[-tail_lines:])
|
|
136
|
+
ref = f"log://{request_id}"
|
|
137
|
+
_refs.put(
|
|
138
|
+
ref,
|
|
139
|
+
{"text": full_text, "lines_total": lines_total, "bytes_total": bytes_total},
|
|
140
|
+
)
|
|
141
|
+
return LogInfo(
|
|
142
|
+
head=head,
|
|
143
|
+
tail=tail,
|
|
144
|
+
lines_total=lines_total,
|
|
145
|
+
bytes_total=bytes_total,
|
|
146
|
+
truncated=True,
|
|
147
|
+
complete=True,
|
|
148
|
+
ref=ref,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def get_log(ref: str) -> dict[str, Any]:
|
|
153
|
+
"""Auxiliary tool: fetch the full log behind a `log.ref`.
|
|
154
|
+
|
|
155
|
+
Per SCHEMA.md §5. Raises KeyError if the ref is unknown.
|
|
156
|
+
"""
|
|
157
|
+
payload = _refs.get(ref)
|
|
158
|
+
if payload is None:
|
|
159
|
+
raise KeyError(f"unknown log ref: {ref!r}")
|
|
160
|
+
return {
|
|
161
|
+
"text": payload["text"],
|
|
162
|
+
"lines_total": payload["lines_total"],
|
|
163
|
+
"bytes_total": payload["bytes_total"],
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def cancel(session_id: str = "main") -> bool:
|
|
168
|
+
"""Request cancellation of the next ``execute()`` call for ``session_id``.
|
|
169
|
+
|
|
170
|
+
Cooperative: does **not** interrupt code that is currently mid-execution
|
|
171
|
+
inside pystata. The flag is consumed (and the run short-circuited)
|
|
172
|
+
when ``execute(session_id=...)`` is next invoked for the same session.
|
|
173
|
+
The short-circuit returns a ``RunResult`` with ``ok=False``, ``rc=-1``,
|
|
174
|
+
and ``error.kind=cancelled``.
|
|
175
|
+
|
|
176
|
+
Returns ``True`` if a new cancel was registered, ``False`` if one was
|
|
177
|
+
already pending (idempotent).
|
|
178
|
+
"""
|
|
179
|
+
with _cancel_lock:
|
|
180
|
+
if session_id in _cancel_pending:
|
|
181
|
+
return False
|
|
182
|
+
_cancel_pending.add(session_id)
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def is_cancel_pending(session_id: str = "main") -> bool:
|
|
187
|
+
"""Whether a cancel will fire on the next ``execute()`` for this session."""
|
|
188
|
+
with _cancel_lock:
|
|
189
|
+
return session_id in _cancel_pending
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def clear_cancel(session_id: str = "main") -> bool:
|
|
193
|
+
"""Drop any pending cancel for ``session_id`` without firing it.
|
|
194
|
+
|
|
195
|
+
Returns ``True`` if a pending cancel was cleared.
|
|
196
|
+
"""
|
|
197
|
+
with _cancel_lock:
|
|
198
|
+
if session_id in _cancel_pending:
|
|
199
|
+
_cancel_pending.remove(session_id)
|
|
200
|
+
return True
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _consume_cancel(session_id: str) -> bool:
|
|
205
|
+
"""Pop and return whether a cancel is pending for ``session_id``."""
|
|
206
|
+
with _cancel_lock:
|
|
207
|
+
if session_id in _cancel_pending:
|
|
208
|
+
_cancel_pending.remove(session_id)
|
|
209
|
+
return True
|
|
210
|
+
return False
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _build_cancelled_result(
|
|
214
|
+
*,
|
|
215
|
+
rt: Any,
|
|
216
|
+
session_id: str,
|
|
217
|
+
request_id: str,
|
|
218
|
+
started_at: str,
|
|
219
|
+
started: float,
|
|
220
|
+
include_dataset_variables: bool,
|
|
221
|
+
) -> RunResult:
|
|
222
|
+
"""Synthesize a RunResult for a cancel-before-Stata short-circuit.
|
|
223
|
+
|
|
224
|
+
The dataset block still reflects current state (post-cancel snapshot);
|
|
225
|
+
log / results / graphs / warnings are empty because no code ran.
|
|
226
|
+
rc=-3 is the synthetic code reserved for cooperative cancellation
|
|
227
|
+
(distinct from -1 adapter_crash and -2 timeout, per SCHEMA.md §3.7).
|
|
228
|
+
"""
|
|
229
|
+
elapsed_total_ms = max(1, int((time.monotonic() - started) * 1000))
|
|
230
|
+
return RunResult(
|
|
231
|
+
ok=False,
|
|
232
|
+
rc=-3,
|
|
233
|
+
session_id=session_id,
|
|
234
|
+
request_id=request_id,
|
|
235
|
+
started_at=started_at,
|
|
236
|
+
elapsed_ms=elapsed_total_ms,
|
|
237
|
+
stata_elapsed_ms=0,
|
|
238
|
+
stata=_stata_info(rt),
|
|
239
|
+
log=LogInfo(
|
|
240
|
+
head="", tail="", lines_total=0, bytes_total=0,
|
|
241
|
+
truncated=False, complete=True, ref=None,
|
|
242
|
+
),
|
|
243
|
+
results=ResultsInfo(),
|
|
244
|
+
dataset=_collect_dataset(rt, include_dataset_variables),
|
|
245
|
+
graphs=[],
|
|
246
|
+
warnings=[],
|
|
247
|
+
error=ErrorInfo(
|
|
248
|
+
kind=ErrorKind.CANCELLED,
|
|
249
|
+
rc=-3,
|
|
250
|
+
rc_label="cancelled",
|
|
251
|
+
message=(
|
|
252
|
+
"Execution cancelled before Stata received the code "
|
|
253
|
+
f"(session_id={session_id!r})."
|
|
254
|
+
),
|
|
255
|
+
command=None,
|
|
256
|
+
line=None,
|
|
257
|
+
context=ErrorContext(before=[], failing="", after=[]),
|
|
258
|
+
commands_executed=0,
|
|
259
|
+
path=None,
|
|
260
|
+
varname=None,
|
|
261
|
+
name=None,
|
|
262
|
+
suggestions=[],
|
|
263
|
+
),
|
|
264
|
+
capabilities=["cancel", "multi_session"],
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _parse_return_list(text: str) -> dict[str, list[str]]:
|
|
269
|
+
"""Parse `return list` / `ereturn list` output into category -> names.
|
|
270
|
+
|
|
271
|
+
Categories are 'scalars', 'macros', 'matrices' (and 'functions' which we
|
|
272
|
+
ignore in v0.1).
|
|
273
|
+
"""
|
|
274
|
+
out: dict[str, list[str]] = {"scalars": [], "macros": [], "matrices": []}
|
|
275
|
+
current: str | None = None
|
|
276
|
+
for raw in text.splitlines():
|
|
277
|
+
line = raw.rstrip()
|
|
278
|
+
stripped = line.strip()
|
|
279
|
+
if not stripped:
|
|
280
|
+
continue
|
|
281
|
+
# Section headers are at the left margin: "scalars:", "macros:", etc.
|
|
282
|
+
if not line.startswith(" ") and stripped.endswith(":"):
|
|
283
|
+
label = stripped[:-1].strip().lower()
|
|
284
|
+
if label in out:
|
|
285
|
+
current = label
|
|
286
|
+
else:
|
|
287
|
+
current = None
|
|
288
|
+
continue
|
|
289
|
+
if current is None:
|
|
290
|
+
continue
|
|
291
|
+
m = _ERETURN_NAME_RE.match(line)
|
|
292
|
+
if m:
|
|
293
|
+
out[current].append(m.group(1))
|
|
294
|
+
return out
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def _list_returns(rt: Any, prefix: str) -> dict[str, list[str]]:
|
|
298
|
+
"""Get the names of r() / e() members by parsing `return list` text.
|
|
299
|
+
|
|
300
|
+
`prefix` is "r" or "e". This runs `return list` / `ereturn list` and
|
|
301
|
+
captures its output into a dedicated buffer (separate from the user log).
|
|
302
|
+
"""
|
|
303
|
+
cmd = "ereturn list" if prefix == "e" else "return list"
|
|
304
|
+
buf = io.StringIO()
|
|
305
|
+
try:
|
|
306
|
+
with redirect_stdout(buf):
|
|
307
|
+
rt.stata.run(cmd, quietly=False, echo=False)
|
|
308
|
+
except Exception: # noqa: BLE001
|
|
309
|
+
return {"scalars": [], "macros": [], "matrices": []}
|
|
310
|
+
return _parse_return_list(buf.getvalue())
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _collect_returns(rt: Any, prefix: str, request_id: str) -> StataReturns:
|
|
314
|
+
"""Build a StataReturns for r() or e() using sfi for typed access.
|
|
315
|
+
|
|
316
|
+
Matrices larger than ``MATRIX_INLINE_CELL_CAP`` cells are emitted with
|
|
317
|
+
``values=None`` and a ``matrix://<request_id>/<prefix>/<name>`` ref;
|
|
318
|
+
callers fetch the values via :func:`get_matrix`.
|
|
319
|
+
"""
|
|
320
|
+
names = _list_returns(rt, prefix)
|
|
321
|
+
sfi = rt.sfi
|
|
322
|
+
|
|
323
|
+
scalars: dict[str, float | None] = {}
|
|
324
|
+
for name in names["scalars"]:
|
|
325
|
+
try:
|
|
326
|
+
v = sfi.Scalar.getValue(f"{prefix}({name})")
|
|
327
|
+
scalars[name] = float(v) if v is not None else None
|
|
328
|
+
except Exception: # noqa: BLE001
|
|
329
|
+
scalars[name] = None
|
|
330
|
+
|
|
331
|
+
macros: dict[str, str] = {}
|
|
332
|
+
for name in names["macros"]:
|
|
333
|
+
try:
|
|
334
|
+
v = sfi.Macro.getGlobal(f"{prefix}({name})")
|
|
335
|
+
macros[name] = v if v is not None else ""
|
|
336
|
+
except Exception: # noqa: BLE001
|
|
337
|
+
macros[name] = ""
|
|
338
|
+
|
|
339
|
+
matrices: dict[str, Matrix] = {}
|
|
340
|
+
for name in names["matrices"]:
|
|
341
|
+
key = f"{prefix}({name})"
|
|
342
|
+
try:
|
|
343
|
+
values = sfi.Matrix.get(key)
|
|
344
|
+
rows = list(sfi.Matrix.getRowNames(key) or [])
|
|
345
|
+
cols = list(sfi.Matrix.getColNames(key) or [])
|
|
346
|
+
norm_values: list[list[float | None]] = [
|
|
347
|
+
[None if v is None else float(v) for v in row]
|
|
348
|
+
for row in values
|
|
349
|
+
]
|
|
350
|
+
n_rows = len(norm_values)
|
|
351
|
+
n_cols = len(norm_values[0]) if n_rows else 0
|
|
352
|
+
if n_rows * n_cols > MATRIX_INLINE_CELL_CAP:
|
|
353
|
+
ref = f"matrix://{request_id}/{prefix}/{name}"
|
|
354
|
+
_refs.put(
|
|
355
|
+
ref,
|
|
356
|
+
{"rows": rows, "cols": cols, "values": norm_values},
|
|
357
|
+
)
|
|
358
|
+
matrices[name] = Matrix(
|
|
359
|
+
rows=rows, cols=cols, values=None, ref=ref
|
|
360
|
+
)
|
|
361
|
+
else:
|
|
362
|
+
matrices[name] = Matrix(
|
|
363
|
+
rows=rows, cols=cols, values=norm_values, ref=None
|
|
364
|
+
)
|
|
365
|
+
except Exception: # noqa: BLE001
|
|
366
|
+
continue
|
|
367
|
+
|
|
368
|
+
return StataReturns(scalars=scalars, macros=macros, matrices=matrices)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _collect_dataset(rt: Any, include_variables: bool) -> DatasetInfo:
|
|
372
|
+
sfi = rt.sfi
|
|
373
|
+
Data = sfi.Data
|
|
374
|
+
SFIToolkit = sfi.SFIToolkit
|
|
375
|
+
Scalar = sfi.Scalar
|
|
376
|
+
|
|
377
|
+
n_vars = int(Data.getVarCount())
|
|
378
|
+
n_obs = int(Data.getObsTotal())
|
|
379
|
+
|
|
380
|
+
# c(changed) / c(filename) / c(frame): some are scalar-accessible, some are
|
|
381
|
+
# macro-accessible. Use a try/fallback.
|
|
382
|
+
def _c_macro(name: str) -> str | None:
|
|
383
|
+
try:
|
|
384
|
+
v = SFIToolkit.macroExpand(f"`c({name})'")
|
|
385
|
+
return v if v else None
|
|
386
|
+
except Exception: # noqa: BLE001
|
|
387
|
+
return None
|
|
388
|
+
|
|
389
|
+
changed_val = 0.0
|
|
390
|
+
try:
|
|
391
|
+
changed_val = float(Scalar.getValue("c(changed)") or 0.0)
|
|
392
|
+
except Exception: # noqa: BLE001
|
|
393
|
+
pass
|
|
394
|
+
changed = bool(changed_val)
|
|
395
|
+
|
|
396
|
+
filename = _c_macro("filename")
|
|
397
|
+
frame_name = _c_macro("frame") or "default"
|
|
398
|
+
|
|
399
|
+
variables: list[VariableInfo] | None
|
|
400
|
+
if include_variables and n_vars > 0:
|
|
401
|
+
cap = min(n_vars, _DATASET_VAR_CAP)
|
|
402
|
+
variables = [
|
|
403
|
+
VariableInfo(
|
|
404
|
+
name=Data.getVarName(i),
|
|
405
|
+
type=Data.getVarType(i),
|
|
406
|
+
label=Data.getVarLabel(i) or "",
|
|
407
|
+
)
|
|
408
|
+
for i in range(cap)
|
|
409
|
+
]
|
|
410
|
+
else:
|
|
411
|
+
variables = None
|
|
412
|
+
|
|
413
|
+
return DatasetInfo(
|
|
414
|
+
frame=frame_name,
|
|
415
|
+
n_obs=n_obs,
|
|
416
|
+
n_vars=n_vars,
|
|
417
|
+
changed=changed,
|
|
418
|
+
filename=filename,
|
|
419
|
+
variables=variables,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def _stata_info(rt: Any) -> StataInfo:
|
|
424
|
+
sfi = rt.sfi
|
|
425
|
+
SFIToolkit = sfi.SFIToolkit
|
|
426
|
+
try:
|
|
427
|
+
version = SFIToolkit.macroExpand("`c(stata_version)'") or None
|
|
428
|
+
except Exception: # noqa: BLE001
|
|
429
|
+
version = None
|
|
430
|
+
edition_str = (rt.edition or "").lower()
|
|
431
|
+
edition = _EDITION_MAP.get(edition_str, StataEdition.UNKNOWN)
|
|
432
|
+
return StataInfo(version=version, edition=edition, backend=Backend.PYSTATA)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def _extract_typed_fields(kind: ErrorKind, message: str) -> dict[str, str | None]:
|
|
436
|
+
fields: dict[str, str | None] = {
|
|
437
|
+
"varname": None,
|
|
438
|
+
"path": None,
|
|
439
|
+
"name": None,
|
|
440
|
+
"command": None,
|
|
441
|
+
}
|
|
442
|
+
if kind == ErrorKind.VARNAME_NOT_FOUND or kind == ErrorKind.NAME_CONFLICT:
|
|
443
|
+
m = _VARNAME_RE.search(message)
|
|
444
|
+
if m:
|
|
445
|
+
if kind == ErrorKind.VARNAME_NOT_FOUND:
|
|
446
|
+
fields["varname"] = m.group(1)
|
|
447
|
+
else:
|
|
448
|
+
fields["name"] = m.group(1)
|
|
449
|
+
if kind in (
|
|
450
|
+
ErrorKind.FILE_NOT_FOUND,
|
|
451
|
+
ErrorKind.FILE_EXISTS,
|
|
452
|
+
ErrorKind.FILE_IO,
|
|
453
|
+
ErrorKind.FILE_CORRUPT,
|
|
454
|
+
):
|
|
455
|
+
m = _FILE_PATH_RE.search(message)
|
|
456
|
+
if m:
|
|
457
|
+
fields["path"] = m.group(1)
|
|
458
|
+
if kind == ErrorKind.NAME_CONFLICT and fields["name"] is None:
|
|
459
|
+
m = _NAME_CONFLICT_RE.search(message)
|
|
460
|
+
if m:
|
|
461
|
+
fields["name"] = m.group(1)
|
|
462
|
+
if kind == ErrorKind.COMMAND_NOT_FOUND:
|
|
463
|
+
m = _UNRECOGNIZED_CMD_RE.search(message)
|
|
464
|
+
if m:
|
|
465
|
+
fields["command"] = m.group(1)
|
|
466
|
+
return fields
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def _parse_failure_transcript(
|
|
470
|
+
error_text: str, user_code: str
|
|
471
|
+
) -> dict[str, Any]:
|
|
472
|
+
"""Pinpoint the failing command in multi-line user code.
|
|
473
|
+
|
|
474
|
+
pystata's SystemError for multi-line input contains the full Stata
|
|
475
|
+
transcript with `. <cmd>` echoes for each line. We parse it to recover:
|
|
476
|
+
|
|
477
|
+
- `failing`: the failing command's text (or "" if not isolatable)
|
|
478
|
+
- `line`: 1-indexed line in the original user code (or None)
|
|
479
|
+
- `commands_executed`: count of *non-comment* commands that completed
|
|
480
|
+
successfully before the failure (or None)
|
|
481
|
+
- `before` / `after`: surrounding lines in the user code (up to 3 / 1)
|
|
482
|
+
"""
|
|
483
|
+
out: dict[str, Any] = {
|
|
484
|
+
"failing": "",
|
|
485
|
+
"line": None,
|
|
486
|
+
"commands_executed": None,
|
|
487
|
+
"before": [],
|
|
488
|
+
"after": [],
|
|
489
|
+
"command": None,
|
|
490
|
+
}
|
|
491
|
+
user_lines = user_code.splitlines()
|
|
492
|
+
non_empty_user_lines = [ln for ln in user_lines if ln.strip()]
|
|
493
|
+
|
|
494
|
+
# Single-line case — no transcript, just the error message.
|
|
495
|
+
if "\n. " not in error_text and not error_text.startswith(". "):
|
|
496
|
+
if len(non_empty_user_lines) == 1:
|
|
497
|
+
failing = non_empty_user_lines[0].strip()
|
|
498
|
+
out["failing"] = failing
|
|
499
|
+
out["command"] = failing
|
|
500
|
+
# Find its line number in the original (with blanks)
|
|
501
|
+
for i, ln in enumerate(user_lines, 1):
|
|
502
|
+
if ln.strip() == failing:
|
|
503
|
+
out["line"] = i
|
|
504
|
+
break
|
|
505
|
+
out["commands_executed"] = 0
|
|
506
|
+
return out
|
|
507
|
+
|
|
508
|
+
# Multi-line case — parse `. <cmd>` lines.
|
|
509
|
+
transcript_lines = error_text.split("\n")
|
|
510
|
+
cmd_echoes: list[str] = []
|
|
511
|
+
for ln in transcript_lines:
|
|
512
|
+
if not ln.startswith(". "):
|
|
513
|
+
continue
|
|
514
|
+
body = ln[2:].strip()
|
|
515
|
+
if not body:
|
|
516
|
+
continue # empty `. ` is just a prompt
|
|
517
|
+
if body.startswith("*") or body.startswith("//"):
|
|
518
|
+
continue # comment-only line — Stata echoes but doesn't "run"
|
|
519
|
+
cmd_echoes.append(body)
|
|
520
|
+
|
|
521
|
+
if not cmd_echoes:
|
|
522
|
+
return out
|
|
523
|
+
|
|
524
|
+
failing = cmd_echoes[-1]
|
|
525
|
+
out["failing"] = failing
|
|
526
|
+
out["command"] = failing
|
|
527
|
+
out["commands_executed"] = len(cmd_echoes) - 1
|
|
528
|
+
|
|
529
|
+
# Match against original user code lines (with blanks) for line number.
|
|
530
|
+
for i, ln in enumerate(user_lines, 1):
|
|
531
|
+
if ln.strip() == failing:
|
|
532
|
+
out["line"] = i
|
|
533
|
+
out["before"] = [
|
|
534
|
+
user_lines[j] for j in range(max(0, i - 4), i - 1) if user_lines[j].strip()
|
|
535
|
+
][-3:]
|
|
536
|
+
if i < len(user_lines):
|
|
537
|
+
next_lines = [
|
|
538
|
+
user_lines[j] for j in range(i, min(len(user_lines), i + 1))
|
|
539
|
+
if user_lines[j].strip()
|
|
540
|
+
]
|
|
541
|
+
out["after"] = next_lines[:1]
|
|
542
|
+
break
|
|
543
|
+
|
|
544
|
+
return out
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
def _build_error(
|
|
548
|
+
rc: int,
|
|
549
|
+
error_message: str,
|
|
550
|
+
user_code: str,
|
|
551
|
+
available_varnames: list[str] | None,
|
|
552
|
+
) -> ErrorInfo:
|
|
553
|
+
kind = classify_rc(rc)
|
|
554
|
+
short_msg = (
|
|
555
|
+
_last_error_line(error_message) if error_message else ""
|
|
556
|
+
)
|
|
557
|
+
typed = _extract_typed_fields(kind, error_message)
|
|
558
|
+
suggs = suggestions_for(
|
|
559
|
+
kind,
|
|
560
|
+
varname=typed["varname"],
|
|
561
|
+
name=typed["name"],
|
|
562
|
+
command=typed["command"],
|
|
563
|
+
path=typed["path"],
|
|
564
|
+
available_varnames=available_varnames,
|
|
565
|
+
)
|
|
566
|
+
pinpoint = _parse_failure_transcript(error_message, user_code)
|
|
567
|
+
return ErrorInfo(
|
|
568
|
+
kind=kind,
|
|
569
|
+
rc=rc,
|
|
570
|
+
message=short_msg,
|
|
571
|
+
command=pinpoint["command"],
|
|
572
|
+
line=pinpoint["line"],
|
|
573
|
+
context=ErrorContext(
|
|
574
|
+
before=pinpoint["before"],
|
|
575
|
+
failing=pinpoint["failing"],
|
|
576
|
+
after=pinpoint["after"],
|
|
577
|
+
),
|
|
578
|
+
commands_executed=pinpoint["commands_executed"],
|
|
579
|
+
path=typed["path"],
|
|
580
|
+
varname=typed["varname"],
|
|
581
|
+
name=typed["name"],
|
|
582
|
+
suggestions=suggs,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def _last_error_line(error_text: str) -> str:
|
|
587
|
+
"""Extract the most informative line from a Stata error transcript.
|
|
588
|
+
|
|
589
|
+
For single-line errors the text is short; we just take the first line.
|
|
590
|
+
For multi-line transcripts the actual error sentence ("variable X not
|
|
591
|
+
found") sits AFTER the last `. <cmd>` echo and BEFORE the `r(NN);` rc
|
|
592
|
+
line. Return that sentence so agents see the diagnosis, not the echoed
|
|
593
|
+
command.
|
|
594
|
+
"""
|
|
595
|
+
lines = [ln for ln in error_text.splitlines() if ln]
|
|
596
|
+
if not lines:
|
|
597
|
+
return ""
|
|
598
|
+
if not any(ln.startswith(". ") for ln in lines):
|
|
599
|
+
return lines[0].strip()
|
|
600
|
+
# Walk from bottom: skip rc lines, take next non-rc, non-`.` line.
|
|
601
|
+
for ln in reversed(lines):
|
|
602
|
+
s = ln.strip()
|
|
603
|
+
if not s:
|
|
604
|
+
continue
|
|
605
|
+
if s.startswith("r(") and s.endswith(");"):
|
|
606
|
+
continue
|
|
607
|
+
if ln.startswith(". "):
|
|
608
|
+
continue
|
|
609
|
+
return s
|
|
610
|
+
return lines[0].strip()
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
614
|
+
# Public entrypoint
|
|
615
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def execute(
|
|
619
|
+
code: str,
|
|
620
|
+
*,
|
|
621
|
+
session_id: str = "main",
|
|
622
|
+
log_lines_head: int = 20,
|
|
623
|
+
log_lines_tail: int = 20,
|
|
624
|
+
include_full_log: bool = False,
|
|
625
|
+
include_graphs: str = "ref", # "ref" | "inline" | "none"
|
|
626
|
+
graph_format: str = "png",
|
|
627
|
+
include_dataset_variables: bool = True,
|
|
628
|
+
timeout_ms: int | None = 600_000, # accepted but not yet enforced (v0.1)
|
|
629
|
+
) -> RunResult:
|
|
630
|
+
"""Execute Stata code and return a v1.0 RunResult.
|
|
631
|
+
|
|
632
|
+
Raises PystataNotAvailable if Stata cannot be initialized.
|
|
633
|
+
|
|
634
|
+
Multi-session: `session_id="main"` routes to Stata's master frame
|
|
635
|
+
(`default`); any other valid Stata-name routes to a same-named Stata
|
|
636
|
+
frame, created on demand. Frames isolate **data** (variables and
|
|
637
|
+
observations), but `r()`, `e()`, scalars, and macros remain global
|
|
638
|
+
across frames — agents needing full isolation should use separate
|
|
639
|
+
processes.
|
|
640
|
+
"""
|
|
641
|
+
if include_graphs not in ("ref", "inline", "none"):
|
|
642
|
+
raise ValueError(
|
|
643
|
+
f"include_graphs must be 'ref' | 'inline' | 'none'; got {include_graphs!r}"
|
|
644
|
+
)
|
|
645
|
+
try:
|
|
646
|
+
gfmt = GraphFormat(graph_format)
|
|
647
|
+
except ValueError as exc:
|
|
648
|
+
raise ValueError(
|
|
649
|
+
f"graph_format must be 'png' | 'svg' | 'pdf'; got {graph_format!r}"
|
|
650
|
+
) from exc
|
|
651
|
+
|
|
652
|
+
rt = get_runtime() # may raise PystataNotAvailable
|
|
653
|
+
_ensure_session(rt, session_id)
|
|
654
|
+
|
|
655
|
+
request_id = _new_request_id()
|
|
656
|
+
started_at = _utc_iso_ms()
|
|
657
|
+
started = time.monotonic()
|
|
658
|
+
|
|
659
|
+
if _consume_cancel(session_id):
|
|
660
|
+
return _build_cancelled_result(
|
|
661
|
+
rt=rt,
|
|
662
|
+
session_id=session_id,
|
|
663
|
+
request_id=request_id,
|
|
664
|
+
started_at=started_at,
|
|
665
|
+
started=started,
|
|
666
|
+
include_dataset_variables=include_dataset_variables,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
# Snapshot existing graph names before user code so we can take a delta
|
|
670
|
+
# afterward. This itself calls `graph dir`, which clobbers r(); user code
|
|
671
|
+
# will overwrite r() if they care about return values.
|
|
672
|
+
pre_graphs = (
|
|
673
|
+
_list_graph_names(rt) if include_graphs != "none" else []
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
stdout_text, rc, err_msg = rt.run_capture(code)
|
|
677
|
+
|
|
678
|
+
elapsed_total_ms = max(1, int((time.monotonic() - started) * 1000))
|
|
679
|
+
# v0.1: stata_elapsed_ms is the same as elapsed_ms (no IPC overhead to
|
|
680
|
+
# subtract; pystata is in-process). We still report it separately so the
|
|
681
|
+
# field is exercised end-to-end.
|
|
682
|
+
stata_elapsed_ms = elapsed_total_ms
|
|
683
|
+
|
|
684
|
+
log = _split_log(
|
|
685
|
+
stdout_text,
|
|
686
|
+
log_lines_head,
|
|
687
|
+
log_lines_tail,
|
|
688
|
+
include_full_log,
|
|
689
|
+
request_id,
|
|
690
|
+
)
|
|
691
|
+
|
|
692
|
+
# On Stata error, we still surface results/dataset state — they reflect
|
|
693
|
+
# whatever state existed before the failing command (per SCHEMA §3.7
|
|
694
|
+
# commands_executed semantics).
|
|
695
|
+
results = ResultsInfo(
|
|
696
|
+
r=_collect_returns(rt, "r", request_id),
|
|
697
|
+
e=_collect_returns(rt, "e", request_id),
|
|
698
|
+
last_estimation_cmd=_last_estimation_cmd(rt),
|
|
699
|
+
)
|
|
700
|
+
dataset = _collect_dataset(rt, include_dataset_variables)
|
|
701
|
+
|
|
702
|
+
available_varnames = (
|
|
703
|
+
[v.name for v in dataset.variables] if dataset.variables else None
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
if err_msg is not None:
|
|
707
|
+
error = _build_error(rc, err_msg, code, available_varnames)
|
|
708
|
+
# Build an error_window: prefer log tail; fall back to the error message
|
|
709
|
+
# itself when the log is empty (pystata can raise before any stdout
|
|
710
|
+
# gets flushed for short failures).
|
|
711
|
+
log_lines = [
|
|
712
|
+
ln for ln in stdout_text.replace("\r\n", "\n").split("\n") if ln
|
|
713
|
+
]
|
|
714
|
+
if log_lines:
|
|
715
|
+
tail_n = min(len(log_lines), 10)
|
|
716
|
+
error_window = "\n".join(log_lines[-tail_n:])
|
|
717
|
+
else:
|
|
718
|
+
error_window = err_msg.strip()
|
|
719
|
+
log = LogInfo(
|
|
720
|
+
head=log.head,
|
|
721
|
+
tail=log.tail,
|
|
722
|
+
lines_total=log.lines_total,
|
|
723
|
+
bytes_total=log.bytes_total,
|
|
724
|
+
truncated=log.truncated,
|
|
725
|
+
complete=log.complete,
|
|
726
|
+
error_window=error_window,
|
|
727
|
+
ref=log.ref,
|
|
728
|
+
)
|
|
729
|
+
else:
|
|
730
|
+
error = None
|
|
731
|
+
|
|
732
|
+
# Graph capture happens AFTER r/e collection so that `graph dir` /
|
|
733
|
+
# `graph display` / `graph export` (all r-class) don't clobber user r().
|
|
734
|
+
if include_graphs != "none":
|
|
735
|
+
graphs = _collect_graphs(
|
|
736
|
+
rt,
|
|
737
|
+
request_id=request_id,
|
|
738
|
+
pre_existing=pre_graphs,
|
|
739
|
+
fmt=gfmt,
|
|
740
|
+
inline=(include_graphs == "inline"),
|
|
741
|
+
)
|
|
742
|
+
else:
|
|
743
|
+
graphs = []
|
|
744
|
+
|
|
745
|
+
capabilities = ["log_truncation", "multi_session"]
|
|
746
|
+
if include_graphs != "none":
|
|
747
|
+
capabilities.append("graph_ref")
|
|
748
|
+
if include_graphs == "inline":
|
|
749
|
+
capabilities.append("inline_graphs")
|
|
750
|
+
|
|
751
|
+
return RunResult(
|
|
752
|
+
ok=(error is None and rc == 0),
|
|
753
|
+
rc=rc if error is not None else 0,
|
|
754
|
+
session_id=session_id,
|
|
755
|
+
request_id=request_id,
|
|
756
|
+
started_at=started_at,
|
|
757
|
+
elapsed_ms=elapsed_total_ms,
|
|
758
|
+
stata_elapsed_ms=stata_elapsed_ms,
|
|
759
|
+
stata=_stata_info(rt),
|
|
760
|
+
log=log,
|
|
761
|
+
results=results,
|
|
762
|
+
dataset=dataset,
|
|
763
|
+
graphs=graphs,
|
|
764
|
+
warnings=_extract_warnings(stdout_text),
|
|
765
|
+
error=error,
|
|
766
|
+
capabilities=capabilities,
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _last_estimation_cmd(rt: Any) -> str | None:
|
|
771
|
+
"""Mirror e(cmd) for callers; returns None if no estimation has run."""
|
|
772
|
+
try:
|
|
773
|
+
v = rt.sfi.Macro.getGlobal("e(cmd)")
|
|
774
|
+
return v or None
|
|
775
|
+
except Exception: # noqa: BLE001
|
|
776
|
+
return None
|
|
777
|
+
|
|
778
|
+
|
|
779
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
780
|
+
# Multi-session via Stata frames (Module 4)
|
|
781
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
_STATA_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def _frame_for_session(session_id: str) -> str:
|
|
788
|
+
"""Map a session_id to a Stata frame name.
|
|
789
|
+
|
|
790
|
+
`"main"` → Stata's master frame `"default"`. Any other id must be a
|
|
791
|
+
Stata-valid name (`[A-Za-z_][A-Za-z0-9_]*`). The schema permits `-` in
|
|
792
|
+
session_id, but Stata frame names disallow it; v0.1 rejects.
|
|
793
|
+
"""
|
|
794
|
+
if session_id == "main":
|
|
795
|
+
return "default"
|
|
796
|
+
if not _STATA_NAME_RE.match(session_id):
|
|
797
|
+
raise ValueError(
|
|
798
|
+
f"session_id {session_id!r} is not a valid Stata frame name. "
|
|
799
|
+
"Use only letters, digits, and underscore; first char must be "
|
|
800
|
+
"a letter or underscore. (v0.1 limitation.)"
|
|
801
|
+
)
|
|
802
|
+
return session_id
|
|
803
|
+
|
|
804
|
+
|
|
805
|
+
def _list_frame_names(rt: Any) -> list[str]:
|
|
806
|
+
Frame = rt.sfi.Frame
|
|
807
|
+
n = Frame.getFrameCount()
|
|
808
|
+
return [Frame.getFrameAt(i) for i in range(n)]
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
def _ensure_session(rt: Any, session_id: str) -> None:
|
|
812
|
+
"""Switch to the frame for `session_id`, creating it if it does not exist."""
|
|
813
|
+
target = _frame_for_session(session_id)
|
|
814
|
+
existing = _list_frame_names(rt)
|
|
815
|
+
if target not in existing:
|
|
816
|
+
with redirect_stdout(io.StringIO()):
|
|
817
|
+
rt.stata.run(f"frame create {target}", quietly=True, echo=False)
|
|
818
|
+
# Switch (no-op if already on it; cheap)
|
|
819
|
+
with redirect_stdout(io.StringIO()):
|
|
820
|
+
rt.stata.run(f"frame change {target}", quietly=True, echo=False)
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
def list_sessions() -> list[dict[str, Any]]:
|
|
824
|
+
"""Auxiliary tool: enumerate live sessions (mapped from Stata frames)."""
|
|
825
|
+
try:
|
|
826
|
+
rt = get_runtime()
|
|
827
|
+
except PystataNotAvailable:
|
|
828
|
+
return []
|
|
829
|
+
sessions: list[dict[str, Any]] = []
|
|
830
|
+
for fname in _list_frame_names(rt):
|
|
831
|
+
sid = "main" if fname == "default" else fname
|
|
832
|
+
# n_obs from each frame; switching is needed since Frame helpers
|
|
833
|
+
# operate on the current working frame for getObsTotal indirectly.
|
|
834
|
+
# Easier: query c(N) after switching.
|
|
835
|
+
with redirect_stdout(io.StringIO()):
|
|
836
|
+
rt.stata.run(f"frame change {fname}", quietly=True, echo=False)
|
|
837
|
+
n_obs = int(rt.sfi.Data.getObsTotal())
|
|
838
|
+
sessions.append({"session_id": sid, "frame": fname, "n_obs": n_obs})
|
|
839
|
+
return sessions
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
def reset_session(session_id: str = "main") -> dict[str, Any]:
|
|
843
|
+
"""Auxiliary tool: drop a session's data (and its frame, except `main`).
|
|
844
|
+
|
|
845
|
+
`main` cannot be dropped — it maps to Stata's master `default` frame.
|
|
846
|
+
For `main`, this performs `clear all` to wipe data in place.
|
|
847
|
+
"""
|
|
848
|
+
rt = get_runtime()
|
|
849
|
+
target = _frame_for_session(session_id)
|
|
850
|
+
if session_id == "main":
|
|
851
|
+
# Switch in, clear, return
|
|
852
|
+
with redirect_stdout(io.StringIO()):
|
|
853
|
+
rt.stata.run("frame change default", quietly=True, echo=False)
|
|
854
|
+
rt.stata.run("clear all", quietly=True, echo=False)
|
|
855
|
+
return {"session_id": "main", "dropped_frame": False}
|
|
856
|
+
# Drop a non-main frame. Must switch off it first.
|
|
857
|
+
with redirect_stdout(io.StringIO()):
|
|
858
|
+
rt.stata.run("frame change default", quietly=True, echo=False)
|
|
859
|
+
rt.stata.run(f"capture frame drop {target}", quietly=True, echo=False)
|
|
860
|
+
# Drop ref-store entries scoped to this session (best-effort).
|
|
861
|
+
_refs.clear_prefix(f"log://{session_id}-")
|
|
862
|
+
_refs.clear_prefix(f"graph://{session_id}-")
|
|
863
|
+
return {"session_id": session_id, "dropped_frame": True}
|
|
864
|
+
|
|
865
|
+
|
|
866
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
867
|
+
# Warning extraction (Module 3)
|
|
868
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
# Patterns are ordered: more specific kinds first. Each pattern produces one
|
|
872
|
+
# warning per match (de-duped at the schema level).
|
|
873
|
+
_WARNING_PATTERNS: tuple[tuple[str, re.Pattern[str]], ...] = (
|
|
874
|
+
# Stata's "omitted because of collinearity" note — shows up under
|
|
875
|
+
# `regress`, `logit`, etc. when factor levels or duplicate vars are
|
|
876
|
+
# dropped from the design matrix.
|
|
877
|
+
(
|
|
878
|
+
"omitted_collinear",
|
|
879
|
+
re.compile(
|
|
880
|
+
r"note:\s+(.+?)\s+omitted because of collinearity\.?",
|
|
881
|
+
re.IGNORECASE,
|
|
882
|
+
),
|
|
883
|
+
),
|
|
884
|
+
# Convergence not achieved (MLE-family commands)
|
|
885
|
+
(
|
|
886
|
+
"convergence",
|
|
887
|
+
re.compile(
|
|
888
|
+
r"convergence (?:not achieved|not reached|failed)", re.IGNORECASE
|
|
889
|
+
),
|
|
890
|
+
),
|
|
891
|
+
# Matrix not pos. def. / singular — typically reported in MLE diagnostics
|
|
892
|
+
(
|
|
893
|
+
"singular",
|
|
894
|
+
re.compile(
|
|
895
|
+
r"(?:matrix\s+)?(?:not symmetric|not positive definite|"
|
|
896
|
+
r"is\s+singular)",
|
|
897
|
+
re.IGNORECASE,
|
|
898
|
+
),
|
|
899
|
+
),
|
|
900
|
+
# Boundary / could-not-find-feasible — softer than rc 491
|
|
901
|
+
(
|
|
902
|
+
"boundary",
|
|
903
|
+
re.compile(r"could not find feasible (?:starting )?values", re.IGNORECASE),
|
|
904
|
+
),
|
|
905
|
+
)
|
|
906
|
+
|
|
907
|
+
# Generic Stata "note:" lines that don't match a more specific pattern.
|
|
908
|
+
_NOTE_RE = re.compile(r"^\s*note:\s*(.+?)\s*$", re.MULTILINE)
|
|
909
|
+
|
|
910
|
+
|
|
911
|
+
def _extract_warnings(log: str) -> list: # list[StataWarning]
|
|
912
|
+
"""Scan the captured log for known Stata warning patterns.
|
|
913
|
+
|
|
914
|
+
Returns a list of StataWarning entries. De-duplicated at the schema layer
|
|
915
|
+
by `(kind, message)`.
|
|
916
|
+
"""
|
|
917
|
+
from stata_code.core.schema import StataWarning
|
|
918
|
+
|
|
919
|
+
out: list = []
|
|
920
|
+
seen: set[tuple[str, str]] = set()
|
|
921
|
+
matched_spans: list[tuple[int, int]] = []
|
|
922
|
+
|
|
923
|
+
for kind, pat in _WARNING_PATTERNS:
|
|
924
|
+
for m in pat.finditer(log):
|
|
925
|
+
msg = m.group(0).strip()
|
|
926
|
+
key = (kind, msg)
|
|
927
|
+
if key in seen:
|
|
928
|
+
continue
|
|
929
|
+
seen.add(key)
|
|
930
|
+
matched_spans.append(m.span())
|
|
931
|
+
out.append(StataWarning(kind=kind, message=msg))
|
|
932
|
+
|
|
933
|
+
# Generic notes: any `note: ...` line not already matched by a specific
|
|
934
|
+
# pattern. Avoid double-counting.
|
|
935
|
+
for m in _NOTE_RE.finditer(log):
|
|
936
|
+
if any(s <= m.start() < e for s, e in matched_spans):
|
|
937
|
+
continue
|
|
938
|
+
msg = m.group(0).strip()
|
|
939
|
+
key = ("note", msg)
|
|
940
|
+
if key in seen:
|
|
941
|
+
continue
|
|
942
|
+
seen.add(key)
|
|
943
|
+
out.append(StataWarning(kind="note", message=msg))
|
|
944
|
+
|
|
945
|
+
return out
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
949
|
+
# Graph capture (Module 1)
|
|
950
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def _png_dimensions(data: bytes) -> tuple[int | None, int | None]:
|
|
954
|
+
"""Best-effort width/height from a PNG IHDR chunk."""
|
|
955
|
+
if len(data) < 24 or data[:8] != b"\x89PNG\r\n\x1a\n":
|
|
956
|
+
return None, None
|
|
957
|
+
return (
|
|
958
|
+
int.from_bytes(data[16:20], "big"),
|
|
959
|
+
int.from_bytes(data[20:24], "big"),
|
|
960
|
+
)
|
|
961
|
+
|
|
962
|
+
|
|
963
|
+
def _list_graph_names(rt: Any) -> list[str]:
|
|
964
|
+
"""Run `graph dir` (silently) and return current in-memory graph names."""
|
|
965
|
+
try:
|
|
966
|
+
with redirect_stdout(io.StringIO()):
|
|
967
|
+
rt.stata.run("graph dir", quietly=False, echo=False)
|
|
968
|
+
raw = rt.sfi.SFIToolkit.macroExpand("`r(list)'") or ""
|
|
969
|
+
return raw.split()
|
|
970
|
+
except Exception: # noqa: BLE001
|
|
971
|
+
return []
|
|
972
|
+
|
|
973
|
+
|
|
974
|
+
def _collect_graphs(
|
|
975
|
+
rt: Any,
|
|
976
|
+
request_id: str,
|
|
977
|
+
pre_existing: list[str],
|
|
978
|
+
fmt: GraphFormat,
|
|
979
|
+
inline: bool,
|
|
980
|
+
) -> list[GraphInfo]:
|
|
981
|
+
"""Capture graphs that user code newly created.
|
|
982
|
+
|
|
983
|
+
Strategy: snapshot graph names before user code (`pre_existing`), call
|
|
984
|
+
after to find the post-existing list, take the set difference. For each
|
|
985
|
+
new graph: `graph display <name>` (makes it active), `graph export` to a
|
|
986
|
+
tmpfile, read bytes, store under a ref. Tmpfile is deleted after.
|
|
987
|
+
"""
|
|
988
|
+
after_names = _list_graph_names(rt)
|
|
989
|
+
new_names = [n for n in after_names if n not in pre_existing]
|
|
990
|
+
if not new_names:
|
|
991
|
+
return []
|
|
992
|
+
|
|
993
|
+
fmt_str = fmt.value
|
|
994
|
+
out: list[GraphInfo] = []
|
|
995
|
+
tmpdir = Path(tempfile.mkdtemp(prefix="stata_code_graph_"))
|
|
996
|
+
try:
|
|
997
|
+
for idx, gname in enumerate(new_names):
|
|
998
|
+
target = tmpdir / f"{idx}.{fmt_str}"
|
|
999
|
+
try:
|
|
1000
|
+
with redirect_stdout(io.StringIO()):
|
|
1001
|
+
rt.stata.run(f"graph display {gname}", quietly=True, echo=False)
|
|
1002
|
+
rt.stata.run(
|
|
1003
|
+
f'graph export "{target}", as({fmt_str}) replace',
|
|
1004
|
+
quietly=True,
|
|
1005
|
+
echo=False,
|
|
1006
|
+
)
|
|
1007
|
+
except SystemError:
|
|
1008
|
+
# Stata refused — skip this graph (e.g., window not found)
|
|
1009
|
+
continue
|
|
1010
|
+
if not target.exists():
|
|
1011
|
+
continue
|
|
1012
|
+
data = target.read_bytes()
|
|
1013
|
+
ref = f"graph://{request_id}/{idx}"
|
|
1014
|
+
width = height = None
|
|
1015
|
+
if fmt == GraphFormat.PNG:
|
|
1016
|
+
width, height = _png_dimensions(data)
|
|
1017
|
+
_refs.put(
|
|
1018
|
+
ref,
|
|
1019
|
+
{
|
|
1020
|
+
"format": fmt_str,
|
|
1021
|
+
"bytes": data,
|
|
1022
|
+
"width": width,
|
|
1023
|
+
"height": height,
|
|
1024
|
+
},
|
|
1025
|
+
)
|
|
1026
|
+
out.append(
|
|
1027
|
+
GraphInfo(
|
|
1028
|
+
ref=ref,
|
|
1029
|
+
name=gname,
|
|
1030
|
+
format=fmt,
|
|
1031
|
+
width=width,
|
|
1032
|
+
height=height,
|
|
1033
|
+
source_command=None, # v0.1: not yet pinpointing
|
|
1034
|
+
source_line=None,
|
|
1035
|
+
inline=_b64(data) if inline else None,
|
|
1036
|
+
)
|
|
1037
|
+
)
|
|
1038
|
+
finally:
|
|
1039
|
+
try:
|
|
1040
|
+
for f in tmpdir.iterdir():
|
|
1041
|
+
try:
|
|
1042
|
+
f.unlink()
|
|
1043
|
+
except OSError:
|
|
1044
|
+
pass
|
|
1045
|
+
tmpdir.rmdir()
|
|
1046
|
+
except OSError:
|
|
1047
|
+
pass
|
|
1048
|
+
|
|
1049
|
+
return out
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def _b64(data: bytes) -> str:
|
|
1053
|
+
import base64
|
|
1054
|
+
|
|
1055
|
+
return base64.b64encode(data).decode("ascii")
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
def get_graph(ref: str, format: str | None = None) -> dict[str, Any]:
|
|
1059
|
+
"""Auxiliary tool: fetch a graph's bytes and dimensions by ref.
|
|
1060
|
+
|
|
1061
|
+
Per SCHEMA.md §5. Returns a dict with `format`, `bytes_b64`, `width`,
|
|
1062
|
+
`height`. Raises KeyError if the ref is unknown (expired, never existed,
|
|
1063
|
+
or session reset).
|
|
1064
|
+
"""
|
|
1065
|
+
payload = _refs.get(ref)
|
|
1066
|
+
if payload is None:
|
|
1067
|
+
raise KeyError(f"unknown graph ref: {ref!r}")
|
|
1068
|
+
return {
|
|
1069
|
+
"format": payload["format"],
|
|
1070
|
+
"bytes_b64": _b64(payload["bytes"]),
|
|
1071
|
+
"width": payload["width"],
|
|
1072
|
+
"height": payload["height"],
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def get_matrix(ref: str) -> dict[str, Any]:
|
|
1077
|
+
"""Auxiliary tool: fetch a matrix's values, rows, cols by ref.
|
|
1078
|
+
|
|
1079
|
+
Per SCHEMA.md §5. Used when ``run()`` returns a Matrix with ``values=None``
|
|
1080
|
+
and a ``matrix://...`` ref because the matrix exceeded the inline cell
|
|
1081
|
+
cap (``MATRIX_INLINE_CELL_CAP`` = 10,000 cells by default). Returns a
|
|
1082
|
+
dict with ``rows``, ``cols``, ``values``. Raises ``KeyError`` if the
|
|
1083
|
+
ref is unknown (expired, never existed, or session reset).
|
|
1084
|
+
"""
|
|
1085
|
+
payload = _refs.get(ref)
|
|
1086
|
+
if payload is None:
|
|
1087
|
+
raise KeyError(f"unknown matrix ref: {ref!r}")
|
|
1088
|
+
return {
|
|
1089
|
+
"rows": payload["rows"],
|
|
1090
|
+
"cols": payload["cols"],
|
|
1091
|
+
"values": payload["values"],
|
|
1092
|
+
}
|