isa-model 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model-0.1.0/LICENSE +21 -0
- isa_model-0.1.0/MANIFEST.in +3 -0
- isa_model-0.1.0/PKG-INFO +116 -0
- isa_model-0.1.0/README.md +86 -0
- isa_model-0.1.0/isa_model/__init__.py +5 -0
- isa_model-0.1.0/isa_model/core/model_manager.py +143 -0
- isa_model-0.1.0/isa_model/core/model_registry.py +115 -0
- isa_model-0.1.0/isa_model/core/model_router.py +226 -0
- isa_model-0.1.0/isa_model/core/model_storage.py +133 -0
- isa_model-0.1.0/isa_model/core/model_version.py +0 -0
- isa_model-0.1.0/isa_model/core/resource_manager.py +202 -0
- isa_model-0.1.0/isa_model/core/storage/hf_storage.py +0 -0
- isa_model-0.1.0/isa_model/core/storage/local_storage.py +0 -0
- isa_model-0.1.0/isa_model/core/storage/minio_storage.py +0 -0
- isa_model-0.1.0/isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model-0.1.0/isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model-0.1.0/isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model-0.1.0/isa_model/inference/__init__.py +11 -0
- isa_model-0.1.0/isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model-0.1.0/isa_model/inference/adapter/unified_api.py +248 -0
- isa_model-0.1.0/isa_model/inference/ai_factory.py +354 -0
- isa_model-0.1.0/isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model-0.1.0/isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model-0.1.0/isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model-0.1.0/isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model-0.1.0/isa_model/inference/backends/__init__.py +53 -0
- isa_model-0.1.0/isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model-0.1.0/isa_model/inference/backends/container_services.py +104 -0
- isa_model-0.1.0/isa_model/inference/backends/local_services.py +72 -0
- isa_model-0.1.0/isa_model/inference/backends/openai_client.py +130 -0
- isa_model-0.1.0/isa_model/inference/backends/replicate_client.py +197 -0
- isa_model-0.1.0/isa_model/inference/backends/third_party_services.py +239 -0
- isa_model-0.1.0/isa_model/inference/backends/triton_client.py +97 -0
- isa_model-0.1.0/isa_model/inference/base.py +46 -0
- isa_model-0.1.0/isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model-0.1.0/isa_model/inference/client_sdk/client.py +134 -0
- isa_model-0.1.0/isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model-0.1.0/isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model-0.1.0/isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model-0.1.0/isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model-0.1.0/isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model-0.1.0/isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model-0.1.0/isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model-0.1.0/isa_model/inference/providers/__init__.py +19 -0
- isa_model-0.1.0/isa_model/inference/providers/base_provider.py +30 -0
- isa_model-0.1.0/isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model-0.1.0/isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model-0.1.0/isa_model/inference/providers/openai_provider.py +87 -0
- isa_model-0.1.0/isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model-0.1.0/isa_model/inference/providers/triton_provider.py +439 -0
- isa_model-0.1.0/isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model-0.1.0/isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model-0.1.0/isa_model/inference/services/__init__.py +14 -0
- isa_model-0.1.0/isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model-0.1.0/isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model-0.1.0/isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model-0.1.0/isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model-0.1.0/isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model-0.1.0/isa_model/inference/services/base_service.py +106 -0
- isa_model-0.1.0/isa_model/inference/services/base_tts_service.py +66 -0
- isa_model-0.1.0/isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model-0.1.0/isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model-0.1.0/isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model-0.1.0/isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model-0.1.0/isa_model/inference/services/llm/__init__.py +16 -0
- isa_model-0.1.0/isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model-0.1.0/isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model-0.1.0/isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model-0.1.0/isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model-0.1.0/isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model-0.1.0/isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model-0.1.0/isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model-0.1.0/isa_model/inference/services/vision/__init__.py +12 -0
- isa_model-0.1.0/isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model-0.1.0/isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model-0.1.0/isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model-0.1.0/isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model-0.1.0/isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model-0.1.0/isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model-0.1.0/isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model-0.1.0/isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model-0.1.0/isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model-0.1.0/isa_model/scripts/inference_tracker.py +283 -0
- isa_model-0.1.0/isa_model/scripts/mlflow_manager.py +379 -0
- isa_model-0.1.0/isa_model/scripts/model_registry.py +465 -0
- isa_model-0.1.0/isa_model/scripts/start_mlflow.py +95 -0
- isa_model-0.1.0/isa_model/scripts/training_tracker.py +257 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model-0.1.0/isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model-0.1.0/isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model-0.1.0/isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model-0.1.0/isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model-0.1.0/isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model-0.1.0/isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model-0.1.0/isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model-0.1.0/isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model-0.1.0/isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model-0.1.0/isa_model/training/image_model/train/train.py +42 -0
- isa_model-0.1.0/isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model-0.1.0/isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model-0.1.0/isa_model/training/image_model/train_main.py +25 -0
- isa_model-0.1.0/isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model-0.1.0/isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model-0.1.0/isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model-0.1.0/isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model-0.1.0/isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model-0.1.0/isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model-0.1.0/isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model-0.1.0/isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0/isa_model.egg-info/PKG-INFO +116 -0
- isa_model-0.1.0/isa_model.egg-info/SOURCES.txt +122 -0
- isa_model-0.1.0/isa_model.egg-info/dependency_links.txt +1 -0
- isa_model-0.1.0/isa_model.egg-info/requires.txt +14 -0
- isa_model-0.1.0/isa_model.egg-info/top_level.txt +1 -0
- isa_model-0.1.0/pyproject.toml +50 -0
- isa_model-0.1.0/setup.cfg +4 -0
- isa_model-0.1.0/setup.py +4 -0
isa_model-0.1.0/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2023-2024 isA_Model Contributors
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
isa_model-0.1.0/PKG-INFO
ADDED
@@ -0,0 +1,116 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: isa-model
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: Unified AI model serving framework
|
5
|
+
Author-email: isA_Model Contributors <your.email@example.com>
|
6
|
+
License: MIT
|
7
|
+
Classifier: Development Status :: 3 - Alpha
|
8
|
+
Classifier: Intended Audience :: Developers
|
9
|
+
Classifier: Operating System :: OS Independent
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Requires-Python: >=3.8
|
13
|
+
Description-Content-Type: text/markdown
|
14
|
+
License-File: LICENSE
|
15
|
+
Requires-Dist: fastapi>=0.95.0
|
16
|
+
Requires-Dist: numpy>=1.20.0
|
17
|
+
Requires-Dist: httpx>=0.23.0
|
18
|
+
Requires-Dist: pydantic>=2.0.0
|
19
|
+
Requires-Dist: uvicorn>=0.22.0
|
20
|
+
Requires-Dist: requests>=2.28.0
|
21
|
+
Requires-Dist: aiohttp>=3.8.0
|
22
|
+
Requires-Dist: transformers>=4.30.0
|
23
|
+
Requires-Dist: langchain-core>=0.1.0
|
24
|
+
Requires-Dist: tritonclient[grpc,http]>=2.30.0
|
25
|
+
Requires-Dist: huggingface-hub>=0.16.0
|
26
|
+
Requires-Dist: kubernetes>=25.3.0
|
27
|
+
Requires-Dist: mlflow>=2.4.0
|
28
|
+
Requires-Dist: torch>=2.0.0
|
29
|
+
Dynamic: license-file
|
30
|
+
|
31
|
+
# isA_Model - AI服务工厂
|
32
|
+
|
33
|
+
isA_Model是一个轻量级AI服务工厂,用于统一管理和调用不同的AI模型和服务提供商。
|
34
|
+
|
35
|
+
## 特性
|
36
|
+
|
37
|
+
- 支持多种AI提供商(Ollama, OpenAI, Replicate, Triton)
|
38
|
+
- 统一的API接口
|
39
|
+
- 灵活的工厂模式
|
40
|
+
- 异步支持
|
41
|
+
- 单例模式,高效缓存
|
42
|
+
|
43
|
+
## 安装
|
44
|
+
|
45
|
+
```bash
|
46
|
+
pip install -r requirements.txt
|
47
|
+
```
|
48
|
+
|
49
|
+
## 快速开始
|
50
|
+
|
51
|
+
使用AI工厂很简单:
|
52
|
+
|
53
|
+
```python
|
54
|
+
from isa_model.inference.ai_factory import AIFactory
|
55
|
+
from isa_model.inference.base import ModelType
|
56
|
+
|
57
|
+
# 获取工厂实例
|
58
|
+
factory = AIFactory()
|
59
|
+
|
60
|
+
# LLM示例 - 使用Ollama
|
61
|
+
llm = factory.get_llm(model_name="llama3.1", provider="ollama")
|
62
|
+
response = await llm.generate("你好,请介绍一下自己。")
|
63
|
+
print(response)
|
64
|
+
|
65
|
+
# 图像生成示例 - 使用Replicate
|
66
|
+
vision_service = factory.get_vision_model(
|
67
|
+
model_name="stability-ai/sdxl:c221b2b8ef527988fb59bf24a8b97c4561f1c671f73bd389f866bfb27c061316",
|
68
|
+
provider="replicate",
|
69
|
+
config={"api_token": "your_replicate_token"}
|
70
|
+
)
|
71
|
+
result = await vision_service.generate_image({
|
72
|
+
"prompt": "A beautiful sunset over mountains",
|
73
|
+
"num_inference_steps": 25
|
74
|
+
})
|
75
|
+
print(result["urls"])
|
76
|
+
```
|
77
|
+
|
78
|
+
## 工厂架构
|
79
|
+
|
80
|
+
isA_Model使用三层架构:
|
81
|
+
|
82
|
+
1. **客户端层** - 应用程序代码
|
83
|
+
2. **服务层** - 模型服务实现(LLM, 图像, 嵌入等)
|
84
|
+
3. **提供商层** - 底层API集成(Ollama, OpenAI, Replicate等)
|
85
|
+
|
86
|
+
### 主要组件
|
87
|
+
|
88
|
+
- `AIFactory` - 中央工厂类,提供模型和服务访问
|
89
|
+
- `BaseService` - 所有服务的基类
|
90
|
+
- `BaseProvider` - 所有提供商的基类
|
91
|
+
- 特定服务实现 - 如`ReplicateVisionService`, `OllamaLLMService`等
|
92
|
+
|
93
|
+
## 支持的模型类型
|
94
|
+
|
95
|
+
- **LLM** - 大语言模型
|
96
|
+
- **VISION** - 图像生成和分析
|
97
|
+
- **EMBEDDING** - 文本嵌入
|
98
|
+
- **AUDIO** - 语音识别
|
99
|
+
- **RERANK** - 重排序
|
100
|
+
|
101
|
+
## 示例
|
102
|
+
|
103
|
+
查看`test_*.py`文件获取更多使用示例。
|
104
|
+
|
105
|
+
## 环境变量
|
106
|
+
|
107
|
+
将API密钥和其他配置添加到`.env.local`文件中:
|
108
|
+
|
109
|
+
```
|
110
|
+
OPENAI_API_KEY=your_openai_key
|
111
|
+
REPLICATE_API_TOKEN=your_replicate_token
|
112
|
+
```
|
113
|
+
|
114
|
+
## 许可证
|
115
|
+
|
116
|
+
MIT
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# isA_Model - AI服务工厂
|
2
|
+
|
3
|
+
isA_Model是一个轻量级AI服务工厂,用于统一管理和调用不同的AI模型和服务提供商。
|
4
|
+
|
5
|
+
## 特性
|
6
|
+
|
7
|
+
- 支持多种AI提供商(Ollama, OpenAI, Replicate, Triton)
|
8
|
+
- 统一的API接口
|
9
|
+
- 灵活的工厂模式
|
10
|
+
- 异步支持
|
11
|
+
- 单例模式,高效缓存
|
12
|
+
|
13
|
+
## 安装
|
14
|
+
|
15
|
+
```bash
|
16
|
+
pip install -r requirements.txt
|
17
|
+
```
|
18
|
+
|
19
|
+
## 快速开始
|
20
|
+
|
21
|
+
使用AI工厂很简单:
|
22
|
+
|
23
|
+
```python
|
24
|
+
from isa_model.inference.ai_factory import AIFactory
|
25
|
+
from isa_model.inference.base import ModelType
|
26
|
+
|
27
|
+
# 获取工厂实例
|
28
|
+
factory = AIFactory()
|
29
|
+
|
30
|
+
# LLM示例 - 使用Ollama
|
31
|
+
llm = factory.get_llm(model_name="llama3.1", provider="ollama")
|
32
|
+
response = await llm.generate("你好,请介绍一下自己。")
|
33
|
+
print(response)
|
34
|
+
|
35
|
+
# 图像生成示例 - 使用Replicate
|
36
|
+
vision_service = factory.get_vision_model(
|
37
|
+
model_name="stability-ai/sdxl:c221b2b8ef527988fb59bf24a8b97c4561f1c671f73bd389f866bfb27c061316",
|
38
|
+
provider="replicate",
|
39
|
+
config={"api_token": "your_replicate_token"}
|
40
|
+
)
|
41
|
+
result = await vision_service.generate_image({
|
42
|
+
"prompt": "A beautiful sunset over mountains",
|
43
|
+
"num_inference_steps": 25
|
44
|
+
})
|
45
|
+
print(result["urls"])
|
46
|
+
```
|
47
|
+
|
48
|
+
## 工厂架构
|
49
|
+
|
50
|
+
isA_Model使用三层架构:
|
51
|
+
|
52
|
+
1. **客户端层** - 应用程序代码
|
53
|
+
2. **服务层** - 模型服务实现(LLM, 图像, 嵌入等)
|
54
|
+
3. **提供商层** - 底层API集成(Ollama, OpenAI, Replicate等)
|
55
|
+
|
56
|
+
### 主要组件
|
57
|
+
|
58
|
+
- `AIFactory` - 中央工厂类,提供模型和服务访问
|
59
|
+
- `BaseService` - 所有服务的基类
|
60
|
+
- `BaseProvider` - 所有提供商的基类
|
61
|
+
- 特定服务实现 - 如`ReplicateVisionService`, `OllamaLLMService`等
|
62
|
+
|
63
|
+
## 支持的模型类型
|
64
|
+
|
65
|
+
- **LLM** - 大语言模型
|
66
|
+
- **VISION** - 图像生成和分析
|
67
|
+
- **EMBEDDING** - 文本嵌入
|
68
|
+
- **AUDIO** - 语音识别
|
69
|
+
- **RERANK** - 重排序
|
70
|
+
|
71
|
+
## 示例
|
72
|
+
|
73
|
+
查看`test_*.py`文件获取更多使用示例。
|
74
|
+
|
75
|
+
## 环境变量
|
76
|
+
|
77
|
+
将API密钥和其他配置添加到`.env.local`文件中:
|
78
|
+
|
79
|
+
```
|
80
|
+
OPENAI_API_KEY=your_openai_key
|
81
|
+
REPLICATE_API_TOKEN=your_replicate_token
|
82
|
+
```
|
83
|
+
|
84
|
+
## 许可证
|
85
|
+
|
86
|
+
MIT
|
@@ -0,0 +1,143 @@
|
|
1
|
+
from typing import Dict, Optional, List, Any
|
2
|
+
import logging
|
3
|
+
from pathlib import Path
|
4
|
+
from huggingface_hub import hf_hub_download, snapshot_download
|
5
|
+
from huggingface_hub.utils import HfHubHTTPError
|
6
|
+
from .model_storage import ModelStorage, LocalModelStorage
|
7
|
+
from .model_registry import ModelRegistry, ModelType, ModelCapability
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
11
|
+
class ModelManager:
|
12
|
+
"""Model management service for handling model downloads, versions, and caching"""
|
13
|
+
|
14
|
+
def __init__(self,
|
15
|
+
storage: Optional[ModelStorage] = None,
|
16
|
+
registry: Optional[ModelRegistry] = None):
|
17
|
+
self.storage = storage or LocalModelStorage()
|
18
|
+
self.registry = registry or ModelRegistry()
|
19
|
+
|
20
|
+
async def get_model(self,
|
21
|
+
model_id: str,
|
22
|
+
repo_id: str,
|
23
|
+
model_type: ModelType,
|
24
|
+
capabilities: List[ModelCapability],
|
25
|
+
revision: Optional[str] = None,
|
26
|
+
force_download: bool = False) -> Path:
|
27
|
+
"""
|
28
|
+
Get model files, downloading if necessary
|
29
|
+
|
30
|
+
Args:
|
31
|
+
model_id: Unique identifier for the model
|
32
|
+
repo_id: Hugging Face repository ID
|
33
|
+
model_type: Type of model (LLM, embedding, etc.)
|
34
|
+
capabilities: List of model capabilities
|
35
|
+
revision: Specific model version/tag
|
36
|
+
force_download: Force re-download even if cached
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
Path to the model files
|
40
|
+
"""
|
41
|
+
# Check if model is already downloaded
|
42
|
+
if not force_download:
|
43
|
+
model_path = await self.storage.load_model(model_id)
|
44
|
+
if model_path:
|
45
|
+
logger.info(f"Using cached model {model_id}")
|
46
|
+
return model_path
|
47
|
+
|
48
|
+
try:
|
49
|
+
# Download model files
|
50
|
+
logger.info(f"Downloading model {model_id} from {repo_id}")
|
51
|
+
model_dir = Path(f"./models/temp/{model_id}")
|
52
|
+
model_dir.mkdir(parents=True, exist_ok=True)
|
53
|
+
|
54
|
+
snapshot_download(
|
55
|
+
repo_id=repo_id,
|
56
|
+
revision=revision,
|
57
|
+
local_dir=model_dir,
|
58
|
+
local_dir_use_symlinks=False
|
59
|
+
)
|
60
|
+
|
61
|
+
# Save model and metadata
|
62
|
+
metadata = {
|
63
|
+
"repo_id": repo_id,
|
64
|
+
"revision": revision,
|
65
|
+
"downloaded_at": str(Path(model_dir).stat().st_mtime)
|
66
|
+
}
|
67
|
+
|
68
|
+
# Register model
|
69
|
+
self.registry.register_model(
|
70
|
+
model_id=model_id,
|
71
|
+
model_type=model_type,
|
72
|
+
capabilities=capabilities,
|
73
|
+
metadata=metadata
|
74
|
+
)
|
75
|
+
|
76
|
+
# Save model files
|
77
|
+
await self.storage.save_model(model_id, str(model_dir), metadata)
|
78
|
+
|
79
|
+
return await self.storage.load_model(model_id)
|
80
|
+
|
81
|
+
except HfHubHTTPError as e:
|
82
|
+
logger.error(f"Failed to download model {model_id}: {e}")
|
83
|
+
raise
|
84
|
+
|
85
|
+
async def list_models(self) -> List[Dict[str, Any]]:
|
86
|
+
"""List all downloaded models with their metadata"""
|
87
|
+
models = await self.storage.list_models()
|
88
|
+
return [
|
89
|
+
{
|
90
|
+
"model_id": model_id,
|
91
|
+
**metadata,
|
92
|
+
**(self.registry.get_model_info(model_id) or {})
|
93
|
+
}
|
94
|
+
for model_id, metadata in models.items()
|
95
|
+
]
|
96
|
+
|
97
|
+
async def remove_model(self, model_id: str) -> bool:
|
98
|
+
"""Remove a model and its metadata"""
|
99
|
+
try:
|
100
|
+
# Remove from storage
|
101
|
+
storage_success = await self.storage.delete_model(model_id)
|
102
|
+
|
103
|
+
# Unregister from registry
|
104
|
+
registry_success = self.registry.unregister_model(model_id)
|
105
|
+
|
106
|
+
return storage_success and registry_success
|
107
|
+
|
108
|
+
except Exception as e:
|
109
|
+
logger.error(f"Failed to remove model {model_id}: {e}")
|
110
|
+
return False
|
111
|
+
|
112
|
+
async def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
113
|
+
"""Get information about a specific model"""
|
114
|
+
storage_info = await self.storage.get_metadata(model_id)
|
115
|
+
registry_info = self.registry.get_model_info(model_id)
|
116
|
+
|
117
|
+
if not storage_info and not registry_info:
|
118
|
+
return None
|
119
|
+
|
120
|
+
return {
|
121
|
+
**(storage_info or {}),
|
122
|
+
**(registry_info or {})
|
123
|
+
}
|
124
|
+
|
125
|
+
async def update_model(self,
|
126
|
+
model_id: str,
|
127
|
+
repo_id: str,
|
128
|
+
model_type: ModelType,
|
129
|
+
capabilities: List[ModelCapability],
|
130
|
+
revision: Optional[str] = None) -> bool:
|
131
|
+
"""Update a model to a new version"""
|
132
|
+
try:
|
133
|
+
return bool(await self.get_model(
|
134
|
+
model_id=model_id,
|
135
|
+
repo_id=repo_id,
|
136
|
+
model_type=model_type,
|
137
|
+
capabilities=capabilities,
|
138
|
+
revision=revision,
|
139
|
+
force_download=True
|
140
|
+
))
|
141
|
+
except Exception as e:
|
142
|
+
logger.error(f"Failed to update model {model_id}: {e}")
|
143
|
+
return False
|
@@ -0,0 +1,115 @@
|
|
1
|
+
from typing import Dict, List, Optional, Any
|
2
|
+
from enum import Enum
|
3
|
+
import logging
|
4
|
+
from pathlib import Path
|
5
|
+
import json
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
class ModelCapability(str, Enum):
|
10
|
+
"""Model capabilities"""
|
11
|
+
TEXT_GENERATION = "text_generation"
|
12
|
+
CHAT = "chat"
|
13
|
+
EMBEDDING = "embedding"
|
14
|
+
RERANKING = "reranking"
|
15
|
+
REASONING = "reasoning"
|
16
|
+
IMAGE_GENERATION = "image_generation"
|
17
|
+
IMAGE_ANALYSIS = "image_analysis"
|
18
|
+
AUDIO_TRANSCRIPTION = "audio_transcription"
|
19
|
+
IMAGE_UNDERSTANDING = "image_understanding"
|
20
|
+
|
21
|
+
class ModelType(str, Enum):
|
22
|
+
"""Model types"""
|
23
|
+
LLM = "llm"
|
24
|
+
EMBEDDING = "embedding"
|
25
|
+
RERANK = "rerank"
|
26
|
+
IMAGE = "image"
|
27
|
+
AUDIO = "audio"
|
28
|
+
VIDEO = "video"
|
29
|
+
VISION = "vision"
|
30
|
+
|
31
|
+
class ModelRegistry:
|
32
|
+
"""Registry for model metadata and capabilities"""
|
33
|
+
|
34
|
+
def __init__(self, registry_file: str = "./models/model_registry.json"):
|
35
|
+
self.registry_file = Path(registry_file)
|
36
|
+
self.registry: Dict[str, Dict[str, Any]] = {}
|
37
|
+
self._load_registry()
|
38
|
+
|
39
|
+
def _load_registry(self):
|
40
|
+
"""Load model registry from file"""
|
41
|
+
if self.registry_file.exists():
|
42
|
+
with open(self.registry_file, 'r') as f:
|
43
|
+
self.registry = json.load(f)
|
44
|
+
else:
|
45
|
+
self.registry = {}
|
46
|
+
self._save_registry()
|
47
|
+
|
48
|
+
def _save_registry(self):
|
49
|
+
"""Save model registry to file"""
|
50
|
+
self.registry_file.parent.mkdir(parents=True, exist_ok=True)
|
51
|
+
with open(self.registry_file, 'w') as f:
|
52
|
+
json.dump(self.registry, f, indent=2)
|
53
|
+
|
54
|
+
def register_model(self,
|
55
|
+
model_id: str,
|
56
|
+
model_type: ModelType,
|
57
|
+
capabilities: List[ModelCapability],
|
58
|
+
metadata: Dict[str, Any]) -> bool:
|
59
|
+
"""Register a model with its capabilities and metadata"""
|
60
|
+
try:
|
61
|
+
self.registry[model_id] = {
|
62
|
+
"type": model_type,
|
63
|
+
"capabilities": [cap.value for cap in capabilities],
|
64
|
+
"metadata": metadata
|
65
|
+
}
|
66
|
+
self._save_registry()
|
67
|
+
logger.info(f"Registered model {model_id}")
|
68
|
+
return True
|
69
|
+
except Exception as e:
|
70
|
+
logger.error(f"Failed to register model {model_id}: {e}")
|
71
|
+
return False
|
72
|
+
|
73
|
+
def unregister_model(self, model_id: str) -> bool:
|
74
|
+
"""Unregister a model"""
|
75
|
+
try:
|
76
|
+
if model_id in self.registry:
|
77
|
+
del self.registry[model_id]
|
78
|
+
self._save_registry()
|
79
|
+
logger.info(f"Unregistered model {model_id}")
|
80
|
+
return True
|
81
|
+
return False
|
82
|
+
except Exception as e:
|
83
|
+
logger.error(f"Failed to unregister model {model_id}: {e}")
|
84
|
+
return False
|
85
|
+
|
86
|
+
def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
87
|
+
"""Get model information"""
|
88
|
+
return self.registry.get(model_id)
|
89
|
+
|
90
|
+
def get_models_by_type(self, model_type: ModelType) -> Dict[str, Dict[str, Any]]:
|
91
|
+
"""Get all models of a specific type"""
|
92
|
+
return {
|
93
|
+
model_id: info
|
94
|
+
for model_id, info in self.registry.items()
|
95
|
+
if info["type"] == model_type
|
96
|
+
}
|
97
|
+
|
98
|
+
def get_models_by_capability(self, capability: ModelCapability) -> Dict[str, Dict[str, Any]]:
|
99
|
+
"""Get all models with a specific capability"""
|
100
|
+
return {
|
101
|
+
model_id: info
|
102
|
+
for model_id, info in self.registry.items()
|
103
|
+
if capability.value in info["capabilities"]
|
104
|
+
}
|
105
|
+
|
106
|
+
def has_capability(self, model_id: str, capability: ModelCapability) -> bool:
|
107
|
+
"""Check if a model has a specific capability"""
|
108
|
+
model_info = self.get_model_info(model_id)
|
109
|
+
if not model_info:
|
110
|
+
return False
|
111
|
+
return capability.value in model_info["capabilities"]
|
112
|
+
|
113
|
+
def list_models(self) -> Dict[str, Dict[str, Any]]:
|
114
|
+
"""List all registered models"""
|
115
|
+
return self.registry
|