isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. isa_model/__init__.py +30 -1
  2. isa_model/client.py +770 -0
  3. isa_model/core/config/__init__.py +16 -0
  4. isa_model/core/config/config_manager.py +514 -0
  5. isa_model/core/config.py +426 -0
  6. isa_model/core/models/model_billing_tracker.py +476 -0
  7. isa_model/core/models/model_manager.py +399 -0
  8. isa_model/core/models/model_repo.py +343 -0
  9. isa_model/core/pricing_manager.py +426 -0
  10. isa_model/core/services/__init__.py +19 -0
  11. isa_model/core/services/intelligent_model_selector.py +547 -0
  12. isa_model/core/types.py +291 -0
  13. isa_model/deployment/__init__.py +2 -0
  14. isa_model/deployment/cloud/__init__.py +9 -0
  15. isa_model/deployment/cloud/modal/__init__.py +10 -0
  16. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
  17. isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
  18. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
  19. isa_model/deployment/cloud/modal/register_models.py +321 -0
  20. isa_model/deployment/runtime/deployed_service.py +338 -0
  21. isa_model/deployment/services/__init__.py +9 -0
  22. isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
  23. isa_model/deployment/services/model_service.py +332 -0
  24. isa_model/deployment/services/service_monitor.py +356 -0
  25. isa_model/deployment/services/service_registry.py +527 -0
  26. isa_model/eval/__init__.py +80 -44
  27. isa_model/eval/config/__init__.py +10 -0
  28. isa_model/eval/config/evaluation_config.py +108 -0
  29. isa_model/eval/evaluators/__init__.py +18 -0
  30. isa_model/eval/evaluators/base_evaluator.py +503 -0
  31. isa_model/eval/evaluators/llm_evaluator.py +472 -0
  32. isa_model/eval/factory.py +417 -709
  33. isa_model/eval/infrastructure/__init__.py +24 -0
  34. isa_model/eval/infrastructure/experiment_tracker.py +466 -0
  35. isa_model/eval/metrics.py +191 -21
  36. isa_model/inference/ai_factory.py +187 -387
  37. isa_model/inference/providers/modal_provider.py +109 -0
  38. isa_model/inference/providers/yyds_provider.py +108 -0
  39. isa_model/inference/services/__init__.py +2 -1
  40. isa_model/inference/services/audio/base_stt_service.py +65 -1
  41. isa_model/inference/services/audio/base_tts_service.py +75 -1
  42. isa_model/inference/services/audio/openai_stt_service.py +189 -151
  43. isa_model/inference/services/audio/openai_tts_service.py +12 -10
  44. isa_model/inference/services/audio/replicate_tts_service.py +61 -56
  45. isa_model/inference/services/base_service.py +55 -55
  46. isa_model/inference/services/embedding/base_embed_service.py +65 -1
  47. isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
  48. isa_model/inference/services/embedding/openai_embed_service.py +8 -10
  49. isa_model/inference/services/helpers/stacked_config.py +148 -0
  50. isa_model/inference/services/img/__init__.py +18 -0
  51. isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
  52. isa_model/inference/services/img/flux_professional_service.py +603 -0
  53. isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
  54. isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
  55. isa_model/inference/services/llm/__init__.py +3 -3
  56. isa_model/inference/services/llm/base_llm_service.py +519 -35
  57. isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
  58. isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
  59. isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
  60. isa_model/inference/services/llm/ollama_llm_service.py +150 -15
  61. isa_model/inference/services/llm/openai_llm_service.py +134 -31
  62. isa_model/inference/services/llm/yyds_llm_service.py +255 -0
  63. isa_model/inference/services/vision/__init__.py +38 -4
  64. isa_model/inference/services/vision/base_vision_service.py +241 -96
  65. isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
  66. isa_model/inference/services/vision/doc_analysis_service.py +640 -0
  67. isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
  68. isa_model/inference/services/vision/helpers/image_utils.py +272 -3
  69. isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
  70. isa_model/inference/services/vision/openai_vision_service.py +109 -170
  71. isa_model/inference/services/vision/replicate_vision_service.py +508 -0
  72. isa_model/inference/services/vision/ui_analysis_service.py +823 -0
  73. isa_model/scripts/register_models.py +370 -0
  74. isa_model/scripts/register_models_with_embeddings.py +510 -0
  75. isa_model/serving/__init__.py +19 -0
  76. isa_model/serving/api/__init__.py +10 -0
  77. isa_model/serving/api/fastapi_server.py +89 -0
  78. isa_model/serving/api/middleware/__init__.py +9 -0
  79. isa_model/serving/api/middleware/request_logger.py +88 -0
  80. isa_model/serving/api/routes/__init__.py +5 -0
  81. isa_model/serving/api/routes/health.py +82 -0
  82. isa_model/serving/api/routes/llm.py +19 -0
  83. isa_model/serving/api/routes/ui_analysis.py +223 -0
  84. isa_model/serving/api/routes/unified.py +202 -0
  85. isa_model/serving/api/routes/vision.py +19 -0
  86. isa_model/serving/api/schemas/__init__.py +17 -0
  87. isa_model/serving/api/schemas/common.py +33 -0
  88. isa_model/serving/api/schemas/ui_analysis.py +78 -0
  89. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
  90. isa_model-0.3.6.dist-info/RECORD +147 -0
  91. isa_model/core/model_manager.py +0 -208
  92. isa_model/core/model_registry.py +0 -342
  93. isa_model/inference/billing_tracker.py +0 -406
  94. isa_model/inference/services/llm/triton_llm_service.py +0 -481
  95. isa_model/inference/services/vision/ollama_vision_service.py +0 -194
  96. isa_model-0.3.4.dist-info/RECORD +0 -91
  97. /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
  98. /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
  99. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
  100. {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,50 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
3
- from isa_model.inference.providers.base_provider import BaseProvider
4
- from isa_model.inference.billing_tracker import track_usage, ServiceType, Provider
3
+ from ...core.models.model_manager import ModelManager
4
+ from ...core.config.config_manager import ConfigManager
5
+ from ...core.types import Provider, ServiceType
5
6
 
6
7
  T = TypeVar('T') # Generic type for responses
7
8
 
8
9
  class BaseService(ABC):
9
- """Base class for all AI services"""
10
+ """Base class for all AI services - now uses centralized managers"""
10
11
 
11
- def __init__(self, provider: 'BaseProvider', model_name: str):
12
- self.provider = provider
12
+ def __init__(self,
13
+ provider_name: str,
14
+ model_name: str,
15
+ model_manager: Optional[ModelManager] = None,
16
+ config_manager: Optional[ConfigManager] = None):
17
+ self.provider_name = provider_name
13
18
  self.model_name = model_name
14
- self.config = provider.get_full_config()
19
+ self.model_manager = model_manager or ModelManager()
20
+ self.config_manager = config_manager or ConfigManager()
15
21
 
16
- def _track_usage(
22
+ # Validate provider is configured
23
+ if not self.config_manager.is_provider_enabled(provider_name):
24
+ raise ValueError(f"Provider {provider_name} is not configured or enabled")
25
+
26
+ def get_api_key(self) -> str:
27
+ """Get API key for the provider"""
28
+ api_key = self.config_manager.get_provider_api_key(self.provider_name)
29
+ if not api_key:
30
+ raise ValueError(f"No API key configured for provider {self.provider_name}")
31
+ return api_key
32
+
33
+ def get_provider_config(self) -> Dict[str, Any]:
34
+ """Get provider configuration"""
35
+ config = self.config_manager.get_provider_config(self.provider_name)
36
+ if not config:
37
+ return {}
38
+
39
+ return {
40
+ "api_key": config.api_key,
41
+ "api_base_url": config.api_base_url,
42
+ "organization": config.organization,
43
+ "rate_limit_rpm": config.rate_limit_rpm,
44
+ "rate_limit_tpm": config.rate_limit_tpm,
45
+ }
46
+
47
+ async def _track_usage(
17
48
  self,
18
49
  service_type: Union[str, ServiceType],
19
50
  operation: str,
@@ -23,23 +54,30 @@ class BaseService(ABC):
23
54
  output_units: Optional[float] = None,
24
55
  metadata: Optional[Dict[str, Any]] = None
25
56
  ):
26
- """Track usage for billing purposes"""
57
+ """Track usage for billing purposes using centralized billing tracker"""
27
58
  try:
28
- # Determine provider name - try multiple attributes
29
- provider_name = getattr(self.provider, 'name', None) or \
30
- getattr(self.provider, 'provider_name', None) or \
31
- getattr(self.provider, '__class__', type(None)).__name__.lower().replace('provider', '') or \
32
- 'unknown'
59
+ # Calculate cost using centralized pricing
60
+ cost_usd = None
61
+ if input_tokens is not None and output_tokens is not None:
62
+ cost_usd = self.model_manager.calculate_cost(
63
+ provider=self.provider_name,
64
+ model_name=self.model_name,
65
+ input_tokens=input_tokens,
66
+ output_tokens=output_tokens
67
+ )
33
68
 
34
- track_usage(
35
- provider=provider_name,
36
- service_type=service_type,
37
- model_name=self.model_name,
69
+ # Track usage through model manager
70
+ self.model_manager.billing_tracker.track_model_usage(
71
+ model_id=self.model_name,
72
+ operation_type="inference",
73
+ provider=self.provider_name,
74
+ service_type=service_type if isinstance(service_type, str) else service_type.value,
38
75
  operation=operation,
39
76
  input_tokens=input_tokens,
40
77
  output_tokens=output_tokens,
41
78
  input_units=input_units,
42
79
  output_units=output_units,
80
+ cost_usd=cost_usd,
43
81
  metadata=metadata
44
82
  )
45
83
  except Exception as e:
@@ -52,44 +90,6 @@ class BaseService(ABC):
52
90
  yield
53
91
  return self
54
92
 
55
- class BaseLLMService(BaseService):
56
- """Base class for LLM services"""
57
-
58
- @abstractmethod
59
- async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
60
- """Universal invocation method"""
61
- pass
62
-
63
- @abstractmethod
64
- async def achat(self, messages: List[Dict[str, str]]) -> T:
65
- """Chat completion method"""
66
- pass
67
-
68
- @abstractmethod
69
- async def acompletion(self, prompt: str) -> T:
70
- """Text completion method"""
71
- pass
72
-
73
- @abstractmethod
74
- async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
75
- """Generate multiple completions"""
76
- pass
77
-
78
- @abstractmethod
79
- async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
80
- """Stream chat responses"""
81
- pass
82
-
83
- @abstractmethod
84
- def get_token_usage(self) -> Any:
85
- """Get total token usage statistics"""
86
- pass
87
-
88
- @abstractmethod
89
- def get_last_token_usage(self) -> Dict[str, int]:
90
- """Get token usage from last request"""
91
- pass
92
-
93
93
  class BaseEmbeddingService(BaseService):
94
94
  """Base class for embedding services"""
95
95
 
@@ -3,7 +3,71 @@ from typing import Dict, Any, List, Union, Optional
3
3
  from isa_model.inference.services.base_service import BaseService
4
4
 
5
5
  class BaseEmbedService(BaseService):
6
- """Base class for embedding services"""
6
+ """Base class for embedding services with unified task dispatch"""
7
+
8
+ async def invoke(
9
+ self,
10
+ input_data: Union[str, List[str]],
11
+ task: Optional[str] = None,
12
+ **kwargs
13
+ ) -> Union[List[float], List[List[float]], List[Dict[str, Any]], Dict[str, Any]]:
14
+ """
15
+ 统一的任务分发方法 - Base类提供通用实现
16
+
17
+ Args:
18
+ input_data: 输入数据,可以是:
19
+ - str: 单个文本
20
+ - List[str]: 多个文本(批量处理)
21
+ task: 任务类型,支持多种embedding任务
22
+ **kwargs: 任务特定的附加参数
23
+
24
+ Returns:
25
+ Various types depending on task
26
+ """
27
+ task = task or "embed"
28
+
29
+ # ==================== 嵌入生成类任务 ====================
30
+ if task == "embed":
31
+ if isinstance(input_data, list):
32
+ return await self.create_text_embeddings(input_data)
33
+ else:
34
+ return await self.create_text_embedding(input_data)
35
+ elif task == "embed_batch":
36
+ if not isinstance(input_data, list):
37
+ input_data = [input_data]
38
+ return await self.create_text_embeddings(input_data)
39
+ elif task == "chunk_and_embed":
40
+ if isinstance(input_data, list):
41
+ raise ValueError("chunk_and_embed task requires single text input")
42
+ return await self.create_chunks(input_data, kwargs.get("metadata"))
43
+ elif task == "similarity":
44
+ embedding1 = kwargs.get("embedding1")
45
+ embedding2 = kwargs.get("embedding2")
46
+ if not embedding1 or not embedding2:
47
+ raise ValueError("similarity task requires embedding1 and embedding2 parameters")
48
+ similarity = await self.compute_similarity(embedding1, embedding2)
49
+ return {"similarity": similarity}
50
+ elif task == "find_similar":
51
+ query_embedding = kwargs.get("query_embedding")
52
+ candidate_embeddings = kwargs.get("candidate_embeddings")
53
+ if not query_embedding or not candidate_embeddings:
54
+ raise ValueError("find_similar task requires query_embedding and candidate_embeddings parameters")
55
+ return await self.find_similar_texts(
56
+ query_embedding,
57
+ candidate_embeddings,
58
+ kwargs.get("top_k", 5)
59
+ )
60
+ else:
61
+ raise NotImplementedError(f"{self.__class__.__name__} does not support task: {task}")
62
+
63
+ def get_supported_tasks(self) -> List[str]:
64
+ """
65
+ 获取支持的任务列表
66
+
67
+ Returns:
68
+ List of supported task names
69
+ """
70
+ return ["embed", "embed_batch", "chunk_and_embed", "similarity", "find_similar"]
7
71
 
8
72
  @abstractmethod
9
73
  async def create_text_embedding(self, text: str) -> List[float]:
@@ -3,44 +3,65 @@ import httpx
3
3
  import asyncio
4
4
  from typing import List, Dict, Any, Optional
5
5
 
6
- # 保留您指定的导入和框架结构
7
6
  from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
8
- from isa_model.inference.providers.base_provider import BaseProvider
9
7
 
10
8
  logger = logging.getLogger(__name__)
11
9
 
12
10
  class OllamaEmbedService(BaseEmbedService):
13
11
  """
14
- Ollama embedding service.
15
- 此类遵循基础服务架构,但使用其自己的 HTTP 客户端与 Ollama API 通信,
16
- 而不依赖于注入的 backend 对象。
12
+ Ollama embedding service with unified architecture.
13
+ Uses direct HTTP client communication with Ollama API.
17
14
  """
18
15
 
19
- def __init__(self, provider: 'BaseProvider', model_name: str = "bge-m3"):
20
- # 保持对基类和 provider 的兼容
21
- super().__init__(provider, model_name)
16
+ def __init__(self, provider_name: str, model_name: str = "bge-m3", **kwargs):
17
+ super().__init__(provider_name, model_name, **kwargs)
22
18
 
23
- # 从基类继承的 self.config 中获取配置
24
- host = self.config.get("host", "localhost")
25
- port = self.config.get("port", 11434)
19
+ # Get configuration from centralized config manager
20
+ provider_config = self.get_provider_config()
26
21
 
27
- # 创建并持有自己的 httpx 客户端实例
28
- base_url = f"http://{host}:{port}"
29
- self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
22
+ # Initialize HTTP client with provider configuration
23
+ try:
24
+ host = provider_config.get("host", "localhost")
25
+ port = provider_config.get("port", 11434)
26
+ base_url = f"http://{host}:{port}"
27
+
28
+ self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
30
29
 
31
- logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
30
+ logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
31
+
32
+ except Exception as e:
33
+ logger.error(f"Failed to initialize Ollama client: {e}")
34
+ raise ValueError(f"Failed to initialize Ollama client: {e}") from e
32
35
 
33
36
  async def create_text_embedding(self, text: str) -> List[float]:
34
- """为单个文本创建 embedding"""
37
+ """Create embedding for single text"""
35
38
  try:
36
39
  payload = {
37
40
  "model": self.model_name,
38
41
  "prompt": text
39
42
  }
40
- # 使用自己的 client 实例,而不是 self.backend
43
+
41
44
  response = await self.client.post("/api/embeddings", json=payload)
42
- response.raise_for_status() # 检查请求是否成功
43
- return response.json()["embedding"]
45
+ response.raise_for_status()
46
+
47
+ result = response.json()
48
+ embedding = result["embedding"]
49
+
50
+ # Track usage for billing (estimate token usage for Ollama)
51
+ estimated_tokens = len(text.split()) * 1.3 # Rough estimation
52
+ await self._track_usage(
53
+ service_type="embedding",
54
+ operation="create_text_embedding",
55
+ input_tokens=int(estimated_tokens),
56
+ output_tokens=0,
57
+ metadata={
58
+ "model": self.model_name,
59
+ "text_length": len(text),
60
+ "estimated_tokens": int(estimated_tokens)
61
+ }
62
+ )
63
+
64
+ return embedding
44
65
 
45
66
  except httpx.RequestError as e:
46
67
  logger.error(f"An error occurred while requesting {e.request.url!r}: {e}")
@@ -50,41 +71,70 @@ class OllamaEmbedService(BaseEmbedService):
50
71
  raise
51
72
 
52
73
  async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
53
- """为多个文本并发地创建 embeddings"""
74
+ """Create embeddings for multiple texts concurrently"""
54
75
  if not texts:
55
76
  return []
56
77
 
57
78
  tasks = [self.create_text_embedding(text) for text in texts]
58
79
  embeddings = await asyncio.gather(*tasks)
80
+
81
+ # Track batch usage for billing
82
+ total_estimated_tokens = sum(len(text.split()) * 1.3 for text in texts)
83
+ await self._track_usage(
84
+ service_type="embedding",
85
+ operation="create_text_embeddings",
86
+ input_tokens=int(total_estimated_tokens),
87
+ output_tokens=0,
88
+ metadata={
89
+ "model": self.model_name,
90
+ "batch_size": len(texts),
91
+ "total_text_length": sum(len(t) for t in texts),
92
+ "estimated_tokens": int(total_estimated_tokens)
93
+ }
94
+ )
95
+
59
96
  return embeddings
60
97
 
61
98
  async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
62
- """将文本分块并为每个块创建 embedding"""
63
- chunk_size = 200 # 单词数量
64
- words = text.split()
65
- chunk_texts = [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
99
+ """Create text chunks with embeddings"""
100
+ chunk_size = 200 # words
101
+ overlap = 50 # word overlap between chunks
66
102
 
67
- if not chunk_texts:
103
+ words = text.split()
104
+ if not words:
68
105
  return []
69
-
70
- embeddings = await self.create_text_embeddings(chunk_texts)
71
106
 
72
- chunks = [
73
- {
107
+ chunks = []
108
+ chunk_texts = []
109
+
110
+ for i in range(0, len(words), chunk_size - overlap):
111
+ chunk_words = words[i:i + chunk_size]
112
+ chunk_text = " ".join(chunk_words)
113
+ chunk_texts.append(chunk_text)
114
+
115
+ chunks.append({
74
116
  "text": chunk_text,
75
- "embedding": emb,
117
+ "start_index": i,
118
+ "end_index": min(i + chunk_size, len(words)),
76
119
  "metadata": metadata or {}
77
- }
78
- for chunk_text, emb in zip(chunk_texts, embeddings)
79
- ]
80
-
120
+ })
121
+
122
+ # Get embeddings for all chunks
123
+ embeddings = await self.create_text_embeddings(chunk_texts)
124
+
125
+ # Add embeddings to chunks
126
+ for chunk, embedding in zip(chunks, embeddings):
127
+ chunk["embedding"] = embedding
128
+
81
129
  return chunks
82
130
 
83
131
  async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
84
- """计算两个嵌入向量之间的余弦相似度"""
132
+ """Compute cosine similarity between two embeddings"""
133
+ import math
134
+
85
135
  dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
86
- norm1 = sum(a * a for a in embedding1) ** 0.5
87
- norm2 = sum(b * b for b in embedding2) ** 0.5
136
+ norm1 = math.sqrt(sum(a * a for a in embedding1))
137
+ norm2 = math.sqrt(sum(b * b for b in embedding2))
88
138
 
89
139
  if norm1 * norm2 == 0:
90
140
  return 0.0
@@ -99,9 +149,13 @@ class OllamaEmbedService(BaseEmbedService):
99
149
  ) -> List[Dict[str, Any]]:
100
150
  """Find most similar texts based on embeddings"""
101
151
  similarities = []
152
+
102
153
  for i, candidate in enumerate(candidate_embeddings):
103
154
  similarity = await self.compute_similarity(query_embedding, candidate)
104
- similarities.append({"index": i, "similarity": similarity})
155
+ similarities.append({
156
+ "index": i,
157
+ "similarity": similarity
158
+ })
105
159
 
106
160
  # Sort by similarity in descending order and return top_k
107
161
  similarities.sort(key=lambda x: x["similarity"], reverse=True)
@@ -109,15 +163,21 @@ class OllamaEmbedService(BaseEmbedService):
109
163
 
110
164
  def get_embedding_dimension(self) -> int:
111
165
  """Get the dimension of embeddings produced by this service"""
112
- # BGE-M3 produces 1024-dimensional embeddings
113
- return 1024
166
+ # Model-specific dimensions
167
+ model_dimensions = {
168
+ "bge-m3": 1024,
169
+ "bge-large": 1024,
170
+ "all-minilm": 384,
171
+ "nomic-embed-text": 768
172
+ }
173
+ return model_dimensions.get(self.model_name, 1024)
114
174
 
115
175
  def get_max_input_length(self) -> int:
116
176
  """Get maximum input text length supported"""
117
- # BGE-M3 supports up to 8192 tokens
177
+ # Most Ollama embedding models support up to 8192 tokens
118
178
  return 8192
119
179
 
120
180
  async def close(self):
121
- """关闭内置的 HTTP 客户端"""
181
+ """Cleanup resources"""
122
182
  await self.client.aclose()
123
- logger.info("OllamaEmbedService's internal client has been closed.")
183
+ logger.info("OllamaEmbedService client has been closed.")
@@ -5,8 +5,6 @@ from openai import AsyncOpenAI
5
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
6
 
7
7
  from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
8
- from isa_model.inference.providers.base_provider import BaseProvider
9
- from isa_model.inference.billing_tracker import ServiceType
10
8
 
11
9
  logger = logging.getLogger(__name__)
12
10
 
@@ -16,11 +14,11 @@ class OpenAIEmbedService(BaseEmbedService):
16
14
  Provides high-quality embeddings for production use.
17
15
  """
18
16
 
19
- def __init__(self, provider: 'BaseProvider', model_name: str = "text-embedding-3-small"):
20
- super().__init__(provider, model_name)
17
+ def __init__(self, provider_name: str, model_name: str = "text-embedding-3-small", **kwargs):
18
+ super().__init__(provider_name, model_name, **kwargs)
21
19
 
22
- # Get full configuration from provider (including sensitive data)
23
- provider_config = provider.get_full_config()
20
+ # Get configuration from centralized config manager
21
+ provider_config = self.get_provider_config()
24
22
 
25
23
  # Initialize AsyncOpenAI client with provider configuration
26
24
  try:
@@ -67,8 +65,8 @@ class OpenAIEmbedService(BaseEmbedService):
67
65
  usage = getattr(response, 'usage', None)
68
66
  if usage:
69
67
  total_tokens = getattr(usage, 'total_tokens', 0)
70
- self._track_usage(
71
- service_type=ServiceType.EMBEDDING,
68
+ await self._track_usage(
69
+ service_type="embedding",
72
70
  operation="create_text_embedding",
73
71
  input_tokens=total_tokens,
74
72
  output_tokens=0,
@@ -112,8 +110,8 @@ class OpenAIEmbedService(BaseEmbedService):
112
110
  usage = getattr(response, 'usage', None)
113
111
  if usage:
114
112
  total_tokens = getattr(usage, 'total_tokens', 0)
115
- self._track_usage(
116
- service_type=ServiceType.EMBEDDING,
113
+ await self._track_usage(
114
+ service_type="embedding",
117
115
  operation="create_text_embeddings",
118
116
  input_tokens=total_tokens,
119
117
  output_tokens=0,
@@ -0,0 +1,148 @@
1
+ """
2
+ Configuration system for stacked services
3
+ """
4
+
5
+ from typing import Dict, Any, List, Optional
6
+ from dataclasses import dataclass, field
7
+ from enum import Enum
8
+
9
+ # Define stacked service specific layer types
10
+ class StackedLayerType(Enum):
11
+ """Types of processing layers for stacked services"""
12
+ INTELLIGENCE = "intelligence" # High-level understanding
13
+ DETECTION = "detection" # Element/object detection
14
+ CLASSIFICATION = "classification" # Detailed classification
15
+ VALIDATION = "validation" # Result validation
16
+ TRANSFORMATION = "transformation" # Data transformation
17
+ GENERATION = "generation" # Content generation
18
+ ENHANCEMENT = "enhancement" # Quality enhancement
19
+ CONTROL = "control" # Precise control/refinement
20
+ UPSCALING = "upscaling" # Resolution enhancement
21
+
22
+ @dataclass
23
+ class LayerConfig:
24
+ """Configuration for a processing layer"""
25
+ name: str
26
+ layer_type: StackedLayerType
27
+ service_type: str # e.g., 'vision', 'llm'
28
+ model_name: str
29
+ parameters: Dict[str, Any]
30
+ depends_on: List[str] # Layer dependencies
31
+ timeout: float = 30.0
32
+ retry_count: int = 1
33
+ fallback_enabled: bool = True
34
+
35
+ @dataclass
36
+ class LayerResult:
37
+ """Result from a processing layer"""
38
+ layer_name: str
39
+ success: bool
40
+ data: Any
41
+ metadata: Dict[str, Any]
42
+ execution_time: float
43
+ error: Optional[str] = None
44
+
45
+ class WorkflowType(Enum):
46
+ """Predefined workflow types"""
47
+ UI_ANALYSIS_FAST = "ui_analysis_fast"
48
+ UI_ANALYSIS_ACCURATE = "ui_analysis_accurate"
49
+ UI_ANALYSIS_COMPREHENSIVE = "ui_analysis_comprehensive"
50
+ SEARCH_PAGE_ANALYSIS = "search_page_analysis"
51
+ CONTENT_EXTRACTION = "content_extraction"
52
+ FORM_INTERACTION = "form_interaction"
53
+ NAVIGATION_ANALYSIS = "navigation_analysis"
54
+ CUSTOM = "custom"
55
+
56
+ @dataclass
57
+ class StackedServiceConfig:
58
+ """Configuration for a stacked service workflow"""
59
+ name: str
60
+ workflow_type: WorkflowType
61
+ layers: List[LayerConfig] = field(default_factory=list)
62
+ global_timeout: float = 120.0
63
+ parallel_execution: bool = False
64
+ fail_fast: bool = False
65
+ metadata: Dict[str, Any] = field(default_factory=dict)
66
+
67
+ class ConfigManager:
68
+ """Manager for stacked service configurations"""
69
+
70
+ PREDEFINED_CONFIGS = {
71
+ WorkflowType.UI_ANALYSIS_FAST: {
72
+ "name": "Fast UI Analysis",
73
+ "layers": [
74
+ LayerConfig(
75
+ name="page_intelligence",
76
+ layer_type=StackedLayerType.INTELLIGENCE,
77
+ service_type="vision",
78
+ model_name="gpt-4.1-nano",
79
+ parameters={"max_tokens": 300},
80
+ depends_on=[],
81
+ timeout=10.0,
82
+ fallback_enabled=True
83
+ ),
84
+ LayerConfig(
85
+ name="element_detection",
86
+ layer_type=StackedLayerType.DETECTION,
87
+ service_type="vision",
88
+ model_name="omniparser",
89
+ parameters={
90
+ "imgsz": 480,
91
+ "box_threshold": 0.08,
92
+ "iou_threshold": 0.2
93
+ },
94
+ depends_on=["page_intelligence"],
95
+ timeout=15.0,
96
+ fallback_enabled=True
97
+ ),
98
+ LayerConfig(
99
+ name="element_classification",
100
+ layer_type=StackedLayerType.CLASSIFICATION,
101
+ service_type="vision",
102
+ model_name="gpt-4.1-nano",
103
+ parameters={"max_tokens": 200},
104
+ depends_on=["page_intelligence", "element_detection"],
105
+ timeout=20.0,
106
+ fallback_enabled=False
107
+ )
108
+ ],
109
+ "global_timeout": 60.0,
110
+ "parallel_execution": False,
111
+ "fail_fast": False,
112
+ "metadata": {
113
+ "description": "Fast UI analysis optimized for speed",
114
+ "expected_time": "30-45 seconds",
115
+ "accuracy": "medium"
116
+ }
117
+ }
118
+ }
119
+
120
+ @classmethod
121
+ def get_config(cls, workflow_type: WorkflowType) -> StackedServiceConfig:
122
+ """Get predefined configuration for a workflow type"""
123
+ if workflow_type not in cls.PREDEFINED_CONFIGS:
124
+ raise ValueError(f"Unknown workflow type: {workflow_type}")
125
+
126
+ config_data = cls.PREDEFINED_CONFIGS[workflow_type]
127
+
128
+ return StackedServiceConfig(
129
+ name=config_data["name"],
130
+ workflow_type=workflow_type,
131
+ layers=config_data["layers"],
132
+ global_timeout=config_data["global_timeout"],
133
+ parallel_execution=config_data["parallel_execution"],
134
+ fail_fast=config_data["fail_fast"],
135
+ metadata=config_data["metadata"]
136
+ )
137
+
138
+ # Convenience function for quick access
139
+ def get_ui_analysis_config(speed: str = "accurate") -> StackedServiceConfig:
140
+ """Get UI analysis configuration by speed preference"""
141
+ speed_mapping = {
142
+ "fast": WorkflowType.UI_ANALYSIS_FAST,
143
+ "accurate": WorkflowType.UI_ANALYSIS_ACCURATE,
144
+ "comprehensive": WorkflowType.UI_ANALYSIS_COMPREHENSIVE
145
+ }
146
+
147
+ workflow_type = speed_mapping.get(speed.lower(), WorkflowType.UI_ANALYSIS_ACCURATE)
148
+ return ConfigManager.get_config(workflow_type)
@@ -0,0 +1,18 @@
1
+ """
2
+ Image Generation Services
3
+
4
+ This module contains services for image generation, separate from vision understanding.
5
+ Including stacked services for complex image generation pipelines.
6
+ """
7
+
8
+ from .base_image_gen_service import BaseImageGenService
9
+ from .replicate_image_gen_service import ReplicateImageGenService
10
+
11
+ # Stacked Image Generation Services
12
+ from .flux_professional_service import FluxProfessionalService
13
+
14
+ __all__ = [
15
+ 'BaseImageGenService',
16
+ 'ReplicateImageGenService',
17
+ 'FluxProfessionalService'
18
+ ]