isa-model 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +1 -1
- isa_model/core/model_manager.py +69 -4
- isa_model/core/storage/hf_storage.py +419 -0
- isa_model/deployment/__init__.py +52 -0
- isa_model/deployment/core/__init__.py +34 -0
- isa_model/deployment/core/deployment_config.py +356 -0
- isa_model/deployment/core/deployment_manager.py +549 -0
- isa_model/deployment/core/isa_deployment_service.py +401 -0
- isa_model/eval/factory.py +381 -140
- isa_model/inference/ai_factory.py +427 -236
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ml_provider.py +50 -0
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +149 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/base_embed_service.py +112 -0
- isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/__init__.py +2 -0
- isa_model/inference/services/llm/base_llm_service.py +158 -86
- isa_model/inference/services/llm/llm_adapter.py +414 -0
- isa_model/inference/services/llm/ollama_llm_service.py +252 -63
- isa_model/inference/services/llm/openai_llm_service.py +231 -93
- isa_model/inference/services/llm/triton_llm_service.py +481 -0
- isa_model/inference/services/ml/base_ml_service.py +78 -0
- isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
- isa_model/inference/services/vision/__init__.py +3 -3
- isa_model/inference/services/vision/base_image_gen_service.py +161 -0
- isa_model/inference/services/vision/base_vision_service.py +177 -0
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +151 -17
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
- isa_model/training/__init__.py +62 -32
- isa_model/training/cloud/__init__.py +22 -0
- isa_model/training/cloud/job_orchestrator.py +402 -0
- isa_model/training/cloud/runpod_trainer.py +454 -0
- isa_model/training/cloud/storage_manager.py +482 -0
- isa_model/training/core/__init__.py +23 -0
- isa_model/training/core/config.py +181 -0
- isa_model/training/core/dataset.py +222 -0
- isa_model/training/core/trainer.py +720 -0
- isa_model/training/core/utils.py +213 -0
- isa_model/training/factory.py +229 -198
- isa_model-0.3.1.dist-info/METADATA +465 -0
- isa_model-0.3.1.dist-info/RECORD +91 -0
- isa_model/core/model_router.py +0 -226
- isa_model/core/model_version.py +0 -0
- isa_model/core/resource_manager.py +0 -202
- isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
- isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
- isa_model/training/engine/llama_factory/__init__.py +0 -39
- isa_model/training/engine/llama_factory/config.py +0 -115
- isa_model/training/engine/llama_factory/data_adapter.py +0 -284
- isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
- isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
- isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
- isa_model/training/engine/llama_factory/factory.py +0 -331
- isa_model/training/engine/llama_factory/rl.py +0 -254
- isa_model/training/engine/llama_factory/trainer.py +0 -171
- isa_model/training/image_model/configs/create_config.py +0 -37
- isa_model/training/image_model/configs/create_flux_config.py +0 -26
- isa_model/training/image_model/configs/create_lora_config.py +0 -21
- isa_model/training/image_model/prepare_massed_compute.py +0 -97
- isa_model/training/image_model/prepare_upload.py +0 -17
- isa_model/training/image_model/raw_data/create_captions.py +0 -16
- isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
- isa_model/training/image_model/raw_data/pre_processing.py +0 -200
- isa_model/training/image_model/train/train.py +0 -42
- isa_model/training/image_model/train/train_flux.py +0 -41
- isa_model/training/image_model/train/train_lora.py +0 -57
- isa_model/training/image_model/train_main.py +0 -25
- isa_model-0.2.0.dist-info/METADATA +0 -327
- isa_model-0.2.0.dist-info/RECORD +0 -92
- isa_model-0.2.0.dist-info/licenses/LICENSE +0 -21
- /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
- /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
- {isa_model-0.2.0.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,29 @@
|
|
1
|
-
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# -*- coding: utf-8 -*-
|
3
|
+
|
4
|
+
"""
|
5
|
+
Simplified AI Factory for creating inference services
|
6
|
+
Uses the new service architecture with proper base classes and centralized API key management
|
7
|
+
"""
|
8
|
+
|
9
|
+
from typing import Dict, Type, Any, Optional, Tuple, List, TYPE_CHECKING, cast
|
2
10
|
import logging
|
3
11
|
from isa_model.inference.providers.base_provider import BaseProvider
|
4
12
|
from isa_model.inference.services.base_service import BaseService
|
5
13
|
from isa_model.inference.base import ModelType
|
6
|
-
import
|
14
|
+
from isa_model.inference.services.vision.base_vision_service import BaseVisionService
|
15
|
+
from isa_model.inference.services.vision.base_image_gen_service import BaseImageGenService
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from isa_model.inference.services.audio.base_stt_service import BaseSTTService
|
19
|
+
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
7
20
|
|
8
|
-
# 设置基本的日志配置
|
9
|
-
logging.basicConfig(level=logging.INFO)
|
10
21
|
logger = logging.getLogger(__name__)
|
11
22
|
|
12
23
|
class AIFactory:
|
13
24
|
"""
|
14
|
-
Factory for creating AI services
|
25
|
+
Simplified Factory for creating AI services with proper inheritance hierarchy
|
26
|
+
API key management is handled by individual providers
|
15
27
|
"""
|
16
28
|
|
17
29
|
_instance = None
|
@@ -24,72 +36,85 @@ class AIFactory:
|
|
24
36
|
|
25
37
|
def __init__(self):
|
26
38
|
"""Initialize the AI Factory."""
|
27
|
-
self.triton_url = os.environ.get("TRITON_URL", "http://localhost:8000")
|
28
|
-
|
29
|
-
# Cache for services (singleton pattern)
|
30
|
-
self._llm_services = {}
|
31
|
-
self._embedding_services = {}
|
32
|
-
self._speech_services = {}
|
33
|
-
|
34
39
|
if not self._is_initialized:
|
35
40
|
self._providers: Dict[str, Type[BaseProvider]] = {}
|
36
41
|
self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
|
37
42
|
self._cached_services: Dict[str, BaseService] = {}
|
38
|
-
self.
|
43
|
+
self._initialize_services()
|
39
44
|
AIFactory._is_initialized = True
|
40
45
|
|
41
|
-
def
|
42
|
-
"""Initialize
|
46
|
+
def _initialize_services(self):
|
47
|
+
"""Initialize available providers and services"""
|
48
|
+
try:
|
49
|
+
# Register Ollama services
|
50
|
+
self._register_ollama_services()
|
51
|
+
|
52
|
+
# Register OpenAI services
|
53
|
+
self._register_openai_services()
|
54
|
+
|
55
|
+
# Register Replicate services
|
56
|
+
self._register_replicate_services()
|
57
|
+
|
58
|
+
logger.info("AI Factory initialized with centralized provider API key management")
|
59
|
+
|
60
|
+
except Exception as e:
|
61
|
+
logger.error(f"Error initializing services: {e}")
|
62
|
+
logger.warning("Some services may not be available")
|
63
|
+
|
64
|
+
def _register_ollama_services(self):
|
65
|
+
"""Register Ollama provider and services"""
|
43
66
|
try:
|
44
|
-
# Import providers and services
|
45
67
|
from isa_model.inference.providers.ollama_provider import OllamaProvider
|
46
|
-
from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
|
47
68
|
from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
|
69
|
+
from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
|
70
|
+
from isa_model.inference.services.vision.ollama_vision_service import OllamaVisionService
|
48
71
|
|
49
|
-
# Register Ollama provider and services
|
50
72
|
self.register_provider('ollama', OllamaProvider)
|
51
|
-
self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
|
52
73
|
self.register_service('ollama', ModelType.LLM, OllamaLLMService)
|
74
|
+
self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
|
75
|
+
self.register_service('ollama', ModelType.VISION, OllamaVisionService)
|
53
76
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
77
|
+
logger.info("Ollama services registered successfully")
|
78
|
+
|
79
|
+
except ImportError as e:
|
80
|
+
logger.warning(f"Ollama services not available: {e}")
|
81
|
+
|
82
|
+
def _register_openai_services(self):
|
83
|
+
"""Register OpenAI provider and services"""
|
84
|
+
try:
|
85
|
+
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
86
|
+
from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
|
87
|
+
from isa_model.inference.services.audio.openai_tts_service import OpenAITTSService
|
88
|
+
from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
|
89
|
+
from isa_model.inference.services.embedding.openai_embed_service import OpenAIEmbedService
|
90
|
+
from isa_model.inference.services.vision.openai_vision_service import OpenAIVisionService
|
91
|
+
|
92
|
+
self.register_provider('openai', OpenAIProvider)
|
93
|
+
self.register_service('openai', ModelType.LLM, OpenAILLMService)
|
94
|
+
self.register_service('openai', ModelType.AUDIO, OpenAITTSService)
|
95
|
+
self.register_service('openai', ModelType.EMBEDDING, OpenAIEmbedService)
|
96
|
+
self.register_service('openai', ModelType.VISION, OpenAIVisionService)
|
97
|
+
|
98
|
+
logger.info("OpenAI services registered successfully")
|
99
|
+
|
100
|
+
except ImportError as e:
|
101
|
+
logger.warning(f"OpenAI services not available: {e}")
|
102
|
+
|
103
|
+
def _register_replicate_services(self):
|
104
|
+
"""Register Replicate provider and services"""
|
105
|
+
try:
|
106
|
+
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
107
|
+
from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateImageGenService
|
108
|
+
from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
|
109
|
+
|
110
|
+
self.register_provider('replicate', ReplicateProvider)
|
111
|
+
self.register_service('replicate', ModelType.VISION, ReplicateImageGenService)
|
112
|
+
self.register_service('replicate', ModelType.AUDIO, ReplicateTTSService)
|
113
|
+
|
114
|
+
logger.info("Replicate services registered successfully")
|
115
|
+
|
116
|
+
except ImportError as e:
|
117
|
+
logger.warning(f"Replicate services not available: {e}")
|
93
118
|
|
94
119
|
def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
|
95
120
|
"""Register an AI provider"""
|
@@ -102,31 +127,27 @@ class AIFactory:
|
|
102
127
|
|
103
128
|
def create_service(self, provider_name: str, model_type: ModelType,
|
104
129
|
model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
|
105
|
-
"""Create a service instance"""
|
130
|
+
"""Create a service instance with provider-managed configuration"""
|
106
131
|
try:
|
107
132
|
cache_key = f"{provider_name}_{model_type}_{model_name}"
|
108
133
|
|
109
134
|
if cache_key in self._cached_services:
|
110
135
|
return self._cached_services[cache_key]
|
111
136
|
|
112
|
-
#
|
113
|
-
|
114
|
-
"log_level": "INFO"
|
115
|
-
}
|
116
|
-
|
117
|
-
# 合并配置
|
118
|
-
service_config = {**base_config, **(config or {})}
|
119
|
-
|
120
|
-
# 创建 provider 和 service
|
121
|
-
provider_class = self._providers[provider_name]
|
137
|
+
# Get provider and service classes
|
138
|
+
provider_class = self._providers.get(provider_name)
|
122
139
|
service_class = self._services.get((provider_name, model_type))
|
123
140
|
|
141
|
+
if not provider_class:
|
142
|
+
raise ValueError(f"No provider registered for '{provider_name}'")
|
143
|
+
|
124
144
|
if not service_class:
|
125
145
|
raise ValueError(
|
126
|
-
f"No service registered for provider {provider_name} and model type {model_type}"
|
146
|
+
f"No service registered for provider '{provider_name}' and model type '{model_type}'"
|
127
147
|
)
|
128
148
|
|
129
|
-
provider
|
149
|
+
# Create provider with user config (provider handles .env loading)
|
150
|
+
provider = provider_class(config=config)
|
130
151
|
service = service_class(provider=provider, model_name=model_name)
|
131
152
|
|
132
153
|
self._cached_services[cache_key] = service
|
@@ -136,224 +157,394 @@ class AIFactory:
|
|
136
157
|
logger.error(f"Error creating service: {e}")
|
137
158
|
raise
|
138
159
|
|
139
|
-
# Convenient methods for common services
|
140
|
-
def
|
141
|
-
|
160
|
+
# Convenient methods for common services with updated defaults
|
161
|
+
def get_llm_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
162
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
142
163
|
"""
|
143
|
-
Get a LLM service instance
|
164
|
+
Get a LLM service instance with automatic defaults
|
144
165
|
|
145
166
|
Args:
|
146
|
-
model_name: Name of the model to use
|
147
|
-
provider: Provider name (
|
148
|
-
config: Optional configuration dictionary
|
149
|
-
|
167
|
+
model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-nano", Ollama="llama3.2:3b")
|
168
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
169
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
170
|
+
Can include: streaming=True/False, temperature, max_tokens, etc.
|
150
171
|
|
151
172
|
Returns:
|
152
173
|
LLM service instance
|
153
|
-
|
154
|
-
Example:
|
155
|
-
# Using with API key directly
|
156
|
-
llm = AIFactory.get_instance().get_llm(
|
157
|
-
model_name="gpt-4o-mini",
|
158
|
-
provider="openai",
|
159
|
-
api_key="your-api-key-here"
|
160
|
-
)
|
161
|
-
|
162
|
-
# Using without API key (will use environment variable)
|
163
|
-
llm = AIFactory.get_instance().get_llm(
|
164
|
-
model_name="gpt-4o-mini",
|
165
|
-
provider="openai"
|
166
|
-
)
|
167
174
|
"""
|
175
|
+
# Set defaults based on provider
|
176
|
+
if provider == "openai":
|
177
|
+
final_model_name = model_name or "gpt-4.1-nano"
|
178
|
+
final_provider = provider
|
179
|
+
elif provider == "ollama":
|
180
|
+
final_model_name = model_name or "llama3.2:3b-instruct-fp16"
|
181
|
+
final_provider = provider
|
182
|
+
else:
|
183
|
+
# Default provider selection - OpenAI with cheapest model
|
184
|
+
final_provider = provider or "openai"
|
185
|
+
if final_provider == "openai":
|
186
|
+
final_model_name = model_name or "gpt-4.1-nano"
|
187
|
+
else:
|
188
|
+
final_model_name = model_name or "llama3.2:3b-instruct-fp16"
|
168
189
|
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
190
|
+
return self.create_service(final_provider, ModelType.LLM, final_model_name, config)
|
191
|
+
|
192
|
+
def get_embedding_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
193
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
194
|
+
"""
|
195
|
+
Get an embedding service instance with automatic defaults
|
173
196
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
197
|
+
Args:
|
198
|
+
model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
|
199
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
200
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
201
|
+
|
202
|
+
Returns:
|
203
|
+
Embedding service instance
|
204
|
+
"""
|
205
|
+
# Set defaults based on provider
|
206
|
+
if provider == "openai":
|
207
|
+
final_model_name = model_name or "text-embedding-3-small"
|
208
|
+
final_provider = provider
|
209
|
+
elif provider == "ollama":
|
210
|
+
final_model_name = model_name or "bge-m3"
|
211
|
+
final_provider = provider
|
212
|
+
else:
|
213
|
+
# Default provider selection
|
214
|
+
final_provider = provider or "openai"
|
215
|
+
if final_provider == "openai":
|
216
|
+
final_model_name = model_name or "text-embedding-3-small"
|
217
|
+
else:
|
218
|
+
final_model_name = model_name or "bge-m3"
|
178
219
|
|
179
|
-
|
180
|
-
|
181
|
-
|
220
|
+
return self.create_service(final_provider, ModelType.EMBEDDING, final_model_name, config)
|
221
|
+
|
222
|
+
def get_vision_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
223
|
+
config: Optional[Dict[str, Any]] = None) -> BaseVisionService:
|
224
|
+
"""
|
225
|
+
Get a vision service instance with automatic defaults
|
182
226
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
227
|
+
Args:
|
228
|
+
model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-mini", Ollama="gemma3:4b")
|
229
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
230
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
231
|
+
|
232
|
+
Returns:
|
233
|
+
Vision service instance
|
234
|
+
"""
|
235
|
+
# Set defaults based on provider
|
236
|
+
if provider == "openai":
|
237
|
+
final_model_name = model_name or "gpt-4.1-mini"
|
238
|
+
final_provider = provider
|
239
|
+
elif provider == "ollama":
|
240
|
+
final_model_name = model_name or "llama3.2-vision:latest"
|
241
|
+
final_provider = provider
|
242
|
+
else:
|
243
|
+
# Default provider selection
|
244
|
+
final_provider = provider or "openai"
|
245
|
+
if final_provider == "openai":
|
246
|
+
final_model_name = model_name or "gpt-4.1-mini"
|
189
247
|
else:
|
190
|
-
|
248
|
+
final_model_name = model_name or "llama3.2-vision:latest"
|
249
|
+
|
250
|
+
return cast(BaseVisionService, self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
251
|
+
|
252
|
+
def get_image_generation_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
253
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
|
254
|
+
"""
|
255
|
+
Get an image generation service instance with automatic defaults
|
256
|
+
|
257
|
+
Args:
|
258
|
+
model_name: Name of the model to use (defaults: "black-forest-labs/flux-schnell" for production)
|
259
|
+
provider: Provider name (defaults to 'replicate')
|
260
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
Image generation service instance
|
264
|
+
"""
|
265
|
+
# Set defaults based on provider
|
266
|
+
final_provider = provider or "replicate"
|
267
|
+
if final_provider == "replicate":
|
268
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
269
|
+
else:
|
270
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
191
271
|
|
192
|
-
|
193
|
-
basic_config.update(config)
|
194
|
-
return self.create_service(provider, ModelType.LLM, model_name, basic_config)
|
272
|
+
return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
195
273
|
|
196
|
-
def
|
197
|
-
|
274
|
+
def get_img(self, type: str = "t2i", model_name: Optional[str] = None, provider: Optional[str] = None,
|
275
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
|
198
276
|
"""
|
199
|
-
Get
|
277
|
+
Get an image generation service with type-specific defaults
|
200
278
|
|
201
279
|
Args:
|
202
|
-
|
203
|
-
|
280
|
+
type: Image generation type:
|
281
|
+
- "t2i" (text-to-image): Uses flux-schnell ($3 per 1000 images)
|
282
|
+
- "i2i" (image-to-image): Uses flux-kontext-pro ($0.04 per image)
|
283
|
+
model_name: Optional model name override
|
284
|
+
provider: Provider name (defaults to 'replicate')
|
204
285
|
config: Optional configuration dictionary
|
205
|
-
api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
|
206
286
|
|
207
287
|
Returns:
|
208
|
-
|
288
|
+
Image generation service instance
|
209
289
|
|
210
|
-
|
211
|
-
#
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
)
|
217
|
-
|
218
|
-
#
|
219
|
-
|
220
|
-
model_name="stability-ai/sdxl",
|
221
|
-
provider="replicate",
|
222
|
-
api_key="your-replicate-token"
|
223
|
-
)
|
290
|
+
Usage:
|
291
|
+
# Text-to-image (default)
|
292
|
+
img_service = AIFactory().get_img()
|
293
|
+
img_service = AIFactory().get_img(type="t2i")
|
294
|
+
|
295
|
+
# Image-to-image
|
296
|
+
img_service = AIFactory().get_img(type="i2i")
|
297
|
+
|
298
|
+
# Custom model
|
299
|
+
img_service = AIFactory().get_img(type="t2i", model_name="custom-model")
|
224
300
|
"""
|
301
|
+
# Set defaults based on type
|
302
|
+
final_provider = provider or "replicate"
|
303
|
+
|
304
|
+
if type == "t2i":
|
305
|
+
# Text-to-image: flux-schnell
|
306
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
307
|
+
elif type == "i2i":
|
308
|
+
# Image-to-image: flux-kontext-pro
|
309
|
+
final_model_name = model_name or "black-forest-labs/flux-kontext-pro"
|
310
|
+
else:
|
311
|
+
raise ValueError(f"Unknown image generation type: {type}. Use 't2i' or 'i2i'")
|
225
312
|
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
313
|
+
return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
314
|
+
|
315
|
+
def get_audio_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
316
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
317
|
+
"""
|
318
|
+
Get an audio service instance (TTS) with automatic defaults
|
230
319
|
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
320
|
+
Args:
|
321
|
+
model_name: Name of the model to use (defaults: OpenAI="tts-1")
|
322
|
+
provider: Provider name (defaults to 'openai')
|
323
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
324
|
+
|
325
|
+
Returns:
|
326
|
+
Audio service instance
|
327
|
+
"""
|
328
|
+
# Set defaults based on provider
|
329
|
+
final_provider = provider or "openai"
|
330
|
+
if final_provider == "openai":
|
331
|
+
final_model_name = model_name or "tts-1"
|
332
|
+
else:
|
333
|
+
final_model_name = model_name or "tts-1"
|
245
334
|
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
335
|
+
return self.create_service(final_provider, ModelType.AUDIO, final_model_name, config)
|
336
|
+
|
337
|
+
def get_tts_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
338
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
|
339
|
+
"""
|
340
|
+
Get a Text-to-Speech service instance with automatic defaults
|
250
341
|
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
342
|
+
Args:
|
343
|
+
model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
|
344
|
+
provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
|
345
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
TTS service instance
|
349
|
+
"""
|
350
|
+
# Set defaults based on provider
|
351
|
+
if provider == "replicate":
|
352
|
+
model_name = model_name or "kokoro-82m"
|
353
|
+
elif provider == "openai":
|
354
|
+
model_name = model_name or "tts-1"
|
355
|
+
else:
|
356
|
+
# Default provider selection
|
357
|
+
provider = provider or "replicate"
|
358
|
+
if provider == "replicate":
|
359
|
+
model_name = model_name or "kokoro-82m"
|
257
360
|
else:
|
258
|
-
|
361
|
+
model_name = model_name or "tts-1"
|
259
362
|
|
260
|
-
|
261
|
-
|
262
|
-
|
363
|
+
# Ensure model_name is never None
|
364
|
+
if model_name is None:
|
365
|
+
model_name = "tts-1"
|
366
|
+
|
367
|
+
if provider == "replicate":
|
368
|
+
from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
|
369
|
+
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
370
|
+
|
371
|
+
# Use full model name for Replicate
|
372
|
+
if model_name == "kokoro-82m":
|
373
|
+
model_name = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"
|
374
|
+
|
375
|
+
provider_instance = ReplicateProvider(config=config)
|
376
|
+
return ReplicateTTSService(provider=provider_instance, model_name=model_name)
|
377
|
+
else:
|
378
|
+
return cast('BaseTTSService', self.get_audio_service(model_name, provider, config))
|
263
379
|
|
264
|
-
def
|
265
|
-
|
266
|
-
"""
|
267
|
-
|
380
|
+
def get_stt_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
381
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
|
382
|
+
"""
|
383
|
+
Get a Speech-to-Text service instance with automatic defaults
|
384
|
+
|
385
|
+
Args:
|
386
|
+
model_name: Name of the model to use (defaults: "whisper-1")
|
387
|
+
provider: Provider name (defaults to 'openai')
|
388
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
389
|
+
|
390
|
+
Returns:
|
391
|
+
STT service instance
|
392
|
+
"""
|
393
|
+
# Set defaults based on provider
|
394
|
+
provider = provider or "openai"
|
395
|
+
if provider == "openai":
|
396
|
+
model_name = model_name or "whisper-1"
|
397
|
+
|
398
|
+
# Ensure model_name is never None
|
399
|
+
if model_name is None:
|
400
|
+
model_name = "whisper-1"
|
401
|
+
|
402
|
+
from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
|
403
|
+
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
404
|
+
|
405
|
+
# Create provider and service directly with config
|
406
|
+
provider_instance = OpenAIProvider(config=config)
|
407
|
+
return OpenAISTTService(provider=provider_instance, model_name=model_name)
|
268
408
|
|
269
|
-
def
|
270
|
-
|
271
|
-
|
272
|
-
|
409
|
+
def get_available_services(self) -> Dict[str, List[str]]:
|
410
|
+
"""Get information about available services"""
|
411
|
+
services = {}
|
412
|
+
for (provider, model_type), service_class in self._services.items():
|
413
|
+
if provider not in services:
|
414
|
+
services[provider] = []
|
415
|
+
services[provider].append(f"{model_type.value}: {service_class.__name__}")
|
416
|
+
return services
|
273
417
|
|
274
|
-
def
|
275
|
-
|
276
|
-
|
277
|
-
|
418
|
+
def clear_cache(self):
|
419
|
+
"""Clear the service cache"""
|
420
|
+
self._cached_services.clear()
|
421
|
+
logger.info("Service cache cleared")
|
278
422
|
|
279
|
-
|
280
|
-
|
281
|
-
"""Get
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
423
|
+
@classmethod
|
424
|
+
def get_instance(cls) -> 'AIFactory':
|
425
|
+
"""Get the singleton instance"""
|
426
|
+
if cls._instance is None:
|
427
|
+
cls._instance = cls()
|
428
|
+
return cls._instance
|
429
|
+
|
430
|
+
# Alias method for cleaner API
|
431
|
+
def get_llm(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
432
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
433
|
+
"""
|
434
|
+
Alias for get_llm_service with cleaner naming
|
287
435
|
|
288
|
-
|
289
|
-
|
290
|
-
"
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
|
436
|
+
Usage:
|
437
|
+
llm = AIFactory().get_llm() # Uses gpt-4.1-nano by default
|
438
|
+
llm = AIFactory().get_llm(model_name="llama3.2", provider="ollama")
|
439
|
+
llm = AIFactory().get_llm(model_name="gpt-4.1-mini", provider="openai", config={"streaming": True})
|
440
|
+
"""
|
441
|
+
return self.get_llm_service(model_name, provider, config)
|
295
442
|
|
296
|
-
|
443
|
+
def get_embed(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
444
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseEmbedService':
|
297
445
|
"""
|
298
|
-
Get
|
446
|
+
Get embedding service with automatic defaults
|
299
447
|
|
300
448
|
Args:
|
301
|
-
model_name: Name of the model
|
449
|
+
model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
|
450
|
+
provider: Provider name (defaults to 'openai' for production)
|
451
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
302
452
|
|
303
453
|
Returns:
|
304
454
|
Embedding service instance
|
455
|
+
|
456
|
+
Usage:
|
457
|
+
# Default (OpenAI text-embedding-3-small)
|
458
|
+
embed = AIFactory().get_embed()
|
459
|
+
|
460
|
+
# Custom model
|
461
|
+
embed = AIFactory().get_embed(model_name="text-embedding-3-large", provider="openai")
|
462
|
+
|
463
|
+
# Development (Ollama)
|
464
|
+
embed = AIFactory().get_embed(provider="ollama")
|
305
465
|
"""
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
raise ValueError(f"Unsupported embedding model: {model_name}")
|
311
|
-
|
312
|
-
async def get_speech_service(self, model_name: str) -> Any:
|
466
|
+
return self.get_embedding_service(model_name, provider, config)
|
467
|
+
|
468
|
+
def get_stt(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
469
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
|
313
470
|
"""
|
314
|
-
Get
|
471
|
+
Get Speech-to-Text service with automatic defaults
|
315
472
|
|
316
473
|
Args:
|
317
|
-
model_name: Name of the model
|
474
|
+
model_name: Name of the model to use (defaults: "whisper-1")
|
475
|
+
provider: Provider name (defaults to 'openai')
|
476
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
318
477
|
|
319
478
|
Returns:
|
320
|
-
|
479
|
+
STT service instance
|
480
|
+
|
481
|
+
Usage:
|
482
|
+
# Default (OpenAI whisper-1)
|
483
|
+
stt = AIFactory().get_stt()
|
484
|
+
|
485
|
+
# Custom configuration
|
486
|
+
stt = AIFactory().get_stt(model_name="whisper-1", provider="openai")
|
487
|
+
"""
|
488
|
+
return self.get_stt_service(model_name, provider, config)
|
489
|
+
|
490
|
+
def get_tts(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
491
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
|
321
492
|
"""
|
322
|
-
|
323
|
-
return self._speech_services[model_name]
|
493
|
+
Get Text-to-Speech service with automatic defaults
|
324
494
|
|
495
|
+
Args:
|
496
|
+
model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
|
497
|
+
provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
|
498
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
499
|
+
|
500
|
+
Returns:
|
501
|
+
TTS service instance
|
502
|
+
|
503
|
+
Usage:
|
504
|
+
# Default (Replicate kokoro-82m)
|
505
|
+
tts = AIFactory().get_tts()
|
506
|
+
|
507
|
+
# Development (OpenAI tts-1)
|
508
|
+
tts = AIFactory().get_tts(provider="openai")
|
509
|
+
|
510
|
+
# Custom model
|
511
|
+
tts = AIFactory().get_tts(model_name="tts-1-hd", provider="openai")
|
512
|
+
"""
|
513
|
+
return self.get_tts_service(model_name, provider, config)
|
514
|
+
|
515
|
+
def get_vision_model(self, model_name: str, provider: str,
|
516
|
+
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
517
|
+
"""Alias for get_vision_service and get_image_generation_service"""
|
518
|
+
if provider == "replicate":
|
519
|
+
return self.get_image_generation_service(model_name, provider, config)
|
520
|
+
else:
|
521
|
+
return self.get_vision_service(model_name, provider, config)
|
325
522
|
|
326
|
-
def
|
523
|
+
def get_vision(
|
524
|
+
self,
|
525
|
+
model_name: Optional[str] = None,
|
526
|
+
provider: Optional[str] = None,
|
527
|
+
config: Optional[Dict[str, Any]] = None
|
528
|
+
) -> 'BaseVisionService':
|
327
529
|
"""
|
328
|
-
Get
|
530
|
+
Get vision service with automatic defaults
|
329
531
|
|
330
532
|
Args:
|
331
|
-
|
533
|
+
model_name: Model name (default: gpt-4.1-nano)
|
534
|
+
provider: Provider name (default: openai)
|
535
|
+
config: Optional configuration override
|
332
536
|
|
333
537
|
Returns:
|
334
|
-
|
538
|
+
Vision service instance
|
335
539
|
"""
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
],
|
342
|
-
"embedding": [
|
343
|
-
{"name": "bge_embed", "description": "BGE-M3 text embedding model"}
|
344
|
-
],
|
345
|
-
"speech": [
|
346
|
-
{"name": "whisper", "description": "Whisper-tiny speech-to-text model"}
|
347
|
-
]
|
348
|
-
}
|
540
|
+
# Set defaults
|
541
|
+
if provider is None:
|
542
|
+
provider = "openai"
|
543
|
+
if model_name is None:
|
544
|
+
model_name = "gpt-4.1-nano"
|
349
545
|
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
def get_instance(cls) -> 'AIFactory':
|
356
|
-
"""Get the singleton instance"""
|
357
|
-
if cls._instance is None:
|
358
|
-
cls._instance = cls()
|
359
|
-
return cls._instance
|
546
|
+
return self.get_vision_service(
|
547
|
+
model_name=model_name,
|
548
|
+
provider=provider,
|
549
|
+
config=config
|
550
|
+
)
|