quantalogic 0.59.3__py3-none-any.whl → 0.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. quantalogic/agent.py +268 -24
  2. quantalogic/agent_config.py +5 -5
  3. quantalogic/agent_factory.py +2 -2
  4. quantalogic/codeact/__init__.py +0 -0
  5. quantalogic/codeact/agent.py +499 -0
  6. quantalogic/codeact/cli.py +232 -0
  7. quantalogic/codeact/constants.py +9 -0
  8. quantalogic/codeact/events.py +78 -0
  9. quantalogic/codeact/llm_util.py +76 -0
  10. quantalogic/codeact/prompts/error_format.j2 +11 -0
  11. quantalogic/codeact/prompts/generate_action.j2 +26 -0
  12. quantalogic/codeact/prompts/generate_program.j2 +39 -0
  13. quantalogic/codeact/prompts/response_format.j2 +11 -0
  14. quantalogic/codeact/tools_manager.py +135 -0
  15. quantalogic/codeact/utils.py +135 -0
  16. quantalogic/coding_agent.py +2 -2
  17. quantalogic/create_custom_agent.py +26 -78
  18. quantalogic/prompts/chat_system_prompt.j2 +10 -7
  19. quantalogic/prompts/code_2_system_prompt.j2 +190 -0
  20. quantalogic/prompts/code_system_prompt.j2 +142 -0
  21. quantalogic/prompts/doc_system_prompt.j2 +178 -0
  22. quantalogic/prompts/legal_2_system_prompt.j2 +218 -0
  23. quantalogic/prompts/legal_system_prompt.j2 +140 -0
  24. quantalogic/prompts/system_prompt.j2 +6 -2
  25. quantalogic/prompts/tools_prompt.j2 +2 -4
  26. quantalogic/prompts.py +23 -4
  27. quantalogic/python_interpreter/__init__.py +23 -0
  28. quantalogic/python_interpreter/assignment_visitors.py +63 -0
  29. quantalogic/python_interpreter/base_visitors.py +20 -0
  30. quantalogic/python_interpreter/class_visitors.py +22 -0
  31. quantalogic/python_interpreter/comprehension_visitors.py +172 -0
  32. quantalogic/python_interpreter/context_visitors.py +59 -0
  33. quantalogic/python_interpreter/control_flow_visitors.py +88 -0
  34. quantalogic/python_interpreter/exception_visitors.py +109 -0
  35. quantalogic/python_interpreter/exceptions.py +39 -0
  36. quantalogic/python_interpreter/execution.py +202 -0
  37. quantalogic/python_interpreter/function_utils.py +386 -0
  38. quantalogic/python_interpreter/function_visitors.py +209 -0
  39. quantalogic/python_interpreter/import_visitors.py +28 -0
  40. quantalogic/python_interpreter/interpreter_core.py +358 -0
  41. quantalogic/python_interpreter/literal_visitors.py +74 -0
  42. quantalogic/python_interpreter/misc_visitors.py +148 -0
  43. quantalogic/python_interpreter/operator_visitors.py +108 -0
  44. quantalogic/python_interpreter/scope.py +10 -0
  45. quantalogic/python_interpreter/visit_handlers.py +110 -0
  46. quantalogic/server/agent_server.py +1 -1
  47. quantalogic/tools/__init__.py +6 -3
  48. quantalogic/tools/action_gen.py +366 -0
  49. quantalogic/tools/duckduckgo_search_tool.py +1 -0
  50. quantalogic/tools/execute_bash_command_tool.py +114 -57
  51. quantalogic/tools/file_tracker_tool.py +49 -0
  52. quantalogic/tools/google_packages/google_news_tool.py +3 -0
  53. quantalogic/tools/image_generation/dalle_e.py +89 -137
  54. quantalogic/tools/python_tool.py +13 -0
  55. quantalogic/tools/rag_tool/__init__.py +2 -9
  56. quantalogic/tools/rag_tool/document_rag_sources_.py +728 -0
  57. quantalogic/tools/rag_tool/ocr_pdf_markdown.py +144 -0
  58. quantalogic/tools/replace_in_file_tool.py +1 -1
  59. quantalogic/tools/{search_definition_names.py → search_definition_names_tool.py} +2 -2
  60. quantalogic/tools/terminal_capture_tool.py +293 -0
  61. quantalogic/tools/tool.py +120 -22
  62. quantalogic/tools/utilities/__init__.py +2 -0
  63. quantalogic/tools/utilities/download_file_tool.py +3 -5
  64. quantalogic/tools/utilities/llm_tool.py +283 -0
  65. quantalogic/tools/utilities/selenium_tool.py +296 -0
  66. quantalogic/tools/utilities/vscode_tool.py +1 -1
  67. quantalogic/tools/web_navigation/__init__.py +5 -0
  68. quantalogic/tools/web_navigation/web_tool.py +145 -0
  69. quantalogic/tools/write_file_tool.py +72 -36
  70. quantalogic/utils/__init__.py +0 -1
  71. quantalogic/utils/test_python_interpreter.py +119 -0
  72. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/METADATA +7 -2
  73. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/RECORD +76 -35
  74. quantalogic/tools/rag_tool/document_metadata.py +0 -15
  75. quantalogic/tools/rag_tool/query_response.py +0 -20
  76. quantalogic/tools/rag_tool/rag_tool.py +0 -566
  77. quantalogic/tools/rag_tool/rag_tool_beta.py +0 -264
  78. quantalogic/utils/python_interpreter.py +0 -905
  79. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/LICENSE +0 -0
  80. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/WHEEL +0 -0
  81. {quantalogic-0.59.3.dist-info → quantalogic-0.61.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,39 @@
1
+ import ast
2
+ from typing import Any, List
3
+
4
+ class ReturnException(Exception):
5
+ def __init__(self, value: Any) -> None:
6
+ self.value: Any = value
7
+
8
+ class BreakException(Exception):
9
+ pass
10
+
11
+ class ContinueException(Exception):
12
+ pass
13
+
14
+ class BaseExceptionGroup(Exception):
15
+ def __init__(self, message: str, exceptions: List[Exception]):
16
+ super().__init__(message)
17
+ self.exceptions = exceptions
18
+ self.message = message
19
+
20
+ def __str__(self):
21
+ return f"{self.message}: {', '.join(str(e) for e in self.exceptions)}"
22
+
23
+ class WrappedException(Exception):
24
+ def __init__(self, message: str, original_exception: Exception, lineno: int, col: int, context_line: str):
25
+ super().__init__(message)
26
+ self.original_exception: Exception = original_exception
27
+ self.lineno: int = lineno
28
+ self.col: int = col
29
+ self.context_line: str = context_line
30
+ self.message = original_exception.args[0] if original_exception.args else str(original_exception)
31
+
32
+ def __str__(self):
33
+ return f"Error line {self.lineno}, col {self.col}:\n{self.context_line}\nDescription: {self.message}"
34
+
35
+ def has_await(node: ast.AST) -> bool:
36
+ for child in ast.walk(node):
37
+ if isinstance(child, ast.Await):
38
+ return True
39
+ return False
@@ -0,0 +1,202 @@
1
+ import ast
2
+ import asyncio
3
+ import textwrap
4
+ import time
5
+ from dataclasses import dataclass
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ from .interpreter_core import ASTInterpreter
9
+ from .function_utils import Function, AsyncFunction
10
+ from .exceptions import WrappedException
11
+
12
+ @dataclass
13
+ class AsyncExecutionResult:
14
+ result: Any
15
+ error: Optional[str]
16
+ execution_time: float
17
+ local_variables: Optional[Dict[str, Any]] = None # Added to store local variables
18
+
19
+ def optimize_ast(tree: ast.AST) -> ast.AST:
20
+ """Perform constant folding and basic optimizations on the AST."""
21
+ class ConstantFolder(ast.NodeTransformer):
22
+ def visit_BinOp(self, node):
23
+ self.generic_visit(node)
24
+ if isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant):
25
+ left, right = node.left.value, node.right.value
26
+ if isinstance(left, (int, float)) and isinstance(right, (int, float)):
27
+ if isinstance(node.op, ast.Add):
28
+ return ast.Constant(value=left + right)
29
+ elif isinstance(node.op, ast.Sub):
30
+ return ast.Constant(value=left - right)
31
+ elif isinstance(node.op, ast.Mult):
32
+ return ast.Constant(value=left * right)
33
+ elif isinstance(node.op, ast.Div) and right != 0:
34
+ return ast.Constant(value=left / right)
35
+ return node
36
+
37
+ def visit_If(self, node):
38
+ self.generic_visit(node)
39
+ if isinstance(node.test, ast.Constant):
40
+ if node.test.value:
41
+ return ast.Module(body=node.body, type_ignores=[])
42
+ else:
43
+ return ast.Module(body=node.orelse, type_ignores=[])
44
+ return node
45
+
46
+ return ConstantFolder().visit(tree)
47
+
48
+ class ControlledEventLoop:
49
+ """Encapsulated event loop management to prevent unauthorized access"""
50
+ def __init__(self):
51
+ self._loop = None
52
+ self._created = False
53
+ self._lock = asyncio.Lock()
54
+
55
+ async def get_loop(self) -> asyncio.AbstractEventLoop:
56
+ async with self._lock:
57
+ if self._loop is None:
58
+ self._loop = asyncio.new_event_loop()
59
+ self._created = True
60
+ return self._loop
61
+
62
+ async def cleanup(self):
63
+ async with self._lock:
64
+ if self._created and self._loop and not self._loop.is_closed():
65
+ for task in asyncio.all_tasks(self._loop):
66
+ task.cancel()
67
+ await asyncio.gather(*asyncio.all_tasks(self._loop), return_exceptions=True)
68
+ self._loop.close()
69
+ self._loop = None
70
+ self._created = False
71
+
72
+ async def run_task(self, coro, timeout: float) -> Any:
73
+ return await asyncio.wait_for(coro, timeout=timeout)
74
+
75
+ async def execute_async(
76
+ code: str,
77
+ entry_point: Optional[str] = None,
78
+ args: Optional[Tuple] = None,
79
+ kwargs: Optional[Dict[str, Any]] = None,
80
+ timeout: float = 30,
81
+ allowed_modules: List[str] = ['asyncio'],
82
+ namespace: Optional[Dict[str, Any]] = None,
83
+ max_memory_mb: int = 1024
84
+ ) -> AsyncExecutionResult:
85
+ start_time = time.time()
86
+ event_loop_manager = ControlledEventLoop()
87
+
88
+ try:
89
+ ast_tree = optimize_ast(ast.parse(textwrap.dedent(code)))
90
+ loop = await event_loop_manager.get_loop()
91
+
92
+ # Remove direct asyncio access from builtins
93
+ safe_namespace = namespace.copy() if namespace else {}
94
+ safe_namespace.pop('asyncio', None) # Prevent direct asyncio access
95
+
96
+ interpreter = ASTInterpreter(
97
+ allowed_modules=allowed_modules,
98
+ restrict_os=True,
99
+ namespace=safe_namespace,
100
+ max_memory_mb=max_memory_mb,
101
+ source=code # Pass source code for better error context
102
+ )
103
+ interpreter.loop = loop
104
+
105
+ async def run_execution():
106
+ return await interpreter.execute_async(ast_tree)
107
+
108
+ await event_loop_manager.run_task(run_execution(), timeout=timeout)
109
+
110
+ if entry_point:
111
+ func = interpreter.env_stack[0].get(entry_point)
112
+ if not func:
113
+ raise NameError(f"Function '{entry_point}' not found in the code")
114
+ args = args or ()
115
+ kwargs = kwargs or {}
116
+ if isinstance(func, AsyncFunction) or asyncio.iscoroutinefunction(func):
117
+ # Expect a tuple (result, local_vars) from AsyncFunction
118
+ execution_result = await event_loop_manager.run_task(func(*args, **kwargs), timeout=timeout)
119
+ if isinstance(execution_result, tuple) and len(execution_result) == 2:
120
+ result, local_vars = execution_result
121
+ else:
122
+ result, local_vars = execution_result, {}
123
+ elif isinstance(func, Function):
124
+ result = await func(*args, **kwargs)
125
+ local_vars = {} # Non-async functions don't yet support local var return
126
+ else:
127
+ result = func(*args, **kwargs)
128
+ if asyncio.iscoroutine(result):
129
+ result = await event_loop_manager.run_task(result, timeout=timeout)
130
+ local_vars = {}
131
+ if asyncio.iscoroutine(result):
132
+ result = await event_loop_manager.run_task(result, timeout=timeout)
133
+ else:
134
+ result = await interpreter.execute_async(ast_tree)
135
+ local_vars = {k: v for k, v in interpreter.env_stack[-1].items() if not k.startswith('__')}
136
+
137
+ # Filter out internal variables if not already filtered
138
+ filtered_local_vars = local_vars if local_vars else {}
139
+ if not entry_point: # Apply filtering only for module-level execution
140
+ filtered_local_vars = {k: v for k, v in local_vars.items() if not k.startswith('__')}
141
+
142
+ return AsyncExecutionResult(
143
+ result=result,
144
+ error=None,
145
+ execution_time=time.time() - start_time,
146
+ local_variables=filtered_local_vars
147
+ )
148
+ except asyncio.TimeoutError as e:
149
+ return AsyncExecutionResult(
150
+ result=None,
151
+ error=f'TimeoutError: Execution exceeded {timeout} seconds',
152
+ execution_time=time.time() - start_time
153
+ )
154
+ except WrappedException as e:
155
+ return AsyncExecutionResult(
156
+ result=None,
157
+ error=str(e),
158
+ execution_time=time.time() - start_time
159
+ )
160
+ except Exception as e:
161
+ error_type = type(getattr(e, 'original_exception', e)).__name__
162
+ error_msg = f'{error_type}: {str(e)}'
163
+ if hasattr(e, 'lineno') and hasattr(e, 'col_offset'):
164
+ error_msg += f' at line {e.lineno}, col {e.col_offset}'
165
+ return AsyncExecutionResult(
166
+ result=None,
167
+ error=error_msg,
168
+ execution_time=time.time() - start_time
169
+ )
170
+ finally:
171
+ await event_loop_manager.cleanup()
172
+
173
+ def interpret_ast(ast_tree: ast.AST, allowed_modules: List[str], source: str = "", restrict_os: bool = False, namespace: Optional[Dict[str, Any]] = None) -> Any:
174
+ ast_tree = optimize_ast(ast_tree)
175
+ event_loop_manager = ControlledEventLoop()
176
+
177
+ # Remove asyncio from namespace
178
+ safe_namespace = namespace.copy() if namespace else {}
179
+ safe_namespace.pop('asyncio', None)
180
+
181
+ interpreter = ASTInterpreter(allowed_modules=allowed_modules, source=source, restrict_os=restrict_os, namespace=safe_namespace)
182
+
183
+ async def run_interpreter():
184
+ loop = await event_loop_manager.get_loop()
185
+ interpreter.loop = loop
186
+ result = await interpreter.visit(ast_tree, wrap_exceptions=True)
187
+ return result
188
+
189
+ try:
190
+ loop = asyncio.new_event_loop()
191
+ asyncio.set_event_loop(loop)
192
+ interpreter.loop = loop
193
+ result = loop.run_until_complete(run_interpreter())
194
+ return result
195
+ finally:
196
+ if not loop.is_closed():
197
+ loop.close()
198
+
199
+ def interpret_code(source_code: str, allowed_modules: List[str], restrict_os: bool = False, namespace: Optional[Dict[str, Any]] = None) -> Any:
200
+ dedented_source = textwrap.dedent(source_code).strip()
201
+ tree: ast.AST = ast.parse(dedented_source)
202
+ return interpret_ast(tree, allowed_modules, source=dedented_source, restrict_os=restrict_os, namespace=namespace)
@@ -0,0 +1,386 @@
1
+ import ast
2
+ import asyncio
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from .interpreter_core import ASTInterpreter
6
+ from .exceptions import ReturnException, BreakException, ContinueException
7
+
8
+ class GeneratorWrapper:
9
+ def __init__(self, gen):
10
+ self.gen = gen
11
+ self.closed = False
12
+
13
+ def __iter__(self):
14
+ return self
15
+
16
+ def __next__(self):
17
+ if self.closed:
18
+ raise StopIteration
19
+ try:
20
+ return next(self.gen)
21
+ except StopIteration:
22
+ self.closed = True
23
+ raise
24
+
25
+ def send(self, value):
26
+ if self.closed:
27
+ raise StopIteration
28
+ try:
29
+ return self.gen.send(value)
30
+ except StopIteration:
31
+ self.closed = True
32
+ raise
33
+
34
+ def throw(self, exc):
35
+ if self.closed:
36
+ raise StopIteration
37
+ try:
38
+ return self.gen.throw(exc)
39
+ except StopIteration:
40
+ self.closed = True
41
+ raise
42
+
43
+ def close(self):
44
+ self.closed = True
45
+ self.gen.close()
46
+
47
+ class Function:
48
+ def __init__(self, node: ast.FunctionDef, closure: List[Dict[str, Any]], interpreter: ASTInterpreter,
49
+ pos_kw_params: List[str], vararg_name: Optional[str], kwonly_params: List[str],
50
+ kwarg_name: Optional[str], pos_defaults: Dict[str, Any], kw_defaults: Dict[str, Any]) -> None:
51
+ self.node: ast.FunctionDef = node
52
+ self.closure: List[Dict[str, Any]] = closure[:]
53
+ self.interpreter: ASTInterpreter = interpreter
54
+ self.pos_kw_params = pos_kw_params
55
+ self.vararg_name = vararg_name
56
+ self.kwonly_params = kwonly_params
57
+ self.kwarg_name = kwarg_name
58
+ self.pos_defaults = pos_defaults
59
+ self.kw_defaults = kw_defaults
60
+ self.defining_class = None
61
+ self.is_generator = any(isinstance(n, (ast.Yield, ast.YieldFrom)) for n in ast.walk(node))
62
+ self.generator_state = None # Added for generator protocol
63
+
64
+ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
65
+ new_env_stack: List[Dict[str, Any]] = self.closure[:]
66
+ local_frame: Dict[str, Any] = {}
67
+ local_frame[self.node.name] = self
68
+
69
+ num_pos = len(self.pos_kw_params)
70
+ for i, arg in enumerate(args):
71
+ if i < num_pos:
72
+ local_frame[self.pos_kw_params[i]] = arg
73
+ elif self.vararg_name:
74
+ if self.vararg_name not in local_frame:
75
+ local_frame[self.vararg_name] = []
76
+ local_frame[self.vararg_name].append(arg)
77
+ else:
78
+ raise TypeError(f"Function '{self.node.name}' takes {num_pos} positional arguments but {len(args)} were given")
79
+ if self.vararg_name and self.vararg_name not in local_frame:
80
+ local_frame[self.vararg_name] = tuple()
81
+
82
+ for kwarg_name, kwarg_value in kwargs.items():
83
+ if kwarg_name in self.pos_kw_params or kwarg_name in self.kwonly_params:
84
+ if kwarg_name in local_frame:
85
+ raise TypeError(f"Function '{self.node.name}' got multiple values for argument '{kwarg_name}'")
86
+ local_frame[kwarg_name] = kwarg_value
87
+ elif self.kwarg_name:
88
+ if self.kwarg_name not in local_frame:
89
+ local_frame[self.kwarg_name] = {}
90
+ local_frame[self.kwarg_name][kwarg_name] = kwarg_value
91
+ else:
92
+ raise TypeError(f"Function '{self.node.name}' got an unexpected keyword argument '{kwarg_name}'")
93
+
94
+ for param in self.pos_kw_params:
95
+ if param not in local_frame and param in self.pos_defaults:
96
+ local_frame[param] = self.pos_defaults[param]
97
+ for param in self.kwonly_params:
98
+ if param not in local_frame and param in self.kw_defaults:
99
+ local_frame[param] = self.kw_defaults[param]
100
+
101
+ if self.kwarg_name and self.kwarg_name in local_frame:
102
+ local_frame[self.kwarg_name] = dict(local_frame[self.kwarg_name])
103
+
104
+ missing_args = [param for param in self.pos_kw_params if param not in local_frame and param not in self.pos_defaults]
105
+ missing_args += [param for param in self.kwonly_params if param not in local_frame and param not in self.kw_defaults]
106
+ if missing_args:
107
+ raise TypeError(f"Function '{self.node.name}' missing required arguments: {', '.join(missing_args)}")
108
+
109
+ if self.pos_kw_params and self.pos_kw_params[0] == 'self' and args:
110
+ local_frame['self'] = args[0]
111
+ local_frame['__current_method__'] = self
112
+
113
+ new_env_stack.append(local_frame)
114
+ new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
115
+
116
+ if self.defining_class and args:
117
+ new_interp.current_class = self.defining_class
118
+ new_interp.current_instance = args[0]
119
+
120
+ if self.is_generator:
121
+ async def generator():
122
+ for body_stmt in self.node.body:
123
+ if isinstance(body_stmt, ast.Expr) and isinstance(body_stmt.value, ast.Yield):
124
+ value = await new_interp.visit(body_stmt.value, wrap_exceptions=True)
125
+ yield value
126
+ elif isinstance(body_stmt, ast.Expr) and isinstance(body_stmt.value, ast.YieldFrom):
127
+ sub_iterable = await new_interp.visit(body_stmt.value, wrap_exceptions=True)
128
+ for v in sub_iterable:
129
+ yield v
130
+ else:
131
+ await new_interp.visit(body_stmt, wrap_exceptions=True)
132
+ gen = generator()
133
+ return GeneratorWrapper(gen)
134
+ else:
135
+ last_value = None
136
+ try:
137
+ for stmt in self.node.body:
138
+ last_value = await new_interp.visit(stmt, wrap_exceptions=True)
139
+ return last_value
140
+ except ReturnException as ret:
141
+ return ret.value
142
+ return last_value
143
+
144
+ def _send_sync(self, gen, value):
145
+ try:
146
+ return next(gen) if value is None else gen.send(value)
147
+ except StopIteration:
148
+ raise
149
+
150
+ def _throw_sync(self, gen, exc):
151
+ try:
152
+ return gen.throw(exc)
153
+ except StopIteration:
154
+ raise
155
+
156
+ def __get__(self, instance: Any, owner: Any):
157
+ if instance is None:
158
+ return self
159
+ async def method(*args: Any, **kwargs: Any) -> Any:
160
+ return await self(instance, *args, **kwargs)
161
+ method.__self__ = instance
162
+ return method
163
+
164
+ class AsyncFunction:
165
+ def __init__(self, node: ast.AsyncFunctionDef, closure: List[Dict[str, Any]], interpreter: ASTInterpreter,
166
+ pos_kw_params: List[str], vararg_name: Optional[str], kwonly_params: List[str],
167
+ kwarg_name: Optional[str], pos_defaults: Dict[str, Any], kw_defaults: Dict[str, Any]) -> None:
168
+ self.node: ast.AsyncFunctionDef = node
169
+ self.closure: List[Dict[str, Any]] = closure[:]
170
+ self.interpreter: ASTInterpreter = interpreter
171
+ self.pos_kw_params = pos_kw_params
172
+ self.vararg_name = vararg_name
173
+ self.kwonly_params = kwonly_params
174
+ self.kwarg_name = kwarg_name
175
+ self.pos_defaults = pos_defaults
176
+ self.kw_defaults = kw_defaults
177
+
178
+ async def __call__(self, *args: Any, **kwargs: Any) -> tuple[Any, Dict[str, Any]]:
179
+ new_env_stack: List[Dict[str, Any]] = self.closure[:]
180
+ local_frame: Dict[str, Any] = {}
181
+ local_frame[self.node.name] = self
182
+
183
+ num_pos = len(self.pos_kw_params)
184
+ for i, arg in enumerate(args):
185
+ if i < num_pos:
186
+ local_frame[self.pos_kw_params[i]] = arg
187
+ elif self.vararg_name:
188
+ if self.vararg_name not in local_frame:
189
+ local_frame[self.vararg_name] = []
190
+ local_frame[self.vararg_name].append(arg)
191
+ else:
192
+ raise TypeError(f"Async function '{self.node.name}' takes {num_pos} positional arguments but {len(args)} were given")
193
+ if self.vararg_name and self.vararg_name not in local_frame:
194
+ local_frame[self.vararg_name] = tuple()
195
+
196
+ for kwarg_name, kwarg_value in kwargs.items():
197
+ if kwarg_name in self.pos_kw_params or kwarg_name in self.kwonly_params:
198
+ if kwarg_name in local_frame:
199
+ raise TypeError(f"Async function '{self.node.name}' got multiple values for argument '{kwarg_name}'")
200
+ local_frame[kwarg_name] = kwarg_value
201
+ elif self.kwarg_name:
202
+ if self.kwarg_name not in local_frame:
203
+ local_frame[self.kwarg_name] = {}
204
+ local_frame[self.kwarg_name][kwarg_name] = kwarg_value
205
+ else:
206
+ raise TypeError(f"Async function '{self.node.name}' got an unexpected keyword argument '{kwarg_name}'")
207
+
208
+ for param in self.pos_kw_params:
209
+ if param not in local_frame and param in self.pos_defaults:
210
+ local_frame[param] = self.pos_defaults[param]
211
+ for param in self.kwonly_params:
212
+ if param not in local_frame and param in self.kw_defaults:
213
+ local_frame[param] = self.kw_defaults[param]
214
+
215
+ missing_args = [param for param in self.pos_kw_params if param not in local_frame and param not in self.pos_defaults]
216
+ missing_args += [param for param in self.kwonly_params if param not in local_frame and param not in self.kw_defaults]
217
+ if missing_args:
218
+ raise TypeError(f"Async function '{self.node.name}' missing required arguments: {', '.join(missing_args)}")
219
+
220
+ new_env_stack.append(local_frame)
221
+ new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
222
+ last_value = None
223
+ try:
224
+ for stmt in self.node.body:
225
+ last_value = await new_interp.visit(stmt, wrap_exceptions=True)
226
+ return last_value, {k: v for k, v in local_frame.items() if not k.startswith('__')}
227
+ except ReturnException as ret:
228
+ return ret.value, {k: v for k, v in local_frame.items() if not k.startswith('__')}
229
+ finally:
230
+ new_env_stack.pop()
231
+
232
+ class AsyncGeneratorFunction:
233
+ def __init__(self, node: ast.AsyncFunctionDef, closure: List[Dict[str, Any]], interpreter: ASTInterpreter,
234
+ pos_kw_params: List[str], vararg_name: Optional[str], kwonly_params: List[str],
235
+ kwarg_name: Optional[str], pos_defaults: Dict[str, Any], kw_defaults: Dict[str, Any]) -> None:
236
+ self.node: ast.AsyncFunctionDef = node
237
+ self.closure: List[Dict[str, Any]] = closure[:]
238
+ self.interpreter: ASTInterpreter = interpreter
239
+ self.pos_kw_params = pos_kw_params
240
+ self.vararg_name = vararg_name
241
+ self.kwonly_params = kwonly_params
242
+ self.kwarg_name = kwarg_name
243
+ self.pos_defaults = pos_defaults
244
+ self.kw_defaults = kw_defaults
245
+ self.generator_state = None
246
+
247
+ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
248
+ new_env_stack: List[Dict[str, Any]] = self.closure[:]
249
+ local_frame: Dict[str, Any] = {}
250
+ local_frame[self.node.name] = self
251
+
252
+ num_pos = len(self.pos_kw_params)
253
+ for i, arg in enumerate(args):
254
+ if i < num_pos:
255
+ local_frame[self.pos_kw_params[i]] = arg
256
+ elif self.vararg_name:
257
+ if self.vararg_name not in local_frame:
258
+ local_frame[self.vararg_name] = []
259
+ local_frame[self.vararg_name].append(arg)
260
+ else:
261
+ raise TypeError(f"Async generator '{self.node.name}' takes {num_pos} positional arguments but {len(args)} were given")
262
+ if self.vararg_name and self.vararg_name not in local_frame:
263
+ local_frame[self.vararg_name] = tuple()
264
+
265
+ for kwarg_name, kwarg_value in kwargs.items():
266
+ if kwarg_name in self.pos_kw_params or kwarg_name in self.kwonly_params:
267
+ if kwarg_name in local_frame:
268
+ raise TypeError(f"Async generator '{self.node.name}' got multiple values for argument '{kwarg_name}'")
269
+ local_frame[kwarg_name] = kwarg_value
270
+ elif self.kwarg_name:
271
+ if self.kwarg_name not in local_frame:
272
+ local_frame[self.kwarg_name] = {}
273
+ local_frame[self.kwarg_name][kwarg_name] = kwarg_value
274
+ else:
275
+ raise TypeError(f"Async generator '{self.node.name}' got an unexpected keyword argument '{kwarg_name}'")
276
+
277
+ for param in self.pos_kw_params:
278
+ if param not in local_frame and param in self.pos_defaults:
279
+ local_frame[param] = self.pos_defaults[param]
280
+ for param in self.kwonly_params:
281
+ if param not in local_frame and param in self.kw_defaults:
282
+ local_frame[param] = self.kw_defaults[param]
283
+
284
+ missing_args = [param for param in self.pos_kw_params if param not in local_frame and param not in self.pos_defaults]
285
+ missing_args += [param for param in self.kwonly_params if param not in local_frame and param not in self.kw_defaults]
286
+ if missing_args:
287
+ raise TypeError(f"Async generator '{self.node.name}' missing required arguments: {', '.join(missing_args)}")
288
+
289
+ new_env_stack.append(local_frame)
290
+ new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
291
+
292
+ async def generator():
293
+ nonlocal self
294
+ if self.generator_state is None:
295
+ self.generator_state = {"closed": False, "pending_value": None}
296
+
297
+ for stmt in self.node.body:
298
+ if self.generator_state["closed"]:
299
+ raise StopAsyncIteration
300
+
301
+ if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Yield):
302
+ value = await new_interp.visit(stmt.value, wrap_exceptions=True)
303
+ yield value
304
+ elif isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.YieldFrom):
305
+ sub_iterable = await new_interp.visit(stmt.value, wrap_exceptions=True)
306
+ async for v in sub_iterable:
307
+ yield v
308
+ else:
309
+ await new_interp.visit(stmt, wrap_exceptions=True)
310
+
311
+ gen = generator()
312
+ gen.send = lambda value: asyncio.run_coroutine_threadsafe(self._send(gen, value), self.interpreter.loop).result()
313
+ gen.throw = lambda exc: asyncio.run_coroutine_threadsafe(self._throw(gen, exc), self.interpreter.loop).result()
314
+ gen.close = lambda: setattr(self.generator_state, "closed", True)
315
+ return gen
316
+
317
+ async def _send(self, gen, value):
318
+ try:
319
+ return await gen.asend(value)
320
+ except StopAsyncIteration:
321
+ raise
322
+
323
+ async def _throw(self, gen, exc):
324
+ try:
325
+ return await gen.athrow(exc)
326
+ except StopAsyncIteration:
327
+ raise
328
+
329
+ class LambdaFunction:
330
+ def __init__(self, node: ast.Lambda, closure: List[Dict[str, Any]], interpreter: ASTInterpreter,
331
+ pos_kw_params: List[str], vararg_name: Optional[str], kwonly_params: List[str],
332
+ kwarg_name: Optional[str], pos_defaults: Dict[str, Any], kw_defaults: Dict[str, Any]) -> None:
333
+ self.node: ast.Lambda = node
334
+ self.closure: List[Dict[str, Any]] = closure[:]
335
+ self.interpreter: ASTInterpreter = interpreter
336
+ self.pos_kw_params = pos_kw_params
337
+ self.vararg_name = vararg_name
338
+ self.kwonly_params = kwonly_params
339
+ self.kwarg_name = kwarg_name
340
+ self.pos_defaults = pos_defaults
341
+ self.kw_defaults = kw_defaults
342
+
343
+ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
344
+ new_env_stack: List[Dict[str, Any]] = self.closure[:]
345
+ local_frame: Dict[str, Any] = {}
346
+
347
+ num_pos = len(self.pos_kw_params)
348
+ for i, arg in enumerate(args):
349
+ if i < num_pos:
350
+ local_frame[self.pos_kw_params[i]] = arg
351
+ elif self.vararg_name:
352
+ if self.vararg_name not in local_frame:
353
+ local_frame[self.vararg_name] = []
354
+ local_frame[self.vararg_name].append(arg)
355
+ else:
356
+ raise TypeError(f"Lambda takes {num_pos} positional arguments but {len(args)} were given")
357
+ if self.vararg_name and self.vararg_name not in local_frame:
358
+ local_frame[self.vararg_name] = tuple()
359
+
360
+ for kwarg_name, kwarg_value in kwargs.items():
361
+ if kwarg_name in self.pos_kw_params or kwarg_name in self.kwonly_params:
362
+ if kwarg_name in local_frame:
363
+ raise TypeError(f"Lambda got multiple values for argument '{kwarg_name}'")
364
+ local_frame[kwarg_name] = kwarg_value
365
+ elif self.kwarg_name:
366
+ if self.kwarg_name not in local_frame:
367
+ local_frame[self.kwarg_name] = {}
368
+ local_frame[self.kwarg_name][kwarg_name] = kwarg_value
369
+ else:
370
+ raise TypeError(f"Lambda got an unexpected keyword argument '{kwarg_name}'")
371
+
372
+ for param in self.pos_kw_params:
373
+ if param not in local_frame and param in self.pos_defaults:
374
+ local_frame[param] = self.pos_defaults[param]
375
+ for param in self.kwonly_params:
376
+ if param not in local_frame and param in self.kw_defaults:
377
+ local_frame[param] = self.kw_defaults[param]
378
+
379
+ missing_args = [param for param in self.pos_kw_params if param not in local_frame and param not in self.pos_defaults]
380
+ missing_args += [param for param in self.kwonly_params if param not in local_frame and param not in self.kw_defaults]
381
+ if missing_args:
382
+ raise TypeError(f"Lambda missing required arguments: {', '.join(missing_args)}")
383
+
384
+ new_env_stack.append(local_frame)
385
+ new_interp: ASTInterpreter = self.interpreter.spawn_from_env(new_env_stack)
386
+ return await new_interp.visit(self.node.body, wrap_exceptions=True)