weakincentives 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of weakincentives might be problematic. Click here for more details.
- weakincentives/__init__.py +1 -1
- weakincentives/adapters/__init__.py +3 -2
- weakincentives/examples/code_review_prompt.py +12 -3
- weakincentives/examples/code_review_session.py +7 -3
- weakincentives/prompt/markdown.py +1 -1
- weakincentives/prompt/prompt.py +7 -7
- weakincentives/prompt/structured_output.py +2 -2
- weakincentives/prompt/tool.py +2 -2
- weakincentives/serde/dataclass_serde.py +16 -14
- weakincentives/session/session.py +5 -0
- weakincentives/tools/__init__.py +12 -0
- weakincentives/tools/asteval.py +698 -0
- weakincentives/tools/vfs.py +59 -56
- weakincentives-0.4.0.dist-info/METADATA +490 -0
- {weakincentives-0.3.0.dist-info → weakincentives-0.4.0.dist-info}/RECORD +17 -16
- weakincentives-0.3.0.dist-info/METADATA +0 -231
- {weakincentives-0.3.0.dist-info → weakincentives-0.4.0.dist-info}/WHEEL +0 -0
- {weakincentives-0.3.0.dist-info → weakincentives-0.4.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,698 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
"""Sandboxed Python expression evaluation backed by :mod:`asteval`."""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import ast
|
|
18
|
+
import builtins
|
|
19
|
+
import contextlib
|
|
20
|
+
import io
|
|
21
|
+
import json
|
|
22
|
+
import logging
|
|
23
|
+
import math
|
|
24
|
+
import statistics
|
|
25
|
+
import sys
|
|
26
|
+
import threading
|
|
27
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
28
|
+
from dataclasses import dataclass, field
|
|
29
|
+
from datetime import UTC, datetime
|
|
30
|
+
from types import MappingProxyType
|
|
31
|
+
from typing import TYPE_CHECKING, Final, Literal, TextIO, cast
|
|
32
|
+
|
|
33
|
+
from ..prompt.markdown import MarkdownSection
|
|
34
|
+
from ..prompt.tool import Tool, ToolResult
|
|
35
|
+
from ..session import Session, select_latest
|
|
36
|
+
from ..session.session import DataEvent
|
|
37
|
+
from .errors import ToolValidationError
|
|
38
|
+
from .vfs import VfsFile, VfsPath, VirtualFileSystem
|
|
39
|
+
|
|
40
|
+
ExpressionMode = Literal["expr", "statements"]
|
|
41
|
+
|
|
42
|
+
_logger = logging.getLogger(__name__)
|
|
43
|
+
|
|
44
|
+
_MAX_CODE_LENGTH: Final[int] = 2_000
|
|
45
|
+
_MAX_STREAM_LENGTH: Final[int] = 4_096
|
|
46
|
+
_MAX_WRITE_LENGTH: Final[int] = 48_000
|
|
47
|
+
_MAX_PATH_DEPTH: Final[int] = 16
|
|
48
|
+
_MAX_SEGMENT_LENGTH: Final[int] = 80
|
|
49
|
+
_ASCII: Final[str] = "ascii"
|
|
50
|
+
_TIMEOUT_SECONDS: Final[float] = 5.0
|
|
51
|
+
|
|
52
|
+
_SAFE_GLOBALS: Final[Mapping[str, object]] = MappingProxyType(
|
|
53
|
+
{
|
|
54
|
+
"abs": abs,
|
|
55
|
+
"len": len,
|
|
56
|
+
"min": min,
|
|
57
|
+
"max": max,
|
|
58
|
+
"print": print,
|
|
59
|
+
"range": range,
|
|
60
|
+
"round": round,
|
|
61
|
+
"sum": sum,
|
|
62
|
+
"str": str,
|
|
63
|
+
"math": math,
|
|
64
|
+
"statistics": MappingProxyType(
|
|
65
|
+
{
|
|
66
|
+
"mean": statistics.mean,
|
|
67
|
+
"median": statistics.median,
|
|
68
|
+
"pstdev": statistics.pstdev,
|
|
69
|
+
"stdev": statistics.stdev,
|
|
70
|
+
"variance": statistics.variance,
|
|
71
|
+
}
|
|
72
|
+
),
|
|
73
|
+
"PI": math.pi,
|
|
74
|
+
"TAU": math.tau,
|
|
75
|
+
"E": math.e,
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
_EVAL_TEMPLATE: Final[str] = (
|
|
80
|
+
"Use the Python evaluation tool for quick calculations and one-off scripts.\n"
|
|
81
|
+
"- Keep code concise (<=2,000 characters) and prefer expression mode unless you need statements.\n"
|
|
82
|
+
"- Pre-load files via `reads`, or call `read_text(path)` inside code to fetch VFS files.\n"
|
|
83
|
+
"- Stage edits with `write_text(path, content, mode)` or declare them in `writes`. Content must be ASCII.\n"
|
|
84
|
+
"- Globals accept JSON-encoded strings and are parsed before execution.\n"
|
|
85
|
+
"- Execution stops after five seconds; design code to finish quickly."
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass(slots=True, frozen=True)
|
|
90
|
+
class EvalFileRead:
|
|
91
|
+
path: VfsPath
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass(slots=True, frozen=True)
|
|
95
|
+
class EvalFileWrite:
|
|
96
|
+
path: VfsPath
|
|
97
|
+
content: str
|
|
98
|
+
mode: Literal["create", "overwrite", "append"] = "create"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass(slots=True, frozen=True)
|
|
102
|
+
class EvalParams:
|
|
103
|
+
code: str
|
|
104
|
+
mode: ExpressionMode = "expr"
|
|
105
|
+
globals: dict[str, str] = field(default_factory=dict)
|
|
106
|
+
reads: tuple[EvalFileRead, ...] = field(default_factory=tuple)
|
|
107
|
+
writes: tuple[EvalFileWrite, ...] = field(default_factory=tuple)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass(slots=True, frozen=True)
|
|
111
|
+
class EvalResult:
|
|
112
|
+
value_repr: str | None
|
|
113
|
+
stdout: str
|
|
114
|
+
stderr: str
|
|
115
|
+
globals: dict[str, str]
|
|
116
|
+
reads: tuple[EvalFileRead, ...]
|
|
117
|
+
writes: tuple[EvalFileWrite, ...]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass(slots=True, frozen=True)
|
|
121
|
+
class _AstevalSectionParams:
|
|
122
|
+
"""Placeholder params container for the asteval section."""
|
|
123
|
+
|
|
124
|
+
pass
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _now() -> datetime:
|
|
128
|
+
value = datetime.now(UTC)
|
|
129
|
+
microsecond = value.microsecond - value.microsecond % 1000
|
|
130
|
+
return value.replace(microsecond=microsecond, tzinfo=UTC)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def _truncate_stream(text: str) -> str:
|
|
134
|
+
if len(text) <= _MAX_STREAM_LENGTH:
|
|
135
|
+
return text
|
|
136
|
+
suffix = "..."
|
|
137
|
+
keep = _MAX_STREAM_LENGTH - len(suffix)
|
|
138
|
+
return f"{text[:keep]}{suffix}"
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _format_preview(value: str | None, *, empty: str, limit: int = 160) -> str:
|
|
142
|
+
if not value:
|
|
143
|
+
return empty
|
|
144
|
+
normalized = value.replace("\n", "\\n")
|
|
145
|
+
if len(normalized) <= limit:
|
|
146
|
+
return normalized
|
|
147
|
+
return f"{normalized[: limit - 3]}..."
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _extract_error_reason(stderr: str) -> str | None:
|
|
151
|
+
text = stderr.strip()
|
|
152
|
+
if not text: # pragma: no cover - defensive fallback
|
|
153
|
+
return None
|
|
154
|
+
if ": " in text:
|
|
155
|
+
candidate = text.split(": ")[-1].strip()
|
|
156
|
+
if candidate:
|
|
157
|
+
trimmed = candidate.rstrip(").")
|
|
158
|
+
return trimmed or candidate
|
|
159
|
+
return text
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _ensure_ascii(value: str, label: str) -> None:
|
|
163
|
+
try:
|
|
164
|
+
value.encode(_ASCII)
|
|
165
|
+
except UnicodeEncodeError as error: # pragma: no cover - defensive guard
|
|
166
|
+
raise ToolValidationError(f"{label} must be ASCII text.") from error
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _normalize_segments(raw_segments: Iterable[str]) -> tuple[str, ...]:
|
|
170
|
+
segments: list[str] = []
|
|
171
|
+
for raw_segment in raw_segments:
|
|
172
|
+
stripped = raw_segment.strip()
|
|
173
|
+
if not stripped:
|
|
174
|
+
continue
|
|
175
|
+
if stripped.startswith("/"):
|
|
176
|
+
raise ToolValidationError("Absolute paths are not allowed in the VFS.")
|
|
177
|
+
for piece in stripped.split("/"):
|
|
178
|
+
if not piece:
|
|
179
|
+
continue
|
|
180
|
+
if piece in {".", ".."}:
|
|
181
|
+
raise ToolValidationError("Path segments may not include '.' or '..'.")
|
|
182
|
+
_ensure_ascii(piece, "path segment")
|
|
183
|
+
if len(piece) > _MAX_SEGMENT_LENGTH:
|
|
184
|
+
raise ToolValidationError(
|
|
185
|
+
"Path segments must be 80 characters or fewer."
|
|
186
|
+
)
|
|
187
|
+
segments.append(piece)
|
|
188
|
+
if len(segments) > _MAX_PATH_DEPTH:
|
|
189
|
+
raise ToolValidationError("Path depth exceeds the allowed limit (16 segments).")
|
|
190
|
+
return tuple(segments)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _normalize_vfs_path(path: VfsPath) -> VfsPath:
|
|
194
|
+
return VfsPath(_normalize_segments(path.segments))
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _require_file(snapshot: VirtualFileSystem, path: VfsPath) -> VfsFile:
|
|
198
|
+
normalized = _normalize_vfs_path(path)
|
|
199
|
+
for file in snapshot.files:
|
|
200
|
+
if file.path.segments == normalized.segments:
|
|
201
|
+
return file
|
|
202
|
+
raise ToolValidationError("File does not exist in the virtual filesystem.")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _normalize_code(code: str) -> str:
|
|
206
|
+
if len(code) > _MAX_CODE_LENGTH:
|
|
207
|
+
raise ToolValidationError("Code exceeds maximum length of 2,000 characters.")
|
|
208
|
+
for char in code:
|
|
209
|
+
code_point = ord(char)
|
|
210
|
+
if code_point < 32 and char not in {"\n", "\t"}:
|
|
211
|
+
raise ToolValidationError("Code contains unsupported control characters.")
|
|
212
|
+
return code
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _normalize_write(write: EvalFileWrite) -> EvalFileWrite:
|
|
216
|
+
path = _normalize_vfs_path(write.path)
|
|
217
|
+
content = write.content
|
|
218
|
+
_ensure_ascii(content, "write content")
|
|
219
|
+
if len(content) > _MAX_WRITE_LENGTH:
|
|
220
|
+
raise ToolValidationError(
|
|
221
|
+
"Content exceeds maximum length of 48,000 characters."
|
|
222
|
+
)
|
|
223
|
+
mode = write.mode
|
|
224
|
+
if mode not in {"create", "overwrite", "append"}:
|
|
225
|
+
raise ToolValidationError("Unsupported write mode requested.")
|
|
226
|
+
return EvalFileWrite(path=path, content=content, mode=mode)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _normalize_reads(reads: Iterable[EvalFileRead]) -> tuple[EvalFileRead, ...]:
|
|
230
|
+
normalized: list[EvalFileRead] = []
|
|
231
|
+
seen: set[tuple[str, ...]] = set()
|
|
232
|
+
for read in reads:
|
|
233
|
+
path = _normalize_vfs_path(read.path)
|
|
234
|
+
key = path.segments
|
|
235
|
+
if key in seen:
|
|
236
|
+
raise ToolValidationError("Duplicate read targets detected.")
|
|
237
|
+
seen.add(key)
|
|
238
|
+
normalized.append(EvalFileRead(path=path))
|
|
239
|
+
return tuple(normalized)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _normalize_writes(writes: Iterable[EvalFileWrite]) -> tuple[EvalFileWrite, ...]:
|
|
243
|
+
normalized: list[EvalFileWrite] = []
|
|
244
|
+
seen: set[tuple[str, ...]] = set()
|
|
245
|
+
for write in writes:
|
|
246
|
+
normalized_write = _normalize_write(write)
|
|
247
|
+
key = normalized_write.path.segments
|
|
248
|
+
if key in seen:
|
|
249
|
+
raise ToolValidationError("Duplicate write targets detected.")
|
|
250
|
+
seen.add(key)
|
|
251
|
+
normalized.append(normalized_write)
|
|
252
|
+
return tuple(normalized)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def _alias_for_path(path: VfsPath) -> str:
|
|
256
|
+
return "/".join(path.segments)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def _format_value(value: object) -> str:
|
|
260
|
+
if isinstance(value, str):
|
|
261
|
+
return value
|
|
262
|
+
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
|
263
|
+
return json.dumps(value)
|
|
264
|
+
if isinstance(value, bool) or value is None:
|
|
265
|
+
return json.dumps(value)
|
|
266
|
+
return f"!repr:{value!r}"
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def _merge_globals(
|
|
270
|
+
initial: Mapping[str, object], updates: Mapping[str, object]
|
|
271
|
+
) -> dict[str, object]:
|
|
272
|
+
merged = dict(initial)
|
|
273
|
+
merged.update(updates)
|
|
274
|
+
return merged
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _apply_writes(
|
|
278
|
+
snapshot: VirtualFileSystem, writes: Iterable[EvalFileWrite]
|
|
279
|
+
) -> VirtualFileSystem:
|
|
280
|
+
files = list(snapshot.files)
|
|
281
|
+
timestamp = _now()
|
|
282
|
+
for write in writes:
|
|
283
|
+
existing_index = next(
|
|
284
|
+
(index for index, file in enumerate(files) if file.path == write.path),
|
|
285
|
+
None,
|
|
286
|
+
)
|
|
287
|
+
existing = files[existing_index] if existing_index is not None else None
|
|
288
|
+
if write.mode == "create" and existing is not None:
|
|
289
|
+
raise ToolValidationError("File already exists; use overwrite or append.")
|
|
290
|
+
if write.mode in {"overwrite", "append"} and existing is None:
|
|
291
|
+
raise ToolValidationError("File does not exist for the requested mode.")
|
|
292
|
+
if write.mode == "append" and existing is not None:
|
|
293
|
+
content = existing.content + write.content
|
|
294
|
+
created_at = existing.created_at
|
|
295
|
+
version = existing.version + 1
|
|
296
|
+
elif existing is not None:
|
|
297
|
+
content = write.content
|
|
298
|
+
created_at = existing.created_at
|
|
299
|
+
version = existing.version + 1
|
|
300
|
+
else:
|
|
301
|
+
content = write.content
|
|
302
|
+
created_at = timestamp
|
|
303
|
+
version = 1
|
|
304
|
+
size_bytes = len(content.encode("utf-8"))
|
|
305
|
+
updated = VfsFile(
|
|
306
|
+
path=write.path,
|
|
307
|
+
content=content,
|
|
308
|
+
encoding="utf-8",
|
|
309
|
+
size_bytes=size_bytes,
|
|
310
|
+
version=version,
|
|
311
|
+
created_at=created_at,
|
|
312
|
+
updated_at=timestamp,
|
|
313
|
+
)
|
|
314
|
+
if existing_index is not None:
|
|
315
|
+
files.pop(existing_index)
|
|
316
|
+
files.append(updated)
|
|
317
|
+
files.sort(key=lambda file: file.path.segments)
|
|
318
|
+
return VirtualFileSystem(files=tuple(files))
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def _parse_string_path(path: str) -> VfsPath:
|
|
322
|
+
if not path.strip():
|
|
323
|
+
raise ToolValidationError("Path must be non-empty.")
|
|
324
|
+
return VfsPath(_normalize_segments((path,)))
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def _build_eval_globals(
|
|
328
|
+
snapshot: VirtualFileSystem, reads: tuple[EvalFileRead, ...]
|
|
329
|
+
) -> dict[str, str]:
|
|
330
|
+
values: dict[str, str] = {}
|
|
331
|
+
for read in reads:
|
|
332
|
+
alias = _alias_for_path(read.path)
|
|
333
|
+
file = _require_file(snapshot, read.path)
|
|
334
|
+
values[alias] = file.content
|
|
335
|
+
return values
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def _parse_user_globals(payload: Mapping[str, str]) -> dict[str, object]:
|
|
339
|
+
parsed: dict[str, object] = {}
|
|
340
|
+
for name, encoded in payload.items():
|
|
341
|
+
identifier = name.strip()
|
|
342
|
+
if not identifier:
|
|
343
|
+
raise ToolValidationError("Global variable names must be non-empty.")
|
|
344
|
+
if not identifier.isidentifier():
|
|
345
|
+
raise ToolValidationError(f"Invalid global variable name '{identifier}'.")
|
|
346
|
+
try:
|
|
347
|
+
parsed_value = json.loads(encoded)
|
|
348
|
+
except json.JSONDecodeError as error:
|
|
349
|
+
raise ToolValidationError(
|
|
350
|
+
f"Invalid JSON for global '{identifier}'."
|
|
351
|
+
) from error
|
|
352
|
+
parsed[identifier] = parsed_value
|
|
353
|
+
return parsed
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
if TYPE_CHECKING:
|
|
357
|
+
from asteval import Interpreter
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def _sanitize_interpreter(interpreter: Interpreter) -> None:
|
|
361
|
+
try:
|
|
362
|
+
import asteval # type: ignore
|
|
363
|
+
except ModuleNotFoundError as error: # pragma: no cover - configuration guard
|
|
364
|
+
raise RuntimeError("asteval dependency is not installed.") from error
|
|
365
|
+
|
|
366
|
+
for name in getattr(asteval, "ALL_DISALLOWED", ()): # pragma: no cover - defensive
|
|
367
|
+
interpreter.symtable.pop(name, None)
|
|
368
|
+
node_handlers = getattr(interpreter, "node_handlers", None)
|
|
369
|
+
if isinstance(node_handlers, dict):
|
|
370
|
+
for key in ("Eval", "Exec", "Import", "ImportFrom"):
|
|
371
|
+
node_handlers.pop(key, None)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def _create_interpreter() -> Interpreter:
|
|
375
|
+
try:
|
|
376
|
+
from asteval import Interpreter # type: ignore
|
|
377
|
+
except ModuleNotFoundError as error: # pragma: no cover - configuration guard
|
|
378
|
+
raise RuntimeError("asteval dependency is not installed.") from error
|
|
379
|
+
|
|
380
|
+
interpreter = Interpreter(use_numpy=False, minimal=True)
|
|
381
|
+
interpreter.symtable = dict(_SAFE_GLOBALS)
|
|
382
|
+
_sanitize_interpreter(interpreter)
|
|
383
|
+
return interpreter
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def _execute_with_timeout(
|
|
387
|
+
func: Callable[[], object],
|
|
388
|
+
) -> tuple[bool, object | None, str]:
|
|
389
|
+
if sys.platform != "win32": # pragma: no branch - platform check
|
|
390
|
+
import signal
|
|
391
|
+
|
|
392
|
+
timed_out = False
|
|
393
|
+
|
|
394
|
+
def handler(signum: int, frame: object | None) -> None: # noqa: ARG001
|
|
395
|
+
nonlocal timed_out
|
|
396
|
+
timed_out = True
|
|
397
|
+
raise TimeoutError
|
|
398
|
+
|
|
399
|
+
previous = signal.signal(signal.SIGALRM, handler)
|
|
400
|
+
signal.setitimer(signal.ITIMER_REAL, _TIMEOUT_SECONDS)
|
|
401
|
+
try:
|
|
402
|
+
value = func()
|
|
403
|
+
return False, value, ""
|
|
404
|
+
except TimeoutError:
|
|
405
|
+
return True, None, "Execution timed out."
|
|
406
|
+
finally:
|
|
407
|
+
signal.setitimer(signal.ITIMER_REAL, 0)
|
|
408
|
+
signal.signal(signal.SIGALRM, previous)
|
|
409
|
+
timeout_message = "Execution timed out."
|
|
410
|
+
result_container: dict[str, object | None] = {}
|
|
411
|
+
error_container: dict[str, str] = {"message": ""}
|
|
412
|
+
completed = threading.Event()
|
|
413
|
+
|
|
414
|
+
def runner() -> None:
|
|
415
|
+
try:
|
|
416
|
+
result_container["value"] = func()
|
|
417
|
+
except TimeoutError:
|
|
418
|
+
error_container["message"] = timeout_message
|
|
419
|
+
except Exception as error: # pragma: no cover - forwarded later
|
|
420
|
+
result_container["error"] = error
|
|
421
|
+
finally:
|
|
422
|
+
completed.set()
|
|
423
|
+
|
|
424
|
+
thread = threading.Thread(target=runner, daemon=True)
|
|
425
|
+
thread.start()
|
|
426
|
+
completed.wait(_TIMEOUT_SECONDS)
|
|
427
|
+
if not completed.is_set():
|
|
428
|
+
return True, None, timeout_message
|
|
429
|
+
if "error" in result_container:
|
|
430
|
+
error = cast(Exception, result_container["error"])
|
|
431
|
+
raise error
|
|
432
|
+
return False, result_container.get("value"), error_container["message"]
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
class _AstevalToolSuite:
|
|
436
|
+
def __init__(self, *, session: Session) -> None:
|
|
437
|
+
self._session = session
|
|
438
|
+
|
|
439
|
+
def run(self, params: EvalParams) -> ToolResult[EvalResult]:
|
|
440
|
+
code = _normalize_code(params.code)
|
|
441
|
+
mode = params.mode
|
|
442
|
+
if mode not in {"expr", "statements"}:
|
|
443
|
+
raise ToolValidationError("Unsupported evaluation mode.")
|
|
444
|
+
reads = _normalize_reads(params.reads)
|
|
445
|
+
writes = _normalize_writes(params.writes)
|
|
446
|
+
read_paths = {read.path.segments for read in reads}
|
|
447
|
+
write_paths = {write.path.segments for write in writes}
|
|
448
|
+
if read_paths & write_paths:
|
|
449
|
+
raise ToolValidationError("Reads and writes must not target the same path.")
|
|
450
|
+
|
|
451
|
+
snapshot = (
|
|
452
|
+
select_latest(self._session, VirtualFileSystem) or VirtualFileSystem()
|
|
453
|
+
)
|
|
454
|
+
read_globals = _build_eval_globals(snapshot, reads)
|
|
455
|
+
user_globals = _parse_user_globals(params.globals)
|
|
456
|
+
|
|
457
|
+
interpreter = _create_interpreter()
|
|
458
|
+
stdout_buffer = io.StringIO()
|
|
459
|
+
stderr_buffer = io.StringIO()
|
|
460
|
+
write_queue: list[EvalFileWrite] = list(writes)
|
|
461
|
+
helper_writes: list[EvalFileWrite] = []
|
|
462
|
+
write_targets = {write.path.segments for write in write_queue}
|
|
463
|
+
builtin_print = builtins.print
|
|
464
|
+
|
|
465
|
+
def sandbox_print(
|
|
466
|
+
*args: object,
|
|
467
|
+
sep: str | None = " ",
|
|
468
|
+
end: str | None = "\n",
|
|
469
|
+
file: TextIO | None = None,
|
|
470
|
+
flush: bool = False,
|
|
471
|
+
) -> None:
|
|
472
|
+
if file is not None: # pragma: no cover - requires custom injected writer
|
|
473
|
+
builtin_print(*args, sep=sep, end=end, file=file, flush=flush)
|
|
474
|
+
return
|
|
475
|
+
actual_sep = " " if sep is None else sep
|
|
476
|
+
actual_end = "\n" if end is None else end
|
|
477
|
+
if not isinstance(actual_sep, str):
|
|
478
|
+
raise TypeError("sep must be None or a string.")
|
|
479
|
+
if not isinstance(actual_end, str):
|
|
480
|
+
raise TypeError("end must be None or a string.")
|
|
481
|
+
text = actual_sep.join(str(arg) for arg in args)
|
|
482
|
+
stdout_buffer.write(text)
|
|
483
|
+
stdout_buffer.write(actual_end)
|
|
484
|
+
if flush:
|
|
485
|
+
stdout_buffer.flush()
|
|
486
|
+
|
|
487
|
+
if mode == "expr":
|
|
488
|
+
try:
|
|
489
|
+
ast.parse(code, mode="eval")
|
|
490
|
+
except SyntaxError as error:
|
|
491
|
+
raise ToolValidationError(
|
|
492
|
+
"Expression mode requires a single expression."
|
|
493
|
+
) from error
|
|
494
|
+
|
|
495
|
+
def read_text(path: str) -> str:
|
|
496
|
+
normalized = _normalize_vfs_path(_parse_string_path(path))
|
|
497
|
+
file = _require_file(snapshot, normalized)
|
|
498
|
+
return file.content
|
|
499
|
+
|
|
500
|
+
def write_text(path: str, content: str, mode: str = "create") -> None:
|
|
501
|
+
normalized_path = _normalize_vfs_path(_parse_string_path(path))
|
|
502
|
+
helper_write = _normalize_write(
|
|
503
|
+
EvalFileWrite(
|
|
504
|
+
path=normalized_path,
|
|
505
|
+
content=content,
|
|
506
|
+
mode=cast(Literal["create", "overwrite", "append"], mode),
|
|
507
|
+
)
|
|
508
|
+
)
|
|
509
|
+
key = helper_write.path.segments
|
|
510
|
+
if key in read_paths:
|
|
511
|
+
raise ToolValidationError(
|
|
512
|
+
"Writes queued during execution must not target read paths."
|
|
513
|
+
)
|
|
514
|
+
if key in write_targets:
|
|
515
|
+
raise ToolValidationError("Duplicate write targets detected.")
|
|
516
|
+
write_targets.add(key)
|
|
517
|
+
helper_writes.append(helper_write)
|
|
518
|
+
|
|
519
|
+
symtable = interpreter.symtable
|
|
520
|
+
symtable.update(user_globals)
|
|
521
|
+
symtable["vfs_reads"] = dict(read_globals)
|
|
522
|
+
symtable["read_text"] = read_text
|
|
523
|
+
symtable["write_text"] = write_text
|
|
524
|
+
symtable["print"] = sandbox_print
|
|
525
|
+
|
|
526
|
+
all_keys = set(symtable)
|
|
527
|
+
captured_errors: list[str] = []
|
|
528
|
+
value_repr: str | None = None
|
|
529
|
+
stderr_text = ""
|
|
530
|
+
try:
|
|
531
|
+
with (
|
|
532
|
+
contextlib.redirect_stdout(stdout_buffer),
|
|
533
|
+
contextlib.redirect_stderr(stderr_buffer),
|
|
534
|
+
):
|
|
535
|
+
interpreter.error = []
|
|
536
|
+
|
|
537
|
+
def runner() -> object:
|
|
538
|
+
return interpreter.eval(code)
|
|
539
|
+
|
|
540
|
+
timed_out, result, timeout_error = _execute_with_timeout(runner)
|
|
541
|
+
if timed_out:
|
|
542
|
+
stderr_text = timeout_error
|
|
543
|
+
elif interpreter.error:
|
|
544
|
+
captured_errors.extend(str(err) for err in interpreter.error)
|
|
545
|
+
if not timed_out and not captured_errors and not stderr_text:
|
|
546
|
+
value_repr = None if result is None else repr(result)
|
|
547
|
+
except ToolValidationError: # pragma: no cover - interpreter wraps tool errors
|
|
548
|
+
raise
|
|
549
|
+
except Exception as error: # pragma: no cover - runtime exception
|
|
550
|
+
captured_errors.append(str(error))
|
|
551
|
+
stdout = _truncate_stream(stdout_buffer.getvalue())
|
|
552
|
+
stderr_raw = (
|
|
553
|
+
stderr_text or "\n".join(captured_errors) or stderr_buffer.getvalue()
|
|
554
|
+
)
|
|
555
|
+
stderr = _truncate_stream(stderr_raw)
|
|
556
|
+
|
|
557
|
+
param_writes = tuple(write_queue)
|
|
558
|
+
pending_writes = bool(write_queue or helper_writes)
|
|
559
|
+
value_preview = _format_preview(value_repr, empty="none")
|
|
560
|
+
stdout_preview = _format_preview(stdout, empty="empty")
|
|
561
|
+
stderr_preview = _format_preview(stderr, empty="empty")
|
|
562
|
+
if stderr and not value_repr:
|
|
563
|
+
final_writes: tuple[EvalFileWrite, ...] = ()
|
|
564
|
+
writes_summary = "discarded" if pending_writes else "none"
|
|
565
|
+
error_reason = _format_preview(
|
|
566
|
+
_extract_error_reason(stderr), empty="unknown"
|
|
567
|
+
)
|
|
568
|
+
message = (
|
|
569
|
+
"Evaluation failed. "
|
|
570
|
+
f"value={value_preview}; stdout={stdout_preview}; "
|
|
571
|
+
f"stderr={stderr_preview}; error={error_reason}; "
|
|
572
|
+
f"writes={writes_summary}."
|
|
573
|
+
)
|
|
574
|
+
else:
|
|
575
|
+
format_context = {
|
|
576
|
+
key: value for key, value in symtable.items() if not key.startswith("_")
|
|
577
|
+
}
|
|
578
|
+
resolved_param_writes: list[EvalFileWrite] = []
|
|
579
|
+
for write in param_writes:
|
|
580
|
+
try:
|
|
581
|
+
resolved_content = write.content.format_map(format_context)
|
|
582
|
+
except KeyError as error:
|
|
583
|
+
missing = error.args[0]
|
|
584
|
+
raise ToolValidationError(
|
|
585
|
+
f"Missing template variable '{missing}' in write request."
|
|
586
|
+
) from error
|
|
587
|
+
resolved_param_writes.append(
|
|
588
|
+
_normalize_write(
|
|
589
|
+
EvalFileWrite(
|
|
590
|
+
path=write.path,
|
|
591
|
+
content=resolved_content,
|
|
592
|
+
mode=write.mode,
|
|
593
|
+
)
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
final_writes = tuple(resolved_param_writes + helper_writes)
|
|
597
|
+
seen_targets: set[tuple[str, ...]] = set()
|
|
598
|
+
for write in final_writes:
|
|
599
|
+
key = write.path.segments
|
|
600
|
+
if key in seen_targets:
|
|
601
|
+
raise ToolValidationError(
|
|
602
|
+
"Duplicate write targets detected."
|
|
603
|
+
) # pragma: no cover - upstream checks prevent duplicates
|
|
604
|
+
seen_targets.add(key)
|
|
605
|
+
if final_writes:
|
|
606
|
+
aliases = ["/".join(write.path.segments) for write in final_writes[:3]]
|
|
607
|
+
if len(final_writes) > 3:
|
|
608
|
+
aliases.append(f"+{len(final_writes) - 3} more")
|
|
609
|
+
writes_summary = f"{len(final_writes)} file(s): {', '.join(aliases)}"
|
|
610
|
+
else:
|
|
611
|
+
writes_summary = "none"
|
|
612
|
+
message = (
|
|
613
|
+
"Evaluation succeeded. "
|
|
614
|
+
f"value={value_preview}; stdout={stdout_preview}; "
|
|
615
|
+
f"stderr={stderr_preview}; writes={writes_summary}."
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
globals_payload: dict[str, str] = {}
|
|
619
|
+
visible_keys = {
|
|
620
|
+
key for key in symtable if key not in all_keys and not key.startswith("_")
|
|
621
|
+
}
|
|
622
|
+
visible_keys.update(user_globals.keys())
|
|
623
|
+
for key in visible_keys:
|
|
624
|
+
globals_payload[key] = _format_value(symtable.get(key))
|
|
625
|
+
globals_payload.update(
|
|
626
|
+
{f"vfs:{alias}": content for alias, content in read_globals.items()}
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
result = EvalResult(
|
|
630
|
+
value_repr=value_repr,
|
|
631
|
+
stdout=stdout,
|
|
632
|
+
stderr=stderr,
|
|
633
|
+
globals=globals_payload,
|
|
634
|
+
reads=reads,
|
|
635
|
+
writes=final_writes,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
_logger.debug(
|
|
639
|
+
"asteval.run",
|
|
640
|
+
extra={
|
|
641
|
+
"event": "asteval.run",
|
|
642
|
+
"mode": mode,
|
|
643
|
+
"stdout_len": len(stdout),
|
|
644
|
+
"stderr_len": len(stderr),
|
|
645
|
+
"write_count": len(final_writes),
|
|
646
|
+
"code_preview": code[:200],
|
|
647
|
+
},
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
return ToolResult(message=message, value=result)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
def _make_eval_result_reducer() -> Callable[
|
|
654
|
+
[tuple[VirtualFileSystem, ...], DataEvent], tuple[VirtualFileSystem, ...]
|
|
655
|
+
]:
|
|
656
|
+
def reducer(
|
|
657
|
+
slice_values: tuple[VirtualFileSystem, ...], event: DataEvent
|
|
658
|
+
) -> tuple[VirtualFileSystem, ...]:
|
|
659
|
+
previous = slice_values[-1] if slice_values else VirtualFileSystem()
|
|
660
|
+
value = cast(EvalResult, event.value)
|
|
661
|
+
if not value.writes:
|
|
662
|
+
return (previous,)
|
|
663
|
+
snapshot = _apply_writes(previous, value.writes)
|
|
664
|
+
return (snapshot,)
|
|
665
|
+
|
|
666
|
+
return reducer
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
class AstevalSection(MarkdownSection[_AstevalSectionParams]):
|
|
670
|
+
"""Prompt section exposing the :mod:`asteval` evaluation tool."""
|
|
671
|
+
|
|
672
|
+
def __init__(self, *, session: Session) -> None:
|
|
673
|
+
self._session = session
|
|
674
|
+
session.register_reducer(
|
|
675
|
+
EvalResult, _make_eval_result_reducer(), slice_type=VirtualFileSystem
|
|
676
|
+
)
|
|
677
|
+
tool_suite = _AstevalToolSuite(session=session)
|
|
678
|
+
tool = Tool[EvalParams, EvalResult](
|
|
679
|
+
name="evaluate_python",
|
|
680
|
+
description="Evaluate a short Python expression in a sandboxed environment with optional VFS access.",
|
|
681
|
+
handler=tool_suite.run,
|
|
682
|
+
)
|
|
683
|
+
super().__init__(
|
|
684
|
+
title="Python Evaluation Tool",
|
|
685
|
+
key="tools.asteval",
|
|
686
|
+
template=_EVAL_TEMPLATE,
|
|
687
|
+
default_params=_AstevalSectionParams(),
|
|
688
|
+
tools=(tool,),
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
__all__ = [
|
|
693
|
+
"AstevalSection",
|
|
694
|
+
"EvalFileRead",
|
|
695
|
+
"EvalFileWrite",
|
|
696
|
+
"EvalParams",
|
|
697
|
+
"EvalResult",
|
|
698
|
+
]
|