isa-model 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +770 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/models/model_repo.py +343 -0
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/__init__.py +9 -0
- isa_model/deployment/cloud/modal/__init__.py +10 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +766 -0
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +406 -0
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +187 -387
- isa_model/inference/providers/modal_provider.py +109 -0
- isa_model/inference/providers/yyds_provider.py +108 -0
- isa_model/inference/services/__init__.py +2 -1
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -55
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -35
- isa_model/inference/services/img/flux_professional_service.py +603 -0
- isa_model/inference/services/img/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +210 -69
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +519 -35
- isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +40 -0
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +150 -15
- isa_model/inference/services/llm/openai_llm_service.py +134 -31
- isa_model/inference/services/llm/yyds_llm_service.py +255 -0
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +241 -96
- isa_model/inference/services/vision/disabled/isA_vision_service.py +500 -0
- isa_model/inference/services/vision/doc_analysis_service.py +640 -0
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +109 -170
- isa_model/inference/services/vision/replicate_vision_service.py +508 -0
- isa_model/inference/services/vision/ui_analysis_service.py +823 -0
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/__init__.py +19 -0
- isa_model/serving/api/__init__.py +10 -0
- isa_model/serving/api/fastapi_server.py +89 -0
- isa_model/serving/api/middleware/__init__.py +9 -0
- isa_model/serving/api/middleware/request_logger.py +88 -0
- isa_model/serving/api/routes/__init__.py +5 -0
- isa_model/serving/api/routes/health.py +82 -0
- isa_model/serving/api/routes/llm.py +19 -0
- isa_model/serving/api/routes/ui_analysis.py +223 -0
- isa_model/serving/api/routes/unified.py +202 -0
- isa_model/serving/api/routes/vision.py +19 -0
- isa_model/serving/api/schemas/__init__.py +17 -0
- isa_model/serving/api/schemas/common.py +33 -0
- isa_model/serving/api/schemas/ui_analysis.py +78 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
- isa_model-0.3.6.dist-info/RECORD +147 -0
- isa_model/core/model_manager.py +0 -208
- isa_model/core/model_registry.py +0 -342
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- isa_model-0.3.4.dist-info/RECORD +0 -91
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
- {isa_model-0.3.4.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,50 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
|
3
|
-
from
|
4
|
-
from
|
3
|
+
from ...core.models.model_manager import ModelManager
|
4
|
+
from ...core.config.config_manager import ConfigManager
|
5
|
+
from ...core.types import Provider, ServiceType
|
5
6
|
|
6
7
|
T = TypeVar('T') # Generic type for responses
|
7
8
|
|
8
9
|
class BaseService(ABC):
|
9
|
-
"""Base class for all AI services"""
|
10
|
+
"""Base class for all AI services - now uses centralized managers"""
|
10
11
|
|
11
|
-
def __init__(self,
|
12
|
-
|
12
|
+
def __init__(self,
|
13
|
+
provider_name: str,
|
14
|
+
model_name: str,
|
15
|
+
model_manager: Optional[ModelManager] = None,
|
16
|
+
config_manager: Optional[ConfigManager] = None):
|
17
|
+
self.provider_name = provider_name
|
13
18
|
self.model_name = model_name
|
14
|
-
self.
|
19
|
+
self.model_manager = model_manager or ModelManager()
|
20
|
+
self.config_manager = config_manager or ConfigManager()
|
15
21
|
|
16
|
-
|
22
|
+
# Validate provider is configured
|
23
|
+
if not self.config_manager.is_provider_enabled(provider_name):
|
24
|
+
raise ValueError(f"Provider {provider_name} is not configured or enabled")
|
25
|
+
|
26
|
+
def get_api_key(self) -> str:
|
27
|
+
"""Get API key for the provider"""
|
28
|
+
api_key = self.config_manager.get_provider_api_key(self.provider_name)
|
29
|
+
if not api_key:
|
30
|
+
raise ValueError(f"No API key configured for provider {self.provider_name}")
|
31
|
+
return api_key
|
32
|
+
|
33
|
+
def get_provider_config(self) -> Dict[str, Any]:
|
34
|
+
"""Get provider configuration"""
|
35
|
+
config = self.config_manager.get_provider_config(self.provider_name)
|
36
|
+
if not config:
|
37
|
+
return {}
|
38
|
+
|
39
|
+
return {
|
40
|
+
"api_key": config.api_key,
|
41
|
+
"api_base_url": config.api_base_url,
|
42
|
+
"organization": config.organization,
|
43
|
+
"rate_limit_rpm": config.rate_limit_rpm,
|
44
|
+
"rate_limit_tpm": config.rate_limit_tpm,
|
45
|
+
}
|
46
|
+
|
47
|
+
async def _track_usage(
|
17
48
|
self,
|
18
49
|
service_type: Union[str, ServiceType],
|
19
50
|
operation: str,
|
@@ -23,23 +54,30 @@ class BaseService(ABC):
|
|
23
54
|
output_units: Optional[float] = None,
|
24
55
|
metadata: Optional[Dict[str, Any]] = None
|
25
56
|
):
|
26
|
-
"""Track usage for billing purposes"""
|
57
|
+
"""Track usage for billing purposes using centralized billing tracker"""
|
27
58
|
try:
|
28
|
-
#
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
59
|
+
# Calculate cost using centralized pricing
|
60
|
+
cost_usd = None
|
61
|
+
if input_tokens is not None and output_tokens is not None:
|
62
|
+
cost_usd = self.model_manager.calculate_cost(
|
63
|
+
provider=self.provider_name,
|
64
|
+
model_name=self.model_name,
|
65
|
+
input_tokens=input_tokens,
|
66
|
+
output_tokens=output_tokens
|
67
|
+
)
|
33
68
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
69
|
+
# Track usage through model manager
|
70
|
+
self.model_manager.billing_tracker.track_model_usage(
|
71
|
+
model_id=self.model_name,
|
72
|
+
operation_type="inference",
|
73
|
+
provider=self.provider_name,
|
74
|
+
service_type=service_type if isinstance(service_type, str) else service_type.value,
|
38
75
|
operation=operation,
|
39
76
|
input_tokens=input_tokens,
|
40
77
|
output_tokens=output_tokens,
|
41
78
|
input_units=input_units,
|
42
79
|
output_units=output_units,
|
80
|
+
cost_usd=cost_usd,
|
43
81
|
metadata=metadata
|
44
82
|
)
|
45
83
|
except Exception as e:
|
@@ -52,44 +90,6 @@ class BaseService(ABC):
|
|
52
90
|
yield
|
53
91
|
return self
|
54
92
|
|
55
|
-
class BaseLLMService(BaseService):
|
56
|
-
"""Base class for LLM services"""
|
57
|
-
|
58
|
-
@abstractmethod
|
59
|
-
async def ainvoke(self, prompt: Union[str, List[Dict[str, str]], Any]) -> T:
|
60
|
-
"""Universal invocation method"""
|
61
|
-
pass
|
62
|
-
|
63
|
-
@abstractmethod
|
64
|
-
async def achat(self, messages: List[Dict[str, str]]) -> T:
|
65
|
-
"""Chat completion method"""
|
66
|
-
pass
|
67
|
-
|
68
|
-
@abstractmethod
|
69
|
-
async def acompletion(self, prompt: str) -> T:
|
70
|
-
"""Text completion method"""
|
71
|
-
pass
|
72
|
-
|
73
|
-
@abstractmethod
|
74
|
-
async def agenerate(self, messages: List[Dict[str, str]], n: int = 1) -> List[T]:
|
75
|
-
"""Generate multiple completions"""
|
76
|
-
pass
|
77
|
-
|
78
|
-
@abstractmethod
|
79
|
-
async def astream_chat(self, messages: List[Dict[str, str]]) -> AsyncGenerator[str, None]:
|
80
|
-
"""Stream chat responses"""
|
81
|
-
pass
|
82
|
-
|
83
|
-
@abstractmethod
|
84
|
-
def get_token_usage(self) -> Any:
|
85
|
-
"""Get total token usage statistics"""
|
86
|
-
pass
|
87
|
-
|
88
|
-
@abstractmethod
|
89
|
-
def get_last_token_usage(self) -> Dict[str, int]:
|
90
|
-
"""Get token usage from last request"""
|
91
|
-
pass
|
92
|
-
|
93
93
|
class BaseEmbeddingService(BaseService):
|
94
94
|
"""Base class for embedding services"""
|
95
95
|
|
@@ -3,7 +3,71 @@ from typing import Dict, Any, List, Union, Optional
|
|
3
3
|
from isa_model.inference.services.base_service import BaseService
|
4
4
|
|
5
5
|
class BaseEmbedService(BaseService):
|
6
|
-
"""Base class for embedding services"""
|
6
|
+
"""Base class for embedding services with unified task dispatch"""
|
7
|
+
|
8
|
+
async def invoke(
|
9
|
+
self,
|
10
|
+
input_data: Union[str, List[str]],
|
11
|
+
task: Optional[str] = None,
|
12
|
+
**kwargs
|
13
|
+
) -> Union[List[float], List[List[float]], List[Dict[str, Any]], Dict[str, Any]]:
|
14
|
+
"""
|
15
|
+
统一的任务分发方法 - Base类提供通用实现
|
16
|
+
|
17
|
+
Args:
|
18
|
+
input_data: 输入数据,可以是:
|
19
|
+
- str: 单个文本
|
20
|
+
- List[str]: 多个文本(批量处理)
|
21
|
+
task: 任务类型,支持多种embedding任务
|
22
|
+
**kwargs: 任务特定的附加参数
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
Various types depending on task
|
26
|
+
"""
|
27
|
+
task = task or "embed"
|
28
|
+
|
29
|
+
# ==================== 嵌入生成类任务 ====================
|
30
|
+
if task == "embed":
|
31
|
+
if isinstance(input_data, list):
|
32
|
+
return await self.create_text_embeddings(input_data)
|
33
|
+
else:
|
34
|
+
return await self.create_text_embedding(input_data)
|
35
|
+
elif task == "embed_batch":
|
36
|
+
if not isinstance(input_data, list):
|
37
|
+
input_data = [input_data]
|
38
|
+
return await self.create_text_embeddings(input_data)
|
39
|
+
elif task == "chunk_and_embed":
|
40
|
+
if isinstance(input_data, list):
|
41
|
+
raise ValueError("chunk_and_embed task requires single text input")
|
42
|
+
return await self.create_chunks(input_data, kwargs.get("metadata"))
|
43
|
+
elif task == "similarity":
|
44
|
+
embedding1 = kwargs.get("embedding1")
|
45
|
+
embedding2 = kwargs.get("embedding2")
|
46
|
+
if not embedding1 or not embedding2:
|
47
|
+
raise ValueError("similarity task requires embedding1 and embedding2 parameters")
|
48
|
+
similarity = await self.compute_similarity(embedding1, embedding2)
|
49
|
+
return {"similarity": similarity}
|
50
|
+
elif task == "find_similar":
|
51
|
+
query_embedding = kwargs.get("query_embedding")
|
52
|
+
candidate_embeddings = kwargs.get("candidate_embeddings")
|
53
|
+
if not query_embedding or not candidate_embeddings:
|
54
|
+
raise ValueError("find_similar task requires query_embedding and candidate_embeddings parameters")
|
55
|
+
return await self.find_similar_texts(
|
56
|
+
query_embedding,
|
57
|
+
candidate_embeddings,
|
58
|
+
kwargs.get("top_k", 5)
|
59
|
+
)
|
60
|
+
else:
|
61
|
+
raise NotImplementedError(f"{self.__class__.__name__} does not support task: {task}")
|
62
|
+
|
63
|
+
def get_supported_tasks(self) -> List[str]:
|
64
|
+
"""
|
65
|
+
获取支持的任务列表
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
List of supported task names
|
69
|
+
"""
|
70
|
+
return ["embed", "embed_batch", "chunk_and_embed", "similarity", "find_similar"]
|
7
71
|
|
8
72
|
@abstractmethod
|
9
73
|
async def create_text_embedding(self, text: str) -> List[float]:
|
@@ -3,44 +3,65 @@ import httpx
|
|
3
3
|
import asyncio
|
4
4
|
from typing import List, Dict, Any, Optional
|
5
5
|
|
6
|
-
# 保留您指定的导入和框架结构
|
7
6
|
from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
|
8
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
9
7
|
|
10
8
|
logger = logging.getLogger(__name__)
|
11
9
|
|
12
10
|
class OllamaEmbedService(BaseEmbedService):
|
13
11
|
"""
|
14
|
-
Ollama embedding service.
|
15
|
-
|
16
|
-
而不依赖于注入的 backend 对象。
|
12
|
+
Ollama embedding service with unified architecture.
|
13
|
+
Uses direct HTTP client communication with Ollama API.
|
17
14
|
"""
|
18
15
|
|
19
|
-
def __init__(self,
|
20
|
-
|
21
|
-
super().__init__(provider, model_name)
|
16
|
+
def __init__(self, provider_name: str, model_name: str = "bge-m3", **kwargs):
|
17
|
+
super().__init__(provider_name, model_name, **kwargs)
|
22
18
|
|
23
|
-
#
|
24
|
-
|
25
|
-
port = self.config.get("port", 11434)
|
19
|
+
# Get configuration from centralized config manager
|
20
|
+
provider_config = self.get_provider_config()
|
26
21
|
|
27
|
-
#
|
28
|
-
|
29
|
-
|
22
|
+
# Initialize HTTP client with provider configuration
|
23
|
+
try:
|
24
|
+
host = provider_config.get("host", "localhost")
|
25
|
+
port = provider_config.get("port", 11434)
|
26
|
+
base_url = f"http://{host}:{port}"
|
27
|
+
|
28
|
+
self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
|
30
29
|
|
31
|
-
|
30
|
+
logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
|
31
|
+
|
32
|
+
except Exception as e:
|
33
|
+
logger.error(f"Failed to initialize Ollama client: {e}")
|
34
|
+
raise ValueError(f"Failed to initialize Ollama client: {e}") from e
|
32
35
|
|
33
36
|
async def create_text_embedding(self, text: str) -> List[float]:
|
34
|
-
"""
|
37
|
+
"""Create embedding for single text"""
|
35
38
|
try:
|
36
39
|
payload = {
|
37
40
|
"model": self.model_name,
|
38
41
|
"prompt": text
|
39
42
|
}
|
40
|
-
|
43
|
+
|
41
44
|
response = await self.client.post("/api/embeddings", json=payload)
|
42
|
-
response.raise_for_status()
|
43
|
-
|
45
|
+
response.raise_for_status()
|
46
|
+
|
47
|
+
result = response.json()
|
48
|
+
embedding = result["embedding"]
|
49
|
+
|
50
|
+
# Track usage for billing (estimate token usage for Ollama)
|
51
|
+
estimated_tokens = len(text.split()) * 1.3 # Rough estimation
|
52
|
+
await self._track_usage(
|
53
|
+
service_type="embedding",
|
54
|
+
operation="create_text_embedding",
|
55
|
+
input_tokens=int(estimated_tokens),
|
56
|
+
output_tokens=0,
|
57
|
+
metadata={
|
58
|
+
"model": self.model_name,
|
59
|
+
"text_length": len(text),
|
60
|
+
"estimated_tokens": int(estimated_tokens)
|
61
|
+
}
|
62
|
+
)
|
63
|
+
|
64
|
+
return embedding
|
44
65
|
|
45
66
|
except httpx.RequestError as e:
|
46
67
|
logger.error(f"An error occurred while requesting {e.request.url!r}: {e}")
|
@@ -50,41 +71,70 @@ class OllamaEmbedService(BaseEmbedService):
|
|
50
71
|
raise
|
51
72
|
|
52
73
|
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
53
|
-
"""
|
74
|
+
"""Create embeddings for multiple texts concurrently"""
|
54
75
|
if not texts:
|
55
76
|
return []
|
56
77
|
|
57
78
|
tasks = [self.create_text_embedding(text) for text in texts]
|
58
79
|
embeddings = await asyncio.gather(*tasks)
|
80
|
+
|
81
|
+
# Track batch usage for billing
|
82
|
+
total_estimated_tokens = sum(len(text.split()) * 1.3 for text in texts)
|
83
|
+
await self._track_usage(
|
84
|
+
service_type="embedding",
|
85
|
+
operation="create_text_embeddings",
|
86
|
+
input_tokens=int(total_estimated_tokens),
|
87
|
+
output_tokens=0,
|
88
|
+
metadata={
|
89
|
+
"model": self.model_name,
|
90
|
+
"batch_size": len(texts),
|
91
|
+
"total_text_length": sum(len(t) for t in texts),
|
92
|
+
"estimated_tokens": int(total_estimated_tokens)
|
93
|
+
}
|
94
|
+
)
|
95
|
+
|
59
96
|
return embeddings
|
60
97
|
|
61
98
|
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
62
|
-
"""
|
63
|
-
chunk_size = 200 #
|
64
|
-
|
65
|
-
chunk_texts = [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
|
99
|
+
"""Create text chunks with embeddings"""
|
100
|
+
chunk_size = 200 # words
|
101
|
+
overlap = 50 # word overlap between chunks
|
66
102
|
|
67
|
-
|
103
|
+
words = text.split()
|
104
|
+
if not words:
|
68
105
|
return []
|
69
|
-
|
70
|
-
embeddings = await self.create_text_embeddings(chunk_texts)
|
71
106
|
|
72
|
-
chunks = [
|
73
|
-
|
107
|
+
chunks = []
|
108
|
+
chunk_texts = []
|
109
|
+
|
110
|
+
for i in range(0, len(words), chunk_size - overlap):
|
111
|
+
chunk_words = words[i:i + chunk_size]
|
112
|
+
chunk_text = " ".join(chunk_words)
|
113
|
+
chunk_texts.append(chunk_text)
|
114
|
+
|
115
|
+
chunks.append({
|
74
116
|
"text": chunk_text,
|
75
|
-
"
|
117
|
+
"start_index": i,
|
118
|
+
"end_index": min(i + chunk_size, len(words)),
|
76
119
|
"metadata": metadata or {}
|
77
|
-
}
|
78
|
-
|
79
|
-
|
80
|
-
|
120
|
+
})
|
121
|
+
|
122
|
+
# Get embeddings for all chunks
|
123
|
+
embeddings = await self.create_text_embeddings(chunk_texts)
|
124
|
+
|
125
|
+
# Add embeddings to chunks
|
126
|
+
for chunk, embedding in zip(chunks, embeddings):
|
127
|
+
chunk["embedding"] = embedding
|
128
|
+
|
81
129
|
return chunks
|
82
130
|
|
83
131
|
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
84
|
-
"""
|
132
|
+
"""Compute cosine similarity between two embeddings"""
|
133
|
+
import math
|
134
|
+
|
85
135
|
dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
|
86
|
-
norm1 = sum(a * a for a in embedding1)
|
87
|
-
norm2 = sum(b * b for b in embedding2)
|
136
|
+
norm1 = math.sqrt(sum(a * a for a in embedding1))
|
137
|
+
norm2 = math.sqrt(sum(b * b for b in embedding2))
|
88
138
|
|
89
139
|
if norm1 * norm2 == 0:
|
90
140
|
return 0.0
|
@@ -99,9 +149,13 @@ class OllamaEmbedService(BaseEmbedService):
|
|
99
149
|
) -> List[Dict[str, Any]]:
|
100
150
|
"""Find most similar texts based on embeddings"""
|
101
151
|
similarities = []
|
152
|
+
|
102
153
|
for i, candidate in enumerate(candidate_embeddings):
|
103
154
|
similarity = await self.compute_similarity(query_embedding, candidate)
|
104
|
-
similarities.append({
|
155
|
+
similarities.append({
|
156
|
+
"index": i,
|
157
|
+
"similarity": similarity
|
158
|
+
})
|
105
159
|
|
106
160
|
# Sort by similarity in descending order and return top_k
|
107
161
|
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
@@ -109,15 +163,21 @@ class OllamaEmbedService(BaseEmbedService):
|
|
109
163
|
|
110
164
|
def get_embedding_dimension(self) -> int:
|
111
165
|
"""Get the dimension of embeddings produced by this service"""
|
112
|
-
#
|
113
|
-
|
166
|
+
# Model-specific dimensions
|
167
|
+
model_dimensions = {
|
168
|
+
"bge-m3": 1024,
|
169
|
+
"bge-large": 1024,
|
170
|
+
"all-minilm": 384,
|
171
|
+
"nomic-embed-text": 768
|
172
|
+
}
|
173
|
+
return model_dimensions.get(self.model_name, 1024)
|
114
174
|
|
115
175
|
def get_max_input_length(self) -> int:
|
116
176
|
"""Get maximum input text length supported"""
|
117
|
-
#
|
177
|
+
# Most Ollama embedding models support up to 8192 tokens
|
118
178
|
return 8192
|
119
179
|
|
120
180
|
async def close(self):
|
121
|
-
"""
|
181
|
+
"""Cleanup resources"""
|
122
182
|
await self.client.aclose()
|
123
|
-
logger.info("OllamaEmbedService
|
183
|
+
logger.info("OllamaEmbedService client has been closed.")
|
@@ -5,8 +5,6 @@ from openai import AsyncOpenAI
|
|
5
5
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
6
|
|
7
7
|
from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
|
8
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
-
from isa_model.inference.billing_tracker import ServiceType
|
10
8
|
|
11
9
|
logger = logging.getLogger(__name__)
|
12
10
|
|
@@ -16,11 +14,11 @@ class OpenAIEmbedService(BaseEmbedService):
|
|
16
14
|
Provides high-quality embeddings for production use.
|
17
15
|
"""
|
18
16
|
|
19
|
-
def __init__(self,
|
20
|
-
super().__init__(
|
17
|
+
def __init__(self, provider_name: str, model_name: str = "text-embedding-3-small", **kwargs):
|
18
|
+
super().__init__(provider_name, model_name, **kwargs)
|
21
19
|
|
22
|
-
# Get
|
23
|
-
provider_config =
|
20
|
+
# Get configuration from centralized config manager
|
21
|
+
provider_config = self.get_provider_config()
|
24
22
|
|
25
23
|
# Initialize AsyncOpenAI client with provider configuration
|
26
24
|
try:
|
@@ -67,8 +65,8 @@ class OpenAIEmbedService(BaseEmbedService):
|
|
67
65
|
usage = getattr(response, 'usage', None)
|
68
66
|
if usage:
|
69
67
|
total_tokens = getattr(usage, 'total_tokens', 0)
|
70
|
-
self._track_usage(
|
71
|
-
service_type=
|
68
|
+
await self._track_usage(
|
69
|
+
service_type="embedding",
|
72
70
|
operation="create_text_embedding",
|
73
71
|
input_tokens=total_tokens,
|
74
72
|
output_tokens=0,
|
@@ -112,8 +110,8 @@ class OpenAIEmbedService(BaseEmbedService):
|
|
112
110
|
usage = getattr(response, 'usage', None)
|
113
111
|
if usage:
|
114
112
|
total_tokens = getattr(usage, 'total_tokens', 0)
|
115
|
-
self._track_usage(
|
116
|
-
service_type=
|
113
|
+
await self._track_usage(
|
114
|
+
service_type="embedding",
|
117
115
|
operation="create_text_embeddings",
|
118
116
|
input_tokens=total_tokens,
|
119
117
|
output_tokens=0,
|
@@ -0,0 +1,148 @@
|
|
1
|
+
"""
|
2
|
+
Configuration system for stacked services
|
3
|
+
"""
|
4
|
+
|
5
|
+
from typing import Dict, Any, List, Optional
|
6
|
+
from dataclasses import dataclass, field
|
7
|
+
from enum import Enum
|
8
|
+
|
9
|
+
# Define stacked service specific layer types
|
10
|
+
class StackedLayerType(Enum):
|
11
|
+
"""Types of processing layers for stacked services"""
|
12
|
+
INTELLIGENCE = "intelligence" # High-level understanding
|
13
|
+
DETECTION = "detection" # Element/object detection
|
14
|
+
CLASSIFICATION = "classification" # Detailed classification
|
15
|
+
VALIDATION = "validation" # Result validation
|
16
|
+
TRANSFORMATION = "transformation" # Data transformation
|
17
|
+
GENERATION = "generation" # Content generation
|
18
|
+
ENHANCEMENT = "enhancement" # Quality enhancement
|
19
|
+
CONTROL = "control" # Precise control/refinement
|
20
|
+
UPSCALING = "upscaling" # Resolution enhancement
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class LayerConfig:
|
24
|
+
"""Configuration for a processing layer"""
|
25
|
+
name: str
|
26
|
+
layer_type: StackedLayerType
|
27
|
+
service_type: str # e.g., 'vision', 'llm'
|
28
|
+
model_name: str
|
29
|
+
parameters: Dict[str, Any]
|
30
|
+
depends_on: List[str] # Layer dependencies
|
31
|
+
timeout: float = 30.0
|
32
|
+
retry_count: int = 1
|
33
|
+
fallback_enabled: bool = True
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class LayerResult:
|
37
|
+
"""Result from a processing layer"""
|
38
|
+
layer_name: str
|
39
|
+
success: bool
|
40
|
+
data: Any
|
41
|
+
metadata: Dict[str, Any]
|
42
|
+
execution_time: float
|
43
|
+
error: Optional[str] = None
|
44
|
+
|
45
|
+
class WorkflowType(Enum):
|
46
|
+
"""Predefined workflow types"""
|
47
|
+
UI_ANALYSIS_FAST = "ui_analysis_fast"
|
48
|
+
UI_ANALYSIS_ACCURATE = "ui_analysis_accurate"
|
49
|
+
UI_ANALYSIS_COMPREHENSIVE = "ui_analysis_comprehensive"
|
50
|
+
SEARCH_PAGE_ANALYSIS = "search_page_analysis"
|
51
|
+
CONTENT_EXTRACTION = "content_extraction"
|
52
|
+
FORM_INTERACTION = "form_interaction"
|
53
|
+
NAVIGATION_ANALYSIS = "navigation_analysis"
|
54
|
+
CUSTOM = "custom"
|
55
|
+
|
56
|
+
@dataclass
|
57
|
+
class StackedServiceConfig:
|
58
|
+
"""Configuration for a stacked service workflow"""
|
59
|
+
name: str
|
60
|
+
workflow_type: WorkflowType
|
61
|
+
layers: List[LayerConfig] = field(default_factory=list)
|
62
|
+
global_timeout: float = 120.0
|
63
|
+
parallel_execution: bool = False
|
64
|
+
fail_fast: bool = False
|
65
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
66
|
+
|
67
|
+
class ConfigManager:
|
68
|
+
"""Manager for stacked service configurations"""
|
69
|
+
|
70
|
+
PREDEFINED_CONFIGS = {
|
71
|
+
WorkflowType.UI_ANALYSIS_FAST: {
|
72
|
+
"name": "Fast UI Analysis",
|
73
|
+
"layers": [
|
74
|
+
LayerConfig(
|
75
|
+
name="page_intelligence",
|
76
|
+
layer_type=StackedLayerType.INTELLIGENCE,
|
77
|
+
service_type="vision",
|
78
|
+
model_name="gpt-4.1-nano",
|
79
|
+
parameters={"max_tokens": 300},
|
80
|
+
depends_on=[],
|
81
|
+
timeout=10.0,
|
82
|
+
fallback_enabled=True
|
83
|
+
),
|
84
|
+
LayerConfig(
|
85
|
+
name="element_detection",
|
86
|
+
layer_type=StackedLayerType.DETECTION,
|
87
|
+
service_type="vision",
|
88
|
+
model_name="omniparser",
|
89
|
+
parameters={
|
90
|
+
"imgsz": 480,
|
91
|
+
"box_threshold": 0.08,
|
92
|
+
"iou_threshold": 0.2
|
93
|
+
},
|
94
|
+
depends_on=["page_intelligence"],
|
95
|
+
timeout=15.0,
|
96
|
+
fallback_enabled=True
|
97
|
+
),
|
98
|
+
LayerConfig(
|
99
|
+
name="element_classification",
|
100
|
+
layer_type=StackedLayerType.CLASSIFICATION,
|
101
|
+
service_type="vision",
|
102
|
+
model_name="gpt-4.1-nano",
|
103
|
+
parameters={"max_tokens": 200},
|
104
|
+
depends_on=["page_intelligence", "element_detection"],
|
105
|
+
timeout=20.0,
|
106
|
+
fallback_enabled=False
|
107
|
+
)
|
108
|
+
],
|
109
|
+
"global_timeout": 60.0,
|
110
|
+
"parallel_execution": False,
|
111
|
+
"fail_fast": False,
|
112
|
+
"metadata": {
|
113
|
+
"description": "Fast UI analysis optimized for speed",
|
114
|
+
"expected_time": "30-45 seconds",
|
115
|
+
"accuracy": "medium"
|
116
|
+
}
|
117
|
+
}
|
118
|
+
}
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def get_config(cls, workflow_type: WorkflowType) -> StackedServiceConfig:
|
122
|
+
"""Get predefined configuration for a workflow type"""
|
123
|
+
if workflow_type not in cls.PREDEFINED_CONFIGS:
|
124
|
+
raise ValueError(f"Unknown workflow type: {workflow_type}")
|
125
|
+
|
126
|
+
config_data = cls.PREDEFINED_CONFIGS[workflow_type]
|
127
|
+
|
128
|
+
return StackedServiceConfig(
|
129
|
+
name=config_data["name"],
|
130
|
+
workflow_type=workflow_type,
|
131
|
+
layers=config_data["layers"],
|
132
|
+
global_timeout=config_data["global_timeout"],
|
133
|
+
parallel_execution=config_data["parallel_execution"],
|
134
|
+
fail_fast=config_data["fail_fast"],
|
135
|
+
metadata=config_data["metadata"]
|
136
|
+
)
|
137
|
+
|
138
|
+
# Convenience function for quick access
|
139
|
+
def get_ui_analysis_config(speed: str = "accurate") -> StackedServiceConfig:
|
140
|
+
"""Get UI analysis configuration by speed preference"""
|
141
|
+
speed_mapping = {
|
142
|
+
"fast": WorkflowType.UI_ANALYSIS_FAST,
|
143
|
+
"accurate": WorkflowType.UI_ANALYSIS_ACCURATE,
|
144
|
+
"comprehensive": WorkflowType.UI_ANALYSIS_COMPREHENSIVE
|
145
|
+
}
|
146
|
+
|
147
|
+
workflow_type = speed_mapping.get(speed.lower(), WorkflowType.UI_ANALYSIS_ACCURATE)
|
148
|
+
return ConfigManager.get_config(workflow_type)
|
@@ -0,0 +1,18 @@
|
|
1
|
+
"""
|
2
|
+
Image Generation Services
|
3
|
+
|
4
|
+
This module contains services for image generation, separate from vision understanding.
|
5
|
+
Including stacked services for complex image generation pipelines.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from .base_image_gen_service import BaseImageGenService
|
9
|
+
from .replicate_image_gen_service import ReplicateImageGenService
|
10
|
+
|
11
|
+
# Stacked Image Generation Services
|
12
|
+
from .flux_professional_service import FluxProfessionalService
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
'BaseImageGenService',
|
16
|
+
'ReplicateImageGenService',
|
17
|
+
'FluxProfessionalService'
|
18
|
+
]
|