tamar-model-client 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,8 @@ import base64
4
4
  import json
5
5
  import logging
6
6
  import os
7
+ import uuid
8
+ from contextvars import ContextVar
7
9
 
8
10
  import grpc
9
11
  from typing import Optional, AsyncIterator, Union, Iterable
@@ -13,21 +15,44 @@ from pydantic import BaseModel
13
15
 
14
16
  from .auth import JWTAuthHandler
15
17
  from .enums import ProviderType, InvokeType
16
- from .exceptions import ConnectionError, ValidationError
18
+ from .exceptions import ConnectionError
17
19
  from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
18
20
  from .generated import model_service_pb2, model_service_pb2_grpc
19
21
  from .schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput, \
20
22
  GoogleVertexAIImagesInput, OpenAIImagesInput
21
23
 
22
- if not logging.getLogger().hasHandlers():
23
- # 配置日志格式
24
- logging.basicConfig(
25
- level=logging.INFO,
26
- format="%(asctime)s [%(levelname)s] %(message)s",
27
- )
28
-
29
24
  logger = logging.getLogger(__name__)
30
25
 
26
+ # 使用 contextvars 管理请求ID
27
+ _request_id: ContextVar[str] = ContextVar('request_id', default='-')
28
+
29
+
30
+ class RequestIdFilter(logging.Filter):
31
+ """自定义日志过滤器,向日志中添加 request_id"""
32
+
33
+ def filter(self, record):
34
+ # 从 ContextVar 中获取当前的 request_id
35
+ record.request_id = _request_id.get()
36
+ return True
37
+
38
+
39
+ if not logger.hasHandlers():
40
+ # 创建日志处理器,输出到控制台
41
+ console_handler = logging.StreamHandler()
42
+
43
+ # 设置日志格式
44
+ formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(request_id)s] %(message)s')
45
+ console_handler.setFormatter(formatter)
46
+
47
+ # 为当前记录器添加处理器
48
+ logger.addHandler(console_handler)
49
+
50
+ # 设置日志级别
51
+ logger.setLevel(logging.INFO)
52
+
53
+ # 将自定义的 RequestIdFilter 添加到 logger 中
54
+ logger.addFilter(RequestIdFilter())
55
+
31
56
  MAX_MESSAGE_LENGTH = 2 ** 31 - 1 # 对于32位系统
32
57
 
33
58
 
@@ -97,6 +122,16 @@ def remove_none_from_dict(data: Any) -> Any:
97
122
  return data
98
123
 
99
124
 
125
+ def generate_request_id():
126
+ """生成一个唯一的request_id"""
127
+ return str(uuid.uuid4())
128
+
129
+
130
+ def set_request_id(request_id: str):
131
+ """设置当前请求的 request_id"""
132
+ _request_id.set(request_id)
133
+
134
+
100
135
  class AsyncTamarModelClient:
101
136
  def __init__(
102
137
  self,
@@ -105,8 +140,8 @@ class AsyncTamarModelClient:
105
140
  jwt_token: Optional[str] = None,
106
141
  default_payload: Optional[dict] = None,
107
142
  token_expires_in: int = 3600,
108
- max_retries: int = 3, # 最大重试次数
109
- retry_delay: float = 1.0, # 初始重试延迟(秒)
143
+ max_retries: Optional[int] = None, # 最大重试次数
144
+ retry_delay: Optional[float] = None, # 初始重试延迟(秒)
110
145
  ):
111
146
  # 服务端地址
112
147
  self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
@@ -137,12 +172,76 @@ class AsyncTamarModelClient:
137
172
  self._closed = False
138
173
  atexit.register(self._safe_sync_close) # 注册进程退出自动关闭
139
174
 
140
- def _build_auth_metadata(self) -> list:
175
+ async def _retry_request(self, func, *args, **kwargs):
176
+ retry_count = 0
177
+ while retry_count < self.max_retries:
178
+ try:
179
+ return await func(*args, **kwargs)
180
+ except (grpc.aio.AioRpcError, grpc.RpcError) as e:
181
+ # 对于取消的情况进行指数退避重试
182
+ if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
183
+ retry_count += 1
184
+ logger.warning(f"❌ RPC cancelled, retrying {retry_count}/{self.max_retries}...")
185
+ if retry_count < self.max_retries:
186
+ delay = self.retry_delay * (2 ** (retry_count - 1))
187
+ await asyncio.sleep(delay)
188
+ else:
189
+ logger.error("❌ Max retry reached for CANCELLED")
190
+ raise
191
+ # 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
192
+ elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
193
+ grpc.StatusCode.DEADLINE_EXCEEDED}:
194
+ retry_count += 1
195
+ logger.warning(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...")
196
+ if retry_count < self.max_retries:
197
+ delay = self.retry_delay * (2 ** (retry_count - 1))
198
+ await asyncio.sleep(delay)
199
+ else:
200
+ logger.error(f"❌ Max retry reached for {e.code()}")
201
+ raise
202
+ else:
203
+ logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
204
+ raise
205
+
206
+ async def _retry_request_stream(self, func, *args, **kwargs):
207
+ retry_count = 0
208
+ while retry_count < self.max_retries:
209
+ try:
210
+ return func(*args, **kwargs)
211
+ except (grpc.aio.AioRpcError, grpc.RpcError) as e:
212
+ # 对于取消的情况进行指数退避重试
213
+ if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
214
+ retry_count += 1
215
+ logger.warning(f"❌ RPC cancelled, retrying {retry_count}/{self.max_retries}...")
216
+ if retry_count < self.max_retries:
217
+ delay = self.retry_delay * (2 ** (retry_count - 1))
218
+ await asyncio.sleep(delay)
219
+ else:
220
+ logger.error("❌ Max retry reached for CANCELLED")
221
+ raise
222
+ # 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
223
+ elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
224
+ grpc.StatusCode.DEADLINE_EXCEEDED}:
225
+ retry_count += 1
226
+ logger.warning(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...")
227
+ if retry_count < self.max_retries:
228
+ delay = self.retry_delay * (2 ** (retry_count - 1))
229
+ await asyncio.sleep(delay)
230
+ else:
231
+ logger.error(f"❌ Max retry reached for {e.code()}")
232
+ raise
233
+ else:
234
+ logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
235
+ raise
236
+
237
+ def _build_auth_metadata(self, request_id: str) -> list:
141
238
  # if not self.jwt_token and self.jwt_handler:
142
239
  # 更改为每次请求都生成一次token
240
+ metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
143
241
  if self.jwt_handler:
144
242
  self.jwt_token = self.jwt_handler.encode_token(self.default_payload, expires_in=self.token_expires_in)
145
- return [("authorization", f"Bearer {self.jwt_token}")] if self.jwt_token else []
243
+ metadata.append(("authorization", f"Bearer {self.jwt_token}"))
244
+ return metadata
146
245
 
147
246
  async def _ensure_initialized(self):
148
247
  """初始化 gRPC 通道,支持 TLS 与重试机制"""
@@ -195,23 +294,36 @@ class AsyncTamarModelClient:
195
294
  logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
196
295
  await asyncio.sleep(delay)
197
296
 
198
- async def _stream(self, model_request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
199
- async for response in self.stub.Invoke(model_request, metadata=metadata, timeout=invoke_timeout):
297
+ async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
298
+ async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
200
299
  yield ModelResponse(
201
300
  content=response.content,
202
301
  usage=json.loads(response.usage) if response.usage else None,
302
+ error=response.error or None,
203
303
  raw_response=json.loads(response.raw_response) if response.raw_response else None,
304
+ request_id=response.request_id if response.request_id else None,
305
+ )
306
+
307
+ async def _invoke_request(self, request, metadata, invoke_timeout):
308
+ async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
309
+ return ModelResponse(
310
+ content=response.content,
311
+ usage=json.loads(response.usage) if response.usage else None,
204
312
  error=response.error or None,
313
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
314
+ request_id=response.request_id if response.request_id else None,
205
315
  )
206
316
 
207
- async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None) -> Union[
317
+ async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None,
318
+ request_id: Optional[str] = None) -> Union[
208
319
  ModelResponse, AsyncIterator[ModelResponse]]:
209
320
  """
210
321
  通用调用模型方法。
211
322
 
212
323
  Args:
213
324
  model_request: ModelRequest 对象,包含请求参数。
214
-
325
+ timeout: Optional[float]
326
+ request_id: Optional[str]
215
327
  Yields:
216
328
  ModelResponse: 支持流式或非流式的模型响应
217
329
 
@@ -227,6 +339,15 @@ class AsyncTamarModelClient:
227
339
  "user_id": model_request.user_context.user_id or ""
228
340
  }
229
341
 
342
+ if not request_id:
343
+ request_id = generate_request_id() # 生成一个新的 request_id
344
+ set_request_id(request_id) # 设置当前请求的 request_id
345
+ metadata = self._build_auth_metadata(request_id) # 将 request_id 加入到请求头
346
+
347
+ # 记录开始日志
348
+ logger.info(
349
+ f"🔵 Request Start | request_id: {request_id} | provider: {model_request.provider} | invoke_type: {model_request.invoke_type} | model_request: {model_request}")
350
+
230
351
  # 动态根据 provider/invoke_type 决定使用哪个 input 字段
231
352
  try:
232
353
  # 选择需要校验的字段集合
@@ -278,22 +399,12 @@ class AsyncTamarModelClient:
278
399
  except Exception as e:
279
400
  raise ValueError(f"构建请求失败: {str(e)}") from e
280
401
 
281
- metadata = self._build_auth_metadata()
282
-
283
402
  try:
284
403
  invoke_timeout = timeout or self.default_invoke_timeout
285
404
  if model_request.stream:
286
- return self._stream(request, metadata, invoke_timeout)
405
+ return await self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
287
406
  else:
288
- async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
289
- return ModelResponse(
290
- content=response.content,
291
- usage=json.loads(response.usage) if response.usage else None,
292
- raw_response=json.loads(response.raw_response) if response.raw_response else None,
293
- error=response.error or None,
294
- custom_id=None,
295
- request_id=response.request_id if response.request_id else None,
296
- )
407
+ return await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
297
408
  except grpc.RpcError as e:
298
409
  error_message = f"❌ Invoke gRPC failed: {str(e)}"
299
410
  logger.error(error_message, exc_info=True)
@@ -303,7 +414,8 @@ class AsyncTamarModelClient:
303
414
  logger.error(error_message, exc_info=True)
304
415
  raise e
305
416
 
306
- async def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None) -> \
417
+ async def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None,
418
+ request_id: Optional[str] = None) -> \
307
419
  BatchModelResponse:
308
420
  """
309
421
  批量模型调用接口
@@ -311,10 +423,11 @@ class AsyncTamarModelClient:
311
423
  Args:
312
424
  batch_request_model: 多条 BatchModelRequest 输入
313
425
  timeout: 调用超时,单位秒
314
-
426
+ request_id: 请求id
315
427
  Returns:
316
428
  BatchModelResponse: 批量请求的结果
317
429
  """
430
+
318
431
  await self._ensure_initialized()
319
432
 
320
433
  if not self.default_payload:
@@ -323,7 +436,14 @@ class AsyncTamarModelClient:
323
436
  "user_id": batch_request_model.user_context.user_id or ""
324
437
  }
325
438
 
326
- metadata = self._build_auth_metadata()
439
+ if not request_id:
440
+ request_id = generate_request_id() # 生成一个新的 request_id
441
+ set_request_id(request_id) # 设置当前请求的 request_id
442
+ metadata = self._build_auth_metadata(request_id) # 将 request_id 加入到请求头
443
+
444
+ # 记录开始日志
445
+ logger.info(
446
+ f"🔵 Batch Request Start | request_id: {request_id} | batch_size: {len(batch_request_model.items)} | batch_request_model: {batch_request_model}")
327
447
 
328
448
  # 构造批量请求
329
449
  items = []
@@ -384,11 +504,8 @@ class AsyncTamarModelClient:
384
504
  invoke_timeout = timeout or self.default_invoke_timeout
385
505
 
386
506
  # 调用 gRPC 接口
387
- response = await self.stub.BatchInvoke(
388
- model_service_pb2.ModelRequest(items=items),
389
- timeout=invoke_timeout,
390
- metadata=metadata
391
- )
507
+ response = await self._retry_request(self.stub.BatchInvoke, model_service_pb2.ModelRequest(items=items),
508
+ timeout=invoke_timeout, metadata=metadata)
392
509
 
393
510
  result = []
394
511
  for res_item in response.items:
@@ -417,7 +534,6 @@ class AsyncTamarModelClient:
417
534
  if self.channel and not self._closed:
418
535
  await self.channel.close()
419
536
  self._closed = True
420
- await self.channel.close()
421
537
  logger.info("✅ gRPC channel closed")
422
538
 
423
539
  def _safe_sync_close(self):
@@ -1,21 +1,135 @@
1
- import asyncio
2
- import atexit
1
+ import base64
2
+ import json
3
3
  import logging
4
- from typing import Optional, Union, Iterator
4
+ import os
5
+ import time
6
+ import uuid
7
+ import grpc
8
+ from typing import Optional, Union, Iterable, Iterator
9
+ from contextvars import ContextVar
5
10
 
6
- from .async_client import AsyncTamarModelClient
7
- from .schemas import ModelRequest, BatchModelRequest, ModelResponse, BatchModelResponse
11
+ from openai import NOT_GIVEN
12
+ from pydantic import BaseModel
13
+
14
+ from .auth import JWTAuthHandler
15
+ from .enums import ProviderType, InvokeType
16
+ from .exceptions import ConnectionError
17
+ from .generated import model_service_pb2, model_service_pb2_grpc
18
+ from .schemas import BatchModelResponse, ModelResponse
19
+ from .schemas.inputs import GoogleGenAiInput, GoogleVertexAIImagesInput, OpenAIResponsesInput, \
20
+ OpenAIChatCompletionsInput, OpenAIImagesInput, BatchModelRequest, ModelRequest
8
21
 
9
22
  logger = logging.getLogger(__name__)
10
23
 
24
+ _request_id: ContextVar[str] = ContextVar('request_id', default='-')
11
25
 
12
- class TamarModelClient:
26
+
27
+ class RequestIdFilter(logging.Filter):
28
+ """自定义日志过滤器,向日志中添加 request_id"""
29
+
30
+ def filter(self, record):
31
+ # 从 ContextVar 中获取当前的 request_id
32
+ record.request_id = _request_id.get()
33
+ return True
34
+
35
+
36
+ if not logger.hasHandlers():
37
+ # 创建日志处理器,输出到控制台
38
+ console_handler = logging.StreamHandler()
39
+
40
+ # 设置日志格式
41
+ formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(request_id)s] %(message)s')
42
+ console_handler.setFormatter(formatter)
43
+
44
+ # 为当前记录器添加处理器
45
+ logger.addHandler(console_handler)
46
+
47
+ # 设置日志级别
48
+ logger.setLevel(logging.INFO)
49
+
50
+ # 将自定义的 RequestIdFilter 添加到 logger 中
51
+ logger.addFilter(RequestIdFilter())
52
+
53
+ MAX_MESSAGE_LENGTH = 2 ** 31 - 1 # 对于32位系统
54
+
55
+
56
+ def is_effective_value(value) -> bool:
13
57
  """
14
- 同步版本的模型管理客户端,用于非异步环境(如 Flask、Django、脚本)。
15
- 内部封装 AsyncTamarModelClient 并处理事件循环兼容性。
58
+ 递归判断value是否是有意义的有效值
16
59
  """
17
- _loop: Optional[asyncio.AbstractEventLoop] = None
60
+ if value is None or value is NOT_GIVEN:
61
+ return False
62
+
63
+ if isinstance(value, str):
64
+ return value.strip() != ""
65
+
66
+ if isinstance(value, bytes):
67
+ return len(value) > 0
68
+
69
+ if isinstance(value, dict):
70
+ for v in value.values():
71
+ if is_effective_value(v):
72
+ return True
73
+ return False
74
+
75
+ if isinstance(value, list):
76
+ for item in value:
77
+ if is_effective_value(item):
78
+ return True
79
+ return False
80
+
81
+ return True # 其他类型(int/float/bool)只要不是None就算有效
82
+
83
+
84
+ def serialize_value(value):
85
+ """递归处理单个值,处理BaseModel, dict, list, bytes"""
86
+ if not is_effective_value(value):
87
+ return None
88
+ if isinstance(value, BaseModel):
89
+ return serialize_value(value.model_dump())
90
+ if hasattr(value, "dict") and callable(value.dict):
91
+ return serialize_value(value.dict())
92
+ if isinstance(value, dict):
93
+ return {k: serialize_value(v) for k, v in value.items()}
94
+ if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
95
+ return [serialize_value(v) for v in value]
96
+ if isinstance(value, bytes):
97
+ return f"bytes:{base64.b64encode(value).decode('utf-8')}"
98
+ return value
99
+
100
+
101
+ from typing import Any
18
102
 
103
+
104
+ def remove_none_from_dict(data: Any) -> Any:
105
+ """
106
+ 遍历 dict/list,递归删除 value 为 None 的字段
107
+ """
108
+ if isinstance(data, dict):
109
+ new_dict = {}
110
+ for key, value in data.items():
111
+ if value is None:
112
+ continue
113
+ cleaned_value = remove_none_from_dict(value)
114
+ new_dict[key] = cleaned_value
115
+ return new_dict
116
+ elif isinstance(data, list):
117
+ return [remove_none_from_dict(item) for item in data]
118
+ else:
119
+ return data
120
+
121
+
122
+ def generate_request_id():
123
+ """生成一个唯一的request_id"""
124
+ return str(uuid.uuid4())
125
+
126
+
127
+ def set_request_id(request_id: str):
128
+ """设置当前请求的 request_id"""
129
+ _request_id.set(request_id)
130
+
131
+
132
+ class TamarModelClient:
19
133
  def __init__(
20
134
  self,
21
135
  server_address: Optional[str] = None,
@@ -23,89 +137,370 @@ class TamarModelClient:
23
137
  jwt_token: Optional[str] = None,
24
138
  default_payload: Optional[dict] = None,
25
139
  token_expires_in: int = 3600,
26
- max_retries: int = 3,
27
- retry_delay: float = 1.0,
140
+ max_retries: Optional[int] = None, # 最大重试次数
141
+ retry_delay: Optional[float] = None, # 初始重试延迟(秒)
28
142
  ):
29
- # 初始化全局事件循环,仅创建一次
30
- if not TamarModelClient._loop:
143
+ self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
144
+ if not self.server_address:
145
+ raise ValueError("Server address must be provided via argument or environment variable.")
146
+ self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
147
+
148
+ # JWT 配置
149
+ self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
150
+ self.jwt_handler = JWTAuthHandler(self.jwt_secret_key)
151
+ self.jwt_token = jwt_token # 用户传入的 Token(可选)
152
+ self.default_payload = default_payload
153
+ self.token_expires_in = token_expires_in
154
+
155
+ # === TLS/Authority 配置 ===
156
+ self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
157
+ self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
158
+
159
+ # === 重试配置 ===
160
+ self.max_retries = max_retries if max_retries is not None else int(
161
+ os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
162
+ self.retry_delay = retry_delay if retry_delay is not None else float(
163
+ os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
164
+
165
+ # === gRPC 通道相关 ===
166
+ self.channel: Optional[grpc.Channel] = None
167
+ self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
168
+ self._closed = False
169
+
170
+ def _retry_request(self, func, *args, **kwargs):
171
+ retry_count = 0
172
+ while retry_count < self.max_retries:
31
173
  try:
32
- TamarModelClient._loop = asyncio.get_running_loop()
33
- except RuntimeError:
34
- TamarModelClient._loop = asyncio.new_event_loop()
35
- asyncio.set_event_loop(TamarModelClient._loop)
36
-
37
- self._loop = TamarModelClient._loop
38
-
39
- self._async_client = AsyncTamarModelClient(
40
- server_address=server_address,
41
- jwt_secret_key=jwt_secret_key,
42
- jwt_token=jwt_token,
43
- default_payload=default_payload,
44
- token_expires_in=token_expires_in,
45
- max_retries=max_retries,
46
- retry_delay=retry_delay,
47
- )
48
- atexit.register(self._safe_sync_close)
49
-
50
- def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None) -> Union[
51
- ModelResponse, Iterator[ModelResponse]]:
52
- """
53
- 同步调用单个模型任务
54
- """
55
- if model_request.stream:
56
- async def stream():
57
- async for r in await self._async_client.invoke(model_request, timeout=timeout):
58
- yield r
174
+ return func(*args, **kwargs)
175
+ except (grpc.RpcError) as e:
176
+ if e.code() in {grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.DEADLINE_EXCEEDED}:
177
+ retry_count += 1
178
+ logger.error(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...")
179
+ if retry_count < self.max_retries:
180
+ delay = self.retry_delay * (2 ** (retry_count - 1))
181
+ time.sleep(delay)
182
+ else:
183
+ logger.error(f"❌ Max retry reached for {e.code()}")
184
+ raise
185
+ else:
186
+ logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
187
+ raise
188
+
189
+ def _build_auth_metadata(self, request_id: str) -> list:
190
+ metadata = [("x-request-id", request_id)] # 将 request_id 添加到 headers
191
+ if self.jwt_handler:
192
+ self.jwt_token = self.jwt_handler.encode_token(self.default_payload, expires_in=self.token_expires_in)
193
+ metadata.append(("authorization", f"Bearer {self.jwt_token}"))
194
+ return metadata
195
+
196
+ def _ensure_initialized(self):
197
+ """初始化 gRPC 通道,支持 TLS 与重试机制"""
198
+ if self.channel and self.stub:
199
+ return
200
+
201
+ retry_count = 0
202
+ options = [
203
+ ('grpc.max_send_message_length', MAX_MESSAGE_LENGTH),
204
+ ('grpc.max_receive_message_length', MAX_MESSAGE_LENGTH),
205
+ ('grpc.keepalive_permit_without_calls', True) # 即使没有活跃请求也保持连接
206
+ ]
207
+ if self.default_authority:
208
+ options.append(("grpc.default_authority", self.default_authority))
209
+
210
+ while retry_count <= self.max_retries:
211
+ try:
212
+ if self.use_tls:
213
+ credentials = grpc.ssl_channel_credentials()
214
+ self.channel = grpc.secure_channel(
215
+ self.server_address,
216
+ credentials,
217
+ options=options
218
+ )
219
+ logger.info("🔐 Using secure gRPC channel (TLS enabled)")
220
+ else:
221
+ self.channel = grpc.insecure_channel(
222
+ self.server_address,
223
+ options=options
224
+ )
225
+ logger.info("🔓 Using insecure gRPC channel (TLS disabled)")
226
+
227
+ # Wait for the channel to be ready (synchronously)
228
+ grpc.channel_ready_future(self.channel).result() # This is blocking in sync mode
229
+
230
+ self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
231
+ logger.info(f"✅ gRPC channel initialized to {self.server_address}")
232
+ return
233
+ except grpc.FutureTimeoutError as e:
234
+ logger.error(f"❌ gRPC channel initialization timed out: {str(e)}", exc_info=True)
235
+ except grpc.RpcError as e:
236
+ logger.error(f"❌ gRPC channel initialization failed: {str(e)}", exc_info=True)
237
+ except Exception as e:
238
+ logger.error(f"❌ Unexpected error during channel initialization: {str(e)}", exc_info=True)
239
+
240
+ retry_count += 1
241
+ if retry_count > self.max_retries:
242
+ logger.error(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.", exc_info=True)
243
+ raise ConnectionError(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.")
244
+
245
+ # 指数退避:延迟时间 = retry_delay * (2 ^ (retry_count - 1))
246
+ delay = self.retry_delay * (2 ** (retry_count - 1))
247
+ logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
248
+ time.sleep(delay) # Blocking sleep in sync version
249
+
250
+ def _stream(self, request, metadata, invoke_timeout) -> Iterator[ModelResponse]:
251
+ for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
252
+ yield ModelResponse(
253
+ content=response.content,
254
+ usage=json.loads(response.usage) if response.usage else None,
255
+ error=response.error or None,
256
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
257
+ request_id=response.request_id if response.request_id else None,
258
+ )
59
259
 
60
- return self._sync_wrap_async_generator(stream())
61
- return self._run_async(self._async_client.invoke(model_request, timeout=timeout))
260
+ def _invoke_request(self, request, metadata, invoke_timeout):
261
+ response = self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout)
262
+ for response in response:
263
+ return ModelResponse(
264
+ content=response.content,
265
+ usage=json.loads(response.usage) if response.usage else None,
266
+ error=response.error or None,
267
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
268
+ request_id=response.request_id if response.request_id else None,
269
+ )
62
270
 
63
- def invoke_batch(self, batch_model_request: BatchModelRequest,
64
- timeout: Optional[float] = None) -> BatchModelResponse:
271
+ def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None, request_id: Optional[str] = None) -> \
272
+ Union[ModelResponse, Iterator[ModelResponse]]:
65
273
  """
66
- 同步调用批量模型任务
274
+ 通用调用模型方法。
275
+
276
+ Args:
277
+ model_request: ModelRequest 对象,包含请求参数。
278
+ timeout: Optional[float]
279
+ request_id: Optional[str]
280
+ Yields:
281
+ ModelResponse: 支持流式或非流式的模型响应
282
+
283
+ Raises:
284
+ ValidationError: 输入验证失败。
285
+ ConnectionError: 连接服务端失败。
67
286
  """
68
- return self._run_async(self._async_client.invoke_batch(batch_model_request, timeout=timeout))
287
+ self._ensure_initialized()
69
288
 
70
- def close(self):
71
- """手动关闭 gRPC 通道"""
72
- self._run_async(self._async_client.close())
289
+ if not self.default_payload:
290
+ self.default_payload = {
291
+ "org_id": model_request.user_context.org_id or "",
292
+ "user_id": model_request.user_context.user_id or ""
293
+ }
73
294
 
74
- def _safe_sync_close(self):
75
- """退出时自动关闭"""
295
+ if not request_id:
296
+ request_id = generate_request_id() # 生成一个新的 request_id
297
+ set_request_id(request_id) # 设置当前请求的 request_id
298
+ metadata = self._build_auth_metadata(request_id) # 将 request_id 加入到请求头
299
+
300
+ # 记录开始日志
301
+ logger.info(
302
+ f"🔵 Request Start | request_id: {request_id} | provider: {model_request.provider} | invoke_type: {model_request.invoke_type} | model_request: {model_request}")
303
+
304
+ # 动态根据 provider/invoke_type 决定使用哪个 input 字段
76
305
  try:
77
- self._run_async(self._async_client.close())
78
- logger.info("✅ gRPC channel closed at exit")
306
+ # 选择需要校验的字段集合
307
+ # 动态分支逻辑
308
+ match (model_request.provider, model_request.invoke_type):
309
+ case (ProviderType.GOOGLE, InvokeType.GENERATION):
310
+ allowed_fields = GoogleGenAiInput.model_fields.keys()
311
+ case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
312
+ allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
313
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
314
+ allowed_fields = OpenAIResponsesInput.model_fields.keys()
315
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
316
+ allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
317
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
318
+ allowed_fields = OpenAIImagesInput.model_fields.keys()
319
+ case _:
320
+ raise ValueError(
321
+ f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
322
+
323
+ # 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
324
+ model_request_dict = model_request.model_dump(exclude_unset=True)
325
+
326
+ grpc_request_kwargs = {}
327
+ for field in allowed_fields:
328
+ if field in model_request_dict:
329
+ value = model_request_dict[field]
330
+
331
+ # 跳过无效的值
332
+ if not is_effective_value(value):
333
+ continue
334
+
335
+ # 序列化grpc不支持的类型
336
+ grpc_request_kwargs[field] = serialize_value(value)
337
+
338
+ # 清理 serialize后的 grpc_request_kwargs
339
+ grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
340
+
341
+ request = model_service_pb2.ModelRequestItem(
342
+ provider=model_request.provider.value,
343
+ channel=model_request.channel.value,
344
+ invoke_type=model_request.invoke_type.value,
345
+ stream=model_request.stream or False,
346
+ org_id=model_request.user_context.org_id or "",
347
+ user_id=model_request.user_context.user_id or "",
348
+ client_type=model_request.user_context.client_type or "",
349
+ extra=grpc_request_kwargs
350
+ )
351
+
79
352
  except Exception as e:
80
- logger.warning(f" gRPC channel close failed at exit: {e}")
353
+ raise ValueError(f"构建请求失败: {str(e)}") from e
81
354
 
82
- def _run_async(self, coro):
83
- """统一运行协程,兼容已存在的事件循环"""
84
355
  try:
85
- loop = asyncio.get_running_loop()
86
- import nest_asyncio
87
- nest_asyncio.apply()
88
- return loop.run_until_complete(coro)
89
- except RuntimeError:
90
- return self._loop.run_until_complete(coro)
91
-
92
- def _sync_wrap_async_generator(self, async_gen_func):
356
+ invoke_timeout = timeout or self.default_invoke_timeout
357
+ if model_request.stream:
358
+ return self._retry_request(self._stream, request, metadata, invoke_timeout)
359
+ else:
360
+ return self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
361
+ except grpc.RpcError as e:
362
+ error_message = f"❌ Invoke gRPC failed: {str(e)}"
363
+ logger.error(error_message, exc_info=True)
364
+ raise e
365
+ except Exception as e:
366
+ error_message = f"❌ Invoke other error: {str(e)}"
367
+ logger.error(error_message, exc_info=True)
368
+ raise e
369
+
370
+ def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None,
371
+ request_id: Optional[str] = None) -> BatchModelResponse:
93
372
  """
94
- 将 async generator 转换为同步 generator,逐条 yield。
373
+ 批量模型调用接口
374
+
375
+ Args:
376
+ batch_request_model: 多条 BatchModelRequest 输入
377
+ timeout: 调用超时,单位秒
378
+ request_id: 请求id
379
+ Returns:
380
+ BatchModelResponse: 批量请求的结果
95
381
  """
96
- loop = self._loop
97
382
 
98
- # 创建异步生成器对象
99
- agen = async_gen_func
383
+ self._ensure_initialized()
100
384
 
101
- class SyncGenerator:
102
- def __iter__(self_inner):
103
- return self_inner
385
+ if not self.default_payload:
386
+ self.default_payload = {
387
+ "org_id": batch_request_model.user_context.org_id or "",
388
+ "user_id": batch_request_model.user_context.user_id or ""
389
+ }
390
+
391
+ if not request_id:
392
+ request_id = generate_request_id() # 生成一个新的 request_id
393
+ set_request_id(request_id) # 设置当前请求的 request_id
394
+ metadata = self._build_auth_metadata(request_id) # 将 request_id 加入到请求头
395
+
396
+ # 记录开始日志
397
+ logger.info(
398
+ f"🔵 Batch Request Start | request_id: {request_id} | batch_size: {len(batch_request_model.items)} | batch_request_model: {batch_request_model}")
399
+
400
+ # 构造批量请求
401
+ items = []
402
+ for model_request_item in batch_request_model.items:
403
+ # 动态根据 provider/invoke_type 决定使用哪个 input 字段
404
+ try:
405
+ match (model_request_item.provider, model_request_item.invoke_type):
406
+ case (ProviderType.GOOGLE, InvokeType.GENERATION):
407
+ allowed_fields = GoogleGenAiInput.model_fields.keys()
408
+ case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
409
+ allowed_fields = GoogleVertexAIImagesInput.model_fields.keys()
410
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
411
+ allowed_fields = OpenAIResponsesInput.model_fields.keys()
412
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
413
+ allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
414
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
415
+ allowed_fields = OpenAIImagesInput.model_fields.keys()
416
+ case _:
417
+ raise ValueError(
418
+ f"Unsupported provider/invoke_type combination: {model_request_item.provider} + {model_request_item.invoke_type}")
419
+
420
+ # 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
421
+ model_request_dict = model_request_item.model_dump(exclude_unset=True)
422
+
423
+ grpc_request_kwargs = {}
424
+ for field in allowed_fields:
425
+ if field in model_request_dict:
426
+ value = model_request_dict[field]
427
+
428
+ # 跳过无效的值
429
+ if not is_effective_value(value):
430
+ continue
431
+
432
+ # 序列化grpc不支持的类型
433
+ grpc_request_kwargs[field] = serialize_value(value)
434
+
435
+ # 清理 serialize后的 grpc_request_kwargs
436
+ grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
437
+
438
+ items.append(model_service_pb2.ModelRequestItem(
439
+ provider=model_request_item.provider.value,
440
+ channel=model_request_item.channel.value,
441
+ invoke_type=model_request_item.invoke_type.value,
442
+ stream=model_request_item.stream or False,
443
+ custom_id=model_request_item.custom_id or "",
444
+ priority=model_request_item.priority or 1,
445
+ org_id=batch_request_model.user_context.org_id or "",
446
+ user_id=batch_request_model.user_context.user_id or "",
447
+ client_type=batch_request_model.user_context.client_type or "",
448
+ extra=grpc_request_kwargs,
449
+ ))
450
+
451
+ except Exception as e:
452
+ raise ValueError(f"构建请求失败: {str(e)},item={model_request_item.custom_id}") from e
453
+
454
+ try:
455
+ # 超时处理逻辑
456
+ invoke_timeout = timeout or self.default_invoke_timeout
457
+
458
+ # 调用 gRPC 接口
459
+ response = self._retry_request(self.stub.BatchInvoke, model_service_pb2.ModelRequest(items=items),
460
+ timeout=invoke_timeout, metadata=metadata)
461
+
462
+ result = []
463
+ for res_item in response.items:
464
+ result.append(ModelResponse(
465
+ content=res_item.content,
466
+ usage=json.loads(res_item.usage) if res_item.usage else None,
467
+ raw_response=json.loads(res_item.raw_response) if res_item.raw_response else None,
468
+ error=res_item.error or None,
469
+ custom_id=res_item.custom_id if res_item.custom_id else None
470
+ ))
471
+ return BatchModelResponse(
472
+ request_id=response.request_id if response.request_id else None,
473
+ responses=result
474
+ )
475
+ except grpc.RpcError as e:
476
+ error_message = f"❌ BatchInvoke gRPC failed: {str(e)}"
477
+ logger.error(error_message, exc_info=True)
478
+ raise e
479
+ except Exception as e:
480
+ error_message = f"❌ BatchInvoke other error: {str(e)}"
481
+ logger.error(error_message, exc_info=True)
482
+ raise e
483
+
484
+ def close(self):
485
+ """关闭 gRPC 通道"""
486
+ if self.channel and not self._closed:
487
+ self.channel.close()
488
+ self._closed = True
489
+ logger.info("✅ gRPC channel closed")
490
+
491
+ def _safe_sync_close(self):
492
+ """进程退出时自动关闭 channel(事件循环处理兼容)"""
493
+ if self.channel and not self._closed:
494
+ try:
495
+ self.close() # 直接调用关闭方法
496
+ except Exception as e:
497
+ logger.error(f"❌ gRPC channel close failed at exit: {e}")
104
498
 
105
- def __next__(self_inner):
106
- try:
107
- return loop.run_until_complete(agen.__anext__())
108
- except StopAsyncIteration:
109
- raise StopIteration
499
+ def __enter__(self):
500
+ """同步初始化连接"""
501
+ self._ensure_initialized()
502
+ return self
110
503
 
111
- return SyncGenerator()
504
+ def __exit__(self, exc_type, exc_val, exc_tb):
505
+ """同步关闭连接"""
506
+ self.close()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tamar-model-client
3
- Version: 0.1.14
3
+ Version: 0.1.16
4
4
  Summary: A Python SDK for interacting with the Model Manager gRPC service
5
5
  Home-page: http://gitlab.tamaredge.top/project-tap/AgentOS/model-manager-client
6
6
  Author: Oscar Ou
@@ -273,13 +273,13 @@ async def main():
273
273
  )
274
274
 
275
275
  # 发送请求并获取响应
276
- response = await client.invoke(request_data)
277
- if response.error:
278
- print(f"错误: {response.error}")
279
- else:
280
- print(f"响应: {response.content}")
281
- if response.usage:
282
- print(f"Token 使用情况: {response.usage}")
276
+ async for r in await client.invoke(model_request):
277
+ if r.error:
278
+ print(f"错误: {r.error}")
279
+ else:
280
+ print(f"响应: {r.content}")
281
+ if r.usage:
282
+ print(f"Token 使用情况: {r.usage}")
283
283
 
284
284
 
285
285
  # 运行异步示例
@@ -531,7 +531,7 @@ python make_grpc.py
531
531
  ### 部署到 pip
532
532
  ```bash
533
533
  python setup.py sdist bdist_wheel
534
- twine check dist/*
534
+ twine upload dist/*
535
535
 
536
536
  ```
537
537
 
@@ -1,8 +1,8 @@
1
1
  tamar_model_client/__init__.py,sha256=LMECAuDARWHV1XzH3msoDXcyurS2eihRQmBy26_PUE0,328
2
- tamar_model_client/async_client.py,sha256=gmZ2xMHO_F-Vtg3OK7B_yf-gtI-WH2NU2LzC6YO_t7k,19649
2
+ tamar_model_client/async_client.py,sha256=K14GigYdcsHQg83PP1YH3wxxZEUwvFlIFMWdFfegnhc,25655
3
3
  tamar_model_client/auth.py,sha256=gbwW5Aakeb49PMbmYvrYlVx1mfyn1LEDJ4qQVs-9DA4,438
4
4
  tamar_model_client/exceptions.py,sha256=jYU494OU_NeIa4X393V-Y73mTNm0JZ9yZApnlOM9CJQ,332
5
- tamar_model_client/sync_client.py,sha256=o8b20fQUvtMq1gWax3_dfOpputYT4l9pRTz6cHdB0lg,4006
5
+ tamar_model_client/sync_client.py,sha256=B4itGuFy1T6g2pnC-95RbaaOqtRIYLeW9eah-CRFRM0,22486
6
6
  tamar_model_client/enums/__init__.py,sha256=3cYYn8ztNGBa_pI_5JGRVYf2QX8fkBVWdjID1PLvoBQ,182
7
7
  tamar_model_client/enums/channel.py,sha256=wCzX579nNpTtwzGeS6S3Ls0UzVAgsOlfy4fXMzQTCAw,199
8
8
  tamar_model_client/enums/invoke.py,sha256=WufImoN_87ZjGyzYitZkhNNFefWJehKfLtyP-DTBYlA,267
@@ -13,7 +13,7 @@ tamar_model_client/generated/model_service_pb2_grpc.py,sha256=k4tIbp3XBxdyuOVR18
13
13
  tamar_model_client/schemas/__init__.py,sha256=AxuI-TcvA4OMTj2FtK4wAItvz9LrK_293pu3cmMLE7k,394
14
14
  tamar_model_client/schemas/inputs.py,sha256=AlvjTRp_UGnbmqzv4OJ3RjH4UGErzSNfKS8Puj6oEXQ,19088
15
15
  tamar_model_client/schemas/outputs.py,sha256=M_fcqUtXPJnfiLabHlyA8BorlC5pYkf5KLjXO1ysKIQ,1031
16
- tamar_model_client-0.1.14.dist-info/METADATA,sha256=XB9fzmRzMJM2UL8udezQf6PHy103GgtwICGmlFUnn4U,16566
17
- tamar_model_client-0.1.14.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
18
- tamar_model_client-0.1.14.dist-info/top_level.txt,sha256=_LfDhPv_fvON0PoZgQuo4M7EjoWtxPRoQOBJziJmip8,19
19
- tamar_model_client-0.1.14.dist-info/RECORD,,
16
+ tamar_model_client-0.1.16.dist-info/METADATA,sha256=YaPEPgdIVcJVSZ55rzx-G5TtjHTT0teXJspOz5O3vyE,16562
17
+ tamar_model_client-0.1.16.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ tamar_model_client-0.1.16.dist-info/top_level.txt,sha256=_LfDhPv_fvON0PoZgQuo4M7EjoWtxPRoQOBJziJmip8,19
19
+ tamar_model_client-0.1.16.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.0.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5