tamar-model-client 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/async_client.py +580 -540
- tamar_model_client/circuit_breaker.py +140 -0
- tamar_model_client/core/__init__.py +40 -0
- tamar_model_client/core/base_client.py +221 -0
- tamar_model_client/core/http_fallback.py +249 -0
- tamar_model_client/core/logging_setup.py +194 -0
- tamar_model_client/core/request_builder.py +221 -0
- tamar_model_client/core/response_handler.py +136 -0
- tamar_model_client/core/utils.py +171 -0
- tamar_model_client/error_handler.py +315 -0
- tamar_model_client/exceptions.py +388 -7
- tamar_model_client/json_formatter.py +36 -1
- tamar_model_client/sync_client.py +590 -486
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/METADATA +289 -61
- tamar_model_client-0.1.21.dist-info/RECORD +35 -0
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/top_level.txt +1 -0
- tests/__init__.py +1 -0
- tests/stream_hanging_analysis.py +357 -0
- tests/test_google_azure_final.py +448 -0
- tests/test_simple.py +235 -0
- tamar_model_client-0.1.19.dist-info/RECORD +0 -22
- {tamar_model_client-0.1.19.dist-info → tamar_model_client-0.1.21.dist-info}/WHEEL +0 -0
@@ -0,0 +1,171 @@
|
|
1
|
+
"""
|
2
|
+
Common utility functions for Tamar Model Client
|
3
|
+
|
4
|
+
This module contains shared utility functions used by both sync and async clients.
|
5
|
+
All functions in this module are pure functions without side effects.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import base64
|
9
|
+
import uuid
|
10
|
+
from typing import Any, Iterable
|
11
|
+
from contextvars import ContextVar
|
12
|
+
|
13
|
+
from openai import NOT_GIVEN
|
14
|
+
from pydantic import BaseModel
|
15
|
+
|
16
|
+
# 使用 contextvars 管理请求ID,支持异步和同步上下文中的请求追踪
|
17
|
+
_request_id: ContextVar[str] = ContextVar('request_id', default='-')
|
18
|
+
|
19
|
+
|
20
|
+
def is_effective_value(value) -> bool:
|
21
|
+
"""
|
22
|
+
递归判断值是否为有效值
|
23
|
+
|
24
|
+
用于过滤掉空值、None、NOT_GIVEN 等无意义的参数,
|
25
|
+
确保只有有效的参数被发送到服务器。
|
26
|
+
|
27
|
+
Args:
|
28
|
+
value: 待检查的值
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
bool: True 表示值有效,False 表示值无效
|
32
|
+
|
33
|
+
处理的无效值类型:
|
34
|
+
- None 和 NOT_GIVEN
|
35
|
+
- 空字符串(仅包含空白字符)
|
36
|
+
- 空字节序列
|
37
|
+
- 空字典(所有值都无效)
|
38
|
+
- 空列表(所有元素都无效)
|
39
|
+
"""
|
40
|
+
if value is None or value is NOT_GIVEN:
|
41
|
+
return False
|
42
|
+
|
43
|
+
if isinstance(value, str):
|
44
|
+
return value.strip() != ""
|
45
|
+
|
46
|
+
if isinstance(value, bytes):
|
47
|
+
return len(value) > 0
|
48
|
+
|
49
|
+
if isinstance(value, dict):
|
50
|
+
# 递归检查字典中的所有值
|
51
|
+
for v in value.values():
|
52
|
+
if is_effective_value(v):
|
53
|
+
return True
|
54
|
+
return False
|
55
|
+
|
56
|
+
if isinstance(value, list):
|
57
|
+
# 递归检查列表中的所有元素
|
58
|
+
for item in value:
|
59
|
+
if is_effective_value(item):
|
60
|
+
return True
|
61
|
+
return False
|
62
|
+
|
63
|
+
# 其他类型(int/float/bool)只要不是 None 就算有效
|
64
|
+
return True
|
65
|
+
|
66
|
+
|
67
|
+
def serialize_value(value):
|
68
|
+
"""
|
69
|
+
递归序列化值,处理各种复杂数据类型
|
70
|
+
|
71
|
+
将 Pydantic 模型、字典、列表、字节等复杂类型转换为
|
72
|
+
可以发送给 gRPC 服务的简单类型。
|
73
|
+
|
74
|
+
Args:
|
75
|
+
value: 待序列化的值
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
序列化后的值,如果值无效则返回 None
|
79
|
+
|
80
|
+
支持的类型转换:
|
81
|
+
- BaseModel -> dict (通过 model_dump)
|
82
|
+
- bytes -> base64 字符串
|
83
|
+
- dict -> 递归处理所有键值对
|
84
|
+
- list -> 递归处理所有元素
|
85
|
+
- 其他类型 -> 直接返回
|
86
|
+
"""
|
87
|
+
if not is_effective_value(value):
|
88
|
+
return None
|
89
|
+
if isinstance(value, BaseModel):
|
90
|
+
return serialize_value(value.model_dump())
|
91
|
+
if hasattr(value, "dict") and callable(value.dict):
|
92
|
+
return serialize_value(value.dict())
|
93
|
+
if isinstance(value, dict):
|
94
|
+
return {k: serialize_value(v) for k, v in value.items()}
|
95
|
+
if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
|
96
|
+
return [serialize_value(v) for v in value]
|
97
|
+
if isinstance(value, bytes):
|
98
|
+
return f"bytes:{base64.b64encode(value).decode('utf-8')}"
|
99
|
+
return value
|
100
|
+
|
101
|
+
|
102
|
+
def remove_none_from_dict(data: Any) -> Any:
|
103
|
+
"""
|
104
|
+
递归清理数据结构中的 None 值
|
105
|
+
|
106
|
+
遍历字典和列表,移除所有值为 None 的字段,
|
107
|
+
确保发送给服务器的数据结构干净整洁。
|
108
|
+
|
109
|
+
Args:
|
110
|
+
data: 待清理的数据(dict、list 或其他类型)
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
清理后的数据结构
|
114
|
+
|
115
|
+
示例:
|
116
|
+
>>> remove_none_from_dict({"a": 1, "b": None, "c": {"d": None, "e": 2}})
|
117
|
+
{"a": 1, "c": {"e": 2}}
|
118
|
+
"""
|
119
|
+
if isinstance(data, dict):
|
120
|
+
new_dict = {}
|
121
|
+
for key, value in data.items():
|
122
|
+
if value is None:
|
123
|
+
continue
|
124
|
+
# 递归清理嵌套结构
|
125
|
+
cleaned_value = remove_none_from_dict(value)
|
126
|
+
new_dict[key] = cleaned_value
|
127
|
+
return new_dict
|
128
|
+
elif isinstance(data, list):
|
129
|
+
# 递归处理列表中的每个元素
|
130
|
+
return [remove_none_from_dict(item) for item in data]
|
131
|
+
else:
|
132
|
+
# 其他类型直接返回
|
133
|
+
return data
|
134
|
+
|
135
|
+
|
136
|
+
def generate_request_id():
|
137
|
+
"""
|
138
|
+
生成唯一的请求ID
|
139
|
+
|
140
|
+
使用 UUID4 生成全局唯一的请求标识符,
|
141
|
+
用于追踪和调试单个请求的生命周期。
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
str: 格式为 "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx" 的UUID字符串
|
145
|
+
"""
|
146
|
+
return str(uuid.uuid4())
|
147
|
+
|
148
|
+
|
149
|
+
def set_request_id(request_id: str):
|
150
|
+
"""
|
151
|
+
设置当前上下文的请求ID
|
152
|
+
|
153
|
+
在 ContextVar 中设置请求ID,使得在整个异步调用链中
|
154
|
+
都能访问到同一个请求ID,便于日志追踪。
|
155
|
+
|
156
|
+
Args:
|
157
|
+
request_id: 要设置的请求ID字符串
|
158
|
+
"""
|
159
|
+
_request_id.set(request_id)
|
160
|
+
|
161
|
+
|
162
|
+
def get_request_id() -> str:
|
163
|
+
"""
|
164
|
+
获取当前上下文的请求ID
|
165
|
+
|
166
|
+
从 ContextVar 中获取当前的请求ID,如果没有设置则返回默认值 '-'
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
str: 当前的请求ID或默认值
|
170
|
+
"""
|
171
|
+
return _request_id.get()
|
@@ -0,0 +1,315 @@
|
|
1
|
+
"""
|
2
|
+
gRPC 错误处理器
|
3
|
+
|
4
|
+
提供统一的错误处理、恢复策略和重试逻辑。
|
5
|
+
"""
|
6
|
+
|
7
|
+
import asyncio
|
8
|
+
import random
|
9
|
+
import grpc
|
10
|
+
import logging
|
11
|
+
from typing import Optional, Dict, Any, Callable, Union
|
12
|
+
from collections import defaultdict
|
13
|
+
|
14
|
+
from .exceptions import (
|
15
|
+
ErrorContext, TamarModelException,
|
16
|
+
NetworkException, ConnectionException, TimeoutException,
|
17
|
+
AuthenticationException, TokenExpiredException, PermissionDeniedException,
|
18
|
+
ValidationException, InvalidParameterException,
|
19
|
+
RateLimitException, ProviderException,
|
20
|
+
ERROR_CATEGORIES, RETRY_POLICY, ErrorStats
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class GrpcErrorHandler:
|
28
|
+
"""统一的 gRPC 错误处理器"""
|
29
|
+
|
30
|
+
def __init__(self, client_logger: Optional[logging.Logger] = None):
|
31
|
+
self.logger = client_logger or logger
|
32
|
+
self.error_stats = ErrorStats()
|
33
|
+
|
34
|
+
def handle_error(self, error: Union[grpc.RpcError, Exception], context: dict) -> TamarModelException:
|
35
|
+
"""
|
36
|
+
统一错误处理流程:
|
37
|
+
1. 创建错误上下文
|
38
|
+
2. 记录错误日志
|
39
|
+
3. 更新错误统计
|
40
|
+
4. 决定错误类型
|
41
|
+
5. 返回相应异常
|
42
|
+
"""
|
43
|
+
error_context = ErrorContext(error, context)
|
44
|
+
|
45
|
+
# 记录详细错误日志
|
46
|
+
# 将error_context的重要信息平铺到日志的data字段中
|
47
|
+
log_data = {
|
48
|
+
"log_type": "info",
|
49
|
+
"request_id": error_context.request_id,
|
50
|
+
"data": {
|
51
|
+
"error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
|
52
|
+
"error_message": error_context.error_message,
|
53
|
+
"provider": error_context.provider,
|
54
|
+
"model": error_context.model,
|
55
|
+
"method": error_context.method,
|
56
|
+
"retry_count": error_context.retry_count,
|
57
|
+
"category": error_context._get_error_category(),
|
58
|
+
"is_retryable": error_context._is_retryable(),
|
59
|
+
"suggested_action": error_context._get_suggested_action(),
|
60
|
+
"debug_string": error_context.error_debug_string,
|
61
|
+
"is_network_cancelled": error_context.is_network_cancelled() if error_context.error_code == grpc.StatusCode.CANCELLED else None
|
62
|
+
}
|
63
|
+
}
|
64
|
+
|
65
|
+
self.logger.error(
|
66
|
+
f"gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
|
67
|
+
extra=log_data
|
68
|
+
)
|
69
|
+
|
70
|
+
# 更新错误统计
|
71
|
+
if error_context.error_code:
|
72
|
+
self.error_stats.record_error(error_context.error_code)
|
73
|
+
|
74
|
+
# 根据错误类型返回相应异常
|
75
|
+
return self._create_exception(error_context)
|
76
|
+
|
77
|
+
def _create_exception(self, error_context: ErrorContext) -> TamarModelException:
|
78
|
+
"""根据错误上下文创建相应的异常"""
|
79
|
+
error_code = error_context.error_code
|
80
|
+
|
81
|
+
if not error_code:
|
82
|
+
return TamarModelException(error_context)
|
83
|
+
|
84
|
+
# 认证相关错误
|
85
|
+
if error_code in ERROR_CATEGORIES['AUTH']:
|
86
|
+
if error_code == grpc.StatusCode.UNAUTHENTICATED:
|
87
|
+
return TokenExpiredException(error_context)
|
88
|
+
else:
|
89
|
+
return PermissionDeniedException(error_context)
|
90
|
+
|
91
|
+
# 网络相关错误
|
92
|
+
elif error_code in ERROR_CATEGORIES['NETWORK']:
|
93
|
+
if error_code == grpc.StatusCode.DEADLINE_EXCEEDED:
|
94
|
+
return TimeoutException(error_context)
|
95
|
+
else:
|
96
|
+
return ConnectionException(error_context)
|
97
|
+
|
98
|
+
# 验证相关错误
|
99
|
+
elif error_code in ERROR_CATEGORIES['VALIDATION']:
|
100
|
+
return InvalidParameterException(error_context)
|
101
|
+
|
102
|
+
# 资源相关错误
|
103
|
+
elif error_code == grpc.StatusCode.RESOURCE_EXHAUSTED:
|
104
|
+
return RateLimitException(error_context)
|
105
|
+
|
106
|
+
# 服务商相关错误
|
107
|
+
elif error_code in ERROR_CATEGORIES['PROVIDER']:
|
108
|
+
return ProviderException(error_context)
|
109
|
+
|
110
|
+
# 默认错误
|
111
|
+
else:
|
112
|
+
return TamarModelException(error_context)
|
113
|
+
|
114
|
+
def get_error_stats(self) -> Dict[str, Any]:
|
115
|
+
"""获取错误统计信息"""
|
116
|
+
return self.error_stats.get_stats()
|
117
|
+
|
118
|
+
def reset_stats(self):
|
119
|
+
"""重置错误统计"""
|
120
|
+
self.error_stats.reset()
|
121
|
+
|
122
|
+
|
123
|
+
class ErrorRecoveryStrategy:
|
124
|
+
"""错误恢复策略"""
|
125
|
+
|
126
|
+
RECOVERY_ACTIONS = {
|
127
|
+
'refresh_token': 'handle_token_refresh',
|
128
|
+
'reconnect': 'handle_reconnect',
|
129
|
+
'backoff': 'handle_backoff',
|
130
|
+
'circuit_break': 'handle_circuit_break',
|
131
|
+
}
|
132
|
+
|
133
|
+
def __init__(self, client):
|
134
|
+
self.client = client
|
135
|
+
|
136
|
+
async def recover_from_error(self, error_context: ErrorContext):
|
137
|
+
"""根据错误类型执行恢复动作"""
|
138
|
+
if not error_context.error_code:
|
139
|
+
return
|
140
|
+
|
141
|
+
policy = RETRY_POLICY.get(error_context.error_code, {})
|
142
|
+
|
143
|
+
if action := policy.get('action'):
|
144
|
+
if action in self.RECOVERY_ACTIONS:
|
145
|
+
handler = getattr(self, self.RECOVERY_ACTIONS[action])
|
146
|
+
await handler(error_context)
|
147
|
+
|
148
|
+
async def handle_token_refresh(self, error_context: ErrorContext):
|
149
|
+
"""处理 Token 刷新"""
|
150
|
+
self.client.logger.info("Attempting to refresh JWT token")
|
151
|
+
# 这里需要客户端实现 _refresh_jwt_token 方法
|
152
|
+
if hasattr(self.client, '_refresh_jwt_token'):
|
153
|
+
await self.client._refresh_jwt_token()
|
154
|
+
|
155
|
+
async def handle_reconnect(self, error_context: ErrorContext):
|
156
|
+
"""处理重连"""
|
157
|
+
self.client.logger.info("Attempting to reconnect channel")
|
158
|
+
# 这里需要客户端实现 _reconnect_channel 方法
|
159
|
+
if hasattr(self.client, '_reconnect_channel'):
|
160
|
+
await self.client._reconnect_channel()
|
161
|
+
|
162
|
+
async def handle_backoff(self, error_context: ErrorContext):
|
163
|
+
"""处理退避等待"""
|
164
|
+
wait_time = self._calculate_backoff(error_context.retry_count)
|
165
|
+
await asyncio.sleep(wait_time)
|
166
|
+
|
167
|
+
async def handle_circuit_break(self, error_context: ErrorContext):
|
168
|
+
"""处理熔断"""
|
169
|
+
self.client.logger.warning("Circuit breaker activated")
|
170
|
+
# 这里可以实现熔断逻辑
|
171
|
+
pass
|
172
|
+
|
173
|
+
def _calculate_backoff(self, retry_count: int) -> float:
|
174
|
+
"""计算退避时间"""
|
175
|
+
base_delay = 1.0
|
176
|
+
max_delay = 60.0
|
177
|
+
jitter_factor = 0.1
|
178
|
+
|
179
|
+
delay = min(base_delay * (2 ** retry_count), max_delay)
|
180
|
+
jitter = random.uniform(0, delay * jitter_factor)
|
181
|
+
return delay + jitter
|
182
|
+
|
183
|
+
|
184
|
+
class EnhancedRetryHandler:
|
185
|
+
"""增强的重试处理器"""
|
186
|
+
|
187
|
+
def __init__(self, max_retries: int = 3, base_delay: float = 1.0):
|
188
|
+
self.max_retries = max_retries
|
189
|
+
self.base_delay = base_delay
|
190
|
+
self.error_handler = GrpcErrorHandler()
|
191
|
+
|
192
|
+
async def execute_with_retry(
|
193
|
+
self,
|
194
|
+
func: Callable,
|
195
|
+
*args,
|
196
|
+
context: Optional[Dict[str, Any]] = None,
|
197
|
+
**kwargs
|
198
|
+
):
|
199
|
+
"""
|
200
|
+
执行函数并处理重试
|
201
|
+
|
202
|
+
Args:
|
203
|
+
func: 要执行的函数
|
204
|
+
*args: 函数参数
|
205
|
+
context: 请求上下文信息
|
206
|
+
**kwargs: 函数关键字参数
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
函数执行结果
|
210
|
+
|
211
|
+
Raises:
|
212
|
+
TamarModelException: 包装后的异常
|
213
|
+
"""
|
214
|
+
context = context or {}
|
215
|
+
last_exception = None
|
216
|
+
|
217
|
+
for attempt in range(self.max_retries + 1):
|
218
|
+
try:
|
219
|
+
context['retry_count'] = attempt
|
220
|
+
return await func(*args, **kwargs)
|
221
|
+
|
222
|
+
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
|
223
|
+
# 创建错误上下文
|
224
|
+
error_context = ErrorContext(e, context)
|
225
|
+
|
226
|
+
# 判断是否可以重试
|
227
|
+
if not self._should_retry(e, attempt):
|
228
|
+
# 不可重试或已达到最大重试次数
|
229
|
+
last_exception = self.error_handler.handle_error(e, context)
|
230
|
+
break
|
231
|
+
|
232
|
+
# 记录重试日志
|
233
|
+
log_data = {
|
234
|
+
"log_type": "info",
|
235
|
+
"request_id": error_context.request_id,
|
236
|
+
"data": {
|
237
|
+
"error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
|
238
|
+
"error_message": error_context.error_message,
|
239
|
+
"retry_count": attempt,
|
240
|
+
"max_retries": self.max_retries,
|
241
|
+
"category": error_context._get_error_category(),
|
242
|
+
"is_retryable": True, # 既然在重试,说明是可重试的
|
243
|
+
"method": error_context.method
|
244
|
+
}
|
245
|
+
}
|
246
|
+
logger.warning(
|
247
|
+
f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
|
248
|
+
extra=log_data
|
249
|
+
)
|
250
|
+
|
251
|
+
# 执行退避等待
|
252
|
+
if attempt < self.max_retries:
|
253
|
+
delay = self._calculate_backoff(attempt)
|
254
|
+
await asyncio.sleep(delay)
|
255
|
+
|
256
|
+
last_exception = self.error_handler.handle_error(e, context)
|
257
|
+
|
258
|
+
except Exception as e:
|
259
|
+
# 非 gRPC 错误,直接包装抛出
|
260
|
+
context['retry_count'] = attempt
|
261
|
+
error_context = ErrorContext(None, context)
|
262
|
+
error_context.error_message = str(e)
|
263
|
+
last_exception = TamarModelException(error_context)
|
264
|
+
break
|
265
|
+
|
266
|
+
# 抛出最后的异常
|
267
|
+
if last_exception:
|
268
|
+
raise last_exception
|
269
|
+
else:
|
270
|
+
raise TamarModelException("Unknown error occurred")
|
271
|
+
|
272
|
+
def _should_retry(self, error: grpc.RpcError, attempt: int) -> bool:
|
273
|
+
"""判断是否应该重试"""
|
274
|
+
if attempt >= self.max_retries:
|
275
|
+
return False
|
276
|
+
|
277
|
+
error_code = error.code()
|
278
|
+
policy = RETRY_POLICY.get(error_code, {})
|
279
|
+
|
280
|
+
# 检查基本重试策略
|
281
|
+
retryable = policy.get('retryable', False)
|
282
|
+
if retryable == False:
|
283
|
+
return False
|
284
|
+
elif retryable == True:
|
285
|
+
return True
|
286
|
+
elif retryable == 'conditional':
|
287
|
+
# 条件重试,需要检查错误详情
|
288
|
+
return self._check_conditional_retry(error)
|
289
|
+
|
290
|
+
return False
|
291
|
+
|
292
|
+
def _check_conditional_retry(self, error: grpc.RpcError) -> bool:
|
293
|
+
"""检查条件重试"""
|
294
|
+
error_message = error.details().lower() if error.details() else ""
|
295
|
+
|
296
|
+
# 一些可重试的内部错误模式
|
297
|
+
retryable_patterns = [
|
298
|
+
'temporary', 'timeout', 'unavailable',
|
299
|
+
'connection', 'network', 'try again'
|
300
|
+
]
|
301
|
+
|
302
|
+
for pattern in retryable_patterns:
|
303
|
+
if pattern in error_message:
|
304
|
+
return True
|
305
|
+
|
306
|
+
return False
|
307
|
+
|
308
|
+
def _calculate_backoff(self, attempt: int) -> float:
|
309
|
+
"""计算退避时间"""
|
310
|
+
max_delay = 60.0
|
311
|
+
jitter_factor = 0.1
|
312
|
+
|
313
|
+
delay = min(self.base_delay * (2 ** attempt), max_delay)
|
314
|
+
jitter = random.uniform(0, delay * jitter_factor)
|
315
|
+
return delay + jitter
|