tamar-model-client 0.1.20__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,249 @@
1
+ """
2
+ HTTP fallback functionality for resilient clients
3
+
4
+ This module provides mixin classes for HTTP-based fallback when gRPC
5
+ connections fail, supporting both synchronous and asynchronous clients.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from typing import Optional, Iterator, AsyncIterator, Dict, Any
11
+
12
+ from . import generate_request_id, get_protected_logger
13
+ from ..schemas import ModelRequest, ModelResponse
14
+
15
+ logger = get_protected_logger(__name__)
16
+
17
+
18
+ class HttpFallbackMixin:
19
+ """HTTP fallback functionality for synchronous clients"""
20
+
21
+ def _ensure_http_client(self) -> None:
22
+ """Ensure HTTP client is initialized"""
23
+ if not hasattr(self, '_http_client') or not self._http_client:
24
+ import requests
25
+ self._http_client = requests.Session()
26
+
27
+ # Set authentication header if available
28
+ # Note: JWT token will be set per request in headers
29
+
30
+ # Set default headers
31
+ self._http_client.headers.update({
32
+ 'Content-Type': 'application/json',
33
+ 'User-Agent': 'TamarModelClient/1.0'
34
+ })
35
+
36
+ def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
37
+ """Convert ModelRequest to HTTP payload format"""
38
+ payload = {
39
+ "provider": model_request.provider.value,
40
+ "model": model_request.model,
41
+ "user_context": model_request.user_context.model_dump(),
42
+ "stream": model_request.stream
43
+ }
44
+
45
+ # Add provider-specific fields
46
+ if hasattr(model_request, 'messages') and model_request.messages:
47
+ payload['messages'] = model_request.messages
48
+ if hasattr(model_request, 'contents') and model_request.contents:
49
+ payload['contents'] = model_request.contents
50
+
51
+ # Add optional fields
52
+ if model_request.channel:
53
+ payload['channel'] = model_request.channel.value
54
+ if model_request.invoke_type:
55
+ payload['invoke_type'] = model_request.invoke_type.value
56
+
57
+ # Add extra parameters
58
+ if hasattr(model_request, 'model_extra') and model_request.model_extra:
59
+ for key, value in model_request.model_extra.items():
60
+ if key not in payload:
61
+ payload[key] = value
62
+
63
+ return payload
64
+
65
+ def _handle_http_stream(self, url: str, payload: Dict[str, Any],
66
+ timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> Iterator[ModelResponse]:
67
+ """Handle HTTP streaming response"""
68
+ import requests
69
+
70
+ response = self._http_client.post(
71
+ url,
72
+ json=payload,
73
+ timeout=timeout or 30,
74
+ headers=headers,
75
+ stream=True
76
+ )
77
+ response.raise_for_status()
78
+
79
+ # Parse SSE stream
80
+ for line in response.iter_lines():
81
+ if line:
82
+ line_str = line.decode('utf-8')
83
+ if line_str.startswith('data: '):
84
+ data_str = line_str[6:]
85
+ if data_str == '[DONE]':
86
+ break
87
+ try:
88
+ data = json.loads(data_str)
89
+ yield ModelResponse(**data)
90
+ except json.JSONDecodeError:
91
+ logger.warning(f"Failed to parse streaming response: {data_str}")
92
+
93
+ def _invoke_http_fallback(self, model_request: ModelRequest,
94
+ timeout: Optional[float] = None,
95
+ request_id: Optional[str] = None) -> Any:
96
+ """HTTP fallback implementation"""
97
+ self._ensure_http_client()
98
+
99
+ # Generate request ID if not provided
100
+ if not request_id:
101
+ request_id = generate_request_id()
102
+
103
+ # Log fallback usage
104
+ logger.warning(
105
+ f"🔻 Using HTTP fallback for request",
106
+ extra={
107
+ "request_id": request_id,
108
+ "provider": model_request.provider.value,
109
+ "model": model_request.model,
110
+ "fallback_url": self.http_fallback_url
111
+ }
112
+ )
113
+
114
+ # Convert to HTTP format
115
+ http_payload = self._convert_to_http_format(model_request)
116
+
117
+ # Construct URL
118
+ url = f"{self.http_fallback_url}/v1/invoke"
119
+
120
+ # Build headers with authentication
121
+ headers = {'X-Request-ID': request_id}
122
+ if hasattr(self, 'jwt_token') and self.jwt_token:
123
+ headers['Authorization'] = f'Bearer {self.jwt_token}'
124
+
125
+ if model_request.stream:
126
+ # Return streaming iterator
127
+ return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
128
+ else:
129
+ # Non-streaming request
130
+ response = self._http_client.post(
131
+ url,
132
+ json=http_payload,
133
+ timeout=timeout or 30,
134
+ headers=headers
135
+ )
136
+ response.raise_for_status()
137
+
138
+ # Parse response
139
+ data = response.json()
140
+ return ModelResponse(**data)
141
+
142
+
143
+ class AsyncHttpFallbackMixin:
144
+ """HTTP fallback functionality for asynchronous clients"""
145
+
146
+ async def _ensure_http_client(self) -> None:
147
+ """Ensure async HTTP client is initialized"""
148
+ if not hasattr(self, '_http_session') or not self._http_session:
149
+ import aiohttp
150
+ self._http_session = aiohttp.ClientSession(
151
+ headers={
152
+ 'Content-Type': 'application/json',
153
+ 'User-Agent': 'AsyncTamarModelClient/1.0'
154
+ }
155
+ )
156
+
157
+ # Note: JWT token will be set per request in headers
158
+
159
+ def _convert_to_http_format(self, model_request: ModelRequest) -> Dict[str, Any]:
160
+ """Convert ModelRequest to HTTP payload format (reuse sync version)"""
161
+ # This method doesn't need to be async, so we can reuse the sync version
162
+ return HttpFallbackMixin._convert_to_http_format(self, model_request)
163
+
164
+ async def _handle_http_stream(self, url: str, payload: Dict[str, Any],
165
+ timeout: Optional[float], request_id: str, headers: Dict[str, str]) -> AsyncIterator[ModelResponse]:
166
+ """Handle async HTTP streaming response"""
167
+ import aiohttp
168
+
169
+ timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
170
+
171
+ async with self._http_session.post(
172
+ url,
173
+ json=payload,
174
+ timeout=timeout_obj,
175
+ headers=headers
176
+ ) as response:
177
+ response.raise_for_status()
178
+
179
+ # Parse SSE stream
180
+ async for line_bytes in response.content:
181
+ if line_bytes:
182
+ line_str = line_bytes.decode('utf-8').strip()
183
+ if line_str.startswith('data: '):
184
+ data_str = line_str[6:]
185
+ if data_str == '[DONE]':
186
+ break
187
+ try:
188
+ data = json.loads(data_str)
189
+ yield ModelResponse(**data)
190
+ except json.JSONDecodeError:
191
+ logger.warning(f"Failed to parse streaming response: {data_str}")
192
+
193
+ async def _invoke_http_fallback(self, model_request: ModelRequest,
194
+ timeout: Optional[float] = None,
195
+ request_id: Optional[str] = None) -> Any:
196
+ """Async HTTP fallback implementation"""
197
+ await self._ensure_http_client()
198
+
199
+ # Generate request ID if not provided
200
+ if not request_id:
201
+ request_id = generate_request_id()
202
+
203
+ # Log fallback usage
204
+ logger.warning(
205
+ f"🔻 Using HTTP fallback for request",
206
+ extra={
207
+ "request_id": request_id,
208
+ "provider": model_request.provider.value,
209
+ "model": model_request.model,
210
+ "fallback_url": self.http_fallback_url
211
+ }
212
+ )
213
+
214
+ # Convert to HTTP format
215
+ http_payload = self._convert_to_http_format(model_request)
216
+
217
+ # Construct URL
218
+ url = f"{self.http_fallback_url}/v1/invoke"
219
+
220
+ # Build headers with authentication
221
+ headers = {'X-Request-ID': request_id}
222
+ if hasattr(self, 'jwt_token') and self.jwt_token:
223
+ headers['Authorization'] = f'Bearer {self.jwt_token}'
224
+
225
+ if model_request.stream:
226
+ # Return async streaming iterator
227
+ return self._handle_http_stream(url, http_payload, timeout, request_id, headers)
228
+ else:
229
+ # Non-streaming request
230
+ import aiohttp
231
+ timeout_obj = aiohttp.ClientTimeout(total=timeout or 30) if timeout else None
232
+
233
+ async with self._http_session.post(
234
+ url,
235
+ json=http_payload,
236
+ timeout=timeout_obj,
237
+ headers=headers
238
+ ) as response:
239
+ response.raise_for_status()
240
+
241
+ # Parse response
242
+ data = await response.json()
243
+ return ModelResponse(**data)
244
+
245
+ async def _cleanup_http_session(self) -> None:
246
+ """Clean up HTTP session"""
247
+ if hasattr(self, '_http_session') and self._http_session:
248
+ await self._http_session.close()
249
+ self._http_session = None
@@ -6,7 +6,8 @@ It includes request ID tracking, JSON formatting, and consistent log configurati
6
6
  """
7
7
 
8
8
  import logging
9
- from typing import Optional
9
+ import threading
10
+ from typing import Optional, Dict
10
11
 
11
12
  from ..json_formatter import JSONFormatter
12
13
  from .utils import get_request_id
@@ -14,6 +15,15 @@ from .utils import get_request_id
14
15
  # gRPC 消息长度限制(32位系统兼容)
15
16
  MAX_MESSAGE_LENGTH = 2 ** 31 - 1
16
17
 
18
+ # SDK 专用的 logger 名称前缀
19
+ TAMAR_LOGGER_PREFIX = "tamar_model_client"
20
+
21
+ # 线程安全的 logger 配置锁
22
+ _logger_lock = threading.Lock()
23
+
24
+ # 已配置的 logger 缓存
25
+ _configured_loggers: Dict[str, logging.Logger] = {}
26
+
17
27
 
18
28
  class RequestIdFilter(logging.Filter):
19
29
  """
@@ -38,9 +48,54 @@ class RequestIdFilter(logging.Filter):
38
48
  return True
39
49
 
40
50
 
51
+ class TamarLoggerAdapter:
52
+ """
53
+ Logger 适配器,确保 SDK 的日志格式不被外部修改
54
+
55
+ 这个适配器包装了原始的 logger,拦截所有的日志方法调用,
56
+ 确保使用正确的格式和处理器。
57
+ """
58
+
59
+ def __init__(self, logger: logging.Logger):
60
+ self._logger = logger
61
+ self._lock = threading.Lock()
62
+
63
+ def _ensure_format(self):
64
+ """确保 logger 使用正确的格式"""
65
+ with self._lock:
66
+ # 检查并修复处理器
67
+ for handler in self._logger.handlers[:]:
68
+ if not isinstance(handler.formatter, JSONFormatter):
69
+ handler.setFormatter(JSONFormatter())
70
+
71
+ # 确保 propagate 设置正确
72
+ if self._logger.propagate:
73
+ self._logger.propagate = False
74
+
75
+ def _log(self, level, msg, *args, **kwargs):
76
+ """统一的日志方法"""
77
+ self._ensure_format()
78
+ getattr(self._logger, level)(msg, *args, **kwargs)
79
+
80
+ def debug(self, msg, *args, **kwargs):
81
+ self._log('debug', msg, *args, **kwargs)
82
+
83
+ def info(self, msg, *args, **kwargs):
84
+ self._log('info', msg, *args, **kwargs)
85
+
86
+ def warning(self, msg, *args, **kwargs):
87
+ self._log('warning', msg, *args, **kwargs)
88
+
89
+ def error(self, msg, *args, **kwargs):
90
+ self._log('error', msg, *args, **kwargs)
91
+
92
+ def critical(self, msg, *args, **kwargs):
93
+ self._log('critical', msg, *args, **kwargs)
94
+
95
+
41
96
  def setup_logger(logger_name: str, level: int = logging.INFO) -> logging.Logger:
42
97
  """
43
- 设置并配置logger
98
+ 设置并配置logger (保持向后兼容)
44
99
 
45
100
  为指定的logger配置处理器、格式化器和过滤器。
46
101
  如果logger已经有处理器,则不会重复配置。
@@ -57,28 +112,83 @@ def setup_logger(logger_name: str, level: int = logging.INFO) -> logging.Logger:
57
112
  - 添加请求ID过滤器用于请求追踪
58
113
  - 避免重复配置
59
114
  """
60
- logger = logging.getLogger(logger_name)
115
+ # 确保 logger 名称以 SDK 前缀开始
116
+ if not logger_name.startswith(TAMAR_LOGGER_PREFIX):
117
+ logger_name = f"{TAMAR_LOGGER_PREFIX}.{logger_name}"
61
118
 
62
- # 仅在没有处理器时配置,避免重复配置
63
- if not logger.hasHandlers():
64
- # 创建控制台日志处理器
119
+ with _logger_lock:
120
+ # 检查缓存
121
+ if logger_name in _configured_loggers:
122
+ return _configured_loggers[logger_name]
123
+
124
+ logger = logging.getLogger(logger_name)
125
+
126
+ # 强制清除所有现有的处理器
127
+ logger.handlers.clear()
128
+
129
+ # 创建专用的控制台处理器
65
130
  console_handler = logging.StreamHandler()
131
+ console_handler.setFormatter(JSONFormatter())
66
132
 
67
- # 使用自定义的 JSON 格式化器,提供结构化日志输出
68
- formatter = JSONFormatter()
69
- console_handler.setFormatter(formatter)
133
+ # 为处理器设置唯一标识,便于识别
134
+ console_handler.name = f"tamar_handler_{id(console_handler)}"
70
135
 
71
- # 为logger添加处理器
136
+ # 添加处理器
72
137
  logger.addHandler(console_handler)
73
138
 
74
139
  # 设置日志级别
75
140
  logger.setLevel(level)
76
141
 
77
- # 添加自定义的请求ID过滤器,用于请求追踪
142
+ # 添加请求ID过滤器
78
143
  logger.addFilter(RequestIdFilter())
79
144
 
80
- # 关键:设置 propagate = False,防止日志传播到父logger
81
- # 这样可以避免测试脚本的日志格式影响客户端日志
145
+ # 关键设置:
146
+ # 1. 不传播到父 logger
82
147
  logger.propagate = False
148
+
149
+ # 2. 禁用外部修改(Python 3.8+)
150
+ if hasattr(logger, 'disabled'):
151
+ logger.disabled = False
152
+
153
+ # 缓存配置好的 logger
154
+ _configured_loggers[logger_name] = logger
155
+
156
+ return logger
157
+
158
+
159
+ def get_protected_logger(logger_name: str, level: int = logging.INFO) -> TamarLoggerAdapter:
160
+ """
161
+ 获取受保护的 logger
162
+
163
+ 返回一个 logger 适配器,确保日志格式不会被外部修改。
83
164
 
84
- return logger
165
+ Args:
166
+ logger_name: logger的名称
167
+ level: 日志级别,默认为 INFO
168
+
169
+ Returns:
170
+ TamarLoggerAdapter: 受保护的 logger 适配器
171
+ """
172
+ logger = setup_logger(logger_name, level)
173
+ return TamarLoggerAdapter(logger)
174
+
175
+
176
+ def reset_logger_config(logger_name: str) -> None:
177
+ """
178
+ 重置 logger 配置
179
+
180
+ 用于测试或需要重新配置的场景。
181
+
182
+ Args:
183
+ logger_name: logger的名称
184
+ """
185
+ if not logger_name.startswith(TAMAR_LOGGER_PREFIX):
186
+ logger_name = f"{TAMAR_LOGGER_PREFIX}.{logger_name}"
187
+
188
+ with _logger_lock:
189
+ if logger_name in _configured_loggers:
190
+ del _configured_loggers[logger_name]
191
+
192
+ logger = logging.getLogger(logger_name)
193
+ logger.handlers.clear()
194
+ logger.filters.clear()
@@ -43,9 +43,28 @@ class GrpcErrorHandler:
43
43
  error_context = ErrorContext(error, context)
44
44
 
45
45
  # 记录详细错误日志
46
+ # 将error_context的重要信息平铺到日志的data字段中
47
+ log_data = {
48
+ "log_type": "info",
49
+ "request_id": error_context.request_id,
50
+ "data": {
51
+ "error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
52
+ "error_message": error_context.error_message,
53
+ "provider": error_context.provider,
54
+ "model": error_context.model,
55
+ "method": error_context.method,
56
+ "retry_count": error_context.retry_count,
57
+ "category": error_context._get_error_category(),
58
+ "is_retryable": error_context._is_retryable(),
59
+ "suggested_action": error_context._get_suggested_action(),
60
+ "debug_string": error_context.error_debug_string,
61
+ "is_network_cancelled": error_context.is_network_cancelled() if error_context.error_code == grpc.StatusCode.CANCELLED else None
62
+ }
63
+ }
64
+
46
65
  self.logger.error(
47
- f"gRPC Error occurred: {error_context.error_code}",
48
- extra=error_context.to_dict()
66
+ f"gRPC Error occurred: {error_context.error_code.name if error_context.error_code else 'UNKNOWN'}",
67
+ extra=log_data
49
68
  )
50
69
 
51
70
  # 更新错误统计
@@ -211,9 +230,22 @@ class EnhancedRetryHandler:
211
230
  break
212
231
 
213
232
  # 记录重试日志
233
+ log_data = {
234
+ "log_type": "info",
235
+ "request_id": error_context.request_id,
236
+ "data": {
237
+ "error_code": error_context.error_code.name if error_context.error_code else 'UNKNOWN',
238
+ "error_message": error_context.error_message,
239
+ "retry_count": attempt,
240
+ "max_retries": self.max_retries,
241
+ "category": error_context._get_error_category(),
242
+ "is_retryable": True, # 既然在重试,说明是可重试的
243
+ "method": error_context.method
244
+ }
245
+ }
214
246
  logger.warning(
215
247
  f"Attempt {attempt + 1}/{self.max_retries + 1} failed: {e.code()}",
216
- extra=error_context.to_dict()
248
+ extra=log_data
217
249
  )
218
250
 
219
251
  # 执行退避等待
@@ -17,7 +17,10 @@ ERROR_CATEGORIES = {
17
17
  'NETWORK': [
18
18
  grpc.StatusCode.UNAVAILABLE,
19
19
  grpc.StatusCode.DEADLINE_EXCEEDED,
20
- grpc.StatusCode.ABORTED,
20
+ grpc.StatusCode.CANCELLED, # 网络中断导致的取消
21
+ ],
22
+ 'CONCURRENCY': [
23
+ grpc.StatusCode.ABORTED, # 并发冲突,单独分类便于监控
21
24
  ],
22
25
  'AUTH': [
23
26
  grpc.StatusCode.UNAUTHENTICATED,
@@ -71,6 +74,19 @@ RETRY_POLICY = {
71
74
  'action': 'refresh_token', # 特殊动作
72
75
  'max_attempts': 1
73
76
  },
77
+ grpc.StatusCode.CANCELLED: {
78
+ 'retryable': True,
79
+ 'backoff': 'linear', # 线性退避,网络问题通常不需要指数退避
80
+ 'max_attempts': 2, # 限制重试次数,避免过度重试
81
+ 'check_details': False # 不检查详细信息,统一重试
82
+ },
83
+ grpc.StatusCode.ABORTED: {
84
+ 'retryable': True,
85
+ 'backoff': 'exponential', # 指数退避,避免加剧并发竞争
86
+ 'max_attempts': 3, # 适中的重试次数
87
+ 'jitter': True, # 添加随机延迟,减少竞争
88
+ 'check_details': False
89
+ },
74
90
  # 不可重试的错误
75
91
  grpc.StatusCode.INVALID_ARGUMENT: {'retryable': False},
76
92
  grpc.StatusCode.NOT_FOUND: {'retryable': False},
@@ -160,6 +176,7 @@ class ErrorContext:
160
176
  """获取建议的处理动作"""
161
177
  suggestions = {
162
178
  'NETWORK': '检查网络连接,稍后重试',
179
+ 'CONCURRENCY': '并发冲突,系统会自动重试',
163
180
  'AUTH': '检查认证信息,可能需要刷新 Token',
164
181
  'VALIDATION': '检查请求参数是否正确',
165
182
  'RESOURCE': '检查资源限制或等待一段时间',