quantalogic 0.60.0__py3-none-any.whl → 0.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. quantalogic/agent_config.py +5 -5
  2. quantalogic/agent_factory.py +2 -2
  3. quantalogic/codeact/__init__.py +0 -0
  4. quantalogic/codeact/agent.py +499 -0
  5. quantalogic/codeact/cli.py +232 -0
  6. quantalogic/codeact/constants.py +9 -0
  7. quantalogic/codeact/events.py +78 -0
  8. quantalogic/codeact/llm_util.py +76 -0
  9. quantalogic/codeact/prompts/error_format.j2 +11 -0
  10. quantalogic/codeact/prompts/generate_action.j2 +26 -0
  11. quantalogic/codeact/prompts/generate_program.j2 +39 -0
  12. quantalogic/codeact/prompts/response_format.j2 +11 -0
  13. quantalogic/codeact/tools_manager.py +135 -0
  14. quantalogic/codeact/utils.py +135 -0
  15. quantalogic/coding_agent.py +2 -2
  16. quantalogic/python_interpreter/__init__.py +23 -0
  17. quantalogic/python_interpreter/assignment_visitors.py +63 -0
  18. quantalogic/python_interpreter/base_visitors.py +20 -0
  19. quantalogic/python_interpreter/class_visitors.py +22 -0
  20. quantalogic/python_interpreter/comprehension_visitors.py +172 -0
  21. quantalogic/python_interpreter/context_visitors.py +59 -0
  22. quantalogic/python_interpreter/control_flow_visitors.py +88 -0
  23. quantalogic/python_interpreter/exception_visitors.py +109 -0
  24. quantalogic/python_interpreter/exceptions.py +39 -0
  25. quantalogic/python_interpreter/execution.py +202 -0
  26. quantalogic/python_interpreter/function_utils.py +386 -0
  27. quantalogic/python_interpreter/function_visitors.py +209 -0
  28. quantalogic/python_interpreter/import_visitors.py +28 -0
  29. quantalogic/python_interpreter/interpreter_core.py +358 -0
  30. quantalogic/python_interpreter/literal_visitors.py +74 -0
  31. quantalogic/python_interpreter/misc_visitors.py +148 -0
  32. quantalogic/python_interpreter/operator_visitors.py +108 -0
  33. quantalogic/python_interpreter/scope.py +10 -0
  34. quantalogic/python_interpreter/visit_handlers.py +110 -0
  35. quantalogic/tools/__init__.py +5 -4
  36. quantalogic/tools/action_gen.py +366 -0
  37. quantalogic/tools/python_tool.py +13 -0
  38. quantalogic/tools/{search_definition_names.py → search_definition_names_tool.py} +2 -2
  39. quantalogic/tools/tool.py +116 -22
  40. quantalogic/utils/__init__.py +0 -1
  41. quantalogic/utils/test_python_interpreter.py +119 -0
  42. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/METADATA +7 -2
  43. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/RECORD +46 -14
  44. quantalogic/utils/python_interpreter.py +0 -905
  45. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/LICENSE +0 -0
  46. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/WHEEL +0 -0
  47. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/entry_points.txt +0 -0
@@ -1,905 +0,0 @@
1
- import ast
2
- import builtins
3
- import textwrap
4
- from typing import Any, Dict, List, Optional, Tuple
5
-
6
-
7
- # Exception used to signal a "return" from a function call.
8
- class ReturnException(Exception):
9
- def __init__(self, value: Any) -> None:
10
- self.value: Any = value
11
-
12
-
13
- # Exceptions used for loop control.
14
- class BreakException(Exception):
15
- pass
16
-
17
-
18
- class ContinueException(Exception):
19
- pass
20
-
21
-
22
- # The main interpreter class.
23
- class ASTInterpreter:
24
- def __init__(
25
- self, allowed_modules: List[str], env_stack: Optional[List[Dict[str, Any]]] = None, source: Optional[str] = None
26
- ) -> None:
27
- """
28
- Initialize the AST interpreter with restricted module access and environment stack.
29
-
30
- Args:
31
- allowed_modules: List of module names allowed to be imported.
32
- env_stack: Optional pre-existing environment stack; if None, a new one is created.
33
- source: Optional source code string for error reporting.
34
- """
35
- self.allowed_modules: List[str] = allowed_modules
36
- self.modules: Dict[str, Any] = {}
37
- # Import only the allowed modules.
38
- for mod in allowed_modules:
39
- self.modules[mod] = __import__(mod)
40
- if env_stack is None:
41
- # Create a global environment (first frame) with allowed modules.
42
- self.env_stack: List[Dict[str, Any]] = [{}]
43
- self.env_stack[0].update(self.modules)
44
- # Use builtins from the builtins module.
45
- safe_builtins: Dict[str, Any] = dict(vars(builtins))
46
- safe_builtins["__import__"] = self.safe_import
47
- if "set" not in safe_builtins:
48
- safe_builtins["set"] = set
49
- self.env_stack[0]["__builtins__"] = safe_builtins
50
- # Make builtins names (like set) directly available.
51
- self.env_stack[0].update(safe_builtins)
52
- if "set" not in self.env_stack[0]:
53
- self.env_stack[0]["set"] = set
54
- else:
55
- self.env_stack = env_stack
56
- # Ensure global frame has safe builtins.
57
- if "__builtins__" not in self.env_stack[0]:
58
- safe_builtins: Dict[str, Any] = dict(vars(builtins))
59
- safe_builtins["__import__"] = self.safe_import
60
- if "set" not in safe_builtins:
61
- safe_builtins["set"] = set
62
- self.env_stack[0]["__builtins__"] = safe_builtins
63
- self.env_stack[0].update(safe_builtins)
64
- if "set" not in self.env_stack[0]:
65
- self.env_stack[0]["set"] = self.env_stack[0]["__builtins__"]["set"]
66
-
67
- # Store source code lines for error reporting if provided.
68
- if source is not None:
69
- self.source_lines: Optional[List[str]] = source.splitlines()
70
- else:
71
- self.source_lines = None
72
-
73
- # Add standard Decimal features if allowed.
74
- if "decimal" in self.modules:
75
- dec = self.modules["decimal"]
76
- self.env_stack[0]["Decimal"] = dec.Decimal
77
- self.env_stack[0]["getcontext"] = dec.getcontext
78
- self.env_stack[0]["setcontext"] = dec.setcontext
79
- self.env_stack[0]["localcontext"] = dec.localcontext
80
- self.env_stack[0]["Context"] = dec.Context
81
-
82
- # This safe __import__ only allows modules explicitly provided.
83
- def safe_import(
84
- self,
85
- name: str,
86
- globals: Optional[Dict[str, Any]] = None,
87
- locals: Optional[Dict[str, Any]] = None,
88
- fromlist: Tuple[str, ...] = (),
89
- level: int = 0,
90
- ) -> Any:
91
- """Restrict imports to only allowed modules."""
92
- if name not in self.allowed_modules:
93
- error_msg = f"Import Error: Module '{name}' is not allowed. Only {self.allowed_modules} are permitted."
94
- raise ImportError(error_msg)
95
- return self.modules[name]
96
-
97
- # Helper: create a new interpreter instance using a given environment stack.
98
- def spawn_from_env(self, env_stack: List[Dict[str, Any]]) -> "ASTInterpreter":
99
- """Spawn a new interpreter with the provided environment stack."""
100
- return ASTInterpreter(
101
- self.allowed_modules, env_stack, source="\n".join(self.source_lines) if self.source_lines else None
102
- )
103
-
104
- # Look up a variable in the chain of environment frames.
105
- def get_variable(self, name: str) -> Any:
106
- """Retrieve a variable's value from the environment stack."""
107
- for frame in reversed(self.env_stack):
108
- if name in frame:
109
- return frame[name]
110
- raise NameError(f"Name {name} is not defined.")
111
-
112
- # Always assign to the most local environment.
113
- def set_variable(self, name: str, value: Any) -> None:
114
- """Set a variable in the most local environment frame."""
115
- self.env_stack[-1][name] = value
116
-
117
- # Used for assignment targets. This handles names and destructuring.
118
- def assign(self, target: ast.AST, value: Any) -> None:
119
- """Assign a value to a target (name, tuple, attribute, or subscript)."""
120
- if isinstance(target, ast.Name):
121
- # If current frame declares the name as global, update global frame.
122
- if "__global_names__" in self.env_stack[-1] and target.id in self.env_stack[-1]["__global_names__"]:
123
- self.env_stack[0][target.id] = value
124
- else:
125
- self.env_stack[-1][target.id] = value
126
- elif isinstance(target, (ast.Tuple, ast.List)):
127
- # Support single-star unpacking.
128
- star_index = None
129
- for i, elt in enumerate(target.elts):
130
- if isinstance(elt, ast.Starred):
131
- if star_index is not None:
132
- raise Exception("Multiple starred expressions not supported")
133
- star_index = i
134
- if star_index is None:
135
- if len(target.elts) != len(value):
136
- raise ValueError("Unpacking mismatch")
137
- for t, v in zip(target.elts, value):
138
- self.assign(t, v)
139
- else:
140
- total = len(value)
141
- before = target.elts[:star_index]
142
- after = target.elts[star_index + 1 :]
143
- if len(before) + len(after) > total:
144
- raise ValueError("Unpacking mismatch")
145
- for i, elt2 in enumerate(before):
146
- self.assign(elt2, value[i])
147
- starred_count = total - len(before) - len(after)
148
- self.assign(target.elts[star_index].value, value[len(before) : len(before) + starred_count])
149
- for j, elt2 in enumerate(after):
150
- self.assign(elt2, value[len(before) + starred_count + j])
151
- elif isinstance(target, ast.Attribute):
152
- obj = self.visit(target.value)
153
- setattr(obj, target.attr, value)
154
- elif isinstance(target, ast.Subscript):
155
- obj = self.visit(target.value)
156
- key = self.visit(target.slice)
157
- obj[key] = value
158
- else:
159
- raise Exception("Unsupported assignment target type: " + str(type(target)))
160
-
161
- # Main visitor dispatch.
162
- def visit(self, node: ast.AST) -> Any:
163
- """Dispatch to the appropriate visitor method for the AST node."""
164
- method_name: str = "visit_" + node.__class__.__name__
165
- method = getattr(self, method_name, self.generic_visit)
166
- try:
167
- return method(node)
168
- except (ReturnException, BreakException, ContinueException):
169
- raise
170
- except Exception as e:
171
- lineno = getattr(node, "lineno", None)
172
- col = getattr(node, "col_offset", None)
173
- lineno = lineno if lineno is not None else 1
174
- col = col if col is not None else 0
175
- context_line = ""
176
- if self.source_lines and 1 <= lineno <= len(self.source_lines):
177
- context_line = self.source_lines[lineno - 1]
178
- raise Exception(f"Error line {lineno}, col {col}:\n{context_line}\nDescription: {str(e)}") from e
179
-
180
- # Fallback for unsupported nodes.
181
- def generic_visit(self, node: ast.AST) -> Any:
182
- """Handle unsupported AST nodes with an error."""
183
- lineno = getattr(node, "lineno", None)
184
- context_line = ""
185
- if self.source_lines and lineno is not None and 1 <= lineno <= len(self.source_lines):
186
- context_line = self.source_lines[lineno - 1]
187
- raise Exception(
188
- f"Unsupported AST node type: {node.__class__.__name__} at line {lineno}.\nContext: {context_line}"
189
- )
190
-
191
- # --- Visitor for Import nodes ---
192
- def visit_Import(self, node: ast.Import) -> None:
193
- """
194
- Process an import statement.
195
- Only allowed modules can be imported.
196
- """
197
- for alias in node.names:
198
- module_name: str = alias.name
199
- asname: str = alias.asname if alias.asname is not None else module_name
200
- if module_name not in self.allowed_modules:
201
- raise Exception(
202
- f"Import Error: Module '{module_name}' is not allowed. Only {self.allowed_modules} are permitted."
203
- )
204
- self.set_variable(asname, self.modules[module_name])
205
-
206
- def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
207
- """Process 'from ... import ...' statements with restricted module access."""
208
- if not node.module:
209
- raise Exception("Import Error: Missing module name in 'from ... import ...' statement")
210
- if node.module not in self.allowed_modules:
211
- raise Exception(
212
- f"Import Error: Module '{node.module}' is not allowed. Only {self.allowed_modules} are permitted."
213
- )
214
- for alias in node.names:
215
- if alias.name == "*":
216
- raise Exception("Import Error: 'from ... import *' is not supported.")
217
- asname = alias.asname if alias.asname else alias.name
218
- attr = getattr(self.modules[node.module], alias.name)
219
- self.set_variable(asname, attr)
220
-
221
- # --- Visitor for ListComprehension nodes ---
222
- def visit_ListComp(self, node: ast.ListComp) -> List[Any]:
223
- """
224
- Process a list comprehension, e.g., [elt for ... in ... if ...].
225
- The comprehension is executed in a new local frame that inherits the current environment.
226
- """
227
- result: List[Any] = []
228
- # Copy the current top-level frame for the comprehension scope.
229
- base_frame: Dict[str, Any] = self.env_stack[-1].copy()
230
- self.env_stack.append(base_frame)
231
-
232
- def rec(gen_idx: int) -> None:
233
- if gen_idx == len(node.generators):
234
- result.append(self.visit(node.elt))
235
- else:
236
- comp = node.generators[gen_idx]
237
- iterable = self.visit(comp.iter)
238
- for item in iterable:
239
- # Push a new frame that inherits the current comprehension scope.
240
- new_frame: Dict[str, Any] = self.env_stack[-1].copy()
241
- self.env_stack.append(new_frame)
242
- self.assign(comp.target, item)
243
- if all(self.visit(if_clause) for if_clause in comp.ifs):
244
- rec(gen_idx + 1)
245
- self.env_stack.pop()
246
-
247
- rec(0)
248
- self.env_stack.pop()
249
- return result
250
-
251
- # --- Other node visitors below ---
252
- def visit_Module(self, node: ast.Module) -> Any:
253
- """Execute module body and return 'result' or last value."""
254
- last_value: Any = None
255
- for stmt in node.body:
256
- last_value = self.visit(stmt)
257
- return self.env_stack[0].get("result", last_value)
258
-
259
- def visit_Expr(self, node: ast.Expr) -> Any:
260
- """Evaluate an expression statement."""
261
- return self.visit(node.value)
262
-
263
- def visit_Constant(self, node: ast.Constant) -> Any:
264
- """Return the value of a constant node."""
265
- return node.value
266
-
267
- def visit_Name(self, node: ast.Name) -> Any:
268
- """Handle variable name lookups or stores."""
269
- if isinstance(node.ctx, ast.Load):
270
- return self.get_variable(node.id)
271
- elif isinstance(node.ctx, ast.Store):
272
- return node.id
273
- else:
274
- raise Exception("Unsupported context for Name")
275
-
276
- def visit_BinOp(self, node: ast.BinOp) -> Any:
277
- """Evaluate binary operations."""
278
- left: Any = self.visit(node.left)
279
- right: Any = self.visit(node.right)
280
- op = node.op
281
- if isinstance(op, ast.Add):
282
- return left + right
283
- elif isinstance(op, ast.Sub):
284
- return left - right
285
- elif isinstance(op, ast.Mult):
286
- return left * right
287
- elif isinstance(op, ast.Div):
288
- return left / right
289
- elif isinstance(op, ast.FloorDiv):
290
- return left // right
291
- elif isinstance(op, ast.Mod):
292
- return left % right
293
- elif isinstance(op, ast.Pow):
294
- return left**right
295
- elif isinstance(op, ast.LShift):
296
- return left << right
297
- elif isinstance(op, ast.RShift):
298
- return left >> right
299
- elif isinstance(op, ast.BitOr):
300
- return left | right
301
- elif isinstance(op, ast.BitXor):
302
- return left ^ right
303
- elif isinstance(op, ast.BitAnd):
304
- return left & right
305
- else:
306
- raise Exception("Unsupported binary operator: " + str(op))
307
-
308
- def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
309
- """Evaluate unary operations."""
310
- operand: Any = self.visit(node.operand)
311
- op = node.op
312
- if isinstance(op, ast.UAdd):
313
- return +operand
314
- elif isinstance(op, ast.USub):
315
- return -operand
316
- elif isinstance(op, ast.Not):
317
- return not operand
318
- elif isinstance(op, ast.Invert):
319
- return ~operand
320
- else:
321
- raise Exception("Unsupported unary operator: " + str(op))
322
-
323
- def visit_Assign(self, node: ast.Assign) -> None:
324
- """Handle assignment statements."""
325
- value: Any = self.visit(node.value)
326
- for target in node.targets:
327
- self.assign(target, value)
328
-
329
- def visit_AugAssign(self, node: ast.AugAssign) -> Any:
330
- """Handle augmented assignments (e.g., +=)."""
331
- # If target is a Name, get its current value from the environment.
332
- if isinstance(node.target, ast.Name):
333
- current_val: Any = self.get_variable(node.target.id)
334
- else:
335
- current_val: Any = self.visit(node.target)
336
- right_val: Any = self.visit(node.value)
337
- op = node.op
338
- if isinstance(op, ast.Add):
339
- result: Any = current_val + right_val
340
- elif isinstance(op, ast.Sub):
341
- result = current_val - right_val
342
- elif isinstance(op, ast.Mult):
343
- result = current_val * right_val
344
- elif isinstance(op, ast.Div):
345
- result = current_val / right_val
346
- elif isinstance(op, ast.FloorDiv):
347
- result = current_val // right_val
348
- elif isinstance(op, ast.Mod):
349
- result = current_val % right_val
350
- elif isinstance(op, ast.Pow):
351
- result = current_val**right_val
352
- elif isinstance(op, ast.BitAnd):
353
- result = current_val & right_val
354
- elif isinstance(op, ast.BitOr):
355
- result = current_val | right_val
356
- elif isinstance(op, ast.BitXor):
357
- result = current_val ^ right_val
358
- elif isinstance(op, ast.LShift):
359
- result = current_val << right_val
360
- elif isinstance(op, ast.RShift):
361
- result = current_val >> right_val
362
- else:
363
- raise Exception("Unsupported augmented operator: " + str(op))
364
- self.assign(node.target, result)
365
- return result
366
-
367
- def visit_Compare(self, node: ast.Compare) -> bool:
368
- """Evaluate comparison operations."""
369
- left: Any = self.visit(node.left)
370
- for op, comparator in zip(node.ops, node.comparators):
371
- right: Any = self.visit(comparator)
372
- if isinstance(op, ast.Eq):
373
- if not (left == right):
374
- return False
375
- elif isinstance(op, ast.NotEq):
376
- if not (left != right):
377
- return False
378
- elif isinstance(op, ast.Lt):
379
- if not (left < right):
380
- return False
381
- elif isinstance(op, ast.LtE):
382
- if not (left <= right):
383
- return False
384
- elif isinstance(op, ast.Gt):
385
- if not (left > right):
386
- return False
387
- elif isinstance(op, ast.GtE):
388
- if not (left >= right):
389
- return False
390
- elif isinstance(op, ast.Is):
391
- if left is not right:
392
- return False
393
- elif isinstance(op, ast.IsNot):
394
- if not (left is not right):
395
- return False
396
- elif isinstance(op, ast.In):
397
- if left not in right:
398
- return False
399
- elif isinstance(op, ast.NotIn):
400
- if not (left not in right):
401
- return False
402
- else:
403
- raise Exception("Unsupported comparison operator: " + str(op))
404
- left = right
405
- return True
406
-
407
- def visit_BoolOp(self, node: ast.BoolOp) -> bool:
408
- """Evaluate boolean operations (and/or)."""
409
- if isinstance(node.op, ast.And):
410
- for value in node.values:
411
- if not self.visit(value):
412
- return False
413
- return True
414
- elif isinstance(node.op, ast.Or):
415
- for value in node.values:
416
- if self.visit(value):
417
- return True
418
- return False
419
- else:
420
- raise Exception("Unsupported boolean operator: " + str(node.op))
421
-
422
- def visit_If(self, node: ast.If) -> Any:
423
- """Handle if statements."""
424
- if self.visit(node.test):
425
- branch = node.body
426
- else:
427
- branch = node.orelse
428
- result = None
429
- if branch:
430
- for stmt in branch[:-1]:
431
- # Execute all but the last statement
432
- self.visit(stmt)
433
- # Return value from the last statement
434
- result = self.visit(branch[-1])
435
- return result
436
-
437
- def visit_While(self, node: ast.While) -> None:
438
- """Handle while loops."""
439
- while self.visit(node.test):
440
- try:
441
- for stmt in node.body:
442
- self.visit(stmt)
443
- except BreakException:
444
- break
445
- except ContinueException:
446
- continue
447
- for stmt in node.orelse:
448
- self.visit(stmt)
449
-
450
- def visit_For(self, node: ast.For) -> None:
451
- """Handle for loops."""
452
- iter_obj: Any = self.visit(node.iter)
453
- for item in iter_obj:
454
- self.assign(node.target, item)
455
- try:
456
- for stmt in node.body:
457
- self.visit(stmt)
458
- except BreakException:
459
- break
460
- except ContinueException:
461
- continue
462
- for stmt in node.orelse:
463
- self.visit(stmt)
464
-
465
- def visit_Break(self, node: ast.Break) -> None:
466
- """Handle break statements."""
467
- raise BreakException()
468
-
469
- def visit_Continue(self, node: ast.Continue) -> None:
470
- """Handle continue statements."""
471
- raise ContinueException()
472
-
473
- def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
474
- """Define a function and store it in the environment."""
475
- # Capture the current env_stack for a closure without copying inner dicts.
476
- closure: List[Dict[str, Any]] = self.env_stack[:]
477
- func = Function(node, closure, self)
478
- self.set_variable(node.name, func)
479
-
480
- def visit_Call(self, node: ast.Call) -> Any:
481
- """Handle function calls."""
482
- func = self.visit(node.func)
483
- args: List[Any] = [self.visit(arg) for arg in node.args]
484
- kwargs: Dict[str, Any] = {kw.arg: self.visit(kw.value) for kw in node.keywords}
485
- return func(*args, **kwargs)
486
-
487
- def visit_Return(self, node: ast.Return) -> None:
488
- """Handle return statements."""
489
- value: Any = self.visit(node.value) if node.value is not None else None
490
- raise ReturnException(value)
491
-
492
- def visit_Lambda(self, node: ast.Lambda) -> Any:
493
- """Define a lambda function."""
494
- closure: List[Dict[str, Any]] = self.env_stack[:]
495
- return LambdaFunction(node, closure, self)
496
-
497
- def visit_List(self, node: ast.List) -> List[Any]:
498
- """Evaluate list literals."""
499
- return [self.visit(elt) for elt in node.elts]
500
-
501
- def visit_Tuple(self, node: ast.Tuple) -> Tuple[Any, ...]:
502
- """Evaluate tuple literals."""
503
- return tuple(self.visit(elt) for elt in node.elts)
504
-
505
- def visit_Dict(self, node: ast.Dict) -> Dict[Any, Any]:
506
- """Evaluate dictionary literals."""
507
- return {self.visit(k): self.visit(v) for k, v in zip(node.keys, node.values)}
508
-
509
- def visit_Set(self, node: ast.Set) -> set:
510
- """Evaluate set literals."""
511
- return set(self.visit(elt) for elt in node.elts)
512
-
513
- def visit_Attribute(self, node: ast.Attribute) -> Any:
514
- """Handle attribute access."""
515
- value: Any = self.visit(node.value)
516
- return getattr(value, node.attr)
517
-
518
- def visit_Subscript(self, node: ast.Subscript) -> Any:
519
- """Handle subscript operations (e.g., list indexing)."""
520
- value: Any = self.visit(node.value)
521
- slice_val: Any = self.visit(node.slice)
522
- return value[slice_val]
523
-
524
- def visit_Slice(self, node: ast.Slice) -> slice:
525
- """Evaluate slice objects."""
526
- lower: Any = self.visit(node.lower) if node.lower else None
527
- upper: Any = self.visit(node.upper) if node.upper else None
528
- step: Any = self.visit(node.step) if node.step else None
529
- return slice(lower, upper, step)
530
-
531
- # For compatibility with older AST versions.
532
- def visit_Index(self, node: ast.Index) -> Any:
533
- """Handle index nodes for older AST compatibility."""
534
- return self.visit(node.value)
535
-
536
- # Visitor for Pass nodes.
537
- def visit_Pass(self, node: ast.Pass) -> None:
538
- """Handle pass statements (do nothing)."""
539
- return None
540
-
541
- def visit_TypeIgnore(self, node: ast.TypeIgnore) -> None:
542
- """Handle type ignore statements (do nothing)."""
543
- pass
544
-
545
- def visit_Try(self, node: ast.Try) -> Any:
546
- """Handle try-except blocks."""
547
- result: Any = None
548
- exc_info: Optional[tuple] = None
549
-
550
- try:
551
- for stmt in node.body:
552
- result = self.visit(stmt)
553
- except Exception as e:
554
- exc_info = (type(e), e, e.__traceback__)
555
- for handler in node.handlers:
556
- # Modified resolution for exception type.
557
- if handler.type is None:
558
- exc_type = Exception
559
- elif isinstance(handler.type, ast.Constant) and isinstance(handler.type.value, type):
560
- exc_type = handler.type.value
561
- elif isinstance(handler.type, ast.Name):
562
- exc_type = self.get_variable(handler.type.id)
563
- else:
564
- exc_type = self.visit(handler.type)
565
- # Use issubclass on the exception type rather than isinstance on the exception instance.
566
- if exc_info and issubclass(exc_info[0], exc_type):
567
- if handler.name:
568
- self.set_variable(handler.name, exc_info[1])
569
- for stmt in handler.body:
570
- result = self.visit(stmt)
571
- exc_info = None # Mark as handled
572
- break
573
- if exc_info:
574
- raise exc_info[1]
575
- else:
576
- for stmt in node.orelse:
577
- result = self.visit(stmt)
578
- finally:
579
- for stmt in node.finalbody:
580
- try:
581
- self.visit(stmt)
582
- except ReturnException:
583
- raise
584
- except Exception:
585
- if exc_info:
586
- raise exc_info[1]
587
- raise
588
-
589
- return result
590
-
591
- def visit_Nonlocal(self, node: ast.Nonlocal) -> None:
592
- """Handle nonlocal statements (minimal support)."""
593
- # Minimal support – assume these names exist in an outer frame.
594
- return None
595
-
596
- def visit_JoinedStr(self, node: ast.JoinedStr) -> str:
597
- """Handle f-strings by concatenating all parts."""
598
- return "".join(self.visit(value) for value in node.values)
599
-
600
- def visit_FormattedValue(self, node: ast.FormattedValue) -> str:
601
- """Handle formatted values within f-strings."""
602
- return str(self.visit(node.value))
603
-
604
- def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any:
605
- """Handle generator expressions."""
606
- def generator():
607
- base_frame: Dict[str, Any] = self.env_stack[-1].copy()
608
- self.env_stack.append(base_frame)
609
-
610
- def rec(gen_idx: int):
611
- if gen_idx == len(node.generators):
612
- yield self.visit(node.elt)
613
- else:
614
- comp = node.generators[gen_idx]
615
- iterable = self.visit(comp.iter)
616
- for item in iterable:
617
- new_frame: Dict[str, Any] = self.env_stack[-1].copy()
618
- self.env_stack.append(new_frame)
619
- self.assign(comp.target, item)
620
- if all(self.visit(if_clause) for if_clause in comp.ifs):
621
- yield from rec(gen_idx + 1)
622
- self.env_stack.pop()
623
-
624
- gen = list(rec(0))
625
- self.env_stack.pop()
626
- for val in gen:
627
- yield val
628
-
629
- return generator()
630
-
631
- def visit_ClassDef(self, node: ast.ClassDef):
632
- """Handle class definitions."""
633
- base_frame = self.env_stack[-1].copy()
634
- self.env_stack.append(base_frame)
635
- try:
636
- for stmt in node.body:
637
- self.visit(stmt)
638
- class_dict = {k: v for k, v in self.env_stack[-1].items() if k not in ["__builtins__"]}
639
- finally:
640
- self.env_stack.pop()
641
- new_class = type(node.name, (), class_dict)
642
- self.set_variable(node.name, new_class)
643
-
644
- def visit_With(self, node: ast.With):
645
- """Handle with statements."""
646
- for item in node.items:
647
- ctx = self.visit(item.context_expr)
648
- val = ctx.__enter__()
649
- if item.optional_vars:
650
- self.assign(item.optional_vars, val)
651
- try:
652
- for stmt in node.body:
653
- self.visit(stmt)
654
- except Exception as e:
655
- if not ctx.__exit__(type(e), e, None):
656
- raise
657
- else:
658
- ctx.__exit__(None, None, None)
659
-
660
- def visit_Raise(self, node: ast.Raise):
661
- """Handle raise statements."""
662
- exc = self.visit(node.exc) if node.exc else None
663
- if exc:
664
- raise exc
665
- raise Exception("Raise with no exception specified")
666
-
667
- def visit_Global(self, node: ast.Global):
668
- """Handle global statements."""
669
- self.env_stack[-1].setdefault("__global_names__", set()).update(node.names)
670
-
671
- def visit_IfExp(self, node: ast.IfExp):
672
- """Handle ternary if expressions."""
673
- return self.visit(node.body) if self.visit(node.test) else self.visit(node.orelse)
674
-
675
- def visit_DictComp(self, node: ast.DictComp):
676
- """Handle dictionary comprehensions."""
677
- result = {}
678
- base_frame = self.env_stack[-1].copy()
679
- self.env_stack.append(base_frame)
680
-
681
- def rec(gen_idx: int):
682
- if gen_idx == len(node.generators):
683
- key = self.visit(node.key)
684
- val = self.visit(node.value)
685
- result[key] = val
686
- else:
687
- comp = node.generators[gen_idx]
688
- for item in self.visit(comp.iter):
689
- new_frame = self.env_stack[-1].copy()
690
- self.env_stack.append(new_frame)
691
- self.assign(comp.target, item)
692
- if all(self.visit(if_clause) for if_clause in comp.ifs):
693
- rec(gen_idx + 1)
694
- self.env_stack.pop()
695
-
696
- rec(0)
697
- self.env_stack.pop()
698
- return result
699
-
700
- def visit_SetComp(self, node: ast.SetComp):
701
- """Handle set comprehensions."""
702
- result = set()
703
- base_frame = self.env_stack[-1].copy()
704
- self.env_stack.append(base_frame)
705
-
706
- def rec(gen_idx: int):
707
- if gen_idx == len(node.generators):
708
- result.add(self.visit(node.elt))
709
- else:
710
- comp = node.generators[gen_idx]
711
- for item in self.visit(comp.iter):
712
- new_frame = self.env_stack[-1].copy()
713
- self.env_stack.append(new_frame)
714
- self.assign(comp.target, item)
715
- if all(self.visit(if_clause) for if_clause in comp.ifs):
716
- rec(gen_idx + 1)
717
- self.env_stack.pop()
718
-
719
- rec(0)
720
- self.env_stack.pop()
721
- return result
722
-
723
-
724
- # Class to represent a user-defined function.
725
- class Function:
726
- def __init__(self, node: ast.FunctionDef, closure: List[Dict[str, Any]], interpreter: ASTInterpreter) -> None:
727
- """Initialize a user-defined function."""
728
- self.node: ast.FunctionDef = node
729
- # Shallow copy to support recursion.
730
- self.closure: List[Dict[str, Any]] = self.env_stack_reference(closure)
731
- self.interpreter: ASTInterpreter = interpreter
732
-
733
- # Helper to simply return the given environment stack (shallow copy of list refs).
734
- def env_stack_reference(self, env_stack: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
735
- """Return a shallow copy of the environment stack."""
736
- return env_stack[:] # shallow
737
-
738
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
739
- """Execute the function with given arguments."""
740
- new_env_stack: List[Dict[str, Any]] = self.closure[:]
741
- local_frame: Dict[str, Any] = {}
742
- # Bind the function into its own local frame for recursion.
743
- local_frame[self.node.name] = self
744
- # For simplicity, only positional parameters are supported.
745
- if len(args) < len(self.node.args.args):
746
- raise TypeError("Not enough arguments provided")
747
- if len(args) > len(self.node.args.args):
748
- raise TypeError("Too many arguments provided")
749
- if kwargs:
750
- raise TypeError("Keyword arguments are not supported")
751
- for i, arg in enumerate(self.node.args.args):
752
- local_frame[arg.arg] = args[i]
753
- new_env_stack.append(local_frame)
754
- new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
755
- try:
756
- for stmt in self.node.body[:-1]:
757
- new_interp.visit(stmt)
758
- return new_interp.visit(self.node.body[-1])
759
- except ReturnException as ret:
760
- return ret.value
761
- return None
762
-
763
- # Add __get__ to support method binding.
764
- def __get__(self, instance: Any, owner: Any):
765
- """Support method binding for instance methods."""
766
- def method(*args: Any, **kwargs: Any) -> Any:
767
- return self(instance, *args, **kwargs)
768
- return method
769
-
770
-
771
- # Class to represent a lambda function.
772
- class LambdaFunction:
773
- def __init__(self, node: ast.Lambda, closure: List[Dict[str, Any]], interpreter: ASTInterpreter) -> None:
774
- """Initialize a lambda function."""
775
- self.node: ast.Lambda = node
776
- self.closure: List[Dict[str, Any]] = self.env_stack_reference(closure)
777
- self.interpreter: ASTInterpreter = interpreter
778
-
779
- def env_stack_reference(self, env_stack: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
780
- """Return a shallow copy of the environment stack."""
781
- return env_stack[:]
782
-
783
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
784
- """Execute the lambda function with given arguments."""
785
- new_env_stack: List[Dict[str, Any]] = self.closure[:]
786
- local_frame: Dict[str, Any] = {}
787
- if len(args) < len(self.node.args.args):
788
- raise TypeError("Not enough arguments for lambda")
789
- if len(args) > len(self.node.args.args):
790
- raise TypeError("Too many arguments for lambda")
791
- if kwargs:
792
- raise TypeError("Lambda does not support keyword arguments")
793
- for i, arg in enumerate(self.node.args.args):
794
- local_frame[arg.arg] = args[i]
795
- new_env_stack.append(local_frame)
796
- new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
797
- return new_interp.visit(self.node.body)
798
-
799
-
800
- # The main function to interpret an AST.
801
- def interpret_ast(ast_tree: Any, allowed_modules: list[str], source: str = "") -> Any:
802
- """Interpret an AST with restricted module access."""
803
- import ast
804
-
805
- # Keep only yield-based nodes in fallback.
806
- unsupported = (ast.Yield, ast.YieldFrom)
807
- for node in ast.walk(ast_tree):
808
- if isinstance(node, unsupported):
809
- safe_globals = {
810
- "__builtins__": {
811
- "range": range,
812
- "len": len,
813
- "print": print,
814
- "__import__": __import__,
815
- "ZeroDivisionError": ZeroDivisionError,
816
- "ValueError": ValueError,
817
- "NameError": NameError,
818
- "TypeError": TypeError,
819
- "list": list,
820
- "dict": dict,
821
- "tuple": tuple,
822
- "set": set,
823
- "float": float,
824
- "int": int,
825
- "bool": bool,
826
- "Exception": Exception,
827
- }
828
- }
829
- for mod in allowed_modules:
830
- safe_globals[mod] = __import__(mod)
831
- local_vars = {}
832
- exec(compile(ast_tree, "<string>", "exec"), safe_globals, local_vars)
833
- return local_vars.get("result", None)
834
- # Otherwise, use the custom interpreter.
835
- interpreter = ASTInterpreter(allowed_modules=allowed_modules, source=source)
836
- return interpreter.visit(ast_tree)
837
-
838
-
839
- # A helper function which takes a Python code string and a list of allowed module names,
840
- # then parses and interprets the code.
841
- def interpret_code(source_code: str, allowed_modules: List[str]) -> Any:
842
- """
843
- Interpret a Python source code string with a restricted set of allowed modules.
844
-
845
- Args:
846
- source_code: The Python source code to interpret.
847
- allowed_modules: A list of module names that are allowed.
848
- Returns:
849
- The result of interpreting the source code.
850
- """
851
- # Dedent the source to normalize its indentation.
852
- dedented_source = textwrap.dedent(source_code)
853
- tree: ast.AST = ast.parse(dedented_source)
854
- return interpret_ast(tree, allowed_modules, source=dedented_source)
855
-
856
-
857
- if __name__ == "__main__":
858
- print("Script is running!")
859
- source_code_1: str = """
860
- import math
861
- def square(x):
862
- return x * x
863
-
864
- y = square(5)
865
- z = math.sqrt(y)
866
- z
867
- """
868
- # Only "math" is allowed here.
869
- try:
870
- result_1: Any = interpret_code(source_code_1, allowed_modules=["math"])
871
- print("Result:", result_1)
872
- except Exception as e:
873
- print("Interpreter error:", e)
874
-
875
- print("Second example:")
876
-
877
- # Define the source code with multiple operations and a list comprehension.
878
- source_code_2: str = """
879
- import math
880
- import numpy as np
881
- def transform_array(x):
882
- # Apply square root
883
- sqrt_vals = [math.sqrt(val) for val in x]
884
-
885
- # Apply sine function
886
- sin_vals = [math.sin(val) for val in sqrt_vals]
887
-
888
- # Apply exponential
889
- exp_vals = [math.exp(val) for val in sin_vals]
890
-
891
- return exp_vals
892
-
893
- array_input = np.array([1, 4, 9, 16, 25])
894
- result = transform_array(array_input)
895
- result
896
- """
897
- print("About to parse source code")
898
- try:
899
- tree_2: ast.AST = ast.parse(textwrap.dedent(source_code_2))
900
- print("Source code parsed successfully")
901
- # Allow both math and numpy.
902
- result_2: Any = interpret_ast(tree_2, allowed_modules=["math", "numpy"], source=textwrap.dedent(source_code_2))
903
- print("Result:", result_2)
904
- except Exception as e:
905
- print("Interpreter error:", e)