isa-model 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.2.0.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/top_level.txt +0 -0
@@ -5,11 +5,6 @@ from isa_model.inference.services.base_service import BaseService
|
|
5
5
|
from isa_model.inference.base import ModelType
|
6
6
|
import os
|
7
7
|
|
8
|
-
from isa_model.inference.services.llm.llama_service import LlamaService
|
9
|
-
from isa_model.inference.services.llm.gemma_service import GemmaService
|
10
|
-
from isa_model.inference.services.audio.whisper_service import WhisperService
|
11
|
-
from isa_model.inference.services.embedding.bge_service import BgeEmbeddingService
|
12
|
-
|
13
8
|
# 设置基本的日志配置
|
14
9
|
logging.basicConfig(level=logging.INFO)
|
15
10
|
logger = logging.getLogger(__name__)
|
@@ -29,7 +24,7 @@ class AIFactory:
|
|
29
24
|
|
30
25
|
def __init__(self):
|
31
26
|
"""Initialize the AI Factory."""
|
32
|
-
self.triton_url = os.environ.get("TRITON_URL", "localhost:
|
27
|
+
self.triton_url = os.environ.get("TRITON_URL", "http://localhost:8000")
|
33
28
|
|
34
29
|
# Cache for services (singleton pattern)
|
35
30
|
self._llm_services = {}
|
@@ -70,58 +65,25 @@ class AIFactory:
|
|
70
65
|
# Register Replicate provider and services
|
71
66
|
try:
|
72
67
|
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
73
|
-
from isa_model.inference.services.
|
74
|
-
from isa_model.inference.services.vision.replicate_vision_service import ReplicateVisionService
|
68
|
+
from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
|
75
69
|
|
76
70
|
self.register_provider('replicate', ReplicateProvider)
|
77
|
-
self.register_service('replicate', ModelType.LLM, ReplicateLLMService)
|
78
71
|
self.register_service('replicate', ModelType.VISION, ReplicateVisionService)
|
79
|
-
logger.info("Replicate
|
72
|
+
logger.info("Replicate provider and vision service registered successfully")
|
80
73
|
except ImportError as e:
|
81
74
|
logger.warning(f"Replicate services not available: {e}")
|
75
|
+
except Exception as e:
|
76
|
+
logger.warning(f"Error registering Replicate services: {e}")
|
82
77
|
|
83
78
|
# Try to register Triton services
|
84
79
|
try:
|
85
80
|
from isa_model.inference.providers.triton_provider import TritonProvider
|
86
|
-
from isa_model.inference.services.llm.triton_llm_service import TritonLLMService
|
87
|
-
from isa_model.inference.services.vision.triton_vision_service import TritonVisionService
|
88
|
-
from isa_model.inference.services.audio.triton_speech_service import TritonSpeechService
|
89
81
|
|
90
82
|
self.register_provider('triton', TritonProvider)
|
91
|
-
|
92
|
-
self.register_service('triton', ModelType.VISION, TritonVisionService)
|
93
|
-
self.register_service('triton', ModelType.AUDIO, TritonSpeechService)
|
94
|
-
logger.info("Triton services registered successfully")
|
95
|
-
|
96
|
-
# Register HuggingFace-based direct LLM service for Llama3-8B
|
97
|
-
try:
|
98
|
-
from isa_model.inference.llm.llama3_service import Llama3Service
|
99
|
-
# Register as a standalone service for direct access
|
100
|
-
self._cached_services["llama3"] = Llama3Service()
|
101
|
-
logger.info("Llama3-8B service registered successfully")
|
102
|
-
except ImportError as e:
|
103
|
-
logger.warning(f"Llama3-8B service not available: {e}")
|
104
|
-
|
105
|
-
# Register HuggingFace-based direct Vision service for Gemma3-4B
|
106
|
-
try:
|
107
|
-
from isa_model.inference.vision.gemma3_service import Gemma3VisionService
|
108
|
-
# Register as a standalone service for direct access
|
109
|
-
self._cached_services["gemma3"] = Gemma3VisionService()
|
110
|
-
logger.info("Gemma3-4B Vision service registered successfully")
|
111
|
-
except ImportError as e:
|
112
|
-
logger.warning(f"Gemma3-4B Vision service not available: {e}")
|
113
|
-
|
114
|
-
# Register HuggingFace-based direct Speech service for Whisper Tiny
|
115
|
-
try:
|
116
|
-
from isa_model.inference.speech.whisper_service import WhisperService
|
117
|
-
# Register as a standalone service for direct access
|
118
|
-
self._cached_services["whisper"] = WhisperService()
|
119
|
-
logger.info("Whisper Tiny Speech service registered successfully")
|
120
|
-
except ImportError as e:
|
121
|
-
logger.warning(f"Whisper Tiny Speech service not available: {e}")
|
83
|
+
logger.info("Triton provider registered successfully")
|
122
84
|
|
123
85
|
except ImportError as e:
|
124
|
-
logger.warning(f"Triton
|
86
|
+
logger.warning(f"Triton provider not available: {e}")
|
125
87
|
|
126
88
|
logger.info("Default AI providers and services initialized with backend architecture")
|
127
89
|
except Exception as e:
|
@@ -176,24 +138,90 @@ class AIFactory:
|
|
176
138
|
|
177
139
|
# Convenient methods for common services
|
178
140
|
def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
|
179
|
-
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
180
|
-
"""
|
141
|
+
config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
|
142
|
+
"""
|
143
|
+
Get a LLM service instance
|
144
|
+
|
145
|
+
Args:
|
146
|
+
model_name: Name of the model to use
|
147
|
+
provider: Provider name ('ollama', 'openai', 'replicate', etc.)
|
148
|
+
config: Optional configuration dictionary
|
149
|
+
api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
LLM service instance
|
153
|
+
|
154
|
+
Example:
|
155
|
+
# Using with API key directly
|
156
|
+
llm = AIFactory.get_instance().get_llm(
|
157
|
+
model_name="gpt-4o-mini",
|
158
|
+
provider="openai",
|
159
|
+
api_key="your-api-key-here"
|
160
|
+
)
|
161
|
+
|
162
|
+
# Using without API key (will use environment variable)
|
163
|
+
llm = AIFactory.get_instance().get_llm(
|
164
|
+
model_name="gpt-4o-mini",
|
165
|
+
provider="openai"
|
166
|
+
)
|
167
|
+
"""
|
168
|
+
|
169
|
+
# Special case for DeepSeek service
|
170
|
+
if model_name.lower() in ["deepseek", "deepseek-r1", "qwen3-8b"]:
|
171
|
+
if "deepseek" in self._cached_services:
|
172
|
+
return self._cached_services["deepseek"]
|
181
173
|
|
182
174
|
# Special case for Llama3-8B direct service
|
183
175
|
if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
|
184
176
|
if "llama3" in self._cached_services:
|
185
177
|
return self._cached_services["llama3"]
|
186
178
|
|
187
|
-
basic_config = {
|
179
|
+
basic_config: Dict[str, Any] = {
|
188
180
|
"temperature": 0
|
189
181
|
}
|
182
|
+
|
183
|
+
# Add API key to config if provided
|
184
|
+
if api_key:
|
185
|
+
if provider == "openai":
|
186
|
+
basic_config["api_key"] = api_key
|
187
|
+
elif provider == "replicate":
|
188
|
+
basic_config["api_token"] = api_key
|
189
|
+
else:
|
190
|
+
logger.warning(f"API key provided but provider '{provider}' may not support it")
|
191
|
+
|
190
192
|
if config:
|
191
193
|
basic_config.update(config)
|
192
194
|
return self.create_service(provider, ModelType.LLM, model_name, basic_config)
|
193
195
|
|
194
196
|
def get_vision_model(self, model_name: str = "gemma3-4b", provider: str = "triton",
|
195
|
-
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
196
|
-
"""
|
197
|
+
config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
|
198
|
+
"""
|
199
|
+
Get a vision model service instance
|
200
|
+
|
201
|
+
Args:
|
202
|
+
model_name: Name of the model to use
|
203
|
+
provider: Provider name ('openai', 'replicate', 'triton', etc.)
|
204
|
+
config: Optional configuration dictionary
|
205
|
+
api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
Vision service instance
|
209
|
+
|
210
|
+
Example:
|
211
|
+
# Using with API key directly
|
212
|
+
vision = AIFactory.get_instance().get_vision_model(
|
213
|
+
model_name="gpt-4o",
|
214
|
+
provider="openai",
|
215
|
+
api_key="your-api-key-here"
|
216
|
+
)
|
217
|
+
|
218
|
+
# Using Replicate for image generation
|
219
|
+
image_gen = AIFactory.get_instance().get_vision_model(
|
220
|
+
model_name="stability-ai/sdxl",
|
221
|
+
provider="replicate",
|
222
|
+
api_key="your-replicate-token"
|
223
|
+
)
|
224
|
+
"""
|
197
225
|
|
198
226
|
# Special case for Gemma3-4B direct service
|
199
227
|
if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
|
@@ -202,19 +230,33 @@ class AIFactory:
|
|
202
230
|
|
203
231
|
# Special case for Replicate's image generation models
|
204
232
|
if provider == "replicate" and "/" in model_name:
|
205
|
-
|
206
|
-
"api_token": os.environ.get("REPLICATE_API_TOKEN", ""),
|
233
|
+
replicate_config: Dict[str, Any] = {
|
207
234
|
"guidance_scale": 7.5,
|
208
235
|
"num_inference_steps": 30
|
209
236
|
}
|
237
|
+
|
238
|
+
# Add API key if provided
|
239
|
+
if api_key:
|
240
|
+
replicate_config["api_token"] = api_key
|
241
|
+
|
210
242
|
if config:
|
211
|
-
|
212
|
-
return self.create_service(provider, ModelType.VISION, model_name,
|
243
|
+
replicate_config.update(config)
|
244
|
+
return self.create_service(provider, ModelType.VISION, model_name, replicate_config)
|
213
245
|
|
214
|
-
basic_config = {
|
246
|
+
basic_config: Dict[str, Any] = {
|
215
247
|
"temperature": 0.7,
|
216
248
|
"max_new_tokens": 512
|
217
249
|
}
|
250
|
+
|
251
|
+
# Add API key to config if provided
|
252
|
+
if api_key:
|
253
|
+
if provider == "openai":
|
254
|
+
basic_config["api_key"] = api_key
|
255
|
+
elif provider == "replicate":
|
256
|
+
basic_config["api_token"] = api_key
|
257
|
+
else:
|
258
|
+
logger.warning(f"API key provided but provider '{provider}' may not support it")
|
259
|
+
|
218
260
|
if config:
|
219
261
|
basic_config.update(config)
|
220
262
|
return self.create_service(provider, ModelType.VISION, model_name, basic_config)
|
@@ -251,32 +293,6 @@ class AIFactory:
|
|
251
293
|
basic_config.update(config)
|
252
294
|
return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
|
253
295
|
|
254
|
-
async def get_llm_service(self, model_name: str) -> Any:
|
255
|
-
"""
|
256
|
-
Get an LLM service for the specified model.
|
257
|
-
|
258
|
-
Args:
|
259
|
-
model_name: Name of the model
|
260
|
-
|
261
|
-
Returns:
|
262
|
-
LLM service instance
|
263
|
-
"""
|
264
|
-
if model_name in self._llm_services:
|
265
|
-
return self._llm_services[model_name]
|
266
|
-
|
267
|
-
if model_name == "llama":
|
268
|
-
service = LlamaService(triton_url=self.triton_url, model_name="llama")
|
269
|
-
await service.load()
|
270
|
-
self._llm_services[model_name] = service
|
271
|
-
return service
|
272
|
-
elif model_name == "gemma":
|
273
|
-
service = GemmaService(triton_url=self.triton_url, model_name="gemma")
|
274
|
-
await service.load()
|
275
|
-
self._llm_services[model_name] = service
|
276
|
-
return service
|
277
|
-
else:
|
278
|
-
raise ValueError(f"Unsupported LLM model: {model_name}")
|
279
|
-
|
280
296
|
async def get_embedding_service(self, model_name: str) -> Any:
|
281
297
|
"""
|
282
298
|
Get an embedding service for the specified model.
|
@@ -290,11 +306,6 @@ class AIFactory:
|
|
290
306
|
if model_name in self._embedding_services:
|
291
307
|
return self._embedding_services[model_name]
|
292
308
|
|
293
|
-
if model_name == "bge_embed":
|
294
|
-
service = BgeEmbeddingService(triton_url=self.triton_url, model_name="bge_embed")
|
295
|
-
await service.load()
|
296
|
-
self._embedding_services[model_name] = service
|
297
|
-
return service
|
298
309
|
else:
|
299
310
|
raise ValueError(f"Unsupported embedding model: {model_name}")
|
300
311
|
|
@@ -311,13 +322,6 @@ class AIFactory:
|
|
311
322
|
if model_name in self._speech_services:
|
312
323
|
return self._speech_services[model_name]
|
313
324
|
|
314
|
-
if model_name == "whisper":
|
315
|
-
service = WhisperService(triton_url=self.triton_url, model_name="whisper")
|
316
|
-
await service.load()
|
317
|
-
self._speech_services[model_name] = service
|
318
|
-
return service
|
319
|
-
else:
|
320
|
-
raise ValueError(f"Unsupported speech model: {model_name}")
|
321
325
|
|
322
326
|
def get_model_info(self, model_type: Optional[str] = None) -> Dict[str, Any]:
|
323
327
|
"""
|
@@ -331,6 +335,7 @@ class AIFactory:
|
|
331
335
|
"""
|
332
336
|
models = {
|
333
337
|
"llm": [
|
338
|
+
{"name": "deepseek", "description": "DeepSeek-R1-0528-Qwen3-8B language model"},
|
334
339
|
{"name": "llama", "description": "Llama3-8B language model"},
|
335
340
|
{"name": "gemma", "description": "Gemma3-4B language model"}
|
336
341
|
],
|
@@ -15,13 +15,13 @@ class OpenAIProvider(BaseProvider):
|
|
15
15
|
|
16
16
|
Args:
|
17
17
|
config (dict, optional): Configuration for the provider
|
18
|
-
- api_key: OpenAI API key (
|
18
|
+
- api_key: OpenAI API key (can be passed here or via environment variable)
|
19
19
|
- api_base: Base URL for OpenAI API (default: https://api.openai.com/v1)
|
20
20
|
- timeout: Timeout for API calls in seconds
|
21
21
|
"""
|
22
22
|
default_config = {
|
23
|
-
"api_key":
|
24
|
-
"api_base": "https://api.openai.com/v1",
|
23
|
+
"api_key": "", # Will be set from config or environment
|
24
|
+
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
|
25
25
|
"timeout": 60,
|
26
26
|
"stream": True,
|
27
27
|
"temperature": 0.7,
|
@@ -32,14 +32,28 @@ class OpenAIProvider(BaseProvider):
|
|
32
32
|
# Merge default config with provided config
|
33
33
|
merged_config = {**default_config, **(config or {})}
|
34
34
|
|
35
|
+
# Set API key from config first, then fallback to environment variable
|
36
|
+
if not merged_config["api_key"]:
|
37
|
+
merged_config["api_key"] = os.environ.get("OPENAI_API_KEY", "")
|
38
|
+
|
35
39
|
super().__init__(config=merged_config)
|
36
40
|
self.name = "openai"
|
37
41
|
|
38
42
|
logger.info(f"Initialized OpenAIProvider with URL: {self.config['api_base']}")
|
39
43
|
|
40
|
-
#
|
44
|
+
# Only warn if no API key is provided at all
|
41
45
|
if not self.config["api_key"]:
|
42
|
-
logger.
|
46
|
+
logger.info("OpenAI API key not provided. You can set it via OPENAI_API_KEY environment variable or pass it in the config when creating services.")
|
47
|
+
|
48
|
+
def set_api_key(self, api_key: str):
|
49
|
+
"""
|
50
|
+
Set the API key after initialization
|
51
|
+
|
52
|
+
Args:
|
53
|
+
api_key: OpenAI API key
|
54
|
+
"""
|
55
|
+
self.config["api_key"] = api_key
|
56
|
+
logger.info("OpenAI API key updated")
|
43
57
|
|
44
58
|
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
45
59
|
"""Get provider capabilities by model type"""
|
@@ -52,7 +66,7 @@ class OpenAIProvider(BaseProvider):
|
|
52
66
|
Capability.EMBEDDING
|
53
67
|
],
|
54
68
|
ModelType.VISION: [
|
55
|
-
Capability.
|
69
|
+
Capability.IMAGE_GENERATION,
|
56
70
|
Capability.MULTIMODAL_UNDERSTANDING
|
57
71
|
],
|
58
72
|
ModelType.AUDIO: [
|
@@ -63,7 +77,7 @@ class OpenAIProvider(BaseProvider):
|
|
63
77
|
def get_models(self, model_type: ModelType) -> List[str]:
|
64
78
|
"""Get available models for given type"""
|
65
79
|
if model_type == ModelType.LLM:
|
66
|
-
return ["gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
|
80
|
+
return ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
|
67
81
|
elif model_type == ModelType.EMBEDDING:
|
68
82
|
return ["text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002"]
|
69
83
|
elif model_type == ModelType.VISION:
|
@@ -15,11 +15,11 @@ class ReplicateProvider(BaseProvider):
|
|
15
15
|
|
16
16
|
Args:
|
17
17
|
config (dict, optional): Configuration for the provider
|
18
|
-
- api_token: Replicate API token (
|
18
|
+
- api_token: Replicate API token (can be passed here or via environment variable)
|
19
19
|
- timeout: Timeout for API calls in seconds
|
20
20
|
"""
|
21
21
|
default_config = {
|
22
|
-
"api_token":
|
22
|
+
"api_token": "", # Will be set from config or environment
|
23
23
|
"timeout": 60,
|
24
24
|
"stream": True,
|
25
25
|
"max_tokens": 1024
|
@@ -28,14 +28,28 @@ class ReplicateProvider(BaseProvider):
|
|
28
28
|
# Merge default config with provided config
|
29
29
|
merged_config = {**default_config, **(config or {})}
|
30
30
|
|
31
|
+
# Set API token from config first, then fallback to environment variable
|
32
|
+
if not merged_config["api_token"]:
|
33
|
+
merged_config["api_token"] = os.environ.get("REPLICATE_API_TOKEN", "")
|
34
|
+
|
31
35
|
super().__init__(config=merged_config)
|
32
36
|
self.name = "replicate"
|
33
37
|
|
34
38
|
logger.info(f"Initialized ReplicateProvider")
|
35
39
|
|
36
|
-
#
|
40
|
+
# Only warn if no API token is provided at all
|
37
41
|
if not self.config["api_token"]:
|
38
|
-
logger.
|
42
|
+
logger.info("Replicate API token not provided. You can set it via REPLICATE_API_TOKEN environment variable or pass it in the config when creating services.")
|
43
|
+
|
44
|
+
def set_api_token(self, api_token: str):
|
45
|
+
"""
|
46
|
+
Set the API token after initialization
|
47
|
+
|
48
|
+
Args:
|
49
|
+
api_token: Replicate API token
|
50
|
+
"""
|
51
|
+
self.config["api_token"] = api_token
|
52
|
+
logger.info("Replicate API token updated")
|
39
53
|
|
40
54
|
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
41
55
|
"""Get provider capabilities by model type"""
|
@@ -45,7 +59,6 @@ class ReplicateProvider(BaseProvider):
|
|
45
59
|
Capability.COMPLETION
|
46
60
|
],
|
47
61
|
ModelType.VISION: [
|
48
|
-
Capability.IMAGE_UNDERSTANDING,
|
49
62
|
Capability.IMAGE_GENERATION,
|
50
63
|
Capability.MULTIMODAL_UNDERSTANDING
|
51
64
|
],
|
@@ -29,7 +29,7 @@ class TritonProvider(BaseProvider):
|
|
29
29
|
|
30
30
|
# Default configuration
|
31
31
|
self.default_config = {
|
32
|
-
"server_url": os.environ.get("TRITON_SERVER_URL", "localhost:8000"),
|
32
|
+
"server_url": os.environ.get("TRITON_SERVER_URL", "http://localhost:8000"),
|
33
33
|
"model_repository": os.environ.get(
|
34
34
|
"MODEL_REPOSITORY",
|
35
35
|
os.path.join(os.getcwd(), "models/triton/model_repository")
|
@@ -0,0 +1,91 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional, BinaryIO
|
3
|
+
from isa_model.inference.services.base_service import BaseService
|
4
|
+
|
5
|
+
class BaseSTTService(BaseService):
|
6
|
+
"""Base class for Speech-to-Text services"""
|
7
|
+
|
8
|
+
@abstractmethod
|
9
|
+
async def transcribe_audio(
|
10
|
+
self,
|
11
|
+
audio_file: Union[str, BinaryIO],
|
12
|
+
language: Optional[str] = None,
|
13
|
+
prompt: Optional[str] = None
|
14
|
+
) -> Dict[str, Any]:
|
15
|
+
"""
|
16
|
+
Transcribe audio file to text
|
17
|
+
|
18
|
+
Args:
|
19
|
+
audio_file: Path to audio file or file-like object
|
20
|
+
language: Language code (e.g., 'en', 'es', 'fr')
|
21
|
+
prompt: Optional prompt to guide transcription
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Dict containing transcription results with keys:
|
25
|
+
- text: The transcribed text
|
26
|
+
- language: Detected/specified language
|
27
|
+
- confidence: Confidence score (if available)
|
28
|
+
- segments: Time-segmented transcription (if available)
|
29
|
+
"""
|
30
|
+
pass
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
async def transcribe_audio_batch(
|
34
|
+
self,
|
35
|
+
audio_files: List[Union[str, BinaryIO]],
|
36
|
+
language: Optional[str] = None,
|
37
|
+
prompt: Optional[str] = None
|
38
|
+
) -> List[Dict[str, Any]]:
|
39
|
+
"""
|
40
|
+
Transcribe multiple audio files
|
41
|
+
|
42
|
+
Args:
|
43
|
+
audio_files: List of audio file paths or file-like objects
|
44
|
+
language: Language code (e.g., 'en', 'es', 'fr')
|
45
|
+
prompt: Optional prompt to guide transcription
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
List of transcription results
|
49
|
+
"""
|
50
|
+
pass
|
51
|
+
|
52
|
+
@abstractmethod
|
53
|
+
async def detect_language(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
|
54
|
+
"""
|
55
|
+
Detect language of audio file
|
56
|
+
|
57
|
+
Args:
|
58
|
+
audio_file: Path to audio file or file-like object
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Dict containing language detection results with keys:
|
62
|
+
- language: Detected language code
|
63
|
+
- confidence: Confidence score
|
64
|
+
- alternatives: List of alternative languages with scores
|
65
|
+
"""
|
66
|
+
pass
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
def get_supported_formats(self) -> List[str]:
|
70
|
+
"""
|
71
|
+
Get list of supported audio formats
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
List of supported file extensions (e.g., ['mp3', 'wav', 'flac'])
|
75
|
+
"""
|
76
|
+
pass
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
def get_supported_languages(self) -> List[str]:
|
80
|
+
"""
|
81
|
+
Get list of supported language codes
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
List of supported language codes (e.g., ['en', 'es', 'fr'])
|
85
|
+
"""
|
86
|
+
pass
|
87
|
+
|
88
|
+
@abstractmethod
|
89
|
+
async def close(self):
|
90
|
+
"""Cleanup resources"""
|
91
|
+
pass
|
@@ -0,0 +1,136 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional, BinaryIO
|
3
|
+
from isa_model.inference.services.base_service import BaseService
|
4
|
+
|
5
|
+
class BaseTTSService(BaseService):
|
6
|
+
"""Base class for Text-to-Speech services"""
|
7
|
+
|
8
|
+
@abstractmethod
|
9
|
+
async def synthesize_speech(
|
10
|
+
self,
|
11
|
+
text: str,
|
12
|
+
voice: Optional[str] = None,
|
13
|
+
speed: float = 1.0,
|
14
|
+
pitch: float = 1.0,
|
15
|
+
format: str = "mp3"
|
16
|
+
) -> Dict[str, Any]:
|
17
|
+
"""
|
18
|
+
Synthesize speech from text
|
19
|
+
|
20
|
+
Args:
|
21
|
+
text: Input text to convert to speech
|
22
|
+
voice: Voice ID or name to use
|
23
|
+
speed: Speech speed multiplier (0.5-2.0)
|
24
|
+
pitch: Pitch adjustment (-1.0 to 1.0)
|
25
|
+
format: Audio format ('mp3', 'wav', 'ogg')
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
Dict containing synthesis results with keys:
|
29
|
+
- audio_data: Binary audio data
|
30
|
+
- format: Audio format
|
31
|
+
- duration: Audio duration in seconds
|
32
|
+
- sample_rate: Audio sample rate
|
33
|
+
"""
|
34
|
+
pass
|
35
|
+
|
36
|
+
@abstractmethod
|
37
|
+
async def synthesize_speech_to_file(
|
38
|
+
self,
|
39
|
+
text: str,
|
40
|
+
output_path: str,
|
41
|
+
voice: Optional[str] = None,
|
42
|
+
speed: float = 1.0,
|
43
|
+
pitch: float = 1.0,
|
44
|
+
format: str = "mp3"
|
45
|
+
) -> Dict[str, Any]:
|
46
|
+
"""
|
47
|
+
Synthesize speech and save directly to file
|
48
|
+
|
49
|
+
Args:
|
50
|
+
text: Input text to convert to speech
|
51
|
+
output_path: Path to save the audio file
|
52
|
+
voice: Voice ID or name to use
|
53
|
+
speed: Speech speed multiplier (0.5-2.0)
|
54
|
+
pitch: Pitch adjustment (-1.0 to 1.0)
|
55
|
+
format: Audio format ('mp3', 'wav', 'ogg')
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Dict containing synthesis results with keys:
|
59
|
+
- file_path: Path to saved audio file
|
60
|
+
- duration: Audio duration in seconds
|
61
|
+
- sample_rate: Audio sample rate
|
62
|
+
"""
|
63
|
+
pass
|
64
|
+
|
65
|
+
@abstractmethod
|
66
|
+
async def synthesize_speech_batch(
|
67
|
+
self,
|
68
|
+
texts: List[str],
|
69
|
+
voice: Optional[str] = None,
|
70
|
+
speed: float = 1.0,
|
71
|
+
pitch: float = 1.0,
|
72
|
+
format: str = "mp3"
|
73
|
+
) -> List[Dict[str, Any]]:
|
74
|
+
"""
|
75
|
+
Synthesize speech for multiple texts
|
76
|
+
|
77
|
+
Args:
|
78
|
+
texts: List of input texts to convert to speech
|
79
|
+
voice: Voice ID or name to use
|
80
|
+
speed: Speech speed multiplier (0.5-2.0)
|
81
|
+
pitch: Pitch adjustment (-1.0 to 1.0)
|
82
|
+
format: Audio format ('mp3', 'wav', 'ogg')
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
List of synthesis result dictionaries
|
86
|
+
"""
|
87
|
+
pass
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def get_available_voices(self) -> List[Dict[str, Any]]:
|
91
|
+
"""
|
92
|
+
Get list of available voices
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
List of voice information dictionaries with keys:
|
96
|
+
- id: Voice identifier
|
97
|
+
- name: Human-readable voice name
|
98
|
+
- language: Language code (e.g., 'en-US', 'es-ES')
|
99
|
+
- gender: Voice gender ('male', 'female', 'neutral')
|
100
|
+
- age: Voice age category ('adult', 'child', 'elderly')
|
101
|
+
"""
|
102
|
+
pass
|
103
|
+
|
104
|
+
@abstractmethod
|
105
|
+
def get_supported_formats(self) -> List[str]:
|
106
|
+
"""
|
107
|
+
Get list of supported audio formats
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
List of supported file extensions (e.g., ['mp3', 'wav', 'ogg'])
|
111
|
+
"""
|
112
|
+
pass
|
113
|
+
|
114
|
+
@abstractmethod
|
115
|
+
def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
|
116
|
+
"""
|
117
|
+
Get detailed information about a specific voice
|
118
|
+
|
119
|
+
Args:
|
120
|
+
voice_id: Voice identifier
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Dict containing voice information:
|
124
|
+
- id: Voice identifier
|
125
|
+
- name: Human-readable voice name
|
126
|
+
- language: Language code
|
127
|
+
- gender: Voice gender
|
128
|
+
- description: Voice description
|
129
|
+
- sample_rate: Default sample rate
|
130
|
+
"""
|
131
|
+
pass
|
132
|
+
|
133
|
+
@abstractmethod
|
134
|
+
async def close(self):
|
135
|
+
"""Cleanup resources"""
|
136
|
+
pass
|
@@ -3,11 +3,11 @@ import tempfile
|
|
3
3
|
import os
|
4
4
|
from openai import AsyncOpenAI
|
5
5
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
|
-
from
|
7
|
-
from
|
8
|
-
|
6
|
+
from isa_model.inference.services.base_service import BaseService
|
7
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
8
|
+
import logging
|
9
9
|
|
10
|
-
logger =
|
10
|
+
logger = logging.getLogger(__name__)
|
11
11
|
|
12
12
|
class YYDSAudioService(BaseService):
|
13
13
|
"""Audio model service wrapper for YYDS"""
|