isa-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
- isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
- isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +359 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +101 -0
- isa_model/inference/providers/replicate_provider.py +107 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/base_stt_service.py +91 -0
- isa_model/inference/services/audio/base_tts_service.py +136 -0
- isa_model/inference/services/audio/openai_tts_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
- isa_model/inference/services/embedding/openai_embed_service.py +0 -0
- isa_model/inference/services/llm/__init__.py +12 -0
- isa_model/inference/services/llm/base_llm_service.py +134 -0
- isa_model/inference/services/llm/ollama_llm_service.py +99 -0
- isa_model/inference/services/llm/openai_llm_service.py +138 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/openai_vision_service.py +80 -0
- isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.0.1.dist-info/METADATA +327 -0
- isa_model-0.0.1.dist-info/RECORD +86 -0
- isa_model-0.0.1.dist-info/WHEEL +5 -0
- isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
- isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,248 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
from typing import Dict, List, Any, Optional, Union
|
5
|
+
from fastapi import FastAPI, HTTPException, Depends, Request
|
6
|
+
from pydantic import BaseModel, Field
|
7
|
+
|
8
|
+
from isa_model.inference.ai_factory import AIFactory
|
9
|
+
|
10
|
+
# Configure logging
|
11
|
+
logging.basicConfig(level=logging.INFO)
|
12
|
+
logger = logging.getLogger("unified_api")
|
13
|
+
|
14
|
+
# Create FastAPI app
|
15
|
+
app = FastAPI(
|
16
|
+
title="Unified AI Model API",
|
17
|
+
description="API for inference with Llama3-8B, Gemma3-4B, Whisper, and BGE-M3 models",
|
18
|
+
version="1.0.0"
|
19
|
+
)
|
20
|
+
|
21
|
+
# Models
|
22
|
+
class ChatMessage(BaseModel):
|
23
|
+
role: str = Field(..., description="Role of the message sender (system, user, assistant)")
|
24
|
+
content: str = Field(..., description="Content of the message")
|
25
|
+
|
26
|
+
class ChatCompletionRequest(BaseModel):
|
27
|
+
model: str = Field(..., description="Model ID to use (llama, gemma)")
|
28
|
+
messages: List[ChatMessage] = Field(..., description="List of messages in the conversation")
|
29
|
+
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
|
30
|
+
max_tokens: Optional[int] = Field(512, description="Maximum number of tokens to generate")
|
31
|
+
top_p: Optional[float] = Field(0.9, description="Top-p sampling parameter")
|
32
|
+
top_k: Optional[int] = Field(50, description="Top-k sampling parameter")
|
33
|
+
|
34
|
+
class ChatCompletionResponse(BaseModel):
|
35
|
+
model: str = Field(..., description="Model used for completion")
|
36
|
+
choices: List[Dict[str, Any]] = Field(..., description="Generated completions")
|
37
|
+
usage: Dict[str, int] = Field(..., description="Token usage statistics")
|
38
|
+
|
39
|
+
class EmbeddingRequest(BaseModel):
|
40
|
+
model: str = Field(..., description="Model ID to use (bge_embed)")
|
41
|
+
input: Union[str, List[str]] = Field(..., description="Text to embed")
|
42
|
+
normalize: Optional[bool] = Field(True, description="Whether to normalize embeddings")
|
43
|
+
|
44
|
+
class TranscriptionRequest(BaseModel):
|
45
|
+
model: str = Field(..., description="Model ID to use (whisper)")
|
46
|
+
audio: str = Field(..., description="Base64-encoded audio data or URL")
|
47
|
+
language: Optional[str] = Field("en", description="Language code")
|
48
|
+
|
49
|
+
# Factory for creating services
|
50
|
+
ai_factory = AIFactory()
|
51
|
+
|
52
|
+
# Dependency to get LLM service
|
53
|
+
async def get_llm_service(model: str):
|
54
|
+
if model == "llama":
|
55
|
+
return await ai_factory.get_llm_service("llama")
|
56
|
+
elif model == "gemma":
|
57
|
+
return await ai_factory.get_llm_service("gemma")
|
58
|
+
else:
|
59
|
+
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
60
|
+
|
61
|
+
# Dependency to get embedding service
|
62
|
+
async def get_embedding_service(model: str):
|
63
|
+
if model == "bge_embed":
|
64
|
+
return await ai_factory.get_embedding_service("bge_embed")
|
65
|
+
else:
|
66
|
+
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
67
|
+
|
68
|
+
# Dependency to get speech service
|
69
|
+
async def get_speech_service(model: str):
|
70
|
+
if model == "whisper":
|
71
|
+
return await ai_factory.get_speech_service("whisper")
|
72
|
+
else:
|
73
|
+
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
74
|
+
|
75
|
+
# Endpoints
|
76
|
+
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
77
|
+
async def chat_completion(request: ChatCompletionRequest):
|
78
|
+
"""Generate chat completion"""
|
79
|
+
try:
|
80
|
+
# Get the appropriate service
|
81
|
+
service = await get_llm_service(request.model)
|
82
|
+
|
83
|
+
# Format messages
|
84
|
+
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
85
|
+
|
86
|
+
# Extract system prompt if present
|
87
|
+
system_prompt = None
|
88
|
+
if formatted_messages and formatted_messages[0]["role"] == "system":
|
89
|
+
system_prompt = formatted_messages[0]["content"]
|
90
|
+
formatted_messages = formatted_messages[1:]
|
91
|
+
|
92
|
+
# Get user prompt (last user message)
|
93
|
+
user_prompt = ""
|
94
|
+
for msg in reversed(formatted_messages):
|
95
|
+
if msg["role"] == "user":
|
96
|
+
user_prompt = msg["content"]
|
97
|
+
break
|
98
|
+
|
99
|
+
if not user_prompt:
|
100
|
+
raise HTTPException(status_code=400, detail="No user message found")
|
101
|
+
|
102
|
+
# Set generation config
|
103
|
+
generation_config = {
|
104
|
+
"temperature": request.temperature,
|
105
|
+
"max_new_tokens": request.max_tokens,
|
106
|
+
"top_p": request.top_p,
|
107
|
+
"top_k": request.top_k
|
108
|
+
}
|
109
|
+
|
110
|
+
# Generate completion
|
111
|
+
completion = await service.generate(
|
112
|
+
prompt=user_prompt,
|
113
|
+
system_prompt=system_prompt,
|
114
|
+
generation_config=generation_config
|
115
|
+
)
|
116
|
+
|
117
|
+
# Format response
|
118
|
+
response = {
|
119
|
+
"model": request.model,
|
120
|
+
"choices": [
|
121
|
+
{
|
122
|
+
"message": {
|
123
|
+
"role": "assistant",
|
124
|
+
"content": completion
|
125
|
+
},
|
126
|
+
"finish_reason": "stop",
|
127
|
+
"index": 0
|
128
|
+
}
|
129
|
+
],
|
130
|
+
"usage": {
|
131
|
+
"prompt_tokens": len(user_prompt.split()),
|
132
|
+
"completion_tokens": len(completion.split()),
|
133
|
+
"total_tokens": len(user_prompt.split()) + len(completion.split())
|
134
|
+
}
|
135
|
+
}
|
136
|
+
|
137
|
+
return response
|
138
|
+
|
139
|
+
except Exception as e:
|
140
|
+
logger.error(f"Error in chat completion: {str(e)}")
|
141
|
+
raise HTTPException(status_code=500, detail=str(e))
|
142
|
+
|
143
|
+
@app.post("/v1/embeddings")
|
144
|
+
async def create_embedding(request: EmbeddingRequest):
|
145
|
+
"""Generate embeddings for text"""
|
146
|
+
try:
|
147
|
+
# Get the embedding service
|
148
|
+
service = await get_embedding_service("bge_embed")
|
149
|
+
|
150
|
+
# Generate embeddings
|
151
|
+
if isinstance(request.input, str):
|
152
|
+
embeddings = await service.embed(request.input, normalize=request.normalize)
|
153
|
+
data = [{"embedding": embeddings[0].tolist(), "index": 0}]
|
154
|
+
else:
|
155
|
+
embeddings = await service.embed(request.input, normalize=request.normalize)
|
156
|
+
data = [{"embedding": emb.tolist(), "index": i} for i, emb in enumerate(embeddings)]
|
157
|
+
|
158
|
+
# Format response
|
159
|
+
response = {
|
160
|
+
"model": request.model,
|
161
|
+
"data": data,
|
162
|
+
"usage": {
|
163
|
+
"prompt_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input])),
|
164
|
+
"total_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input]))
|
165
|
+
}
|
166
|
+
}
|
167
|
+
|
168
|
+
return response
|
169
|
+
|
170
|
+
except Exception as e:
|
171
|
+
logger.error(f"Error in embedding generation: {str(e)}")
|
172
|
+
raise HTTPException(status_code=500, detail=str(e))
|
173
|
+
|
174
|
+
@app.post("/v1/audio/transcriptions")
|
175
|
+
async def transcribe_audio(request: TranscriptionRequest):
|
176
|
+
"""Transcribe audio to text"""
|
177
|
+
try:
|
178
|
+
import base64
|
179
|
+
|
180
|
+
# Get the speech service
|
181
|
+
service = await get_speech_service("whisper")
|
182
|
+
|
183
|
+
# Process audio
|
184
|
+
if request.audio.startswith(("http://", "https://")):
|
185
|
+
# URL - download audio
|
186
|
+
import requests
|
187
|
+
audio_data = requests.get(request.audio).content
|
188
|
+
else:
|
189
|
+
# Base64 - decode
|
190
|
+
audio_data = base64.b64decode(request.audio)
|
191
|
+
|
192
|
+
# Transcribe
|
193
|
+
transcription = await service.transcribe(
|
194
|
+
audio=audio_data,
|
195
|
+
language=request.language
|
196
|
+
)
|
197
|
+
|
198
|
+
# Format response
|
199
|
+
response = {
|
200
|
+
"model": request.model,
|
201
|
+
"text": transcription
|
202
|
+
}
|
203
|
+
|
204
|
+
return response
|
205
|
+
|
206
|
+
except Exception as e:
|
207
|
+
logger.error(f"Error in audio transcription: {str(e)}")
|
208
|
+
raise HTTPException(status_code=500, detail=str(e))
|
209
|
+
|
210
|
+
# Health check endpoint
|
211
|
+
@app.get("/health")
|
212
|
+
async def health_check():
|
213
|
+
"""Health check endpoint"""
|
214
|
+
return {"status": "healthy"}
|
215
|
+
|
216
|
+
# Model info endpoint
|
217
|
+
@app.get("/v1/models")
|
218
|
+
async def list_models():
|
219
|
+
"""List available models"""
|
220
|
+
models = [
|
221
|
+
{
|
222
|
+
"id": "llama",
|
223
|
+
"type": "llm",
|
224
|
+
"description": "Llama3-8B language model"
|
225
|
+
},
|
226
|
+
{
|
227
|
+
"id": "gemma",
|
228
|
+
"type": "llm",
|
229
|
+
"description": "Gemma3-4B language model"
|
230
|
+
},
|
231
|
+
{
|
232
|
+
"id": "whisper",
|
233
|
+
"type": "speech",
|
234
|
+
"description": "Whisper-tiny speech-to-text model"
|
235
|
+
},
|
236
|
+
{
|
237
|
+
"id": "bge_embed",
|
238
|
+
"type": "embedding",
|
239
|
+
"description": "BGE-M3 text embedding model"
|
240
|
+
}
|
241
|
+
]
|
242
|
+
|
243
|
+
return {"data": models}
|
244
|
+
|
245
|
+
# Main entry point
|
246
|
+
if __name__ == "__main__":
|
247
|
+
import uvicorn
|
248
|
+
uvicorn.run(app, host="0.0.0.0", port=8080)
|
@@ -0,0 +1,359 @@
|
|
1
|
+
from typing import Dict, Type, Any, Optional, Tuple
|
2
|
+
import logging
|
3
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
4
|
+
from isa_model.inference.services.base_service import BaseService
|
5
|
+
from isa_model.inference.base import ModelType
|
6
|
+
import os
|
7
|
+
|
8
|
+
# 设置基本的日志配置
|
9
|
+
logging.basicConfig(level=logging.INFO)
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
class AIFactory:
|
13
|
+
"""
|
14
|
+
Factory for creating AI services based on the Single Model pattern.
|
15
|
+
"""
|
16
|
+
|
17
|
+
_instance = None
|
18
|
+
_is_initialized = False
|
19
|
+
|
20
|
+
def __new__(cls):
|
21
|
+
if cls._instance is None:
|
22
|
+
cls._instance = super().__new__(cls)
|
23
|
+
return cls._instance
|
24
|
+
|
25
|
+
def __init__(self):
|
26
|
+
"""Initialize the AI Factory."""
|
27
|
+
self.triton_url = os.environ.get("TRITON_URL", "http://localhost:8000")
|
28
|
+
|
29
|
+
# Cache for services (singleton pattern)
|
30
|
+
self._llm_services = {}
|
31
|
+
self._embedding_services = {}
|
32
|
+
self._speech_services = {}
|
33
|
+
|
34
|
+
if not self._is_initialized:
|
35
|
+
self._providers: Dict[str, Type[BaseProvider]] = {}
|
36
|
+
self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
|
37
|
+
self._cached_services: Dict[str, BaseService] = {}
|
38
|
+
self._initialize_defaults()
|
39
|
+
AIFactory._is_initialized = True
|
40
|
+
|
41
|
+
def _initialize_defaults(self):
|
42
|
+
"""Initialize default providers and services"""
|
43
|
+
try:
|
44
|
+
# Import providers and services
|
45
|
+
from isa_model.inference.providers.ollama_provider import OllamaProvider
|
46
|
+
from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
|
47
|
+
from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
|
48
|
+
|
49
|
+
# Register Ollama provider and services
|
50
|
+
self.register_provider('ollama', OllamaProvider)
|
51
|
+
self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
|
52
|
+
self.register_service('ollama', ModelType.LLM, OllamaLLMService)
|
53
|
+
|
54
|
+
# Register OpenAI provider and services
|
55
|
+
try:
|
56
|
+
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
57
|
+
from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
|
58
|
+
|
59
|
+
self.register_provider('openai', OpenAIProvider)
|
60
|
+
self.register_service('openai', ModelType.LLM, OpenAILLMService)
|
61
|
+
logger.info("OpenAI services registered successfully")
|
62
|
+
except ImportError as e:
|
63
|
+
logger.warning(f"OpenAI services not available: {e}")
|
64
|
+
|
65
|
+
# Register Replicate provider and services
|
66
|
+
try:
|
67
|
+
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
68
|
+
from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
|
69
|
+
|
70
|
+
self.register_provider('replicate', ReplicateProvider)
|
71
|
+
self.register_service('replicate', ModelType.VISION, ReplicateVisionService)
|
72
|
+
logger.info("Replicate provider and vision service registered successfully")
|
73
|
+
except ImportError as e:
|
74
|
+
logger.warning(f"Replicate services not available: {e}")
|
75
|
+
except Exception as e:
|
76
|
+
logger.warning(f"Error registering Replicate services: {e}")
|
77
|
+
|
78
|
+
# Try to register Triton services
|
79
|
+
try:
|
80
|
+
from isa_model.inference.providers.triton_provider import TritonProvider
|
81
|
+
|
82
|
+
self.register_provider('triton', TritonProvider)
|
83
|
+
logger.info("Triton provider registered successfully")
|
84
|
+
|
85
|
+
except ImportError as e:
|
86
|
+
logger.warning(f"Triton provider not available: {e}")
|
87
|
+
|
88
|
+
logger.info("Default AI providers and services initialized with backend architecture")
|
89
|
+
except Exception as e:
|
90
|
+
logger.error(f"Error initializing default providers and services: {e}")
|
91
|
+
# Don't raise - allow factory to work even if some services fail to load
|
92
|
+
logger.warning("Some services may not be available due to import errors")
|
93
|
+
|
94
|
+
def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
|
95
|
+
"""Register an AI provider"""
|
96
|
+
self._providers[name] = provider_class
|
97
|
+
|
98
|
+
def register_service(self, provider_name: str, model_type: ModelType,
|
99
|
+
service_class: Type[BaseService]) -> None:
|
100
|
+
"""Register a service type with its provider"""
|
101
|
+
self._services[(provider_name, model_type)] = service_class
|
102
|
+
|
103
|
+
def create_service(self, provider_name: str, model_type: ModelType,
|
104
|
+
model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
|
105
|
+
"""Create a service instance"""
|
106
|
+
try:
|
107
|
+
cache_key = f"{provider_name}_{model_type}_{model_name}"
|
108
|
+
|
109
|
+
if cache_key in self._cached_services:
|
110
|
+
return self._cached_services[cache_key]
|
111
|
+
|
112
|
+
# 基础配置
|
113
|
+
base_config = {
|
114
|
+
"log_level": "INFO"
|
115
|
+
}
|
116
|
+
|
117
|
+
# 合并配置
|
118
|
+
service_config = {**base_config, **(config or {})}
|
119
|
+
|
120
|
+
# 创建 provider 和 service
|
121
|
+
provider_class = self._providers[provider_name]
|
122
|
+
service_class = self._services.get((provider_name, model_type))
|
123
|
+
|
124
|
+
if not service_class:
|
125
|
+
raise ValueError(
|
126
|
+
f"No service registered for provider {provider_name} and model type {model_type}"
|
127
|
+
)
|
128
|
+
|
129
|
+
provider = provider_class(config=service_config)
|
130
|
+
service = service_class(provider=provider, model_name=model_name)
|
131
|
+
|
132
|
+
self._cached_services[cache_key] = service
|
133
|
+
return service
|
134
|
+
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Error creating service: {e}")
|
137
|
+
raise
|
138
|
+
|
139
|
+
# Convenient methods for common services
|
140
|
+
def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
|
141
|
+
config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
|
142
|
+
"""
|
143
|
+
Get a LLM service instance
|
144
|
+
|
145
|
+
Args:
|
146
|
+
model_name: Name of the model to use
|
147
|
+
provider: Provider name ('ollama', 'openai', 'replicate', etc.)
|
148
|
+
config: Optional configuration dictionary
|
149
|
+
api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
LLM service instance
|
153
|
+
|
154
|
+
Example:
|
155
|
+
# Using with API key directly
|
156
|
+
llm = AIFactory.get_instance().get_llm(
|
157
|
+
model_name="gpt-4o-mini",
|
158
|
+
provider="openai",
|
159
|
+
api_key="your-api-key-here"
|
160
|
+
)
|
161
|
+
|
162
|
+
# Using without API key (will use environment variable)
|
163
|
+
llm = AIFactory.get_instance().get_llm(
|
164
|
+
model_name="gpt-4o-mini",
|
165
|
+
provider="openai"
|
166
|
+
)
|
167
|
+
"""
|
168
|
+
|
169
|
+
# Special case for DeepSeek service
|
170
|
+
if model_name.lower() in ["deepseek", "deepseek-r1", "qwen3-8b"]:
|
171
|
+
if "deepseek" in self._cached_services:
|
172
|
+
return self._cached_services["deepseek"]
|
173
|
+
|
174
|
+
# Special case for Llama3-8B direct service
|
175
|
+
if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
|
176
|
+
if "llama3" in self._cached_services:
|
177
|
+
return self._cached_services["llama3"]
|
178
|
+
|
179
|
+
basic_config: Dict[str, Any] = {
|
180
|
+
"temperature": 0
|
181
|
+
}
|
182
|
+
|
183
|
+
# Add API key to config if provided
|
184
|
+
if api_key:
|
185
|
+
if provider == "openai":
|
186
|
+
basic_config["api_key"] = api_key
|
187
|
+
elif provider == "replicate":
|
188
|
+
basic_config["api_token"] = api_key
|
189
|
+
else:
|
190
|
+
logger.warning(f"API key provided but provider '{provider}' may not support it")
|
191
|
+
|
192
|
+
if config:
|
193
|
+
basic_config.update(config)
|
194
|
+
return self.create_service(provider, ModelType.LLM, model_name, basic_config)
|
195
|
+
|
196
|
+
def get_vision_model(self, model_name: str = "gemma3-4b", provider: str = "triton",
|
197
|
+
config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
|
198
|
+
"""
|
199
|
+
Get a vision model service instance
|
200
|
+
|
201
|
+
Args:
|
202
|
+
model_name: Name of the model to use
|
203
|
+
provider: Provider name ('openai', 'replicate', 'triton', etc.)
|
204
|
+
config: Optional configuration dictionary
|
205
|
+
api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
Vision service instance
|
209
|
+
|
210
|
+
Example:
|
211
|
+
# Using with API key directly
|
212
|
+
vision = AIFactory.get_instance().get_vision_model(
|
213
|
+
model_name="gpt-4o",
|
214
|
+
provider="openai",
|
215
|
+
api_key="your-api-key-here"
|
216
|
+
)
|
217
|
+
|
218
|
+
# Using Replicate for image generation
|
219
|
+
image_gen = AIFactory.get_instance().get_vision_model(
|
220
|
+
model_name="stability-ai/sdxl",
|
221
|
+
provider="replicate",
|
222
|
+
api_key="your-replicate-token"
|
223
|
+
)
|
224
|
+
"""
|
225
|
+
|
226
|
+
# Special case for Gemma3-4B direct service
|
227
|
+
if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
|
228
|
+
if "gemma3" in self._cached_services:
|
229
|
+
return self._cached_services["gemma3"]
|
230
|
+
|
231
|
+
# Special case for Replicate's image generation models
|
232
|
+
if provider == "replicate" and "/" in model_name:
|
233
|
+
replicate_config: Dict[str, Any] = {
|
234
|
+
"guidance_scale": 7.5,
|
235
|
+
"num_inference_steps": 30
|
236
|
+
}
|
237
|
+
|
238
|
+
# Add API key if provided
|
239
|
+
if api_key:
|
240
|
+
replicate_config["api_token"] = api_key
|
241
|
+
|
242
|
+
if config:
|
243
|
+
replicate_config.update(config)
|
244
|
+
return self.create_service(provider, ModelType.VISION, model_name, replicate_config)
|
245
|
+
|
246
|
+
basic_config: Dict[str, Any] = {
|
247
|
+
"temperature": 0.7,
|
248
|
+
"max_new_tokens": 512
|
249
|
+
}
|
250
|
+
|
251
|
+
# Add API key to config if provided
|
252
|
+
if api_key:
|
253
|
+
if provider == "openai":
|
254
|
+
basic_config["api_key"] = api_key
|
255
|
+
elif provider == "replicate":
|
256
|
+
basic_config["api_token"] = api_key
|
257
|
+
else:
|
258
|
+
logger.warning(f"API key provided but provider '{provider}' may not support it")
|
259
|
+
|
260
|
+
if config:
|
261
|
+
basic_config.update(config)
|
262
|
+
return self.create_service(provider, ModelType.VISION, model_name, basic_config)
|
263
|
+
|
264
|
+
def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
|
265
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
266
|
+
"""Get an embedding service instance"""
|
267
|
+
return self.create_service(provider, ModelType.EMBEDDING, model_name, config)
|
268
|
+
|
269
|
+
def get_rerank(self, model_name: str = "bge-m3", provider: str = "ollama",
|
270
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
271
|
+
"""Get a rerank service instance"""
|
272
|
+
return self.create_service(provider, ModelType.RERANK, model_name, config)
|
273
|
+
|
274
|
+
def get_embed_service(self, model_name: str = "bge-m3", provider: str = "ollama",
|
275
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
276
|
+
"""Get an embedding service instance"""
|
277
|
+
return self.get_embedding(model_name, provider, config)
|
278
|
+
|
279
|
+
def get_speech_model(self, model_name: str = "whisper_tiny", provider: str = "triton",
|
280
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
281
|
+
"""Get a speech-to-text model service instance"""
|
282
|
+
|
283
|
+
# Special case for Whisper Tiny direct service
|
284
|
+
if model_name.lower() in ["whisper", "whisper_tiny", "whisper-tiny"]:
|
285
|
+
if "whisper" in self._cached_services:
|
286
|
+
return self._cached_services["whisper"]
|
287
|
+
|
288
|
+
basic_config = {
|
289
|
+
"language": "en",
|
290
|
+
"task": "transcribe"
|
291
|
+
}
|
292
|
+
if config:
|
293
|
+
basic_config.update(config)
|
294
|
+
return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
|
295
|
+
|
296
|
+
async def get_embedding_service(self, model_name: str) -> Any:
|
297
|
+
"""
|
298
|
+
Get an embedding service for the specified model.
|
299
|
+
|
300
|
+
Args:
|
301
|
+
model_name: Name of the model
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
Embedding service instance
|
305
|
+
"""
|
306
|
+
if model_name in self._embedding_services:
|
307
|
+
return self._embedding_services[model_name]
|
308
|
+
|
309
|
+
else:
|
310
|
+
raise ValueError(f"Unsupported embedding model: {model_name}")
|
311
|
+
|
312
|
+
async def get_speech_service(self, model_name: str) -> Any:
|
313
|
+
"""
|
314
|
+
Get a speech service for the specified model.
|
315
|
+
|
316
|
+
Args:
|
317
|
+
model_name: Name of the model
|
318
|
+
|
319
|
+
Returns:
|
320
|
+
Speech service instance
|
321
|
+
"""
|
322
|
+
if model_name in self._speech_services:
|
323
|
+
return self._speech_services[model_name]
|
324
|
+
|
325
|
+
|
326
|
+
def get_model_info(self, model_type: Optional[str] = None) -> Dict[str, Any]:
|
327
|
+
"""
|
328
|
+
Get information about available models.
|
329
|
+
|
330
|
+
Args:
|
331
|
+
model_type: Optional filter for model type
|
332
|
+
|
333
|
+
Returns:
|
334
|
+
Dict of model information
|
335
|
+
"""
|
336
|
+
models = {
|
337
|
+
"llm": [
|
338
|
+
{"name": "deepseek", "description": "DeepSeek-R1-0528-Qwen3-8B language model"},
|
339
|
+
{"name": "llama", "description": "Llama3-8B language model"},
|
340
|
+
{"name": "gemma", "description": "Gemma3-4B language model"}
|
341
|
+
],
|
342
|
+
"embedding": [
|
343
|
+
{"name": "bge_embed", "description": "BGE-M3 text embedding model"}
|
344
|
+
],
|
345
|
+
"speech": [
|
346
|
+
{"name": "whisper", "description": "Whisper-tiny speech-to-text model"}
|
347
|
+
]
|
348
|
+
}
|
349
|
+
|
350
|
+
if model_type:
|
351
|
+
return {model_type: models.get(model_type, [])}
|
352
|
+
return models
|
353
|
+
|
354
|
+
@classmethod
|
355
|
+
def get_instance(cls) -> 'AIFactory':
|
356
|
+
"""Get the singleton instance"""
|
357
|
+
if cls._instance is None:
|
358
|
+
cls._instance = cls()
|
359
|
+
return cls._instance
|
@@ -0,0 +1,46 @@
|
|
1
|
+
"""
|
2
|
+
Base definitions for the Inference layer.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from enum import Enum, auto
|
6
|
+
from typing import Dict, List, Optional, Any, Union, TypeVar, Generic
|
7
|
+
|
8
|
+
T = TypeVar('T')
|
9
|
+
|
10
|
+
|
11
|
+
class ModelType(str, Enum):
|
12
|
+
"""Types of AI models supported by the framework."""
|
13
|
+
LLM = "llm"
|
14
|
+
EMBEDDING = "embedding"
|
15
|
+
VISION = "vision"
|
16
|
+
AUDIO = "audio"
|
17
|
+
OCR = "ocr"
|
18
|
+
TTS = "tts"
|
19
|
+
RERANK = "rerank"
|
20
|
+
MULTIMODAL = "multimodal"
|
21
|
+
|
22
|
+
|
23
|
+
class Capability(str, Enum):
|
24
|
+
"""Capabilities supported by models."""
|
25
|
+
CHAT = "chat"
|
26
|
+
COMPLETION = "completion"
|
27
|
+
EMBEDDING = "embedding"
|
28
|
+
IMAGE_GENERATION = "image_generation"
|
29
|
+
IMAGE_CLASSIFICATION = "image_classification"
|
30
|
+
OBJECT_DETECTION = "object_detection"
|
31
|
+
SPEECH_TO_TEXT = "speech_to_text"
|
32
|
+
TEXT_TO_SPEECH = "text_to_speech"
|
33
|
+
OCR = "ocr"
|
34
|
+
RERANKING = "reranking"
|
35
|
+
MULTIMODAL_UNDERSTANDING = "multimodal_understanding"
|
36
|
+
|
37
|
+
|
38
|
+
class RoutingStrategy(str, Enum):
|
39
|
+
"""Routing strategies for distributing requests among model replicas."""
|
40
|
+
ROUND_ROBIN = "round_robin"
|
41
|
+
WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
|
42
|
+
LEAST_CONNECTIONS = "least_connections"
|
43
|
+
RESPONSE_TIME = "response_time"
|
44
|
+
RANDOM = "random"
|
45
|
+
CONSISTENT_HASH = "consistent_hash"
|
46
|
+
DYNAMIC_LOAD_BALANCING = "dynamic_load_balancing"
|
@@ -0,0 +1,19 @@
|
|
1
|
+
"""
|
2
|
+
Providers - Components for integrating with different model providers
|
3
|
+
|
4
|
+
File: isa_model/inference/providers/__init__.py
|
5
|
+
This module contains provider implementations for different AI model backends.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from .base_provider import BaseProvider
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"BaseProvider",
|
12
|
+
]
|
13
|
+
|
14
|
+
# Provider implementations can be imported individually as needed
|
15
|
+
# from .triton_provider import TritonProvider
|
16
|
+
# from .ollama_provider import OllamaProvider
|
17
|
+
# from .yyds_provider import YYDSProvider
|
18
|
+
# from .openai_provider import OpenAIProvider
|
19
|
+
# from .replicate_provider import ReplicateProvider
|