isa-model 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +937 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/{storage/supabase_storage.py → models/model_repo.py} +72 -73
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +157 -3
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +104 -3
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +538 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/deployment/services/simple_auto_deploy_vision_service.py +275 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +257 -601
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -17
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -1
- isa_model/inference/services/{stacked → img}/flux_professional_service.py +25 -1
- isa_model/inference/services/{stacked → img/helpers}/base_stacked_service.py +40 -35
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +44 -31
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +492 -40
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +51 -17
- isa_model/inference/services/llm/openai_llm_service.py +70 -19
- isa_model/inference/services/llm/yyds_llm_service.py +24 -23
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +218 -117
- isa_model/inference/services/vision/{isA_vision_service.py → disabled/isA_vision_service.py} +98 -0
- isa_model/inference/services/{stacked → vision}/doc_analysis_service.py +1 -1
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +104 -307
- isa_model/inference/services/vision/replicate_vision_service.py +140 -325
- isa_model/inference/services/{stacked → vision}/ui_analysis_service.py +2 -498
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/api/fastapi_server.py +6 -1
- isa_model/serving/api/routes/unified.py +274 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/METADATA +4 -1
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/RECORD +78 -53
- isa_model/config/__init__.py +0 -9
- isa_model/config/config_manager.py +0 -213
- isa_model/core/model_manager.py +0 -213
- isa_model/core/model_registry.py +0 -375
- isa_model/core/vision_models_init.py +0 -116
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/stacked/__init__.py +0 -26
- isa_model/inference/services/stacked/config.py +0 -426
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- /isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/WHEEL +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/top_level.txt +0 -0
@@ -1,213 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Configuration Manager
|
3
|
-
|
4
|
-
Central configuration management with environment support
|
5
|
-
"""
|
6
|
-
|
7
|
-
import os
|
8
|
-
import yaml
|
9
|
-
from typing import Dict, Any, Optional
|
10
|
-
from pathlib import Path
|
11
|
-
from dataclasses import dataclass
|
12
|
-
import logging
|
13
|
-
|
14
|
-
logger = logging.getLogger(__name__)
|
15
|
-
|
16
|
-
@dataclass
|
17
|
-
class ConfigSection:
|
18
|
-
"""Base configuration section"""
|
19
|
-
def to_dict(self) -> Dict[str, Any]:
|
20
|
-
return self.__dict__
|
21
|
-
|
22
|
-
@dataclass
|
23
|
-
class DeploymentConfig(ConfigSection):
|
24
|
-
"""Deployment configuration"""
|
25
|
-
platform: str = "replicate" # replicate, modal, aws, local
|
26
|
-
modal_app_name: str = "isa-ui-analysis"
|
27
|
-
modal_gpu_type: str = "A100-40GB"
|
28
|
-
modal_memory: int = 32768
|
29
|
-
modal_timeout: int = 1800
|
30
|
-
modal_keep_warm: int = 1
|
31
|
-
|
32
|
-
@dataclass
|
33
|
-
class ModelConfig(ConfigSection):
|
34
|
-
"""Model configuration"""
|
35
|
-
ui_detection_model: str = "microsoft/omniparser-v2"
|
36
|
-
ui_planning_model: str = "gpt-4o-mini"
|
37
|
-
fallback_detection: str = "yolov8n"
|
38
|
-
quantization: bool = False
|
39
|
-
batch_size: int = 1
|
40
|
-
|
41
|
-
@dataclass
|
42
|
-
class ServingConfig(ConfigSection):
|
43
|
-
"""Serving configuration"""
|
44
|
-
host: str = "0.0.0.0"
|
45
|
-
port: int = 8000
|
46
|
-
workers: int = 1
|
47
|
-
reload: bool = False
|
48
|
-
log_level: str = "info"
|
49
|
-
cors_origins: list = None
|
50
|
-
|
51
|
-
def __post_init__(self):
|
52
|
-
if self.cors_origins is None:
|
53
|
-
self.cors_origins = ["*"]
|
54
|
-
|
55
|
-
@dataclass
|
56
|
-
class APIConfig(ConfigSection):
|
57
|
-
"""API configuration"""
|
58
|
-
rate_limit: int = 100 # requests per minute
|
59
|
-
max_file_size: int = 10 * 1024 * 1024 # 10MB
|
60
|
-
cache_ttl: int = 3600 # 1 hour
|
61
|
-
enable_auth: bool = False
|
62
|
-
|
63
|
-
@dataclass
|
64
|
-
class ISAConfig:
|
65
|
-
"""Complete ISA configuration"""
|
66
|
-
environment: str
|
67
|
-
deployment: DeploymentConfig
|
68
|
-
models: ModelConfig
|
69
|
-
serving: ServingConfig
|
70
|
-
api: APIConfig
|
71
|
-
|
72
|
-
def to_dict(self) -> Dict[str, Any]:
|
73
|
-
return {
|
74
|
-
"environment": self.environment,
|
75
|
-
"deployment": self.deployment.to_dict(),
|
76
|
-
"models": self.models.to_dict(),
|
77
|
-
"serving": self.serving.to_dict(),
|
78
|
-
"api": self.api.to_dict()
|
79
|
-
}
|
80
|
-
|
81
|
-
class ConfigManager:
|
82
|
-
"""Configuration manager with environment support"""
|
83
|
-
|
84
|
-
_instance = None
|
85
|
-
_config = None
|
86
|
-
|
87
|
-
def __new__(cls):
|
88
|
-
if cls._instance is None:
|
89
|
-
cls._instance = super(ConfigManager, cls).__new__(cls)
|
90
|
-
return cls._instance
|
91
|
-
|
92
|
-
def __init__(self):
|
93
|
-
if self._config is None:
|
94
|
-
self._load_config()
|
95
|
-
|
96
|
-
def _load_config(self):
|
97
|
-
"""Load configuration from environment and files"""
|
98
|
-
env = os.getenv("ISA_ENV", "development")
|
99
|
-
|
100
|
-
# Default configurations
|
101
|
-
default_config = {
|
102
|
-
"deployment": DeploymentConfig(),
|
103
|
-
"models": ModelConfig(),
|
104
|
-
"serving": ServingConfig(),
|
105
|
-
"api": APIConfig()
|
106
|
-
}
|
107
|
-
|
108
|
-
# Load environment-specific configuration
|
109
|
-
config_file = self._get_config_file(env)
|
110
|
-
if config_file and config_file.exists():
|
111
|
-
try:
|
112
|
-
with open(config_file, 'r') as f:
|
113
|
-
file_config = yaml.safe_load(f)
|
114
|
-
|
115
|
-
# Merge configurations
|
116
|
-
self._config = self._merge_configs(default_config, file_config, env)
|
117
|
-
logger.info(f"Loaded configuration for environment: {env}")
|
118
|
-
|
119
|
-
except Exception as e:
|
120
|
-
logger.warning(f"Failed to load config file {config_file}: {e}")
|
121
|
-
self._config = ISAConfig(environment=env, **default_config)
|
122
|
-
else:
|
123
|
-
logger.info(f"No config file found for {env}, using defaults")
|
124
|
-
self._config = ISAConfig(environment=env, **default_config)
|
125
|
-
|
126
|
-
# Override with environment variables
|
127
|
-
self._apply_env_overrides()
|
128
|
-
|
129
|
-
def _get_config_file(self, env: str) -> Optional[Path]:
|
130
|
-
"""Get configuration file path for environment"""
|
131
|
-
# Try to find config file in multiple locations
|
132
|
-
possible_paths = [
|
133
|
-
Path(__file__).parent / "environments" / f"{env}.yaml",
|
134
|
-
Path.cwd() / "config" / f"{env}.yaml",
|
135
|
-
Path.cwd() / f"config_{env}.yaml"
|
136
|
-
]
|
137
|
-
|
138
|
-
for path in possible_paths:
|
139
|
-
if path.exists():
|
140
|
-
return path
|
141
|
-
return None
|
142
|
-
|
143
|
-
def _merge_configs(self, default: Dict, file_config: Dict, env: str) -> ISAConfig:
|
144
|
-
"""Merge default and file configurations"""
|
145
|
-
|
146
|
-
# Update deployment config
|
147
|
-
deployment_data = {**default["deployment"].__dict__}
|
148
|
-
if "deployment" in file_config:
|
149
|
-
deployment_data.update(file_config["deployment"])
|
150
|
-
deployment = DeploymentConfig(**deployment_data)
|
151
|
-
|
152
|
-
# Update model config
|
153
|
-
models_data = {**default["models"].__dict__}
|
154
|
-
if "models" in file_config:
|
155
|
-
models_data.update(file_config["models"])
|
156
|
-
models = ModelConfig(**models_data)
|
157
|
-
|
158
|
-
# Update serving config
|
159
|
-
serving_data = {**default["serving"].__dict__}
|
160
|
-
if "serving" in file_config:
|
161
|
-
serving_data.update(file_config["serving"])
|
162
|
-
serving = ServingConfig(**serving_data)
|
163
|
-
|
164
|
-
# Update API config
|
165
|
-
api_data = {**default["api"].__dict__}
|
166
|
-
if "api" in file_config:
|
167
|
-
api_data.update(file_config["api"])
|
168
|
-
api = APIConfig(**api_data)
|
169
|
-
|
170
|
-
return ISAConfig(
|
171
|
-
environment=env,
|
172
|
-
deployment=deployment,
|
173
|
-
models=models,
|
174
|
-
serving=serving,
|
175
|
-
api=api
|
176
|
-
)
|
177
|
-
|
178
|
-
def _apply_env_overrides(self):
|
179
|
-
"""Apply environment variable overrides"""
|
180
|
-
# Deployment overrides
|
181
|
-
if os.getenv("ISA_DEPLOYMENT_PLATFORM"):
|
182
|
-
self._config.deployment.platform = os.getenv("ISA_DEPLOYMENT_PLATFORM")
|
183
|
-
|
184
|
-
# Model overrides
|
185
|
-
if os.getenv("ISA_UI_DETECTION_MODEL"):
|
186
|
-
self._config.models.ui_detection_model = os.getenv("ISA_UI_DETECTION_MODEL")
|
187
|
-
|
188
|
-
# Serving overrides
|
189
|
-
if os.getenv("ISA_SERVING_PORT"):
|
190
|
-
self._config.serving.port = int(os.getenv("ISA_SERVING_PORT"))
|
191
|
-
|
192
|
-
if os.getenv("ISA_SERVING_HOST"):
|
193
|
-
self._config.serving.host = os.getenv("ISA_SERVING_HOST")
|
194
|
-
|
195
|
-
def get_config(self) -> ISAConfig:
|
196
|
-
"""Get current configuration"""
|
197
|
-
return self._config
|
198
|
-
|
199
|
-
def reload(self):
|
200
|
-
"""Reload configuration"""
|
201
|
-
self._config = None
|
202
|
-
self._load_config()
|
203
|
-
|
204
|
-
# Singleton instance
|
205
|
-
_config_manager = ConfigManager()
|
206
|
-
|
207
|
-
def get_config() -> ISAConfig:
|
208
|
-
"""Get configuration instance"""
|
209
|
-
return _config_manager.get_config()
|
210
|
-
|
211
|
-
def reload_config():
|
212
|
-
"""Reload configuration"""
|
213
|
-
_config_manager.reload()
|
isa_model/core/model_manager.py
DELETED
@@ -1,213 +0,0 @@
|
|
1
|
-
from typing import Dict, Optional, List, Any
|
2
|
-
import logging
|
3
|
-
from pathlib import Path
|
4
|
-
from huggingface_hub import hf_hub_download, snapshot_download
|
5
|
-
from huggingface_hub.errors import HfHubHTTPError
|
6
|
-
from .model_storage import ModelStorage, LocalModelStorage
|
7
|
-
from .model_registry import ModelRegistry, ModelType, ModelCapability
|
8
|
-
|
9
|
-
logger = logging.getLogger(__name__)
|
10
|
-
|
11
|
-
class ModelManager:
|
12
|
-
"""Model management service for handling model downloads, versions, and caching"""
|
13
|
-
|
14
|
-
# 统一的模型计费信息 (per 1M tokens)
|
15
|
-
MODEL_PRICING = {
|
16
|
-
# OpenAI Models
|
17
|
-
"openai": {
|
18
|
-
"gpt-4o-mini": {"input": 0.15, "output": 0.6},
|
19
|
-
"gpt-4.1-mini": {"input": 0.4, "output": 1.6},
|
20
|
-
"gpt-4.1-nano": {"input": 0.1, "output": 0.4},
|
21
|
-
"gpt-4o": {"input": 5.0, "output": 15.0},
|
22
|
-
"gpt-4-turbo": {"input": 10.0, "output": 30.0},
|
23
|
-
"gpt-4": {"input": 30.0, "output": 60.0},
|
24
|
-
"gpt-3.5-turbo": {"input": 0.5, "output": 1.5},
|
25
|
-
"text-embedding-3-small": {"input": 0.02, "output": 0.0},
|
26
|
-
"text-embedding-3-large": {"input": 0.13, "output": 0.0},
|
27
|
-
"whisper-1": {"input": 6.0, "output": 0.0},
|
28
|
-
"tts-1": {"input": 15.0, "output": 0.0},
|
29
|
-
"tts-1-hd": {"input": 30.0, "output": 0.0},
|
30
|
-
},
|
31
|
-
# Ollama Models (免费本地模型)
|
32
|
-
"ollama": {
|
33
|
-
"llama3.2:3b-instruct-fp16": {"input": 0.0, "output": 0.0},
|
34
|
-
"llama3.2-vision:latest": {"input": 0.0, "output": 0.0},
|
35
|
-
"bge-m3": {"input": 0.0, "output": 0.0},
|
36
|
-
},
|
37
|
-
# Replicate Models
|
38
|
-
"replicate": {
|
39
|
-
"black-forest-labs/flux-schnell": {"input": 3.0, "output": 0.0}, # $3 per 1000 images
|
40
|
-
"black-forest-labs/flux-kontext-pro": {"input": 40.0, "output": 0.0}, # $0.04 per image = $40 per 1000 images
|
41
|
-
"meta/meta-llama-3-8b-instruct": {"input": 0.05, "output": 0.25},
|
42
|
-
"kokoro-82m": {"input": 0.0, "output": 0.4}, # ~$0.0004 per second
|
43
|
-
"jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13": {"input": 0.0, "output": 0.4},
|
44
|
-
},
|
45
|
-
# YYDS Models
|
46
|
-
"yyds": {
|
47
|
-
"claude-sonnet-4-20250514": {"input": 4.5, "output": 22.5}, # $0.0045/1K = $4.5/1M, $0.0225/1K = $22.5/1M
|
48
|
-
"claude-3-5-sonnet-20241022": {"input": 3.0, "output": 15.0}, # $0.003/1K = $3.0/1M, $0.015/1K = $15.0/1M
|
49
|
-
}
|
50
|
-
}
|
51
|
-
|
52
|
-
def __init__(self,
|
53
|
-
storage: Optional[ModelStorage] = None,
|
54
|
-
registry: Optional[ModelRegistry] = None):
|
55
|
-
self.storage = storage or LocalModelStorage()
|
56
|
-
self.registry = registry or ModelRegistry()
|
57
|
-
|
58
|
-
def get_model_pricing(self, provider: str, model_name: str) -> Dict[str, float]:
|
59
|
-
"""获取模型定价信息"""
|
60
|
-
return self.MODEL_PRICING.get(provider, {}).get(model_name, {"input": 0.0, "output": 0.0})
|
61
|
-
|
62
|
-
def calculate_cost(self, provider: str, model_name: str, input_tokens: int, output_tokens: int) -> float:
|
63
|
-
"""计算请求成本"""
|
64
|
-
pricing = self.get_model_pricing(provider, model_name)
|
65
|
-
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
66
|
-
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
67
|
-
return input_cost + output_cost
|
68
|
-
|
69
|
-
def get_cheapest_model(self, provider: str, model_type: str = "llm") -> Optional[str]:
|
70
|
-
"""获取最便宜的模型"""
|
71
|
-
provider_models = self.MODEL_PRICING.get(provider, {})
|
72
|
-
if not provider_models:
|
73
|
-
return None
|
74
|
-
|
75
|
-
# 计算每个模型的平均成本 (假设输入输出比例 1:1)
|
76
|
-
cheapest_model = None
|
77
|
-
lowest_cost = float('inf')
|
78
|
-
|
79
|
-
for model_name, pricing in provider_models.items():
|
80
|
-
avg_cost = (pricing["input"] + pricing["output"]) / 2
|
81
|
-
if avg_cost < lowest_cost:
|
82
|
-
lowest_cost = avg_cost
|
83
|
-
cheapest_model = model_name
|
84
|
-
|
85
|
-
return cheapest_model
|
86
|
-
|
87
|
-
async def get_model(self,
|
88
|
-
model_id: str,
|
89
|
-
repo_id: str,
|
90
|
-
model_type: ModelType,
|
91
|
-
capabilities: List[ModelCapability],
|
92
|
-
revision: Optional[str] = None,
|
93
|
-
force_download: bool = False) -> Optional[Path]:
|
94
|
-
"""
|
95
|
-
Get model files, downloading if necessary
|
96
|
-
|
97
|
-
Args:
|
98
|
-
model_id: Unique identifier for the model
|
99
|
-
repo_id: Hugging Face repository ID
|
100
|
-
model_type: Type of model (LLM, embedding, etc.)
|
101
|
-
capabilities: List of model capabilities
|
102
|
-
revision: Specific model version/tag
|
103
|
-
force_download: Force re-download even if cached
|
104
|
-
|
105
|
-
Returns:
|
106
|
-
Path to the model files or None if failed
|
107
|
-
"""
|
108
|
-
# Check if model is already downloaded
|
109
|
-
if not force_download:
|
110
|
-
model_path = await self.storage.load_model(model_id)
|
111
|
-
if model_path:
|
112
|
-
logger.info(f"Using cached model {model_id}")
|
113
|
-
return model_path
|
114
|
-
|
115
|
-
try:
|
116
|
-
# Download model files
|
117
|
-
logger.info(f"Downloading model {model_id} from {repo_id}")
|
118
|
-
model_dir = Path(f"./models/temp/{model_id}")
|
119
|
-
model_dir.mkdir(parents=True, exist_ok=True)
|
120
|
-
|
121
|
-
snapshot_download(
|
122
|
-
repo_id=repo_id,
|
123
|
-
revision=revision,
|
124
|
-
local_dir=model_dir,
|
125
|
-
local_dir_use_symlinks=False
|
126
|
-
)
|
127
|
-
|
128
|
-
# Save model and metadata
|
129
|
-
metadata = {
|
130
|
-
"repo_id": repo_id,
|
131
|
-
"revision": revision,
|
132
|
-
"downloaded_at": str(Path(model_dir).stat().st_mtime)
|
133
|
-
}
|
134
|
-
|
135
|
-
# Register model
|
136
|
-
self.registry.register_model(
|
137
|
-
model_id=model_id,
|
138
|
-
model_type=model_type,
|
139
|
-
capabilities=capabilities,
|
140
|
-
metadata=metadata
|
141
|
-
)
|
142
|
-
|
143
|
-
# Save model files
|
144
|
-
await self.storage.save_model(model_id, str(model_dir), metadata)
|
145
|
-
|
146
|
-
return await self.storage.load_model(model_id)
|
147
|
-
|
148
|
-
except HfHubHTTPError as e:
|
149
|
-
logger.error(f"Failed to download model {model_id}: {e}")
|
150
|
-
return None
|
151
|
-
except Exception as e:
|
152
|
-
logger.error(f"Unexpected error downloading model {model_id}: {e}")
|
153
|
-
return None
|
154
|
-
|
155
|
-
async def list_models(self) -> List[Dict[str, Any]]:
|
156
|
-
"""List all downloaded models with their metadata"""
|
157
|
-
models = await self.storage.list_models()
|
158
|
-
return [
|
159
|
-
{
|
160
|
-
"model_id": model_id,
|
161
|
-
**metadata,
|
162
|
-
**(self.registry.get_model_info(model_id) or {})
|
163
|
-
}
|
164
|
-
for model_id, metadata in models.items()
|
165
|
-
]
|
166
|
-
|
167
|
-
async def remove_model(self, model_id: str) -> bool:
|
168
|
-
"""Remove a model and its metadata"""
|
169
|
-
try:
|
170
|
-
# Remove from storage
|
171
|
-
storage_success = await self.storage.delete_model(model_id)
|
172
|
-
|
173
|
-
# Unregister from registry
|
174
|
-
registry_success = self.registry.unregister_model(model_id)
|
175
|
-
|
176
|
-
return storage_success and registry_success
|
177
|
-
|
178
|
-
except Exception as e:
|
179
|
-
logger.error(f"Failed to remove model {model_id}: {e}")
|
180
|
-
return False
|
181
|
-
|
182
|
-
async def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
183
|
-
"""Get information about a specific model"""
|
184
|
-
storage_info = await self.storage.get_metadata(model_id)
|
185
|
-
registry_info = self.registry.get_model_info(model_id)
|
186
|
-
|
187
|
-
if not storage_info and not registry_info:
|
188
|
-
return None
|
189
|
-
|
190
|
-
return {
|
191
|
-
**(storage_info or {}),
|
192
|
-
**(registry_info or {})
|
193
|
-
}
|
194
|
-
|
195
|
-
async def update_model(self,
|
196
|
-
model_id: str,
|
197
|
-
repo_id: str,
|
198
|
-
model_type: ModelType,
|
199
|
-
capabilities: List[ModelCapability],
|
200
|
-
revision: Optional[str] = None) -> bool:
|
201
|
-
"""Update a model to a new version"""
|
202
|
-
try:
|
203
|
-
return bool(await self.get_model(
|
204
|
-
model_id=model_id,
|
205
|
-
repo_id=repo_id,
|
206
|
-
model_type=model_type,
|
207
|
-
capabilities=capabilities,
|
208
|
-
revision=revision,
|
209
|
-
force_download=True
|
210
|
-
))
|
211
|
-
except Exception as e:
|
212
|
-
logger.error(f"Failed to update model {model_id}: {e}")
|
213
|
-
return False
|