isa-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +359 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +101 -0
- isa_model/inference/providers/replicate_provider.py +107 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/openai_tts_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
- isa_model/inference/services/embedding/openai_embed_service.py +0 -0
- isa_model/inference/services/llm/__init__.py +12 -0
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +99 -0
- isa_model/inference/services/llm/openai_llm_service.py +138 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/openai_vision_service.py +80 -0
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.0.1.dist-info/METADATA +327 -0
- isa_model-0.0.1.dist-info/RECORD +86 -0
- isa_model-0.0.1.dist-info/WHEEL +5 -0
- isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
- isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,439 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import json
|
4
|
+
import numpy as np
|
5
|
+
import base64
|
6
|
+
from typing import Dict, Any, Optional, List, Union
|
7
|
+
|
8
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
+
from isa_model.inference.base import ModelType, Capability
|
10
|
+
from isa_model.inference.providers.model_cache_manager import ModelCacheManager
|
11
|
+
import asyncio
|
12
|
+
|
13
|
+
# 设置日志
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
class TritonProvider(BaseProvider):
|
17
|
+
"""
|
18
|
+
Provider for Triton Inference Server models.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
22
|
+
"""
|
23
|
+
Initialize the Triton provider.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
config: Configuration for the provider
|
27
|
+
"""
|
28
|
+
super().__init__(config or {})
|
29
|
+
|
30
|
+
# Default configuration
|
31
|
+
self.default_config = {
|
32
|
+
"server_url": os.environ.get("TRITON_SERVER_URL", "http://localhost:8000"),
|
33
|
+
"model_repository": os.environ.get(
|
34
|
+
"MODEL_REPOSITORY",
|
35
|
+
os.path.join(os.getcwd(), "models/triton/model_repository")
|
36
|
+
),
|
37
|
+
"http_headers": {},
|
38
|
+
"verbose": True,
|
39
|
+
"client_timeout": 300.0, # 5 minutes timeout
|
40
|
+
"max_batch_size": 8,
|
41
|
+
"max_sequence_length": 2048,
|
42
|
+
"temperature": 0.7,
|
43
|
+
"top_p": 0.9,
|
44
|
+
"model_cache_size": 5, # LRU cache size
|
45
|
+
"tokenizer_name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
46
|
+
}
|
47
|
+
|
48
|
+
# Merge provided config with defaults
|
49
|
+
self.config = {**self.default_config, **self.config}
|
50
|
+
|
51
|
+
# Set up logging
|
52
|
+
log_level = self.config.get("log_level", "INFO")
|
53
|
+
numeric_level = getattr(logging, log_level.upper(), logging.INFO)
|
54
|
+
logger.setLevel(numeric_level)
|
55
|
+
|
56
|
+
logger.info(f"Initialized Triton provider with URL: {self.config['server_url']}")
|
57
|
+
|
58
|
+
# Initialize model cache manager
|
59
|
+
self.model_cache = ModelCacheManager(
|
60
|
+
cache_size=self.config.get("model_cache_size"),
|
61
|
+
model_repository=self.config.get("model_repository")
|
62
|
+
)
|
63
|
+
|
64
|
+
# For MLflow Gateway compatibility
|
65
|
+
self.triton_url = config.get("triton_url", "localhost:8001")
|
66
|
+
|
67
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
68
|
+
"""Get provider capabilities by model type"""
|
69
|
+
return {
|
70
|
+
ModelType.LLM: [
|
71
|
+
Capability.CHAT,
|
72
|
+
Capability.COMPLETION
|
73
|
+
],
|
74
|
+
ModelType.EMBEDDING: [
|
75
|
+
Capability.EMBEDDING
|
76
|
+
],
|
77
|
+
ModelType.VISION: [
|
78
|
+
Capability.IMAGE_UNDERSTANDING
|
79
|
+
]
|
80
|
+
}
|
81
|
+
|
82
|
+
def get_models(self, model_type: ModelType) -> List[str]:
|
83
|
+
"""Get available models for given type"""
|
84
|
+
# Query the model cache manager for available models
|
85
|
+
return self.model_cache.list_available_models(model_type)
|
86
|
+
|
87
|
+
async def load_model(self, model_name: str, model_type: ModelType) -> bool:
|
88
|
+
"""Load a model into Triton server via Model Cache Manager"""
|
89
|
+
return await self.model_cache.load_model(model_name, model_type)
|
90
|
+
|
91
|
+
async def unload_model(self, model_name: str) -> bool:
|
92
|
+
"""Unload a model from Triton server"""
|
93
|
+
return await self.model_cache.unload_model(model_name)
|
94
|
+
|
95
|
+
def get_config(self) -> Dict[str, Any]:
|
96
|
+
"""
|
97
|
+
Get the configuration for this provider.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
Provider configuration
|
101
|
+
"""
|
102
|
+
return self.config
|
103
|
+
|
104
|
+
def create_client(self):
|
105
|
+
"""
|
106
|
+
Create a Triton client instance.
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
Triton HTTP client
|
110
|
+
"""
|
111
|
+
try:
|
112
|
+
import tritonclient.http as httpclient
|
113
|
+
|
114
|
+
server_url = self.config.get("triton_url", self.config["server_url"])
|
115
|
+
|
116
|
+
client = httpclient.InferenceServerClient(
|
117
|
+
url=server_url,
|
118
|
+
verbose=self.config["verbose"],
|
119
|
+
connection_timeout=self.config["client_timeout"],
|
120
|
+
network_timeout=self.config["client_timeout"]
|
121
|
+
)
|
122
|
+
|
123
|
+
return client
|
124
|
+
except ImportError:
|
125
|
+
logger.error("tritonclient package not installed. Please install with: pip install tritonclient")
|
126
|
+
raise
|
127
|
+
except Exception as e:
|
128
|
+
logger.error(f"Error creating Triton client: {str(e)}")
|
129
|
+
raise
|
130
|
+
|
131
|
+
def is_server_live(self) -> bool:
|
132
|
+
"""
|
133
|
+
Check if the Triton server is live.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
True if the server is live, False otherwise
|
137
|
+
"""
|
138
|
+
try:
|
139
|
+
client = self.create_client()
|
140
|
+
return client.is_server_live()
|
141
|
+
except Exception as e:
|
142
|
+
logger.error(f"Error checking server liveness: {str(e)}")
|
143
|
+
return False
|
144
|
+
|
145
|
+
def is_model_ready(self, model_name: str) -> bool:
|
146
|
+
"""
|
147
|
+
Check if a model is ready on the Triton server.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
model_name: Name of the model
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
True if the model is ready, False otherwise
|
154
|
+
"""
|
155
|
+
try:
|
156
|
+
client = self.create_client()
|
157
|
+
return client.is_model_ready(model_name)
|
158
|
+
except Exception as e:
|
159
|
+
logger.error(f"Error checking model readiness: {str(e)}")
|
160
|
+
return False
|
161
|
+
|
162
|
+
def get_model_metadata(self, model_name: str) -> Dict[str, Any]:
|
163
|
+
"""
|
164
|
+
Get metadata for a model.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
model_name: Name of the model
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
Model metadata
|
171
|
+
"""
|
172
|
+
try:
|
173
|
+
client = self.create_client()
|
174
|
+
metadata = client.get_model_metadata(model_name)
|
175
|
+
return metadata
|
176
|
+
except Exception as e:
|
177
|
+
logger.error(f"Error getting model metadata: {str(e)}")
|
178
|
+
raise
|
179
|
+
|
180
|
+
def get_model_config(self, model_name: str) -> Dict[str, Any]:
|
181
|
+
"""
|
182
|
+
Get configuration for a model.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
model_name: Name of the model
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Model configuration
|
189
|
+
"""
|
190
|
+
try:
|
191
|
+
client = self.create_client()
|
192
|
+
config = client.get_model_config(model_name)
|
193
|
+
return config
|
194
|
+
except Exception as e:
|
195
|
+
logger.error(f"Error getting model config: {str(e)}")
|
196
|
+
raise
|
197
|
+
|
198
|
+
def is_reasoning_model(self, model_name: str) -> bool:
|
199
|
+
"""Check if the model is optimized for reasoning tasks"""
|
200
|
+
# This is a simple implementation, could be enhanced to check model metadata
|
201
|
+
return model_name.lower().find("reasoning") != -1 or model_name.lower() in ["llama3", "mistral"]
|
202
|
+
|
203
|
+
# Methods for MLflow Gateway compatibility
|
204
|
+
|
205
|
+
async def completions(self, prompt: str, model_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
206
|
+
"""
|
207
|
+
Generate completions for MLflow Gateway.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
prompt: User prompt text
|
211
|
+
model_name: Name of the model to use
|
212
|
+
params: Additional parameters
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
Completion response
|
216
|
+
"""
|
217
|
+
try:
|
218
|
+
import tritonclient.http as httpclient
|
219
|
+
|
220
|
+
# Create client
|
221
|
+
client = self.create_client()
|
222
|
+
|
223
|
+
# Generate config
|
224
|
+
generation_config = {
|
225
|
+
"temperature": params.get("temperature", 0.7),
|
226
|
+
"max_new_tokens": params.get("max_tokens", 512),
|
227
|
+
"top_p": params.get("top_p", 0.9),
|
228
|
+
"top_k": params.get("top_k", 50),
|
229
|
+
}
|
230
|
+
|
231
|
+
# Prepare inputs
|
232
|
+
inputs = []
|
233
|
+
|
234
|
+
# Add prompt input
|
235
|
+
prompt_data = np.array([prompt], dtype=np.object_)
|
236
|
+
prompt_input = httpclient.InferInput("prompt", prompt_data.shape, "BYTES")
|
237
|
+
prompt_input.set_data_from_numpy(prompt_data)
|
238
|
+
inputs.append(prompt_input)
|
239
|
+
|
240
|
+
# Add system prompt if provided
|
241
|
+
if "system_prompt" in params:
|
242
|
+
system_data = np.array([params["system_prompt"]], dtype=np.object_)
|
243
|
+
system_input = httpclient.InferInput("system_prompt", system_data.shape, "BYTES")
|
244
|
+
system_input.set_data_from_numpy(system_data)
|
245
|
+
inputs.append(system_input)
|
246
|
+
|
247
|
+
# Add generation config
|
248
|
+
config_data = np.array([json.dumps(generation_config)], dtype=np.object_)
|
249
|
+
config_input = httpclient.InferInput("generation_config", config_data.shape, "BYTES")
|
250
|
+
config_input.set_data_from_numpy(config_data)
|
251
|
+
inputs.append(config_input)
|
252
|
+
|
253
|
+
# Create output
|
254
|
+
outputs = [httpclient.InferRequestedOutput("text_output")]
|
255
|
+
|
256
|
+
# Run inference
|
257
|
+
response = await asyncio.to_thread(
|
258
|
+
client.infer,
|
259
|
+
model_name,
|
260
|
+
inputs,
|
261
|
+
outputs=outputs
|
262
|
+
)
|
263
|
+
|
264
|
+
# Process response
|
265
|
+
output = response.as_numpy("text_output")
|
266
|
+
text = output[0].decode('utf-8')
|
267
|
+
|
268
|
+
return {
|
269
|
+
"completion": text,
|
270
|
+
"metadata": {
|
271
|
+
"model": model_name,
|
272
|
+
"provider": "triton",
|
273
|
+
"token_usage": {
|
274
|
+
"prompt_tokens": len(prompt.split()),
|
275
|
+
"completion_tokens": len(text.split()),
|
276
|
+
"total_tokens": len(prompt.split()) + len(text.split())
|
277
|
+
}
|
278
|
+
}
|
279
|
+
}
|
280
|
+
|
281
|
+
except Exception as e:
|
282
|
+
logger.error(f"Error during completion: {str(e)}")
|
283
|
+
return {
|
284
|
+
"error": str(e),
|
285
|
+
"metadata": {
|
286
|
+
"model": model_name,
|
287
|
+
"provider": "triton"
|
288
|
+
}
|
289
|
+
}
|
290
|
+
|
291
|
+
async def embeddings(self, text: Union[str, List[str]], model_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
292
|
+
"""
|
293
|
+
Generate embeddings for MLflow Gateway.
|
294
|
+
|
295
|
+
Args:
|
296
|
+
text: Text or list of texts to embed
|
297
|
+
model_name: Name of the model to use
|
298
|
+
params: Additional parameters
|
299
|
+
|
300
|
+
Returns:
|
301
|
+
Embedding response
|
302
|
+
"""
|
303
|
+
try:
|
304
|
+
import tritonclient.http as httpclient
|
305
|
+
|
306
|
+
# Create client
|
307
|
+
client = self.create_client()
|
308
|
+
|
309
|
+
# Normalize parameter
|
310
|
+
normalize = params.get("normalize", True)
|
311
|
+
|
312
|
+
# Handle input text (convert to list if it's a single string)
|
313
|
+
text_list = [text] if isinstance(text, str) else text
|
314
|
+
|
315
|
+
# Add text input
|
316
|
+
text_data = np.array(text_list, dtype=np.object_)
|
317
|
+
text_input = httpclient.InferInput("text_input", text_data.shape, "BYTES")
|
318
|
+
text_input.set_data_from_numpy(text_data)
|
319
|
+
|
320
|
+
# Add normalize parameter
|
321
|
+
normalize_data = np.array([normalize], dtype=bool)
|
322
|
+
normalize_input = httpclient.InferInput("normalize", normalize_data.shape, "BOOL")
|
323
|
+
normalize_input.set_data_from_numpy(normalize_data)
|
324
|
+
|
325
|
+
# Create inputs
|
326
|
+
inputs = [text_input, normalize_input]
|
327
|
+
|
328
|
+
# Create output
|
329
|
+
outputs = [httpclient.InferRequestedOutput("embedding")]
|
330
|
+
|
331
|
+
# Run inference
|
332
|
+
response = await asyncio.to_thread(
|
333
|
+
client.infer,
|
334
|
+
model_name,
|
335
|
+
inputs,
|
336
|
+
outputs=outputs
|
337
|
+
)
|
338
|
+
|
339
|
+
# Process response
|
340
|
+
embedding_output = response.as_numpy("embedding")
|
341
|
+
|
342
|
+
return {
|
343
|
+
"embedding": embedding_output.tolist(),
|
344
|
+
"metadata": {
|
345
|
+
"model": model_name,
|
346
|
+
"provider": "triton",
|
347
|
+
"dimensions": embedding_output.shape[-1]
|
348
|
+
}
|
349
|
+
}
|
350
|
+
|
351
|
+
except Exception as e:
|
352
|
+
logger.error(f"Error during embedding: {str(e)}")
|
353
|
+
return {
|
354
|
+
"error": str(e),
|
355
|
+
"metadata": {
|
356
|
+
"model": model_name,
|
357
|
+
"provider": "triton"
|
358
|
+
}
|
359
|
+
}
|
360
|
+
|
361
|
+
async def speech_to_text(self, audio: str, model_name: str, params: Dict[str, Any]) -> Dict[str, Any]:
|
362
|
+
"""
|
363
|
+
Transcribe audio for MLflow Gateway.
|
364
|
+
|
365
|
+
Args:
|
366
|
+
audio: Base64 encoded audio data or URL
|
367
|
+
model_name: Name of the model to use
|
368
|
+
params: Additional parameters
|
369
|
+
|
370
|
+
Returns:
|
371
|
+
Transcription response
|
372
|
+
"""
|
373
|
+
try:
|
374
|
+
import tritonclient.http as httpclient
|
375
|
+
|
376
|
+
# Create client
|
377
|
+
client = self.create_client()
|
378
|
+
|
379
|
+
# Decode audio from base64 or download from URL
|
380
|
+
if audio.startswith(("http://", "https://")):
|
381
|
+
import requests
|
382
|
+
audio_data = requests.get(audio).content
|
383
|
+
else:
|
384
|
+
audio_data = base64.b64decode(audio)
|
385
|
+
|
386
|
+
# Language parameter
|
387
|
+
language = params.get("language", "en")
|
388
|
+
|
389
|
+
# Process audio to get numpy array
|
390
|
+
import io
|
391
|
+
import librosa
|
392
|
+
|
393
|
+
with io.BytesIO(audio_data) as audio_bytes:
|
394
|
+
audio_array, _ = librosa.load(audio_bytes, sr=16000)
|
395
|
+
audio_array = audio_array.astype(np.float32)
|
396
|
+
|
397
|
+
# Create inputs
|
398
|
+
audio_input = httpclient.InferInput("audio_input", audio_array.shape, "FP32")
|
399
|
+
audio_input.set_data_from_numpy(audio_array)
|
400
|
+
|
401
|
+
language_data = np.array([language], dtype=np.object_)
|
402
|
+
language_input = httpclient.InferInput("language", language_data.shape, "BYTES")
|
403
|
+
language_input.set_data_from_numpy(language_data)
|
404
|
+
|
405
|
+
inputs = [audio_input, language_input]
|
406
|
+
|
407
|
+
# Create output
|
408
|
+
outputs = [httpclient.InferRequestedOutput("text_output")]
|
409
|
+
|
410
|
+
# Run inference
|
411
|
+
response = await asyncio.to_thread(
|
412
|
+
client.infer,
|
413
|
+
model_name,
|
414
|
+
inputs,
|
415
|
+
outputs=outputs
|
416
|
+
)
|
417
|
+
|
418
|
+
# Process response
|
419
|
+
output = response.as_numpy("text_output")
|
420
|
+
transcription = output[0].decode('utf-8')
|
421
|
+
|
422
|
+
return {
|
423
|
+
"text": transcription,
|
424
|
+
"metadata": {
|
425
|
+
"model": model_name,
|
426
|
+
"provider": "triton",
|
427
|
+
"language": language
|
428
|
+
}
|
429
|
+
}
|
430
|
+
|
431
|
+
except Exception as e:
|
432
|
+
logger.error(f"Error during speech-to-text: {str(e)}")
|
433
|
+
return {
|
434
|
+
"error": str(e),
|
435
|
+
"metadata": {
|
436
|
+
"model": model_name,
|
437
|
+
"provider": "triton"
|
438
|
+
}
|
439
|
+
}
|
@@ -0,0 +1,14 @@
|
|
1
|
+
"""
|
2
|
+
Services - Service implementations for different model types
|
3
|
+
|
4
|
+
File: isa_model/inference/services/__init__.py
|
5
|
+
This module contains service implementations for different AI model types.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from .base_service import BaseService, BaseLLMService, BaseEmbeddingService
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"BaseService",
|
12
|
+
"BaseLLMService",
|
13
|
+
"BaseEmbeddingService"
|
14
|
+
]
|
@@ -0,0 +1,91 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional, BinaryIO
|
3
|
+
from isa_model.inference.services.base_service import BaseService
|
4
|
+
|
5
|
+
class BaseSTTService(BaseService):
|
6
|
+
"""Base class for Speech-to-Text services"""
|
7
|
+
|
8
|
+
@abstractmethod
|
9
|
+
async def transcribe_audio(
|
10
|
+
self,
|
11
|
+
audio_file: Union[str, BinaryIO],
|
12
|
+
language: Optional[str] = None,
|
13
|
+
prompt: Optional[str] = None
|
14
|
+
) -> Dict[str, Any]:
|
15
|
+
"""
|
16
|
+
Transcribe audio file to text
|
17
|
+
|
18
|
+
Args:
|
19
|
+
audio_file: Path to audio file or file-like object
|
20
|
+
language: Language code (e.g., 'en', 'es', 'fr')
|
21
|
+
prompt: Optional prompt to guide transcription
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Dict containing transcription results with keys:
|
25
|
+
- text: The transcribed text
|
26
|
+
- language: Detected/specified language
|
27
|
+
- confidence: Confidence score (if available)
|
28
|
+
- segments: Time-segmented transcription (if available)
|
29
|
+
"""
|
30
|
+
pass
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
async def transcribe_audio_batch(
|
34
|
+
self,
|
35
|
+
audio_files: List[Union[str, BinaryIO]],
|
36
|
+
language: Optional[str] = None,
|
37
|
+
prompt: Optional[str] = None
|
38
|
+
) -> List[Dict[str, Any]]:
|
39
|
+
"""
|
40
|
+
Transcribe multiple audio files
|
41
|
+
|
42
|
+
Args:
|
43
|
+
audio_files: List of audio file paths or file-like objects
|
44
|
+
language: Language code (e.g., 'en', 'es', 'fr')
|
45
|
+
prompt: Optional prompt to guide transcription
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
List of transcription results
|
49
|
+
"""
|
50
|
+
pass
|
51
|
+
|
52
|
+
@abstractmethod
|
53
|
+
async def detect_language(self, audio_file: Union[str, BinaryIO]) -> Dict[str, Any]:
|
54
|
+
"""
|
55
|
+
Detect language of audio file
|
56
|
+
|
57
|
+
Args:
|
58
|
+
audio_file: Path to audio file or file-like object
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Dict containing language detection results with keys:
|
62
|
+
- language: Detected language code
|
63
|
+
- confidence: Confidence score
|
64
|
+
- alternatives: List of alternative languages with scores
|
65
|
+
"""
|
66
|
+
pass
|
67
|
+
|
68
|
+
@abstractmethod
|
69
|
+
def get_supported_formats(self) -> List[str]:
|
70
|
+
"""
|
71
|
+
Get list of supported audio formats
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
List of supported file extensions (e.g., ['mp3', 'wav', 'flac'])
|
75
|
+
"""
|
76
|
+
pass
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
def get_supported_languages(self) -> List[str]:
|
80
|
+
"""
|
81
|
+
Get list of supported language codes
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
List of supported language codes (e.g., ['en', 'es', 'fr'])
|
85
|
+
"""
|
86
|
+
pass
|
87
|
+
|
88
|
+
@abstractmethod
|
89
|
+
async def close(self):
|
90
|
+
"""Cleanup resources"""
|
91
|
+
pass
|
@@ -0,0 +1,136 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, Optional, BinaryIO
|
3
|
+
from isa_model.inference.services.base_service import BaseService
|
4
|
+
|
5
|
+
class BaseTTSService(BaseService):
|
6
|
+
"""Base class for Text-to-Speech services"""
|
7
|
+
|
8
|
+
@abstractmethod
|
9
|
+
async def synthesize_speech(
|
10
|
+
self,
|
11
|
+
text: str,
|
12
|
+
voice: Optional[str] = None,
|
13
|
+
speed: float = 1.0,
|
14
|
+
pitch: float = 1.0,
|
15
|
+
format: str = "mp3"
|
16
|
+
) -> Dict[str, Any]:
|
17
|
+
"""
|
18
|
+
Synthesize speech from text
|
19
|
+
|
20
|
+
Args:
|
21
|
+
text: Input text to convert to speech
|
22
|
+
voice: Voice ID or name to use
|
23
|
+
speed: Speech speed multiplier (0.5-2.0)
|
24
|
+
pitch: Pitch adjustment (-1.0 to 1.0)
|
25
|
+
format: Audio format ('mp3', 'wav', 'ogg')
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
Dict containing synthesis results with keys:
|
29
|
+
- audio_data: Binary audio data
|
30
|
+
- format: Audio format
|
31
|
+
- duration: Audio duration in seconds
|
32
|
+
- sample_rate: Audio sample rate
|
33
|
+
"""
|
34
|
+
pass
|
35
|
+
|
36
|
+
@abstractmethod
|
37
|
+
async def synthesize_speech_to_file(
|
38
|
+
self,
|
39
|
+
text: str,
|
40
|
+
output_path: str,
|
41
|
+
voice: Optional[str] = None,
|
42
|
+
speed: float = 1.0,
|
43
|
+
pitch: float = 1.0,
|
44
|
+
format: str = "mp3"
|
45
|
+
) -> Dict[str, Any]:
|
46
|
+
"""
|
47
|
+
Synthesize speech and save directly to file
|
48
|
+
|
49
|
+
Args:
|
50
|
+
text: Input text to convert to speech
|
51
|
+
output_path: Path to save the audio file
|
52
|
+
voice: Voice ID or name to use
|
53
|
+
speed: Speech speed multiplier (0.5-2.0)
|
54
|
+
pitch: Pitch adjustment (-1.0 to 1.0)
|
55
|
+
format: Audio format ('mp3', 'wav', 'ogg')
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
Dict containing synthesis results with keys:
|
59
|
+
- file_path: Path to saved audio file
|
60
|
+
- duration: Audio duration in seconds
|
61
|
+
- sample_rate: Audio sample rate
|
62
|
+
"""
|
63
|
+
pass
|
64
|
+
|
65
|
+
@abstractmethod
|
66
|
+
async def synthesize_speech_batch(
|
67
|
+
self,
|
68
|
+
texts: List[str],
|
69
|
+
voice: Optional[str] = None,
|
70
|
+
speed: float = 1.0,
|
71
|
+
pitch: float = 1.0,
|
72
|
+
format: str = "mp3"
|
73
|
+
) -> List[Dict[str, Any]]:
|
74
|
+
"""
|
75
|
+
Synthesize speech for multiple texts
|
76
|
+
|
77
|
+
Args:
|
78
|
+
texts: List of input texts to convert to speech
|
79
|
+
voice: Voice ID or name to use
|
80
|
+
speed: Speech speed multiplier (0.5-2.0)
|
81
|
+
pitch: Pitch adjustment (-1.0 to 1.0)
|
82
|
+
format: Audio format ('mp3', 'wav', 'ogg')
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
List of synthesis result dictionaries
|
86
|
+
"""
|
87
|
+
pass
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def get_available_voices(self) -> List[Dict[str, Any]]:
|
91
|
+
"""
|
92
|
+
Get list of available voices
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
List of voice information dictionaries with keys:
|
96
|
+
- id: Voice identifier
|
97
|
+
- name: Human-readable voice name
|
98
|
+
- language: Language code (e.g., 'en-US', 'es-ES')
|
99
|
+
- gender: Voice gender ('male', 'female', 'neutral')
|
100
|
+
- age: Voice age category ('adult', 'child', 'elderly')
|
101
|
+
"""
|
102
|
+
pass
|
103
|
+
|
104
|
+
@abstractmethod
|
105
|
+
def get_supported_formats(self) -> List[str]:
|
106
|
+
"""
|
107
|
+
Get list of supported audio formats
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
List of supported file extensions (e.g., ['mp3', 'wav', 'ogg'])
|
111
|
+
"""
|
112
|
+
pass
|
113
|
+
|
114
|
+
@abstractmethod
|
115
|
+
def get_voice_info(self, voice_id: str) -> Dict[str, Any]:
|
116
|
+
"""
|
117
|
+
Get detailed information about a specific voice
|
118
|
+
|
119
|
+
Args:
|
120
|
+
voice_id: Voice identifier
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Dict containing voice information:
|
124
|
+
- id: Voice identifier
|
125
|
+
- name: Human-readable voice name
|
126
|
+
- language: Language code
|
127
|
+
- gender: Voice gender
|
128
|
+
- description: Voice description
|
129
|
+
- sample_rate: Default sample rate
|
130
|
+
"""
|
131
|
+
pass
|
132
|
+
|
133
|
+
@abstractmethod
|
134
|
+
async def close(self):
|
135
|
+
"""Cleanup resources"""
|
136
|
+
pass
|