isa-model 0.0.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. isa_model/__init__.py +1 -1
  2. isa_model/core/model_manager.py +69 -4
  3. isa_model/core/model_registry.py +273 -46
  4. isa_model/core/storage/hf_storage.py +419 -0
  5. isa_model/deployment/__init__.py +52 -0
  6. isa_model/deployment/core/__init__.py +34 -0
  7. isa_model/deployment/core/deployment_config.py +356 -0
  8. isa_model/deployment/core/deployment_manager.py +549 -0
  9. isa_model/deployment/core/isa_deployment_service.py +401 -0
  10. isa_model/eval/factory.py +381 -140
  11. isa_model/inference/ai_factory.py +427 -236
  12. isa_model/inference/billing_tracker.py +406 -0
  13. isa_model/inference/providers/base_provider.py +51 -4
  14. isa_model/inference/providers/ml_provider.py +50 -0
  15. isa_model/inference/providers/ollama_provider.py +37 -18
  16. isa_model/inference/providers/openai_provider.py +65 -36
  17. isa_model/inference/providers/replicate_provider.py +42 -30
  18. isa_model/inference/services/audio/base_stt_service.py +21 -2
  19. isa_model/inference/services/audio/openai_realtime_service.py +353 -0
  20. isa_model/inference/services/audio/openai_stt_service.py +252 -0
  21. isa_model/inference/services/audio/openai_tts_service.py +149 -9
  22. isa_model/inference/services/audio/replicate_tts_service.py +239 -0
  23. isa_model/inference/services/base_service.py +36 -1
  24. isa_model/inference/services/embedding/base_embed_service.py +112 -0
  25. isa_model/inference/services/embedding/ollama_embed_service.py +28 -2
  26. isa_model/inference/services/embedding/openai_embed_service.py +223 -0
  27. isa_model/inference/services/llm/__init__.py +2 -0
  28. isa_model/inference/services/llm/base_llm_service.py +158 -86
  29. isa_model/inference/services/llm/llm_adapter.py +414 -0
  30. isa_model/inference/services/llm/ollama_llm_service.py +252 -63
  31. isa_model/inference/services/llm/openai_llm_service.py +231 -93
  32. isa_model/inference/services/llm/triton_llm_service.py +481 -0
  33. isa_model/inference/services/ml/base_ml_service.py +78 -0
  34. isa_model/inference/services/ml/sklearn_ml_service.py +140 -0
  35. isa_model/inference/services/vision/__init__.py +3 -3
  36. isa_model/inference/services/vision/base_image_gen_service.py +161 -0
  37. isa_model/inference/services/vision/base_vision_service.py +177 -0
  38. isa_model/inference/services/vision/helpers/image_utils.py +4 -3
  39. isa_model/inference/services/vision/ollama_vision_service.py +151 -17
  40. isa_model/inference/services/vision/openai_vision_service.py +275 -41
  41. isa_model/inference/services/vision/replicate_image_gen_service.py +278 -118
  42. isa_model/training/__init__.py +62 -32
  43. isa_model/training/cloud/__init__.py +22 -0
  44. isa_model/training/cloud/job_orchestrator.py +402 -0
  45. isa_model/training/cloud/runpod_trainer.py +454 -0
  46. isa_model/training/cloud/storage_manager.py +482 -0
  47. isa_model/training/core/__init__.py +23 -0
  48. isa_model/training/core/config.py +181 -0
  49. isa_model/training/core/dataset.py +222 -0
  50. isa_model/training/core/trainer.py +720 -0
  51. isa_model/training/core/utils.py +213 -0
  52. isa_model/training/factory.py +229 -198
  53. isa_model-0.3.1.dist-info/METADATA +465 -0
  54. isa_model-0.3.1.dist-info/RECORD +91 -0
  55. isa_model/core/model_router.py +0 -226
  56. isa_model/core/model_version.py +0 -0
  57. isa_model/core/resource_manager.py +0 -202
  58. isa_model/deployment/gpu_fp16_ds8/models/deepseek_r1/1/model.py +0 -120
  59. isa_model/deployment/gpu_fp16_ds8/scripts/download_model.py +0 -18
  60. isa_model/training/engine/llama_factory/__init__.py +0 -39
  61. isa_model/training/engine/llama_factory/config.py +0 -115
  62. isa_model/training/engine/llama_factory/data_adapter.py +0 -284
  63. isa_model/training/engine/llama_factory/examples/__init__.py +0 -6
  64. isa_model/training/engine/llama_factory/examples/finetune_with_tracking.py +0 -185
  65. isa_model/training/engine/llama_factory/examples/rlhf_with_tracking.py +0 -163
  66. isa_model/training/engine/llama_factory/factory.py +0 -331
  67. isa_model/training/engine/llama_factory/rl.py +0 -254
  68. isa_model/training/engine/llama_factory/trainer.py +0 -171
  69. isa_model/training/image_model/configs/create_config.py +0 -37
  70. isa_model/training/image_model/configs/create_flux_config.py +0 -26
  71. isa_model/training/image_model/configs/create_lora_config.py +0 -21
  72. isa_model/training/image_model/prepare_massed_compute.py +0 -97
  73. isa_model/training/image_model/prepare_upload.py +0 -17
  74. isa_model/training/image_model/raw_data/create_captions.py +0 -16
  75. isa_model/training/image_model/raw_data/create_lora_captions.py +0 -20
  76. isa_model/training/image_model/raw_data/pre_processing.py +0 -200
  77. isa_model/training/image_model/train/train.py +0 -42
  78. isa_model/training/image_model/train/train_flux.py +0 -41
  79. isa_model/training/image_model/train/train_lora.py +0 -57
  80. isa_model/training/image_model/train_main.py +0 -25
  81. isa_model-0.0.2.dist-info/METADATA +0 -327
  82. isa_model-0.0.2.dist-info/RECORD +0 -92
  83. isa_model-0.0.2.dist-info/licenses/LICENSE +0 -21
  84. /isa_model/training/{llm_model/annotation → annotation}/annotation_schema.py +0 -0
  85. /isa_model/training/{llm_model/annotation → annotation}/processors/annotation_processor.py +0 -0
  86. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_manager.py +0 -0
  87. /isa_model/training/{llm_model/annotation → annotation}/storage/dataset_schema.py +0 -0
  88. /isa_model/training/{llm_model/annotation → annotation}/tests/test_annotation_flow.py +0 -0
  89. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio copy.py +0 -0
  90. /isa_model/training/{llm_model/annotation → annotation}/tests/test_minio_upload.py +0 -0
  91. /isa_model/training/{llm_model/annotation → annotation}/views/annotation_controller.py +0 -0
  92. {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/WHEEL +0 -0
  93. {isa_model-0.0.2.dist-info → isa_model-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,29 @@
1
- from typing import Dict, Type, Any, Optional, Tuple
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Simplified AI Factory for creating inference services
6
+ Uses the new service architecture with proper base classes and centralized API key management
7
+ """
8
+
9
+ from typing import Dict, Type, Any, Optional, Tuple, List, TYPE_CHECKING, cast
2
10
  import logging
3
11
  from isa_model.inference.providers.base_provider import BaseProvider
4
12
  from isa_model.inference.services.base_service import BaseService
5
13
  from isa_model.inference.base import ModelType
6
- import os
14
+ from isa_model.inference.services.vision.base_vision_service import BaseVisionService
15
+ from isa_model.inference.services.vision.base_image_gen_service import BaseImageGenService
16
+
17
+ if TYPE_CHECKING:
18
+ from isa_model.inference.services.audio.base_stt_service import BaseSTTService
19
+ from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
20
 
8
- # 设置基本的日志配置
9
- logging.basicConfig(level=logging.INFO)
10
21
  logger = logging.getLogger(__name__)
11
22
 
12
23
  class AIFactory:
13
24
  """
14
- Factory for creating AI services based on the Single Model pattern.
25
+ Simplified Factory for creating AI services with proper inheritance hierarchy
26
+ API key management is handled by individual providers
15
27
  """
16
28
 
17
29
  _instance = None
@@ -24,72 +36,85 @@ class AIFactory:
24
36
 
25
37
  def __init__(self):
26
38
  """Initialize the AI Factory."""
27
- self.triton_url = os.environ.get("TRITON_URL", "http://localhost:8000")
28
-
29
- # Cache for services (singleton pattern)
30
- self._llm_services = {}
31
- self._embedding_services = {}
32
- self._speech_services = {}
33
-
34
39
  if not self._is_initialized:
35
40
  self._providers: Dict[str, Type[BaseProvider]] = {}
36
41
  self._services: Dict[Tuple[str, ModelType], Type[BaseService]] = {}
37
42
  self._cached_services: Dict[str, BaseService] = {}
38
- self._initialize_defaults()
43
+ self._initialize_services()
39
44
  AIFactory._is_initialized = True
40
45
 
41
- def _initialize_defaults(self):
42
- """Initialize default providers and services"""
46
+ def _initialize_services(self):
47
+ """Initialize available providers and services"""
48
+ try:
49
+ # Register Ollama services
50
+ self._register_ollama_services()
51
+
52
+ # Register OpenAI services
53
+ self._register_openai_services()
54
+
55
+ # Register Replicate services
56
+ self._register_replicate_services()
57
+
58
+ logger.info("AI Factory initialized with centralized provider API key management")
59
+
60
+ except Exception as e:
61
+ logger.error(f"Error initializing services: {e}")
62
+ logger.warning("Some services may not be available")
63
+
64
+ def _register_ollama_services(self):
65
+ """Register Ollama provider and services"""
43
66
  try:
44
- # Import providers and services
45
67
  from isa_model.inference.providers.ollama_provider import OllamaProvider
46
- from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
47
68
  from isa_model.inference.services.llm.ollama_llm_service import OllamaLLMService
69
+ from isa_model.inference.services.embedding.ollama_embed_service import OllamaEmbedService
70
+ from isa_model.inference.services.vision.ollama_vision_service import OllamaVisionService
48
71
 
49
- # Register Ollama provider and services
50
72
  self.register_provider('ollama', OllamaProvider)
51
- self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
52
73
  self.register_service('ollama', ModelType.LLM, OllamaLLMService)
74
+ self.register_service('ollama', ModelType.EMBEDDING, OllamaEmbedService)
75
+ self.register_service('ollama', ModelType.VISION, OllamaVisionService)
53
76
 
54
- # Register OpenAI provider and services
55
- try:
56
- from isa_model.inference.providers.openai_provider import OpenAIProvider
57
- from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
58
-
59
- self.register_provider('openai', OpenAIProvider)
60
- self.register_service('openai', ModelType.LLM, OpenAILLMService)
61
- logger.info("OpenAI services registered successfully")
62
- except ImportError as e:
63
- logger.warning(f"OpenAI services not available: {e}")
64
-
65
- # Register Replicate provider and services
66
- try:
67
- from isa_model.inference.providers.replicate_provider import ReplicateProvider
68
- from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateVisionService
69
-
70
- self.register_provider('replicate', ReplicateProvider)
71
- self.register_service('replicate', ModelType.VISION, ReplicateVisionService)
72
- logger.info("Replicate provider and vision service registered successfully")
73
- except ImportError as e:
74
- logger.warning(f"Replicate services not available: {e}")
75
- except Exception as e:
76
- logger.warning(f"Error registering Replicate services: {e}")
77
-
78
- # Try to register Triton services
79
- try:
80
- from isa_model.inference.providers.triton_provider import TritonProvider
81
-
82
- self.register_provider('triton', TritonProvider)
83
- logger.info("Triton provider registered successfully")
84
-
85
- except ImportError as e:
86
- logger.warning(f"Triton provider not available: {e}")
87
-
88
- logger.info("Default AI providers and services initialized with backend architecture")
89
- except Exception as e:
90
- logger.error(f"Error initializing default providers and services: {e}")
91
- # Don't raise - allow factory to work even if some services fail to load
92
- logger.warning("Some services may not be available due to import errors")
77
+ logger.info("Ollama services registered successfully")
78
+
79
+ except ImportError as e:
80
+ logger.warning(f"Ollama services not available: {e}")
81
+
82
+ def _register_openai_services(self):
83
+ """Register OpenAI provider and services"""
84
+ try:
85
+ from isa_model.inference.providers.openai_provider import OpenAIProvider
86
+ from isa_model.inference.services.llm.openai_llm_service import OpenAILLMService
87
+ from isa_model.inference.services.audio.openai_tts_service import OpenAITTSService
88
+ from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
89
+ from isa_model.inference.services.embedding.openai_embed_service import OpenAIEmbedService
90
+ from isa_model.inference.services.vision.openai_vision_service import OpenAIVisionService
91
+
92
+ self.register_provider('openai', OpenAIProvider)
93
+ self.register_service('openai', ModelType.LLM, OpenAILLMService)
94
+ self.register_service('openai', ModelType.AUDIO, OpenAITTSService)
95
+ self.register_service('openai', ModelType.EMBEDDING, OpenAIEmbedService)
96
+ self.register_service('openai', ModelType.VISION, OpenAIVisionService)
97
+
98
+ logger.info("OpenAI services registered successfully")
99
+
100
+ except ImportError as e:
101
+ logger.warning(f"OpenAI services not available: {e}")
102
+
103
+ def _register_replicate_services(self):
104
+ """Register Replicate provider and services"""
105
+ try:
106
+ from isa_model.inference.providers.replicate_provider import ReplicateProvider
107
+ from isa_model.inference.services.vision.replicate_image_gen_service import ReplicateImageGenService
108
+ from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
109
+
110
+ self.register_provider('replicate', ReplicateProvider)
111
+ self.register_service('replicate', ModelType.VISION, ReplicateImageGenService)
112
+ self.register_service('replicate', ModelType.AUDIO, ReplicateTTSService)
113
+
114
+ logger.info("Replicate services registered successfully")
115
+
116
+ except ImportError as e:
117
+ logger.warning(f"Replicate services not available: {e}")
93
118
 
94
119
  def register_provider(self, name: str, provider_class: Type[BaseProvider]) -> None:
95
120
  """Register an AI provider"""
@@ -102,31 +127,27 @@ class AIFactory:
102
127
 
103
128
  def create_service(self, provider_name: str, model_type: ModelType,
104
129
  model_name: str, config: Optional[Dict[str, Any]] = None) -> BaseService:
105
- """Create a service instance"""
130
+ """Create a service instance with provider-managed configuration"""
106
131
  try:
107
132
  cache_key = f"{provider_name}_{model_type}_{model_name}"
108
133
 
109
134
  if cache_key in self._cached_services:
110
135
  return self._cached_services[cache_key]
111
136
 
112
- # 基础配置
113
- base_config = {
114
- "log_level": "INFO"
115
- }
116
-
117
- # 合并配置
118
- service_config = {**base_config, **(config or {})}
119
-
120
- # 创建 provider 和 service
121
- provider_class = self._providers[provider_name]
137
+ # Get provider and service classes
138
+ provider_class = self._providers.get(provider_name)
122
139
  service_class = self._services.get((provider_name, model_type))
123
140
 
141
+ if not provider_class:
142
+ raise ValueError(f"No provider registered for '{provider_name}'")
143
+
124
144
  if not service_class:
125
145
  raise ValueError(
126
- f"No service registered for provider {provider_name} and model type {model_type}"
146
+ f"No service registered for provider '{provider_name}' and model type '{model_type}'"
127
147
  )
128
148
 
129
- provider = provider_class(config=service_config)
149
+ # Create provider with user config (provider handles .env loading)
150
+ provider = provider_class(config=config)
130
151
  service = service_class(provider=provider, model_name=model_name)
131
152
 
132
153
  self._cached_services[cache_key] = service
@@ -136,224 +157,394 @@ class AIFactory:
136
157
  logger.error(f"Error creating service: {e}")
137
158
  raise
138
159
 
139
- # Convenient methods for common services
140
- def get_llm(self, model_name: str = "llama3.1", provider: str = "ollama",
141
- config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
160
+ # Convenient methods for common services with updated defaults
161
+ def get_llm_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
162
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
142
163
  """
143
- Get a LLM service instance
164
+ Get a LLM service instance with automatic defaults
144
165
 
145
166
  Args:
146
- model_name: Name of the model to use
147
- provider: Provider name ('ollama', 'openai', 'replicate', etc.)
148
- config: Optional configuration dictionary
149
- api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
167
+ model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-nano", Ollama="llama3.2:3b")
168
+ provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
169
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
170
+ Can include: streaming=True/False, temperature, max_tokens, etc.
150
171
 
151
172
  Returns:
152
173
  LLM service instance
153
-
154
- Example:
155
- # Using with API key directly
156
- llm = AIFactory.get_instance().get_llm(
157
- model_name="gpt-4o-mini",
158
- provider="openai",
159
- api_key="your-api-key-here"
160
- )
161
-
162
- # Using without API key (will use environment variable)
163
- llm = AIFactory.get_instance().get_llm(
164
- model_name="gpt-4o-mini",
165
- provider="openai"
166
- )
167
174
  """
175
+ # Set defaults based on provider
176
+ if provider == "openai":
177
+ final_model_name = model_name or "gpt-4.1-nano"
178
+ final_provider = provider
179
+ elif provider == "ollama":
180
+ final_model_name = model_name or "llama3.2:3b-instruct-fp16"
181
+ final_provider = provider
182
+ else:
183
+ # Default provider selection - OpenAI with cheapest model
184
+ final_provider = provider or "openai"
185
+ if final_provider == "openai":
186
+ final_model_name = model_name or "gpt-4.1-nano"
187
+ else:
188
+ final_model_name = model_name or "llama3.2:3b-instruct-fp16"
168
189
 
169
- # Special case for DeepSeek service
170
- if model_name.lower() in ["deepseek", "deepseek-r1", "qwen3-8b"]:
171
- if "deepseek" in self._cached_services:
172
- return self._cached_services["deepseek"]
190
+ return self.create_service(final_provider, ModelType.LLM, final_model_name, config)
191
+
192
+ def get_embedding_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
193
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
194
+ """
195
+ Get an embedding service instance with automatic defaults
173
196
 
174
- # Special case for Llama3-8B direct service
175
- if model_name.lower() in ["llama3", "llama3-8b", "meta-llama-3"]:
176
- if "llama3" in self._cached_services:
177
- return self._cached_services["llama3"]
197
+ Args:
198
+ model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
199
+ provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
200
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
201
+
202
+ Returns:
203
+ Embedding service instance
204
+ """
205
+ # Set defaults based on provider
206
+ if provider == "openai":
207
+ final_model_name = model_name or "text-embedding-3-small"
208
+ final_provider = provider
209
+ elif provider == "ollama":
210
+ final_model_name = model_name or "bge-m3"
211
+ final_provider = provider
212
+ else:
213
+ # Default provider selection
214
+ final_provider = provider or "openai"
215
+ if final_provider == "openai":
216
+ final_model_name = model_name or "text-embedding-3-small"
217
+ else:
218
+ final_model_name = model_name or "bge-m3"
178
219
 
179
- basic_config: Dict[str, Any] = {
180
- "temperature": 0
181
- }
220
+ return self.create_service(final_provider, ModelType.EMBEDDING, final_model_name, config)
221
+
222
+ def get_vision_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
223
+ config: Optional[Dict[str, Any]] = None) -> BaseVisionService:
224
+ """
225
+ Get a vision service instance with automatic defaults
182
226
 
183
- # Add API key to config if provided
184
- if api_key:
185
- if provider == "openai":
186
- basic_config["api_key"] = api_key
187
- elif provider == "replicate":
188
- basic_config["api_token"] = api_key
227
+ Args:
228
+ model_name: Name of the model to use (defaults: OpenAI="gpt-4.1-mini", Ollama="gemma3:4b")
229
+ provider: Provider name (defaults to 'openai' for production, 'ollama' for dev)
230
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
231
+
232
+ Returns:
233
+ Vision service instance
234
+ """
235
+ # Set defaults based on provider
236
+ if provider == "openai":
237
+ final_model_name = model_name or "gpt-4.1-mini"
238
+ final_provider = provider
239
+ elif provider == "ollama":
240
+ final_model_name = model_name or "llama3.2-vision:latest"
241
+ final_provider = provider
242
+ else:
243
+ # Default provider selection
244
+ final_provider = provider or "openai"
245
+ if final_provider == "openai":
246
+ final_model_name = model_name or "gpt-4.1-mini"
189
247
  else:
190
- logger.warning(f"API key provided but provider '{provider}' may not support it")
248
+ final_model_name = model_name or "llama3.2-vision:latest"
249
+
250
+ return cast(BaseVisionService, self.create_service(final_provider, ModelType.VISION, final_model_name, config))
251
+
252
+ def get_image_generation_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
253
+ config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
254
+ """
255
+ Get an image generation service instance with automatic defaults
256
+
257
+ Args:
258
+ model_name: Name of the model to use (defaults: "black-forest-labs/flux-schnell" for production)
259
+ provider: Provider name (defaults to 'replicate')
260
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
261
+
262
+ Returns:
263
+ Image generation service instance
264
+ """
265
+ # Set defaults based on provider
266
+ final_provider = provider or "replicate"
267
+ if final_provider == "replicate":
268
+ final_model_name = model_name or "black-forest-labs/flux-schnell"
269
+ else:
270
+ final_model_name = model_name or "black-forest-labs/flux-schnell"
191
271
 
192
- if config:
193
- basic_config.update(config)
194
- return self.create_service(provider, ModelType.LLM, model_name, basic_config)
272
+ return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
195
273
 
196
- def get_vision_model(self, model_name: str = "gemma3-4b", provider: str = "triton",
197
- config: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None) -> BaseService:
274
+ def get_img(self, type: str = "t2i", model_name: Optional[str] = None, provider: Optional[str] = None,
275
+ config: Optional[Dict[str, Any]] = None) -> 'BaseImageGenService':
198
276
  """
199
- Get a vision model service instance
277
+ Get an image generation service with type-specific defaults
200
278
 
201
279
  Args:
202
- model_name: Name of the model to use
203
- provider: Provider name ('openai', 'replicate', 'triton', etc.)
280
+ type: Image generation type:
281
+ - "t2i" (text-to-image): Uses flux-schnell ($3 per 1000 images)
282
+ - "i2i" (image-to-image): Uses flux-kontext-pro ($0.04 per image)
283
+ model_name: Optional model name override
284
+ provider: Provider name (defaults to 'replicate')
204
285
  config: Optional configuration dictionary
205
- api_key: Optional API key for the provider (OpenAI, Replicate, etc.)
206
286
 
207
287
  Returns:
208
- Vision service instance
288
+ Image generation service instance
209
289
 
210
- Example:
211
- # Using with API key directly
212
- vision = AIFactory.get_instance().get_vision_model(
213
- model_name="gpt-4o",
214
- provider="openai",
215
- api_key="your-api-key-here"
216
- )
217
-
218
- # Using Replicate for image generation
219
- image_gen = AIFactory.get_instance().get_vision_model(
220
- model_name="stability-ai/sdxl",
221
- provider="replicate",
222
- api_key="your-replicate-token"
223
- )
290
+ Usage:
291
+ # Text-to-image (default)
292
+ img_service = AIFactory().get_img()
293
+ img_service = AIFactory().get_img(type="t2i")
294
+
295
+ # Image-to-image
296
+ img_service = AIFactory().get_img(type="i2i")
297
+
298
+ # Custom model
299
+ img_service = AIFactory().get_img(type="t2i", model_name="custom-model")
224
300
  """
301
+ # Set defaults based on type
302
+ final_provider = provider or "replicate"
303
+
304
+ if type == "t2i":
305
+ # Text-to-image: flux-schnell
306
+ final_model_name = model_name or "black-forest-labs/flux-schnell"
307
+ elif type == "i2i":
308
+ # Image-to-image: flux-kontext-pro
309
+ final_model_name = model_name or "black-forest-labs/flux-kontext-pro"
310
+ else:
311
+ raise ValueError(f"Unknown image generation type: {type}. Use 't2i' or 'i2i'")
225
312
 
226
- # Special case for Gemma3-4B direct service
227
- if model_name.lower() in ["gemma3", "gemma3-4b", "gemma3-vision"]:
228
- if "gemma3" in self._cached_services:
229
- return self._cached_services["gemma3"]
313
+ return cast('BaseImageGenService', self.create_service(final_provider, ModelType.VISION, final_model_name, config))
314
+
315
+ def get_audio_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
316
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
317
+ """
318
+ Get an audio service instance (TTS) with automatic defaults
230
319
 
231
- # Special case for Replicate's image generation models
232
- if provider == "replicate" and "/" in model_name:
233
- replicate_config: Dict[str, Any] = {
234
- "guidance_scale": 7.5,
235
- "num_inference_steps": 30
236
- }
237
-
238
- # Add API key if provided
239
- if api_key:
240
- replicate_config["api_token"] = api_key
241
-
242
- if config:
243
- replicate_config.update(config)
244
- return self.create_service(provider, ModelType.VISION, model_name, replicate_config)
320
+ Args:
321
+ model_name: Name of the model to use (defaults: OpenAI="tts-1")
322
+ provider: Provider name (defaults to 'openai')
323
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
324
+
325
+ Returns:
326
+ Audio service instance
327
+ """
328
+ # Set defaults based on provider
329
+ final_provider = provider or "openai"
330
+ if final_provider == "openai":
331
+ final_model_name = model_name or "tts-1"
332
+ else:
333
+ final_model_name = model_name or "tts-1"
245
334
 
246
- basic_config: Dict[str, Any] = {
247
- "temperature": 0.7,
248
- "max_new_tokens": 512
249
- }
335
+ return self.create_service(final_provider, ModelType.AUDIO, final_model_name, config)
336
+
337
+ def get_tts_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
338
+ config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
339
+ """
340
+ Get a Text-to-Speech service instance with automatic defaults
250
341
 
251
- # Add API key to config if provided
252
- if api_key:
253
- if provider == "openai":
254
- basic_config["api_key"] = api_key
255
- elif provider == "replicate":
256
- basic_config["api_token"] = api_key
342
+ Args:
343
+ model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
344
+ provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
345
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
346
+
347
+ Returns:
348
+ TTS service instance
349
+ """
350
+ # Set defaults based on provider
351
+ if provider == "replicate":
352
+ model_name = model_name or "kokoro-82m"
353
+ elif provider == "openai":
354
+ model_name = model_name or "tts-1"
355
+ else:
356
+ # Default provider selection
357
+ provider = provider or "replicate"
358
+ if provider == "replicate":
359
+ model_name = model_name or "kokoro-82m"
257
360
  else:
258
- logger.warning(f"API key provided but provider '{provider}' may not support it")
361
+ model_name = model_name or "tts-1"
259
362
 
260
- if config:
261
- basic_config.update(config)
262
- return self.create_service(provider, ModelType.VISION, model_name, basic_config)
363
+ # Ensure model_name is never None
364
+ if model_name is None:
365
+ model_name = "tts-1"
366
+
367
+ if provider == "replicate":
368
+ from isa_model.inference.services.audio.replicate_tts_service import ReplicateTTSService
369
+ from isa_model.inference.providers.replicate_provider import ReplicateProvider
370
+
371
+ # Use full model name for Replicate
372
+ if model_name == "kokoro-82m":
373
+ model_name = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"
374
+
375
+ provider_instance = ReplicateProvider(config=config)
376
+ return ReplicateTTSService(provider=provider_instance, model_name=model_name)
377
+ else:
378
+ return cast('BaseTTSService', self.get_audio_service(model_name, provider, config))
263
379
 
264
- def get_embedding(self, model_name: str = "bge-m3", provider: str = "ollama",
265
- config: Optional[Dict[str, Any]] = None) -> BaseService:
266
- """Get an embedding service instance"""
267
- return self.create_service(provider, ModelType.EMBEDDING, model_name, config)
380
+ def get_stt_service(self, model_name: Optional[str] = None, provider: Optional[str] = None,
381
+ config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
382
+ """
383
+ Get a Speech-to-Text service instance with automatic defaults
384
+
385
+ Args:
386
+ model_name: Name of the model to use (defaults: "whisper-1")
387
+ provider: Provider name (defaults to 'openai')
388
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
389
+
390
+ Returns:
391
+ STT service instance
392
+ """
393
+ # Set defaults based on provider
394
+ provider = provider or "openai"
395
+ if provider == "openai":
396
+ model_name = model_name or "whisper-1"
397
+
398
+ # Ensure model_name is never None
399
+ if model_name is None:
400
+ model_name = "whisper-1"
401
+
402
+ from isa_model.inference.services.audio.openai_stt_service import OpenAISTTService
403
+ from isa_model.inference.providers.openai_provider import OpenAIProvider
404
+
405
+ # Create provider and service directly with config
406
+ provider_instance = OpenAIProvider(config=config)
407
+ return OpenAISTTService(provider=provider_instance, model_name=model_name)
268
408
 
269
- def get_rerank(self, model_name: str = "bge-m3", provider: str = "ollama",
270
- config: Optional[Dict[str, Any]] = None) -> BaseService:
271
- """Get a rerank service instance"""
272
- return self.create_service(provider, ModelType.RERANK, model_name, config)
409
+ def get_available_services(self) -> Dict[str, List[str]]:
410
+ """Get information about available services"""
411
+ services = {}
412
+ for (provider, model_type), service_class in self._services.items():
413
+ if provider not in services:
414
+ services[provider] = []
415
+ services[provider].append(f"{model_type.value}: {service_class.__name__}")
416
+ return services
273
417
 
274
- def get_embed_service(self, model_name: str = "bge-m3", provider: str = "ollama",
275
- config: Optional[Dict[str, Any]] = None) -> BaseService:
276
- """Get an embedding service instance"""
277
- return self.get_embedding(model_name, provider, config)
418
+ def clear_cache(self):
419
+ """Clear the service cache"""
420
+ self._cached_services.clear()
421
+ logger.info("Service cache cleared")
278
422
 
279
- def get_speech_model(self, model_name: str = "whisper_tiny", provider: str = "triton",
280
- config: Optional[Dict[str, Any]] = None) -> BaseService:
281
- """Get a speech-to-text model service instance"""
282
-
283
- # Special case for Whisper Tiny direct service
284
- if model_name.lower() in ["whisper", "whisper_tiny", "whisper-tiny"]:
285
- if "whisper" in self._cached_services:
286
- return self._cached_services["whisper"]
423
+ @classmethod
424
+ def get_instance(cls) -> 'AIFactory':
425
+ """Get the singleton instance"""
426
+ if cls._instance is None:
427
+ cls._instance = cls()
428
+ return cls._instance
429
+
430
+ # Alias method for cleaner API
431
+ def get_llm(self, model_name: Optional[str] = None, provider: Optional[str] = None,
432
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
433
+ """
434
+ Alias for get_llm_service with cleaner naming
287
435
 
288
- basic_config = {
289
- "language": "en",
290
- "task": "transcribe"
291
- }
292
- if config:
293
- basic_config.update(config)
294
- return self.create_service(provider, ModelType.AUDIO, model_name, basic_config)
436
+ Usage:
437
+ llm = AIFactory().get_llm() # Uses gpt-4.1-nano by default
438
+ llm = AIFactory().get_llm(model_name="llama3.2", provider="ollama")
439
+ llm = AIFactory().get_llm(model_name="gpt-4.1-mini", provider="openai", config={"streaming": True})
440
+ """
441
+ return self.get_llm_service(model_name, provider, config)
295
442
 
296
- async def get_embedding_service(self, model_name: str) -> Any:
443
+ def get_embed(self, model_name: Optional[str] = None, provider: Optional[str] = None,
444
+ config: Optional[Dict[str, Any]] = None) -> 'BaseEmbedService':
297
445
  """
298
- Get an embedding service for the specified model.
446
+ Get embedding service with automatic defaults
299
447
 
300
448
  Args:
301
- model_name: Name of the model
449
+ model_name: Name of the model to use (defaults: OpenAI="text-embedding-3-small", Ollama="bge-m3")
450
+ provider: Provider name (defaults to 'openai' for production)
451
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
302
452
 
303
453
  Returns:
304
454
  Embedding service instance
455
+
456
+ Usage:
457
+ # Default (OpenAI text-embedding-3-small)
458
+ embed = AIFactory().get_embed()
459
+
460
+ # Custom model
461
+ embed = AIFactory().get_embed(model_name="text-embedding-3-large", provider="openai")
462
+
463
+ # Development (Ollama)
464
+ embed = AIFactory().get_embed(provider="ollama")
305
465
  """
306
- if model_name in self._embedding_services:
307
- return self._embedding_services[model_name]
308
-
309
- else:
310
- raise ValueError(f"Unsupported embedding model: {model_name}")
311
-
312
- async def get_speech_service(self, model_name: str) -> Any:
466
+ return self.get_embedding_service(model_name, provider, config)
467
+
468
+ def get_stt(self, model_name: Optional[str] = None, provider: Optional[str] = None,
469
+ config: Optional[Dict[str, Any]] = None) -> 'BaseSTTService':
313
470
  """
314
- Get a speech service for the specified model.
471
+ Get Speech-to-Text service with automatic defaults
315
472
 
316
473
  Args:
317
- model_name: Name of the model
474
+ model_name: Name of the model to use (defaults: "whisper-1")
475
+ provider: Provider name (defaults to 'openai')
476
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
318
477
 
319
478
  Returns:
320
- Speech service instance
479
+ STT service instance
480
+
481
+ Usage:
482
+ # Default (OpenAI whisper-1)
483
+ stt = AIFactory().get_stt()
484
+
485
+ # Custom configuration
486
+ stt = AIFactory().get_stt(model_name="whisper-1", provider="openai")
487
+ """
488
+ return self.get_stt_service(model_name, provider, config)
489
+
490
+ def get_tts(self, model_name: Optional[str] = None, provider: Optional[str] = None,
491
+ config: Optional[Dict[str, Any]] = None) -> 'BaseTTSService':
321
492
  """
322
- if model_name in self._speech_services:
323
- return self._speech_services[model_name]
493
+ Get Text-to-Speech service with automatic defaults
324
494
 
495
+ Args:
496
+ model_name: Name of the model to use (defaults: Replicate="kokoro-82m", OpenAI="tts-1")
497
+ provider: Provider name (defaults to 'replicate' for production, 'openai' for dev)
498
+ config: Optional configuration dictionary (auto-loads from .env if not provided)
499
+
500
+ Returns:
501
+ TTS service instance
502
+
503
+ Usage:
504
+ # Default (Replicate kokoro-82m)
505
+ tts = AIFactory().get_tts()
506
+
507
+ # Development (OpenAI tts-1)
508
+ tts = AIFactory().get_tts(provider="openai")
509
+
510
+ # Custom model
511
+ tts = AIFactory().get_tts(model_name="tts-1-hd", provider="openai")
512
+ """
513
+ return self.get_tts_service(model_name, provider, config)
514
+
515
+ def get_vision_model(self, model_name: str, provider: str,
516
+ config: Optional[Dict[str, Any]] = None) -> BaseService:
517
+ """Alias for get_vision_service and get_image_generation_service"""
518
+ if provider == "replicate":
519
+ return self.get_image_generation_service(model_name, provider, config)
520
+ else:
521
+ return self.get_vision_service(model_name, provider, config)
325
522
 
326
- def get_model_info(self, model_type: Optional[str] = None) -> Dict[str, Any]:
523
+ def get_vision(
524
+ self,
525
+ model_name: Optional[str] = None,
526
+ provider: Optional[str] = None,
527
+ config: Optional[Dict[str, Any]] = None
528
+ ) -> 'BaseVisionService':
327
529
  """
328
- Get information about available models.
530
+ Get vision service with automatic defaults
329
531
 
330
532
  Args:
331
- model_type: Optional filter for model type
533
+ model_name: Model name (default: gpt-4.1-nano)
534
+ provider: Provider name (default: openai)
535
+ config: Optional configuration override
332
536
 
333
537
  Returns:
334
- Dict of model information
538
+ Vision service instance
335
539
  """
336
- models = {
337
- "llm": [
338
- {"name": "deepseek", "description": "DeepSeek-R1-0528-Qwen3-8B language model"},
339
- {"name": "llama", "description": "Llama3-8B language model"},
340
- {"name": "gemma", "description": "Gemma3-4B language model"}
341
- ],
342
- "embedding": [
343
- {"name": "bge_embed", "description": "BGE-M3 text embedding model"}
344
- ],
345
- "speech": [
346
- {"name": "whisper", "description": "Whisper-tiny speech-to-text model"}
347
- ]
348
- }
540
+ # Set defaults
541
+ if provider is None:
542
+ provider = "openai"
543
+ if model_name is None:
544
+ model_name = "gpt-4.1-nano"
349
545
 
350
- if model_type:
351
- return {model_type: models.get(model_type, [])}
352
- return models
353
-
354
- @classmethod
355
- def get_instance(cls) -> 'AIFactory':
356
- """Get the singleton instance"""
357
- if cls._instance is None:
358
- cls._instance = cls()
359
- return cls._instance
546
+ return self.get_vision_service(
547
+ model_name=model_name,
548
+ provider=provider,
549
+ config=config
550
+ )