quantalogic 0.60.0__py3-none-any.whl → 0.61.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. quantalogic/agent_config.py +5 -5
  2. quantalogic/agent_factory.py +2 -2
  3. quantalogic/codeact/__init__.py +0 -0
  4. quantalogic/codeact/agent.py +499 -0
  5. quantalogic/codeact/cli.py +232 -0
  6. quantalogic/codeact/constants.py +9 -0
  7. quantalogic/codeact/events.py +78 -0
  8. quantalogic/codeact/llm_util.py +76 -0
  9. quantalogic/codeact/prompts/error_format.j2 +11 -0
  10. quantalogic/codeact/prompts/generate_action.j2 +26 -0
  11. quantalogic/codeact/prompts/generate_program.j2 +39 -0
  12. quantalogic/codeact/prompts/response_format.j2 +11 -0
  13. quantalogic/codeact/tools_manager.py +135 -0
  14. quantalogic/codeact/utils.py +135 -0
  15. quantalogic/coding_agent.py +2 -2
  16. quantalogic/python_interpreter/__init__.py +23 -0
  17. quantalogic/python_interpreter/assignment_visitors.py +63 -0
  18. quantalogic/python_interpreter/base_visitors.py +20 -0
  19. quantalogic/python_interpreter/class_visitors.py +22 -0
  20. quantalogic/python_interpreter/comprehension_visitors.py +172 -0
  21. quantalogic/python_interpreter/context_visitors.py +59 -0
  22. quantalogic/python_interpreter/control_flow_visitors.py +88 -0
  23. quantalogic/python_interpreter/exception_visitors.py +109 -0
  24. quantalogic/python_interpreter/exceptions.py +39 -0
  25. quantalogic/python_interpreter/execution.py +202 -0
  26. quantalogic/python_interpreter/function_utils.py +386 -0
  27. quantalogic/python_interpreter/function_visitors.py +209 -0
  28. quantalogic/python_interpreter/import_visitors.py +28 -0
  29. quantalogic/python_interpreter/interpreter_core.py +358 -0
  30. quantalogic/python_interpreter/literal_visitors.py +74 -0
  31. quantalogic/python_interpreter/misc_visitors.py +148 -0
  32. quantalogic/python_interpreter/operator_visitors.py +108 -0
  33. quantalogic/python_interpreter/scope.py +10 -0
  34. quantalogic/python_interpreter/visit_handlers.py +110 -0
  35. quantalogic/tools/__init__.py +5 -4
  36. quantalogic/tools/action_gen.py +366 -0
  37. quantalogic/tools/python_tool.py +13 -0
  38. quantalogic/tools/{search_definition_names.py → search_definition_names_tool.py} +2 -2
  39. quantalogic/tools/tool.py +116 -22
  40. quantalogic/utils/__init__.py +0 -1
  41. quantalogic/utils/test_python_interpreter.py +119 -0
  42. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/METADATA +7 -2
  43. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/RECORD +46 -14
  44. quantalogic/utils/python_interpreter.py +0 -905
  45. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/LICENSE +0 -0
  46. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/WHEEL +0 -0
  47. {quantalogic-0.60.0.dist-info → quantalogic-0.61.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,108 @@
1
+ import ast
2
+ from typing import Any
3
+
4
+ from .interpreter_core import ASTInterpreter
5
+
6
+ async def visit_BinOp(self: ASTInterpreter, node: ast.BinOp, wrap_exceptions: bool = True) -> Any:
7
+ left: Any = await self.visit(node.left, wrap_exceptions=wrap_exceptions)
8
+ right: Any = await self.visit(node.right, wrap_exceptions=wrap_exceptions)
9
+ op = node.op
10
+ if isinstance(op, ast.Add):
11
+ return left + right
12
+ elif isinstance(op, ast.Sub):
13
+ if isinstance(left, set) and isinstance(right, set):
14
+ return left - right
15
+ return left - right
16
+ elif isinstance(op, ast.Mult):
17
+ return left * right
18
+ elif isinstance(op, ast.Div):
19
+ return left / right
20
+ elif isinstance(op, ast.FloorDiv):
21
+ return left // right
22
+ elif isinstance(op, ast.Mod):
23
+ return left % right
24
+ elif isinstance(op, ast.Pow):
25
+ return left**right
26
+ elif isinstance(op, ast.LShift):
27
+ return left << right
28
+ elif isinstance(op, ast.RShift):
29
+ return left >> right
30
+ elif isinstance(op, ast.BitOr):
31
+ if isinstance(left, set) and isinstance(right, set):
32
+ return left | right
33
+ return left | right
34
+ elif isinstance(op, ast.BitXor):
35
+ return left ^ right
36
+ elif isinstance(op, ast.BitAnd):
37
+ if isinstance(left, set) and isinstance(right, set):
38
+ return left & right
39
+ return left & right
40
+ else:
41
+ raise Exception("Unsupported binary operator: " + str(op))
42
+
43
+ async def visit_UnaryOp(self: ASTInterpreter, node: ast.UnaryOp, wrap_exceptions: bool = True) -> Any:
44
+ operand: Any = await self.visit(node.operand, wrap_exceptions=wrap_exceptions)
45
+ op = node.op
46
+ if isinstance(op, ast.UAdd):
47
+ return +operand
48
+ elif isinstance(op, ast.USub):
49
+ return -operand
50
+ elif isinstance(op, ast.Not):
51
+ return not operand
52
+ elif isinstance(op, ast.Invert):
53
+ return ~operand
54
+ else:
55
+ raise Exception("Unsupported unary operator: " + str(op))
56
+
57
+ async def visit_Compare(self: ASTInterpreter, node: ast.Compare, wrap_exceptions: bool = True) -> bool:
58
+ left: Any = await self.visit(node.left, wrap_exceptions=wrap_exceptions)
59
+ for op, comparator in zip(node.ops, node.comparators):
60
+ right: Any = await self.visit(comparator, wrap_exceptions=wrap_exceptions)
61
+ if isinstance(op, ast.Eq):
62
+ if not (left == right):
63
+ return False
64
+ elif isinstance(op, ast.NotEq):
65
+ if not (left != right):
66
+ return False
67
+ elif isinstance(op, ast.Lt):
68
+ if not (left < right):
69
+ return False
70
+ elif isinstance(op, ast.LtE):
71
+ if not (left <= right):
72
+ return False
73
+ elif isinstance(op, ast.Gt):
74
+ if not (left > right):
75
+ return False
76
+ elif isinstance(op, ast.GtE):
77
+ if not (left >= right):
78
+ return False
79
+ elif isinstance(op, ast.Is):
80
+ if left is not right:
81
+ return False
82
+ elif isinstance(op, ast.IsNot):
83
+ if not (left is not right):
84
+ return False
85
+ elif isinstance(op, ast.In):
86
+ if left not in right:
87
+ return False
88
+ elif isinstance(op, ast.NotIn):
89
+ if not (left not in right):
90
+ return False
91
+ else:
92
+ raise Exception("Unsupported comparison operator: " + str(op))
93
+ left = right
94
+ return True
95
+
96
+ async def visit_BoolOp(self: ASTInterpreter, node: ast.BoolOp, wrap_exceptions: bool = True) -> bool:
97
+ if isinstance(node.op, ast.And):
98
+ for value in node.values:
99
+ if not await self.visit(value, wrap_exceptions=wrap_exceptions):
100
+ return False
101
+ return True
102
+ elif isinstance(node.op, ast.Or):
103
+ for value in node.values:
104
+ if await self.visit(value, wrap_exceptions=wrap_exceptions):
105
+ return True
106
+ return False
107
+ else:
108
+ raise Exception("Unsupported boolean operator: " + str(node.op))
@@ -0,0 +1,10 @@
1
+ # quantalogic/utils/scope.py
2
+ class Scope:
3
+ def __init__(self, env_stack):
4
+ self.env_stack = env_stack
5
+
6
+ def __enter__(self):
7
+ self.env_stack.append({})
8
+
9
+ def __exit__(self, exc_type, exc_value, traceback):
10
+ self.env_stack.pop()
@@ -0,0 +1,110 @@
1
+ from .base_visitors import (
2
+ visit_Module,
3
+ visit_Expr,
4
+ visit_Pass,
5
+ visit_TypeIgnore,
6
+ )
7
+
8
+ from .import_visitors import (
9
+ visit_Import,
10
+ visit_ImportFrom,
11
+ )
12
+
13
+ from .literal_visitors import (
14
+ visit_Constant,
15
+ visit_Name,
16
+ visit_List,
17
+ visit_Tuple,
18
+ visit_Dict,
19
+ visit_Set,
20
+ visit_Attribute,
21
+ visit_Subscript,
22
+ visit_Slice,
23
+ visit_Index,
24
+ visit_Starred,
25
+ visit_JoinedStr,
26
+ visit_FormattedValue,
27
+ )
28
+
29
+ from .operator_visitors import (
30
+ visit_BinOp,
31
+ visit_UnaryOp,
32
+ visit_Compare,
33
+ visit_BoolOp,
34
+ )
35
+
36
+ from .assignment_visitors import (
37
+ visit_Assign,
38
+ visit_AugAssign,
39
+ visit_AnnAssign,
40
+ visit_NamedExpr,
41
+ )
42
+
43
+ from .control_flow_visitors import (
44
+ visit_If,
45
+ visit_While,
46
+ visit_For,
47
+ visit_AsyncFor,
48
+ visit_Break,
49
+ visit_Continue,
50
+ visit_Return,
51
+ visit_IfExp,
52
+ )
53
+
54
+ from .function_visitors import (
55
+ visit_FunctionDef,
56
+ visit_AsyncFunctionDef,
57
+ visit_AsyncGeneratorDef,
58
+ visit_Call,
59
+ visit_Await,
60
+ visit_Lambda,
61
+ )
62
+
63
+ from .comprehension_visitors import (
64
+ visit_ListComp,
65
+ visit_DictComp,
66
+ visit_SetComp,
67
+ visit_GeneratorExp,
68
+ )
69
+
70
+ from .exception_visitors import (
71
+ visit_Try,
72
+ visit_TryStar,
73
+ visit_Raise,
74
+ )
75
+
76
+ from .class_visitors import (
77
+ visit_ClassDef,
78
+ )
79
+
80
+ from .context_visitors import (
81
+ visit_With,
82
+ visit_AsyncWith,
83
+ )
84
+
85
+ from .misc_visitors import (
86
+ visit_Global,
87
+ visit_Nonlocal,
88
+ visit_Delete,
89
+ visit_Assert,
90
+ visit_Yield,
91
+ visit_YieldFrom,
92
+ visit_Match,
93
+ _match_pattern,
94
+ )
95
+
96
+ __all__ = [
97
+ "visit_Import", "visit_ImportFrom", "visit_ListComp", "visit_Module", "visit_Expr",
98
+ "visit_Constant", "visit_Name", "visit_BinOp", "visit_UnaryOp", "visit_Assign",
99
+ "visit_AugAssign", "visit_AnnAssign", "visit_Compare", "visit_BoolOp", "visit_If",
100
+ "visit_While", "visit_For", "visit_Break", "visit_Continue", "visit_FunctionDef",
101
+ "visit_AsyncFunctionDef", "visit_AsyncGeneratorDef", "visit_Call", "visit_Await",
102
+ "visit_Return", "visit_Lambda", "visit_List", "visit_Tuple", "visit_Dict",
103
+ "visit_Set", "visit_Attribute", "visit_Subscript", "visit_Slice", "visit_Index",
104
+ "visit_Starred", "visit_Pass", "visit_TypeIgnore", "visit_Try", "visit_TryStar",
105
+ "visit_Nonlocal", "visit_JoinedStr", "visit_FormattedValue", "visit_GeneratorExp",
106
+ "visit_ClassDef", "visit_With", "visit_AsyncWith", "visit_Raise", "visit_Global",
107
+ "visit_IfExp", "visit_DictComp", "visit_SetComp", "visit_Yield", "visit_YieldFrom",
108
+ "visit_Match", "visit_Delete", "visit_AsyncFor", "visit_Assert", "visit_NamedExpr",
109
+ "_match_pattern",
110
+ ]
@@ -8,6 +8,7 @@ from .duckduckgo_search_tool import DuckDuckGoSearchTool
8
8
  from .edit_whole_content_tool import EditWholeContentTool
9
9
  from .elixir_tool import ElixirTool
10
10
  from .execute_bash_command_tool import ExecuteBashCommandTool
11
+ from .file_tracker_tool import FileTrackerTool
11
12
  from .grep_app_tool import GrepAppTool
12
13
  from .input_question_tool import InputQuestionTool
13
14
  from .jinja_tool import JinjaTool
@@ -22,17 +23,16 @@ from .read_file_tool import ReadFileTool
22
23
  from .read_html_tool import ReadHTMLTool
23
24
  from .replace_in_file_tool import ReplaceInFileTool
24
25
  from .ripgrep_tool import RipgrepTool
25
- from .safe_python_interpreter_tool import SafePythonInterpreterTool
26
- from .search_definition_names import SearchDefinitionNames
26
+ from .search_definition_names_tool import SearchDefinitionNamesTool
27
27
  from .sequence_tool import SequenceTool
28
28
  from .serpapi_search_tool import SerpApiSearchTool
29
29
  from .sql_query_tool import SQLQueryTool
30
30
  from .task_complete_tool import TaskCompleteTool
31
31
  from .tool import Tool, ToolArgument, create_tool
32
32
  from .unified_diff_tool import UnifiedDiffTool
33
+ from .utils.generate_database_report import generate_database_report
33
34
  from .wikipedia_search_tool import WikipediaSearchTool
34
35
  from .write_file_tool import WriteFileTool
35
- from .file_tracker_tool import FileTrackerTool
36
36
 
37
37
  # Define __all__ to control what gets imported with `from quantalogic.tools import *`
38
38
  __all__ = [
@@ -42,6 +42,7 @@ __all__ = [
42
42
  'EditWholeContentTool',
43
43
  'ElixirTool',
44
44
  'ExecuteBashCommandTool',
45
+ 'generate_database_report',
45
46
  'GrepAppTool',
46
47
  'InputQuestionTool',
47
48
  'JinjaTool',
@@ -57,7 +58,7 @@ __all__ = [
57
58
  'ReplaceInFileTool',
58
59
  'RipgrepTool',
59
60
  'SafePythonInterpreterTool',
60
- 'SearchDefinitionNames',
61
+ 'SearchDefinitionNamesTool',
61
62
  'SequenceTool',
62
63
  'SerpApiSearchTool',
63
64
  'SQLQueryTool',
@@ -0,0 +1,366 @@
1
+ import ast
2
+ import asyncio
3
+ from asyncio import TimeoutError
4
+ from contextlib import AsyncExitStack
5
+ from functools import partial
6
+ from typing import Callable, Dict, List
7
+
8
+ import litellm
9
+ import typer
10
+ from loguru import logger
11
+
12
+ from quantalogic.python_interpreter import execute_async
13
+ from quantalogic.tools.tool import Tool, ToolArgument
14
+
15
+ # Configure loguru to log to a file with rotation, matching original
16
+ logger.add("action_gen.log", rotation="10 MB", level="DEBUG")
17
+
18
+ # Initialize Typer app, unchanged
19
+ app = typer.Typer()
20
+
21
+ # Define tool classes with logging in async_execute, preserving original structure
22
+ class AddTool(Tool):
23
+ def __init__(self):
24
+ super().__init__(
25
+ name="add_tool",
26
+ description="Adds two numbers and returns the sum.",
27
+ arguments=[
28
+ ToolArgument(name="a", arg_type="int", description="First number", required=True),
29
+ ToolArgument(name="b", arg_type="int", description="Second number", required=True)
30
+ ],
31
+ return_type="int"
32
+ )
33
+
34
+ async def async_execute(self, **kwargs) -> str:
35
+ logger.info(f"Starting tool execution: {self.name}")
36
+ logger.info(f"Adding {kwargs['a']} and {kwargs['b']}")
37
+ result = str(int(kwargs["a"]) + int(kwargs["b"]))
38
+ logger.info(f"Finished tool execution: {self.name}")
39
+ return result
40
+
41
+ class MultiplyTool(Tool):
42
+ def __init__(self):
43
+ super().__init__(
44
+ name="multiply_tool",
45
+ description="Multiplies two numbers and returns the product.",
46
+ arguments=[
47
+ ToolArgument(name="x", arg_type="int", description="First number", required=True),
48
+ ToolArgument(name="y", arg_type="int", description="Second number", required=True)
49
+ ],
50
+ return_type="int"
51
+ )
52
+
53
+ async def async_execute(self, **kwargs) -> str:
54
+ logger.info(f"Starting tool execution: {self.name}")
55
+ logger.info(f"Multiplying {kwargs['x']} and {kwargs['y']}")
56
+ result = str(int(kwargs["x"]) * int(kwargs["y"]))
57
+ logger.info(f"Finished tool execution: {self.name}")
58
+ return result
59
+
60
+ class ConcatTool(Tool):
61
+ def __init__(self):
62
+ super().__init__(
63
+ name="concat_tool",
64
+ description="Concatenates two strings.",
65
+ arguments=[
66
+ ToolArgument(name="s1", arg_type="string", description="First string", required=True),
67
+ ToolArgument(name="s2", arg_type="string", description="Second string", required=True)
68
+ ],
69
+ return_type="string"
70
+ )
71
+
72
+ async def async_execute(self, **kwargs) -> str:
73
+ logger.info(f"Starting tool execution: {self.name}")
74
+ logger.info(f"Concatenating '{kwargs['s1']}' and '{kwargs['s2']}'")
75
+ result = kwargs["s1"] + kwargs["s2"]
76
+ logger.info(f"Finished tool execution: {self.name}")
77
+ return result
78
+
79
+ class AgentTool(Tool):
80
+ def __init__(self, model: str = "gemini/gemini-2.0-flash"):
81
+ super().__init__(
82
+ name="agent_tool",
83
+ description="Generates text using a language model based on a system prompt and user prompt.",
84
+ arguments=[
85
+ ToolArgument(name="system_prompt", arg_type="string", description="System prompt to guide the model's behavior", required=True),
86
+ ToolArgument(name="prompt", arg_type="string", description="User prompt to generate a response for", required=True),
87
+ ToolArgument(name="temperature", arg_type="float", description="Temperature for generation (0 to 1)", required=True)
88
+ ],
89
+ return_type="string"
90
+ )
91
+ self.model = model
92
+
93
+ async def async_execute(self, **kwargs) -> str:
94
+ logger.info(f"Starting tool execution: {self.name}")
95
+ system_prompt = kwargs["system_prompt"]
96
+ prompt = kwargs["prompt"]
97
+ temperature = float(kwargs["temperature"])
98
+
99
+ # Validate temperature, unchanged
100
+ if not 0 <= temperature <= 1:
101
+ logger.error(f"Temperature {temperature} is out of range (0-1)")
102
+ raise ValueError("Temperature must be between 0 and 1")
103
+
104
+ logger.info(f"Generating text with model {self.model}, temperature {temperature}")
105
+ try:
106
+ async with AsyncExitStack() as stack:
107
+ timeout_cm = asyncio.timeout(30)
108
+ await stack.enter_async_context(timeout_cm)
109
+
110
+ logger.debug(f"Making API call to {self.model}")
111
+ response = await litellm.acompletion(
112
+ model=self.model,
113
+ messages=[
114
+ {"role": "system", "content": system_prompt},
115
+ {"role": "user", "content": prompt}
116
+ ],
117
+ temperature=temperature,
118
+ max_tokens=1000 # Original default
119
+ )
120
+ generated_text = response.choices[0].message.content.strip()
121
+ logger.debug(f"Generated text: {generated_text}")
122
+ result = generated_text
123
+ logger.info(f"Finished tool execution: {self.name}")
124
+ return result
125
+ except TimeoutError as e:
126
+ error_msg = f"API call to {self.model} timed out after 30 seconds"
127
+ logger.error(error_msg)
128
+ raise TimeoutError(error_msg) from e
129
+ except Exception as e:
130
+ logger.error(f"Failed to generate text with {self.model}: {str(e)}")
131
+ raise RuntimeError(f"Text generation failed: {str(e)}")
132
+
133
+ # Asynchronous function to generate the program, matching original behavior with updated prompt
134
+ async def generate_program(task_description: str, tools: List[Tool], model: str, max_tokens: int) -> str:
135
+ """
136
+ Asynchronously generate a Python program that solves a given task using a list of tools.
137
+
138
+ Args:
139
+ task_description (str): A description of the task to be solved.
140
+ tools (List[Tool]): A list of Tool objects available for use.
141
+ model (str): The litellm model to use for code generation.
142
+ max_tokens (int): Maximum number of tokens for the generated response.
143
+
144
+ Returns:
145
+ str: A string containing a complete Python program.
146
+ """
147
+ logger.debug(f"Generating program for task: {task_description}")
148
+ tool_docstrings = "\n\n".join([tool.to_docstring() for tool in tools])
149
+
150
+ # Updated prompt with reinforced instruction to exclude __main__ block
151
+ prompt = f"""
152
+ You are a Python code generator. Your task is to create a Python program that solves the following task:
153
+ "{task_description}"
154
+
155
+ You have access to the following pre-defined async tool functions, as defined with their signatures and descriptions:
156
+
157
+ {tool_docstrings}
158
+
159
+ Instructions:
160
+ 1. Generate a Python program as a single string.
161
+ 2. Include only the import for asyncio (import asyncio).
162
+ 3. Define an async function named main() that solves the task.
163
+ 4. Use the pre-defined tool functions (e.g., add_tool, multiply_tool, concat_tool) directly by calling them with await and the appropriate arguments as specified in their descriptions.
164
+ 5. Do not redefine the tool functions within the program; assume they are already available in the namespace.
165
+ 6. Return the program as markdown code block.
166
+ 7. Strictly exclude asyncio.run(main()) or any code outside the main() function definition, including any 'if __name__ == "__main__":' block, as the runtime will handle execution of main().
167
+ 8. Do not include explanatory text outside the program string.
168
+ 9. Express all string variables as multiline strings
169
+ string, always start a string at the beginning of a line.
170
+ 10. Always print the result at the end of the program.
171
+
172
+ Example task: "Add 5 and 7 and print the result"
173
+ Example output:
174
+ ```python
175
+ import asyncio
176
+
177
+ async def main():
178
+ result = await add_tool(a=5, b=7)
179
+ print(result)
180
+ ```
181
+ """
182
+
183
+ logger.debug(f"Prompt sent to litellm:\n{prompt}")
184
+
185
+ try:
186
+ logger.debug(f"Calling litellm with model {model}")
187
+ response = await litellm.acompletion(
188
+ model=model,
189
+ messages=[
190
+ {"role": "system", "content": "You are a Python code generator."},
191
+ {"role": "user", "content": prompt}
192
+ ],
193
+ max_tokens=max_tokens,
194
+ temperature=0.3
195
+ )
196
+ generated_code = response.choices[0].message.content.strip()
197
+ logger.debug("Code generation successful")
198
+ except Exception as e:
199
+ logger.error(f"Failed to generate code: {str(e)}")
200
+ raise typer.BadParameter(f"Failed to generate code with model '{model}': {str(e)}")
201
+
202
+ # Clean up output, preserving original logic
203
+ if generated_code.startswith('"""') and generated_code.endswith('"""'):
204
+ generated_code = generated_code[3:-3]
205
+ elif generated_code.startswith("```python") and generated_code.endswith("```"):
206
+ generated_code = generated_code[9:-3].strip()
207
+
208
+ # Post-processing to remove any __main__ block if generated despite instructions
209
+ if "if __name__ == \"__main__\":" in generated_code:
210
+ lines = generated_code.splitlines()
211
+ main_end_idx = next(
212
+ (i for i in range(len(lines)) if "if __name__" in lines[i]),
213
+ len(lines)
214
+ )
215
+ generated_code = "\n".join(lines[:main_end_idx]).strip()
216
+ logger.warning("Removed unexpected __main__ block from generated code")
217
+
218
+ return generated_code
219
+
220
+ # Updated async core logic with improved interpreter usage
221
+ async def generate_core(task: str, model: str, max_tokens: int) -> None:
222
+ """
223
+ Core logic to generate and execute a Python program based on a task description.
224
+
225
+ Args:
226
+ task (str): The task description to generate a program for.
227
+ model (str): The litellm model to use for generation.
228
+ max_tokens (int): Maximum number of tokens for the generated response.
229
+ """
230
+ logger.info(f"Starting generate command for task: {task}")
231
+ # Input validation, unchanged
232
+ if not task.strip():
233
+ logger.error("Task description is empty")
234
+ raise typer.BadParameter("Task description cannot be empty")
235
+ if max_tokens <= 0:
236
+ logger.error("max-tokens must be positive")
237
+ raise typer.BadParameter("max-tokens must be a positive integer")
238
+
239
+ # Initialize tools, unchanged
240
+ tools = [
241
+ AddTool(),
242
+ MultiplyTool(),
243
+ ConcatTool(),
244
+ AgentTool(model=model)
245
+ ]
246
+
247
+ # Generate the program
248
+ try:
249
+ program = await generate_program(task, tools, model, max_tokens)
250
+ except Exception as e:
251
+ logger.error(f"Failed to generate program: {str(e)}")
252
+ typer.echo(typer.style(f"Error: {str(e)}", fg=typer.colors.RED))
253
+ raise typer.Exit(code=1)
254
+
255
+ logger.debug(f"Generated program:\n{program}")
256
+ # Output the generated program with original style
257
+ typer.echo(typer.style("Generated Python Program:", fg=typer.colors.GREEN, bold=True))
258
+ typer.echo(program)
259
+
260
+ # Validate program structure
261
+ try:
262
+ ast_tree = ast.parse(program)
263
+ has_async_main = any(
264
+ isinstance(node, ast.AsyncFunctionDef) and node.name == "main"
265
+ for node in ast.walk(ast_tree)
266
+ )
267
+ if not has_async_main:
268
+ logger.warning("Generated code lacks an async main() function")
269
+ typer.echo(typer.style("Warning: Generated code lacks an async main() function", fg=typer.colors.YELLOW))
270
+ return
271
+ except SyntaxError as e:
272
+ logger.error(f"Syntax error in generated code: {str(e)}")
273
+ typer.echo(typer.style(f"Syntax error in generated code: {str(e)}", fg=typer.colors.RED))
274
+ return
275
+
276
+ # Prepare namespace with tool instances
277
+ namespace: Dict[str, Callable] = {
278
+ "asyncio": asyncio,
279
+ "add_tool": partial(AddTool().async_execute),
280
+ "multiply_tool": partial(MultiplyTool().async_execute),
281
+ "concat_tool": partial(ConcatTool().async_execute),
282
+ "agent_tool": partial(AgentTool(model=model).async_execute),
283
+ }
284
+
285
+ # Check for namespace collisions
286
+ reserved_names = set(vars(__builtins__))
287
+ for name in namespace:
288
+ if name in reserved_names and name != "asyncio":
289
+ logger.warning(f"Namespace collision detected: '{name}' shadows a builtin")
290
+ typer.echo(typer.style(f"Warning: Tool name '{name}' shadows a builtin", fg=typer.colors.YELLOW))
291
+
292
+ # Execute the program
293
+ typer.echo("\n" + typer.style("Executing the program:", fg=typer.colors.GREEN, bold=True))
294
+ try:
295
+ logger.debug("Executing generated code with execute_async")
296
+ execution_result = await execute_async(
297
+ code=program,
298
+ timeout=30,
299
+ entry_point="main",
300
+ allowed_modules=["asyncio"],
301
+ namespace=namespace,
302
+ )
303
+
304
+ # Detailed error handling
305
+ if execution_result.error:
306
+ if "SyntaxError" in execution_result.error:
307
+ logger.error(f"Syntax error: {execution_result.error}")
308
+ typer.echo(typer.style(f"Syntax error: {execution_result.error}", fg=typer.colors.RED))
309
+ elif "TimeoutError" in execution_result.error:
310
+ logger.error(f"Timeout: {execution_result.error}")
311
+ typer.echo(typer.style(f"Timeout: {execution_result.error}", fg=typer.colors.RED))
312
+ else:
313
+ logger.error(f"Runtime error: {execution_result.error}")
314
+ typer.echo(typer.style(f"Runtime error: {execution_result.error}", fg=typer.colors.RED))
315
+ else:
316
+ logger.info(f"Execution completed in {execution_result.execution_time:.2f} seconds")
317
+ typer.echo(typer.style(f"Execution completed in {execution_result.execution_time:.2f} seconds", fg=typer.colors.GREEN))
318
+
319
+ # Display the result if it's not None
320
+ if execution_result.result is not None:
321
+ typer.echo("\n" + typer.style("Result:", fg=typer.colors.BLUE, bold=True))
322
+ typer.echo(str(execution_result.result))
323
+ except ValueError as e:
324
+ logger.error(f"Invalid code generated: {str(e)}")
325
+ typer.echo(typer.style(f"Invalid code: {str(e)}", fg=typer.colors.RED))
326
+ except Exception as e:
327
+ logger.error(f"Unexpected execution error: {str(e)}")
328
+ typer.echo(typer.style(f"Unexpected error during execution: {str(e)}", fg=typer.colors.RED))
329
+ else:
330
+ logger.info("Program executed successfully")
331
+
332
+ @app.command()
333
+ def generate(
334
+ task: str = typer.Argument(
335
+ ...,
336
+ help="The task description to generate a program for (e.g., 'Add 5 and 7 and print the result')"
337
+ ),
338
+ model: str = typer.Option(
339
+ "gemini/gemini-2.0-flash",
340
+ "--model",
341
+ "-m",
342
+ help="The litellm model to use for generation (e.g., 'gpt-3.5-turbo', 'gpt-4')"
343
+ ),
344
+ max_tokens: int = typer.Option(
345
+ 4000,
346
+ "--max-tokens",
347
+ "-t",
348
+ help="Maximum number of tokens for the generated response (default: 4000)"
349
+ )
350
+ ) -> None:
351
+ """Generate and execute a Python program based on a task description"""
352
+ try:
353
+ # Run async core logic, preserving original execution style
354
+ asyncio.run(generate_core(task, model, max_tokens))
355
+ except Exception as e:
356
+ logger.error(f"Command failed: {str(e)}")
357
+ typer.echo(typer.style(f"Error: {str(e)}", fg=typer.colors.RED))
358
+ raise typer.Exit(code=1)
359
+
360
+ # Entry point, unchanged
361
+ def main() -> None:
362
+ logger.debug("Starting script execution")
363
+ app()
364
+
365
+ if __name__ == "__main__":
366
+ main()
@@ -154,6 +154,10 @@ class PythonTool(Tool):
154
154
  script_path = os.path.join(temp_dir, "script.py")
155
155
  self._write_script(script_path, script)
156
156
 
157
+ # Ensure the host directory exists
158
+ if host_dir:
159
+ self._ensure_directory_exists(host_dir)
160
+
157
161
  # Prepare pip install commands
158
162
  pip_install_cmd = self._prepare_install_commands(install_commands)
159
163
 
@@ -417,6 +421,15 @@ class PythonTool(Tool):
417
421
  logger.debug(f"Parsed environment variables: {env_vars}")
418
422
  return env_vars
419
423
 
424
+ def _ensure_directory_exists(self, directory_path: str) -> None:
425
+ """Ensures the specified directory exists, creating it if it does not.
426
+
427
+ Args:
428
+ directory_path (str): The path to the directory to ensure.
429
+ """
430
+ if not os.path.exists(directory_path):
431
+ os.makedirs(directory_path)
432
+
420
433
 
421
434
  if __name__ == "__main__":
422
435
  # Example usage of PythonTool
@@ -24,7 +24,7 @@ logging.basicConfig(level=logging.INFO)
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
26
 
27
- class SearchDefinitionNames(Tool):
27
+ class SearchDefinitionNamesTool(Tool):
28
28
  """Tool for searching definition names in a directory using Tree-sitter.
29
29
 
30
30
  Supports searching for:
@@ -448,7 +448,7 @@ class SearchDefinitionNames(Tool):
448
448
 
449
449
 
450
450
  if __name__ == "__main__":
451
- tool = SearchDefinitionNames()
451
+ tool = SearchDefinitionNamesTool()
452
452
  print(tool.to_markdown())
453
453
 
454
454
  # Example usage with different output formats