lm-deluge 0.0.80__py3-none-any.whl → 0.0.82__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lm_deluge/__init__.py +1 -2
- lm_deluge/api_requests/anthropic.py +2 -1
- lm_deluge/api_requests/base.py +13 -0
- lm_deluge/api_requests/gemini.py +1 -1
- lm_deluge/api_requests/openai.py +3 -2
- lm_deluge/client.py +16 -11
- lm_deluge/llm_tools/__init__.py +12 -5
- lm_deluge/pipelines/__init__.py +11 -0
- lm_deluge/{llm_tools → pipelines}/score.py +2 -2
- lm_deluge/{llm_tools → pipelines}/translate.py +5 -3
- lm_deluge/prompt.py +105 -0
- lm_deluge/request_context.py +2 -2
- lm_deluge/{tool.py → tool/__init__.py} +531 -314
- lm_deluge/tool/prefab/__init__.py +29 -0
- lm_deluge/tool/prefab/batch_tool.py +156 -0
- lm_deluge/{llm_tools → tool/prefab}/filesystem.py +1 -1
- lm_deluge/tool/prefab/memory.py +190 -0
- lm_deluge/tool/prefab/otc/__init__.py +165 -0
- lm_deluge/tool/prefab/otc/executor.py +281 -0
- lm_deluge/tool/prefab/otc/parse.py +188 -0
- lm_deluge/{llm_tools → tool/prefab}/sandbox.py +251 -61
- lm_deluge/{llm_tools → tool/prefab}/todos.py +1 -1
- lm_deluge/tool/prefab/tool_search.py +169 -0
- lm_deluge/tracker.py +16 -13
- {lm_deluge-0.0.80.dist-info → lm_deluge-0.0.82.dist-info}/METADATA +2 -3
- {lm_deluge-0.0.80.dist-info → lm_deluge-0.0.82.dist-info}/RECORD +34 -28
- lm_deluge/presets/cerebras.py +0 -17
- lm_deluge/presets/meta.py +0 -13
- /lm_deluge/{llm_tools → pipelines}/classify.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/extract.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/locate.py +0 -0
- /lm_deluge/{llm_tools → pipelines}/ocr.py +0 -0
- /lm_deluge/{llm_tools → tool/prefab}/subagents.py +0 -0
- {lm_deluge-0.0.80.dist-info → lm_deluge-0.0.82.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.80.dist-info → lm_deluge-0.0.82.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.80.dist-info → lm_deluge-0.0.82.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
from typing import Any, Callable
|
|
4
|
+
|
|
5
|
+
from lm_deluge.tool import Tool
|
|
6
|
+
|
|
7
|
+
from .parse import SAFE_BUILTINS, OTCExecutionError, validate_code
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OutputCapture:
|
|
11
|
+
"""Captures print() output during execution."""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self.outputs: list[str] = []
|
|
15
|
+
|
|
16
|
+
def print(self, *args, **kwargs):
|
|
17
|
+
"""Replacement print function that captures output."""
|
|
18
|
+
output = " ".join(str(arg) for arg in args)
|
|
19
|
+
self.outputs.append(output)
|
|
20
|
+
|
|
21
|
+
def get_output(self) -> str:
|
|
22
|
+
return "\n".join(self.outputs)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class PendingResult:
|
|
26
|
+
"""Placeholder for a tool call result."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, call_id: int, results: dict[int, Any]):
|
|
29
|
+
self._call_id = call_id
|
|
30
|
+
self._results = results
|
|
31
|
+
|
|
32
|
+
def _require_result(self) -> Any:
|
|
33
|
+
if self._call_id not in self._results:
|
|
34
|
+
raise RuntimeError(f"Result for call {self._call_id} not yet available")
|
|
35
|
+
return self._results[self._call_id]
|
|
36
|
+
|
|
37
|
+
def is_ready(self) -> bool:
|
|
38
|
+
return self._call_id in self._results
|
|
39
|
+
|
|
40
|
+
def __repr__(self) -> str:
|
|
41
|
+
return repr(self._require_result())
|
|
42
|
+
|
|
43
|
+
def __str__(self) -> str:
|
|
44
|
+
return str(self._require_result())
|
|
45
|
+
|
|
46
|
+
def __getattr__(self, name: str) -> Any:
|
|
47
|
+
return getattr(self._require_result(), name)
|
|
48
|
+
|
|
49
|
+
def __getitem__(self, key: Any) -> Any:
|
|
50
|
+
return self._require_result()[key]
|
|
51
|
+
|
|
52
|
+
def __iter__(self):
|
|
53
|
+
return iter(self._require_result())
|
|
54
|
+
|
|
55
|
+
def __len__(self) -> int:
|
|
56
|
+
return len(self._require_result())
|
|
57
|
+
|
|
58
|
+
def __bool__(self) -> bool:
|
|
59
|
+
return bool(self._require_result())
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class OTCExecutor:
|
|
63
|
+
"""Executes OTC code with access to tools."""
|
|
64
|
+
|
|
65
|
+
def __init__(self, tools: list[Tool]):
|
|
66
|
+
self.tools = {tool.name: tool for tool in tools}
|
|
67
|
+
self.tool_names = set(self.tools.keys())
|
|
68
|
+
|
|
69
|
+
def _contains_unresolved(self, value: Any) -> bool:
|
|
70
|
+
"""Check if a value (possibly nested) contains an unresolved PendingResult."""
|
|
71
|
+
if isinstance(value, PendingResult):
|
|
72
|
+
return not value.is_ready()
|
|
73
|
+
if isinstance(value, list):
|
|
74
|
+
return any(self._contains_unresolved(item) for item in value)
|
|
75
|
+
if isinstance(value, tuple):
|
|
76
|
+
return any(self._contains_unresolved(item) for item in value)
|
|
77
|
+
if isinstance(value, set):
|
|
78
|
+
return any(self._contains_unresolved(item) for item in value)
|
|
79
|
+
if isinstance(value, dict):
|
|
80
|
+
return any(self._contains_unresolved(v) for v in value.values())
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
def _resolve_dependencies(self, value: Any, results: dict[int, Any]) -> Any:
|
|
84
|
+
"""Replace PendingResult placeholders with concrete values."""
|
|
85
|
+
if isinstance(value, PendingResult):
|
|
86
|
+
return value._require_result()
|
|
87
|
+
if isinstance(value, list):
|
|
88
|
+
return [self._resolve_dependencies(v, results) for v in value]
|
|
89
|
+
if isinstance(value, tuple):
|
|
90
|
+
return tuple(self._resolve_dependencies(v, results) for v in value)
|
|
91
|
+
if isinstance(value, set):
|
|
92
|
+
return {self._resolve_dependencies(v, results) for v in value}
|
|
93
|
+
if isinstance(value, dict):
|
|
94
|
+
return {k: self._resolve_dependencies(v, results) for k, v in value.items()}
|
|
95
|
+
return value
|
|
96
|
+
|
|
97
|
+
def _resolve_output_value(self, value: Any, results: dict[int, Any]) -> Any:
|
|
98
|
+
"""Resolve PendingResult placeholders when building the final output."""
|
|
99
|
+
if isinstance(value, PendingResult):
|
|
100
|
+
return value._require_result()
|
|
101
|
+
if isinstance(value, list):
|
|
102
|
+
return [self._resolve_output_value(v, results) for v in value]
|
|
103
|
+
if isinstance(value, tuple):
|
|
104
|
+
return tuple(self._resolve_output_value(v, results) for v in value)
|
|
105
|
+
if isinstance(value, set):
|
|
106
|
+
return {self._resolve_output_value(v, results) for v in value}
|
|
107
|
+
if isinstance(value, dict):
|
|
108
|
+
return {k: self._resolve_output_value(v, results) for k, v in value.items()}
|
|
109
|
+
return value
|
|
110
|
+
|
|
111
|
+
def _make_sync_tool_wrapper(
|
|
112
|
+
self,
|
|
113
|
+
tool: Tool,
|
|
114
|
+
pending_calls: list,
|
|
115
|
+
results: dict[int, Any],
|
|
116
|
+
call_state: dict[str, int],
|
|
117
|
+
pending_call_ids: set[int],
|
|
118
|
+
) -> Callable:
|
|
119
|
+
"""Create a sync wrapper that queues tool calls for later execution."""
|
|
120
|
+
|
|
121
|
+
def wrapper(*args, **kwargs):
|
|
122
|
+
# Convert positional args to kwargs using tool parameter order
|
|
123
|
+
if args and tool.parameters:
|
|
124
|
+
param_names = list(tool.parameters.keys())
|
|
125
|
+
for i, arg in enumerate(args):
|
|
126
|
+
if i < len(param_names):
|
|
127
|
+
kwargs[param_names[i]] = arg
|
|
128
|
+
|
|
129
|
+
# Ensure we don't pass unresolved PendingResult objects as arguments
|
|
130
|
+
if self._contains_unresolved(kwargs):
|
|
131
|
+
raise RuntimeError("Result for call dependency not yet available")
|
|
132
|
+
|
|
133
|
+
# Resolve any PendingResult values before queueing
|
|
134
|
+
resolved_kwargs = self._resolve_dependencies(kwargs, results)
|
|
135
|
+
|
|
136
|
+
# Generate a deterministic call ID based on execution order
|
|
137
|
+
call_id = call_state["next_id"]
|
|
138
|
+
call_state["next_id"] += 1
|
|
139
|
+
|
|
140
|
+
# Avoid re-queueing calls that already have results or are pending
|
|
141
|
+
if call_id not in results and call_id not in pending_call_ids:
|
|
142
|
+
pending_call_ids.add(call_id)
|
|
143
|
+
pending_calls.append(
|
|
144
|
+
{
|
|
145
|
+
"id": call_id,
|
|
146
|
+
"tool": tool.name,
|
|
147
|
+
"kwargs": resolved_kwargs,
|
|
148
|
+
}
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Return a placeholder that will be resolved later
|
|
152
|
+
return PendingResult(call_id, results)
|
|
153
|
+
|
|
154
|
+
return wrapper
|
|
155
|
+
|
|
156
|
+
async def _execute_pending_calls(self, pending_calls: list, results: dict) -> None:
|
|
157
|
+
"""Execute all pending tool calls in parallel."""
|
|
158
|
+
if not pending_calls:
|
|
159
|
+
return
|
|
160
|
+
|
|
161
|
+
async def execute_one(call: dict) -> tuple[int, Any]:
|
|
162
|
+
tool = self.tools[call["tool"]]
|
|
163
|
+
try:
|
|
164
|
+
if asyncio.iscoroutinefunction(tool.run):
|
|
165
|
+
result = await tool.run(**call["kwargs"])
|
|
166
|
+
elif tool.run is not None:
|
|
167
|
+
result = tool.run(**call["kwargs"])
|
|
168
|
+
else:
|
|
169
|
+
raise OTCExecutionError("tool is not executable")
|
|
170
|
+
|
|
171
|
+
# Try to parse as JSON if it's a string
|
|
172
|
+
if isinstance(result, str):
|
|
173
|
+
try:
|
|
174
|
+
result = json.loads(result)
|
|
175
|
+
except json.JSONDecodeError:
|
|
176
|
+
pass # Keep as string
|
|
177
|
+
|
|
178
|
+
return call["id"], result
|
|
179
|
+
except Exception as e:
|
|
180
|
+
return call["id"], {"error": str(e)}
|
|
181
|
+
|
|
182
|
+
# Execute all in parallel
|
|
183
|
+
call_results = await asyncio.gather(
|
|
184
|
+
*[execute_one(call) for call in pending_calls]
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Store results
|
|
188
|
+
for call_id, result in call_results:
|
|
189
|
+
results[call_id] = result
|
|
190
|
+
|
|
191
|
+
# Clear pending
|
|
192
|
+
pending_calls.clear()
|
|
193
|
+
|
|
194
|
+
async def execute(self, code: str) -> str:
|
|
195
|
+
"""Execute OTC code and return the final output.
|
|
196
|
+
|
|
197
|
+
The execution model:
|
|
198
|
+
1. Parse and validate the code
|
|
199
|
+
2. Execute line-by-line, collecting tool calls
|
|
200
|
+
3. When we hit a point where results are needed, execute pending calls
|
|
201
|
+
4. Continue until done
|
|
202
|
+
5. Return captured output or final expression value
|
|
203
|
+
"""
|
|
204
|
+
# Validate
|
|
205
|
+
tree = validate_code(code, self.tool_names)
|
|
206
|
+
|
|
207
|
+
# Set up execution environment
|
|
208
|
+
pending_calls: list = []
|
|
209
|
+
results: dict = {}
|
|
210
|
+
output_capture = OutputCapture()
|
|
211
|
+
pending_call_ids: set[int] = set()
|
|
212
|
+
call_state = {"next_id": 0}
|
|
213
|
+
|
|
214
|
+
# Create tool wrappers
|
|
215
|
+
tool_wrappers = {
|
|
216
|
+
name: self._make_sync_tool_wrapper(
|
|
217
|
+
tool, pending_calls, results, call_state, pending_call_ids
|
|
218
|
+
)
|
|
219
|
+
for name, tool in self.tools.items()
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
# Build globals
|
|
223
|
+
exec_globals = {
|
|
224
|
+
"__builtins__": {**SAFE_BUILTINS, "print": output_capture.print},
|
|
225
|
+
"json": json, # Allow json for output formatting
|
|
226
|
+
**tool_wrappers,
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
exec_locals: dict = {}
|
|
230
|
+
|
|
231
|
+
# Execute the code
|
|
232
|
+
# We need to handle the deferred execution pattern:
|
|
233
|
+
# Tool calls return PendingResult objects, and we need to resolve them
|
|
234
|
+
# before they're actually used.
|
|
235
|
+
|
|
236
|
+
# Strategy: Execute the whole thing, catching any "not yet available" errors,
|
|
237
|
+
# then execute pending calls and retry until done.
|
|
238
|
+
|
|
239
|
+
max_iterations = 100 # Prevent infinite loops
|
|
240
|
+
|
|
241
|
+
for _ in range(max_iterations):
|
|
242
|
+
# Reset call sequencing and pending tracking for this pass
|
|
243
|
+
call_state["next_id"] = 0
|
|
244
|
+
pending_call_ids.clear()
|
|
245
|
+
try:
|
|
246
|
+
exec(compile(tree, "<otc>", "exec"), exec_globals, exec_locals)
|
|
247
|
+
# If we get here, execution completed
|
|
248
|
+
# Execute any remaining pending calls (though their results won't be used)
|
|
249
|
+
await self._execute_pending_calls(pending_calls, results)
|
|
250
|
+
pending_call_ids.clear()
|
|
251
|
+
break
|
|
252
|
+
|
|
253
|
+
except RuntimeError as e:
|
|
254
|
+
if "not yet available" in str(e):
|
|
255
|
+
# Need to resolve pending calls and retry
|
|
256
|
+
await self._execute_pending_calls(pending_calls, results)
|
|
257
|
+
pending_call_ids.clear()
|
|
258
|
+
# Continue the loop to retry
|
|
259
|
+
else:
|
|
260
|
+
raise OTCExecutionError(f"Runtime error: {e}")
|
|
261
|
+
|
|
262
|
+
except Exception as e:
|
|
263
|
+
raise OTCExecutionError(f"Execution error: {type(e).__name__}: {e}")
|
|
264
|
+
|
|
265
|
+
else:
|
|
266
|
+
raise OTCExecutionError("Execution exceeded maximum iterations")
|
|
267
|
+
|
|
268
|
+
# Get output
|
|
269
|
+
output = output_capture.get_output()
|
|
270
|
+
|
|
271
|
+
# If no print output, try to get the last expression value
|
|
272
|
+
if not output and exec_locals:
|
|
273
|
+
# Look for a 'result' variable or the last assigned value
|
|
274
|
+
if "result" in exec_locals:
|
|
275
|
+
result = self._resolve_output_value(exec_locals["result"], results)
|
|
276
|
+
if isinstance(result, str):
|
|
277
|
+
output = result
|
|
278
|
+
else:
|
|
279
|
+
output = json.dumps(result, default=str, indent=2)
|
|
280
|
+
|
|
281
|
+
return output if output else "Composition completed with no output"
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
|
|
3
|
+
SAFE_BUILTINS = {
|
|
4
|
+
# Types
|
|
5
|
+
"bool": bool,
|
|
6
|
+
"int": int,
|
|
7
|
+
"float": float,
|
|
8
|
+
"str": str,
|
|
9
|
+
"list": list,
|
|
10
|
+
"dict": dict,
|
|
11
|
+
"tuple": tuple,
|
|
12
|
+
"set": set,
|
|
13
|
+
"frozenset": frozenset,
|
|
14
|
+
"type": type,
|
|
15
|
+
# Functions
|
|
16
|
+
"abs": abs,
|
|
17
|
+
"all": all,
|
|
18
|
+
"any": any,
|
|
19
|
+
"bin": bin,
|
|
20
|
+
"chr": chr,
|
|
21
|
+
"divmod": divmod,
|
|
22
|
+
"enumerate": enumerate,
|
|
23
|
+
"filter": filter,
|
|
24
|
+
"format": format,
|
|
25
|
+
"hasattr": hasattr,
|
|
26
|
+
"hash": hash,
|
|
27
|
+
"hex": hex,
|
|
28
|
+
"isinstance": isinstance,
|
|
29
|
+
"issubclass": issubclass,
|
|
30
|
+
"iter": iter,
|
|
31
|
+
"len": len,
|
|
32
|
+
"map": map,
|
|
33
|
+
"max": max,
|
|
34
|
+
"min": min,
|
|
35
|
+
"next": next,
|
|
36
|
+
"oct": oct,
|
|
37
|
+
"ord": ord,
|
|
38
|
+
"pow": pow,
|
|
39
|
+
"print": print, # Captured for output
|
|
40
|
+
"range": range,
|
|
41
|
+
"repr": repr,
|
|
42
|
+
"reversed": reversed,
|
|
43
|
+
"round": round,
|
|
44
|
+
"slice": slice,
|
|
45
|
+
"sorted": sorted,
|
|
46
|
+
"sum": sum,
|
|
47
|
+
"zip": zip,
|
|
48
|
+
# Constants
|
|
49
|
+
"True": True,
|
|
50
|
+
"False": False,
|
|
51
|
+
"None": None,
|
|
52
|
+
# Exceptions (for try/except)
|
|
53
|
+
"Exception": Exception,
|
|
54
|
+
"ValueError": ValueError,
|
|
55
|
+
"TypeError": TypeError,
|
|
56
|
+
"KeyError": KeyError,
|
|
57
|
+
"IndexError": IndexError,
|
|
58
|
+
"AttributeError": AttributeError,
|
|
59
|
+
"RuntimeError": RuntimeError,
|
|
60
|
+
"StopIteration": StopIteration,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
# AST nodes that are NOT allowed
|
|
64
|
+
FORBIDDEN_NODES = {
|
|
65
|
+
ast.Import,
|
|
66
|
+
ast.ImportFrom,
|
|
67
|
+
ast.Global,
|
|
68
|
+
ast.Nonlocal,
|
|
69
|
+
ast.AsyncWith, # We control async, not user code
|
|
70
|
+
ast.Yield,
|
|
71
|
+
ast.YieldFrom,
|
|
72
|
+
ast.ClassDef, # No class definitions
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Forbidden function calls
|
|
76
|
+
FORBIDDEN_CALLS = {
|
|
77
|
+
"eval",
|
|
78
|
+
"exec",
|
|
79
|
+
"compile",
|
|
80
|
+
"open",
|
|
81
|
+
"input",
|
|
82
|
+
"__import__",
|
|
83
|
+
"globals",
|
|
84
|
+
"locals",
|
|
85
|
+
"vars",
|
|
86
|
+
"dir",
|
|
87
|
+
"getattr",
|
|
88
|
+
"setattr",
|
|
89
|
+
"delattr",
|
|
90
|
+
"breakpoint",
|
|
91
|
+
"exit",
|
|
92
|
+
"quit",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Forbidden attribute access patterns
|
|
96
|
+
FORBIDDEN_ATTRIBUTES = {
|
|
97
|
+
"__class__",
|
|
98
|
+
"__bases__",
|
|
99
|
+
"__subclasses__",
|
|
100
|
+
"__mro__",
|
|
101
|
+
"__code__",
|
|
102
|
+
"__globals__",
|
|
103
|
+
"__builtins__",
|
|
104
|
+
"__import__",
|
|
105
|
+
"__dict__",
|
|
106
|
+
"__module__",
|
|
107
|
+
"__reduce__",
|
|
108
|
+
"__reduce_ex__",
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class OTCSecurityError(Exception):
|
|
113
|
+
"""Raised when code violates OTC security constraints."""
|
|
114
|
+
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class OTCExecutionError(Exception):
|
|
119
|
+
"""Raised when code execution fails."""
|
|
120
|
+
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ASTValidator(ast.NodeVisitor):
|
|
125
|
+
"""Validates that an AST doesn't contain forbidden constructs."""
|
|
126
|
+
|
|
127
|
+
def __init__(self, allowed_tool_names: set[str]):
|
|
128
|
+
self.allowed_tool_names = allowed_tool_names
|
|
129
|
+
self.errors: list[str] = []
|
|
130
|
+
|
|
131
|
+
def visit(self, node: ast.AST) -> None:
|
|
132
|
+
# Check for forbidden node types
|
|
133
|
+
if type(node) in FORBIDDEN_NODES:
|
|
134
|
+
self.errors.append(
|
|
135
|
+
f"Forbidden construct: {type(node).__name__} at line {getattr(node, 'lineno', '?')}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
self.generic_visit(node)
|
|
139
|
+
|
|
140
|
+
def visit_Call(self, node: ast.Call) -> None:
|
|
141
|
+
# Check for forbidden function calls
|
|
142
|
+
if isinstance(node.func, ast.Name):
|
|
143
|
+
if node.func.id in FORBIDDEN_CALLS:
|
|
144
|
+
self.errors.append(
|
|
145
|
+
f"Forbidden function call: {node.func.id} at line {node.lineno}"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.generic_visit(node)
|
|
149
|
+
|
|
150
|
+
def visit_Attribute(self, node: ast.Attribute) -> None:
|
|
151
|
+
# Check for forbidden attribute access
|
|
152
|
+
if node.attr in FORBIDDEN_ATTRIBUTES:
|
|
153
|
+
self.errors.append(
|
|
154
|
+
f"Forbidden attribute access: {node.attr} at line {node.lineno}"
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Also check for dunder access patterns
|
|
158
|
+
if node.attr.startswith("__") and node.attr.endswith("__"):
|
|
159
|
+
if node.attr not in {"__len__", "__iter__", "__next__", "__contains__"}:
|
|
160
|
+
self.errors.append(
|
|
161
|
+
f"Forbidden dunder access: {node.attr} at line {node.lineno}"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
self.generic_visit(node)
|
|
165
|
+
|
|
166
|
+
def validate(self, tree: ast.AST) -> list[str]:
|
|
167
|
+
"""Validate the AST and return list of errors."""
|
|
168
|
+
self.errors = []
|
|
169
|
+
self.visit(tree)
|
|
170
|
+
return self.errors
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def validate_code(code: str, allowed_tool_names: set[str]) -> ast.Module:
|
|
174
|
+
"""Parse and validate code, returning AST if valid."""
|
|
175
|
+
try:
|
|
176
|
+
tree = ast.parse(code)
|
|
177
|
+
except SyntaxError as e:
|
|
178
|
+
raise OTCSecurityError(f"Syntax error: {e}")
|
|
179
|
+
|
|
180
|
+
validator = ASTValidator(allowed_tool_names)
|
|
181
|
+
errors = validator.validate(tree)
|
|
182
|
+
|
|
183
|
+
if errors:
|
|
184
|
+
raise OTCSecurityError(
|
|
185
|
+
"Security violations:\n" + "\n".join(f" - {e}" for e in errors)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return tree
|