isa-model 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. isa_model/__init__.py +30 -1
  2. isa_model/client.py +937 -0
  3. isa_model/core/config/__init__.py +16 -0
  4. isa_model/core/config/config_manager.py +514 -0
  5. isa_model/core/config.py +426 -0
  6. isa_model/core/models/model_billing_tracker.py +476 -0
  7. isa_model/core/models/model_manager.py +399 -0
  8. isa_model/core/{storage/supabase_storage.py → models/model_repo.py} +72 -73
  9. isa_model/core/pricing_manager.py +426 -0
  10. isa_model/core/services/__init__.py +19 -0
  11. isa_model/core/services/intelligent_model_selector.py +547 -0
  12. isa_model/core/types.py +291 -0
  13. isa_model/deployment/__init__.py +2 -0
  14. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +157 -3
  15. isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
  16. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +104 -3
  17. isa_model/deployment/cloud/modal/register_models.py +321 -0
  18. isa_model/deployment/runtime/deployed_service.py +338 -0
  19. isa_model/deployment/services/__init__.py +9 -0
  20. isa_model/deployment/services/auto_deploy_vision_service.py +538 -0
  21. isa_model/deployment/services/model_service.py +332 -0
  22. isa_model/deployment/services/service_monitor.py +356 -0
  23. isa_model/deployment/services/service_registry.py +527 -0
  24. isa_model/deployment/services/simple_auto_deploy_vision_service.py +275 -0
  25. isa_model/eval/__init__.py +80 -44
  26. isa_model/eval/config/__init__.py +10 -0
  27. isa_model/eval/config/evaluation_config.py +108 -0
  28. isa_model/eval/evaluators/__init__.py +18 -0
  29. isa_model/eval/evaluators/base_evaluator.py +503 -0
  30. isa_model/eval/evaluators/llm_evaluator.py +472 -0
  31. isa_model/eval/factory.py +417 -709
  32. isa_model/eval/infrastructure/__init__.py +24 -0
  33. isa_model/eval/infrastructure/experiment_tracker.py +466 -0
  34. isa_model/eval/metrics.py +191 -21
  35. isa_model/inference/ai_factory.py +257 -601
  36. isa_model/inference/services/audio/base_stt_service.py +65 -1
  37. isa_model/inference/services/audio/base_tts_service.py +75 -1
  38. isa_model/inference/services/audio/openai_stt_service.py +189 -151
  39. isa_model/inference/services/audio/openai_tts_service.py +12 -10
  40. isa_model/inference/services/audio/replicate_tts_service.py +61 -56
  41. isa_model/inference/services/base_service.py +55 -17
  42. isa_model/inference/services/embedding/base_embed_service.py +65 -1
  43. isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
  44. isa_model/inference/services/embedding/openai_embed_service.py +8 -10
  45. isa_model/inference/services/helpers/stacked_config.py +148 -0
  46. isa_model/inference/services/img/__init__.py +18 -0
  47. isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -1
  48. isa_model/inference/services/{stacked → img}/flux_professional_service.py +25 -1
  49. isa_model/inference/services/{stacked → img/helpers}/base_stacked_service.py +40 -35
  50. isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +44 -31
  51. isa_model/inference/services/llm/__init__.py +3 -3
  52. isa_model/inference/services/llm/base_llm_service.py +492 -40
  53. isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
  54. isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
  55. isa_model/inference/services/llm/ollama_llm_service.py +51 -17
  56. isa_model/inference/services/llm/openai_llm_service.py +70 -19
  57. isa_model/inference/services/llm/yyds_llm_service.py +24 -23
  58. isa_model/inference/services/vision/__init__.py +38 -4
  59. isa_model/inference/services/vision/base_vision_service.py +218 -117
  60. isa_model/inference/services/vision/{isA_vision_service.py → disabled/isA_vision_service.py} +98 -0
  61. isa_model/inference/services/{stacked → vision}/doc_analysis_service.py +1 -1
  62. isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
  63. isa_model/inference/services/vision/helpers/image_utils.py +272 -3
  64. isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
  65. isa_model/inference/services/vision/openai_vision_service.py +104 -307
  66. isa_model/inference/services/vision/replicate_vision_service.py +140 -325
  67. isa_model/inference/services/{stacked → vision}/ui_analysis_service.py +2 -498
  68. isa_model/scripts/register_models.py +370 -0
  69. isa_model/scripts/register_models_with_embeddings.py +510 -0
  70. isa_model/serving/api/fastapi_server.py +6 -1
  71. isa_model/serving/api/routes/unified.py +274 -0
  72. {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/METADATA +4 -1
  73. {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/RECORD +78 -53
  74. isa_model/config/__init__.py +0 -9
  75. isa_model/config/config_manager.py +0 -213
  76. isa_model/core/model_manager.py +0 -213
  77. isa_model/core/model_registry.py +0 -375
  78. isa_model/core/vision_models_init.py +0 -116
  79. isa_model/inference/billing_tracker.py +0 -406
  80. isa_model/inference/services/llm/triton_llm_service.py +0 -481
  81. isa_model/inference/services/stacked/__init__.py +0 -26
  82. isa_model/inference/services/stacked/config.py +0 -426
  83. isa_model/inference/services/vision/ollama_vision_service.py +0 -194
  84. /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
  85. /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
  86. /isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +0 -0
  87. {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/WHEEL +0 -0
  88. {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,399 @@
1
+ from typing import Dict, Optional, List, Any
2
+ import logging
3
+ from pathlib import Path
4
+ from datetime import datetime
5
+ from huggingface_hub import hf_hub_download, snapshot_download
6
+ from huggingface_hub.errors import HfHubHTTPError
7
+ from .model_storage import ModelStorage, LocalModelStorage
8
+ from .model_repo import ModelRegistry, ModelType, ModelCapability
9
+ from .model_billing_tracker import ModelBillingTracker, ModelOperationType
10
+ from ..pricing_manager import PricingManager
11
+ from ..config import ConfigManager
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class ModelManager:
16
+ """
17
+ Model lifecycle management service
18
+
19
+ Handles the complete model lifecycle:
20
+ - Model registration and metadata management
21
+ - Model downloads, versions, and caching
22
+ - Cost tracking and billing across all operations
23
+ - Integration with model training, evaluation, and deployment
24
+ """
25
+
26
+
27
+ def __init__(self,
28
+ storage: Optional[ModelStorage] = None,
29
+ registry: Optional[ModelRegistry] = None,
30
+ billing_tracker: Optional[ModelBillingTracker] = None,
31
+ pricing_manager: Optional[PricingManager] = None,
32
+ config_manager: Optional[ConfigManager] = None):
33
+ self.storage = storage or LocalModelStorage()
34
+ self.registry = registry or ModelRegistry()
35
+ self.billing_tracker = billing_tracker or ModelBillingTracker(model_registry=self.registry)
36
+ self.pricing_manager = pricing_manager or PricingManager()
37
+ self.config_manager = config_manager or ConfigManager()
38
+
39
+ def get_model_pricing(self, provider: str, model_name: str) -> Dict[str, float]:
40
+ """获取模型定价信息"""
41
+ pricing = self.pricing_manager.get_model_pricing(provider, model_name)
42
+ if pricing:
43
+ return {"input": pricing.input_cost, "output": pricing.output_cost}
44
+ return {"input": 0.0, "output": 0.0}
45
+
46
+ def calculate_cost(self, provider: str, model_name: str, input_tokens: int, output_tokens: int) -> float:
47
+ """计算请求成本"""
48
+ return self.pricing_manager.calculate_cost(
49
+ provider=provider,
50
+ model_name=model_name,
51
+ input_units=input_tokens,
52
+ output_units=output_tokens
53
+ )
54
+
55
+ def get_cheapest_model(self, provider: str, model_type: str = "llm") -> Optional[str]:
56
+ """获取最便宜的模型"""
57
+ result = self.pricing_manager.get_cheapest_model(
58
+ provider=provider,
59
+ unit_type="token",
60
+ min_input_units=1000 # Assume 1K tokens for comparison
61
+ )
62
+ return result["model_name"] if result else None
63
+
64
+ async def get_model(self,
65
+ model_id: str,
66
+ repo_id: str,
67
+ model_type: ModelType,
68
+ capabilities: List[ModelCapability],
69
+ revision: Optional[str] = None,
70
+ force_download: bool = False) -> Optional[Path]:
71
+ """
72
+ Get model files, downloading if necessary
73
+
74
+ Args:
75
+ model_id: Unique identifier for the model
76
+ repo_id: Hugging Face repository ID
77
+ model_type: Type of model (LLM, embedding, etc.)
78
+ capabilities: List of model capabilities
79
+ revision: Specific model version/tag
80
+ force_download: Force re-download even if cached
81
+
82
+ Returns:
83
+ Path to the model files or None if failed
84
+ """
85
+ # Check if model is already downloaded
86
+ if not force_download:
87
+ model_path = await self.storage.load_model(model_id)
88
+ if model_path:
89
+ logger.info(f"Using cached model {model_id}")
90
+ return model_path
91
+
92
+ try:
93
+ # Download model files
94
+ logger.info(f"Downloading model {model_id} from {repo_id}")
95
+ model_dir = Path(f"./models/temp/{model_id}")
96
+ model_dir.mkdir(parents=True, exist_ok=True)
97
+
98
+ snapshot_download(
99
+ repo_id=repo_id,
100
+ revision=revision,
101
+ local_dir=model_dir,
102
+ local_dir_use_symlinks=False
103
+ )
104
+
105
+ # Save model and metadata
106
+ metadata = {
107
+ "repo_id": repo_id,
108
+ "revision": revision,
109
+ "downloaded_at": str(Path(model_dir).stat().st_mtime)
110
+ }
111
+
112
+ # Register model
113
+ self.registry.register_model(
114
+ model_id=model_id,
115
+ model_type=model_type,
116
+ capabilities=capabilities,
117
+ metadata=metadata
118
+ )
119
+
120
+ # Save model files
121
+ await self.storage.save_model(model_id, str(model_dir), metadata)
122
+
123
+ return await self.storage.load_model(model_id)
124
+
125
+ except HfHubHTTPError as e:
126
+ logger.error(f"Failed to download model {model_id}: {e}")
127
+ return None
128
+ except Exception as e:
129
+ logger.error(f"Unexpected error downloading model {model_id}: {e}")
130
+ return None
131
+
132
+ async def list_models(self) -> List[Dict[str, Any]]:
133
+ """List all downloaded models with their metadata"""
134
+ models = await self.storage.list_models()
135
+ return [
136
+ {
137
+ "model_id": model_id,
138
+ **metadata,
139
+ **(self.registry.get_model_info(model_id) or {})
140
+ }
141
+ for model_id, metadata in models.items()
142
+ ]
143
+
144
+ async def remove_model(self, model_id: str) -> bool:
145
+ """Remove a model and its metadata"""
146
+ try:
147
+ # Remove from storage
148
+ storage_success = await self.storage.delete_model(model_id)
149
+
150
+ # Unregister from registry
151
+ registry_success = self.registry.unregister_model(model_id)
152
+
153
+ return storage_success and registry_success
154
+
155
+ except Exception as e:
156
+ logger.error(f"Failed to remove model {model_id}: {e}")
157
+ return False
158
+
159
+ async def get_model_info(self, model_id: str) -> Optional[Dict[str, Any]]:
160
+ """Get information about a specific model"""
161
+ storage_info = await self.storage.get_metadata(model_id)
162
+ registry_info = self.registry.get_model_info(model_id)
163
+
164
+ if not storage_info and not registry_info:
165
+ return None
166
+
167
+ return {
168
+ **(storage_info or {}),
169
+ **(registry_info or {})
170
+ }
171
+
172
+ async def update_model(self,
173
+ model_id: str,
174
+ repo_id: str,
175
+ model_type: ModelType,
176
+ capabilities: List[ModelCapability],
177
+ revision: Optional[str] = None) -> bool:
178
+ """Update a model to a new version"""
179
+ try:
180
+ return bool(await self.get_model(
181
+ model_id=model_id,
182
+ repo_id=repo_id,
183
+ model_type=model_type,
184
+ capabilities=capabilities,
185
+ revision=revision,
186
+ force_download=True
187
+ ))
188
+ except Exception as e:
189
+ logger.error(f"Failed to update model {model_id}: {e}")
190
+ return False
191
+
192
+ # === MODEL LIFECYCLE MANAGEMENT ===
193
+
194
+ async def register_model_for_lifecycle(
195
+ self,
196
+ model_id: str,
197
+ model_type: ModelType,
198
+ capabilities: List[ModelCapability],
199
+ provider: str = "custom",
200
+ provider_model_name: Optional[str] = None,
201
+ metadata: Optional[Dict[str, Any]] = None
202
+ ) -> bool:
203
+ """
204
+ Register a model for lifecycle management
205
+
206
+ Args:
207
+ model_id: Unique identifier for the model
208
+ model_type: Type of model (LLM, embedding, etc.)
209
+ capabilities: List of model capabilities
210
+ provider: Provider name for billing
211
+ provider_model_name: Provider-specific model name for pricing
212
+ metadata: Additional metadata
213
+
214
+ Returns:
215
+ True if registration successful
216
+ """
217
+ try:
218
+ # Prepare metadata with billing info
219
+ full_metadata = metadata or {}
220
+ full_metadata.update({
221
+ "provider": provider,
222
+ "provider_model_name": provider_model_name or model_id,
223
+ "registered_for_lifecycle": True,
224
+ "lifecycle_stage": "registered"
225
+ })
226
+
227
+ # Register in model registry
228
+ success = self.registry.register_model(
229
+ model_id=model_id,
230
+ model_type=model_type,
231
+ capabilities=capabilities,
232
+ metadata=full_metadata
233
+ )
234
+
235
+ if success:
236
+ # Track registration operation
237
+ self.billing_tracker.track_model_usage(
238
+ model_id=model_id,
239
+ operation_type=ModelOperationType.STORAGE,
240
+ provider=provider,
241
+ service_type="model_management",
242
+ operation="register_model",
243
+ metadata={"stage": "registration"}
244
+ )
245
+
246
+ logger.info(f"Successfully registered model {model_id} for lifecycle management")
247
+
248
+ return success
249
+
250
+ except Exception as e:
251
+ logger.error(f"Failed to register model {model_id} for lifecycle: {e}")
252
+ return False
253
+
254
+ def track_model_usage(
255
+ self,
256
+ model_id: str,
257
+ operation_type: ModelOperationType,
258
+ provider: str,
259
+ service_type: str,
260
+ operation: str,
261
+ input_tokens: Optional[int] = None,
262
+ output_tokens: Optional[int] = None,
263
+ input_units: Optional[float] = None,
264
+ output_units: Optional[float] = None,
265
+ metadata: Optional[Dict[str, Any]] = None
266
+ ):
267
+ """
268
+ Track model usage and costs
269
+
270
+ This method should be called by:
271
+ - Training services when training a model
272
+ - Evaluation services when evaluating a model
273
+ - Deployment services when deploying a model
274
+ - Inference services when using a model for inference
275
+ """
276
+ return self.billing_tracker.track_model_usage(
277
+ model_id=model_id,
278
+ operation_type=operation_type,
279
+ provider=provider,
280
+ service_type=service_type,
281
+ operation=operation,
282
+ input_tokens=input_tokens,
283
+ output_tokens=output_tokens,
284
+ input_units=input_units,
285
+ output_units=output_units,
286
+ metadata=metadata
287
+ )
288
+
289
+ async def update_model_stage(
290
+ self,
291
+ model_id: str,
292
+ new_stage: str,
293
+ metadata: Optional[Dict[str, Any]] = None
294
+ ) -> bool:
295
+ """
296
+ Update model lifecycle stage
297
+
298
+ Args:
299
+ model_id: Model identifier
300
+ new_stage: New lifecycle stage (training, evaluation, deployment, production, retired)
301
+ metadata: Additional metadata for this stage
302
+
303
+ Returns:
304
+ True if update successful
305
+ """
306
+ try:
307
+ # Get current model info
308
+ model_info = self.registry.get_model_info(model_id)
309
+ if not model_info:
310
+ logger.error(f"Model {model_id} not found in registry")
311
+ return False
312
+
313
+ # Update metadata with new stage
314
+ current_metadata = model_info.get("metadata", {})
315
+ current_metadata.update({
316
+ "lifecycle_stage": new_stage,
317
+ "stage_updated_at": str(datetime.now()),
318
+ **(metadata or {})
319
+ })
320
+
321
+ # Update in registry
322
+ success = self.registry.register_model(
323
+ model_id=model_id,
324
+ model_type=ModelType(model_info["type"]),
325
+ capabilities=[ModelCapability(cap) for cap in model_info["capabilities"]],
326
+ metadata=current_metadata
327
+ )
328
+
329
+ if success:
330
+ logger.info(f"Updated model {model_id} to stage: {new_stage}")
331
+
332
+ return success
333
+
334
+ except Exception as e:
335
+ logger.error(f"Failed to update model {model_id} stage: {e}")
336
+ return False
337
+
338
+ def get_model_lifecycle_summary(self, model_id: str) -> Optional[Dict[str, Any]]:
339
+ """
340
+ Get complete lifecycle summary for a model including costs
341
+
342
+ Returns:
343
+ Dictionary with model info, lifecycle stage, and billing summary
344
+ """
345
+ try:
346
+ # Get model info from registry
347
+ model_info = self.registry.get_model_info(model_id)
348
+ if not model_info:
349
+ return None
350
+
351
+ # Get billing summary from tracker
352
+ billing_summary = self.billing_tracker.get_model_usage_summary(model_id)
353
+
354
+ return {
355
+ "model_id": model_id,
356
+ "model_info": model_info,
357
+ "billing_summary": billing_summary,
358
+ "current_stage": model_info.get("metadata", {}).get("lifecycle_stage", "unknown")
359
+ }
360
+
361
+ except Exception as e:
362
+ logger.error(f"Failed to get lifecycle summary for {model_id}: {e}")
363
+ return None
364
+
365
+ def list_models_by_stage(self, stage: str) -> List[Dict[str, Any]]:
366
+ """
367
+ List all models in a specific lifecycle stage
368
+
369
+ Args:
370
+ stage: Lifecycle stage to filter by
371
+
372
+ Returns:
373
+ List of model dictionaries
374
+ """
375
+ try:
376
+ all_models = self.registry.list_models()
377
+ stage_models = []
378
+
379
+ for model_id, model_info in all_models.items():
380
+ current_stage = model_info.get("metadata", {}).get("lifecycle_stage")
381
+ if current_stage == stage:
382
+ stage_models.append({
383
+ "model_id": model_id,
384
+ **model_info
385
+ })
386
+
387
+ return stage_models
388
+
389
+ except Exception as e:
390
+ logger.error(f"Failed to list models by stage {stage}: {e}")
391
+ return []
392
+
393
+ def get_billing_summary_by_operation(self, operation_type: ModelOperationType) -> Dict[str, Any]:
394
+ """Get billing summary for a specific operation type"""
395
+ return self.billing_tracker.get_operation_summary(operation_type)
396
+
397
+ def print_model_costs(self, model_id: str):
398
+ """Print cost summary for a specific model"""
399
+ self.billing_tracker.print_model_summary(model_id)
@@ -1,71 +1,93 @@
1
1
  """
2
- Supabase Storage Implementation for Model Registry
2
+ Unified Model Registry with Supabase Backend
3
3
 
4
- Uses Supabase as the backend database for model metadata and capabilities
5
- Supports the full model lifecycle with cloud-based storage
4
+ Simplified architecture using only Supabase for model metadata and capabilities.
5
+ No SQLite support - uses unified configuration management.
6
6
  """
7
7
 
8
8
  import os
9
9
  import json
10
10
  import logging
11
- from typing import Optional, Dict, Any, List
11
+ from typing import Dict, List, Optional, Any
12
+ from enum import Enum
12
13
  from datetime import datetime
13
- from pathlib import Path
14
14
 
15
15
  try:
16
16
  from supabase import create_client, Client
17
- from dotenv import load_dotenv
18
17
  SUPABASE_AVAILABLE = True
19
18
  except ImportError:
20
19
  SUPABASE_AVAILABLE = False
21
20
 
22
- from ..model_storage import ModelStorage
21
+ from ..config import ConfigManager
23
22
 
24
23
  logger = logging.getLogger(__name__)
25
24
 
26
- class SupabaseModelRegistry:
27
- """
28
- Supabase-based model registry for metadata and capabilities
29
-
30
- Replaces SQLite with cloud-based Supabase database
31
- """
25
+ class ModelCapability(str, Enum):
26
+ """Model capabilities"""
27
+ TEXT_GENERATION = "text_generation"
28
+ CHAT = "chat"
29
+ EMBEDDING = "embedding"
30
+ RERANKING = "reranking"
31
+ REASONING = "reasoning"
32
+ IMAGE_GENERATION = "image_generation"
33
+ IMAGE_ANALYSIS = "image_analysis"
34
+ AUDIO_TRANSCRIPTION = "audio_transcription"
35
+ IMAGE_UNDERSTANDING = "image_understanding"
36
+ UI_DETECTION = "ui_detection"
37
+ OCR = "ocr"
38
+ TABLE_DETECTION = "table_detection"
39
+ TABLE_STRUCTURE_RECOGNITION = "table_structure_recognition"
40
+
41
+ class ModelType(str, Enum):
42
+ """Model types"""
43
+ LLM = "llm"
44
+ EMBEDDING = "embedding"
45
+ RERANK = "rerank"
46
+ IMAGE = "image"
47
+ AUDIO = "audio"
48
+ VIDEO = "video"
49
+ VISION = "vision"
50
+
51
+ class ModelRegistry:
52
+ """Unified Model Registry with Supabase backend"""
32
53
 
33
54
  def __init__(self):
34
55
  if not SUPABASE_AVAILABLE:
35
56
  raise ImportError("supabase-py is required. Install with: pip install supabase")
36
57
 
37
- # Load environment variables
38
- load_dotenv()
58
+ # Get configuration from unified ConfigManager
59
+ self.config_manager = ConfigManager()
60
+ global_config = self.config_manager.get_global_config()
39
61
 
40
- self.supabase_url = os.getenv("SUPABASE_URL")
41
- self.supabase_key = os.getenv("SUPABASE_ANON_KEY")
62
+ # Get Supabase configuration from database config
63
+ self.supabase_url = global_config.database.supabase_url or os.getenv("SUPABASE_URL")
64
+ self.supabase_key = global_config.database.supabase_key or os.getenv("SUPABASE_ANON_KEY") or os.getenv("SERVICE_ROLE_KEY")
42
65
 
43
66
  if not self.supabase_url or not self.supabase_key:
44
- raise ValueError("SUPABASE_URL and SUPABASE_ANON_KEY must be set in environment")
67
+ raise ValueError("SUPABASE_URL and SUPABASE_ANON_KEY (or SERVICE_ROLE_KEY) must be configured")
45
68
 
46
69
  # Initialize Supabase client
47
70
  self.supabase: Client = create_client(self.supabase_url, self.supabase_key)
48
71
 
49
- # Initialize tables if needed
72
+ # Verify connection
50
73
  self._ensure_tables()
51
74
 
52
- logger.info("Supabase model registry initialized")
75
+ logger.info("Model registry initialized with Supabase backend")
53
76
 
54
77
  def _ensure_tables(self):
55
78
  """Ensure required tables exist in Supabase"""
56
- # Note: In production, these tables should be created via Supabase migrations
57
- # This is just for development/initialization
58
79
  try:
59
80
  # Check if models table exists by trying to query it
60
81
  result = self.supabase.table('models').select('model_id').limit(1).execute()
82
+ logger.debug("Models table verified")
61
83
  except Exception as e:
62
84
  logger.warning(f"Models table might not exist: {e}")
63
- # In production, you would run proper migrations here
85
+ # In production, tables should be created via Supabase migrations
64
86
 
65
87
  def register_model(self,
66
88
  model_id: str,
67
- model_type: str,
68
- capabilities: List[str],
89
+ model_type: ModelType,
90
+ capabilities: List[ModelCapability],
69
91
  metadata: Dict[str, Any]) -> bool:
70
92
  """Register a model with its capabilities and metadata"""
71
93
  try:
@@ -74,7 +96,7 @@ class SupabaseModelRegistry:
74
96
  # Prepare model data
75
97
  model_data = {
76
98
  'model_id': model_id,
77
- 'model_type': model_type,
99
+ 'model_type': model_type.value,
78
100
  'metadata': json.dumps(metadata),
79
101
  'created_at': current_time,
80
102
  'updated_at': current_time
@@ -95,7 +117,7 @@ class SupabaseModelRegistry:
95
117
  capability_data = [
96
118
  {
97
119
  'model_id': model_id,
98
- 'capability': capability,
120
+ 'capability': capability.value,
99
121
  'created_at': current_time
100
122
  }
101
123
  for capability in capabilities
@@ -107,7 +129,7 @@ class SupabaseModelRegistry:
107
129
  logger.error(f"Failed to insert capabilities for {model_id}")
108
130
  return False
109
131
 
110
- logger.info(f"Successfully registered model {model_id}")
132
+ logger.info(f"Successfully registered model {model_id} with {len(capabilities)} capabilities")
111
133
  return True
112
134
 
113
135
  except Exception as e:
@@ -159,10 +181,10 @@ class SupabaseModelRegistry:
159
181
  logger.error(f"Failed to get model info for {model_id}: {e}")
160
182
  return None
161
183
 
162
- def get_models_by_type(self, model_type: str) -> Dict[str, Dict[str, Any]]:
184
+ def get_models_by_type(self, model_type: ModelType) -> Dict[str, Dict[str, Any]]:
163
185
  """Get all models of a specific type"""
164
186
  try:
165
- models_result = self.supabase.table('models').select('*').eq('model_type', model_type).execute()
187
+ models_result = self.supabase.table('models').select('*').eq('model_type', model_type.value).execute()
166
188
 
167
189
  result = {}
168
190
  for model in models_result.data:
@@ -186,49 +208,20 @@ class SupabaseModelRegistry:
186
208
  logger.error(f"Failed to get models by type {model_type}: {e}")
187
209
  return {}
188
210
 
189
- def get_models_by_capability(self, capability: str) -> Dict[str, Dict[str, Any]]:
211
+ def get_models_by_capability(self, capability: ModelCapability) -> Dict[str, Dict[str, Any]]:
190
212
  """Get all models with a specific capability"""
191
213
  try:
192
- # Join query to get models with specific capability
193
- query = """
194
- SELECT DISTINCT m.*, mc.capability
195
- FROM models m
196
- INNER JOIN model_capabilities mc ON m.model_id = mc.model_id
197
- WHERE mc.capability = %s
198
- """
199
-
200
- # Use RPC for complex queries
201
- result = self.supabase.rpc('get_models_by_capability', {'capability_name': capability}).execute()
202
-
203
- if result.data:
204
- models_dict = {}
205
- for row in result.data:
206
- model_id = row['model_id']
207
- if model_id not in models_dict:
208
- # Get all capabilities for this model
209
- cap_result = self.supabase.table('model_capabilities').select('capability').eq('model_id', model_id).execute()
210
- capabilities = [cap['capability'] for cap in cap_result.data]
211
-
212
- models_dict[model_id] = {
213
- "type": row["model_type"],
214
- "capabilities": capabilities,
215
- "metadata": json.loads(row["metadata"]) if row["metadata"] else {},
216
- "created_at": row["created_at"],
217
- "updated_at": row["updated_at"]
218
- }
219
-
220
- return models_dict
221
-
222
- # Fallback: manual join if RPC not available
223
- cap_result = self.supabase.table('model_capabilities').select('model_id').eq('capability', capability).execute()
214
+ # Get model IDs with specific capability
215
+ cap_result = self.supabase.table('model_capabilities').select('model_id').eq('capability', capability.value).execute()
224
216
  model_ids = [row['model_id'] for row in cap_result.data]
225
217
 
226
218
  if not model_ids:
227
219
  return {}
228
220
 
221
+ # Get model details
229
222
  models_result = self.supabase.table('models').select('*').in_('model_id', model_ids).execute()
230
223
 
231
- result_dict = {}
224
+ result = {}
232
225
  for model in models_result.data:
233
226
  model_id = model["model_id"]
234
227
 
@@ -236,7 +229,7 @@ class SupabaseModelRegistry:
236
229
  all_caps_result = self.supabase.table('model_capabilities').select('capability').eq('model_id', model_id).execute()
237
230
  capabilities = [cap['capability'] for cap in all_caps_result.data]
238
231
 
239
- result_dict[model_id] = {
232
+ result[model_id] = {
240
233
  "type": model["model_type"],
241
234
  "capabilities": capabilities,
242
235
  "metadata": json.loads(model["metadata"]) if model["metadata"] else {},
@@ -244,16 +237,16 @@ class SupabaseModelRegistry:
244
237
  "updated_at": model["updated_at"]
245
238
  }
246
239
 
247
- return result_dict
240
+ return result
248
241
 
249
242
  except Exception as e:
250
243
  logger.error(f"Failed to get models by capability {capability}: {e}")
251
244
  return {}
252
245
 
253
- def has_capability(self, model_id: str, capability: str) -> bool:
246
+ def has_capability(self, model_id: str, capability: ModelCapability) -> bool:
254
247
  """Check if a model has a specific capability"""
255
248
  try:
256
- result = self.supabase.table('model_capabilities').select('model_id').eq('model_id', model_id).eq('capability', capability).execute()
249
+ result = self.supabase.table('model_capabilities').select('model_id').eq('model_id', model_id).eq('capability', capability.value).execute()
257
250
 
258
251
  return len(result.data) > 0
259
252
 
@@ -295,13 +288,19 @@ class SupabaseModelRegistry:
295
288
  total_result = self.supabase.table('models').select('model_id', count='exact').execute()
296
289
  total_models = total_result.count if total_result.count is not None else 0
297
290
 
298
- # Count by type
299
- type_result = self.supabase.rpc('get_model_type_counts').execute()
300
- type_counts = {row['model_type']: row['count'] for row in type_result.data} if type_result.data else {}
291
+ # Count by type (manual aggregation since RPC might not exist)
292
+ models_result = self.supabase.table('models').select('model_type').execute()
293
+ type_counts = {}
294
+ for model in models_result.data:
295
+ model_type = model['model_type']
296
+ type_counts[model_type] = type_counts.get(model_type, 0) + 1
301
297
 
302
298
  # Count by capability
303
- cap_result = self.supabase.rpc('get_capability_counts').execute()
304
- capability_counts = {row['capability']: row['count'] for row in cap_result.data} if cap_result.data else {}
299
+ caps_result = self.supabase.table('model_capabilities').select('capability').execute()
300
+ capability_counts = {}
301
+ for cap in caps_result.data:
302
+ capability = cap['capability']
303
+ capability_counts[capability] = capability_counts.get(capability, 0) + 1
305
304
 
306
305
  return {
307
306
  "total_models": total_models,