isa-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/mlflow_gateway/__init__.py +8 -0
  12. isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
  13. isa_model/deployment/unified_multimodal_client.py +341 -0
  14. isa_model/inference/__init__.py +11 -0
  15. isa_model/inference/adapter/triton_adapter.py +453 -0
  16. isa_model/inference/adapter/unified_api.py +248 -0
  17. isa_model/inference/ai_factory.py +354 -0
  18. isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
  19. isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
  20. isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
  21. isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
  22. isa_model/inference/backends/__init__.py +53 -0
  23. isa_model/inference/backends/base_backend_client.py +26 -0
  24. isa_model/inference/backends/container_services.py +104 -0
  25. isa_model/inference/backends/local_services.py +72 -0
  26. isa_model/inference/backends/openai_client.py +130 -0
  27. isa_model/inference/backends/replicate_client.py +197 -0
  28. isa_model/inference/backends/third_party_services.py +239 -0
  29. isa_model/inference/backends/triton_client.py +97 -0
  30. isa_model/inference/base.py +46 -0
  31. isa_model/inference/client_sdk/__init__.py +0 -0
  32. isa_model/inference/client_sdk/client.py +134 -0
  33. isa_model/inference/client_sdk/client_data_std.py +34 -0
  34. isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
  35. isa_model/inference/client_sdk/exceptions.py +0 -0
  36. isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
  37. isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
  38. isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
  39. isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
  40. isa_model/inference/providers/__init__.py +19 -0
  41. isa_model/inference/providers/base_provider.py +30 -0
  42. isa_model/inference/providers/model_cache_manager.py +341 -0
  43. isa_model/inference/providers/ollama_provider.py +73 -0
  44. isa_model/inference/providers/openai_provider.py +87 -0
  45. isa_model/inference/providers/replicate_provider.py +94 -0
  46. isa_model/inference/providers/triton_provider.py +439 -0
  47. isa_model/inference/providers/vllm_provider.py +0 -0
  48. isa_model/inference/providers/yyds_provider.py +83 -0
  49. isa_model/inference/services/__init__.py +14 -0
  50. isa_model/inference/services/audio/fish_speech/handler.py +215 -0
  51. isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
  52. isa_model/inference/services/audio/triton_speech_service.py +138 -0
  53. isa_model/inference/services/audio/whisper_service.py +186 -0
  54. isa_model/inference/services/audio/yyds_audio_service.py +71 -0
  55. isa_model/inference/services/base_service.py +106 -0
  56. isa_model/inference/services/base_tts_service.py +66 -0
  57. isa_model/inference/services/embedding/bge_service.py +183 -0
  58. isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
  59. isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
  60. isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
  61. isa_model/inference/services/llm/__init__.py +16 -0
  62. isa_model/inference/services/llm/gemma_service.py +143 -0
  63. isa_model/inference/services/llm/llama_service.py +143 -0
  64. isa_model/inference/services/llm/ollama_llm_service.py +108 -0
  65. isa_model/inference/services/llm/openai_llm_service.py +129 -0
  66. isa_model/inference/services/llm/replicate_llm_service.py +179 -0
  67. isa_model/inference/services/llm/triton_llm_service.py +230 -0
  68. isa_model/inference/services/others/table_transformer_service.py +61 -0
  69. isa_model/inference/services/vision/__init__.py +12 -0
  70. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  71. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  72. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  73. isa_model/inference/services/vision/replicate_vision_service.py +241 -0
  74. isa_model/inference/services/vision/triton_vision_service.py +199 -0
  75. isa_model/inference/services/vision/yyds_vision_service.py +80 -0
  76. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  77. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  78. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  79. isa_model/scripts/inference_tracker.py +283 -0
  80. isa_model/scripts/mlflow_manager.py +379 -0
  81. isa_model/scripts/model_registry.py +465 -0
  82. isa_model/scripts/start_mlflow.py +95 -0
  83. isa_model/scripts/training_tracker.py +257 -0
  84. isa_model/training/engine/llama_factory/__init__.py +39 -0
  85. isa_model/training/engine/llama_factory/config.py +115 -0
  86. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  87. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  88. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  89. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  90. isa_model/training/engine/llama_factory/factory.py +331 -0
  91. isa_model/training/engine/llama_factory/rl.py +254 -0
  92. isa_model/training/engine/llama_factory/trainer.py +171 -0
  93. isa_model/training/image_model/configs/create_config.py +37 -0
  94. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  95. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  96. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  97. isa_model/training/image_model/prepare_upload.py +17 -0
  98. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  99. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  100. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  101. isa_model/training/image_model/train/train.py +42 -0
  102. isa_model/training/image_model/train/train_flux.py +41 -0
  103. isa_model/training/image_model/train/train_lora.py +57 -0
  104. isa_model/training/image_model/train_main.py +25 -0
  105. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  106. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  107. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  108. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  109. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  110. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  111. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  112. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  113. isa_model-0.1.0.dist-info/METADATA +116 -0
  114. isa_model-0.1.0.dist-info/RECORD +117 -0
  115. isa_model-0.1.0.dist-info/WHEEL +5 -0
  116. isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
  117. isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,453 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multimodal OpenAI-Compatible Adapter for Triton Inference Server
4
+
5
+ This adapter translates between OpenAI API format and Triton Inference Server format.
6
+ It supports multiple modalities (text, image, voice) through a unified API.
7
+
8
+ Features:
9
+ - Chat completions API (text)
10
+ - Image generation API
11
+ - Audio transcription API
12
+ - Embeddings API
13
+
14
+ The adapter routes requests to the appropriate Triton model based on the task.
15
+ """
16
+
17
+ import os
18
+ import json
19
+ import time
20
+ import base64
21
+ import logging
22
+ import requests
23
+ import tempfile
24
+ import uvicorn
25
+ import uuid
26
+ from typing import List, Dict, Any, Optional, Union
27
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Body, BackgroundTasks
28
+ from fastapi.responses import StreamingResponse, FileResponse
29
+ from fastapi.middleware.cors import CORSMiddleware
30
+ from pydantic import BaseModel, Field
31
+ import numpy as np
32
+ from datetime import datetime
33
+
34
+ # Configure logging
35
+ logging.basicConfig(level=logging.INFO)
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Initialize FastAPI app
39
+ app = FastAPI(title="Multimodal OpenAI-Compatible Adapter")
40
+
41
+ # Add CORS middleware
42
+ app.add_middleware(
43
+ CORSMiddleware,
44
+ allow_origins=["*"],
45
+ allow_credentials=True,
46
+ allow_methods=["*"],
47
+ allow_headers=["*"],
48
+ )
49
+
50
+ # Constants
51
+ TRITON_URL = os.environ.get("TRITON_URL", "http://localhost:8000")
52
+ DEFAULT_TEXT_MODEL = os.environ.get("DEFAULT_TEXT_MODEL", "llama3_cpu")
53
+ DEFAULT_IMAGE_MODEL = os.environ.get("DEFAULT_IMAGE_MODEL", "stable_diffusion")
54
+ DEFAULT_AUDIO_MODEL = os.environ.get("DEFAULT_AUDIO_MODEL", "whisper_tiny")
55
+ DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "bge_m3")
56
+ DEFAULT_VISION_MODEL = os.environ.get("DEFAULT_VISION_MODEL", "gemma3_4b")
57
+
58
+ # ===== Schema Definitions =====
59
+
60
+ class ChatMessage(BaseModel):
61
+ role: str
62
+ content: Union[str, List[Dict[str, Any]]]
63
+ name: Optional[str] = None
64
+
65
+ class ChatCompletionRequest(BaseModel):
66
+ model: str
67
+ messages: List[ChatMessage]
68
+ temperature: Optional[float] = 0.7
69
+ top_p: Optional[float] = 1.0
70
+ max_tokens: Optional[int] = 100
71
+ stream: Optional[bool] = False
72
+ stop: Optional[Union[str, List[str]]] = None
73
+
74
+ class ImageGenerationRequest(BaseModel):
75
+ model: str
76
+ prompt: str
77
+ n: Optional[int] = 1
78
+ size: Optional[str] = "1024x1024"
79
+ response_format: Optional[str] = "url"
80
+
81
+ class AudioTranscriptionRequest(BaseModel):
82
+ model: str
83
+ file: str # Base64 encoded audio
84
+ response_format: Optional[str] = "text"
85
+ language: Optional[str] = "en"
86
+
87
+ class EmbeddingRequest(BaseModel):
88
+ model: str
89
+ input: Union[str, List[str]]
90
+ encoding_format: Optional[str] = "float"
91
+
92
+ # ===== Helper Functions =====
93
+
94
+ def generate_response_id(prefix: str = "res") -> str:
95
+ """Generate a unique response ID."""
96
+ return f"{prefix}-{uuid.uuid4()}"
97
+
98
+ def format_chat_response(content: str, model: str) -> Dict[str, Any]:
99
+ """Format chat completion response in OpenAI format."""
100
+ return {
101
+ "id": generate_response_id("chatcmpl"),
102
+ "object": "chat.completion",
103
+ "created": int(time.time()),
104
+ "model": model,
105
+ "choices": [
106
+ {
107
+ "index": 0,
108
+ "message": {
109
+ "role": "assistant",
110
+ "content": content
111
+ },
112
+ "finish_reason": "stop"
113
+ }
114
+ ],
115
+ "usage": {
116
+ "prompt_tokens": 0, # We don't track these yet
117
+ "completion_tokens": 0,
118
+ "total_tokens": 0
119
+ }
120
+ }
121
+
122
+ def format_image_response(image_data: str, model: str) -> Dict[str, Any]:
123
+ """Format image generation response in OpenAI format."""
124
+ return {
125
+ "created": int(time.time()),
126
+ "data": [
127
+ {
128
+ "url": f"data:image/png;base64,{image_data}"
129
+ }
130
+ ]
131
+ }
132
+
133
+ def format_audio_response(text: str, model: str) -> Dict[str, Any]:
134
+ """Format audio transcription response in OpenAI format."""
135
+ return {
136
+ "text": text
137
+ }
138
+
139
+ def format_embedding_response(embeddings: List[List[float]], model: str) -> Dict[str, Any]:
140
+ """Format embedding response in OpenAI format."""
141
+ data = []
142
+ for i, embedding in enumerate(embeddings):
143
+ data.append({
144
+ "object": "embedding",
145
+ "embedding": embedding,
146
+ "index": i
147
+ })
148
+
149
+ return {
150
+ "object": "list",
151
+ "data": data,
152
+ "model": model,
153
+ "usage": {
154
+ "prompt_tokens": 0,
155
+ "total_tokens": 0
156
+ }
157
+ }
158
+
159
+ def extract_content_from_messages(messages: List[ChatMessage]) -> Dict[str, Any]:
160
+ """Extract content from messages for Triton input."""
161
+ formatted_content = ""
162
+ image_data = None
163
+
164
+ for msg in messages:
165
+ # Handle both string content and list of content parts
166
+ if isinstance(msg.content, str):
167
+ content = msg.content
168
+ formatted_content += f"{msg.role.capitalize()}: {content}\n"
169
+ else:
170
+ # For multimodal content, extract text and image parts
171
+ text_parts = []
172
+ for part in msg.content:
173
+ if part.get("type") == "text":
174
+ text_parts.append(part.get("text", ""))
175
+ elif part.get("type") == "image_url":
176
+ # Extract image from URL (assuming base64 encoded)
177
+ image_url = part.get("image_url", {}).get("url", "")
178
+ if image_url.startswith("data:image/"):
179
+ # Extract the base64 part
180
+ image_data = image_url.split(",")[1]
181
+
182
+ # Add text parts to formatted content
183
+ content = " ".join(text_parts)
184
+ formatted_content += f"{msg.role.capitalize()}: {content}\n"
185
+
186
+ formatted_content += "Assistant:"
187
+ return {"text": formatted_content, "image": image_data}
188
+
189
+ # ===== API Routes =====
190
+
191
+ @app.post("/v1/chat/completions")
192
+ async def chat_completions(request: ChatCompletionRequest):
193
+ """Handle chat completion requests."""
194
+ logger.info(f"Received request: {request.dict()}")
195
+
196
+ # Extract the formatted content from messages
197
+ content = extract_content_from_messages(request.messages)
198
+ input_text = content["text"]
199
+ image_data = content["image"]
200
+
201
+ # Use requested model or default
202
+ model = request.model if request.model != "default" else DEFAULT_TEXT_MODEL
203
+
204
+ # Prepare request for Triton
205
+ triton_request = {
206
+ "inputs": [
207
+ {
208
+ "name": "text_input",
209
+ "shape": [1, 1],
210
+ "datatype": "BYTES",
211
+ "data": [input_text]
212
+ },
213
+ {
214
+ "name": "max_tokens",
215
+ "shape": [1, 1],
216
+ "datatype": "INT32",
217
+ "data": [request.max_tokens]
218
+ },
219
+ {
220
+ "name": "temperature",
221
+ "shape": [1, 1],
222
+ "datatype": "FP32",
223
+ "data": [request.temperature]
224
+ }
225
+ ]
226
+ }
227
+
228
+ # Add image input if available and using vision model
229
+ if image_data is not None and model == "gemma3_4b":
230
+ try:
231
+ # Decode base64 image
232
+ from PIL import Image
233
+ import io
234
+ import numpy as np
235
+
236
+ # Decode and preprocess image
237
+ image_bytes = base64.b64decode(image_data)
238
+ image = Image.open(io.BytesIO(image_bytes))
239
+
240
+ # Resize to expected size (224x224 for most vision models)
241
+ image = image.resize((224, 224))
242
+
243
+ # Convert to RGB if not already
244
+ if image.mode != "RGB":
245
+ image = image.convert("RGB")
246
+
247
+ # Convert to numpy array and normalize
248
+ image_array = np.array(image).astype(np.float32) / 255.0
249
+
250
+ # Reorder from HWC to CHW format
251
+ image_array = np.transpose(image_array, (2, 0, 1))
252
+
253
+ # Add image input to Triton request
254
+ triton_request["inputs"].append({
255
+ "name": "image_input",
256
+ "shape": list(image_array.shape),
257
+ "datatype": "FP32",
258
+ "data": image_array.flatten().tolist()
259
+ })
260
+
261
+ logger.info("Added image input to request")
262
+ except Exception as e:
263
+ logger.error(f"Error processing image: {str(e)}")
264
+
265
+ logger.info(f"Sending to Triton: {triton_request}")
266
+
267
+ # Send to Triton
268
+ try:
269
+ response = requests.post(
270
+ f"{TRITON_URL}/v2/models/{model}/infer",
271
+ json=triton_request
272
+ )
273
+ response.raise_for_status()
274
+ triton_response = response.json()
275
+ logger.info(f"Triton response status: {response.status_code}")
276
+ logger.info(f"Triton response: {triton_response}")
277
+
278
+ # Extract text output
279
+ output_data = triton_response["outputs"][0]["data"][0]
280
+
281
+ # Format response
282
+ return format_chat_response(output_data, model)
283
+
284
+ except Exception as e:
285
+ logger.error(f"Error calling Triton: {str(e)}")
286
+ raise HTTPException(status_code=500, detail=f"Error calling model: {str(e)}")
287
+
288
+ @app.post("/v1/images/generations")
289
+ async def generate_images(request: ImageGenerationRequest):
290
+ """Handle image generation requests."""
291
+ logger.info(f"Received image generation request: {request.dict()}")
292
+
293
+ # Use requested model or default
294
+ model = request.model if request.model != "default" else DEFAULT_IMAGE_MODEL
295
+
296
+ # For demo purposes - in a real implementation, this would call the Triton image model
297
+ # Here we'll just simulate image generation with a placeholder
298
+ try:
299
+ # Simulate Triton call (replace with actual call to Triton when image model is available)
300
+ # Return a placeholder image for demonstration
301
+ with open("placeholder.png", "rb") as f:
302
+ image_data = base64.b64encode(f.read()).decode("utf-8")
303
+
304
+ return format_image_response(image_data, model)
305
+
306
+ except Exception as e:
307
+ logger.error(f"Error generating image: {str(e)}")
308
+ raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
309
+
310
+ @app.post("/v1/audio/transcriptions")
311
+ async def transcribe_audio(request: AudioTranscriptionRequest):
312
+ """Handle audio transcription requests."""
313
+ logger.info(f"Received audio transcription request: {request.dict()}")
314
+
315
+ # Use requested model or default
316
+ model = request.model if request.model != "default" else DEFAULT_AUDIO_MODEL
317
+
318
+ try:
319
+ # Decode the base64 audio
320
+ audio_data = base64.b64decode(request.file)
321
+
322
+ # Save to temporary file
323
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
324
+ temp_file.write(audio_data)
325
+ temp_file_path = temp_file.name
326
+
327
+ # Load and preprocess audio for Whisper
328
+ import librosa
329
+ import numpy as np
330
+
331
+ # Load audio file and resample to 16kHz for Whisper
332
+ audio_array, _ = librosa.load(temp_file_path, sr=16000, mono=True)
333
+
334
+ # Prepare request for Triton
335
+ triton_request = {
336
+ "inputs": [
337
+ {
338
+ "name": "audio_input",
339
+ "shape": [len(audio_array)],
340
+ "datatype": "FP32",
341
+ "data": audio_array.tolist()
342
+ }
343
+ ]
344
+ }
345
+
346
+ # Add language if provided
347
+ if hasattr(request, 'language') and request.language:
348
+ triton_request["inputs"].append({
349
+ "name": "language",
350
+ "shape": [1, 1],
351
+ "datatype": "BYTES",
352
+ "data": [request.language]
353
+ })
354
+
355
+ # Clean up temp file
356
+ os.unlink(temp_file_path)
357
+
358
+ # Send to Triton
359
+ response = requests.post(
360
+ f"{TRITON_URL}/v2/models/{model}/infer",
361
+ json=triton_request
362
+ )
363
+ response.raise_for_status()
364
+ triton_response = response.json()
365
+
366
+ # Extract text output
367
+ transcription = triton_response["outputs"][0]["data"][0]
368
+
369
+ return format_audio_response(transcription, model)
370
+
371
+ except Exception as e:
372
+ logger.error(f"Error transcribing audio: {str(e)}")
373
+ # Fallback response
374
+ return format_audio_response(
375
+ "This is a placeholder transcription. In production, this would be generated by the Whisper model.",
376
+ model
377
+ )
378
+
379
+ @app.post("/v1/embeddings")
380
+ async def create_embeddings(request: EmbeddingRequest):
381
+ """Handle embedding requests."""
382
+ logger.info(f"Received embedding request: {request.dict()}")
383
+
384
+ # Use requested model or default
385
+ model = request.model if request.model != "default" else DEFAULT_EMBEDDING_MODEL
386
+
387
+ # Convert input to list if it's a single string
388
+ inputs = request.input if isinstance(request.input, list) else [request.input]
389
+
390
+ try:
391
+ # Process each input text
392
+ all_embeddings = []
393
+
394
+ for text in inputs:
395
+ # Prepare request for Triton
396
+ triton_request = {
397
+ "inputs": [
398
+ {
399
+ "name": "text_input",
400
+ "shape": [1, 1],
401
+ "datatype": "BYTES",
402
+ "data": [text]
403
+ }
404
+ ]
405
+ }
406
+
407
+ # Send to Triton
408
+ response = requests.post(
409
+ f"{TRITON_URL}/v2/models/{model}/infer",
410
+ json=triton_request
411
+ )
412
+ response.raise_for_status()
413
+ triton_response = response.json()
414
+
415
+ # Extract embedding
416
+ embedding = triton_response["outputs"][0]["data"]
417
+ all_embeddings.append(embedding)
418
+
419
+ return format_embedding_response(all_embeddings, model)
420
+
421
+ except Exception as e:
422
+ logger.error(f"Error creating embeddings: {str(e)}")
423
+
424
+ # Fallback - return random embeddings
425
+ embeddings = []
426
+ for _ in inputs:
427
+ # Generate a random embedding vector of dimension 1024 (BGE-M3)
428
+ embedding = np.random.normal(0, 1, 1024).tolist()
429
+ embeddings.append(embedding)
430
+
431
+ return format_embedding_response(embeddings, model)
432
+
433
+ @app.get("/health")
434
+ async def health_check():
435
+ """Health check endpoint."""
436
+ return {"status": "healthy"}
437
+
438
+ # ===== Main =====
439
+
440
+ if __name__ == "__main__":
441
+ # Create placeholder image for demo
442
+ try:
443
+ if not os.path.exists("placeholder.png"):
444
+ # Create a simple 256x256 black image
445
+ import numpy as np
446
+ from PIL import Image
447
+ img = Image.fromarray(np.zeros((256, 256, 3), dtype=np.uint8))
448
+ img.save("placeholder.png")
449
+ except ImportError:
450
+ logger.warning("PIL not installed. Cannot create placeholder image.")
451
+
452
+ # Start server
453
+ uvicorn.run(app, host="0.0.0.0", port=8300)
@@ -0,0 +1,248 @@
1
+ import os
2
+ import json
3
+ import logging
4
+ from typing import Dict, List, Any, Optional, Union
5
+ from fastapi import FastAPI, HTTPException, Depends, Request
6
+ from pydantic import BaseModel, Field
7
+
8
+ from isa_model.inference.ai_factory import AIFactory
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger("unified_api")
13
+
14
+ # Create FastAPI app
15
+ app = FastAPI(
16
+ title="Unified AI Model API",
17
+ description="API for inference with Llama3-8B, Gemma3-4B, Whisper, and BGE-M3 models",
18
+ version="1.0.0"
19
+ )
20
+
21
+ # Models
22
+ class ChatMessage(BaseModel):
23
+ role: str = Field(..., description="Role of the message sender (system, user, assistant)")
24
+ content: str = Field(..., description="Content of the message")
25
+
26
+ class ChatCompletionRequest(BaseModel):
27
+ model: str = Field(..., description="Model ID to use (llama, gemma)")
28
+ messages: List[ChatMessage] = Field(..., description="List of messages in the conversation")
29
+ temperature: Optional[float] = Field(0.7, description="Sampling temperature")
30
+ max_tokens: Optional[int] = Field(512, description="Maximum number of tokens to generate")
31
+ top_p: Optional[float] = Field(0.9, description="Top-p sampling parameter")
32
+ top_k: Optional[int] = Field(50, description="Top-k sampling parameter")
33
+
34
+ class ChatCompletionResponse(BaseModel):
35
+ model: str = Field(..., description="Model used for completion")
36
+ choices: List[Dict[str, Any]] = Field(..., description="Generated completions")
37
+ usage: Dict[str, int] = Field(..., description="Token usage statistics")
38
+
39
+ class EmbeddingRequest(BaseModel):
40
+ model: str = Field(..., description="Model ID to use (bge_embed)")
41
+ input: Union[str, List[str]] = Field(..., description="Text to embed")
42
+ normalize: Optional[bool] = Field(True, description="Whether to normalize embeddings")
43
+
44
+ class TranscriptionRequest(BaseModel):
45
+ model: str = Field(..., description="Model ID to use (whisper)")
46
+ audio: str = Field(..., description="Base64-encoded audio data or URL")
47
+ language: Optional[str] = Field("en", description="Language code")
48
+
49
+ # Factory for creating services
50
+ ai_factory = AIFactory()
51
+
52
+ # Dependency to get LLM service
53
+ async def get_llm_service(model: str):
54
+ if model == "llama":
55
+ return await ai_factory.get_llm_service("llama")
56
+ elif model == "gemma":
57
+ return await ai_factory.get_llm_service("gemma")
58
+ else:
59
+ raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
60
+
61
+ # Dependency to get embedding service
62
+ async def get_embedding_service(model: str):
63
+ if model == "bge_embed":
64
+ return await ai_factory.get_embedding_service("bge_embed")
65
+ else:
66
+ raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
67
+
68
+ # Dependency to get speech service
69
+ async def get_speech_service(model: str):
70
+ if model == "whisper":
71
+ return await ai_factory.get_speech_service("whisper")
72
+ else:
73
+ raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
74
+
75
+ # Endpoints
76
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
77
+ async def chat_completion(request: ChatCompletionRequest):
78
+ """Generate chat completion"""
79
+ try:
80
+ # Get the appropriate service
81
+ service = await get_llm_service(request.model)
82
+
83
+ # Format messages
84
+ formatted_messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
85
+
86
+ # Extract system prompt if present
87
+ system_prompt = None
88
+ if formatted_messages and formatted_messages[0]["role"] == "system":
89
+ system_prompt = formatted_messages[0]["content"]
90
+ formatted_messages = formatted_messages[1:]
91
+
92
+ # Get user prompt (last user message)
93
+ user_prompt = ""
94
+ for msg in reversed(formatted_messages):
95
+ if msg["role"] == "user":
96
+ user_prompt = msg["content"]
97
+ break
98
+
99
+ if not user_prompt:
100
+ raise HTTPException(status_code=400, detail="No user message found")
101
+
102
+ # Set generation config
103
+ generation_config = {
104
+ "temperature": request.temperature,
105
+ "max_new_tokens": request.max_tokens,
106
+ "top_p": request.top_p,
107
+ "top_k": request.top_k
108
+ }
109
+
110
+ # Generate completion
111
+ completion = await service.generate(
112
+ prompt=user_prompt,
113
+ system_prompt=system_prompt,
114
+ generation_config=generation_config
115
+ )
116
+
117
+ # Format response
118
+ response = {
119
+ "model": request.model,
120
+ "choices": [
121
+ {
122
+ "message": {
123
+ "role": "assistant",
124
+ "content": completion
125
+ },
126
+ "finish_reason": "stop",
127
+ "index": 0
128
+ }
129
+ ],
130
+ "usage": {
131
+ "prompt_tokens": len(user_prompt.split()),
132
+ "completion_tokens": len(completion.split()),
133
+ "total_tokens": len(user_prompt.split()) + len(completion.split())
134
+ }
135
+ }
136
+
137
+ return response
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error in chat completion: {str(e)}")
141
+ raise HTTPException(status_code=500, detail=str(e))
142
+
143
+ @app.post("/v1/embeddings")
144
+ async def create_embedding(request: EmbeddingRequest):
145
+ """Generate embeddings for text"""
146
+ try:
147
+ # Get the embedding service
148
+ service = await get_embedding_service("bge_embed")
149
+
150
+ # Generate embeddings
151
+ if isinstance(request.input, str):
152
+ embeddings = await service.embed(request.input, normalize=request.normalize)
153
+ data = [{"embedding": embeddings[0].tolist(), "index": 0}]
154
+ else:
155
+ embeddings = await service.embed(request.input, normalize=request.normalize)
156
+ data = [{"embedding": emb.tolist(), "index": i} for i, emb in enumerate(embeddings)]
157
+
158
+ # Format response
159
+ response = {
160
+ "model": request.model,
161
+ "data": data,
162
+ "usage": {
163
+ "prompt_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input])),
164
+ "total_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input]))
165
+ }
166
+ }
167
+
168
+ return response
169
+
170
+ except Exception as e:
171
+ logger.error(f"Error in embedding generation: {str(e)}")
172
+ raise HTTPException(status_code=500, detail=str(e))
173
+
174
+ @app.post("/v1/audio/transcriptions")
175
+ async def transcribe_audio(request: TranscriptionRequest):
176
+ """Transcribe audio to text"""
177
+ try:
178
+ import base64
179
+
180
+ # Get the speech service
181
+ service = await get_speech_service("whisper")
182
+
183
+ # Process audio
184
+ if request.audio.startswith(("http://", "https://")):
185
+ # URL - download audio
186
+ import requests
187
+ audio_data = requests.get(request.audio).content
188
+ else:
189
+ # Base64 - decode
190
+ audio_data = base64.b64decode(request.audio)
191
+
192
+ # Transcribe
193
+ transcription = await service.transcribe(
194
+ audio=audio_data,
195
+ language=request.language
196
+ )
197
+
198
+ # Format response
199
+ response = {
200
+ "model": request.model,
201
+ "text": transcription
202
+ }
203
+
204
+ return response
205
+
206
+ except Exception as e:
207
+ logger.error(f"Error in audio transcription: {str(e)}")
208
+ raise HTTPException(status_code=500, detail=str(e))
209
+
210
+ # Health check endpoint
211
+ @app.get("/health")
212
+ async def health_check():
213
+ """Health check endpoint"""
214
+ return {"status": "healthy"}
215
+
216
+ # Model info endpoint
217
+ @app.get("/v1/models")
218
+ async def list_models():
219
+ """List available models"""
220
+ models = [
221
+ {
222
+ "id": "llama",
223
+ "type": "llm",
224
+ "description": "Llama3-8B language model"
225
+ },
226
+ {
227
+ "id": "gemma",
228
+ "type": "llm",
229
+ "description": "Gemma3-4B language model"
230
+ },
231
+ {
232
+ "id": "whisper",
233
+ "type": "speech",
234
+ "description": "Whisper-tiny speech-to-text model"
235
+ },
236
+ {
237
+ "id": "bge_embed",
238
+ "type": "embedding",
239
+ "description": "BGE-M3 text embedding model"
240
+ }
241
+ ]
242
+
243
+ return {"data": models}
244
+
245
+ # Main entry point
246
+ if __name__ == "__main__":
247
+ import uvicorn
248
+ uvicorn.run(app, host="0.0.0.0", port=8080)