isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. isa_model/__init__.py +30 -1
  2. isa_model/client.py +770 -0
  3. isa_model/core/config/__init__.py +16 -0
  4. isa_model/core/config/config_manager.py +514 -0
  5. isa_model/core/config.py +426 -0
  6. isa_model/core/models/model_billing_tracker.py +476 -0
  7. isa_model/core/models/model_manager.py +399 -0
  8. isa_model/core/models/model_repo.py +343 -0
  9. isa_model/core/pricing_manager.py +426 -0
  10. isa_model/core/services/__init__.py +19 -0
  11. isa_model/core/services/intelligent_model_selector.py +547 -0
  12. isa_model/core/types.py +291 -0
  13. isa_model/deployment/__init__.py +2 -0
  14. isa_model/deployment/cloud/__init__.py +9 -0
  15. isa_model/deployment/cloud/modal/__init__.py +10 -0
  16. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
  17. isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
  18. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
  19. isa_model/deployment/cloud/modal/register_models.py +321 -0
  20. isa_model/deployment/runtime/deployed_service.py +338 -0
  21. isa_model/deployment/services/__init__.py +9 -0
  22. isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
  23. isa_model/deployment/services/model_service.py +332 -0
  24. isa_model/deployment/services/service_monitor.py +356 -0
  25. isa_model/deployment/services/service_registry.py +527 -0
  26. isa_model/eval/__init__.py +80 -44
  27. isa_model/eval/config/__init__.py +10 -0
  28. isa_model/eval/config/evaluation_config.py +108 -0
  29. isa_model/eval/evaluators/__init__.py +18 -0
  30. isa_model/eval/evaluators/base_evaluator.py +503 -0
  31. isa_model/eval/evaluators/llm_evaluator.py +472 -0
  32. isa_model/eval/factory.py +417 -709
  33. isa_model/eval/infrastructure/__init__.py +24 -0
  34. isa_model/eval/infrastructure/experiment_tracker.py +466 -0
  35. isa_model/eval/metrics.py +191 -21
  36. isa_model/inference/ai_factory.py +187 -387
  37. isa_model/inference/providers/modal_provider.py +109 -0
  38. isa_model/inference/providers/yyds_provider.py +108 -0
  39. isa_model/inference/services/__init__.py +2 -1
  40. isa_model/inference/services/audio/base_stt_service.py +65 -1
  41. isa_model/inference/services/audio/base_tts_service.py +75 -1
  42. isa_model/inference/services/audio/openai_stt_service.py +189 -151
  43. isa_model/inference/services/audio/openai_tts_service.py +12 -10
  44. isa_model/inference/services/audio/replicate_tts_service.py +61 -56
  45. isa_model/inference/services/base_service.py +55 -55
  46. isa_model/inference/services/embedding/base_embed_service.py +65 -1
  47. isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
  48. isa_model/inference/services/embedding/openai_embed_service.py +8 -10
  49. isa_model/inference/services/helpers/stacked_config.py +148 -0
  50. isa_model/inference/services/img/__init__.py +18 -0
  51. isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
  52. isa_model/inference/services/img/flux_professional_service.py +603 -0
  53. isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
  54. isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
  55. isa_model/inference/services/llm/__init__.py +3 -3
  56. isa_model/inference/services/llm/base_llm_service.py +519 -35
  57. isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
  58. isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
  59. isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
  60. isa_model/inference/services/llm/ollama_llm_service.py +150 -15
  61. isa_model/inference/services/llm/openai_llm_service.py +134 -31
  62. isa_model/inference/services/llm/yyds_llm_service.py +255 -0
  63. isa_model/inference/services/vision/__init__.py +38 -4
  64. isa_model/inference/services/vision/base_vision_service.py +241 -96
  65. isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
  66. isa_model/inference/services/vision/doc_analysis_service.py +640 -0
  67. isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
  68. isa_model/inference/services/vision/helpers/image_utils.py +272 -3
  69. isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
  70. isa_model/inference/services/vision/openai_vision_service.py +109 -170
  71. isa_model/inference/services/vision/replicate_vision_service.py +508 -0
  72. isa_model/inference/services/vision/ui_analysis_service.py +823 -0
  73. isa_model/scripts/register_models.py +370 -0
  74. isa_model/scripts/register_models_with_embeddings.py +510 -0
  75. isa_model/serving/__init__.py +19 -0
  76. isa_model/serving/api/__init__.py +10 -0
  77. isa_model/serving/api/fastapi_server.py +89 -0
  78. isa_model/serving/api/middleware/__init__.py +9 -0
  79. isa_model/serving/api/middleware/request_logger.py +88 -0
  80. isa_model/serving/api/routes/__init__.py +5 -0
  81. isa_model/serving/api/routes/health.py +82 -0
  82. isa_model/serving/api/routes/llm.py +19 -0
  83. isa_model/serving/api/routes/ui_analysis.py +223 -0
  84. isa_model/serving/api/routes/unified.py +202 -0
  85. isa_model/serving/api/routes/vision.py +19 -0
  86. isa_model/serving/api/schemas/__init__.py +17 -0
  87. isa_model/serving/api/schemas/common.py +33 -0
  88. isa_model/serving/api/schemas/ui_analysis.py +78 -0
  89. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
  90. isa_model-0.3.6.dist-info/RECORD +147 -0
  91. isa_model/core/model_manager.py +0 -208
  92. isa_model/core/model_registry.py +0 -342
  93. isa_model/inference/billing_tracker.py +0 -406
  94. isa_model/inference/services/llm/triton_llm_service.py +0 -481
  95. isa_model/inference/services/vision/ollama_vision_service.py +0 -194
  96. isa_model-0.3.4.dist-info/RECORD +0 -91
  97. /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
  98. /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
  99. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
  100. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -5,8 +5,6 @@ from openai import AsyncOpenAI
5
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
6
 
7
7
  from isa_model.inference.services.audio.base_stt_service import BaseSTTService
8
- from isa_model.inference.providers.base_provider import BaseProvider
9
- from isa_model.inference.billing_tracker import ServiceType
10
8
 
11
9
  logger = logging.getLogger(__name__)
12
10
 
@@ -14,22 +12,22 @@ class OpenAISTTService(BaseSTTService):
14
12
  """
15
13
  OpenAI Speech-to-Text service using whisper-1 model.
16
14
  Supports transcription and translation to English.
15
+ Uses the new unified architecture with centralized config management.
17
16
  """
18
17
 
19
- def __init__(self, provider: 'BaseProvider', model_name: str = "whisper-1"):
20
- super().__init__(provider, model_name)
18
+ def __init__(self, provider_name: str, model_name: str = "whisper-1", **kwargs):
19
+ super().__init__(provider_name, model_name, **kwargs)
21
20
 
22
- # Get full configuration from provider (including sensitive data)
23
- provider_config = provider.get_full_config()
21
+ # Get provider configuration from centralized config manager
22
+ provider_config = self.get_provider_config()
24
23
 
25
24
  # Initialize AsyncOpenAI client with provider configuration
26
25
  try:
27
- if not provider_config.get("api_key"):
28
- raise ValueError("OpenAI API key not found in provider configuration")
26
+ api_key = self.get_api_key()
29
27
 
30
28
  self.client = AsyncOpenAI(
31
- api_key=provider_config["api_key"],
32
- base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
29
+ api_key=api_key,
30
+ base_url=provider_config.get("api_base_url", "https://api.openai.com/v1"),
33
31
  organization=provider_config.get("organization")
34
32
  )
35
33
 
@@ -48,205 +46,245 @@ class OpenAISTTService(BaseSTTService):
48
46
  wait=wait_exponential(multiplier=1, min=4, max=10),
49
47
  reraise=True
50
48
  )
51
- async def _download_audio(self, audio_url: str) -> bytes:
52
- """Download audio from URL"""
53
- async with aiohttp.ClientSession() as session:
54
- async with session.get(audio_url) as response:
55
- if response.status == 200:
56
- return await response.read()
57
- else:
58
- raise ValueError(f"Failed to download audio from {audio_url}: {response.status}")
59
-
60
- async def transcribe(
61
- self,
62
- audio_file: Union[str, BinaryIO],
63
- language: Optional[str] = None,
64
- prompt: Optional[str] = None
65
- ) -> Dict[str, Any]:
66
- """Transcribe audio file to text using whisper-1"""
67
- try:
68
- # Prepare the audio file
69
- if isinstance(audio_file, str):
70
- if audio_file.startswith(('http://', 'https://')):
71
- # Download audio from URL
72
- audio_data = await self._download_audio(audio_file)
73
- filename = audio_file.split('/')[-1] or 'audio.wav'
74
- else:
75
- # Local file path
76
- with open(audio_file, 'rb') as f:
77
- audio_data = f.read()
78
- filename = audio_file
79
- else:
80
- audio_data = audio_file.read()
81
- filename = getattr(audio_file, 'name', 'audio.wav')
82
-
83
- # Check file size
84
- if len(audio_data) > self.max_file_size:
85
- raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
49
+ async def transcribe(self, audio_file: Union[str, BinaryIO], language: Optional[str] = None, prompt: Optional[str] = None) -> Dict[str, Any]:
50
+ """
51
+ Transcribe audio file to text using OpenAI's Whisper model.
52
+
53
+ Args:
54
+ audio_file: Path to audio file or file-like object
55
+ language: Optional language code for better accuracy
56
+ **kwargs: Additional parameters for the transcription API
86
57
 
87
- # Prepare transcription parameters
88
- kwargs = {
58
+ Returns:
59
+ Dict containing transcription result and metadata
60
+ """
61
+ try:
62
+ # Prepare request parameters
63
+ transcription_params = {
89
64
  "model": self.model_name,
90
- "file": (filename, audio_data),
91
65
  "response_format": "verbose_json"
92
66
  }
93
67
 
94
68
  if language:
95
- kwargs["language"] = language
96
- if prompt:
97
- kwargs["prompt"] = prompt
69
+ transcription_params["language"] = language
98
70
 
99
- # Transcribe audio
100
- response = await self.client.audio.transcriptions.create(**kwargs)
71
+ # Add optional parameters
72
+ if prompt:
73
+ transcription_params["prompt"] = prompt
101
74
 
102
- # Track usage for billing
103
- usage = getattr(response, 'usage', {})
104
- input_tokens = usage.get('input_tokens', 0) if usage else 0
105
- output_tokens = usage.get('output_tokens', 0) if usage else 0
75
+ # Handle file input
76
+ if isinstance(audio_file, str):
77
+ with open(audio_file, "rb") as f:
78
+ transcription = await self.client.audio.transcriptions.create(
79
+ file=f,
80
+ **transcription_params
81
+ )
82
+ else:
83
+ transcription = await self.client.audio.transcriptions.create(
84
+ file=audio_file,
85
+ **transcription_params
86
+ )
106
87
 
107
- # For audio, also track duration in minutes
108
- duration_minutes = getattr(response, 'duration', 0) / 60.0 if getattr(response, 'duration', 0) else 0
88
+ # Extract usage information for billing
89
+ result = {
90
+ "text": transcription.text,
91
+ "language": getattr(transcription, 'language', language),
92
+ "duration": getattr(transcription, 'duration', None),
93
+ "segments": getattr(transcription, 'segments', []),
94
+ "usage": {
95
+ "input_units": getattr(transcription, 'duration', 1), # Duration in seconds
96
+ "output_tokens": len(transcription.text.split()) if transcription.text else 0
97
+ }
98
+ }
109
99
 
110
- self._track_usage(
111
- service_type=ServiceType.AUDIO_STT,
100
+ # Track usage for billing
101
+ await self._track_usage(
102
+ service_type="audio_stt",
112
103
  operation="transcribe",
113
- input_tokens=input_tokens,
114
- output_tokens=output_tokens,
115
- input_units=duration_minutes, # Duration in minutes
104
+ input_units=result["usage"]["input_units"],
105
+ output_tokens=result["usage"]["output_tokens"],
116
106
  metadata={
117
- "language": language,
118
- "model": self.model_name,
119
- "file_size": len(audio_data)
107
+ "language": result.get("language"),
108
+ "model_name": self.model_name,
109
+ "provider": self.provider_name
120
110
  }
121
111
  )
122
112
 
123
- # Format response
124
- result = {
125
- "text": response.text,
126
- "language": getattr(response, 'language', language or 'unknown'),
127
- "duration": getattr(response, 'duration', None),
128
- "segments": getattr(response, 'segments', []),
129
- "confidence": None, # whisper-1 doesn't provide confidence scores
130
- "usage": usage # Include usage information
131
- }
132
-
133
113
  return result
134
114
 
135
115
  except Exception as e:
136
- logger.error(f"Error transcribing audio: {e}")
116
+ logger.error(f"Transcription failed: {e}")
137
117
  raise
138
-
118
+
139
119
  @retry(
140
120
  stop=stop_after_attempt(3),
141
121
  wait=wait_exponential(multiplier=1, min=4, max=10),
142
122
  reraise=True
143
123
  )
144
- async def translate(
145
- self,
146
- audio_file: Union[str, BinaryIO]
147
- ) -> Dict[str, Any]:
148
- """Translate audio file to English text"""
124
+ async def translate(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
125
+ """
126
+ Translate audio file to English text using OpenAI's Whisper model.
127
+
128
+ Args:
129
+ audio_file: Path to audio file or file-like object
130
+ **kwargs: Additional parameters for the translation API
131
+
132
+ Returns:
133
+ Dict containing translation result and metadata
134
+ """
149
135
  try:
150
- # Prepare the audio file
151
- if isinstance(audio_file, str):
152
- with open(audio_file, 'rb') as f:
153
- audio_data = f.read()
154
- filename = audio_file
155
- else:
156
- audio_data = audio_file.read()
157
- filename = getattr(audio_file, 'name', 'audio.wav')
136
+ # Prepare request parameters
137
+ translation_params = {
138
+ "model": self.model_name,
139
+ "response_format": "verbose_json"
140
+ }
158
141
 
159
- # Check file size
160
- if len(audio_data) > self.max_file_size:
161
- raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
142
+ # No additional parameters for translation
162
143
 
163
- # Translate audio to English
164
- response = await self.client.audio.translations.create(
165
- model=self.model_name,
166
- file=(filename, audio_data),
167
- response_format="verbose_json"
168
- )
144
+ # Handle file input
145
+ if isinstance(audio_file, str):
146
+ with open(audio_file, "rb") as f:
147
+ translation = await self.client.audio.translations.create(
148
+ file=f,
149
+ **translation_params
150
+ )
151
+ else:
152
+ translation = await self.client.audio.translations.create(
153
+ file=audio_file,
154
+ **translation_params
155
+ )
169
156
 
170
- # Format response
157
+ # Extract usage information for billing
171
158
  result = {
172
- "text": response.text,
173
- "detected_language": getattr(response, 'language', 'unknown'),
174
- "duration": getattr(response, 'duration', None),
175
- "segments": getattr(response, 'segments', []),
176
- "confidence": None # Whisper doesn't provide confidence scores
159
+ "text": translation.text,
160
+ "language": "en", # Translation is always to English
161
+ "duration": getattr(translation, 'duration', None),
162
+ "segments": getattr(translation, 'segments', []),
163
+ "usage": {
164
+ "input_units": getattr(translation, 'duration', 1), # Duration in seconds
165
+ "output_tokens": len(translation.text.split()) if translation.text else 0
166
+ }
177
167
  }
178
168
 
169
+ # Track usage for billing
170
+ await self._track_usage(
171
+ service_type="audio_stt",
172
+ operation="translate",
173
+ input_units=result["usage"]["input_units"],
174
+ output_tokens=result["usage"]["output_tokens"],
175
+ metadata={
176
+ "target_language": "en",
177
+ "model_name": self.model_name,
178
+ "provider": self.provider_name
179
+ }
180
+ )
181
+
179
182
  return result
180
183
 
181
184
  except Exception as e:
182
- logger.error(f"Error translating audio: {e}")
185
+ logger.error(f"Translation failed: {e}")
183
186
  raise
184
-
185
- async def transcribe_batch(
186
- self,
187
- audio_files: List[Union[str, BinaryIO]],
188
- language: Optional[str] = None,
189
- prompt: Optional[str] = None
190
- ) -> List[Dict[str, Any]]:
191
- """Transcribe multiple audio files"""
192
- results = []
187
+
188
+ async def transcribe_batch(self, audio_files: List[Union[str, BinaryIO]], language: Optional[str] = None, prompt: Optional[str] = None) -> List[Dict[str, Any]]:
189
+ """
190
+ Transcribe multiple audio files in batch.
193
191
 
192
+ Args:
193
+ audio_files: List of audio file paths or file-like objects
194
+ language: Optional language code for better accuracy
195
+ **kwargs: Additional parameters for the transcription API
196
+
197
+ Returns:
198
+ List of transcription results
199
+ """
200
+ results = []
194
201
  for audio_file in audio_files:
195
202
  try:
196
203
  result = await self.transcribe(audio_file, language, prompt)
197
204
  results.append(result)
198
205
  except Exception as e:
199
- logger.error(f"Error transcribing audio file: {e}")
206
+ logger.error(f"Failed to transcribe {audio_file}: {e}")
200
207
  results.append({
201
- "text": "",
202
- "language": "unknown",
203
- "duration": None,
204
- "segments": [],
205
- "confidence": None,
206
- "error": str(e)
208
+ "error": str(e),
209
+ "file": str(audio_file),
210
+ "text": None
207
211
  })
208
212
 
209
213
  return results
210
-
214
+
211
215
  async def detect_language(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
212
- """Detect language of audio file"""
216
+ """
217
+ Detect the language of an audio file.
218
+
219
+ Args:
220
+ audio_file: Path to audio file or file-like object
221
+ **kwargs: Additional parameters
222
+
223
+ Returns:
224
+ Dict containing detected language and confidence
225
+ """
213
226
  try:
214
- # Transcribe with language detection
215
- result = await self.transcribe(audio_file, language=None)
227
+ # Use transcription with language detection - need to access client directly
228
+ transcription = await self.client.audio.transcriptions.create(
229
+ file=audio_file if not isinstance(audio_file, str) else open(audio_file, "rb"),
230
+ model=self.model_name,
231
+ response_format="verbose_json"
232
+ )
233
+
234
+ result = {
235
+ "text": transcription.text,
236
+ "language": getattr(transcription, 'language', "unknown")
237
+ }
216
238
 
217
239
  return {
218
- "language": result["language"],
219
- "confidence": 1.0, # Whisper is generally confident
220
- "alternatives": [] # Whisper doesn't provide alternatives
240
+ "language": result.get("language", "unknown"),
241
+ "confidence": 1.0, # OpenAI doesn't provide confidence scores
242
+ "text_sample": result.get("text", "")[:100] if result.get("text") else ""
221
243
  }
222
244
 
223
245
  except Exception as e:
224
- logger.error(f"Error detecting language: {e}")
225
- raise
226
-
246
+ logger.error(f"Language detection failed: {e}")
247
+ return {
248
+ "language": "unknown",
249
+ "confidence": 0.0,
250
+ "error": str(e)
251
+ }
252
+
227
253
  def get_supported_formats(self) -> List[str]:
228
- """Get list of supported audio formats"""
229
- return self.supported_formats.copy()
254
+ """
255
+ Get list of supported audio formats.
256
+
257
+ Returns:
258
+ List of supported file extensions
259
+ """
260
+ return self.supported_formats
230
261
 
231
262
  def get_supported_languages(self) -> List[str]:
232
- """Get list of supported language codes"""
233
- # Whisper supports 99+ languages
263
+ """
264
+ Get list of supported language codes for OpenAI Whisper.
265
+
266
+ Returns:
267
+ List of supported language codes
268
+ """
234
269
  return [
235
- 'af', 'am', 'ar', 'as', 'az', 'ba', 'be', 'bg', 'bn', 'bo', 'br', 'bs', 'ca',
236
- 'cs', 'cy', 'da', 'de', 'el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr',
237
- 'gl', 'gu', 'ha', 'haw', 'he', 'hi', 'hr', 'ht', 'hu', 'hy', 'id', 'is', 'it',
238
- 'ja', 'jw', 'ka', 'kk', 'km', 'kn', 'ko', 'la', 'lb', 'ln', 'lo', 'lt', 'lv',
239
- 'mg', 'mi', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'my', 'ne', 'nl', 'nn', 'no',
240
- 'oc', 'pa', 'pl', 'ps', 'pt', 'ro', 'ru', 'sa', 'sd', 'si', 'sk', 'sl', 'sn',
241
- 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk', 'tl', 'tr',
242
- 'tt', 'uk', 'ur', 'uz', 'vi', 'yi', 'yo', 'zh'
270
+ 'af', 'ar', 'hy', 'az', 'be', 'bs', 'bg', 'ca', 'zh', 'hr', 'cs', 'da',
271
+ 'nl', 'en', 'et', 'fi', 'fr', 'gl', 'de', 'el', 'he', 'hi', 'hu', 'is',
272
+ 'id', 'it', 'ja', 'kn', 'kk', 'ko', 'lv', 'lt', 'mk', 'ms', 'mr', 'mi',
273
+ 'ne', 'no', 'fa', 'pl', 'pt', 'ro', 'ru', 'sr', 'sk', 'sl', 'es', 'sw',
274
+ 'sv', 'tl', 'ta', 'th', 'tr', 'uk', 'ur', 'vi', 'cy'
243
275
  ]
244
-
276
+
245
277
  def get_max_file_size(self) -> int:
246
- """Get maximum file size in bytes"""
278
+ """
279
+ Get maximum file size limit in bytes.
280
+
281
+ Returns:
282
+ Maximum file size in bytes
283
+ """
247
284
  return self.max_file_size
248
-
285
+
249
286
  async def close(self):
250
287
  """Cleanup resources"""
251
- await self.client.close()
252
- logger.info("OpenAISTTService client has been closed.")
288
+ if hasattr(self.client, 'close'):
289
+ await self.client.close()
290
+ logger.info("OpenAI STT service closed")
@@ -4,20 +4,18 @@ import os
4
4
  from openai import AsyncOpenAI
5
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
6
  from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
- from isa_model.inference.providers.base_provider import BaseProvider
8
- from isa_model.inference.billing_tracker import ServiceType
9
7
  import logging
10
8
 
11
9
  logger = logging.getLogger(__name__)
12
10
 
13
11
  class OpenAITTSService(BaseTTSService):
14
- """Audio model service wrapper for YYDS"""
12
+ """OpenAI TTS service with unified architecture"""
15
13
 
16
- def __init__(self, provider: 'BaseProvider', model_name: str):
17
- super().__init__(provider, model_name)
14
+ def __init__(self, provider_name: str, model_name: str = "tts-1", **kwargs):
15
+ super().__init__(provider_name, model_name, **kwargs)
18
16
 
19
- # Get full configuration from provider (including sensitive data)
20
- provider_config = provider.get_full_config()
17
+ # Get configuration from centralized config manager
18
+ provider_config = self.get_provider_config()
21
19
 
22
20
  # Initialize AsyncOpenAI client with provider configuration
23
21
  try:
@@ -113,8 +111,8 @@ class OpenAITTSService(BaseTTSService):
113
111
  estimated_duration_seconds = (words / 150.0) * 60.0 / speed
114
112
 
115
113
  # Track usage for billing (OpenAI TTS is token-based: $15 per 1M characters)
116
- self._track_usage(
117
- service_type=ServiceType.AUDIO_TTS,
114
+ await self._track_usage(
115
+ service_type="audio_tts",
118
116
  operation="synthesize_speech",
119
117
  input_tokens=len(text), # Characters as input tokens
120
118
  output_tokens=0,
@@ -130,8 +128,12 @@ class OpenAITTSService(BaseTTSService):
130
128
  }
131
129
  )
132
130
 
131
+ # For HTTP API compatibility, encode audio data as base64
132
+ import base64
133
+ audio_base64 = base64.b64encode(audio_data).decode('utf-8')
134
+
133
135
  return {
134
- "audio_data": audio_data,
136
+ "audio_data_base64": audio_base64, # Base64 encoded for JSON compatibility
135
137
  "format": format,
136
138
  "duration": estimated_duration_seconds,
137
139
  "sample_rate": 24000 # Default for OpenAI TTS
@@ -4,42 +4,45 @@ import replicate
4
4
  from tenacity import retry, stop_after_attempt, wait_exponential
5
5
 
6
6
  from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
- from isa_model.inference.providers.base_provider import BaseProvider
8
- from isa_model.inference.billing_tracker import ServiceType
9
7
 
10
8
  logger = logging.getLogger(__name__)
11
9
 
12
10
  class ReplicateTTSService(BaseTTSService):
13
11
  """
14
- Replicate Text-to-Speech service using Kokoro model.
12
+ Replicate Text-to-Speech service using Kokoro model with unified architecture.
15
13
  High-quality voice synthesis with multiple voice options.
16
14
  """
17
15
 
18
- def __init__(self, provider: 'BaseProvider', model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"):
19
- super().__init__(provider, model_name)
16
+ def __init__(self, provider_name: str, model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", **kwargs):
17
+ super().__init__(provider_name, model_name, **kwargs)
20
18
 
21
- # Get full configuration from provider (including sensitive data)
22
- provider_config = provider.get_full_config()
19
+ # Get configuration from centralized config manager
20
+ provider_config = self.get_provider_config()
23
21
 
24
22
  # Set up Replicate API token from provider configuration
25
- self.api_token = provider_config.get('api_token') or provider_config.get('replicate_api_token')
26
- if not self.api_token:
27
- raise ValueError("Replicate API token not found in provider configuration")
28
-
29
- # Set environment variable for replicate library
30
- import os
31
- os.environ['REPLICATE_API_TOKEN'] = self.api_token
32
-
33
- # Available voices for Kokoro model
34
- self.available_voices = [
35
- "af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
36
- ]
37
-
38
- # Default settings
39
- self.default_voice = "af_nicole"
40
- self.default_speed = 1.0
41
-
42
- logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
23
+ try:
24
+ self.api_token = provider_config.get('api_key') or provider_config.get('replicate_api_token')
25
+ if not self.api_token:
26
+ raise ValueError("Replicate API token not found in provider configuration")
27
+
28
+ # Set environment variable for replicate library
29
+ import os
30
+ os.environ['REPLICATE_API_TOKEN'] = self.api_token
31
+
32
+ # Available voices for Kokoro model
33
+ self.available_voices = [
34
+ "af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
35
+ ]
36
+
37
+ # Default settings
38
+ self.default_voice = "af_nicole"
39
+ self.default_speed = 1.0
40
+
41
+ logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
42
+
43
+ except Exception as e:
44
+ logger.error(f"Failed to initialize Replicate client: {e}")
45
+ raise ValueError(f"Failed to initialize Replicate client: {e}") from e
43
46
 
44
47
  @retry(
45
48
  stop=stop_after_attempt(3),
@@ -51,8 +54,8 @@ class ReplicateTTSService(BaseTTSService):
51
54
  text: str,
52
55
  voice: Optional[str] = None,
53
56
  speed: float = 1.0,
54
- pitch: Optional[float] = None,
55
- volume: Optional[float] = None
57
+ pitch: float = 1.0,
58
+ format: str = "wav"
56
59
  ) -> Dict[str, Any]:
57
60
  """Synthesize speech from text using Kokoro model"""
58
61
  try:
@@ -99,8 +102,8 @@ class ReplicateTTSService(BaseTTSService):
99
102
  estimated_duration_seconds = (words / 150.0) * 60.0 / speed
100
103
 
101
104
  # Track usage for billing
102
- self._track_usage(
103
- service_type=ServiceType.AUDIO_TTS,
105
+ await self._track_usage(
106
+ service_type="audio_tts",
104
107
  operation="synthesize_speech",
105
108
  input_tokens=0,
106
109
  output_tokens=0,
@@ -115,15 +118,24 @@ class ReplicateTTSService(BaseTTSService):
115
118
  }
116
119
  )
117
120
 
121
+ # Download audio data for return format consistency
122
+ import aiohttp
123
+ async with aiohttp.ClientSession() as session:
124
+ async with session.get(audio_url) as response:
125
+ response.raise_for_status()
126
+ audio_data = await response.read()
127
+
118
128
  result = {
119
- "audio_url": audio_url,
120
- "text": text,
121
- "voice": selected_voice,
122
- "speed": speed,
123
- "duration_seconds": estimated_duration_seconds,
129
+ "audio_data": audio_data,
130
+ "format": "wav", # Kokoro typically outputs WAV
131
+ "duration": estimated_duration_seconds,
132
+ "sample_rate": 22050,
133
+ "audio_url": audio_url, # Keep URL for reference
124
134
  "metadata": {
125
135
  "model": self.model_name,
126
136
  "provider": "replicate",
137
+ "voice": selected_voice,
138
+ "speed": speed,
127
139
  "voice_options": self.available_voices
128
140
  }
129
141
  }
@@ -137,36 +149,29 @@ class ReplicateTTSService(BaseTTSService):
137
149
 
138
150
  async def synthesize_speech_to_file(
139
151
  self,
140
- text: str,
152
+ text: str,
141
153
  output_path: str,
142
154
  voice: Optional[str] = None,
143
155
  speed: float = 1.0,
144
- pitch: Optional[float] = None,
145
- volume: Optional[float] = None
156
+ pitch: float = 1.0,
157
+ format: str = "wav"
146
158
  ) -> Dict[str, Any]:
147
159
  """Synthesize speech and save to file"""
148
160
  try:
149
- # Get audio URL
150
- result = await self.synthesize_speech(text, voice, speed, pitch, volume)
151
- audio_url = result["audio_url"]
161
+ # Get synthesis result
162
+ result = await self.synthesize_speech(text, voice, speed, pitch, format)
163
+ audio_data = result["audio_data"]
152
164
 
153
- # Download and save audio
154
- import aiohttp
155
- import aiofiles
165
+ # Save audio data to file
166
+ with open(output_path, 'wb') as f:
167
+ f.write(audio_data)
156
168
 
157
- async with aiohttp.ClientSession() as session:
158
- async with session.get(audio_url) as response:
159
- response.raise_for_status()
160
- audio_data = await response.read()
161
-
162
- async with aiofiles.open(output_path, 'wb') as f:
163
- await f.write(audio_data)
164
-
165
- result["output_path"] = output_path
166
- result["file_size"] = len(audio_data)
167
-
168
- logger.info(f"Audio saved to: {output_path}")
169
- return result
169
+ return {
170
+ "file_path": output_path,
171
+ "duration": result["duration"],
172
+ "sample_rate": result["sample_rate"],
173
+ "file_size": len(audio_data)
174
+ }
170
175
 
171
176
  except Exception as e:
172
177
  logger.error(f"Error saving audio to file: {e}")