isa-model 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. isa_model/core/model_manager.py +69 -4
  2. isa_model/inference/ai_factory.py +335 -46
  3. isa_model/inference/billing_tracker.py +406 -0
  4. isa_model/inference/providers/base_provider.py +51 -4
  5. isa_model/inference/providers/ollama_provider.py +37 -18
  6. isa_model/inference/providers/openai_provider.py +65 -36
  7. isa_model/inference/providers/replicate_provider.py +42 -30
  8. isa_model/inference/services/audio/base_stt_service.py +21 -2
  9. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  10. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  11. isa_model/inference/services/audio/openai_tts_service.py +48 -9
  12. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  13. isa_model/inference/services/base_service.py +36 -1
  14. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  15. isa_model/inference/services/llm/base_llm_service.py +88 -192
  16. isa_model/inference/services/llm/llm_adapter.py +459 -0
  17. isa_model/inference/services/llm/ollama_llm_service.py +111 -185
  18. isa_model/inference/services/llm/openai_llm_service.py +115 -360
  19. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  20. isa_model/inference/services/vision/ollama_vision_service.py +11 -3
  21. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  22. isa_model/inference/services/vision/replicate_image_gen_service.py +233 -205
  23. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/METADATA +1 -1
  24. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/RECORD +26 -21
  25. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/WHEEL +0 -0
  26. {isa_model-0.3.0.dist-info → isa_model-0.3.2.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ from typing import Dict, Optional, List, Any
2
2
  import logging
3
3
  from pathlib import Path
4
4
  from huggingface_hub import hf_hub_download, snapshot_download
5
- from huggingface_hub.utils import HfHubHTTPError
5
+ from huggingface_hub.errors import HfHubHTTPError
6
6
  from .model_storage import ModelStorage, LocalModelStorage
7
7
  from .model_registry import ModelRegistry, ModelType, ModelCapability
8
8
 
@@ -11,19 +11,81 @@ logger = logging.getLogger(__name__)
11
11
  class ModelManager:
12
12
  """Model management service for handling model downloads, versions, and caching"""
13
13
 
14
+ # 统一的模型计费信息 (per 1M tokens)
15
+ MODEL_PRICING = {
16
+ # OpenAI Models
17
+ "openai": {
18
+ "gpt-4o-mini": {"input": 0.15, "output": 0.6},
19
+ "gpt-4.1-mini": {"input": 0.4, "output": 1.6},
20
+ "gpt-4.1-nano": {"input": 0.1, "output": 0.4},
21
+ "gpt-4o": {"input": 5.0, "output": 15.0},
22
+ "gpt-4-turbo": {"input": 10.0, "output": 30.0},
23
+ "gpt-4": {"input": 30.0, "output": 60.0},
24
+ "gpt-3.5-turbo": {"input": 0.5, "output": 1.5},
25
+ "text-embedding-3-small": {"input": 0.02, "output": 0.0},
26
+ "text-embedding-3-large": {"input": 0.13, "output": 0.0},
27
+ "whisper-1": {"input": 6.0, "output": 0.0},
28
+ "tts-1": {"input": 15.0, "output": 0.0},
29
+ "tts-1-hd": {"input": 30.0, "output": 0.0},
30
+ },
31
+ # Ollama Models (免费本地模型)
32
+ "ollama": {
33
+ "llama3.2:3b-instruct-fp16": {"input": 0.0, "output": 0.0},
34
+ "llama3.2-vision:latest": {"input": 0.0, "output": 0.0},
35
+ "bge-m3": {"input": 0.0, "output": 0.0},
36
+ },
37
+ # Replicate Models
38
+ "replicate": {
39
+ "black-forest-labs/flux-schnell": {"input": 3.0, "output": 0.0}, # $3 per 1000 images
40
+ "black-forest-labs/flux-kontext-pro": {"input": 40.0, "output": 0.0}, # $0.04 per image = $40 per 1000 images
41
+ "meta/meta-llama-3-8b-instruct": {"input": 0.05, "output": 0.25},
42
+ "kokoro-82m": {"input": 0.0, "output": 0.4}, # ~$0.0004 per second
43
+ "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13": {"input": 0.0, "output": 0.4},
44
+ }
45
+ }
46
+
14
47
  def __init__(self,
15
48
  storage: Optional[ModelStorage] = None,
16
49
  registry: Optional[ModelRegistry] = None):
17
50
  self.storage = storage or LocalModelStorage()
18
51
  self.registry = registry or ModelRegistry()
19
52
 
53
+ def get_model_pricing(self, provider: str, model_name: str) -> Dict[str, float]:
54
+ """获取模型定价信息"""
55
+ return self.MODEL_PRICING.get(provider, {}).get(model_name, {"input": 0.0, "output": 0.0})
56
+
57
+ def calculate_cost(self, provider: str, model_name: str, input_tokens: int, output_tokens: int) -> float:
58
+ """计算请求成本"""
59
+ pricing = self.get_model_pricing(provider, model_name)
60
+ input_cost = (input_tokens / 1_000_000) * pricing["input"]
61
+ output_cost = (output_tokens / 1_000_000) * pricing["output"]
62
+ return input_cost + output_cost
63
+
64
+ def get_cheapest_model(self, provider: str, model_type: str = "llm") -> Optional[str]:
65
+ """获取最便宜的模型"""
66
+ provider_models = self.MODEL_PRICING.get(provider, {})
67
+ if not provider_models:
68
+ return None
69
+
70
+ # 计算每个模型的平均成本 (假设输入输出比例 1:1)
71
+ cheapest_model = None
72
+ lowest_cost = float('inf')
73
+
74
+ for model_name, pricing in provider_models.items():
75
+ avg_cost = (pricing["input"] + pricing["output"]) / 2
76
+ if avg_cost < lowest_cost:
77
+ lowest_cost = avg_cost
78
+ cheapest_model = model_name
79
+
80
+ return cheapest_model
81
+
20
82
  async def get_model(self,
21
83
  model_id: str,
22
84
  repo_id: str,
23
85
  model_type: ModelType,
24
86
  capabilities: List[ModelCapability],
25
87
  revision: Optional[str] = None,
26
- force_download: bool = False) -> Path:
88
+ force_download: bool = False) -> Optional[Path]:
27
89
  """
28
90
  Get model files, downloading if necessary
29
91
 
@@ -36,7 +98,7 @@ class ModelManager:
36
98
  force_download: Force re-download even if cached
37
99
 
38
100
  Returns:
39
- Path to the model files
101
+ Path to the model files or None if failed
40
102
  """
41
103
  # Check if model is already downloaded
42
104
  if not force_download:
@@ -80,7 +142,10 @@ class ModelManager:
80
142
 
81
143
  except HfHubHTTPError as e:
82
144
  logger.error(f"Failed to download model {model_id}: {e}")
83
- raise
145
+ return None
146
+ except Exception as e:
147
+ logger.error(f"Unexpected error downloading model {model_id}: {e}")
148
+ return None
84
149
 
85
150
  async def list_models(self) -> List[Dict[str, Any]]:
86
151
  """List all downloaded models with their metadata"""
@@ -3,21 +3,27 @@
3
3
 
4
4
  """
5
5
  Simplified AI Factory for creating inference services
6
- Uses the new service architecture with proper base classes
6
+ Uses the new service architecture with proper base classes and centralized API key management
7
7
  """
8
8
 
9
- from typing import Dict, Type, Any, Optional, Tuple, List
9
+ from typing import Dict, Type, Any, Optional, Tuple, List, TYPE_CHECKING, cast
10
10
  import logging
11
- import os
12
11
  from isa_model.inference.providers.base_provider import BaseProvider
13
12
  from isa_model.inference.services.base_service import BaseService
14
13
  from isa_model.inference.base import ModelType
14
+ from isa_model.inference.services.vision.base_vision_service import BaseVisionService
15
+ from isa_model.inference.services.vision.base_image_gen_service import BaseImageGenService
16
+
17
+ if TYPE_CHECKING:
18
+ from isa_model.inference.services.audio.base_stt_service import BaseSTTService
19
+ from isa_model.inference.services.audio.base_tts_service import BaseTTSService
15
20
 
16
21
  logger = logging.getLogger(__name__)
17
22
 
18
23
  class AIFactory:
19
24
  """
20
25
  Simplified Factory for creating AI services with proper inheritance hierarchy
26
+ API key management is handled by individual providers
21
27
  """
22
28
 
23
29
  _instance = None
@@ -49,7 +55,7 @@ class AIFactory:
49
55
  # Register Replicate services
50
56
  self._register_replicate_services()
51
57
 
52
- logger.info("AI Factory initialized with simplified service architecture")
58
+ logger.info("AI Factory initialized with centralized provider API key management")
53
59
 
54
60
  except Exception as e:
55
61
  logger.error(f"Error initializing services: {e}")
@@ -79,10 +85,15 @@ class AIFactory:
79
85
  from isa_model.inference.providers.openai_provider import OpenAIProvider
80
86
  from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
81
87
  from isa_model.inference.services.audio.openai_tts_service import OpenAITTSService
88
+ from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
89
+ from isa_model.inference.services.embedding.openai_embed_service import OpenAIEmbedService
90
+ from isa_model.inference.services.vision.openai_vision_service import OpenAIVisionService
82
91
 
83
92
  self.register_provider('openai', OpenAIProvider)
84
93
  self.register_service('openai', ModelType.LLM, OpenAILLMService)
85
94
  self.register_service('openai', ModelType.AUDIO, OpenAITTSService)
95
+ self.register_service('openai', ModelType.EMBEDDING, OpenAIEmbedService)
96
+ self.register_service('openai', ModelType.VISION, OpenAIVisionService)
86
97
 
87
98
  logger.info("OpenAI services registered successfully")
88
99
 
@@ -94,9 +105,11 @@ class AIFactory:
94
105
  try:
95
106
  from isa_model.inference.providers.replicate_provider import ReplicateProvider
96
107
  from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateImageGenService
108
+ from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
97
109
 
98
110
  self.register_provider('replicate', ReplicateProvider)
99
111
  self.register_service('replicate', ModelType.VISION, ReplicateImageGenService)
112
+ self.register_service('replicate', ModelType.AUDIO, ReplicateTTSService)
100
113
 
101
114
  logger.info("Replicate services registered successfully")
102
115
 
@@ -114,7 +127,7 @@ class AIFactory:
114
127
 
115
128
  def create_service(self, provider_name: str, model_type: ModelType,
116
129
  model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
117
- """Create a service instance"""
130
+ """Create a service instance with provider-managed configuration"""
118
131
  try:
119
132
  cache_key = f"{provider_name}_{model_type}_{model_name}"
120
133
 
@@ -133,8 +146,8 @@ class AIFactory:
133
146
  f"No service registered for provider '{provider_name}' and model type '{model_type}'"
134
147
  )
135
148
 
136
- # Create provider and service
137
- provider = provider_class(config=config or {})
149
+ # Create provider with user config (provider handles .env loading)
150
+ provider = provider_class(config=config)
138
151
  service = service_class(provider=provider, model_name=model_name)
139
152
 
140
153
  self._cached_services[cache_key] = service
@@ -144,81 +157,254 @@ class AIFactory:
144
157
  logger.error(f"Error creating service: {e}")
145
158
  raise
146
159
 
147
- # Convenient methods for common services
148
- def get_llm_service(self, model_name: str = "llama3.1", provider: str = "ollama",
160
+ # Convenient methods for common services with updated defaults
161
+ def get_llm_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
149
162
  config: Optional[Dict[str, Any]] = None) -> BaseService:
150
163
  """
151
- Get a LLM service instance
164
+ Get a LLM service instance with automatic defaults
152
165
 
153
166
  Args:
154
- model_name: Name of the model to use
155
- provider: Provider name ('ollama', 'openai')
156
- config: Optional configuration dictionary
167
+ model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-nano", Ollama="llama3.2:3b")
168
+ provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
169
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
170
+ Can include: streaming=True/False, temperature, max_tokens, etc.
157
171
 
158
172
  Returns:
159
173
  LLM service instance
160
174
  """
161
- return self.create_service(provider, ModelType.LLM, model_name, config)
175
+ # Set defaults based on provider
176
+ if provider == "openai":
177
+ final_model_name = model_name or "gpt-4.1-nano"
178
+ final_provider = provider
179
+ elif provider == "ollama":
180
+ final_model_name = model_name or "llama3.2:3b-instruct-fp16"
181
+ final_provider = provider
182
+ else:
183
+ # Default provider selection - OpenAI with cheapest model
184
+ final_provider = provider or "openai"
185
+ if final_provider == "openai":
186
+ final_model_name = model_name or "gpt-4.1-nano"
187
+ else:
188
+ final_model_name = model_name or "llama3.2:3b-instruct-fp16"
189
+
190
+ return self.create_service(final_provider, ModelType.LLM, final_model_name, config)
162
191
 
163
- def get_embedding_service(self, model_name: str = "bge-m3", provider: str = "ollama",
192
+ def get_embedding_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
164
193
  config: Optional[Dict[str, Any]] = None) -> BaseService:
165
194
  """
166
- Get an embedding service instance
195
+ Get an embedding service instance with automatic defaults
167
196
 
168
197
  Args:
169
- model_name: Name of the model to use
170
- provider: Provider name ('ollama')
171
- config: Optional configuration dictionary
198
+ model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
199
+ provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
200
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
172
201
 
173
202
  Returns:
174
203
  Embedding service instance
175
204
  """
176
- return self.create_service(provider, ModelType.EMBEDDING, model_name, config)
205
+ # Set defaults based on provider
206
+ if provider == "openai":
207
+ final_model_name = model_name or "text-embedding-3-small"
208
+ final_provider = provider
209
+ elif provider == "ollama":
210
+ final_model_name = model_name or "bge-m3"
211
+ final_provider = provider
212
+ else:
213
+ # Default provider selection
214
+ final_provider = provider or "openai"
215
+ if final_provider == "openai":
216
+ final_model_name = model_name or "text-embedding-3-small"
217
+ else:
218
+ final_model_name = model_name or "bge-m3"
219
+
220
+ return self.create_service(final_provider, ModelType.EMBEDDING, final_model_name, config)
177
221
 
178
- def get_vision_service(self, model_name: str, provider: str,
179
- config: Optional[Dict[str, Any]] = None) -> BaseService:
222
+ def get_vision_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
223
+ config: Optional[Dict[str, Any]] = None) -> BaseVisionService:
180
224
  """
181
- Get a vision service instance
225
+ Get a vision service instance with automatic defaults
182
226
 
183
227
  Args:
184
- model_name: Name of the model to use
185
- provider: Provider name ('ollama', 'replicate')
186
- config: Optional configuration dictionary
228
+ model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-mini", Ollama="gemma3:4b")
229
+ provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
230
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
187
231
 
188
232
  Returns:
189
233
  Vision service instance
190
234
  """
191
- return self.create_service(provider, ModelType.VISION, model_name, config)
235
+ # Set defaults based on provider
236
+ if provider == "openai":
237
+ final_model_name = model_name or "gpt-4.1-mini"
238
+ final_provider = provider
239
+ elif provider == "ollama":
240
+ final_model_name = model_name or "llama3.2-vision:latest"
241
+ final_provider = provider
242
+ else:
243
+ # Default provider selection
244
+ final_provider = provider or "openai"
245
+ if final_provider == "openai":
246
+ final_model_name = model_name or "gpt-4.1-mini"
247
+ else:
248
+ final_model_name = model_name or "llama3.2-vision:latest"
249
+
250
+ return cast(BaseVisionService, self.create_service(final_provider, ModelType.VISION, final_model_name, config))
251
+
252
+ def get_image_generation_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
253
+ config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
254
+ """
255
+ Get an image generation service instance with automatic defaults
256
+
257
+ Args:
258
+ model_name: Name of the model to use (defaults: "black-forest-labs/flux-schnell" for production)
259
+ provider: Provider name (defaults to 'replicate')
260
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
261
+
262
+ Returns:
263
+ Image generation service instance
264
+ """
265
+ # Set defaults based on provider
266
+ final_provider = provider or "replicate"
267
+ if final_provider == "replicate":
268
+ final_model_name = model_name or "black-forest-labs/flux-schnell"
269
+ else:
270
+ final_model_name = model_name or "black-forest-labs/flux-schnell"
271
+
272
+ return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
192
273
 
193
- def get_image_generation_service(self, model_name: str, provider: str = "replicate",
194
- config: Optional[Dict[str, Any]] = None) -> BaseService:
274
+ def get_img(self, type: str = "t2i", model_name: Optional[str] = None, provider: Optional[str] = None,
275
+ config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
195
276
  """
196
- Get an image generation service instance
277
+ Get an image generation service with type-specific defaults
197
278
 
198
279
  Args:
199
- model_name: Name of the model to use (e.g., "stability-ai/sdxl")
200
- provider: Provider name ('replicate')
280
+ type: Image generation type:
281
+ - "t2i" (text-to-image): Uses flux-schnell ($3 per 1000 images)
282
+ - "i2i" (image-to-image): Uses flux-kontext-pro ($0.04 per image)
283
+ model_name: Optional model name override
284
+ provider: Provider name (defaults to 'replicate')
201
285
  config: Optional configuration dictionary
202
286
 
203
287
  Returns:
204
288
  Image generation service instance
289
+
290
+ Usage:
291
+ # Text-to-image (default)
292
+ img_service = AIFactory().get_img()
293
+ img_service = AIFactory().get_img(type="t2i")
294
+
295
+ # Image-to-image
296
+ img_service = AIFactory().get_img(type="i2i")
297
+
298
+ # Custom model
299
+ img_service = AIFactory().get_img(type="t2i", model_name="custom-model")
205
300
  """
206
- return self.create_service(provider, ModelType.VISION, model_name, config)
301
+ # Set defaults based on type
302
+ final_provider = provider or "replicate"
303
+
304
+ if type == "t2i":
305
+ # Text-to-image: flux-schnell
306
+ final_model_name = model_name or "black-forest-labs/flux-schnell"
307
+ elif type == "i2i":
308
+ # Image-to-image: flux-kontext-pro
309
+ final_model_name = model_name or "black-forest-labs/flux-kontext-pro"
310
+ else:
311
+ raise ValueError(f"Unknown image generation type: {type}. Use 't2i' or 'i2i'")
312
+
313
+ return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
207
314
 
208
- def get_audio_service(self, model_name: str = "tts-1", provider: str = "openai",
315
+ def get_audio_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
209
316
  config: Optional[Dict[str, Any]] = None) -> BaseService:
210
317
  """
211
- Get an audio service instance
318
+ Get an audio service instance (TTS) with automatic defaults
212
319
 
213
320
  Args:
214
- model_name: Name of the model to use
215
- provider: Provider name ('openai')
216
- config: Optional configuration dictionary
321
+ model_name: Name of the model to use (defaults: OpenAI="tts-1")
322
+ provider: Provider name (defaults to 'openai')
323
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
217
324
 
218
325
  Returns:
219
326
  Audio service instance
220
327
  """
221
- return self.create_service(provider, ModelType.AUDIO, model_name, config)
328
+ # Set defaults based on provider
329
+ final_provider = provider or "openai"
330
+ if final_provider == "openai":
331
+ final_model_name = model_name or "tts-1"
332
+ else:
333
+ final_model_name = model_name or "tts-1"
334
+
335
+ return self.create_service(final_provider, ModelType.AUDIO, final_model_name, config)
336
+
337
+ def get_tts_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
338
+ config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
339
+ """
340
+ Get a Text-to-Speech service instance with automatic defaults
341
+
342
+ Args:
343
+ model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
344
+ provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
345
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
346
+
347
+ Returns:
348
+ TTS service instance
349
+ """
350
+ # Set defaults based on provider
351
+ if provider == "replicate":
352
+ model_name = model_name or "kokoro-82m"
353
+ elif provider == "openai":
354
+ model_name = model_name or "tts-1"
355
+ else:
356
+ # Default provider selection
357
+ provider = provider or "replicate"
358
+ if provider == "replicate":
359
+ model_name = model_name or "kokoro-82m"
360
+ else:
361
+ model_name = model_name or "tts-1"
362
+
363
+ # Ensure model_name is never None
364
+ if model_name is None:
365
+ model_name = "tts-1"
366
+
367
+ if provider == "replicate":
368
+ from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
369
+ from isa_model.inference.providers.replicate_provider import ReplicateProvider
370
+
371
+ # Use full model name for Replicate
372
+ if model_name == "kokoro-82m":
373
+ model_name = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"
374
+
375
+ provider_instance = ReplicateProvider(config=config)
376
+ return ReplicateTTSService(provider=provider_instance, model_name=model_name)
377
+ else:
378
+ return cast('BaseTTSService', self.get_audio_service(model_name, provider, config))
379
+
380
+ def get_stt_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
381
+ config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
382
+ """
383
+ Get a Speech-to-Text service instance with automatic defaults
384
+
385
+ Args:
386
+ model_name: Name of the model to use (defaults: "whisper-1")
387
+ provider: Provider name (defaults to 'openai')
388
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
389
+
390
+ Returns:
391
+ STT service instance
392
+ """
393
+ # Set defaults based on provider
394
+ provider = provider or "openai"
395
+ if provider == "openai":
396
+ model_name = model_name or "whisper-1"
397
+
398
+ # Ensure model_name is never None
399
+ if model_name is None:
400
+ model_name = "whisper-1"
401
+
402
+ from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
403
+ from isa_model.inference.providers.openai_provider import OpenAIProvider
404
+
405
+ # Create provider and service directly with config
406
+ provider_instance = OpenAIProvider(config=config)
407
+ return OpenAISTTService(provider=provider_instance, model_name=model_name)
222
408
 
223
409
  def get_available_services(self) -> Dict[str, List[str]]:
224
410
  """Get information about available services"""
@@ -241,16 +427,90 @@ class AIFactory:
241
427
  cls._instance = cls()
242
428
  return cls._instance
243
429
 
244
- # Alias methods for backward compatibility with tests
245
- def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
430
+ # Alias method for cleaner API
431
+ def get_llm(self, model_name: Optional[str] = None, provider: Optional[str] = None,
246
432
  config: Optional[Dict[str, Any]] = None) -> BaseService:
247
- """Alias for get_llm_service"""
433
+ """
434
+ Alias for get_llm_service with cleaner naming
435
+
436
+ Usage:
437
+ llm = AIFactory().get_llm() # Uses gpt-4.1-nano by default
438
+ llm = AIFactory().get_llm(model_name="llama3.2", provider="ollama")
439
+ llm = AIFactory().get_llm(model_name="gpt-4.1-mini", provider="openai", config={"streaming": True})
440
+ """
248
441
  return self.get_llm_service(model_name, provider, config)
249
442
 
250
- def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
251
- config: Optional[Dict[str, Any]] = None) -> BaseService:
252
- """Alias for get_embedding_service"""
443
+ def get_embed(self, model_name: Optional[str] = None, provider: Optional[str] = None,
444
+ config: Optional[Dict[str, Any]] = None) -> 'BaseEmbedService':
445
+ """
446
+ Get embedding service with automatic defaults
447
+
448
+ Args:
449
+ model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
450
+ provider: Provider name (defaults to 'openai' for production)
451
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
452
+
453
+ Returns:
454
+ Embedding service instance
455
+
456
+ Usage:
457
+ # Default (OpenAI text-embedding-3-small)
458
+ embed = AIFactory().get_embed()
459
+
460
+ # Custom model
461
+ embed = AIFactory().get_embed(model_name="text-embedding-3-large", provider="openai")
462
+
463
+ # Development (Ollama)
464
+ embed = AIFactory().get_embed(provider="ollama")
465
+ """
253
466
  return self.get_embedding_service(model_name, provider, config)
467
+
468
+ def get_stt(self, model_name: Optional[str] = None, provider: Optional[str] = None,
469
+ config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
470
+ """
471
+ Get Speech-to-Text service with automatic defaults
472
+
473
+ Args:
474
+ model_name: Name of the model to use (defaults: "whisper-1")
475
+ provider: Provider name (defaults to 'openai')
476
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
477
+
478
+ Returns:
479
+ STT service instance
480
+
481
+ Usage:
482
+ # Default (OpenAI whisper-1)
483
+ stt = AIFactory().get_stt()
484
+
485
+ # Custom configuration
486
+ stt = AIFactory().get_stt(model_name="whisper-1", provider="openai")
487
+ """
488
+ return self.get_stt_service(model_name, provider, config)
489
+
490
+ def get_tts(self, model_name: Optional[str] = None, provider: Optional[str] = None,
491
+ config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
492
+ """
493
+ Get Text-to-Speech service with automatic defaults
494
+
495
+ Args:
496
+ model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
497
+ provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
498
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
499
+
500
+ Returns:
501
+ TTS service instance
502
+
503
+ Usage:
504
+ # Default (Replicate kokoro-82m)
505
+ tts = AIFactory().get_tts()
506
+
507
+ # Development (OpenAI tts-1)
508
+ tts = AIFactory().get_tts(provider="openai")
509
+
510
+ # Custom model
511
+ tts = AIFactory().get_tts(model_name="tts-1-hd", provider="openai")
512
+ """
513
+ return self.get_tts_service(model_name, provider, config)
254
514
 
255
515
  def get_vision_model(self, model_name: str, provider: str,
256
516
  config: Optional[Dict[str, Any]] = None) -> BaseService:
@@ -258,4 +518,33 @@ class AIFactory:
258
518
  if provider == "replicate":
259
519
  return self.get_image_generation_service(model_name, provider, config)
260
520
  else:
261
- return self.get_vision_service(model_name, provider, config)
521
+ return self.get_vision_service(model_name, provider, config)
522
+
523
+ def get_vision(
524
+ self,
525
+ model_name: Optional[str] = None,
526
+ provider: Optional[str] = None,
527
+ config: Optional[Dict[str, Any]] = None
528
+ ) -> 'BaseVisionService':
529
+ """
530
+ Get vision service with automatic defaults
531
+
532
+ Args:
533
+ model_name: Model name (default: gpt-4.1-nano)
534
+ provider: Provider name (default: openai)
535
+ config: Optional configuration override
536
+
537
+ Returns:
538
+ Vision service instance
539
+ """
540
+ # Set defaults
541
+ if provider is None:
542
+ provider = "openai"
543
+ if model_name is None:
544
+ model_name = "gpt-4.1-nano"
545
+
546
+ return self.get_vision_service(
547
+ model_name=model_name,
548
+ provider=provider,
549
+ config=config
550
+ )