isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +770 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/models/model_repo.py +343 -0
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +187 -387
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -55
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
- isa_model/inference/services/img/flux_professional_service.py +603 -0
- isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +519 -35
- isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +150 -15
- isa_model/inference/services/llm/openai_llm_service.py +134 -31
- isa_model/inference/services/llm/yyds_llm_service.py +255 -0
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +241 -96
- isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
- isa_model/inference/services/vision/doc_analysis_service.py +640 -0
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +109 -170
- isa_model/inference/services/vision/replicate_vision_service.py +508 -0
- isa_model/inference/services/vision/ui_analysis_service.py +823 -0
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +89 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/unified.py +202 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
- isa_model-0.3.6.dist-info/RECORD +147 -0
- isa_model/core/model_manager.py +0 -208
- isa_model/core/model_registry.py +0 -342
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- isa_model-0.3.4.dist-info/RECORD +0 -91
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -5,8 +5,6 @@ from openai import AsyncOpenAI
|
|
5
5
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
6
|
|
7
7
|
from isa_model.inference.services.audio.base_stt_service import BaseSTTService
|
8
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
-
from isa_model.inference.billing_tracker import ServiceType
|
10
8
|
|
11
9
|
logger = logging.getLogger(__name__)
|
12
10
|
|
@@ -14,22 +12,22 @@ class OpenAISTTService(BaseSTTService):
|
|
14
12
|
"""
|
15
13
|
OpenAI Speech-to-Text service using whisper-1 model.
|
16
14
|
Supports transcription and translation to English.
|
15
|
+
Uses the new unified architecture with centralized config management.
|
17
16
|
"""
|
18
17
|
|
19
|
-
def __init__(self,
|
20
|
-
super().__init__(
|
18
|
+
def __init__(self, provider_name: str, model_name: str = "whisper-1", **kwargs):
|
19
|
+
super().__init__(provider_name, model_name, **kwargs)
|
21
20
|
|
22
|
-
# Get
|
23
|
-
provider_config =
|
21
|
+
# Get provider configuration from centralized config manager
|
22
|
+
provider_config = self.get_provider_config()
|
24
23
|
|
25
24
|
# Initialize AsyncOpenAI client with provider configuration
|
26
25
|
try:
|
27
|
-
|
28
|
-
raise ValueError("OpenAI API key not found in provider configuration")
|
26
|
+
api_key = self.get_api_key()
|
29
27
|
|
30
28
|
self.client = AsyncOpenAI(
|
31
|
-
api_key=
|
32
|
-
base_url=provider_config.get("
|
29
|
+
api_key=api_key,
|
30
|
+
base_url=provider_config.get("api_base_url", "https://api.openai.com/v1"),
|
33
31
|
organization=provider_config.get("organization")
|
34
32
|
)
|
35
33
|
|
@@ -48,205 +46,245 @@ class OpenAISTTService(BaseSTTService):
|
|
48
46
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
49
47
|
reraise=True
|
50
48
|
)
|
51
|
-
async def
|
52
|
-
"""
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
async def transcribe(
|
61
|
-
self,
|
62
|
-
audio_file: Union[str, BinaryIO],
|
63
|
-
language: Optional[str] = None,
|
64
|
-
prompt: Optional[str] = None
|
65
|
-
) -> Dict[str, Any]:
|
66
|
-
"""Transcribe audio file to text using whisper-1"""
|
67
|
-
try:
|
68
|
-
# Prepare the audio file
|
69
|
-
if isinstance(audio_file, str):
|
70
|
-
if audio_file.startswith(('http://', 'https://')):
|
71
|
-
# Download audio from URL
|
72
|
-
audio_data = await self._download_audio(audio_file)
|
73
|
-
filename = audio_file.split('/')[-1] or 'audio.wav'
|
74
|
-
else:
|
75
|
-
# Local file path
|
76
|
-
with open(audio_file, 'rb') as f:
|
77
|
-
audio_data = f.read()
|
78
|
-
filename = audio_file
|
79
|
-
else:
|
80
|
-
audio_data = audio_file.read()
|
81
|
-
filename = getattr(audio_file, 'name', 'audio.wav')
|
82
|
-
|
83
|
-
# Check file size
|
84
|
-
if len(audio_data) > self.max_file_size:
|
85
|
-
raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
|
49
|
+
async def transcribe(self, audio_file: Union[str, BinaryIO], language: Optional[str] = None, prompt: Optional[str] = None) -> Dict[str, Any]:
|
50
|
+
"""
|
51
|
+
Transcribe audio file to text using OpenAI's Whisper model.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
audio_file: Path to audio file or file-like object
|
55
|
+
language: Optional language code for better accuracy
|
56
|
+
**kwargs: Additional parameters for the transcription API
|
86
57
|
|
87
|
-
|
88
|
-
|
58
|
+
Returns:
|
59
|
+
Dict containing transcription result and metadata
|
60
|
+
"""
|
61
|
+
try:
|
62
|
+
# Prepare request parameters
|
63
|
+
transcription_params = {
|
89
64
|
"model": self.model_name,
|
90
|
-
"file": (filename, audio_data),
|
91
65
|
"response_format": "verbose_json"
|
92
66
|
}
|
93
67
|
|
94
68
|
if language:
|
95
|
-
|
96
|
-
if prompt:
|
97
|
-
kwargs["prompt"] = prompt
|
69
|
+
transcription_params["language"] = language
|
98
70
|
|
99
|
-
#
|
100
|
-
|
71
|
+
# Add optional parameters
|
72
|
+
if prompt:
|
73
|
+
transcription_params["prompt"] = prompt
|
101
74
|
|
102
|
-
#
|
103
|
-
|
104
|
-
|
105
|
-
|
75
|
+
# Handle file input
|
76
|
+
if isinstance(audio_file, str):
|
77
|
+
with open(audio_file, "rb") as f:
|
78
|
+
transcription = await self.client.audio.transcriptions.create(
|
79
|
+
file=f,
|
80
|
+
**transcription_params
|
81
|
+
)
|
82
|
+
else:
|
83
|
+
transcription = await self.client.audio.transcriptions.create(
|
84
|
+
file=audio_file,
|
85
|
+
**transcription_params
|
86
|
+
)
|
106
87
|
|
107
|
-
#
|
108
|
-
|
88
|
+
# Extract usage information for billing
|
89
|
+
result = {
|
90
|
+
"text": transcription.text,
|
91
|
+
"language": getattr(transcription, 'language', language),
|
92
|
+
"duration": getattr(transcription, 'duration', None),
|
93
|
+
"segments": getattr(transcription, 'segments', []),
|
94
|
+
"usage": {
|
95
|
+
"input_units": getattr(transcription, 'duration', 1), # Duration in seconds
|
96
|
+
"output_tokens": len(transcription.text.split()) if transcription.text else 0
|
97
|
+
}
|
98
|
+
}
|
109
99
|
|
110
|
-
|
111
|
-
|
100
|
+
# Track usage for billing
|
101
|
+
await self._track_usage(
|
102
|
+
service_type="audio_stt",
|
112
103
|
operation="transcribe",
|
113
|
-
|
114
|
-
output_tokens=output_tokens,
|
115
|
-
input_units=duration_minutes, # Duration in minutes
|
104
|
+
input_units=result["usage"]["input_units"],
|
105
|
+
output_tokens=result["usage"]["output_tokens"],
|
116
106
|
metadata={
|
117
|
-
"language": language,
|
118
|
-
"
|
119
|
-
"
|
107
|
+
"language": result.get("language"),
|
108
|
+
"model_name": self.model_name,
|
109
|
+
"provider": self.provider_name
|
120
110
|
}
|
121
111
|
)
|
122
112
|
|
123
|
-
# Format response
|
124
|
-
result = {
|
125
|
-
"text": response.text,
|
126
|
-
"language": getattr(response, 'language', language or 'unknown'),
|
127
|
-
"duration": getattr(response, 'duration', None),
|
128
|
-
"segments": getattr(response, 'segments', []),
|
129
|
-
"confidence": None, # whisper-1 doesn't provide confidence scores
|
130
|
-
"usage": usage # Include usage information
|
131
|
-
}
|
132
|
-
|
133
113
|
return result
|
134
114
|
|
135
115
|
except Exception as e:
|
136
|
-
logger.error(f"
|
116
|
+
logger.error(f"Transcription failed: {e}")
|
137
117
|
raise
|
138
|
-
|
118
|
+
|
139
119
|
@retry(
|
140
120
|
stop=stop_after_attempt(3),
|
141
121
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
142
122
|
reraise=True
|
143
123
|
)
|
144
|
-
async def translate(
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
124
|
+
async def translate(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
|
125
|
+
"""
|
126
|
+
Translate audio file to English text using OpenAI's Whisper model.
|
127
|
+
|
128
|
+
Args:
|
129
|
+
audio_file: Path to audio file or file-like object
|
130
|
+
**kwargs: Additional parameters for the translation API
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Dict containing translation result and metadata
|
134
|
+
"""
|
149
135
|
try:
|
150
|
-
# Prepare
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
else:
|
156
|
-
audio_data = audio_file.read()
|
157
|
-
filename = getattr(audio_file, 'name', 'audio.wav')
|
136
|
+
# Prepare request parameters
|
137
|
+
translation_params = {
|
138
|
+
"model": self.model_name,
|
139
|
+
"response_format": "verbose_json"
|
140
|
+
}
|
158
141
|
|
159
|
-
#
|
160
|
-
if len(audio_data) > self.max_file_size:
|
161
|
-
raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
|
142
|
+
# No additional parameters for translation
|
162
143
|
|
163
|
-
#
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
144
|
+
# Handle file input
|
145
|
+
if isinstance(audio_file, str):
|
146
|
+
with open(audio_file, "rb") as f:
|
147
|
+
translation = await self.client.audio.translations.create(
|
148
|
+
file=f,
|
149
|
+
**translation_params
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
translation = await self.client.audio.translations.create(
|
153
|
+
file=audio_file,
|
154
|
+
**translation_params
|
155
|
+
)
|
169
156
|
|
170
|
-
#
|
157
|
+
# Extract usage information for billing
|
171
158
|
result = {
|
172
|
-
"text":
|
173
|
-
"
|
174
|
-
"duration": getattr(
|
175
|
-
"segments": getattr(
|
176
|
-
"
|
159
|
+
"text": translation.text,
|
160
|
+
"language": "en", # Translation is always to English
|
161
|
+
"duration": getattr(translation, 'duration', None),
|
162
|
+
"segments": getattr(translation, 'segments', []),
|
163
|
+
"usage": {
|
164
|
+
"input_units": getattr(translation, 'duration', 1), # Duration in seconds
|
165
|
+
"output_tokens": len(translation.text.split()) if translation.text else 0
|
166
|
+
}
|
177
167
|
}
|
178
168
|
|
169
|
+
# Track usage for billing
|
170
|
+
await self._track_usage(
|
171
|
+
service_type="audio_stt",
|
172
|
+
operation="translate",
|
173
|
+
input_units=result["usage"]["input_units"],
|
174
|
+
output_tokens=result["usage"]["output_tokens"],
|
175
|
+
metadata={
|
176
|
+
"target_language": "en",
|
177
|
+
"model_name": self.model_name,
|
178
|
+
"provider": self.provider_name
|
179
|
+
}
|
180
|
+
)
|
181
|
+
|
179
182
|
return result
|
180
183
|
|
181
184
|
except Exception as e:
|
182
|
-
logger.error(f"
|
185
|
+
logger.error(f"Translation failed: {e}")
|
183
186
|
raise
|
184
|
-
|
185
|
-
async def transcribe_batch(
|
186
|
-
|
187
|
-
|
188
|
-
language: Optional[str] = None,
|
189
|
-
prompt: Optional[str] = None
|
190
|
-
) -> List[Dict[str, Any]]:
|
191
|
-
"""Transcribe multiple audio files"""
|
192
|
-
results = []
|
187
|
+
|
188
|
+
async def transcribe_batch(self, audio_files: List[Union[str, BinaryIO]], language: Optional[str] = None, prompt: Optional[str] = None) -> List[Dict[str, Any]]:
|
189
|
+
"""
|
190
|
+
Transcribe multiple audio files in batch.
|
193
191
|
|
192
|
+
Args:
|
193
|
+
audio_files: List of audio file paths or file-like objects
|
194
|
+
language: Optional language code for better accuracy
|
195
|
+
**kwargs: Additional parameters for the transcription API
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
List of transcription results
|
199
|
+
"""
|
200
|
+
results = []
|
194
201
|
for audio_file in audio_files:
|
195
202
|
try:
|
196
203
|
result = await self.transcribe(audio_file, language, prompt)
|
197
204
|
results.append(result)
|
198
205
|
except Exception as e:
|
199
|
-
logger.error(f"
|
206
|
+
logger.error(f"Failed to transcribe {audio_file}: {e}")
|
200
207
|
results.append({
|
201
|
-
"
|
202
|
-
"
|
203
|
-
"
|
204
|
-
"segments": [],
|
205
|
-
"confidence": None,
|
206
|
-
"error": str(e)
|
208
|
+
"error": str(e),
|
209
|
+
"file": str(audio_file),
|
210
|
+
"text": None
|
207
211
|
})
|
208
212
|
|
209
213
|
return results
|
210
|
-
|
214
|
+
|
211
215
|
async def detect_language(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
|
212
|
-
"""
|
216
|
+
"""
|
217
|
+
Detect the language of an audio file.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
audio_file: Path to audio file or file-like object
|
221
|
+
**kwargs: Additional parameters
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
Dict containing detected language and confidence
|
225
|
+
"""
|
213
226
|
try:
|
214
|
-
#
|
215
|
-
|
227
|
+
# Use transcription with language detection - need to access client directly
|
228
|
+
transcription = await self.client.audio.transcriptions.create(
|
229
|
+
file=audio_file if not isinstance(audio_file, str) else open(audio_file, "rb"),
|
230
|
+
model=self.model_name,
|
231
|
+
response_format="verbose_json"
|
232
|
+
)
|
233
|
+
|
234
|
+
result = {
|
235
|
+
"text": transcription.text,
|
236
|
+
"language": getattr(transcription, 'language', "unknown")
|
237
|
+
}
|
216
238
|
|
217
239
|
return {
|
218
|
-
"language": result
|
219
|
-
"confidence": 1.0, #
|
220
|
-
"
|
240
|
+
"language": result.get("language", "unknown"),
|
241
|
+
"confidence": 1.0, # OpenAI doesn't provide confidence scores
|
242
|
+
"text_sample": result.get("text", "")[:100] if result.get("text") else ""
|
221
243
|
}
|
222
244
|
|
223
245
|
except Exception as e:
|
224
|
-
logger.error(f"
|
225
|
-
|
226
|
-
|
246
|
+
logger.error(f"Language detection failed: {e}")
|
247
|
+
return {
|
248
|
+
"language": "unknown",
|
249
|
+
"confidence": 0.0,
|
250
|
+
"error": str(e)
|
251
|
+
}
|
252
|
+
|
227
253
|
def get_supported_formats(self) -> List[str]:
|
228
|
-
"""
|
229
|
-
|
254
|
+
"""
|
255
|
+
Get list of supported audio formats.
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
List of supported file extensions
|
259
|
+
"""
|
260
|
+
return self.supported_formats
|
230
261
|
|
231
262
|
def get_supported_languages(self) -> List[str]:
|
232
|
-
"""
|
233
|
-
|
263
|
+
"""
|
264
|
+
Get list of supported language codes for OpenAI Whisper.
|
265
|
+
|
266
|
+
Returns:
|
267
|
+
List of supported language codes
|
268
|
+
"""
|
234
269
|
return [
|
235
|
-
'af', '
|
236
|
-
'
|
237
|
-
'
|
238
|
-
'
|
239
|
-
'
|
240
|
-
'oc', 'pa', 'pl', 'ps', 'pt', 'ro', 'ru', 'sa', 'sd', 'si', 'sk', 'sl', 'sn',
|
241
|
-
'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk', 'tl', 'tr',
|
242
|
-
'tt', 'uk', 'ur', 'uz', 'vi', 'yi', 'yo', 'zh'
|
270
|
+
'af', 'ar', 'hy', 'az', 'be', 'bs', 'bg', 'ca', 'zh', 'hr', 'cs', 'da',
|
271
|
+
'nl', 'en', 'et', 'fi', 'fr', 'gl', 'de', 'el', 'he', 'hi', 'hu', 'is',
|
272
|
+
'id', 'it', 'ja', 'kn', 'kk', 'ko', 'lv', 'lt', 'mk', 'ms', 'mr', 'mi',
|
273
|
+
'ne', 'no', 'fa', 'pl', 'pt', 'ro', 'ru', 'sr', 'sk', 'sl', 'es', 'sw',
|
274
|
+
'sv', 'tl', 'ta', 'th', 'tr', 'uk', 'ur', 'vi', 'cy'
|
243
275
|
]
|
244
|
-
|
276
|
+
|
245
277
|
def get_max_file_size(self) -> int:
|
246
|
-
"""
|
278
|
+
"""
|
279
|
+
Get maximum file size limit in bytes.
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
Maximum file size in bytes
|
283
|
+
"""
|
247
284
|
return self.max_file_size
|
248
|
-
|
285
|
+
|
249
286
|
async def close(self):
|
250
287
|
"""Cleanup resources"""
|
251
|
-
|
252
|
-
|
288
|
+
if hasattr(self.client, 'close'):
|
289
|
+
await self.client.close()
|
290
|
+
logger.info("OpenAI STT service closed")
|
@@ -4,20 +4,18 @@ import os
|
|
4
4
|
from openai import AsyncOpenAI
|
5
5
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
6
|
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
7
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
8
|
-
from isa_model.inference.billing_tracker import ServiceType
|
9
7
|
import logging
|
10
8
|
|
11
9
|
logger = logging.getLogger(__name__)
|
12
10
|
|
13
11
|
class OpenAITTSService(BaseTTSService):
|
14
|
-
"""
|
12
|
+
"""OpenAI TTS service with unified architecture"""
|
15
13
|
|
16
|
-
def __init__(self,
|
17
|
-
super().__init__(
|
14
|
+
def __init__(self, provider_name: str, model_name: str = "tts-1", **kwargs):
|
15
|
+
super().__init__(provider_name, model_name, **kwargs)
|
18
16
|
|
19
|
-
# Get
|
20
|
-
provider_config =
|
17
|
+
# Get configuration from centralized config manager
|
18
|
+
provider_config = self.get_provider_config()
|
21
19
|
|
22
20
|
# Initialize AsyncOpenAI client with provider configuration
|
23
21
|
try:
|
@@ -113,8 +111,8 @@ class OpenAITTSService(BaseTTSService):
|
|
113
111
|
estimated_duration_seconds = (words / 150.0) * 60.0 / speed
|
114
112
|
|
115
113
|
# Track usage for billing (OpenAI TTS is token-based: $15 per 1M characters)
|
116
|
-
self._track_usage(
|
117
|
-
service_type=
|
114
|
+
await self._track_usage(
|
115
|
+
service_type="audio_tts",
|
118
116
|
operation="synthesize_speech",
|
119
117
|
input_tokens=len(text), # Characters as input tokens
|
120
118
|
output_tokens=0,
|
@@ -130,8 +128,12 @@ class OpenAITTSService(BaseTTSService):
|
|
130
128
|
}
|
131
129
|
)
|
132
130
|
|
131
|
+
# For HTTP API compatibility, encode audio data as base64
|
132
|
+
import base64
|
133
|
+
audio_base64 = base64.b64encode(audio_data).decode('utf-8')
|
134
|
+
|
133
135
|
return {
|
134
|
-
"
|
136
|
+
"audio_data_base64": audio_base64, # Base64 encoded for JSON compatibility
|
135
137
|
"format": format,
|
136
138
|
"duration": estimated_duration_seconds,
|
137
139
|
"sample_rate": 24000 # Default for OpenAI TTS
|
@@ -4,42 +4,45 @@ import replicate
|
|
4
4
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
5
5
|
|
6
6
|
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
7
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
8
|
-
from isa_model.inference.billing_tracker import ServiceType
|
9
7
|
|
10
8
|
logger = logging.getLogger(__name__)
|
11
9
|
|
12
10
|
class ReplicateTTSService(BaseTTSService):
|
13
11
|
"""
|
14
|
-
Replicate Text-to-Speech service using Kokoro model.
|
12
|
+
Replicate Text-to-Speech service using Kokoro model with unified architecture.
|
15
13
|
High-quality voice synthesis with multiple voice options.
|
16
14
|
"""
|
17
15
|
|
18
|
-
def __init__(self,
|
19
|
-
super().__init__(
|
16
|
+
def __init__(self, provider_name: str, model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", **kwargs):
|
17
|
+
super().__init__(provider_name, model_name, **kwargs)
|
20
18
|
|
21
|
-
# Get
|
22
|
-
provider_config =
|
19
|
+
# Get configuration from centralized config manager
|
20
|
+
provider_config = self.get_provider_config()
|
23
21
|
|
24
22
|
# Set up Replicate API token from provider configuration
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
23
|
+
try:
|
24
|
+
self.api_token = provider_config.get('api_key') or provider_config.get('replicate_api_token')
|
25
|
+
if not self.api_token:
|
26
|
+
raise ValueError("Replicate API token not found in provider configuration")
|
27
|
+
|
28
|
+
# Set environment variable for replicate library
|
29
|
+
import os
|
30
|
+
os.environ['REPLICATE_API_TOKEN'] = self.api_token
|
31
|
+
|
32
|
+
# Available voices for Kokoro model
|
33
|
+
self.available_voices = [
|
34
|
+
"af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
|
35
|
+
]
|
36
|
+
|
37
|
+
# Default settings
|
38
|
+
self.default_voice = "af_nicole"
|
39
|
+
self.default_speed = 1.0
|
40
|
+
|
41
|
+
logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
|
42
|
+
|
43
|
+
except Exception as e:
|
44
|
+
logger.error(f"Failed to initialize Replicate client: {e}")
|
45
|
+
raise ValueError(f"Failed to initialize Replicate client: {e}") from e
|
43
46
|
|
44
47
|
@retry(
|
45
48
|
stop=stop_after_attempt(3),
|
@@ -51,8 +54,8 @@ class ReplicateTTSService(BaseTTSService):
|
|
51
54
|
text: str,
|
52
55
|
voice: Optional[str] = None,
|
53
56
|
speed: float = 1.0,
|
54
|
-
pitch:
|
55
|
-
|
57
|
+
pitch: float = 1.0,
|
58
|
+
format: str = "wav"
|
56
59
|
) -> Dict[str, Any]:
|
57
60
|
"""Synthesize speech from text using Kokoro model"""
|
58
61
|
try:
|
@@ -99,8 +102,8 @@ class ReplicateTTSService(BaseTTSService):
|
|
99
102
|
estimated_duration_seconds = (words / 150.0) * 60.0 / speed
|
100
103
|
|
101
104
|
# Track usage for billing
|
102
|
-
self._track_usage(
|
103
|
-
service_type=
|
105
|
+
await self._track_usage(
|
106
|
+
service_type="audio_tts",
|
104
107
|
operation="synthesize_speech",
|
105
108
|
input_tokens=0,
|
106
109
|
output_tokens=0,
|
@@ -115,15 +118,24 @@ class ReplicateTTSService(BaseTTSService):
|
|
115
118
|
}
|
116
119
|
)
|
117
120
|
|
121
|
+
# Download audio data for return format consistency
|
122
|
+
import aiohttp
|
123
|
+
async with aiohttp.ClientSession() as session:
|
124
|
+
async with session.get(audio_url) as response:
|
125
|
+
response.raise_for_status()
|
126
|
+
audio_data = await response.read()
|
127
|
+
|
118
128
|
result = {
|
119
|
-
"
|
120
|
-
"
|
121
|
-
"
|
122
|
-
"
|
123
|
-
"
|
129
|
+
"audio_data": audio_data,
|
130
|
+
"format": "wav", # Kokoro typically outputs WAV
|
131
|
+
"duration": estimated_duration_seconds,
|
132
|
+
"sample_rate": 22050,
|
133
|
+
"audio_url": audio_url, # Keep URL for reference
|
124
134
|
"metadata": {
|
125
135
|
"model": self.model_name,
|
126
136
|
"provider": "replicate",
|
137
|
+
"voice": selected_voice,
|
138
|
+
"speed": speed,
|
127
139
|
"voice_options": self.available_voices
|
128
140
|
}
|
129
141
|
}
|
@@ -137,36 +149,29 @@ class ReplicateTTSService(BaseTTSService):
|
|
137
149
|
|
138
150
|
async def synthesize_speech_to_file(
|
139
151
|
self,
|
140
|
-
text: str,
|
152
|
+
text: str,
|
141
153
|
output_path: str,
|
142
154
|
voice: Optional[str] = None,
|
143
155
|
speed: float = 1.0,
|
144
|
-
pitch:
|
145
|
-
|
156
|
+
pitch: float = 1.0,
|
157
|
+
format: str = "wav"
|
146
158
|
) -> Dict[str, Any]:
|
147
159
|
"""Synthesize speech and save to file"""
|
148
160
|
try:
|
149
|
-
# Get
|
150
|
-
result = await self.synthesize_speech(text, voice, speed, pitch,
|
151
|
-
|
161
|
+
# Get synthesis result
|
162
|
+
result = await self.synthesize_speech(text, voice, speed, pitch, format)
|
163
|
+
audio_data = result["audio_data"]
|
152
164
|
|
153
|
-
#
|
154
|
-
|
155
|
-
|
165
|
+
# Save audio data to file
|
166
|
+
with open(output_path, 'wb') as f:
|
167
|
+
f.write(audio_data)
|
156
168
|
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
await f.write(audio_data)
|
164
|
-
|
165
|
-
result["output_path"] = output_path
|
166
|
-
result["file_size"] = len(audio_data)
|
167
|
-
|
168
|
-
logger.info(f"Audio saved to: {output_path}")
|
169
|
-
return result
|
169
|
+
return {
|
170
|
+
"file_path": output_path,
|
171
|
+
"duration": result["duration"],
|
172
|
+
"sample_rate": result["sample_rate"],
|
173
|
+
"file_size": len(audio_data)
|
174
|
+
}
|
170
175
|
|
171
176
|
except Exception as e:
|
172
177
|
logger.error(f"Error saving audio to file: {e}")
|