tamar-model-client 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,6 +22,7 @@ import asyncio
22
22
  import atexit
23
23
  import json
24
24
  import logging
25
+ import random
25
26
  import time
26
27
  from typing import Optional, AsyncIterator, Union
27
28
 
@@ -31,7 +32,7 @@ from grpc import RpcError
31
32
  from .core import (
32
33
  generate_request_id,
33
34
  set_request_id,
34
- setup_logger,
35
+ get_protected_logger,
35
36
  MAX_MESSAGE_LENGTH
36
37
  )
37
38
  from .core.base_client import BaseClient
@@ -42,12 +43,13 @@ from .exceptions import ConnectionError, TamarModelException
42
43
  from .error_handler import EnhancedRetryHandler
43
44
  from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
44
45
  from .generated import model_service_pb2, model_service_pb2_grpc
46
+ from .core.http_fallback import AsyncHttpFallbackMixin
45
47
 
46
- # 配置日志记录器
47
- logger = setup_logger(__name__)
48
+ # 配置日志记录器(使用受保护的logger)
49
+ logger = get_protected_logger(__name__)
48
50
 
49
51
 
50
- class AsyncTamarModelClient(BaseClient):
52
+ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
51
53
  """
52
54
  Tamar Model Client 异步客户端
53
55
 
@@ -127,6 +129,10 @@ class AsyncTamarModelClient(BaseClient):
127
129
  self._closed = True
128
130
  logger.info("🔒 gRPC channel closed",
129
131
  extra={"log_type": "info", "data": {"status": "closed"}})
132
+
133
+ # 清理 HTTP session(如果有)
134
+ if self.resilient_enabled:
135
+ await self._cleanup_http_session()
130
136
 
131
137
  async def __aenter__(self):
132
138
  """异步上下文管理器入口"""
@@ -255,32 +261,152 @@ class AsyncTamarModelClient(BaseClient):
255
261
  AsyncIterator: 流式响应迭代器
256
262
  """
257
263
  last_exception = None
264
+ context = {
265
+ 'method': 'stream',
266
+ 'client_version': 'async',
267
+ }
258
268
 
259
269
  for attempt in range(self.max_retries + 1):
260
270
  try:
271
+ context['retry_count'] = attempt
261
272
  # 尝试创建流
262
273
  async for item in func(*args, **kwargs):
263
274
  yield item
264
275
  return
265
276
 
266
277
  except RpcError as e:
267
- last_exception = e
278
+ # 使用智能重试判断
279
+ context['retry_count'] = attempt
280
+
281
+ # 创建错误上下文并判断是否应该重试
282
+ from .exceptions import ErrorContext, get_retry_policy
283
+ error_context = ErrorContext(e, context)
284
+ error_code = e.code()
285
+ policy = get_retry_policy(error_code)
286
+ retryable = policy.get('retryable', False)
287
+
288
+ should_retry = False
268
289
  if attempt < self.max_retries:
290
+ if retryable == True:
291
+ should_retry = True
292
+ elif retryable == 'conditional':
293
+ # 条件重试,特殊处理 CANCELLED
294
+ if error_code == grpc.StatusCode.CANCELLED:
295
+ should_retry = error_context.is_network_cancelled()
296
+ else:
297
+ should_retry = self._check_error_details_for_retry(e)
298
+
299
+ if should_retry:
300
+ log_data = {
301
+ "log_type": "info",
302
+ "request_id": context.get('request_id'),
303
+ "data": {
304
+ "error_code": e.code().name if e.code() else 'UNKNOWN',
305
+ "retry_count": attempt,
306
+ "max_retries": self.max_retries,
307
+ "method": "stream"
308
+ }
309
+ }
269
310
  logger.warning(
270
- f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
271
- extra={"retry_count": attempt, "error_code": str(e.code())}
311
+ f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
312
+ extra=log_data
272
313
  )
273
- await asyncio.sleep(self.retry_delay * (attempt + 1))
314
+
315
+ # 计算退避时间
316
+ delay = self._calculate_backoff(attempt, error_code)
317
+ await asyncio.sleep(delay)
274
318
  else:
319
+ # 不重试或已达到最大重试次数
320
+ log_data = {
321
+ "log_type": "info",
322
+ "request_id": context.get('request_id'),
323
+ "data": {
324
+ "error_code": e.code().name if e.code() else 'UNKNOWN',
325
+ "retry_count": attempt,
326
+ "max_retries": self.max_retries,
327
+ "method": "stream",
328
+ "will_retry": False
329
+ }
330
+ }
331
+ logger.error(
332
+ f"Stream failed: {e.code()} (no retry)",
333
+ extra=log_data
334
+ )
335
+ last_exception = self.error_handler.handle_error(e, context)
275
336
  break
337
+
338
+ last_exception = e
339
+
276
340
  except Exception as e:
341
+ context['retry_count'] = attempt
277
342
  raise TamarModelException(str(e)) from e
278
343
 
279
344
  if last_exception:
280
- raise self.error_handler.handle_error(last_exception, {"retry_count": self.max_retries})
345
+ if isinstance(last_exception, TamarModelException):
346
+ raise last_exception
347
+ else:
348
+ raise self.error_handler.handle_error(last_exception, context)
281
349
  else:
282
350
  raise TamarModelException("Unknown streaming error occurred")
283
351
 
352
+ def _check_error_details_for_retry(self, error: RpcError) -> bool:
353
+ """检查错误详情决定是否重试"""
354
+ error_message = error.details().lower() if error.details() else ""
355
+
356
+ # 可重试的错误模式
357
+ retryable_patterns = [
358
+ 'temporary', 'timeout', 'unavailable',
359
+ 'connection', 'network', 'try again'
360
+ ]
361
+
362
+ for pattern in retryable_patterns:
363
+ if pattern in error_message:
364
+ return True
365
+
366
+ return False
367
+
368
+ def _calculate_backoff(self, attempt: int, error_code = None) -> float:
369
+ """
370
+ 计算退避时间,支持不同的退避策略
371
+
372
+ Args:
373
+ attempt: 当前重试次数
374
+ error_code: gRPC错误码,用于确定退避策略
375
+ """
376
+ max_delay = 60.0
377
+ base_delay = self.retry_delay
378
+
379
+ # 获取错误的重试策略
380
+ if error_code:
381
+ from .exceptions import get_retry_policy
382
+ policy = get_retry_policy(error_code)
383
+ backoff_type = policy.get('backoff', 'exponential')
384
+ use_jitter = policy.get('jitter', False)
385
+ else:
386
+ backoff_type = 'exponential'
387
+ use_jitter = False
388
+
389
+ # 根据退避类型计算延迟
390
+ if backoff_type == 'linear':
391
+ # 线性退避:delay * (attempt + 1)
392
+ delay = min(base_delay * (attempt + 1), max_delay)
393
+ else:
394
+ # 指数退避:delay * 2^attempt
395
+ delay = min(base_delay * (2 ** attempt), max_delay)
396
+
397
+ # 添加抖动
398
+ if use_jitter:
399
+ jitter_factor = 0.2 # 增加抖动范围,减少竞争
400
+ jitter = random.uniform(0, delay * jitter_factor)
401
+ delay += jitter
402
+ else:
403
+ # 默认的小量抖动,避免完全同步
404
+ jitter_factor = 0.05
405
+ jitter = random.uniform(0, delay * jitter_factor)
406
+ delay += jitter
407
+
408
+ return delay
409
+
284
410
  async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
285
411
  """
286
412
  处理流式响应
@@ -424,6 +550,12 @@ class AsyncTamarModelClient(BaseClient):
424
550
  ValidationError: 输入验证失败。
425
551
  ConnectionError: 连接服务端失败。
426
552
  """
553
+ # 如果启用了熔断且熔断器打开,直接走 HTTP
554
+ if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
555
+ if self.http_fallback_url:
556
+ logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback")
557
+ return await self._invoke_http_fallback(model_request, timeout, request_id)
558
+
427
559
  await self._ensure_initialized()
428
560
 
429
561
  if not self.default_payload:
@@ -491,9 +623,14 @@ class AsyncTamarModelClient(BaseClient):
491
623
  "data": ResponseHandler.build_log_data(model_request, result)
492
624
  }
493
625
  )
626
+
627
+ # 记录成功(如果启用了熔断)
628
+ if self.resilient_enabled and self.circuit_breaker:
629
+ self.circuit_breaker.record_success()
630
+
494
631
  return result
495
632
 
496
- except grpc.RpcError as e:
633
+ except (ConnectionError, grpc.RpcError) as e:
497
634
  duration = time.time() - start_time
498
635
  error_message = f"❌ Invoke gRPC failed: {str(e)}"
499
636
  logger.error(error_message, exc_info=True,
@@ -506,6 +643,18 @@ class AsyncTamarModelClient(BaseClient):
506
643
  error=e
507
644
  )
508
645
  })
646
+
647
+ # 记录失败并尝试降级(如果启用了熔断)
648
+ if self.resilient_enabled and self.circuit_breaker:
649
+ # 将错误码传递给熔断器,用于智能失败统计
650
+ error_code = e.code() if hasattr(e, 'code') else None
651
+ self.circuit_breaker.record_failure(error_code)
652
+
653
+ # 如果可以降级,则降级
654
+ if self.http_fallback_url and self.circuit_breaker.should_fallback():
655
+ logger.warning(f"🔻 gRPC failed, falling back to HTTP: {str(e)}")
656
+ return await self._invoke_http_fallback(model_request, timeout, request_id)
657
+
509
658
  raise e
510
659
  except Exception as e:
511
660
  duration = time.time() - start_time
@@ -0,0 +1,140 @@
1
+ """
2
+ Circuit Breaker implementation for resilient client
3
+
4
+ This module provides a thread-safe circuit breaker pattern implementation
5
+ to handle failures gracefully and prevent cascading failures.
6
+ """
7
+
8
+ import time
9
+ import logging
10
+ from enum import Enum
11
+ from threading import Lock
12
+ from typing import Optional
13
+
14
+ from .core.logging_setup import get_protected_logger
15
+
16
+ logger = get_protected_logger(__name__)
17
+
18
+
19
+ class CircuitState(Enum):
20
+ """Circuit breaker states"""
21
+ CLOSED = "closed" # Normal operation
22
+ OPEN = "open" # Circuit is broken, requests fail fast
23
+ HALF_OPEN = "half_open" # Testing if service has recovered
24
+
25
+
26
+ class CircuitBreaker:
27
+ """
28
+ Thread-safe circuit breaker implementation
29
+
30
+ The circuit breaker prevents cascading failures by failing fast when
31
+ a service is unavailable, and automatically recovers when the service
32
+ becomes available again.
33
+ """
34
+
35
+ def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
36
+ """
37
+ Initialize the circuit breaker
38
+
39
+ Args:
40
+ failure_threshold: Number of consecutive failures before opening circuit
41
+ recovery_timeout: Seconds to wait before attempting recovery
42
+ """
43
+ self.failure_threshold = failure_threshold
44
+ self.recovery_timeout = recovery_timeout
45
+ self.failure_count = 0
46
+ self.last_failure_time: Optional[float] = None
47
+ self.state = CircuitState.CLOSED
48
+ self._lock = Lock()
49
+
50
+ @property
51
+ def is_open(self) -> bool:
52
+ """Check if circuit breaker is open"""
53
+ with self._lock:
54
+ if self.state == CircuitState.OPEN:
55
+ # Check if we should attempt recovery
56
+ if (self.last_failure_time and
57
+ time.time() - self.last_failure_time > self.recovery_timeout):
58
+ self.state = CircuitState.HALF_OPEN
59
+ logger.info("🔄 Circuit breaker entering HALF_OPEN state")
60
+ return False
61
+ return True
62
+ return False
63
+
64
+ def record_success(self) -> None:
65
+ """Record a successful request"""
66
+ with self._lock:
67
+ if self.state == CircuitState.HALF_OPEN:
68
+ # Success in half-open state means service has recovered
69
+ self.state = CircuitState.CLOSED
70
+ self.failure_count = 0
71
+ logger.info("🔺 Circuit breaker recovered to CLOSED state")
72
+ elif self.state == CircuitState.CLOSED and self.failure_count > 0:
73
+ # Reset failure count on success
74
+ self.failure_count = 0
75
+
76
+ def record_failure(self, error_code=None) -> None:
77
+ """
78
+ Record a failed request
79
+
80
+ Args:
81
+ error_code: gRPC error code for failure classification
82
+ """
83
+ with self._lock:
84
+ # 对于某些错误类型,不计入熔断统计或权重较低
85
+ if error_code and self._should_ignore_for_circuit_breaker(error_code):
86
+ return
87
+
88
+ # ABORTED 错误权重较低,因为通常是瞬时的并发问题
89
+ import grpc
90
+ if error_code == grpc.StatusCode.ABORTED:
91
+ # ABORTED 错误只计算半个失败
92
+ self.failure_count += 0.5
93
+ else:
94
+ self.failure_count += 1
95
+
96
+ self.last_failure_time = time.time()
97
+
98
+ if self.failure_count >= self.failure_threshold:
99
+ if self.state != CircuitState.OPEN:
100
+ self.state = CircuitState.OPEN
101
+ logger.warning(
102
+ f"🔻 Circuit breaker OPENED after {self.failure_count} failures",
103
+ extra={
104
+ "failure_count": self.failure_count,
105
+ "threshold": self.failure_threshold,
106
+ "trigger_error": error_code.name if error_code else "unknown"
107
+ }
108
+ )
109
+
110
+ def _should_ignore_for_circuit_breaker(self, error_code) -> bool:
111
+ """
112
+ 判断错误是否应该被熔断器忽略
113
+
114
+ 某些错误不应该触发熔断:
115
+ - 客户端主动取消的请求
116
+ - 认证相关错误(不代表服务不可用)
117
+ """
118
+ import grpc
119
+ ignored_codes = {
120
+ grpc.StatusCode.UNAUTHENTICATED, # 认证问题,不是服务问题
121
+ grpc.StatusCode.PERMISSION_DENIED, # 权限问题,不是服务问题
122
+ grpc.StatusCode.INVALID_ARGUMENT, # 参数错误,不是服务问题
123
+ }
124
+ return error_code in ignored_codes
125
+
126
+ def should_fallback(self) -> bool:
127
+ """Check if fallback should be used"""
128
+ return self.is_open and self.state != CircuitState.HALF_OPEN
129
+
130
+ def get_state(self) -> str:
131
+ """Get current circuit state"""
132
+ return self.state.value
133
+
134
+ def reset(self) -> None:
135
+ """Reset circuit breaker to initial state"""
136
+ with self._lock:
137
+ self.state = CircuitState.CLOSED
138
+ self.failure_count = 0
139
+ self.last_failure_time = None
140
+ logger.info("🔄 Circuit breaker reset to CLOSED state")
@@ -16,6 +16,9 @@ from .utils import (
16
16
  from .logging_setup import (
17
17
  setup_logger,
18
18
  RequestIdFilter,
19
+ TamarLoggerAdapter,
20
+ get_protected_logger,
21
+ reset_logger_config,
19
22
  MAX_MESSAGE_LENGTH
20
23
  )
21
24
 
@@ -30,5 +33,8 @@ __all__ = [
30
33
  # Logging
31
34
  'setup_logger',
32
35
  'RequestIdFilter',
36
+ 'TamarLoggerAdapter',
37
+ 'get_protected_logger',
38
+ 'reset_logger_config',
33
39
  'MAX_MESSAGE_LENGTH',
34
40
  ]
@@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
12
12
 
13
13
  from ..auth import JWTAuthHandler
14
14
  from ..error_handler import GrpcErrorHandler, ErrorRecoveryStrategy
15
- from .logging_setup import MAX_MESSAGE_LENGTH, setup_logger
15
+ from .logging_setup import MAX_MESSAGE_LENGTH, get_protected_logger
16
16
 
17
17
 
18
18
  class BaseClient(ABC):
@@ -79,7 +79,7 @@ class BaseClient(ABC):
79
79
  os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
80
80
 
81
81
  # === 日志配置 ===
82
- self.logger = setup_logger(logger_name or __name__)
82
+ self.logger = get_protected_logger(logger_name or __name__)
83
83
 
84
84
  # === 错误处理器 ===
85
85
  self.error_handler = GrpcErrorHandler(self.logger)
@@ -87,6 +87,9 @@ class BaseClient(ABC):
87
87
 
88
88
  # === 连接状态 ===
89
89
  self._closed = False
90
+
91
+ # === 熔断降级配置 ===
92
+ self._init_resilient_features()
90
93
 
91
94
  def build_channel_options(self) -> list:
92
95
  """
@@ -165,4 +168,54 @@ class BaseClient(ABC):
165
168
  @abstractmethod
166
169
  def __exit__(self, exc_type, exc_val, exc_tb):
167
170
  """退出上下文管理器(由子类实现)"""
168
- pass
171
+ pass
172
+
173
+ def _init_resilient_features(self):
174
+ """初始化熔断降级特性"""
175
+ # 是否启用熔断降级
176
+ self.resilient_enabled = os.getenv('MODEL_CLIENT_RESILIENT_ENABLED', 'false').lower() == 'true'
177
+
178
+ if self.resilient_enabled:
179
+ # HTTP 降级地址
180
+ self.http_fallback_url = os.getenv('MODEL_CLIENT_HTTP_FALLBACK_URL')
181
+
182
+ if not self.http_fallback_url:
183
+ self.logger.warning("🔶 Resilient mode enabled but MODEL_CLIENT_HTTP_FALLBACK_URL not set")
184
+ self.resilient_enabled = False
185
+ return
186
+
187
+ # 初始化熔断器
188
+ from ..circuit_breaker import CircuitBreaker
189
+ self.circuit_breaker = CircuitBreaker(
190
+ failure_threshold=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_THRESHOLD', '5')),
191
+ recovery_timeout=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_TIMEOUT', '60'))
192
+ )
193
+
194
+ # HTTP 客户端(延迟初始化)
195
+ self._http_client = None
196
+ self._http_session = None # 异步客户端使用
197
+
198
+ self.logger.info(
199
+ "🛡️ Resilient mode enabled",
200
+ extra={
201
+ "http_fallback_url": self.http_fallback_url,
202
+ "circuit_breaker_threshold": self.circuit_breaker.failure_threshold,
203
+ "circuit_breaker_timeout": self.circuit_breaker.recovery_timeout
204
+ }
205
+ )
206
+ else:
207
+ self.circuit_breaker = None
208
+ self.http_fallback_url = None
209
+
210
+ def get_resilient_metrics(self):
211
+ """获取熔断降级指标"""
212
+ if not self.resilient_enabled or not self.circuit_breaker:
213
+ return None
214
+
215
+ return {
216
+ "enabled": self.resilient_enabled,
217
+ "circuit_state": self.circuit_breaker.get_state(),
218
+ "failure_count": self.circuit_breaker.failure_count,
219
+ "last_failure_time": self.circuit_breaker.last_failure_time,
220
+ "http_fallback_url": self.http_fallback_url
221
+ }