isa-model 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +5 -0
- isa_model/core/model_manager.py +143 -0
- isa_model/core/model_registry.py +115 -0
- isa_model/core/model_router.py +226 -0
- isa_model/core/model_storage.py +133 -0
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +202 -0
- isa_model/core/storage/hf_storage.py +0 -0
- isa_model/core/storage/local_storage.py +0 -0
- isa_model/core/storage/minio_storage.py +0 -0
- isa_model/deployment/mlflow_gateway/__init__.py +8 -0
- isa_model/deployment/mlflow_gateway/start_gateway.py +65 -0
- isa_model/deployment/unified_multimodal_client.py +341 -0
- isa_model/inference/__init__.py +11 -0
- isa_model/inference/adapter/triton_adapter.py +453 -0
- isa_model/inference/adapter/unified_api.py +248 -0
- isa_model/inference/ai_factory.py +354 -0
- isa_model/inference/backends/Pytorch/bge_embed_backend.py +188 -0
- isa_model/inference/backends/Pytorch/gemma_backend.py +167 -0
- isa_model/inference/backends/Pytorch/llama_backend.py +166 -0
- isa_model/inference/backends/Pytorch/whisper_backend.py +194 -0
- isa_model/inference/backends/__init__.py +53 -0
- isa_model/inference/backends/base_backend_client.py +26 -0
- isa_model/inference/backends/container_services.py +104 -0
- isa_model/inference/backends/local_services.py +72 -0
- isa_model/inference/backends/openai_client.py +130 -0
- isa_model/inference/backends/replicate_client.py +197 -0
- isa_model/inference/backends/third_party_services.py +239 -0
- isa_model/inference/backends/triton_client.py +97 -0
- isa_model/inference/base.py +46 -0
- isa_model/inference/client_sdk/__init__.py +0 -0
- isa_model/inference/client_sdk/client.py +134 -0
- isa_model/inference/client_sdk/client_data_std.py +34 -0
- isa_model/inference/client_sdk/client_sdk_schema.py +16 -0
- isa_model/inference/client_sdk/exceptions.py +0 -0
- isa_model/inference/engine/triton/model_repository/bge/1/model.py +174 -0
- isa_model/inference/engine/triton/model_repository/gemma/1/model.py +250 -0
- isa_model/inference/engine/triton/model_repository/llama/1/model.py +76 -0
- isa_model/inference/engine/triton/model_repository/whisper/1/model.py +195 -0
- isa_model/inference/providers/__init__.py +19 -0
- isa_model/inference/providers/base_provider.py +30 -0
- isa_model/inference/providers/model_cache_manager.py +341 -0
- isa_model/inference/providers/ollama_provider.py +73 -0
- isa_model/inference/providers/openai_provider.py +87 -0
- isa_model/inference/providers/replicate_provider.py +94 -0
- isa_model/inference/providers/triton_provider.py +439 -0
- isa_model/inference/providers/vllm_provider.py +0 -0
- isa_model/inference/providers/yyds_provider.py +83 -0
- isa_model/inference/services/__init__.py +14 -0
- isa_model/inference/services/audio/fish_speech/handler.py +215 -0
- isa_model/inference/services/audio/runpod_tts_fish_service.py +212 -0
- isa_model/inference/services/audio/triton_speech_service.py +138 -0
- isa_model/inference/services/audio/whisper_service.py +186 -0
- isa_model/inference/services/audio/yyds_audio_service.py +71 -0
- isa_model/inference/services/base_service.py +106 -0
- isa_model/inference/services/base_tts_service.py +66 -0
- isa_model/inference/services/embedding/bge_service.py +183 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +85 -0
- isa_model/inference/services/embedding/ollama_rerank_service.py +118 -0
- isa_model/inference/services/embedding/onnx_rerank_service.py +73 -0
- isa_model/inference/services/llm/__init__.py +16 -0
- isa_model/inference/services/llm/gemma_service.py +143 -0
- isa_model/inference/services/llm/llama_service.py +143 -0
- isa_model/inference/services/llm/ollama_llm_service.py +108 -0
- isa_model/inference/services/llm/openai_llm_service.py +129 -0
- isa_model/inference/services/llm/replicate_llm_service.py +179 -0
- isa_model/inference/services/llm/triton_llm_service.py +230 -0
- isa_model/inference/services/others/table_transformer_service.py +61 -0
- isa_model/inference/services/vision/__init__.py +12 -0
- isa_model/inference/services/vision/helpers/image_utils.py +58 -0
- isa_model/inference/services/vision/helpers/text_splitter.py +46 -0
- isa_model/inference/services/vision/ollama_vision_service.py +60 -0
- isa_model/inference/services/vision/replicate_vision_service.py +241 -0
- isa_model/inference/services/vision/triton_vision_service.py +199 -0
- isa_model/inference/services/vision/yyds_vision_service.py +80 -0
- isa_model/inference/utils/conversion/bge_rerank_convert.py +73 -0
- isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- isa_model/inference/utils/conversion/torch_converter.py +0 -0
- isa_model/scripts/inference_tracker.py +283 -0
- isa_model/scripts/mlflow_manager.py +379 -0
- isa_model/scripts/model_registry.py +465 -0
- isa_model/scripts/start_mlflow.py +95 -0
- isa_model/scripts/training_tracker.py +257 -0
- isa_model/training/engine/llama_factory/__init__.py +39 -0
- isa_model/training/engine/llama_factory/config.py +115 -0
- isa_model/training/engine/llama_factory/data_adapter.py +284 -0
- isa_model/training/engine/llama_factory/examples/__init__.py +6 -0
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +185 -0
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +163 -0
- isa_model/training/engine/llama_factory/factory.py +331 -0
- isa_model/training/engine/llama_factory/rl.py +254 -0
- isa_model/training/engine/llama_factory/trainer.py +171 -0
- isa_model/training/image_model/configs/create_config.py +37 -0
- isa_model/training/image_model/configs/create_flux_config.py +26 -0
- isa_model/training/image_model/configs/create_lora_config.py +21 -0
- isa_model/training/image_model/prepare_massed_compute.py +97 -0
- isa_model/training/image_model/prepare_upload.py +17 -0
- isa_model/training/image_model/raw_data/create_captions.py +16 -0
- isa_model/training/image_model/raw_data/create_lora_captions.py +20 -0
- isa_model/training/image_model/raw_data/pre_processing.py +200 -0
- isa_model/training/image_model/train/train.py +42 -0
- isa_model/training/image_model/train/train_flux.py +41 -0
- isa_model/training/image_model/train/train_lora.py +57 -0
- isa_model/training/image_model/train_main.py +25 -0
- isa_model/training/llm_model/annotation/annotation_schema.py +47 -0
- isa_model/training/llm_model/annotation/processors/annotation_processor.py +126 -0
- isa_model/training/llm_model/annotation/storage/dataset_manager.py +131 -0
- isa_model/training/llm_model/annotation/storage/dataset_schema.py +44 -0
- isa_model/training/llm_model/annotation/tests/test_annotation_flow.py +109 -0
- isa_model/training/llm_model/annotation/tests/test_minio copy.py +113 -0
- isa_model/training/llm_model/annotation/tests/test_minio_upload.py +43 -0
- isa_model/training/llm_model/annotation/views/annotation_controller.py +158 -0
- isa_model-0.1.0.dist-info/METADATA +116 -0
- isa_model-0.1.0.dist-info/RECORD +117 -0
- isa_model-0.1.0.dist-info/WHEEL +5 -0
- isa_model-0.1.0.dist-info/licenses/LICENSE +21 -0
- isa_model-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,202 @@
|
|
1
|
+
import time
|
2
|
+
import threading
|
3
|
+
import logging
|
4
|
+
from typing import Dict, List, Any, Optional
|
5
|
+
import psutil
|
6
|
+
try:
|
7
|
+
import torch
|
8
|
+
import nvidia_smi
|
9
|
+
HAS_GPU = torch.cuda.is_available()
|
10
|
+
except ImportError:
|
11
|
+
HAS_GPU = False
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
class ResourceManager:
|
16
|
+
"""
|
17
|
+
Monitors system resources and manages model loading/unloading
|
18
|
+
to prevent resource exhaustion.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, model_registry, monitoring_interval=30):
|
22
|
+
self.registry = model_registry
|
23
|
+
self.monitoring_interval = monitoring_interval # seconds
|
24
|
+
|
25
|
+
self.max_memory_percent = 90 # Maximum memory usage percentage
|
26
|
+
self.max_gpu_memory_percent = 90 # Maximum GPU memory usage percentage
|
27
|
+
|
28
|
+
self._stop_event = threading.Event()
|
29
|
+
self._monitor_thread = None
|
30
|
+
|
31
|
+
# Track resource usage over time
|
32
|
+
self.resource_history = {
|
33
|
+
"timestamps": [],
|
34
|
+
"cpu_percent": [],
|
35
|
+
"memory_percent": [],
|
36
|
+
"gpu_utilization": [],
|
37
|
+
"gpu_memory_percent": []
|
38
|
+
}
|
39
|
+
|
40
|
+
# Initialize GPU monitoring if available
|
41
|
+
if HAS_GPU:
|
42
|
+
try:
|
43
|
+
nvidia_smi.nvmlInit()
|
44
|
+
self.gpu_count = torch.cuda.device_count()
|
45
|
+
logger.info(f"Initialized GPU monitoring with {self.gpu_count} devices")
|
46
|
+
except Exception as e:
|
47
|
+
logger.warning(f"Failed to initialize NVIDIA SMI: {str(e)}")
|
48
|
+
self.gpu_count = 0
|
49
|
+
else:
|
50
|
+
self.gpu_count = 0
|
51
|
+
|
52
|
+
logger.info("Initialized ResourceManager")
|
53
|
+
|
54
|
+
def start_monitoring(self):
|
55
|
+
"""Start the resource monitoring thread"""
|
56
|
+
if self._monitor_thread is not None and self._monitor_thread.is_alive():
|
57
|
+
logger.warning("Resource monitoring already running")
|
58
|
+
return
|
59
|
+
|
60
|
+
self._stop_event.clear()
|
61
|
+
self._monitor_thread = threading.Thread(
|
62
|
+
target=self._monitor_resources,
|
63
|
+
daemon=True
|
64
|
+
)
|
65
|
+
self._monitor_thread.start()
|
66
|
+
logger.info("Started resource monitoring thread")
|
67
|
+
|
68
|
+
def stop_monitoring(self):
|
69
|
+
"""Stop the resource monitoring thread"""
|
70
|
+
if self._monitor_thread is not None:
|
71
|
+
self._stop_event.set()
|
72
|
+
self._monitor_thread.join(timeout=5)
|
73
|
+
self._monitor_thread = None
|
74
|
+
logger.info("Stopped resource monitoring thread")
|
75
|
+
|
76
|
+
def _monitor_resources(self):
|
77
|
+
"""Monitor system resources in a loop"""
|
78
|
+
while not self._stop_event.is_set():
|
79
|
+
try:
|
80
|
+
# Get current resource usage
|
81
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
82
|
+
memory = psutil.virtual_memory()
|
83
|
+
memory_percent = memory.percent
|
84
|
+
|
85
|
+
# GPU monitoring
|
86
|
+
gpu_utilization = 0
|
87
|
+
gpu_memory_percent = 0
|
88
|
+
|
89
|
+
if self.gpu_count > 0:
|
90
|
+
gpu_util_sum = 0
|
91
|
+
gpu_mem_percent_sum = 0
|
92
|
+
|
93
|
+
for i in range(self.gpu_count):
|
94
|
+
try:
|
95
|
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
|
96
|
+
util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
|
97
|
+
mem_info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
|
98
|
+
|
99
|
+
gpu_util_sum += util.gpu
|
100
|
+
gpu_mem_percent = (mem_info.used / mem_info.total) * 100
|
101
|
+
gpu_mem_percent_sum += gpu_mem_percent
|
102
|
+
except Exception as e:
|
103
|
+
logger.error(f"Error getting GPU {i} stats: {str(e)}")
|
104
|
+
|
105
|
+
if self.gpu_count > 0:
|
106
|
+
gpu_utilization = gpu_util_sum / self.gpu_count
|
107
|
+
gpu_memory_percent = gpu_mem_percent_sum / self.gpu_count
|
108
|
+
|
109
|
+
# Record history (keep last 60 samples)
|
110
|
+
now = time.time()
|
111
|
+
self.resource_history["timestamps"].append(now)
|
112
|
+
self.resource_history["cpu_percent"].append(cpu_percent)
|
113
|
+
self.resource_history["memory_percent"].append(memory_percent)
|
114
|
+
self.resource_history["gpu_utilization"].append(gpu_utilization)
|
115
|
+
self.resource_history["gpu_memory_percent"].append(gpu_memory_percent)
|
116
|
+
|
117
|
+
# Trim history to last 60 samples
|
118
|
+
max_history = 60
|
119
|
+
if len(self.resource_history["timestamps"]) > max_history:
|
120
|
+
for key in self.resource_history:
|
121
|
+
self.resource_history[key] = self.resource_history[key][-max_history:]
|
122
|
+
|
123
|
+
# Check if we need to unload models
|
124
|
+
self._check_resource_constraints(memory_percent, gpu_memory_percent)
|
125
|
+
|
126
|
+
# Log current usage
|
127
|
+
logger.debug(
|
128
|
+
f"Resource usage - CPU: {cpu_percent:.1f}%, Memory: {memory_percent:.1f}%, "
|
129
|
+
f"GPU: {gpu_utilization:.1f}%, GPU Memory: {gpu_memory_percent:.1f}%"
|
130
|
+
)
|
131
|
+
|
132
|
+
# Wait for next check
|
133
|
+
self._stop_event.wait(self.monitoring_interval)
|
134
|
+
|
135
|
+
except Exception as e:
|
136
|
+
logger.error(f"Error in resource monitoring: {str(e)}")
|
137
|
+
# Wait a bit before retrying
|
138
|
+
self._stop_event.wait(5)
|
139
|
+
|
140
|
+
def _check_resource_constraints(self, memory_percent, gpu_memory_percent):
|
141
|
+
"""Check if we need to unload models due to resource constraints"""
|
142
|
+
# Check memory usage
|
143
|
+
if memory_percent > self.max_memory_percent:
|
144
|
+
logger.warning(
|
145
|
+
f"Memory usage ({memory_percent:.1f}%) exceeds threshold ({self.max_memory_percent}%). "
|
146
|
+
"Unloading least used model."
|
147
|
+
)
|
148
|
+
# This would trigger model unloading
|
149
|
+
# self.registry._evict_least_used_model()
|
150
|
+
|
151
|
+
# Check GPU memory usage
|
152
|
+
if HAS_GPU and gpu_memory_percent > self.max_gpu_memory_percent:
|
153
|
+
logger.warning(
|
154
|
+
f"GPU memory usage ({gpu_memory_percent:.1f}%) exceeds threshold ({self.max_gpu_memory_percent}%). "
|
155
|
+
"Unloading least used model."
|
156
|
+
)
|
157
|
+
# This would trigger model unloading
|
158
|
+
# self.registry._evict_least_used_model()
|
159
|
+
|
160
|
+
def get_resource_usage(self) -> Dict[str, Any]:
|
161
|
+
"""Get current resource usage stats"""
|
162
|
+
try:
|
163
|
+
cpu_percent = psutil.cpu_percent(interval=0.1)
|
164
|
+
memory = psutil.virtual_memory()
|
165
|
+
memory_percent = memory.percent
|
166
|
+
|
167
|
+
result = {
|
168
|
+
"cpu_percent": cpu_percent,
|
169
|
+
"memory_total_gb": memory.total / (1024**3),
|
170
|
+
"memory_available_gb": memory.available / (1024**3),
|
171
|
+
"memory_percent": memory_percent,
|
172
|
+
"gpus": []
|
173
|
+
}
|
174
|
+
|
175
|
+
# GPU stats
|
176
|
+
if HAS_GPU:
|
177
|
+
for i in range(self.gpu_count):
|
178
|
+
try:
|
179
|
+
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i)
|
180
|
+
name = nvidia_smi.nvmlDeviceGetName(handle)
|
181
|
+
util = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
|
182
|
+
mem_info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle)
|
183
|
+
temp = nvidia_smi.nvmlDeviceGetTemperature(
|
184
|
+
handle, nvidia_smi.NVML_TEMPERATURE_GPU
|
185
|
+
)
|
186
|
+
|
187
|
+
result["gpus"].append({
|
188
|
+
"index": i,
|
189
|
+
"name": name,
|
190
|
+
"utilization_percent": util.gpu,
|
191
|
+
"memory_total_gb": mem_info.total / (1024**3),
|
192
|
+
"memory_used_gb": mem_info.used / (1024**3),
|
193
|
+
"memory_percent": (mem_info.used / mem_info.total) * 100,
|
194
|
+
"temperature_c": temp
|
195
|
+
})
|
196
|
+
except Exception as e:
|
197
|
+
logger.error(f"Error getting GPU {i} stats: {str(e)}")
|
198
|
+
|
199
|
+
return result
|
200
|
+
except Exception as e:
|
201
|
+
logger.error(f"Error getting resource usage: {str(e)}")
|
202
|
+
return {"error": str(e)}
|
File without changes
|
File without changes
|
File without changes
|
@@ -0,0 +1,65 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
MLflow Gateway starter script.
|
4
|
+
Replaces the custom adapter with industry-standard MLflow Gateway.
|
5
|
+
|
6
|
+
Usage:
|
7
|
+
python -m isa_model.deployment.mlflow_gateway.start_gateway
|
8
|
+
"""
|
9
|
+
|
10
|
+
import os
|
11
|
+
import sys
|
12
|
+
import logging
|
13
|
+
import subprocess
|
14
|
+
from pathlib import Path
|
15
|
+
|
16
|
+
# Configure logging
|
17
|
+
logging.basicConfig(
|
18
|
+
level=logging.INFO,
|
19
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
20
|
+
)
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
def start_mlflow_gateway():
|
25
|
+
"""Start MLflow Gateway with our configuration."""
|
26
|
+
|
27
|
+
# Get the directory containing this script
|
28
|
+
script_dir = Path(__file__).parent
|
29
|
+
config_file = script_dir / "gateway_config.yaml"
|
30
|
+
|
31
|
+
if not config_file.exists():
|
32
|
+
logger.error(f"Gateway config file not found: {config_file}")
|
33
|
+
sys.exit(1)
|
34
|
+
|
35
|
+
# Set environment variables
|
36
|
+
os.environ["MLFLOW_GATEWAY_CONFIG_PATH"] = str(config_file)
|
37
|
+
|
38
|
+
# MLflow Gateway command
|
39
|
+
cmd = [
|
40
|
+
"mlflow", "gateway", "start",
|
41
|
+
"--config-path", str(config_file),
|
42
|
+
"--host", "0.0.0.0",
|
43
|
+
"--port", "8000"
|
44
|
+
]
|
45
|
+
|
46
|
+
logger.info("🚀 Starting MLflow Gateway...")
|
47
|
+
logger.info(f"📁 Config file: {config_file}")
|
48
|
+
logger.info(f"🌐 Server: http://localhost:8000")
|
49
|
+
logger.info(f"📚 Docs: http://localhost:8000/docs")
|
50
|
+
|
51
|
+
try:
|
52
|
+
# Start the gateway
|
53
|
+
subprocess.run(cmd, check=True)
|
54
|
+
except KeyboardInterrupt:
|
55
|
+
logger.info("MLflow Gateway stopped by user")
|
56
|
+
except subprocess.CalledProcessError as e:
|
57
|
+
logger.error(f"MLflow Gateway failed to start: {e}")
|
58
|
+
sys.exit(1)
|
59
|
+
except Exception as e:
|
60
|
+
logger.error(f"Unexpected error: {e}")
|
61
|
+
sys.exit(1)
|
62
|
+
|
63
|
+
|
64
|
+
if __name__ == "__main__":
|
65
|
+
start_mlflow_gateway()
|
@@ -0,0 +1,341 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
"""
|
3
|
+
Unified Multimodal Client
|
4
|
+
|
5
|
+
This client provides a unified interface to different model types and modalities,
|
6
|
+
abstracting away the complexity of different backends and deployment strategies.
|
7
|
+
|
8
|
+
Features:
|
9
|
+
- Text generation (chat completion)
|
10
|
+
- Image generation
|
11
|
+
- Audio transcription
|
12
|
+
- Embeddings
|
13
|
+
|
14
|
+
Usage:
|
15
|
+
from isa_model.deployment.unified_multimodal_client import UnifiedClient
|
16
|
+
|
17
|
+
client = UnifiedClient()
|
18
|
+
|
19
|
+
# Text generation
|
20
|
+
response = client.chat_completion("What is MLflow?")
|
21
|
+
|
22
|
+
# Image generation
|
23
|
+
image_data = client.generate_image("A beautiful mountain landscape")
|
24
|
+
|
25
|
+
# Audio transcription
|
26
|
+
transcription = client.transcribe_audio(audio_base64)
|
27
|
+
|
28
|
+
# Embeddings
|
29
|
+
embeddings = client.get_embeddings("This is a test sentence.")
|
30
|
+
"""
|
31
|
+
|
32
|
+
import os
|
33
|
+
import json
|
34
|
+
import base64
|
35
|
+
import requests
|
36
|
+
import tempfile
|
37
|
+
from typing import List, Dict, Any, Optional, Union
|
38
|
+
from dataclasses import dataclass
|
39
|
+
from PIL import Image
|
40
|
+
import io
|
41
|
+
|
42
|
+
@dataclass
|
43
|
+
class DeploymentConfig:
|
44
|
+
"""Deployment configuration for a model type"""
|
45
|
+
name: str
|
46
|
+
endpoint: str
|
47
|
+
api_key: Optional[str] = None
|
48
|
+
|
49
|
+
class UnifiedClient:
|
50
|
+
"""Unified client for multimodal AI models"""
|
51
|
+
|
52
|
+
def __init__(self, adapter_url: str = "http://localhost:8300"):
|
53
|
+
"""Initialize the client with the adapter URL"""
|
54
|
+
self.adapter_url = adapter_url
|
55
|
+
|
56
|
+
# Configure deployment endpoints - directly to multimodal adapter
|
57
|
+
self.deployments = {
|
58
|
+
"text": DeploymentConfig(
|
59
|
+
name="default",
|
60
|
+
endpoint=f"{adapter_url}/v1/chat/completions"
|
61
|
+
),
|
62
|
+
"image": DeploymentConfig(
|
63
|
+
name="default",
|
64
|
+
endpoint=f"{adapter_url}/v1/images/generations"
|
65
|
+
),
|
66
|
+
"audio": DeploymentConfig(
|
67
|
+
name="default",
|
68
|
+
endpoint=f"{adapter_url}/v1/audio/transcriptions"
|
69
|
+
),
|
70
|
+
"embeddings": DeploymentConfig(
|
71
|
+
name="default",
|
72
|
+
endpoint=f"{adapter_url}/v1/embeddings"
|
73
|
+
)
|
74
|
+
}
|
75
|
+
|
76
|
+
def _make_request(self,
|
77
|
+
deployment_type: str,
|
78
|
+
payload: Dict[str, Any],
|
79
|
+
files: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
80
|
+
"""Make a request to the specified deployment type"""
|
81
|
+
if deployment_type not in self.deployments:
|
82
|
+
raise ValueError(f"Unsupported deployment type: {deployment_type}")
|
83
|
+
|
84
|
+
deployment = self.deployments[deployment_type]
|
85
|
+
|
86
|
+
headers = {
|
87
|
+
"Content-Type": "application/json"
|
88
|
+
}
|
89
|
+
|
90
|
+
if deployment.api_key:
|
91
|
+
headers["Authorization"] = f"Bearer {deployment.api_key}"
|
92
|
+
|
93
|
+
try:
|
94
|
+
if files:
|
95
|
+
# For multipart/form-data requests
|
96
|
+
response = requests.post(
|
97
|
+
deployment.endpoint,
|
98
|
+
data=payload,
|
99
|
+
files=files
|
100
|
+
)
|
101
|
+
else:
|
102
|
+
# Ensure model is included in the payload
|
103
|
+
if "model" not in payload:
|
104
|
+
payload["model"] = deployment.name
|
105
|
+
|
106
|
+
# For JSON requests
|
107
|
+
response = requests.post(
|
108
|
+
deployment.endpoint,
|
109
|
+
json=payload,
|
110
|
+
headers=headers
|
111
|
+
)
|
112
|
+
|
113
|
+
response.raise_for_status()
|
114
|
+
return response.json()
|
115
|
+
|
116
|
+
except Exception as e:
|
117
|
+
print(f"Error calling {deployment_type} endpoint: {str(e)}")
|
118
|
+
print(f"Response: {response.text if 'response' in locals() else 'No response'}")
|
119
|
+
raise
|
120
|
+
|
121
|
+
def chat_completion(self,
|
122
|
+
prompt: str,
|
123
|
+
system_prompt: Optional[str] = None,
|
124
|
+
max_tokens: int = 100,
|
125
|
+
temperature: float = 0.7) -> str:
|
126
|
+
"""Generate a chat completion response"""
|
127
|
+
messages = []
|
128
|
+
|
129
|
+
if system_prompt:
|
130
|
+
messages.append({
|
131
|
+
"role": "system",
|
132
|
+
"content": system_prompt
|
133
|
+
})
|
134
|
+
|
135
|
+
messages.append({
|
136
|
+
"role": "user",
|
137
|
+
"content": prompt
|
138
|
+
})
|
139
|
+
|
140
|
+
payload = {
|
141
|
+
"messages": messages,
|
142
|
+
"max_tokens": max_tokens,
|
143
|
+
"temperature": temperature
|
144
|
+
}
|
145
|
+
|
146
|
+
response = self._make_request("text", payload)
|
147
|
+
|
148
|
+
if "choices" in response and len(response["choices"]) > 0:
|
149
|
+
return response["choices"][0]["message"]["content"]
|
150
|
+
else:
|
151
|
+
return "Error: No response generated"
|
152
|
+
|
153
|
+
def generate_image(self,
|
154
|
+
prompt: str,
|
155
|
+
n: int = 1,
|
156
|
+
size: str = "1024x1024") -> str:
|
157
|
+
"""Generate an image from a text prompt"""
|
158
|
+
payload = {
|
159
|
+
"prompt": prompt,
|
160
|
+
"n": n,
|
161
|
+
"size": size
|
162
|
+
}
|
163
|
+
|
164
|
+
response = self._make_request("image", payload)
|
165
|
+
|
166
|
+
if "data" in response and len(response["data"]) > 0:
|
167
|
+
# Return the base64 data URL
|
168
|
+
return response["data"][0]["url"]
|
169
|
+
else:
|
170
|
+
return "Error: No image generated"
|
171
|
+
|
172
|
+
def save_image(self, image_data_url: str, output_path: str) -> None:
|
173
|
+
"""Save a base64 image data URL to a file"""
|
174
|
+
if image_data_url.startswith("data:image"):
|
175
|
+
# Extract the base64 part from the data URL
|
176
|
+
base64_data = image_data_url.split(",")[1]
|
177
|
+
|
178
|
+
# Decode the base64 data
|
179
|
+
image_data = base64.b64decode(base64_data)
|
180
|
+
|
181
|
+
# Save the image
|
182
|
+
with open(output_path, "wb") as f:
|
183
|
+
f.write(image_data)
|
184
|
+
|
185
|
+
print(f"Image saved to {output_path}")
|
186
|
+
else:
|
187
|
+
raise ValueError("Invalid image data URL format")
|
188
|
+
|
189
|
+
def transcribe_audio(self,
|
190
|
+
audio_data: Union[str, bytes],
|
191
|
+
language: str = "en") -> str:
|
192
|
+
"""
|
193
|
+
Transcribe audio to text
|
194
|
+
|
195
|
+
Parameters:
|
196
|
+
- audio_data: Either a base64 encoded string or raw bytes
|
197
|
+
- language: Language code
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
- Transcribed text
|
201
|
+
"""
|
202
|
+
# Convert bytes to base64 if needed
|
203
|
+
if isinstance(audio_data, bytes):
|
204
|
+
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
205
|
+
else:
|
206
|
+
# Assume it's already base64 encoded
|
207
|
+
audio_base64 = audio_data
|
208
|
+
|
209
|
+
payload = {
|
210
|
+
"file": audio_base64,
|
211
|
+
"language": language
|
212
|
+
}
|
213
|
+
|
214
|
+
response = self._make_request("audio", payload)
|
215
|
+
|
216
|
+
if "text" in response:
|
217
|
+
return response["text"]
|
218
|
+
else:
|
219
|
+
return "Error: No transcription generated"
|
220
|
+
|
221
|
+
def transcribe_audio_file(self,
|
222
|
+
file_path: str,
|
223
|
+
language: str = "en") -> str:
|
224
|
+
"""
|
225
|
+
Transcribe an audio file to text
|
226
|
+
|
227
|
+
Parameters:
|
228
|
+
- file_path: Path to the audio file
|
229
|
+
- language: Language code
|
230
|
+
|
231
|
+
Returns:
|
232
|
+
- Transcribed text
|
233
|
+
"""
|
234
|
+
with open(file_path, "rb") as f:
|
235
|
+
audio_data = f.read()
|
236
|
+
|
237
|
+
return self.transcribe_audio(audio_data, language)
|
238
|
+
|
239
|
+
def get_embeddings(self,
|
240
|
+
text: Union[str, List[str]]) -> List[List[float]]:
|
241
|
+
"""
|
242
|
+
Get embeddings for text or a list of texts
|
243
|
+
|
244
|
+
Parameters:
|
245
|
+
- text: Either a single string or a list of strings
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
- List of embedding vectors
|
249
|
+
"""
|
250
|
+
payload = {
|
251
|
+
"input": text
|
252
|
+
}
|
253
|
+
|
254
|
+
response = self._make_request("embeddings", payload)
|
255
|
+
|
256
|
+
if "data" in response:
|
257
|
+
return [item["embedding"] for item in response["data"]]
|
258
|
+
else:
|
259
|
+
return []
|
260
|
+
|
261
|
+
def similarity(self, text1: str, text2: str) -> float:
|
262
|
+
"""
|
263
|
+
Calculate the cosine similarity between two texts
|
264
|
+
|
265
|
+
Parameters:
|
266
|
+
- text1: First text
|
267
|
+
- text2: Second text
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
- Cosine similarity (0-1)
|
271
|
+
"""
|
272
|
+
import numpy as np
|
273
|
+
|
274
|
+
# Get embeddings for both texts
|
275
|
+
embeddings = self.get_embeddings([text1, text2])
|
276
|
+
|
277
|
+
if len(embeddings) != 2:
|
278
|
+
raise ValueError("Failed to get embeddings for both texts")
|
279
|
+
|
280
|
+
# Calculate cosine similarity
|
281
|
+
embedding1 = np.array(embeddings[0])
|
282
|
+
embedding2 = np.array(embeddings[1])
|
283
|
+
|
284
|
+
cos_sim = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2))
|
285
|
+
return float(cos_sim)
|
286
|
+
|
287
|
+
def health_check(self) -> bool:
|
288
|
+
"""Check if the adapter is healthy"""
|
289
|
+
try:
|
290
|
+
response = requests.get(f"{self.adapter_url}/health")
|
291
|
+
return response.status_code == 200
|
292
|
+
except Exception as e:
|
293
|
+
print(f"Health check failed: {str(e)}")
|
294
|
+
return False
|
295
|
+
|
296
|
+
if __name__ == "__main__":
|
297
|
+
# Test the client
|
298
|
+
client = UnifiedClient()
|
299
|
+
|
300
|
+
print("\n===== Unified Multimodal Client Demo =====")
|
301
|
+
|
302
|
+
# Check health
|
303
|
+
if not client.health_check():
|
304
|
+
print("Adapter is not healthy. Make sure it's running.")
|
305
|
+
exit(1)
|
306
|
+
|
307
|
+
# Test chat completion
|
308
|
+
print("\nTesting chat completion...")
|
309
|
+
response = client.chat_completion(
|
310
|
+
"What are the key benefits of MLflow?",
|
311
|
+
system_prompt="You are a helpful AI assistant specializing in machine learning.",
|
312
|
+
max_tokens=150
|
313
|
+
)
|
314
|
+
print(f"\nResponse: {response}")
|
315
|
+
|
316
|
+
# Test embeddings
|
317
|
+
print("\nTesting embeddings...")
|
318
|
+
embeddings = client.get_embeddings("What is MLflow?")
|
319
|
+
print(f"Embedding dimensionality: {len(embeddings[0])}")
|
320
|
+
print(f"First 5 values: {embeddings[0][:5]}")
|
321
|
+
|
322
|
+
# Test similarity
|
323
|
+
print("\nTesting similarity...")
|
324
|
+
similarity = client.similarity(
|
325
|
+
"MLflow is a platform for managing machine learning workflows.",
|
326
|
+
"MLflow helps data scientists track experiments and deploy models."
|
327
|
+
)
|
328
|
+
print(f"Similarity: {similarity:.4f}")
|
329
|
+
|
330
|
+
# Test image generation
|
331
|
+
print("\nTesting image generation...")
|
332
|
+
image_url = client.generate_image("A beautiful mountain landscape")
|
333
|
+
print(f"Image URL: {image_url[:30]}...")
|
334
|
+
|
335
|
+
# Test audio transcription
|
336
|
+
print("\nTesting audio transcription...")
|
337
|
+
dummy_audio = base64.b64encode(b"dummy audio data").decode("utf-8")
|
338
|
+
transcription = client.transcribe_audio(dummy_audio)
|
339
|
+
print(f"Transcription: {transcription}")
|
340
|
+
|
341
|
+
print("\n===== Demo Complete =====")
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
Inference module for isA_Model
|
3
|
+
|
4
|
+
File: isa_model/inference/__init__.py
|
5
|
+
This module provides the main inference components for the IsA Model system.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from .ai_factory import AIFactory
|
9
|
+
from .base import ModelType, Capability, RoutingStrategy
|
10
|
+
|
11
|
+
__all__ = ["AIFactory", "ModelType", "Capability", "RoutingStrategy"]
|