pydantic-ai-rlm 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,481 @@
1
+ from __future__ import annotations
2
+
3
+ import contextlib
4
+ import io
5
+ import json
6
+ import os
7
+ import shutil
8
+ import sys
9
+ import tempfile
10
+ import threading
11
+ import time
12
+ from contextlib import contextmanager
13
+ from dataclasses import dataclass
14
+ from typing import Any, ClassVar
15
+
16
+ from pydantic_ai import ModelRequest
17
+ from pydantic_ai.direct import model_request_sync
18
+ from pydantic_ai.messages import TextPart
19
+
20
+ from .dependencies import ContextType, RLMConfig
21
+
22
+
23
+ @dataclass
24
+ class REPLResult:
25
+ """Result from REPL code execution."""
26
+
27
+ stdout: str
28
+ """Standard output from execution."""
29
+
30
+ stderr: str
31
+ """Standard error from execution."""
32
+
33
+ locals: dict[str, Any]
34
+ """Local variables after execution."""
35
+
36
+ execution_time: float
37
+ """Time taken to execute in seconds."""
38
+
39
+ success: bool = True
40
+ """Whether execution completed without errors."""
41
+
42
+ def __str__(self) -> str:
43
+ return f"REPLResult(success={self.success}, stdout={self.stdout[:100]}..., stderr={self.stderr[:100]}...)"
44
+
45
+
46
+ class REPLEnvironment:
47
+ """
48
+ Sandboxed Python execution environment for RLM.
49
+
50
+ Provides a safe environment where the LLM can execute Python code
51
+ to analyze large contexts. The context is loaded as a variable
52
+ accessible within the REPL.
53
+
54
+ Key features:
55
+ - Sandboxed execution with restricted built-ins
56
+ - Persistent state across multiple executions
57
+ - Stdout/stderr capture
58
+ - Configurable security settings
59
+ """
60
+
61
+ # Safe built-ins that don't allow dangerous operations
62
+ SAFE_BUILTINS: ClassVar[dict[str, Any]] = {
63
+ # Core types
64
+ "print": print,
65
+ "len": len,
66
+ "str": str,
67
+ "int": int,
68
+ "float": float,
69
+ "bool": bool,
70
+ "list": list,
71
+ "dict": dict,
72
+ "set": set,
73
+ "tuple": tuple,
74
+ "type": type,
75
+ "isinstance": isinstance,
76
+ "issubclass": issubclass,
77
+ # Iteration
78
+ "range": range,
79
+ "enumerate": enumerate,
80
+ "zip": zip,
81
+ "map": map,
82
+ "filter": filter,
83
+ "sorted": sorted,
84
+ "reversed": reversed,
85
+ "iter": iter,
86
+ "next": next,
87
+ # Math
88
+ "min": min,
89
+ "max": max,
90
+ "sum": sum,
91
+ "abs": abs,
92
+ "round": round,
93
+ "pow": pow,
94
+ "divmod": divmod,
95
+ # String/char
96
+ "chr": chr,
97
+ "ord": ord,
98
+ "hex": hex,
99
+ "bin": bin,
100
+ "oct": oct,
101
+ "repr": repr,
102
+ "ascii": ascii,
103
+ "format": format,
104
+ # Collections
105
+ "any": any,
106
+ "all": all,
107
+ "slice": slice,
108
+ "hash": hash,
109
+ "id": id,
110
+ "callable": callable,
111
+ # Attribute access
112
+ "hasattr": hasattr,
113
+ "getattr": getattr,
114
+ "setattr": setattr,
115
+ "delattr": delattr,
116
+ "dir": dir,
117
+ "vars": vars,
118
+ # Binary
119
+ "bytes": bytes,
120
+ "bytearray": bytearray,
121
+ "memoryview": memoryview,
122
+ "complex": complex,
123
+ # OOP
124
+ "super": super,
125
+ "property": property,
126
+ "staticmethod": staticmethod,
127
+ "classmethod": classmethod,
128
+ "object": object,
129
+ # Exceptions (for try/except blocks)
130
+ "Exception": Exception,
131
+ "ValueError": ValueError,
132
+ "TypeError": TypeError,
133
+ "KeyError": KeyError,
134
+ "IndexError": IndexError,
135
+ "AttributeError": AttributeError,
136
+ "RuntimeError": RuntimeError,
137
+ "StopIteration": StopIteration,
138
+ "AssertionError": AssertionError,
139
+ "NotImplementedError": NotImplementedError,
140
+ # Allow imports (controlled)
141
+ "__import__": __import__,
142
+ }
143
+
144
+ # Additional built-ins when file access is enabled
145
+ FILE_ACCESS_BUILTINS: ClassVar[dict[str, Any]] = {
146
+ "open": open,
147
+ "FileNotFoundError": FileNotFoundError,
148
+ "OSError": OSError,
149
+ "IOError": IOError,
150
+ }
151
+
152
+ # Built-ins that are always blocked
153
+ BLOCKED_BUILTINS: ClassVar[dict[str, None]] = {
154
+ "eval": None,
155
+ "exec": None,
156
+ "compile": None,
157
+ "globals": None,
158
+ "locals": None,
159
+ "input": None,
160
+ "__builtins__": None,
161
+ }
162
+
163
+ def __init__(
164
+ self,
165
+ context: ContextType,
166
+ config: RLMConfig | None = None,
167
+ ):
168
+ """
169
+ Initialize the REPL environment.
170
+
171
+ Args:
172
+ context: The context data to make available as 'context' variable
173
+ config: Configuration options for the REPL
174
+ """
175
+ self.config = config or RLMConfig()
176
+ self.original_cwd = os.getcwd()
177
+ self.temp_dir = tempfile.mkdtemp(prefix="rlm_repl_")
178
+ self._lock = threading.Lock()
179
+ self.locals: dict[str, Any] = {}
180
+
181
+ # Setup globals with safe built-ins
182
+ self.globals: dict[str, Any] = {
183
+ "__builtins__": self._create_builtins(),
184
+ }
185
+
186
+ if self.config.sub_model:
187
+ self._setup_llm_query()
188
+
189
+ # Load context into environment
190
+ self._load_context(context)
191
+
192
+ def _create_builtins(self) -> dict[str, Any]:
193
+ """Create the built-ins dict based on config."""
194
+ builtins = dict(self.SAFE_BUILTINS)
195
+
196
+ # Always include file access builtins - needed for context loading
197
+ # and generally useful for data analysis. The temp directory
198
+ # provides sandboxing.
199
+ builtins.update(self.FILE_ACCESS_BUILTINS)
200
+
201
+ # Apply blocked builtins
202
+ builtins.update(self.BLOCKED_BUILTINS)
203
+
204
+ return builtins
205
+
206
+ def _setup_llm_query(self) -> None:
207
+ """
208
+ Set up the llm_query function for the REPL environment.
209
+ """
210
+ from .logging import get_logger
211
+
212
+ def llm_query(prompt: str) -> str:
213
+ """
214
+ Query a sub-LLM with the given prompt.
215
+
216
+ This function allows you to delegate analysis tasks to another
217
+ LLM, which is useful for processing large contexts in chunks.
218
+
219
+ Args:
220
+ prompt: The prompt to send to the sub-LLM
221
+
222
+ Returns:
223
+ The sub-LLM's response as a string
224
+ """
225
+ logger = get_logger()
226
+
227
+ try:
228
+ if not self.config.sub_model:
229
+ return "Error: No sub-model configured"
230
+
231
+ # Log the query
232
+ logger.log_llm_query(prompt)
233
+
234
+ result = model_request_sync(
235
+ self.config.sub_model,
236
+ [ModelRequest.user_text_prompt(prompt)],
237
+ )
238
+ # Extract text from the response parts
239
+ text_parts = [part.content for part in result.parts if isinstance(part, TextPart)]
240
+ response = "".join(text_parts) if text_parts else ""
241
+
242
+ # Log the response
243
+ logger.log_llm_response(response)
244
+
245
+ return response
246
+ except Exception as e:
247
+ return f"Error querying sub-LLM: {e!s}"
248
+
249
+ # Add llm_query to globals
250
+ self.globals["llm_query"] = llm_query
251
+
252
+ def _load_context(self, context: ContextType) -> None:
253
+ """
254
+ Load context data into the REPL environment.
255
+
256
+ The context is written to a file and then loaded into the
257
+ 'context' variable in the REPL namespace.
258
+ """
259
+ if isinstance(context, str):
260
+ # Text context
261
+ context_path = os.path.join(self.temp_dir, "context.txt")
262
+ with open(context_path, "w", encoding="utf-8") as f:
263
+ f.write(context)
264
+
265
+ load_code = f"""
266
+ with open(r'{context_path}', 'r', encoding='utf-8') as f:
267
+ context = f.read()
268
+ """
269
+ else:
270
+ # JSON context (dict or list)
271
+ context_path = os.path.join(self.temp_dir, "context.json")
272
+ with open(context_path, "w", encoding="utf-8") as f:
273
+ json.dump(context, f, indent=2, default=str)
274
+
275
+ load_code = f"""
276
+ import json
277
+ with open(r'{context_path}', 'r', encoding='utf-8') as f:
278
+ context = json.load(f)
279
+ """
280
+
281
+ # Execute the load code to populate 'context' variable
282
+ self._execute_internal(load_code)
283
+
284
+ def _execute_internal(self, code: str) -> None:
285
+ """Execute code internally without capturing output."""
286
+ combined = {**self.globals, **self.locals}
287
+ exec(code, combined, combined)
288
+
289
+ # Update locals with new variables
290
+ for key, value in combined.items():
291
+ if key not in self.globals and not key.startswith("_"):
292
+ self.locals[key] = value
293
+
294
+ @contextmanager
295
+ def _capture_output(self):
296
+ """Thread-safe context manager to capture stdout/stderr."""
297
+ old_stdout = sys.stdout
298
+ old_stderr = sys.stderr
299
+
300
+ stdout_buffer = io.StringIO()
301
+ stderr_buffer = io.StringIO()
302
+
303
+ try:
304
+ sys.stdout = stdout_buffer
305
+ sys.stderr = stderr_buffer
306
+ yield stdout_buffer, stderr_buffer
307
+ finally:
308
+ sys.stdout = old_stdout
309
+ sys.stderr = old_stderr
310
+
311
+ @contextmanager
312
+ def _temp_working_directory(self):
313
+ """Context manager to temporarily change working directory."""
314
+ old_cwd = os.getcwd()
315
+ try:
316
+ os.chdir(self.temp_dir)
317
+ yield
318
+ finally:
319
+ os.chdir(old_cwd)
320
+
321
+ def execute(self, code: str) -> REPLResult:
322
+ """
323
+ Execute Python code in the REPL environment.
324
+
325
+ Args:
326
+ code: Python code to execute
327
+
328
+ Returns:
329
+ REPLResult with stdout, stderr, locals, and timing
330
+ """
331
+ start_time = time.time()
332
+ success = True
333
+ stdout_content = ""
334
+ stderr_content = ""
335
+
336
+ with (
337
+ self._lock,
338
+ self._capture_output() as (stdout_buffer, stderr_buffer),
339
+ self._temp_working_directory(),
340
+ ):
341
+ try:
342
+ # Split into imports and other code
343
+ lines = code.split("\n")
344
+ import_lines = []
345
+ other_lines = []
346
+
347
+ for line in lines:
348
+ stripped = line.strip()
349
+ if stripped.startswith(("import ", "from ")) and not stripped.startswith("#"):
350
+ import_lines.append(line)
351
+ else:
352
+ other_lines.append(line)
353
+
354
+ # Execute imports in globals
355
+ if import_lines:
356
+ import_code = "\n".join(import_lines)
357
+ exec(import_code, self.globals, self.globals)
358
+
359
+ # Execute rest of code
360
+ if other_lines:
361
+ other_code = "\n".join(other_lines)
362
+ combined = {**self.globals, **self.locals}
363
+
364
+ # Try to evaluate last expression for display
365
+ self._execute_with_expression_display(other_code, other_lines, combined)
366
+
367
+ # Update locals
368
+ for key, value in combined.items():
369
+ if key not in self.globals:
370
+ self.locals[key] = value
371
+
372
+ stdout_content = stdout_buffer.getvalue()
373
+ stderr_content = stderr_buffer.getvalue()
374
+
375
+ except Exception as e:
376
+ success = False
377
+ stderr_content = stderr_buffer.getvalue() + f"\nError: {e!s}"
378
+ stdout_content = stdout_buffer.getvalue()
379
+
380
+ execution_time = time.time() - start_time
381
+
382
+ # Truncate output if needed
383
+ max_chars = self.config.truncate_output_chars
384
+ if len(stdout_content) > max_chars:
385
+ stdout_content = stdout_content[:max_chars] + "\n... (output truncated)"
386
+ if len(stderr_content) > max_chars:
387
+ stderr_content = stderr_content[:max_chars] + "\n... (output truncated)"
388
+
389
+ return REPLResult(
390
+ stdout=stdout_content,
391
+ stderr=stderr_content,
392
+ locals=dict(self.locals),
393
+ execution_time=execution_time,
394
+ success=success,
395
+ )
396
+
397
+ def _execute_with_expression_display(
398
+ self,
399
+ code: str,
400
+ lines: list[str],
401
+ namespace: dict[str, Any],
402
+ ) -> None:
403
+ """
404
+ Execute code, displaying the last expression's value if applicable.
405
+
406
+ This mimics notebook/REPL behavior where the last expression is
407
+ automatically displayed.
408
+ """
409
+ # Find non-comment, non-empty lines
410
+ non_comment_lines = [line for line in lines if line.strip() and not line.strip().startswith("#")]
411
+
412
+ if not non_comment_lines:
413
+ exec(code, namespace, namespace)
414
+ return
415
+
416
+ last_line = non_comment_lines[-1].strip()
417
+
418
+ # Check if last line is an expression (not a statement)
419
+ is_expression = (
420
+ not last_line.startswith(
421
+ (
422
+ "import ",
423
+ "from ",
424
+ "def ",
425
+ "class ",
426
+ "if ",
427
+ "for ",
428
+ "while ",
429
+ "try:",
430
+ "with ",
431
+ "return ",
432
+ "yield ",
433
+ "raise ",
434
+ "break",
435
+ "continue",
436
+ "pass",
437
+ "assert ",
438
+ "del ",
439
+ "global ",
440
+ "nonlocal ",
441
+ )
442
+ )
443
+ and "=" not in last_line.split("#")[0] # Not assignment
444
+ and not last_line.endswith(":") # Not control structure
445
+ and not last_line.startswith("print(") # Not explicit print
446
+ )
447
+
448
+ if is_expression and len(non_comment_lines) > 0:
449
+ try:
450
+ # Execute all but last line
451
+ if len(non_comment_lines) > 1:
452
+ # Find where last line starts
453
+ last_line_idx = None
454
+ for i, line in enumerate(lines):
455
+ if line.strip() == last_line:
456
+ last_line_idx = i
457
+ break
458
+
459
+ if last_line_idx and last_line_idx > 0:
460
+ statements = "\n".join(lines[:last_line_idx])
461
+ exec(statements, namespace, namespace)
462
+
463
+ # Evaluate and print last expression
464
+ result = eval(last_line, namespace, namespace)
465
+ if result is not None:
466
+ print(repr(result))
467
+
468
+ except (SyntaxError, NameError):
469
+ # Fall back to normal execution
470
+ exec(code, namespace, namespace)
471
+ else:
472
+ exec(code, namespace, namespace)
473
+
474
+ def cleanup(self) -> None:
475
+ """Clean up temporary directory."""
476
+ with contextlib.suppress(Exception):
477
+ shutil.rmtree(self.temp_dir)
478
+
479
+ def __del__(self):
480
+ """Destructor to ensure cleanup."""
481
+ self.cleanup()
@@ -0,0 +1,168 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+
5
+ from pydantic_ai import RunContext
6
+ from pydantic_ai.toolsets import FunctionToolset
7
+
8
+ from .dependencies import RLMConfig, RLMDependencies
9
+ from .logging import get_logger
10
+ from .repl import REPLEnvironment, REPLResult
11
+ from .utils import format_repl_result
12
+
13
+ EXECUTE_CODE_DESCRIPTION = """
14
+ Execute Python code in a sandboxed REPL environment.
15
+
16
+ ## Environment
17
+ - A `context` variable is pre-loaded with the data to analyze
18
+ - Variables persist between executions within the same session
19
+ - Standard library modules are available (json, re, collections, etc.)
20
+ - Use print() to display output
21
+
22
+ ## When to Use
23
+ - Analyzing or processing structured data (JSON, dicts, lists)
24
+ - Performing calculations or data transformations
25
+ - Extracting specific information from large datasets
26
+ - Testing hypotheses about the data structure
27
+
28
+ ## Best Practices
29
+ 1. Start by exploring the context: `print(type(context))`, `print(len(context))`
30
+ 2. Break complex operations into smaller steps
31
+ 3. Use print() liberally to understand intermediate results
32
+ 4. Handle potential errors gracefully with try/except
33
+
34
+ ## Available Functions
35
+ - `llm_query(prompt)`: Query the LLM for reasoning assistance (if configured)
36
+ - Important: Do not use `llm_query` in the first code execution. Use it only after you have
37
+ explored the context and identified specific sections that need semantic analysis.
38
+
39
+ ## Example
40
+ ```python
41
+ # Explore the data
42
+ print(f"Context type: {type(context)}")
43
+ print(f"Keys: {list(context.keys()) if isinstance(context, dict) else 'N/A'}")
44
+
45
+ # Process and extract information
46
+ if isinstance(context, dict):
47
+ for key, value in context.items():
48
+ print(f"{key}: {type(value)}")
49
+ ```
50
+ """
51
+
52
+
53
+ # Global registry to track REPL environments for cleanup
54
+ _repl_registry: dict[int, REPLEnvironment] = {}
55
+
56
+
57
+ def create_rlm_toolset(
58
+ *,
59
+ code_timeout: float = 60.0,
60
+ sub_model: str | None = None,
61
+ toolset_id: str | None = None,
62
+ ) -> FunctionToolset[RLMDependencies]:
63
+ """Create an RLM toolset for code execution in a sandboxed REPL.
64
+
65
+ This toolset provides an `execute_code` tool that allows AI agents to
66
+ run Python code with access to a `context` variable containing data to analyze.
67
+
68
+ Args:
69
+ code_timeout: Timeout in seconds for code execution. Defaults to 60.0.
70
+ sub_model: Model to use for llm_query() within the REPL environment.
71
+ toolset_id: Optional unique identifier for the toolset.
72
+
73
+ Returns:
74
+ FunctionToolset compatible with any pydantic-ai agent.
75
+
76
+ Example (basic usage):
77
+ ```python
78
+ from pydantic_ai import Agent
79
+ from pydantic_ai_rlm import create_rlm_toolset, RLMDependencies
80
+
81
+ toolset = create_rlm_toolset()
82
+ agent = Agent("openai:gpt-5", toolsets=[toolset])
83
+
84
+ deps = RLMDependencies(context={"users": [...]})
85
+ result = await agent.run("Analyze the user data", deps=deps)
86
+ ```
87
+
88
+ Example (with timeout and sub-model):
89
+ ```python
90
+ from pydantic_ai_rlm import create_rlm_toolset, RLMDependencies, RLMConfig
91
+
92
+ toolset = create_rlm_toolset(
93
+ code_timeout=120.0,
94
+ sub_model="openai:gpt-5-mini",
95
+ )
96
+ agent = Agent("openai:gpt-5", toolsets=[toolset])
97
+
98
+ deps = RLMDependencies(
99
+ context=large_dataset,
100
+ config=RLMConfig(),
101
+ )
102
+ result = await agent.run("Process this dataset", deps=deps)
103
+ ```
104
+
105
+ Example (with toolset composition):
106
+ ```python
107
+ from pydantic_ai_rlm import create_rlm_toolset
108
+
109
+ rlm_toolset = create_rlm_toolset().prefixed("rlm")
110
+ # Tool will be named 'rlm_execute_code'
111
+ ```
112
+ """
113
+ toolset: FunctionToolset[RLMDependencies] = FunctionToolset(id=toolset_id)
114
+
115
+ def _get_or_create_repl(ctx: RunContext[RLMDependencies]) -> REPLEnvironment:
116
+ """Get or create REPL environment for this run context."""
117
+ deps_id = id(ctx.deps)
118
+
119
+ if deps_id not in _repl_registry:
120
+ config = ctx.deps.config or RLMConfig()
121
+ # Override sub_model from factory if set and not already in config
122
+ if sub_model and not config.sub_model:
123
+ config = RLMConfig(
124
+ sub_model=sub_model,
125
+ )
126
+ _repl_registry[deps_id] = REPLEnvironment(
127
+ context=ctx.deps.context,
128
+ config=config,
129
+ )
130
+
131
+ return _repl_registry[deps_id]
132
+
133
+ @toolset.tool(description=EXECUTE_CODE_DESCRIPTION)
134
+ async def execute_code(ctx: RunContext[RLMDependencies], code: str) -> str:
135
+ repl_env = _get_or_create_repl(ctx)
136
+ logger = get_logger()
137
+
138
+ # Log the code being executed
139
+ logger.log_code_execution(code)
140
+
141
+ try:
142
+ loop = asyncio.get_running_loop()
143
+ result: REPLResult = await asyncio.wait_for(
144
+ loop.run_in_executor(None, repl_env.execute, code),
145
+ timeout=code_timeout,
146
+ )
147
+
148
+ # Log the result
149
+ logger.log_result(result)
150
+
151
+ return format_repl_result(result)
152
+
153
+ except TimeoutError:
154
+ return f"Error: Code execution timed out after {code_timeout} seconds."
155
+ except Exception as e:
156
+ return f"Error executing code: {e!s}"
157
+
158
+ return toolset
159
+
160
+
161
+ def cleanup_repl_environments() -> None:
162
+ """Clean up all REPL environments.
163
+
164
+ Call this when you're done with all agent runs to release resources.
165
+ """
166
+ for repl_env in _repl_registry.values():
167
+ repl_env.cleanup()
168
+ _repl_registry.clear()
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ from .repl import REPLResult
4
+
5
+
6
+ def format_repl_result(result: REPLResult, max_var_display: int = 200) -> str:
7
+ """
8
+ Format a REPL execution result for display to the LLM.
9
+
10
+ Args:
11
+ result: The REPLResult from code execution
12
+ max_var_display: Maximum characters to show per variable value
13
+
14
+ Returns:
15
+ Formatted string suitable for LLM consumption
16
+ """
17
+ parts = []
18
+
19
+ if result.stdout.strip():
20
+ parts.append(f"Output:\n{result.stdout}")
21
+
22
+ if result.stderr.strip():
23
+ parts.append(f"Errors:\n{result.stderr}")
24
+
25
+ # Show created/modified variables (excluding internal ones)
26
+ user_vars = {k: v for k, v in result.locals.items() if not k.startswith("_") and k not in ("context", "json", "re", "os")}
27
+
28
+ if user_vars:
29
+ var_summaries = []
30
+ for name, value in user_vars.items():
31
+ try:
32
+ value_str = repr(value)
33
+ if len(value_str) > max_var_display:
34
+ value_str = value_str[:max_var_display] + "..."
35
+ var_summaries.append(f" {name} = {value_str}")
36
+ except Exception:
37
+ var_summaries.append(f" {name} = <{type(value).__name__}>")
38
+
39
+ if var_summaries:
40
+ parts.append("Variables:\n" + "\n".join(var_summaries))
41
+
42
+ parts.append(f"Execution time: {result.execution_time:.3f}s")
43
+
44
+ if not parts:
45
+ return "Code executed successfully (no output)"
46
+
47
+ return "\n\n".join(parts)