tamar-model-client 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/async_client.py +44 -11
- tamar_model_client/enums/invoke.py +2 -1
- tamar_model_client/schemas/inputs.py +54 -123
- tamar_model_client/sync_client.py +16 -15
- tamar_model_client/utils.py +118 -0
- {tamar_model_client-0.1.15.dist-info → tamar_model_client-0.1.17.dist-info}/METADATA +9 -9
- {tamar_model_client-0.1.15.dist-info → tamar_model_client-0.1.17.dist-info}/RECORD +9 -8
- {tamar_model_client-0.1.15.dist-info → tamar_model_client-0.1.17.dist-info}/WHEEL +0 -0
- {tamar_model_client-0.1.15.dist-info → tamar_model_client-0.1.17.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ from .exceptions import ConnectionError
|
|
19
19
|
from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
|
20
20
|
from .generated import model_service_pb2, model_service_pb2_grpc
|
21
21
|
from .schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput, \
|
22
|
-
GoogleVertexAIImagesInput, OpenAIImagesInput
|
22
|
+
GoogleVertexAIImagesInput, OpenAIImagesInput, OpenAIImagesEditInput
|
23
23
|
|
24
24
|
logger = logging.getLogger(__name__)
|
25
25
|
|
@@ -203,6 +203,37 @@ class AsyncTamarModelClient:
|
|
203
203
|
logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
|
204
204
|
raise
|
205
205
|
|
206
|
+
async def _retry_request_stream(self, func, *args, **kwargs):
|
207
|
+
retry_count = 0
|
208
|
+
while retry_count < self.max_retries:
|
209
|
+
try:
|
210
|
+
return func(*args, **kwargs)
|
211
|
+
except (grpc.aio.AioRpcError, grpc.RpcError) as e:
|
212
|
+
# 对于取消的情况进行指数退避重试
|
213
|
+
if isinstance(e, grpc.aio.AioRpcError) and e.code() == grpc.StatusCode.CANCELLED:
|
214
|
+
retry_count += 1
|
215
|
+
logger.warning(f"❌ RPC cancelled, retrying {retry_count}/{self.max_retries}...")
|
216
|
+
if retry_count < self.max_retries:
|
217
|
+
delay = self.retry_delay * (2 ** (retry_count - 1))
|
218
|
+
await asyncio.sleep(delay)
|
219
|
+
else:
|
220
|
+
logger.error("❌ Max retry reached for CANCELLED")
|
221
|
+
raise
|
222
|
+
# 针对其他 RPC 错误类型,如暂时的连接问题、服务器超时等
|
223
|
+
elif isinstance(e, grpc.RpcError) and e.code() in {grpc.StatusCode.UNAVAILABLE,
|
224
|
+
grpc.StatusCode.DEADLINE_EXCEEDED}:
|
225
|
+
retry_count += 1
|
226
|
+
logger.warning(f"❌ gRPC error {e.code()}, retrying {retry_count}/{self.max_retries}...")
|
227
|
+
if retry_count < self.max_retries:
|
228
|
+
delay = self.retry_delay * (2 ** (retry_count - 1))
|
229
|
+
await asyncio.sleep(delay)
|
230
|
+
else:
|
231
|
+
logger.error(f"❌ Max retry reached for {e.code()}")
|
232
|
+
raise
|
233
|
+
else:
|
234
|
+
logger.error(f"❌ Non-retryable gRPC error: {e}", exc_info=True)
|
235
|
+
raise
|
236
|
+
|
206
237
|
def _build_auth_metadata(self, request_id: str) -> list:
|
207
238
|
# if not self.jwt_token and self.jwt_handler:
|
208
239
|
# 更改为每次请求都生成一次token
|
@@ -263,25 +294,23 @@ class AsyncTamarModelClient:
|
|
263
294
|
logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
|
264
295
|
await asyncio.sleep(delay)
|
265
296
|
|
266
|
-
async def
|
267
|
-
|
268
|
-
async for response in self.stub.Invoke(model_request, metadata=metadata, timeout=invoke_timeout):
|
297
|
+
async def _stream(self, request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
|
298
|
+
async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
|
269
299
|
yield ModelResponse(
|
270
300
|
content=response.content,
|
271
301
|
usage=json.loads(response.usage) if response.usage else None,
|
272
|
-
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
273
302
|
error=response.error or None,
|
303
|
+
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
304
|
+
request_id=response.request_id if response.request_id else None,
|
274
305
|
)
|
275
306
|
|
276
|
-
async def _stream(self, model_request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
|
277
|
-
return await self._retry_request(self._stream_inner, model_request, metadata, invoke_timeout)
|
278
|
-
|
279
307
|
async def _invoke_request(self, request, metadata, invoke_timeout):
|
280
308
|
async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
|
281
309
|
return ModelResponse(
|
282
310
|
content=response.content,
|
283
311
|
usage=json.loads(response.usage) if response.usage else None,
|
284
312
|
error=response.error or None,
|
313
|
+
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
285
314
|
request_id=response.request_id if response.request_id else None,
|
286
315
|
)
|
287
316
|
|
@@ -317,7 +346,7 @@ class AsyncTamarModelClient:
|
|
317
346
|
|
318
347
|
# 记录开始日志
|
319
348
|
logger.info(
|
320
|
-
f"🔵 Request Start | request_id: {request_id} | provider: {model_request.provider} | invoke_type: {model_request.invoke_type}
|
349
|
+
f"🔵 Request Start | request_id: {request_id} | provider: {model_request.provider} | invoke_type: {model_request.invoke_type}")
|
321
350
|
|
322
351
|
# 动态根据 provider/invoke_type 决定使用哪个 input 字段
|
323
352
|
try:
|
@@ -334,6 +363,8 @@ class AsyncTamarModelClient:
|
|
334
363
|
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
335
364
|
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
336
365
|
allowed_fields = OpenAIImagesInput.model_fields.keys()
|
366
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
367
|
+
allowed_fields = OpenAIImagesEditInput.model_fields.keys()
|
337
368
|
case _:
|
338
369
|
raise ValueError(
|
339
370
|
f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
|
@@ -373,7 +404,7 @@ class AsyncTamarModelClient:
|
|
373
404
|
try:
|
374
405
|
invoke_timeout = timeout or self.default_invoke_timeout
|
375
406
|
if model_request.stream:
|
376
|
-
return await self._stream
|
407
|
+
return await self._retry_request_stream(self._stream, request, metadata, invoke_timeout)
|
377
408
|
else:
|
378
409
|
return await self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
|
379
410
|
except grpc.RpcError as e:
|
@@ -414,7 +445,7 @@ class AsyncTamarModelClient:
|
|
414
445
|
|
415
446
|
# 记录开始日志
|
416
447
|
logger.info(
|
417
|
-
f"🔵 Batch Request Start | request_id: {request_id} | batch_size: {len(batch_request_model.items)}
|
448
|
+
f"🔵 Batch Request Start | request_id: {request_id} | batch_size: {len(batch_request_model.items)}")
|
418
449
|
|
419
450
|
# 构造批量请求
|
420
451
|
items = []
|
@@ -432,6 +463,8 @@ class AsyncTamarModelClient:
|
|
432
463
|
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
433
464
|
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
434
465
|
allowed_fields = OpenAIImagesInput.model_fields.keys()
|
466
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
467
|
+
allowed_fields = OpenAIImagesEditInput.model_fields.keys()
|
435
468
|
case _:
|
436
469
|
raise ValueError(
|
437
470
|
f"Unsupported provider/invoke_type combination: {model_request_item.provider} + {model_request_item.invoke_type}")
|
@@ -1,18 +1,22 @@
|
|
1
|
+
import mimetypes
|
2
|
+
import os
|
3
|
+
|
1
4
|
import httpx
|
2
5
|
from google.genai import types
|
3
6
|
from openai import NotGiven, NOT_GIVEN
|
4
|
-
from openai._types import Headers, Query, Body
|
7
|
+
from openai._types import Headers, Query, Body, FileTypes
|
5
8
|
from openai.types import ChatModel, Metadata, ReasoningEffort, ResponsesModel, Reasoning, ImageModel
|
6
9
|
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionAudioParam, completion_create_params, \
|
7
10
|
ChatCompletionPredictionContentParam, ChatCompletionStreamOptionsParam, ChatCompletionToolChoiceOptionParam, \
|
8
11
|
ChatCompletionToolParam
|
9
12
|
from openai.types.responses import ResponseInputParam, ResponseIncludable, ResponseTextConfigParam, \
|
10
13
|
response_create_params, ToolParam
|
11
|
-
from pydantic import BaseModel, model_validator
|
12
|
-
from typing import List, Optional, Union, Iterable, Dict, Literal
|
14
|
+
from pydantic import BaseModel, model_validator, field_validator
|
15
|
+
from typing import List, Optional, Union, Iterable, Dict, Literal, IO
|
13
16
|
|
14
17
|
from tamar_model_client.enums import ProviderType, InvokeType
|
15
18
|
from tamar_model_client.enums.channel import Channel
|
19
|
+
from tamar_model_client.utils import convert_file_field, validate_fields_by_provider_and_invoke_type
|
16
20
|
|
17
21
|
|
18
22
|
class UserContext(BaseModel):
|
@@ -149,6 +153,29 @@ class OpenAIImagesInput(BaseModel):
|
|
149
153
|
}
|
150
154
|
|
151
155
|
|
156
|
+
class OpenAIImagesEditInput(BaseModel):
|
157
|
+
image: Union[FileTypes, List[FileTypes]]
|
158
|
+
prompt: str
|
159
|
+
background: Optional[Literal["transparent", "opaque", "auto"]] | NotGiven = NOT_GIVEN
|
160
|
+
mask: FileTypes | NotGiven = NOT_GIVEN
|
161
|
+
model: Union[str, ImageModel, None] | NotGiven = NOT_GIVEN
|
162
|
+
n: Optional[int] | NotGiven = NOT_GIVEN
|
163
|
+
quality: Optional[Literal["standard", "low", "medium", "high", "auto"]] | NotGiven = NOT_GIVEN
|
164
|
+
response_format: Optional[Literal["url", "b64_json"]] | NotGiven = NOT_GIVEN
|
165
|
+
size: Optional[Literal["256x256", "512x512", "1024x1024", "1536x1024", "1024x1536", "auto"]] | NotGiven = NOT_GIVEN
|
166
|
+
user: str | NotGiven = NOT_GIVEN
|
167
|
+
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
|
168
|
+
# The extra values given here take precedence over values defined on the client or passed to this method.
|
169
|
+
extra_headers: Headers | None = None
|
170
|
+
extra_query: Query | None = None
|
171
|
+
extra_body: Body | None = None
|
172
|
+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN
|
173
|
+
|
174
|
+
model_config = {
|
175
|
+
"arbitrary_types_allowed": True
|
176
|
+
}
|
177
|
+
|
178
|
+
|
152
179
|
class BaseRequest(BaseModel):
|
153
180
|
provider: ProviderType # 供应商,如 "openai", "google" 等
|
154
181
|
channel: Channel = Channel.NORMAL # 渠道:不同服务商之前有不同的调用SDK,这里指定是调用哪个SDK
|
@@ -212,8 +239,11 @@ class ModelRequestInput(BaseRequest):
|
|
212
239
|
contents: Optional[Union[types.ContentListUnion, types.ContentListUnionDict]] = None
|
213
240
|
config: Optional[types.GenerateContentConfigOrDict] = None
|
214
241
|
|
215
|
-
# OpenAIImagesInput + GoogleVertexAIImagesInput 合并字段
|
242
|
+
# OpenAIImagesInput + OpenAIImagesEditInput + GoogleVertexAIImagesInput 合并字段
|
243
|
+
image: Optional[Union[FileTypes, List[FileTypes]]] = None
|
216
244
|
prompt: Optional[str] = None
|
245
|
+
background: Optional[Literal["transparent", "opaque", "auto"]] | NotGiven = NOT_GIVEN
|
246
|
+
mask: FileTypes | NotGiven = NOT_GIVEN
|
217
247
|
negative_prompt: Optional[str] = None
|
218
248
|
aspect_ratio: Optional[Literal["1:1", "9:16", "16:9", "4:3", "3:4"]] = None
|
219
249
|
guidance_scale: Optional[float] = None
|
@@ -223,7 +253,8 @@ class ModelRequestInput(BaseRequest):
|
|
223
253
|
safety_filter_level: Optional[Literal["block_most", "block_some", "block_few", "block_fewest"]] = None
|
224
254
|
person_generation: Optional[Literal["dont_allow", "allow_adult", "allow_all"]] = None
|
225
255
|
quality: Optional[Literal["standard", "hd"]] | NotGiven = NOT_GIVEN
|
226
|
-
size: Optional[Literal[
|
256
|
+
size: Optional[Literal[
|
257
|
+
"auto", "1024x1024", "1536x1024", "1024x1536", "256x256", "512x512", "1792x1024", "1024x1792"]] | NotGiven = NOT_GIVEN
|
227
258
|
style: Optional[Literal["vivid", "natural"]] | NotGiven = NOT_GIVEN
|
228
259
|
number_of_images: Optional[int] = None # Google 用法
|
229
260
|
|
@@ -231,71 +262,26 @@ class ModelRequestInput(BaseRequest):
|
|
231
262
|
"arbitrary_types_allowed": True
|
232
263
|
}
|
233
264
|
|
265
|
+
@field_validator("image", mode="before")
|
266
|
+
@classmethod
|
267
|
+
def validate_image(cls, v):
|
268
|
+
return convert_file_field(v)
|
269
|
+
|
270
|
+
@field_validator("mask", mode="before")
|
271
|
+
@classmethod
|
272
|
+
def validate_mask(cls, v):
|
273
|
+
return convert_file_field(v)
|
274
|
+
|
234
275
|
|
235
276
|
class ModelRequest(ModelRequestInput):
|
236
277
|
user_context: UserContext # 用户信息
|
237
278
|
|
238
279
|
@model_validator(mode="after")
|
239
280
|
def validate_by_provider_and_invoke_type(self) -> "ModelRequest":
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
openai_responses_allowed = base_allowed | set(OpenAIResponsesInput.model_fields.keys())
|
245
|
-
openai_chat_allowed = base_allowed | set(OpenAIChatCompletionsInput.model_fields.keys())
|
246
|
-
openai_images_allowed = base_allowed | set(OpenAIImagesInput.model_fields.keys())
|
247
|
-
google_vertexai_images_allowed = base_allowed | set(GoogleVertexAIImagesInput.model_fields.keys())
|
248
|
-
|
249
|
-
# 各模型类型必填字段
|
250
|
-
google_required_fields = {"model", "contents"}
|
251
|
-
google_vertexai_image_required_fields = {"model", "prompt"}
|
252
|
-
|
253
|
-
openai_responses_required_fields = {"input", "model"}
|
254
|
-
openai_chat_required_fields = {"messages", "model"}
|
255
|
-
openai_image_required_fields = {"prompt"}
|
256
|
-
|
257
|
-
# 选择需要校验的字段集合
|
258
|
-
# 动态分支逻辑
|
259
|
-
match (self.provider, self.invoke_type):
|
260
|
-
case (ProviderType.GOOGLE, InvokeType.GENERATION):
|
261
|
-
allowed_fields = google_allowed
|
262
|
-
expected_fields = google_required_fields
|
263
|
-
case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
|
264
|
-
allowed_fields = google_vertexai_images_allowed
|
265
|
-
expected_fields = google_vertexai_image_required_fields
|
266
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
|
267
|
-
allowed_fields = openai_responses_allowed
|
268
|
-
expected_fields = openai_responses_required_fields
|
269
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
|
270
|
-
allowed_fields = openai_chat_allowed
|
271
|
-
expected_fields = openai_chat_required_fields
|
272
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
273
|
-
allowed_fields = openai_images_allowed
|
274
|
-
expected_fields = openai_image_required_fields
|
275
|
-
case _:
|
276
|
-
raise ValueError(f"Unsupported provider/invoke_type combination: {self.provider} + {self.invoke_type}")
|
277
|
-
|
278
|
-
# 校验必填字段是否缺失
|
279
|
-
missing = [field for field in expected_fields if getattr(self, field, None) is None]
|
280
|
-
if missing:
|
281
|
-
raise ValueError(
|
282
|
-
f"Missing required fields for provider={self.provider} and invoke_type={self.invoke_type}: {missing}")
|
283
|
-
|
284
|
-
# 检查是否有非法字段
|
285
|
-
illegal_fields = []
|
286
|
-
valid_fields = {"provider", "channel", "invoke_type"} if self.invoke_type == InvokeType.IMAGE_GENERATION else {
|
287
|
-
"provider", "channel", "invoke_type", "stream"}
|
288
|
-
for name, value in self.__dict__.items():
|
289
|
-
if name in valid_fields:
|
290
|
-
continue
|
291
|
-
if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
|
292
|
-
illegal_fields.append(name)
|
293
|
-
|
294
|
-
if illegal_fields:
|
295
|
-
raise ValueError(
|
296
|
-
f"Unsupported fields for provider={self.provider} and invoke_type={self.invoke_type}: {illegal_fields}")
|
297
|
-
|
298
|
-
return self
|
281
|
+
return validate_fields_by_provider_and_invoke_type(
|
282
|
+
instance=self,
|
283
|
+
extra_allowed_fields={"provider", "channel", "invoke_type", "user_context"},
|
284
|
+
)
|
299
285
|
|
300
286
|
|
301
287
|
class BatchModelRequestItem(ModelRequestInput):
|
@@ -304,65 +290,10 @@ class BatchModelRequestItem(ModelRequestInput):
|
|
304
290
|
|
305
291
|
@model_validator(mode="after")
|
306
292
|
def validate_by_provider_and_invoke_type(self) -> "BatchModelRequestItem":
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
openai_responses_allowed = base_allowed | set(OpenAIResponsesInput.model_fields.keys())
|
312
|
-
openai_chat_allowed = base_allowed | set(OpenAIChatCompletionsInput.model_fields.keys())
|
313
|
-
openai_images_allowed = base_allowed | set(OpenAIImagesInput.model_fields.keys())
|
314
|
-
google_vertexai_images_allowed = base_allowed | set(GoogleVertexAIImagesInput.model_fields.keys())
|
315
|
-
|
316
|
-
# 各模型类型必填字段
|
317
|
-
google_required_fields = {"model", "contents"}
|
318
|
-
google_vertexai_image_required_fields = {"model", "prompt"}
|
319
|
-
|
320
|
-
openai_responses_required_fields = {"input", "model"}
|
321
|
-
openai_chat_required_fields = {"messages", "model"}
|
322
|
-
openai_image_required_fields = {"prompt"}
|
323
|
-
|
324
|
-
# 选择需要校验的字段集合
|
325
|
-
# 动态分支逻辑
|
326
|
-
match (self.provider, self.invoke_type):
|
327
|
-
case (ProviderType.GOOGLE, InvokeType.GENERATION):
|
328
|
-
allowed_fields = google_allowed
|
329
|
-
expected_fields = google_required_fields
|
330
|
-
case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
|
331
|
-
allowed_fields = google_vertexai_images_allowed
|
332
|
-
expected_fields = google_vertexai_image_required_fields
|
333
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
|
334
|
-
allowed_fields = openai_responses_allowed
|
335
|
-
expected_fields = openai_responses_required_fields
|
336
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
|
337
|
-
allowed_fields = openai_chat_allowed
|
338
|
-
expected_fields = openai_chat_required_fields
|
339
|
-
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
340
|
-
allowed_fields = openai_images_allowed
|
341
|
-
expected_fields = openai_image_required_fields
|
342
|
-
case _:
|
343
|
-
raise ValueError(f"Unsupported provider/invoke_type combination: {self.provider} + {self.invoke_type}")
|
344
|
-
|
345
|
-
# 校验必填字段是否缺失
|
346
|
-
missing = [field for field in expected_fields if getattr(self, field, None) is None]
|
347
|
-
if missing:
|
348
|
-
raise ValueError(
|
349
|
-
f"Missing required fields for provider={self.provider} and invoke_type={self.invoke_type}: {missing}")
|
350
|
-
|
351
|
-
# 检查是否有非法字段
|
352
|
-
illegal_fields = []
|
353
|
-
valid_fields = {"provider", "channel", "invoke_type"} if self.invoke_type == InvokeType.IMAGE_GENERATION else {
|
354
|
-
"provider", "channel", "invoke_type", "stream"}
|
355
|
-
for name, value in self.__dict__.items():
|
356
|
-
if name in valid_fields:
|
357
|
-
continue
|
358
|
-
if name not in allowed_fields and value is not None and not isinstance(value, NotGiven):
|
359
|
-
illegal_fields.append(name)
|
360
|
-
|
361
|
-
if illegal_fields:
|
362
|
-
raise ValueError(
|
363
|
-
f"Unsupported fields for provider={self.provider} and invoke_type={self.invoke_type}: {illegal_fields}")
|
364
|
-
|
365
|
-
return self
|
293
|
+
return validate_fields_by_provider_and_invoke_type(
|
294
|
+
instance=self,
|
295
|
+
extra_allowed_fields={"provider", "channel", "invoke_type", "user_context", "custom_id"},
|
296
|
+
)
|
366
297
|
|
367
298
|
|
368
299
|
class BatchModelRequest(BaseModel):
|
@@ -17,7 +17,7 @@ from .exceptions import ConnectionError
|
|
17
17
|
from .generated import model_service_pb2, model_service_pb2_grpc
|
18
18
|
from .schemas import BatchModelResponse, ModelResponse
|
19
19
|
from .schemas.inputs import GoogleGenAiInput, GoogleVertexAIImagesInput, OpenAIResponsesInput, \
|
20
|
-
OpenAIChatCompletionsInput, OpenAIImagesInput, BatchModelRequest, ModelRequest
|
20
|
+
OpenAIChatCompletionsInput, OpenAIImagesInput, OpenAIImagesEditInput, BatchModelRequest, ModelRequest
|
21
21
|
|
22
22
|
logger = logging.getLogger(__name__)
|
23
23
|
|
@@ -247,20 +247,16 @@ class TamarModelClient:
|
|
247
247
|
logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
|
248
248
|
time.sleep(delay) # Blocking sleep in sync version
|
249
249
|
|
250
|
-
def
|
251
|
-
|
252
|
-
response = self.stub.Invoke(model_request, metadata=metadata, timeout=invoke_timeout)
|
253
|
-
for res in response:
|
250
|
+
def _stream(self, request, metadata, invoke_timeout) -> Iterator[ModelResponse]:
|
251
|
+
for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
|
254
252
|
yield ModelResponse(
|
255
|
-
content=
|
256
|
-
usage=json.loads(
|
257
|
-
|
258
|
-
|
253
|
+
content=response.content,
|
254
|
+
usage=json.loads(response.usage) if response.usage else None,
|
255
|
+
error=response.error or None,
|
256
|
+
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
257
|
+
request_id=response.request_id if response.request_id else None,
|
259
258
|
)
|
260
259
|
|
261
|
-
def _stream(self, model_request, metadata, invoke_timeout) -> Iterator[ModelResponse]:
|
262
|
-
return self._retry_request(self._stream_inner, model_request, metadata, invoke_timeout)
|
263
|
-
|
264
260
|
def _invoke_request(self, request, metadata, invoke_timeout):
|
265
261
|
response = self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout)
|
266
262
|
for response in response:
|
@@ -268,6 +264,7 @@ class TamarModelClient:
|
|
268
264
|
content=response.content,
|
269
265
|
usage=json.loads(response.usage) if response.usage else None,
|
270
266
|
error=response.error or None,
|
267
|
+
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
271
268
|
request_id=response.request_id if response.request_id else None,
|
272
269
|
)
|
273
270
|
|
@@ -302,7 +299,7 @@ class TamarModelClient:
|
|
302
299
|
|
303
300
|
# 记录开始日志
|
304
301
|
logger.info(
|
305
|
-
f"🔵 Request Start |
|
302
|
+
f"🔵 Request Start |provider: {model_request.provider} | invoke_type: {model_request.invoke_type}")
|
306
303
|
|
307
304
|
# 动态根据 provider/invoke_type 决定使用哪个 input 字段
|
308
305
|
try:
|
@@ -319,6 +316,8 @@ class TamarModelClient:
|
|
319
316
|
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
320
317
|
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
321
318
|
allowed_fields = OpenAIImagesInput.model_fields.keys()
|
319
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
320
|
+
allowed_fields = OpenAIImagesEditInput.model_fields.keys()
|
322
321
|
case _:
|
323
322
|
raise ValueError(
|
324
323
|
f"Unsupported provider/invoke_type combination: {model_request.provider} + {model_request.invoke_type}")
|
@@ -358,7 +357,7 @@ class TamarModelClient:
|
|
358
357
|
try:
|
359
358
|
invoke_timeout = timeout or self.default_invoke_timeout
|
360
359
|
if model_request.stream:
|
361
|
-
return self._stream
|
360
|
+
return self._retry_request(self._stream, request, metadata, invoke_timeout)
|
362
361
|
else:
|
363
362
|
return self._retry_request(self._invoke_request, request, metadata, invoke_timeout)
|
364
363
|
except grpc.RpcError as e:
|
@@ -398,7 +397,7 @@ class TamarModelClient:
|
|
398
397
|
|
399
398
|
# 记录开始日志
|
400
399
|
logger.info(
|
401
|
-
f"🔵 Batch Request Start |
|
400
|
+
f"🔵 Batch Request Start | batch_size: {len(batch_request_model.items)}")
|
402
401
|
|
403
402
|
# 构造批量请求
|
404
403
|
items = []
|
@@ -416,6 +415,8 @@ class TamarModelClient:
|
|
416
415
|
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
417
416
|
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
418
417
|
allowed_fields = OpenAIImagesInput.model_fields.keys()
|
418
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
419
|
+
allowed_fields = OpenAIImagesEditInput.model_fields.keys()
|
419
420
|
case _:
|
420
421
|
raise ValueError(
|
421
422
|
f"Unsupported provider/invoke_type combination: {model_request_item.provider} + {model_request_item.invoke_type}")
|
@@ -0,0 +1,118 @@
|
|
1
|
+
from openai import NotGiven
|
2
|
+
from pydantic import BaseModel
|
3
|
+
from typing import Any
|
4
|
+
import os, mimetypes
|
5
|
+
|
6
|
+
def convert_file_field(value: Any) -> Any:
|
7
|
+
def is_file_like(obj):
|
8
|
+
return hasattr(obj, "read") and callable(obj.read)
|
9
|
+
|
10
|
+
def infer_mimetype(filename: str) -> str:
|
11
|
+
mime, _ = mimetypes.guess_type(filename)
|
12
|
+
return mime or "application/octet-stream"
|
13
|
+
|
14
|
+
def convert_item(item):
|
15
|
+
if is_file_like(item):
|
16
|
+
filename = os.path.basename(getattr(item, "name", "file.png"))
|
17
|
+
content_type = infer_mimetype(filename)
|
18
|
+
content = item.read()
|
19
|
+
if hasattr(item, "seek"):
|
20
|
+
item.seek(0)
|
21
|
+
return (filename, content, content_type)
|
22
|
+
elif isinstance(item, tuple):
|
23
|
+
parts = list(item)
|
24
|
+
if len(parts) > 1:
|
25
|
+
maybe_file = parts[1]
|
26
|
+
if is_file_like(maybe_file):
|
27
|
+
content = maybe_file.read()
|
28
|
+
if hasattr(maybe_file, "seek"):
|
29
|
+
maybe_file.seek(0)
|
30
|
+
parts[1] = content
|
31
|
+
elif not isinstance(maybe_file, (bytes, bytearray)):
|
32
|
+
raise ValueError(f"Unsupported second element in tuple: {type(maybe_file)}")
|
33
|
+
if len(parts) == 2:
|
34
|
+
parts.append(infer_mimetype(os.path.basename(parts[0] or "file.png")))
|
35
|
+
return tuple(parts)
|
36
|
+
else:
|
37
|
+
return item
|
38
|
+
|
39
|
+
if value is None:
|
40
|
+
return value
|
41
|
+
elif isinstance(value, list):
|
42
|
+
return [convert_item(v) for v in value]
|
43
|
+
else:
|
44
|
+
return convert_item(value)
|
45
|
+
|
46
|
+
|
47
|
+
def validate_fields_by_provider_and_invoke_type(
|
48
|
+
instance: BaseModel,
|
49
|
+
extra_allowed_fields: set[str],
|
50
|
+
extra_required_fields: set[str] = set()
|
51
|
+
) -> BaseModel:
|
52
|
+
"""
|
53
|
+
通用的字段校验逻辑,根据 provider 和 invoke_type 动态检查字段合法性和必填字段。
|
54
|
+
适用于 ModelRequest 和 BatchModelRequestItem。
|
55
|
+
"""
|
56
|
+
from tamar_model_client.enums import ProviderType, InvokeType
|
57
|
+
from tamar_model_client.schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput, \
|
58
|
+
OpenAIImagesInput, OpenAIImagesEditInput, GoogleVertexAIImagesInput
|
59
|
+
|
60
|
+
google_allowed = extra_allowed_fields | set(GoogleGenAiInput.model_fields)
|
61
|
+
openai_responses_allowed = extra_allowed_fields | set(OpenAIResponsesInput.model_fields)
|
62
|
+
openai_chat_allowed = extra_allowed_fields | set(OpenAIChatCompletionsInput.model_fields)
|
63
|
+
openai_images_allowed = extra_allowed_fields | set(OpenAIImagesInput.model_fields)
|
64
|
+
openai_images_edit_allowed = extra_allowed_fields | set(OpenAIImagesEditInput.model_fields)
|
65
|
+
google_vertexai_images_allowed = extra_allowed_fields | set(GoogleVertexAIImagesInput.model_fields)
|
66
|
+
|
67
|
+
google_required = {"model", "contents"}
|
68
|
+
google_vertex_required = {"model", "prompt"}
|
69
|
+
openai_resp_required = {"input", "model"}
|
70
|
+
openai_chat_required = {"messages", "model"}
|
71
|
+
openai_img_required = {"prompt"}
|
72
|
+
openai_edit_required = {"image", "prompt"}
|
73
|
+
|
74
|
+
match (instance.provider, instance.invoke_type):
|
75
|
+
case (ProviderType.GOOGLE, InvokeType.GENERATION):
|
76
|
+
allowed = google_allowed
|
77
|
+
required = google_required
|
78
|
+
case (ProviderType.GOOGLE, InvokeType.IMAGE_GENERATION):
|
79
|
+
allowed = google_vertexai_images_allowed
|
80
|
+
required = google_vertex_required
|
81
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.RESPONSES | InvokeType.GENERATION):
|
82
|
+
allowed = openai_responses_allowed
|
83
|
+
required = openai_resp_required
|
84
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.CHAT_COMPLETIONS):
|
85
|
+
allowed = openai_chat_allowed
|
86
|
+
required = openai_chat_required
|
87
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_GENERATION):
|
88
|
+
allowed = openai_images_allowed
|
89
|
+
required = openai_img_required
|
90
|
+
case ((ProviderType.OPENAI | ProviderType.AZURE), InvokeType.IMAGE_EDIT_GENERATION):
|
91
|
+
allowed = openai_images_edit_allowed
|
92
|
+
required = openai_edit_required
|
93
|
+
case _:
|
94
|
+
raise ValueError(f"Unsupported provider/invoke_type: {instance.provider} + {instance.invoke_type}")
|
95
|
+
|
96
|
+
required = required | extra_required_fields
|
97
|
+
|
98
|
+
missing = [f for f in required if getattr(instance, f, None) is None]
|
99
|
+
if missing:
|
100
|
+
raise ValueError(
|
101
|
+
f"Missing required fields for provider={instance.provider} and invoke_type={instance.invoke_type}: {missing}")
|
102
|
+
|
103
|
+
illegal = []
|
104
|
+
valid_fields = {"provider", "channel", "invoke_type"}
|
105
|
+
if getattr(instance, "stream", None) is not None:
|
106
|
+
valid_fields.add("stream")
|
107
|
+
|
108
|
+
for k, v in instance.__dict__.items():
|
109
|
+
if k in valid_fields:
|
110
|
+
continue
|
111
|
+
if k not in allowed and v is not None and not isinstance(v, NotGiven):
|
112
|
+
illegal.append(k)
|
113
|
+
|
114
|
+
if illegal:
|
115
|
+
raise ValueError(
|
116
|
+
f"Unsupported fields for provider={instance.provider} and invoke_type={instance.invoke_type}: {illegal}")
|
117
|
+
|
118
|
+
return instance
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: tamar-model-client
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.17
|
4
4
|
Summary: A Python SDK for interacting with the Model Manager gRPC service
|
5
5
|
Home-page: http://gitlab.tamaredge.top/project-tap/AgentOS/model-manager-client
|
6
6
|
Author: Oscar Ou
|
@@ -273,13 +273,13 @@ async def main():
|
|
273
273
|
)
|
274
274
|
|
275
275
|
# 发送请求并获取响应
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
276
|
+
async for r in await client.invoke(model_request):
|
277
|
+
if r.error:
|
278
|
+
print(f"错误: {r.error}")
|
279
|
+
else:
|
280
|
+
print(f"响应: {r.content}")
|
281
|
+
if r.usage:
|
282
|
+
print(f"Token 使用情况: {r.usage}")
|
283
283
|
|
284
284
|
|
285
285
|
# 运行异步示例
|
@@ -531,7 +531,7 @@ python make_grpc.py
|
|
531
531
|
### 部署到 pip
|
532
532
|
```bash
|
533
533
|
python setup.py sdist bdist_wheel
|
534
|
-
twine
|
534
|
+
twine upload dist/*
|
535
535
|
|
536
536
|
```
|
537
537
|
|
@@ -1,19 +1,20 @@
|
|
1
1
|
tamar_model_client/__init__.py,sha256=LMECAuDARWHV1XzH3msoDXcyurS2eihRQmBy26_PUE0,328
|
2
|
-
tamar_model_client/async_client.py,sha256=
|
2
|
+
tamar_model_client/async_client.py,sha256=fTQMLWz7DxW1fynmfUxlS3anmYOxv6giVUGq6ZG4kzk,25972
|
3
3
|
tamar_model_client/auth.py,sha256=gbwW5Aakeb49PMbmYvrYlVx1mfyn1LEDJ4qQVs-9DA4,438
|
4
4
|
tamar_model_client/exceptions.py,sha256=jYU494OU_NeIa4X393V-Y73mTNm0JZ9yZApnlOM9CJQ,332
|
5
|
-
tamar_model_client/sync_client.py,sha256
|
5
|
+
tamar_model_client/sync_client.py,sha256=-Gbx1DP4LRFZZZd4sKpY5Fi-_WHZEVayl1ABD3k7O6I,22748
|
6
|
+
tamar_model_client/utils.py,sha256=Kn6pFz9GEC96H4eejEax66AkzvsrXI3WCSDtgDjnVTI,5238
|
6
7
|
tamar_model_client/enums/__init__.py,sha256=3cYYn8ztNGBa_pI_5JGRVYf2QX8fkBVWdjID1PLvoBQ,182
|
7
8
|
tamar_model_client/enums/channel.py,sha256=wCzX579nNpTtwzGeS6S3Ls0UzVAgsOlfy4fXMzQTCAw,199
|
8
|
-
tamar_model_client/enums/invoke.py,sha256=
|
9
|
+
tamar_model_client/enums/invoke.py,sha256=Up87myAg4-0SDJV5a82ggPDpYHSLEtIco8BF_5Ph1nY,322
|
9
10
|
tamar_model_client/enums/providers.py,sha256=L_bX75K6KnWURoFizoitZ1Ybza7bmYDqXecNzNpgIrI,165
|
10
11
|
tamar_model_client/generated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
12
|
tamar_model_client/generated/model_service_pb2.py,sha256=RI6wNSmgmylzWPedFfPxx938UzS7kcPR58YTzYshcL8,3066
|
12
13
|
tamar_model_client/generated/model_service_pb2_grpc.py,sha256=k4tIbp3XBxdyuOVR18Ung_4SUryONB51UYf_uUEl6V4,5145
|
13
14
|
tamar_model_client/schemas/__init__.py,sha256=AxuI-TcvA4OMTj2FtK4wAItvz9LrK_293pu3cmMLE7k,394
|
14
|
-
tamar_model_client/schemas/inputs.py,sha256=
|
15
|
+
tamar_model_client/schemas/inputs.py,sha256=dz1m8NbUIxA99JXZc8WlyzbKpDuz1lEzx3VghC33zYI,14625
|
15
16
|
tamar_model_client/schemas/outputs.py,sha256=M_fcqUtXPJnfiLabHlyA8BorlC5pYkf5KLjXO1ysKIQ,1031
|
16
|
-
tamar_model_client-0.1.
|
17
|
-
tamar_model_client-0.1.
|
18
|
-
tamar_model_client-0.1.
|
19
|
-
tamar_model_client-0.1.
|
17
|
+
tamar_model_client-0.1.17.dist-info/METADATA,sha256=eP0oGK9qWIJXNu0YX2Q4LjERRPTWI0Hn8JgOJCJWW1w,16562
|
18
|
+
tamar_model_client-0.1.17.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
19
|
+
tamar_model_client-0.1.17.dist-info/top_level.txt,sha256=_LfDhPv_fvON0PoZgQuo4M7EjoWtxPRoQOBJziJmip8,19
|
20
|
+
tamar_model_client-0.1.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|