tamar-model-client 0.1.8__tar.gz → 0.1.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/PKG-INFO +17 -10
  2. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/README.md +14 -7
  3. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/setup.py +3 -3
  4. tamar_model_client-0.1.16/tamar_model_client/async_client.py +558 -0
  5. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/generated/model_service_pb2.py +3 -3
  6. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/generated/model_service_pb2_grpc.py +1 -1
  7. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/schemas/inputs.py +7 -2
  8. tamar_model_client-0.1.8/tamar_model_client/async_client.py → tamar_model_client-0.1.16/tamar_model_client/sync_client.py +168 -92
  9. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client.egg-info/PKG-INFO +17 -10
  10. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client.egg-info/requires.txt +2 -2
  11. tamar_model_client-0.1.8/tamar_model_client/sync_client.py +0 -111
  12. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/setup.cfg +0 -0
  13. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/__init__.py +0 -0
  14. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/auth.py +0 -0
  15. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/enums/__init__.py +0 -0
  16. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/enums/channel.py +0 -0
  17. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/enums/invoke.py +0 -0
  18. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/enums/providers.py +0 -0
  19. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/exceptions.py +0 -0
  20. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/generated/__init__.py +0 -0
  21. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/schemas/__init__.py +0 -0
  22. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client/schemas/outputs.py +0 -0
  23. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client.egg-info/SOURCES.txt +0 -0
  24. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client.egg-info/dependency_links.txt +0 -0
  25. {tamar_model_client-0.1.8 → tamar_model_client-0.1.16}/tamar_model_client.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tamar-model-client
3
- Version: 0.1.8
3
+ Version: 0.1.16
4
4
  Summary: A Python SDK for interacting with the Model Manager gRPC service
5
5
  Home-page: http://gitlab.tamaredge.top/project-tap/AgentOS/model-manager-client
6
6
  Author: Oscar Ou
@@ -11,8 +11,8 @@ Classifier: License :: OSI Approved :: MIT License
11
11
  Classifier: Operating System :: OS Independent
12
12
  Requires-Python: >=3.8
13
13
  Description-Content-Type: text/markdown
14
- Requires-Dist: grpcio
15
- Requires-Dist: grpcio-tools
14
+ Requires-Dist: grpcio~=1.67.1
15
+ Requires-Dist: grpcio-tools~=1.67.1
16
16
  Requires-Dist: pydantic
17
17
  Requires-Dist: PyJWT
18
18
  Requires-Dist: nest_asyncio
@@ -273,13 +273,13 @@ async def main():
273
273
  )
274
274
 
275
275
  # 发送请求并获取响应
276
- response = await client.invoke(request_data)
277
- if response.error:
278
- print(f"错误: {response.error}")
279
- else:
280
- print(f"响应: {response.content}")
281
- if response.usage:
282
- print(f"Token 使用情况: {response.usage}")
276
+ async for r in await client.invoke(model_request):
277
+ if r.error:
278
+ print(f"错误: {r.error}")
279
+ else:
280
+ print(f"响应: {r.content}")
281
+ if r.usage:
282
+ print(f"Token 使用情况: {r.usage}")
283
283
 
284
284
 
285
285
  # 运行异步示例
@@ -528,6 +528,13 @@ pip install -e .
528
528
  python make_grpc.py
529
529
  ```
530
530
 
531
+ ### 部署到 pip
532
+ ```bash
533
+ python setup.py sdist bdist_wheel
534
+ twine upload dist/*
535
+
536
+ ```
537
+
531
538
  ## 许可证
532
539
 
533
540
  MIT License
@@ -243,13 +243,13 @@ async def main():
243
243
  )
244
244
 
245
245
  # 发送请求并获取响应
246
- response = await client.invoke(request_data)
247
- if response.error:
248
- print(f"错误: {response.error}")
249
- else:
250
- print(f"响应: {response.content}")
251
- if response.usage:
252
- print(f"Token 使用情况: {response.usage}")
246
+ async for r in await client.invoke(model_request):
247
+ if r.error:
248
+ print(f"错误: {r.error}")
249
+ else:
250
+ print(f"响应: {r.content}")
251
+ if r.usage:
252
+ print(f"Token 使用情况: {r.usage}")
253
253
 
254
254
 
255
255
  # 运行异步示例
@@ -498,6 +498,13 @@ pip install -e .
498
498
  python make_grpc.py
499
499
  ```
500
500
 
501
+ ### 部署到 pip
502
+ ```bash
503
+ python setup.py sdist bdist_wheel
504
+ twine upload dist/*
505
+
506
+ ```
507
+
501
508
  ## 许可证
502
509
 
503
510
  MIT License
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name="tamar-model-client",
5
- version="0.1.8",
5
+ version="0.1.16",
6
6
  description="A Python SDK for interacting with the Model Manager gRPC service",
7
7
  author="Oscar Ou",
8
8
  author_email="oscar.ou@tamaredge.ai",
@@ -12,8 +12,8 @@ setup(
12
12
  "tamar_model_client": ["generated/*.py"], # 包含 gRPC 生成文件
13
13
  },
14
14
  install_requires=[
15
- "grpcio",
16
- "grpcio-tools",
15
+ "grpcio~=1.67.1",
16
+ "grpcio-tools~=1.67.1",
17
17
  "pydantic",
18
18
  "PyJWT",
19
19
  "nest_asyncio",
@@ -0,0 +1,558 @@
1
+ import asyncio
2
+ import atexit
3
+ import base64
4
+ import json
5
+ import logging
6
+ import os
7
+ import uuid
8
+ from contextvars import ContextVar
9
+
10
+ import grpc
11
+ from typing import Optional, AsyncIterator, Union, Iterable
12
+
13
+ from openai import NOT_GIVEN
14
+ from pydantic import BaseModel
15
+
16
+ from .auth import JWTAuthHandler
17
+ from .enums import ProviderType, InvokeType
18
+ from .exceptions import ConnectionError
19
+ from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
20
+ from .generated import model_service_pb2, model_service_pb2_grpc
21
+ from .schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput, \
22
+ GoogleVertexAIImagesInput, OpenAIImagesInput
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # 使用 contextvars 管理请求ID
27
+ _request_id: ContextVar[str] = ContextVar('request_id', default='-')
28
+
29
+
30
+ class RequestIdFilter(logging.Filter):
31
+ """自定义日志过滤器,向日志中添加 request_id"""
32
+
33
+ def filter(self, record):
34
+ # 从 ContextVar 中获取当前的 request_id
35
+ record.request_id = _request_id.get()
36
+ return True
37
+
38
+
39
+ if not logger.hasHandlers():
40
+ # 创建日志处理器,输出到控制台
41
+ console_handler = logging.StreamHandler()
42
+
43
+ # 设置日志格式
44
+ formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(request_id)s] %(message)s')
45
+ console_handler.setFormatter(formatter)
46
+
47
+ # 为当前记录器添加处理器
48
+ logger.addHandler(console_handler)
49
+
50
+ # 设置日志级别
51
+ logger.setLevel(logging.INFO)
52
+
53
+ # 将自定义的 RequestIdFilter 添加到 logger 中
54
+ logger.addFilter(RequestIdFilter())
55
+
56
+ MAX_MESSAGE_LENGTH = 2 ** 31 - 1 # 对于32位系统
57
+
58
+
59
+ def is_effective_value(value) -> bool:
60
+ """
61
+ 递归判断value是否是有意义的有效值
62
+ """
63
+ if value is None or value is NOT_GIVEN:
64
+ return False
65
+
66
+ if isinstance(value, str):
67
+ return value.strip() != ""
68
+
69
+ if isinstance(value, bytes):
70
+ return len(value) > 0
71
+
72
+ if isinstance(value, dict):
73
+ for v in value.values():
74
+ if is_effective_value(v):
75
+ return True
76
+ return False
77
+
78
+ if isinstance(value, list):
79
+ for item in value:
80
+ if is_effective_value(item):
81
+ return True
82
+ return False
83
+
84
+ return True # 其他类型(int/float/bool)只要不是None就算有效
85
+
86
+
87
+ def serialize_value(value):
88
+ """递归处理单个值,处理BaseModel, dict, list, bytes"""
89
+ if not is_effective_value(value):
90
+ return None
91
+ if isinstance(value, BaseModel):
92
+ return serialize_value(value.model_dump())
93
+ if hasattr(value, "dict") and callable(value.dict):
94
+ return serialize_value(value.dict())
95
+ if isinstance(value, dict):
96
+ return {k: serialize_value(v) for k, v in value.items()}
97
+ if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
98
+ return [serialize_value(v) for v in value]
99
+ if isinstance(value, bytes):
100
+ return f"bytes:{base64.b64encode(value).decode('utf-8')}"
101
+ return value
102
+
103
+
104
+ from typing import Any
105
+
106
+
107
+ def remove_none_from_dict(data: Any) -> Any:
108
+ """
109
+ 遍历 dict/list,递归删除 value 为 None 的字段
110
+ """
111
+ if isinstance(data, dict):
112
+ new_dict = {}
113
+ for key, value in data.items():
114
+ if value is None:
115
+ continue
116
+ cleaned_value = remove_none_from_dict(value)
117
+ new_dict[key] = cleaned_value
118
+ return new_dict
119
+ elif isinstance(data, list):
120
+ return [remove_none_from_dict(item) for item in data]
121
+ else:
122
+ return data
123
+
124
+
125
+ def generate_request_id():
126
+ """生成一个唯一的request_id"""
127
+ return str(uuid.uuid4())
128
+
129
+
130
+ def set_request_id(request_id: str):
131
+ """设置当前请求的 request_id"""
132
+ _request_id.set(request_id)
133
+
134
+
135
+ class AsyncTamarModelClient:
136
+ def __init__(
137
+ self,
138
+ server_address: Optional[str] = None,
139
+ jwt_secret_key: Optional[str] = None,
140
+ jwt_token: Optional[str] = None,
141
+ default_payload: Optional[dict] = None,
142
+ token_expires_in: int = 3600,
143
+ max_retries: Optional[int] = None, # 最大重试次数
144
+ retry_delay: Optional[float] = None, # 初始重试延迟(秒)
145
+ ):
146
+ # 服务端地址
147
+ self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
148
+ if not self.server_address:
149
+ raise ValueError("Server address must be provided via argument or environment variable.")
150
+ self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
151
+
152
+ # JWT 配置
153
+ self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
154
+ self.jwt_handler = JWTAuthHandler(self.jwt_secret_key)
155
+ self.jwt_token = jwt_token # 用户传入的 Token(可选)
156
+ self.default_payload = default_payload
157
+ self.token_expires_in = token_expires_in
158
+
159
+ # === TLS/Authority 配置 ===
160
+ self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
161
+ self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
162
+
163
+ # === 重试配置 ===
164
+ self.max_retries = max_retries if max_retries is not None else int(
165
+ os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
166
+ self.retry_delay = retry_delay if retry_delay is not None else float(
167
+ os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
168
+
169
+ # === gRPC 通道相关 ===
170
+ self.channel: Optional[grpc.aio.Channel] = None
171
+ self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
172
+ self._closed = False
173
+ atexit.register(self._safe_sync_close) # 注册进程退出自动关闭
174
+
175
+ async def _retry_request(self, func, *args, **kwargs):
176
+ retry_count = 0
177
+ while retry_count < self.max_retries:
178
+ try:
179
+ return await func(*args, **kwargs)
180
+ except (grpc.aio.AioRpcError, grpc.RpcError) as e:
181
+ # 对于取消的情况进行指数退避重试
182
+ if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
183
+ retry_count += 1
184
+ logger.warning(f"❌ RPC cancelled, retrying {retry_count}/{self.max_retries}...")
185
+ if retry_count < self.max_retries:
186
+ delay = self.retry_delay * (2 ** (retry_count - 1))
187
+ await asyncio.sleep(delay)
188
+ else:
189
+ logger.error("❌ Max retry reached for CANCELLED")
190
+ raise
191
+ # 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
192
+ elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
193
+ grpc.StatusCode.DEADLINE_EXCEEDED}:
194
+ retry_count += 1
195
+ logger.warning(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...")
196
+ if retry_count < self.max_retries:
197
+ delay = self.retry_delay * (2 ** (retry_count - 1))
198
+ await asyncio.sleep(delay)
199
+ else:
200
+ logger.error(f"❌ Max retry reached for {e.code()}")
201
+ raise
202
+ else:
203
+ logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
204
+ raise
205
+
206
+ async def _retry_request_stream(self, func, *args, **kwargs):
207
+ retry_count = 0
208
+ while retry_count < self.max_retries:
209
+ try:
210
+ return func(*args, **kwargs)
211
+ except (grpc.aio.AioRpcError, grpc.RpcError) as e:
212
+ # 对于取消的情况进行指数退避重试
213
+ if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
214
+ retry_count += 1
215
+ logger.warning(f"❌ RPC cancelled, retrying {retry_count}/{self.max_retries}...")
216
+ if retry_count < self.max_retries:
217
+ delay = self.retry_delay * (2 ** (retry_count - 1))
218
+ await asyncio.sleep(delay)
219
+ else:
220
+ logger.error("❌ Max retry reached for CANCELLED")
221
+ raise
222
+ # 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
223
+ elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
224
+ grpc.StatusCode.DEADLINE_EXCEEDED}:
225
+ retry_count += 1
226
+ logger.warning(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...")
227
+ if retry_count < self.max_retries:
228
+ delay = self.retry_delay * (2 ** (retry_count - 1))
229
+ await asyncio.sleep(delay)
230
+ else:
231
+ logger.error(f"❌ Max retry reached for {e.code()}")
232
+ raise
233
+ else:
234
+ logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
235
+ raise
236
+
237
+ def _build_auth_metadata(self, request_id: str) -> list:
238
+ # if not self.jwt_token and self.jwt_handler:
239
+ # 更改为每次请求都生成一次token
240
+ metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
241
+ if self.jwt_handler:
242
+ self.jwt_token = self.jwt_handler.encode_token(self.default_payload, expires_in=self.token_expires_in)
243
+ metadata.append(("authorization", f"Bearer {self.jwt_token}"))
244
+ return metadata
245
+
246
+ async def _ensure_initialized(self):
247
+ """初始化 gRPC 通道,支持 TLS 与重试机制"""
248
+ if self.channel and self.stub:
249
+ return
250
+
251
+ retry_count = 0
252
+ options = [
253
+ ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
254
+ ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
255
+ ('grpc.keepalive_permit_without_calls', True) # 即使没有活跃请求也保持连接
256
+ ]
257
+ if self.default_authority:
258
+ options.append(("grpc.default_authority", self.default_authority))
259
+
260
+ while retry_count <= self.max_retries:
261
+ try:
262
+ if self.use_tls:
263
+ credentials = grpc.ssl_channel_credentials()
264
+ self.channel = grpc.aio.secure_channel(
265
+ self.server_address,
266
+ credentials,
267
+ options=options
268
+ )
269
+ logger.info("🔐 Using secure gRPC channel (TLS enabled)")
270
+ else:
271
+ self.channel = grpc.aio.insecure_channel(
272
+ self.server_address,
273
+ options=options
274
+ )
275
+ logger.info("🔓 Using insecure gRPC channel (TLS disabled)")
276
+ await self.channel.channel_ready()
277
+ self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
278
+ logger.info(f"✅ gRPC channel initialized to {self.server_address}")
279
+ return
280
+ except grpc.FutureTimeoutError as e:
281
+ logger.error(f"❌ gRPC channel initialization timed out: {str(e)}", exc_info=True)
282
+ except grpc.RpcError as e:
283
+ logger.error(f"❌ gRPC channel initialization failed: {str(e)}", exc_info=True)
284
+ except Exception as e:
285
+ logger.error(f"❌ Unexpected error during channel initialization: {str(e)}", exc_info=True)
286
+
287
+ retry_count += 1
288
+ if retry_count > self.max_retries:
289
+ logger.error(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.", exc_info=True)
290
+ raise ConnectionError(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.")
291
+
292
+ # 指数退避:延迟时间 = retry_delay * (2 ^ (retry_count - 1))
293
+ delay = self.retry_delay * (2 ** (retry_count - 1))
294
+ logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
295
+ await asyncio.sleep(delay)
296
+
297
+ async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
298
+ async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
299
+ yield ModelResponse(
300
+ content=response.content,
301
+ usage=json.loads(response.usage) if response.usage else None,
302
+ error=response.error or None,
303
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
304
+ request_id=response.request_id if response.request_id else None,
305
+ )
306
+
307
+ async def _invoke_request(self, request, metadata, invoke_timeout):
308
+ async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
309
+ return ModelResponse(
310
+ content=response.content,
311
+ usage=json.loads(response.usage) if response.usage else None,
312
+ error=response.error or None,
313
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
314
+ request_id=response.request_id if response.request_id else None,
315
+ )
316
+
317
+ async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None,
318
+ request_id: Optional[str] = None) -> Union[
319
+ ModelResponse, AsyncIterator[ModelResponse]]:
320
+ """
321
+ 通用调用模型方法。
322
+
323
+ Args:
324
+ model_request: ModelRequest 对象,包含请求参数。
325
+ timeout: Optional[float]
326
+ request_id: Optional[str]
327
+ Yields:
328
+ ModelResponse: 支持流式或非流式的模型响应
329
+
330
+ Raises:
331
+ ValidationError: 输入验证失败。
332
+ ConnectionError: 连接服务端失败。
333
+ """
334
+ await self._ensure_initialized()
335
+
336
+ if not self.default_payload:
337
+ self.default_payload = {
338
+ "org_id": model_request.user_context.org_id or "",
339
+ "user_id": model_request.user_context.user_id or ""
340
+ }
341
+
342
+ if not request_id:
343
+ request_id = generate_request_id() # 生成一个新的 request_id
344
+ set_request_id(request_id) # 设置当前请求的 request_id
345
+ metadata = self._build_auth_metadata(request_id) # 将 request_id 加入到请求头
346
+
347
+ # 记录开始日志
348
+ logger.info(
349
+ f"🔵 Request Start | request_id: {request_id} | provider: {model_request.provider} | invoke_type: {model_request.invoke_type} | model_request: {model_request}")
350
+
351
+ # 动态根据 provider/invoke_type 决定使用哪个 input 字段
352
+ try:
353
+ # 选择需要校验的字段集合
354
+ # 动态分支逻辑
355
+ match (model_request.provider, model_request.invoke_type):
356
+ case (ProviderType.GOOGLE, InvokeType.GENERATION):
357
+ allowed_fields = GoogleGenAiInput.model_fields.keys()
358
+ case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
359
+ allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
360
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
361
+ allowed_fields = OpenAIResponsesInput.model_fields.keys()
362
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
363
+ allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
364
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
365
+ allowed_fields = OpenAIImagesInput.model_fields.keys()
366
+ case _:
367
+ raise ValueError(
368
+ f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
369
+
370
+ # 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
371
+ model_request_dict = model_request.model_dump(exclude_unset=True)
372
+
373
+ grpc_request_kwargs = {}
374
+ for field in allowed_fields:
375
+ if field in model_request_dict:
376
+ value = model_request_dict[field]
377
+
378
+ # 跳过无效的值
379
+ if not is_effective_value(value):
380
+ continue
381
+
382
+ # 序列化grpc不支持的类型
383
+ grpc_request_kwargs[field] = serialize_value(value)
384
+
385
+ # 清理 serialize后的 grpc_request_kwargs
386
+ grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
387
+
388
+ request = model_service_pb2.ModelRequestItem(
389
+ provider=model_request.provider.value,
390
+ channel=model_request.channel.value,
391
+ invoke_type=model_request.invoke_type.value,
392
+ stream=model_request.stream or False,
393
+ org_id=model_request.user_context.org_id or "",
394
+ user_id=model_request.user_context.user_id or "",
395
+ client_type=model_request.user_context.client_type or "",
396
+ extra=grpc_request_kwargs
397
+ )
398
+
399
+ except Exception as e:
400
+ raise ValueError(f"构建请求失败: {str(e)}") from e
401
+
402
+ try:
403
+ invoke_timeout = timeout or self.default_invoke_timeout
404
+ if model_request.stream:
405
+ return await self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
406
+ else:
407
+ return await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
408
+ except grpc.RpcError as e:
409
+ error_message = f"❌ Invoke gRPC failed: {str(e)}"
410
+ logger.error(error_message, exc_info=True)
411
+ raise e
412
+ except Exception as e:
413
+ error_message = f"❌ Invoke other error: {str(e)}"
414
+ logger.error(error_message, exc_info=True)
415
+ raise e
416
+
417
+ async def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None,
418
+ request_id: Optional[str] = None) -> \
419
+ BatchModelResponse:
420
+ """
421
+ 批量模型调用接口
422
+
423
+ Args:
424
+ batch_request_model: 多条 BatchModelRequest 输入
425
+ timeout: 调用超时,单位秒
426
+ request_id: 请求id
427
+ Returns:
428
+ BatchModelResponse: 批量请求的结果
429
+ """
430
+
431
+ await self._ensure_initialized()
432
+
433
+ if not self.default_payload:
434
+ self.default_payload = {
435
+ "org_id": batch_request_model.user_context.org_id or "",
436
+ "user_id": batch_request_model.user_context.user_id or ""
437
+ }
438
+
439
+ if not request_id:
440
+ request_id = generate_request_id() # 生成一个新的 request_id
441
+ set_request_id(request_id) # 设置当前请求的 request_id
442
+ metadata = self._build_auth_metadata(request_id) # 将 request_id 加入到请求头
443
+
444
+ # 记录开始日志
445
+ logger.info(
446
+ f"🔵 Batch Request Start | request_id: {request_id} | batch_size: {len(batch_request_model.items)} | batch_request_model: {batch_request_model}")
447
+
448
+ # 构造批量请求
449
+ items = []
450
+ for model_request_item in batch_request_model.items:
451
+ # 动态根据 provider/invoke_type 决定使用哪个 input 字段
452
+ try:
453
+ match (model_request_item.provider, model_request_item.invoke_type):
454
+ case (ProviderType.GOOGLE, InvokeType.GENERATION):
455
+ allowed_fields = GoogleGenAiInput.model_fields.keys()
456
+ case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
457
+ allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
458
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
459
+ allowed_fields = OpenAIResponsesInput.model_fields.keys()
460
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
461
+ allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
462
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
463
+ allowed_fields = OpenAIImagesInput.model_fields.keys()
464
+ case _:
465
+ raise ValueError(
466
+ f"Unsupported provider/invoke_type combination: {model_request_item.provider} + {model_request_item.invoke_type}")
467
+
468
+ # 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
469
+ model_request_dict = model_request_item.model_dump(exclude_unset=True)
470
+
471
+ grpc_request_kwargs = {}
472
+ for field in allowed_fields:
473
+ if field in model_request_dict:
474
+ value = model_request_dict[field]
475
+
476
+ # 跳过无效的值
477
+ if not is_effective_value(value):
478
+ continue
479
+
480
+ # 序列化grpc不支持的类型
481
+ grpc_request_kwargs[field] = serialize_value(value)
482
+
483
+ # 清理 serialize后的 grpc_request_kwargs
484
+ grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
485
+
486
+ items.append(model_service_pb2.ModelRequestItem(
487
+ provider=model_request_item.provider.value,
488
+ channel=model_request_item.channel.value,
489
+ invoke_type=model_request_item.invoke_type.value,
490
+ stream=model_request_item.stream or False,
491
+ custom_id=model_request_item.custom_id or "",
492
+ priority=model_request_item.priority or 1,
493
+ org_id=batch_request_model.user_context.org_id or "",
494
+ user_id=batch_request_model.user_context.user_id or "",
495
+ client_type=batch_request_model.user_context.client_type or "",
496
+ extra=grpc_request_kwargs,
497
+ ))
498
+
499
+ except Exception as e:
500
+ raise ValueError(f"构建请求失败: {str(e)},item={model_request_item.custom_id}") from e
501
+
502
+ try:
503
+ # 超时处理逻辑
504
+ invoke_timeout = timeout or self.default_invoke_timeout
505
+
506
+ # 调用 gRPC 接口
507
+ response = await self._retry_request(self.stub.BatchInvoke, model_service_pb2.ModelRequest(items=items),
508
+ timeout=invoke_timeout, metadata=metadata)
509
+
510
+ result = []
511
+ for res_item in response.items:
512
+ result.append(ModelResponse(
513
+ content=res_item.content,
514
+ usage=json.loads(res_item.usage) if res_item.usage else None,
515
+ raw_response=json.loads(res_item.raw_response) if res_item.raw_response else None,
516
+ error=res_item.error or None,
517
+ custom_id=res_item.custom_id if res_item.custom_id else None
518
+ ))
519
+ return BatchModelResponse(
520
+ request_id=response.request_id if response.request_id else None,
521
+ responses=result
522
+ )
523
+ except grpc.RpcError as e:
524
+ error_message = f"❌ BatchInvoke gRPC failed: {str(e)}"
525
+ logger.error(error_message, exc_info=True)
526
+ raise e
527
+ except Exception as e:
528
+ error_message = f"❌ BatchInvoke other error: {str(e)}"
529
+ logger.error(error_message, exc_info=True)
530
+ raise e
531
+
532
+ async def close(self):
533
+ """关闭 gRPC 通道"""
534
+ if self.channel and not self._closed:
535
+ await self.channel.close()
536
+ self._closed = True
537
+ logger.info("✅ gRPC channel closed")
538
+
539
+ def _safe_sync_close(self):
540
+ """进程退出时自动关闭 channel(事件循环处理兼容)"""
541
+ if self.channel and not self._closed:
542
+ try:
543
+ loop = asyncio.get_event_loop()
544
+ if loop.is_running():
545
+ loop.create_task(self.close())
546
+ else:
547
+ loop.run_until_complete(self.close())
548
+ except Exception as e:
549
+ logger.warning(f"❌ gRPC channel close failed at exit: {e}")
550
+
551
+ async def __aenter__(self):
552
+ """支持 async with 自动初始化连接"""
553
+ await self._ensure_initialized()
554
+ return self
555
+
556
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
557
+ """支持 async with 自动关闭连接"""
558
+ await self.close()
@@ -2,7 +2,7 @@
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
3
  # NO CHECKED-IN PROTOBUF GENCODE
4
4
  # source: model_service.proto
5
- # Protobuf Python Version: 5.29.0
5
+ # Protobuf Python Version: 5.27.2
6
6
  """Generated protocol buffer code."""
7
7
  from google.protobuf import descriptor as _descriptor
8
8
  from google.protobuf import descriptor_pool as _descriptor_pool
@@ -12,8 +12,8 @@ from google.protobuf.internal import builder as _builder
12
12
  _runtime_version.ValidateProtobufRuntimeVersion(
13
13
  _runtime_version.Domain.PUBLIC,
14
14
  5,
15
- 29,
16
- 0,
15
+ 27,
16
+ 2,
17
17
  '',
18
18
  'model_service.proto'
19
19
  )