lm-deluge 0.0.87__py3-none-any.whl → 0.0.89__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/api_requests/gemini.py +19 -7
- lm_deluge/models/google.py +13 -0
- lm_deluge/tool/prefab/__init__.py +9 -1
- lm_deluge/tool/prefab/full_text_search/__init__.py +285 -0
- lm_deluge/tool/prefab/full_text_search/tantivy_index.py +396 -0
- lm_deluge/tool/prefab/rlm/__init__.py +296 -0
- lm_deluge/tool/prefab/rlm/executor.py +349 -0
- lm_deluge/tool/prefab/rlm/parse.py +144 -0
- lm_deluge/tool/prefab/sandbox.py +908 -0
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/METADATA +12 -1
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/RECORD +14 -9
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.87.dist-info → lm_deluge-0.0.89.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RLM (Recursive Language Model) code executor.
|
|
3
|
+
|
|
4
|
+
Executes Python code with access to a context variable and lm() function
|
|
5
|
+
for recursive language model calls.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Callable
|
|
12
|
+
|
|
13
|
+
from .parse import (
|
|
14
|
+
RLM_MODULES,
|
|
15
|
+
RLM_SAFE_BUILTINS,
|
|
16
|
+
RLMExecutionError,
|
|
17
|
+
validate_rlm_code,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from lm_deluge.client import _LLMClient
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class OutputCapture:
|
|
25
|
+
"""Captures print() output during execution."""
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self.outputs: list[str] = []
|
|
29
|
+
|
|
30
|
+
def print(self, *args, **kwargs):
|
|
31
|
+
"""Replacement print function that captures output."""
|
|
32
|
+
output = " ".join(str(arg) for arg in args)
|
|
33
|
+
self.outputs.append(output)
|
|
34
|
+
|
|
35
|
+
def get_output(self) -> str:
|
|
36
|
+
return "\n".join(self.outputs)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class PendingLMResult:
|
|
40
|
+
"""Placeholder for an lm() call result that hasn't completed yet."""
|
|
41
|
+
|
|
42
|
+
def __init__(self, call_id: int, results: dict[int, str]):
|
|
43
|
+
self._call_id = call_id
|
|
44
|
+
self._results = results
|
|
45
|
+
|
|
46
|
+
def _require_result(self) -> str:
|
|
47
|
+
if self._call_id not in self._results:
|
|
48
|
+
raise RuntimeError(f"LM result for call {self._call_id} not yet available")
|
|
49
|
+
return self._results[self._call_id]
|
|
50
|
+
|
|
51
|
+
def is_ready(self) -> bool:
|
|
52
|
+
return self._call_id in self._results
|
|
53
|
+
|
|
54
|
+
def __repr__(self) -> str:
|
|
55
|
+
return repr(self._require_result())
|
|
56
|
+
|
|
57
|
+
def __str__(self) -> str:
|
|
58
|
+
return str(self._require_result())
|
|
59
|
+
|
|
60
|
+
def __getattr__(self, name: str) -> Any:
|
|
61
|
+
return getattr(self._require_result(), name)
|
|
62
|
+
|
|
63
|
+
def __getitem__(self, key: Any) -> Any:
|
|
64
|
+
return self._require_result()[key]
|
|
65
|
+
|
|
66
|
+
def __iter__(self):
|
|
67
|
+
return iter(self._require_result())
|
|
68
|
+
|
|
69
|
+
def __len__(self) -> int:
|
|
70
|
+
return len(self._require_result())
|
|
71
|
+
|
|
72
|
+
def __bool__(self) -> bool:
|
|
73
|
+
return bool(self._require_result())
|
|
74
|
+
|
|
75
|
+
def __add__(self, other):
|
|
76
|
+
return self._require_result() + other
|
|
77
|
+
|
|
78
|
+
def __radd__(self, other):
|
|
79
|
+
return other + self._require_result()
|
|
80
|
+
|
|
81
|
+
def __contains__(self, item):
|
|
82
|
+
return item in self._require_result()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class FinalAnswer(Exception):
|
|
86
|
+
"""Raised when FINAL() or FINAL_VAR() is called to signal completion."""
|
|
87
|
+
|
|
88
|
+
def __init__(self, answer: Any):
|
|
89
|
+
self.answer = answer
|
|
90
|
+
super().__init__("Final answer signaled")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _resolve_value(value: Any, results: dict[int, str]) -> Any:
|
|
94
|
+
"""Recursively resolve any PendingLMResult placeholders in a value."""
|
|
95
|
+
if isinstance(value, PendingLMResult):
|
|
96
|
+
return value._require_result()
|
|
97
|
+
if isinstance(value, list):
|
|
98
|
+
return [_resolve_value(v, results) for v in value]
|
|
99
|
+
if isinstance(value, tuple):
|
|
100
|
+
return tuple(_resolve_value(v, results) for v in value)
|
|
101
|
+
if isinstance(value, dict):
|
|
102
|
+
return {k: _resolve_value(v, results) for k, v in value.items()}
|
|
103
|
+
if isinstance(value, set):
|
|
104
|
+
return {_resolve_value(v, results) for v in value}
|
|
105
|
+
return value
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def _contains_unresolved(value: Any) -> bool:
|
|
109
|
+
"""Check if a value contains any unresolved PendingLMResult."""
|
|
110
|
+
if isinstance(value, PendingLMResult):
|
|
111
|
+
return not value.is_ready()
|
|
112
|
+
if isinstance(value, (list, tuple, set)):
|
|
113
|
+
return any(_contains_unresolved(item) for item in value)
|
|
114
|
+
if isinstance(value, dict):
|
|
115
|
+
return any(_contains_unresolved(v) for v in value.values())
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class RLMExecutor:
|
|
120
|
+
"""Executes RLM code with access to context and lm() calls."""
|
|
121
|
+
|
|
122
|
+
def __init__(
|
|
123
|
+
self,
|
|
124
|
+
context: str,
|
|
125
|
+
client: _LLMClient,
|
|
126
|
+
context_var_name: str = "CONTEXT",
|
|
127
|
+
max_lm_calls_per_execution: int = 20,
|
|
128
|
+
):
|
|
129
|
+
"""Initialize the RLM executor.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
context: The long context string to analyze
|
|
133
|
+
client: LLMClient for making recursive lm() calls
|
|
134
|
+
context_var_name: Variable name for the context (default: "CONTEXT")
|
|
135
|
+
max_lm_calls_per_execution: Maximum lm() calls allowed per execute() call
|
|
136
|
+
"""
|
|
137
|
+
self.context = context
|
|
138
|
+
self.client = client
|
|
139
|
+
self.context_var_name = context_var_name
|
|
140
|
+
self.max_lm_calls_per_execution = max_lm_calls_per_execution
|
|
141
|
+
|
|
142
|
+
# Persistent state across execute() calls
|
|
143
|
+
self._persistent_locals: dict[str, Any] = {}
|
|
144
|
+
|
|
145
|
+
def _make_lm_wrapper(
|
|
146
|
+
self,
|
|
147
|
+
pending_lm_calls: list[dict],
|
|
148
|
+
lm_results: dict[int, str],
|
|
149
|
+
call_state: dict[str, int],
|
|
150
|
+
pending_call_ids: set[int],
|
|
151
|
+
) -> Callable[[str], PendingLMResult]:
|
|
152
|
+
"""Create the lm(prompt) wrapper function."""
|
|
153
|
+
|
|
154
|
+
def lm_call(prompt: str) -> PendingLMResult:
|
|
155
|
+
# Check for unresolved dependencies in the prompt
|
|
156
|
+
if _contains_unresolved(prompt):
|
|
157
|
+
raise RuntimeError("LM result for call dependency not yet available")
|
|
158
|
+
|
|
159
|
+
call_id = call_state["next_lm_id"]
|
|
160
|
+
call_state["next_lm_id"] += 1
|
|
161
|
+
|
|
162
|
+
# Only queue if not already completed or pending
|
|
163
|
+
if call_id not in lm_results and call_id not in pending_call_ids:
|
|
164
|
+
if len(pending_lm_calls) >= self.max_lm_calls_per_execution:
|
|
165
|
+
raise RuntimeError(
|
|
166
|
+
f"Too many lm() calls in single execution "
|
|
167
|
+
f"(max {self.max_lm_calls_per_execution})"
|
|
168
|
+
)
|
|
169
|
+
pending_call_ids.add(call_id)
|
|
170
|
+
pending_lm_calls.append(
|
|
171
|
+
{
|
|
172
|
+
"id": call_id,
|
|
173
|
+
"prompt": str(prompt),
|
|
174
|
+
}
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
return PendingLMResult(call_id, lm_results)
|
|
178
|
+
|
|
179
|
+
return lm_call
|
|
180
|
+
|
|
181
|
+
def _make_final_func(
|
|
182
|
+
self, exec_namespace: dict[str, Any], lm_results: dict[int, str]
|
|
183
|
+
) -> Callable[[Any], None]:
|
|
184
|
+
"""Create final(answer) function."""
|
|
185
|
+
|
|
186
|
+
def final_func(answer: Any) -> None:
|
|
187
|
+
resolved = _resolve_value(answer, lm_results)
|
|
188
|
+
raise FinalAnswer(resolved)
|
|
189
|
+
|
|
190
|
+
return final_func
|
|
191
|
+
|
|
192
|
+
def _make_final_var_func(
|
|
193
|
+
self, exec_namespace: dict[str, Any], lm_results: dict[int, str]
|
|
194
|
+
) -> Callable[[str], None]:
|
|
195
|
+
"""Create final_var(varname) function."""
|
|
196
|
+
|
|
197
|
+
def final_var_func(varname: str) -> None:
|
|
198
|
+
if varname not in exec_namespace:
|
|
199
|
+
raise RuntimeError(f"Variable '{varname}' not found")
|
|
200
|
+
value = exec_namespace[varname]
|
|
201
|
+
resolved = _resolve_value(value, lm_results)
|
|
202
|
+
raise FinalAnswer(resolved)
|
|
203
|
+
|
|
204
|
+
return final_var_func
|
|
205
|
+
|
|
206
|
+
async def _execute_pending_lm_calls(
|
|
207
|
+
self,
|
|
208
|
+
pending_calls: list[dict],
|
|
209
|
+
results: dict[int, str],
|
|
210
|
+
) -> None:
|
|
211
|
+
"""Execute all pending lm() calls in parallel."""
|
|
212
|
+
if not pending_calls:
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
from lm_deluge.prompt import Conversation
|
|
216
|
+
|
|
217
|
+
# Start all calls in parallel using start_nowait
|
|
218
|
+
task_mapping: list[tuple[int, int]] = [] # (call_id, task_id)
|
|
219
|
+
for call in pending_calls:
|
|
220
|
+
conv = Conversation.user(call["prompt"])
|
|
221
|
+
task_id = self.client.start_nowait(conv)
|
|
222
|
+
task_mapping.append((call["id"], task_id))
|
|
223
|
+
|
|
224
|
+
# Wait for all to complete
|
|
225
|
+
for call_id, task_id in task_mapping:
|
|
226
|
+
try:
|
|
227
|
+
response = await self.client.wait_for(task_id)
|
|
228
|
+
results[call_id] = response.completion or "(no response)"
|
|
229
|
+
except Exception as e:
|
|
230
|
+
results[call_id] = f"Error: {e}"
|
|
231
|
+
|
|
232
|
+
# Clear the pending list
|
|
233
|
+
pending_calls.clear()
|
|
234
|
+
|
|
235
|
+
def _format_answer(self, value: Any) -> str:
|
|
236
|
+
"""Format the final answer as a string."""
|
|
237
|
+
if isinstance(value, str):
|
|
238
|
+
return value
|
|
239
|
+
try:
|
|
240
|
+
return json.dumps(value, default=str, indent=2)
|
|
241
|
+
except Exception:
|
|
242
|
+
return str(value)
|
|
243
|
+
|
|
244
|
+
async def execute(self, code: str) -> tuple[str, bool]:
|
|
245
|
+
"""Execute RLM code.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
code: Python code to execute
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Tuple of (output_string, is_final) where is_final indicates
|
|
252
|
+
whether FINAL()/FINAL_VAR() was called.
|
|
253
|
+
"""
|
|
254
|
+
# Validate the code
|
|
255
|
+
tree = validate_rlm_code(code)
|
|
256
|
+
|
|
257
|
+
# Set up execution environment
|
|
258
|
+
pending_lm_calls: list[dict] = []
|
|
259
|
+
lm_results: dict[int, str] = {}
|
|
260
|
+
pending_call_ids: set[int] = set()
|
|
261
|
+
call_state = {"next_lm_id": 0}
|
|
262
|
+
output_capture = OutputCapture()
|
|
263
|
+
|
|
264
|
+
# Create the lm() wrapper
|
|
265
|
+
lm_wrapper = self._make_lm_wrapper(
|
|
266
|
+
pending_lm_calls, lm_results, call_state, pending_call_ids
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
# Build a single namespace for execution
|
|
270
|
+
# Using a single dict for both globals and locals ensures that
|
|
271
|
+
# variables are visible inside nested scopes (list comprehensions, etc.)
|
|
272
|
+
exec_namespace: dict[str, Any] = {
|
|
273
|
+
"__builtins__": {**RLM_SAFE_BUILTINS, "print": output_capture.print},
|
|
274
|
+
self.context_var_name: self.context,
|
|
275
|
+
"lm": lm_wrapper,
|
|
276
|
+
"json": json, # Explicitly include json
|
|
277
|
+
**RLM_MODULES,
|
|
278
|
+
# Include persistent state from previous calls
|
|
279
|
+
**self._persistent_locals,
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
# Add final and final_var (they need access to exec_namespace for final_var)
|
|
283
|
+
exec_namespace["final"] = self._make_final_func(exec_namespace, lm_results)
|
|
284
|
+
exec_namespace["final_var"] = self._make_final_var_func(
|
|
285
|
+
exec_namespace, lm_results
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
# Track which keys are "system" keys that shouldn't be persisted
|
|
289
|
+
system_keys = set(exec_namespace.keys())
|
|
290
|
+
|
|
291
|
+
# Execute with retry loop for deferred lm() resolution
|
|
292
|
+
max_iterations = 50
|
|
293
|
+
compiled = compile(tree, "<rlm>", "exec")
|
|
294
|
+
|
|
295
|
+
for iteration in range(max_iterations):
|
|
296
|
+
# Reset call sequencing for this pass
|
|
297
|
+
call_state["next_lm_id"] = 0
|
|
298
|
+
pending_call_ids.clear()
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
exec(compiled, exec_namespace)
|
|
302
|
+
|
|
303
|
+
# Execution completed - run any remaining pending calls
|
|
304
|
+
await self._execute_pending_lm_calls(pending_lm_calls, lm_results)
|
|
305
|
+
|
|
306
|
+
# Update persistent locals (exclude system keys)
|
|
307
|
+
for key, value in exec_namespace.items():
|
|
308
|
+
if key not in system_keys:
|
|
309
|
+
self._persistent_locals[key] = value
|
|
310
|
+
|
|
311
|
+
break
|
|
312
|
+
|
|
313
|
+
except FinalAnswer as fa:
|
|
314
|
+
# FINAL() or FINAL_VAR() was called
|
|
315
|
+
for key, value in exec_namespace.items():
|
|
316
|
+
if key not in system_keys:
|
|
317
|
+
self._persistent_locals[key] = value
|
|
318
|
+
return (self._format_answer(fa.answer), True)
|
|
319
|
+
|
|
320
|
+
except RuntimeError as e:
|
|
321
|
+
if "not yet available" in str(e):
|
|
322
|
+
# Need to resolve pending lm() calls and retry
|
|
323
|
+
await self._execute_pending_lm_calls(pending_lm_calls, lm_results)
|
|
324
|
+
pending_call_ids.clear()
|
|
325
|
+
# Continue to retry
|
|
326
|
+
else:
|
|
327
|
+
raise RLMExecutionError(f"Runtime error: {e}")
|
|
328
|
+
|
|
329
|
+
except Exception as e:
|
|
330
|
+
raise RLMExecutionError(f"Execution error: {type(e).__name__}: {e}")
|
|
331
|
+
|
|
332
|
+
else:
|
|
333
|
+
raise RLMExecutionError(
|
|
334
|
+
f"Execution exceeded maximum iterations ({max_iterations})"
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# Get output
|
|
338
|
+
output = output_capture.get_output()
|
|
339
|
+
|
|
340
|
+
# If no print output, check for result variable
|
|
341
|
+
if not output and "result" in exec_namespace:
|
|
342
|
+
result_value = _resolve_value(exec_namespace["result"], lm_results)
|
|
343
|
+
output = self._format_answer(result_value)
|
|
344
|
+
|
|
345
|
+
return (output or "Execution completed with no output", False)
|
|
346
|
+
|
|
347
|
+
def reset(self) -> None:
|
|
348
|
+
"""Reset the persistent state."""
|
|
349
|
+
self._persistent_locals.clear()
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RLM (Recursive Language Model) code parsing and validation.
|
|
3
|
+
|
|
4
|
+
Extends OTC's security model with additional modules for context analysis.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import ast
|
|
8
|
+
import collections
|
|
9
|
+
import json
|
|
10
|
+
import math
|
|
11
|
+
import re
|
|
12
|
+
|
|
13
|
+
# Import OTC's base security definitions
|
|
14
|
+
from ..otc.parse import (
|
|
15
|
+
FORBIDDEN_CALLS,
|
|
16
|
+
SAFE_BUILTINS,
|
|
17
|
+
ASTValidator,
|
|
18
|
+
OTCSecurityError,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
# RLM uses the same builtins as OTC
|
|
22
|
+
RLM_SAFE_BUILTINS = SAFE_BUILTINS.copy()
|
|
23
|
+
|
|
24
|
+
# Modules available in RLM - imports of these are stripped (no-ops)
|
|
25
|
+
RLM_ALLOWED_IMPORTS = {"re", "math", "collections", "json"}
|
|
26
|
+
|
|
27
|
+
# Modules and common imports available in RLM (injected into globals)
|
|
28
|
+
RLM_MODULES = {
|
|
29
|
+
# Full modules
|
|
30
|
+
"re": re,
|
|
31
|
+
"math": math,
|
|
32
|
+
"collections": collections,
|
|
33
|
+
"json": json,
|
|
34
|
+
# Common imports from collections
|
|
35
|
+
"Counter": collections.Counter,
|
|
36
|
+
"defaultdict": collections.defaultdict,
|
|
37
|
+
"deque": collections.deque,
|
|
38
|
+
"namedtuple": collections.namedtuple,
|
|
39
|
+
"OrderedDict": collections.OrderedDict,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class RLMSecurityError(OTCSecurityError):
|
|
44
|
+
"""Raised when RLM code violates security constraints."""
|
|
45
|
+
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class RLMExecutionError(Exception):
|
|
50
|
+
"""Raised when RLM code execution fails."""
|
|
51
|
+
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class RLMASTValidator(ASTValidator):
|
|
56
|
+
"""Validates RLM code with additional checks.
|
|
57
|
+
|
|
58
|
+
Import statements for allowed modules are stripped (no-ops).
|
|
59
|
+
Imports of disallowed modules raise errors.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, allowed_names: set[str] | None = None):
|
|
63
|
+
super().__init__(allowed_tool_names=set())
|
|
64
|
+
self.allowed_names = allowed_names or set()
|
|
65
|
+
|
|
66
|
+
def visit(self, node: ast.AST) -> None:
|
|
67
|
+
# Check imports - allowed ones will be stripped later, disallowed ones error
|
|
68
|
+
if isinstance(node, ast.Import):
|
|
69
|
+
for alias in node.names:
|
|
70
|
+
if alias.name not in RLM_ALLOWED_IMPORTS:
|
|
71
|
+
self.errors.append(
|
|
72
|
+
f"Forbidden import: {alias.name} at line {node.lineno}. "
|
|
73
|
+
f"Available modules: {', '.join(sorted(RLM_ALLOWED_IMPORTS))}"
|
|
74
|
+
)
|
|
75
|
+
self.generic_visit(node)
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
if isinstance(node, ast.ImportFrom):
|
|
79
|
+
if node.module not in RLM_ALLOWED_IMPORTS:
|
|
80
|
+
self.errors.append(
|
|
81
|
+
f"Forbidden import: from {node.module} at line {node.lineno}. "
|
|
82
|
+
f"Available modules: {', '.join(sorted(RLM_ALLOWED_IMPORTS))}"
|
|
83
|
+
)
|
|
84
|
+
self.generic_visit(node)
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
# For all other nodes, use parent validation
|
|
88
|
+
super().visit(node)
|
|
89
|
+
|
|
90
|
+
def visit_Call(self, node: ast.Call) -> None:
|
|
91
|
+
if isinstance(node.func, ast.Name):
|
|
92
|
+
if node.func.id in FORBIDDEN_CALLS:
|
|
93
|
+
self.errors.append(
|
|
94
|
+
f"Forbidden function call: {node.func.id} at line {node.lineno}"
|
|
95
|
+
)
|
|
96
|
+
self.generic_visit(node)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class ImportStripper(ast.NodeTransformer):
|
|
100
|
+
"""Strips import statements for allowed modules from the AST."""
|
|
101
|
+
|
|
102
|
+
def visit_Import(self, node: ast.Import) -> ast.AST | None:
|
|
103
|
+
# Keep only imports of non-allowed modules (which will error at validation)
|
|
104
|
+
remaining = [
|
|
105
|
+
alias for alias in node.names if alias.name not in RLM_ALLOWED_IMPORTS
|
|
106
|
+
]
|
|
107
|
+
if not remaining:
|
|
108
|
+
return None # Remove the entire import statement
|
|
109
|
+
node.names = remaining
|
|
110
|
+
return node
|
|
111
|
+
|
|
112
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST | None:
|
|
113
|
+
# Strip imports from allowed modules
|
|
114
|
+
if node.module in RLM_ALLOWED_IMPORTS:
|
|
115
|
+
return None # Remove the entire import statement
|
|
116
|
+
return node
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def validate_rlm_code(code: str) -> ast.Module:
|
|
120
|
+
"""Parse and validate RLM code, returning AST if valid.
|
|
121
|
+
|
|
122
|
+
Import statements for allowed modules (re, math, collections, json) are
|
|
123
|
+
stripped from the AST since these modules are already in the namespace.
|
|
124
|
+
"""
|
|
125
|
+
try:
|
|
126
|
+
tree = ast.parse(code)
|
|
127
|
+
except SyntaxError as e:
|
|
128
|
+
raise RLMSecurityError(f"Syntax error: {e}")
|
|
129
|
+
|
|
130
|
+
# Validate first (before stripping)
|
|
131
|
+
validator = RLMASTValidator()
|
|
132
|
+
errors = validator.validate(tree)
|
|
133
|
+
|
|
134
|
+
if errors:
|
|
135
|
+
raise RLMSecurityError(
|
|
136
|
+
"Security violations:\n" + "\n".join(f" - {e}" for e in errors)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Strip allowed imports (they're no-ops since modules are pre-loaded)
|
|
140
|
+
stripper = ImportStripper()
|
|
141
|
+
tree = stripper.visit(tree)
|
|
142
|
+
ast.fix_missing_locations(tree)
|
|
143
|
+
|
|
144
|
+
return tree
|