bizyengine 1.2.50__py3-none-any.whl → 1.2.51__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bizyengine/bizy_server/errno.py +21 -0
- bizyengine/bizy_server/server.py +119 -159
- bizyengine/bizybot/__init__.py +12 -0
- bizyengine/bizybot/client.py +774 -0
- bizyengine/bizybot/config.py +129 -0
- bizyengine/bizybot/coordinator.py +556 -0
- bizyengine/bizybot/exceptions.py +186 -0
- bizyengine/bizybot/mcp/__init__.py +3 -0
- bizyengine/bizybot/mcp/manager.py +520 -0
- bizyengine/bizybot/mcp/models.py +46 -0
- bizyengine/bizybot/mcp/registry.py +129 -0
- bizyengine/bizybot/mcp/routing.py +378 -0
- bizyengine/bizybot/models.py +344 -0
- bizyengine/core/common/client.py +0 -1
- bizyengine/version.txt +1 -1
- {bizyengine-1.2.50.dist-info → bizyengine-1.2.51.dist-info}/METADATA +2 -1
- {bizyengine-1.2.50.dist-info → bizyengine-1.2.51.dist-info}/RECORD +19 -8
- {bizyengine-1.2.50.dist-info → bizyengine-1.2.51.dist-info}/WHEEL +0 -0
- {bizyengine-1.2.50.dist-info → bizyengine-1.2.51.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,774 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM客户端模块 - 封装OpenAI API调用
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
|
10
|
+
|
|
11
|
+
import aiohttp
|
|
12
|
+
from openai import AsyncOpenAI
|
|
13
|
+
from openai.types.chat import ChatCompletion
|
|
14
|
+
|
|
15
|
+
from bizyengine.bizybot.exceptions import (
|
|
16
|
+
LLMAPIError,
|
|
17
|
+
LLMResponseError,
|
|
18
|
+
LLMTimeoutError,
|
|
19
|
+
ToolValidationError,
|
|
20
|
+
)
|
|
21
|
+
from bizyengine.core.common.env_var import BIZYAIR_X_SERVER
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class LLMConfig:
|
|
26
|
+
"""LLM配置"""
|
|
27
|
+
|
|
28
|
+
api_key: str
|
|
29
|
+
base_url: str = BIZYAIR_X_SERVER
|
|
30
|
+
model: str = "moonshotai/Kimi-K2-Instruct"
|
|
31
|
+
temperature: float = 0.7
|
|
32
|
+
max_tokens: Optional[int] = None
|
|
33
|
+
timeout: float = 30.0
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class ToolFunction:
|
|
38
|
+
"""工具函数定义"""
|
|
39
|
+
|
|
40
|
+
name: str
|
|
41
|
+
arguments: str # JSON字符串格式
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ToolCall:
|
|
46
|
+
"""工具调用"""
|
|
47
|
+
|
|
48
|
+
id: str
|
|
49
|
+
type: str # 目前只支持 "function"
|
|
50
|
+
function: ToolFunction
|
|
51
|
+
result: Optional[dict] = None
|
|
52
|
+
error: Optional[str] = None
|
|
53
|
+
|
|
54
|
+
def to_dict(self) -> dict:
|
|
55
|
+
"""转换为字典格式"""
|
|
56
|
+
return {
|
|
57
|
+
"id": self.id,
|
|
58
|
+
"type": self.type,
|
|
59
|
+
"function": {
|
|
60
|
+
"name": self.function.name,
|
|
61
|
+
"arguments": self.function.arguments,
|
|
62
|
+
},
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class Message:
|
|
68
|
+
"""消息"""
|
|
69
|
+
|
|
70
|
+
role: str # "user", "assistant", "system", "tool"
|
|
71
|
+
content: Optional[str] = None
|
|
72
|
+
reasoning_content: Optional[str] = None # 支持推理模型
|
|
73
|
+
tool_calls: Optional[List[ToolCall]] = None
|
|
74
|
+
tool_call_id: Optional[str] = None
|
|
75
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
76
|
+
|
|
77
|
+
def to_openai_format(self) -> dict:
|
|
78
|
+
"""转换为OpenAI API格式"""
|
|
79
|
+
msg = {"role": self.role}
|
|
80
|
+
|
|
81
|
+
if self.content is not None:
|
|
82
|
+
msg["content"] = self.content
|
|
83
|
+
|
|
84
|
+
if self.tool_calls:
|
|
85
|
+
msg["tool_calls"] = [
|
|
86
|
+
{
|
|
87
|
+
"id": tc.id,
|
|
88
|
+
"type": tc.type,
|
|
89
|
+
"function": {
|
|
90
|
+
"name": tc.function.name,
|
|
91
|
+
"arguments": tc.function.arguments,
|
|
92
|
+
},
|
|
93
|
+
}
|
|
94
|
+
for tc in self.tool_calls
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
if self.tool_call_id:
|
|
98
|
+
msg["tool_call_id"] = self.tool_call_id
|
|
99
|
+
|
|
100
|
+
return msg
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass
|
|
104
|
+
class Usage:
|
|
105
|
+
"""使用统计"""
|
|
106
|
+
|
|
107
|
+
prompt_tokens: int
|
|
108
|
+
completion_tokens: int
|
|
109
|
+
total_tokens: int
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass
|
|
113
|
+
class LLMResponse:
|
|
114
|
+
"""LLM响应"""
|
|
115
|
+
|
|
116
|
+
id: str
|
|
117
|
+
choices: List["ResponseChoice"]
|
|
118
|
+
usage: Optional[Usage] = None
|
|
119
|
+
created: Optional[int] = None
|
|
120
|
+
model: Optional[str] = None
|
|
121
|
+
object: str = "chat.completion"
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@dataclass
|
|
125
|
+
class ResponseChoice:
|
|
126
|
+
"""响应选择"""
|
|
127
|
+
|
|
128
|
+
message: Message
|
|
129
|
+
finish_reason: Optional[str] = None # "stop", "eos", "length", "tool_calls"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class ToolCallProcessor:
|
|
133
|
+
"""工具调用处理器"""
|
|
134
|
+
|
|
135
|
+
def extract_tool_calls(self, message: Message) -> List[ToolCall]:
|
|
136
|
+
"""从消息中提取工具调用"""
|
|
137
|
+
if not message.tool_calls:
|
|
138
|
+
return []
|
|
139
|
+
|
|
140
|
+
tool_calls = []
|
|
141
|
+
for tc in message.tool_calls:
|
|
142
|
+
try:
|
|
143
|
+
# 验证工具调用格式
|
|
144
|
+
if tc.type != "function":
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
# 解析和验证参数JSON(仅校验,不保留变量以避免未使用告警)
|
|
148
|
+
self._parse_and_validate_arguments(tc.function.arguments)
|
|
149
|
+
|
|
150
|
+
tool_call = ToolCall(
|
|
151
|
+
id=tc.id,
|
|
152
|
+
type=tc.type,
|
|
153
|
+
function=ToolFunction(
|
|
154
|
+
name=tc.function.name, arguments=tc.function.arguments
|
|
155
|
+
),
|
|
156
|
+
)
|
|
157
|
+
tool_calls.append(tool_call)
|
|
158
|
+
|
|
159
|
+
except Exception:
|
|
160
|
+
# 解析错误但继续处理其他工具调用
|
|
161
|
+
continue
|
|
162
|
+
|
|
163
|
+
return tool_calls
|
|
164
|
+
|
|
165
|
+
def _parse_and_validate_arguments(self, arguments_str: str) -> dict:
|
|
166
|
+
"""解析和验证工具参数"""
|
|
167
|
+
try:
|
|
168
|
+
arguments = json.loads(arguments_str)
|
|
169
|
+
if not isinstance(arguments, dict):
|
|
170
|
+
raise ToolValidationError("Arguments must be a JSON object")
|
|
171
|
+
return arguments
|
|
172
|
+
except json.JSONDecodeError as e:
|
|
173
|
+
raise ToolValidationError(f"Invalid JSON in tool arguments: {e}") from e
|
|
174
|
+
|
|
175
|
+
def validate_tool_arguments(self, tool_call: ToolCall, tool_schema: dict) -> dict:
|
|
176
|
+
"""验证工具参数是否符合工具模式"""
|
|
177
|
+
try:
|
|
178
|
+
arguments = json.loads(tool_call.function.arguments)
|
|
179
|
+
|
|
180
|
+
# 基本类型检查
|
|
181
|
+
if not isinstance(arguments, dict):
|
|
182
|
+
raise ToolValidationError("Tool arguments must be a JSON object")
|
|
183
|
+
|
|
184
|
+
# 如果有schema,进行更详细的验证
|
|
185
|
+
if tool_schema and "parameters" in tool_schema:
|
|
186
|
+
self._validate_against_schema(arguments, tool_schema["parameters"])
|
|
187
|
+
|
|
188
|
+
return arguments
|
|
189
|
+
|
|
190
|
+
except json.JSONDecodeError as e:
|
|
191
|
+
raise ToolValidationError(f"Invalid JSON in tool arguments: {e}") from e
|
|
192
|
+
except ToolValidationError:
|
|
193
|
+
raise
|
|
194
|
+
except Exception as e:
|
|
195
|
+
raise ToolValidationError(f"Tool argument validation failed: {e}") from e
|
|
196
|
+
|
|
197
|
+
def _validate_against_schema(self, arguments: dict, schema: dict) -> None:
|
|
198
|
+
"""根据JSON Schema验证参数"""
|
|
199
|
+
# 基础验证 - 检查必需参数
|
|
200
|
+
if "required" in schema:
|
|
201
|
+
for required_field in schema["required"]:
|
|
202
|
+
if required_field not in arguments:
|
|
203
|
+
raise ToolValidationError(
|
|
204
|
+
f"Missing required parameter: {required_field}"
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# 检查参数类型(简化版本)
|
|
208
|
+
if "properties" in schema:
|
|
209
|
+
for param_name, param_value in arguments.items():
|
|
210
|
+
if param_name in schema["properties"]:
|
|
211
|
+
expected_type = schema["properties"][param_name].get("type")
|
|
212
|
+
if expected_type and not self._check_type(
|
|
213
|
+
param_value, expected_type
|
|
214
|
+
):
|
|
215
|
+
raise ToolValidationError(
|
|
216
|
+
f"Parameter {param_name} has wrong type, expected {expected_type}"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def _check_type(self, value: Any, expected_type: str) -> bool:
|
|
220
|
+
"""检查值的类型是否符合预期"""
|
|
221
|
+
type_mapping = {
|
|
222
|
+
"string": str,
|
|
223
|
+
"number": (int, float),
|
|
224
|
+
"integer": int,
|
|
225
|
+
"boolean": bool,
|
|
226
|
+
"array": list,
|
|
227
|
+
"object": dict,
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
expected_python_type = type_mapping.get(expected_type)
|
|
231
|
+
if expected_python_type:
|
|
232
|
+
return isinstance(value, expected_python_type)
|
|
233
|
+
|
|
234
|
+
return True # 未知类型,跳过检查
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class StreamingToolCallHandler:
|
|
238
|
+
"""流式工具调用处理器"""
|
|
239
|
+
|
|
240
|
+
def __init__(self):
|
|
241
|
+
self.partial_tool_calls = {} # 存储部分工具调用数据
|
|
242
|
+
self.tool_processor = ToolCallProcessor()
|
|
243
|
+
|
|
244
|
+
def process_streaming_tool_calls(self, chunk: dict) -> Optional[List[ToolCall]]:
|
|
245
|
+
"""处理流式响应中的工具调用"""
|
|
246
|
+
# 处理工具调用数据
|
|
247
|
+
if "tool_calls" in chunk:
|
|
248
|
+
tool_calls_delta = chunk["tool_calls"]
|
|
249
|
+
|
|
250
|
+
for tc_delta in tool_calls_delta:
|
|
251
|
+
# 获取工具调用的索引,用于匹配(OpenAI流式响应中的关键字段)
|
|
252
|
+
call_index = tc_delta.get("index", 0)
|
|
253
|
+
call_id = tc_delta.get("id")
|
|
254
|
+
|
|
255
|
+
# 使用索引作为主要标识符,因为在流式响应中ID可能为空
|
|
256
|
+
primary_key = f"call_{call_index}"
|
|
257
|
+
|
|
258
|
+
# 累积工具调用数据
|
|
259
|
+
if primary_key not in self.partial_tool_calls:
|
|
260
|
+
self.partial_tool_calls[primary_key] = {
|
|
261
|
+
"id": call_id or "", # 初始化为空字符串,等待真实ID
|
|
262
|
+
"index": call_index,
|
|
263
|
+
"type": tc_delta.get("type", "function"),
|
|
264
|
+
"function": {"name": "", "arguments": ""},
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
# 更新真实ID(如果提供了)- 这通常在第一个chunk中出现
|
|
268
|
+
if call_id and call_id != "":
|
|
269
|
+
self.partial_tool_calls[primary_key]["id"] = call_id
|
|
270
|
+
|
|
271
|
+
# 更新函数名称
|
|
272
|
+
if (
|
|
273
|
+
"function" in tc_delta
|
|
274
|
+
and "name" in tc_delta["function"]
|
|
275
|
+
and tc_delta["function"]["name"]
|
|
276
|
+
):
|
|
277
|
+
self.partial_tool_calls[primary_key]["function"][
|
|
278
|
+
"name"
|
|
279
|
+
] += tc_delta["function"]["name"]
|
|
280
|
+
|
|
281
|
+
# 累积参数
|
|
282
|
+
if (
|
|
283
|
+
"function" in tc_delta
|
|
284
|
+
and "arguments" in tc_delta["function"]
|
|
285
|
+
and tc_delta["function"]["arguments"]
|
|
286
|
+
):
|
|
287
|
+
self.partial_tool_calls[primary_key]["function"][
|
|
288
|
+
"arguments"
|
|
289
|
+
] += tc_delta["function"]["arguments"]
|
|
290
|
+
|
|
291
|
+
# 检查是否有完整的工具调用 - 当收到finish_reason时完成工具调用
|
|
292
|
+
if chunk.get("finish_reason") == "tool_calls":
|
|
293
|
+
return self._finalize_tool_calls()
|
|
294
|
+
|
|
295
|
+
return None
|
|
296
|
+
|
|
297
|
+
def _finalize_tool_calls(self) -> Optional[List[ToolCall]]:
|
|
298
|
+
"""
|
|
299
|
+
完成部分工具调用 - 将之前累积的工具调用数据转换为完整的工具调用对象执行
|
|
300
|
+
工具调用 - 通过 MCP 服务器执行这些工具
|
|
301
|
+
继续对话 - 将工具执行结果返回给 LLM 继续生成响应
|
|
302
|
+
"""
|
|
303
|
+
if not self.partial_tool_calls:
|
|
304
|
+
return None
|
|
305
|
+
|
|
306
|
+
completed_calls = []
|
|
307
|
+
for primary_key, call_data in self.partial_tool_calls.items():
|
|
308
|
+
try:
|
|
309
|
+
# 验证参数是否为完整的JSON
|
|
310
|
+
# 确保有有效的ID
|
|
311
|
+
# 创建完整的ToolCall对象
|
|
312
|
+
arguments_str = call_data["function"]["arguments"]
|
|
313
|
+
if arguments_str: # 只有当参数不为空时才验证
|
|
314
|
+
self.tool_processor._parse_and_validate_arguments(arguments_str)
|
|
315
|
+
|
|
316
|
+
# 确保有有效的ID,如果没有则使用primary_key
|
|
317
|
+
tool_id = call_data["id"] if call_data["id"] else primary_key
|
|
318
|
+
|
|
319
|
+
# 创建工具调用对象
|
|
320
|
+
tool_call = ToolCall(
|
|
321
|
+
id=tool_id,
|
|
322
|
+
type=call_data["type"],
|
|
323
|
+
function=ToolFunction(
|
|
324
|
+
name=call_data["function"]["name"], arguments=arguments_str
|
|
325
|
+
),
|
|
326
|
+
)
|
|
327
|
+
completed_calls.append(tool_call)
|
|
328
|
+
|
|
329
|
+
except (json.JSONDecodeError, ValueError):
|
|
330
|
+
pass
|
|
331
|
+
|
|
332
|
+
# 清理已完成的调用
|
|
333
|
+
self.partial_tool_calls.clear()
|
|
334
|
+
|
|
335
|
+
return completed_calls if completed_calls else None
|
|
336
|
+
|
|
337
|
+
def get_partial_tool_calls(self) -> Dict[str, dict]:
|
|
338
|
+
"""获取当前部分工具调用状态(用于调试)"""
|
|
339
|
+
return self.partial_tool_calls.copy()
|
|
340
|
+
|
|
341
|
+
def clear_partial_calls(self) -> None:
|
|
342
|
+
"""清理部分工具调用数据"""
|
|
343
|
+
self.partial_tool_calls.clear()
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
class LLMClient:
|
|
347
|
+
"""LLM客户端 - 封装OpenAI API调用"""
|
|
348
|
+
|
|
349
|
+
def __init__(self, config: LLMConfig):
|
|
350
|
+
"""初始化LLM客户端"""
|
|
351
|
+
self.config = config
|
|
352
|
+
self._client = None
|
|
353
|
+
self._session = None
|
|
354
|
+
self._streaming_handler = StreamingToolCallHandler()
|
|
355
|
+
self._tool_processor = ToolCallProcessor()
|
|
356
|
+
|
|
357
|
+
@property
|
|
358
|
+
def client(self) -> AsyncOpenAI:
|
|
359
|
+
"""获取OpenAI客户端实例"""
|
|
360
|
+
if self._client is None:
|
|
361
|
+
self._client = AsyncOpenAI(
|
|
362
|
+
api_key=self.config.api_key,
|
|
363
|
+
base_url=self.config.base_url,
|
|
364
|
+
timeout=self.config.timeout,
|
|
365
|
+
)
|
|
366
|
+
return self._client
|
|
367
|
+
|
|
368
|
+
@property
|
|
369
|
+
def session(self) -> aiohttp.ClientSession:
|
|
370
|
+
"""获取aiohttp会话实例"""
|
|
371
|
+
if self._session is None:
|
|
372
|
+
timeout = aiohttp.ClientTimeout(total=self.config.timeout)
|
|
373
|
+
self._session = aiohttp.ClientSession(timeout=timeout)
|
|
374
|
+
return self._session
|
|
375
|
+
|
|
376
|
+
async def chat_completion(
|
|
377
|
+
self,
|
|
378
|
+
messages: List[dict],
|
|
379
|
+
tools: Optional[List[dict]] = None,
|
|
380
|
+
stream: bool = False,
|
|
381
|
+
**kwargs,
|
|
382
|
+
) -> Union[LLMResponse, AsyncIterator[dict]]:
|
|
383
|
+
"""
|
|
384
|
+
发送聊天完成请求
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
messages: 消息列表
|
|
388
|
+
tools: 可用工具列表
|
|
389
|
+
stream: 是否使用流式响应
|
|
390
|
+
**kwargs: 其他参数
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
LLMResponse或流式响应迭代器
|
|
394
|
+
"""
|
|
395
|
+
try:
|
|
396
|
+
# 合并配置参数
|
|
397
|
+
params = {
|
|
398
|
+
"model": kwargs.get("model", self.config.model),
|
|
399
|
+
"messages": messages,
|
|
400
|
+
"temperature": kwargs.get("temperature", self.config.temperature),
|
|
401
|
+
"stream": stream,
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
if self.config.max_tokens:
|
|
405
|
+
params["max_tokens"] = kwargs.get("max_tokens", self.config.max_tokens)
|
|
406
|
+
|
|
407
|
+
if tools:
|
|
408
|
+
params["tools"] = tools
|
|
409
|
+
|
|
410
|
+
try:
|
|
411
|
+
if stream:
|
|
412
|
+
# 流式响应
|
|
413
|
+
stream_response = await self.client.chat.completions.create(
|
|
414
|
+
**params
|
|
415
|
+
)
|
|
416
|
+
result = self._process_streaming_response(stream_response)
|
|
417
|
+
return result
|
|
418
|
+
else:
|
|
419
|
+
# 非流式响应
|
|
420
|
+
response = await self.client.chat.completions.create(**params)
|
|
421
|
+
result = self._parse_completion_response(response)
|
|
422
|
+
return result
|
|
423
|
+
except Exception:
|
|
424
|
+
raise
|
|
425
|
+
|
|
426
|
+
except asyncio.TimeoutError as e:
|
|
427
|
+
raise LLMTimeoutError(
|
|
428
|
+
f"Request timeout after {self.config.timeout}s"
|
|
429
|
+
) from e
|
|
430
|
+
except Exception as e:
|
|
431
|
+
# Determine if it's an API error with status code
|
|
432
|
+
status_code = getattr(e, "status_code", None)
|
|
433
|
+
response_body = getattr(e, "response", None)
|
|
434
|
+
if response_body:
|
|
435
|
+
response_body = str(response_body)
|
|
436
|
+
|
|
437
|
+
raise LLMAPIError(
|
|
438
|
+
f"API call failed: {str(e)}",
|
|
439
|
+
status_code=status_code,
|
|
440
|
+
response_body=response_body,
|
|
441
|
+
) from e
|
|
442
|
+
|
|
443
|
+
def _parse_completion_response(self, response: ChatCompletion) -> LLMResponse:
|
|
444
|
+
"""非流式解析完整响应"""
|
|
445
|
+
choices = []
|
|
446
|
+
for choice in response.choices:
|
|
447
|
+
# 解析工具调用
|
|
448
|
+
tool_calls = None
|
|
449
|
+
if choice.message.tool_calls:
|
|
450
|
+
tool_calls = [
|
|
451
|
+
ToolCall(
|
|
452
|
+
id=tc.id,
|
|
453
|
+
type=tc.type,
|
|
454
|
+
function=ToolFunction(
|
|
455
|
+
name=tc.function.name, arguments=tc.function.arguments
|
|
456
|
+
),
|
|
457
|
+
)
|
|
458
|
+
for tc in choice.message.tool_calls
|
|
459
|
+
]
|
|
460
|
+
|
|
461
|
+
message = Message(
|
|
462
|
+
role=choice.message.role,
|
|
463
|
+
content=choice.message.content,
|
|
464
|
+
tool_calls=tool_calls,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
choices.append(
|
|
468
|
+
ResponseChoice(message=message, finish_reason=choice.finish_reason)
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
usage = None
|
|
472
|
+
if response.usage:
|
|
473
|
+
usage = Usage(
|
|
474
|
+
prompt_tokens=response.usage.prompt_tokens,
|
|
475
|
+
completion_tokens=response.usage.completion_tokens,
|
|
476
|
+
total_tokens=response.usage.total_tokens,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
return LLMResponse(
|
|
480
|
+
id=response.id,
|
|
481
|
+
choices=choices,
|
|
482
|
+
usage=usage,
|
|
483
|
+
created=response.created,
|
|
484
|
+
model=response.model,
|
|
485
|
+
object=response.object,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
async def _process_streaming_response(self, stream) -> AsyncIterator[dict]:
|
|
489
|
+
"""接收原始、复杂的数据流,然后把它实时地翻译和整理成一种更干净、更标准化的格式,但是不对参数进行整理积累,而是返回增量参数"""
|
|
490
|
+
try:
|
|
491
|
+
async for chunk in stream:
|
|
492
|
+
|
|
493
|
+
if chunk.choices:
|
|
494
|
+
choice = chunk.choices[0]
|
|
495
|
+
delta = choice.delta
|
|
496
|
+
|
|
497
|
+
result = {
|
|
498
|
+
"type": "chunk",
|
|
499
|
+
"id": chunk.id,
|
|
500
|
+
"created": chunk.created,
|
|
501
|
+
"model": chunk.model,
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
# 处理内容增量
|
|
505
|
+
if delta.content:
|
|
506
|
+
result["content"] = delta.content
|
|
507
|
+
|
|
508
|
+
# 处理推理内容增量(如果支持)
|
|
509
|
+
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
|
510
|
+
result["reasoning_content"] = delta.reasoning_content
|
|
511
|
+
|
|
512
|
+
# 处理工具调用 - 支持delta中的tool_calls
|
|
513
|
+
if hasattr(delta, "tool_calls") and delta.tool_calls:
|
|
514
|
+
result["tool_calls"] = [
|
|
515
|
+
{
|
|
516
|
+
"id": tc.id if tc.id else "",
|
|
517
|
+
"index": getattr(tc, "index", 0), # 添加index字段
|
|
518
|
+
"type": tc.type if tc.type else "function",
|
|
519
|
+
"function": {
|
|
520
|
+
"name": (
|
|
521
|
+
tc.function.name
|
|
522
|
+
if tc.function and tc.function.name
|
|
523
|
+
else ""
|
|
524
|
+
),
|
|
525
|
+
"arguments": (
|
|
526
|
+
tc.function.arguments
|
|
527
|
+
if tc.function and tc.function.arguments
|
|
528
|
+
else ""
|
|
529
|
+
),
|
|
530
|
+
},
|
|
531
|
+
}
|
|
532
|
+
for tc in delta.tool_calls
|
|
533
|
+
]
|
|
534
|
+
|
|
535
|
+
# 处理完整消息中的工具调用(某些API可能在这里返回)
|
|
536
|
+
if (
|
|
537
|
+
hasattr(choice, "message")
|
|
538
|
+
and hasattr(choice.message, "tool_calls")
|
|
539
|
+
and choice.message.tool_calls
|
|
540
|
+
):
|
|
541
|
+
result["tool_calls"] = [
|
|
542
|
+
{
|
|
543
|
+
"id": tc.id if tc.id else "",
|
|
544
|
+
"type": tc.type if tc.type else "function",
|
|
545
|
+
"function": {
|
|
546
|
+
"name": (
|
|
547
|
+
tc.function.name
|
|
548
|
+
if tc.function and tc.function.name
|
|
549
|
+
else ""
|
|
550
|
+
),
|
|
551
|
+
"arguments": (
|
|
552
|
+
tc.function.arguments
|
|
553
|
+
if tc.function and tc.function.arguments
|
|
554
|
+
else ""
|
|
555
|
+
),
|
|
556
|
+
},
|
|
557
|
+
}
|
|
558
|
+
for tc in choice.message.tool_calls
|
|
559
|
+
]
|
|
560
|
+
|
|
561
|
+
# 处理结束原因
|
|
562
|
+
if choice.finish_reason:
|
|
563
|
+
result["finish_reason"] = choice.finish_reason
|
|
564
|
+
|
|
565
|
+
yield result
|
|
566
|
+
|
|
567
|
+
except Exception as e:
|
|
568
|
+
raise LLMResponseError(f"Streaming response error: {str(e)}") from e
|
|
569
|
+
|
|
570
|
+
async def parse_streaming_response(
|
|
571
|
+
self, stream: AsyncIterator[bytes]
|
|
572
|
+
) -> AsyncIterator[dict]:
|
|
573
|
+
"""
|
|
574
|
+
解析SSE流式响应
|
|
575
|
+
|
|
576
|
+
Args:
|
|
577
|
+
stream: 字节流迭代器
|
|
578
|
+
|
|
579
|
+
Yields:
|
|
580
|
+
解析后的响应数据
|
|
581
|
+
"""
|
|
582
|
+
try:
|
|
583
|
+
buffer = ""
|
|
584
|
+
async for chunk in stream:
|
|
585
|
+
if isinstance(chunk, bytes):
|
|
586
|
+
chunk = chunk.decode("utf-8")
|
|
587
|
+
|
|
588
|
+
buffer += chunk
|
|
589
|
+
lines = buffer.split("\n")
|
|
590
|
+
buffer = lines[-1] # 保留最后一行(可能不完整)
|
|
591
|
+
|
|
592
|
+
for line in lines[:-1]:
|
|
593
|
+
line = line.strip()
|
|
594
|
+
if line.startswith("data: "):
|
|
595
|
+
data_str = line[6:].strip()
|
|
596
|
+
|
|
597
|
+
if data_str == "[DONE]":
|
|
598
|
+
return
|
|
599
|
+
|
|
600
|
+
try:
|
|
601
|
+
chunk_data = json.loads(data_str)
|
|
602
|
+
|
|
603
|
+
# 解析流式chunk
|
|
604
|
+
if chunk_data.get("object") == "chat.completion.chunk":
|
|
605
|
+
yield self._process_sse_chunk(chunk_data)
|
|
606
|
+
|
|
607
|
+
except json.JSONDecodeError:
|
|
608
|
+
continue
|
|
609
|
+
|
|
610
|
+
except Exception as e:
|
|
611
|
+
raise LLMResponseError(f"Streaming parse error: {str(e)}") from e
|
|
612
|
+
|
|
613
|
+
def _process_sse_chunk(self, chunk_data: dict) -> dict:
|
|
614
|
+
"""处理单个SSE数据块"""
|
|
615
|
+
choice = chunk_data["choices"][0]
|
|
616
|
+
delta = choice.get("delta", {})
|
|
617
|
+
|
|
618
|
+
result = {
|
|
619
|
+
"type": "chunk",
|
|
620
|
+
"id": chunk_data["id"],
|
|
621
|
+
"created": chunk_data.get("created"),
|
|
622
|
+
"model": chunk_data.get("model"),
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
# 处理内容增量
|
|
626
|
+
if "content" in delta and delta["content"]:
|
|
627
|
+
result["content"] = delta["content"]
|
|
628
|
+
|
|
629
|
+
# 处理推理内容增量
|
|
630
|
+
if "reasoning_content" in delta and delta["reasoning_content"]:
|
|
631
|
+
result["reasoning_content"] = delta["reasoning_content"]
|
|
632
|
+
|
|
633
|
+
# 处理工具调用
|
|
634
|
+
if "tool_calls" in delta and delta["tool_calls"]:
|
|
635
|
+
result["tool_calls"] = delta["tool_calls"]
|
|
636
|
+
|
|
637
|
+
# 处理结束原因
|
|
638
|
+
if choice.get("finish_reason"):
|
|
639
|
+
result["finish_reason"] = choice["finish_reason"]
|
|
640
|
+
|
|
641
|
+
return result
|
|
642
|
+
|
|
643
|
+
async def handle_streaming_with_reconnect(
|
|
644
|
+
self,
|
|
645
|
+
messages: List[dict],
|
|
646
|
+
tools: Optional[List[dict]] = None,
|
|
647
|
+
max_retries: int = 3,
|
|
648
|
+
**kwargs,
|
|
649
|
+
) -> AsyncIterator[dict]:
|
|
650
|
+
"""
|
|
651
|
+
带重连机制的流式响应处理
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
messages: 消息列表
|
|
655
|
+
tools: 可用工具列表
|
|
656
|
+
max_retries: 最大重试次数
|
|
657
|
+
**kwargs: 其他参数
|
|
658
|
+
|
|
659
|
+
Yields:
|
|
660
|
+
流式响应数据
|
|
661
|
+
"""
|
|
662
|
+
retry_count = 0
|
|
663
|
+
|
|
664
|
+
while retry_count <= max_retries:
|
|
665
|
+
try:
|
|
666
|
+
stream = await self.chat_completion(
|
|
667
|
+
messages=messages, tools=tools, stream=True, **kwargs
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
async for chunk in stream:
|
|
671
|
+
yield chunk
|
|
672
|
+
|
|
673
|
+
# 成功完成,退出重试循环
|
|
674
|
+
break
|
|
675
|
+
|
|
676
|
+
except (LLMTimeoutError,):
|
|
677
|
+
retry_count += 1
|
|
678
|
+
if retry_count > max_retries:
|
|
679
|
+
raise
|
|
680
|
+
|
|
681
|
+
wait_time = min(2**retry_count, 30) # 指数退避,最大30秒
|
|
682
|
+
await asyncio.sleep(wait_time)
|
|
683
|
+
|
|
684
|
+
except Exception:
|
|
685
|
+
raise
|
|
686
|
+
|
|
687
|
+
async def parse_tool_calls(self, response: LLMResponse) -> List[ToolCall]:
|
|
688
|
+
"""
|
|
689
|
+
解析LLM响应中的工具调用
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
response: LLM响应对象
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
工具调用列表
|
|
696
|
+
"""
|
|
697
|
+
tool_calls = []
|
|
698
|
+
|
|
699
|
+
for choice in response.choices:
|
|
700
|
+
if choice.message.tool_calls:
|
|
701
|
+
extracted_calls = self._tool_processor.extract_tool_calls(
|
|
702
|
+
choice.message
|
|
703
|
+
)
|
|
704
|
+
tool_calls.extend(extracted_calls)
|
|
705
|
+
|
|
706
|
+
return tool_calls
|
|
707
|
+
|
|
708
|
+
async def validate_tool_arguments(
|
|
709
|
+
self, tool_call: ToolCall, tool_schema: dict
|
|
710
|
+
) -> dict:
|
|
711
|
+
"""
|
|
712
|
+
验证工具调用参数
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
tool_call: 工具调用对象
|
|
716
|
+
tool_schema: 工具的JSON Schema
|
|
717
|
+
|
|
718
|
+
Returns:
|
|
719
|
+
验证后的参数字典
|
|
720
|
+
|
|
721
|
+
Raises:
|
|
722
|
+
ValueError: 参数验证失败
|
|
723
|
+
"""
|
|
724
|
+
return self._tool_processor.validate_tool_arguments(tool_call, tool_schema)
|
|
725
|
+
|
|
726
|
+
def extract_tool_calls_from_message(self, message: Message) -> List[ToolCall]:
|
|
727
|
+
"""
|
|
728
|
+
从消息中提取工具调用
|
|
729
|
+
|
|
730
|
+
Args:
|
|
731
|
+
message: 消息对象
|
|
732
|
+
|
|
733
|
+
Returns:
|
|
734
|
+
工具调用列表
|
|
735
|
+
"""
|
|
736
|
+
return self._tool_processor.extract_tool_calls(message)
|
|
737
|
+
|
|
738
|
+
def process_streaming_tool_calls_incremental(
|
|
739
|
+
self, chunk: dict
|
|
740
|
+
) -> Optional[List[ToolCall]]:
|
|
741
|
+
"""
|
|
742
|
+
增量处理流式工具调用
|
|
743
|
+
|
|
744
|
+
Args:
|
|
745
|
+
chunk: 流式响应块
|
|
746
|
+
|
|
747
|
+
Returns:
|
|
748
|
+
完成的工具调用列表(如果有)
|
|
749
|
+
"""
|
|
750
|
+
return self._streaming_handler.process_streaming_tool_calls(chunk)
|
|
751
|
+
|
|
752
|
+
def get_streaming_tool_call_status(self) -> Dict[str, dict]:
|
|
753
|
+
"""获取当前流式工具调用的状态"""
|
|
754
|
+
return self._streaming_handler.get_partial_tool_calls()
|
|
755
|
+
|
|
756
|
+
def clear_streaming_tool_calls(self) -> None:
|
|
757
|
+
"""清理流式工具调用状态"""
|
|
758
|
+
self._streaming_handler.clear_partial_calls()
|
|
759
|
+
|
|
760
|
+
def update_config(self, config: LLMConfig) -> None:
|
|
761
|
+
"""更新配置"""
|
|
762
|
+
self.config = config
|
|
763
|
+
# 重置客户端以使用新配置
|
|
764
|
+
self._client = None
|
|
765
|
+
|
|
766
|
+
async def close(self) -> None:
|
|
767
|
+
"""关闭客户端连接"""
|
|
768
|
+
if self._client:
|
|
769
|
+
await self._client.close()
|
|
770
|
+
self._client = None
|
|
771
|
+
|
|
772
|
+
if self._session:
|
|
773
|
+
await self._session.close()
|
|
774
|
+
self._session = None
|