tamar-model-client 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/__init__.py +2 -0
- tamar_model_client/async_client.py +430 -539
- tamar_model_client/core/__init__.py +34 -0
- tamar_model_client/core/base_client.py +168 -0
- tamar_model_client/core/logging_setup.py +84 -0
- tamar_model_client/core/request_builder.py +221 -0
- tamar_model_client/core/response_handler.py +136 -0
- tamar_model_client/core/utils.py +171 -0
- tamar_model_client/error_handler.py +283 -0
- tamar_model_client/exceptions.py +371 -7
- tamar_model_client/json_formatter.py +36 -1
- tamar_model_client/logging_icons.py +60 -0
- tamar_model_client/sync_client.py +473 -485
- {tamar_model_client-0.1.18.dist-info → tamar_model_client-0.1.20.dist-info}/METADATA +217 -61
- tamar_model_client-0.1.20.dist-info/RECORD +33 -0
- {tamar_model_client-0.1.18.dist-info → tamar_model_client-0.1.20.dist-info}/top_level.txt +1 -0
- tests/__init__.py +1 -0
- tests/stream_hanging_analysis.py +357 -0
- tests/test_google_azure_final.py +448 -0
- tests/test_simple.py +235 -0
- tamar_model_client-0.1.18.dist-info/RECORD +0 -21
- {tamar_model_client-0.1.18.dist-info → tamar_model_client-0.1.20.dist-info}/WHEEL +0 -0
@@ -1,273 +1,180 @@
|
|
1
|
+
"""
|
2
|
+
Tamar Model Client 异步客户端实现
|
3
|
+
|
4
|
+
本模块实现了异步的 gRPC 客户端,用于与 Model Manager Server 进行通信。
|
5
|
+
支持单个请求、批量请求、流式响应等功能,并提供了完整的错误处理和重试机制。
|
6
|
+
|
7
|
+
主要功能:
|
8
|
+
- 异步 gRPC 通信
|
9
|
+
- JWT 认证
|
10
|
+
- 自动重试和错误处理
|
11
|
+
- 流式响应支持
|
12
|
+
- 连接池管理
|
13
|
+
- 详细的日志记录
|
14
|
+
|
15
|
+
使用示例:
|
16
|
+
async with AsyncTamarModelClient() as client:
|
17
|
+
request = ModelRequest(...)
|
18
|
+
response = await client.invoke(request)
|
19
|
+
"""
|
20
|
+
|
1
21
|
import asyncio
|
2
22
|
import atexit
|
3
|
-
import base64
|
4
23
|
import json
|
5
24
|
import logging
|
6
|
-
import os
|
7
25
|
import time
|
8
|
-
import
|
9
|
-
from contextvars import ContextVar
|
26
|
+
from typing import Optional, AsyncIterator, Union
|
10
27
|
|
11
28
|
import grpc
|
12
|
-
from
|
13
|
-
|
14
|
-
from
|
15
|
-
|
16
|
-
|
17
|
-
|
29
|
+
from grpc import RpcError
|
30
|
+
|
31
|
+
from .core import (
|
32
|
+
generate_request_id,
|
33
|
+
set_request_id,
|
34
|
+
setup_logger,
|
35
|
+
MAX_MESSAGE_LENGTH
|
36
|
+
)
|
37
|
+
from .core.base_client import BaseClient
|
38
|
+
from .core.request_builder import RequestBuilder
|
39
|
+
from .core.response_handler import ResponseHandler
|
18
40
|
from .enums import ProviderType, InvokeType
|
19
|
-
from .exceptions import ConnectionError
|
41
|
+
from .exceptions import ConnectionError, TamarModelException
|
42
|
+
from .error_handler import EnhancedRetryHandler
|
20
43
|
from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
|
21
44
|
from .generated import model_service_pb2, model_service_pb2_grpc
|
22
|
-
from .schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput, \
|
23
|
-
GoogleVertexAIImagesInput, OpenAIImagesInput, OpenAIImagesEditInput
|
24
|
-
from .json_formatter import JSONFormatter
|
25
|
-
|
26
|
-
logger = logging.getLogger(__name__)
|
27
|
-
|
28
|
-
# 使用 contextvars 管理请求ID
|
29
|
-
_request_id: ContextVar[str] = ContextVar('request_id', default='-')
|
30
|
-
|
31
|
-
|
32
|
-
class RequestIdFilter(logging.Filter):
|
33
|
-
"""自定义日志过滤器,向日志中添加 request_id"""
|
34
|
-
|
35
|
-
def filter(self, record):
|
36
|
-
# 从 ContextVar 中获取当前的 request_id
|
37
|
-
record.request_id = _request_id.get()
|
38
|
-
return True
|
39
|
-
|
40
|
-
|
41
|
-
if not logger.hasHandlers():
|
42
|
-
# 创建日志处理器,输出到控制台
|
43
|
-
console_handler = logging.StreamHandler()
|
44
|
-
|
45
|
-
# 使用 JSON 格式化器
|
46
|
-
formatter = JSONFormatter()
|
47
|
-
console_handler.setFormatter(formatter)
|
48
|
-
|
49
|
-
# 为当前记录器添加处理器
|
50
|
-
logger.addHandler(console_handler)
|
51
|
-
|
52
|
-
# 设置日志级别
|
53
|
-
logger.setLevel(logging.INFO)
|
54
|
-
|
55
|
-
# 将自定义的 RequestIdFilter 添加到 logger 中
|
56
|
-
logger.addFilter(RequestIdFilter())
|
57
45
|
|
58
|
-
|
46
|
+
# 配置日志记录器
|
47
|
+
logger = setup_logger(__name__)
|
59
48
|
|
60
49
|
|
61
|
-
|
50
|
+
class AsyncTamarModelClient(BaseClient):
|
62
51
|
"""
|
63
|
-
|
52
|
+
Tamar Model Client 异步客户端
|
53
|
+
|
54
|
+
提供与 Model Manager Server 的异步通信能力,支持:
|
55
|
+
- 单个和批量模型调用
|
56
|
+
- 流式和非流式响应
|
57
|
+
- 自动重试和错误恢复
|
58
|
+
- JWT 认证
|
59
|
+
- 连接池管理
|
60
|
+
|
61
|
+
使用示例:
|
62
|
+
# 基本用法
|
63
|
+
client = AsyncTamarModelClient()
|
64
|
+
await client.connect()
|
65
|
+
|
66
|
+
request = ModelRequest(...)
|
67
|
+
response = await client.invoke(request)
|
68
|
+
|
69
|
+
# 上下文管理器用法(推荐)
|
70
|
+
async with AsyncTamarModelClient() as client:
|
71
|
+
response = await client.invoke(request)
|
72
|
+
|
73
|
+
环境变量配置:
|
74
|
+
MODEL_MANAGER_SERVER_ADDRESS: gRPC 服务器地址
|
75
|
+
MODEL_MANAGER_SERVER_JWT_SECRET_KEY: JWT 密钥
|
76
|
+
MODEL_MANAGER_SERVER_GRPC_USE_TLS: 是否使用 TLS
|
77
|
+
MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES: 最大重试次数
|
78
|
+
MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY: 重试延迟
|
64
79
|
"""
|
65
|
-
if value is None or value is NOT_GIVEN:
|
66
|
-
return False
|
67
|
-
|
68
|
-
if isinstance(value, str):
|
69
|
-
return value.strip() != ""
|
70
|
-
|
71
|
-
if isinstance(value, bytes):
|
72
|
-
return len(value) > 0
|
73
|
-
|
74
|
-
if isinstance(value, dict):
|
75
|
-
for v in value.values():
|
76
|
-
if is_effective_value(v):
|
77
|
-
return True
|
78
|
-
return False
|
79
|
-
|
80
|
-
if isinstance(value, list):
|
81
|
-
for item in value:
|
82
|
-
if is_effective_value(item):
|
83
|
-
return True
|
84
|
-
return False
|
85
|
-
|
86
|
-
return True # 其他类型(int/float/bool)只要不是None就算有效
|
87
80
|
|
81
|
+
def __init__(self, **kwargs):
|
82
|
+
"""
|
83
|
+
初始化异步客户端
|
84
|
+
|
85
|
+
参数继承自 BaseClient,包括:
|
86
|
+
- server_address: gRPC 服务器地址
|
87
|
+
- jwt_secret_key: JWT 签名密钥
|
88
|
+
- jwt_token: 预生成的 JWT 令牌
|
89
|
+
- default_payload: JWT 令牌的默认载荷
|
90
|
+
- token_expires_in: JWT 令牌过期时间
|
91
|
+
- max_retries: 最大重试次数
|
92
|
+
- retry_delay: 初始重试延迟
|
93
|
+
"""
|
94
|
+
super().__init__(logger_name=__name__, **kwargs)
|
95
|
+
|
96
|
+
# === gRPC 通道和连接管理 ===
|
97
|
+
self.channel: Optional[grpc.aio.Channel] = None
|
98
|
+
self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
|
99
|
+
|
100
|
+
# === 增强的重试处理器 ===
|
101
|
+
self.retry_handler = EnhancedRetryHandler(
|
102
|
+
max_retries=self.max_retries,
|
103
|
+
base_delay=self.retry_delay
|
104
|
+
)
|
105
|
+
|
106
|
+
# 注册退出时的清理函数
|
107
|
+
atexit.register(self._cleanup_atexit)
|
88
108
|
|
89
|
-
def
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
if isinstance(value, dict):
|
98
|
-
return {k: serialize_value(v) for k, v in value.items()}
|
99
|
-
if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
|
100
|
-
return [serialize_value(v) for v in value]
|
101
|
-
if isinstance(value, bytes):
|
102
|
-
return f"bytes:{base64.b64encode(value).decode('utf-8')}"
|
103
|
-
return value
|
109
|
+
def _cleanup_atexit(self):
|
110
|
+
"""程序退出时的清理函数"""
|
111
|
+
if self.channel and not self._closed:
|
112
|
+
try:
|
113
|
+
asyncio.create_task(self.close())
|
114
|
+
except RuntimeError:
|
115
|
+
# 如果事件循环已经关闭,忽略错误
|
116
|
+
pass
|
104
117
|
|
118
|
+
async def close(self):
|
119
|
+
"""
|
120
|
+
关闭客户端连接
|
121
|
+
|
122
|
+
优雅地关闭 gRPC 通道并清理资源。
|
123
|
+
建议在程序结束前调用此方法,或使用上下文管理器自动管理。
|
124
|
+
"""
|
125
|
+
if self.channel and not self._closed:
|
126
|
+
await self.channel.close()
|
127
|
+
self._closed = True
|
128
|
+
logger.info("🔒 gRPC channel closed",
|
129
|
+
extra={"log_type": "info", "data": {"status": "closed"}})
|
105
130
|
|
106
|
-
|
131
|
+
async def __aenter__(self):
|
132
|
+
"""异步上下文管理器入口"""
|
133
|
+
await self.connect()
|
134
|
+
return self
|
107
135
|
|
136
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
137
|
+
"""异步上下文管理器出口"""
|
138
|
+
await self.close()
|
108
139
|
|
109
|
-
def
|
110
|
-
|
111
|
-
|
112
|
-
"""
|
113
|
-
if isinstance(data, dict):
|
114
|
-
new_dict = {}
|
115
|
-
for key, value in data.items():
|
116
|
-
if value is None:
|
117
|
-
continue
|
118
|
-
cleaned_value = remove_none_from_dict(value)
|
119
|
-
new_dict[key] = cleaned_value
|
120
|
-
return new_dict
|
121
|
-
elif isinstance(data, list):
|
122
|
-
return [remove_none_from_dict(item) for item in data]
|
123
|
-
else:
|
124
|
-
return data
|
125
|
-
|
126
|
-
|
127
|
-
def generate_request_id():
|
128
|
-
"""生成一个唯一的request_id"""
|
129
|
-
return str(uuid.uuid4())
|
130
|
-
|
131
|
-
|
132
|
-
def set_request_id(request_id: str):
|
133
|
-
"""设置当前请求的 request_id"""
|
134
|
-
_request_id.set(request_id)
|
135
|
-
|
136
|
-
|
137
|
-
class AsyncTamarModelClient:
|
138
|
-
def __init__(
|
139
|
-
self,
|
140
|
-
server_address: Optional[str] = None,
|
141
|
-
jwt_secret_key: Optional[str] = None,
|
142
|
-
jwt_token: Optional[str] = None,
|
143
|
-
default_payload: Optional[dict] = None,
|
144
|
-
token_expires_in: int = 3600,
|
145
|
-
max_retries: Optional[int] = None, # 最大重试次数
|
146
|
-
retry_delay: Optional[float] = None, # 初始重试延迟(秒)
|
147
|
-
):
|
148
|
-
# 服务端地址
|
149
|
-
self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
|
150
|
-
if not self.server_address:
|
151
|
-
raise ValueError("Server address must be provided via argument or environment variable.")
|
152
|
-
self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
|
153
|
-
|
154
|
-
# JWT 配置
|
155
|
-
self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
|
156
|
-
self.jwt_handler = JWTAuthHandler(self.jwt_secret_key)
|
157
|
-
self.jwt_token = jwt_token # 用户传入的 Token(可选)
|
158
|
-
self.default_payload = default_payload
|
159
|
-
self.token_expires_in = token_expires_in
|
160
|
-
|
161
|
-
# === TLS/Authority 配置 ===
|
162
|
-
self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
|
163
|
-
self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
|
164
|
-
|
165
|
-
# === 重试配置 ===
|
166
|
-
self.max_retries = max_retries if max_retries is not None else int(
|
167
|
-
os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
|
168
|
-
self.retry_delay = retry_delay if retry_delay is not None else float(
|
169
|
-
os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
|
170
|
-
|
171
|
-
# === gRPC 通道相关 ===
|
172
|
-
self.channel: Optional[grpc.aio.Channel] = None
|
173
|
-
self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
|
174
|
-
self._closed = False
|
175
|
-
atexit.register(self._safe_sync_close) # 注册进程退出自动关闭
|
140
|
+
def __enter__(self):
|
141
|
+
"""同步上下文管理器入口(不支持)"""
|
142
|
+
raise TypeError("Use 'async with' for AsyncTamarModelClient")
|
176
143
|
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
try:
|
181
|
-
return await func(*args, **kwargs)
|
182
|
-
except (grpc.aio.AioRpcError, grpc.RpcError) as e:
|
183
|
-
# 对于取消的情况进行指数退避重试
|
184
|
-
if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
|
185
|
-
retry_count += 1
|
186
|
-
logger.info(f"❌ RPC cancelled, retrying {retry_count}/{self.max_retries}...",
|
187
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": "CANCELLED"}})
|
188
|
-
if retry_count < self.max_retries:
|
189
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
190
|
-
await asyncio.sleep(delay)
|
191
|
-
else:
|
192
|
-
logger.error("❌ Max retry reached for CANCELLED",
|
193
|
-
extra={"log_type": "info", "data": {"error_code": "CANCELLED", "max_retries_reached": True}})
|
194
|
-
raise
|
195
|
-
# 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
|
196
|
-
elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
|
197
|
-
grpc.StatusCode.DEADLINE_EXCEEDED}:
|
198
|
-
retry_count += 1
|
199
|
-
logger.info(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...",
|
200
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": str(e.code())}})
|
201
|
-
if retry_count < self.max_retries:
|
202
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
203
|
-
await asyncio.sleep(delay)
|
204
|
-
else:
|
205
|
-
logger.error(f"❌ Max retry reached for {e.code()}",
|
206
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()), "max_retries_reached": True}})
|
207
|
-
raise
|
208
|
-
else:
|
209
|
-
logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True,
|
210
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()) if hasattr(e, 'code') else None, "retryable": False}})
|
211
|
-
raise
|
144
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
145
|
+
"""同步上下文管理器出口(不支持)"""
|
146
|
+
pass
|
212
147
|
|
213
|
-
async def
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
retry_count += 1
|
222
|
-
logger.info(f"❌ RPC cancelled, retrying {retry_count}/{self.max_retries}...",
|
223
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": "CANCELLED"}})
|
224
|
-
if retry_count < self.max_retries:
|
225
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
226
|
-
await asyncio.sleep(delay)
|
227
|
-
else:
|
228
|
-
logger.error("❌ Max retry reached for CANCELLED",
|
229
|
-
extra={"log_type": "info", "data": {"error_code": "CANCELLED", "max_retries_reached": True}})
|
230
|
-
raise
|
231
|
-
# 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
|
232
|
-
elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
|
233
|
-
grpc.StatusCode.DEADLINE_EXCEEDED}:
|
234
|
-
retry_count += 1
|
235
|
-
logger.info(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...",
|
236
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": str(e.code())}})
|
237
|
-
if retry_count < self.max_retries:
|
238
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
239
|
-
await asyncio.sleep(delay)
|
240
|
-
else:
|
241
|
-
logger.error(f"❌ Max retry reached for {e.code()}",
|
242
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()), "max_retries_reached": True}})
|
243
|
-
raise
|
244
|
-
else:
|
245
|
-
logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True,
|
246
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()) if hasattr(e, 'code') else None, "retryable": False}})
|
247
|
-
raise
|
248
|
-
|
249
|
-
def _build_auth_metadata(self, request_id: str) -> list:
|
250
|
-
# if not self.jwt_token and self.jwt_handler:
|
251
|
-
# 更改为每次请求都生成一次token
|
252
|
-
metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
|
253
|
-
if self.jwt_handler:
|
254
|
-
self.jwt_token = self.jwt_handler.encode_token(self.default_payload, expires_in=self.token_expires_in)
|
255
|
-
metadata.append(("authorization", f"Bearer {self.jwt_token}"))
|
256
|
-
return metadata
|
148
|
+
async def connect(self):
|
149
|
+
"""
|
150
|
+
显式连接到服务器
|
151
|
+
|
152
|
+
建立与 gRPC 服务器的连接。通常不需要手动调用,
|
153
|
+
因为 invoke 方法会自动确保连接已建立。
|
154
|
+
"""
|
155
|
+
await self._ensure_initialized()
|
257
156
|
|
258
157
|
async def _ensure_initialized(self):
|
259
|
-
"""
|
158
|
+
"""
|
159
|
+
初始化gRPC通道
|
160
|
+
|
161
|
+
确保gRPC通道和存根已正确初始化。如果初始化失败,
|
162
|
+
会进行重试,支持TLS配置和完整的keepalive选项。
|
163
|
+
|
164
|
+
连接配置包括:
|
165
|
+
- 消息大小限制
|
166
|
+
- Keepalive设置(30秒ping间隔,10秒超时)
|
167
|
+
- 连接生命周期管理(1小时最大连接时间)
|
168
|
+
- 性能优化选项(带宽探测、内置重试)
|
169
|
+
|
170
|
+
Raises:
|
171
|
+
ConnectionError: 当达到最大重试次数仍无法连接时
|
172
|
+
"""
|
260
173
|
if self.channel and self.stub:
|
261
174
|
return
|
262
175
|
|
263
176
|
retry_count = 0
|
264
|
-
options =
|
265
|
-
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
266
|
-
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
|
267
|
-
('grpc.keepalive_permit_without_calls', True) # 即使没有活跃请求也保持连接
|
268
|
-
]
|
269
|
-
if self.default_authority:
|
270
|
-
options.append(("grpc.default_authority", self.default_authority))
|
177
|
+
options = self.build_channel_options()
|
271
178
|
|
272
179
|
while retry_count <= self.max_retries:
|
273
180
|
try:
|
@@ -279,60 +186,152 @@ class AsyncTamarModelClient:
|
|
279
186
|
options=options
|
280
187
|
)
|
281
188
|
logger.info("🔐 Using secure gRPC channel (TLS enabled)",
|
282
|
-
|
189
|
+
extra={"log_type": "info",
|
190
|
+
"data": {"tls_enabled": True, "server_address": self.server_address}})
|
283
191
|
else:
|
284
192
|
self.channel = grpc.aio.insecure_channel(
|
285
193
|
self.server_address,
|
286
194
|
options=options
|
287
195
|
)
|
288
196
|
logger.info("🔓 Using insecure gRPC channel (TLS disabled)",
|
289
|
-
|
197
|
+
extra={"log_type": "info",
|
198
|
+
"data": {"tls_enabled": False, "server_address": self.server_address}})
|
199
|
+
|
290
200
|
await self.channel.channel_ready()
|
291
201
|
self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
|
292
202
|
logger.info(f"✅ gRPC channel initialized to {self.server_address}",
|
293
|
-
|
203
|
+
extra={"log_type": "info",
|
204
|
+
"data": {"status": "success", "server_address": self.server_address}})
|
294
205
|
return
|
206
|
+
|
295
207
|
except grpc.FutureTimeoutError as e:
|
296
208
|
logger.error(f"❌ gRPC channel initialization timed out: {str(e)}", exc_info=True,
|
297
|
-
|
209
|
+
extra={"log_type": "info",
|
210
|
+
"data": {"error_type": "timeout", "server_address": self.server_address}})
|
298
211
|
except grpc.RpcError as e:
|
299
212
|
logger.error(f"❌ gRPC channel initialization failed: {str(e)}", exc_info=True,
|
300
|
-
|
213
|
+
extra={"log_type": "info",
|
214
|
+
"data": {"error_type": "grpc_error", "server_address": self.server_address}})
|
301
215
|
except Exception as e:
|
302
|
-
logger.error(f"❌ Unexpected error during channel initialization: {str(e)}", exc_info=True,
|
303
|
-
|
304
|
-
|
216
|
+
logger.error(f"❌ Unexpected error during gRPC channel initialization: {str(e)}", exc_info=True,
|
217
|
+
extra={"log_type": "info",
|
218
|
+
"data": {"error_type": "unknown", "server_address": self.server_address}})
|
219
|
+
|
305
220
|
retry_count += 1
|
306
|
-
if retry_count
|
307
|
-
|
308
|
-
extra={"log_type": "info", "data": {"max_retries_reached": True, "server_address": self.server_address}})
|
309
|
-
raise ConnectionError(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.")
|
221
|
+
if retry_count <= self.max_retries:
|
222
|
+
await asyncio.sleep(self.retry_delay * retry_count)
|
310
223
|
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
224
|
+
raise ConnectionError(f"Failed to connect to {self.server_address} after {self.max_retries} retries")
|
225
|
+
|
226
|
+
async def _retry_request(self, func, *args, **kwargs):
|
227
|
+
"""
|
228
|
+
使用增强的重试处理器执行请求
|
229
|
+
|
230
|
+
Args:
|
231
|
+
func: 要执行的异步函数
|
232
|
+
*args: 函数参数
|
233
|
+
**kwargs: 函数关键字参数
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
函数执行结果
|
237
|
+
|
238
|
+
Raises:
|
239
|
+
TamarModelException: 当所有重试都失败时
|
240
|
+
"""
|
241
|
+
return await self.retry_handler.execute_with_retry(func, *args, **kwargs)
|
242
|
+
|
243
|
+
async def _retry_request_stream(self, func, *args, **kwargs):
|
244
|
+
"""
|
245
|
+
流式请求的重试逻辑
|
246
|
+
|
247
|
+
对于流式响应,需要特殊的重试处理,因为流不能简单地重新执行。
|
248
|
+
|
249
|
+
Args:
|
250
|
+
func: 生成流的异步函数
|
251
|
+
*args: 函数参数
|
252
|
+
**kwargs: 函数关键字参数
|
253
|
+
|
254
|
+
Returns:
|
255
|
+
AsyncIterator: 流式响应迭代器
|
256
|
+
"""
|
257
|
+
last_exception = None
|
258
|
+
|
259
|
+
for attempt in range(self.max_retries + 1):
|
260
|
+
try:
|
261
|
+
# 尝试创建流
|
262
|
+
async for item in func(*args, **kwargs):
|
263
|
+
yield item
|
264
|
+
return
|
265
|
+
|
266
|
+
except RpcError as e:
|
267
|
+
last_exception = e
|
268
|
+
if attempt < self.max_retries:
|
269
|
+
logger.warning(
|
270
|
+
f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
|
271
|
+
extra={"retry_count": attempt, "error_code": str(e.code())}
|
272
|
+
)
|
273
|
+
await asyncio.sleep(self.retry_delay * (attempt + 1))
|
274
|
+
else:
|
275
|
+
break
|
276
|
+
except Exception as e:
|
277
|
+
raise TamarModelException(str(e)) from e
|
278
|
+
|
279
|
+
if last_exception:
|
280
|
+
raise self.error_handler.handle_error(last_exception, {"retry_count": self.max_retries})
|
281
|
+
else:
|
282
|
+
raise TamarModelException("Unknown streaming error occurred")
|
316
283
|
|
317
284
|
async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
285
|
+
"""
|
286
|
+
处理流式响应
|
287
|
+
|
288
|
+
包含块级超时保护,防止流式响应挂起。
|
289
|
+
|
290
|
+
Args:
|
291
|
+
request: gRPC 请求对象
|
292
|
+
metadata: 请求元数据
|
293
|
+
invoke_timeout: 总体超时时间
|
294
|
+
|
295
|
+
Yields:
|
296
|
+
ModelResponse: 流式响应的每个数据块
|
297
|
+
|
298
|
+
Raises:
|
299
|
+
TimeoutError: 当等待下一个数据块超时时
|
300
|
+
"""
|
301
|
+
stream_iter = self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout).__aiter__()
|
302
|
+
chunk_timeout = 30.0 # 单个数据块的超时时间
|
303
|
+
|
304
|
+
try:
|
305
|
+
while True:
|
306
|
+
try:
|
307
|
+
# 对每个数据块的获取进行超时保护
|
308
|
+
response = await asyncio.wait_for(
|
309
|
+
stream_iter.__anext__(),
|
310
|
+
timeout=chunk_timeout
|
311
|
+
)
|
312
|
+
yield ResponseHandler.build_model_response(response)
|
313
|
+
|
314
|
+
except asyncio.TimeoutError:
|
315
|
+
raise TimeoutError(f"流式响应在等待下一个数据块时超时 ({chunk_timeout}s)")
|
316
|
+
|
317
|
+
except StopAsyncIteration:
|
318
|
+
break # 正常结束
|
319
|
+
except Exception as e:
|
320
|
+
raise
|
321
|
+
|
322
|
+
async def _stream_with_logging(self, request, metadata, invoke_timeout, start_time, model_request) -> AsyncIterator[
|
323
|
+
ModelResponse]:
|
324
|
+
"""流式响应的包装器,用于记录完整的响应日志并处理重试"""
|
329
325
|
total_content = ""
|
330
326
|
final_usage = None
|
331
327
|
error_occurred = None
|
332
328
|
chunk_count = 0
|
333
|
-
|
329
|
+
|
330
|
+
# 使用重试逻辑获取流生成器
|
331
|
+
stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
|
332
|
+
|
334
333
|
try:
|
335
|
-
async for response in
|
334
|
+
async for response in stream_generator:
|
336
335
|
chunk_count += 1
|
337
336
|
if response.content:
|
338
337
|
total_content += response.content
|
@@ -341,26 +340,46 @@ class AsyncTamarModelClient:
|
|
341
340
|
if response.error:
|
342
341
|
error_occurred = response.error
|
343
342
|
yield response
|
344
|
-
|
345
|
-
#
|
343
|
+
|
344
|
+
# 流式响应完成,记录日志
|
346
345
|
duration = time.time() - start_time
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
"
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
"
|
355
|
-
"
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
346
|
+
if error_occurred:
|
347
|
+
# 流式响应中包含错误
|
348
|
+
logger.warning(
|
349
|
+
f"⚠️ Stream completed with errors | chunks: {chunk_count}",
|
350
|
+
extra={
|
351
|
+
"log_type": "response",
|
352
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
353
|
+
"duration": duration,
|
354
|
+
"data": ResponseHandler.build_log_data(
|
355
|
+
model_request,
|
356
|
+
stream_stats={
|
357
|
+
"chunks_count": chunk_count,
|
358
|
+
"total_length": len(total_content),
|
359
|
+
"usage": final_usage,
|
360
|
+
"error": error_occurred
|
361
|
+
}
|
362
|
+
)
|
361
363
|
}
|
362
|
-
|
363
|
-
|
364
|
+
)
|
365
|
+
else:
|
366
|
+
# 流式响应成功完成
|
367
|
+
logger.info(
|
368
|
+
f"✅ Stream completed successfully | chunks: {chunk_count}",
|
369
|
+
extra={
|
370
|
+
"log_type": "response",
|
371
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
372
|
+
"duration": duration,
|
373
|
+
"data": ResponseHandler.build_log_data(
|
374
|
+
model_request,
|
375
|
+
stream_stats={
|
376
|
+
"chunks_count": chunk_count,
|
377
|
+
"total_length": len(total_content),
|
378
|
+
"usage": final_usage
|
379
|
+
}
|
380
|
+
)
|
381
|
+
}
|
382
|
+
)
|
364
383
|
except Exception as e:
|
365
384
|
# 流式响应出错,记录错误日志
|
366
385
|
duration = time.time() - start_time
|
@@ -371,28 +390,22 @@ class AsyncTamarModelClient:
|
|
371
390
|
"log_type": "response",
|
372
391
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
373
392
|
"duration": duration,
|
374
|
-
"data":
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
}
|
393
|
+
"data": ResponseHandler.build_log_data(
|
394
|
+
model_request,
|
395
|
+
error=e,
|
396
|
+
stream_stats={
|
397
|
+
"chunks_count": chunk_count,
|
398
|
+
"partial_content_length": len(total_content)
|
399
|
+
}
|
400
|
+
)
|
383
401
|
}
|
384
402
|
)
|
385
403
|
raise
|
386
404
|
|
387
405
|
async def _invoke_request(self, request, metadata, invoke_timeout):
|
406
|
+
"""执行单个非流式请求"""
|
388
407
|
async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
|
389
|
-
return
|
390
|
-
content=response.content,
|
391
|
-
usage=json.loads(response.usage) if response.usage else None,
|
392
|
-
error=response.error or None,
|
393
|
-
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
394
|
-
request_id=response.request_id if response.request_id else None,
|
395
|
-
)
|
408
|
+
return ResponseHandler.build_model_response(response)
|
396
409
|
|
397
410
|
async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None,
|
398
411
|
request_id: Optional[str] = None) -> Union[
|
@@ -420,9 +433,9 @@ class AsyncTamarModelClient:
|
|
420
433
|
}
|
421
434
|
|
422
435
|
if not request_id:
|
423
|
-
request_id = generate_request_id()
|
424
|
-
set_request_id(request_id)
|
425
|
-
metadata = self._build_auth_metadata(request_id)
|
436
|
+
request_id = generate_request_id()
|
437
|
+
set_request_id(request_id)
|
438
|
+
metadata = self._build_auth_metadata(request_id)
|
426
439
|
|
427
440
|
# 记录开始日志
|
428
441
|
start_time = time.time()
|
@@ -431,135 +444,86 @@ class AsyncTamarModelClient:
|
|
431
444
|
extra={
|
432
445
|
"log_type": "request",
|
433
446
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
434
|
-
"data":
|
435
|
-
"provider": model_request.provider.value,
|
436
|
-
"invoke_type": model_request.invoke_type.value,
|
437
|
-
"model": model_request.model,
|
438
|
-
"stream": model_request.stream,
|
439
|
-
"org_id": model_request.user_context.org_id,
|
440
|
-
"user_id": model_request.user_context.user_id,
|
441
|
-
"client_type": model_request.user_context.client_type
|
442
|
-
}
|
447
|
+
"data": ResponseHandler.build_log_data(model_request)
|
443
448
|
})
|
444
449
|
|
445
|
-
# 动态根据 provider/invoke_type 决定使用哪个 input 字段
|
446
450
|
try:
|
447
|
-
#
|
448
|
-
|
449
|
-
|
450
|
-
case (ProviderType.GOOGLE, InvokeType.GENERATION):
|
451
|
-
allowed_fields = GoogleGenAiInput.model_fields.keys()
|
452
|
-
case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
|
453
|
-
allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
|
454
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
|
455
|
-
allowed_fields = OpenAIResponsesInput.model_fields.keys()
|
456
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
|
457
|
-
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
458
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
459
|
-
allowed_fields = OpenAIImagesInput.model_fields.keys()
|
460
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
461
|
-
allowed_fields = OpenAIImagesEditInput.model_fields.keys()
|
462
|
-
case _:
|
463
|
-
raise ValueError(
|
464
|
-
f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
|
465
|
-
|
466
|
-
# 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
|
467
|
-
model_request_dict = model_request.model_dump(exclude_unset=True)
|
468
|
-
|
469
|
-
grpc_request_kwargs = {}
|
470
|
-
for field in allowed_fields:
|
471
|
-
if field in model_request_dict:
|
472
|
-
value = model_request_dict[field]
|
473
|
-
|
474
|
-
# 跳过无效的值
|
475
|
-
if not is_effective_value(value):
|
476
|
-
continue
|
477
|
-
|
478
|
-
# 序列化grpc不支持的类型
|
479
|
-
grpc_request_kwargs[field] = serialize_value(value)
|
480
|
-
|
481
|
-
# 清理 serialize后的 grpc_request_kwargs
|
482
|
-
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
483
|
-
|
484
|
-
request = model_service_pb2.ModelRequestItem(
|
485
|
-
provider=model_request.provider.value,
|
486
|
-
channel=model_request.channel.value,
|
487
|
-
invoke_type=model_request.invoke_type.value,
|
488
|
-
stream=model_request.stream or False,
|
489
|
-
org_id=model_request.user_context.org_id or "",
|
490
|
-
user_id=model_request.user_context.user_id or "",
|
491
|
-
client_type=model_request.user_context.client_type or "",
|
492
|
-
extra=grpc_request_kwargs
|
493
|
-
)
|
494
|
-
|
451
|
+
# 构建 gRPC 请求
|
452
|
+
request = RequestBuilder.build_single_request(model_request)
|
453
|
+
|
495
454
|
except Exception as e:
|
455
|
+
duration = time.time() - start_time
|
456
|
+
logger.error(
|
457
|
+
f"❌ Request build failed: {str(e)}",
|
458
|
+
exc_info=True,
|
459
|
+
extra={
|
460
|
+
"log_type": "response",
|
461
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
462
|
+
"duration": duration,
|
463
|
+
"data": {
|
464
|
+
"provider": model_request.provider.value,
|
465
|
+
"invoke_type": model_request.invoke_type.value,
|
466
|
+
"model": getattr(model_request, 'model', None),
|
467
|
+
"error_type": "build_error",
|
468
|
+
"error_message": str(e)
|
469
|
+
}
|
470
|
+
}
|
471
|
+
)
|
496
472
|
raise ValueError(f"构建请求失败: {str(e)}") from e
|
497
473
|
|
498
474
|
try:
|
499
475
|
invoke_timeout = timeout or self.default_invoke_timeout
|
500
476
|
if model_request.stream:
|
501
|
-
#
|
502
|
-
stream_generator = await self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
|
477
|
+
# 对于流式响应,直接返回带日志记录的包装器
|
503
478
|
return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request)
|
504
479
|
else:
|
505
480
|
result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
|
506
|
-
|
481
|
+
|
507
482
|
# 记录非流式响应的成功日志
|
508
483
|
duration = time.time() - start_time
|
484
|
+
content_length = len(result.content) if result.content else 0
|
509
485
|
logger.info(
|
510
|
-
f"✅ Request completed
|
486
|
+
f"✅ Request completed | content_length: {content_length}",
|
511
487
|
extra={
|
512
488
|
"log_type": "response",
|
513
489
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
514
490
|
"duration": duration,
|
515
|
-
"data":
|
516
|
-
"provider": model_request.provider.value,
|
517
|
-
"invoke_type": model_request.invoke_type.value,
|
518
|
-
"model": model_request.model,
|
519
|
-
"stream": False,
|
520
|
-
"content_length": len(result.content) if result.content else 0,
|
521
|
-
"usage": result.usage
|
522
|
-
}
|
491
|
+
"data": ResponseHandler.build_log_data(model_request, result)
|
523
492
|
}
|
524
493
|
)
|
525
494
|
return result
|
495
|
+
|
526
496
|
except grpc.RpcError as e:
|
527
497
|
duration = time.time() - start_time
|
528
498
|
error_message = f"❌ Invoke gRPC failed: {str(e)}"
|
529
499
|
logger.error(error_message, exc_info=True,
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
"model": model_request.model
|
540
|
-
}
|
541
|
-
})
|
500
|
+
extra={
|
501
|
+
"log_type": "response",
|
502
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
503
|
+
"duration": duration,
|
504
|
+
"data": ResponseHandler.build_log_data(
|
505
|
+
model_request,
|
506
|
+
error=e
|
507
|
+
)
|
508
|
+
})
|
542
509
|
raise e
|
543
510
|
except Exception as e:
|
544
511
|
duration = time.time() - start_time
|
545
512
|
error_message = f"❌ Invoke other error: {str(e)}"
|
546
513
|
logger.error(error_message, exc_info=True,
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
}
|
557
|
-
})
|
514
|
+
extra={
|
515
|
+
"log_type": "response",
|
516
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
517
|
+
"duration": duration,
|
518
|
+
"data": ResponseHandler.build_log_data(
|
519
|
+
model_request,
|
520
|
+
error=e
|
521
|
+
)
|
522
|
+
})
|
558
523
|
raise e
|
559
524
|
|
560
525
|
async def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None,
|
561
|
-
request_id: Optional[str] = None) ->
|
562
|
-
BatchModelResponse:
|
526
|
+
request_id: Optional[str] = None) -> BatchModelResponse:
|
563
527
|
"""
|
564
528
|
批量模型调用接口
|
565
529
|
|
@@ -570,7 +534,6 @@ class AsyncTamarModelClient:
|
|
570
534
|
Returns:
|
571
535
|
BatchModelResponse: 批量请求的结果
|
572
536
|
"""
|
573
|
-
|
574
537
|
await self._ensure_initialized()
|
575
538
|
|
576
539
|
if not self.default_payload:
|
@@ -580,9 +543,9 @@ class AsyncTamarModelClient:
|
|
580
543
|
}
|
581
544
|
|
582
545
|
if not request_id:
|
583
|
-
request_id = generate_request_id()
|
584
|
-
set_request_id(request_id)
|
585
|
-
metadata = self._build_auth_metadata(request_id)
|
546
|
+
request_id = generate_request_id()
|
547
|
+
set_request_id(request_id)
|
548
|
+
metadata = self._build_auth_metadata(request_id)
|
586
549
|
|
587
550
|
# 记录开始日志
|
588
551
|
start_time = time.time()
|
@@ -599,155 +562,83 @@ class AsyncTamarModelClient:
|
|
599
562
|
}
|
600
563
|
})
|
601
564
|
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
# 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
|
625
|
-
model_request_dict = model_request_item.model_dump(exclude_unset=True)
|
626
|
-
|
627
|
-
grpc_request_kwargs = {}
|
628
|
-
for field in allowed_fields:
|
629
|
-
if field in model_request_dict:
|
630
|
-
value = model_request_dict[field]
|
631
|
-
|
632
|
-
# 跳过无效的值
|
633
|
-
if not is_effective_value(value):
|
634
|
-
continue
|
635
|
-
|
636
|
-
# 序列化grpc不支持的类型
|
637
|
-
grpc_request_kwargs[field] = serialize_value(value)
|
638
|
-
|
639
|
-
# 清理 serialize后的 grpc_request_kwargs
|
640
|
-
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
641
|
-
|
642
|
-
items.append(model_service_pb2.ModelRequestItem(
|
643
|
-
provider=model_request_item.provider.value,
|
644
|
-
channel=model_request_item.channel.value,
|
645
|
-
invoke_type=model_request_item.invoke_type.value,
|
646
|
-
stream=model_request_item.stream or False,
|
647
|
-
custom_id=model_request_item.custom_id or "",
|
648
|
-
priority=model_request_item.priority or 1,
|
649
|
-
org_id=batch_request_model.user_context.org_id or "",
|
650
|
-
user_id=batch_request_model.user_context.user_id or "",
|
651
|
-
client_type=batch_request_model.user_context.client_type or "",
|
652
|
-
extra=grpc_request_kwargs,
|
653
|
-
))
|
654
|
-
|
655
|
-
except Exception as e:
|
656
|
-
raise ValueError(f"构建请求失败: {str(e)},item={model_request_item.custom_id}") from e
|
565
|
+
try:
|
566
|
+
# 构建批量请求
|
567
|
+
batch_request = RequestBuilder.build_batch_request(batch_request_model)
|
568
|
+
|
569
|
+
except Exception as e:
|
570
|
+
duration = time.time() - start_time
|
571
|
+
logger.error(
|
572
|
+
f"❌ Batch request build failed: {str(e)}",
|
573
|
+
exc_info=True,
|
574
|
+
extra={
|
575
|
+
"log_type": "response",
|
576
|
+
"uri": "/batch_invoke",
|
577
|
+
"duration": duration,
|
578
|
+
"data": {
|
579
|
+
"batch_size": len(batch_request_model.items),
|
580
|
+
"error_type": "build_error",
|
581
|
+
"error_message": str(e)
|
582
|
+
}
|
583
|
+
}
|
584
|
+
)
|
585
|
+
raise ValueError(f"构建批量请求失败: {str(e)}") from e
|
657
586
|
|
658
587
|
try:
|
659
|
-
# 超时处理逻辑
|
660
588
|
invoke_timeout = timeout or self.default_invoke_timeout
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
result = []
|
667
|
-
for res_item in response.items:
|
668
|
-
result.append(ModelResponse(
|
669
|
-
content=res_item.content,
|
670
|
-
usage=json.loads(res_item.usage) if res_item.usage else None,
|
671
|
-
raw_response=json.loads(res_item.raw_response) if res_item.raw_response else None,
|
672
|
-
error=res_item.error or None,
|
673
|
-
custom_id=res_item.custom_id if res_item.custom_id else None
|
674
|
-
))
|
675
|
-
batch_response = BatchModelResponse(
|
676
|
-
request_id=response.request_id if response.request_id else None,
|
677
|
-
responses=result
|
589
|
+
batch_response = await self._retry_request(
|
590
|
+
self.stub.BatchInvoke,
|
591
|
+
batch_request,
|
592
|
+
metadata=metadata,
|
593
|
+
timeout=invoke_timeout
|
678
594
|
)
|
679
|
-
|
595
|
+
|
596
|
+
# 构建响应对象
|
597
|
+
result = ResponseHandler.build_batch_response(batch_response)
|
598
|
+
|
680
599
|
# 记录成功日志
|
681
600
|
duration = time.time() - start_time
|
682
601
|
logger.info(
|
683
|
-
f"✅ Batch
|
602
|
+
f"✅ Batch Request completed | batch_size: {len(result.responses)}",
|
684
603
|
extra={
|
685
604
|
"log_type": "response",
|
686
605
|
"uri": "/batch_invoke",
|
687
606
|
"duration": duration,
|
688
607
|
"data": {
|
689
|
-
"batch_size": len(
|
690
|
-
"
|
608
|
+
"batch_size": len(result.responses),
|
609
|
+
"success_count": sum(1 for item in result.responses if not item.error),
|
610
|
+
"error_count": sum(1 for item in result.responses if item.error)
|
691
611
|
}
|
692
|
-
}
|
693
|
-
|
694
|
-
return
|
612
|
+
})
|
613
|
+
|
614
|
+
return result
|
615
|
+
|
695
616
|
except grpc.RpcError as e:
|
696
617
|
duration = time.time() - start_time
|
697
|
-
error_message = f"❌
|
618
|
+
error_message = f"❌ Batch invoke gRPC failed: {str(e)}"
|
698
619
|
logger.error(error_message, exc_info=True,
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
620
|
+
extra={
|
621
|
+
"log_type": "response",
|
622
|
+
"uri": "/batch_invoke",
|
623
|
+
"duration": duration,
|
624
|
+
"data": {
|
625
|
+
"error_type": "grpc_error",
|
626
|
+
"error_code": str(e.code()) if hasattr(e, 'code') else None,
|
627
|
+
"batch_size": len(batch_request_model.items)
|
628
|
+
}
|
629
|
+
})
|
709
630
|
raise e
|
710
631
|
except Exception as e:
|
711
632
|
duration = time.time() - start_time
|
712
|
-
error_message = f"❌
|
633
|
+
error_message = f"❌ Batch invoke other error: {str(e)}"
|
713
634
|
logger.error(error_message, exc_info=True,
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
raise e
|
724
|
-
|
725
|
-
async def close(self):
|
726
|
-
"""关闭 gRPC 通道"""
|
727
|
-
if self.channel and not self._closed:
|
728
|
-
await self.channel.close()
|
729
|
-
self._closed = True
|
730
|
-
logger.info("✅ gRPC channel closed",
|
731
|
-
extra={"log_type": "info", "data": {"status": "success"}})
|
732
|
-
|
733
|
-
def _safe_sync_close(self):
|
734
|
-
"""进程退出时自动关闭 channel(事件循环处理兼容)"""
|
735
|
-
if self.channel and not self._closed:
|
736
|
-
try:
|
737
|
-
loop = asyncio.get_event_loop()
|
738
|
-
if loop.is_running():
|
739
|
-
loop.create_task(self.close())
|
740
|
-
else:
|
741
|
-
loop.run_until_complete(self.close())
|
742
|
-
except Exception as e:
|
743
|
-
logger.info(f"❌ gRPC channel close failed at exit: {e}",
|
744
|
-
extra={"log_type": "info", "data": {"status": "failed", "error": str(e)}})
|
745
|
-
|
746
|
-
async def __aenter__(self):
|
747
|
-
"""支持 async with 自动初始化连接"""
|
748
|
-
await self._ensure_initialized()
|
749
|
-
return self
|
750
|
-
|
751
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
752
|
-
"""支持 async with 自动关闭连接"""
|
753
|
-
await self.close()
|
635
|
+
extra={
|
636
|
+
"log_type": "response",
|
637
|
+
"uri": "/batch_invoke",
|
638
|
+
"duration": duration,
|
639
|
+
"data": {
|
640
|
+
"error_type": "other_error",
|
641
|
+
"batch_size": len(batch_request_model.items)
|
642
|
+
}
|
643
|
+
})
|
644
|
+
raise e
|