isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,186 @@
|
|
1
|
+
import logging
|
2
|
+
import asyncio
|
3
|
+
import io
|
4
|
+
import numpy as np
|
5
|
+
from typing import Dict, Any, Optional, Union, BinaryIO
|
6
|
+
|
7
|
+
from isa_model.inference.services.base_service import BaseService
|
8
|
+
from isa_model.inference.backends.triton_client import TritonClient
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class WhisperService(BaseService):
|
14
|
+
"""
|
15
|
+
Service for Whisper speech-to-text using Triton Inference Server.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def __init__(self, triton_url: str = "localhost:8001", model_name: str = "whisper"):
|
19
|
+
"""
|
20
|
+
Initialize the Whisper service.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
triton_url: URL of the Triton Inference Server
|
24
|
+
model_name: Name of the model in Triton
|
25
|
+
"""
|
26
|
+
super().__init__()
|
27
|
+
self.triton_url = triton_url
|
28
|
+
self.model_name = model_name
|
29
|
+
self.client = None
|
30
|
+
|
31
|
+
# Default configuration
|
32
|
+
self.default_config = {
|
33
|
+
"language": "en",
|
34
|
+
"sampling_rate": 16000
|
35
|
+
}
|
36
|
+
|
37
|
+
self.logger = logger
|
38
|
+
|
39
|
+
async def load(self) -> None:
|
40
|
+
"""
|
41
|
+
Load the client connection to Triton.
|
42
|
+
"""
|
43
|
+
if self.is_loaded():
|
44
|
+
return
|
45
|
+
|
46
|
+
try:
|
47
|
+
from tritonclient.http import InferenceServerClient
|
48
|
+
|
49
|
+
# Create Triton client
|
50
|
+
self.logger.info(f"Connecting to Triton server at {self.triton_url}")
|
51
|
+
self.client = TritonClient(self.triton_url)
|
52
|
+
|
53
|
+
# Check if model is ready
|
54
|
+
if not await self.client.is_model_ready(self.model_name):
|
55
|
+
self.logger.error(f"Model {self.model_name} is not ready on Triton server")
|
56
|
+
raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
|
57
|
+
|
58
|
+
self._loaded = True
|
59
|
+
self.logger.info(f"Connected to Triton for model {self.model_name}")
|
60
|
+
|
61
|
+
except Exception as e:
|
62
|
+
self.logger.error(f"Failed to connect to Triton: {str(e)}")
|
63
|
+
raise
|
64
|
+
|
65
|
+
async def unload(self) -> None:
|
66
|
+
"""
|
67
|
+
Unload the client connection.
|
68
|
+
"""
|
69
|
+
if not self.is_loaded():
|
70
|
+
return
|
71
|
+
|
72
|
+
self.client = None
|
73
|
+
self._loaded = False
|
74
|
+
self.logger.info("Triton client connection closed")
|
75
|
+
|
76
|
+
async def transcribe(self,
|
77
|
+
audio: Union[str, BinaryIO, bytes, np.ndarray],
|
78
|
+
language: str = "en",
|
79
|
+
config: Optional[Dict[str, Any]] = None) -> str:
|
80
|
+
"""
|
81
|
+
Transcribe audio to text using Triton.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
audio: Audio input (file path, file-like object, bytes, or numpy array)
|
85
|
+
language: Language code (e.g., "en", "fr")
|
86
|
+
config: Additional configuration parameters
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
Transcribed text
|
90
|
+
"""
|
91
|
+
if not self.is_loaded():
|
92
|
+
await self.load()
|
93
|
+
|
94
|
+
# Process audio to get numpy array
|
95
|
+
audio_array = await self._process_audio_input(audio)
|
96
|
+
|
97
|
+
# Get configuration
|
98
|
+
merged_config = self.default_config.copy()
|
99
|
+
if config:
|
100
|
+
merged_config.update(config)
|
101
|
+
|
102
|
+
# Override language if provided
|
103
|
+
if language:
|
104
|
+
merged_config["language"] = language
|
105
|
+
|
106
|
+
try:
|
107
|
+
# Prepare inputs
|
108
|
+
inputs = {
|
109
|
+
"audio_input": audio_array,
|
110
|
+
"language": np.array([merged_config["language"]], dtype=np.object_)
|
111
|
+
}
|
112
|
+
|
113
|
+
# Run inference
|
114
|
+
result = await self.client.infer(
|
115
|
+
model_name=self.model_name,
|
116
|
+
inputs=inputs,
|
117
|
+
outputs=["text_output"]
|
118
|
+
)
|
119
|
+
|
120
|
+
# Extract transcription
|
121
|
+
transcription = result["text_output"][0].decode('utf-8')
|
122
|
+
|
123
|
+
return transcription
|
124
|
+
|
125
|
+
except Exception as e:
|
126
|
+
self.logger.error(f"Error during Whisper transcription: {str(e)}")
|
127
|
+
raise
|
128
|
+
|
129
|
+
async def _process_audio_input(self, audio: Union[str, BinaryIO, bytes, np.ndarray]) -> np.ndarray:
|
130
|
+
"""
|
131
|
+
Process different types of audio inputs into a numpy array.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
audio: Audio input (file path, file-like object, bytes, or numpy array)
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
Numpy array of the audio
|
138
|
+
"""
|
139
|
+
if isinstance(audio, np.ndarray):
|
140
|
+
return audio
|
141
|
+
|
142
|
+
try:
|
143
|
+
import librosa
|
144
|
+
|
145
|
+
if isinstance(audio, str):
|
146
|
+
# File path
|
147
|
+
y, sr = librosa.load(audio, sr=self.default_config["sampling_rate"])
|
148
|
+
return y.astype(np.float32)
|
149
|
+
|
150
|
+
elif isinstance(audio, (io.IOBase, BinaryIO)):
|
151
|
+
# File-like object
|
152
|
+
audio.seek(0)
|
153
|
+
y, sr = librosa.load(audio, sr=self.default_config["sampling_rate"])
|
154
|
+
return y.astype(np.float32)
|
155
|
+
|
156
|
+
elif isinstance(audio, bytes):
|
157
|
+
# Bytes
|
158
|
+
with io.BytesIO(audio) as audio_bytes:
|
159
|
+
y, sr = librosa.load(audio_bytes, sr=self.default_config["sampling_rate"])
|
160
|
+
return y.astype(np.float32)
|
161
|
+
|
162
|
+
else:
|
163
|
+
raise ValueError(f"Unsupported audio type: {type(audio)}")
|
164
|
+
|
165
|
+
except ImportError:
|
166
|
+
self.logger.error("librosa not installed. Please install with: pip install librosa")
|
167
|
+
raise
|
168
|
+
except Exception as e:
|
169
|
+
self.logger.error(f"Error processing audio: {str(e)}")
|
170
|
+
raise
|
171
|
+
|
172
|
+
def get_model_info(self) -> Dict[str, Any]:
|
173
|
+
"""
|
174
|
+
Get information about the model.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
Dictionary containing model information
|
178
|
+
"""
|
179
|
+
return {
|
180
|
+
"name": self.model_name,
|
181
|
+
"type": "speech",
|
182
|
+
"backend": "triton",
|
183
|
+
"url": self.triton_url,
|
184
|
+
"loaded": self.is_loaded(),
|
185
|
+
"config": self.default_config
|
186
|
+
}
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from typing import Dict, Any
|
2
|
+
import tempfile
|
3
|
+
import os
|
4
|
+
from openai import AsyncOpenAI
|
5
|
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
|
+
from ...base_service import BaseService
|
7
|
+
from ...base_provider import BaseProvider
|
8
|
+
from app.config.config_manager import config_manager
|
9
|
+
|
10
|
+
logger = config_manager.get_logger(__name__)
|
11
|
+
|
12
|
+
class YYDSAudioService(BaseService):
|
13
|
+
"""Audio model service wrapper for YYDS"""
|
14
|
+
|
15
|
+
def __init__(self, provider: 'BaseProvider', model_name: str):
|
16
|
+
super().__init__(provider, model_name)
|
17
|
+
# 初始化 AsyncOpenAI 客户端
|
18
|
+
self._client = AsyncOpenAI(
|
19
|
+
api_key=self.config.get('api_key'),
|
20
|
+
base_url=self.config.get('base_url')
|
21
|
+
)
|
22
|
+
self.language = self.config.get('language', None)
|
23
|
+
|
24
|
+
@property
|
25
|
+
def client(self) -> AsyncOpenAI:
|
26
|
+
"""获取底层的 OpenAI 客户端"""
|
27
|
+
return self._client
|
28
|
+
|
29
|
+
@retry(
|
30
|
+
stop=stop_after_attempt(3),
|
31
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
32
|
+
reraise=True
|
33
|
+
)
|
34
|
+
async def transcribe(self, audio_data: bytes) -> Dict[str, Any]:
|
35
|
+
"""转写音频数据
|
36
|
+
|
37
|
+
Args:
|
38
|
+
audio_data: 音频二进制数据
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
Dict[str, Any]: 包含转写文本的字典
|
42
|
+
"""
|
43
|
+
try:
|
44
|
+
# 创建临时文件存储音频数据
|
45
|
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file:
|
46
|
+
temp_file.write(audio_data)
|
47
|
+
temp_file.flush()
|
48
|
+
|
49
|
+
# 以二进制模式打开文件用于 API 请求
|
50
|
+
with open(temp_file.name, 'rb') as audio_file:
|
51
|
+
# 只在有效的 ISO-639-1 语言代码时包含 language 参数
|
52
|
+
params = {
|
53
|
+
'model': self.model_name,
|
54
|
+
'file': audio_file,
|
55
|
+
}
|
56
|
+
if self.language and isinstance(self.language, str):
|
57
|
+
params['language'] = self.language
|
58
|
+
|
59
|
+
response = await self._client.audio.transcriptions.create(**params)
|
60
|
+
|
61
|
+
# 清理临时文件
|
62
|
+
os.unlink(temp_file.name)
|
63
|
+
|
64
|
+
# 返回包含转写文本的字典
|
65
|
+
return {
|
66
|
+
"text": response.text
|
67
|
+
}
|
68
|
+
|
69
|
+
except Exception as e:
|
70
|
+
logger.error(f"Error in audio transcription: {e}")
|
71
|
+
raise
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
|
3
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
4
|
+
|
5
|
+
T = TypeVar('T') # Generic type for responses
|
6
|
+
|
7
|
+
class BaseService(ABC):
|
8
|
+
"""Base class for all AI services"""
|
9
|
+
|
10
|
+
def __init__(self, provider: 'BaseProvider', model_name: str):
|
11
|
+
self.provider = provider
|
12
|
+
self.model_name = model_name
|
13
|
+
self.config = provider.get_config()
|
14
|
+
|
15
|
+
def __await__(self):
|
16
|
+
"""Make the service awaitable"""
|
17
|
+
yield
|
18
|
+
return self
|
19
|
+
|
20
|
+
class BaseLLMService(BaseService):
|
21
|
+
"""Base class for LLM services"""
|
22
|
+
|
23
|
+
@abstractmethod
|
24
|
+
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
|
25
|
+
"""Universal invocation method"""
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
async def achat(self, messages: List[Dict[str, str]]) -> T:
|
30
|
+
"""Chat completion method"""
|
31
|
+
pass
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
async def acompletion(self, prompt: str) -> T:
|
35
|
+
"""Text completion method"""
|
36
|
+
pass
|
37
|
+
|
38
|
+
@abstractmethod
|
39
|
+
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
|
40
|
+
"""Generate multiple completions"""
|
41
|
+
pass
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
45
|
+
"""Stream chat responses"""
|
46
|
+
pass
|
47
|
+
|
48
|
+
@abstractmethod
|
49
|
+
def get_token_usage(self) -> Any:
|
50
|
+
"""Get total token usage statistics"""
|
51
|
+
pass
|
52
|
+
|
53
|
+
@abstractmethod
|
54
|
+
def get_last_token_usage(self) -> Dict[str, int]:
|
55
|
+
"""Get token usage from last request"""
|
56
|
+
pass
|
57
|
+
|
58
|
+
class BaseEmbeddingService(BaseService):
|
59
|
+
"""Base class for embedding services"""
|
60
|
+
|
61
|
+
@abstractmethod
|
62
|
+
async def create_text_embedding(self, text: str) -> List[float]:
|
63
|
+
"""Create embedding for single text"""
|
64
|
+
pass
|
65
|
+
|
66
|
+
@abstractmethod
|
67
|
+
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
68
|
+
"""Create embeddings for multiple texts"""
|
69
|
+
pass
|
70
|
+
|
71
|
+
@abstractmethod
|
72
|
+
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
73
|
+
"""Create text chunks with embeddings"""
|
74
|
+
pass
|
75
|
+
|
76
|
+
@abstractmethod
|
77
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
78
|
+
"""Compute similarity between two embeddings"""
|
79
|
+
pass
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
async def close(self):
|
83
|
+
"""Cleanup resources"""
|
84
|
+
pass
|
85
|
+
|
86
|
+
class BaseRerankService(BaseService):
|
87
|
+
"""Base class for reranking services"""
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
async def rerank(
|
91
|
+
self,
|
92
|
+
query: str,
|
93
|
+
documents: List[Dict],
|
94
|
+
top_k: int = 5
|
95
|
+
) -> List[Dict]:
|
96
|
+
"""Rerank documents based on query relevance"""
|
97
|
+
pass
|
98
|
+
|
99
|
+
@abstractmethod
|
100
|
+
async def rerank_texts(
|
101
|
+
self,
|
102
|
+
query: str,
|
103
|
+
texts: List[str]
|
104
|
+
) -> List[Dict]:
|
105
|
+
"""Rerank raw texts based on query relevance"""
|
106
|
+
pass
|
@@ -0,0 +1,66 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from typing import Dict, Any, Optional, Union, BinaryIO
|
3
|
+
from .base_service import BaseService
|
4
|
+
|
5
|
+
class BaseTTSService(BaseService):
|
6
|
+
"""Base class for Text-to-Speech services"""
|
7
|
+
|
8
|
+
@abstractmethod
|
9
|
+
async def generate_speech(
|
10
|
+
self,
|
11
|
+
text: str,
|
12
|
+
voice_id: Optional[str] = None,
|
13
|
+
language: Optional[str] = None,
|
14
|
+
speed: float = 1.0,
|
15
|
+
options: Optional[Dict[str, Any]] = None
|
16
|
+
) -> bytes:
|
17
|
+
"""
|
18
|
+
Generate speech from text
|
19
|
+
|
20
|
+
Args:
|
21
|
+
text: The text to convert to speech
|
22
|
+
voice_id: Optional voice identifier
|
23
|
+
language: Optional language code
|
24
|
+
speed: Speech speed factor (1.0 is normal speed)
|
25
|
+
options: Additional model-specific options
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
Audio data as bytes
|
29
|
+
"""
|
30
|
+
pass
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
async def save_to_file(
|
34
|
+
self,
|
35
|
+
text: str,
|
36
|
+
output_file: Union[str, BinaryIO],
|
37
|
+
voice_id: Optional[str] = None,
|
38
|
+
language: Optional[str] = None,
|
39
|
+
speed: float = 1.0,
|
40
|
+
options: Optional[Dict[str, Any]] = None
|
41
|
+
) -> str:
|
42
|
+
"""
|
43
|
+
Generate speech and save to file
|
44
|
+
|
45
|
+
Args:
|
46
|
+
text: The text to convert to speech
|
47
|
+
output_file: Path to output file or file-like object
|
48
|
+
voice_id: Optional voice identifier
|
49
|
+
language: Optional language code
|
50
|
+
speed: Speech speed factor (1.0 is normal speed)
|
51
|
+
options: Additional model-specific options
|
52
|
+
|
53
|
+
Returns:
|
54
|
+
Path to the saved file
|
55
|
+
"""
|
56
|
+
pass
|
57
|
+
|
58
|
+
@abstractmethod
|
59
|
+
async def get_available_voices(self) -> Dict[str, Any]:
|
60
|
+
"""
|
61
|
+
Get available voices for the TTS service
|
62
|
+
|
63
|
+
Returns:
|
64
|
+
Dictionary of available voices with their details
|
65
|
+
"""
|
66
|
+
pass
|
@@ -0,0 +1,183 @@
|
|
1
|
+
import logging
|
2
|
+
import asyncio
|
3
|
+
import numpy as np
|
4
|
+
from typing import Dict, List, Any, Optional, Union
|
5
|
+
|
6
|
+
from isa_model.inference.services.base_service import BaseService
|
7
|
+
from isa_model.inference.backends.triton_client import TritonClient
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
|
12
|
+
class BgeEmbeddingService(BaseService):
|
13
|
+
"""
|
14
|
+
Service for BGE embedding using Triton Inference Server.
|
15
|
+
"""
|
16
|
+
|
17
|
+
def __init__(self, triton_url: str = "localhost:8001", model_name: str = "bge_embed"):
|
18
|
+
"""
|
19
|
+
Initialize the BGE embedding service.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
triton_url: URL of the Triton Inference Server
|
23
|
+
model_name: Name of the model in Triton
|
24
|
+
"""
|
25
|
+
super().__init__()
|
26
|
+
self.triton_url = triton_url
|
27
|
+
self.model_name = model_name
|
28
|
+
self.client = None
|
29
|
+
|
30
|
+
# Default configuration
|
31
|
+
self.default_config = {
|
32
|
+
"normalize": True
|
33
|
+
}
|
34
|
+
|
35
|
+
self.logger = logger
|
36
|
+
|
37
|
+
async def load(self) -> None:
|
38
|
+
"""
|
39
|
+
Load the client connection to Triton.
|
40
|
+
"""
|
41
|
+
if self.is_loaded():
|
42
|
+
return
|
43
|
+
|
44
|
+
try:
|
45
|
+
# Create Triton client
|
46
|
+
self.logger.info(f"Connecting to Triton server at {self.triton_url}")
|
47
|
+
self.client = TritonClient(self.triton_url)
|
48
|
+
|
49
|
+
# Check if model is ready
|
50
|
+
if not await self.client.is_model_ready(self.model_name):
|
51
|
+
self.logger.error(f"Model {self.model_name} is not ready on Triton server")
|
52
|
+
raise RuntimeError(f"Model {self.model_name} is not ready on Triton server")
|
53
|
+
|
54
|
+
self._loaded = True
|
55
|
+
self.logger.info(f"Connected to Triton for model {self.model_name}")
|
56
|
+
|
57
|
+
except Exception as e:
|
58
|
+
self.logger.error(f"Failed to connect to Triton: {str(e)}")
|
59
|
+
raise
|
60
|
+
|
61
|
+
async def unload(self) -> None:
|
62
|
+
"""
|
63
|
+
Unload the client connection.
|
64
|
+
"""
|
65
|
+
if not self.is_loaded():
|
66
|
+
return
|
67
|
+
|
68
|
+
self.client = None
|
69
|
+
self._loaded = False
|
70
|
+
self.logger.info("Triton client connection closed")
|
71
|
+
|
72
|
+
async def embed(self,
|
73
|
+
texts: Union[str, List[str]],
|
74
|
+
normalize: Optional[bool] = None) -> np.ndarray:
|
75
|
+
"""
|
76
|
+
Generate embeddings for texts using Triton.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
texts: Single text or list of texts to embed
|
80
|
+
normalize: Whether to normalize embeddings (if None, use default)
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Numpy array of embeddings, shape [batch_size, embedding_dim]
|
84
|
+
"""
|
85
|
+
if not self.is_loaded():
|
86
|
+
await self.load()
|
87
|
+
|
88
|
+
# Handle single text input
|
89
|
+
if isinstance(texts, str):
|
90
|
+
texts = [texts]
|
91
|
+
|
92
|
+
# Use default normalize setting if not specified
|
93
|
+
if normalize is None:
|
94
|
+
normalize = self.default_config["normalize"]
|
95
|
+
|
96
|
+
try:
|
97
|
+
# Prepare inputs
|
98
|
+
inputs = {
|
99
|
+
"text_input": texts,
|
100
|
+
"normalize": np.array([normalize], dtype=bool)
|
101
|
+
}
|
102
|
+
|
103
|
+
# Run inference
|
104
|
+
result = await self.client.infer(
|
105
|
+
model_name=self.model_name,
|
106
|
+
inputs=inputs,
|
107
|
+
outputs=["embedding"]
|
108
|
+
)
|
109
|
+
|
110
|
+
# Extract embeddings
|
111
|
+
embeddings = result["embedding"]
|
112
|
+
|
113
|
+
return embeddings
|
114
|
+
|
115
|
+
except Exception as e:
|
116
|
+
self.logger.error(f"Error during embedding generation: {str(e)}")
|
117
|
+
raise
|
118
|
+
|
119
|
+
async def similarity(self,
|
120
|
+
text1: str,
|
121
|
+
text2: str,
|
122
|
+
normalize: Optional[bool] = None) -> float:
|
123
|
+
"""
|
124
|
+
Calculate the similarity between two texts.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
text1: First text
|
128
|
+
text2: Second text
|
129
|
+
normalize: Whether to normalize embeddings (if None, use default)
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
Cosine similarity score (float between -1 and 1)
|
133
|
+
"""
|
134
|
+
# Generate embeddings for both texts
|
135
|
+
embeddings = await self.embed([text1, text2], normalize=normalize)
|
136
|
+
|
137
|
+
# Calculate cosine similarity
|
138
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
139
|
+
similarity = cosine_similarity(embeddings[0:1], embeddings[1:2])[0][0]
|
140
|
+
|
141
|
+
return float(similarity)
|
142
|
+
|
143
|
+
async def batch_similarity(self,
|
144
|
+
queries: List[str],
|
145
|
+
documents: List[str],
|
146
|
+
normalize: Optional[bool] = None) -> np.ndarray:
|
147
|
+
"""
|
148
|
+
Calculate similarities between queries and documents.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
queries: List of query texts
|
152
|
+
documents: List of document texts
|
153
|
+
normalize: Whether to normalize embeddings (if None, use default)
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
Numpy array of similarity scores, shape [len(queries), len(documents)]
|
157
|
+
"""
|
158
|
+
# Generate embeddings for queries and documents
|
159
|
+
query_embeddings = await self.embed(queries, normalize=normalize)
|
160
|
+
doc_embeddings = await self.embed(documents, normalize=normalize)
|
161
|
+
|
162
|
+
# Calculate cosine similarities
|
163
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
164
|
+
similarities = cosine_similarity(query_embeddings, doc_embeddings)
|
165
|
+
|
166
|
+
return similarities
|
167
|
+
|
168
|
+
def get_model_info(self) -> Dict[str, Any]:
|
169
|
+
"""
|
170
|
+
Get information about the model.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
Dictionary containing model information
|
174
|
+
"""
|
175
|
+
return {
|
176
|
+
"name": self.model_name,
|
177
|
+
"type": "embedding",
|
178
|
+
"backend": "triton",
|
179
|
+
"url": self.triton_url,
|
180
|
+
"loaded": self.is_loaded(),
|
181
|
+
"embedding_dim": 1024, # Typical for BGE models
|
182
|
+
"config": self.default_config
|
183
|
+
}
|
@@ -0,0 +1,85 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import List, Dict, Any, Optional
|
3
|
+
from isa_model.inference.services.base_service import BaseEmbeddingService
|
4
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
5
|
+
from isa_model.inference.backends.local_services import OllamaBackendClient
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
class OllamaEmbedService(BaseEmbeddingService):
|
10
|
+
"""Ollama embedding service using backend client"""
|
11
|
+
|
12
|
+
def __init__(self, provider: 'BaseProvider', model_name: str = "bge-m3", backend: Optional[OllamaBackendClient] = None):
|
13
|
+
super().__init__(provider, model_name)
|
14
|
+
|
15
|
+
# Use provided backend or create new one
|
16
|
+
if backend:
|
17
|
+
self.backend = backend
|
18
|
+
else:
|
19
|
+
host = self.config.get("host", "localhost")
|
20
|
+
port = self.config.get("port", 11434)
|
21
|
+
self.backend = OllamaBackendClient(host, port)
|
22
|
+
|
23
|
+
logger.info(f"Initialized OllamaEmbedService with model {model_name}")
|
24
|
+
|
25
|
+
async def create_text_embedding(self, text: str) -> List[float]:
|
26
|
+
"""Create embedding for text"""
|
27
|
+
try:
|
28
|
+
payload = {
|
29
|
+
"model": self.model_name,
|
30
|
+
"prompt": text
|
31
|
+
}
|
32
|
+
response = await self.backend.post("/api/embeddings", payload)
|
33
|
+
return response["embedding"]
|
34
|
+
|
35
|
+
except Exception as e:
|
36
|
+
logger.error(f"Error creating text embedding: {e}")
|
37
|
+
raise
|
38
|
+
|
39
|
+
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
40
|
+
"""Create embeddings for multiple texts"""
|
41
|
+
embeddings = []
|
42
|
+
for text in texts:
|
43
|
+
embedding = await self.create_text_embedding(text)
|
44
|
+
embeddings.append(embedding)
|
45
|
+
return embeddings
|
46
|
+
|
47
|
+
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
48
|
+
"""Create text chunks with embeddings"""
|
49
|
+
# 简单实现:将文本分成固定大小的块
|
50
|
+
chunk_size = 200 # 单词数量
|
51
|
+
chunks = []
|
52
|
+
|
53
|
+
# 按单词分割
|
54
|
+
words = text.split()
|
55
|
+
|
56
|
+
# 分块
|
57
|
+
for i in range(0, len(words), chunk_size):
|
58
|
+
chunk_text = " ".join(words[i:i+chunk_size])
|
59
|
+
embedding = await self.create_text_embedding(chunk_text)
|
60
|
+
|
61
|
+
chunk = {
|
62
|
+
"text": chunk_text,
|
63
|
+
"embedding": embedding,
|
64
|
+
"metadata": metadata or {}
|
65
|
+
}
|
66
|
+
chunks.append(chunk)
|
67
|
+
|
68
|
+
return chunks
|
69
|
+
|
70
|
+
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
71
|
+
"""计算两个嵌入向量之间的余弦相似度"""
|
72
|
+
# 余弦相似度简单实现
|
73
|
+
dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
|
74
|
+
norm1 = sum(a * a for a in embedding1) ** 0.5
|
75
|
+
norm2 = sum(b * b for b in embedding2) ** 0.5
|
76
|
+
|
77
|
+
if norm1 * norm2 == 0:
|
78
|
+
return 0.0
|
79
|
+
|
80
|
+
return dot_product / (norm1 * norm2)
|
81
|
+
|
82
|
+
async def close(self):
|
83
|
+
"""Close the backend client"""
|
84
|
+
await self.backend.close()
|
85
|
+
|