tamar-model-client 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,249 @@
1
+ """
2
+ HTTP fallback functionality for resilient clients
3
+
4
+ This module provides mixin classes for HTTP-based fallback when gRPC
5
+ connections fail, supporting both synchronous and asynchronous clients.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from typing import Optional, Iterator, AsyncIterator, Dict, Any
11
+
12
+ from . import generate_request_id, get_protected_logger
13
+ from ..schemas import ModelRequest, ModelResponse
14
+
15
+ logger = get_protected_logger(__name__)
16
+
17
+
18
+ class HttpFallbackMixin:
19
+ """HTTP fallback functionality for synchronous clients"""
20
+
21
+ def _ensure_http_client(self) -> None:
22
+ """Ensure HTTP client is initialized"""
23
+ if not hasattr(self, '_http_client') or not self._http_client:
24
+ import requests
25
+ self._http_client = requests.Session()
26
+
27
+ # Set authentication header if available
28
+ # Note: JWT token will be set per request in headers
29
+
30
+ # Set default headers
31
+ self._http_client.headers.update({
32
+ 'Content-Type': 'application/json',
33
+ 'User-Agent': 'TamarModelClient/1.0'
34
+ })
35
+
36
+ def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
37
+ """Convert ModelRequest to HTTP payload format"""
38
+ payload = {
39
+ "provider": model_request.provider.value,
40
+ "model": model_request.model,
41
+ "user_context": model_request.user_context.model_dump(),
42
+ "stream": model_request.stream
43
+ }
44
+
45
+ # Add provider-specific fields
46
+ if hasattr(model_request, 'messages') and model_request.messages:
47
+ payload['messages'] = model_request.messages
48
+ if hasattr(model_request, 'contents') and model_request.contents:
49
+ payload['contents'] = model_request.contents
50
+
51
+ # Add optional fields
52
+ if model_request.channel:
53
+ payload['channel'] = model_request.channel.value
54
+ if model_request.invoke_type:
55
+ payload['invoke_type'] = model_request.invoke_type.value
56
+
57
+ # Add extra parameters
58
+ if hasattr(model_request, 'model_extra') and model_request.model_extra:
59
+ for key, value in model_request.model_extra.items():
60
+ if key not in payload:
61
+ payload[key] = value
62
+
63
+ return payload
64
+
65
+ def _handle_http_stream(self, url: str, payload: Dict[str, Any],
66
+ timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> Iterator[ModelResponse]:
67
+ """Handle HTTP streaming response"""
68
+ import requests
69
+
70
+ response = self._http_client.post(
71
+ url,
72
+ json=payload,
73
+ timeout=timeout or 30,
74
+ headers=headers,
75
+ stream=True
76
+ )
77
+ response.raise_for_status()
78
+
79
+ # Parse SSE stream
80
+ for line in response.iter_lines():
81
+ if line:
82
+ line_str = line.decode('utf-8')
83
+ if line_str.startswith('data: '):
84
+ data_str = line_str[6:]
85
+ if data_str == '[DONE]':
86
+ break
87
+ try:
88
+ data = json.loads(data_str)
89
+ yield ModelResponse(**data)
90
+ except json.JSONDecodeError:
91
+ logger.warning(f"Failed to parse streaming response: {data_str}")
92
+
93
+ def _invoke_http_fallback(self, model_request: ModelRequest,
94
+ timeout: Optional[float] = None,
95
+ request_id: Optional[str] = None) -> Any:
96
+ """HTTP fallback implementation"""
97
+ self._ensure_http_client()
98
+
99
+ # Generate request ID if not provided
100
+ if not request_id:
101
+ request_id = generate_request_id()
102
+
103
+ # Log fallback usage
104
+ logger.warning(
105
+ f"🔻 Using HTTP fallback for request",
106
+ extra={
107
+ "request_id": request_id,
108
+ "provider": model_request.provider.value,
109
+ "model": model_request.model,
110
+ "fallback_url": self.http_fallback_url
111
+ }
112
+ )
113
+
114
+ # Convert to HTTP format
115
+ http_payload = self._convert_to_http_format(model_request)
116
+
117
+ # Construct URL
118
+ url = f"{self.http_fallback_url}/v1/invoke"
119
+
120
+ # Build headers with authentication
121
+ headers = {'X-Request-ID': request_id}
122
+ if hasattr(self, 'jwt_token') and self.jwt_token:
123
+ headers['Authorization'] = f'Bearer {self.jwt_token}'
124
+
125
+ if model_request.stream:
126
+ # Return streaming iterator
127
+ return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
128
+ else:
129
+ # Non-streaming request
130
+ response = self._http_client.post(
131
+ url,
132
+ json=http_payload,
133
+ timeout=timeout or 30,
134
+ headers=headers
135
+ )
136
+ response.raise_for_status()
137
+
138
+ # Parse response
139
+ data = response.json()
140
+ return ModelResponse(**data)
141
+
142
+
143
+ class AsyncHttpFallbackMixin:
144
+ """HTTP fallback functionality for asynchronous clients"""
145
+
146
+ async def _ensure_http_client(self) -> None:
147
+ """Ensure async HTTP client is initialized"""
148
+ if not hasattr(self, '_http_session') or not self._http_session:
149
+ import aiohttp
150
+ self._http_session = aiohttp.ClientSession(
151
+ headers={
152
+ 'Content-Type': 'application/json',
153
+ 'User-Agent': 'AsyncTamarModelClient/1.0'
154
+ }
155
+ )
156
+
157
+ # Note: JWT token will be set per request in headers
158
+
159
+ def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
160
+ """Convert ModelRequest to HTTP payload format (reuse sync version)"""
161
+ # This method doesn't need to be async, so we can reuse the sync version
162
+ return HttpFallbackMixin._convert_to_http_format(self, model_request)
163
+
164
+ async def _handle_http_stream(self, url: str, payload: Dict[str, Any],
165
+ timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> AsyncIterator[ModelResponse]:
166
+ """Handle async HTTP streaming response"""
167
+ import aiohttp
168
+
169
+ timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
170
+
171
+ async with self._http_session.post(
172
+ url,
173
+ json=payload,
174
+ timeout=timeout_obj,
175
+ headers=headers
176
+ ) as response:
177
+ response.raise_for_status()
178
+
179
+ # Parse SSE stream
180
+ async for line_bytes in response.content:
181
+ if line_bytes:
182
+ line_str = line_bytes.decode('utf-8').strip()
183
+ if line_str.startswith('data: '):
184
+ data_str = line_str[6:]
185
+ if data_str == '[DONE]':
186
+ break
187
+ try:
188
+ data = json.loads(data_str)
189
+ yield ModelResponse(**data)
190
+ except json.JSONDecodeError:
191
+ logger.warning(f"Failed to parse streaming response: {data_str}")
192
+
193
+ async def _invoke_http_fallback(self, model_request: ModelRequest,
194
+ timeout: Optional[float] = None,
195
+ request_id: Optional[str] = None) -> Any:
196
+ """Async HTTP fallback implementation"""
197
+ await self._ensure_http_client()
198
+
199
+ # Generate request ID if not provided
200
+ if not request_id:
201
+ request_id = generate_request_id()
202
+
203
+ # Log fallback usage
204
+ logger.warning(
205
+ f"🔻 Using HTTP fallback for request",
206
+ extra={
207
+ "request_id": request_id,
208
+ "provider": model_request.provider.value,
209
+ "model": model_request.model,
210
+ "fallback_url": self.http_fallback_url
211
+ }
212
+ )
213
+
214
+ # Convert to HTTP format
215
+ http_payload = self._convert_to_http_format(model_request)
216
+
217
+ # Construct URL
218
+ url = f"{self.http_fallback_url}/v1/invoke"
219
+
220
+ # Build headers with authentication
221
+ headers = {'X-Request-ID': request_id}
222
+ if hasattr(self, 'jwt_token') and self.jwt_token:
223
+ headers['Authorization'] = f'Bearer {self.jwt_token}'
224
+
225
+ if model_request.stream:
226
+ # Return async streaming iterator
227
+ return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
228
+ else:
229
+ # Non-streaming request
230
+ import aiohttp
231
+ timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
232
+
233
+ async with self._http_session.post(
234
+ url,
235
+ json=http_payload,
236
+ timeout=timeout_obj,
237
+ headers=headers
238
+ ) as response:
239
+ response.raise_for_status()
240
+
241
+ # Parse response
242
+ data = await response.json()
243
+ return ModelResponse(**data)
244
+
245
+ async def _cleanup_http_session(self) -> None:
246
+ """Clean up HTTP session"""
247
+ if hasattr(self, '_http_session') and self._http_session:
248
+ await self._http_session.close()
249
+ self._http_session = None
@@ -6,7 +6,8 @@ It includes request ID tracking, JSON formatting, and consistent log configurati
6
6
  """
7
7
 
8
8
  import logging
9
- from typing import Optional
9
+ import threading
10
+ from typing import Optional, Dict
10
11
 
11
12
  from ..json_formatter import JSONFormatter
12
13
  from .utils import get_request_id
@@ -14,6 +15,15 @@ from .utils import get_request_id
14
15
  # gRPC 消息长度限制(32位系统兼容)
15
16
  MAX_MESSAGE_LENGTH = 2 ** 31 - 1
16
17
 
18
+ # SDK 专用的 logger 名称前缀
19
+ TAMAR_LOGGER_PREFIX = "tamar_model_client"
20
+
21
+ # 线程安全的 logger 配置锁
22
+ _logger_lock = threading.Lock()
23
+
24
+ # 已配置的 logger 缓存
25
+ _configured_loggers: Dict[str, logging.Logger] = {}
26
+
17
27
 
18
28
  class RequestIdFilter(logging.Filter):
19
29
  """
@@ -38,9 +48,54 @@ class RequestIdFilter(logging.Filter):
38
48
  return True
39
49
 
40
50
 
51
+ class TamarLoggerAdapter:
52
+ """
53
+ Logger 适配器,确保 SDK 的日志格式不被外部修改
54
+
55
+ 这个适配器包装了原始的 logger,拦截所有的日志方法调用,
56
+ 确保使用正确的格式和处理器。
57
+ """
58
+
59
+ def __init__(self, logger: logging.Logger):
60
+ self._logger = logger
61
+ self._lock = threading.Lock()
62
+
63
+ def _ensure_format(self):
64
+ """确保 logger 使用正确的格式"""
65
+ with self._lock:
66
+ # 检查并修复处理器
67
+ for handler in self._logger.handlers[:]:
68
+ if not isinstance(handler.formatter, JSONFormatter):
69
+ handler.setFormatter(JSONFormatter())
70
+
71
+ # 确保 propagate 设置正确
72
+ if self._logger.propagate:
73
+ self._logger.propagate = False
74
+
75
+ def _log(self, level, msg, *args, **kwargs):
76
+ """统一的日志方法"""
77
+ self._ensure_format()
78
+ getattr(self._logger, level)(msg, *args, **kwargs)
79
+
80
+ def debug(self, msg, *args, **kwargs):
81
+ self._log('debug', msg, *args, **kwargs)
82
+
83
+ def info(self, msg, *args, **kwargs):
84
+ self._log('info', msg, *args, **kwargs)
85
+
86
+ def warning(self, msg, *args, **kwargs):
87
+ self._log('warning', msg, *args, **kwargs)
88
+
89
+ def error(self, msg, *args, **kwargs):
90
+ self._log('error', msg, *args, **kwargs)
91
+
92
+ def critical(self, msg, *args, **kwargs):
93
+ self._log('critical', msg, *args, **kwargs)
94
+
95
+
41
96
  def setup_logger(logger_name: str, level: int = logging.INFO) -> logging.Logger:
42
97
  """
43
- 设置并配置logger
98
+ 设置并配置logger (保持向后兼容)
44
99
 
45
100
  为指定的logger配置处理器、格式化器和过滤器。
46
101
  如果logger已经有处理器,则不会重复配置。
@@ -57,28 +112,83 @@ def setup_logger(logger_name: str, level: int = logging.INFO) -> logging.Logger:
57
112
  - 添加请求ID过滤器用于请求追踪
58
113
  - 避免重复配置
59
114
  """
60
- logger = logging.getLogger(logger_name)
115
+ # 确保 logger 名称以 SDK 前缀开始
116
+ if not logger_name.startswith(TAMAR_LOGGER_PREFIX):
117
+ logger_name = f"{TAMAR_LOGGER_PREFIX}.{logger_name}"
61
118
 
62
- # 仅在没有处理器时配置,避免重复配置
63
- if not logger.hasHandlers():
64
- # 创建控制台日志处理器
119
+ with _logger_lock:
120
+ # 检查缓存
121
+ if logger_name in _configured_loggers:
122
+ return _configured_loggers[logger_name]
123
+
124
+ logger = logging.getLogger(logger_name)
125
+
126
+ # 强制清除所有现有的处理器
127
+ logger.handlers.clear()
128
+
129
+ # 创建专用的控制台处理器
65
130
  console_handler = logging.StreamHandler()
131
+ console_handler.setFormatter(JSONFormatter())
66
132
 
67
- # 使用自定义的 JSON 格式化器,提供结构化日志输出
68
- formatter = JSONFormatter()
69
- console_handler.setFormatter(formatter)
133
+ # 为处理器设置唯一标识,便于识别
134
+ console_handler.name = f"tamar_handler_{id(console_handler)}"
70
135
 
71
- # 为logger添加处理器
136
+ # 添加处理器
72
137
  logger.addHandler(console_handler)
73
138
 
74
139
  # 设置日志级别
75
140
  logger.setLevel(level)
76
141
 
77
- # 添加自定义的请求ID过滤器,用于请求追踪
142
+ # 添加请求ID过滤器
78
143
  logger.addFilter(RequestIdFilter())
79
144
 
80
- # 关键:设置 propagate = False,防止日志传播到父logger
81
- # 这样可以避免测试脚本的日志格式影响客户端日志
145
+ # 关键设置:
146
+ # 1. 不传播到父 logger
82
147
  logger.propagate = False
148
+
149
+ # 2. 禁用外部修改(Python 3.8+)
150
+ if hasattr(logger, 'disabled'):
151
+ logger.disabled = False
152
+
153
+ # 缓存配置好的 logger
154
+ _configured_loggers[logger_name] = logger
155
+
156
+ return logger
157
+
158
+
159
+ def get_protected_logger(logger_name: str, level: int = logging.INFO) -> TamarLoggerAdapter:
160
+ """
161
+ 获取受保护的 logger
162
+
163
+ 返回一个 logger 适配器,确保日志格式不会被外部修改。
83
164
 
84
- return logger
165
+ Args:
166
+ logger_name: logger的名称
167
+ level: 日志级别,默认为 INFO
168
+
169
+ Returns:
170
+ TamarLoggerAdapter: 受保护的 logger 适配器
171
+ """
172
+ logger = setup_logger(logger_name, level)
173
+ return TamarLoggerAdapter(logger)
174
+
175
+
176
+ def reset_logger_config(logger_name: str) -> None:
177
+ """
178
+ 重置 logger 配置
179
+
180
+ 用于测试或需要重新配置的场景。
181
+
182
+ Args:
183
+ logger_name: logger的名称
184
+ """
185
+ if not logger_name.startswith(TAMAR_LOGGER_PREFIX):
186
+ logger_name = f"{TAMAR_LOGGER_PREFIX}.{logger_name}"
187
+
188
+ with _logger_lock:
189
+ if logger_name in _configured_loggers:
190
+ del _configured_loggers[logger_name]
191
+
192
+ logger = logging.getLogger(logger_name)
193
+ logger.handlers.clear()
194
+ logger.filters.clear()
@@ -43,9 +43,32 @@ class GrpcErrorHandler:
43
43
  error_context = ErrorContext(error, context)
44
44
 
45
45
  # 记录详细错误日志
46
+ # 将error_context的重要信息平铺到日志的data字段中
47
+ log_data = {
48
+ "log_type": "info",
49
+ "request_id": error_context.request_id,
50
+ "data": {
51
+ "error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
52
+ "error_message": error_context.error_message,
53
+ "provider": error_context.provider,
54
+ "model": error_context.model,
55
+ "method": error_context.method,
56
+ "retry_count": error_context.retry_count,
57
+ "category": error_context._get_error_category(),
58
+ "is_retryable": error_context._is_retryable(),
59
+ "suggested_action": error_context._get_suggested_action(),
60
+ "debug_string": error_context.error_debug_string,
61
+ "is_network_cancelled": error_context.is_network_cancelled() if error_context.error_code == grpc.StatusCode.CANCELLED else None
62
+ }
63
+ }
64
+
65
+ # 如果上下文中有 duration,添加到日志中
66
+ if 'duration' in context:
67
+ log_data['duration'] = context['duration']
68
+
46
69
  self.logger.error(
47
- f"gRPC Error occurred: {error_context.error_code}",
48
- extra=error_context.to_dict()
70
+ f"gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
71
+ extra=log_data
49
72
  )
50
73
 
51
74
  # 更新错误统计
@@ -192,6 +215,10 @@ class EnhancedRetryHandler:
192
215
  Raises:
193
216
  TamarModelException: 包装后的异常
194
217
  """
218
+ # 记录开始时间
219
+ import time
220
+ method_start_time = time.time()
221
+
195
222
  context = context or {}
196
223
  last_exception = None
197
224
 
@@ -207,13 +234,32 @@ class EnhancedRetryHandler:
207
234
  # 判断是否可以重试
208
235
  if not self._should_retry(e, attempt):
209
236
  # 不可重试或已达到最大重试次数
237
+ current_duration = time.time() - method_start_time
238
+ context['duration'] = current_duration
210
239
  last_exception = self.error_handler.handle_error(e, context)
211
240
  break
241
+
242
+ # 计算当前耗时
243
+ current_duration = time.time() - method_start_time
212
244
 
213
245
  # 记录重试日志
246
+ log_data = {
247
+ "log_type": "info",
248
+ "request_id": error_context.request_id,
249
+ "data": {
250
+ "error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
251
+ "error_message": error_context.error_message,
252
+ "retry_count": attempt,
253
+ "max_retries": self.max_retries,
254
+ "category": error_context._get_error_category(),
255
+ "is_retryable": True, # 既然在重试,说明是可重试的
256
+ "method": error_context.method
257
+ },
258
+ "duration": current_duration
259
+ }
214
260
  logger.warning(
215
261
  f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
216
- extra=error_context.to_dict()
262
+ extra=log_data
217
263
  )
218
264
 
219
265
  # 执行退避等待
@@ -221,6 +267,7 @@ class EnhancedRetryHandler:
221
267
  delay = self._calculate_backoff(attempt)
222
268
  await asyncio.sleep(delay)
223
269
 
270
+ context['duration'] = current_duration
224
271
  last_exception = self.error_handler.handle_error(e, context)
225
272
 
226
273
  except Exception as e:
@@ -239,12 +286,19 @@ class EnhancedRetryHandler:
239
286
 
240
287
  def _should_retry(self, error: grpc.RpcError, attempt: int) -> bool:
241
288
  """判断是否应该重试"""
242
- if attempt >= self.max_retries:
243
- return False
244
-
245
289
  error_code = error.code()
246
290
  policy = RETRY_POLICY.get(error_code, {})
247
291
 
292
+ # 先检查错误级别的 max_attempts 配置
293
+ # max_attempts 表示最大重试次数(不包括初始请求)
294
+ error_max_attempts = policy.get('max_attempts', self.max_retries)
295
+ if attempt >= error_max_attempts:
296
+ return False
297
+
298
+ # 再检查全局的 max_retries
299
+ if attempt >= self.max_retries:
300
+ return False
301
+
248
302
  # 检查基本重试策略
249
303
  retryable = policy.get('retryable', False)
250
304
  if retryable == False:
@@ -17,7 +17,10 @@ ERROR_CATEGORIES = {
17
17
  'NETWORK': [
18
18
  grpc.StatusCode.UNAVAILABLE,
19
19
  grpc.StatusCode.DEADLINE_EXCEEDED,
20
- grpc.StatusCode.ABORTED,
20
+ grpc.StatusCode.CANCELLED, # 网络中断导致的取消
21
+ ],
22
+ 'CONCURRENCY': [
23
+ grpc.StatusCode.ABORTED, # 并发冲突,单独分类便于监控
21
24
  ],
22
25
  'AUTH': [
23
26
  grpc.StatusCode.UNAUTHENTICATED,
@@ -71,6 +74,19 @@ RETRY_POLICY = {
71
74
  'action': 'refresh_token', # 特殊动作
72
75
  'max_attempts': 1
73
76
  },
77
+ grpc.StatusCode.CANCELLED: {
78
+ 'retryable': True,
79
+ 'backoff': 'linear', # 线性退避,网络问题通常不需要指数退避
80
+ 'max_attempts': 2, # 最大重试次数(不包括初始请求),总共会尝试3次
81
+ 'check_details': False # 不检查详细信息,统一重试
82
+ },
83
+ grpc.StatusCode.ABORTED: {
84
+ 'retryable': True,
85
+ 'backoff': 'exponential', # 指数退避,避免加剧并发竞争
86
+ 'max_attempts': 3, # 适中的重试次数
87
+ 'jitter': True, # 添加随机延迟,减少竞争
88
+ 'check_details': False
89
+ },
74
90
  # 不可重试的错误
75
91
  grpc.StatusCode.INVALID_ARGUMENT: {'retryable': False},
76
92
  grpc.StatusCode.NOT_FOUND: {'retryable': False},
@@ -160,6 +176,7 @@ class ErrorContext:
160
176
  """获取建议的处理动作"""
161
177
  suggestions = {
162
178
  'NETWORK': '检查网络连接,稍后重试',
179
+ 'CONCURRENCY': '并发冲突,系统会自动重试',
163
180
  'AUTH': '检查认证信息,可能需要刷新 Token',
164
181
  'VALIDATION': '检查请求参数是否正确',
165
182
  'RESOURCE': '检查资源限制或等待一段时间',
@@ -167,6 +184,37 @@ class ErrorContext:
167
184
  'DATA': '数据损坏或丢失,请检查输入数据',
168
185
  }
169
186
  return suggestions.get(self._get_error_category(), '未知错误,请联系技术支持')
187
+
188
+ def is_network_cancelled(self) -> bool:
189
+ """
190
+ 判断 CANCELLED 错误是否由网络中断导致
191
+
192
+ Returns:
193
+ bool: 如果是网络中断导致的 CANCELLED 返回 True
194
+ """
195
+ if self.error_code != grpc.StatusCode.CANCELLED:
196
+ return False
197
+
198
+ # 检查错误消息中是否包含网络相关的关键词
199
+ error_msg = (self.error_message or '').lower()
200
+ debug_msg = (self.error_debug_string or '').lower()
201
+
202
+ network_patterns = [
203
+ 'connection reset',
204
+ 'connection refused',
205
+ 'connection closed',
206
+ 'network unreachable',
207
+ 'broken pipe',
208
+ 'socket closed',
209
+ 'eof',
210
+ 'transport'
211
+ ]
212
+
213
+ for pattern in network_patterns:
214
+ if pattern in error_msg or pattern in debug_msg:
215
+ return True
216
+
217
+ return False
170
218
 
171
219
 
172
220
  # ===== 异常类层级 =====