isa-model 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. isa_model/core/model_manager.py +69 -4
  2. isa_model/inference/ai_factory.py +335 -46
  3. isa_model/inference/billing_tracker.py +406 -0
  4. isa_model/inference/providers/base_provider.py +51 -4
  5. isa_model/inference/providers/ollama_provider.py +37 -18
  6. isa_model/inference/providers/openai_provider.py +65 -36
  7. isa_model/inference/providers/replicate_provider.py +42 -30
  8. isa_model/inference/services/audio/base_stt_service.py +21 -2
  9. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  10. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  11. isa_model/inference/services/audio/openai_tts_service.py +48 -9
  12. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  13. isa_model/inference/services/base_service.py +36 -1
  14. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  15. isa_model/inference/services/llm/base_llm_service.py +88 -192
  16. isa_model/inference/services/llm/llm_adapter.py +459 -0
  17. isa_model/inference/services/llm/ollama_llm_service.py +111 -185
  18. isa_model/inference/services/llm/openai_llm_service.py +115 -360
  19. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  20. isa_model/inference/services/vision/ollama_vision_service.py +11 -3
  21. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  22. isa_model/inference/services/vision/replicate_image_gen_service.py +233 -205
  23. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/METADATA +1 -1
  24. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/RECORD +26 -21
  25. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/WHEEL +0 -0
  26. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,252 @@
1
+ import logging
2
+ import aiohttp
3
+ from typing import Dict, Any, List, Union, Optional, BinaryIO
4
+ from openai import AsyncOpenAI
5
+ from tenacity import retry, stop_after_attempt, wait_exponential
6
+
7
+ from isa_model.inference.services.audio.base_stt_service import BaseSTTService
8
+ from isa_model.inference.providers.base_provider import BaseProvider
9
+ from isa_model.inference.billing_tracker import ServiceType
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class OpenAISTTService(BaseSTTService):
14
+ """
15
+ OpenAI Speech-to-Text service using whisper-1 model.
16
+ Supports transcription and translation to English.
17
+ """
18
+
19
+ def __init__(self, provider: 'BaseProvider', model_name: str = "whisper-1"):
20
+ super().__init__(provider, model_name)
21
+
22
+ # Get full configuration from provider (including sensitive data)
23
+ provider_config = provider.get_full_config()
24
+
25
+ # Initialize AsyncOpenAI client with provider configuration
26
+ try:
27
+ if not provider_config.get("api_key"):
28
+ raise ValueError("OpenAI API key not found in provider configuration")
29
+
30
+ self.client = AsyncOpenAI(
31
+ api_key=provider_config["api_key"],
32
+ base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
33
+ organization=provider_config.get("organization")
34
+ )
35
+
36
+ logger.info(f"Initialized OpenAISTTService with model '{self.model_name}'")
37
+
38
+ except Exception as e:
39
+ logger.error(f"Failed to initialize OpenAI client: {e}")
40
+ raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
41
+
42
+ # Model configurations
43
+ self.max_file_size = provider_config.get('max_file_size', 25 * 1024 * 1024) # 25MB
44
+ self.supported_formats = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm']
45
+
46
+ @retry(
47
+ stop=stop_after_attempt(3),
48
+ wait=wait_exponential(multiplier=1, min=4, max=10),
49
+ reraise=True
50
+ )
51
+ async def _download_audio(self, audio_url: str) -> bytes:
52
+ """Download audio from URL"""
53
+ async with aiohttp.ClientSession() as session:
54
+ async with session.get(audio_url) as response:
55
+ if response.status == 200:
56
+ return await response.read()
57
+ else:
58
+ raise ValueError(f"Failed to download audio from {audio_url}: {response.status}")
59
+
60
+ async def transcribe(
61
+ self,
62
+ audio_file: Union[str, BinaryIO],
63
+ language: Optional[str] = None,
64
+ prompt: Optional[str] = None
65
+ ) -> Dict[str, Any]:
66
+ """Transcribe audio file to text using whisper-1"""
67
+ try:
68
+ # Prepare the audio file
69
+ if isinstance(audio_file, str):
70
+ if audio_file.startswith(('http://', 'https://')):
71
+ # Download audio from URL
72
+ audio_data = await self._download_audio(audio_file)
73
+ filename = audio_file.split('/')[-1] or 'audio.wav'
74
+ else:
75
+ # Local file path
76
+ with open(audio_file, 'rb') as f:
77
+ audio_data = f.read()
78
+ filename = audio_file
79
+ else:
80
+ audio_data = audio_file.read()
81
+ filename = getattr(audio_file, 'name', 'audio.wav')
82
+
83
+ # Check file size
84
+ if len(audio_data) > self.max_file_size:
85
+ raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
86
+
87
+ # Prepare transcription parameters
88
+ kwargs = {
89
+ "model": self.model_name,
90
+ "file": (filename, audio_data),
91
+ "response_format": "verbose_json"
92
+ }
93
+
94
+ if language:
95
+ kwargs["language"] = language
96
+ if prompt:
97
+ kwargs["prompt"] = prompt
98
+
99
+ # Transcribe audio
100
+ response = await self.client.audio.transcriptions.create(**kwargs)
101
+
102
+ # Track usage for billing
103
+ usage = getattr(response, 'usage', {})
104
+ input_tokens = usage.get('input_tokens', 0) if usage else 0
105
+ output_tokens = usage.get('output_tokens', 0) if usage else 0
106
+
107
+ # For audio, also track duration in minutes
108
+ duration_minutes = getattr(response, 'duration', 0) / 60.0 if getattr(response, 'duration', 0) else 0
109
+
110
+ self._track_usage(
111
+ service_type=ServiceType.AUDIO_STT,
112
+ operation="transcribe",
113
+ input_tokens=input_tokens,
114
+ output_tokens=output_tokens,
115
+ input_units=duration_minutes, # Duration in minutes
116
+ metadata={
117
+ "language": language,
118
+ "model": self.model_name,
119
+ "file_size": len(audio_data)
120
+ }
121
+ )
122
+
123
+ # Format response
124
+ result = {
125
+ "text": response.text,
126
+ "language": getattr(response, 'language', language or 'unknown'),
127
+ "duration": getattr(response, 'duration', None),
128
+ "segments": getattr(response, 'segments', []),
129
+ "confidence": None, # whisper-1 doesn't provide confidence scores
130
+ "usage": usage # Include usage information
131
+ }
132
+
133
+ return result
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error transcribing audio: {e}")
137
+ raise
138
+
139
+ @retry(
140
+ stop=stop_after_attempt(3),
141
+ wait=wait_exponential(multiplier=1, min=4, max=10),
142
+ reraise=True
143
+ )
144
+ async def translate(
145
+ self,
146
+ audio_file: Union[str, BinaryIO]
147
+ ) -> Dict[str, Any]:
148
+ """Translate audio file to English text"""
149
+ try:
150
+ # Prepare the audio file
151
+ if isinstance(audio_file, str):
152
+ with open(audio_file, 'rb') as f:
153
+ audio_data = f.read()
154
+ filename = audio_file
155
+ else:
156
+ audio_data = audio_file.read()
157
+ filename = getattr(audio_file, 'name', 'audio.wav')
158
+
159
+ # Check file size
160
+ if len(audio_data) > self.max_file_size:
161
+ raise ValueError(f"Audio file size ({len(audio_data)} bytes) exceeds maximum ({self.max_file_size} bytes)")
162
+
163
+ # Translate audio to English
164
+ response = await self.client.audio.translations.create(
165
+ model=self.model_name,
166
+ file=(filename, audio_data),
167
+ response_format="verbose_json"
168
+ )
169
+
170
+ # Format response
171
+ result = {
172
+ "text": response.text,
173
+ "detected_language": getattr(response, 'language', 'unknown'),
174
+ "duration": getattr(response, 'duration', None),
175
+ "segments": getattr(response, 'segments', []),
176
+ "confidence": None # Whisper doesn't provide confidence scores
177
+ }
178
+
179
+ return result
180
+
181
+ except Exception as e:
182
+ logger.error(f"Error translating audio: {e}")
183
+ raise
184
+
185
+ async def transcribe_batch(
186
+ self,
187
+ audio_files: List[Union[str, BinaryIO]],
188
+ language: Optional[str] = None,
189
+ prompt: Optional[str] = None
190
+ ) -> List[Dict[str, Any]]:
191
+ """Transcribe multiple audio files"""
192
+ results = []
193
+
194
+ for audio_file in audio_files:
195
+ try:
196
+ result = await self.transcribe(audio_file, language, prompt)
197
+ results.append(result)
198
+ except Exception as e:
199
+ logger.error(f"Error transcribing audio file: {e}")
200
+ results.append({
201
+ "text": "",
202
+ "language": "unknown",
203
+ "duration": None,
204
+ "segments": [],
205
+ "confidence": None,
206
+ "error": str(e)
207
+ })
208
+
209
+ return results
210
+
211
+ async def detect_language(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
212
+ """Detect language of audio file"""
213
+ try:
214
+ # Transcribe with language detection
215
+ result = await self.transcribe(audio_file, language=None)
216
+
217
+ return {
218
+ "language": result["language"],
219
+ "confidence": 1.0, # Whisper is generally confident
220
+ "alternatives": [] # Whisper doesn't provide alternatives
221
+ }
222
+
223
+ except Exception as e:
224
+ logger.error(f"Error detecting language: {e}")
225
+ raise
226
+
227
+ def get_supported_formats(self) -> List[str]:
228
+ """Get list of supported audio formats"""
229
+ return self.supported_formats.copy()
230
+
231
+ def get_supported_languages(self) -> List[str]:
232
+ """Get list of supported language codes"""
233
+ # Whisper supports 99+ languages
234
+ return [
235
+ 'af', 'am', 'ar', 'as', 'az', 'ba', 'be', 'bg', 'bn', 'bo', 'br', 'bs', 'ca',
236
+ 'cs', 'cy', 'da', 'de', 'el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fo', 'fr',
237
+ 'gl', 'gu', 'ha', 'haw', 'he', 'hi', 'hr', 'ht', 'hu', 'hy', 'id', 'is', 'it',
238
+ 'ja', 'jw', 'ka', 'kk', 'km', 'kn', 'ko', 'la', 'lb', 'ln', 'lo', 'lt', 'lv',
239
+ 'mg', 'mi', 'mk', 'ml', 'mn', 'mr', 'ms', 'mt', 'my', 'ne', 'nl', 'nn', 'no',
240
+ 'oc', 'pa', 'pl', 'ps', 'pt', 'ro', 'ru', 'sa', 'sd', 'si', 'sk', 'sl', 'sn',
241
+ 'so', 'sq', 'sr', 'su', 'sv', 'sw', 'ta', 'te', 'tg', 'th', 'tk', 'tl', 'tr',
242
+ 'tt', 'uk', 'ur', 'uz', 'vi', 'yi', 'yo', 'zh'
243
+ ]
244
+
245
+ def get_max_file_size(self) -> int:
246
+ """Get maximum file size in bytes"""
247
+ return self.max_file_size
248
+
249
+ async def close(self):
250
+ """Cleanup resources"""
251
+ await self.client.close()
252
+ logger.info("OpenAISTTService client has been closed.")
@@ -5,6 +5,7 @@ from openai import AsyncOpenAI
5
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
6
  from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
7
  from isa_model.inference.providers.base_provider import BaseProvider
8
+ from isa_model.inference.billing_tracker import ServiceType
8
9
  import logging
9
10
 
10
11
  logger = logging.getLogger(__name__)
@@ -14,12 +15,28 @@ class OpenAITTSService(BaseTTSService):
14
15
 
15
16
  def __init__(self, provider: 'BaseProvider', model_name: str):
16
17
  super().__init__(provider, model_name)
17
- # 初始化 AsyncOpenAI 客户端
18
- self._client = AsyncOpenAI(
19
- api_key=self.config.get('api_key'),
20
- base_url=self.config.get('base_url')
21
- )
22
- self.language = self.config.get('language', None)
18
+
19
+ # Get full configuration from provider (including sensitive data)
20
+ provider_config = provider.get_full_config()
21
+
22
+ # Initialize AsyncOpenAI client with provider configuration
23
+ try:
24
+ if not provider_config.get("api_key"):
25
+ raise ValueError("OpenAI API key not found in provider configuration")
26
+
27
+ self._client = AsyncOpenAI(
28
+ api_key=provider_config["api_key"],
29
+ base_url=provider_config.get("base_url", "https://api.openai.com/v1"),
30
+ organization=provider_config.get("organization")
31
+ )
32
+
33
+ logger.info(f"Initialized OpenAITTSService with model '{self.model_name}'")
34
+
35
+ except Exception as e:
36
+ logger.error(f"Failed to initialize OpenAI client: {e}")
37
+ raise ValueError(f"Failed to initialize OpenAI client. Check your API key configuration: {e}") from e
38
+
39
+ self.language = provider_config.get('language', None)
23
40
 
24
41
  @property
25
42
  def client(self) -> AsyncOpenAI:
@@ -83,18 +100,40 @@ class OpenAITTSService(BaseTTSService):
83
100
  try:
84
101
  response = await self._client.audio.speech.create(
85
102
  model="tts-1",
86
- voice=voice or "alloy",
103
+ voice=voice or "alloy", # type: ignore
87
104
  input=text,
88
- response_format=format,
105
+ response_format=format, # type: ignore
89
106
  speed=speed
90
107
  )
91
108
 
92
109
  audio_data = response.content
93
110
 
111
+ # Estimate audio duration for billing (rough estimation: ~150 words per minute)
112
+ words = len(text.split())
113
+ estimated_duration_seconds = (words / 150.0) * 60.0 / speed
114
+
115
+ # Track usage for billing (OpenAI TTS is token-based: $15 per 1M characters)
116
+ self._track_usage(
117
+ service_type=ServiceType.AUDIO_TTS,
118
+ operation="synthesize_speech",
119
+ input_tokens=len(text), # Characters as input tokens
120
+ output_tokens=0,
121
+ input_units=len(text), # Text length
122
+ output_units=estimated_duration_seconds, # Audio duration in seconds
123
+ metadata={
124
+ "model": self.model_name,
125
+ "voice": voice or "alloy",
126
+ "speed": speed,
127
+ "format": format,
128
+ "text_length": len(text),
129
+ "estimated_duration_seconds": estimated_duration_seconds
130
+ }
131
+ )
132
+
94
133
  return {
95
134
  "audio_data": audio_data,
96
135
  "format": format,
97
- "duration": 0.0, # OpenAI doesn't provide duration
136
+ "duration": estimated_duration_seconds,
98
137
  "sample_rate": 24000 # Default for OpenAI TTS
99
138
  }
100
139
 
@@ -0,0 +1,239 @@
1
+ import logging
2
+ from typing import Dict, Any, List, Optional, BinaryIO
3
+ import replicate
4
+ from tenacity import retry, stop_after_attempt, wait_exponential
5
+
6
+ from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
+ from isa_model.inference.providers.base_provider import BaseProvider
8
+ from isa_model.inference.billing_tracker import ServiceType
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class ReplicateTTSService(BaseTTSService):
13
+ """
14
+ Replicate Text-to-Speech service using Kokoro model.
15
+ High-quality voice synthesis with multiple voice options.
16
+ """
17
+
18
+ def __init__(self, provider: 'BaseProvider', model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"):
19
+ super().__init__(provider, model_name)
20
+
21
+ # Get full configuration from provider (including sensitive data)
22
+ provider_config = provider.get_full_config()
23
+
24
+ # Set up Replicate API token from provider configuration
25
+ self.api_token = provider_config.get('api_token') or provider_config.get('replicate_api_token')
26
+ if not self.api_token:
27
+ raise ValueError("Replicate API token not found in provider configuration")
28
+
29
+ # Set environment variable for replicate library
30
+ import os
31
+ os.environ['REPLICATE_API_TOKEN'] = self.api_token
32
+
33
+ # Available voices for Kokoro model
34
+ self.available_voices = [
35
+ "af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
36
+ ]
37
+
38
+ # Default settings
39
+ self.default_voice = "af_nicole"
40
+ self.default_speed = 1.0
41
+
42
+ logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
43
+
44
+ @retry(
45
+ stop=stop_after_attempt(3),
46
+ wait=wait_exponential(multiplier=1, min=4, max=10),
47
+ reraise=True
48
+ )
49
+ async def synthesize_speech(
50
+ self,
51
+ text: str,
52
+ voice: Optional[str] = None,
53
+ speed: float = 1.0,
54
+ pitch: Optional[float] = None,
55
+ volume: Optional[float] = None
56
+ ) -> Dict[str, Any]:
57
+ """Synthesize speech from text using Kokoro model"""
58
+ try:
59
+ # Validate and set voice
60
+ selected_voice = voice or self.default_voice
61
+ if selected_voice not in self.available_voices:
62
+ logger.warning(f"Voice '{selected_voice}' not available, using default '{self.default_voice}'")
63
+ selected_voice = self.default_voice
64
+
65
+ # Prepare input parameters
66
+ input_params = {
67
+ "text": text,
68
+ "voice": selected_voice,
69
+ "speed": max(0.5, min(2.0, speed)) # Clamp speed between 0.5 and 2.0
70
+ }
71
+
72
+ logger.info(f"Synthesizing speech with voice '{selected_voice}' and speed {speed}")
73
+
74
+ # Run the model
75
+ output = await replicate.async_run(self.model_name, input=input_params)
76
+
77
+ # Handle different output formats
78
+ try:
79
+ if isinstance(output, str):
80
+ audio_url = output
81
+ elif hasattr(output, 'url'):
82
+ # Handle FileOutput object
83
+ audio_url = str(getattr(output, 'url', output))
84
+ elif isinstance(output, list) and len(output) > 0:
85
+ first_output = output[0]
86
+ if hasattr(first_output, 'url'):
87
+ audio_url = str(getattr(first_output, 'url', first_output))
88
+ else:
89
+ audio_url = str(first_output)
90
+ else:
91
+ # Convert to string as fallback
92
+ audio_url = str(output)
93
+ except Exception:
94
+ # Safe fallback
95
+ audio_url = str(output)
96
+
97
+ # Estimate audio duration for billing (rough estimation: ~150 words per minute)
98
+ words = len(text.split())
99
+ estimated_duration_seconds = (words / 150.0) * 60.0 / speed
100
+
101
+ # Track usage for billing
102
+ self._track_usage(
103
+ service_type=ServiceType.AUDIO_TTS,
104
+ operation="synthesize_speech",
105
+ input_tokens=0,
106
+ output_tokens=0,
107
+ input_units=len(text), # Text length
108
+ output_units=estimated_duration_seconds, # Audio duration in seconds
109
+ metadata={
110
+ "model": self.model_name,
111
+ "voice": selected_voice,
112
+ "speed": speed,
113
+ "text_length": len(text),
114
+ "estimated_duration_seconds": estimated_duration_seconds
115
+ }
116
+ )
117
+
118
+ result = {
119
+ "audio_url": audio_url,
120
+ "text": text,
121
+ "voice": selected_voice,
122
+ "speed": speed,
123
+ "duration_seconds": estimated_duration_seconds,
124
+ "metadata": {
125
+ "model": self.model_name,
126
+ "provider": "replicate",
127
+ "voice_options": self.available_voices
128
+ }
129
+ }
130
+
131
+ logger.info(f"Speech synthesis completed: {audio_url}")
132
+ return result
133
+
134
+ except Exception as e:
135
+ logger.error(f"Error synthesizing speech: {e}")
136
+ raise
137
+
138
+ async def synthesize_speech_to_file(
139
+ self,
140
+ text: str,
141
+ output_path: str,
142
+ voice: Optional[str] = None,
143
+ speed: float = 1.0,
144
+ pitch: Optional[float] = None,
145
+ volume: Optional[float] = None
146
+ ) -> Dict[str, Any]:
147
+ """Synthesize speech and save to file"""
148
+ try:
149
+ # Get audio URL
150
+ result = await self.synthesize_speech(text, voice, speed, pitch, volume)
151
+ audio_url = result["audio_url"]
152
+
153
+ # Download and save audio
154
+ import aiohttp
155
+ import aiofiles
156
+
157
+ async with aiohttp.ClientSession() as session:
158
+ async with session.get(audio_url) as response:
159
+ response.raise_for_status()
160
+ audio_data = await response.read()
161
+
162
+ async with aiofiles.open(output_path, 'wb') as f:
163
+ await f.write(audio_data)
164
+
165
+ result["output_path"] = output_path
166
+ result["file_size"] = len(audio_data)
167
+
168
+ logger.info(f"Audio saved to: {output_path}")
169
+ return result
170
+
171
+ except Exception as e:
172
+ logger.error(f"Error saving audio to file: {e}")
173
+ raise
174
+
175
+ async def synthesize_speech_batch(
176
+ self,
177
+ texts: List[str],
178
+ voice: Optional[str] = None,
179
+ speed: float = 1.0,
180
+ pitch: float = 1.0,
181
+ format: str = "wav"
182
+ ) -> List[Dict[str, Any]]:
183
+ """Synthesize multiple texts"""
184
+ results = []
185
+
186
+ for text in texts:
187
+ try:
188
+ result = await self.synthesize_speech(text, voice, speed)
189
+ results.append(result)
190
+ except Exception as e:
191
+ logger.error(f"Error synthesizing text '{text[:50]}...': {e}")
192
+ results.append({
193
+ "audio_url": None,
194
+ "text": text,
195
+ "voice": voice or self.default_voice,
196
+ "speed": speed,
197
+ "error": str(e)
198
+ })
199
+
200
+ return results
201
+
202
+ def get_available_voices(self) -> List[Dict[str, Any]]:
203
+ """Get list of available voices"""
204
+ voices = []
205
+ for voice in self.available_voices:
206
+ voice_info = self.get_voice_info(voice)
207
+ voices.append({
208
+ "id": voice,
209
+ "name": voice.replace("_", " ").title(),
210
+ "language": "en-US",
211
+ "gender": voice_info.get("gender", "unknown"),
212
+ "age": "adult"
213
+ })
214
+ return voices
215
+
216
+ def get_supported_formats(self) -> List[str]:
217
+ """Get list of supported audio formats"""
218
+ return ["wav", "mp3"] # Kokoro typically outputs WAV
219
+
220
+ def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
221
+ """Get information about a specific voice"""
222
+ if voice_id not in self.available_voices:
223
+ return {"error": f"Voice '{voice_id}' not available"}
224
+
225
+ # Voice metadata (you can expand this with more details)
226
+ voice_info = {
227
+ "af_bella": {"id": "af_bella", "name": "Bella", "gender": "female", "language": "en-US", "description": "Warm, friendly female voice", "sample_rate": 22050},
228
+ "af_nicole": {"id": "af_nicole", "name": "Nicole", "gender": "female", "language": "en-US", "description": "Clear, professional female voice", "sample_rate": 22050},
229
+ "af_sarah": {"id": "af_sarah", "name": "Sarah", "gender": "female", "language": "en-US", "description": "Gentle, expressive female voice", "sample_rate": 22050},
230
+ "af_sky": {"id": "af_sky", "name": "Sky", "gender": "female", "language": "en-US", "description": "Bright, energetic female voice", "sample_rate": 22050},
231
+ "am_adam": {"id": "am_adam", "name": "Adam", "gender": "male", "language": "en-US", "description": "Deep, authoritative male voice", "sample_rate": 22050},
232
+ "am_michael": {"id": "am_michael", "name": "Michael", "gender": "male", "language": "en-US", "description": "Smooth, conversational male voice", "sample_rate": 22050}
233
+ }
234
+
235
+ return voice_info.get(voice_id, {"id": voice_id, "gender": "unknown", "language": "en-US", "description": "Voice information not available", "sample_rate": 22050})
236
+
237
+ async def close(self):
238
+ """Cleanup resources"""
239
+ logger.info("ReplicateTTSService resources cleaned up")
@@ -1,6 +1,7 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
3
3
  from isa_model.inference.providers.base_provider import BaseProvider
4
+ from isa_model.inference.billing_tracker import track_usage, ServiceType, Provider
4
5
 
5
6
  T = TypeVar('T') # Generic type for responses
6
7
 
@@ -10,7 +11,41 @@ class BaseService(ABC):
10
11
  def __init__(self, provider: 'BaseProvider', model_name: str):
11
12
  self.provider = provider
12
13
  self.model_name = model_name
13
- self.config = provider.get_config()
14
+ self.config = provider.get_full_config()
15
+
16
+ def _track_usage(
17
+ self,
18
+ service_type: Union[str, ServiceType],
19
+ operation: str,
20
+ input_tokens: Optional[int] = None,
21
+ output_tokens: Optional[int] = None,
22
+ input_units: Optional[float] = None,
23
+ output_units: Optional[float] = None,
24
+ metadata: Optional[Dict[str, Any]] = None
25
+ ):
26
+ """Track usage for billing purposes"""
27
+ try:
28
+ # Determine provider name - try multiple attributes
29
+ provider_name = getattr(self.provider, 'name', None) or \
30
+ getattr(self.provider, 'provider_name', None) or \
31
+ getattr(self.provider, '__class__', type(None)).__name__.lower().replace('provider', '') or \
32
+ 'unknown'
33
+
34
+ track_usage(
35
+ provider=provider_name,
36
+ service_type=service_type,
37
+ model_name=self.model_name,
38
+ operation=operation,
39
+ input_tokens=input_tokens,
40
+ output_tokens=output_tokens,
41
+ input_units=input_units,
42
+ output_units=output_units,
43
+ metadata=metadata
44
+ )
45
+ except Exception as e:
46
+ # Don't let billing tracking break the service
47
+ import logging
48
+ logging.getLogger(__name__).warning(f"Failed to track usage: {e}")
14
49
 
15
50
  def __await__(self):
16
51
  """Make the service awaitable"""