quantalogic 0.60.0__py3-none-any.whl → 0.61.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. quantalogic/agent_config.py +5 -5
  2. quantalogic/agent_factory.py +2 -2
  3. quantalogic/codeact/__init__.py +0 -0
  4. quantalogic/codeact/agent.py +499 -0
  5. quantalogic/codeact/cli.py +232 -0
  6. quantalogic/codeact/constants.py +9 -0
  7. quantalogic/codeact/events.py +78 -0
  8. quantalogic/codeact/llm_util.py +76 -0
  9. quantalogic/codeact/prompts/error_format.j2 +11 -0
  10. quantalogic/codeact/prompts/generate_action.j2 +26 -0
  11. quantalogic/codeact/prompts/generate_program.j2 +39 -0
  12. quantalogic/codeact/prompts/response_format.j2 +11 -0
  13. quantalogic/codeact/tools_manager.py +135 -0
  14. quantalogic/codeact/utils.py +135 -0
  15. quantalogic/coding_agent.py +2 -2
  16. quantalogic/python_interpreter/__init__.py +23 -0
  17. quantalogic/python_interpreter/assignment_visitors.py +63 -0
  18. quantalogic/python_interpreter/base_visitors.py +20 -0
  19. quantalogic/python_interpreter/class_visitors.py +22 -0
  20. quantalogic/python_interpreter/comprehension_visitors.py +172 -0
  21. quantalogic/python_interpreter/context_visitors.py +59 -0
  22. quantalogic/python_interpreter/control_flow_visitors.py +88 -0
  23. quantalogic/python_interpreter/exception_visitors.py +109 -0
  24. quantalogic/python_interpreter/exceptions.py +39 -0
  25. quantalogic/python_interpreter/execution.py +202 -0
  26. quantalogic/python_interpreter/function_utils.py +386 -0
  27. quantalogic/python_interpreter/function_visitors.py +209 -0
  28. quantalogic/python_interpreter/import_visitors.py +28 -0
  29. quantalogic/python_interpreter/interpreter_core.py +358 -0
  30. quantalogic/python_interpreter/literal_visitors.py +74 -0
  31. quantalogic/python_interpreter/misc_visitors.py +148 -0
  32. quantalogic/python_interpreter/operator_visitors.py +108 -0
  33. quantalogic/python_interpreter/scope.py +10 -0
  34. quantalogic/python_interpreter/visit_handlers.py +110 -0
  35. quantalogic/tools/__init__.py +5 -4
  36. quantalogic/tools/action_gen.py +366 -0
  37. quantalogic/tools/python_tool.py +13 -0
  38. quantalogic/tools/{search_definition_names.py → search_definition_names_tool.py} +2 -2
  39. quantalogic/tools/tool.py +116 -22
  40. quantalogic/tools/utils/generate_database_report.py +2 -2
  41. quantalogic/utils/__init__.py +0 -1
  42. quantalogic/utils/test_python_interpreter.py +119 -0
  43. {quantalogic-0.60.0.dist-info → quantalogic-0.61.1.dist-info}/METADATA +8 -2
  44. {quantalogic-0.60.0.dist-info → quantalogic-0.61.1.dist-info}/RECORD +47 -15
  45. quantalogic/utils/python_interpreter.py +0 -905
  46. {quantalogic-0.60.0.dist-info → quantalogic-0.61.1.dist-info}/LICENSE +0 -0
  47. {quantalogic-0.60.0.dist-info → quantalogic-0.61.1.dist-info}/WHEEL +0 -0
  48. {quantalogic-0.60.0.dist-info → quantalogic-0.61.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,135 @@
1
+ import ast
2
+ import inspect
3
+ from functools import wraps
4
+ from typing import Any, Callable, Tuple
5
+
6
+ from loguru import logger
7
+ from lxml import etree
8
+
9
+
10
+ def log_async_tool(verb: str):
11
+ """Decorator factory for consistent async tool logging."""
12
+ def decorator(func: Callable) -> Callable:
13
+ @wraps(func)
14
+ async def wrapper(*args, **kwargs):
15
+ logger.info(f"Starting tool: {func.__name__}")
16
+ sig = inspect.signature(func)
17
+ bound_args = sig.bind(*args, **kwargs)
18
+ bound_args.apply_defaults()
19
+ logger.info(f"{verb} {', '.join(f'{k}={v}' for k, v in bound_args.arguments.items())}")
20
+ result = await func(*args, **kwargs)
21
+ logger.info(f"Finished tool: {func.__name__}")
22
+ return result
23
+ return wrapper
24
+ return decorator
25
+
26
+
27
+ def log_tool_method(func: Callable) -> Callable:
28
+ """Decorator for logging Tool class methods."""
29
+ @wraps(func)
30
+ async def wrapper(self, **kwargs):
31
+ logger.info(f"Starting tool: {self.name}")
32
+ try:
33
+ result = await func(self, **kwargs)
34
+ logger.info(f"Finished tool: {self.name}")
35
+ return result
36
+ except Exception as e:
37
+ logger.error(f"Tool {self.name} failed: {e}")
38
+ raise
39
+ return wrapper
40
+
41
+
42
+ def validate_xml(xml_string: str) -> bool:
43
+ """Validate XML string."""
44
+ try:
45
+ etree.fromstring(xml_string)
46
+ return True
47
+ except etree.XMLSyntaxError as e:
48
+ logger.error(f"XML validation failed: {e}")
49
+ return False
50
+
51
+
52
+ def validate_code(code: str) -> bool:
53
+ """Check if code has an async main() function."""
54
+ try:
55
+ tree = ast.parse(code)
56
+ return any(isinstance(node, ast.AsyncFunctionDef) and node.name == "main"
57
+ for node in ast.walk(tree))
58
+ except SyntaxError:
59
+ return False
60
+
61
+
62
+ def format_xml_element(tag: str, value: Any, **attribs) -> etree.Element:
63
+ """Create an XML element with optional CDATA and attributes."""
64
+ elem = etree.Element(tag, **attribs)
65
+ elem.text = etree.CDATA(str(value)) if value is not None else None
66
+ return elem
67
+
68
+
69
+ class XMLResultHandler:
70
+ """Utility class for handling XML formatting and parsing."""
71
+ @staticmethod
72
+ def format_execution_result(result) -> str:
73
+ """Format execution result as XML."""
74
+ root = etree.Element("ExecutionResult")
75
+ root.append(format_xml_element("Status", "Success" if not result.error else "Error"))
76
+ root.append(format_xml_element("Value", result.result or result.error))
77
+ root.append(format_xml_element("ExecutionTime", f"{result.execution_time:.2f} seconds"))
78
+
79
+ completed = result.result and result.result.startswith("Task completed:")
80
+ root.append(format_xml_element("Completed", str(completed).lower()))
81
+
82
+ if completed:
83
+ final_answer = result.result[len("Task completed:"):].strip()
84
+ root.append(format_xml_element("FinalAnswer", final_answer))
85
+
86
+ if result.local_variables:
87
+ vars_elem = etree.SubElement(root, "Variables")
88
+ for k, v in result.local_variables.items():
89
+ if not callable(v) and not k.startswith("__"):
90
+ vars_elem.append(format_xml_element("Variable", str(v)[:5000] +
91
+ ("... (truncated)" if len(str(v)) > 5000 else ""),
92
+ name=k))
93
+ return etree.tostring(root, pretty_print=True, encoding="unicode")
94
+
95
+ @staticmethod
96
+ def format_result_summary(result_xml: str) -> str:
97
+ """Format XML result into a readable summary."""
98
+ try:
99
+ root = etree.fromstring(result_xml)
100
+ lines = [
101
+ f"- Status: {root.findtext('Status', 'N/A')}",
102
+ f"- Value: {root.findtext('Value', 'N/A')}",
103
+ f"- Execution Time: {root.findtext('ExecutionTime', 'N/A')}",
104
+ f"- Completed: {root.findtext('Completed', 'N/A').capitalize()}"
105
+ ]
106
+ if final_answer := root.findtext("FinalAnswer"):
107
+ lines.append(f"- Final Answer: {final_answer}")
108
+
109
+ if (vars_elem := root.find("Variables")) is not None:
110
+ lines.append("- Variables:")
111
+ lines.extend(f" - {var.get('name', 'unknown')}: {var.text.strip() or 'N/A'}"
112
+ for var in vars_elem.findall("Variable"))
113
+ return "\n".join(lines)
114
+ except etree.XMLSyntaxError:
115
+ logger.error(f"Failed to parse XML: {result_xml}")
116
+ return result_xml
117
+
118
+ @staticmethod
119
+ def parse_response(response: str) -> Tuple[str, str]:
120
+ """Parse XML response to extract thought and code."""
121
+ try:
122
+ root = etree.fromstring(response)
123
+ thought = root.findtext("Thought") or ""
124
+ code = root.findtext("Code") or ""
125
+ return thought, code
126
+ except etree.XMLSyntaxError as e:
127
+ raise ValueError(f"Failed to parse XML: {e}")
128
+
129
+ @staticmethod
130
+ def extract_result_value(result: str) -> str:
131
+ """Extract the value from the result XML."""
132
+ try:
133
+ return etree.fromstring(result).findtext("Value") or ""
134
+ except etree.XMLSyntaxError:
135
+ return ""
@@ -15,7 +15,7 @@ from quantalogic.tools import (
15
15
  ReadHTMLTool,
16
16
  ReplaceInFileTool,
17
17
  RipgrepTool,
18
- SearchDefinitionNames,
18
+ SearchDefinitionNamesTool,
19
19
  TaskCompleteTool,
20
20
  WriteFileTool,
21
21
  )
@@ -71,7 +71,7 @@ def create_coding_agent(
71
71
  # Code navigation and search tools
72
72
  ListDirectoryTool(), # Lists directory contents
73
73
  RipgrepTool(), # Searches code with regex
74
- SearchDefinitionNames(), # Finds code definitions
74
+ SearchDefinitionNamesTool(), # Finds code definitions
75
75
  # Specialized language model tools
76
76
  ReadFileTool(),
77
77
  ExecuteBashCommandTool(),
@@ -0,0 +1,23 @@
1
+ # quantalogic/utils/__init__.py
2
+ from .exceptions import BreakException, ContinueException, ReturnException, WrappedException, has_await
3
+ from .execution import AsyncExecutionResult, execute_async, interpret_ast, interpret_code
4
+ from .function_utils import AsyncFunction, Function, LambdaFunction
5
+ from .interpreter_core import ASTInterpreter
6
+ from .scope import Scope
7
+
8
+ __all__ = [
9
+ 'ASTInterpreter',
10
+ 'execute_async',
11
+ 'interpret_ast',
12
+ 'interpret_code',
13
+ 'AsyncExecutionResult',
14
+ 'ReturnException',
15
+ 'BreakException',
16
+ 'ContinueException',
17
+ 'WrappedException',
18
+ 'has_await',
19
+ 'Function',
20
+ 'AsyncFunction',
21
+ 'LambdaFunction',
22
+ 'Scope',
23
+ ]
@@ -0,0 +1,63 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .interpreter_core import ASTInterpreter
5
+
6
+ async def visit_Assign(self: ASTInterpreter, node: ast.Assign, wrap_exceptions: bool = True) -> None:
7
+ value: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
8
+ for target in node.targets:
9
+ if isinstance(target, ast.Subscript):
10
+ obj = await self.visit(target.value, wrap_exceptions=wrap_exceptions)
11
+ key = await self.visit(target.slice, wrap_exceptions=wrap_exceptions)
12
+ obj[key] = value
13
+ else:
14
+ await self.assign(target, value)
15
+
16
+ async def visit_AugAssign(self: ASTInterpreter, node: ast.AugAssign, wrap_exceptions: bool = True) -> Any:
17
+ if isinstance(node.target, ast.Name):
18
+ current_val: Any = self.get_variable(node.target.id)
19
+ else:
20
+ current_val: Any = await self.visit(node.target, wrap_exceptions=wrap_exceptions)
21
+ right_val: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
22
+ op = node.op
23
+ if isinstance(op, ast.Add):
24
+ result: Any = current_val + right_val
25
+ elif isinstance(op, ast.Sub):
26
+ result = current_val - right_val
27
+ elif isinstance(op, ast.Mult):
28
+ result = current_val * right_val
29
+ elif isinstance(op, ast.Div):
30
+ result = current_val / right_val
31
+ elif isinstance(op, ast.FloorDiv):
32
+ result = current_val // right_val
33
+ elif isinstance(op, ast.Mod):
34
+ result = current_val % right_val
35
+ elif isinstance(op, ast.Pow):
36
+ result = current_val**right_val
37
+ elif isinstance(op, ast.BitAnd):
38
+ result = current_val & right_val
39
+ elif isinstance(op, ast.BitOr):
40
+ result = current_val | right_val
41
+ elif isinstance(op, ast.BitXor):
42
+ result = current_val ^ right_val
43
+ elif isinstance(op, ast.LShift):
44
+ result = current_val << right_val
45
+ elif isinstance(op, ast.RShift):
46
+ result = current_val >> right_val
47
+ else:
48
+ raise Exception("Unsupported augmented operator: " + str(op))
49
+ await self.assign(node.target, result)
50
+ return result
51
+
52
+ async def visit_AnnAssign(self: ASTInterpreter, node: ast.AnnAssign, wrap_exceptions: bool = True) -> None:
53
+ value = await self.visit(node.value, wrap_exceptions=wrap_exceptions) if node.value else None
54
+ annotation = await self.visit(node.annotation, wrap_exceptions=True)
55
+ if isinstance(node.target, ast.Name):
56
+ self.type_hints[node.target.id] = annotation
57
+ if value is not None or node.simple:
58
+ await self.assign(node.target, value)
59
+
60
+ async def visit_NamedExpr(self: ASTInterpreter, node: ast.NamedExpr, wrap_exceptions: bool = True) -> Any:
61
+ value = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
62
+ await self.assign(node.target, value)
63
+ return value
@@ -0,0 +1,20 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .exceptions import WrappedException
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_Module(self: ASTInterpreter, node: ast.Module, wrap_exceptions: bool = True) -> Any:
8
+ last_value = None
9
+ for stmt in node.body:
10
+ last_value = await self.visit(stmt, wrap_exceptions=True)
11
+ return last_value
12
+
13
+ async def visit_Expr(self: ASTInterpreter, node: ast.Expr, wrap_exceptions: bool = True) -> Any:
14
+ return await self.visit(node.value, wrap_exceptions=wrap_exceptions)
15
+
16
+ async def visit_Pass(self: ASTInterpreter, node: ast.Pass, wrap_exceptions: bool = True) -> None:
17
+ return None
18
+
19
+ async def visit_TypeIgnore(self: ASTInterpreter, node: ast.TypeIgnore, wrap_exceptions: bool = True) -> None:
20
+ pass
@@ -0,0 +1,22 @@
1
+ import ast
2
+ from typing import Any, Dict, List
3
+
4
+ from .function_utils import Function
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_ClassDef(self: ASTInterpreter, node: ast.ClassDef, wrap_exceptions: bool = True) -> Any:
8
+ base_frame = {}
9
+ self.env_stack.append(base_frame)
10
+ bases = [await self.visit(base, wrap_exceptions=True) for base in node.bases]
11
+ try:
12
+ for stmt in node.body:
13
+ await self.visit(stmt, wrap_exceptions=True)
14
+ class_dict = {k: v for k, v in self.env_stack[-1].items() if k not in ["__builtins__"]}
15
+ cls = type(node.name, tuple(bases), class_dict)
16
+ for name, value in class_dict.items():
17
+ if isinstance(value, Function):
18
+ value.defining_class = cls
19
+ self.env_stack[-2][node.name] = cls
20
+ return cls
21
+ finally:
22
+ self.env_stack.pop()
@@ -0,0 +1,172 @@
1
+ import ast
2
+ from typing import Any, Dict, List
3
+
4
+ from .interpreter_core import ASTInterpreter
5
+ from .exceptions import WrappedException
6
+
7
+ async def visit_ListComp(self: ASTInterpreter, node: ast.ListComp, wrap_exceptions: bool = True) -> List[Any]:
8
+ result = []
9
+ base_frame = self.env_stack[-1].copy()
10
+ self.env_stack.append(base_frame)
11
+
12
+ async def rec(gen_idx: int):
13
+ if gen_idx == len(node.generators):
14
+ element = await self.visit(node.elt, wrap_exceptions=wrap_exceptions)
15
+ result.append(element)
16
+ else:
17
+ comp = node.generators[gen_idx]
18
+ iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
19
+ if hasattr(iterable, '__aiter__'):
20
+ async for item in iterable:
21
+ new_frame = self.env_stack[-1].copy()
22
+ self.env_stack.append(new_frame)
23
+ await self.assign(comp.target, item)
24
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
25
+ if all(conditions):
26
+ await rec(gen_idx + 1)
27
+ self.env_stack.pop()
28
+ else:
29
+ try:
30
+ for item in iterable:
31
+ new_frame = self.env_stack[-1].copy()
32
+ self.env_stack.append(new_frame)
33
+ await self.assign(comp.target, item)
34
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
35
+ if all(conditions):
36
+ await rec(gen_idx + 1)
37
+ self.env_stack.pop()
38
+ except TypeError as e:
39
+ lineno = getattr(node, "lineno", 1)
40
+ col = getattr(node, "col_offset", 0)
41
+ context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
42
+ raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
43
+
44
+ await rec(0)
45
+ self.env_stack.pop()
46
+ return result
47
+
48
+ async def visit_DictComp(self: ASTInterpreter, node: ast.DictComp, wrap_exceptions: bool = True) -> Dict[Any, Any]:
49
+ result = {}
50
+ base_frame = self.env_stack[-1].copy()
51
+ self.env_stack.append(base_frame)
52
+
53
+ async def rec(gen_idx: int):
54
+ if gen_idx == len(node.generators):
55
+ key = await self.visit(node.key, wrap_exceptions=True)
56
+ val = await self.visit(node.value, wrap_exceptions=True)
57
+ result[key] = val
58
+ else:
59
+ comp = node.generators[gen_idx]
60
+ iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
61
+ if hasattr(iterable, '__aiter__'):
62
+ async for item in iterable:
63
+ new_frame = self.env_stack[-1].copy()
64
+ self.env_stack.append(new_frame)
65
+ await self.assign(comp.target, item)
66
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
67
+ if all(conditions):
68
+ await rec(gen_idx + 1)
69
+ self.env_stack.pop()
70
+ else:
71
+ try:
72
+ for item in iterable:
73
+ new_frame = self.env_stack[-1].copy()
74
+ self.env_stack.append(new_frame)
75
+ await self.assign(comp.target, item)
76
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
77
+ if all(conditions):
78
+ await rec(gen_idx + 1)
79
+ self.env_stack.pop()
80
+ except TypeError as e:
81
+ lineno = getattr(node, "lineno", 1)
82
+ col = getattr(node, "col_offset", 0)
83
+ context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
84
+ raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
85
+
86
+ await rec(0)
87
+ self.env_stack.pop()
88
+ return result
89
+
90
+ async def visit_SetComp(self: ASTInterpreter, node: ast.SetComp, wrap_exceptions: bool = True) -> set:
91
+ result = set()
92
+ base_frame = self.env_stack[-1].copy()
93
+ self.env_stack.append(base_frame)
94
+
95
+ async def rec(gen_idx: int):
96
+ if gen_idx == len(node.generators):
97
+ result.add(await self.visit(node.elt, wrap_exceptions=True))
98
+ else:
99
+ comp = node.generators[gen_idx]
100
+ iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
101
+ if hasattr(iterable, '__aiter__'):
102
+ async for item in iterable:
103
+ new_frame = self.env_stack[-1].copy()
104
+ self.env_stack.append(new_frame)
105
+ await self.assign(comp.target, item)
106
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
107
+ if all(conditions):
108
+ await rec(gen_idx + 1)
109
+ self.env_stack.pop()
110
+ else:
111
+ try:
112
+ for item in iterable:
113
+ new_frame = self.env_stack[-1].copy()
114
+ self.env_stack.append(new_frame)
115
+ await self.assign(comp.target, item)
116
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
117
+ if all(conditions):
118
+ await rec(gen_idx + 1)
119
+ self.env_stack.pop()
120
+ except TypeError as e:
121
+ lineno = getattr(node, "lineno", 1)
122
+ col = getattr(node, "col_offset", 0)
123
+ context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
124
+ raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
125
+
126
+ await rec(0)
127
+ self.env_stack.pop()
128
+ return result
129
+
130
+ async def visit_GeneratorExp(self: ASTInterpreter, node: ast.GeneratorExp, wrap_exceptions: bool = True) -> Any:
131
+ base_frame: Dict[str, Any] = self.env_stack[-1].copy()
132
+ self.env_stack.append(base_frame)
133
+
134
+ async def gen():
135
+ async def rec(gen_idx: int):
136
+ if gen_idx == len(node.generators):
137
+ yield await self.visit(node.elt, wrap_exceptions=True)
138
+ else:
139
+ comp = node.generators[gen_idx]
140
+ iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
141
+ if hasattr(iterable, '__aiter__'):
142
+ async for item in iterable:
143
+ new_frame = self.env_stack[-1].copy()
144
+ self.env_stack.append(new_frame)
145
+ await self.assign(comp.target, item)
146
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
147
+ if all(conditions):
148
+ async for val in rec(gen_idx + 1):
149
+ yield val
150
+ self.env_stack.pop()
151
+ else:
152
+ try:
153
+ for item in iterable:
154
+ new_frame = self.env_stack[-1].copy()
155
+ self.env_stack.append(new_frame)
156
+ await self.assign(comp.target, item)
157
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
158
+ if all(conditions):
159
+ async for val in rec(gen_idx + 1):
160
+ yield val
161
+ self.env_stack.pop()
162
+ except TypeError as e:
163
+ lineno = getattr(node, "lineno", 1)
164
+ col = getattr(node, "col_offset", 0)
165
+ context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
166
+ raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
167
+
168
+ async for val in rec(0):
169
+ yield val
170
+
171
+ self.env_stack.pop()
172
+ return gen()
@@ -0,0 +1,59 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .exceptions import ReturnException
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_With(self: ASTInterpreter, node: ast.With, wrap_exceptions: bool = True) -> Any:
8
+ result = None
9
+ contexts = []
10
+ for item in node.items:
11
+ ctx = await self.visit(item.context_expr, wrap_exceptions=wrap_exceptions)
12
+ val = ctx.__enter__()
13
+ contexts.append((ctx, val))
14
+ if item.optional_vars:
15
+ await self.assign(item.optional_vars, val)
16
+ try:
17
+ for stmt in node.body:
18
+ result = await self.visit(stmt, wrap_exceptions=wrap_exceptions)
19
+ except ReturnException as ret:
20
+ for ctx, _ in reversed(contexts):
21
+ ctx.__exit__(None, None, None)
22
+ raise ret
23
+ except Exception as e:
24
+ exc_type, exc_value, tb = type(e), e, e.__traceback__
25
+ for ctx, _ in reversed(contexts):
26
+ if not ctx.__exit__(exc_type, exc_value, tb):
27
+ raise
28
+ raise
29
+ else:
30
+ for ctx, _ in reversed(contexts):
31
+ ctx.__exit__(None, None, None)
32
+ return result
33
+
34
+ async def visit_AsyncWith(self: ASTInterpreter, node: ast.AsyncWith, wrap_exceptions: bool = True) -> Any:
35
+ result = None
36
+ contexts = []
37
+ for item in node.items:
38
+ ctx = await self.visit(item.context_expr, wrap_exceptions=wrap_exceptions)
39
+ val = await ctx.__aenter__()
40
+ contexts.append((ctx, val))
41
+ if item.optional_vars:
42
+ await self.assign(item.optional_vars, val)
43
+ try:
44
+ for stmt in node.body:
45
+ result = await self.visit(stmt, wrap_exceptions=wrap_exceptions)
46
+ except ReturnException as ret:
47
+ for ctx, _ in reversed(contexts):
48
+ await ctx.__aexit__(None, None, None)
49
+ raise ret
50
+ except Exception as e:
51
+ exc_type, exc_value, tb = type(e), e, e.__traceback__
52
+ for ctx, _ in reversed(contexts):
53
+ if not await ctx.__aexit__(exc_type, exc_value, tb):
54
+ raise
55
+ raise
56
+ else:
57
+ for ctx, _ in reversed(contexts):
58
+ await ctx.__aexit__(None, None, None)
59
+ return result
@@ -0,0 +1,88 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .exceptions import BreakException, ContinueException, ReturnException
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_If(self: ASTInterpreter, node: ast.If, wrap_exceptions: bool = True) -> Any:
8
+ if await self.visit(node.test, wrap_exceptions=wrap_exceptions):
9
+ branch = node.body
10
+ else:
11
+ branch = node.orelse
12
+ result = None
13
+ if branch:
14
+ for stmt in branch[:-1]:
15
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
16
+ result = await self.visit(branch[-1], wrap_exceptions=wrap_exceptions)
17
+ return result
18
+
19
+ async def visit_While(self: ASTInterpreter, node: ast.While, wrap_exceptions: bool = True) -> None:
20
+ while await self.visit(node.test, wrap_exceptions=wrap_exceptions):
21
+ try:
22
+ for stmt in node.body:
23
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
24
+ except BreakException:
25
+ break
26
+ except ContinueException:
27
+ continue
28
+ for stmt in node.orelse:
29
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
30
+
31
+ async def visit_For(self: ASTInterpreter, node: ast.For, wrap_exceptions: bool = True) -> None:
32
+ iter_obj: Any = await self.visit(node.iter, wrap_exceptions=wrap_exceptions)
33
+ broke = False
34
+ if hasattr(iter_obj, '__aiter__'):
35
+ async for item in iter_obj:
36
+ await self.assign(node.target, item)
37
+ try:
38
+ for stmt in node.body:
39
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
40
+ except BreakException:
41
+ broke = True
42
+ break
43
+ except ContinueException:
44
+ continue
45
+ else:
46
+ for item in iter_obj:
47
+ await self.assign(node.target, item)
48
+ try:
49
+ for stmt in node.body:
50
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
51
+ except BreakException:
52
+ broke = True
53
+ break
54
+ except ContinueException:
55
+ continue
56
+ if not broke:
57
+ for stmt in node.orelse:
58
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
59
+
60
+ async def visit_AsyncFor(self: ASTInterpreter, node: ast.AsyncFor, wrap_exceptions: bool = True) -> None:
61
+ iterable = await self.visit(node.iter, wrap_exceptions=wrap_exceptions)
62
+ broke = False
63
+ async for value in iterable:
64
+ await self.assign(node.target, value)
65
+ try:
66
+ for stmt in node.body:
67
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
68
+ except BreakException:
69
+ broke = True
70
+ break
71
+ except ContinueException:
72
+ continue
73
+ if not broke:
74
+ for stmt in node.orelse:
75
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
76
+
77
+ async def visit_Break(self: ASTInterpreter, node: ast.Break, wrap_exceptions: bool = True) -> None:
78
+ raise BreakException()
79
+
80
+ async def visit_Continue(self: ASTInterpreter, node: ast.Continue, wrap_exceptions: bool = True) -> None:
81
+ raise ContinueException()
82
+
83
+ async def visit_Return(self: ASTInterpreter, node: ast.Return, wrap_exceptions: bool = True) -> None:
84
+ value: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions) if node.value is not None else None
85
+ raise ReturnException(value)
86
+
87
+ async def visit_IfExp(self: ASTInterpreter, node: ast.IfExp, wrap_exceptions: bool = True) -> Any:
88
+ return await self.visit(node.body, wrap_exceptions=wrap_exceptions) if await self.visit(node.test, wrap_exceptions=wrap_exceptions) else await self.visit(node.orelse, wrap_exceptions=wrap_exceptions)