stata-code 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1092 @@
1
+ """High-level execute() — runs Stata code and returns a v1.0 RunResult.
2
+
3
+ This is the only place that touches Stata. The MCP server and Jupyter
4
+ kernel both import from here and only translate transports.
5
+
6
+ Implements the v1.0 envelope from SCHEMA.md: ok / rc / error / log /
7
+ results / dataset / graphs / warnings / capabilities. r() and e() are
8
+ collected via sfi (native types). Multi-session is implemented through
9
+ Stata frames (session_id="main" ↔ default frame). Per-line error
10
+ attribution comes from parsing pystata's transcript.
11
+
12
+ For deferred items (hard timeout, cooperative cancellation, get_matrix
13
+ ref mode, console fallback for Stata 11–16, streaming logs), see
14
+ SCHEMA.md §8.
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import io
20
+ import re
21
+ import tempfile
22
+ import threading
23
+ import time
24
+ import uuid
25
+ from contextlib import redirect_stdout
26
+ from datetime import datetime, timezone
27
+ from pathlib import Path
28
+ from typing import Any
29
+
30
+ from stata_code.core import _refs
31
+ from stata_code.core._runtime import PystataNotAvailable, get_runtime
32
+ from stata_code.core.errors import classify_rc, suggestions_for
33
+ from stata_code.core.schema import (
34
+ Backend,
35
+ DatasetInfo,
36
+ ErrorContext,
37
+ ErrorInfo,
38
+ ErrorKind,
39
+ GraphFormat,
40
+ GraphInfo,
41
+ LogInfo,
42
+ Matrix,
43
+ ResultsInfo,
44
+ RunResult,
45
+ StataEdition,
46
+ StataInfo,
47
+ StataReturns,
48
+ VariableInfo,
49
+ )
50
+
51
+ # ─────────────────────────────────────────────────────────────────────────────
52
+ # Helpers
53
+ # ─────────────────────────────────────────────────────────────────────────────
54
+
55
+
56
+ _EDITION_MAP: dict[str, StataEdition] = {
57
+ "mp": StataEdition.MP,
58
+ "se": StataEdition.SE,
59
+ "ic": StataEdition.IC,
60
+ "be": StataEdition.BE,
61
+ }
62
+
63
+ _ERETURN_NAME_RE = re.compile(r"^\s*(?:e|r)\(([A-Za-z_][A-Za-z0-9_]*)\)\s*[=:]")
64
+ _VARNAME_RE = re.compile(r"variable (\w+) (?:not found|already defined)")
65
+ _FILE_PATH_RE = re.compile(
66
+ r"file\s+(\S+?)\s+(?:not\s+found|already\s+exists|could\s+not)"
67
+ )
68
+ _NAME_CONFLICT_RE = re.compile(r"(\w+)\s+already\s+(?:defined|exists)")
69
+ _UNRECOGNIZED_CMD_RE = re.compile(r"(\S+)\s+(?:is\s+)?unrecognized\s+command")
70
+
71
+ # Cooperative cancellation: a per-session "cancel-pending" flag, settable
72
+ # from any thread via `cancel(session_id)`. The flag is consumed by the
73
+ # next `execute()` call for that session, which short-circuits and returns
74
+ # a RunResult with `error.kind="cancelled"` instead of forwarding the code
75
+ # to Stata. Cooperative semantics — does NOT interrupt code that is
76
+ # already mid-`stata.run()`. Hard interruption requires the subprocess-
77
+ # based runtime planned for v0.3+ (see SCHEMA.md §8).
78
+ _cancel_lock = threading.Lock()
79
+ _cancel_pending: set[str] = set()
80
+
81
+ # Cap on `dataset.variables` to avoid pathological return sizes (per SCHEMA §3.5).
82
+ _DATASET_VAR_CAP = 200
83
+
84
+ # Cap on inlined matrix cells (rows × cols). Above this, `values` is omitted
85
+ # from the envelope and a `matrix://...` ref is stored instead, retrievable
86
+ # via `get_matrix(ref)`. Per SCHEMA.md §3.4: "Producers SHOULD do this when
87
+ # a matrix would inline more than ~10,000 cells."
88
+ MATRIX_INLINE_CELL_CAP = 10_000
89
+
90
+
91
+ def _utc_iso_ms() -> str:
92
+ now = datetime.now(timezone.utc)
93
+ return now.strftime("%Y-%m-%dT%H:%M:%S.") + f"{now.microsecond // 1000:03d}Z"
94
+
95
+
96
+ def _new_request_id() -> str:
97
+ # uuid4 hex is unique enough; ULID would be sortable but adds a dep.
98
+ return uuid.uuid4().hex
99
+
100
+
101
+ def _split_log(
102
+ log: str,
103
+ head_lines: int,
104
+ tail_lines: int,
105
+ include_full: bool,
106
+ request_id: str,
107
+ ) -> LogInfo:
108
+ """Build a LogInfo per SCHEMA §3.3.
109
+
110
+ Stores the full log under `log://<request_id>` when truncating, so that
111
+ `get_log(ref)` can retrieve it later within the producer's lifetime.
112
+ """
113
+ norm = log.replace("\r\n", "\n").replace("\r", "\n")
114
+ lines = norm.split("\n")
115
+ if lines and lines[-1] == "":
116
+ lines = lines[:-1]
117
+ lines_total = len(lines)
118
+ # `bytes_total` reflects the byte count of what `get_log(ref)` would
119
+ # return — i.e., the normalized text without trailing newline.
120
+ full_text = "\n".join(lines)
121
+ bytes_total = len(full_text.encode("utf-8"))
122
+
123
+ if include_full or lines_total <= head_lines + tail_lines:
124
+ return LogInfo(
125
+ head=full_text,
126
+ tail="",
127
+ lines_total=lines_total,
128
+ bytes_total=bytes_total,
129
+ truncated=False,
130
+ complete=True,
131
+ ref=None,
132
+ )
133
+
134
+ head = "\n".join(lines[:head_lines])
135
+ tail = "\n".join(lines[-tail_lines:])
136
+ ref = f"log://{request_id}"
137
+ _refs.put(
138
+ ref,
139
+ {"text": full_text, "lines_total": lines_total, "bytes_total": bytes_total},
140
+ )
141
+ return LogInfo(
142
+ head=head,
143
+ tail=tail,
144
+ lines_total=lines_total,
145
+ bytes_total=bytes_total,
146
+ truncated=True,
147
+ complete=True,
148
+ ref=ref,
149
+ )
150
+
151
+
152
+ def get_log(ref: str) -> dict[str, Any]:
153
+ """Auxiliary tool: fetch the full log behind a `log.ref`.
154
+
155
+ Per SCHEMA.md §5. Raises KeyError if the ref is unknown.
156
+ """
157
+ payload = _refs.get(ref)
158
+ if payload is None:
159
+ raise KeyError(f"unknown log ref: {ref!r}")
160
+ return {
161
+ "text": payload["text"],
162
+ "lines_total": payload["lines_total"],
163
+ "bytes_total": payload["bytes_total"],
164
+ }
165
+
166
+
167
+ def cancel(session_id: str = "main") -> bool:
168
+ """Request cancellation of the next ``execute()`` call for ``session_id``.
169
+
170
+ Cooperative: does **not** interrupt code that is currently mid-execution
171
+ inside pystata. The flag is consumed (and the run short-circuited)
172
+ when ``execute(session_id=...)`` is next invoked for the same session.
173
+ The short-circuit returns a ``RunResult`` with ``ok=False``, ``rc=-1``,
174
+ and ``error.kind=cancelled``.
175
+
176
+ Returns ``True`` if a new cancel was registered, ``False`` if one was
177
+ already pending (idempotent).
178
+ """
179
+ with _cancel_lock:
180
+ if session_id in _cancel_pending:
181
+ return False
182
+ _cancel_pending.add(session_id)
183
+ return True
184
+
185
+
186
+ def is_cancel_pending(session_id: str = "main") -> bool:
187
+ """Whether a cancel will fire on the next ``execute()`` for this session."""
188
+ with _cancel_lock:
189
+ return session_id in _cancel_pending
190
+
191
+
192
+ def clear_cancel(session_id: str = "main") -> bool:
193
+ """Drop any pending cancel for ``session_id`` without firing it.
194
+
195
+ Returns ``True`` if a pending cancel was cleared.
196
+ """
197
+ with _cancel_lock:
198
+ if session_id in _cancel_pending:
199
+ _cancel_pending.remove(session_id)
200
+ return True
201
+ return False
202
+
203
+
204
+ def _consume_cancel(session_id: str) -> bool:
205
+ """Pop and return whether a cancel is pending for ``session_id``."""
206
+ with _cancel_lock:
207
+ if session_id in _cancel_pending:
208
+ _cancel_pending.remove(session_id)
209
+ return True
210
+ return False
211
+
212
+
213
+ def _build_cancelled_result(
214
+ *,
215
+ rt: Any,
216
+ session_id: str,
217
+ request_id: str,
218
+ started_at: str,
219
+ started: float,
220
+ include_dataset_variables: bool,
221
+ ) -> RunResult:
222
+ """Synthesize a RunResult for a cancel-before-Stata short-circuit.
223
+
224
+ The dataset block still reflects current state (post-cancel snapshot);
225
+ log / results / graphs / warnings are empty because no code ran.
226
+ rc=-3 is the synthetic code reserved for cooperative cancellation
227
+ (distinct from -1 adapter_crash and -2 timeout, per SCHEMA.md §3.7).
228
+ """
229
+ elapsed_total_ms = max(1, int((time.monotonic() - started) * 1000))
230
+ return RunResult(
231
+ ok=False,
232
+ rc=-3,
233
+ session_id=session_id,
234
+ request_id=request_id,
235
+ started_at=started_at,
236
+ elapsed_ms=elapsed_total_ms,
237
+ stata_elapsed_ms=0,
238
+ stata=_stata_info(rt),
239
+ log=LogInfo(
240
+ head="", tail="", lines_total=0, bytes_total=0,
241
+ truncated=False, complete=True, ref=None,
242
+ ),
243
+ results=ResultsInfo(),
244
+ dataset=_collect_dataset(rt, include_dataset_variables),
245
+ graphs=[],
246
+ warnings=[],
247
+ error=ErrorInfo(
248
+ kind=ErrorKind.CANCELLED,
249
+ rc=-3,
250
+ rc_label="cancelled",
251
+ message=(
252
+ "Execution cancelled before Stata received the code "
253
+ f"(session_id={session_id!r})."
254
+ ),
255
+ command=None,
256
+ line=None,
257
+ context=ErrorContext(before=[], failing="", after=[]),
258
+ commands_executed=0,
259
+ path=None,
260
+ varname=None,
261
+ name=None,
262
+ suggestions=[],
263
+ ),
264
+ capabilities=["cancel", "multi_session"],
265
+ )
266
+
267
+
268
+ def _parse_return_list(text: str) -> dict[str, list[str]]:
269
+ """Parse `return list` / `ereturn list` output into category -> names.
270
+
271
+ Categories are 'scalars', 'macros', 'matrices' (and 'functions' which we
272
+ ignore in v0.1).
273
+ """
274
+ out: dict[str, list[str]] = {"scalars": [], "macros": [], "matrices": []}
275
+ current: str | None = None
276
+ for raw in text.splitlines():
277
+ line = raw.rstrip()
278
+ stripped = line.strip()
279
+ if not stripped:
280
+ continue
281
+ # Section headers are at the left margin: "scalars:", "macros:", etc.
282
+ if not line.startswith(" ") and stripped.endswith(":"):
283
+ label = stripped[:-1].strip().lower()
284
+ if label in out:
285
+ current = label
286
+ else:
287
+ current = None
288
+ continue
289
+ if current is None:
290
+ continue
291
+ m = _ERETURN_NAME_RE.match(line)
292
+ if m:
293
+ out[current].append(m.group(1))
294
+ return out
295
+
296
+
297
+ def _list_returns(rt: Any, prefix: str) -> dict[str, list[str]]:
298
+ """Get the names of r() / e() members by parsing `return list` text.
299
+
300
+ `prefix` is "r" or "e". This runs `return list` / `ereturn list` and
301
+ captures its output into a dedicated buffer (separate from the user log).
302
+ """
303
+ cmd = "ereturn list" if prefix == "e" else "return list"
304
+ buf = io.StringIO()
305
+ try:
306
+ with redirect_stdout(buf):
307
+ rt.stata.run(cmd, quietly=False, echo=False)
308
+ except Exception: # noqa: BLE001
309
+ return {"scalars": [], "macros": [], "matrices": []}
310
+ return _parse_return_list(buf.getvalue())
311
+
312
+
313
+ def _collect_returns(rt: Any, prefix: str, request_id: str) -> StataReturns:
314
+ """Build a StataReturns for r() or e() using sfi for typed access.
315
+
316
+ Matrices larger than ``MATRIX_INLINE_CELL_CAP`` cells are emitted with
317
+ ``values=None`` and a ``matrix://<request_id>/<prefix>/<name>`` ref;
318
+ callers fetch the values via :func:`get_matrix`.
319
+ """
320
+ names = _list_returns(rt, prefix)
321
+ sfi = rt.sfi
322
+
323
+ scalars: dict[str, float | None] = {}
324
+ for name in names["scalars"]:
325
+ try:
326
+ v = sfi.Scalar.getValue(f"{prefix}({name})")
327
+ scalars[name] = float(v) if v is not None else None
328
+ except Exception: # noqa: BLE001
329
+ scalars[name] = None
330
+
331
+ macros: dict[str, str] = {}
332
+ for name in names["macros"]:
333
+ try:
334
+ v = sfi.Macro.getGlobal(f"{prefix}({name})")
335
+ macros[name] = v if v is not None else ""
336
+ except Exception: # noqa: BLE001
337
+ macros[name] = ""
338
+
339
+ matrices: dict[str, Matrix] = {}
340
+ for name in names["matrices"]:
341
+ key = f"{prefix}({name})"
342
+ try:
343
+ values = sfi.Matrix.get(key)
344
+ rows = list(sfi.Matrix.getRowNames(key) or [])
345
+ cols = list(sfi.Matrix.getColNames(key) or [])
346
+ norm_values: list[list[float | None]] = [
347
+ [None if v is None else float(v) for v in row]
348
+ for row in values
349
+ ]
350
+ n_rows = len(norm_values)
351
+ n_cols = len(norm_values[0]) if n_rows else 0
352
+ if n_rows * n_cols > MATRIX_INLINE_CELL_CAP:
353
+ ref = f"matrix://{request_id}/{prefix}/{name}"
354
+ _refs.put(
355
+ ref,
356
+ {"rows": rows, "cols": cols, "values": norm_values},
357
+ )
358
+ matrices[name] = Matrix(
359
+ rows=rows, cols=cols, values=None, ref=ref
360
+ )
361
+ else:
362
+ matrices[name] = Matrix(
363
+ rows=rows, cols=cols, values=norm_values, ref=None
364
+ )
365
+ except Exception: # noqa: BLE001
366
+ continue
367
+
368
+ return StataReturns(scalars=scalars, macros=macros, matrices=matrices)
369
+
370
+
371
+ def _collect_dataset(rt: Any, include_variables: bool) -> DatasetInfo:
372
+ sfi = rt.sfi
373
+ Data = sfi.Data
374
+ SFIToolkit = sfi.SFIToolkit
375
+ Scalar = sfi.Scalar
376
+
377
+ n_vars = int(Data.getVarCount())
378
+ n_obs = int(Data.getObsTotal())
379
+
380
+ # c(changed) / c(filename) / c(frame): some are scalar-accessible, some are
381
+ # macro-accessible. Use a try/fallback.
382
+ def _c_macro(name: str) -> str | None:
383
+ try:
384
+ v = SFIToolkit.macroExpand(f"`c({name})'")
385
+ return v if v else None
386
+ except Exception: # noqa: BLE001
387
+ return None
388
+
389
+ changed_val = 0.0
390
+ try:
391
+ changed_val = float(Scalar.getValue("c(changed)") or 0.0)
392
+ except Exception: # noqa: BLE001
393
+ pass
394
+ changed = bool(changed_val)
395
+
396
+ filename = _c_macro("filename")
397
+ frame_name = _c_macro("frame") or "default"
398
+
399
+ variables: list[VariableInfo] | None
400
+ if include_variables and n_vars > 0:
401
+ cap = min(n_vars, _DATASET_VAR_CAP)
402
+ variables = [
403
+ VariableInfo(
404
+ name=Data.getVarName(i),
405
+ type=Data.getVarType(i),
406
+ label=Data.getVarLabel(i) or "",
407
+ )
408
+ for i in range(cap)
409
+ ]
410
+ else:
411
+ variables = None
412
+
413
+ return DatasetInfo(
414
+ frame=frame_name,
415
+ n_obs=n_obs,
416
+ n_vars=n_vars,
417
+ changed=changed,
418
+ filename=filename,
419
+ variables=variables,
420
+ )
421
+
422
+
423
+ def _stata_info(rt: Any) -> StataInfo:
424
+ sfi = rt.sfi
425
+ SFIToolkit = sfi.SFIToolkit
426
+ try:
427
+ version = SFIToolkit.macroExpand("`c(stata_version)'") or None
428
+ except Exception: # noqa: BLE001
429
+ version = None
430
+ edition_str = (rt.edition or "").lower()
431
+ edition = _EDITION_MAP.get(edition_str, StataEdition.UNKNOWN)
432
+ return StataInfo(version=version, edition=edition, backend=Backend.PYSTATA)
433
+
434
+
435
+ def _extract_typed_fields(kind: ErrorKind, message: str) -> dict[str, str | None]:
436
+ fields: dict[str, str | None] = {
437
+ "varname": None,
438
+ "path": None,
439
+ "name": None,
440
+ "command": None,
441
+ }
442
+ if kind == ErrorKind.VARNAME_NOT_FOUND or kind == ErrorKind.NAME_CONFLICT:
443
+ m = _VARNAME_RE.search(message)
444
+ if m:
445
+ if kind == ErrorKind.VARNAME_NOT_FOUND:
446
+ fields["varname"] = m.group(1)
447
+ else:
448
+ fields["name"] = m.group(1)
449
+ if kind in (
450
+ ErrorKind.FILE_NOT_FOUND,
451
+ ErrorKind.FILE_EXISTS,
452
+ ErrorKind.FILE_IO,
453
+ ErrorKind.FILE_CORRUPT,
454
+ ):
455
+ m = _FILE_PATH_RE.search(message)
456
+ if m:
457
+ fields["path"] = m.group(1)
458
+ if kind == ErrorKind.NAME_CONFLICT and fields["name"] is None:
459
+ m = _NAME_CONFLICT_RE.search(message)
460
+ if m:
461
+ fields["name"] = m.group(1)
462
+ if kind == ErrorKind.COMMAND_NOT_FOUND:
463
+ m = _UNRECOGNIZED_CMD_RE.search(message)
464
+ if m:
465
+ fields["command"] = m.group(1)
466
+ return fields
467
+
468
+
469
+ def _parse_failure_transcript(
470
+ error_text: str, user_code: str
471
+ ) -> dict[str, Any]:
472
+ """Pinpoint the failing command in multi-line user code.
473
+
474
+ pystata's SystemError for multi-line input contains the full Stata
475
+ transcript with `. <cmd>` echoes for each line. We parse it to recover:
476
+
477
+ - `failing`: the failing command's text (or "" if not isolatable)
478
+ - `line`: 1-indexed line in the original user code (or None)
479
+ - `commands_executed`: count of *non-comment* commands that completed
480
+ successfully before the failure (or None)
481
+ - `before` / `after`: surrounding lines in the user code (up to 3 / 1)
482
+ """
483
+ out: dict[str, Any] = {
484
+ "failing": "",
485
+ "line": None,
486
+ "commands_executed": None,
487
+ "before": [],
488
+ "after": [],
489
+ "command": None,
490
+ }
491
+ user_lines = user_code.splitlines()
492
+ non_empty_user_lines = [ln for ln in user_lines if ln.strip()]
493
+
494
+ # Single-line case — no transcript, just the error message.
495
+ if "\n. " not in error_text and not error_text.startswith(". "):
496
+ if len(non_empty_user_lines) == 1:
497
+ failing = non_empty_user_lines[0].strip()
498
+ out["failing"] = failing
499
+ out["command"] = failing
500
+ # Find its line number in the original (with blanks)
501
+ for i, ln in enumerate(user_lines, 1):
502
+ if ln.strip() == failing:
503
+ out["line"] = i
504
+ break
505
+ out["commands_executed"] = 0
506
+ return out
507
+
508
+ # Multi-line case — parse `. <cmd>` lines.
509
+ transcript_lines = error_text.split("\n")
510
+ cmd_echoes: list[str] = []
511
+ for ln in transcript_lines:
512
+ if not ln.startswith(". "):
513
+ continue
514
+ body = ln[2:].strip()
515
+ if not body:
516
+ continue # empty `. ` is just a prompt
517
+ if body.startswith("*") or body.startswith("//"):
518
+ continue # comment-only line — Stata echoes but doesn't "run"
519
+ cmd_echoes.append(body)
520
+
521
+ if not cmd_echoes:
522
+ return out
523
+
524
+ failing = cmd_echoes[-1]
525
+ out["failing"] = failing
526
+ out["command"] = failing
527
+ out["commands_executed"] = len(cmd_echoes) - 1
528
+
529
+ # Match against original user code lines (with blanks) for line number.
530
+ for i, ln in enumerate(user_lines, 1):
531
+ if ln.strip() == failing:
532
+ out["line"] = i
533
+ out["before"] = [
534
+ user_lines[j] for j in range(max(0, i - 4), i - 1) if user_lines[j].strip()
535
+ ][-3:]
536
+ if i < len(user_lines):
537
+ next_lines = [
538
+ user_lines[j] for j in range(i, min(len(user_lines), i + 1))
539
+ if user_lines[j].strip()
540
+ ]
541
+ out["after"] = next_lines[:1]
542
+ break
543
+
544
+ return out
545
+
546
+
547
+ def _build_error(
548
+ rc: int,
549
+ error_message: str,
550
+ user_code: str,
551
+ available_varnames: list[str] | None,
552
+ ) -> ErrorInfo:
553
+ kind = classify_rc(rc)
554
+ short_msg = (
555
+ _last_error_line(error_message) if error_message else ""
556
+ )
557
+ typed = _extract_typed_fields(kind, error_message)
558
+ suggs = suggestions_for(
559
+ kind,
560
+ varname=typed["varname"],
561
+ name=typed["name"],
562
+ command=typed["command"],
563
+ path=typed["path"],
564
+ available_varnames=available_varnames,
565
+ )
566
+ pinpoint = _parse_failure_transcript(error_message, user_code)
567
+ return ErrorInfo(
568
+ kind=kind,
569
+ rc=rc,
570
+ message=short_msg,
571
+ command=pinpoint["command"],
572
+ line=pinpoint["line"],
573
+ context=ErrorContext(
574
+ before=pinpoint["before"],
575
+ failing=pinpoint["failing"],
576
+ after=pinpoint["after"],
577
+ ),
578
+ commands_executed=pinpoint["commands_executed"],
579
+ path=typed["path"],
580
+ varname=typed["varname"],
581
+ name=typed["name"],
582
+ suggestions=suggs,
583
+ )
584
+
585
+
586
+ def _last_error_line(error_text: str) -> str:
587
+ """Extract the most informative line from a Stata error transcript.
588
+
589
+ For single-line errors the text is short; we just take the first line.
590
+ For multi-line transcripts the actual error sentence ("variable X not
591
+ found") sits AFTER the last `. <cmd>` echo and BEFORE the `r(NN);` rc
592
+ line. Return that sentence so agents see the diagnosis, not the echoed
593
+ command.
594
+ """
595
+ lines = [ln for ln in error_text.splitlines() if ln]
596
+ if not lines:
597
+ return ""
598
+ if not any(ln.startswith(". ") for ln in lines):
599
+ return lines[0].strip()
600
+ # Walk from bottom: skip rc lines, take next non-rc, non-`.` line.
601
+ for ln in reversed(lines):
602
+ s = ln.strip()
603
+ if not s:
604
+ continue
605
+ if s.startswith("r(") and s.endswith(");"):
606
+ continue
607
+ if ln.startswith(". "):
608
+ continue
609
+ return s
610
+ return lines[0].strip()
611
+
612
+
613
+ # ─────────────────────────────────────────────────────────────────────────────
614
+ # Public entrypoint
615
+ # ─────────────────────────────────────────────────────────────────────────────
616
+
617
+
618
+ def execute(
619
+ code: str,
620
+ *,
621
+ session_id: str = "main",
622
+ log_lines_head: int = 20,
623
+ log_lines_tail: int = 20,
624
+ include_full_log: bool = False,
625
+ include_graphs: str = "ref", # "ref" | "inline" | "none"
626
+ graph_format: str = "png",
627
+ include_dataset_variables: bool = True,
628
+ timeout_ms: int | None = 600_000, # accepted but not yet enforced (v0.1)
629
+ ) -> RunResult:
630
+ """Execute Stata code and return a v1.0 RunResult.
631
+
632
+ Raises PystataNotAvailable if Stata cannot be initialized.
633
+
634
+ Multi-session: `session_id="main"` routes to Stata's master frame
635
+ (`default`); any other valid Stata-name routes to a same-named Stata
636
+ frame, created on demand. Frames isolate **data** (variables and
637
+ observations), but `r()`, `e()`, scalars, and macros remain global
638
+ across frames — agents needing full isolation should use separate
639
+ processes.
640
+ """
641
+ if include_graphs not in ("ref", "inline", "none"):
642
+ raise ValueError(
643
+ f"include_graphs must be 'ref' | 'inline' | 'none'; got {include_graphs!r}"
644
+ )
645
+ try:
646
+ gfmt = GraphFormat(graph_format)
647
+ except ValueError as exc:
648
+ raise ValueError(
649
+ f"graph_format must be 'png' | 'svg' | 'pdf'; got {graph_format!r}"
650
+ ) from exc
651
+
652
+ rt = get_runtime() # may raise PystataNotAvailable
653
+ _ensure_session(rt, session_id)
654
+
655
+ request_id = _new_request_id()
656
+ started_at = _utc_iso_ms()
657
+ started = time.monotonic()
658
+
659
+ if _consume_cancel(session_id):
660
+ return _build_cancelled_result(
661
+ rt=rt,
662
+ session_id=session_id,
663
+ request_id=request_id,
664
+ started_at=started_at,
665
+ started=started,
666
+ include_dataset_variables=include_dataset_variables,
667
+ )
668
+
669
+ # Snapshot existing graph names before user code so we can take a delta
670
+ # afterward. This itself calls `graph dir`, which clobbers r(); user code
671
+ # will overwrite r() if they care about return values.
672
+ pre_graphs = (
673
+ _list_graph_names(rt) if include_graphs != "none" else []
674
+ )
675
+
676
+ stdout_text, rc, err_msg = rt.run_capture(code)
677
+
678
+ elapsed_total_ms = max(1, int((time.monotonic() - started) * 1000))
679
+ # v0.1: stata_elapsed_ms is the same as elapsed_ms (no IPC overhead to
680
+ # subtract; pystata is in-process). We still report it separately so the
681
+ # field is exercised end-to-end.
682
+ stata_elapsed_ms = elapsed_total_ms
683
+
684
+ log = _split_log(
685
+ stdout_text,
686
+ log_lines_head,
687
+ log_lines_tail,
688
+ include_full_log,
689
+ request_id,
690
+ )
691
+
692
+ # On Stata error, we still surface results/dataset state — they reflect
693
+ # whatever state existed before the failing command (per SCHEMA §3.7
694
+ # commands_executed semantics).
695
+ results = ResultsInfo(
696
+ r=_collect_returns(rt, "r", request_id),
697
+ e=_collect_returns(rt, "e", request_id),
698
+ last_estimation_cmd=_last_estimation_cmd(rt),
699
+ )
700
+ dataset = _collect_dataset(rt, include_dataset_variables)
701
+
702
+ available_varnames = (
703
+ [v.name for v in dataset.variables] if dataset.variables else None
704
+ )
705
+
706
+ if err_msg is not None:
707
+ error = _build_error(rc, err_msg, code, available_varnames)
708
+ # Build an error_window: prefer log tail; fall back to the error message
709
+ # itself when the log is empty (pystata can raise before any stdout
710
+ # gets flushed for short failures).
711
+ log_lines = [
712
+ ln for ln in stdout_text.replace("\r\n", "\n").split("\n") if ln
713
+ ]
714
+ if log_lines:
715
+ tail_n = min(len(log_lines), 10)
716
+ error_window = "\n".join(log_lines[-tail_n:])
717
+ else:
718
+ error_window = err_msg.strip()
719
+ log = LogInfo(
720
+ head=log.head,
721
+ tail=log.tail,
722
+ lines_total=log.lines_total,
723
+ bytes_total=log.bytes_total,
724
+ truncated=log.truncated,
725
+ complete=log.complete,
726
+ error_window=error_window,
727
+ ref=log.ref,
728
+ )
729
+ else:
730
+ error = None
731
+
732
+ # Graph capture happens AFTER r/e collection so that `graph dir` /
733
+ # `graph display` / `graph export` (all r-class) don't clobber user r().
734
+ if include_graphs != "none":
735
+ graphs = _collect_graphs(
736
+ rt,
737
+ request_id=request_id,
738
+ pre_existing=pre_graphs,
739
+ fmt=gfmt,
740
+ inline=(include_graphs == "inline"),
741
+ )
742
+ else:
743
+ graphs = []
744
+
745
+ capabilities = ["log_truncation", "multi_session"]
746
+ if include_graphs != "none":
747
+ capabilities.append("graph_ref")
748
+ if include_graphs == "inline":
749
+ capabilities.append("inline_graphs")
750
+
751
+ return RunResult(
752
+ ok=(error is None and rc == 0),
753
+ rc=rc if error is not None else 0,
754
+ session_id=session_id,
755
+ request_id=request_id,
756
+ started_at=started_at,
757
+ elapsed_ms=elapsed_total_ms,
758
+ stata_elapsed_ms=stata_elapsed_ms,
759
+ stata=_stata_info(rt),
760
+ log=log,
761
+ results=results,
762
+ dataset=dataset,
763
+ graphs=graphs,
764
+ warnings=_extract_warnings(stdout_text),
765
+ error=error,
766
+ capabilities=capabilities,
767
+ )
768
+
769
+
770
+ def _last_estimation_cmd(rt: Any) -> str | None:
771
+ """Mirror e(cmd) for callers; returns None if no estimation has run."""
772
+ try:
773
+ v = rt.sfi.Macro.getGlobal("e(cmd)")
774
+ return v or None
775
+ except Exception: # noqa: BLE001
776
+ return None
777
+
778
+
779
+ # ─────────────────────────────────────────────────────────────────────────────
780
+ # Multi-session via Stata frames (Module 4)
781
+ # ─────────────────────────────────────────────────────────────────────────────
782
+
783
+
784
+ _STATA_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
785
+
786
+
787
+ def _frame_for_session(session_id: str) -> str:
788
+ """Map a session_id to a Stata frame name.
789
+
790
+ `"main"` → Stata's master frame `"default"`. Any other id must be a
791
+ Stata-valid name (`[A-Za-z_][A-Za-z0-9_]*`). The schema permits `-` in
792
+ session_id, but Stata frame names disallow it; v0.1 rejects.
793
+ """
794
+ if session_id == "main":
795
+ return "default"
796
+ if not _STATA_NAME_RE.match(session_id):
797
+ raise ValueError(
798
+ f"session_id {session_id!r} is not a valid Stata frame name. "
799
+ "Use only letters, digits, and underscore; first char must be "
800
+ "a letter or underscore. (v0.1 limitation.)"
801
+ )
802
+ return session_id
803
+
804
+
805
+ def _list_frame_names(rt: Any) -> list[str]:
806
+ Frame = rt.sfi.Frame
807
+ n = Frame.getFrameCount()
808
+ return [Frame.getFrameAt(i) for i in range(n)]
809
+
810
+
811
+ def _ensure_session(rt: Any, session_id: str) -> None:
812
+ """Switch to the frame for `session_id`, creating it if it does not exist."""
813
+ target = _frame_for_session(session_id)
814
+ existing = _list_frame_names(rt)
815
+ if target not in existing:
816
+ with redirect_stdout(io.StringIO()):
817
+ rt.stata.run(f"frame create {target}", quietly=True, echo=False)
818
+ # Switch (no-op if already on it; cheap)
819
+ with redirect_stdout(io.StringIO()):
820
+ rt.stata.run(f"frame change {target}", quietly=True, echo=False)
821
+
822
+
823
+ def list_sessions() -> list[dict[str, Any]]:
824
+ """Auxiliary tool: enumerate live sessions (mapped from Stata frames)."""
825
+ try:
826
+ rt = get_runtime()
827
+ except PystataNotAvailable:
828
+ return []
829
+ sessions: list[dict[str, Any]] = []
830
+ for fname in _list_frame_names(rt):
831
+ sid = "main" if fname == "default" else fname
832
+ # n_obs from each frame; switching is needed since Frame helpers
833
+ # operate on the current working frame for getObsTotal indirectly.
834
+ # Easier: query c(N) after switching.
835
+ with redirect_stdout(io.StringIO()):
836
+ rt.stata.run(f"frame change {fname}", quietly=True, echo=False)
837
+ n_obs = int(rt.sfi.Data.getObsTotal())
838
+ sessions.append({"session_id": sid, "frame": fname, "n_obs": n_obs})
839
+ return sessions
840
+
841
+
842
+ def reset_session(session_id: str = "main") -> dict[str, Any]:
843
+ """Auxiliary tool: drop a session's data (and its frame, except `main`).
844
+
845
+ `main` cannot be dropped — it maps to Stata's master `default` frame.
846
+ For `main`, this performs `clear all` to wipe data in place.
847
+ """
848
+ rt = get_runtime()
849
+ target = _frame_for_session(session_id)
850
+ if session_id == "main":
851
+ # Switch in, clear, return
852
+ with redirect_stdout(io.StringIO()):
853
+ rt.stata.run("frame change default", quietly=True, echo=False)
854
+ rt.stata.run("clear all", quietly=True, echo=False)
855
+ return {"session_id": "main", "dropped_frame": False}
856
+ # Drop a non-main frame. Must switch off it first.
857
+ with redirect_stdout(io.StringIO()):
858
+ rt.stata.run("frame change default", quietly=True, echo=False)
859
+ rt.stata.run(f"capture frame drop {target}", quietly=True, echo=False)
860
+ # Drop ref-store entries scoped to this session (best-effort).
861
+ _refs.clear_prefix(f"log://{session_id}-")
862
+ _refs.clear_prefix(f"graph://{session_id}-")
863
+ return {"session_id": session_id, "dropped_frame": True}
864
+
865
+
866
+ # ─────────────────────────────────────────────────────────────────────────────
867
+ # Warning extraction (Module 3)
868
+ # ─────────────────────────────────────────────────────────────────────────────
869
+
870
+
871
+ # Patterns are ordered: more specific kinds first. Each pattern produces one
872
+ # warning per match (de-duped at the schema level).
873
+ _WARNING_PATTERNS: tuple[tuple[str, re.Pattern[str]], ...] = (
874
+ # Stata's "omitted because of collinearity" note — shows up under
875
+ # `regress`, `logit`, etc. when factor levels or duplicate vars are
876
+ # dropped from the design matrix.
877
+ (
878
+ "omitted_collinear",
879
+ re.compile(
880
+ r"note:\s+(.+?)\s+omitted because of collinearity\.?",
881
+ re.IGNORECASE,
882
+ ),
883
+ ),
884
+ # Convergence not achieved (MLE-family commands)
885
+ (
886
+ "convergence",
887
+ re.compile(
888
+ r"convergence (?:not achieved|not reached|failed)", re.IGNORECASE
889
+ ),
890
+ ),
891
+ # Matrix not pos. def. / singular — typically reported in MLE diagnostics
892
+ (
893
+ "singular",
894
+ re.compile(
895
+ r"(?:matrix\s+)?(?:not symmetric|not positive definite|"
896
+ r"is\s+singular)",
897
+ re.IGNORECASE,
898
+ ),
899
+ ),
900
+ # Boundary / could-not-find-feasible — softer than rc 491
901
+ (
902
+ "boundary",
903
+ re.compile(r"could not find feasible (?:starting )?values", re.IGNORECASE),
904
+ ),
905
+ )
906
+
907
+ # Generic Stata "note:" lines that don't match a more specific pattern.
908
+ _NOTE_RE = re.compile(r"^\s*note:\s*(.+?)\s*$", re.MULTILINE)
909
+
910
+
911
+ def _extract_warnings(log: str) -> list: # list[StataWarning]
912
+ """Scan the captured log for known Stata warning patterns.
913
+
914
+ Returns a list of StataWarning entries. De-duplicated at the schema layer
915
+ by `(kind, message)`.
916
+ """
917
+ from stata_code.core.schema import StataWarning
918
+
919
+ out: list = []
920
+ seen: set[tuple[str, str]] = set()
921
+ matched_spans: list[tuple[int, int]] = []
922
+
923
+ for kind, pat in _WARNING_PATTERNS:
924
+ for m in pat.finditer(log):
925
+ msg = m.group(0).strip()
926
+ key = (kind, msg)
927
+ if key in seen:
928
+ continue
929
+ seen.add(key)
930
+ matched_spans.append(m.span())
931
+ out.append(StataWarning(kind=kind, message=msg))
932
+
933
+ # Generic notes: any `note: ...` line not already matched by a specific
934
+ # pattern. Avoid double-counting.
935
+ for m in _NOTE_RE.finditer(log):
936
+ if any(s <= m.start() < e for s, e in matched_spans):
937
+ continue
938
+ msg = m.group(0).strip()
939
+ key = ("note", msg)
940
+ if key in seen:
941
+ continue
942
+ seen.add(key)
943
+ out.append(StataWarning(kind="note", message=msg))
944
+
945
+ return out
946
+
947
+
948
+ # ─────────────────────────────────────────────────────────────────────────────
949
+ # Graph capture (Module 1)
950
+ # ─────────────────────────────────────────────────────────────────────────────
951
+
952
+
953
+ def _png_dimensions(data: bytes) -> tuple[int | None, int | None]:
954
+ """Best-effort width/height from a PNG IHDR chunk."""
955
+ if len(data) < 24 or data[:8] != b"\x89PNG\r\n\x1a\n":
956
+ return None, None
957
+ return (
958
+ int.from_bytes(data[16:20], "big"),
959
+ int.from_bytes(data[20:24], "big"),
960
+ )
961
+
962
+
963
+ def _list_graph_names(rt: Any) -> list[str]:
964
+ """Run `graph dir` (silently) and return current in-memory graph names."""
965
+ try:
966
+ with redirect_stdout(io.StringIO()):
967
+ rt.stata.run("graph dir", quietly=False, echo=False)
968
+ raw = rt.sfi.SFIToolkit.macroExpand("`r(list)'") or ""
969
+ return raw.split()
970
+ except Exception: # noqa: BLE001
971
+ return []
972
+
973
+
974
+ def _collect_graphs(
975
+ rt: Any,
976
+ request_id: str,
977
+ pre_existing: list[str],
978
+ fmt: GraphFormat,
979
+ inline: bool,
980
+ ) -> list[GraphInfo]:
981
+ """Capture graphs that user code newly created.
982
+
983
+ Strategy: snapshot graph names before user code (`pre_existing`), call
984
+ after to find the post-existing list, take the set difference. For each
985
+ new graph: `graph display <name>` (makes it active), `graph export` to a
986
+ tmpfile, read bytes, store under a ref. Tmpfile is deleted after.
987
+ """
988
+ after_names = _list_graph_names(rt)
989
+ new_names = [n for n in after_names if n not in pre_existing]
990
+ if not new_names:
991
+ return []
992
+
993
+ fmt_str = fmt.value
994
+ out: list[GraphInfo] = []
995
+ tmpdir = Path(tempfile.mkdtemp(prefix="stata_code_graph_"))
996
+ try:
997
+ for idx, gname in enumerate(new_names):
998
+ target = tmpdir / f"{idx}.{fmt_str}"
999
+ try:
1000
+ with redirect_stdout(io.StringIO()):
1001
+ rt.stata.run(f"graph display {gname}", quietly=True, echo=False)
1002
+ rt.stata.run(
1003
+ f'graph export "{target}", as({fmt_str}) replace',
1004
+ quietly=True,
1005
+ echo=False,
1006
+ )
1007
+ except SystemError:
1008
+ # Stata refused — skip this graph (e.g., window not found)
1009
+ continue
1010
+ if not target.exists():
1011
+ continue
1012
+ data = target.read_bytes()
1013
+ ref = f"graph://{request_id}/{idx}"
1014
+ width = height = None
1015
+ if fmt == GraphFormat.PNG:
1016
+ width, height = _png_dimensions(data)
1017
+ _refs.put(
1018
+ ref,
1019
+ {
1020
+ "format": fmt_str,
1021
+ "bytes": data,
1022
+ "width": width,
1023
+ "height": height,
1024
+ },
1025
+ )
1026
+ out.append(
1027
+ GraphInfo(
1028
+ ref=ref,
1029
+ name=gname,
1030
+ format=fmt,
1031
+ width=width,
1032
+ height=height,
1033
+ source_command=None, # v0.1: not yet pinpointing
1034
+ source_line=None,
1035
+ inline=_b64(data) if inline else None,
1036
+ )
1037
+ )
1038
+ finally:
1039
+ try:
1040
+ for f in tmpdir.iterdir():
1041
+ try:
1042
+ f.unlink()
1043
+ except OSError:
1044
+ pass
1045
+ tmpdir.rmdir()
1046
+ except OSError:
1047
+ pass
1048
+
1049
+ return out
1050
+
1051
+
1052
+ def _b64(data: bytes) -> str:
1053
+ import base64
1054
+
1055
+ return base64.b64encode(data).decode("ascii")
1056
+
1057
+
1058
+ def get_graph(ref: str, format: str | None = None) -> dict[str, Any]:
1059
+ """Auxiliary tool: fetch a graph's bytes and dimensions by ref.
1060
+
1061
+ Per SCHEMA.md §5. Returns a dict with `format`, `bytes_b64`, `width`,
1062
+ `height`. Raises KeyError if the ref is unknown (expired, never existed,
1063
+ or session reset).
1064
+ """
1065
+ payload = _refs.get(ref)
1066
+ if payload is None:
1067
+ raise KeyError(f"unknown graph ref: {ref!r}")
1068
+ return {
1069
+ "format": payload["format"],
1070
+ "bytes_b64": _b64(payload["bytes"]),
1071
+ "width": payload["width"],
1072
+ "height": payload["height"],
1073
+ }
1074
+
1075
+
1076
+ def get_matrix(ref: str) -> dict[str, Any]:
1077
+ """Auxiliary tool: fetch a matrix's values, rows, cols by ref.
1078
+
1079
+ Per SCHEMA.md §5. Used when ``run()`` returns a Matrix with ``values=None``
1080
+ and a ``matrix://...`` ref because the matrix exceeded the inline cell
1081
+ cap (``MATRIX_INLINE_CELL_CAP`` = 10,000 cells by default). Returns a
1082
+ dict with ``rows``, ``cols``, ``values``. Raises ``KeyError`` if the
1083
+ ref is unknown (expired, never existed, or session reset).
1084
+ """
1085
+ payload = _refs.get(ref)
1086
+ if payload is None:
1087
+ raise KeyError(f"unknown matrix ref: {ref!r}")
1088
+ return {
1089
+ "rows": payload["rows"],
1090
+ "cols": payload["cols"],
1091
+ "values": payload["values"],
1092
+ }