tamar-model-client 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/__init__.py +2 -0
- tamar_model_client/async_client.py +430 -539
- tamar_model_client/core/__init__.py +34 -0
- tamar_model_client/core/base_client.py +168 -0
- tamar_model_client/core/logging_setup.py +84 -0
- tamar_model_client/core/request_builder.py +221 -0
- tamar_model_client/core/response_handler.py +136 -0
- tamar_model_client/core/utils.py +171 -0
- tamar_model_client/error_handler.py +283 -0
- tamar_model_client/exceptions.py +371 -7
- tamar_model_client/json_formatter.py +36 -1
- tamar_model_client/logging_icons.py +60 -0
- tamar_model_client/sync_client.py +473 -485
- {tamar_model_client-0.1.18.dist-info → tamar_model_client-0.1.20.dist-info}/METADATA +217 -61
- tamar_model_client-0.1.20.dist-info/RECORD +33 -0
- {tamar_model_client-0.1.18.dist-info → tamar_model_client-0.1.20.dist-info}/top_level.txt +1 -0
- tests/__init__.py +1 -0
- tests/stream_hanging_analysis.py +357 -0
- tests/test_google_azure_final.py +448 -0
- tests/test_simple.py +235 -0
- tamar_model_client-0.1.18.dist-info/RECORD +0 -21
- {tamar_model_client-0.1.18.dist-info → tamar_model_client-0.1.20.dist-info}/WHEEL +0 -0
@@ -0,0 +1,34 @@
|
|
1
|
+
"""
|
2
|
+
Core components for Tamar Model Client
|
3
|
+
|
4
|
+
This package contains shared components used by both sync and async clients.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from .utils import (
|
8
|
+
is_effective_value,
|
9
|
+
serialize_value,
|
10
|
+
remove_none_from_dict,
|
11
|
+
generate_request_id,
|
12
|
+
set_request_id,
|
13
|
+
get_request_id
|
14
|
+
)
|
15
|
+
|
16
|
+
from .logging_setup import (
|
17
|
+
setup_logger,
|
18
|
+
RequestIdFilter,
|
19
|
+
MAX_MESSAGE_LENGTH
|
20
|
+
)
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
# Utils
|
24
|
+
'is_effective_value',
|
25
|
+
'serialize_value',
|
26
|
+
'remove_none_from_dict',
|
27
|
+
'generate_request_id',
|
28
|
+
'set_request_id',
|
29
|
+
'get_request_id',
|
30
|
+
# Logging
|
31
|
+
'setup_logger',
|
32
|
+
'RequestIdFilter',
|
33
|
+
'MAX_MESSAGE_LENGTH',
|
34
|
+
]
|
@@ -0,0 +1,168 @@
|
|
1
|
+
"""
|
2
|
+
Base client class for Tamar Model Client
|
3
|
+
|
4
|
+
This module provides the base client class with shared initialization logic
|
5
|
+
and configuration management for both sync and async clients.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import os
|
9
|
+
import logging
|
10
|
+
from typing import Optional, Dict, Any
|
11
|
+
from abc import ABC, abstractmethod
|
12
|
+
|
13
|
+
from ..auth import JWTAuthHandler
|
14
|
+
from ..error_handler import GrpcErrorHandler, ErrorRecoveryStrategy
|
15
|
+
from .logging_setup import MAX_MESSAGE_LENGTH, setup_logger
|
16
|
+
|
17
|
+
|
18
|
+
class BaseClient(ABC):
|
19
|
+
"""
|
20
|
+
基础客户端抽象类
|
21
|
+
|
22
|
+
提供同步和异步客户端的共享功能:
|
23
|
+
- 配置管理
|
24
|
+
- 认证设置
|
25
|
+
- 连接选项构建
|
26
|
+
- 错误处理器初始化
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
server_address: Optional[str] = None,
|
32
|
+
jwt_secret_key: Optional[str] = None,
|
33
|
+
jwt_token: Optional[str] = None,
|
34
|
+
default_payload: Optional[dict] = None,
|
35
|
+
token_expires_in: int = 3600,
|
36
|
+
max_retries: Optional[int] = None,
|
37
|
+
retry_delay: Optional[float] = None,
|
38
|
+
logger_name: str = None,
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
初始化基础客户端
|
42
|
+
|
43
|
+
Args:
|
44
|
+
server_address: gRPC 服务器地址,格式为 "host:port"
|
45
|
+
jwt_secret_key: JWT 签名密钥,用于生成认证令牌
|
46
|
+
jwt_token: 预生成的 JWT 令牌(可选)
|
47
|
+
default_payload: JWT 令牌的默认载荷
|
48
|
+
token_expires_in: JWT 令牌过期时间(秒)
|
49
|
+
max_retries: 最大重试次数(默认从环境变量读取)
|
50
|
+
retry_delay: 初始重试延迟(秒,默认从环境变量读取)
|
51
|
+
logger_name: 日志记录器名称
|
52
|
+
|
53
|
+
Raises:
|
54
|
+
ValueError: 当服务器地址未提供时
|
55
|
+
"""
|
56
|
+
# === 服务端地址配置 ===
|
57
|
+
self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
|
58
|
+
if not self.server_address:
|
59
|
+
raise ValueError("Server address must be provided via argument or environment variable.")
|
60
|
+
|
61
|
+
# 默认调用超时时间
|
62
|
+
self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
|
63
|
+
|
64
|
+
# === JWT 认证配置 ===
|
65
|
+
self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
|
66
|
+
self.jwt_handler = JWTAuthHandler(self.jwt_secret_key) if self.jwt_secret_key else None
|
67
|
+
self.jwt_token = jwt_token # 用户传入的预生成 Token(可选)
|
68
|
+
self.default_payload = default_payload
|
69
|
+
self.token_expires_in = token_expires_in
|
70
|
+
|
71
|
+
# === TLS/Authority 配置 ===
|
72
|
+
self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
|
73
|
+
self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
|
74
|
+
|
75
|
+
# === 重试配置 ===
|
76
|
+
self.max_retries = max_retries if max_retries is not None else int(
|
77
|
+
os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
|
78
|
+
self.retry_delay = retry_delay if retry_delay is not None else float(
|
79
|
+
os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
|
80
|
+
|
81
|
+
# === 日志配置 ===
|
82
|
+
self.logger = setup_logger(logger_name or __name__)
|
83
|
+
|
84
|
+
# === 错误处理器 ===
|
85
|
+
self.error_handler = GrpcErrorHandler(self.logger)
|
86
|
+
self.recovery_strategy = ErrorRecoveryStrategy(self)
|
87
|
+
|
88
|
+
# === 连接状态 ===
|
89
|
+
self._closed = False
|
90
|
+
|
91
|
+
def build_channel_options(self) -> list:
|
92
|
+
"""
|
93
|
+
构建 gRPC 通道选项
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
list: gRPC 通道配置选项列表
|
97
|
+
|
98
|
+
包含的配置:
|
99
|
+
- 消息大小限制
|
100
|
+
- Keepalive 设置(30秒ping间隔,10秒超时)
|
101
|
+
- 连接生命周期管理(1小时最大连接时间)
|
102
|
+
- 性能优化选项(带宽探测、内置重试)
|
103
|
+
"""
|
104
|
+
options = [
|
105
|
+
# 消息大小限制
|
106
|
+
('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
|
107
|
+
('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
|
108
|
+
|
109
|
+
# Keepalive 核心配置
|
110
|
+
('grpc.keepalive_time_ms', 30000), # 30秒发送一次 keepalive ping
|
111
|
+
('grpc.keepalive_timeout_ms', 10000), # ping 响应超时时间 10秒
|
112
|
+
('grpc.keepalive_permit_without_calls', True), # 空闲时也发送 keepalive
|
113
|
+
('grpc.http2.max_pings_without_data', 2), # 无数据时最大 ping 次数
|
114
|
+
|
115
|
+
# 连接管理增强配置
|
116
|
+
('grpc.http2.min_time_between_pings_ms', 10000), # ping 最小间隔 10秒
|
117
|
+
('grpc.http2.max_connection_idle_ms', 300000), # 最大空闲时间 5分钟
|
118
|
+
('grpc.http2.max_connection_age_ms', 3600000), # 连接最大生存时间 1小时
|
119
|
+
('grpc.http2.max_connection_age_grace_ms', 5000), # 优雅关闭时间 5秒
|
120
|
+
|
121
|
+
# 性能相关配置
|
122
|
+
('grpc.http2.bdp_probe', 1), # 启用带宽延迟探测
|
123
|
+
('grpc.enable_retries', 1), # 启用内置重试
|
124
|
+
]
|
125
|
+
|
126
|
+
if self.default_authority:
|
127
|
+
options.append(("grpc.default_authority", self.default_authority))
|
128
|
+
|
129
|
+
return options
|
130
|
+
|
131
|
+
def _build_auth_metadata(self, request_id: str) -> list:
|
132
|
+
"""
|
133
|
+
构建认证元数据
|
134
|
+
|
135
|
+
为每个请求构建包含认证信息和请求ID的gRPC元数据。
|
136
|
+
JWT令牌会在每次请求时重新生成以确保有效性。
|
137
|
+
|
138
|
+
Args:
|
139
|
+
request_id: 当前请求的唯一标识符
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
list: gRPC元数据列表,包含请求ID和认证令牌
|
143
|
+
"""
|
144
|
+
metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
|
145
|
+
|
146
|
+
if self.jwt_handler:
|
147
|
+
self.jwt_token = self.jwt_handler.encode_token(
|
148
|
+
self.default_payload,
|
149
|
+
expires_in=self.token_expires_in
|
150
|
+
)
|
151
|
+
metadata.append(("authorization", f"Bearer {self.jwt_token}"))
|
152
|
+
|
153
|
+
return metadata
|
154
|
+
|
155
|
+
@abstractmethod
|
156
|
+
def close(self):
|
157
|
+
"""关闭客户端连接(由子类实现)"""
|
158
|
+
pass
|
159
|
+
|
160
|
+
@abstractmethod
|
161
|
+
def __enter__(self):
|
162
|
+
"""进入上下文管理器(由子类实现)"""
|
163
|
+
pass
|
164
|
+
|
165
|
+
@abstractmethod
|
166
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
167
|
+
"""退出上下文管理器(由子类实现)"""
|
168
|
+
pass
|
@@ -0,0 +1,84 @@
|
|
1
|
+
"""
|
2
|
+
Logging configuration for Tamar Model Client
|
3
|
+
|
4
|
+
This module provides centralized logging setup for both sync and async clients.
|
5
|
+
It includes request ID tracking, JSON formatting, and consistent log configuration.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
from typing import Optional
|
10
|
+
|
11
|
+
from ..json_formatter import JSONFormatter
|
12
|
+
from .utils import get_request_id
|
13
|
+
|
14
|
+
# gRPC 消息长度限制(32位系统兼容)
|
15
|
+
MAX_MESSAGE_LENGTH = 2 ** 31 - 1
|
16
|
+
|
17
|
+
|
18
|
+
class RequestIdFilter(logging.Filter):
|
19
|
+
"""
|
20
|
+
自定义日志过滤器,向日志记录中添加 request_id
|
21
|
+
|
22
|
+
这个过滤器从 ContextVar 中获取当前请求的 ID,
|
23
|
+
并将其添加到日志记录中,便于追踪和调试。
|
24
|
+
"""
|
25
|
+
|
26
|
+
def filter(self, record):
|
27
|
+
"""
|
28
|
+
过滤日志记录,添加 request_id 字段
|
29
|
+
|
30
|
+
Args:
|
31
|
+
record: 日志记录对象
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
bool: 总是返回 True,表示记录应被处理
|
35
|
+
"""
|
36
|
+
# 从 ContextVar 中获取当前的 request_id
|
37
|
+
record.request_id = get_request_id()
|
38
|
+
return True
|
39
|
+
|
40
|
+
|
41
|
+
def setup_logger(logger_name: str, level: int = logging.INFO) -> logging.Logger:
|
42
|
+
"""
|
43
|
+
设置并配置logger
|
44
|
+
|
45
|
+
为指定的logger配置处理器、格式化器和过滤器。
|
46
|
+
如果logger已经有处理器,则不会重复配置。
|
47
|
+
|
48
|
+
Args:
|
49
|
+
logger_name: logger的名称
|
50
|
+
level: 日志级别,默认为 INFO
|
51
|
+
|
52
|
+
Returns:
|
53
|
+
logging.Logger: 配置好的logger实例
|
54
|
+
|
55
|
+
特性:
|
56
|
+
- 使用 JSON 格式化器提供结构化日志输出
|
57
|
+
- 添加请求ID过滤器用于请求追踪
|
58
|
+
- 避免重复配置
|
59
|
+
"""
|
60
|
+
logger = logging.getLogger(logger_name)
|
61
|
+
|
62
|
+
# 仅在没有处理器时配置,避免重复配置
|
63
|
+
if not logger.hasHandlers():
|
64
|
+
# 创建控制台日志处理器
|
65
|
+
console_handler = logging.StreamHandler()
|
66
|
+
|
67
|
+
# 使用自定义的 JSON 格式化器,提供结构化日志输出
|
68
|
+
formatter = JSONFormatter()
|
69
|
+
console_handler.setFormatter(formatter)
|
70
|
+
|
71
|
+
# 为logger添加处理器
|
72
|
+
logger.addHandler(console_handler)
|
73
|
+
|
74
|
+
# 设置日志级别
|
75
|
+
logger.setLevel(level)
|
76
|
+
|
77
|
+
# 添加自定义的请求ID过滤器,用于请求追踪
|
78
|
+
logger.addFilter(RequestIdFilter())
|
79
|
+
|
80
|
+
# 关键:设置 propagate = False,防止日志传播到父logger
|
81
|
+
# 这样可以避免测试脚本的日志格式影响客户端日志
|
82
|
+
logger.propagate = False
|
83
|
+
|
84
|
+
return logger
|
@@ -0,0 +1,221 @@
|
|
1
|
+
"""
|
2
|
+
Request building logic for Tamar Model Client
|
3
|
+
|
4
|
+
This module handles the construction of gRPC request objects from
|
5
|
+
model request objects, including provider-specific field validation.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
from typing import Dict, Any, Set
|
10
|
+
|
11
|
+
from ..enums import ProviderType, InvokeType
|
12
|
+
from ..generated import model_service_pb2
|
13
|
+
from ..schemas.inputs import (
|
14
|
+
ModelRequest,
|
15
|
+
BatchModelRequest,
|
16
|
+
BatchModelRequestItem,
|
17
|
+
UserContext,
|
18
|
+
GoogleGenAiInput,
|
19
|
+
GoogleVertexAIImagesInput,
|
20
|
+
OpenAIResponsesInput,
|
21
|
+
OpenAIChatCompletionsInput,
|
22
|
+
OpenAIImagesInput,
|
23
|
+
OpenAIImagesEditInput
|
24
|
+
)
|
25
|
+
from .utils import is_effective_value, serialize_value, remove_none_from_dict
|
26
|
+
|
27
|
+
|
28
|
+
class RequestBuilder:
|
29
|
+
"""
|
30
|
+
请求构建器
|
31
|
+
|
32
|
+
负责将高级的 ModelRequest 对象转换为 gRPC 协议所需的请求对象,
|
33
|
+
包括参数验证、序列化和提供商特定的字段处理。
|
34
|
+
"""
|
35
|
+
|
36
|
+
@staticmethod
|
37
|
+
def get_allowed_fields(provider: ProviderType, invoke_type: InvokeType) -> Set[str]:
|
38
|
+
"""
|
39
|
+
获取特定提供商和调用类型组合所允许的字段
|
40
|
+
|
41
|
+
Args:
|
42
|
+
provider: 提供商类型
|
43
|
+
invoke_type: 调用类型
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
Set[str]: 允许的字段名集合
|
47
|
+
|
48
|
+
Raises:
|
49
|
+
ValueError: 当提供商和调用类型组合不受支持时
|
50
|
+
"""
|
51
|
+
match (provider, invoke_type):
|
52
|
+
case (ProviderType.GOOGLE, InvokeType.GENERATION):
|
53
|
+
return set(GoogleGenAiInput.model_fields.keys())
|
54
|
+
case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
|
55
|
+
return set(GoogleVertexAIImagesInput.model_fields.keys())
|
56
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
|
57
|
+
return set(OpenAIResponsesInput.model_fields.keys())
|
58
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
|
59
|
+
return set(OpenAIChatCompletionsInput.model_fields.keys())
|
60
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
61
|
+
return set(OpenAIImagesInput.model_fields.keys())
|
62
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
63
|
+
return set(OpenAIImagesEditInput.model_fields.keys())
|
64
|
+
case _:
|
65
|
+
raise ValueError(
|
66
|
+
f"Unsupported provider/invoke_type combination: {provider} + {invoke_type}"
|
67
|
+
)
|
68
|
+
|
69
|
+
@staticmethod
|
70
|
+
def build_grpc_extra_fields(model_request: ModelRequest) -> Dict[str, Any]:
|
71
|
+
"""
|
72
|
+
构建 gRPC 请求的额外字段
|
73
|
+
|
74
|
+
根据提供商和调用类型,过滤并序列化请求中的参数。
|
75
|
+
|
76
|
+
Args:
|
77
|
+
model_request: 模型请求对象
|
78
|
+
|
79
|
+
Returns:
|
80
|
+
Dict[str, Any]: 序列化后的额外字段字典
|
81
|
+
|
82
|
+
Raises:
|
83
|
+
ValueError: 当构建请求失败时
|
84
|
+
"""
|
85
|
+
try:
|
86
|
+
# 获取允许的字段集合
|
87
|
+
allowed_fields = RequestBuilder.get_allowed_fields(
|
88
|
+
model_request.provider,
|
89
|
+
model_request.invoke_type
|
90
|
+
)
|
91
|
+
|
92
|
+
# 将 ModelRequest 转换为字典,只包含已设置的字段
|
93
|
+
model_request_dict = model_request.model_dump(exclude_unset=True)
|
94
|
+
|
95
|
+
# 构建 gRPC 请求参数
|
96
|
+
grpc_request_kwargs = {}
|
97
|
+
for field in allowed_fields:
|
98
|
+
if field in model_request_dict:
|
99
|
+
value = model_request_dict[field]
|
100
|
+
|
101
|
+
# 跳过无效的值
|
102
|
+
if not is_effective_value(value):
|
103
|
+
continue
|
104
|
+
|
105
|
+
# 序列化不支持的类型
|
106
|
+
grpc_request_kwargs[field] = serialize_value(value)
|
107
|
+
|
108
|
+
# 清理序列化后的参数中的 None 值
|
109
|
+
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
110
|
+
|
111
|
+
return grpc_request_kwargs
|
112
|
+
|
113
|
+
except Exception as e:
|
114
|
+
raise ValueError(f"构建请求失败: {str(e)}") from e
|
115
|
+
|
116
|
+
@staticmethod
|
117
|
+
def build_single_request(model_request: ModelRequest) -> model_service_pb2.ModelRequestItem:
|
118
|
+
"""
|
119
|
+
构建单个模型请求的 gRPC 对象
|
120
|
+
|
121
|
+
Args:
|
122
|
+
model_request: 模型请求对象
|
123
|
+
|
124
|
+
Returns:
|
125
|
+
model_service_pb2.ModelRequestItem: gRPC 请求对象
|
126
|
+
|
127
|
+
Raises:
|
128
|
+
ValueError: 当构建请求失败时
|
129
|
+
"""
|
130
|
+
# 构建额外字段
|
131
|
+
extra_fields = RequestBuilder.build_grpc_extra_fields(model_request)
|
132
|
+
|
133
|
+
# 创建 gRPC 请求对象
|
134
|
+
return model_service_pb2.ModelRequestItem(
|
135
|
+
provider=model_request.provider.value,
|
136
|
+
channel=model_request.channel.value if model_request.channel else "",
|
137
|
+
invoke_type=model_request.invoke_type.value,
|
138
|
+
stream=model_request.stream or False,
|
139
|
+
org_id=model_request.user_context.org_id or "",
|
140
|
+
user_id=model_request.user_context.user_id or "",
|
141
|
+
client_type=model_request.user_context.client_type or "",
|
142
|
+
extra=extra_fields
|
143
|
+
)
|
144
|
+
|
145
|
+
@staticmethod
|
146
|
+
def build_batch_request_item(
|
147
|
+
batch_item: "BatchModelRequestItem",
|
148
|
+
user_context: "UserContext"
|
149
|
+
) -> model_service_pb2.ModelRequestItem:
|
150
|
+
"""
|
151
|
+
构建批量请求中的单个项目
|
152
|
+
|
153
|
+
Args:
|
154
|
+
batch_item: 批量请求项
|
155
|
+
user_context: 用户上下文(来自父BatchModelRequest)
|
156
|
+
|
157
|
+
Returns:
|
158
|
+
model_service_pb2.ModelRequestItem: gRPC 请求对象
|
159
|
+
"""
|
160
|
+
# 构建额外字段
|
161
|
+
extra_fields = RequestBuilder.build_grpc_extra_fields(batch_item)
|
162
|
+
|
163
|
+
# 添加 custom_id 如果存在
|
164
|
+
if hasattr(batch_item, 'custom_id') and batch_item.custom_id:
|
165
|
+
request_item = model_service_pb2.ModelRequestItem(
|
166
|
+
provider=batch_item.provider.value,
|
167
|
+
channel=batch_item.channel.value if batch_item.channel else "",
|
168
|
+
invoke_type=batch_item.invoke_type.value,
|
169
|
+
stream=batch_item.stream or False,
|
170
|
+
org_id=user_context.org_id or "",
|
171
|
+
user_id=user_context.user_id or "",
|
172
|
+
client_type=user_context.client_type or "",
|
173
|
+
custom_id=batch_item.custom_id,
|
174
|
+
extra=extra_fields
|
175
|
+
)
|
176
|
+
else:
|
177
|
+
request_item = model_service_pb2.ModelRequestItem(
|
178
|
+
provider=batch_item.provider.value,
|
179
|
+
channel=batch_item.channel.value if batch_item.channel else "",
|
180
|
+
invoke_type=batch_item.invoke_type.value,
|
181
|
+
stream=batch_item.stream or False,
|
182
|
+
org_id=user_context.org_id or "",
|
183
|
+
user_id=user_context.user_id or "",
|
184
|
+
client_type=user_context.client_type or "",
|
185
|
+
extra=extra_fields
|
186
|
+
)
|
187
|
+
|
188
|
+
# 添加 priority 如果存在
|
189
|
+
if hasattr(batch_item, 'priority') and batch_item.priority is not None:
|
190
|
+
request_item.priority = batch_item.priority
|
191
|
+
|
192
|
+
return request_item
|
193
|
+
|
194
|
+
@staticmethod
|
195
|
+
def build_batch_request(batch_request: BatchModelRequest) -> model_service_pb2.ModelRequest:
|
196
|
+
"""
|
197
|
+
构建批量请求的 gRPC 对象
|
198
|
+
|
199
|
+
Args:
|
200
|
+
batch_request: 批量请求对象
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
model_service_pb2.ModelRequest: gRPC 批量请求对象
|
204
|
+
|
205
|
+
Raises:
|
206
|
+
ValueError: 当构建请求失败时
|
207
|
+
"""
|
208
|
+
items = []
|
209
|
+
|
210
|
+
for batch_item in batch_request.items:
|
211
|
+
# 为每个请求项构建 gRPC 对象,传入 user_context
|
212
|
+
request_item = RequestBuilder.build_batch_request_item(
|
213
|
+
batch_item,
|
214
|
+
batch_request.user_context
|
215
|
+
)
|
216
|
+
items.append(request_item)
|
217
|
+
|
218
|
+
# 创建批量请求对象
|
219
|
+
return model_service_pb2.ModelRequest(
|
220
|
+
items=items
|
221
|
+
)
|
@@ -0,0 +1,136 @@
|
|
1
|
+
"""
|
2
|
+
Response handling logic for Tamar Model Client
|
3
|
+
|
4
|
+
This module provides utilities for processing gRPC responses and
|
5
|
+
converting them to client response objects.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import json
|
9
|
+
from typing import Optional, Dict, Any
|
10
|
+
|
11
|
+
from ..schemas import ModelResponse, BatchModelResponse
|
12
|
+
|
13
|
+
|
14
|
+
class ResponseHandler:
|
15
|
+
"""
|
16
|
+
响应处理器
|
17
|
+
|
18
|
+
负责将 gRPC 响应转换为客户端响应对象,
|
19
|
+
包括 JSON 解析、错误处理和数据结构转换。
|
20
|
+
"""
|
21
|
+
|
22
|
+
@staticmethod
|
23
|
+
def build_model_response(grpc_response) -> ModelResponse:
|
24
|
+
"""
|
25
|
+
从 gRPC 响应构建 ModelResponse 对象
|
26
|
+
|
27
|
+
Args:
|
28
|
+
grpc_response: gRPC 服务返回的响应对象
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
ModelResponse: 客户端响应对象
|
32
|
+
"""
|
33
|
+
return ModelResponse(
|
34
|
+
content=grpc_response.content,
|
35
|
+
usage=ResponseHandler._parse_json_field(grpc_response.usage),
|
36
|
+
error=grpc_response.error or None,
|
37
|
+
raw_response=ResponseHandler._parse_json_field(grpc_response.raw_response),
|
38
|
+
request_id=grpc_response.request_id if grpc_response.request_id else None,
|
39
|
+
)
|
40
|
+
|
41
|
+
@staticmethod
|
42
|
+
def build_batch_response(grpc_response) -> BatchModelResponse:
|
43
|
+
"""
|
44
|
+
从 gRPC 批量响应构建 BatchModelResponse 对象
|
45
|
+
|
46
|
+
Args:
|
47
|
+
grpc_response: gRPC 服务返回的批量响应对象
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
BatchModelResponse: 客户端批量响应对象
|
51
|
+
"""
|
52
|
+
responses = []
|
53
|
+
for response_item in grpc_response.items:
|
54
|
+
model_response = ResponseHandler.build_model_response(response_item)
|
55
|
+
responses.append(model_response)
|
56
|
+
|
57
|
+
return BatchModelResponse(
|
58
|
+
responses=responses,
|
59
|
+
request_id=grpc_response.request_id if grpc_response.request_id else None
|
60
|
+
)
|
61
|
+
|
62
|
+
@staticmethod
|
63
|
+
def _parse_json_field(json_str: Optional[str]) -> Optional[Dict[str, Any]]:
|
64
|
+
"""
|
65
|
+
安全地解析 JSON 字符串
|
66
|
+
|
67
|
+
Args:
|
68
|
+
json_str: 待解析的 JSON 字符串
|
69
|
+
|
70
|
+
Returns:
|
71
|
+
Optional[Dict[str, Any]]: 解析后的字典,或 None(如果输入为空)
|
72
|
+
"""
|
73
|
+
if not json_str:
|
74
|
+
return None
|
75
|
+
|
76
|
+
try:
|
77
|
+
return json.loads(json_str)
|
78
|
+
except json.JSONDecodeError:
|
79
|
+
# 如果解析失败,返回原始字符串作为错误信息
|
80
|
+
return {"error": "JSON parse error", "raw": json_str}
|
81
|
+
|
82
|
+
@staticmethod
|
83
|
+
def build_log_data(
|
84
|
+
model_request,
|
85
|
+
response: Optional[ModelResponse] = None,
|
86
|
+
duration: Optional[float] = None,
|
87
|
+
error: Optional[Exception] = None,
|
88
|
+
stream_stats: Optional[Dict[str, Any]] = None
|
89
|
+
) -> Dict[str, Any]:
|
90
|
+
"""
|
91
|
+
构建日志数据
|
92
|
+
|
93
|
+
为请求和响应日志构建结构化的数据字典。
|
94
|
+
|
95
|
+
Args:
|
96
|
+
model_request: 原始请求对象
|
97
|
+
response: 响应对象(可选)
|
98
|
+
duration: 请求持续时间(秒)
|
99
|
+
error: 错误对象(可选)
|
100
|
+
stream_stats: 流式响应统计信息(可选)
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Dict[str, Any]: 日志数据字典
|
104
|
+
"""
|
105
|
+
data = {
|
106
|
+
"provider": model_request.provider.value,
|
107
|
+
"invoke_type": model_request.invoke_type.value,
|
108
|
+
"model": getattr(model_request, 'model', None),
|
109
|
+
"stream": getattr(model_request, 'stream', False),
|
110
|
+
}
|
111
|
+
|
112
|
+
# 添加用户上下文信息(如果有)
|
113
|
+
if hasattr(model_request, 'user_context'):
|
114
|
+
data.update({
|
115
|
+
"org_id": model_request.user_context.org_id,
|
116
|
+
"user_id": model_request.user_context.user_id,
|
117
|
+
"client_type": model_request.user_context.client_type
|
118
|
+
})
|
119
|
+
|
120
|
+
# 添加响应信息
|
121
|
+
if response:
|
122
|
+
if hasattr(response, 'content') and response.content:
|
123
|
+
data["content_length"] = len(response.content)
|
124
|
+
if hasattr(response, 'usage'):
|
125
|
+
data["usage"] = response.usage
|
126
|
+
|
127
|
+
# 添加流式响应统计
|
128
|
+
if stream_stats:
|
129
|
+
data.update(stream_stats)
|
130
|
+
|
131
|
+
# 添加错误信息
|
132
|
+
if error:
|
133
|
+
data["error_type"] = type(error).__name__
|
134
|
+
data["error_message"] = str(error)
|
135
|
+
|
136
|
+
return data
|