isa-model 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/core/model_manager.py +69 -4
- isa_model/inference/ai_factory.py +335 -46
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +48 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/base_llm_service.py +88 -192
- isa_model/inference/services/llm/llm_adapter.py +459 -0
- isa_model/inference/services/llm/ollama_llm_service.py +111 -185
- isa_model/inference/services/llm/openai_llm_service.py +115 -360
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +11 -3
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +233 -205
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/METADATA +1 -1
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/RECORD +26 -21
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/WHEEL +0 -0
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,252 @@
|
|
1
|
+
import logging
|
2
|
+
import aiohttp
|
3
|
+
from typing import Dict, Any, List, Union, Optional, BinaryIO
|
4
|
+
from openai import AsyncOpenAI
|
5
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
|
+
|
7
|
+
from isa_model.inference.services.audio.base_stt_service import BaseSTTService
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
from isa_model.inference.billing_tracker import ServiceType
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
class OpenAISTTService(BaseSTTService):
|
14
|
+
"""
|
15
|
+
OpenAI Speech-to-Text service using whisper-1 model.
|
16
|
+
Supports transcription and translation to English.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "whisper-1"):
|
20
|
+
super().__init__(provider, model_name)
|
21
|
+
|
22
|
+
# Get full configuration from provider (including sensitive data)
|
23
|
+
provider_config = provider.get_full_config()
|
24
|
+
|
25
|
+
# Initialize AsyncOpenAI client with provider configuration
|
26
|
+
try:
|
27
|
+
if not provider_config.get("api_key"):
|
28
|
+
raise ValueError("OpenAI API key not found in provider configuration")
|
29
|
+
|
30
|
+
self.client = AsyncOpenAI(
|
31
|
+
api_key=provider_config["api_key"],
|
32
|
+
base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
|
33
|
+
organization=provider_config.get("organization")
|
34
|
+
)
|
35
|
+
|
36
|
+
logger.info(f"Initialized OpenAISTTService with model '{self.model_name}'")
|
37
|
+
|
38
|
+
except Exception as e:
|
39
|
+
logger.error(f"Failed to initialize OpenAI client: {e}")
|
40
|
+
raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
|
41
|
+
|
42
|
+
# Model configurations
|
43
|
+
self.max_file_size = provider_config.get('max_file_size', 25 * 1024 * 1024) # 25MB
|
44
|
+
self.supported_formats = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']
|
45
|
+
|
46
|
+
@retry(
|
47
|
+
stop=stop_after_attempt(3),
|
48
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
49
|
+
reraise=True
|
50
|
+
)
|
51
|
+
async def _download_audio(self, audio_url: str) -> bytes:
|
52
|
+
"""Download audio from URL"""
|
53
|
+
async with aiohttp.ClientSession() as session:
|
54
|
+
async with session.get(audio_url) as response:
|
55
|
+
if response.status == 200:
|
56
|
+
return await response.read()
|
57
|
+
else:
|
58
|
+
raise ValueError(f"Failed to download audio from {audio_url}: {response.status}")
|
59
|
+
|
60
|
+
async def transcribe(
|
61
|
+
self,
|
62
|
+
audio_file: Union[str, BinaryIO],
|
63
|
+
language: Optional[str] = None,
|
64
|
+
prompt: Optional[str] = None
|
65
|
+
) -> Dict[str, Any]:
|
66
|
+
"""Transcribe audio file to text using whisper-1"""
|
67
|
+
try:
|
68
|
+
# Prepare the audio file
|
69
|
+
if isinstance(audio_file, str):
|
70
|
+
if audio_file.startswith(('http://', 'https://')):
|
71
|
+
# Download audio from URL
|
72
|
+
audio_data = await self._download_audio(audio_file)
|
73
|
+
filename = audio_file.split('/')[-1] or 'audio.wav'
|
74
|
+
else:
|
75
|
+
# Local file path
|
76
|
+
with open(audio_file, 'rb') as f:
|
77
|
+
audio_data = f.read()
|
78
|
+
filename = audio_file
|
79
|
+
else:
|
80
|
+
audio_data = audio_file.read()
|
81
|
+
filename = getattr(audio_file, 'name', 'audio.wav')
|
82
|
+
|
83
|
+
# Check file size
|
84
|
+
if len(audio_data) > self.max_file_size:
|
85
|
+
raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
|
86
|
+
|
87
|
+
# Prepare transcription parameters
|
88
|
+
kwargs = {
|
89
|
+
"model": self.model_name,
|
90
|
+
"file": (filename, audio_data),
|
91
|
+
"response_format": "verbose_json"
|
92
|
+
}
|
93
|
+
|
94
|
+
if language:
|
95
|
+
kwargs["language"] = language
|
96
|
+
if prompt:
|
97
|
+
kwargs["prompt"] = prompt
|
98
|
+
|
99
|
+
# Transcribe audio
|
100
|
+
response = await self.client.audio.transcriptions.create(**kwargs)
|
101
|
+
|
102
|
+
# Track usage for billing
|
103
|
+
usage = getattr(response, 'usage', {})
|
104
|
+
input_tokens = usage.get('input_tokens', 0) if usage else 0
|
105
|
+
output_tokens = usage.get('output_tokens', 0) if usage else 0
|
106
|
+
|
107
|
+
# For audio, also track duration in minutes
|
108
|
+
duration_minutes = getattr(response, 'duration', 0) / 60.0 if getattr(response, 'duration', 0) else 0
|
109
|
+
|
110
|
+
self._track_usage(
|
111
|
+
service_type=ServiceType.AUDIO_STT,
|
112
|
+
operation="transcribe",
|
113
|
+
input_tokens=input_tokens,
|
114
|
+
output_tokens=output_tokens,
|
115
|
+
input_units=duration_minutes, # Duration in minutes
|
116
|
+
metadata={
|
117
|
+
"language": language,
|
118
|
+
"model": self.model_name,
|
119
|
+
"file_size": len(audio_data)
|
120
|
+
}
|
121
|
+
)
|
122
|
+
|
123
|
+
# Format response
|
124
|
+
result = {
|
125
|
+
"text": response.text,
|
126
|
+
"language": getattr(response, 'language', language or 'unknown'),
|
127
|
+
"duration": getattr(response, 'duration', None),
|
128
|
+
"segments": getattr(response, 'segments', []),
|
129
|
+
"confidence": None, # whisper-1 doesn't provide confidence scores
|
130
|
+
"usage": usage # Include usage information
|
131
|
+
}
|
132
|
+
|
133
|
+
return result
|
134
|
+
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Error transcribing audio: {e}")
|
137
|
+
raise
|
138
|
+
|
139
|
+
@retry(
|
140
|
+
stop=stop_after_attempt(3),
|
141
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
142
|
+
reraise=True
|
143
|
+
)
|
144
|
+
async def translate(
|
145
|
+
self,
|
146
|
+
audio_file: Union[str, BinaryIO]
|
147
|
+
) -> Dict[str, Any]:
|
148
|
+
"""Translate audio file to English text"""
|
149
|
+
try:
|
150
|
+
# Prepare the audio file
|
151
|
+
if isinstance(audio_file, str):
|
152
|
+
with open(audio_file, 'rb') as f:
|
153
|
+
audio_data = f.read()
|
154
|
+
filename = audio_file
|
155
|
+
else:
|
156
|
+
audio_data = audio_file.read()
|
157
|
+
filename = getattr(audio_file, 'name', 'audio.wav')
|
158
|
+
|
159
|
+
# Check file size
|
160
|
+
if len(audio_data) > self.max_file_size:
|
161
|
+
raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
|
162
|
+
|
163
|
+
# Translate audio to English
|
164
|
+
response = await self.client.audio.translations.create(
|
165
|
+
model=self.model_name,
|
166
|
+
file=(filename, audio_data),
|
167
|
+
response_format="verbose_json"
|
168
|
+
)
|
169
|
+
|
170
|
+
# Format response
|
171
|
+
result = {
|
172
|
+
"text": response.text,
|
173
|
+
"detected_language": getattr(response, 'language', 'unknown'),
|
174
|
+
"duration": getattr(response, 'duration', None),
|
175
|
+
"segments": getattr(response, 'segments', []),
|
176
|
+
"confidence": None # Whisper doesn't provide confidence scores
|
177
|
+
}
|
178
|
+
|
179
|
+
return result
|
180
|
+
|
181
|
+
except Exception as e:
|
182
|
+
logger.error(f"Error translating audio: {e}")
|
183
|
+
raise
|
184
|
+
|
185
|
+
async def transcribe_batch(
|
186
|
+
self,
|
187
|
+
audio_files: List[Union[str, BinaryIO]],
|
188
|
+
language: Optional[str] = None,
|
189
|
+
prompt: Optional[str] = None
|
190
|
+
) -> List[Dict[str, Any]]:
|
191
|
+
"""Transcribe multiple audio files"""
|
192
|
+
results = []
|
193
|
+
|
194
|
+
for audio_file in audio_files:
|
195
|
+
try:
|
196
|
+
result = await self.transcribe(audio_file, language, prompt)
|
197
|
+
results.append(result)
|
198
|
+
except Exception as e:
|
199
|
+
logger.error(f"Error transcribing audio file: {e}")
|
200
|
+
results.append({
|
201
|
+
"text": "",
|
202
|
+
"language": "unknown",
|
203
|
+
"duration": None,
|
204
|
+
"segments": [],
|
205
|
+
"confidence": None,
|
206
|
+
"error": str(e)
|
207
|
+
})
|
208
|
+
|
209
|
+
return results
|
210
|
+
|
211
|
+
async def detect_language(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
|
212
|
+
"""Detect language of audio file"""
|
213
|
+
try:
|
214
|
+
# Transcribe with language detection
|
215
|
+
result = await self.transcribe(audio_file, language=None)
|
216
|
+
|
217
|
+
return {
|
218
|
+
"language": result["language"],
|
219
|
+
"confidence": 1.0, # Whisper is generally confident
|
220
|
+
"alternatives": [] # Whisper doesn't provide alternatives
|
221
|
+
}
|
222
|
+
|
223
|
+
except Exception as e:
|
224
|
+
logger.error(f"Error detecting language: {e}")
|
225
|
+
raise
|
226
|
+
|
227
|
+
def get_supported_formats(self) -> List[str]:
|
228
|
+
"""Get list of supported audio formats"""
|
229
|
+
return self.supported_formats.copy()
|
230
|
+
|
231
|
+
def get_supported_languages(self) -> List[str]:
|
232
|
+
"""Get list of supported language codes"""
|
233
|
+
# Whisper supports 99+ languages
|
234
|
+
return [
|
235
|
+
'af', 'am', 'ar', 'as', 'az', 'ba', 'be', 'bg', 'bn', 'bo', 'br', 'bs', 'ca',
|
236
|
+
'cs', 'cy', 'da', 'de', 'el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr',
|
237
|
+
'gl', 'gu', 'ha', 'haw', 'he', 'hi', 'hr', 'ht', 'hu', 'hy', 'id', 'is', 'it',
|
238
|
+
'ja', 'jw', 'ka', 'kk', 'km', 'kn', 'ko', 'la', 'lb', 'ln', 'lo', 'lt', 'lv',
|
239
|
+
'mg', 'mi', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'my', 'ne', 'nl', 'nn', 'no',
|
240
|
+
'oc', 'pa', 'pl', 'ps', 'pt', 'ro', 'ru', 'sa', 'sd', 'si', 'sk', 'sl', 'sn',
|
241
|
+
'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk', 'tl', 'tr',
|
242
|
+
'tt', 'uk', 'ur', 'uz', 'vi', 'yi', 'yo', 'zh'
|
243
|
+
]
|
244
|
+
|
245
|
+
def get_max_file_size(self) -> int:
|
246
|
+
"""Get maximum file size in bytes"""
|
247
|
+
return self.max_file_size
|
248
|
+
|
249
|
+
async def close(self):
|
250
|
+
"""Cleanup resources"""
|
251
|
+
await self.client.close()
|
252
|
+
logger.info("OpenAISTTService client has been closed.")
|
@@ -5,6 +5,7 @@ from openai import AsyncOpenAI
|
|
5
5
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
6
|
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
7
7
|
from isa_model.inference.providers.base_provider import BaseProvider
|
8
|
+
from isa_model.inference.billing_tracker import ServiceType
|
8
9
|
import logging
|
9
10
|
|
10
11
|
logger = logging.getLogger(__name__)
|
@@ -14,12 +15,28 @@ class OpenAITTSService(BaseTTSService):
|
|
14
15
|
|
15
16
|
def __init__(self, provider: 'BaseProvider', model_name: str):
|
16
17
|
super().__init__(provider, model_name)
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
18
|
+
|
19
|
+
# Get full configuration from provider (including sensitive data)
|
20
|
+
provider_config = provider.get_full_config()
|
21
|
+
|
22
|
+
# Initialize AsyncOpenAI client with provider configuration
|
23
|
+
try:
|
24
|
+
if not provider_config.get("api_key"):
|
25
|
+
raise ValueError("OpenAI API key not found in provider configuration")
|
26
|
+
|
27
|
+
self._client = AsyncOpenAI(
|
28
|
+
api_key=provider_config["api_key"],
|
29
|
+
base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
|
30
|
+
organization=provider_config.get("organization")
|
31
|
+
)
|
32
|
+
|
33
|
+
logger.info(f"Initialized OpenAITTSService with model '{self.model_name}'")
|
34
|
+
|
35
|
+
except Exception as e:
|
36
|
+
logger.error(f"Failed to initialize OpenAI client: {e}")
|
37
|
+
raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
|
38
|
+
|
39
|
+
self.language = provider_config.get('language', None)
|
23
40
|
|
24
41
|
@property
|
25
42
|
def client(self) -> AsyncOpenAI:
|
@@ -83,18 +100,40 @@ class OpenAITTSService(BaseTTSService):
|
|
83
100
|
try:
|
84
101
|
response = await self._client.audio.speech.create(
|
85
102
|
model="tts-1",
|
86
|
-
voice=voice or "alloy",
|
103
|
+
voice=voice or "alloy", # type: ignore
|
87
104
|
input=text,
|
88
|
-
response_format=format,
|
105
|
+
response_format=format, # type: ignore
|
89
106
|
speed=speed
|
90
107
|
)
|
91
108
|
|
92
109
|
audio_data = response.content
|
93
110
|
|
111
|
+
# Estimate audio duration for billing (rough estimation: ~150 words per minute)
|
112
|
+
words = len(text.split())
|
113
|
+
estimated_duration_seconds = (words / 150.0) * 60.0 / speed
|
114
|
+
|
115
|
+
# Track usage for billing (OpenAI TTS is token-based: $15 per 1M characters)
|
116
|
+
self._track_usage(
|
117
|
+
service_type=ServiceType.AUDIO_TTS,
|
118
|
+
operation="synthesize_speech",
|
119
|
+
input_tokens=len(text), # Characters as input tokens
|
120
|
+
output_tokens=0,
|
121
|
+
input_units=len(text), # Text length
|
122
|
+
output_units=estimated_duration_seconds, # Audio duration in seconds
|
123
|
+
metadata={
|
124
|
+
"model": self.model_name,
|
125
|
+
"voice": voice or "alloy",
|
126
|
+
"speed": speed,
|
127
|
+
"format": format,
|
128
|
+
"text_length": len(text),
|
129
|
+
"estimated_duration_seconds": estimated_duration_seconds
|
130
|
+
}
|
131
|
+
)
|
132
|
+
|
94
133
|
return {
|
95
134
|
"audio_data": audio_data,
|
96
135
|
"format": format,
|
97
|
-
"duration":
|
136
|
+
"duration": estimated_duration_seconds,
|
98
137
|
"sample_rate": 24000 # Default for OpenAI TTS
|
99
138
|
}
|
100
139
|
|
@@ -0,0 +1,239 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, Any, List, Optional, BinaryIO
|
3
|
+
import replicate
|
4
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
5
|
+
|
6
|
+
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
7
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
8
|
+
from isa_model.inference.billing_tracker import ServiceType
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
class ReplicateTTSService(BaseTTSService):
|
13
|
+
"""
|
14
|
+
Replicate Text-to-Speech service using Kokoro model.
|
15
|
+
High-quality voice synthesis with multiple voice options.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"):
|
19
|
+
super().__init__(provider, model_name)
|
20
|
+
|
21
|
+
# Get full configuration from provider (including sensitive data)
|
22
|
+
provider_config = provider.get_full_config()
|
23
|
+
|
24
|
+
# Set up Replicate API token from provider configuration
|
25
|
+
self.api_token = provider_config.get('api_token') or provider_config.get('replicate_api_token')
|
26
|
+
if not self.api_token:
|
27
|
+
raise ValueError("Replicate API token not found in provider configuration")
|
28
|
+
|
29
|
+
# Set environment variable for replicate library
|
30
|
+
import os
|
31
|
+
os.environ['REPLICATE_API_TOKEN'] = self.api_token
|
32
|
+
|
33
|
+
# Available voices for Kokoro model
|
34
|
+
self.available_voices = [
|
35
|
+
"af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
|
36
|
+
]
|
37
|
+
|
38
|
+
# Default settings
|
39
|
+
self.default_voice = "af_nicole"
|
40
|
+
self.default_speed = 1.0
|
41
|
+
|
42
|
+
logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
|
43
|
+
|
44
|
+
@retry(
|
45
|
+
stop=stop_after_attempt(3),
|
46
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
47
|
+
reraise=True
|
48
|
+
)
|
49
|
+
async def synthesize_speech(
|
50
|
+
self,
|
51
|
+
text: str,
|
52
|
+
voice: Optional[str] = None,
|
53
|
+
speed: float = 1.0,
|
54
|
+
pitch: Optional[float] = None,
|
55
|
+
volume: Optional[float] = None
|
56
|
+
) -> Dict[str, Any]:
|
57
|
+
"""Synthesize speech from text using Kokoro model"""
|
58
|
+
try:
|
59
|
+
# Validate and set voice
|
60
|
+
selected_voice = voice or self.default_voice
|
61
|
+
if selected_voice not in self.available_voices:
|
62
|
+
logger.warning(f"Voice '{selected_voice}' not available, using default '{self.default_voice}'")
|
63
|
+
selected_voice = self.default_voice
|
64
|
+
|
65
|
+
# Prepare input parameters
|
66
|
+
input_params = {
|
67
|
+
"text": text,
|
68
|
+
"voice": selected_voice,
|
69
|
+
"speed": max(0.5, min(2.0, speed)) # Clamp speed between 0.5 and 2.0
|
70
|
+
}
|
71
|
+
|
72
|
+
logger.info(f"Synthesizing speech with voice '{selected_voice}' and speed {speed}")
|
73
|
+
|
74
|
+
# Run the model
|
75
|
+
output = await replicate.async_run(self.model_name, input=input_params)
|
76
|
+
|
77
|
+
# Handle different output formats
|
78
|
+
try:
|
79
|
+
if isinstance(output, str):
|
80
|
+
audio_url = output
|
81
|
+
elif hasattr(output, 'url'):
|
82
|
+
# Handle FileOutput object
|
83
|
+
audio_url = str(getattr(output, 'url', output))
|
84
|
+
elif isinstance(output, list) and len(output) > 0:
|
85
|
+
first_output = output[0]
|
86
|
+
if hasattr(first_output, 'url'):
|
87
|
+
audio_url = str(getattr(first_output, 'url', first_output))
|
88
|
+
else:
|
89
|
+
audio_url = str(first_output)
|
90
|
+
else:
|
91
|
+
# Convert to string as fallback
|
92
|
+
audio_url = str(output)
|
93
|
+
except Exception:
|
94
|
+
# Safe fallback
|
95
|
+
audio_url = str(output)
|
96
|
+
|
97
|
+
# Estimate audio duration for billing (rough estimation: ~150 words per minute)
|
98
|
+
words = len(text.split())
|
99
|
+
estimated_duration_seconds = (words / 150.0) * 60.0 / speed
|
100
|
+
|
101
|
+
# Track usage for billing
|
102
|
+
self._track_usage(
|
103
|
+
service_type=ServiceType.AUDIO_TTS,
|
104
|
+
operation="synthesize_speech",
|
105
|
+
input_tokens=0,
|
106
|
+
output_tokens=0,
|
107
|
+
input_units=len(text), # Text length
|
108
|
+
output_units=estimated_duration_seconds, # Audio duration in seconds
|
109
|
+
metadata={
|
110
|
+
"model": self.model_name,
|
111
|
+
"voice": selected_voice,
|
112
|
+
"speed": speed,
|
113
|
+
"text_length": len(text),
|
114
|
+
"estimated_duration_seconds": estimated_duration_seconds
|
115
|
+
}
|
116
|
+
)
|
117
|
+
|
118
|
+
result = {
|
119
|
+
"audio_url": audio_url,
|
120
|
+
"text": text,
|
121
|
+
"voice": selected_voice,
|
122
|
+
"speed": speed,
|
123
|
+
"duration_seconds": estimated_duration_seconds,
|
124
|
+
"metadata": {
|
125
|
+
"model": self.model_name,
|
126
|
+
"provider": "replicate",
|
127
|
+
"voice_options": self.available_voices
|
128
|
+
}
|
129
|
+
}
|
130
|
+
|
131
|
+
logger.info(f"Speech synthesis completed: {audio_url}")
|
132
|
+
return result
|
133
|
+
|
134
|
+
except Exception as e:
|
135
|
+
logger.error(f"Error synthesizing speech: {e}")
|
136
|
+
raise
|
137
|
+
|
138
|
+
async def synthesize_speech_to_file(
|
139
|
+
self,
|
140
|
+
text: str,
|
141
|
+
output_path: str,
|
142
|
+
voice: Optional[str] = None,
|
143
|
+
speed: float = 1.0,
|
144
|
+
pitch: Optional[float] = None,
|
145
|
+
volume: Optional[float] = None
|
146
|
+
) -> Dict[str, Any]:
|
147
|
+
"""Synthesize speech and save to file"""
|
148
|
+
try:
|
149
|
+
# Get audio URL
|
150
|
+
result = await self.synthesize_speech(text, voice, speed, pitch, volume)
|
151
|
+
audio_url = result["audio_url"]
|
152
|
+
|
153
|
+
# Download and save audio
|
154
|
+
import aiohttp
|
155
|
+
import aiofiles
|
156
|
+
|
157
|
+
async with aiohttp.ClientSession() as session:
|
158
|
+
async with session.get(audio_url) as response:
|
159
|
+
response.raise_for_status()
|
160
|
+
audio_data = await response.read()
|
161
|
+
|
162
|
+
async with aiofiles.open(output_path, 'wb') as f:
|
163
|
+
await f.write(audio_data)
|
164
|
+
|
165
|
+
result["output_path"] = output_path
|
166
|
+
result["file_size"] = len(audio_data)
|
167
|
+
|
168
|
+
logger.info(f"Audio saved to: {output_path}")
|
169
|
+
return result
|
170
|
+
|
171
|
+
except Exception as e:
|
172
|
+
logger.error(f"Error saving audio to file: {e}")
|
173
|
+
raise
|
174
|
+
|
175
|
+
async def synthesize_speech_batch(
|
176
|
+
self,
|
177
|
+
texts: List[str],
|
178
|
+
voice: Optional[str] = None,
|
179
|
+
speed: float = 1.0,
|
180
|
+
pitch: float = 1.0,
|
181
|
+
format: str = "wav"
|
182
|
+
) -> List[Dict[str, Any]]:
|
183
|
+
"""Synthesize multiple texts"""
|
184
|
+
results = []
|
185
|
+
|
186
|
+
for text in texts:
|
187
|
+
try:
|
188
|
+
result = await self.synthesize_speech(text, voice, speed)
|
189
|
+
results.append(result)
|
190
|
+
except Exception as e:
|
191
|
+
logger.error(f"Error synthesizing text '{text[:50]}...': {e}")
|
192
|
+
results.append({
|
193
|
+
"audio_url": None,
|
194
|
+
"text": text,
|
195
|
+
"voice": voice or self.default_voice,
|
196
|
+
"speed": speed,
|
197
|
+
"error": str(e)
|
198
|
+
})
|
199
|
+
|
200
|
+
return results
|
201
|
+
|
202
|
+
def get_available_voices(self) -> List[Dict[str, Any]]:
|
203
|
+
"""Get list of available voices"""
|
204
|
+
voices = []
|
205
|
+
for voice in self.available_voices:
|
206
|
+
voice_info = self.get_voice_info(voice)
|
207
|
+
voices.append({
|
208
|
+
"id": voice,
|
209
|
+
"name": voice.replace("_", " ").title(),
|
210
|
+
"language": "en-US",
|
211
|
+
"gender": voice_info.get("gender", "unknown"),
|
212
|
+
"age": "adult"
|
213
|
+
})
|
214
|
+
return voices
|
215
|
+
|
216
|
+
def get_supported_formats(self) -> List[str]:
|
217
|
+
"""Get list of supported audio formats"""
|
218
|
+
return ["wav", "mp3"] # Kokoro typically outputs WAV
|
219
|
+
|
220
|
+
def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
|
221
|
+
"""Get information about a specific voice"""
|
222
|
+
if voice_id not in self.available_voices:
|
223
|
+
return {"error": f"Voice '{voice_id}' not available"}
|
224
|
+
|
225
|
+
# Voice metadata (you can expand this with more details)
|
226
|
+
voice_info = {
|
227
|
+
"af_bella": {"id": "af_bella", "name": "Bella", "gender": "female", "language": "en-US", "description": "Warm, friendly female voice", "sample_rate": 22050},
|
228
|
+
"af_nicole": {"id": "af_nicole", "name": "Nicole", "gender": "female", "language": "en-US", "description": "Clear, professional female voice", "sample_rate": 22050},
|
229
|
+
"af_sarah": {"id": "af_sarah", "name": "Sarah", "gender": "female", "language": "en-US", "description": "Gentle, expressive female voice", "sample_rate": 22050},
|
230
|
+
"af_sky": {"id": "af_sky", "name": "Sky", "gender": "female", "language": "en-US", "description": "Bright, energetic female voice", "sample_rate": 22050},
|
231
|
+
"am_adam": {"id": "am_adam", "name": "Adam", "gender": "male", "language": "en-US", "description": "Deep, authoritative male voice", "sample_rate": 22050},
|
232
|
+
"am_michael": {"id": "am_michael", "name": "Michael", "gender": "male", "language": "en-US", "description": "Smooth, conversational male voice", "sample_rate": 22050}
|
233
|
+
}
|
234
|
+
|
235
|
+
return voice_info.get(voice_id, {"id": voice_id, "gender": "unknown", "language": "en-US", "description": "Voice information not available", "sample_rate": 22050})
|
236
|
+
|
237
|
+
async def close(self):
|
238
|
+
"""Cleanup resources"""
|
239
|
+
logger.info("ReplicateTTSService resources cleaned up")
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
|
3
3
|
from isa_model.inference.providers.base_provider import BaseProvider
|
4
|
+
from isa_model.inference.billing_tracker import track_usage, ServiceType, Provider
|
4
5
|
|
5
6
|
T = TypeVar('T') # Generic type for responses
|
6
7
|
|
@@ -10,7 +11,41 @@ class BaseService(ABC):
|
|
10
11
|
def __init__(self, provider: 'BaseProvider', model_name: str):
|
11
12
|
self.provider = provider
|
12
13
|
self.model_name = model_name
|
13
|
-
self.config = provider.
|
14
|
+
self.config = provider.get_full_config()
|
15
|
+
|
16
|
+
def _track_usage(
|
17
|
+
self,
|
18
|
+
service_type: Union[str, ServiceType],
|
19
|
+
operation: str,
|
20
|
+
input_tokens: Optional[int] = None,
|
21
|
+
output_tokens: Optional[int] = None,
|
22
|
+
input_units: Optional[float] = None,
|
23
|
+
output_units: Optional[float] = None,
|
24
|
+
metadata: Optional[Dict[str, Any]] = None
|
25
|
+
):
|
26
|
+
"""Track usage for billing purposes"""
|
27
|
+
try:
|
28
|
+
# Determine provider name - try multiple attributes
|
29
|
+
provider_name = getattr(self.provider, 'name', None) or \
|
30
|
+
getattr(self.provider, 'provider_name', None) or \
|
31
|
+
getattr(self.provider, '__class__', type(None)).__name__.lower().replace('provider', '') or \
|
32
|
+
'unknown'
|
33
|
+
|
34
|
+
track_usage(
|
35
|
+
provider=provider_name,
|
36
|
+
service_type=service_type,
|
37
|
+
model_name=self.model_name,
|
38
|
+
operation=operation,
|
39
|
+
input_tokens=input_tokens,
|
40
|
+
output_tokens=output_tokens,
|
41
|
+
input_units=input_units,
|
42
|
+
output_units=output_units,
|
43
|
+
metadata=metadata
|
44
|
+
)
|
45
|
+
except Exception as e:
|
46
|
+
# Don't let billing tracking break the service
|
47
|
+
import logging
|
48
|
+
logging.getLogger(__name__).warning(f"Failed to track usage: {e}")
|
14
49
|
|
15
50
|
def __await__(self):
|
16
51
|
"""Make the service awaitable"""
|