isa-model 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/core/model_manager.py +69 -4
- isa_model/inference/ai_factory.py +335 -46
- isa_model/inference/billing_tracker.py +406 -0
- isa_model/inference/providers/base_provider.py +51 -4
- isa_model/inference/providers/ollama_provider.py +37 -18
- isa_model/inference/providers/openai_provider.py +65 -36
- isa_model/inference/providers/replicate_provider.py +42 -30
- isa_model/inference/services/audio/base_stt_service.py +21 -2
- isa_model/inference/services/audio/openai_realtime_service.py +353 -0
- isa_model/inference/services/audio/openai_stt_service.py +252 -0
- isa_model/inference/services/audio/openai_tts_service.py +48 -9
- isa_model/inference/services/audio/replicate_tts_service.py +239 -0
- isa_model/inference/services/base_service.py +36 -1
- isa_model/inference/services/embedding/openai_embed_service.py +223 -0
- isa_model/inference/services/llm/base_llm_service.py +88 -192
- isa_model/inference/services/llm/llm_adapter.py +459 -0
- isa_model/inference/services/llm/ollama_llm_service.py +111 -185
- isa_model/inference/services/llm/openai_llm_service.py +115 -360
- isa_model/inference/services/vision/helpers/image_utils.py +4 -3
- isa_model/inference/services/vision/ollama_vision_service.py +11 -3
- isa_model/inference/services/vision/openai_vision_service.py +275 -41
- isa_model/inference/services/vision/replicate_image_gen_service.py +233 -205
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/METADATA +1 -1
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/RECORD +26 -21
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/WHEEL +0 -0
- {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/top_level.txt +0 -0
isa_model/core/model_manager.py
CHANGED
@@ -2,7 +2,7 @@ from typing import Dict, Optional, List, Any
|
|
2
2
|
import logging
|
3
3
|
from pathlib import Path
|
4
4
|
from huggingface_hub import hf_hub_download, snapshot_download
|
5
|
-
from huggingface_hub.
|
5
|
+
from huggingface_hub.errors import HfHubHTTPError
|
6
6
|
from .model_storage import ModelStorage, LocalModelStorage
|
7
7
|
from .model_registry import ModelRegistry, ModelType, ModelCapability
|
8
8
|
|
@@ -11,19 +11,81 @@ logger = logging.getLogger(__name__)
|
|
11
11
|
class ModelManager:
|
12
12
|
"""Model management service for handling model downloads, versions, and caching"""
|
13
13
|
|
14
|
+
# 统一的模型计费信息 (per 1M tokens)
|
15
|
+
MODEL_PRICING = {
|
16
|
+
# OpenAI Models
|
17
|
+
"openai": {
|
18
|
+
"gpt-4o-mini": {"input": 0.15, "output": 0.6},
|
19
|
+
"gpt-4.1-mini": {"input": 0.4, "output": 1.6},
|
20
|
+
"gpt-4.1-nano": {"input": 0.1, "output": 0.4},
|
21
|
+
"gpt-4o": {"input": 5.0, "output": 15.0},
|
22
|
+
"gpt-4-turbo": {"input": 10.0, "output": 30.0},
|
23
|
+
"gpt-4": {"input": 30.0, "output": 60.0},
|
24
|
+
"gpt-3.5-turbo": {"input": 0.5, "output": 1.5},
|
25
|
+
"text-embedding-3-small": {"input": 0.02, "output": 0.0},
|
26
|
+
"text-embedding-3-large": {"input": 0.13, "output": 0.0},
|
27
|
+
"whisper-1": {"input": 6.0, "output": 0.0},
|
28
|
+
"tts-1": {"input": 15.0, "output": 0.0},
|
29
|
+
"tts-1-hd": {"input": 30.0, "output": 0.0},
|
30
|
+
},
|
31
|
+
# Ollama Models (免费本地模型)
|
32
|
+
"ollama": {
|
33
|
+
"llama3.2:3b-instruct-fp16": {"input": 0.0, "output": 0.0},
|
34
|
+
"llama3.2-vision:latest": {"input": 0.0, "output": 0.0},
|
35
|
+
"bge-m3": {"input": 0.0, "output": 0.0},
|
36
|
+
},
|
37
|
+
# Replicate Models
|
38
|
+
"replicate": {
|
39
|
+
"black-forest-labs/flux-schnell": {"input": 3.0, "output": 0.0}, # $3 per 1000 images
|
40
|
+
"black-forest-labs/flux-kontext-pro": {"input": 40.0, "output": 0.0}, # $0.04 per image = $40 per 1000 images
|
41
|
+
"meta/meta-llama-3-8b-instruct": {"input": 0.05, "output": 0.25},
|
42
|
+
"kokoro-82m": {"input": 0.0, "output": 0.4}, # ~$0.0004 per second
|
43
|
+
"jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13": {"input": 0.0, "output": 0.4},
|
44
|
+
}
|
45
|
+
}
|
46
|
+
|
14
47
|
def __init__(self,
|
15
48
|
storage: Optional[ModelStorage] = None,
|
16
49
|
registry: Optional[ModelRegistry] = None):
|
17
50
|
self.storage = storage or LocalModelStorage()
|
18
51
|
self.registry = registry or ModelRegistry()
|
19
52
|
|
53
|
+
def get_model_pricing(self, provider: str, model_name: str) -> Dict[str, float]:
|
54
|
+
"""获取模型定价信息"""
|
55
|
+
return self.MODEL_PRICING.get(provider, {}).get(model_name, {"input": 0.0, "output": 0.0})
|
56
|
+
|
57
|
+
def calculate_cost(self, provider: str, model_name: str, input_tokens: int, output_tokens: int) -> float:
|
58
|
+
"""计算请求成本"""
|
59
|
+
pricing = self.get_model_pricing(provider, model_name)
|
60
|
+
input_cost = (input_tokens / 1_000_000) * pricing["input"]
|
61
|
+
output_cost = (output_tokens / 1_000_000) * pricing["output"]
|
62
|
+
return input_cost + output_cost
|
63
|
+
|
64
|
+
def get_cheapest_model(self, provider: str, model_type: str = "llm") -> Optional[str]:
|
65
|
+
"""获取最便宜的模型"""
|
66
|
+
provider_models = self.MODEL_PRICING.get(provider, {})
|
67
|
+
if not provider_models:
|
68
|
+
return None
|
69
|
+
|
70
|
+
# 计算每个模型的平均成本 (假设输入输出比例 1:1)
|
71
|
+
cheapest_model = None
|
72
|
+
lowest_cost = float('inf')
|
73
|
+
|
74
|
+
for model_name, pricing in provider_models.items():
|
75
|
+
avg_cost = (pricing["input"] + pricing["output"]) / 2
|
76
|
+
if avg_cost < lowest_cost:
|
77
|
+
lowest_cost = avg_cost
|
78
|
+
cheapest_model = model_name
|
79
|
+
|
80
|
+
return cheapest_model
|
81
|
+
|
20
82
|
async def get_model(self,
|
21
83
|
model_id: str,
|
22
84
|
repo_id: str,
|
23
85
|
model_type: ModelType,
|
24
86
|
capabilities: List[ModelCapability],
|
25
87
|
revision: Optional[str] = None,
|
26
|
-
force_download: bool = False) -> Path:
|
88
|
+
force_download: bool = False) -> Optional[Path]:
|
27
89
|
"""
|
28
90
|
Get model files, downloading if necessary
|
29
91
|
|
@@ -36,7 +98,7 @@ class ModelManager:
|
|
36
98
|
force_download: Force re-download even if cached
|
37
99
|
|
38
100
|
Returns:
|
39
|
-
Path to the model files
|
101
|
+
Path to the model files or None if failed
|
40
102
|
"""
|
41
103
|
# Check if model is already downloaded
|
42
104
|
if not force_download:
|
@@ -80,7 +142,10 @@ class ModelManager:
|
|
80
142
|
|
81
143
|
except HfHubHTTPError as e:
|
82
144
|
logger.error(f"Failed to download model {model_id}: {e}")
|
83
|
-
|
145
|
+
return None
|
146
|
+
except Exception as e:
|
147
|
+
logger.error(f"Unexpected error downloading model {model_id}: {e}")
|
148
|
+
return None
|
84
149
|
|
85
150
|
async def list_models(self) -> List[Dict[str, Any]]:
|
86
151
|
"""List all downloaded models with their metadata"""
|
@@ -3,21 +3,27 @@
|
|
3
3
|
|
4
4
|
"""
|
5
5
|
Simplified AI Factory for creating inference services
|
6
|
-
Uses the new service architecture with proper base classes
|
6
|
+
Uses the new service architecture with proper base classes and centralized API key management
|
7
7
|
"""
|
8
8
|
|
9
|
-
from typing import Dict, Type, Any, Optional, Tuple, List
|
9
|
+
from typing import Dict, Type, Any, Optional, Tuple, List, TYPE_CHECKING, cast
|
10
10
|
import logging
|
11
|
-
import os
|
12
11
|
from isa_model.inference.providers.base_provider import BaseProvider
|
13
12
|
from isa_model.inference.services.base_service import BaseService
|
14
13
|
from isa_model.inference.base import ModelType
|
14
|
+
from isa_model.inference.services.vision.base_vision_service import BaseVisionService
|
15
|
+
from isa_model.inference.services.vision.base_image_gen_service import BaseImageGenService
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from isa_model.inference.services.audio.base_stt_service import BaseSTTService
|
19
|
+
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
15
20
|
|
16
21
|
logger = logging.getLogger(__name__)
|
17
22
|
|
18
23
|
class AIFactory:
|
19
24
|
"""
|
20
25
|
Simplified Factory for creating AI services with proper inheritance hierarchy
|
26
|
+
API key management is handled by individual providers
|
21
27
|
"""
|
22
28
|
|
23
29
|
_instance = None
|
@@ -49,7 +55,7 @@ class AIFactory:
|
|
49
55
|
# Register Replicate services
|
50
56
|
self._register_replicate_services()
|
51
57
|
|
52
|
-
logger.info("AI Factory initialized with
|
58
|
+
logger.info("AI Factory initialized with centralized provider API key management")
|
53
59
|
|
54
60
|
except Exception as e:
|
55
61
|
logger.error(f"Error initializing services: {e}")
|
@@ -79,10 +85,15 @@ class AIFactory:
|
|
79
85
|
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
80
86
|
from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
|
81
87
|
from isa_model.inference.services.audio.openai_tts_service import OpenAITTSService
|
88
|
+
from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
|
89
|
+
from isa_model.inference.services.embedding.openai_embed_service import OpenAIEmbedService
|
90
|
+
from isa_model.inference.services.vision.openai_vision_service import OpenAIVisionService
|
82
91
|
|
83
92
|
self.register_provider('openai', OpenAIProvider)
|
84
93
|
self.register_service('openai', ModelType.LLM, OpenAILLMService)
|
85
94
|
self.register_service('openai', ModelType.AUDIO, OpenAITTSService)
|
95
|
+
self.register_service('openai', ModelType.EMBEDDING, OpenAIEmbedService)
|
96
|
+
self.register_service('openai', ModelType.VISION, OpenAIVisionService)
|
86
97
|
|
87
98
|
logger.info("OpenAI services registered successfully")
|
88
99
|
|
@@ -94,9 +105,11 @@ class AIFactory:
|
|
94
105
|
try:
|
95
106
|
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
96
107
|
from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateImageGenService
|
108
|
+
from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
|
97
109
|
|
98
110
|
self.register_provider('replicate', ReplicateProvider)
|
99
111
|
self.register_service('replicate', ModelType.VISION, ReplicateImageGenService)
|
112
|
+
self.register_service('replicate', ModelType.AUDIO, ReplicateTTSService)
|
100
113
|
|
101
114
|
logger.info("Replicate services registered successfully")
|
102
115
|
|
@@ -114,7 +127,7 @@ class AIFactory:
|
|
114
127
|
|
115
128
|
def create_service(self, provider_name: str, model_type: ModelType,
|
116
129
|
model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
|
117
|
-
"""Create a service instance"""
|
130
|
+
"""Create a service instance with provider-managed configuration"""
|
118
131
|
try:
|
119
132
|
cache_key = f"{provider_name}_{model_type}_{model_name}"
|
120
133
|
|
@@ -133,8 +146,8 @@ class AIFactory:
|
|
133
146
|
f"No service registered for provider '{provider_name}' and model type '{model_type}'"
|
134
147
|
)
|
135
148
|
|
136
|
-
# Create provider
|
137
|
-
provider = provider_class(config=config
|
149
|
+
# Create provider with user config (provider handles .env loading)
|
150
|
+
provider = provider_class(config=config)
|
138
151
|
service = service_class(provider=provider, model_name=model_name)
|
139
152
|
|
140
153
|
self._cached_services[cache_key] = service
|
@@ -144,81 +157,254 @@ class AIFactory:
|
|
144
157
|
logger.error(f"Error creating service: {e}")
|
145
158
|
raise
|
146
159
|
|
147
|
-
# Convenient methods for common services
|
148
|
-
def get_llm_service(self, model_name: str =
|
160
|
+
# Convenient methods for common services with updated defaults
|
161
|
+
def get_llm_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
149
162
|
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
150
163
|
"""
|
151
|
-
Get a LLM service instance
|
164
|
+
Get a LLM service instance with automatic defaults
|
152
165
|
|
153
166
|
Args:
|
154
|
-
model_name: Name of the model to use
|
155
|
-
provider: Provider name ('
|
156
|
-
config: Optional configuration dictionary
|
167
|
+
model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-nano", Ollama="llama3.2:3b")
|
168
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
169
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
170
|
+
Can include: streaming=True/False, temperature, max_tokens, etc.
|
157
171
|
|
158
172
|
Returns:
|
159
173
|
LLM service instance
|
160
174
|
"""
|
161
|
-
|
175
|
+
# Set defaults based on provider
|
176
|
+
if provider == "openai":
|
177
|
+
final_model_name = model_name or "gpt-4.1-nano"
|
178
|
+
final_provider = provider
|
179
|
+
elif provider == "ollama":
|
180
|
+
final_model_name = model_name or "llama3.2:3b-instruct-fp16"
|
181
|
+
final_provider = provider
|
182
|
+
else:
|
183
|
+
# Default provider selection - OpenAI with cheapest model
|
184
|
+
final_provider = provider or "openai"
|
185
|
+
if final_provider == "openai":
|
186
|
+
final_model_name = model_name or "gpt-4.1-nano"
|
187
|
+
else:
|
188
|
+
final_model_name = model_name or "llama3.2:3b-instruct-fp16"
|
189
|
+
|
190
|
+
return self.create_service(final_provider, ModelType.LLM, final_model_name, config)
|
162
191
|
|
163
|
-
def get_embedding_service(self, model_name: str =
|
192
|
+
def get_embedding_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
164
193
|
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
165
194
|
"""
|
166
|
-
Get an embedding service instance
|
195
|
+
Get an embedding service instance with automatic defaults
|
167
196
|
|
168
197
|
Args:
|
169
|
-
model_name: Name of the model to use
|
170
|
-
provider: Provider name ('ollama')
|
171
|
-
config: Optional configuration dictionary
|
198
|
+
model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
|
199
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
200
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
172
201
|
|
173
202
|
Returns:
|
174
203
|
Embedding service instance
|
175
204
|
"""
|
176
|
-
|
205
|
+
# Set defaults based on provider
|
206
|
+
if provider == "openai":
|
207
|
+
final_model_name = model_name or "text-embedding-3-small"
|
208
|
+
final_provider = provider
|
209
|
+
elif provider == "ollama":
|
210
|
+
final_model_name = model_name or "bge-m3"
|
211
|
+
final_provider = provider
|
212
|
+
else:
|
213
|
+
# Default provider selection
|
214
|
+
final_provider = provider or "openai"
|
215
|
+
if final_provider == "openai":
|
216
|
+
final_model_name = model_name or "text-embedding-3-small"
|
217
|
+
else:
|
218
|
+
final_model_name = model_name or "bge-m3"
|
219
|
+
|
220
|
+
return self.create_service(final_provider, ModelType.EMBEDDING, final_model_name, config)
|
177
221
|
|
178
|
-
def get_vision_service(self, model_name: str, provider: str,
|
179
|
-
config: Optional[Dict[str, Any]] = None) ->
|
222
|
+
def get_vision_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
223
|
+
config: Optional[Dict[str, Any]] = None) -> BaseVisionService:
|
180
224
|
"""
|
181
|
-
Get a vision service instance
|
225
|
+
Get a vision service instance with automatic defaults
|
182
226
|
|
183
227
|
Args:
|
184
|
-
model_name: Name of the model to use
|
185
|
-
provider: Provider name ('
|
186
|
-
config: Optional configuration dictionary
|
228
|
+
model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-mini", Ollama="gemma3:4b")
|
229
|
+
provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
|
230
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
187
231
|
|
188
232
|
Returns:
|
189
233
|
Vision service instance
|
190
234
|
"""
|
191
|
-
|
235
|
+
# Set defaults based on provider
|
236
|
+
if provider == "openai":
|
237
|
+
final_model_name = model_name or "gpt-4.1-mini"
|
238
|
+
final_provider = provider
|
239
|
+
elif provider == "ollama":
|
240
|
+
final_model_name = model_name or "llama3.2-vision:latest"
|
241
|
+
final_provider = provider
|
242
|
+
else:
|
243
|
+
# Default provider selection
|
244
|
+
final_provider = provider or "openai"
|
245
|
+
if final_provider == "openai":
|
246
|
+
final_model_name = model_name or "gpt-4.1-mini"
|
247
|
+
else:
|
248
|
+
final_model_name = model_name or "llama3.2-vision:latest"
|
249
|
+
|
250
|
+
return cast(BaseVisionService, self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
251
|
+
|
252
|
+
def get_image_generation_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
253
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
|
254
|
+
"""
|
255
|
+
Get an image generation service instance with automatic defaults
|
256
|
+
|
257
|
+
Args:
|
258
|
+
model_name: Name of the model to use (defaults: "black-forest-labs/flux-schnell" for production)
|
259
|
+
provider: Provider name (defaults to 'replicate')
|
260
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
261
|
+
|
262
|
+
Returns:
|
263
|
+
Image generation service instance
|
264
|
+
"""
|
265
|
+
# Set defaults based on provider
|
266
|
+
final_provider = provider or "replicate"
|
267
|
+
if final_provider == "replicate":
|
268
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
269
|
+
else:
|
270
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
271
|
+
|
272
|
+
return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
192
273
|
|
193
|
-
def
|
194
|
-
|
274
|
+
def get_img(self, type: str = "t2i", model_name: Optional[str] = None, provider: Optional[str] = None,
|
275
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
|
195
276
|
"""
|
196
|
-
Get an image generation service
|
277
|
+
Get an image generation service with type-specific defaults
|
197
278
|
|
198
279
|
Args:
|
199
|
-
|
200
|
-
|
280
|
+
type: Image generation type:
|
281
|
+
- "t2i" (text-to-image): Uses flux-schnell ($3 per 1000 images)
|
282
|
+
- "i2i" (image-to-image): Uses flux-kontext-pro ($0.04 per image)
|
283
|
+
model_name: Optional model name override
|
284
|
+
provider: Provider name (defaults to 'replicate')
|
201
285
|
config: Optional configuration dictionary
|
202
286
|
|
203
287
|
Returns:
|
204
288
|
Image generation service instance
|
289
|
+
|
290
|
+
Usage:
|
291
|
+
# Text-to-image (default)
|
292
|
+
img_service = AIFactory().get_img()
|
293
|
+
img_service = AIFactory().get_img(type="t2i")
|
294
|
+
|
295
|
+
# Image-to-image
|
296
|
+
img_service = AIFactory().get_img(type="i2i")
|
297
|
+
|
298
|
+
# Custom model
|
299
|
+
img_service = AIFactory().get_img(type="t2i", model_name="custom-model")
|
205
300
|
"""
|
206
|
-
|
301
|
+
# Set defaults based on type
|
302
|
+
final_provider = provider or "replicate"
|
303
|
+
|
304
|
+
if type == "t2i":
|
305
|
+
# Text-to-image: flux-schnell
|
306
|
+
final_model_name = model_name or "black-forest-labs/flux-schnell"
|
307
|
+
elif type == "i2i":
|
308
|
+
# Image-to-image: flux-kontext-pro
|
309
|
+
final_model_name = model_name or "black-forest-labs/flux-kontext-pro"
|
310
|
+
else:
|
311
|
+
raise ValueError(f"Unknown image generation type: {type}. Use 't2i' or 'i2i'")
|
312
|
+
|
313
|
+
return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
|
207
314
|
|
208
|
-
def get_audio_service(self, model_name: str =
|
315
|
+
def get_audio_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
209
316
|
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
210
317
|
"""
|
211
|
-
Get an audio service instance
|
318
|
+
Get an audio service instance (TTS) with automatic defaults
|
212
319
|
|
213
320
|
Args:
|
214
|
-
model_name: Name of the model to use
|
215
|
-
provider: Provider name ('openai')
|
216
|
-
config: Optional configuration dictionary
|
321
|
+
model_name: Name of the model to use (defaults: OpenAI="tts-1")
|
322
|
+
provider: Provider name (defaults to 'openai')
|
323
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
217
324
|
|
218
325
|
Returns:
|
219
326
|
Audio service instance
|
220
327
|
"""
|
221
|
-
|
328
|
+
# Set defaults based on provider
|
329
|
+
final_provider = provider or "openai"
|
330
|
+
if final_provider == "openai":
|
331
|
+
final_model_name = model_name or "tts-1"
|
332
|
+
else:
|
333
|
+
final_model_name = model_name or "tts-1"
|
334
|
+
|
335
|
+
return self.create_service(final_provider, ModelType.AUDIO, final_model_name, config)
|
336
|
+
|
337
|
+
def get_tts_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
338
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
|
339
|
+
"""
|
340
|
+
Get a Text-to-Speech service instance with automatic defaults
|
341
|
+
|
342
|
+
Args:
|
343
|
+
model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
|
344
|
+
provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
|
345
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
TTS service instance
|
349
|
+
"""
|
350
|
+
# Set defaults based on provider
|
351
|
+
if provider == "replicate":
|
352
|
+
model_name = model_name or "kokoro-82m"
|
353
|
+
elif provider == "openai":
|
354
|
+
model_name = model_name or "tts-1"
|
355
|
+
else:
|
356
|
+
# Default provider selection
|
357
|
+
provider = provider or "replicate"
|
358
|
+
if provider == "replicate":
|
359
|
+
model_name = model_name or "kokoro-82m"
|
360
|
+
else:
|
361
|
+
model_name = model_name or "tts-1"
|
362
|
+
|
363
|
+
# Ensure model_name is never None
|
364
|
+
if model_name is None:
|
365
|
+
model_name = "tts-1"
|
366
|
+
|
367
|
+
if provider == "replicate":
|
368
|
+
from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
|
369
|
+
from isa_model.inference.providers.replicate_provider import ReplicateProvider
|
370
|
+
|
371
|
+
# Use full model name for Replicate
|
372
|
+
if model_name == "kokoro-82m":
|
373
|
+
model_name = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"
|
374
|
+
|
375
|
+
provider_instance = ReplicateProvider(config=config)
|
376
|
+
return ReplicateTTSService(provider=provider_instance, model_name=model_name)
|
377
|
+
else:
|
378
|
+
return cast('BaseTTSService', self.get_audio_service(model_name, provider, config))
|
379
|
+
|
380
|
+
def get_stt_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
381
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
|
382
|
+
"""
|
383
|
+
Get a Speech-to-Text service instance with automatic defaults
|
384
|
+
|
385
|
+
Args:
|
386
|
+
model_name: Name of the model to use (defaults: "whisper-1")
|
387
|
+
provider: Provider name (defaults to 'openai')
|
388
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
389
|
+
|
390
|
+
Returns:
|
391
|
+
STT service instance
|
392
|
+
"""
|
393
|
+
# Set defaults based on provider
|
394
|
+
provider = provider or "openai"
|
395
|
+
if provider == "openai":
|
396
|
+
model_name = model_name or "whisper-1"
|
397
|
+
|
398
|
+
# Ensure model_name is never None
|
399
|
+
if model_name is None:
|
400
|
+
model_name = "whisper-1"
|
401
|
+
|
402
|
+
from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
|
403
|
+
from isa_model.inference.providers.openai_provider import OpenAIProvider
|
404
|
+
|
405
|
+
# Create provider and service directly with config
|
406
|
+
provider_instance = OpenAIProvider(config=config)
|
407
|
+
return OpenAISTTService(provider=provider_instance, model_name=model_name)
|
222
408
|
|
223
409
|
def get_available_services(self) -> Dict[str, List[str]]:
|
224
410
|
"""Get information about available services"""
|
@@ -241,16 +427,90 @@ class AIFactory:
|
|
241
427
|
cls._instance = cls()
|
242
428
|
return cls._instance
|
243
429
|
|
244
|
-
# Alias
|
245
|
-
def get_llm(self, model_name: str =
|
430
|
+
# Alias method for cleaner API
|
431
|
+
def get_llm(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
246
432
|
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
247
|
-
"""
|
433
|
+
"""
|
434
|
+
Alias for get_llm_service with cleaner naming
|
435
|
+
|
436
|
+
Usage:
|
437
|
+
llm = AIFactory().get_llm() # Uses gpt-4.1-nano by default
|
438
|
+
llm = AIFactory().get_llm(model_name="llama3.2", provider="ollama")
|
439
|
+
llm = AIFactory().get_llm(model_name="gpt-4.1-mini", provider="openai", config={"streaming": True})
|
440
|
+
"""
|
248
441
|
return self.get_llm_service(model_name, provider, config)
|
249
442
|
|
250
|
-
def
|
251
|
-
config: Optional[Dict[str, Any]] = None) ->
|
252
|
-
"""
|
443
|
+
def get_embed(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
444
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseEmbedService':
|
445
|
+
"""
|
446
|
+
Get embedding service with automatic defaults
|
447
|
+
|
448
|
+
Args:
|
449
|
+
model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
|
450
|
+
provider: Provider name (defaults to 'openai' for production)
|
451
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
Embedding service instance
|
455
|
+
|
456
|
+
Usage:
|
457
|
+
# Default (OpenAI text-embedding-3-small)
|
458
|
+
embed = AIFactory().get_embed()
|
459
|
+
|
460
|
+
# Custom model
|
461
|
+
embed = AIFactory().get_embed(model_name="text-embedding-3-large", provider="openai")
|
462
|
+
|
463
|
+
# Development (Ollama)
|
464
|
+
embed = AIFactory().get_embed(provider="ollama")
|
465
|
+
"""
|
253
466
|
return self.get_embedding_service(model_name, provider, config)
|
467
|
+
|
468
|
+
def get_stt(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
469
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
|
470
|
+
"""
|
471
|
+
Get Speech-to-Text service with automatic defaults
|
472
|
+
|
473
|
+
Args:
|
474
|
+
model_name: Name of the model to use (defaults: "whisper-1")
|
475
|
+
provider: Provider name (defaults to 'openai')
|
476
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
477
|
+
|
478
|
+
Returns:
|
479
|
+
STT service instance
|
480
|
+
|
481
|
+
Usage:
|
482
|
+
# Default (OpenAI whisper-1)
|
483
|
+
stt = AIFactory().get_stt()
|
484
|
+
|
485
|
+
# Custom configuration
|
486
|
+
stt = AIFactory().get_stt(model_name="whisper-1", provider="openai")
|
487
|
+
"""
|
488
|
+
return self.get_stt_service(model_name, provider, config)
|
489
|
+
|
490
|
+
def get_tts(self, model_name: Optional[str] = None, provider: Optional[str] = None,
|
491
|
+
config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
|
492
|
+
"""
|
493
|
+
Get Text-to-Speech service with automatic defaults
|
494
|
+
|
495
|
+
Args:
|
496
|
+
model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
|
497
|
+
provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
|
498
|
+
config: Optional configuration dictionary (auto-loads from .env if not provided)
|
499
|
+
|
500
|
+
Returns:
|
501
|
+
TTS service instance
|
502
|
+
|
503
|
+
Usage:
|
504
|
+
# Default (Replicate kokoro-82m)
|
505
|
+
tts = AIFactory().get_tts()
|
506
|
+
|
507
|
+
# Development (OpenAI tts-1)
|
508
|
+
tts = AIFactory().get_tts(provider="openai")
|
509
|
+
|
510
|
+
# Custom model
|
511
|
+
tts = AIFactory().get_tts(model_name="tts-1-hd", provider="openai")
|
512
|
+
"""
|
513
|
+
return self.get_tts_service(model_name, provider, config)
|
254
514
|
|
255
515
|
def get_vision_model(self, model_name: str, provider: str,
|
256
516
|
config: Optional[Dict[str, Any]] = None) -> BaseService:
|
@@ -258,4 +518,33 @@ class AIFactory:
|
|
258
518
|
if provider == "replicate":
|
259
519
|
return self.get_image_generation_service(model_name, provider, config)
|
260
520
|
else:
|
261
|
-
return self.get_vision_service(model_name, provider, config)
|
521
|
+
return self.get_vision_service(model_name, provider, config)
|
522
|
+
|
523
|
+
def get_vision(
|
524
|
+
self,
|
525
|
+
model_name: Optional[str] = None,
|
526
|
+
provider: Optional[str] = None,
|
527
|
+
config: Optional[Dict[str, Any]] = None
|
528
|
+
) -> 'BaseVisionService':
|
529
|
+
"""
|
530
|
+
Get vision service with automatic defaults
|
531
|
+
|
532
|
+
Args:
|
533
|
+
model_name: Model name (default: gpt-4.1-nano)
|
534
|
+
provider: Provider name (default: openai)
|
535
|
+
config: Optional configuration override
|
536
|
+
|
537
|
+
Returns:
|
538
|
+
Vision service instance
|
539
|
+
"""
|
540
|
+
# Set defaults
|
541
|
+
if provider is None:
|
542
|
+
provider = "openai"
|
543
|
+
if model_name is None:
|
544
|
+
model_name = "gpt-4.1-nano"
|
545
|
+
|
546
|
+
return self.get_vision_service(
|
547
|
+
model_name=model_name,
|
548
|
+
provider=provider,
|
549
|
+
config=config
|
550
|
+
)
|