travel-agent-cli 0.2.0 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +6 -6
- package/package.json +2 -2
- package/python/agents/__init__.py +19 -0
- package/python/agents/analysis_agent.py +234 -0
- package/python/agents/base.py +377 -0
- package/python/agents/collector_agent.py +304 -0
- package/python/agents/manager_agent.py +251 -0
- package/python/agents/planning_agent.py +161 -0
- package/python/agents/product_agent.py +672 -0
- package/python/agents/report_agent.py +172 -0
- package/python/analyzers/__init__.py +10 -0
- package/python/analyzers/hot_score.py +123 -0
- package/python/analyzers/ranker.py +225 -0
- package/python/analyzers/route_planner.py +86 -0
- package/python/cli/commands.py +254 -0
- package/python/collectors/__init__.py +14 -0
- package/python/collectors/ota/ctrip.py +120 -0
- package/python/collectors/ota/fliggy.py +152 -0
- package/python/collectors/weibo.py +235 -0
- package/python/collectors/wenlv.py +155 -0
- package/python/collectors/xiaohongshu.py +170 -0
- package/python/config/__init__.py +30 -0
- package/python/config/models.py +119 -0
- package/python/config/prompts.py +105 -0
- package/python/config/settings.py +172 -0
- package/python/export/__init__.py +6 -0
- package/python/export/report.py +192 -0
- package/python/main.py +632 -0
- package/python/pyproject.toml +51 -0
- package/python/scheduler/tasks.py +77 -0
- package/python/tools/fliggy_mcp.py +553 -0
- package/python/tools/flyai_tools.py +251 -0
- package/python/tools/mcp_tools.py +412 -0
- package/python/utils/__init__.py +9 -0
- package/python/utils/http.py +73 -0
- package/python/utils/storage.py +288 -0
- package/scripts/postinstall.js +59 -65
package/bin/cli.js
CHANGED
|
@@ -11,8 +11,8 @@ const fs = require('fs');
|
|
|
11
11
|
|
|
12
12
|
// 获取包的安装路径
|
|
13
13
|
const packagePath = path.join(__dirname, '..');
|
|
14
|
-
const
|
|
15
|
-
const mainPy = path.join(
|
|
14
|
+
const pythonDir = path.join(packagePath, 'python');
|
|
15
|
+
const mainPy = path.join(pythonDir, 'main.py');
|
|
16
16
|
|
|
17
17
|
// 检查 Python 是否可用
|
|
18
18
|
function findPython() {
|
|
@@ -34,9 +34,9 @@ function findPython() {
|
|
|
34
34
|
// 检查虚拟环境是否存在
|
|
35
35
|
function getVenvPython() {
|
|
36
36
|
const venvPaths = [
|
|
37
|
-
path.join(
|
|
38
|
-
path.join(
|
|
39
|
-
path.join(
|
|
37
|
+
path.join(pythonDir, 'venv', 'bin', 'python'),
|
|
38
|
+
path.join(pythonDir, '.venv', 'bin', 'python'),
|
|
39
|
+
path.join(pythonDir, 'venv', 'Scripts', 'python.exe'),
|
|
40
40
|
];
|
|
41
41
|
|
|
42
42
|
for (const venvPath of venvPaths) {
|
|
@@ -118,7 +118,7 @@ function run() {
|
|
|
118
118
|
const pyArgs = [mainPy, ...args];
|
|
119
119
|
const child = spawn(pythonCmd, pyArgs, {
|
|
120
120
|
stdio: 'inherit',
|
|
121
|
-
cwd:
|
|
121
|
+
cwd: pythonDir
|
|
122
122
|
});
|
|
123
123
|
|
|
124
124
|
child.on('error', (err) => {
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "travel-agent-cli",
|
|
3
|
-
"version": "0.2.
|
|
3
|
+
"version": "0.2.2",
|
|
4
4
|
"description": "AI 驱动的旅行目的地推荐 Agent - 命令行工具(集成 FlyAI 旅行搜索)",
|
|
5
5
|
"bin": {
|
|
6
6
|
"travel-agent": "bin/cli.js",
|
|
@@ -28,7 +28,7 @@
|
|
|
28
28
|
"license": "MIT",
|
|
29
29
|
"repository": {
|
|
30
30
|
"type": "git",
|
|
31
|
-
"url": "https://github.com/your-username/travel-agent.git",
|
|
31
|
+
"url": "git+https://github.com/your-username/travel-agent.git",
|
|
32
32
|
"directory": "npm-package"
|
|
33
33
|
},
|
|
34
34
|
"engines": {
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Agents 模块 - 多 Agent 系统
|
|
2
|
+
|
|
3
|
+
导出所有 Agent 类
|
|
4
|
+
"""
|
|
5
|
+
from agents.base import BaseAgent
|
|
6
|
+
from agents.collector_agent import CollectionAgent
|
|
7
|
+
from agents.analysis_agent import AnalysisAgent
|
|
8
|
+
from agents.planning_agent import PlanningAgent
|
|
9
|
+
from agents.report_agent import ReportAgent
|
|
10
|
+
from agents.manager_agent import ManagerAgent
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"BaseAgent",
|
|
14
|
+
"CollectionAgent",
|
|
15
|
+
"AnalysisAgent",
|
|
16
|
+
"PlanningAgent",
|
|
17
|
+
"ReportAgent",
|
|
18
|
+
"ManagerAgent",
|
|
19
|
+
]
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""分析 Agent - 负责数据分析和目的地排序"""
|
|
2
|
+
from typing import Dict, Any, List, Optional
|
|
3
|
+
import json
|
|
4
|
+
from agents.base import BaseAgent
|
|
5
|
+
from analyzers.hot_score import HotScoreCalculator
|
|
6
|
+
from config.models import SocialPost, WenlvInfo
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AnalysisAgent(BaseAgent):
|
|
10
|
+
"""分析 Agent
|
|
11
|
+
|
|
12
|
+
职责:
|
|
13
|
+
- 计算目的地热度评分
|
|
14
|
+
- 综合分析并排序
|
|
15
|
+
- 生成推荐理由
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
name = "analysis_agent"
|
|
19
|
+
role = "旅行数据分析师"
|
|
20
|
+
goal = "分析采集的数据,识别热门目的地,给出专业评分和推荐理由"
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
provider: Optional[str] = None,
|
|
25
|
+
model: Optional[str] = None,
|
|
26
|
+
use_tools: bool = False,
|
|
27
|
+
):
|
|
28
|
+
super().__init__(provider, model, use_tools)
|
|
29
|
+
self.score_calculator = HotScoreCalculator()
|
|
30
|
+
|
|
31
|
+
async def execute_local(self, task: str, context: Dict[str, Any]) -> str:
|
|
32
|
+
"""本地执行分析"""
|
|
33
|
+
posts = context.get("posts", [])
|
|
34
|
+
wenlv_infos = context.get("wenlv_infos", [])
|
|
35
|
+
analysis = context.get("analysis", {})
|
|
36
|
+
|
|
37
|
+
# 如果有分析数据,生成本地推荐
|
|
38
|
+
if analysis:
|
|
39
|
+
destinations = analysis.get("destinations", [])
|
|
40
|
+
top_n = 10
|
|
41
|
+
|
|
42
|
+
recommendations = []
|
|
43
|
+
for i, dest in enumerate(destinations[:top_n], 1):
|
|
44
|
+
recommendations.append({
|
|
45
|
+
"name": dest.get("name", "未知目的地"),
|
|
46
|
+
"rank": i,
|
|
47
|
+
"score": dest.get("scores", {}).get("total_score", 5.0),
|
|
48
|
+
"reason": f"基于 {dest.get('post_count', 0)} 条社交媒体内容和 {dest.get('wenlv_count', 0)} 条文旅信息推荐",
|
|
49
|
+
"estimated_cost": "¥3000-5000",
|
|
50
|
+
"suggested_days": 5,
|
|
51
|
+
"best_time": "春秋季节",
|
|
52
|
+
"highlights": ["特色美食", "文化体验", "自然风光"]
|
|
53
|
+
})
|
|
54
|
+
|
|
55
|
+
result = {
|
|
56
|
+
"destinations": recommendations,
|
|
57
|
+
"summary": "本地分析推荐结果"
|
|
58
|
+
}
|
|
59
|
+
return json.dumps(result, ensure_ascii=False)
|
|
60
|
+
|
|
61
|
+
# 简化的本地分析
|
|
62
|
+
destinations = self._extract_destinations(posts, wenlv_infos)
|
|
63
|
+
|
|
64
|
+
recommendations = []
|
|
65
|
+
for i, dest in enumerate(destinations[:10], 1):
|
|
66
|
+
recommendations.append({
|
|
67
|
+
"name": dest,
|
|
68
|
+
"rank": i,
|
|
69
|
+
"score": 5.0,
|
|
70
|
+
"reason": "本地分析结果",
|
|
71
|
+
"estimated_cost": "¥3000-5000",
|
|
72
|
+
"suggested_days": 5,
|
|
73
|
+
"best_time": "春秋季节",
|
|
74
|
+
"highlights": ["特色美食", "文化体验"]
|
|
75
|
+
})
|
|
76
|
+
|
|
77
|
+
result = {
|
|
78
|
+
"destinations": recommendations,
|
|
79
|
+
"summary": "本地分析推荐结果"
|
|
80
|
+
}
|
|
81
|
+
return json.dumps(result, ensure_ascii=False)
|
|
82
|
+
|
|
83
|
+
def _extract_destinations(
|
|
84
|
+
self,
|
|
85
|
+
posts: List[Dict],
|
|
86
|
+
wenlv_infos: List[Dict]
|
|
87
|
+
) -> List[str]:
|
|
88
|
+
"""从数据中提取目的地"""
|
|
89
|
+
common_destinations = [
|
|
90
|
+
"三亚", "云南", "大理", "丽江", "西双版纳",
|
|
91
|
+
"四川", "成都", "九寨沟", "川西",
|
|
92
|
+
"北京", "上海", "广州", "深圳",
|
|
93
|
+
"浙江", "杭州", "乌镇", "苏州", "南京",
|
|
94
|
+
"江苏", "陕西", "西安",
|
|
95
|
+
"广西", "桂林", "阳朔",
|
|
96
|
+
"海南", "西藏", "拉萨", "新疆", "喀纳斯",
|
|
97
|
+
"甘肃", "敦煌", "青海", "青海湖",
|
|
98
|
+
"黑龙江", "哈尔滨", "吉林", "长白山"
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
destinations = {}
|
|
102
|
+
|
|
103
|
+
for post in posts:
|
|
104
|
+
text = f"{post.get('title', '')} {post.get('content', '')}"
|
|
105
|
+
for dest in common_destinations:
|
|
106
|
+
if dest in text:
|
|
107
|
+
if dest not in destinations:
|
|
108
|
+
destinations[dest] = {"posts": 0, "wenlv": 0}
|
|
109
|
+
destinations[dest]["posts"] += 1
|
|
110
|
+
|
|
111
|
+
for info in wenlv_infos:
|
|
112
|
+
region = info.get("region", "")
|
|
113
|
+
if region and region != "全国":
|
|
114
|
+
if region not in destinations:
|
|
115
|
+
destinations[region] = {"posts": 0, "wenlv": 0}
|
|
116
|
+
destinations[region]["wenlv"] += 1
|
|
117
|
+
|
|
118
|
+
# 按热度排序
|
|
119
|
+
sorted_dests = sorted(
|
|
120
|
+
destinations.items(),
|
|
121
|
+
key=lambda x: x[1]["posts"] * 2 + x[1]["wenlv"] * 3,
|
|
122
|
+
reverse=True
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return [d[0] for d in sorted_dests]
|
|
126
|
+
|
|
127
|
+
async def analyze_destinations(
|
|
128
|
+
self,
|
|
129
|
+
posts: List[Dict],
|
|
130
|
+
wenlv_infos: List[Dict],
|
|
131
|
+
flights: List[Dict] = None,
|
|
132
|
+
hotels: List[Dict] = None
|
|
133
|
+
) -> Dict[str, Any]:
|
|
134
|
+
"""分析目的地数据
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
posts: 社交媒体帖子列表
|
|
138
|
+
wenlv_infos: 文旅信息列表
|
|
139
|
+
flights: 航班数据(可选)
|
|
140
|
+
hotels: 酒店数据(可选)
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
分析结果字典
|
|
144
|
+
"""
|
|
145
|
+
# 提取目的地
|
|
146
|
+
destinations = self._extract_destinations(posts, wenlv_infos)
|
|
147
|
+
|
|
148
|
+
# 为每个目的地计算评分
|
|
149
|
+
scored_destinations = []
|
|
150
|
+
for dest in destinations:
|
|
151
|
+
# 筛选相关数据
|
|
152
|
+
related_posts = [
|
|
153
|
+
p for p in posts
|
|
154
|
+
if dest in p.get("title", "") or dest in p.get("content", "")
|
|
155
|
+
]
|
|
156
|
+
related_wenlv = [
|
|
157
|
+
w for w in wenlv_infos
|
|
158
|
+
if dest in w.get("title", "") or dest in w.get("content", "") or dest == w.get("region")
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
# 计算评分
|
|
162
|
+
score_info = self.score_calculator.calculate_destination_score(
|
|
163
|
+
destination=dest,
|
|
164
|
+
posts=[SocialPost(**p) if isinstance(p, dict) else p for p in related_posts],
|
|
165
|
+
wenlv_infos=[WenlvInfo(**w) if isinstance(w, dict) else w for w in related_wenlv],
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
scored_destinations.append({
|
|
169
|
+
"name": dest,
|
|
170
|
+
"scores": score_info,
|
|
171
|
+
"post_count": len(related_posts),
|
|
172
|
+
"wenlv_count": len(related_wenlv),
|
|
173
|
+
})
|
|
174
|
+
|
|
175
|
+
# 按总分排序
|
|
176
|
+
scored_destinations.sort(
|
|
177
|
+
key=lambda x: x["scores"]["total_score"],
|
|
178
|
+
reverse=True
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
return {
|
|
182
|
+
"destinations": scored_destinations,
|
|
183
|
+
"total_analyzed": len(scored_destinations),
|
|
184
|
+
"top_destinations": [d["name"] for d in scored_destinations[:5]]
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
async def generate_recommendations(
|
|
188
|
+
self,
|
|
189
|
+
analysis_result: Dict[str, Any],
|
|
190
|
+
top_n: int = 10
|
|
191
|
+
) -> str:
|
|
192
|
+
"""生成推荐结果
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
analysis_result: 分析结果
|
|
196
|
+
top_n: 推荐数量
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
推荐的 JSON 字符串
|
|
200
|
+
"""
|
|
201
|
+
task = f"""基于以下分析结果,生成 TOP{top_n}旅行目的地推荐。
|
|
202
|
+
|
|
203
|
+
对于每个目的地,请提供:
|
|
204
|
+
1. 排名
|
|
205
|
+
2. 目的地名称
|
|
206
|
+
3. 综合得分(基于提供的分数)
|
|
207
|
+
4. 推荐理由(100-200 字)
|
|
208
|
+
5. 预估费用范围
|
|
209
|
+
6. 建议游玩天数
|
|
210
|
+
7. 最佳旅行时间
|
|
211
|
+
8. 亮点特色(2-3 个)
|
|
212
|
+
|
|
213
|
+
分析数据:
|
|
214
|
+
{json.dumps(analysis_result, ensure_ascii=False, indent=2)}
|
|
215
|
+
|
|
216
|
+
请输出 JSON 格式:
|
|
217
|
+
{{
|
|
218
|
+
"destinations": [
|
|
219
|
+
{{
|
|
220
|
+
"name": "目的地",
|
|
221
|
+
"rank": 1,
|
|
222
|
+
"score": 8.5,
|
|
223
|
+
"reason": "推荐理由...",
|
|
224
|
+
"estimated_cost": "¥3000-5000",
|
|
225
|
+
"suggested_days": 5,
|
|
226
|
+
"best_time": "3-5 月",
|
|
227
|
+
"highlights": ["亮点 1", "亮点 2"]
|
|
228
|
+
}}
|
|
229
|
+
],
|
|
230
|
+
"summary": "整体总结"
|
|
231
|
+
}}
|
|
232
|
+
"""
|
|
233
|
+
|
|
234
|
+
return await self.execute(task, {"analysis": analysis_result})
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
"""Agent 基类模块 - 支持多 LLM 提供商
|
|
2
|
+
|
|
3
|
+
提供统一的 Agent 接口,支持 Anthropic/OpenAI/DeepSeek/Azure/Ollama 等厂商
|
|
4
|
+
"""
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Dict, Any, Optional, List, Callable
|
|
7
|
+
from config.settings import get_settings, LLM_PROVIDERS
|
|
8
|
+
from tools.mcp_tools import Tool, ToolRegistry, ToolHandler, build_tool
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseAgent(ABC):
|
|
12
|
+
"""Agent 基类
|
|
13
|
+
|
|
14
|
+
所有子 Agent 都需要继承此类并实现具体方法
|
|
15
|
+
|
|
16
|
+
支持两种执行模式:
|
|
17
|
+
1. 纯文本对话 (默认)
|
|
18
|
+
2. Tool Use 模式 (需要配置 tools)
|
|
19
|
+
|
|
20
|
+
支持的 LLM 提供商:
|
|
21
|
+
- Anthropic (claude-sonnet-4-6, claude-opus-4-6, ...)
|
|
22
|
+
- OpenAI (gpt-4o, gpt-4-turbo, ...)
|
|
23
|
+
- DeepSeek (deepseek-chat, deepseek-coder)
|
|
24
|
+
- Azure OpenAI
|
|
25
|
+
- Ollama (本地部署)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
name: str = "base_agent"
|
|
29
|
+
role: str = "助手"
|
|
30
|
+
goal: str = "帮助用户完成任务"
|
|
31
|
+
|
|
32
|
+
# 子类可定义可用的工具列表(MCP 格式)
|
|
33
|
+
# 使用 build_tool() 构建符合 MCP 标准的工具定义
|
|
34
|
+
available_tools: Dict[str, Tool] = {}
|
|
35
|
+
tool_handlers: Dict[str, ToolHandler] = {}
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
provider: Optional[str] = None,
|
|
40
|
+
model: Optional[str] = None,
|
|
41
|
+
use_tools: bool = False,
|
|
42
|
+
):
|
|
43
|
+
"""初始化 Agent
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
provider: LLM 提供商 (anthropic/openai/deepseek/azure/ollama)
|
|
47
|
+
model: 使用的模型名称,留空则使用配置中的默认模型
|
|
48
|
+
use_tools: 是否启用 Tool Use 模式
|
|
49
|
+
"""
|
|
50
|
+
self.settings = get_settings()
|
|
51
|
+
|
|
52
|
+
# 确定使用的提供商和模型
|
|
53
|
+
self.provider = provider or self.settings.llm_provider or "anthropic"
|
|
54
|
+
self.model = model or self.settings.llm_model or self.settings.get_active_model()
|
|
55
|
+
self.use_tools = use_tools
|
|
56
|
+
|
|
57
|
+
# 工具注册表
|
|
58
|
+
self.tool_registry = ToolRegistry()
|
|
59
|
+
|
|
60
|
+
# 初始化客户端
|
|
61
|
+
self.client = self._init_client()
|
|
62
|
+
|
|
63
|
+
# 注册工具
|
|
64
|
+
self._register_tools()
|
|
65
|
+
|
|
66
|
+
def _init_client(self):
|
|
67
|
+
"""初始化 LLM 客户端
|
|
68
|
+
|
|
69
|
+
根据配置的提供商初始化对应的客户端
|
|
70
|
+
"""
|
|
71
|
+
if not self.settings.is_provider_configured(self.provider):
|
|
72
|
+
print(f"[{self.name}] 警告:{self.provider} 未配置,将使用降级模式")
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
if self.provider == "anthropic":
|
|
77
|
+
from anthropic import Anthropic
|
|
78
|
+
return Anthropic(api_key=self.settings.anthropic_api_key)
|
|
79
|
+
|
|
80
|
+
elif self.provider == "openai":
|
|
81
|
+
from openai import OpenAI
|
|
82
|
+
return OpenAI(api_key=self.settings.openai_api_key)
|
|
83
|
+
|
|
84
|
+
elif self.provider == "deepseek":
|
|
85
|
+
from openai import OpenAI # DeepSeek 使用 OpenAI 兼容接口
|
|
86
|
+
return OpenAI(
|
|
87
|
+
api_key=self.settings.deepseek_api_key,
|
|
88
|
+
base_url="https://api.deepseek.com"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
elif self.provider == "azure":
|
|
92
|
+
from openai import AzureOpenAI
|
|
93
|
+
return AzureOpenAI(
|
|
94
|
+
api_key=self.settings.azure_openai_api_key,
|
|
95
|
+
azure_endpoint=self.settings.azure_openai_endpoint,
|
|
96
|
+
api_version=self.settings.azure_openai_api_version,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
elif self.provider == "ollama":
|
|
100
|
+
from openai import OpenAI # Ollama 使用 OpenAI 兼容接口
|
|
101
|
+
return OpenAI(
|
|
102
|
+
base_url=self.settings.ollama_base_url,
|
|
103
|
+
api_key="ollama", # Ollama 不需要真实 Key
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
elif self.provider == "qwen":
|
|
107
|
+
from openai import OpenAI # Qwen 使用 OpenAI 兼容接口
|
|
108
|
+
return OpenAI(
|
|
109
|
+
api_key=self.settings.dashscope_api_key,
|
|
110
|
+
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
else:
|
|
114
|
+
print(f"[{self.name}] 未知提供商:{self.provider}")
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
except ImportError as e:
|
|
118
|
+
print(f"[{self.name}] 导入客户端库失败:{e}")
|
|
119
|
+
print(f"请安装对应的依赖:pip install anthropic openai")
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
print(f"[{self.name}] 初始化客户端失败:{e}")
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
def _register_tools(self):
|
|
127
|
+
"""注册工具
|
|
128
|
+
|
|
129
|
+
子类可以在 available_tools 中定义工具(MCP 格式),在 tool_handlers 中定义处理函数
|
|
130
|
+
工具会被注册到 tool_registry 中,支持动态添加和调用
|
|
131
|
+
"""
|
|
132
|
+
# 注册类变量中定义的工具
|
|
133
|
+
for tool_name, tool_def in self.available_tools.items():
|
|
134
|
+
handler = self.tool_handlers.get(tool_name)
|
|
135
|
+
if handler:
|
|
136
|
+
self.tool_registry.register(tool_def, handler)
|
|
137
|
+
|
|
138
|
+
def _get_tool_definitions(self) -> List[Tool]:
|
|
139
|
+
"""获取工具定义列表(MCP 标准格式)
|
|
140
|
+
|
|
141
|
+
返回 Claude API / MCP 兼容的工具定义
|
|
142
|
+
"""
|
|
143
|
+
return self.tool_registry.get_definitions()
|
|
144
|
+
|
|
145
|
+
def _build_prompt(self, task: str, context: Dict[str, Any]) -> str:
|
|
146
|
+
"""构建 Prompt
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
task: 任务描述
|
|
150
|
+
context: 上下文信息
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
完整的 prompt 字符串
|
|
154
|
+
"""
|
|
155
|
+
context_str = ""
|
|
156
|
+
if context:
|
|
157
|
+
import json
|
|
158
|
+
clean_context = {}
|
|
159
|
+
for k, v in context.items():
|
|
160
|
+
if hasattr(v, 'dict') or hasattr(v, 'model_dump'):
|
|
161
|
+
clean_context[k] = v.model_dump() if hasattr(v, 'model_dump') else v.dict()
|
|
162
|
+
elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], 'dict'):
|
|
163
|
+
clean_context[k] = [
|
|
164
|
+
item.model_dump() if hasattr(item, 'model_dump') else item.dict()
|
|
165
|
+
for item in v[:10]
|
|
166
|
+
]
|
|
167
|
+
else:
|
|
168
|
+
clean_context[k] = v
|
|
169
|
+
|
|
170
|
+
context_str = json.dumps(clean_context, ensure_ascii=False, indent=2, default=str)
|
|
171
|
+
|
|
172
|
+
return f"""你是一个{self.role}。
|
|
173
|
+
你的目标:{self.goal}
|
|
174
|
+
|
|
175
|
+
当前任务:
|
|
176
|
+
{task}
|
|
177
|
+
|
|
178
|
+
上下文信息:
|
|
179
|
+
{context_str if context_str else "无"}
|
|
180
|
+
|
|
181
|
+
请完成你的任务,直接输出结果:"""
|
|
182
|
+
|
|
183
|
+
async def _execute_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> str:
|
|
184
|
+
"""执行工具调用(通过 tool_registry)
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
tool_name: 工具名称
|
|
188
|
+
tool_input: 工具参数
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
工具执行结果
|
|
192
|
+
"""
|
|
193
|
+
try:
|
|
194
|
+
result = await self.tool_registry.call(tool_name, **tool_input)
|
|
195
|
+
return str(result)
|
|
196
|
+
except ValueError as e:
|
|
197
|
+
return f"错误:未知工具 {tool_name}"
|
|
198
|
+
except Exception as e:
|
|
199
|
+
return f"工具执行失败:{e}"
|
|
200
|
+
|
|
201
|
+
async def execute(
|
|
202
|
+
self,
|
|
203
|
+
task: str,
|
|
204
|
+
context: Optional[Dict[str, Any]] = None,
|
|
205
|
+
system_prompt: Optional[str] = None,
|
|
206
|
+
max_tokens: int = 2048,
|
|
207
|
+
use_tool_mode: Optional[bool] = None,
|
|
208
|
+
) -> str:
|
|
209
|
+
"""执行任务
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
task: 任务描述
|
|
213
|
+
context: 上下文信息
|
|
214
|
+
system_prompt: 可选的系统 prompt
|
|
215
|
+
max_tokens: 最大输出 token 数
|
|
216
|
+
use_tool_mode: 是否使用 Tool Use 模式,None 则使用实例配置
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Agent 执行结果
|
|
220
|
+
"""
|
|
221
|
+
context = context or {}
|
|
222
|
+
use_tool = use_tool_mode if use_tool_mode is not None else self.use_tools
|
|
223
|
+
|
|
224
|
+
# 如果没有客户端,使用本地执行
|
|
225
|
+
if not self.client:
|
|
226
|
+
return await self.execute_local(task, context)
|
|
227
|
+
|
|
228
|
+
# Tool Use 模式
|
|
229
|
+
if use_tool and self.available_tools:
|
|
230
|
+
return await self._execute_with_tools(
|
|
231
|
+
task, context, system_prompt, max_tokens
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# 普通文本模式
|
|
235
|
+
prompt = self._build_prompt(task, context)
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
if self.provider == "anthropic":
|
|
239
|
+
return await self._call_anthropic(prompt, system_prompt, max_tokens)
|
|
240
|
+
else:
|
|
241
|
+
return await self._call_openai_compatible(prompt, system_prompt, max_tokens)
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
print(f"[{self.name}] API 调用失败:{e},降级为本地执行")
|
|
245
|
+
return await self.execute_local(task, context)
|
|
246
|
+
|
|
247
|
+
async def _call_anthropic(
|
|
248
|
+
self,
|
|
249
|
+
prompt: str,
|
|
250
|
+
system_prompt: Optional[str],
|
|
251
|
+
max_tokens: int
|
|
252
|
+
) -> str:
|
|
253
|
+
"""调用 Anthropic API"""
|
|
254
|
+
messages = [{"role": "user", "content": prompt}]
|
|
255
|
+
if system_prompt:
|
|
256
|
+
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
257
|
+
|
|
258
|
+
response = self.client.messages.create(
|
|
259
|
+
model=self.model,
|
|
260
|
+
max_tokens=max_tokens,
|
|
261
|
+
messages=messages,
|
|
262
|
+
)
|
|
263
|
+
return response.content[0].text
|
|
264
|
+
|
|
265
|
+
async def _call_openai_compatible(
|
|
266
|
+
self,
|
|
267
|
+
prompt: str,
|
|
268
|
+
system_prompt: Optional[str],
|
|
269
|
+
max_tokens: int
|
|
270
|
+
) -> str:
|
|
271
|
+
"""调用 OpenAI 兼容 API (OpenAI/DeepSeek/Azure/Ollama)"""
|
|
272
|
+
messages = [{"role": "user", "content": prompt}]
|
|
273
|
+
if system_prompt:
|
|
274
|
+
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
275
|
+
|
|
276
|
+
response = self.client.chat.completions.create(
|
|
277
|
+
model=self.model,
|
|
278
|
+
messages=messages,
|
|
279
|
+
max_tokens=max_tokens,
|
|
280
|
+
)
|
|
281
|
+
return response.choices[0].message.content
|
|
282
|
+
|
|
283
|
+
async def _execute_with_tools(
|
|
284
|
+
self,
|
|
285
|
+
task: str,
|
|
286
|
+
context: Dict[str, Any],
|
|
287
|
+
system_prompt: Optional[str],
|
|
288
|
+
max_tokens: int = 2048
|
|
289
|
+
) -> str:
|
|
290
|
+
"""使用 Tool Use 模式执行任务"""
|
|
291
|
+
# 注意:Tool Use 目前仅支持 Anthropic
|
|
292
|
+
if self.provider != "anthropic":
|
|
293
|
+
print(f"[{self.name}] Tool Use 仅支持 Anthropic,降级为普通模式")
|
|
294
|
+
return await self.execute(task, context, system_prompt, max_tokens, use_tool_mode=False)
|
|
295
|
+
|
|
296
|
+
prompt = self._build_prompt(task, context)
|
|
297
|
+
messages = [{"role": "user", "content": prompt}]
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
response = self.client.messages.create(
|
|
301
|
+
model=self.model,
|
|
302
|
+
max_tokens=max_tokens,
|
|
303
|
+
system=system_prompt or f"你是一个{self.role}。{self.goal}",
|
|
304
|
+
messages=messages,
|
|
305
|
+
tools=self._get_tool_definitions() if self.available_tools else None,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
final_result = []
|
|
309
|
+
tool_results = []
|
|
310
|
+
|
|
311
|
+
for content in response.content:
|
|
312
|
+
if content.type == "text":
|
|
313
|
+
final_result.append(content.text)
|
|
314
|
+
elif content.type == "tool_use":
|
|
315
|
+
tool_name = content.name
|
|
316
|
+
tool_input = content.input
|
|
317
|
+
tool_id = content.id
|
|
318
|
+
|
|
319
|
+
tool_result = await self._execute_tool(tool_name, tool_input)
|
|
320
|
+
tool_results.append({
|
|
321
|
+
"type": "tool_result",
|
|
322
|
+
"tool_use_id": tool_id,
|
|
323
|
+
"content": tool_result,
|
|
324
|
+
})
|
|
325
|
+
|
|
326
|
+
if tool_results:
|
|
327
|
+
messages.append(response)
|
|
328
|
+
messages.append({
|
|
329
|
+
"role": "user",
|
|
330
|
+
"content": tool_results,
|
|
331
|
+
})
|
|
332
|
+
|
|
333
|
+
response2 = self.client.messages.create(
|
|
334
|
+
model=self.model,
|
|
335
|
+
max_tokens=max_tokens,
|
|
336
|
+
messages=messages,
|
|
337
|
+
tools=self._get_tool_definitions() if self.available_tools else None,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
for content in response2.content:
|
|
341
|
+
if content.type == "text":
|
|
342
|
+
final_result.append(content.text)
|
|
343
|
+
|
|
344
|
+
return "\n".join(final_result)
|
|
345
|
+
|
|
346
|
+
except Exception as e:
|
|
347
|
+
print(f"[{self.name}] Tool Use 执行失败:{e},降级为普通模式")
|
|
348
|
+
return await self.execute(task, context, system_prompt, max_tokens, use_tool_mode=False)
|
|
349
|
+
|
|
350
|
+
@abstractmethod
|
|
351
|
+
def execute_local(self, task: str, context: Dict[str, Any]) -> str:
|
|
352
|
+
"""本地执行逻辑(API 不可用时降级)
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
task: 任务描述
|
|
356
|
+
context: 上下文信息
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
执行结果
|
|
360
|
+
"""
|
|
361
|
+
pass
|
|
362
|
+
|
|
363
|
+
def switch_provider(self, provider: str, model: Optional[str] = None):
|
|
364
|
+
"""切换 LLM 提供商
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
provider: 新的提供商
|
|
368
|
+
model: 可选的模型名称
|
|
369
|
+
"""
|
|
370
|
+
self.provider = provider
|
|
371
|
+
if model:
|
|
372
|
+
self.model = model
|
|
373
|
+
self.client = self._init_client()
|
|
374
|
+
print(f"[{self.name}] 已切换到 {provider}/{self.model}")
|
|
375
|
+
|
|
376
|
+
def __repr__(self) -> str:
|
|
377
|
+
return f"{self.__class__.__name__}(provider={self.provider}, model={self.model})"
|