isa-model 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. isa_model/__init__.py +5 -0
  2. isa_model/core/model_manager.py +143 -0
  3. isa_model/core/model_registry.py +115 -0
  4. isa_model/core/model_router.py +226 -0
  5. isa_model/core/model_storage.py +133 -0
  6. isa_model/core/model_version.py +0 -0
  7. isa_model/core/resource_manager.py +202 -0
  8. isa_model/core/storage/hf_storage.py +0 -0
  9. isa_model/core/storage/local_storage.py +0 -0
  10. isa_model/core/storage/minio_storage.py +0 -0
  11. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +120 -0
  12. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +18 -0
  13. isa_model/deployment/gpu_int8_ds8/app/server.py +66 -0
  14. isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +43 -0
  15. isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +35 -0
  16. isa_model/inference/__init__.py +11 -0
  17. isa_model/inference/adapter/unified_api.py +248 -0
  18. isa_model/inference/ai_factory.py +359 -0
  19. isa_model/inference/base.py +46 -0
  20. isa_model/inference/providers/__init__.py +19 -0
  21. isa_model/inference/providers/base_provider.py +30 -0
  22. isa_model/inference/providers/model_cache_manager.py +341 -0
  23. isa_model/inference/providers/ollama_provider.py +73 -0
  24. isa_model/inference/providers/openai_provider.py +101 -0
  25. isa_model/inference/providers/replicate_provider.py +107 -0
  26. isa_model/inference/providers/triton_provider.py +439 -0
  27. isa_model/inference/services/__init__.py +14 -0
  28. isa_model/inference/services/audio/base_stt_service.py +91 -0
  29. isa_model/inference/services/audio/base_tts_service.py +136 -0
  30. isa_model/inference/services/audio/openai_tts_service.py +71 -0
  31. isa_model/inference/services/base_service.py +106 -0
  32. isa_model/inference/services/embedding/ollama_embed_service.py +97 -0
  33. isa_model/inference/services/embedding/openai_embed_service.py +0 -0
  34. isa_model/inference/services/llm/__init__.py +12 -0
  35. isa_model/inference/services/llm/base_llm_service.py +134 -0
  36. isa_model/inference/services/llm/ollama_llm_service.py +99 -0
  37. isa_model/inference/services/llm/openai_llm_service.py +138 -0
  38. isa_model/inference/services/others/table_transformer_service.py +61 -0
  39. isa_model/inference/services/vision/__init__.py +12 -0
  40. isa_model/inference/services/vision/helpers/image_utils.py +58 -0
  41. isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
  42. isa_model/inference/services/vision/ollama_vision_service.py +60 -0
  43. isa_model/inference/services/vision/openai_vision_service.py +80 -0
  44. isa_model/inference/services/vision/replicate_image_gen_service.py +185 -0
  45. isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
  46. isa_model/inference/utils/conversion/onnx_converter.py +0 -0
  47. isa_model/inference/utils/conversion/torch_converter.py +0 -0
  48. isa_model/scripts/inference_tracker.py +283 -0
  49. isa_model/scripts/mlflow_manager.py +379 -0
  50. isa_model/scripts/model_registry.py +465 -0
  51. isa_model/scripts/start_mlflow.py +95 -0
  52. isa_model/scripts/training_tracker.py +257 -0
  53. isa_model/training/engine/llama_factory/__init__.py +39 -0
  54. isa_model/training/engine/llama_factory/config.py +115 -0
  55. isa_model/training/engine/llama_factory/data_adapter.py +284 -0
  56. isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
  57. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
  58. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
  59. isa_model/training/engine/llama_factory/factory.py +331 -0
  60. isa_model/training/engine/llama_factory/rl.py +254 -0
  61. isa_model/training/engine/llama_factory/trainer.py +171 -0
  62. isa_model/training/image_model/configs/create_config.py +37 -0
  63. isa_model/training/image_model/configs/create_flux_config.py +26 -0
  64. isa_model/training/image_model/configs/create_lora_config.py +21 -0
  65. isa_model/training/image_model/prepare_massed_compute.py +97 -0
  66. isa_model/training/image_model/prepare_upload.py +17 -0
  67. isa_model/training/image_model/raw_data/create_captions.py +16 -0
  68. isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
  69. isa_model/training/image_model/raw_data/pre_processing.py +200 -0
  70. isa_model/training/image_model/train/train.py +42 -0
  71. isa_model/training/image_model/train/train_flux.py +41 -0
  72. isa_model/training/image_model/train/train_lora.py +57 -0
  73. isa_model/training/image_model/train_main.py +25 -0
  74. isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
  75. isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
  76. isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
  77. isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
  78. isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
  79. isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
  80. isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
  81. isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
  82. isa_model-0.0.1.dist-info/METADATA +327 -0
  83. isa_model-0.0.1.dist-info/RECORD +86 -0
  84. isa_model-0.0.1.dist-info/WHEEL +5 -0
  85. isa_model-0.0.1.dist-info/licenses/LICENSE +21 -0
  86. isa_model-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,248 @@
1
+ import os
2
+ import json
3
+ import logging
4
+ from typing import Dict, List, Any, Optional, Union
5
+ from fastapi import FastAPI, HTTPException, Depends, Request
6
+ from pydantic import BaseModel, Field
7
+
8
+ from isa_model.inference.ai_factory import AIFactory
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger("unified_api")
13
+
14
+ # Create FastAPI app
15
+ app = FastAPI(
16
+ title="Unified AI Model API",
17
+ description="API for inference with Llama3-8B, Gemma3-4B, Whisper, and BGE-M3 models",
18
+ version="1.0.0"
19
+ )
20
+
21
+ # Models
22
+ class ChatMessage(BaseModel):
23
+ role: str = Field(..., description="Role of the message sender (system, user, assistant)")
24
+ content: str = Field(..., description="Content of the message")
25
+
26
+ class ChatCompletionRequest(BaseModel):
27
+ model: str = Field(..., description="Model ID to use (llama, gemma)")
28
+ messages: List[ChatMessage] = Field(..., description="List of messages in the conversation")
29
+ temperature: Optional[float] = Field(0.7, description="Sampling temperature")
30
+ max_tokens: Optional[int] = Field(512, description="Maximum number of tokens to generate")
31
+ top_p: Optional[float] = Field(0.9, description="Top-p sampling parameter")
32
+ top_k: Optional[int] = Field(50, description="Top-k sampling parameter")
33
+
34
+ class ChatCompletionResponse(BaseModel):
35
+ model: str = Field(..., description="Model used for completion")
36
+ choices: List[Dict[str, Any]] = Field(..., description="Generated completions")
37
+ usage: Dict[str, int] = Field(..., description="Token usage statistics")
38
+
39
+ class EmbeddingRequest(BaseModel):
40
+ model: str = Field(..., description="Model ID to use (bge_embed)")
41
+ input: Union[str, List[str]] = Field(..., description="Text to embed")
42
+ normalize: Optional[bool] = Field(True, description="Whether to normalize embeddings")
43
+
44
+ class TranscriptionRequest(BaseModel):
45
+ model: str = Field(..., description="Model ID to use (whisper)")
46
+ audio: str = Field(..., description="Base64-encoded audio data or URL")
47
+ language: Optional[str] = Field("en", description="Language code")
48
+
49
+ # Factory for creating services
50
+ ai_factory = AIFactory()
51
+
52
+ # Dependency to get LLM service
53
+ async def get_llm_service(model: str):
54
+ if model == "llama":
55
+ return await ai_factory.get_llm_service("llama")
56
+ elif model == "gemma":
57
+ return await ai_factory.get_llm_service("gemma")
58
+ else:
59
+ raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
60
+
61
+ # Dependency to get embedding service
62
+ async def get_embedding_service(model: str):
63
+ if model == "bge_embed":
64
+ return await ai_factory.get_embedding_service("bge_embed")
65
+ else:
66
+ raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
67
+
68
+ # Dependency to get speech service
69
+ async def get_speech_service(model: str):
70
+ if model == "whisper":
71
+ return await ai_factory.get_speech_service("whisper")
72
+ else:
73
+ raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
74
+
75
+ # Endpoints
76
+ @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
77
+ async def chat_completion(request: ChatCompletionRequest):
78
+ """Generate chat completion"""
79
+ try:
80
+ # Get the appropriate service
81
+ service = await get_llm_service(request.model)
82
+
83
+ # Format messages
84
+ formatted_messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
85
+
86
+ # Extract system prompt if present
87
+ system_prompt = None
88
+ if formatted_messages and formatted_messages[0]["role"] == "system":
89
+ system_prompt = formatted_messages[0]["content"]
90
+ formatted_messages = formatted_messages[1:]
91
+
92
+ # Get user prompt (last user message)
93
+ user_prompt = ""
94
+ for msg in reversed(formatted_messages):
95
+ if msg["role"] == "user":
96
+ user_prompt = msg["content"]
97
+ break
98
+
99
+ if not user_prompt:
100
+ raise HTTPException(status_code=400, detail="No user message found")
101
+
102
+ # Set generation config
103
+ generation_config = {
104
+ "temperature": request.temperature,
105
+ "max_new_tokens": request.max_tokens,
106
+ "top_p": request.top_p,
107
+ "top_k": request.top_k
108
+ }
109
+
110
+ # Generate completion
111
+ completion = await service.generate(
112
+ prompt=user_prompt,
113
+ system_prompt=system_prompt,
114
+ generation_config=generation_config
115
+ )
116
+
117
+ # Format response
118
+ response = {
119
+ "model": request.model,
120
+ "choices": [
121
+ {
122
+ "message": {
123
+ "role": "assistant",
124
+ "content": completion
125
+ },
126
+ "finish_reason": "stop",
127
+ "index": 0
128
+ }
129
+ ],
130
+ "usage": {
131
+ "prompt_tokens": len(user_prompt.split()),
132
+ "completion_tokens": len(completion.split()),
133
+ "total_tokens": len(user_prompt.split()) + len(completion.split())
134
+ }
135
+ }
136
+
137
+ return response
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error in chat completion: {str(e)}")
141
+ raise HTTPException(status_code=500, detail=str(e))
142
+
143
+ @app.post("/v1/embeddings")
144
+ async def create_embedding(request: EmbeddingRequest):
145
+ """Generate embeddings for text"""
146
+ try:
147
+ # Get the embedding service
148
+ service = await get_embedding_service("bge_embed")
149
+
150
+ # Generate embeddings
151
+ if isinstance(request.input, str):
152
+ embeddings = await service.embed(request.input, normalize=request.normalize)
153
+ data = [{"embedding": embeddings[0].tolist(), "index": 0}]
154
+ else:
155
+ embeddings = await service.embed(request.input, normalize=request.normalize)
156
+ data = [{"embedding": emb.tolist(), "index": i} for i, emb in enumerate(embeddings)]
157
+
158
+ # Format response
159
+ response = {
160
+ "model": request.model,
161
+ "data": data,
162
+ "usage": {
163
+ "prompt_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input])),
164
+ "total_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input]))
165
+ }
166
+ }
167
+
168
+ return response
169
+
170
+ except Exception as e:
171
+ logger.error(f"Error in embedding generation: {str(e)}")
172
+ raise HTTPException(status_code=500, detail=str(e))
173
+
174
+ @app.post("/v1/audio/transcriptions")
175
+ async def transcribe_audio(request: TranscriptionRequest):
176
+ """Transcribe audio to text"""
177
+ try:
178
+ import base64
179
+
180
+ # Get the speech service
181
+ service = await get_speech_service("whisper")
182
+
183
+ # Process audio
184
+ if request.audio.startswith(("http://", "https://")):
185
+ # URL - download audio
186
+ import requests
187
+ audio_data = requests.get(request.audio).content
188
+ else:
189
+ # Base64 - decode
190
+ audio_data = base64.b64decode(request.audio)
191
+
192
+ # Transcribe
193
+ transcription = await service.transcribe(
194
+ audio=audio_data,
195
+ language=request.language
196
+ )
197
+
198
+ # Format response
199
+ response = {
200
+ "model": request.model,
201
+ "text": transcription
202
+ }
203
+
204
+ return response
205
+
206
+ except Exception as e:
207
+ logger.error(f"Error in audio transcription: {str(e)}")
208
+ raise HTTPException(status_code=500, detail=str(e))
209
+
210
+ # Health check endpoint
211
+ @app.get("/health")
212
+ async def health_check():
213
+ """Health check endpoint"""
214
+ return {"status": "healthy"}
215
+
216
+ # Model info endpoint
217
+ @app.get("/v1/models")
218
+ async def list_models():
219
+ """List available models"""
220
+ models = [
221
+ {
222
+ "id": "llama",
223
+ "type": "llm",
224
+ "description": "Llama3-8B language model"
225
+ },
226
+ {
227
+ "id": "gemma",
228
+ "type": "llm",
229
+ "description": "Gemma3-4B language model"
230
+ },
231
+ {
232
+ "id": "whisper",
233
+ "type": "speech",
234
+ "description": "Whisper-tiny speech-to-text model"
235
+ },
236
+ {
237
+ "id": "bge_embed",
238
+ "type": "embedding",
239
+ "description": "BGE-M3 text embedding model"
240
+ }
241
+ ]
242
+
243
+ return {"data": models}
244
+
245
+ # Main entry point
246
+ if __name__ == "__main__":
247
+ import uvicorn
248
+ uvicorn.run(app, host="0.0.0.0", port=8080)
@@ -0,0 +1,359 @@
1
+ from typing import Dict, Type, Any, Optional, Tuple
2
+ import logging
3
+ from isa_model.inference.providers.base_provider import BaseProvider
4
+ from isa_model.inference.services.base_service import BaseService
5
+ from isa_model.inference.base import ModelType
6
+ import os
7
+
8
+ # 设置基本的日志配置
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class AIFactory:
13
+ """
14
+ Factory for creating AI services based on the Single Model pattern.
15
+ """
16
+
17
+ _instance = None
18
+ _is_initialized = False
19
+
20
+ def __new__(cls):
21
+ if cls._instance is None:
22
+ cls._instance = super().__new__(cls)
23
+ return cls._instance
24
+
25
+ def __init__(self):
26
+ """Initialize the AI Factory."""
27
+ self.triton_url = os.environ.get("TRITON_URL", "http://localhost:8000")
28
+
29
+ # Cache for services (singleton pattern)
30
+ self._llm_services = {}
31
+ self._embedding_services = {}
32
+ self._speech_services = {}
33
+
34
+ if not self._is_initialized:
35
+ self._providers: Dict[str, Type[BaseProvider]] = {}
36
+ self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
37
+ self._cached_services: Dict[str, BaseService] = {}
38
+ self._initialize_defaults()
39
+ AIFactory._is_initialized = True
40
+
41
+ def _initialize_defaults(self):
42
+ """Initialize default providers and services"""
43
+ try:
44
+ # Import providers and services
45
+ from isa_model.inference.providers.ollama_provider import OllamaProvider
46
+ from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
47
+ from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
48
+
49
+ # Register Ollama provider and services
50
+ self.register_provider('ollama', OllamaProvider)
51
+ self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
52
+ self.register_service('ollama', ModelType.LLM, OllamaLLMService)
53
+
54
+ # Register OpenAI provider and services
55
+ try:
56
+ from isa_model.inference.providers.openai_provider import OpenAIProvider
57
+ from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
58
+
59
+ self.register_provider('openai', OpenAIProvider)
60
+ self.register_service('openai', ModelType.LLM, OpenAILLMService)
61
+ logger.info("OpenAI services registered successfully")
62
+ except ImportError as e:
63
+ logger.warning(f"OpenAI services not available: {e}")
64
+
65
+ # Register Replicate provider and services
66
+ try:
67
+ from isa_model.inference.providers.replicate_provider import ReplicateProvider
68
+ from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
69
+
70
+ self.register_provider('replicate', ReplicateProvider)
71
+ self.register_service('replicate', ModelType.VISION, ReplicateVisionService)
72
+ logger.info("Replicate provider and vision service registered successfully")
73
+ except ImportError as e:
74
+ logger.warning(f"Replicate services not available: {e}")
75
+ except Exception as e:
76
+ logger.warning(f"Error registering Replicate services: {e}")
77
+
78
+ # Try to register Triton services
79
+ try:
80
+ from isa_model.inference.providers.triton_provider import TritonProvider
81
+
82
+ self.register_provider('triton', TritonProvider)
83
+ logger.info("Triton provider registered successfully")
84
+
85
+ except ImportError as e:
86
+ logger.warning(f"Triton provider not available: {e}")
87
+
88
+ logger.info("Default AI providers and services initialized with backend architecture")
89
+ except Exception as e:
90
+ logger.error(f"Error initializing default providers and services: {e}")
91
+ # Don't raise - allow factory to work even if some services fail to load
92
+ logger.warning("Some services may not be available due to import errors")
93
+
94
+ def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
95
+ """Register an AI provider"""
96
+ self._providers[name] = provider_class
97
+
98
+ def register_service(self, provider_name: str, model_type: ModelType,
99
+ service_class: Type[BaseService]) -> None:
100
+ """Register a service type with its provider"""
101
+ self._services[(provider_name, model_type)] = service_class
102
+
103
+ def create_service(self, provider_name: str, model_type: ModelType,
104
+ model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
105
+ """Create a service instance"""
106
+ try:
107
+ cache_key = f"{provider_name}_{model_type}_{model_name}"
108
+
109
+ if cache_key in self._cached_services:
110
+ return self._cached_services[cache_key]
111
+
112
+ # 基础配置
113
+ base_config = {
114
+ "log_level": "INFO"
115
+ }
116
+
117
+ # 合并配置
118
+ service_config = {**base_config, **(config or {})}
119
+
120
+ # 创建 provider 和 service
121
+ provider_class = self._providers[provider_name]
122
+ service_class = self._services.get((provider_name, model_type))
123
+
124
+ if not service_class:
125
+ raise ValueError(
126
+ f"No service registered for provider {provider_name} and model type {model_type}"
127
+ )
128
+
129
+ provider = provider_class(config=service_config)
130
+ service = service_class(provider=provider, model_name=model_name)
131
+
132
+ self._cached_services[cache_key] = service
133
+ return service
134
+
135
+ except Exception as e:
136
+ logger.error(f"Error creating service: {e}")
137
+ raise
138
+
139
+ # Convenient methods for common services
140
+ def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
141
+ config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
142
+ """
143
+ Get a LLM service instance
144
+
145
+ Args:
146
+ model_name: Name of the model to use
147
+ provider: Provider name ('ollama', 'openai', 'replicate', etc.)
148
+ config: Optional configuration dictionary
149
+ api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
150
+
151
+ Returns:
152
+ LLM service instance
153
+
154
+ Example:
155
+ # Using with API key directly
156
+ llm = AIFactory.get_instance().get_llm(
157
+ model_name="gpt-4o-mini",
158
+ provider="openai",
159
+ api_key="your-api-key-here"
160
+ )
161
+
162
+ # Using without API key (will use environment variable)
163
+ llm = AIFactory.get_instance().get_llm(
164
+ model_name="gpt-4o-mini",
165
+ provider="openai"
166
+ )
167
+ """
168
+
169
+ # Special case for DeepSeek service
170
+ if model_name.lower() in ["deepseek", "deepseek-r1", "qwen3-8b"]:
171
+ if "deepseek" in self._cached_services:
172
+ return self._cached_services["deepseek"]
173
+
174
+ # Special case for Llama3-8B direct service
175
+ if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
176
+ if "llama3" in self._cached_services:
177
+ return self._cached_services["llama3"]
178
+
179
+ basic_config: Dict[str, Any] = {
180
+ "temperature": 0
181
+ }
182
+
183
+ # Add API key to config if provided
184
+ if api_key:
185
+ if provider == "openai":
186
+ basic_config["api_key"] = api_key
187
+ elif provider == "replicate":
188
+ basic_config["api_token"] = api_key
189
+ else:
190
+ logger.warning(f"API key provided but provider '{provider}' may not support it")
191
+
192
+ if config:
193
+ basic_config.update(config)
194
+ return self.create_service(provider, ModelType.LLM, model_name, basic_config)
195
+
196
+ def get_vision_model(self, model_name: str = "gemma3-4b", provider: str = "triton",
197
+ config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
198
+ """
199
+ Get a vision model service instance
200
+
201
+ Args:
202
+ model_name: Name of the model to use
203
+ provider: Provider name ('openai', 'replicate', 'triton', etc.)
204
+ config: Optional configuration dictionary
205
+ api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
206
+
207
+ Returns:
208
+ Vision service instance
209
+
210
+ Example:
211
+ # Using with API key directly
212
+ vision = AIFactory.get_instance().get_vision_model(
213
+ model_name="gpt-4o",
214
+ provider="openai",
215
+ api_key="your-api-key-here"
216
+ )
217
+
218
+ # Using Replicate for image generation
219
+ image_gen = AIFactory.get_instance().get_vision_model(
220
+ model_name="stability-ai/sdxl",
221
+ provider="replicate",
222
+ api_key="your-replicate-token"
223
+ )
224
+ """
225
+
226
+ # Special case for Gemma3-4B direct service
227
+ if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
228
+ if "gemma3" in self._cached_services:
229
+ return self._cached_services["gemma3"]
230
+
231
+ # Special case for Replicate's image generation models
232
+ if provider == "replicate" and "/" in model_name:
233
+ replicate_config: Dict[str, Any] = {
234
+ "guidance_scale": 7.5,
235
+ "num_inference_steps": 30
236
+ }
237
+
238
+ # Add API key if provided
239
+ if api_key:
240
+ replicate_config["api_token"] = api_key
241
+
242
+ if config:
243
+ replicate_config.update(config)
244
+ return self.create_service(provider, ModelType.VISION, model_name, replicate_config)
245
+
246
+ basic_config: Dict[str, Any] = {
247
+ "temperature": 0.7,
248
+ "max_new_tokens": 512
249
+ }
250
+
251
+ # Add API key to config if provided
252
+ if api_key:
253
+ if provider == "openai":
254
+ basic_config["api_key"] = api_key
255
+ elif provider == "replicate":
256
+ basic_config["api_token"] = api_key
257
+ else:
258
+ logger.warning(f"API key provided but provider '{provider}' may not support it")
259
+
260
+ if config:
261
+ basic_config.update(config)
262
+ return self.create_service(provider, ModelType.VISION, model_name, basic_config)
263
+
264
+ def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
265
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
266
+ """Get an embedding service instance"""
267
+ return self.create_service(provider, ModelType.EMBEDDING, model_name, config)
268
+
269
+ def get_rerank(self, model_name: str = "bge-m3", provider: str = "ollama",
270
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
271
+ """Get a rerank service instance"""
272
+ return self.create_service(provider, ModelType.RERANK, model_name, config)
273
+
274
+ def get_embed_service(self, model_name: str = "bge-m3", provider: str = "ollama",
275
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
276
+ """Get an embedding service instance"""
277
+ return self.get_embedding(model_name, provider, config)
278
+
279
+ def get_speech_model(self, model_name: str = "whisper_tiny", provider: str = "triton",
280
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
281
+ """Get a speech-to-text model service instance"""
282
+
283
+ # Special case for Whisper Tiny direct service
284
+ if model_name.lower() in ["whisper", "whisper_tiny", "whisper-tiny"]:
285
+ if "whisper" in self._cached_services:
286
+ return self._cached_services["whisper"]
287
+
288
+ basic_config = {
289
+ "language": "en",
290
+ "task": "transcribe"
291
+ }
292
+ if config:
293
+ basic_config.update(config)
294
+ return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
295
+
296
+ async def get_embedding_service(self, model_name: str) -> Any:
297
+ """
298
+ Get an embedding service for the specified model.
299
+
300
+ Args:
301
+ model_name: Name of the model
302
+
303
+ Returns:
304
+ Embedding service instance
305
+ """
306
+ if model_name in self._embedding_services:
307
+ return self._embedding_services[model_name]
308
+
309
+ else:
310
+ raise ValueError(f"Unsupported embedding model: {model_name}")
311
+
312
+ async def get_speech_service(self, model_name: str) -> Any:
313
+ """
314
+ Get a speech service for the specified model.
315
+
316
+ Args:
317
+ model_name: Name of the model
318
+
319
+ Returns:
320
+ Speech service instance
321
+ """
322
+ if model_name in self._speech_services:
323
+ return self._speech_services[model_name]
324
+
325
+
326
+ def get_model_info(self, model_type: Optional[str] = None) -> Dict[str, Any]:
327
+ """
328
+ Get information about available models.
329
+
330
+ Args:
331
+ model_type: Optional filter for model type
332
+
333
+ Returns:
334
+ Dict of model information
335
+ """
336
+ models = {
337
+ "llm": [
338
+ {"name": "deepseek", "description": "DeepSeek-R1-0528-Qwen3-8B language model"},
339
+ {"name": "llama", "description": "Llama3-8B language model"},
340
+ {"name": "gemma", "description": "Gemma3-4B language model"}
341
+ ],
342
+ "embedding": [
343
+ {"name": "bge_embed", "description": "BGE-M3 text embedding model"}
344
+ ],
345
+ "speech": [
346
+ {"name": "whisper", "description": "Whisper-tiny speech-to-text model"}
347
+ ]
348
+ }
349
+
350
+ if model_type:
351
+ return {model_type: models.get(model_type, [])}
352
+ return models
353
+
354
+ @classmethod
355
+ def get_instance(cls) -> 'AIFactory':
356
+ """Get the singleton instance"""
357
+ if cls._instance is None:
358
+ cls._instance = cls()
359
+ return cls._instance
@@ -0,0 +1,46 @@
1
+ """
2
+ Base definitions for the Inference layer.
3
+ """
4
+
5
+ from enum import Enum, auto
6
+ from typing import Dict, List, Optional, Any, Union, TypeVar, Generic
7
+
8
+ T = TypeVar('T')
9
+
10
+
11
+ class ModelType(str, Enum):
12
+ """Types of AI models supported by the framework."""
13
+ LLM = "llm"
14
+ EMBEDDING = "embedding"
15
+ VISION = "vision"
16
+ AUDIO = "audio"
17
+ OCR = "ocr"
18
+ TTS = "tts"
19
+ RERANK = "rerank"
20
+ MULTIMODAL = "multimodal"
21
+
22
+
23
+ class Capability(str, Enum):
24
+ """Capabilities supported by models."""
25
+ CHAT = "chat"
26
+ COMPLETION = "completion"
27
+ EMBEDDING = "embedding"
28
+ IMAGE_GENERATION = "image_generation"
29
+ IMAGE_CLASSIFICATION = "image_classification"
30
+ OBJECT_DETECTION = "object_detection"
31
+ SPEECH_TO_TEXT = "speech_to_text"
32
+ TEXT_TO_SPEECH = "text_to_speech"
33
+ OCR = "ocr"
34
+ RERANKING = "reranking"
35
+ MULTIMODAL_UNDERSTANDING = "multimodal_understanding"
36
+
37
+
38
+ class RoutingStrategy(str, Enum):
39
+ """Routing strategies for distributing requests among model replicas."""
40
+ ROUND_ROBIN = "round_robin"
41
+ WEIGHTED_ROUND_ROBIN = "weighted_round_robin"
42
+ LEAST_CONNECTIONS = "least_connections"
43
+ RESPONSE_TIME = "response_time"
44
+ RANDOM = "random"
45
+ CONSISTENT_HASH = "consistent_hash"
46
+ DYNAMIC_LOAD_BALANCING = "dynamic_load_balancing"
@@ -0,0 +1,19 @@
1
+ """
2
+ Providers - Components for integrating with different model providers
3
+
4
+ File: isa_model/inference/providers/__init__.py
5
+ This module contains provider implementations for different AI model backends.
6
+ """
7
+
8
+ from .base_provider import BaseProvider
9
+
10
+ __all__ = [
11
+ "BaseProvider",
12
+ ]
13
+
14
+ # Provider implementations can be imported individually as needed
15
+ # from .triton_provider import TritonProvider
16
+ # from .ollama_provider import OllamaProvider
17
+ # from .yyds_provider import YYDSProvider
18
+ # from .openai_provider import OpenAIProvider
19
+ # from .replicate_provider import ReplicateProvider