isa-model 0.3.91__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. isa_model/client.py +732 -573
  2. isa_model/core/cache/redis_cache.py +401 -0
  3. isa_model/core/config/config_manager.py +53 -10
  4. isa_model/core/config.py +1 -1
  5. isa_model/core/database/__init__.py +1 -0
  6. isa_model/core/database/migrations.py +277 -0
  7. isa_model/core/database/supabase_client.py +123 -0
  8. isa_model/core/models/__init__.py +37 -0
  9. isa_model/core/models/model_billing_tracker.py +60 -88
  10. isa_model/core/models/model_manager.py +36 -18
  11. isa_model/core/models/model_repo.py +44 -38
  12. isa_model/core/models/model_statistics_tracker.py +234 -0
  13. isa_model/core/models/model_storage.py +0 -1
  14. isa_model/core/models/model_version_manager.py +959 -0
  15. isa_model/core/pricing_manager.py +2 -249
  16. isa_model/core/resilience/circuit_breaker.py +366 -0
  17. isa_model/core/security/secrets.py +358 -0
  18. isa_model/core/services/__init__.py +2 -4
  19. isa_model/core/services/intelligent_model_selector.py +101 -370
  20. isa_model/core/storage/hf_storage.py +1 -1
  21. isa_model/core/types.py +7 -0
  22. isa_model/deployment/cloud/modal/isa_audio_chatTTS_service.py +520 -0
  23. isa_model/deployment/cloud/modal/isa_audio_fish_service.py +0 -0
  24. isa_model/deployment/cloud/modal/isa_audio_openvoice_service.py +758 -0
  25. isa_model/deployment/cloud/modal/isa_audio_service_v2.py +1044 -0
  26. isa_model/deployment/cloud/modal/isa_embed_rerank_service.py +296 -0
  27. isa_model/deployment/cloud/modal/isa_video_hunyuan_service.py +423 -0
  28. isa_model/deployment/cloud/modal/isa_vision_ocr_service.py +519 -0
  29. isa_model/deployment/cloud/modal/isa_vision_qwen25_service.py +709 -0
  30. isa_model/deployment/cloud/modal/isa_vision_table_service.py +467 -323
  31. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +607 -180
  32. isa_model/deployment/cloud/modal/isa_vision_ui_service_optimized.py +660 -0
  33. isa_model/deployment/core/deployment_manager.py +6 -4
  34. isa_model/deployment/services/auto_hf_modal_deployer.py +894 -0
  35. isa_model/eval/benchmarks/__init__.py +27 -0
  36. isa_model/eval/benchmarks/multimodal_datasets.py +460 -0
  37. isa_model/eval/benchmarks.py +244 -12
  38. isa_model/eval/evaluators/__init__.py +8 -2
  39. isa_model/eval/evaluators/audio_evaluator.py +727 -0
  40. isa_model/eval/evaluators/embedding_evaluator.py +742 -0
  41. isa_model/eval/evaluators/vision_evaluator.py +564 -0
  42. isa_model/eval/example_evaluation.py +395 -0
  43. isa_model/eval/factory.py +272 -5
  44. isa_model/eval/isa_benchmarks.py +700 -0
  45. isa_model/eval/isa_integration.py +582 -0
  46. isa_model/eval/metrics.py +159 -6
  47. isa_model/eval/tests/unit/test_basic.py +396 -0
  48. isa_model/inference/ai_factory.py +44 -8
  49. isa_model/inference/services/audio/__init__.py +21 -0
  50. isa_model/inference/services/audio/base_realtime_service.py +225 -0
  51. isa_model/inference/services/audio/isa_tts_service.py +0 -0
  52. isa_model/inference/services/audio/openai_realtime_service.py +320 -124
  53. isa_model/inference/services/audio/openai_stt_service.py +32 -6
  54. isa_model/inference/services/base_service.py +17 -1
  55. isa_model/inference/services/embedding/__init__.py +13 -0
  56. isa_model/inference/services/embedding/base_embed_service.py +111 -8
  57. isa_model/inference/services/embedding/isa_embed_service.py +305 -0
  58. isa_model/inference/services/embedding/openai_embed_service.py +2 -4
  59. isa_model/inference/services/embedding/tests/test_embedding.py +222 -0
  60. isa_model/inference/services/img/__init__.py +2 -2
  61. isa_model/inference/services/img/base_image_gen_service.py +24 -7
  62. isa_model/inference/services/img/replicate_image_gen_service.py +84 -422
  63. isa_model/inference/services/img/services/replicate_face_swap.py +193 -0
  64. isa_model/inference/services/img/services/replicate_flux.py +226 -0
  65. isa_model/inference/services/img/services/replicate_flux_kontext.py +219 -0
  66. isa_model/inference/services/img/services/replicate_sticker_maker.py +249 -0
  67. isa_model/inference/services/img/tests/test_img_client.py +297 -0
  68. isa_model/inference/services/llm/base_llm_service.py +30 -6
  69. isa_model/inference/services/llm/helpers/llm_adapter.py +63 -9
  70. isa_model/inference/services/llm/ollama_llm_service.py +2 -1
  71. isa_model/inference/services/llm/openai_llm_service.py +652 -55
  72. isa_model/inference/services/llm/yyds_llm_service.py +2 -1
  73. isa_model/inference/services/vision/__init__.py +5 -5
  74. isa_model/inference/services/vision/base_vision_service.py +118 -185
  75. isa_model/inference/services/vision/helpers/image_utils.py +11 -5
  76. isa_model/inference/services/vision/isa_vision_service.py +573 -0
  77. isa_model/inference/services/vision/tests/test_ocr_client.py +284 -0
  78. isa_model/serving/api/fastapi_server.py +88 -16
  79. isa_model/serving/api/middleware/auth.py +311 -0
  80. isa_model/serving/api/middleware/security.py +278 -0
  81. isa_model/serving/api/routes/analytics.py +486 -0
  82. isa_model/serving/api/routes/deployments.py +339 -0
  83. isa_model/serving/api/routes/evaluations.py +579 -0
  84. isa_model/serving/api/routes/logs.py +430 -0
  85. isa_model/serving/api/routes/settings.py +582 -0
  86. isa_model/serving/api/routes/unified.py +324 -165
  87. isa_model/serving/api/startup.py +304 -0
  88. isa_model/serving/modal_proxy_server.py +249 -0
  89. isa_model/training/__init__.py +100 -6
  90. isa_model/training/core/__init__.py +4 -1
  91. isa_model/training/examples/intelligent_training_example.py +281 -0
  92. isa_model/training/intelligent/__init__.py +25 -0
  93. isa_model/training/intelligent/decision_engine.py +643 -0
  94. isa_model/training/intelligent/intelligent_factory.py +888 -0
  95. isa_model/training/intelligent/knowledge_base.py +751 -0
  96. isa_model/training/intelligent/resource_optimizer.py +839 -0
  97. isa_model/training/intelligent/task_classifier.py +576 -0
  98. isa_model/training/storage/__init__.py +24 -0
  99. isa_model/training/storage/core_integration.py +439 -0
  100. isa_model/training/storage/training_repository.py +552 -0
  101. isa_model/training/storage/training_storage.py +628 -0
  102. {isa_model-0.3.91.dist-info → isa_model-0.4.0.dist-info}/METADATA +13 -1
  103. isa_model-0.4.0.dist-info/RECORD +182 -0
  104. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +0 -766
  105. isa_model/deployment/cloud/modal/register_models.py +0 -321
  106. isa_model/inference/adapter/unified_api.py +0 -248
  107. isa_model/inference/services/helpers/stacked_config.py +0 -148
  108. isa_model/inference/services/img/flux_professional_service.py +0 -603
  109. isa_model/inference/services/img/helpers/base_stacked_service.py +0 -274
  110. isa_model/inference/services/others/table_transformer_service.py +0 -61
  111. isa_model/inference/services/vision/doc_analysis_service.py +0 -640
  112. isa_model/inference/services/vision/helpers/base_stacked_service.py +0 -274
  113. isa_model/inference/services/vision/ui_analysis_service.py +0 -823
  114. isa_model/scripts/inference_tracker.py +0 -283
  115. isa_model/scripts/mlflow_manager.py +0 -379
  116. isa_model/scripts/model_registry.py +0 -465
  117. isa_model/scripts/register_models.py +0 -370
  118. isa_model/scripts/register_models_with_embeddings.py +0 -510
  119. isa_model/scripts/start_mlflow.py +0 -95
  120. isa_model/scripts/training_tracker.py +0 -257
  121. isa_model-0.3.91.dist-info/RECORD +0 -138
  122. {isa_model-0.3.91.dist-info → isa_model-0.4.0.dist-info}/WHEEL +0 -0
  123. {isa_model-0.3.91.dist-info → isa_model-0.4.0.dist-info}/top_level.txt +0 -0
@@ -1,321 +0,0 @@
1
- """
2
- Model Registration Script for UI Analysis Pipeline
3
-
4
- Registers the latest versions of UI analysis models in the core model registry
5
- Prepares models for Modal deployment with proper version management
6
- """
7
-
8
- import asyncio
9
- from pathlib import Path
10
- import sys
11
- import os
12
-
13
- # Add project root to path
14
- project_root = Path(__file__).parent.parent.parent.parent
15
- sys.path.insert(0, str(project_root))
16
-
17
- from isa_model.core.model_manager import ModelManager
18
- from isa_model.core.model_repo import ModelRegistry, ModelType, ModelCapability
19
-
20
- async def register_ui_analysis_models():
21
- """Register UI analysis models with latest versions"""
22
-
23
- # Initialize model manager and registry
24
- model_manager = ModelManager()
25
-
26
- print("🔧 Registering UI Analysis Models...")
27
-
28
- # Debug: Check available capabilities
29
- print("Available capabilities:")
30
- for cap in ModelCapability:
31
- print(f" - {cap.name}: {cap.value}")
32
- print()
33
-
34
- # Model definitions with latest versions from HuggingFace
35
- models_to_register = [
36
- {
37
- "model_id": "omniparser-v2.0",
38
- "repo_id": "microsoft/OmniParser",
39
- "model_type": ModelType.VISION,
40
- "capabilities": [
41
- ModelCapability.UI_DETECTION,
42
- ModelCapability.IMAGE_ANALYSIS,
43
- ModelCapability.IMAGE_UNDERSTANDING
44
- ],
45
- "revision": "main", # Latest version
46
- "metadata": {
47
- "description": "Microsoft OmniParser v2.0 - Advanced UI element detection",
48
- "provider": "microsoft",
49
- "model_family": "omniparser",
50
- "version": "2.0",
51
- "paper": "https://arxiv.org/abs/2408.00203",
52
- "huggingface_url": "https://huggingface.co/microsoft/OmniParser",
53
- "use_case": "UI element detection and parsing",
54
- "input_format": "image",
55
- "output_format": "structured_elements",
56
- "gpu_memory_mb": 8192,
57
- "inference_time_ms": 500
58
- }
59
- },
60
- {
61
- "model_id": "table-transformer-v1.1-detection",
62
- "repo_id": "microsoft/table-transformer-detection",
63
- "model_type": ModelType.VISION,
64
- "capabilities": [
65
- ModelCapability.TABLE_DETECTION,
66
- ModelCapability.IMAGE_ANALYSIS
67
- ],
68
- "revision": "main",
69
- "metadata": {
70
- "description": "Microsoft Table Transformer v1.1 - Table detection model",
71
- "provider": "microsoft",
72
- "model_family": "table-transformer",
73
- "version": "1.1",
74
- "paper": "https://arxiv.org/abs/2110.00061",
75
- "huggingface_url": "https://huggingface.co/microsoft/table-transformer-detection",
76
- "use_case": "Table detection in documents and images",
77
- "input_format": "image",
78
- "output_format": "bounding_boxes",
79
- "gpu_memory_mb": 4096,
80
- "inference_time_ms": 300
81
- }
82
- },
83
- {
84
- "model_id": "table-transformer-v1.1-structure",
85
- "repo_id": "microsoft/table-transformer-structure-recognition",
86
- "model_type": ModelType.VISION,
87
- "capabilities": [
88
- ModelCapability.TABLE_STRUCTURE_RECOGNITION,
89
- ModelCapability.IMAGE_ANALYSIS
90
- ],
91
- "revision": "main",
92
- "metadata": {
93
- "description": "Microsoft Table Transformer v1.1 - Table structure recognition",
94
- "provider": "microsoft",
95
- "model_family": "table-transformer",
96
- "version": "1.1",
97
- "paper": "https://arxiv.org/abs/2110.00061",
98
- "huggingface_url": "https://huggingface.co/microsoft/table-transformer-structure-recognition",
99
- "use_case": "Table structure recognition and cell extraction",
100
- "input_format": "image",
101
- "output_format": "table_structure",
102
- "gpu_memory_mb": 4096,
103
- "inference_time_ms": 400
104
- }
105
- },
106
- {
107
- "model_id": "paddleocr-v3.0",
108
- "repo_id": "PaddlePaddle/PaddleOCR",
109
- "model_type": ModelType.VISION,
110
- "capabilities": [
111
- ModelCapability.OCR,
112
- ModelCapability.IMAGE_ANALYSIS
113
- ],
114
- "revision": "release/2.8",
115
- "metadata": {
116
- "description": "PaddleOCR v3.0 - Multilingual OCR model",
117
- "provider": "paddlepaddle",
118
- "model_family": "paddleocr",
119
- "version": "3.0",
120
- "github_url": "https://github.com/PaddlePaddle/PaddleOCR",
121
- "huggingface_url": "https://huggingface.co/PaddlePaddle/PaddleOCR",
122
- "use_case": "Text extraction from images",
123
- "input_format": "image",
124
- "output_format": "text_with_coordinates",
125
- "languages": ["en", "ch", "multilingual"],
126
- "gpu_memory_mb": 2048,
127
- "inference_time_ms": 200
128
- }
129
- },
130
- {
131
- "model_id": "yolov8n-fallback",
132
- "repo_id": "ultralytics/yolov8",
133
- "model_type": ModelType.VISION,
134
- "capabilities": [
135
- ModelCapability.IMAGE_ANALYSIS,
136
- ModelCapability.UI_DETECTION # As fallback
137
- ],
138
- "revision": "main",
139
- "metadata": {
140
- "description": "YOLOv8 Nano - Fallback object detection model",
141
- "provider": "ultralytics",
142
- "model_family": "yolo",
143
- "version": "8.0",
144
- "github_url": "https://github.com/ultralytics/ultralytics",
145
- "use_case": "General object detection (fallback for UI elements)",
146
- "input_format": "image",
147
- "output_format": "bounding_boxes",
148
- "gpu_memory_mb": 1024,
149
- "inference_time_ms": 50
150
- }
151
- }
152
- ]
153
-
154
- # Register each model
155
- registration_results = []
156
-
157
- for model_config in models_to_register:
158
- print(f"\n📝 Registering {model_config['model_id']}...")
159
-
160
- try:
161
- # Register model in registry (without downloading)
162
- success = model_manager.registry.register_model(
163
- model_id=model_config['model_id'],
164
- model_type=model_config['model_type'],
165
- capabilities=model_config['capabilities'],
166
- metadata={
167
- **model_config['metadata'],
168
- 'repo_id': model_config['repo_id'],
169
- 'revision': model_config['revision'],
170
- 'registered_at': 'auto',
171
- 'download_status': 'not_downloaded'
172
- }
173
- )
174
-
175
- if success:
176
- print(f"✅ Successfully registered {model_config['model_id']}")
177
- registration_results.append({
178
- 'model_id': model_config['model_id'],
179
- 'status': 'success'
180
- })
181
- else:
182
- print(f"❌ Failed to register {model_config['model_id']}")
183
- registration_results.append({
184
- 'model_id': model_config['model_id'],
185
- 'status': 'failed'
186
- })
187
-
188
- except Exception as e:
189
- print(f"❌ Error registering {model_config['model_id']}: {e}")
190
- registration_results.append({
191
- 'model_id': model_config['model_id'],
192
- 'status': 'error',
193
- 'error': str(e)
194
- })
195
-
196
- # Print summary
197
- print(f"\n📊 Registration Summary:")
198
- successful = [r for r in registration_results if r['status'] == 'success']
199
- failed = [r for r in registration_results if r['status'] != 'success']
200
-
201
- print(f"✅ Successfully registered: {len(successful)} models")
202
- for result in successful:
203
- print(f" - {result['model_id']}")
204
-
205
- if failed:
206
- print(f"❌ Failed to register: {len(failed)} models")
207
- for result in failed:
208
- error_msg = f" ({result.get('error', 'unknown error')})" if 'error' in result else ""
209
- print(f" - {result['model_id']}{error_msg}")
210
-
211
- return registration_results
212
-
213
- async def verify_model_registry():
214
- """Verify registered models and their capabilities"""
215
-
216
- model_manager = ModelManager()
217
-
218
- print(f"\n🔍 Verifying Model Registry...")
219
-
220
- # Check models by capability
221
- capabilities_to_check = [
222
- ModelCapability.UI_DETECTION,
223
- ModelCapability.OCR,
224
- ModelCapability.TABLE_DETECTION,
225
- ModelCapability.TABLE_STRUCTURE_RECOGNITION
226
- ]
227
-
228
- for capability in capabilities_to_check:
229
- models = model_manager.registry.get_models_by_capability(capability)
230
- print(f"\n📋 Models with {capability.value} capability:")
231
-
232
- if models:
233
- for model_id, model_info in models.items():
234
- metadata = model_info.get('metadata', {})
235
- version = metadata.get('version', 'unknown')
236
- provider = metadata.get('provider', 'unknown')
237
- print(f" ✅ {model_id} (v{version}, {provider})")
238
- else:
239
- print(f" ❌ No models found for {capability.value}")
240
-
241
- # Print overall stats
242
- stats = model_manager.registry.get_stats()
243
- print(f"\n📈 Registry Statistics:")
244
- print(f" Total models: {stats['total_models']}")
245
- print(f" Models by type: {stats['models_by_type']}")
246
- print(f" Models by capability: {stats['models_by_capability']}")
247
-
248
- def get_model_for_capability(capability: ModelCapability) -> str:
249
- """Get the best model for a specific capability"""
250
-
251
- model_manager = ModelManager()
252
- models = model_manager.registry.get_models_by_capability(capability)
253
-
254
- if not models:
255
- return None
256
-
257
- # Priority order for UI analysis models
258
- priority_order = {
259
- ModelCapability.UI_DETECTION: [
260
- "omniparser-v2.0",
261
- "yolov8n-fallback"
262
- ],
263
- ModelCapability.OCR: [
264
- "paddleocr-v3.0"
265
- ],
266
- ModelCapability.TABLE_DETECTION: [
267
- "table-transformer-v1.1-detection"
268
- ],
269
- ModelCapability.TABLE_STRUCTURE_RECOGNITION: [
270
- "table-transformer-v1.1-structure"
271
- ]
272
- }
273
-
274
- preferred_models = priority_order.get(capability, [])
275
-
276
- # Return the first available preferred model
277
- for model_id in preferred_models:
278
- if model_id in models:
279
- return model_id
280
-
281
- # Fallback to first available model
282
- return list(models.keys())[0] if models else None
283
-
284
- async def main():
285
- """Main registration workflow"""
286
-
287
- print("🚀 ISA Model Registry - UI Analysis Models Registration")
288
- print("=" * 60)
289
-
290
- try:
291
- # Register models
292
- results = await register_ui_analysis_models()
293
-
294
- # Verify registration
295
- await verify_model_registry()
296
-
297
- print(f"\n🎉 Model registration completed!")
298
- print(f" Use ModelManager.get_model() to download and use models")
299
- print(f" Use get_model_for_capability() to get recommended models")
300
-
301
- # Show usage example
302
- print(f"\n💡 Usage Example:")
303
- print(f" from isa_model.core.model_manager import ModelManager")
304
- print(f" from isa_model.core.model_repo import ModelCapability")
305
- print(f" ")
306
- print(f" manager = ModelManager()")
307
- print(f" ui_model_path = await manager.get_model(")
308
- print(f" model_id='omniparser-v2.0',")
309
- print(f" repo_id='microsoft/OmniParser',")
310
- print(f" model_type=ModelType.VISION,")
311
- print(f" capabilities=[ModelCapability.UI_DETECTION]")
312
- print(f" )")
313
-
314
- except Exception as e:
315
- print(f"❌ Registration failed: {e}")
316
- return False
317
-
318
- return True
319
-
320
- if __name__ == "__main__":
321
- asyncio.run(main())
@@ -1,248 +0,0 @@
1
- import os
2
- import json
3
- import logging
4
- from typing import Dict, List, Any, Optional, Union
5
- from fastapi import FastAPI, HTTPException, Depends, Request
6
- from pydantic import BaseModel, Field
7
-
8
- from isa_model.inference.ai_factory import AIFactory
9
-
10
- # Configure logging
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger("unified_api")
13
-
14
- # Create FastAPI app
15
- app = FastAPI(
16
- title="Unified AI Model API",
17
- description="API for inference with Llama3-8B, Gemma3-4B, Whisper, and BGE-M3 models",
18
- version="1.0.0"
19
- )
20
-
21
- # Models
22
- class ChatMessage(BaseModel):
23
- role: str = Field(..., description="Role of the message sender (system, user, assistant)")
24
- content: str = Field(..., description="Content of the message")
25
-
26
- class ChatCompletionRequest(BaseModel):
27
- model: str = Field(..., description="Model ID to use (llama, gemma)")
28
- messages: List[ChatMessage] = Field(..., description="List of messages in the conversation")
29
- temperature: Optional[float] = Field(0.7, description="Sampling temperature")
30
- max_tokens: Optional[int] = Field(512, description="Maximum number of tokens to generate")
31
- top_p: Optional[float] = Field(0.9, description="Top-p sampling parameter")
32
- top_k: Optional[int] = Field(50, description="Top-k sampling parameter")
33
-
34
- class ChatCompletionResponse(BaseModel):
35
- model: str = Field(..., description="Model used for completion")
36
- choices: List[Dict[str, Any]] = Field(..., description="Generated completions")
37
- usage: Dict[str, int] = Field(..., description="Token usage statistics")
38
-
39
- class EmbeddingRequest(BaseModel):
40
- model: str = Field(..., description="Model ID to use (bge_embed)")
41
- input: Union[str, List[str]] = Field(..., description="Text to embed")
42
- normalize: Optional[bool] = Field(True, description="Whether to normalize embeddings")
43
-
44
- class TranscriptionRequest(BaseModel):
45
- model: str = Field(..., description="Model ID to use (whisper)")
46
- audio: str = Field(..., description="Base64-encoded audio data or URL")
47
- language: Optional[str] = Field("en", description="Language code")
48
-
49
- # Factory for creating services
50
- ai_factory = AIFactory()
51
-
52
- # Dependency to get LLM service
53
- async def get_llm_service(model: str):
54
- if model == "llama":
55
- return await ai_factory.get_llm_service("llama")
56
- elif model == "gemma":
57
- return await ai_factory.get_llm_service("gemma")
58
- else:
59
- raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
60
-
61
- # Dependency to get embedding service
62
- async def get_embedding_service(model: str):
63
- if model == "bge_embed":
64
- return await ai_factory.get_embedding_service("bge_embed")
65
- else:
66
- raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
67
-
68
- # Dependency to get speech service
69
- async def get_speech_service(model: str):
70
- if model == "whisper":
71
- return await ai_factory.get_speech_service("whisper")
72
- else:
73
- raise HTTPException(status_code=400, detail=f"Unsupported model: {model}")
74
-
75
- # Endpoints
76
- @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
77
- async def chat_completion(request: ChatCompletionRequest):
78
- """Generate chat completion"""
79
- try:
80
- # Get the appropriate service
81
- service = await get_llm_service(request.model)
82
-
83
- # Format messages
84
- formatted_messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
85
-
86
- # Extract system prompt if present
87
- system_prompt = None
88
- if formatted_messages and formatted_messages[0]["role"] == "system":
89
- system_prompt = formatted_messages[0]["content"]
90
- formatted_messages = formatted_messages[1:]
91
-
92
- # Get user prompt (last user message)
93
- user_prompt = ""
94
- for msg in reversed(formatted_messages):
95
- if msg["role"] == "user":
96
- user_prompt = msg["content"]
97
- break
98
-
99
- if not user_prompt:
100
- raise HTTPException(status_code=400, detail="No user message found")
101
-
102
- # Set generation config
103
- generation_config = {
104
- "temperature": request.temperature,
105
- "max_new_tokens": request.max_tokens,
106
- "top_p": request.top_p,
107
- "top_k": request.top_k
108
- }
109
-
110
- # Generate completion
111
- completion = await service.generate(
112
- prompt=user_prompt,
113
- system_prompt=system_prompt,
114
- generation_config=generation_config
115
- )
116
-
117
- # Format response
118
- response = {
119
- "model": request.model,
120
- "choices": [
121
- {
122
- "message": {
123
- "role": "assistant",
124
- "content": completion
125
- },
126
- "finish_reason": "stop",
127
- "index": 0
128
- }
129
- ],
130
- "usage": {
131
- "prompt_tokens": len(user_prompt.split()),
132
- "completion_tokens": len(completion.split()),
133
- "total_tokens": len(user_prompt.split()) + len(completion.split())
134
- }
135
- }
136
-
137
- return response
138
-
139
- except Exception as e:
140
- logger.error(f"Error in chat completion: {str(e)}")
141
- raise HTTPException(status_code=500, detail=str(e))
142
-
143
- @app.post("/v1/embeddings")
144
- async def create_embedding(request: EmbeddingRequest):
145
- """Generate embeddings for text"""
146
- try:
147
- # Get the embedding service
148
- service = await get_embedding_service("bge_embed")
149
-
150
- # Generate embeddings
151
- if isinstance(request.input, str):
152
- embeddings = await service.embed(request.input, normalize=request.normalize)
153
- data = [{"embedding": embeddings[0].tolist(), "index": 0}]
154
- else:
155
- embeddings = await service.embed(request.input, normalize=request.normalize)
156
- data = [{"embedding": emb.tolist(), "index": i} for i, emb in enumerate(embeddings)]
157
-
158
- # Format response
159
- response = {
160
- "model": request.model,
161
- "data": data,
162
- "usage": {
163
- "prompt_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input])),
164
- "total_tokens": sum(len(text.split()) for text in (request.input if isinstance(request.input, list) else [request.input]))
165
- }
166
- }
167
-
168
- return response
169
-
170
- except Exception as e:
171
- logger.error(f"Error in embedding generation: {str(e)}")
172
- raise HTTPException(status_code=500, detail=str(e))
173
-
174
- @app.post("/v1/audio/transcriptions")
175
- async def transcribe_audio(request: TranscriptionRequest):
176
- """Transcribe audio to text"""
177
- try:
178
- import base64
179
-
180
- # Get the speech service
181
- service = await get_speech_service("whisper")
182
-
183
- # Process audio
184
- if request.audio.startswith(("http://", "https://")):
185
- # URL - download audio
186
- import requests
187
- audio_data = requests.get(request.audio).content
188
- else:
189
- # Base64 - decode
190
- audio_data = base64.b64decode(request.audio)
191
-
192
- # Transcribe
193
- transcription = await service.transcribe(
194
- audio=audio_data,
195
- language=request.language
196
- )
197
-
198
- # Format response
199
- response = {
200
- "model": request.model,
201
- "text": transcription
202
- }
203
-
204
- return response
205
-
206
- except Exception as e:
207
- logger.error(f"Error in audio transcription: {str(e)}")
208
- raise HTTPException(status_code=500, detail=str(e))
209
-
210
- # Health check endpoint
211
- @app.get("/health")
212
- async def health_check():
213
- """Health check endpoint"""
214
- return {"status": "healthy"}
215
-
216
- # Model info endpoint
217
- @app.get("/v1/models")
218
- async def list_models():
219
- """List available models"""
220
- models = [
221
- {
222
- "id": "llama",
223
- "type": "llm",
224
- "description": "Llama3-8B language model"
225
- },
226
- {
227
- "id": "gemma",
228
- "type": "llm",
229
- "description": "Gemma3-4B language model"
230
- },
231
- {
232
- "id": "whisper",
233
- "type": "speech",
234
- "description": "Whisper-tiny speech-to-text model"
235
- },
236
- {
237
- "id": "bge_embed",
238
- "type": "embedding",
239
- "description": "BGE-M3 text embedding model"
240
- }
241
- ]
242
-
243
- return {"data": models}
244
-
245
- # Main entry point
246
- if __name__ == "__main__":
247
- import uvicorn
248
- uvicorn.run(app, host="0.0.0.0", port=8080)