quantalogic 0.60.0__py3-none-any.whl → 0.61.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/agent_config.py +5 -5
- quantalogic/agent_factory.py +2 -2
- quantalogic/codeact/__init__.py +0 -0
- quantalogic/codeact/agent.py +499 -0
- quantalogic/codeact/cli.py +232 -0
- quantalogic/codeact/constants.py +9 -0
- quantalogic/codeact/events.py +78 -0
- quantalogic/codeact/llm_util.py +76 -0
- quantalogic/codeact/prompts/error_format.j2 +11 -0
- quantalogic/codeact/prompts/generate_action.j2 +26 -0
- quantalogic/codeact/prompts/generate_program.j2 +39 -0
- quantalogic/codeact/prompts/response_format.j2 +11 -0
- quantalogic/codeact/tools_manager.py +135 -0
- quantalogic/codeact/utils.py +135 -0
- quantalogic/coding_agent.py +2 -2
- quantalogic/python_interpreter/__init__.py +23 -0
- quantalogic/python_interpreter/assignment_visitors.py +63 -0
- quantalogic/python_interpreter/base_visitors.py +20 -0
- quantalogic/python_interpreter/class_visitors.py +22 -0
- quantalogic/python_interpreter/comprehension_visitors.py +172 -0
- quantalogic/python_interpreter/context_visitors.py +59 -0
- quantalogic/python_interpreter/control_flow_visitors.py +88 -0
- quantalogic/python_interpreter/exception_visitors.py +109 -0
- quantalogic/python_interpreter/exceptions.py +39 -0
- quantalogic/python_interpreter/execution.py +202 -0
- quantalogic/python_interpreter/function_utils.py +386 -0
- quantalogic/python_interpreter/function_visitors.py +209 -0
- quantalogic/python_interpreter/import_visitors.py +28 -0
- quantalogic/python_interpreter/interpreter_core.py +358 -0
- quantalogic/python_interpreter/literal_visitors.py +74 -0
- quantalogic/python_interpreter/misc_visitors.py +148 -0
- quantalogic/python_interpreter/operator_visitors.py +108 -0
- quantalogic/python_interpreter/scope.py +10 -0
- quantalogic/python_interpreter/visit_handlers.py +110 -0
- quantalogic/tools/__init__.py +5 -4
- quantalogic/tools/action_gen.py +366 -0
- quantalogic/tools/python_tool.py +13 -0
- quantalogic/tools/{search_definition_names.py → search_definition_names_tool.py} +2 -2
- quantalogic/tools/tool.py +116 -22
- quantalogic/utils/__init__.py +0 -1
- quantalogic/utils/test_python_interpreter.py +119 -0
- {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/METADATA +7 -2
- {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/RECORD +46 -14
- quantalogic/utils/python_interpreter.py +0 -905
- {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/LICENSE +0 -0
- {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/WHEEL +0 -0
- {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,135 @@
|
|
1
|
+
import ast
|
2
|
+
import inspect
|
3
|
+
from functools import wraps
|
4
|
+
from typing import Any, Callable, Tuple
|
5
|
+
|
6
|
+
from loguru import logger
|
7
|
+
from lxml import etree
|
8
|
+
|
9
|
+
|
10
|
+
def log_async_tool(verb: str):
|
11
|
+
"""Decorator factory for consistent async tool logging."""
|
12
|
+
def decorator(func: Callable) -> Callable:
|
13
|
+
@wraps(func)
|
14
|
+
async def wrapper(*args, **kwargs):
|
15
|
+
logger.info(f"Starting tool: {func.__name__}")
|
16
|
+
sig = inspect.signature(func)
|
17
|
+
bound_args = sig.bind(*args, **kwargs)
|
18
|
+
bound_args.apply_defaults()
|
19
|
+
logger.info(f"{verb} {', '.join(f'{k}={v}' for k, v in bound_args.arguments.items())}")
|
20
|
+
result = await func(*args, **kwargs)
|
21
|
+
logger.info(f"Finished tool: {func.__name__}")
|
22
|
+
return result
|
23
|
+
return wrapper
|
24
|
+
return decorator
|
25
|
+
|
26
|
+
|
27
|
+
def log_tool_method(func: Callable) -> Callable:
|
28
|
+
"""Decorator for logging Tool class methods."""
|
29
|
+
@wraps(func)
|
30
|
+
async def wrapper(self, **kwargs):
|
31
|
+
logger.info(f"Starting tool: {self.name}")
|
32
|
+
try:
|
33
|
+
result = await func(self, **kwargs)
|
34
|
+
logger.info(f"Finished tool: {self.name}")
|
35
|
+
return result
|
36
|
+
except Exception as e:
|
37
|
+
logger.error(f"Tool {self.name} failed: {e}")
|
38
|
+
raise
|
39
|
+
return wrapper
|
40
|
+
|
41
|
+
|
42
|
+
def validate_xml(xml_string: str) -> bool:
|
43
|
+
"""Validate XML string."""
|
44
|
+
try:
|
45
|
+
etree.fromstring(xml_string)
|
46
|
+
return True
|
47
|
+
except etree.XMLSyntaxError as e:
|
48
|
+
logger.error(f"XML validation failed: {e}")
|
49
|
+
return False
|
50
|
+
|
51
|
+
|
52
|
+
def validate_code(code: str) -> bool:
|
53
|
+
"""Check if code has an async main() function."""
|
54
|
+
try:
|
55
|
+
tree = ast.parse(code)
|
56
|
+
return any(isinstance(node, ast.AsyncFunctionDef) and node.name == "main"
|
57
|
+
for node in ast.walk(tree))
|
58
|
+
except SyntaxError:
|
59
|
+
return False
|
60
|
+
|
61
|
+
|
62
|
+
def format_xml_element(tag: str, value: Any, **attribs) -> etree.Element:
|
63
|
+
"""Create an XML element with optional CDATA and attributes."""
|
64
|
+
elem = etree.Element(tag, **attribs)
|
65
|
+
elem.text = etree.CDATA(str(value)) if value is not None else None
|
66
|
+
return elem
|
67
|
+
|
68
|
+
|
69
|
+
class XMLResultHandler:
|
70
|
+
"""Utility class for handling XML formatting and parsing."""
|
71
|
+
@staticmethod
|
72
|
+
def format_execution_result(result) -> str:
|
73
|
+
"""Format execution result as XML."""
|
74
|
+
root = etree.Element("ExecutionResult")
|
75
|
+
root.append(format_xml_element("Status", "Success" if not result.error else "Error"))
|
76
|
+
root.append(format_xml_element("Value", result.result or result.error))
|
77
|
+
root.append(format_xml_element("ExecutionTime", f"{result.execution_time:.2f} seconds"))
|
78
|
+
|
79
|
+
completed = result.result and result.result.startswith("Task completed:")
|
80
|
+
root.append(format_xml_element("Completed", str(completed).lower()))
|
81
|
+
|
82
|
+
if completed:
|
83
|
+
final_answer = result.result[len("Task completed:"):].strip()
|
84
|
+
root.append(format_xml_element("FinalAnswer", final_answer))
|
85
|
+
|
86
|
+
if result.local_variables:
|
87
|
+
vars_elem = etree.SubElement(root, "Variables")
|
88
|
+
for k, v in result.local_variables.items():
|
89
|
+
if not callable(v) and not k.startswith("__"):
|
90
|
+
vars_elem.append(format_xml_element("Variable", str(v)[:5000] +
|
91
|
+
("... (truncated)" if len(str(v)) > 5000 else ""),
|
92
|
+
name=k))
|
93
|
+
return etree.tostring(root, pretty_print=True, encoding="unicode")
|
94
|
+
|
95
|
+
@staticmethod
|
96
|
+
def format_result_summary(result_xml: str) -> str:
|
97
|
+
"""Format XML result into a readable summary."""
|
98
|
+
try:
|
99
|
+
root = etree.fromstring(result_xml)
|
100
|
+
lines = [
|
101
|
+
f"- Status: {root.findtext('Status', 'N/A')}",
|
102
|
+
f"- Value: {root.findtext('Value', 'N/A')}",
|
103
|
+
f"- Execution Time: {root.findtext('ExecutionTime', 'N/A')}",
|
104
|
+
f"- Completed: {root.findtext('Completed', 'N/A').capitalize()}"
|
105
|
+
]
|
106
|
+
if final_answer := root.findtext("FinalAnswer"):
|
107
|
+
lines.append(f"- Final Answer: {final_answer}")
|
108
|
+
|
109
|
+
if (vars_elem := root.find("Variables")) is not None:
|
110
|
+
lines.append("- Variables:")
|
111
|
+
lines.extend(f" - {var.get('name', 'unknown')}: {var.text.strip() or 'N/A'}"
|
112
|
+
for var in vars_elem.findall("Variable"))
|
113
|
+
return "\n".join(lines)
|
114
|
+
except etree.XMLSyntaxError:
|
115
|
+
logger.error(f"Failed to parse XML: {result_xml}")
|
116
|
+
return result_xml
|
117
|
+
|
118
|
+
@staticmethod
|
119
|
+
def parse_response(response: str) -> Tuple[str, str]:
|
120
|
+
"""Parse XML response to extract thought and code."""
|
121
|
+
try:
|
122
|
+
root = etree.fromstring(response)
|
123
|
+
thought = root.findtext("Thought") or ""
|
124
|
+
code = root.findtext("Code") or ""
|
125
|
+
return thought, code
|
126
|
+
except etree.XMLSyntaxError as e:
|
127
|
+
raise ValueError(f"Failed to parse XML: {e}")
|
128
|
+
|
129
|
+
@staticmethod
|
130
|
+
def extract_result_value(result: str) -> str:
|
131
|
+
"""Extract the value from the result XML."""
|
132
|
+
try:
|
133
|
+
return etree.fromstring(result).findtext("Value") or ""
|
134
|
+
except etree.XMLSyntaxError:
|
135
|
+
return ""
|
quantalogic/coding_agent.py
CHANGED
@@ -15,7 +15,7 @@ from quantalogic.tools import (
|
|
15
15
|
ReadHTMLTool,
|
16
16
|
ReplaceInFileTool,
|
17
17
|
RipgrepTool,
|
18
|
-
|
18
|
+
SearchDefinitionNamesTool,
|
19
19
|
TaskCompleteTool,
|
20
20
|
WriteFileTool,
|
21
21
|
)
|
@@ -71,7 +71,7 @@ def create_coding_agent(
|
|
71
71
|
# Code navigation and search tools
|
72
72
|
ListDirectoryTool(), # Lists directory contents
|
73
73
|
RipgrepTool(), # Searches code with regex
|
74
|
-
|
74
|
+
SearchDefinitionNamesTool(), # Finds code definitions
|
75
75
|
# Specialized language model tools
|
76
76
|
ReadFileTool(),
|
77
77
|
ExecuteBashCommandTool(),
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# quantalogic/utils/__init__.py
|
2
|
+
from .exceptions import BreakException, ContinueException, ReturnException, WrappedException, has_await
|
3
|
+
from .execution import AsyncExecutionResult, execute_async, interpret_ast, interpret_code
|
4
|
+
from .function_utils import AsyncFunction, Function, LambdaFunction
|
5
|
+
from .interpreter_core import ASTInterpreter
|
6
|
+
from .scope import Scope
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
'ASTInterpreter',
|
10
|
+
'execute_async',
|
11
|
+
'interpret_ast',
|
12
|
+
'interpret_code',
|
13
|
+
'AsyncExecutionResult',
|
14
|
+
'ReturnException',
|
15
|
+
'BreakException',
|
16
|
+
'ContinueException',
|
17
|
+
'WrappedException',
|
18
|
+
'has_await',
|
19
|
+
'Function',
|
20
|
+
'AsyncFunction',
|
21
|
+
'LambdaFunction',
|
22
|
+
'Scope',
|
23
|
+
]
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import ast
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from .interpreter_core import ASTInterpreter
|
5
|
+
|
6
|
+
async def visit_Assign(self: ASTInterpreter, node: ast.Assign, wrap_exceptions: bool = True) -> None:
|
7
|
+
value: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
|
8
|
+
for target in node.targets:
|
9
|
+
if isinstance(target, ast.Subscript):
|
10
|
+
obj = await self.visit(target.value, wrap_exceptions=wrap_exceptions)
|
11
|
+
key = await self.visit(target.slice, wrap_exceptions=wrap_exceptions)
|
12
|
+
obj[key] = value
|
13
|
+
else:
|
14
|
+
await self.assign(target, value)
|
15
|
+
|
16
|
+
async def visit_AugAssign(self: ASTInterpreter, node: ast.AugAssign, wrap_exceptions: bool = True) -> Any:
|
17
|
+
if isinstance(node.target, ast.Name):
|
18
|
+
current_val: Any = self.get_variable(node.target.id)
|
19
|
+
else:
|
20
|
+
current_val: Any = await self.visit(node.target, wrap_exceptions=wrap_exceptions)
|
21
|
+
right_val: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
|
22
|
+
op = node.op
|
23
|
+
if isinstance(op, ast.Add):
|
24
|
+
result: Any = current_val + right_val
|
25
|
+
elif isinstance(op, ast.Sub):
|
26
|
+
result = current_val - right_val
|
27
|
+
elif isinstance(op, ast.Mult):
|
28
|
+
result = current_val * right_val
|
29
|
+
elif isinstance(op, ast.Div):
|
30
|
+
result = current_val / right_val
|
31
|
+
elif isinstance(op, ast.FloorDiv):
|
32
|
+
result = current_val // right_val
|
33
|
+
elif isinstance(op, ast.Mod):
|
34
|
+
result = current_val % right_val
|
35
|
+
elif isinstance(op, ast.Pow):
|
36
|
+
result = current_val**right_val
|
37
|
+
elif isinstance(op, ast.BitAnd):
|
38
|
+
result = current_val & right_val
|
39
|
+
elif isinstance(op, ast.BitOr):
|
40
|
+
result = current_val | right_val
|
41
|
+
elif isinstance(op, ast.BitXor):
|
42
|
+
result = current_val ^ right_val
|
43
|
+
elif isinstance(op, ast.LShift):
|
44
|
+
result = current_val << right_val
|
45
|
+
elif isinstance(op, ast.RShift):
|
46
|
+
result = current_val >> right_val
|
47
|
+
else:
|
48
|
+
raise Exception("Unsupported augmented operator: " + str(op))
|
49
|
+
await self.assign(node.target, result)
|
50
|
+
return result
|
51
|
+
|
52
|
+
async def visit_AnnAssign(self: ASTInterpreter, node: ast.AnnAssign, wrap_exceptions: bool = True) -> None:
|
53
|
+
value = await self.visit(node.value, wrap_exceptions=wrap_exceptions) if node.value else None
|
54
|
+
annotation = await self.visit(node.annotation, wrap_exceptions=True)
|
55
|
+
if isinstance(node.target, ast.Name):
|
56
|
+
self.type_hints[node.target.id] = annotation
|
57
|
+
if value is not None or node.simple:
|
58
|
+
await self.assign(node.target, value)
|
59
|
+
|
60
|
+
async def visit_NamedExpr(self: ASTInterpreter, node: ast.NamedExpr, wrap_exceptions: bool = True) -> Any:
|
61
|
+
value = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
|
62
|
+
await self.assign(node.target, value)
|
63
|
+
return value
|
@@ -0,0 +1,20 @@
|
|
1
|
+
import ast
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from .exceptions import WrappedException
|
5
|
+
from .interpreter_core import ASTInterpreter
|
6
|
+
|
7
|
+
async def visit_Module(self: ASTInterpreter, node: ast.Module, wrap_exceptions: bool = True) -> Any:
|
8
|
+
last_value = None
|
9
|
+
for stmt in node.body:
|
10
|
+
last_value = await self.visit(stmt, wrap_exceptions=True)
|
11
|
+
return last_value
|
12
|
+
|
13
|
+
async def visit_Expr(self: ASTInterpreter, node: ast.Expr, wrap_exceptions: bool = True) -> Any:
|
14
|
+
return await self.visit(node.value, wrap_exceptions=wrap_exceptions)
|
15
|
+
|
16
|
+
async def visit_Pass(self: ASTInterpreter, node: ast.Pass, wrap_exceptions: bool = True) -> None:
|
17
|
+
return None
|
18
|
+
|
19
|
+
async def visit_TypeIgnore(self: ASTInterpreter, node: ast.TypeIgnore, wrap_exceptions: bool = True) -> None:
|
20
|
+
pass
|
@@ -0,0 +1,22 @@
|
|
1
|
+
import ast
|
2
|
+
from typing import Any, Dict, List
|
3
|
+
|
4
|
+
from .function_utils import Function
|
5
|
+
from .interpreter_core import ASTInterpreter
|
6
|
+
|
7
|
+
async def visit_ClassDef(self: ASTInterpreter, node: ast.ClassDef, wrap_exceptions: bool = True) -> Any:
|
8
|
+
base_frame = {}
|
9
|
+
self.env_stack.append(base_frame)
|
10
|
+
bases = [await self.visit(base, wrap_exceptions=True) for base in node.bases]
|
11
|
+
try:
|
12
|
+
for stmt in node.body:
|
13
|
+
await self.visit(stmt, wrap_exceptions=True)
|
14
|
+
class_dict = {k: v for k, v in self.env_stack[-1].items() if k not in ["__builtins__"]}
|
15
|
+
cls = type(node.name, tuple(bases), class_dict)
|
16
|
+
for name, value in class_dict.items():
|
17
|
+
if isinstance(value, Function):
|
18
|
+
value.defining_class = cls
|
19
|
+
self.env_stack[-2][node.name] = cls
|
20
|
+
return cls
|
21
|
+
finally:
|
22
|
+
self.env_stack.pop()
|
@@ -0,0 +1,172 @@
|
|
1
|
+
import ast
|
2
|
+
from typing import Any, Dict, List
|
3
|
+
|
4
|
+
from .interpreter_core import ASTInterpreter
|
5
|
+
from .exceptions import WrappedException
|
6
|
+
|
7
|
+
async def visit_ListComp(self: ASTInterpreter, node: ast.ListComp, wrap_exceptions: bool = True) -> List[Any]:
|
8
|
+
result = []
|
9
|
+
base_frame = self.env_stack[-1].copy()
|
10
|
+
self.env_stack.append(base_frame)
|
11
|
+
|
12
|
+
async def rec(gen_idx: int):
|
13
|
+
if gen_idx == len(node.generators):
|
14
|
+
element = await self.visit(node.elt, wrap_exceptions=wrap_exceptions)
|
15
|
+
result.append(element)
|
16
|
+
else:
|
17
|
+
comp = node.generators[gen_idx]
|
18
|
+
iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
|
19
|
+
if hasattr(iterable, '__aiter__'):
|
20
|
+
async for item in iterable:
|
21
|
+
new_frame = self.env_stack[-1].copy()
|
22
|
+
self.env_stack.append(new_frame)
|
23
|
+
await self.assign(comp.target, item)
|
24
|
+
conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
|
25
|
+
if all(conditions):
|
26
|
+
await rec(gen_idx + 1)
|
27
|
+
self.env_stack.pop()
|
28
|
+
else:
|
29
|
+
try:
|
30
|
+
for item in iterable:
|
31
|
+
new_frame = self.env_stack[-1].copy()
|
32
|
+
self.env_stack.append(new_frame)
|
33
|
+
await self.assign(comp.target, item)
|
34
|
+
conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
|
35
|
+
if all(conditions):
|
36
|
+
await rec(gen_idx + 1)
|
37
|
+
self.env_stack.pop()
|
38
|
+
except TypeError as e:
|
39
|
+
lineno = getattr(node, "lineno", 1)
|
40
|
+
col = getattr(node, "col_offset", 0)
|
41
|
+
context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
|
42
|
+
raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
|
43
|
+
|
44
|
+
await rec(0)
|
45
|
+
self.env_stack.pop()
|
46
|
+
return result
|
47
|
+
|
48
|
+
async def visit_DictComp(self: ASTInterpreter, node: ast.DictComp, wrap_exceptions: bool = True) -> Dict[Any, Any]:
|
49
|
+
result = {}
|
50
|
+
base_frame = self.env_stack[-1].copy()
|
51
|
+
self.env_stack.append(base_frame)
|
52
|
+
|
53
|
+
async def rec(gen_idx: int):
|
54
|
+
if gen_idx == len(node.generators):
|
55
|
+
key = await self.visit(node.key, wrap_exceptions=True)
|
56
|
+
val = await self.visit(node.value, wrap_exceptions=True)
|
57
|
+
result[key] = val
|
58
|
+
else:
|
59
|
+
comp = node.generators[gen_idx]
|
60
|
+
iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
|
61
|
+
if hasattr(iterable, '__aiter__'):
|
62
|
+
async for item in iterable:
|
63
|
+
new_frame = self.env_stack[-1].copy()
|
64
|
+
self.env_stack.append(new_frame)
|
65
|
+
await self.assign(comp.target, item)
|
66
|
+
conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
|
67
|
+
if all(conditions):
|
68
|
+
await rec(gen_idx + 1)
|
69
|
+
self.env_stack.pop()
|
70
|
+
else:
|
71
|
+
try:
|
72
|
+
for item in iterable:
|
73
|
+
new_frame = self.env_stack[-1].copy()
|
74
|
+
self.env_stack.append(new_frame)
|
75
|
+
await self.assign(comp.target, item)
|
76
|
+
conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
|
77
|
+
if all(conditions):
|
78
|
+
await rec(gen_idx + 1)
|
79
|
+
self.env_stack.pop()
|
80
|
+
except TypeError as e:
|
81
|
+
lineno = getattr(node, "lineno", 1)
|
82
|
+
col = getattr(node, "col_offset", 0)
|
83
|
+
context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
|
84
|
+
raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
|
85
|
+
|
86
|
+
await rec(0)
|
87
|
+
self.env_stack.pop()
|
88
|
+
return result
|
89
|
+
|
90
|
+
async def visit_SetComp(self: ASTInterpreter, node: ast.SetComp, wrap_exceptions: bool = True) -> set:
|
91
|
+
result = set()
|
92
|
+
base_frame = self.env_stack[-1].copy()
|
93
|
+
self.env_stack.append(base_frame)
|
94
|
+
|
95
|
+
async def rec(gen_idx: int):
|
96
|
+
if gen_idx == len(node.generators):
|
97
|
+
result.add(await self.visit(node.elt, wrap_exceptions=True))
|
98
|
+
else:
|
99
|
+
comp = node.generators[gen_idx]
|
100
|
+
iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
|
101
|
+
if hasattr(iterable, '__aiter__'):
|
102
|
+
async for item in iterable:
|
103
|
+
new_frame = self.env_stack[-1].copy()
|
104
|
+
self.env_stack.append(new_frame)
|
105
|
+
await self.assign(comp.target, item)
|
106
|
+
conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
|
107
|
+
if all(conditions):
|
108
|
+
await rec(gen_idx + 1)
|
109
|
+
self.env_stack.pop()
|
110
|
+
else:
|
111
|
+
try:
|
112
|
+
for item in iterable:
|
113
|
+
new_frame = self.env_stack[-1].copy()
|
114
|
+
self.env_stack.append(new_frame)
|
115
|
+
await self.assign(comp.target, item)
|
116
|
+
conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
|
117
|
+
if all(conditions):
|
118
|
+
await rec(gen_idx + 1)
|
119
|
+
self.env_stack.pop()
|
120
|
+
except TypeError as e:
|
121
|
+
lineno = getattr(node, "lineno", 1)
|
122
|
+
col = getattr(node, "col_offset", 0)
|
123
|
+
context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
|
124
|
+
raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
|
125
|
+
|
126
|
+
await rec(0)
|
127
|
+
self.env_stack.pop()
|
128
|
+
return result
|
129
|
+
|
130
|
+
async def visit_GeneratorExp(self: ASTInterpreter, node: ast.GeneratorExp, wrap_exceptions: bool = True) -> Any:
|
131
|
+
base_frame: Dict[str, Any] = self.env_stack[-1].copy()
|
132
|
+
self.env_stack.append(base_frame)
|
133
|
+
|
134
|
+
async def gen():
|
135
|
+
async def rec(gen_idx: int):
|
136
|
+
if gen_idx == len(node.generators):
|
137
|
+
yield await self.visit(node.elt, wrap_exceptions=True)
|
138
|
+
else:
|
139
|
+
comp = node.generators[gen_idx]
|
140
|
+
iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
|
141
|
+
if hasattr(iterable, '__aiter__'):
|
142
|
+
async for item in iterable:
|
143
|
+
new_frame = self.env_stack[-1].copy()
|
144
|
+
self.env_stack.append(new_frame)
|
145
|
+
await self.assign(comp.target, item)
|
146
|
+
conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
|
147
|
+
if all(conditions):
|
148
|
+
async for val in rec(gen_idx + 1):
|
149
|
+
yield val
|
150
|
+
self.env_stack.pop()
|
151
|
+
else:
|
152
|
+
try:
|
153
|
+
for item in iterable:
|
154
|
+
new_frame = self.env_stack[-1].copy()
|
155
|
+
self.env_stack.append(new_frame)
|
156
|
+
await self.assign(comp.target, item)
|
157
|
+
conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
|
158
|
+
if all(conditions):
|
159
|
+
async for val in rec(gen_idx + 1):
|
160
|
+
yield val
|
161
|
+
self.env_stack.pop()
|
162
|
+
except TypeError as e:
|
163
|
+
lineno = getattr(node, "lineno", 1)
|
164
|
+
col = getattr(node, "col_offset", 0)
|
165
|
+
context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
|
166
|
+
raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
|
167
|
+
|
168
|
+
async for val in rec(0):
|
169
|
+
yield val
|
170
|
+
|
171
|
+
self.env_stack.pop()
|
172
|
+
return gen()
|
@@ -0,0 +1,59 @@
|
|
1
|
+
import ast
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from .exceptions import ReturnException
|
5
|
+
from .interpreter_core import ASTInterpreter
|
6
|
+
|
7
|
+
async def visit_With(self: ASTInterpreter, node: ast.With, wrap_exceptions: bool = True) -> Any:
|
8
|
+
result = None
|
9
|
+
contexts = []
|
10
|
+
for item in node.items:
|
11
|
+
ctx = await self.visit(item.context_expr, wrap_exceptions=wrap_exceptions)
|
12
|
+
val = ctx.__enter__()
|
13
|
+
contexts.append((ctx, val))
|
14
|
+
if item.optional_vars:
|
15
|
+
await self.assign(item.optional_vars, val)
|
16
|
+
try:
|
17
|
+
for stmt in node.body:
|
18
|
+
result = await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
19
|
+
except ReturnException as ret:
|
20
|
+
for ctx, _ in reversed(contexts):
|
21
|
+
ctx.__exit__(None, None, None)
|
22
|
+
raise ret
|
23
|
+
except Exception as e:
|
24
|
+
exc_type, exc_value, tb = type(e), e, e.__traceback__
|
25
|
+
for ctx, _ in reversed(contexts):
|
26
|
+
if not ctx.__exit__(exc_type, exc_value, tb):
|
27
|
+
raise
|
28
|
+
raise
|
29
|
+
else:
|
30
|
+
for ctx, _ in reversed(contexts):
|
31
|
+
ctx.__exit__(None, None, None)
|
32
|
+
return result
|
33
|
+
|
34
|
+
async def visit_AsyncWith(self: ASTInterpreter, node: ast.AsyncWith, wrap_exceptions: bool = True) -> Any:
|
35
|
+
result = None
|
36
|
+
contexts = []
|
37
|
+
for item in node.items:
|
38
|
+
ctx = await self.visit(item.context_expr, wrap_exceptions=wrap_exceptions)
|
39
|
+
val = await ctx.__aenter__()
|
40
|
+
contexts.append((ctx, val))
|
41
|
+
if item.optional_vars:
|
42
|
+
await self.assign(item.optional_vars, val)
|
43
|
+
try:
|
44
|
+
for stmt in node.body:
|
45
|
+
result = await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
46
|
+
except ReturnException as ret:
|
47
|
+
for ctx, _ in reversed(contexts):
|
48
|
+
await ctx.__aexit__(None, None, None)
|
49
|
+
raise ret
|
50
|
+
except Exception as e:
|
51
|
+
exc_type, exc_value, tb = type(e), e, e.__traceback__
|
52
|
+
for ctx, _ in reversed(contexts):
|
53
|
+
if not await ctx.__aexit__(exc_type, exc_value, tb):
|
54
|
+
raise
|
55
|
+
raise
|
56
|
+
else:
|
57
|
+
for ctx, _ in reversed(contexts):
|
58
|
+
await ctx.__aexit__(None, None, None)
|
59
|
+
return result
|
@@ -0,0 +1,88 @@
|
|
1
|
+
import ast
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from .exceptions import BreakException, ContinueException, ReturnException
|
5
|
+
from .interpreter_core import ASTInterpreter
|
6
|
+
|
7
|
+
async def visit_If(self: ASTInterpreter, node: ast.If, wrap_exceptions: bool = True) -> Any:
|
8
|
+
if await self.visit(node.test, wrap_exceptions=wrap_exceptions):
|
9
|
+
branch = node.body
|
10
|
+
else:
|
11
|
+
branch = node.orelse
|
12
|
+
result = None
|
13
|
+
if branch:
|
14
|
+
for stmt in branch[:-1]:
|
15
|
+
await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
16
|
+
result = await self.visit(branch[-1], wrap_exceptions=wrap_exceptions)
|
17
|
+
return result
|
18
|
+
|
19
|
+
async def visit_While(self: ASTInterpreter, node: ast.While, wrap_exceptions: bool = True) -> None:
|
20
|
+
while await self.visit(node.test, wrap_exceptions=wrap_exceptions):
|
21
|
+
try:
|
22
|
+
for stmt in node.body:
|
23
|
+
await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
24
|
+
except BreakException:
|
25
|
+
break
|
26
|
+
except ContinueException:
|
27
|
+
continue
|
28
|
+
for stmt in node.orelse:
|
29
|
+
await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
30
|
+
|
31
|
+
async def visit_For(self: ASTInterpreter, node: ast.For, wrap_exceptions: bool = True) -> None:
|
32
|
+
iter_obj: Any = await self.visit(node.iter, wrap_exceptions=wrap_exceptions)
|
33
|
+
broke = False
|
34
|
+
if hasattr(iter_obj, '__aiter__'):
|
35
|
+
async for item in iter_obj:
|
36
|
+
await self.assign(node.target, item)
|
37
|
+
try:
|
38
|
+
for stmt in node.body:
|
39
|
+
await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
40
|
+
except BreakException:
|
41
|
+
broke = True
|
42
|
+
break
|
43
|
+
except ContinueException:
|
44
|
+
continue
|
45
|
+
else:
|
46
|
+
for item in iter_obj:
|
47
|
+
await self.assign(node.target, item)
|
48
|
+
try:
|
49
|
+
for stmt in node.body:
|
50
|
+
await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
51
|
+
except BreakException:
|
52
|
+
broke = True
|
53
|
+
break
|
54
|
+
except ContinueException:
|
55
|
+
continue
|
56
|
+
if not broke:
|
57
|
+
for stmt in node.orelse:
|
58
|
+
await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
59
|
+
|
60
|
+
async def visit_AsyncFor(self: ASTInterpreter, node: ast.AsyncFor, wrap_exceptions: bool = True) -> None:
|
61
|
+
iterable = await self.visit(node.iter, wrap_exceptions=wrap_exceptions)
|
62
|
+
broke = False
|
63
|
+
async for value in iterable:
|
64
|
+
await self.assign(node.target, value)
|
65
|
+
try:
|
66
|
+
for stmt in node.body:
|
67
|
+
await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
68
|
+
except BreakException:
|
69
|
+
broke = True
|
70
|
+
break
|
71
|
+
except ContinueException:
|
72
|
+
continue
|
73
|
+
if not broke:
|
74
|
+
for stmt in node.orelse:
|
75
|
+
await self.visit(stmt, wrap_exceptions=wrap_exceptions)
|
76
|
+
|
77
|
+
async def visit_Break(self: ASTInterpreter, node: ast.Break, wrap_exceptions: bool = True) -> None:
|
78
|
+
raise BreakException()
|
79
|
+
|
80
|
+
async def visit_Continue(self: ASTInterpreter, node: ast.Continue, wrap_exceptions: bool = True) -> None:
|
81
|
+
raise ContinueException()
|
82
|
+
|
83
|
+
async def visit_Return(self: ASTInterpreter, node: ast.Return, wrap_exceptions: bool = True) -> None:
|
84
|
+
value: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions) if node.value is not None else None
|
85
|
+
raise ReturnException(value)
|
86
|
+
|
87
|
+
async def visit_IfExp(self: ASTInterpreter, node: ast.IfExp, wrap_exceptions: bool = True) -> Any:
|
88
|
+
return await self.visit(node.body, wrap_exceptions=wrap_exceptions) if await self.visit(node.test, wrap_exceptions=wrap_exceptions) else await self.visit(node.orelse, wrap_exceptions=wrap_exceptions)
|