isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_registry.py +273 -46
  3. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  4. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  5. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  6. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  7. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  8. isa_model/eval/__init__.py +56 -0
  9. isa_model/eval/benchmarks.py +469 -0
  10. isa_model/eval/factory.py +582 -0
  11. isa_model/eval/metrics.py +628 -0
  12. isa_model/inference/ai_factory.py +98 -93
  13. isa_model/inference/providers/openai_provider.py +21 -7
  14. isa_model/inference/providers/replicate_provider.py +18 -5
  15. isa_model/inference/providers/triton_provider.py +1 -1
  16. isa_model/inference/services/audio/base_stt_service.py +91 -0
  17. isa_model/inference/services/audio/base_tts_service.py +136 -0
  18. isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
  19. isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
  20. isa_model/inference/services/llm/__init__.py +0 -4
  21. isa_model/inference/services/llm/base_llm_service.py +134 -0
  22. isa_model/inference/services/llm/ollama_llm_service.py +1 -10
  23. isa_model/inference/services/llm/openai_llm_service.py +70 -61
  24. isa_model/inference/services/vision/__init__.py +1 -1
  25. isa_model/inference/services/vision/ollama_vision_service.py +4 -4
  26. isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
  27. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  28. isa_model/training/__init__.py +44 -0
  29. isa_model/training/factory.py +393 -0
  30. isa_model-0.1.1.dist-info/METADATA +327 -0
  31. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
  32. isa_model/deployment/mlflow_gateway/__init__.py +0 -8
  33. isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
  34. isa_model/deployment/unified_multimodal_client.py +0 -341
  35. isa_model/inference/adapter/triton_adapter.py +0 -453
  36. isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
  37. isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
  38. isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
  39. isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
  40. isa_model/inference/backends/__init__.py +0 -53
  41. isa_model/inference/backends/base_backend_client.py +0 -26
  42. isa_model/inference/backends/container_services.py +0 -104
  43. isa_model/inference/backends/local_services.py +0 -72
  44. isa_model/inference/backends/openai_client.py +0 -130
  45. isa_model/inference/backends/replicate_client.py +0 -197
  46. isa_model/inference/backends/third_party_services.py +0 -239
  47. isa_model/inference/backends/triton_client.py +0 -97
  48. isa_model/inference/client_sdk/client.py +0 -134
  49. isa_model/inference/client_sdk/client_data_std.py +0 -34
  50. isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
  51. isa_model/inference/client_sdk/exceptions.py +0 -0
  52. isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
  53. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
  54. isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
  55. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
  56. isa_model/inference/providers/vllm_provider.py +0 -0
  57. isa_model/inference/providers/yyds_provider.py +0 -83
  58. isa_model/inference/services/audio/fish_speech/handler.py +0 -215
  59. isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
  60. isa_model/inference/services/audio/triton_speech_service.py +0 -138
  61. isa_model/inference/services/audio/whisper_service.py +0 -186
  62. isa_model/inference/services/base_tts_service.py +0 -66
  63. isa_model/inference/services/embedding/bge_service.py +0 -183
  64. isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
  65. isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
  66. isa_model/inference/services/llm/gemma_service.py +0 -143
  67. isa_model/inference/services/llm/llama_service.py +0 -143
  68. isa_model/inference/services/llm/replicate_llm_service.py +0 -179
  69. isa_model/inference/services/llm/triton_llm_service.py +0 -230
  70. isa_model/inference/services/vision/replicate_vision_service.py +0 -241
  71. isa_model/inference/services/vision/triton_vision_service.py +0 -199
  72. isa_model-0.1.0.dist-info/METADATA +0 -116
  73. /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
  74. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
  75. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
  76. {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,453 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Multimodal OpenAI-Compatible Adapter for Triton Inference Server
4
-
5
- This adapter translates between OpenAI API format and Triton Inference Server format.
6
- It supports multiple modalities (text, image, voice) through a unified API.
7
-
8
- Features:
9
- - Chat completions API (text)
10
- - Image generation API
11
- - Audio transcription API
12
- - Embeddings API
13
-
14
- The adapter routes requests to the appropriate Triton model based on the task.
15
- """
16
-
17
- import os
18
- import json
19
- import time
20
- import base64
21
- import logging
22
- import requests
23
- import tempfile
24
- import uvicorn
25
- import uuid
26
- from typing import List, Dict, Any, Optional, Union
27
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Body, BackgroundTasks
28
- from fastapi.responses import StreamingResponse, FileResponse
29
- from fastapi.middleware.cors import CORSMiddleware
30
- from pydantic import BaseModel, Field
31
- import numpy as np
32
- from datetime import datetime
33
-
34
- # Configure logging
35
- logging.basicConfig(level=logging.INFO)
36
- logger = logging.getLogger(__name__)
37
-
38
- # Initialize FastAPI app
39
- app = FastAPI(title="Multimodal OpenAI-Compatible Adapter")
40
-
41
- # Add CORS middleware
42
- app.add_middleware(
43
- CORSMiddleware,
44
- allow_origins=["*"],
45
- allow_credentials=True,
46
- allow_methods=["*"],
47
- allow_headers=["*"],
48
- )
49
-
50
- # Constants
51
- TRITON_URL = os.environ.get("TRITON_URL", "http://localhost:8000")
52
- DEFAULT_TEXT_MODEL = os.environ.get("DEFAULT_TEXT_MODEL", "llama3_cpu")
53
- DEFAULT_IMAGE_MODEL = os.environ.get("DEFAULT_IMAGE_MODEL", "stable_diffusion")
54
- DEFAULT_AUDIO_MODEL = os.environ.get("DEFAULT_AUDIO_MODEL", "whisper_tiny")
55
- DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "bge_m3")
56
- DEFAULT_VISION_MODEL = os.environ.get("DEFAULT_VISION_MODEL", "gemma3_4b")
57
-
58
- # ===== Schema Definitions =====
59
-
60
- class ChatMessage(BaseModel):
61
- role: str
62
- content: Union[str, List[Dict[str, Any]]]
63
- name: Optional[str] = None
64
-
65
- class ChatCompletionRequest(BaseModel):
66
- model: str
67
- messages: List[ChatMessage]
68
- temperature: Optional[float] = 0.7
69
- top_p: Optional[float] = 1.0
70
- max_tokens: Optional[int] = 100
71
- stream: Optional[bool] = False
72
- stop: Optional[Union[str, List[str]]] = None
73
-
74
- class ImageGenerationRequest(BaseModel):
75
- model: str
76
- prompt: str
77
- n: Optional[int] = 1
78
- size: Optional[str] = "1024x1024"
79
- response_format: Optional[str] = "url"
80
-
81
- class AudioTranscriptionRequest(BaseModel):
82
- model: str
83
- file: str # Base64 encoded audio
84
- response_format: Optional[str] = "text"
85
- language: Optional[str] = "en"
86
-
87
- class EmbeddingRequest(BaseModel):
88
- model: str
89
- input: Union[str, List[str]]
90
- encoding_format: Optional[str] = "float"
91
-
92
- # ===== Helper Functions =====
93
-
94
- def generate_response_id(prefix: str = "res") -> str:
95
- """Generate a unique response ID."""
96
- return f"{prefix}-{uuid.uuid4()}"
97
-
98
- def format_chat_response(content: str, model: str) -> Dict[str, Any]:
99
- """Format chat completion response in OpenAI format."""
100
- return {
101
- "id": generate_response_id("chatcmpl"),
102
- "object": "chat.completion",
103
- "created": int(time.time()),
104
- "model": model,
105
- "choices": [
106
- {
107
- "index": 0,
108
- "message": {
109
- "role": "assistant",
110
- "content": content
111
- },
112
- "finish_reason": "stop"
113
- }
114
- ],
115
- "usage": {
116
- "prompt_tokens": 0, # We don't track these yet
117
- "completion_tokens": 0,
118
- "total_tokens": 0
119
- }
120
- }
121
-
122
- def format_image_response(image_data: str, model: str) -> Dict[str, Any]:
123
- """Format image generation response in OpenAI format."""
124
- return {
125
- "created": int(time.time()),
126
- "data": [
127
- {
128
- "url": f"data:image/png;base64,{image_data}"
129
- }
130
- ]
131
- }
132
-
133
- def format_audio_response(text: str, model: str) -> Dict[str, Any]:
134
- """Format audio transcription response in OpenAI format."""
135
- return {
136
- "text": text
137
- }
138
-
139
- def format_embedding_response(embeddings: List[List[float]], model: str) -> Dict[str, Any]:
140
- """Format embedding response in OpenAI format."""
141
- data = []
142
- for i, embedding in enumerate(embeddings):
143
- data.append({
144
- "object": "embedding",
145
- "embedding": embedding,
146
- "index": i
147
- })
148
-
149
- return {
150
- "object": "list",
151
- "data": data,
152
- "model": model,
153
- "usage": {
154
- "prompt_tokens": 0,
155
- "total_tokens": 0
156
- }
157
- }
158
-
159
- def extract_content_from_messages(messages: List[ChatMessage]) -> Dict[str, Any]:
160
- """Extract content from messages for Triton input."""
161
- formatted_content = ""
162
- image_data = None
163
-
164
- for msg in messages:
165
- # Handle both string content and list of content parts
166
- if isinstance(msg.content, str):
167
- content = msg.content
168
- formatted_content += f"{msg.role.capitalize()}: {content}\n"
169
- else:
170
- # For multimodal content, extract text and image parts
171
- text_parts = []
172
- for part in msg.content:
173
- if part.get("type") == "text":
174
- text_parts.append(part.get("text", ""))
175
- elif part.get("type") == "image_url":
176
- # Extract image from URL (assuming base64 encoded)
177
- image_url = part.get("image_url", {}).get("url", "")
178
- if image_url.startswith("data:image/"):
179
- # Extract the base64 part
180
- image_data = image_url.split(",")[1]
181
-
182
- # Add text parts to formatted content
183
- content = " ".join(text_parts)
184
- formatted_content += f"{msg.role.capitalize()}: {content}\n"
185
-
186
- formatted_content += "Assistant:"
187
- return {"text": formatted_content, "image": image_data}
188
-
189
- # ===== API Routes =====
190
-
191
- @app.post("/v1/chat/completions")
192
- async def chat_completions(request: ChatCompletionRequest):
193
- """Handle chat completion requests."""
194
- logger.info(f"Received request: {request.dict()}")
195
-
196
- # Extract the formatted content from messages
197
- content = extract_content_from_messages(request.messages)
198
- input_text = content["text"]
199
- image_data = content["image"]
200
-
201
- # Use requested model or default
202
- model = request.model if request.model != "default" else DEFAULT_TEXT_MODEL
203
-
204
- # Prepare request for Triton
205
- triton_request = {
206
- "inputs": [
207
- {
208
- "name": "text_input",
209
- "shape": [1, 1],
210
- "datatype": "BYTES",
211
- "data": [input_text]
212
- },
213
- {
214
- "name": "max_tokens",
215
- "shape": [1, 1],
216
- "datatype": "INT32",
217
- "data": [request.max_tokens]
218
- },
219
- {
220
- "name": "temperature",
221
- "shape": [1, 1],
222
- "datatype": "FP32",
223
- "data": [request.temperature]
224
- }
225
- ]
226
- }
227
-
228
- # Add image input if available and using vision model
229
- if image_data is not None and model == "gemma3_4b":
230
- try:
231
- # Decode base64 image
232
- from PIL import Image
233
- import io
234
- import numpy as np
235
-
236
- # Decode and preprocess image
237
- image_bytes = base64.b64decode(image_data)
238
- image = Image.open(io.BytesIO(image_bytes))
239
-
240
- # Resize to expected size (224x224 for most vision models)
241
- image = image.resize((224, 224))
242
-
243
- # Convert to RGB if not already
244
- if image.mode != "RGB":
245
- image = image.convert("RGB")
246
-
247
- # Convert to numpy array and normalize
248
- image_array = np.array(image).astype(np.float32) / 255.0
249
-
250
- # Reorder from HWC to CHW format
251
- image_array = np.transpose(image_array, (2, 0, 1))
252
-
253
- # Add image input to Triton request
254
- triton_request["inputs"].append({
255
- "name": "image_input",
256
- "shape": list(image_array.shape),
257
- "datatype": "FP32",
258
- "data": image_array.flatten().tolist()
259
- })
260
-
261
- logger.info("Added image input to request")
262
- except Exception as e:
263
- logger.error(f"Error processing image: {str(e)}")
264
-
265
- logger.info(f"Sending to Triton: {triton_request}")
266
-
267
- # Send to Triton
268
- try:
269
- response = requests.post(
270
- f"{TRITON_URL}/v2/models/{model}/infer",
271
- json=triton_request
272
- )
273
- response.raise_for_status()
274
- triton_response = response.json()
275
- logger.info(f"Triton response status: {response.status_code}")
276
- logger.info(f"Triton response: {triton_response}")
277
-
278
- # Extract text output
279
- output_data = triton_response["outputs"][0]["data"][0]
280
-
281
- # Format response
282
- return format_chat_response(output_data, model)
283
-
284
- except Exception as e:
285
- logger.error(f"Error calling Triton: {str(e)}")
286
- raise HTTPException(status_code=500, detail=f"Error calling model: {str(e)}")
287
-
288
- @app.post("/v1/images/generations")
289
- async def generate_images(request: ImageGenerationRequest):
290
- """Handle image generation requests."""
291
- logger.info(f"Received image generation request: {request.dict()}")
292
-
293
- # Use requested model or default
294
- model = request.model if request.model != "default" else DEFAULT_IMAGE_MODEL
295
-
296
- # For demo purposes - in a real implementation, this would call the Triton image model
297
- # Here we'll just simulate image generation with a placeholder
298
- try:
299
- # Simulate Triton call (replace with actual call to Triton when image model is available)
300
- # Return a placeholder image for demonstration
301
- with open("placeholder.png", "rb") as f:
302
- image_data = base64.b64encode(f.read()).decode("utf-8")
303
-
304
- return format_image_response(image_data, model)
305
-
306
- except Exception as e:
307
- logger.error(f"Error generating image: {str(e)}")
308
- raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
309
-
310
- @app.post("/v1/audio/transcriptions")
311
- async def transcribe_audio(request: AudioTranscriptionRequest):
312
- """Handle audio transcription requests."""
313
- logger.info(f"Received audio transcription request: {request.dict()}")
314
-
315
- # Use requested model or default
316
- model = request.model if request.model != "default" else DEFAULT_AUDIO_MODEL
317
-
318
- try:
319
- # Decode the base64 audio
320
- audio_data = base64.b64decode(request.file)
321
-
322
- # Save to temporary file
323
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
324
- temp_file.write(audio_data)
325
- temp_file_path = temp_file.name
326
-
327
- # Load and preprocess audio for Whisper
328
- import librosa
329
- import numpy as np
330
-
331
- # Load audio file and resample to 16kHz for Whisper
332
- audio_array, _ = librosa.load(temp_file_path, sr=16000, mono=True)
333
-
334
- # Prepare request for Triton
335
- triton_request = {
336
- "inputs": [
337
- {
338
- "name": "audio_input",
339
- "shape": [len(audio_array)],
340
- "datatype": "FP32",
341
- "data": audio_array.tolist()
342
- }
343
- ]
344
- }
345
-
346
- # Add language if provided
347
- if hasattr(request, 'language') and request.language:
348
- triton_request["inputs"].append({
349
- "name": "language",
350
- "shape": [1, 1],
351
- "datatype": "BYTES",
352
- "data": [request.language]
353
- })
354
-
355
- # Clean up temp file
356
- os.unlink(temp_file_path)
357
-
358
- # Send to Triton
359
- response = requests.post(
360
- f"{TRITON_URL}/v2/models/{model}/infer",
361
- json=triton_request
362
- )
363
- response.raise_for_status()
364
- triton_response = response.json()
365
-
366
- # Extract text output
367
- transcription = triton_response["outputs"][0]["data"][0]
368
-
369
- return format_audio_response(transcription, model)
370
-
371
- except Exception as e:
372
- logger.error(f"Error transcribing audio: {str(e)}")
373
- # Fallback response
374
- return format_audio_response(
375
- "This is a placeholder transcription. In production, this would be generated by the Whisper model.",
376
- model
377
- )
378
-
379
- @app.post("/v1/embeddings")
380
- async def create_embeddings(request: EmbeddingRequest):
381
- """Handle embedding requests."""
382
- logger.info(f"Received embedding request: {request.dict()}")
383
-
384
- # Use requested model or default
385
- model = request.model if request.model != "default" else DEFAULT_EMBEDDING_MODEL
386
-
387
- # Convert input to list if it's a single string
388
- inputs = request.input if isinstance(request.input, list) else [request.input]
389
-
390
- try:
391
- # Process each input text
392
- all_embeddings = []
393
-
394
- for text in inputs:
395
- # Prepare request for Triton
396
- triton_request = {
397
- "inputs": [
398
- {
399
- "name": "text_input",
400
- "shape": [1, 1],
401
- "datatype": "BYTES",
402
- "data": [text]
403
- }
404
- ]
405
- }
406
-
407
- # Send to Triton
408
- response = requests.post(
409
- f"{TRITON_URL}/v2/models/{model}/infer",
410
- json=triton_request
411
- )
412
- response.raise_for_status()
413
- triton_response = response.json()
414
-
415
- # Extract embedding
416
- embedding = triton_response["outputs"][0]["data"]
417
- all_embeddings.append(embedding)
418
-
419
- return format_embedding_response(all_embeddings, model)
420
-
421
- except Exception as e:
422
- logger.error(f"Error creating embeddings: {str(e)}")
423
-
424
- # Fallback - return random embeddings
425
- embeddings = []
426
- for _ in inputs:
427
- # Generate a random embedding vector of dimension 1024 (BGE-M3)
428
- embedding = np.random.normal(0, 1, 1024).tolist()
429
- embeddings.append(embedding)
430
-
431
- return format_embedding_response(embeddings, model)
432
-
433
- @app.get("/health")
434
- async def health_check():
435
- """Health check endpoint."""
436
- return {"status": "healthy"}
437
-
438
- # ===== Main =====
439
-
440
- if __name__ == "__main__":
441
- # Create placeholder image for demo
442
- try:
443
- if not os.path.exists("placeholder.png"):
444
- # Create a simple 256x256 black image
445
- import numpy as np
446
- from PIL import Image
447
- img = Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
448
- img.save("placeholder.png")
449
- except ImportError:
450
- logger.warning("PIL not installed. Cannot create placeholder image.")
451
-
452
- # Start server
453
- uvicorn.run(app, host="0.0.0.0", port=8300)
@@ -1,188 +0,0 @@
1
- import os
2
- import logging
3
- import torch
4
- import numpy as np
5
- from typing import Dict, List, Any, Optional, Union
6
-
7
- logger = logging.getLogger(__name__)
8
-
9
-
10
- class BgeEmbedBackend:
11
- """
12
- PyTorch backend for the BGE embedding model.
13
- """
14
-
15
- def __init__(self, model_path: Optional[str] = None, device: str = "auto"):
16
- """
17
- Initialize the BGE embedding backend.
18
-
19
- Args:
20
- model_path: Path to the model
21
- device: Device to run the model on ("cpu", "cuda", or "auto")
22
- """
23
- self.model_path = model_path or os.environ.get("BGE_MODEL_PATH", "/models/Bge-m3")
24
- self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
25
- self.model = None
26
- self.tokenizer = None
27
- self._loaded = False
28
-
29
- # Default configuration
30
- self.config = {
31
- "normalize": True,
32
- "max_length": 512,
33
- "pooling_method": "cls" # Use CLS token for sentence embedding
34
- }
35
-
36
- self.logger = logger
37
-
38
- def load(self) -> None:
39
- """
40
- Load the model and tokenizer.
41
- """
42
- if self._loaded:
43
- return
44
-
45
- try:
46
- from transformers import AutoModel, AutoTokenizer
47
-
48
- # Load tokenizer
49
- self.logger.info(f"Loading BGE tokenizer from {self.model_path}")
50
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
51
-
52
- # Load model
53
- self.logger.info(f"Loading BGE model on {self.device}")
54
- if self.device == "cpu":
55
- self.model = AutoModel.from_pretrained(
56
- self.model_path,
57
- torch_dtype=torch.float32,
58
- device_map="auto"
59
- )
60
- else: # cuda
61
- self.model = AutoModel.from_pretrained(
62
- self.model_path,
63
- torch_dtype=torch.float16, # Use half precision on GPU
64
- device_map="auto"
65
- )
66
-
67
- self.model.eval()
68
- self._loaded = True
69
- self.logger.info("BGE model loaded successfully")
70
-
71
- except Exception as e:
72
- self.logger.error(f"Failed to load BGE model: {str(e)}")
73
- raise
74
-
75
- def unload(self) -> None:
76
- """
77
- Unload the model and tokenizer.
78
- """
79
- if not self._loaded:
80
- return
81
-
82
- self.model = None
83
- self.tokenizer = None
84
- self._loaded = False
85
-
86
- # Force garbage collection
87
- import gc
88
- gc.collect()
89
-
90
- if self.device == "cuda":
91
- torch.cuda.empty_cache()
92
-
93
- self.logger.info("BGE model unloaded")
94
-
95
- def embed(self,
96
- texts: Union[str, List[str]],
97
- normalize: Optional[bool] = None) -> np.ndarray:
98
- """
99
- Generate embeddings for texts.
100
-
101
- Args:
102
- texts: Single text or list of texts to embed
103
- normalize: Whether to normalize embeddings (if None, use default)
104
-
105
- Returns:
106
- Numpy array of embeddings, shape [batch_size, embedding_dim]
107
- """
108
- if not self._loaded:
109
- self.load()
110
-
111
- # Handle single text input
112
- if isinstance(texts, str):
113
- texts = [texts]
114
-
115
- # Use default normalize setting if not specified
116
- if normalize is None:
117
- normalize = self.config["normalize"]
118
-
119
- try:
120
- # Tokenize the texts
121
- inputs = self.tokenizer(
122
- texts,
123
- padding=True,
124
- truncation=True,
125
- max_length=self.config["max_length"],
126
- return_tensors="pt"
127
- ).to(self.device)
128
-
129
- # Generate embeddings
130
- with torch.no_grad():
131
- outputs = self.model(**inputs)
132
-
133
- # Use [CLS] token embedding as the sentence embedding
134
- embeddings = outputs.last_hidden_state[:, 0, :]
135
-
136
- # Normalize embeddings if required
137
- if normalize:
138
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
139
-
140
- # Convert to numpy array
141
- embeddings_np = embeddings.cpu().numpy()
142
-
143
- return embeddings_np
144
-
145
- except Exception as e:
146
- self.logger.error(f"Error during BGE embedding generation: {str(e)}")
147
- raise
148
-
149
- def get_model_info(self) -> Dict[str, Any]:
150
- """
151
- Get information about the model.
152
-
153
- Returns:
154
- Dictionary containing model information
155
- """
156
- return {
157
- "name": "bge-m3",
158
- "type": "embedding",
159
- "device": self.device,
160
- "path": self.model_path,
161
- "loaded": self._loaded,
162
- "embedding_dim": 1024, # Typical for BGE models
163
- "config": self.config
164
- }
165
-
166
- def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
167
- """
168
- Calculate cosine similarity between two embeddings.
169
-
170
- Args:
171
- embedding1: First embedding vector
172
- embedding2: Second embedding vector
173
-
174
- Returns:
175
- Cosine similarity score (float between -1 and 1)
176
- """
177
- from sklearn.metrics.pairwise import cosine_similarity
178
-
179
- # Reshape if needed
180
- if embedding1.ndim == 1:
181
- embedding1 = embedding1.reshape(1, -1)
182
- if embedding2.ndim == 1:
183
- embedding2 = embedding2.reshape(1, -1)
184
-
185
- # Calculate cosine similarity
186
- similarity = cosine_similarity(embedding1, embedding2)[0][0]
187
-
188
- return float(similarity)