tamar-model-client 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -22,6 +22,7 @@ import asyncio
22
22
  import atexit
23
23
  import json
24
24
  import logging
25
+ import random
25
26
  import time
26
27
  from typing import Optional, AsyncIterator, Union
27
28
 
@@ -31,7 +32,7 @@ from grpc import RpcError
31
32
  from .core import (
32
33
  generate_request_id,
33
34
  set_request_id,
34
- setup_logger,
35
+ get_protected_logger,
35
36
  MAX_MESSAGE_LENGTH
36
37
  )
37
38
  from .core.base_client import BaseClient
@@ -42,12 +43,13 @@ from .exceptions import ConnectionError, TamarModelException
42
43
  from .error_handler import EnhancedRetryHandler
43
44
  from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
44
45
  from .generated import model_service_pb2, model_service_pb2_grpc
46
+ from .core.http_fallback import AsyncHttpFallbackMixin
45
47
 
46
- # 配置日志记录器
47
- logger = setup_logger(__name__)
48
+ # 配置日志记录器(使用受保护的logger)
49
+ logger = get_protected_logger(__name__)
48
50
 
49
51
 
50
- class AsyncTamarModelClient(BaseClient):
52
+ class AsyncTamarModelClient(BaseClient, AsyncHttpFallbackMixin):
51
53
  """
52
54
  Tamar Model Client 异步客户端
53
55
 
@@ -127,6 +129,10 @@ class AsyncTamarModelClient(BaseClient):
127
129
  self._closed = True
128
130
  logger.info("🔒 gRPC channel closed",
129
131
  extra={"log_type": "info", "data": {"status": "closed"}})
132
+
133
+ # 清理 HTTP session(如果有)
134
+ if self.resilient_enabled:
135
+ await self._cleanup_http_session()
130
136
 
131
137
  async def __aenter__(self):
132
138
  """异步上下文管理器入口"""
@@ -238,7 +244,16 @@ class AsyncTamarModelClient(BaseClient):
238
244
  Raises:
239
245
  TamarModelException: 当所有重试都失败时
240
246
  """
241
- return await self.retry_handler.execute_with_retry(func, *args, **kwargs)
247
+ # kwargs中提取request_id(如果有的话),然后移除它
248
+ request_id = kwargs.pop('request_id', None) or get_request_id()
249
+
250
+ # 构建包含request_id的上下文
251
+ context = {
252
+ 'method': func.__name__ if hasattr(func, '__name__') else 'unknown',
253
+ 'client_version': 'async',
254
+ 'request_id': request_id,
255
+ }
256
+ return await self.retry_handler.execute_with_retry(func, *args, context=context, **kwargs)
242
257
 
243
258
  async def _retry_request_stream(self, func, *args, **kwargs):
244
259
  """
@@ -254,33 +269,174 @@ class AsyncTamarModelClient(BaseClient):
254
269
  Returns:
255
270
  AsyncIterator: 流式响应迭代器
256
271
  """
272
+ # 记录方法开始时间
273
+ import time
274
+ method_start_time = time.time()
275
+
276
+ # 从kwargs中提取request_id(如果有的话),然后移除它
277
+ request_id = kwargs.pop('request_id', None) or get_request_id()
278
+
257
279
  last_exception = None
280
+ context = {
281
+ 'method': 'stream',
282
+ 'client_version': 'async',
283
+ 'request_id': request_id,
284
+ }
258
285
 
259
286
  for attempt in range(self.max_retries + 1):
260
287
  try:
288
+ context['retry_count'] = attempt
261
289
  # 尝试创建流
262
290
  async for item in func(*args, **kwargs):
263
291
  yield item
264
292
  return
265
293
 
266
294
  except RpcError as e:
267
- last_exception = e
268
- if attempt < self.max_retries:
295
+ # 使用智能重试判断
296
+ context['retry_count'] = attempt
297
+
298
+ # 创建错误上下文并判断是否应该重试
299
+ from .exceptions import ErrorContext, get_retry_policy
300
+ error_context = ErrorContext(e, context)
301
+ error_code = e.code()
302
+ policy = get_retry_policy(error_code)
303
+
304
+ # 先检查错误级别的 max_attempts 配置
305
+ # max_attempts 表示最大重试次数(不包括初始请求)
306
+ error_max_attempts = policy.get('max_attempts', self.max_retries)
307
+ if attempt >= error_max_attempts:
308
+ should_retry = False
309
+ elif attempt >= self.max_retries:
310
+ should_retry = False
311
+ else:
312
+ retryable = policy.get('retryable', False)
313
+ if retryable == True:
314
+ should_retry = True
315
+ elif retryable == 'conditional':
316
+ # 条件重试,特殊处理 CANCELLED
317
+ if error_code == grpc.StatusCode.CANCELLED:
318
+ should_retry = error_context.is_network_cancelled()
319
+ else:
320
+ should_retry = self._check_error_details_for_retry(e)
321
+ else:
322
+ should_retry = False
323
+
324
+ if should_retry:
325
+ current_duration = time.time() - method_start_time
326
+ log_data = {
327
+ "log_type": "info",
328
+ "request_id": context.get('request_id'),
329
+ "data": {
330
+ "error_code": e.code().name if e.code() else 'UNKNOWN',
331
+ "retry_count": attempt,
332
+ "max_retries": self.max_retries,
333
+ "method": "stream"
334
+ },
335
+ "duration": current_duration
336
+ }
269
337
  logger.warning(
270
- f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
271
- extra={"retry_count": attempt, "error_code": str(e.code())}
338
+ f"Stream attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()} (will retry)",
339
+ extra=log_data
272
340
  )
273
- await asyncio.sleep(self.retry_delay * (attempt + 1))
341
+
342
+ # 计算退避时间
343
+ delay = self._calculate_backoff(attempt, error_code)
344
+ await asyncio.sleep(delay)
274
345
  else:
346
+ # 不重试或已达到最大重试次数
347
+ current_duration = time.time() - method_start_time
348
+ log_data = {
349
+ "log_type": "info",
350
+ "request_id": context.get('request_id'),
351
+ "data": {
352
+ "error_code": e.code().name if e.code() else 'UNKNOWN',
353
+ "retry_count": attempt,
354
+ "max_retries": self.max_retries,
355
+ "method": "stream",
356
+ "will_retry": False
357
+ },
358
+ "duration": current_duration
359
+ }
360
+ logger.error(
361
+ f"Stream failed: {e.code()} (no retry)",
362
+ extra=log_data
363
+ )
364
+ context['duration'] = current_duration
365
+ last_exception = self.error_handler.handle_error(e, context)
275
366
  break
367
+
368
+ last_exception = e
369
+
276
370
  except Exception as e:
371
+ context['retry_count'] = attempt
277
372
  raise TamarModelException(str(e)) from e
278
373
 
279
374
  if last_exception:
280
- raise self.error_handler.handle_error(last_exception, {"retry_count": self.max_retries})
375
+ if isinstance(last_exception, TamarModelException):
376
+ raise last_exception
377
+ else:
378
+ raise self.error_handler.handle_error(last_exception, context)
281
379
  else:
282
380
  raise TamarModelException("Unknown streaming error occurred")
283
381
 
382
+ def _check_error_details_for_retry(self, error: RpcError) -> bool:
383
+ """检查错误详情决定是否重试"""
384
+ error_message = error.details().lower() if error.details() else ""
385
+
386
+ # 可重试的错误模式
387
+ retryable_patterns = [
388
+ 'temporary', 'timeout', 'unavailable',
389
+ 'connection', 'network', 'try again'
390
+ ]
391
+
392
+ for pattern in retryable_patterns:
393
+ if pattern in error_message:
394
+ return True
395
+
396
+ return False
397
+
398
+ def _calculate_backoff(self, attempt: int, error_code = None) -> float:
399
+ """
400
+ 计算退避时间,支持不同的退避策略
401
+
402
+ Args:
403
+ attempt: 当前重试次数
404
+ error_code: gRPC错误码,用于确定退避策略
405
+ """
406
+ max_delay = 60.0
407
+ base_delay = self.retry_delay
408
+
409
+ # 获取错误的重试策略
410
+ if error_code:
411
+ from .exceptions import get_retry_policy
412
+ policy = get_retry_policy(error_code)
413
+ backoff_type = policy.get('backoff', 'exponential')
414
+ use_jitter = policy.get('jitter', False)
415
+ else:
416
+ backoff_type = 'exponential'
417
+ use_jitter = False
418
+
419
+ # 根据退避类型计算延迟
420
+ if backoff_type == 'linear':
421
+ # 线性退避:delay * (attempt + 1)
422
+ delay = min(base_delay * (attempt + 1), max_delay)
423
+ else:
424
+ # 指数退避:delay * 2^attempt
425
+ delay = min(base_delay * (2 ** attempt), max_delay)
426
+
427
+ # 添加抖动
428
+ if use_jitter:
429
+ jitter_factor = 0.2 # 增加抖动范围,减少竞争
430
+ jitter = random.uniform(0, delay * jitter_factor)
431
+ delay += jitter
432
+ else:
433
+ # 默认的小量抖动,避免完全同步
434
+ jitter_factor = 0.05
435
+ jitter = random.uniform(0, delay * jitter_factor)
436
+ delay += jitter
437
+
438
+ return delay
439
+
284
440
  async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
285
441
  """
286
442
  处理流式响应
@@ -328,7 +484,7 @@ class AsyncTamarModelClient(BaseClient):
328
484
  chunk_count = 0
329
485
 
330
486
  # 使用重试逻辑获取流生成器
331
- stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
487
+ stream_generator = self._retry_request_stream(self._stream, request, metadata, invoke_timeout, request_id=get_request_id())
332
488
 
333
489
  try:
334
490
  async for response in stream_generator:
@@ -424,6 +580,12 @@ class AsyncTamarModelClient(BaseClient):
424
580
  ValidationError: 输入验证失败。
425
581
  ConnectionError: 连接服务端失败。
426
582
  """
583
+ # 如果启用了熔断且熔断器打开,直接走 HTTP
584
+ if self.resilient_enabled and self.circuit_breaker and self.circuit_breaker.is_open:
585
+ if self.http_fallback_url:
586
+ logger.warning("🔻 Circuit breaker is OPEN, using HTTP fallback")
587
+ return await self._invoke_http_fallback(model_request, timeout, request_id)
588
+
427
589
  await self._ensure_initialized()
428
590
 
429
591
  if not self.default_payload:
@@ -477,7 +639,7 @@ class AsyncTamarModelClient(BaseClient):
477
639
  # 对于流式响应,直接返回带日志记录的包装器
478
640
  return self._stream_with_logging(request, metadata, invoke_timeout, start_time, model_request)
479
641
  else:
480
- result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
642
+ result = await self._retry_request(self._invoke_request, request, metadata, invoke_timeout, request_id=request_id)
481
643
 
482
644
  # 记录非流式响应的成功日志
483
645
  duration = time.time() - start_time
@@ -491,9 +653,14 @@ class AsyncTamarModelClient(BaseClient):
491
653
  "data": ResponseHandler.build_log_data(model_request, result)
492
654
  }
493
655
  )
656
+
657
+ # 记录成功(如果启用了熔断)
658
+ if self.resilient_enabled and self.circuit_breaker:
659
+ self.circuit_breaker.record_success()
660
+
494
661
  return result
495
662
 
496
- except grpc.RpcError as e:
663
+ except (ConnectionError, grpc.RpcError) as e:
497
664
  duration = time.time() - start_time
498
665
  error_message = f"❌ Invoke gRPC failed: {str(e)}"
499
666
  logger.error(error_message, exc_info=True,
@@ -506,6 +673,18 @@ class AsyncTamarModelClient(BaseClient):
506
673
  error=e
507
674
  )
508
675
  })
676
+
677
+ # 记录失败并尝试降级(如果启用了熔断)
678
+ if self.resilient_enabled and self.circuit_breaker:
679
+ # 将错误码传递给熔断器,用于智能失败统计
680
+ error_code = e.code() if hasattr(e, 'code') else None
681
+ self.circuit_breaker.record_failure(error_code)
682
+
683
+ # 如果可以降级,则降级
684
+ if self.http_fallback_url and self.circuit_breaker.should_fallback():
685
+ logger.warning(f"🔻 gRPC failed, falling back to HTTP: {str(e)}")
686
+ return await self._invoke_http_fallback(model_request, timeout, request_id)
687
+
509
688
  raise e
510
689
  except Exception as e:
511
690
  duration = time.time() - start_time
@@ -590,7 +769,8 @@ class AsyncTamarModelClient(BaseClient):
590
769
  self.stub.BatchInvoke,
591
770
  batch_request,
592
771
  metadata=metadata,
593
- timeout=invoke_timeout
772
+ timeout=invoke_timeout,
773
+ request_id=request_id
594
774
  )
595
775
 
596
776
  # 构建响应对象
@@ -0,0 +1,140 @@
1
+ """
2
+ Circuit Breaker implementation for resilient client
3
+
4
+ This module provides a thread-safe circuit breaker pattern implementation
5
+ to handle failures gracefully and prevent cascading failures.
6
+ """
7
+
8
+ import time
9
+ import logging
10
+ from enum import Enum
11
+ from threading import Lock
12
+ from typing import Optional
13
+
14
+ from .core.logging_setup import get_protected_logger
15
+
16
+ logger = get_protected_logger(__name__)
17
+
18
+
19
+ class CircuitState(Enum):
20
+ """Circuit breaker states"""
21
+ CLOSED = "closed" # Normal operation
22
+ OPEN = "open" # Circuit is broken, requests fail fast
23
+ HALF_OPEN = "half_open" # Testing if service has recovered
24
+
25
+
26
+ class CircuitBreaker:
27
+ """
28
+ Thread-safe circuit breaker implementation
29
+
30
+ The circuit breaker prevents cascading failures by failing fast when
31
+ a service is unavailable, and automatically recovers when the service
32
+ becomes available again.
33
+ """
34
+
35
+ def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
36
+ """
37
+ Initialize the circuit breaker
38
+
39
+ Args:
40
+ failure_threshold: Number of consecutive failures before opening circuit
41
+ recovery_timeout: Seconds to wait before attempting recovery
42
+ """
43
+ self.failure_threshold = failure_threshold
44
+ self.recovery_timeout = recovery_timeout
45
+ self.failure_count = 0
46
+ self.last_failure_time: Optional[float] = None
47
+ self.state = CircuitState.CLOSED
48
+ self._lock = Lock()
49
+
50
+ @property
51
+ def is_open(self) -> bool:
52
+ """Check if circuit breaker is open"""
53
+ with self._lock:
54
+ if self.state == CircuitState.OPEN:
55
+ # Check if we should attempt recovery
56
+ if (self.last_failure_time and
57
+ time.time() - self.last_failure_time > self.recovery_timeout):
58
+ self.state = CircuitState.HALF_OPEN
59
+ logger.info("🔄 Circuit breaker entering HALF_OPEN state")
60
+ return False
61
+ return True
62
+ return False
63
+
64
+ def record_success(self) -> None:
65
+ """Record a successful request"""
66
+ with self._lock:
67
+ if self.state == CircuitState.HALF_OPEN:
68
+ # Success in half-open state means service has recovered
69
+ self.state = CircuitState.CLOSED
70
+ self.failure_count = 0
71
+ logger.info("🔺 Circuit breaker recovered to CLOSED state")
72
+ elif self.state == CircuitState.CLOSED and self.failure_count > 0:
73
+ # Reset failure count on success
74
+ self.failure_count = 0
75
+
76
+ def record_failure(self, error_code=None) -> None:
77
+ """
78
+ Record a failed request
79
+
80
+ Args:
81
+ error_code: gRPC error code for failure classification
82
+ """
83
+ with self._lock:
84
+ # 对于某些错误类型,不计入熔断统计或权重较低
85
+ if error_code and self._should_ignore_for_circuit_breaker(error_code):
86
+ return
87
+
88
+ # ABORTED 错误权重较低,因为通常是瞬时的并发问题
89
+ import grpc
90
+ if error_code == grpc.StatusCode.ABORTED:
91
+ # ABORTED 错误只计算半个失败
92
+ self.failure_count += 0.5
93
+ else:
94
+ self.failure_count += 1
95
+
96
+ self.last_failure_time = time.time()
97
+
98
+ if self.failure_count >= self.failure_threshold:
99
+ if self.state != CircuitState.OPEN:
100
+ self.state = CircuitState.OPEN
101
+ logger.warning(
102
+ f"🔻 Circuit breaker OPENED after {self.failure_count} failures",
103
+ extra={
104
+ "failure_count": self.failure_count,
105
+ "threshold": self.failure_threshold,
106
+ "trigger_error": error_code.name if error_code else "unknown"
107
+ }
108
+ )
109
+
110
+ def _should_ignore_for_circuit_breaker(self, error_code) -> bool:
111
+ """
112
+ 判断错误是否应该被熔断器忽略
113
+
114
+ 某些错误不应该触发熔断:
115
+ - 客户端主动取消的请求
116
+ - 认证相关错误(不代表服务不可用)
117
+ """
118
+ import grpc
119
+ ignored_codes = {
120
+ grpc.StatusCode.UNAUTHENTICATED, # 认证问题,不是服务问题
121
+ grpc.StatusCode.PERMISSION_DENIED, # 权限问题,不是服务问题
122
+ grpc.StatusCode.INVALID_ARGUMENT, # 参数错误,不是服务问题
123
+ }
124
+ return error_code in ignored_codes
125
+
126
+ def should_fallback(self) -> bool:
127
+ """Check if fallback should be used"""
128
+ return self.is_open and self.state != CircuitState.HALF_OPEN
129
+
130
+ def get_state(self) -> str:
131
+ """Get current circuit state"""
132
+ return self.state.value
133
+
134
+ def reset(self) -> None:
135
+ """Reset circuit breaker to initial state"""
136
+ with self._lock:
137
+ self.state = CircuitState.CLOSED
138
+ self.failure_count = 0
139
+ self.last_failure_time = None
140
+ logger.info("🔄 Circuit breaker reset to CLOSED state")
@@ -16,6 +16,9 @@ from .utils import (
16
16
  from .logging_setup import (
17
17
  setup_logger,
18
18
  RequestIdFilter,
19
+ TamarLoggerAdapter,
20
+ get_protected_logger,
21
+ reset_logger_config,
19
22
  MAX_MESSAGE_LENGTH
20
23
  )
21
24
 
@@ -30,5 +33,8 @@ __all__ = [
30
33
  # Logging
31
34
  'setup_logger',
32
35
  'RequestIdFilter',
36
+ 'TamarLoggerAdapter',
37
+ 'get_protected_logger',
38
+ 'reset_logger_config',
33
39
  'MAX_MESSAGE_LENGTH',
34
40
  ]
@@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
12
12
 
13
13
  from ..auth import JWTAuthHandler
14
14
  from ..error_handler import GrpcErrorHandler, ErrorRecoveryStrategy
15
- from .logging_setup import MAX_MESSAGE_LENGTH, setup_logger
15
+ from .logging_setup import MAX_MESSAGE_LENGTH, get_protected_logger
16
16
 
17
17
 
18
18
  class BaseClient(ABC):
@@ -79,7 +79,7 @@ class BaseClient(ABC):
79
79
  os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
80
80
 
81
81
  # === 日志配置 ===
82
- self.logger = setup_logger(logger_name or __name__)
82
+ self.logger = get_protected_logger(logger_name or __name__)
83
83
 
84
84
  # === 错误处理器 ===
85
85
  self.error_handler = GrpcErrorHandler(self.logger)
@@ -87,6 +87,9 @@ class BaseClient(ABC):
87
87
 
88
88
  # === 连接状态 ===
89
89
  self._closed = False
90
+
91
+ # === 熔断降级配置 ===
92
+ self._init_resilient_features()
90
93
 
91
94
  def build_channel_options(self) -> list:
92
95
  """
@@ -165,4 +168,54 @@ class BaseClient(ABC):
165
168
  @abstractmethod
166
169
  def __exit__(self, exc_type, exc_val, exc_tb):
167
170
  """退出上下文管理器(由子类实现)"""
168
- pass
171
+ pass
172
+
173
+ def _init_resilient_features(self):
174
+ """初始化熔断降级特性"""
175
+ # 是否启用熔断降级
176
+ self.resilient_enabled = os.getenv('MODEL_CLIENT_RESILIENT_ENABLED', 'false').lower() == 'true'
177
+
178
+ if self.resilient_enabled:
179
+ # HTTP 降级地址
180
+ self.http_fallback_url = os.getenv('MODEL_CLIENT_HTTP_FALLBACK_URL')
181
+
182
+ if not self.http_fallback_url:
183
+ self.logger.warning("🔶 Resilient mode enabled but MODEL_CLIENT_HTTP_FALLBACK_URL not set")
184
+ self.resilient_enabled = False
185
+ return
186
+
187
+ # 初始化熔断器
188
+ from ..circuit_breaker import CircuitBreaker
189
+ self.circuit_breaker = CircuitBreaker(
190
+ failure_threshold=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_THRESHOLD', '5')),
191
+ recovery_timeout=int(os.getenv('MODEL_CLIENT_CIRCUIT_BREAKER_TIMEOUT', '60'))
192
+ )
193
+
194
+ # HTTP 客户端(延迟初始化)
195
+ self._http_client = None
196
+ self._http_session = None # 异步客户端使用
197
+
198
+ self.logger.info(
199
+ "🛡️ Resilient mode enabled",
200
+ extra={
201
+ "http_fallback_url": self.http_fallback_url,
202
+ "circuit_breaker_threshold": self.circuit_breaker.failure_threshold,
203
+ "circuit_breaker_timeout": self.circuit_breaker.recovery_timeout
204
+ }
205
+ )
206
+ else:
207
+ self.circuit_breaker = None
208
+ self.http_fallback_url = None
209
+
210
+ def get_resilient_metrics(self):
211
+ """获取熔断降级指标"""
212
+ if not self.resilient_enabled or not self.circuit_breaker:
213
+ return None
214
+
215
+ return {
216
+ "enabled": self.resilient_enabled,
217
+ "circuit_state": self.circuit_breaker.get_state(),
218
+ "failure_count": self.circuit_breaker.failure_count,
219
+ "last_failure_time": self.circuit_breaker.last_failure_time,
220
+ "http_fallback_url": self.http_fallback_url
221
+ }