tamar-model-client 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/__init__.py +2 -0
- tamar_model_client/async_client.py +1 -1
- tamar_model_client/core/base_client.py +4 -0
- tamar_model_client/core/response_handler.py +86 -3
- tamar_model_client/schemas/outputs.py +19 -0
- tamar_model_client/sync_client.py +1 -1
- tamar_model_client/tool_call_helper.py +169 -0
- {tamar_model_client-0.2.6.dist-info → tamar_model_client-0.2.7.dist-info}/METADATA +1 -1
- {tamar_model_client-0.2.6.dist-info → tamar_model_client-0.2.7.dist-info}/RECORD +13 -11
- tests/test_google_azure_final.py +21 -21
- tests/test_tool_call_enhancement.py +571 -0
- {tamar_model_client-0.2.6.dist-info → tamar_model_client-0.2.7.dist-info}/WHEEL +0 -0
- {tamar_model_client-0.2.6.dist-info → tamar_model_client-0.2.7.dist-info}/top_level.txt +0 -0
tamar_model_client/__init__.py
CHANGED
@@ -3,6 +3,7 @@ from .async_client import AsyncTamarModelClient
|
|
3
3
|
from .exceptions import ModelManagerClientError, ConnectionError, ValidationError
|
4
4
|
from .json_formatter import JSONFormatter
|
5
5
|
from . import logging_icons
|
6
|
+
from .tool_call_helper import ToolCallHelper
|
6
7
|
|
7
8
|
__all__ = [
|
8
9
|
"TamarModelClient",
|
@@ -12,4 +13,5 @@ __all__ = [
|
|
12
13
|
"ValidationError",
|
13
14
|
"JSONFormatter",
|
14
15
|
"logging_icons",
|
16
|
+
"ToolCallHelper",
|
15
17
|
]
|
@@ -623,7 +623,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
623
623
|
origin_request_id
|
624
624
|
)
|
625
625
|
stream_iter = self.stub.Invoke(request, metadata=fresh_metadata, timeout=invoke_timeout).__aiter__()
|
626
|
-
chunk_timeout =
|
626
|
+
chunk_timeout = self.stream_chunk_timeout # 单个数据块的超时时间
|
627
627
|
|
628
628
|
try:
|
629
629
|
while True:
|
@@ -296,6 +296,10 @@ class BaseClient(ABC):
|
|
296
296
|
if hasattr(grpc.StatusCode, error_name):
|
297
297
|
self.never_fallback_errors.add(getattr(grpc.StatusCode, error_name))
|
298
298
|
|
299
|
+
# 流式响应单个数据块的超时时间(秒)
|
300
|
+
# AI模型生成可能需要更长时间,默认设置为120秒
|
301
|
+
self.stream_chunk_timeout = float(os.getenv('MODEL_CLIENT_STREAM_CHUNK_TIMEOUT', '120.0'))
|
302
|
+
|
299
303
|
if self.fast_fallback_enabled:
|
300
304
|
self.logger.info(
|
301
305
|
"🚀 Fast fallback enabled",
|
@@ -22,20 +22,87 @@ class ResponseHandler:
|
|
22
22
|
@staticmethod
|
23
23
|
def build_model_response(grpc_response) -> ModelResponse:
|
24
24
|
"""
|
25
|
-
从 gRPC
|
25
|
+
从 gRPC 响应构建增强的 ModelResponse 对象
|
26
|
+
|
27
|
+
新增功能:
|
28
|
+
1. 自动提取 tool_calls(对标 OpenAI SDK)
|
29
|
+
2. 提取 finish_reason(对标 OpenAI SDK)
|
30
|
+
3. 支持多种 provider 格式转换
|
26
31
|
|
27
32
|
Args:
|
28
33
|
grpc_response: gRPC 服务返回的响应对象
|
29
34
|
|
30
35
|
Returns:
|
31
|
-
ModelResponse:
|
36
|
+
ModelResponse: 增强的客户端响应对象
|
32
37
|
"""
|
38
|
+
raw_response = ResponseHandler._parse_json_field(grpc_response.raw_response)
|
39
|
+
|
40
|
+
# 提取 tool_calls 和 finish_reason
|
41
|
+
tool_calls = None
|
42
|
+
finish_reason = None
|
43
|
+
|
44
|
+
if raw_response and isinstance(raw_response, dict):
|
45
|
+
# OpenAI/Azure OpenAI 格式
|
46
|
+
if 'choices' in raw_response and raw_response['choices']:
|
47
|
+
choice = raw_response['choices'][0]
|
48
|
+
|
49
|
+
# 提取 tool_calls
|
50
|
+
if 'message' in choice and 'tool_calls' in choice['message']:
|
51
|
+
tool_calls = choice['message']['tool_calls']
|
52
|
+
|
53
|
+
# 提取 finish_reason
|
54
|
+
if 'finish_reason' in choice:
|
55
|
+
finish_reason = choice['finish_reason']
|
56
|
+
|
57
|
+
# Google AI 格式适配
|
58
|
+
elif 'candidates' in raw_response and raw_response['candidates']:
|
59
|
+
candidate = raw_response['candidates'][0]
|
60
|
+
|
61
|
+
# Google 格式的 function calls 映射
|
62
|
+
if 'content' in candidate and 'parts' in candidate['content']:
|
63
|
+
parts = candidate['content']['parts']
|
64
|
+
google_tool_calls = []
|
65
|
+
|
66
|
+
for i, part in enumerate(parts):
|
67
|
+
if 'functionCall' in part:
|
68
|
+
# 转换为 OpenAI 兼容格式
|
69
|
+
function_call = part['functionCall']
|
70
|
+
google_tool_calls.append({
|
71
|
+
'id': f"call_{i}_{function_call.get('name', 'unknown')}",
|
72
|
+
'type': 'function',
|
73
|
+
'function': {
|
74
|
+
'name': function_call.get('name', ''),
|
75
|
+
'arguments': json.dumps(function_call.get('args', {}))
|
76
|
+
}
|
77
|
+
})
|
78
|
+
|
79
|
+
if google_tool_calls:
|
80
|
+
tool_calls = google_tool_calls
|
81
|
+
|
82
|
+
# Google 的 finish_reason
|
83
|
+
if 'finishReason' in candidate:
|
84
|
+
# 映射 Google 格式到标准格式
|
85
|
+
google_reason = candidate['finishReason']
|
86
|
+
finish_reason_mapping = {
|
87
|
+
'STOP': 'stop',
|
88
|
+
'MAX_TOKENS': 'length',
|
89
|
+
'SAFETY': 'content_filter',
|
90
|
+
'RECITATION': 'content_filter'
|
91
|
+
}
|
92
|
+
finish_reason = finish_reason_mapping.get(google_reason, google_reason.lower())
|
93
|
+
|
94
|
+
# 如果有工具调用,设置 finish_reason 为 tool_calls
|
95
|
+
if tool_calls:
|
96
|
+
finish_reason = 'tool_calls'
|
97
|
+
|
33
98
|
return ModelResponse(
|
34
99
|
content=grpc_response.content,
|
35
100
|
usage=ResponseHandler._parse_json_field(grpc_response.usage),
|
36
101
|
error=grpc_response.error or None,
|
37
|
-
raw_response=
|
102
|
+
raw_response=raw_response,
|
38
103
|
request_id=grpc_response.request_id if grpc_response.request_id else None,
|
104
|
+
tool_calls=tool_calls,
|
105
|
+
finish_reason=finish_reason
|
39
106
|
)
|
40
107
|
|
41
108
|
@staticmethod
|
@@ -117,12 +184,28 @@ class ResponseHandler:
|
|
117
184
|
"client_type": model_request.user_context.client_type
|
118
185
|
})
|
119
186
|
|
187
|
+
# 添加请求中的 tool 信息
|
188
|
+
if hasattr(model_request, 'tools') and model_request.tools:
|
189
|
+
data["tools_count"] = len(model_request.tools) if isinstance(model_request.tools, list) else 1
|
190
|
+
data["has_tools"] = True
|
191
|
+
|
192
|
+
if hasattr(model_request, 'tool_choice') and model_request.tool_choice:
|
193
|
+
data["tool_choice"] = str(model_request.tool_choice)
|
194
|
+
|
120
195
|
# 添加响应信息
|
121
196
|
if response:
|
122
197
|
if hasattr(response, 'content') and response.content:
|
123
198
|
data["content_length"] = len(response.content)
|
124
199
|
if hasattr(response, 'usage'):
|
125
200
|
data["usage"] = response.usage
|
201
|
+
|
202
|
+
# 新增:tool_calls 相关日志
|
203
|
+
if hasattr(response, 'tool_calls') and response.tool_calls:
|
204
|
+
data["tool_calls_count"] = len(response.tool_calls)
|
205
|
+
data["has_tool_calls"] = True
|
206
|
+
|
207
|
+
if hasattr(response, 'finish_reason') and response.finish_reason:
|
208
|
+
data["finish_reason"] = response.finish_reason
|
126
209
|
|
127
210
|
# 添加流式响应统计
|
128
211
|
if stream_stats:
|
@@ -15,8 +15,27 @@ class BaseResponse(BaseModel):
|
|
15
15
|
|
16
16
|
|
17
17
|
class ModelResponse(BaseResponse):
|
18
|
+
"""增强的模型响应类,对标 OpenAI SDK 的 Tool Call 支持"""
|
19
|
+
|
18
20
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
21
|
+
|
19
22
|
request_id: Optional[str] = None # 请求ID,用于跟踪请求
|
23
|
+
|
24
|
+
# 新增字段 - 对标 OpenAI SDK
|
25
|
+
tool_calls: Optional[List[Dict[str, Any]]] = None
|
26
|
+
"""Tool calls 列表,对应 OpenAI SDK 的 message.tool_calls"""
|
27
|
+
|
28
|
+
finish_reason: Optional[str] = None
|
29
|
+
"""完成原因,对应 OpenAI SDK 的 choice.finish_reason"""
|
30
|
+
|
31
|
+
# 基础便利方法
|
32
|
+
def has_tool_calls(self) -> bool:
|
33
|
+
"""检查响应是否包含 tool calls
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
bool: 如果包含 tool calls 返回 True
|
37
|
+
"""
|
38
|
+
return bool(self.tool_calls and len(self.tool_calls) > 0)
|
20
39
|
|
21
40
|
|
22
41
|
class BatchModelResponse(BaseModel):
|
@@ -0,0 +1,169 @@
|
|
1
|
+
"""
|
2
|
+
Tool Call 实用工具类
|
3
|
+
|
4
|
+
提供简化 Tool Call 使用的基础工具方法,减少常见错误和样板代码。
|
5
|
+
注意:本工具类仅提供数据处理便利,不包含自动执行功能。
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
from typing import List, Dict, Any, Optional
|
10
|
+
|
11
|
+
from .schemas.outputs import ModelResponse
|
12
|
+
|
13
|
+
|
14
|
+
class ToolCallHelper:
|
15
|
+
"""Tool Call 实用工具类
|
16
|
+
|
17
|
+
提供基础的数据处理方法,对标 OpenAI SDK 的使用体验。
|
18
|
+
"""
|
19
|
+
|
20
|
+
@staticmethod
|
21
|
+
def create_function_tool(
|
22
|
+
name: str,
|
23
|
+
description: str,
|
24
|
+
parameters: Dict[str, Any],
|
25
|
+
strict: Optional[bool] = None
|
26
|
+
) -> Dict[str, Any]:
|
27
|
+
"""创建函数工具定义(对标 OpenAI SDK 的工具定义格式)
|
28
|
+
|
29
|
+
Args:
|
30
|
+
name: 函数名称
|
31
|
+
description: 函数描述
|
32
|
+
parameters: 函数参数的 JSON Schema
|
33
|
+
strict: 是否启用严格模式(OpenAI Structured Outputs)
|
34
|
+
|
35
|
+
Returns:
|
36
|
+
ChatCompletionToolParam: 工具定义对象
|
37
|
+
|
38
|
+
Example:
|
39
|
+
>>> weather_tool = ToolCallHelper.create_function_tool(
|
40
|
+
... name="get_weather",
|
41
|
+
... description="获取指定城市的天气信息",
|
42
|
+
... parameters={
|
43
|
+
... "type": "object",
|
44
|
+
... "properties": {
|
45
|
+
... "location": {"type": "string", "description": "城市名称"}
|
46
|
+
... },
|
47
|
+
... "required": ["location"]
|
48
|
+
... }
|
49
|
+
... )
|
50
|
+
"""
|
51
|
+
tool_def = {
|
52
|
+
"type": "function",
|
53
|
+
"function": {
|
54
|
+
"name": name,
|
55
|
+
"description": description,
|
56
|
+
"parameters": parameters
|
57
|
+
}
|
58
|
+
}
|
59
|
+
|
60
|
+
if strict is not None:
|
61
|
+
tool_def["function"]["strict"] = strict
|
62
|
+
|
63
|
+
return tool_def
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def create_tool_response_message(
|
67
|
+
tool_call_id: str,
|
68
|
+
content: str,
|
69
|
+
name: Optional[str] = None
|
70
|
+
) -> Dict[str, Any]:
|
71
|
+
"""创建工具响应消息(对标 OpenAI SDK 的消息格式)
|
72
|
+
|
73
|
+
Args:
|
74
|
+
tool_call_id: 工具调用 ID
|
75
|
+
content: 工具执行结果
|
76
|
+
name: 工具名称(可选)
|
77
|
+
|
78
|
+
Returns:
|
79
|
+
ChatCompletionMessageParam: 工具响应消息
|
80
|
+
|
81
|
+
Example:
|
82
|
+
>>> tool_message = ToolCallHelper.create_tool_response_message(
|
83
|
+
... tool_call_id="call_123",
|
84
|
+
... content="北京今天晴天,25°C",
|
85
|
+
... name="get_weather"
|
86
|
+
... )
|
87
|
+
"""
|
88
|
+
message = {
|
89
|
+
"role": "tool",
|
90
|
+
"tool_call_id": tool_call_id,
|
91
|
+
"content": content
|
92
|
+
}
|
93
|
+
|
94
|
+
if name:
|
95
|
+
message["name"] = name
|
96
|
+
|
97
|
+
return message
|
98
|
+
|
99
|
+
@staticmethod
|
100
|
+
def parse_function_arguments(tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
101
|
+
"""安全解析函数参数(解决 OpenAI SDK 需要手动 json.loads 的痛点)
|
102
|
+
|
103
|
+
Args:
|
104
|
+
tool_call: 工具调用对象
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
Dict[str, Any]: 解析后的参数字典
|
108
|
+
|
109
|
+
Raises:
|
110
|
+
ValueError: 不支持的工具类型或参数解析失败
|
111
|
+
|
112
|
+
Example:
|
113
|
+
>>> tool_call = response.tool_calls[0]
|
114
|
+
>>> arguments = ToolCallHelper.parse_function_arguments(tool_call)
|
115
|
+
>>> print(arguments["location"]) # "北京"
|
116
|
+
"""
|
117
|
+
if tool_call.get("type") != "function":
|
118
|
+
raise ValueError(f"不支持的工具类型: {tool_call.get('type')}")
|
119
|
+
|
120
|
+
function = tool_call.get("function", {})
|
121
|
+
arguments_str = function.get("arguments", "{}")
|
122
|
+
|
123
|
+
try:
|
124
|
+
return json.loads(arguments_str)
|
125
|
+
except json.JSONDecodeError as e:
|
126
|
+
raise ValueError(f"解析工具参数失败: {arguments_str}") from e
|
127
|
+
|
128
|
+
@staticmethod
|
129
|
+
def build_messages_with_tool_response(
|
130
|
+
original_messages: List[Dict[str, Any]],
|
131
|
+
assistant_message: ModelResponse,
|
132
|
+
tool_responses: List[Dict[str, Any]]
|
133
|
+
) -> List[Dict[str, Any]]:
|
134
|
+
"""构建包含工具响应的消息列表(简化版工具方法)
|
135
|
+
|
136
|
+
Args:
|
137
|
+
original_messages: 原始消息列表
|
138
|
+
assistant_message: 包含 tool calls 的助手响应
|
139
|
+
tool_responses: 工具响应列表
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
List[Dict[str, Any]]: 新的消息列表
|
143
|
+
|
144
|
+
Example:
|
145
|
+
>>> new_messages = ToolCallHelper.build_messages_with_tool_response(
|
146
|
+
... original_messages=request.messages,
|
147
|
+
... assistant_message=response,
|
148
|
+
... tool_responses=[tool_message]
|
149
|
+
... )
|
150
|
+
>>> # 然后开发者手动创建新请求发送
|
151
|
+
"""
|
152
|
+
new_messages = list(original_messages)
|
153
|
+
|
154
|
+
# 添加助手的响应消息
|
155
|
+
assistant_msg = {
|
156
|
+
"role": "assistant",
|
157
|
+
"content": assistant_message.content or ""
|
158
|
+
}
|
159
|
+
|
160
|
+
# 如果有 tool calls,添加到消息中
|
161
|
+
if assistant_message.has_tool_calls():
|
162
|
+
assistant_msg["tool_calls"] = assistant_message.tool_calls
|
163
|
+
|
164
|
+
new_messages.append(assistant_msg)
|
165
|
+
|
166
|
+
# 添加工具响应消息
|
167
|
+
new_messages.extend(tool_responses)
|
168
|
+
|
169
|
+
return new_messages
|
@@ -1,20 +1,21 @@
|
|
1
|
-
tamar_model_client/__init__.py,sha256=
|
2
|
-
tamar_model_client/async_client.py,sha256=
|
1
|
+
tamar_model_client/__init__.py,sha256=PpU0HWhbgp6jb7cksBcEOUNalHmJ14WfkjhZGVBK1-U,513
|
2
|
+
tamar_model_client/async_client.py,sha256=u3fikItsFfzqwSYsvH7GIrHXJGSE1oeJ7ubXnx9G8QE,46452
|
3
3
|
tamar_model_client/auth.py,sha256=DrtnFpG0ZKFUnTnV_Y-FuLRiC2kobcgg0W5Gr1ywg1k,1398
|
4
4
|
tamar_model_client/circuit_breaker.py,sha256=Y3AVp7WzVYU-ubcmovKsJ8DRJbbO4G7vdZgSjnwcWJQ,5550
|
5
5
|
tamar_model_client/error_handler.py,sha256=y7EipcqkXbCecSAOsnoSP3SH7hvZSNF_NUHooTi3hP0,18364
|
6
6
|
tamar_model_client/exceptions.py,sha256=EOr4JMYI7hVszRvNYJ1JqsUNpVmd16T2KpJ0MkFTsUE,13073
|
7
7
|
tamar_model_client/json_formatter.py,sha256=XT8XPMKKM2M22tuYR2e1rvWHcpz3UD9iLLgGPsGOjCI,2410
|
8
8
|
tamar_model_client/logging_icons.py,sha256=MRTZ1Xvkep9ce_jdltj54_XZUXvIpQ95soRNmLdJ4qw,1837
|
9
|
-
tamar_model_client/sync_client.py,sha256=
|
9
|
+
tamar_model_client/sync_client.py,sha256=3BpvW_lexbBapqi4u37LAE00gQlwK3p-aAC7K0ubPTM,53781
|
10
|
+
tamar_model_client/tool_call_helper.py,sha256=MbrrKuQ89uP0AojVtxefCbiKI11QKCoOB6GHMr_brjk,5704
|
10
11
|
tamar_model_client/utils.py,sha256=9gJm71UuQhyyBCgo6gvMjv74xepOlw6AiwuSzea2CL0,5595
|
11
12
|
tamar_model_client/core/__init__.py,sha256=RMiZjV1S4csWPLxB_JfdOea8fYPz97Oj3humQSBw1OI,1054
|
12
|
-
tamar_model_client/core/base_client.py,sha256=
|
13
|
+
tamar_model_client/core/base_client.py,sha256=4_QDpXemqTGVO4jPREe0ou9aW-a9_XD3P_E72h8VGDU,14098
|
13
14
|
tamar_model_client/core/http_fallback.py,sha256=2N7-N_TZrtffDjuv9s3-CD8Xy7qw9AuI5xeWGUnGQ0w,22217
|
14
15
|
tamar_model_client/core/logging_setup.py,sha256=-MXzTR4Ax50H16cbq1jCXbxgayf5fZ0U3o0--fMmxD8,6692
|
15
16
|
tamar_model_client/core/request_builder.py,sha256=aplTEXGgeipn-dRCdUptHYWkT9c4zjKmbmI8Ckbv_sM,8516
|
16
17
|
tamar_model_client/core/request_id_manager.py,sha256=S-Mliaby9zN_bx-B85FvVnttal-w0skkjy2ZvWoQ5vw,3689
|
17
|
-
tamar_model_client/core/response_handler.py,sha256=
|
18
|
+
tamar_model_client/core/response_handler.py,sha256=dmF5GXvBqwtssgVyJDkQc_MxgjBaoY3yOKOk2gJumbE,8613
|
18
19
|
tamar_model_client/core/utils.py,sha256=AcbsGfNQEaZLYI4OZJs-BdmJgxAoLUC5LFoiYmji820,5875
|
19
20
|
tamar_model_client/enums/__init__.py,sha256=3cYYn8ztNGBa_pI_5JGRVYf2QX8fkBVWdjID1PLvoBQ,182
|
20
21
|
tamar_model_client/enums/channel.py,sha256=wCzX579nNpTtwzGeS6S3Ls0UzVAgsOlfy4fXMzQTCAw,199
|
@@ -25,14 +26,15 @@ tamar_model_client/generated/model_service_pb2.py,sha256=RI6wNSmgmylzWPedFfPxx93
|
|
25
26
|
tamar_model_client/generated/model_service_pb2_grpc.py,sha256=k4tIbp3XBxdyuOVR18Ung_4SUryONB51UYf_uUEl6V4,5145
|
26
27
|
tamar_model_client/schemas/__init__.py,sha256=j5XaUGc5Wy669IRNCey56sryCK8QrYbSTl8DSbCqc94,437
|
27
28
|
tamar_model_client/schemas/inputs.py,sha256=Hyl3f-oGT3ooUSs4yn7rHWLJHlltXsQwLSwnlNJx1sE,17503
|
28
|
-
tamar_model_client/schemas/outputs.py,sha256=
|
29
|
+
tamar_model_client/schemas/outputs.py,sha256=Dpd2I-gVr1U154r31OfJ7UrYV5oVh7A5b25cAMKT374,1672
|
29
30
|
tests/__init__.py,sha256=kbmImddLDwdqlkkmkyKtl4bQy_ipe-R8eskpaBylU9w,38
|
30
31
|
tests/stream_hanging_analysis.py,sha256=W3W48IhQbNAR6-xvMpoWZvnWOnr56CTaH4-aORNBuD4,14807
|
31
32
|
tests/test_circuit_breaker.py,sha256=nhEBnyXFjIYjRWlUdu7Z9PnPq48ypbBK6fxN6deHedw,12172
|
32
|
-
tests/test_google_azure_final.py,sha256=
|
33
|
+
tests/test_google_azure_final.py,sha256=QY_1fPK3qrgvP-3rXWeiLc8lQMgCWtR2U-aTcl-xK3I,67009
|
33
34
|
tests/test_logging_issue.py,sha256=JTMbotfHpAEPMBj73pOwxPn-Zn4QVQJX6scMz48FRDQ,2427
|
34
35
|
tests/test_simple.py,sha256=Xf0U-J9_xn_LzUsmYu06suK0_7DrPeko8OHoHldsNxE,7169
|
35
|
-
|
36
|
-
tamar_model_client-0.2.
|
37
|
-
tamar_model_client-0.2.
|
38
|
-
tamar_model_client-0.2.
|
36
|
+
tests/test_tool_call_enhancement.py,sha256=vhhT2_Ni9B0emycT1q7qHn4P9KqDN8UNPzFswTMa4_c,21591
|
37
|
+
tamar_model_client-0.2.7.dist-info/METADATA,sha256=7yk65vRlF1VJRkxZzBKoRROvliNE39wlVogxjR9Fo5w,44637
|
38
|
+
tamar_model_client-0.2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
39
|
+
tamar_model_client-0.2.7.dist-info/top_level.txt,sha256=f1I-S8iWN-cgv4gB8gxRg9jJOTJMumvm4oGKVPfGg6A,25
|
40
|
+
tamar_model_client-0.2.7.dist-info/RECORD,,
|
tests/test_google_azure_final.py
CHANGED
@@ -1471,36 +1471,36 @@ async def main():
|
|
1471
1471
|
|
1472
1472
|
try:
|
1473
1473
|
# 同步测试
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1474
|
+
test_google_ai_studio()
|
1475
|
+
test_google_vertex_ai()
|
1476
|
+
test_azure_openai()
|
1477
1477
|
|
1478
1478
|
# 新增:图像生成测试
|
1479
|
-
|
1480
|
-
|
1479
|
+
test_google_genai_image_generation()
|
1480
|
+
test_google_vertex_ai_image_generation()
|
1481
1481
|
|
1482
1482
|
# 同步批量测试
|
1483
|
-
|
1483
|
+
test_sync_batch_requests()
|
1484
1484
|
|
1485
1485
|
# 异步流式测试
|
1486
|
-
|
1487
|
-
|
1486
|
+
await asyncio.wait_for(test_google_streaming(), timeout=60.0)
|
1487
|
+
await asyncio.wait_for(test_azure_streaming(), timeout=60.0)
|
1488
1488
|
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
#
|
1494
|
-
|
1495
|
-
|
1496
|
-
#
|
1497
|
-
|
1489
|
+
#:异步图像生成测试
|
1490
|
+
await asyncio.wait_for(test_google_genai_image_generation_async(), timeout=120.0)
|
1491
|
+
await asyncio.wait_for(test_google_vertex_ai_image_generation_async(), timeout=120.0)
|
1492
|
+
|
1493
|
+
# 异步批量测试
|
1494
|
+
await asyncio.wait_for(test_batch_requests(), timeout=120.0)
|
1495
|
+
|
1496
|
+
# 新增:图像生成批量测试
|
1497
|
+
await asyncio.wait_for(test_image_generation_batch(), timeout=180.0)
|
1498
1498
|
|
1499
1499
|
# 同步并发测试
|
1500
|
-
|
1501
|
-
|
1502
|
-
#
|
1503
|
-
|
1500
|
+
test_concurrent_requests(2) # 测试150个并发请求
|
1501
|
+
|
1502
|
+
# 异步并发测试
|
1503
|
+
await test_async_concurrent_requests(2) # 测试50个异步并发请求(复用连接)
|
1504
1504
|
|
1505
1505
|
print("\n✅ 测试完成")
|
1506
1506
|
|
@@ -0,0 +1,571 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Tool Call Enhancement 测试脚本
|
4
|
+
|
5
|
+
测试 Tool Call 功能的增强实现,包括:
|
6
|
+
1. ModelResponse 的 tool_calls 和 finish_reason 字段
|
7
|
+
2. ToolCallHelper 工具类的便利方法
|
8
|
+
3. ResponseHandler 的自动提取功能
|
9
|
+
"""
|
10
|
+
|
11
|
+
import asyncio
|
12
|
+
import json
|
13
|
+
import logging
|
14
|
+
import os
|
15
|
+
import sys
|
16
|
+
from unittest.mock import Mock
|
17
|
+
|
18
|
+
# 配置测试脚本专用的日志
|
19
|
+
test_logger = logging.getLogger('test_tool_call_enhancement')
|
20
|
+
test_logger.setLevel(logging.INFO)
|
21
|
+
test_logger.propagate = False
|
22
|
+
|
23
|
+
test_handler = logging.StreamHandler()
|
24
|
+
test_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
25
|
+
test_logger.addHandler(test_handler)
|
26
|
+
|
27
|
+
logger = test_logger
|
28
|
+
|
29
|
+
# 工具函数实现
|
30
|
+
def get_weather(location: str) -> str:
|
31
|
+
"""获取指定城市的天气信息
|
32
|
+
|
33
|
+
Args:
|
34
|
+
location: 城市名称
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
天气信息字符串
|
38
|
+
"""
|
39
|
+
# 模拟天气数据
|
40
|
+
weather_data = {
|
41
|
+
"北京": "北京今天晴天,温度25°C,微风",
|
42
|
+
"上海": "上海今天多云,温度28°C,湿度较高",
|
43
|
+
"广州": "广州今天阴天,温度32°C,有雷阵雨",
|
44
|
+
"深圳": "深圳今天晴天,温度30°C,空气质量良好",
|
45
|
+
"杭州": "杭州今天小雨,温度22°C,建议带伞",
|
46
|
+
"成都": "成都今天阴天,温度26°C,空气湿润"
|
47
|
+
}
|
48
|
+
|
49
|
+
# 默认天气信息
|
50
|
+
return weather_data.get(location, f"{location}今天天气晴朗,温度适宜")
|
51
|
+
|
52
|
+
# 设置测试环境变量
|
53
|
+
os.environ['MODEL_MANAGER_SERVER_GRPC_USE_TLS'] = "false"
|
54
|
+
os.environ['MODEL_MANAGER_SERVER_ADDRESS'] = "localhost:50051"
|
55
|
+
os.environ['MODEL_MANAGER_SERVER_JWT_SECRET_KEY'] = "model-manager-server-jwt-key"
|
56
|
+
|
57
|
+
# 导入客户端模块
|
58
|
+
try:
|
59
|
+
from tamar_model_client import TamarModelClient, AsyncTamarModelClient
|
60
|
+
from tamar_model_client.schemas import ModelRequest, UserContext
|
61
|
+
from tamar_model_client.enums import ProviderType, InvokeType, Channel
|
62
|
+
|
63
|
+
# 为了调试,临时启用 SDK 的日志输出
|
64
|
+
os.environ['TAMAR_MODEL_CLIENT_LOG_LEVEL'] = 'INFO'
|
65
|
+
|
66
|
+
except ImportError as e:
|
67
|
+
logger.error(f"导入模块失败: {e}")
|
68
|
+
sys.exit(1)
|
69
|
+
|
70
|
+
|
71
|
+
def test_model_response_enhancement():
|
72
|
+
"""测试 ModelResponse 增强功能"""
|
73
|
+
print("\n📋 测试 ModelResponse 增强功能...")
|
74
|
+
|
75
|
+
try:
|
76
|
+
from tamar_model_client.schemas.outputs import ModelResponse
|
77
|
+
|
78
|
+
# 测试有工具调用的情况
|
79
|
+
response_with_tools = ModelResponse(
|
80
|
+
content="I need to call some tools.",
|
81
|
+
tool_calls=[
|
82
|
+
{
|
83
|
+
"id": "call_123",
|
84
|
+
"type": "function",
|
85
|
+
"function": {"name": "get_weather", "arguments": '{"location": "Beijing"}'}
|
86
|
+
}
|
87
|
+
],
|
88
|
+
finish_reason="tool_calls"
|
89
|
+
)
|
90
|
+
|
91
|
+
assert response_with_tools.has_tool_calls() is True
|
92
|
+
assert len(response_with_tools.tool_calls) == 1
|
93
|
+
print(" ✅ 有工具调用的情况测试通过")
|
94
|
+
|
95
|
+
# 测试无工具调用的情况
|
96
|
+
response_without_tools = ModelResponse(
|
97
|
+
content="Here is the answer.",
|
98
|
+
finish_reason="stop"
|
99
|
+
)
|
100
|
+
|
101
|
+
assert response_without_tools.has_tool_calls() is False
|
102
|
+
assert response_without_tools.tool_calls is None
|
103
|
+
print(" ✅ 无工具调用的情况测试通过")
|
104
|
+
|
105
|
+
# 测试空的工具调用列表
|
106
|
+
response_empty_tools = ModelResponse(
|
107
|
+
content="Here is the answer.",
|
108
|
+
tool_calls=[],
|
109
|
+
finish_reason="stop"
|
110
|
+
)
|
111
|
+
|
112
|
+
assert response_empty_tools.has_tool_calls() is False
|
113
|
+
print(" ✅ 空工具调用列表的情况测试通过")
|
114
|
+
|
115
|
+
print("✅ ModelResponse 增强功能测试全部通过")
|
116
|
+
|
117
|
+
except Exception as e:
|
118
|
+
print(f"❌ ModelResponse 增强功能测试失败: {str(e)}")
|
119
|
+
|
120
|
+
|
121
|
+
def test_tool_call_helper():
|
122
|
+
"""测试 ToolCallHelper 工具类"""
|
123
|
+
print("\n🔧 测试 ToolCallHelper 工具类...")
|
124
|
+
|
125
|
+
try:
|
126
|
+
from tamar_model_client import ToolCallHelper
|
127
|
+
from tamar_model_client.schemas.outputs import ModelResponse
|
128
|
+
|
129
|
+
# 测试创建函数工具
|
130
|
+
tool = ToolCallHelper.create_function_tool(
|
131
|
+
name="test_func",
|
132
|
+
description="测试函数",
|
133
|
+
parameters={
|
134
|
+
"type": "object",
|
135
|
+
"properties": {"param1": {"type": "string"}},
|
136
|
+
"required": ["param1"]
|
137
|
+
}
|
138
|
+
)
|
139
|
+
|
140
|
+
assert tool["type"] == "function"
|
141
|
+
assert tool["function"]["name"] == "test_func"
|
142
|
+
assert tool["function"]["description"] == "测试函数"
|
143
|
+
assert "param1" in tool["function"]["parameters"]["properties"]
|
144
|
+
print(" ✅ 创建函数工具测试通过")
|
145
|
+
|
146
|
+
# 测试解析函数参数
|
147
|
+
tool_call = {
|
148
|
+
"type": "function",
|
149
|
+
"function": {
|
150
|
+
"name": "test_func",
|
151
|
+
"arguments": '{"location": "Beijing", "unit": "celsius"}'
|
152
|
+
}
|
153
|
+
}
|
154
|
+
|
155
|
+
args = ToolCallHelper.parse_function_arguments(tool_call)
|
156
|
+
assert args["location"] == "Beijing"
|
157
|
+
assert args["unit"] == "celsius"
|
158
|
+
print(" ✅ 解析函数参数测试通过")
|
159
|
+
|
160
|
+
# 测试创建工具响应消息
|
161
|
+
response_msg = ToolCallHelper.create_tool_response_message(
|
162
|
+
"call_123",
|
163
|
+
"Tool execution result",
|
164
|
+
"test_tool"
|
165
|
+
)
|
166
|
+
|
167
|
+
assert response_msg["role"] == "tool"
|
168
|
+
assert response_msg["tool_call_id"] == "call_123"
|
169
|
+
assert response_msg["content"] == "Tool execution result"
|
170
|
+
assert response_msg["name"] == "test_tool"
|
171
|
+
print(" ✅ 创建工具响应消息测试通过")
|
172
|
+
|
173
|
+
# 测试构建包含工具响应的消息列表
|
174
|
+
original_messages = [
|
175
|
+
{"role": "user", "content": "What's the weather?"}
|
176
|
+
]
|
177
|
+
|
178
|
+
assistant_response = ModelResponse(
|
179
|
+
content="I'll check the weather for you.",
|
180
|
+
tool_calls=[
|
181
|
+
{
|
182
|
+
"id": "call_weather",
|
183
|
+
"type": "function",
|
184
|
+
"function": {"name": "get_weather", "arguments": '{"location": "Beijing"}'}
|
185
|
+
}
|
186
|
+
],
|
187
|
+
finish_reason="tool_calls"
|
188
|
+
)
|
189
|
+
|
190
|
+
tool_responses = [
|
191
|
+
{
|
192
|
+
"role": "tool",
|
193
|
+
"tool_call_id": "call_weather",
|
194
|
+
"content": "Beijing: Sunny, 25°C"
|
195
|
+
}
|
196
|
+
]
|
197
|
+
|
198
|
+
new_messages = ToolCallHelper.build_messages_with_tool_response(
|
199
|
+
original_messages, assistant_response, tool_responses
|
200
|
+
)
|
201
|
+
|
202
|
+
assert len(new_messages) == 3
|
203
|
+
assert new_messages[0]["role"] == "user"
|
204
|
+
assert new_messages[1]["role"] == "assistant"
|
205
|
+
assert new_messages[1]["tool_calls"] == assistant_response.tool_calls
|
206
|
+
assert new_messages[2]["role"] == "tool"
|
207
|
+
assert new_messages[2]["tool_call_id"] == "call_weather"
|
208
|
+
print(" ✅ 构建消息列表测试通过")
|
209
|
+
|
210
|
+
print("✅ ToolCallHelper 工具类测试全部通过")
|
211
|
+
|
212
|
+
except Exception as e:
|
213
|
+
print(f"❌ ToolCallHelper 工具类测试失败: {str(e)}")
|
214
|
+
|
215
|
+
|
216
|
+
def test_response_handler_enhancement():
|
217
|
+
"""测试 ResponseHandler 增强功能"""
|
218
|
+
print("\n🔄 测试 ResponseHandler 增强功能...")
|
219
|
+
|
220
|
+
try:
|
221
|
+
from tamar_model_client.core.response_handler import ResponseHandler
|
222
|
+
|
223
|
+
# 测试 OpenAI 格式的 tool calls 提取
|
224
|
+
mock_grpc_response = Mock()
|
225
|
+
mock_grpc_response.content = ""
|
226
|
+
mock_grpc_response.usage = None
|
227
|
+
mock_grpc_response.error = None
|
228
|
+
mock_grpc_response.request_id = "req_123"
|
229
|
+
mock_grpc_response.raw_response = json.dumps({
|
230
|
+
"choices": [
|
231
|
+
{
|
232
|
+
"message": {
|
233
|
+
"tool_calls": [
|
234
|
+
{
|
235
|
+
"id": "call_456",
|
236
|
+
"type": "function",
|
237
|
+
"function": {
|
238
|
+
"name": "get_weather",
|
239
|
+
"arguments": '{"location": "Shanghai"}'
|
240
|
+
}
|
241
|
+
}
|
242
|
+
]
|
243
|
+
},
|
244
|
+
"finish_reason": "tool_calls"
|
245
|
+
}
|
246
|
+
]
|
247
|
+
})
|
248
|
+
|
249
|
+
response = ResponseHandler.build_model_response(mock_grpc_response)
|
250
|
+
|
251
|
+
assert response.has_tool_calls() is True
|
252
|
+
assert response.finish_reason == "tool_calls"
|
253
|
+
assert len(response.tool_calls) == 1
|
254
|
+
assert response.tool_calls[0]["function"]["name"] == "get_weather"
|
255
|
+
print(" ✅ OpenAI 格式 tool calls 提取测试通过")
|
256
|
+
|
257
|
+
# 测试 Google 格式转换
|
258
|
+
mock_google_response = Mock()
|
259
|
+
mock_google_response.content = ""
|
260
|
+
mock_google_response.usage = None
|
261
|
+
mock_google_response.error = None
|
262
|
+
mock_google_response.request_id = "req_456"
|
263
|
+
mock_google_response.raw_response = json.dumps({
|
264
|
+
"candidates": [
|
265
|
+
{
|
266
|
+
"content": {
|
267
|
+
"parts": [
|
268
|
+
{
|
269
|
+
"functionCall": {
|
270
|
+
"name": "get_weather",
|
271
|
+
"args": {"location": "Guangzhou"}
|
272
|
+
}
|
273
|
+
}
|
274
|
+
]
|
275
|
+
},
|
276
|
+
"finishReason": "STOP"
|
277
|
+
}
|
278
|
+
]
|
279
|
+
})
|
280
|
+
|
281
|
+
google_response = ResponseHandler.build_model_response(mock_google_response)
|
282
|
+
|
283
|
+
assert google_response.has_tool_calls() is True
|
284
|
+
assert google_response.finish_reason == "tool_calls" # 自动转换
|
285
|
+
assert len(google_response.tool_calls) == 1
|
286
|
+
|
287
|
+
tool_call = google_response.tool_calls[0]
|
288
|
+
assert tool_call["type"] == "function"
|
289
|
+
assert tool_call["function"]["name"] == "get_weather"
|
290
|
+
assert "call_0_get_weather" in tool_call["id"]
|
291
|
+
|
292
|
+
# 验证参数转换
|
293
|
+
args = json.loads(tool_call["function"]["arguments"])
|
294
|
+
assert args["location"] == "Guangzhou"
|
295
|
+
print(" ✅ Google 格式转换测试通过")
|
296
|
+
|
297
|
+
print("✅ ResponseHandler 增强功能测试全部通过")
|
298
|
+
|
299
|
+
except Exception as e:
|
300
|
+
print(f"❌ ResponseHandler 增强功能测试失败: {str(e)}")
|
301
|
+
|
302
|
+
|
303
|
+
def test_openai_tool_call():
|
304
|
+
"""测试 OpenAI Tool Call 场景"""
|
305
|
+
print("\n🔧 测试 OpenAI Tool Call...")
|
306
|
+
|
307
|
+
try:
|
308
|
+
client = TamarModelClient()
|
309
|
+
|
310
|
+
request = ModelRequest(
|
311
|
+
provider=ProviderType.AZURE,
|
312
|
+
invoke_type=InvokeType.CHAT_COMPLETIONS,
|
313
|
+
model="gpt-4o-mini",
|
314
|
+
messages=[
|
315
|
+
{"role": "user", "content": "北京今天天气如何?"}
|
316
|
+
],
|
317
|
+
tools=[
|
318
|
+
{
|
319
|
+
"type": "function",
|
320
|
+
"name": "get_weather",
|
321
|
+
"description": "获取指定城市的天气信息",
|
322
|
+
"parameters": {
|
323
|
+
"type": "object",
|
324
|
+
"properties": {
|
325
|
+
"location": {
|
326
|
+
"type": "string",
|
327
|
+
"description": "城市名称"
|
328
|
+
}
|
329
|
+
},
|
330
|
+
"required": ["location"]
|
331
|
+
},
|
332
|
+
"strict": None
|
333
|
+
}
|
334
|
+
],
|
335
|
+
tool_choice="auto",
|
336
|
+
user_context=UserContext(
|
337
|
+
user_id="test_user",
|
338
|
+
org_id="test_org",
|
339
|
+
client_type="tool_call_test"
|
340
|
+
)
|
341
|
+
)
|
342
|
+
|
343
|
+
response = client.invoke(request)
|
344
|
+
|
345
|
+
print(f"✅ OpenAI Tool Call 测试成功")
|
346
|
+
print(f" 响应类型: {type(response)}")
|
347
|
+
print(f" 是否有 tool calls: {response.has_tool_calls()}")
|
348
|
+
print(f" finish_reason: {response.finish_reason}")
|
349
|
+
|
350
|
+
if response.has_tool_calls():
|
351
|
+
print(f" tool_calls 数量: {len(response.tool_calls)}")
|
352
|
+
for i, tool_call in enumerate(response.tool_calls):
|
353
|
+
function_name = tool_call['function']['name']
|
354
|
+
print(f" 工具 {i+1}: {function_name}")
|
355
|
+
from tamar_model_client import ToolCallHelper
|
356
|
+
args = ToolCallHelper.parse_function_arguments(tool_call)
|
357
|
+
print(f" 参数: {args}")
|
358
|
+
|
359
|
+
# 演示实际工具函数调用
|
360
|
+
if function_name == "get_weather":
|
361
|
+
result = get_weather(args['location'])
|
362
|
+
print(f" 执行结果: {result}")
|
363
|
+
|
364
|
+
print(f" 响应内容: {str(response.content)[:200]}...")
|
365
|
+
|
366
|
+
except Exception as e:
|
367
|
+
print(f"❌ OpenAI Tool Call 测试失败: {str(e)}")
|
368
|
+
|
369
|
+
|
370
|
+
def test_google_tool_call():
|
371
|
+
"""测试 Google Tool Call 场景"""
|
372
|
+
print("\n🔧 测试 Google Tool Call...")
|
373
|
+
|
374
|
+
try:
|
375
|
+
client = TamarModelClient()
|
376
|
+
|
377
|
+
request = ModelRequest(
|
378
|
+
provider=ProviderType.GOOGLE,
|
379
|
+
invoke_type=InvokeType.GENERATION,
|
380
|
+
model="tamar-google-gemini-flash-lite",
|
381
|
+
contents=[
|
382
|
+
{"role": "user", "parts": [{"text": "上海今天天气如何?"}]}
|
383
|
+
],
|
384
|
+
config={
|
385
|
+
"tools": [
|
386
|
+
{
|
387
|
+
"functionDeclarations": [
|
388
|
+
{
|
389
|
+
"name": "get_weather",
|
390
|
+
"description": "获取指定城市的天气信息",
|
391
|
+
"parameters": {
|
392
|
+
"type": "object",
|
393
|
+
"properties": {
|
394
|
+
"location": {
|
395
|
+
"type": "string",
|
396
|
+
"description": "城市名称"
|
397
|
+
}
|
398
|
+
},
|
399
|
+
"required": ["location"]
|
400
|
+
}
|
401
|
+
}
|
402
|
+
]
|
403
|
+
}
|
404
|
+
]
|
405
|
+
},
|
406
|
+
user_context=UserContext(
|
407
|
+
user_id="test_user",
|
408
|
+
org_id="test_org",
|
409
|
+
client_type="google_tool_test"
|
410
|
+
)
|
411
|
+
)
|
412
|
+
|
413
|
+
response = client.invoke(request)
|
414
|
+
|
415
|
+
print(f"✅ Google Tool Call 测试成功")
|
416
|
+
print(f" 响应类型: {type(response)}")
|
417
|
+
print(f" 是否有 tool calls: {response.has_tool_calls()}")
|
418
|
+
print(f" finish_reason: {response.finish_reason}")
|
419
|
+
|
420
|
+
if response.has_tool_calls():
|
421
|
+
print(f" tool_calls 数量: {len(response.tool_calls)}")
|
422
|
+
for i, tool_call in enumerate(response.tool_calls):
|
423
|
+
function_name = tool_call['function']['name']
|
424
|
+
print(f" 工具 {i+1}: {function_name}")
|
425
|
+
from tamar_model_client import ToolCallHelper
|
426
|
+
args = ToolCallHelper.parse_function_arguments(tool_call)
|
427
|
+
print(f" 参数: {args}")
|
428
|
+
|
429
|
+
# 演示实际工具函数调用
|
430
|
+
if function_name == "get_weather":
|
431
|
+
result = get_weather(args['location'])
|
432
|
+
print(f" 执行结果: {result}")
|
433
|
+
|
434
|
+
print(f" 响应内容: {str(response.content)[:200]}...")
|
435
|
+
|
436
|
+
except Exception as e:
|
437
|
+
print(f"❌ Google Tool Call 测试失败: {str(e)}")
|
438
|
+
|
439
|
+
|
440
|
+
async def test_async_tool_call_workflow():
|
441
|
+
"""测试异步工具调用工作流程"""
|
442
|
+
print("\n🔄 测试异步工具调用工作流程...")
|
443
|
+
|
444
|
+
try:
|
445
|
+
from tamar_model_client import ToolCallHelper
|
446
|
+
|
447
|
+
async with AsyncTamarModelClient() as client:
|
448
|
+
# 1. 发送带工具的请求
|
449
|
+
initial_messages = [
|
450
|
+
{"role": "user", "content": "深圳今天天气怎么样?"}
|
451
|
+
]
|
452
|
+
|
453
|
+
request = ModelRequest(
|
454
|
+
provider=ProviderType.AZURE,
|
455
|
+
invoke_type=InvokeType.CHAT_COMPLETIONS,
|
456
|
+
model="gpt-4o-mini",
|
457
|
+
messages=initial_messages,
|
458
|
+
tools=[
|
459
|
+
{
|
460
|
+
"type": "function",
|
461
|
+
"name": "get_weather",
|
462
|
+
"description": "获取天气信息",
|
463
|
+
"parameters": {
|
464
|
+
"type": "object",
|
465
|
+
"properties": {
|
466
|
+
"location": {"type": "string", "description": "城市名称"}
|
467
|
+
},
|
468
|
+
"required": ["location"]
|
469
|
+
},
|
470
|
+
"strict": None
|
471
|
+
}
|
472
|
+
],
|
473
|
+
tool_choice="auto",
|
474
|
+
user_context=UserContext(
|
475
|
+
user_id="async_test_user",
|
476
|
+
org_id="test_org",
|
477
|
+
client_type="async_tool_test"
|
478
|
+
)
|
479
|
+
)
|
480
|
+
|
481
|
+
# 发送初始请求
|
482
|
+
response = await client.invoke(request)
|
483
|
+
|
484
|
+
print(f" 步骤 1: 初始请求完成")
|
485
|
+
print(f" 是否需要工具调用: {response.has_tool_calls()}")
|
486
|
+
|
487
|
+
if response.has_tool_calls():
|
488
|
+
# 2. 执行工具函数
|
489
|
+
tool_responses = []
|
490
|
+
for tool_call in response.tool_calls:
|
491
|
+
function_name = tool_call["function"]["name"]
|
492
|
+
args = ToolCallHelper.parse_function_arguments(tool_call)
|
493
|
+
|
494
|
+
# 根据函数名调用对应的工具函数
|
495
|
+
if function_name == "get_weather":
|
496
|
+
weather_result = get_weather(args['location'])
|
497
|
+
else:
|
498
|
+
weather_result = f"未知函数: {function_name}"
|
499
|
+
|
500
|
+
tool_response = ToolCallHelper.create_tool_response_message(
|
501
|
+
tool_call["id"],
|
502
|
+
weather_result,
|
503
|
+
function_name
|
504
|
+
)
|
505
|
+
tool_responses.append(tool_response)
|
506
|
+
|
507
|
+
# 3. 构建包含工具响应的新消息列表
|
508
|
+
new_messages = ToolCallHelper.build_messages_with_tool_response(
|
509
|
+
initial_messages,
|
510
|
+
response,
|
511
|
+
tool_responses
|
512
|
+
)
|
513
|
+
|
514
|
+
# 4. 发送包含工具响应的后续请求
|
515
|
+
follow_up_request = ModelRequest(
|
516
|
+
provider=ProviderType.AZURE,
|
517
|
+
invoke_type=InvokeType.CHAT_COMPLETIONS,
|
518
|
+
model="gpt-4o-mini",
|
519
|
+
messages=new_messages,
|
520
|
+
user_context=UserContext(
|
521
|
+
user_id="async_test_user",
|
522
|
+
org_id="test_org",
|
523
|
+
client_type="async_tool_test"
|
524
|
+
)
|
525
|
+
)
|
526
|
+
|
527
|
+
final_response = await client.invoke(follow_up_request)
|
528
|
+
|
529
|
+
print(f" 步骤 2: 工具调用模拟完成")
|
530
|
+
print(f" 步骤 3: 最终回复生成完成")
|
531
|
+
print(f" 最终回复: {final_response.content[:100]}...")
|
532
|
+
|
533
|
+
print(f"✅ 异步工具调用工作流程测试成功")
|
534
|
+
else:
|
535
|
+
print(f" 模型没有请求工具调用,直接回复: {response.content}")
|
536
|
+
|
537
|
+
except Exception as e:
|
538
|
+
print(f"❌ 异步工具调用工作流程测试失败: {str(e)}")
|
539
|
+
|
540
|
+
|
541
|
+
async def main():
|
542
|
+
"""主函数"""
|
543
|
+
print("🚀 Tool Call Enhancement 功能测试")
|
544
|
+
print("=" * 50)
|
545
|
+
|
546
|
+
try:
|
547
|
+
# 单元测试
|
548
|
+
print("\n📋 运行单元测试...")
|
549
|
+
test_model_response_enhancement()
|
550
|
+
test_tool_call_helper()
|
551
|
+
test_response_handler_enhancement()
|
552
|
+
print("\n✅ 所有单元测试通过")
|
553
|
+
|
554
|
+
# 真实场景测试
|
555
|
+
print("\n🌐 运行真实场景测试...")
|
556
|
+
test_openai_tool_call()
|
557
|
+
test_google_tool_call()
|
558
|
+
await test_async_tool_call_workflow()
|
559
|
+
|
560
|
+
print("\n✅ 所有测试完成")
|
561
|
+
|
562
|
+
except KeyboardInterrupt:
|
563
|
+
print("\n⚠️ 测试被用户中断")
|
564
|
+
except Exception as e:
|
565
|
+
print(f"\n❌ 测试执行出错: {e}")
|
566
|
+
finally:
|
567
|
+
print("🏁 测试程序已退出")
|
568
|
+
|
569
|
+
|
570
|
+
if __name__ == "__main__":
|
571
|
+
asyncio.run(main())
|
File without changes
|
File without changes
|