isa-model 0.0.4__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/storage/hf_storage.py +419 -0
  3. isa_model/deployment/__init__.py +52 -0
  4. isa_model/deployment/core/__init__.py +34 -0
  5. isa_model/deployment/core/deployment_config.py +356 -0
  6. isa_model/deployment/core/deployment_manager.py +549 -0
  7. isa_model/deployment/core/isa_deployment_service.py +401 -0
  8. isa_model/eval/factory.py +381 -140
  9. isa_model/inference/ai_factory.py +142 -240
  10. isa_model/inference/providers/ml_provider.py +50 -0
  11. isa_model/inference/services/audio/openai_tts_service.py +104 -3
  12. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  13. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  14. isa_model/inference/services/llm/__init__.py +2 -0
  15. isa_model/inference/services/llm/base_llm_service.py +111 -1
  16. isa_model/inference/services/llm/ollama_llm_service.py +234 -26
  17. isa_model/inference/services/llm/openai_llm_service.py +180 -26
  18. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  19. isa_model/inference/services/ml/base_ml_service.py +78 -0
  20. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  21. isa_model/inference/services/vision/__init__.py +3 -3
  22. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  23. isa_model/inference/services/vision/base_vision_service.py +177 -0
  24. isa_model/inference/services/vision/ollama_vision_service.py +143 -17
  25. isa_model/inference/services/vision/replicate_image_gen_service.py +139 -7
  26. isa_model/training/__init__.py +62 -32
  27. isa_model/training/cloud/__init__.py +22 -0
  28. isa_model/training/cloud/job_orchestrator.py +402 -0
  29. isa_model/training/cloud/runpod_trainer.py +454 -0
  30. isa_model/training/cloud/storage_manager.py +482 -0
  31. isa_model/training/core/__init__.py +23 -0
  32. isa_model/training/core/config.py +181 -0
  33. isa_model/training/core/dataset.py +222 -0
  34. isa_model/training/core/trainer.py +720 -0
  35. isa_model/training/core/utils.py +213 -0
  36. isa_model/training/factory.py +229 -198
  37. isa_model-0.0.8.dist-info/METADATA +465 -0
  38. isa_model-0.0.8.dist-info/RECORD +86 -0
  39. isa_model/core/model_router.py +0 -226
  40. isa_model/core/model_version.py +0 -0
  41. isa_model/core/resource_manager.py +0 -202
  42. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  43. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  44. isa_model/training/engine/llama_factory/__init__.py +0 -39
  45. isa_model/training/engine/llama_factory/config.py +0 -115
  46. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  47. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  48. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  49. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  50. isa_model/training/engine/llama_factory/factory.py +0 -331
  51. isa_model/training/engine/llama_factory/rl.py +0 -254
  52. isa_model/training/engine/llama_factory/trainer.py +0 -171
  53. isa_model/training/image_model/configs/create_config.py +0 -37
  54. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  55. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  56. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  57. isa_model/training/image_model/prepare_upload.py +0 -17
  58. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  59. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  60. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  61. isa_model/training/image_model/train/train.py +0 -42
  62. isa_model/training/image_model/train/train_flux.py +0 -41
  63. isa_model/training/image_model/train/train_lora.py +0 -57
  64. isa_model/training/image_model/train_main.py +0 -25
  65. isa_model-0.0.4.dist-info/METADATA +0 -327
  66. isa_model-0.0.4.dist-info/RECORD +0 -92
  67. isa_model-0.0.4.dist-info/licenses/LICENSE +0 -21
  68. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  69. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  70. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  71. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  72. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  73. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  74. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  75. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  76. {isa_model-0.0.4.dist-info → isa_model-0.0.8.dist-info}/WHEEL +0 -0
  77. {isa_model-0.0.4.dist-info → isa_model-0.0.8.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,23 @@
1
- from typing import Dict, Type, Any, Optional, Tuple
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Simplified AI Factory for creating inference services
6
+ Uses the new service architecture with proper base classes
7
+ """
8
+
9
+ from typing import Dict, Type, Any, Optional, Tuple, List
2
10
  import logging
11
+ import os
3
12
  from isa_model.inference.providers.base_provider import BaseProvider
4
13
  from isa_model.inference.services.base_service import BaseService
5
14
  from isa_model.inference.base import ModelType
6
- import os
7
15
 
8
- # 设置基本的日志配置
9
- logging.basicConfig(level=logging.INFO)
10
16
  logger = logging.getLogger(__name__)
11
17
 
12
18
  class AIFactory:
13
19
  """
14
- Factory for creating AI services based on the Single Model pattern.
20
+ Simplified Factory for creating AI services with proper inheritance hierarchy
15
21
  """
16
22
 
17
23
  _instance = None
@@ -24,72 +30,78 @@ class AIFactory:
24
30
 
25
31
  def __init__(self):
26
32
  """Initialize the AI Factory."""
27
- self.triton_url = os.environ.get("TRITON_URL", "http://localhost:8000")
28
-
29
- # Cache for services (singleton pattern)
30
- self._llm_services = {}
31
- self._embedding_services = {}
32
- self._speech_services = {}
33
-
34
33
  if not self._is_initialized:
35
34
  self._providers: Dict[str, Type[BaseProvider]] = {}
36
35
  self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
37
36
  self._cached_services: Dict[str, BaseService] = {}
38
- self._initialize_defaults()
37
+ self._initialize_services()
39
38
  AIFactory._is_initialized = True
40
39
 
41
- def _initialize_defaults(self):
42
- """Initialize default providers and services"""
40
+ def _initialize_services(self):
41
+ """Initialize available providers and services"""
42
+ try:
43
+ # Register Ollama services
44
+ self._register_ollama_services()
45
+
46
+ # Register OpenAI services
47
+ self._register_openai_services()
48
+
49
+ # Register Replicate services
50
+ self._register_replicate_services()
51
+
52
+ logger.info("AI Factory initialized with simplified service architecture")
53
+
54
+ except Exception as e:
55
+ logger.error(f"Error initializing services: {e}")
56
+ logger.warning("Some services may not be available")
57
+
58
+ def _register_ollama_services(self):
59
+ """Register Ollama provider and services"""
43
60
  try:
44
- # Import providers and services
45
61
  from isa_model.inference.providers.ollama_provider import OllamaProvider
46
- from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
47
62
  from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
63
+ from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
64
+ from isa_model.inference.services.vision.ollama_vision_service import OllamaVisionService
48
65
 
49
- # Register Ollama provider and services
50
66
  self.register_provider('ollama', OllamaProvider)
51
- self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
52
67
  self.register_service('ollama', ModelType.LLM, OllamaLLMService)
68
+ self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
69
+ self.register_service('ollama', ModelType.VISION, OllamaVisionService)
53
70
 
54
- # Register OpenAI provider and services
55
- try:
56
- from isa_model.inference.providers.openai_provider import OpenAIProvider
57
- from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
58
-
59
- self.register_provider('openai', OpenAIProvider)
60
- self.register_service('openai', ModelType.LLM, OpenAILLMService)
61
- logger.info("OpenAI services registered successfully")
62
- except ImportError as e:
63
- logger.warning(f"OpenAI services not available: {e}")
71
+ logger.info("Ollama services registered successfully")
64
72
 
65
- # Register Replicate provider and services
66
- try:
67
- from isa_model.inference.providers.replicate_provider import ReplicateProvider
68
- from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
69
-
70
- self.register_provider('replicate', ReplicateProvider)
71
- self.register_service('replicate', ModelType.VISION, ReplicateVisionService)
72
- logger.info("Replicate provider and vision service registered successfully")
73
- except ImportError as e:
74
- logger.warning(f"Replicate services not available: {e}")
75
- except Exception as e:
76
- logger.warning(f"Error registering Replicate services: {e}")
73
+ except ImportError as e:
74
+ logger.warning(f"Ollama services not available: {e}")
75
+
76
+ def _register_openai_services(self):
77
+ """Register OpenAI provider and services"""
78
+ try:
79
+ from isa_model.inference.providers.openai_provider import OpenAIProvider
80
+ from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
81
+ from isa_model.inference.services.audio.openai_tts_service import OpenAITTSService
77
82
 
78
- # Try to register Triton services
79
- try:
80
- from isa_model.inference.providers.triton_provider import TritonProvider
81
-
82
- self.register_provider('triton', TritonProvider)
83
- logger.info("Triton provider registered successfully")
84
-
85
- except ImportError as e:
86
- logger.warning(f"Triton provider not available: {e}")
83
+ self.register_provider('openai', OpenAIProvider)
84
+ self.register_service('openai', ModelType.LLM, OpenAILLMService)
85
+ self.register_service('openai', ModelType.AUDIO, OpenAITTSService)
87
86
 
88
- logger.info("Default AI providers and services initialized with backend architecture")
89
- except Exception as e:
90
- logger.error(f"Error initializing default providers and services: {e}")
91
- # Don't raise - allow factory to work even if some services fail to load
92
- logger.warning("Some services may not be available due to import errors")
87
+ logger.info("OpenAI services registered successfully")
88
+
89
+ except ImportError as e:
90
+ logger.warning(f"OpenAI services not available: {e}")
91
+
92
+ def _register_replicate_services(self):
93
+ """Register Replicate provider and services"""
94
+ try:
95
+ from isa_model.inference.providers.replicate_provider import ReplicateProvider
96
+ from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateImageGenService
97
+
98
+ self.register_provider('replicate', ReplicateProvider)
99
+ self.register_service('replicate', ModelType.VISION, ReplicateImageGenService)
100
+
101
+ logger.info("Replicate services registered successfully")
102
+
103
+ except ImportError as e:
104
+ logger.warning(f"Replicate services not available: {e}")
93
105
 
94
106
  def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
95
107
  """Register an AI provider"""
@@ -109,24 +121,20 @@ class AIFactory:
109
121
  if cache_key in self._cached_services:
110
122
  return self._cached_services[cache_key]
111
123
 
112
- # 基础配置
113
- base_config = {
114
- "log_level": "INFO"
115
- }
116
-
117
- # 合并配置
118
- service_config = {**base_config, **(config or {})}
119
-
120
- # 创建 provider 和 service
121
- provider_class = self._providers[provider_name]
124
+ # Get provider and service classes
125
+ provider_class = self._providers.get(provider_name)
122
126
  service_class = self._services.get((provider_name, model_type))
123
127
 
128
+ if not provider_class:
129
+ raise ValueError(f"No provider registered for '{provider_name}'")
130
+
124
131
  if not service_class:
125
132
  raise ValueError(
126
- f"No service registered for provider {provider_name} and model type {model_type}"
133
+ f"No service registered for provider '{provider_name}' and model type '{model_type}'"
127
134
  )
128
135
 
129
- provider = provider_class(config=service_config)
136
+ # Create provider and service
137
+ provider = provider_class(config=config or {})
130
138
  service = service_class(provider=provider, model_name=model_name)
131
139
 
132
140
  self._cached_services[cache_key] = service
@@ -137,223 +145,117 @@ class AIFactory:
137
145
  raise
138
146
 
139
147
  # Convenient methods for common services
140
- def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
141
- config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
148
+ def get_llm_service(self, model_name: str = "llama3.1", provider: str = "ollama",
149
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
142
150
  """
143
151
  Get a LLM service instance
144
152
 
145
153
  Args:
146
154
  model_name: Name of the model to use
147
- provider: Provider name ('ollama', 'openai', 'replicate', etc.)
155
+ provider: Provider name ('ollama', 'openai')
148
156
  config: Optional configuration dictionary
149
- api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
150
157
 
151
158
  Returns:
152
159
  LLM service instance
153
-
154
- Example:
155
- # Using with API key directly
156
- llm = AIFactory.get_instance().get_llm(
157
- model_name="gpt-4o-mini",
158
- provider="openai",
159
- api_key="your-api-key-here"
160
- )
161
-
162
- # Using without API key (will use environment variable)
163
- llm = AIFactory.get_instance().get_llm(
164
- model_name="gpt-4o-mini",
165
- provider="openai"
166
- )
167
160
  """
168
-
169
- # Special case for DeepSeek service
170
- if model_name.lower() in ["deepseek", "deepseek-r1", "qwen3-8b"]:
171
- if "deepseek" in self._cached_services:
172
- return self._cached_services["deepseek"]
173
-
174
- # Special case for Llama3-8B direct service
175
- if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
176
- if "llama3" in self._cached_services:
177
- return self._cached_services["llama3"]
178
-
179
- basic_config: Dict[str, Any] = {
180
- "temperature": 0
181
- }
182
-
183
- # Add API key to config if provided
184
- if api_key:
185
- if provider == "openai":
186
- basic_config["api_key"] = api_key
187
- elif provider == "replicate":
188
- basic_config["api_token"] = api_key
189
- else:
190
- logger.warning(f"API key provided but provider '{provider}' may not support it")
191
-
192
- if config:
193
- basic_config.update(config)
194
- return self.create_service(provider, ModelType.LLM, model_name, basic_config)
161
+ return self.create_service(provider, ModelType.LLM, model_name, config)
195
162
 
196
- def get_vision_model(self, model_name: str = "gemma3-4b", provider: str = "triton",
197
- config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
163
+ def get_embedding_service(self, model_name: str = "bge-m3", provider: str = "ollama",
164
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
198
165
  """
199
- Get a vision model service instance
166
+ Get an embedding service instance
200
167
 
201
168
  Args:
202
169
  model_name: Name of the model to use
203
- provider: Provider name ('openai', 'replicate', 'triton', etc.)
170
+ provider: Provider name ('ollama')
204
171
  config: Optional configuration dictionary
205
- api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
206
172
 
207
173
  Returns:
208
- Vision service instance
209
-
210
- Example:
211
- # Using with API key directly
212
- vision = AIFactory.get_instance().get_vision_model(
213
- model_name="gpt-4o",
214
- provider="openai",
215
- api_key="your-api-key-here"
216
- )
217
-
218
- # Using Replicate for image generation
219
- image_gen = AIFactory.get_instance().get_vision_model(
220
- model_name="stability-ai/sdxl",
221
- provider="replicate",
222
- api_key="your-replicate-token"
223
- )
174
+ Embedding service instance
224
175
  """
225
-
226
- # Special case for Gemma3-4B direct service
227
- if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
228
- if "gemma3" in self._cached_services:
229
- return self._cached_services["gemma3"]
230
-
231
- # Special case for Replicate's image generation models
232
- if provider == "replicate" and "/" in model_name:
233
- replicate_config: Dict[str, Any] = {
234
- "guidance_scale": 7.5,
235
- "num_inference_steps": 30
236
- }
237
-
238
- # Add API key if provided
239
- if api_key:
240
- replicate_config["api_token"] = api_key
241
-
242
- if config:
243
- replicate_config.update(config)
244
- return self.create_service(provider, ModelType.VISION, model_name, replicate_config)
245
-
246
- basic_config: Dict[str, Any] = {
247
- "temperature": 0.7,
248
- "max_new_tokens": 512
249
- }
250
-
251
- # Add API key to config if provided
252
- if api_key:
253
- if provider == "openai":
254
- basic_config["api_key"] = api_key
255
- elif provider == "replicate":
256
- basic_config["api_token"] = api_key
257
- else:
258
- logger.warning(f"API key provided but provider '{provider}' may not support it")
259
-
260
- if config:
261
- basic_config.update(config)
262
- return self.create_service(provider, ModelType.VISION, model_name, basic_config)
263
-
264
- def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
265
- config: Optional[Dict[str, Any]] = None) -> BaseService:
266
- """Get an embedding service instance"""
267
176
  return self.create_service(provider, ModelType.EMBEDDING, model_name, config)
268
177
 
269
- def get_rerank(self, model_name: str = "bge-m3", provider: str = "ollama",
270
- config: Optional[Dict[str, Any]] = None) -> BaseService:
271
- """Get a rerank service instance"""
272
- return self.create_service(provider, ModelType.RERANK, model_name, config)
273
-
274
- def get_embed_service(self, model_name: str = "bge-m3", provider: str = "ollama",
275
- config: Optional[Dict[str, Any]] = None) -> BaseService:
276
- """Get an embedding service instance"""
277
- return self.get_embedding(model_name, provider, config)
278
-
279
- def get_speech_model(self, model_name: str = "whisper_tiny", provider: str = "triton",
280
- config: Optional[Dict[str, Any]] = None) -> BaseService:
281
- """Get a speech-to-text model service instance"""
282
-
283
- # Special case for Whisper Tiny direct service
284
- if model_name.lower() in ["whisper", "whisper_tiny", "whisper-tiny"]:
285
- if "whisper" in self._cached_services:
286
- return self._cached_services["whisper"]
287
-
288
- basic_config = {
289
- "language": "en",
290
- "task": "transcribe"
291
- }
292
- if config:
293
- basic_config.update(config)
294
- return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
295
-
296
- async def get_embedding_service(self, model_name: str) -> Any:
178
+ def get_vision_service(self, model_name: str, provider: str,
179
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
297
180
  """
298
- Get an embedding service for the specified model.
181
+ Get a vision service instance
299
182
 
300
183
  Args:
301
- model_name: Name of the model
184
+ model_name: Name of the model to use
185
+ provider: Provider name ('ollama', 'replicate')
186
+ config: Optional configuration dictionary
302
187
 
303
188
  Returns:
304
- Embedding service instance
189
+ Vision service instance
305
190
  """
306
- if model_name in self._embedding_services:
307
- return self._embedding_services[model_name]
308
-
309
- else:
310
- raise ValueError(f"Unsupported embedding model: {model_name}")
191
+ return self.create_service(provider, ModelType.VISION, model_name, config)
311
192
 
312
- async def get_speech_service(self, model_name: str) -> Any:
193
+ def get_image_generation_service(self, model_name: str, provider: str = "replicate",
194
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
313
195
  """
314
- Get a speech service for the specified model.
196
+ Get an image generation service instance
315
197
 
316
198
  Args:
317
- model_name: Name of the model
199
+ model_name: Name of the model to use (e.g., "stability-ai/sdxl")
200
+ provider: Provider name ('replicate')
201
+ config: Optional configuration dictionary
318
202
 
319
203
  Returns:
320
- Speech service instance
204
+ Image generation service instance
321
205
  """
322
- if model_name in self._speech_services:
323
- return self._speech_services[model_name]
324
-
206
+ return self.create_service(provider, ModelType.VISION, model_name, config)
325
207
 
326
- def get_model_info(self, model_type: Optional[str] = None) -> Dict[str, Any]:
208
+ def get_audio_service(self, model_name: str = "tts-1", provider: str = "openai",
209
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
327
210
  """
328
- Get information about available models.
211
+ Get an audio service instance
329
212
 
330
213
  Args:
331
- model_type: Optional filter for model type
214
+ model_name: Name of the model to use
215
+ provider: Provider name ('openai')
216
+ config: Optional configuration dictionary
332
217
 
333
218
  Returns:
334
- Dict of model information
219
+ Audio service instance
335
220
  """
336
- models = {
337
- "llm": [
338
- {"name": "deepseek", "description": "DeepSeek-R1-0528-Qwen3-8B language model"},
339
- {"name": "llama", "description": "Llama3-8B language model"},
340
- {"name": "gemma", "description": "Gemma3-4B language model"}
341
- ],
342
- "embedding": [
343
- {"name": "bge_embed", "description": "BGE-M3 text embedding model"}
344
- ],
345
- "speech": [
346
- {"name": "whisper", "description": "Whisper-tiny speech-to-text model"}
347
- ]
348
- }
349
-
350
- if model_type:
351
- return {model_type: models.get(model_type, [])}
352
- return models
221
+ return self.create_service(provider, ModelType.AUDIO, model_name, config)
222
+
223
+ def get_available_services(self) -> Dict[str, List[str]]:
224
+ """Get information about available services"""
225
+ services = {}
226
+ for (provider, model_type), service_class in self._services.items():
227
+ if provider not in services:
228
+ services[provider] = []
229
+ services[provider].append(f"{model_type.value}: {service_class.__name__}")
230
+ return services
231
+
232
+ def clear_cache(self):
233
+ """Clear the service cache"""
234
+ self._cached_services.clear()
235
+ logger.info("Service cache cleared")
353
236
 
354
237
  @classmethod
355
238
  def get_instance(cls) -> 'AIFactory':
356
239
  """Get the singleton instance"""
357
240
  if cls._instance is None:
358
241
  cls._instance = cls()
359
- return cls._instance
242
+ return cls._instance
243
+
244
+ # Alias methods for backward compatibility with tests
245
+ def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
246
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
247
+ """Alias for get_llm_service"""
248
+ return self.get_llm_service(model_name, provider, config)
249
+
250
+ def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
251
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
252
+ """Alias for get_embedding_service"""
253
+ return self.get_embedding_service(model_name, provider, config)
254
+
255
+ def get_vision_model(self, model_name: str, provider: str,
256
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
257
+ """Alias for get_vision_service and get_image_generation_service"""
258
+ if provider == "replicate":
259
+ return self.get_image_generation_service(model_name, provider, config)
260
+ else:
261
+ return self.get_vision_service(model_name, provider, config)
@@ -0,0 +1,50 @@
1
+ from isa_model.inference.providers.base_provider import BaseProvider
2
+ from isa_model.inference.base import ModelType, Capability
3
+ from typing import Dict, List, Any
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class MLProvider(BaseProvider):
9
+ """Provider for traditional ML models"""
10
+
11
+ def __init__(self, config=None):
12
+ default_config = {
13
+ "model_directory": "./models/ml",
14
+ "cache_models": True,
15
+ "max_cache_size": 5
16
+ }
17
+
18
+ merged_config = {**default_config, **(config or {})}
19
+ super().__init__(config=merged_config)
20
+ self.name = "ml"
21
+
22
+ logger.info(f"Initialized MLProvider with model directory: {self.config['model_directory']}")
23
+
24
+ def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
25
+ """Get provider capabilities"""
26
+ return {
27
+ ModelType.LLM: [], # ML models are not LLMs
28
+ ModelType.EMBEDDING: [],
29
+ ModelType.VISION: [],
30
+ "ML": [ # Custom model type for traditional ML
31
+ "CLASSIFICATION",
32
+ "REGRESSION",
33
+ "CLUSTERING",
34
+ "FEATURE_EXTRACTION"
35
+ ]
36
+ }
37
+
38
+ def get_models(self, model_type: str = "ML") -> List[str]:
39
+ """Get available ML models"""
40
+ # In practice, this would scan the model directory
41
+ return [
42
+ "fraud_detection_rf",
43
+ "customer_churn_xgb",
44
+ "price_prediction_lr",
45
+ "recommendation_kmeans"
46
+ ]
47
+
48
+ def get_config(self) -> Dict[str, Any]:
49
+ """Get provider configuration"""
50
+ return self.config
@@ -1,15 +1,15 @@
1
- from typing import Dict, Any
1
+ from typing import Dict, Any, List, Optional
2
2
  import tempfile
3
3
  import os
4
4
  from openai import AsyncOpenAI
5
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
- from isa_model.inference.services.base_service import BaseService
6
+ from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
7
  from isa_model.inference.providers.base_provider import BaseProvider
8
8
  import logging
9
9
 
10
10
  logger = logging.getLogger(__name__)
11
11
 
12
- class YYDSAudioService(BaseService):
12
+ class OpenAITTSService(BaseTTSService):
13
13
  """Audio model service wrapper for YYDS"""
14
14
 
15
15
  def __init__(self, provider: 'BaseProvider', model_name: str):
@@ -69,3 +69,104 @@ class YYDSAudioService(BaseService):
69
69
  except Exception as e:
70
70
  logger.error(f"Error in audio transcription: {e}")
71
71
  raise
72
+
73
+ # 实现BaseTTSService的抽象方法
74
+ async def synthesize_speech(
75
+ self,
76
+ text: str,
77
+ voice: Optional[str] = None,
78
+ speed: float = 1.0,
79
+ pitch: float = 1.0,
80
+ format: str = "mp3"
81
+ ) -> Dict[str, Any]:
82
+ """Synthesize speech from text using OpenAI TTS"""
83
+ try:
84
+ response = await self._client.audio.speech.create(
85
+ model="tts-1",
86
+ voice=voice or "alloy",
87
+ input=text,
88
+ response_format=format,
89
+ speed=speed
90
+ )
91
+
92
+ audio_data = response.content
93
+
94
+ return {
95
+ "audio_data": audio_data,
96
+ "format": format,
97
+ "duration": 0.0, # OpenAI doesn't provide duration
98
+ "sample_rate": 24000 # Default for OpenAI TTS
99
+ }
100
+
101
+ except Exception as e:
102
+ logger.error(f"Error in speech synthesis: {e}")
103
+ raise
104
+
105
+ async def synthesize_speech_to_file(
106
+ self,
107
+ text: str,
108
+ output_path: str,
109
+ voice: Optional[str] = None,
110
+ speed: float = 1.0,
111
+ pitch: float = 1.0,
112
+ format: str = "mp3"
113
+ ) -> Dict[str, Any]:
114
+ """Synthesize speech and save to file"""
115
+ result = await self.synthesize_speech(text, voice, speed, pitch, format)
116
+
117
+ with open(output_path, 'wb') as f:
118
+ f.write(result["audio_data"])
119
+
120
+ return {
121
+ "file_path": output_path,
122
+ "duration": result["duration"],
123
+ "sample_rate": result["sample_rate"]
124
+ }
125
+
126
+ async def synthesize_speech_batch(
127
+ self,
128
+ texts: List[str],
129
+ voice: Optional[str] = None,
130
+ speed: float = 1.0,
131
+ pitch: float = 1.0,
132
+ format: str = "mp3"
133
+ ) -> List[Dict[str, Any]]:
134
+ """Synthesize speech for multiple texts"""
135
+ results = []
136
+ for text in texts:
137
+ result = await self.synthesize_speech(text, voice, speed, pitch, format)
138
+ results.append(result)
139
+ return results
140
+
141
+ def get_available_voices(self) -> List[Dict[str, Any]]:
142
+ """Get list of available OpenAI voices"""
143
+ return [
144
+ {"id": "alloy", "name": "Alloy", "language": "en-US", "gender": "neutral", "age": "adult"},
145
+ {"id": "echo", "name": "Echo", "language": "en-US", "gender": "male", "age": "adult"},
146
+ {"id": "fable", "name": "Fable", "language": "en-US", "gender": "neutral", "age": "adult"},
147
+ {"id": "onyx", "name": "Onyx", "language": "en-US", "gender": "male", "age": "adult"},
148
+ {"id": "nova", "name": "Nova", "language": "en-US", "gender": "female", "age": "adult"},
149
+ {"id": "shimmer", "name": "Shimmer", "language": "en-US", "gender": "female", "age": "adult"}
150
+ ]
151
+
152
+ def get_supported_formats(self) -> List[str]:
153
+ """Get list of supported audio formats"""
154
+ return ["mp3", "opus", "aac", "flac"]
155
+
156
+ def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
157
+ """Get detailed information about a specific voice"""
158
+ voices = {voice["id"]: voice for voice in self.get_available_voices()}
159
+ voice_info = voices.get(voice_id, {})
160
+
161
+ if voice_info:
162
+ voice_info.update({
163
+ "description": f"OpenAI {voice_info['name']} voice",
164
+ "sample_rate": 24000
165
+ })
166
+
167
+ return voice_info
168
+
169
+ async def close(self):
170
+ """Cleanup resources"""
171
+ if hasattr(self._client, 'close'):
172
+ await self._client.close()