tamar-model-client 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,194 @@
1
+ """
2
+ Logging configuration for Tamar Model Client
3
+
4
+ This module provides centralized logging setup for both sync and async clients.
5
+ It includes request ID tracking, JSON formatting, and consistent log configuration.
6
+ """
7
+
8
+ import logging
9
+ import threading
10
+ from typing import Optional, Dict
11
+
12
+ from ..json_formatter import JSONFormatter
13
+ from .utils import get_request_id
14
+
15
+ # gRPC 消息长度限制(32位系统兼容)
16
+ MAX_MESSAGE_LENGTH = 2 ** 31 - 1
17
+
18
+ # SDK 专用的 logger 名称前缀
19
+ TAMAR_LOGGER_PREFIX = "tamar_model_client"
20
+
21
+ # 线程安全的 logger 配置锁
22
+ _logger_lock = threading.Lock()
23
+
24
+ # 已配置的 logger 缓存
25
+ _configured_loggers: Dict[str, logging.Logger] = {}
26
+
27
+
28
+ class RequestIdFilter(logging.Filter):
29
+ """
30
+ 自定义日志过滤器,向日志记录中添加 request_id
31
+
32
+ 这个过滤器从 ContextVar 中获取当前请求的 ID,
33
+ 并将其添加到日志记录中,便于追踪和调试。
34
+ """
35
+
36
+ def filter(self, record):
37
+ """
38
+ 过滤日志记录,添加 request_id 字段
39
+
40
+ Args:
41
+ record: 日志记录对象
42
+
43
+ Returns:
44
+ bool: 总是返回 True,表示记录应被处理
45
+ """
46
+ # 从 ContextVar 中获取当前的 request_id
47
+ record.request_id = get_request_id()
48
+ return True
49
+
50
+
51
+ class TamarLoggerAdapter:
52
+ """
53
+ Logger 适配器,确保 SDK 的日志格式不被外部修改
54
+
55
+ 这个适配器包装了原始的 logger,拦截所有的日志方法调用,
56
+ 确保使用正确的格式和处理器。
57
+ """
58
+
59
+ def __init__(self, logger: logging.Logger):
60
+ self._logger = logger
61
+ self._lock = threading.Lock()
62
+
63
+ def _ensure_format(self):
64
+ """确保 logger 使用正确的格式"""
65
+ with self._lock:
66
+ # 检查并修复处理器
67
+ for handler in self._logger.handlers[:]:
68
+ if not isinstance(handler.formatter, JSONFormatter):
69
+ handler.setFormatter(JSONFormatter())
70
+
71
+ # 确保 propagate 设置正确
72
+ if self._logger.propagate:
73
+ self._logger.propagate = False
74
+
75
+ def _log(self, level, msg, *args, **kwargs):
76
+ """统一的日志方法"""
77
+ self._ensure_format()
78
+ getattr(self._logger, level)(msg, *args, **kwargs)
79
+
80
+ def debug(self, msg, *args, **kwargs):
81
+ self._log('debug', msg, *args, **kwargs)
82
+
83
+ def info(self, msg, *args, **kwargs):
84
+ self._log('info', msg, *args, **kwargs)
85
+
86
+ def warning(self, msg, *args, **kwargs):
87
+ self._log('warning', msg, *args, **kwargs)
88
+
89
+ def error(self, msg, *args, **kwargs):
90
+ self._log('error', msg, *args, **kwargs)
91
+
92
+ def critical(self, msg, *args, **kwargs):
93
+ self._log('critical', msg, *args, **kwargs)
94
+
95
+
96
+ def setup_logger(logger_name: str, level: int = logging.INFO) -> logging.Logger:
97
+ """
98
+ 设置并配置logger (保持向后兼容)
99
+
100
+ 为指定的logger配置处理器、格式化器和过滤器。
101
+ 如果logger已经有处理器,则不会重复配置。
102
+
103
+ Args:
104
+ logger_name: logger的名称
105
+ level: 日志级别,默认为 INFO
106
+
107
+ Returns:
108
+ logging.Logger: 配置好的logger实例
109
+
110
+ 特性:
111
+ - 使用 JSON 格式化器提供结构化日志输出
112
+ - 添加请求ID过滤器用于请求追踪
113
+ - 避免重复配置
114
+ """
115
+ # 确保 logger 名称以 SDK 前缀开始
116
+ if not logger_name.startswith(TAMAR_LOGGER_PREFIX):
117
+ logger_name = f"{TAMAR_LOGGER_PREFIX}.{logger_name}"
118
+
119
+ with _logger_lock:
120
+ # 检查缓存
121
+ if logger_name in _configured_loggers:
122
+ return _configured_loggers[logger_name]
123
+
124
+ logger = logging.getLogger(logger_name)
125
+
126
+ # 强制清除所有现有的处理器
127
+ logger.handlers.clear()
128
+
129
+ # 创建专用的控制台处理器
130
+ console_handler = logging.StreamHandler()
131
+ console_handler.setFormatter(JSONFormatter())
132
+
133
+ # 为处理器设置唯一标识,便于识别
134
+ console_handler.name = f"tamar_handler_{id(console_handler)}"
135
+
136
+ # 添加处理器
137
+ logger.addHandler(console_handler)
138
+
139
+ # 设置日志级别
140
+ logger.setLevel(level)
141
+
142
+ # 添加请求ID过滤器
143
+ logger.addFilter(RequestIdFilter())
144
+
145
+ # 关键设置:
146
+ # 1. 不传播到父 logger
147
+ logger.propagate = False
148
+
149
+ # 2. 禁用外部修改(Python 3.8+)
150
+ if hasattr(logger, 'disabled'):
151
+ logger.disabled = False
152
+
153
+ # 缓存配置好的 logger
154
+ _configured_loggers[logger_name] = logger
155
+
156
+ return logger
157
+
158
+
159
+ def get_protected_logger(logger_name: str, level: int = logging.INFO) -> TamarLoggerAdapter:
160
+ """
161
+ 获取受保护的 logger
162
+
163
+ 返回一个 logger 适配器,确保日志格式不会被外部修改。
164
+
165
+ Args:
166
+ logger_name: logger的名称
167
+ level: 日志级别,默认为 INFO
168
+
169
+ Returns:
170
+ TamarLoggerAdapter: 受保护的 logger 适配器
171
+ """
172
+ logger = setup_logger(logger_name, level)
173
+ return TamarLoggerAdapter(logger)
174
+
175
+
176
+ def reset_logger_config(logger_name: str) -> None:
177
+ """
178
+ 重置 logger 配置
179
+
180
+ 用于测试或需要重新配置的场景。
181
+
182
+ Args:
183
+ logger_name: logger的名称
184
+ """
185
+ if not logger_name.startswith(TAMAR_LOGGER_PREFIX):
186
+ logger_name = f"{TAMAR_LOGGER_PREFIX}.{logger_name}"
187
+
188
+ with _logger_lock:
189
+ if logger_name in _configured_loggers:
190
+ del _configured_loggers[logger_name]
191
+
192
+ logger = logging.getLogger(logger_name)
193
+ logger.handlers.clear()
194
+ logger.filters.clear()
@@ -0,0 +1,221 @@
1
+ """
2
+ Request building logic for Tamar Model Client
3
+
4
+ This module handles the construction of gRPC request objects from
5
+ model request objects, including provider-specific field validation.
6
+ """
7
+
8
+ import json
9
+ from typing import Dict, Any, Set
10
+
11
+ from ..enums import ProviderType, InvokeType
12
+ from ..generated import model_service_pb2
13
+ from ..schemas.inputs import (
14
+ ModelRequest,
15
+ BatchModelRequest,
16
+ BatchModelRequestItem,
17
+ UserContext,
18
+ GoogleGenAiInput,
19
+ GoogleVertexAIImagesInput,
20
+ OpenAIResponsesInput,
21
+ OpenAIChatCompletionsInput,
22
+ OpenAIImagesInput,
23
+ OpenAIImagesEditInput
24
+ )
25
+ from .utils import is_effective_value, serialize_value, remove_none_from_dict
26
+
27
+
28
+ class RequestBuilder:
29
+ """
30
+ 请求构建器
31
+
32
+ 负责将高级的 ModelRequest 对象转换为 gRPC 协议所需的请求对象,
33
+ 包括参数验证、序列化和提供商特定的字段处理。
34
+ """
35
+
36
+ @staticmethod
37
+ def get_allowed_fields(provider: ProviderType, invoke_type: InvokeType) -> Set[str]:
38
+ """
39
+ 获取特定提供商和调用类型组合所允许的字段
40
+
41
+ Args:
42
+ provider: 提供商类型
43
+ invoke_type: 调用类型
44
+
45
+ Returns:
46
+ Set[str]: 允许的字段名集合
47
+
48
+ Raises:
49
+ ValueError: 当提供商和调用类型组合不受支持时
50
+ """
51
+ match (provider, invoke_type):
52
+ case (ProviderType.GOOGLE, InvokeType.GENERATION):
53
+ return set(GoogleGenAiInput.model_fields.keys())
54
+ case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
55
+ return set(GoogleVertexAIImagesInput.model_fields.keys())
56
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
57
+ return set(OpenAIResponsesInput.model_fields.keys())
58
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
59
+ return set(OpenAIChatCompletionsInput.model_fields.keys())
60
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
61
+ return set(OpenAIImagesInput.model_fields.keys())
62
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
63
+ return set(OpenAIImagesEditInput.model_fields.keys())
64
+ case _:
65
+ raise ValueError(
66
+ f"Unsupported provider/invoke_type combination: {provider} + {invoke_type}"
67
+ )
68
+
69
+ @staticmethod
70
+ def build_grpc_extra_fields(model_request: ModelRequest) -> Dict[str, Any]:
71
+ """
72
+ 构建 gRPC 请求的额外字段
73
+
74
+ 根据提供商和调用类型,过滤并序列化请求中的参数。
75
+
76
+ Args:
77
+ model_request: 模型请求对象
78
+
79
+ Returns:
80
+ Dict[str, Any]: 序列化后的额外字段字典
81
+
82
+ Raises:
83
+ ValueError: 当构建请求失败时
84
+ """
85
+ try:
86
+ # 获取允许的字段集合
87
+ allowed_fields = RequestBuilder.get_allowed_fields(
88
+ model_request.provider,
89
+ model_request.invoke_type
90
+ )
91
+
92
+ # 将 ModelRequest 转换为字典,只包含已设置的字段
93
+ model_request_dict = model_request.model_dump(exclude_unset=True)
94
+
95
+ # 构建 gRPC 请求参数
96
+ grpc_request_kwargs = {}
97
+ for field in allowed_fields:
98
+ if field in model_request_dict:
99
+ value = model_request_dict[field]
100
+
101
+ # 跳过无效的值
102
+ if not is_effective_value(value):
103
+ continue
104
+
105
+ # 序列化不支持的类型
106
+ grpc_request_kwargs[field] = serialize_value(value)
107
+
108
+ # 清理序列化后的参数中的 None 值
109
+ grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
110
+
111
+ return grpc_request_kwargs
112
+
113
+ except Exception as e:
114
+ raise ValueError(f"构建请求失败: {str(e)}") from e
115
+
116
+ @staticmethod
117
+ def build_single_request(model_request: ModelRequest) -> model_service_pb2.ModelRequestItem:
118
+ """
119
+ 构建单个模型请求的 gRPC 对象
120
+
121
+ Args:
122
+ model_request: 模型请求对象
123
+
124
+ Returns:
125
+ model_service_pb2.ModelRequestItem: gRPC 请求对象
126
+
127
+ Raises:
128
+ ValueError: 当构建请求失败时
129
+ """
130
+ # 构建额外字段
131
+ extra_fields = RequestBuilder.build_grpc_extra_fields(model_request)
132
+
133
+ # 创建 gRPC 请求对象
134
+ return model_service_pb2.ModelRequestItem(
135
+ provider=model_request.provider.value,
136
+ channel=model_request.channel.value if model_request.channel else "",
137
+ invoke_type=model_request.invoke_type.value,
138
+ stream=model_request.stream or False,
139
+ org_id=model_request.user_context.org_id or "",
140
+ user_id=model_request.user_context.user_id or "",
141
+ client_type=model_request.user_context.client_type or "",
142
+ extra=extra_fields
143
+ )
144
+
145
+ @staticmethod
146
+ def build_batch_request_item(
147
+ batch_item: "BatchModelRequestItem",
148
+ user_context: "UserContext"
149
+ ) -> model_service_pb2.ModelRequestItem:
150
+ """
151
+ 构建批量请求中的单个项目
152
+
153
+ Args:
154
+ batch_item: 批量请求项
155
+ user_context: 用户上下文(来自父BatchModelRequest)
156
+
157
+ Returns:
158
+ model_service_pb2.ModelRequestItem: gRPC 请求对象
159
+ """
160
+ # 构建额外字段
161
+ extra_fields = RequestBuilder.build_grpc_extra_fields(batch_item)
162
+
163
+ # 添加 custom_id 如果存在
164
+ if hasattr(batch_item, 'custom_id') and batch_item.custom_id:
165
+ request_item = model_service_pb2.ModelRequestItem(
166
+ provider=batch_item.provider.value,
167
+ channel=batch_item.channel.value if batch_item.channel else "",
168
+ invoke_type=batch_item.invoke_type.value,
169
+ stream=batch_item.stream or False,
170
+ org_id=user_context.org_id or "",
171
+ user_id=user_context.user_id or "",
172
+ client_type=user_context.client_type or "",
173
+ custom_id=batch_item.custom_id,
174
+ extra=extra_fields
175
+ )
176
+ else:
177
+ request_item = model_service_pb2.ModelRequestItem(
178
+ provider=batch_item.provider.value,
179
+ channel=batch_item.channel.value if batch_item.channel else "",
180
+ invoke_type=batch_item.invoke_type.value,
181
+ stream=batch_item.stream or False,
182
+ org_id=user_context.org_id or "",
183
+ user_id=user_context.user_id or "",
184
+ client_type=user_context.client_type or "",
185
+ extra=extra_fields
186
+ )
187
+
188
+ # 添加 priority 如果存在
189
+ if hasattr(batch_item, 'priority') and batch_item.priority is not None:
190
+ request_item.priority = batch_item.priority
191
+
192
+ return request_item
193
+
194
+ @staticmethod
195
+ def build_batch_request(batch_request: BatchModelRequest) -> model_service_pb2.ModelRequest:
196
+ """
197
+ 构建批量请求的 gRPC 对象
198
+
199
+ Args:
200
+ batch_request: 批量请求对象
201
+
202
+ Returns:
203
+ model_service_pb2.ModelRequest: gRPC 批量请求对象
204
+
205
+ Raises:
206
+ ValueError: 当构建请求失败时
207
+ """
208
+ items = []
209
+
210
+ for batch_item in batch_request.items:
211
+ # 为每个请求项构建 gRPC 对象,传入 user_context
212
+ request_item = RequestBuilder.build_batch_request_item(
213
+ batch_item,
214
+ batch_request.user_context
215
+ )
216
+ items.append(request_item)
217
+
218
+ # 创建批量请求对象
219
+ return model_service_pb2.ModelRequest(
220
+ items=items
221
+ )
@@ -0,0 +1,136 @@
1
+ """
2
+ Response handling logic for Tamar Model Client
3
+
4
+ This module provides utilities for processing gRPC responses and
5
+ converting them to client response objects.
6
+ """
7
+
8
+ import json
9
+ from typing import Optional, Dict, Any
10
+
11
+ from ..schemas import ModelResponse, BatchModelResponse
12
+
13
+
14
+ class ResponseHandler:
15
+ """
16
+ 响应处理器
17
+
18
+ 负责将 gRPC 响应转换为客户端响应对象,
19
+ 包括 JSON 解析、错误处理和数据结构转换。
20
+ """
21
+
22
+ @staticmethod
23
+ def build_model_response(grpc_response) -> ModelResponse:
24
+ """
25
+ 从 gRPC 响应构建 ModelResponse 对象
26
+
27
+ Args:
28
+ grpc_response: gRPC 服务返回的响应对象
29
+
30
+ Returns:
31
+ ModelResponse: 客户端响应对象
32
+ """
33
+ return ModelResponse(
34
+ content=grpc_response.content,
35
+ usage=ResponseHandler._parse_json_field(grpc_response.usage),
36
+ error=grpc_response.error or None,
37
+ raw_response=ResponseHandler._parse_json_field(grpc_response.raw_response),
38
+ request_id=grpc_response.request_id if grpc_response.request_id else None,
39
+ )
40
+
41
+ @staticmethod
42
+ def build_batch_response(grpc_response) -> BatchModelResponse:
43
+ """
44
+ 从 gRPC 批量响应构建 BatchModelResponse 对象
45
+
46
+ Args:
47
+ grpc_response: gRPC 服务返回的批量响应对象
48
+
49
+ Returns:
50
+ BatchModelResponse: 客户端批量响应对象
51
+ """
52
+ responses = []
53
+ for response_item in grpc_response.items:
54
+ model_response = ResponseHandler.build_model_response(response_item)
55
+ responses.append(model_response)
56
+
57
+ return BatchModelResponse(
58
+ responses=responses,
59
+ request_id=grpc_response.request_id if grpc_response.request_id else None
60
+ )
61
+
62
+ @staticmethod
63
+ def _parse_json_field(json_str: Optional[str]) -> Optional[Dict[str, Any]]:
64
+ """
65
+ 安全地解析 JSON 字符串
66
+
67
+ Args:
68
+ json_str: 待解析的 JSON 字符串
69
+
70
+ Returns:
71
+ Optional[Dict[str, Any]]: 解析后的字典,或 None(如果输入为空)
72
+ """
73
+ if not json_str:
74
+ return None
75
+
76
+ try:
77
+ return json.loads(json_str)
78
+ except json.JSONDecodeError:
79
+ # 如果解析失败,返回原始字符串作为错误信息
80
+ return {"error": "JSON parse error", "raw": json_str}
81
+
82
+ @staticmethod
83
+ def build_log_data(
84
+ model_request,
85
+ response: Optional[ModelResponse] = None,
86
+ duration: Optional[float] = None,
87
+ error: Optional[Exception] = None,
88
+ stream_stats: Optional[Dict[str, Any]] = None
89
+ ) -> Dict[str, Any]:
90
+ """
91
+ 构建日志数据
92
+
93
+ 为请求和响应日志构建结构化的数据字典。
94
+
95
+ Args:
96
+ model_request: 原始请求对象
97
+ response: 响应对象(可选)
98
+ duration: 请求持续时间(秒)
99
+ error: 错误对象(可选)
100
+ stream_stats: 流式响应统计信息(可选)
101
+
102
+ Returns:
103
+ Dict[str, Any]: 日志数据字典
104
+ """
105
+ data = {
106
+ "provider": model_request.provider.value,
107
+ "invoke_type": model_request.invoke_type.value,
108
+ "model": getattr(model_request, 'model', None),
109
+ "stream": getattr(model_request, 'stream', False),
110
+ }
111
+
112
+ # 添加用户上下文信息(如果有)
113
+ if hasattr(model_request, 'user_context'):
114
+ data.update({
115
+ "org_id": model_request.user_context.org_id,
116
+ "user_id": model_request.user_context.user_id,
117
+ "client_type": model_request.user_context.client_type
118
+ })
119
+
120
+ # 添加响应信息
121
+ if response:
122
+ if hasattr(response, 'content') and response.content:
123
+ data["content_length"] = len(response.content)
124
+ if hasattr(response, 'usage'):
125
+ data["usage"] = response.usage
126
+
127
+ # 添加流式响应统计
128
+ if stream_stats:
129
+ data.update(stream_stats)
130
+
131
+ # 添加错误信息
132
+ if error:
133
+ data["error_type"] = type(error).__name__
134
+ data["error_message"] = str(error)
135
+
136
+ return data