stata-code 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stata_code/__init__.py +100 -0
- stata_code/core/__init__.py +73 -0
- stata_code/core/_pool.py +808 -0
- stata_code/core/_refs.py +97 -0
- stata_code/core/_runtime.py +179 -0
- stata_code/core/errors.py +447 -0
- stata_code/core/runner.py +1092 -0
- stata_code/core/schema.py +317 -0
- stata_code/kernel/__init__.py +5 -0
- stata_code/kernel/__main__.py +6 -0
- stata_code/kernel/kernel.py +331 -0
- stata_code/mcp/__init__.py +3 -0
- stata_code/mcp/__main__.py +6 -0
- stata_code/mcp/server.py +360 -0
- stata_code-0.3.0.dist-info/METADATA +389 -0
- stata_code-0.3.0.dist-info/RECORD +20 -0
- stata_code-0.3.0.dist-info/WHEEL +4 -0
- stata_code-0.3.0.dist-info/entry_points.txt +3 -0
- stata_code-0.3.0.dist-info/licenses/LICENSE +21 -0
- stata_code-0.3.0.dist-info/licenses/LICENSE-POLICY.md +125 -0
|
@@ -0,0 +1,317 @@
|
|
|
1
|
+
"""Pydantic v2 models for the stata_code v1.0 result schema (see SCHEMA.md)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Literal
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
|
10
|
+
|
|
11
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
12
|
+
# Enums (closed at v1.0; new values are minor-version additive)
|
|
13
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ErrorKind(str, Enum):
|
|
17
|
+
SYNTAX = "syntax"
|
|
18
|
+
COMMAND_NOT_FOUND = "command_not_found"
|
|
19
|
+
VARNAME_NOT_FOUND = "varname_not_found"
|
|
20
|
+
INVALID_NAME = "invalid_name"
|
|
21
|
+
TYPE_MISMATCH = "type_mismatch"
|
|
22
|
+
NAME_CONFLICT = "name_conflict"
|
|
23
|
+
NOT_SORTED = "not_sorted"
|
|
24
|
+
CONVERGENCE = "convergence"
|
|
25
|
+
INFEASIBLE = "infeasible"
|
|
26
|
+
ESTIMATION_SAMPLE_EMPTY = "estimation_sample_empty"
|
|
27
|
+
ESTIMATION_FAILURE = "estimation_failure"
|
|
28
|
+
NO_ESTIMATION_RESULTS = "no_estimation_results"
|
|
29
|
+
NO_OBSERVATIONS = "no_observations"
|
|
30
|
+
DATA_IN_MEMORY = "data_in_memory"
|
|
31
|
+
MATRIX_SINGULAR = "matrix_singular"
|
|
32
|
+
MATRIX_CONFORMABILITY = "matrix_conformability"
|
|
33
|
+
MATRIX_MISSING = "matrix_missing"
|
|
34
|
+
FILE_NOT_FOUND = "file_not_found"
|
|
35
|
+
FILE_EXISTS = "file_exists"
|
|
36
|
+
FILE_CORRUPT = "file_corrupt"
|
|
37
|
+
FILE_IO = "file_io"
|
|
38
|
+
NETWORK = "network"
|
|
39
|
+
PERMISSION = "permission"
|
|
40
|
+
ENCODING = "encoding"
|
|
41
|
+
STATA_LIMIT = "stata_limit"
|
|
42
|
+
OUT_OF_MEMORY = "out_of_memory"
|
|
43
|
+
INTERRUPT = "interrupt"
|
|
44
|
+
CANCELLED = "cancelled"
|
|
45
|
+
TIMEOUT = "timeout"
|
|
46
|
+
ADAPTER_CRASH = "adapter_crash"
|
|
47
|
+
UNKNOWN = "unknown"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class StataEdition(str, Enum):
|
|
51
|
+
MP = "MP"
|
|
52
|
+
SE = "SE"
|
|
53
|
+
IC = "IC"
|
|
54
|
+
BE = "BE"
|
|
55
|
+
UNKNOWN = "unknown"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class GraphFormat(str, Enum):
|
|
59
|
+
PNG = "png"
|
|
60
|
+
SVG = "svg"
|
|
61
|
+
PDF = "pdf"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class IncludeGraphs(str, Enum):
|
|
65
|
+
REF = "ref"
|
|
66
|
+
INLINE = "inline"
|
|
67
|
+
NONE = "none"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class Backend(str, Enum):
|
|
71
|
+
PYSTATA = "pystata"
|
|
72
|
+
CONSOLE = "console"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
76
|
+
# Base config — every model is forward-compat (tolerates unknown fields)
|
|
77
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class _Base(BaseModel):
|
|
81
|
+
"""Base for all schema models; allows unknown fields per §6 forward-compat."""
|
|
82
|
+
|
|
83
|
+
model_config = ConfigDict(extra="allow", validate_assignment=True)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
87
|
+
# Sub-models
|
|
88
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class StataInfo(_Base):
|
|
92
|
+
version: str | None = None
|
|
93
|
+
edition: StataEdition = StataEdition.UNKNOWN
|
|
94
|
+
backend: Backend
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class LogInfo(_Base):
|
|
98
|
+
head: str = ""
|
|
99
|
+
tail: str = ""
|
|
100
|
+
lines_total: int = 0
|
|
101
|
+
bytes_total: int = 0
|
|
102
|
+
truncated: bool = False
|
|
103
|
+
complete: bool = True
|
|
104
|
+
error_window: str | None = None
|
|
105
|
+
ref: str | None = None
|
|
106
|
+
|
|
107
|
+
@model_validator(mode="after")
|
|
108
|
+
def _check_invariants(self) -> LogInfo:
|
|
109
|
+
if self.truncated and self.ref is None:
|
|
110
|
+
raise ValueError("log.truncated=True requires log.ref to be set")
|
|
111
|
+
if not self.truncated and self.tail != "":
|
|
112
|
+
raise ValueError(
|
|
113
|
+
"log.truncated=False requires log.tail to be empty "
|
|
114
|
+
"(see SCHEMA.md §3.3)"
|
|
115
|
+
)
|
|
116
|
+
if self.lines_total < 0 or self.bytes_total < 0:
|
|
117
|
+
raise ValueError("log.lines_total and log.bytes_total must be ≥ 0")
|
|
118
|
+
return self
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class Matrix(_Base):
|
|
122
|
+
rows: list[str]
|
|
123
|
+
cols: list[str]
|
|
124
|
+
values: list[list[float | None]] | None = None
|
|
125
|
+
ref: str | None = None
|
|
126
|
+
|
|
127
|
+
@model_validator(mode="after")
|
|
128
|
+
def _check_shape(self) -> Matrix:
|
|
129
|
+
if self.values is None and self.ref is None:
|
|
130
|
+
raise ValueError("matrix must have either values or ref set (or both)")
|
|
131
|
+
if self.values is not None:
|
|
132
|
+
if len(self.values) != len(self.rows):
|
|
133
|
+
raise ValueError(
|
|
134
|
+
f"matrix.values has {len(self.values)} rows, "
|
|
135
|
+
f"expected {len(self.rows)}"
|
|
136
|
+
)
|
|
137
|
+
ncols = len(self.cols)
|
|
138
|
+
for i, row in enumerate(self.values):
|
|
139
|
+
if len(row) != ncols:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"matrix.values row {i} has {len(row)} cols, "
|
|
142
|
+
f"expected {ncols}"
|
|
143
|
+
)
|
|
144
|
+
return self
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class StataReturns(_Base):
|
|
148
|
+
"""Shape shared by r() and e() — distinct instances at RunResult.results.{r,e}."""
|
|
149
|
+
|
|
150
|
+
scalars: dict[str, float | None] = Field(default_factory=dict)
|
|
151
|
+
macros: dict[str, str] = Field(default_factory=dict)
|
|
152
|
+
matrices: dict[str, Matrix] = Field(default_factory=dict)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class ResultsInfo(_Base):
|
|
156
|
+
r: StataReturns = Field(default_factory=StataReturns)
|
|
157
|
+
e: StataReturns = Field(default_factory=StataReturns)
|
|
158
|
+
last_estimation_cmd: str | None = None
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class VariableInfo(_Base):
|
|
162
|
+
name: str
|
|
163
|
+
type: str # Stata storage type: byte/int/long/float/double/str#/strL
|
|
164
|
+
label: str = ""
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class DatasetInfo(_Base):
|
|
168
|
+
frame: str = "default"
|
|
169
|
+
n_obs: int = 0
|
|
170
|
+
n_vars: int = 0
|
|
171
|
+
changed: bool = False
|
|
172
|
+
filename: str | None = None
|
|
173
|
+
variables: list[VariableInfo] | None = None
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class GraphInfo(_Base):
|
|
177
|
+
ref: str
|
|
178
|
+
name: str = "Graph"
|
|
179
|
+
format: GraphFormat = GraphFormat.PNG
|
|
180
|
+
width: int | None = None
|
|
181
|
+
height: int | None = None
|
|
182
|
+
source_command: str | None = None
|
|
183
|
+
source_line: int | None = None
|
|
184
|
+
inline: str | None = None # base64 of the bytes when explicitly requested
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class Suggestion(_Base):
|
|
188
|
+
action: str
|
|
189
|
+
command: str | None = None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class ErrorContext(_Base):
|
|
193
|
+
before: list[str] = Field(default_factory=list)
|
|
194
|
+
failing: str = ""
|
|
195
|
+
after: list[str] = Field(default_factory=list)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
_MESSAGE_MAX = 4096
|
|
199
|
+
_COMMAND_MAX = 1024
|
|
200
|
+
_WARNING_MAX = 1024
|
|
201
|
+
_TRUNC_MARK = "…"
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _truncate(text: str, limit: int) -> str:
|
|
205
|
+
if len(text) <= limit:
|
|
206
|
+
return text
|
|
207
|
+
return text[: limit - len(_TRUNC_MARK)] + _TRUNC_MARK
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class ErrorInfo(_Base):
|
|
211
|
+
kind: ErrorKind
|
|
212
|
+
rc: int
|
|
213
|
+
rc_label: str = ""
|
|
214
|
+
message: str = ""
|
|
215
|
+
command: str | None = None
|
|
216
|
+
line: int | None = None
|
|
217
|
+
context: ErrorContext = Field(default_factory=ErrorContext)
|
|
218
|
+
commands_executed: int | None = None
|
|
219
|
+
path: str | None = None
|
|
220
|
+
varname: str | None = None
|
|
221
|
+
name: str | None = None
|
|
222
|
+
suggestions: list[Suggestion] = Field(default_factory=list)
|
|
223
|
+
|
|
224
|
+
@field_validator("message")
|
|
225
|
+
@classmethod
|
|
226
|
+
def _truncate_message(cls, v: str) -> str:
|
|
227
|
+
return _truncate(v, _MESSAGE_MAX)
|
|
228
|
+
|
|
229
|
+
@field_validator("command")
|
|
230
|
+
@classmethod
|
|
231
|
+
def _truncate_command(cls, v: str | None) -> str | None:
|
|
232
|
+
return None if v is None else _truncate(v, _COMMAND_MAX)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class StataWarning(_Base):
|
|
236
|
+
"""JSON wire name is `warnings`; class avoids shadowing the builtin `Warning`."""
|
|
237
|
+
|
|
238
|
+
kind: str = "unknown"
|
|
239
|
+
message: str = ""
|
|
240
|
+
|
|
241
|
+
@field_validator("message")
|
|
242
|
+
@classmethod
|
|
243
|
+
def _truncate(cls, v: str) -> str:
|
|
244
|
+
return _truncate(v, _WARNING_MAX)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
248
|
+
# Top-level
|
|
249
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
_SESSION_ID_RE = re.compile(r"[A-Za-z0-9_-]+")
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class RunResult(_Base):
|
|
256
|
+
"""Top-level v1.0 schema. SCHEMA.md is normative; this is its derived form."""
|
|
257
|
+
|
|
258
|
+
ok: bool
|
|
259
|
+
rc: int
|
|
260
|
+
session_id: str = "main"
|
|
261
|
+
request_id: str
|
|
262
|
+
started_at: str # ISO 8601 UTC with millisecond precision
|
|
263
|
+
elapsed_ms: int
|
|
264
|
+
stata_elapsed_ms: int | None = None
|
|
265
|
+
|
|
266
|
+
stata: StataInfo
|
|
267
|
+
log: LogInfo = Field(default_factory=LogInfo)
|
|
268
|
+
results: ResultsInfo = Field(default_factory=ResultsInfo)
|
|
269
|
+
dataset: DatasetInfo = Field(default_factory=DatasetInfo)
|
|
270
|
+
graphs: list[GraphInfo] = Field(default_factory=list)
|
|
271
|
+
warnings: list[StataWarning] = Field(default_factory=list)
|
|
272
|
+
error: ErrorInfo | None = None
|
|
273
|
+
|
|
274
|
+
schema_version: Literal["1.0"] = "1.0"
|
|
275
|
+
capabilities: list[str] = Field(default_factory=list)
|
|
276
|
+
|
|
277
|
+
@field_validator("session_id")
|
|
278
|
+
@classmethod
|
|
279
|
+
def _check_session_id(cls, v: str) -> str:
|
|
280
|
+
if not _SESSION_ID_RE.fullmatch(v):
|
|
281
|
+
raise ValueError(
|
|
282
|
+
f"session_id must match [A-Za-z0-9_-]+; got {v!r}. "
|
|
283
|
+
"':' is reserved for future remote prefixing."
|
|
284
|
+
)
|
|
285
|
+
return v
|
|
286
|
+
|
|
287
|
+
@field_validator("elapsed_ms")
|
|
288
|
+
@classmethod
|
|
289
|
+
def _nonneg_elapsed(cls, v: int) -> int:
|
|
290
|
+
if v < 0:
|
|
291
|
+
raise ValueError(f"elapsed_ms must be ≥ 0; got {v}")
|
|
292
|
+
return v
|
|
293
|
+
|
|
294
|
+
@field_validator("stata_elapsed_ms")
|
|
295
|
+
@classmethod
|
|
296
|
+
def _nonneg_stata_elapsed(cls, v: int | None) -> int | None:
|
|
297
|
+
if v is not None and v < 0:
|
|
298
|
+
raise ValueError(f"stata_elapsed_ms must be ≥ 0; got {v}")
|
|
299
|
+
return v
|
|
300
|
+
|
|
301
|
+
@model_validator(mode="after")
|
|
302
|
+
def _consistency(self) -> RunResult:
|
|
303
|
+
if self.ok:
|
|
304
|
+
if self.error is not None:
|
|
305
|
+
raise ValueError("ok=True requires error to be None (SCHEMA.md §3.1)")
|
|
306
|
+
if self.rc != 0:
|
|
307
|
+
raise ValueError(f"ok=True requires rc=0; got {self.rc}")
|
|
308
|
+
else:
|
|
309
|
+
if self.error is None:
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"ok=False requires error to be non-None (SCHEMA.md §3.1)"
|
|
312
|
+
)
|
|
313
|
+
if self.error.rc != self.rc:
|
|
314
|
+
raise ValueError(
|
|
315
|
+
f"top-level rc ({self.rc}) must equal error.rc ({self.error.rc})"
|
|
316
|
+
)
|
|
317
|
+
return self
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""Stata Jupyter kernel — exposes the v1.0 stata_code pipeline.
|
|
2
|
+
|
|
3
|
+
The kernel uses `runner.execute()` for every cell. Defaults are tuned for
|
|
4
|
+
human/notebook use rather than agent use:
|
|
5
|
+
- `include_full_log=True`: full log shown in stdout (no head/tail truncation)
|
|
6
|
+
- `include_graphs="inline"`: graph bytes embedded for direct rendering
|
|
7
|
+
- `session_id="main"`: single-session unless the kernel is configured with
|
|
8
|
+
multiple kernel specs
|
|
9
|
+
|
|
10
|
+
Install via `python -m stata_code.kernel install --user`.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import sys
|
|
17
|
+
import traceback
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from tempfile import TemporaryDirectory
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
from ipykernel.kernelbase import Kernel
|
|
24
|
+
|
|
25
|
+
_HAS_IPYKERNEL = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
Kernel = object # type: ignore[misc,assignment]
|
|
28
|
+
_HAS_IPYKERNEL = False
|
|
29
|
+
|
|
30
|
+
from stata_code.core._runtime import PystataNotAvailable
|
|
31
|
+
from stata_code.core.runner import execute
|
|
32
|
+
from stata_code.core.schema import RunResult
|
|
33
|
+
|
|
34
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
35
|
+
# Static keyword / help tables (carried over verbatim — independent of
|
|
36
|
+
# pipeline; used by do_complete / do_inspect)
|
|
37
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
STATA_KEYWORDS: tuple[str, ...] = (
|
|
41
|
+
" quietly", " noisily", " capture",
|
|
42
|
+
" summarize", " summarize, detail", " describe", " browse",
|
|
43
|
+
" list", " inspect", " count", " assert",
|
|
44
|
+
" generate", " egen", " replace", " recode", " destring", " tostring",
|
|
45
|
+
" merge", " append", " joinby", " cross",
|
|
46
|
+
" sort", " gsort", " by", " bysort", " collapse", " contract", " stack",
|
|
47
|
+
" reshape", " xpose", " fillin",
|
|
48
|
+
" regress", " logistic", " probit", " tobit", " ivreg", " areg",
|
|
49
|
+
" xtreg", " logit", " ologit", " oprobit", " mlogit",
|
|
50
|
+
" estimates", " eststo", " esttab", " estpost",
|
|
51
|
+
" label", " label variable", " label define", " label values",
|
|
52
|
+
" keep", " drop", " use", " save", " clear", " insheet", " infile",
|
|
53
|
+
" infix", " import", " export", " outfile", " outreg",
|
|
54
|
+
" graph", " graph bar", " graph box", " graph twoway", " graph export",
|
|
55
|
+
" display", " putexcel", " putdocx",
|
|
56
|
+
" tempfile", " tempvar", " global", " local",
|
|
57
|
+
" foreach", " forvalues", " while", " if", " else", " continue",
|
|
58
|
+
" set", " update", " restore", " preserve",
|
|
59
|
+
" version", " mata", " python",
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
STATA_HELP: dict[str, str] = {
|
|
64
|
+
"summarize": "summarize [varlist] [if] [in] [weight] [, detail]\n\nCompute summary statistics.",
|
|
65
|
+
"regress": "regress depvar [indepvars] [if] [in] [weight] [, options]\n\nLinear regression.",
|
|
66
|
+
"logistic": "logistic depvar [indepvars] [if] [in] [weight] [, options]\n\nLogistic regression.",
|
|
67
|
+
"generate": "generate newvar = exp\n\ngenerate creates a new variable.",
|
|
68
|
+
"replace": "replace oldvar = exp [if] [in]\n\nreplace replaces the values of an existing variable.",
|
|
69
|
+
"merge": "merge [n] 1:1 varlist using filename [, options]\n\nmerge joins data from disk.",
|
|
70
|
+
"graph": "graph [type] plot [if] [in] [, options]\n\ngraph creates twoway plots.",
|
|
71
|
+
"by": "by varlist: command\n\nby repeats command for each subset of data.",
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
76
|
+
# Kernel
|
|
77
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class StataKernel(Kernel if _HAS_IPYKERNEL else object):
|
|
81
|
+
protocol_version = "5.3"
|
|
82
|
+
implementation = "stata_code.kernel"
|
|
83
|
+
implementation_version = "0.2.0"
|
|
84
|
+
language_info: dict[str, Any] = {
|
|
85
|
+
"name": "stata",
|
|
86
|
+
"codemirror_mode": "stata",
|
|
87
|
+
"file_extension": ".do",
|
|
88
|
+
"mimetype": "text/x-stata",
|
|
89
|
+
"pygments_lexer": "stata",
|
|
90
|
+
"version": "18.0",
|
|
91
|
+
}
|
|
92
|
+
banner = "Stata kernel (stata_code) — backed by the v1.0 runner pipeline"
|
|
93
|
+
help_links = [{"text": "Stata Help", "url": "https://www.stata.com/help.cgi?"}]
|
|
94
|
+
|
|
95
|
+
_last_result: RunResult | None = None
|
|
96
|
+
|
|
97
|
+
# ── Execution ──────────────────────────────────────────────────────────
|
|
98
|
+
|
|
99
|
+
def do_execute(
|
|
100
|
+
self,
|
|
101
|
+
code: str,
|
|
102
|
+
silent: bool = False,
|
|
103
|
+
store_history: bool = True,
|
|
104
|
+
user_expressions: dict[str, Any] | None = None,
|
|
105
|
+
allow_stdin: bool = False,
|
|
106
|
+
) -> dict[str, Any]:
|
|
107
|
+
if not _HAS_IPYKERNEL:
|
|
108
|
+
return self._error_reply("ipykernel not installed")
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
result = execute(
|
|
112
|
+
code.strip(),
|
|
113
|
+
include_full_log=True,
|
|
114
|
+
include_graphs="inline",
|
|
115
|
+
)
|
|
116
|
+
except PystataNotAvailable as exc:
|
|
117
|
+
return self._error_reply(f"Stata not available: {exc}")
|
|
118
|
+
except Exception as exc: # noqa: BLE001
|
|
119
|
+
traceback.print_exc()
|
|
120
|
+
return self._error_reply(str(exc))
|
|
121
|
+
|
|
122
|
+
self._last_result = result
|
|
123
|
+
|
|
124
|
+
if not silent:
|
|
125
|
+
if result.log.head:
|
|
126
|
+
self._stream("stdout", result.log.head + "\n")
|
|
127
|
+
if result.warnings:
|
|
128
|
+
for w in result.warnings:
|
|
129
|
+
self._stream("stderr", f"[{w.kind}] {w.message}\n")
|
|
130
|
+
for graph in result.graphs:
|
|
131
|
+
if graph.inline:
|
|
132
|
+
self._publish_image(graph.inline, graph.format.value)
|
|
133
|
+
if result.error:
|
|
134
|
+
msg = self._format_error(result)
|
|
135
|
+
self._stream("stderr", msg + "\n")
|
|
136
|
+
|
|
137
|
+
return self._reply(result)
|
|
138
|
+
|
|
139
|
+
# ── Reply helpers ──────────────────────────────────────────────────────
|
|
140
|
+
|
|
141
|
+
def _reply(self, r: RunResult) -> dict[str, Any]:
|
|
142
|
+
if r.error is None:
|
|
143
|
+
return {
|
|
144
|
+
"status": "ok",
|
|
145
|
+
"execution_count": self.execution_count,
|
|
146
|
+
"payload": [],
|
|
147
|
+
"user_expressions": {},
|
|
148
|
+
}
|
|
149
|
+
return {
|
|
150
|
+
"status": "error",
|
|
151
|
+
"execution_count": self.execution_count,
|
|
152
|
+
"ename": f"StataError({r.error.kind.value})",
|
|
153
|
+
"evalue": r.error.message,
|
|
154
|
+
"traceback": [self._format_error(r)],
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
def _format_error(self, r: RunResult) -> str:
|
|
158
|
+
e = r.error
|
|
159
|
+
assert e is not None
|
|
160
|
+
parts = [f"!!! Stata error: {e.kind.value} (rc={e.rc})", f" {e.message}"]
|
|
161
|
+
if e.line is not None:
|
|
162
|
+
parts.append(f" at line {e.line}: {e.context.failing!r}")
|
|
163
|
+
for s in e.suggestions:
|
|
164
|
+
parts.append(f" → {s.action}")
|
|
165
|
+
return "\n".join(parts)
|
|
166
|
+
|
|
167
|
+
def _error_reply(self, msg: str) -> dict[str, Any]:
|
|
168
|
+
return {
|
|
169
|
+
"status": "error",
|
|
170
|
+
"execution_count": self.execution_count,
|
|
171
|
+
"ename": "RuntimeError",
|
|
172
|
+
"evalue": msg,
|
|
173
|
+
"traceback": [msg],
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
def _stream(self, name: str, text: str) -> None:
|
|
177
|
+
if not text:
|
|
178
|
+
return
|
|
179
|
+
try:
|
|
180
|
+
self.send_response(
|
|
181
|
+
self.iopub_socket, "stream", {"name": name, "text": text}
|
|
182
|
+
)
|
|
183
|
+
except Exception: # noqa: BLE001
|
|
184
|
+
pass # non-kernel context (tests)
|
|
185
|
+
|
|
186
|
+
def _publish_image(self, b64_data: str, fmt: str) -> None:
|
|
187
|
+
mime = {
|
|
188
|
+
"png": "image/png",
|
|
189
|
+
"svg": "image/svg+xml",
|
|
190
|
+
"pdf": "application/pdf",
|
|
191
|
+
}.get(fmt, "image/png")
|
|
192
|
+
try:
|
|
193
|
+
self.send_response(
|
|
194
|
+
self.iopub_socket,
|
|
195
|
+
"display_data",
|
|
196
|
+
{
|
|
197
|
+
"data": {mime: b64_data, "text/plain": f"[graph: {fmt}]"},
|
|
198
|
+
"metadata": {},
|
|
199
|
+
},
|
|
200
|
+
)
|
|
201
|
+
except Exception: # noqa: BLE001
|
|
202
|
+
pass
|
|
203
|
+
|
|
204
|
+
# ── Completion / Inspection (unchanged from prior kernel) ──────────────
|
|
205
|
+
|
|
206
|
+
def do_complete(self, code: str, cursor_pos: int) -> dict[str, Any]:
|
|
207
|
+
line = code[:cursor_pos]
|
|
208
|
+
token_start = len(line) - 1
|
|
209
|
+
while token_start > 0 and line[token_start - 1] not in (" \t\n\r(,"):
|
|
210
|
+
token_start -= 1
|
|
211
|
+
token = line[token_start:cursor_pos]
|
|
212
|
+
matches = sorted(kw for kw in STATA_KEYWORDS if kw.lstrip().startswith(token))
|
|
213
|
+
return {
|
|
214
|
+
"status": "ok",
|
|
215
|
+
"matches": matches,
|
|
216
|
+
"cursor_start": token_start,
|
|
217
|
+
"cursor_end": cursor_pos,
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
def do_inspect(
|
|
221
|
+
self, code: str, cursor_pos: int, detail_level: int = 0
|
|
222
|
+
) -> dict[str, Any]:
|
|
223
|
+
word_end = cursor_pos
|
|
224
|
+
word_start = word_end - 1
|
|
225
|
+
while word_start > 0 and code[word_start - 1].isalnum():
|
|
226
|
+
word_start -= 1
|
|
227
|
+
word = code[word_start:word_end]
|
|
228
|
+
found = STATA_HELP.get(word.lower())
|
|
229
|
+
if found:
|
|
230
|
+
return {
|
|
231
|
+
"status": "ok",
|
|
232
|
+
"found": True,
|
|
233
|
+
"name": word,
|
|
234
|
+
"documentation": found,
|
|
235
|
+
"cursor_start": word_start,
|
|
236
|
+
"cursor_end": word_end,
|
|
237
|
+
}
|
|
238
|
+
return {"status": "ok", "found": False}
|
|
239
|
+
|
|
240
|
+
def do_kernel_info(self) -> dict[str, Any]:
|
|
241
|
+
return {
|
|
242
|
+
"protocol_version": self.protocol_version,
|
|
243
|
+
"implementation": self.implementation,
|
|
244
|
+
"implementation_version": self.implementation_version,
|
|
245
|
+
"language_info": self.language_info,
|
|
246
|
+
"banner": self.banner,
|
|
247
|
+
"help_links": self.help_links,
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
252
|
+
# Kernel installation CLI
|
|
253
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def install_kernel(user: bool = True, system: bool = False) -> None:
|
|
257
|
+
"""Register the Stata kernel with Jupyter.
|
|
258
|
+
|
|
259
|
+
By default installs into the current user's Jupyter data dir. Pass
|
|
260
|
+
`system=True` to request a non-user install through Jupyter's kernelspec
|
|
261
|
+
manager.
|
|
262
|
+
"""
|
|
263
|
+
from jupyter_client.kernelspec import KernelSpecManager
|
|
264
|
+
|
|
265
|
+
py_exec = Path(sys.executable).resolve()
|
|
266
|
+
kernel_json = {
|
|
267
|
+
"argv": [
|
|
268
|
+
str(py_exec),
|
|
269
|
+
"-m",
|
|
270
|
+
"stata_code.kernel",
|
|
271
|
+
"-f",
|
|
272
|
+
"{connection_file}",
|
|
273
|
+
],
|
|
274
|
+
"display_name": "Stata",
|
|
275
|
+
"language": "stata",
|
|
276
|
+
"metadata": {"debugger": False},
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
install_user = False if system else user
|
|
280
|
+
with TemporaryDirectory(prefix="stata_code_kernel_") as td:
|
|
281
|
+
src_dir = Path(td)
|
|
282
|
+
(src_dir / "kernel.json").write_text(json.dumps(kernel_json, indent=2))
|
|
283
|
+
dest = KernelSpecManager().install_kernel_spec(
|
|
284
|
+
str(src_dir),
|
|
285
|
+
kernel_name="stata",
|
|
286
|
+
user=install_user,
|
|
287
|
+
replace=True,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
print(f"Kernel installed to: {dest}")
|
|
291
|
+
print("Restart Jupyter and select 'Stata' as the kernel.")
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def run_main() -> None:
|
|
295
|
+
"""Console script entry point — installer or kernel launcher.
|
|
296
|
+
|
|
297
|
+
Usage::
|
|
298
|
+
|
|
299
|
+
stata-code-kernel install [--user|--system] # install kernel spec
|
|
300
|
+
stata-code-kernel -f <connection_file> # launch the kernel
|
|
301
|
+
# (Jupyter calls this)
|
|
302
|
+
"""
|
|
303
|
+
import argparse
|
|
304
|
+
import sys as _sys
|
|
305
|
+
|
|
306
|
+
# Distinguish the "install" subcommand from any other invocation (Jupyter
|
|
307
|
+
# passes connection-file flags that argparse subparsers can't see).
|
|
308
|
+
if len(_sys.argv) == 1 or _sys.argv[1] in {"-h", "--help"}:
|
|
309
|
+
print(
|
|
310
|
+
"usage: stata-code-kernel install [--user|--system]\n"
|
|
311
|
+
" stata-code-kernel -f <connection_file>\n\n"
|
|
312
|
+
"Install or launch the Stata Jupyter kernel."
|
|
313
|
+
)
|
|
314
|
+
return
|
|
315
|
+
|
|
316
|
+
if len(_sys.argv) > 1 and _sys.argv[1] == "install":
|
|
317
|
+
parser = argparse.ArgumentParser(prog="stata-code-kernel install")
|
|
318
|
+
target = parser.add_mutually_exclusive_group()
|
|
319
|
+
target.add_argument("--user", dest="user", action="store_true", default=True)
|
|
320
|
+
target.add_argument("--system", dest="user", action="store_false")
|
|
321
|
+
args = parser.parse_args(_sys.argv[2:])
|
|
322
|
+
install_kernel(user=args.user, system=not args.user)
|
|
323
|
+
return
|
|
324
|
+
|
|
325
|
+
from ipykernel.kernelapp import IPKernelApp
|
|
326
|
+
|
|
327
|
+
IPKernelApp.launch_instance(kernel_class=StataKernel)
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
if __name__ == "__main__": # pragma: no cover
|
|
331
|
+
run_main()
|