tamar-model-client 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tamar_model_client/__init__.py +11 -0
- tamar_model_client/async_client.py +414 -0
- tamar_model_client/auth.py +14 -0
- tamar_model_client/enums/__init__.py +8 -0
- tamar_model_client/enums/channel.py +11 -0
- tamar_model_client/enums/invoke.py +10 -0
- tamar_model_client/enums/providers.py +8 -0
- tamar_model_client/exceptions.py +11 -0
- tamar_model_client/generated/__init__.py +0 -0
- tamar_model_client/generated/model_service_pb2.py +45 -0
- tamar_model_client/generated/model_service_pb2_grpc.py +145 -0
- tamar_model_client/schemas/__init__.py +17 -0
- tamar_model_client/schemas/inputs.py +294 -0
- tamar_model_client/schemas/outputs.py +24 -0
- tamar_model_client/sync_client.py +111 -0
- {tamar_model_client-0.1.1.dist-info → tamar_model_client-0.1.3.dist-info}/METADATA +61 -90
- tamar_model_client-0.1.3.dist-info/RECORD +34 -0
- tamar_model_client-0.1.3.dist-info/top_level.txt +1 -0
- tamar_model_client-0.1.1.dist-info/RECORD +0 -19
- tamar_model_client-0.1.1.dist-info/top_level.txt +0 -1
- {tamar_model_client-0.1.1.dist-info → tamar_model_client-0.1.3.dist-info}/WHEEL +0 -0
@@ -0,0 +1,11 @@
|
|
1
|
+
from .sync_client import TamarModelClient
|
2
|
+
from .async_client import AsyncTamarModelClient
|
3
|
+
from .exceptions import ModelManagerClientError, ConnectionError, ValidationError
|
4
|
+
|
5
|
+
__all__ = [
|
6
|
+
"TamarModelClient",
|
7
|
+
"AsyncTamarModelClient",
|
8
|
+
"ModelManagerClientError",
|
9
|
+
"ConnectionError",
|
10
|
+
"ValidationError",
|
11
|
+
]
|
@@ -0,0 +1,414 @@
|
|
1
|
+
import asyncio
|
2
|
+
import atexit
|
3
|
+
import base64
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
|
8
|
+
import grpc
|
9
|
+
from typing import Optional, AsyncIterator, Union, Iterable
|
10
|
+
|
11
|
+
from openai import NOT_GIVEN
|
12
|
+
from pydantic import BaseModel
|
13
|
+
|
14
|
+
from .auth import JWTAuthHandler
|
15
|
+
from .enums import ProviderType, InvokeType
|
16
|
+
from .exceptions import ConnectionError, ValidationError
|
17
|
+
from .schemas import ModelRequest, ModelResponse, BatchModelRequest, BatchModelResponse
|
18
|
+
from .generated import model_service_pb2, model_service_pb2_grpc
|
19
|
+
from .schemas.inputs import GoogleGenAiInput, OpenAIResponsesInput, OpenAIChatCompletionsInput
|
20
|
+
|
21
|
+
if not logging.getLogger().hasHandlers():
|
22
|
+
# 配置日志格式
|
23
|
+
logging.basicConfig(
|
24
|
+
level=logging.INFO,
|
25
|
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
26
|
+
)
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
31
|
+
def is_effective_value(value) -> bool:
|
32
|
+
"""
|
33
|
+
递归判断value是否是有意义的有效值
|
34
|
+
"""
|
35
|
+
if value is None or value is NOT_GIVEN:
|
36
|
+
return False
|
37
|
+
|
38
|
+
if isinstance(value, str):
|
39
|
+
return value.strip() != ""
|
40
|
+
|
41
|
+
if isinstance(value, bytes):
|
42
|
+
return len(value) > 0
|
43
|
+
|
44
|
+
if isinstance(value, dict):
|
45
|
+
for v in value.values():
|
46
|
+
if is_effective_value(v):
|
47
|
+
return True
|
48
|
+
return False
|
49
|
+
|
50
|
+
if isinstance(value, list):
|
51
|
+
for item in value:
|
52
|
+
if is_effective_value(item):
|
53
|
+
return True
|
54
|
+
return False
|
55
|
+
|
56
|
+
return True # 其他类型(int/float/bool)只要不是None就算有效
|
57
|
+
|
58
|
+
|
59
|
+
def serialize_value(value):
|
60
|
+
"""递归处理单个值,处理BaseModel, dict, list, bytes"""
|
61
|
+
if not is_effective_value(value):
|
62
|
+
return None
|
63
|
+
if isinstance(value, BaseModel):
|
64
|
+
return serialize_value(value.model_dump())
|
65
|
+
if hasattr(value, "dict") and callable(value.dict):
|
66
|
+
return serialize_value(value.dict())
|
67
|
+
if isinstance(value, dict):
|
68
|
+
return {k: serialize_value(v) for k, v in value.items()}
|
69
|
+
if isinstance(value, list) or (isinstance(value, Iterable) and not isinstance(value, (str, bytes))):
|
70
|
+
return [serialize_value(v) for v in value]
|
71
|
+
if isinstance(value, bytes):
|
72
|
+
return f"bytes:{base64.b64encode(value).decode('utf-8')}"
|
73
|
+
return value
|
74
|
+
|
75
|
+
|
76
|
+
from typing import Any
|
77
|
+
|
78
|
+
|
79
|
+
def remove_none_from_dict(data: Any) -> Any:
|
80
|
+
"""
|
81
|
+
遍历 dict/list,递归删除 value 为 None 的字段
|
82
|
+
"""
|
83
|
+
if isinstance(data, dict):
|
84
|
+
new_dict = {}
|
85
|
+
for key, value in data.items():
|
86
|
+
if value is None:
|
87
|
+
continue
|
88
|
+
cleaned_value = remove_none_from_dict(value)
|
89
|
+
new_dict[key] = cleaned_value
|
90
|
+
return new_dict
|
91
|
+
elif isinstance(data, list):
|
92
|
+
return [remove_none_from_dict(item) for item in data]
|
93
|
+
else:
|
94
|
+
return data
|
95
|
+
|
96
|
+
|
97
|
+
class AsyncTamarModelClient:
|
98
|
+
def __init__(
|
99
|
+
self,
|
100
|
+
server_address: Optional[str] = None,
|
101
|
+
jwt_secret_key: Optional[str] = None,
|
102
|
+
jwt_token: Optional[str] = None,
|
103
|
+
default_payload: Optional[dict] = None,
|
104
|
+
token_expires_in: int = 3600,
|
105
|
+
max_retries: int = 3, # 最大重试次数
|
106
|
+
retry_delay: float = 1.0, # 初始重试延迟(秒)
|
107
|
+
):
|
108
|
+
# 服务端地址
|
109
|
+
self.server_address = server_address or os.getenv("MODEL_MANAGER_SERVER_ADDRESS")
|
110
|
+
if not self.server_address:
|
111
|
+
raise ValueError("Server address must be provided via argument or environment variable.")
|
112
|
+
self.default_invoke_timeout = float(os.getenv("MODEL_MANAGER_SERVER_INVOKE_TIMEOUT", 30.0))
|
113
|
+
|
114
|
+
# JWT 配置
|
115
|
+
self.jwt_secret_key = jwt_secret_key or os.getenv("MODEL_MANAGER_SERVER_JWT_SECRET_KEY")
|
116
|
+
self.jwt_handler = JWTAuthHandler(self.jwt_secret_key)
|
117
|
+
self.jwt_token = jwt_token # 用户传入的 Token(可选)
|
118
|
+
self.default_payload = default_payload
|
119
|
+
self.token_expires_in = token_expires_in
|
120
|
+
|
121
|
+
# === TLS/Authority 配置 ===
|
122
|
+
self.use_tls = os.getenv("MODEL_MANAGER_SERVER_GRPC_USE_TLS", "true").lower() == "true"
|
123
|
+
self.default_authority = os.getenv("MODEL_MANAGER_SERVER_GRPC_DEFAULT_AUTHORITY")
|
124
|
+
|
125
|
+
# === 重试配置 ===
|
126
|
+
self.max_retries = max_retries if max_retries is not None else int(
|
127
|
+
os.getenv("MODEL_MANAGER_SERVER_GRPC_MAX_RETRIES", 3))
|
128
|
+
self.retry_delay = retry_delay if retry_delay is not None else float(
|
129
|
+
os.getenv("MODEL_MANAGER_SERVER_GRPC_RETRY_DELAY", 1.0))
|
130
|
+
|
131
|
+
# === gRPC 通道相关 ===
|
132
|
+
self.channel: Optional[grpc.aio.Channel] = None
|
133
|
+
self.stub: Optional[model_service_pb2_grpc.ModelServiceStub] = None
|
134
|
+
self._closed = False
|
135
|
+
atexit.register(self._safe_sync_close) # 注册进程退出自动关闭
|
136
|
+
|
137
|
+
def _build_auth_metadata(self) -> list:
|
138
|
+
if not self.jwt_token and self.jwt_handler:
|
139
|
+
self.jwt_token = self.jwt_handler.encode_token(self.default_payload, expires_in=self.token_expires_in)
|
140
|
+
return [("authorization", f"Bearer {self.jwt_token}")] if self.jwt_token else []
|
141
|
+
|
142
|
+
async def _ensure_initialized(self):
|
143
|
+
"""初始化 gRPC 通道,支持 TLS 与重试机制"""
|
144
|
+
if self.channel and self.stub:
|
145
|
+
return
|
146
|
+
|
147
|
+
retry_count = 0
|
148
|
+
options = []
|
149
|
+
if self.default_authority:
|
150
|
+
options.append(("grpc.default_authority", self.default_authority))
|
151
|
+
|
152
|
+
while retry_count <= self.max_retries:
|
153
|
+
try:
|
154
|
+
if self.use_tls:
|
155
|
+
credentials = grpc.ssl_channel_credentials()
|
156
|
+
self.channel = grpc.aio.secure_channel(
|
157
|
+
self.server_address,
|
158
|
+
credentials,
|
159
|
+
options=options
|
160
|
+
)
|
161
|
+
logger.info("🔐 Using secure gRPC channel (TLS enabled)")
|
162
|
+
else:
|
163
|
+
self.channel = grpc.aio.insecure_channel(
|
164
|
+
self.server_address,
|
165
|
+
options=options
|
166
|
+
)
|
167
|
+
logger.info("🔓 Using insecure gRPC channel (TLS disabled)")
|
168
|
+
await self.channel.channel_ready()
|
169
|
+
self.stub = model_service_pb2_grpc.ModelServiceStub(self.channel)
|
170
|
+
logger.info(f"✅ gRPC channel initialized to {self.server_address}")
|
171
|
+
return
|
172
|
+
except grpc.FutureTimeoutError as e:
|
173
|
+
logger.warning(f"❌ gRPC channel initialization timed out: {str(e)}")
|
174
|
+
except grpc.RpcError as e:
|
175
|
+
logger.warning(f"❌ gRPC channel initialization failed: {str(e)}")
|
176
|
+
except Exception as e:
|
177
|
+
logger.warning(f"❌ Unexpected error during channel initialization: {str(e)}")
|
178
|
+
|
179
|
+
retry_count += 1
|
180
|
+
if retry_count > self.max_retries:
|
181
|
+
raise ConnectionError(f"❌ Failed to initialize gRPC channel after {self.max_retries} retries.")
|
182
|
+
|
183
|
+
# 指数退避:延迟时间 = retry_delay * (2 ^ (retry_count - 1))
|
184
|
+
delay = self.retry_delay * (2 ** (retry_count - 1))
|
185
|
+
logger.info(f"🚀 Retrying connection (attempt {retry_count}/{self.max_retries}) after {delay:.2f}s delay...")
|
186
|
+
await asyncio.sleep(delay)
|
187
|
+
|
188
|
+
async def _stream(self, model_request, metadata, invoke_timeout) -> AsyncIterator[ModelResponse]:
|
189
|
+
try:
|
190
|
+
async for response in self.stub.Invoke(model_request, metadata=metadata, timeout=invoke_timeout):
|
191
|
+
yield ModelResponse(
|
192
|
+
content=response.content,
|
193
|
+
usage=json.loads(response.usage) if response.usage else None,
|
194
|
+
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
195
|
+
error=response.error or None,
|
196
|
+
)
|
197
|
+
except grpc.RpcError as e:
|
198
|
+
raise ConnectionError(f"gRPC call failed: {str(e)}")
|
199
|
+
except Exception as e:
|
200
|
+
raise ValidationError(f"Invalid input: {str(e)}")
|
201
|
+
|
202
|
+
async def invoke(self, model_request: ModelRequest, timeout: Optional[float] = None) -> Union[
|
203
|
+
ModelResponse, AsyncIterator[ModelResponse]]:
|
204
|
+
"""
|
205
|
+
通用调用模型方法。
|
206
|
+
|
207
|
+
Args:
|
208
|
+
model_request: ModelRequest 对象,包含请求参数。
|
209
|
+
|
210
|
+
Yields:
|
211
|
+
ModelResponse: 支持流式或非流式的模型响应
|
212
|
+
|
213
|
+
Raises:
|
214
|
+
ValidationError: 输入验证失败。
|
215
|
+
ConnectionError: 连接服务端失败。
|
216
|
+
"""
|
217
|
+
await self._ensure_initialized()
|
218
|
+
|
219
|
+
if not self.default_payload:
|
220
|
+
self.default_payload = {
|
221
|
+
"org_id": model_request.user_context.org_id or "",
|
222
|
+
"user_id": model_request.user_context.user_id or ""
|
223
|
+
}
|
224
|
+
|
225
|
+
# 动态根据 provider/invoke_type 决定使用哪个 input 字段
|
226
|
+
try:
|
227
|
+
if model_request.provider == ProviderType.GOOGLE:
|
228
|
+
allowed_fields = GoogleGenAiInput.model_fields.keys()
|
229
|
+
elif model_request.provider in {ProviderType.OPENAI, ProviderType.AZURE}:
|
230
|
+
if model_request.invoke_type in {InvokeType.RESPONSES, InvokeType.GENERATION}:
|
231
|
+
allowed_fields = OpenAIResponsesInput.model_fields.keys()
|
232
|
+
elif model_request.invoke_type == InvokeType.CHAT_COMPLETIONS:
|
233
|
+
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
234
|
+
else:
|
235
|
+
raise ValueError(f"暂不支持的调用类型: {model_request.invoke_type}")
|
236
|
+
else:
|
237
|
+
raise ValueError(f"暂不支持的提供商: {model_request.provider}")
|
238
|
+
|
239
|
+
# 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
|
240
|
+
model_request_dict = model_request.model_dump(exclude_unset=True)
|
241
|
+
|
242
|
+
grpc_request_kwargs = {}
|
243
|
+
for field in allowed_fields:
|
244
|
+
if field in model_request_dict:
|
245
|
+
value = model_request_dict[field]
|
246
|
+
|
247
|
+
# 跳过无效的值
|
248
|
+
if not is_effective_value(value):
|
249
|
+
continue
|
250
|
+
|
251
|
+
# 序列化grpc不支持的类型
|
252
|
+
grpc_request_kwargs[field] = serialize_value(value)
|
253
|
+
|
254
|
+
# 清理 serialize后的 grpc_request_kwargs
|
255
|
+
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
256
|
+
|
257
|
+
request = model_service_pb2.ModelRequestItem(
|
258
|
+
provider=model_request.provider.value,
|
259
|
+
channel=model_request.channel.value,
|
260
|
+
invoke_type=model_request.invoke_type.value,
|
261
|
+
stream=model_request.stream or False,
|
262
|
+
org_id=model_request.user_context.org_id or "",
|
263
|
+
user_id=model_request.user_context.user_id or "",
|
264
|
+
client_type=model_request.user_context.client_type or "",
|
265
|
+
extra=grpc_request_kwargs
|
266
|
+
)
|
267
|
+
|
268
|
+
except Exception as e:
|
269
|
+
raise ValueError(f"构建请求失败: {str(e)}") from e
|
270
|
+
|
271
|
+
metadata = self._build_auth_metadata()
|
272
|
+
|
273
|
+
invoke_timeout = timeout or self.default_invoke_timeout
|
274
|
+
if model_request.stream:
|
275
|
+
return self._stream(request, metadata, invoke_timeout)
|
276
|
+
else:
|
277
|
+
async for response in self.stub.Invoke(request, metadata=metadata, timeout=invoke_timeout):
|
278
|
+
return ModelResponse(
|
279
|
+
content=response.content,
|
280
|
+
usage=json.loads(response.usage) if response.usage else None,
|
281
|
+
raw_response=json.loads(response.raw_response) if response.raw_response else None,
|
282
|
+
error=response.error or None,
|
283
|
+
custom_id=None,
|
284
|
+
request_id=response.request_id if response.request_id else None,
|
285
|
+
)
|
286
|
+
|
287
|
+
async def invoke_batch(self, batch_request_model: BatchModelRequest, timeout: Optional[float] = None) -> \
|
288
|
+
BatchModelResponse:
|
289
|
+
"""
|
290
|
+
批量模型调用接口
|
291
|
+
|
292
|
+
Args:
|
293
|
+
batch_request_model: 多条 BatchModelRequest 输入
|
294
|
+
timeout: 调用超时,单位秒
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
BatchModelResponse: 批量请求的结果
|
298
|
+
"""
|
299
|
+
await self._ensure_initialized()
|
300
|
+
|
301
|
+
if not self.default_payload:
|
302
|
+
self.default_payload = {
|
303
|
+
"org_id": batch_request_model.user_context.org_id or "",
|
304
|
+
"user_id": batch_request_model.user_context.user_id or ""
|
305
|
+
}
|
306
|
+
|
307
|
+
metadata = self._build_auth_metadata()
|
308
|
+
|
309
|
+
# 构造批量请求
|
310
|
+
items = []
|
311
|
+
for model_request_item in batch_request_model.items:
|
312
|
+
# 动态根据 provider/invoke_type 决定使用哪个 input 字段
|
313
|
+
try:
|
314
|
+
if model_request_item.provider == ProviderType.GOOGLE:
|
315
|
+
allowed_fields = GoogleGenAiInput.model_fields.keys()
|
316
|
+
elif model_request_item.provider in {ProviderType.OPENAI, ProviderType.AZURE}:
|
317
|
+
if model_request_item.invoke_type in {InvokeType.RESPONSES, InvokeType.GENERATION}:
|
318
|
+
allowed_fields = OpenAIResponsesInput.model_fields.keys()
|
319
|
+
elif model_request_item.invoke_type == InvokeType.CHAT_COMPLETIONS:
|
320
|
+
allowed_fields = OpenAIChatCompletionsInput.model_fields.keys()
|
321
|
+
else:
|
322
|
+
raise ValueError(f"暂不支持的调用类型: {model_request_item.invoke_type}")
|
323
|
+
else:
|
324
|
+
raise ValueError(f"暂不支持的提供商: {model_request_item.provider}")
|
325
|
+
|
326
|
+
# 将 ModelRequest 转 dict,过滤只保留 base + allowed 的字段
|
327
|
+
model_request_dict = model_request_item.model_dump(exclude_unset=True)
|
328
|
+
|
329
|
+
grpc_request_kwargs = {}
|
330
|
+
for field in allowed_fields:
|
331
|
+
if field in model_request_dict:
|
332
|
+
value = model_request_dict[field]
|
333
|
+
|
334
|
+
# 跳过无效的值
|
335
|
+
if not is_effective_value(value):
|
336
|
+
continue
|
337
|
+
|
338
|
+
# 序列化grpc不支持的类型
|
339
|
+
grpc_request_kwargs[field] = serialize_value(value)
|
340
|
+
|
341
|
+
# 清理 serialize后的 grpc_request_kwargs
|
342
|
+
grpc_request_kwargs = remove_none_from_dict(grpc_request_kwargs)
|
343
|
+
|
344
|
+
items.append(model_service_pb2.ModelRequestItem(
|
345
|
+
provider=model_request_item.provider.value,
|
346
|
+
channel=model_request_item.channel.value,
|
347
|
+
invoke_type=model_request_item.invoke_type.value,
|
348
|
+
stream=model_request_item.stream or False,
|
349
|
+
custom_id=model_request_item.custom_id or "",
|
350
|
+
priority=model_request_item.priority or 1,
|
351
|
+
org_id=batch_request_model.user_context.org_id or "",
|
352
|
+
user_id=batch_request_model.user_context.user_id or "",
|
353
|
+
client_type=batch_request_model.user_context.client_type or "",
|
354
|
+
extra=grpc_request_kwargs,
|
355
|
+
))
|
356
|
+
|
357
|
+
except Exception as e:
|
358
|
+
raise ValueError(f"构建请求失败: {str(e)},item={model_request_item.custom_id}") from e
|
359
|
+
|
360
|
+
try:
|
361
|
+
# 超时处理逻辑
|
362
|
+
invoke_timeout = timeout or self.default_invoke_timeout
|
363
|
+
|
364
|
+
# 调用 gRPC 接口
|
365
|
+
response = await self.stub.BatchInvoke(
|
366
|
+
model_service_pb2.ModelRequest(items=items),
|
367
|
+
timeout=invoke_timeout,
|
368
|
+
metadata=metadata
|
369
|
+
)
|
370
|
+
|
371
|
+
result = []
|
372
|
+
for res_item in response.items:
|
373
|
+
result.append(ModelResponse(
|
374
|
+
content=res_item.content,
|
375
|
+
usage=json.loads(res_item.usage) if res_item.usage else None,
|
376
|
+
raw_response=json.loads(res_item.raw_response) if res_item.raw_response else None,
|
377
|
+
error=res_item.error or None,
|
378
|
+
custom_id=res_item.custom_id if res_item.custom_id else None
|
379
|
+
))
|
380
|
+
return BatchModelResponse(
|
381
|
+
request_id=response.request_id if response.request_id else None,
|
382
|
+
responses=result
|
383
|
+
)
|
384
|
+
except grpc.RpcError as e:
|
385
|
+
raise ConnectionError(f"BatchInvoke failed: {str(e)}")
|
386
|
+
|
387
|
+
async def close(self):
|
388
|
+
"""关闭 gRPC 通道"""
|
389
|
+
if self.channel and not self._closed:
|
390
|
+
await self.channel.close()
|
391
|
+
self._closed = True
|
392
|
+
await self.channel.close()
|
393
|
+
logger.info("✅ gRPC channel closed")
|
394
|
+
|
395
|
+
def _safe_sync_close(self):
|
396
|
+
"""进程退出时自动关闭 channel(事件循环处理兼容)"""
|
397
|
+
if self.channel and not self._closed:
|
398
|
+
try:
|
399
|
+
loop = asyncio.get_event_loop()
|
400
|
+
if loop.is_running():
|
401
|
+
loop.create_task(self.close())
|
402
|
+
else:
|
403
|
+
loop.run_until_complete(self.close())
|
404
|
+
except Exception as e:
|
405
|
+
logger.warning(f"❌ gRPC channel close failed at exit: {e}")
|
406
|
+
|
407
|
+
async def __aenter__(self):
|
408
|
+
"""支持 async with 自动初始化连接"""
|
409
|
+
await self._ensure_initialized()
|
410
|
+
return self
|
411
|
+
|
412
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
413
|
+
"""支持 async with 自动关闭连接"""
|
414
|
+
await self.close()
|
@@ -0,0 +1,14 @@
|
|
1
|
+
import time
|
2
|
+
import jwt
|
3
|
+
|
4
|
+
|
5
|
+
# JWT 处理类
|
6
|
+
class JWTAuthHandler:
|
7
|
+
def __init__(self, secret_key: str):
|
8
|
+
self.secret_key = secret_key
|
9
|
+
|
10
|
+
def encode_token(self, payload: dict, expires_in: int = 3600) -> str:
|
11
|
+
"""生成带过期时间的 JWT Token"""
|
12
|
+
payload = payload.copy()
|
13
|
+
payload["exp"] = int(time.time()) + expires_in
|
14
|
+
return jwt.encode(payload, self.secret_key, algorithm="HS256")
|
@@ -0,0 +1,11 @@
|
|
1
|
+
class ModelManagerClientError(Exception):
|
2
|
+
"""Base exception for Model Manager Client errors"""
|
3
|
+
pass
|
4
|
+
|
5
|
+
class ConnectionError(ModelManagerClientError):
|
6
|
+
"""Raised when connection to gRPC server fails"""
|
7
|
+
pass
|
8
|
+
|
9
|
+
class ValidationError(ModelManagerClientError):
|
10
|
+
"""Raised when input validation fails"""
|
11
|
+
pass
|
File without changes
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
4
|
+
# source: model_service.proto
|
5
|
+
# Protobuf Python Version: 5.29.0
|
6
|
+
"""Generated protocol buffer code."""
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
11
|
+
from google.protobuf.internal import builder as _builder
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
14
|
+
5,
|
15
|
+
29,
|
16
|
+
0,
|
17
|
+
'',
|
18
|
+
'model_service.proto'
|
19
|
+
)
|
20
|
+
# @@protoc_insertion_point(imports)
|
21
|
+
|
22
|
+
_sym_db = _symbol_database.Default()
|
23
|
+
|
24
|
+
|
25
|
+
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
|
26
|
+
|
27
|
+
|
28
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13model_service.proto\x12\rmodel_service\x1a\x1cgoogle/protobuf/struct.proto\"\x82\x02\n\x10ModelRequestItem\x12\x10\n\x08provider\x18\x01 \x01(\t\x12\x0f\n\x07\x63hannel\x18\x02 \x01(\t\x12\x13\n\x0binvoke_type\x18\x03 \x01(\t\x12\x0e\n\x06stream\x18\x04 \x01(\x08\x12\x0e\n\x06org_id\x18\x05 \x01(\t\x12\x0f\n\x07user_id\x18\x06 \x01(\t\x12\x13\n\x0b\x63lient_type\x18\x07 \x01(\t\x12\x15\n\x08priority\x18\x08 \x01(\x05H\x00\x88\x01\x01\x12\x16\n\tcustom_id\x18\t \x01(\tH\x01\x88\x01\x01\x12&\n\x05\x65xtra\x18\n \x01(\x0b\x32\x17.google.protobuf.StructB\x0b\n\t_priorityB\x0c\n\n_custom_id\">\n\x0cModelRequest\x12.\n\x05items\x18\x01 \x03(\x0b\x32\x1f.model_service.ModelRequestItem\"\xa6\x01\n\x11ModelResponseItem\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\t\x12\r\n\x05usage\x18\x02 \x01(\t\x12\x14\n\x0craw_response\x18\x03 \x01(\t\x12\r\n\x05\x65rror\x18\x04 \x01(\t\x12\x16\n\tcustom_id\x18\x05 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x06 \x01(\tH\x01\x88\x01\x01\x42\x0c\n\n_custom_idB\r\n\x0b_request_id\"T\n\rModelResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12/\n\x05items\x18\x02 \x03(\x0b\x32 .model_service.ModelResponseItem2\xa7\x01\n\x0cModelService\x12M\n\x06Invoke\x12\x1f.model_service.ModelRequestItem\x1a .model_service.ModelResponseItem0\x01\x12H\n\x0b\x42\x61tchInvoke\x12\x1b.model_service.ModelRequest\x1a\x1c.model_service.ModelResponseb\x06proto3')
|
29
|
+
|
30
|
+
_globals = globals()
|
31
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
32
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'model_service_pb2', _globals)
|
33
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
34
|
+
DESCRIPTOR._loaded_options = None
|
35
|
+
_globals['_MODELREQUESTITEM']._serialized_start=69
|
36
|
+
_globals['_MODELREQUESTITEM']._serialized_end=327
|
37
|
+
_globals['_MODELREQUEST']._serialized_start=329
|
38
|
+
_globals['_MODELREQUEST']._serialized_end=391
|
39
|
+
_globals['_MODELRESPONSEITEM']._serialized_start=394
|
40
|
+
_globals['_MODELRESPONSEITEM']._serialized_end=560
|
41
|
+
_globals['_MODELRESPONSE']._serialized_start=562
|
42
|
+
_globals['_MODELRESPONSE']._serialized_end=646
|
43
|
+
_globals['_MODELSERVICE']._serialized_start=649
|
44
|
+
_globals['_MODELSERVICE']._serialized_end=816
|
45
|
+
# @@protoc_insertion_point(module_scope)
|