isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.1.1.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
File without changes
|
@@ -1,83 +0,0 @@
|
|
1
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
2
|
-
from isa_model.inference.base import ModelType, Capability
|
3
|
-
from typing import Dict, List, Any
|
4
|
-
import logging
|
5
|
-
|
6
|
-
logger = logging.getLogger(__name__)
|
7
|
-
|
8
|
-
class YYDSProvider(BaseProvider):
|
9
|
-
"""Provider for YYDS API (Your YYDS Provider API)"""
|
10
|
-
|
11
|
-
def __init__(self, config=None):
|
12
|
-
"""
|
13
|
-
Initialize the YYDS Provider
|
14
|
-
|
15
|
-
Args:
|
16
|
-
config (dict, optional): Configuration for the provider
|
17
|
-
- api_key: API key for authentication
|
18
|
-
- api_base: Base URL for YYDS API
|
19
|
-
- timeout: Timeout for API calls in seconds
|
20
|
-
"""
|
21
|
-
default_config = {
|
22
|
-
"api_base": "https://api.yyds.ai/v1",
|
23
|
-
"timeout": 60,
|
24
|
-
"max_retries": 3,
|
25
|
-
"temperature": 0.7,
|
26
|
-
"top_p": 0.9,
|
27
|
-
"max_tokens": 2048
|
28
|
-
}
|
29
|
-
|
30
|
-
# Merge default config with provided config
|
31
|
-
merged_config = {**default_config, **(config or {})}
|
32
|
-
|
33
|
-
super().__init__(config=merged_config)
|
34
|
-
self.name = "yyds"
|
35
|
-
|
36
|
-
# Validate API key
|
37
|
-
api_key = self.config.get("api_key")
|
38
|
-
if not api_key:
|
39
|
-
logger.warning("No API key provided for YYDS Provider. Some operations may fail.")
|
40
|
-
|
41
|
-
logger.info(f"Initialized YYDSProvider with API base: {self.config['api_base']}")
|
42
|
-
|
43
|
-
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
44
|
-
"""Get provider capabilities by model type"""
|
45
|
-
return {
|
46
|
-
ModelType.LLM: [
|
47
|
-
Capability.CHAT,
|
48
|
-
Capability.COMPLETION
|
49
|
-
],
|
50
|
-
ModelType.VISION: [
|
51
|
-
Capability.IMAGE_CLASSIFICATION,
|
52
|
-
Capability.IMAGE_UNDERSTANDING
|
53
|
-
],
|
54
|
-
ModelType.AUDIO: [
|
55
|
-
Capability.SPEECH_TO_TEXT,
|
56
|
-
Capability.TEXT_TO_SPEECH
|
57
|
-
]
|
58
|
-
}
|
59
|
-
|
60
|
-
def get_models(self, model_type: ModelType) -> List[str]:
|
61
|
-
"""Get available models for given type"""
|
62
|
-
# Placeholder: In real implementation, this would query the YYDS API
|
63
|
-
if model_type == ModelType.LLM:
|
64
|
-
return ["yyds-l", "yyds-xl", "yyds-xxl"]
|
65
|
-
elif model_type == ModelType.VISION:
|
66
|
-
return ["yyds-vision", "yyds-multimodal"]
|
67
|
-
elif model_type == ModelType.AUDIO:
|
68
|
-
return ["yyds-speech", "yyds-tts"]
|
69
|
-
else:
|
70
|
-
return []
|
71
|
-
|
72
|
-
def get_config(self) -> Dict[str, Any]:
|
73
|
-
"""Get provider configuration"""
|
74
|
-
# Return a copy of the config, without the API key for security
|
75
|
-
config_copy = self.config.copy()
|
76
|
-
if "api_key" in config_copy:
|
77
|
-
config_copy["api_key"] = "***" # Mask the API key
|
78
|
-
return config_copy
|
79
|
-
|
80
|
-
def is_reasoning_model(self, model_name: str) -> bool:
|
81
|
-
"""Check if the model is optimized for reasoning tasks"""
|
82
|
-
# Only the largest models are considered reasoning-capable
|
83
|
-
return model_name in ["yyds-xxl"]
|
@@ -1,215 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import base64
|
3
|
-
import tempfile
|
4
|
-
import runpod
|
5
|
-
import torch
|
6
|
-
import torchaudio
|
7
|
-
import sys
|
8
|
-
import json
|
9
|
-
from pathlib import Path
|
10
|
-
from typing import Dict, Any, List, Union
|
11
|
-
|
12
|
-
# Add Fish-Speech to the Python path
|
13
|
-
sys.path.append('/app/fish-speech')
|
14
|
-
|
15
|
-
# Import Fish-Speech modules
|
16
|
-
from fish_speech.models.fish_speech.model import FishSpeech
|
17
|
-
from fish_speech.models.fish_speech.config import FishSpeechConfig
|
18
|
-
from fish_speech.utils.audio import load_audio, save_audio
|
19
|
-
from fish_speech.utils.tokenizer import load_tokenizer
|
20
|
-
|
21
|
-
# Load the model
|
22
|
-
MODEL_PATH = "/app/models/fish_speech_v1.4.0.pth"
|
23
|
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
24
|
-
|
25
|
-
# Initialize the model
|
26
|
-
def load_model():
|
27
|
-
print("Loading Fish-Speech model...")
|
28
|
-
config = FishSpeechConfig()
|
29
|
-
model = FishSpeech(config)
|
30
|
-
|
31
|
-
# Load the model weights
|
32
|
-
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
|
33
|
-
model.load_state_dict(checkpoint["model"])
|
34
|
-
model.eval()
|
35
|
-
model.to(DEVICE)
|
36
|
-
|
37
|
-
# Load the tokenizer
|
38
|
-
tokenizer = load_tokenizer()
|
39
|
-
|
40
|
-
print(f"Model loaded successfully on {DEVICE}")
|
41
|
-
return model, tokenizer
|
42
|
-
|
43
|
-
model, tokenizer = load_model()
|
44
|
-
|
45
|
-
# Download a file from a URL
|
46
|
-
async def download_file(url: str) -> str:
|
47
|
-
import aiohttp
|
48
|
-
import os
|
49
|
-
|
50
|
-
# Create a temporary file
|
51
|
-
temp_dir = tempfile.mkdtemp()
|
52
|
-
local_filename = os.path.join(temp_dir, "reference_audio.wav")
|
53
|
-
|
54
|
-
# Download the file
|
55
|
-
async with aiohttp.ClientSession() as session:
|
56
|
-
async with session.get(url) as response:
|
57
|
-
if response.status != 200:
|
58
|
-
raise Exception(f"Failed to download file: {response.status}")
|
59
|
-
|
60
|
-
with open(local_filename, "wb") as f:
|
61
|
-
f.write(await response.read())
|
62
|
-
|
63
|
-
return local_filename
|
64
|
-
|
65
|
-
# Generate speech using Fish-Speech
|
66
|
-
def generate_speech(
|
67
|
-
text: str,
|
68
|
-
reference_audio_path: str = None,
|
69
|
-
language: str = "auto",
|
70
|
-
speed: float = 1.0,
|
71
|
-
gpt_cond_len: int = 30,
|
72
|
-
max_ref_len: int = 60,
|
73
|
-
enhance_audio: bool = True
|
74
|
-
) -> str:
|
75
|
-
"""
|
76
|
-
Generate speech using Fish-Speech
|
77
|
-
|
78
|
-
Args:
|
79
|
-
text: Text to convert to speech
|
80
|
-
reference_audio_path: Path to reference audio file for voice cloning
|
81
|
-
language: Language code (auto for auto-detection)
|
82
|
-
speed: Speech speed factor
|
83
|
-
gpt_cond_len: GPT conditioning length
|
84
|
-
max_ref_len: Maximum reference length
|
85
|
-
enhance_audio: Whether to enhance audio quality
|
86
|
-
|
87
|
-
Returns:
|
88
|
-
Path to the generated audio file
|
89
|
-
"""
|
90
|
-
print(f"Generating speech for text: {text}")
|
91
|
-
|
92
|
-
# Load reference audio if provided
|
93
|
-
reference = None
|
94
|
-
if reference_audio_path:
|
95
|
-
print(f"Loading reference audio: {reference_audio_path}")
|
96
|
-
reference, sr = load_audio(reference_audio_path)
|
97
|
-
reference = reference.to(DEVICE)
|
98
|
-
|
99
|
-
# Generate speech
|
100
|
-
with torch.no_grad():
|
101
|
-
# Tokenize the text
|
102
|
-
tokens = tokenizer.encode(text, language=language)
|
103
|
-
tokens = torch.tensor(tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
|
104
|
-
|
105
|
-
# Generate speech
|
106
|
-
output = model.generate(
|
107
|
-
tokens,
|
108
|
-
reference=reference,
|
109
|
-
gpt_cond_latent_length=gpt_cond_len,
|
110
|
-
max_ref_length=max_ref_len,
|
111
|
-
top_k=50,
|
112
|
-
top_p=0.95,
|
113
|
-
temperature=0.7,
|
114
|
-
speed=speed
|
115
|
-
)
|
116
|
-
|
117
|
-
# Get the audio
|
118
|
-
audio = output["audio"]
|
119
|
-
|
120
|
-
# Enhance audio if requested
|
121
|
-
if enhance_audio:
|
122
|
-
# Apply simple normalization
|
123
|
-
audio = audio / torch.max(torch.abs(audio))
|
124
|
-
|
125
|
-
# Save the audio to a temporary file
|
126
|
-
temp_dir = tempfile.mkdtemp()
|
127
|
-
output_path = os.path.join(temp_dir, "output.wav")
|
128
|
-
save_audio(audio.cpu(), output_path, sample_rate=24000)
|
129
|
-
|
130
|
-
print(f"Speech generated successfully: {output_path}")
|
131
|
-
return output_path
|
132
|
-
|
133
|
-
# RunPod handler
|
134
|
-
def handler(event):
|
135
|
-
"""
|
136
|
-
RunPod handler function
|
137
|
-
|
138
|
-
Args:
|
139
|
-
event: RunPod event object
|
140
|
-
|
141
|
-
Returns:
|
142
|
-
Dictionary with the generated audio
|
143
|
-
"""
|
144
|
-
try:
|
145
|
-
# Get the input data
|
146
|
-
input_data = event.get("input", {})
|
147
|
-
|
148
|
-
# Extract parameters
|
149
|
-
text_data = input_data.get("text", [])
|
150
|
-
language = input_data.get("language", "auto")
|
151
|
-
speed = input_data.get("speed", 1.0)
|
152
|
-
gpt_cond_len = input_data.get("gpt_cond_len", 30)
|
153
|
-
max_ref_len = input_data.get("max_ref_len", 60)
|
154
|
-
enhance_audio = input_data.get("enhance_audio", True)
|
155
|
-
voice_data = input_data.get("voice", {})
|
156
|
-
|
157
|
-
# Validate input
|
158
|
-
if not text_data:
|
159
|
-
return {"error": "No text provided"}
|
160
|
-
|
161
|
-
# Process each text segment
|
162
|
-
results = []
|
163
|
-
|
164
|
-
# Download reference audio if provided
|
165
|
-
reference_audio_paths = {}
|
166
|
-
for speaker_id, url in voice_data.items():
|
167
|
-
if url:
|
168
|
-
# Use runpod.download_file for synchronous download
|
169
|
-
local_path = runpod.download_file(url)
|
170
|
-
reference_audio_paths[speaker_id] = local_path
|
171
|
-
|
172
|
-
# Process each text segment
|
173
|
-
for speaker_id, text in text_data:
|
174
|
-
# Get reference audio path
|
175
|
-
reference_path = reference_audio_paths.get(speaker_id)
|
176
|
-
|
177
|
-
# Generate speech
|
178
|
-
output_path = generate_speech(
|
179
|
-
text=text,
|
180
|
-
reference_audio_path=reference_path,
|
181
|
-
language=language,
|
182
|
-
speed=speed,
|
183
|
-
gpt_cond_len=gpt_cond_len,
|
184
|
-
max_ref_len=max_ref_len,
|
185
|
-
enhance_audio=enhance_audio
|
186
|
-
)
|
187
|
-
|
188
|
-
# Read the audio file and convert to base64
|
189
|
-
with open(output_path, "rb") as f:
|
190
|
-
audio_data = f.read()
|
191
|
-
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
192
|
-
|
193
|
-
# Add to results
|
194
|
-
results.append({
|
195
|
-
"speaker_id": speaker_id,
|
196
|
-
"text": text,
|
197
|
-
"audio_base64": audio_base64
|
198
|
-
})
|
199
|
-
|
200
|
-
# Return the results
|
201
|
-
return {
|
202
|
-
"audio_base64": results[0]["audio_base64"] if results else None,
|
203
|
-
"results": results
|
204
|
-
}
|
205
|
-
|
206
|
-
except Exception as e:
|
207
|
-
import traceback
|
208
|
-
error_message = str(e)
|
209
|
-
stack_trace = traceback.format_exc()
|
210
|
-
print(f"Error: {error_message}")
|
211
|
-
print(f"Stack trace: {stack_trace}")
|
212
|
-
return {"error": error_message, "stack_trace": stack_trace}
|
213
|
-
|
214
|
-
# Start the RunPod handler
|
215
|
-
runpod.serverless.start({"handler": handler})
|
@@ -1,212 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import json
|
3
|
-
import aiohttp
|
4
|
-
import asyncio
|
5
|
-
import base64
|
6
|
-
import tempfile
|
7
|
-
from typing import Dict, Any, Optional, Union, BinaryIO, List
|
8
|
-
from ...base_tts_service import BaseTTSService
|
9
|
-
from ...providers.runpod_provider import RunPodProvider
|
10
|
-
|
11
|
-
class RunPodTTSFishService(BaseTTSService):
|
12
|
-
"""
|
13
|
-
RunPod TTS service using Fish-Speech
|
14
|
-
|
15
|
-
This service uses the Fish-Speech TTS model deployed on RunPod to generate speech from text.
|
16
|
-
Fish-Speech is a state-of-the-art open-source TTS model that supports multilingual text-to-speech
|
17
|
-
and voice cloning capabilities.
|
18
|
-
"""
|
19
|
-
|
20
|
-
def __init__(self, provider: RunPodProvider, model_name: str = "fish-speech"):
|
21
|
-
"""
|
22
|
-
Initialize the RunPod TTS Fish service
|
23
|
-
|
24
|
-
Args:
|
25
|
-
provider: RunPod provider instance
|
26
|
-
model_name: Model name (default: "fish-speech")
|
27
|
-
"""
|
28
|
-
super().__init__(provider, model_name)
|
29
|
-
self.api_key = self.config.get("api_key")
|
30
|
-
self.endpoint_id = self.config.get("endpoint_id")
|
31
|
-
self.base_url = self.config.get("base_url")
|
32
|
-
|
33
|
-
if not self.api_key:
|
34
|
-
raise ValueError("RunPod API key is required")
|
35
|
-
|
36
|
-
if not self.endpoint_id:
|
37
|
-
raise ValueError("RunPod endpoint ID is required")
|
38
|
-
|
39
|
-
self.endpoint_url = f"{self.base_url}/{self.endpoint_id}/run"
|
40
|
-
self.status_url = f"{self.base_url}/{self.endpoint_id}/status"
|
41
|
-
self.headers = {
|
42
|
-
"Authorization": f"Bearer {self.api_key}",
|
43
|
-
"Content-Type": "application/json"
|
44
|
-
}
|
45
|
-
|
46
|
-
# Default voice reference URLs (can be overridden in the options)
|
47
|
-
self.default_voices = {}
|
48
|
-
|
49
|
-
async def _run_inference(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
50
|
-
"""
|
51
|
-
Run inference on the RunPod endpoint
|
52
|
-
|
53
|
-
Args:
|
54
|
-
payload: Request payload
|
55
|
-
|
56
|
-
Returns:
|
57
|
-
Response from the RunPod endpoint
|
58
|
-
"""
|
59
|
-
async with aiohttp.ClientSession() as session:
|
60
|
-
async with session.post(
|
61
|
-
self.endpoint_url,
|
62
|
-
headers=self.headers,
|
63
|
-
json={"input": payload}
|
64
|
-
) as response:
|
65
|
-
if response.status != 200:
|
66
|
-
error_text = await response.text()
|
67
|
-
raise Exception(f"RunPod API error: {response.status} - {error_text}")
|
68
|
-
|
69
|
-
result = await response.json()
|
70
|
-
job_id = result.get("id")
|
71
|
-
|
72
|
-
if not job_id:
|
73
|
-
raise Exception("No job ID returned from RunPod API")
|
74
|
-
|
75
|
-
# Poll for job completion
|
76
|
-
while True:
|
77
|
-
async with session.get(
|
78
|
-
f"{self.status_url}/{job_id}",
|
79
|
-
headers=self.headers
|
80
|
-
) as status_response:
|
81
|
-
status_data = await status_response.json()
|
82
|
-
status = status_data.get("status")
|
83
|
-
|
84
|
-
if status == "COMPLETED":
|
85
|
-
return status_data.get("output", {})
|
86
|
-
elif status == "FAILED":
|
87
|
-
error = status_data.get("error", "Unknown error")
|
88
|
-
raise Exception(f"RunPod job failed: {error}")
|
89
|
-
|
90
|
-
# Wait before polling again
|
91
|
-
await asyncio.sleep(1)
|
92
|
-
|
93
|
-
async def generate_speech(
|
94
|
-
self,
|
95
|
-
text: str,
|
96
|
-
voice_id: Optional[str] = None,
|
97
|
-
language: Optional[str] = None,
|
98
|
-
speed: float = 1.0,
|
99
|
-
options: Optional[Dict[str, Any]] = None
|
100
|
-
) -> bytes:
|
101
|
-
"""
|
102
|
-
Generate speech from text using Fish-Speech
|
103
|
-
|
104
|
-
Args:
|
105
|
-
text: The text to convert to speech
|
106
|
-
voice_id: Voice identifier (URL to reference audio)
|
107
|
-
language: Language code (auto-detected if not provided)
|
108
|
-
speed: Speech speed factor (1.0 is normal speed)
|
109
|
-
options: Additional options:
|
110
|
-
- gpt_cond_len: GPT conditioning length
|
111
|
-
- max_ref_len: Maximum reference length
|
112
|
-
- enhance_audio: Whether to enhance audio quality
|
113
|
-
|
114
|
-
Returns:
|
115
|
-
Audio data as bytes
|
116
|
-
"""
|
117
|
-
options = options or {}
|
118
|
-
|
119
|
-
# Prepare the payload
|
120
|
-
payload = {
|
121
|
-
"text": [[voice_id or "speaker_0", text]],
|
122
|
-
"language": language or "auto",
|
123
|
-
"speed": speed,
|
124
|
-
"gpt_cond_len": options.get("gpt_cond_len", 30),
|
125
|
-
"max_ref_len": options.get("max_ref_len", 60),
|
126
|
-
"enhance_audio": options.get("enhance_audio", True)
|
127
|
-
}
|
128
|
-
|
129
|
-
# Add voice reference
|
130
|
-
voice_url = None
|
131
|
-
if voice_id and voice_id.startswith(("http://", "https://")):
|
132
|
-
voice_url = voice_id
|
133
|
-
elif voice_id and voice_id in self.default_voices:
|
134
|
-
voice_url = self.default_voices[voice_id]
|
135
|
-
|
136
|
-
if voice_url:
|
137
|
-
payload["voice"] = {"speaker_0": voice_url}
|
138
|
-
|
139
|
-
# Run inference
|
140
|
-
result = await self._run_inference(payload)
|
141
|
-
|
142
|
-
# Extract audio data
|
143
|
-
if "audio_base64" in result:
|
144
|
-
return base64.b64decode(result["audio_base64"])
|
145
|
-
elif "audio_url" in result:
|
146
|
-
# Download audio from URL
|
147
|
-
async with aiohttp.ClientSession() as session:
|
148
|
-
async with session.get(result["audio_url"]) as response:
|
149
|
-
if response.status != 200:
|
150
|
-
raise Exception(f"Failed to download audio: {response.status}")
|
151
|
-
return await response.read()
|
152
|
-
else:
|
153
|
-
raise Exception("No audio data in response")
|
154
|
-
|
155
|
-
async def save_to_file(
|
156
|
-
self,
|
157
|
-
text: str,
|
158
|
-
output_file: Union[str, BinaryIO],
|
159
|
-
voice_id: Optional[str] = None,
|
160
|
-
language: Optional[str] = None,
|
161
|
-
speed: float = 1.0,
|
162
|
-
options: Optional[Dict[str, Any]] = None
|
163
|
-
) -> str:
|
164
|
-
"""
|
165
|
-
Generate speech and save to file
|
166
|
-
|
167
|
-
Args:
|
168
|
-
text: The text to convert to speech
|
169
|
-
output_file: Path to output file or file-like object
|
170
|
-
voice_id: Voice identifier
|
171
|
-
language: Language code
|
172
|
-
speed: Speech speed factor
|
173
|
-
options: Additional options
|
174
|
-
|
175
|
-
Returns:
|
176
|
-
Path to the saved file
|
177
|
-
"""
|
178
|
-
audio_data = await self.generate_speech(text, voice_id, language, speed, options)
|
179
|
-
|
180
|
-
if isinstance(output_file, str):
|
181
|
-
with open(output_file, "wb") as f:
|
182
|
-
f.write(audio_data)
|
183
|
-
return output_file
|
184
|
-
else:
|
185
|
-
output_file.write(audio_data)
|
186
|
-
if hasattr(output_file, "name"):
|
187
|
-
return output_file.name
|
188
|
-
return "audio.wav"
|
189
|
-
|
190
|
-
async def get_available_voices(self) -> Dict[str, Any]:
|
191
|
-
"""
|
192
|
-
Get available voices for the TTS service
|
193
|
-
|
194
|
-
Returns:
|
195
|
-
Dictionary of available voices with their details
|
196
|
-
"""
|
197
|
-
# Fish-Speech doesn't have a fixed set of voices as it uses voice cloning
|
198
|
-
# Return the default voices that have been configured
|
199
|
-
return {
|
200
|
-
"voices": list(self.default_voices.keys()),
|
201
|
-
"note": "Fish-Speech supports voice cloning. Provide a URL to a reference audio file to clone a voice."
|
202
|
-
}
|
203
|
-
|
204
|
-
def add_voice(self, voice_id: str, reference_url: str) -> None:
|
205
|
-
"""
|
206
|
-
Add a voice to the default voices
|
207
|
-
|
208
|
-
Args:
|
209
|
-
voice_id: Voice identifier
|
210
|
-
reference_url: URL to the reference audio file
|
211
|
-
"""
|
212
|
-
self.default_voices[voice_id] = reference_url
|
@@ -1,138 +0,0 @@
|
|
1
|
-
import json
|
2
|
-
import logging
|
3
|
-
import asyncio
|
4
|
-
import io
|
5
|
-
import numpy as np
|
6
|
-
from typing import Dict, List, Any, AsyncGenerator, Optional, Union, BinaryIO
|
7
|
-
|
8
|
-
from isa_model.inference.services.base_service import BaseService
|
9
|
-
from isa_model.inference.providers.triton_provider import TritonProvider
|
10
|
-
|
11
|
-
logger = logging.getLogger(__name__)
|
12
|
-
|
13
|
-
|
14
|
-
class TritonSpeechService(BaseService):
|
15
|
-
"""
|
16
|
-
Speech service that uses Triton Inference Server to run speech-to-text inference.
|
17
|
-
"""
|
18
|
-
|
19
|
-
def __init__(self, provider: TritonProvider, model_name: str):
|
20
|
-
"""
|
21
|
-
Initialize the Triton Speech service.
|
22
|
-
|
23
|
-
Args:
|
24
|
-
provider: The Triton provider
|
25
|
-
model_name: Name of the model in Triton (e.g., "whisper_tiny")
|
26
|
-
"""
|
27
|
-
super().__init__(provider, model_name)
|
28
|
-
self.client = None
|
29
|
-
|
30
|
-
async def _initialize_client(self):
|
31
|
-
"""Initialize the Triton client"""
|
32
|
-
if self.client is None:
|
33
|
-
self.client = self.provider.create_client()
|
34
|
-
|
35
|
-
# Check if model is ready
|
36
|
-
if not self.provider.is_model_ready(self.model_name):
|
37
|
-
logger.error(f"Model {self.model_name} is not ready on Triton server")
|
38
|
-
raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
|
39
|
-
|
40
|
-
logger.info(f"Initialized Triton client for speech model: {self.model_name}")
|
41
|
-
|
42
|
-
async def transcribe(self,
|
43
|
-
audio: Union[str, BinaryIO, bytes, np.ndarray],
|
44
|
-
language: str = "en",
|
45
|
-
config: Optional[Dict[str, Any]] = None) -> str:
|
46
|
-
"""
|
47
|
-
Transcribe audio to text using the Triton Inference Server.
|
48
|
-
|
49
|
-
Args:
|
50
|
-
audio: Audio input (file path, file-like object, bytes, or numpy array)
|
51
|
-
language: Language code (e.g., "en", "fr")
|
52
|
-
config: Additional configuration parameters
|
53
|
-
|
54
|
-
Returns:
|
55
|
-
Transcribed text
|
56
|
-
"""
|
57
|
-
await self._initialize_client()
|
58
|
-
|
59
|
-
try:
|
60
|
-
import tritonclient.http as httpclient
|
61
|
-
|
62
|
-
# Process audio to get numpy array
|
63
|
-
audio_array = await self._process_audio_input(audio)
|
64
|
-
|
65
|
-
# Create input tensors for audio
|
66
|
-
audio_input = httpclient.InferInput("audio_input", audio_array.shape, "FP32")
|
67
|
-
audio_input.set_data_from_numpy(audio_array)
|
68
|
-
inputs = [audio_input]
|
69
|
-
|
70
|
-
# Add language input
|
71
|
-
language_data = np.array([language], dtype=np.object_)
|
72
|
-
language_input = httpclient.InferInput("language", language_data.shape, "BYTES")
|
73
|
-
language_input.set_data_from_numpy(language_data)
|
74
|
-
inputs.append(language_input)
|
75
|
-
|
76
|
-
# Create output tensor
|
77
|
-
outputs = [httpclient.InferRequestedOutput("text_output")]
|
78
|
-
|
79
|
-
# Send the request
|
80
|
-
response = await asyncio.to_thread(
|
81
|
-
self.client.infer,
|
82
|
-
self.model_name,
|
83
|
-
inputs,
|
84
|
-
outputs=outputs
|
85
|
-
)
|
86
|
-
|
87
|
-
# Process the response
|
88
|
-
output = response.as_numpy("text_output")
|
89
|
-
transcription = output[0].decode('utf-8')
|
90
|
-
|
91
|
-
return transcription
|
92
|
-
|
93
|
-
except Exception as e:
|
94
|
-
logger.error(f"Error during Triton speech inference: {str(e)}")
|
95
|
-
raise
|
96
|
-
|
97
|
-
async def _process_audio_input(self, audio: Union[str, BinaryIO, bytes, np.ndarray]) -> np.ndarray:
|
98
|
-
"""
|
99
|
-
Process different types of audio inputs into a numpy array.
|
100
|
-
|
101
|
-
Args:
|
102
|
-
audio: Audio input (file path, file-like object, bytes, or numpy array)
|
103
|
-
|
104
|
-
Returns:
|
105
|
-
Numpy array of the audio
|
106
|
-
"""
|
107
|
-
if isinstance(audio, np.ndarray):
|
108
|
-
return audio
|
109
|
-
|
110
|
-
try:
|
111
|
-
import librosa
|
112
|
-
|
113
|
-
if isinstance(audio, str):
|
114
|
-
# File path
|
115
|
-
y, sr = librosa.load(audio, sr=16000) # Whisper expects 16kHz audio
|
116
|
-
return y.astype(np.float32)
|
117
|
-
|
118
|
-
elif isinstance(audio, (io.IOBase, BinaryIO)):
|
119
|
-
# File-like object
|
120
|
-
audio.seek(0)
|
121
|
-
y, sr = librosa.load(audio, sr=16000)
|
122
|
-
return y.astype(np.float32)
|
123
|
-
|
124
|
-
elif isinstance(audio, bytes):
|
125
|
-
# Bytes
|
126
|
-
with io.BytesIO(audio) as audio_bytes:
|
127
|
-
y, sr = librosa.load(audio_bytes, sr=16000)
|
128
|
-
return y.astype(np.float32)
|
129
|
-
|
130
|
-
else:
|
131
|
-
raise ValueError(f"Unsupported audio type: {type(audio)}")
|
132
|
-
|
133
|
-
except ImportError:
|
134
|
-
logger.error("librosa not installed. Please install with: pip install librosa")
|
135
|
-
raise
|
136
|
-
except Exception as e:
|
137
|
-
logger.error(f"Error processing audio: {str(e)}")
|
138
|
-
raise
|