isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,354 @@
|
|
1
|
+
from typing import Dict, Type, Any, Optional, Tuple
|
2
|
+
import logging
|
3
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
4
|
+
from isa_model.inference.services.base_service import BaseService
|
5
|
+
from isa_model.inference.base import ModelType
|
6
|
+
import os
|
7
|
+
|
8
|
+
from isa_model.inference.services.llm.llama_service import LlamaService
|
9
|
+
from isa_model.inference.services.llm.gemma_service import GemmaService
|
10
|
+
from isa_model.inference.services.audio.whisper_service import WhisperService
|
11
|
+
from isa_model.inference.services.embedding.bge_service import BgeEmbeddingService
|
12
|
+
|
13
|
+
# 设置基本的日志配置
|
14
|
+
logging.basicConfig(level=logging.INFO)
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
class AIFactory:
|
18
|
+
"""
|
19
|
+
Factory for creating AI services based on the Single Model pattern.
|
20
|
+
"""
|
21
|
+
|
22
|
+
_instance = None
|
23
|
+
_is_initialized = False
|
24
|
+
|
25
|
+
def __new__(cls):
|
26
|
+
if cls._instance is None:
|
27
|
+
cls._instance = super().__new__(cls)
|
28
|
+
return cls._instance
|
29
|
+
|
30
|
+
def __init__(self):
|
31
|
+
"""Initialize the AI Factory."""
|
32
|
+
self.triton_url = os.environ.get("TRITON_URL", "localhost:8001")
|
33
|
+
|
34
|
+
# Cache for services (singleton pattern)
|
35
|
+
self._llm_services = {}
|
36
|
+
self._embedding_services = {}
|
37
|
+
self._speech_services = {}
|
38
|
+
|
39
|
+
if not self._is_initialized:
|
40
|
+
self._providers: Dict[str, Type[BaseProvider]] = {}
|
41
|
+
self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
|
42
|
+
self._cached_services: Dict[str, BaseService] = {}
|
43
|
+
self._initialize_defaults()
|
44
|
+
AIFactory._is_initialized = True
|
45
|
+
|
46
|
+
def _initialize_defaults(self):
|
47
|
+
"""Initialize default providers and services"""
|
48
|
+
try:
|
49
|
+
# Import providers and services
|
50
|
+
from isa_model.inference.providers.ollama_provider import OllamaProvider
|
51
|
+
from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
|
52
|
+
from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
|
53
|
+
|
54
|
+
# Register Ollama provider and services
|
55
|
+
self.register_provider('ollama', OllamaProvider)
|
56
|
+
self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
|
57
|
+
self.register_service('ollama', ModelType.LLM, OllamaLLMService)
|
58
|
+
|
59
|
+
# Register OpenAI provider and services
|
60
|
+
try:
|
61
|
+
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
62
|
+
from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
|
63
|
+
|
64
|
+
self.register_provider('openai', OpenAIProvider)
|
65
|
+
self.register_service('openai', ModelType.LLM, OpenAILLMService)
|
66
|
+
logger.info("OpenAI services registered successfully")
|
67
|
+
except ImportError as e:
|
68
|
+
logger.warning(f"OpenAI services not available: {e}")
|
69
|
+
|
70
|
+
# Register Replicate provider and services
|
71
|
+
try:
|
72
|
+
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
73
|
+
from isa_model.inference.services.llm.replicate_llm_service import ReplicateLLMService
|
74
|
+
from isa_model.inference.services.vision.replicate_vision_service import ReplicateVisionService
|
75
|
+
|
76
|
+
self.register_provider('replicate', ReplicateProvider)
|
77
|
+
self.register_service('replicate', ModelType.LLM, ReplicateLLMService)
|
78
|
+
self.register_service('replicate', ModelType.VISION, ReplicateVisionService)
|
79
|
+
logger.info("Replicate services registered successfully")
|
80
|
+
except ImportError as e:
|
81
|
+
logger.warning(f"Replicate services not available: {e}")
|
82
|
+
|
83
|
+
# Try to register Triton services
|
84
|
+
try:
|
85
|
+
from isa_model.inference.providers.triton_provider import TritonProvider
|
86
|
+
from isa_model.inference.services.llm.triton_llm_service import TritonLLMService
|
87
|
+
from isa_model.inference.services.vision.triton_vision_service import TritonVisionService
|
88
|
+
from isa_model.inference.services.audio.triton_speech_service import TritonSpeechService
|
89
|
+
|
90
|
+
self.register_provider('triton', TritonProvider)
|
91
|
+
self.register_service('triton', ModelType.LLM, TritonLLMService)
|
92
|
+
self.register_service('triton', ModelType.VISION, TritonVisionService)
|
93
|
+
self.register_service('triton', ModelType.AUDIO, TritonSpeechService)
|
94
|
+
logger.info("Triton services registered successfully")
|
95
|
+
|
96
|
+
# Register HuggingFace-based direct LLM service for Llama3-8B
|
97
|
+
try:
|
98
|
+
from isa_model.inference.llm.llama3_service import Llama3Service
|
99
|
+
# Register as a standalone service for direct access
|
100
|
+
self._cached_services["llama3"] = Llama3Service()
|
101
|
+
logger.info("Llama3-8B service registered successfully")
|
102
|
+
except ImportError as e:
|
103
|
+
logger.warning(f"Llama3-8B service not available: {e}")
|
104
|
+
|
105
|
+
# Register HuggingFace-based direct Vision service for Gemma3-4B
|
106
|
+
try:
|
107
|
+
from isa_model.inference.vision.gemma3_service import Gemma3VisionService
|
108
|
+
# Register as a standalone service for direct access
|
109
|
+
self._cached_services["gemma3"] = Gemma3VisionService()
|
110
|
+
logger.info("Gemma3-4B Vision service registered successfully")
|
111
|
+
except ImportError as e:
|
112
|
+
logger.warning(f"Gemma3-4B Vision service not available: {e}")
|
113
|
+
|
114
|
+
# Register HuggingFace-based direct Speech service for Whisper Tiny
|
115
|
+
try:
|
116
|
+
from isa_model.inference.speech.whisper_service import WhisperService
|
117
|
+
# Register as a standalone service for direct access
|
118
|
+
self._cached_services["whisper"] = WhisperService()
|
119
|
+
logger.info("Whisper Tiny Speech service registered successfully")
|
120
|
+
except ImportError as e:
|
121
|
+
logger.warning(f"Whisper Tiny Speech service not available: {e}")
|
122
|
+
|
123
|
+
except ImportError as e:
|
124
|
+
logger.warning(f"Triton services not available: {e}")
|
125
|
+
|
126
|
+
logger.info("Default AI providers and services initialized with backend architecture")
|
127
|
+
except Exception as e:
|
128
|
+
logger.error(f"Error initializing default providers and services: {e}")
|
129
|
+
# Don't raise - allow factory to work even if some services fail to load
|
130
|
+
logger.warning("Some services may not be available due to import errors")
|
131
|
+
|
132
|
+
def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
|
133
|
+
"""Register an AI provider"""
|
134
|
+
self._providers[name] = provider_class
|
135
|
+
|
136
|
+
def register_service(self, provider_name: str, model_type: ModelType,
|
137
|
+
service_class: Type[BaseService]) -> None:
|
138
|
+
"""Register a service type with its provider"""
|
139
|
+
self._services[(provider_name, model_type)] = service_class
|
140
|
+
|
141
|
+
def create_service(self, provider_name: str, model_type: ModelType,
|
142
|
+
model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
|
143
|
+
"""Create a service instance"""
|
144
|
+
try:
|
145
|
+
cache_key = f"{provider_name}_{model_type}_{model_name}"
|
146
|
+
|
147
|
+
if cache_key in self._cached_services:
|
148
|
+
return self._cached_services[cache_key]
|
149
|
+
|
150
|
+
# 基础配置
|
151
|
+
base_config = {
|
152
|
+
"log_level": "INFO"
|
153
|
+
}
|
154
|
+
|
155
|
+
# 合并配置
|
156
|
+
service_config = {**base_config, **(config or {})}
|
157
|
+
|
158
|
+
# 创建 provider 和 service
|
159
|
+
provider_class = self._providers[provider_name]
|
160
|
+
service_class = self._services.get((provider_name, model_type))
|
161
|
+
|
162
|
+
if not service_class:
|
163
|
+
raise ValueError(
|
164
|
+
f"No service registered for provider {provider_name} and model type {model_type}"
|
165
|
+
)
|
166
|
+
|
167
|
+
provider = provider_class(config=service_config)
|
168
|
+
service = service_class(provider=provider, model_name=model_name)
|
169
|
+
|
170
|
+
self._cached_services[cache_key] = service
|
171
|
+
return service
|
172
|
+
|
173
|
+
except Exception as e:
|
174
|
+
logger.error(f"Error creating service: {e}")
|
175
|
+
raise
|
176
|
+
|
177
|
+
# Convenient methods for common services
|
178
|
+
def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
|
179
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
180
|
+
"""Get a LLM service instance"""
|
181
|
+
|
182
|
+
# Special case for Llama3-8B direct service
|
183
|
+
if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
|
184
|
+
if "llama3" in self._cached_services:
|
185
|
+
return self._cached_services["llama3"]
|
186
|
+
|
187
|
+
basic_config = {
|
188
|
+
"temperature": 0
|
189
|
+
}
|
190
|
+
if config:
|
191
|
+
basic_config.update(config)
|
192
|
+
return self.create_service(provider, ModelType.LLM, model_name, basic_config)
|
193
|
+
|
194
|
+
def get_vision_model(self, model_name: str = "gemma3-4b", provider: str = "triton",
|
195
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
196
|
+
"""Get a vision model service instance"""
|
197
|
+
|
198
|
+
# Special case for Gemma3-4B direct service
|
199
|
+
if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
|
200
|
+
if "gemma3" in self._cached_services:
|
201
|
+
return self._cached_services["gemma3"]
|
202
|
+
|
203
|
+
# Special case for Replicate's image generation models
|
204
|
+
if provider == "replicate" and "/" in model_name:
|
205
|
+
basic_config = {
|
206
|
+
"api_token": os.environ.get("REPLICATE_API_TOKEN", ""),
|
207
|
+
"guidance_scale": 7.5,
|
208
|
+
"num_inference_steps": 30
|
209
|
+
}
|
210
|
+
if config:
|
211
|
+
basic_config.update(config)
|
212
|
+
return self.create_service(provider, ModelType.VISION, model_name, basic_config)
|
213
|
+
|
214
|
+
basic_config = {
|
215
|
+
"temperature": 0.7,
|
216
|
+
"max_new_tokens": 512
|
217
|
+
}
|
218
|
+
if config:
|
219
|
+
basic_config.update(config)
|
220
|
+
return self.create_service(provider, ModelType.VISION, model_name, basic_config)
|
221
|
+
|
222
|
+
def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
|
223
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
224
|
+
"""Get an embedding service instance"""
|
225
|
+
return self.create_service(provider, ModelType.EMBEDDING, model_name, config)
|
226
|
+
|
227
|
+
def get_rerank(self, model_name: str = "bge-m3", provider: str = "ollama",
|
228
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
229
|
+
"""Get a rerank service instance"""
|
230
|
+
return self.create_service(provider, ModelType.RERANK, model_name, config)
|
231
|
+
|
232
|
+
def get_embed_service(self, model_name: str = "bge-m3", provider: str = "ollama",
|
233
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
234
|
+
"""Get an embedding service instance"""
|
235
|
+
return self.get_embedding(model_name, provider, config)
|
236
|
+
|
237
|
+
def get_speech_model(self, model_name: str = "whisper_tiny", provider: str = "triton",
|
238
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
239
|
+
"""Get a speech-to-text model service instance"""
|
240
|
+
|
241
|
+
# Special case for Whisper Tiny direct service
|
242
|
+
if model_name.lower() in ["whisper", "whisper_tiny", "whisper-tiny"]:
|
243
|
+
if "whisper" in self._cached_services:
|
244
|
+
return self._cached_services["whisper"]
|
245
|
+
|
246
|
+
basic_config = {
|
247
|
+
"language": "en",
|
248
|
+
"task": "transcribe"
|
249
|
+
}
|
250
|
+
if config:
|
251
|
+
basic_config.update(config)
|
252
|
+
return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
|
253
|
+
|
254
|
+
async def get_llm_service(self, model_name: str) -> Any:
|
255
|
+
"""
|
256
|
+
Get an LLM service for the specified model.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
model_name: Name of the model
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
LLM service instance
|
263
|
+
"""
|
264
|
+
if model_name in self._llm_services:
|
265
|
+
return self._llm_services[model_name]
|
266
|
+
|
267
|
+
if model_name == "llama":
|
268
|
+
service = LlamaService(triton_url=self.triton_url, model_name="llama")
|
269
|
+
await service.load()
|
270
|
+
self._llm_services[model_name] = service
|
271
|
+
return service
|
272
|
+
elif model_name == "gemma":
|
273
|
+
service = GemmaService(triton_url=self.triton_url, model_name="gemma")
|
274
|
+
await service.load()
|
275
|
+
self._llm_services[model_name] = service
|
276
|
+
return service
|
277
|
+
else:
|
278
|
+
raise ValueError(f"Unsupported LLM model: {model_name}")
|
279
|
+
|
280
|
+
async def get_embedding_service(self, model_name: str) -> Any:
|
281
|
+
"""
|
282
|
+
Get an embedding service for the specified model.
|
283
|
+
|
284
|
+
Args:
|
285
|
+
model_name: Name of the model
|
286
|
+
|
287
|
+
Returns:
|
288
|
+
Embedding service instance
|
289
|
+
"""
|
290
|
+
if model_name in self._embedding_services:
|
291
|
+
return self._embedding_services[model_name]
|
292
|
+
|
293
|
+
if model_name == "bge_embed":
|
294
|
+
service = BgeEmbeddingService(triton_url=self.triton_url, model_name="bge_embed")
|
295
|
+
await service.load()
|
296
|
+
self._embedding_services[model_name] = service
|
297
|
+
return service
|
298
|
+
else:
|
299
|
+
raise ValueError(f"Unsupported embedding model: {model_name}")
|
300
|
+
|
301
|
+
async def get_speech_service(self, model_name: str) -> Any:
|
302
|
+
"""
|
303
|
+
Get a speech service for the specified model.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
model_name: Name of the model
|
307
|
+
|
308
|
+
Returns:
|
309
|
+
Speech service instance
|
310
|
+
"""
|
311
|
+
if model_name in self._speech_services:
|
312
|
+
return self._speech_services[model_name]
|
313
|
+
|
314
|
+
if model_name == "whisper":
|
315
|
+
service = WhisperService(triton_url=self.triton_url, model_name="whisper")
|
316
|
+
await service.load()
|
317
|
+
self._speech_services[model_name] = service
|
318
|
+
return service
|
319
|
+
else:
|
320
|
+
raise ValueError(f"Unsupported speech model: {model_name}")
|
321
|
+
|
322
|
+
def get_model_info(self, model_type: Optional[str] = None) -> Dict[str, Any]:
|
323
|
+
"""
|
324
|
+
Get information about available models.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
model_type: Optional filter for model type
|
328
|
+
|
329
|
+
Returns:
|
330
|
+
Dict of model information
|
331
|
+
"""
|
332
|
+
models = {
|
333
|
+
"llm": [
|
334
|
+
{"name": "llama", "description": "Llama3-8B language model"},
|
335
|
+
{"name": "gemma", "description": "Gemma3-4B language model"}
|
336
|
+
],
|
337
|
+
"embedding": [
|
338
|
+
{"name": "bge_embed", "description": "BGE-M3 text embedding model"}
|
339
|
+
],
|
340
|
+
"speech": [
|
341
|
+
{"name": "whisper", "description": "Whisper-tiny speech-to-text model"}
|
342
|
+
]
|
343
|
+
}
|
344
|
+
|
345
|
+
if model_type:
|
346
|
+
return {model_type: models.get(model_type, [])}
|
347
|
+
return models
|
348
|
+
|
349
|
+
@classmethod
|
350
|
+
def get_instance(cls) -> 'AIFactory':
|
351
|
+
"""Get the singleton instance"""
|
352
|
+
if cls._instance is None:
|
353
|
+
cls._instance = cls()
|
354
|
+
return cls._instance
|
@@ -0,0 +1,188 @@
|
|
1
|
+
import os
|
2
|
+
import logging
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from typing import Dict, List, Any, Optional, Union
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class BgeEmbedBackend:
|
11
|
+
"""
|
12
|
+
PyTorch backend for the BGE embedding model.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, model_path: Optional[str] = None, device: str = "auto"):
|
16
|
+
"""
|
17
|
+
Initialize the BGE embedding backend.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
model_path: Path to the model
|
21
|
+
device: Device to run the model on ("cpu", "cuda", or "auto")
|
22
|
+
"""
|
23
|
+
self.model_path = model_path or os.environ.get("BGE_MODEL_PATH", "/models/Bge-m3")
|
24
|
+
self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
|
25
|
+
self.model = None
|
26
|
+
self.tokenizer = None
|
27
|
+
self._loaded = False
|
28
|
+
|
29
|
+
# Default configuration
|
30
|
+
self.config = {
|
31
|
+
"normalize": True,
|
32
|
+
"max_length": 512,
|
33
|
+
"pooling_method": "cls" # Use CLS token for sentence embedding
|
34
|
+
}
|
35
|
+
|
36
|
+
self.logger = logger
|
37
|
+
|
38
|
+
def load(self) -> None:
|
39
|
+
"""
|
40
|
+
Load the model and tokenizer.
|
41
|
+
"""
|
42
|
+
if self._loaded:
|
43
|
+
return
|
44
|
+
|
45
|
+
try:
|
46
|
+
from transformers import AutoModel, AutoTokenizer
|
47
|
+
|
48
|
+
# Load tokenizer
|
49
|
+
self.logger.info(f"Loading BGE tokenizer from {self.model_path}")
|
50
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
51
|
+
|
52
|
+
# Load model
|
53
|
+
self.logger.info(f"Loading BGE model on {self.device}")
|
54
|
+
if self.device == "cpu":
|
55
|
+
self.model = AutoModel.from_pretrained(
|
56
|
+
self.model_path,
|
57
|
+
torch_dtype=torch.float32,
|
58
|
+
device_map="auto"
|
59
|
+
)
|
60
|
+
else: # cuda
|
61
|
+
self.model = AutoModel.from_pretrained(
|
62
|
+
self.model_path,
|
63
|
+
torch_dtype=torch.float16, # Use half precision on GPU
|
64
|
+
device_map="auto"
|
65
|
+
)
|
66
|
+
|
67
|
+
self.model.eval()
|
68
|
+
self._loaded = True
|
69
|
+
self.logger.info("BGE model loaded successfully")
|
70
|
+
|
71
|
+
except Exception as e:
|
72
|
+
self.logger.error(f"Failed to load BGE model: {str(e)}")
|
73
|
+
raise
|
74
|
+
|
75
|
+
def unload(self) -> None:
|
76
|
+
"""
|
77
|
+
Unload the model and tokenizer.
|
78
|
+
"""
|
79
|
+
if not self._loaded:
|
80
|
+
return
|
81
|
+
|
82
|
+
self.model = None
|
83
|
+
self.tokenizer = None
|
84
|
+
self._loaded = False
|
85
|
+
|
86
|
+
# Force garbage collection
|
87
|
+
import gc
|
88
|
+
gc.collect()
|
89
|
+
|
90
|
+
if self.device == "cuda":
|
91
|
+
torch.cuda.empty_cache()
|
92
|
+
|
93
|
+
self.logger.info("BGE model unloaded")
|
94
|
+
|
95
|
+
def embed(self,
|
96
|
+
texts: Union[str, List[str]],
|
97
|
+
normalize: Optional[bool] = None) -> np.ndarray:
|
98
|
+
"""
|
99
|
+
Generate embeddings for texts.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
texts: Single text or list of texts to embed
|
103
|
+
normalize: Whether to normalize embeddings (if None, use default)
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
Numpy array of embeddings, shape [batch_size, embedding_dim]
|
107
|
+
"""
|
108
|
+
if not self._loaded:
|
109
|
+
self.load()
|
110
|
+
|
111
|
+
# Handle single text input
|
112
|
+
if isinstance(texts, str):
|
113
|
+
texts = [texts]
|
114
|
+
|
115
|
+
# Use default normalize setting if not specified
|
116
|
+
if normalize is None:
|
117
|
+
normalize = self.config["normalize"]
|
118
|
+
|
119
|
+
try:
|
120
|
+
# Tokenize the texts
|
121
|
+
inputs = self.tokenizer(
|
122
|
+
texts,
|
123
|
+
padding=True,
|
124
|
+
truncation=True,
|
125
|
+
max_length=self.config["max_length"],
|
126
|
+
return_tensors="pt"
|
127
|
+
).to(self.device)
|
128
|
+
|
129
|
+
# Generate embeddings
|
130
|
+
with torch.no_grad():
|
131
|
+
outputs = self.model(**inputs)
|
132
|
+
|
133
|
+
# Use [CLS] token embedding as the sentence embedding
|
134
|
+
embeddings = outputs.last_hidden_state[:, 0, :]
|
135
|
+
|
136
|
+
# Normalize embeddings if required
|
137
|
+
if normalize:
|
138
|
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
139
|
+
|
140
|
+
# Convert to numpy array
|
141
|
+
embeddings_np = embeddings.cpu().numpy()
|
142
|
+
|
143
|
+
return embeddings_np
|
144
|
+
|
145
|
+
except Exception as e:
|
146
|
+
self.logger.error(f"Error during BGE embedding generation: {str(e)}")
|
147
|
+
raise
|
148
|
+
|
149
|
+
def get_model_info(self) -> Dict[str, Any]:
|
150
|
+
"""
|
151
|
+
Get information about the model.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
Dictionary containing model information
|
155
|
+
"""
|
156
|
+
return {
|
157
|
+
"name": "bge-m3",
|
158
|
+
"type": "embedding",
|
159
|
+
"device": self.device,
|
160
|
+
"path": self.model_path,
|
161
|
+
"loaded": self._loaded,
|
162
|
+
"embedding_dim": 1024, # Typical for BGE models
|
163
|
+
"config": self.config
|
164
|
+
}
|
165
|
+
|
166
|
+
def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
167
|
+
"""
|
168
|
+
Calculate cosine similarity between two embeddings.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
embedding1: First embedding vector
|
172
|
+
embedding2: Second embedding vector
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
Cosine similarity score (float between -1 and 1)
|
176
|
+
"""
|
177
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
178
|
+
|
179
|
+
# Reshape if needed
|
180
|
+
if embedding1.ndim == 1:
|
181
|
+
embedding1 = embedding1.reshape(1, -1)
|
182
|
+
if embedding2.ndim == 1:
|
183
|
+
embedding2 = embedding2.reshape(1, -1)
|
184
|
+
|
185
|
+
# Calculate cosine similarity
|
186
|
+
similarity = cosine_similarity(embedding1, embedding2)[0][0]
|
187
|
+
|
188
|
+
return float(similarity)
|