tamar-model-client 0.1.28__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/async_client.py +71 -42
- tamar_model_client/auth.py +31 -2
- tamar_model_client/core/base_client.py +29 -11
- tamar_model_client/core/http_fallback.py +101 -17
- tamar_model_client/error_handler.py +8 -6
- tamar_model_client/json_formatter.py +9 -0
- tamar_model_client/sync_client.py +59 -24
- {tamar_model_client-0.1.28.dist-info → tamar_model_client-0.2.0.dist-info}/METADATA +496 -7
- {tamar_model_client-0.1.28.dist-info → tamar_model_client-0.2.0.dist-info}/RECORD +13 -12
- tests/test_circuit_breaker.py +269 -0
- tests/test_google_azure_final.py +589 -5
- {tamar_model_client-0.1.28.dist-info → tamar_model_client-0.2.0.dist-info}/WHEEL +0 -0
- {tamar_model_client-0.1.28.dist-info → tamar_model_client-0.2.0.dist-info}/top_level.txt +0 -0
@@ -103,7 +103,6 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
103
103
|
self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
|
104
104
|
self._channel_error_count = 0
|
105
105
|
self._last_channel_error_time = None
|
106
|
-
self._channel_lock = asyncio.Lock() # 异步锁
|
107
106
|
|
108
107
|
# === Request ID 管理 ===
|
109
108
|
self._request_id_manager = RequestIdManager()
|
@@ -194,7 +193,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
194
193
|
# 如果 channel 存在但不健康,记录日志
|
195
194
|
if self.channel and self.stub:
|
196
195
|
logger.warning(
|
197
|
-
"Channel exists but unhealthy, will recreate",
|
196
|
+
"⚠️ Channel exists but unhealthy, will recreate",
|
198
197
|
extra={
|
199
198
|
"log_type": "channel_recreate",
|
200
199
|
"data": {
|
@@ -222,7 +221,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
222
221
|
"data": {"tls_enabled": True, "server_address": self.server_address}})
|
223
222
|
else:
|
224
223
|
self.channel = grpc.aio.insecure_channel(
|
225
|
-
self.server_address,
|
224
|
+
f"dns:///{self.server_address}",
|
226
225
|
options=options
|
227
226
|
)
|
228
227
|
logger.info("🔓 Using insecure gRPC channel (TLS disabled)",
|
@@ -272,7 +271,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
272
271
|
# 如果处于关闭或失败状态,需要重建
|
273
272
|
if state in [grpc.ChannelConnectivity.SHUTDOWN,
|
274
273
|
grpc.ChannelConnectivity.TRANSIENT_FAILURE]:
|
275
|
-
logger.warning(f"Channel in unhealthy state: {state}",
|
274
|
+
logger.warning(f"⚠️ Channel in unhealthy state: {state}",
|
276
275
|
extra={"log_type": "info",
|
277
276
|
"data": {"channel_state": str(state)}})
|
278
277
|
return False
|
@@ -280,7 +279,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
280
279
|
# 如果最近有多次错误,也需要重建
|
281
280
|
if self._channel_error_count > 3 and self._last_channel_error_time:
|
282
281
|
if time.time() - self._last_channel_error_time < 60: # 60秒内
|
283
|
-
logger.warning("Too many channel errors recently, marking as unhealthy",
|
282
|
+
logger.warning("⚠️ Too many channel errors recently, marking as unhealthy",
|
284
283
|
extra={"log_type": "info",
|
285
284
|
"data": {"error_count": self._channel_error_count}})
|
286
285
|
return False
|
@@ -288,7 +287,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
288
287
|
return True
|
289
288
|
|
290
289
|
except Exception as e:
|
291
|
-
logger.error(f"Error checking channel health: {e}",
|
290
|
+
logger.error(f"❌ Error checking channel health: {e}",
|
292
291
|
extra={"log_type": "info",
|
293
292
|
"data": {"error": str(e)}})
|
294
293
|
return False
|
@@ -299,27 +298,26 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
299
298
|
|
300
299
|
关闭旧的 channel 并创建新的连接
|
301
300
|
"""
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
extra={"log_type": "info"})
|
301
|
+
# 关闭旧 channel
|
302
|
+
if self.channel:
|
303
|
+
try:
|
304
|
+
await self.channel.close()
|
305
|
+
logger.info("🔚 Closed unhealthy channel",
|
306
|
+
extra={"log_type": "info"})
|
307
|
+
except Exception as e:
|
308
|
+
logger.warning(f"⚠️ Error closing channel: {e}",
|
309
|
+
extra={"log_type": "info"})
|
310
|
+
|
311
|
+
# 清空引用
|
312
|
+
self.channel = None
|
313
|
+
self.stub = None
|
314
|
+
|
315
|
+
# 重置错误计数
|
316
|
+
self._channel_error_count = 0
|
317
|
+
self._last_channel_error_time = None
|
318
|
+
|
319
|
+
logger.info("🔄 Recreating gRPC channel...",
|
320
|
+
extra={"log_type": "info"})
|
323
321
|
|
324
322
|
def _record_channel_error(self, error: grpc.RpcError):
|
325
323
|
"""
|
@@ -346,7 +344,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
346
344
|
|
347
345
|
# 记录详细的错误信息
|
348
346
|
logger.warning(
|
349
|
-
f"Channel error recorded: {error.code().name}",
|
347
|
+
f"⚠️ Channel error recorded: {error.code().name}",
|
350
348
|
extra={
|
351
349
|
"log_type": "channel_error",
|
352
350
|
"data": {
|
@@ -457,7 +455,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
457
455
|
is_network_cancelled = error_context.is_network_cancelled()
|
458
456
|
|
459
457
|
logger.warning(
|
460
|
-
f"CANCELLED error in stream, channel state: {channel_state}",
|
458
|
+
f"⚠️ CANCELLED error in stream, channel state: {channel_state}",
|
461
459
|
extra={
|
462
460
|
"log_type": "cancelled_debug",
|
463
461
|
"request_id": context.get('request_id'),
|
@@ -485,14 +483,16 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
485
483
|
"request_id": context.get('request_id'),
|
486
484
|
"data": {
|
487
485
|
"error_code": e.code().name if e.code() else 'UNKNOWN',
|
486
|
+
"error_details": e.details() if hasattr(e, 'details') else '',
|
488
487
|
"retry_count": attempt,
|
489
488
|
"max_retries": self.max_retries,
|
490
489
|
"method": "stream"
|
491
490
|
},
|
492
491
|
"duration": current_duration
|
493
492
|
}
|
493
|
+
error_detail = f" - {e.details()}" if e.details() else ""
|
494
494
|
logger.warning(
|
495
|
-
f"
|
495
|
+
f"🔄 Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (will retry)",
|
496
496
|
extra=log_data
|
497
497
|
)
|
498
498
|
|
@@ -507,6 +507,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
507
507
|
"request_id": context.get('request_id'),
|
508
508
|
"data": {
|
509
509
|
"error_code": e.code().name if e.code() else 'UNKNOWN',
|
510
|
+
"error_details": e.details() if hasattr(e, 'details') else '',
|
510
511
|
"retry_count": attempt,
|
511
512
|
"max_retries": self.max_retries,
|
512
513
|
"method": "stream",
|
@@ -514,8 +515,9 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
514
515
|
},
|
515
516
|
"duration": current_duration
|
516
517
|
}
|
517
|
-
|
518
|
-
|
518
|
+
error_detail = f" - {e.details()}" if e.details() else ""
|
519
|
+
logger.warning(
|
520
|
+
f"⚠️ Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (no more retries)",
|
519
521
|
extra=log_data
|
520
522
|
)
|
521
523
|
context['duration'] = current_duration
|
@@ -596,7 +598,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
596
598
|
|
597
599
|
return delay
|
598
600
|
|
599
|
-
async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
|
601
|
+
async def _stream(self, request, metadata, invoke_timeout, request_id=None, origin_request_id=None) -> AsyncIterator[ModelResponse]:
|
600
602
|
"""
|
601
603
|
处理流式响应
|
602
604
|
|
@@ -604,8 +606,10 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
604
606
|
|
605
607
|
Args:
|
606
608
|
request: gRPC 请求对象
|
607
|
-
metadata:
|
609
|
+
metadata: 请求元数据(为了兼容性保留,但会被忽略)
|
608
610
|
invoke_timeout: 总体超时时间
|
611
|
+
request_id: 请求ID
|
612
|
+
origin_request_id: 原始请求ID
|
609
613
|
|
610
614
|
Yields:
|
611
615
|
ModelResponse: 流式响应的每个数据块
|
@@ -613,7 +617,12 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
613
617
|
Raises:
|
614
618
|
TimeoutError: 当等待下一个数据块超时时
|
615
619
|
"""
|
616
|
-
|
620
|
+
# 每次调用时重新生成metadata,确保JWT token是最新的
|
621
|
+
fresh_metadata = self._build_auth_metadata(
|
622
|
+
request_id or get_request_id(),
|
623
|
+
origin_request_id
|
624
|
+
)
|
625
|
+
stream_iter = self.stub.Invoke(request, metadata=fresh_metadata, timeout=invoke_timeout).__aiter__()
|
617
626
|
chunk_timeout = 30.0 # 单个数据块的超时时间
|
618
627
|
|
619
628
|
try:
|
@@ -634,7 +643,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
634
643
|
except Exception as e:
|
635
644
|
raise
|
636
645
|
|
637
|
-
async def _stream_with_logging(self, request, metadata, invoke_timeout, start_time, model_request) -> AsyncIterator[
|
646
|
+
async def _stream_with_logging(self, request, metadata, invoke_timeout, start_time, model_request, request_id=None, origin_request_id=None) -> AsyncIterator[
|
638
647
|
ModelResponse]:
|
639
648
|
"""流式响应的包装器,用于记录完整的响应日志并处理重试"""
|
640
649
|
total_content = ""
|
@@ -643,7 +652,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
643
652
|
chunk_count = 0
|
644
653
|
|
645
654
|
# 使用重试逻辑获取流生成器
|
646
|
-
stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout, request_id=get_request_id())
|
655
|
+
stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout, request_id=request_id or get_request_id(), origin_request_id=origin_request_id)
|
647
656
|
|
648
657
|
try:
|
649
658
|
async for response in stream_generator:
|
@@ -717,9 +726,22 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
717
726
|
)
|
718
727
|
raise
|
719
728
|
|
720
|
-
async def _invoke_request(self, request, metadata, invoke_timeout):
|
721
|
-
"""执行单个非流式请求
|
722
|
-
|
729
|
+
async def _invoke_request(self, request, metadata, invoke_timeout, request_id=None, origin_request_id=None):
|
730
|
+
"""执行单个非流式请求
|
731
|
+
|
732
|
+
Args:
|
733
|
+
request: gRPC请求对象
|
734
|
+
metadata: 请求元数据(为了兼容性保留,但会被忽略)
|
735
|
+
invoke_timeout: 请求超时时间
|
736
|
+
request_id: 请求ID
|
737
|
+
origin_request_id: 原始请求ID
|
738
|
+
"""
|
739
|
+
# 每次调用时重新生成metadata,确保JWT token是最新的
|
740
|
+
fresh_metadata = self._build_auth_metadata(
|
741
|
+
request_id or get_request_id(),
|
742
|
+
origin_request_id
|
743
|
+
)
|
744
|
+
async for response in self.stub.Invoke(request, metadata=fresh_metadata, timeout=invoke_timeout):
|
723
745
|
return ResponseHandler.build_model_response(response)
|
724
746
|
|
725
747
|
async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None,
|
@@ -815,13 +837,13 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
815
837
|
invoke_timeout = timeout or self.default_invoke_timeout
|
816
838
|
if model_request.stream:
|
817
839
|
# 对于流式响应,直接返回带日志记录的包装器
|
818
|
-
return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request)
|
840
|
+
return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request, request_id, origin_request_id)
|
819
841
|
else:
|
820
842
|
# 存储model_request和origin_request_id供重试方法使用
|
821
843
|
self._current_model_request = model_request
|
822
844
|
self._current_origin_request_id = origin_request_id
|
823
845
|
try:
|
824
|
-
result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
|
846
|
+
result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id, origin_request_id=origin_request_id)
|
825
847
|
finally:
|
826
848
|
# 清理临时存储
|
827
849
|
if hasattr(self, '_current_model_request'):
|
@@ -1033,6 +1055,13 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
|
|
1033
1055
|
"batch_size": len(batch_request_model.items)
|
1034
1056
|
}
|
1035
1057
|
})
|
1058
|
+
|
1059
|
+
# 记录失败(如果启用了熔断)
|
1060
|
+
if self.resilient_enabled and self.circuit_breaker:
|
1061
|
+
# 将错误码传递给熔断器,用于智能失败统计
|
1062
|
+
error_code = e.code() if hasattr(e, 'code') else None
|
1063
|
+
self.circuit_breaker.record_failure(error_code)
|
1064
|
+
|
1036
1065
|
raise e
|
1037
1066
|
except Exception as e:
|
1038
1067
|
duration = time.time() - start_time
|
tamar_model_client/auth.py
CHANGED
@@ -1,14 +1,43 @@
|
|
1
1
|
import time
|
2
2
|
import jwt
|
3
|
+
from typing import Optional
|
3
4
|
|
4
5
|
|
5
6
|
# JWT 处理类
|
6
7
|
class JWTAuthHandler:
|
7
8
|
def __init__(self, secret_key: str):
|
8
9
|
self.secret_key = secret_key
|
10
|
+
self._token_cache: Optional[str] = None
|
11
|
+
self._token_exp_time: Optional[int] = None
|
9
12
|
|
10
13
|
def encode_token(self, payload: dict, expires_in: int = 3600) -> str:
|
11
14
|
"""生成带过期时间的 JWT Token"""
|
12
15
|
payload = payload.copy()
|
13
|
-
|
14
|
-
|
16
|
+
exp_time = int(time.time()) + expires_in
|
17
|
+
payload["exp"] = exp_time
|
18
|
+
token = jwt.encode(payload, self.secret_key, algorithm="HS256")
|
19
|
+
|
20
|
+
# 缓存token和过期时间
|
21
|
+
self._token_cache = token
|
22
|
+
self._token_exp_time = exp_time
|
23
|
+
|
24
|
+
return token
|
25
|
+
|
26
|
+
def is_token_expiring_soon(self, buffer_seconds: int = 60) -> bool:
|
27
|
+
"""检查token是否即将过期
|
28
|
+
|
29
|
+
Args:
|
30
|
+
buffer_seconds: 提前多少秒认为token即将过期,默认60秒
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
bool: True表示token即将过期或已过期
|
34
|
+
"""
|
35
|
+
if not self._token_exp_time:
|
36
|
+
return True
|
37
|
+
|
38
|
+
current_time = int(time.time())
|
39
|
+
return current_time >= (self._token_exp_time - buffer_seconds)
|
40
|
+
|
41
|
+
def get_cached_token(self) -> Optional[str]:
|
42
|
+
"""获取缓存的token"""
|
43
|
+
return self._token_cache
|
@@ -136,10 +136,7 @@ class BaseClient(ABC):
|
|
136
136
|
('grpc.resource_quota_size', 1048576000), # 设置资源配额为1GB
|
137
137
|
|
138
138
|
# 启用负载均衡配置
|
139
|
-
('grpc.
|
140
|
-
|
141
|
-
# 启用详细的日志记录
|
142
|
-
('grpc.debug', 1), # 启用 gRPC 的调试日志,记录更多的连接和请求信息
|
139
|
+
('grpc.lb_policy_name', 'round_robin'), # 设置负载均衡策略为 round_robin(轮询)
|
143
140
|
]
|
144
141
|
|
145
142
|
if self.default_authority:
|
@@ -168,10 +165,27 @@ class BaseClient(ABC):
|
|
168
165
|
metadata.append(("x-origin-request-id", origin_request_id))
|
169
166
|
|
170
167
|
if self.jwt_handler:
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
168
|
+
# 检查token是否即将过期,如果是则刷新
|
169
|
+
if self.jwt_handler.is_token_expiring_soon():
|
170
|
+
self.jwt_token = self.jwt_handler.encode_token(
|
171
|
+
self.default_payload,
|
172
|
+
expires_in=self.token_expires_in
|
173
|
+
)
|
174
|
+
else:
|
175
|
+
# 使用缓存的token
|
176
|
+
cached_token = self.jwt_handler.get_cached_token()
|
177
|
+
if cached_token:
|
178
|
+
self.jwt_token = cached_token
|
179
|
+
else:
|
180
|
+
# 如果没有缓存,生成新token
|
181
|
+
self.jwt_token = self.jwt_handler.encode_token(
|
182
|
+
self.default_payload,
|
183
|
+
expires_in=self.token_expires_in
|
184
|
+
)
|
185
|
+
|
186
|
+
metadata.append(("authorization", f"Bearer {self.jwt_token}"))
|
187
|
+
elif self.jwt_token:
|
188
|
+
# 使用用户提供的预生成token
|
175
189
|
metadata.append(("authorization", f"Bearer {self.jwt_token}"))
|
176
190
|
|
177
191
|
return metadata
|
@@ -240,9 +254,13 @@ class BaseClient(ABC):
|
|
240
254
|
|
241
255
|
return {
|
242
256
|
"enabled": self.resilient_enabled,
|
243
|
-
"
|
244
|
-
|
245
|
-
|
257
|
+
"circuit_breaker": {
|
258
|
+
"state": self.circuit_breaker.get_state(),
|
259
|
+
"failure_count": self.circuit_breaker.failure_count,
|
260
|
+
"last_failure_time": self.circuit_breaker.last_failure_time,
|
261
|
+
"failure_threshold": self.circuit_breaker.failure_threshold,
|
262
|
+
"recovery_timeout": self.circuit_breaker.recovery_timeout
|
263
|
+
},
|
246
264
|
"http_fallback_url": self.http_fallback_url
|
247
265
|
}
|
248
266
|
|
@@ -15,6 +15,59 @@ from ..schemas import ModelRequest, ModelResponse
|
|
15
15
|
logger = get_protected_logger(__name__)
|
16
16
|
|
17
17
|
|
18
|
+
def safe_serialize(obj: Any) -> Any:
|
19
|
+
"""
|
20
|
+
安全地序列化对象,避免 Pydantic ValidatorIterator 序列化问题
|
21
|
+
"""
|
22
|
+
if obj is None:
|
23
|
+
return None
|
24
|
+
|
25
|
+
# 处理基本类型
|
26
|
+
if isinstance(obj, (str, int, float, bool)):
|
27
|
+
return obj
|
28
|
+
|
29
|
+
# 处理列表
|
30
|
+
if isinstance(obj, (list, tuple)):
|
31
|
+
return [safe_serialize(item) for item in obj]
|
32
|
+
|
33
|
+
# 处理字典
|
34
|
+
if isinstance(obj, dict):
|
35
|
+
return {key: safe_serialize(value) for key, value in obj.items()}
|
36
|
+
|
37
|
+
# 处理 Pydantic 模型
|
38
|
+
if hasattr(obj, 'model_dump'):
|
39
|
+
try:
|
40
|
+
return obj.model_dump(exclude_unset=True)
|
41
|
+
except Exception:
|
42
|
+
# 如果 model_dump 失败,尝试手动提取字段
|
43
|
+
try:
|
44
|
+
if hasattr(obj, '__dict__'):
|
45
|
+
return {k: safe_serialize(v) for k, v in obj.__dict__.items()
|
46
|
+
if not k.startswith('_') and not callable(v)}
|
47
|
+
elif hasattr(obj, '__slots__'):
|
48
|
+
return {slot: safe_serialize(getattr(obj, slot, None))
|
49
|
+
for slot in obj.__slots__ if hasattr(obj, slot)}
|
50
|
+
except Exception:
|
51
|
+
pass
|
52
|
+
|
53
|
+
# 处理 Pydantic v1 模型
|
54
|
+
if hasattr(obj, 'dict'):
|
55
|
+
try:
|
56
|
+
return obj.dict(exclude_unset=True)
|
57
|
+
except Exception:
|
58
|
+
pass
|
59
|
+
|
60
|
+
# 处理枚举
|
61
|
+
if hasattr(obj, 'value'):
|
62
|
+
return obj.value
|
63
|
+
|
64
|
+
# 最后的尝试:转换为字符串
|
65
|
+
try:
|
66
|
+
return str(obj)
|
67
|
+
except Exception:
|
68
|
+
return None
|
69
|
+
|
70
|
+
|
18
71
|
class HttpFallbackMixin:
|
19
72
|
"""HTTP fallback functionality for synchronous clients
|
20
73
|
|
@@ -43,30 +96,37 @@ class HttpFallbackMixin:
|
|
43
96
|
|
44
97
|
def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
|
45
98
|
"""Convert ModelRequest to HTTP payload format"""
|
99
|
+
# Use safe serialization to avoid Pydantic ValidatorIterator issues
|
46
100
|
payload = {
|
47
|
-
"provider": model_request.provider
|
48
|
-
"model": model_request.model,
|
49
|
-
"user_context": model_request.user_context
|
50
|
-
"stream": model_request.stream
|
101
|
+
"provider": safe_serialize(model_request.provider),
|
102
|
+
"model": safe_serialize(model_request.model),
|
103
|
+
"user_context": safe_serialize(model_request.user_context),
|
104
|
+
"stream": safe_serialize(model_request.stream)
|
51
105
|
}
|
52
106
|
|
53
107
|
# Add provider-specific fields
|
54
108
|
if hasattr(model_request, 'messages') and model_request.messages:
|
55
|
-
payload['messages'] = model_request.messages
|
109
|
+
payload['messages'] = safe_serialize(model_request.messages)
|
56
110
|
if hasattr(model_request, 'contents') and model_request.contents:
|
57
|
-
payload['contents'] = model_request.contents
|
111
|
+
payload['contents'] = safe_serialize(model_request.contents)
|
58
112
|
|
59
113
|
# Add optional fields
|
60
114
|
if model_request.channel:
|
61
|
-
payload['channel'] = model_request.channel
|
115
|
+
payload['channel'] = safe_serialize(model_request.channel)
|
62
116
|
if model_request.invoke_type:
|
63
|
-
payload['invoke_type'] = model_request.invoke_type
|
117
|
+
payload['invoke_type'] = safe_serialize(model_request.invoke_type)
|
64
118
|
|
65
|
-
# Add
|
119
|
+
# Add config parameters safely
|
120
|
+
if hasattr(model_request, 'config') and model_request.config:
|
121
|
+
payload['config'] = safe_serialize(model_request.config)
|
122
|
+
|
123
|
+
# Add extra parameters safely
|
66
124
|
if hasattr(model_request, 'model_extra') and model_request.model_extra:
|
67
|
-
|
68
|
-
|
69
|
-
|
125
|
+
serialized_extra = safe_serialize(model_request.model_extra)
|
126
|
+
if isinstance(serialized_extra, dict):
|
127
|
+
for key, value in serialized_extra.items():
|
128
|
+
if key not in payload:
|
129
|
+
payload[key] = value
|
70
130
|
|
71
131
|
return payload
|
72
132
|
|
@@ -96,7 +156,7 @@ class HttpFallbackMixin:
|
|
96
156
|
data = json.loads(data_str)
|
97
157
|
yield ModelResponse(**data)
|
98
158
|
except json.JSONDecodeError:
|
99
|
-
logger.warning(f"Failed to parse streaming response: {data_str}")
|
159
|
+
logger.warning(f"⚠️ Failed to parse streaming response: {data_str}")
|
100
160
|
|
101
161
|
def _invoke_http_fallback(self, model_request: ModelRequest,
|
102
162
|
timeout: Optional[float] = None,
|
@@ -262,14 +322,35 @@ class AsyncHttpFallbackMixin:
|
|
262
322
|
"""
|
263
323
|
|
264
324
|
async def _ensure_http_client(self) -> None:
|
265
|
-
"""Ensure async HTTP client is initialized"""
|
325
|
+
"""Ensure async HTTP client is initialized in the correct event loop"""
|
326
|
+
import asyncio
|
327
|
+
import aiohttp
|
328
|
+
|
329
|
+
# Get current event loop
|
330
|
+
current_loop = asyncio.get_running_loop()
|
331
|
+
|
332
|
+
# Check if we need to recreate the session
|
333
|
+
need_new_session = False
|
334
|
+
|
266
335
|
if not hasattr(self, '_http_session') or not self._http_session:
|
267
|
-
|
336
|
+
need_new_session = True
|
337
|
+
elif hasattr(self, '_http_session_loop') and self._http_session_loop != current_loop:
|
338
|
+
# Session was created in a different event loop
|
339
|
+
logger.warning("🔄 HTTP session bound to different event loop, recreating...")
|
340
|
+
# Close old session if possible
|
341
|
+
try:
|
342
|
+
await self._http_session.close()
|
343
|
+
except Exception as e:
|
344
|
+
logger.debug(f"Error closing old session: {e}")
|
345
|
+
need_new_session = True
|
346
|
+
|
347
|
+
if need_new_session:
|
268
348
|
self._http_session = aiohttp.ClientSession(
|
269
349
|
headers={
|
270
350
|
'User-Agent': 'AsyncTamarModelClient/1.0'
|
271
351
|
}
|
272
352
|
)
|
353
|
+
self._http_session_loop = current_loop
|
273
354
|
|
274
355
|
# Note: JWT token will be set per request in headers
|
275
356
|
|
@@ -305,7 +386,7 @@ class AsyncHttpFallbackMixin:
|
|
305
386
|
data = json.loads(data_str)
|
306
387
|
yield ModelResponse(**data)
|
307
388
|
except json.JSONDecodeError:
|
308
|
-
logger.warning(f"Failed to parse streaming response: {data_str}")
|
389
|
+
logger.warning(f"⚠️ Failed to parse streaming response: {data_str}")
|
309
390
|
|
310
391
|
async def _invoke_http_fallback(self, model_request: ModelRequest,
|
311
392
|
timeout: Optional[float] = None,
|
@@ -339,6 +420,7 @@ class AsyncHttpFallbackMixin:
|
|
339
420
|
|
340
421
|
# Convert to HTTP format
|
341
422
|
http_payload = self._convert_to_http_format(model_request)
|
423
|
+
print(http_payload)
|
342
424
|
|
343
425
|
# Construct URL
|
344
426
|
url = f"{self.http_fallback_url}/v1/invoke"
|
@@ -467,4 +549,6 @@ class AsyncHttpFallbackMixin:
|
|
467
549
|
"""Clean up HTTP session"""
|
468
550
|
if hasattr(self, '_http_session') and self._http_session:
|
469
551
|
await self._http_session.close()
|
470
|
-
self._http_session = None
|
552
|
+
self._http_session = None
|
553
|
+
if hasattr(self, '_http_session_loop'):
|
554
|
+
self._http_session_loop = None
|
@@ -67,7 +67,7 @@ class GrpcErrorHandler:
|
|
67
67
|
log_data['duration'] = context['duration']
|
68
68
|
|
69
69
|
self.logger.error(
|
70
|
-
f"gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
|
70
|
+
f"❌ gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
|
71
71
|
extra=log_data
|
72
72
|
)
|
73
73
|
|
@@ -151,14 +151,14 @@ class ErrorRecoveryStrategy:
|
|
151
151
|
|
152
152
|
async def handle_token_refresh(self, error_context: ErrorContext):
|
153
153
|
"""处理 Token 刷新"""
|
154
|
-
self.client.logger.info("Attempting to refresh JWT token")
|
154
|
+
self.client.logger.info("🔄 Attempting to refresh JWT token")
|
155
155
|
# 这里需要客户端实现 _refresh_jwt_token 方法
|
156
156
|
if hasattr(self.client, '_refresh_jwt_token'):
|
157
157
|
await self.client._refresh_jwt_token()
|
158
158
|
|
159
159
|
async def handle_reconnect(self, error_context: ErrorContext):
|
160
160
|
"""处理重连"""
|
161
|
-
self.client.logger.info("Attempting to reconnect channel")
|
161
|
+
self.client.logger.info("🔄 Attempting to reconnect channel")
|
162
162
|
# 这里需要客户端实现 _reconnect_channel 方法
|
163
163
|
if hasattr(self.client, '_reconnect_channel'):
|
164
164
|
await self.client._reconnect_channel()
|
@@ -170,7 +170,7 @@ class ErrorRecoveryStrategy:
|
|
170
170
|
|
171
171
|
async def handle_circuit_break(self, error_context: ErrorContext):
|
172
172
|
"""处理熔断"""
|
173
|
-
self.client.logger.warning("Circuit breaker activated")
|
173
|
+
self.client.logger.warning("⚠️ Circuit breaker activated")
|
174
174
|
# 这里可以实现熔断逻辑
|
175
175
|
pass
|
176
176
|
|
@@ -322,8 +322,9 @@ class EnhancedRetryHandler:
|
|
322
322
|
},
|
323
323
|
"duration": current_duration
|
324
324
|
}
|
325
|
+
error_detail = f" - {error_context.error_message}" if error_context.error_message else ""
|
325
326
|
logger.warning(
|
326
|
-
f"
|
327
|
+
f"⚠️ Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (no more retries)",
|
327
328
|
extra=log_data
|
328
329
|
)
|
329
330
|
last_exception = self.error_handler.handle_error(e, context)
|
@@ -346,8 +347,9 @@ class EnhancedRetryHandler:
|
|
346
347
|
},
|
347
348
|
"duration": current_duration
|
348
349
|
}
|
350
|
+
error_detail = f" - {error_context.error_message}" if error_context.error_message else ""
|
349
351
|
logger.warning(
|
350
|
-
f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
|
352
|
+
f"🔄 Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (will retry)",
|
351
353
|
extra=log_data
|
352
354
|
)
|
353
355
|
|
@@ -57,5 +57,14 @@ class JSONFormatter(logging.Formatter):
|
|
57
57
|
if hasattr(record, "trace"):
|
58
58
|
log_data["trace"] = getattr(record, "trace")
|
59
59
|
|
60
|
+
# 添加异常信息(如果有的话)
|
61
|
+
if record.exc_info:
|
62
|
+
import traceback
|
63
|
+
log_data["exception"] = {
|
64
|
+
"type": record.exc_info[0].__name__ if record.exc_info[0] else None,
|
65
|
+
"message": str(record.exc_info[1]) if record.exc_info[1] else None,
|
66
|
+
"traceback": traceback.format_exception(*record.exc_info)
|
67
|
+
}
|
68
|
+
|
60
69
|
# 使用安全的 JSON 编码器
|
61
70
|
return json.dumps(log_data, ensure_ascii=False, cls=SafeJSONEncoder)
|