quantalogic 0.59.3__py3-none-any.whl → 0.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. quantalogic/agent.py +268 -24
  2. quantalogic/agent_config.py +5 -5
  3. quantalogic/agent_factory.py +2 -2
  4. quantalogic/codeact/__init__.py +0 -0
  5. quantalogic/codeact/agent.py +499 -0
  6. quantalogic/codeact/cli.py +232 -0
  7. quantalogic/codeact/constants.py +9 -0
  8. quantalogic/codeact/events.py +78 -0
  9. quantalogic/codeact/llm_util.py +76 -0
  10. quantalogic/codeact/prompts/error_format.j2 +11 -0
  11. quantalogic/codeact/prompts/generate_action.j2 +26 -0
  12. quantalogic/codeact/prompts/generate_program.j2 +39 -0
  13. quantalogic/codeact/prompts/response_format.j2 +11 -0
  14. quantalogic/codeact/tools_manager.py +135 -0
  15. quantalogic/codeact/utils.py +135 -0
  16. quantalogic/coding_agent.py +2 -2
  17. quantalogic/create_custom_agent.py +26 -78
  18. quantalogic/prompts/chat_system_prompt.j2 +10 -7
  19. quantalogic/prompts/code_2_system_prompt.j2 +190 -0
  20. quantalogic/prompts/code_system_prompt.j2 +142 -0
  21. quantalogic/prompts/doc_system_prompt.j2 +178 -0
  22. quantalogic/prompts/legal_2_system_prompt.j2 +218 -0
  23. quantalogic/prompts/legal_system_prompt.j2 +140 -0
  24. quantalogic/prompts/system_prompt.j2 +6 -2
  25. quantalogic/prompts/tools_prompt.j2 +2 -4
  26. quantalogic/prompts.py +23 -4
  27. quantalogic/python_interpreter/__init__.py +23 -0
  28. quantalogic/python_interpreter/assignment_visitors.py +63 -0
  29. quantalogic/python_interpreter/base_visitors.py +20 -0
  30. quantalogic/python_interpreter/class_visitors.py +22 -0
  31. quantalogic/python_interpreter/comprehension_visitors.py +172 -0
  32. quantalogic/python_interpreter/context_visitors.py +59 -0
  33. quantalogic/python_interpreter/control_flow_visitors.py +88 -0
  34. quantalogic/python_interpreter/exception_visitors.py +109 -0
  35. quantalogic/python_interpreter/exceptions.py +39 -0
  36. quantalogic/python_interpreter/execution.py +202 -0
  37. quantalogic/python_interpreter/function_utils.py +386 -0
  38. quantalogic/python_interpreter/function_visitors.py +209 -0
  39. quantalogic/python_interpreter/import_visitors.py +28 -0
  40. quantalogic/python_interpreter/interpreter_core.py +358 -0
  41. quantalogic/python_interpreter/literal_visitors.py +74 -0
  42. quantalogic/python_interpreter/misc_visitors.py +148 -0
  43. quantalogic/python_interpreter/operator_visitors.py +108 -0
  44. quantalogic/python_interpreter/scope.py +10 -0
  45. quantalogic/python_interpreter/visit_handlers.py +110 -0
  46. quantalogic/server/agent_server.py +1 -1
  47. quantalogic/tools/__init__.py +6 -3
  48. quantalogic/tools/action_gen.py +366 -0
  49. quantalogic/tools/duckduckgo_search_tool.py +1 -0
  50. quantalogic/tools/execute_bash_command_tool.py +114 -57
  51. quantalogic/tools/file_tracker_tool.py +49 -0
  52. quantalogic/tools/google_packages/google_news_tool.py +3 -0
  53. quantalogic/tools/image_generation/dalle_e.py +89 -137
  54. quantalogic/tools/python_tool.py +13 -0
  55. quantalogic/tools/rag_tool/__init__.py +2 -9
  56. quantalogic/tools/rag_tool/document_rag_sources_.py +728 -0
  57. quantalogic/tools/rag_tool/ocr_pdf_markdown.py +144 -0
  58. quantalogic/tools/replace_in_file_tool.py +1 -1
  59. quantalogic/tools/{search_definition_names.py → search_definition_names_tool.py} +2 -2
  60. quantalogic/tools/terminal_capture_tool.py +293 -0
  61. quantalogic/tools/tool.py +120 -22
  62. quantalogic/tools/utilities/__init__.py +2 -0
  63. quantalogic/tools/utilities/download_file_tool.py +3 -5
  64. quantalogic/tools/utilities/llm_tool.py +283 -0
  65. quantalogic/tools/utilities/selenium_tool.py +296 -0
  66. quantalogic/tools/utilities/vscode_tool.py +1 -1
  67. quantalogic/tools/web_navigation/__init__.py +5 -0
  68. quantalogic/tools/web_navigation/web_tool.py +145 -0
  69. quantalogic/tools/write_file_tool.py +72 -36
  70. quantalogic/utils/__init__.py +0 -1
  71. quantalogic/utils/test_python_interpreter.py +119 -0
  72. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/METADATA +7 -2
  73. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/RECORD +76 -35
  74. quantalogic/tools/rag_tool/document_metadata.py +0 -15
  75. quantalogic/tools/rag_tool/query_response.py +0 -20
  76. quantalogic/tools/rag_tool/rag_tool.py +0 -566
  77. quantalogic/tools/rag_tool/rag_tool_beta.py +0 -264
  78. quantalogic/utils/python_interpreter.py +0 -905
  79. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/LICENSE +0 -0
  80. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/WHEEL +0 -0
  81. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/entry_points.txt +0 -0
@@ -7,6 +7,10 @@ Expert ReAct AI Agent implementing OODA (Observe-Orient-Decide-Act) loop with co
7
7
  ### Input Protocol
8
8
  Task Format: <task>task_description</task>
9
9
 
10
+ ### FORBIDDEN
11
+ - Never install / execute commands apt-get, or install packages using npm, pnpm, pip yarn, or any other package manager
12
+ - Just provide the command to run if wanted in their environment.
13
+
10
14
  ### Cognitive Framework
11
15
  1. 🔍 **OBSERVE**: Gather essential data
12
16
  2. 🧭 **ORIENT**: Analyze context briefly
@@ -75,8 +79,8 @@ Task Format: <task>task_description</task>
75
79
  - 🛠️ **Tools**: {{ tools }}
76
80
  - 🌐 **Environment**: {{ environment }}
77
81
 
78
- ### Execution Guidelines
79
- . 🎯 Focus on task objectives
82
+ ### Execution Guidelines
83
+ 1. 🎯 Focus on task objectives
80
84
  2. 📊 Use data-driven decisions
81
85
  3. 🔄 Optimize with feedback loops
82
86
  4. ⚡ Maximize efficiency via interpolation
@@ -7,7 +7,5 @@ Instructions:
7
7
  1. Select ONE tool per message
8
8
  2. You will receive the tool's output in the next user response
9
9
  3. Choose the most appropriate tool for each step
10
- 4. If it's not asked to write on files, don't use write_file tool
11
- 5. If files are written, then use tool to display the prepared download link
12
- 6. Give the final full answer using all the variables
13
- 7. Use task_complete tool to confirm task completion with the full content of the final answer
10
+ 4. Give the final full answer using all the variables
11
+ 5. Use task_complete tool to confirm task completion with the full content of the final answer
quantalogic/prompts.py CHANGED
@@ -1,20 +1,34 @@
1
1
  import os
2
2
  from pathlib import Path
3
+ from typing import Dict
3
4
 
4
5
  from jinja2 import Environment, FileSystemLoader
6
+ from loguru import logger
5
7
 
6
8
  from quantalogic.version import get_version
7
9
 
10
+ # Map agent modes to their system prompt templates
11
+ SYSTEM_PROMPTS: Dict[str, str] = {
12
+ "react": "system_prompt.j2",
13
+ "chat": "chat_prompt.j2",
14
+ "code": "code_system_prompt.j2",
15
+ "code_enhanced": "code_2_system_prompt.j2",
16
+ "legal": "legal_system_prompt.j2",
17
+ "legal_enhanced": "legal_2_system_prompt.j2",
18
+ "doc": "doc_system_prompt.j2",
19
+ "default": "system_prompt.j2" # Fallback template
20
+ }
8
21
 
9
- def system_prompt(tools: str, environment: str, expertise: str = ""):
22
+ def system_prompt(tools: str, environment: str, expertise: str = "", agent_mode: str = "react"):
10
23
  """System prompt for the ReAct chatbot with enhanced cognitive architecture.
11
24
 
12
- Uses a Jinja2 template from the prompts directory.
25
+ Uses a Jinja2 template from the prompts directory based on agent_mode.
13
26
 
14
27
  Args:
15
28
  tools: Available tools for the agent
16
29
  environment: Environment information
17
30
  expertise: Domain expertise information
31
+ agent_mode: Mode to determine which system prompt to use
18
32
 
19
33
  Returns:
20
34
  str: The rendered system prompt
@@ -26,8 +40,13 @@ def system_prompt(tools: str, environment: str, expertise: str = ""):
26
40
  template_dir = current_dir / 'prompts'
27
41
  env = Environment(loader=FileSystemLoader(template_dir))
28
42
 
29
- # Load the template
30
- template = env.get_template('system_prompt.j2')
43
+ # Get template name based on agent mode, fallback to default if not found
44
+ template_name = SYSTEM_PROMPTS.get(agent_mode, "system_prompt.j2")
45
+ try:
46
+ template = env.get_template(template_name)
47
+ except Exception as e:
48
+ logger.warning(f"Template {template_name} not found, using default")
49
+ template = env.get_template("system_prompt.j2")
31
50
 
32
51
  # Render the template with the provided variables
33
52
  return template.render(
@@ -0,0 +1,23 @@
1
+ # quantalogic/utils/__init__.py
2
+ from .exceptions import BreakException, ContinueException, ReturnException, WrappedException, has_await
3
+ from .execution import AsyncExecutionResult, execute_async, interpret_ast, interpret_code
4
+ from .function_utils import AsyncFunction, Function, LambdaFunction
5
+ from .interpreter_core import ASTInterpreter
6
+ from .scope import Scope
7
+
8
+ __all__ = [
9
+ 'ASTInterpreter',
10
+ 'execute_async',
11
+ 'interpret_ast',
12
+ 'interpret_code',
13
+ 'AsyncExecutionResult',
14
+ 'ReturnException',
15
+ 'BreakException',
16
+ 'ContinueException',
17
+ 'WrappedException',
18
+ 'has_await',
19
+ 'Function',
20
+ 'AsyncFunction',
21
+ 'LambdaFunction',
22
+ 'Scope',
23
+ ]
@@ -0,0 +1,63 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .interpreter_core import ASTInterpreter
5
+
6
+ async def visit_Assign(self: ASTInterpreter, node: ast.Assign, wrap_exceptions: bool = True) -> None:
7
+ value: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
8
+ for target in node.targets:
9
+ if isinstance(target, ast.Subscript):
10
+ obj = await self.visit(target.value, wrap_exceptions=wrap_exceptions)
11
+ key = await self.visit(target.slice, wrap_exceptions=wrap_exceptions)
12
+ obj[key] = value
13
+ else:
14
+ await self.assign(target, value)
15
+
16
+ async def visit_AugAssign(self: ASTInterpreter, node: ast.AugAssign, wrap_exceptions: bool = True) -> Any:
17
+ if isinstance(node.target, ast.Name):
18
+ current_val: Any = self.get_variable(node.target.id)
19
+ else:
20
+ current_val: Any = await self.visit(node.target, wrap_exceptions=wrap_exceptions)
21
+ right_val: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
22
+ op = node.op
23
+ if isinstance(op, ast.Add):
24
+ result: Any = current_val + right_val
25
+ elif isinstance(op, ast.Sub):
26
+ result = current_val - right_val
27
+ elif isinstance(op, ast.Mult):
28
+ result = current_val * right_val
29
+ elif isinstance(op, ast.Div):
30
+ result = current_val / right_val
31
+ elif isinstance(op, ast.FloorDiv):
32
+ result = current_val // right_val
33
+ elif isinstance(op, ast.Mod):
34
+ result = current_val % right_val
35
+ elif isinstance(op, ast.Pow):
36
+ result = current_val**right_val
37
+ elif isinstance(op, ast.BitAnd):
38
+ result = current_val & right_val
39
+ elif isinstance(op, ast.BitOr):
40
+ result = current_val | right_val
41
+ elif isinstance(op, ast.BitXor):
42
+ result = current_val ^ right_val
43
+ elif isinstance(op, ast.LShift):
44
+ result = current_val << right_val
45
+ elif isinstance(op, ast.RShift):
46
+ result = current_val >> right_val
47
+ else:
48
+ raise Exception("Unsupported augmented operator: " + str(op))
49
+ await self.assign(node.target, result)
50
+ return result
51
+
52
+ async def visit_AnnAssign(self: ASTInterpreter, node: ast.AnnAssign, wrap_exceptions: bool = True) -> None:
53
+ value = await self.visit(node.value, wrap_exceptions=wrap_exceptions) if node.value else None
54
+ annotation = await self.visit(node.annotation, wrap_exceptions=True)
55
+ if isinstance(node.target, ast.Name):
56
+ self.type_hints[node.target.id] = annotation
57
+ if value is not None or node.simple:
58
+ await self.assign(node.target, value)
59
+
60
+ async def visit_NamedExpr(self: ASTInterpreter, node: ast.NamedExpr, wrap_exceptions: bool = True) -> Any:
61
+ value = await self.visit(node.value, wrap_exceptions=wrap_exceptions)
62
+ await self.assign(node.target, value)
63
+ return value
@@ -0,0 +1,20 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .exceptions import WrappedException
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_Module(self: ASTInterpreter, node: ast.Module, wrap_exceptions: bool = True) -> Any:
8
+ last_value = None
9
+ for stmt in node.body:
10
+ last_value = await self.visit(stmt, wrap_exceptions=True)
11
+ return last_value
12
+
13
+ async def visit_Expr(self: ASTInterpreter, node: ast.Expr, wrap_exceptions: bool = True) -> Any:
14
+ return await self.visit(node.value, wrap_exceptions=wrap_exceptions)
15
+
16
+ async def visit_Pass(self: ASTInterpreter, node: ast.Pass, wrap_exceptions: bool = True) -> None:
17
+ return None
18
+
19
+ async def visit_TypeIgnore(self: ASTInterpreter, node: ast.TypeIgnore, wrap_exceptions: bool = True) -> None:
20
+ pass
@@ -0,0 +1,22 @@
1
+ import ast
2
+ from typing import Any, Dict, List
3
+
4
+ from .function_utils import Function
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_ClassDef(self: ASTInterpreter, node: ast.ClassDef, wrap_exceptions: bool = True) -> Any:
8
+ base_frame = {}
9
+ self.env_stack.append(base_frame)
10
+ bases = [await self.visit(base, wrap_exceptions=True) for base in node.bases]
11
+ try:
12
+ for stmt in node.body:
13
+ await self.visit(stmt, wrap_exceptions=True)
14
+ class_dict = {k: v for k, v in self.env_stack[-1].items() if k not in ["__builtins__"]}
15
+ cls = type(node.name, tuple(bases), class_dict)
16
+ for name, value in class_dict.items():
17
+ if isinstance(value, Function):
18
+ value.defining_class = cls
19
+ self.env_stack[-2][node.name] = cls
20
+ return cls
21
+ finally:
22
+ self.env_stack.pop()
@@ -0,0 +1,172 @@
1
+ import ast
2
+ from typing import Any, Dict, List
3
+
4
+ from .interpreter_core import ASTInterpreter
5
+ from .exceptions import WrappedException
6
+
7
+ async def visit_ListComp(self: ASTInterpreter, node: ast.ListComp, wrap_exceptions: bool = True) -> List[Any]:
8
+ result = []
9
+ base_frame = self.env_stack[-1].copy()
10
+ self.env_stack.append(base_frame)
11
+
12
+ async def rec(gen_idx: int):
13
+ if gen_idx == len(node.generators):
14
+ element = await self.visit(node.elt, wrap_exceptions=wrap_exceptions)
15
+ result.append(element)
16
+ else:
17
+ comp = node.generators[gen_idx]
18
+ iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
19
+ if hasattr(iterable, '__aiter__'):
20
+ async for item in iterable:
21
+ new_frame = self.env_stack[-1].copy()
22
+ self.env_stack.append(new_frame)
23
+ await self.assign(comp.target, item)
24
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
25
+ if all(conditions):
26
+ await rec(gen_idx + 1)
27
+ self.env_stack.pop()
28
+ else:
29
+ try:
30
+ for item in iterable:
31
+ new_frame = self.env_stack[-1].copy()
32
+ self.env_stack.append(new_frame)
33
+ await self.assign(comp.target, item)
34
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
35
+ if all(conditions):
36
+ await rec(gen_idx + 1)
37
+ self.env_stack.pop()
38
+ except TypeError as e:
39
+ lineno = getattr(node, "lineno", 1)
40
+ col = getattr(node, "col_offset", 0)
41
+ context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
42
+ raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
43
+
44
+ await rec(0)
45
+ self.env_stack.pop()
46
+ return result
47
+
48
+ async def visit_DictComp(self: ASTInterpreter, node: ast.DictComp, wrap_exceptions: bool = True) -> Dict[Any, Any]:
49
+ result = {}
50
+ base_frame = self.env_stack[-1].copy()
51
+ self.env_stack.append(base_frame)
52
+
53
+ async def rec(gen_idx: int):
54
+ if gen_idx == len(node.generators):
55
+ key = await self.visit(node.key, wrap_exceptions=True)
56
+ val = await self.visit(node.value, wrap_exceptions=True)
57
+ result[key] = val
58
+ else:
59
+ comp = node.generators[gen_idx]
60
+ iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
61
+ if hasattr(iterable, '__aiter__'):
62
+ async for item in iterable:
63
+ new_frame = self.env_stack[-1].copy()
64
+ self.env_stack.append(new_frame)
65
+ await self.assign(comp.target, item)
66
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
67
+ if all(conditions):
68
+ await rec(gen_idx + 1)
69
+ self.env_stack.pop()
70
+ else:
71
+ try:
72
+ for item in iterable:
73
+ new_frame = self.env_stack[-1].copy()
74
+ self.env_stack.append(new_frame)
75
+ await self.assign(comp.target, item)
76
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
77
+ if all(conditions):
78
+ await rec(gen_idx + 1)
79
+ self.env_stack.pop()
80
+ except TypeError as e:
81
+ lineno = getattr(node, "lineno", 1)
82
+ col = getattr(node, "col_offset", 0)
83
+ context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
84
+ raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
85
+
86
+ await rec(0)
87
+ self.env_stack.pop()
88
+ return result
89
+
90
+ async def visit_SetComp(self: ASTInterpreter, node: ast.SetComp, wrap_exceptions: bool = True) -> set:
91
+ result = set()
92
+ base_frame = self.env_stack[-1].copy()
93
+ self.env_stack.append(base_frame)
94
+
95
+ async def rec(gen_idx: int):
96
+ if gen_idx == len(node.generators):
97
+ result.add(await self.visit(node.elt, wrap_exceptions=True))
98
+ else:
99
+ comp = node.generators[gen_idx]
100
+ iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
101
+ if hasattr(iterable, '__aiter__'):
102
+ async for item in iterable:
103
+ new_frame = self.env_stack[-1].copy()
104
+ self.env_stack.append(new_frame)
105
+ await self.assign(comp.target, item)
106
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
107
+ if all(conditions):
108
+ await rec(gen_idx + 1)
109
+ self.env_stack.pop()
110
+ else:
111
+ try:
112
+ for item in iterable:
113
+ new_frame = self.env_stack[-1].copy()
114
+ self.env_stack.append(new_frame)
115
+ await self.assign(comp.target, item)
116
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
117
+ if all(conditions):
118
+ await rec(gen_idx + 1)
119
+ self.env_stack.pop()
120
+ except TypeError as e:
121
+ lineno = getattr(node, "lineno", 1)
122
+ col = getattr(node, "col_offset", 0)
123
+ context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
124
+ raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
125
+
126
+ await rec(0)
127
+ self.env_stack.pop()
128
+ return result
129
+
130
+ async def visit_GeneratorExp(self: ASTInterpreter, node: ast.GeneratorExp, wrap_exceptions: bool = True) -> Any:
131
+ base_frame: Dict[str, Any] = self.env_stack[-1].copy()
132
+ self.env_stack.append(base_frame)
133
+
134
+ async def gen():
135
+ async def rec(gen_idx: int):
136
+ if gen_idx == len(node.generators):
137
+ yield await self.visit(node.elt, wrap_exceptions=True)
138
+ else:
139
+ comp = node.generators[gen_idx]
140
+ iterable = await self.visit(comp.iter, wrap_exceptions=wrap_exceptions)
141
+ if hasattr(iterable, '__aiter__'):
142
+ async for item in iterable:
143
+ new_frame = self.env_stack[-1].copy()
144
+ self.env_stack.append(new_frame)
145
+ await self.assign(comp.target, item)
146
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
147
+ if all(conditions):
148
+ async for val in rec(gen_idx + 1):
149
+ yield val
150
+ self.env_stack.pop()
151
+ else:
152
+ try:
153
+ for item in iterable:
154
+ new_frame = self.env_stack[-1].copy()
155
+ self.env_stack.append(new_frame)
156
+ await self.assign(comp.target, item)
157
+ conditions = [await self.visit(if_clause, wrap_exceptions=True) for if_clause in comp.ifs]
158
+ if all(conditions):
159
+ async for val in rec(gen_idx + 1):
160
+ yield val
161
+ self.env_stack.pop()
162
+ except TypeError as e:
163
+ lineno = getattr(node, "lineno", 1)
164
+ col = getattr(node, "col_offset", 0)
165
+ context_line = self.source_lines[lineno - 1] if self.source_lines and lineno <= len(self.source_lines) else ""
166
+ raise WrappedException(f"Object {iterable} is not iterable", e, lineno, col, context_line) from e
167
+
168
+ async for val in rec(0):
169
+ yield val
170
+
171
+ self.env_stack.pop()
172
+ return gen()
@@ -0,0 +1,59 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .exceptions import ReturnException
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_With(self: ASTInterpreter, node: ast.With, wrap_exceptions: bool = True) -> Any:
8
+ result = None
9
+ contexts = []
10
+ for item in node.items:
11
+ ctx = await self.visit(item.context_expr, wrap_exceptions=wrap_exceptions)
12
+ val = ctx.__enter__()
13
+ contexts.append((ctx, val))
14
+ if item.optional_vars:
15
+ await self.assign(item.optional_vars, val)
16
+ try:
17
+ for stmt in node.body:
18
+ result = await self.visit(stmt, wrap_exceptions=wrap_exceptions)
19
+ except ReturnException as ret:
20
+ for ctx, _ in reversed(contexts):
21
+ ctx.__exit__(None, None, None)
22
+ raise ret
23
+ except Exception as e:
24
+ exc_type, exc_value, tb = type(e), e, e.__traceback__
25
+ for ctx, _ in reversed(contexts):
26
+ if not ctx.__exit__(exc_type, exc_value, tb):
27
+ raise
28
+ raise
29
+ else:
30
+ for ctx, _ in reversed(contexts):
31
+ ctx.__exit__(None, None, None)
32
+ return result
33
+
34
+ async def visit_AsyncWith(self: ASTInterpreter, node: ast.AsyncWith, wrap_exceptions: bool = True) -> Any:
35
+ result = None
36
+ contexts = []
37
+ for item in node.items:
38
+ ctx = await self.visit(item.context_expr, wrap_exceptions=wrap_exceptions)
39
+ val = await ctx.__aenter__()
40
+ contexts.append((ctx, val))
41
+ if item.optional_vars:
42
+ await self.assign(item.optional_vars, val)
43
+ try:
44
+ for stmt in node.body:
45
+ result = await self.visit(stmt, wrap_exceptions=wrap_exceptions)
46
+ except ReturnException as ret:
47
+ for ctx, _ in reversed(contexts):
48
+ await ctx.__aexit__(None, None, None)
49
+ raise ret
50
+ except Exception as e:
51
+ exc_type, exc_value, tb = type(e), e, e.__traceback__
52
+ for ctx, _ in reversed(contexts):
53
+ if not await ctx.__aexit__(exc_type, exc_value, tb):
54
+ raise
55
+ raise
56
+ else:
57
+ for ctx, _ in reversed(contexts):
58
+ await ctx.__aexit__(None, None, None)
59
+ return result
@@ -0,0 +1,88 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .exceptions import BreakException, ContinueException, ReturnException
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_If(self: ASTInterpreter, node: ast.If, wrap_exceptions: bool = True) -> Any:
8
+ if await self.visit(node.test, wrap_exceptions=wrap_exceptions):
9
+ branch = node.body
10
+ else:
11
+ branch = node.orelse
12
+ result = None
13
+ if branch:
14
+ for stmt in branch[:-1]:
15
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
16
+ result = await self.visit(branch[-1], wrap_exceptions=wrap_exceptions)
17
+ return result
18
+
19
+ async def visit_While(self: ASTInterpreter, node: ast.While, wrap_exceptions: bool = True) -> None:
20
+ while await self.visit(node.test, wrap_exceptions=wrap_exceptions):
21
+ try:
22
+ for stmt in node.body:
23
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
24
+ except BreakException:
25
+ break
26
+ except ContinueException:
27
+ continue
28
+ for stmt in node.orelse:
29
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
30
+
31
+ async def visit_For(self: ASTInterpreter, node: ast.For, wrap_exceptions: bool = True) -> None:
32
+ iter_obj: Any = await self.visit(node.iter, wrap_exceptions=wrap_exceptions)
33
+ broke = False
34
+ if hasattr(iter_obj, '__aiter__'):
35
+ async for item in iter_obj:
36
+ await self.assign(node.target, item)
37
+ try:
38
+ for stmt in node.body:
39
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
40
+ except BreakException:
41
+ broke = True
42
+ break
43
+ except ContinueException:
44
+ continue
45
+ else:
46
+ for item in iter_obj:
47
+ await self.assign(node.target, item)
48
+ try:
49
+ for stmt in node.body:
50
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
51
+ except BreakException:
52
+ broke = True
53
+ break
54
+ except ContinueException:
55
+ continue
56
+ if not broke:
57
+ for stmt in node.orelse:
58
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
59
+
60
+ async def visit_AsyncFor(self: ASTInterpreter, node: ast.AsyncFor, wrap_exceptions: bool = True) -> None:
61
+ iterable = await self.visit(node.iter, wrap_exceptions=wrap_exceptions)
62
+ broke = False
63
+ async for value in iterable:
64
+ await self.assign(node.target, value)
65
+ try:
66
+ for stmt in node.body:
67
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
68
+ except BreakException:
69
+ broke = True
70
+ break
71
+ except ContinueException:
72
+ continue
73
+ if not broke:
74
+ for stmt in node.orelse:
75
+ await self.visit(stmt, wrap_exceptions=wrap_exceptions)
76
+
77
+ async def visit_Break(self: ASTInterpreter, node: ast.Break, wrap_exceptions: bool = True) -> None:
78
+ raise BreakException()
79
+
80
+ async def visit_Continue(self: ASTInterpreter, node: ast.Continue, wrap_exceptions: bool = True) -> None:
81
+ raise ContinueException()
82
+
83
+ async def visit_Return(self: ASTInterpreter, node: ast.Return, wrap_exceptions: bool = True) -> None:
84
+ value: Any = await self.visit(node.value, wrap_exceptions=wrap_exceptions) if node.value is not None else None
85
+ raise ReturnException(value)
86
+
87
+ async def visit_IfExp(self: ASTInterpreter, node: ast.IfExp, wrap_exceptions: bool = True) -> Any:
88
+ return await self.visit(node.body, wrap_exceptions=wrap_exceptions) if await self.visit(node.test, wrap_exceptions=wrap_exceptions) else await self.visit(node.orelse, wrap_exceptions=wrap_exceptions)
@@ -0,0 +1,109 @@
1
+ import ast
2
+ from typing import Any, Optional, Tuple
3
+
4
+ from .exceptions import BaseExceptionGroup, ReturnException, WrappedException
5
+ from .interpreter_core import ASTInterpreter
6
+
7
+ async def visit_Try(self: ASTInterpreter, node: ast.Try, wrap_exceptions: bool = True) -> Any:
8
+ result: Any = None
9
+ try:
10
+ for stmt in node.body:
11
+ result = await self.visit(stmt, wrap_exceptions=False)
12
+ except ReturnException as ret:
13
+ raise ret
14
+ except Exception as e:
15
+ original_e = e.original_exception if isinstance(e, WrappedException) else e
16
+ for handler in node.handlers:
17
+ exc_type = await self._resolve_exception_type(handler.type)
18
+ if exc_type and isinstance(original_e, exc_type):
19
+ if handler.name:
20
+ self.set_variable(handler.name, original_e)
21
+ handler_result = None
22
+ try:
23
+ for stmt in handler.body:
24
+ handler_result = await self.visit(stmt, wrap_exceptions=True)
25
+ except ReturnException as ret:
26
+ raise ret
27
+ if handler_result is not None:
28
+ result = handler_result
29
+ break
30
+ else:
31
+ raise
32
+ else:
33
+ for stmt in node.orelse:
34
+ result = await self.visit(stmt, wrap_exceptions=True)
35
+ finally:
36
+ for stmt in node.finalbody:
37
+ await self.visit(stmt, wrap_exceptions=True)
38
+ return result
39
+
40
+ async def visit_TryStar(self: ASTInterpreter, node: ast.TryStar, wrap_exceptions: bool = True) -> Any:
41
+ result: Any = None
42
+ exc_info: Optional[Tuple] = None
43
+
44
+ try:
45
+ for stmt in node.body:
46
+ result = await self.visit(stmt, wrap_exceptions=False)
47
+ except BaseException as e:
48
+ exc_info = (type(e), e, e.__traceback__)
49
+ handled = False
50
+ if isinstance(e, BaseExceptionGroup):
51
+ remaining_exceptions = []
52
+ for handler in node.handlers:
53
+ if handler.type is None:
54
+ exc_type = BaseException
55
+ elif isinstance(handler.type, ast.Name):
56
+ exc_type = self.get_variable(handler.type.id)
57
+ else:
58
+ exc_type = await self.visit(handler.type, wrap_exceptions=True)
59
+ matching_exceptions = [ex for ex in e.exceptions if isinstance(ex, exc_type)]
60
+ if matching_exceptions:
61
+ if handler.name:
62
+ self.set_variable(handler.name, BaseExceptionGroup("", matching_exceptions))
63
+ for stmt in handler.body:
64
+ result = await self.visit(stmt, wrap_exceptions=True)
65
+ handled = True
66
+ remaining_exceptions.extend([ex for ex in e.exceptions if not isinstance(ex, exc_type)])
67
+ if remaining_exceptions and not handled:
68
+ raise BaseExceptionGroup("Uncaught exceptions", remaining_exceptions)
69
+ if handled:
70
+ exc_info = None
71
+ else:
72
+ for handler in node.handlers:
73
+ if handler.type is None:
74
+ exc_type = BaseException
75
+ elif isinstance(handler.type, ast.Name):
76
+ exc_type = self.get_variable(handler.type.id)
77
+ else:
78
+ exc_type = await self.visit(handler.type, wrap_exceptions=True)
79
+ if exc_info and issubclass(exc_info[0], exc_type):
80
+ if handler.name:
81
+ self.set_variable(handler.name, exc_info[1])
82
+ for stmt in handler.body:
83
+ result = await self.visit(stmt, wrap_exceptions=True)
84
+ exc_info = None
85
+ handled = True
86
+ break
87
+ if exc_info and not handled:
88
+ raise exc_info[1]
89
+ else:
90
+ for stmt in node.orelse:
91
+ result = await self.visit(stmt, wrap_exceptions=True)
92
+ finally:
93
+ for stmt in node.finalbody:
94
+ try:
95
+ await self.visit(stmt, wrap_exceptions=True)
96
+ except ReturnException:
97
+ raise
98
+ except Exception:
99
+ if exc_info:
100
+ raise exc_info[1]
101
+ raise
102
+
103
+ return result
104
+
105
+ async def visit_Raise(self: ASTInterpreter, node: ast.Raise, wrap_exceptions: bool = True) -> None:
106
+ exc = await self.visit(node.exc, wrap_exceptions=wrap_exceptions) if node.exc else None
107
+ if exc:
108
+ raise exc
109
+ raise Exception("Raise with no exception specified")