jarvis-ai-assistant 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jarvis-ai-assistant might be problematic. Click here for more details.
- jarvis/.jarvis +1 -0
- jarvis/__init__.py +3 -0
- jarvis/__pycache__/agent.cpython-313.pyc +0 -0
- jarvis/__pycache__/models.cpython-313.pyc +0 -0
- jarvis/__pycache__/tools.cpython-313.pyc +0 -0
- jarvis/__pycache__/utils.cpython-313.pyc +0 -0
- jarvis/agent.py +100 -0
- jarvis/main.py +161 -0
- jarvis/models.py +112 -0
- jarvis/tools/__init__.py +22 -0
- jarvis/tools/__pycache__/__init__.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/base.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/file_ops.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/python_script.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/rag.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/search.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/shell.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/user_confirmation.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/user_interaction.cpython-313.pyc +0 -0
- jarvis/tools/__pycache__/webpage.cpython-313.pyc +0 -0
- jarvis/tools/base.py +155 -0
- jarvis/tools/file_ops.py +106 -0
- jarvis/tools/python_script.py +150 -0
- jarvis/tools/rag.py +154 -0
- jarvis/tools/search.py +48 -0
- jarvis/tools/shell.py +67 -0
- jarvis/tools/user_confirmation.py +58 -0
- jarvis/tools/user_interaction.py +86 -0
- jarvis/tools/webpage.py +90 -0
- jarvis/utils.py +105 -0
- jarvis_ai_assistant-0.1.0.dist-info/METADATA +125 -0
- jarvis_ai_assistant-0.1.0.dist-info/RECORD +35 -0
- jarvis_ai_assistant-0.1.0.dist-info/WHEEL +5 -0
- jarvis_ai_assistant-0.1.0.dist-info/entry_points.txt +2 -0
- jarvis_ai_assistant-0.1.0.dist-info/top_level.txt +1 -0
jarvis/.jarvis
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
- 查看当前git仓库状态,如果有未提交的变更,自动根据当前仓库的文件修改内容总结生成一个git commit,然后执行git push推送到远端仓库,全流程不要询问用户
|
jarvis/__init__.py
ADDED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
jarvis/agent.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import subprocess
|
|
3
|
+
from typing import Dict, Any, List, Optional
|
|
4
|
+
from .tools import ToolRegistry
|
|
5
|
+
from .utils import Spinner, PrettyOutput, OutputType, get_multiline_input
|
|
6
|
+
from .models import BaseModel, OllamaModel
|
|
7
|
+
import re
|
|
8
|
+
import os
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
|
|
11
|
+
class Agent:
|
|
12
|
+
def __init__(self, model: BaseModel, tool_registry: ToolRegistry):
|
|
13
|
+
self.model = model
|
|
14
|
+
self.tool_registry = tool_registry
|
|
15
|
+
# 编译正则表达式
|
|
16
|
+
self.tool_call_pattern = re.compile(r'<tool_call>\s*({[^}]+})\s*</tool_call>')
|
|
17
|
+
self.messages = [
|
|
18
|
+
{
|
|
19
|
+
"role": "system",
|
|
20
|
+
"content": """You are a rigorous AI assistant, all data must be obtained through tools, and no fabrication or speculation is allowed. """ + "\n" + self.tool_registry.tool_help_text()
|
|
21
|
+
}
|
|
22
|
+
]
|
|
23
|
+
self.spinner = Spinner()
|
|
24
|
+
|
|
25
|
+
def _call_model(self, messages: List[Dict], use_tools: bool = True) -> Dict:
|
|
26
|
+
"""调用模型获取响应"""
|
|
27
|
+
self.spinner.start()
|
|
28
|
+
try:
|
|
29
|
+
return self.model.chat(
|
|
30
|
+
messages=messages,
|
|
31
|
+
tools=self.tool_registry.get_all_tools() if use_tools else None
|
|
32
|
+
)
|
|
33
|
+
except Exception as e:
|
|
34
|
+
raise Exception(f"模型调用失败: {str(e)}")
|
|
35
|
+
finally:
|
|
36
|
+
self.spinner.stop()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def run(self, user_input: str) :
|
|
40
|
+
"""处理用户输入并返回响应"""
|
|
41
|
+
# 检查是否是结束命令
|
|
42
|
+
self.clear_history()
|
|
43
|
+
self.messages.append({
|
|
44
|
+
"role": "user",
|
|
45
|
+
"content": user_input
|
|
46
|
+
})
|
|
47
|
+
while True:
|
|
48
|
+
try:
|
|
49
|
+
# 获取初始响应
|
|
50
|
+
response = self._call_model(self.messages)
|
|
51
|
+
current_response = response
|
|
52
|
+
|
|
53
|
+
# 将工具执行结果添加到对话
|
|
54
|
+
self.messages.append({
|
|
55
|
+
"role": "assistant",
|
|
56
|
+
"content": response["message"].get("content", ""),
|
|
57
|
+
"tool_calls": current_response["message"]["tool_calls"]
|
|
58
|
+
})
|
|
59
|
+
|
|
60
|
+
# 处理可能的多轮工具调用
|
|
61
|
+
if len(current_response["message"]["tool_calls"]) > 0:
|
|
62
|
+
# 添加当前助手响应到输出(如果有内容)
|
|
63
|
+
if current_response["message"].get("content"):
|
|
64
|
+
PrettyOutput.print(current_response["message"]["content"], OutputType.SYSTEM)
|
|
65
|
+
|
|
66
|
+
# 使用 ToolRegistry 的 handle_tool_calls 方法处理工具调用
|
|
67
|
+
tool_result = self.tool_registry.handle_tool_calls(current_response["message"]["tool_calls"])
|
|
68
|
+
PrettyOutput.print(tool_result, OutputType.RESULT)
|
|
69
|
+
|
|
70
|
+
self.messages.append({
|
|
71
|
+
"role": "tool",
|
|
72
|
+
"content": tool_result
|
|
73
|
+
})
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
# 添加最终响应到对话历史和输出
|
|
77
|
+
final_content = current_response["message"].get("content", "")
|
|
78
|
+
|
|
79
|
+
if final_content:
|
|
80
|
+
PrettyOutput.print(final_content, OutputType.SYSTEM)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
# 如果没有工具调用且响应很短,可能需要继续对话
|
|
84
|
+
user_input = get_multiline_input("您可以继续输入,或输入空行结束当前任务")
|
|
85
|
+
if not user_input:
|
|
86
|
+
PrettyOutput.print("===============任务结束===============", OutputType.INFO)
|
|
87
|
+
break
|
|
88
|
+
|
|
89
|
+
self.messages.append({
|
|
90
|
+
"role": "user",
|
|
91
|
+
"content": user_input
|
|
92
|
+
})
|
|
93
|
+
|
|
94
|
+
except Exception as e:
|
|
95
|
+
error_msg = f"处理响应时出错: {str(e)}"
|
|
96
|
+
PrettyOutput.print(error_msg, OutputType.ERROR)
|
|
97
|
+
|
|
98
|
+
def clear_history(self):
|
|
99
|
+
"""清除对话历史,只保留系统提示"""
|
|
100
|
+
self.messages = [self.messages[0]]
|
jarvis/main.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Command line interface for Jarvis."""
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import yaml
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
# 添加父目录到Python路径以支持导入
|
|
11
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
12
|
+
|
|
13
|
+
from jarvis.agent import Agent
|
|
14
|
+
from jarvis.tools import ToolRegistry
|
|
15
|
+
from jarvis.models import DDGSModel, OllamaModel
|
|
16
|
+
from jarvis.utils import PrettyOutput, OutputType, get_multiline_input
|
|
17
|
+
|
|
18
|
+
# 定义支持的平台和模型
|
|
19
|
+
SUPPORTED_PLATFORMS = {
|
|
20
|
+
"ollama": {
|
|
21
|
+
"models": ["llama3.2", "qwen2.5:14b"],
|
|
22
|
+
"default": "qwen2.5:14b"
|
|
23
|
+
},
|
|
24
|
+
"ddgs": {
|
|
25
|
+
"models": ["gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"],
|
|
26
|
+
"default": "gpt-4o-mini"
|
|
27
|
+
}
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def load_tasks() -> list:
|
|
31
|
+
"""Load tasks from .jarvis file if it exists."""
|
|
32
|
+
if not os.path.exists(".jarvis"):
|
|
33
|
+
return []
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
with open(".jarvis", "r", encoding="utf-8") as f:
|
|
37
|
+
tasks = yaml.safe_load(f)
|
|
38
|
+
|
|
39
|
+
if not isinstance(tasks, list):
|
|
40
|
+
PrettyOutput.print("Warning: .jarvis file should contain a list of tasks", OutputType.ERROR)
|
|
41
|
+
return []
|
|
42
|
+
|
|
43
|
+
return [str(task) for task in tasks if task] # Convert all tasks to strings and filter out empty ones
|
|
44
|
+
except Exception as e:
|
|
45
|
+
PrettyOutput.print(f"Error loading .jarvis file: {str(e)}", OutputType.ERROR)
|
|
46
|
+
return []
|
|
47
|
+
|
|
48
|
+
def select_task(tasks: list) -> str:
|
|
49
|
+
"""Let user select a task from the list or skip."""
|
|
50
|
+
if not tasks:
|
|
51
|
+
return ""
|
|
52
|
+
|
|
53
|
+
PrettyOutput.print("\nFound predefined tasks:", OutputType.INFO)
|
|
54
|
+
for i, task in enumerate(tasks, 1):
|
|
55
|
+
PrettyOutput.print(f"[{i}] {task}", OutputType.INFO)
|
|
56
|
+
PrettyOutput.print("[0] Skip predefined tasks", OutputType.INFO)
|
|
57
|
+
|
|
58
|
+
while True:
|
|
59
|
+
try:
|
|
60
|
+
choice = input("\nSelect a task number (0 to skip): ").strip()
|
|
61
|
+
if not choice:
|
|
62
|
+
return ""
|
|
63
|
+
|
|
64
|
+
choice = int(choice)
|
|
65
|
+
if choice == 0:
|
|
66
|
+
return ""
|
|
67
|
+
elif 1 <= choice <= len(tasks):
|
|
68
|
+
return tasks[choice - 1]
|
|
69
|
+
else:
|
|
70
|
+
PrettyOutput.print("Invalid choice. Please try again.", OutputType.ERROR)
|
|
71
|
+
except ValueError:
|
|
72
|
+
PrettyOutput.print("Please enter a valid number.", OutputType.ERROR)
|
|
73
|
+
|
|
74
|
+
def main():
|
|
75
|
+
"""Main entry point for Jarvis."""
|
|
76
|
+
parser = argparse.ArgumentParser(description="Jarvis AI Assistant")
|
|
77
|
+
|
|
78
|
+
# 添加平台选择参数
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--platform",
|
|
81
|
+
choices=list(SUPPORTED_PLATFORMS.keys()),
|
|
82
|
+
default="ollama",
|
|
83
|
+
help="选择运行平台 (默认: ollama)"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# 添加模型选择参数
|
|
87
|
+
parser.add_argument(
|
|
88
|
+
"--model",
|
|
89
|
+
help="选择模型 (默认: 根据平台自动选择)"
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# 添加API基础URL参数
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"--api-base",
|
|
95
|
+
default="http://localhost:11434",
|
|
96
|
+
help="Ollama API基础URL (仅用于Ollama平台, 默认: http://localhost:11434)"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
args = parser.parse_args()
|
|
100
|
+
|
|
101
|
+
# 验证并设置默认模型
|
|
102
|
+
if args.model:
|
|
103
|
+
if args.model not in SUPPORTED_PLATFORMS[args.platform]["models"]:
|
|
104
|
+
supported_models = ", ".join(SUPPORTED_PLATFORMS[args.platform]["models"])
|
|
105
|
+
PrettyOutput.print(
|
|
106
|
+
f"错误: 平台 {args.platform} 不支持模型 {args.model}\n"
|
|
107
|
+
f"支持的模型: {supported_models}",
|
|
108
|
+
OutputType.ERROR
|
|
109
|
+
)
|
|
110
|
+
return 1
|
|
111
|
+
else:
|
|
112
|
+
args.model = SUPPORTED_PLATFORMS[args.platform]["default"]
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
# 根据平台创建相应的模型实例
|
|
116
|
+
if args.platform == "ollama":
|
|
117
|
+
model = OllamaModel(
|
|
118
|
+
model_name=args.model,
|
|
119
|
+
api_base=args.api_base
|
|
120
|
+
)
|
|
121
|
+
platform_name = f"Ollama ({args.model})"
|
|
122
|
+
else: # ddgs
|
|
123
|
+
model = DDGSModel(model_name=args.model)
|
|
124
|
+
platform_name = f"DuckDuckGo Search ({args.model})"
|
|
125
|
+
|
|
126
|
+
tool_registry = ToolRegistry()
|
|
127
|
+
agent = Agent(model, tool_registry)
|
|
128
|
+
|
|
129
|
+
# 欢迎信息
|
|
130
|
+
PrettyOutput.print(f"Jarvis 已初始化 - {platform_name}", OutputType.SYSTEM)
|
|
131
|
+
|
|
132
|
+
# 加载预定义任务
|
|
133
|
+
tasks = load_tasks()
|
|
134
|
+
if tasks:
|
|
135
|
+
selected_task = select_task(tasks)
|
|
136
|
+
if selected_task:
|
|
137
|
+
PrettyOutput.print(f"\n执行任务: {selected_task}", OutputType.INFO)
|
|
138
|
+
agent.run(selected_task)
|
|
139
|
+
return 0
|
|
140
|
+
|
|
141
|
+
# 如果没有选择预定义任务,进入交互模式
|
|
142
|
+
while True:
|
|
143
|
+
try:
|
|
144
|
+
user_input = get_multiline_input("请输入您的任务(输入空行退出):")
|
|
145
|
+
if not user_input:
|
|
146
|
+
break
|
|
147
|
+
agent.run(user_input)
|
|
148
|
+
except KeyboardInterrupt:
|
|
149
|
+
print("\n正在退出...")
|
|
150
|
+
break
|
|
151
|
+
except Exception as e:
|
|
152
|
+
PrettyOutput.print(f"错误: {str(e)}", OutputType.ERROR)
|
|
153
|
+
|
|
154
|
+
except Exception as e:
|
|
155
|
+
PrettyOutput.print(f"初始化错误: {str(e)}", OutputType.ERROR)
|
|
156
|
+
return 1
|
|
157
|
+
|
|
158
|
+
return 0
|
|
159
|
+
|
|
160
|
+
if __name__ == "__main__":
|
|
161
|
+
exit(main())
|
jarvis/models.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
import time
|
|
4
|
+
from typing import Dict, List, Optional
|
|
5
|
+
from duckduckgo_search import DDGS
|
|
6
|
+
import ollama
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
|
|
9
|
+
from .utils import OutputType, PrettyOutput
|
|
10
|
+
|
|
11
|
+
class BaseModel(ABC):
|
|
12
|
+
"""大语言模型基类"""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def chat(self, messages: List[Dict], tools: Optional[List[Dict]] = None) -> Dict:
|
|
16
|
+
"""执行对话"""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@staticmethod
|
|
20
|
+
def extract_tool_calls(content: str) -> List[Dict]:
|
|
21
|
+
"""从内容中提取工具调用"""
|
|
22
|
+
tool_calls = []
|
|
23
|
+
# 修改正则表达式以更好地处理多行内容
|
|
24
|
+
pattern = re.compile(
|
|
25
|
+
r'<tool_call>\s*({(?:[^{}]|(?:{[^{}]*})|(?:{(?:[^{}]|{[^{}]*})*}))*})\s*</tool_call>',
|
|
26
|
+
re.DOTALL # 添加DOTALL标志以匹配跨行内容
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
matches = pattern.finditer(content)
|
|
30
|
+
for match in matches:
|
|
31
|
+
try:
|
|
32
|
+
tool_call_str = match.group(1).strip()
|
|
33
|
+
tool_call = json.loads(tool_call_str)
|
|
34
|
+
if isinstance(tool_call, dict) and "name" in tool_call and "arguments" in tool_call:
|
|
35
|
+
tool_calls.append({
|
|
36
|
+
"function": {
|
|
37
|
+
"name": tool_call["name"],
|
|
38
|
+
"arguments": tool_call["arguments"]
|
|
39
|
+
}
|
|
40
|
+
})
|
|
41
|
+
else:
|
|
42
|
+
PrettyOutput.print(f"无效的工具调用格式: {tool_call_str}", OutputType.ERROR)
|
|
43
|
+
except json.JSONDecodeError as e:
|
|
44
|
+
PrettyOutput.print(f"JSON解析错误: {str(e)}", OutputType.ERROR)
|
|
45
|
+
PrettyOutput.print(f"解析内容: {tool_call_str}", OutputType.ERROR)
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
return tool_calls
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class DDGSModel(BaseModel):
|
|
52
|
+
def __init__(self, model_name: str = "gpt-4o-mini"):
|
|
53
|
+
"""
|
|
54
|
+
[1]: gpt-4o-mini
|
|
55
|
+
[2]: claude-3-haiku
|
|
56
|
+
[3]: llama-3.1-70b
|
|
57
|
+
[4]: mixtral-8x7b
|
|
58
|
+
"""
|
|
59
|
+
self.model_name = model_name
|
|
60
|
+
|
|
61
|
+
def __make_prompt(self, messages: List[Dict], tools: Optional[List[Dict]] = None) -> str:
|
|
62
|
+
prompt = "You are an AI Agent skilled in utilizing tools and planning tasks. Based on the task input by the user and the list of available tools, you output the tool invocation methods in a specified format. The user will provide feedback on the results of the tool execution, allowing you to continue analyzing and ultimately complete the user's designated task. Below is the list of tools and their usage methods. Let's use them step by step to accomplish the user's task.\n"
|
|
63
|
+
for tool in tools:
|
|
64
|
+
prompt += f"- Tool: {tool['function']['name']}\n"
|
|
65
|
+
prompt += f" Description: {tool['function']['description']}\n"
|
|
66
|
+
prompt += f" Arguments: {tool['function']['parameters']}\n"
|
|
67
|
+
for message in messages:
|
|
68
|
+
prompt += f"[{message['role']}]: {message['content']}\n"
|
|
69
|
+
return prompt
|
|
70
|
+
|
|
71
|
+
def chat(self, messages: List[Dict], tools: Optional[List[Dict]] = None) -> Dict:
|
|
72
|
+
ddgs = DDGS()
|
|
73
|
+
prompt = self.__make_prompt(messages, tools)
|
|
74
|
+
content = ddgs.chat(prompt)
|
|
75
|
+
tool_calls = BaseModel.extract_tool_calls(content)
|
|
76
|
+
return {
|
|
77
|
+
"message": {
|
|
78
|
+
"content": content,
|
|
79
|
+
"tool_calls": tool_calls
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class OllamaModel(BaseModel):
|
|
85
|
+
"""Ollama模型实现"""
|
|
86
|
+
|
|
87
|
+
def __init__(self, model_name: str = "qwen2.5:14b", api_base: str = "http://localhost:11434"):
|
|
88
|
+
self.model_name = model_name
|
|
89
|
+
self.api_base = api_base
|
|
90
|
+
self.client = ollama.Client(host=api_base)
|
|
91
|
+
|
|
92
|
+
def chat(self, messages: List[Dict], tools: Optional[List[Dict]] = None) -> Dict:
|
|
93
|
+
"""调用Ollama API获取响应"""
|
|
94
|
+
try:
|
|
95
|
+
response = self.client.chat(
|
|
96
|
+
model=self.model_name,
|
|
97
|
+
messages=messages,
|
|
98
|
+
tools=tools
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
content = response.message.content
|
|
102
|
+
tool_calls = response.message.tool_calls or BaseModel.extract_tool_calls(content)
|
|
103
|
+
|
|
104
|
+
# 转换响应格式
|
|
105
|
+
return {
|
|
106
|
+
"message": {
|
|
107
|
+
"content": content,
|
|
108
|
+
"tool_calls": tool_calls
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
except Exception as e:
|
|
112
|
+
raise Exception(f"Ollama API调用失败: {str(e)}")
|
jarvis/tools/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from .base import Tool, ToolRegistry
|
|
2
|
+
from .python_script import PythonScript
|
|
3
|
+
from .file_ops import FileOperationTool
|
|
4
|
+
from .search import SearchTool
|
|
5
|
+
from .shell import ShellTool
|
|
6
|
+
from .user_interaction import UserInteractionTool
|
|
7
|
+
from .user_confirmation import UserConfirmationTool
|
|
8
|
+
from .rag import RAGTool
|
|
9
|
+
from .webpage import WebpageTool
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
'Tool',
|
|
13
|
+
'ToolRegistry',
|
|
14
|
+
'PythonScript',
|
|
15
|
+
'FileOperationTool',
|
|
16
|
+
'SearchTool',
|
|
17
|
+
'ShellTool',
|
|
18
|
+
'UserInteractionTool',
|
|
19
|
+
'UserConfirmationTool',
|
|
20
|
+
'RAGTool',
|
|
21
|
+
'WebpageTool',
|
|
22
|
+
]
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
jarvis/tools/base.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from typing import Dict, Any, List, Optional, Callable
|
|
2
|
+
import json
|
|
3
|
+
from ..utils import PrettyOutput, OutputType
|
|
4
|
+
|
|
5
|
+
class Tool:
|
|
6
|
+
def __init__(self, name: str, description: str, parameters: Dict, func: Callable):
|
|
7
|
+
self.name = name
|
|
8
|
+
self.description = description
|
|
9
|
+
self.parameters = parameters
|
|
10
|
+
self.func = func
|
|
11
|
+
|
|
12
|
+
def to_dict(self) -> Dict:
|
|
13
|
+
"""转换为Ollama工具格式"""
|
|
14
|
+
return {
|
|
15
|
+
"type": "function",
|
|
16
|
+
"function": {
|
|
17
|
+
"name": self.name,
|
|
18
|
+
"description": self.description,
|
|
19
|
+
"parameters": self.parameters
|
|
20
|
+
}
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
def execute(self, arguments: Dict) -> Dict[str, Any]:
|
|
24
|
+
"""执行工具函数"""
|
|
25
|
+
return self.func(arguments)
|
|
26
|
+
|
|
27
|
+
class ToolRegistry:
|
|
28
|
+
def __init__(self):
|
|
29
|
+
self.tools: Dict[str, Tool] = {}
|
|
30
|
+
self._register_default_tools()
|
|
31
|
+
|
|
32
|
+
def _register_default_tools(self):
|
|
33
|
+
"""注册所有默认工具"""
|
|
34
|
+
from .search import SearchTool
|
|
35
|
+
from .shell import ShellTool
|
|
36
|
+
from .user_interaction import UserInteractionTool
|
|
37
|
+
from .user_confirmation import UserConfirmationTool
|
|
38
|
+
from .python_script import PythonScriptTool
|
|
39
|
+
from .file_ops import FileOperationTool
|
|
40
|
+
from .rag import RAGTool
|
|
41
|
+
from .webpage import WebpageTool
|
|
42
|
+
|
|
43
|
+
tools = [
|
|
44
|
+
SearchTool(),
|
|
45
|
+
ShellTool(),
|
|
46
|
+
UserInteractionTool(),
|
|
47
|
+
UserConfirmationTool(),
|
|
48
|
+
PythonScriptTool(),
|
|
49
|
+
FileOperationTool(),
|
|
50
|
+
RAGTool(),
|
|
51
|
+
WebpageTool(),
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
for tool in tools:
|
|
55
|
+
self.register_tool(
|
|
56
|
+
name=tool.name,
|
|
57
|
+
description=tool.description,
|
|
58
|
+
parameters=tool.parameters,
|
|
59
|
+
func=tool.execute
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def register_tool(self, name: str, description: str, parameters: Dict, func: Callable):
|
|
63
|
+
"""注册新工具"""
|
|
64
|
+
self.tools[name] = Tool(name, description, parameters, func)
|
|
65
|
+
|
|
66
|
+
def get_tool(self, name: str) -> Optional[Tool]:
|
|
67
|
+
"""获取工具"""
|
|
68
|
+
return self.tools.get(name)
|
|
69
|
+
|
|
70
|
+
def get_all_tools(self) -> List[Dict]:
|
|
71
|
+
"""获取所有工具的Ollama格式定义"""
|
|
72
|
+
return [tool.to_dict() for tool in self.tools.values()]
|
|
73
|
+
|
|
74
|
+
def execute_tool(self, name: str, arguments: Dict) -> Dict[str, Any]:
|
|
75
|
+
"""执行指定工具"""
|
|
76
|
+
tool = self.get_tool(name)
|
|
77
|
+
if tool is None:
|
|
78
|
+
return {"success": False, "error": f"Tool {name} does not exist"}
|
|
79
|
+
return tool.execute(arguments)
|
|
80
|
+
|
|
81
|
+
def handle_tool_calls(self, tool_calls: List[Dict]) -> str:
|
|
82
|
+
"""处理工具调用"""
|
|
83
|
+
results = []
|
|
84
|
+
for tool_call in tool_calls:
|
|
85
|
+
name = tool_call["function"]["name"]
|
|
86
|
+
args = tool_call["function"]["arguments"]
|
|
87
|
+
if isinstance(args, str):
|
|
88
|
+
try:
|
|
89
|
+
args = json.loads(args)
|
|
90
|
+
except json.JSONDecodeError:
|
|
91
|
+
return f"Invalid JSON in arguments for tool {name}"
|
|
92
|
+
|
|
93
|
+
PrettyOutput.print(f"Calling tool: {name}", OutputType.INFO)
|
|
94
|
+
if isinstance(args, dict):
|
|
95
|
+
for key, value in args.items():
|
|
96
|
+
PrettyOutput.print(f" - {key}: {value}", OutputType.INFO)
|
|
97
|
+
else:
|
|
98
|
+
PrettyOutput.print(f" Arguments: {args}", OutputType.INFO)
|
|
99
|
+
PrettyOutput.print("", OutputType.INFO)
|
|
100
|
+
|
|
101
|
+
result = self.execute_tool(name, args)
|
|
102
|
+
if result["success"]:
|
|
103
|
+
stdout = result["stdout"]
|
|
104
|
+
stderr = result.get("stderr", "")
|
|
105
|
+
output_parts = []
|
|
106
|
+
output_parts.append(f"Result:\n{stdout}")
|
|
107
|
+
if stderr:
|
|
108
|
+
output_parts.append(f"Errors:\n{stderr}")
|
|
109
|
+
output = "\n\n".join(output_parts)
|
|
110
|
+
else:
|
|
111
|
+
error_msg = result["error"]
|
|
112
|
+
output = f"Execution failed: {error_msg}"
|
|
113
|
+
|
|
114
|
+
results.append(output)
|
|
115
|
+
return "\n".join(results)
|
|
116
|
+
|
|
117
|
+
def tool_help_text(self) -> str:
|
|
118
|
+
"""返回所有工具的帮助文本"""
|
|
119
|
+
return """Available Tools:
|
|
120
|
+
|
|
121
|
+
1. search: Search for information using DuckDuckGo
|
|
122
|
+
2. read_webpage: Extract content from webpages
|
|
123
|
+
3. execute_python: Run Python code with dependency management
|
|
124
|
+
4. execute_shell: Execute shell commands
|
|
125
|
+
5. ask_user: Get input from user with options support
|
|
126
|
+
6. ask_user_confirmation: Get yes/no confirmation from user
|
|
127
|
+
7. file_operation: Read/write files in workspace directory
|
|
128
|
+
8. rag_query: Query documents using RAG
|
|
129
|
+
|
|
130
|
+
Guidelines:
|
|
131
|
+
1. Always verify information through tools
|
|
132
|
+
2. Use search + read_webpage for research
|
|
133
|
+
3. Use Python/shell for data processing
|
|
134
|
+
4. Ask user when information is missing
|
|
135
|
+
|
|
136
|
+
Tool Call Format:
|
|
137
|
+
<tool_call>
|
|
138
|
+
{
|
|
139
|
+
"name": "tool_name",
|
|
140
|
+
"arguments": {
|
|
141
|
+
"param1": "value1"
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
</tool_call>
|
|
145
|
+
|
|
146
|
+
Example:
|
|
147
|
+
<tool_call>
|
|
148
|
+
{
|
|
149
|
+
"name": "search",
|
|
150
|
+
"arguments": {
|
|
151
|
+
"query": "Python GIL",
|
|
152
|
+
"max_results": 3
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
</tool_call>"""
|