tamar-model-client 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,273 +1,186 @@
1
+ """
2
+ Tamar Model Client 异步客户端实现
3
+
4
+ 本模块实现了异步的 gRPC 客户端,用于与 Model Manager Server 进行通信。
5
+ 支持单个请求、批量请求、流式响应等功能,并提供了完整的错误处理和重试机制。
6
+
7
+ 主要功能:
8
+ - 异步 gRPC 通信
9
+ - JWT 认证
10
+ - 自动重试和错误处理
11
+ - 流式响应支持
12
+ - 连接池管理
13
+ - 详细的日志记录
14
+
15
+ 使用示例:
16
+ async with AsyncTamarModelClient() as client:
17
+ request = ModelRequest(...)
18
+ response = await client.invoke(request)
19
+ """
20
+
1
21
  import asyncio
2
22
  import atexit
3
- import base64
4
23
  import json
5
24
  import logging
6
- import os
25
+ import random
7
26
  import time
8
- import uuid
9
- from contextvars import ContextVar
27
+ from typing import Optional, AsyncIterator, Union
10
28
 
11
29
  import grpc
12
- from typing import Optional, AsyncIterator, Union, Iterable
13
-
14
- from openai import NOT_GIVEN
15
- from pydantic import BaseModel
16
-
17
- from .auth import JWTAuthHandler
30
+ from grpc import RpcError
31
+
32
+ from .core import (
33
+ generate_request_id,
34
+ set_request_id,
35
+ get_protected_logger,
36
+ MAX_MESSAGE_LENGTH
37
+ )
38
+ from .core.base_client import BaseClient
39
+ from .core.request_builder import RequestBuilder
40
+ from .core.response_handler import ResponseHandler
18
41
  from .enums import ProviderType, InvokeType
19
- from .exceptions import ConnectionError
42
+ from .exceptions import ConnectionError, TamarModelException
43
+ from .error_handler import EnhancedRetryHandler
20
44
  from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
21
45
  from .generated import model_service_pb2, model_service_pb2_grpc
22
- from .schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput, \
23
- GoogleVertexAIImagesInput, OpenAIImagesInput, OpenAIImagesEditInput
24
- from .json_formatter import JSONFormatter
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
- # 使用 contextvars 管理请求ID
29
- _request_id: ContextVar[str] = ContextVar('request_id', default='-')
30
-
31
-
32
- class RequestIdFilter(logging.Filter):
33
- """自定义日志过滤器,向日志中添加 request_id"""
34
-
35
- def filter(self, record):
36
- # 从 ContextVar 中获取当前的 request_id
37
- record.request_id = _request_id.get()
38
- return True
39
-
46
+ from .core.http_fallback import AsyncHttpFallbackMixin
40
47
 
41
- if not logger.hasHandlers():
42
- # 创建日志处理器,输出到控制台
43
- console_handler = logging.StreamHandler()
48
+ # 配置日志记录器(使用受保护的logger
49
+ logger = get_protected_logger(__name__)
44
50
 
45
- # 使用 JSON 格式化器
46
- formatter = JSONFormatter()
47
- console_handler.setFormatter(formatter)
48
51
 
49
- # 为当前记录器添加处理器
50
- logger.addHandler(console_handler)
51
-
52
- # 设置日志级别
53
- logger.setLevel(logging.INFO)
54
-
55
- # 将自定义的 RequestIdFilter 添加到 logger 中
56
- logger.addFilter(RequestIdFilter())
57
-
58
- MAX_MESSAGE_LENGTH = 2 ** 31 - 1 # 对于32位系统
59
-
60
-
61
- def is_effective_value(value) -> bool:
52
+ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
62
53
  """
63
- 递归判断value是否是有意义的有效值
54
+ Tamar Model Client 异步客户端
55
+
56
+ 提供与 Model Manager Server 的异步通信能力,支持:
57
+ - 单个和批量模型调用
58
+ - 流式和非流式响应
59
+ - 自动重试和错误恢复
60
+ - JWT 认证
61
+ - 连接池管理
62
+
63
+ 使用示例:
64
+ # 基本用法
65
+ client = AsyncTamarModelClient()
66
+ await client.connect()
67
+
68
+ request = ModelRequest(...)
69
+ response = await client.invoke(request)
70
+
71
+ # 上下文管理器用法(推荐)
72
+ async with AsyncTamarModelClient() as client:
73
+ response = await client.invoke(request)
74
+
75
+ 环境变量配置:
76
+ MODEL_MANAGER_SERVER_ADDRESS: gRPC 服务器地址
77
+ MODEL_MANAGER_SERVER_JWT_SECRET_KEY: JWT 密钥
78
+ MODEL_MANAGER_SERVER_GRPC_USE_TLS: 是否使用 TLS
79
+ MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES: 最大重试次数
80
+ MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY: 重试延迟
64
81
  """
65
- if value is None or value is NOT_GIVEN:
66
- return False
67
-
68
- if isinstance(value, str):
69
- return value.strip() != ""
70
-
71
- if isinstance(value, bytes):
72
- return len(value) > 0
73
-
74
- if isinstance(value, dict):
75
- for v in value.values():
76
- if is_effective_value(v):
77
- return True
78
- return False
79
-
80
- if isinstance(value, list):
81
- for item in value:
82
- if is_effective_value(item):
83
- return True
84
- return False
85
-
86
- return True # 其他类型(int/float/bool)只要不是None就算有效
87
82
 
83
+ def __init__(self, **kwargs):
84
+ """
85
+ 初始化异步客户端
86
+
87
+ 参数继承自 BaseClient,包括:
88
+ - server_address: gRPC 服务器地址
89
+ - jwt_secret_key: JWT 签名密钥
90
+ - jwt_token: 预生成的 JWT 令牌
91
+ - default_payload: JWT 令牌的默认载荷
92
+ - token_expires_in: JWT 令牌过期时间
93
+ - max_retries: 最大重试次数
94
+ - retry_delay: 初始重试延迟
95
+ """
96
+ super().__init__(logger_name=__name__, **kwargs)
97
+
98
+ # === gRPC 通道和连接管理 ===
99
+ self.channel: Optional[grpc.aio.Channel] = None
100
+ self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
101
+
102
+ # === 增强的重试处理器 ===
103
+ self.retry_handler = EnhancedRetryHandler(
104
+ max_retries=self.max_retries,
105
+ base_delay=self.retry_delay
106
+ )
107
+
108
+ # 注册退出时的清理函数
109
+ atexit.register(self._cleanup_atexit)
88
110
 
89
- def serialize_value(value):
90
- """递归处理单个值,处理BaseModel, dict, list, bytes"""
91
- if not is_effective_value(value):
92
- return None
93
- if isinstance(value, BaseModel):
94
- return serialize_value(value.model_dump())
95
- if hasattr(value, "dict") and callable(value.dict):
96
- return serialize_value(value.dict())
97
- if isinstance(value, dict):
98
- return {k: serialize_value(v) for k, v in value.items()}
99
- if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
100
- return [serialize_value(v) for v in value]
101
- if isinstance(value, bytes):
102
- return f"bytes:{base64.b64encode(value).decode('utf-8')}"
103
- return value
111
+ def _cleanup_atexit(self):
112
+ """程序退出时的清理函数"""
113
+ if self.channel and not self._closed:
114
+ try:
115
+ asyncio.create_task(self.close())
116
+ except RuntimeError:
117
+ # 如果事件循环已经关闭,忽略错误
118
+ pass
104
119
 
120
+ async def close(self):
121
+ """
122
+ 关闭客户端连接
123
+
124
+ 优雅地关闭 gRPC 通道并清理资源。
125
+ 建议在程序结束前调用此方法,或使用上下文管理器自动管理。
126
+ """
127
+ if self.channel and not self._closed:
128
+ await self.channel.close()
129
+ self._closed = True
130
+ logger.info("🔒 gRPC channel closed",
131
+ extra={"log_type": "info", "data": {"status": "closed"}})
132
+
133
+ # 清理 HTTP session(如果有)
134
+ if self.resilient_enabled:
135
+ await self._cleanup_http_session()
105
136
 
106
- from typing import Any
137
+ async def __aenter__(self):
138
+ """异步上下文管理器入口"""
139
+ await self.connect()
140
+ return self
107
141
 
142
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
143
+ """异步上下文管理器出口"""
144
+ await self.close()
108
145
 
109
- def remove_none_from_dict(data: Any) -> Any:
110
- """
111
- 遍历 dict/list,递归删除 value None 的字段
112
- """
113
- if isinstance(data, dict):
114
- new_dict = {}
115
- for key, value in data.items():
116
- if value is None:
117
- continue
118
- cleaned_value = remove_none_from_dict(value)
119
- new_dict[key] = cleaned_value
120
- return new_dict
121
- elif isinstance(data, list):
122
- return [remove_none_from_dict(item) for item in data]
123
- else:
124
- return data
125
-
126
-
127
- def generate_request_id():
128
- """生成一个唯一的request_id"""
129
- return str(uuid.uuid4())
130
-
131
-
132
- def set_request_id(request_id: str):
133
- """设置当前请求的 request_id"""
134
- _request_id.set(request_id)
135
-
136
-
137
- class AsyncTamarModelClient:
138
- def __init__(
139
- self,
140
- server_address: Optional[str] = None,
141
- jwt_secret_key: Optional[str] = None,
142
- jwt_token: Optional[str] = None,
143
- default_payload: Optional[dict] = None,
144
- token_expires_in: int = 3600,
145
- max_retries: Optional[int] = None, # 最大重试次数
146
- retry_delay: Optional[float] = None, # 初始重试延迟(秒)
147
- ):
148
- # 服务端地址
149
- self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
150
- if not self.server_address:
151
- raise ValueError("Server address must be provided via argument or environment variable.")
152
- self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
153
-
154
- # JWT 配置
155
- self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
156
- self.jwt_handler = JWTAuthHandler(self.jwt_secret_key)
157
- self.jwt_token = jwt_token # 用户传入的 Token(可选)
158
- self.default_payload = default_payload
159
- self.token_expires_in = token_expires_in
160
-
161
- # === TLS/Authority 配置 ===
162
- self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
163
- self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
164
-
165
- # === 重试配置 ===
166
- self.max_retries = max_retries if max_retries is not None else int(
167
- os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
168
- self.retry_delay = retry_delay if retry_delay is not None else float(
169
- os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
170
-
171
- # === gRPC 通道相关 ===
172
- self.channel: Optional[grpc.aio.Channel] = None
173
- self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
174
- self._closed = False
175
- atexit.register(self._safe_sync_close) # 注册进程退出自动关闭
146
+ def __enter__(self):
147
+ """同步上下文管理器入口(不支持)"""
148
+ raise TypeError("Use 'async with' for AsyncTamarModelClient")
176
149
 
177
- async def _retry_request(self, func, *args, **kwargs):
178
- retry_count = 0
179
- while retry_count < self.max_retries:
180
- try:
181
- return await func(*args, **kwargs)
182
- except (grpc.aio.AioRpcError, grpc.RpcError) as e:
183
- # 对于取消的情况进行指数退避重试
184
- if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
185
- retry_count += 1
186
- logger.warning(f"⚠️ RPC cancelled, retrying {retry_count}/{self.max_retries}...",
187
- extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": "CANCELLED"}})
188
- if retry_count < self.max_retries:
189
- delay = self.retry_delay * (2 ** (retry_count - 1))
190
- await asyncio.sleep(delay)
191
- else:
192
- logger.error("❌ Max retry reached for CANCELLED",
193
- extra={"log_type": "info", "data": {"error_code": "CANCELLED", "max_retries_reached": True}})
194
- raise
195
- # 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
196
- elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
197
- grpc.StatusCode.DEADLINE_EXCEEDED}:
198
- retry_count += 1
199
- logger.warning(f"⚠️ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...",
200
- extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": str(e.code())}})
201
- if retry_count < self.max_retries:
202
- delay = self.retry_delay * (2 ** (retry_count - 1))
203
- await asyncio.sleep(delay)
204
- else:
205
- logger.error(f"❌ Max retry reached for {e.code()}",
206
- extra={"log_type": "info", "data": {"error_code": str(e.code()), "max_retries_reached": True}})
207
- raise
208
- else:
209
- logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True,
210
- extra={"log_type": "info", "data": {"error_code": str(e.code()) if hasattr(e, 'code') else None, "retryable": False}})
211
- raise
150
+ def __exit__(self, exc_type, exc_val, exc_tb):
151
+ """同步上下文管理器出口(不支持)"""
152
+ pass
212
153
 
213
- async def _retry_request_stream(self, func, *args, **kwargs):
214
- retry_count = 0
215
- while retry_count < self.max_retries:
216
- try:
217
- return func(*args, **kwargs)
218
- except (grpc.aio.AioRpcError, grpc.RpcError) as e:
219
- # 对于取消的情况进行指数退避重试
220
- if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
221
- retry_count += 1
222
- logger.warning(f"⚠️ RPC cancelled, retrying {retry_count}/{self.max_retries}...",
223
- extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": "CANCELLED"}})
224
- if retry_count < self.max_retries:
225
- delay = self.retry_delay * (2 ** (retry_count - 1))
226
- await asyncio.sleep(delay)
227
- else:
228
- logger.error("❌ Max retry reached for CANCELLED",
229
- extra={"log_type": "info", "data": {"error_code": "CANCELLED", "max_retries_reached": True}})
230
- raise
231
- # 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
232
- elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
233
- grpc.StatusCode.DEADLINE_EXCEEDED}:
234
- retry_count += 1
235
- logger.warning(f"⚠️ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...",
236
- extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": str(e.code())}})
237
- if retry_count < self.max_retries:
238
- delay = self.retry_delay * (2 ** (retry_count - 1))
239
- await asyncio.sleep(delay)
240
- else:
241
- logger.error(f"❌ Max retry reached for {e.code()}",
242
- extra={"log_type": "info", "data": {"error_code": str(e.code()), "max_retries_reached": True}})
243
- raise
244
- else:
245
- logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True,
246
- extra={"log_type": "info", "data": {"error_code": str(e.code()) if hasattr(e, 'code') else None, "retryable": False}})
247
- raise
248
-
249
- def _build_auth_metadata(self, request_id: str) -> list:
250
- # if not self.jwt_token and self.jwt_handler:
251
- # 更改为每次请求都生成一次token
252
- metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
253
- if self.jwt_handler:
254
- self.jwt_token = self.jwt_handler.encode_token(self.default_payload, expires_in=self.token_expires_in)
255
- metadata.append(("authorization", f"Bearer {self.jwt_token}"))
256
- return metadata
154
+ async def connect(self):
155
+ """
156
+ 显式连接到服务器
157
+
158
+ 建立与 gRPC 服务器的连接。通常不需要手动调用,
159
+ 因为 invoke 方法会自动确保连接已建立。
160
+ """
161
+ await self._ensure_initialized()
257
162
 
258
163
  async def _ensure_initialized(self):
259
- """初始化 gRPC 通道,支持 TLS 与重试机制"""
164
+ """
165
+ 初始化gRPC通道
166
+
167
+ 确保gRPC通道和存根已正确初始化。如果初始化失败,
168
+ 会进行重试,支持TLS配置和完整的keepalive选项。
169
+
170
+ 连接配置包括:
171
+ - 消息大小限制
172
+ - Keepalive设置(30秒ping间隔,10秒超时)
173
+ - 连接生命周期管理(1小时最大连接时间)
174
+ - 性能优化选项(带宽探测、内置重试)
175
+
176
+ Raises:
177
+ ConnectionError: 当达到最大重试次数仍无法连接时
178
+ """
260
179
  if self.channel and self.stub:
261
180
  return
262
181
 
263
182
  retry_count = 0
264
- options = [
265
- ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
266
- ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
267
- ('grpc.keepalive_permit_without_calls', True) # 即使没有活跃请求也保持连接
268
- ]
269
- if self.default_authority:
270
- options.append(("grpc.default_authority", self.default_authority))
183
+ options = self.build_channel_options()
271
184
 
272
185
  while retry_count <= self.max_retries:
273
186
  try:
@@ -279,60 +192,272 @@ class AsyncTamarModelClient:
279
192
  options=options
280
193
  )
281
194
  logger.info("🔐 Using secure gRPC channel (TLS enabled)",
282
- extra={"log_type": "info", "data": {"tls_enabled": True, "server_address": self.server_address}})
195
+ extra={"log_type": "info",
196
+ "data": {"tls_enabled": True, "server_address": self.server_address}})
283
197
  else:
284
198
  self.channel = grpc.aio.insecure_channel(
285
199
  self.server_address,
286
200
  options=options
287
201
  )
288
202
  logger.info("🔓 Using insecure gRPC channel (TLS disabled)",
289
- extra={"log_type": "info", "data": {"tls_enabled": False, "server_address": self.server_address}})
203
+ extra={"log_type": "info",
204
+ "data": {"tls_enabled": False, "server_address": self.server_address}})
205
+
290
206
  await self.channel.channel_ready()
291
207
  self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
292
208
  logger.info(f"✅ gRPC channel initialized to {self.server_address}",
293
- extra={"log_type": "info", "data": {"status": "success", "server_address": self.server_address}})
209
+ extra={"log_type": "info",
210
+ "data": {"status": "success", "server_address": self.server_address}})
294
211
  return
212
+
295
213
  except grpc.FutureTimeoutError as e:
296
214
  logger.error(f"❌ gRPC channel initialization timed out: {str(e)}", exc_info=True,
297
- extra={"log_type": "info", "data": {"error_type": "timeout", "server_address": self.server_address}})
215
+ extra={"log_type": "info",
216
+ "data": {"error_type": "timeout", "server_address": self.server_address}})
298
217
  except grpc.RpcError as e:
299
218
  logger.error(f"❌ gRPC channel initialization failed: {str(e)}", exc_info=True,
300
- extra={"log_type": "info", "data": {"error_type": "rpc_error", "server_address": self.server_address}})
219
+ extra={"log_type": "info",
220
+ "data": {"error_type": "grpc_error", "server_address": self.server_address}})
301
221
  except Exception as e:
302
- logger.error(f"❌ Unexpected error during channel initialization: {str(e)}", exc_info=True,
303
- extra={"log_type": "info", "data": {"error_type": "unexpected", "server_address": self.server_address}})
304
-
222
+ logger.error(f"❌ Unexpected error during gRPC channel initialization: {str(e)}", exc_info=True,
223
+ extra={"log_type": "info",
224
+ "data": {"error_type": "unknown", "server_address": self.server_address}})
225
+
305
226
  retry_count += 1
306
- if retry_count > self.max_retries:
307
- logger.error(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.", exc_info=True,
308
- extra={"log_type": "info", "data": {"max_retries_reached": True, "server_address": self.server_address}})
309
- raise ConnectionError(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.")
227
+ if retry_count <= self.max_retries:
228
+ await asyncio.sleep(self.retry_delay * retry_count)
310
229
 
311
- # 指数退避:延迟时间 = retry_delay * (2 ^ (retry_count - 1))
312
- delay = self.retry_delay * (2 ** (retry_count - 1))
313
- logger.warning(f"🔄 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...",
314
- extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "delay": delay}})
315
- await asyncio.sleep(delay)
230
+ raise ConnectionError(f"Failed to connect to {self.server_address} after {self.max_retries} retries")
316
231
 
317
- async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
318
- async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
319
- yield ModelResponse(
320
- content=response.content,
321
- usage=json.loads(response.usage) if response.usage else None,
322
- error=response.error or None,
323
- raw_response=json.loads(response.raw_response) if response.raw_response else None,
324
- request_id=response.request_id if response.request_id else None,
325
- )
232
+ async def _retry_request(self, func, *args, **kwargs):
233
+ """
234
+ 使用增强的重试处理器执行请求
235
+
236
+ Args:
237
+ func: 要执行的异步函数
238
+ *args: 函数参数
239
+ **kwargs: 函数关键字参数
240
+
241
+ Returns:
242
+ 函数执行结果
243
+
244
+ Raises:
245
+ TamarModelException: 当所有重试都失败时
246
+ """
247
+ return await self.retry_handler.execute_with_retry(func, *args, **kwargs)
248
+
249
+ async def _retry_request_stream(self, func, *args, **kwargs):
250
+ """
251
+ 流式请求的重试逻辑
252
+
253
+ 对于流式响应,需要特殊的重试处理,因为流不能简单地重新执行。
254
+
255
+ Args:
256
+ func: 生成流的异步函数
257
+ *args: 函数参数
258
+ **kwargs: 函数关键字参数
259
+
260
+ Returns:
261
+ AsyncIterator: 流式响应迭代器
262
+ """
263
+ last_exception = None
264
+ context = {
265
+ 'method': 'stream',
266
+ 'client_version': 'async',
267
+ }
268
+
269
+ for attempt in range(self.max_retries + 1):
270
+ try:
271
+ context['retry_count'] = attempt
272
+ # 尝试创建流
273
+ async for item in func(*args, **kwargs):
274
+ yield item
275
+ return
276
+
277
+ except RpcError as e:
278
+ # 使用智能重试判断
279
+ context['retry_count'] = attempt
280
+
281
+ # 创建错误上下文并判断是否应该重试
282
+ from .exceptions import ErrorContext, get_retry_policy
283
+ error_context = ErrorContext(e, context)
284
+ error_code = e.code()
285
+ policy = get_retry_policy(error_code)
286
+ retryable = policy.get('retryable', False)
287
+
288
+ should_retry = False
289
+ if attempt < self.max_retries:
290
+ if retryable == True:
291
+ should_retry = True
292
+ elif retryable == 'conditional':
293
+ # 条件重试,特殊处理 CANCELLED
294
+ if error_code == grpc.StatusCode.CANCELLED:
295
+ should_retry = error_context.is_network_cancelled()
296
+ else:
297
+ should_retry = self._check_error_details_for_retry(e)
298
+
299
+ if should_retry:
300
+ log_data = {
301
+ "log_type": "info",
302
+ "request_id": context.get('request_id'),
303
+ "data": {
304
+ "error_code": e.code().name if e.code() else 'UNKNOWN',
305
+ "retry_count": attempt,
306
+ "max_retries": self.max_retries,
307
+ "method": "stream"
308
+ }
309
+ }
310
+ logger.warning(
311
+ f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
312
+ extra=log_data
313
+ )
314
+
315
+ # 计算退避时间
316
+ delay = self._calculate_backoff(attempt, error_code)
317
+ await asyncio.sleep(delay)
318
+ else:
319
+ # 不重试或已达到最大重试次数
320
+ log_data = {
321
+ "log_type": "info",
322
+ "request_id": context.get('request_id'),
323
+ "data": {
324
+ "error_code": e.code().name if e.code() else 'UNKNOWN',
325
+ "retry_count": attempt,
326
+ "max_retries": self.max_retries,
327
+ "method": "stream",
328
+ "will_retry": False
329
+ }
330
+ }
331
+ logger.error(
332
+ f"Stream failed: {e.code()} (no retry)",
333
+ extra=log_data
334
+ )
335
+ last_exception = self.error_handler.handle_error(e, context)
336
+ break
337
+
338
+ last_exception = e
339
+
340
+ except Exception as e:
341
+ context['retry_count'] = attempt
342
+ raise TamarModelException(str(e)) from e
343
+
344
+ if last_exception:
345
+ if isinstance(last_exception, TamarModelException):
346
+ raise last_exception
347
+ else:
348
+ raise self.error_handler.handle_error(last_exception, context)
349
+ else:
350
+ raise TamarModelException("Unknown streaming error occurred")
351
+
352
+ def _check_error_details_for_retry(self, error: RpcError) -> bool:
353
+ """检查错误详情决定是否重试"""
354
+ error_message = error.details().lower() if error.details() else ""
355
+
356
+ # 可重试的错误模式
357
+ retryable_patterns = [
358
+ 'temporary', 'timeout', 'unavailable',
359
+ 'connection', 'network', 'try again'
360
+ ]
361
+
362
+ for pattern in retryable_patterns:
363
+ if pattern in error_message:
364
+ return True
365
+
366
+ return False
326
367
 
327
- async def _stream_with_logging(self, request, metadata, invoke_timeout, start_time, model_request) -> AsyncIterator[ModelResponse]:
328
- """流式响应的包装器,用于记录完整的响应日志"""
368
+ def _calculate_backoff(self, attempt: int, error_code = None) -> float:
369
+ """
370
+ 计算退避时间,支持不同的退避策略
371
+
372
+ Args:
373
+ attempt: 当前重试次数
374
+ error_code: gRPC错误码,用于确定退避策略
375
+ """
376
+ max_delay = 60.0
377
+ base_delay = self.retry_delay
378
+
379
+ # 获取错误的重试策略
380
+ if error_code:
381
+ from .exceptions import get_retry_policy
382
+ policy = get_retry_policy(error_code)
383
+ backoff_type = policy.get('backoff', 'exponential')
384
+ use_jitter = policy.get('jitter', False)
385
+ else:
386
+ backoff_type = 'exponential'
387
+ use_jitter = False
388
+
389
+ # 根据退避类型计算延迟
390
+ if backoff_type == 'linear':
391
+ # 线性退避:delay * (attempt + 1)
392
+ delay = min(base_delay * (attempt + 1), max_delay)
393
+ else:
394
+ # 指数退避:delay * 2^attempt
395
+ delay = min(base_delay * (2 ** attempt), max_delay)
396
+
397
+ # 添加抖动
398
+ if use_jitter:
399
+ jitter_factor = 0.2 # 增加抖动范围,减少竞争
400
+ jitter = random.uniform(0, delay * jitter_factor)
401
+ delay += jitter
402
+ else:
403
+ # 默认的小量抖动,避免完全同步
404
+ jitter_factor = 0.05
405
+ jitter = random.uniform(0, delay * jitter_factor)
406
+ delay += jitter
407
+
408
+ return delay
409
+
410
+ async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
411
+ """
412
+ 处理流式响应
413
+
414
+ 包含块级超时保护,防止流式响应挂起。
415
+
416
+ Args:
417
+ request: gRPC 请求对象
418
+ metadata: 请求元数据
419
+ invoke_timeout: 总体超时时间
420
+
421
+ Yields:
422
+ ModelResponse: 流式响应的每个数据块
423
+
424
+ Raises:
425
+ TimeoutError: 当等待下一个数据块超时时
426
+ """
427
+ stream_iter = self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout).__aiter__()
428
+ chunk_timeout = 30.0 # 单个数据块的超时时间
429
+
430
+ try:
431
+ while True:
432
+ try:
433
+ # 对每个数据块的获取进行超时保护
434
+ response = await asyncio.wait_for(
435
+ stream_iter.__anext__(),
436
+ timeout=chunk_timeout
437
+ )
438
+ yield ResponseHandler.build_model_response(response)
439
+
440
+ except asyncio.TimeoutError:
441
+ raise TimeoutError(f"流式响应在等待下一个数据块时超时 ({chunk_timeout}s)")
442
+
443
+ except StopAsyncIteration:
444
+ break # 正常结束
445
+ except Exception as e:
446
+ raise
447
+
448
+ async def _stream_with_logging(self, request, metadata, invoke_timeout, start_time, model_request) -> AsyncIterator[
449
+ ModelResponse]:
450
+ """流式响应的包装器,用于记录完整的响应日志并处理重试"""
329
451
  total_content = ""
330
452
  final_usage = None
331
453
  error_occurred = None
332
454
  chunk_count = 0
333
-
455
+
456
+ # 使用重试逻辑获取流生成器
457
+ stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
458
+
334
459
  try:
335
- async for response in self._stream(request, metadata, invoke_timeout):
460
+ async for response in stream_generator:
336
461
  chunk_count += 1
337
462
  if response.content:
338
463
  total_content += response.content
@@ -341,26 +466,46 @@ class AsyncTamarModelClient:
341
466
  if response.error:
342
467
  error_occurred = response.error
343
468
  yield response
344
-
345
- # 流式响应完成,记录成功日志
469
+
470
+ # 流式响应完成,记录日志
346
471
  duration = time.time() - start_time
347
- logger.info(
348
- f"✅ Stream completed successfully | chunks: {chunk_count}",
349
- extra={
350
- "log_type": "response",
351
- "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
352
- "duration": duration,
353
- "data": {
354
- "provider": model_request.provider.value,
355
- "invoke_type": model_request.invoke_type.value,
356
- "model": model_request.model,
357
- "stream": True,
358
- "chunks_count": chunk_count,
359
- "total_length": len(total_content),
360
- "usage": final_usage
472
+ if error_occurred:
473
+ # 流式响应中包含错误
474
+ logger.warning(
475
+ f"⚠️ Stream completed with errors | chunks: {chunk_count}",
476
+ extra={
477
+ "log_type": "response",
478
+ "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
479
+ "duration": duration,
480
+ "data": ResponseHandler.build_log_data(
481
+ model_request,
482
+ stream_stats={
483
+ "chunks_count": chunk_count,
484
+ "total_length": len(total_content),
485
+ "usage": final_usage,
486
+ "error": error_occurred
487
+ }
488
+ )
361
489
  }
362
- }
363
- )
490
+ )
491
+ else:
492
+ # 流式响应成功完成
493
+ logger.info(
494
+ f"✅ Stream completed successfully | chunks: {chunk_count}",
495
+ extra={
496
+ "log_type": "response",
497
+ "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
498
+ "duration": duration,
499
+ "data": ResponseHandler.build_log_data(
500
+ model_request,
501
+ stream_stats={
502
+ "chunks_count": chunk_count,
503
+ "total_length": len(total_content),
504
+ "usage": final_usage
505
+ }
506
+ )
507
+ }
508
+ )
364
509
  except Exception as e:
365
510
  # 流式响应出错,记录错误日志
366
511
  duration = time.time() - start_time
@@ -371,28 +516,22 @@ class AsyncTamarModelClient:
371
516
  "log_type": "response",
372
517
  "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
373
518
  "duration": duration,
374
- "data": {
375
- "provider": model_request.provider.value,
376
- "invoke_type": model_request.invoke_type.value,
377
- "model": model_request.model,
378
- "stream": True,
379
- "chunks_count": chunk_count,
380
- "error_type": type(e).__name__,
381
- "partial_content_length": len(total_content)
382
- }
519
+ "data": ResponseHandler.build_log_data(
520
+ model_request,
521
+ error=e,
522
+ stream_stats={
523
+ "chunks_count": chunk_count,
524
+ "partial_content_length": len(total_content)
525
+ }
526
+ )
383
527
  }
384
528
  )
385
529
  raise
386
530
 
387
531
  async def _invoke_request(self, request, metadata, invoke_timeout):
532
+ """执行单个非流式请求"""
388
533
  async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
389
- return ModelResponse(
390
- content=response.content,
391
- usage=json.loads(response.usage) if response.usage else None,
392
- error=response.error or None,
393
- raw_response=json.loads(response.raw_response) if response.raw_response else None,
394
- request_id=response.request_id if response.request_id else None,
395
- )
534
+ return ResponseHandler.build_model_response(response)
396
535
 
397
536
  async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None,
398
537
  request_id: Optional[str] = None) -> Union[
@@ -411,6 +550,12 @@ class AsyncTamarModelClient:
411
550
  ValidationError: 输入验证失败。
412
551
  ConnectionError: 连接服务端失败。
413
552
  """
553
+ # 如果启用了熔断且熔断器打开,直接走 HTTP
554
+ if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
555
+ if self.http_fallback_url:
556
+ logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback")
557
+ return await self._invoke_http_fallback(model_request, timeout, request_id)
558
+
414
559
  await self._ensure_initialized()
415
560
 
416
561
  if not self.default_payload:
@@ -420,9 +565,9 @@ class AsyncTamarModelClient:
420
565
  }
421
566
 
422
567
  if not request_id:
423
- request_id = generate_request_id() # 生成一个新的 request_id
424
- set_request_id(request_id) # 设置当前请求的 request_id
425
- metadata = self._build_auth_metadata(request_id) # 将 request_id 加入到请求头
568
+ request_id = generate_request_id()
569
+ set_request_id(request_id)
570
+ metadata = self._build_auth_metadata(request_id)
426
571
 
427
572
  # 记录开始日志
428
573
  start_time = time.time()
@@ -431,135 +576,103 @@ class AsyncTamarModelClient:
431
576
  extra={
432
577
  "log_type": "request",
433
578
  "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
434
- "data": {
435
- "provider": model_request.provider.value,
436
- "invoke_type": model_request.invoke_type.value,
437
- "model": model_request.model,
438
- "stream": model_request.stream,
439
- "org_id": model_request.user_context.org_id,
440
- "user_id": model_request.user_context.user_id,
441
- "client_type": model_request.user_context.client_type
442
- }
579
+ "data": ResponseHandler.build_log_data(model_request)
443
580
  })
444
581
 
445
- # 动态根据 provider/invoke_type 决定使用哪个 input 字段
446
582
  try:
447
- # 选择需要校验的字段集合
448
- # 动态分支逻辑
449
- match (model_request.provider, model_request.invoke_type):
450
- case (ProviderType.GOOGLE, InvokeType.GENERATION):
451
- allowed_fields = GoogleGenAiInput.model_fields.keys()
452
- case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
453
- allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
454
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
455
- allowed_fields = OpenAIResponsesInput.model_fields.keys()
456
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
457
- allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
458
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
459
- allowed_fields = OpenAIImagesInput.model_fields.keys()
460
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
461
- allowed_fields = OpenAIImagesEditInput.model_fields.keys()
462
- case _:
463
- raise ValueError(
464
- f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
465
-
466
- # 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
467
- model_request_dict = model_request.model_dump(exclude_unset=True)
468
-
469
- grpc_request_kwargs = {}
470
- for field in allowed_fields:
471
- if field in model_request_dict:
472
- value = model_request_dict[field]
473
-
474
- # 跳过无效的值
475
- if not is_effective_value(value):
476
- continue
477
-
478
- # 序列化grpc不支持的类型
479
- grpc_request_kwargs[field] = serialize_value(value)
480
-
481
- # 清理 serialize后的 grpc_request_kwargs
482
- grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
483
-
484
- request = model_service_pb2.ModelRequestItem(
485
- provider=model_request.provider.value,
486
- channel=model_request.channel.value,
487
- invoke_type=model_request.invoke_type.value,
488
- stream=model_request.stream or False,
489
- org_id=model_request.user_context.org_id or "",
490
- user_id=model_request.user_context.user_id or "",
491
- client_type=model_request.user_context.client_type or "",
492
- extra=grpc_request_kwargs
493
- )
494
-
583
+ # 构建 gRPC 请求
584
+ request = RequestBuilder.build_single_request(model_request)
585
+
495
586
  except Exception as e:
587
+ duration = time.time() - start_time
588
+ logger.error(
589
+ f"❌ Request build failed: {str(e)}",
590
+ exc_info=True,
591
+ extra={
592
+ "log_type": "response",
593
+ "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
594
+ "duration": duration,
595
+ "data": {
596
+ "provider": model_request.provider.value,
597
+ "invoke_type": model_request.invoke_type.value,
598
+ "model": getattr(model_request, 'model', None),
599
+ "error_type": "build_error",
600
+ "error_message": str(e)
601
+ }
602
+ }
603
+ )
496
604
  raise ValueError(f"构建请求失败: {str(e)}") from e
497
605
 
498
606
  try:
499
607
  invoke_timeout = timeout or self.default_invoke_timeout
500
608
  if model_request.stream:
501
- # 对于流式响应,使用带日志记录的包装器
502
- stream_generator = await self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
609
+ # 对于流式响应,直接返回带日志记录的包装器
503
610
  return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request)
504
611
  else:
505
612
  result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
506
-
613
+
507
614
  # 记录非流式响应的成功日志
508
615
  duration = time.time() - start_time
616
+ content_length = len(result.content) if result.content else 0
509
617
  logger.info(
510
- f"✅ Request completed successfully",
618
+ f"✅ Request completed | content_length: {content_length}",
511
619
  extra={
512
620
  "log_type": "response",
513
621
  "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
514
622
  "duration": duration,
515
- "data": {
516
- "provider": model_request.provider.value,
517
- "invoke_type": model_request.invoke_type.value,
518
- "model": model_request.model,
519
- "stream": False,
520
- "content_length": len(result.content) if result.content else 0,
521
- "usage": result.usage
522
- }
623
+ "data": ResponseHandler.build_log_data(model_request, result)
523
624
  }
524
625
  )
626
+
627
+ # 记录成功(如果启用了熔断)
628
+ if self.resilient_enabled and self.circuit_breaker:
629
+ self.circuit_breaker.record_success()
630
+
525
631
  return result
526
- except grpc.RpcError as e:
632
+
633
+ except (ConnectionError, grpc.RpcError) as e:
527
634
  duration = time.time() - start_time
528
635
  error_message = f"❌ Invoke gRPC failed: {str(e)}"
529
636
  logger.error(error_message, exc_info=True,
530
- extra={
531
- "log_type": "response",
532
- "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
533
- "duration": duration,
534
- "data": {
535
- "error_type": "grpc_error",
536
- "error_code": str(e.code()) if hasattr(e, 'code') else None,
537
- "provider": model_request.provider.value,
538
- "invoke_type": model_request.invoke_type.value,
539
- "model": model_request.model
540
- }
541
- })
637
+ extra={
638
+ "log_type": "response",
639
+ "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
640
+ "duration": duration,
641
+ "data": ResponseHandler.build_log_data(
642
+ model_request,
643
+ error=e
644
+ )
645
+ })
646
+
647
+ # 记录失败并尝试降级(如果启用了熔断)
648
+ if self.resilient_enabled and self.circuit_breaker:
649
+ # 将错误码传递给熔断器,用于智能失败统计
650
+ error_code = e.code() if hasattr(e, 'code') else None
651
+ self.circuit_breaker.record_failure(error_code)
652
+
653
+ # 如果可以降级,则降级
654
+ if self.http_fallback_url and self.circuit_breaker.should_fallback():
655
+ logger.warning(f"🔻 gRPC failed, falling back to HTTP: {str(e)}")
656
+ return await self._invoke_http_fallback(model_request, timeout, request_id)
657
+
542
658
  raise e
543
659
  except Exception as e:
544
660
  duration = time.time() - start_time
545
661
  error_message = f"❌ Invoke other error: {str(e)}"
546
662
  logger.error(error_message, exc_info=True,
547
- extra={
548
- "log_type": "response",
549
- "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
550
- "duration": duration,
551
- "data": {
552
- "error_type": "other_error",
553
- "provider": model_request.provider.value,
554
- "invoke_type": model_request.invoke_type.value,
555
- "model": model_request.model
556
- }
557
- })
663
+ extra={
664
+ "log_type": "response",
665
+ "uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
666
+ "duration": duration,
667
+ "data": ResponseHandler.build_log_data(
668
+ model_request,
669
+ error=e
670
+ )
671
+ })
558
672
  raise e
559
673
 
560
674
  async def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None,
561
- request_id: Optional[str] = None) -> \
562
- BatchModelResponse:
675
+ request_id: Optional[str] = None) -> BatchModelResponse:
563
676
  """
564
677
  批量模型调用接口
565
678
 
@@ -570,7 +683,6 @@ class AsyncTamarModelClient:
570
683
  Returns:
571
684
  BatchModelResponse: 批量请求的结果
572
685
  """
573
-
574
686
  await self._ensure_initialized()
575
687
 
576
688
  if not self.default_payload:
@@ -580,9 +692,9 @@ class AsyncTamarModelClient:
580
692
  }
581
693
 
582
694
  if not request_id:
583
- request_id = generate_request_id() # 生成一个新的 request_id
584
- set_request_id(request_id) # 设置当前请求的 request_id
585
- metadata = self._build_auth_metadata(request_id) # 将 request_id 加入到请求头
695
+ request_id = generate_request_id()
696
+ set_request_id(request_id)
697
+ metadata = self._build_auth_metadata(request_id)
586
698
 
587
699
  # 记录开始日志
588
700
  start_time = time.time()
@@ -599,155 +711,83 @@ class AsyncTamarModelClient:
599
711
  }
600
712
  })
601
713
 
602
- # 构造批量请求
603
- items = []
604
- for model_request_item in batch_request_model.items:
605
- # 动态根据 provider/invoke_type 决定使用哪个 input 字段
606
- try:
607
- match (model_request_item.provider, model_request_item.invoke_type):
608
- case (ProviderType.GOOGLE, InvokeType.GENERATION):
609
- allowed_fields = GoogleGenAiInput.model_fields.keys()
610
- case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
611
- allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
612
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
613
- allowed_fields = OpenAIResponsesInput.model_fields.keys()
614
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
615
- allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
616
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
617
- allowed_fields = OpenAIImagesInput.model_fields.keys()
618
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
619
- allowed_fields = OpenAIImagesEditInput.model_fields.keys()
620
- case _:
621
- raise ValueError(
622
- f"Unsupported provider/invoke_type combination: {model_request_item.provider} + {model_request_item.invoke_type}")
623
-
624
- # 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
625
- model_request_dict = model_request_item.model_dump(exclude_unset=True)
626
-
627
- grpc_request_kwargs = {}
628
- for field in allowed_fields:
629
- if field in model_request_dict:
630
- value = model_request_dict[field]
631
-
632
- # 跳过无效的值
633
- if not is_effective_value(value):
634
- continue
635
-
636
- # 序列化grpc不支持的类型
637
- grpc_request_kwargs[field] = serialize_value(value)
638
-
639
- # 清理 serialize后的 grpc_request_kwargs
640
- grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
641
-
642
- items.append(model_service_pb2.ModelRequestItem(
643
- provider=model_request_item.provider.value,
644
- channel=model_request_item.channel.value,
645
- invoke_type=model_request_item.invoke_type.value,
646
- stream=model_request_item.stream or False,
647
- custom_id=model_request_item.custom_id or "",
648
- priority=model_request_item.priority or 1,
649
- org_id=batch_request_model.user_context.org_id or "",
650
- user_id=batch_request_model.user_context.user_id or "",
651
- client_type=batch_request_model.user_context.client_type or "",
652
- extra=grpc_request_kwargs,
653
- ))
654
-
655
- except Exception as e:
656
- raise ValueError(f"构建请求失败: {str(e)},item={model_request_item.custom_id}") from e
714
+ try:
715
+ # 构建批量请求
716
+ batch_request = RequestBuilder.build_batch_request(batch_request_model)
717
+
718
+ except Exception as e:
719
+ duration = time.time() - start_time
720
+ logger.error(
721
+ f"❌ Batch request build failed: {str(e)}",
722
+ exc_info=True,
723
+ extra={
724
+ "log_type": "response",
725
+ "uri": "/batch_invoke",
726
+ "duration": duration,
727
+ "data": {
728
+ "batch_size": len(batch_request_model.items),
729
+ "error_type": "build_error",
730
+ "error_message": str(e)
731
+ }
732
+ }
733
+ )
734
+ raise ValueError(f"构建批量请求失败: {str(e)}") from e
657
735
 
658
736
  try:
659
- # 超时处理逻辑
660
737
  invoke_timeout = timeout or self.default_invoke_timeout
661
-
662
- # 调用 gRPC 接口
663
- response = await self._retry_request(self.stub.BatchInvoke, model_service_pb2.ModelRequest(items=items),
664
- timeout=invoke_timeout, metadata=metadata)
665
-
666
- result = []
667
- for res_item in response.items:
668
- result.append(ModelResponse(
669
- content=res_item.content,
670
- usage=json.loads(res_item.usage) if res_item.usage else None,
671
- raw_response=json.loads(res_item.raw_response) if res_item.raw_response else None,
672
- error=res_item.error or None,
673
- custom_id=res_item.custom_id if res_item.custom_id else None
674
- ))
675
- batch_response = BatchModelResponse(
676
- request_id=response.request_id if response.request_id else None,
677
- responses=result
738
+ batch_response = await self._retry_request(
739
+ self.stub.BatchInvoke,
740
+ batch_request,
741
+ metadata=metadata,
742
+ timeout=invoke_timeout
678
743
  )
679
-
744
+
745
+ # 构建响应对象
746
+ result = ResponseHandler.build_batch_response(batch_response)
747
+
680
748
  # 记录成功日志
681
749
  duration = time.time() - start_time
682
750
  logger.info(
683
- f"✅ Batch request completed successfully",
751
+ f"✅ Batch Request completed | batch_size: {len(result.responses)}",
684
752
  extra={
685
753
  "log_type": "response",
686
754
  "uri": "/batch_invoke",
687
755
  "duration": duration,
688
756
  "data": {
689
- "batch_size": len(batch_request_model.items),
690
- "responses_count": len(result)
757
+ "batch_size": len(result.responses),
758
+ "success_count": sum(1 for item in result.responses if not item.error),
759
+ "error_count": sum(1 for item in result.responses if item.error)
691
760
  }
692
- }
693
- )
694
- return batch_response
761
+ })
762
+
763
+ return result
764
+
695
765
  except grpc.RpcError as e:
696
766
  duration = time.time() - start_time
697
- error_message = f"❌ BatchInvoke gRPC failed: {str(e)}"
767
+ error_message = f"❌ Batch invoke gRPC failed: {str(e)}"
698
768
  logger.error(error_message, exc_info=True,
699
- extra={
700
- "log_type": "response",
701
- "uri": "/batch_invoke",
702
- "duration": duration,
703
- "data": {
704
- "error_type": "grpc_error",
705
- "error_code": str(e.code()) if hasattr(e, 'code') else None,
706
- "batch_size": len(batch_request_model.items)
707
- }
708
- })
769
+ extra={
770
+ "log_type": "response",
771
+ "uri": "/batch_invoke",
772
+ "duration": duration,
773
+ "data": {
774
+ "error_type": "grpc_error",
775
+ "error_code": str(e.code()) if hasattr(e, 'code') else None,
776
+ "batch_size": len(batch_request_model.items)
777
+ }
778
+ })
709
779
  raise e
710
780
  except Exception as e:
711
781
  duration = time.time() - start_time
712
- error_message = f"❌ BatchInvoke other error: {str(e)}"
782
+ error_message = f"❌ Batch invoke other error: {str(e)}"
713
783
  logger.error(error_message, exc_info=True,
714
- extra={
715
- "log_type": "response",
716
- "uri": "/batch_invoke",
717
- "duration": duration,
718
- "data": {
719
- "error_type": "other_error",
720
- "batch_size": len(batch_request_model.items)
721
- }
722
- })
723
- raise e
724
-
725
- async def close(self):
726
- """关闭 gRPC 通道"""
727
- if self.channel and not self._closed:
728
- await self.channel.close()
729
- self._closed = True
730
- logger.info("✅ gRPC channel closed",
731
- extra={"log_type": "info", "data": {"status": "success"}})
732
-
733
- def _safe_sync_close(self):
734
- """进程退出时自动关闭 channel(事件循环处理兼容)"""
735
- if self.channel and not self._closed:
736
- try:
737
- loop = asyncio.get_event_loop()
738
- if loop.is_running():
739
- loop.create_task(self.close())
740
- else:
741
- loop.run_until_complete(self.close())
742
- except Exception as e:
743
- logger.warning(f"⚠️ gRPC channel close failed at exit: {e}",
744
- extra={"log_type": "info", "data": {"status": "failed", "error": str(e)}})
745
-
746
- async def __aenter__(self):
747
- """支持 async with 自动初始化连接"""
748
- await self._ensure_initialized()
749
- return self
750
-
751
- async def __aexit__(self, exc_type, exc_val, exc_tb):
752
- """支持 async with 自动关闭连接"""
753
- await self.close()
784
+ extra={
785
+ "log_type": "response",
786
+ "uri": "/batch_invoke",
787
+ "duration": duration,
788
+ "data": {
789
+ "error_type": "other_error",
790
+ "batch_size": len(batch_request_model.items)
791
+ }
792
+ })
793
+ raise e