tamar-model-client 0.1.15__tar.gz → 0.1.17__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/PKG-INFO +9 -9
  2. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/README.md +8 -8
  3. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/setup.py +1 -1
  4. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/async_client.py +44 -11
  5. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/enums/invoke.py +2 -1
  6. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/schemas/inputs.py +54 -123
  7. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/sync_client.py +16 -15
  8. tamar_model_client-0.1.17/tamar_model_client/utils.py +118 -0
  9. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client.egg-info/PKG-INFO +9 -9
  10. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client.egg-info/SOURCES.txt +1 -0
  11. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/setup.cfg +0 -0
  12. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/__init__.py +0 -0
  13. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/auth.py +0 -0
  14. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/enums/__init__.py +0 -0
  15. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/enums/channel.py +0 -0
  16. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/enums/providers.py +0 -0
  17. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/exceptions.py +0 -0
  18. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/generated/__init__.py +0 -0
  19. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/generated/model_service_pb2.py +0 -0
  20. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/generated/model_service_pb2_grpc.py +0 -0
  21. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/schemas/__init__.py +0 -0
  22. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client/schemas/outputs.py +0 -0
  23. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client.egg-info/dependency_links.txt +0 -0
  24. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client.egg-info/requires.txt +0 -0
  25. {tamar_model_client-0.1.15 → tamar_model_client-0.1.17}/tamar_model_client.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tamar-model-client
3
- Version: 0.1.15
3
+ Version: 0.1.17
4
4
  Summary: A Python SDK for interacting with the Model Manager gRPC service
5
5
  Home-page: http://gitlab.tamaredge.top/project-tap/AgentOS/model-manager-client
6
6
  Author: Oscar Ou
@@ -273,13 +273,13 @@ async def main():
273
273
  )
274
274
 
275
275
  # 发送请求并获取响应
276
- response = await client.invoke(request_data)
277
- if response.error:
278
- print(f"错误: {response.error}")
279
- else:
280
- print(f"响应: {response.content}")
281
- if response.usage:
282
- print(f"Token 使用情况: {response.usage}")
276
+ async for r in await client.invoke(model_request):
277
+ if r.error:
278
+ print(f"错误: {r.error}")
279
+ else:
280
+ print(f"响应: {r.content}")
281
+ if r.usage:
282
+ print(f"Token 使用情况: {r.usage}")
283
283
 
284
284
 
285
285
  # 运行异步示例
@@ -531,7 +531,7 @@ python make_grpc.py
531
531
  ### 部署到 pip
532
532
  ```bash
533
533
  python setup.py sdist bdist_wheel
534
- twine check dist/*
534
+ twine upload dist/*
535
535
 
536
536
  ```
537
537
 
@@ -243,13 +243,13 @@ async def main():
243
243
  )
244
244
 
245
245
  # 发送请求并获取响应
246
- response = await client.invoke(request_data)
247
- if response.error:
248
- print(f"错误: {response.error}")
249
- else:
250
- print(f"响应: {response.content}")
251
- if response.usage:
252
- print(f"Token 使用情况: {response.usage}")
246
+ async for r in await client.invoke(model_request):
247
+ if r.error:
248
+ print(f"错误: {r.error}")
249
+ else:
250
+ print(f"响应: {r.content}")
251
+ if r.usage:
252
+ print(f"Token 使用情况: {r.usage}")
253
253
 
254
254
 
255
255
  # 运行异步示例
@@ -501,7 +501,7 @@ python make_grpc.py
501
501
  ### 部署到 pip
502
502
  ```bash
503
503
  python setup.py sdist bdist_wheel
504
- twine check dist/*
504
+ twine upload dist/*
505
505
 
506
506
  ```
507
507
 
@@ -2,7 +2,7 @@ from setuptools import setup, find_packages
2
2
 
3
3
  setup(
4
4
  name="tamar-model-client",
5
- version="0.1.15",
5
+ version="0.1.17",
6
6
  description="A Python SDK for interacting with the Model Manager gRPC service",
7
7
  author="Oscar Ou",
8
8
  author_email="oscar.ou@tamaredge.ai",
@@ -19,7 +19,7 @@ from .exceptions import ConnectionError
19
19
  from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
20
20
  from .generated import model_service_pb2, model_service_pb2_grpc
21
21
  from .schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput, \
22
- GoogleVertexAIImagesInput, OpenAIImagesInput
22
+ GoogleVertexAIImagesInput, OpenAIImagesInput, OpenAIImagesEditInput
23
23
 
24
24
  logger = logging.getLogger(__name__)
25
25
 
@@ -203,6 +203,37 @@ class AsyncTamarModelClient:
203
203
  logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
204
204
  raise
205
205
 
206
+ async def _retry_request_stream(self, func, *args, **kwargs):
207
+ retry_count = 0
208
+ while retry_count < self.max_retries:
209
+ try:
210
+ return func(*args, **kwargs)
211
+ except (grpc.aio.AioRpcError, grpc.RpcError) as e:
212
+ # 对于取消的情况进行指数退避重试
213
+ if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
214
+ retry_count += 1
215
+ logger.warning(f"❌ RPC cancelled, retrying {retry_count}/{self.max_retries}...")
216
+ if retry_count < self.max_retries:
217
+ delay = self.retry_delay * (2 ** (retry_count - 1))
218
+ await asyncio.sleep(delay)
219
+ else:
220
+ logger.error("❌ Max retry reached for CANCELLED")
221
+ raise
222
+ # 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
223
+ elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
224
+ grpc.StatusCode.DEADLINE_EXCEEDED}:
225
+ retry_count += 1
226
+ logger.warning(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...")
227
+ if retry_count < self.max_retries:
228
+ delay = self.retry_delay * (2 ** (retry_count - 1))
229
+ await asyncio.sleep(delay)
230
+ else:
231
+ logger.error(f"❌ Max retry reached for {e.code()}")
232
+ raise
233
+ else:
234
+ logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
235
+ raise
236
+
206
237
  def _build_auth_metadata(self, request_id: str) -> list:
207
238
  # if not self.jwt_token and self.jwt_handler:
208
239
  # 更改为每次请求都生成一次token
@@ -263,25 +294,23 @@ class AsyncTamarModelClient:
263
294
  logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
264
295
  await asyncio.sleep(delay)
265
296
 
266
- async def _stream_inner(self, model_request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
267
- """Inner function to handle the actual streaming gRPC call."""
268
- async for response in self.stub.Invoke(model_request, metadata=metadata, timeout=invoke_timeout):
297
+ async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
298
+ async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
269
299
  yield ModelResponse(
270
300
  content=response.content,
271
301
  usage=json.loads(response.usage) if response.usage else None,
272
- raw_response=json.loads(response.raw_response) if response.raw_response else None,
273
302
  error=response.error or None,
303
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
304
+ request_id=response.request_id if response.request_id else None,
274
305
  )
275
306
 
276
- async def _stream(self, model_request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
277
- return await self._retry_request(self._stream_inner, model_request, metadata, invoke_timeout)
278
-
279
307
  async def _invoke_request(self, request, metadata, invoke_timeout):
280
308
  async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
281
309
  return ModelResponse(
282
310
  content=response.content,
283
311
  usage=json.loads(response.usage) if response.usage else None,
284
312
  error=response.error or None,
313
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
285
314
  request_id=response.request_id if response.request_id else None,
286
315
  )
287
316
 
@@ -317,7 +346,7 @@ class AsyncTamarModelClient:
317
346
 
318
347
  # 记录开始日志
319
348
  logger.info(
320
- f"🔵 Request Start | request_id: {request_id} | provider: {model_request.provider} | invoke_type: {model_request.invoke_type} | model_request: {model_request}")
349
+ f"🔵 Request Start | request_id: {request_id} | provider: {model_request.provider} | invoke_type: {model_request.invoke_type}")
321
350
 
322
351
  # 动态根据 provider/invoke_type 决定使用哪个 input 字段
323
352
  try:
@@ -334,6 +363,8 @@ class AsyncTamarModelClient:
334
363
  allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
335
364
  case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
336
365
  allowed_fields = OpenAIImagesInput.model_fields.keys()
366
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
367
+ allowed_fields = OpenAIImagesEditInput.model_fields.keys()
337
368
  case _:
338
369
  raise ValueError(
339
370
  f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
@@ -373,7 +404,7 @@ class AsyncTamarModelClient:
373
404
  try:
374
405
  invoke_timeout = timeout or self.default_invoke_timeout
375
406
  if model_request.stream:
376
- return await self._stream(request, metadata, invoke_timeout)
407
+ return await self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
377
408
  else:
378
409
  return await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
379
410
  except grpc.RpcError as e:
@@ -414,7 +445,7 @@ class AsyncTamarModelClient:
414
445
 
415
446
  # 记录开始日志
416
447
  logger.info(
417
- f"🔵 Batch Request Start | request_id: {request_id} | batch_size: {len(batch_request_model.items)} | batch_request_model: {batch_request_model}")
448
+ f"🔵 Batch Request Start | request_id: {request_id} | batch_size: {len(batch_request_model.items)}")
418
449
 
419
450
  # 构造批量请求
420
451
  items = []
@@ -432,6 +463,8 @@ class AsyncTamarModelClient:
432
463
  allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
433
464
  case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
434
465
  allowed_fields = OpenAIImagesInput.model_fields.keys()
466
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
467
+ allowed_fields = OpenAIImagesEditInput.model_fields.keys()
435
468
  case _:
436
469
  raise ValueError(
437
470
  f"Unsupported provider/invoke_type combination: {model_request_item.provider} + {model_request_item.invoke_type}")
@@ -7,4 +7,5 @@ class InvokeType(str, Enum):
7
7
  CHAT_COMPLETIONS = "chat-completions"
8
8
 
9
9
  GENERATION = "generation" # 生成类,默认的值
10
- IMAGE_GENERATION = "image-generation"
10
+ IMAGE_GENERATION = "image-generation"
11
+ IMAGE_EDIT_GENERATION = "image-edit-generation"
@@ -1,18 +1,22 @@
1
+ import mimetypes
2
+ import os
3
+
1
4
  import httpx
2
5
  from google.genai import types
3
6
  from openai import NotGiven, NOT_GIVEN
4
- from openai._types import Headers, Query, Body
7
+ from openai._types import Headers, Query, Body, FileTypes
5
8
  from openai.types import ChatModel, Metadata, ReasoningEffort, ResponsesModel, Reasoning, ImageModel
6
9
  from openai.types.chat import ChatCompletionMessageParam, ChatCompletionAudioParam, completion_create_params, \
7
10
  ChatCompletionPredictionContentParam, ChatCompletionStreamOptionsParam, ChatCompletionToolChoiceOptionParam, \
8
11
  ChatCompletionToolParam
9
12
  from openai.types.responses import ResponseInputParam, ResponseIncludable, ResponseTextConfigParam, \
10
13
  response_create_params, ToolParam
11
- from pydantic import BaseModel, model_validator
12
- from typing import List, Optional, Union, Iterable, Dict, Literal
14
+ from pydantic import BaseModel, model_validator, field_validator
15
+ from typing import List, Optional, Union, Iterable, Dict, Literal, IO
13
16
 
14
17
  from tamar_model_client.enums import ProviderType, InvokeType
15
18
  from tamar_model_client.enums.channel import Channel
19
+ from tamar_model_client.utils import convert_file_field, validate_fields_by_provider_and_invoke_type
16
20
 
17
21
 
18
22
  class UserContext(BaseModel):
@@ -149,6 +153,29 @@ class OpenAIImagesInput(BaseModel):
149
153
  }
150
154
 
151
155
 
156
+ class OpenAIImagesEditInput(BaseModel):
157
+ image: Union[FileTypes, List[FileTypes]]
158
+ prompt: str
159
+ background: Optional[Literal["transparent", "opaque", "auto"]] | NotGiven = NOT_GIVEN
160
+ mask: FileTypes | NotGiven = NOT_GIVEN
161
+ model: Union[str, ImageModel, None] | NotGiven = NOT_GIVEN
162
+ n: Optional[int] | NotGiven = NOT_GIVEN
163
+ quality: Optional[Literal["standard", "low", "medium", "high", "auto"]] | NotGiven = NOT_GIVEN
164
+ response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN
165
+ size: Optional[Literal["256x256", "512x512", "1024x1024", "1536x1024", "1024x1536", "auto"]] | NotGiven = NOT_GIVEN
166
+ user: str | NotGiven = NOT_GIVEN
167
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
168
+ # The extra values given here take precedence over values defined on the client or passed to this method.
169
+ extra_headers: Headers | None = None
170
+ extra_query: Query | None = None
171
+ extra_body: Body | None = None
172
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
173
+
174
+ model_config = {
175
+ "arbitrary_types_allowed": True
176
+ }
177
+
178
+
152
179
  class BaseRequest(BaseModel):
153
180
  provider: ProviderType # 供应商,如 "openai", "google" 等
154
181
  channel: Channel = Channel.NORMAL # 渠道:不同服务商之前有不同的调用SDK,这里指定是调用哪个SDK
@@ -212,8 +239,11 @@ class ModelRequestInput(BaseRequest):
212
239
  contents: Optional[Union[types.ContentListUnion, types.ContentListUnionDict]] = None
213
240
  config: Optional[types.GenerateContentConfigOrDict] = None
214
241
 
215
- # OpenAIImagesInput + GoogleVertexAIImagesInput 合并字段
242
+ # OpenAIImagesInput + OpenAIImagesEditInput + GoogleVertexAIImagesInput 合并字段
243
+ image: Optional[Union[FileTypes, List[FileTypes]]] = None
216
244
  prompt: Optional[str] = None
245
+ background: Optional[Literal["transparent", "opaque", "auto"]] | NotGiven = NOT_GIVEN
246
+ mask: FileTypes | NotGiven = NOT_GIVEN
217
247
  negative_prompt: Optional[str] = None
218
248
  aspect_ratio: Optional[Literal["1:1", "9:16", "16:9", "4:3", "3:4"]] = None
219
249
  guidance_scale: Optional[float] = None
@@ -223,7 +253,8 @@ class ModelRequestInput(BaseRequest):
223
253
  safety_filter_level: Optional[Literal["block_most", "block_some", "block_few", "block_fewest"]] = None
224
254
  person_generation: Optional[Literal["dont_allow", "allow_adult", "allow_all"]] = None
225
255
  quality: Optional[Literal["standard", "hd"]] | NotGiven = NOT_GIVEN
226
- size: Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]] | NotGiven = NOT_GIVEN
256
+ size: Optional[Literal[
257
+ "auto", "1024x1024", "1536x1024", "1024x1536", "256x256", "512x512", "1792x1024", "1024x1792"]] | NotGiven = NOT_GIVEN
227
258
  style: Optional[Literal["vivid", "natural"]] | NotGiven = NOT_GIVEN
228
259
  number_of_images: Optional[int] = None # Google 用法
229
260
 
@@ -231,71 +262,26 @@ class ModelRequestInput(BaseRequest):
231
262
  "arbitrary_types_allowed": True
232
263
  }
233
264
 
265
+ @field_validator("image", mode="before")
266
+ @classmethod
267
+ def validate_image(cls, v):
268
+ return convert_file_field(v)
269
+
270
+ @field_validator("mask", mode="before")
271
+ @classmethod
272
+ def validate_mask(cls, v):
273
+ return convert_file_field(v)
274
+
234
275
 
235
276
  class ModelRequest(ModelRequestInput):
236
277
  user_context: UserContext # 用户信息
237
278
 
238
279
  @model_validator(mode="after")
239
280
  def validate_by_provider_and_invoke_type(self) -> "ModelRequest":
240
- """根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
241
- # 动态获取 allowed fields
242
- base_allowed = {"provider", "channel", "invoke_type", "user_context"}
243
- google_allowed = base_allowed | set(GoogleGenAiInput.model_fields.keys())
244
- openai_responses_allowed = base_allowed | set(OpenAIResponsesInput.model_fields.keys())
245
- openai_chat_allowed = base_allowed | set(OpenAIChatCompletionsInput.model_fields.keys())
246
- openai_images_allowed = base_allowed | set(OpenAIImagesInput.model_fields.keys())
247
- google_vertexai_images_allowed = base_allowed | set(GoogleVertexAIImagesInput.model_fields.keys())
248
-
249
- # 各模型类型必填字段
250
- google_required_fields = {"model", "contents"}
251
- google_vertexai_image_required_fields = {"model", "prompt"}
252
-
253
- openai_responses_required_fields = {"input", "model"}
254
- openai_chat_required_fields = {"messages", "model"}
255
- openai_image_required_fields = {"prompt"}
256
-
257
- # 选择需要校验的字段集合
258
- # 动态分支逻辑
259
- match (self.provider, self.invoke_type):
260
- case (ProviderType.GOOGLE, InvokeType.GENERATION):
261
- allowed_fields = google_allowed
262
- expected_fields = google_required_fields
263
- case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
264
- allowed_fields = google_vertexai_images_allowed
265
- expected_fields = google_vertexai_image_required_fields
266
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
267
- allowed_fields = openai_responses_allowed
268
- expected_fields = openai_responses_required_fields
269
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
270
- allowed_fields = openai_chat_allowed
271
- expected_fields = openai_chat_required_fields
272
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
273
- allowed_fields = openai_images_allowed
274
- expected_fields = openai_image_required_fields
275
- case _:
276
- raise ValueError(f"Unsupported provider/invoke_type combination: {self.provider} + {self.invoke_type}")
277
-
278
- # 校验必填字段是否缺失
279
- missing = [field for field in expected_fields if getattr(self, field, None) is None]
280
- if missing:
281
- raise ValueError(
282
- f"Missing required fields for provider={self.provider} and invoke_type={self.invoke_type}: {missing}")
283
-
284
- # 检查是否有非法字段
285
- illegal_fields = []
286
- valid_fields = {"provider", "channel", "invoke_type"} if self.invoke_type == InvokeType.IMAGE_GENERATION else {
287
- "provider", "channel", "invoke_type", "stream"}
288
- for name, value in self.__dict__.items():
289
- if name in valid_fields:
290
- continue
291
- if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
292
- illegal_fields.append(name)
293
-
294
- if illegal_fields:
295
- raise ValueError(
296
- f"Unsupported fields for provider={self.provider} and invoke_type={self.invoke_type}: {illegal_fields}")
297
-
298
- return self
281
+ return validate_fields_by_provider_and_invoke_type(
282
+ instance=self,
283
+ extra_allowed_fields={"provider", "channel", "invoke_type", "user_context"},
284
+ )
299
285
 
300
286
 
301
287
  class BatchModelRequestItem(ModelRequestInput):
@@ -304,65 +290,10 @@ class BatchModelRequestItem(ModelRequestInput):
304
290
 
305
291
  @model_validator(mode="after")
306
292
  def validate_by_provider_and_invoke_type(self) -> "BatchModelRequestItem":
307
- """根据 provider 和 invoke_type 动态校验具体输入模型字段。"""
308
- # 动态获取 allowed fields
309
- base_allowed = {"provider", "channel", "invoke_type", "user_context", "custom_id"}
310
- google_allowed = base_allowed | set(GoogleGenAiInput.model_fields.keys())
311
- openai_responses_allowed = base_allowed | set(OpenAIResponsesInput.model_fields.keys())
312
- openai_chat_allowed = base_allowed | set(OpenAIChatCompletionsInput.model_fields.keys())
313
- openai_images_allowed = base_allowed | set(OpenAIImagesInput.model_fields.keys())
314
- google_vertexai_images_allowed = base_allowed | set(GoogleVertexAIImagesInput.model_fields.keys())
315
-
316
- # 各模型类型必填字段
317
- google_required_fields = {"model", "contents"}
318
- google_vertexai_image_required_fields = {"model", "prompt"}
319
-
320
- openai_responses_required_fields = {"input", "model"}
321
- openai_chat_required_fields = {"messages", "model"}
322
- openai_image_required_fields = {"prompt"}
323
-
324
- # 选择需要校验的字段集合
325
- # 动态分支逻辑
326
- match (self.provider, self.invoke_type):
327
- case (ProviderType.GOOGLE, InvokeType.GENERATION):
328
- allowed_fields = google_allowed
329
- expected_fields = google_required_fields
330
- case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
331
- allowed_fields = google_vertexai_images_allowed
332
- expected_fields = google_vertexai_image_required_fields
333
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
334
- allowed_fields = openai_responses_allowed
335
- expected_fields = openai_responses_required_fields
336
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
337
- allowed_fields = openai_chat_allowed
338
- expected_fields = openai_chat_required_fields
339
- case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
340
- allowed_fields = openai_images_allowed
341
- expected_fields = openai_image_required_fields
342
- case _:
343
- raise ValueError(f"Unsupported provider/invoke_type combination: {self.provider} + {self.invoke_type}")
344
-
345
- # 校验必填字段是否缺失
346
- missing = [field for field in expected_fields if getattr(self, field, None) is None]
347
- if missing:
348
- raise ValueError(
349
- f"Missing required fields for provider={self.provider} and invoke_type={self.invoke_type}: {missing}")
350
-
351
- # 检查是否有非法字段
352
- illegal_fields = []
353
- valid_fields = {"provider", "channel", "invoke_type"} if self.invoke_type == InvokeType.IMAGE_GENERATION else {
354
- "provider", "channel", "invoke_type", "stream"}
355
- for name, value in self.__dict__.items():
356
- if name in valid_fields:
357
- continue
358
- if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
359
- illegal_fields.append(name)
360
-
361
- if illegal_fields:
362
- raise ValueError(
363
- f"Unsupported fields for provider={self.provider} and invoke_type={self.invoke_type}: {illegal_fields}")
364
-
365
- return self
293
+ return validate_fields_by_provider_and_invoke_type(
294
+ instance=self,
295
+ extra_allowed_fields={"provider", "channel", "invoke_type", "user_context", "custom_id"},
296
+ )
366
297
 
367
298
 
368
299
  class BatchModelRequest(BaseModel):
@@ -17,7 +17,7 @@ from .exceptions import ConnectionError
17
17
  from .generated import model_service_pb2, model_service_pb2_grpc
18
18
  from .schemas import BatchModelResponse, ModelResponse
19
19
  from .schemas.inputs import GoogleGenAiInput, GoogleVertexAIImagesInput, OpenAIResponsesInput, \
20
- OpenAIChatCompletionsInput, OpenAIImagesInput, BatchModelRequest, ModelRequest
20
+ OpenAIChatCompletionsInput, OpenAIImagesInput, OpenAIImagesEditInput, BatchModelRequest, ModelRequest
21
21
 
22
22
  logger = logging.getLogger(__name__)
23
23
 
@@ -247,20 +247,16 @@ class TamarModelClient:
247
247
  logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
248
248
  time.sleep(delay) # Blocking sleep in sync version
249
249
 
250
- def _stream_inner(self, model_request, metadata, invoke_timeout) -> Iterator[ModelResponse]:
251
- """Inner function to handle the actual streaming gRPC call."""
252
- response = self.stub.Invoke(model_request, metadata=metadata, timeout=invoke_timeout)
253
- for res in response:
250
+ def _stream(self, request, metadata, invoke_timeout) -> Iterator[ModelResponse]:
251
+ for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
254
252
  yield ModelResponse(
255
- content=res.content,
256
- usage=json.loads(res.usage) if res.usage else None,
257
- raw_response=json.loads(res.raw_response) if res.raw_response else None,
258
- error=res.error or None,
253
+ content=response.content,
254
+ usage=json.loads(response.usage) if response.usage else None,
255
+ error=response.error or None,
256
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
257
+ request_id=response.request_id if response.request_id else None,
259
258
  )
260
259
 
261
- def _stream(self, model_request, metadata, invoke_timeout) -> Iterator[ModelResponse]:
262
- return self._retry_request(self._stream_inner, model_request, metadata, invoke_timeout)
263
-
264
260
  def _invoke_request(self, request, metadata, invoke_timeout):
265
261
  response = self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout)
266
262
  for response in response:
@@ -268,6 +264,7 @@ class TamarModelClient:
268
264
  content=response.content,
269
265
  usage=json.loads(response.usage) if response.usage else None,
270
266
  error=response.error or None,
267
+ raw_response=json.loads(response.raw_response) if response.raw_response else None,
271
268
  request_id=response.request_id if response.request_id else None,
272
269
  )
273
270
 
@@ -302,7 +299,7 @@ class TamarModelClient:
302
299
 
303
300
  # 记录开始日志
304
301
  logger.info(
305
- f"🔵 Request Start | request_id: {request_id} | provider: {model_request.provider} | invoke_type: {model_request.invoke_type} | model_request: {model_request}")
302
+ f"🔵 Request Start |provider: {model_request.provider} | invoke_type: {model_request.invoke_type}")
306
303
 
307
304
  # 动态根据 provider/invoke_type 决定使用哪个 input 字段
308
305
  try:
@@ -319,6 +316,8 @@ class TamarModelClient:
319
316
  allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
320
317
  case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
321
318
  allowed_fields = OpenAIImagesInput.model_fields.keys()
319
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
320
+ allowed_fields = OpenAIImagesEditInput.model_fields.keys()
322
321
  case _:
323
322
  raise ValueError(
324
323
  f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
@@ -358,7 +357,7 @@ class TamarModelClient:
358
357
  try:
359
358
  invoke_timeout = timeout or self.default_invoke_timeout
360
359
  if model_request.stream:
361
- return self._stream(request, metadata, invoke_timeout)
360
+ return self._retry_request(self._stream, request, metadata, invoke_timeout)
362
361
  else:
363
362
  return self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
364
363
  except grpc.RpcError as e:
@@ -398,7 +397,7 @@ class TamarModelClient:
398
397
 
399
398
  # 记录开始日志
400
399
  logger.info(
401
- f"🔵 Batch Request Start | request_id: {request_id} | batch_size: {len(batch_request_model.items)} | batch_request_model: {batch_request_model}")
400
+ f"🔵 Batch Request Start | batch_size: {len(batch_request_model.items)}")
402
401
 
403
402
  # 构造批量请求
404
403
  items = []
@@ -416,6 +415,8 @@ class TamarModelClient:
416
415
  allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
417
416
  case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
418
417
  allowed_fields = OpenAIImagesInput.model_fields.keys()
418
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
419
+ allowed_fields = OpenAIImagesEditInput.model_fields.keys()
419
420
  case _:
420
421
  raise ValueError(
421
422
  f"Unsupported provider/invoke_type combination: {model_request_item.provider} + {model_request_item.invoke_type}")
@@ -0,0 +1,118 @@
1
+ from openai import NotGiven
2
+ from pydantic import BaseModel
3
+ from typing import Any
4
+ import os, mimetypes
5
+
6
+ def convert_file_field(value: Any) -> Any:
7
+ def is_file_like(obj):
8
+ return hasattr(obj, "read") and callable(obj.read)
9
+
10
+ def infer_mimetype(filename: str) -> str:
11
+ mime, _ = mimetypes.guess_type(filename)
12
+ return mime or "application/octet-stream"
13
+
14
+ def convert_item(item):
15
+ if is_file_like(item):
16
+ filename = os.path.basename(getattr(item, "name", "file.png"))
17
+ content_type = infer_mimetype(filename)
18
+ content = item.read()
19
+ if hasattr(item, "seek"):
20
+ item.seek(0)
21
+ return (filename, content, content_type)
22
+ elif isinstance(item, tuple):
23
+ parts = list(item)
24
+ if len(parts) > 1:
25
+ maybe_file = parts[1]
26
+ if is_file_like(maybe_file):
27
+ content = maybe_file.read()
28
+ if hasattr(maybe_file, "seek"):
29
+ maybe_file.seek(0)
30
+ parts[1] = content
31
+ elif not isinstance(maybe_file, (bytes, bytearray)):
32
+ raise ValueError(f"Unsupported second element in tuple: {type(maybe_file)}")
33
+ if len(parts) == 2:
34
+ parts.append(infer_mimetype(os.path.basename(parts[0] or "file.png")))
35
+ return tuple(parts)
36
+ else:
37
+ return item
38
+
39
+ if value is None:
40
+ return value
41
+ elif isinstance(value, list):
42
+ return [convert_item(v) for v in value]
43
+ else:
44
+ return convert_item(value)
45
+
46
+
47
+ def validate_fields_by_provider_and_invoke_type(
48
+ instance: BaseModel,
49
+ extra_allowed_fields: set[str],
50
+ extra_required_fields: set[str] = set()
51
+ ) -> BaseModel:
52
+ """
53
+ 通用的字段校验逻辑,根据 provider 和 invoke_type 动态检查字段合法性和必填字段。
54
+ 适用于 ModelRequest 和 BatchModelRequestItem。
55
+ """
56
+ from tamar_model_client.enums import ProviderType, InvokeType
57
+ from tamar_model_client.schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput, \
58
+ OpenAIImagesInput, OpenAIImagesEditInput, GoogleVertexAIImagesInput
59
+
60
+ google_allowed = extra_allowed_fields | set(GoogleGenAiInput.model_fields)
61
+ openai_responses_allowed = extra_allowed_fields | set(OpenAIResponsesInput.model_fields)
62
+ openai_chat_allowed = extra_allowed_fields | set(OpenAIChatCompletionsInput.model_fields)
63
+ openai_images_allowed = extra_allowed_fields | set(OpenAIImagesInput.model_fields)
64
+ openai_images_edit_allowed = extra_allowed_fields | set(OpenAIImagesEditInput.model_fields)
65
+ google_vertexai_images_allowed = extra_allowed_fields | set(GoogleVertexAIImagesInput.model_fields)
66
+
67
+ google_required = {"model", "contents"}
68
+ google_vertex_required = {"model", "prompt"}
69
+ openai_resp_required = {"input", "model"}
70
+ openai_chat_required = {"messages", "model"}
71
+ openai_img_required = {"prompt"}
72
+ openai_edit_required = {"image", "prompt"}
73
+
74
+ match (instance.provider, instance.invoke_type):
75
+ case (ProviderType.GOOGLE, InvokeType.GENERATION):
76
+ allowed = google_allowed
77
+ required = google_required
78
+ case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
79
+ allowed = google_vertexai_images_allowed
80
+ required = google_vertex_required
81
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
82
+ allowed = openai_responses_allowed
83
+ required = openai_resp_required
84
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
85
+ allowed = openai_chat_allowed
86
+ required = openai_chat_required
87
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
88
+ allowed = openai_images_allowed
89
+ required = openai_img_required
90
+ case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
91
+ allowed = openai_images_edit_allowed
92
+ required = openai_edit_required
93
+ case _:
94
+ raise ValueError(f"Unsupported provider/invoke_type: {instance.provider} + {instance.invoke_type}")
95
+
96
+ required = required | extra_required_fields
97
+
98
+ missing = [f for f in required if getattr(instance, f, None) is None]
99
+ if missing:
100
+ raise ValueError(
101
+ f"Missing required fields for provider={instance.provider} and invoke_type={instance.invoke_type}: {missing}")
102
+
103
+ illegal = []
104
+ valid_fields = {"provider", "channel", "invoke_type"}
105
+ if getattr(instance, "stream", None) is not None:
106
+ valid_fields.add("stream")
107
+
108
+ for k, v in instance.__dict__.items():
109
+ if k in valid_fields:
110
+ continue
111
+ if k not in allowed and v is not None and not isinstance(v, NotGiven):
112
+ illegal.append(k)
113
+
114
+ if illegal:
115
+ raise ValueError(
116
+ f"Unsupported fields for provider={instance.provider} and invoke_type={instance.invoke_type}: {illegal}")
117
+
118
+ return instance
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tamar-model-client
3
- Version: 0.1.15
3
+ Version: 0.1.17
4
4
  Summary: A Python SDK for interacting with the Model Manager gRPC service
5
5
  Home-page: http://gitlab.tamaredge.top/project-tap/AgentOS/model-manager-client
6
6
  Author: Oscar Ou
@@ -273,13 +273,13 @@ async def main():
273
273
  )
274
274
 
275
275
  # 发送请求并获取响应
276
- response = await client.invoke(request_data)
277
- if response.error:
278
- print(f"错误: {response.error}")
279
- else:
280
- print(f"响应: {response.content}")
281
- if response.usage:
282
- print(f"Token 使用情况: {response.usage}")
276
+ async for r in await client.invoke(model_request):
277
+ if r.error:
278
+ print(f"错误: {r.error}")
279
+ else:
280
+ print(f"响应: {r.content}")
281
+ if r.usage:
282
+ print(f"Token 使用情况: {r.usage}")
283
283
 
284
284
 
285
285
  # 运行异步示例
@@ -531,7 +531,7 @@ python make_grpc.py
531
531
  ### 部署到 pip
532
532
  ```bash
533
533
  python setup.py sdist bdist_wheel
534
- twine check dist/*
534
+ twine upload dist/*
535
535
 
536
536
  ```
537
537
 
@@ -5,6 +5,7 @@ tamar_model_client/async_client.py
5
5
  tamar_model_client/auth.py
6
6
  tamar_model_client/exceptions.py
7
7
  tamar_model_client/sync_client.py
8
+ tamar_model_client/utils.py
8
9
  tamar_model_client.egg-info/PKG-INFO
9
10
  tamar_model_client.egg-info/SOURCES.txt
10
11
  tamar_model_client.egg-info/dependency_links.txt