tamar-model-client 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/async_client.py +580 -540
- tamar_model_client/circuit_breaker.py +140 -0
- tamar_model_client/core/__init__.py +40 -0
- tamar_model_client/core/base_client.py +221 -0
- tamar_model_client/core/http_fallback.py +249 -0
- tamar_model_client/core/logging_setup.py +194 -0
- tamar_model_client/core/request_builder.py +221 -0
- tamar_model_client/core/response_handler.py +136 -0
- tamar_model_client/core/utils.py +171 -0
- tamar_model_client/error_handler.py +315 -0
- tamar_model_client/exceptions.py +388 -7
- tamar_model_client/json_formatter.py +36 -1
- tamar_model_client/sync_client.py +590 -486
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/METADATA +289 -61
- tamar_model_client-0.1.21.dist-info/RECORD +35 -0
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/top_level.txt +1 -0
- tests/__init__.py +1 -0
- tests/stream_hanging_analysis.py +357 -0
- tests/test_google_azure_final.py +448 -0
- tests/test_simple.py +235 -0
- tamar_model_client-0.1.19.dist-info/RECORD +0 -22
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/WHEEL +0 -0
@@ -1,273 +1,186 @@
|
|
1
|
+
"""
|
2
|
+
Tamar Model Client 异步客户端实现
|
3
|
+
|
4
|
+
本模块实现了异步的 gRPC 客户端,用于与 Model Manager Server 进行通信。
|
5
|
+
支持单个请求、批量请求、流式响应等功能,并提供了完整的错误处理和重试机制。
|
6
|
+
|
7
|
+
主要功能:
|
8
|
+
- 异步 gRPC 通信
|
9
|
+
- JWT 认证
|
10
|
+
- 自动重试和错误处理
|
11
|
+
- 流式响应支持
|
12
|
+
- 连接池管理
|
13
|
+
- 详细的日志记录
|
14
|
+
|
15
|
+
使用示例:
|
16
|
+
async with AsyncTamarModelClient() as client:
|
17
|
+
request = ModelRequest(...)
|
18
|
+
response = await client.invoke(request)
|
19
|
+
"""
|
20
|
+
|
1
21
|
import asyncio
|
2
22
|
import atexit
|
3
|
-
import base64
|
4
23
|
import json
|
5
24
|
import logging
|
6
|
-
import
|
25
|
+
import random
|
7
26
|
import time
|
8
|
-
import
|
9
|
-
from contextvars import ContextVar
|
27
|
+
from typing import Optional, AsyncIterator, Union
|
10
28
|
|
11
29
|
import grpc
|
12
|
-
from
|
13
|
-
|
14
|
-
from
|
15
|
-
|
16
|
-
|
17
|
-
|
30
|
+
from grpc import RpcError
|
31
|
+
|
32
|
+
from .core import (
|
33
|
+
generate_request_id,
|
34
|
+
set_request_id,
|
35
|
+
get_protected_logger,
|
36
|
+
MAX_MESSAGE_LENGTH
|
37
|
+
)
|
38
|
+
from .core.base_client import BaseClient
|
39
|
+
from .core.request_builder import RequestBuilder
|
40
|
+
from .core.response_handler import ResponseHandler
|
18
41
|
from .enums import ProviderType, InvokeType
|
19
|
-
from .exceptions import ConnectionError
|
42
|
+
from .exceptions import ConnectionError, TamarModelException
|
43
|
+
from .error_handler import EnhancedRetryHandler
|
20
44
|
from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
|
21
45
|
from .generated import model_service_pb2, model_service_pb2_grpc
|
22
|
-
from .
|
23
|
-
GoogleVertexAIImagesInput, OpenAIImagesInput, OpenAIImagesEditInput
|
24
|
-
from .json_formatter import JSONFormatter
|
25
|
-
|
26
|
-
logger = logging.getLogger(__name__)
|
27
|
-
|
28
|
-
# 使用 contextvars 管理请求ID
|
29
|
-
_request_id: ContextVar[str] = ContextVar('request_id', default='-')
|
30
|
-
|
31
|
-
|
32
|
-
class RequestIdFilter(logging.Filter):
|
33
|
-
"""自定义日志过滤器,向日志中添加 request_id"""
|
34
|
-
|
35
|
-
def filter(self, record):
|
36
|
-
# 从 ContextVar 中获取当前的 request_id
|
37
|
-
record.request_id = _request_id.get()
|
38
|
-
return True
|
39
|
-
|
46
|
+
from .core.http_fallback import AsyncHttpFallbackMixin
|
40
47
|
|
41
|
-
|
42
|
-
|
43
|
-
console_handler = logging.StreamHandler()
|
48
|
+
# 配置日志记录器(使用受保护的logger)
|
49
|
+
logger = get_protected_logger(__name__)
|
44
50
|
|
45
|
-
# 使用 JSON 格式化器
|
46
|
-
formatter = JSONFormatter()
|
47
|
-
console_handler.setFormatter(formatter)
|
48
51
|
|
49
|
-
|
50
|
-
logger.addHandler(console_handler)
|
51
|
-
|
52
|
-
# 设置日志级别
|
53
|
-
logger.setLevel(logging.INFO)
|
54
|
-
|
55
|
-
# 将自定义的 RequestIdFilter 添加到 logger 中
|
56
|
-
logger.addFilter(RequestIdFilter())
|
57
|
-
|
58
|
-
MAX_MESSAGE_LENGTH = 2 ** 31 - 1 # 对于32位系统
|
59
|
-
|
60
|
-
|
61
|
-
def is_effective_value(value) -> bool:
|
52
|
+
class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
62
53
|
"""
|
63
|
-
|
54
|
+
Tamar Model Client 异步客户端
|
55
|
+
|
56
|
+
提供与 Model Manager Server 的异步通信能力,支持:
|
57
|
+
- 单个和批量模型调用
|
58
|
+
- 流式和非流式响应
|
59
|
+
- 自动重试和错误恢复
|
60
|
+
- JWT 认证
|
61
|
+
- 连接池管理
|
62
|
+
|
63
|
+
使用示例:
|
64
|
+
# 基本用法
|
65
|
+
client = AsyncTamarModelClient()
|
66
|
+
await client.connect()
|
67
|
+
|
68
|
+
request = ModelRequest(...)
|
69
|
+
response = await client.invoke(request)
|
70
|
+
|
71
|
+
# 上下文管理器用法(推荐)
|
72
|
+
async with AsyncTamarModelClient() as client:
|
73
|
+
response = await client.invoke(request)
|
74
|
+
|
75
|
+
环境变量配置:
|
76
|
+
MODEL_MANAGER_SERVER_ADDRESS: gRPC 服务器地址
|
77
|
+
MODEL_MANAGER_SERVER_JWT_SECRET_KEY: JWT 密钥
|
78
|
+
MODEL_MANAGER_SERVER_GRPC_USE_TLS: 是否使用 TLS
|
79
|
+
MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES: 最大重试次数
|
80
|
+
MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY: 重试延迟
|
64
81
|
"""
|
65
|
-
if value is None or value is NOT_GIVEN:
|
66
|
-
return False
|
67
|
-
|
68
|
-
if isinstance(value, str):
|
69
|
-
return value.strip() != ""
|
70
|
-
|
71
|
-
if isinstance(value, bytes):
|
72
|
-
return len(value) > 0
|
73
|
-
|
74
|
-
if isinstance(value, dict):
|
75
|
-
for v in value.values():
|
76
|
-
if is_effective_value(v):
|
77
|
-
return True
|
78
|
-
return False
|
79
|
-
|
80
|
-
if isinstance(value, list):
|
81
|
-
for item in value:
|
82
|
-
if is_effective_value(item):
|
83
|
-
return True
|
84
|
-
return False
|
85
|
-
|
86
|
-
return True # 其他类型(int/float/bool)只要不是None就算有效
|
87
82
|
|
83
|
+
def __init__(self, **kwargs):
|
84
|
+
"""
|
85
|
+
初始化异步客户端
|
86
|
+
|
87
|
+
参数继承自 BaseClient,包括:
|
88
|
+
- server_address: gRPC 服务器地址
|
89
|
+
- jwt_secret_key: JWT 签名密钥
|
90
|
+
- jwt_token: 预生成的 JWT 令牌
|
91
|
+
- default_payload: JWT 令牌的默认载荷
|
92
|
+
- token_expires_in: JWT 令牌过期时间
|
93
|
+
- max_retries: 最大重试次数
|
94
|
+
- retry_delay: 初始重试延迟
|
95
|
+
"""
|
96
|
+
super().__init__(logger_name=__name__, **kwargs)
|
97
|
+
|
98
|
+
# === gRPC 通道和连接管理 ===
|
99
|
+
self.channel: Optional[grpc.aio.Channel] = None
|
100
|
+
self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
|
101
|
+
|
102
|
+
# === 增强的重试处理器 ===
|
103
|
+
self.retry_handler = EnhancedRetryHandler(
|
104
|
+
max_retries=self.max_retries,
|
105
|
+
base_delay=self.retry_delay
|
106
|
+
)
|
107
|
+
|
108
|
+
# 注册退出时的清理函数
|
109
|
+
atexit.register(self._cleanup_atexit)
|
88
110
|
|
89
|
-
def
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
if isinstance(value, dict):
|
98
|
-
return {k: serialize_value(v) for k, v in value.items()}
|
99
|
-
if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
|
100
|
-
return [serialize_value(v) for v in value]
|
101
|
-
if isinstance(value, bytes):
|
102
|
-
return f"bytes:{base64.b64encode(value).decode('utf-8')}"
|
103
|
-
return value
|
111
|
+
def _cleanup_atexit(self):
|
112
|
+
"""程序退出时的清理函数"""
|
113
|
+
if self.channel and not self._closed:
|
114
|
+
try:
|
115
|
+
asyncio.create_task(self.close())
|
116
|
+
except RuntimeError:
|
117
|
+
# 如果事件循环已经关闭,忽略错误
|
118
|
+
pass
|
104
119
|
|
120
|
+
async def close(self):
|
121
|
+
"""
|
122
|
+
关闭客户端连接
|
123
|
+
|
124
|
+
优雅地关闭 gRPC 通道并清理资源。
|
125
|
+
建议在程序结束前调用此方法,或使用上下文管理器自动管理。
|
126
|
+
"""
|
127
|
+
if self.channel and not self._closed:
|
128
|
+
await self.channel.close()
|
129
|
+
self._closed = True
|
130
|
+
logger.info("🔒 gRPC channel closed",
|
131
|
+
extra={"log_type": "info", "data": {"status": "closed"}})
|
132
|
+
|
133
|
+
# 清理 HTTP session(如果有)
|
134
|
+
if self.resilient_enabled:
|
135
|
+
await self._cleanup_http_session()
|
105
136
|
|
106
|
-
|
137
|
+
async def __aenter__(self):
|
138
|
+
"""异步上下文管理器入口"""
|
139
|
+
await self.connect()
|
140
|
+
return self
|
107
141
|
|
142
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
143
|
+
"""异步上下文管理器出口"""
|
144
|
+
await self.close()
|
108
145
|
|
109
|
-
def
|
110
|
-
|
111
|
-
|
112
|
-
"""
|
113
|
-
if isinstance(data, dict):
|
114
|
-
new_dict = {}
|
115
|
-
for key, value in data.items():
|
116
|
-
if value is None:
|
117
|
-
continue
|
118
|
-
cleaned_value = remove_none_from_dict(value)
|
119
|
-
new_dict[key] = cleaned_value
|
120
|
-
return new_dict
|
121
|
-
elif isinstance(data, list):
|
122
|
-
return [remove_none_from_dict(item) for item in data]
|
123
|
-
else:
|
124
|
-
return data
|
125
|
-
|
126
|
-
|
127
|
-
def generate_request_id():
|
128
|
-
"""生成一个唯一的request_id"""
|
129
|
-
return str(uuid.uuid4())
|
130
|
-
|
131
|
-
|
132
|
-
def set_request_id(request_id: str):
|
133
|
-
"""设置当前请求的 request_id"""
|
134
|
-
_request_id.set(request_id)
|
135
|
-
|
136
|
-
|
137
|
-
class AsyncTamarModelClient:
|
138
|
-
def __init__(
|
139
|
-
self,
|
140
|
-
server_address: Optional[str] = None,
|
141
|
-
jwt_secret_key: Optional[str] = None,
|
142
|
-
jwt_token: Optional[str] = None,
|
143
|
-
default_payload: Optional[dict] = None,
|
144
|
-
token_expires_in: int = 3600,
|
145
|
-
max_retries: Optional[int] = None, # 最大重试次数
|
146
|
-
retry_delay: Optional[float] = None, # 初始重试延迟(秒)
|
147
|
-
):
|
148
|
-
# 服务端地址
|
149
|
-
self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
|
150
|
-
if not self.server_address:
|
151
|
-
raise ValueError("Server address must be provided via argument or environment variable.")
|
152
|
-
self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
|
153
|
-
|
154
|
-
# JWT 配置
|
155
|
-
self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
|
156
|
-
self.jwt_handler = JWTAuthHandler(self.jwt_secret_key)
|
157
|
-
self.jwt_token = jwt_token # 用户传入的 Token(可选)
|
158
|
-
self.default_payload = default_payload
|
159
|
-
self.token_expires_in = token_expires_in
|
160
|
-
|
161
|
-
# === TLS/Authority 配置 ===
|
162
|
-
self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
|
163
|
-
self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
|
164
|
-
|
165
|
-
# === 重试配置 ===
|
166
|
-
self.max_retries = max_retries if max_retries is not None else int(
|
167
|
-
os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
|
168
|
-
self.retry_delay = retry_delay if retry_delay is not None else float(
|
169
|
-
os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
|
170
|
-
|
171
|
-
# === gRPC 通道相关 ===
|
172
|
-
self.channel: Optional[grpc.aio.Channel] = None
|
173
|
-
self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
|
174
|
-
self._closed = False
|
175
|
-
atexit.register(self._safe_sync_close) # 注册进程退出自动关闭
|
146
|
+
def __enter__(self):
|
147
|
+
"""同步上下文管理器入口(不支持)"""
|
148
|
+
raise TypeError("Use 'async with' for AsyncTamarModelClient")
|
176
149
|
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
try:
|
181
|
-
return await func(*args, **kwargs)
|
182
|
-
except (grpc.aio.AioRpcError, grpc.RpcError) as e:
|
183
|
-
# 对于取消的情况进行指数退避重试
|
184
|
-
if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
|
185
|
-
retry_count += 1
|
186
|
-
logger.warning(f"⚠️ RPC cancelled, retrying {retry_count}/{self.max_retries}...",
|
187
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": "CANCELLED"}})
|
188
|
-
if retry_count < self.max_retries:
|
189
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
190
|
-
await asyncio.sleep(delay)
|
191
|
-
else:
|
192
|
-
logger.error("❌ Max retry reached for CANCELLED",
|
193
|
-
extra={"log_type": "info", "data": {"error_code": "CANCELLED", "max_retries_reached": True}})
|
194
|
-
raise
|
195
|
-
# 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
|
196
|
-
elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
|
197
|
-
grpc.StatusCode.DEADLINE_EXCEEDED}:
|
198
|
-
retry_count += 1
|
199
|
-
logger.warning(f"⚠️ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...",
|
200
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": str(e.code())}})
|
201
|
-
if retry_count < self.max_retries:
|
202
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
203
|
-
await asyncio.sleep(delay)
|
204
|
-
else:
|
205
|
-
logger.error(f"❌ Max retry reached for {e.code()}",
|
206
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()), "max_retries_reached": True}})
|
207
|
-
raise
|
208
|
-
else:
|
209
|
-
logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True,
|
210
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()) if hasattr(e, 'code') else None, "retryable": False}})
|
211
|
-
raise
|
150
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
151
|
+
"""同步上下文管理器出口(不支持)"""
|
152
|
+
pass
|
212
153
|
|
213
|
-
async def
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
retry_count += 1
|
222
|
-
logger.warning(f"⚠️ RPC cancelled, retrying {retry_count}/{self.max_retries}...",
|
223
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": "CANCELLED"}})
|
224
|
-
if retry_count < self.max_retries:
|
225
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
226
|
-
await asyncio.sleep(delay)
|
227
|
-
else:
|
228
|
-
logger.error("❌ Max retry reached for CANCELLED",
|
229
|
-
extra={"log_type": "info", "data": {"error_code": "CANCELLED", "max_retries_reached": True}})
|
230
|
-
raise
|
231
|
-
# 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
|
232
|
-
elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
|
233
|
-
grpc.StatusCode.DEADLINE_EXCEEDED}:
|
234
|
-
retry_count += 1
|
235
|
-
logger.warning(f"⚠️ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...",
|
236
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "error_code": str(e.code())}})
|
237
|
-
if retry_count < self.max_retries:
|
238
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
239
|
-
await asyncio.sleep(delay)
|
240
|
-
else:
|
241
|
-
logger.error(f"❌ Max retry reached for {e.code()}",
|
242
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()), "max_retries_reached": True}})
|
243
|
-
raise
|
244
|
-
else:
|
245
|
-
logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True,
|
246
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()) if hasattr(e, 'code') else None, "retryable": False}})
|
247
|
-
raise
|
248
|
-
|
249
|
-
def _build_auth_metadata(self, request_id: str) -> list:
|
250
|
-
# if not self.jwt_token and self.jwt_handler:
|
251
|
-
# 更改为每次请求都生成一次token
|
252
|
-
metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
|
253
|
-
if self.jwt_handler:
|
254
|
-
self.jwt_token = self.jwt_handler.encode_token(self.default_payload, expires_in=self.token_expires_in)
|
255
|
-
metadata.append(("authorization", f"Bearer {self.jwt_token}"))
|
256
|
-
return metadata
|
154
|
+
async def connect(self):
|
155
|
+
"""
|
156
|
+
显式连接到服务器
|
157
|
+
|
158
|
+
建立与 gRPC 服务器的连接。通常不需要手动调用,
|
159
|
+
因为 invoke 方法会自动确保连接已建立。
|
160
|
+
"""
|
161
|
+
await self._ensure_initialized()
|
257
162
|
|
258
163
|
async def _ensure_initialized(self):
|
259
|
-
"""
|
164
|
+
"""
|
165
|
+
初始化gRPC通道
|
166
|
+
|
167
|
+
确保gRPC通道和存根已正确初始化。如果初始化失败,
|
168
|
+
会进行重试,支持TLS配置和完整的keepalive选项。
|
169
|
+
|
170
|
+
连接配置包括:
|
171
|
+
- 消息大小限制
|
172
|
+
- Keepalive设置(30秒ping间隔,10秒超时)
|
173
|
+
- 连接生命周期管理(1小时最大连接时间)
|
174
|
+
- 性能优化选项(带宽探测、内置重试)
|
175
|
+
|
176
|
+
Raises:
|
177
|
+
ConnectionError: 当达到最大重试次数仍无法连接时
|
178
|
+
"""
|
260
179
|
if self.channel and self.stub:
|
261
180
|
return
|
262
181
|
|
263
182
|
retry_count = 0
|
264
|
-
options =
|
265
|
-
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
266
|
-
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
|
267
|
-
('grpc.keepalive_permit_without_calls', True) # 即使没有活跃请求也保持连接
|
268
|
-
]
|
269
|
-
if self.default_authority:
|
270
|
-
options.append(("grpc.default_authority", self.default_authority))
|
183
|
+
options = self.build_channel_options()
|
271
184
|
|
272
185
|
while retry_count <= self.max_retries:
|
273
186
|
try:
|
@@ -279,60 +192,272 @@ class AsyncTamarModelClient:
|
|
279
192
|
options=options
|
280
193
|
)
|
281
194
|
logger.info("🔐 Using secure gRPC channel (TLS enabled)",
|
282
|
-
|
195
|
+
extra={"log_type": "info",
|
196
|
+
"data": {"tls_enabled": True, "server_address": self.server_address}})
|
283
197
|
else:
|
284
198
|
self.channel = grpc.aio.insecure_channel(
|
285
199
|
self.server_address,
|
286
200
|
options=options
|
287
201
|
)
|
288
202
|
logger.info("🔓 Using insecure gRPC channel (TLS disabled)",
|
289
|
-
|
203
|
+
extra={"log_type": "info",
|
204
|
+
"data": {"tls_enabled": False, "server_address": self.server_address}})
|
205
|
+
|
290
206
|
await self.channel.channel_ready()
|
291
207
|
self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
|
292
208
|
logger.info(f"✅ gRPC channel initialized to {self.server_address}",
|
293
|
-
|
209
|
+
extra={"log_type": "info",
|
210
|
+
"data": {"status": "success", "server_address": self.server_address}})
|
294
211
|
return
|
212
|
+
|
295
213
|
except grpc.FutureTimeoutError as e:
|
296
214
|
logger.error(f"❌ gRPC channel initialization timed out: {str(e)}", exc_info=True,
|
297
|
-
|
215
|
+
extra={"log_type": "info",
|
216
|
+
"data": {"error_type": "timeout", "server_address": self.server_address}})
|
298
217
|
except grpc.RpcError as e:
|
299
218
|
logger.error(f"❌ gRPC channel initialization failed: {str(e)}", exc_info=True,
|
300
|
-
|
219
|
+
extra={"log_type": "info",
|
220
|
+
"data": {"error_type": "grpc_error", "server_address": self.server_address}})
|
301
221
|
except Exception as e:
|
302
|
-
logger.error(f"❌ Unexpected error during channel initialization: {str(e)}", exc_info=True,
|
303
|
-
|
304
|
-
|
222
|
+
logger.error(f"❌ Unexpected error during gRPC channel initialization: {str(e)}", exc_info=True,
|
223
|
+
extra={"log_type": "info",
|
224
|
+
"data": {"error_type": "unknown", "server_address": self.server_address}})
|
225
|
+
|
305
226
|
retry_count += 1
|
306
|
-
if retry_count
|
307
|
-
|
308
|
-
extra={"log_type": "info", "data": {"max_retries_reached": True, "server_address": self.server_address}})
|
309
|
-
raise ConnectionError(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.")
|
227
|
+
if retry_count <= self.max_retries:
|
228
|
+
await asyncio.sleep(self.retry_delay * retry_count)
|
310
229
|
|
311
|
-
|
312
|
-
delay = self.retry_delay * (2 ** (retry_count - 1))
|
313
|
-
logger.warning(f"🔄 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...",
|
314
|
-
extra={"log_type": "info", "data": {"retry_count": retry_count, "max_retries": self.max_retries, "delay": delay}})
|
315
|
-
await asyncio.sleep(delay)
|
230
|
+
raise ConnectionError(f"Failed to connect to {self.server_address} after {self.max_retries} retries")
|
316
231
|
|
317
|
-
async def
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
232
|
+
async def _retry_request(self, func, *args, **kwargs):
|
233
|
+
"""
|
234
|
+
使用增强的重试处理器执行请求
|
235
|
+
|
236
|
+
Args:
|
237
|
+
func: 要执行的异步函数
|
238
|
+
*args: 函数参数
|
239
|
+
**kwargs: 函数关键字参数
|
240
|
+
|
241
|
+
Returns:
|
242
|
+
函数执行结果
|
243
|
+
|
244
|
+
Raises:
|
245
|
+
TamarModelException: 当所有重试都失败时
|
246
|
+
"""
|
247
|
+
return await self.retry_handler.execute_with_retry(func, *args, **kwargs)
|
248
|
+
|
249
|
+
async def _retry_request_stream(self, func, *args, **kwargs):
|
250
|
+
"""
|
251
|
+
流式请求的重试逻辑
|
252
|
+
|
253
|
+
对于流式响应,需要特殊的重试处理,因为流不能简单地重新执行。
|
254
|
+
|
255
|
+
Args:
|
256
|
+
func: 生成流的异步函数
|
257
|
+
*args: 函数参数
|
258
|
+
**kwargs: 函数关键字参数
|
259
|
+
|
260
|
+
Returns:
|
261
|
+
AsyncIterator: 流式响应迭代器
|
262
|
+
"""
|
263
|
+
last_exception = None
|
264
|
+
context = {
|
265
|
+
'method': 'stream',
|
266
|
+
'client_version': 'async',
|
267
|
+
}
|
268
|
+
|
269
|
+
for attempt in range(self.max_retries + 1):
|
270
|
+
try:
|
271
|
+
context['retry_count'] = attempt
|
272
|
+
# 尝试创建流
|
273
|
+
async for item in func(*args, **kwargs):
|
274
|
+
yield item
|
275
|
+
return
|
276
|
+
|
277
|
+
except RpcError as e:
|
278
|
+
# 使用智能重试判断
|
279
|
+
context['retry_count'] = attempt
|
280
|
+
|
281
|
+
# 创建错误上下文并判断是否应该重试
|
282
|
+
from .exceptions import ErrorContext, get_retry_policy
|
283
|
+
error_context = ErrorContext(e, context)
|
284
|
+
error_code = e.code()
|
285
|
+
policy = get_retry_policy(error_code)
|
286
|
+
retryable = policy.get('retryable', False)
|
287
|
+
|
288
|
+
should_retry = False
|
289
|
+
if attempt < self.max_retries:
|
290
|
+
if retryable == True:
|
291
|
+
should_retry = True
|
292
|
+
elif retryable == 'conditional':
|
293
|
+
# 条件重试,特殊处理 CANCELLED
|
294
|
+
if error_code == grpc.StatusCode.CANCELLED:
|
295
|
+
should_retry = error_context.is_network_cancelled()
|
296
|
+
else:
|
297
|
+
should_retry = self._check_error_details_for_retry(e)
|
298
|
+
|
299
|
+
if should_retry:
|
300
|
+
log_data = {
|
301
|
+
"log_type": "info",
|
302
|
+
"request_id": context.get('request_id'),
|
303
|
+
"data": {
|
304
|
+
"error_code": e.code().name if e.code() else 'UNKNOWN',
|
305
|
+
"retry_count": attempt,
|
306
|
+
"max_retries": self.max_retries,
|
307
|
+
"method": "stream"
|
308
|
+
}
|
309
|
+
}
|
310
|
+
logger.warning(
|
311
|
+
f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
|
312
|
+
extra=log_data
|
313
|
+
)
|
314
|
+
|
315
|
+
# 计算退避时间
|
316
|
+
delay = self._calculate_backoff(attempt, error_code)
|
317
|
+
await asyncio.sleep(delay)
|
318
|
+
else:
|
319
|
+
# 不重试或已达到最大重试次数
|
320
|
+
log_data = {
|
321
|
+
"log_type": "info",
|
322
|
+
"request_id": context.get('request_id'),
|
323
|
+
"data": {
|
324
|
+
"error_code": e.code().name if e.code() else 'UNKNOWN',
|
325
|
+
"retry_count": attempt,
|
326
|
+
"max_retries": self.max_retries,
|
327
|
+
"method": "stream",
|
328
|
+
"will_retry": False
|
329
|
+
}
|
330
|
+
}
|
331
|
+
logger.error(
|
332
|
+
f"Stream failed: {e.code()} (no retry)",
|
333
|
+
extra=log_data
|
334
|
+
)
|
335
|
+
last_exception = self.error_handler.handle_error(e, context)
|
336
|
+
break
|
337
|
+
|
338
|
+
last_exception = e
|
339
|
+
|
340
|
+
except Exception as e:
|
341
|
+
context['retry_count'] = attempt
|
342
|
+
raise TamarModelException(str(e)) from e
|
343
|
+
|
344
|
+
if last_exception:
|
345
|
+
if isinstance(last_exception, TamarModelException):
|
346
|
+
raise last_exception
|
347
|
+
else:
|
348
|
+
raise self.error_handler.handle_error(last_exception, context)
|
349
|
+
else:
|
350
|
+
raise TamarModelException("Unknown streaming error occurred")
|
351
|
+
|
352
|
+
def _check_error_details_for_retry(self, error: RpcError) -> bool:
|
353
|
+
"""检查错误详情决定是否重试"""
|
354
|
+
error_message = error.details().lower() if error.details() else ""
|
355
|
+
|
356
|
+
# 可重试的错误模式
|
357
|
+
retryable_patterns = [
|
358
|
+
'temporary', 'timeout', 'unavailable',
|
359
|
+
'connection', 'network', 'try again'
|
360
|
+
]
|
361
|
+
|
362
|
+
for pattern in retryable_patterns:
|
363
|
+
if pattern in error_message:
|
364
|
+
return True
|
365
|
+
|
366
|
+
return False
|
326
367
|
|
327
|
-
|
328
|
-
"""
|
368
|
+
def _calculate_backoff(self, attempt: int, error_code = None) -> float:
|
369
|
+
"""
|
370
|
+
计算退避时间,支持不同的退避策略
|
371
|
+
|
372
|
+
Args:
|
373
|
+
attempt: 当前重试次数
|
374
|
+
error_code: gRPC错误码,用于确定退避策略
|
375
|
+
"""
|
376
|
+
max_delay = 60.0
|
377
|
+
base_delay = self.retry_delay
|
378
|
+
|
379
|
+
# 获取错误的重试策略
|
380
|
+
if error_code:
|
381
|
+
from .exceptions import get_retry_policy
|
382
|
+
policy = get_retry_policy(error_code)
|
383
|
+
backoff_type = policy.get('backoff', 'exponential')
|
384
|
+
use_jitter = policy.get('jitter', False)
|
385
|
+
else:
|
386
|
+
backoff_type = 'exponential'
|
387
|
+
use_jitter = False
|
388
|
+
|
389
|
+
# 根据退避类型计算延迟
|
390
|
+
if backoff_type == 'linear':
|
391
|
+
# 线性退避:delay * (attempt + 1)
|
392
|
+
delay = min(base_delay * (attempt + 1), max_delay)
|
393
|
+
else:
|
394
|
+
# 指数退避:delay * 2^attempt
|
395
|
+
delay = min(base_delay * (2 ** attempt), max_delay)
|
396
|
+
|
397
|
+
# 添加抖动
|
398
|
+
if use_jitter:
|
399
|
+
jitter_factor = 0.2 # 增加抖动范围,减少竞争
|
400
|
+
jitter = random.uniform(0, delay * jitter_factor)
|
401
|
+
delay += jitter
|
402
|
+
else:
|
403
|
+
# 默认的小量抖动,避免完全同步
|
404
|
+
jitter_factor = 0.05
|
405
|
+
jitter = random.uniform(0, delay * jitter_factor)
|
406
|
+
delay += jitter
|
407
|
+
|
408
|
+
return delay
|
409
|
+
|
410
|
+
async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
|
411
|
+
"""
|
412
|
+
处理流式响应
|
413
|
+
|
414
|
+
包含块级超时保护,防止流式响应挂起。
|
415
|
+
|
416
|
+
Args:
|
417
|
+
request: gRPC 请求对象
|
418
|
+
metadata: 请求元数据
|
419
|
+
invoke_timeout: 总体超时时间
|
420
|
+
|
421
|
+
Yields:
|
422
|
+
ModelResponse: 流式响应的每个数据块
|
423
|
+
|
424
|
+
Raises:
|
425
|
+
TimeoutError: 当等待下一个数据块超时时
|
426
|
+
"""
|
427
|
+
stream_iter = self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout).__aiter__()
|
428
|
+
chunk_timeout = 30.0 # 单个数据块的超时时间
|
429
|
+
|
430
|
+
try:
|
431
|
+
while True:
|
432
|
+
try:
|
433
|
+
# 对每个数据块的获取进行超时保护
|
434
|
+
response = await asyncio.wait_for(
|
435
|
+
stream_iter.__anext__(),
|
436
|
+
timeout=chunk_timeout
|
437
|
+
)
|
438
|
+
yield ResponseHandler.build_model_response(response)
|
439
|
+
|
440
|
+
except asyncio.TimeoutError:
|
441
|
+
raise TimeoutError(f"流式响应在等待下一个数据块时超时 ({chunk_timeout}s)")
|
442
|
+
|
443
|
+
except StopAsyncIteration:
|
444
|
+
break # 正常结束
|
445
|
+
except Exception as e:
|
446
|
+
raise
|
447
|
+
|
448
|
+
async def _stream_with_logging(self, request, metadata, invoke_timeout, start_time, model_request) -> AsyncIterator[
|
449
|
+
ModelResponse]:
|
450
|
+
"""流式响应的包装器,用于记录完整的响应日志并处理重试"""
|
329
451
|
total_content = ""
|
330
452
|
final_usage = None
|
331
453
|
error_occurred = None
|
332
454
|
chunk_count = 0
|
333
|
-
|
455
|
+
|
456
|
+
# 使用重试逻辑获取流生成器
|
457
|
+
stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
|
458
|
+
|
334
459
|
try:
|
335
|
-
async for response in
|
460
|
+
async for response in stream_generator:
|
336
461
|
chunk_count += 1
|
337
462
|
if response.content:
|
338
463
|
total_content += response.content
|
@@ -341,26 +466,46 @@ class AsyncTamarModelClient:
|
|
341
466
|
if response.error:
|
342
467
|
error_occurred = response.error
|
343
468
|
yield response
|
344
|
-
|
345
|
-
#
|
469
|
+
|
470
|
+
# 流式响应完成,记录日志
|
346
471
|
duration = time.time() - start_time
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
"
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
"
|
355
|
-
"
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
472
|
+
if error_occurred:
|
473
|
+
# 流式响应中包含错误
|
474
|
+
logger.warning(
|
475
|
+
f"⚠️ Stream completed with errors | chunks: {chunk_count}",
|
476
|
+
extra={
|
477
|
+
"log_type": "response",
|
478
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
479
|
+
"duration": duration,
|
480
|
+
"data": ResponseHandler.build_log_data(
|
481
|
+
model_request,
|
482
|
+
stream_stats={
|
483
|
+
"chunks_count": chunk_count,
|
484
|
+
"total_length": len(total_content),
|
485
|
+
"usage": final_usage,
|
486
|
+
"error": error_occurred
|
487
|
+
}
|
488
|
+
)
|
361
489
|
}
|
362
|
-
|
363
|
-
|
490
|
+
)
|
491
|
+
else:
|
492
|
+
# 流式响应成功完成
|
493
|
+
logger.info(
|
494
|
+
f"✅ Stream completed successfully | chunks: {chunk_count}",
|
495
|
+
extra={
|
496
|
+
"log_type": "response",
|
497
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
498
|
+
"duration": duration,
|
499
|
+
"data": ResponseHandler.build_log_data(
|
500
|
+
model_request,
|
501
|
+
stream_stats={
|
502
|
+
"chunks_count": chunk_count,
|
503
|
+
"total_length": len(total_content),
|
504
|
+
"usage": final_usage
|
505
|
+
}
|
506
|
+
)
|
507
|
+
}
|
508
|
+
)
|
364
509
|
except Exception as e:
|
365
510
|
# 流式响应出错,记录错误日志
|
366
511
|
duration = time.time() - start_time
|
@@ -371,28 +516,22 @@ class AsyncTamarModelClient:
|
|
371
516
|
"log_type": "response",
|
372
517
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
373
518
|
"duration": duration,
|
374
|
-
"data":
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
}
|
519
|
+
"data": ResponseHandler.build_log_data(
|
520
|
+
model_request,
|
521
|
+
error=e,
|
522
|
+
stream_stats={
|
523
|
+
"chunks_count": chunk_count,
|
524
|
+
"partial_content_length": len(total_content)
|
525
|
+
}
|
526
|
+
)
|
383
527
|
}
|
384
528
|
)
|
385
529
|
raise
|
386
530
|
|
387
531
|
async def _invoke_request(self, request, metadata, invoke_timeout):
|
532
|
+
"""执行单个非流式请求"""
|
388
533
|
async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
|
389
|
-
return
|
390
|
-
content=response.content,
|
391
|
-
usage=json.loads(response.usage) if response.usage else None,
|
392
|
-
error=response.error or None,
|
393
|
-
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
394
|
-
request_id=response.request_id if response.request_id else None,
|
395
|
-
)
|
534
|
+
return ResponseHandler.build_model_response(response)
|
396
535
|
|
397
536
|
async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None,
|
398
537
|
request_id: Optional[str] = None) -> Union[
|
@@ -411,6 +550,12 @@ class AsyncTamarModelClient:
|
|
411
550
|
ValidationError: 输入验证失败。
|
412
551
|
ConnectionError: 连接服务端失败。
|
413
552
|
"""
|
553
|
+
# 如果启用了熔断且熔断器打开,直接走 HTTP
|
554
|
+
if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
|
555
|
+
if self.http_fallback_url:
|
556
|
+
logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback")
|
557
|
+
return await self._invoke_http_fallback(model_request, timeout, request_id)
|
558
|
+
|
414
559
|
await self._ensure_initialized()
|
415
560
|
|
416
561
|
if not self.default_payload:
|
@@ -420,9 +565,9 @@ class AsyncTamarModelClient:
|
|
420
565
|
}
|
421
566
|
|
422
567
|
if not request_id:
|
423
|
-
request_id = generate_request_id()
|
424
|
-
set_request_id(request_id)
|
425
|
-
metadata = self._build_auth_metadata(request_id)
|
568
|
+
request_id = generate_request_id()
|
569
|
+
set_request_id(request_id)
|
570
|
+
metadata = self._build_auth_metadata(request_id)
|
426
571
|
|
427
572
|
# 记录开始日志
|
428
573
|
start_time = time.time()
|
@@ -431,135 +576,103 @@ class AsyncTamarModelClient:
|
|
431
576
|
extra={
|
432
577
|
"log_type": "request",
|
433
578
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
434
|
-
"data":
|
435
|
-
"provider": model_request.provider.value,
|
436
|
-
"invoke_type": model_request.invoke_type.value,
|
437
|
-
"model": model_request.model,
|
438
|
-
"stream": model_request.stream,
|
439
|
-
"org_id": model_request.user_context.org_id,
|
440
|
-
"user_id": model_request.user_context.user_id,
|
441
|
-
"client_type": model_request.user_context.client_type
|
442
|
-
}
|
579
|
+
"data": ResponseHandler.build_log_data(model_request)
|
443
580
|
})
|
444
581
|
|
445
|
-
# 动态根据 provider/invoke_type 决定使用哪个 input 字段
|
446
582
|
try:
|
447
|
-
#
|
448
|
-
|
449
|
-
|
450
|
-
case (ProviderType.GOOGLE, InvokeType.GENERATION):
|
451
|
-
allowed_fields = GoogleGenAiInput.model_fields.keys()
|
452
|
-
case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
|
453
|
-
allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
|
454
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
|
455
|
-
allowed_fields = OpenAIResponsesInput.model_fields.keys()
|
456
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
|
457
|
-
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
458
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
459
|
-
allowed_fields = OpenAIImagesInput.model_fields.keys()
|
460
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
461
|
-
allowed_fields = OpenAIImagesEditInput.model_fields.keys()
|
462
|
-
case _:
|
463
|
-
raise ValueError(
|
464
|
-
f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
|
465
|
-
|
466
|
-
# 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
|
467
|
-
model_request_dict = model_request.model_dump(exclude_unset=True)
|
468
|
-
|
469
|
-
grpc_request_kwargs = {}
|
470
|
-
for field in allowed_fields:
|
471
|
-
if field in model_request_dict:
|
472
|
-
value = model_request_dict[field]
|
473
|
-
|
474
|
-
# 跳过无效的值
|
475
|
-
if not is_effective_value(value):
|
476
|
-
continue
|
477
|
-
|
478
|
-
# 序列化grpc不支持的类型
|
479
|
-
grpc_request_kwargs[field] = serialize_value(value)
|
480
|
-
|
481
|
-
# 清理 serialize后的 grpc_request_kwargs
|
482
|
-
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
483
|
-
|
484
|
-
request = model_service_pb2.ModelRequestItem(
|
485
|
-
provider=model_request.provider.value,
|
486
|
-
channel=model_request.channel.value,
|
487
|
-
invoke_type=model_request.invoke_type.value,
|
488
|
-
stream=model_request.stream or False,
|
489
|
-
org_id=model_request.user_context.org_id or "",
|
490
|
-
user_id=model_request.user_context.user_id or "",
|
491
|
-
client_type=model_request.user_context.client_type or "",
|
492
|
-
extra=grpc_request_kwargs
|
493
|
-
)
|
494
|
-
|
583
|
+
# 构建 gRPC 请求
|
584
|
+
request = RequestBuilder.build_single_request(model_request)
|
585
|
+
|
495
586
|
except Exception as e:
|
587
|
+
duration = time.time() - start_time
|
588
|
+
logger.error(
|
589
|
+
f"❌ Request build failed: {str(e)}",
|
590
|
+
exc_info=True,
|
591
|
+
extra={
|
592
|
+
"log_type": "response",
|
593
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
594
|
+
"duration": duration,
|
595
|
+
"data": {
|
596
|
+
"provider": model_request.provider.value,
|
597
|
+
"invoke_type": model_request.invoke_type.value,
|
598
|
+
"model": getattr(model_request, 'model', None),
|
599
|
+
"error_type": "build_error",
|
600
|
+
"error_message": str(e)
|
601
|
+
}
|
602
|
+
}
|
603
|
+
)
|
496
604
|
raise ValueError(f"构建请求失败: {str(e)}") from e
|
497
605
|
|
498
606
|
try:
|
499
607
|
invoke_timeout = timeout or self.default_invoke_timeout
|
500
608
|
if model_request.stream:
|
501
|
-
#
|
502
|
-
stream_generator = await self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
|
609
|
+
# 对于流式响应,直接返回带日志记录的包装器
|
503
610
|
return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request)
|
504
611
|
else:
|
505
612
|
result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
|
506
|
-
|
613
|
+
|
507
614
|
# 记录非流式响应的成功日志
|
508
615
|
duration = time.time() - start_time
|
616
|
+
content_length = len(result.content) if result.content else 0
|
509
617
|
logger.info(
|
510
|
-
f"✅ Request completed
|
618
|
+
f"✅ Request completed | content_length: {content_length}",
|
511
619
|
extra={
|
512
620
|
"log_type": "response",
|
513
621
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
514
622
|
"duration": duration,
|
515
|
-
"data":
|
516
|
-
"provider": model_request.provider.value,
|
517
|
-
"invoke_type": model_request.invoke_type.value,
|
518
|
-
"model": model_request.model,
|
519
|
-
"stream": False,
|
520
|
-
"content_length": len(result.content) if result.content else 0,
|
521
|
-
"usage": result.usage
|
522
|
-
}
|
623
|
+
"data": ResponseHandler.build_log_data(model_request, result)
|
523
624
|
}
|
524
625
|
)
|
626
|
+
|
627
|
+
# 记录成功(如果启用了熔断)
|
628
|
+
if self.resilient_enabled and self.circuit_breaker:
|
629
|
+
self.circuit_breaker.record_success()
|
630
|
+
|
525
631
|
return result
|
526
|
-
|
632
|
+
|
633
|
+
except (ConnectionError, grpc.RpcError) as e:
|
527
634
|
duration = time.time() - start_time
|
528
635
|
error_message = f"❌ Invoke gRPC failed: {str(e)}"
|
529
636
|
logger.error(error_message, exc_info=True,
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
|
534
|
-
|
535
|
-
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
637
|
+
extra={
|
638
|
+
"log_type": "response",
|
639
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
640
|
+
"duration": duration,
|
641
|
+
"data": ResponseHandler.build_log_data(
|
642
|
+
model_request,
|
643
|
+
error=e
|
644
|
+
)
|
645
|
+
})
|
646
|
+
|
647
|
+
# 记录失败并尝试降级(如果启用了熔断)
|
648
|
+
if self.resilient_enabled and self.circuit_breaker:
|
649
|
+
# 将错误码传递给熔断器,用于智能失败统计
|
650
|
+
error_code = e.code() if hasattr(e, 'code') else None
|
651
|
+
self.circuit_breaker.record_failure(error_code)
|
652
|
+
|
653
|
+
# 如果可以降级,则降级
|
654
|
+
if self.http_fallback_url and self.circuit_breaker.should_fallback():
|
655
|
+
logger.warning(f"🔻 gRPC failed, falling back to HTTP: {str(e)}")
|
656
|
+
return await self._invoke_http_fallback(model_request, timeout, request_id)
|
657
|
+
|
542
658
|
raise e
|
543
659
|
except Exception as e:
|
544
660
|
duration = time.time() - start_time
|
545
661
|
error_message = f"❌ Invoke other error: {str(e)}"
|
546
662
|
logger.error(error_message, exc_info=True,
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
}
|
557
|
-
})
|
663
|
+
extra={
|
664
|
+
"log_type": "response",
|
665
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
666
|
+
"duration": duration,
|
667
|
+
"data": ResponseHandler.build_log_data(
|
668
|
+
model_request,
|
669
|
+
error=e
|
670
|
+
)
|
671
|
+
})
|
558
672
|
raise e
|
559
673
|
|
560
674
|
async def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None,
|
561
|
-
request_id: Optional[str] = None) ->
|
562
|
-
BatchModelResponse:
|
675
|
+
request_id: Optional[str] = None) -> BatchModelResponse:
|
563
676
|
"""
|
564
677
|
批量模型调用接口
|
565
678
|
|
@@ -570,7 +683,6 @@ class AsyncTamarModelClient:
|
|
570
683
|
Returns:
|
571
684
|
BatchModelResponse: 批量请求的结果
|
572
685
|
"""
|
573
|
-
|
574
686
|
await self._ensure_initialized()
|
575
687
|
|
576
688
|
if not self.default_payload:
|
@@ -580,9 +692,9 @@ class AsyncTamarModelClient:
|
|
580
692
|
}
|
581
693
|
|
582
694
|
if not request_id:
|
583
|
-
request_id = generate_request_id()
|
584
|
-
set_request_id(request_id)
|
585
|
-
metadata = self._build_auth_metadata(request_id)
|
695
|
+
request_id = generate_request_id()
|
696
|
+
set_request_id(request_id)
|
697
|
+
metadata = self._build_auth_metadata(request_id)
|
586
698
|
|
587
699
|
# 记录开始日志
|
588
700
|
start_time = time.time()
|
@@ -599,155 +711,83 @@ class AsyncTamarModelClient:
|
|
599
711
|
}
|
600
712
|
})
|
601
713
|
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
# 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
|
625
|
-
model_request_dict = model_request_item.model_dump(exclude_unset=True)
|
626
|
-
|
627
|
-
grpc_request_kwargs = {}
|
628
|
-
for field in allowed_fields:
|
629
|
-
if field in model_request_dict:
|
630
|
-
value = model_request_dict[field]
|
631
|
-
|
632
|
-
# 跳过无效的值
|
633
|
-
if not is_effective_value(value):
|
634
|
-
continue
|
635
|
-
|
636
|
-
# 序列化grpc不支持的类型
|
637
|
-
grpc_request_kwargs[field] = serialize_value(value)
|
638
|
-
|
639
|
-
# 清理 serialize后的 grpc_request_kwargs
|
640
|
-
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
641
|
-
|
642
|
-
items.append(model_service_pb2.ModelRequestItem(
|
643
|
-
provider=model_request_item.provider.value,
|
644
|
-
channel=model_request_item.channel.value,
|
645
|
-
invoke_type=model_request_item.invoke_type.value,
|
646
|
-
stream=model_request_item.stream or False,
|
647
|
-
custom_id=model_request_item.custom_id or "",
|
648
|
-
priority=model_request_item.priority or 1,
|
649
|
-
org_id=batch_request_model.user_context.org_id or "",
|
650
|
-
user_id=batch_request_model.user_context.user_id or "",
|
651
|
-
client_type=batch_request_model.user_context.client_type or "",
|
652
|
-
extra=grpc_request_kwargs,
|
653
|
-
))
|
654
|
-
|
655
|
-
except Exception as e:
|
656
|
-
raise ValueError(f"构建请求失败: {str(e)},item={model_request_item.custom_id}") from e
|
714
|
+
try:
|
715
|
+
# 构建批量请求
|
716
|
+
batch_request = RequestBuilder.build_batch_request(batch_request_model)
|
717
|
+
|
718
|
+
except Exception as e:
|
719
|
+
duration = time.time() - start_time
|
720
|
+
logger.error(
|
721
|
+
f"❌ Batch request build failed: {str(e)}",
|
722
|
+
exc_info=True,
|
723
|
+
extra={
|
724
|
+
"log_type": "response",
|
725
|
+
"uri": "/batch_invoke",
|
726
|
+
"duration": duration,
|
727
|
+
"data": {
|
728
|
+
"batch_size": len(batch_request_model.items),
|
729
|
+
"error_type": "build_error",
|
730
|
+
"error_message": str(e)
|
731
|
+
}
|
732
|
+
}
|
733
|
+
)
|
734
|
+
raise ValueError(f"构建批量请求失败: {str(e)}") from e
|
657
735
|
|
658
736
|
try:
|
659
|
-
# 超时处理逻辑
|
660
737
|
invoke_timeout = timeout or self.default_invoke_timeout
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
result = []
|
667
|
-
for res_item in response.items:
|
668
|
-
result.append(ModelResponse(
|
669
|
-
content=res_item.content,
|
670
|
-
usage=json.loads(res_item.usage) if res_item.usage else None,
|
671
|
-
raw_response=json.loads(res_item.raw_response) if res_item.raw_response else None,
|
672
|
-
error=res_item.error or None,
|
673
|
-
custom_id=res_item.custom_id if res_item.custom_id else None
|
674
|
-
))
|
675
|
-
batch_response = BatchModelResponse(
|
676
|
-
request_id=response.request_id if response.request_id else None,
|
677
|
-
responses=result
|
738
|
+
batch_response = await self._retry_request(
|
739
|
+
self.stub.BatchInvoke,
|
740
|
+
batch_request,
|
741
|
+
metadata=metadata,
|
742
|
+
timeout=invoke_timeout
|
678
743
|
)
|
679
|
-
|
744
|
+
|
745
|
+
# 构建响应对象
|
746
|
+
result = ResponseHandler.build_batch_response(batch_response)
|
747
|
+
|
680
748
|
# 记录成功日志
|
681
749
|
duration = time.time() - start_time
|
682
750
|
logger.info(
|
683
|
-
f"✅ Batch
|
751
|
+
f"✅ Batch Request completed | batch_size: {len(result.responses)}",
|
684
752
|
extra={
|
685
753
|
"log_type": "response",
|
686
754
|
"uri": "/batch_invoke",
|
687
755
|
"duration": duration,
|
688
756
|
"data": {
|
689
|
-
"batch_size": len(
|
690
|
-
"
|
757
|
+
"batch_size": len(result.responses),
|
758
|
+
"success_count": sum(1 for item in result.responses if not item.error),
|
759
|
+
"error_count": sum(1 for item in result.responses if item.error)
|
691
760
|
}
|
692
|
-
}
|
693
|
-
|
694
|
-
return
|
761
|
+
})
|
762
|
+
|
763
|
+
return result
|
764
|
+
|
695
765
|
except grpc.RpcError as e:
|
696
766
|
duration = time.time() - start_time
|
697
|
-
error_message = f"❌
|
767
|
+
error_message = f"❌ Batch invoke gRPC failed: {str(e)}"
|
698
768
|
logger.error(error_message, exc_info=True,
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
708
|
-
|
769
|
+
extra={
|
770
|
+
"log_type": "response",
|
771
|
+
"uri": "/batch_invoke",
|
772
|
+
"duration": duration,
|
773
|
+
"data": {
|
774
|
+
"error_type": "grpc_error",
|
775
|
+
"error_code": str(e.code()) if hasattr(e, 'code') else None,
|
776
|
+
"batch_size": len(batch_request_model.items)
|
777
|
+
}
|
778
|
+
})
|
709
779
|
raise e
|
710
780
|
except Exception as e:
|
711
781
|
duration = time.time() - start_time
|
712
|
-
error_message = f"❌
|
782
|
+
error_message = f"❌ Batch invoke other error: {str(e)}"
|
713
783
|
logger.error(error_message, exc_info=True,
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
raise e
|
724
|
-
|
725
|
-
async def close(self):
|
726
|
-
"""关闭 gRPC 通道"""
|
727
|
-
if self.channel and not self._closed:
|
728
|
-
await self.channel.close()
|
729
|
-
self._closed = True
|
730
|
-
logger.info("✅ gRPC channel closed",
|
731
|
-
extra={"log_type": "info", "data": {"status": "success"}})
|
732
|
-
|
733
|
-
def _safe_sync_close(self):
|
734
|
-
"""进程退出时自动关闭 channel(事件循环处理兼容)"""
|
735
|
-
if self.channel and not self._closed:
|
736
|
-
try:
|
737
|
-
loop = asyncio.get_event_loop()
|
738
|
-
if loop.is_running():
|
739
|
-
loop.create_task(self.close())
|
740
|
-
else:
|
741
|
-
loop.run_until_complete(self.close())
|
742
|
-
except Exception as e:
|
743
|
-
logger.warning(f"⚠️ gRPC channel close failed at exit: {e}",
|
744
|
-
extra={"log_type": "info", "data": {"status": "failed", "error": str(e)}})
|
745
|
-
|
746
|
-
async def __aenter__(self):
|
747
|
-
"""支持 async with 自动初始化连接"""
|
748
|
-
await self._ensure_initialized()
|
749
|
-
return self
|
750
|
-
|
751
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
752
|
-
"""支持 async with 自动关闭连接"""
|
753
|
-
await self.close()
|
784
|
+
extra={
|
785
|
+
"log_type": "response",
|
786
|
+
"uri": "/batch_invoke",
|
787
|
+
"duration": duration,
|
788
|
+
"data": {
|
789
|
+
"error_type": "other_error",
|
790
|
+
"batch_size": len(batch_request_model.items)
|
791
|
+
}
|
792
|
+
})
|
793
|
+
raise e
|