tamar-model-client 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/__init__.py +11 -0
- tamar_model_client/async_client.py +419 -0
- tamar_model_client/auth.py +14 -0
- tamar_model_client/enums/__init__.py +8 -0
- tamar_model_client/enums/channel.py +11 -0
- tamar_model_client/enums/invoke.py +10 -0
- tamar_model_client/enums/providers.py +8 -0
- tamar_model_client/exceptions.py +11 -0
- tamar_model_client/generated/__init__.py +0 -0
- tamar_model_client/generated/model_service_pb2.py +45 -0
- tamar_model_client/generated/model_service_pb2_grpc.py +145 -0
- tamar_model_client/schemas/__init__.py +17 -0
- tamar_model_client/schemas/inputs.py +294 -0
- tamar_model_client/schemas/outputs.py +24 -0
- tamar_model_client/sync_client.py +111 -0
- {tamar_model_client-0.1.1.dist-info → tamar_model_client-0.1.2.dist-info}/METADATA +61 -90
- tamar_model_client-0.1.2.dist-info/RECORD +34 -0
- tamar_model_client-0.1.2.dist-info/top_level.txt +1 -0
- tamar_model_client-0.1.1.dist-info/RECORD +0 -19
- tamar_model_client-0.1.1.dist-info/top_level.txt +0 -1
- {tamar_model_client-0.1.1.dist-info → tamar_model_client-0.1.2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,145 @@
|
|
1
|
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
2
|
+
"""Client and server classes corresponding to protobuf-defined services."""
|
3
|
+
import grpc
|
4
|
+
import warnings
|
5
|
+
|
6
|
+
import tamar_model_client.generated.model_service_pb2 as model__service__pb2
|
7
|
+
|
8
|
+
GRPC_GENERATED_VERSION = '1.71.0'
|
9
|
+
GRPC_VERSION = grpc.__version__
|
10
|
+
_version_not_supported = False
|
11
|
+
|
12
|
+
try:
|
13
|
+
from grpc._utilities import first_version_is_lower
|
14
|
+
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
15
|
+
except ImportError:
|
16
|
+
_version_not_supported = True
|
17
|
+
|
18
|
+
if _version_not_supported:
|
19
|
+
raise RuntimeError(
|
20
|
+
f'The grpc package installed is at version {GRPC_VERSION},'
|
21
|
+
+ f' but the generated code in model_service_pb2_grpc.py depends on'
|
22
|
+
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
23
|
+
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
24
|
+
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class ModelServiceStub(object):
|
29
|
+
"""grpc 服务(接口)定义
|
30
|
+
"""
|
31
|
+
|
32
|
+
def __init__(self, channel):
|
33
|
+
"""Constructor.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
channel: A grpc.Channel.
|
37
|
+
"""
|
38
|
+
self.Invoke = channel.unary_stream(
|
39
|
+
'/model_service.ModelService/Invoke',
|
40
|
+
request_serializer=model__service__pb2.ModelRequestItem.SerializeToString,
|
41
|
+
response_deserializer=model__service__pb2.ModelResponseItem.FromString,
|
42
|
+
_registered_method=True)
|
43
|
+
self.BatchInvoke = channel.unary_unary(
|
44
|
+
'/model_service.ModelService/BatchInvoke',
|
45
|
+
request_serializer=model__service__pb2.ModelRequest.SerializeToString,
|
46
|
+
response_deserializer=model__service__pb2.ModelResponse.FromString,
|
47
|
+
_registered_method=True)
|
48
|
+
|
49
|
+
|
50
|
+
class ModelServiceServicer(object):
|
51
|
+
"""grpc 服务(接口)定义
|
52
|
+
"""
|
53
|
+
|
54
|
+
def Invoke(self, request, context):
|
55
|
+
"""单条请求 + 流式响应
|
56
|
+
"""
|
57
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
58
|
+
context.set_details('Method not implemented!')
|
59
|
+
raise NotImplementedError('Method not implemented!')
|
60
|
+
|
61
|
+
def BatchInvoke(self, request, context):
|
62
|
+
"""批量调用接口,不支持流式
|
63
|
+
"""
|
64
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
65
|
+
context.set_details('Method not implemented!')
|
66
|
+
raise NotImplementedError('Method not implemented!')
|
67
|
+
|
68
|
+
|
69
|
+
def add_ModelServiceServicer_to_server(servicer, server):
|
70
|
+
rpc_method_handlers = {
|
71
|
+
'Invoke': grpc.unary_stream_rpc_method_handler(
|
72
|
+
servicer.Invoke,
|
73
|
+
request_deserializer=model__service__pb2.ModelRequestItem.FromString,
|
74
|
+
response_serializer=model__service__pb2.ModelResponseItem.SerializeToString,
|
75
|
+
),
|
76
|
+
'BatchInvoke': grpc.unary_unary_rpc_method_handler(
|
77
|
+
servicer.BatchInvoke,
|
78
|
+
request_deserializer=model__service__pb2.ModelRequest.FromString,
|
79
|
+
response_serializer=model__service__pb2.ModelResponse.SerializeToString,
|
80
|
+
),
|
81
|
+
}
|
82
|
+
generic_handler = grpc.method_handlers_generic_handler(
|
83
|
+
'model_service.ModelService', rpc_method_handlers)
|
84
|
+
server.add_generic_rpc_handlers((generic_handler,))
|
85
|
+
server.add_registered_method_handlers('model_service.ModelService', rpc_method_handlers)
|
86
|
+
|
87
|
+
|
88
|
+
# This class is part of an EXPERIMENTAL API.
|
89
|
+
class ModelService(object):
|
90
|
+
"""grpc 服务(接口)定义
|
91
|
+
"""
|
92
|
+
|
93
|
+
@staticmethod
|
94
|
+
def Invoke(request,
|
95
|
+
target,
|
96
|
+
options=(),
|
97
|
+
channel_credentials=None,
|
98
|
+
call_credentials=None,
|
99
|
+
insecure=False,
|
100
|
+
compression=None,
|
101
|
+
wait_for_ready=None,
|
102
|
+
timeout=None,
|
103
|
+
metadata=None):
|
104
|
+
return grpc.experimental.unary_stream(
|
105
|
+
request,
|
106
|
+
target,
|
107
|
+
'/model_service.ModelService/Invoke',
|
108
|
+
model__service__pb2.ModelRequestItem.SerializeToString,
|
109
|
+
model__service__pb2.ModelResponseItem.FromString,
|
110
|
+
options,
|
111
|
+
channel_credentials,
|
112
|
+
insecure,
|
113
|
+
call_credentials,
|
114
|
+
compression,
|
115
|
+
wait_for_ready,
|
116
|
+
timeout,
|
117
|
+
metadata,
|
118
|
+
_registered_method=True)
|
119
|
+
|
120
|
+
@staticmethod
|
121
|
+
def BatchInvoke(request,
|
122
|
+
target,
|
123
|
+
options=(),
|
124
|
+
channel_credentials=None,
|
125
|
+
call_credentials=None,
|
126
|
+
insecure=False,
|
127
|
+
compression=None,
|
128
|
+
wait_for_ready=None,
|
129
|
+
timeout=None,
|
130
|
+
metadata=None):
|
131
|
+
return grpc.experimental.unary_unary(
|
132
|
+
request,
|
133
|
+
target,
|
134
|
+
'/model_service.ModelService/BatchInvoke',
|
135
|
+
model__service__pb2.ModelRequest.SerializeToString,
|
136
|
+
model__service__pb2.ModelResponse.FromString,
|
137
|
+
options,
|
138
|
+
channel_credentials,
|
139
|
+
insecure,
|
140
|
+
call_credentials,
|
141
|
+
compression,
|
142
|
+
wait_for_ready,
|
143
|
+
timeout,
|
144
|
+
metadata,
|
145
|
+
_registered_method=True)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
"""
|
2
|
+
Schema definitions for the API
|
3
|
+
"""
|
4
|
+
|
5
|
+
from .inputs import UserContext, ModelRequest, BatchModelRequestItem, BatchModelRequest
|
6
|
+
from .outputs import ModelResponse, BatchModelResponse
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
# Model Inputs
|
10
|
+
"UserContext",
|
11
|
+
"ModelRequest",
|
12
|
+
"BatchModelRequestItem",
|
13
|
+
"BatchModelRequest",
|
14
|
+
# Model Outputs
|
15
|
+
"ModelResponse",
|
16
|
+
"BatchModelResponse",
|
17
|
+
]
|
@@ -0,0 +1,294 @@
|
|
1
|
+
import httpx
|
2
|
+
from google.genai import types
|
3
|
+
from openai import NotGiven, NOT_GIVEN
|
4
|
+
from openai._types import Headers, Query, Body
|
5
|
+
from openai.types import ChatModel, Metadata, ReasoningEffort, ResponsesModel, Reasoning
|
6
|
+
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionAudioParam, completion_create_params, \
|
7
|
+
ChatCompletionPredictionContentParam, ChatCompletionStreamOptionsParam, ChatCompletionToolChoiceOptionParam, \
|
8
|
+
ChatCompletionToolParam
|
9
|
+
from openai.types.responses import ResponseInputParam, ResponseIncludable, ResponseTextConfigParam, \
|
10
|
+
response_create_params, ToolParam
|
11
|
+
from pydantic import BaseModel, model_validator
|
12
|
+
from typing import List, Optional, Union, Iterable, Dict, Literal
|
13
|
+
|
14
|
+
from tamar_model_client.enums import ProviderType, InvokeType
|
15
|
+
from tamar_model_client.enums.channel import Channel
|
16
|
+
|
17
|
+
|
18
|
+
class UserContext(BaseModel):
|
19
|
+
org_id: str # 组织id
|
20
|
+
user_id: str # 用户id
|
21
|
+
client_type: str # 客户端类型,这里记录的是哪个服务请求过来的
|
22
|
+
|
23
|
+
|
24
|
+
class GoogleGenAiInput(BaseModel):
|
25
|
+
model: str
|
26
|
+
contents: Union[types.ContentListUnion, types.ContentListUnionDict]
|
27
|
+
config: Optional[types.GenerateContentConfigOrDict] = None
|
28
|
+
|
29
|
+
model_config = {
|
30
|
+
"arbitrary_types_allowed": True
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
class OpenAIResponsesInput(BaseModel):
|
35
|
+
input: Union[str, ResponseInputParam]
|
36
|
+
model: ResponsesModel
|
37
|
+
include: Optional[List[ResponseIncludable]] | NotGiven = NOT_GIVEN
|
38
|
+
instructions: Optional[str] | NotGiven = NOT_GIVEN
|
39
|
+
max_output_tokens: Optional[int] | NotGiven = NOT_GIVEN
|
40
|
+
metadata: Optional[Metadata] | NotGiven = NOT_GIVEN
|
41
|
+
parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN
|
42
|
+
previous_response_id: Optional[str] | NotGiven = NOT_GIVEN
|
43
|
+
reasoning: Optional[Reasoning] | NotGiven = NOT_GIVEN
|
44
|
+
store: Optional[bool] | NotGiven = NOT_GIVEN
|
45
|
+
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN
|
46
|
+
temperature: Optional[float] | NotGiven = NOT_GIVEN
|
47
|
+
text: ResponseTextConfigParam | NotGiven = NOT_GIVEN
|
48
|
+
tool_choice: response_create_params.ToolChoice | NotGiven = NOT_GIVEN
|
49
|
+
tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN
|
50
|
+
top_p: Optional[float] | NotGiven = NOT_GIVEN
|
51
|
+
truncation: Optional[Literal["auto", "disabled"]] | NotGiven = NOT_GIVEN
|
52
|
+
user: str | NotGiven = NOT_GIVEN
|
53
|
+
extra_headers: Headers | None = None
|
54
|
+
extra_query: Query | None = None
|
55
|
+
extra_body: Body | None = None
|
56
|
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
|
57
|
+
|
58
|
+
model_config = {
|
59
|
+
"arbitrary_types_allowed": True
|
60
|
+
}
|
61
|
+
|
62
|
+
|
63
|
+
class OpenAIChatCompletionsInput(BaseModel):
|
64
|
+
messages: Iterable[ChatCompletionMessageParam]
|
65
|
+
model: Union[str, ChatModel]
|
66
|
+
audio: Optional[ChatCompletionAudioParam] | NotGiven = NOT_GIVEN
|
67
|
+
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN
|
68
|
+
function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN
|
69
|
+
functions: Iterable[completion_create_params.Function] | NotGiven = NOT_GIVEN
|
70
|
+
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN
|
71
|
+
logprobs: Optional[bool] | NotGiven = NOT_GIVEN
|
72
|
+
max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN
|
73
|
+
max_tokens: Optional[int] | NotGiven = NOT_GIVEN
|
74
|
+
metadata: Optional[Metadata] | NotGiven = NOT_GIVEN
|
75
|
+
modalities: Optional[List[Literal["text", "audio"]]] | NotGiven = NOT_GIVEN
|
76
|
+
n: Optional[int] | NotGiven = NOT_GIVEN
|
77
|
+
parallel_tool_calls: bool | NotGiven = NOT_GIVEN
|
78
|
+
prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN
|
79
|
+
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN
|
80
|
+
reasoning_effort: Optional[ReasoningEffort] | NotGiven = NOT_GIVEN
|
81
|
+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN
|
82
|
+
seed: Optional[int] | NotGiven = NOT_GIVEN
|
83
|
+
service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN
|
84
|
+
stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN
|
85
|
+
store: Optional[bool] | NotGiven = NOT_GIVEN
|
86
|
+
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN
|
87
|
+
stream_options: Optional[ChatCompletionStreamOptionsParam] | NotGiven = NOT_GIVEN
|
88
|
+
temperature: Optional[float] | NotGiven = NOT_GIVEN
|
89
|
+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN
|
90
|
+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN
|
91
|
+
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN
|
92
|
+
top_p: Optional[float] | NotGiven = NOT_GIVEN
|
93
|
+
user: str | NotGiven = NOT_GIVEN
|
94
|
+
web_search_options: completion_create_params.WebSearchOptions | NotGiven = NOT_GIVEN
|
95
|
+
extra_headers: Headers | None = None
|
96
|
+
extra_query: Query | None = None
|
97
|
+
extra_body: Body | None = None
|
98
|
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
|
99
|
+
|
100
|
+
model_config = {
|
101
|
+
"arbitrary_types_allowed": True
|
102
|
+
}
|
103
|
+
|
104
|
+
|
105
|
+
class BaseRequest(BaseModel):
|
106
|
+
provider: ProviderType # 供应商,如 "openai", "google" 等
|
107
|
+
channel: Channel = Channel.NORMAL # 渠道:不同服务商之前有不同的调用SDK,这里指定是调用哪个SDK
|
108
|
+
invoke_type: InvokeType = InvokeType.GENERATION # 模型调用类型:generation-生成模型调用
|
109
|
+
|
110
|
+
|
111
|
+
class ModelRequestInput(BaseRequest):
|
112
|
+
# 合并model字段
|
113
|
+
model: Optional[Union[str, ResponsesModel, ChatModel]] = None
|
114
|
+
|
115
|
+
# OpenAI Responses Input
|
116
|
+
input: Optional[Union[str, ResponseInputParam]] = None
|
117
|
+
include: Optional[Union[List[ResponseIncludable], NotGiven]] = NOT_GIVEN
|
118
|
+
instructions: Optional[Union[str, NotGiven]] = NOT_GIVEN
|
119
|
+
max_output_tokens: Optional[Union[int, NotGiven]] = NOT_GIVEN
|
120
|
+
metadata: Optional[Union[Metadata, NotGiven]] = NOT_GIVEN
|
121
|
+
parallel_tool_calls: Optional[Union[bool, NotGiven]] = NOT_GIVEN
|
122
|
+
previous_response_id: Optional[Union[str, NotGiven]] = NOT_GIVEN
|
123
|
+
reasoning: Optional[Union[Reasoning, NotGiven]] = NOT_GIVEN
|
124
|
+
store: Optional[Union[bool, NotGiven]] = NOT_GIVEN
|
125
|
+
stream: Optional[Union[Literal[False], Literal[True], NotGiven]] = NOT_GIVEN
|
126
|
+
temperature: Optional[Union[float, NotGiven]] = NOT_GIVEN
|
127
|
+
text: Optional[Union[ResponseTextConfigParam, NotGiven]] = NOT_GIVEN
|
128
|
+
tool_choice: Optional[
|
129
|
+
Union[response_create_params.ToolChoice, ChatCompletionToolChoiceOptionParam, NotGiven]] = NOT_GIVEN
|
130
|
+
tools: Optional[Union[Iterable[ToolParam], Iterable[ChatCompletionToolParam], NotGiven]] = NOT_GIVEN
|
131
|
+
top_p: Optional[Union[float, NotGiven]] = NOT_GIVEN
|
132
|
+
truncation: Optional[Union[Literal["auto", "disabled"], NotGiven]] = NOT_GIVEN
|
133
|
+
user: Optional[Union[str, NotGiven]] = NOT_GIVEN
|
134
|
+
|
135
|
+
extra_headers: Optional[Union[Headers, None]] = None
|
136
|
+
extra_query: Optional[Union[Query, None]] = None
|
137
|
+
extra_body: Optional[Union[Body, None]] = None
|
138
|
+
timeout: Optional[Union[float, httpx.Timeout, None, NotGiven]] = NOT_GIVEN
|
139
|
+
|
140
|
+
# OpenAI Chat Completions Input
|
141
|
+
messages: Optional[Iterable[ChatCompletionMessageParam]] = None
|
142
|
+
audio: Optional[Union[ChatCompletionAudioParam, NotGiven]] = NOT_GIVEN
|
143
|
+
frequency_penalty: Optional[Union[float, NotGiven]] = NOT_GIVEN
|
144
|
+
function_call: Optional[Union[completion_create_params.FunctionCall, NotGiven]] = NOT_GIVEN
|
145
|
+
functions: Optional[Union[Iterable[completion_create_params.Function], NotGiven]] = NOT_GIVEN
|
146
|
+
logit_bias: Optional[Union[Dict[str, int], NotGiven]] = NOT_GIVEN
|
147
|
+
logprobs: Optional[Union[bool, NotGiven]] = NOT_GIVEN
|
148
|
+
max_completion_tokens: Optional[Union[int, NotGiven]] = NOT_GIVEN
|
149
|
+
modalities: Optional[Union[List[Literal["text", "audio"]], NotGiven]] = NOT_GIVEN
|
150
|
+
n: Optional[Union[int, NotGiven]] = NOT_GIVEN
|
151
|
+
prediction: Optional[Union[ChatCompletionPredictionContentParam, NotGiven]] = NOT_GIVEN
|
152
|
+
presence_penalty: Optional[Union[float, NotGiven]] = NOT_GIVEN
|
153
|
+
reasoning_effort: Optional[Union[ReasoningEffort, NotGiven]] = NOT_GIVEN
|
154
|
+
response_format: Optional[Union[completion_create_params.ResponseFormat, NotGiven]] = NOT_GIVEN
|
155
|
+
seed: Optional[Union[int, NotGiven]] = NOT_GIVEN
|
156
|
+
service_tier: Optional[Union[Literal["auto", "default"], NotGiven]] = NOT_GIVEN
|
157
|
+
stop: Optional[Union[Optional[str], List[str], None, NotGiven]] = NOT_GIVEN
|
158
|
+
top_logprobs: Optional[Union[int, NotGiven]] = NOT_GIVEN
|
159
|
+
web_search_options: Optional[Union[completion_create_params.WebSearchOptions, NotGiven]] = NOT_GIVEN
|
160
|
+
stream_options: Optional[Union[ChatCompletionStreamOptionsParam, NotGiven]] = NOT_GIVEN
|
161
|
+
|
162
|
+
# Google GenAI Input
|
163
|
+
contents: Optional[Union[types.ContentListUnion, types.ContentListUnionDict]] = None
|
164
|
+
config: Optional[types.GenerateContentConfigOrDict] = None
|
165
|
+
|
166
|
+
model_config = {
|
167
|
+
"arbitrary_types_allowed": True
|
168
|
+
}
|
169
|
+
|
170
|
+
|
171
|
+
class ModelRequest(ModelRequestInput):
|
172
|
+
user_context: UserContext # 用户信息
|
173
|
+
|
174
|
+
@model_validator(mode="after")
|
175
|
+
def validate_by_provider_and_invoke_type(self) -> "ModelRequest":
|
176
|
+
"""根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
|
177
|
+
# 动态获取 allowed fields
|
178
|
+
base_allowed = ["provider", "channel", "invoke_type", "user_context"]
|
179
|
+
google_allowed = set(base_allowed) | set(GoogleGenAiInput.model_fields.keys())
|
180
|
+
openai_responses_allowed = set(base_allowed) | set(OpenAIResponsesInput.model_fields.keys())
|
181
|
+
openai_chat_allowed = set(base_allowed) | set(OpenAIChatCompletionsInput.model_fields.keys())
|
182
|
+
|
183
|
+
# 导入或定义你的原始输入模型
|
184
|
+
google_required_fields = {"model", "contents"}
|
185
|
+
openai_responses_required_fields = {"input", "model"}
|
186
|
+
openai_chat_required_fields = {"messages", "model"}
|
187
|
+
|
188
|
+
# 选择需要校验的字段集合
|
189
|
+
if self.provider == ProviderType.GOOGLE:
|
190
|
+
expected_fields = google_required_fields
|
191
|
+
allowed_fields = google_allowed
|
192
|
+
elif self.provider == ProviderType.OPENAI or self.provider == ProviderType.AZURE:
|
193
|
+
if self.invoke_type == InvokeType.RESPONSES or self.invoke_type == InvokeType.GENERATION:
|
194
|
+
expected_fields = openai_responses_required_fields
|
195
|
+
allowed_fields = openai_responses_allowed
|
196
|
+
elif self.invoke_type == InvokeType.CHAT_COMPLETIONS:
|
197
|
+
expected_fields = openai_chat_required_fields
|
198
|
+
allowed_fields = openai_chat_allowed
|
199
|
+
else:
|
200
|
+
raise ValueError(f"暂不支持的调用类型: {self.invoke_type}")
|
201
|
+
else:
|
202
|
+
raise ValueError(f"暂不支持的提供商: {self.provider}")
|
203
|
+
|
204
|
+
# 检查是否缺失关键字段
|
205
|
+
missing = []
|
206
|
+
for field in expected_fields:
|
207
|
+
if getattr(self, field, None) is None:
|
208
|
+
missing.append(field)
|
209
|
+
|
210
|
+
if missing:
|
211
|
+
raise ValueError(
|
212
|
+
f"{self.provider}({self.invoke_type})请求缺少必填字段: {missing}"
|
213
|
+
)
|
214
|
+
|
215
|
+
# 检查是否有非法字段
|
216
|
+
illegal_fields = []
|
217
|
+
for name, value in self.__dict__.items():
|
218
|
+
if name in {"provider", "channel", "invoke_type", "stream"}:
|
219
|
+
continue
|
220
|
+
if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
|
221
|
+
illegal_fields.append(name)
|
222
|
+
|
223
|
+
if illegal_fields:
|
224
|
+
raise ValueError(
|
225
|
+
f"{self.provider}({self.invoke_type})存在不支持的字段: {illegal_fields}"
|
226
|
+
)
|
227
|
+
|
228
|
+
return self
|
229
|
+
|
230
|
+
|
231
|
+
class BatchModelRequestItem(ModelRequestInput):
|
232
|
+
custom_id: Optional[str] = None
|
233
|
+
priority: Optional[int] = None # (可选、预留字段)批量调用时执行的优先级
|
234
|
+
|
235
|
+
@model_validator(mode="after")
|
236
|
+
def validate_by_provider_and_invoke_type(self) -> "BatchModelRequestItem":
|
237
|
+
"""根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
|
238
|
+
# 动态获取 allowed fields
|
239
|
+
base_allowed = ["provider", "channel", "invoke_type", "custom_id", "priority"]
|
240
|
+
google_allowed = set(base_allowed) | set(GoogleGenAiInput.model_fields.keys())
|
241
|
+
openai_responses_allowed = set(base_allowed) | set(OpenAIResponsesInput.model_fields.keys())
|
242
|
+
openai_chat_allowed = set(base_allowed) | set(OpenAIChatCompletionsInput.model_fields.keys())
|
243
|
+
|
244
|
+
# 导入或定义你的原始输入模型
|
245
|
+
google_required_fields = {"model", "contents"}
|
246
|
+
openai_responses_required_fields = {"input", "model"}
|
247
|
+
openai_chat_required_fields = {"messages", "model"}
|
248
|
+
|
249
|
+
# 选择需要校验的字段集合
|
250
|
+
if self.provider == ProviderType.GOOGLE:
|
251
|
+
expected_fields = google_required_fields
|
252
|
+
allowed_fields = google_allowed
|
253
|
+
elif self.provider == ProviderType.OPENAI or self.provider == ProviderType.AZURE:
|
254
|
+
if self.invoke_type == InvokeType.RESPONSES or self.invoke_type == InvokeType.GENERATION:
|
255
|
+
expected_fields = openai_responses_required_fields
|
256
|
+
allowed_fields = openai_responses_allowed
|
257
|
+
elif self.invoke_type == InvokeType.CHAT_COMPLETIONS:
|
258
|
+
expected_fields = openai_chat_required_fields
|
259
|
+
allowed_fields = openai_chat_allowed
|
260
|
+
else:
|
261
|
+
raise ValueError(f"暂不支持的调用类型: {self.invoke_type}")
|
262
|
+
else:
|
263
|
+
raise ValueError(f"暂不支持的提供商: {self.provider}")
|
264
|
+
|
265
|
+
# 检查是否缺失关键字段
|
266
|
+
missing = []
|
267
|
+
for field in expected_fields:
|
268
|
+
if getattr(self, field, None) is None:
|
269
|
+
missing.append(field)
|
270
|
+
|
271
|
+
if missing:
|
272
|
+
raise ValueError(
|
273
|
+
f"{self.provider}({self.invoke_type})请求缺少必填字段: {missing}"
|
274
|
+
)
|
275
|
+
|
276
|
+
# 检查是否有非法字段
|
277
|
+
illegal_fields = []
|
278
|
+
for name, value in self.__dict__.items():
|
279
|
+
if name in {"provider", "channel", "invoke_type", "stream"}:
|
280
|
+
continue
|
281
|
+
if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
|
282
|
+
illegal_fields.append(name)
|
283
|
+
|
284
|
+
if illegal_fields:
|
285
|
+
raise ValueError(
|
286
|
+
f"{self.provider}({self.invoke_type})存在不支持的字段: {illegal_fields}"
|
287
|
+
)
|
288
|
+
|
289
|
+
return self
|
290
|
+
|
291
|
+
|
292
|
+
class BatchModelRequest(BaseModel):
|
293
|
+
user_context: UserContext # 用户信息
|
294
|
+
items: List[BatchModelRequestItem] # 批量请求项列表
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from typing import Any, Iterator, Optional, Union, Dict, List
|
2
|
+
|
3
|
+
from pydantic import BaseModel, ConfigDict
|
4
|
+
|
5
|
+
|
6
|
+
class BaseResponse(BaseModel):
|
7
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
8
|
+
|
9
|
+
content: Optional[str] = None # 文本输出内容
|
10
|
+
usage: Optional[Dict] = None # tokens / 请求成本等(JSON)
|
11
|
+
stream_response: Optional[Union[Iterator[str], Any]] = None # 用于流式响应(同步 or 异步)
|
12
|
+
raw_response: Optional[Union[Dict, List]] = None # 模型服务商返回的原始结构(JSON)
|
13
|
+
error: Optional[Any] = None # 错误信息
|
14
|
+
custom_id: Optional[str] = None # 自定义ID,用于批量请求时结果关联
|
15
|
+
|
16
|
+
|
17
|
+
class ModelResponse(BaseResponse):
|
18
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
19
|
+
request_id: Optional[str] = None # 请求ID,用于跟踪请求
|
20
|
+
|
21
|
+
|
22
|
+
class BatchModelResponse(BaseModel):
|
23
|
+
request_id: Optional[str] = None # 请求ID,用于跟踪请求
|
24
|
+
responses: Optional[List[BaseResponse]] = None # 批量请求的响应列表
|
@@ -0,0 +1,111 @@
|
|
1
|
+
import asyncio
|
2
|
+
import atexit
|
3
|
+
import logging
|
4
|
+
from typing import Optional, Union, Iterator
|
5
|
+
|
6
|
+
from .async_client import AsyncTamarModelClient
|
7
|
+
from .schemas import ModelRequest, BatchModelRequest, ModelResponse, BatchModelResponse
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class TamarModelClient:
|
13
|
+
"""
|
14
|
+
同步版本的模型管理客户端,用于非异步环境(如 Flask、Django、脚本)。
|
15
|
+
内部封装 AsyncTamarModelClient 并处理事件循环兼容性。
|
16
|
+
"""
|
17
|
+
_loop: Optional[asyncio.AbstractEventLoop] = None
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
server_address: Optional[str] = None,
|
22
|
+
jwt_secret_key: Optional[str] = None,
|
23
|
+
jwt_token: Optional[str] = None,
|
24
|
+
default_payload: Optional[dict] = None,
|
25
|
+
token_expires_in: int = 3600,
|
26
|
+
max_retries: int = 3,
|
27
|
+
retry_delay: float = 1.0,
|
28
|
+
):
|
29
|
+
# 初始化全局事件循环,仅创建一次
|
30
|
+
if not TamarModelClient._loop:
|
31
|
+
try:
|
32
|
+
TamarModelClient._loop = asyncio.get_running_loop()
|
33
|
+
except RuntimeError:
|
34
|
+
TamarModelClient._loop = asyncio.new_event_loop()
|
35
|
+
asyncio.set_event_loop(TamarModelClient._loop)
|
36
|
+
|
37
|
+
self._loop = TamarModelClient._loop
|
38
|
+
|
39
|
+
self._async_client = AsyncTamarModelClient(
|
40
|
+
server_address=server_address,
|
41
|
+
jwt_secret_key=jwt_secret_key,
|
42
|
+
jwt_token=jwt_token,
|
43
|
+
default_payload=default_payload,
|
44
|
+
token_expires_in=token_expires_in,
|
45
|
+
max_retries=max_retries,
|
46
|
+
retry_delay=retry_delay,
|
47
|
+
)
|
48
|
+
atexit.register(self._safe_sync_close)
|
49
|
+
|
50
|
+
def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None) -> Union[
|
51
|
+
ModelResponse, Iterator[ModelResponse]]:
|
52
|
+
"""
|
53
|
+
同步调用单个模型任务
|
54
|
+
"""
|
55
|
+
if model_request.stream:
|
56
|
+
async def stream():
|
57
|
+
async for r in await self._async_client.invoke(model_request, timeout=timeout):
|
58
|
+
yield r
|
59
|
+
|
60
|
+
return self._sync_wrap_async_generator(stream())
|
61
|
+
return self._run_async(self._async_client.invoke(model_request, timeout=timeout))
|
62
|
+
|
63
|
+
def invoke_batch(self, batch_model_request: BatchModelRequest,
|
64
|
+
timeout: Optional[float] = None) -> BatchModelResponse:
|
65
|
+
"""
|
66
|
+
同步调用批量模型任务
|
67
|
+
"""
|
68
|
+
return self._run_async(self._async_client.invoke_batch(batch_model_request, timeout=timeout))
|
69
|
+
|
70
|
+
def close(self):
|
71
|
+
"""手动关闭 gRPC 通道"""
|
72
|
+
self._run_async(self._async_client.close())
|
73
|
+
|
74
|
+
def _safe_sync_close(self):
|
75
|
+
"""退出时自动关闭"""
|
76
|
+
try:
|
77
|
+
self._run_async(self._async_client.close())
|
78
|
+
logger.info("✅ gRPC channel closed at exit")
|
79
|
+
except Exception as e:
|
80
|
+
logger.warning(f"❌ gRPC channel close failed at exit: {e}")
|
81
|
+
|
82
|
+
def _run_async(self, coro):
|
83
|
+
"""统一运行协程,兼容已存在的事件循环"""
|
84
|
+
try:
|
85
|
+
loop = asyncio.get_running_loop()
|
86
|
+
import nest_asyncio
|
87
|
+
nest_asyncio.apply()
|
88
|
+
return loop.run_until_complete(coro)
|
89
|
+
except RuntimeError:
|
90
|
+
return self._loop.run_until_complete(coro)
|
91
|
+
|
92
|
+
def _sync_wrap_async_generator(self, async_gen_func):
|
93
|
+
"""
|
94
|
+
将 async generator 转换为同步 generator,逐条 yield。
|
95
|
+
"""
|
96
|
+
loop = self._loop
|
97
|
+
|
98
|
+
# 创建异步生成器对象
|
99
|
+
agen = async_gen_func
|
100
|
+
|
101
|
+
class SyncGenerator:
|
102
|
+
def __iter__(self_inner):
|
103
|
+
return self_inner
|
104
|
+
|
105
|
+
def __next__(self_inner):
|
106
|
+
try:
|
107
|
+
return loop.run_until_complete(agen.__anext__())
|
108
|
+
except StopAsyncIteration:
|
109
|
+
raise StopIteration
|
110
|
+
|
111
|
+
return SyncGenerator()
|