tamar-model-client 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,145 @@
1
+ # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
+ """Client and server classes corresponding to protobuf-defined services."""
3
+ import grpc
4
+ import warnings
5
+
6
+ import model_manager_client.generated.model_service_pb2 as model__service__pb2
7
+
8
+ GRPC_GENERATED_VERSION = '1.71.0'
9
+ GRPC_VERSION = grpc.__version__
10
+ _version_not_supported = False
11
+
12
+ try:
13
+ from grpc._utilities import first_version_is_lower
14
+ _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
15
+ except ImportError:
16
+ _version_not_supported = True
17
+
18
+ if _version_not_supported:
19
+ raise RuntimeError(
20
+ f'The grpc package installed is at version {GRPC_VERSION},'
21
+ + f' but the generated code in model_service_pb2_grpc.py depends on'
22
+ + f' grpcio>={GRPC_GENERATED_VERSION}.'
23
+ + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
24
+ + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
25
+ )
26
+
27
+
28
+ class ModelServiceStub(object):
29
+ """grpc 服务(接口)定义
30
+ """
31
+
32
+ def __init__(self, channel):
33
+ """Constructor.
34
+
35
+ Args:
36
+ channel: A grpc.Channel.
37
+ """
38
+ self.Invoke = channel.unary_stream(
39
+ '/model_service.ModelService/Invoke',
40
+ request_serializer=model__service__pb2.ModelRequestItem.SerializeToString,
41
+ response_deserializer=model__service__pb2.ModelResponseItem.FromString,
42
+ _registered_method=True)
43
+ self.BatchInvoke = channel.unary_unary(
44
+ '/model_service.ModelService/BatchInvoke',
45
+ request_serializer=model__service__pb2.ModelRequest.SerializeToString,
46
+ response_deserializer=model__service__pb2.ModelResponse.FromString,
47
+ _registered_method=True)
48
+
49
+
50
+ class ModelServiceServicer(object):
51
+ """grpc 服务(接口)定义
52
+ """
53
+
54
+ def Invoke(self, request, context):
55
+ """单条请求 + 流式响应
56
+ """
57
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
58
+ context.set_details('Method not implemented!')
59
+ raise NotImplementedError('Method not implemented!')
60
+
61
+ def BatchInvoke(self, request, context):
62
+ """批量调用接口,不支持流式
63
+ """
64
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
65
+ context.set_details('Method not implemented!')
66
+ raise NotImplementedError('Method not implemented!')
67
+
68
+
69
+ def add_ModelServiceServicer_to_server(servicer, server):
70
+ rpc_method_handlers = {
71
+ 'Invoke': grpc.unary_stream_rpc_method_handler(
72
+ servicer.Invoke,
73
+ request_deserializer=model__service__pb2.ModelRequestItem.FromString,
74
+ response_serializer=model__service__pb2.ModelResponseItem.SerializeToString,
75
+ ),
76
+ 'BatchInvoke': grpc.unary_unary_rpc_method_handler(
77
+ servicer.BatchInvoke,
78
+ request_deserializer=model__service__pb2.ModelRequest.FromString,
79
+ response_serializer=model__service__pb2.ModelResponse.SerializeToString,
80
+ ),
81
+ }
82
+ generic_handler = grpc.method_handlers_generic_handler(
83
+ 'model_service.ModelService', rpc_method_handlers)
84
+ server.add_generic_rpc_handlers((generic_handler,))
85
+ server.add_registered_method_handlers('model_service.ModelService', rpc_method_handlers)
86
+
87
+
88
+ # This class is part of an EXPERIMENTAL API.
89
+ class ModelService(object):
90
+ """grpc 服务(接口)定义
91
+ """
92
+
93
+ @staticmethod
94
+ def Invoke(request,
95
+ target,
96
+ options=(),
97
+ channel_credentials=None,
98
+ call_credentials=None,
99
+ insecure=False,
100
+ compression=None,
101
+ wait_for_ready=None,
102
+ timeout=None,
103
+ metadata=None):
104
+ return grpc.experimental.unary_stream(
105
+ request,
106
+ target,
107
+ '/model_service.ModelService/Invoke',
108
+ model__service__pb2.ModelRequestItem.SerializeToString,
109
+ model__service__pb2.ModelResponseItem.FromString,
110
+ options,
111
+ channel_credentials,
112
+ insecure,
113
+ call_credentials,
114
+ compression,
115
+ wait_for_ready,
116
+ timeout,
117
+ metadata,
118
+ _registered_method=True)
119
+
120
+ @staticmethod
121
+ def BatchInvoke(request,
122
+ target,
123
+ options=(),
124
+ channel_credentials=None,
125
+ call_credentials=None,
126
+ insecure=False,
127
+ compression=None,
128
+ wait_for_ready=None,
129
+ timeout=None,
130
+ metadata=None):
131
+ return grpc.experimental.unary_unary(
132
+ request,
133
+ target,
134
+ '/model_service.ModelService/BatchInvoke',
135
+ model__service__pb2.ModelRequest.SerializeToString,
136
+ model__service__pb2.ModelResponse.FromString,
137
+ options,
138
+ channel_credentials,
139
+ insecure,
140
+ call_credentials,
141
+ compression,
142
+ wait_for_ready,
143
+ timeout,
144
+ metadata,
145
+ _registered_method=True)
@@ -0,0 +1,17 @@
1
+ """
2
+ Schema definitions for the API
3
+ """
4
+
5
+ from .inputs import UserContext, ModelRequest, BatchModelRequestItem, BatchModelRequest
6
+ from .outputs import ModelResponse, BatchModelResponse
7
+
8
+ __all__ = [
9
+ # Model Inputs
10
+ "UserContext",
11
+ "ModelRequest",
12
+ "BatchModelRequestItem",
13
+ "BatchModelRequest",
14
+ # Model Outputs
15
+ "ModelResponse",
16
+ "BatchModelResponse",
17
+ ]
@@ -0,0 +1,294 @@
1
+ import httpx
2
+ from google.genai import types
3
+ from openai import NotGiven, NOT_GIVEN
4
+ from openai._types import Headers, Query, Body
5
+ from openai.types import ChatModel, Metadata, ReasoningEffort, ResponsesModel, Reasoning
6
+ from openai.types.chat import ChatCompletionMessageParam, ChatCompletionAudioParam, completion_create_params, \
7
+ ChatCompletionPredictionContentParam, ChatCompletionStreamOptionsParam, ChatCompletionToolChoiceOptionParam, \
8
+ ChatCompletionToolParam
9
+ from openai.types.responses import ResponseInputParam, ResponseIncludable, ResponseTextConfigParam, \
10
+ response_create_params, ToolParam
11
+ from pydantic import BaseModel, model_validator
12
+ from typing import List, Optional, Union, Iterable, Dict, Literal
13
+
14
+ from model_manager_client.enums import ProviderType, InvokeType
15
+ from model_manager_client.enums.channel import Channel
16
+
17
+
18
+ class UserContext(BaseModel):
19
+ org_id: str # 组织id
20
+ user_id: str # 用户id
21
+ client_type: str # 客户端类型,这里记录的是哪个服务请求过来的
22
+
23
+
24
+ class GoogleGenAiInput(BaseModel):
25
+ model: str
26
+ contents: Union[types.ContentListUnion, types.ContentListUnionDict]
27
+ config: Optional[types.GenerateContentConfigOrDict] = None
28
+
29
+ model_config = {
30
+ "arbitrary_types_allowed": True
31
+ }
32
+
33
+
34
+ class OpenAIResponsesInput(BaseModel):
35
+ input: Union[str, ResponseInputParam]
36
+ model: ResponsesModel
37
+ include: Optional[List[ResponseIncludable]] | NotGiven = NOT_GIVEN
38
+ instructions: Optional[str] | NotGiven = NOT_GIVEN
39
+ max_output_tokens: Optional[int] | NotGiven = NOT_GIVEN
40
+ metadata: Optional[Metadata] | NotGiven = NOT_GIVEN
41
+ parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN
42
+ previous_response_id: Optional[str] | NotGiven = NOT_GIVEN
43
+ reasoning: Optional[Reasoning] | NotGiven = NOT_GIVEN
44
+ store: Optional[bool] | NotGiven = NOT_GIVEN
45
+ stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN
46
+ temperature: Optional[float] | NotGiven = NOT_GIVEN
47
+ text: ResponseTextConfigParam | NotGiven = NOT_GIVEN
48
+ tool_choice: response_create_params.ToolChoice | NotGiven = NOT_GIVEN
49
+ tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN
50
+ top_p: Optional[float] | NotGiven = NOT_GIVEN
51
+ truncation: Optional[Literal["auto", "disabled"]] | NotGiven = NOT_GIVEN
52
+ user: str | NotGiven = NOT_GIVEN
53
+ extra_headers: Headers | None = None
54
+ extra_query: Query | None = None
55
+ extra_body: Body | None = None
56
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
57
+
58
+ model_config = {
59
+ "arbitrary_types_allowed": True
60
+ }
61
+
62
+
63
+ class OpenAIChatCompletionsInput(BaseModel):
64
+ messages: Iterable[ChatCompletionMessageParam]
65
+ model: Union[str, ChatModel]
66
+ audio: Optional[ChatCompletionAudioParam] | NotGiven = NOT_GIVEN
67
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN
68
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN
69
+ functions: Iterable[completion_create_params.Function] | NotGiven = NOT_GIVEN
70
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN
71
+ logprobs: Optional[bool] | NotGiven = NOT_GIVEN
72
+ max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN
73
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN
74
+ metadata: Optional[Metadata] | NotGiven = NOT_GIVEN
75
+ modalities: Optional[List[Literal["text", "audio"]]] | NotGiven = NOT_GIVEN
76
+ n: Optional[int] | NotGiven = NOT_GIVEN
77
+ parallel_tool_calls: bool | NotGiven = NOT_GIVEN
78
+ prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN
79
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN
80
+ reasoning_effort: Optional[ReasoningEffort] | NotGiven = NOT_GIVEN
81
+ response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN
82
+ seed: Optional[int] | NotGiven = NOT_GIVEN
83
+ service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN
84
+ stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN
85
+ store: Optional[bool] | NotGiven = NOT_GIVEN
86
+ stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN
87
+ stream_options: Optional[ChatCompletionStreamOptionsParam] | NotGiven = NOT_GIVEN
88
+ temperature: Optional[float] | NotGiven = NOT_GIVEN
89
+ tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN
90
+ tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN
91
+ top_logprobs: Optional[int] | NotGiven = NOT_GIVEN
92
+ top_p: Optional[float] | NotGiven = NOT_GIVEN
93
+ user: str | NotGiven = NOT_GIVEN
94
+ web_search_options: completion_create_params.WebSearchOptions | NotGiven = NOT_GIVEN
95
+ extra_headers: Headers | None = None
96
+ extra_query: Query | None = None
97
+ extra_body: Body | None = None
98
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
99
+
100
+ model_config = {
101
+ "arbitrary_types_allowed": True
102
+ }
103
+
104
+
105
+ class BaseRequest(BaseModel):
106
+ provider: ProviderType # 供应商,如 "openai", "google" 等
107
+ channel: Channel = Channel.NORMAL # 渠道:不同服务商之前有不同的调用SDK,这里指定是调用哪个SDK
108
+ invoke_type: InvokeType = InvokeType.GENERATION # 模型调用类型:generation-生成模型调用
109
+
110
+
111
+ class ModelRequestInput(BaseRequest):
112
+ # 合并model字段
113
+ model: Optional[Union[str, ResponsesModel, ChatModel]] = None
114
+
115
+ # OpenAI Responses Input
116
+ input: Optional[Union[str, ResponseInputParam]] = None
117
+ include: Optional[Union[List[ResponseIncludable], NotGiven]] = NOT_GIVEN
118
+ instructions: Optional[Union[str, NotGiven]] = NOT_GIVEN
119
+ max_output_tokens: Optional[Union[int, NotGiven]] = NOT_GIVEN
120
+ metadata: Optional[Union[Metadata, NotGiven]] = NOT_GIVEN
121
+ parallel_tool_calls: Optional[Union[bool, NotGiven]] = NOT_GIVEN
122
+ previous_response_id: Optional[Union[str, NotGiven]] = NOT_GIVEN
123
+ reasoning: Optional[Union[Reasoning, NotGiven]] = NOT_GIVEN
124
+ store: Optional[Union[bool, NotGiven]] = NOT_GIVEN
125
+ stream: Optional[Union[Literal[False], Literal[True], NotGiven]] = NOT_GIVEN
126
+ temperature: Optional[Union[float, NotGiven]] = NOT_GIVEN
127
+ text: Optional[Union[ResponseTextConfigParam, NotGiven]] = NOT_GIVEN
128
+ tool_choice: Optional[
129
+ Union[response_create_params.ToolChoice, ChatCompletionToolChoiceOptionParam, NotGiven]] = NOT_GIVEN
130
+ tools: Optional[Union[Iterable[ToolParam], Iterable[ChatCompletionToolParam], NotGiven]] = NOT_GIVEN
131
+ top_p: Optional[Union[float, NotGiven]] = NOT_GIVEN
132
+ truncation: Optional[Union[Literal["auto", "disabled"], NotGiven]] = NOT_GIVEN
133
+ user: Optional[Union[str, NotGiven]] = NOT_GIVEN
134
+
135
+ extra_headers: Optional[Union[Headers, None]] = None
136
+ extra_query: Optional[Union[Query, None]] = None
137
+ extra_body: Optional[Union[Body, None]] = None
138
+ timeout: Optional[Union[float, httpx.Timeout, None, NotGiven]] = NOT_GIVEN
139
+
140
+ # OpenAI Chat Completions Input
141
+ messages: Optional[Iterable[ChatCompletionMessageParam]] = None
142
+ audio: Optional[Union[ChatCompletionAudioParam, NotGiven]] = NOT_GIVEN
143
+ frequency_penalty: Optional[Union[float, NotGiven]] = NOT_GIVEN
144
+ function_call: Optional[Union[completion_create_params.FunctionCall, NotGiven]] = NOT_GIVEN
145
+ functions: Optional[Union[Iterable[completion_create_params.Function], NotGiven]] = NOT_GIVEN
146
+ logit_bias: Optional[Union[Dict[str, int], NotGiven]] = NOT_GIVEN
147
+ logprobs: Optional[Union[bool, NotGiven]] = NOT_GIVEN
148
+ max_completion_tokens: Optional[Union[int, NotGiven]] = NOT_GIVEN
149
+ modalities: Optional[Union[List[Literal["text", "audio"]], NotGiven]] = NOT_GIVEN
150
+ n: Optional[Union[int, NotGiven]] = NOT_GIVEN
151
+ prediction: Optional[Union[ChatCompletionPredictionContentParam, NotGiven]] = NOT_GIVEN
152
+ presence_penalty: Optional[Union[float, NotGiven]] = NOT_GIVEN
153
+ reasoning_effort: Optional[Union[ReasoningEffort, NotGiven]] = NOT_GIVEN
154
+ response_format: Optional[Union[completion_create_params.ResponseFormat, NotGiven]] = NOT_GIVEN
155
+ seed: Optional[Union[int, NotGiven]] = NOT_GIVEN
156
+ service_tier: Optional[Union[Literal["auto", "default"], NotGiven]] = NOT_GIVEN
157
+ stop: Optional[Union[Optional[str], List[str], None, NotGiven]] = NOT_GIVEN
158
+ top_logprobs: Optional[Union[int, NotGiven]] = NOT_GIVEN
159
+ web_search_options: Optional[Union[completion_create_params.WebSearchOptions, NotGiven]] = NOT_GIVEN
160
+ stream_options: Optional[Union[ChatCompletionStreamOptionsParam, NotGiven]] = NOT_GIVEN
161
+
162
+ # Google GenAI Input
163
+ contents: Optional[Union[types.ContentListUnion, types.ContentListUnionDict]] = None
164
+ config: Optional[types.GenerateContentConfigOrDict] = None
165
+
166
+ model_config = {
167
+ "arbitrary_types_allowed": True
168
+ }
169
+
170
+
171
+ class ModelRequest(ModelRequestInput):
172
+ user_context: UserContext # 用户信息
173
+
174
+ @model_validator(mode="after")
175
+ def validate_by_provider_and_invoke_type(self) -> "ModelRequest":
176
+ """根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
177
+ # 动态获取 allowed fields
178
+ base_allowed = ["provider", "channel", "invoke_type", "user_context"]
179
+ google_allowed = set(base_allowed) | set(GoogleGenAiInput.model_fields.keys())
180
+ openai_responses_allowed = set(base_allowed) | set(OpenAIResponsesInput.model_fields.keys())
181
+ openai_chat_allowed = set(base_allowed) | set(OpenAIChatCompletionsInput.model_fields.keys())
182
+
183
+ # 导入或定义你的原始输入模型
184
+ google_required_fields = {"model", "contents"}
185
+ openai_responses_required_fields = {"input", "model"}
186
+ openai_chat_required_fields = {"messages", "model"}
187
+
188
+ # 选择需要校验的字段集合
189
+ if self.provider == ProviderType.GOOGLE:
190
+ expected_fields = google_required_fields
191
+ allowed_fields = google_allowed
192
+ elif self.provider == ProviderType.OPENAI or self.provider == ProviderType.AZURE:
193
+ if self.invoke_type == InvokeType.RESPONSES or self.invoke_type == InvokeType.GENERATION:
194
+ expected_fields = openai_responses_required_fields
195
+ allowed_fields = openai_responses_allowed
196
+ elif self.invoke_type == InvokeType.CHAT_COMPLETIONS:
197
+ expected_fields = openai_chat_required_fields
198
+ allowed_fields = openai_chat_allowed
199
+ else:
200
+ raise ValueError(f"暂不支持的调用类型: {self.invoke_type}")
201
+ else:
202
+ raise ValueError(f"暂不支持的提供商: {self.provider}")
203
+
204
+ # 检查是否缺失关键字段
205
+ missing = []
206
+ for field in expected_fields:
207
+ if getattr(self, field, None) is None:
208
+ missing.append(field)
209
+
210
+ if missing:
211
+ raise ValueError(
212
+ f"{self.provider}({self.invoke_type})请求缺少必填字段: {missing}"
213
+ )
214
+
215
+ # 检查是否有非法字段
216
+ illegal_fields = []
217
+ for name, value in self.__dict__.items():
218
+ if name in {"provider", "channel", "invoke_type", "stream"}:
219
+ continue
220
+ if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
221
+ illegal_fields.append(name)
222
+
223
+ if illegal_fields:
224
+ raise ValueError(
225
+ f"{self.provider}({self.invoke_type})存在不支持的字段: {illegal_fields}"
226
+ )
227
+
228
+ return self
229
+
230
+
231
+ class BatchModelRequestItem(ModelRequestInput):
232
+ custom_id: Optional[str] = None
233
+ priority: Optional[int] = None # (可选、预留字段)批量调用时执行的优先级
234
+
235
+ @model_validator(mode="after")
236
+ def validate_by_provider_and_invoke_type(self) -> "BatchModelRequestItem":
237
+ """根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
238
+ # 动态获取 allowed fields
239
+ base_allowed = ["provider", "channel", "invoke_type", "custom_id", "priority"]
240
+ google_allowed = set(base_allowed) | set(GoogleGenAiInput.model_fields.keys())
241
+ openai_responses_allowed = set(base_allowed) | set(OpenAIResponsesInput.model_fields.keys())
242
+ openai_chat_allowed = set(base_allowed) | set(OpenAIChatCompletionsInput.model_fields.keys())
243
+
244
+ # 导入或定义你的原始输入模型
245
+ google_required_fields = {"model", "contents"}
246
+ openai_responses_required_fields = {"input", "model"}
247
+ openai_chat_required_fields = {"messages", "model"}
248
+
249
+ # 选择需要校验的字段集合
250
+ if self.provider == ProviderType.GOOGLE:
251
+ expected_fields = google_required_fields
252
+ allowed_fields = google_allowed
253
+ elif self.provider == ProviderType.OPENAI or self.provider == ProviderType.AZURE:
254
+ if self.invoke_type == InvokeType.RESPONSES or self.invoke_type == InvokeType.GENERATION:
255
+ expected_fields = openai_responses_required_fields
256
+ allowed_fields = openai_responses_allowed
257
+ elif self.invoke_type == InvokeType.CHAT_COMPLETIONS:
258
+ expected_fields = openai_chat_required_fields
259
+ allowed_fields = openai_chat_allowed
260
+ else:
261
+ raise ValueError(f"暂不支持的调用类型: {self.invoke_type}")
262
+ else:
263
+ raise ValueError(f"暂不支持的提供商: {self.provider}")
264
+
265
+ # 检查是否缺失关键字段
266
+ missing = []
267
+ for field in expected_fields:
268
+ if getattr(self, field, None) is None:
269
+ missing.append(field)
270
+
271
+ if missing:
272
+ raise ValueError(
273
+ f"{self.provider}({self.invoke_type})请求缺少必填字段: {missing}"
274
+ )
275
+
276
+ # 检查是否有非法字段
277
+ illegal_fields = []
278
+ for name, value in self.__dict__.items():
279
+ if name in {"provider", "channel", "invoke_type", "stream"}:
280
+ continue
281
+ if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
282
+ illegal_fields.append(name)
283
+
284
+ if illegal_fields:
285
+ raise ValueError(
286
+ f"{self.provider}({self.invoke_type})存在不支持的字段: {illegal_fields}"
287
+ )
288
+
289
+ return self
290
+
291
+
292
+ class BatchModelRequest(BaseModel):
293
+ user_context: UserContext # 用户信息
294
+ items: List[BatchModelRequestItem] # 批量请求项列表
@@ -0,0 +1,24 @@
1
+ from typing import Any, Iterator, Optional, Union, Dict, List
2
+
3
+ from pydantic import BaseModel, ConfigDict
4
+
5
+
6
+ class BaseResponse(BaseModel):
7
+ model_config = ConfigDict(arbitrary_types_allowed=True)
8
+
9
+ content: Optional[str] = None # 文本输出内容
10
+ usage: Optional[Dict] = None # tokens / 请求成本等(JSON)
11
+ stream_response: Optional[Union[Iterator[str], Any]] = None # 用于流式响应(同步 or 异步)
12
+ raw_response: Optional[Union[Dict, List]] = None # 模型服务商返回的原始结构(JSON)
13
+ error: Optional[Any] = None # 错误信息
14
+ custom_id: Optional[str] = None # 自定义ID,用于批量请求时结果关联
15
+
16
+
17
+ class ModelResponse(BaseResponse):
18
+ model_config = ConfigDict(arbitrary_types_allowed=True)
19
+ request_id: Optional[str] = None # 请求ID,用于跟踪请求
20
+
21
+
22
+ class BatchModelResponse(BaseModel):
23
+ request_id: Optional[str] = None # 请求ID,用于跟踪请求
24
+ responses: Optional[List[BaseResponse]] = None # 批量请求的响应列表
@@ -0,0 +1,111 @@
1
+ import asyncio
2
+ import atexit
3
+ import logging
4
+ from typing import Optional, Union, Iterator
5
+
6
+ from .async_client import AsyncModelManagerClient
7
+ from .schemas import ModelRequest, BatchModelRequest, ModelResponse, BatchModelResponse
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ModelManagerClient:
13
+ """
14
+ 同步版本的模型管理客户端,用于非异步环境(如 Flask、Django、脚本)。
15
+ 内部封装 AsyncModelManagerClient 并处理事件循环兼容性。
16
+ """
17
+ _loop: Optional[asyncio.AbstractEventLoop] = None
18
+
19
+ def __init__(
20
+ self,
21
+ server_address: Optional[str] = None,
22
+ jwt_secret_key: Optional[str] = None,
23
+ jwt_token: Optional[str] = None,
24
+ default_payload: Optional[dict] = None,
25
+ token_expires_in: int = 3600,
26
+ max_retries: int = 3,
27
+ retry_delay: float = 1.0,
28
+ ):
29
+ # 初始化全局事件循环,仅创建一次
30
+ if not ModelManagerClient._loop:
31
+ try:
32
+ ModelManagerClient._loop = asyncio.get_running_loop()
33
+ except RuntimeError:
34
+ ModelManagerClient._loop = asyncio.new_event_loop()
35
+ asyncio.set_event_loop(ModelManagerClient._loop)
36
+
37
+ self._loop = ModelManagerClient._loop
38
+
39
+ self._async_client = AsyncModelManagerClient(
40
+ server_address=server_address,
41
+ jwt_secret_key=jwt_secret_key,
42
+ jwt_token=jwt_token,
43
+ default_payload=default_payload,
44
+ token_expires_in=token_expires_in,
45
+ max_retries=max_retries,
46
+ retry_delay=retry_delay,
47
+ )
48
+ atexit.register(self._safe_sync_close)
49
+
50
+ def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None) -> Union[
51
+ ModelResponse, Iterator[ModelResponse]]:
52
+ """
53
+ 同步调用单个模型任务
54
+ """
55
+ if model_request.stream:
56
+ async def stream():
57
+ async for r in await self._async_client.invoke(model_request, timeout=timeout):
58
+ yield r
59
+
60
+ return self._sync_wrap_async_generator(stream())
61
+ return self._run_async(self._async_client.invoke(model_request, timeout=timeout))
62
+
63
+ def invoke_batch(self, batch_model_request: BatchModelRequest,
64
+ timeout: Optional[float] = None) -> BatchModelResponse:
65
+ """
66
+ 同步调用批量模型任务
67
+ """
68
+ return self._run_async(self._async_client.invoke_batch(batch_model_request, timeout=timeout))
69
+
70
+ def close(self):
71
+ """手动关闭 gRPC 通道"""
72
+ self._run_async(self._async_client.close())
73
+
74
+ def _safe_sync_close(self):
75
+ """退出时自动关闭"""
76
+ try:
77
+ self._run_async(self._async_client.close())
78
+ logger.info("✅ gRPC channel closed at exit")
79
+ except Exception as e:
80
+ logger.warning(f"❌ gRPC channel close failed at exit: {e}")
81
+
82
+ def _run_async(self, coro):
83
+ """统一运行协程,兼容已存在的事件循环"""
84
+ try:
85
+ loop = asyncio.get_running_loop()
86
+ import nest_asyncio
87
+ nest_asyncio.apply()
88
+ return loop.run_until_complete(coro)
89
+ except RuntimeError:
90
+ return self._loop.run_until_complete(coro)
91
+
92
+ def _sync_wrap_async_generator(self, async_gen_func):
93
+ """
94
+ 将 async generator 转换为同步 generator,逐条 yield。
95
+ """
96
+ loop = self._loop
97
+
98
+ # 创建异步生成器对象
99
+ agen = async_gen_func
100
+
101
+ class SyncGenerator:
102
+ def __iter__(self_inner):
103
+ return self_inner
104
+
105
+ def __next__(self_inner):
106
+ try:
107
+ return loop.run_until_complete(agen.__anext__())
108
+ except StopAsyncIteration:
109
+ raise StopIteration
110
+
111
+ return SyncGenerator()