recursive-llm-ts 2.0.12 → 3.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +60 -43
- package/bin/rlm-go +0 -0
- package/dist/bridge-factory.d.ts +1 -1
- package/dist/bridge-factory.js +44 -14
- package/dist/bridge-interface.d.ts +1 -0
- package/dist/bunpy-bridge.d.ts +3 -4
- package/dist/bunpy-bridge.js +11 -164
- package/dist/go-bridge.d.ts +5 -0
- package/dist/go-bridge.js +136 -0
- package/dist/rlm-bridge.js +3 -1
- package/go/README.md +347 -0
- package/go/cmd/rlm/main.go +63 -0
- package/go/go.mod +12 -0
- package/go/go.sum +57 -0
- package/go/integration_test.sh +169 -0
- package/go/internal/rlm/benchmark_test.go +168 -0
- package/go/internal/rlm/errors.go +83 -0
- package/go/internal/rlm/openai.go +128 -0
- package/go/internal/rlm/parser.go +53 -0
- package/go/internal/rlm/parser_test.go +202 -0
- package/go/internal/rlm/prompt.go +68 -0
- package/go/internal/rlm/repl.go +260 -0
- package/go/internal/rlm/repl_test.go +291 -0
- package/go/internal/rlm/rlm.go +142 -0
- package/go/internal/rlm/types.go +108 -0
- package/go/test_mock.sh +90 -0
- package/go/test_rlm.sh +41 -0
- package/go/test_simple.sh +78 -0
- package/package.json +6 -9
- package/scripts/build-go-binary.js +41 -0
- package/recursive-llm/pyproject.toml +0 -70
- package/recursive-llm/src/rlm/__init__.py +0 -14
- package/recursive-llm/src/rlm/core.py +0 -322
- package/recursive-llm/src/rlm/parser.py +0 -93
- package/recursive-llm/src/rlm/prompts.py +0 -50
- package/recursive-llm/src/rlm/repl.py +0 -235
- package/recursive-llm/src/rlm/types.py +0 -37
- package/scripts/install-python-deps.js +0 -101
|
@@ -1,322 +0,0 @@
|
|
|
1
|
-
"""Core RLM implementation."""
|
|
2
|
-
|
|
3
|
-
import asyncio
|
|
4
|
-
import re
|
|
5
|
-
from typing import Optional, Dict, Any, List
|
|
6
|
-
|
|
7
|
-
import litellm
|
|
8
|
-
|
|
9
|
-
from .types import Message
|
|
10
|
-
from .repl import REPLExecutor, REPLError
|
|
11
|
-
from .prompts import build_system_prompt
|
|
12
|
-
from .parser import parse_response, is_final
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class RLMError(Exception):
|
|
16
|
-
"""Base error for RLM."""
|
|
17
|
-
pass
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class MaxIterationsError(RLMError):
|
|
21
|
-
"""Max iterations exceeded."""
|
|
22
|
-
pass
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
class MaxDepthError(RLMError):
|
|
26
|
-
"""Max recursion depth exceeded."""
|
|
27
|
-
pass
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class RLM:
|
|
31
|
-
"""Recursive Language Model."""
|
|
32
|
-
|
|
33
|
-
def __init__(
|
|
34
|
-
self,
|
|
35
|
-
model: str,
|
|
36
|
-
recursive_model: Optional[str] = None,
|
|
37
|
-
api_base: Optional[str] = None,
|
|
38
|
-
api_key: Optional[str] = None,
|
|
39
|
-
max_depth: int = 5,
|
|
40
|
-
max_iterations: int = 30,
|
|
41
|
-
_current_depth: int = 0,
|
|
42
|
-
**llm_kwargs: Any
|
|
43
|
-
):
|
|
44
|
-
"""
|
|
45
|
-
Initialize RLM.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
model: Model name (e.g., "gpt-4o", "claude-sonnet-4", "ollama/llama3.2")
|
|
49
|
-
recursive_model: Optional cheaper model for recursive calls
|
|
50
|
-
api_base: Optional API base URL
|
|
51
|
-
api_key: Optional API key
|
|
52
|
-
max_depth: Maximum recursion depth
|
|
53
|
-
max_iterations: Maximum REPL iterations per call
|
|
54
|
-
_current_depth: Internal current depth tracker
|
|
55
|
-
**llm_kwargs: Additional LiteLLM parameters
|
|
56
|
-
"""
|
|
57
|
-
# Patch for recursive-llm-ts bug where config is passed as 2nd positional arg
|
|
58
|
-
if isinstance(recursive_model, dict):
|
|
59
|
-
config = recursive_model
|
|
60
|
-
# Reset recursive_model default
|
|
61
|
-
self.recursive_model = config.get('recursive_model', model)
|
|
62
|
-
self.api_base = config.get('api_base', api_base)
|
|
63
|
-
self.api_key = config.get('api_key', api_key)
|
|
64
|
-
self.max_depth = int(config.get('max_depth', max_depth))
|
|
65
|
-
self.max_iterations = int(config.get('max_iterations', max_iterations))
|
|
66
|
-
|
|
67
|
-
# Extract other llm kwargs
|
|
68
|
-
excluded = {'recursive_model', 'api_base', 'api_key', 'max_depth', 'max_iterations'}
|
|
69
|
-
self.llm_kwargs = {k: v for k, v in config.items() if k not in excluded}
|
|
70
|
-
# Merge with any actual kwargs passed
|
|
71
|
-
self.llm_kwargs.update(llm_kwargs)
|
|
72
|
-
else:
|
|
73
|
-
self.recursive_model = recursive_model or model
|
|
74
|
-
self.api_base = api_base
|
|
75
|
-
self.api_key = api_key
|
|
76
|
-
self.max_depth = max_depth
|
|
77
|
-
self.max_iterations = max_iterations
|
|
78
|
-
self.llm_kwargs = llm_kwargs
|
|
79
|
-
|
|
80
|
-
self._current_depth = _current_depth
|
|
81
|
-
self.model = model
|
|
82
|
-
|
|
83
|
-
self.repl = REPLExecutor()
|
|
84
|
-
|
|
85
|
-
# Stats
|
|
86
|
-
self._llm_calls = 0
|
|
87
|
-
self._iterations = 0
|
|
88
|
-
|
|
89
|
-
def completion(
|
|
90
|
-
self,
|
|
91
|
-
query: str = "",
|
|
92
|
-
context: str = "",
|
|
93
|
-
**kwargs: Any
|
|
94
|
-
) -> str:
|
|
95
|
-
"""
|
|
96
|
-
Sync wrapper for acompletion.
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
query: User query (optional if query is in context)
|
|
100
|
-
context: Context to process (optional, can pass query here)
|
|
101
|
-
**kwargs: Additional LiteLLM parameters
|
|
102
|
-
|
|
103
|
-
Returns:
|
|
104
|
-
Final answer string
|
|
105
|
-
|
|
106
|
-
Examples:
|
|
107
|
-
# Standard usage
|
|
108
|
-
rlm.completion(query="Summarize this", context=document)
|
|
109
|
-
|
|
110
|
-
# Query in context (RLM will extract task)
|
|
111
|
-
rlm.completion(context="Summarize this document: ...")
|
|
112
|
-
|
|
113
|
-
# Single string (treat as context)
|
|
114
|
-
rlm.completion("Process this text and extract dates")
|
|
115
|
-
"""
|
|
116
|
-
# If only one argument provided, treat it as context
|
|
117
|
-
if query and not context:
|
|
118
|
-
context = query
|
|
119
|
-
query = ""
|
|
120
|
-
|
|
121
|
-
return asyncio.run(self.acompletion(query, context, **kwargs))
|
|
122
|
-
|
|
123
|
-
async def acompletion(
|
|
124
|
-
self,
|
|
125
|
-
query: str = "",
|
|
126
|
-
context: str = "",
|
|
127
|
-
**kwargs: Any
|
|
128
|
-
) -> str:
|
|
129
|
-
"""
|
|
130
|
-
Main async completion method.
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
query: User query (optional if query is in context)
|
|
134
|
-
context: Context to process (optional, can pass query here)
|
|
135
|
-
**kwargs: Additional LiteLLM parameters
|
|
136
|
-
|
|
137
|
-
Returns:
|
|
138
|
-
Final answer string
|
|
139
|
-
|
|
140
|
-
Raises:
|
|
141
|
-
MaxIterationsError: If max iterations exceeded
|
|
142
|
-
MaxDepthError: If max recursion depth exceeded
|
|
143
|
-
|
|
144
|
-
Examples:
|
|
145
|
-
# Explicit query and context
|
|
146
|
-
await rlm.acompletion(query="What is this?", context=doc)
|
|
147
|
-
|
|
148
|
-
# Query embedded in context
|
|
149
|
-
await rlm.acompletion(context="Extract all dates from: ...")
|
|
150
|
-
|
|
151
|
-
# LLM will figure out the task
|
|
152
|
-
await rlm.acompletion(context=document_with_instructions)
|
|
153
|
-
"""
|
|
154
|
-
# If only query provided, treat it as context
|
|
155
|
-
if query and not context:
|
|
156
|
-
context = query
|
|
157
|
-
query = ""
|
|
158
|
-
if self._current_depth >= self.max_depth:
|
|
159
|
-
raise MaxDepthError(f"Max recursion depth ({self.max_depth}) exceeded")
|
|
160
|
-
|
|
161
|
-
# Initialize REPL environment
|
|
162
|
-
repl_env = self._build_repl_env(query, context)
|
|
163
|
-
|
|
164
|
-
# Build initial messages
|
|
165
|
-
system_prompt = build_system_prompt(len(context), self._current_depth)
|
|
166
|
-
messages: List[Message] = [
|
|
167
|
-
{"role": "system", "content": system_prompt},
|
|
168
|
-
{"role": "user", "content": query}
|
|
169
|
-
]
|
|
170
|
-
|
|
171
|
-
# Main loop
|
|
172
|
-
for iteration in range(self.max_iterations):
|
|
173
|
-
self._iterations = iteration + 1
|
|
174
|
-
|
|
175
|
-
# Call LLM
|
|
176
|
-
response = await self._call_llm(messages, **kwargs)
|
|
177
|
-
|
|
178
|
-
# Check for FINAL
|
|
179
|
-
if is_final(response):
|
|
180
|
-
answer = parse_response(response, repl_env)
|
|
181
|
-
if answer is not None:
|
|
182
|
-
return answer
|
|
183
|
-
|
|
184
|
-
# Execute code in REPL
|
|
185
|
-
try:
|
|
186
|
-
exec_result = self.repl.execute(response, repl_env)
|
|
187
|
-
except REPLError as e:
|
|
188
|
-
exec_result = f"Error: {str(e)}"
|
|
189
|
-
except Exception as e:
|
|
190
|
-
exec_result = f"Unexpected error: {str(e)}"
|
|
191
|
-
|
|
192
|
-
# Add to conversation
|
|
193
|
-
messages.append({"role": "assistant", "content": response})
|
|
194
|
-
messages.append({"role": "user", "content": exec_result})
|
|
195
|
-
|
|
196
|
-
raise MaxIterationsError(
|
|
197
|
-
f"Max iterations ({self.max_iterations}) exceeded without FINAL()"
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
async def _call_llm(
|
|
201
|
-
self,
|
|
202
|
-
messages: List[Message],
|
|
203
|
-
**kwargs: Any
|
|
204
|
-
) -> str:
|
|
205
|
-
"""
|
|
206
|
-
Call LLM API.
|
|
207
|
-
|
|
208
|
-
Args:
|
|
209
|
-
messages: Conversation messages
|
|
210
|
-
**kwargs: Additional parameters (can override model here)
|
|
211
|
-
|
|
212
|
-
Returns:
|
|
213
|
-
LLM response text
|
|
214
|
-
"""
|
|
215
|
-
self._llm_calls += 1
|
|
216
|
-
|
|
217
|
-
# Choose model based on depth
|
|
218
|
-
default_model = self.model if self._current_depth == 0 else self.recursive_model
|
|
219
|
-
|
|
220
|
-
# Allow override via kwargs
|
|
221
|
-
model = kwargs.pop('model', default_model)
|
|
222
|
-
|
|
223
|
-
# Merge kwargs
|
|
224
|
-
call_kwargs = {**self.llm_kwargs, **kwargs}
|
|
225
|
-
if self.api_base:
|
|
226
|
-
call_kwargs['api_base'] = self.api_base
|
|
227
|
-
if self.api_key:
|
|
228
|
-
call_kwargs['api_key'] = self.api_key
|
|
229
|
-
|
|
230
|
-
# Call LiteLLM
|
|
231
|
-
response = await litellm.acompletion(
|
|
232
|
-
model=model,
|
|
233
|
-
messages=messages,
|
|
234
|
-
**call_kwargs
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
# Extract text
|
|
238
|
-
return response.choices[0].message.content
|
|
239
|
-
|
|
240
|
-
def _build_repl_env(self, query: str, context: str) -> Dict[str, Any]:
|
|
241
|
-
"""
|
|
242
|
-
Build REPL environment.
|
|
243
|
-
|
|
244
|
-
Args:
|
|
245
|
-
query: User query
|
|
246
|
-
context: Context string
|
|
247
|
-
|
|
248
|
-
Returns:
|
|
249
|
-
Environment dict
|
|
250
|
-
"""
|
|
251
|
-
env: Dict[str, Any] = {
|
|
252
|
-
'context': context,
|
|
253
|
-
'query': query,
|
|
254
|
-
'recursive_llm': self._make_recursive_fn(),
|
|
255
|
-
're': re, # Whitelist re module
|
|
256
|
-
}
|
|
257
|
-
return env
|
|
258
|
-
|
|
259
|
-
def _make_recursive_fn(self) -> Any:
|
|
260
|
-
"""
|
|
261
|
-
Create recursive LLM function for REPL.
|
|
262
|
-
|
|
263
|
-
Returns:
|
|
264
|
-
Async function that can be called from REPL
|
|
265
|
-
"""
|
|
266
|
-
async def recursive_llm(sub_query: str, sub_context: str) -> str:
|
|
267
|
-
"""
|
|
268
|
-
Recursively process sub-context.
|
|
269
|
-
|
|
270
|
-
Args:
|
|
271
|
-
sub_query: Query for sub-context
|
|
272
|
-
sub_context: Sub-context to process
|
|
273
|
-
|
|
274
|
-
Returns:
|
|
275
|
-
Answer from recursive call
|
|
276
|
-
"""
|
|
277
|
-
if self._current_depth + 1 >= self.max_depth:
|
|
278
|
-
return f"Max recursion depth ({self.max_depth}) reached"
|
|
279
|
-
|
|
280
|
-
# Create sub-RLM with increased depth
|
|
281
|
-
sub_rlm = RLM(
|
|
282
|
-
model=self.recursive_model,
|
|
283
|
-
recursive_model=self.recursive_model,
|
|
284
|
-
api_base=self.api_base,
|
|
285
|
-
api_key=self.api_key,
|
|
286
|
-
max_depth=self.max_depth,
|
|
287
|
-
max_iterations=self.max_iterations,
|
|
288
|
-
_current_depth=self._current_depth + 1,
|
|
289
|
-
**self.llm_kwargs
|
|
290
|
-
)
|
|
291
|
-
|
|
292
|
-
return await sub_rlm.acompletion(sub_query, sub_context)
|
|
293
|
-
|
|
294
|
-
# Wrap in sync function for REPL compatibility
|
|
295
|
-
def sync_recursive_llm(sub_query: str, sub_context: str) -> str:
|
|
296
|
-
"""Sync wrapper for recursive_llm."""
|
|
297
|
-
# Check if we're in an async context
|
|
298
|
-
try:
|
|
299
|
-
loop = asyncio.get_running_loop()
|
|
300
|
-
# We're in async context, but REPL is sync
|
|
301
|
-
# Create a new thread to run async code
|
|
302
|
-
import concurrent.futures
|
|
303
|
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
304
|
-
future = executor.submit(
|
|
305
|
-
asyncio.run,
|
|
306
|
-
recursive_llm(sub_query, sub_context)
|
|
307
|
-
)
|
|
308
|
-
return future.result()
|
|
309
|
-
except RuntimeError:
|
|
310
|
-
# No running loop, safe to use asyncio.run
|
|
311
|
-
return asyncio.run(recursive_llm(sub_query, sub_context))
|
|
312
|
-
|
|
313
|
-
return sync_recursive_llm
|
|
314
|
-
|
|
315
|
-
@property
|
|
316
|
-
def stats(self) -> Dict[str, int]:
|
|
317
|
-
"""Get execution statistics."""
|
|
318
|
-
return {
|
|
319
|
-
'llm_calls': self._llm_calls,
|
|
320
|
-
'iterations': self._iterations,
|
|
321
|
-
'depth': self._current_depth,
|
|
322
|
-
}
|
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
"""Parse FINAL() and FINAL_VAR() statements from LLM responses."""
|
|
2
|
-
|
|
3
|
-
import re
|
|
4
|
-
from typing import Optional, Dict, Any
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def extract_final(response: str) -> Optional[str]:
|
|
8
|
-
"""
|
|
9
|
-
Extract answer from FINAL() statement.
|
|
10
|
-
|
|
11
|
-
Args:
|
|
12
|
-
response: LLM response text
|
|
13
|
-
|
|
14
|
-
Returns:
|
|
15
|
-
Extracted answer or None if not found
|
|
16
|
-
"""
|
|
17
|
-
# Look for FINAL("answer") or FINAL('answer')
|
|
18
|
-
patterns = [
|
|
19
|
-
r'FINAL\s*\(\s*"""(.*)"""', # FINAL("""answer""") - triple double quotes
|
|
20
|
-
r"FINAL\s*\(\s*'''(.*)'''", # FINAL('''answer''') - triple single quotes
|
|
21
|
-
r'FINAL\s*\(\s*"([^"]*)"', # FINAL("answer") - double quotes
|
|
22
|
-
r"FINAL\s*\(\s*'([^']*)'", # FINAL('answer') - single quotes
|
|
23
|
-
]
|
|
24
|
-
|
|
25
|
-
for pattern in patterns:
|
|
26
|
-
match = re.search(pattern, response, re.DOTALL)
|
|
27
|
-
if match:
|
|
28
|
-
return match.group(1).strip()
|
|
29
|
-
|
|
30
|
-
return None
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def extract_final_var(response: str, env: Dict[str, Any]) -> Optional[str]:
|
|
34
|
-
"""
|
|
35
|
-
Extract answer from FINAL_VAR() statement.
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
response: LLM response text
|
|
39
|
-
env: REPL environment with variables
|
|
40
|
-
|
|
41
|
-
Returns:
|
|
42
|
-
Variable value as string or None if not found
|
|
43
|
-
"""
|
|
44
|
-
# Look for FINAL_VAR(var_name)
|
|
45
|
-
match = re.search(r'FINAL_VAR\s*\(\s*(\w+)\s*\)', response)
|
|
46
|
-
if not match:
|
|
47
|
-
return None
|
|
48
|
-
|
|
49
|
-
var_name = match.group(1)
|
|
50
|
-
|
|
51
|
-
# Get variable from environment
|
|
52
|
-
if var_name in env:
|
|
53
|
-
value = env[var_name]
|
|
54
|
-
return str(value)
|
|
55
|
-
|
|
56
|
-
return None
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def is_final(response: str) -> bool:
|
|
60
|
-
"""
|
|
61
|
-
Check if response contains FINAL() or FINAL_VAR().
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
response: LLM response text
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
True if response contains final statement
|
|
68
|
-
"""
|
|
69
|
-
return 'FINAL(' in response or 'FINAL_VAR(' in response
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def parse_response(response: str, env: Dict[str, Any]) -> Optional[str]:
|
|
73
|
-
"""
|
|
74
|
-
Parse response for any final statement.
|
|
75
|
-
|
|
76
|
-
Args:
|
|
77
|
-
response: LLM response text
|
|
78
|
-
env: REPL environment
|
|
79
|
-
|
|
80
|
-
Returns:
|
|
81
|
-
Final answer or None
|
|
82
|
-
"""
|
|
83
|
-
# Try FINAL() first
|
|
84
|
-
answer = extract_final(response)
|
|
85
|
-
if answer is not None:
|
|
86
|
-
return answer
|
|
87
|
-
|
|
88
|
-
# Try FINAL_VAR()
|
|
89
|
-
answer = extract_final_var(response, env)
|
|
90
|
-
if answer is not None:
|
|
91
|
-
return answer
|
|
92
|
-
|
|
93
|
-
return None
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
"""System prompt templates for RLM."""
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
def build_system_prompt(context_size: int, depth: int = 0) -> str:
|
|
5
|
-
"""
|
|
6
|
-
Build system prompt for RLM.
|
|
7
|
-
|
|
8
|
-
Args:
|
|
9
|
-
context_size: Size of context in characters
|
|
10
|
-
depth: Current recursion depth
|
|
11
|
-
|
|
12
|
-
Returns:
|
|
13
|
-
System prompt string
|
|
14
|
-
"""
|
|
15
|
-
# Minimal prompt (paper-style)
|
|
16
|
-
prompt = f"""You are a Recursive Language Model. You interact with context through a Python REPL environment.
|
|
17
|
-
|
|
18
|
-
The context is stored in variable `context` (not in this prompt). Size: {context_size:,} characters.
|
|
19
|
-
|
|
20
|
-
Available in environment:
|
|
21
|
-
- context: str (the document to analyze)
|
|
22
|
-
- query: str (the question: "{"{"}query{"}"}")
|
|
23
|
-
- recursive_llm(sub_query, sub_context) -> str (recursively process sub-context)
|
|
24
|
-
- re: already imported regex module (use re.findall, re.search, etc.)
|
|
25
|
-
|
|
26
|
-
Write Python code to answer the query. The last expression or print() output will be shown to you.
|
|
27
|
-
|
|
28
|
-
Examples:
|
|
29
|
-
- print(context[:100]) # See first 100 chars
|
|
30
|
-
- errors = re.findall(r'ERROR', context) # Find all ERROR
|
|
31
|
-
- count = len(errors); print(count) # Count and show
|
|
32
|
-
|
|
33
|
-
When you have the answer, use FINAL("answer") - this is NOT a function, just write it as text.
|
|
34
|
-
|
|
35
|
-
Depth: {depth}"""
|
|
36
|
-
|
|
37
|
-
return prompt
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def build_user_prompt(query: str) -> str:
|
|
41
|
-
"""
|
|
42
|
-
Build user prompt.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
query: User's question
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
User prompt string
|
|
49
|
-
"""
|
|
50
|
-
return query
|
|
@@ -1,235 +0,0 @@
|
|
|
1
|
-
"""Safe REPL executor using RestrictedPython."""
|
|
2
|
-
|
|
3
|
-
import io
|
|
4
|
-
import sys
|
|
5
|
-
from typing import Dict, Any, Optional
|
|
6
|
-
from RestrictedPython import compile_restricted_exec, safe_globals, limited_builtins, utility_builtins
|
|
7
|
-
from RestrictedPython.Guards import guarded_iter_unpack_sequence, safer_getattr
|
|
8
|
-
from RestrictedPython.PrintCollector import PrintCollector
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class REPLError(Exception):
|
|
12
|
-
"""Error during REPL execution."""
|
|
13
|
-
pass
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
class REPLExecutor:
|
|
17
|
-
"""Safe Python code executor."""
|
|
18
|
-
|
|
19
|
-
def __init__(self, timeout: int = 5, max_output_chars: int = 2000):
|
|
20
|
-
"""
|
|
21
|
-
Initialize REPL executor.
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
timeout: Execution timeout in seconds (not currently enforced)
|
|
25
|
-
max_output_chars: Maximum characters to return (truncate if longer)
|
|
26
|
-
"""
|
|
27
|
-
self.timeout = timeout
|
|
28
|
-
self.max_output_chars = max_output_chars
|
|
29
|
-
|
|
30
|
-
def execute(self, code: str, env: Dict[str, Any]) -> str:
|
|
31
|
-
"""
|
|
32
|
-
Execute Python code in restricted environment.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
code: Python code to execute
|
|
36
|
-
env: Environment with context, query, recursive_llm, etc.
|
|
37
|
-
|
|
38
|
-
Returns:
|
|
39
|
-
String result of execution (stdout or last expression)
|
|
40
|
-
|
|
41
|
-
Raises:
|
|
42
|
-
REPLError: If code execution fails
|
|
43
|
-
"""
|
|
44
|
-
# Filter out code blocks if present (LLM might wrap code)
|
|
45
|
-
code = self._extract_code(code)
|
|
46
|
-
|
|
47
|
-
if not code.strip():
|
|
48
|
-
return "No code to execute"
|
|
49
|
-
|
|
50
|
-
# Build restricted globals
|
|
51
|
-
restricted_globals = self._build_globals(env)
|
|
52
|
-
|
|
53
|
-
# Capture stdout
|
|
54
|
-
old_stdout = sys.stdout
|
|
55
|
-
sys.stdout = captured_output = io.StringIO()
|
|
56
|
-
|
|
57
|
-
try:
|
|
58
|
-
# Compile with RestrictedPython
|
|
59
|
-
byte_code = compile_restricted_exec(code)
|
|
60
|
-
|
|
61
|
-
if byte_code.errors:
|
|
62
|
-
raise REPLError(f"Compilation error: {', '.join(byte_code.errors)}")
|
|
63
|
-
|
|
64
|
-
# Execute
|
|
65
|
-
exec(byte_code.code, restricted_globals, env)
|
|
66
|
-
|
|
67
|
-
# Get output from stdout
|
|
68
|
-
output = captured_output.getvalue()
|
|
69
|
-
|
|
70
|
-
# Get output from PrintCollector if available
|
|
71
|
-
if '_print' in env and hasattr(env['_print'], '__call__'):
|
|
72
|
-
# PrintCollector stores prints in its txt attribute
|
|
73
|
-
print_collector = env['_print']
|
|
74
|
-
if hasattr(print_collector, 'txt'):
|
|
75
|
-
output += ''.join(print_collector.txt)
|
|
76
|
-
|
|
77
|
-
# Check if last line was an expression (try to get its value)
|
|
78
|
-
# This handles cases like: error_count (should return its value)
|
|
79
|
-
lines = code.strip().split('\n')
|
|
80
|
-
if lines:
|
|
81
|
-
last_line = lines[-1].strip()
|
|
82
|
-
# If last line is a simple expression (no assignment, no keyword)
|
|
83
|
-
if last_line and not any(kw in last_line for kw in ['=', 'import', 'def', 'class', 'if', 'for', 'while', 'with']):
|
|
84
|
-
try:
|
|
85
|
-
# Try to evaluate the last line as expression
|
|
86
|
-
result = eval(last_line, restricted_globals, env)
|
|
87
|
-
if result is not None:
|
|
88
|
-
output += str(result) + '\n'
|
|
89
|
-
except:
|
|
90
|
-
pass # Not an expression, ignore
|
|
91
|
-
|
|
92
|
-
if not output:
|
|
93
|
-
return "Code executed successfully (no output)"
|
|
94
|
-
|
|
95
|
-
# Truncate output if too long (as per paper: "truncated version of output")
|
|
96
|
-
if len(output) > self.max_output_chars:
|
|
97
|
-
truncated = output[:self.max_output_chars]
|
|
98
|
-
return f"{truncated}\n\n[Output truncated: {len(output)} chars total, showing first {self.max_output_chars}]"
|
|
99
|
-
|
|
100
|
-
return output.strip()
|
|
101
|
-
|
|
102
|
-
except Exception as e:
|
|
103
|
-
raise REPLError(f"Execution error: {str(e)}")
|
|
104
|
-
|
|
105
|
-
finally:
|
|
106
|
-
sys.stdout = old_stdout
|
|
107
|
-
|
|
108
|
-
def _extract_code(self, text: str) -> str:
|
|
109
|
-
"""
|
|
110
|
-
Extract code from markdown code blocks if present.
|
|
111
|
-
|
|
112
|
-
Args:
|
|
113
|
-
text: Raw text that might contain code
|
|
114
|
-
|
|
115
|
-
Returns:
|
|
116
|
-
Extracted code
|
|
117
|
-
"""
|
|
118
|
-
# Check for markdown code blocks
|
|
119
|
-
if '```python' in text:
|
|
120
|
-
start = text.find('```python') + len('```python')
|
|
121
|
-
end = text.find('```', start)
|
|
122
|
-
if end != -1:
|
|
123
|
-
return text[start:end].strip()
|
|
124
|
-
|
|
125
|
-
if '```' in text:
|
|
126
|
-
start = text.find('```') + 3
|
|
127
|
-
end = text.find('```', start)
|
|
128
|
-
if end != -1:
|
|
129
|
-
return text[start:end].strip()
|
|
130
|
-
|
|
131
|
-
return text
|
|
132
|
-
|
|
133
|
-
def _build_globals(self, env: Dict[str, Any]) -> Dict[str, Any]:
|
|
134
|
-
"""
|
|
135
|
-
Build restricted globals for safe execution.
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
env: User environment
|
|
139
|
-
|
|
140
|
-
Returns:
|
|
141
|
-
Safe globals dict
|
|
142
|
-
"""
|
|
143
|
-
restricted_globals = safe_globals.copy()
|
|
144
|
-
restricted_globals.update(limited_builtins)
|
|
145
|
-
restricted_globals.update(utility_builtins)
|
|
146
|
-
|
|
147
|
-
# Add guards
|
|
148
|
-
restricted_globals['_iter_unpack_sequence_'] = guarded_iter_unpack_sequence
|
|
149
|
-
restricted_globals['_getattr_'] = safer_getattr
|
|
150
|
-
restricted_globals['_getitem_'] = lambda obj, index: obj[index]
|
|
151
|
-
restricted_globals['_getiter_'] = iter
|
|
152
|
-
restricted_globals['_print_'] = PrintCollector
|
|
153
|
-
|
|
154
|
-
# Add additional safe builtins
|
|
155
|
-
restricted_globals.update({
|
|
156
|
-
# Types
|
|
157
|
-
'len': len,
|
|
158
|
-
'str': str,
|
|
159
|
-
'int': int,
|
|
160
|
-
'float': float,
|
|
161
|
-
'bool': bool,
|
|
162
|
-
'list': list,
|
|
163
|
-
'dict': dict,
|
|
164
|
-
'tuple': tuple,
|
|
165
|
-
'set': set,
|
|
166
|
-
'frozenset': frozenset,
|
|
167
|
-
'bytes': bytes,
|
|
168
|
-
'bytearray': bytearray,
|
|
169
|
-
|
|
170
|
-
# Iteration
|
|
171
|
-
'range': range,
|
|
172
|
-
'enumerate': enumerate,
|
|
173
|
-
'zip': zip,
|
|
174
|
-
'map': map,
|
|
175
|
-
'filter': filter,
|
|
176
|
-
'reversed': reversed,
|
|
177
|
-
'iter': iter,
|
|
178
|
-
'next': next,
|
|
179
|
-
|
|
180
|
-
# Aggregation
|
|
181
|
-
'sorted': sorted,
|
|
182
|
-
'sum': sum,
|
|
183
|
-
'min': min,
|
|
184
|
-
'max': max,
|
|
185
|
-
'any': any,
|
|
186
|
-
'all': all,
|
|
187
|
-
|
|
188
|
-
# Math
|
|
189
|
-
'abs': abs,
|
|
190
|
-
'round': round,
|
|
191
|
-
'pow': pow,
|
|
192
|
-
'divmod': divmod,
|
|
193
|
-
|
|
194
|
-
# String/repr
|
|
195
|
-
'chr': chr,
|
|
196
|
-
'ord': ord,
|
|
197
|
-
'hex': hex,
|
|
198
|
-
'oct': oct,
|
|
199
|
-
'bin': bin,
|
|
200
|
-
'repr': repr,
|
|
201
|
-
'ascii': ascii,
|
|
202
|
-
'format': format,
|
|
203
|
-
|
|
204
|
-
# Type checking
|
|
205
|
-
'isinstance': isinstance,
|
|
206
|
-
'issubclass': issubclass,
|
|
207
|
-
'callable': callable,
|
|
208
|
-
'type': type,
|
|
209
|
-
'hasattr': hasattr,
|
|
210
|
-
|
|
211
|
-
# Constants
|
|
212
|
-
'True': True,
|
|
213
|
-
'False': False,
|
|
214
|
-
'None': None,
|
|
215
|
-
})
|
|
216
|
-
|
|
217
|
-
# Add safe standard library modules
|
|
218
|
-
# These are read-only and don't allow file/network access
|
|
219
|
-
import re
|
|
220
|
-
import json
|
|
221
|
-
import math
|
|
222
|
-
from datetime import datetime, timedelta
|
|
223
|
-
from collections import Counter, defaultdict
|
|
224
|
-
|
|
225
|
-
restricted_globals.update({
|
|
226
|
-
're': re, # Regex (read-only)
|
|
227
|
-
'json': json, # JSON parsing (read-only)
|
|
228
|
-
'math': math, # Math functions
|
|
229
|
-
'datetime': datetime, # Date parsing
|
|
230
|
-
'timedelta': timedelta, # Time deltas
|
|
231
|
-
'Counter': Counter, # Counting helper
|
|
232
|
-
'defaultdict': defaultdict, # Dict with defaults
|
|
233
|
-
})
|
|
234
|
-
|
|
235
|
-
return restricted_globals
|