isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,453 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Multimodal OpenAI-Compatible Adapter for Triton Inference Server
|
4
|
+
|
5
|
+
This adapter translates between OpenAI API format and Triton Inference Server format.
|
6
|
+
It supports multiple modalities (text, image, voice) through a unified API.
|
7
|
+
|
8
|
+
Features:
|
9
|
+
- Chat completions API (text)
|
10
|
+
- Image generation API
|
11
|
+
- Audio transcription API
|
12
|
+
- Embeddings API
|
13
|
+
|
14
|
+
The adapter routes requests to the appropriate Triton model based on the task.
|
15
|
+
"""
|
16
|
+
|
17
|
+
import os
|
18
|
+
import json
|
19
|
+
import time
|
20
|
+
import base64
|
21
|
+
import logging
|
22
|
+
import requests
|
23
|
+
import tempfile
|
24
|
+
import uvicorn
|
25
|
+
import uuid
|
26
|
+
from typing import List, Dict, Any, Optional, Union
|
27
|
+
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Body, BackgroundTasks
|
28
|
+
from fastapi.responses import StreamingResponse, FileResponse
|
29
|
+
from fastapi.middleware.cors import CORSMiddleware
|
30
|
+
from pydantic import BaseModel, Field
|
31
|
+
import numpy as np
|
32
|
+
from datetime import datetime
|
33
|
+
|
34
|
+
# Configure logging
|
35
|
+
logging.basicConfig(level=logging.INFO)
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
|
38
|
+
# Initialize FastAPI app
|
39
|
+
app = FastAPI(title="Multimodal OpenAI-Compatible Adapter")
|
40
|
+
|
41
|
+
# Add CORS middleware
|
42
|
+
app.add_middleware(
|
43
|
+
CORSMiddleware,
|
44
|
+
allow_origins=["*"],
|
45
|
+
allow_credentials=True,
|
46
|
+
allow_methods=["*"],
|
47
|
+
allow_headers=["*"],
|
48
|
+
)
|
49
|
+
|
50
|
+
# Constants
|
51
|
+
TRITON_URL = os.environ.get("TRITON_URL", "http://localhost:8000")
|
52
|
+
DEFAULT_TEXT_MODEL = os.environ.get("DEFAULT_TEXT_MODEL", "llama3_cpu")
|
53
|
+
DEFAULT_IMAGE_MODEL = os.environ.get("DEFAULT_IMAGE_MODEL", "stable_diffusion")
|
54
|
+
DEFAULT_AUDIO_MODEL = os.environ.get("DEFAULT_AUDIO_MODEL", "whisper_tiny")
|
55
|
+
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "bge_m3")
|
56
|
+
DEFAULT_VISION_MODEL = os.environ.get("DEFAULT_VISION_MODEL", "gemma3_4b")
|
57
|
+
|
58
|
+
# ===== Schema Definitions =====
|
59
|
+
|
60
|
+
class ChatMessage(BaseModel):
|
61
|
+
role: str
|
62
|
+
content: Union[str, List[Dict[str, Any]]]
|
63
|
+
name: Optional[str] = None
|
64
|
+
|
65
|
+
class ChatCompletionRequest(BaseModel):
|
66
|
+
model: str
|
67
|
+
messages: List[ChatMessage]
|
68
|
+
temperature: Optional[float] = 0.7
|
69
|
+
top_p: Optional[float] = 1.0
|
70
|
+
max_tokens: Optional[int] = 100
|
71
|
+
stream: Optional[bool] = False
|
72
|
+
stop: Optional[Union[str, List[str]]] = None
|
73
|
+
|
74
|
+
class ImageGenerationRequest(BaseModel):
|
75
|
+
model: str
|
76
|
+
prompt: str
|
77
|
+
n: Optional[int] = 1
|
78
|
+
size: Optional[str] = "1024x1024"
|
79
|
+
response_format: Optional[str] = "url"
|
80
|
+
|
81
|
+
class AudioTranscriptionRequest(BaseModel):
|
82
|
+
model: str
|
83
|
+
file: str # Base64 encoded audio
|
84
|
+
response_format: Optional[str] = "text"
|
85
|
+
language: Optional[str] = "en"
|
86
|
+
|
87
|
+
class EmbeddingRequest(BaseModel):
|
88
|
+
model: str
|
89
|
+
input: Union[str, List[str]]
|
90
|
+
encoding_format: Optional[str] = "float"
|
91
|
+
|
92
|
+
# ===== Helper Functions =====
|
93
|
+
|
94
|
+
def generate_response_id(prefix: str = "res") -> str:
|
95
|
+
"""Generate a unique response ID."""
|
96
|
+
return f"{prefix}-{uuid.uuid4()}"
|
97
|
+
|
98
|
+
def format_chat_response(content: str, model: str) -> Dict[str, Any]:
|
99
|
+
"""Format chat completion response in OpenAI format."""
|
100
|
+
return {
|
101
|
+
"id": generate_response_id("chatcmpl"),
|
102
|
+
"object": "chat.completion",
|
103
|
+
"created": int(time.time()),
|
104
|
+
"model": model,
|
105
|
+
"choices": [
|
106
|
+
{
|
107
|
+
"index": 0,
|
108
|
+
"message": {
|
109
|
+
"role": "assistant",
|
110
|
+
"content": content
|
111
|
+
},
|
112
|
+
"finish_reason": "stop"
|
113
|
+
}
|
114
|
+
],
|
115
|
+
"usage": {
|
116
|
+
"prompt_tokens": 0, # We don't track these yet
|
117
|
+
"completion_tokens": 0,
|
118
|
+
"total_tokens": 0
|
119
|
+
}
|
120
|
+
}
|
121
|
+
|
122
|
+
def format_image_response(image_data: str, model: str) -> Dict[str, Any]:
|
123
|
+
"""Format image generation response in OpenAI format."""
|
124
|
+
return {
|
125
|
+
"created": int(time.time()),
|
126
|
+
"data": [
|
127
|
+
{
|
128
|
+
"url": f"data:image/png;base64,{image_data}"
|
129
|
+
}
|
130
|
+
]
|
131
|
+
}
|
132
|
+
|
133
|
+
def format_audio_response(text: str, model: str) -> Dict[str, Any]:
|
134
|
+
"""Format audio transcription response in OpenAI format."""
|
135
|
+
return {
|
136
|
+
"text": text
|
137
|
+
}
|
138
|
+
|
139
|
+
def format_embedding_response(embeddings: List[List[float]], model: str) -> Dict[str, Any]:
|
140
|
+
"""Format embedding response in OpenAI format."""
|
141
|
+
data = []
|
142
|
+
for i, embedding in enumerate(embeddings):
|
143
|
+
data.append({
|
144
|
+
"object": "embedding",
|
145
|
+
"embedding": embedding,
|
146
|
+
"index": i
|
147
|
+
})
|
148
|
+
|
149
|
+
return {
|
150
|
+
"object": "list",
|
151
|
+
"data": data,
|
152
|
+
"model": model,
|
153
|
+
"usage": {
|
154
|
+
"prompt_tokens": 0,
|
155
|
+
"total_tokens": 0
|
156
|
+
}
|
157
|
+
}
|
158
|
+
|
159
|
+
def extract_content_from_messages(messages: List[ChatMessage]) -> Dict[str, Any]:
|
160
|
+
"""Extract content from messages for Triton input."""
|
161
|
+
formatted_content = ""
|
162
|
+
image_data = None
|
163
|
+
|
164
|
+
for msg in messages:
|
165
|
+
# Handle both string content and list of content parts
|
166
|
+
if isinstance(msg.content, str):
|
167
|
+
content = msg.content
|
168
|
+
formatted_content += f"{msg.role.capitalize()}: {content}\n"
|
169
|
+
else:
|
170
|
+
# For multimodal content, extract text and image parts
|
171
|
+
text_parts = []
|
172
|
+
for part in msg.content:
|
173
|
+
if part.get("type") == "text":
|
174
|
+
text_parts.append(part.get("text", ""))
|
175
|
+
elif part.get("type") == "image_url":
|
176
|
+
# Extract image from URL (assuming base64 encoded)
|
177
|
+
image_url = part.get("image_url", {}).get("url", "")
|
178
|
+
if image_url.startswith("data:image/"):
|
179
|
+
# Extract the base64 part
|
180
|
+
image_data = image_url.split(",")[1]
|
181
|
+
|
182
|
+
# Add text parts to formatted content
|
183
|
+
content = " ".join(text_parts)
|
184
|
+
formatted_content += f"{msg.role.capitalize()}: {content}\n"
|
185
|
+
|
186
|
+
formatted_content += "Assistant:"
|
187
|
+
return {"text": formatted_content, "image": image_data}
|
188
|
+
|
189
|
+
# ===== API Routes =====
|
190
|
+
|
191
|
+
@app.post("/v1/chat/completions")
|
192
|
+
async def chat_completions(request: ChatCompletionRequest):
|
193
|
+
"""Handle chat completion requests."""
|
194
|
+
logger.info(f"Received request: {request.dict()}")
|
195
|
+
|
196
|
+
# Extract the formatted content from messages
|
197
|
+
content = extract_content_from_messages(request.messages)
|
198
|
+
input_text = content["text"]
|
199
|
+
image_data = content["image"]
|
200
|
+
|
201
|
+
# Use requested model or default
|
202
|
+
model = request.model if request.model != "default" else DEFAULT_TEXT_MODEL
|
203
|
+
|
204
|
+
# Prepare request for Triton
|
205
|
+
triton_request = {
|
206
|
+
"inputs": [
|
207
|
+
{
|
208
|
+
"name": "text_input",
|
209
|
+
"shape": [1, 1],
|
210
|
+
"datatype": "BYTES",
|
211
|
+
"data": [input_text]
|
212
|
+
},
|
213
|
+
{
|
214
|
+
"name": "max_tokens",
|
215
|
+
"shape": [1, 1],
|
216
|
+
"datatype": "INT32",
|
217
|
+
"data": [request.max_tokens]
|
218
|
+
},
|
219
|
+
{
|
220
|
+
"name": "temperature",
|
221
|
+
"shape": [1, 1],
|
222
|
+
"datatype": "FP32",
|
223
|
+
"data": [request.temperature]
|
224
|
+
}
|
225
|
+
]
|
226
|
+
}
|
227
|
+
|
228
|
+
# Add image input if available and using vision model
|
229
|
+
if image_data is not None and model == "gemma3_4b":
|
230
|
+
try:
|
231
|
+
# Decode base64 image
|
232
|
+
from PIL import Image
|
233
|
+
import io
|
234
|
+
import numpy as np
|
235
|
+
|
236
|
+
# Decode and preprocess image
|
237
|
+
image_bytes = base64.b64decode(image_data)
|
238
|
+
image = Image.open(io.BytesIO(image_bytes))
|
239
|
+
|
240
|
+
# Resize to expected size (224x224 for most vision models)
|
241
|
+
image = image.resize((224, 224))
|
242
|
+
|
243
|
+
# Convert to RGB if not already
|
244
|
+
if image.mode != "RGB":
|
245
|
+
image = image.convert("RGB")
|
246
|
+
|
247
|
+
# Convert to numpy array and normalize
|
248
|
+
image_array = np.array(image).astype(np.float32) / 255.0
|
249
|
+
|
250
|
+
# Reorder from HWC to CHW format
|
251
|
+
image_array = np.transpose(image_array, (2, 0, 1))
|
252
|
+
|
253
|
+
# Add image input to Triton request
|
254
|
+
triton_request["inputs"].append({
|
255
|
+
"name": "image_input",
|
256
|
+
"shape": list(image_array.shape),
|
257
|
+
"datatype": "FP32",
|
258
|
+
"data": image_array.flatten().tolist()
|
259
|
+
})
|
260
|
+
|
261
|
+
logger.info("Added image input to request")
|
262
|
+
except Exception as e:
|
263
|
+
logger.error(f"Error processing image: {str(e)}")
|
264
|
+
|
265
|
+
logger.info(f"Sending to Triton: {triton_request}")
|
266
|
+
|
267
|
+
# Send to Triton
|
268
|
+
try:
|
269
|
+
response = requests.post(
|
270
|
+
f"{TRITON_URL}/v2/models/{model}/infer",
|
271
|
+
json=triton_request
|
272
|
+
)
|
273
|
+
response.raise_for_status()
|
274
|
+
triton_response = response.json()
|
275
|
+
logger.info(f"Triton response status: {response.status_code}")
|
276
|
+
logger.info(f"Triton response: {triton_response}")
|
277
|
+
|
278
|
+
# Extract text output
|
279
|
+
output_data = triton_response["outputs"][0]["data"][0]
|
280
|
+
|
281
|
+
# Format response
|
282
|
+
return format_chat_response(output_data, model)
|
283
|
+
|
284
|
+
except Exception as e:
|
285
|
+
logger.error(f"Error calling Triton: {str(e)}")
|
286
|
+
raise HTTPException(status_code=500, detail=f"Error calling model: {str(e)}")
|
287
|
+
|
288
|
+
@app.post("/v1/images/generations")
|
289
|
+
async def generate_images(request: ImageGenerationRequest):
|
290
|
+
"""Handle image generation requests."""
|
291
|
+
logger.info(f"Received image generation request: {request.dict()}")
|
292
|
+
|
293
|
+
# Use requested model or default
|
294
|
+
model = request.model if request.model != "default" else DEFAULT_IMAGE_MODEL
|
295
|
+
|
296
|
+
# For demo purposes - in a real implementation, this would call the Triton image model
|
297
|
+
# Here we'll just simulate image generation with a placeholder
|
298
|
+
try:
|
299
|
+
# Simulate Triton call (replace with actual call to Triton when image model is available)
|
300
|
+
# Return a placeholder image for demonstration
|
301
|
+
with open("placeholder.png", "rb") as f:
|
302
|
+
image_data = base64.b64encode(f.read()).decode("utf-8")
|
303
|
+
|
304
|
+
return format_image_response(image_data, model)
|
305
|
+
|
306
|
+
except Exception as e:
|
307
|
+
logger.error(f"Error generating image: {str(e)}")
|
308
|
+
raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
|
309
|
+
|
310
|
+
@app.post("/v1/audio/transcriptions")
|
311
|
+
async def transcribe_audio(request: AudioTranscriptionRequest):
|
312
|
+
"""Handle audio transcription requests."""
|
313
|
+
logger.info(f"Received audio transcription request: {request.dict()}")
|
314
|
+
|
315
|
+
# Use requested model or default
|
316
|
+
model = request.model if request.model != "default" else DEFAULT_AUDIO_MODEL
|
317
|
+
|
318
|
+
try:
|
319
|
+
# Decode the base64 audio
|
320
|
+
audio_data = base64.b64decode(request.file)
|
321
|
+
|
322
|
+
# Save to temporary file
|
323
|
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
324
|
+
temp_file.write(audio_data)
|
325
|
+
temp_file_path = temp_file.name
|
326
|
+
|
327
|
+
# Load and preprocess audio for Whisper
|
328
|
+
import librosa
|
329
|
+
import numpy as np
|
330
|
+
|
331
|
+
# Load audio file and resample to 16kHz for Whisper
|
332
|
+
audio_array, _ = librosa.load(temp_file_path, sr=16000, mono=True)
|
333
|
+
|
334
|
+
# Prepare request for Triton
|
335
|
+
triton_request = {
|
336
|
+
"inputs": [
|
337
|
+
{
|
338
|
+
"name": "audio_input",
|
339
|
+
"shape": [len(audio_array)],
|
340
|
+
"datatype": "FP32",
|
341
|
+
"data": audio_array.tolist()
|
342
|
+
}
|
343
|
+
]
|
344
|
+
}
|
345
|
+
|
346
|
+
# Add language if provided
|
347
|
+
if hasattr(request, 'language') and request.language:
|
348
|
+
triton_request["inputs"].append({
|
349
|
+
"name": "language",
|
350
|
+
"shape": [1, 1],
|
351
|
+
"datatype": "BYTES",
|
352
|
+
"data": [request.language]
|
353
|
+
})
|
354
|
+
|
355
|
+
# Clean up temp file
|
356
|
+
os.unlink(temp_file_path)
|
357
|
+
|
358
|
+
# Send to Triton
|
359
|
+
response = requests.post(
|
360
|
+
f"{TRITON_URL}/v2/models/{model}/infer",
|
361
|
+
json=triton_request
|
362
|
+
)
|
363
|
+
response.raise_for_status()
|
364
|
+
triton_response = response.json()
|
365
|
+
|
366
|
+
# Extract text output
|
367
|
+
transcription = triton_response["outputs"][0]["data"][0]
|
368
|
+
|
369
|
+
return format_audio_response(transcription, model)
|
370
|
+
|
371
|
+
except Exception as e:
|
372
|
+
logger.error(f"Error transcribing audio: {str(e)}")
|
373
|
+
# Fallback response
|
374
|
+
return format_audio_response(
|
375
|
+
"This is a placeholder transcription. In production, this would be generated by the Whisper model.",
|
376
|
+
model
|
377
|
+
)
|
378
|
+
|
379
|
+
@app.post("/v1/embeddings")
|
380
|
+
async def create_embeddings(request: EmbeddingRequest):
|
381
|
+
"""Handle embedding requests."""
|
382
|
+
logger.info(f"Received embedding request: {request.dict()}")
|
383
|
+
|
384
|
+
# Use requested model or default
|
385
|
+
model = request.model if request.model != "default" else DEFAULT_EMBEDDING_MODEL
|
386
|
+
|
387
|
+
# Convert input to list if it's a single string
|
388
|
+
inputs = request.input if isinstance(request.input, list) else [request.input]
|
389
|
+
|
390
|
+
try:
|
391
|
+
# Process each input text
|
392
|
+
all_embeddings = []
|
393
|
+
|
394
|
+
for text in inputs:
|
395
|
+
# Prepare request for Triton
|
396
|
+
triton_request = {
|
397
|
+
"inputs": [
|
398
|
+
{
|
399
|
+
"name": "text_input",
|
400
|
+
"shape": [1, 1],
|
401
|
+
"datatype": "BYTES",
|
402
|
+
"data": [text]
|
403
|
+
}
|
404
|
+
]
|
405
|
+
}
|
406
|
+
|
407
|
+
# Send to Triton
|
408
|
+
response = requests.post(
|
409
|
+
f"{TRITON_URL}/v2/models/{model}/infer",
|
410
|
+
json=triton_request
|
411
|
+
)
|
412
|
+
response.raise_for_status()
|
413
|
+
triton_response = response.json()
|
414
|
+
|
415
|
+
# Extract embedding
|
416
|
+
embedding = triton_response["outputs"][0]["data"]
|
417
|
+
all_embeddings.append(embedding)
|
418
|
+
|
419
|
+
return format_embedding_response(all_embeddings, model)
|
420
|
+
|
421
|
+
except Exception as e:
|
422
|
+
logger.error(f"Error creating embeddings: {str(e)}")
|
423
|
+
|
424
|
+
# Fallback - return random embeddings
|
425
|
+
embeddings = []
|
426
|
+
for _ in inputs:
|
427
|
+
# Generate a random embedding vector of dimension 1024 (BGE-M3)
|
428
|
+
embedding = np.random.normal(0, 1, 1024).tolist()
|
429
|
+
embeddings.append(embedding)
|
430
|
+
|
431
|
+
return format_embedding_response(embeddings, model)
|
432
|
+
|
433
|
+
@app.get("/health")
|
434
|
+
async def health_check():
|
435
|
+
"""Health check endpoint."""
|
436
|
+
return {"status": "healthy"}
|
437
|
+
|
438
|
+
# ===== Main =====
|
439
|
+
|
440
|
+
if __name__ == "__main__":
|
441
|
+
# Create placeholder image for demo
|
442
|
+
try:
|
443
|
+
if not os.path.exists("placeholder.png"):
|
444
|
+
# Create a simple 256x256 black image
|
445
|
+
import numpy as np
|
446
|
+
from PIL import Image
|
447
|
+
img = Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
|
448
|
+
img.save("placeholder.png")
|
449
|
+
except ImportError:
|
450
|
+
logger.warning("PIL not installed. Cannot create placeholder image.")
|
451
|
+
|
452
|
+
# Start server
|
453
|
+
uvicorn.run(app, host="0.0.0.0", port=8300)
|
@@ -0,0 +1,248 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
from typing import Dict, List, Any, Optional, Union
|
5
|
+
from fastapi import FastAPI, HTTPException, Depends, Request
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
|
8
|
+
from isa_model.inference.ai_factory import AIFactory
|
9
|
+
|
10
|
+
# Configure logging
|
11
|
+
logging.basicConfig(level=logging.INFO)
|
12
|
+
logger = logging.getLogger("unified_api")
|
13
|
+
|
14
|
+
# Create FastAPI app
|
15
|
+
app = FastAPI(
|
16
|
+
title="Unified AI Model API",
|
17
|
+
description="API for inference with Llama3-8B, Gemma3-4B, Whisper, and BGE-M3 models",
|
18
|
+
version="1.0.0"
|
19
|
+
)
|
20
|
+
|
21
|
+
# Models
|
22
|
+
class ChatMessage(BaseModel):
|
23
|
+
role: str = Field(..., description="Role of the message sender (system, user, assistant)")
|
24
|
+
content: str = Field(..., description="Content of the message")
|
25
|
+
|
26
|
+
class ChatCompletionRequest(BaseModel):
|
27
|
+
model: str = Field(..., description="Model ID to use (llama, gemma)")
|
28
|
+
messages: List[ChatMessage] = Field(..., description="List of messages in the conversation")
|
29
|
+
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
|
30
|
+
max_tokens: Optional[int] = Field(512, description="Maximum number of tokens to generate")
|
31
|
+
top_p: Optional[float] = Field(0.9, description="Top-p sampling parameter")
|
32
|
+
top_k: Optional[int] = Field(50, description="Top-k sampling parameter")
|
33
|
+
|
34
|
+
class ChatCompletionResponse(BaseModel):
|
35
|
+
model: str = Field(..., description="Model used for completion")
|
36
|
+
choices: List[Dict[str, Any]] = Field(..., description="Generated completions")
|
37
|
+
usage: Dict[str, int] = Field(..., description="Token usage statistics")
|
38
|
+
|
39
|
+
class EmbeddingRequest(BaseModel):
|
40
|
+
model: str = Field(..., description="Model ID to use (bge_embed)")
|
41
|
+
input: Union[str, List[str]] = Field(..., description="Text to embed")
|
42
|
+
normalize: Optional[bool] = Field(True, description="Whether to normalize embeddings")
|
43
|
+
|
44
|
+
class TranscriptionRequest(BaseModel):
|
45
|
+
model: str = Field(..., description="Model ID to use (whisper)")
|
46
|
+
audio: str = Field(..., description="Base64-encoded audio data or URL")
|
47
|
+
language: Optional[str] = Field("en", description="Language code")
|
48
|
+
|
49
|
+
# Factory for creating services
|
50
|
+
ai_factory = AIFactory()
|
51
|
+
|
52
|
+
# Dependency to get LLM service
|
53
|
+
async def get_llm_service(model: str):
|
54
|
+
if model == "llama":
|
55
|
+
return await ai_factory.get_llm_service("llama")
|
56
|
+
elif model == "gemma":
|
57
|
+
return await ai_factory.get_llm_service("gemma")
|
58
|
+
else:
|
59
|
+
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
60
|
+
|
61
|
+
# Dependency to get embedding service
|
62
|
+
async def get_embedding_service(model: str):
|
63
|
+
if model == "bge_embed":
|
64
|
+
return await ai_factory.get_embedding_service("bge_embed")
|
65
|
+
else:
|
66
|
+
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
67
|
+
|
68
|
+
# Dependency to get speech service
|
69
|
+
async def get_speech_service(model: str):
|
70
|
+
if model == "whisper":
|
71
|
+
return await ai_factory.get_speech_service("whisper")
|
72
|
+
else:
|
73
|
+
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
74
|
+
|
75
|
+
# Endpoints
|
76
|
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
77
|
+
async def chat_completion(request: ChatCompletionRequest):
|
78
|
+
"""Generate chat completion"""
|
79
|
+
try:
|
80
|
+
# Get the appropriate service
|
81
|
+
service = await get_llm_service(request.model)
|
82
|
+
|
83
|
+
# Format messages
|
84
|
+
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
85
|
+
|
86
|
+
# Extract system prompt if present
|
87
|
+
system_prompt = None
|
88
|
+
if formatted_messages and formatted_messages[0]["role"] == "system":
|
89
|
+
system_prompt = formatted_messages[0]["content"]
|
90
|
+
formatted_messages = formatted_messages[1:]
|
91
|
+
|
92
|
+
# Get user prompt (last user message)
|
93
|
+
user_prompt = ""
|
94
|
+
for msg in reversed(formatted_messages):
|
95
|
+
if msg["role"] == "user":
|
96
|
+
user_prompt = msg["content"]
|
97
|
+
break
|
98
|
+
|
99
|
+
if not user_prompt:
|
100
|
+
raise HTTPException(status_code=400, detail="No user message found")
|
101
|
+
|
102
|
+
# Set generation config
|
103
|
+
generation_config = {
|
104
|
+
"temperature": request.temperature,
|
105
|
+
"max_new_tokens": request.max_tokens,
|
106
|
+
"top_p": request.top_p,
|
107
|
+
"top_k": request.top_k
|
108
|
+
}
|
109
|
+
|
110
|
+
# Generate completion
|
111
|
+
completion = await service.generate(
|
112
|
+
prompt=user_prompt,
|
113
|
+
system_prompt=system_prompt,
|
114
|
+
generation_config=generation_config
|
115
|
+
)
|
116
|
+
|
117
|
+
# Format response
|
118
|
+
response = {
|
119
|
+
"model": request.model,
|
120
|
+
"choices": [
|
121
|
+
{
|
122
|
+
"message": {
|
123
|
+
"role": "assistant",
|
124
|
+
"content": completion
|
125
|
+
},
|
126
|
+
"finish_reason": "stop",
|
127
|
+
"index": 0
|
128
|
+
}
|
129
|
+
],
|
130
|
+
"usage": {
|
131
|
+
"prompt_tokens": len(user_prompt.split()),
|
132
|
+
"completion_tokens": len(completion.split()),
|
133
|
+
"total_tokens": len(user_prompt.split()) + len(completion.split())
|
134
|
+
}
|
135
|
+
}
|
136
|
+
|
137
|
+
return response
|
138
|
+
|
139
|
+
except Exception as e:
|
140
|
+
logger.error(f"Error in chat completion: {str(e)}")
|
141
|
+
raise HTTPException(status_code=500, detail=str(e))
|
142
|
+
|
143
|
+
@app.post("/v1/embeddings")
|
144
|
+
async def create_embedding(request: EmbeddingRequest):
|
145
|
+
"""Generate embeddings for text"""
|
146
|
+
try:
|
147
|
+
# Get the embedding service
|
148
|
+
service = await get_embedding_service("bge_embed")
|
149
|
+
|
150
|
+
# Generate embeddings
|
151
|
+
if isinstance(request.input, str):
|
152
|
+
embeddings = await service.embed(request.input, normalize=request.normalize)
|
153
|
+
data = [{"embedding": embeddings[0].tolist(), "index": 0}]
|
154
|
+
else:
|
155
|
+
embeddings = await service.embed(request.input, normalize=request.normalize)
|
156
|
+
data = [{"embedding": emb.tolist(), "index": i} for i, emb in enumerate(embeddings)]
|
157
|
+
|
158
|
+
# Format response
|
159
|
+
response = {
|
160
|
+
"model": request.model,
|
161
|
+
"data": data,
|
162
|
+
"usage": {
|
163
|
+
"prompt_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input])),
|
164
|
+
"total_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input]))
|
165
|
+
}
|
166
|
+
}
|
167
|
+
|
168
|
+
return response
|
169
|
+
|
170
|
+
except Exception as e:
|
171
|
+
logger.error(f"Error in embedding generation: {str(e)}")
|
172
|
+
raise HTTPException(status_code=500, detail=str(e))
|
173
|
+
|
174
|
+
@app.post("/v1/audio/transcriptions")
|
175
|
+
async def transcribe_audio(request: TranscriptionRequest):
|
176
|
+
"""Transcribe audio to text"""
|
177
|
+
try:
|
178
|
+
import base64
|
179
|
+
|
180
|
+
# Get the speech service
|
181
|
+
service = await get_speech_service("whisper")
|
182
|
+
|
183
|
+
# Process audio
|
184
|
+
if request.audio.startswith(("http://", "https://")):
|
185
|
+
# URL - download audio
|
186
|
+
import requests
|
187
|
+
audio_data = requests.get(request.audio).content
|
188
|
+
else:
|
189
|
+
# Base64 - decode
|
190
|
+
audio_data = base64.b64decode(request.audio)
|
191
|
+
|
192
|
+
# Transcribe
|
193
|
+
transcription = await service.transcribe(
|
194
|
+
audio=audio_data,
|
195
|
+
language=request.language
|
196
|
+
)
|
197
|
+
|
198
|
+
# Format response
|
199
|
+
response = {
|
200
|
+
"model": request.model,
|
201
|
+
"text": transcription
|
202
|
+
}
|
203
|
+
|
204
|
+
return response
|
205
|
+
|
206
|
+
except Exception as e:
|
207
|
+
logger.error(f"Error in audio transcription: {str(e)}")
|
208
|
+
raise HTTPException(status_code=500, detail=str(e))
|
209
|
+
|
210
|
+
# Health check endpoint
|
211
|
+
@app.get("/health")
|
212
|
+
async def health_check():
|
213
|
+
"""Health check endpoint"""
|
214
|
+
return {"status": "healthy"}
|
215
|
+
|
216
|
+
# Model info endpoint
|
217
|
+
@app.get("/v1/models")
|
218
|
+
async def list_models():
|
219
|
+
"""List available models"""
|
220
|
+
models = [
|
221
|
+
{
|
222
|
+
"id": "llama",
|
223
|
+
"type": "llm",
|
224
|
+
"description": "Llama3-8B language model"
|
225
|
+
},
|
226
|
+
{
|
227
|
+
"id": "gemma",
|
228
|
+
"type": "llm",
|
229
|
+
"description": "Gemma3-4B language model"
|
230
|
+
},
|
231
|
+
{
|
232
|
+
"id": "whisper",
|
233
|
+
"type": "speech",
|
234
|
+
"description": "Whisper-tiny speech-to-text model"
|
235
|
+
},
|
236
|
+
{
|
237
|
+
"id": "bge_embed",
|
238
|
+
"type": "embedding",
|
239
|
+
"description": "BGE-M3 text embedding model"
|
240
|
+
}
|
241
|
+
]
|
242
|
+
|
243
|
+
return {"data": models}
|
244
|
+
|
245
|
+
# Main entry point
|
246
|
+
if __name__ == "__main__":
|
247
|
+
import uvicorn
|
248
|
+
uvicorn.run(app, host="0.0.0.0", port=8080)
|