isa-model 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {isa_model-0.3.0 → isa_model-0.3.2}/PKG-INFO +1 -1
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/core/model_manager.py +69 -4
- isa_model-0.3.2/isa_model/inference/ai_factory.py +550 -0
- isa_model-0.3.2/isa_model/inference/billing_tracker.py +406 -0
- isa_model-0.3.2/isa_model/inference/providers/base_provider.py +77 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model-0.3.2/isa_model/inference/providers/openai_provider.py +130 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/providers/replicate_provider.py +42 -30
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model-0.3.2/isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model-0.3.2/isa_model/inference/services/audio/openai_stt_service.py +252 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/audio/openai_tts_service.py +48 -9
- isa_model-0.3.2/isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/base_service.py +36 -1
- isa_model-0.3.2/isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model-0.3.2/isa_model/inference/services/llm/base_llm_service.py +140 -0
- isa_model-0.3.2/isa_model/inference/services/llm/llm_adapter.py +459 -0
- isa_model-0.3.2/isa_model/inference/services/llm/ollama_llm_service.py +233 -0
- isa_model-0.3.2/isa_model/inference/services/llm/openai_llm_service.py +199 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/vision/ollama_vision_service.py +11 -3
- isa_model-0.3.2/isa_model/inference/services/vision/openai_vision_service.py +314 -0
- isa_model-0.3.2/isa_model/inference/services/vision/replicate_image_gen_service.py +345 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model.egg-info/PKG-INFO +1 -1
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model.egg-info/SOURCES.txt +6 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/pyproject.toml +1 -1
- isa_model-0.3.2/tests/test_all_services.py +531 -0
- isa_model-0.3.0/isa_model/inference/ai_factory.py +0 -261
- isa_model-0.3.0/isa_model/inference/providers/base_provider.py +0 -30
- isa_model-0.3.0/isa_model/inference/providers/openai_provider.py +0 -101
- isa_model-0.3.0/isa_model/inference/services/embedding/openai_embed_service.py +0 -0
- isa_model-0.3.0/isa_model/inference/services/llm/base_llm_service.py +0 -244
- isa_model-0.3.0/isa_model/inference/services/llm/ollama_llm_service.py +0 -307
- isa_model-0.3.0/isa_model/inference/services/llm/openai_llm_service.py +0 -444
- isa_model-0.3.0/isa_model/inference/services/vision/openai_vision_service.py +0 -80
- isa_model-0.3.0/isa_model/inference/services/vision/replicate_image_gen_service.py +0 -317
- {isa_model-0.3.0 → isa_model-0.3.2}/MANIFEST.in +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/README.md +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/core/model_registry.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/core/model_storage.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/core/storage/hf_storage.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/core/storage/local_storage.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/core/storage/minio_storage.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/deployment/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/deployment/core/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/deployment/core/deployment_config.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/deployment/core/deployment_manager.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/deployment/core/isa_deployment_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/deployment/gpu_int8_ds8/app/server.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/deployment/gpu_int8_ds8/scripts/test_client.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/deployment/gpu_int8_ds8/scripts/test_client_os.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/eval/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/eval/benchmarks.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/eval/factory.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/eval/metrics.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/adapter/unified_api.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/base.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/providers/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/providers/ml_provider.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/providers/model_cache_manager.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/providers/triton_provider.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/audio/base_tts_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/embedding/base_embed_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/embedding/ollama_embed_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/llm/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/llm/triton_llm_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/ml/base_ml_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/ml/sklearn_ml_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/others/table_transformer_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/vision/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/vision/base_image_gen_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/vision/base_vision_service.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/services/vision/helpers/text_splitter.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/utils/conversion/bge_rerank_convert.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/utils/conversion/onnx_converter.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/inference/utils/conversion/torch_converter.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/scripts/inference_tracker.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/scripts/mlflow_manager.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/scripts/model_registry.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/scripts/start_mlflow.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/scripts/training_tracker.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/annotation/annotation_schema.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/annotation/processors/annotation_processor.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/annotation/storage/dataset_manager.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/annotation/storage/dataset_schema.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/annotation/tests/test_annotation_flow.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/annotation/tests/test_minio copy.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/annotation/tests/test_minio_upload.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/annotation/views/annotation_controller.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/cloud/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/cloud/job_orchestrator.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/cloud/runpod_trainer.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/cloud/storage_manager.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/core/__init__.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/core/config.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/core/dataset.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/core/trainer.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/core/utils.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model/training/factory.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model.egg-info/dependency_links.txt +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model.egg-info/requires.txt +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/isa_model.egg-info/top_level.txt +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/setup.cfg +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/setup.py +0 -0
- {isa_model-0.3.0 → isa_model-0.3.2}/tests/test_training_setup.py +0 -0
@@ -2,7 +2,7 @@ from typing import Dict, Optional, List, Any
|
|
2
2
|
import logging
|
3
3
|
from pathlib import Path
|
4
4
|
from huggingface_hub import hf_hub_download, snapshot_download
|
5
|
-
from huggingface_hub.
|
5
|
+
from huggingface_hub.errors import HfHubHTTPError
|
6
6
|
from .model_storage import ModelStorage, LocalModelStorage
|
7
7
|
from .model_registry import ModelRegistry, ModelType, ModelCapability
|
8
8
|
|
@@ -11,19 +11,81 @@ logger = logging.getLogger(__name__)
|
|
11
11
|
class ModelManager:
|
12
12
|
"""Model management service for handling model downloads, versions, and caching"""
|
13
13
|
|
14
|
+
# 统一的模型计费信息 (per 1M tokens)
|
15
|
+
MODEL_PRICING = {
|
16
|
+
# OpenAI Models
|
17
|
+
"openai": {
|
18
|
+
"gpt-4o-mini": {"input": 0.15, "output": 0.6},
|
19
|
+
"gpt-4.1-mini": {"input": 0.4, "output": 1.6},
|
20
|
+
"gpt-4.1-nano": {"input": 0.1, "output": 0.4},
|
21
|
+
"gpt-4o": {"input": 5.0, "output": 15.0},
|
22
|
+
"gpt-4-turbo": {"input": 10.0, "output": 30.0},
|
23
|
+
"gpt-4": {"input": 30.0, "output": 60.0},
|
24
|
+
"gpt-3.5-turbo": {"input": 0.5, "output": 1.5},
|
25
|
+
"text-embedding-3-small": {"input": 0.02, "output": 0.0},
|
26
|
+
"text-embedding-3-large": {"input": 0.13, "output": 0.0},
|
27
|
+
"whisper-1": {"input": 6.0, "output": 0.0},
|
28
|
+
"tts-1": {"input": 15.0, "output": 0.0},
|
29
|
+
"tts-1-hd": {"input": 30.0, "output": 0.0},
|
30
|
+
},
|
31
|
+
# Ollama Models (免费本地模型)
|
32
|
+
"ollama": {
|
33
|
+
"llama3.2:3b-instruct-fp16": {"input": 0.0, "output": 0.0},
|
34
|
+
"llama3.2-vision:latest": {"input": 0.0, "output": 0.0},
|
35
|
+
"bge-m3": {"input": 0.0, "output": 0.0},
|
36
|
+
},
|
37
|
+
# Replicate Models
|
38
|
+
"replicate": {
|
39
|
+
"black-forest-labs/flux-schnell": {"input": 3.0, "output": 0.0}, # $3 per 1000 images
|
40
|
+
"black-forest-labs/flux-kontext-pro": {"input": 40.0, "output": 0.0}, # $0.04 per image = $40 per 1000 images
|
41
|
+
"meta/meta-llama-3-8b-instruct": {"input": 0.05, "output": 0.25},
|
42
|
+
"kokoro-82m": {"input": 0.0, "output": 0.4}, # ~$0.0004 per second
|
43
|
+
"jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13": {"input": 0.0, "output": 0.4},
|
44
|
+
}
|
45
|
+
}
|
46
|
+
|
14
47
|
def __init__(self,
|
15
48
|
storage: Optional[ModelStorage] = None,
|
16
49
|
registry: Optional[ModelRegistry] = None):
|
17
50
|
self.storage = storage or LocalModelStorage()
|
18
51
|
self.registry = registry or ModelRegistry()
|
19
52
|
|
53
|
+
def get_model_pricing(self, provider: str, model_name: str) -> Dict[str, float]:
|
54
|
+
"""获取模型定价信息"""
|
55
|
+
return self.MODEL_PRICING.get(provider, {}).get(model_name, {"input": 0.0, "output": 0.0})
|
56
|
+
|
57
|
+
def calculate_cost(self, provider: str, model_name: str, input_tokens: int, output_tokens: int) -> float:
|
58
|
+
"""计算请求成本"""
|
59
|
+
pricing = self.get_model_pricing(provider, model_name)
|
60
|
+
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
61
|
+
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
62
|
+
return input_cost + output_cost
|
63
|
+
|
64
|
+
def get_cheapest_model(self, provider: str, model_type: str = "llm") -> Optional[str]:
|
65
|
+
"""获取最便宜的模型"""
|
66
|
+
provider_models = self.MODEL_PRICING.get(provider, {})
|
67
|
+
if not provider_models:
|
68
|
+
return None
|
69
|
+
|
70
|
+
# 计算每个模型的平均成本 (假设输入输出比例 1:1)
|
71
|
+
cheapest_model = None
|
72
|
+
lowest_cost = float('inf')
|
73
|
+
|
74
|
+
for model_name, pricing in provider_models.items():
|
75
|
+
avg_cost = (pricing["input"] + pricing["output"]) / 2
|
76
|
+
if avg_cost < lowest_cost:
|
77
|
+
lowest_cost = avg_cost
|
78
|
+
cheapest_model = model_name
|
79
|
+
|
80
|
+
return cheapest_model
|
81
|
+
|
20
82
|
async def get_model(self,
|
21
83
|
model_id: str,
|
22
84
|
repo_id: str,
|
23
85
|
model_type: ModelType,
|
24
86
|
capabilities: List[ModelCapability],
|
25
87
|
revision: Optional[str] = None,
|
26
|
-
force_download: bool = False) -> Path:
|
88
|
+
force_download: bool = False) -> Optional[Path]:
|
27
89
|
"""
|
28
90
|
Get model files, downloading if necessary
|
29
91
|
|
@@ -36,7 +98,7 @@ class ModelManager:
|
|
36
98
|
force_download: Force re-download even if cached
|
37
99
|
|
38
100
|
Returns:
|
39
|
-
Path to the model files
|
101
|
+
Path to the model files or None if failed
|
40
102
|
"""
|
41
103
|
# Check if model is already downloaded
|
42
104
|
if not force_download:
|
@@ -80,7 +142,10 @@ class ModelManager:
|
|
80
142
|
|
81
143
|
except HfHubHTTPError as e:
|
82
144
|
logger.error(f"Failed to download model {model_id}: {e}")
|
83
|
-
|
145
|
+
return None
|
146
|
+
except Exception as e:
|
147
|
+
logger.error(f"Unexpected error downloading model {model_id}: {e}")
|
148
|
+
return None
|
84
149
|
|
85
150
|
async def list_models(self) -> List[Dict[str, Any]]:
|
86
151
|
"""List all downloaded models with their metadata"""
|
@@ -0,0 +1,550 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Simplified AI Factory for creating inference services
|
6
|
+
Uses the new service architecture with proper base classes and centralized API key management
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Dict, Type, Any, Optional, Tuple, List, TYPE_CHECKING, cast
|
10
|
+
import logging
|
11
|
+
from isa_model.inference.providers.base_provider import BaseProvider
|
12
|
+
from isa_model.inference.services.base_service import BaseService
|
13
|
+
from isa_model.inference.base import ModelType
|
14
|
+
from isa_model.inference.services.vision.base_vision_service import BaseVisionService
|
15
|
+
from isa_model.inference.services.vision.base_image_gen_service import BaseImageGenService
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from isa_model.inference.services.audio.base_stt_service import BaseSTTService
|
19
|
+
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
class AIFactory:
|
24
|
+
"""
|
25
|
+
Simplified Factory for creating AI services with proper inheritance hierarchy
|
26
|
+
API key management is handled by individual providers
|
27
|
+
"""
|
28
|
+
|
29
|
+
_instance = None
|
30
|
+
_is_initialized = False
|
31
|
+
|
32
|
+
def __new__(cls):
|
33
|
+
if cls._instance is None:
|
34
|
+
cls._instance = super().__new__(cls)
|
35
|
+
return cls._instance
|
36
|
+
|
37
|
+
def __init__(self):
|
38
|
+
"""Initialize the AI Factory."""
|
39
|
+
if not self._is_initialized:
|
40
|
+
self._providers: Dict[str, Type[BaseProvider]] = {}
|
41
|
+
self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
|
42
|
+
self._cached_services: Dict[str, BaseService] = {}
|
43
|
+
self._initialize_services()
|
44
|
+
AIFactory._is_initialized = True
|
45
|
+
|
46
|
+
def _initialize_services(self):
|
47
|
+
"""Initialize available providers and services"""
|
48
|
+
try:
|
49
|
+
# Register Ollama services
|
50
|
+
self._register_ollama_services()
|
51
|
+
|
52
|
+
# Register OpenAI services
|
53
|
+
self._register_openai_services()
|
54
|
+
|
55
|
+
# Register Replicate services
|
56
|
+
self._register_replicate_services()
|
57
|
+
|
58
|
+
logger.info("AI Factory initialized with centralized provider API key management")
|
59
|
+
|
60
|
+
except Exception as e:
|
61
|
+
logger.error(f"Error initializing services: {e}")
|
62
|
+
logger.warning("Some services may not be available")
|
63
|
+
|
64
|
+
def _register_ollama_services(self):
|
65
|
+
"""Register Ollama provider and services"""
|
66
|
+
try:
|
67
|
+
from isa_model.inference.providers.ollama_provider import OllamaProvider
|
68
|
+
from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
|
69
|
+
from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
|
70
|
+
from isa_model.inference.services.vision.ollama_vision_service import OllamaVisionService
|
71
|
+
|
72
|
+
self.register_provider('ollama', OllamaProvider)
|
73
|
+
self.register_service('ollama', ModelType.LLM, OllamaLLMService)
|
74
|
+
self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
|
75
|
+
self.register_service('ollama', ModelType.VISION, OllamaVisionService)
|
76
|
+
|
77
|
+
logger.info("Ollama services registered successfully")
|
78
|
+
|
79
|
+
except ImportError as e:
|
80
|
+
logger.warning(f"Ollama services not available: {e}")
|
81
|
+
|
82
|
+
def _register_openai_services(self):
|
83
|
+
"""Register OpenAI provider and services"""
|
84
|
+
try:
|
85
|
+
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
86
|
+
from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
|
87
|
+
from isa_model.inference.services.audio.openai_tts_service import OpenAITTSService
|
88
|
+
from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
|
89
|
+
from isa_model.inference.services.embedding.openai_embed_service import OpenAIEmbedService
|
90
|
+
from isa_model.inference.services.vision.openai_vision_service import OpenAIVisionService
|
91
|
+
|
92
|
+
self.register_provider('openai', OpenAIProvider)
|
93
|
+
self.register_service('openai', ModelType.LLM, OpenAILLMService)
|
94
|
+
self.register_service('openai', ModelType.AUDIO, OpenAITTSService)
|
95
|
+
self.register_service('openai', ModelType.EMBEDDING, OpenAIEmbedService)
|
96
|
+
self.register_service('openai', ModelType.VISION, OpenAIVisionService)
|
97
|
+
|
98
|
+
logger.info("OpenAI services registered successfully")
|
99
|
+
|
100
|
+
except ImportError as e:
|
101
|
+
logger.warning(f"OpenAI services not available: {e}")
|
102
|
+
|
103
|
+
def _register_replicate_services(self):
|
104
|
+
"""Register Replicate provider and services"""
|
105
|
+
try:
|
106
|
+
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
107
|
+
from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateImageGenService
|
108
|
+
from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
|
109
|
+
|
110
|
+
self.register_provider('replicate', ReplicateProvider)
|
111
|
+
self.register_service('replicate', ModelType.VISION, ReplicateImageGenService)
|
112
|
+
self.register_service('replicate', ModelType.AUDIO, ReplicateTTSService)
|
113
|
+
|
114
|
+
logger.info("Replicate services registered successfully")
|
115
|
+
|
116
|
+
except ImportError as e:
|
117
|
+
logger.warning(f"Replicate services not available: {e}")
|
118
|
+
|
119
|
+
def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
|
120
|
+
"""Register an AI provider"""
|
121
|
+
self._providers[name] = provider_class
|
122
|
+
|
123
|
+
def register_service(self, provider_name: str, model_type: ModelType,
|
124
|
+
service_class: Type[BaseService]) -> None:
|
125
|
+
"""Register a service type with its provider"""
|
126
|
+
self._services[(provider_name, model_type)] = service_class
|
127
|
+
|
128
|
+
def create_service(self, provider_name: str, model_type: ModelType,
|
129
|
+
model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
|
130
|
+
"""Create a service instance with provider-managed configuration"""
|
131
|
+
try:
|
132
|
+
cache_key = f"{provider_name}_{model_type}_{model_name}"
|
133
|
+
|
134
|
+
if cache_key in self._cached_services:
|
135
|
+
return self._cached_services[cache_key]
|
136
|
+
|
137
|
+
# Get provider and service classes
|
138
|
+
provider_class = self._providers.get(provider_name)
|
139
|
+
service_class = self._services.get((provider_name, model_type))
|
140
|
+
|
141
|
+
if not provider_class:
|
142
|
+
raise ValueError(f"No provider registered for '{provider_name}'")
|
143
|
+
|
144
|
+
if not service_class:
|
145
|
+
raise ValueError(
|
146
|
+
f"No service registered for provider '{provider_name}' and model type '{model_type}'"
|
147
|
+
)
|
148
|
+
|
149
|
+
# Create provider with user config (provider handles .env loading)
|
150
|
+
provider = provider_class(config=config)
|
151
|
+
service = service_class(provider=provider, model_name=model_name)
|
152
|
+
|
153
|
+
self._cached_services[cache_key] = service
|
154
|
+
return service
|
155
|
+
|
156
|
+
except Exception as e:
|
157
|
+
logger.error(f"Error creating service: {e}")
|
158
|
+
raise
|
159
|
+
|
160
|
+
# Convenient methods for common services with updated defaults
|
161
|
+
def get_llm_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
162
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
163
|
+
"""
|
164
|
+
Get a LLM service instance with automatic defaults
|
165
|
+
|
166
|
+
Args:
|
167
|
+
model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-nano", Ollama="llama3.2:3b")
|
168
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
169
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
170
|
+
Can include: streaming=True/False, temperature, max_tokens, etc.
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
LLM service instance
|
174
|
+
"""
|
175
|
+
# Set defaults based on provider
|
176
|
+
if provider == "openai":
|
177
|
+
final_model_name = model_name or "gpt-4.1-nano"
|
178
|
+
final_provider = provider
|
179
|
+
elif provider == "ollama":
|
180
|
+
final_model_name = model_name or "llama3.2:3b-instruct-fp16"
|
181
|
+
final_provider = provider
|
182
|
+
else:
|
183
|
+
# Default provider selection - OpenAI with cheapest model
|
184
|
+
final_provider = provider or "openai"
|
185
|
+
if final_provider == "openai":
|
186
|
+
final_model_name = model_name or "gpt-4.1-nano"
|
187
|
+
else:
|
188
|
+
final_model_name = model_name or "llama3.2:3b-instruct-fp16"
|
189
|
+
|
190
|
+
return self.create_service(final_provider, ModelType.LLM, final_model_name, config)
|
191
|
+
|
192
|
+
def get_embedding_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
193
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
194
|
+
"""
|
195
|
+
Get an embedding service instance with automatic defaults
|
196
|
+
|
197
|
+
Args:
|
198
|
+
model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
|
199
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
200
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
Embedding service instance
|
204
|
+
"""
|
205
|
+
# Set defaults based on provider
|
206
|
+
if provider == "openai":
|
207
|
+
final_model_name = model_name or "text-embedding-3-small"
|
208
|
+
final_provider = provider
|
209
|
+
elif provider == "ollama":
|
210
|
+
final_model_name = model_name or "bge-m3"
|
211
|
+
final_provider = provider
|
212
|
+
else:
|
213
|
+
# Default provider selection
|
214
|
+
final_provider = provider or "openai"
|
215
|
+
if final_provider == "openai":
|
216
|
+
final_model_name = model_name or "text-embedding-3-small"
|
217
|
+
else:
|
218
|
+
final_model_name = model_name or "bge-m3"
|
219
|
+
|
220
|
+
return self.create_service(final_provider, ModelType.EMBEDDING, final_model_name, config)
|
221
|
+
|
222
|
+
def get_vision_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
223
|
+
config: Optional[Dict[str, Any]] = None) -> BaseVisionService:
|
224
|
+
"""
|
225
|
+
Get a vision service instance with automatic defaults
|
226
|
+
|
227
|
+
Args:
|
228
|
+
model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-mini", Ollama="gemma3:4b")
|
229
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
230
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
Vision service instance
|
234
|
+
"""
|
235
|
+
# Set defaults based on provider
|
236
|
+
if provider == "openai":
|
237
|
+
final_model_name = model_name or "gpt-4.1-mini"
|
238
|
+
final_provider = provider
|
239
|
+
elif provider == "ollama":
|
240
|
+
final_model_name = model_name or "llama3.2-vision:latest"
|
241
|
+
final_provider = provider
|
242
|
+
else:
|
243
|
+
# Default provider selection
|
244
|
+
final_provider = provider or "openai"
|
245
|
+
if final_provider == "openai":
|
246
|
+
final_model_name = model_name or "gpt-4.1-mini"
|
247
|
+
else:
|
248
|
+
final_model_name = model_name or "llama3.2-vision:latest"
|
249
|
+
|
250
|
+
return cast(BaseVisionService, self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
251
|
+
|
252
|
+
def get_image_generation_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
253
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
|
254
|
+
"""
|
255
|
+
Get an image generation service instance with automatic defaults
|
256
|
+
|
257
|
+
Args:
|
258
|
+
model_name: Name of the model to use (defaults: "black-forest-labs/flux-schnell" for production)
|
259
|
+
provider: Provider name (defaults to 'replicate')
|
260
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
Image generation service instance
|
264
|
+
"""
|
265
|
+
# Set defaults based on provider
|
266
|
+
final_provider = provider or "replicate"
|
267
|
+
if final_provider == "replicate":
|
268
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
269
|
+
else:
|
270
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
271
|
+
|
272
|
+
return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
273
|
+
|
274
|
+
def get_img(self, type: str = "t2i", model_name: Optional[str] = None, provider: Optional[str] = None,
|
275
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
|
276
|
+
"""
|
277
|
+
Get an image generation service with type-specific defaults
|
278
|
+
|
279
|
+
Args:
|
280
|
+
type: Image generation type:
|
281
|
+
- "t2i" (text-to-image): Uses flux-schnell ($3 per 1000 images)
|
282
|
+
- "i2i" (image-to-image): Uses flux-kontext-pro ($0.04 per image)
|
283
|
+
model_name: Optional model name override
|
284
|
+
provider: Provider name (defaults to 'replicate')
|
285
|
+
config: Optional configuration dictionary
|
286
|
+
|
287
|
+
Returns:
|
288
|
+
Image generation service instance
|
289
|
+
|
290
|
+
Usage:
|
291
|
+
# Text-to-image (default)
|
292
|
+
img_service = AIFactory().get_img()
|
293
|
+
img_service = AIFactory().get_img(type="t2i")
|
294
|
+
|
295
|
+
# Image-to-image
|
296
|
+
img_service = AIFactory().get_img(type="i2i")
|
297
|
+
|
298
|
+
# Custom model
|
299
|
+
img_service = AIFactory().get_img(type="t2i", model_name="custom-model")
|
300
|
+
"""
|
301
|
+
# Set defaults based on type
|
302
|
+
final_provider = provider or "replicate"
|
303
|
+
|
304
|
+
if type == "t2i":
|
305
|
+
# Text-to-image: flux-schnell
|
306
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
307
|
+
elif type == "i2i":
|
308
|
+
# Image-to-image: flux-kontext-pro
|
309
|
+
final_model_name = model_name or "black-forest-labs/flux-kontext-pro"
|
310
|
+
else:
|
311
|
+
raise ValueError(f"Unknown image generation type: {type}. Use 't2i' or 'i2i'")
|
312
|
+
|
313
|
+
return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
314
|
+
|
315
|
+
def get_audio_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
316
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
317
|
+
"""
|
318
|
+
Get an audio service instance (TTS) with automatic defaults
|
319
|
+
|
320
|
+
Args:
|
321
|
+
model_name: Name of the model to use (defaults: OpenAI="tts-1")
|
322
|
+
provider: Provider name (defaults to 'openai')
|
323
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
Audio service instance
|
327
|
+
"""
|
328
|
+
# Set defaults based on provider
|
329
|
+
final_provider = provider or "openai"
|
330
|
+
if final_provider == "openai":
|
331
|
+
final_model_name = model_name or "tts-1"
|
332
|
+
else:
|
333
|
+
final_model_name = model_name or "tts-1"
|
334
|
+
|
335
|
+
return self.create_service(final_provider, ModelType.AUDIO, final_model_name, config)
|
336
|
+
|
337
|
+
def get_tts_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
338
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
|
339
|
+
"""
|
340
|
+
Get a Text-to-Speech service instance with automatic defaults
|
341
|
+
|
342
|
+
Args:
|
343
|
+
model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
|
344
|
+
provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
|
345
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
TTS service instance
|
349
|
+
"""
|
350
|
+
# Set defaults based on provider
|
351
|
+
if provider == "replicate":
|
352
|
+
model_name = model_name or "kokoro-82m"
|
353
|
+
elif provider == "openai":
|
354
|
+
model_name = model_name or "tts-1"
|
355
|
+
else:
|
356
|
+
# Default provider selection
|
357
|
+
provider = provider or "replicate"
|
358
|
+
if provider == "replicate":
|
359
|
+
model_name = model_name or "kokoro-82m"
|
360
|
+
else:
|
361
|
+
model_name = model_name or "tts-1"
|
362
|
+
|
363
|
+
# Ensure model_name is never None
|
364
|
+
if model_name is None:
|
365
|
+
model_name = "tts-1"
|
366
|
+
|
367
|
+
if provider == "replicate":
|
368
|
+
from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
|
369
|
+
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
370
|
+
|
371
|
+
# Use full model name for Replicate
|
372
|
+
if model_name == "kokoro-82m":
|
373
|
+
model_name = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"
|
374
|
+
|
375
|
+
provider_instance = ReplicateProvider(config=config)
|
376
|
+
return ReplicateTTSService(provider=provider_instance, model_name=model_name)
|
377
|
+
else:
|
378
|
+
return cast('BaseTTSService', self.get_audio_service(model_name, provider, config))
|
379
|
+
|
380
|
+
def get_stt_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
381
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
|
382
|
+
"""
|
383
|
+
Get a Speech-to-Text service instance with automatic defaults
|
384
|
+
|
385
|
+
Args:
|
386
|
+
model_name: Name of the model to use (defaults: "whisper-1")
|
387
|
+
provider: Provider name (defaults to 'openai')
|
388
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
389
|
+
|
390
|
+
Returns:
|
391
|
+
STT service instance
|
392
|
+
"""
|
393
|
+
# Set defaults based on provider
|
394
|
+
provider = provider or "openai"
|
395
|
+
if provider == "openai":
|
396
|
+
model_name = model_name or "whisper-1"
|
397
|
+
|
398
|
+
# Ensure model_name is never None
|
399
|
+
if model_name is None:
|
400
|
+
model_name = "whisper-1"
|
401
|
+
|
402
|
+
from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
|
403
|
+
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
404
|
+
|
405
|
+
# Create provider and service directly with config
|
406
|
+
provider_instance = OpenAIProvider(config=config)
|
407
|
+
return OpenAISTTService(provider=provider_instance, model_name=model_name)
|
408
|
+
|
409
|
+
def get_available_services(self) -> Dict[str, List[str]]:
|
410
|
+
"""Get information about available services"""
|
411
|
+
services = {}
|
412
|
+
for (provider, model_type), service_class in self._services.items():
|
413
|
+
if provider not in services:
|
414
|
+
services[provider] = []
|
415
|
+
services[provider].append(f"{model_type.value}: {service_class.__name__}")
|
416
|
+
return services
|
417
|
+
|
418
|
+
def clear_cache(self):
|
419
|
+
"""Clear the service cache"""
|
420
|
+
self._cached_services.clear()
|
421
|
+
logger.info("Service cache cleared")
|
422
|
+
|
423
|
+
@classmethod
|
424
|
+
def get_instance(cls) -> 'AIFactory':
|
425
|
+
"""Get the singleton instance"""
|
426
|
+
if cls._instance is None:
|
427
|
+
cls._instance = cls()
|
428
|
+
return cls._instance
|
429
|
+
|
430
|
+
# Alias method for cleaner API
|
431
|
+
def get_llm(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
432
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
433
|
+
"""
|
434
|
+
Alias for get_llm_service with cleaner naming
|
435
|
+
|
436
|
+
Usage:
|
437
|
+
llm = AIFactory().get_llm() # Uses gpt-4.1-nano by default
|
438
|
+
llm = AIFactory().get_llm(model_name="llama3.2", provider="ollama")
|
439
|
+
llm = AIFactory().get_llm(model_name="gpt-4.1-mini", provider="openai", config={"streaming": True})
|
440
|
+
"""
|
441
|
+
return self.get_llm_service(model_name, provider, config)
|
442
|
+
|
443
|
+
def get_embed(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
444
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseEmbedService':
|
445
|
+
"""
|
446
|
+
Get embedding service with automatic defaults
|
447
|
+
|
448
|
+
Args:
|
449
|
+
model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
|
450
|
+
provider: Provider name (defaults to 'openai' for production)
|
451
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
Embedding service instance
|
455
|
+
|
456
|
+
Usage:
|
457
|
+
# Default (OpenAI text-embedding-3-small)
|
458
|
+
embed = AIFactory().get_embed()
|
459
|
+
|
460
|
+
# Custom model
|
461
|
+
embed = AIFactory().get_embed(model_name="text-embedding-3-large", provider="openai")
|
462
|
+
|
463
|
+
# Development (Ollama)
|
464
|
+
embed = AIFactory().get_embed(provider="ollama")
|
465
|
+
"""
|
466
|
+
return self.get_embedding_service(model_name, provider, config)
|
467
|
+
|
468
|
+
def get_stt(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
469
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
|
470
|
+
"""
|
471
|
+
Get Speech-to-Text service with automatic defaults
|
472
|
+
|
473
|
+
Args:
|
474
|
+
model_name: Name of the model to use (defaults: "whisper-1")
|
475
|
+
provider: Provider name (defaults to 'openai')
|
476
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
477
|
+
|
478
|
+
Returns:
|
479
|
+
STT service instance
|
480
|
+
|
481
|
+
Usage:
|
482
|
+
# Default (OpenAI whisper-1)
|
483
|
+
stt = AIFactory().get_stt()
|
484
|
+
|
485
|
+
# Custom configuration
|
486
|
+
stt = AIFactory().get_stt(model_name="whisper-1", provider="openai")
|
487
|
+
"""
|
488
|
+
return self.get_stt_service(model_name, provider, config)
|
489
|
+
|
490
|
+
def get_tts(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
491
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
|
492
|
+
"""
|
493
|
+
Get Text-to-Speech service with automatic defaults
|
494
|
+
|
495
|
+
Args:
|
496
|
+
model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
|
497
|
+
provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
|
498
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
499
|
+
|
500
|
+
Returns:
|
501
|
+
TTS service instance
|
502
|
+
|
503
|
+
Usage:
|
504
|
+
# Default (Replicate kokoro-82m)
|
505
|
+
tts = AIFactory().get_tts()
|
506
|
+
|
507
|
+
# Development (OpenAI tts-1)
|
508
|
+
tts = AIFactory().get_tts(provider="openai")
|
509
|
+
|
510
|
+
# Custom model
|
511
|
+
tts = AIFactory().get_tts(model_name="tts-1-hd", provider="openai")
|
512
|
+
"""
|
513
|
+
return self.get_tts_service(model_name, provider, config)
|
514
|
+
|
515
|
+
def get_vision_model(self, model_name: str, provider: str,
|
516
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
517
|
+
"""Alias for get_vision_service and get_image_generation_service"""
|
518
|
+
if provider == "replicate":
|
519
|
+
return self.get_image_generation_service(model_name, provider, config)
|
520
|
+
else:
|
521
|
+
return self.get_vision_service(model_name, provider, config)
|
522
|
+
|
523
|
+
def get_vision(
|
524
|
+
self,
|
525
|
+
model_name: Optional[str] = None,
|
526
|
+
provider: Optional[str] = None,
|
527
|
+
config: Optional[Dict[str, Any]] = None
|
528
|
+
) -> 'BaseVisionService':
|
529
|
+
"""
|
530
|
+
Get vision service with automatic defaults
|
531
|
+
|
532
|
+
Args:
|
533
|
+
model_name: Model name (default: gpt-4.1-nano)
|
534
|
+
provider: Provider name (default: openai)
|
535
|
+
config: Optional configuration override
|
536
|
+
|
537
|
+
Returns:
|
538
|
+
Vision service instance
|
539
|
+
"""
|
540
|
+
# Set defaults
|
541
|
+
if provider is None:
|
542
|
+
provider = "openai"
|
543
|
+
if model_name is None:
|
544
|
+
model_name = "gpt-4.1-nano"
|
545
|
+
|
546
|
+
return self.get_vision_service(
|
547
|
+
model_name=model_name,
|
548
|
+
provider=provider,
|
549
|
+
config=config
|
550
|
+
)
|