tamar-model-client 0.1.27__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,6 +32,7 @@ from grpc import RpcError
32
32
  from .core import (
33
33
  generate_request_id,
34
34
  set_request_id,
35
+ set_origin_request_id,
35
36
  get_protected_logger,
36
37
  MAX_MESSAGE_LENGTH,
37
38
  get_request_id,
@@ -102,7 +103,6 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
102
103
  self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
103
104
  self._channel_error_count = 0
104
105
  self._last_channel_error_time = None
105
- self._channel_lock = asyncio.Lock() # 异步锁
106
106
 
107
107
  # === Request ID 管理 ===
108
108
  self._request_id_manager = RequestIdManager()
@@ -113,6 +113,9 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
113
113
  base_delay=self.retry_delay
114
114
  )
115
115
 
116
+ # 设置client引用,用于快速降级
117
+ self.retry_handler.error_handler.client = self
118
+
116
119
  # 注册退出时的清理函数
117
120
  atexit.register(self._cleanup_atexit)
118
121
 
@@ -190,7 +193,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
190
193
  # 如果 channel 存在但不健康,记录日志
191
194
  if self.channel and self.stub:
192
195
  logger.warning(
193
- "Channel exists but unhealthy, will recreate",
196
+ "⚠️ Channel exists but unhealthy, will recreate",
194
197
  extra={
195
198
  "log_type": "channel_recreate",
196
199
  "data": {
@@ -218,7 +221,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
218
221
  "data": {"tls_enabled": True, "server_address": self.server_address}})
219
222
  else:
220
223
  self.channel = grpc.aio.insecure_channel(
221
- self.server_address,
224
+ f"dns:///{self.server_address}",
222
225
  options=options
223
226
  )
224
227
  logger.info("🔓 Using insecure gRPC channel (TLS disabled)",
@@ -268,7 +271,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
268
271
  # 如果处于关闭或失败状态,需要重建
269
272
  if state in [grpc.ChannelConnectivity.SHUTDOWN,
270
273
  grpc.ChannelConnectivity.TRANSIENT_FAILURE]:
271
- logger.warning(f"Channel in unhealthy state: {state}",
274
+ logger.warning(f"⚠️ Channel in unhealthy state: {state}",
272
275
  extra={"log_type": "info",
273
276
  "data": {"channel_state": str(state)}})
274
277
  return False
@@ -276,7 +279,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
276
279
  # 如果最近有多次错误,也需要重建
277
280
  if self._channel_error_count > 3 and self._last_channel_error_time:
278
281
  if time.time() - self._last_channel_error_time < 60: # 60秒内
279
- logger.warning("Too many channel errors recently, marking as unhealthy",
282
+ logger.warning("⚠️ Too many channel errors recently, marking as unhealthy",
280
283
  extra={"log_type": "info",
281
284
  "data": {"error_count": self._channel_error_count}})
282
285
  return False
@@ -284,7 +287,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
284
287
  return True
285
288
 
286
289
  except Exception as e:
287
- logger.error(f"Error checking channel health: {e}",
290
+ logger.error(f"Error checking channel health: {e}",
288
291
  extra={"log_type": "info",
289
292
  "data": {"error": str(e)}})
290
293
  return False
@@ -295,27 +298,26 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
295
298
 
296
299
  关闭旧的 channel 并创建新的连接
297
300
  """
298
- async with self._channel_lock:
299
- # 关闭旧 channel
300
- if self.channel:
301
- try:
302
- await self.channel.close()
303
- logger.info("Closed unhealthy channel",
304
- extra={"log_type": "info"})
305
- except Exception as e:
306
- logger.warning(f"Error closing channel: {e}",
307
- extra={"log_type": "info"})
308
-
309
- # 清空引用
310
- self.channel = None
311
- self.stub = None
312
-
313
- # 重置错误计数
314
- self._channel_error_count = 0
315
- self._last_channel_error_time = None
316
-
317
- logger.info("Recreating gRPC channel...",
318
- extra={"log_type": "info"})
301
+ # 关闭旧 channel
302
+ if self.channel:
303
+ try:
304
+ await self.channel.close()
305
+ logger.info("🔚 Closed unhealthy channel",
306
+ extra={"log_type": "info"})
307
+ except Exception as e:
308
+ logger.warning(f"⚠️ Error closing channel: {e}",
309
+ extra={"log_type": "info"})
310
+
311
+ # 清空引用
312
+ self.channel = None
313
+ self.stub = None
314
+
315
+ # 重置错误计数
316
+ self._channel_error_count = 0
317
+ self._last_channel_error_time = None
318
+
319
+ logger.info("🔄 Recreating gRPC channel...",
320
+ extra={"log_type": "info"})
319
321
 
320
322
  def _record_channel_error(self, error: grpc.RpcError):
321
323
  """
@@ -342,7 +344,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
342
344
 
343
345
  # 记录详细的错误信息
344
346
  logger.warning(
345
- f"Channel error recorded: {error.code().name}",
347
+ f"⚠️ Channel error recorded: {error.code().name}",
346
348
  extra={
347
349
  "log_type": "channel_error",
348
350
  "data": {
@@ -453,7 +455,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
453
455
  is_network_cancelled = error_context.is_network_cancelled()
454
456
 
455
457
  logger.warning(
456
- f"CANCELLED error in stream, channel state: {channel_state}",
458
+ f"⚠️ CANCELLED error in stream, channel state: {channel_state}",
457
459
  extra={
458
460
  "log_type": "cancelled_debug",
459
461
  "request_id": context.get('request_id'),
@@ -481,14 +483,16 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
481
483
  "request_id": context.get('request_id'),
482
484
  "data": {
483
485
  "error_code": e.code().name if e.code() else 'UNKNOWN',
486
+ "error_details": e.details() if hasattr(e, 'details') else '',
484
487
  "retry_count": attempt,
485
488
  "max_retries": self.max_retries,
486
489
  "method": "stream"
487
490
  },
488
491
  "duration": current_duration
489
492
  }
493
+ error_detail = f" - {e.details()}" if e.details() else ""
490
494
  logger.warning(
491
- f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
495
+ f"🔄 Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (will retry)",
492
496
  extra=log_data
493
497
  )
494
498
 
@@ -503,6 +507,7 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
503
507
  "request_id": context.get('request_id'),
504
508
  "data": {
505
509
  "error_code": e.code().name if e.code() else 'UNKNOWN',
510
+ "error_details": e.details() if hasattr(e, 'details') else '',
506
511
  "retry_count": attempt,
507
512
  "max_retries": self.max_retries,
508
513
  "method": "stream",
@@ -510,8 +515,9 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
510
515
  },
511
516
  "duration": current_duration
512
517
  }
513
- logger.error(
514
- f"Stream failed: {e.code()} (no retry)",
518
+ error_detail = f" - {e.details()}" if e.details() else ""
519
+ logger.warning(
520
+ f"⚠️ Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}{error_detail} (no more retries)",
515
521
  extra=log_data
516
522
  )
517
523
  context['duration'] = current_duration
@@ -739,7 +745,12 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
739
745
  if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
740
746
  if self.http_fallback_url:
741
747
  logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback")
742
- return await self._invoke_http_fallback(model_request, timeout, request_id)
748
+ # 在这里还没有计算origin_request_id,所以先计算
749
+ temp_origin_request_id = None
750
+ temp_request_id = request_id
751
+ if request_id:
752
+ temp_request_id, temp_origin_request_id = self._request_id_manager.get_composite_id(request_id)
753
+ return await self._invoke_http_fallback(model_request, timeout, temp_request_id, temp_origin_request_id)
743
754
 
744
755
  await self._ensure_initialized()
745
756
 
@@ -759,6 +770,8 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
759
770
  request_id = generate_request_id()
760
771
 
761
772
  set_request_id(request_id)
773
+ if origin_request_id:
774
+ set_origin_request_id(origin_request_id)
762
775
  metadata = self._build_auth_metadata(request_id, origin_request_id)
763
776
 
764
777
  # 构建日志数据
@@ -806,7 +819,17 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
806
819
  # 对于流式响应,直接返回带日志记录的包装器
807
820
  return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request)
808
821
  else:
809
- result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
822
+ # 存储model_request和origin_request_id供重试方法使用
823
+ self._current_model_request = model_request
824
+ self._current_origin_request_id = origin_request_id
825
+ try:
826
+ result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
827
+ finally:
828
+ # 清理临时存储
829
+ if hasattr(self, '_current_model_request'):
830
+ delattr(self, '_current_model_request')
831
+ if hasattr(self, '_current_origin_request_id'):
832
+ delattr(self, '_current_origin_request_id')
810
833
 
811
834
  # 记录非流式响应的成功日志
812
835
  duration = time.time() - start_time
@@ -854,16 +877,11 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
854
877
  if isinstance(e, grpc.RpcError):
855
878
  self._record_channel_error(e)
856
879
 
857
- # 记录失败并尝试降级(如果启用了熔断)
880
+ # 记录失败(如果启用了熔断)
858
881
  if self.resilient_enabled and self.circuit_breaker:
859
882
  # 将错误码传递给熔断器,用于智能失败统计
860
883
  error_code = e.code() if hasattr(e, 'code') else None
861
884
  self.circuit_breaker.record_failure(error_code)
862
-
863
- # 如果可以降级,则降级
864
- if self.http_fallback_url and self.circuit_breaker.should_fallback():
865
- logger.warning(f"🔻 gRPC failed, falling back to HTTP: {str(e)}")
866
- return await self._invoke_http_fallback(model_request, timeout, request_id)
867
885
 
868
886
  raise e
869
887
  except Exception as e:
@@ -893,6 +911,17 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
893
911
  Returns:
894
912
  BatchModelResponse: 批量请求的结果
895
913
  """
914
+ # 如果启用了熔断且熔断器打开,直接走 HTTP
915
+ if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
916
+ if self.http_fallback_url:
917
+ logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback for batch request")
918
+ # 在这里还没有计算origin_request_id,所以先计算
919
+ temp_origin_request_id = None
920
+ temp_request_id = request_id
921
+ if request_id:
922
+ temp_request_id, temp_origin_request_id = self._request_id_manager.get_composite_id(request_id)
923
+ return await self._invoke_batch_http_fallback(batch_request_model, timeout, temp_request_id, temp_origin_request_id)
924
+
896
925
  await self._ensure_initialized()
897
926
 
898
927
  if not self.default_payload:
@@ -911,6 +940,8 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
911
940
  request_id = generate_request_id()
912
941
 
913
942
  set_request_id(request_id)
943
+ if origin_request_id:
944
+ set_origin_request_id(origin_request_id)
914
945
  metadata = self._build_auth_metadata(request_id, origin_request_id)
915
946
 
916
947
  # 构建日志数据
@@ -957,6 +988,11 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
957
988
 
958
989
  try:
959
990
  invoke_timeout = timeout or self.default_invoke_timeout
991
+
992
+ # 保存批量请求信息用于降级
993
+ self._current_batch_request = batch_request_model
994
+ self._current_origin_request_id = origin_request_id
995
+
960
996
  batch_response = await self._retry_request(
961
997
  self.stub.BatchInvoke,
962
998
  batch_request,
@@ -999,6 +1035,13 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
999
1035
  "batch_size": len(batch_request_model.items)
1000
1036
  }
1001
1037
  })
1038
+
1039
+ # 记录失败(如果启用了熔断)
1040
+ if self.resilient_enabled and self.circuit_breaker:
1041
+ # 将错误码传递给熔断器,用于智能失败统计
1042
+ error_code = e.code() if hasattr(e, 'code') else None
1043
+ self.circuit_breaker.record_failure(error_code)
1044
+
1002
1045
  raise e
1003
1046
  except Exception as e:
1004
1047
  duration = time.time() - start_time
@@ -101,9 +101,12 @@ class CircuitBreaker:
101
101
  logger.warning(
102
102
  f"🔻 Circuit breaker OPENED after {self.failure_count} failures",
103
103
  extra={
104
- "failure_count": self.failure_count,
105
- "threshold": self.failure_threshold,
106
- "trigger_error": error_code.name if error_code else "unknown"
104
+ "log_type": "info",
105
+ "data": {
106
+ "failure_count": self.failure_count,
107
+ "threshold": self.failure_threshold,
108
+ "trigger_error": error_code.name if error_code else "unknown"
109
+ }
107
110
  }
108
111
  )
109
112
 
@@ -10,7 +10,9 @@ from .utils import (
10
10
  remove_none_from_dict,
11
11
  generate_request_id,
12
12
  set_request_id,
13
- get_request_id
13
+ get_request_id,
14
+ set_origin_request_id,
15
+ get_origin_request_id
14
16
  )
15
17
 
16
18
  from .logging_setup import (
@@ -32,6 +34,8 @@ __all__ = [
32
34
  'generate_request_id',
33
35
  'set_request_id',
34
36
  'get_request_id',
37
+ 'set_origin_request_id',
38
+ 'get_origin_request_id',
35
39
  # Logging
36
40
  'setup_logger',
37
41
  'RequestIdFilter',
@@ -6,8 +6,7 @@ and configuration management for both sync and async clients.
6
6
  """
7
7
 
8
8
  import os
9
- import logging
10
- from typing import Optional, Dict, Any
9
+ from typing import Optional
11
10
  from abc import ABC, abstractmethod
12
11
 
13
12
  from ..auth import JWTAuthHandler
@@ -25,7 +24,7 @@ class BaseClient(ABC):
25
24
  - 连接选项构建
26
25
  - 错误处理器初始化
27
26
  """
28
-
27
+
29
28
  def __init__(
30
29
  self,
31
30
  server_address: Optional[str] = None,
@@ -57,40 +56,43 @@ class BaseClient(ABC):
57
56
  self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
58
57
  if not self.server_address:
59
58
  raise ValueError("Server address must be provided via argument or environment variable.")
60
-
59
+
61
60
  # 默认调用超时时间
62
61
  self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
63
-
62
+
64
63
  # === JWT 认证配置 ===
65
64
  self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
66
65
  self.jwt_handler = JWTAuthHandler(self.jwt_secret_key) if self.jwt_secret_key else None
67
66
  self.jwt_token = jwt_token # 用户传入的预生成 Token(可选)
68
67
  self.default_payload = default_payload
69
68
  self.token_expires_in = token_expires_in
70
-
69
+
71
70
  # === TLS/Authority 配置 ===
72
71
  self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
73
72
  self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
74
-
73
+
75
74
  # === 重试配置 ===
76
75
  self.max_retries = max_retries if max_retries is not None else int(
77
76
  os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 6))
78
77
  self.retry_delay = retry_delay if retry_delay is not None else float(
79
78
  os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
80
-
79
+
81
80
  # === 日志配置 ===
82
81
  self.logger = get_protected_logger(logger_name or __name__)
83
-
82
+
84
83
  # === 错误处理器 ===
85
84
  self.error_handler = GrpcErrorHandler(self.logger)
86
85
  self.recovery_strategy = ErrorRecoveryStrategy(self)
87
-
86
+
88
87
  # === 连接状态 ===
89
88
  self._closed = False
90
-
89
+
91
90
  # === 熔断降级配置 ===
92
91
  self._init_resilient_features()
93
-
92
+
93
+ # === 快速降级配置 ===
94
+ self._init_fast_fallback_config()
95
+
94
96
  def build_channel_options(self) -> list:
95
97
  """
96
98
  构建 gRPC 通道选项
@@ -108,29 +110,40 @@ class BaseClient(ABC):
108
110
  # 消息大小限制
109
111
  ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
110
112
  ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
111
-
113
+
112
114
  # Keepalive 核心配置
113
115
  ('grpc.keepalive_time_ms', 30000), # 30秒发送一次 keepalive ping
114
116
  ('grpc.keepalive_timeout_ms', 10000), # ping 响应超时时间 10秒
115
117
  ('grpc.keepalive_permit_without_calls', True), # 空闲时也发送 keepalive
116
118
  ('grpc.http2.max_pings_without_data', 2), # 无数据时最大 ping 次数
117
-
119
+
118
120
  # 连接管理增强配置
119
121
  ('grpc.http2.min_time_between_pings_ms', 10000), # ping 最小间隔 10秒
120
122
  ('grpc.http2.max_connection_idle_ms', 300000), # 最大空闲时间 5分钟
121
123
  ('grpc.http2.max_connection_age_ms', 3600000), # 连接最大生存时间 1小时
122
124
  ('grpc.http2.max_connection_age_grace_ms', 5000), # 优雅关闭时间 5秒
123
-
125
+
124
126
  # 性能相关配置
125
127
  ('grpc.http2.bdp_probe', 1), # 启用带宽延迟探测
126
128
  ('grpc.enable_retries', 1), # 启用内置重试
129
+
130
+ # 启用连接池配置(如果 gRPC 客户端支持)
131
+ ('grpc.keepalive_time_ms', 30000), # 保持活跃的连接时间(30秒)
132
+ ('grpc.keepalive_timeout_ms', 10000), # ping 响应超时时间(10秒)
133
+ ('grpc.max_connection_idle_ms', 300000), # 连接最大空闲时间(5分钟)
134
+
135
+ # 设置资源配额
136
+ ('grpc.resource_quota_size', 1048576000), # 设置资源配额为1GB
137
+
138
+ # 启用负载均衡配置
139
+ ('grpc.lb_policy_name', 'round_robin'), # 设置负载均衡策略为 round_robin(轮询)
127
140
  ]
128
-
141
+
129
142
  if self.default_authority:
130
143
  options.append(("grpc.default_authority", self.default_authority))
131
-
144
+
132
145
  return options
133
-
146
+
134
147
  def _build_auth_metadata(self, request_id: str, origin_request_id: Optional[str] = None) -> list:
135
148
  """
136
149
  构建认证元数据
@@ -146,81 +159,164 @@ class BaseClient(ABC):
146
159
  list: gRPC元数据列表,包含请求ID和认证令牌
147
160
  """
148
161
  metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
149
-
162
+
150
163
  # 如果有原始请求ID,也添加到 headers
151
164
  if origin_request_id:
152
165
  metadata.append(("x-origin-request-id", origin_request_id))
153
-
166
+
154
167
  if self.jwt_handler:
155
168
  self.jwt_token = self.jwt_handler.encode_token(
156
- self.default_payload,
169
+ self.default_payload,
157
170
  expires_in=self.token_expires_in
158
171
  )
159
172
  metadata.append(("authorization", f"Bearer {self.jwt_token}"))
160
-
173
+
161
174
  return metadata
162
-
175
+
163
176
  @abstractmethod
164
177
  def close(self):
165
178
  """关闭客户端连接(由子类实现)"""
166
179
  pass
167
-
180
+
168
181
  @abstractmethod
169
182
  def __enter__(self):
170
183
  """进入上下文管理器(由子类实现)"""
171
184
  pass
172
-
185
+
173
186
  @abstractmethod
174
187
  def __exit__(self, exc_type, exc_val, exc_tb):
175
188
  """退出上下文管理器(由子类实现)"""
176
189
  pass
177
-
190
+
178
191
  def _init_resilient_features(self):
179
192
  """初始化熔断降级特性"""
180
193
  # 是否启用熔断降级
181
194
  self.resilient_enabled = os.getenv('MODEL_CLIENT_RESILIENT_ENABLED', 'false').lower() == 'true'
182
-
195
+
183
196
  if self.resilient_enabled:
184
197
  # HTTP 降级地址
185
198
  self.http_fallback_url = os.getenv('MODEL_CLIENT_HTTP_FALLBACK_URL')
186
-
199
+
187
200
  if not self.http_fallback_url:
188
201
  self.logger.warning("🔶 Resilient mode enabled but MODEL_CLIENT_HTTP_FALLBACK_URL not set")
189
202
  self.resilient_enabled = False
190
203
  return
191
-
204
+
192
205
  # 初始化熔断器
193
206
  from ..circuit_breaker import CircuitBreaker
194
207
  self.circuit_breaker = CircuitBreaker(
195
208
  failure_threshold=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_THRESHOLD', '5')),
196
209
  recovery_timeout=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_TIMEOUT', '60'))
197
210
  )
198
-
211
+
199
212
  # HTTP 客户端(延迟初始化)
200
213
  self._http_client = None
201
214
  self._http_session = None # 异步客户端使用
202
-
215
+
203
216
  self.logger.info(
204
217
  "🛡️ Resilient mode enabled",
205
218
  extra={
206
- "http_fallback_url": self.http_fallback_url,
207
- "circuit_breaker_threshold": self.circuit_breaker.failure_threshold,
208
- "circuit_breaker_timeout": self.circuit_breaker.recovery_timeout
219
+ "log_type": "info",
220
+ "data": {
221
+ "http_fallback_url": self.http_fallback_url,
222
+ "circuit_breaker_threshold": self.circuit_breaker.failure_threshold,
223
+ "circuit_breaker_timeout": self.circuit_breaker.recovery_timeout
224
+ }
209
225
  }
210
226
  )
211
227
  else:
212
228
  self.circuit_breaker = None
213
229
  self.http_fallback_url = None
214
-
230
+ self._http_client = None
231
+ self._http_session = None
232
+
215
233
  def get_resilient_metrics(self):
216
234
  """获取熔断降级指标"""
217
235
  if not self.resilient_enabled or not self.circuit_breaker:
218
236
  return None
219
-
237
+
220
238
  return {
221
239
  "enabled": self.resilient_enabled,
222
- "circuit_state": self.circuit_breaker.get_state(),
223
- "failure_count": self.circuit_breaker.failure_count,
224
- "last_failure_time": self.circuit_breaker.last_failure_time,
240
+ "circuit_breaker": {
241
+ "state": self.circuit_breaker.get_state(),
242
+ "failure_count": self.circuit_breaker.failure_count,
243
+ "last_failure_time": self.circuit_breaker.last_failure_time,
244
+ "failure_threshold": self.circuit_breaker.failure_threshold,
245
+ "recovery_timeout": self.circuit_breaker.recovery_timeout
246
+ },
225
247
  "http_fallback_url": self.http_fallback_url
226
- }
248
+ }
249
+
250
+ def _init_fast_fallback_config(self):
251
+ """初始化快速降级配置"""
252
+ import grpc
253
+
254
+ # 是否启用快速降级
255
+ self.fast_fallback_enabled = os.getenv('MODEL_CLIENT_FAST_FALLBACK_ENABLED', 'false').lower() == 'true'
256
+
257
+ # 降级前的最大gRPC重试次数
258
+ self.fallback_after_retries = int(os.getenv('MODEL_CLIENT_FALLBACK_AFTER_RETRIES', '1'))
259
+
260
+ # 立即降级的错误码配置
261
+ immediate_fallback_errors = os.getenv('MODEL_CLIENT_IMMEDIATE_FALLBACK_ERRORS',
262
+ 'UNAVAILABLE,DEADLINE_EXCEEDED,CANCELLED')
263
+ self.immediate_fallback_errors = set()
264
+
265
+ if immediate_fallback_errors:
266
+ for error_name in immediate_fallback_errors.split(','):
267
+ error_name = error_name.strip()
268
+ if hasattr(grpc.StatusCode, error_name):
269
+ self.immediate_fallback_errors.add(getattr(grpc.StatusCode, error_name))
270
+
271
+ # 永不降级的错误码
272
+ never_fallback_errors = os.getenv('MODEL_CLIENT_NEVER_FALLBACK_ERRORS',
273
+ 'UNAUTHENTICATED,PERMISSION_DENIED,INVALID_ARGUMENT')
274
+ self.never_fallback_errors = set()
275
+
276
+ if never_fallback_errors:
277
+ for error_name in never_fallback_errors.split(','):
278
+ error_name = error_name.strip()
279
+ if hasattr(grpc.StatusCode, error_name):
280
+ self.never_fallback_errors.add(getattr(grpc.StatusCode, error_name))
281
+
282
+ if self.fast_fallback_enabled:
283
+ self.logger.info(
284
+ "🚀 Fast fallback enabled",
285
+ extra={
286
+ "data": {
287
+ "fallback_after_retries": self.fallback_after_retries,
288
+ "immediate_fallback_errors": [e.name for e in self.immediate_fallback_errors],
289
+ "never_fallback_errors": [e.name for e in self.never_fallback_errors]
290
+ }
291
+ }
292
+ )
293
+
294
+ def _should_try_fallback(self, error_code, attempt: int) -> bool:
295
+ """
296
+ 判断是否应该尝试降级
297
+
298
+ Args:
299
+ error_code: gRPC错误码
300
+ attempt: 当前重试次数
301
+
302
+ Returns:
303
+ bool: 是否应该尝试降级
304
+ """
305
+ # 未启用快速降级
306
+ if not self.fast_fallback_enabled:
307
+ return False
308
+
309
+ # 未启用熔断降级功能
310
+ if not self.resilient_enabled or not self.http_fallback_url:
311
+ return False
312
+
313
+ # 永不降级的错误类型
314
+ if error_code in self.never_fallback_errors:
315
+ return False
316
+
317
+ # 立即降级的错误类型
318
+ if error_code in self.immediate_fallback_errors:
319
+ return True
320
+
321
+ # 其他错误在达到重试次数后降级
322
+ return attempt >= self.fallback_after_retries