isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.1.1.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,239 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Third-party Services Backend - External API services with wrappers.
|
3
|
-
Examples: OpenAI, Anthropic, Cohere, Google AI, Azure OpenAI
|
4
|
-
"""
|
5
|
-
|
6
|
-
import aiohttp
|
7
|
-
import json
|
8
|
-
from typing import Dict, Any, List, Optional
|
9
|
-
from .base_backend_client import BaseBackendClient
|
10
|
-
|
11
|
-
|
12
|
-
class OpenAIClient(BaseBackendClient):
|
13
|
-
"""Wrapper for OpenAI API"""
|
14
|
-
|
15
|
-
def __init__(self, api_key: str, base_url: str = "https://api.openai.com/v1"):
|
16
|
-
self.api_key = api_key
|
17
|
-
self.base_url = base_url.rstrip('/')
|
18
|
-
self.headers = {
|
19
|
-
"Authorization": f"Bearer {api_key}",
|
20
|
-
"Content-Type": "application/json"
|
21
|
-
}
|
22
|
-
|
23
|
-
async def generate_completion(self, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
24
|
-
"""Generate completion using OpenAI API"""
|
25
|
-
async with aiohttp.ClientSession() as session:
|
26
|
-
payload = {
|
27
|
-
"model": model,
|
28
|
-
"prompt": prompt,
|
29
|
-
"max_tokens": kwargs.get("max_tokens", 100),
|
30
|
-
"temperature": kwargs.get("temperature", 0.7),
|
31
|
-
**kwargs
|
32
|
-
}
|
33
|
-
async with session.post(
|
34
|
-
f"{self.base_url}/completions",
|
35
|
-
json=payload,
|
36
|
-
headers=self.headers
|
37
|
-
) as response:
|
38
|
-
return await response.json()
|
39
|
-
|
40
|
-
async def generate_chat_completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
41
|
-
"""Generate chat completion using OpenAI API"""
|
42
|
-
async with aiohttp.ClientSession() as session:
|
43
|
-
payload = {
|
44
|
-
"model": model,
|
45
|
-
"messages": messages,
|
46
|
-
"max_tokens": kwargs.get("max_tokens", 100),
|
47
|
-
"temperature": kwargs.get("temperature", 0.7),
|
48
|
-
**kwargs
|
49
|
-
}
|
50
|
-
async with session.post(
|
51
|
-
f"{self.base_url}/chat/completions",
|
52
|
-
json=payload,
|
53
|
-
headers=self.headers
|
54
|
-
) as response:
|
55
|
-
return await response.json()
|
56
|
-
|
57
|
-
async def generate_embeddings(self, model: str, input_text: str, **kwargs) -> Dict[str, Any]:
|
58
|
-
"""Generate embeddings using OpenAI API"""
|
59
|
-
async with aiohttp.ClientSession() as session:
|
60
|
-
payload = {
|
61
|
-
"model": model,
|
62
|
-
"input": input_text,
|
63
|
-
**kwargs
|
64
|
-
}
|
65
|
-
async with session.post(
|
66
|
-
f"{self.base_url}/embeddings",
|
67
|
-
json=payload,
|
68
|
-
headers=self.headers
|
69
|
-
) as response:
|
70
|
-
return await response.json()
|
71
|
-
|
72
|
-
async def health_check(self) -> bool:
|
73
|
-
"""Check if OpenAI API is accessible"""
|
74
|
-
try:
|
75
|
-
async with aiohttp.ClientSession() as session:
|
76
|
-
async with session.get(f"{self.base_url}/models", headers=self.headers) as response:
|
77
|
-
return response.status == 200
|
78
|
-
except Exception:
|
79
|
-
return False
|
80
|
-
|
81
|
-
|
82
|
-
class AnthropicClient(BaseBackendClient):
|
83
|
-
"""Wrapper for Anthropic Claude API"""
|
84
|
-
|
85
|
-
def __init__(self, api_key: str, base_url: str = "https://api.anthropic.com/v1"):
|
86
|
-
self.api_key = api_key
|
87
|
-
self.base_url = base_url.rstrip('/')
|
88
|
-
self.headers = {
|
89
|
-
"x-api-key": api_key,
|
90
|
-
"Content-Type": "application/json",
|
91
|
-
"anthropic-version": "2023-06-01"
|
92
|
-
}
|
93
|
-
|
94
|
-
async def generate_chat_completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
95
|
-
"""Generate chat completion using Anthropic API"""
|
96
|
-
async with aiohttp.ClientSession() as session:
|
97
|
-
payload = {
|
98
|
-
"model": model,
|
99
|
-
"messages": messages,
|
100
|
-
"max_tokens": kwargs.get("max_tokens", 100),
|
101
|
-
**kwargs
|
102
|
-
}
|
103
|
-
async with session.post(
|
104
|
-
f"{self.base_url}/messages",
|
105
|
-
json=payload,
|
106
|
-
headers=self.headers
|
107
|
-
) as response:
|
108
|
-
return await response.json()
|
109
|
-
|
110
|
-
async def health_check(self) -> bool:
|
111
|
-
"""Check if Anthropic API is accessible"""
|
112
|
-
try:
|
113
|
-
# Anthropic doesn't have a models endpoint, so we'll just check the base URL
|
114
|
-
async with aiohttp.ClientSession() as session:
|
115
|
-
async with session.get(self.base_url, headers=self.headers) as response:
|
116
|
-
return response.status in [200, 404] # 404 is also acceptable for base URL
|
117
|
-
except Exception:
|
118
|
-
return False
|
119
|
-
|
120
|
-
|
121
|
-
class CohereClient(BaseBackendClient):
|
122
|
-
"""Wrapper for Cohere API"""
|
123
|
-
|
124
|
-
def __init__(self, api_key: str, base_url: str = "https://api.cohere.ai/v1"):
|
125
|
-
self.api_key = api_key
|
126
|
-
self.base_url = base_url.rstrip('/')
|
127
|
-
self.headers = {
|
128
|
-
"Authorization": f"Bearer {api_key}",
|
129
|
-
"Content-Type": "application/json"
|
130
|
-
}
|
131
|
-
|
132
|
-
async def generate_completion(self, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
133
|
-
"""Generate completion using Cohere API"""
|
134
|
-
async with aiohttp.ClientSession() as session:
|
135
|
-
payload = {
|
136
|
-
"model": model,
|
137
|
-
"prompt": prompt,
|
138
|
-
"max_tokens": kwargs.get("max_tokens", 100),
|
139
|
-
"temperature": kwargs.get("temperature", 0.7),
|
140
|
-
**kwargs
|
141
|
-
}
|
142
|
-
async with session.post(
|
143
|
-
f"{self.base_url}/generate",
|
144
|
-
json=payload,
|
145
|
-
headers=self.headers
|
146
|
-
) as response:
|
147
|
-
return await response.json()
|
148
|
-
|
149
|
-
async def generate_embeddings(self, model: str, texts: List[str], **kwargs) -> Dict[str, Any]:
|
150
|
-
"""Generate embeddings using Cohere API"""
|
151
|
-
async with aiohttp.ClientSession() as session:
|
152
|
-
payload = {
|
153
|
-
"model": model,
|
154
|
-
"texts": texts,
|
155
|
-
**kwargs
|
156
|
-
}
|
157
|
-
async with session.post(
|
158
|
-
f"{self.base_url}/embed",
|
159
|
-
json=payload,
|
160
|
-
headers=self.headers
|
161
|
-
) as response:
|
162
|
-
return await response.json()
|
163
|
-
|
164
|
-
async def health_check(self) -> bool:
|
165
|
-
"""Check if Cohere API is accessible"""
|
166
|
-
try:
|
167
|
-
async with aiohttp.ClientSession() as session:
|
168
|
-
async with session.get(f"{self.base_url}/check-api-key", headers=self.headers) as response:
|
169
|
-
return response.status == 200
|
170
|
-
except Exception:
|
171
|
-
return False
|
172
|
-
|
173
|
-
|
174
|
-
class AzureOpenAIClient(BaseBackendClient):
|
175
|
-
"""Wrapper for Azure OpenAI API"""
|
176
|
-
|
177
|
-
def __init__(self, api_key: str, endpoint: str, api_version: str = "2023-12-01-preview"):
|
178
|
-
self.api_key = api_key
|
179
|
-
self.endpoint = endpoint.rstrip('/')
|
180
|
-
self.api_version = api_version
|
181
|
-
self.headers = {
|
182
|
-
"api-key": api_key,
|
183
|
-
"Content-Type": "application/json"
|
184
|
-
}
|
185
|
-
|
186
|
-
async def generate_chat_completion(self, deployment_name: str, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
|
187
|
-
"""Generate chat completion using Azure OpenAI API"""
|
188
|
-
async with aiohttp.ClientSession() as session:
|
189
|
-
payload = {
|
190
|
-
"messages": messages,
|
191
|
-
"max_tokens": kwargs.get("max_tokens", 100),
|
192
|
-
"temperature": kwargs.get("temperature", 0.7),
|
193
|
-
**kwargs
|
194
|
-
}
|
195
|
-
url = f"{self.endpoint}/openai/deployments/{deployment_name}/chat/completions?api-version={self.api_version}"
|
196
|
-
async with session.post(url, json=payload, headers=self.headers) as response:
|
197
|
-
return await response.json()
|
198
|
-
|
199
|
-
async def health_check(self) -> bool:
|
200
|
-
"""Check if Azure OpenAI API is accessible"""
|
201
|
-
try:
|
202
|
-
async with aiohttp.ClientSession() as session:
|
203
|
-
url = f"{self.endpoint}/openai/models?api-version={self.api_version}"
|
204
|
-
async with session.get(url, headers=self.headers) as response:
|
205
|
-
return response.status == 200
|
206
|
-
except Exception:
|
207
|
-
return False
|
208
|
-
|
209
|
-
|
210
|
-
class GoogleAIClient(BaseBackendClient):
|
211
|
-
"""Wrapper for Google AI (Gemini) API"""
|
212
|
-
|
213
|
-
def __init__(self, api_key: str, base_url: str = "https://generativelanguage.googleapis.com/v1"):
|
214
|
-
self.api_key = api_key
|
215
|
-
self.base_url = base_url.rstrip('/')
|
216
|
-
|
217
|
-
async def generate_completion(self, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
218
|
-
"""Generate completion using Google AI API"""
|
219
|
-
async with aiohttp.ClientSession() as session:
|
220
|
-
payload = {
|
221
|
-
"contents": [{"parts": [{"text": prompt}]}],
|
222
|
-
"generationConfig": {
|
223
|
-
"maxOutputTokens": kwargs.get("max_tokens", 100),
|
224
|
-
"temperature": kwargs.get("temperature", 0.7),
|
225
|
-
}
|
226
|
-
}
|
227
|
-
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
228
|
-
async with session.post(url, json=payload) as response:
|
229
|
-
return await response.json()
|
230
|
-
|
231
|
-
async def health_check(self) -> bool:
|
232
|
-
"""Check if Google AI API is accessible"""
|
233
|
-
try:
|
234
|
-
async with aiohttp.ClientSession() as session:
|
235
|
-
url = f"{self.base_url}/models?key={self.api_key}"
|
236
|
-
async with session.get(url) as response:
|
237
|
-
return response.status == 200
|
238
|
-
except Exception:
|
239
|
-
return False
|
@@ -1,97 +0,0 @@
|
|
1
|
-
|
2
|
-
import aiohttp
|
3
|
-
import json
|
4
|
-
from typing import Dict, Any, List, Optional, AsyncGenerator
|
5
|
-
from .base_backend_client import BaseBackendClient
|
6
|
-
|
7
|
-
|
8
|
-
class TritonBackendClient(BaseBackendClient):
|
9
|
-
"""Pure connection client for Triton Inference Server"""
|
10
|
-
|
11
|
-
def __init__(self, url: str = "localhost:8000", protocol: str = "http"):
|
12
|
-
self.base_url = f"http://{url}" if not url.startswith("http") else url
|
13
|
-
self.protocol = protocol
|
14
|
-
self._session = None
|
15
|
-
|
16
|
-
async def _get_session(self):
|
17
|
-
"""Get or create HTTP session"""
|
18
|
-
if self._session is None:
|
19
|
-
self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
|
20
|
-
return self._session
|
21
|
-
|
22
|
-
async def post(self, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
23
|
-
"""Make POST request to Triton server"""
|
24
|
-
session = await self._get_session()
|
25
|
-
async with session.post(f"{self.base_url}{endpoint}", json=payload) as response:
|
26
|
-
response.raise_for_status()
|
27
|
-
return await response.json()
|
28
|
-
|
29
|
-
async def get(self, endpoint: str) -> Dict[str, Any]:
|
30
|
-
"""Make GET request to Triton server"""
|
31
|
-
session = await self._get_session()
|
32
|
-
async with session.get(f"{self.base_url}{endpoint}") as response:
|
33
|
-
response.raise_for_status()
|
34
|
-
return await response.json()
|
35
|
-
|
36
|
-
async def model_ready(self, model_name: str) -> bool:
|
37
|
-
"""Check if model is ready"""
|
38
|
-
try:
|
39
|
-
await self.get(f"/v2/models/{model_name}/ready")
|
40
|
-
return True
|
41
|
-
except Exception:
|
42
|
-
return False
|
43
|
-
|
44
|
-
async def model_metadata(self, model_name: str) -> Dict[str, Any]:
|
45
|
-
"""Get model metadata"""
|
46
|
-
return await self.get(f"/v2/models/{model_name}")
|
47
|
-
|
48
|
-
async def server_ready(self) -> bool:
|
49
|
-
"""Check if server is ready"""
|
50
|
-
try:
|
51
|
-
await self.get("/v2/health/ready")
|
52
|
-
return True
|
53
|
-
except Exception:
|
54
|
-
return False
|
55
|
-
|
56
|
-
async def health_check(self) -> bool:
|
57
|
-
"""Check server health"""
|
58
|
-
return await self.server_ready()
|
59
|
-
|
60
|
-
async def close(self):
|
61
|
-
"""Close the HTTP session"""
|
62
|
-
if self._session:
|
63
|
-
await self._session.close()
|
64
|
-
self._session = None
|
65
|
-
|
66
|
-
|
67
|
-
# Keep old class name for backward compatibility
|
68
|
-
class TritonClient(TritonBackendClient):
|
69
|
-
"""Backward compatibility alias"""
|
70
|
-
|
71
|
-
def __init__(self, backend_connector_config: Dict = None, config: Dict = None):
|
72
|
-
if backend_connector_config:
|
73
|
-
url = backend_connector_config.get("url", "localhost:8000")
|
74
|
-
else:
|
75
|
-
url = "localhost:8000"
|
76
|
-
super().__init__(url)
|
77
|
-
|
78
|
-
async def infer(self,
|
79
|
-
model_runtime_config: Dict,
|
80
|
-
unified_request_payload: Dict,
|
81
|
-
task_type: str,
|
82
|
-
request_id: str) -> Dict:
|
83
|
-
"""Legacy method for backward compatibility"""
|
84
|
-
# This is a placeholder for the old interface
|
85
|
-
# New code should use the direct HTTP methods
|
86
|
-
raise NotImplementedError("Use direct HTTP methods instead")
|
87
|
-
|
88
|
-
async def stream(self,
|
89
|
-
model_runtime_config: Dict,
|
90
|
-
unified_request_payload: Dict,
|
91
|
-
task_type: str,
|
92
|
-
request_id: str) -> AsyncGenerator[Dict, None]:
|
93
|
-
"""Legacy method for backward compatibility"""
|
94
|
-
# This is a placeholder for the old interface
|
95
|
-
# New code should use the direct HTTP methods
|
96
|
-
raise NotImplementedError("Use direct HTTP methods instead")
|
97
|
-
yield # Make it a generator
|
@@ -1,134 +0,0 @@
|
|
1
|
-
# Universal Inference Client
|
2
|
-
"""
|
3
|
-
旨在为开发者提供一个高级、统一且易于使用的Python客户端库,用于与“通用推理平台”(Universal Inference Platform,
|
4
|
-
即您正在构建的整个系统)进行交互。该客户端封装了与平台后端 orchestrator_adapter 服务通信的所有复杂性,
|
5
|
-
允许用户通过面向任务(task-oriented)的方法来调用各种AI模型(语言、视觉、语音等),
|
6
|
-
而无需关心这些模型具体由哪个推理引擎(PyTorch, vLLM, Triton, Ollama)承载,或者它们是本地部署的模型还是外部API服务。
|
7
|
-
"""
|
8
|
-
|
9
|
-
|
10
|
-
import httpx
|
11
|
-
from typing import Dict, List, Union, Optional, AsyncGenerator
|
12
|
-
from .client_sdk_schema import *
|
13
|
-
|
14
|
-
class UniversalInferenceClient:
|
15
|
-
def __init__(self):
|
16
|
-
|
17
|
-
self.adapter_url = "http://adapter.isa_model.com/api/v1"
|
18
|
-
self.adapter_key = "isa_model_adapter"
|
19
|
-
self.client = httpx.AsyncClient(
|
20
|
-
base_url=self.adapter_url,
|
21
|
-
headers={"Authorization": f"Bearer {self.adapter_key}"}
|
22
|
-
)
|
23
|
-
|
24
|
-
async def _make_request(self,
|
25
|
-
method: str,
|
26
|
-
url: str,
|
27
|
-
params: Optional[Dict] = None,
|
28
|
-
data: Optional[Dict] = None,
|
29
|
-
headers: Optional[Dict] = None,
|
30
|
-
**kwargs) -> httpx.Response:
|
31
|
-
"""
|
32
|
-
Make a request to the adapter service
|
33
|
-
"""
|
34
|
-
headers = headers or {}
|
35
|
-
headers["Authorization"] = f"Bearer {self.adapter_key}"
|
36
|
-
|
37
|
-
async with self.client as client:
|
38
|
-
response = await client.request(
|
39
|
-
method,
|
40
|
-
url,
|
41
|
-
params=params,
|
42
|
-
data=data,
|
43
|
-
headers=headers,
|
44
|
-
**kwargs
|
45
|
-
)
|
46
|
-
response.raise_for_status()
|
47
|
-
return response.json()
|
48
|
-
|
49
|
-
async def invoke(self,
|
50
|
-
model_id: str,
|
51
|
-
raw_task_payload: Dict,
|
52
|
-
stream: bool = False,
|
53
|
-
**kwargs) -> Union[Dict, AsyncGenerator[Dict, None]]:
|
54
|
-
pass
|
55
|
-
|
56
|
-
async def chat(self,
|
57
|
-
model_id: str,
|
58
|
-
messages: List[Dict[str, str]],
|
59
|
-
stream: bool = False,
|
60
|
-
temperature: float = 0.7,
|
61
|
-
max_tokens: int = 1000) -> Union[UnifiedChatResponse, AsyncGenerator[UnifiedChatResponse, None]]:
|
62
|
-
pass
|
63
|
-
|
64
|
-
async def generate_text(self,
|
65
|
-
model_id: str,
|
66
|
-
prompt: str,
|
67
|
-
stream: bool = False,
|
68
|
-
temperature: float = 0.7,
|
69
|
-
max_tokens: int = 1000) -> Union[UnifiedTextResponse, AsyncGenerator[UnifiedTextChunk, None]]:
|
70
|
-
pass
|
71
|
-
|
72
|
-
async def embed(self,
|
73
|
-
model_id: str,
|
74
|
-
inputs: Union[str, List[str]],
|
75
|
-
input_type: str = "document",
|
76
|
-
**kwargs) -> UnifiedEmbeddingResponse:
|
77
|
-
pass
|
78
|
-
|
79
|
-
async def rerank(self,
|
80
|
-
model_id: str,
|
81
|
-
query: str,
|
82
|
-
documents: List[Union[str, Dict]],
|
83
|
-
top_k: Optional[int] = None,
|
84
|
-
**kwargs) -> UnifiedRerankResponse:
|
85
|
-
pass
|
86
|
-
|
87
|
-
async def transcribe_audio(self,
|
88
|
-
model_id: str,
|
89
|
-
audio_data: bytes,
|
90
|
-
language: Optional[str] = None,
|
91
|
-
**kwargs) -> UnifiedAudioTranscriptionResponse:
|
92
|
-
pass
|
93
|
-
|
94
|
-
async def generate_speech(self,
|
95
|
-
model_id: str,
|
96
|
-
text: str,
|
97
|
-
voice_id: Optional[str] = None,
|
98
|
-
**kwargs) -> UnifiedSpeechGenerationResponse:
|
99
|
-
pass
|
100
|
-
|
101
|
-
async def analyze_image(self,
|
102
|
-
model_id: str,
|
103
|
-
image_data: bytes,
|
104
|
-
query: str,
|
105
|
-
**kwargs) -> UnifiedImageAnalysisResponse:
|
106
|
-
pass
|
107
|
-
|
108
|
-
async def generate_image(self,
|
109
|
-
model_id: str,
|
110
|
-
prompt: str,
|
111
|
-
**kwargs) -> UnifiedImageGenerationResponse:
|
112
|
-
pass
|
113
|
-
|
114
|
-
async def generate_video(self,
|
115
|
-
model_id: str,
|
116
|
-
prompt: str,
|
117
|
-
**kwargs) -> UnifiedVideoGenerationResponse:
|
118
|
-
pass
|
119
|
-
|
120
|
-
async def generate_audio(self,
|
121
|
-
model_id: str,
|
122
|
-
text: str,
|
123
|
-
voice_id: Optional[str] = None,
|
124
|
-
**kwargs) -> UnifiedAudioGenerationResponse:
|
125
|
-
pass
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
@@ -1,34 +0,0 @@
|
|
1
|
-
# Inside UniversalInferenceClient.chat method
|
2
|
-
processed_messages = []
|
3
|
-
for user_msg in messages: # messages is what the end-user provided to the SDK
|
4
|
-
if isinstance(user_msg, dict) and "role" in user_msg and "content" in user_msg:
|
5
|
-
processed_messages.append(user_msg)
|
6
|
-
elif isinstance(user_msg, BaseMessage): # If user passes LangChain messages directly
|
7
|
-
# Serialize BaseMessage to our standard dict format
|
8
|
-
msg_dict = {"role": "", "content": user_msg.content}
|
9
|
-
if isinstance(user_msg, HumanMessage):
|
10
|
-
msg_dict["role"] = "user"
|
11
|
-
elif isinstance(user_msg, AIMessage):
|
12
|
-
msg_dict["role"] = "assistant"
|
13
|
-
if hasattr(user_msg, "tool_calls") and user_msg.tool_calls:
|
14
|
-
# Serialize tool_calls to a list of dicts
|
15
|
-
msg_dict["tool_calls"] = [
|
16
|
-
{"id": tc.get("id"), "type": "function", "function": {"name": tc.get("name"), "arguments": json.dumps(tc.get("args", {}))}}
|
17
|
-
for tc in user_msg.tool_calls # Assuming user_msg.tool_calls are already dicts or serializable
|
18
|
-
]
|
19
|
-
elif isinstance(user_msg, SystemMessage):
|
20
|
-
msg_dict["role"] = "system"
|
21
|
-
elif isinstance(user_msg, ToolMessage):
|
22
|
-
msg_dict["role"] = "tool"
|
23
|
-
msg_dict["tool_call_id"] = user_msg.tool_call_id
|
24
|
-
else:
|
25
|
-
# Handle other BaseMessage types or raise error
|
26
|
-
pass
|
27
|
-
processed_messages.append(msg_dict)
|
28
|
-
# ... (add more flexible input handling if needed, e.g., a list of tuples)
|
29
|
-
else:
|
30
|
-
raise ValueError("Unsupported message format in chat method input.")
|
31
|
-
|
32
|
-
# task_specific_payload for orchestrator would be:
|
33
|
-
# task_payload = {"messages": processed_messages}
|
34
|
-
# Then call self._invoke_orchestrator(model_id, task_payload, ...)
|
@@ -1,16 +0,0 @@
|
|
1
|
-
|
2
|
-
"""
|
3
|
-
在这里定义(使用Pydantic或Protobuf)一个标准的 ChatMessageSchema,它包含 role: str, content: Optional[str],
|
4
|
-
tool_calls: Optional[List[ToolCallSchema]], tool_call_id: Optional[str] 等字段。
|
5
|
-
UnifiedRestInvokeRequest (或gRPC的 UnifiedRequest) 中的 task_specific_payload 字段,在处理聊天任务时,
|
6
|
-
其内部的 "messages" 键对应的值就是 List[ChatMessageSchema]。
|
7
|
-
|
8
|
-
"""
|
9
|
-
|
10
|
-
from pydantic import BaseModel
|
11
|
-
|
12
|
-
class UnifiedChatResponse(BaseModel):
|
13
|
-
pass
|
14
|
-
|
15
|
-
class UnifiedTextResponse(BaseModel):
|
16
|
-
pass
|
File without changes
|