isa-model 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_registry.py +273 -46
  3. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  4. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  5. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  6. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  7. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  8. isa_model/eval/__init__.py +56 -0
  9. isa_model/eval/benchmarks.py +469 -0
  10. isa_model/eval/factory.py +582 -0
  11. isa_model/eval/metrics.py +628 -0
  12. isa_model/inference/ai_factory.py +98 -93
  13. isa_model/inference/providers/openai_provider.py +21 -7
  14. isa_model/inference/providers/replicate_provider.py +18 -5
  15. isa_model/inference/providers/triton_provider.py +1 -1
  16. isa_model/inference/services/audio/base_stt_service.py +91 -0
  17. isa_model/inference/services/audio/base_tts_service.py +136 -0
  18. isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
  19. isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
  20. isa_model/inference/services/llm/__init__.py +0 -4
  21. isa_model/inference/services/llm/base_llm_service.py +134 -0
  22. isa_model/inference/services/llm/ollama_llm_service.py +1 -10
  23. isa_model/inference/services/llm/openai_llm_service.py +70 -61
  24. isa_model/inference/services/vision/__init__.py +1 -1
  25. isa_model/inference/services/vision/ollama_vision_service.py +4 -4
  26. isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
  27. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  28. isa_model/training/__init__.py +44 -0
  29. isa_model/training/factory.py +393 -0
  30. isa_model-0.2.0.dist-info/METADATA +327 -0
  31. {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/RECORD +35 -60
  32. isa_model/deployment/mlflow_gateway/__init__.py +0 -8
  33. isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
  34. isa_model/deployment/unified_multimodal_client.py +0 -341
  35. isa_model/inference/adapter/triton_adapter.py +0 -453
  36. isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
  37. isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
  38. isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
  39. isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
  40. isa_model/inference/backends/__init__.py +0 -53
  41. isa_model/inference/backends/base_backend_client.py +0 -26
  42. isa_model/inference/backends/container_services.py +0 -104
  43. isa_model/inference/backends/local_services.py +0 -72
  44. isa_model/inference/backends/openai_client.py +0 -130
  45. isa_model/inference/backends/replicate_client.py +0 -197
  46. isa_model/inference/backends/third_party_services.py +0 -239
  47. isa_model/inference/backends/triton_client.py +0 -97
  48. isa_model/inference/client_sdk/client.py +0 -134
  49. isa_model/inference/client_sdk/client_data_std.py +0 -34
  50. isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
  51. isa_model/inference/client_sdk/exceptions.py +0 -0
  52. isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
  53. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
  54. isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
  55. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
  56. isa_model/inference/providers/vllm_provider.py +0 -0
  57. isa_model/inference/providers/yyds_provider.py +0 -83
  58. isa_model/inference/services/audio/fish_speech/handler.py +0 -215
  59. isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
  60. isa_model/inference/services/audio/triton_speech_service.py +0 -138
  61. isa_model/inference/services/audio/whisper_service.py +0 -186
  62. isa_model/inference/services/base_tts_service.py +0 -66
  63. isa_model/inference/services/embedding/bge_service.py +0 -183
  64. isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
  65. isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
  66. isa_model/inference/services/llm/gemma_service.py +0 -143
  67. isa_model/inference/services/llm/llama_service.py +0 -143
  68. isa_model/inference/services/llm/replicate_llm_service.py +0 -179
  69. isa_model/inference/services/llm/triton_llm_service.py +0 -230
  70. isa_model/inference/services/vision/replicate_vision_service.py +0 -241
  71. isa_model/inference/services/vision/triton_vision_service.py +0 -199
  72. isa_model-0.1.0.dist-info/METADATA +0 -116
  73. /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
  74. {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/WHEEL +0 -0
  75. {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {isa_model-0.1.0.dist-info → isa_model-0.2.0.dist-info}/top_level.txt +0 -0
File without changes
@@ -1,83 +0,0 @@
1
- from isa_model.inference.providers.base_provider import BaseProvider
2
- from isa_model.inference.base import ModelType, Capability
3
- from typing import Dict, List, Any
4
- import logging
5
-
6
- logger = logging.getLogger(__name__)
7
-
8
- class YYDSProvider(BaseProvider):
9
- """Provider for YYDS API (Your YYDS Provider API)"""
10
-
11
- def __init__(self, config=None):
12
- """
13
- Initialize the YYDS Provider
14
-
15
- Args:
16
- config (dict, optional): Configuration for the provider
17
- - api_key: API key for authentication
18
- - api_base: Base URL for YYDS API
19
- - timeout: Timeout for API calls in seconds
20
- """
21
- default_config = {
22
- "api_base": "https://api.yyds.ai/v1",
23
- "timeout": 60,
24
- "max_retries": 3,
25
- "temperature": 0.7,
26
- "top_p": 0.9,
27
- "max_tokens": 2048
28
- }
29
-
30
- # Merge default config with provided config
31
- merged_config = {**default_config, **(config or {})}
32
-
33
- super().__init__(config=merged_config)
34
- self.name = "yyds"
35
-
36
- # Validate API key
37
- api_key = self.config.get("api_key")
38
- if not api_key:
39
- logger.warning("No API key provided for YYDS Provider. Some operations may fail.")
40
-
41
- logger.info(f"Initialized YYDSProvider with API base: {self.config['api_base']}")
42
-
43
- def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
44
- """Get provider capabilities by model type"""
45
- return {
46
- ModelType.LLM: [
47
- Capability.CHAT,
48
- Capability.COMPLETION
49
- ],
50
- ModelType.VISION: [
51
- Capability.IMAGE_CLASSIFICATION,
52
- Capability.IMAGE_UNDERSTANDING
53
- ],
54
- ModelType.AUDIO: [
55
- Capability.SPEECH_TO_TEXT,
56
- Capability.TEXT_TO_SPEECH
57
- ]
58
- }
59
-
60
- def get_models(self, model_type: ModelType) -> List[str]:
61
- """Get available models for given type"""
62
- # Placeholder: In real implementation, this would query the YYDS API
63
- if model_type == ModelType.LLM:
64
- return ["yyds-l", "yyds-xl", "yyds-xxl"]
65
- elif model_type == ModelType.VISION:
66
- return ["yyds-vision", "yyds-multimodal"]
67
- elif model_type == ModelType.AUDIO:
68
- return ["yyds-speech", "yyds-tts"]
69
- else:
70
- return []
71
-
72
- def get_config(self) -> Dict[str, Any]:
73
- """Get provider configuration"""
74
- # Return a copy of the config, without the API key for security
75
- config_copy = self.config.copy()
76
- if "api_key" in config_copy:
77
- config_copy["api_key"] = "***" # Mask the API key
78
- return config_copy
79
-
80
- def is_reasoning_model(self, model_name: str) -> bool:
81
- """Check if the model is optimized for reasoning tasks"""
82
- # Only the largest models are considered reasoning-capable
83
- return model_name in ["yyds-xxl"]
@@ -1,215 +0,0 @@
1
- import os
2
- import base64
3
- import tempfile
4
- import runpod
5
- import torch
6
- import torchaudio
7
- import sys
8
- import json
9
- from pathlib import Path
10
- from typing import Dict, Any, List, Union
11
-
12
- # Add Fish-Speech to the Python path
13
- sys.path.append('/app/fish-speech')
14
-
15
- # Import Fish-Speech modules
16
- from fish_speech.models.fish_speech.model import FishSpeech
17
- from fish_speech.models.fish_speech.config import FishSpeechConfig
18
- from fish_speech.utils.audio import load_audio, save_audio
19
- from fish_speech.utils.tokenizer import load_tokenizer
20
-
21
- # Load the model
22
- MODEL_PATH = "/app/models/fish_speech_v1.4.0.pth"
23
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
-
25
- # Initialize the model
26
- def load_model():
27
- print("Loading Fish-Speech model...")
28
- config = FishSpeechConfig()
29
- model = FishSpeech(config)
30
-
31
- # Load the model weights
32
- checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
33
- model.load_state_dict(checkpoint["model"])
34
- model.eval()
35
- model.to(DEVICE)
36
-
37
- # Load the tokenizer
38
- tokenizer = load_tokenizer()
39
-
40
- print(f"Model loaded successfully on {DEVICE}")
41
- return model, tokenizer
42
-
43
- model, tokenizer = load_model()
44
-
45
- # Download a file from a URL
46
- async def download_file(url: str) -> str:
47
- import aiohttp
48
- import os
49
-
50
- # Create a temporary file
51
- temp_dir = tempfile.mkdtemp()
52
- local_filename = os.path.join(temp_dir, "reference_audio.wav")
53
-
54
- # Download the file
55
- async with aiohttp.ClientSession() as session:
56
- async with session.get(url) as response:
57
- if response.status != 200:
58
- raise Exception(f"Failed to download file: {response.status}")
59
-
60
- with open(local_filename, "wb") as f:
61
- f.write(await response.read())
62
-
63
- return local_filename
64
-
65
- # Generate speech using Fish-Speech
66
- def generate_speech(
67
- text: str,
68
- reference_audio_path: str = None,
69
- language: str = "auto",
70
- speed: float = 1.0,
71
- gpt_cond_len: int = 30,
72
- max_ref_len: int = 60,
73
- enhance_audio: bool = True
74
- ) -> str:
75
- """
76
- Generate speech using Fish-Speech
77
-
78
- Args:
79
- text: Text to convert to speech
80
- reference_audio_path: Path to reference audio file for voice cloning
81
- language: Language code (auto for auto-detection)
82
- speed: Speech speed factor
83
- gpt_cond_len: GPT conditioning length
84
- max_ref_len: Maximum reference length
85
- enhance_audio: Whether to enhance audio quality
86
-
87
- Returns:
88
- Path to the generated audio file
89
- """
90
- print(f"Generating speech for text: {text}")
91
-
92
- # Load reference audio if provided
93
- reference = None
94
- if reference_audio_path:
95
- print(f"Loading reference audio: {reference_audio_path}")
96
- reference, sr = load_audio(reference_audio_path)
97
- reference = reference.to(DEVICE)
98
-
99
- # Generate speech
100
- with torch.no_grad():
101
- # Tokenize the text
102
- tokens = tokenizer.encode(text, language=language)
103
- tokens = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
104
-
105
- # Generate speech
106
- output = model.generate(
107
- tokens,
108
- reference=reference,
109
- gpt_cond_latent_length=gpt_cond_len,
110
- max_ref_length=max_ref_len,
111
- top_k=50,
112
- top_p=0.95,
113
- temperature=0.7,
114
- speed=speed
115
- )
116
-
117
- # Get the audio
118
- audio = output["audio"]
119
-
120
- # Enhance audio if requested
121
- if enhance_audio:
122
- # Apply simple normalization
123
- audio = audio / torch.max(torch.abs(audio))
124
-
125
- # Save the audio to a temporary file
126
- temp_dir = tempfile.mkdtemp()
127
- output_path = os.path.join(temp_dir, "output.wav")
128
- save_audio(audio.cpu(), output_path, sample_rate=24000)
129
-
130
- print(f"Speech generated successfully: {output_path}")
131
- return output_path
132
-
133
- # RunPod handler
134
- def handler(event):
135
- """
136
- RunPod handler function
137
-
138
- Args:
139
- event: RunPod event object
140
-
141
- Returns:
142
- Dictionary with the generated audio
143
- """
144
- try:
145
- # Get the input data
146
- input_data = event.get("input", {})
147
-
148
- # Extract parameters
149
- text_data = input_data.get("text", [])
150
- language = input_data.get("language", "auto")
151
- speed = input_data.get("speed", 1.0)
152
- gpt_cond_len = input_data.get("gpt_cond_len", 30)
153
- max_ref_len = input_data.get("max_ref_len", 60)
154
- enhance_audio = input_data.get("enhance_audio", True)
155
- voice_data = input_data.get("voice", {})
156
-
157
- # Validate input
158
- if not text_data:
159
- return {"error": "No text provided"}
160
-
161
- # Process each text segment
162
- results = []
163
-
164
- # Download reference audio if provided
165
- reference_audio_paths = {}
166
- for speaker_id, url in voice_data.items():
167
- if url:
168
- # Use runpod.download_file for synchronous download
169
- local_path = runpod.download_file(url)
170
- reference_audio_paths[speaker_id] = local_path
171
-
172
- # Process each text segment
173
- for speaker_id, text in text_data:
174
- # Get reference audio path
175
- reference_path = reference_audio_paths.get(speaker_id)
176
-
177
- # Generate speech
178
- output_path = generate_speech(
179
- text=text,
180
- reference_audio_path=reference_path,
181
- language=language,
182
- speed=speed,
183
- gpt_cond_len=gpt_cond_len,
184
- max_ref_len=max_ref_len,
185
- enhance_audio=enhance_audio
186
- )
187
-
188
- # Read the audio file and convert to base64
189
- with open(output_path, "rb") as f:
190
- audio_data = f.read()
191
- audio_base64 = base64.b64encode(audio_data).decode("utf-8")
192
-
193
- # Add to results
194
- results.append({
195
- "speaker_id": speaker_id,
196
- "text": text,
197
- "audio_base64": audio_base64
198
- })
199
-
200
- # Return the results
201
- return {
202
- "audio_base64": results[0]["audio_base64"] if results else None,
203
- "results": results
204
- }
205
-
206
- except Exception as e:
207
- import traceback
208
- error_message = str(e)
209
- stack_trace = traceback.format_exc()
210
- print(f"Error: {error_message}")
211
- print(f"Stack trace: {stack_trace}")
212
- return {"error": error_message, "stack_trace": stack_trace}
213
-
214
- # Start the RunPod handler
215
- runpod.serverless.start({"handler": handler})
@@ -1,212 +0,0 @@
1
- import os
2
- import json
3
- import aiohttp
4
- import asyncio
5
- import base64
6
- import tempfile
7
- from typing import Dict, Any, Optional, Union, BinaryIO, List
8
- from ...base_tts_service import BaseTTSService
9
- from ...providers.runpod_provider import RunPodProvider
10
-
11
- class RunPodTTSFishService(BaseTTSService):
12
- """
13
- RunPod TTS service using Fish-Speech
14
-
15
- This service uses the Fish-Speech TTS model deployed on RunPod to generate speech from text.
16
- Fish-Speech is a state-of-the-art open-source TTS model that supports multilingual text-to-speech
17
- and voice cloning capabilities.
18
- """
19
-
20
- def __init__(self, provider: RunPodProvider, model_name: str = "fish-speech"):
21
- """
22
- Initialize the RunPod TTS Fish service
23
-
24
- Args:
25
- provider: RunPod provider instance
26
- model_name: Model name (default: "fish-speech")
27
- """
28
- super().__init__(provider, model_name)
29
- self.api_key = self.config.get("api_key")
30
- self.endpoint_id = self.config.get("endpoint_id")
31
- self.base_url = self.config.get("base_url")
32
-
33
- if not self.api_key:
34
- raise ValueError("RunPod API key is required")
35
-
36
- if not self.endpoint_id:
37
- raise ValueError("RunPod endpoint ID is required")
38
-
39
- self.endpoint_url = f"{self.base_url}/{self.endpoint_id}/run"
40
- self.status_url = f"{self.base_url}/{self.endpoint_id}/status"
41
- self.headers = {
42
- "Authorization": f"Bearer {self.api_key}",
43
- "Content-Type": "application/json"
44
- }
45
-
46
- # Default voice reference URLs (can be overridden in the options)
47
- self.default_voices = {}
48
-
49
- async def _run_inference(self, payload: Dict[str, Any]) -> Dict[str, Any]:
50
- """
51
- Run inference on the RunPod endpoint
52
-
53
- Args:
54
- payload: Request payload
55
-
56
- Returns:
57
- Response from the RunPod endpoint
58
- """
59
- async with aiohttp.ClientSession() as session:
60
- async with session.post(
61
- self.endpoint_url,
62
- headers=self.headers,
63
- json={"input": payload}
64
- ) as response:
65
- if response.status != 200:
66
- error_text = await response.text()
67
- raise Exception(f"RunPod API error: {response.status} - {error_text}")
68
-
69
- result = await response.json()
70
- job_id = result.get("id")
71
-
72
- if not job_id:
73
- raise Exception("No job ID returned from RunPod API")
74
-
75
- # Poll for job completion
76
- while True:
77
- async with session.get(
78
- f"{self.status_url}/{job_id}",
79
- headers=self.headers
80
- ) as status_response:
81
- status_data = await status_response.json()
82
- status = status_data.get("status")
83
-
84
- if status == "COMPLETED":
85
- return status_data.get("output", {})
86
- elif status == "FAILED":
87
- error = status_data.get("error", "Unknown error")
88
- raise Exception(f"RunPod job failed: {error}")
89
-
90
- # Wait before polling again
91
- await asyncio.sleep(1)
92
-
93
- async def generate_speech(
94
- self,
95
- text: str,
96
- voice_id: Optional[str] = None,
97
- language: Optional[str] = None,
98
- speed: float = 1.0,
99
- options: Optional[Dict[str, Any]] = None
100
- ) -> bytes:
101
- """
102
- Generate speech from text using Fish-Speech
103
-
104
- Args:
105
- text: The text to convert to speech
106
- voice_id: Voice identifier (URL to reference audio)
107
- language: Language code (auto-detected if not provided)
108
- speed: Speech speed factor (1.0 is normal speed)
109
- options: Additional options:
110
- - gpt_cond_len: GPT conditioning length
111
- - max_ref_len: Maximum reference length
112
- - enhance_audio: Whether to enhance audio quality
113
-
114
- Returns:
115
- Audio data as bytes
116
- """
117
- options = options or {}
118
-
119
- # Prepare the payload
120
- payload = {
121
- "text": [[voice_id or "speaker_0", text]],
122
- "language": language or "auto",
123
- "speed": speed,
124
- "gpt_cond_len": options.get("gpt_cond_len", 30),
125
- "max_ref_len": options.get("max_ref_len", 60),
126
- "enhance_audio": options.get("enhance_audio", True)
127
- }
128
-
129
- # Add voice reference
130
- voice_url = None
131
- if voice_id and voice_id.startswith(("http://", "https://")):
132
- voice_url = voice_id
133
- elif voice_id and voice_id in self.default_voices:
134
- voice_url = self.default_voices[voice_id]
135
-
136
- if voice_url:
137
- payload["voice"] = {"speaker_0": voice_url}
138
-
139
- # Run inference
140
- result = await self._run_inference(payload)
141
-
142
- # Extract audio data
143
- if "audio_base64" in result:
144
- return base64.b64decode(result["audio_base64"])
145
- elif "audio_url" in result:
146
- # Download audio from URL
147
- async with aiohttp.ClientSession() as session:
148
- async with session.get(result["audio_url"]) as response:
149
- if response.status != 200:
150
- raise Exception(f"Failed to download audio: {response.status}")
151
- return await response.read()
152
- else:
153
- raise Exception("No audio data in response")
154
-
155
- async def save_to_file(
156
- self,
157
- text: str,
158
- output_file: Union[str, BinaryIO],
159
- voice_id: Optional[str] = None,
160
- language: Optional[str] = None,
161
- speed: float = 1.0,
162
- options: Optional[Dict[str, Any]] = None
163
- ) -> str:
164
- """
165
- Generate speech and save to file
166
-
167
- Args:
168
- text: The text to convert to speech
169
- output_file: Path to output file or file-like object
170
- voice_id: Voice identifier
171
- language: Language code
172
- speed: Speech speed factor
173
- options: Additional options
174
-
175
- Returns:
176
- Path to the saved file
177
- """
178
- audio_data = await self.generate_speech(text, voice_id, language, speed, options)
179
-
180
- if isinstance(output_file, str):
181
- with open(output_file, "wb") as f:
182
- f.write(audio_data)
183
- return output_file
184
- else:
185
- output_file.write(audio_data)
186
- if hasattr(output_file, "name"):
187
- return output_file.name
188
- return "audio.wav"
189
-
190
- async def get_available_voices(self) -> Dict[str, Any]:
191
- """
192
- Get available voices for the TTS service
193
-
194
- Returns:
195
- Dictionary of available voices with their details
196
- """
197
- # Fish-Speech doesn't have a fixed set of voices as it uses voice cloning
198
- # Return the default voices that have been configured
199
- return {
200
- "voices": list(self.default_voices.keys()),
201
- "note": "Fish-Speech supports voice cloning. Provide a URL to a reference audio file to clone a voice."
202
- }
203
-
204
- def add_voice(self, voice_id: str, reference_url: str) -> None:
205
- """
206
- Add a voice to the default voices
207
-
208
- Args:
209
- voice_id: Voice identifier
210
- reference_url: URL to the reference audio file
211
- """
212
- self.default_voices[voice_id] = reference_url
@@ -1,138 +0,0 @@
1
- import json
2
- import logging
3
- import asyncio
4
- import io
5
- import numpy as np
6
- from typing import Dict, List, Any, AsyncGenerator, Optional, Union, BinaryIO
7
-
8
- from isa_model.inference.services.base_service import BaseService
9
- from isa_model.inference.providers.triton_provider import TritonProvider
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- class TritonSpeechService(BaseService):
15
- """
16
- Speech service that uses Triton Inference Server to run speech-to-text inference.
17
- """
18
-
19
- def __init__(self, provider: TritonProvider, model_name: str):
20
- """
21
- Initialize the Triton Speech service.
22
-
23
- Args:
24
- provider: The Triton provider
25
- model_name: Name of the model in Triton (e.g., "whisper_tiny")
26
- """
27
- super().__init__(provider, model_name)
28
- self.client = None
29
-
30
- async def _initialize_client(self):
31
- """Initialize the Triton client"""
32
- if self.client is None:
33
- self.client = self.provider.create_client()
34
-
35
- # Check if model is ready
36
- if not self.provider.is_model_ready(self.model_name):
37
- logger.error(f"Model {self.model_name} is not ready on Triton server")
38
- raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
39
-
40
- logger.info(f"Initialized Triton client for speech model: {self.model_name}")
41
-
42
- async def transcribe(self,
43
- audio: Union[str, BinaryIO, bytes, np.ndarray],
44
- language: str = "en",
45
- config: Optional[Dict[str, Any]] = None) -> str:
46
- """
47
- Transcribe audio to text using the Triton Inference Server.
48
-
49
- Args:
50
- audio: Audio input (file path, file-like object, bytes, or numpy array)
51
- language: Language code (e.g., "en", "fr")
52
- config: Additional configuration parameters
53
-
54
- Returns:
55
- Transcribed text
56
- """
57
- await self._initialize_client()
58
-
59
- try:
60
- import tritonclient.http as httpclient
61
-
62
- # Process audio to get numpy array
63
- audio_array = await self._process_audio_input(audio)
64
-
65
- # Create input tensors for audio
66
- audio_input = httpclient.InferInput("audio_input", audio_array.shape, "FP32")
67
- audio_input.set_data_from_numpy(audio_array)
68
- inputs = [audio_input]
69
-
70
- # Add language input
71
- language_data = np.array([language], dtype=np.object_)
72
- language_input = httpclient.InferInput("language", language_data.shape, "BYTES")
73
- language_input.set_data_from_numpy(language_data)
74
- inputs.append(language_input)
75
-
76
- # Create output tensor
77
- outputs = [httpclient.InferRequestedOutput("text_output")]
78
-
79
- # Send the request
80
- response = await asyncio.to_thread(
81
- self.client.infer,
82
- self.model_name,
83
- inputs,
84
- outputs=outputs
85
- )
86
-
87
- # Process the response
88
- output = response.as_numpy("text_output")
89
- transcription = output[0].decode('utf-8')
90
-
91
- return transcription
92
-
93
- except Exception as e:
94
- logger.error(f"Error during Triton speech inference: {str(e)}")
95
- raise
96
-
97
- async def _process_audio_input(self, audio: Union[str, BinaryIO, bytes, np.ndarray]) -> np.ndarray:
98
- """
99
- Process different types of audio inputs into a numpy array.
100
-
101
- Args:
102
- audio: Audio input (file path, file-like object, bytes, or numpy array)
103
-
104
- Returns:
105
- Numpy array of the audio
106
- """
107
- if isinstance(audio, np.ndarray):
108
- return audio
109
-
110
- try:
111
- import librosa
112
-
113
- if isinstance(audio, str):
114
- # File path
115
- y, sr = librosa.load(audio, sr=16000) # Whisper expects 16kHz audio
116
- return y.astype(np.float32)
117
-
118
- elif isinstance(audio, (io.IOBase, BinaryIO)):
119
- # File-like object
120
- audio.seek(0)
121
- y, sr = librosa.load(audio, sr=16000)
122
- return y.astype(np.float32)
123
-
124
- elif isinstance(audio, bytes):
125
- # Bytes
126
- with io.BytesIO(audio) as audio_bytes:
127
- y, sr = librosa.load(audio_bytes, sr=16000)
128
- return y.astype(np.float32)
129
-
130
- else:
131
- raise ValueError(f"Unsupported audio type: {type(audio)}")
132
-
133
- except ImportError:
134
- logger.error("librosa not installed. Please install with: pip install librosa")
135
- raise
136
- except Exception as e:
137
- logger.error(f"Error processing audio: {str(e)}")
138
- raise