tamar-model-client 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -131,7 +131,7 @@ class OpenAIImagesInput(BaseModel):
131
131
  n: Optional[int] | NotGiven = NOT_GIVEN
132
132
  quality: Literal["standard", "hd"] | NotGiven = NOT_GIVEN
133
133
  response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN
134
- size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] | NotGiven = NOT_GIVEN
134
+ size: Optional[Literal["256x256", "512x512", "1024x1024", "1536x1024", "1024x1536", "auto"]]
135
135
  style: Optional[Literal["vivid", "natural"]] | NotGiven = NOT_GIVEN
136
136
  user: str | NotGiven = NOT_GIVEN
137
137
  extra_headers: Headers | None = None
@@ -301,7 +301,7 @@ class BatchModelRequestItem(ModelRequestInput):
301
301
  def validate_by_provider_and_invoke_type(self) -> "BatchModelRequestItem":
302
302
  """根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
303
303
  # 动态获取 allowed fields
304
- base_allowed = {"provider", "channel", "invoke_type", "user_context"}
304
+ base_allowed = {"provider", "channel", "invoke_type", "user_context", "custom_id"}
305
305
  google_allowed = base_allowed | set(GoogleGenAiInput.model_fields.keys())
306
306
  openai_responses_allowed = base_allowed | set(OpenAIResponsesInput.model_fields.keys())
307
307
  openai_chat_allowed = base_allowed | set(OpenAIChatCompletionsInput.model_fields.keys())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tamar-model-client
3
- Version: 0.1.11
3
+ Version: 0.1.13
4
4
  Summary: A Python SDK for interacting with the Model Manager gRPC service
5
5
  Home-page: http://gitlab.tamaredge.top/project-tap/AgentOS/model-manager-client
6
6
  Author: Oscar Ou
@@ -0,0 +1,19 @@
1
+ tamar_model_client/__init__.py,sha256=LMECAuDARWHV1XzH3msoDXcyurS2eihRQmBy26_PUE0,328
2
+ tamar_model_client/async_client.py,sha256=gmZ2xMHO_F-Vtg3OK7B_yf-gtI-WH2NU2LzC6YO_t7k,19649
3
+ tamar_model_client/auth.py,sha256=gbwW5Aakeb49PMbmYvrYlVx1mfyn1LEDJ4qQVs-9DA4,438
4
+ tamar_model_client/exceptions.py,sha256=jYU494OU_NeIa4X393V-Y73mTNm0JZ9yZApnlOM9CJQ,332
5
+ tamar_model_client/sync_client.py,sha256=o8b20fQUvtMq1gWax3_dfOpputYT4l9pRTz6cHdB0lg,4006
6
+ tamar_model_client/enums/__init__.py,sha256=3cYYn8ztNGBa_pI_5JGRVYf2QX8fkBVWdjID1PLvoBQ,182
7
+ tamar_model_client/enums/channel.py,sha256=wCzX579nNpTtwzGeS6S3Ls0UzVAgsOlfy4fXMzQTCAw,199
8
+ tamar_model_client/enums/invoke.py,sha256=WufImoN_87ZjGyzYitZkhNNFefWJehKfLtyP-DTBYlA,267
9
+ tamar_model_client/enums/providers.py,sha256=L_bX75K6KnWURoFizoitZ1Ybza7bmYDqXecNzNpgIrI,165
10
+ tamar_model_client/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ tamar_model_client/generated/model_service_pb2.py,sha256=RI6wNSmgmylzWPedFfPxx938UzS7kcPR58YTzYshcL8,3066
12
+ tamar_model_client/generated/model_service_pb2_grpc.py,sha256=k4tIbp3XBxdyuOVR18Ung_4SUryONB51UYf_uUEl6V4,5145
13
+ tamar_model_client/schemas/__init__.py,sha256=AxuI-TcvA4OMTj2FtK4wAItvz9LrK_293pu3cmMLE7k,394
14
+ tamar_model_client/schemas/inputs.py,sha256=yQzidAsRYi4GWEC-4hRaL5Ovo-wZA-ma-74j2LrxGM0,18719
15
+ tamar_model_client/schemas/outputs.py,sha256=M_fcqUtXPJnfiLabHlyA8BorlC5pYkf5KLjXO1ysKIQ,1031
16
+ tamar_model_client-0.1.13.dist-info/METADATA,sha256=MXuzkyBGqK2-yE72kq8rSq41Mc_QoYLE8cOxhsTP4_U,16566
17
+ tamar_model_client-0.1.13.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
18
+ tamar_model_client-0.1.13.dist-info/top_level.txt,sha256=_LfDhPv_fvON0PoZgQuo4M7EjoWtxPRoQOBJziJmip8,19
19
+ tamar_model_client-0.1.13.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- from .sync_client import ModelManagerClient
2
- from .async_client import AsyncModelManagerClient
3
- from .exceptions import ModelManagerClientError, ConnectionError, ValidationError
4
-
5
- __all__ = [
6
- "ModelManagerClient",
7
- "AsyncModelManagerClient",
8
- "ModelManagerClientError",
9
- "ConnectionError",
10
- "ValidationError",
11
- ]
@@ -1,419 +0,0 @@
1
- import asyncio
2
- import atexit
3
- import json
4
- import logging
5
- import os
6
-
7
- import grpc
8
- from typing import Optional, AsyncIterator, Union, Iterable
9
-
10
- from openai import NOT_GIVEN
11
- from pydantic import BaseModel
12
-
13
- from .auth import JWTAuthHandler
14
- from .enums import ProviderType, InvokeType
15
- from .exceptions import ConnectionError, ValidationError
16
- from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
17
- from .generated import model_service_pb2, model_service_pb2_grpc
18
- from .schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput
19
-
20
- if not logging.getLogger().hasHandlers():
21
- # 配置日志格式
22
- logging.basicConfig(
23
- level=logging.INFO,
24
- format="%(asctime)s [%(levelname)s] %(message)s",
25
- )
26
-
27
- logger = logging.getLogger(__name__)
28
-
29
-
30
- class AsyncModelManagerClient:
31
- def __init__(
32
- self,
33
- server_address: Optional[str] = None,
34
- jwt_secret_key: Optional[str] = None,
35
- jwt_token: Optional[str] = None,
36
- default_payload: Optional[dict] = None,
37
- token_expires_in: int = 3600,
38
- max_retries: int = 3, # 最大重试次数
39
- retry_delay: float = 1.0, # 初始重试延迟(秒)
40
- ):
41
- # 服务端地址
42
- self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
43
- if not self.server_address:
44
- raise ValueError("Server address must be provided via argument or environment variable.")
45
- self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
46
-
47
- # JWT 配置
48
- self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
49
- self.jwt_handler = JWTAuthHandler(self.jwt_secret_key)
50
- self.jwt_token = jwt_token # 用户传入的 Token(可选)
51
- self.default_payload = default_payload
52
- self.token_expires_in = token_expires_in
53
-
54
- # === TLS/Authority 配置 ===
55
- self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
56
- self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
57
-
58
- # === 重试配置 ===
59
- self.max_retries = max_retries if max_retries is not None else int(
60
- os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
61
- self.retry_delay = retry_delay if retry_delay is not None else float(
62
- os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
63
-
64
- # === gRPC 通道相关 ===
65
- self.channel: Optional[grpc.aio.Channel] = None
66
- self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
67
- self._closed = False
68
- atexit.register(self._safe_sync_close) # 注册进程退出自动关闭
69
-
70
- def _build_auth_metadata(self) -> list:
71
- if not self.jwt_token and self.jwt_handler:
72
- self.jwt_token = self.jwt_handler.encode_token(self.default_payload, expires_in=self.token_expires_in)
73
- return [("authorization", f"Bearer {self.jwt_token}")] if self.jwt_token else []
74
-
75
- async def _ensure_initialized(self):
76
- """初始化 gRPC 通道,支持 TLS 与重试机制"""
77
- if self.channel and self.stub:
78
- return
79
-
80
- retry_count = 0
81
- options = []
82
- if self.default_authority:
83
- options.append(("grpc.default_authority", self.default_authority))
84
-
85
- while retry_count <= self.max_retries:
86
- try:
87
- if self.use_tls:
88
- credentials = grpc.ssl_channel_credentials()
89
- self.channel = grpc.aio.secure_channel(
90
- self.server_address,
91
- credentials,
92
- options=options
93
- )
94
- logger.info("🔐 Using secure gRPC channel (TLS enabled)")
95
- else:
96
- self.channel = grpc.aio.insecure_channel(
97
- self.server_address,
98
- options=options
99
- )
100
- logger.info("🔓 Using insecure gRPC channel (TLS disabled)")
101
- await self.channel.channel_ready()
102
- self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
103
- logger.info(f"✅ gRPC channel initialized to {self.server_address}")
104
- return
105
- except grpc.FutureTimeoutError as e:
106
- logger.warning(f"❌ gRPC channel initialization timed out: {str(e)}")
107
- except grpc.RpcError as e:
108
- logger.warning(f"❌ gRPC channel initialization failed: {str(e)}")
109
- except Exception as e:
110
- logger.warning(f"❌ Unexpected error during channel initialization: {str(e)}")
111
-
112
- retry_count += 1
113
- if retry_count > self.max_retries:
114
- raise ConnectionError(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.")
115
-
116
- # 指数退避:延迟时间 = retry_delay * (2 ^ (retry_count - 1))
117
- delay = self.retry_delay * (2 ** (retry_count - 1))
118
- logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
119
- await asyncio.sleep(delay)
120
-
121
- async def _stream(self, model_request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
122
- try:
123
- async for response in self.stub.Invoke(model_request, metadata=metadata, timeout=invoke_timeout):
124
- yield ModelResponse(
125
- content=response.content,
126
- usage=json.loads(response.usage) if response.usage else None,
127
- raw_response=json.loads(response.raw_response) if response.raw_response else None,
128
- error=response.error or None,
129
- )
130
- except grpc.RpcError as e:
131
- raise ConnectionError(f"gRPC call failed: {str(e)}")
132
- except Exception as e:
133
- raise ValidationError(f"Invalid input: {str(e)}")
134
-
135
- async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None) -> Union[
136
- ModelResponse, AsyncIterator[ModelResponse]]:
137
- """
138
- 通用调用模型方法。
139
-
140
- Args:
141
- model_request: ModelRequest 对象,包含请求参数。
142
-
143
- Yields:
144
- ModelResponse: 支持流式或非流式的模型响应
145
-
146
- Raises:
147
- ValidationError: 输入验证失败。
148
- ConnectionError: 连接服务端失败。
149
- """
150
- await self._ensure_initialized()
151
-
152
- if not self.default_payload:
153
- self.default_payload = {
154
- "org_id": model_request.user_context.org_id or "",
155
- "user_id": model_request.user_context.user_id or ""
156
- }
157
-
158
- # 动态根据 provider/invoke_type 决定使用哪个 input 字段
159
- try:
160
- if model_request.provider == ProviderType.GOOGLE:
161
- allowed_fields = GoogleGenAiInput.model_fields.keys()
162
- elif model_request.provider in {ProviderType.OPENAI, ProviderType.AZURE}:
163
- if model_request.invoke_type in {InvokeType.RESPONSES, InvokeType.GENERATION}:
164
- allowed_fields = OpenAIResponsesInput.model_fields.keys()
165
- elif model_request.invoke_type == InvokeType.CHAT_COMPLETIONS:
166
- allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
167
- else:
168
- raise ValueError(f"暂不支持的调用类型: {model_request.invoke_type}")
169
- else:
170
- raise ValueError(f"暂不支持的提供商: {model_request.provider}")
171
-
172
- # 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
173
- model_request_dict = model_request.model_dump(exclude_unset=True)
174
-
175
- grpc_request_kwargs = {}
176
- for field in allowed_fields:
177
- if field in model_request_dict:
178
- value = model_request_dict[field]
179
-
180
- # Skip fields with NotGiven or None (unless explicitly allowed)
181
- if value is NOT_GIVEN or value is None:
182
- continue
183
-
184
- # 特别处理:如果是自定义的 BaseModel 或特定类型
185
- if isinstance(value, BaseModel):
186
- grpc_request_kwargs[field] = value.model_dump()
187
- # 如果是 OpenAI / Google 里的自定义对象,通常有 dict() 方法
188
- elif hasattr(value, "dict") and callable(value.dict):
189
- grpc_request_kwargs[field] = value.dict()
190
- # 如果是 list,需要处理里面元素也是自定义对象的情况
191
- elif isinstance(value, Iterable) and not isinstance(value, (str, bytes, dict)):
192
- new_list = []
193
- for item in value:
194
- if isinstance(item, BaseModel):
195
- new_list.append(item.model_dump())
196
- elif hasattr(item, "dict") and callable(item.dict):
197
- new_list.append(item.dict())
198
- elif isinstance(item, dict):
199
- # Handle nested dictionaries
200
- nested_dict = {}
201
- for k, v in item.items():
202
- if isinstance(v, BaseModel):
203
- nested_dict[k] = v.model_dump()
204
- elif hasattr(v, "dict") and callable(v.dict):
205
- nested_dict[k] = v.dict()
206
- else:
207
- nested_dict[k] = v
208
- new_list.append(nested_dict)
209
- else:
210
- new_list.append(item)
211
- grpc_request_kwargs[field] = new_list
212
- # 如果是 dict,同理处理内部元素
213
- elif isinstance(value, dict):
214
- new_dict = {}
215
- for k, v in value.items():
216
- if isinstance(v, BaseModel):
217
- new_dict[k] = v.model_dump()
218
- elif hasattr(v, "dict") and callable(v.dict):
219
- new_dict[k] = v.dict()
220
- else:
221
- new_dict[k] = v
222
- grpc_request_kwargs[field] = new_dict
223
- else:
224
- grpc_request_kwargs[field] = value
225
-
226
- request = model_service_pb2.ModelRequestItem(
227
- provider=model_request.provider.value,
228
- channel=model_request.channel.value,
229
- invoke_type=model_request.invoke_type.value,
230
- stream=model_request.stream or False,
231
- org_id=model_request.user_context.org_id or "",
232
- user_id=model_request.user_context.user_id or "",
233
- client_type=model_request.user_context.client_type or "",
234
- extra=grpc_request_kwargs
235
- )
236
-
237
- except Exception as e:
238
- raise ValueError(f"构建请求失败: {str(e)}") from e
239
-
240
- metadata = self._build_auth_metadata()
241
-
242
- invoke_timeout = timeout or self.default_invoke_timeout
243
- if model_request.stream:
244
- return self._stream(request, metadata, invoke_timeout)
245
- else:
246
- async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
247
- return ModelResponse(
248
- content=response.content,
249
- usage=json.loads(response.usage) if response.usage else None,
250
- raw_response=json.loads(response.raw_response) if response.raw_response else None,
251
- error=response.error or None,
252
- custom_id=None,
253
- request_id=response.request_id if response.request_id else None,
254
- )
255
-
256
- async def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None) -> \
257
- BatchModelResponse:
258
- """
259
- 批量模型调用接口
260
-
261
- Args:
262
- batch_request_model: 多条 BatchModelRequest 输入
263
- timeout: 调用超时,单位秒
264
-
265
- Returns:
266
- BatchModelResponse: 批量请求的结果
267
- """
268
- await self._ensure_initialized()
269
-
270
- if not self.default_payload:
271
- self.default_payload = {
272
- "org_id": batch_request_model.user_context.org_id or "",
273
- "user_id": batch_request_model.user_context.user_id or ""
274
- }
275
-
276
- metadata = self._build_auth_metadata()
277
-
278
- # 构造批量请求
279
- items = []
280
- for model_request_item in batch_request_model.items:
281
- # 动态根据 provider/invoke_type 决定使用哪个 input 字段
282
- try:
283
- if model_request_item.provider == ProviderType.GOOGLE:
284
- allowed_fields = GoogleGenAiInput.model_fields.keys()
285
- elif model_request_item.provider in {ProviderType.OPENAI, ProviderType.AZURE}:
286
- if model_request_item.invoke_type in {InvokeType.RESPONSES, InvokeType.GENERATION}:
287
- allowed_fields = OpenAIResponsesInput.model_fields.keys()
288
- elif model_request_item.invoke_type == InvokeType.CHAT_COMPLETIONS:
289
- allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
290
- else:
291
- raise ValueError(f"暂不支持的调用类型: {model_request_item.invoke_type}")
292
- else:
293
- raise ValueError(f"暂不支持的提供商: {model_request_item.provider}")
294
-
295
- # 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
296
- model_request_dict = model_request_item.model_dump(exclude_unset=True)
297
-
298
- grpc_request_kwargs = {}
299
- for field in allowed_fields:
300
- if field in model_request_dict:
301
- value = model_request_dict[field]
302
-
303
- # Skip fields with NotGiven or None (unless explicitly allowed)
304
- if value is NOT_GIVEN or value is None:
305
- continue
306
-
307
- # 特别处理:如果是自定义的 BaseModel 或特定类型
308
- if isinstance(value, BaseModel):
309
- grpc_request_kwargs[field] = value.model_dump()
310
- # 如果是 OpenAI / Google 里的自定义对象,通常有 dict() 方法
311
- elif hasattr(value, "dict") and callable(value.dict):
312
- grpc_request_kwargs[field] = value.dict()
313
- # 如果是 list,需要处理里面元素也是自定义对象的情况
314
- elif isinstance(value, Iterable) and not isinstance(value, (str, bytes, dict)):
315
- new_list = []
316
- for item in value:
317
- if isinstance(item, BaseModel):
318
- new_list.append(item.model_dump())
319
- elif hasattr(item, "dict") and callable(item.dict):
320
- new_list.append(item.dict())
321
- elif isinstance(item, dict):
322
- # Handle nested dictionaries
323
- nested_dict = {}
324
- for k, v in item.items():
325
- if isinstance(v, BaseModel):
326
- nested_dict[k] = v.model_dump()
327
- elif hasattr(v, "dict") and callable(v.dict):
328
- nested_dict[k] = v.dict()
329
- else:
330
- nested_dict[k] = v
331
- new_list.append(nested_dict)
332
- else:
333
- new_list.append(item)
334
- grpc_request_kwargs[field] = new_list
335
- # 如果是 dict,同理处理内部元素
336
- elif isinstance(value, dict):
337
- new_dict = {}
338
- for k, v in value.items():
339
- if isinstance(v, BaseModel):
340
- new_dict[k] = v.model_dump()
341
- elif hasattr(v, "dict") and callable(v.dict):
342
- new_dict[k] = v.dict()
343
- else:
344
- new_dict[k] = v
345
- grpc_request_kwargs[field] = new_dict
346
- else:
347
- grpc_request_kwargs[field] = value
348
-
349
- items.append(model_service_pb2.ModelRequestItem(
350
- provider=model_request_item.provider.value,
351
- channel=model_request_item.channel.value,
352
- invoke_type=model_request_item.invoke_type.value,
353
- stream=model_request_item.stream or False,
354
- custom_id=model_request_item.custom_id or "",
355
- priority=model_request_item.priority or 1,
356
- org_id=batch_request_model.user_context.org_id or "",
357
- user_id=batch_request_model.user_context.user_id or "",
358
- client_type=batch_request_model.user_context.client_type or "",
359
- extra=grpc_request_kwargs,
360
- ))
361
-
362
- except Exception as e:
363
- raise ValueError(f"构建请求失败: {str(e)},item={model_request_item.custom_id}") from e
364
-
365
- try:
366
- # 超时处理逻辑
367
- invoke_timeout = timeout or self.default_invoke_timeout
368
-
369
- # 调用 gRPC 接口
370
- response = await self.stub.BatchInvoke(
371
- model_service_pb2.ModelRequest(items=items),
372
- timeout=invoke_timeout,
373
- metadata=metadata
374
- )
375
-
376
- result = []
377
- for res_item in response.items:
378
- result.append(ModelResponse(
379
- content=res_item.content,
380
- usage=json.loads(res_item.usage) if res_item.usage else None,
381
- raw_response=json.loads(res_item.raw_response) if res_item.raw_response else None,
382
- error=res_item.error or None,
383
- custom_id=res_item.custom_id if res_item.custom_id else None
384
- ))
385
- return BatchModelResponse(
386
- request_id=response.request_id if response.request_id else None,
387
- responses=result
388
- )
389
- except grpc.RpcError as e:
390
- raise ConnectionError(f"BatchInvoke failed: {str(e)}")
391
-
392
- async def close(self):
393
- """关闭 gRPC 通道"""
394
- if self.channel and not self._closed:
395
- await self.channel.close()
396
- self._closed = True
397
- await self.channel.close()
398
- logger.info("✅ gRPC channel closed")
399
-
400
- def _safe_sync_close(self):
401
- """进程退出时自动关闭 channel(事件循环处理兼容)"""
402
- if self.channel and not self._closed:
403
- try:
404
- loop = asyncio.get_event_loop()
405
- if loop.is_running():
406
- loop.create_task(self.close())
407
- else:
408
- loop.run_until_complete(self.close())
409
- except Exception as e:
410
- logger.warning(f"❌ gRPC channel close failed at exit: {e}")
411
-
412
- async def __aenter__(self):
413
- """支持 async with 自动初始化连接"""
414
- await self._ensure_initialized()
415
- return self
416
-
417
- async def __aexit__(self, exc_type, exc_val, exc_tb):
418
- """支持 async with 自动关闭连接"""
419
- await self.close()
@@ -1,14 +0,0 @@
1
- import time
2
- import jwt
3
-
4
-
5
- # JWT 处理类
6
- class JWTAuthHandler:
7
- def __init__(self, secret_key: str):
8
- self.secret_key = secret_key
9
-
10
- def encode_token(self, payload: dict, expires_in: int = 3600) -> str:
11
- """生成带过期时间的 JWT Token"""
12
- payload = payload.copy()
13
- payload["exp"] = int(time.time()) + expires_in
14
- return jwt.encode(payload, self.secret_key, algorithm="HS256")
@@ -1,8 +0,0 @@
1
- """
2
- 枚举类型定义
3
- """
4
- from .channel import Channel
5
- from .invoke import InvokeType
6
- from .providers import ProviderType
7
-
8
- __all__ = ["ProviderType", "InvokeType", "Channel"]
@@ -1,11 +0,0 @@
1
- from enum import Enum
2
-
3
-
4
- class Channel(str, Enum):
5
- """渠道枚举"""
6
- OPENAI = "openai"
7
- VERTEXAI = "vertexai"
8
- AI_STUDIO = "ai-studio"
9
-
10
- # 默认的
11
- NORMAL = "normal"
@@ -1,10 +0,0 @@
1
- from enum import Enum
2
-
3
-
4
- class InvokeType(str, Enum):
5
- """模型调用类型枚举"""
6
- RESPONSES = "responses"
7
- CHAT_COMPLETIONS = "chat-completions"
8
-
9
- # 默认的
10
- GENERATION = "generation"
@@ -1,8 +0,0 @@
1
- from enum import Enum
2
-
3
-
4
- class ProviderType(str, Enum):
5
- """模型提供商类型枚举"""
6
- OPENAI = "openai"
7
- GOOGLE = "google"
8
- AZURE = "azure"
@@ -1,11 +0,0 @@
1
- class ModelManagerClientError(Exception):
2
- """Base exception for Model Manager Client errors"""
3
- pass
4
-
5
- class ConnectionError(ModelManagerClientError):
6
- """Raised when connection to gRPC server fails"""
7
- pass
8
-
9
- class ValidationError(ModelManagerClientError):
10
- """Raised when input validation fails"""
11
- pass
File without changes
@@ -1,45 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Generated by the protocol buffer compiler. DO NOT EDIT!
3
- # NO CHECKED-IN PROTOBUF GENCODE
4
- # source: model_service.proto
5
- # Protobuf Python Version: 5.29.0
6
- """Generated protocol buffer code."""
7
- from google.protobuf import descriptor as _descriptor
8
- from google.protobuf import descriptor_pool as _descriptor_pool
9
- from google.protobuf import runtime_version as _runtime_version
10
- from google.protobuf import symbol_database as _symbol_database
11
- from google.protobuf.internal import builder as _builder
12
- _runtime_version.ValidateProtobufRuntimeVersion(
13
- _runtime_version.Domain.PUBLIC,
14
- 5,
15
- 29,
16
- 0,
17
- '',
18
- 'model_service.proto'
19
- )
20
- # @@protoc_insertion_point(imports)
21
-
22
- _sym_db = _symbol_database.Default()
23
-
24
-
25
- from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
26
-
27
-
28
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13model_service.proto\x12\rmodel_service\x1a\x1cgoogle/protobuf/struct.proto\"\x82\x02\n\x10ModelRequestItem\x12\x10\n\x08provider\x18\x01 \x01(\t\x12\x0f\n\x07\x63hannel\x18\x02 \x01(\t\x12\x13\n\x0binvoke_type\x18\x03 \x01(\t\x12\x0e\n\x06stream\x18\x04 \x01(\x08\x12\x0e\n\x06org_id\x18\x05 \x01(\t\x12\x0f\n\x07user_id\x18\x06 \x01(\t\x12\x13\n\x0b\x63lient_type\x18\x07 \x01(\t\x12\x15\n\x08priority\x18\x08 \x01(\x05H\x00\x88\x01\x01\x12\x16\n\tcustom_id\x18\t \x01(\tH\x01\x88\x01\x01\x12&\n\x05\x65xtra\x18\n \x01(\x0b\x32\x17.google.protobuf.StructB\x0b\n\t_priorityB\x0c\n\n_custom_id\">\n\x0cModelRequest\x12.\n\x05items\x18\x01 \x03(\x0b\x32\x1f.model_service.ModelRequestItem\"\xa6\x01\n\x11ModelResponseItem\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\t\x12\r\n\x05usage\x18\x02 \x01(\t\x12\x14\n\x0craw_response\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x16\n\tcustom_id\x18\x05 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_custom_idB\r\n\x0b_request_id\"T\n\rModelResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12/\n\x05items\x18\x02 \x03(\x0b\x32 .model_service.ModelResponseItem2\xa7\x01\n\x0cModelService\x12M\n\x06Invoke\x12\x1f.model_service.ModelRequestItem\x1a .model_service.ModelResponseItem0\x01\x12H\n\x0b\x42\x61tchInvoke\x12\x1b.model_service.ModelRequest\x1a\x1c.model_service.ModelResponseb\x06proto3')
29
-
30
- _globals = globals()
31
- _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
32
- _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'model_service_pb2', _globals)
33
- if not _descriptor._USE_C_DESCRIPTORS:
34
- DESCRIPTOR._loaded_options = None
35
- _globals['_MODELREQUESTITEM']._serialized_start=69
36
- _globals['_MODELREQUESTITEM']._serialized_end=327
37
- _globals['_MODELREQUEST']._serialized_start=329
38
- _globals['_MODELREQUEST']._serialized_end=391
39
- _globals['_MODELRESPONSEITEM']._serialized_start=394
40
- _globals['_MODELRESPONSEITEM']._serialized_end=560
41
- _globals['_MODELRESPONSE']._serialized_start=562
42
- _globals['_MODELRESPONSE']._serialized_end=646
43
- _globals['_MODELSERVICE']._serialized_start=649
44
- _globals['_MODELSERVICE']._serialized_end=816
45
- # @@protoc_insertion_point(module_scope)
@@ -1,145 +0,0 @@
1
- # Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
2
- """Client and server classes corresponding to protobuf-defined services."""
3
- import grpc
4
- import warnings
5
-
6
- import model_manager_client.generated.model_service_pb2 as model__service__pb2
7
-
8
- GRPC_GENERATED_VERSION = '1.71.0'
9
- GRPC_VERSION = grpc.__version__
10
- _version_not_supported = False
11
-
12
- try:
13
- from grpc._utilities import first_version_is_lower
14
- _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
15
- except ImportError:
16
- _version_not_supported = True
17
-
18
- if _version_not_supported:
19
- raise RuntimeError(
20
- f'The grpc package installed is at version {GRPC_VERSION},'
21
- + f' but the generated code in model_service_pb2_grpc.py depends on'
22
- + f' grpcio>={GRPC_GENERATED_VERSION}.'
23
- + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
24
- + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
25
- )
26
-
27
-
28
- class ModelServiceStub(object):
29
- """grpc 服务(接口)定义
30
- """
31
-
32
- def __init__(self, channel):
33
- """Constructor.
34
-
35
- Args:
36
- channel: A grpc.Channel.
37
- """
38
- self.Invoke = channel.unary_stream(
39
- '/model_service.ModelService/Invoke',
40
- request_serializer=model__service__pb2.ModelRequestItem.SerializeToString,
41
- response_deserializer=model__service__pb2.ModelResponseItem.FromString,
42
- _registered_method=True)
43
- self.BatchInvoke = channel.unary_unary(
44
- '/model_service.ModelService/BatchInvoke',
45
- request_serializer=model__service__pb2.ModelRequest.SerializeToString,
46
- response_deserializer=model__service__pb2.ModelResponse.FromString,
47
- _registered_method=True)
48
-
49
-
50
- class ModelServiceServicer(object):
51
- """grpc 服务(接口)定义
52
- """
53
-
54
- def Invoke(self, request, context):
55
- """单条请求 + 流式响应
56
- """
57
- context.set_code(grpc.StatusCode.UNIMPLEMENTED)
58
- context.set_details('Method not implemented!')
59
- raise NotImplementedError('Method not implemented!')
60
-
61
- def BatchInvoke(self, request, context):
62
- """批量调用接口,不支持流式
63
- """
64
- context.set_code(grpc.StatusCode.UNIMPLEMENTED)
65
- context.set_details('Method not implemented!')
66
- raise NotImplementedError('Method not implemented!')
67
-
68
-
69
- def add_ModelServiceServicer_to_server(servicer, server):
70
- rpc_method_handlers = {
71
- 'Invoke': grpc.unary_stream_rpc_method_handler(
72
- servicer.Invoke,
73
- request_deserializer=model__service__pb2.ModelRequestItem.FromString,
74
- response_serializer=model__service__pb2.ModelResponseItem.SerializeToString,
75
- ),
76
- 'BatchInvoke': grpc.unary_unary_rpc_method_handler(
77
- servicer.BatchInvoke,
78
- request_deserializer=model__service__pb2.ModelRequest.FromString,
79
- response_serializer=model__service__pb2.ModelResponse.SerializeToString,
80
- ),
81
- }
82
- generic_handler = grpc.method_handlers_generic_handler(
83
- 'model_service.ModelService', rpc_method_handlers)
84
- server.add_generic_rpc_handlers((generic_handler,))
85
- server.add_registered_method_handlers('model_service.ModelService', rpc_method_handlers)
86
-
87
-
88
- # This class is part of an EXPERIMENTAL API.
89
- class ModelService(object):
90
- """grpc 服务(接口)定义
91
- """
92
-
93
- @staticmethod
94
- def Invoke(request,
95
- target,
96
- options=(),
97
- channel_credentials=None,
98
- call_credentials=None,
99
- insecure=False,
100
- compression=None,
101
- wait_for_ready=None,
102
- timeout=None,
103
- metadata=None):
104
- return grpc.experimental.unary_stream(
105
- request,
106
- target,
107
- '/model_service.ModelService/Invoke',
108
- model__service__pb2.ModelRequestItem.SerializeToString,
109
- model__service__pb2.ModelResponseItem.FromString,
110
- options,
111
- channel_credentials,
112
- insecure,
113
- call_credentials,
114
- compression,
115
- wait_for_ready,
116
- timeout,
117
- metadata,
118
- _registered_method=True)
119
-
120
- @staticmethod
121
- def BatchInvoke(request,
122
- target,
123
- options=(),
124
- channel_credentials=None,
125
- call_credentials=None,
126
- insecure=False,
127
- compression=None,
128
- wait_for_ready=None,
129
- timeout=None,
130
- metadata=None):
131
- return grpc.experimental.unary_unary(
132
- request,
133
- target,
134
- '/model_service.ModelService/BatchInvoke',
135
- model__service__pb2.ModelRequest.SerializeToString,
136
- model__service__pb2.ModelResponse.FromString,
137
- options,
138
- channel_credentials,
139
- insecure,
140
- call_credentials,
141
- compression,
142
- wait_for_ready,
143
- timeout,
144
- metadata,
145
- _registered_method=True)
@@ -1,17 +0,0 @@
1
- """
2
- Schema definitions for the API
3
- """
4
-
5
- from .inputs import UserContext, ModelRequest, BatchModelRequestItem, BatchModelRequest
6
- from .outputs import ModelResponse, BatchModelResponse
7
-
8
- __all__ = [
9
- # Model Inputs
10
- "UserContext",
11
- "ModelRequest",
12
- "BatchModelRequestItem",
13
- "BatchModelRequest",
14
- # Model Outputs
15
- "ModelResponse",
16
- "BatchModelResponse",
17
- ]
@@ -1,294 +0,0 @@
1
- import httpx
2
- from google.genai import types
3
- from openai import NotGiven, NOT_GIVEN
4
- from openai._types import Headers, Query, Body
5
- from openai.types import ChatModel, Metadata, ReasoningEffort, ResponsesModel, Reasoning
6
- from openai.types.chat import ChatCompletionMessageParam, ChatCompletionAudioParam, completion_create_params, \
7
- ChatCompletionPredictionContentParam, ChatCompletionStreamOptionsParam, ChatCompletionToolChoiceOptionParam, \
8
- ChatCompletionToolParam
9
- from openai.types.responses import ResponseInputParam, ResponseIncludable, ResponseTextConfigParam, \
10
- response_create_params, ToolParam
11
- from pydantic import BaseModel, model_validator
12
- from typing import List, Optional, Union, Iterable, Dict, Literal
13
-
14
- from model_manager_client.enums import ProviderType, InvokeType
15
- from model_manager_client.enums.channel import Channel
16
-
17
-
18
- class UserContext(BaseModel):
19
- org_id: str # 组织id
20
- user_id: str # 用户id
21
- client_type: str # 客户端类型,这里记录的是哪个服务请求过来的
22
-
23
-
24
- class GoogleGenAiInput(BaseModel):
25
- model: str
26
- contents: Union[types.ContentListUnion, types.ContentListUnionDict]
27
- config: Optional[types.GenerateContentConfigOrDict] = None
28
-
29
- model_config = {
30
- "arbitrary_types_allowed": True
31
- }
32
-
33
-
34
- class OpenAIResponsesInput(BaseModel):
35
- input: Union[str, ResponseInputParam]
36
- model: ResponsesModel
37
- include: Optional[List[ResponseIncludable]] | NotGiven = NOT_GIVEN
38
- instructions: Optional[str] | NotGiven = NOT_GIVEN
39
- max_output_tokens: Optional[int] | NotGiven = NOT_GIVEN
40
- metadata: Optional[Metadata] | NotGiven = NOT_GIVEN
41
- parallel_tool_calls: Optional[bool] | NotGiven = NOT_GIVEN
42
- previous_response_id: Optional[str] | NotGiven = NOT_GIVEN
43
- reasoning: Optional[Reasoning] | NotGiven = NOT_GIVEN
44
- store: Optional[bool] | NotGiven = NOT_GIVEN
45
- stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN
46
- temperature: Optional[float] | NotGiven = NOT_GIVEN
47
- text: ResponseTextConfigParam | NotGiven = NOT_GIVEN
48
- tool_choice: response_create_params.ToolChoice | NotGiven = NOT_GIVEN
49
- tools: Iterable[ToolParam] | NotGiven = NOT_GIVEN
50
- top_p: Optional[float] | NotGiven = NOT_GIVEN
51
- truncation: Optional[Literal["auto", "disabled"]] | NotGiven = NOT_GIVEN
52
- user: str | NotGiven = NOT_GIVEN
53
- extra_headers: Headers | None = None
54
- extra_query: Query | None = None
55
- extra_body: Body | None = None
56
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
57
-
58
- model_config = {
59
- "arbitrary_types_allowed": True
60
- }
61
-
62
-
63
- class OpenAIChatCompletionsInput(BaseModel):
64
- messages: Iterable[ChatCompletionMessageParam]
65
- model: Union[str, ChatModel]
66
- audio: Optional[ChatCompletionAudioParam] | NotGiven = NOT_GIVEN
67
- frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN
68
- function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN
69
- functions: Iterable[completion_create_params.Function] | NotGiven = NOT_GIVEN
70
- logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN
71
- logprobs: Optional[bool] | NotGiven = NOT_GIVEN
72
- max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN
73
- max_tokens: Optional[int] | NotGiven = NOT_GIVEN
74
- metadata: Optional[Metadata] | NotGiven = NOT_GIVEN
75
- modalities: Optional[List[Literal["text", "audio"]]] | NotGiven = NOT_GIVEN
76
- n: Optional[int] | NotGiven = NOT_GIVEN
77
- parallel_tool_calls: bool | NotGiven = NOT_GIVEN
78
- prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN
79
- presence_penalty: Optional[float] | NotGiven = NOT_GIVEN
80
- reasoning_effort: Optional[ReasoningEffort] | NotGiven = NOT_GIVEN
81
- response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN
82
- seed: Optional[int] | NotGiven = NOT_GIVEN
83
- service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN
84
- stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN
85
- store: Optional[bool] | NotGiven = NOT_GIVEN
86
- stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN
87
- stream_options: Optional[ChatCompletionStreamOptionsParam] | NotGiven = NOT_GIVEN
88
- temperature: Optional[float] | NotGiven = NOT_GIVEN
89
- tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN
90
- tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN
91
- top_logprobs: Optional[int] | NotGiven = NOT_GIVEN
92
- top_p: Optional[float] | NotGiven = NOT_GIVEN
93
- user: str | NotGiven = NOT_GIVEN
94
- web_search_options: completion_create_params.WebSearchOptions | NotGiven = NOT_GIVEN
95
- extra_headers: Headers | None = None
96
- extra_query: Query | None = None
97
- extra_body: Body | None = None
98
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
99
-
100
- model_config = {
101
- "arbitrary_types_allowed": True
102
- }
103
-
104
-
105
- class BaseRequest(BaseModel):
106
- provider: ProviderType # 供应商,如 "openai", "google" 等
107
- channel: Channel = Channel.NORMAL # 渠道:不同服务商之前有不同的调用SDK,这里指定是调用哪个SDK
108
- invoke_type: InvokeType = InvokeType.TEXT_GENERATION # 模型调用类型:generation-生成模型调用
109
-
110
-
111
- class ModelRequestInput(BaseRequest):
112
- # 合并model字段
113
- model: Optional[Union[str, ResponsesModel, ChatModel]] = None
114
-
115
- # OpenAI Responses Input
116
- input: Optional[Union[str, ResponseInputParam]] = None
117
- include: Optional[Union[List[ResponseIncludable], NotGiven]] = NOT_GIVEN
118
- instructions: Optional[Union[str, NotGiven]] = NOT_GIVEN
119
- max_output_tokens: Optional[Union[int, NotGiven]] = NOT_GIVEN
120
- metadata: Optional[Union[Metadata, NotGiven]] = NOT_GIVEN
121
- parallel_tool_calls: Optional[Union[bool, NotGiven]] = NOT_GIVEN
122
- previous_response_id: Optional[Union[str, NotGiven]] = NOT_GIVEN
123
- reasoning: Optional[Union[Reasoning, NotGiven]] = NOT_GIVEN
124
- store: Optional[Union[bool, NotGiven]] = NOT_GIVEN
125
- stream: Optional[Union[Literal[False], Literal[True], NotGiven]] = NOT_GIVEN
126
- temperature: Optional[Union[float, NotGiven]] = NOT_GIVEN
127
- text: Optional[Union[ResponseTextConfigParam, NotGiven]] = NOT_GIVEN
128
- tool_choice: Optional[
129
- Union[response_create_params.ToolChoice, ChatCompletionToolChoiceOptionParam, NotGiven]] = NOT_GIVEN
130
- tools: Optional[Union[Iterable[ToolParam], Iterable[ChatCompletionToolParam], NotGiven]] = NOT_GIVEN
131
- top_p: Optional[Union[float, NotGiven]] = NOT_GIVEN
132
- truncation: Optional[Union[Literal["auto", "disabled"], NotGiven]] = NOT_GIVEN
133
- user: Optional[Union[str, NotGiven]] = NOT_GIVEN
134
-
135
- extra_headers: Optional[Union[Headers, None]] = None
136
- extra_query: Optional[Union[Query, None]] = None
137
- extra_body: Optional[Union[Body, None]] = None
138
- timeout: Optional[Union[float, httpx.Timeout, None, NotGiven]] = NOT_GIVEN
139
-
140
- # OpenAI Chat Completions Input
141
- messages: Optional[Iterable[ChatCompletionMessageParam]] = None
142
- audio: Optional[Union[ChatCompletionAudioParam, NotGiven]] = NOT_GIVEN
143
- frequency_penalty: Optional[Union[float, NotGiven]] = NOT_GIVEN
144
- function_call: Optional[Union[completion_create_params.FunctionCall, NotGiven]] = NOT_GIVEN
145
- functions: Optional[Union[Iterable[completion_create_params.Function], NotGiven]] = NOT_GIVEN
146
- logit_bias: Optional[Union[Dict[str, int], NotGiven]] = NOT_GIVEN
147
- logprobs: Optional[Union[bool, NotGiven]] = NOT_GIVEN
148
- max_completion_tokens: Optional[Union[int, NotGiven]] = NOT_GIVEN
149
- modalities: Optional[Union[List[Literal["text", "audio"]], NotGiven]] = NOT_GIVEN
150
- n: Optional[Union[int, NotGiven]] = NOT_GIVEN
151
- prediction: Optional[Union[ChatCompletionPredictionContentParam, NotGiven]] = NOT_GIVEN
152
- presence_penalty: Optional[Union[float, NotGiven]] = NOT_GIVEN
153
- reasoning_effort: Optional[Union[ReasoningEffort, NotGiven]] = NOT_GIVEN
154
- response_format: Optional[Union[completion_create_params.ResponseFormat, NotGiven]] = NOT_GIVEN
155
- seed: Optional[Union[int, NotGiven]] = NOT_GIVEN
156
- service_tier: Optional[Union[Literal["auto", "default"], NotGiven]] = NOT_GIVEN
157
- stop: Optional[Union[Optional[str], List[str], None, NotGiven]] = NOT_GIVEN
158
- top_logprobs: Optional[Union[int, NotGiven]] = NOT_GIVEN
159
- web_search_options: Optional[Union[completion_create_params.WebSearchOptions, NotGiven]] = NOT_GIVEN
160
- stream_options: Optional[Union[ChatCompletionStreamOptionsParam, NotGiven]] = NOT_GIVEN
161
-
162
- # Google GenAI Input
163
- contents: Optional[Union[types.ContentListUnion, types.ContentListUnionDict]] = None
164
- config: Optional[types.GenerateContentConfigOrDict] = None
165
-
166
- model_config = {
167
- "arbitrary_types_allowed": True
168
- }
169
-
170
-
171
- class ModelRequest(ModelRequestInput):
172
- user_context: UserContext # 用户信息
173
-
174
- @model_validator(mode="after")
175
- def validate_by_provider_and_invoke_type(self) -> "ModelRequest":
176
- """根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
177
- # 动态获取 allowed fields
178
- base_allowed = ["provider", "channel", "invoke_type", "user_context"]
179
- google_allowed = set(base_allowed) | set(GoogleGenAiInput.model_fields.keys())
180
- openai_responses_allowed = set(base_allowed) | set(OpenAIResponsesInput.model_fields.keys())
181
- openai_chat_allowed = set(base_allowed) | set(OpenAIChatCompletionsInput.model_fields.keys())
182
-
183
- # 导入或定义你的原始输入模型
184
- google_required_fields = {"model", "contents"}
185
- openai_responses_required_fields = {"input", "model"}
186
- openai_chat_required_fields = {"messages", "model"}
187
-
188
- # 选择需要校验的字段集合
189
- if self.provider == ProviderType.GOOGLE:
190
- expected_fields = google_required_fields
191
- allowed_fields = google_allowed
192
- elif self.provider == ProviderType.OPENAI or self.provider == ProviderType.AZURE:
193
- if self.invoke_type == InvokeType.RESPONSES or self.invoke_type == InvokeType.TEXT_GENERATION:
194
- expected_fields = openai_responses_required_fields
195
- allowed_fields = openai_responses_allowed
196
- elif self.invoke_type == InvokeType.CHAT_COMPLETIONS:
197
- expected_fields = openai_chat_required_fields
198
- allowed_fields = openai_chat_allowed
199
- else:
200
- raise ValueError(f"暂不支持的调用类型: {self.invoke_type}")
201
- else:
202
- raise ValueError(f"暂不支持的提供商: {self.provider}")
203
-
204
- # 检查是否缺失关键字段
205
- missing = []
206
- for field in expected_fields:
207
- if getattr(self, field, None) is None:
208
- missing.append(field)
209
-
210
- if missing:
211
- raise ValueError(
212
- f"{self.provider}({self.invoke_type})请求缺少必填字段: {missing}"
213
- )
214
-
215
- # 检查是否有非法字段
216
- illegal_fields = []
217
- for name, value in self.__dict__.items():
218
- if name in {"provider", "channel", "invoke_type", "stream"}:
219
- continue
220
- if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
221
- illegal_fields.append(name)
222
-
223
- if illegal_fields:
224
- raise ValueError(
225
- f"{self.provider}({self.invoke_type})存在不支持的字段: {illegal_fields}"
226
- )
227
-
228
- return self
229
-
230
-
231
- class BatchModelRequestItem(ModelRequestInput):
232
- custom_id: Optional[str] = None
233
- priority: Optional[int] = None # (可选、预留字段)批量调用时执行的优先级
234
-
235
- @model_validator(mode="after")
236
- def validate_by_provider_and_invoke_type(self) -> "BatchModelRequestItem":
237
- """根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
238
- # 动态获取 allowed fields
239
- base_allowed = ["provider", "channel", "invoke_type", "custom_id", "priority"]
240
- google_allowed = set(base_allowed) | set(GoogleGenAiInput.model_fields.keys())
241
- openai_responses_allowed = set(base_allowed) | set(OpenAIResponsesInput.model_fields.keys())
242
- openai_chat_allowed = set(base_allowed) | set(OpenAIChatCompletionsInput.model_fields.keys())
243
-
244
- # 导入或定义你的原始输入模型
245
- google_required_fields = {"model", "contents"}
246
- openai_responses_required_fields = {"input", "model"}
247
- openai_chat_required_fields = {"messages", "model"}
248
-
249
- # 选择需要校验的字段集合
250
- if self.provider == ProviderType.GOOGLE:
251
- expected_fields = google_required_fields
252
- allowed_fields = google_allowed
253
- elif self.provider == ProviderType.OPENAI or self.provider == ProviderType.AZURE:
254
- if self.invoke_type == InvokeType.RESPONSES or self.invoke_type == InvokeType.TEXT_GENERATION:
255
- expected_fields = openai_responses_required_fields
256
- allowed_fields = openai_responses_allowed
257
- elif self.invoke_type == InvokeType.CHAT_COMPLETIONS:
258
- expected_fields = openai_chat_required_fields
259
- allowed_fields = openai_chat_allowed
260
- else:
261
- raise ValueError(f"暂不支持的调用类型: {self.invoke_type}")
262
- else:
263
- raise ValueError(f"暂不支持的提供商: {self.provider}")
264
-
265
- # 检查是否缺失关键字段
266
- missing = []
267
- for field in expected_fields:
268
- if getattr(self, field, None) is None:
269
- missing.append(field)
270
-
271
- if missing:
272
- raise ValueError(
273
- f"{self.provider}({self.invoke_type})请求缺少必填字段: {missing}"
274
- )
275
-
276
- # 检查是否有非法字段
277
- illegal_fields = []
278
- for name, value in self.__dict__.items():
279
- if name in {"provider", "channel", "invoke_type", "stream"}:
280
- continue
281
- if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
282
- illegal_fields.append(name)
283
-
284
- if illegal_fields:
285
- raise ValueError(
286
- f"{self.provider}({self.invoke_type})存在不支持的字段: {illegal_fields}"
287
- )
288
-
289
- return self
290
-
291
-
292
- class BatchModelRequest(BaseModel):
293
- user_context: UserContext # 用户信息
294
- items: List[BatchModelRequestItem] # 批量请求项列表
@@ -1,24 +0,0 @@
1
- from typing import Any, Iterator, Optional, Union, Dict, List
2
-
3
- from pydantic import BaseModel, ConfigDict
4
-
5
-
6
- class BaseResponse(BaseModel):
7
- model_config = ConfigDict(arbitrary_types_allowed=True)
8
-
9
- content: Optional[str] = None # 文本输出内容
10
- usage: Optional[Dict] = None # tokens / 请求成本等(JSON)
11
- stream_response: Optional[Union[Iterator[str], Any]] = None # 用于流式响应(同步 or 异步)
12
- raw_response: Optional[Union[Dict, List]] = None # 模型服务商返回的原始结构(JSON)
13
- error: Optional[Any] = None # 错误信息
14
- custom_id: Optional[str] = None # 自定义ID,用于批量请求时结果关联
15
-
16
-
17
- class ModelResponse(BaseResponse):
18
- model_config = ConfigDict(arbitrary_types_allowed=True)
19
- request_id: Optional[str] = None # 请求ID,用于跟踪请求
20
-
21
-
22
- class BatchModelResponse(BaseModel):
23
- request_id: Optional[str] = None # 请求ID,用于跟踪请求
24
- responses: Optional[List[BaseResponse]] = None # 批量请求的响应列表
@@ -1,111 +0,0 @@
1
- import asyncio
2
- import atexit
3
- import logging
4
- from typing import Optional, Union, Iterator
5
-
6
- from .async_client import AsyncModelManagerClient
7
- from .schemas import ModelRequest, BatchModelRequest, ModelResponse, BatchModelResponse
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
-
12
- class ModelManagerClient:
13
- """
14
- 同步版本的模型管理客户端,用于非异步环境(如 Flask、Django、脚本)。
15
- 内部封装 AsyncModelManagerClient 并处理事件循环兼容性。
16
- """
17
- _loop: Optional[asyncio.AbstractEventLoop] = None
18
-
19
- def __init__(
20
- self,
21
- server_address: Optional[str] = None,
22
- jwt_secret_key: Optional[str] = None,
23
- jwt_token: Optional[str] = None,
24
- default_payload: Optional[dict] = None,
25
- token_expires_in: int = 3600,
26
- max_retries: int = 3,
27
- retry_delay: float = 1.0,
28
- ):
29
- # 初始化全局事件循环,仅创建一次
30
- if not ModelManagerClient._loop:
31
- try:
32
- ModelManagerClient._loop = asyncio.get_running_loop()
33
- except RuntimeError:
34
- ModelManagerClient._loop = asyncio.new_event_loop()
35
- asyncio.set_event_loop(ModelManagerClient._loop)
36
-
37
- self._loop = ModelManagerClient._loop
38
-
39
- self._async_client = AsyncModelManagerClient(
40
- server_address=server_address,
41
- jwt_secret_key=jwt_secret_key,
42
- jwt_token=jwt_token,
43
- default_payload=default_payload,
44
- token_expires_in=token_expires_in,
45
- max_retries=max_retries,
46
- retry_delay=retry_delay,
47
- )
48
- atexit.register(self._safe_sync_close)
49
-
50
- def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None) -> Union[
51
- ModelResponse, Iterator[ModelResponse]]:
52
- """
53
- 同步调用单个模型任务
54
- """
55
- if model_request.stream:
56
- async def stream():
57
- async for r in await self._async_client.invoke(model_request, timeout=timeout):
58
- yield r
59
-
60
- return self._sync_wrap_async_generator(stream())
61
- return self._run_async(self._async_client.invoke(model_request, timeout=timeout))
62
-
63
- def invoke_batch(self, batch_model_request: BatchModelRequest,
64
- timeout: Optional[float] = None) -> BatchModelResponse:
65
- """
66
- 同步调用批量模型任务
67
- """
68
- return self._run_async(self._async_client.invoke_batch(batch_model_request, timeout=timeout))
69
-
70
- def close(self):
71
- """手动关闭 gRPC 通道"""
72
- self._run_async(self._async_client.close())
73
-
74
- def _safe_sync_close(self):
75
- """退出时自动关闭"""
76
- try:
77
- self._run_async(self._async_client.close())
78
- logger.info("✅ gRPC channel closed at exit")
79
- except Exception as e:
80
- logger.warning(f"❌ gRPC channel close failed at exit: {e}")
81
-
82
- def _run_async(self, coro):
83
- """统一运行协程,兼容已存在的事件循环"""
84
- try:
85
- loop = asyncio.get_running_loop()
86
- import nest_asyncio
87
- nest_asyncio.apply()
88
- return loop.run_until_complete(coro)
89
- except RuntimeError:
90
- return self._loop.run_until_complete(coro)
91
-
92
- def _sync_wrap_async_generator(self, async_gen_func):
93
- """
94
- 将 async generator 转换为同步 generator,逐条 yield。
95
- """
96
- loop = self._loop
97
-
98
- # 创建异步生成器对象
99
- agen = async_gen_func
100
-
101
- class SyncGenerator:
102
- def __iter__(self_inner):
103
- return self_inner
104
-
105
- def __next__(self_inner):
106
- try:
107
- return loop.run_until_complete(agen.__anext__())
108
- except StopAsyncIteration:
109
- raise StopIteration
110
-
111
- return SyncGenerator()
@@ -1,34 +0,0 @@
1
- model_manager_client/__init__.py,sha256=LsqGh8ARtH9PQijbUjjrvHHmG09YwY4jmejAtlqV9ng,336
2
- model_manager_client/async_client.py,sha256=4vl4wLMucTqJ8moAZb0bonKYC6gFKwASZFQJIbCiBM4,20599
3
- model_manager_client/auth.py,sha256=gbwW5Aakeb49PMbmYvrYlVx1mfyn1LEDJ4qQVs-9DA4,438
4
- model_manager_client/exceptions.py,sha256=jYU494OU_NeIa4X393V-Y73mTNm0JZ9yZApnlOM9CJQ,332
5
- model_manager_client/sync_client.py,sha256=rLap64kk4rvAGJQsB7OXH565PW35xlMiSXh0iQTnJiM,4024
6
- model_manager_client/enums/__init__.py,sha256=3cYYn8ztNGBa_pI_5JGRVYf2QX8fkBVWdjID1PLvoBQ,182
7
- model_manager_client/enums/channel.py,sha256=wCzX579nNpTtwzGeS6S3Ls0UzVAgsOlfy4fXMzQTCAw,199
8
- model_manager_client/enums/invoke.py,sha256=9C5BxyAd4En-PSscOMynhfDa5WavGaSSOVFSYQGerK4,215
9
- model_manager_client/enums/providers.py,sha256=L_bX75K6KnWURoFizoitZ1Ybza7bmYDqXecNzNpgIrI,165
10
- model_manager_client/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- model_manager_client/generated/model_service_pb2.py,sha256=ST84YYQk8x6UtQKIx6HprUxH5uGU4i3LhC8b-lHUQtA,3066
12
- model_manager_client/generated/model_service_pb2_grpc.py,sha256=BzsINWQeACVnVzLVV0PgieZA25C2-EklMKlA-W50c6Y,5147
13
- model_manager_client/schemas/__init__.py,sha256=AxuI-TcvA4OMTj2FtK4wAItvz9LrK_293pu3cmMLE7k,394
14
- model_manager_client/schemas/inputs.py,sha256=3HUxnbuyQbuvMz1C46zydFYz-iEvLAUWVzOx7-eKS_I,14338
15
- model_manager_client/schemas/outputs.py,sha256=M_fcqUtXPJnfiLabHlyA8BorlC5pYkf5KLjXO1ysKIQ,1031
16
- tamar_model_client/__init__.py,sha256=LMECAuDARWHV1XzH3msoDXcyurS2eihRQmBy26_PUE0,328
17
- tamar_model_client/async_client.py,sha256=gmZ2xMHO_F-Vtg3OK7B_yf-gtI-WH2NU2LzC6YO_t7k,19649
18
- tamar_model_client/auth.py,sha256=gbwW5Aakeb49PMbmYvrYlVx1mfyn1LEDJ4qQVs-9DA4,438
19
- tamar_model_client/exceptions.py,sha256=jYU494OU_NeIa4X393V-Y73mTNm0JZ9yZApnlOM9CJQ,332
20
- tamar_model_client/sync_client.py,sha256=o8b20fQUvtMq1gWax3_dfOpputYT4l9pRTz6cHdB0lg,4006
21
- tamar_model_client/enums/__init__.py,sha256=3cYYn8ztNGBa_pI_5JGRVYf2QX8fkBVWdjID1PLvoBQ,182
22
- tamar_model_client/enums/channel.py,sha256=wCzX579nNpTtwzGeS6S3Ls0UzVAgsOlfy4fXMzQTCAw,199
23
- tamar_model_client/enums/invoke.py,sha256=WufImoN_87ZjGyzYitZkhNNFefWJehKfLtyP-DTBYlA,267
24
- tamar_model_client/enums/providers.py,sha256=L_bX75K6KnWURoFizoitZ1Ybza7bmYDqXecNzNpgIrI,165
25
- tamar_model_client/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- tamar_model_client/generated/model_service_pb2.py,sha256=RI6wNSmgmylzWPedFfPxx938UzS7kcPR58YTzYshcL8,3066
27
- tamar_model_client/generated/model_service_pb2_grpc.py,sha256=k4tIbp3XBxdyuOVR18Ung_4SUryONB51UYf_uUEl6V4,5145
28
- tamar_model_client/schemas/__init__.py,sha256=AxuI-TcvA4OMTj2FtK4wAItvz9LrK_293pu3cmMLE7k,394
29
- tamar_model_client/schemas/inputs.py,sha256=Y9zzt-RoRklkxxe_3VJbZvPghJ00KUjHtFUmD0pCdHs,18721
30
- tamar_model_client/schemas/outputs.py,sha256=M_fcqUtXPJnfiLabHlyA8BorlC5pYkf5KLjXO1ysKIQ,1031
31
- tamar_model_client-0.1.11.dist-info/METADATA,sha256=Ia4eGAZVs3vebAQxIENipL-XfJ7_CXWag4OwFU3V5GA,16566
32
- tamar_model_client-0.1.11.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
33
- tamar_model_client-0.1.11.dist-info/top_level.txt,sha256=_LfDhPv_fvON0PoZgQuo4M7EjoWtxPRoQOBJziJmip8,19
34
- tamar_model_client-0.1.11.dist-info/RECORD,,