tamar-model-client 0.1.27__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -67,7 +67,7 @@ class GrpcErrorHandler:
67
67
  log_data['duration'] = context['duration']
68
68
 
69
69
  self.logger.error(
70
- f"gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
70
+ f"gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
71
71
  extra=log_data
72
72
  )
73
73
 
@@ -151,14 +151,14 @@ class ErrorRecoveryStrategy:
151
151
 
152
152
  async def handle_token_refresh(self, error_context: ErrorContext):
153
153
  """处理 Token 刷新"""
154
- self.client.logger.info("Attempting to refresh JWT token")
154
+ self.client.logger.info("🔄 Attempting to refresh JWT token")
155
155
  # 这里需要客户端实现 _refresh_jwt_token 方法
156
156
  if hasattr(self.client, '_refresh_jwt_token'):
157
157
  await self.client._refresh_jwt_token()
158
158
 
159
159
  async def handle_reconnect(self, error_context: ErrorContext):
160
160
  """处理重连"""
161
- self.client.logger.info("Attempting to reconnect channel")
161
+ self.client.logger.info("🔄 Attempting to reconnect channel")
162
162
  # 这里需要客户端实现 _reconnect_channel 方法
163
163
  if hasattr(self.client, '_reconnect_channel'):
164
164
  await self.client._reconnect_channel()
@@ -170,7 +170,7 @@ class ErrorRecoveryStrategy:
170
170
 
171
171
  async def handle_circuit_break(self, error_context: ErrorContext):
172
172
  """处理熔断"""
173
- self.client.logger.warning("Circuit breaker activated")
173
+ self.client.logger.warning("⚠️ Circuit breaker activated")
174
174
  # 这里可以实现熔断逻辑
175
175
  pass
176
176
 
@@ -230,19 +230,107 @@ class EnhancedRetryHandler:
230
230
  except (grpc.RpcError, grpc.aio.AioRpcError) as e:
231
231
  # 创建错误上下文
232
232
  error_context = ErrorContext(e, context)
233
+ current_duration = time.time() - method_start_time
234
+ context['duration'] = current_duration
233
235
 
234
236
  # 判断是否可以重试
235
- if not self._should_retry(e, attempt):
237
+ should_retry = self._should_retry(e, attempt)
238
+
239
+ # 检查是否应该尝试快速降级(需要从外部注入client引用)
240
+ should_try_fallback = False
241
+ if hasattr(self.error_handler, 'client') and hasattr(self.error_handler.client, '_should_try_fallback'):
242
+ should_try_fallback = self.error_handler.client._should_try_fallback(e.code(), attempt)
243
+
244
+ if should_try_fallback:
245
+ # 尝试快速降级到HTTP
246
+ logger.warning(
247
+ f"🚀 Fast fallback triggered for {e.code().name} after {attempt + 1} attempts",
248
+ extra={
249
+ "log_type": "fast_fallback",
250
+ "request_id": error_context.request_id,
251
+ "data": {
252
+ "error_code": e.code().name,
253
+ "attempt": attempt,
254
+ "fallback_reason": "immediate" if hasattr(self.error_handler.client, 'immediate_fallback_errors') and e.code() in self.error_handler.client.immediate_fallback_errors else "after_retries"
255
+ }
256
+ }
257
+ )
258
+
259
+ try:
260
+ # 尝试HTTP降级(需要从context获取必要参数)
261
+ if hasattr(self.error_handler, 'client'):
262
+ # 检查是否是批量请求
263
+ if hasattr(self.error_handler.client, '_current_batch_request'):
264
+ batch_request = self.error_handler.client._current_batch_request
265
+ origin_request_id = getattr(self.error_handler.client, '_current_origin_request_id', None)
266
+ timeout = context.get('timeout')
267
+ request_id = context.get('request_id')
268
+
269
+ # 尝试批量HTTP降级
270
+ result = await self.error_handler.client._invoke_batch_http_fallback(batch_request, timeout, request_id, origin_request_id)
271
+ elif hasattr(self.error_handler.client, '_current_model_request'):
272
+ model_request = self.error_handler.client._current_model_request
273
+ origin_request_id = getattr(self.error_handler.client, '_current_origin_request_id', None)
274
+ timeout = context.get('timeout')
275
+ request_id = context.get('request_id')
276
+
277
+ # 尝试HTTP降级
278
+ result = await self.error_handler.client._invoke_http_fallback(model_request, timeout, request_id, origin_request_id)
279
+
280
+ logger.info(
281
+ f"✅ Fast fallback to HTTP succeeded",
282
+ extra={
283
+ "log_type": "fast_fallback_success",
284
+ "request_id": request_id,
285
+ "data": {
286
+ "grpc_attempts": attempt + 1,
287
+ "fallback_duration": time.time() - method_start_time
288
+ }
289
+ }
290
+ )
291
+
292
+ return result
293
+ except Exception as fallback_error:
294
+ # 降级失败,记录日志但继续原有重试逻辑
295
+ logger.warning(
296
+ f"⚠️ Fast fallback to HTTP failed: {str(fallback_error)}",
297
+ extra={
298
+ "log_type": "fast_fallback_failed",
299
+ "request_id": error_context.request_id,
300
+ "data": {
301
+ "fallback_error": str(fallback_error),
302
+ "will_continue_grpc_retry": should_retry and attempt < self.max_retries
303
+ }
304
+ }
305
+ )
306
+
307
+ if not should_retry:
236
308
  # 不可重试或已达到最大重试次数
237
- current_duration = time.time() - method_start_time
238
- context['duration'] = current_duration
309
+ # 记录最终失败日志
310
+ log_data = {
311
+ "log_type": "info",
312
+ "request_id": error_context.request_id,
313
+ "data": {
314
+ "error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
315
+ "error_message": error_context.error_message,
316
+ "retry_count": attempt,
317
+ "max_retries": self.max_retries,
318
+ "category": error_context._get_error_category(),
319
+ "is_retryable": False,
320
+ "method": error_context.method,
321
+ "final_failure": True
322
+ },
323
+ "duration": current_duration
324
+ }
325
+ error_detail = f" - {error_context.error_message}" if error_context.error_message else ""
326
+ logger.warning(
327
+ f"⚠️ Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (no more retries)",
328
+ extra=log_data
329
+ )
239
330
  last_exception = self.error_handler.handle_error(e, context)
240
331
  break
241
332
 
242
- # 计算当前耗时
243
- current_duration = time.time() - method_start_time
244
-
245
- # 记录重试日志
333
+ # 可以重试,记录重试日志
246
334
  log_data = {
247
335
  "log_type": "info",
248
336
  "request_id": error_context.request_id,
@@ -252,13 +340,16 @@ class EnhancedRetryHandler:
252
340
  "retry_count": attempt,
253
341
  "max_retries": self.max_retries,
254
342
  "category": error_context._get_error_category(),
255
- "is_retryable": True, # 既然在重试,说明是可重试的
256
- "method": error_context.method
343
+ "is_retryable": True,
344
+ "method": error_context.method,
345
+ "will_retry": True,
346
+ "fallback_attempted": should_try_fallback
257
347
  },
258
348
  "duration": current_duration
259
349
  }
350
+ error_detail = f" - {error_context.error_message}" if error_context.error_message else ""
260
351
  logger.warning(
261
- f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
352
+ f"🔄 Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (will retry)",
262
353
  extra=log_data
263
354
  )
264
355
 
@@ -267,8 +358,8 @@ class EnhancedRetryHandler:
267
358
  delay = self._calculate_backoff(attempt)
268
359
  await asyncio.sleep(delay)
269
360
 
270
- context['duration'] = current_duration
271
- last_exception = self.error_handler.handle_error(e, context)
361
+ # 保存异常,以备后续使用
362
+ last_exception = e
272
363
 
273
364
  except Exception as e:
274
365
  # 非 gRPC 错误,直接包装抛出
@@ -280,7 +371,11 @@ class EnhancedRetryHandler:
280
371
 
281
372
  # 抛出最后的异常
282
373
  if last_exception:
283
- raise last_exception
374
+ if isinstance(last_exception, TamarModelException):
375
+ raise last_exception
376
+ else:
377
+ # 对于原始的 gRPC 异常,需要包装
378
+ raise self.error_handler.handle_error(last_exception, context)
284
379
  else:
285
380
  raise TamarModelException("Unknown error occurred")
286
381
 
@@ -57,5 +57,14 @@ class JSONFormatter(logging.Formatter):
57
57
  if hasattr(record, "trace"):
58
58
  log_data["trace"] = getattr(record, "trace")
59
59
 
60
+ # 添加异常信息(如果有的话)
61
+ if record.exc_info:
62
+ import traceback
63
+ log_data["exception"] = {
64
+ "type": record.exc_info[0].__name__ if record.exc_info[0] else None,
65
+ "message": str(record.exc_info[1]) if record.exc_info[1] else None,
66
+ "traceback": traceback.format_exception(*record.exc_info)
67
+ }
68
+
60
69
  # 使用安全的 JSON 编码器
61
70
  return json.dumps(log_data, ensure_ascii=False, cls=SafeJSONEncoder)
@@ -31,6 +31,7 @@ import grpc
31
31
  from .core import (
32
32
  generate_request_id,
33
33
  set_request_id,
34
+ set_origin_request_id,
34
35
  get_protected_logger,
35
36
  MAX_MESSAGE_LENGTH,
36
37
  get_request_id,
@@ -158,7 +159,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
158
159
  # 如果 channel 存在但不健康,记录日志
159
160
  if self.channel and self.stub:
160
161
  logger.warning(
161
- "Channel exists but unhealthy, will recreate",
162
+ "⚠️ Channel exists but unhealthy, will recreate",
162
163
  extra={
163
164
  "log_type": "channel_recreate",
164
165
  "data": {
@@ -186,7 +187,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
186
187
  "data": {"tls_enabled": True, "server_address": self.server_address}})
187
188
  else:
188
189
  self.channel = grpc.insecure_channel(
189
- self.server_address,
190
+ f"dns:///{self.server_address}",
190
191
  options=options
191
192
  )
192
193
  logger.info("🔓 Using insecure gRPC channel (TLS disabled)",
@@ -237,7 +238,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
237
238
  # 如果处于关闭或失败状态,需要重建
238
239
  if state in [grpc.ChannelConnectivity.SHUTDOWN,
239
240
  grpc.ChannelConnectivity.TRANSIENT_FAILURE]:
240
- logger.warning(f"Channel in unhealthy state: {state}",
241
+ logger.warning(f"⚠️ Channel in unhealthy state: {state}",
241
242
  extra={"log_type": "info",
242
243
  "data": {"channel_state": str(state)}})
243
244
  return False
@@ -245,7 +246,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
245
246
  # 如果最近有多次错误,也需要重建
246
247
  if self._channel_error_count > 3 and self._last_channel_error_time:
247
248
  if time.time() - self._last_channel_error_time < 60: # 60秒内
248
- logger.warning("Too many channel errors recently, marking as unhealthy",
249
+ logger.warning("⚠️ Too many channel errors recently, marking as unhealthy",
249
250
  extra={"log_type": "info",
250
251
  "data": {"error_count": self._channel_error_count}})
251
252
  return False
@@ -253,7 +254,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
253
254
  return True
254
255
 
255
256
  except Exception as e:
256
- logger.error(f"Error checking channel health: {e}",
257
+ logger.error(f"Error checking channel health: {e}",
257
258
  extra={"log_type": "info",
258
259
  "data": {"error": str(e)}})
259
260
  return False
@@ -269,10 +270,10 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
269
270
  if self.channel:
270
271
  try:
271
272
  self.channel.close()
272
- logger.info("Closed unhealthy channel",
273
+ logger.info("🔚 Closed unhealthy channel",
273
274
  extra={"log_type": "info"})
274
275
  except Exception as e:
275
- logger.warning(f"Error closing channel: {e}",
276
+ logger.warning(f"⚠️ Error closing channel: {e}",
276
277
  extra={"log_type": "info"})
277
278
 
278
279
  # 清空引用
@@ -283,7 +284,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
283
284
  self._channel_error_count = 0
284
285
  self._last_channel_error_time = None
285
286
 
286
- logger.info("Recreating gRPC channel...",
287
+ logger.info("🔄 Recreating gRPC channel...",
287
288
  extra={"log_type": "info"})
288
289
 
289
290
  def _record_channel_error(self, error: grpc.RpcError):
@@ -311,7 +312,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
311
312
 
312
313
  # 记录详细的错误信息
313
314
  logger.warning(
314
- f"Channel error recorded: {error.code().name}",
315
+ f"⚠️ Channel error recorded: {error.code().name}",
315
316
  extra={
316
317
  "log_type": "channel_error",
317
318
  "data": {
@@ -352,20 +353,15 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
352
353
  except grpc.RpcError as e:
353
354
  # 使用新的错误处理逻辑
354
355
  context['retry_count'] = attempt
356
+ current_duration = time.time() - method_start_time
355
357
 
356
358
  # 判断是否可以重试
357
359
  should_retry = self._should_retry(e, attempt)
358
- if not should_retry or attempt >= self.max_retries:
359
- # 不可重试或已达到最大重试次数
360
- current_duration = time.time() - method_start_time
361
- context['duration'] = current_duration
362
- last_exception = self.error_handler.handle_error(e, context)
363
- break
364
-
365
- # 计算当前的耗时
366
- current_duration = time.time() - method_start_time
367
360
 
368
- # 特殊处理 CANCELLED 错误
361
+ # 记录 channel 错误
362
+ self._record_channel_error(e)
363
+
364
+ # 特殊处理 CANCELLED 错误的日志
369
365
  if e.code() == grpc.StatusCode.CANCELLED:
370
366
  channel_state = None
371
367
  if self.channel:
@@ -375,7 +371,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
375
371
  channel_state = "UNKNOWN"
376
372
 
377
373
  logger.warning(
378
- f"CANCELLED error detected, channel state: {channel_state}",
374
+ f"⚠️ CANCELLED error detected, channel state: {channel_state}",
379
375
  extra={
380
376
  "log_type": "cancelled_debug",
381
377
  "request_id": context.get('request_id'),
@@ -389,20 +385,125 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
389
385
  }
390
386
  )
391
387
 
392
- # 记录重试日志
388
+ # 检查是否应该尝试快速降级
389
+ should_try_fallback = self._should_try_fallback(e.code(), attempt)
390
+
391
+ if should_try_fallback:
392
+ # 尝试快速降级到HTTP
393
+ logger.warning(
394
+ f"🚀 Fast fallback triggered for {e.code().name} after {attempt + 1} attempts",
395
+ extra={
396
+ "log_type": "fast_fallback",
397
+ "request_id": context.get('request_id'),
398
+ "data": {
399
+ "error_code": e.code().name,
400
+ "attempt": attempt,
401
+ "fallback_reason": "immediate" if e.code() in self.immediate_fallback_errors else "after_retries"
402
+ }
403
+ }
404
+ )
405
+
406
+ try:
407
+ # 从 kwargs 中提取降级所需的参数
408
+ fallback_kwargs = kwargs.copy()
409
+
410
+ # 如果是 _invoke_request,需要提取 model_request
411
+ if func.__name__ == '_invoke_request' and len(args) >= 3:
412
+ # args 结构: (request, metadata, invoke_timeout)
413
+ # 需要从原始参数中恢复 model_request
414
+ if hasattr(self, '_current_model_request'):
415
+ model_request = self._current_model_request
416
+ origin_request_id = getattr(self, '_current_origin_request_id', None)
417
+ timeout = args[2] if len(args) > 2 else None
418
+ request_id = context.get('request_id')
419
+
420
+ # 尝试HTTP降级
421
+ result = self._invoke_http_fallback(model_request, timeout, request_id, origin_request_id)
422
+ # 如果是 BatchInvoke,需要使用批量降级
423
+ elif func.__name__ == 'BatchInvoke' and hasattr(self, '_current_batch_request'):
424
+ batch_request = self._current_batch_request
425
+ origin_request_id = getattr(self, '_current_origin_request_id', None)
426
+ timeout = fallback_kwargs.get('timeout')
427
+ request_id = context.get('request_id')
428
+
429
+ # 尝试批量HTTP降级
430
+ result = self._invoke_batch_http_fallback(batch_request, timeout, request_id, origin_request_id)
431
+ else:
432
+ # 其他情况,无法处理降级
433
+ raise ValueError(f"Unable to perform HTTP fallback for {func.__name__}")
434
+
435
+ logger.info(
436
+ f"✅ Fast fallback to HTTP succeeded",
437
+ extra={
438
+ "log_type": "fast_fallback_success",
439
+ "request_id": request_id,
440
+ "data": {
441
+ "grpc_attempts": attempt + 1,
442
+ "fallback_duration": time.time() - method_start_time
443
+ }
444
+ }
445
+ )
446
+
447
+ return result
448
+ except Exception as fallback_error:
449
+ # 降级失败,记录日志但继续原有重试逻辑
450
+ logger.warning(
451
+ f"⚠️ Fast fallback to HTTP failed: {str(fallback_error)}",
452
+ extra={
453
+ "log_type": "fast_fallback_failed",
454
+ "request_id": context.get('request_id'),
455
+ "data": {
456
+ "fallback_error": str(fallback_error),
457
+ "will_continue_grpc_retry": should_retry and attempt < self.max_retries
458
+ }
459
+ }
460
+ )
461
+
462
+ if not should_retry or attempt >= self.max_retries:
463
+ # 不可重试或已达到最大重试次数
464
+ context['duration'] = current_duration
465
+
466
+ # 记录最终失败日志
467
+ log_data = {
468
+ "log_type": "info",
469
+ "request_id": context.get('request_id'),
470
+ "data": {
471
+ "error_code": e.code().name if e.code() else 'UNKNOWN',
472
+ "error_details": e.details() if hasattr(e, 'details') else '',
473
+ "retry_count": attempt,
474
+ "max_retries": self.max_retries,
475
+ "method": context.get('method', 'unknown'),
476
+ "final_failure": True
477
+ },
478
+ "duration": current_duration
479
+ }
480
+ error_detail = f" - {e.details()}" if e.details() else ""
481
+ logger.warning(
482
+ f"⚠️ Final attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (no more retries)",
483
+ extra=log_data
484
+ )
485
+
486
+ last_exception = self.error_handler.handle_error(e, context)
487
+ break
488
+
489
+ # 可以重试,记录重试日志
393
490
  log_data = {
394
491
  "log_type": "info",
395
492
  "request_id": context.get('request_id'),
396
493
  "data": {
397
494
  "error_code": e.code().name if e.code() else 'UNKNOWN',
495
+ "error_details": e.details() if hasattr(e, 'details') else '',
398
496
  "retry_count": attempt,
399
497
  "max_retries": self.max_retries,
400
- "method": context.get('method', 'unknown')
498
+ "method": context.get('method', 'unknown'),
499
+ "will_retry": True,
500
+ "fallback_attempted": should_try_fallback
401
501
  },
402
502
  "duration": current_duration
403
503
  }
504
+ error_detail = f" - {e.details()}" if e.details() else ""
404
505
  logger.warning(
405
- f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
506
+ f"🔄 Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (will retry)",
406
507
  extra=log_data
407
508
  )
408
509
 
@@ -410,12 +511,9 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
410
511
  if attempt < self.max_retries:
411
512
  delay = self._calculate_backoff(attempt, e.code())
412
513
  time.sleep(delay)
413
-
414
- context['duration'] = current_duration
415
- last_exception = self.error_handler.handle_error(e, context)
416
514
 
417
- # 记录 channel 错误
418
- self._record_channel_error(e)
515
+ # 保存异常,以备后续使用
516
+ last_exception = e
419
517
 
420
518
  except Exception as e:
421
519
  # 非 gRPC 错误,直接包装抛出
@@ -589,6 +687,7 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
589
687
  "request_id": context.get('request_id'),
590
688
  "data": {
591
689
  "error_code": e.code().name if e.code() else 'UNKNOWN',
690
+ "error_details": e.details() if hasattr(e, 'details') else '',
592
691
  "retry_count": attempt,
593
692
  "max_retries": self.max_retries,
594
693
  "method": "stream",
@@ -596,8 +695,9 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
596
695
  },
597
696
  "duration": current_duration
598
697
  }
698
+ error_detail = f" - {e.details()}" if e.details() else ""
599
699
  logger.error(
600
- f"Stream failed: {e.code()} (no retry)",
700
+ f"Stream failed: {e.code()}{error_detail} (no retry)",
601
701
  extra=log_data
602
702
  )
603
703
  context['duration'] = current_duration
@@ -610,14 +710,16 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
610
710
  "request_id": context.get('request_id'),
611
711
  "data": {
612
712
  "error_code": e.code().name if e.code() else 'UNKNOWN',
713
+ "error_details": e.details() if hasattr(e, 'details') else '',
613
714
  "retry_count": attempt,
614
715
  "max_retries": self.max_retries,
615
716
  "method": "stream"
616
717
  },
617
718
  "duration": current_duration
618
719
  }
720
+ error_detail = f" - {e.details()}" if e.details() else ""
619
721
  logger.warning(
620
- f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
722
+ f"🔄 Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (will retry)",
621
723
  extra=log_data
622
724
  )
623
725
 
@@ -803,7 +905,12 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
803
905
  if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
804
906
  if self.http_fallback_url:
805
907
  logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback")
806
- return self._invoke_http_fallback(model_request, timeout, request_id)
908
+ # 在这里还没有计算origin_request_id,所以先计算
909
+ temp_origin_request_id = None
910
+ temp_request_id = request_id
911
+ if request_id:
912
+ temp_request_id, temp_origin_request_id = self._request_id_manager.get_composite_id(request_id)
913
+ return self._invoke_http_fallback(model_request, timeout, temp_request_id, temp_origin_request_id)
807
914
 
808
915
  self._ensure_initialized()
809
916
 
@@ -823,6 +930,8 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
823
930
  request_id = generate_request_id()
824
931
 
825
932
  set_request_id(request_id)
933
+ if origin_request_id:
934
+ set_origin_request_id(origin_request_id)
826
935
  metadata = self._build_auth_metadata(request_id, origin_request_id)
827
936
 
828
937
  # 构建日志数据
@@ -874,7 +983,17 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
874
983
  request_id=request_id
875
984
  )
876
985
  else:
877
- result = self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
986
+ # 存储model_request和origin_request_id供重试方法使用
987
+ self._current_model_request = model_request
988
+ self._current_origin_request_id = origin_request_id
989
+ try:
990
+ result = self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
991
+ finally:
992
+ # 清理临时存储
993
+ if hasattr(self, '_current_model_request'):
994
+ delattr(self, '_current_model_request')
995
+ if hasattr(self, '_current_origin_request_id'):
996
+ delattr(self, '_current_origin_request_id')
878
997
 
879
998
  # 记录非流式响应的成功日志
880
999
  duration = time.time() - start_time
@@ -922,16 +1041,11 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
922
1041
  if isinstance(e, grpc.RpcError):
923
1042
  self._record_channel_error(e)
924
1043
 
925
- # 记录失败并尝试降级(如果启用了熔断)
1044
+ # 记录失败(如果启用了熔断)
926
1045
  if self.resilient_enabled and self.circuit_breaker:
927
1046
  # 将错误码传递给熔断器,用于智能失败统计
928
1047
  error_code = e.code() if hasattr(e, 'code') else None
929
1048
  self.circuit_breaker.record_failure(error_code)
930
-
931
- # 如果可以降级,则降级
932
- if self.http_fallback_url and self.circuit_breaker.should_fallback():
933
- logger.warning(f"🔻 gRPC failed, falling back to HTTP: {str(e)}")
934
- return self._invoke_http_fallback(model_request, timeout, request_id)
935
1049
 
936
1050
  raise e
937
1051
  except Exception as e:
@@ -961,6 +1075,17 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
961
1075
  Returns:
962
1076
  BatchModelResponse: 批量请求的结果
963
1077
  """
1078
+ # 如果启用了熔断且熔断器打开,直接走 HTTP
1079
+ if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
1080
+ if self.http_fallback_url:
1081
+ logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback for batch request")
1082
+ # 在这里还没有计算origin_request_id,所以先计算
1083
+ temp_origin_request_id = None
1084
+ temp_request_id = request_id
1085
+ if request_id:
1086
+ temp_request_id, temp_origin_request_id = self._request_id_manager.get_composite_id(request_id)
1087
+ return self._invoke_batch_http_fallback(batch_request_model, timeout, temp_request_id, temp_origin_request_id)
1088
+
964
1089
  self._ensure_initialized()
965
1090
 
966
1091
  if not self.default_payload:
@@ -979,6 +1104,8 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
979
1104
  request_id = generate_request_id()
980
1105
 
981
1106
  set_request_id(request_id)
1107
+ if origin_request_id:
1108
+ set_origin_request_id(origin_request_id)
982
1109
  metadata = self._build_auth_metadata(request_id, origin_request_id)
983
1110
 
984
1111
  # 构建日志数据
@@ -1025,6 +1152,11 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
1025
1152
 
1026
1153
  try:
1027
1154
  invoke_timeout = timeout or self.default_invoke_timeout
1155
+
1156
+ # 保存批量请求信息用于降级
1157
+ self._current_batch_request = batch_request_model
1158
+ self._current_origin_request_id = origin_request_id
1159
+
1028
1160
  batch_response = self._retry_request(
1029
1161
  self.stub.BatchInvoke,
1030
1162
  batch_request,
@@ -1067,6 +1199,13 @@ class TamarModelClient(BaseClient, HttpFallbackMixin):
1067
1199
  "batch_size": len(batch_request_model.items)
1068
1200
  }
1069
1201
  })
1202
+
1203
+ # 记录失败(如果启用了熔断)
1204
+ if self.resilient_enabled and self.circuit_breaker:
1205
+ # 将错误码传递给熔断器,用于智能失败统计
1206
+ error_code = e.code() if hasattr(e, 'code') else None
1207
+ self.circuit_breaker.record_failure(error_code)
1208
+
1070
1209
  raise e
1071
1210
  except Exception as e:
1072
1211
  duration = time.time() - start_time