isa-model 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +937 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/{storage/supabase_storage.py → models/model_repo.py} +72 -73
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +157 -3
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +104 -3
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +538 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/deployment/services/simple_auto_deploy_vision_service.py +275 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +257 -601
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -17
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -1
- isa_model/inference/services/{stacked → img}/flux_professional_service.py +25 -1
- isa_model/inference/services/{stacked → img/helpers}/base_stacked_service.py +40 -35
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +44 -31
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +492 -40
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +51 -17
- isa_model/inference/services/llm/openai_llm_service.py +70 -19
- isa_model/inference/services/llm/yyds_llm_service.py +24 -23
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +218 -117
- isa_model/inference/services/vision/{isA_vision_service.py → disabled/isA_vision_service.py} +98 -0
- isa_model/inference/services/{stacked → vision}/doc_analysis_service.py +1 -1
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +104 -307
- isa_model/inference/services/vision/replicate_vision_service.py +140 -325
- isa_model/inference/services/{stacked → vision}/ui_analysis_service.py +2 -498
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/api/fastapi_server.py +6 -1
- isa_model/serving/api/routes/unified.py +274 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/METADATA +4 -1
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/RECORD +78 -53
- isa_model/config/__init__.py +0 -9
- isa_model/config/config_manager.py +0 -213
- isa_model/core/model_manager.py +0 -213
- isa_model/core/model_registry.py +0 -375
- isa_model/core/vision_models_init.py +0 -116
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/stacked/__init__.py +0 -26
- isa_model/inference/services/stacked/config.py +0 -426
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- /isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/WHEEL +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,399 @@
|
|
1
|
+
from typing import Dict, Optional, List, Any
|
2
|
+
import logging
|
3
|
+
from pathlib import Path
|
4
|
+
from datetime import datetime
|
5
|
+
from huggingface_hub import hf_hub_download, snapshot_download
|
6
|
+
from huggingface_hub.errors import HfHubHTTPError
|
7
|
+
from .model_storage import ModelStorage, LocalModelStorage
|
8
|
+
from .model_repo import ModelRegistry, ModelType, ModelCapability
|
9
|
+
from .model_billing_tracker import ModelBillingTracker, ModelOperationType
|
10
|
+
from ..pricing_manager import PricingManager
|
11
|
+
from ..config import ConfigManager
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
class ModelManager:
|
16
|
+
"""
|
17
|
+
Model lifecycle management service
|
18
|
+
|
19
|
+
Handles the complete model lifecycle:
|
20
|
+
- Model registration and metadata management
|
21
|
+
- Model downloads, versions, and caching
|
22
|
+
- Cost tracking and billing across all operations
|
23
|
+
- Integration with model training, evaluation, and deployment
|
24
|
+
"""
|
25
|
+
|
26
|
+
|
27
|
+
def __init__(self,
|
28
|
+
storage: Optional[ModelStorage] = None,
|
29
|
+
registry: Optional[ModelRegistry] = None,
|
30
|
+
billing_tracker: Optional[ModelBillingTracker] = None,
|
31
|
+
pricing_manager: Optional[PricingManager] = None,
|
32
|
+
config_manager: Optional[ConfigManager] = None):
|
33
|
+
self.storage = storage or LocalModelStorage()
|
34
|
+
self.registry = registry or ModelRegistry()
|
35
|
+
self.billing_tracker = billing_tracker or ModelBillingTracker(model_registry=self.registry)
|
36
|
+
self.pricing_manager = pricing_manager or PricingManager()
|
37
|
+
self.config_manager = config_manager or ConfigManager()
|
38
|
+
|
39
|
+
def get_model_pricing(self, provider: str, model_name: str) -> Dict[str, float]:
|
40
|
+
"""获取模型定价信息"""
|
41
|
+
pricing = self.pricing_manager.get_model_pricing(provider, model_name)
|
42
|
+
if pricing:
|
43
|
+
return {"input": pricing.input_cost, "output": pricing.output_cost}
|
44
|
+
return {"input": 0.0, "output": 0.0}
|
45
|
+
|
46
|
+
def calculate_cost(self, provider: str, model_name: str, input_tokens: int, output_tokens: int) -> float:
|
47
|
+
"""计算请求成本"""
|
48
|
+
return self.pricing_manager.calculate_cost(
|
49
|
+
provider=provider,
|
50
|
+
model_name=model_name,
|
51
|
+
input_units=input_tokens,
|
52
|
+
output_units=output_tokens
|
53
|
+
)
|
54
|
+
|
55
|
+
def get_cheapest_model(self, provider: str, model_type: str = "llm") -> Optional[str]:
|
56
|
+
"""获取最便宜的模型"""
|
57
|
+
result = self.pricing_manager.get_cheapest_model(
|
58
|
+
provider=provider,
|
59
|
+
unit_type="token",
|
60
|
+
min_input_units=1000 # Assume 1K tokens for comparison
|
61
|
+
)
|
62
|
+
return result["model_name"] if result else None
|
63
|
+
|
64
|
+
async def get_model(self,
|
65
|
+
model_id: str,
|
66
|
+
repo_id: str,
|
67
|
+
model_type: ModelType,
|
68
|
+
capabilities: List[ModelCapability],
|
69
|
+
revision: Optional[str] = None,
|
70
|
+
force_download: bool = False) -> Optional[Path]:
|
71
|
+
"""
|
72
|
+
Get model files, downloading if necessary
|
73
|
+
|
74
|
+
Args:
|
75
|
+
model_id: Unique identifier for the model
|
76
|
+
repo_id: Hugging Face repository ID
|
77
|
+
model_type: Type of model (LLM, embedding, etc.)
|
78
|
+
capabilities: List of model capabilities
|
79
|
+
revision: Specific model version/tag
|
80
|
+
force_download: Force re-download even if cached
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
Path to the model files or None if failed
|
84
|
+
"""
|
85
|
+
# Check if model is already downloaded
|
86
|
+
if not force_download:
|
87
|
+
model_path = await self.storage.load_model(model_id)
|
88
|
+
if model_path:
|
89
|
+
logger.info(f"Using cached model {model_id}")
|
90
|
+
return model_path
|
91
|
+
|
92
|
+
try:
|
93
|
+
# Download model files
|
94
|
+
logger.info(f"Downloading model {model_id} from {repo_id}")
|
95
|
+
model_dir = Path(f"./models/temp/{model_id}")
|
96
|
+
model_dir.mkdir(parents=True, exist_ok=True)
|
97
|
+
|
98
|
+
snapshot_download(
|
99
|
+
repo_id=repo_id,
|
100
|
+
revision=revision,
|
101
|
+
local_dir=model_dir,
|
102
|
+
local_dir_use_symlinks=False
|
103
|
+
)
|
104
|
+
|
105
|
+
# Save model and metadata
|
106
|
+
metadata = {
|
107
|
+
"repo_id": repo_id,
|
108
|
+
"revision": revision,
|
109
|
+
"downloaded_at": str(Path(model_dir).stat().st_mtime)
|
110
|
+
}
|
111
|
+
|
112
|
+
# Register model
|
113
|
+
self.registry.register_model(
|
114
|
+
model_id=model_id,
|
115
|
+
model_type=model_type,
|
116
|
+
capabilities=capabilities,
|
117
|
+
metadata=metadata
|
118
|
+
)
|
119
|
+
|
120
|
+
# Save model files
|
121
|
+
await self.storage.save_model(model_id, str(model_dir), metadata)
|
122
|
+
|
123
|
+
return await self.storage.load_model(model_id)
|
124
|
+
|
125
|
+
except HfHubHTTPError as e:
|
126
|
+
logger.error(f"Failed to download model {model_id}: {e}")
|
127
|
+
return None
|
128
|
+
except Exception as e:
|
129
|
+
logger.error(f"Unexpected error downloading model {model_id}: {e}")
|
130
|
+
return None
|
131
|
+
|
132
|
+
async def list_models(self) -> List[Dict[str, Any]]:
|
133
|
+
"""List all downloaded models with their metadata"""
|
134
|
+
models = await self.storage.list_models()
|
135
|
+
return [
|
136
|
+
{
|
137
|
+
"model_id": model_id,
|
138
|
+
**metadata,
|
139
|
+
**(self.registry.get_model_info(model_id) or {})
|
140
|
+
}
|
141
|
+
for model_id, metadata in models.items()
|
142
|
+
]
|
143
|
+
|
144
|
+
async def remove_model(self, model_id: str) -> bool:
|
145
|
+
"""Remove a model and its metadata"""
|
146
|
+
try:
|
147
|
+
# Remove from storage
|
148
|
+
storage_success = await self.storage.delete_model(model_id)
|
149
|
+
|
150
|
+
# Unregister from registry
|
151
|
+
registry_success = self.registry.unregister_model(model_id)
|
152
|
+
|
153
|
+
return storage_success and registry_success
|
154
|
+
|
155
|
+
except Exception as e:
|
156
|
+
logger.error(f"Failed to remove model {model_id}: {e}")
|
157
|
+
return False
|
158
|
+
|
159
|
+
async def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
|
160
|
+
"""Get information about a specific model"""
|
161
|
+
storage_info = await self.storage.get_metadata(model_id)
|
162
|
+
registry_info = self.registry.get_model_info(model_id)
|
163
|
+
|
164
|
+
if not storage_info and not registry_info:
|
165
|
+
return None
|
166
|
+
|
167
|
+
return {
|
168
|
+
**(storage_info or {}),
|
169
|
+
**(registry_info or {})
|
170
|
+
}
|
171
|
+
|
172
|
+
async def update_model(self,
|
173
|
+
model_id: str,
|
174
|
+
repo_id: str,
|
175
|
+
model_type: ModelType,
|
176
|
+
capabilities: List[ModelCapability],
|
177
|
+
revision: Optional[str] = None) -> bool:
|
178
|
+
"""Update a model to a new version"""
|
179
|
+
try:
|
180
|
+
return bool(await self.get_model(
|
181
|
+
model_id=model_id,
|
182
|
+
repo_id=repo_id,
|
183
|
+
model_type=model_type,
|
184
|
+
capabilities=capabilities,
|
185
|
+
revision=revision,
|
186
|
+
force_download=True
|
187
|
+
))
|
188
|
+
except Exception as e:
|
189
|
+
logger.error(f"Failed to update model {model_id}: {e}")
|
190
|
+
return False
|
191
|
+
|
192
|
+
# === MODEL LIFECYCLE MANAGEMENT ===
|
193
|
+
|
194
|
+
async def register_model_for_lifecycle(
|
195
|
+
self,
|
196
|
+
model_id: str,
|
197
|
+
model_type: ModelType,
|
198
|
+
capabilities: List[ModelCapability],
|
199
|
+
provider: str = "custom",
|
200
|
+
provider_model_name: Optional[str] = None,
|
201
|
+
metadata: Optional[Dict[str, Any]] = None
|
202
|
+
) -> bool:
|
203
|
+
"""
|
204
|
+
Register a model for lifecycle management
|
205
|
+
|
206
|
+
Args:
|
207
|
+
model_id: Unique identifier for the model
|
208
|
+
model_type: Type of model (LLM, embedding, etc.)
|
209
|
+
capabilities: List of model capabilities
|
210
|
+
provider: Provider name for billing
|
211
|
+
provider_model_name: Provider-specific model name for pricing
|
212
|
+
metadata: Additional metadata
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
True if registration successful
|
216
|
+
"""
|
217
|
+
try:
|
218
|
+
# Prepare metadata with billing info
|
219
|
+
full_metadata = metadata or {}
|
220
|
+
full_metadata.update({
|
221
|
+
"provider": provider,
|
222
|
+
"provider_model_name": provider_model_name or model_id,
|
223
|
+
"registered_for_lifecycle": True,
|
224
|
+
"lifecycle_stage": "registered"
|
225
|
+
})
|
226
|
+
|
227
|
+
# Register in model registry
|
228
|
+
success = self.registry.register_model(
|
229
|
+
model_id=model_id,
|
230
|
+
model_type=model_type,
|
231
|
+
capabilities=capabilities,
|
232
|
+
metadata=full_metadata
|
233
|
+
)
|
234
|
+
|
235
|
+
if success:
|
236
|
+
# Track registration operation
|
237
|
+
self.billing_tracker.track_model_usage(
|
238
|
+
model_id=model_id,
|
239
|
+
operation_type=ModelOperationType.STORAGE,
|
240
|
+
provider=provider,
|
241
|
+
service_type="model_management",
|
242
|
+
operation="register_model",
|
243
|
+
metadata={"stage": "registration"}
|
244
|
+
)
|
245
|
+
|
246
|
+
logger.info(f"Successfully registered model {model_id} for lifecycle management")
|
247
|
+
|
248
|
+
return success
|
249
|
+
|
250
|
+
except Exception as e:
|
251
|
+
logger.error(f"Failed to register model {model_id} for lifecycle: {e}")
|
252
|
+
return False
|
253
|
+
|
254
|
+
def track_model_usage(
|
255
|
+
self,
|
256
|
+
model_id: str,
|
257
|
+
operation_type: ModelOperationType,
|
258
|
+
provider: str,
|
259
|
+
service_type: str,
|
260
|
+
operation: str,
|
261
|
+
input_tokens: Optional[int] = None,
|
262
|
+
output_tokens: Optional[int] = None,
|
263
|
+
input_units: Optional[float] = None,
|
264
|
+
output_units: Optional[float] = None,
|
265
|
+
metadata: Optional[Dict[str, Any]] = None
|
266
|
+
):
|
267
|
+
"""
|
268
|
+
Track model usage and costs
|
269
|
+
|
270
|
+
This method should be called by:
|
271
|
+
- Training services when training a model
|
272
|
+
- Evaluation services when evaluating a model
|
273
|
+
- Deployment services when deploying a model
|
274
|
+
- Inference services when using a model for inference
|
275
|
+
"""
|
276
|
+
return self.billing_tracker.track_model_usage(
|
277
|
+
model_id=model_id,
|
278
|
+
operation_type=operation_type,
|
279
|
+
provider=provider,
|
280
|
+
service_type=service_type,
|
281
|
+
operation=operation,
|
282
|
+
input_tokens=input_tokens,
|
283
|
+
output_tokens=output_tokens,
|
284
|
+
input_units=input_units,
|
285
|
+
output_units=output_units,
|
286
|
+
metadata=metadata
|
287
|
+
)
|
288
|
+
|
289
|
+
async def update_model_stage(
|
290
|
+
self,
|
291
|
+
model_id: str,
|
292
|
+
new_stage: str,
|
293
|
+
metadata: Optional[Dict[str, Any]] = None
|
294
|
+
) -> bool:
|
295
|
+
"""
|
296
|
+
Update model lifecycle stage
|
297
|
+
|
298
|
+
Args:
|
299
|
+
model_id: Model identifier
|
300
|
+
new_stage: New lifecycle stage (training, evaluation, deployment, production, retired)
|
301
|
+
metadata: Additional metadata for this stage
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
True if update successful
|
305
|
+
"""
|
306
|
+
try:
|
307
|
+
# Get current model info
|
308
|
+
model_info = self.registry.get_model_info(model_id)
|
309
|
+
if not model_info:
|
310
|
+
logger.error(f"Model {model_id} not found in registry")
|
311
|
+
return False
|
312
|
+
|
313
|
+
# Update metadata with new stage
|
314
|
+
current_metadata = model_info.get("metadata", {})
|
315
|
+
current_metadata.update({
|
316
|
+
"lifecycle_stage": new_stage,
|
317
|
+
"stage_updated_at": str(datetime.now()),
|
318
|
+
**(metadata or {})
|
319
|
+
})
|
320
|
+
|
321
|
+
# Update in registry
|
322
|
+
success = self.registry.register_model(
|
323
|
+
model_id=model_id,
|
324
|
+
model_type=ModelType(model_info["type"]),
|
325
|
+
capabilities=[ModelCapability(cap) for cap in model_info["capabilities"]],
|
326
|
+
metadata=current_metadata
|
327
|
+
)
|
328
|
+
|
329
|
+
if success:
|
330
|
+
logger.info(f"Updated model {model_id} to stage: {new_stage}")
|
331
|
+
|
332
|
+
return success
|
333
|
+
|
334
|
+
except Exception as e:
|
335
|
+
logger.error(f"Failed to update model {model_id} stage: {e}")
|
336
|
+
return False
|
337
|
+
|
338
|
+
def get_model_lifecycle_summary(self, model_id: str) -> Optional[Dict[str, Any]]:
|
339
|
+
"""
|
340
|
+
Get complete lifecycle summary for a model including costs
|
341
|
+
|
342
|
+
Returns:
|
343
|
+
Dictionary with model info, lifecycle stage, and billing summary
|
344
|
+
"""
|
345
|
+
try:
|
346
|
+
# Get model info from registry
|
347
|
+
model_info = self.registry.get_model_info(model_id)
|
348
|
+
if not model_info:
|
349
|
+
return None
|
350
|
+
|
351
|
+
# Get billing summary from tracker
|
352
|
+
billing_summary = self.billing_tracker.get_model_usage_summary(model_id)
|
353
|
+
|
354
|
+
return {
|
355
|
+
"model_id": model_id,
|
356
|
+
"model_info": model_info,
|
357
|
+
"billing_summary": billing_summary,
|
358
|
+
"current_stage": model_info.get("metadata", {}).get("lifecycle_stage", "unknown")
|
359
|
+
}
|
360
|
+
|
361
|
+
except Exception as e:
|
362
|
+
logger.error(f"Failed to get lifecycle summary for {model_id}: {e}")
|
363
|
+
return None
|
364
|
+
|
365
|
+
def list_models_by_stage(self, stage: str) -> List[Dict[str, Any]]:
|
366
|
+
"""
|
367
|
+
List all models in a specific lifecycle stage
|
368
|
+
|
369
|
+
Args:
|
370
|
+
stage: Lifecycle stage to filter by
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
List of model dictionaries
|
374
|
+
"""
|
375
|
+
try:
|
376
|
+
all_models = self.registry.list_models()
|
377
|
+
stage_models = []
|
378
|
+
|
379
|
+
for model_id, model_info in all_models.items():
|
380
|
+
current_stage = model_info.get("metadata", {}).get("lifecycle_stage")
|
381
|
+
if current_stage == stage:
|
382
|
+
stage_models.append({
|
383
|
+
"model_id": model_id,
|
384
|
+
**model_info
|
385
|
+
})
|
386
|
+
|
387
|
+
return stage_models
|
388
|
+
|
389
|
+
except Exception as e:
|
390
|
+
logger.error(f"Failed to list models by stage {stage}: {e}")
|
391
|
+
return []
|
392
|
+
|
393
|
+
def get_billing_summary_by_operation(self, operation_type: ModelOperationType) -> Dict[str, Any]:
|
394
|
+
"""Get billing summary for a specific operation type"""
|
395
|
+
return self.billing_tracker.get_operation_summary(operation_type)
|
396
|
+
|
397
|
+
def print_model_costs(self, model_id: str):
|
398
|
+
"""Print cost summary for a specific model"""
|
399
|
+
self.billing_tracker.print_model_summary(model_id)
|
@@ -1,71 +1,93 @@
|
|
1
1
|
"""
|
2
|
-
|
2
|
+
Unified Model Registry with Supabase Backend
|
3
3
|
|
4
|
-
|
5
|
-
|
4
|
+
Simplified architecture using only Supabase for model metadata and capabilities.
|
5
|
+
No SQLite support - uses unified configuration management.
|
6
6
|
"""
|
7
7
|
|
8
8
|
import os
|
9
9
|
import json
|
10
10
|
import logging
|
11
|
-
from typing import
|
11
|
+
from typing import Dict, List, Optional, Any
|
12
|
+
from enum import Enum
|
12
13
|
from datetime import datetime
|
13
|
-
from pathlib import Path
|
14
14
|
|
15
15
|
try:
|
16
16
|
from supabase import create_client, Client
|
17
|
-
from dotenv import load_dotenv
|
18
17
|
SUPABASE_AVAILABLE = True
|
19
18
|
except ImportError:
|
20
19
|
SUPABASE_AVAILABLE = False
|
21
20
|
|
22
|
-
from ..
|
21
|
+
from ..config import ConfigManager
|
23
22
|
|
24
23
|
logger = logging.getLogger(__name__)
|
25
24
|
|
26
|
-
class
|
27
|
-
"""
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
""
|
25
|
+
class ModelCapability(str, Enum):
|
26
|
+
"""Model capabilities"""
|
27
|
+
TEXT_GENERATION = "text_generation"
|
28
|
+
CHAT = "chat"
|
29
|
+
EMBEDDING = "embedding"
|
30
|
+
RERANKING = "reranking"
|
31
|
+
REASONING = "reasoning"
|
32
|
+
IMAGE_GENERATION = "image_generation"
|
33
|
+
IMAGE_ANALYSIS = "image_analysis"
|
34
|
+
AUDIO_TRANSCRIPTION = "audio_transcription"
|
35
|
+
IMAGE_UNDERSTANDING = "image_understanding"
|
36
|
+
UI_DETECTION = "ui_detection"
|
37
|
+
OCR = "ocr"
|
38
|
+
TABLE_DETECTION = "table_detection"
|
39
|
+
TABLE_STRUCTURE_RECOGNITION = "table_structure_recognition"
|
40
|
+
|
41
|
+
class ModelType(str, Enum):
|
42
|
+
"""Model types"""
|
43
|
+
LLM = "llm"
|
44
|
+
EMBEDDING = "embedding"
|
45
|
+
RERANK = "rerank"
|
46
|
+
IMAGE = "image"
|
47
|
+
AUDIO = "audio"
|
48
|
+
VIDEO = "video"
|
49
|
+
VISION = "vision"
|
50
|
+
|
51
|
+
class ModelRegistry:
|
52
|
+
"""Unified Model Registry with Supabase backend"""
|
32
53
|
|
33
54
|
def __init__(self):
|
34
55
|
if not SUPABASE_AVAILABLE:
|
35
56
|
raise ImportError("supabase-py is required. Install with: pip install supabase")
|
36
57
|
|
37
|
-
#
|
38
|
-
|
58
|
+
# Get configuration from unified ConfigManager
|
59
|
+
self.config_manager = ConfigManager()
|
60
|
+
global_config = self.config_manager.get_global_config()
|
39
61
|
|
40
|
-
|
41
|
-
self.
|
62
|
+
# Get Supabase configuration from database config
|
63
|
+
self.supabase_url = global_config.database.supabase_url or os.getenv("SUPABASE_URL")
|
64
|
+
self.supabase_key = global_config.database.supabase_key or os.getenv("SUPABASE_ANON_KEY") or os.getenv("SERVICE_ROLE_KEY")
|
42
65
|
|
43
66
|
if not self.supabase_url or not self.supabase_key:
|
44
|
-
raise ValueError("SUPABASE_URL and SUPABASE_ANON_KEY must be
|
67
|
+
raise ValueError("SUPABASE_URL and SUPABASE_ANON_KEY (or SERVICE_ROLE_KEY) must be configured")
|
45
68
|
|
46
69
|
# Initialize Supabase client
|
47
70
|
self.supabase: Client = create_client(self.supabase_url, self.supabase_key)
|
48
71
|
|
49
|
-
#
|
72
|
+
# Verify connection
|
50
73
|
self._ensure_tables()
|
51
74
|
|
52
|
-
logger.info("
|
75
|
+
logger.info("Model registry initialized with Supabase backend")
|
53
76
|
|
54
77
|
def _ensure_tables(self):
|
55
78
|
"""Ensure required tables exist in Supabase"""
|
56
|
-
# Note: In production, these tables should be created via Supabase migrations
|
57
|
-
# This is just for development/initialization
|
58
79
|
try:
|
59
80
|
# Check if models table exists by trying to query it
|
60
81
|
result = self.supabase.table('models').select('model_id').limit(1).execute()
|
82
|
+
logger.debug("Models table verified")
|
61
83
|
except Exception as e:
|
62
84
|
logger.warning(f"Models table might not exist: {e}")
|
63
|
-
# In production,
|
85
|
+
# In production, tables should be created via Supabase migrations
|
64
86
|
|
65
87
|
def register_model(self,
|
66
88
|
model_id: str,
|
67
|
-
model_type:
|
68
|
-
capabilities: List[
|
89
|
+
model_type: ModelType,
|
90
|
+
capabilities: List[ModelCapability],
|
69
91
|
metadata: Dict[str, Any]) -> bool:
|
70
92
|
"""Register a model with its capabilities and metadata"""
|
71
93
|
try:
|
@@ -74,7 +96,7 @@ class SupabaseModelRegistry:
|
|
74
96
|
# Prepare model data
|
75
97
|
model_data = {
|
76
98
|
'model_id': model_id,
|
77
|
-
'model_type': model_type,
|
99
|
+
'model_type': model_type.value,
|
78
100
|
'metadata': json.dumps(metadata),
|
79
101
|
'created_at': current_time,
|
80
102
|
'updated_at': current_time
|
@@ -95,7 +117,7 @@ class SupabaseModelRegistry:
|
|
95
117
|
capability_data = [
|
96
118
|
{
|
97
119
|
'model_id': model_id,
|
98
|
-
'capability': capability,
|
120
|
+
'capability': capability.value,
|
99
121
|
'created_at': current_time
|
100
122
|
}
|
101
123
|
for capability in capabilities
|
@@ -107,7 +129,7 @@ class SupabaseModelRegistry:
|
|
107
129
|
logger.error(f"Failed to insert capabilities for {model_id}")
|
108
130
|
return False
|
109
131
|
|
110
|
-
logger.info(f"Successfully registered model {model_id}")
|
132
|
+
logger.info(f"Successfully registered model {model_id} with {len(capabilities)} capabilities")
|
111
133
|
return True
|
112
134
|
|
113
135
|
except Exception as e:
|
@@ -159,10 +181,10 @@ class SupabaseModelRegistry:
|
|
159
181
|
logger.error(f"Failed to get model info for {model_id}: {e}")
|
160
182
|
return None
|
161
183
|
|
162
|
-
def get_models_by_type(self, model_type:
|
184
|
+
def get_models_by_type(self, model_type: ModelType) -> Dict[str, Dict[str, Any]]:
|
163
185
|
"""Get all models of a specific type"""
|
164
186
|
try:
|
165
|
-
models_result = self.supabase.table('models').select('*').eq('model_type', model_type).execute()
|
187
|
+
models_result = self.supabase.table('models').select('*').eq('model_type', model_type.value).execute()
|
166
188
|
|
167
189
|
result = {}
|
168
190
|
for model in models_result.data:
|
@@ -186,49 +208,20 @@ class SupabaseModelRegistry:
|
|
186
208
|
logger.error(f"Failed to get models by type {model_type}: {e}")
|
187
209
|
return {}
|
188
210
|
|
189
|
-
def get_models_by_capability(self, capability:
|
211
|
+
def get_models_by_capability(self, capability: ModelCapability) -> Dict[str, Dict[str, Any]]:
|
190
212
|
"""Get all models with a specific capability"""
|
191
213
|
try:
|
192
|
-
#
|
193
|
-
|
194
|
-
SELECT DISTINCT m.*, mc.capability
|
195
|
-
FROM models m
|
196
|
-
INNER JOIN model_capabilities mc ON m.model_id = mc.model_id
|
197
|
-
WHERE mc.capability = %s
|
198
|
-
"""
|
199
|
-
|
200
|
-
# Use RPC for complex queries
|
201
|
-
result = self.supabase.rpc('get_models_by_capability', {'capability_name': capability}).execute()
|
202
|
-
|
203
|
-
if result.data:
|
204
|
-
models_dict = {}
|
205
|
-
for row in result.data:
|
206
|
-
model_id = row['model_id']
|
207
|
-
if model_id not in models_dict:
|
208
|
-
# Get all capabilities for this model
|
209
|
-
cap_result = self.supabase.table('model_capabilities').select('capability').eq('model_id', model_id).execute()
|
210
|
-
capabilities = [cap['capability'] for cap in cap_result.data]
|
211
|
-
|
212
|
-
models_dict[model_id] = {
|
213
|
-
"type": row["model_type"],
|
214
|
-
"capabilities": capabilities,
|
215
|
-
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
|
216
|
-
"created_at": row["created_at"],
|
217
|
-
"updated_at": row["updated_at"]
|
218
|
-
}
|
219
|
-
|
220
|
-
return models_dict
|
221
|
-
|
222
|
-
# Fallback: manual join if RPC not available
|
223
|
-
cap_result = self.supabase.table('model_capabilities').select('model_id').eq('capability', capability).execute()
|
214
|
+
# Get model IDs with specific capability
|
215
|
+
cap_result = self.supabase.table('model_capabilities').select('model_id').eq('capability', capability.value).execute()
|
224
216
|
model_ids = [row['model_id'] for row in cap_result.data]
|
225
217
|
|
226
218
|
if not model_ids:
|
227
219
|
return {}
|
228
220
|
|
221
|
+
# Get model details
|
229
222
|
models_result = self.supabase.table('models').select('*').in_('model_id', model_ids).execute()
|
230
223
|
|
231
|
-
|
224
|
+
result = {}
|
232
225
|
for model in models_result.data:
|
233
226
|
model_id = model["model_id"]
|
234
227
|
|
@@ -236,7 +229,7 @@ class SupabaseModelRegistry:
|
|
236
229
|
all_caps_result = self.supabase.table('model_capabilities').select('capability').eq('model_id', model_id).execute()
|
237
230
|
capabilities = [cap['capability'] for cap in all_caps_result.data]
|
238
231
|
|
239
|
-
|
232
|
+
result[model_id] = {
|
240
233
|
"type": model["model_type"],
|
241
234
|
"capabilities": capabilities,
|
242
235
|
"metadata": json.loads(model["metadata"]) if model["metadata"] else {},
|
@@ -244,16 +237,16 @@ class SupabaseModelRegistry:
|
|
244
237
|
"updated_at": model["updated_at"]
|
245
238
|
}
|
246
239
|
|
247
|
-
return
|
240
|
+
return result
|
248
241
|
|
249
242
|
except Exception as e:
|
250
243
|
logger.error(f"Failed to get models by capability {capability}: {e}")
|
251
244
|
return {}
|
252
245
|
|
253
|
-
def has_capability(self, model_id: str, capability:
|
246
|
+
def has_capability(self, model_id: str, capability: ModelCapability) -> bool:
|
254
247
|
"""Check if a model has a specific capability"""
|
255
248
|
try:
|
256
|
-
result = self.supabase.table('model_capabilities').select('model_id').eq('model_id', model_id).eq('capability', capability).execute()
|
249
|
+
result = self.supabase.table('model_capabilities').select('model_id').eq('model_id', model_id).eq('capability', capability.value).execute()
|
257
250
|
|
258
251
|
return len(result.data) > 0
|
259
252
|
|
@@ -295,13 +288,19 @@ class SupabaseModelRegistry:
|
|
295
288
|
total_result = self.supabase.table('models').select('model_id', count='exact').execute()
|
296
289
|
total_models = total_result.count if total_result.count is not None else 0
|
297
290
|
|
298
|
-
# Count by type
|
299
|
-
|
300
|
-
type_counts = {
|
291
|
+
# Count by type (manual aggregation since RPC might not exist)
|
292
|
+
models_result = self.supabase.table('models').select('model_type').execute()
|
293
|
+
type_counts = {}
|
294
|
+
for model in models_result.data:
|
295
|
+
model_type = model['model_type']
|
296
|
+
type_counts[model_type] = type_counts.get(model_type, 0) + 1
|
301
297
|
|
302
298
|
# Count by capability
|
303
|
-
|
304
|
-
capability_counts = {
|
299
|
+
caps_result = self.supabase.table('model_capabilities').select('capability').execute()
|
300
|
+
capability_counts = {}
|
301
|
+
for cap in caps_result.data:
|
302
|
+
capability = cap['capability']
|
303
|
+
capability_counts[capability] = capability_counts.get(capability, 0) + 1
|
305
304
|
|
306
305
|
return {
|
307
306
|
"total_models": total_models,
|