tamar-model-client 0.1.21__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,7 +33,7 @@ from .core import (
33
33
  generate_request_id,
34
34
  set_request_id,
35
35
  get_protected_logger,
36
- MAX_MESSAGE_LENGTH
36
+ MAX_MESSAGE_LENGTH, get_request_id
37
37
  )
38
38
  from .core.base_client import BaseClient
39
39
  from .core.request_builder import RequestBuilder
@@ -244,7 +244,16 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
244
244
  Raises:
245
245
  TamarModelException: 当所有重试都失败时
246
246
  """
247
- return await self.retry_handler.execute_with_retry(func, *args, **kwargs)
247
+ # kwargs中提取request_id(如果有的话),然后移除它
248
+ request_id = kwargs.pop('request_id', None) or get_request_id()
249
+
250
+ # 构建包含request_id的上下文
251
+ context = {
252
+ 'method': func.__name__ if hasattr(func, '__name__') else 'unknown',
253
+ 'client_version': 'async',
254
+ 'request_id': request_id,
255
+ }
256
+ return await self.retry_handler.execute_with_retry(func, *args, context=context, **kwargs)
248
257
 
249
258
  async def _retry_request_stream(self, func, *args, **kwargs):
250
259
  """
@@ -260,10 +269,18 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
260
269
  Returns:
261
270
  AsyncIterator: 流式响应迭代器
262
271
  """
272
+ # 记录方法开始时间
273
+ import time
274
+ method_start_time = time.time()
275
+
276
+ # 从kwargs中提取request_id(如果有的话),然后移除它
277
+ request_id = kwargs.pop('request_id', None) or get_request_id()
278
+
263
279
  last_exception = None
264
280
  context = {
265
281
  'method': 'stream',
266
282
  'client_version': 'async',
283
+ 'request_id': request_id,
267
284
  }
268
285
 
269
286
  for attempt in range(self.max_retries + 1):
@@ -283,10 +300,16 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
283
300
  error_context = ErrorContext(e, context)
284
301
  error_code = e.code()
285
302
  policy = get_retry_policy(error_code)
286
- retryable = policy.get('retryable', False)
287
303
 
288
- should_retry = False
289
- if attempt < self.max_retries:
304
+ # 先检查错误级别的 max_attempts 配置
305
+ # max_attempts 表示最大重试次数(不包括初始请求)
306
+ error_max_attempts = policy.get('max_attempts', self.max_retries)
307
+ if attempt >= error_max_attempts:
308
+ should_retry = False
309
+ elif attempt >= self.max_retries:
310
+ should_retry = False
311
+ else:
312
+ retryable = policy.get('retryable', False)
290
313
  if retryable == True:
291
314
  should_retry = True
292
315
  elif retryable == 'conditional':
@@ -295,8 +318,11 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
295
318
  should_retry = error_context.is_network_cancelled()
296
319
  else:
297
320
  should_retry = self._check_error_details_for_retry(e)
321
+ else:
322
+ should_retry = False
298
323
 
299
324
  if should_retry:
325
+ current_duration = time.time() - method_start_time
300
326
  log_data = {
301
327
  "log_type": "info",
302
328
  "request_id": context.get('request_id'),
@@ -305,7 +331,8 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
305
331
  "retry_count": attempt,
306
332
  "max_retries": self.max_retries,
307
333
  "method": "stream"
308
- }
334
+ },
335
+ "duration": current_duration
309
336
  }
310
337
  logger.warning(
311
338
  f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
@@ -317,6 +344,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
317
344
  await asyncio.sleep(delay)
318
345
  else:
319
346
  # 不重试或已达到最大重试次数
347
+ current_duration = time.time() - method_start_time
320
348
  log_data = {
321
349
  "log_type": "info",
322
350
  "request_id": context.get('request_id'),
@@ -326,12 +354,14 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
326
354
  "max_retries": self.max_retries,
327
355
  "method": "stream",
328
356
  "will_retry": False
329
- }
357
+ },
358
+ "duration": current_duration
330
359
  }
331
360
  logger.error(
332
361
  f"Stream failed: {e.code()} (no retry)",
333
362
  extra=log_data
334
363
  )
364
+ context['duration'] = current_duration
335
365
  last_exception = self.error_handler.handle_error(e, context)
336
366
  break
337
367
 
@@ -454,7 +484,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
454
484
  chunk_count = 0
455
485
 
456
486
  # 使用重试逻辑获取流生成器
457
- stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
487
+ stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout, request_id=get_request_id())
458
488
 
459
489
  try:
460
490
  async for response in stream_generator:
@@ -609,7 +639,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
609
639
  # 对于流式响应,直接返回带日志记录的包装器
610
640
  return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request)
611
641
  else:
612
- result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
642
+ result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
613
643
 
614
644
  # 记录非流式响应的成功日志
615
645
  duration = time.time() - start_time
@@ -739,7 +769,8 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
739
769
  self.stub.BatchInvoke,
740
770
  batch_request,
741
771
  metadata=metadata,
742
- timeout=invoke_timeout
772
+ timeout=invoke_timeout,
773
+ request_id=request_id
743
774
  )
744
775
 
745
776
  # 构建响应对象
@@ -11,6 +11,7 @@ import logging
11
11
  from typing import Optional, Dict, Any, Callable, Union
12
12
  from collections import defaultdict
13
13
 
14
+ from .core import get_protected_logger
14
15
  from .exceptions import (
15
16
  ErrorContext, TamarModelException,
16
17
  NetworkException, ConnectionException, TimeoutException,
@@ -20,17 +21,16 @@ from .exceptions import (
20
21
  ERROR_CATEGORIES, RETRY_POLICY, ErrorStats
21
22
  )
22
23
 
23
-
24
- logger = logging.getLogger(__name__)
24
+ logger = get_protected_logger(__name__)
25
25
 
26
26
 
27
27
  class GrpcErrorHandler:
28
28
  """统一的 gRPC 错误处理器"""
29
-
29
+
30
30
  def __init__(self, client_logger: Optional[logging.Logger] = None):
31
31
  self.logger = client_logger or logger
32
32
  self.error_stats = ErrorStats()
33
-
33
+
34
34
  def handle_error(self, error: Union[grpc.RpcError, Exception], context: dict) -> TamarModelException:
35
35
  """
36
36
  统一错误处理流程:
@@ -41,7 +41,7 @@ class GrpcErrorHandler:
41
41
  5. 返回相应异常
42
42
  """
43
43
  error_context = ErrorContext(error, context)
44
-
44
+
45
45
  # 记录详细错误日志
46
46
  # 将error_context的重要信息平铺到日志的data字段中
47
47
  log_data = {
@@ -61,60 +61,64 @@ class GrpcErrorHandler:
61
61
  "is_network_cancelled": error_context.is_network_cancelled() if error_context.error_code == grpc.StatusCode.CANCELLED else None
62
62
  }
63
63
  }
64
-
64
+
65
+ # 如果上下文中有 duration,添加到日志中
66
+ if 'duration' in context:
67
+ log_data['duration'] = context['duration']
68
+
65
69
  self.logger.error(
66
70
  f"gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
67
71
  extra=log_data
68
72
  )
69
-
73
+
70
74
  # 更新错误统计
71
75
  if error_context.error_code:
72
76
  self.error_stats.record_error(error_context.error_code)
73
-
77
+
74
78
  # 根据错误类型返回相应异常
75
79
  return self._create_exception(error_context)
76
-
80
+
77
81
  def _create_exception(self, error_context: ErrorContext) -> TamarModelException:
78
82
  """根据错误上下文创建相应的异常"""
79
83
  error_code = error_context.error_code
80
-
84
+
81
85
  if not error_code:
82
86
  return TamarModelException(error_context)
83
-
87
+
84
88
  # 认证相关错误
85
89
  if error_code in ERROR_CATEGORIES['AUTH']:
86
90
  if error_code == grpc.StatusCode.UNAUTHENTICATED:
87
91
  return TokenExpiredException(error_context)
88
92
  else:
89
93
  return PermissionDeniedException(error_context)
90
-
94
+
91
95
  # 网络相关错误
92
96
  elif error_code in ERROR_CATEGORIES['NETWORK']:
93
97
  if error_code == grpc.StatusCode.DEADLINE_EXCEEDED:
94
98
  return TimeoutException(error_context)
95
99
  else:
96
100
  return ConnectionException(error_context)
97
-
101
+
98
102
  # 验证相关错误
99
103
  elif error_code in ERROR_CATEGORIES['VALIDATION']:
100
104
  return InvalidParameterException(error_context)
101
-
105
+
102
106
  # 资源相关错误
103
107
  elif error_code == grpc.StatusCode.RESOURCE_EXHAUSTED:
104
108
  return RateLimitException(error_context)
105
-
109
+
106
110
  # 服务商相关错误
107
111
  elif error_code in ERROR_CATEGORIES['PROVIDER']:
108
112
  return ProviderException(error_context)
109
-
113
+
110
114
  # 默认错误
111
115
  else:
112
116
  return TamarModelException(error_context)
113
-
117
+
114
118
  def get_error_stats(self) -> Dict[str, Any]:
115
119
  """获取错误统计信息"""
116
120
  return self.error_stats.get_stats()
117
-
121
+
118
122
  def reset_stats(self):
119
123
  """重置错误统计"""
120
124
  self.error_stats.reset()
@@ -122,60 +126,60 @@ class GrpcErrorHandler:
122
126
 
123
127
  class ErrorRecoveryStrategy:
124
128
  """错误恢复策略"""
125
-
129
+
126
130
  RECOVERY_ACTIONS = {
127
131
  'refresh_token': 'handle_token_refresh',
128
132
  'reconnect': 'handle_reconnect',
129
133
  'backoff': 'handle_backoff',
130
134
  'circuit_break': 'handle_circuit_break',
131
135
  }
132
-
136
+
133
137
  def __init__(self, client):
134
138
  self.client = client
135
-
139
+
136
140
  async def recover_from_error(self, error_context: ErrorContext):
137
141
  """根据错误类型执行恢复动作"""
138
142
  if not error_context.error_code:
139
143
  return
140
-
144
+
141
145
  policy = RETRY_POLICY.get(error_context.error_code, {})
142
-
146
+
143
147
  if action := policy.get('action'):
144
148
  if action in self.RECOVERY_ACTIONS:
145
149
  handler = getattr(self, self.RECOVERY_ACTIONS[action])
146
150
  await handler(error_context)
147
-
151
+
148
152
  async def handle_token_refresh(self, error_context: ErrorContext):
149
153
  """处理 Token 刷新"""
150
154
  self.client.logger.info("Attempting to refresh JWT token")
151
155
  # 这里需要客户端实现 _refresh_jwt_token 方法
152
156
  if hasattr(self.client, '_refresh_jwt_token'):
153
157
  await self.client._refresh_jwt_token()
154
-
158
+
155
159
  async def handle_reconnect(self, error_context: ErrorContext):
156
160
  """处理重连"""
157
161
  self.client.logger.info("Attempting to reconnect channel")
158
162
  # 这里需要客户端实现 _reconnect_channel 方法
159
163
  if hasattr(self.client, '_reconnect_channel'):
160
164
  await self.client._reconnect_channel()
161
-
165
+
162
166
  async def handle_backoff(self, error_context: ErrorContext):
163
167
  """处理退避等待"""
164
168
  wait_time = self._calculate_backoff(error_context.retry_count)
165
169
  await asyncio.sleep(wait_time)
166
-
170
+
167
171
  async def handle_circuit_break(self, error_context: ErrorContext):
168
172
  """处理熔断"""
169
173
  self.client.logger.warning("Circuit breaker activated")
170
174
  # 这里可以实现熔断逻辑
171
175
  pass
172
-
176
+
173
177
  def _calculate_backoff(self, retry_count: int) -> float:
174
178
  """计算退避时间"""
175
179
  base_delay = 1.0
176
180
  max_delay = 60.0
177
181
  jitter_factor = 0.1
178
-
182
+
179
183
  delay = min(base_delay * (2 ** retry_count), max_delay)
180
184
  jitter = random.uniform(0, delay * jitter_factor)
181
185
  return delay + jitter
@@ -183,18 +187,18 @@ class ErrorRecoveryStrategy:
183
187
 
184
188
  class EnhancedRetryHandler:
185
189
  """增强的重试处理器"""
186
-
190
+
187
191
  def __init__(self, max_retries: int = 3, base_delay: float = 1.0):
188
192
  self.max_retries = max_retries
189
193
  self.base_delay = base_delay
190
194
  self.error_handler = GrpcErrorHandler()
191
-
195
+
192
196
  async def execute_with_retry(
193
- self,
194
- func: Callable,
195
- *args,
196
- context: Optional[Dict[str, Any]] = None,
197
- **kwargs
197
+ self,
198
+ func: Callable,
199
+ *args,
200
+ context: Optional[Dict[str, Any]] = None,
201
+ **kwargs
198
202
  ):
199
203
  """
200
204
  执行函数并处理重试
@@ -211,24 +215,33 @@ class EnhancedRetryHandler:
211
215
  Raises:
212
216
  TamarModelException: 包装后的异常
213
217
  """
218
+ # 记录开始时间
219
+ import time
220
+ method_start_time = time.time()
221
+
214
222
  context = context or {}
215
223
  last_exception = None
216
-
224
+
217
225
  for attempt in range(self.max_retries + 1):
218
226
  try:
219
227
  context['retry_count'] = attempt
220
228
  return await func(*args, **kwargs)
221
-
229
+
222
230
  except (grpc.RpcError, grpc.aio.AioRpcError) as e:
223
231
  # 创建错误上下文
224
232
  error_context = ErrorContext(e, context)
225
-
233
+
226
234
  # 判断是否可以重试
227
235
  if not self._should_retry(e, attempt):
228
236
  # 不可重试或已达到最大重试次数
237
+ current_duration = time.time() - method_start_time
238
+ context['duration'] = current_duration
229
239
  last_exception = self.error_handler.handle_error(e, context)
230
240
  break
231
-
241
+
242
+ # 计算当前耗时
243
+ current_duration = time.time() - method_start_time
244
+
232
245
  # 记录重试日志
233
246
  log_data = {
234
247
  "log_type": "info",
@@ -241,20 +254,22 @@ class EnhancedRetryHandler:
241
254
  "category": error_context._get_error_category(),
242
255
  "is_retryable": True, # 既然在重试,说明是可重试的
243
256
  "method": error_context.method
244
- }
257
+ },
258
+ "duration": current_duration
245
259
  }
246
260
  logger.warning(
247
261
  f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
248
262
  extra=log_data
249
263
  )
250
-
264
+
251
265
  # 执行退避等待
252
266
  if attempt < self.max_retries:
253
267
  delay = self._calculate_backoff(attempt)
254
268
  await asyncio.sleep(delay)
255
-
269
+
270
+ context['duration'] = current_duration
256
271
  last_exception = self.error_handler.handle_error(e, context)
257
-
272
+
258
273
  except Exception as e:
259
274
  # 非 gRPC 错误,直接包装抛出
260
275
  context['retry_count'] = attempt
@@ -262,21 +277,28 @@ class EnhancedRetryHandler:
262
277
  error_context.error_message = str(e)
263
278
  last_exception = TamarModelException(error_context)
264
279
  break
265
-
280
+
266
281
  # 抛出最后的异常
267
282
  if last_exception:
268
283
  raise last_exception
269
284
  else:
270
285
  raise TamarModelException("Unknown error occurred")
271
-
286
+
272
287
  def _should_retry(self, error: grpc.RpcError, attempt: int) -> bool:
273
288
  """判断是否应该重试"""
274
- if attempt >= self.max_retries:
275
- return False
276
-
277
289
  error_code = error.code()
278
290
  policy = RETRY_POLICY.get(error_code, {})
279
-
291
+
292
+ # 先检查错误级别的 max_attempts 配置
293
+ # max_attempts 表示最大重试次数(不包括初始请求)
294
+ error_max_attempts = policy.get('max_attempts', self.max_retries)
295
+ if attempt >= error_max_attempts:
296
+ return False
297
+
298
+ # 再检查全局的 max_retries
299
+ if attempt >= self.max_retries:
300
+ return False
301
+
280
302
  # 检查基本重试策略
281
303
  retryable = policy.get('retryable', False)
282
304
  if retryable == False:
@@ -286,30 +308,30 @@ class EnhancedRetryHandler:
286
308
  elif retryable == 'conditional':
287
309
  # 条件重试,需要检查错误详情
288
310
  return self._check_conditional_retry(error)
289
-
311
+
290
312
  return False
291
-
313
+
292
314
  def _check_conditional_retry(self, error: grpc.RpcError) -> bool:
293
315
  """检查条件重试"""
294
316
  error_message = error.details().lower() if error.details() else ""
295
-
317
+
296
318
  # 一些可重试的内部错误模式
297
319
  retryable_patterns = [
298
- 'temporary', 'timeout', 'unavailable',
320
+ 'temporary', 'timeout', 'unavailable',
299
321
  'connection', 'network', 'try again'
300
322
  ]
301
-
323
+
302
324
  for pattern in retryable_patterns:
303
325
  if pattern in error_message:
304
326
  return True
305
-
327
+
306
328
  return False
307
-
329
+
308
330
  def _calculate_backoff(self, attempt: int) -> float:
309
331
  """计算退避时间"""
310
332
  max_delay = 60.0
311
333
  jitter_factor = 0.1
312
-
334
+
313
335
  delay = min(self.base_delay * (2 ** attempt), max_delay)
314
336
  jitter = random.uniform(0, delay * jitter_factor)
315
- return delay + jitter
337
+ return delay + jitter
@@ -77,7 +77,7 @@ RETRY_POLICY = {
77
77
  grpc.StatusCode.CANCELLED: {
78
78
  'retryable': True,
79
79
  'backoff': 'linear', # 线性退避,网络问题通常不需要指数退避
80
- 'max_attempts': 2, # 限制重试次数,避免过度重试
80
+ 'max_attempts': 2, # 最大重试次数(不包括初始请求),总共会尝试3次
81
81
  'check_details': False # 不检查详细信息,统一重试
82
82
  },
83
83
  grpc.StatusCode.ABORTED: {
@@ -184,6 +184,37 @@ class ErrorContext:
184
184
  'DATA': '数据损坏或丢失,请检查输入数据',
185
185
  }
186
186
  return suggestions.get(self._get_error_category(), '未知错误,请联系技术支持')
187
+
188
+ def is_network_cancelled(self) -> bool:
189
+ """
190
+ 判断 CANCELLED 错误是否由网络中断导致
191
+
192
+ Returns:
193
+ bool: 如果是网络中断导致的 CANCELLED 返回 True
194
+ """
195
+ if self.error_code != grpc.StatusCode.CANCELLED:
196
+ return False
197
+
198
+ # 检查错误消息中是否包含网络相关的关键词
199
+ error_msg = (self.error_message or '').lower()
200
+ debug_msg = (self.error_debug_string or '').lower()
201
+
202
+ network_patterns = [
203
+ 'connection reset',
204
+ 'connection refused',
205
+ 'connection closed',
206
+ 'network unreachable',
207
+ 'broken pipe',
208
+ 'socket closed',
209
+ 'eof',
210
+ 'transport'
211
+ ]
212
+
213
+ for pattern in network_patterns:
214
+ if pattern in error_msg or pattern in debug_msg:
215
+ return True
216
+
217
+ return False
187
218
 
188
219
 
189
220
  # ===== 异常类层级 =====