isa-model 0.0.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_manager.py +69 -4
  3. isa_model/core/model_registry.py +273 -46
  4. isa_model/core/storage/hf_storage.py +419 -0
  5. isa_model/deployment/__init__.py +52 -0
  6. isa_model/deployment/core/__init__.py +34 -0
  7. isa_model/deployment/core/deployment_config.py +356 -0
  8. isa_model/deployment/core/deployment_manager.py +549 -0
  9. isa_model/deployment/core/isa_deployment_service.py +401 -0
  10. isa_model/eval/factory.py +381 -140
  11. isa_model/inference/ai_factory.py +427 -236
  12. isa_model/inference/billing_tracker.py +406 -0
  13. isa_model/inference/providers/base_provider.py +51 -4
  14. isa_model/inference/providers/ml_provider.py +50 -0
  15. isa_model/inference/providers/ollama_provider.py +37 -18
  16. isa_model/inference/providers/openai_provider.py +65 -36
  17. isa_model/inference/providers/replicate_provider.py +42 -30
  18. isa_model/inference/services/audio/base_stt_service.py +21 -2
  19. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  20. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  21. isa_model/inference/services/audio/openai_tts_service.py +149 -9
  22. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  23. isa_model/inference/services/base_service.py +36 -1
  24. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  25. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  26. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  27. isa_model/inference/services/llm/__init__.py +2 -0
  28. isa_model/inference/services/llm/base_llm_service.py +158 -86
  29. isa_model/inference/services/llm/llm_adapter.py +414 -0
  30. isa_model/inference/services/llm/ollama_llm_service.py +252 -63
  31. isa_model/inference/services/llm/openai_llm_service.py +231 -93
  32. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  33. isa_model/inference/services/ml/base_ml_service.py +78 -0
  34. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  35. isa_model/inference/services/vision/__init__.py +3 -3
  36. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  37. isa_model/inference/services/vision/base_vision_service.py +177 -0
  38. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  39. isa_model/inference/services/vision/ollama_vision_service.py +151 -17
  40. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  41. isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
  42. isa_model/training/__init__.py +62 -32
  43. isa_model/training/cloud/__init__.py +22 -0
  44. isa_model/training/cloud/job_orchestrator.py +402 -0
  45. isa_model/training/cloud/runpod_trainer.py +454 -0
  46. isa_model/training/cloud/storage_manager.py +482 -0
  47. isa_model/training/core/__init__.py +23 -0
  48. isa_model/training/core/config.py +181 -0
  49. isa_model/training/core/dataset.py +222 -0
  50. isa_model/training/core/trainer.py +720 -0
  51. isa_model/training/core/utils.py +213 -0
  52. isa_model/training/factory.py +229 -198
  53. isa_model-0.3.1.dist-info/METADATA +465 -0
  54. isa_model-0.3.1.dist-info/RECORD +91 -0
  55. isa_model/core/model_router.py +0 -226
  56. isa_model/core/model_version.py +0 -0
  57. isa_model/core/resource_manager.py +0 -202
  58. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  59. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  60. isa_model/training/engine/llama_factory/__init__.py +0 -39
  61. isa_model/training/engine/llama_factory/config.py +0 -115
  62. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  63. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  64. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  65. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  66. isa_model/training/engine/llama_factory/factory.py +0 -331
  67. isa_model/training/engine/llama_factory/rl.py +0 -254
  68. isa_model/training/engine/llama_factory/trainer.py +0 -171
  69. isa_model/training/image_model/configs/create_config.py +0 -37
  70. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  71. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  72. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  73. isa_model/training/image_model/prepare_upload.py +0 -17
  74. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  75. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  76. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  77. isa_model/training/image_model/train/train.py +0 -42
  78. isa_model/training/image_model/train/train_flux.py +0 -41
  79. isa_model/training/image_model/train/train_lora.py +0 -57
  80. isa_model/training/image_model/train_main.py +0 -25
  81. isa_model-0.0.2.dist-info/METADATA +0 -327
  82. isa_model-0.0.2.dist-info/RECORD +0 -92
  83. isa_model-0.0.2.dist-info/licenses/LICENSE +0 -21
  84. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  85. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  86. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  87. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  88. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  89. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  90. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  91. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  92. {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
  93. {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,252 @@
1
+ import logging
2
+ import aiohttp
3
+ from typing import Dict, Any, List, Union, Optional, BinaryIO
4
+ from openai import AsyncOpenAI
5
+ from tenacity import retry, stop_after_attempt, wait_exponential
6
+
7
+ from isa_model.inference.services.audio.base_stt_service import BaseSTTService
8
+ from isa_model.inference.providers.base_provider import BaseProvider
9
+ from isa_model.inference.billing_tracker import ServiceType
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class OpenAISTTService(BaseSTTService):
14
+ """
15
+ OpenAI Speech-to-Text service using whisper-1 model.
16
+ Supports transcription and translation to English.
17
+ """
18
+
19
+ def __init__(self, provider: 'BaseProvider', model_name: str = "whisper-1"):
20
+ super().__init__(provider, model_name)
21
+
22
+ # Get full configuration from provider (including sensitive data)
23
+ provider_config = provider.get_full_config()
24
+
25
+ # Initialize AsyncOpenAI client with provider configuration
26
+ try:
27
+ if not provider_config.get("api_key"):
28
+ raise ValueError("OpenAI API key not found in provider configuration")
29
+
30
+ self.client = AsyncOpenAI(
31
+ api_key=provider_config["api_key"],
32
+ base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
33
+ organization=provider_config.get("organization")
34
+ )
35
+
36
+ logger.info(f"Initialized OpenAISTTService with model '{self.model_name}'")
37
+
38
+ except Exception as e:
39
+ logger.error(f"Failed to initialize OpenAI client: {e}")
40
+ raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
41
+
42
+ # Model configurations
43
+ self.max_file_size = provider_config.get('max_file_size', 25 * 1024 * 1024) # 25MB
44
+ self.supported_formats = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']
45
+
46
+ @retry(
47
+ stop=stop_after_attempt(3),
48
+ wait=wait_exponential(multiplier=1, min=4, max=10),
49
+ reraise=True
50
+ )
51
+ async def _download_audio(self, audio_url: str) -> bytes:
52
+ """Download audio from URL"""
53
+ async with aiohttp.ClientSession() as session:
54
+ async with session.get(audio_url) as response:
55
+ if response.status == 200:
56
+ return await response.read()
57
+ else:
58
+ raise ValueError(f"Failed to download audio from {audio_url}: {response.status}")
59
+
60
+ async def transcribe(
61
+ self,
62
+ audio_file: Union[str, BinaryIO],
63
+ language: Optional[str] = None,
64
+ prompt: Optional[str] = None
65
+ ) -> Dict[str, Any]:
66
+ """Transcribe audio file to text using whisper-1"""
67
+ try:
68
+ # Prepare the audio file
69
+ if isinstance(audio_file, str):
70
+ if audio_file.startswith(('http://', 'https://')):
71
+ # Download audio from URL
72
+ audio_data = await self._download_audio(audio_file)
73
+ filename = audio_file.split('/')[-1] or 'audio.wav'
74
+ else:
75
+ # Local file path
76
+ with open(audio_file, 'rb') as f:
77
+ audio_data = f.read()
78
+ filename = audio_file
79
+ else:
80
+ audio_data = audio_file.read()
81
+ filename = getattr(audio_file, 'name', 'audio.wav')
82
+
83
+ # Check file size
84
+ if len(audio_data) > self.max_file_size:
85
+ raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
86
+
87
+ # Prepare transcription parameters
88
+ kwargs = {
89
+ "model": self.model_name,
90
+ "file": (filename, audio_data),
91
+ "response_format": "verbose_json"
92
+ }
93
+
94
+ if language:
95
+ kwargs["language"] = language
96
+ if prompt:
97
+ kwargs["prompt"] = prompt
98
+
99
+ # Transcribe audio
100
+ response = await self.client.audio.transcriptions.create(**kwargs)
101
+
102
+ # Track usage for billing
103
+ usage = getattr(response, 'usage', {})
104
+ input_tokens = usage.get('input_tokens', 0) if usage else 0
105
+ output_tokens = usage.get('output_tokens', 0) if usage else 0
106
+
107
+ # For audio, also track duration in minutes
108
+ duration_minutes = getattr(response, 'duration', 0) / 60.0 if getattr(response, 'duration', 0) else 0
109
+
110
+ self._track_usage(
111
+ service_type=ServiceType.AUDIO_STT,
112
+ operation="transcribe",
113
+ input_tokens=input_tokens,
114
+ output_tokens=output_tokens,
115
+ input_units=duration_minutes, # Duration in minutes
116
+ metadata={
117
+ "language": language,
118
+ "model": self.model_name,
119
+ "file_size": len(audio_data)
120
+ }
121
+ )
122
+
123
+ # Format response
124
+ result = {
125
+ "text": response.text,
126
+ "language": getattr(response, 'language', language or 'unknown'),
127
+ "duration": getattr(response, 'duration', None),
128
+ "segments": getattr(response, 'segments', []),
129
+ "confidence": None, # whisper-1 doesn't provide confidence scores
130
+ "usage": usage # Include usage information
131
+ }
132
+
133
+ return result
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error transcribing audio: {e}")
137
+ raise
138
+
139
+ @retry(
140
+ stop=stop_after_attempt(3),
141
+ wait=wait_exponential(multiplier=1, min=4, max=10),
142
+ reraise=True
143
+ )
144
+ async def translate(
145
+ self,
146
+ audio_file: Union[str, BinaryIO]
147
+ ) -> Dict[str, Any]:
148
+ """Translate audio file to English text"""
149
+ try:
150
+ # Prepare the audio file
151
+ if isinstance(audio_file, str):
152
+ with open(audio_file, 'rb') as f:
153
+ audio_data = f.read()
154
+ filename = audio_file
155
+ else:
156
+ audio_data = audio_file.read()
157
+ filename = getattr(audio_file, 'name', 'audio.wav')
158
+
159
+ # Check file size
160
+ if len(audio_data) > self.max_file_size:
161
+ raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
162
+
163
+ # Translate audio to English
164
+ response = await self.client.audio.translations.create(
165
+ model=self.model_name,
166
+ file=(filename, audio_data),
167
+ response_format="verbose_json"
168
+ )
169
+
170
+ # Format response
171
+ result = {
172
+ "text": response.text,
173
+ "detected_language": getattr(response, 'language', 'unknown'),
174
+ "duration": getattr(response, 'duration', None),
175
+ "segments": getattr(response, 'segments', []),
176
+ "confidence": None # Whisper doesn't provide confidence scores
177
+ }
178
+
179
+ return result
180
+
181
+ except Exception as e:
182
+ logger.error(f"Error translating audio: {e}")
183
+ raise
184
+
185
+ async def transcribe_batch(
186
+ self,
187
+ audio_files: List[Union[str, BinaryIO]],
188
+ language: Optional[str] = None,
189
+ prompt: Optional[str] = None
190
+ ) -> List[Dict[str, Any]]:
191
+ """Transcribe multiple audio files"""
192
+ results = []
193
+
194
+ for audio_file in audio_files:
195
+ try:
196
+ result = await self.transcribe(audio_file, language, prompt)
197
+ results.append(result)
198
+ except Exception as e:
199
+ logger.error(f"Error transcribing audio file: {e}")
200
+ results.append({
201
+ "text": "",
202
+ "language": "unknown",
203
+ "duration": None,
204
+ "segments": [],
205
+ "confidence": None,
206
+ "error": str(e)
207
+ })
208
+
209
+ return results
210
+
211
+ async def detect_language(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
212
+ """Detect language of audio file"""
213
+ try:
214
+ # Transcribe with language detection
215
+ result = await self.transcribe(audio_file, language=None)
216
+
217
+ return {
218
+ "language": result["language"],
219
+ "confidence": 1.0, # Whisper is generally confident
220
+ "alternatives": [] # Whisper doesn't provide alternatives
221
+ }
222
+
223
+ except Exception as e:
224
+ logger.error(f"Error detecting language: {e}")
225
+ raise
226
+
227
+ def get_supported_formats(self) -> List[str]:
228
+ """Get list of supported audio formats"""
229
+ return self.supported_formats.copy()
230
+
231
+ def get_supported_languages(self) -> List[str]:
232
+ """Get list of supported language codes"""
233
+ # Whisper supports 99+ languages
234
+ return [
235
+ 'af', 'am', 'ar', 'as', 'az', 'ba', 'be', 'bg', 'bn', 'bo', 'br', 'bs', 'ca',
236
+ 'cs', 'cy', 'da', 'de', 'el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr',
237
+ 'gl', 'gu', 'ha', 'haw', 'he', 'hi', 'hr', 'ht', 'hu', 'hy', 'id', 'is', 'it',
238
+ 'ja', 'jw', 'ka', 'kk', 'km', 'kn', 'ko', 'la', 'lb', 'ln', 'lo', 'lt', 'lv',
239
+ 'mg', 'mi', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'my', 'ne', 'nl', 'nn', 'no',
240
+ 'oc', 'pa', 'pl', 'ps', 'pt', 'ro', 'ru', 'sa', 'sd', 'si', 'sk', 'sl', 'sn',
241
+ 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk', 'tl', 'tr',
242
+ 'tt', 'uk', 'ur', 'uz', 'vi', 'yi', 'yo', 'zh'
243
+ ]
244
+
245
+ def get_max_file_size(self) -> int:
246
+ """Get maximum file size in bytes"""
247
+ return self.max_file_size
248
+
249
+ async def close(self):
250
+ """Cleanup resources"""
251
+ await self.client.close()
252
+ logger.info("OpenAISTTService client has been closed.")
@@ -1,25 +1,42 @@
1
- from typing import Dict, Any
1
+ from typing import Dict, Any, List, Optional
2
2
  import tempfile
3
3
  import os
4
4
  from openai import AsyncOpenAI
5
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
- from isa_model.inference.services.base_service import BaseService
6
+ from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
7
  from isa_model.inference.providers.base_provider import BaseProvider
8
+ from isa_model.inference.billing_tracker import ServiceType
8
9
  import logging
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
12
- class YYDSAudioService(BaseService):
13
+ class OpenAITTSService(BaseTTSService):
13
14
  """Audio model service wrapper for YYDS"""
14
15
 
15
16
  def __init__(self, provider: 'BaseProvider', model_name: str):
16
17
  super().__init__(provider, model_name)
17
- # 初始化 AsyncOpenAI 客户端
18
- self._client = AsyncOpenAI(
19
- api_key=self.config.get('api_key'),
20
- base_url=self.config.get('base_url')
21
- )
22
- self.language = self.config.get('language', None)
18
+
19
+ # Get full configuration from provider (including sensitive data)
20
+ provider_config = provider.get_full_config()
21
+
22
+ # Initialize AsyncOpenAI client with provider configuration
23
+ try:
24
+ if not provider_config.get("api_key"):
25
+ raise ValueError("OpenAI API key not found in provider configuration")
26
+
27
+ self._client = AsyncOpenAI(
28
+ api_key=provider_config["api_key"],
29
+ base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
30
+ organization=provider_config.get("organization")
31
+ )
32
+
33
+ logger.info(f"Initialized OpenAITTSService with model '{self.model_name}'")
34
+
35
+ except Exception as e:
36
+ logger.error(f"Failed to initialize OpenAI client: {e}")
37
+ raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
38
+
39
+ self.language = provider_config.get('language', None)
23
40
 
24
41
  @property
25
42
  def client(self) -> AsyncOpenAI:
@@ -69,3 +86,126 @@ class YYDSAudioService(BaseService):
69
86
  except Exception as e:
70
87
  logger.error(f"Error in audio transcription: {e}")
71
88
  raise
89
+
90
+ # 实现BaseTTSService的抽象方法
91
+ async def synthesize_speech(
92
+ self,
93
+ text: str,
94
+ voice: Optional[str] = None,
95
+ speed: float = 1.0,
96
+ pitch: float = 1.0,
97
+ format: str = "mp3"
98
+ ) -> Dict[str, Any]:
99
+ """Synthesize speech from text using OpenAI TTS"""
100
+ try:
101
+ response = await self._client.audio.speech.create(
102
+ model="tts-1",
103
+ voice=voice or "alloy", # type: ignore
104
+ input=text,
105
+ response_format=format, # type: ignore
106
+ speed=speed
107
+ )
108
+
109
+ audio_data = response.content
110
+
111
+ # Estimate audio duration for billing (rough estimation: ~150 words per minute)
112
+ words = len(text.split())
113
+ estimated_duration_seconds = (words / 150.0) * 60.0 / speed
114
+
115
+ # Track usage for billing (OpenAI TTS is token-based: $15 per 1M characters)
116
+ self._track_usage(
117
+ service_type=ServiceType.AUDIO_TTS,
118
+ operation="synthesize_speech",
119
+ input_tokens=len(text), # Characters as input tokens
120
+ output_tokens=0,
121
+ input_units=len(text), # Text length
122
+ output_units=estimated_duration_seconds, # Audio duration in seconds
123
+ metadata={
124
+ "model": self.model_name,
125
+ "voice": voice or "alloy",
126
+ "speed": speed,
127
+ "format": format,
128
+ "text_length": len(text),
129
+ "estimated_duration_seconds": estimated_duration_seconds
130
+ }
131
+ )
132
+
133
+ return {
134
+ "audio_data": audio_data,
135
+ "format": format,
136
+ "duration": estimated_duration_seconds,
137
+ "sample_rate": 24000 # Default for OpenAI TTS
138
+ }
139
+
140
+ except Exception as e:
141
+ logger.error(f"Error in speech synthesis: {e}")
142
+ raise
143
+
144
+ async def synthesize_speech_to_file(
145
+ self,
146
+ text: str,
147
+ output_path: str,
148
+ voice: Optional[str] = None,
149
+ speed: float = 1.0,
150
+ pitch: float = 1.0,
151
+ format: str = "mp3"
152
+ ) -> Dict[str, Any]:
153
+ """Synthesize speech and save to file"""
154
+ result = await self.synthesize_speech(text, voice, speed, pitch, format)
155
+
156
+ with open(output_path, 'wb') as f:
157
+ f.write(result["audio_data"])
158
+
159
+ return {
160
+ "file_path": output_path,
161
+ "duration": result["duration"],
162
+ "sample_rate": result["sample_rate"]
163
+ }
164
+
165
+ async def synthesize_speech_batch(
166
+ self,
167
+ texts: List[str],
168
+ voice: Optional[str] = None,
169
+ speed: float = 1.0,
170
+ pitch: float = 1.0,
171
+ format: str = "mp3"
172
+ ) -> List[Dict[str, Any]]:
173
+ """Synthesize speech for multiple texts"""
174
+ results = []
175
+ for text in texts:
176
+ result = await self.synthesize_speech(text, voice, speed, pitch, format)
177
+ results.append(result)
178
+ return results
179
+
180
+ def get_available_voices(self) -> List[Dict[str, Any]]:
181
+ """Get list of available OpenAI voices"""
182
+ return [
183
+ {"id": "alloy", "name": "Alloy", "language": "en-US", "gender": "neutral", "age": "adult"},
184
+ {"id": "echo", "name": "Echo", "language": "en-US", "gender": "male", "age": "adult"},
185
+ {"id": "fable", "name": "Fable", "language": "en-US", "gender": "neutral", "age": "adult"},
186
+ {"id": "onyx", "name": "Onyx", "language": "en-US", "gender": "male", "age": "adult"},
187
+ {"id": "nova", "name": "Nova", "language": "en-US", "gender": "female", "age": "adult"},
188
+ {"id": "shimmer", "name": "Shimmer", "language": "en-US", "gender": "female", "age": "adult"}
189
+ ]
190
+
191
+ def get_supported_formats(self) -> List[str]:
192
+ """Get list of supported audio formats"""
193
+ return ["mp3", "opus", "aac", "flac"]
194
+
195
+ def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
196
+ """Get detailed information about a specific voice"""
197
+ voices = {voice["id"]: voice for voice in self.get_available_voices()}
198
+ voice_info = voices.get(voice_id, {})
199
+
200
+ if voice_info:
201
+ voice_info.update({
202
+ "description": f"OpenAI {voice_info['name']} voice",
203
+ "sample_rate": 24000
204
+ })
205
+
206
+ return voice_info
207
+
208
+ async def close(self):
209
+ """Cleanup resources"""
210
+ if hasattr(self._client, 'close'):
211
+ await self._client.close()
@@ -0,0 +1,239 @@
1
+ import logging
2
+ from typing import Dict, Any, List, Optional, BinaryIO
3
+ import replicate
4
+ from tenacity import retry, stop_after_attempt, wait_exponential
5
+
6
+ from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
+ from isa_model.inference.providers.base_provider import BaseProvider
8
+ from isa_model.inference.billing_tracker import ServiceType
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class ReplicateTTSService(BaseTTSService):
13
+ """
14
+ Replicate Text-to-Speech service using Kokoro model.
15
+ High-quality voice synthesis with multiple voice options.
16
+ """
17
+
18
+ def __init__(self, provider: 'BaseProvider', model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"):
19
+ super().__init__(provider, model_name)
20
+
21
+ # Get full configuration from provider (including sensitive data)
22
+ provider_config = provider.get_full_config()
23
+
24
+ # Set up Replicate API token from provider configuration
25
+ self.api_token = provider_config.get('api_token') or provider_config.get('replicate_api_token')
26
+ if not self.api_token:
27
+ raise ValueError("Replicate API token not found in provider configuration")
28
+
29
+ # Set environment variable for replicate library
30
+ import os
31
+ os.environ['REPLICATE_API_TOKEN'] = self.api_token
32
+
33
+ # Available voices for Kokoro model
34
+ self.available_voices = [
35
+ "af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
36
+ ]
37
+
38
+ # Default settings
39
+ self.default_voice = "af_nicole"
40
+ self.default_speed = 1.0
41
+
42
+ logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
43
+
44
+ @retry(
45
+ stop=stop_after_attempt(3),
46
+ wait=wait_exponential(multiplier=1, min=4, max=10),
47
+ reraise=True
48
+ )
49
+ async def synthesize_speech(
50
+ self,
51
+ text: str,
52
+ voice: Optional[str] = None,
53
+ speed: float = 1.0,
54
+ pitch: Optional[float] = None,
55
+ volume: Optional[float] = None
56
+ ) -> Dict[str, Any]:
57
+ """Synthesize speech from text using Kokoro model"""
58
+ try:
59
+ # Validate and set voice
60
+ selected_voice = voice or self.default_voice
61
+ if selected_voice not in self.available_voices:
62
+ logger.warning(f"Voice '{selected_voice}' not available, using default '{self.default_voice}'")
63
+ selected_voice = self.default_voice
64
+
65
+ # Prepare input parameters
66
+ input_params = {
67
+ "text": text,
68
+ "voice": selected_voice,
69
+ "speed": max(0.5, min(2.0, speed)) # Clamp speed between 0.5 and 2.0
70
+ }
71
+
72
+ logger.info(f"Synthesizing speech with voice '{selected_voice}' and speed {speed}")
73
+
74
+ # Run the model
75
+ output = await replicate.async_run(self.model_name, input=input_params)
76
+
77
+ # Handle different output formats
78
+ try:
79
+ if isinstance(output, str):
80
+ audio_url = output
81
+ elif hasattr(output, 'url'):
82
+ # Handle FileOutput object
83
+ audio_url = str(getattr(output, 'url', output))
84
+ elif isinstance(output, list) and len(output) > 0:
85
+ first_output = output[0]
86
+ if hasattr(first_output, 'url'):
87
+ audio_url = str(getattr(first_output, 'url', first_output))
88
+ else:
89
+ audio_url = str(first_output)
90
+ else:
91
+ # Convert to string as fallback
92
+ audio_url = str(output)
93
+ except Exception:
94
+ # Safe fallback
95
+ audio_url = str(output)
96
+
97
+ # Estimate audio duration for billing (rough estimation: ~150 words per minute)
98
+ words = len(text.split())
99
+ estimated_duration_seconds = (words / 150.0) * 60.0 / speed
100
+
101
+ # Track usage for billing
102
+ self._track_usage(
103
+ service_type=ServiceType.AUDIO_TTS,
104
+ operation="synthesize_speech",
105
+ input_tokens=0,
106
+ output_tokens=0,
107
+ input_units=len(text), # Text length
108
+ output_units=estimated_duration_seconds, # Audio duration in seconds
109
+ metadata={
110
+ "model": self.model_name,
111
+ "voice": selected_voice,
112
+ "speed": speed,
113
+ "text_length": len(text),
114
+ "estimated_duration_seconds": estimated_duration_seconds
115
+ }
116
+ )
117
+
118
+ result = {
119
+ "audio_url": audio_url,
120
+ "text": text,
121
+ "voice": selected_voice,
122
+ "speed": speed,
123
+ "duration_seconds": estimated_duration_seconds,
124
+ "metadata": {
125
+ "model": self.model_name,
126
+ "provider": "replicate",
127
+ "voice_options": self.available_voices
128
+ }
129
+ }
130
+
131
+ logger.info(f"Speech synthesis completed: {audio_url}")
132
+ return result
133
+
134
+ except Exception as e:
135
+ logger.error(f"Error synthesizing speech: {e}")
136
+ raise
137
+
138
+ async def synthesize_speech_to_file(
139
+ self,
140
+ text: str,
141
+ output_path: str,
142
+ voice: Optional[str] = None,
143
+ speed: float = 1.0,
144
+ pitch: Optional[float] = None,
145
+ volume: Optional[float] = None
146
+ ) -> Dict[str, Any]:
147
+ """Synthesize speech and save to file"""
148
+ try:
149
+ # Get audio URL
150
+ result = await self.synthesize_speech(text, voice, speed, pitch, volume)
151
+ audio_url = result["audio_url"]
152
+
153
+ # Download and save audio
154
+ import aiohttp
155
+ import aiofiles
156
+
157
+ async with aiohttp.ClientSession() as session:
158
+ async with session.get(audio_url) as response:
159
+ response.raise_for_status()
160
+ audio_data = await response.read()
161
+
162
+ async with aiofiles.open(output_path, 'wb') as f:
163
+ await f.write(audio_data)
164
+
165
+ result["output_path"] = output_path
166
+ result["file_size"] = len(audio_data)
167
+
168
+ logger.info(f"Audio saved to: {output_path}")
169
+ return result
170
+
171
+ except Exception as e:
172
+ logger.error(f"Error saving audio to file: {e}")
173
+ raise
174
+
175
+ async def synthesize_speech_batch(
176
+ self,
177
+ texts: List[str],
178
+ voice: Optional[str] = None,
179
+ speed: float = 1.0,
180
+ pitch: float = 1.0,
181
+ format: str = "wav"
182
+ ) -> List[Dict[str, Any]]:
183
+ """Synthesize multiple texts"""
184
+ results = []
185
+
186
+ for text in texts:
187
+ try:
188
+ result = await self.synthesize_speech(text, voice, speed)
189
+ results.append(result)
190
+ except Exception as e:
191
+ logger.error(f"Error synthesizing text '{text[:50]}...': {e}")
192
+ results.append({
193
+ "audio_url": None,
194
+ "text": text,
195
+ "voice": voice or self.default_voice,
196
+ "speed": speed,
197
+ "error": str(e)
198
+ })
199
+
200
+ return results
201
+
202
+ def get_available_voices(self) -> List[Dict[str, Any]]:
203
+ """Get list of available voices"""
204
+ voices = []
205
+ for voice in self.available_voices:
206
+ voice_info = self.get_voice_info(voice)
207
+ voices.append({
208
+ "id": voice,
209
+ "name": voice.replace("_", " ").title(),
210
+ "language": "en-US",
211
+ "gender": voice_info.get("gender", "unknown"),
212
+ "age": "adult"
213
+ })
214
+ return voices
215
+
216
+ def get_supported_formats(self) -> List[str]:
217
+ """Get list of supported audio formats"""
218
+ return ["wav", "mp3"] # Kokoro typically outputs WAV
219
+
220
+ def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
221
+ """Get information about a specific voice"""
222
+ if voice_id not in self.available_voices:
223
+ return {"error": f"Voice '{voice_id}' not available"}
224
+
225
+ # Voice metadata (you can expand this with more details)
226
+ voice_info = {
227
+ "af_bella": {"id": "af_bella", "name": "Bella", "gender": "female", "language": "en-US", "description": "Warm, friendly female voice", "sample_rate": 22050},
228
+ "af_nicole": {"id": "af_nicole", "name": "Nicole", "gender": "female", "language": "en-US", "description": "Clear, professional female voice", "sample_rate": 22050},
229
+ "af_sarah": {"id": "af_sarah", "name": "Sarah", "gender": "female", "language": "en-US", "description": "Gentle, expressive female voice", "sample_rate": 22050},
230
+ "af_sky": {"id": "af_sky", "name": "Sky", "gender": "female", "language": "en-US", "description": "Bright, energetic female voice", "sample_rate": 22050},
231
+ "am_adam": {"id": "am_adam", "name": "Adam", "gender": "male", "language": "en-US", "description": "Deep, authoritative male voice", "sample_rate": 22050},
232
+ "am_michael": {"id": "am_michael", "name": "Michael", "gender": "male", "language": "en-US", "description": "Smooth, conversational male voice", "sample_rate": 22050}
233
+ }
234
+
235
+ return voice_info.get(voice_id, {"id": voice_id, "gender": "unknown", "language": "en-US", "description": "Voice information not available", "sample_rate": 22050})
236
+
237
+ async def close(self):
238
+ """Cleanup resources"""
239
+ logger.info("ReplicateTTSService resources cleaned up")