isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
1
+ """
2
+ Third-party Services Backend - External API services with wrappers.
3
+ Examples: OpenAI, Anthropic, Cohere, Google AI, Azure OpenAI
4
+ """
5
+
6
+ import aiohttp
7
+ import json
8
+ from typing import Dict, Any, List, Optional
9
+ from .base_backend_client import BaseBackendClient
10
+
11
+
12
+ class OpenAIClient(BaseBackendClient):
13
+ """Wrapper for OpenAI API"""
14
+
15
+ def __init__(self, api_key: str, base_url: str = "https://api.openai.com/v1"):
16
+ self.api_key = api_key
17
+ self.base_url = base_url.rstrip('/')
18
+ self.headers = {
19
+ "Authorization": f"Bearer {api_key}",
20
+ "Content-Type": "application/json"
21
+ }
22
+
23
+ async def generate_completion(self, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
24
+ """Generate completion using OpenAI API"""
25
+ async with aiohttp.ClientSession() as session:
26
+ payload = {
27
+ "model": model,
28
+ "prompt": prompt,
29
+ "max_tokens": kwargs.get("max_tokens", 100),
30
+ "temperature": kwargs.get("temperature", 0.7),
31
+ **kwargs
32
+ }
33
+ async with session.post(
34
+ f"{self.base_url}/completions",
35
+ json=payload,
36
+ headers=self.headers
37
+ ) as response:
38
+ return await response.json()
39
+
40
+ async def generate_chat_completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
41
+ """Generate chat completion using OpenAI API"""
42
+ async with aiohttp.ClientSession() as session:
43
+ payload = {
44
+ "model": model,
45
+ "messages": messages,
46
+ "max_tokens": kwargs.get("max_tokens", 100),
47
+ "temperature": kwargs.get("temperature", 0.7),
48
+ **kwargs
49
+ }
50
+ async with session.post(
51
+ f"{self.base_url}/chat/completions",
52
+ json=payload,
53
+ headers=self.headers
54
+ ) as response:
55
+ return await response.json()
56
+
57
+ async def generate_embeddings(self, model: str, input_text: str, **kwargs) -> Dict[str, Any]:
58
+ """Generate embeddings using OpenAI API"""
59
+ async with aiohttp.ClientSession() as session:
60
+ payload = {
61
+ "model": model,
62
+ "input": input_text,
63
+ **kwargs
64
+ }
65
+ async with session.post(
66
+ f"{self.base_url}/embeddings",
67
+ json=payload,
68
+ headers=self.headers
69
+ ) as response:
70
+ return await response.json()
71
+
72
+ async def health_check(self) -> bool:
73
+ """Check if OpenAI API is accessible"""
74
+ try:
75
+ async with aiohttp.ClientSession() as session:
76
+ async with session.get(f"{self.base_url}/models", headers=self.headers) as response:
77
+ return response.status == 200
78
+ except Exception:
79
+ return False
80
+
81
+
82
+ class AnthropicClient(BaseBackendClient):
83
+ """Wrapper for Anthropic Claude API"""
84
+
85
+ def __init__(self, api_key: str, base_url: str = "https://api.anthropic.com/v1"):
86
+ self.api_key = api_key
87
+ self.base_url = base_url.rstrip('/')
88
+ self.headers = {
89
+ "x-api-key": api_key,
90
+ "Content-Type": "application/json",
91
+ "anthropic-version": "2023-06-01"
92
+ }
93
+
94
+ async def generate_chat_completion(self, model: str, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
95
+ """Generate chat completion using Anthropic API"""
96
+ async with aiohttp.ClientSession() as session:
97
+ payload = {
98
+ "model": model,
99
+ "messages": messages,
100
+ "max_tokens": kwargs.get("max_tokens", 100),
101
+ **kwargs
102
+ }
103
+ async with session.post(
104
+ f"{self.base_url}/messages",
105
+ json=payload,
106
+ headers=self.headers
107
+ ) as response:
108
+ return await response.json()
109
+
110
+ async def health_check(self) -> bool:
111
+ """Check if Anthropic API is accessible"""
112
+ try:
113
+ # Anthropic doesn't have a models endpoint, so we'll just check the base URL
114
+ async with aiohttp.ClientSession() as session:
115
+ async with session.get(self.base_url, headers=self.headers) as response:
116
+ return response.status in [200, 404] # 404 is also acceptable for base URL
117
+ except Exception:
118
+ return False
119
+
120
+
121
+ class CohereClient(BaseBackendClient):
122
+ """Wrapper for Cohere API"""
123
+
124
+ def __init__(self, api_key: str, base_url: str = "https://api.cohere.ai/v1"):
125
+ self.api_key = api_key
126
+ self.base_url = base_url.rstrip('/')
127
+ self.headers = {
128
+ "Authorization": f"Bearer {api_key}",
129
+ "Content-Type": "application/json"
130
+ }
131
+
132
+ async def generate_completion(self, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
133
+ """Generate completion using Cohere API"""
134
+ async with aiohttp.ClientSession() as session:
135
+ payload = {
136
+ "model": model,
137
+ "prompt": prompt,
138
+ "max_tokens": kwargs.get("max_tokens", 100),
139
+ "temperature": kwargs.get("temperature", 0.7),
140
+ **kwargs
141
+ }
142
+ async with session.post(
143
+ f"{self.base_url}/generate",
144
+ json=payload,
145
+ headers=self.headers
146
+ ) as response:
147
+ return await response.json()
148
+
149
+ async def generate_embeddings(self, model: str, texts: List[str], **kwargs) -> Dict[str, Any]:
150
+ """Generate embeddings using Cohere API"""
151
+ async with aiohttp.ClientSession() as session:
152
+ payload = {
153
+ "model": model,
154
+ "texts": texts,
155
+ **kwargs
156
+ }
157
+ async with session.post(
158
+ f"{self.base_url}/embed",
159
+ json=payload,
160
+ headers=self.headers
161
+ ) as response:
162
+ return await response.json()
163
+
164
+ async def health_check(self) -> bool:
165
+ """Check if Cohere API is accessible"""
166
+ try:
167
+ async with aiohttp.ClientSession() as session:
168
+ async with session.get(f"{self.base_url}/check-api-key", headers=self.headers) as response:
169
+ return response.status == 200
170
+ except Exception:
171
+ return False
172
+
173
+
174
+ class AzureOpenAIClient(BaseBackendClient):
175
+ """Wrapper for Azure OpenAI API"""
176
+
177
+ def __init__(self, api_key: str, endpoint: str, api_version: str = "2023-12-01-preview"):
178
+ self.api_key = api_key
179
+ self.endpoint = endpoint.rstrip('/')
180
+ self.api_version = api_version
181
+ self.headers = {
182
+ "api-key": api_key,
183
+ "Content-Type": "application/json"
184
+ }
185
+
186
+ async def generate_chat_completion(self, deployment_name: str, messages: List[Dict[str, str]], **kwargs) -> Dict[str, Any]:
187
+ """Generate chat completion using Azure OpenAI API"""
188
+ async with aiohttp.ClientSession() as session:
189
+ payload = {
190
+ "messages": messages,
191
+ "max_tokens": kwargs.get("max_tokens", 100),
192
+ "temperature": kwargs.get("temperature", 0.7),
193
+ **kwargs
194
+ }
195
+ url = f"{self.endpoint}/openai/deployments/{deployment_name}/chat/completions?api-version={self.api_version}"
196
+ async with session.post(url, json=payload, headers=self.headers) as response:
197
+ return await response.json()
198
+
199
+ async def health_check(self) -> bool:
200
+ """Check if Azure OpenAI API is accessible"""
201
+ try:
202
+ async with aiohttp.ClientSession() as session:
203
+ url = f"{self.endpoint}/openai/models?api-version={self.api_version}"
204
+ async with session.get(url, headers=self.headers) as response:
205
+ return response.status == 200
206
+ except Exception:
207
+ return False
208
+
209
+
210
+ class GoogleAIClient(BaseBackendClient):
211
+ """Wrapper for Google AI (Gemini) API"""
212
+
213
+ def __init__(self, api_key: str, base_url: str = "https://generativelanguage.googleapis.com/v1"):
214
+ self.api_key = api_key
215
+ self.base_url = base_url.rstrip('/')
216
+
217
+ async def generate_completion(self, model: str, prompt: str, **kwargs) -> Dict[str, Any]:
218
+ """Generate completion using Google AI API"""
219
+ async with aiohttp.ClientSession() as session:
220
+ payload = {
221
+ "contents": [{"parts": [{"text": prompt}]}],
222
+ "generationConfig": {
223
+ "maxOutputTokens": kwargs.get("max_tokens", 100),
224
+ "temperature": kwargs.get("temperature", 0.7),
225
+ }
226
+ }
227
+ url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
228
+ async with session.post(url, json=payload) as response:
229
+ return await response.json()
230
+
231
+ async def health_check(self) -> bool:
232
+ """Check if Google AI API is accessible"""
233
+ try:
234
+ async with aiohttp.ClientSession() as session:
235
+ url = f"{self.base_url}/models?key={self.api_key}"
236
+ async with session.get(url) as response:
237
+ return response.status == 200
238
+ except Exception:
239
+ return False
@@ -0,0 +1,97 @@
1
+
2
+ import aiohttp
3
+ import json
4
+ from typing import Dict, Any, List, Optional, AsyncGenerator
5
+ from .base_backend_client import BaseBackendClient
6
+
7
+
8
+ class TritonBackendClient(BaseBackendClient):
9
+ """Pure connection client for Triton Inference Server"""
10
+
11
+ def __init__(self, url: str = "localhost:8000", protocol: str = "http"):
12
+ self.base_url = f"http://{url}" if not url.startswith("http") else url
13
+ self.protocol = protocol
14
+ self._session = None
15
+
16
+ async def _get_session(self):
17
+ """Get or create HTTP session"""
18
+ if self._session is None:
19
+ self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
20
+ return self._session
21
+
22
+ async def post(self, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:
23
+ """Make POST request to Triton server"""
24
+ session = await self._get_session()
25
+ async with session.post(f"{self.base_url}{endpoint}", json=payload) as response:
26
+ response.raise_for_status()
27
+ return await response.json()
28
+
29
+ async def get(self, endpoint: str) -> Dict[str, Any]:
30
+ """Make GET request to Triton server"""
31
+ session = await self._get_session()
32
+ async with session.get(f"{self.base_url}{endpoint}") as response:
33
+ response.raise_for_status()
34
+ return await response.json()
35
+
36
+ async def model_ready(self, model_name: str) -> bool:
37
+ """Check if model is ready"""
38
+ try:
39
+ await self.get(f"/v2/models/{model_name}/ready")
40
+ return True
41
+ except Exception:
42
+ return False
43
+
44
+ async def model_metadata(self, model_name: str) -> Dict[str, Any]:
45
+ """Get model metadata"""
46
+ return await self.get(f"/v2/models/{model_name}")
47
+
48
+ async def server_ready(self) -> bool:
49
+ """Check if server is ready"""
50
+ try:
51
+ await self.get("/v2/health/ready")
52
+ return True
53
+ except Exception:
54
+ return False
55
+
56
+ async def health_check(self) -> bool:
57
+ """Check server health"""
58
+ return await self.server_ready()
59
+
60
+ async def close(self):
61
+ """Close the HTTP session"""
62
+ if self._session:
63
+ await self._session.close()
64
+ self._session = None
65
+
66
+
67
+ # Keep old class name for backward compatibility
68
+ class TritonClient(TritonBackendClient):
69
+ """Backward compatibility alias"""
70
+
71
+ def __init__(self, backend_connector_config: Dict = None, config: Dict = None):
72
+ if backend_connector_config:
73
+ url = backend_connector_config.get("url", "localhost:8000")
74
+ else:
75
+ url = "localhost:8000"
76
+ super().__init__(url)
77
+
78
+ async def infer(self,
79
+ model_runtime_config: Dict,
80
+ unified_request_payload: Dict,
81
+ task_type: str,
82
+ request_id: str) -> Dict:
83
+ """Legacy method for backward compatibility"""
84
+ # This is a placeholder for the old interface
85
+ # New code should use the direct HTTP methods
86
+ raise NotImplementedError("Use direct HTTP methods instead")
87
+
88
+ async def stream(self,
89
+ model_runtime_config: Dict,
90
+ unified_request_payload: Dict,
91
+ task_type: str,
92
+ request_id: str) -> AsyncGenerator[Dict, None]:
93
+ """Legacy method for backward compatibility"""
94
+ # This is a placeholder for the old interface
95
+ # New code should use the direct HTTP methods
96
+ raise NotImplementedError("Use direct HTTP methods instead")
97
+ yield # Make it a generator
@@ -0,0 +1,46 @@
1
+ """
2
+ Base definitions for the Inference layer.
3
+ """
4
+
5
+ from enum import Enum, auto
6
+ from typing import Dict, List, Optional, Any, Union, TypeVar, Generic
7
+
8
+ T = TypeVar('T')
9
+
10
+
11
+ class ModelType(str, Enum):
12
+ """Types of AI models supported by the framework."""
13
+ LLM = "llm"
14
+ EMBEDDING = "embedding"
15
+ VISION = "vision"
16
+ AUDIO = "audio"
17
+ OCR = "ocr"
18
+ TTS = "tts"
19
+ RERANK = "rerank"
20
+ MULTIMODAL = "multimodal"
21
+
22
+
23
+ class Capability(str, Enum):
24
+ """Capabilities supported by models."""
25
+ CHAT = "chat"
26
+ COMPLETION = "completion"
27
+ EMBEDDING = "embedding"
28
+ IMAGE_GENERATION = "image_generation"
29
+ IMAGE_CLASSIFICATION = "image_classification"
30
+ OBJECT_DETECTION = "object_detection"
31
+ SPEECH_TO_TEXT = "speech_to_text"
32
+ TEXT_TO_SPEECH = "text_to_speech"
33
+ OCR = "ocr"
34
+ RERANKING = "reranking"
35
+ MULTIMODAL_UNDERSTANDING = "multimodal_understanding"
36
+
37
+
38
+ class RoutingStrategy(str, Enum):
39
+ """Routing strategies for distributing requests among model replicas."""
40
+ ROUND_ROBIN = "round_robin"
41
+ WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
42
+ LEAST_CONNECTIONS = "least_connections"
43
+ RESPONSE_TIME = "response_time"
44
+ RANDOM = "random"
45
+ CONSISTENT_HASH = "consistent_hash"
46
+ DYNAMIC_LOAD_BALANCING = "dynamic_load_balancing"
File without changes
@@ -0,0 +1,134 @@
1
+ # Universal Inference Client
2
+ """
3
+ 旨在为开发者提供一个高级、统一且易于使用的Python客户端库,用于与“通用推理平台”(Universal Inference Platform,
4
+ 即您正在构建的整个系统)进行交互。该客户端封装了与平台后端 orchestrator_adapter 服务通信的所有复杂性,
5
+ 允许用户通过面向任务(task-oriented)的方法来调用各种AI模型(语言、视觉、语音等),
6
+ 而无需关心这些模型具体由哪个推理引擎(PyTorch, vLLM, Triton, Ollama)承载,或者它们是本地部署的模型还是外部API服务。
7
+ """
8
+
9
+
10
+ import httpx
11
+ from typing import Dict, List, Union, Optional, AsyncGenerator
12
+ from .client_sdk_schema import *
13
+
14
+ class UniversalInferenceClient:
15
+ def __init__(self):
16
+
17
+ self.adapter_url = "http://adapter.isa_model.com/api/v1"
18
+ self.adapter_key = "isa_model_adapter"
19
+ self.client = httpx.AsyncClient(
20
+ base_url=self.adapter_url,
21
+ headers={"Authorization": f"Bearer {self.adapter_key}"}
22
+ )
23
+
24
+ async def _make_request(self,
25
+ method: str,
26
+ url: str,
27
+ params: Optional[Dict] = None,
28
+ data: Optional[Dict] = None,
29
+ headers: Optional[Dict] = None,
30
+ **kwargs) -> httpx.Response:
31
+ """
32
+ Make a request to the adapter service
33
+ """
34
+ headers = headers or {}
35
+ headers["Authorization"] = f"Bearer {self.adapter_key}"
36
+
37
+ async with self.client as client:
38
+ response = await client.request(
39
+ method,
40
+ url,
41
+ params=params,
42
+ data=data,
43
+ headers=headers,
44
+ **kwargs
45
+ )
46
+ response.raise_for_status()
47
+ return response.json()
48
+
49
+ async def invoke(self,
50
+ model_id: str,
51
+ raw_task_payload: Dict,
52
+ stream: bool = False,
53
+ **kwargs) -> Union[Dict, AsyncGenerator[Dict, None]]:
54
+ pass
55
+
56
+ async def chat(self,
57
+ model_id: str,
58
+ messages: List[Dict[str, str]],
59
+ stream: bool = False,
60
+ temperature: float = 0.7,
61
+ max_tokens: int = 1000) -> Union[UnifiedChatResponse, AsyncGenerator[UnifiedChatResponse, None]]:
62
+ pass
63
+
64
+ async def generate_text(self,
65
+ model_id: str,
66
+ prompt: str,
67
+ stream: bool = False,
68
+ temperature: float = 0.7,
69
+ max_tokens: int = 1000) -> Union[UnifiedTextResponse, AsyncGenerator[UnifiedTextChunk, None]]:
70
+ pass
71
+
72
+ async def embed(self,
73
+ model_id: str,
74
+ inputs: Union[str, List[str]],
75
+ input_type: str = "document",
76
+ **kwargs) -> UnifiedEmbeddingResponse:
77
+ pass
78
+
79
+ async def rerank(self,
80
+ model_id: str,
81
+ query: str,
82
+ documents: List[Union[str, Dict]],
83
+ top_k: Optional[int] = None,
84
+ **kwargs) -> UnifiedRerankResponse:
85
+ pass
86
+
87
+ async def transcribe_audio(self,
88
+ model_id: str,
89
+ audio_data: bytes,
90
+ language: Optional[str] = None,
91
+ **kwargs) -> UnifiedAudioTranscriptionResponse:
92
+ pass
93
+
94
+ async def generate_speech(self,
95
+ model_id: str,
96
+ text: str,
97
+ voice_id: Optional[str] = None,
98
+ **kwargs) -> UnifiedSpeechGenerationResponse:
99
+ pass
100
+
101
+ async def analyze_image(self,
102
+ model_id: str,
103
+ image_data: bytes,
104
+ query: str,
105
+ **kwargs) -> UnifiedImageAnalysisResponse:
106
+ pass
107
+
108
+ async def generate_image(self,
109
+ model_id: str,
110
+ prompt: str,
111
+ **kwargs) -> UnifiedImageGenerationResponse:
112
+ pass
113
+
114
+ async def generate_video(self,
115
+ model_id: str,
116
+ prompt: str,
117
+ **kwargs) -> UnifiedVideoGenerationResponse:
118
+ pass
119
+
120
+ async def generate_audio(self,
121
+ model_id: str,
122
+ text: str,
123
+ voice_id: Optional[str] = None,
124
+ **kwargs) -> UnifiedAudioGenerationResponse:
125
+ pass
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
@@ -0,0 +1,34 @@
1
+ # Inside UniversalInferenceClient.chat method
2
+ processed_messages = []
3
+ for user_msg in messages: # messages is what the end-user provided to the SDK
4
+ if isinstance(user_msg, dict) and "role" in user_msg and "content" in user_msg:
5
+ processed_messages.append(user_msg)
6
+ elif isinstance(user_msg, BaseMessage): # If user passes LangChain messages directly
7
+ # Serialize BaseMessage to our standard dict format
8
+ msg_dict = {"role": "", "content": user_msg.content}
9
+ if isinstance(user_msg, HumanMessage):
10
+ msg_dict["role"] = "user"
11
+ elif isinstance(user_msg, AIMessage):
12
+ msg_dict["role"] = "assistant"
13
+ if hasattr(user_msg, "tool_calls") and user_msg.tool_calls:
14
+ # Serialize tool_calls to a list of dicts
15
+ msg_dict["tool_calls"] = [
16
+ {"id": tc.get("id"), "type": "function", "function": {"name": tc.get("name"), "arguments": json.dumps(tc.get("args", {}))}}
17
+ for tc in user_msg.tool_calls # Assuming user_msg.tool_calls are already dicts or serializable
18
+ ]
19
+ elif isinstance(user_msg, SystemMessage):
20
+ msg_dict["role"] = "system"
21
+ elif isinstance(user_msg, ToolMessage):
22
+ msg_dict["role"] = "tool"
23
+ msg_dict["tool_call_id"] = user_msg.tool_call_id
24
+ else:
25
+ # Handle other BaseMessage types or raise error
26
+ pass
27
+ processed_messages.append(msg_dict)
28
+ # ... (add more flexible input handling if needed, e.g., a list of tuples)
29
+ else:
30
+ raise ValueError("Unsupported message format in chat method input.")
31
+
32
+ # task_specific_payload for orchestrator would be:
33
+ # task_payload = {"messages": processed_messages}
34
+ # Then call self._invoke_orchestrator(model_id, task_payload, ...)
@@ -0,0 +1,16 @@
1
+
2
+ """
3
+ 在这里定义(使用Pydantic或Protobuf)一个标准的 ChatMessageSchema,它包含 role: str, content: Optional[str],
4
+ tool_calls: Optional[List[ToolCallSchema]], tool_call_id: Optional[str] 等字段。
5
+ UnifiedRestInvokeRequest (或gRPC的 UnifiedRequest) 中的 task_specific_payload 字段,在处理聊天任务时,
6
+ 其内部的 "messages" 键对应的值就是 List[ChatMessageSchema]。
7
+
8
+ """
9
+
10
+ from pydantic import BaseModel
11
+
12
+ class UnifiedChatResponse(BaseModel):
13
+ pass
14
+
15
+ class UnifiedTextResponse(BaseModel):
16
+ pass
File without changes