quantalogic 0.33.1__py3-none-any.whl → 0.33.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/agent.py +2 -7
- quantalogic/agent_config.py +5 -1
- quantalogic/coding_agent.py +8 -2
- quantalogic/generative_model.py +7 -2
- quantalogic/prompts.py +2 -21
- quantalogic/tools/__init__.py +6 -1
- quantalogic/tools/read_html_tool.py +5 -3
- quantalogic/tools/safe_python_interpreter_tool.py +213 -0
- quantalogic/tools/sequence_tool.py +285 -0
- quantalogic/utils/__init__.py +1 -0
- quantalogic/utils/python_interpreter.py +835 -0
- {quantalogic-0.33.1.dist-info → quantalogic-0.33.3.dist-info}/METADATA +1 -1
- {quantalogic-0.33.1.dist-info → quantalogic-0.33.3.dist-info}/RECORD +16 -13
- {quantalogic-0.33.1.dist-info → quantalogic-0.33.3.dist-info}/LICENSE +0 -0
- {quantalogic-0.33.1.dist-info → quantalogic-0.33.3.dist-info}/WHEEL +0 -0
- {quantalogic-0.33.1.dist-info → quantalogic-0.33.3.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,835 @@
|
|
1
|
+
import ast
|
2
|
+
import builtins
|
3
|
+
import textwrap
|
4
|
+
from typing import Any, List, Dict, Optional, Tuple
|
5
|
+
|
6
|
+
# Exception used to signal a "return" from a function call.
|
7
|
+
class ReturnException(Exception):
|
8
|
+
def __init__(self, value: Any) -> None:
|
9
|
+
self.value: Any = value
|
10
|
+
|
11
|
+
# Exceptions used for loop control.
|
12
|
+
class BreakException(Exception):
|
13
|
+
pass
|
14
|
+
|
15
|
+
class ContinueException(Exception):
|
16
|
+
pass
|
17
|
+
|
18
|
+
# The main interpreter class.
|
19
|
+
class ASTInterpreter:
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
allowed_modules: List[str],
|
23
|
+
env_stack: Optional[List[Dict[str, Any]]] = None,
|
24
|
+
source: Optional[str] = None
|
25
|
+
) -> None:
|
26
|
+
self.allowed_modules: List[str] = allowed_modules
|
27
|
+
self.modules: Dict[str, Any] = {}
|
28
|
+
# Import only the allowed modules.
|
29
|
+
for mod in allowed_modules:
|
30
|
+
self.modules[mod] = __import__(mod)
|
31
|
+
if env_stack is None:
|
32
|
+
# Create a global environment (first frame) with allowed modules.
|
33
|
+
self.env_stack: List[Dict[str, Any]] = [{}]
|
34
|
+
self.env_stack[0].update(self.modules)
|
35
|
+
# Use builtins from the builtins module.
|
36
|
+
safe_builtins: Dict[str, Any] = dict(vars(builtins))
|
37
|
+
safe_builtins["__import__"] = self.safe_import
|
38
|
+
if "set" not in safe_builtins:
|
39
|
+
safe_builtins["set"] = set
|
40
|
+
self.env_stack[0]["__builtins__"] = safe_builtins
|
41
|
+
# Make builtins names (like set) directly available.
|
42
|
+
self.env_stack[0].update(safe_builtins)
|
43
|
+
if "set" not in self.env_stack[0]:
|
44
|
+
self.env_stack[0]["set"] = set
|
45
|
+
else:
|
46
|
+
self.env_stack = env_stack
|
47
|
+
# Ensure global frame has safe builtins.
|
48
|
+
if "__builtins__" not in self.env_stack[0]:
|
49
|
+
safe_builtins: Dict[str, Any] = dict(vars(builtins))
|
50
|
+
safe_builtins["__import__"] = self.safe_import
|
51
|
+
if "set" not in safe_builtins:
|
52
|
+
safe_builtins["set"] = set
|
53
|
+
self.env_stack[0]["__builtins__"] = safe_builtins
|
54
|
+
self.env_stack[0].update(safe_builtins)
|
55
|
+
if "set" not in self.env_stack[0]:
|
56
|
+
self.env_stack[0]["set"] = self.env_stack[0]["__builtins__"]["set"]
|
57
|
+
|
58
|
+
# Store source code lines for error reporting if provided.
|
59
|
+
if source is not None:
|
60
|
+
self.source_lines: Optional[List[str]] = source.splitlines()
|
61
|
+
else:
|
62
|
+
self.source_lines = None
|
63
|
+
|
64
|
+
# NEW: Add standard Decimal features if allowed.
|
65
|
+
if "decimal" in self.modules:
|
66
|
+
dec = self.modules["decimal"]
|
67
|
+
self.env_stack[0]["Decimal"] = dec.Decimal
|
68
|
+
self.env_stack[0]["getcontext"] = dec.getcontext
|
69
|
+
self.env_stack[0]["setcontext"] = dec.setcontext
|
70
|
+
self.env_stack[0]["localcontext"] = dec.localcontext
|
71
|
+
self.env_stack[0]["Context"] = dec.Context
|
72
|
+
|
73
|
+
# This safe __import__ only allows modules explicitly provided.
|
74
|
+
def safe_import(
|
75
|
+
self,
|
76
|
+
name: str,
|
77
|
+
globals: Optional[Dict[str, Any]] = None,
|
78
|
+
locals: Optional[Dict[str, Any]] = None,
|
79
|
+
fromlist: Tuple[str, ...] = (),
|
80
|
+
level: int = 0
|
81
|
+
) -> Any:
|
82
|
+
if name not in self.allowed_modules:
|
83
|
+
error_msg = f"Import Error: Module '{name}' is not allowed. Only {self.allowed_modules} are permitted."
|
84
|
+
raise ImportError(error_msg)
|
85
|
+
return self.modules[name]
|
86
|
+
|
87
|
+
# Helper: create a new interpreter instance using a given environment stack.
|
88
|
+
def spawn_from_env(self, env_stack: List[Dict[str, Any]]) -> "ASTInterpreter":
|
89
|
+
return ASTInterpreter(
|
90
|
+
self.allowed_modules,
|
91
|
+
env_stack,
|
92
|
+
source="\n".join(self.source_lines) if self.source_lines else None
|
93
|
+
)
|
94
|
+
|
95
|
+
# Look up a variable in the chain of environment frames.
|
96
|
+
def get_variable(self, name: str) -> Any:
|
97
|
+
for frame in reversed(self.env_stack):
|
98
|
+
if name in frame:
|
99
|
+
return frame[name]
|
100
|
+
raise NameError(f"Name {name} is not defined.")
|
101
|
+
|
102
|
+
# Always assign to the most local environment.
|
103
|
+
def set_variable(self, name: str, value: Any) -> None:
|
104
|
+
self.env_stack[-1][name] = value
|
105
|
+
|
106
|
+
# Used for assignment targets. This handles names and destructuring.
|
107
|
+
def assign(self, target: ast.AST, value: Any) -> None:
|
108
|
+
if isinstance(target, ast.Name):
|
109
|
+
# If current frame declares the name as global, update global frame.
|
110
|
+
if "__global_names__" in self.env_stack[-1] and target.id in self.env_stack[-1]["__global_names__"]:
|
111
|
+
self.env_stack[0][target.id] = value
|
112
|
+
else:
|
113
|
+
self.env_stack[-1][target.id] = value
|
114
|
+
elif isinstance(target, (ast.Tuple, ast.List)):
|
115
|
+
# Support single-star unpacking.
|
116
|
+
star_index = None
|
117
|
+
for i, elt in enumerate(target.elts):
|
118
|
+
if isinstance(elt, ast.Starred):
|
119
|
+
if star_index is not None:
|
120
|
+
raise Exception("Multiple starred expressions not supported")
|
121
|
+
star_index = i
|
122
|
+
if star_index is None:
|
123
|
+
if len(target.elts) != len(value):
|
124
|
+
raise ValueError("Unpacking mismatch")
|
125
|
+
for t, v in zip(target.elts, value):
|
126
|
+
self.assign(t, v)
|
127
|
+
else:
|
128
|
+
total = len(value)
|
129
|
+
before = target.elts[:star_index]
|
130
|
+
after = target.elts[star_index+1:]
|
131
|
+
if len(before) + len(after) > total:
|
132
|
+
raise ValueError("Unpacking mismatch")
|
133
|
+
for i, elt2 in enumerate(before):
|
134
|
+
self.assign(elt2, value[i])
|
135
|
+
starred_count = total - len(before) - len(after)
|
136
|
+
self.assign(target.elts[star_index].value, value[len(before):len(before)+starred_count])
|
137
|
+
for j, elt2 in enumerate(after):
|
138
|
+
self.assign(elt2, value[len(before)+starred_count+j])
|
139
|
+
elif isinstance(target, ast.Attribute):
|
140
|
+
obj = self.visit(target.value)
|
141
|
+
setattr(obj, target.attr, value)
|
142
|
+
elif isinstance(target, ast.Subscript):
|
143
|
+
obj = self.visit(target.value)
|
144
|
+
key = self.visit(target.slice)
|
145
|
+
obj[key] = value
|
146
|
+
else:
|
147
|
+
raise Exception("Unsupported assignment target type: " + str(type(target)))
|
148
|
+
|
149
|
+
# Main visitor dispatch.
|
150
|
+
def visit(self, node: ast.AST) -> Any:
|
151
|
+
method_name: str = "visit_" + node.__class__.__name__
|
152
|
+
method = getattr(self, method_name, self.generic_visit)
|
153
|
+
try:
|
154
|
+
return method(node)
|
155
|
+
except (ReturnException, BreakException, ContinueException):
|
156
|
+
raise
|
157
|
+
except Exception as e:
|
158
|
+
lineno = getattr(node, "lineno", None)
|
159
|
+
col = getattr(node, "col_offset", None)
|
160
|
+
lineno = lineno if lineno is not None else 1
|
161
|
+
col = col if col is not None else 0
|
162
|
+
context_line = ""
|
163
|
+
if self.source_lines and 1 <= lineno <= len(self.source_lines):
|
164
|
+
context_line = self.source_lines[lineno - 1]
|
165
|
+
raise Exception(
|
166
|
+
f"Error line {lineno}, col {col}:\n{context_line}\nDescription: {str(e)}"
|
167
|
+
) from e
|
168
|
+
|
169
|
+
# Fallback for unsupported nodes.
|
170
|
+
def generic_visit(self, node: ast.AST) -> Any:
|
171
|
+
lineno = getattr(node, "lineno", None)
|
172
|
+
context_line = ""
|
173
|
+
if self.source_lines and lineno is not None and 1 <= lineno <= len(self.source_lines):
|
174
|
+
context_line = self.source_lines[lineno - 1]
|
175
|
+
raise Exception(
|
176
|
+
f"Unsupported AST node type: {node.__class__.__name__} at line {lineno}.\nContext: {context_line}"
|
177
|
+
)
|
178
|
+
|
179
|
+
# --- Visitor for Import nodes ---
|
180
|
+
def visit_Import(self, node: ast.Import) -> None:
|
181
|
+
"""
|
182
|
+
Process an import statement.
|
183
|
+
Only allowed modules can be imported.
|
184
|
+
"""
|
185
|
+
for alias in node.names:
|
186
|
+
module_name: str = alias.name
|
187
|
+
asname: str = alias.asname if alias.asname is not None else module_name
|
188
|
+
if module_name not in self.allowed_modules:
|
189
|
+
raise Exception(f"Import Error: Module '{module_name}' is not allowed. Only {self.allowed_modules} are permitted.")
|
190
|
+
self.set_variable(asname, self.modules[module_name])
|
191
|
+
|
192
|
+
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
|
193
|
+
if not node.module:
|
194
|
+
raise Exception("Import Error: Missing module name in 'from ... import ...' statement")
|
195
|
+
if node.module not in self.allowed_modules:
|
196
|
+
raise Exception(f"Import Error: Module '{node.module}' is not allowed. Only {self.allowed_modules} are permitted.")
|
197
|
+
for alias in node.names:
|
198
|
+
if alias.name == "*":
|
199
|
+
raise Exception("Import Error: 'from ... import *' is not supported.")
|
200
|
+
asname = alias.asname if alias.asname else alias.name
|
201
|
+
attr = getattr(self.modules[node.module], alias.name)
|
202
|
+
self.set_variable(asname, attr)
|
203
|
+
|
204
|
+
# --- Visitor for ListComprehension nodes ---
|
205
|
+
def visit_ListComp(self, node: ast.ListComp) -> List[Any]:
|
206
|
+
"""
|
207
|
+
Process a list comprehension, e.g., [elt for ... in ... if ...].
|
208
|
+
The comprehension is executed in a new local frame that inherits the
|
209
|
+
current environment.
|
210
|
+
"""
|
211
|
+
result: List[Any] = []
|
212
|
+
# Copy the current top-level frame for the comprehension scope.
|
213
|
+
base_frame: Dict[str, Any] = self.env_stack[-1].copy()
|
214
|
+
self.env_stack.append(base_frame)
|
215
|
+
|
216
|
+
def rec(gen_idx: int) -> None:
|
217
|
+
if gen_idx == len(node.generators):
|
218
|
+
result.append(self.visit(node.elt))
|
219
|
+
else:
|
220
|
+
comp = node.generators[gen_idx]
|
221
|
+
iterable = self.visit(comp.iter)
|
222
|
+
for item in iterable:
|
223
|
+
# Push a new frame that inherits the current comprehension scope.
|
224
|
+
new_frame: Dict[str, Any] = self.env_stack[-1].copy()
|
225
|
+
self.env_stack.append(new_frame)
|
226
|
+
self.assign(comp.target, item)
|
227
|
+
if all(self.visit(if_clause) for if_clause in comp.ifs):
|
228
|
+
rec(gen_idx + 1)
|
229
|
+
self.env_stack.pop()
|
230
|
+
|
231
|
+
rec(0)
|
232
|
+
self.env_stack.pop()
|
233
|
+
return result
|
234
|
+
|
235
|
+
# --- Other node visitors below ---
|
236
|
+
def visit_Module(self, node: ast.Module) -> Any:
|
237
|
+
# Execute all statements then return the 'result' variable if set.
|
238
|
+
last_value: Any = None
|
239
|
+
for stmt in node.body:
|
240
|
+
last_value = self.visit(stmt)
|
241
|
+
return self.env_stack[0].get("result", last_value)
|
242
|
+
|
243
|
+
def visit_Expr(self, node: ast.Expr) -> Any:
|
244
|
+
return self.visit(node.value)
|
245
|
+
|
246
|
+
def visit_Constant(self, node: ast.Constant) -> Any:
|
247
|
+
return node.value
|
248
|
+
|
249
|
+
def visit_Name(self, node: ast.Name) -> Any:
|
250
|
+
if isinstance(node.ctx, ast.Load):
|
251
|
+
return self.get_variable(node.id)
|
252
|
+
elif isinstance(node.ctx, ast.Store):
|
253
|
+
return node.id
|
254
|
+
else:
|
255
|
+
raise Exception("Unsupported context for Name")
|
256
|
+
|
257
|
+
def visit_BinOp(self, node: ast.BinOp) -> Any:
|
258
|
+
left: Any = self.visit(node.left)
|
259
|
+
right: Any = self.visit(node.right)
|
260
|
+
op = node.op
|
261
|
+
if isinstance(op, ast.Add):
|
262
|
+
return left + right
|
263
|
+
elif isinstance(op, ast.Sub):
|
264
|
+
return left - right
|
265
|
+
elif isinstance(op, ast.Mult):
|
266
|
+
return left * right
|
267
|
+
elif isinstance(op, ast.Div):
|
268
|
+
return left / right
|
269
|
+
elif isinstance(op, ast.FloorDiv):
|
270
|
+
return left // right
|
271
|
+
elif isinstance(op, ast.Mod):
|
272
|
+
return left % right
|
273
|
+
elif isinstance(op, ast.Pow):
|
274
|
+
return left ** right
|
275
|
+
elif isinstance(op, ast.LShift):
|
276
|
+
return left << right
|
277
|
+
elif isinstance(op, ast.RShift):
|
278
|
+
return left >> right
|
279
|
+
elif isinstance(op, ast.BitOr):
|
280
|
+
return left | right
|
281
|
+
elif isinstance(op, ast.BitXor):
|
282
|
+
return left ^ right
|
283
|
+
elif isinstance(op, ast.BitAnd):
|
284
|
+
return left & right
|
285
|
+
else:
|
286
|
+
raise Exception("Unsupported binary operator: " + str(op))
|
287
|
+
|
288
|
+
def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
|
289
|
+
operand: Any = self.visit(node.operand)
|
290
|
+
op = node.op
|
291
|
+
if isinstance(op, ast.UAdd):
|
292
|
+
return +operand
|
293
|
+
elif isinstance(op, ast.USub):
|
294
|
+
return -operand
|
295
|
+
elif isinstance(op, ast.Not):
|
296
|
+
return not operand
|
297
|
+
elif isinstance(op, ast.Invert):
|
298
|
+
return ~operand
|
299
|
+
else:
|
300
|
+
raise Exception("Unsupported unary operator: " + str(op))
|
301
|
+
|
302
|
+
def visit_Assign(self, node: ast.Assign) -> None:
|
303
|
+
value: Any = self.visit(node.value)
|
304
|
+
for target in node.targets:
|
305
|
+
self.assign(target, value)
|
306
|
+
|
307
|
+
def visit_AugAssign(self, node: ast.AugAssign) -> Any:
|
308
|
+
# If target is a Name, get its current value from the environment.
|
309
|
+
if isinstance(node.target, ast.Name):
|
310
|
+
current_val: Any = self.get_variable(node.target.id)
|
311
|
+
else:
|
312
|
+
current_val: Any = self.visit(node.target)
|
313
|
+
right_val: Any = self.visit(node.value)
|
314
|
+
op = node.op
|
315
|
+
if isinstance(op, ast.Add):
|
316
|
+
result: Any = current_val + right_val
|
317
|
+
elif isinstance(op, ast.Sub):
|
318
|
+
result = current_val - right_val
|
319
|
+
elif isinstance(op, ast.Mult):
|
320
|
+
result = current_val * right_val
|
321
|
+
elif isinstance(op, ast.Div):
|
322
|
+
result = current_val / right_val
|
323
|
+
elif isinstance(op, ast.FloorDiv):
|
324
|
+
result = current_val // right_val
|
325
|
+
elif isinstance(op, ast.Mod):
|
326
|
+
result = current_val % right_val
|
327
|
+
elif isinstance(op, ast.Pow):
|
328
|
+
result = current_val ** right_val
|
329
|
+
elif isinstance(op, ast.BitAnd):
|
330
|
+
result = current_val & right_val
|
331
|
+
elif isinstance(op, ast.BitOr):
|
332
|
+
result = current_val | right_val
|
333
|
+
elif isinstance(op, ast.BitXor):
|
334
|
+
result = current_val ^ right_val
|
335
|
+
elif isinstance(op, ast.LShift):
|
336
|
+
result = current_val << right_val
|
337
|
+
elif isinstance(op, ast.RShift):
|
338
|
+
result = current_val >> right_val
|
339
|
+
else:
|
340
|
+
raise Exception("Unsupported augmented operator: " + str(op))
|
341
|
+
self.assign(node.target, result)
|
342
|
+
return result
|
343
|
+
|
344
|
+
def visit_Compare(self, node: ast.Compare) -> bool:
|
345
|
+
left: Any = self.visit(node.left)
|
346
|
+
for op, comparator in zip(node.ops, node.comparators):
|
347
|
+
right: Any = self.visit(comparator)
|
348
|
+
if isinstance(op, ast.Eq):
|
349
|
+
if not (left == right):
|
350
|
+
return False
|
351
|
+
elif isinstance(op, ast.NotEq):
|
352
|
+
if not (left != right):
|
353
|
+
return False
|
354
|
+
elif isinstance(op, ast.Lt):
|
355
|
+
if not (left < right):
|
356
|
+
return False
|
357
|
+
elif isinstance(op, ast.LtE):
|
358
|
+
if not (left <= right):
|
359
|
+
return False
|
360
|
+
elif isinstance(op, ast.Gt):
|
361
|
+
if not (left > right):
|
362
|
+
return False
|
363
|
+
elif isinstance(op, ast.GtE):
|
364
|
+
if not (left >= right):
|
365
|
+
return False
|
366
|
+
elif isinstance(op, ast.Is):
|
367
|
+
if not (left is right):
|
368
|
+
return False
|
369
|
+
elif isinstance(op, ast.IsNot):
|
370
|
+
if not (left is not right):
|
371
|
+
return False
|
372
|
+
elif isinstance(op, ast.In):
|
373
|
+
if not (left in right):
|
374
|
+
return False
|
375
|
+
elif isinstance(op, ast.NotIn):
|
376
|
+
if not (left not in right):
|
377
|
+
return False
|
378
|
+
else:
|
379
|
+
raise Exception("Unsupported comparison operator: " + str(op))
|
380
|
+
left = right
|
381
|
+
return True
|
382
|
+
|
383
|
+
def visit_BoolOp(self, node: ast.BoolOp) -> bool:
|
384
|
+
if isinstance(node.op, ast.And):
|
385
|
+
for value in node.values:
|
386
|
+
if not self.visit(value):
|
387
|
+
return False
|
388
|
+
return True
|
389
|
+
elif isinstance(node.op, ast.Or):
|
390
|
+
for value in node.values:
|
391
|
+
if self.visit(value):
|
392
|
+
return True
|
393
|
+
return False
|
394
|
+
else:
|
395
|
+
raise Exception("Unsupported boolean operator: " + str(node.op))
|
396
|
+
|
397
|
+
def visit_If(self, node: ast.If) -> Any:
|
398
|
+
if self.visit(node.test):
|
399
|
+
branch = node.body
|
400
|
+
else:
|
401
|
+
branch = node.orelse
|
402
|
+
result = None
|
403
|
+
if branch:
|
404
|
+
for stmt in branch[:-1]:
|
405
|
+
# Execute all but the last statement
|
406
|
+
self.visit(stmt)
|
407
|
+
# Return value from the last statement
|
408
|
+
result = self.visit(branch[-1])
|
409
|
+
return result
|
410
|
+
|
411
|
+
def visit_While(self, node: ast.While) -> None:
|
412
|
+
while self.visit(node.test):
|
413
|
+
try:
|
414
|
+
for stmt in node.body:
|
415
|
+
self.visit(stmt)
|
416
|
+
except BreakException:
|
417
|
+
break
|
418
|
+
except ContinueException:
|
419
|
+
continue
|
420
|
+
for stmt in node.orelse:
|
421
|
+
self.visit(stmt)
|
422
|
+
|
423
|
+
def visit_For(self, node: ast.For) -> None:
|
424
|
+
iter_obj: Any = self.visit(node.iter)
|
425
|
+
for item in iter_obj:
|
426
|
+
self.assign(node.target, item)
|
427
|
+
try:
|
428
|
+
for stmt in node.body:
|
429
|
+
self.visit(stmt)
|
430
|
+
except BreakException:
|
431
|
+
break
|
432
|
+
except ContinueException:
|
433
|
+
continue
|
434
|
+
for stmt in node.orelse:
|
435
|
+
self.visit(stmt)
|
436
|
+
|
437
|
+
def visit_Break(self, node: ast.Break) -> None:
|
438
|
+
raise BreakException()
|
439
|
+
|
440
|
+
def visit_Continue(self, node: ast.Continue) -> None:
|
441
|
+
raise ContinueException()
|
442
|
+
|
443
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
|
444
|
+
# Capture the current env_stack for a closure without copying inner dicts.
|
445
|
+
closure: List[Dict[str, Any]] = self.env_stack[:] # <-- changed here
|
446
|
+
func = Function(node, closure, self)
|
447
|
+
self.set_variable(node.name, func)
|
448
|
+
|
449
|
+
def visit_Call(self, node: ast.Call) -> Any:
|
450
|
+
func = self.visit(node.func)
|
451
|
+
args: List[Any] = [self.visit(arg) for arg in node.args]
|
452
|
+
kwargs: Dict[str, Any] = {kw.arg: self.visit(kw.value) for kw in node.keywords}
|
453
|
+
return func(*args, **kwargs)
|
454
|
+
|
455
|
+
def visit_Return(self, node: ast.Return) -> None:
|
456
|
+
value: Any = self.visit(node.value) if node.value is not None else None
|
457
|
+
raise ReturnException(value)
|
458
|
+
|
459
|
+
def visit_Lambda(self, node: ast.Lambda) -> Any:
|
460
|
+
closure: List[Dict[str, Any]] = self.env_stack[:] # <-- changed here
|
461
|
+
return LambdaFunction(node, closure, self)
|
462
|
+
|
463
|
+
def visit_List(self, node: ast.List) -> List[Any]:
|
464
|
+
return [self.visit(elt) for elt in node.elts]
|
465
|
+
|
466
|
+
def visit_Tuple(self, node: ast.Tuple) -> Tuple[Any, ...]:
|
467
|
+
return tuple(self.visit(elt) for elt in node.elts)
|
468
|
+
|
469
|
+
def visit_Dict(self, node: ast.Dict) -> Dict[Any, Any]:
|
470
|
+
return {self.visit(k): self.visit(v) for k, v in zip(node.keys, node.values)}
|
471
|
+
|
472
|
+
def visit_Set(self, node: ast.Set) -> set:
|
473
|
+
return set(self.visit(elt) for elt in node.elts)
|
474
|
+
|
475
|
+
def visit_Attribute(self, node: ast.Attribute) -> Any:
|
476
|
+
value: Any = self.visit(node.value)
|
477
|
+
return getattr(value, node.attr)
|
478
|
+
|
479
|
+
def visit_Subscript(self, node: ast.Subscript) -> Any:
|
480
|
+
value: Any = self.visit(node.value)
|
481
|
+
slice_val: Any = self.visit(node.slice)
|
482
|
+
return value[slice_val]
|
483
|
+
|
484
|
+
def visit_Slice(self, node: ast.Slice) -> slice:
|
485
|
+
lower: Any = self.visit(node.lower) if node.lower else None
|
486
|
+
upper: Any = self.visit(node.upper) if node.upper else None
|
487
|
+
step: Any = self.visit(node.step) if node.step else None
|
488
|
+
return slice(lower, upper, step)
|
489
|
+
|
490
|
+
# For compatibility with older AST versions.
|
491
|
+
def visit_Index(self, node: ast.Index) -> Any:
|
492
|
+
return self.visit(node.value)
|
493
|
+
|
494
|
+
# Visitor for Pass nodes.
|
495
|
+
def visit_Pass(self, node: ast.Pass) -> None:
|
496
|
+
# Simply ignore 'pass' statements.
|
497
|
+
return None
|
498
|
+
|
499
|
+
def visit_TypeIgnore(self, node: ast.TypeIgnore) -> None:
|
500
|
+
pass
|
501
|
+
|
502
|
+
def visit_Try(self, node: ast.Try) -> Any:
|
503
|
+
result: Any = None
|
504
|
+
exc_info: Optional[tuple] = None
|
505
|
+
|
506
|
+
try:
|
507
|
+
for stmt in node.body:
|
508
|
+
result = self.visit(stmt)
|
509
|
+
except Exception as e:
|
510
|
+
exc_info = (type(e), e, e.__traceback__)
|
511
|
+
for handler in node.handlers:
|
512
|
+
# Modified resolution for exception type.
|
513
|
+
if handler.type is None:
|
514
|
+
exc_type = Exception
|
515
|
+
elif isinstance(handler.type, ast.Constant) and isinstance(handler.type.value, type):
|
516
|
+
exc_type = handler.type.value
|
517
|
+
elif isinstance(handler.type, ast.Name):
|
518
|
+
exc_type = self.get_variable(handler.type.id)
|
519
|
+
else:
|
520
|
+
exc_type = self.visit(handler.type)
|
521
|
+
# Use issubclass on the exception type rather than isinstance on the exception instance.
|
522
|
+
if exc_info and issubclass(exc_info[0], exc_type):
|
523
|
+
if handler.name:
|
524
|
+
self.set_variable(handler.name, exc_info[1])
|
525
|
+
for stmt in handler.body:
|
526
|
+
result = self.visit(stmt)
|
527
|
+
exc_info = None # Mark as handled
|
528
|
+
break
|
529
|
+
if exc_info:
|
530
|
+
raise exc_info[1]
|
531
|
+
else:
|
532
|
+
for stmt in node.orelse:
|
533
|
+
result = self.visit(stmt)
|
534
|
+
finally:
|
535
|
+
for stmt in node.finalbody:
|
536
|
+
try:
|
537
|
+
self.visit(stmt)
|
538
|
+
except ReturnException:
|
539
|
+
raise
|
540
|
+
except Exception:
|
541
|
+
if exc_info:
|
542
|
+
raise exc_info[1]
|
543
|
+
raise
|
544
|
+
|
545
|
+
return result
|
546
|
+
|
547
|
+
def visit_Nonlocal(self, node: ast.Nonlocal) -> None:
|
548
|
+
# Minimal support – assume these names exist in an outer frame.
|
549
|
+
return None
|
550
|
+
|
551
|
+
def visit_JoinedStr(self, node: ast.JoinedStr) -> str:
|
552
|
+
# Support f-string: concatenate all parts.
|
553
|
+
return "".join(self.visit(value) for value in node.values)
|
554
|
+
|
555
|
+
def visit_FormattedValue(self, node: ast.FormattedValue) -> str:
|
556
|
+
# Format the embedded expression.
|
557
|
+
return str(self.visit(node.value))
|
558
|
+
|
559
|
+
def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any:
|
560
|
+
# Process a generator expression.
|
561
|
+
def generator():
|
562
|
+
base_frame: Dict[str, Any] = self.env_stack[-1].copy()
|
563
|
+
self.env_stack.append(base_frame)
|
564
|
+
|
565
|
+
def rec(gen_idx: int):
|
566
|
+
if gen_idx == len(node.generators):
|
567
|
+
yield self.visit(node.elt)
|
568
|
+
else:
|
569
|
+
comp = node.generators[gen_idx]
|
570
|
+
iterable = self.visit(comp.iter)
|
571
|
+
for item in iterable:
|
572
|
+
new_frame: Dict[str, Any] = self.env_stack[-1].copy()
|
573
|
+
self.env_stack.append(new_frame)
|
574
|
+
self.assign(comp.target, item)
|
575
|
+
if all(self.visit(if_clause) for if_clause in comp.ifs):
|
576
|
+
yield from rec(gen_idx + 1)
|
577
|
+
self.env_stack.pop()
|
578
|
+
gen = list(rec(0))
|
579
|
+
self.env_stack.pop()
|
580
|
+
for val in gen:
|
581
|
+
yield val
|
582
|
+
return generator()
|
583
|
+
|
584
|
+
def visit_ClassDef(self, node: ast.ClassDef):
|
585
|
+
base_frame = self.env_stack[-1].copy()
|
586
|
+
self.env_stack.append(base_frame)
|
587
|
+
try:
|
588
|
+
for stmt in node.body:
|
589
|
+
self.visit(stmt)
|
590
|
+
class_dict = {
|
591
|
+
k: v for k, v in self.env_stack[-1].items()
|
592
|
+
if k not in ["__builtins__"]
|
593
|
+
}
|
594
|
+
finally:
|
595
|
+
self.env_stack.pop()
|
596
|
+
new_class = type(node.name, (), class_dict)
|
597
|
+
self.set_variable(node.name, new_class)
|
598
|
+
|
599
|
+
def visit_With(self, node: ast.With):
|
600
|
+
for item in node.items:
|
601
|
+
ctx = self.visit(item.context_expr)
|
602
|
+
val = ctx.__enter__()
|
603
|
+
if item.optional_vars:
|
604
|
+
self.assign(item.optional_vars, val)
|
605
|
+
try:
|
606
|
+
for stmt in node.body:
|
607
|
+
self.visit(stmt)
|
608
|
+
except Exception as e:
|
609
|
+
if not ctx.__exit__(type(e), e, None):
|
610
|
+
raise
|
611
|
+
else:
|
612
|
+
ctx.__exit__(None, None, None)
|
613
|
+
|
614
|
+
def visit_Raise(self, node: ast.Raise):
|
615
|
+
exc = self.visit(node.exc) if node.exc else None
|
616
|
+
if exc:
|
617
|
+
raise exc
|
618
|
+
raise Exception("Raise with no exception specified")
|
619
|
+
|
620
|
+
def visit_Global(self, node: ast.Global):
|
621
|
+
self.env_stack[-1].setdefault("__global_names__", set()).update(node.names)
|
622
|
+
|
623
|
+
def visit_IfExp(self, node: ast.IfExp):
|
624
|
+
return self.visit(node.body) if self.visit(node.test) else self.visit(node.orelse)
|
625
|
+
|
626
|
+
def visit_DictComp(self, node: ast.DictComp):
|
627
|
+
result = {}
|
628
|
+
base_frame = self.env_stack[-1].copy()
|
629
|
+
self.env_stack.append(base_frame)
|
630
|
+
def rec(gen_idx: int):
|
631
|
+
if gen_idx == len(node.generators):
|
632
|
+
key = self.visit(node.key)
|
633
|
+
val = self.visit(node.value)
|
634
|
+
result[key] = val
|
635
|
+
else:
|
636
|
+
comp = node.generators[gen_idx]
|
637
|
+
for item in self.visit(comp.iter):
|
638
|
+
new_frame = self.env_stack[-1].copy()
|
639
|
+
self.env_stack.append(new_frame)
|
640
|
+
self.assign(comp.target, item)
|
641
|
+
if all(self.visit(if_clause) for if_clause in comp.ifs):
|
642
|
+
rec(gen_idx + 1)
|
643
|
+
self.env_stack.pop()
|
644
|
+
rec(0)
|
645
|
+
self.env_stack.pop()
|
646
|
+
return result
|
647
|
+
|
648
|
+
def visit_SetComp(self, node: ast.SetComp):
|
649
|
+
result = set()
|
650
|
+
base_frame = self.env_stack[-1].copy()
|
651
|
+
self.env_stack.append(base_frame)
|
652
|
+
def rec(gen_idx: int):
|
653
|
+
if gen_idx == len(node.generators):
|
654
|
+
result.add(self.visit(node.elt))
|
655
|
+
else:
|
656
|
+
comp = node.generators[gen_idx]
|
657
|
+
for item in self.visit(comp.iter):
|
658
|
+
new_frame = self.env_stack[-1].copy()
|
659
|
+
self.env_stack.append(new_frame)
|
660
|
+
self.assign(comp.target, item)
|
661
|
+
if all(self.visit(if_clause) for if_clause in comp.ifs):
|
662
|
+
rec(gen_idx + 1)
|
663
|
+
self.env_stack.pop()
|
664
|
+
rec(0)
|
665
|
+
self.env_stack.pop()
|
666
|
+
return result
|
667
|
+
|
668
|
+
# Class to represent a user-defined function.
|
669
|
+
class Function:
|
670
|
+
def __init__(self, node: ast.FunctionDef, closure: List[Dict[str, Any]], interpreter: ASTInterpreter) -> None:
|
671
|
+
self.node: ast.FunctionDef = node
|
672
|
+
# Shallow copy to support recursion.
|
673
|
+
self.closure: List[Dict[str, Any]] = self.env_stack_reference(closure)
|
674
|
+
self.interpreter: ASTInterpreter = interpreter
|
675
|
+
|
676
|
+
# Helper to simply return the given environment stack (shallow copy of list refs).
|
677
|
+
def env_stack_reference(self, env_stack: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
678
|
+
return env_stack[:] # shallow
|
679
|
+
|
680
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
681
|
+
new_env_stack: List[Dict[str, Any]] = self.closure[:]
|
682
|
+
local_frame: Dict[str, Any] = {}
|
683
|
+
# Bind the function into its own local frame for recursion.
|
684
|
+
local_frame[self.node.name] = self
|
685
|
+
# For simplicity, only positional parameters are supported.
|
686
|
+
if len(args) < len(self.node.args.args):
|
687
|
+
raise TypeError("Not enough arguments provided")
|
688
|
+
if len(args) > len(self.node.args.args):
|
689
|
+
raise TypeError("Too many arguments provided")
|
690
|
+
if kwargs:
|
691
|
+
raise TypeError("Keyword arguments are not supported")
|
692
|
+
for i, arg in enumerate(self.node.args.args):
|
693
|
+
local_frame[arg.arg] = args[i]
|
694
|
+
new_env_stack.append(local_frame)
|
695
|
+
new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
|
696
|
+
try:
|
697
|
+
for stmt in self.node.body[:-1]:
|
698
|
+
new_interp.visit(stmt)
|
699
|
+
return new_interp.visit(self.node.body[-1])
|
700
|
+
except ReturnException as ret:
|
701
|
+
return ret.value
|
702
|
+
return None
|
703
|
+
|
704
|
+
# Add __get__ to support method binding.
|
705
|
+
def __get__(self, instance: Any, owner: Any):
|
706
|
+
def method(*args: Any, **kwargs: Any) -> Any:
|
707
|
+
return self(instance, *args, **kwargs)
|
708
|
+
return method
|
709
|
+
|
710
|
+
# Class to represent a lambda function.
|
711
|
+
class LambdaFunction:
|
712
|
+
def __init__(self, node: ast.Lambda, closure: List[Dict[str, Any]], interpreter: ASTInterpreter) -> None:
|
713
|
+
self.node: ast.Lambda = node
|
714
|
+
self.closure: List[Dict[str, Any]] = self.env_stack_reference(closure)
|
715
|
+
self.interpreter: ASTInterpreter = interpreter
|
716
|
+
|
717
|
+
def env_stack_reference(self, env_stack: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
718
|
+
return env_stack[:]
|
719
|
+
|
720
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
721
|
+
new_env_stack: List[Dict[str, Any]] = self.closure[:]
|
722
|
+
local_frame: Dict[str, Any] = {}
|
723
|
+
if len(args) < len(self.node.args.args):
|
724
|
+
raise TypeError("Not enough arguments for lambda")
|
725
|
+
if len(args) > len(self.node.args.args):
|
726
|
+
raise TypeError("Too many arguments for lambda")
|
727
|
+
if kwargs:
|
728
|
+
raise TypeError("Lambda does not support keyword arguments")
|
729
|
+
for i, arg in enumerate(self.node.args.args):
|
730
|
+
local_frame[arg.arg] = args[i]
|
731
|
+
new_env_stack.append(local_frame)
|
732
|
+
new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
|
733
|
+
return new_interp.visit(self.node.body)
|
734
|
+
|
735
|
+
# The main function to interpret an AST.
|
736
|
+
def interpret_ast(ast_tree: Any, allowed_modules: list[str], source: str = "") -> Any:
|
737
|
+
import ast
|
738
|
+
# Keep only yield-based nodes in fallback.
|
739
|
+
unsupported = (ast.Yield, ast.YieldFrom)
|
740
|
+
for node in ast.walk(ast_tree):
|
741
|
+
if isinstance(node, unsupported):
|
742
|
+
safe_globals = {
|
743
|
+
"__builtins__": {
|
744
|
+
"range": range,
|
745
|
+
"len": len,
|
746
|
+
"print": print,
|
747
|
+
"__import__": __import__,
|
748
|
+
"ZeroDivisionError": ZeroDivisionError,
|
749
|
+
"ValueError": ValueError,
|
750
|
+
"NameError": NameError,
|
751
|
+
"TypeError": TypeError,
|
752
|
+
"list": list,
|
753
|
+
"dict": dict,
|
754
|
+
"tuple": tuple,
|
755
|
+
"set": set,
|
756
|
+
"float": float,
|
757
|
+
"int": int,
|
758
|
+
"bool": bool,
|
759
|
+
"Exception": Exception
|
760
|
+
}
|
761
|
+
}
|
762
|
+
for mod in allowed_modules:
|
763
|
+
safe_globals[mod] = __import__(mod)
|
764
|
+
local_vars = {}
|
765
|
+
# ...existing code...
|
766
|
+
exec(compile(ast_tree, "<string>", "exec"), safe_globals, local_vars)
|
767
|
+
return local_vars.get("result", None)
|
768
|
+
# Otherwise, use the custom interpreter.
|
769
|
+
interpreter = ASTInterpreter(allowed_modules=allowed_modules, source=source)
|
770
|
+
return interpreter.visit(ast_tree)
|
771
|
+
|
772
|
+
# A helper function which takes a Python code string and a list of allowed module names,
|
773
|
+
# then parses and interprets the code.
|
774
|
+
def interpret_code(source_code: str, allowed_modules: List[str]) -> Any:
|
775
|
+
"""
|
776
|
+
Interpret a Python source code string with a restricted set of allowed modules.
|
777
|
+
|
778
|
+
:param source_code: The Python source code to interpret.
|
779
|
+
:param allowed_modules: A list of module names that are allowed.
|
780
|
+
:return: The result of interpreting the source code.
|
781
|
+
"""
|
782
|
+
# Dedent the source to normalize its indentation.
|
783
|
+
dedented_source = textwrap.dedent(source_code)
|
784
|
+
tree: ast.AST = ast.parse(dedented_source)
|
785
|
+
return interpret_ast(tree, allowed_modules, source=dedented_source)
|
786
|
+
|
787
|
+
if __name__ == "__main__":
|
788
|
+
print("Script is running!")
|
789
|
+
source_code_1: str = """
|
790
|
+
import math
|
791
|
+
def square(x):
|
792
|
+
return x * x
|
793
|
+
|
794
|
+
y = square(5)
|
795
|
+
z = math.sqrt(y)
|
796
|
+
z
|
797
|
+
"""
|
798
|
+
# Only "math" is allowed here.
|
799
|
+
try:
|
800
|
+
result_1: Any = interpret_code(source_code_1, allowed_modules=["math"])
|
801
|
+
print("Result:", result_1)
|
802
|
+
except Exception as e:
|
803
|
+
print("Interpreter error:", e)
|
804
|
+
|
805
|
+
print("Second example:")
|
806
|
+
|
807
|
+
# Define the source code with multiple operations and a list comprehension.
|
808
|
+
source_code_2: str = """
|
809
|
+
import math
|
810
|
+
import numpy as np
|
811
|
+
def transform_array(x):
|
812
|
+
# Apply square root
|
813
|
+
sqrt_vals = [math.sqrt(val) for val in x]
|
814
|
+
|
815
|
+
# Apply sine function
|
816
|
+
sin_vals = [math.sin(val) for val in sqrt_vals]
|
817
|
+
|
818
|
+
# Apply exponential
|
819
|
+
exp_vals = [math.exp(val) for val in sin_vals]
|
820
|
+
|
821
|
+
return exp_vals
|
822
|
+
|
823
|
+
array_input = np.array([1, 4, 9, 16, 25])
|
824
|
+
result = transform_array(array_input)
|
825
|
+
result
|
826
|
+
"""
|
827
|
+
print("About to parse source code")
|
828
|
+
try:
|
829
|
+
tree_2: ast.AST = ast.parse(textwrap.dedent(source_code_2))
|
830
|
+
print("Source code parsed successfully")
|
831
|
+
# Allow both math and numpy.
|
832
|
+
result_2: Any = interpret_ast(tree_2, allowed_modules=["math", "numpy"], source=textwrap.dedent(source_code_2))
|
833
|
+
print("Result:", result_2)
|
834
|
+
except Exception as e:
|
835
|
+
print("Interpreter error:", e)
|