tamar-model-client 0.1.19__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,171 @@
1
+ """
2
+ Common utility functions for Tamar Model Client
3
+
4
+ This module contains shared utility functions used by both sync and async clients.
5
+ All functions in this module are pure functions without side effects.
6
+ """
7
+
8
+ import base64
9
+ import uuid
10
+ from typing import Any, Iterable
11
+ from contextvars import ContextVar
12
+
13
+ from openai import NOT_GIVEN
14
+ from pydantic import BaseModel
15
+
16
+ # 使用 contextvars 管理请求ID,支持异步和同步上下文中的请求追踪
17
+ _request_id: ContextVar[str] = ContextVar('request_id', default='-')
18
+
19
+
20
+ def is_effective_value(value) -> bool:
21
+ """
22
+ 递归判断值是否为有效值
23
+
24
+ 用于过滤掉空值、None、NOT_GIVEN 等无意义的参数,
25
+ 确保只有有效的参数被发送到服务器。
26
+
27
+ Args:
28
+ value: 待检查的值
29
+
30
+ Returns:
31
+ bool: True 表示值有效,False 表示值无效
32
+
33
+ 处理的无效值类型:
34
+ - None 和 NOT_GIVEN
35
+ - 空字符串(仅包含空白字符)
36
+ - 空字节序列
37
+ - 空字典(所有值都无效)
38
+ - 空列表(所有元素都无效)
39
+ """
40
+ if value is None or value is NOT_GIVEN:
41
+ return False
42
+
43
+ if isinstance(value, str):
44
+ return value.strip() != ""
45
+
46
+ if isinstance(value, bytes):
47
+ return len(value) > 0
48
+
49
+ if isinstance(value, dict):
50
+ # 递归检查字典中的所有值
51
+ for v in value.values():
52
+ if is_effective_value(v):
53
+ return True
54
+ return False
55
+
56
+ if isinstance(value, list):
57
+ # 递归检查列表中的所有元素
58
+ for item in value:
59
+ if is_effective_value(item):
60
+ return True
61
+ return False
62
+
63
+ # 其他类型(int/float/bool)只要不是 None 就算有效
64
+ return True
65
+
66
+
67
+ def serialize_value(value):
68
+ """
69
+ 递归序列化值,处理各种复杂数据类型
70
+
71
+ 将 Pydantic 模型、字典、列表、字节等复杂类型转换为
72
+ 可以发送给 gRPC 服务的简单类型。
73
+
74
+ Args:
75
+ value: 待序列化的值
76
+
77
+ Returns:
78
+ 序列化后的值,如果值无效则返回 None
79
+
80
+ 支持的类型转换:
81
+ - BaseModel -> dict (通过 model_dump)
82
+ - bytes -> base64 字符串
83
+ - dict -> 递归处理所有键值对
84
+ - list -> 递归处理所有元素
85
+ - 其他类型 -> 直接返回
86
+ """
87
+ if not is_effective_value(value):
88
+ return None
89
+ if isinstance(value, BaseModel):
90
+ return serialize_value(value.model_dump())
91
+ if hasattr(value, "dict") and callable(value.dict):
92
+ return serialize_value(value.dict())
93
+ if isinstance(value, dict):
94
+ return {k: serialize_value(v) for k, v in value.items()}
95
+ if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
96
+ return [serialize_value(v) for v in value]
97
+ if isinstance(value, bytes):
98
+ return f"bytes:{base64.b64encode(value).decode('utf-8')}"
99
+ return value
100
+
101
+
102
+ def remove_none_from_dict(data: Any) -> Any:
103
+ """
104
+ 递归清理数据结构中的 None 值
105
+
106
+ 遍历字典和列表,移除所有值为 None 的字段,
107
+ 确保发送给服务器的数据结构干净整洁。
108
+
109
+ Args:
110
+ data: 待清理的数据(dict、list 或其他类型)
111
+
112
+ Returns:
113
+ 清理后的数据结构
114
+
115
+ 示例:
116
+ >>> remove_none_from_dict({"a": 1, "b": None, "c": {"d": None, "e": 2}})
117
+ {"a": 1, "c": {"e": 2}}
118
+ """
119
+ if isinstance(data, dict):
120
+ new_dict = {}
121
+ for key, value in data.items():
122
+ if value is None:
123
+ continue
124
+ # 递归清理嵌套结构
125
+ cleaned_value = remove_none_from_dict(value)
126
+ new_dict[key] = cleaned_value
127
+ return new_dict
128
+ elif isinstance(data, list):
129
+ # 递归处理列表中的每个元素
130
+ return [remove_none_from_dict(item) for item in data]
131
+ else:
132
+ # 其他类型直接返回
133
+ return data
134
+
135
+
136
+ def generate_request_id():
137
+ """
138
+ 生成唯一的请求ID
139
+
140
+ 使用 UUID4 生成全局唯一的请求标识符,
141
+ 用于追踪和调试单个请求的生命周期。
142
+
143
+ Returns:
144
+ str: 格式为 "xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx" 的UUID字符串
145
+ """
146
+ return str(uuid.uuid4())
147
+
148
+
149
+ def set_request_id(request_id: str):
150
+ """
151
+ 设置当前上下文的请求ID
152
+
153
+ 在 ContextVar 中设置请求ID,使得在整个异步调用链中
154
+ 都能访问到同一个请求ID,便于日志追踪。
155
+
156
+ Args:
157
+ request_id: 要设置的请求ID字符串
158
+ """
159
+ _request_id.set(request_id)
160
+
161
+
162
+ def get_request_id() -> str:
163
+ """
164
+ 获取当前上下文的请求ID
165
+
166
+ 从 ContextVar 中获取当前的请求ID,如果没有设置则返回默认值 '-'
167
+
168
+ Returns:
169
+ str: 当前的请求ID或默认值
170
+ """
171
+ return _request_id.get()
@@ -0,0 +1,283 @@
1
+ """
2
+ gRPC 错误处理器
3
+
4
+ 提供统一的错误处理、恢复策略和重试逻辑。
5
+ """
6
+
7
+ import asyncio
8
+ import random
9
+ import grpc
10
+ import logging
11
+ from typing import Optional, Dict, Any, Callable, Union
12
+ from collections import defaultdict
13
+
14
+ from .exceptions import (
15
+ ErrorContext, TamarModelException,
16
+ NetworkException, ConnectionException, TimeoutException,
17
+ AuthenticationException, TokenExpiredException, PermissionDeniedException,
18
+ ValidationException, InvalidParameterException,
19
+ RateLimitException, ProviderException,
20
+ ERROR_CATEGORIES, RETRY_POLICY, ErrorStats
21
+ )
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class GrpcErrorHandler:
28
+ """统一的 gRPC 错误处理器"""
29
+
30
+ def __init__(self, client_logger: Optional[logging.Logger] = None):
31
+ self.logger = client_logger or logger
32
+ self.error_stats = ErrorStats()
33
+
34
+ def handle_error(self, error: Union[grpc.RpcError, Exception], context: dict) -> TamarModelException:
35
+ """
36
+ 统一错误处理流程:
37
+ 1. 创建错误上下文
38
+ 2. 记录错误日志
39
+ 3. 更新错误统计
40
+ 4. 决定错误类型
41
+ 5. 返回相应异常
42
+ """
43
+ error_context = ErrorContext(error, context)
44
+
45
+ # 记录详细错误日志
46
+ self.logger.error(
47
+ f"gRPC Error occurred: {error_context.error_code}",
48
+ extra=error_context.to_dict()
49
+ )
50
+
51
+ # 更新错误统计
52
+ if error_context.error_code:
53
+ self.error_stats.record_error(error_context.error_code)
54
+
55
+ # 根据错误类型返回相应异常
56
+ return self._create_exception(error_context)
57
+
58
+ def _create_exception(self, error_context: ErrorContext) -> TamarModelException:
59
+ """根据错误上下文创建相应的异常"""
60
+ error_code = error_context.error_code
61
+
62
+ if not error_code:
63
+ return TamarModelException(error_context)
64
+
65
+ # 认证相关错误
66
+ if error_code in ERROR_CATEGORIES['AUTH']:
67
+ if error_code == grpc.StatusCode.UNAUTHENTICATED:
68
+ return TokenExpiredException(error_context)
69
+ else:
70
+ return PermissionDeniedException(error_context)
71
+
72
+ # 网络相关错误
73
+ elif error_code in ERROR_CATEGORIES['NETWORK']:
74
+ if error_code == grpc.StatusCode.DEADLINE_EXCEEDED:
75
+ return TimeoutException(error_context)
76
+ else:
77
+ return ConnectionException(error_context)
78
+
79
+ # 验证相关错误
80
+ elif error_code in ERROR_CATEGORIES['VALIDATION']:
81
+ return InvalidParameterException(error_context)
82
+
83
+ # 资源相关错误
84
+ elif error_code == grpc.StatusCode.RESOURCE_EXHAUSTED:
85
+ return RateLimitException(error_context)
86
+
87
+ # 服务商相关错误
88
+ elif error_code in ERROR_CATEGORIES['PROVIDER']:
89
+ return ProviderException(error_context)
90
+
91
+ # 默认错误
92
+ else:
93
+ return TamarModelException(error_context)
94
+
95
+ def get_error_stats(self) -> Dict[str, Any]:
96
+ """获取错误统计信息"""
97
+ return self.error_stats.get_stats()
98
+
99
+ def reset_stats(self):
100
+ """重置错误统计"""
101
+ self.error_stats.reset()
102
+
103
+
104
+ class ErrorRecoveryStrategy:
105
+ """错误恢复策略"""
106
+
107
+ RECOVERY_ACTIONS = {
108
+ 'refresh_token': 'handle_token_refresh',
109
+ 'reconnect': 'handle_reconnect',
110
+ 'backoff': 'handle_backoff',
111
+ 'circuit_break': 'handle_circuit_break',
112
+ }
113
+
114
+ def __init__(self, client):
115
+ self.client = client
116
+
117
+ async def recover_from_error(self, error_context: ErrorContext):
118
+ """根据错误类型执行恢复动作"""
119
+ if not error_context.error_code:
120
+ return
121
+
122
+ policy = RETRY_POLICY.get(error_context.error_code, {})
123
+
124
+ if action := policy.get('action'):
125
+ if action in self.RECOVERY_ACTIONS:
126
+ handler = getattr(self, self.RECOVERY_ACTIONS[action])
127
+ await handler(error_context)
128
+
129
+ async def handle_token_refresh(self, error_context: ErrorContext):
130
+ """处理 Token 刷新"""
131
+ self.client.logger.info("Attempting to refresh JWT token")
132
+ # 这里需要客户端实现 _refresh_jwt_token 方法
133
+ if hasattr(self.client, '_refresh_jwt_token'):
134
+ await self.client._refresh_jwt_token()
135
+
136
+ async def handle_reconnect(self, error_context: ErrorContext):
137
+ """处理重连"""
138
+ self.client.logger.info("Attempting to reconnect channel")
139
+ # 这里需要客户端实现 _reconnect_channel 方法
140
+ if hasattr(self.client, '_reconnect_channel'):
141
+ await self.client._reconnect_channel()
142
+
143
+ async def handle_backoff(self, error_context: ErrorContext):
144
+ """处理退避等待"""
145
+ wait_time = self._calculate_backoff(error_context.retry_count)
146
+ await asyncio.sleep(wait_time)
147
+
148
+ async def handle_circuit_break(self, error_context: ErrorContext):
149
+ """处理熔断"""
150
+ self.client.logger.warning("Circuit breaker activated")
151
+ # 这里可以实现熔断逻辑
152
+ pass
153
+
154
+ def _calculate_backoff(self, retry_count: int) -> float:
155
+ """计算退避时间"""
156
+ base_delay = 1.0
157
+ max_delay = 60.0
158
+ jitter_factor = 0.1
159
+
160
+ delay = min(base_delay * (2 ** retry_count), max_delay)
161
+ jitter = random.uniform(0, delay * jitter_factor)
162
+ return delay + jitter
163
+
164
+
165
+ class EnhancedRetryHandler:
166
+ """增强的重试处理器"""
167
+
168
+ def __init__(self, max_retries: int = 3, base_delay: float = 1.0):
169
+ self.max_retries = max_retries
170
+ self.base_delay = base_delay
171
+ self.error_handler = GrpcErrorHandler()
172
+
173
+ async def execute_with_retry(
174
+ self,
175
+ func: Callable,
176
+ *args,
177
+ context: Optional[Dict[str, Any]] = None,
178
+ **kwargs
179
+ ):
180
+ """
181
+ 执行函数并处理重试
182
+
183
+ Args:
184
+ func: 要执行的函数
185
+ *args: 函数参数
186
+ context: 请求上下文信息
187
+ **kwargs: 函数关键字参数
188
+
189
+ Returns:
190
+ 函数执行结果
191
+
192
+ Raises:
193
+ TamarModelException: 包装后的异常
194
+ """
195
+ context = context or {}
196
+ last_exception = None
197
+
198
+ for attempt in range(self.max_retries + 1):
199
+ try:
200
+ context['retry_count'] = attempt
201
+ return await func(*args, **kwargs)
202
+
203
+ except (grpc.RpcError, grpc.aio.AioRpcError) as e:
204
+ # 创建错误上下文
205
+ error_context = ErrorContext(e, context)
206
+
207
+ # 判断是否可以重试
208
+ if not self._should_retry(e, attempt):
209
+ # 不可重试或已达到最大重试次数
210
+ last_exception = self.error_handler.handle_error(e, context)
211
+ break
212
+
213
+ # 记录重试日志
214
+ logger.warning(
215
+ f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
216
+ extra=error_context.to_dict()
217
+ )
218
+
219
+ # 执行退避等待
220
+ if attempt < self.max_retries:
221
+ delay = self._calculate_backoff(attempt)
222
+ await asyncio.sleep(delay)
223
+
224
+ last_exception = self.error_handler.handle_error(e, context)
225
+
226
+ except Exception as e:
227
+ # 非 gRPC 错误,直接包装抛出
228
+ context['retry_count'] = attempt
229
+ error_context = ErrorContext(None, context)
230
+ error_context.error_message = str(e)
231
+ last_exception = TamarModelException(error_context)
232
+ break
233
+
234
+ # 抛出最后的异常
235
+ if last_exception:
236
+ raise last_exception
237
+ else:
238
+ raise TamarModelException("Unknown error occurred")
239
+
240
+ def _should_retry(self, error: grpc.RpcError, attempt: int) -> bool:
241
+ """判断是否应该重试"""
242
+ if attempt >= self.max_retries:
243
+ return False
244
+
245
+ error_code = error.code()
246
+ policy = RETRY_POLICY.get(error_code, {})
247
+
248
+ # 检查基本重试策略
249
+ retryable = policy.get('retryable', False)
250
+ if retryable == False:
251
+ return False
252
+ elif retryable == True:
253
+ return True
254
+ elif retryable == 'conditional':
255
+ # 条件重试,需要检查错误详情
256
+ return self._check_conditional_retry(error)
257
+
258
+ return False
259
+
260
+ def _check_conditional_retry(self, error: grpc.RpcError) -> bool:
261
+ """检查条件重试"""
262
+ error_message = error.details().lower() if error.details() else ""
263
+
264
+ # 一些可重试的内部错误模式
265
+ retryable_patterns = [
266
+ 'temporary', 'timeout', 'unavailable',
267
+ 'connection', 'network', 'try again'
268
+ ]
269
+
270
+ for pattern in retryable_patterns:
271
+ if pattern in error_message:
272
+ return True
273
+
274
+ return False
275
+
276
+ def _calculate_backoff(self, attempt: int) -> float:
277
+ """计算退避时间"""
278
+ max_delay = 60.0
279
+ jitter_factor = 0.1
280
+
281
+ delay = min(self.base_delay * (2 ** attempt), max_delay)
282
+ jitter = random.uniform(0, delay * jitter_factor)
283
+ return delay + jitter