isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_registry.py +273 -46
  3. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  4. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  5. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  6. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  7. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  8. isa_model/eval/__init__.py +56 -0
  9. isa_model/eval/benchmarks.py +469 -0
  10. isa_model/eval/factory.py +582 -0
  11. isa_model/eval/metrics.py +628 -0
  12. isa_model/inference/ai_factory.py +98 -93
  13. isa_model/inference/providers/openai_provider.py +21 -7
  14. isa_model/inference/providers/replicate_provider.py +18 -5
  15. isa_model/inference/providers/triton_provider.py +1 -1
  16. isa_model/inference/services/audio/base_stt_service.py +91 -0
  17. isa_model/inference/services/audio/base_tts_service.py +136 -0
  18. isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
  19. isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
  20. isa_model/inference/services/llm/__init__.py +0 -4
  21. isa_model/inference/services/llm/base_llm_service.py +134 -0
  22. isa_model/inference/services/llm/ollama_llm_service.py +1 -10
  23. isa_model/inference/services/llm/openai_llm_service.py +70 -61
  24. isa_model/inference/services/vision/__init__.py +1 -1
  25. isa_model/inference/services/vision/ollama_vision_service.py +4 -4
  26. isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
  27. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  28. isa_model/training/__init__.py +44 -0
  29. isa_model/training/factory.py +393 -0
  30. isa_model-0.1.1.dist-info/METADATA +327 -0
  31. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
  32. isa_model/deployment/mlflow_gateway/__init__.py +0 -8
  33. isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
  34. isa_model/deployment/unified_multimodal_client.py +0 -341
  35. isa_model/inference/adapter/triton_adapter.py +0 -453
  36. isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
  37. isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
  38. isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
  39. isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
  40. isa_model/inference/backends/__init__.py +0 -53
  41. isa_model/inference/backends/base_backend_client.py +0 -26
  42. isa_model/inference/backends/container_services.py +0 -104
  43. isa_model/inference/backends/local_services.py +0 -72
  44. isa_model/inference/backends/openai_client.py +0 -130
  45. isa_model/inference/backends/replicate_client.py +0 -197
  46. isa_model/inference/backends/third_party_services.py +0 -239
  47. isa_model/inference/backends/triton_client.py +0 -97
  48. isa_model/inference/client_sdk/client.py +0 -134
  49. isa_model/inference/client_sdk/client_data_std.py +0 -34
  50. isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
  51. isa_model/inference/client_sdk/exceptions.py +0 -0
  52. isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
  53. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
  54. isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
  55. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
  56. isa_model/inference/providers/vllm_provider.py +0 -0
  57. isa_model/inference/providers/yyds_provider.py +0 -83
  58. isa_model/inference/services/audio/fish_speech/handler.py +0 -215
  59. isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
  60. isa_model/inference/services/audio/triton_speech_service.py +0 -138
  61. isa_model/inference/services/audio/whisper_service.py +0 -186
  62. isa_model/inference/services/base_tts_service.py +0 -66
  63. isa_model/inference/services/embedding/bge_service.py +0 -183
  64. isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
  65. isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
  66. isa_model/inference/services/llm/gemma_service.py +0 -143
  67. isa_model/inference/services/llm/llama_service.py +0 -143
  68. isa_model/inference/services/llm/replicate_llm_service.py +0 -179
  69. isa_model/inference/services/llm/triton_llm_service.py +0 -230
  70. isa_model/inference/services/vision/replicate_vision_service.py +0 -241
  71. isa_model/inference/services/vision/triton_vision_service.py +0 -199
  72. isa_model-0.1.0.dist-info/METADATA +0 -116
  73. /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
  74. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
  75. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
  76. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -5,11 +5,6 @@ from isa_model.inference.services.base_service import BaseService
5
5
  from isa_model.inference.base import ModelType
6
6
  import os
7
7
 
8
- from isa_model.inference.services.llm.llama_service import LlamaService
9
- from isa_model.inference.services.llm.gemma_service import GemmaService
10
- from isa_model.inference.services.audio.whisper_service import WhisperService
11
- from isa_model.inference.services.embedding.bge_service import BgeEmbeddingService
12
-
13
8
  # 设置基本的日志配置
14
9
  logging.basicConfig(level=logging.INFO)
15
10
  logger = logging.getLogger(__name__)
@@ -29,7 +24,7 @@ class AIFactory:
29
24
 
30
25
  def __init__(self):
31
26
  """Initialize the AI Factory."""
32
- self.triton_url = os.environ.get("TRITON_URL", "localhost:8001")
27
+ self.triton_url = os.environ.get("TRITON_URL", "http://localhost:8000")
33
28
 
34
29
  # Cache for services (singleton pattern)
35
30
  self._llm_services = {}
@@ -70,58 +65,25 @@ class AIFactory:
70
65
  # Register Replicate provider and services
71
66
  try:
72
67
  from isa_model.inference.providers.replicate_provider import ReplicateProvider
73
- from isa_model.inference.services.llm.replicate_llm_service import ReplicateLLMService
74
- from isa_model.inference.services.vision.replicate_vision_service import ReplicateVisionService
68
+ from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
75
69
 
76
70
  self.register_provider('replicate', ReplicateProvider)
77
- self.register_service('replicate', ModelType.LLM, ReplicateLLMService)
78
71
  self.register_service('replicate', ModelType.VISION, ReplicateVisionService)
79
- logger.info("Replicate services registered successfully")
72
+ logger.info("Replicate provider and vision service registered successfully")
80
73
  except ImportError as e:
81
74
  logger.warning(f"Replicate services not available: {e}")
75
+ except Exception as e:
76
+ logger.warning(f"Error registering Replicate services: {e}")
82
77
 
83
78
  # Try to register Triton services
84
79
  try:
85
80
  from isa_model.inference.providers.triton_provider import TritonProvider
86
- from isa_model.inference.services.llm.triton_llm_service import TritonLLMService
87
- from isa_model.inference.services.vision.triton_vision_service import TritonVisionService
88
- from isa_model.inference.services.audio.triton_speech_service import TritonSpeechService
89
81
 
90
82
  self.register_provider('triton', TritonProvider)
91
- self.register_service('triton', ModelType.LLM, TritonLLMService)
92
- self.register_service('triton', ModelType.VISION, TritonVisionService)
93
- self.register_service('triton', ModelType.AUDIO, TritonSpeechService)
94
- logger.info("Triton services registered successfully")
95
-
96
- # Register HuggingFace-based direct LLM service for Llama3-8B
97
- try:
98
- from isa_model.inference.llm.llama3_service import Llama3Service
99
- # Register as a standalone service for direct access
100
- self._cached_services["llama3"] = Llama3Service()
101
- logger.info("Llama3-8B service registered successfully")
102
- except ImportError as e:
103
- logger.warning(f"Llama3-8B service not available: {e}")
104
-
105
- # Register HuggingFace-based direct Vision service for Gemma3-4B
106
- try:
107
- from isa_model.inference.vision.gemma3_service import Gemma3VisionService
108
- # Register as a standalone service for direct access
109
- self._cached_services["gemma3"] = Gemma3VisionService()
110
- logger.info("Gemma3-4B Vision service registered successfully")
111
- except ImportError as e:
112
- logger.warning(f"Gemma3-4B Vision service not available: {e}")
113
-
114
- # Register HuggingFace-based direct Speech service for Whisper Tiny
115
- try:
116
- from isa_model.inference.speech.whisper_service import WhisperService
117
- # Register as a standalone service for direct access
118
- self._cached_services["whisper"] = WhisperService()
119
- logger.info("Whisper Tiny Speech service registered successfully")
120
- except ImportError as e:
121
- logger.warning(f"Whisper Tiny Speech service not available: {e}")
83
+ logger.info("Triton provider registered successfully")
122
84
 
123
85
  except ImportError as e:
124
- logger.warning(f"Triton services not available: {e}")
86
+ logger.warning(f"Triton provider not available: {e}")
125
87
 
126
88
  logger.info("Default AI providers and services initialized with backend architecture")
127
89
  except Exception as e:
@@ -176,24 +138,90 @@ class AIFactory:
176
138
 
177
139
  # Convenient methods for common services
178
140
  def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
179
- config: Optional[Dict[str, Any]] = None) -> BaseService:
180
- """Get a LLM service instance"""
141
+ config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
142
+ """
143
+ Get a LLM service instance
144
+
145
+ Args:
146
+ model_name: Name of the model to use
147
+ provider: Provider name ('ollama', 'openai', 'replicate', etc.)
148
+ config: Optional configuration dictionary
149
+ api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
150
+
151
+ Returns:
152
+ LLM service instance
153
+
154
+ Example:
155
+ # Using with API key directly
156
+ llm = AIFactory.get_instance().get_llm(
157
+ model_name="gpt-4o-mini",
158
+ provider="openai",
159
+ api_key="your-api-key-here"
160
+ )
161
+
162
+ # Using without API key (will use environment variable)
163
+ llm = AIFactory.get_instance().get_llm(
164
+ model_name="gpt-4o-mini",
165
+ provider="openai"
166
+ )
167
+ """
168
+
169
+ # Special case for DeepSeek service
170
+ if model_name.lower() in ["deepseek", "deepseek-r1", "qwen3-8b"]:
171
+ if "deepseek" in self._cached_services:
172
+ return self._cached_services["deepseek"]
181
173
 
182
174
  # Special case for Llama3-8B direct service
183
175
  if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
184
176
  if "llama3" in self._cached_services:
185
177
  return self._cached_services["llama3"]
186
178
 
187
- basic_config = {
179
+ basic_config: Dict[str, Any] = {
188
180
  "temperature": 0
189
181
  }
182
+
183
+ # Add API key to config if provided
184
+ if api_key:
185
+ if provider == "openai":
186
+ basic_config["api_key"] = api_key
187
+ elif provider == "replicate":
188
+ basic_config["api_token"] = api_key
189
+ else:
190
+ logger.warning(f"API key provided but provider '{provider}' may not support it")
191
+
190
192
  if config:
191
193
  basic_config.update(config)
192
194
  return self.create_service(provider, ModelType.LLM, model_name, basic_config)
193
195
 
194
196
  def get_vision_model(self, model_name: str = "gemma3-4b", provider: str = "triton",
195
- config: Optional[Dict[str, Any]] = None) -> BaseService:
196
- """Get a vision model service instance"""
197
+ config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
198
+ """
199
+ Get a vision model service instance
200
+
201
+ Args:
202
+ model_name: Name of the model to use
203
+ provider: Provider name ('openai', 'replicate', 'triton', etc.)
204
+ config: Optional configuration dictionary
205
+ api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
206
+
207
+ Returns:
208
+ Vision service instance
209
+
210
+ Example:
211
+ # Using with API key directly
212
+ vision = AIFactory.get_instance().get_vision_model(
213
+ model_name="gpt-4o",
214
+ provider="openai",
215
+ api_key="your-api-key-here"
216
+ )
217
+
218
+ # Using Replicate for image generation
219
+ image_gen = AIFactory.get_instance().get_vision_model(
220
+ model_name="stability-ai/sdxl",
221
+ provider="replicate",
222
+ api_key="your-replicate-token"
223
+ )
224
+ """
197
225
 
198
226
  # Special case for Gemma3-4B direct service
199
227
  if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
@@ -202,19 +230,33 @@ class AIFactory:
202
230
 
203
231
  # Special case for Replicate's image generation models
204
232
  if provider == "replicate" and "/" in model_name:
205
- basic_config = {
206
- "api_token": os.environ.get("REPLICATE_API_TOKEN", ""),
233
+ replicate_config: Dict[str, Any] = {
207
234
  "guidance_scale": 7.5,
208
235
  "num_inference_steps": 30
209
236
  }
237
+
238
+ # Add API key if provided
239
+ if api_key:
240
+ replicate_config["api_token"] = api_key
241
+
210
242
  if config:
211
- basic_config.update(config)
212
- return self.create_service(provider, ModelType.VISION, model_name, basic_config)
243
+ replicate_config.update(config)
244
+ return self.create_service(provider, ModelType.VISION, model_name, replicate_config)
213
245
 
214
- basic_config = {
246
+ basic_config: Dict[str, Any] = {
215
247
  "temperature": 0.7,
216
248
  "max_new_tokens": 512
217
249
  }
250
+
251
+ # Add API key to config if provided
252
+ if api_key:
253
+ if provider == "openai":
254
+ basic_config["api_key"] = api_key
255
+ elif provider == "replicate":
256
+ basic_config["api_token"] = api_key
257
+ else:
258
+ logger.warning(f"API key provided but provider '{provider}' may not support it")
259
+
218
260
  if config:
219
261
  basic_config.update(config)
220
262
  return self.create_service(provider, ModelType.VISION, model_name, basic_config)
@@ -251,32 +293,6 @@ class AIFactory:
251
293
  basic_config.update(config)
252
294
  return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
253
295
 
254
- async def get_llm_service(self, model_name: str) -> Any:
255
- """
256
- Get an LLM service for the specified model.
257
-
258
- Args:
259
- model_name: Name of the model
260
-
261
- Returns:
262
- LLM service instance
263
- """
264
- if model_name in self._llm_services:
265
- return self._llm_services[model_name]
266
-
267
- if model_name == "llama":
268
- service = LlamaService(triton_url=self.triton_url, model_name="llama")
269
- await service.load()
270
- self._llm_services[model_name] = service
271
- return service
272
- elif model_name == "gemma":
273
- service = GemmaService(triton_url=self.triton_url, model_name="gemma")
274
- await service.load()
275
- self._llm_services[model_name] = service
276
- return service
277
- else:
278
- raise ValueError(f"Unsupported LLM model: {model_name}")
279
-
280
296
  async def get_embedding_service(self, model_name: str) -> Any:
281
297
  """
282
298
  Get an embedding service for the specified model.
@@ -290,11 +306,6 @@ class AIFactory:
290
306
  if model_name in self._embedding_services:
291
307
  return self._embedding_services[model_name]
292
308
 
293
- if model_name == "bge_embed":
294
- service = BgeEmbeddingService(triton_url=self.triton_url, model_name="bge_embed")
295
- await service.load()
296
- self._embedding_services[model_name] = service
297
- return service
298
309
  else:
299
310
  raise ValueError(f"Unsupported embedding model: {model_name}")
300
311
 
@@ -311,13 +322,6 @@ class AIFactory:
311
322
  if model_name in self._speech_services:
312
323
  return self._speech_services[model_name]
313
324
 
314
- if model_name == "whisper":
315
- service = WhisperService(triton_url=self.triton_url, model_name="whisper")
316
- await service.load()
317
- self._speech_services[model_name] = service
318
- return service
319
- else:
320
- raise ValueError(f"Unsupported speech model: {model_name}")
321
325
 
322
326
  def get_model_info(self, model_type: Optional[str] = None) -> Dict[str, Any]:
323
327
  """
@@ -331,6 +335,7 @@ class AIFactory:
331
335
  """
332
336
  models = {
333
337
  "llm": [
338
+ {"name": "deepseek", "description": "DeepSeek-R1-0528-Qwen3-8B language model"},
334
339
  {"name": "llama", "description": "Llama3-8B language model"},
335
340
  {"name": "gemma", "description": "Gemma3-4B language model"}
336
341
  ],
@@ -15,13 +15,13 @@ class OpenAIProvider(BaseProvider):
15
15
 
16
16
  Args:
17
17
  config (dict, optional): Configuration for the provider
18
- - api_key: OpenAI API key (default: from environment variable)
18
+ - api_key: OpenAI API key (can be passed here or via environment variable)
19
19
  - api_base: Base URL for OpenAI API (default: https://api.openai.com/v1)
20
20
  - timeout: Timeout for API calls in seconds
21
21
  """
22
22
  default_config = {
23
- "api_key": os.environ.get("OPENAI_API_KEY", ""),
24
- "api_base": "https://api.openai.com/v1",
23
+ "api_key": "", # Will be set from config or environment
24
+ "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
25
25
  "timeout": 60,
26
26
  "stream": True,
27
27
  "temperature": 0.7,
@@ -32,14 +32,28 @@ class OpenAIProvider(BaseProvider):
32
32
  # Merge default config with provided config
33
33
  merged_config = {**default_config, **(config or {})}
34
34
 
35
+ # Set API key from config first, then fallback to environment variable
36
+ if not merged_config["api_key"]:
37
+ merged_config["api_key"] = os.environ.get("OPENAI_API_KEY", "")
38
+
35
39
  super().__init__(config=merged_config)
36
40
  self.name = "openai"
37
41
 
38
42
  logger.info(f"Initialized OpenAIProvider with URL: {self.config['api_base']}")
39
43
 
40
- # Validate API key
44
+ # Only warn if no API key is provided at all
41
45
  if not self.config["api_key"]:
42
- logger.warning("OpenAI API key not provided. Set OPENAI_API_KEY environment variable or pass in config.")
46
+ logger.info("OpenAI API key not provided. You can set it via OPENAI_API_KEY environment variable or pass it in the config when creating services.")
47
+
48
+ def set_api_key(self, api_key: str):
49
+ """
50
+ Set the API key after initialization
51
+
52
+ Args:
53
+ api_key: OpenAI API key
54
+ """
55
+ self.config["api_key"] = api_key
56
+ logger.info("OpenAI API key updated")
43
57
 
44
58
  def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
45
59
  """Get provider capabilities by model type"""
@@ -52,7 +66,7 @@ class OpenAIProvider(BaseProvider):
52
66
  Capability.EMBEDDING
53
67
  ],
54
68
  ModelType.VISION: [
55
- Capability.IMAGE_UNDERSTANDING,
69
+ Capability.IMAGE_GENERATION,
56
70
  Capability.MULTIMODAL_UNDERSTANDING
57
71
  ],
58
72
  ModelType.AUDIO: [
@@ -63,7 +77,7 @@ class OpenAIProvider(BaseProvider):
63
77
  def get_models(self, model_type: ModelType) -> List[str]:
64
78
  """Get available models for given type"""
65
79
  if model_type == ModelType.LLM:
66
- return ["gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
80
+ return ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
67
81
  elif model_type == ModelType.EMBEDDING:
68
82
  return ["text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002"]
69
83
  elif model_type == ModelType.VISION:
@@ -15,11 +15,11 @@ class ReplicateProvider(BaseProvider):
15
15
 
16
16
  Args:
17
17
  config (dict, optional): Configuration for the provider
18
- - api_token: Replicate API token (default: from environment variable)
18
+ - api_token: Replicate API token (can be passed here or via environment variable)
19
19
  - timeout: Timeout for API calls in seconds
20
20
  """
21
21
  default_config = {
22
- "api_token": os.environ.get("REPLICATE_API_TOKEN", ""),
22
+ "api_token": "", # Will be set from config or environment
23
23
  "timeout": 60,
24
24
  "stream": True,
25
25
  "max_tokens": 1024
@@ -28,14 +28,28 @@ class ReplicateProvider(BaseProvider):
28
28
  # Merge default config with provided config
29
29
  merged_config = {**default_config, **(config or {})}
30
30
 
31
+ # Set API token from config first, then fallback to environment variable
32
+ if not merged_config["api_token"]:
33
+ merged_config["api_token"] = os.environ.get("REPLICATE_API_TOKEN", "")
34
+
31
35
  super().__init__(config=merged_config)
32
36
  self.name = "replicate"
33
37
 
34
38
  logger.info(f"Initialized ReplicateProvider")
35
39
 
36
- # Validate API token
40
+ # Only warn if no API token is provided at all
37
41
  if not self.config["api_token"]:
38
- logger.warning("Replicate API token not provided. Set REPLICATE_API_TOKEN environment variable or pass in config.")
42
+ logger.info("Replicate API token not provided. You can set it via REPLICATE_API_TOKEN environment variable or pass it in the config when creating services.")
43
+
44
+ def set_api_token(self, api_token: str):
45
+ """
46
+ Set the API token after initialization
47
+
48
+ Args:
49
+ api_token: Replicate API token
50
+ """
51
+ self.config["api_token"] = api_token
52
+ logger.info("Replicate API token updated")
39
53
 
40
54
  def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
41
55
  """Get provider capabilities by model type"""
@@ -45,7 +59,6 @@ class ReplicateProvider(BaseProvider):
45
59
  Capability.COMPLETION
46
60
  ],
47
61
  ModelType.VISION: [
48
- Capability.IMAGE_UNDERSTANDING,
49
62
  Capability.IMAGE_GENERATION,
50
63
  Capability.MULTIMODAL_UNDERSTANDING
51
64
  ],
@@ -29,7 +29,7 @@ class TritonProvider(BaseProvider):
29
29
 
30
30
  # Default configuration
31
31
  self.default_config = {
32
- "server_url": os.environ.get("TRITON_SERVER_URL", "localhost:8000"),
32
+ "server_url": os.environ.get("TRITON_SERVER_URL", "http://localhost:8000"),
33
33
  "model_repository": os.environ.get(
34
34
  "MODEL_REPOSITORY",
35
35
  os.path.join(os.getcwd(), "models/triton/model_repository")
@@ -0,0 +1,91 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, List, Union, Optional, BinaryIO
3
+ from isa_model.inference.services.base_service import BaseService
4
+
5
+ class BaseSTTService(BaseService):
6
+ """Base class for Speech-to-Text services"""
7
+
8
+ @abstractmethod
9
+ async def transcribe_audio(
10
+ self,
11
+ audio_file: Union[str, BinaryIO],
12
+ language: Optional[str] = None,
13
+ prompt: Optional[str] = None
14
+ ) -> Dict[str, Any]:
15
+ """
16
+ Transcribe audio file to text
17
+
18
+ Args:
19
+ audio_file: Path to audio file or file-like object
20
+ language: Language code (e.g., 'en', 'es', 'fr')
21
+ prompt: Optional prompt to guide transcription
22
+
23
+ Returns:
24
+ Dict containing transcription results with keys:
25
+ - text: The transcribed text
26
+ - language: Detected/specified language
27
+ - confidence: Confidence score (if available)
28
+ - segments: Time-segmented transcription (if available)
29
+ """
30
+ pass
31
+
32
+ @abstractmethod
33
+ async def transcribe_audio_batch(
34
+ self,
35
+ audio_files: List[Union[str, BinaryIO]],
36
+ language: Optional[str] = None,
37
+ prompt: Optional[str] = None
38
+ ) -> List[Dict[str, Any]]:
39
+ """
40
+ Transcribe multiple audio files
41
+
42
+ Args:
43
+ audio_files: List of audio file paths or file-like objects
44
+ language: Language code (e.g., 'en', 'es', 'fr')
45
+ prompt: Optional prompt to guide transcription
46
+
47
+ Returns:
48
+ List of transcription results
49
+ """
50
+ pass
51
+
52
+ @abstractmethod
53
+ async def detect_language(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
54
+ """
55
+ Detect language of audio file
56
+
57
+ Args:
58
+ audio_file: Path to audio file or file-like object
59
+
60
+ Returns:
61
+ Dict containing language detection results with keys:
62
+ - language: Detected language code
63
+ - confidence: Confidence score
64
+ - alternatives: List of alternative languages with scores
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def get_supported_formats(self) -> List[str]:
70
+ """
71
+ Get list of supported audio formats
72
+
73
+ Returns:
74
+ List of supported file extensions (e.g., ['mp3', 'wav', 'flac'])
75
+ """
76
+ pass
77
+
78
+ @abstractmethod
79
+ def get_supported_languages(self) -> List[str]:
80
+ """
81
+ Get list of supported language codes
82
+
83
+ Returns:
84
+ List of supported language codes (e.g., ['en', 'es', 'fr'])
85
+ """
86
+ pass
87
+
88
+ @abstractmethod
89
+ async def close(self):
90
+ """Cleanup resources"""
91
+ pass
@@ -0,0 +1,136 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Any, List, Union, Optional, BinaryIO
3
+ from isa_model.inference.services.base_service import BaseService
4
+
5
+ class BaseTTSService(BaseService):
6
+ """Base class for Text-to-Speech services"""
7
+
8
+ @abstractmethod
9
+ async def synthesize_speech(
10
+ self,
11
+ text: str,
12
+ voice: Optional[str] = None,
13
+ speed: float = 1.0,
14
+ pitch: float = 1.0,
15
+ format: str = "mp3"
16
+ ) -> Dict[str, Any]:
17
+ """
18
+ Synthesize speech from text
19
+
20
+ Args:
21
+ text: Input text to convert to speech
22
+ voice: Voice ID or name to use
23
+ speed: Speech speed multiplier (0.5-2.0)
24
+ pitch: Pitch adjustment (-1.0 to 1.0)
25
+ format: Audio format ('mp3', 'wav', 'ogg')
26
+
27
+ Returns:
28
+ Dict containing synthesis results with keys:
29
+ - audio_data: Binary audio data
30
+ - format: Audio format
31
+ - duration: Audio duration in seconds
32
+ - sample_rate: Audio sample rate
33
+ """
34
+ pass
35
+
36
+ @abstractmethod
37
+ async def synthesize_speech_to_file(
38
+ self,
39
+ text: str,
40
+ output_path: str,
41
+ voice: Optional[str] = None,
42
+ speed: float = 1.0,
43
+ pitch: float = 1.0,
44
+ format: str = "mp3"
45
+ ) -> Dict[str, Any]:
46
+ """
47
+ Synthesize speech and save directly to file
48
+
49
+ Args:
50
+ text: Input text to convert to speech
51
+ output_path: Path to save the audio file
52
+ voice: Voice ID or name to use
53
+ speed: Speech speed multiplier (0.5-2.0)
54
+ pitch: Pitch adjustment (-1.0 to 1.0)
55
+ format: Audio format ('mp3', 'wav', 'ogg')
56
+
57
+ Returns:
58
+ Dict containing synthesis results with keys:
59
+ - file_path: Path to saved audio file
60
+ - duration: Audio duration in seconds
61
+ - sample_rate: Audio sample rate
62
+ """
63
+ pass
64
+
65
+ @abstractmethod
66
+ async def synthesize_speech_batch(
67
+ self,
68
+ texts: List[str],
69
+ voice: Optional[str] = None,
70
+ speed: float = 1.0,
71
+ pitch: float = 1.0,
72
+ format: str = "mp3"
73
+ ) -> List[Dict[str, Any]]:
74
+ """
75
+ Synthesize speech for multiple texts
76
+
77
+ Args:
78
+ texts: List of input texts to convert to speech
79
+ voice: Voice ID or name to use
80
+ speed: Speech speed multiplier (0.5-2.0)
81
+ pitch: Pitch adjustment (-1.0 to 1.0)
82
+ format: Audio format ('mp3', 'wav', 'ogg')
83
+
84
+ Returns:
85
+ List of synthesis result dictionaries
86
+ """
87
+ pass
88
+
89
+ @abstractmethod
90
+ def get_available_voices(self) -> List[Dict[str, Any]]:
91
+ """
92
+ Get list of available voices
93
+
94
+ Returns:
95
+ List of voice information dictionaries with keys:
96
+ - id: Voice identifier
97
+ - name: Human-readable voice name
98
+ - language: Language code (e.g., 'en-US', 'es-ES')
99
+ - gender: Voice gender ('male', 'female', 'neutral')
100
+ - age: Voice age category ('adult', 'child', 'elderly')
101
+ """
102
+ pass
103
+
104
+ @abstractmethod
105
+ def get_supported_formats(self) -> List[str]:
106
+ """
107
+ Get list of supported audio formats
108
+
109
+ Returns:
110
+ List of supported file extensions (e.g., ['mp3', 'wav', 'ogg'])
111
+ """
112
+ pass
113
+
114
+ @abstractmethod
115
+ def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
116
+ """
117
+ Get detailed information about a specific voice
118
+
119
+ Args:
120
+ voice_id: Voice identifier
121
+
122
+ Returns:
123
+ Dict containing voice information:
124
+ - id: Voice identifier
125
+ - name: Human-readable voice name
126
+ - language: Language code
127
+ - gender: Voice gender
128
+ - description: Voice description
129
+ - sample_rate: Default sample rate
130
+ """
131
+ pass
132
+
133
+ @abstractmethod
134
+ async def close(self):
135
+ """Cleanup resources"""
136
+ pass
@@ -3,11 +3,11 @@ import tempfile
3
3
  import os
4
4
  from openai import AsyncOpenAI
5
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
- from ...base_service import BaseService
7
- from ...base_provider import BaseProvider
8
- from app.config.config_manager import config_manager
6
+ from isa_model.inference.services.base_service import BaseService
7
+ from isa_model.inference.providers.base_provider import BaseProvider
8
+ import logging
9
9
 
10
- logger = config_manager.get_logger(__name__)
10
+ logger = logging.getLogger(__name__)
11
11
 
12
12
  class YYDSAudioService(BaseService):
13
13
  """Audio model service wrapper for YYDS"""