tamar-model-client 0.1.27__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -32,6 +32,7 @@ from grpc import RpcError
32
32
  from .core import (
33
33
  generate_request_id,
34
34
  set_request_id,
35
+ set_origin_request_id,
35
36
  get_protected_logger,
36
37
  MAX_MESSAGE_LENGTH,
37
38
  get_request_id,
@@ -113,6 +114,9 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
113
114
  base_delay=self.retry_delay
114
115
  )
115
116
 
117
+ # 设置client引用,用于快速降级
118
+ self.retry_handler.error_handler.client = self
119
+
116
120
  # 注册退出时的清理函数
117
121
  atexit.register(self._cleanup_atexit)
118
122
 
@@ -739,7 +743,12 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
739
743
  if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
740
744
  if self.http_fallback_url:
741
745
  logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback")
742
- return await self._invoke_http_fallback(model_request, timeout, request_id)
746
+ # 在这里还没有计算origin_request_id,所以先计算
747
+ temp_origin_request_id = None
748
+ temp_request_id = request_id
749
+ if request_id:
750
+ temp_request_id, temp_origin_request_id = self._request_id_manager.get_composite_id(request_id)
751
+ return await self._invoke_http_fallback(model_request, timeout, temp_request_id, temp_origin_request_id)
743
752
 
744
753
  await self._ensure_initialized()
745
754
 
@@ -759,6 +768,8 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
759
768
  request_id = generate_request_id()
760
769
 
761
770
  set_request_id(request_id)
771
+ if origin_request_id:
772
+ set_origin_request_id(origin_request_id)
762
773
  metadata = self._build_auth_metadata(request_id, origin_request_id)
763
774
 
764
775
  # 构建日志数据
@@ -806,7 +817,17 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
806
817
  # 对于流式响应,直接返回带日志记录的包装器
807
818
  return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request)
808
819
  else:
809
- result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
820
+ # 存储model_request和origin_request_id供重试方法使用
821
+ self._current_model_request = model_request
822
+ self._current_origin_request_id = origin_request_id
823
+ try:
824
+ result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
825
+ finally:
826
+ # 清理临时存储
827
+ if hasattr(self, '_current_model_request'):
828
+ delattr(self, '_current_model_request')
829
+ if hasattr(self, '_current_origin_request_id'):
830
+ delattr(self, '_current_origin_request_id')
810
831
 
811
832
  # 记录非流式响应的成功日志
812
833
  duration = time.time() - start_time
@@ -854,16 +875,11 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
854
875
  if isinstance(e, grpc.RpcError):
855
876
  self._record_channel_error(e)
856
877
 
857
- # 记录失败并尝试降级(如果启用了熔断)
878
+ # 记录失败(如果启用了熔断)
858
879
  if self.resilient_enabled and self.circuit_breaker:
859
880
  # 将错误码传递给熔断器,用于智能失败统计
860
881
  error_code = e.code() if hasattr(e, 'code') else None
861
882
  self.circuit_breaker.record_failure(error_code)
862
-
863
- # 如果可以降级,则降级
864
- if self.http_fallback_url and self.circuit_breaker.should_fallback():
865
- logger.warning(f"🔻 gRPC failed, falling back to HTTP: {str(e)}")
866
- return await self._invoke_http_fallback(model_request, timeout, request_id)
867
883
 
868
884
  raise e
869
885
  except Exception as e:
@@ -893,6 +909,17 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
893
909
  Returns:
894
910
  BatchModelResponse: 批量请求的结果
895
911
  """
912
+ # 如果启用了熔断且熔断器打开,直接走 HTTP
913
+ if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
914
+ if self.http_fallback_url:
915
+ logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback for batch request")
916
+ # 在这里还没有计算origin_request_id,所以先计算
917
+ temp_origin_request_id = None
918
+ temp_request_id = request_id
919
+ if request_id:
920
+ temp_request_id, temp_origin_request_id = self._request_id_manager.get_composite_id(request_id)
921
+ return await self._invoke_batch_http_fallback(batch_request_model, timeout, temp_request_id, temp_origin_request_id)
922
+
896
923
  await self._ensure_initialized()
897
924
 
898
925
  if not self.default_payload:
@@ -911,6 +938,8 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
911
938
  request_id = generate_request_id()
912
939
 
913
940
  set_request_id(request_id)
941
+ if origin_request_id:
942
+ set_origin_request_id(origin_request_id)
914
943
  metadata = self._build_auth_metadata(request_id, origin_request_id)
915
944
 
916
945
  # 构建日志数据
@@ -957,6 +986,11 @@ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
957
986
 
958
987
  try:
959
988
  invoke_timeout = timeout or self.default_invoke_timeout
989
+
990
+ # 保存批量请求信息用于降级
991
+ self._current_batch_request = batch_request_model
992
+ self._current_origin_request_id = origin_request_id
993
+
960
994
  batch_response = await self._retry_request(
961
995
  self.stub.BatchInvoke,
962
996
  batch_request,
@@ -101,9 +101,12 @@ class CircuitBreaker:
101
101
  logger.warning(
102
102
  f"🔻 Circuit breaker OPENED after {self.failure_count} failures",
103
103
  extra={
104
- "failure_count": self.failure_count,
105
- "threshold": self.failure_threshold,
106
- "trigger_error": error_code.name if error_code else "unknown"
104
+ "log_type": "info",
105
+ "data": {
106
+ "failure_count": self.failure_count,
107
+ "threshold": self.failure_threshold,
108
+ "trigger_error": error_code.name if error_code else "unknown"
109
+ }
107
110
  }
108
111
  )
109
112
 
@@ -10,7 +10,9 @@ from .utils import (
10
10
  remove_none_from_dict,
11
11
  generate_request_id,
12
12
  set_request_id,
13
- get_request_id
13
+ get_request_id,
14
+ set_origin_request_id,
15
+ get_origin_request_id
14
16
  )
15
17
 
16
18
  from .logging_setup import (
@@ -32,6 +34,8 @@ __all__ = [
32
34
  'generate_request_id',
33
35
  'set_request_id',
34
36
  'get_request_id',
37
+ 'set_origin_request_id',
38
+ 'get_origin_request_id',
35
39
  # Logging
36
40
  'setup_logger',
37
41
  'RequestIdFilter',
@@ -6,8 +6,7 @@ and configuration management for both sync and async clients.
6
6
  """
7
7
 
8
8
  import os
9
- import logging
10
- from typing import Optional, Dict, Any
9
+ from typing import Optional
11
10
  from abc import ABC, abstractmethod
12
11
 
13
12
  from ..auth import JWTAuthHandler
@@ -25,7 +24,7 @@ class BaseClient(ABC):
25
24
  - 连接选项构建
26
25
  - 错误处理器初始化
27
26
  """
28
-
27
+
29
28
  def __init__(
30
29
  self,
31
30
  server_address: Optional[str] = None,
@@ -57,40 +56,43 @@ class BaseClient(ABC):
57
56
  self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
58
57
  if not self.server_address:
59
58
  raise ValueError("Server address must be provided via argument or environment variable.")
60
-
59
+
61
60
  # 默认调用超时时间
62
61
  self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
63
-
62
+
64
63
  # === JWT 认证配置 ===
65
64
  self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
66
65
  self.jwt_handler = JWTAuthHandler(self.jwt_secret_key) if self.jwt_secret_key else None
67
66
  self.jwt_token = jwt_token # 用户传入的预生成 Token(可选)
68
67
  self.default_payload = default_payload
69
68
  self.token_expires_in = token_expires_in
70
-
69
+
71
70
  # === TLS/Authority 配置 ===
72
71
  self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
73
72
  self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
74
-
73
+
75
74
  # === 重试配置 ===
76
75
  self.max_retries = max_retries if max_retries is not None else int(
77
76
  os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 6))
78
77
  self.retry_delay = retry_delay if retry_delay is not None else float(
79
78
  os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
80
-
79
+
81
80
  # === 日志配置 ===
82
81
  self.logger = get_protected_logger(logger_name or __name__)
83
-
82
+
84
83
  # === 错误处理器 ===
85
84
  self.error_handler = GrpcErrorHandler(self.logger)
86
85
  self.recovery_strategy = ErrorRecoveryStrategy(self)
87
-
86
+
88
87
  # === 连接状态 ===
89
88
  self._closed = False
90
-
89
+
91
90
  # === 熔断降级配置 ===
92
91
  self._init_resilient_features()
93
-
92
+
93
+ # === 快速降级配置 ===
94
+ self._init_fast_fallback_config()
95
+
94
96
  def build_channel_options(self) -> list:
95
97
  """
96
98
  构建 gRPC 通道选项
@@ -108,29 +110,43 @@ class BaseClient(ABC):
108
110
  # 消息大小限制
109
111
  ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
110
112
  ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
111
-
113
+
112
114
  # Keepalive 核心配置
113
115
  ('grpc.keepalive_time_ms', 30000), # 30秒发送一次 keepalive ping
114
116
  ('grpc.keepalive_timeout_ms', 10000), # ping 响应超时时间 10秒
115
117
  ('grpc.keepalive_permit_without_calls', True), # 空闲时也发送 keepalive
116
118
  ('grpc.http2.max_pings_without_data', 2), # 无数据时最大 ping 次数
117
-
119
+
118
120
  # 连接管理增强配置
119
121
  ('grpc.http2.min_time_between_pings_ms', 10000), # ping 最小间隔 10秒
120
122
  ('grpc.http2.max_connection_idle_ms', 300000), # 最大空闲时间 5分钟
121
123
  ('grpc.http2.max_connection_age_ms', 3600000), # 连接最大生存时间 1小时
122
124
  ('grpc.http2.max_connection_age_grace_ms', 5000), # 优雅关闭时间 5秒
123
-
125
+
124
126
  # 性能相关配置
125
127
  ('grpc.http2.bdp_probe', 1), # 启用带宽延迟探测
126
128
  ('grpc.enable_retries', 1), # 启用内置重试
129
+
130
+ # 启用连接池配置(如果 gRPC 客户端支持)
131
+ ('grpc.keepalive_time_ms', 30000), # 保持活跃的连接时间(30秒)
132
+ ('grpc.keepalive_timeout_ms', 10000), # ping 响应超时时间(10秒)
133
+ ('grpc.max_connection_idle_ms', 300000), # 连接最大空闲时间(5分钟)
134
+
135
+ # 设置资源配额
136
+ ('grpc.resource_quota_size', 1048576000), # 设置资源配额为1GB
137
+
138
+ # 启用负载均衡配置
139
+ ('grpc.lb_policy', 'round_robin'), # 设置负载均衡策略为 round_robin(轮询)
140
+
141
+ # 启用详细的日志记录
142
+ ('grpc.debug', 1), # 启用 gRPC 的调试日志,记录更多的连接和请求信息
127
143
  ]
128
-
144
+
129
145
  if self.default_authority:
130
146
  options.append(("grpc.default_authority", self.default_authority))
131
-
147
+
132
148
  return options
133
-
149
+
134
150
  def _build_auth_metadata(self, request_id: str, origin_request_id: Optional[str] = None) -> list:
135
151
  """
136
152
  构建认证元数据
@@ -146,81 +162,160 @@ class BaseClient(ABC):
146
162
  list: gRPC元数据列表,包含请求ID和认证令牌
147
163
  """
148
164
  metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
149
-
165
+
150
166
  # 如果有原始请求ID,也添加到 headers
151
167
  if origin_request_id:
152
168
  metadata.append(("x-origin-request-id", origin_request_id))
153
-
169
+
154
170
  if self.jwt_handler:
155
171
  self.jwt_token = self.jwt_handler.encode_token(
156
- self.default_payload,
172
+ self.default_payload,
157
173
  expires_in=self.token_expires_in
158
174
  )
159
175
  metadata.append(("authorization", f"Bearer {self.jwt_token}"))
160
-
176
+
161
177
  return metadata
162
-
178
+
163
179
  @abstractmethod
164
180
  def close(self):
165
181
  """关闭客户端连接(由子类实现)"""
166
182
  pass
167
-
183
+
168
184
  @abstractmethod
169
185
  def __enter__(self):
170
186
  """进入上下文管理器(由子类实现)"""
171
187
  pass
172
-
188
+
173
189
  @abstractmethod
174
190
  def __exit__(self, exc_type, exc_val, exc_tb):
175
191
  """退出上下文管理器(由子类实现)"""
176
192
  pass
177
-
193
+
178
194
  def _init_resilient_features(self):
179
195
  """初始化熔断降级特性"""
180
196
  # 是否启用熔断降级
181
197
  self.resilient_enabled = os.getenv('MODEL_CLIENT_RESILIENT_ENABLED', 'false').lower() == 'true'
182
-
198
+
183
199
  if self.resilient_enabled:
184
200
  # HTTP 降级地址
185
201
  self.http_fallback_url = os.getenv('MODEL_CLIENT_HTTP_FALLBACK_URL')
186
-
202
+
187
203
  if not self.http_fallback_url:
188
204
  self.logger.warning("🔶 Resilient mode enabled but MODEL_CLIENT_HTTP_FALLBACK_URL not set")
189
205
  self.resilient_enabled = False
190
206
  return
191
-
207
+
192
208
  # 初始化熔断器
193
209
  from ..circuit_breaker import CircuitBreaker
194
210
  self.circuit_breaker = CircuitBreaker(
195
211
  failure_threshold=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_THRESHOLD', '5')),
196
212
  recovery_timeout=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_TIMEOUT', '60'))
197
213
  )
198
-
214
+
199
215
  # HTTP 客户端(延迟初始化)
200
216
  self._http_client = None
201
217
  self._http_session = None # 异步客户端使用
202
-
218
+
203
219
  self.logger.info(
204
220
  "🛡️ Resilient mode enabled",
205
221
  extra={
206
- "http_fallback_url": self.http_fallback_url,
207
- "circuit_breaker_threshold": self.circuit_breaker.failure_threshold,
208
- "circuit_breaker_timeout": self.circuit_breaker.recovery_timeout
222
+ "log_type": "info",
223
+ "data": {
224
+ "http_fallback_url": self.http_fallback_url,
225
+ "circuit_breaker_threshold": self.circuit_breaker.failure_threshold,
226
+ "circuit_breaker_timeout": self.circuit_breaker.recovery_timeout
227
+ }
209
228
  }
210
229
  )
211
230
  else:
212
231
  self.circuit_breaker = None
213
232
  self.http_fallback_url = None
214
-
233
+ self._http_client = None
234
+ self._http_session = None
235
+
215
236
  def get_resilient_metrics(self):
216
237
  """获取熔断降级指标"""
217
238
  if not self.resilient_enabled or not self.circuit_breaker:
218
239
  return None
219
-
240
+
220
241
  return {
221
242
  "enabled": self.resilient_enabled,
222
243
  "circuit_state": self.circuit_breaker.get_state(),
223
244
  "failure_count": self.circuit_breaker.failure_count,
224
245
  "last_failure_time": self.circuit_breaker.last_failure_time,
225
246
  "http_fallback_url": self.http_fallback_url
226
- }
247
+ }
248
+
249
+ def _init_fast_fallback_config(self):
250
+ """初始化快速降级配置"""
251
+ import grpc
252
+
253
+ # 是否启用快速降级
254
+ self.fast_fallback_enabled = os.getenv('MODEL_CLIENT_FAST_FALLBACK_ENABLED', 'false').lower() == 'true'
255
+
256
+ # 降级前的最大gRPC重试次数
257
+ self.fallback_after_retries = int(os.getenv('MODEL_CLIENT_FALLBACK_AFTER_RETRIES', '1'))
258
+
259
+ # 立即降级的错误码配置
260
+ immediate_fallback_errors = os.getenv('MODEL_CLIENT_IMMEDIATE_FALLBACK_ERRORS',
261
+ 'UNAVAILABLE,DEADLINE_EXCEEDED,CANCELLED')
262
+ self.immediate_fallback_errors = set()
263
+
264
+ if immediate_fallback_errors:
265
+ for error_name in immediate_fallback_errors.split(','):
266
+ error_name = error_name.strip()
267
+ if hasattr(grpc.StatusCode, error_name):
268
+ self.immediate_fallback_errors.add(getattr(grpc.StatusCode, error_name))
269
+
270
+ # 永不降级的错误码
271
+ never_fallback_errors = os.getenv('MODEL_CLIENT_NEVER_FALLBACK_ERRORS',
272
+ 'UNAUTHENTICATED,PERMISSION_DENIED,INVALID_ARGUMENT')
273
+ self.never_fallback_errors = set()
274
+
275
+ if never_fallback_errors:
276
+ for error_name in never_fallback_errors.split(','):
277
+ error_name = error_name.strip()
278
+ if hasattr(grpc.StatusCode, error_name):
279
+ self.never_fallback_errors.add(getattr(grpc.StatusCode, error_name))
280
+
281
+ if self.fast_fallback_enabled:
282
+ self.logger.info(
283
+ "🚀 Fast fallback enabled",
284
+ extra={
285
+ "data": {
286
+ "fallback_after_retries": self.fallback_after_retries,
287
+ "immediate_fallback_errors": [e.name for e in self.immediate_fallback_errors],
288
+ "never_fallback_errors": [e.name for e in self.never_fallback_errors]
289
+ }
290
+ }
291
+ )
292
+
293
+ def _should_try_fallback(self, error_code, attempt: int) -> bool:
294
+ """
295
+ 判断是否应该尝试降级
296
+
297
+ Args:
298
+ error_code: gRPC错误码
299
+ attempt: 当前重试次数
300
+
301
+ Returns:
302
+ bool: 是否应该尝试降级
303
+ """
304
+ # 未启用快速降级
305
+ if not self.fast_fallback_enabled:
306
+ return False
307
+
308
+ # 未启用熔断降级功能
309
+ if not self.resilient_enabled or not self.http_fallback_url:
310
+ return False
311
+
312
+ # 永不降级的错误类型
313
+ if error_code in self.never_fallback_errors:
314
+ return False
315
+
316
+ # 立即降级的错误类型
317
+ if error_code in self.immediate_fallback_errors:
318
+ return True
319
+
320
+ # 其他错误在达到重试次数后降级
321
+ return attempt >= self.fallback_after_retries