travel-agent-cli 0.1.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +76 -21
- package/bin/cli.js +14 -11
- package/package.json +8 -4
- package/python/agents/__init__.py +19 -0
- package/python/agents/analysis_agent.py +234 -0
- package/python/agents/base.py +377 -0
- package/python/agents/collector_agent.py +304 -0
- package/python/agents/manager_agent.py +251 -0
- package/python/agents/planning_agent.py +161 -0
- package/python/agents/product_agent.py +672 -0
- package/python/agents/report_agent.py +172 -0
- package/python/analyzers/__init__.py +10 -0
- package/python/analyzers/hot_score.py +123 -0
- package/python/analyzers/ranker.py +225 -0
- package/python/analyzers/route_planner.py +86 -0
- package/python/cli/commands.py +254 -0
- package/python/collectors/__init__.py +14 -0
- package/python/collectors/ota/ctrip.py +120 -0
- package/python/collectors/ota/fliggy.py +152 -0
- package/python/collectors/weibo.py +235 -0
- package/python/collectors/wenlv.py +155 -0
- package/python/collectors/xiaohongshu.py +170 -0
- package/python/config/__init__.py +30 -0
- package/python/config/models.py +119 -0
- package/python/config/prompts.py +105 -0
- package/python/config/settings.py +172 -0
- package/python/export/__init__.py +6 -0
- package/python/export/report.py +192 -0
- package/python/main.py +632 -0
- package/python/pyproject.toml +51 -0
- package/python/scheduler/tasks.py +77 -0
- package/python/tools/fliggy_mcp.py +553 -0
- package/python/tools/flyai_tools.py +251 -0
- package/python/tools/mcp_tools.py +412 -0
- package/python/utils/__init__.py +9 -0
- package/python/utils/http.py +73 -0
- package/python/utils/storage.py +288 -0
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
"""Agent 基类模块 - 支持多 LLM 提供商
|
|
2
|
+
|
|
3
|
+
提供统一的 Agent 接口,支持 Anthropic/OpenAI/DeepSeek/Azure/Ollama 等厂商
|
|
4
|
+
"""
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Dict, Any, Optional, List, Callable
|
|
7
|
+
from config.settings import get_settings, LLM_PROVIDERS
|
|
8
|
+
from tools.mcp_tools import Tool, ToolRegistry, ToolHandler, build_tool
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseAgent(ABC):
|
|
12
|
+
"""Agent 基类
|
|
13
|
+
|
|
14
|
+
所有子 Agent 都需要继承此类并实现具体方法
|
|
15
|
+
|
|
16
|
+
支持两种执行模式:
|
|
17
|
+
1. 纯文本对话 (默认)
|
|
18
|
+
2. Tool Use 模式 (需要配置 tools)
|
|
19
|
+
|
|
20
|
+
支持的 LLM 提供商:
|
|
21
|
+
- Anthropic (claude-sonnet-4-6, claude-opus-4-6, ...)
|
|
22
|
+
- OpenAI (gpt-4o, gpt-4-turbo, ...)
|
|
23
|
+
- DeepSeek (deepseek-chat, deepseek-coder)
|
|
24
|
+
- Azure OpenAI
|
|
25
|
+
- Ollama (本地部署)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
name: str = "base_agent"
|
|
29
|
+
role: str = "助手"
|
|
30
|
+
goal: str = "帮助用户完成任务"
|
|
31
|
+
|
|
32
|
+
# 子类可定义可用的工具列表(MCP 格式)
|
|
33
|
+
# 使用 build_tool() 构建符合 MCP 标准的工具定义
|
|
34
|
+
available_tools: Dict[str, Tool] = {}
|
|
35
|
+
tool_handlers: Dict[str, ToolHandler] = {}
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
provider: Optional[str] = None,
|
|
40
|
+
model: Optional[str] = None,
|
|
41
|
+
use_tools: bool = False,
|
|
42
|
+
):
|
|
43
|
+
"""初始化 Agent
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
provider: LLM 提供商 (anthropic/openai/deepseek/azure/ollama)
|
|
47
|
+
model: 使用的模型名称,留空则使用配置中的默认模型
|
|
48
|
+
use_tools: 是否启用 Tool Use 模式
|
|
49
|
+
"""
|
|
50
|
+
self.settings = get_settings()
|
|
51
|
+
|
|
52
|
+
# 确定使用的提供商和模型
|
|
53
|
+
self.provider = provider or self.settings.llm_provider or "anthropic"
|
|
54
|
+
self.model = model or self.settings.llm_model or self.settings.get_active_model()
|
|
55
|
+
self.use_tools = use_tools
|
|
56
|
+
|
|
57
|
+
# 工具注册表
|
|
58
|
+
self.tool_registry = ToolRegistry()
|
|
59
|
+
|
|
60
|
+
# 初始化客户端
|
|
61
|
+
self.client = self._init_client()
|
|
62
|
+
|
|
63
|
+
# 注册工具
|
|
64
|
+
self._register_tools()
|
|
65
|
+
|
|
66
|
+
def _init_client(self):
|
|
67
|
+
"""初始化 LLM 客户端
|
|
68
|
+
|
|
69
|
+
根据配置的提供商初始化对应的客户端
|
|
70
|
+
"""
|
|
71
|
+
if not self.settings.is_provider_configured(self.provider):
|
|
72
|
+
print(f"[{self.name}] 警告:{self.provider} 未配置,将使用降级模式")
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
try:
|
|
76
|
+
if self.provider == "anthropic":
|
|
77
|
+
from anthropic import Anthropic
|
|
78
|
+
return Anthropic(api_key=self.settings.anthropic_api_key)
|
|
79
|
+
|
|
80
|
+
elif self.provider == "openai":
|
|
81
|
+
from openai import OpenAI
|
|
82
|
+
return OpenAI(api_key=self.settings.openai_api_key)
|
|
83
|
+
|
|
84
|
+
elif self.provider == "deepseek":
|
|
85
|
+
from openai import OpenAI # DeepSeek 使用 OpenAI 兼容接口
|
|
86
|
+
return OpenAI(
|
|
87
|
+
api_key=self.settings.deepseek_api_key,
|
|
88
|
+
base_url="https://api.deepseek.com"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
elif self.provider == "azure":
|
|
92
|
+
from openai import AzureOpenAI
|
|
93
|
+
return AzureOpenAI(
|
|
94
|
+
api_key=self.settings.azure_openai_api_key,
|
|
95
|
+
azure_endpoint=self.settings.azure_openai_endpoint,
|
|
96
|
+
api_version=self.settings.azure_openai_api_version,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
elif self.provider == "ollama":
|
|
100
|
+
from openai import OpenAI # Ollama 使用 OpenAI 兼容接口
|
|
101
|
+
return OpenAI(
|
|
102
|
+
base_url=self.settings.ollama_base_url,
|
|
103
|
+
api_key="ollama", # Ollama 不需要真实 Key
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
elif self.provider == "qwen":
|
|
107
|
+
from openai import OpenAI # Qwen 使用 OpenAI 兼容接口
|
|
108
|
+
return OpenAI(
|
|
109
|
+
api_key=self.settings.dashscope_api_key,
|
|
110
|
+
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
else:
|
|
114
|
+
print(f"[{self.name}] 未知提供商:{self.provider}")
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
except ImportError as e:
|
|
118
|
+
print(f"[{self.name}] 导入客户端库失败:{e}")
|
|
119
|
+
print(f"请安装对应的依赖:pip install anthropic openai")
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
print(f"[{self.name}] 初始化客户端失败:{e}")
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
def _register_tools(self):
|
|
127
|
+
"""注册工具
|
|
128
|
+
|
|
129
|
+
子类可以在 available_tools 中定义工具(MCP 格式),在 tool_handlers 中定义处理函数
|
|
130
|
+
工具会被注册到 tool_registry 中,支持动态添加和调用
|
|
131
|
+
"""
|
|
132
|
+
# 注册类变量中定义的工具
|
|
133
|
+
for tool_name, tool_def in self.available_tools.items():
|
|
134
|
+
handler = self.tool_handlers.get(tool_name)
|
|
135
|
+
if handler:
|
|
136
|
+
self.tool_registry.register(tool_def, handler)
|
|
137
|
+
|
|
138
|
+
def _get_tool_definitions(self) -> List[Tool]:
|
|
139
|
+
"""获取工具定义列表(MCP 标准格式)
|
|
140
|
+
|
|
141
|
+
返回 Claude API / MCP 兼容的工具定义
|
|
142
|
+
"""
|
|
143
|
+
return self.tool_registry.get_definitions()
|
|
144
|
+
|
|
145
|
+
def _build_prompt(self, task: str, context: Dict[str, Any]) -> str:
|
|
146
|
+
"""构建 Prompt
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
task: 任务描述
|
|
150
|
+
context: 上下文信息
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
完整的 prompt 字符串
|
|
154
|
+
"""
|
|
155
|
+
context_str = ""
|
|
156
|
+
if context:
|
|
157
|
+
import json
|
|
158
|
+
clean_context = {}
|
|
159
|
+
for k, v in context.items():
|
|
160
|
+
if hasattr(v, 'dict') or hasattr(v, 'model_dump'):
|
|
161
|
+
clean_context[k] = v.model_dump() if hasattr(v, 'model_dump') else v.dict()
|
|
162
|
+
elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], 'dict'):
|
|
163
|
+
clean_context[k] = [
|
|
164
|
+
item.model_dump() if hasattr(item, 'model_dump') else item.dict()
|
|
165
|
+
for item in v[:10]
|
|
166
|
+
]
|
|
167
|
+
else:
|
|
168
|
+
clean_context[k] = v
|
|
169
|
+
|
|
170
|
+
context_str = json.dumps(clean_context, ensure_ascii=False, indent=2, default=str)
|
|
171
|
+
|
|
172
|
+
return f"""你是一个{self.role}。
|
|
173
|
+
你的目标:{self.goal}
|
|
174
|
+
|
|
175
|
+
当前任务:
|
|
176
|
+
{task}
|
|
177
|
+
|
|
178
|
+
上下文信息:
|
|
179
|
+
{context_str if context_str else "无"}
|
|
180
|
+
|
|
181
|
+
请完成你的任务,直接输出结果:"""
|
|
182
|
+
|
|
183
|
+
async def _execute_tool(self, tool_name: str, tool_input: Dict[str, Any]) -> str:
|
|
184
|
+
"""执行工具调用(通过 tool_registry)
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
tool_name: 工具名称
|
|
188
|
+
tool_input: 工具参数
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
工具执行结果
|
|
192
|
+
"""
|
|
193
|
+
try:
|
|
194
|
+
result = await self.tool_registry.call(tool_name, **tool_input)
|
|
195
|
+
return str(result)
|
|
196
|
+
except ValueError as e:
|
|
197
|
+
return f"错误:未知工具 {tool_name}"
|
|
198
|
+
except Exception as e:
|
|
199
|
+
return f"工具执行失败:{e}"
|
|
200
|
+
|
|
201
|
+
async def execute(
|
|
202
|
+
self,
|
|
203
|
+
task: str,
|
|
204
|
+
context: Optional[Dict[str, Any]] = None,
|
|
205
|
+
system_prompt: Optional[str] = None,
|
|
206
|
+
max_tokens: int = 2048,
|
|
207
|
+
use_tool_mode: Optional[bool] = None,
|
|
208
|
+
) -> str:
|
|
209
|
+
"""执行任务
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
task: 任务描述
|
|
213
|
+
context: 上下文信息
|
|
214
|
+
system_prompt: 可选的系统 prompt
|
|
215
|
+
max_tokens: 最大输出 token 数
|
|
216
|
+
use_tool_mode: 是否使用 Tool Use 模式,None 则使用实例配置
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Agent 执行结果
|
|
220
|
+
"""
|
|
221
|
+
context = context or {}
|
|
222
|
+
use_tool = use_tool_mode if use_tool_mode is not None else self.use_tools
|
|
223
|
+
|
|
224
|
+
# 如果没有客户端,使用本地执行
|
|
225
|
+
if not self.client:
|
|
226
|
+
return await self.execute_local(task, context)
|
|
227
|
+
|
|
228
|
+
# Tool Use 模式
|
|
229
|
+
if use_tool and self.available_tools:
|
|
230
|
+
return await self._execute_with_tools(
|
|
231
|
+
task, context, system_prompt, max_tokens
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# 普通文本模式
|
|
235
|
+
prompt = self._build_prompt(task, context)
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
if self.provider == "anthropic":
|
|
239
|
+
return await self._call_anthropic(prompt, system_prompt, max_tokens)
|
|
240
|
+
else:
|
|
241
|
+
return await self._call_openai_compatible(prompt, system_prompt, max_tokens)
|
|
242
|
+
|
|
243
|
+
except Exception as e:
|
|
244
|
+
print(f"[{self.name}] API 调用失败:{e},降级为本地执行")
|
|
245
|
+
return await self.execute_local(task, context)
|
|
246
|
+
|
|
247
|
+
async def _call_anthropic(
|
|
248
|
+
self,
|
|
249
|
+
prompt: str,
|
|
250
|
+
system_prompt: Optional[str],
|
|
251
|
+
max_tokens: int
|
|
252
|
+
) -> str:
|
|
253
|
+
"""调用 Anthropic API"""
|
|
254
|
+
messages = [{"role": "user", "content": prompt}]
|
|
255
|
+
if system_prompt:
|
|
256
|
+
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
257
|
+
|
|
258
|
+
response = self.client.messages.create(
|
|
259
|
+
model=self.model,
|
|
260
|
+
max_tokens=max_tokens,
|
|
261
|
+
messages=messages,
|
|
262
|
+
)
|
|
263
|
+
return response.content[0].text
|
|
264
|
+
|
|
265
|
+
async def _call_openai_compatible(
|
|
266
|
+
self,
|
|
267
|
+
prompt: str,
|
|
268
|
+
system_prompt: Optional[str],
|
|
269
|
+
max_tokens: int
|
|
270
|
+
) -> str:
|
|
271
|
+
"""调用 OpenAI 兼容 API (OpenAI/DeepSeek/Azure/Ollama)"""
|
|
272
|
+
messages = [{"role": "user", "content": prompt}]
|
|
273
|
+
if system_prompt:
|
|
274
|
+
messages.insert(0, {"role": "system", "content": system_prompt})
|
|
275
|
+
|
|
276
|
+
response = self.client.chat.completions.create(
|
|
277
|
+
model=self.model,
|
|
278
|
+
messages=messages,
|
|
279
|
+
max_tokens=max_tokens,
|
|
280
|
+
)
|
|
281
|
+
return response.choices[0].message.content
|
|
282
|
+
|
|
283
|
+
async def _execute_with_tools(
|
|
284
|
+
self,
|
|
285
|
+
task: str,
|
|
286
|
+
context: Dict[str, Any],
|
|
287
|
+
system_prompt: Optional[str],
|
|
288
|
+
max_tokens: int = 2048
|
|
289
|
+
) -> str:
|
|
290
|
+
"""使用 Tool Use 模式执行任务"""
|
|
291
|
+
# 注意:Tool Use 目前仅支持 Anthropic
|
|
292
|
+
if self.provider != "anthropic":
|
|
293
|
+
print(f"[{self.name}] Tool Use 仅支持 Anthropic,降级为普通模式")
|
|
294
|
+
return await self.execute(task, context, system_prompt, max_tokens, use_tool_mode=False)
|
|
295
|
+
|
|
296
|
+
prompt = self._build_prompt(task, context)
|
|
297
|
+
messages = [{"role": "user", "content": prompt}]
|
|
298
|
+
|
|
299
|
+
try:
|
|
300
|
+
response = self.client.messages.create(
|
|
301
|
+
model=self.model,
|
|
302
|
+
max_tokens=max_tokens,
|
|
303
|
+
system=system_prompt or f"你是一个{self.role}。{self.goal}",
|
|
304
|
+
messages=messages,
|
|
305
|
+
tools=self._get_tool_definitions() if self.available_tools else None,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
final_result = []
|
|
309
|
+
tool_results = []
|
|
310
|
+
|
|
311
|
+
for content in response.content:
|
|
312
|
+
if content.type == "text":
|
|
313
|
+
final_result.append(content.text)
|
|
314
|
+
elif content.type == "tool_use":
|
|
315
|
+
tool_name = content.name
|
|
316
|
+
tool_input = content.input
|
|
317
|
+
tool_id = content.id
|
|
318
|
+
|
|
319
|
+
tool_result = await self._execute_tool(tool_name, tool_input)
|
|
320
|
+
tool_results.append({
|
|
321
|
+
"type": "tool_result",
|
|
322
|
+
"tool_use_id": tool_id,
|
|
323
|
+
"content": tool_result,
|
|
324
|
+
})
|
|
325
|
+
|
|
326
|
+
if tool_results:
|
|
327
|
+
messages.append(response)
|
|
328
|
+
messages.append({
|
|
329
|
+
"role": "user",
|
|
330
|
+
"content": tool_results,
|
|
331
|
+
})
|
|
332
|
+
|
|
333
|
+
response2 = self.client.messages.create(
|
|
334
|
+
model=self.model,
|
|
335
|
+
max_tokens=max_tokens,
|
|
336
|
+
messages=messages,
|
|
337
|
+
tools=self._get_tool_definitions() if self.available_tools else None,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
for content in response2.content:
|
|
341
|
+
if content.type == "text":
|
|
342
|
+
final_result.append(content.text)
|
|
343
|
+
|
|
344
|
+
return "\n".join(final_result)
|
|
345
|
+
|
|
346
|
+
except Exception as e:
|
|
347
|
+
print(f"[{self.name}] Tool Use 执行失败:{e},降级为普通模式")
|
|
348
|
+
return await self.execute(task, context, system_prompt, max_tokens, use_tool_mode=False)
|
|
349
|
+
|
|
350
|
+
@abstractmethod
|
|
351
|
+
def execute_local(self, task: str, context: Dict[str, Any]) -> str:
|
|
352
|
+
"""本地执行逻辑(API 不可用时降级)
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
task: 任务描述
|
|
356
|
+
context: 上下文信息
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
执行结果
|
|
360
|
+
"""
|
|
361
|
+
pass
|
|
362
|
+
|
|
363
|
+
def switch_provider(self, provider: str, model: Optional[str] = None):
|
|
364
|
+
"""切换 LLM 提供商
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
provider: 新的提供商
|
|
368
|
+
model: 可选的模型名称
|
|
369
|
+
"""
|
|
370
|
+
self.provider = provider
|
|
371
|
+
if model:
|
|
372
|
+
self.model = model
|
|
373
|
+
self.client = self._init_client()
|
|
374
|
+
print(f"[{self.name}] 已切换到 {provider}/{self.model}")
|
|
375
|
+
|
|
376
|
+
def __repr__(self) -> str:
|
|
377
|
+
return f"{self.__class__.__name__}(provider={self.provider}, model={self.model})"
|
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""采集 Agent - 负责数据采集任务
|
|
2
|
+
|
|
3
|
+
支持 Tool Use 模式,Claude 可以自主决定调用哪个采集工具
|
|
4
|
+
"""
|
|
5
|
+
from typing import Dict, Any, List
|
|
6
|
+
from agents.base import BaseAgent
|
|
7
|
+
from tools.mcp_tools import build_tool, string_property, integer_property
|
|
8
|
+
from collectors.xiaohongshu import XiaohongshuCollector
|
|
9
|
+
from collectors.weibo import WeiboCollector
|
|
10
|
+
from collectors.wenlv import WenlvCollector
|
|
11
|
+
from collectors.ota.fliggy import FliggyCollector
|
|
12
|
+
from collectors.ota.ctrip import CtripCollector
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CollectionAgent(BaseAgent):
|
|
16
|
+
"""采集 Agent
|
|
17
|
+
|
|
18
|
+
职责:
|
|
19
|
+
- 采集社交媒体热点(小红书、微博)
|
|
20
|
+
- 采集文旅局官方信息
|
|
21
|
+
- 采集 OTA 机酒价格
|
|
22
|
+
|
|
23
|
+
可用工具(MCP 标准格式):
|
|
24
|
+
- search_social_media: 搜索社交媒体
|
|
25
|
+
- collect_wenlv_info: 采集文旅信息
|
|
26
|
+
- search_flights: 搜索机票
|
|
27
|
+
- search_hotels: 搜索酒店
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
name = "collection_agent"
|
|
31
|
+
role = "旅行数据采集专家"
|
|
32
|
+
goal = "从多个来源采集高质量的旅行相关数据,包括社交媒体热点、官方政策和机酒价格"
|
|
33
|
+
|
|
34
|
+
# 定义可用的工具(MCP 标准格式)
|
|
35
|
+
available_tools = {
|
|
36
|
+
"search_social_media": build_tool(
|
|
37
|
+
name="search_social_media",
|
|
38
|
+
description="搜索社交媒体平台上的旅行相关内容",
|
|
39
|
+
properties={
|
|
40
|
+
"keyword": string_property(
|
|
41
|
+
"搜索关键词,例如 '三亚旅行'、'海岛游'"
|
|
42
|
+
),
|
|
43
|
+
"platform": string_property(
|
|
44
|
+
"目标平台:xiaohongshu(小红书), weibo(微博), all(全部)",
|
|
45
|
+
enum=["xiaohongshu", "weibo", "all"],
|
|
46
|
+
default="all"
|
|
47
|
+
),
|
|
48
|
+
"limit": integer_property(
|
|
49
|
+
"返回结果数量限制",
|
|
50
|
+
minimum=1,
|
|
51
|
+
maximum=100,
|
|
52
|
+
default=20
|
|
53
|
+
)
|
|
54
|
+
},
|
|
55
|
+
required=["keyword"]
|
|
56
|
+
),
|
|
57
|
+
"collect_wenlv_info": build_tool(
|
|
58
|
+
name="collect_wenlv_info",
|
|
59
|
+
description="采集各地文旅局官网的政策、活动和推荐路线信息",
|
|
60
|
+
properties={
|
|
61
|
+
"region": string_property(
|
|
62
|
+
"可选的地区名称,例如 '云南'、'浙江',不传则采集全部"
|
|
63
|
+
)
|
|
64
|
+
},
|
|
65
|
+
required=[]
|
|
66
|
+
),
|
|
67
|
+
"search_flights": build_tool(
|
|
68
|
+
name="search_flights",
|
|
69
|
+
description="搜索航班价格信息",
|
|
70
|
+
properties={
|
|
71
|
+
"departure_city": string_property(
|
|
72
|
+
"出发城市,例如 '北京'、'上海'"
|
|
73
|
+
),
|
|
74
|
+
"arrival_city": string_property(
|
|
75
|
+
"到达城市,例如 '三亚'、'成都'"
|
|
76
|
+
),
|
|
77
|
+
"date": string_property(
|
|
78
|
+
"出发日期,格式 YYYY-MM-DD,不传则搜索近期"
|
|
79
|
+
)
|
|
80
|
+
},
|
|
81
|
+
required=["departure_city", "arrival_city"]
|
|
82
|
+
),
|
|
83
|
+
"search_hotels": build_tool(
|
|
84
|
+
name="search_hotels",
|
|
85
|
+
description="搜索酒店价格信息",
|
|
86
|
+
properties={
|
|
87
|
+
"destination": string_property(
|
|
88
|
+
"目的地城市,例如 '三亚'、'成都'"
|
|
89
|
+
),
|
|
90
|
+
"price_min": integer_property(
|
|
91
|
+
"最低价格(元/晚)"
|
|
92
|
+
),
|
|
93
|
+
"price_max": integer_property(
|
|
94
|
+
"最高价格(元/晚)"
|
|
95
|
+
)
|
|
96
|
+
},
|
|
97
|
+
required=["destination"]
|
|
98
|
+
)
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
def __init__(self, provider: str = None, model: str = None, use_tools: bool = True):
|
|
102
|
+
# 先设置工具处理函数,再调用父类初始化
|
|
103
|
+
# 这样 _register_tools() 才能正确注册
|
|
104
|
+
self.tool_handlers = {
|
|
105
|
+
"search_social_media": self._handle_search_social_media,
|
|
106
|
+
"collect_wenlv_info": self._handle_collect_wenlv,
|
|
107
|
+
"search_flights": self._handle_search_flights,
|
|
108
|
+
"search_hotels": self._handle_search_hotels,
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
super().__init__(provider, model, use_tools)
|
|
112
|
+
|
|
113
|
+
async def execute_local(self, task: str, context: Dict[str, Any]) -> str:
|
|
114
|
+
"""本地执行采集"""
|
|
115
|
+
sources = context.get("sources", ["xiaohongshu", "weibo", "wenlv"])
|
|
116
|
+
keyword = context.get("keyword", "旅行")
|
|
117
|
+
|
|
118
|
+
results = []
|
|
119
|
+
for source in sources:
|
|
120
|
+
if source in self.collectors:
|
|
121
|
+
results.append(f"✓ {source}: 采集完成(模拟数据)")
|
|
122
|
+
|
|
123
|
+
return f"采集结果:\n" + "\n".join(results)
|
|
124
|
+
|
|
125
|
+
# ========== 工具处理函数 ==========
|
|
126
|
+
|
|
127
|
+
async def _handle_search_social_media(
|
|
128
|
+
self,
|
|
129
|
+
keyword: str,
|
|
130
|
+
platform: str = "all",
|
|
131
|
+
limit: int = 20
|
|
132
|
+
) -> str:
|
|
133
|
+
"""处理社交媒体搜索工具"""
|
|
134
|
+
result = {"posts": [], "stats": {}}
|
|
135
|
+
|
|
136
|
+
platforms = ["xiaohongshu", "weibo"] if platform == "all" else [platform]
|
|
137
|
+
|
|
138
|
+
if "xiaohongshu" in platforms:
|
|
139
|
+
collector = XiaohongshuCollector()
|
|
140
|
+
posts = await collector.search(keyword=keyword, page_size=limit)
|
|
141
|
+
result["posts"].extend([p.model_dump() for p in posts])
|
|
142
|
+
result["stats"]["xiaohongshu_count"] = len(posts)
|
|
143
|
+
|
|
144
|
+
if "weibo" in platforms:
|
|
145
|
+
collector = WeiboCollector()
|
|
146
|
+
posts = await collector.search(keyword=keyword, count=limit)
|
|
147
|
+
result["posts"].extend([p.model_dump() for p in posts])
|
|
148
|
+
result["stats"]["weibo_count"] = len(posts)
|
|
149
|
+
|
|
150
|
+
import json
|
|
151
|
+
return json.dumps(result, ensure_ascii=False, default=str)
|
|
152
|
+
|
|
153
|
+
async def _handle_collect_wenlv(self, region: str = None) -> str:
|
|
154
|
+
"""处理文旅信息采集工具"""
|
|
155
|
+
collector = WenlvCollector()
|
|
156
|
+
infos = await collector.collect()
|
|
157
|
+
|
|
158
|
+
# 按地区过滤
|
|
159
|
+
if region:
|
|
160
|
+
infos = [i for i in infos if region in i.region]
|
|
161
|
+
|
|
162
|
+
result = {
|
|
163
|
+
"infos": [i.model_dump() for i in infos],
|
|
164
|
+
"stats": {"wenlv_count": len(infos)}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
import json
|
|
168
|
+
return json.dumps(result, ensure_ascii=False, default=str)
|
|
169
|
+
|
|
170
|
+
async def _handle_search_flights(
|
|
171
|
+
self,
|
|
172
|
+
departure_city: str,
|
|
173
|
+
arrival_city: str,
|
|
174
|
+
date: str = None
|
|
175
|
+
) -> str:
|
|
176
|
+
"""处理航班搜索工具"""
|
|
177
|
+
collector = FliggyCollector()
|
|
178
|
+
flights = await collector.search_flights(departure_city, arrival_city, date)
|
|
179
|
+
|
|
180
|
+
result = {
|
|
181
|
+
"flights": [f.model_dump() for f in flights],
|
|
182
|
+
"stats": {"count": len(flights)}
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
import json
|
|
186
|
+
return json.dumps(result, ensure_ascii=False, default=str)
|
|
187
|
+
|
|
188
|
+
async def _handle_search_hotels(
|
|
189
|
+
self,
|
|
190
|
+
destination: str,
|
|
191
|
+
price_min: int = None,
|
|
192
|
+
price_max: int = None
|
|
193
|
+
) -> str:
|
|
194
|
+
"""处理酒店搜索工具"""
|
|
195
|
+
collector = FliggyCollector()
|
|
196
|
+
hotels = await collector.search_hotels(
|
|
197
|
+
destination,
|
|
198
|
+
price_min=price_min,
|
|
199
|
+
price_max=price_max
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
result = {
|
|
203
|
+
"hotels": [h.model_dump() for h in hotels],
|
|
204
|
+
"stats": {"count": len(hotels)}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
import json
|
|
208
|
+
return json.dumps(result, ensure_ascii=False, default=str)
|
|
209
|
+
|
|
210
|
+
# ========== 高级方法(供 Manager 调用) ==========
|
|
211
|
+
|
|
212
|
+
async def collect_social_media(
|
|
213
|
+
self,
|
|
214
|
+
keyword: str = "旅行",
|
|
215
|
+
sources: List[str] = None
|
|
216
|
+
) -> Dict[str, Any]:
|
|
217
|
+
"""采集社交媒体数据"""
|
|
218
|
+
if sources is None:
|
|
219
|
+
sources = ["xiaohongshu", "weibo"]
|
|
220
|
+
|
|
221
|
+
result = {"posts": [], "stats": {}}
|
|
222
|
+
|
|
223
|
+
if "xiaohongshu" in sources:
|
|
224
|
+
collector = XiaohongshuCollector()
|
|
225
|
+
posts = await collector.search(keyword=keyword)
|
|
226
|
+
result["posts"].extend([p.model_dump() for p in posts])
|
|
227
|
+
result["stats"]["xiaohongshu_count"] = len(posts)
|
|
228
|
+
|
|
229
|
+
if "weibo" in sources:
|
|
230
|
+
collector = WeiboCollector()
|
|
231
|
+
posts = await collector.search(keyword=keyword)
|
|
232
|
+
result["posts"].extend([p.model_dump() for p in posts])
|
|
233
|
+
result["stats"]["weibo_count"] = len(posts)
|
|
234
|
+
|
|
235
|
+
return result
|
|
236
|
+
|
|
237
|
+
async def collect_wenlv(self) -> Dict[str, Any]:
|
|
238
|
+
"""采集文旅局信息"""
|
|
239
|
+
collector = WenlvCollector()
|
|
240
|
+
infos = await collector.collect()
|
|
241
|
+
|
|
242
|
+
return {
|
|
243
|
+
"infos": [i.model_dump() for i in infos],
|
|
244
|
+
"stats": {"wenlv_count": len(infos)}
|
|
245
|
+
}
|
|
246
|
+
|
|
247
|
+
async def collect_ota(
|
|
248
|
+
self,
|
|
249
|
+
destinations: List[str],
|
|
250
|
+
departure_city: str = "北京"
|
|
251
|
+
) -> Dict[str, Any]:
|
|
252
|
+
"""采集 OTA 机酒数据"""
|
|
253
|
+
result = {"flights": [], "hotels": [], "stats": {}}
|
|
254
|
+
|
|
255
|
+
flight_collector = FliggyCollector()
|
|
256
|
+
hotel_collector = FliggyCollector()
|
|
257
|
+
|
|
258
|
+
for dest in destinations[:5]:
|
|
259
|
+
flights = await flight_collector.search_flights(departure_city, dest)
|
|
260
|
+
hotels = await hotel_collector.search_hotels(dest)
|
|
261
|
+
result["flights"].extend([f.model_dump() for f in flights])
|
|
262
|
+
result["hotels"].extend([h.model_dump() for h in hotels])
|
|
263
|
+
|
|
264
|
+
result["stats"]["flight_count"] = len(result["flights"])
|
|
265
|
+
result["stats"]["hotel_count"] = len(result["hotels"])
|
|
266
|
+
|
|
267
|
+
return result
|
|
268
|
+
|
|
269
|
+
async def collect_all(
|
|
270
|
+
self,
|
|
271
|
+
keyword: str = "旅行",
|
|
272
|
+
include_ota: bool = True
|
|
273
|
+
) -> Dict[str, Any]:
|
|
274
|
+
"""执行完整采集流程"""
|
|
275
|
+
result = {
|
|
276
|
+
"social_media": await self.collect_social_media(keyword),
|
|
277
|
+
"wenlv": await self.collect_wenlv(),
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
if include_ota:
|
|
281
|
+
destinations = self._extract_destinations_from_posts(
|
|
282
|
+
result["social_media"]["posts"]
|
|
283
|
+
)
|
|
284
|
+
result["ota"] = await self.collect_ota(destinations)
|
|
285
|
+
|
|
286
|
+
return result
|
|
287
|
+
|
|
288
|
+
def _extract_destinations_from_posts(self, posts: List[Dict]) -> List[str]:
|
|
289
|
+
"""从帖子中提取目的地"""
|
|
290
|
+
common_destinations = [
|
|
291
|
+
"三亚", "云南", "大理", "丽江", "四川", "成都",
|
|
292
|
+
"北京", "上海", "浙江", "杭州", "江苏", "苏州",
|
|
293
|
+
"陕西", "西安", "广西", "桂林", "海南", "西藏",
|
|
294
|
+
"新疆", "甘肃", "青海", "黑龙江", "哈尔滨"
|
|
295
|
+
]
|
|
296
|
+
|
|
297
|
+
destinations = set()
|
|
298
|
+
for post in posts:
|
|
299
|
+
text = f"{post.get('title', '')} {post.get('content', '')}"
|
|
300
|
+
for dest in common_destinations:
|
|
301
|
+
if dest in text:
|
|
302
|
+
destinations.add(dest)
|
|
303
|
+
|
|
304
|
+
return list(destinations)[:10]
|