tamar-model-client 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/async_client.py +195 -15
- tamar_model_client/circuit_breaker.py +140 -0
- tamar_model_client/core/__init__.py +6 -0
- tamar_model_client/core/base_client.py +56 -3
- tamar_model_client/core/http_fallback.py +249 -0
- tamar_model_client/core/logging_setup.py +124 -14
- tamar_model_client/error_handler.py +60 -6
- tamar_model_client/exceptions.py +49 -1
- tamar_model_client/sync_client.py +239 -27
- {tamar_model_client-0.1.20.dist-info → tamar_model_client-0.1.22.dist-info}/METADATA +73 -1
- {tamar_model_client-0.1.20.dist-info → tamar_model_client-0.1.22.dist-info}/RECORD +15 -12
- tests/test_google_azure_final.py +325 -63
- tests/test_logging_issue.py +75 -0
- {tamar_model_client-0.1.20.dist-info → tamar_model_client-0.1.22.dist-info}/WHEEL +0 -0
- {tamar_model_client-0.1.20.dist-info → tamar_model_client-0.1.22.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,249 @@
|
|
1
|
+
"""
|
2
|
+
HTTP fallback functionality for resilient clients
|
3
|
+
|
4
|
+
This module provides mixin classes for HTTP-based fallback when gRPC
|
5
|
+
connections fail, supporting both synchronous and asynchronous clients.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
import logging
|
10
|
+
from typing import Optional, Iterator, AsyncIterator, Dict, Any
|
11
|
+
|
12
|
+
from . import generate_request_id, get_protected_logger
|
13
|
+
from ..schemas import ModelRequest, ModelResponse
|
14
|
+
|
15
|
+
logger = get_protected_logger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class HttpFallbackMixin:
|
19
|
+
"""HTTP fallback functionality for synchronous clients"""
|
20
|
+
|
21
|
+
def _ensure_http_client(self) -> None:
|
22
|
+
"""Ensure HTTP client is initialized"""
|
23
|
+
if not hasattr(self, '_http_client') or not self._http_client:
|
24
|
+
import requests
|
25
|
+
self._http_client = requests.Session()
|
26
|
+
|
27
|
+
# Set authentication header if available
|
28
|
+
# Note: JWT token will be set per request in headers
|
29
|
+
|
30
|
+
# Set default headers
|
31
|
+
self._http_client.headers.update({
|
32
|
+
'Content-Type': 'application/json',
|
33
|
+
'User-Agent': 'TamarModelClient/1.0'
|
34
|
+
})
|
35
|
+
|
36
|
+
def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
|
37
|
+
"""Convert ModelRequest to HTTP payload format"""
|
38
|
+
payload = {
|
39
|
+
"provider": model_request.provider.value,
|
40
|
+
"model": model_request.model,
|
41
|
+
"user_context": model_request.user_context.model_dump(),
|
42
|
+
"stream": model_request.stream
|
43
|
+
}
|
44
|
+
|
45
|
+
# Add provider-specific fields
|
46
|
+
if hasattr(model_request, 'messages') and model_request.messages:
|
47
|
+
payload['messages'] = model_request.messages
|
48
|
+
if hasattr(model_request, 'contents') and model_request.contents:
|
49
|
+
payload['contents'] = model_request.contents
|
50
|
+
|
51
|
+
# Add optional fields
|
52
|
+
if model_request.channel:
|
53
|
+
payload['channel'] = model_request.channel.value
|
54
|
+
if model_request.invoke_type:
|
55
|
+
payload['invoke_type'] = model_request.invoke_type.value
|
56
|
+
|
57
|
+
# Add extra parameters
|
58
|
+
if hasattr(model_request, 'model_extra') and model_request.model_extra:
|
59
|
+
for key, value in model_request.model_extra.items():
|
60
|
+
if key not in payload:
|
61
|
+
payload[key] = value
|
62
|
+
|
63
|
+
return payload
|
64
|
+
|
65
|
+
def _handle_http_stream(self, url: str, payload: Dict[str, Any],
|
66
|
+
timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> Iterator[ModelResponse]:
|
67
|
+
"""Handle HTTP streaming response"""
|
68
|
+
import requests
|
69
|
+
|
70
|
+
response = self._http_client.post(
|
71
|
+
url,
|
72
|
+
json=payload,
|
73
|
+
timeout=timeout or 30,
|
74
|
+
headers=headers,
|
75
|
+
stream=True
|
76
|
+
)
|
77
|
+
response.raise_for_status()
|
78
|
+
|
79
|
+
# Parse SSE stream
|
80
|
+
for line in response.iter_lines():
|
81
|
+
if line:
|
82
|
+
line_str = line.decode('utf-8')
|
83
|
+
if line_str.startswith('data: '):
|
84
|
+
data_str = line_str[6:]
|
85
|
+
if data_str == '[DONE]':
|
86
|
+
break
|
87
|
+
try:
|
88
|
+
data = json.loads(data_str)
|
89
|
+
yield ModelResponse(**data)
|
90
|
+
except json.JSONDecodeError:
|
91
|
+
logger.warning(f"Failed to parse streaming response: {data_str}")
|
92
|
+
|
93
|
+
def _invoke_http_fallback(self, model_request: ModelRequest,
|
94
|
+
timeout: Optional[float] = None,
|
95
|
+
request_id: Optional[str] = None) -> Any:
|
96
|
+
"""HTTP fallback implementation"""
|
97
|
+
self._ensure_http_client()
|
98
|
+
|
99
|
+
# Generate request ID if not provided
|
100
|
+
if not request_id:
|
101
|
+
request_id = generate_request_id()
|
102
|
+
|
103
|
+
# Log fallback usage
|
104
|
+
logger.warning(
|
105
|
+
f"🔻 Using HTTP fallback for request",
|
106
|
+
extra={
|
107
|
+
"request_id": request_id,
|
108
|
+
"provider": model_request.provider.value,
|
109
|
+
"model": model_request.model,
|
110
|
+
"fallback_url": self.http_fallback_url
|
111
|
+
}
|
112
|
+
)
|
113
|
+
|
114
|
+
# Convert to HTTP format
|
115
|
+
http_payload = self._convert_to_http_format(model_request)
|
116
|
+
|
117
|
+
# Construct URL
|
118
|
+
url = f"{self.http_fallback_url}/v1/invoke"
|
119
|
+
|
120
|
+
# Build headers with authentication
|
121
|
+
headers = {'X-Request-ID': request_id}
|
122
|
+
if hasattr(self, 'jwt_token') and self.jwt_token:
|
123
|
+
headers['Authorization'] = f'Bearer {self.jwt_token}'
|
124
|
+
|
125
|
+
if model_request.stream:
|
126
|
+
# Return streaming iterator
|
127
|
+
return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
|
128
|
+
else:
|
129
|
+
# Non-streaming request
|
130
|
+
response = self._http_client.post(
|
131
|
+
url,
|
132
|
+
json=http_payload,
|
133
|
+
timeout=timeout or 30,
|
134
|
+
headers=headers
|
135
|
+
)
|
136
|
+
response.raise_for_status()
|
137
|
+
|
138
|
+
# Parse response
|
139
|
+
data = response.json()
|
140
|
+
return ModelResponse(**data)
|
141
|
+
|
142
|
+
|
143
|
+
class AsyncHttpFallbackMixin:
|
144
|
+
"""HTTP fallback functionality for asynchronous clients"""
|
145
|
+
|
146
|
+
async def _ensure_http_client(self) -> None:
|
147
|
+
"""Ensure async HTTP client is initialized"""
|
148
|
+
if not hasattr(self, '_http_session') or not self._http_session:
|
149
|
+
import aiohttp
|
150
|
+
self._http_session = aiohttp.ClientSession(
|
151
|
+
headers={
|
152
|
+
'Content-Type': 'application/json',
|
153
|
+
'User-Agent': 'AsyncTamarModelClient/1.0'
|
154
|
+
}
|
155
|
+
)
|
156
|
+
|
157
|
+
# Note: JWT token will be set per request in headers
|
158
|
+
|
159
|
+
def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
|
160
|
+
"""Convert ModelRequest to HTTP payload format (reuse sync version)"""
|
161
|
+
# This method doesn't need to be async, so we can reuse the sync version
|
162
|
+
return HttpFallbackMixin._convert_to_http_format(self, model_request)
|
163
|
+
|
164
|
+
async def _handle_http_stream(self, url: str, payload: Dict[str, Any],
|
165
|
+
timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> AsyncIterator[ModelResponse]:
|
166
|
+
"""Handle async HTTP streaming response"""
|
167
|
+
import aiohttp
|
168
|
+
|
169
|
+
timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
|
170
|
+
|
171
|
+
async with self._http_session.post(
|
172
|
+
url,
|
173
|
+
json=payload,
|
174
|
+
timeout=timeout_obj,
|
175
|
+
headers=headers
|
176
|
+
) as response:
|
177
|
+
response.raise_for_status()
|
178
|
+
|
179
|
+
# Parse SSE stream
|
180
|
+
async for line_bytes in response.content:
|
181
|
+
if line_bytes:
|
182
|
+
line_str = line_bytes.decode('utf-8').strip()
|
183
|
+
if line_str.startswith('data: '):
|
184
|
+
data_str = line_str[6:]
|
185
|
+
if data_str == '[DONE]':
|
186
|
+
break
|
187
|
+
try:
|
188
|
+
data = json.loads(data_str)
|
189
|
+
yield ModelResponse(**data)
|
190
|
+
except json.JSONDecodeError:
|
191
|
+
logger.warning(f"Failed to parse streaming response: {data_str}")
|
192
|
+
|
193
|
+
async def _invoke_http_fallback(self, model_request: ModelRequest,
|
194
|
+
timeout: Optional[float] = None,
|
195
|
+
request_id: Optional[str] = None) -> Any:
|
196
|
+
"""Async HTTP fallback implementation"""
|
197
|
+
await self._ensure_http_client()
|
198
|
+
|
199
|
+
# Generate request ID if not provided
|
200
|
+
if not request_id:
|
201
|
+
request_id = generate_request_id()
|
202
|
+
|
203
|
+
# Log fallback usage
|
204
|
+
logger.warning(
|
205
|
+
f"🔻 Using HTTP fallback for request",
|
206
|
+
extra={
|
207
|
+
"request_id": request_id,
|
208
|
+
"provider": model_request.provider.value,
|
209
|
+
"model": model_request.model,
|
210
|
+
"fallback_url": self.http_fallback_url
|
211
|
+
}
|
212
|
+
)
|
213
|
+
|
214
|
+
# Convert to HTTP format
|
215
|
+
http_payload = self._convert_to_http_format(model_request)
|
216
|
+
|
217
|
+
# Construct URL
|
218
|
+
url = f"{self.http_fallback_url}/v1/invoke"
|
219
|
+
|
220
|
+
# Build headers with authentication
|
221
|
+
headers = {'X-Request-ID': request_id}
|
222
|
+
if hasattr(self, 'jwt_token') and self.jwt_token:
|
223
|
+
headers['Authorization'] = f'Bearer {self.jwt_token}'
|
224
|
+
|
225
|
+
if model_request.stream:
|
226
|
+
# Return async streaming iterator
|
227
|
+
return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
|
228
|
+
else:
|
229
|
+
# Non-streaming request
|
230
|
+
import aiohttp
|
231
|
+
timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
|
232
|
+
|
233
|
+
async with self._http_session.post(
|
234
|
+
url,
|
235
|
+
json=http_payload,
|
236
|
+
timeout=timeout_obj,
|
237
|
+
headers=headers
|
238
|
+
) as response:
|
239
|
+
response.raise_for_status()
|
240
|
+
|
241
|
+
# Parse response
|
242
|
+
data = await response.json()
|
243
|
+
return ModelResponse(**data)
|
244
|
+
|
245
|
+
async def _cleanup_http_session(self) -> None:
|
246
|
+
"""Clean up HTTP session"""
|
247
|
+
if hasattr(self, '_http_session') and self._http_session:
|
248
|
+
await self._http_session.close()
|
249
|
+
self._http_session = None
|
@@ -6,7 +6,8 @@ It includes request ID tracking, JSON formatting, and consistent log configurati
|
|
6
6
|
"""
|
7
7
|
|
8
8
|
import logging
|
9
|
-
|
9
|
+
import threading
|
10
|
+
from typing import Optional, Dict
|
10
11
|
|
11
12
|
from ..json_formatter import JSONFormatter
|
12
13
|
from .utils import get_request_id
|
@@ -14,6 +15,15 @@ from .utils import get_request_id
|
|
14
15
|
# gRPC 消息长度限制(32位系统兼容)
|
15
16
|
MAX_MESSAGE_LENGTH = 2 ** 31 - 1
|
16
17
|
|
18
|
+
# SDK 专用的 logger 名称前缀
|
19
|
+
TAMAR_LOGGER_PREFIX = "tamar_model_client"
|
20
|
+
|
21
|
+
# 线程安全的 logger 配置锁
|
22
|
+
_logger_lock = threading.Lock()
|
23
|
+
|
24
|
+
# 已配置的 logger 缓存
|
25
|
+
_configured_loggers: Dict[str, logging.Logger] = {}
|
26
|
+
|
17
27
|
|
18
28
|
class RequestIdFilter(logging.Filter):
|
19
29
|
"""
|
@@ -38,9 +48,54 @@ class RequestIdFilter(logging.Filter):
|
|
38
48
|
return True
|
39
49
|
|
40
50
|
|
51
|
+
class TamarLoggerAdapter:
|
52
|
+
"""
|
53
|
+
Logger 适配器,确保 SDK 的日志格式不被外部修改
|
54
|
+
|
55
|
+
这个适配器包装了原始的 logger,拦截所有的日志方法调用,
|
56
|
+
确保使用正确的格式和处理器。
|
57
|
+
"""
|
58
|
+
|
59
|
+
def __init__(self, logger: logging.Logger):
|
60
|
+
self._logger = logger
|
61
|
+
self._lock = threading.Lock()
|
62
|
+
|
63
|
+
def _ensure_format(self):
|
64
|
+
"""确保 logger 使用正确的格式"""
|
65
|
+
with self._lock:
|
66
|
+
# 检查并修复处理器
|
67
|
+
for handler in self._logger.handlers[:]:
|
68
|
+
if not isinstance(handler.formatter, JSONFormatter):
|
69
|
+
handler.setFormatter(JSONFormatter())
|
70
|
+
|
71
|
+
# 确保 propagate 设置正确
|
72
|
+
if self._logger.propagate:
|
73
|
+
self._logger.propagate = False
|
74
|
+
|
75
|
+
def _log(self, level, msg, *args, **kwargs):
|
76
|
+
"""统一的日志方法"""
|
77
|
+
self._ensure_format()
|
78
|
+
getattr(self._logger, level)(msg, *args, **kwargs)
|
79
|
+
|
80
|
+
def debug(self, msg, *args, **kwargs):
|
81
|
+
self._log('debug', msg, *args, **kwargs)
|
82
|
+
|
83
|
+
def info(self, msg, *args, **kwargs):
|
84
|
+
self._log('info', msg, *args, **kwargs)
|
85
|
+
|
86
|
+
def warning(self, msg, *args, **kwargs):
|
87
|
+
self._log('warning', msg, *args, **kwargs)
|
88
|
+
|
89
|
+
def error(self, msg, *args, **kwargs):
|
90
|
+
self._log('error', msg, *args, **kwargs)
|
91
|
+
|
92
|
+
def critical(self, msg, *args, **kwargs):
|
93
|
+
self._log('critical', msg, *args, **kwargs)
|
94
|
+
|
95
|
+
|
41
96
|
def setup_logger(logger_name: str, level: int = logging.INFO) -> logging.Logger:
|
42
97
|
"""
|
43
|
-
设置并配置logger
|
98
|
+
设置并配置logger (保持向后兼容)
|
44
99
|
|
45
100
|
为指定的logger配置处理器、格式化器和过滤器。
|
46
101
|
如果logger已经有处理器,则不会重复配置。
|
@@ -57,28 +112,83 @@ def setup_logger(logger_name: str, level: int = logging.INFO) -> logging.Logger:
|
|
57
112
|
- 添加请求ID过滤器用于请求追踪
|
58
113
|
- 避免重复配置
|
59
114
|
"""
|
60
|
-
logger
|
115
|
+
# 确保 logger 名称以 SDK 前缀开始
|
116
|
+
if not logger_name.startswith(TAMAR_LOGGER_PREFIX):
|
117
|
+
logger_name = f"{TAMAR_LOGGER_PREFIX}.{logger_name}"
|
61
118
|
|
62
|
-
|
63
|
-
|
64
|
-
|
119
|
+
with _logger_lock:
|
120
|
+
# 检查缓存
|
121
|
+
if logger_name in _configured_loggers:
|
122
|
+
return _configured_loggers[logger_name]
|
123
|
+
|
124
|
+
logger = logging.getLogger(logger_name)
|
125
|
+
|
126
|
+
# 强制清除所有现有的处理器
|
127
|
+
logger.handlers.clear()
|
128
|
+
|
129
|
+
# 创建专用的控制台处理器
|
65
130
|
console_handler = logging.StreamHandler()
|
131
|
+
console_handler.setFormatter(JSONFormatter())
|
66
132
|
|
67
|
-
#
|
68
|
-
|
69
|
-
console_handler.setFormatter(formatter)
|
133
|
+
# 为处理器设置唯一标识,便于识别
|
134
|
+
console_handler.name = f"tamar_handler_{id(console_handler)}"
|
70
135
|
|
71
|
-
#
|
136
|
+
# 添加处理器
|
72
137
|
logger.addHandler(console_handler)
|
73
138
|
|
74
139
|
# 设置日志级别
|
75
140
|
logger.setLevel(level)
|
76
141
|
|
77
|
-
#
|
142
|
+
# 添加请求ID过滤器
|
78
143
|
logger.addFilter(RequestIdFilter())
|
79
144
|
|
80
|
-
#
|
81
|
-
#
|
145
|
+
# 关键设置:
|
146
|
+
# 1. 不传播到父 logger
|
82
147
|
logger.propagate = False
|
148
|
+
|
149
|
+
# 2. 禁用外部修改(Python 3.8+)
|
150
|
+
if hasattr(logger, 'disabled'):
|
151
|
+
logger.disabled = False
|
152
|
+
|
153
|
+
# 缓存配置好的 logger
|
154
|
+
_configured_loggers[logger_name] = logger
|
155
|
+
|
156
|
+
return logger
|
157
|
+
|
158
|
+
|
159
|
+
def get_protected_logger(logger_name: str, level: int = logging.INFO) -> TamarLoggerAdapter:
|
160
|
+
"""
|
161
|
+
获取受保护的 logger
|
162
|
+
|
163
|
+
返回一个 logger 适配器,确保日志格式不会被外部修改。
|
83
164
|
|
84
|
-
|
165
|
+
Args:
|
166
|
+
logger_name: logger的名称
|
167
|
+
level: 日志级别,默认为 INFO
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
TamarLoggerAdapter: 受保护的 logger 适配器
|
171
|
+
"""
|
172
|
+
logger = setup_logger(logger_name, level)
|
173
|
+
return TamarLoggerAdapter(logger)
|
174
|
+
|
175
|
+
|
176
|
+
def reset_logger_config(logger_name: str) -> None:
|
177
|
+
"""
|
178
|
+
重置 logger 配置
|
179
|
+
|
180
|
+
用于测试或需要重新配置的场景。
|
181
|
+
|
182
|
+
Args:
|
183
|
+
logger_name: logger的名称
|
184
|
+
"""
|
185
|
+
if not logger_name.startswith(TAMAR_LOGGER_PREFIX):
|
186
|
+
logger_name = f"{TAMAR_LOGGER_PREFIX}.{logger_name}"
|
187
|
+
|
188
|
+
with _logger_lock:
|
189
|
+
if logger_name in _configured_loggers:
|
190
|
+
del _configured_loggers[logger_name]
|
191
|
+
|
192
|
+
logger = logging.getLogger(logger_name)
|
193
|
+
logger.handlers.clear()
|
194
|
+
logger.filters.clear()
|
@@ -43,9 +43,32 @@ class GrpcErrorHandler:
|
|
43
43
|
error_context = ErrorContext(error, context)
|
44
44
|
|
45
45
|
# 记录详细错误日志
|
46
|
+
# 将error_context的重要信息平铺到日志的data字段中
|
47
|
+
log_data = {
|
48
|
+
"log_type": "info",
|
49
|
+
"request_id": error_context.request_id,
|
50
|
+
"data": {
|
51
|
+
"error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
|
52
|
+
"error_message": error_context.error_message,
|
53
|
+
"provider": error_context.provider,
|
54
|
+
"model": error_context.model,
|
55
|
+
"method": error_context.method,
|
56
|
+
"retry_count": error_context.retry_count,
|
57
|
+
"category": error_context._get_error_category(),
|
58
|
+
"is_retryable": error_context._is_retryable(),
|
59
|
+
"suggested_action": error_context._get_suggested_action(),
|
60
|
+
"debug_string": error_context.error_debug_string,
|
61
|
+
"is_network_cancelled": error_context.is_network_cancelled() if error_context.error_code == grpc.StatusCode.CANCELLED else None
|
62
|
+
}
|
63
|
+
}
|
64
|
+
|
65
|
+
# 如果上下文中有 duration,添加到日志中
|
66
|
+
if 'duration' in context:
|
67
|
+
log_data['duration'] = context['duration']
|
68
|
+
|
46
69
|
self.logger.error(
|
47
|
-
f"gRPC Error occurred: {error_context.error_code}",
|
48
|
-
extra=
|
70
|
+
f"gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
|
71
|
+
extra=log_data
|
49
72
|
)
|
50
73
|
|
51
74
|
# 更新错误统计
|
@@ -192,6 +215,10 @@ class EnhancedRetryHandler:
|
|
192
215
|
Raises:
|
193
216
|
TamarModelException: 包装后的异常
|
194
217
|
"""
|
218
|
+
# 记录开始时间
|
219
|
+
import time
|
220
|
+
method_start_time = time.time()
|
221
|
+
|
195
222
|
context = context or {}
|
196
223
|
last_exception = None
|
197
224
|
|
@@ -207,13 +234,32 @@ class EnhancedRetryHandler:
|
|
207
234
|
# 判断是否可以重试
|
208
235
|
if not self._should_retry(e, attempt):
|
209
236
|
# 不可重试或已达到最大重试次数
|
237
|
+
current_duration = time.time() - method_start_time
|
238
|
+
context['duration'] = current_duration
|
210
239
|
last_exception = self.error_handler.handle_error(e, context)
|
211
240
|
break
|
241
|
+
|
242
|
+
# 计算当前耗时
|
243
|
+
current_duration = time.time() - method_start_time
|
212
244
|
|
213
245
|
# 记录重试日志
|
246
|
+
log_data = {
|
247
|
+
"log_type": "info",
|
248
|
+
"request_id": error_context.request_id,
|
249
|
+
"data": {
|
250
|
+
"error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
|
251
|
+
"error_message": error_context.error_message,
|
252
|
+
"retry_count": attempt,
|
253
|
+
"max_retries": self.max_retries,
|
254
|
+
"category": error_context._get_error_category(),
|
255
|
+
"is_retryable": True, # 既然在重试,说明是可重试的
|
256
|
+
"method": error_context.method
|
257
|
+
},
|
258
|
+
"duration": current_duration
|
259
|
+
}
|
214
260
|
logger.warning(
|
215
261
|
f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
|
216
|
-
extra=
|
262
|
+
extra=log_data
|
217
263
|
)
|
218
264
|
|
219
265
|
# 执行退避等待
|
@@ -221,6 +267,7 @@ class EnhancedRetryHandler:
|
|
221
267
|
delay = self._calculate_backoff(attempt)
|
222
268
|
await asyncio.sleep(delay)
|
223
269
|
|
270
|
+
context['duration'] = current_duration
|
224
271
|
last_exception = self.error_handler.handle_error(e, context)
|
225
272
|
|
226
273
|
except Exception as e:
|
@@ -239,12 +286,19 @@ class EnhancedRetryHandler:
|
|
239
286
|
|
240
287
|
def _should_retry(self, error: grpc.RpcError, attempt: int) -> bool:
|
241
288
|
"""判断是否应该重试"""
|
242
|
-
if attempt >= self.max_retries:
|
243
|
-
return False
|
244
|
-
|
245
289
|
error_code = error.code()
|
246
290
|
policy = RETRY_POLICY.get(error_code, {})
|
247
291
|
|
292
|
+
# 先检查错误级别的 max_attempts 配置
|
293
|
+
# max_attempts 表示最大重试次数(不包括初始请求)
|
294
|
+
error_max_attempts = policy.get('max_attempts', self.max_retries)
|
295
|
+
if attempt >= error_max_attempts:
|
296
|
+
return False
|
297
|
+
|
298
|
+
# 再检查全局的 max_retries
|
299
|
+
if attempt >= self.max_retries:
|
300
|
+
return False
|
301
|
+
|
248
302
|
# 检查基本重试策略
|
249
303
|
retryable = policy.get('retryable', False)
|
250
304
|
if retryable == False:
|
tamar_model_client/exceptions.py
CHANGED
@@ -17,7 +17,10 @@ ERROR_CATEGORIES = {
|
|
17
17
|
'NETWORK': [
|
18
18
|
grpc.StatusCode.UNAVAILABLE,
|
19
19
|
grpc.StatusCode.DEADLINE_EXCEEDED,
|
20
|
-
grpc.StatusCode.
|
20
|
+
grpc.StatusCode.CANCELLED, # 网络中断导致的取消
|
21
|
+
],
|
22
|
+
'CONCURRENCY': [
|
23
|
+
grpc.StatusCode.ABORTED, # 并发冲突,单独分类便于监控
|
21
24
|
],
|
22
25
|
'AUTH': [
|
23
26
|
grpc.StatusCode.UNAUTHENTICATED,
|
@@ -71,6 +74,19 @@ RETRY_POLICY = {
|
|
71
74
|
'action': 'refresh_token', # 特殊动作
|
72
75
|
'max_attempts': 1
|
73
76
|
},
|
77
|
+
grpc.StatusCode.CANCELLED: {
|
78
|
+
'retryable': True,
|
79
|
+
'backoff': 'linear', # 线性退避,网络问题通常不需要指数退避
|
80
|
+
'max_attempts': 2, # 最大重试次数(不包括初始请求),总共会尝试3次
|
81
|
+
'check_details': False # 不检查详细信息,统一重试
|
82
|
+
},
|
83
|
+
grpc.StatusCode.ABORTED: {
|
84
|
+
'retryable': True,
|
85
|
+
'backoff': 'exponential', # 指数退避,避免加剧并发竞争
|
86
|
+
'max_attempts': 3, # 适中的重试次数
|
87
|
+
'jitter': True, # 添加随机延迟,减少竞争
|
88
|
+
'check_details': False
|
89
|
+
},
|
74
90
|
# 不可重试的错误
|
75
91
|
grpc.StatusCode.INVALID_ARGUMENT: {'retryable': False},
|
76
92
|
grpc.StatusCode.NOT_FOUND: {'retryable': False},
|
@@ -160,6 +176,7 @@ class ErrorContext:
|
|
160
176
|
"""获取建议的处理动作"""
|
161
177
|
suggestions = {
|
162
178
|
'NETWORK': '检查网络连接,稍后重试',
|
179
|
+
'CONCURRENCY': '并发冲突,系统会自动重试',
|
163
180
|
'AUTH': '检查认证信息,可能需要刷新 Token',
|
164
181
|
'VALIDATION': '检查请求参数是否正确',
|
165
182
|
'RESOURCE': '检查资源限制或等待一段时间',
|
@@ -167,6 +184,37 @@ class ErrorContext:
|
|
167
184
|
'DATA': '数据损坏或丢失,请检查输入数据',
|
168
185
|
}
|
169
186
|
return suggestions.get(self._get_error_category(), '未知错误,请联系技术支持')
|
187
|
+
|
188
|
+
def is_network_cancelled(self) -> bool:
|
189
|
+
"""
|
190
|
+
判断 CANCELLED 错误是否由网络中断导致
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
bool: 如果是网络中断导致的 CANCELLED 返回 True
|
194
|
+
"""
|
195
|
+
if self.error_code != grpc.StatusCode.CANCELLED:
|
196
|
+
return False
|
197
|
+
|
198
|
+
# 检查错误消息中是否包含网络相关的关键词
|
199
|
+
error_msg = (self.error_message or '').lower()
|
200
|
+
debug_msg = (self.error_debug_string or '').lower()
|
201
|
+
|
202
|
+
network_patterns = [
|
203
|
+
'connection reset',
|
204
|
+
'connection refused',
|
205
|
+
'connection closed',
|
206
|
+
'network unreachable',
|
207
|
+
'broken pipe',
|
208
|
+
'socket closed',
|
209
|
+
'eof',
|
210
|
+
'transport'
|
211
|
+
]
|
212
|
+
|
213
|
+
for pattern in network_patterns:
|
214
|
+
if pattern in error_msg or pattern in debug_msg:
|
215
|
+
return True
|
216
|
+
|
217
|
+
return False
|
170
218
|
|
171
219
|
|
172
220
|
# ===== 异常类层级 =====
|