isa-model 0.1.1__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +142 -240
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/services/audio/openai_tts_service.py +104 -3
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +111 -1
- isa_model/inference/services/llm/ollama_llm_service.py +234 -26
- isa_model/inference/services/llm/openai_llm_service.py +225 -28
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/ollama_vision_service.py +143 -17
- isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.2.8.dist-info/METADATA +465 -0
- isa_model-0.2.8.dist-info/RECORD +86 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.1.1.dist-info/METADATA +0 -327
- isa_model-0.1.1.dist-info/RECORD +0 -92
- isa_model-0.1.1.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/WHEEL +0 -0
- {isa_model-0.1.1.dist-info → isa_model-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,23 @@
|
|
1
|
-
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Simplified AI Factory for creating inference services
|
6
|
+
Uses the new service architecture with proper base classes
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Dict, Type, Any, Optional, Tuple, List
|
2
10
|
import logging
|
11
|
+
import os
|
3
12
|
from isa_model.inference.providers.base_provider import BaseProvider
|
4
13
|
from isa_model.inference.services.base_service import BaseService
|
5
14
|
from isa_model.inference.base import ModelType
|
6
|
-
import os
|
7
15
|
|
8
|
-
# 设置基本的日志配置
|
9
|
-
logging.basicConfig(level=logging.INFO)
|
10
16
|
logger = logging.getLogger(__name__)
|
11
17
|
|
12
18
|
class AIFactory:
|
13
19
|
"""
|
14
|
-
Factory for creating AI services
|
20
|
+
Simplified Factory for creating AI services with proper inheritance hierarchy
|
15
21
|
"""
|
16
22
|
|
17
23
|
_instance = None
|
@@ -24,72 +30,78 @@ class AIFactory:
|
|
24
30
|
|
25
31
|
def __init__(self):
|
26
32
|
"""Initialize the AI Factory."""
|
27
|
-
self.triton_url = os.environ.get("TRITON_URL", "http://localhost:8000")
|
28
|
-
|
29
|
-
# Cache for services (singleton pattern)
|
30
|
-
self._llm_services = {}
|
31
|
-
self._embedding_services = {}
|
32
|
-
self._speech_services = {}
|
33
|
-
|
34
33
|
if not self._is_initialized:
|
35
34
|
self._providers: Dict[str, Type[BaseProvider]] = {}
|
36
35
|
self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
|
37
36
|
self._cached_services: Dict[str, BaseService] = {}
|
38
|
-
self.
|
37
|
+
self._initialize_services()
|
39
38
|
AIFactory._is_initialized = True
|
40
39
|
|
41
|
-
def
|
42
|
-
"""Initialize
|
40
|
+
def _initialize_services(self):
|
41
|
+
"""Initialize available providers and services"""
|
42
|
+
try:
|
43
|
+
# Register Ollama services
|
44
|
+
self._register_ollama_services()
|
45
|
+
|
46
|
+
# Register OpenAI services
|
47
|
+
self._register_openai_services()
|
48
|
+
|
49
|
+
# Register Replicate services
|
50
|
+
self._register_replicate_services()
|
51
|
+
|
52
|
+
logger.info("AI Factory initialized with simplified service architecture")
|
53
|
+
|
54
|
+
except Exception as e:
|
55
|
+
logger.error(f"Error initializing services: {e}")
|
56
|
+
logger.warning("Some services may not be available")
|
57
|
+
|
58
|
+
def _register_ollama_services(self):
|
59
|
+
"""Register Ollama provider and services"""
|
43
60
|
try:
|
44
|
-
# Import providers and services
|
45
61
|
from isa_model.inference.providers.ollama_provider import OllamaProvider
|
46
|
-
from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
|
47
62
|
from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
|
63
|
+
from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
|
64
|
+
from isa_model.inference.services.vision.ollama_vision_service import OllamaVisionService
|
48
65
|
|
49
|
-
# Register Ollama provider and services
|
50
66
|
self.register_provider('ollama', OllamaProvider)
|
51
|
-
self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
|
52
67
|
self.register_service('ollama', ModelType.LLM, OllamaLLMService)
|
68
|
+
self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
|
69
|
+
self.register_service('ollama', ModelType.VISION, OllamaVisionService)
|
53
70
|
|
54
|
-
|
55
|
-
try:
|
56
|
-
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
57
|
-
from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
|
58
|
-
|
59
|
-
self.register_provider('openai', OpenAIProvider)
|
60
|
-
self.register_service('openai', ModelType.LLM, OpenAILLMService)
|
61
|
-
logger.info("OpenAI services registered successfully")
|
62
|
-
except ImportError as e:
|
63
|
-
logger.warning(f"OpenAI services not available: {e}")
|
71
|
+
logger.info("Ollama services registered successfully")
|
64
72
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
logger.warning(f"Replicate services not available: {e}")
|
75
|
-
except Exception as e:
|
76
|
-
logger.warning(f"Error registering Replicate services: {e}")
|
73
|
+
except ImportError as e:
|
74
|
+
logger.warning(f"Ollama services not available: {e}")
|
75
|
+
|
76
|
+
def _register_openai_services(self):
|
77
|
+
"""Register OpenAI provider and services"""
|
78
|
+
try:
|
79
|
+
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
80
|
+
from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
|
81
|
+
from isa_model.inference.services.audio.openai_tts_service import OpenAITTSService
|
77
82
|
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
self.register_provider('triton', TritonProvider)
|
83
|
-
logger.info("Triton provider registered successfully")
|
84
|
-
|
85
|
-
except ImportError as e:
|
86
|
-
logger.warning(f"Triton provider not available: {e}")
|
83
|
+
self.register_provider('openai', OpenAIProvider)
|
84
|
+
self.register_service('openai', ModelType.LLM, OpenAILLMService)
|
85
|
+
self.register_service('openai', ModelType.AUDIO, OpenAITTSService)
|
87
86
|
|
88
|
-
logger.info("
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
87
|
+
logger.info("OpenAI services registered successfully")
|
88
|
+
|
89
|
+
except ImportError as e:
|
90
|
+
logger.warning(f"OpenAI services not available: {e}")
|
91
|
+
|
92
|
+
def _register_replicate_services(self):
|
93
|
+
"""Register Replicate provider and services"""
|
94
|
+
try:
|
95
|
+
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
96
|
+
from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateImageGenService
|
97
|
+
|
98
|
+
self.register_provider('replicate', ReplicateProvider)
|
99
|
+
self.register_service('replicate', ModelType.VISION, ReplicateImageGenService)
|
100
|
+
|
101
|
+
logger.info("Replicate services registered successfully")
|
102
|
+
|
103
|
+
except ImportError as e:
|
104
|
+
logger.warning(f"Replicate services not available: {e}")
|
93
105
|
|
94
106
|
def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
|
95
107
|
"""Register an AI provider"""
|
@@ -109,24 +121,20 @@ class AIFactory:
|
|
109
121
|
if cache_key in self._cached_services:
|
110
122
|
return self._cached_services[cache_key]
|
111
123
|
|
112
|
-
#
|
113
|
-
|
114
|
-
"log_level": "INFO"
|
115
|
-
}
|
116
|
-
|
117
|
-
# 合并配置
|
118
|
-
service_config = {**base_config, **(config or {})}
|
119
|
-
|
120
|
-
# 创建 provider 和 service
|
121
|
-
provider_class = self._providers[provider_name]
|
124
|
+
# Get provider and service classes
|
125
|
+
provider_class = self._providers.get(provider_name)
|
122
126
|
service_class = self._services.get((provider_name, model_type))
|
123
127
|
|
128
|
+
if not provider_class:
|
129
|
+
raise ValueError(f"No provider registered for '{provider_name}'")
|
130
|
+
|
124
131
|
if not service_class:
|
125
132
|
raise ValueError(
|
126
|
-
f"No service registered for provider {provider_name} and model type {model_type}"
|
133
|
+
f"No service registered for provider '{provider_name}' and model type '{model_type}'"
|
127
134
|
)
|
128
135
|
|
129
|
-
provider
|
136
|
+
# Create provider and service
|
137
|
+
provider = provider_class(config=config or {})
|
130
138
|
service = service_class(provider=provider, model_name=model_name)
|
131
139
|
|
132
140
|
self._cached_services[cache_key] = service
|
@@ -137,223 +145,117 @@ class AIFactory:
|
|
137
145
|
raise
|
138
146
|
|
139
147
|
# Convenient methods for common services
|
140
|
-
def
|
141
|
-
|
148
|
+
def get_llm_service(self, model_name: str = "llama3.1", provider: str = "ollama",
|
149
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
142
150
|
"""
|
143
151
|
Get a LLM service instance
|
144
152
|
|
145
153
|
Args:
|
146
154
|
model_name: Name of the model to use
|
147
|
-
provider: Provider name ('ollama', 'openai'
|
155
|
+
provider: Provider name ('ollama', 'openai')
|
148
156
|
config: Optional configuration dictionary
|
149
|
-
api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
|
150
157
|
|
151
158
|
Returns:
|
152
159
|
LLM service instance
|
153
|
-
|
154
|
-
Example:
|
155
|
-
# Using with API key directly
|
156
|
-
llm = AIFactory.get_instance().get_llm(
|
157
|
-
model_name="gpt-4o-mini",
|
158
|
-
provider="openai",
|
159
|
-
api_key="your-api-key-here"
|
160
|
-
)
|
161
|
-
|
162
|
-
# Using without API key (will use environment variable)
|
163
|
-
llm = AIFactory.get_instance().get_llm(
|
164
|
-
model_name="gpt-4o-mini",
|
165
|
-
provider="openai"
|
166
|
-
)
|
167
160
|
"""
|
168
|
-
|
169
|
-
# Special case for DeepSeek service
|
170
|
-
if model_name.lower() in ["deepseek", "deepseek-r1", "qwen3-8b"]:
|
171
|
-
if "deepseek" in self._cached_services:
|
172
|
-
return self._cached_services["deepseek"]
|
173
|
-
|
174
|
-
# Special case for Llama3-8B direct service
|
175
|
-
if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
|
176
|
-
if "llama3" in self._cached_services:
|
177
|
-
return self._cached_services["llama3"]
|
178
|
-
|
179
|
-
basic_config: Dict[str, Any] = {
|
180
|
-
"temperature": 0
|
181
|
-
}
|
182
|
-
|
183
|
-
# Add API key to config if provided
|
184
|
-
if api_key:
|
185
|
-
if provider == "openai":
|
186
|
-
basic_config["api_key"] = api_key
|
187
|
-
elif provider == "replicate":
|
188
|
-
basic_config["api_token"] = api_key
|
189
|
-
else:
|
190
|
-
logger.warning(f"API key provided but provider '{provider}' may not support it")
|
191
|
-
|
192
|
-
if config:
|
193
|
-
basic_config.update(config)
|
194
|
-
return self.create_service(provider, ModelType.LLM, model_name, basic_config)
|
161
|
+
return self.create_service(provider, ModelType.LLM, model_name, config)
|
195
162
|
|
196
|
-
def
|
197
|
-
|
163
|
+
def get_embedding_service(self, model_name: str = "bge-m3", provider: str = "ollama",
|
164
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
198
165
|
"""
|
199
|
-
Get
|
166
|
+
Get an embedding service instance
|
200
167
|
|
201
168
|
Args:
|
202
169
|
model_name: Name of the model to use
|
203
|
-
provider: Provider name ('
|
170
|
+
provider: Provider name ('ollama')
|
204
171
|
config: Optional configuration dictionary
|
205
|
-
api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
|
206
172
|
|
207
173
|
Returns:
|
208
|
-
|
209
|
-
|
210
|
-
Example:
|
211
|
-
# Using with API key directly
|
212
|
-
vision = AIFactory.get_instance().get_vision_model(
|
213
|
-
model_name="gpt-4o",
|
214
|
-
provider="openai",
|
215
|
-
api_key="your-api-key-here"
|
216
|
-
)
|
217
|
-
|
218
|
-
# Using Replicate for image generation
|
219
|
-
image_gen = AIFactory.get_instance().get_vision_model(
|
220
|
-
model_name="stability-ai/sdxl",
|
221
|
-
provider="replicate",
|
222
|
-
api_key="your-replicate-token"
|
223
|
-
)
|
174
|
+
Embedding service instance
|
224
175
|
"""
|
225
|
-
|
226
|
-
# Special case for Gemma3-4B direct service
|
227
|
-
if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
|
228
|
-
if "gemma3" in self._cached_services:
|
229
|
-
return self._cached_services["gemma3"]
|
230
|
-
|
231
|
-
# Special case for Replicate's image generation models
|
232
|
-
if provider == "replicate" and "/" in model_name:
|
233
|
-
replicate_config: Dict[str, Any] = {
|
234
|
-
"guidance_scale": 7.5,
|
235
|
-
"num_inference_steps": 30
|
236
|
-
}
|
237
|
-
|
238
|
-
# Add API key if provided
|
239
|
-
if api_key:
|
240
|
-
replicate_config["api_token"] = api_key
|
241
|
-
|
242
|
-
if config:
|
243
|
-
replicate_config.update(config)
|
244
|
-
return self.create_service(provider, ModelType.VISION, model_name, replicate_config)
|
245
|
-
|
246
|
-
basic_config: Dict[str, Any] = {
|
247
|
-
"temperature": 0.7,
|
248
|
-
"max_new_tokens": 512
|
249
|
-
}
|
250
|
-
|
251
|
-
# Add API key to config if provided
|
252
|
-
if api_key:
|
253
|
-
if provider == "openai":
|
254
|
-
basic_config["api_key"] = api_key
|
255
|
-
elif provider == "replicate":
|
256
|
-
basic_config["api_token"] = api_key
|
257
|
-
else:
|
258
|
-
logger.warning(f"API key provided but provider '{provider}' may not support it")
|
259
|
-
|
260
|
-
if config:
|
261
|
-
basic_config.update(config)
|
262
|
-
return self.create_service(provider, ModelType.VISION, model_name, basic_config)
|
263
|
-
|
264
|
-
def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
|
265
|
-
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
266
|
-
"""Get an embedding service instance"""
|
267
176
|
return self.create_service(provider, ModelType.EMBEDDING, model_name, config)
|
268
177
|
|
269
|
-
def
|
270
|
-
|
271
|
-
"""Get a rerank service instance"""
|
272
|
-
return self.create_service(provider, ModelType.RERANK, model_name, config)
|
273
|
-
|
274
|
-
def get_embed_service(self, model_name: str = "bge-m3", provider: str = "ollama",
|
275
|
-
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
276
|
-
"""Get an embedding service instance"""
|
277
|
-
return self.get_embedding(model_name, provider, config)
|
278
|
-
|
279
|
-
def get_speech_model(self, model_name: str = "whisper_tiny", provider: str = "triton",
|
280
|
-
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
281
|
-
"""Get a speech-to-text model service instance"""
|
282
|
-
|
283
|
-
# Special case for Whisper Tiny direct service
|
284
|
-
if model_name.lower() in ["whisper", "whisper_tiny", "whisper-tiny"]:
|
285
|
-
if "whisper" in self._cached_services:
|
286
|
-
return self._cached_services["whisper"]
|
287
|
-
|
288
|
-
basic_config = {
|
289
|
-
"language": "en",
|
290
|
-
"task": "transcribe"
|
291
|
-
}
|
292
|
-
if config:
|
293
|
-
basic_config.update(config)
|
294
|
-
return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
|
295
|
-
|
296
|
-
async def get_embedding_service(self, model_name: str) -> Any:
|
178
|
+
def get_vision_service(self, model_name: str, provider: str,
|
179
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
297
180
|
"""
|
298
|
-
Get
|
181
|
+
Get a vision service instance
|
299
182
|
|
300
183
|
Args:
|
301
|
-
model_name: Name of the model
|
184
|
+
model_name: Name of the model to use
|
185
|
+
provider: Provider name ('ollama', 'replicate')
|
186
|
+
config: Optional configuration dictionary
|
302
187
|
|
303
188
|
Returns:
|
304
|
-
|
189
|
+
Vision service instance
|
305
190
|
"""
|
306
|
-
|
307
|
-
return self._embedding_services[model_name]
|
308
|
-
|
309
|
-
else:
|
310
|
-
raise ValueError(f"Unsupported embedding model: {model_name}")
|
191
|
+
return self.create_service(provider, ModelType.VISION, model_name, config)
|
311
192
|
|
312
|
-
|
193
|
+
def get_image_generation_service(self, model_name: str, provider: str = "replicate",
|
194
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
313
195
|
"""
|
314
|
-
Get
|
196
|
+
Get an image generation service instance
|
315
197
|
|
316
198
|
Args:
|
317
|
-
model_name: Name of the model
|
199
|
+
model_name: Name of the model to use (e.g., "stability-ai/sdxl")
|
200
|
+
provider: Provider name ('replicate')
|
201
|
+
config: Optional configuration dictionary
|
318
202
|
|
319
203
|
Returns:
|
320
|
-
|
204
|
+
Image generation service instance
|
321
205
|
"""
|
322
|
-
|
323
|
-
return self._speech_services[model_name]
|
324
|
-
|
206
|
+
return self.create_service(provider, ModelType.VISION, model_name, config)
|
325
207
|
|
326
|
-
def
|
208
|
+
def get_audio_service(self, model_name: str = "tts-1", provider: str = "openai",
|
209
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
327
210
|
"""
|
328
|
-
Get
|
211
|
+
Get an audio service instance
|
329
212
|
|
330
213
|
Args:
|
331
|
-
|
214
|
+
model_name: Name of the model to use
|
215
|
+
provider: Provider name ('openai')
|
216
|
+
config: Optional configuration dictionary
|
332
217
|
|
333
218
|
Returns:
|
334
|
-
|
219
|
+
Audio service instance
|
335
220
|
"""
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
]
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
return {model_type: models.get(model_type, [])}
|
352
|
-
return models
|
221
|
+
return self.create_service(provider, ModelType.AUDIO, model_name, config)
|
222
|
+
|
223
|
+
def get_available_services(self) -> Dict[str, List[str]]:
|
224
|
+
"""Get information about available services"""
|
225
|
+
services = {}
|
226
|
+
for (provider, model_type), service_class in self._services.items():
|
227
|
+
if provider not in services:
|
228
|
+
services[provider] = []
|
229
|
+
services[provider].append(f"{model_type.value}: {service_class.__name__}")
|
230
|
+
return services
|
231
|
+
|
232
|
+
def clear_cache(self):
|
233
|
+
"""Clear the service cache"""
|
234
|
+
self._cached_services.clear()
|
235
|
+
logger.info("Service cache cleared")
|
353
236
|
|
354
237
|
@classmethod
|
355
238
|
def get_instance(cls) -> 'AIFactory':
|
356
239
|
"""Get the singleton instance"""
|
357
240
|
if cls._instance is None:
|
358
241
|
cls._instance = cls()
|
359
|
-
return cls._instance
|
242
|
+
return cls._instance
|
243
|
+
|
244
|
+
# Alias methods for backward compatibility with tests
|
245
|
+
def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
|
246
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
247
|
+
"""Alias for get_llm_service"""
|
248
|
+
return self.get_llm_service(model_name, provider, config)
|
249
|
+
|
250
|
+
def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
|
251
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
252
|
+
"""Alias for get_embedding_service"""
|
253
|
+
return self.get_embedding_service(model_name, provider, config)
|
254
|
+
|
255
|
+
def get_vision_model(self, model_name: str, provider: str,
|
256
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
257
|
+
"""Alias for get_vision_service and get_image_generation_service"""
|
258
|
+
if provider == "replicate":
|
259
|
+
return self.get_image_generation_service(model_name, provider, config)
|
260
|
+
else:
|
261
|
+
return self.get_vision_service(model_name, provider, config)
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
2
|
+
from isa_model.inference.base import ModelType, Capability
|
3
|
+
from typing import Dict, List, Any
|
4
|
+
import logging
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
class MLProvider(BaseProvider):
|
9
|
+
"""Provider for traditional ML models"""
|
10
|
+
|
11
|
+
def __init__(self, config=None):
|
12
|
+
default_config = {
|
13
|
+
"model_directory": "./models/ml",
|
14
|
+
"cache_models": True,
|
15
|
+
"max_cache_size": 5
|
16
|
+
}
|
17
|
+
|
18
|
+
merged_config = {**default_config, **(config or {})}
|
19
|
+
super().__init__(config=merged_config)
|
20
|
+
self.name = "ml"
|
21
|
+
|
22
|
+
logger.info(f"Initialized MLProvider with model directory: {self.config['model_directory']}")
|
23
|
+
|
24
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
25
|
+
"""Get provider capabilities"""
|
26
|
+
return {
|
27
|
+
ModelType.LLM: [], # ML models are not LLMs
|
28
|
+
ModelType.EMBEDDING: [],
|
29
|
+
ModelType.VISION: [],
|
30
|
+
"ML": [ # Custom model type for traditional ML
|
31
|
+
"CLASSIFICATION",
|
32
|
+
"REGRESSION",
|
33
|
+
"CLUSTERING",
|
34
|
+
"FEATURE_EXTRACTION"
|
35
|
+
]
|
36
|
+
}
|
37
|
+
|
38
|
+
def get_models(self, model_type: str = "ML") -> List[str]:
|
39
|
+
"""Get available ML models"""
|
40
|
+
# In practice, this would scan the model directory
|
41
|
+
return [
|
42
|
+
"fraud_detection_rf",
|
43
|
+
"customer_churn_xgb",
|
44
|
+
"price_prediction_lr",
|
45
|
+
"recommendation_kmeans"
|
46
|
+
]
|
47
|
+
|
48
|
+
def get_config(self) -> Dict[str, Any]:
|
49
|
+
"""Get provider configuration"""
|
50
|
+
return self.config
|
@@ -1,15 +1,15 @@
|
|
1
|
-
from typing import Dict, Any
|
1
|
+
from typing import Dict, Any, List, Optional
|
2
2
|
import tempfile
|
3
3
|
import os
|
4
4
|
from openai import AsyncOpenAI
|
5
5
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
|
-
from isa_model.inference.services.
|
6
|
+
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
7
7
|
from isa_model.inference.providers.base_provider import BaseProvider
|
8
8
|
import logging
|
9
9
|
|
10
10
|
logger = logging.getLogger(__name__)
|
11
11
|
|
12
|
-
class
|
12
|
+
class OpenAITTSService(BaseTTSService):
|
13
13
|
"""Audio model service wrapper for YYDS"""
|
14
14
|
|
15
15
|
def __init__(self, provider: 'BaseProvider', model_name: str):
|
@@ -69,3 +69,104 @@ class YYDSAudioService(BaseService):
|
|
69
69
|
except Exception as e:
|
70
70
|
logger.error(f"Error in audio transcription: {e}")
|
71
71
|
raise
|
72
|
+
|
73
|
+
# 实现BaseTTSService的抽象方法
|
74
|
+
async def synthesize_speech(
|
75
|
+
self,
|
76
|
+
text: str,
|
77
|
+
voice: Optional[str] = None,
|
78
|
+
speed: float = 1.0,
|
79
|
+
pitch: float = 1.0,
|
80
|
+
format: str = "mp3"
|
81
|
+
) -> Dict[str, Any]:
|
82
|
+
"""Synthesize speech from text using OpenAI TTS"""
|
83
|
+
try:
|
84
|
+
response = await self._client.audio.speech.create(
|
85
|
+
model="tts-1",
|
86
|
+
voice=voice or "alloy",
|
87
|
+
input=text,
|
88
|
+
response_format=format,
|
89
|
+
speed=speed
|
90
|
+
)
|
91
|
+
|
92
|
+
audio_data = response.content
|
93
|
+
|
94
|
+
return {
|
95
|
+
"audio_data": audio_data,
|
96
|
+
"format": format,
|
97
|
+
"duration": 0.0, # OpenAI doesn't provide duration
|
98
|
+
"sample_rate": 24000 # Default for OpenAI TTS
|
99
|
+
}
|
100
|
+
|
101
|
+
except Exception as e:
|
102
|
+
logger.error(f"Error in speech synthesis: {e}")
|
103
|
+
raise
|
104
|
+
|
105
|
+
async def synthesize_speech_to_file(
|
106
|
+
self,
|
107
|
+
text: str,
|
108
|
+
output_path: str,
|
109
|
+
voice: Optional[str] = None,
|
110
|
+
speed: float = 1.0,
|
111
|
+
pitch: float = 1.0,
|
112
|
+
format: str = "mp3"
|
113
|
+
) -> Dict[str, Any]:
|
114
|
+
"""Synthesize speech and save to file"""
|
115
|
+
result = await self.synthesize_speech(text, voice, speed, pitch, format)
|
116
|
+
|
117
|
+
with open(output_path, 'wb') as f:
|
118
|
+
f.write(result["audio_data"])
|
119
|
+
|
120
|
+
return {
|
121
|
+
"file_path": output_path,
|
122
|
+
"duration": result["duration"],
|
123
|
+
"sample_rate": result["sample_rate"]
|
124
|
+
}
|
125
|
+
|
126
|
+
async def synthesize_speech_batch(
|
127
|
+
self,
|
128
|
+
texts: List[str],
|
129
|
+
voice: Optional[str] = None,
|
130
|
+
speed: float = 1.0,
|
131
|
+
pitch: float = 1.0,
|
132
|
+
format: str = "mp3"
|
133
|
+
) -> List[Dict[str, Any]]:
|
134
|
+
"""Synthesize speech for multiple texts"""
|
135
|
+
results = []
|
136
|
+
for text in texts:
|
137
|
+
result = await self.synthesize_speech(text, voice, speed, pitch, format)
|
138
|
+
results.append(result)
|
139
|
+
return results
|
140
|
+
|
141
|
+
def get_available_voices(self) -> List[Dict[str, Any]]:
|
142
|
+
"""Get list of available OpenAI voices"""
|
143
|
+
return [
|
144
|
+
{"id": "alloy", "name": "Alloy", "language": "en-US", "gender": "neutral", "age": "adult"},
|
145
|
+
{"id": "echo", "name": "Echo", "language": "en-US", "gender": "male", "age": "adult"},
|
146
|
+
{"id": "fable", "name": "Fable", "language": "en-US", "gender": "neutral", "age": "adult"},
|
147
|
+
{"id": "onyx", "name": "Onyx", "language": "en-US", "gender": "male", "age": "adult"},
|
148
|
+
{"id": "nova", "name": "Nova", "language": "en-US", "gender": "female", "age": "adult"},
|
149
|
+
{"id": "shimmer", "name": "Shimmer", "language": "en-US", "gender": "female", "age": "adult"}
|
150
|
+
]
|
151
|
+
|
152
|
+
def get_supported_formats(self) -> List[str]:
|
153
|
+
"""Get list of supported audio formats"""
|
154
|
+
return ["mp3", "opus", "aac", "flac"]
|
155
|
+
|
156
|
+
def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
|
157
|
+
"""Get detailed information about a specific voice"""
|
158
|
+
voices = {voice["id"]: voice for voice in self.get_available_voices()}
|
159
|
+
voice_info = voices.get(voice_id, {})
|
160
|
+
|
161
|
+
if voice_info:
|
162
|
+
voice_info.update({
|
163
|
+
"description": f"OpenAI {voice_info['name']} voice",
|
164
|
+
"sample_rate": 24000
|
165
|
+
})
|
166
|
+
|
167
|
+
return voice_info
|
168
|
+
|
169
|
+
async def close(self):
|
170
|
+
"""Cleanup resources"""
|
171
|
+
if hasattr(self._client, 'close'):
|
172
|
+
await self._client.close()
|