isa-model 0.3.9__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/client.py +732 -565
- isa_model/core/cache/redis_cache.py +401 -0
- isa_model/core/config/config_manager.py +53 -10
- isa_model/core/config.py +1 -1
- isa_model/core/database/__init__.py +1 -0
- isa_model/core/database/migrations.py +277 -0
- isa_model/core/database/supabase_client.py +123 -0
- isa_model/core/models/__init__.py +37 -0
- isa_model/core/models/model_billing_tracker.py +60 -88
- isa_model/core/models/model_manager.py +36 -18
- isa_model/core/models/model_repo.py +44 -38
- isa_model/core/models/model_statistics_tracker.py +234 -0
- isa_model/core/models/model_storage.py +0 -1
- isa_model/core/models/model_version_manager.py +959 -0
- isa_model/core/pricing_manager.py +2 -249
- isa_model/core/resilience/circuit_breaker.py +366 -0
- isa_model/core/security/secrets.py +358 -0
- isa_model/core/services/__init__.py +2 -4
- isa_model/core/services/intelligent_model_selector.py +101 -370
- isa_model/core/storage/hf_storage.py +1 -1
- isa_model/core/types.py +7 -0
- isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
- isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
- isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
- isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
- isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
- isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
- isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
- isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
- isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
- isa_model/deployment/core/deployment_manager.py +6 -4
- isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
- isa_model/eval/benchmarks/__init__.py +27 -0
- isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
- isa_model/eval/benchmarks.py +244 -12
- isa_model/eval/evaluators/__init__.py +8 -2
- isa_model/eval/evaluators/audio_evaluator.py +727 -0
- isa_model/eval/evaluators/embedding_evaluator.py +742 -0
- isa_model/eval/evaluators/vision_evaluator.py +564 -0
- isa_model/eval/example_evaluation.py +395 -0
- isa_model/eval/factory.py +272 -5
- isa_model/eval/isa_benchmarks.py +700 -0
- isa_model/eval/isa_integration.py +582 -0
- isa_model/eval/metrics.py +159 -6
- isa_model/eval/tests/unit/test_basic.py +396 -0
- isa_model/inference/ai_factory.py +44 -8
- isa_model/inference/services/audio/__init__.py +21 -0
- isa_model/inference/services/audio/base_realtime_service.py +225 -0
- isa_model/inference/services/audio/isa_tts_service.py +0 -0
- isa_model/inference/services/audio/openai_realtime_service.py +320 -124
- isa_model/inference/services/audio/openai_stt_service.py +32 -6
- isa_model/inference/services/base_service.py +17 -1
- isa_model/inference/services/embedding/__init__.py +13 -0
- isa_model/inference/services/embedding/base_embed_service.py +111 -8
- isa_model/inference/services/embedding/isa_embed_service.py +305 -0
- isa_model/inference/services/embedding/openai_embed_service.py +2 -4
- isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
- isa_model/inference/services/img/__init__.py +2 -2
- isa_model/inference/services/img/base_image_gen_service.py +24 -7
- isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
- isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
- isa_model/inference/services/img/services/replicate_flux.py +226 -0
- isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
- isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
- isa_model/inference/services/img/tests/test_img_client.py +297 -0
- isa_model/inference/services/llm/base_llm_service.py +30 -6
- isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
- isa_model/inference/services/llm/ollama_llm_service.py +2 -1
- isa_model/inference/services/llm/openai_llm_service.py +652 -55
- isa_model/inference/services/llm/yyds_llm_service.py +2 -1
- isa_model/inference/services/vision/__init__.py +5 -5
- isa_model/inference/services/vision/base_vision_service.py +118 -185
- isa_model/inference/services/vision/helpers/image_utils.py +11 -5
- isa_model/inference/services/vision/isa_vision_service.py +573 -0
- isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
- isa_model/serving/api/fastapi_server.py +88 -16
- isa_model/serving/api/middleware/auth.py +311 -0
- isa_model/serving/api/middleware/security.py +278 -0
- isa_model/serving/api/routes/analytics.py +486 -0
- isa_model/serving/api/routes/deployments.py +339 -0
- isa_model/serving/api/routes/evaluations.py +579 -0
- isa_model/serving/api/routes/logs.py +430 -0
- isa_model/serving/api/routes/settings.py +582 -0
- isa_model/serving/api/routes/unified.py +324 -165
- isa_model/serving/api/startup.py +304 -0
- isa_model/serving/modal_proxy_server.py +249 -0
- isa_model/training/__init__.py +100 -6
- isa_model/training/core/__init__.py +4 -1
- isa_model/training/examples/intelligent_training_example.py +281 -0
- isa_model/training/intelligent/__init__.py +25 -0
- isa_model/training/intelligent/decision_engine.py +643 -0
- isa_model/training/intelligent/intelligent_factory.py +888 -0
- isa_model/training/intelligent/knowledge_base.py +751 -0
- isa_model/training/intelligent/resource_optimizer.py +839 -0
- isa_model/training/intelligent/task_classifier.py +576 -0
- isa_model/training/storage/__init__.py +24 -0
- isa_model/training/storage/core_integration.py +439 -0
- isa_model/training/storage/training_repository.py +552 -0
- isa_model/training/storage/training_storage.py +628 -0
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
- isa_model-0.4.0.dist-info/RECORD +182 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
- isa_model/deployment/cloud/modal/register_models.py +0 -321
- isa_model/inference/adapter/unified_api.py +0 -248
- isa_model/inference/services/helpers/stacked_config.py +0 -148
- isa_model/inference/services/img/flux_professional_service.py +0 -603
- isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/others/table_transformer_service.py +0 -61
- isa_model/inference/services/vision/doc_analysis_service.py +0 -640
- isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
- isa_model/inference/services/vision/ui_analysis_service.py +0 -823
- isa_model/scripts/inference_tracker.py +0 -283
- isa_model/scripts/mlflow_manager.py +0 -379
- isa_model/scripts/model_registry.py +0 -465
- isa_model/scripts/register_models.py +0 -370
- isa_model/scripts/register_models_with_embeddings.py +0 -510
- isa_model/scripts/start_mlflow.py +0 -95
- isa_model/scripts/training_tracker.py +0 -257
- isa_model-0.3.9.dist-info/RECORD +0 -138
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
- {isa_model-0.3.9.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,321 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Model Registration Script for UI Analysis Pipeline
|
3
|
-
|
4
|
-
Registers the latest versions of UI analysis models in the core model registry
|
5
|
-
Prepares models for Modal deployment with proper version management
|
6
|
-
"""
|
7
|
-
|
8
|
-
import asyncio
|
9
|
-
from pathlib import Path
|
10
|
-
import sys
|
11
|
-
import os
|
12
|
-
|
13
|
-
# Add project root to path
|
14
|
-
project_root = Path(__file__).parent.parent.parent.parent
|
15
|
-
sys.path.insert(0, str(project_root))
|
16
|
-
|
17
|
-
from isa_model.core.model_manager import ModelManager
|
18
|
-
from isa_model.core.model_repo import ModelRegistry, ModelType, ModelCapability
|
19
|
-
|
20
|
-
async def register_ui_analysis_models():
|
21
|
-
"""Register UI analysis models with latest versions"""
|
22
|
-
|
23
|
-
# Initialize model manager and registry
|
24
|
-
model_manager = ModelManager()
|
25
|
-
|
26
|
-
print("🔧 Registering UI Analysis Models...")
|
27
|
-
|
28
|
-
# Debug: Check available capabilities
|
29
|
-
print("Available capabilities:")
|
30
|
-
for cap in ModelCapability:
|
31
|
-
print(f" - {cap.name}: {cap.value}")
|
32
|
-
print()
|
33
|
-
|
34
|
-
# Model definitions with latest versions from HuggingFace
|
35
|
-
models_to_register = [
|
36
|
-
{
|
37
|
-
"model_id": "omniparser-v2.0",
|
38
|
-
"repo_id": "microsoft/OmniParser",
|
39
|
-
"model_type": ModelType.VISION,
|
40
|
-
"capabilities": [
|
41
|
-
ModelCapability.UI_DETECTION,
|
42
|
-
ModelCapability.IMAGE_ANALYSIS,
|
43
|
-
ModelCapability.IMAGE_UNDERSTANDING
|
44
|
-
],
|
45
|
-
"revision": "main", # Latest version
|
46
|
-
"metadata": {
|
47
|
-
"description": "Microsoft OmniParser v2.0 - Advanced UI element detection",
|
48
|
-
"provider": "microsoft",
|
49
|
-
"model_family": "omniparser",
|
50
|
-
"version": "2.0",
|
51
|
-
"paper": "https://arxiv.org/abs/2408.00203",
|
52
|
-
"huggingface_url": "https://huggingface.co/microsoft/OmniParser",
|
53
|
-
"use_case": "UI element detection and parsing",
|
54
|
-
"input_format": "image",
|
55
|
-
"output_format": "structured_elements",
|
56
|
-
"gpu_memory_mb": 8192,
|
57
|
-
"inference_time_ms": 500
|
58
|
-
}
|
59
|
-
},
|
60
|
-
{
|
61
|
-
"model_id": "table-transformer-v1.1-detection",
|
62
|
-
"repo_id": "microsoft/table-transformer-detection",
|
63
|
-
"model_type": ModelType.VISION,
|
64
|
-
"capabilities": [
|
65
|
-
ModelCapability.TABLE_DETECTION,
|
66
|
-
ModelCapability.IMAGE_ANALYSIS
|
67
|
-
],
|
68
|
-
"revision": "main",
|
69
|
-
"metadata": {
|
70
|
-
"description": "Microsoft Table Transformer v1.1 - Table detection model",
|
71
|
-
"provider": "microsoft",
|
72
|
-
"model_family": "table-transformer",
|
73
|
-
"version": "1.1",
|
74
|
-
"paper": "https://arxiv.org/abs/2110.00061",
|
75
|
-
"huggingface_url": "https://huggingface.co/microsoft/table-transformer-detection",
|
76
|
-
"use_case": "Table detection in documents and images",
|
77
|
-
"input_format": "image",
|
78
|
-
"output_format": "bounding_boxes",
|
79
|
-
"gpu_memory_mb": 4096,
|
80
|
-
"inference_time_ms": 300
|
81
|
-
}
|
82
|
-
},
|
83
|
-
{
|
84
|
-
"model_id": "table-transformer-v1.1-structure",
|
85
|
-
"repo_id": "microsoft/table-transformer-structure-recognition",
|
86
|
-
"model_type": ModelType.VISION,
|
87
|
-
"capabilities": [
|
88
|
-
ModelCapability.TABLE_STRUCTURE_RECOGNITION,
|
89
|
-
ModelCapability.IMAGE_ANALYSIS
|
90
|
-
],
|
91
|
-
"revision": "main",
|
92
|
-
"metadata": {
|
93
|
-
"description": "Microsoft Table Transformer v1.1 - Table structure recognition",
|
94
|
-
"provider": "microsoft",
|
95
|
-
"model_family": "table-transformer",
|
96
|
-
"version": "1.1",
|
97
|
-
"paper": "https://arxiv.org/abs/2110.00061",
|
98
|
-
"huggingface_url": "https://huggingface.co/microsoft/table-transformer-structure-recognition",
|
99
|
-
"use_case": "Table structure recognition and cell extraction",
|
100
|
-
"input_format": "image",
|
101
|
-
"output_format": "table_structure",
|
102
|
-
"gpu_memory_mb": 4096,
|
103
|
-
"inference_time_ms": 400
|
104
|
-
}
|
105
|
-
},
|
106
|
-
{
|
107
|
-
"model_id": "paddleocr-v3.0",
|
108
|
-
"repo_id": "PaddlePaddle/PaddleOCR",
|
109
|
-
"model_type": ModelType.VISION,
|
110
|
-
"capabilities": [
|
111
|
-
ModelCapability.OCR,
|
112
|
-
ModelCapability.IMAGE_ANALYSIS
|
113
|
-
],
|
114
|
-
"revision": "release/2.8",
|
115
|
-
"metadata": {
|
116
|
-
"description": "PaddleOCR v3.0 - Multilingual OCR model",
|
117
|
-
"provider": "paddlepaddle",
|
118
|
-
"model_family": "paddleocr",
|
119
|
-
"version": "3.0",
|
120
|
-
"github_url": "https://github.com/PaddlePaddle/PaddleOCR",
|
121
|
-
"huggingface_url": "https://huggingface.co/PaddlePaddle/PaddleOCR",
|
122
|
-
"use_case": "Text extraction from images",
|
123
|
-
"input_format": "image",
|
124
|
-
"output_format": "text_with_coordinates",
|
125
|
-
"languages": ["en", "ch", "multilingual"],
|
126
|
-
"gpu_memory_mb": 2048,
|
127
|
-
"inference_time_ms": 200
|
128
|
-
}
|
129
|
-
},
|
130
|
-
{
|
131
|
-
"model_id": "yolov8n-fallback",
|
132
|
-
"repo_id": "ultralytics/yolov8",
|
133
|
-
"model_type": ModelType.VISION,
|
134
|
-
"capabilities": [
|
135
|
-
ModelCapability.IMAGE_ANALYSIS,
|
136
|
-
ModelCapability.UI_DETECTION # As fallback
|
137
|
-
],
|
138
|
-
"revision": "main",
|
139
|
-
"metadata": {
|
140
|
-
"description": "YOLOv8 Nano - Fallback object detection model",
|
141
|
-
"provider": "ultralytics",
|
142
|
-
"model_family": "yolo",
|
143
|
-
"version": "8.0",
|
144
|
-
"github_url": "https://github.com/ultralytics/ultralytics",
|
145
|
-
"use_case": "General object detection (fallback for UI elements)",
|
146
|
-
"input_format": "image",
|
147
|
-
"output_format": "bounding_boxes",
|
148
|
-
"gpu_memory_mb": 1024,
|
149
|
-
"inference_time_ms": 50
|
150
|
-
}
|
151
|
-
}
|
152
|
-
]
|
153
|
-
|
154
|
-
# Register each model
|
155
|
-
registration_results = []
|
156
|
-
|
157
|
-
for model_config in models_to_register:
|
158
|
-
print(f"\n📝 Registering {model_config['model_id']}...")
|
159
|
-
|
160
|
-
try:
|
161
|
-
# Register model in registry (without downloading)
|
162
|
-
success = model_manager.registry.register_model(
|
163
|
-
model_id=model_config['model_id'],
|
164
|
-
model_type=model_config['model_type'],
|
165
|
-
capabilities=model_config['capabilities'],
|
166
|
-
metadata={
|
167
|
-
**model_config['metadata'],
|
168
|
-
'repo_id': model_config['repo_id'],
|
169
|
-
'revision': model_config['revision'],
|
170
|
-
'registered_at': 'auto',
|
171
|
-
'download_status': 'not_downloaded'
|
172
|
-
}
|
173
|
-
)
|
174
|
-
|
175
|
-
if success:
|
176
|
-
print(f"✅ Successfully registered {model_config['model_id']}")
|
177
|
-
registration_results.append({
|
178
|
-
'model_id': model_config['model_id'],
|
179
|
-
'status': 'success'
|
180
|
-
})
|
181
|
-
else:
|
182
|
-
print(f"❌ Failed to register {model_config['model_id']}")
|
183
|
-
registration_results.append({
|
184
|
-
'model_id': model_config['model_id'],
|
185
|
-
'status': 'failed'
|
186
|
-
})
|
187
|
-
|
188
|
-
except Exception as e:
|
189
|
-
print(f"❌ Error registering {model_config['model_id']}: {e}")
|
190
|
-
registration_results.append({
|
191
|
-
'model_id': model_config['model_id'],
|
192
|
-
'status': 'error',
|
193
|
-
'error': str(e)
|
194
|
-
})
|
195
|
-
|
196
|
-
# Print summary
|
197
|
-
print(f"\n📊 Registration Summary:")
|
198
|
-
successful = [r for r in registration_results if r['status'] == 'success']
|
199
|
-
failed = [r for r in registration_results if r['status'] != 'success']
|
200
|
-
|
201
|
-
print(f"✅ Successfully registered: {len(successful)} models")
|
202
|
-
for result in successful:
|
203
|
-
print(f" - {result['model_id']}")
|
204
|
-
|
205
|
-
if failed:
|
206
|
-
print(f"❌ Failed to register: {len(failed)} models")
|
207
|
-
for result in failed:
|
208
|
-
error_msg = f" ({result.get('error', 'unknown error')})" if 'error' in result else ""
|
209
|
-
print(f" - {result['model_id']}{error_msg}")
|
210
|
-
|
211
|
-
return registration_results
|
212
|
-
|
213
|
-
async def verify_model_registry():
|
214
|
-
"""Verify registered models and their capabilities"""
|
215
|
-
|
216
|
-
model_manager = ModelManager()
|
217
|
-
|
218
|
-
print(f"\n🔍 Verifying Model Registry...")
|
219
|
-
|
220
|
-
# Check models by capability
|
221
|
-
capabilities_to_check = [
|
222
|
-
ModelCapability.UI_DETECTION,
|
223
|
-
ModelCapability.OCR,
|
224
|
-
ModelCapability.TABLE_DETECTION,
|
225
|
-
ModelCapability.TABLE_STRUCTURE_RECOGNITION
|
226
|
-
]
|
227
|
-
|
228
|
-
for capability in capabilities_to_check:
|
229
|
-
models = model_manager.registry.get_models_by_capability(capability)
|
230
|
-
print(f"\n📋 Models with {capability.value} capability:")
|
231
|
-
|
232
|
-
if models:
|
233
|
-
for model_id, model_info in models.items():
|
234
|
-
metadata = model_info.get('metadata', {})
|
235
|
-
version = metadata.get('version', 'unknown')
|
236
|
-
provider = metadata.get('provider', 'unknown')
|
237
|
-
print(f" ✅ {model_id} (v{version}, {provider})")
|
238
|
-
else:
|
239
|
-
print(f" ❌ No models found for {capability.value}")
|
240
|
-
|
241
|
-
# Print overall stats
|
242
|
-
stats = model_manager.registry.get_stats()
|
243
|
-
print(f"\n📈 Registry Statistics:")
|
244
|
-
print(f" Total models: {stats['total_models']}")
|
245
|
-
print(f" Models by type: {stats['models_by_type']}")
|
246
|
-
print(f" Models by capability: {stats['models_by_capability']}")
|
247
|
-
|
248
|
-
def get_model_for_capability(capability: ModelCapability) -> str:
|
249
|
-
"""Get the best model for a specific capability"""
|
250
|
-
|
251
|
-
model_manager = ModelManager()
|
252
|
-
models = model_manager.registry.get_models_by_capability(capability)
|
253
|
-
|
254
|
-
if not models:
|
255
|
-
return None
|
256
|
-
|
257
|
-
# Priority order for UI analysis models
|
258
|
-
priority_order = {
|
259
|
-
ModelCapability.UI_DETECTION: [
|
260
|
-
"omniparser-v2.0",
|
261
|
-
"yolov8n-fallback"
|
262
|
-
],
|
263
|
-
ModelCapability.OCR: [
|
264
|
-
"paddleocr-v3.0"
|
265
|
-
],
|
266
|
-
ModelCapability.TABLE_DETECTION: [
|
267
|
-
"table-transformer-v1.1-detection"
|
268
|
-
],
|
269
|
-
ModelCapability.TABLE_STRUCTURE_RECOGNITION: [
|
270
|
-
"table-transformer-v1.1-structure"
|
271
|
-
]
|
272
|
-
}
|
273
|
-
|
274
|
-
preferred_models = priority_order.get(capability, [])
|
275
|
-
|
276
|
-
# Return the first available preferred model
|
277
|
-
for model_id in preferred_models:
|
278
|
-
if model_id in models:
|
279
|
-
return model_id
|
280
|
-
|
281
|
-
# Fallback to first available model
|
282
|
-
return list(models.keys())[0] if models else None
|
283
|
-
|
284
|
-
async def main():
|
285
|
-
"""Main registration workflow"""
|
286
|
-
|
287
|
-
print("🚀 ISA Model Registry - UI Analysis Models Registration")
|
288
|
-
print("=" * 60)
|
289
|
-
|
290
|
-
try:
|
291
|
-
# Register models
|
292
|
-
results = await register_ui_analysis_models()
|
293
|
-
|
294
|
-
# Verify registration
|
295
|
-
await verify_model_registry()
|
296
|
-
|
297
|
-
print(f"\n🎉 Model registration completed!")
|
298
|
-
print(f" Use ModelManager.get_model() to download and use models")
|
299
|
-
print(f" Use get_model_for_capability() to get recommended models")
|
300
|
-
|
301
|
-
# Show usage example
|
302
|
-
print(f"\n💡 Usage Example:")
|
303
|
-
print(f" from isa_model.core.model_manager import ModelManager")
|
304
|
-
print(f" from isa_model.core.model_repo import ModelCapability")
|
305
|
-
print(f" ")
|
306
|
-
print(f" manager = ModelManager()")
|
307
|
-
print(f" ui_model_path = await manager.get_model(")
|
308
|
-
print(f" model_id='omniparser-v2.0',")
|
309
|
-
print(f" repo_id='microsoft/OmniParser',")
|
310
|
-
print(f" model_type=ModelType.VISION,")
|
311
|
-
print(f" capabilities=[ModelCapability.UI_DETECTION]")
|
312
|
-
print(f" )")
|
313
|
-
|
314
|
-
except Exception as e:
|
315
|
-
print(f"❌ Registration failed: {e}")
|
316
|
-
return False
|
317
|
-
|
318
|
-
return True
|
319
|
-
|
320
|
-
if __name__ == "__main__":
|
321
|
-
asyncio.run(main())
|
@@ -1,248 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import json
|
3
|
-
import logging
|
4
|
-
from typing import Dict, List, Any, Optional, Union
|
5
|
-
from fastapi import FastAPI, HTTPException, Depends, Request
|
6
|
-
from pydantic import BaseModel, Field
|
7
|
-
|
8
|
-
from isa_model.inference.ai_factory import AIFactory
|
9
|
-
|
10
|
-
# Configure logging
|
11
|
-
logging.basicConfig(level=logging.INFO)
|
12
|
-
logger = logging.getLogger("unified_api")
|
13
|
-
|
14
|
-
# Create FastAPI app
|
15
|
-
app = FastAPI(
|
16
|
-
title="Unified AI Model API",
|
17
|
-
description="API for inference with Llama3-8B, Gemma3-4B, Whisper, and BGE-M3 models",
|
18
|
-
version="1.0.0"
|
19
|
-
)
|
20
|
-
|
21
|
-
# Models
|
22
|
-
class ChatMessage(BaseModel):
|
23
|
-
role: str = Field(..., description="Role of the message sender (system, user, assistant)")
|
24
|
-
content: str = Field(..., description="Content of the message")
|
25
|
-
|
26
|
-
class ChatCompletionRequest(BaseModel):
|
27
|
-
model: str = Field(..., description="Model ID to use (llama, gemma)")
|
28
|
-
messages: List[ChatMessage] = Field(..., description="List of messages in the conversation")
|
29
|
-
temperature: Optional[float] = Field(0.7, description="Sampling temperature")
|
30
|
-
max_tokens: Optional[int] = Field(512, description="Maximum number of tokens to generate")
|
31
|
-
top_p: Optional[float] = Field(0.9, description="Top-p sampling parameter")
|
32
|
-
top_k: Optional[int] = Field(50, description="Top-k sampling parameter")
|
33
|
-
|
34
|
-
class ChatCompletionResponse(BaseModel):
|
35
|
-
model: str = Field(..., description="Model used for completion")
|
36
|
-
choices: List[Dict[str, Any]] = Field(..., description="Generated completions")
|
37
|
-
usage: Dict[str, int] = Field(..., description="Token usage statistics")
|
38
|
-
|
39
|
-
class EmbeddingRequest(BaseModel):
|
40
|
-
model: str = Field(..., description="Model ID to use (bge_embed)")
|
41
|
-
input: Union[str, List[str]] = Field(..., description="Text to embed")
|
42
|
-
normalize: Optional[bool] = Field(True, description="Whether to normalize embeddings")
|
43
|
-
|
44
|
-
class TranscriptionRequest(BaseModel):
|
45
|
-
model: str = Field(..., description="Model ID to use (whisper)")
|
46
|
-
audio: str = Field(..., description="Base64-encoded audio data or URL")
|
47
|
-
language: Optional[str] = Field("en", description="Language code")
|
48
|
-
|
49
|
-
# Factory for creating services
|
50
|
-
ai_factory = AIFactory()
|
51
|
-
|
52
|
-
# Dependency to get LLM service
|
53
|
-
async def get_llm_service(model: str):
|
54
|
-
if model == "llama":
|
55
|
-
return await ai_factory.get_llm_service("llama")
|
56
|
-
elif model == "gemma":
|
57
|
-
return await ai_factory.get_llm_service("gemma")
|
58
|
-
else:
|
59
|
-
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
60
|
-
|
61
|
-
# Dependency to get embedding service
|
62
|
-
async def get_embedding_service(model: str):
|
63
|
-
if model == "bge_embed":
|
64
|
-
return await ai_factory.get_embedding_service("bge_embed")
|
65
|
-
else:
|
66
|
-
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
67
|
-
|
68
|
-
# Dependency to get speech service
|
69
|
-
async def get_speech_service(model: str):
|
70
|
-
if model == "whisper":
|
71
|
-
return await ai_factory.get_speech_service("whisper")
|
72
|
-
else:
|
73
|
-
raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
|
74
|
-
|
75
|
-
# Endpoints
|
76
|
-
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
77
|
-
async def chat_completion(request: ChatCompletionRequest):
|
78
|
-
"""Generate chat completion"""
|
79
|
-
try:
|
80
|
-
# Get the appropriate service
|
81
|
-
service = await get_llm_service(request.model)
|
82
|
-
|
83
|
-
# Format messages
|
84
|
-
formatted_messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
85
|
-
|
86
|
-
# Extract system prompt if present
|
87
|
-
system_prompt = None
|
88
|
-
if formatted_messages and formatted_messages[0]["role"] == "system":
|
89
|
-
system_prompt = formatted_messages[0]["content"]
|
90
|
-
formatted_messages = formatted_messages[1:]
|
91
|
-
|
92
|
-
# Get user prompt (last user message)
|
93
|
-
user_prompt = ""
|
94
|
-
for msg in reversed(formatted_messages):
|
95
|
-
if msg["role"] == "user":
|
96
|
-
user_prompt = msg["content"]
|
97
|
-
break
|
98
|
-
|
99
|
-
if not user_prompt:
|
100
|
-
raise HTTPException(status_code=400, detail="No user message found")
|
101
|
-
|
102
|
-
# Set generation config
|
103
|
-
generation_config = {
|
104
|
-
"temperature": request.temperature,
|
105
|
-
"max_new_tokens": request.max_tokens,
|
106
|
-
"top_p": request.top_p,
|
107
|
-
"top_k": request.top_k
|
108
|
-
}
|
109
|
-
|
110
|
-
# Generate completion
|
111
|
-
completion = await service.generate(
|
112
|
-
prompt=user_prompt,
|
113
|
-
system_prompt=system_prompt,
|
114
|
-
generation_config=generation_config
|
115
|
-
)
|
116
|
-
|
117
|
-
# Format response
|
118
|
-
response = {
|
119
|
-
"model": request.model,
|
120
|
-
"choices": [
|
121
|
-
{
|
122
|
-
"message": {
|
123
|
-
"role": "assistant",
|
124
|
-
"content": completion
|
125
|
-
},
|
126
|
-
"finish_reason": "stop",
|
127
|
-
"index": 0
|
128
|
-
}
|
129
|
-
],
|
130
|
-
"usage": {
|
131
|
-
"prompt_tokens": len(user_prompt.split()),
|
132
|
-
"completion_tokens": len(completion.split()),
|
133
|
-
"total_tokens": len(user_prompt.split()) + len(completion.split())
|
134
|
-
}
|
135
|
-
}
|
136
|
-
|
137
|
-
return response
|
138
|
-
|
139
|
-
except Exception as e:
|
140
|
-
logger.error(f"Error in chat completion: {str(e)}")
|
141
|
-
raise HTTPException(status_code=500, detail=str(e))
|
142
|
-
|
143
|
-
@app.post("/v1/embeddings")
|
144
|
-
async def create_embedding(request: EmbeddingRequest):
|
145
|
-
"""Generate embeddings for text"""
|
146
|
-
try:
|
147
|
-
# Get the embedding service
|
148
|
-
service = await get_embedding_service("bge_embed")
|
149
|
-
|
150
|
-
# Generate embeddings
|
151
|
-
if isinstance(request.input, str):
|
152
|
-
embeddings = await service.embed(request.input, normalize=request.normalize)
|
153
|
-
data = [{"embedding": embeddings[0].tolist(), "index": 0}]
|
154
|
-
else:
|
155
|
-
embeddings = await service.embed(request.input, normalize=request.normalize)
|
156
|
-
data = [{"embedding": emb.tolist(), "index": i} for i, emb in enumerate(embeddings)]
|
157
|
-
|
158
|
-
# Format response
|
159
|
-
response = {
|
160
|
-
"model": request.model,
|
161
|
-
"data": data,
|
162
|
-
"usage": {
|
163
|
-
"prompt_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input])),
|
164
|
-
"total_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input]))
|
165
|
-
}
|
166
|
-
}
|
167
|
-
|
168
|
-
return response
|
169
|
-
|
170
|
-
except Exception as e:
|
171
|
-
logger.error(f"Error in embedding generation: {str(e)}")
|
172
|
-
raise HTTPException(status_code=500, detail=str(e))
|
173
|
-
|
174
|
-
@app.post("/v1/audio/transcriptions")
|
175
|
-
async def transcribe_audio(request: TranscriptionRequest):
|
176
|
-
"""Transcribe audio to text"""
|
177
|
-
try:
|
178
|
-
import base64
|
179
|
-
|
180
|
-
# Get the speech service
|
181
|
-
service = await get_speech_service("whisper")
|
182
|
-
|
183
|
-
# Process audio
|
184
|
-
if request.audio.startswith(("http://", "https://")):
|
185
|
-
# URL - download audio
|
186
|
-
import requests
|
187
|
-
audio_data = requests.get(request.audio).content
|
188
|
-
else:
|
189
|
-
# Base64 - decode
|
190
|
-
audio_data = base64.b64decode(request.audio)
|
191
|
-
|
192
|
-
# Transcribe
|
193
|
-
transcription = await service.transcribe(
|
194
|
-
audio=audio_data,
|
195
|
-
language=request.language
|
196
|
-
)
|
197
|
-
|
198
|
-
# Format response
|
199
|
-
response = {
|
200
|
-
"model": request.model,
|
201
|
-
"text": transcription
|
202
|
-
}
|
203
|
-
|
204
|
-
return response
|
205
|
-
|
206
|
-
except Exception as e:
|
207
|
-
logger.error(f"Error in audio transcription: {str(e)}")
|
208
|
-
raise HTTPException(status_code=500, detail=str(e))
|
209
|
-
|
210
|
-
# Health check endpoint
|
211
|
-
@app.get("/health")
|
212
|
-
async def health_check():
|
213
|
-
"""Health check endpoint"""
|
214
|
-
return {"status": "healthy"}
|
215
|
-
|
216
|
-
# Model info endpoint
|
217
|
-
@app.get("/v1/models")
|
218
|
-
async def list_models():
|
219
|
-
"""List available models"""
|
220
|
-
models = [
|
221
|
-
{
|
222
|
-
"id": "llama",
|
223
|
-
"type": "llm",
|
224
|
-
"description": "Llama3-8B language model"
|
225
|
-
},
|
226
|
-
{
|
227
|
-
"id": "gemma",
|
228
|
-
"type": "llm",
|
229
|
-
"description": "Gemma3-4B language model"
|
230
|
-
},
|
231
|
-
{
|
232
|
-
"id": "whisper",
|
233
|
-
"type": "speech",
|
234
|
-
"description": "Whisper-tiny speech-to-text model"
|
235
|
-
},
|
236
|
-
{
|
237
|
-
"id": "bge_embed",
|
238
|
-
"type": "embedding",
|
239
|
-
"description": "BGE-M3 text embedding model"
|
240
|
-
}
|
241
|
-
]
|
242
|
-
|
243
|
-
return {"data": models}
|
244
|
-
|
245
|
-
# Main entry point
|
246
|
-
if __name__ == "__main__":
|
247
|
-
import uvicorn
|
248
|
-
uvicorn.run(app, host="0.0.0.0", port=8080)
|