pydantic-ai-rlm 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydantic_ai_rlm/__init__.py +32 -0
- pydantic_ai_rlm/agent.py +161 -0
- pydantic_ai_rlm/dependencies.py +47 -0
- pydantic_ai_rlm/logging.py +274 -0
- pydantic_ai_rlm/prompts.py +118 -0
- pydantic_ai_rlm/py.typed +0 -0
- pydantic_ai_rlm/repl.py +481 -0
- pydantic_ai_rlm/toolset.py +168 -0
- pydantic_ai_rlm/utils.py +47 -0
- pydantic_ai_rlm-0.1.0.dist-info/METADATA +344 -0
- pydantic_ai_rlm-0.1.0.dist-info/RECORD +13 -0
- pydantic_ai_rlm-0.1.0.dist-info/WHEEL +4 -0
- pydantic_ai_rlm-0.1.0.dist-info/licenses/LICENSE +21 -0
pydantic_ai_rlm/repl.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
import io
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import shutil
|
|
8
|
+
import sys
|
|
9
|
+
import tempfile
|
|
10
|
+
import threading
|
|
11
|
+
import time
|
|
12
|
+
from contextlib import contextmanager
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import Any, ClassVar
|
|
15
|
+
|
|
16
|
+
from pydantic_ai import ModelRequest
|
|
17
|
+
from pydantic_ai.direct import model_request_sync
|
|
18
|
+
from pydantic_ai.messages import TextPart
|
|
19
|
+
|
|
20
|
+
from .dependencies import ContextType, RLMConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class REPLResult:
|
|
25
|
+
"""Result from REPL code execution."""
|
|
26
|
+
|
|
27
|
+
stdout: str
|
|
28
|
+
"""Standard output from execution."""
|
|
29
|
+
|
|
30
|
+
stderr: str
|
|
31
|
+
"""Standard error from execution."""
|
|
32
|
+
|
|
33
|
+
locals: dict[str, Any]
|
|
34
|
+
"""Local variables after execution."""
|
|
35
|
+
|
|
36
|
+
execution_time: float
|
|
37
|
+
"""Time taken to execute in seconds."""
|
|
38
|
+
|
|
39
|
+
success: bool = True
|
|
40
|
+
"""Whether execution completed without errors."""
|
|
41
|
+
|
|
42
|
+
def __str__(self) -> str:
|
|
43
|
+
return f"REPLResult(success={self.success}, stdout={self.stdout[:100]}..., stderr={self.stderr[:100]}...)"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class REPLEnvironment:
|
|
47
|
+
"""
|
|
48
|
+
Sandboxed Python execution environment for RLM.
|
|
49
|
+
|
|
50
|
+
Provides a safe environment where the LLM can execute Python code
|
|
51
|
+
to analyze large contexts. The context is loaded as a variable
|
|
52
|
+
accessible within the REPL.
|
|
53
|
+
|
|
54
|
+
Key features:
|
|
55
|
+
- Sandboxed execution with restricted built-ins
|
|
56
|
+
- Persistent state across multiple executions
|
|
57
|
+
- Stdout/stderr capture
|
|
58
|
+
- Configurable security settings
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
# Safe built-ins that don't allow dangerous operations
|
|
62
|
+
SAFE_BUILTINS: ClassVar[dict[str, Any]] = {
|
|
63
|
+
# Core types
|
|
64
|
+
"print": print,
|
|
65
|
+
"len": len,
|
|
66
|
+
"str": str,
|
|
67
|
+
"int": int,
|
|
68
|
+
"float": float,
|
|
69
|
+
"bool": bool,
|
|
70
|
+
"list": list,
|
|
71
|
+
"dict": dict,
|
|
72
|
+
"set": set,
|
|
73
|
+
"tuple": tuple,
|
|
74
|
+
"type": type,
|
|
75
|
+
"isinstance": isinstance,
|
|
76
|
+
"issubclass": issubclass,
|
|
77
|
+
# Iteration
|
|
78
|
+
"range": range,
|
|
79
|
+
"enumerate": enumerate,
|
|
80
|
+
"zip": zip,
|
|
81
|
+
"map": map,
|
|
82
|
+
"filter": filter,
|
|
83
|
+
"sorted": sorted,
|
|
84
|
+
"reversed": reversed,
|
|
85
|
+
"iter": iter,
|
|
86
|
+
"next": next,
|
|
87
|
+
# Math
|
|
88
|
+
"min": min,
|
|
89
|
+
"max": max,
|
|
90
|
+
"sum": sum,
|
|
91
|
+
"abs": abs,
|
|
92
|
+
"round": round,
|
|
93
|
+
"pow": pow,
|
|
94
|
+
"divmod": divmod,
|
|
95
|
+
# String/char
|
|
96
|
+
"chr": chr,
|
|
97
|
+
"ord": ord,
|
|
98
|
+
"hex": hex,
|
|
99
|
+
"bin": bin,
|
|
100
|
+
"oct": oct,
|
|
101
|
+
"repr": repr,
|
|
102
|
+
"ascii": ascii,
|
|
103
|
+
"format": format,
|
|
104
|
+
# Collections
|
|
105
|
+
"any": any,
|
|
106
|
+
"all": all,
|
|
107
|
+
"slice": slice,
|
|
108
|
+
"hash": hash,
|
|
109
|
+
"id": id,
|
|
110
|
+
"callable": callable,
|
|
111
|
+
# Attribute access
|
|
112
|
+
"hasattr": hasattr,
|
|
113
|
+
"getattr": getattr,
|
|
114
|
+
"setattr": setattr,
|
|
115
|
+
"delattr": delattr,
|
|
116
|
+
"dir": dir,
|
|
117
|
+
"vars": vars,
|
|
118
|
+
# Binary
|
|
119
|
+
"bytes": bytes,
|
|
120
|
+
"bytearray": bytearray,
|
|
121
|
+
"memoryview": memoryview,
|
|
122
|
+
"complex": complex,
|
|
123
|
+
# OOP
|
|
124
|
+
"super": super,
|
|
125
|
+
"property": property,
|
|
126
|
+
"staticmethod": staticmethod,
|
|
127
|
+
"classmethod": classmethod,
|
|
128
|
+
"object": object,
|
|
129
|
+
# Exceptions (for try/except blocks)
|
|
130
|
+
"Exception": Exception,
|
|
131
|
+
"ValueError": ValueError,
|
|
132
|
+
"TypeError": TypeError,
|
|
133
|
+
"KeyError": KeyError,
|
|
134
|
+
"IndexError": IndexError,
|
|
135
|
+
"AttributeError": AttributeError,
|
|
136
|
+
"RuntimeError": RuntimeError,
|
|
137
|
+
"StopIteration": StopIteration,
|
|
138
|
+
"AssertionError": AssertionError,
|
|
139
|
+
"NotImplementedError": NotImplementedError,
|
|
140
|
+
# Allow imports (controlled)
|
|
141
|
+
"__import__": __import__,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
# Additional built-ins when file access is enabled
|
|
145
|
+
FILE_ACCESS_BUILTINS: ClassVar[dict[str, Any]] = {
|
|
146
|
+
"open": open,
|
|
147
|
+
"FileNotFoundError": FileNotFoundError,
|
|
148
|
+
"OSError": OSError,
|
|
149
|
+
"IOError": IOError,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
# Built-ins that are always blocked
|
|
153
|
+
BLOCKED_BUILTINS: ClassVar[dict[str, None]] = {
|
|
154
|
+
"eval": None,
|
|
155
|
+
"exec": None,
|
|
156
|
+
"compile": None,
|
|
157
|
+
"globals": None,
|
|
158
|
+
"locals": None,
|
|
159
|
+
"input": None,
|
|
160
|
+
"__builtins__": None,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
def __init__(
|
|
164
|
+
self,
|
|
165
|
+
context: ContextType,
|
|
166
|
+
config: RLMConfig | None = None,
|
|
167
|
+
):
|
|
168
|
+
"""
|
|
169
|
+
Initialize the REPL environment.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
context: The context data to make available as 'context' variable
|
|
173
|
+
config: Configuration options for the REPL
|
|
174
|
+
"""
|
|
175
|
+
self.config = config or RLMConfig()
|
|
176
|
+
self.original_cwd = os.getcwd()
|
|
177
|
+
self.temp_dir = tempfile.mkdtemp(prefix="rlm_repl_")
|
|
178
|
+
self._lock = threading.Lock()
|
|
179
|
+
self.locals: dict[str, Any] = {}
|
|
180
|
+
|
|
181
|
+
# Setup globals with safe built-ins
|
|
182
|
+
self.globals: dict[str, Any] = {
|
|
183
|
+
"__builtins__": self._create_builtins(),
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
if self.config.sub_model:
|
|
187
|
+
self._setup_llm_query()
|
|
188
|
+
|
|
189
|
+
# Load context into environment
|
|
190
|
+
self._load_context(context)
|
|
191
|
+
|
|
192
|
+
def _create_builtins(self) -> dict[str, Any]:
|
|
193
|
+
"""Create the built-ins dict based on config."""
|
|
194
|
+
builtins = dict(self.SAFE_BUILTINS)
|
|
195
|
+
|
|
196
|
+
# Always include file access builtins - needed for context loading
|
|
197
|
+
# and generally useful for data analysis. The temp directory
|
|
198
|
+
# provides sandboxing.
|
|
199
|
+
builtins.update(self.FILE_ACCESS_BUILTINS)
|
|
200
|
+
|
|
201
|
+
# Apply blocked builtins
|
|
202
|
+
builtins.update(self.BLOCKED_BUILTINS)
|
|
203
|
+
|
|
204
|
+
return builtins
|
|
205
|
+
|
|
206
|
+
def _setup_llm_query(self) -> None:
|
|
207
|
+
"""
|
|
208
|
+
Set up the llm_query function for the REPL environment.
|
|
209
|
+
"""
|
|
210
|
+
from .logging import get_logger
|
|
211
|
+
|
|
212
|
+
def llm_query(prompt: str) -> str:
|
|
213
|
+
"""
|
|
214
|
+
Query a sub-LLM with the given prompt.
|
|
215
|
+
|
|
216
|
+
This function allows you to delegate analysis tasks to another
|
|
217
|
+
LLM, which is useful for processing large contexts in chunks.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
prompt: The prompt to send to the sub-LLM
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
The sub-LLM's response as a string
|
|
224
|
+
"""
|
|
225
|
+
logger = get_logger()
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
if not self.config.sub_model:
|
|
229
|
+
return "Error: No sub-model configured"
|
|
230
|
+
|
|
231
|
+
# Log the query
|
|
232
|
+
logger.log_llm_query(prompt)
|
|
233
|
+
|
|
234
|
+
result = model_request_sync(
|
|
235
|
+
self.config.sub_model,
|
|
236
|
+
[ModelRequest.user_text_prompt(prompt)],
|
|
237
|
+
)
|
|
238
|
+
# Extract text from the response parts
|
|
239
|
+
text_parts = [part.content for part in result.parts if isinstance(part, TextPart)]
|
|
240
|
+
response = "".join(text_parts) if text_parts else ""
|
|
241
|
+
|
|
242
|
+
# Log the response
|
|
243
|
+
logger.log_llm_response(response)
|
|
244
|
+
|
|
245
|
+
return response
|
|
246
|
+
except Exception as e:
|
|
247
|
+
return f"Error querying sub-LLM: {e!s}"
|
|
248
|
+
|
|
249
|
+
# Add llm_query to globals
|
|
250
|
+
self.globals["llm_query"] = llm_query
|
|
251
|
+
|
|
252
|
+
def _load_context(self, context: ContextType) -> None:
|
|
253
|
+
"""
|
|
254
|
+
Load context data into the REPL environment.
|
|
255
|
+
|
|
256
|
+
The context is written to a file and then loaded into the
|
|
257
|
+
'context' variable in the REPL namespace.
|
|
258
|
+
"""
|
|
259
|
+
if isinstance(context, str):
|
|
260
|
+
# Text context
|
|
261
|
+
context_path = os.path.join(self.temp_dir, "context.txt")
|
|
262
|
+
with open(context_path, "w", encoding="utf-8") as f:
|
|
263
|
+
f.write(context)
|
|
264
|
+
|
|
265
|
+
load_code = f"""
|
|
266
|
+
with open(r'{context_path}', 'r', encoding='utf-8') as f:
|
|
267
|
+
context = f.read()
|
|
268
|
+
"""
|
|
269
|
+
else:
|
|
270
|
+
# JSON context (dict or list)
|
|
271
|
+
context_path = os.path.join(self.temp_dir, "context.json")
|
|
272
|
+
with open(context_path, "w", encoding="utf-8") as f:
|
|
273
|
+
json.dump(context, f, indent=2, default=str)
|
|
274
|
+
|
|
275
|
+
load_code = f"""
|
|
276
|
+
import json
|
|
277
|
+
with open(r'{context_path}', 'r', encoding='utf-8') as f:
|
|
278
|
+
context = json.load(f)
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
# Execute the load code to populate 'context' variable
|
|
282
|
+
self._execute_internal(load_code)
|
|
283
|
+
|
|
284
|
+
def _execute_internal(self, code: str) -> None:
|
|
285
|
+
"""Execute code internally without capturing output."""
|
|
286
|
+
combined = {**self.globals, **self.locals}
|
|
287
|
+
exec(code, combined, combined)
|
|
288
|
+
|
|
289
|
+
# Update locals with new variables
|
|
290
|
+
for key, value in combined.items():
|
|
291
|
+
if key not in self.globals and not key.startswith("_"):
|
|
292
|
+
self.locals[key] = value
|
|
293
|
+
|
|
294
|
+
@contextmanager
|
|
295
|
+
def _capture_output(self):
|
|
296
|
+
"""Thread-safe context manager to capture stdout/stderr."""
|
|
297
|
+
old_stdout = sys.stdout
|
|
298
|
+
old_stderr = sys.stderr
|
|
299
|
+
|
|
300
|
+
stdout_buffer = io.StringIO()
|
|
301
|
+
stderr_buffer = io.StringIO()
|
|
302
|
+
|
|
303
|
+
try:
|
|
304
|
+
sys.stdout = stdout_buffer
|
|
305
|
+
sys.stderr = stderr_buffer
|
|
306
|
+
yield stdout_buffer, stderr_buffer
|
|
307
|
+
finally:
|
|
308
|
+
sys.stdout = old_stdout
|
|
309
|
+
sys.stderr = old_stderr
|
|
310
|
+
|
|
311
|
+
@contextmanager
|
|
312
|
+
def _temp_working_directory(self):
|
|
313
|
+
"""Context manager to temporarily change working directory."""
|
|
314
|
+
old_cwd = os.getcwd()
|
|
315
|
+
try:
|
|
316
|
+
os.chdir(self.temp_dir)
|
|
317
|
+
yield
|
|
318
|
+
finally:
|
|
319
|
+
os.chdir(old_cwd)
|
|
320
|
+
|
|
321
|
+
def execute(self, code: str) -> REPLResult:
|
|
322
|
+
"""
|
|
323
|
+
Execute Python code in the REPL environment.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
code: Python code to execute
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
REPLResult with stdout, stderr, locals, and timing
|
|
330
|
+
"""
|
|
331
|
+
start_time = time.time()
|
|
332
|
+
success = True
|
|
333
|
+
stdout_content = ""
|
|
334
|
+
stderr_content = ""
|
|
335
|
+
|
|
336
|
+
with (
|
|
337
|
+
self._lock,
|
|
338
|
+
self._capture_output() as (stdout_buffer, stderr_buffer),
|
|
339
|
+
self._temp_working_directory(),
|
|
340
|
+
):
|
|
341
|
+
try:
|
|
342
|
+
# Split into imports and other code
|
|
343
|
+
lines = code.split("\n")
|
|
344
|
+
import_lines = []
|
|
345
|
+
other_lines = []
|
|
346
|
+
|
|
347
|
+
for line in lines:
|
|
348
|
+
stripped = line.strip()
|
|
349
|
+
if stripped.startswith(("import ", "from ")) and not stripped.startswith("#"):
|
|
350
|
+
import_lines.append(line)
|
|
351
|
+
else:
|
|
352
|
+
other_lines.append(line)
|
|
353
|
+
|
|
354
|
+
# Execute imports in globals
|
|
355
|
+
if import_lines:
|
|
356
|
+
import_code = "\n".join(import_lines)
|
|
357
|
+
exec(import_code, self.globals, self.globals)
|
|
358
|
+
|
|
359
|
+
# Execute rest of code
|
|
360
|
+
if other_lines:
|
|
361
|
+
other_code = "\n".join(other_lines)
|
|
362
|
+
combined = {**self.globals, **self.locals}
|
|
363
|
+
|
|
364
|
+
# Try to evaluate last expression for display
|
|
365
|
+
self._execute_with_expression_display(other_code, other_lines, combined)
|
|
366
|
+
|
|
367
|
+
# Update locals
|
|
368
|
+
for key, value in combined.items():
|
|
369
|
+
if key not in self.globals:
|
|
370
|
+
self.locals[key] = value
|
|
371
|
+
|
|
372
|
+
stdout_content = stdout_buffer.getvalue()
|
|
373
|
+
stderr_content = stderr_buffer.getvalue()
|
|
374
|
+
|
|
375
|
+
except Exception as e:
|
|
376
|
+
success = False
|
|
377
|
+
stderr_content = stderr_buffer.getvalue() + f"\nError: {e!s}"
|
|
378
|
+
stdout_content = stdout_buffer.getvalue()
|
|
379
|
+
|
|
380
|
+
execution_time = time.time() - start_time
|
|
381
|
+
|
|
382
|
+
# Truncate output if needed
|
|
383
|
+
max_chars = self.config.truncate_output_chars
|
|
384
|
+
if len(stdout_content) > max_chars:
|
|
385
|
+
stdout_content = stdout_content[:max_chars] + "\n... (output truncated)"
|
|
386
|
+
if len(stderr_content) > max_chars:
|
|
387
|
+
stderr_content = stderr_content[:max_chars] + "\n... (output truncated)"
|
|
388
|
+
|
|
389
|
+
return REPLResult(
|
|
390
|
+
stdout=stdout_content,
|
|
391
|
+
stderr=stderr_content,
|
|
392
|
+
locals=dict(self.locals),
|
|
393
|
+
execution_time=execution_time,
|
|
394
|
+
success=success,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
def _execute_with_expression_display(
|
|
398
|
+
self,
|
|
399
|
+
code: str,
|
|
400
|
+
lines: list[str],
|
|
401
|
+
namespace: dict[str, Any],
|
|
402
|
+
) -> None:
|
|
403
|
+
"""
|
|
404
|
+
Execute code, displaying the last expression's value if applicable.
|
|
405
|
+
|
|
406
|
+
This mimics notebook/REPL behavior where the last expression is
|
|
407
|
+
automatically displayed.
|
|
408
|
+
"""
|
|
409
|
+
# Find non-comment, non-empty lines
|
|
410
|
+
non_comment_lines = [line for line in lines if line.strip() and not line.strip().startswith("#")]
|
|
411
|
+
|
|
412
|
+
if not non_comment_lines:
|
|
413
|
+
exec(code, namespace, namespace)
|
|
414
|
+
return
|
|
415
|
+
|
|
416
|
+
last_line = non_comment_lines[-1].strip()
|
|
417
|
+
|
|
418
|
+
# Check if last line is an expression (not a statement)
|
|
419
|
+
is_expression = (
|
|
420
|
+
not last_line.startswith(
|
|
421
|
+
(
|
|
422
|
+
"import ",
|
|
423
|
+
"from ",
|
|
424
|
+
"def ",
|
|
425
|
+
"class ",
|
|
426
|
+
"if ",
|
|
427
|
+
"for ",
|
|
428
|
+
"while ",
|
|
429
|
+
"try:",
|
|
430
|
+
"with ",
|
|
431
|
+
"return ",
|
|
432
|
+
"yield ",
|
|
433
|
+
"raise ",
|
|
434
|
+
"break",
|
|
435
|
+
"continue",
|
|
436
|
+
"pass",
|
|
437
|
+
"assert ",
|
|
438
|
+
"del ",
|
|
439
|
+
"global ",
|
|
440
|
+
"nonlocal ",
|
|
441
|
+
)
|
|
442
|
+
)
|
|
443
|
+
and "=" not in last_line.split("#")[0] # Not assignment
|
|
444
|
+
and not last_line.endswith(":") # Not control structure
|
|
445
|
+
and not last_line.startswith("print(") # Not explicit print
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
if is_expression and len(non_comment_lines) > 0:
|
|
449
|
+
try:
|
|
450
|
+
# Execute all but last line
|
|
451
|
+
if len(non_comment_lines) > 1:
|
|
452
|
+
# Find where last line starts
|
|
453
|
+
last_line_idx = None
|
|
454
|
+
for i, line in enumerate(lines):
|
|
455
|
+
if line.strip() == last_line:
|
|
456
|
+
last_line_idx = i
|
|
457
|
+
break
|
|
458
|
+
|
|
459
|
+
if last_line_idx and last_line_idx > 0:
|
|
460
|
+
statements = "\n".join(lines[:last_line_idx])
|
|
461
|
+
exec(statements, namespace, namespace)
|
|
462
|
+
|
|
463
|
+
# Evaluate and print last expression
|
|
464
|
+
result = eval(last_line, namespace, namespace)
|
|
465
|
+
if result is not None:
|
|
466
|
+
print(repr(result))
|
|
467
|
+
|
|
468
|
+
except (SyntaxError, NameError):
|
|
469
|
+
# Fall back to normal execution
|
|
470
|
+
exec(code, namespace, namespace)
|
|
471
|
+
else:
|
|
472
|
+
exec(code, namespace, namespace)
|
|
473
|
+
|
|
474
|
+
def cleanup(self) -> None:
|
|
475
|
+
"""Clean up temporary directory."""
|
|
476
|
+
with contextlib.suppress(Exception):
|
|
477
|
+
shutil.rmtree(self.temp_dir)
|
|
478
|
+
|
|
479
|
+
def __del__(self):
|
|
480
|
+
"""Destructor to ensure cleanup."""
|
|
481
|
+
self.cleanup()
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
|
|
5
|
+
from pydantic_ai import RunContext
|
|
6
|
+
from pydantic_ai.toolsets import FunctionToolset
|
|
7
|
+
|
|
8
|
+
from .dependencies import RLMConfig, RLMDependencies
|
|
9
|
+
from .logging import get_logger
|
|
10
|
+
from .repl import REPLEnvironment, REPLResult
|
|
11
|
+
from .utils import format_repl_result
|
|
12
|
+
|
|
13
|
+
EXECUTE_CODE_DESCRIPTION = """
|
|
14
|
+
Execute Python code in a sandboxed REPL environment.
|
|
15
|
+
|
|
16
|
+
## Environment
|
|
17
|
+
- A `context` variable is pre-loaded with the data to analyze
|
|
18
|
+
- Variables persist between executions within the same session
|
|
19
|
+
- Standard library modules are available (json, re, collections, etc.)
|
|
20
|
+
- Use print() to display output
|
|
21
|
+
|
|
22
|
+
## When to Use
|
|
23
|
+
- Analyzing or processing structured data (JSON, dicts, lists)
|
|
24
|
+
- Performing calculations or data transformations
|
|
25
|
+
- Extracting specific information from large datasets
|
|
26
|
+
- Testing hypotheses about the data structure
|
|
27
|
+
|
|
28
|
+
## Best Practices
|
|
29
|
+
1. Start by exploring the context: `print(type(context))`, `print(len(context))`
|
|
30
|
+
2. Break complex operations into smaller steps
|
|
31
|
+
3. Use print() liberally to understand intermediate results
|
|
32
|
+
4. Handle potential errors gracefully with try/except
|
|
33
|
+
|
|
34
|
+
## Available Functions
|
|
35
|
+
- `llm_query(prompt)`: Query the LLM for reasoning assistance (if configured)
|
|
36
|
+
- Important: Do not use `llm_query` in the first code execution. Use it only after you have
|
|
37
|
+
explored the context and identified specific sections that need semantic analysis.
|
|
38
|
+
|
|
39
|
+
## Example
|
|
40
|
+
```python
|
|
41
|
+
# Explore the data
|
|
42
|
+
print(f"Context type: {type(context)}")
|
|
43
|
+
print(f"Keys: {list(context.keys()) if isinstance(context, dict) else 'N/A'}")
|
|
44
|
+
|
|
45
|
+
# Process and extract information
|
|
46
|
+
if isinstance(context, dict):
|
|
47
|
+
for key, value in context.items():
|
|
48
|
+
print(f"{key}: {type(value)}")
|
|
49
|
+
```
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# Global registry to track REPL environments for cleanup
|
|
54
|
+
_repl_registry: dict[int, REPLEnvironment] = {}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def create_rlm_toolset(
|
|
58
|
+
*,
|
|
59
|
+
code_timeout: float = 60.0,
|
|
60
|
+
sub_model: str | None = None,
|
|
61
|
+
toolset_id: str | None = None,
|
|
62
|
+
) -> FunctionToolset[RLMDependencies]:
|
|
63
|
+
"""Create an RLM toolset for code execution in a sandboxed REPL.
|
|
64
|
+
|
|
65
|
+
This toolset provides an `execute_code` tool that allows AI agents to
|
|
66
|
+
run Python code with access to a `context` variable containing data to analyze.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
code_timeout: Timeout in seconds for code execution. Defaults to 60.0.
|
|
70
|
+
sub_model: Model to use for llm_query() within the REPL environment.
|
|
71
|
+
toolset_id: Optional unique identifier for the toolset.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
FunctionToolset compatible with any pydantic-ai agent.
|
|
75
|
+
|
|
76
|
+
Example (basic usage):
|
|
77
|
+
```python
|
|
78
|
+
from pydantic_ai import Agent
|
|
79
|
+
from pydantic_ai_rlm import create_rlm_toolset, RLMDependencies
|
|
80
|
+
|
|
81
|
+
toolset = create_rlm_toolset()
|
|
82
|
+
agent = Agent("openai:gpt-5", toolsets=[toolset])
|
|
83
|
+
|
|
84
|
+
deps = RLMDependencies(context={"users": [...]})
|
|
85
|
+
result = await agent.run("Analyze the user data", deps=deps)
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
Example (with timeout and sub-model):
|
|
89
|
+
```python
|
|
90
|
+
from pydantic_ai_rlm import create_rlm_toolset, RLMDependencies, RLMConfig
|
|
91
|
+
|
|
92
|
+
toolset = create_rlm_toolset(
|
|
93
|
+
code_timeout=120.0,
|
|
94
|
+
sub_model="openai:gpt-5-mini",
|
|
95
|
+
)
|
|
96
|
+
agent = Agent("openai:gpt-5", toolsets=[toolset])
|
|
97
|
+
|
|
98
|
+
deps = RLMDependencies(
|
|
99
|
+
context=large_dataset,
|
|
100
|
+
config=RLMConfig(),
|
|
101
|
+
)
|
|
102
|
+
result = await agent.run("Process this dataset", deps=deps)
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
Example (with toolset composition):
|
|
106
|
+
```python
|
|
107
|
+
from pydantic_ai_rlm import create_rlm_toolset
|
|
108
|
+
|
|
109
|
+
rlm_toolset = create_rlm_toolset().prefixed("rlm")
|
|
110
|
+
# Tool will be named 'rlm_execute_code'
|
|
111
|
+
```
|
|
112
|
+
"""
|
|
113
|
+
toolset: FunctionToolset[RLMDependencies] = FunctionToolset(id=toolset_id)
|
|
114
|
+
|
|
115
|
+
def _get_or_create_repl(ctx: RunContext[RLMDependencies]) -> REPLEnvironment:
|
|
116
|
+
"""Get or create REPL environment for this run context."""
|
|
117
|
+
deps_id = id(ctx.deps)
|
|
118
|
+
|
|
119
|
+
if deps_id not in _repl_registry:
|
|
120
|
+
config = ctx.deps.config or RLMConfig()
|
|
121
|
+
# Override sub_model from factory if set and not already in config
|
|
122
|
+
if sub_model and not config.sub_model:
|
|
123
|
+
config = RLMConfig(
|
|
124
|
+
sub_model=sub_model,
|
|
125
|
+
)
|
|
126
|
+
_repl_registry[deps_id] = REPLEnvironment(
|
|
127
|
+
context=ctx.deps.context,
|
|
128
|
+
config=config,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return _repl_registry[deps_id]
|
|
132
|
+
|
|
133
|
+
@toolset.tool(description=EXECUTE_CODE_DESCRIPTION)
|
|
134
|
+
async def execute_code(ctx: RunContext[RLMDependencies], code: str) -> str:
|
|
135
|
+
repl_env = _get_or_create_repl(ctx)
|
|
136
|
+
logger = get_logger()
|
|
137
|
+
|
|
138
|
+
# Log the code being executed
|
|
139
|
+
logger.log_code_execution(code)
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
loop = asyncio.get_running_loop()
|
|
143
|
+
result: REPLResult = await asyncio.wait_for(
|
|
144
|
+
loop.run_in_executor(None, repl_env.execute, code),
|
|
145
|
+
timeout=code_timeout,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Log the result
|
|
149
|
+
logger.log_result(result)
|
|
150
|
+
|
|
151
|
+
return format_repl_result(result)
|
|
152
|
+
|
|
153
|
+
except TimeoutError:
|
|
154
|
+
return f"Error: Code execution timed out after {code_timeout} seconds."
|
|
155
|
+
except Exception as e:
|
|
156
|
+
return f"Error executing code: {e!s}"
|
|
157
|
+
|
|
158
|
+
return toolset
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def cleanup_repl_environments() -> None:
|
|
162
|
+
"""Clean up all REPL environments.
|
|
163
|
+
|
|
164
|
+
Call this when you're done with all agent runs to release resources.
|
|
165
|
+
"""
|
|
166
|
+
for repl_env in _repl_registry.values():
|
|
167
|
+
repl_env.cleanup()
|
|
168
|
+
_repl_registry.clear()
|
pydantic_ai_rlm/utils.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from .repl import REPLResult
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def format_repl_result(result: REPLResult, max_var_display: int = 200) -> str:
|
|
7
|
+
"""
|
|
8
|
+
Format a REPL execution result for display to the LLM.
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
result: The REPLResult from code execution
|
|
12
|
+
max_var_display: Maximum characters to show per variable value
|
|
13
|
+
|
|
14
|
+
Returns:
|
|
15
|
+
Formatted string suitable for LLM consumption
|
|
16
|
+
"""
|
|
17
|
+
parts = []
|
|
18
|
+
|
|
19
|
+
if result.stdout.strip():
|
|
20
|
+
parts.append(f"Output:\n{result.stdout}")
|
|
21
|
+
|
|
22
|
+
if result.stderr.strip():
|
|
23
|
+
parts.append(f"Errors:\n{result.stderr}")
|
|
24
|
+
|
|
25
|
+
# Show created/modified variables (excluding internal ones)
|
|
26
|
+
user_vars = {k: v for k, v in result.locals.items() if not k.startswith("_") and k not in ("context", "json", "re", "os")}
|
|
27
|
+
|
|
28
|
+
if user_vars:
|
|
29
|
+
var_summaries = []
|
|
30
|
+
for name, value in user_vars.items():
|
|
31
|
+
try:
|
|
32
|
+
value_str = repr(value)
|
|
33
|
+
if len(value_str) > max_var_display:
|
|
34
|
+
value_str = value_str[:max_var_display] + "..."
|
|
35
|
+
var_summaries.append(f" {name} = {value_str}")
|
|
36
|
+
except Exception:
|
|
37
|
+
var_summaries.append(f" {name} = <{type(value).__name__}>")
|
|
38
|
+
|
|
39
|
+
if var_summaries:
|
|
40
|
+
parts.append("Variables:\n" + "\n".join(var_summaries))
|
|
41
|
+
|
|
42
|
+
parts.append(f"Execution time: {result.execution_time:.3f}s")
|
|
43
|
+
|
|
44
|
+
if not parts:
|
|
45
|
+
return "Code executed successfully (no output)"
|
|
46
|
+
|
|
47
|
+
return "\n\n".join(parts)
|