tamar-model-client 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/async_client.py +580 -540
- tamar_model_client/circuit_breaker.py +140 -0
- tamar_model_client/core/__init__.py +40 -0
- tamar_model_client/core/base_client.py +221 -0
- tamar_model_client/core/http_fallback.py +249 -0
- tamar_model_client/core/logging_setup.py +194 -0
- tamar_model_client/core/request_builder.py +221 -0
- tamar_model_client/core/response_handler.py +136 -0
- tamar_model_client/core/utils.py +171 -0
- tamar_model_client/error_handler.py +315 -0
- tamar_model_client/exceptions.py +388 -7
- tamar_model_client/json_formatter.py +36 -1
- tamar_model_client/sync_client.py +590 -486
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/METADATA +289 -61
- tamar_model_client-0.1.21.dist-info/RECORD +35 -0
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/top_level.txt +1 -0
- tests/__init__.py +1 -0
- tests/stream_hanging_analysis.py +357 -0
- tests/test_google_azure_final.py +448 -0
- tests/test_simple.py +235 -0
- tamar_model_client-0.1.19.dist-info/RECORD +0 -22
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/WHEEL +0 -0
@@ -1,215 +1,153 @@
|
|
1
|
-
|
1
|
+
"""
|
2
|
+
Tamar Model Client 同步客户端实现
|
3
|
+
|
4
|
+
本模块实现了同步的 gRPC 客户端,用于与 Model Manager Server 进行通信。
|
5
|
+
提供了与异步客户端相同的功能,但使用同步 API,适合在同步环境中使用。
|
6
|
+
|
7
|
+
主要功能:
|
8
|
+
- 同步 gRPC 通信
|
9
|
+
- JWT 认证
|
10
|
+
- 自动重试和错误处理
|
11
|
+
- 连接池管理
|
12
|
+
- 详细的日志记录
|
13
|
+
|
14
|
+
使用示例:
|
15
|
+
with TamarModelClient() as client:
|
16
|
+
request = ModelRequest(...)
|
17
|
+
response = client.invoke(request)
|
18
|
+
|
19
|
+
注意:对于需要高并发的场景,建议使用 AsyncTamarModelClient
|
20
|
+
"""
|
21
|
+
|
2
22
|
import json
|
3
23
|
import logging
|
4
|
-
import
|
24
|
+
import random
|
5
25
|
import time
|
6
|
-
import
|
7
|
-
import grpc
|
8
|
-
from typing import Optional, Union, Iterable, Iterator
|
9
|
-
from contextvars import ContextVar
|
26
|
+
from typing import Optional, Union, Iterator
|
10
27
|
|
11
|
-
|
12
|
-
from pydantic import BaseModel
|
28
|
+
import grpc
|
13
29
|
|
14
|
-
from .
|
15
|
-
|
16
|
-
|
30
|
+
from .core import (
|
31
|
+
generate_request_id,
|
32
|
+
set_request_id,
|
33
|
+
get_protected_logger,
|
34
|
+
MAX_MESSAGE_LENGTH
|
35
|
+
)
|
36
|
+
from .core.base_client import BaseClient
|
37
|
+
from .core.request_builder import RequestBuilder
|
38
|
+
from .core.response_handler import ResponseHandler
|
39
|
+
from .exceptions import ConnectionError, TamarModelException
|
17
40
|
from .generated import model_service_pb2, model_service_pb2_grpc
|
18
41
|
from .schemas import BatchModelResponse, ModelResponse
|
19
|
-
from .schemas.inputs import
|
20
|
-
|
21
|
-
from .json_formatter import JSONFormatter
|
22
|
-
|
23
|
-
logger = logging.getLogger(__name__)
|
24
|
-
|
25
|
-
_request_id: ContextVar[str] = ContextVar('request_id', default='-')
|
26
|
-
|
27
|
-
|
28
|
-
class RequestIdFilter(logging.Filter):
|
29
|
-
"""自定义日志过滤器,向日志中添加 request_id"""
|
30
|
-
|
31
|
-
def filter(self, record):
|
32
|
-
# 从 ContextVar 中获取当前的 request_id
|
33
|
-
record.request_id = _request_id.get()
|
34
|
-
return True
|
42
|
+
from .schemas.inputs import BatchModelRequest, ModelRequest
|
43
|
+
from .core.http_fallback import HttpFallbackMixin
|
35
44
|
|
45
|
+
# 配置日志记录器(使用受保护的logger)
|
46
|
+
logger = get_protected_logger(__name__)
|
36
47
|
|
37
|
-
if not logger.hasHandlers():
|
38
|
-
# 创建日志处理器,输出到控制台
|
39
|
-
console_handler = logging.StreamHandler()
|
40
48
|
|
41
|
-
|
42
|
-
formatter = JSONFormatter()
|
43
|
-
console_handler.setFormatter(formatter)
|
44
|
-
|
45
|
-
# 为当前记录器添加处理器
|
46
|
-
logger.addHandler(console_handler)
|
47
|
-
|
48
|
-
# 设置日志级别
|
49
|
-
logger.setLevel(logging.INFO)
|
50
|
-
|
51
|
-
# 将自定义的 RequestIdFilter 添加到 logger 中
|
52
|
-
logger.addFilter(RequestIdFilter())
|
53
|
-
|
54
|
-
MAX_MESSAGE_LENGTH = 2 ** 31 - 1 # 对于32位系统
|
55
|
-
|
56
|
-
|
57
|
-
def is_effective_value(value) -> bool:
|
49
|
+
class TamarModelClient(BaseClient, HttpFallbackMixin):
|
58
50
|
"""
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
"""递归处理单个值,处理BaseModel, dict, list, bytes"""
|
87
|
-
if not is_effective_value(value):
|
88
|
-
return None
|
89
|
-
if isinstance(value, BaseModel):
|
90
|
-
return serialize_value(value.model_dump())
|
91
|
-
if hasattr(value, "dict") and callable(value.dict):
|
92
|
-
return serialize_value(value.dict())
|
93
|
-
if isinstance(value, dict):
|
94
|
-
return {k: serialize_value(v) for k, v in value.items()}
|
95
|
-
if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
|
96
|
-
return [serialize_value(v) for v in value]
|
97
|
-
if isinstance(value, bytes):
|
98
|
-
return f"bytes:{base64.b64encode(value).decode('utf-8')}"
|
99
|
-
return value
|
100
|
-
|
101
|
-
|
102
|
-
from typing import Any
|
103
|
-
|
104
|
-
|
105
|
-
def remove_none_from_dict(data: Any) -> Any:
|
106
|
-
"""
|
107
|
-
遍历 dict/list,递归删除 value 为 None 的字段
|
51
|
+
Tamar Model Client 同步客户端
|
52
|
+
|
53
|
+
提供与 Model Manager Server 的同步通信能力,支持:
|
54
|
+
- 单个和批量模型调用
|
55
|
+
- 流式和非流式响应
|
56
|
+
- 自动重试和错误恢复
|
57
|
+
- JWT 认证
|
58
|
+
- 连接池管理
|
59
|
+
|
60
|
+
使用示例:
|
61
|
+
# 基本用法
|
62
|
+
client = TamarModelClient()
|
63
|
+
client.connect()
|
64
|
+
|
65
|
+
request = ModelRequest(...)
|
66
|
+
response = client.invoke(request)
|
67
|
+
|
68
|
+
# 上下文管理器用法(推荐)
|
69
|
+
with TamarModelClient() as client:
|
70
|
+
response = client.invoke(request)
|
71
|
+
|
72
|
+
环境变量配置:
|
73
|
+
MODEL_MANAGER_SERVER_ADDRESS: gRPC 服务器地址
|
74
|
+
MODEL_MANAGER_SERVER_JWT_SECRET_KEY: JWT 密钥
|
75
|
+
MODEL_MANAGER_SERVER_GRPC_USE_TLS: 是否使用 TLS
|
76
|
+
MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES: 最大重试次数
|
77
|
+
MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY: 重试延迟
|
108
78
|
"""
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
def set_request_id(request_id: str):
|
129
|
-
"""设置当前请求的 request_id"""
|
130
|
-
_request_id.set(request_id)
|
131
|
-
|
132
|
-
|
133
|
-
class TamarModelClient:
|
134
|
-
def __init__(
|
135
|
-
self,
|
136
|
-
server_address: Optional[str] = None,
|
137
|
-
jwt_secret_key: Optional[str] = None,
|
138
|
-
jwt_token: Optional[str] = None,
|
139
|
-
default_payload: Optional[dict] = None,
|
140
|
-
token_expires_in: int = 3600,
|
141
|
-
max_retries: Optional[int] = None, # 最大重试次数
|
142
|
-
retry_delay: Optional[float] = None, # 初始重试延迟(秒)
|
143
|
-
):
|
144
|
-
self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
|
145
|
-
if not self.server_address:
|
146
|
-
raise ValueError("Server address must be provided via argument or environment variable.")
|
147
|
-
self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
|
148
|
-
|
149
|
-
# JWT 配置
|
150
|
-
self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
|
151
|
-
self.jwt_handler = JWTAuthHandler(self.jwt_secret_key)
|
152
|
-
self.jwt_token = jwt_token # 用户传入的 Token(可选)
|
153
|
-
self.default_payload = default_payload
|
154
|
-
self.token_expires_in = token_expires_in
|
155
|
-
|
156
|
-
# === TLS/Authority 配置 ===
|
157
|
-
self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
|
158
|
-
self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
|
159
|
-
|
160
|
-
# === 重试配置 ===
|
161
|
-
self.max_retries = max_retries if max_retries is not None else int(
|
162
|
-
os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
|
163
|
-
self.retry_delay = retry_delay if retry_delay is not None else float(
|
164
|
-
os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
|
165
|
-
|
166
|
-
# === gRPC 通道相关 ===
|
79
|
+
|
80
|
+
def __init__(self, **kwargs):
|
81
|
+
"""
|
82
|
+
初始化同步客户端
|
83
|
+
|
84
|
+
参数继承自 BaseClient,包括:
|
85
|
+
- server_address: gRPC 服务器地址
|
86
|
+
- jwt_secret_key: JWT 签名密钥
|
87
|
+
- jwt_token: 预生成的 JWT 令牌
|
88
|
+
- default_payload: JWT 令牌的默认载荷
|
89
|
+
- token_expires_in: JWT 令牌过期时间
|
90
|
+
- max_retries: 最大重试次数
|
91
|
+
- retry_delay: 初始重试延迟
|
92
|
+
"""
|
93
|
+
super().__init__(logger_name=__name__, **kwargs)
|
94
|
+
|
95
|
+
# === gRPC 通道和连接管理 ===
|
167
96
|
self.channel: Optional[grpc.Channel] = None
|
168
97
|
self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
|
169
|
-
self._closed = False
|
170
98
|
|
171
|
-
def
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
else:
|
189
|
-
logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True,
|
190
|
-
extra={"log_type": "info", "data": {"error_code": str(e.code()) if hasattr(e, 'code') else None, "retryable": False}})
|
191
|
-
raise
|
99
|
+
def close(self):
|
100
|
+
"""
|
101
|
+
关闭客户端连接
|
102
|
+
|
103
|
+
优雅地关闭 gRPC 通道并清理资源。
|
104
|
+
建议在程序结束前调用此方法,或使用上下文管理器自动管理。
|
105
|
+
"""
|
106
|
+
if self.channel and not self._closed:
|
107
|
+
self.channel.close()
|
108
|
+
self._closed = True
|
109
|
+
logger.info("🔒 gRPC channel closed",
|
110
|
+
extra={"log_type": "info", "data": {"status": "closed"}})
|
111
|
+
|
112
|
+
def __enter__(self):
|
113
|
+
"""上下文管理器入口"""
|
114
|
+
self.connect()
|
115
|
+
return self
|
192
116
|
|
193
|
-
def
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
117
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
118
|
+
"""上下文管理器出口"""
|
119
|
+
self.close()
|
120
|
+
|
121
|
+
def connect(self):
|
122
|
+
"""
|
123
|
+
显式连接到服务器
|
124
|
+
|
125
|
+
建立与 gRPC 服务器的连接。通常不需要手动调用,
|
126
|
+
因为 invoke 方法会自动确保连接已建立。
|
127
|
+
"""
|
128
|
+
self._ensure_initialized()
|
199
129
|
|
200
130
|
def _ensure_initialized(self):
|
201
|
-
"""
|
131
|
+
"""
|
132
|
+
初始化gRPC通道
|
133
|
+
|
134
|
+
确保gRPC通道和存根已正确初始化。如果初始化失败,
|
135
|
+
会进行重试,支持TLS配置和完整的keepalive选项。
|
136
|
+
|
137
|
+
连接配置包括:
|
138
|
+
- 消息大小限制
|
139
|
+
- Keepalive设置(30秒ping间隔,10秒超时)
|
140
|
+
- 连接生命周期管理(1小时最大连接时间)
|
141
|
+
- 性能优化选项(带宽探测、内置重试)
|
142
|
+
|
143
|
+
Raises:
|
144
|
+
ConnectionError: 当达到最大重试次数仍无法连接时
|
145
|
+
"""
|
202
146
|
if self.channel and self.stub:
|
203
147
|
return
|
204
148
|
|
205
149
|
retry_count = 0
|
206
|
-
options =
|
207
|
-
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
208
|
-
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
|
209
|
-
('grpc.keepalive_permit_without_calls', True) # 即使没有活跃请求也保持连接
|
210
|
-
]
|
211
|
-
if self.default_authority:
|
212
|
-
options.append(("grpc.default_authority", self.default_authority))
|
150
|
+
options = self.build_channel_options()
|
213
151
|
|
214
152
|
while retry_count <= self.max_retries:
|
215
153
|
try:
|
@@ -221,61 +159,303 @@ class TamarModelClient:
|
|
221
159
|
options=options
|
222
160
|
)
|
223
161
|
logger.info("🔐 Using secure gRPC channel (TLS enabled)",
|
224
|
-
|
162
|
+
extra={"log_type": "info",
|
163
|
+
"data": {"tls_enabled": True, "server_address": self.server_address}})
|
225
164
|
else:
|
226
165
|
self.channel = grpc.insecure_channel(
|
227
166
|
self.server_address,
|
228
167
|
options=options
|
229
168
|
)
|
230
169
|
logger.info("🔓 Using insecure gRPC channel (TLS disabled)",
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
170
|
+
extra={"log_type": "info",
|
171
|
+
"data": {"tls_enabled": False, "server_address": self.server_address}})
|
172
|
+
|
173
|
+
# 等待通道就绪
|
174
|
+
grpc.channel_ready_future(self.channel).result(timeout=10)
|
236
175
|
self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
|
237
176
|
logger.info(f"✅ gRPC channel initialized to {self.server_address}",
|
238
|
-
|
177
|
+
extra={"log_type": "info",
|
178
|
+
"data": {"status": "success", "server_address": self.server_address}})
|
239
179
|
return
|
180
|
+
|
240
181
|
except grpc.FutureTimeoutError as e:
|
241
182
|
logger.error(f"❌ gRPC channel initialization timed out: {str(e)}", exc_info=True,
|
242
|
-
|
183
|
+
extra={"log_type": "info",
|
184
|
+
"data": {"error_type": "timeout", "server_address": self.server_address}})
|
243
185
|
except grpc.RpcError as e:
|
244
186
|
logger.error(f"❌ gRPC channel initialization failed: {str(e)}", exc_info=True,
|
245
|
-
|
187
|
+
extra={"log_type": "info",
|
188
|
+
"data": {"error_type": "grpc_error", "server_address": self.server_address}})
|
246
189
|
except Exception as e:
|
247
|
-
logger.error(f"❌ Unexpected error during channel initialization: {str(e)}", exc_info=True,
|
248
|
-
|
249
|
-
|
190
|
+
logger.error(f"❌ Unexpected error during gRPC channel initialization: {str(e)}", exc_info=True,
|
191
|
+
extra={"log_type": "info",
|
192
|
+
"data": {"error_type": "unknown", "server_address": self.server_address}})
|
193
|
+
|
250
194
|
retry_count += 1
|
251
|
-
if retry_count
|
252
|
-
|
253
|
-
|
254
|
-
|
195
|
+
if retry_count <= self.max_retries:
|
196
|
+
time.sleep(self.retry_delay * retry_count)
|
197
|
+
|
198
|
+
raise ConnectionError(f"Failed to connect to {self.server_address} after {self.max_retries} retries")
|
199
|
+
|
200
|
+
def _retry_request(self, func, *args, **kwargs):
|
201
|
+
"""
|
202
|
+
使用增强的错误处理器进行重试(同步版本)
|
203
|
+
"""
|
204
|
+
# 构建请求上下文
|
205
|
+
context = {
|
206
|
+
'method': func.__name__ if hasattr(func, '__name__') else 'unknown',
|
207
|
+
'client_version': 'sync',
|
208
|
+
}
|
209
|
+
|
210
|
+
last_exception = None
|
211
|
+
|
212
|
+
for attempt in range(self.max_retries + 1):
|
213
|
+
try:
|
214
|
+
context['retry_count'] = attempt
|
215
|
+
return func(*args, **kwargs)
|
216
|
+
|
217
|
+
except grpc.RpcError as e:
|
218
|
+
# 使用新的错误处理逻辑
|
219
|
+
context['retry_count'] = attempt
|
220
|
+
|
221
|
+
# 判断是否可以重试
|
222
|
+
should_retry = self._should_retry(e, attempt)
|
223
|
+
if not should_retry or attempt >= self.max_retries:
|
224
|
+
# 不可重试或已达到最大重试次数
|
225
|
+
last_exception = self.error_handler.handle_error(e, context)
|
226
|
+
break
|
227
|
+
|
228
|
+
# 记录重试日志
|
229
|
+
log_data = {
|
230
|
+
"log_type": "info",
|
231
|
+
"request_id": context.get('request_id'),
|
232
|
+
"data": {
|
233
|
+
"error_code": e.code().name if e.code() else 'UNKNOWN',
|
234
|
+
"retry_count": attempt,
|
235
|
+
"max_retries": self.max_retries,
|
236
|
+
"method": context.get('method', 'unknown')
|
237
|
+
}
|
238
|
+
}
|
239
|
+
logger.warning(
|
240
|
+
f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
|
241
|
+
extra=log_data
|
242
|
+
)
|
243
|
+
|
244
|
+
# 执行退避等待
|
245
|
+
if attempt < self.max_retries:
|
246
|
+
delay = self._calculate_backoff(attempt, e.code())
|
247
|
+
time.sleep(delay)
|
248
|
+
|
249
|
+
last_exception = self.error_handler.handle_error(e, context)
|
250
|
+
|
251
|
+
except Exception as e:
|
252
|
+
# 非 gRPC 错误,直接包装抛出
|
253
|
+
context['retry_count'] = attempt
|
254
|
+
last_exception = TamarModelException(str(e))
|
255
|
+
break
|
256
|
+
|
257
|
+
# 抛出最后的异常
|
258
|
+
if last_exception:
|
259
|
+
raise last_exception
|
260
|
+
else:
|
261
|
+
raise TamarModelException("Unknown error occurred")
|
262
|
+
|
263
|
+
def _calculate_backoff(self, attempt: int, error_code: grpc.StatusCode = None) -> float:
|
264
|
+
"""
|
265
|
+
计算退避时间,支持不同的退避策略
|
266
|
+
|
267
|
+
Args:
|
268
|
+
attempt: 当前重试次数
|
269
|
+
error_code: gRPC错误码,用于确定退避策略
|
270
|
+
"""
|
271
|
+
max_delay = 60.0
|
272
|
+
base_delay = self.retry_delay
|
273
|
+
|
274
|
+
# 获取错误的重试策略
|
275
|
+
if error_code:
|
276
|
+
from .exceptions import get_retry_policy
|
277
|
+
policy = get_retry_policy(error_code)
|
278
|
+
backoff_type = policy.get('backoff', 'exponential')
|
279
|
+
use_jitter = policy.get('jitter', False)
|
280
|
+
else:
|
281
|
+
backoff_type = 'exponential'
|
282
|
+
use_jitter = False
|
283
|
+
|
284
|
+
# 根据退避类型计算延迟
|
285
|
+
if backoff_type == 'linear':
|
286
|
+
# 线性退避:delay * (attempt + 1)
|
287
|
+
delay = min(base_delay * (attempt + 1), max_delay)
|
288
|
+
else:
|
289
|
+
# 指数退避:delay * 2^attempt
|
290
|
+
delay = min(base_delay * (2 ** attempt), max_delay)
|
291
|
+
|
292
|
+
# 添加抖动
|
293
|
+
if use_jitter:
|
294
|
+
jitter_factor = 0.2 # 增加抖动范围,减少竞争
|
295
|
+
jitter = random.uniform(0, delay * jitter_factor)
|
296
|
+
delay += jitter
|
297
|
+
else:
|
298
|
+
# 默认的小量抖动,避免完全同步
|
299
|
+
jitter_factor = 0.05
|
300
|
+
jitter = random.uniform(0, delay * jitter_factor)
|
301
|
+
delay += jitter
|
302
|
+
|
303
|
+
return delay
|
255
304
|
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
305
|
+
def _retry_request_stream(self, func, *args, **kwargs):
|
306
|
+
"""
|
307
|
+
流式请求的重试逻辑(同步版本)
|
308
|
+
|
309
|
+
对于流式响应,需要特殊的重试处理,因为流不能简单地重新执行。
|
310
|
+
|
311
|
+
Args:
|
312
|
+
func: 生成流的函数
|
313
|
+
*args: 函数参数
|
314
|
+
**kwargs: 函数关键字参数
|
315
|
+
|
316
|
+
Yields:
|
317
|
+
流式响应的每个元素
|
318
|
+
"""
|
319
|
+
last_exception = None
|
320
|
+
context = {
|
321
|
+
'method': 'stream',
|
322
|
+
'client_version': 'sync',
|
323
|
+
}
|
324
|
+
|
325
|
+
for attempt in range(self.max_retries + 1):
|
326
|
+
try:
|
327
|
+
context['retry_count'] = attempt
|
328
|
+
# 尝试创建流
|
329
|
+
for item in func(*args, **kwargs):
|
330
|
+
yield item
|
331
|
+
return
|
332
|
+
|
333
|
+
except grpc.RpcError as e:
|
334
|
+
# 使用智能重试判断
|
335
|
+
context['retry_count'] = attempt
|
336
|
+
|
337
|
+
# 判断是否应该重试
|
338
|
+
should_retry = self._should_retry(e, attempt)
|
339
|
+
if not should_retry or attempt >= self.max_retries:
|
340
|
+
# 不重试或已达到最大重试次数
|
341
|
+
log_data = {
|
342
|
+
"log_type": "info",
|
343
|
+
"request_id": context.get('request_id'),
|
344
|
+
"data": {
|
345
|
+
"error_code": e.code().name if e.code() else 'UNKNOWN',
|
346
|
+
"retry_count": attempt,
|
347
|
+
"max_retries": self.max_retries,
|
348
|
+
"method": "stream",
|
349
|
+
"will_retry": False
|
350
|
+
}
|
351
|
+
}
|
352
|
+
logger.error(
|
353
|
+
f"Stream failed: {e.code()} (no retry)",
|
354
|
+
extra=log_data
|
355
|
+
)
|
356
|
+
last_exception = self.error_handler.handle_error(e, context)
|
357
|
+
break
|
358
|
+
|
359
|
+
# 记录重试日志
|
360
|
+
log_data = {
|
361
|
+
"log_type": "info",
|
362
|
+
"request_id": context.get('request_id'),
|
363
|
+
"data": {
|
364
|
+
"error_code": e.code().name if e.code() else 'UNKNOWN',
|
365
|
+
"retry_count": attempt,
|
366
|
+
"max_retries": self.max_retries,
|
367
|
+
"method": "stream"
|
368
|
+
}
|
369
|
+
}
|
370
|
+
logger.warning(
|
371
|
+
f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
|
372
|
+
extra=log_data
|
373
|
+
)
|
374
|
+
|
375
|
+
# 执行退避等待
|
376
|
+
if attempt < self.max_retries:
|
377
|
+
delay = self._calculate_backoff(attempt, e.code())
|
378
|
+
time.sleep(delay)
|
379
|
+
|
380
|
+
last_exception = e
|
381
|
+
|
382
|
+
except Exception as e:
|
383
|
+
context['retry_count'] = attempt
|
384
|
+
raise TamarModelException(str(e)) from e
|
385
|
+
|
386
|
+
if last_exception:
|
387
|
+
if isinstance(last_exception, TamarModelException):
|
388
|
+
raise last_exception
|
389
|
+
else:
|
390
|
+
raise self.error_handler.handle_error(last_exception, context)
|
391
|
+
else:
|
392
|
+
raise TamarModelException("Unknown streaming error occurred")
|
261
393
|
|
262
394
|
def _stream(self, request, metadata, invoke_timeout) -> Iterator[ModelResponse]:
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
395
|
+
"""
|
396
|
+
处理流式响应
|
397
|
+
|
398
|
+
Args:
|
399
|
+
request: gRPC 请求对象
|
400
|
+
metadata: 请求元数据
|
401
|
+
invoke_timeout: 总体超时时间
|
402
|
+
|
403
|
+
Yields:
|
404
|
+
ModelResponse: 流式响应的每个数据块
|
405
|
+
|
406
|
+
Raises:
|
407
|
+
TimeoutError: 当等待下一个数据块超时时
|
408
|
+
"""
|
409
|
+
import threading
|
410
|
+
import queue
|
411
|
+
|
412
|
+
# 创建队列用于线程间通信
|
413
|
+
response_queue = queue.Queue()
|
414
|
+
exception_queue = queue.Queue()
|
415
|
+
|
416
|
+
def fetch_responses():
|
417
|
+
"""在单独线程中获取流式响应"""
|
418
|
+
try:
|
419
|
+
for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
|
420
|
+
response_queue.put(response)
|
421
|
+
response_queue.put(None) # 标记流结束
|
422
|
+
except Exception as e:
|
423
|
+
exception_queue.put(e)
|
424
|
+
response_queue.put(None)
|
425
|
+
|
426
|
+
# 启动响应获取线程
|
427
|
+
fetch_thread = threading.Thread(target=fetch_responses)
|
428
|
+
fetch_thread.daemon = True
|
429
|
+
fetch_thread.start()
|
430
|
+
|
431
|
+
chunk_timeout = 30.0 # 单个数据块的超时时间
|
432
|
+
|
433
|
+
while True:
|
434
|
+
# 检查是否有异常
|
435
|
+
if not exception_queue.empty():
|
436
|
+
raise exception_queue.get()
|
437
|
+
|
438
|
+
try:
|
439
|
+
# 等待下一个响应,带超时
|
440
|
+
response = response_queue.get(timeout=chunk_timeout)
|
441
|
+
|
442
|
+
if response is None:
|
443
|
+
# 流结束
|
444
|
+
break
|
445
|
+
|
446
|
+
yield ResponseHandler.build_model_response(response)
|
447
|
+
|
448
|
+
except queue.Empty:
|
449
|
+
raise TimeoutError(f"流式响应在等待下一个数据块时超时 ({chunk_timeout}s)")
|
450
|
+
|
451
|
+
def _stream_with_logging(self, request, metadata, invoke_timeout, start_time, model_request) -> Iterator[
|
452
|
+
ModelResponse]:
|
453
|
+
"""流式响应的包装器,用于记录完整的响应日志并处理重试"""
|
274
454
|
total_content = ""
|
275
455
|
final_usage = None
|
276
456
|
error_occurred = None
|
277
457
|
chunk_count = 0
|
278
|
-
|
458
|
+
|
279
459
|
try:
|
280
460
|
for response in self._stream(request, metadata, invoke_timeout):
|
281
461
|
chunk_count += 1
|
@@ -286,26 +466,46 @@ class TamarModelClient:
|
|
286
466
|
if response.error:
|
287
467
|
error_occurred = response.error
|
288
468
|
yield response
|
289
|
-
|
290
|
-
#
|
469
|
+
|
470
|
+
# 流式响应完成,记录日志
|
291
471
|
duration = time.time() - start_time
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
"
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
"
|
300
|
-
"
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
472
|
+
if error_occurred:
|
473
|
+
# 流式响应中包含错误
|
474
|
+
logger.warning(
|
475
|
+
f"⚠️ Stream completed with errors | chunks: {chunk_count}",
|
476
|
+
extra={
|
477
|
+
"log_type": "response",
|
478
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
479
|
+
"duration": duration,
|
480
|
+
"data": ResponseHandler.build_log_data(
|
481
|
+
model_request,
|
482
|
+
stream_stats={
|
483
|
+
"chunks_count": chunk_count,
|
484
|
+
"total_length": len(total_content),
|
485
|
+
"usage": final_usage,
|
486
|
+
"error": error_occurred
|
487
|
+
}
|
488
|
+
)
|
306
489
|
}
|
307
|
-
|
308
|
-
|
490
|
+
)
|
491
|
+
else:
|
492
|
+
# 流式响应成功完成
|
493
|
+
logger.info(
|
494
|
+
f"✅ Stream completed successfully | chunks: {chunk_count}",
|
495
|
+
extra={
|
496
|
+
"log_type": "response",
|
497
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
498
|
+
"duration": duration,
|
499
|
+
"data": ResponseHandler.build_log_data(
|
500
|
+
model_request,
|
501
|
+
stream_stats={
|
502
|
+
"chunks_count": chunk_count,
|
503
|
+
"total_length": len(total_content),
|
504
|
+
"usage": final_usage
|
505
|
+
}
|
506
|
+
)
|
507
|
+
}
|
508
|
+
)
|
309
509
|
except Exception as e:
|
310
510
|
# 流式响应出错,记录错误日志
|
311
511
|
duration = time.time() - start_time
|
@@ -316,29 +516,23 @@ class TamarModelClient:
|
|
316
516
|
"log_type": "response",
|
317
517
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
318
518
|
"duration": duration,
|
319
|
-
"data":
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
}
|
519
|
+
"data": ResponseHandler.build_log_data(
|
520
|
+
model_request,
|
521
|
+
error=e,
|
522
|
+
stream_stats={
|
523
|
+
"chunks_count": chunk_count,
|
524
|
+
"partial_content_length": len(total_content)
|
525
|
+
}
|
526
|
+
)
|
328
527
|
}
|
329
528
|
)
|
330
529
|
raise
|
331
530
|
|
332
531
|
def _invoke_request(self, request, metadata, invoke_timeout):
|
532
|
+
"""执行单个非流式请求"""
|
333
533
|
response = self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout)
|
334
534
|
for response in response:
|
335
|
-
return
|
336
|
-
content=response.content,
|
337
|
-
usage=json.loads(response.usage) if response.usage else None,
|
338
|
-
error=response.error or None,
|
339
|
-
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
340
|
-
request_id=response.request_id if response.request_id else None,
|
341
|
-
)
|
535
|
+
return ResponseHandler.build_model_response(response)
|
342
536
|
|
343
537
|
def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None, request_id: Optional[str] = None) -> \
|
344
538
|
Union[ModelResponse, Iterator[ModelResponse]]:
|
@@ -356,6 +550,12 @@ class TamarModelClient:
|
|
356
550
|
ValidationError: 输入验证失败。
|
357
551
|
ConnectionError: 连接服务端失败。
|
358
552
|
"""
|
553
|
+
# 如果启用了熔断且熔断器打开,直接走 HTTP
|
554
|
+
if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
|
555
|
+
if self.http_fallback_url:
|
556
|
+
logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback")
|
557
|
+
return self._invoke_http_fallback(model_request, timeout, request_id)
|
558
|
+
|
359
559
|
self._ensure_initialized()
|
360
560
|
|
361
561
|
if not self.default_payload:
|
@@ -365,9 +565,9 @@ class TamarModelClient:
|
|
365
565
|
}
|
366
566
|
|
367
567
|
if not request_id:
|
368
|
-
request_id = generate_request_id()
|
369
|
-
set_request_id(request_id)
|
370
|
-
metadata = self._build_auth_metadata(request_id)
|
568
|
+
request_id = generate_request_id()
|
569
|
+
set_request_id(request_id)
|
570
|
+
metadata = self._build_auth_metadata(request_id)
|
371
571
|
|
372
572
|
# 记录开始日志
|
373
573
|
start_time = time.time()
|
@@ -376,129 +576,102 @@ class TamarModelClient:
|
|
376
576
|
extra={
|
377
577
|
"log_type": "request",
|
378
578
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
379
|
-
"data":
|
380
|
-
"provider": model_request.provider.value,
|
381
|
-
"invoke_type": model_request.invoke_type.value,
|
382
|
-
"model": model_request.model,
|
383
|
-
"stream": model_request.stream,
|
384
|
-
"org_id": model_request.user_context.org_id,
|
385
|
-
"user_id": model_request.user_context.user_id,
|
386
|
-
"client_type": model_request.user_context.client_type
|
387
|
-
}
|
579
|
+
"data": ResponseHandler.build_log_data(model_request)
|
388
580
|
})
|
389
581
|
|
390
|
-
# 动态根据 provider/invoke_type 决定使用哪个 input 字段
|
391
582
|
try:
|
392
|
-
#
|
393
|
-
|
394
|
-
|
395
|
-
case (ProviderType.GOOGLE, InvokeType.GENERATION):
|
396
|
-
allowed_fields = GoogleGenAiInput.model_fields.keys()
|
397
|
-
case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
|
398
|
-
allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
|
399
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
|
400
|
-
allowed_fields = OpenAIResponsesInput.model_fields.keys()
|
401
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
|
402
|
-
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
403
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
404
|
-
allowed_fields = OpenAIImagesInput.model_fields.keys()
|
405
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
406
|
-
allowed_fields = OpenAIImagesEditInput.model_fields.keys()
|
407
|
-
case _:
|
408
|
-
raise ValueError(
|
409
|
-
f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
|
410
|
-
|
411
|
-
# 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
|
412
|
-
model_request_dict = model_request.model_dump(exclude_unset=True)
|
413
|
-
|
414
|
-
grpc_request_kwargs = {}
|
415
|
-
for field in allowed_fields:
|
416
|
-
if field in model_request_dict:
|
417
|
-
value = model_request_dict[field]
|
418
|
-
|
419
|
-
# 跳过无效的值
|
420
|
-
if not is_effective_value(value):
|
421
|
-
continue
|
422
|
-
|
423
|
-
# 序列化grpc不支持的类型
|
424
|
-
grpc_request_kwargs[field] = serialize_value(value)
|
425
|
-
|
426
|
-
# 清理 serialize后的 grpc_request_kwargs
|
427
|
-
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
428
|
-
|
429
|
-
request = model_service_pb2.ModelRequestItem(
|
430
|
-
provider=model_request.provider.value,
|
431
|
-
channel=model_request.channel.value,
|
432
|
-
invoke_type=model_request.invoke_type.value,
|
433
|
-
stream=model_request.stream or False,
|
434
|
-
org_id=model_request.user_context.org_id or "",
|
435
|
-
user_id=model_request.user_context.user_id or "",
|
436
|
-
client_type=model_request.user_context.client_type or "",
|
437
|
-
extra=grpc_request_kwargs
|
438
|
-
)
|
439
|
-
|
583
|
+
# 构建 gRPC 请求
|
584
|
+
request = RequestBuilder.build_single_request(model_request)
|
585
|
+
|
440
586
|
except Exception as e:
|
587
|
+
duration = time.time() - start_time
|
588
|
+
logger.error(
|
589
|
+
f"❌ Request build failed: {str(e)}",
|
590
|
+
exc_info=True,
|
591
|
+
extra={
|
592
|
+
"log_type": "response",
|
593
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
594
|
+
"duration": duration,
|
595
|
+
"data": {
|
596
|
+
"provider": model_request.provider.value,
|
597
|
+
"invoke_type": model_request.invoke_type.value,
|
598
|
+
"model": getattr(model_request, 'model', None),
|
599
|
+
"error_type": "build_error",
|
600
|
+
"error_message": str(e)
|
601
|
+
}
|
602
|
+
}
|
603
|
+
)
|
441
604
|
raise ValueError(f"构建请求失败: {str(e)}") from e
|
442
605
|
|
443
606
|
try:
|
444
607
|
invoke_timeout = timeout or self.default_invoke_timeout
|
445
608
|
if model_request.stream:
|
446
|
-
#
|
447
|
-
return self.
|
609
|
+
# 对于流式响应,使用重试包装器
|
610
|
+
return self._retry_request_stream(
|
611
|
+
self._stream_with_logging,
|
612
|
+
request, metadata, invoke_timeout, start_time, model_request
|
613
|
+
)
|
448
614
|
else:
|
449
615
|
result = self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
|
450
|
-
|
616
|
+
|
451
617
|
# 记录非流式响应的成功日志
|
452
618
|
duration = time.time() - start_time
|
619
|
+
content_length = len(result.content) if result.content else 0
|
453
620
|
logger.info(
|
454
|
-
f"✅ Request completed
|
621
|
+
f"✅ Request completed | content_length: {content_length}",
|
455
622
|
extra={
|
456
623
|
"log_type": "response",
|
457
624
|
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
458
625
|
"duration": duration,
|
459
|
-
"data":
|
460
|
-
"provider": model_request.provider.value,
|
461
|
-
"invoke_type": model_request.invoke_type.value,
|
462
|
-
"model": model_request.model,
|
463
|
-
"stream": False,
|
464
|
-
"content_length": len(result.content) if result.content else 0,
|
465
|
-
"usage": result.usage
|
466
|
-
}
|
626
|
+
"data": ResponseHandler.build_log_data(model_request, result)
|
467
627
|
}
|
468
628
|
)
|
629
|
+
|
630
|
+
# 记录成功(如果启用了熔断)
|
631
|
+
if self.resilient_enabled and self.circuit_breaker:
|
632
|
+
self.circuit_breaker.record_success()
|
633
|
+
|
469
634
|
return result
|
470
|
-
|
635
|
+
|
636
|
+
except (ConnectionError, grpc.RpcError) as e:
|
471
637
|
duration = time.time() - start_time
|
472
638
|
error_message = f"❌ Invoke gRPC failed: {str(e)}"
|
473
639
|
logger.error(error_message, exc_info=True,
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
640
|
+
extra={
|
641
|
+
"log_type": "response",
|
642
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
643
|
+
"duration": duration,
|
644
|
+
"data": ResponseHandler.build_log_data(
|
645
|
+
model_request,
|
646
|
+
error=e
|
647
|
+
)
|
648
|
+
})
|
649
|
+
|
650
|
+
# 记录失败并尝试降级(如果启用了熔断)
|
651
|
+
if self.resilient_enabled and self.circuit_breaker:
|
652
|
+
# 将错误码传递给熔断器,用于智能失败统计
|
653
|
+
error_code = e.code() if hasattr(e, 'code') else None
|
654
|
+
self.circuit_breaker.record_failure(error_code)
|
655
|
+
|
656
|
+
# 如果可以降级,则降级
|
657
|
+
if self.http_fallback_url and self.circuit_breaker.should_fallback():
|
658
|
+
logger.warning(f"🔻 gRPC failed, falling back to HTTP: {str(e)}")
|
659
|
+
return self._invoke_http_fallback(model_request, timeout, request_id)
|
660
|
+
|
486
661
|
raise e
|
487
662
|
except Exception as e:
|
488
663
|
duration = time.time() - start_time
|
489
664
|
error_message = f"❌ Invoke other error: {str(e)}"
|
490
665
|
logger.error(error_message, exc_info=True,
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
}
|
501
|
-
})
|
666
|
+
extra={
|
667
|
+
"log_type": "response",
|
668
|
+
"uri": f"/invoke/{model_request.provider.value}/{model_request.invoke_type.value}",
|
669
|
+
"duration": duration,
|
670
|
+
"data": ResponseHandler.build_log_data(
|
671
|
+
model_request,
|
672
|
+
error=e
|
673
|
+
)
|
674
|
+
})
|
502
675
|
raise e
|
503
676
|
|
504
677
|
def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None,
|
@@ -513,7 +686,6 @@ class TamarModelClient:
|
|
513
686
|
Returns:
|
514
687
|
BatchModelResponse: 批量请求的结果
|
515
688
|
"""
|
516
|
-
|
517
689
|
self._ensure_initialized()
|
518
690
|
|
519
691
|
if not self.default_payload:
|
@@ -523,9 +695,9 @@ class TamarModelClient:
|
|
523
695
|
}
|
524
696
|
|
525
697
|
if not request_id:
|
526
|
-
request_id = generate_request_id()
|
527
|
-
set_request_id(request_id)
|
528
|
-
metadata = self._build_auth_metadata(request_id)
|
698
|
+
request_id = generate_request_id()
|
699
|
+
set_request_id(request_id)
|
700
|
+
metadata = self._build_auth_metadata(request_id)
|
529
701
|
|
530
702
|
# 记录开始日志
|
531
703
|
start_time = time.time()
|
@@ -542,151 +714,83 @@ class TamarModelClient:
|
|
542
714
|
}
|
543
715
|
})
|
544
716
|
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
557
|
-
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
# 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
|
568
|
-
model_request_dict = model_request_item.model_dump(exclude_unset=True)
|
569
|
-
|
570
|
-
grpc_request_kwargs = {}
|
571
|
-
for field in allowed_fields:
|
572
|
-
if field in model_request_dict:
|
573
|
-
value = model_request_dict[field]
|
574
|
-
|
575
|
-
# 跳过无效的值
|
576
|
-
if not is_effective_value(value):
|
577
|
-
continue
|
578
|
-
|
579
|
-
# 序列化grpc不支持的类型
|
580
|
-
grpc_request_kwargs[field] = serialize_value(value)
|
581
|
-
|
582
|
-
# 清理 serialize后的 grpc_request_kwargs
|
583
|
-
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
584
|
-
|
585
|
-
items.append(model_service_pb2.ModelRequestItem(
|
586
|
-
provider=model_request_item.provider.value,
|
587
|
-
channel=model_request_item.channel.value,
|
588
|
-
invoke_type=model_request_item.invoke_type.value,
|
589
|
-
stream=model_request_item.stream or False,
|
590
|
-
custom_id=model_request_item.custom_id or "",
|
591
|
-
priority=model_request_item.priority or 1,
|
592
|
-
org_id=batch_request_model.user_context.org_id or "",
|
593
|
-
user_id=batch_request_model.user_context.user_id or "",
|
594
|
-
client_type=batch_request_model.user_context.client_type or "",
|
595
|
-
extra=grpc_request_kwargs,
|
596
|
-
))
|
597
|
-
|
598
|
-
except Exception as e:
|
599
|
-
raise ValueError(f"构建请求失败: {str(e)},item={model_request_item.custom_id}") from e
|
717
|
+
try:
|
718
|
+
# 构建批量请求
|
719
|
+
batch_request = RequestBuilder.build_batch_request(batch_request_model)
|
720
|
+
|
721
|
+
except Exception as e:
|
722
|
+
duration = time.time() - start_time
|
723
|
+
logger.error(
|
724
|
+
f"❌ Batch request build failed: {str(e)}",
|
725
|
+
exc_info=True,
|
726
|
+
extra={
|
727
|
+
"log_type": "response",
|
728
|
+
"uri": "/batch_invoke",
|
729
|
+
"duration": duration,
|
730
|
+
"data": {
|
731
|
+
"batch_size": len(batch_request_model.items),
|
732
|
+
"error_type": "build_error",
|
733
|
+
"error_message": str(e)
|
734
|
+
}
|
735
|
+
}
|
736
|
+
)
|
737
|
+
raise ValueError(f"构建批量请求失败: {str(e)}") from e
|
600
738
|
|
601
739
|
try:
|
602
|
-
# 超时处理逻辑
|
603
740
|
invoke_timeout = timeout or self.default_invoke_timeout
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
result = []
|
610
|
-
for res_item in response.items:
|
611
|
-
result.append(ModelResponse(
|
612
|
-
content=res_item.content,
|
613
|
-
usage=json.loads(res_item.usage) if res_item.usage else None,
|
614
|
-
raw_response=json.loads(res_item.raw_response) if res_item.raw_response else None,
|
615
|
-
error=res_item.error or None,
|
616
|
-
custom_id=res_item.custom_id if res_item.custom_id else None
|
617
|
-
))
|
618
|
-
batch_response = BatchModelResponse(
|
619
|
-
request_id=response.request_id if response.request_id else None,
|
620
|
-
responses=result
|
741
|
+
batch_response = self._retry_request(
|
742
|
+
self.stub.BatchInvoke,
|
743
|
+
batch_request,
|
744
|
+
metadata=metadata,
|
745
|
+
timeout=invoke_timeout
|
621
746
|
)
|
622
|
-
|
747
|
+
|
748
|
+
# 构建响应对象
|
749
|
+
result = ResponseHandler.build_batch_response(batch_response)
|
750
|
+
|
623
751
|
# 记录成功日志
|
624
752
|
duration = time.time() - start_time
|
625
753
|
logger.info(
|
626
|
-
f"✅ Batch
|
754
|
+
f"✅ Batch Request completed | batch_size: {len(result.responses)}",
|
627
755
|
extra={
|
628
756
|
"log_type": "response",
|
629
757
|
"uri": "/batch_invoke",
|
630
758
|
"duration": duration,
|
631
759
|
"data": {
|
632
|
-
"batch_size": len(
|
633
|
-
"
|
760
|
+
"batch_size": len(result.responses),
|
761
|
+
"success_count": sum(1 for item in result.responses if not item.error),
|
762
|
+
"error_count": sum(1 for item in result.responses if item.error)
|
634
763
|
}
|
635
|
-
}
|
636
|
-
|
637
|
-
return
|
764
|
+
})
|
765
|
+
|
766
|
+
return result
|
767
|
+
|
638
768
|
except grpc.RpcError as e:
|
639
769
|
duration = time.time() - start_time
|
640
|
-
error_message = f"❌
|
770
|
+
error_message = f"❌ Batch invoke gRPC failed: {str(e)}"
|
641
771
|
logger.error(error_message, exc_info=True,
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
772
|
+
extra={
|
773
|
+
"log_type": "response",
|
774
|
+
"uri": "/batch_invoke",
|
775
|
+
"duration": duration,
|
776
|
+
"data": {
|
777
|
+
"error_type": "grpc_error",
|
778
|
+
"error_code": str(e.code()) if hasattr(e, 'code') else None,
|
779
|
+
"batch_size": len(batch_request_model.items)
|
780
|
+
}
|
781
|
+
})
|
652
782
|
raise e
|
653
783
|
except Exception as e:
|
654
784
|
duration = time.time() - start_time
|
655
|
-
error_message = f"❌
|
785
|
+
error_message = f"❌ Batch invoke other error: {str(e)}"
|
656
786
|
logger.error(error_message, exc_info=True,
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
raise e
|
667
|
-
|
668
|
-
def close(self):
|
669
|
-
"""关闭 gRPC 通道"""
|
670
|
-
if self.channel and not self._closed:
|
671
|
-
self.channel.close()
|
672
|
-
self._closed = True
|
673
|
-
logger.info("✅ gRPC channel closed",
|
674
|
-
extra={"log_type": "info", "data": {"status": "success"}})
|
675
|
-
|
676
|
-
def _safe_sync_close(self):
|
677
|
-
"""进程退出时自动关闭 channel(事件循环处理兼容)"""
|
678
|
-
if self.channel and not self._closed:
|
679
|
-
try:
|
680
|
-
self.close() # 直接调用关闭方法
|
681
|
-
except Exception as e:
|
682
|
-
logger.warning(f"⚠️ gRPC channel close failed at exit: {e}",
|
683
|
-
extra={"log_type": "info", "data": {"status": "failed", "error": str(e)}})
|
684
|
-
|
685
|
-
def __enter__(self):
|
686
|
-
"""同步初始化连接"""
|
687
|
-
self._ensure_initialized()
|
688
|
-
return self
|
689
|
-
|
690
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
691
|
-
"""同步关闭连接"""
|
692
|
-
self.close()
|
787
|
+
extra={
|
788
|
+
"log_type": "response",
|
789
|
+
"uri": "/batch_invoke",
|
790
|
+
"duration": duration,
|
791
|
+
"data": {
|
792
|
+
"error_type": "other_error",
|
793
|
+
"batch_size": len(batch_request_model.items)
|
794
|
+
}
|
795
|
+
})
|
796
|
+
raise e
|