isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,19 @@
|
|
1
|
+
"""
|
2
|
+
Providers - Components for integrating with different model providers
|
3
|
+
|
4
|
+
File: isa_model/inference/providers/__init__.py
|
5
|
+
This module contains provider implementations for different AI model backends.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from .base_provider import BaseProvider
|
9
|
+
|
10
|
+
__all__ = [
|
11
|
+
"BaseProvider",
|
12
|
+
]
|
13
|
+
|
14
|
+
# Provider implementations can be imported individually as needed
|
15
|
+
# from .triton_provider import TritonProvider
|
16
|
+
# from .ollama_provider import OllamaProvider
|
17
|
+
# from .yyds_provider import YYDSProvider
|
18
|
+
# from .openai_provider import OpenAIProvider
|
19
|
+
# from .replicate_provider import ReplicateProvider
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Dict, List, Any, Optional
|
3
|
+
|
4
|
+
from isa_model.inference.base import ModelType, Capability
|
5
|
+
|
6
|
+
class BaseProvider(ABC):
|
7
|
+
"""Base class for all AI providers"""
|
8
|
+
|
9
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
10
|
+
self.config = config or {}
|
11
|
+
|
12
|
+
@abstractmethod
|
13
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
14
|
+
"""Get provider capabilities by model type"""
|
15
|
+
pass
|
16
|
+
|
17
|
+
@abstractmethod
|
18
|
+
def get_models(self, model_type: ModelType) -> List[str]:
|
19
|
+
"""Get available models for given type"""
|
20
|
+
pass
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def get_config(self) -> Dict[str, Any]:
|
24
|
+
"""Get provider configuration"""
|
25
|
+
return self.config
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def is_reasoning_model(self, model_name: str) -> bool:
|
29
|
+
"""Check if the model is optimized for reasoning tasks"""
|
30
|
+
pass
|
@@ -0,0 +1,341 @@
|
|
1
|
+
from typing import Dict, List, Any, Optional
|
2
|
+
import aiohttp
|
3
|
+
import logging
|
4
|
+
import asyncio
|
5
|
+
from collections import OrderedDict
|
6
|
+
import os
|
7
|
+
import json
|
8
|
+
import hashlib
|
9
|
+
from pathlib import Path
|
10
|
+
from isa_model.inference.base import ModelType
|
11
|
+
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
class ModelCacheManager:
|
15
|
+
"""管理Triton服务器模型的加载/卸载,支持轮询模式"""
|
16
|
+
|
17
|
+
def __init__(self, cache_size: int = 5, model_repository: str = "/models"):
|
18
|
+
"""
|
19
|
+
初始化模型缓存管理器
|
20
|
+
|
21
|
+
Args:
|
22
|
+
cache_size: 最大缓存模型数量
|
23
|
+
model_repository: 模型库路径
|
24
|
+
"""
|
25
|
+
self.cache_size = cache_size
|
26
|
+
self.model_repository = model_repository
|
27
|
+
|
28
|
+
# LRU缓存使用OrderedDict
|
29
|
+
self.model_cache = OrderedDict()
|
30
|
+
|
31
|
+
# 服务器配置
|
32
|
+
self.server_config = {
|
33
|
+
"polling_enabled": True, # 默认启用轮询模式(适合多模型场景)
|
34
|
+
"triton_url": "localhost:8000",
|
35
|
+
"openai_api_url": "localhost:9000"
|
36
|
+
}
|
37
|
+
|
38
|
+
# 模型类型映射
|
39
|
+
self.model_type_map = {
|
40
|
+
ModelType.LLM: "llm",
|
41
|
+
ModelType.EMBEDDING: "embedding",
|
42
|
+
ModelType.VISION: "vision",
|
43
|
+
ModelType.RERANK: "rerank"
|
44
|
+
}
|
45
|
+
|
46
|
+
logger.info(f"初始化ModelCacheManager,缓存大小: {cache_size},模型库: {model_repository}")
|
47
|
+
|
48
|
+
async def detect_server_mode(self):
|
49
|
+
"""检测Triton服务器是否运行在轮询模式"""
|
50
|
+
try:
|
51
|
+
# 尝试加载任意模型以检测模式
|
52
|
+
models = await self._get_repository_models()
|
53
|
+
if not models:
|
54
|
+
logger.warning("无法获取模型列表,无法检测服务器模式")
|
55
|
+
return
|
56
|
+
|
57
|
+
test_model = models[0]
|
58
|
+
url = f"http://{self.server_config['triton_url']}/v2/repository/models/{test_model}/load"
|
59
|
+
|
60
|
+
async with aiohttp.ClientSession() as session:
|
61
|
+
async with session.post(url) as response:
|
62
|
+
response_text = await response.text()
|
63
|
+
|
64
|
+
if response.status == 503 and "polling is enabled" in response_text:
|
65
|
+
self.server_config["polling_enabled"] = True
|
66
|
+
logger.info("检测到Triton服务器运行在轮询模式(多模型模式)")
|
67
|
+
elif response.status == 200:
|
68
|
+
self.server_config["polling_enabled"] = False
|
69
|
+
logger.info("检测到Triton服务器运行在手动加载模式(单模型模式)")
|
70
|
+
else:
|
71
|
+
logger.warning(f"无法确定服务器模式,状态码: {response.status}")
|
72
|
+
except Exception as e:
|
73
|
+
logger.error(f"检测服务器模式时出错: {e}")
|
74
|
+
|
75
|
+
async def load_model(self, model_name: str, model_type: ModelType) -> bool:
|
76
|
+
"""
|
77
|
+
加载模型到Triton服务器
|
78
|
+
|
79
|
+
Args:
|
80
|
+
model_name: 模型名称
|
81
|
+
model_type: 模型类型
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
bool: 成功返回True,失败返回False
|
85
|
+
"""
|
86
|
+
# 如果是第一次加载,检测服务器模式
|
87
|
+
if not hasattr(self, '_mode_detected'):
|
88
|
+
await self.detect_server_mode()
|
89
|
+
self._mode_detected = True
|
90
|
+
|
91
|
+
if model_name in self.model_cache:
|
92
|
+
# 模型已加载,移到LRU缓存末尾
|
93
|
+
self.model_cache.move_to_end(model_name)
|
94
|
+
logger.info(f"模型 {model_name} 已在缓存中,移至末尾")
|
95
|
+
return True
|
96
|
+
|
97
|
+
try:
|
98
|
+
# 检查模型是否已加载到服务器
|
99
|
+
is_loaded = await self._check_model_loaded(model_name)
|
100
|
+
if is_loaded:
|
101
|
+
logger.info(f"模型 {model_name} 已在服务器中加载")
|
102
|
+
self.model_cache[model_name] = {
|
103
|
+
"type": model_type,
|
104
|
+
"load_time": asyncio.get_event_loop().time()
|
105
|
+
}
|
106
|
+
return True
|
107
|
+
|
108
|
+
# 如果在轮询模式下,我们不能手动加载模型
|
109
|
+
if self.server_config["polling_enabled"]:
|
110
|
+
# 检查模型是否存在
|
111
|
+
exists = await self._check_model_exists(model_name)
|
112
|
+
if exists:
|
113
|
+
logger.warning(f"服务器在轮询模式下,无法手动加载模型 {model_name},但模型存在")
|
114
|
+
# 我们假设模型将通过轮询加载
|
115
|
+
return True
|
116
|
+
else:
|
117
|
+
logger.error(f"模型 {model_name} 不存在于服务器存储库中")
|
118
|
+
return False
|
119
|
+
else:
|
120
|
+
# 在非轮询模式下,可以手动加载
|
121
|
+
# 如果缓存已满,卸载最少使用的模型
|
122
|
+
if len(self.model_cache) >= self.cache_size:
|
123
|
+
lru_model, _ = self.model_cache.popitem(last=False)
|
124
|
+
await self._unload_from_triton(lru_model)
|
125
|
+
logger.info(f"从缓存中卸载LRU模型 {lru_model}")
|
126
|
+
|
127
|
+
# 加载新模型
|
128
|
+
success = await self._load_to_triton(model_name)
|
129
|
+
if success:
|
130
|
+
self.model_cache[model_name] = {
|
131
|
+
"type": model_type,
|
132
|
+
"load_time": asyncio.get_event_loop().time()
|
133
|
+
}
|
134
|
+
logger.info(f"成功加载模型 {model_name}")
|
135
|
+
return True
|
136
|
+
else:
|
137
|
+
logger.error(f"加载模型 {model_name} 失败")
|
138
|
+
return False
|
139
|
+
|
140
|
+
except Exception as e:
|
141
|
+
logger.error(f"加载模型 {model_name} 时出错: {e}")
|
142
|
+
return False
|
143
|
+
|
144
|
+
async def unload_model(self, model_name: str) -> bool:
|
145
|
+
"""卸载模型"""
|
146
|
+
# 如果在轮询模式下,我们不能手动卸载模型
|
147
|
+
if self.server_config["polling_enabled"]:
|
148
|
+
logger.warning(f"服务器在轮询模式下,无法手动卸载模型 {model_name}")
|
149
|
+
return True
|
150
|
+
|
151
|
+
if model_name not in self.model_cache:
|
152
|
+
logger.warning(f"模型 {model_name} 未在缓存中,无需卸载")
|
153
|
+
return True
|
154
|
+
|
155
|
+
try:
|
156
|
+
# 卸载模型
|
157
|
+
success = await self._unload_from_triton(model_name)
|
158
|
+
if success:
|
159
|
+
# 从缓存中移除
|
160
|
+
self.model_cache.pop(model_name, None)
|
161
|
+
logger.info(f"成功卸载模型 {model_name}")
|
162
|
+
return True
|
163
|
+
else:
|
164
|
+
logger.error(f"卸载模型 {model_name} 失败")
|
165
|
+
return False
|
166
|
+
|
167
|
+
except Exception as e:
|
168
|
+
logger.error(f"卸载模型 {model_name} 时出错: {e}")
|
169
|
+
return False
|
170
|
+
|
171
|
+
async def _load_to_triton(self, model_name: str) -> bool:
|
172
|
+
"""向Triton服务器发送加载模型请求"""
|
173
|
+
try:
|
174
|
+
logger.info(f"尝试加载模型 {model_name} 到Triton服务器")
|
175
|
+
|
176
|
+
url = f"http://{self.server_config['triton_url']}/v2/repository/models/{model_name}/load"
|
177
|
+
|
178
|
+
async with aiohttp.ClientSession() as session:
|
179
|
+
async with session.post(url) as response:
|
180
|
+
response_text = await response.text()
|
181
|
+
|
182
|
+
if response.status == 200:
|
183
|
+
logger.info(f"成功加载模型 {model_name}")
|
184
|
+
return True
|
185
|
+
elif response.status == 400:
|
186
|
+
# 模型可能已加载
|
187
|
+
logger.info(f"模型 {model_name} 可能已加载: {response_text}")
|
188
|
+
return True
|
189
|
+
elif response.status == 503 and "polling is enabled" in response_text:
|
190
|
+
# 检测到轮询模式
|
191
|
+
self.server_config["polling_enabled"] = True
|
192
|
+
logger.warning(f"服务器在轮询模式下,无法手动加载模型: {response_text}")
|
193
|
+
# 检查模型是否存在
|
194
|
+
return await self._check_model_exists(model_name)
|
195
|
+
else:
|
196
|
+
logger.error(f"加载模型 {model_name} 失败: Status {response.status}, Response: {response_text}")
|
197
|
+
return False
|
198
|
+
|
199
|
+
except Exception as e:
|
200
|
+
logger.error(f"向Triton API发送加载模型 {model_name} 请求时出错: {e}")
|
201
|
+
return False
|
202
|
+
|
203
|
+
async def _check_model_loaded(self, model_name: str) -> bool:
|
204
|
+
"""检查模型是否已加载"""
|
205
|
+
try:
|
206
|
+
url = f"http://{self.server_config['triton_url']}/v2/models/{model_name}/ready"
|
207
|
+
|
208
|
+
async with aiohttp.ClientSession() as session:
|
209
|
+
async with session.get(url) as response:
|
210
|
+
if response.status == 200:
|
211
|
+
logger.info(f"模型 {model_name} 已加载")
|
212
|
+
return True
|
213
|
+
else:
|
214
|
+
logger.info(f"模型 {model_name} 未加载,状态码: {response.status}")
|
215
|
+
return False
|
216
|
+
except Exception as e:
|
217
|
+
logger.error(f"检查模型 {model_name} 是否加载时出错: {e}")
|
218
|
+
return False
|
219
|
+
|
220
|
+
async def _check_model_exists(self, model_name: str) -> bool:
|
221
|
+
"""检查模型是否存在于存储库中"""
|
222
|
+
try:
|
223
|
+
url = f"http://{self.server_config['triton_url']}/v2/repository/index"
|
224
|
+
|
225
|
+
async with aiohttp.ClientSession() as session:
|
226
|
+
async with session.post(url) as response:
|
227
|
+
if response.status == 200:
|
228
|
+
models = await response.json()
|
229
|
+
model_names = [model["name"] for model in models]
|
230
|
+
exists = model_name in model_names
|
231
|
+
logger.info(f"模型 {model_name} {'存在' if exists else '不存在'}于存储库中")
|
232
|
+
logger.info(f"可用模型: {model_names}")
|
233
|
+
return exists
|
234
|
+
else:
|
235
|
+
logger.error(f"检查模型存在性失败: {response.status}")
|
236
|
+
return False
|
237
|
+
except Exception as e:
|
238
|
+
logger.error(f"检查模型存在性时出错: {e}")
|
239
|
+
return False
|
240
|
+
|
241
|
+
async def _unload_from_triton(self, model_name: str) -> bool:
|
242
|
+
"""从Triton服务器卸载模型"""
|
243
|
+
try:
|
244
|
+
url = f"http://{self.server_config['triton_url']}/v2/repository/models/{model_name}/unload"
|
245
|
+
|
246
|
+
async with aiohttp.ClientSession() as session:
|
247
|
+
async with session.post(url) as response:
|
248
|
+
response_text = await response.text()
|
249
|
+
|
250
|
+
if response.status == 200:
|
251
|
+
logger.info(f"成功卸载模型 {model_name}")
|
252
|
+
return True
|
253
|
+
elif response.status == 503 and "polling is enabled" in response_text:
|
254
|
+
# 检测到轮询模式
|
255
|
+
self.server_config["polling_enabled"] = True
|
256
|
+
logger.warning(f"服务器在轮询模式下,无法手动卸载模型: {response_text}")
|
257
|
+
return True
|
258
|
+
else:
|
259
|
+
logger.error(f"卸载模型 {model_name} 失败: Status {response.status}, Response: {response_text}")
|
260
|
+
return False
|
261
|
+
except Exception as e:
|
262
|
+
logger.error(f"向Triton API发送卸载模型 {model_name} 请求时出错: {e}")
|
263
|
+
return False
|
264
|
+
|
265
|
+
def list_available_models(self, model_type: ModelType = None) -> List[str]:
|
266
|
+
"""
|
267
|
+
列出可用模型
|
268
|
+
|
269
|
+
Args:
|
270
|
+
model_type: 按模型类型筛选
|
271
|
+
|
272
|
+
Returns:
|
273
|
+
模型名称列表
|
274
|
+
"""
|
275
|
+
try:
|
276
|
+
# 获取模型列表
|
277
|
+
models = asyncio.run(self._get_repository_models())
|
278
|
+
|
279
|
+
if not models:
|
280
|
+
logger.warning("在存储库中未找到模型或无法连接到服务器")
|
281
|
+
return []
|
282
|
+
|
283
|
+
# 如果未指定模型类型,返回所有模型
|
284
|
+
if model_type is None:
|
285
|
+
return models
|
286
|
+
|
287
|
+
# 基于命名约定的简单过滤器
|
288
|
+
if model_type == ModelType.LLM:
|
289
|
+
# 返回包含关键字的模型
|
290
|
+
llm_keywords = ["llama", "mistral", "gemma", "qwen", "phi", "gpt", "falcon"]
|
291
|
+
return [m for m in models if any(kw in m.lower() for kw in llm_keywords)]
|
292
|
+
elif model_type == ModelType.EMBEDDING:
|
293
|
+
embed_keywords = ["embed", "bge", "e5", "text-embedding"]
|
294
|
+
return [m for m in models if any(kw in m.lower() for kw in embed_keywords)]
|
295
|
+
elif model_type == ModelType.VISION:
|
296
|
+
vision_keywords = ["clip", "vision", "multimodal", "image"]
|
297
|
+
return [m for m in models if any(kw in m.lower() for kw in vision_keywords)]
|
298
|
+
elif model_type == ModelType.RERANK:
|
299
|
+
rerank_keywords = ["rerank", "cross-encoder"]
|
300
|
+
return [m for m in models if any(kw in m.lower() for kw in rerank_keywords)]
|
301
|
+
else:
|
302
|
+
return []
|
303
|
+
|
304
|
+
except Exception as e:
|
305
|
+
logger.error(f"列出模型时出错: {e}")
|
306
|
+
return []
|
307
|
+
|
308
|
+
async def _get_repository_models(self) -> List[str]:
|
309
|
+
"""从Triton服务器获取模型列表"""
|
310
|
+
try:
|
311
|
+
url = f"http://{self.server_config['triton_url']}/v2/repository/index"
|
312
|
+
|
313
|
+
async with aiohttp.ClientSession() as session:
|
314
|
+
async with session.post(url) as response:
|
315
|
+
if response.status == 200:
|
316
|
+
models = await response.json()
|
317
|
+
return [model["name"] for model in models]
|
318
|
+
else:
|
319
|
+
logger.error(f"获取模型失败: Status {response.status}")
|
320
|
+
return []
|
321
|
+
except Exception as e:
|
322
|
+
logger.error(f"获取存储库模型时出错: {e}")
|
323
|
+
return []
|
324
|
+
|
325
|
+
async def get_openai_models(self) -> List[Dict[str, Any]]:
|
326
|
+
"""获取OpenAI兼容API中的可用模型"""
|
327
|
+
try:
|
328
|
+
url = f"http://{self.server_config['openai_api_url']}/v1/models"
|
329
|
+
|
330
|
+
async with aiohttp.ClientSession() as session:
|
331
|
+
async with session.get(url) as response:
|
332
|
+
if response.status == 200:
|
333
|
+
result = await response.json()
|
334
|
+
logger.info(f"从OpenAI API获取到 {len(result.get('data', []))} 个模型")
|
335
|
+
return result.get("data", [])
|
336
|
+
else:
|
337
|
+
logger.error(f"获取OpenAI模型失败: Status {response.status}")
|
338
|
+
return []
|
339
|
+
except Exception as e:
|
340
|
+
logger.error(f"获取OpenAI模型时出错: {e}")
|
341
|
+
return []
|
@@ -0,0 +1,73 @@
|
|
1
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
2
|
+
from isa_model.inference.base import ModelType, Capability
|
3
|
+
from typing import Dict, List, Any
|
4
|
+
import logging
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
class OllamaProvider(BaseProvider):
|
9
|
+
"""Provider for Ollama API"""
|
10
|
+
|
11
|
+
def __init__(self, config=None):
|
12
|
+
"""
|
13
|
+
Initialize the Ollama Provider
|
14
|
+
|
15
|
+
Args:
|
16
|
+
config (dict, optional): Configuration for the provider
|
17
|
+
- base_url: Base URL for Ollama API (default: http://localhost:11434)
|
18
|
+
- timeout: Timeout for API calls in seconds
|
19
|
+
"""
|
20
|
+
default_config = {
|
21
|
+
"base_url": "http://localhost:11434",
|
22
|
+
"timeout": 60,
|
23
|
+
"stream": True,
|
24
|
+
"temperature": 0.7,
|
25
|
+
"top_p": 0.9,
|
26
|
+
"max_tokens": 2048,
|
27
|
+
"keep_alive": "5m"
|
28
|
+
}
|
29
|
+
|
30
|
+
# Merge default config with provided config
|
31
|
+
merged_config = {**default_config, **(config or {})}
|
32
|
+
|
33
|
+
super().__init__(config=merged_config)
|
34
|
+
self.name = "ollama"
|
35
|
+
|
36
|
+
logger.info(f"Initialized OllamaProvider with URL: {self.config['base_url']}")
|
37
|
+
|
38
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
39
|
+
"""Get provider capabilities by model type"""
|
40
|
+
return {
|
41
|
+
ModelType.LLM: [
|
42
|
+
Capability.CHAT,
|
43
|
+
Capability.COMPLETION
|
44
|
+
],
|
45
|
+
ModelType.EMBEDDING: [
|
46
|
+
Capability.EMBEDDING
|
47
|
+
],
|
48
|
+
ModelType.VISION: [
|
49
|
+
Capability.IMAGE_UNDERSTANDING
|
50
|
+
]
|
51
|
+
}
|
52
|
+
|
53
|
+
def get_models(self, model_type: ModelType) -> List[str]:
|
54
|
+
"""Get available models for given type"""
|
55
|
+
# Placeholder: In real implementation, this would query Ollama API
|
56
|
+
if model_type == ModelType.LLM:
|
57
|
+
return ["llama3", "mistral", "phi3", "llama3.1", "codellama", "gemma"]
|
58
|
+
elif model_type == ModelType.EMBEDDING:
|
59
|
+
return ["bge-m3", "nomic-embed-text"]
|
60
|
+
elif model_type == ModelType.VISION:
|
61
|
+
return ["llava", "bakllava", "llama3-vision"]
|
62
|
+
else:
|
63
|
+
return []
|
64
|
+
|
65
|
+
def get_config(self) -> Dict[str, Any]:
|
66
|
+
"""Get provider configuration"""
|
67
|
+
return self.config
|
68
|
+
|
69
|
+
def is_reasoning_model(self, model_name: str) -> bool:
|
70
|
+
"""Check if the model is optimized for reasoning tasks"""
|
71
|
+
# Default implementation: consider larger models as reasoning-capable
|
72
|
+
reasoning_models = ["llama3.1", "llama3", "claude3", "gpt4", "mixtral", "gemma", "palm2"]
|
73
|
+
return any(rm in model_name.lower() for rm in reasoning_models)
|
@@ -0,0 +1,87 @@
|
|
1
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
2
|
+
from isa_model.inference.base import ModelType, Capability
|
3
|
+
from typing import Dict, List, Any
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
class OpenAIProvider(BaseProvider):
|
10
|
+
"""Provider for OpenAI API"""
|
11
|
+
|
12
|
+
def __init__(self, config=None):
|
13
|
+
"""
|
14
|
+
Initialize the OpenAI Provider
|
15
|
+
|
16
|
+
Args:
|
17
|
+
config (dict, optional): Configuration for the provider
|
18
|
+
- api_key: OpenAI API key (default: from environment variable)
|
19
|
+
- api_base: Base URL for OpenAI API (default: https://api.openai.com/v1)
|
20
|
+
- timeout: Timeout for API calls in seconds
|
21
|
+
"""
|
22
|
+
default_config = {
|
23
|
+
"api_key": os.environ.get("OPENAI_API_KEY", ""),
|
24
|
+
"api_base": "https://api.openai.com/v1",
|
25
|
+
"timeout": 60,
|
26
|
+
"stream": True,
|
27
|
+
"temperature": 0.7,
|
28
|
+
"top_p": 0.9,
|
29
|
+
"max_tokens": 1024
|
30
|
+
}
|
31
|
+
|
32
|
+
# Merge default config with provided config
|
33
|
+
merged_config = {**default_config, **(config or {})}
|
34
|
+
|
35
|
+
super().__init__(config=merged_config)
|
36
|
+
self.name = "openai"
|
37
|
+
|
38
|
+
logger.info(f"Initialized OpenAIProvider with URL: {self.config['api_base']}")
|
39
|
+
|
40
|
+
# Validate API key
|
41
|
+
if not self.config["api_key"]:
|
42
|
+
logger.warning("OpenAI API key not provided. Set OPENAI_API_KEY environment variable or pass in config.")
|
43
|
+
|
44
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
45
|
+
"""Get provider capabilities by model type"""
|
46
|
+
return {
|
47
|
+
ModelType.LLM: [
|
48
|
+
Capability.CHAT,
|
49
|
+
Capability.COMPLETION
|
50
|
+
],
|
51
|
+
ModelType.EMBEDDING: [
|
52
|
+
Capability.EMBEDDING
|
53
|
+
],
|
54
|
+
ModelType.VISION: [
|
55
|
+
Capability.IMAGE_UNDERSTANDING,
|
56
|
+
Capability.MULTIMODAL_UNDERSTANDING
|
57
|
+
],
|
58
|
+
ModelType.AUDIO: [
|
59
|
+
Capability.SPEECH_TO_TEXT
|
60
|
+
]
|
61
|
+
}
|
62
|
+
|
63
|
+
def get_models(self, model_type: ModelType) -> List[str]:
|
64
|
+
"""Get available models for given type"""
|
65
|
+
if model_type == ModelType.LLM:
|
66
|
+
return ["gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo"]
|
67
|
+
elif model_type == ModelType.EMBEDDING:
|
68
|
+
return ["text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002"]
|
69
|
+
elif model_type == ModelType.VISION:
|
70
|
+
return ["gpt-4o", "gpt-4-vision-preview"]
|
71
|
+
elif model_type == ModelType.AUDIO:
|
72
|
+
return ["whisper-1"]
|
73
|
+
else:
|
74
|
+
return []
|
75
|
+
|
76
|
+
def get_config(self) -> Dict[str, Any]:
|
77
|
+
"""Get provider configuration"""
|
78
|
+
# Return a copy without sensitive information
|
79
|
+
config_copy = self.config.copy()
|
80
|
+
if "api_key" in config_copy:
|
81
|
+
config_copy["api_key"] = "***" if config_copy["api_key"] else ""
|
82
|
+
return config_copy
|
83
|
+
|
84
|
+
def is_reasoning_model(self, model_name: str) -> bool:
|
85
|
+
"""Check if the model is optimized for reasoning tasks"""
|
86
|
+
reasoning_models = ["gpt-4", "gpt-4o", "gpt-4-turbo"]
|
87
|
+
return any(rm in model_name.lower() for rm in reasoning_models)
|
@@ -0,0 +1,94 @@
|
|
1
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
2
|
+
from isa_model.inference.base import ModelType, Capability
|
3
|
+
from typing import Dict, List, Any
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
class ReplicateProvider(BaseProvider):
|
10
|
+
"""Provider for Replicate API"""
|
11
|
+
|
12
|
+
def __init__(self, config=None):
|
13
|
+
"""
|
14
|
+
Initialize the Replicate Provider
|
15
|
+
|
16
|
+
Args:
|
17
|
+
config (dict, optional): Configuration for the provider
|
18
|
+
- api_token: Replicate API token (default: from environment variable)
|
19
|
+
- timeout: Timeout for API calls in seconds
|
20
|
+
"""
|
21
|
+
default_config = {
|
22
|
+
"api_token": os.environ.get("REPLICATE_API_TOKEN", ""),
|
23
|
+
"timeout": 60,
|
24
|
+
"stream": True,
|
25
|
+
"max_tokens": 1024
|
26
|
+
}
|
27
|
+
|
28
|
+
# Merge default config with provided config
|
29
|
+
merged_config = {**default_config, **(config or {})}
|
30
|
+
|
31
|
+
super().__init__(config=merged_config)
|
32
|
+
self.name = "replicate"
|
33
|
+
|
34
|
+
logger.info(f"Initialized ReplicateProvider")
|
35
|
+
|
36
|
+
# Validate API token
|
37
|
+
if not self.config["api_token"]:
|
38
|
+
logger.warning("Replicate API token not provided. Set REPLICATE_API_TOKEN environment variable or pass in config.")
|
39
|
+
|
40
|
+
def get_capabilities(self) -> Dict[ModelType, List[Capability]]:
|
41
|
+
"""Get provider capabilities by model type"""
|
42
|
+
return {
|
43
|
+
ModelType.LLM: [
|
44
|
+
Capability.CHAT,
|
45
|
+
Capability.COMPLETION
|
46
|
+
],
|
47
|
+
ModelType.VISION: [
|
48
|
+
Capability.IMAGE_UNDERSTANDING,
|
49
|
+
Capability.IMAGE_GENERATION,
|
50
|
+
Capability.MULTIMODAL_UNDERSTANDING
|
51
|
+
],
|
52
|
+
ModelType.AUDIO: [
|
53
|
+
Capability.SPEECH_TO_TEXT,
|
54
|
+
Capability.TEXT_TO_SPEECH
|
55
|
+
]
|
56
|
+
}
|
57
|
+
|
58
|
+
def get_models(self, model_type: ModelType) -> List[str]:
|
59
|
+
"""Get available models for given type"""
|
60
|
+
if model_type == ModelType.LLM:
|
61
|
+
return [
|
62
|
+
"meta/llama-3-70b-instruct",
|
63
|
+
"meta/llama-3-8b-instruct",
|
64
|
+
"anthropic/claude-3-opus-20240229",
|
65
|
+
"anthropic/claude-3-sonnet-20240229"
|
66
|
+
]
|
67
|
+
elif model_type == ModelType.VISION:
|
68
|
+
return [
|
69
|
+
"stability-ai/sdxl",
|
70
|
+
"stability-ai/stable-diffusion-3-medium",
|
71
|
+
"meta/llama-3-70b-vision",
|
72
|
+
"anthropic/claude-3-opus-20240229",
|
73
|
+
"anthropic/claude-3-sonnet-20240229"
|
74
|
+
]
|
75
|
+
elif model_type == ModelType.AUDIO:
|
76
|
+
return [
|
77
|
+
"openai/whisper",
|
78
|
+
"suno-ai/bark"
|
79
|
+
]
|
80
|
+
else:
|
81
|
+
return []
|
82
|
+
|
83
|
+
def get_config(self) -> Dict[str, Any]:
|
84
|
+
"""Get provider configuration"""
|
85
|
+
# Return a copy without sensitive information
|
86
|
+
config_copy = self.config.copy()
|
87
|
+
if "api_token" in config_copy:
|
88
|
+
config_copy["api_token"] = "***" if config_copy["api_token"] else ""
|
89
|
+
return config_copy
|
90
|
+
|
91
|
+
def is_reasoning_model(self, model_name: str) -> bool:
|
92
|
+
"""Check if the model is optimized for reasoning tasks"""
|
93
|
+
reasoning_models = ["llama-3-70b", "claude-3-opus", "claude-3-sonnet"]
|
94
|
+
return any(rm in model_name.lower() for rm in reasoning_models)
|