LightAgent 0.3.3__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,8 +3,8 @@
3
3
 
4
4
  """
5
5
  作者: [weego/WXAI-Team]
6
- 版本: 0.3.3
7
- 最后更新: 2025-05-05
6
+ 版本: 0.4.0
7
+ 最后更新: 2025-06-12
8
8
  """
9
9
 
10
10
  import asyncio
@@ -21,7 +21,7 @@ from contextlib import AsyncExitStack
21
21
  from copy import deepcopy
22
22
  from datetime import datetime
23
23
  from functools import partial
24
- from typing import List, Dict, Any, Callable, Union, Optional, Generator, AsyncGenerator
24
+ from typing import List, Dict, Any, Callable, Union, Optional, Generator, AsyncGenerator, Protocol
25
25
  from uuid import uuid4
26
26
 
27
27
  import httpx
@@ -30,34 +30,39 @@ from mcp.client.sse import sse_client
30
30
  from mcp.client.stdio import stdio_client
31
31
  from openai.types.chat import ChatCompletionChunk
32
32
 
33
- # 全局工具注册表
34
- _FUNCTION_MAPPINGS = {} # 工具名称 -> 工具函数
35
- _FUNCTION_INFO = {} # 工具名称 -> 工具info信息
36
- _OPENAI_FUNCTION_SCHEMAS = [] # OpenAI 格式的工具描述
37
- _PROMPT_FUNCTION_SCHEMAS = [] # prompt 格式的工具描述
38
33
 
39
- __version__ = "0.3.3" # 你可以根据需要设置版本号
34
+ __version__ = "0.4.0" # 你可以根据需要设置版本号
40
35
 
41
36
 
42
37
  # openai.langfuse_auth_check()
43
38
 
39
+ # 1. 定义内存接口协议
40
+ class MemoryProtocol(Protocol):
41
+ def store(self, data: str, user_id: str) -> Any:
42
+ ...
44
43
 
45
- def register_tool_manually(tools: List[Union[str, Callable]]) -> bool:
46
- """
47
- 手动注册多个工具,从函数属性中提取工具信息
48
- :param tools: 工具函数列表
49
- """
50
- for func in tools:
44
+ def retrieve(self, query: str, user_id: str) -> List[Any]:
45
+ ...
46
+
47
+ class ToolRegistry:
48
+ """集中管理工具注册表,避免全局变量"""
49
+
50
+ def __init__(self):
51
+ self.function_mappings = {} # 工具名称 -> 工具函数
52
+ self.function_info = {} # 工具名称 -> 工具info信息
53
+ self.openai_function_schemas = [] # OpenAI 格式的工具描述
54
+
55
+ def register_tool(self, func: Callable) -> bool:
56
+ """注册单个工具"""
51
57
  if not hasattr(func, "tool_info"):
52
- # raise ValueError(f"Function `{func.__name__}` does not have tool_info attribute.")
53
- continue
58
+ return False
54
59
 
55
60
  tool_info = func.tool_info
56
61
  tool_name = tool_info["tool_name"]
57
62
 
58
- # 注册到全局字典
59
- _FUNCTION_INFO[tool_name] = tool_info
60
- _FUNCTION_MAPPINGS[tool_name] = func # 注册工具
63
+ # 注册到字典
64
+ self.function_info[tool_name] = tool_info
65
+ self.function_mappings[tool_name] = func
61
66
 
62
67
  # 构建 OpenAI 格式的工具描述
63
68
  tool_params_openai = {}
@@ -83,174 +88,194 @@ def register_tool_manually(tools: List[Union[str, Callable]]) -> bool:
83
88
  }
84
89
  }
85
90
 
86
- _OPENAI_FUNCTION_SCHEMAS.append(tool_def_openai)
87
- return True
91
+ self.openai_function_schemas.append(tool_def_openai)
92
+ return True
88
93
 
94
+ def register_tools(self, tools: List[Callable]) -> bool:
95
+ """批量注册工具"""
96
+ success = True
97
+ for func in tools:
98
+ if not self.register_tool(func):
99
+ success = False
100
+ return success
89
101
 
90
- def load_tool(tool_name: str, tools_directory: str = "tools"):
91
- """
92
- 根据工具名称从 tools 目录中加载对应的工具函数
93
- """
94
- tool_path = os.path.join(tools_directory, f"{tool_name}.py")
95
- if not os.path.exists(tool_path):
96
- raise FileNotFoundError(f"Tool '{tool_name}' not found in {tools_directory}")
102
+ def get_tools(self) -> List[Dict[str, Any]]:
103
+ """获取所有工具的描述(OpenAI 格式)"""
104
+ return deepcopy(self.openai_function_schemas)
97
105
 
98
- # 动态加载模块
99
- spec = importlib.util.spec_from_file_location(tool_name, tool_path)
100
- module = importlib.util.module_from_spec(spec)
101
- spec.loader.exec_module(module)
106
+ def get_tools_str(self) -> str:
107
+ """将工具描述转换为格式化的 JSON 字符串"""
108
+ return json.dumps(self.openai_function_schemas, indent=4, ensure_ascii=False)
102
109
 
103
- # 获取工具函数
104
- if hasattr(module, tool_name):
105
- tool_func = getattr(module, tool_name)
106
- if callable(tool_func) and hasattr(tool_func, "tool_info"):
107
- return tool_func
108
- raise AttributeError(f"Tool '{tool_name}' is not properly defined in {tool_path}")
110
+ def filter_tools(self, tool_reflection_result: str) -> List[Dict]:
111
+ """根据内容过滤工具"""
112
+ try:
113
+ # 安全解析可能包含 Markdown 代码块的 JSON
114
+ refined_content = tool_reflection_result.strip()
115
+ if refined_content.startswith('```json') and refined_content.endswith('```'):
116
+ refined_content = refined_content[7:-3].strip()
109
117
 
118
+ parsed_data = json.loads(refined_content)
119
+ valid_tools = {tool["name"].strip().lower() for tool in parsed_data.get("tools", [])}
110
120
 
111
- async def dispatch_tool(tool_name: str, tool_params: Dict[str, Any]) -> Union[
112
- str, Generator[str, None, None], AsyncGenerator[str, None]]:
113
- """
114
- 调用工具执行,支持同步/异步工具及流式输出。
115
- """
116
- if tool_name not in _FUNCTION_MAPPINGS:
117
- return f"Tool `{tool_name}` not found."
121
+ return [
122
+ schema for schema in self.openai_function_schemas
123
+ if isinstance(schema, dict) and
124
+ schema.get("function", {}).get("name", "").strip().lower() in valid_tools
125
+ ]
126
+ except (json.JSONDecodeError, KeyError, AttributeError) as e:
127
+ raise ValueError(f"工具过滤失败: {str(e)}") from e
118
128
 
119
- tool_call = _FUNCTION_MAPPINGS[tool_name]
120
- try:
121
- # 处理不同类型的流式输出
122
- # 区分同步/异步工具
123
- if inspect.iscoroutinefunction(tool_call):
124
- # result = await tool_call(**tool_params)
125
- # 将参数以字典形式传递给包装器
126
- result = await tool_call(**tool_params) if inspect.iscoroutinefunction(tool_call) else tool_call(
127
- **tool_params)
128
- else:
129
- result = tool_call(**tool_params)
130
129
 
131
- # 处理不同类型的流式输出
132
- if inspect.isasyncgen(result):
133
- return async_stream_generator(result)
134
- elif inspect.isgenerator(result):
135
- return stream_generator(result)
136
- else:
130
+ class ToolLoader:
131
+ """工具加载器,支持动态加载和缓存"""
132
+
133
+ def __init__(self, tools_directory: str = "tools"):
134
+ self.tools_directory = tools_directory
135
+ self.loaded_tools = {}
136
+
137
+ def load_tool(self, tool_name: str) -> Callable:
138
+ """加载单个工具"""
139
+ if tool_name in self.loaded_tools:
140
+ return self.loaded_tools[tool_name]
141
+
142
+ tool_path = os.path.join(self.tools_directory, f"{tool_name}.py")
143
+ if not os.path.exists(tool_path):
144
+ raise FileNotFoundError(f"Tool '{tool_name}' not found in {tool_path}")
145
+
146
+ # 动态加载模块
147
+ spec = importlib.util.spec_from_file_location(tool_name, tool_path)
148
+ module = importlib.util.module_from_spec(spec)
149
+ spec.loader.exec_module(module)
150
+
151
+ # 获取工具函数
152
+ if hasattr(module, tool_name):
153
+ tool_func = getattr(module, tool_name)
154
+ if callable(tool_func) and hasattr(tool_func, "tool_info"):
155
+ self.loaded_tools[tool_name] = tool_func
156
+ return tool_func
157
+
158
+ raise AttributeError(f"Tool '{tool_name}' is not properly defined in {tool_path}")
159
+
160
+ def load_tools(self, tool_names: List[str]) -> Dict[str, Callable]:
161
+ """批量加载工具"""
162
+ for tool_name in tool_names:
163
+ if tool_name not in self.loaded_tools:
164
+ self.load_tool(tool_name)
165
+ return self.loaded_tools
166
+
167
+
168
+ class AsyncToolDispatcher:
169
+ """异步工具调度器"""
170
+
171
+ async def dispatch(self, tool_name: str, tool_params: Dict[str, Any]) -> Union[
172
+ str, Generator[str, None, None], AsyncGenerator[str, None]]:
173
+ """调用工具执行,支持同步/异步工具及流式输出"""
174
+ if tool_name not in self.function_mappings:
175
+ return f"Tool `{tool_name}` not found."
176
+
177
+ tool_call = self.function_mappings[tool_name]
178
+ try:
179
+ # 处理不同类型的工具
180
+ if inspect.iscoroutinefunction(tool_call):
181
+ result = await tool_call(**tool_params)
182
+ elif inspect.isasyncgenfunction(tool_call):
183
+ result = tool_call(**tool_params)
184
+ else:
185
+ result = tool_call(**tool_params)
186
+
187
+ # 处理流式输出
188
+ if inspect.isasyncgen(result):
189
+ return self.async_stream_generator(result)
190
+ elif inspect.isgenerator(result):
191
+ return self.stream_generator(result)
137
192
  return str(result)
138
- except Exception as e:
139
- return traceback.format_exc()
193
+ except Exception as e:
194
+ return f"Tool call error: {str(e)}\n{traceback.format_exc()}"
140
195
 
196
+ async def async_stream_generator(self, async_gen: AsyncGenerator) -> AsyncGenerator[str, None]:
197
+ async for chunk in async_gen:
198
+ yield chunk
141
199
 
142
- async def async_stream_generator(async_gen: AsyncGenerator) -> AsyncGenerator[str, None]:
143
- async for chunk in async_gen:
144
- yield chunk
200
+ def stream_generator(self, sync_gen: Generator) -> Generator[str, None, None]:
201
+ for chunk in sync_gen:
202
+ yield chunk
145
203
 
146
204
 
147
- def stream_generator(sync_gen: Generator) -> Generator[str, None, None]:
148
- for chunk in sync_gen:
149
- yield chunk
205
+ class LoggerManager:
206
+ """集中管理日志系统"""
150
207
 
208
+ def __init__(self, name: str, debug: bool, log_level: str, log_file: Optional[str] = None):
209
+ self.name = name
210
+ self.debug = debug
211
+ self.logger = self._setup_logger(log_level, log_file)
212
+ self.traceid = ""
151
213
 
152
- def dispatch_tool_old(tool_name: str, tool_params: Dict[str, Any]) -> str:
153
- """
154
- 调用工具执行
155
- """
156
- if tool_name not in _FUNCTION_MAPPINGS:
157
- return f"Tool `{tool_name}` not found."
214
+ def _setup_logger(self, log_level: str, log_file: Optional[str] = None) -> logging.Logger:
215
+ logger = logging.getLogger(self.name)
216
+ logger.setLevel(log_level.upper())
217
+ logger.propagate = False
158
218
 
159
- tool_call = _FUNCTION_MAPPINGS[tool_name]
160
- try:
161
- # print(f"Calling tool: {tool_name} with params: {tool_params}") # 调试信息
162
- return str(tool_call(**tool_params))
163
- except Exception as e:
164
- # print(f"Tool call failed: {e}") # 调试信息
165
- return traceback.format_exc()
166
-
167
-
168
- def get_tools() -> List[Dict[str, Any]]:
169
- """
170
- 获取所有工具的描述(OpenAI 格式)
171
- """
172
- return deepcopy(_OPENAI_FUNCTION_SCHEMAS)
173
-
174
-
175
- def get_tools_str() -> str:
176
- """
177
- 将 _OPENAI_FUNCTION_SCHEMAS 转换为格式化的 JSON 字符串。
178
- Returns:
179
- str: 格式化的 JSON 字符串。
180
- """
181
- # 使用 json.dumps 将字典转换为格式化的 JSON 字符串
182
- tools_str = json.dumps(_OPENAI_FUNCTION_SCHEMAS, indent=4, ensure_ascii=False)
183
- return tools_str
184
-
185
-
186
- def filter_tools_schemas(refined_content: str) -> json:
187
- """
188
- 根据refined_content中的工具列表过滤全局_OPENAI_FUNCTION_SCHEMAS
189
- :param refined_content: 包含工具列表的JSON字符串
190
- """
191
- # global _OPENAI_FUNCTION_SCHEMAS # 声明操作全局变量
192
- """安全解析可能包含 Markdown 代码块的 JSON"""
193
- refined_content = refined_content.strip()
194
- if refined_content.startswith('```json') and refined_content.endswith('```'):
195
- refined_content = refined_content[7:-3].strip() # 去除 ```json 和 ```
196
- try:
197
- # 解析工具列表
198
- parsed_data: Dict[str, List[Dict]] = json.loads(refined_content)
199
- valid_tools = {tool["name"].strip().lower() for tool in parsed_data.get("tools", [])}
200
-
201
- # 原地过滤操作
202
- filtered_schemas: List[Dict] = []
203
- for schema in _OPENAI_FUNCTION_SCHEMAS:
204
- if not isinstance(schema, dict):
205
- continue
219
+ formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
206
220
 
207
- # 深度检查结构
208
- function_info = schema.get("function", {})
209
- if isinstance(function_info, dict):
210
- schema_name = function_info.get("name", "").strip().lower()
211
- if schema_name in valid_tools:
212
- filtered_schemas.append(schema)
221
+ if self.debug:
222
+ console_handler = logging.StreamHandler()
223
+ console_handler.setFormatter(formatter)
224
+ logger.addHandler(console_handler)
213
225
 
214
- # 直接替换全局变量内容
215
- # _OPENAI_FUNCTION_SCHEMAS[:] = filtered_schemas
216
- return filtered_schemas
226
+ if log_file:
227
+ # 确保 log 目录存在
228
+ log_dir = os.path.dirname(log_file)
229
+ if log_dir and not os.path.exists(log_dir):
230
+ os.makedirs(log_dir)
217
231
 
218
- except (json.JSONDecodeError, KeyError, AttributeError) as e:
219
- # 错误处理:清空工具列表并记录日志
220
- # _OPENAI_FUNCTION_SCHEMAS.clear()
221
- raise ValueError(f"工具过滤失败: {str(e)}") from e
232
+ file_handler = logging.FileHandler(log_file)
233
+ file_handler.setFormatter(formatter)
234
+ logger.addHandler(file_handler)
235
+
236
+ return logger
237
+
238
+ def log(self, level: str, action: str, data: Any):
239
+ """记录日志"""
240
+ if not self.debug:
241
+ return
242
+
243
+ trace_info = f"[TraceID: {self.traceid}] " if self.traceid else ""
244
+ log_message = f"{trace_info}{action}: {data}"
245
+
246
+ if level == "DEBUG":
247
+ self.logger.debug(log_message)
248
+ elif level == "INFO":
249
+ self.logger.info(log_message)
250
+ elif level == "ERROR":
251
+ self.logger.error(log_message)
252
+
253
+ def set_traceid(self, traceid: str):
254
+ """设置当前跟踪ID"""
255
+ self.traceid = traceid
222
256
 
223
257
 
224
258
  class MCPClientManager:
225
259
  """增强版MCP客户端管理器"""
226
260
 
227
- def __init__(self, config: dict):
261
+ def __init__(self, config: dict, tool_registry: ToolRegistry):
228
262
  self.config = config
263
+ self.tool_registry = tool_registry
229
264
  self.session: Optional[ClientSession] = None
230
265
  self.exit_stack = AsyncExitStack()
231
- self._streams_context = None
232
- self._session_context = None
233
- self.server_sessions = {} # 存储不同服务器的会话
234
-
235
- async def _call_tool_wrapper(self, tool_name: str, target_server: str, **kwargs):
236
- """参数转换适配器"""
237
- return await self.call_tool(
238
- tool_name=tool_name,
239
- arguments=kwargs,
240
- target_server=target_server
241
- )
266
+ self.server_sessions = {}
242
267
 
243
268
  async def _create_session(self, server_name: str, config: dict):
244
269
  """创建并管理会话上下文"""
245
270
  if 'url' in config:
246
271
  # SSE 服务器连接
247
- self._streams_context = sse_client(
272
+ streams_context = sse_client(
248
273
  url=config['url'],
249
274
  headers=config.get('headers', {})
250
275
  )
251
- streams = await self.exit_stack.enter_async_context(self._streams_context)
252
- self._session_context = ClientSession(*streams)
253
- self.session = await self.exit_stack.enter_async_context(self._session_context)
276
+ streams = await self.exit_stack.enter_async_context(streams_context)
277
+ session_context = ClientSession(*streams)
278
+ self.session = await self.exit_stack.enter_async_context(session_context)
254
279
  else:
255
280
  # 标准输入输出服务器连接
256
281
  server_params = StdioServerParameters(
@@ -260,23 +285,19 @@ class MCPClientManager:
260
285
  )
261
286
  transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
262
287
  stdio, write = transport
263
- self._session_context = ClientSession(stdio, write)
264
- self.session = await self.exit_stack.enter_async_context(self._session_context)
288
+ session_context = ClientSession(stdio, write)
289
+ self.session = await self.exit_stack.enter_async_context(session_context)
265
290
 
266
291
  await self.session.initialize()
267
292
  self.server_sessions[server_name] = self.session
268
293
 
269
294
  async def cleanup(self):
270
295
  """清理所有会话资源"""
271
- await self.exit_stack.__aexit__(None, None, None)
296
+ await self.exit_stack.aclose()
272
297
  self.server_sessions.clear()
273
298
 
274
299
  async def register_mcp_tool(self) -> bool:
275
- """
276
- 自动注册所有MCP服务的工具到全局字典
277
- :param config: MCP服务配置(与call_tool使用的相同配置)
278
- :return: 是否至少成功注册一个工具
279
- """
300
+ """自动注册所有MCP服务的工具"""
280
301
  registered_count = 0
281
302
  enabled_servers = [
282
303
  (name, config)
@@ -286,15 +307,10 @@ class MCPClientManager:
286
307
 
287
308
  for server_name, config in enabled_servers:
288
309
  try:
289
- # 创建会话连接
290
- # print(server_name,config)
291
310
  await self._create_session(server_name, config)
292
-
293
- # 获取工具列表
294
311
  tools_response = await self.session.list_tools()
295
- print(f"🔍 Registering tools for server : {server_name} ...")
312
+ print(f"🔍 Registering MCP tools for server : {server_name} ...")
296
313
 
297
- # 注册工具处理逻辑
298
314
  for tool in tools_response.tools:
299
315
  try:
300
316
  # 构建工具元数据
@@ -316,9 +332,9 @@ class MCPClientManager:
316
332
  "required": param_name in required_fields
317
333
  })
318
334
 
319
- # 注册到全局字典
320
- _FUNCTION_INFO[tool.name] = tool_info
321
- _FUNCTION_MAPPINGS[tool.name] = partial(
335
+ # 注册到工具注册表
336
+ self.tool_registry.function_info[tool.name] = tool_info
337
+ self.tool_registry.function_mappings[tool.name] = partial(
322
338
  self._call_tool_wrapper,
323
339
  tool_name=tool.name,
324
340
  target_server=server_name
@@ -340,31 +356,27 @@ class MCPClientManager:
340
356
  }
341
357
  }
342
358
  }
343
- _OPENAI_FUNCTION_SCHEMAS.append(openai_schema)
344
-
359
+ self.tool_registry.openai_function_schemas.append(openai_schema)
345
360
  registered_count += 1
346
- print(f"✅ The registered tool : {tool.name}")
347
-
361
+ print(f"✅ The registered MCP tool : {tool.name}")
348
362
  except Exception as e:
349
- print(f"⚠️ 工具 {tool.name} 注册失败: {str(e)}")
350
363
  continue
351
- print(f"🟢 {registered_count} tools have been registered")
352
-
353
364
  except Exception as e:
354
- print(f"🔴 服务器 {server_name} 连接失败: {str(e)}")
355
365
  continue
356
- # 清理
366
+
357
367
  await self.cleanup()
358
368
  return registered_count > 0
359
369
 
370
+ async def _call_tool_wrapper(self, tool_name: str, target_server: str, **kwargs):
371
+ """参数转换适配器"""
372
+ return await self.call_tool(
373
+ tool_name=tool_name,
374
+ arguments=kwargs,
375
+ target_server=target_server
376
+ )
377
+
360
378
  async def call_tool(self, tool_name: str, arguments: dict, target_server: str = None):
361
- """
362
- 通用工具调用方法
363
- :param tool_name: 要调用的工具名称
364
- :param arguments: 工具参数字典
365
- :param target_server: 指定服务器名称(可选)
366
- :return: 工具调用结果
367
- """
379
+ """通用工具调用方法"""
368
380
  enabled_servers = [
369
381
  (name, config)
370
382
  for name, config in self.config["mcpServers"].items()
@@ -376,15 +388,11 @@ class MCPClientManager:
376
388
 
377
389
  for server_name, config in enabled_servers:
378
390
  try:
379
- # 复用已建立的会话
380
391
  session = self.server_sessions.get(server_name)
381
- # print(111,server_name,session)
382
- # print(222,server_name,config)
383
392
  if not session:
384
393
  await self._create_session(server_name, config)
385
394
  session = self.session
386
395
 
387
- # 获取工具列表
388
396
  tools = await session.list_tools()
389
397
  available_tools = {t.name: t for t in tools.tools}
390
398
 
@@ -395,31 +403,26 @@ class MCPClientManager:
395
403
 
396
404
  # 执行调用
397
405
  result = await session.call_tool(tool_name, arguments)
398
- # print(f"mcp工具运行结果: {result.content[0].text}")
399
- # 调用完成清理session
400
406
  await self.cleanup()
401
407
  return {
402
408
  "server": server_name,
403
409
  "tool": tool_name,
404
410
  "result": result.content[0].text
405
411
  }
406
-
407
412
  except Exception as e:
408
- print(f"调用服务器 {server_name} 失败: {str(e)}")
409
413
  continue
410
414
 
411
415
  raise ValueError(f"工具 {tool_name} 在可用服务器中未找到")
412
416
 
413
417
  def _validate_arguments(self, arguments: dict, schema: dict):
414
- """简单参数校验(可选)"""
418
+ """简单参数校验"""
415
419
  required_fields = schema.get("required", [])
416
420
  for field in required_fields:
417
421
  if field not in arguments:
418
422
  raise ValueError(f"缺少必要参数: {field}")
419
423
 
420
-
421
424
  class LightAgent:
422
- __version__ = "0.3.3" # 将版本号放在类中
425
+ __version__ = "0.4.0" # 将版本号放在类中
423
426
 
424
427
  def __init__(
425
428
  self,
@@ -431,7 +434,7 @@ class LightAgent:
431
434
  api_key: str | None = None, # 模型 api key
432
435
  base_url: str | httpx.URL | None = None, # 模型 base url
433
436
  websocket_base_url: str | httpx.URL | None = None, # 模型 websocket base url
434
- memory: str | None = None, # 支持外部传入记忆模块
437
+ memory: Optional[MemoryProtocol] = None, # 支持外部传入记忆模块
435
438
  tree_of_thought: bool = False, # 是否启用链式思考
436
439
  tot_model: str | None = None, # 链式思考模型
437
440
  tot_api_key: str | None = None, # 链式思考模型API密钥
@@ -466,6 +469,13 @@ class LightAgent:
466
469
  :param log_file: 日志文件路径。
467
470
  :param tracetools: log跟踪工具。
468
471
  """
472
+
473
+ # 初始化核心组件
474
+ self.tool_registry = ToolRegistry()
475
+ self.tool_loader = ToolLoader()
476
+ self.tool_dispatcher = AsyncToolDispatcher()
477
+ self.tool_dispatcher.function_mappings = self.tool_registry.function_mappings
478
+
469
479
  self.mcp_setting = None
470
480
  self.mcp_client = None
471
481
  if not model:
@@ -500,7 +510,14 @@ class LightAgent:
500
510
  if debug:
501
511
  self.log_file = os.path.join(log_dir, log_file)
502
512
  # Set up the logger
503
- self.logger = self._setup_logger(log_level, self.log_file)
513
+ # 初始化日志系统
514
+ self.logger = LoggerManager(
515
+ name=self.name,
516
+ debug=debug,
517
+ log_level=log_level,
518
+ log_file=log_file
519
+ )
520
+
504
521
  if tools is None:
505
522
  self.tools = []
506
523
  if tools:
@@ -515,49 +532,54 @@ class LightAgent:
515
532
  )
516
533
  self.api_key = api_key
517
534
  self.websocket_base_url = websocket_base_url
518
-
519
- if base_url is None:
520
- base_url = f"https://api.openai.com/v1"
535
+ self.base_url = base_url or os.environ.get("OPENAI_BASE_URL") or "https://api.openai.com/v1"
521
536
 
522
537
  if self.tree_of_thought:
523
538
  if tot_api_key is None:
524
- tot_api_key = api_key
539
+ tot_api_key = self.api_key
525
540
  if tot_base_url is None:
526
- tot_base_url = base_url
541
+ tot_base_url = self.base_url
527
542
  if not tot_model:
528
543
  tot_model = "deepseek-r1" # 默认思维推理模型为deepseek-r1
529
544
  self.tot_model = tot_model
530
545
 
531
- if tracetools is None:
532
- self.tracetools = []
546
+ # 初始化客户端
547
+ self._initialize_clients(tracetools, tot_api_key, tot_base_url, tot_model)
548
+
549
+ def _initialize_clients(self, tracetools, tot_api_key, tot_base_url, tot_model):
550
+ """初始化 OpenAI 客户端"""
533
551
  if tracetools:
534
- self.tracetools = tracetools
535
- # 初始化工具列表
536
552
  from langfuse.openai import openai as la_openai
537
- la_openai.langfuse_public_key = self.tracetools['TraceToolConfig']['langfuse_public_key']
538
- la_openai.langfuse_secret_key = self.tracetools['TraceToolConfig']['langfuse_secret_key']
539
- la_openai.langfuse_enabled = self.tracetools['TraceToolConfig'][
540
- 'langfuse_enabled'] # Default is True, set to False to disable Langfuse
541
- la_openai.langfuse_host = self.tracetools['TraceToolConfig']['langfuse_host'] # 🇪🇺 EU region
542
- la_openai.base_url = base_url
553
+ la_openai.langfuse_public_key = tracetools['TraceToolConfig']['langfuse_public_key']
554
+ la_openai.langfuse_secret_key = tracetools['TraceToolConfig']['langfuse_secret_key']
555
+ la_openai.langfuse_enabled = tracetools['TraceToolConfig']['langfuse_enabled']
556
+ la_openai.langfuse_host = tracetools['TraceToolConfig']['langfuse_host']
557
+ la_openai.base_url = self.base_url
543
558
  la_openai.api_key = self.api_key
544
559
  self.client = la_openai
560
+
545
561
  if self.tree_of_thought:
546
- la_openai.base_url = tot_base_url
547
- la_openai.api_key = tot_api_key
562
+ la_openai.base_url = tot_base_url or self.base_url
563
+ la_openai.api_key = tot_api_key or self.api_key
548
564
  self.tot_client = la_openai
549
565
  else:
550
566
  from openai import OpenAI as la_openai
551
567
  self.client = la_openai(
552
- base_url=base_url,
568
+ base_url=self.base_url,
553
569
  api_key=self.api_key
554
570
  )
555
571
  if self.tree_of_thought:
556
572
  self.tot_client = la_openai(
557
- base_url=tot_base_url,
558
- api_key=tot_api_key
573
+ base_url=tot_base_url or self.base_url,
574
+ api_key=tot_api_key or self.api_key
559
575
  )
560
576
 
577
+ def get_tools(self) -> List[Dict[str, Any]]:
578
+ """
579
+ 获取所有工具的描述(OpenAI 格式)
580
+ """
581
+ return deepcopy(self.tool_registry.get_tools())
582
+
561
583
  def get_tool(self, tool_name: str) -> Callable:
562
584
  """
563
585
  用于外部可以获取已加载的工具函数
@@ -568,111 +590,19 @@ class LightAgent:
568
590
  return self.loaded_tools[tool_name]
569
591
  raise ValueError(f"Tool `{tool_name}` is not loaded.")
570
592
 
571
- def get_tools(self) -> List[str]:
572
- """
573
- 用于外部可以获取已加载的工具函数列表
574
- :return: 工具函数
575
- """
576
- return list(_FUNCTION_MAPPINGS.keys())
577
-
578
593
  def load_tools(self, tool_names: List[Union[str, Callable]], tools_directory: str = "tools"):
579
- """
580
- 根据工具名称列表加载对应的工具函数,并注册到全局工具注册表中
581
- """
582
- for tool_name in tool_names:
583
- try:
584
- tool_func = load_tool(tool_name, tools_directory)
585
- # globals()[tool_name] = tool_func # 添加到全局命名空间
586
- self.loaded_tools[tool_name] = tool_func # 存储工具函数
587
- # print(f"Tool `{tool_name}` loaded successfully and added to _loaded_tools.") # 调试信息
588
-
589
- # 注册工具函数
590
- if hasattr(tool_func, "tool_info"):
591
- tool_info = tool_func.tool_info
592
- _FUNCTION_INFO[tool_name] = tool_info # 注册工具info信息
593
- _FUNCTION_MAPPINGS[tool_name] = tool_func
594
-
595
- # 构建 OpenAI 格式的工具描述
596
- tool_params_openai = {}
597
- tool_required = []
598
- for param in tool_info["tool_params"]:
599
- tool_params_openai[param["name"]] = {
600
- "type": param["type"],
601
- "description": param["description"],
602
- }
603
- if param["required"]:
604
- tool_required.append(param["name"])
605
-
606
- tool_def_openai = {
607
- "type": "function",
608
- "function": {
609
- "name": tool_name,
610
- "description": tool_info["tool_description"],
611
- "parameters": {
612
- "type": "object",
613
- "properties": tool_params_openai,
614
- "required": tool_required,
615
- },
616
- }
617
- }
618
- _OPENAI_FUNCTION_SCHEMAS.append(tool_def_openai)
619
-
620
- self.log("DEBUG", "load_tools success", {"tools": tool_name})
621
- except Exception as e:
622
- if register_tool_manually([tool_name]):
623
- self.log("DEBUG", "register_tool_manually success", {"tools": tool_name})
624
- else:
625
- self.log("DEBUG", "load_tools error", {"e": e})
626
-
627
- def _setup_logger(self, log_level: str, log_file: Optional[str] = None) -> logging.Logger:
628
- """
629
- 设置日志记录器。
630
-
631
- :param log_level: 日志级别(INFO, DEBUG, ERROR)。
632
- :param log_file: 日志文件路径。
633
- :return: 配置好的日志记录器。
634
- """
635
- logger = logging.getLogger(self.name)
636
- logger.setLevel(log_level.upper())
637
- logger.propagate = False # 禁用传播到根日志记录器
638
-
639
- formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
640
-
641
- # 输出到控制台
642
- # 仅在调试模式下输出到控制台
643
- if self.debug:
644
- console_handler = logging.StreamHandler()
645
- console_handler.setFormatter(formatter)
646
- logger.addHandler(console_handler)
647
-
648
- # 输出到文件(如果指定了文件路径)
649
- if log_file:
650
- file_handler = logging.FileHandler(log_file)
651
- file_handler.setFormatter(formatter)
652
- logger.addHandler(file_handler)
653
-
654
- return logger
655
-
656
- def log(self, level, action, data):
657
- """
658
- 记录日志。
659
-
660
- :param level: 日志级别(INFO, DEBUG, ERROR)。
661
- :param action: 日志动作(如 chat, call_tool, retrieve_memory)。
662
- :param data: 日志数据。
663
- """
664
- if not self.debug:
665
- return
666
- if self.traceid is not None:
667
- log_message = f"[TraceID: {self.traceid}] {action}: {data}"
668
- else:
669
- log_message = f"{action}: {data}"
670
- if level == "DEBUG":
671
- self.logger.debug(log_message)
672
- elif level == "INFO":
673
- self.logger.info(log_message)
674
- elif level == "ERROR":
675
- self.logger.error(log_message)
594
+ """加载并注册工具"""
595
+ for tool in tool_names:
596
+ if isinstance(tool, str):
597
+ try:
598
+ tool_func = self.tool_loader.load_tool(tool)
599
+ self.tool_registry.register_tool(tool_func)
600
+ self.logger.log("DEBUG", "load_tools", {"tool": tool, "status": "success"})
601
+ except Exception as e:
602
+ self.logger.log("ERROR", "load_tools", {"tool": tool, "error": str(e)})
603
+ elif callable(tool) and hasattr(tool, "tool_info"):
604
+ if self.tool_registry.register_tool(tool):
605
+ self.logger.log("DEBUG", "register_tool", {"tool": tool.__name__, "status": "success"})
676
606
 
677
607
  async def setup_mcp(
678
608
  self,
@@ -682,9 +612,9 @@ class LightAgent:
682
612
  self.mcp_setting = mcp_setting
683
613
  """单独初始化 MCP 模块"""
684
614
  if self.mcp_setting and not self.mcp_client:
685
- self.mcp_client = MCPClientManager(self.mcp_setting)
615
+ self.mcp_client = MCPClientManager(self.mcp_setting, self.tool_registry)
686
616
  await self.mcp_client.register_mcp_tool()
687
- self.log("INFO", "setup_mcp", "MCP 模块初始化成功")
617
+ self.logger.log("INFO", "setup_mcp", "MCP 模块初始化成功")
688
618
 
689
619
  def run(
690
620
  self,
@@ -708,22 +638,13 @@ class LightAgent:
708
638
  :param metadata: 元数据。
709
639
  :return: 代理的回复。
710
640
  """
711
- self.traceid = uuid4().hex
712
- self.log("INFO", "run", {"query": query, "user_id": user_id, "light_swarm": light_swarm, "stream": stream})
713
- if history is None:
714
- history = []
715
- # 构建消息列表,先添加系统提示信息
716
- params = {}
717
-
718
- # 1. 判断是否需要转移任务
719
- # if light_swarm:
720
- # intent = self._detect_intent(query, light_swarm)
721
- # if intent and intent.get("transfer_to"):
722
- # target_agent_name = intent["transfer_to"]
723
- # self.log("INFO", "detect_intent", {"intent": intent})
724
- # print(light_swarm.agents[target_agent_name])
725
- # self._transfer_to_agent(light_swarm.agents[target_agent_name], query, stream=stream)
726
- # return # 立即结束当前生成器
641
+ # 设置跟踪ID
642
+ traceid = uuid4().hex
643
+ self.logger.set_traceid(traceid)
644
+ self.logger.log("INFO", "run_start", {"query": query, "user_id": user_id, "stream": stream})
645
+
646
+ # 初始化历史记录
647
+ history = history or []
727
648
 
728
649
  # 0. 判断是否需要转移任务
729
650
  if light_swarm:
@@ -735,66 +656,75 @@ class LightAgent:
735
656
  now = datetime.now()
736
657
  current_date = now.strftime("%Y-%m-%d")
737
658
  current_time = now.strftime("%H:%M:%S")
738
- system_prompt = f"##代理名称:{self.name} ##代理指令 /n{self.instructions} ##身份 /n {self.role} /n 请一步一步思考来完成用户的要求。尽可能完成用户的回答,如果有补充信息,请参考补充信息来调用工具,直到获取所有满足用户的提问所需的答案。 /n 今日的日期: {current_date} 当前时间: {current_time}"
739
- params = dict(model=self.model, stream=stream)
740
- memory = ''
741
-
742
- # 2.添加langfuse的session
743
- if self.tracetools:
744
- params["session_id"] = self.traceid
745
- self.log("DEBUG", "Query Trace ID", {"query": query})
746
-
747
- # 3. 从记忆中检索相关内容&保存记忆
748
- if self.memory:
749
- related_memories = self.memory.retrieve(query=query, user_id=user_id)
750
- memory = memory + self._build_context(related_memories)
751
- self.memory.store(data=query, user_id=user_id)
752
- if self.self_learning:
753
- agent_memories = self.memory.retrieve(query=query, user_id=self.name)
754
- memory = memory + self._build_agent_memory(agent_memories)
755
- self.memory.store(data=query, user_id=self.name)
756
-
757
- query = f"{memory}\n##用户提问:\n{query}"
758
- # print(query)
659
+ system_prompt = (
660
+ f"##代理名称:{self.name}\n"
661
+ f"##代理指令:{self.instructions}\n"
662
+ f"##身份:{self.role}\n"
663
+ f"请一步一步思考来完成用户的要求。尽可能完成用户的回答,如果有补充信息,请参考补充信息来调用工具,直到获取所有满足用户的提问所需的答案。\n"
664
+ f"今日的日期: {current_date} 当前时间: {current_time}"
665
+ )
666
+ # 添加记忆上下文
667
+ query = self._add_memory_context(query, user_id)
759
668
 
760
- # 4. 思维链
669
+ # 思维链处理
761
670
  active_tools = []
762
671
  if self.tree_of_thought:
763
- tot_response, active_tools = self.run_thought(query=query)
764
- system_prompt = system_prompt + f" /n ##以下是问题的补充说明 /n {tot_response}"
765
- self.log("DEBUG", "tree_of_thought", {"response": tot_response, "active_tools": active_tools})
672
+ tot_response, active_tools = self.run_thought(query)
673
+ system_prompt += f"\n##以下是问题的补充说明\n{tot_response}"
674
+ self.logger.log("DEBUG", "tree_of_thought", {"response": tot_response, "active_tools": active_tools})
766
675
 
767
- # 5. 拼接tools工具
768
- # 带类型校验 自适应工具机制
769
- try:
770
- tools = active_tools if (
771
- len(active_tools) > 0
772
- ) else get_tools()
773
- except TypeError:
774
- tools = get_tools()
775
- # 带类型校验 自适应工具机制
776
- # tools = get_tools() # v0.2.X的工具选取机制
676
+ # 准备API参数
677
+ params = {
678
+ "model": self.model,
679
+ "messages": [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": query}],
680
+ "stream": stream
681
+ }
682
+
683
+ # 添加工具
684
+ tools = active_tools or self.tool_registry.get_tools()
777
685
  if tools:
778
- self.log("DEBUG", "register_tools", {"tools": list(_FUNCTION_MAPPINGS.keys())})
779
- self.log("DEBUG", "active_tools", {"tools": tools})
780
686
  params["tools"] = tools
781
687
  params["tool_choice"] = "auto"
782
688
 
783
- # 6. 调用核心运行逻辑
784
- params["messages"] = [{"role": "system", "content": system_prompt}]
785
- # 将历史对话添加到消息列表中
786
- for item in history:
787
- params["messages"].append({"role": item["role"], "content": item["content"]})
788
- # 最后添加当前用户的查询信息
789
- params["messages"].append({"role": "user", "content": query})
689
+ # 添加跟踪会话
690
+ if hasattr(self, 'tracetools') and self.tracetools:
691
+ params["session_id"] = traceid
790
692
 
693
+ # 调用模型
791
694
  response = self.client.chat.completions.create(**params)
695
+ return self._core_run_logic(response, params, stream, max_retry)
696
+
697
+ def _add_memory_context(self, query: str, user_id: str) -> str:
698
+ """添加记忆上下文"""
699
+ if not self.memory:
700
+ return query
701
+
702
+ context = ""
703
+ related_memories = self.memory.retrieve(query=query, user_id=user_id)
704
+ if related_memories and related_memories.get("results"):
705
+ context += "\n##用户偏好\n用户之前提到了:\n" + "\n".join(
706
+ [m["memory"] for m in related_memories["results"]]
707
+ )
708
+ self.memory.store(data=query, user_id=user_id)
792
709
 
793
- result = self._core_run_logic(response, params, stream, max_retry)
710
+ if self.self_learning:
711
+ agent_memories = self.memory.retrieve(query=query, user_id=self.name)
712
+ if agent_memories and agent_memories.get("results"):
713
+ context += "\n##问题相关补充信息:\n" + "\n".join(
714
+ [m["memory"] for m in agent_memories["results"]]
715
+ )
716
+ self.memory.store(data=query, user_id=self.name)
794
717
 
795
- return result
718
+ return f"{context}\n##用户提问:\n{query}" if context else query
796
719
 
797
- def _run_logic_non_stream(self, response, params, max_retry) -> Union[str, None]:
720
+ def _core_run_logic(self, response, params, stream, max_retry) -> Union[Generator[str, None, None], str]:
721
+ """核心运行逻辑"""
722
+ if stream:
723
+ return self._run_stream_logic(response, params, max_retry)
724
+ else:
725
+ return self._run_non_stream_logic(response, params, max_retry)
726
+
727
+ def _run_non_stream_logic(self, response, params, max_retry) -> Union[str, None]:
798
728
  """
799
729
  非流式处理逻辑。
800
730
  """
@@ -806,7 +736,7 @@ class LightAgent:
806
736
  output = ""
807
737
  function_call_name = ""
808
738
  tool_calls = response.choices[0].message.tool_calls
809
- self.log("DEBUG", "non_stream tool_calls", {"tool_calls": tool_calls})
739
+ self.logger.log("DEBUG", "non_stream tool_calls", {"tool_calls": tool_calls})
810
740
 
811
741
  # 遍历所有工具调用
812
742
  for tool_call in tool_calls:
@@ -814,13 +744,13 @@ class LightAgent:
814
744
 
815
745
  # 尝试自动修复常见转义问题
816
746
  fixed_args = function_call.arguments.replace('\\"', '"').replace('\\\\', '\\')
817
- self.log("DEBUG", "non_stream function_call", {"function_call": fixed_args})
747
+ self.logger.log("DEBUG", "non_stream function_call", {"function_call": fixed_args})
818
748
 
819
749
  # 解析函数参数
820
750
  function_args = json.loads(fixed_args)
821
751
 
822
752
  # 调用工具并获取响应
823
- tool_response = asyncio.run(dispatch_tool(function_call.name, function_args))
753
+ tool_response = asyncio.run(self.tool_dispatcher.dispatch(function_call.name, function_args))
824
754
  function_call_name = function_call.name
825
755
  combined_response = ""
826
756
  single_tool_response = ""
@@ -863,13 +793,13 @@ class LightAgent:
863
793
  pass # 如果不是 JSON 字符串,保持原样
864
794
  single_tool_response = combined_response # 处理单个工具的方法
865
795
 
866
- self.log("INFO", "non_stream single_tool_response", {"single_tool_response": single_tool_response})
796
+ self.logger.log("INFO", "non_stream single_tool_response", {"single_tool_response": single_tool_response})
867
797
 
868
798
  # 将单个工具的响应结果添加到列表中
869
799
  tool_responses.append(single_tool_response)
870
800
 
871
801
  # 将所有工具调用的结果合并为一个字符串
872
- self.log("DEBUG", "non_stream tool_responses", {"tool_responses": tool_responses})
802
+ self.logger.log("DEBUG", "non_stream tool_responses", {"tool_responses": tool_responses})
873
803
 
874
804
  combined_tool_response = "\n".join(tool_responses)
875
805
 
@@ -889,14 +819,14 @@ class LightAgent:
889
819
  else:
890
820
  # 返回最终回复
891
821
  reply = response.choices[0].message.content
892
- self.log("INFO", "non_stream final_reply", {"reply": reply})
822
+ self.logger.log("INFO", "non_stream final_reply", {"reply": reply})
893
823
  return reply
894
824
 
895
825
  # 更新响应
896
826
  if function_call_name == 'finish':
897
827
  return # 如果最后调用了finish工具,则结束生成器
898
828
  # print("params:",params)
899
- self.log("DEBUG", "non_stream chat-completions params", {"params": params})
829
+ self.logger.log("DEBUG", "non_stream chat-completions params", {"params": params})
900
830
 
901
831
  try:
902
832
  response = self.client.chat.completions.create(**params)
@@ -904,20 +834,17 @@ class LightAgent:
904
834
  print(f"An error occurred: {e}")
905
835
 
906
836
  # 重试次数用尽
907
- self.log("ERROR", "max_retry_reached", {"message": "Failed to generate a valid response."})
837
+ self.logger.log("ERROR", "max_retry_reached", {"message": "Failed to generate a valid response."})
908
838
  return "Failed to generate a valid response."
909
839
 
910
- def _run_logic_stream(self, response, params, max_retry) -> Generator[str, None, None]:
911
- """
912
- 流式处理逻辑。
913
- """
840
+ def _run_stream_logic(self, response, params, max_retry) -> Generator[str, None, None]:
841
+ """流式处理逻辑"""
914
842
  for _ in range(max_retry):
915
843
  # 初始化变量
916
844
  output = ""
917
- function_call_name = ""
918
- function_call_arguments = ""
919
845
  tool_calls = [] # 用于存储所有工具调用的信息
920
846
  tool_responses = [] # 用于存储所有工具调用的结果
847
+ finish_called = False # 标记是否调用了finish工具
921
848
 
922
849
  for chunk in response:
923
850
  content = chunk.choices[0].delta.content or ""
@@ -936,7 +863,7 @@ class LightAgent:
936
863
 
937
864
  # 如果工具调用信息尚未记录,初始化一个空字典
938
865
  if len(tool_calls) <= tool_call_index:
939
- tool_calls.append({"name": "", "arguments": "", "index": tool_call_index})
866
+ tool_calls.append({"name": "", "arguments": "", "index": tool_call_index, "title": ""})
940
867
 
941
868
  # 更新工具调用的名称
942
869
  if hasattr(tool_call_delta.function, "name") and tool_call_delta.function.name:
@@ -946,35 +873,48 @@ class LightAgent:
946
873
  if hasattr(tool_call_delta.function, "arguments") and tool_call_delta.function.arguments:
947
874
  tool_calls[tool_call_index]["arguments"] += tool_call_delta.function.arguments
948
875
 
949
- except (IndexError, AttributeError, KeyError):
950
- pass
876
+ except (IndexError, AttributeError, KeyError) as e:
877
+ self.logger.log("ERROR", "tool_call_error", {
878
+ "error": str(e),
879
+ "traceback": traceback.format_exc()
880
+ })
951
881
 
952
882
  # 如果流式输出结束
953
- if chunk.choices[0].finish_reason == "stop" and not any(tool_call["name"] for tool_call in tool_calls):
954
- self.log("INFO", "stream_response", {"output": output})
883
+ finish_reason = chunk.choices[0].finish_reason if chunk.choices else None
884
+ if finish_reason == "stop" and not any(tc["name"] for tc in tool_calls):
885
+ self.logger.log("INFO", "stream_response", {"output": output})
955
886
  return # 结束生成器
956
887
 
957
888
  # 如果工具调用结束
958
- elif chunk.choices[0].finish_reason == "tool_calls" or (
959
- chunk.choices[0].finish_reason == "stop" and any(
960
- tool_call["name"] for tool_call in tool_calls)):
889
+ elif finish_reason in ("tool_calls", "stop") and any(tc["name"] for tc in tool_calls):
961
890
  # 遍历所有工具调用
962
- self.log("DEBUG", "stream tool_calls", {"tool_calls": tool_calls})
891
+ self.logger.log("DEBUG", "stream tool_calls", {"tool_calls": tool_calls})
963
892
  for tool_call in tool_calls:
964
893
  if tool_call["name"]: # 确保工具调用有名称
965
- function_call = {
966
- "name": tool_call["name"],
967
- "title": _FUNCTION_INFO.get(tool_call["name"], {}).get('tool_title') or '',
968
- "arguments": tool_call["arguments"],
894
+ tool_name = tool_call["name"]
895
+ arguments = tool_call["arguments"]
896
+
897
+ # 从注册表中获取工具标题
898
+ tool_info = self.tool_registry.function_info.get(tool_name, {})
899
+ tool_title = tool_info.get("tool_title") or ""
900
+
901
+ # 更新工具调用信息
902
+ tool_call["title"] = tool_title
903
+
904
+ # 记录调用工具
905
+ tool_call_info = {
906
+ "name": tool_name,
907
+ "title": tool_title,
908
+ "arguments": arguments,
969
909
  }
970
- self.log("INFO", "stream function_call", {"function_call": function_call})
910
+ self.logger.log("INFO", "stream function_call", {"tool_call_start": tool_call_info})
971
911
  # 将工具的调用信息推送给开发者
972
- yield function_call
912
+ yield tool_call_info
973
913
 
974
914
  # 解析参数并调用工具
975
915
  try:
976
916
  # 使用正则表达式将多个 JSON 对象拆分开
977
- json_objects = re.findall(r'\{.*?\}', function_call["arguments"])
917
+ json_objects = re.findall(r'\{.*?\}', tool_call_info["arguments"])
978
918
 
979
919
  # 解析每个 JSON 对象并调用工具
980
920
  # for json_obj in json_objects:
@@ -985,12 +925,15 @@ class LightAgent:
985
925
  for json_obj in json_objects:
986
926
  # 尝试自动修复常见转义问题
987
927
  fixed_args = json_obj.replace('\\"', '"').replace('\\\\', '\\')
988
- self.log("DEBUG", "stream fixed_args", {"fixed_args": fixed_args})
928
+ self.logger.log("DEBUG", "stream fixed_args", {"fixed_args": fixed_args})
989
929
 
930
+ # 解析参数
990
931
  function_args = json.loads(fixed_args)
991
- # tool_response = dispatch_tool(function_call["name"], function_args)
992
- tool_response = asyncio.run(dispatch_tool(function_call["name"], function_args))
993
- function_call_name = function_call["name"]
932
+
933
+ # 调用工具
934
+ tool_response = asyncio.run(self.tool_dispatcher.dispatch(tool_name, function_args))
935
+
936
+ # 处理不同类型的工具响应
994
937
  combined_response = ""
995
938
  single_tool_response = ""
996
939
 
@@ -1003,15 +946,14 @@ class LightAgent:
1003
946
  yield chunk
1004
947
  else:
1005
948
  tool_output = {
1006
- "name": tool_call["name"],
1007
- "title": _FUNCTION_INFO.get(tool_call["name"], {}).get(
1008
- 'tool_title') or '',
949
+ "name": tool_name,
950
+ "title": tool_title,
1009
951
  "output": chunk,
1010
952
  }
1011
- self.log("DEBUG", "stream tool_output", {"tool_output": tool_output})
953
+ self.logger.log("DEBUG", "stream tool_output", {"tool_output": tool_output})
1012
954
  yield tool_output
1013
955
  # 将工具的调用信息推送给开发者
1014
- if function_call_name == 'finish':
956
+ if tool_name == 'finish':
1015
957
  content = chunk.choices[0].delta.content or ""
1016
958
  combined_response += content # 将每个 chunk 叠加
1017
959
  else:
@@ -1019,25 +961,54 @@ class LightAgent:
1019
961
  single_tool_response = combined_response # 处理单个工具的方法
1020
962
  else:
1021
963
  # print(f"Non-streaming response from tool: {tool_response}")
1022
- combined_response = tool_response
964
+ combined_response = str(tool_response)
1023
965
  single_tool_response = combined_response # 处理单个工具的方法
1024
- self.log("INFO", "stream single_tool_response",
966
+ tool_output = {
967
+ "name": tool_name,
968
+ "title": tool_title,
969
+ "output": combined_response
970
+ }
971
+ yield tool_output
972
+
973
+ # 记录工具响应
974
+ self.logger.log("INFO", "stream single_tool_response",
1025
975
  {"single_tool_response": single_tool_response})
1026
- # 将单个工具的响应结果添加到列表中
976
+
977
+ # 将单个工具的响应结果保存到列表中
1027
978
  tool_responses.append(single_tool_response)
1028
979
 
1029
- except json.JSONDecodeError as e:
1030
- self.log("ERROR", "json_decode_error",
1031
- {"error": str(e), "arguments": function_call["arguments"]})
1032
- continue
980
+ # 检查是否调用了finish工具
981
+ if tool_name == 'finish':
982
+ finish_called = True
983
+ self.logger.log("INFO", "finish_tool_called", {"response": combined_response})
1033
984
 
1034
- # 将所有工具调用的结果合并为一个字符串
985
+ except json.JSONDecodeError as e:
986
+ error_msg = f"JSON解析错误: {str(e)}\n参数: {arguments}"
987
+ self.logger.log("ERROR", "json_decode_error", {"tool": tool_name, "title": tool_title, "error": error_msg})
988
+ tool_responses.append(error_msg)
989
+ yield {"name": tool_name, "title": tool_title, "error": error_msg}
990
+
991
+ except Exception as e:
992
+ error_msg = f"工具调用错误: {str(e)}\n{traceback.format_exc()}"
993
+ self.logger.log("ERROR", "tool_execution_error", {
994
+ "tool": tool_name,
995
+ "title": tool_title,
996
+ "error": error_msg
997
+ })
998
+ tool_responses.append(error_msg)
999
+ yield {"name": tool_name, "title": tool_title, "error": error_msg}
1000
+
1001
+ # 如果调用了finish工具,则结束处理
1002
+ if finish_called:
1003
+ return
1004
+
1005
+ # 准备下一轮请求
1035
1006
  combined_tool_response = "\n".join(tool_responses)
1036
1007
  tool_str = json.dumps(
1037
1008
  [{"name": tool_call["name"], "arguments": tool_call["arguments"]} for tool_call in tool_calls],
1038
1009
  ensure_ascii=False)
1039
1010
 
1040
- # 将工具调用和响应添加到消息列表中
1011
+ # 添加工具调用和响应到消息历史
1041
1012
  params["messages"].append(
1042
1013
  {
1043
1014
  "role": "assistant",
@@ -1051,26 +1022,15 @@ class LightAgent:
1051
1022
  }
1052
1023
  )
1053
1024
 
1025
+ # 创建新的响应流
1026
+ self.logger.log("DEBUG", "stream next_request_params", {"params": params})
1027
+ response = self.client.chat.completions.create(**params)
1054
1028
  break
1055
1029
 
1056
- # 更新响应
1057
- if function_call_name == 'finish':
1058
- return # 如果最后调用了finish工具,则结束生成器
1059
- self.log("DEBUG", "stream chat-completions params", {"params": params})
1060
- response = self.client.chat.completions.create(**params)
1061
-
1062
1030
  # 重试次数用尽
1063
- self.log("ERROR", "max_retry_reached", {"message": "Failed to generate a valid response."})
1031
+ self.logger.log("ERROR", "max_retry_reached", {"message": "Failed to generate a valid response."})
1064
1032
  yield "Failed to generate a valid response."
1065
1033
 
1066
- def _core_run_logic(self, response, params, stream, max_retry) -> Union[Generator[str, None, None], str]:
1067
- """
1068
- 核心运行逻辑。
1069
- """
1070
- if stream:
1071
- return self._run_logic_stream(response, params, max_retry)
1072
- else:
1073
- return self._run_logic_non_stream(response, params, max_retry)
1074
1034
 
1075
1035
  def _handle_task_transfer(
1076
1036
  self,
@@ -1089,9 +1049,9 @@ class LightAgent:
1089
1049
  intent = self._detect_intent(query, light_swarm)
1090
1050
  if intent and intent.get("transfer_to"):
1091
1051
  target_agent_name = intent["transfer_to"]
1092
- self.log("INFO", "detect_intent", {"intent": intent})
1052
+ self.logger.log("INFO", "detect_intent", {"intent": intent})
1093
1053
  if target_agent_name == self.name:
1094
- self.log("INFO", "self_transfer_detected", {"target_agent": target_agent_name})
1054
+ self.logger.log("INFO", "self_transfer_detected", {"target_agent": target_agent_name})
1095
1055
  return None # 如果是自身,直接返回 None
1096
1056
  if stream:
1097
1057
  return self._handle_task_transfer_stream(light_swarm.agents[target_agent_name], query, light_swarm)
@@ -1113,18 +1073,18 @@ class LightAgent:
1113
1073
  :param light_swarm: LightSwarm 实例。
1114
1074
  :return: 生成器,用于流式输出。
1115
1075
  """
1116
- self.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
1076
+ self.logger.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
1117
1077
 
1118
1078
  # 检查目标代理是否有效
1119
1079
  if not hasattr(target_agent, 'run'):
1120
- self.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
1080
+ self.logger.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
1121
1081
  yield "Failed to transfer task: invalid target agent"
1122
1082
  return
1123
1083
 
1124
1084
  try:
1125
1085
  yield from target_agent.run(context, light_swarm=light_swarm, stream=True)
1126
1086
  except Exception as e:
1127
- self.log("ERROR", "run_failed", {"error": str(e)})
1087
+ self.logger.log("ERROR", "run_failed", {"error": str(e)})
1128
1088
  raise # 重新抛出异常以便调试
1129
1089
 
1130
1090
  def _handle_task_transfer_non_stream(
@@ -1141,11 +1101,11 @@ class LightAgent:
1141
1101
  :param light_swarm: LightSwarm 实例。
1142
1102
  :return: 字符串,表示非流式输出结果。
1143
1103
  """
1144
- self.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
1104
+ self.logger.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
1145
1105
 
1146
1106
  # 检查目标代理是否有效
1147
1107
  if not hasattr(target_agent, 'run'):
1148
- self.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
1108
+ self.logger.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
1149
1109
  return "Failed to transfer task: invalid target agent"
1150
1110
 
1151
1111
  try:
@@ -1154,7 +1114,7 @@ class LightAgent:
1154
1114
  return "".join(result) # 将生成器转换为字符串
1155
1115
  return result
1156
1116
  except Exception as e:
1157
- self.log("ERROR", "run_failed", {"error": str(e)})
1117
+ self.logger.log("ERROR", "run_failed", {"error": str(e)})
1158
1118
  raise # 重新抛出异常以便调试
1159
1119
 
1160
1120
  def _build_context(self, related_memories):
@@ -1171,7 +1131,7 @@ class LightAgent:
1171
1131
  return ""
1172
1132
 
1173
1133
  prompt = f"\n##用户偏好 \n用户之前提到了\n{memory_context}。"
1174
- self.log("DEBUG", "related_memories", {"memory_context": memory_context})
1134
+ self.logger.log("DEBUG", "related_memories", {"memory_context": memory_context})
1175
1135
  return prompt
1176
1136
 
1177
1137
  def _build_agent_memory(self, agent_memories):
@@ -1189,20 +1149,20 @@ class LightAgent:
1189
1149
  return ""
1190
1150
 
1191
1151
  prompt = f"\n##以下是解决该问题的相关补充信息:\n{memory_context}。"
1192
- self.log("DEBUG", "agent_memories", {"memory_context": memory_context})
1152
+ self.logger.log("DEBUG", "agent_memories", {"memory_context": memory_context})
1193
1153
  return prompt
1194
1154
 
1195
1155
  def run_thought(self, query: str) -> tuple:
1196
1156
  """使用思维树的方式 让大模型先根据get_tools_str生成一个解答用户query的工具使用计划"""
1197
1157
  tot_model = self.tot_model # self.model
1198
- tools = get_tools_str()
1158
+ tools = self.tool_registry.get_tools_str()
1199
1159
  if not isinstance(tools, str):
1200
1160
  tools = str(tools) # 确保 tools 是字符串
1201
1161
  now = datetime.now()
1202
1162
  current_date = now.strftime("%Y-%m-%d")
1203
1163
  current_time = now.strftime("%H:%M:%S")
1204
1164
  system_prompt = f"""你是一个智能助手,请根据用户输入的问题,结合工具使用计划,生成一个思维树,并按照思维树依次调用工具步骤,最终生成一个最终回答。\n 今日的日期: {current_date} 当前时间: {current_time} \n 工具列表: {tools}"""
1205
- self.log("DEBUG", "run_thought", {"system_prompt": system_prompt})
1165
+ self.logger.log("DEBUG", "run_thought", {"system_prompt": system_prompt})
1206
1166
 
1207
1167
  try:
1208
1168
  # 1. 第一次请求,生成初始的工具使用计划
@@ -1210,20 +1170,20 @@ class LightAgent:
1210
1170
  messages=[{"role": "system", "content": system_prompt}, {"role": "user", "content": query}],
1211
1171
  stream=False)
1212
1172
  response = self.tot_client.chat.completions.create(**params)
1213
- initial_content = response.choices[0].message.content
1214
- self.log("DEBUG", "initial_response", {"response": initial_content})
1173
+ thought_response = response.choices[0].message.content
1174
+ self.logger.log("DEBUG", "thought_response", {"response": thought_response})
1215
1175
 
1216
1176
  # 2. 第二次请求,请求大模型反思并生成新的工具使用规划
1217
1177
  reflection_prompt = "请反思你的回答,请严格按照<工具列表>中的工具来规划,不可以创造其他新的工具。请输出新的任务规划,不要输出其他分析和回答。"
1218
1178
  reflection_params = dict(model=tot_model, messages=[
1219
1179
  {"role": "user", "content": f"{system_prompt} /n 开始思考问题: {query}"},
1220
- {"role": "assistant", "content": initial_content},
1180
+ {"role": "assistant", "content": thought_response},
1221
1181
  {"role": "user", "content": reflection_prompt}
1222
1182
  ], stream=False)
1223
- self.log("DEBUG", "reflection_params", {"params": reflection_params})
1183
+ self.logger.log("DEBUG", "reflection_params", {"params": reflection_params})
1224
1184
  reflection_response = self.tot_client.chat.completions.create(**reflection_params)
1225
1185
  refined_content = reflection_response.choices[0].message.content
1226
- self.log("DEBUG", "refined_response", {"response": refined_content})
1186
+ self.logger.log("DEBUG", "reflection_response", {"response": refined_content})
1227
1187
 
1228
1188
  # 获取工具的使用集合
1229
1189
  tool_reflection_prompt = """请严格按以下要求执行:
@@ -1245,21 +1205,21 @@ class LightAgent:
1245
1205
  stream=False
1246
1206
  )
1247
1207
 
1248
- self.log("DEBUG", "tool_reflection_params", {"params": tool_reflection_params})
1208
+ self.logger.log("DEBUG", "tool_reflection_params", {"params": tool_reflection_params})
1249
1209
  tool_reflection_response = self.tot_client.chat.completions.create(**tool_reflection_params)
1250
1210
  tool_reflection_result = tool_reflection_response.choices[0].message.content
1251
- self.log("DEBUG", "tool_reflection_result", {"result": tool_reflection_result})
1211
+ self.logger.log("DEBUG", "tool_reflection_result", {"result": tool_reflection_result})
1252
1212
 
1253
1213
  # 3.执行自适应工具过滤
1254
1214
  current_tools = []
1255
1215
  if self.filter_tools:
1256
- current_tools = filter_tools_schemas(tool_reflection_result)
1257
- self.log("DEBUG", "current_tools", {"get_tools": current_tools})
1216
+ current_tools = self.tool_registry.filter_tools(tool_reflection_result)
1217
+ self.logger.log("DEBUG", "current_tools", {"get_tools": current_tools})
1258
1218
 
1259
1219
  return refined_content, current_tools
1260
1220
 
1261
1221
  except Exception as e:
1262
- self.log("ERROR", "run_thought_failure", {"error": str(e)})
1222
+ self.logger.log("ERROR", "run_thought_failure", {"error": str(e)})
1263
1223
  raise RuntimeError(f"思维链执行失败: {str(e)}") from e
1264
1224
 
1265
1225
  def _detect_intent(self, query: str, light_swarm=None) -> Optional[Dict]:
@@ -1296,7 +1256,7 @@ class LightAgent:
1296
1256
  messages=[{"role": "system", "content": prompt}]
1297
1257
  )
1298
1258
  intent = response.choices[0].message.content
1299
- self.log("DEBUG", "detect_intent", {"intent": intent})
1259
+ self.logger.log("DEBUG", "detect_intent", {"intent": intent})
1300
1260
 
1301
1261
  # # 使用正则表达式解析意图
1302
1262
  # match = re.search(r"transfer to (\w+)", intent, re.IGNORECASE)
@@ -1328,11 +1288,11 @@ class LightAgent:
1328
1288
  :param stream: 是否启用流式输出。
1329
1289
  :return: 如果 stream=True,返回生成器;否则返回完整结果字符串。
1330
1290
  """
1331
- self.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
1291
+ self.logger.log("INFO", "transfer_to_agent", {"from": self.name, "to": target_agent.name, "context": context})
1332
1292
 
1333
1293
  # 检查目标代理是否有效
1334
1294
  if not hasattr(target_agent, 'run'):
1335
- self.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
1295
+ self.logger.log("ERROR", "invalid_target_agent", {"target_agent": target_agent})
1336
1296
  return "Failed to transfer task: invalid target agent"
1337
1297
  #
1338
1298
  # # 调用目标代理的 run 方法
@@ -1352,7 +1312,7 @@ class LightAgent:
1352
1312
  return "".join(result) # 将生成器转换为字符串
1353
1313
  return result
1354
1314
  except Exception as e:
1355
- self.log("ERROR", "run_failed", {"error": str(e)})
1315
+ self.logger.log("ERROR", "run_failed", {"error": str(e)})
1356
1316
  raise # 重新抛出异常以便调试
1357
1317
 
1358
1318
  def create_tool(self, user_input: str, tools_directory: str = "tools"):
@@ -1429,19 +1389,19 @@ get_weather.tool_info = {
1429
1389
  tool_code = tool_data.get("tool_code")
1430
1390
 
1431
1391
  if not tool_name or not tool_code:
1432
- self.log("ERROR", "invalid_tool_data", {"tool_data": tool_data})
1392
+ self.logger.log("ERROR", "invalid_tool_data", {"tool_data": tool_data})
1433
1393
  continue
1434
1394
 
1435
1395
  # 保存生成的代码到 tools 目录
1436
1396
  tool_path = os.path.join(tools_directory, f"{tool_name}.py")
1437
1397
  with open(tool_path, "w", encoding="utf-8") as f:
1438
1398
  f.write(tool_code)
1439
- self.log("INFO", "tool_created", {"tool_name": tool_name, "tool_path": tool_path})
1399
+ self.logger.log("INFO", "tool_created", {"tool_name": tool_name, "tool_path": tool_path})
1440
1400
 
1441
1401
  # 自动加载新创建的工具
1442
1402
  self.load_tools([tool_name], tools_directory)
1443
1403
  except Exception as e:
1444
- self.log("ERROR", "tool_creation_failed", {"error": str(e)})
1404
+ self.logger.log("ERROR", "tool_creation_failed", {"error": str(e)})
1445
1405
 
1446
1406
 
1447
1407
  class LightSwarm:
@@ -1457,11 +1417,11 @@ class LightSwarm:
1457
1417
  for agent in agents:
1458
1418
  if agent.name in self.agents:
1459
1419
  # print(f"Agent '{agent.name}' is already registered.")
1460
- agent.log("INFO", "register_agent", {"agent_name": agent.name, "status": "already_registered"})
1420
+ agent.logger.log("INFO", "register_agent", {"agent_name": agent.name, "status": "already_registered"})
1461
1421
  else:
1462
1422
  self.agents[agent.name] = agent
1463
1423
  # print(f"Agent '{agent.name}' registered.")
1464
- agent.log("INFO", "register_agent", {"agent_name": agent.name, "status": "registered"})
1424
+ agent.logger.log("INFO", "register_agent", {"agent_name": agent.name, "status": "registered"})
1465
1425
 
1466
1426
  def run(self, agent: LightAgent, query: str, stream=False):
1467
1427
  """