quantalogic 0.33.0__py3-none-any.whl → 0.33.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,835 @@
1
+ import ast
2
+ import builtins
3
+ import textwrap
4
+ from typing import Any, List, Dict, Optional, Tuple
5
+
6
+ # Exception used to signal a "return" from a function call.
7
+ class ReturnException(Exception):
8
+ def __init__(self, value: Any) -> None:
9
+ self.value: Any = value
10
+
11
+ # Exceptions used for loop control.
12
+ class BreakException(Exception):
13
+ pass
14
+
15
+ class ContinueException(Exception):
16
+ pass
17
+
18
+ # The main interpreter class.
19
+ class ASTInterpreter:
20
+ def __init__(
21
+ self,
22
+ allowed_modules: List[str],
23
+ env_stack: Optional[List[Dict[str, Any]]] = None,
24
+ source: Optional[str] = None
25
+ ) -> None:
26
+ self.allowed_modules: List[str] = allowed_modules
27
+ self.modules: Dict[str, Any] = {}
28
+ # Import only the allowed modules.
29
+ for mod in allowed_modules:
30
+ self.modules[mod] = __import__(mod)
31
+ if env_stack is None:
32
+ # Create a global environment (first frame) with allowed modules.
33
+ self.env_stack: List[Dict[str, Any]] = [{}]
34
+ self.env_stack[0].update(self.modules)
35
+ # Use builtins from the builtins module.
36
+ safe_builtins: Dict[str, Any] = dict(vars(builtins))
37
+ safe_builtins["__import__"] = self.safe_import
38
+ if "set" not in safe_builtins:
39
+ safe_builtins["set"] = set
40
+ self.env_stack[0]["__builtins__"] = safe_builtins
41
+ # Make builtins names (like set) directly available.
42
+ self.env_stack[0].update(safe_builtins)
43
+ if "set" not in self.env_stack[0]:
44
+ self.env_stack[0]["set"] = set
45
+ else:
46
+ self.env_stack = env_stack
47
+ # Ensure global frame has safe builtins.
48
+ if "__builtins__" not in self.env_stack[0]:
49
+ safe_builtins: Dict[str, Any] = dict(vars(builtins))
50
+ safe_builtins["__import__"] = self.safe_import
51
+ if "set" not in safe_builtins:
52
+ safe_builtins["set"] = set
53
+ self.env_stack[0]["__builtins__"] = safe_builtins
54
+ self.env_stack[0].update(safe_builtins)
55
+ if "set" not in self.env_stack[0]:
56
+ self.env_stack[0]["set"] = self.env_stack[0]["__builtins__"]["set"]
57
+
58
+ # Store source code lines for error reporting if provided.
59
+ if source is not None:
60
+ self.source_lines: Optional[List[str]] = source.splitlines()
61
+ else:
62
+ self.source_lines = None
63
+
64
+ # NEW: Add standard Decimal features if allowed.
65
+ if "decimal" in self.modules:
66
+ dec = self.modules["decimal"]
67
+ self.env_stack[0]["Decimal"] = dec.Decimal
68
+ self.env_stack[0]["getcontext"] = dec.getcontext
69
+ self.env_stack[0]["setcontext"] = dec.setcontext
70
+ self.env_stack[0]["localcontext"] = dec.localcontext
71
+ self.env_stack[0]["Context"] = dec.Context
72
+
73
+ # This safe __import__ only allows modules explicitly provided.
74
+ def safe_import(
75
+ self,
76
+ name: str,
77
+ globals: Optional[Dict[str, Any]] = None,
78
+ locals: Optional[Dict[str, Any]] = None,
79
+ fromlist: Tuple[str, ...] = (),
80
+ level: int = 0
81
+ ) -> Any:
82
+ if name not in self.allowed_modules:
83
+ error_msg = f"Import Error: Module '{name}' is not allowed. Only {self.allowed_modules} are permitted."
84
+ raise ImportError(error_msg)
85
+ return self.modules[name]
86
+
87
+ # Helper: create a new interpreter instance using a given environment stack.
88
+ def spawn_from_env(self, env_stack: List[Dict[str, Any]]) -> "ASTInterpreter":
89
+ return ASTInterpreter(
90
+ self.allowed_modules,
91
+ env_stack,
92
+ source="\n".join(self.source_lines) if self.source_lines else None
93
+ )
94
+
95
+ # Look up a variable in the chain of environment frames.
96
+ def get_variable(self, name: str) -> Any:
97
+ for frame in reversed(self.env_stack):
98
+ if name in frame:
99
+ return frame[name]
100
+ raise NameError(f"Name {name} is not defined.")
101
+
102
+ # Always assign to the most local environment.
103
+ def set_variable(self, name: str, value: Any) -> None:
104
+ self.env_stack[-1][name] = value
105
+
106
+ # Used for assignment targets. This handles names and destructuring.
107
+ def assign(self, target: ast.AST, value: Any) -> None:
108
+ if isinstance(target, ast.Name):
109
+ # If current frame declares the name as global, update global frame.
110
+ if "__global_names__" in self.env_stack[-1] and target.id in self.env_stack[-1]["__global_names__"]:
111
+ self.env_stack[0][target.id] = value
112
+ else:
113
+ self.env_stack[-1][target.id] = value
114
+ elif isinstance(target, (ast.Tuple, ast.List)):
115
+ # Support single-star unpacking.
116
+ star_index = None
117
+ for i, elt in enumerate(target.elts):
118
+ if isinstance(elt, ast.Starred):
119
+ if star_index is not None:
120
+ raise Exception("Multiple starred expressions not supported")
121
+ star_index = i
122
+ if star_index is None:
123
+ if len(target.elts) != len(value):
124
+ raise ValueError("Unpacking mismatch")
125
+ for t, v in zip(target.elts, value):
126
+ self.assign(t, v)
127
+ else:
128
+ total = len(value)
129
+ before = target.elts[:star_index]
130
+ after = target.elts[star_index+1:]
131
+ if len(before) + len(after) > total:
132
+ raise ValueError("Unpacking mismatch")
133
+ for i, elt2 in enumerate(before):
134
+ self.assign(elt2, value[i])
135
+ starred_count = total - len(before) - len(after)
136
+ self.assign(target.elts[star_index].value, value[len(before):len(before)+starred_count])
137
+ for j, elt2 in enumerate(after):
138
+ self.assign(elt2, value[len(before)+starred_count+j])
139
+ elif isinstance(target, ast.Attribute):
140
+ obj = self.visit(target.value)
141
+ setattr(obj, target.attr, value)
142
+ elif isinstance(target, ast.Subscript):
143
+ obj = self.visit(target.value)
144
+ key = self.visit(target.slice)
145
+ obj[key] = value
146
+ else:
147
+ raise Exception("Unsupported assignment target type: " + str(type(target)))
148
+
149
+ # Main visitor dispatch.
150
+ def visit(self, node: ast.AST) -> Any:
151
+ method_name: str = "visit_" + node.__class__.__name__
152
+ method = getattr(self, method_name, self.generic_visit)
153
+ try:
154
+ return method(node)
155
+ except (ReturnException, BreakException, ContinueException):
156
+ raise
157
+ except Exception as e:
158
+ lineno = getattr(node, "lineno", None)
159
+ col = getattr(node, "col_offset", None)
160
+ lineno = lineno if lineno is not None else 1
161
+ col = col if col is not None else 0
162
+ context_line = ""
163
+ if self.source_lines and 1 <= lineno <= len(self.source_lines):
164
+ context_line = self.source_lines[lineno - 1]
165
+ raise Exception(
166
+ f"Error line {lineno}, col {col}:\n{context_line}\nDescription: {str(e)}"
167
+ ) from e
168
+
169
+ # Fallback for unsupported nodes.
170
+ def generic_visit(self, node: ast.AST) -> Any:
171
+ lineno = getattr(node, "lineno", None)
172
+ context_line = ""
173
+ if self.source_lines and lineno is not None and 1 <= lineno <= len(self.source_lines):
174
+ context_line = self.source_lines[lineno - 1]
175
+ raise Exception(
176
+ f"Unsupported AST node type: {node.__class__.__name__} at line {lineno}.\nContext: {context_line}"
177
+ )
178
+
179
+ # --- Visitor for Import nodes ---
180
+ def visit_Import(self, node: ast.Import) -> None:
181
+ """
182
+ Process an import statement.
183
+ Only allowed modules can be imported.
184
+ """
185
+ for alias in node.names:
186
+ module_name: str = alias.name
187
+ asname: str = alias.asname if alias.asname is not None else module_name
188
+ if module_name not in self.allowed_modules:
189
+ raise Exception(f"Import Error: Module '{module_name}' is not allowed. Only {self.allowed_modules} are permitted.")
190
+ self.set_variable(asname, self.modules[module_name])
191
+
192
+ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
193
+ if not node.module:
194
+ raise Exception("Import Error: Missing module name in 'from ... import ...' statement")
195
+ if node.module not in self.allowed_modules:
196
+ raise Exception(f"Import Error: Module '{node.module}' is not allowed. Only {self.allowed_modules} are permitted.")
197
+ for alias in node.names:
198
+ if alias.name == "*":
199
+ raise Exception("Import Error: 'from ... import *' is not supported.")
200
+ asname = alias.asname if alias.asname else alias.name
201
+ attr = getattr(self.modules[node.module], alias.name)
202
+ self.set_variable(asname, attr)
203
+
204
+ # --- Visitor for ListComprehension nodes ---
205
+ def visit_ListComp(self, node: ast.ListComp) -> List[Any]:
206
+ """
207
+ Process a list comprehension, e.g., [elt for ... in ... if ...].
208
+ The comprehension is executed in a new local frame that inherits the
209
+ current environment.
210
+ """
211
+ result: List[Any] = []
212
+ # Copy the current top-level frame for the comprehension scope.
213
+ base_frame: Dict[str, Any] = self.env_stack[-1].copy()
214
+ self.env_stack.append(base_frame)
215
+
216
+ def rec(gen_idx: int) -> None:
217
+ if gen_idx == len(node.generators):
218
+ result.append(self.visit(node.elt))
219
+ else:
220
+ comp = node.generators[gen_idx]
221
+ iterable = self.visit(comp.iter)
222
+ for item in iterable:
223
+ # Push a new frame that inherits the current comprehension scope.
224
+ new_frame: Dict[str, Any] = self.env_stack[-1].copy()
225
+ self.env_stack.append(new_frame)
226
+ self.assign(comp.target, item)
227
+ if all(self.visit(if_clause) for if_clause in comp.ifs):
228
+ rec(gen_idx + 1)
229
+ self.env_stack.pop()
230
+
231
+ rec(0)
232
+ self.env_stack.pop()
233
+ return result
234
+
235
+ # --- Other node visitors below ---
236
+ def visit_Module(self, node: ast.Module) -> Any:
237
+ # Execute all statements then return the 'result' variable if set.
238
+ last_value: Any = None
239
+ for stmt in node.body:
240
+ last_value = self.visit(stmt)
241
+ return self.env_stack[0].get("result", last_value)
242
+
243
+ def visit_Expr(self, node: ast.Expr) -> Any:
244
+ return self.visit(node.value)
245
+
246
+ def visit_Constant(self, node: ast.Constant) -> Any:
247
+ return node.value
248
+
249
+ def visit_Name(self, node: ast.Name) -> Any:
250
+ if isinstance(node.ctx, ast.Load):
251
+ return self.get_variable(node.id)
252
+ elif isinstance(node.ctx, ast.Store):
253
+ return node.id
254
+ else:
255
+ raise Exception("Unsupported context for Name")
256
+
257
+ def visit_BinOp(self, node: ast.BinOp) -> Any:
258
+ left: Any = self.visit(node.left)
259
+ right: Any = self.visit(node.right)
260
+ op = node.op
261
+ if isinstance(op, ast.Add):
262
+ return left + right
263
+ elif isinstance(op, ast.Sub):
264
+ return left - right
265
+ elif isinstance(op, ast.Mult):
266
+ return left * right
267
+ elif isinstance(op, ast.Div):
268
+ return left / right
269
+ elif isinstance(op, ast.FloorDiv):
270
+ return left // right
271
+ elif isinstance(op, ast.Mod):
272
+ return left % right
273
+ elif isinstance(op, ast.Pow):
274
+ return left ** right
275
+ elif isinstance(op, ast.LShift):
276
+ return left << right
277
+ elif isinstance(op, ast.RShift):
278
+ return left >> right
279
+ elif isinstance(op, ast.BitOr):
280
+ return left | right
281
+ elif isinstance(op, ast.BitXor):
282
+ return left ^ right
283
+ elif isinstance(op, ast.BitAnd):
284
+ return left & right
285
+ else:
286
+ raise Exception("Unsupported binary operator: " + str(op))
287
+
288
+ def visit_UnaryOp(self, node: ast.UnaryOp) -> Any:
289
+ operand: Any = self.visit(node.operand)
290
+ op = node.op
291
+ if isinstance(op, ast.UAdd):
292
+ return +operand
293
+ elif isinstance(op, ast.USub):
294
+ return -operand
295
+ elif isinstance(op, ast.Not):
296
+ return not operand
297
+ elif isinstance(op, ast.Invert):
298
+ return ~operand
299
+ else:
300
+ raise Exception("Unsupported unary operator: " + str(op))
301
+
302
+ def visit_Assign(self, node: ast.Assign) -> None:
303
+ value: Any = self.visit(node.value)
304
+ for target in node.targets:
305
+ self.assign(target, value)
306
+
307
+ def visit_AugAssign(self, node: ast.AugAssign) -> Any:
308
+ # If target is a Name, get its current value from the environment.
309
+ if isinstance(node.target, ast.Name):
310
+ current_val: Any = self.get_variable(node.target.id)
311
+ else:
312
+ current_val: Any = self.visit(node.target)
313
+ right_val: Any = self.visit(node.value)
314
+ op = node.op
315
+ if isinstance(op, ast.Add):
316
+ result: Any = current_val + right_val
317
+ elif isinstance(op, ast.Sub):
318
+ result = current_val - right_val
319
+ elif isinstance(op, ast.Mult):
320
+ result = current_val * right_val
321
+ elif isinstance(op, ast.Div):
322
+ result = current_val / right_val
323
+ elif isinstance(op, ast.FloorDiv):
324
+ result = current_val // right_val
325
+ elif isinstance(op, ast.Mod):
326
+ result = current_val % right_val
327
+ elif isinstance(op, ast.Pow):
328
+ result = current_val ** right_val
329
+ elif isinstance(op, ast.BitAnd):
330
+ result = current_val & right_val
331
+ elif isinstance(op, ast.BitOr):
332
+ result = current_val | right_val
333
+ elif isinstance(op, ast.BitXor):
334
+ result = current_val ^ right_val
335
+ elif isinstance(op, ast.LShift):
336
+ result = current_val << right_val
337
+ elif isinstance(op, ast.RShift):
338
+ result = current_val >> right_val
339
+ else:
340
+ raise Exception("Unsupported augmented operator: " + str(op))
341
+ self.assign(node.target, result)
342
+ return result
343
+
344
+ def visit_Compare(self, node: ast.Compare) -> bool:
345
+ left: Any = self.visit(node.left)
346
+ for op, comparator in zip(node.ops, node.comparators):
347
+ right: Any = self.visit(comparator)
348
+ if isinstance(op, ast.Eq):
349
+ if not (left == right):
350
+ return False
351
+ elif isinstance(op, ast.NotEq):
352
+ if not (left != right):
353
+ return False
354
+ elif isinstance(op, ast.Lt):
355
+ if not (left < right):
356
+ return False
357
+ elif isinstance(op, ast.LtE):
358
+ if not (left <= right):
359
+ return False
360
+ elif isinstance(op, ast.Gt):
361
+ if not (left > right):
362
+ return False
363
+ elif isinstance(op, ast.GtE):
364
+ if not (left >= right):
365
+ return False
366
+ elif isinstance(op, ast.Is):
367
+ if not (left is right):
368
+ return False
369
+ elif isinstance(op, ast.IsNot):
370
+ if not (left is not right):
371
+ return False
372
+ elif isinstance(op, ast.In):
373
+ if not (left in right):
374
+ return False
375
+ elif isinstance(op, ast.NotIn):
376
+ if not (left not in right):
377
+ return False
378
+ else:
379
+ raise Exception("Unsupported comparison operator: " + str(op))
380
+ left = right
381
+ return True
382
+
383
+ def visit_BoolOp(self, node: ast.BoolOp) -> bool:
384
+ if isinstance(node.op, ast.And):
385
+ for value in node.values:
386
+ if not self.visit(value):
387
+ return False
388
+ return True
389
+ elif isinstance(node.op, ast.Or):
390
+ for value in node.values:
391
+ if self.visit(value):
392
+ return True
393
+ return False
394
+ else:
395
+ raise Exception("Unsupported boolean operator: " + str(node.op))
396
+
397
+ def visit_If(self, node: ast.If) -> Any:
398
+ if self.visit(node.test):
399
+ branch = node.body
400
+ else:
401
+ branch = node.orelse
402
+ result = None
403
+ if branch:
404
+ for stmt in branch[:-1]:
405
+ # Execute all but the last statement
406
+ self.visit(stmt)
407
+ # Return value from the last statement
408
+ result = self.visit(branch[-1])
409
+ return result
410
+
411
+ def visit_While(self, node: ast.While) -> None:
412
+ while self.visit(node.test):
413
+ try:
414
+ for stmt in node.body:
415
+ self.visit(stmt)
416
+ except BreakException:
417
+ break
418
+ except ContinueException:
419
+ continue
420
+ for stmt in node.orelse:
421
+ self.visit(stmt)
422
+
423
+ def visit_For(self, node: ast.For) -> None:
424
+ iter_obj: Any = self.visit(node.iter)
425
+ for item in iter_obj:
426
+ self.assign(node.target, item)
427
+ try:
428
+ for stmt in node.body:
429
+ self.visit(stmt)
430
+ except BreakException:
431
+ break
432
+ except ContinueException:
433
+ continue
434
+ for stmt in node.orelse:
435
+ self.visit(stmt)
436
+
437
+ def visit_Break(self, node: ast.Break) -> None:
438
+ raise BreakException()
439
+
440
+ def visit_Continue(self, node: ast.Continue) -> None:
441
+ raise ContinueException()
442
+
443
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
444
+ # Capture the current env_stack for a closure without copying inner dicts.
445
+ closure: List[Dict[str, Any]] = self.env_stack[:] # <-- changed here
446
+ func = Function(node, closure, self)
447
+ self.set_variable(node.name, func)
448
+
449
+ def visit_Call(self, node: ast.Call) -> Any:
450
+ func = self.visit(node.func)
451
+ args: List[Any] = [self.visit(arg) for arg in node.args]
452
+ kwargs: Dict[str, Any] = {kw.arg: self.visit(kw.value) for kw in node.keywords}
453
+ return func(*args, **kwargs)
454
+
455
+ def visit_Return(self, node: ast.Return) -> None:
456
+ value: Any = self.visit(node.value) if node.value is not None else None
457
+ raise ReturnException(value)
458
+
459
+ def visit_Lambda(self, node: ast.Lambda) -> Any:
460
+ closure: List[Dict[str, Any]] = self.env_stack[:] # <-- changed here
461
+ return LambdaFunction(node, closure, self)
462
+
463
+ def visit_List(self, node: ast.List) -> List[Any]:
464
+ return [self.visit(elt) for elt in node.elts]
465
+
466
+ def visit_Tuple(self, node: ast.Tuple) -> Tuple[Any, ...]:
467
+ return tuple(self.visit(elt) for elt in node.elts)
468
+
469
+ def visit_Dict(self, node: ast.Dict) -> Dict[Any, Any]:
470
+ return {self.visit(k): self.visit(v) for k, v in zip(node.keys, node.values)}
471
+
472
+ def visit_Set(self, node: ast.Set) -> set:
473
+ return set(self.visit(elt) for elt in node.elts)
474
+
475
+ def visit_Attribute(self, node: ast.Attribute) -> Any:
476
+ value: Any = self.visit(node.value)
477
+ return getattr(value, node.attr)
478
+
479
+ def visit_Subscript(self, node: ast.Subscript) -> Any:
480
+ value: Any = self.visit(node.value)
481
+ slice_val: Any = self.visit(node.slice)
482
+ return value[slice_val]
483
+
484
+ def visit_Slice(self, node: ast.Slice) -> slice:
485
+ lower: Any = self.visit(node.lower) if node.lower else None
486
+ upper: Any = self.visit(node.upper) if node.upper else None
487
+ step: Any = self.visit(node.step) if node.step else None
488
+ return slice(lower, upper, step)
489
+
490
+ # For compatibility with older AST versions.
491
+ def visit_Index(self, node: ast.Index) -> Any:
492
+ return self.visit(node.value)
493
+
494
+ # Visitor for Pass nodes.
495
+ def visit_Pass(self, node: ast.Pass) -> None:
496
+ # Simply ignore 'pass' statements.
497
+ return None
498
+
499
+ def visit_TypeIgnore(self, node: ast.TypeIgnore) -> None:
500
+ pass
501
+
502
+ def visit_Try(self, node: ast.Try) -> Any:
503
+ result: Any = None
504
+ exc_info: Optional[tuple] = None
505
+
506
+ try:
507
+ for stmt in node.body:
508
+ result = self.visit(stmt)
509
+ except Exception as e:
510
+ exc_info = (type(e), e, e.__traceback__)
511
+ for handler in node.handlers:
512
+ # Modified resolution for exception type.
513
+ if handler.type is None:
514
+ exc_type = Exception
515
+ elif isinstance(handler.type, ast.Constant) and isinstance(handler.type.value, type):
516
+ exc_type = handler.type.value
517
+ elif isinstance(handler.type, ast.Name):
518
+ exc_type = self.get_variable(handler.type.id)
519
+ else:
520
+ exc_type = self.visit(handler.type)
521
+ # Use issubclass on the exception type rather than isinstance on the exception instance.
522
+ if exc_info and issubclass(exc_info[0], exc_type):
523
+ if handler.name:
524
+ self.set_variable(handler.name, exc_info[1])
525
+ for stmt in handler.body:
526
+ result = self.visit(stmt)
527
+ exc_info = None # Mark as handled
528
+ break
529
+ if exc_info:
530
+ raise exc_info[1]
531
+ else:
532
+ for stmt in node.orelse:
533
+ result = self.visit(stmt)
534
+ finally:
535
+ for stmt in node.finalbody:
536
+ try:
537
+ self.visit(stmt)
538
+ except ReturnException:
539
+ raise
540
+ except Exception:
541
+ if exc_info:
542
+ raise exc_info[1]
543
+ raise
544
+
545
+ return result
546
+
547
+ def visit_Nonlocal(self, node: ast.Nonlocal) -> None:
548
+ # Minimal support – assume these names exist in an outer frame.
549
+ return None
550
+
551
+ def visit_JoinedStr(self, node: ast.JoinedStr) -> str:
552
+ # Support f-string: concatenate all parts.
553
+ return "".join(self.visit(value) for value in node.values)
554
+
555
+ def visit_FormattedValue(self, node: ast.FormattedValue) -> str:
556
+ # Format the embedded expression.
557
+ return str(self.visit(node.value))
558
+
559
+ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any:
560
+ # Process a generator expression.
561
+ def generator():
562
+ base_frame: Dict[str, Any] = self.env_stack[-1].copy()
563
+ self.env_stack.append(base_frame)
564
+
565
+ def rec(gen_idx: int):
566
+ if gen_idx == len(node.generators):
567
+ yield self.visit(node.elt)
568
+ else:
569
+ comp = node.generators[gen_idx]
570
+ iterable = self.visit(comp.iter)
571
+ for item in iterable:
572
+ new_frame: Dict[str, Any] = self.env_stack[-1].copy()
573
+ self.env_stack.append(new_frame)
574
+ self.assign(comp.target, item)
575
+ if all(self.visit(if_clause) for if_clause in comp.ifs):
576
+ yield from rec(gen_idx + 1)
577
+ self.env_stack.pop()
578
+ gen = list(rec(0))
579
+ self.env_stack.pop()
580
+ for val in gen:
581
+ yield val
582
+ return generator()
583
+
584
+ def visit_ClassDef(self, node: ast.ClassDef):
585
+ base_frame = self.env_stack[-1].copy()
586
+ self.env_stack.append(base_frame)
587
+ try:
588
+ for stmt in node.body:
589
+ self.visit(stmt)
590
+ class_dict = {
591
+ k: v for k, v in self.env_stack[-1].items()
592
+ if k not in ["__builtins__"]
593
+ }
594
+ finally:
595
+ self.env_stack.pop()
596
+ new_class = type(node.name, (), class_dict)
597
+ self.set_variable(node.name, new_class)
598
+
599
+ def visit_With(self, node: ast.With):
600
+ for item in node.items:
601
+ ctx = self.visit(item.context_expr)
602
+ val = ctx.__enter__()
603
+ if item.optional_vars:
604
+ self.assign(item.optional_vars, val)
605
+ try:
606
+ for stmt in node.body:
607
+ self.visit(stmt)
608
+ except Exception as e:
609
+ if not ctx.__exit__(type(e), e, None):
610
+ raise
611
+ else:
612
+ ctx.__exit__(None, None, None)
613
+
614
+ def visit_Raise(self, node: ast.Raise):
615
+ exc = self.visit(node.exc) if node.exc else None
616
+ if exc:
617
+ raise exc
618
+ raise Exception("Raise with no exception specified")
619
+
620
+ def visit_Global(self, node: ast.Global):
621
+ self.env_stack[-1].setdefault("__global_names__", set()).update(node.names)
622
+
623
+ def visit_IfExp(self, node: ast.IfExp):
624
+ return self.visit(node.body) if self.visit(node.test) else self.visit(node.orelse)
625
+
626
+ def visit_DictComp(self, node: ast.DictComp):
627
+ result = {}
628
+ base_frame = self.env_stack[-1].copy()
629
+ self.env_stack.append(base_frame)
630
+ def rec(gen_idx: int):
631
+ if gen_idx == len(node.generators):
632
+ key = self.visit(node.key)
633
+ val = self.visit(node.value)
634
+ result[key] = val
635
+ else:
636
+ comp = node.generators[gen_idx]
637
+ for item in self.visit(comp.iter):
638
+ new_frame = self.env_stack[-1].copy()
639
+ self.env_stack.append(new_frame)
640
+ self.assign(comp.target, item)
641
+ if all(self.visit(if_clause) for if_clause in comp.ifs):
642
+ rec(gen_idx + 1)
643
+ self.env_stack.pop()
644
+ rec(0)
645
+ self.env_stack.pop()
646
+ return result
647
+
648
+ def visit_SetComp(self, node: ast.SetComp):
649
+ result = set()
650
+ base_frame = self.env_stack[-1].copy()
651
+ self.env_stack.append(base_frame)
652
+ def rec(gen_idx: int):
653
+ if gen_idx == len(node.generators):
654
+ result.add(self.visit(node.elt))
655
+ else:
656
+ comp = node.generators[gen_idx]
657
+ for item in self.visit(comp.iter):
658
+ new_frame = self.env_stack[-1].copy()
659
+ self.env_stack.append(new_frame)
660
+ self.assign(comp.target, item)
661
+ if all(self.visit(if_clause) for if_clause in comp.ifs):
662
+ rec(gen_idx + 1)
663
+ self.env_stack.pop()
664
+ rec(0)
665
+ self.env_stack.pop()
666
+ return result
667
+
668
+ # Class to represent a user-defined function.
669
+ class Function:
670
+ def __init__(self, node: ast.FunctionDef, closure: List[Dict[str, Any]], interpreter: ASTInterpreter) -> None:
671
+ self.node: ast.FunctionDef = node
672
+ # Shallow copy to support recursion.
673
+ self.closure: List[Dict[str, Any]] = self.env_stack_reference(closure)
674
+ self.interpreter: ASTInterpreter = interpreter
675
+
676
+ # Helper to simply return the given environment stack (shallow copy of list refs).
677
+ def env_stack_reference(self, env_stack: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
678
+ return env_stack[:] # shallow
679
+
680
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
681
+ new_env_stack: List[Dict[str, Any]] = self.closure[:]
682
+ local_frame: Dict[str, Any] = {}
683
+ # Bind the function into its own local frame for recursion.
684
+ local_frame[self.node.name] = self
685
+ # For simplicity, only positional parameters are supported.
686
+ if len(args) < len(self.node.args.args):
687
+ raise TypeError("Not enough arguments provided")
688
+ if len(args) > len(self.node.args.args):
689
+ raise TypeError("Too many arguments provided")
690
+ if kwargs:
691
+ raise TypeError("Keyword arguments are not supported")
692
+ for i, arg in enumerate(self.node.args.args):
693
+ local_frame[arg.arg] = args[i]
694
+ new_env_stack.append(local_frame)
695
+ new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
696
+ try:
697
+ for stmt in self.node.body[:-1]:
698
+ new_interp.visit(stmt)
699
+ return new_interp.visit(self.node.body[-1])
700
+ except ReturnException as ret:
701
+ return ret.value
702
+ return None
703
+
704
+ # Add __get__ to support method binding.
705
+ def __get__(self, instance: Any, owner: Any):
706
+ def method(*args: Any, **kwargs: Any) -> Any:
707
+ return self(instance, *args, **kwargs)
708
+ return method
709
+
710
+ # Class to represent a lambda function.
711
+ class LambdaFunction:
712
+ def __init__(self, node: ast.Lambda, closure: List[Dict[str, Any]], interpreter: ASTInterpreter) -> None:
713
+ self.node: ast.Lambda = node
714
+ self.closure: List[Dict[str, Any]] = self.env_stack_reference(closure)
715
+ self.interpreter: ASTInterpreter = interpreter
716
+
717
+ def env_stack_reference(self, env_stack: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
718
+ return env_stack[:]
719
+
720
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
721
+ new_env_stack: List[Dict[str, Any]] = self.closure[:]
722
+ local_frame: Dict[str, Any] = {}
723
+ if len(args) < len(self.node.args.args):
724
+ raise TypeError("Not enough arguments for lambda")
725
+ if len(args) > len(self.node.args.args):
726
+ raise TypeError("Too many arguments for lambda")
727
+ if kwargs:
728
+ raise TypeError("Lambda does not support keyword arguments")
729
+ for i, arg in enumerate(self.node.args.args):
730
+ local_frame[arg.arg] = args[i]
731
+ new_env_stack.append(local_frame)
732
+ new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
733
+ return new_interp.visit(self.node.body)
734
+
735
+ # The main function to interpret an AST.
736
+ def interpret_ast(ast_tree: Any, allowed_modules: list[str], source: str = "") -> Any:
737
+ import ast
738
+ # Keep only yield-based nodes in fallback.
739
+ unsupported = (ast.Yield, ast.YieldFrom)
740
+ for node in ast.walk(ast_tree):
741
+ if isinstance(node, unsupported):
742
+ safe_globals = {
743
+ "__builtins__": {
744
+ "range": range,
745
+ "len": len,
746
+ "print": print,
747
+ "__import__": __import__,
748
+ "ZeroDivisionError": ZeroDivisionError,
749
+ "ValueError": ValueError,
750
+ "NameError": NameError,
751
+ "TypeError": TypeError,
752
+ "list": list,
753
+ "dict": dict,
754
+ "tuple": tuple,
755
+ "set": set,
756
+ "float": float,
757
+ "int": int,
758
+ "bool": bool,
759
+ "Exception": Exception
760
+ }
761
+ }
762
+ for mod in allowed_modules:
763
+ safe_globals[mod] = __import__(mod)
764
+ local_vars = {}
765
+ # ...existing code...
766
+ exec(compile(ast_tree, "<string>", "exec"), safe_globals, local_vars)
767
+ return local_vars.get("result", None)
768
+ # Otherwise, use the custom interpreter.
769
+ interpreter = ASTInterpreter(allowed_modules=allowed_modules, source=source)
770
+ return interpreter.visit(ast_tree)
771
+
772
+ # A helper function which takes a Python code string and a list of allowed module names,
773
+ # then parses and interprets the code.
774
+ def interpret_code(source_code: str, allowed_modules: List[str]) -> Any:
775
+ """
776
+ Interpret a Python source code string with a restricted set of allowed modules.
777
+
778
+ :param source_code: The Python source code to interpret.
779
+ :param allowed_modules: A list of module names that are allowed.
780
+ :return: The result of interpreting the source code.
781
+ """
782
+ # Dedent the source to normalize its indentation.
783
+ dedented_source = textwrap.dedent(source_code)
784
+ tree: ast.AST = ast.parse(dedented_source)
785
+ return interpret_ast(tree, allowed_modules, source=dedented_source)
786
+
787
+ if __name__ == "__main__":
788
+ print("Script is running!")
789
+ source_code_1: str = """
790
+ import math
791
+ def square(x):
792
+ return x * x
793
+
794
+ y = square(5)
795
+ z = math.sqrt(y)
796
+ z
797
+ """
798
+ # Only "math" is allowed here.
799
+ try:
800
+ result_1: Any = interpret_code(source_code_1, allowed_modules=["math"])
801
+ print("Result:", result_1)
802
+ except Exception as e:
803
+ print("Interpreter error:", e)
804
+
805
+ print("Second example:")
806
+
807
+ # Define the source code with multiple operations and a list comprehension.
808
+ source_code_2: str = """
809
+ import math
810
+ import numpy as np
811
+ def transform_array(x):
812
+ # Apply square root
813
+ sqrt_vals = [math.sqrt(val) for val in x]
814
+
815
+ # Apply sine function
816
+ sin_vals = [math.sin(val) for val in sqrt_vals]
817
+
818
+ # Apply exponential
819
+ exp_vals = [math.exp(val) for val in sin_vals]
820
+
821
+ return exp_vals
822
+
823
+ array_input = np.array([1, 4, 9, 16, 25])
824
+ result = transform_array(array_input)
825
+ result
826
+ """
827
+ print("About to parse source code")
828
+ try:
829
+ tree_2: ast.AST = ast.parse(textwrap.dedent(source_code_2))
830
+ print("Source code parsed successfully")
831
+ # Allow both math and numpy.
832
+ result_2: Any = interpret_ast(tree_2, allowed_modules=["math", "numpy"], source=textwrap.dedent(source_code_2))
833
+ print("Result:", result_2)
834
+ except Exception as e:
835
+ print("Interpreter error:", e)