isa-model 0.1.0__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_registry.py +273 -46
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/eval/__init__.py +56 -0
- isa_model/eval/benchmarks.py +469 -0
- isa_model/eval/factory.py +582 -0
- isa_model/eval/metrics.py +628 -0
- isa_model/inference/ai_factory.py +98 -93
- isa_model/inference/providers/openai_provider.py +21 -7
- isa_model/inference/providers/replicate_provider.py +18 -5
- isa_model/inference/providers/triton_provider.py +1 -1
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/{yyds_audio_service.py → openai_tts_service.py} +4 -4
- isa_model/inference/services/embedding/ollama_embed_service.py +48 -36
- isa_model/inference/services/llm/__init__.py +0 -4
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +1 -10
- isa_model/inference/services/llm/openai_llm_service.py +70 -61
- isa_model/inference/services/vision/__init__.py +1 -1
- isa_model/inference/services/vision/ollama_vision_service.py +4 -4
- isa_model/inference/services/vision/{yyds_vision_service.py → openai_vision_service.py} +5 -5
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/training/__init__.py +44 -0
- isa_model/training/factory.py +393 -0
- isa_model-0.1.1.dist-info/METADATA +327 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/RECORD +35 -60
- isa_model/deployment/mlflow_gateway/__init__.py +0 -8
- isa_model/deployment/mlflow_gateway/start_gateway.py +0 -65
- isa_model/deployment/unified_multimodal_client.py +0 -341
- isa_model/inference/adapter/triton_adapter.py +0 -453
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +0 -188
- isa_model/inference/backends/Pytorch/gemma_backend.py +0 -167
- isa_model/inference/backends/Pytorch/llama_backend.py +0 -166
- isa_model/inference/backends/Pytorch/whisper_backend.py +0 -194
- isa_model/inference/backends/__init__.py +0 -53
- isa_model/inference/backends/base_backend_client.py +0 -26
- isa_model/inference/backends/container_services.py +0 -104
- isa_model/inference/backends/local_services.py +0 -72
- isa_model/inference/backends/openai_client.py +0 -130
- isa_model/inference/backends/replicate_client.py +0 -197
- isa_model/inference/backends/third_party_services.py +0 -239
- isa_model/inference/backends/triton_client.py +0 -97
- isa_model/inference/client_sdk/client.py +0 -134
- isa_model/inference/client_sdk/client_data_std.py +0 -34
- isa_model/inference/client_sdk/client_sdk_schema.py +0 -16
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +0 -174
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +0 -250
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +0 -76
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +0 -195
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +0 -83
- isa_model/inference/services/audio/fish_speech/handler.py +0 -215
- isa_model/inference/services/audio/runpod_tts_fish_service.py +0 -212
- isa_model/inference/services/audio/triton_speech_service.py +0 -138
- isa_model/inference/services/audio/whisper_service.py +0 -186
- isa_model/inference/services/base_tts_service.py +0 -66
- isa_model/inference/services/embedding/bge_service.py +0 -183
- isa_model/inference/services/embedding/ollama_rerank_service.py +0 -118
- isa_model/inference/services/embedding/onnx_rerank_service.py +0 -73
- isa_model/inference/services/llm/gemma_service.py +0 -143
- isa_model/inference/services/llm/llama_service.py +0 -143
- isa_model/inference/services/llm/replicate_llm_service.py +0 -179
- isa_model/inference/services/llm/triton_llm_service.py +0 -230
- isa_model/inference/services/vision/replicate_vision_service.py +0 -241
- isa_model/inference/services/vision/triton_vision_service.py +0 -199
- isa_model-0.1.0.dist-info/METADATA +0 -116
- /isa_model/inference/{client_sdk/__init__.py → services/embedding/openai_embed_service.py} +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/WHEEL +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/licenses/LICENSE +0 -0
- {isa_model-0.1.0.dist-info → isa_model-0.1.1.dist-info}/top_level.txt +0 -0
@@ -1,453 +0,0 @@
|
|
1
|
-
#!/usr/bin/env python3
|
2
|
-
"""
|
3
|
-
Multimodal OpenAI-Compatible Adapter for Triton Inference Server
|
4
|
-
|
5
|
-
This adapter translates between OpenAI API format and Triton Inference Server format.
|
6
|
-
It supports multiple modalities (text, image, voice) through a unified API.
|
7
|
-
|
8
|
-
Features:
|
9
|
-
- Chat completions API (text)
|
10
|
-
- Image generation API
|
11
|
-
- Audio transcription API
|
12
|
-
- Embeddings API
|
13
|
-
|
14
|
-
The adapter routes requests to the appropriate Triton model based on the task.
|
15
|
-
"""
|
16
|
-
|
17
|
-
import os
|
18
|
-
import json
|
19
|
-
import time
|
20
|
-
import base64
|
21
|
-
import logging
|
22
|
-
import requests
|
23
|
-
import tempfile
|
24
|
-
import uvicorn
|
25
|
-
import uuid
|
26
|
-
from typing import List, Dict, Any, Optional, Union
|
27
|
-
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Body, BackgroundTasks
|
28
|
-
from fastapi.responses import StreamingResponse, FileResponse
|
29
|
-
from fastapi.middleware.cors import CORSMiddleware
|
30
|
-
from pydantic import BaseModel, Field
|
31
|
-
import numpy as np
|
32
|
-
from datetime import datetime
|
33
|
-
|
34
|
-
# Configure logging
|
35
|
-
logging.basicConfig(level=logging.INFO)
|
36
|
-
logger = logging.getLogger(__name__)
|
37
|
-
|
38
|
-
# Initialize FastAPI app
|
39
|
-
app = FastAPI(title="Multimodal OpenAI-Compatible Adapter")
|
40
|
-
|
41
|
-
# Add CORS middleware
|
42
|
-
app.add_middleware(
|
43
|
-
CORSMiddleware,
|
44
|
-
allow_origins=["*"],
|
45
|
-
allow_credentials=True,
|
46
|
-
allow_methods=["*"],
|
47
|
-
allow_headers=["*"],
|
48
|
-
)
|
49
|
-
|
50
|
-
# Constants
|
51
|
-
TRITON_URL = os.environ.get("TRITON_URL", "http://localhost:8000")
|
52
|
-
DEFAULT_TEXT_MODEL = os.environ.get("DEFAULT_TEXT_MODEL", "llama3_cpu")
|
53
|
-
DEFAULT_IMAGE_MODEL = os.environ.get("DEFAULT_IMAGE_MODEL", "stable_diffusion")
|
54
|
-
DEFAULT_AUDIO_MODEL = os.environ.get("DEFAULT_AUDIO_MODEL", "whisper_tiny")
|
55
|
-
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "bge_m3")
|
56
|
-
DEFAULT_VISION_MODEL = os.environ.get("DEFAULT_VISION_MODEL", "gemma3_4b")
|
57
|
-
|
58
|
-
# ===== Schema Definitions =====
|
59
|
-
|
60
|
-
class ChatMessage(BaseModel):
|
61
|
-
role: str
|
62
|
-
content: Union[str, List[Dict[str, Any]]]
|
63
|
-
name: Optional[str] = None
|
64
|
-
|
65
|
-
class ChatCompletionRequest(BaseModel):
|
66
|
-
model: str
|
67
|
-
messages: List[ChatMessage]
|
68
|
-
temperature: Optional[float] = 0.7
|
69
|
-
top_p: Optional[float] = 1.0
|
70
|
-
max_tokens: Optional[int] = 100
|
71
|
-
stream: Optional[bool] = False
|
72
|
-
stop: Optional[Union[str, List[str]]] = None
|
73
|
-
|
74
|
-
class ImageGenerationRequest(BaseModel):
|
75
|
-
model: str
|
76
|
-
prompt: str
|
77
|
-
n: Optional[int] = 1
|
78
|
-
size: Optional[str] = "1024x1024"
|
79
|
-
response_format: Optional[str] = "url"
|
80
|
-
|
81
|
-
class AudioTranscriptionRequest(BaseModel):
|
82
|
-
model: str
|
83
|
-
file: str # Base64 encoded audio
|
84
|
-
response_format: Optional[str] = "text"
|
85
|
-
language: Optional[str] = "en"
|
86
|
-
|
87
|
-
class EmbeddingRequest(BaseModel):
|
88
|
-
model: str
|
89
|
-
input: Union[str, List[str]]
|
90
|
-
encoding_format: Optional[str] = "float"
|
91
|
-
|
92
|
-
# ===== Helper Functions =====
|
93
|
-
|
94
|
-
def generate_response_id(prefix: str = "res") -> str:
|
95
|
-
"""Generate a unique response ID."""
|
96
|
-
return f"{prefix}-{uuid.uuid4()}"
|
97
|
-
|
98
|
-
def format_chat_response(content: str, model: str) -> Dict[str, Any]:
|
99
|
-
"""Format chat completion response in OpenAI format."""
|
100
|
-
return {
|
101
|
-
"id": generate_response_id("chatcmpl"),
|
102
|
-
"object": "chat.completion",
|
103
|
-
"created": int(time.time()),
|
104
|
-
"model": model,
|
105
|
-
"choices": [
|
106
|
-
{
|
107
|
-
"index": 0,
|
108
|
-
"message": {
|
109
|
-
"role": "assistant",
|
110
|
-
"content": content
|
111
|
-
},
|
112
|
-
"finish_reason": "stop"
|
113
|
-
}
|
114
|
-
],
|
115
|
-
"usage": {
|
116
|
-
"prompt_tokens": 0, # We don't track these yet
|
117
|
-
"completion_tokens": 0,
|
118
|
-
"total_tokens": 0
|
119
|
-
}
|
120
|
-
}
|
121
|
-
|
122
|
-
def format_image_response(image_data: str, model: str) -> Dict[str, Any]:
|
123
|
-
"""Format image generation response in OpenAI format."""
|
124
|
-
return {
|
125
|
-
"created": int(time.time()),
|
126
|
-
"data": [
|
127
|
-
{
|
128
|
-
"url": f"data:image/png;base64,{image_data}"
|
129
|
-
}
|
130
|
-
]
|
131
|
-
}
|
132
|
-
|
133
|
-
def format_audio_response(text: str, model: str) -> Dict[str, Any]:
|
134
|
-
"""Format audio transcription response in OpenAI format."""
|
135
|
-
return {
|
136
|
-
"text": text
|
137
|
-
}
|
138
|
-
|
139
|
-
def format_embedding_response(embeddings: List[List[float]], model: str) -> Dict[str, Any]:
|
140
|
-
"""Format embedding response in OpenAI format."""
|
141
|
-
data = []
|
142
|
-
for i, embedding in enumerate(embeddings):
|
143
|
-
data.append({
|
144
|
-
"object": "embedding",
|
145
|
-
"embedding": embedding,
|
146
|
-
"index": i
|
147
|
-
})
|
148
|
-
|
149
|
-
return {
|
150
|
-
"object": "list",
|
151
|
-
"data": data,
|
152
|
-
"model": model,
|
153
|
-
"usage": {
|
154
|
-
"prompt_tokens": 0,
|
155
|
-
"total_tokens": 0
|
156
|
-
}
|
157
|
-
}
|
158
|
-
|
159
|
-
def extract_content_from_messages(messages: List[ChatMessage]) -> Dict[str, Any]:
|
160
|
-
"""Extract content from messages for Triton input."""
|
161
|
-
formatted_content = ""
|
162
|
-
image_data = None
|
163
|
-
|
164
|
-
for msg in messages:
|
165
|
-
# Handle both string content and list of content parts
|
166
|
-
if isinstance(msg.content, str):
|
167
|
-
content = msg.content
|
168
|
-
formatted_content += f"{msg.role.capitalize()}: {content}\n"
|
169
|
-
else:
|
170
|
-
# For multimodal content, extract text and image parts
|
171
|
-
text_parts = []
|
172
|
-
for part in msg.content:
|
173
|
-
if part.get("type") == "text":
|
174
|
-
text_parts.append(part.get("text", ""))
|
175
|
-
elif part.get("type") == "image_url":
|
176
|
-
# Extract image from URL (assuming base64 encoded)
|
177
|
-
image_url = part.get("image_url", {}).get("url", "")
|
178
|
-
if image_url.startswith("data:image/"):
|
179
|
-
# Extract the base64 part
|
180
|
-
image_data = image_url.split(",")[1]
|
181
|
-
|
182
|
-
# Add text parts to formatted content
|
183
|
-
content = " ".join(text_parts)
|
184
|
-
formatted_content += f"{msg.role.capitalize()}: {content}\n"
|
185
|
-
|
186
|
-
formatted_content += "Assistant:"
|
187
|
-
return {"text": formatted_content, "image": image_data}
|
188
|
-
|
189
|
-
# ===== API Routes =====
|
190
|
-
|
191
|
-
@app.post("/v1/chat/completions")
|
192
|
-
async def chat_completions(request: ChatCompletionRequest):
|
193
|
-
"""Handle chat completion requests."""
|
194
|
-
logger.info(f"Received request: {request.dict()}")
|
195
|
-
|
196
|
-
# Extract the formatted content from messages
|
197
|
-
content = extract_content_from_messages(request.messages)
|
198
|
-
input_text = content["text"]
|
199
|
-
image_data = content["image"]
|
200
|
-
|
201
|
-
# Use requested model or default
|
202
|
-
model = request.model if request.model != "default" else DEFAULT_TEXT_MODEL
|
203
|
-
|
204
|
-
# Prepare request for Triton
|
205
|
-
triton_request = {
|
206
|
-
"inputs": [
|
207
|
-
{
|
208
|
-
"name": "text_input",
|
209
|
-
"shape": [1, 1],
|
210
|
-
"datatype": "BYTES",
|
211
|
-
"data": [input_text]
|
212
|
-
},
|
213
|
-
{
|
214
|
-
"name": "max_tokens",
|
215
|
-
"shape": [1, 1],
|
216
|
-
"datatype": "INT32",
|
217
|
-
"data": [request.max_tokens]
|
218
|
-
},
|
219
|
-
{
|
220
|
-
"name": "temperature",
|
221
|
-
"shape": [1, 1],
|
222
|
-
"datatype": "FP32",
|
223
|
-
"data": [request.temperature]
|
224
|
-
}
|
225
|
-
]
|
226
|
-
}
|
227
|
-
|
228
|
-
# Add image input if available and using vision model
|
229
|
-
if image_data is not None and model == "gemma3_4b":
|
230
|
-
try:
|
231
|
-
# Decode base64 image
|
232
|
-
from PIL import Image
|
233
|
-
import io
|
234
|
-
import numpy as np
|
235
|
-
|
236
|
-
# Decode and preprocess image
|
237
|
-
image_bytes = base64.b64decode(image_data)
|
238
|
-
image = Image.open(io.BytesIO(image_bytes))
|
239
|
-
|
240
|
-
# Resize to expected size (224x224 for most vision models)
|
241
|
-
image = image.resize((224, 224))
|
242
|
-
|
243
|
-
# Convert to RGB if not already
|
244
|
-
if image.mode != "RGB":
|
245
|
-
image = image.convert("RGB")
|
246
|
-
|
247
|
-
# Convert to numpy array and normalize
|
248
|
-
image_array = np.array(image).astype(np.float32) / 255.0
|
249
|
-
|
250
|
-
# Reorder from HWC to CHW format
|
251
|
-
image_array = np.transpose(image_array, (2, 0, 1))
|
252
|
-
|
253
|
-
# Add image input to Triton request
|
254
|
-
triton_request["inputs"].append({
|
255
|
-
"name": "image_input",
|
256
|
-
"shape": list(image_array.shape),
|
257
|
-
"datatype": "FP32",
|
258
|
-
"data": image_array.flatten().tolist()
|
259
|
-
})
|
260
|
-
|
261
|
-
logger.info("Added image input to request")
|
262
|
-
except Exception as e:
|
263
|
-
logger.error(f"Error processing image: {str(e)}")
|
264
|
-
|
265
|
-
logger.info(f"Sending to Triton: {triton_request}")
|
266
|
-
|
267
|
-
# Send to Triton
|
268
|
-
try:
|
269
|
-
response = requests.post(
|
270
|
-
f"{TRITON_URL}/v2/models/{model}/infer",
|
271
|
-
json=triton_request
|
272
|
-
)
|
273
|
-
response.raise_for_status()
|
274
|
-
triton_response = response.json()
|
275
|
-
logger.info(f"Triton response status: {response.status_code}")
|
276
|
-
logger.info(f"Triton response: {triton_response}")
|
277
|
-
|
278
|
-
# Extract text output
|
279
|
-
output_data = triton_response["outputs"][0]["data"][0]
|
280
|
-
|
281
|
-
# Format response
|
282
|
-
return format_chat_response(output_data, model)
|
283
|
-
|
284
|
-
except Exception as e:
|
285
|
-
logger.error(f"Error calling Triton: {str(e)}")
|
286
|
-
raise HTTPException(status_code=500, detail=f"Error calling model: {str(e)}")
|
287
|
-
|
288
|
-
@app.post("/v1/images/generations")
|
289
|
-
async def generate_images(request: ImageGenerationRequest):
|
290
|
-
"""Handle image generation requests."""
|
291
|
-
logger.info(f"Received image generation request: {request.dict()}")
|
292
|
-
|
293
|
-
# Use requested model or default
|
294
|
-
model = request.model if request.model != "default" else DEFAULT_IMAGE_MODEL
|
295
|
-
|
296
|
-
# For demo purposes - in a real implementation, this would call the Triton image model
|
297
|
-
# Here we'll just simulate image generation with a placeholder
|
298
|
-
try:
|
299
|
-
# Simulate Triton call (replace with actual call to Triton when image model is available)
|
300
|
-
# Return a placeholder image for demonstration
|
301
|
-
with open("placeholder.png", "rb") as f:
|
302
|
-
image_data = base64.b64encode(f.read()).decode("utf-8")
|
303
|
-
|
304
|
-
return format_image_response(image_data, model)
|
305
|
-
|
306
|
-
except Exception as e:
|
307
|
-
logger.error(f"Error generating image: {str(e)}")
|
308
|
-
raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
|
309
|
-
|
310
|
-
@app.post("/v1/audio/transcriptions")
|
311
|
-
async def transcribe_audio(request: AudioTranscriptionRequest):
|
312
|
-
"""Handle audio transcription requests."""
|
313
|
-
logger.info(f"Received audio transcription request: {request.dict()}")
|
314
|
-
|
315
|
-
# Use requested model or default
|
316
|
-
model = request.model if request.model != "default" else DEFAULT_AUDIO_MODEL
|
317
|
-
|
318
|
-
try:
|
319
|
-
# Decode the base64 audio
|
320
|
-
audio_data = base64.b64decode(request.file)
|
321
|
-
|
322
|
-
# Save to temporary file
|
323
|
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
324
|
-
temp_file.write(audio_data)
|
325
|
-
temp_file_path = temp_file.name
|
326
|
-
|
327
|
-
# Load and preprocess audio for Whisper
|
328
|
-
import librosa
|
329
|
-
import numpy as np
|
330
|
-
|
331
|
-
# Load audio file and resample to 16kHz for Whisper
|
332
|
-
audio_array, _ = librosa.load(temp_file_path, sr=16000, mono=True)
|
333
|
-
|
334
|
-
# Prepare request for Triton
|
335
|
-
triton_request = {
|
336
|
-
"inputs": [
|
337
|
-
{
|
338
|
-
"name": "audio_input",
|
339
|
-
"shape": [len(audio_array)],
|
340
|
-
"datatype": "FP32",
|
341
|
-
"data": audio_array.tolist()
|
342
|
-
}
|
343
|
-
]
|
344
|
-
}
|
345
|
-
|
346
|
-
# Add language if provided
|
347
|
-
if hasattr(request, 'language') and request.language:
|
348
|
-
triton_request["inputs"].append({
|
349
|
-
"name": "language",
|
350
|
-
"shape": [1, 1],
|
351
|
-
"datatype": "BYTES",
|
352
|
-
"data": [request.language]
|
353
|
-
})
|
354
|
-
|
355
|
-
# Clean up temp file
|
356
|
-
os.unlink(temp_file_path)
|
357
|
-
|
358
|
-
# Send to Triton
|
359
|
-
response = requests.post(
|
360
|
-
f"{TRITON_URL}/v2/models/{model}/infer",
|
361
|
-
json=triton_request
|
362
|
-
)
|
363
|
-
response.raise_for_status()
|
364
|
-
triton_response = response.json()
|
365
|
-
|
366
|
-
# Extract text output
|
367
|
-
transcription = triton_response["outputs"][0]["data"][0]
|
368
|
-
|
369
|
-
return format_audio_response(transcription, model)
|
370
|
-
|
371
|
-
except Exception as e:
|
372
|
-
logger.error(f"Error transcribing audio: {str(e)}")
|
373
|
-
# Fallback response
|
374
|
-
return format_audio_response(
|
375
|
-
"This is a placeholder transcription. In production, this would be generated by the Whisper model.",
|
376
|
-
model
|
377
|
-
)
|
378
|
-
|
379
|
-
@app.post("/v1/embeddings")
|
380
|
-
async def create_embeddings(request: EmbeddingRequest):
|
381
|
-
"""Handle embedding requests."""
|
382
|
-
logger.info(f"Received embedding request: {request.dict()}")
|
383
|
-
|
384
|
-
# Use requested model or default
|
385
|
-
model = request.model if request.model != "default" else DEFAULT_EMBEDDING_MODEL
|
386
|
-
|
387
|
-
# Convert input to list if it's a single string
|
388
|
-
inputs = request.input if isinstance(request.input, list) else [request.input]
|
389
|
-
|
390
|
-
try:
|
391
|
-
# Process each input text
|
392
|
-
all_embeddings = []
|
393
|
-
|
394
|
-
for text in inputs:
|
395
|
-
# Prepare request for Triton
|
396
|
-
triton_request = {
|
397
|
-
"inputs": [
|
398
|
-
{
|
399
|
-
"name": "text_input",
|
400
|
-
"shape": [1, 1],
|
401
|
-
"datatype": "BYTES",
|
402
|
-
"data": [text]
|
403
|
-
}
|
404
|
-
]
|
405
|
-
}
|
406
|
-
|
407
|
-
# Send to Triton
|
408
|
-
response = requests.post(
|
409
|
-
f"{TRITON_URL}/v2/models/{model}/infer",
|
410
|
-
json=triton_request
|
411
|
-
)
|
412
|
-
response.raise_for_status()
|
413
|
-
triton_response = response.json()
|
414
|
-
|
415
|
-
# Extract embedding
|
416
|
-
embedding = triton_response["outputs"][0]["data"]
|
417
|
-
all_embeddings.append(embedding)
|
418
|
-
|
419
|
-
return format_embedding_response(all_embeddings, model)
|
420
|
-
|
421
|
-
except Exception as e:
|
422
|
-
logger.error(f"Error creating embeddings: {str(e)}")
|
423
|
-
|
424
|
-
# Fallback - return random embeddings
|
425
|
-
embeddings = []
|
426
|
-
for _ in inputs:
|
427
|
-
# Generate a random embedding vector of dimension 1024 (BGE-M3)
|
428
|
-
embedding = np.random.normal(0, 1, 1024).tolist()
|
429
|
-
embeddings.append(embedding)
|
430
|
-
|
431
|
-
return format_embedding_response(embeddings, model)
|
432
|
-
|
433
|
-
@app.get("/health")
|
434
|
-
async def health_check():
|
435
|
-
"""Health check endpoint."""
|
436
|
-
return {"status": "healthy"}
|
437
|
-
|
438
|
-
# ===== Main =====
|
439
|
-
|
440
|
-
if __name__ == "__main__":
|
441
|
-
# Create placeholder image for demo
|
442
|
-
try:
|
443
|
-
if not os.path.exists("placeholder.png"):
|
444
|
-
# Create a simple 256x256 black image
|
445
|
-
import numpy as np
|
446
|
-
from PIL import Image
|
447
|
-
img = Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
|
448
|
-
img.save("placeholder.png")
|
449
|
-
except ImportError:
|
450
|
-
logger.warning("PIL not installed. Cannot create placeholder image.")
|
451
|
-
|
452
|
-
# Start server
|
453
|
-
uvicorn.run(app, host="0.0.0.0", port=8300)
|
@@ -1,188 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import logging
|
3
|
-
import torch
|
4
|
-
import numpy as np
|
5
|
-
from typing import Dict, List, Any, Optional, Union
|
6
|
-
|
7
|
-
logger = logging.getLogger(__name__)
|
8
|
-
|
9
|
-
|
10
|
-
class BgeEmbedBackend:
|
11
|
-
"""
|
12
|
-
PyTorch backend for the BGE embedding model.
|
13
|
-
"""
|
14
|
-
|
15
|
-
def __init__(self, model_path: Optional[str] = None, device: str = "auto"):
|
16
|
-
"""
|
17
|
-
Initialize the BGE embedding backend.
|
18
|
-
|
19
|
-
Args:
|
20
|
-
model_path: Path to the model
|
21
|
-
device: Device to run the model on ("cpu", "cuda", or "auto")
|
22
|
-
"""
|
23
|
-
self.model_path = model_path or os.environ.get("BGE_MODEL_PATH", "/models/Bge-m3")
|
24
|
-
self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
|
25
|
-
self.model = None
|
26
|
-
self.tokenizer = None
|
27
|
-
self._loaded = False
|
28
|
-
|
29
|
-
# Default configuration
|
30
|
-
self.config = {
|
31
|
-
"normalize": True,
|
32
|
-
"max_length": 512,
|
33
|
-
"pooling_method": "cls" # Use CLS token for sentence embedding
|
34
|
-
}
|
35
|
-
|
36
|
-
self.logger = logger
|
37
|
-
|
38
|
-
def load(self) -> None:
|
39
|
-
"""
|
40
|
-
Load the model and tokenizer.
|
41
|
-
"""
|
42
|
-
if self._loaded:
|
43
|
-
return
|
44
|
-
|
45
|
-
try:
|
46
|
-
from transformers import AutoModel, AutoTokenizer
|
47
|
-
|
48
|
-
# Load tokenizer
|
49
|
-
self.logger.info(f"Loading BGE tokenizer from {self.model_path}")
|
50
|
-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
|
51
|
-
|
52
|
-
# Load model
|
53
|
-
self.logger.info(f"Loading BGE model on {self.device}")
|
54
|
-
if self.device == "cpu":
|
55
|
-
self.model = AutoModel.from_pretrained(
|
56
|
-
self.model_path,
|
57
|
-
torch_dtype=torch.float32,
|
58
|
-
device_map="auto"
|
59
|
-
)
|
60
|
-
else: # cuda
|
61
|
-
self.model = AutoModel.from_pretrained(
|
62
|
-
self.model_path,
|
63
|
-
torch_dtype=torch.float16, # Use half precision on GPU
|
64
|
-
device_map="auto"
|
65
|
-
)
|
66
|
-
|
67
|
-
self.model.eval()
|
68
|
-
self._loaded = True
|
69
|
-
self.logger.info("BGE model loaded successfully")
|
70
|
-
|
71
|
-
except Exception as e:
|
72
|
-
self.logger.error(f"Failed to load BGE model: {str(e)}")
|
73
|
-
raise
|
74
|
-
|
75
|
-
def unload(self) -> None:
|
76
|
-
"""
|
77
|
-
Unload the model and tokenizer.
|
78
|
-
"""
|
79
|
-
if not self._loaded:
|
80
|
-
return
|
81
|
-
|
82
|
-
self.model = None
|
83
|
-
self.tokenizer = None
|
84
|
-
self._loaded = False
|
85
|
-
|
86
|
-
# Force garbage collection
|
87
|
-
import gc
|
88
|
-
gc.collect()
|
89
|
-
|
90
|
-
if self.device == "cuda":
|
91
|
-
torch.cuda.empty_cache()
|
92
|
-
|
93
|
-
self.logger.info("BGE model unloaded")
|
94
|
-
|
95
|
-
def embed(self,
|
96
|
-
texts: Union[str, List[str]],
|
97
|
-
normalize: Optional[bool] = None) -> np.ndarray:
|
98
|
-
"""
|
99
|
-
Generate embeddings for texts.
|
100
|
-
|
101
|
-
Args:
|
102
|
-
texts: Single text or list of texts to embed
|
103
|
-
normalize: Whether to normalize embeddings (if None, use default)
|
104
|
-
|
105
|
-
Returns:
|
106
|
-
Numpy array of embeddings, shape [batch_size, embedding_dim]
|
107
|
-
"""
|
108
|
-
if not self._loaded:
|
109
|
-
self.load()
|
110
|
-
|
111
|
-
# Handle single text input
|
112
|
-
if isinstance(texts, str):
|
113
|
-
texts = [texts]
|
114
|
-
|
115
|
-
# Use default normalize setting if not specified
|
116
|
-
if normalize is None:
|
117
|
-
normalize = self.config["normalize"]
|
118
|
-
|
119
|
-
try:
|
120
|
-
# Tokenize the texts
|
121
|
-
inputs = self.tokenizer(
|
122
|
-
texts,
|
123
|
-
padding=True,
|
124
|
-
truncation=True,
|
125
|
-
max_length=self.config["max_length"],
|
126
|
-
return_tensors="pt"
|
127
|
-
).to(self.device)
|
128
|
-
|
129
|
-
# Generate embeddings
|
130
|
-
with torch.no_grad():
|
131
|
-
outputs = self.model(**inputs)
|
132
|
-
|
133
|
-
# Use [CLS] token embedding as the sentence embedding
|
134
|
-
embeddings = outputs.last_hidden_state[:, 0, :]
|
135
|
-
|
136
|
-
# Normalize embeddings if required
|
137
|
-
if normalize:
|
138
|
-
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
139
|
-
|
140
|
-
# Convert to numpy array
|
141
|
-
embeddings_np = embeddings.cpu().numpy()
|
142
|
-
|
143
|
-
return embeddings_np
|
144
|
-
|
145
|
-
except Exception as e:
|
146
|
-
self.logger.error(f"Error during BGE embedding generation: {str(e)}")
|
147
|
-
raise
|
148
|
-
|
149
|
-
def get_model_info(self) -> Dict[str, Any]:
|
150
|
-
"""
|
151
|
-
Get information about the model.
|
152
|
-
|
153
|
-
Returns:
|
154
|
-
Dictionary containing model information
|
155
|
-
"""
|
156
|
-
return {
|
157
|
-
"name": "bge-m3",
|
158
|
-
"type": "embedding",
|
159
|
-
"device": self.device,
|
160
|
-
"path": self.model_path,
|
161
|
-
"loaded": self._loaded,
|
162
|
-
"embedding_dim": 1024, # Typical for BGE models
|
163
|
-
"config": self.config
|
164
|
-
}
|
165
|
-
|
166
|
-
def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
167
|
-
"""
|
168
|
-
Calculate cosine similarity between two embeddings.
|
169
|
-
|
170
|
-
Args:
|
171
|
-
embedding1: First embedding vector
|
172
|
-
embedding2: Second embedding vector
|
173
|
-
|
174
|
-
Returns:
|
175
|
-
Cosine similarity score (float between -1 and 1)
|
176
|
-
"""
|
177
|
-
from sklearn.metrics.pairwise import cosine_similarity
|
178
|
-
|
179
|
-
# Reshape if needed
|
180
|
-
if embedding1.ndim == 1:
|
181
|
-
embedding1 = embedding1.reshape(1, -1)
|
182
|
-
if embedding2.ndim == 1:
|
183
|
-
embedding2 = embedding2.reshape(1, -1)
|
184
|
-
|
185
|
-
# Calculate cosine similarity
|
186
|
-
similarity = cosine_similarity(embedding1, embedding2)[0][0]
|
187
|
-
|
188
|
-
return float(similarity)
|