isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,354 @@
1
+ from typing import Dict, Type, Any, Optional, Tuple
2
+ import logging
3
+ from isa_model.inference.providers.base_provider import BaseProvider
4
+ from isa_model.inference.services.base_service import BaseService
5
+ from isa_model.inference.base import ModelType
6
+ import os
7
+
8
+ from isa_model.inference.services.llm.llama_service import LlamaService
9
+ from isa_model.inference.services.llm.gemma_service import GemmaService
10
+ from isa_model.inference.services.audio.whisper_service import WhisperService
11
+ from isa_model.inference.services.embedding.bge_service import BgeEmbeddingService
12
+
13
+ # 设置基本的日志配置
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class AIFactory:
18
+ """
19
+ Factory for creating AI services based on the Single Model pattern.
20
+ """
21
+
22
+ _instance = None
23
+ _is_initialized = False
24
+
25
+ def __new__(cls):
26
+ if cls._instance is None:
27
+ cls._instance = super().__new__(cls)
28
+ return cls._instance
29
+
30
+ def __init__(self):
31
+ """Initialize the AI Factory."""
32
+ self.triton_url = os.environ.get("TRITON_URL", "localhost:8001")
33
+
34
+ # Cache for services (singleton pattern)
35
+ self._llm_services = {}
36
+ self._embedding_services = {}
37
+ self._speech_services = {}
38
+
39
+ if not self._is_initialized:
40
+ self._providers: Dict[str, Type[BaseProvider]] = {}
41
+ self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
42
+ self._cached_services: Dict[str, BaseService] = {}
43
+ self._initialize_defaults()
44
+ AIFactory._is_initialized = True
45
+
46
+ def _initialize_defaults(self):
47
+ """Initialize default providers and services"""
48
+ try:
49
+ # Import providers and services
50
+ from isa_model.inference.providers.ollama_provider import OllamaProvider
51
+ from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
52
+ from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
53
+
54
+ # Register Ollama provider and services
55
+ self.register_provider('ollama', OllamaProvider)
56
+ self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
57
+ self.register_service('ollama', ModelType.LLM, OllamaLLMService)
58
+
59
+ # Register OpenAI provider and services
60
+ try:
61
+ from isa_model.inference.providers.openai_provider import OpenAIProvider
62
+ from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
63
+
64
+ self.register_provider('openai', OpenAIProvider)
65
+ self.register_service('openai', ModelType.LLM, OpenAILLMService)
66
+ logger.info("OpenAI services registered successfully")
67
+ except ImportError as e:
68
+ logger.warning(f"OpenAI services not available: {e}")
69
+
70
+ # Register Replicate provider and services
71
+ try:
72
+ from isa_model.inference.providers.replicate_provider import ReplicateProvider
73
+ from isa_model.inference.services.llm.replicate_llm_service import ReplicateLLMService
74
+ from isa_model.inference.services.vision.replicate_vision_service import ReplicateVisionService
75
+
76
+ self.register_provider('replicate', ReplicateProvider)
77
+ self.register_service('replicate', ModelType.LLM, ReplicateLLMService)
78
+ self.register_service('replicate', ModelType.VISION, ReplicateVisionService)
79
+ logger.info("Replicate services registered successfully")
80
+ except ImportError as e:
81
+ logger.warning(f"Replicate services not available: {e}")
82
+
83
+ # Try to register Triton services
84
+ try:
85
+ from isa_model.inference.providers.triton_provider import TritonProvider
86
+ from isa_model.inference.services.llm.triton_llm_service import TritonLLMService
87
+ from isa_model.inference.services.vision.triton_vision_service import TritonVisionService
88
+ from isa_model.inference.services.audio.triton_speech_service import TritonSpeechService
89
+
90
+ self.register_provider('triton', TritonProvider)
91
+ self.register_service('triton', ModelType.LLM, TritonLLMService)
92
+ self.register_service('triton', ModelType.VISION, TritonVisionService)
93
+ self.register_service('triton', ModelType.AUDIO, TritonSpeechService)
94
+ logger.info("Triton services registered successfully")
95
+
96
+ # Register HuggingFace-based direct LLM service for Llama3-8B
97
+ try:
98
+ from isa_model.inference.llm.llama3_service import Llama3Service
99
+ # Register as a standalone service for direct access
100
+ self._cached_services["llama3"] = Llama3Service()
101
+ logger.info("Llama3-8B service registered successfully")
102
+ except ImportError as e:
103
+ logger.warning(f"Llama3-8B service not available: {e}")
104
+
105
+ # Register HuggingFace-based direct Vision service for Gemma3-4B
106
+ try:
107
+ from isa_model.inference.vision.gemma3_service import Gemma3VisionService
108
+ # Register as a standalone service for direct access
109
+ self._cached_services["gemma3"] = Gemma3VisionService()
110
+ logger.info("Gemma3-4B Vision service registered successfully")
111
+ except ImportError as e:
112
+ logger.warning(f"Gemma3-4B Vision service not available: {e}")
113
+
114
+ # Register HuggingFace-based direct Speech service for Whisper Tiny
115
+ try:
116
+ from isa_model.inference.speech.whisper_service import WhisperService
117
+ # Register as a standalone service for direct access
118
+ self._cached_services["whisper"] = WhisperService()
119
+ logger.info("Whisper Tiny Speech service registered successfully")
120
+ except ImportError as e:
121
+ logger.warning(f"Whisper Tiny Speech service not available: {e}")
122
+
123
+ except ImportError as e:
124
+ logger.warning(f"Triton services not available: {e}")
125
+
126
+ logger.info("Default AI providers and services initialized with backend architecture")
127
+ except Exception as e:
128
+ logger.error(f"Error initializing default providers and services: {e}")
129
+ # Don't raise - allow factory to work even if some services fail to load
130
+ logger.warning("Some services may not be available due to import errors")
131
+
132
+ def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
133
+ """Register an AI provider"""
134
+ self._providers[name] = provider_class
135
+
136
+ def register_service(self, provider_name: str, model_type: ModelType,
137
+ service_class: Type[BaseService]) -> None:
138
+ """Register a service type with its provider"""
139
+ self._services[(provider_name, model_type)] = service_class
140
+
141
+ def create_service(self, provider_name: str, model_type: ModelType,
142
+ model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
143
+ """Create a service instance"""
144
+ try:
145
+ cache_key = f"{provider_name}_{model_type}_{model_name}"
146
+
147
+ if cache_key in self._cached_services:
148
+ return self._cached_services[cache_key]
149
+
150
+ # 基础配置
151
+ base_config = {
152
+ "log_level": "INFO"
153
+ }
154
+
155
+ # 合并配置
156
+ service_config = {**base_config, **(config or {})}
157
+
158
+ # 创建 provider 和 service
159
+ provider_class = self._providers[provider_name]
160
+ service_class = self._services.get((provider_name, model_type))
161
+
162
+ if not service_class:
163
+ raise ValueError(
164
+ f"No service registered for provider {provider_name} and model type {model_type}"
165
+ )
166
+
167
+ provider = provider_class(config=service_config)
168
+ service = service_class(provider=provider, model_name=model_name)
169
+
170
+ self._cached_services[cache_key] = service
171
+ return service
172
+
173
+ except Exception as e:
174
+ logger.error(f"Error creating service: {e}")
175
+ raise
176
+
177
+ # Convenient methods for common services
178
+ def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
179
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
180
+ """Get a LLM service instance"""
181
+
182
+ # Special case for Llama3-8B direct service
183
+ if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
184
+ if "llama3" in self._cached_services:
185
+ return self._cached_services["llama3"]
186
+
187
+ basic_config = {
188
+ "temperature": 0
189
+ }
190
+ if config:
191
+ basic_config.update(config)
192
+ return self.create_service(provider, ModelType.LLM, model_name, basic_config)
193
+
194
+ def get_vision_model(self, model_name: str = "gemma3-4b", provider: str = "triton",
195
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
196
+ """Get a vision model service instance"""
197
+
198
+ # Special case for Gemma3-4B direct service
199
+ if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
200
+ if "gemma3" in self._cached_services:
201
+ return self._cached_services["gemma3"]
202
+
203
+ # Special case for Replicate's image generation models
204
+ if provider == "replicate" and "/" in model_name:
205
+ basic_config = {
206
+ "api_token": os.environ.get("REPLICATE_API_TOKEN", ""),
207
+ "guidance_scale": 7.5,
208
+ "num_inference_steps": 30
209
+ }
210
+ if config:
211
+ basic_config.update(config)
212
+ return self.create_service(provider, ModelType.VISION, model_name, basic_config)
213
+
214
+ basic_config = {
215
+ "temperature": 0.7,
216
+ "max_new_tokens": 512
217
+ }
218
+ if config:
219
+ basic_config.update(config)
220
+ return self.create_service(provider, ModelType.VISION, model_name, basic_config)
221
+
222
+ def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
223
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
224
+ """Get an embedding service instance"""
225
+ return self.create_service(provider, ModelType.EMBEDDING, model_name, config)
226
+
227
+ def get_rerank(self, model_name: str = "bge-m3", provider: str = "ollama",
228
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
229
+ """Get a rerank service instance"""
230
+ return self.create_service(provider, ModelType.RERANK, model_name, config)
231
+
232
+ def get_embed_service(self, model_name: str = "bge-m3", provider: str = "ollama",
233
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
234
+ """Get an embedding service instance"""
235
+ return self.get_embedding(model_name, provider, config)
236
+
237
+ def get_speech_model(self, model_name: str = "whisper_tiny", provider: str = "triton",
238
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
239
+ """Get a speech-to-text model service instance"""
240
+
241
+ # Special case for Whisper Tiny direct service
242
+ if model_name.lower() in ["whisper", "whisper_tiny", "whisper-tiny"]:
243
+ if "whisper" in self._cached_services:
244
+ return self._cached_services["whisper"]
245
+
246
+ basic_config = {
247
+ "language": "en",
248
+ "task": "transcribe"
249
+ }
250
+ if config:
251
+ basic_config.update(config)
252
+ return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
253
+
254
+ async def get_llm_service(self, model_name: str) -> Any:
255
+ """
256
+ Get an LLM service for the specified model.
257
+
258
+ Args:
259
+ model_name: Name of the model
260
+
261
+ Returns:
262
+ LLM service instance
263
+ """
264
+ if model_name in self._llm_services:
265
+ return self._llm_services[model_name]
266
+
267
+ if model_name == "llama":
268
+ service = LlamaService(triton_url=self.triton_url, model_name="llama")
269
+ await service.load()
270
+ self._llm_services[model_name] = service
271
+ return service
272
+ elif model_name == "gemma":
273
+ service = GemmaService(triton_url=self.triton_url, model_name="gemma")
274
+ await service.load()
275
+ self._llm_services[model_name] = service
276
+ return service
277
+ else:
278
+ raise ValueError(f"Unsupported LLM model: {model_name}")
279
+
280
+ async def get_embedding_service(self, model_name: str) -> Any:
281
+ """
282
+ Get an embedding service for the specified model.
283
+
284
+ Args:
285
+ model_name: Name of the model
286
+
287
+ Returns:
288
+ Embedding service instance
289
+ """
290
+ if model_name in self._embedding_services:
291
+ return self._embedding_services[model_name]
292
+
293
+ if model_name == "bge_embed":
294
+ service = BgeEmbeddingService(triton_url=self.triton_url, model_name="bge_embed")
295
+ await service.load()
296
+ self._embedding_services[model_name] = service
297
+ return service
298
+ else:
299
+ raise ValueError(f"Unsupported embedding model: {model_name}")
300
+
301
+ async def get_speech_service(self, model_name: str) -> Any:
302
+ """
303
+ Get a speech service for the specified model.
304
+
305
+ Args:
306
+ model_name: Name of the model
307
+
308
+ Returns:
309
+ Speech service instance
310
+ """
311
+ if model_name in self._speech_services:
312
+ return self._speech_services[model_name]
313
+
314
+ if model_name == "whisper":
315
+ service = WhisperService(triton_url=self.triton_url, model_name="whisper")
316
+ await service.load()
317
+ self._speech_services[model_name] = service
318
+ return service
319
+ else:
320
+ raise ValueError(f"Unsupported speech model: {model_name}")
321
+
322
+ def get_model_info(self, model_type: Optional[str] = None) -> Dict[str, Any]:
323
+ """
324
+ Get information about available models.
325
+
326
+ Args:
327
+ model_type: Optional filter for model type
328
+
329
+ Returns:
330
+ Dict of model information
331
+ """
332
+ models = {
333
+ "llm": [
334
+ {"name": "llama", "description": "Llama3-8B language model"},
335
+ {"name": "gemma", "description": "Gemma3-4B language model"}
336
+ ],
337
+ "embedding": [
338
+ {"name": "bge_embed", "description": "BGE-M3 text embedding model"}
339
+ ],
340
+ "speech": [
341
+ {"name": "whisper", "description": "Whisper-tiny speech-to-text model"}
342
+ ]
343
+ }
344
+
345
+ if model_type:
346
+ return {model_type: models.get(model_type, [])}
347
+ return models
348
+
349
+ @classmethod
350
+ def get_instance(cls) -> 'AIFactory':
351
+ """Get the singleton instance"""
352
+ if cls._instance is None:
353
+ cls._instance = cls()
354
+ return cls._instance
@@ -0,0 +1,188 @@
1
+ import os
2
+ import logging
3
+ import torch
4
+ import numpy as np
5
+ from typing import Dict, List, Any, Optional, Union
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class BgeEmbedBackend:
11
+ """
12
+ PyTorch backend for the BGE embedding model.
13
+ """
14
+
15
+ def __init__(self, model_path: Optional[str] = None, device: str = "auto"):
16
+ """
17
+ Initialize the BGE embedding backend.
18
+
19
+ Args:
20
+ model_path: Path to the model
21
+ device: Device to run the model on ("cpu", "cuda", or "auto")
22
+ """
23
+ self.model_path = model_path or os.environ.get("BGE_MODEL_PATH", "/models/Bge-m3")
24
+ self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.model = None
26
+ self.tokenizer = None
27
+ self._loaded = False
28
+
29
+ # Default configuration
30
+ self.config = {
31
+ "normalize": True,
32
+ "max_length": 512,
33
+ "pooling_method": "cls" # Use CLS token for sentence embedding
34
+ }
35
+
36
+ self.logger = logger
37
+
38
+ def load(self) -> None:
39
+ """
40
+ Load the model and tokenizer.
41
+ """
42
+ if self._loaded:
43
+ return
44
+
45
+ try:
46
+ from transformers import AutoModel, AutoTokenizer
47
+
48
+ # Load tokenizer
49
+ self.logger.info(f"Loading BGE tokenizer from {self.model_path}")
50
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
51
+
52
+ # Load model
53
+ self.logger.info(f"Loading BGE model on {self.device}")
54
+ if self.device == "cpu":
55
+ self.model = AutoModel.from_pretrained(
56
+ self.model_path,
57
+ torch_dtype=torch.float32,
58
+ device_map="auto"
59
+ )
60
+ else: # cuda
61
+ self.model = AutoModel.from_pretrained(
62
+ self.model_path,
63
+ torch_dtype=torch.float16, # Use half precision on GPU
64
+ device_map="auto"
65
+ )
66
+
67
+ self.model.eval()
68
+ self._loaded = True
69
+ self.logger.info("BGE model loaded successfully")
70
+
71
+ except Exception as e:
72
+ self.logger.error(f"Failed to load BGE model: {str(e)}")
73
+ raise
74
+
75
+ def unload(self) -> None:
76
+ """
77
+ Unload the model and tokenizer.
78
+ """
79
+ if not self._loaded:
80
+ return
81
+
82
+ self.model = None
83
+ self.tokenizer = None
84
+ self._loaded = False
85
+
86
+ # Force garbage collection
87
+ import gc
88
+ gc.collect()
89
+
90
+ if self.device == "cuda":
91
+ torch.cuda.empty_cache()
92
+
93
+ self.logger.info("BGE model unloaded")
94
+
95
+ def embed(self,
96
+ texts: Union[str, List[str]],
97
+ normalize: Optional[bool] = None) -> np.ndarray:
98
+ """
99
+ Generate embeddings for texts.
100
+
101
+ Args:
102
+ texts: Single text or list of texts to embed
103
+ normalize: Whether to normalize embeddings (if None, use default)
104
+
105
+ Returns:
106
+ Numpy array of embeddings, shape [batch_size, embedding_dim]
107
+ """
108
+ if not self._loaded:
109
+ self.load()
110
+
111
+ # Handle single text input
112
+ if isinstance(texts, str):
113
+ texts = [texts]
114
+
115
+ # Use default normalize setting if not specified
116
+ if normalize is None:
117
+ normalize = self.config["normalize"]
118
+
119
+ try:
120
+ # Tokenize the texts
121
+ inputs = self.tokenizer(
122
+ texts,
123
+ padding=True,
124
+ truncation=True,
125
+ max_length=self.config["max_length"],
126
+ return_tensors="pt"
127
+ ).to(self.device)
128
+
129
+ # Generate embeddings
130
+ with torch.no_grad():
131
+ outputs = self.model(**inputs)
132
+
133
+ # Use [CLS] token embedding as the sentence embedding
134
+ embeddings = outputs.last_hidden_state[:, 0, :]
135
+
136
+ # Normalize embeddings if required
137
+ if normalize:
138
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
139
+
140
+ # Convert to numpy array
141
+ embeddings_np = embeddings.cpu().numpy()
142
+
143
+ return embeddings_np
144
+
145
+ except Exception as e:
146
+ self.logger.error(f"Error during BGE embedding generation: {str(e)}")
147
+ raise
148
+
149
+ def get_model_info(self) -> Dict[str, Any]:
150
+ """
151
+ Get information about the model.
152
+
153
+ Returns:
154
+ Dictionary containing model information
155
+ """
156
+ return {
157
+ "name": "bge-m3",
158
+ "type": "embedding",
159
+ "device": self.device,
160
+ "path": self.model_path,
161
+ "loaded": self._loaded,
162
+ "embedding_dim": 1024, # Typical for BGE models
163
+ "config": self.config
164
+ }
165
+
166
+ def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
167
+ """
168
+ Calculate cosine similarity between two embeddings.
169
+
170
+ Args:
171
+ embedding1: First embedding vector
172
+ embedding2: Second embedding vector
173
+
174
+ Returns:
175
+ Cosine similarity score (float between -1 and 1)
176
+ """
177
+ from sklearn.metrics.pairwise import cosine_similarity
178
+
179
+ # Reshape if needed
180
+ if embedding1.ndim == 1:
181
+ embedding1 = embedding1.reshape(1, -1)
182
+ if embedding2.ndim == 1:
183
+ embedding2 = embedding2.reshape(1, -1)
184
+
185
+ # Calculate cosine similarity
186
+ similarity = cosine_similarity(embedding1, embedding2)[0][0]
187
+
188
+ return float(similarity)