LightAgent 0.3.3__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {lightagent-0.3.3 → lightagent-0.4.0}/LightAgent/la_core.py +442 -482
- {lightagent-0.3.3 → lightagent-0.4.0}/PKG-INFO +19 -15
- {lightagent-0.3.3 → lightagent-0.4.0}/README.de.md +1 -1
- {lightagent-0.3.3 → lightagent-0.4.0}/README.es.md +1 -1
- {lightagent-0.3.3 → lightagent-0.4.0}/README.fr.md +1 -1
- {lightagent-0.3.3 → lightagent-0.4.0}/README.ja.md +1 -1
- {lightagent-0.3.3 → lightagent-0.4.0}/README.ko.md +1 -1
- {lightagent-0.3.3 → lightagent-0.4.0}/README.md +4 -2
- {lightagent-0.3.3 → lightagent-0.4.0}/README.pt.md +1 -1
- {lightagent-0.3.3 → lightagent-0.4.0}/README.ru.md +1 -1
- {lightagent-0.3.3 → lightagent-0.4.0}/README.zh-CN.md +4 -2
- {lightagent-0.3.3 → lightagent-0.4.0}/pyproject.toml +4 -4
- {lightagent-0.3.3 → lightagent-0.4.0}/LICENSE +0 -0
- {lightagent-0.3.3 → lightagent-0.4.0}/LightAgent/__init__.py +0 -0
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
|
|
4
4
|
"""
|
|
5
5
|
作者: [weego/WXAI-Team]
|
|
6
|
-
版本: 0.
|
|
7
|
-
最后更新: 2025-
|
|
6
|
+
版本: 0.4.0
|
|
7
|
+
最后更新: 2025-06-12
|
|
8
8
|
"""
|
|
9
9
|
|
|
10
10
|
import asyncio
|
|
@@ -21,7 +21,7 @@ from contextlib import AsyncExitStack
|
|
|
21
21
|
from copy import deepcopy
|
|
22
22
|
from datetime import datetime
|
|
23
23
|
from functools import partial
|
|
24
|
-
from typing import List, Dict, Any, Callable, Union, Optional, Generator, AsyncGenerator
|
|
24
|
+
from typing import List, Dict, Any, Callable, Union, Optional, Generator, AsyncGenerator, Protocol
|
|
25
25
|
from uuid import uuid4
|
|
26
26
|
|
|
27
27
|
import httpx
|
|
@@ -30,34 +30,39 @@ from mcp.client.sse import sse_client
|
|
|
30
30
|
from mcp.client.stdio import stdio_client
|
|
31
31
|
from openai.types.chat import ChatCompletionChunk
|
|
32
32
|
|
|
33
|
-
# 全局工具注册表
|
|
34
|
-
_FUNCTION_MAPPINGS = {} # 工具名称 -> 工具函数
|
|
35
|
-
_FUNCTION_INFO = {} # 工具名称 -> 工具info信息
|
|
36
|
-
_OPENAI_FUNCTION_SCHEMAS = [] # OpenAI 格式的工具描述
|
|
37
|
-
_PROMPT_FUNCTION_SCHEMAS = [] # prompt 格式的工具描述
|
|
38
33
|
|
|
39
|
-
__version__ = "0.
|
|
34
|
+
__version__ = "0.4.0" # 你可以根据需要设置版本号
|
|
40
35
|
|
|
41
36
|
|
|
42
37
|
# openai.langfuse_auth_check()
|
|
43
38
|
|
|
39
|
+
# 1. 定义内存接口协议
|
|
40
|
+
class MemoryProtocol(Protocol):
|
|
41
|
+
def store(self, data: str, user_id: str) -> Any:
|
|
42
|
+
...
|
|
44
43
|
|
|
45
|
-
def
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
"""
|
|
50
|
-
|
|
44
|
+
def retrieve(self, query: str, user_id: str) -> List[Any]:
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
class ToolRegistry:
|
|
48
|
+
"""集中管理工具注册表,避免全局变量"""
|
|
49
|
+
|
|
50
|
+
def __init__(self):
|
|
51
|
+
self.function_mappings = {} # 工具名称 -> 工具函数
|
|
52
|
+
self.function_info = {} # 工具名称 -> 工具info信息
|
|
53
|
+
self.openai_function_schemas = [] # OpenAI 格式的工具描述
|
|
54
|
+
|
|
55
|
+
def register_tool(self, func: Callable) -> bool:
|
|
56
|
+
"""注册单个工具"""
|
|
51
57
|
if not hasattr(func, "tool_info"):
|
|
52
|
-
|
|
53
|
-
continue
|
|
58
|
+
return False
|
|
54
59
|
|
|
55
60
|
tool_info = func.tool_info
|
|
56
61
|
tool_name = tool_info["tool_name"]
|
|
57
62
|
|
|
58
|
-
#
|
|
59
|
-
|
|
60
|
-
|
|
63
|
+
# 注册到字典
|
|
64
|
+
self.function_info[tool_name] = tool_info
|
|
65
|
+
self.function_mappings[tool_name] = func
|
|
61
66
|
|
|
62
67
|
# 构建 OpenAI 格式的工具描述
|
|
63
68
|
tool_params_openai = {}
|
|
@@ -83,174 +88,194 @@ def register_tool_manually(tools: List[Union[str, Callable]]) -> bool:
|
|
|
83
88
|
}
|
|
84
89
|
}
|
|
85
90
|
|
|
86
|
-
|
|
87
|
-
|
|
91
|
+
self.openai_function_schemas.append(tool_def_openai)
|
|
92
|
+
return True
|
|
88
93
|
|
|
94
|
+
def register_tools(self, tools: List[Callable]) -> bool:
|
|
95
|
+
"""批量注册工具"""
|
|
96
|
+
success = True
|
|
97
|
+
for func in tools:
|
|
98
|
+
if not self.register_tool(func):
|
|
99
|
+
success = False
|
|
100
|
+
return success
|
|
89
101
|
|
|
90
|
-
def
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
"""
|
|
94
|
-
tool_path = os.path.join(tools_directory, f"{tool_name}.py")
|
|
95
|
-
if not os.path.exists(tool_path):
|
|
96
|
-
raise FileNotFoundError(f"Tool '{tool_name}' not found in {tools_directory}")
|
|
102
|
+
def get_tools(self) -> List[Dict[str, Any]]:
|
|
103
|
+
"""获取所有工具的描述(OpenAI 格式)"""
|
|
104
|
+
return deepcopy(self.openai_function_schemas)
|
|
97
105
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
spec.loader.exec_module(module)
|
|
106
|
+
def get_tools_str(self) -> str:
|
|
107
|
+
"""将工具描述转换为格式化的 JSON 字符串"""
|
|
108
|
+
return json.dumps(self.openai_function_schemas, indent=4, ensure_ascii=False)
|
|
102
109
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
110
|
+
def filter_tools(self, tool_reflection_result: str) -> List[Dict]:
|
|
111
|
+
"""根据内容过滤工具"""
|
|
112
|
+
try:
|
|
113
|
+
# 安全解析可能包含 Markdown 代码块的 JSON
|
|
114
|
+
refined_content = tool_reflection_result.strip()
|
|
115
|
+
if refined_content.startswith('```json') and refined_content.endswith('```'):
|
|
116
|
+
refined_content = refined_content[7:-3].strip()
|
|
109
117
|
|
|
118
|
+
parsed_data = json.loads(refined_content)
|
|
119
|
+
valid_tools = {tool["name"].strip().lower() for tool in parsed_data.get("tools", [])}
|
|
110
120
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
121
|
+
return [
|
|
122
|
+
schema for schema in self.openai_function_schemas
|
|
123
|
+
if isinstance(schema, dict) and
|
|
124
|
+
schema.get("function", {}).get("name", "").strip().lower() in valid_tools
|
|
125
|
+
]
|
|
126
|
+
except (json.JSONDecodeError, KeyError, AttributeError) as e:
|
|
127
|
+
raise ValueError(f"工具过滤失败: {str(e)}") from e
|
|
118
128
|
|
|
119
|
-
tool_call = _FUNCTION_MAPPINGS[tool_name]
|
|
120
|
-
try:
|
|
121
|
-
# 处理不同类型的流式输出
|
|
122
|
-
# 区分同步/异步工具
|
|
123
|
-
if inspect.iscoroutinefunction(tool_call):
|
|
124
|
-
# result = await tool_call(**tool_params)
|
|
125
|
-
# 将参数以字典形式传递给包装器
|
|
126
|
-
result = await tool_call(**tool_params) if inspect.iscoroutinefunction(tool_call) else tool_call(
|
|
127
|
-
**tool_params)
|
|
128
|
-
else:
|
|
129
|
-
result = tool_call(**tool_params)
|
|
130
129
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
130
|
+
class ToolLoader:
|
|
131
|
+
"""工具加载器,支持动态加载和缓存"""
|
|
132
|
+
|
|
133
|
+
def __init__(self, tools_directory: str = "tools"):
|
|
134
|
+
self.tools_directory = tools_directory
|
|
135
|
+
self.loaded_tools = {}
|
|
136
|
+
|
|
137
|
+
def load_tool(self, tool_name: str) -> Callable:
|
|
138
|
+
"""加载单个工具"""
|
|
139
|
+
if tool_name in self.loaded_tools:
|
|
140
|
+
return self.loaded_tools[tool_name]
|
|
141
|
+
|
|
142
|
+
tool_path = os.path.join(self.tools_directory, f"{tool_name}.py")
|
|
143
|
+
if not os.path.exists(tool_path):
|
|
144
|
+
raise FileNotFoundError(f"Tool '{tool_name}' not found in {tool_path}")
|
|
145
|
+
|
|
146
|
+
# 动态加载模块
|
|
147
|
+
spec = importlib.util.spec_from_file_location(tool_name, tool_path)
|
|
148
|
+
module = importlib.util.module_from_spec(spec)
|
|
149
|
+
spec.loader.exec_module(module)
|
|
150
|
+
|
|
151
|
+
# 获取工具函数
|
|
152
|
+
if hasattr(module, tool_name):
|
|
153
|
+
tool_func = getattr(module, tool_name)
|
|
154
|
+
if callable(tool_func) and hasattr(tool_func, "tool_info"):
|
|
155
|
+
self.loaded_tools[tool_name] = tool_func
|
|
156
|
+
return tool_func
|
|
157
|
+
|
|
158
|
+
raise AttributeError(f"Tool '{tool_name}' is not properly defined in {tool_path}")
|
|
159
|
+
|
|
160
|
+
def load_tools(self, tool_names: List[str]) -> Dict[str, Callable]:
|
|
161
|
+
"""批量加载工具"""
|
|
162
|
+
for tool_name in tool_names:
|
|
163
|
+
if tool_name not in self.loaded_tools:
|
|
164
|
+
self.load_tool(tool_name)
|
|
165
|
+
return self.loaded_tools
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class AsyncToolDispatcher:
|
|
169
|
+
"""异步工具调度器"""
|
|
170
|
+
|
|
171
|
+
async def dispatch(self, tool_name: str, tool_params: Dict[str, Any]) -> Union[
|
|
172
|
+
str, Generator[str, None, None], AsyncGenerator[str, None]]:
|
|
173
|
+
"""调用工具执行,支持同步/异步工具及流式输出"""
|
|
174
|
+
if tool_name not in self.function_mappings:
|
|
175
|
+
return f"Tool `{tool_name}` not found."
|
|
176
|
+
|
|
177
|
+
tool_call = self.function_mappings[tool_name]
|
|
178
|
+
try:
|
|
179
|
+
# 处理不同类型的工具
|
|
180
|
+
if inspect.iscoroutinefunction(tool_call):
|
|
181
|
+
result = await tool_call(**tool_params)
|
|
182
|
+
elif inspect.isasyncgenfunction(tool_call):
|
|
183
|
+
result = tool_call(**tool_params)
|
|
184
|
+
else:
|
|
185
|
+
result = tool_call(**tool_params)
|
|
186
|
+
|
|
187
|
+
# 处理流式输出
|
|
188
|
+
if inspect.isasyncgen(result):
|
|
189
|
+
return self.async_stream_generator(result)
|
|
190
|
+
elif inspect.isgenerator(result):
|
|
191
|
+
return self.stream_generator(result)
|
|
137
192
|
return str(result)
|
|
138
|
-
|
|
139
|
-
|
|
193
|
+
except Exception as e:
|
|
194
|
+
return f"Tool call error: {str(e)}\n{traceback.format_exc()}"
|
|
140
195
|
|
|
196
|
+
async def async_stream_generator(self, async_gen: AsyncGenerator) -> AsyncGenerator[str, None]:
|
|
197
|
+
async for chunk in async_gen:
|
|
198
|
+
yield chunk
|
|
141
199
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
200
|
+
def stream_generator(self, sync_gen: Generator) -> Generator[str, None, None]:
|
|
201
|
+
for chunk in sync_gen:
|
|
202
|
+
yield chunk
|
|
145
203
|
|
|
146
204
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
yield chunk
|
|
205
|
+
class LoggerManager:
|
|
206
|
+
"""集中管理日志系统"""
|
|
150
207
|
|
|
208
|
+
def __init__(self, name: str, debug: bool, log_level: str, log_file: Optional[str] = None):
|
|
209
|
+
self.name = name
|
|
210
|
+
self.debug = debug
|
|
211
|
+
self.logger = self._setup_logger(log_level, log_file)
|
|
212
|
+
self.traceid = ""
|
|
151
213
|
|
|
152
|
-
def
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
if tool_name not in _FUNCTION_MAPPINGS:
|
|
157
|
-
return f"Tool `{tool_name}` not found."
|
|
214
|
+
def _setup_logger(self, log_level: str, log_file: Optional[str] = None) -> logging.Logger:
|
|
215
|
+
logger = logging.getLogger(self.name)
|
|
216
|
+
logger.setLevel(log_level.upper())
|
|
217
|
+
logger.propagate = False
|
|
158
218
|
|
|
159
|
-
|
|
160
|
-
try:
|
|
161
|
-
# print(f"Calling tool: {tool_name} with params: {tool_params}") # 调试信息
|
|
162
|
-
return str(tool_call(**tool_params))
|
|
163
|
-
except Exception as e:
|
|
164
|
-
# print(f"Tool call failed: {e}") # 调试信息
|
|
165
|
-
return traceback.format_exc()
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
def get_tools() -> List[Dict[str, Any]]:
|
|
169
|
-
"""
|
|
170
|
-
获取所有工具的描述(OpenAI 格式)
|
|
171
|
-
"""
|
|
172
|
-
return deepcopy(_OPENAI_FUNCTION_SCHEMAS)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def get_tools_str() -> str:
|
|
176
|
-
"""
|
|
177
|
-
将 _OPENAI_FUNCTION_SCHEMAS 转换为格式化的 JSON 字符串。
|
|
178
|
-
Returns:
|
|
179
|
-
str: 格式化的 JSON 字符串。
|
|
180
|
-
"""
|
|
181
|
-
# 使用 json.dumps 将字典转换为格式化的 JSON 字符串
|
|
182
|
-
tools_str = json.dumps(_OPENAI_FUNCTION_SCHEMAS, indent=4, ensure_ascii=False)
|
|
183
|
-
return tools_str
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
def filter_tools_schemas(refined_content: str) -> json:
|
|
187
|
-
"""
|
|
188
|
-
根据refined_content中的工具列表过滤全局_OPENAI_FUNCTION_SCHEMAS
|
|
189
|
-
:param refined_content: 包含工具列表的JSON字符串
|
|
190
|
-
"""
|
|
191
|
-
# global _OPENAI_FUNCTION_SCHEMAS # 声明操作全局变量
|
|
192
|
-
"""安全解析可能包含 Markdown 代码块的 JSON"""
|
|
193
|
-
refined_content = refined_content.strip()
|
|
194
|
-
if refined_content.startswith('```json') and refined_content.endswith('```'):
|
|
195
|
-
refined_content = refined_content[7:-3].strip() # 去除 ```json 和 ```
|
|
196
|
-
try:
|
|
197
|
-
# 解析工具列表
|
|
198
|
-
parsed_data: Dict[str, List[Dict]] = json.loads(refined_content)
|
|
199
|
-
valid_tools = {tool["name"].strip().lower() for tool in parsed_data.get("tools", [])}
|
|
200
|
-
|
|
201
|
-
# 原地过滤操作
|
|
202
|
-
filtered_schemas: List[Dict] = []
|
|
203
|
-
for schema in _OPENAI_FUNCTION_SCHEMAS:
|
|
204
|
-
if not isinstance(schema, dict):
|
|
205
|
-
continue
|
|
219
|
+
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
|
|
206
220
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
if schema_name in valid_tools:
|
|
212
|
-
filtered_schemas.append(schema)
|
|
221
|
+
if self.debug:
|
|
222
|
+
console_handler = logging.StreamHandler()
|
|
223
|
+
console_handler.setFormatter(formatter)
|
|
224
|
+
logger.addHandler(console_handler)
|
|
213
225
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
226
|
+
if log_file:
|
|
227
|
+
# 确保 log 目录存在
|
|
228
|
+
log_dir = os.path.dirname(log_file)
|
|
229
|
+
if log_dir and not os.path.exists(log_dir):
|
|
230
|
+
os.makedirs(log_dir)
|
|
217
231
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
232
|
+
file_handler = logging.FileHandler(log_file)
|
|
233
|
+
file_handler.setFormatter(formatter)
|
|
234
|
+
logger.addHandler(file_handler)
|
|
235
|
+
|
|
236
|
+
return logger
|
|
237
|
+
|
|
238
|
+
def log(self, level: str, action: str, data: Any):
|
|
239
|
+
"""记录日志"""
|
|
240
|
+
if not self.debug:
|
|
241
|
+
return
|
|
242
|
+
|
|
243
|
+
trace_info = f"[TraceID: {self.traceid}] " if self.traceid else ""
|
|
244
|
+
log_message = f"{trace_info}{action}: {data}"
|
|
245
|
+
|
|
246
|
+
if level == "DEBUG":
|
|
247
|
+
self.logger.debug(log_message)
|
|
248
|
+
elif level == "INFO":
|
|
249
|
+
self.logger.info(log_message)
|
|
250
|
+
elif level == "ERROR":
|
|
251
|
+
self.logger.error(log_message)
|
|
252
|
+
|
|
253
|
+
def set_traceid(self, traceid: str):
|
|
254
|
+
"""设置当前跟踪ID"""
|
|
255
|
+
self.traceid = traceid
|
|
222
256
|
|
|
223
257
|
|
|
224
258
|
class MCPClientManager:
|
|
225
259
|
"""增强版MCP客户端管理器"""
|
|
226
260
|
|
|
227
|
-
def __init__(self, config: dict):
|
|
261
|
+
def __init__(self, config: dict, tool_registry: ToolRegistry):
|
|
228
262
|
self.config = config
|
|
263
|
+
self.tool_registry = tool_registry
|
|
229
264
|
self.session: Optional[ClientSession] = None
|
|
230
265
|
self.exit_stack = AsyncExitStack()
|
|
231
|
-
self.
|
|
232
|
-
self._session_context = None
|
|
233
|
-
self.server_sessions = {} # 存储不同服务器的会话
|
|
234
|
-
|
|
235
|
-
async def _call_tool_wrapper(self, tool_name: str, target_server: str, **kwargs):
|
|
236
|
-
"""参数转换适配器"""
|
|
237
|
-
return await self.call_tool(
|
|
238
|
-
tool_name=tool_name,
|
|
239
|
-
arguments=kwargs,
|
|
240
|
-
target_server=target_server
|
|
241
|
-
)
|
|
266
|
+
self.server_sessions = {}
|
|
242
267
|
|
|
243
268
|
async def _create_session(self, server_name: str, config: dict):
|
|
244
269
|
"""创建并管理会话上下文"""
|
|
245
270
|
if 'url' in config:
|
|
246
271
|
# SSE 服务器连接
|
|
247
|
-
|
|
272
|
+
streams_context = sse_client(
|
|
248
273
|
url=config['url'],
|
|
249
274
|
headers=config.get('headers', {})
|
|
250
275
|
)
|
|
251
|
-
streams = await self.exit_stack.enter_async_context(
|
|
252
|
-
|
|
253
|
-
self.session = await self.exit_stack.enter_async_context(
|
|
276
|
+
streams = await self.exit_stack.enter_async_context(streams_context)
|
|
277
|
+
session_context = ClientSession(*streams)
|
|
278
|
+
self.session = await self.exit_stack.enter_async_context(session_context)
|
|
254
279
|
else:
|
|
255
280
|
# 标准输入输出服务器连接
|
|
256
281
|
server_params = StdioServerParameters(
|
|
@@ -260,23 +285,19 @@ class MCPClientManager:
|
|
|
260
285
|
)
|
|
261
286
|
transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
|
262
287
|
stdio, write = transport
|
|
263
|
-
|
|
264
|
-
self.session = await self.exit_stack.enter_async_context(
|
|
288
|
+
session_context = ClientSession(stdio, write)
|
|
289
|
+
self.session = await self.exit_stack.enter_async_context(session_context)
|
|
265
290
|
|
|
266
291
|
await self.session.initialize()
|
|
267
292
|
self.server_sessions[server_name] = self.session
|
|
268
293
|
|
|
269
294
|
async def cleanup(self):
|
|
270
295
|
"""清理所有会话资源"""
|
|
271
|
-
await self.exit_stack.
|
|
296
|
+
await self.exit_stack.aclose()
|
|
272
297
|
self.server_sessions.clear()
|
|
273
298
|
|
|
274
299
|
async def register_mcp_tool(self) -> bool:
|
|
275
|
-
"""
|
|
276
|
-
自动注册所有MCP服务的工具到全局字典
|
|
277
|
-
:param config: MCP服务配置(与call_tool使用的相同配置)
|
|
278
|
-
:return: 是否至少成功注册一个工具
|
|
279
|
-
"""
|
|
300
|
+
"""自动注册所有MCP服务的工具"""
|
|
280
301
|
registered_count = 0
|
|
281
302
|
enabled_servers = [
|
|
282
303
|
(name, config)
|
|
@@ -286,15 +307,10 @@ class MCPClientManager:
|
|
|
286
307
|
|
|
287
308
|
for server_name, config in enabled_servers:
|
|
288
309
|
try:
|
|
289
|
-
# 创建会话连接
|
|
290
|
-
# print(server_name,config)
|
|
291
310
|
await self._create_session(server_name, config)
|
|
292
|
-
|
|
293
|
-
# 获取工具列表
|
|
294
311
|
tools_response = await self.session.list_tools()
|
|
295
|
-
print(f"🔍 Registering tools for server : {server_name} ...")
|
|
312
|
+
print(f"🔍 Registering MCP tools for server : {server_name} ...")
|
|
296
313
|
|
|
297
|
-
# 注册工具处理逻辑
|
|
298
314
|
for tool in tools_response.tools:
|
|
299
315
|
try:
|
|
300
316
|
# 构建工具元数据
|
|
@@ -316,9 +332,9 @@ class MCPClientManager:
|
|
|
316
332
|
"required": param_name in required_fields
|
|
317
333
|
})
|
|
318
334
|
|
|
319
|
-
#
|
|
320
|
-
|
|
321
|
-
|
|
335
|
+
# 注册到工具注册表
|
|
336
|
+
self.tool_registry.function_info[tool.name] = tool_info
|
|
337
|
+
self.tool_registry.function_mappings[tool.name] = partial(
|
|
322
338
|
self._call_tool_wrapper,
|
|
323
339
|
tool_name=tool.name,
|
|
324
340
|
target_server=server_name
|
|
@@ -340,31 +356,27 @@ class MCPClientManager:
|
|
|
340
356
|
}
|
|
341
357
|
}
|
|
342
358
|
}
|
|
343
|
-
|
|
344
|
-
|
|
359
|
+
self.tool_registry.openai_function_schemas.append(openai_schema)
|
|
345
360
|
registered_count += 1
|
|
346
|
-
print(f"✅ The registered tool : {tool.name}")
|
|
347
|
-
|
|
361
|
+
print(f"✅ The registered MCP tool : {tool.name}")
|
|
348
362
|
except Exception as e:
|
|
349
|
-
print(f"⚠️ 工具 {tool.name} 注册失败: {str(e)}")
|
|
350
363
|
continue
|
|
351
|
-
print(f"🟢 {registered_count} tools have been registered")
|
|
352
|
-
|
|
353
364
|
except Exception as e:
|
|
354
|
-
print(f"🔴 服务器 {server_name} 连接失败: {str(e)}")
|
|
355
365
|
continue
|
|
356
|
-
|
|
366
|
+
|
|
357
367
|
await self.cleanup()
|
|
358
368
|
return registered_count > 0
|
|
359
369
|
|
|
370
|
+
async def _call_tool_wrapper(self, tool_name: str, target_server: str, **kwargs):
|
|
371
|
+
"""参数转换适配器"""
|
|
372
|
+
return await self.call_tool(
|
|
373
|
+
tool_name=tool_name,
|
|
374
|
+
arguments=kwargs,
|
|
375
|
+
target_server=target_server
|
|
376
|
+
)
|
|
377
|
+
|
|
360
378
|
async def call_tool(self, tool_name: str, arguments: dict, target_server: str = None):
|
|
361
|
-
"""
|
|
362
|
-
通用工具调用方法
|
|
363
|
-
:param tool_name: 要调用的工具名称
|
|
364
|
-
:param arguments: 工具参数字典
|
|
365
|
-
:param target_server: 指定服务器名称(可选)
|
|
366
|
-
:return: 工具调用结果
|
|
367
|
-
"""
|
|
379
|
+
"""通用工具调用方法"""
|
|
368
380
|
enabled_servers = [
|
|
369
381
|
(name, config)
|
|
370
382
|
for name, config in self.config["mcpServers"].items()
|
|
@@ -376,15 +388,11 @@ class MCPClientManager:
|
|
|
376
388
|
|
|
377
389
|
for server_name, config in enabled_servers:
|
|
378
390
|
try:
|
|
379
|
-
# 复用已建立的会话
|
|
380
391
|
session = self.server_sessions.get(server_name)
|
|
381
|
-
# print(111,server_name,session)
|
|
382
|
-
# print(222,server_name,config)
|
|
383
392
|
if not session:
|
|
384
393
|
await self._create_session(server_name, config)
|
|
385
394
|
session = self.session
|
|
386
395
|
|
|
387
|
-
# 获取工具列表
|
|
388
396
|
tools = await session.list_tools()
|
|
389
397
|
available_tools = {t.name: t for t in tools.tools}
|
|
390
398
|
|
|
@@ -395,31 +403,26 @@ class MCPClientManager:
|
|
|
395
403
|
|
|
396
404
|
# 执行调用
|
|
397
405
|
result = await session.call_tool(tool_name, arguments)
|
|
398
|
-
# print(f"mcp工具运行结果: {result.content[0].text}")
|
|
399
|
-
# 调用完成清理session
|
|
400
406
|
await self.cleanup()
|
|
401
407
|
return {
|
|
402
408
|
"server": server_name,
|
|
403
409
|
"tool": tool_name,
|
|
404
410
|
"result": result.content[0].text
|
|
405
411
|
}
|
|
406
|
-
|
|
407
412
|
except Exception as e:
|
|
408
|
-
print(f"调用服务器 {server_name} 失败: {str(e)}")
|
|
409
413
|
continue
|
|
410
414
|
|
|
411
415
|
raise ValueError(f"工具 {tool_name} 在可用服务器中未找到")
|
|
412
416
|
|
|
413
417
|
def _validate_arguments(self, arguments: dict, schema: dict):
|
|
414
|
-
"""
|
|
418
|
+
"""简单参数校验"""
|
|
415
419
|
required_fields = schema.get("required", [])
|
|
416
420
|
for field in required_fields:
|
|
417
421
|
if field not in arguments:
|
|
418
422
|
raise ValueError(f"缺少必要参数: {field}")
|
|
419
423
|
|
|
420
|
-
|
|
421
424
|
class LightAgent:
|
|
422
|
-
__version__ = "0.
|
|
425
|
+
__version__ = "0.4.0" # 将版本号放在类中
|
|
423
426
|
|
|
424
427
|
def __init__(
|
|
425
428
|
self,
|
|
@@ -431,7 +434,7 @@ class LightAgent:
|
|
|
431
434
|
api_key: str | None = None, # 模型 api key
|
|
432
435
|
base_url: str | httpx.URL | None = None, # 模型 base url
|
|
433
436
|
websocket_base_url: str | httpx.URL | None = None, # 模型 websocket base url
|
|
434
|
-
memory:
|
|
437
|
+
memory: Optional[MemoryProtocol] = None, # 支持外部传入记忆模块
|
|
435
438
|
tree_of_thought: bool = False, # 是否启用链式思考
|
|
436
439
|
tot_model: str | None = None, # 链式思考模型
|
|
437
440
|
tot_api_key: str | None = None, # 链式思考模型API密钥
|
|
@@ -466,6 +469,13 @@ class LightAgent:
|
|
|
466
469
|
:param log_file: 日志文件路径。
|
|
467
470
|
:param tracetools: log跟踪工具。
|
|
468
471
|
"""
|
|
472
|
+
|
|
473
|
+
# 初始化核心组件
|
|
474
|
+
self.tool_registry = ToolRegistry()
|
|
475
|
+
self.tool_loader = ToolLoader()
|
|
476
|
+
self.tool_dispatcher = AsyncToolDispatcher()
|
|
477
|
+
self.tool_dispatcher.function_mappings = self.tool_registry.function_mappings
|
|
478
|
+
|
|
469
479
|
self.mcp_setting = None
|
|
470
480
|
self.mcp_client = None
|
|
471
481
|
if not model:
|
|
@@ -500,7 +510,14 @@ class LightAgent:
|
|
|
500
510
|
if debug:
|
|
501
511
|
self.log_file = os.path.join(log_dir, log_file)
|
|
502
512
|
# Set up the logger
|
|
503
|
-
|
|
513
|
+
# 初始化日志系统
|
|
514
|
+
self.logger = LoggerManager(
|
|
515
|
+
name=self.name,
|
|
516
|
+
debug=debug,
|
|
517
|
+
log_level=log_level,
|
|
518
|
+
log_file=log_file
|
|
519
|
+
)
|
|
520
|
+
|
|
504
521
|
if tools is None:
|
|
505
522
|
self.tools = []
|
|
506
523
|
if tools:
|
|
@@ -515,49 +532,54 @@ class LightAgent:
|
|
|
515
532
|
)
|
|
516
533
|
self.api_key = api_key
|
|
517
534
|
self.websocket_base_url = websocket_base_url
|
|
518
|
-
|
|
519
|
-
if base_url is None:
|
|
520
|
-
base_url = f"https://api.openai.com/v1"
|
|
535
|
+
self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
|
|
521
536
|
|
|
522
537
|
if self.tree_of_thought:
|
|
523
538
|
if tot_api_key is None:
|
|
524
|
-
tot_api_key = api_key
|
|
539
|
+
tot_api_key = self.api_key
|
|
525
540
|
if tot_base_url is None:
|
|
526
|
-
tot_base_url = base_url
|
|
541
|
+
tot_base_url = self.base_url
|
|
527
542
|
if not tot_model:
|
|
528
543
|
tot_model = "deepseek-r1" # 默认思维推理模型为deepseek-r1
|
|
529
544
|
self.tot_model = tot_model
|
|
530
545
|
|
|
531
|
-
|
|
532
|
-
|
|
546
|
+
# 初始化客户端
|
|
547
|
+
self._initialize_clients(tracetools, tot_api_key, tot_base_url, tot_model)
|
|
548
|
+
|
|
549
|
+
def _initialize_clients(self, tracetools, tot_api_key, tot_base_url, tot_model):
|
|
550
|
+
"""初始化 OpenAI 客户端"""
|
|
533
551
|
if tracetools:
|
|
534
|
-
self.tracetools = tracetools
|
|
535
|
-
# 初始化工具列表
|
|
536
552
|
from langfuse.openai import openai as la_openai
|
|
537
|
-
la_openai.langfuse_public_key =
|
|
538
|
-
la_openai.langfuse_secret_key =
|
|
539
|
-
la_openai.langfuse_enabled =
|
|
540
|
-
|
|
541
|
-
la_openai.
|
|
542
|
-
la_openai.base_url = base_url
|
|
553
|
+
la_openai.langfuse_public_key = tracetools['TraceToolConfig']['langfuse_public_key']
|
|
554
|
+
la_openai.langfuse_secret_key = tracetools['TraceToolConfig']['langfuse_secret_key']
|
|
555
|
+
la_openai.langfuse_enabled = tracetools['TraceToolConfig']['langfuse_enabled']
|
|
556
|
+
la_openai.langfuse_host = tracetools['TraceToolConfig']['langfuse_host']
|
|
557
|
+
la_openai.base_url = self.base_url
|
|
543
558
|
la_openai.api_key = self.api_key
|
|
544
559
|
self.client = la_openai
|
|
560
|
+
|
|
545
561
|
if self.tree_of_thought:
|
|
546
|
-
la_openai.base_url = tot_base_url
|
|
547
|
-
la_openai.api_key = tot_api_key
|
|
562
|
+
la_openai.base_url = tot_base_url or self.base_url
|
|
563
|
+
la_openai.api_key = tot_api_key or self.api_key
|
|
548
564
|
self.tot_client = la_openai
|
|
549
565
|
else:
|
|
550
566
|
from openai import OpenAI as la_openai
|
|
551
567
|
self.client = la_openai(
|
|
552
|
-
base_url=base_url,
|
|
568
|
+
base_url=self.base_url,
|
|
553
569
|
api_key=self.api_key
|
|
554
570
|
)
|
|
555
571
|
if self.tree_of_thought:
|
|
556
572
|
self.tot_client = la_openai(
|
|
557
|
-
base_url=tot_base_url,
|
|
558
|
-
api_key=tot_api_key
|
|
573
|
+
base_url=tot_base_url or self.base_url,
|
|
574
|
+
api_key=tot_api_key or self.api_key
|
|
559
575
|
)
|
|
560
576
|
|
|
577
|
+
def get_tools(self) -> List[Dict[str, Any]]:
|
|
578
|
+
"""
|
|
579
|
+
获取所有工具的描述(OpenAI 格式)
|
|
580
|
+
"""
|
|
581
|
+
return deepcopy(self.tool_registry.get_tools())
|
|
582
|
+
|
|
561
583
|
def get_tool(self, tool_name: str) -> Callable:
|
|
562
584
|
"""
|
|
563
585
|
用于外部可以获取已加载的工具函数
|
|
@@ -568,111 +590,19 @@ class LightAgent:
|
|
|
568
590
|
return self.loaded_tools[tool_name]
|
|
569
591
|
raise ValueError(f"Tool `{tool_name}` is not loaded.")
|
|
570
592
|
|
|
571
|
-
def get_tools(self) -> List[str]:
|
|
572
|
-
"""
|
|
573
|
-
用于外部可以获取已加载的工具函数列表
|
|
574
|
-
:return: 工具函数
|
|
575
|
-
"""
|
|
576
|
-
return list(_FUNCTION_MAPPINGS.keys())
|
|
577
|
-
|
|
578
593
|
def load_tools(self, tool_names: List[Union[str, Callable]], tools_directory: str = "tools"):
|
|
579
|
-
"""
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
tool_info = tool_func.tool_info
|
|
592
|
-
_FUNCTION_INFO[tool_name] = tool_info # 注册工具info信息
|
|
593
|
-
_FUNCTION_MAPPINGS[tool_name] = tool_func
|
|
594
|
-
|
|
595
|
-
# 构建 OpenAI 格式的工具描述
|
|
596
|
-
tool_params_openai = {}
|
|
597
|
-
tool_required = []
|
|
598
|
-
for param in tool_info["tool_params"]:
|
|
599
|
-
tool_params_openai[param["name"]] = {
|
|
600
|
-
"type": param["type"],
|
|
601
|
-
"description": param["description"],
|
|
602
|
-
}
|
|
603
|
-
if param["required"]:
|
|
604
|
-
tool_required.append(param["name"])
|
|
605
|
-
|
|
606
|
-
tool_def_openai = {
|
|
607
|
-
"type": "function",
|
|
608
|
-
"function": {
|
|
609
|
-
"name": tool_name,
|
|
610
|
-
"description": tool_info["tool_description"],
|
|
611
|
-
"parameters": {
|
|
612
|
-
"type": "object",
|
|
613
|
-
"properties": tool_params_openai,
|
|
614
|
-
"required": tool_required,
|
|
615
|
-
},
|
|
616
|
-
}
|
|
617
|
-
}
|
|
618
|
-
_OPENAI_FUNCTION_SCHEMAS.append(tool_def_openai)
|
|
619
|
-
|
|
620
|
-
self.log("DEBUG", "load_tools success", {"tools": tool_name})
|
|
621
|
-
except Exception as e:
|
|
622
|
-
if register_tool_manually([tool_name]):
|
|
623
|
-
self.log("DEBUG", "register_tool_manually success", {"tools": tool_name})
|
|
624
|
-
else:
|
|
625
|
-
self.log("DEBUG", "load_tools error", {"e": e})
|
|
626
|
-
|
|
627
|
-
def _setup_logger(self, log_level: str, log_file: Optional[str] = None) -> logging.Logger:
|
|
628
|
-
"""
|
|
629
|
-
设置日志记录器。
|
|
630
|
-
|
|
631
|
-
:param log_level: 日志级别(INFO, DEBUG, ERROR)。
|
|
632
|
-
:param log_file: 日志文件路径。
|
|
633
|
-
:return: 配置好的日志记录器。
|
|
634
|
-
"""
|
|
635
|
-
logger = logging.getLogger(self.name)
|
|
636
|
-
logger.setLevel(log_level.upper())
|
|
637
|
-
logger.propagate = False # 禁用传播到根日志记录器
|
|
638
|
-
|
|
639
|
-
formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
|
|
640
|
-
|
|
641
|
-
# 输出到控制台
|
|
642
|
-
# 仅在调试模式下输出到控制台
|
|
643
|
-
if self.debug:
|
|
644
|
-
console_handler = logging.StreamHandler()
|
|
645
|
-
console_handler.setFormatter(formatter)
|
|
646
|
-
logger.addHandler(console_handler)
|
|
647
|
-
|
|
648
|
-
# 输出到文件(如果指定了文件路径)
|
|
649
|
-
if log_file:
|
|
650
|
-
file_handler = logging.FileHandler(log_file)
|
|
651
|
-
file_handler.setFormatter(formatter)
|
|
652
|
-
logger.addHandler(file_handler)
|
|
653
|
-
|
|
654
|
-
return logger
|
|
655
|
-
|
|
656
|
-
def log(self, level, action, data):
|
|
657
|
-
"""
|
|
658
|
-
记录日志。
|
|
659
|
-
|
|
660
|
-
:param level: 日志级别(INFO, DEBUG, ERROR)。
|
|
661
|
-
:param action: 日志动作(如 chat, call_tool, retrieve_memory)。
|
|
662
|
-
:param data: 日志数据。
|
|
663
|
-
"""
|
|
664
|
-
if not self.debug:
|
|
665
|
-
return
|
|
666
|
-
if self.traceid is not None:
|
|
667
|
-
log_message = f"[TraceID: {self.traceid}] {action}: {data}"
|
|
668
|
-
else:
|
|
669
|
-
log_message = f"{action}: {data}"
|
|
670
|
-
if level == "DEBUG":
|
|
671
|
-
self.logger.debug(log_message)
|
|
672
|
-
elif level == "INFO":
|
|
673
|
-
self.logger.info(log_message)
|
|
674
|
-
elif level == "ERROR":
|
|
675
|
-
self.logger.error(log_message)
|
|
594
|
+
"""加载并注册工具"""
|
|
595
|
+
for tool in tool_names:
|
|
596
|
+
if isinstance(tool, str):
|
|
597
|
+
try:
|
|
598
|
+
tool_func = self.tool_loader.load_tool(tool)
|
|
599
|
+
self.tool_registry.register_tool(tool_func)
|
|
600
|
+
self.logger.log("DEBUG", "load_tools", {"tool": tool, "status": "success"})
|
|
601
|
+
except Exception as e:
|
|
602
|
+
self.logger.log("ERROR", "load_tools", {"tool": tool, "error": str(e)})
|
|
603
|
+
elif callable(tool) and hasattr(tool, "tool_info"):
|
|
604
|
+
if self.tool_registry.register_tool(tool):
|
|
605
|
+
self.logger.log("DEBUG", "register_tool", {"tool": tool.__name__, "status": "success"})
|
|
676
606
|
|
|
677
607
|
async def setup_mcp(
|
|
678
608
|
self,
|
|
@@ -682,9 +612,9 @@ class LightAgent:
|
|
|
682
612
|
self.mcp_setting = mcp_setting
|
|
683
613
|
"""单独初始化 MCP 模块"""
|
|
684
614
|
if self.mcp_setting and not self.mcp_client:
|
|
685
|
-
self.mcp_client = MCPClientManager(self.mcp_setting)
|
|
615
|
+
self.mcp_client = MCPClientManager(self.mcp_setting, self.tool_registry)
|
|
686
616
|
await self.mcp_client.register_mcp_tool()
|
|
687
|
-
self.log("INFO", "setup_mcp", "MCP 模块初始化成功")
|
|
617
|
+
self.logger.log("INFO", "setup_mcp", "MCP 模块初始化成功")
|
|
688
618
|
|
|
689
619
|
def run(
|
|
690
620
|
self,
|
|
@@ -708,22 +638,13 @@ class LightAgent:
|
|
|
708
638
|
:param metadata: 元数据。
|
|
709
639
|
:return: 代理的回复。
|
|
710
640
|
"""
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
# 1. 判断是否需要转移任务
|
|
719
|
-
# if light_swarm:
|
|
720
|
-
# intent = self._detect_intent(query, light_swarm)
|
|
721
|
-
# if intent and intent.get("transfer_to"):
|
|
722
|
-
# target_agent_name = intent["transfer_to"]
|
|
723
|
-
# self.log("INFO", "detect_intent", {"intent": intent})
|
|
724
|
-
# print(light_swarm.agents[target_agent_name])
|
|
725
|
-
# self._transfer_to_agent(light_swarm.agents[target_agent_name], query, stream=stream)
|
|
726
|
-
# return # 立即结束当前生成器
|
|
641
|
+
# 设置跟踪ID
|
|
642
|
+
traceid = uuid4().hex
|
|
643
|
+
self.logger.set_traceid(traceid)
|
|
644
|
+
self.logger.log("INFO", "run_start", {"query": query, "user_id": user_id, "stream": stream})
|
|
645
|
+
|
|
646
|
+
# 初始化历史记录
|
|
647
|
+
history = history or []
|
|
727
648
|
|
|
728
649
|
# 0. 判断是否需要转移任务
|
|
729
650
|
if light_swarm:
|
|
@@ -735,66 +656,75 @@ class LightAgent:
|
|
|
735
656
|
now = datetime.now()
|
|
736
657
|
current_date = now.strftime("%Y-%m-%d")
|
|
737
658
|
current_time = now.strftime("%H:%M:%S")
|
|
738
|
-
system_prompt =
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
# 3. 从记忆中检索相关内容&保存记忆
|
|
748
|
-
if self.memory:
|
|
749
|
-
related_memories = self.memory.retrieve(query=query, user_id=user_id)
|
|
750
|
-
memory = memory + self._build_context(related_memories)
|
|
751
|
-
self.memory.store(data=query, user_id=user_id)
|
|
752
|
-
if self.self_learning:
|
|
753
|
-
agent_memories = self.memory.retrieve(query=query, user_id=self.name)
|
|
754
|
-
memory = memory + self._build_agent_memory(agent_memories)
|
|
755
|
-
self.memory.store(data=query, user_id=self.name)
|
|
756
|
-
|
|
757
|
-
query = f"{memory}\n##用户提问:\n{query}"
|
|
758
|
-
# print(query)
|
|
659
|
+
system_prompt = (
|
|
660
|
+
f"##代理名称:{self.name}\n"
|
|
661
|
+
f"##代理指令:{self.instructions}\n"
|
|
662
|
+
f"##身份:{self.role}\n"
|
|
663
|
+
f"请一步一步思考来完成用户的要求。尽可能完成用户的回答,如果有补充信息,请参考补充信息来调用工具,直到获取所有满足用户的提问所需的答案。\n"
|
|
664
|
+
f"今日的日期: {current_date} 当前时间: {current_time}"
|
|
665
|
+
)
|
|
666
|
+
# 添加记忆上下文
|
|
667
|
+
query = self._add_memory_context(query, user_id)
|
|
759
668
|
|
|
760
|
-
#
|
|
669
|
+
# 思维链处理
|
|
761
670
|
active_tools = []
|
|
762
671
|
if self.tree_of_thought:
|
|
763
|
-
tot_response, active_tools = self.run_thought(query
|
|
764
|
-
system_prompt
|
|
765
|
-
self.log("DEBUG", "tree_of_thought", {"response": tot_response, "active_tools": active_tools})
|
|
672
|
+
tot_response, active_tools = self.run_thought(query)
|
|
673
|
+
system_prompt += f"\n##以下是问题的补充说明\n{tot_response}"
|
|
674
|
+
self.logger.log("DEBUG", "tree_of_thought", {"response": tot_response, "active_tools": active_tools})
|
|
766
675
|
|
|
767
|
-
#
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
# tools = get_tools() # v0.2.X的工具选取机制
|
|
676
|
+
# 准备API参数
|
|
677
|
+
params = {
|
|
678
|
+
"model": self.model,
|
|
679
|
+
"messages": [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": query}],
|
|
680
|
+
"stream": stream
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
# 添加工具
|
|
684
|
+
tools = active_tools or self.tool_registry.get_tools()
|
|
777
685
|
if tools:
|
|
778
|
-
self.log("DEBUG", "register_tools", {"tools": list(_FUNCTION_MAPPINGS.keys())})
|
|
779
|
-
self.log("DEBUG", "active_tools", {"tools": tools})
|
|
780
686
|
params["tools"] = tools
|
|
781
687
|
params["tool_choice"] = "auto"
|
|
782
688
|
|
|
783
|
-
#
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
for item in history:
|
|
787
|
-
params["messages"].append({"role": item["role"], "content": item["content"]})
|
|
788
|
-
# 最后添加当前用户的查询信息
|
|
789
|
-
params["messages"].append({"role": "user", "content": query})
|
|
689
|
+
# 添加跟踪会话
|
|
690
|
+
if hasattr(self, 'tracetools') and self.tracetools:
|
|
691
|
+
params["session_id"] = traceid
|
|
790
692
|
|
|
693
|
+
# 调用模型
|
|
791
694
|
response = self.client.chat.completions.create(**params)
|
|
695
|
+
return self._core_run_logic(response, params, stream, max_retry)
|
|
696
|
+
|
|
697
|
+
def _add_memory_context(self, query: str, user_id: str) -> str:
|
|
698
|
+
"""添加记忆上下文"""
|
|
699
|
+
if not self.memory:
|
|
700
|
+
return query
|
|
701
|
+
|
|
702
|
+
context = ""
|
|
703
|
+
related_memories = self.memory.retrieve(query=query, user_id=user_id)
|
|
704
|
+
if related_memories and related_memories.get("results"):
|
|
705
|
+
context += "\n##用户偏好\n用户之前提到了:\n" + "\n".join(
|
|
706
|
+
[m["memory"] for m in related_memories["results"]]
|
|
707
|
+
)
|
|
708
|
+
self.memory.store(data=query, user_id=user_id)
|
|
792
709
|
|
|
793
|
-
|
|
710
|
+
if self.self_learning:
|
|
711
|
+
agent_memories = self.memory.retrieve(query=query, user_id=self.name)
|
|
712
|
+
if agent_memories and agent_memories.get("results"):
|
|
713
|
+
context += "\n##问题相关补充信息:\n" + "\n".join(
|
|
714
|
+
[m["memory"] for m in agent_memories["results"]]
|
|
715
|
+
)
|
|
716
|
+
self.memory.store(data=query, user_id=self.name)
|
|
794
717
|
|
|
795
|
-
return
|
|
718
|
+
return f"{context}\n##用户提问:\n{query}" if context else query
|
|
796
719
|
|
|
797
|
-
def
|
|
720
|
+
def _core_run_logic(self, response, params, stream, max_retry) -> Union[Generator[str, None, None], str]:
|
|
721
|
+
"""核心运行逻辑"""
|
|
722
|
+
if stream:
|
|
723
|
+
return self._run_stream_logic(response, params, max_retry)
|
|
724
|
+
else:
|
|
725
|
+
return self._run_non_stream_logic(response, params, max_retry)
|
|
726
|
+
|
|
727
|
+
def _run_non_stream_logic(self, response, params, max_retry) -> Union[str, None]:
|
|
798
728
|
"""
|
|
799
729
|
非流式处理逻辑。
|
|
800
730
|
"""
|
|
@@ -806,7 +736,7 @@ class LightAgent:
|
|
|
806
736
|
output = ""
|
|
807
737
|
function_call_name = ""
|
|
808
738
|
tool_calls = response.choices[0].message.tool_calls
|
|
809
|
-
self.log("DEBUG", "non_stream tool_calls", {"tool_calls": tool_calls})
|
|
739
|
+
self.logger.log("DEBUG", "non_stream tool_calls", {"tool_calls": tool_calls})
|
|
810
740
|
|
|
811
741
|
# 遍历所有工具调用
|
|
812
742
|
for tool_call in tool_calls:
|
|
@@ -814,13 +744,13 @@ class LightAgent:
|
|
|
814
744
|
|
|
815
745
|
# 尝试自动修复常见转义问题
|
|
816
746
|
fixed_args = function_call.arguments.replace('\\"', '"').replace('\\\\', '\\')
|
|
817
|
-
self.log("DEBUG", "non_stream function_call", {"function_call": fixed_args})
|
|
747
|
+
self.logger.log("DEBUG", "non_stream function_call", {"function_call": fixed_args})
|
|
818
748
|
|
|
819
749
|
# 解析函数参数
|
|
820
750
|
function_args = json.loads(fixed_args)
|
|
821
751
|
|
|
822
752
|
# 调用工具并获取响应
|
|
823
|
-
tool_response = asyncio.run(
|
|
753
|
+
tool_response = asyncio.run(self.tool_dispatcher.dispatch(function_call.name, function_args))
|
|
824
754
|
function_call_name = function_call.name
|
|
825
755
|
combined_response = ""
|
|
826
756
|
single_tool_response = ""
|
|
@@ -863,13 +793,13 @@ class LightAgent:
|
|
|
863
793
|
pass # 如果不是 JSON 字符串,保持原样
|
|
864
794
|
single_tool_response = combined_response # 处理单个工具的方法
|
|
865
795
|
|
|
866
|
-
self.log("INFO", "non_stream single_tool_response", {"single_tool_response": single_tool_response})
|
|
796
|
+
self.logger.log("INFO", "non_stream single_tool_response", {"single_tool_response": single_tool_response})
|
|
867
797
|
|
|
868
798
|
# 将单个工具的响应结果添加到列表中
|
|
869
799
|
tool_responses.append(single_tool_response)
|
|
870
800
|
|
|
871
801
|
# 将所有工具调用的结果合并为一个字符串
|
|
872
|
-
self.log("DEBUG", "non_stream tool_responses", {"tool_responses": tool_responses})
|
|
802
|
+
self.logger.log("DEBUG", "non_stream tool_responses", {"tool_responses": tool_responses})
|
|
873
803
|
|
|
874
804
|
combined_tool_response = "\n".join(tool_responses)
|
|
875
805
|
|
|
@@ -889,14 +819,14 @@ class LightAgent:
|
|
|
889
819
|
else:
|
|
890
820
|
# 返回最终回复
|
|
891
821
|
reply = response.choices[0].message.content
|
|
892
|
-
self.log("INFO", "non_stream final_reply", {"reply": reply})
|
|
822
|
+
self.logger.log("INFO", "non_stream final_reply", {"reply": reply})
|
|
893
823
|
return reply
|
|
894
824
|
|
|
895
825
|
# 更新响应
|
|
896
826
|
if function_call_name == 'finish':
|
|
897
827
|
return # 如果最后调用了finish工具,则结束生成器
|
|
898
828
|
# print("params:",params)
|
|
899
|
-
self.log("DEBUG", "non_stream chat-completions params", {"params": params})
|
|
829
|
+
self.logger.log("DEBUG", "non_stream chat-completions params", {"params": params})
|
|
900
830
|
|
|
901
831
|
try:
|
|
902
832
|
response = self.client.chat.completions.create(**params)
|
|
@@ -904,20 +834,17 @@ class LightAgent:
|
|
|
904
834
|
print(f"An error occurred: {e}")
|
|
905
835
|
|
|
906
836
|
# 重试次数用尽
|
|
907
|
-
self.log("ERROR", "max_retry_reached", {"message": "Failed to generate a valid response."})
|
|
837
|
+
self.logger.log("ERROR", "max_retry_reached", {"message": "Failed to generate a valid response."})
|
|
908
838
|
return "Failed to generate a valid response."
|
|
909
839
|
|
|
910
|
-
def
|
|
911
|
-
"""
|
|
912
|
-
流式处理逻辑。
|
|
913
|
-
"""
|
|
840
|
+
def _run_stream_logic(self, response, params, max_retry) -> Generator[str, None, None]:
|
|
841
|
+
"""流式处理逻辑"""
|
|
914
842
|
for _ in range(max_retry):
|
|
915
843
|
# 初始化变量
|
|
916
844
|
output = ""
|
|
917
|
-
function_call_name = ""
|
|
918
|
-
function_call_arguments = ""
|
|
919
845
|
tool_calls = [] # 用于存储所有工具调用的信息
|
|
920
846
|
tool_responses = [] # 用于存储所有工具调用的结果
|
|
847
|
+
finish_called = False # 标记是否调用了finish工具
|
|
921
848
|
|
|
922
849
|
for chunk in response:
|
|
923
850
|
content = chunk.choices[0].delta.content or ""
|
|
@@ -936,7 +863,7 @@ class LightAgent:
|
|
|
936
863
|
|
|
937
864
|
# 如果工具调用信息尚未记录,初始化一个空字典
|
|
938
865
|
if len(tool_calls) <= tool_call_index:
|
|
939
|
-
tool_calls.append({"name": "", "arguments": "", "index": tool_call_index})
|
|
866
|
+
tool_calls.append({"name": "", "arguments": "", "index": tool_call_index, "title": ""})
|
|
940
867
|
|
|
941
868
|
# 更新工具调用的名称
|
|
942
869
|
if hasattr(tool_call_delta.function, "name") and tool_call_delta.function.name:
|
|
@@ -946,35 +873,48 @@ class LightAgent:
|
|
|
946
873
|
if hasattr(tool_call_delta.function, "arguments") and tool_call_delta.function.arguments:
|
|
947
874
|
tool_calls[tool_call_index]["arguments"] += tool_call_delta.function.arguments
|
|
948
875
|
|
|
949
|
-
except (IndexError, AttributeError, KeyError):
|
|
950
|
-
|
|
876
|
+
except (IndexError, AttributeError, KeyError) as e:
|
|
877
|
+
self.logger.log("ERROR", "tool_call_error", {
|
|
878
|
+
"error": str(e),
|
|
879
|
+
"traceback": traceback.format_exc()
|
|
880
|
+
})
|
|
951
881
|
|
|
952
882
|
# 如果流式输出结束
|
|
953
|
-
|
|
954
|
-
|
|
883
|
+
finish_reason = chunk.choices[0].finish_reason if chunk.choices else None
|
|
884
|
+
if finish_reason == "stop" and not any(tc["name"] for tc in tool_calls):
|
|
885
|
+
self.logger.log("INFO", "stream_response", {"output": output})
|
|
955
886
|
return # 结束生成器
|
|
956
887
|
|
|
957
888
|
# 如果工具调用结束
|
|
958
|
-
elif
|
|
959
|
-
chunk.choices[0].finish_reason == "stop" and any(
|
|
960
|
-
tool_call["name"] for tool_call in tool_calls)):
|
|
889
|
+
elif finish_reason in ("tool_calls", "stop") and any(tc["name"] for tc in tool_calls):
|
|
961
890
|
# 遍历所有工具调用
|
|
962
|
-
self.log("DEBUG", "stream tool_calls", {"tool_calls": tool_calls})
|
|
891
|
+
self.logger.log("DEBUG", "stream tool_calls", {"tool_calls": tool_calls})
|
|
963
892
|
for tool_call in tool_calls:
|
|
964
893
|
if tool_call["name"]: # 确保工具调用有名称
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
894
|
+
tool_name = tool_call["name"]
|
|
895
|
+
arguments = tool_call["arguments"]
|
|
896
|
+
|
|
897
|
+
# 从注册表中获取工具标题
|
|
898
|
+
tool_info = self.tool_registry.function_info.get(tool_name, {})
|
|
899
|
+
tool_title = tool_info.get("tool_title") or ""
|
|
900
|
+
|
|
901
|
+
# 更新工具调用信息
|
|
902
|
+
tool_call["title"] = tool_title
|
|
903
|
+
|
|
904
|
+
# 记录调用工具
|
|
905
|
+
tool_call_info = {
|
|
906
|
+
"name": tool_name,
|
|
907
|
+
"title": tool_title,
|
|
908
|
+
"arguments": arguments,
|
|
969
909
|
}
|
|
970
|
-
self.log("INFO", "stream function_call", {"
|
|
910
|
+
self.logger.log("INFO", "stream function_call", {"tool_call_start": tool_call_info})
|
|
971
911
|
# 将工具的调用信息推送给开发者
|
|
972
|
-
yield
|
|
912
|
+
yield tool_call_info
|
|
973
913
|
|
|
974
914
|
# 解析参数并调用工具
|
|
975
915
|
try:
|
|
976
916
|
# 使用正则表达式将多个 JSON 对象拆分开
|
|
977
|
-
json_objects = re.findall(r'\{.*?\}',
|
|
917
|
+
json_objects = re.findall(r'\{.*?\}', tool_call_info["arguments"])
|
|
978
918
|
|
|
979
919
|
# 解析每个 JSON 对象并调用工具
|
|
980
920
|
# for json_obj in json_objects:
|
|
@@ -985,12 +925,15 @@ class LightAgent:
|
|
|
985
925
|
for json_obj in json_objects:
|
|
986
926
|
# 尝试自动修复常见转义问题
|
|
987
927
|
fixed_args = json_obj.replace('\\"', '"').replace('\\\\', '\\')
|
|
988
|
-
self.log("DEBUG", "stream fixed_args", {"fixed_args": fixed_args})
|
|
928
|
+
self.logger.log("DEBUG", "stream fixed_args", {"fixed_args": fixed_args})
|
|
989
929
|
|
|
930
|
+
# 解析参数
|
|
990
931
|
function_args = json.loads(fixed_args)
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
932
|
+
|
|
933
|
+
# 调用工具
|
|
934
|
+
tool_response = asyncio.run(self.tool_dispatcher.dispatch(tool_name, function_args))
|
|
935
|
+
|
|
936
|
+
# 处理不同类型的工具响应
|
|
994
937
|
combined_response = ""
|
|
995
938
|
single_tool_response = ""
|
|
996
939
|
|
|
@@ -1003,15 +946,14 @@ class LightAgent:
|
|
|
1003
946
|
yield chunk
|
|
1004
947
|
else:
|
|
1005
948
|
tool_output = {
|
|
1006
|
-
"name":
|
|
1007
|
-
"title":
|
|
1008
|
-
'tool_title') or '',
|
|
949
|
+
"name": tool_name,
|
|
950
|
+
"title": tool_title,
|
|
1009
951
|
"output": chunk,
|
|
1010
952
|
}
|
|
1011
|
-
self.log("DEBUG", "stream tool_output", {"tool_output": tool_output})
|
|
953
|
+
self.logger.log("DEBUG", "stream tool_output", {"tool_output": tool_output})
|
|
1012
954
|
yield tool_output
|
|
1013
955
|
# 将工具的调用信息推送给开发者
|
|
1014
|
-
if
|
|
956
|
+
if tool_name == 'finish':
|
|
1015
957
|
content = chunk.choices[0].delta.content or ""
|
|
1016
958
|
combined_response += content # 将每个 chunk 叠加
|
|
1017
959
|
else:
|
|
@@ -1019,25 +961,54 @@ class LightAgent:
|
|
|
1019
961
|
single_tool_response = combined_response # 处理单个工具的方法
|
|
1020
962
|
else:
|
|
1021
963
|
# print(f"Non-streaming response from tool: {tool_response}")
|
|
1022
|
-
combined_response = tool_response
|
|
964
|
+
combined_response = str(tool_response)
|
|
1023
965
|
single_tool_response = combined_response # 处理单个工具的方法
|
|
1024
|
-
|
|
966
|
+
tool_output = {
|
|
967
|
+
"name": tool_name,
|
|
968
|
+
"title": tool_title,
|
|
969
|
+
"output": combined_response
|
|
970
|
+
}
|
|
971
|
+
yield tool_output
|
|
972
|
+
|
|
973
|
+
# 记录工具响应
|
|
974
|
+
self.logger.log("INFO", "stream single_tool_response",
|
|
1025
975
|
{"single_tool_response": single_tool_response})
|
|
1026
|
-
|
|
976
|
+
|
|
977
|
+
# 将单个工具的响应结果保存到列表中
|
|
1027
978
|
tool_responses.append(single_tool_response)
|
|
1028
979
|
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
980
|
+
# 检查是否调用了finish工具
|
|
981
|
+
if tool_name == 'finish':
|
|
982
|
+
finish_called = True
|
|
983
|
+
self.logger.log("INFO", "finish_tool_called", {"response": combined_response})
|
|
1033
984
|
|
|
1034
|
-
|
|
985
|
+
except json.JSONDecodeError as e:
|
|
986
|
+
error_msg = f"JSON解析错误: {str(e)}\n参数: {arguments}"
|
|
987
|
+
self.logger.log("ERROR", "json_decode_error", {"tool": tool_name, "title": tool_title, "error": error_msg})
|
|
988
|
+
tool_responses.append(error_msg)
|
|
989
|
+
yield {"name": tool_name, "title": tool_title, "error": error_msg}
|
|
990
|
+
|
|
991
|
+
except Exception as e:
|
|
992
|
+
error_msg = f"工具调用错误: {str(e)}\n{traceback.format_exc()}"
|
|
993
|
+
self.logger.log("ERROR", "tool_execution_error", {
|
|
994
|
+
"tool": tool_name,
|
|
995
|
+
"title": tool_title,
|
|
996
|
+
"error": error_msg
|
|
997
|
+
})
|
|
998
|
+
tool_responses.append(error_msg)
|
|
999
|
+
yield {"name": tool_name, "title": tool_title, "error": error_msg}
|
|
1000
|
+
|
|
1001
|
+
# 如果调用了finish工具,则结束处理
|
|
1002
|
+
if finish_called:
|
|
1003
|
+
return
|
|
1004
|
+
|
|
1005
|
+
# 准备下一轮请求
|
|
1035
1006
|
combined_tool_response = "\n".join(tool_responses)
|
|
1036
1007
|
tool_str = json.dumps(
|
|
1037
1008
|
[{"name": tool_call["name"], "arguments": tool_call["arguments"]} for tool_call in tool_calls],
|
|
1038
1009
|
ensure_ascii=False)
|
|
1039
1010
|
|
|
1040
|
-
#
|
|
1011
|
+
# 添加工具调用和响应到消息历史
|
|
1041
1012
|
params["messages"].append(
|
|
1042
1013
|
{
|
|
1043
1014
|
"role": "assistant",
|
|
@@ -1051,26 +1022,15 @@ class LightAgent:
|
|
|
1051
1022
|
}
|
|
1052
1023
|
)
|
|
1053
1024
|
|
|
1025
|
+
# 创建新的响应流
|
|
1026
|
+
self.logger.log("DEBUG", "stream next_request_params", {"params": params})
|
|
1027
|
+
response = self.client.chat.completions.create(**params)
|
|
1054
1028
|
break
|
|
1055
1029
|
|
|
1056
|
-
# 更新响应
|
|
1057
|
-
if function_call_name == 'finish':
|
|
1058
|
-
return # 如果最后调用了finish工具,则结束生成器
|
|
1059
|
-
self.log("DEBUG", "stream chat-completions params", {"params": params})
|
|
1060
|
-
response = self.client.chat.completions.create(**params)
|
|
1061
|
-
|
|
1062
1030
|
# 重试次数用尽
|
|
1063
|
-
self.log("ERROR", "max_retry_reached", {"message": "Failed to generate a valid response."})
|
|
1031
|
+
self.logger.log("ERROR", "max_retry_reached", {"message": "Failed to generate a valid response."})
|
|
1064
1032
|
yield "Failed to generate a valid response."
|
|
1065
1033
|
|
|
1066
|
-
def _core_run_logic(self, response, params, stream, max_retry) -> Union[Generator[str, None, None], str]:
|
|
1067
|
-
"""
|
|
1068
|
-
核心运行逻辑。
|
|
1069
|
-
"""
|
|
1070
|
-
if stream:
|
|
1071
|
-
return self._run_logic_stream(response, params, max_retry)
|
|
1072
|
-
else:
|
|
1073
|
-
return self._run_logic_non_stream(response, params, max_retry)
|
|
1074
1034
|
|
|
1075
1035
|
def _handle_task_transfer(
|
|
1076
1036
|
self,
|
|
@@ -1089,9 +1049,9 @@ class LightAgent:
|
|
|
1089
1049
|
intent = self._detect_intent(query, light_swarm)
|
|
1090
1050
|
if intent and intent.get("transfer_to"):
|
|
1091
1051
|
target_agent_name = intent["transfer_to"]
|
|
1092
|
-
self.log("INFO", "detect_intent", {"intent": intent})
|
|
1052
|
+
self.logger.log("INFO", "detect_intent", {"intent": intent})
|
|
1093
1053
|
if target_agent_name == self.name:
|
|
1094
|
-
self.log("INFO", "self_transfer_detected", {"target_agent": target_agent_name})
|
|
1054
|
+
self.logger.log("INFO", "self_transfer_detected", {"target_agent": target_agent_name})
|
|
1095
1055
|
return None # 如果是自身,直接返回 None
|
|
1096
1056
|
if stream:
|
|
1097
1057
|
return self._handle_task_transfer_stream(light_swarm.agents[target_agent_name], query, light_swarm)
|
|
@@ -1113,18 +1073,18 @@ class LightAgent:
|
|
|
1113
1073
|
:param light_swarm: LightSwarm 实例。
|
|
1114
1074
|
:return: 生成器,用于流式输出。
|
|
1115
1075
|
"""
|
|
1116
|
-
self.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
|
|
1076
|
+
self.logger.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
|
|
1117
1077
|
|
|
1118
1078
|
# 检查目标代理是否有效
|
|
1119
1079
|
if not hasattr(target_agent, 'run'):
|
|
1120
|
-
self.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
|
|
1080
|
+
self.logger.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
|
|
1121
1081
|
yield "Failed to transfer task: invalid target agent"
|
|
1122
1082
|
return
|
|
1123
1083
|
|
|
1124
1084
|
try:
|
|
1125
1085
|
yield from target_agent.run(context, light_swarm=light_swarm, stream=True)
|
|
1126
1086
|
except Exception as e:
|
|
1127
|
-
self.log("ERROR", "run_failed", {"error": str(e)})
|
|
1087
|
+
self.logger.log("ERROR", "run_failed", {"error": str(e)})
|
|
1128
1088
|
raise # 重新抛出异常以便调试
|
|
1129
1089
|
|
|
1130
1090
|
def _handle_task_transfer_non_stream(
|
|
@@ -1141,11 +1101,11 @@ class LightAgent:
|
|
|
1141
1101
|
:param light_swarm: LightSwarm 实例。
|
|
1142
1102
|
:return: 字符串,表示非流式输出结果。
|
|
1143
1103
|
"""
|
|
1144
|
-
self.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
|
|
1104
|
+
self.logger.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
|
|
1145
1105
|
|
|
1146
1106
|
# 检查目标代理是否有效
|
|
1147
1107
|
if not hasattr(target_agent, 'run'):
|
|
1148
|
-
self.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
|
|
1108
|
+
self.logger.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
|
|
1149
1109
|
return "Failed to transfer task: invalid target agent"
|
|
1150
1110
|
|
|
1151
1111
|
try:
|
|
@@ -1154,7 +1114,7 @@ class LightAgent:
|
|
|
1154
1114
|
return "".join(result) # 将生成器转换为字符串
|
|
1155
1115
|
return result
|
|
1156
1116
|
except Exception as e:
|
|
1157
|
-
self.log("ERROR", "run_failed", {"error": str(e)})
|
|
1117
|
+
self.logger.log("ERROR", "run_failed", {"error": str(e)})
|
|
1158
1118
|
raise # 重新抛出异常以便调试
|
|
1159
1119
|
|
|
1160
1120
|
def _build_context(self, related_memories):
|
|
@@ -1171,7 +1131,7 @@ class LightAgent:
|
|
|
1171
1131
|
return ""
|
|
1172
1132
|
|
|
1173
1133
|
prompt = f"\n##用户偏好 \n用户之前提到了\n{memory_context}。"
|
|
1174
|
-
self.log("DEBUG", "related_memories", {"memory_context": memory_context})
|
|
1134
|
+
self.logger.log("DEBUG", "related_memories", {"memory_context": memory_context})
|
|
1175
1135
|
return prompt
|
|
1176
1136
|
|
|
1177
1137
|
def _build_agent_memory(self, agent_memories):
|
|
@@ -1189,20 +1149,20 @@ class LightAgent:
|
|
|
1189
1149
|
return ""
|
|
1190
1150
|
|
|
1191
1151
|
prompt = f"\n##以下是解决该问题的相关补充信息:\n{memory_context}。"
|
|
1192
|
-
self.log("DEBUG", "agent_memories", {"memory_context": memory_context})
|
|
1152
|
+
self.logger.log("DEBUG", "agent_memories", {"memory_context": memory_context})
|
|
1193
1153
|
return prompt
|
|
1194
1154
|
|
|
1195
1155
|
def run_thought(self, query: str) -> tuple:
|
|
1196
1156
|
"""使用思维树的方式 让大模型先根据get_tools_str生成一个解答用户query的工具使用计划"""
|
|
1197
1157
|
tot_model = self.tot_model # self.model
|
|
1198
|
-
tools = get_tools_str()
|
|
1158
|
+
tools = self.tool_registry.get_tools_str()
|
|
1199
1159
|
if not isinstance(tools, str):
|
|
1200
1160
|
tools = str(tools) # 确保 tools 是字符串
|
|
1201
1161
|
now = datetime.now()
|
|
1202
1162
|
current_date = now.strftime("%Y-%m-%d")
|
|
1203
1163
|
current_time = now.strftime("%H:%M:%S")
|
|
1204
1164
|
system_prompt = f"""你是一个智能助手,请根据用户输入的问题,结合工具使用计划,生成一个思维树,并按照思维树依次调用工具步骤,最终生成一个最终回答。\n 今日的日期: {current_date} 当前时间: {current_time} \n 工具列表: {tools}"""
|
|
1205
|
-
self.log("DEBUG", "run_thought", {"system_prompt": system_prompt})
|
|
1165
|
+
self.logger.log("DEBUG", "run_thought", {"system_prompt": system_prompt})
|
|
1206
1166
|
|
|
1207
1167
|
try:
|
|
1208
1168
|
# 1. 第一次请求,生成初始的工具使用计划
|
|
@@ -1210,20 +1170,20 @@ class LightAgent:
|
|
|
1210
1170
|
messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": query}],
|
|
1211
1171
|
stream=False)
|
|
1212
1172
|
response = self.tot_client.chat.completions.create(**params)
|
|
1213
|
-
|
|
1214
|
-
self.log("DEBUG", "
|
|
1173
|
+
thought_response = response.choices[0].message.content
|
|
1174
|
+
self.logger.log("DEBUG", "thought_response", {"response": thought_response})
|
|
1215
1175
|
|
|
1216
1176
|
# 2. 第二次请求,请求大模型反思并生成新的工具使用规划
|
|
1217
1177
|
reflection_prompt = "请反思你的回答,请严格按照<工具列表>中的工具来规划,不可以创造其他新的工具。请输出新的任务规划,不要输出其他分析和回答。"
|
|
1218
1178
|
reflection_params = dict(model=tot_model, messages=[
|
|
1219
1179
|
{"role": "user", "content": f"{system_prompt} /n 开始思考问题: {query}"},
|
|
1220
|
-
{"role": "assistant", "content":
|
|
1180
|
+
{"role": "assistant", "content": thought_response},
|
|
1221
1181
|
{"role": "user", "content": reflection_prompt}
|
|
1222
1182
|
], stream=False)
|
|
1223
|
-
self.log("DEBUG", "reflection_params", {"params": reflection_params})
|
|
1183
|
+
self.logger.log("DEBUG", "reflection_params", {"params": reflection_params})
|
|
1224
1184
|
reflection_response = self.tot_client.chat.completions.create(**reflection_params)
|
|
1225
1185
|
refined_content = reflection_response.choices[0].message.content
|
|
1226
|
-
self.log("DEBUG", "
|
|
1186
|
+
self.logger.log("DEBUG", "reflection_response", {"response": refined_content})
|
|
1227
1187
|
|
|
1228
1188
|
# 获取工具的使用集合
|
|
1229
1189
|
tool_reflection_prompt = """请严格按以下要求执行:
|
|
@@ -1245,21 +1205,21 @@ class LightAgent:
|
|
|
1245
1205
|
stream=False
|
|
1246
1206
|
)
|
|
1247
1207
|
|
|
1248
|
-
self.log("DEBUG", "tool_reflection_params", {"params": tool_reflection_params})
|
|
1208
|
+
self.logger.log("DEBUG", "tool_reflection_params", {"params": tool_reflection_params})
|
|
1249
1209
|
tool_reflection_response = self.tot_client.chat.completions.create(**tool_reflection_params)
|
|
1250
1210
|
tool_reflection_result = tool_reflection_response.choices[0].message.content
|
|
1251
|
-
self.log("DEBUG", "tool_reflection_result", {"result": tool_reflection_result})
|
|
1211
|
+
self.logger.log("DEBUG", "tool_reflection_result", {"result": tool_reflection_result})
|
|
1252
1212
|
|
|
1253
1213
|
# 3.执行自适应工具过滤
|
|
1254
1214
|
current_tools = []
|
|
1255
1215
|
if self.filter_tools:
|
|
1256
|
-
current_tools =
|
|
1257
|
-
self.log("DEBUG", "current_tools", {"get_tools": current_tools})
|
|
1216
|
+
current_tools = self.tool_registry.filter_tools(tool_reflection_result)
|
|
1217
|
+
self.logger.log("DEBUG", "current_tools", {"get_tools": current_tools})
|
|
1258
1218
|
|
|
1259
1219
|
return refined_content, current_tools
|
|
1260
1220
|
|
|
1261
1221
|
except Exception as e:
|
|
1262
|
-
self.log("ERROR", "run_thought_failure", {"error": str(e)})
|
|
1222
|
+
self.logger.log("ERROR", "run_thought_failure", {"error": str(e)})
|
|
1263
1223
|
raise RuntimeError(f"思维链执行失败: {str(e)}") from e
|
|
1264
1224
|
|
|
1265
1225
|
def _detect_intent(self, query: str, light_swarm=None) -> Optional[Dict]:
|
|
@@ -1296,7 +1256,7 @@ class LightAgent:
|
|
|
1296
1256
|
messages=[{"role": "system", "content": prompt}]
|
|
1297
1257
|
)
|
|
1298
1258
|
intent = response.choices[0].message.content
|
|
1299
|
-
self.log("DEBUG", "detect_intent", {"intent": intent})
|
|
1259
|
+
self.logger.log("DEBUG", "detect_intent", {"intent": intent})
|
|
1300
1260
|
|
|
1301
1261
|
# # 使用正则表达式解析意图
|
|
1302
1262
|
# match = re.search(r"transfer to (\w+)", intent, re.IGNORECASE)
|
|
@@ -1328,11 +1288,11 @@ class LightAgent:
|
|
|
1328
1288
|
:param stream: 是否启用流式输出。
|
|
1329
1289
|
:return: 如果 stream=True,返回生成器;否则返回完整结果字符串。
|
|
1330
1290
|
"""
|
|
1331
|
-
self.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
|
|
1291
|
+
self.logger.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
|
|
1332
1292
|
|
|
1333
1293
|
# 检查目标代理是否有效
|
|
1334
1294
|
if not hasattr(target_agent, 'run'):
|
|
1335
|
-
self.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
|
|
1295
|
+
self.logger.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
|
|
1336
1296
|
return "Failed to transfer task: invalid target agent"
|
|
1337
1297
|
#
|
|
1338
1298
|
# # 调用目标代理的 run 方法
|
|
@@ -1352,7 +1312,7 @@ class LightAgent:
|
|
|
1352
1312
|
return "".join(result) # 将生成器转换为字符串
|
|
1353
1313
|
return result
|
|
1354
1314
|
except Exception as e:
|
|
1355
|
-
self.log("ERROR", "run_failed", {"error": str(e)})
|
|
1315
|
+
self.logger.log("ERROR", "run_failed", {"error": str(e)})
|
|
1356
1316
|
raise # 重新抛出异常以便调试
|
|
1357
1317
|
|
|
1358
1318
|
def create_tool(self, user_input: str, tools_directory: str = "tools"):
|
|
@@ -1429,19 +1389,19 @@ get_weather.tool_info = {
|
|
|
1429
1389
|
tool_code = tool_data.get("tool_code")
|
|
1430
1390
|
|
|
1431
1391
|
if not tool_name or not tool_code:
|
|
1432
|
-
self.log("ERROR", "invalid_tool_data", {"tool_data": tool_data})
|
|
1392
|
+
self.logger.log("ERROR", "invalid_tool_data", {"tool_data": tool_data})
|
|
1433
1393
|
continue
|
|
1434
1394
|
|
|
1435
1395
|
# 保存生成的代码到 tools 目录
|
|
1436
1396
|
tool_path = os.path.join(tools_directory, f"{tool_name}.py")
|
|
1437
1397
|
with open(tool_path, "w", encoding="utf-8") as f:
|
|
1438
1398
|
f.write(tool_code)
|
|
1439
|
-
self.log("INFO", "tool_created", {"tool_name": tool_name, "tool_path": tool_path})
|
|
1399
|
+
self.logger.log("INFO", "tool_created", {"tool_name": tool_name, "tool_path": tool_path})
|
|
1440
1400
|
|
|
1441
1401
|
# 自动加载新创建的工具
|
|
1442
1402
|
self.load_tools([tool_name], tools_directory)
|
|
1443
1403
|
except Exception as e:
|
|
1444
|
-
self.log("ERROR", "tool_creation_failed", {"error": str(e)})
|
|
1404
|
+
self.logger.log("ERROR", "tool_creation_failed", {"error": str(e)})
|
|
1445
1405
|
|
|
1446
1406
|
|
|
1447
1407
|
class LightSwarm:
|
|
@@ -1457,11 +1417,11 @@ class LightSwarm:
|
|
|
1457
1417
|
for agent in agents:
|
|
1458
1418
|
if agent.name in self.agents:
|
|
1459
1419
|
# print(f"Agent '{agent.name}' is already registered.")
|
|
1460
|
-
agent.log("INFO", "register_agent", {"agent_name": agent.name, "status": "already_registered"})
|
|
1420
|
+
agent.logger.log("INFO", "register_agent", {"agent_name": agent.name, "status": "already_registered"})
|
|
1461
1421
|
else:
|
|
1462
1422
|
self.agents[agent.name] = agent
|
|
1463
1423
|
# print(f"Agent '{agent.name}' registered.")
|
|
1464
|
-
agent.log("INFO", "register_agent", {"agent_name": agent.name, "status": "registered"})
|
|
1424
|
+
agent.logger.log("INFO", "register_agent", {"agent_name": agent.name, "status": "registered"})
|
|
1465
1425
|
|
|
1466
1426
|
def run(self, agent: LightAgent, query: str, stream=False):
|
|
1467
1427
|
"""
|