isa-model 0.3.5__py3-none-any.whl → 0.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- isa_model/__init__.py +30 -1
- isa_model/client.py +937 -0
- isa_model/core/config/__init__.py +16 -0
- isa_model/core/config/config_manager.py +514 -0
- isa_model/core/config.py +426 -0
- isa_model/core/models/model_billing_tracker.py +476 -0
- isa_model/core/models/model_manager.py +399 -0
- isa_model/core/{storage/supabase_storage.py → models/model_repo.py} +72 -73
- isa_model/core/pricing_manager.py +426 -0
- isa_model/core/services/__init__.py +19 -0
- isa_model/core/services/intelligent_model_selector.py +547 -0
- isa_model/core/types.py +291 -0
- isa_model/deployment/__init__.py +2 -0
- isa_model/deployment/cloud/modal/isa_vision_doc_service.py +157 -3
- isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
- isa_model/deployment/cloud/modal/isa_vision_ui_service.py +104 -3
- isa_model/deployment/cloud/modal/register_models.py +321 -0
- isa_model/deployment/runtime/deployed_service.py +338 -0
- isa_model/deployment/services/__init__.py +9 -0
- isa_model/deployment/services/auto_deploy_vision_service.py +538 -0
- isa_model/deployment/services/model_service.py +332 -0
- isa_model/deployment/services/service_monitor.py +356 -0
- isa_model/deployment/services/service_registry.py +527 -0
- isa_model/deployment/services/simple_auto_deploy_vision_service.py +275 -0
- isa_model/eval/__init__.py +80 -44
- isa_model/eval/config/__init__.py +10 -0
- isa_model/eval/config/evaluation_config.py +108 -0
- isa_model/eval/evaluators/__init__.py +18 -0
- isa_model/eval/evaluators/base_evaluator.py +503 -0
- isa_model/eval/evaluators/llm_evaluator.py +472 -0
- isa_model/eval/factory.py +417 -709
- isa_model/eval/infrastructure/__init__.py +24 -0
- isa_model/eval/infrastructure/experiment_tracker.py +466 -0
- isa_model/eval/metrics.py +191 -21
- isa_model/inference/ai_factory.py +257 -601
- isa_model/inference/services/audio/base_stt_service.py +65 -1
- isa_model/inference/services/audio/base_tts_service.py +75 -1
- isa_model/inference/services/audio/openai_stt_service.py +189 -151
- isa_model/inference/services/audio/openai_tts_service.py +12 -10
- isa_model/inference/services/audio/replicate_tts_service.py +61 -56
- isa_model/inference/services/base_service.py +55 -17
- isa_model/inference/services/embedding/base_embed_service.py +65 -1
- isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
- isa_model/inference/services/embedding/openai_embed_service.py +8 -10
- isa_model/inference/services/helpers/stacked_config.py +148 -0
- isa_model/inference/services/img/__init__.py +18 -0
- isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -1
- isa_model/inference/services/{stacked → img}/flux_professional_service.py +25 -1
- isa_model/inference/services/{stacked → img/helpers}/base_stacked_service.py +40 -35
- isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +44 -31
- isa_model/inference/services/llm/__init__.py +3 -3
- isa_model/inference/services/llm/base_llm_service.py +492 -40
- isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
- isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
- isa_model/inference/services/llm/ollama_llm_service.py +51 -17
- isa_model/inference/services/llm/openai_llm_service.py +70 -19
- isa_model/inference/services/llm/yyds_llm_service.py +24 -23
- isa_model/inference/services/vision/__init__.py +38 -4
- isa_model/inference/services/vision/base_vision_service.py +218 -117
- isa_model/inference/services/vision/{isA_vision_service.py → disabled/isA_vision_service.py} +98 -0
- isa_model/inference/services/{stacked → vision}/doc_analysis_service.py +1 -1
- isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
- isa_model/inference/services/vision/helpers/image_utils.py +272 -3
- isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
- isa_model/inference/services/vision/openai_vision_service.py +104 -307
- isa_model/inference/services/vision/replicate_vision_service.py +140 -325
- isa_model/inference/services/{stacked → vision}/ui_analysis_service.py +2 -498
- isa_model/scripts/register_models.py +370 -0
- isa_model/scripts/register_models_with_embeddings.py +510 -0
- isa_model/serving/api/fastapi_server.py +6 -1
- isa_model/serving/api/routes/unified.py +274 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/METADATA +4 -1
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/RECORD +78 -53
- isa_model/config/__init__.py +0 -9
- isa_model/config/config_manager.py +0 -213
- isa_model/core/model_manager.py +0 -213
- isa_model/core/model_registry.py +0 -375
- isa_model/core/vision_models_init.py +0 -116
- isa_model/inference/billing_tracker.py +0 -406
- isa_model/inference/services/llm/triton_llm_service.py +0 -481
- isa_model/inference/services/stacked/__init__.py +0 -26
- isa_model/inference/services/stacked/config.py +0 -426
- isa_model/inference/services/vision/ollama_vision_service.py +0 -194
- /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
- /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
- /isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/WHEEL +0 -0
- {isa_model-0.3.5.dist-info → isa_model-0.3.7.dist-info}/top_level.txt +0 -0
@@ -4,42 +4,45 @@ import replicate
|
|
4
4
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
5
5
|
|
6
6
|
from isa_model.inference.services.audio.base_tts_service import BaseTTSService
|
7
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
8
|
-
from isa_model.inference.billing_tracker import ServiceType
|
9
7
|
|
10
8
|
logger = logging.getLogger(__name__)
|
11
9
|
|
12
10
|
class ReplicateTTSService(BaseTTSService):
|
13
11
|
"""
|
14
|
-
Replicate Text-to-Speech service using Kokoro model.
|
12
|
+
Replicate Text-to-Speech service using Kokoro model with unified architecture.
|
15
13
|
High-quality voice synthesis with multiple voice options.
|
16
14
|
"""
|
17
15
|
|
18
|
-
def __init__(self,
|
19
|
-
super().__init__(
|
16
|
+
def __init__(self, provider_name: str, model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", **kwargs):
|
17
|
+
super().__init__(provider_name, model_name, **kwargs)
|
20
18
|
|
21
|
-
# Get
|
22
|
-
provider_config =
|
19
|
+
# Get configuration from centralized config manager
|
20
|
+
provider_config = self.get_provider_config()
|
23
21
|
|
24
22
|
# Set up Replicate API token from provider configuration
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
23
|
+
try:
|
24
|
+
self.api_token = provider_config.get('api_key') or provider_config.get('replicate_api_token')
|
25
|
+
if not self.api_token:
|
26
|
+
raise ValueError("Replicate API token not found in provider configuration")
|
27
|
+
|
28
|
+
# Set environment variable for replicate library
|
29
|
+
import os
|
30
|
+
os.environ['REPLICATE_API_TOKEN'] = self.api_token
|
31
|
+
|
32
|
+
# Available voices for Kokoro model
|
33
|
+
self.available_voices = [
|
34
|
+
"af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
|
35
|
+
]
|
36
|
+
|
37
|
+
# Default settings
|
38
|
+
self.default_voice = "af_nicole"
|
39
|
+
self.default_speed = 1.0
|
40
|
+
|
41
|
+
logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
|
42
|
+
|
43
|
+
except Exception as e:
|
44
|
+
logger.error(f"Failed to initialize Replicate client: {e}")
|
45
|
+
raise ValueError(f"Failed to initialize Replicate client: {e}") from e
|
43
46
|
|
44
47
|
@retry(
|
45
48
|
stop=stop_after_attempt(3),
|
@@ -51,8 +54,8 @@ class ReplicateTTSService(BaseTTSService):
|
|
51
54
|
text: str,
|
52
55
|
voice: Optional[str] = None,
|
53
56
|
speed: float = 1.0,
|
54
|
-
pitch:
|
55
|
-
|
57
|
+
pitch: float = 1.0,
|
58
|
+
format: str = "wav"
|
56
59
|
) -> Dict[str, Any]:
|
57
60
|
"""Synthesize speech from text using Kokoro model"""
|
58
61
|
try:
|
@@ -99,8 +102,8 @@ class ReplicateTTSService(BaseTTSService):
|
|
99
102
|
estimated_duration_seconds = (words / 150.0) * 60.0 / speed
|
100
103
|
|
101
104
|
# Track usage for billing
|
102
|
-
self._track_usage(
|
103
|
-
service_type=
|
105
|
+
await self._track_usage(
|
106
|
+
service_type="audio_tts",
|
104
107
|
operation="synthesize_speech",
|
105
108
|
input_tokens=0,
|
106
109
|
output_tokens=0,
|
@@ -115,15 +118,24 @@ class ReplicateTTSService(BaseTTSService):
|
|
115
118
|
}
|
116
119
|
)
|
117
120
|
|
121
|
+
# Download audio data for return format consistency
|
122
|
+
import aiohttp
|
123
|
+
async with aiohttp.ClientSession() as session:
|
124
|
+
async with session.get(audio_url) as response:
|
125
|
+
response.raise_for_status()
|
126
|
+
audio_data = await response.read()
|
127
|
+
|
118
128
|
result = {
|
119
|
-
"
|
120
|
-
"
|
121
|
-
"
|
122
|
-
"
|
123
|
-
"
|
129
|
+
"audio_data": audio_data,
|
130
|
+
"format": "wav", # Kokoro typically outputs WAV
|
131
|
+
"duration": estimated_duration_seconds,
|
132
|
+
"sample_rate": 22050,
|
133
|
+
"audio_url": audio_url, # Keep URL for reference
|
124
134
|
"metadata": {
|
125
135
|
"model": self.model_name,
|
126
136
|
"provider": "replicate",
|
137
|
+
"voice": selected_voice,
|
138
|
+
"speed": speed,
|
127
139
|
"voice_options": self.available_voices
|
128
140
|
}
|
129
141
|
}
|
@@ -137,36 +149,29 @@ class ReplicateTTSService(BaseTTSService):
|
|
137
149
|
|
138
150
|
async def synthesize_speech_to_file(
|
139
151
|
self,
|
140
|
-
text: str,
|
152
|
+
text: str,
|
141
153
|
output_path: str,
|
142
154
|
voice: Optional[str] = None,
|
143
155
|
speed: float = 1.0,
|
144
|
-
pitch:
|
145
|
-
|
156
|
+
pitch: float = 1.0,
|
157
|
+
format: str = "wav"
|
146
158
|
) -> Dict[str, Any]:
|
147
159
|
"""Synthesize speech and save to file"""
|
148
160
|
try:
|
149
|
-
# Get
|
150
|
-
result = await self.synthesize_speech(text, voice, speed, pitch,
|
151
|
-
|
161
|
+
# Get synthesis result
|
162
|
+
result = await self.synthesize_speech(text, voice, speed, pitch, format)
|
163
|
+
audio_data = result["audio_data"]
|
152
164
|
|
153
|
-
#
|
154
|
-
|
155
|
-
|
165
|
+
# Save audio data to file
|
166
|
+
with open(output_path, 'wb') as f:
|
167
|
+
f.write(audio_data)
|
156
168
|
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
await f.write(audio_data)
|
164
|
-
|
165
|
-
result["output_path"] = output_path
|
166
|
-
result["file_size"] = len(audio_data)
|
167
|
-
|
168
|
-
logger.info(f"Audio saved to: {output_path}")
|
169
|
-
return result
|
169
|
+
return {
|
170
|
+
"file_path": output_path,
|
171
|
+
"duration": result["duration"],
|
172
|
+
"sample_rate": result["sample_rate"],
|
173
|
+
"file_size": len(audio_data)
|
174
|
+
}
|
170
175
|
|
171
176
|
except Exception as e:
|
172
177
|
logger.error(f"Error saving audio to file: {e}")
|
@@ -1,19 +1,50 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
2
|
from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
|
3
|
-
from
|
4
|
-
from
|
3
|
+
from ...core.models.model_manager import ModelManager
|
4
|
+
from ...core.config.config_manager import ConfigManager
|
5
|
+
from ...core.types import Provider, ServiceType
|
5
6
|
|
6
7
|
T = TypeVar('T') # Generic type for responses
|
7
8
|
|
8
9
|
class BaseService(ABC):
|
9
|
-
"""Base class for all AI services"""
|
10
|
+
"""Base class for all AI services - now uses centralized managers"""
|
10
11
|
|
11
|
-
def __init__(self,
|
12
|
-
|
12
|
+
def __init__(self,
|
13
|
+
provider_name: str,
|
14
|
+
model_name: str,
|
15
|
+
model_manager: Optional[ModelManager] = None,
|
16
|
+
config_manager: Optional[ConfigManager] = None):
|
17
|
+
self.provider_name = provider_name
|
13
18
|
self.model_name = model_name
|
14
|
-
self.
|
19
|
+
self.model_manager = model_manager or ModelManager()
|
20
|
+
self.config_manager = config_manager or ConfigManager()
|
15
21
|
|
16
|
-
|
22
|
+
# Validate provider is configured
|
23
|
+
if not self.config_manager.is_provider_enabled(provider_name):
|
24
|
+
raise ValueError(f"Provider {provider_name} is not configured or enabled")
|
25
|
+
|
26
|
+
def get_api_key(self) -> str:
|
27
|
+
"""Get API key for the provider"""
|
28
|
+
api_key = self.config_manager.get_provider_api_key(self.provider_name)
|
29
|
+
if not api_key:
|
30
|
+
raise ValueError(f"No API key configured for provider {self.provider_name}")
|
31
|
+
return api_key
|
32
|
+
|
33
|
+
def get_provider_config(self) -> Dict[str, Any]:
|
34
|
+
"""Get provider configuration"""
|
35
|
+
config = self.config_manager.get_provider_config(self.provider_name)
|
36
|
+
if not config:
|
37
|
+
return {}
|
38
|
+
|
39
|
+
return {
|
40
|
+
"api_key": config.api_key,
|
41
|
+
"api_base_url": config.api_base_url,
|
42
|
+
"organization": config.organization,
|
43
|
+
"rate_limit_rpm": config.rate_limit_rpm,
|
44
|
+
"rate_limit_tpm": config.rate_limit_tpm,
|
45
|
+
}
|
46
|
+
|
47
|
+
async def _track_usage(
|
17
48
|
self,
|
18
49
|
service_type: Union[str, ServiceType],
|
19
50
|
operation: str,
|
@@ -23,23 +54,30 @@ class BaseService(ABC):
|
|
23
54
|
output_units: Optional[float] = None,
|
24
55
|
metadata: Optional[Dict[str, Any]] = None
|
25
56
|
):
|
26
|
-
"""Track usage for billing purposes"""
|
57
|
+
"""Track usage for billing purposes using centralized billing tracker"""
|
27
58
|
try:
|
28
|
-
#
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
59
|
+
# Calculate cost using centralized pricing
|
60
|
+
cost_usd = None
|
61
|
+
if input_tokens is not None and output_tokens is not None:
|
62
|
+
cost_usd = self.model_manager.calculate_cost(
|
63
|
+
provider=self.provider_name,
|
64
|
+
model_name=self.model_name,
|
65
|
+
input_tokens=input_tokens,
|
66
|
+
output_tokens=output_tokens
|
67
|
+
)
|
33
68
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
69
|
+
# Track usage through model manager
|
70
|
+
self.model_manager.billing_tracker.track_model_usage(
|
71
|
+
model_id=self.model_name,
|
72
|
+
operation_type="inference",
|
73
|
+
provider=self.provider_name,
|
74
|
+
service_type=service_type if isinstance(service_type, str) else service_type.value,
|
38
75
|
operation=operation,
|
39
76
|
input_tokens=input_tokens,
|
40
77
|
output_tokens=output_tokens,
|
41
78
|
input_units=input_units,
|
42
79
|
output_units=output_units,
|
80
|
+
cost_usd=cost_usd,
|
43
81
|
metadata=metadata
|
44
82
|
)
|
45
83
|
except Exception as e:
|
@@ -3,7 +3,71 @@ from typing import Dict, Any, List, Union, Optional
|
|
3
3
|
from isa_model.inference.services.base_service import BaseService
|
4
4
|
|
5
5
|
class BaseEmbedService(BaseService):
|
6
|
-
"""Base class for embedding services"""
|
6
|
+
"""Base class for embedding services with unified task dispatch"""
|
7
|
+
|
8
|
+
async def invoke(
|
9
|
+
self,
|
10
|
+
input_data: Union[str, List[str]],
|
11
|
+
task: Optional[str] = None,
|
12
|
+
**kwargs
|
13
|
+
) -> Union[List[float], List[List[float]], List[Dict[str, Any]], Dict[str, Any]]:
|
14
|
+
"""
|
15
|
+
统一的任务分发方法 - Base类提供通用实现
|
16
|
+
|
17
|
+
Args:
|
18
|
+
input_data: 输入数据,可以是:
|
19
|
+
- str: 单个文本
|
20
|
+
- List[str]: 多个文本(批量处理)
|
21
|
+
task: 任务类型,支持多种embedding任务
|
22
|
+
**kwargs: 任务特定的附加参数
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
Various types depending on task
|
26
|
+
"""
|
27
|
+
task = task or "embed"
|
28
|
+
|
29
|
+
# ==================== 嵌入生成类任务 ====================
|
30
|
+
if task == "embed":
|
31
|
+
if isinstance(input_data, list):
|
32
|
+
return await self.create_text_embeddings(input_data)
|
33
|
+
else:
|
34
|
+
return await self.create_text_embedding(input_data)
|
35
|
+
elif task == "embed_batch":
|
36
|
+
if not isinstance(input_data, list):
|
37
|
+
input_data = [input_data]
|
38
|
+
return await self.create_text_embeddings(input_data)
|
39
|
+
elif task == "chunk_and_embed":
|
40
|
+
if isinstance(input_data, list):
|
41
|
+
raise ValueError("chunk_and_embed task requires single text input")
|
42
|
+
return await self.create_chunks(input_data, kwargs.get("metadata"))
|
43
|
+
elif task == "similarity":
|
44
|
+
embedding1 = kwargs.get("embedding1")
|
45
|
+
embedding2 = kwargs.get("embedding2")
|
46
|
+
if not embedding1 or not embedding2:
|
47
|
+
raise ValueError("similarity task requires embedding1 and embedding2 parameters")
|
48
|
+
similarity = await self.compute_similarity(embedding1, embedding2)
|
49
|
+
return {"similarity": similarity}
|
50
|
+
elif task == "find_similar":
|
51
|
+
query_embedding = kwargs.get("query_embedding")
|
52
|
+
candidate_embeddings = kwargs.get("candidate_embeddings")
|
53
|
+
if not query_embedding or not candidate_embeddings:
|
54
|
+
raise ValueError("find_similar task requires query_embedding and candidate_embeddings parameters")
|
55
|
+
return await self.find_similar_texts(
|
56
|
+
query_embedding,
|
57
|
+
candidate_embeddings,
|
58
|
+
kwargs.get("top_k", 5)
|
59
|
+
)
|
60
|
+
else:
|
61
|
+
raise NotImplementedError(f"{self.__class__.__name__} does not support task: {task}")
|
62
|
+
|
63
|
+
def get_supported_tasks(self) -> List[str]:
|
64
|
+
"""
|
65
|
+
获取支持的任务列表
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
List of supported task names
|
69
|
+
"""
|
70
|
+
return ["embed", "embed_batch", "chunk_and_embed", "similarity", "find_similar"]
|
7
71
|
|
8
72
|
@abstractmethod
|
9
73
|
async def create_text_embedding(self, text: str) -> List[float]:
|
@@ -3,44 +3,65 @@ import httpx
|
|
3
3
|
import asyncio
|
4
4
|
from typing import List, Dict, Any, Optional
|
5
5
|
|
6
|
-
# 保留您指定的导入和框架结构
|
7
6
|
from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
|
8
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
9
7
|
|
10
8
|
logger = logging.getLogger(__name__)
|
11
9
|
|
12
10
|
class OllamaEmbedService(BaseEmbedService):
|
13
11
|
"""
|
14
|
-
Ollama embedding service.
|
15
|
-
|
16
|
-
而不依赖于注入的 backend 对象。
|
12
|
+
Ollama embedding service with unified architecture.
|
13
|
+
Uses direct HTTP client communication with Ollama API.
|
17
14
|
"""
|
18
15
|
|
19
|
-
def __init__(self,
|
20
|
-
|
21
|
-
super().__init__(provider, model_name)
|
16
|
+
def __init__(self, provider_name: str, model_name: str = "bge-m3", **kwargs):
|
17
|
+
super().__init__(provider_name, model_name, **kwargs)
|
22
18
|
|
23
|
-
#
|
24
|
-
|
25
|
-
port = self.config.get("port", 11434)
|
19
|
+
# Get configuration from centralized config manager
|
20
|
+
provider_config = self.get_provider_config()
|
26
21
|
|
27
|
-
#
|
28
|
-
|
29
|
-
|
22
|
+
# Initialize HTTP client with provider configuration
|
23
|
+
try:
|
24
|
+
host = provider_config.get("host", "localhost")
|
25
|
+
port = provider_config.get("port", 11434)
|
26
|
+
base_url = f"http://{host}:{port}"
|
27
|
+
|
28
|
+
self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
|
30
29
|
|
31
|
-
|
30
|
+
logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
|
31
|
+
|
32
|
+
except Exception as e:
|
33
|
+
logger.error(f"Failed to initialize Ollama client: {e}")
|
34
|
+
raise ValueError(f"Failed to initialize Ollama client: {e}") from e
|
32
35
|
|
33
36
|
async def create_text_embedding(self, text: str) -> List[float]:
|
34
|
-
"""
|
37
|
+
"""Create embedding for single text"""
|
35
38
|
try:
|
36
39
|
payload = {
|
37
40
|
"model": self.model_name,
|
38
41
|
"prompt": text
|
39
42
|
}
|
40
|
-
|
43
|
+
|
41
44
|
response = await self.client.post("/api/embeddings", json=payload)
|
42
|
-
response.raise_for_status()
|
43
|
-
|
45
|
+
response.raise_for_status()
|
46
|
+
|
47
|
+
result = response.json()
|
48
|
+
embedding = result["embedding"]
|
49
|
+
|
50
|
+
# Track usage for billing (estimate token usage for Ollama)
|
51
|
+
estimated_tokens = len(text.split()) * 1.3 # Rough estimation
|
52
|
+
await self._track_usage(
|
53
|
+
service_type="embedding",
|
54
|
+
operation="create_text_embedding",
|
55
|
+
input_tokens=int(estimated_tokens),
|
56
|
+
output_tokens=0,
|
57
|
+
metadata={
|
58
|
+
"model": self.model_name,
|
59
|
+
"text_length": len(text),
|
60
|
+
"estimated_tokens": int(estimated_tokens)
|
61
|
+
}
|
62
|
+
)
|
63
|
+
|
64
|
+
return embedding
|
44
65
|
|
45
66
|
except httpx.RequestError as e:
|
46
67
|
logger.error(f"An error occurred while requesting {e.request.url!r}: {e}")
|
@@ -50,41 +71,70 @@ class OllamaEmbedService(BaseEmbedService):
|
|
50
71
|
raise
|
51
72
|
|
52
73
|
async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
53
|
-
"""
|
74
|
+
"""Create embeddings for multiple texts concurrently"""
|
54
75
|
if not texts:
|
55
76
|
return []
|
56
77
|
|
57
78
|
tasks = [self.create_text_embedding(text) for text in texts]
|
58
79
|
embeddings = await asyncio.gather(*tasks)
|
80
|
+
|
81
|
+
# Track batch usage for billing
|
82
|
+
total_estimated_tokens = sum(len(text.split()) * 1.3 for text in texts)
|
83
|
+
await self._track_usage(
|
84
|
+
service_type="embedding",
|
85
|
+
operation="create_text_embeddings",
|
86
|
+
input_tokens=int(total_estimated_tokens),
|
87
|
+
output_tokens=0,
|
88
|
+
metadata={
|
89
|
+
"model": self.model_name,
|
90
|
+
"batch_size": len(texts),
|
91
|
+
"total_text_length": sum(len(t) for t in texts),
|
92
|
+
"estimated_tokens": int(total_estimated_tokens)
|
93
|
+
}
|
94
|
+
)
|
95
|
+
|
59
96
|
return embeddings
|
60
97
|
|
61
98
|
async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
|
62
|
-
"""
|
63
|
-
chunk_size = 200 #
|
64
|
-
|
65
|
-
chunk_texts = [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
|
99
|
+
"""Create text chunks with embeddings"""
|
100
|
+
chunk_size = 200 # words
|
101
|
+
overlap = 50 # word overlap between chunks
|
66
102
|
|
67
|
-
|
103
|
+
words = text.split()
|
104
|
+
if not words:
|
68
105
|
return []
|
69
|
-
|
70
|
-
embeddings = await self.create_text_embeddings(chunk_texts)
|
71
106
|
|
72
|
-
chunks = [
|
73
|
-
|
107
|
+
chunks = []
|
108
|
+
chunk_texts = []
|
109
|
+
|
110
|
+
for i in range(0, len(words), chunk_size - overlap):
|
111
|
+
chunk_words = words[i:i + chunk_size]
|
112
|
+
chunk_text = " ".join(chunk_words)
|
113
|
+
chunk_texts.append(chunk_text)
|
114
|
+
|
115
|
+
chunks.append({
|
74
116
|
"text": chunk_text,
|
75
|
-
"
|
117
|
+
"start_index": i,
|
118
|
+
"end_index": min(i + chunk_size, len(words)),
|
76
119
|
"metadata": metadata or {}
|
77
|
-
}
|
78
|
-
|
79
|
-
|
80
|
-
|
120
|
+
})
|
121
|
+
|
122
|
+
# Get embeddings for all chunks
|
123
|
+
embeddings = await self.create_text_embeddings(chunk_texts)
|
124
|
+
|
125
|
+
# Add embeddings to chunks
|
126
|
+
for chunk, embedding in zip(chunks, embeddings):
|
127
|
+
chunk["embedding"] = embedding
|
128
|
+
|
81
129
|
return chunks
|
82
130
|
|
83
131
|
async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
84
|
-
"""
|
132
|
+
"""Compute cosine similarity between two embeddings"""
|
133
|
+
import math
|
134
|
+
|
85
135
|
dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
|
86
|
-
norm1 = sum(a * a for a in embedding1)
|
87
|
-
norm2 = sum(b * b for b in embedding2)
|
136
|
+
norm1 = math.sqrt(sum(a * a for a in embedding1))
|
137
|
+
norm2 = math.sqrt(sum(b * b for b in embedding2))
|
88
138
|
|
89
139
|
if norm1 * norm2 == 0:
|
90
140
|
return 0.0
|
@@ -99,9 +149,13 @@ class OllamaEmbedService(BaseEmbedService):
|
|
99
149
|
) -> List[Dict[str, Any]]:
|
100
150
|
"""Find most similar texts based on embeddings"""
|
101
151
|
similarities = []
|
152
|
+
|
102
153
|
for i, candidate in enumerate(candidate_embeddings):
|
103
154
|
similarity = await self.compute_similarity(query_embedding, candidate)
|
104
|
-
similarities.append({
|
155
|
+
similarities.append({
|
156
|
+
"index": i,
|
157
|
+
"similarity": similarity
|
158
|
+
})
|
105
159
|
|
106
160
|
# Sort by similarity in descending order and return top_k
|
107
161
|
similarities.sort(key=lambda x: x["similarity"], reverse=True)
|
@@ -109,15 +163,21 @@ class OllamaEmbedService(BaseEmbedService):
|
|
109
163
|
|
110
164
|
def get_embedding_dimension(self) -> int:
|
111
165
|
"""Get the dimension of embeddings produced by this service"""
|
112
|
-
#
|
113
|
-
|
166
|
+
# Model-specific dimensions
|
167
|
+
model_dimensions = {
|
168
|
+
"bge-m3": 1024,
|
169
|
+
"bge-large": 1024,
|
170
|
+
"all-minilm": 384,
|
171
|
+
"nomic-embed-text": 768
|
172
|
+
}
|
173
|
+
return model_dimensions.get(self.model_name, 1024)
|
114
174
|
|
115
175
|
def get_max_input_length(self) -> int:
|
116
176
|
"""Get maximum input text length supported"""
|
117
|
-
#
|
177
|
+
# Most Ollama embedding models support up to 8192 tokens
|
118
178
|
return 8192
|
119
179
|
|
120
180
|
async def close(self):
|
121
|
-
"""
|
181
|
+
"""Cleanup resources"""
|
122
182
|
await self.client.aclose()
|
123
|
-
logger.info("OllamaEmbedService
|
183
|
+
logger.info("OllamaEmbedService client has been closed.")
|
@@ -5,8 +5,6 @@ from openai import AsyncOpenAI
|
|
5
5
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
6
6
|
|
7
7
|
from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
|
8
|
-
from isa_model.inference.providers.base_provider import BaseProvider
|
9
|
-
from isa_model.inference.billing_tracker import ServiceType
|
10
8
|
|
11
9
|
logger = logging.getLogger(__name__)
|
12
10
|
|
@@ -16,11 +14,11 @@ class OpenAIEmbedService(BaseEmbedService):
|
|
16
14
|
Provides high-quality embeddings for production use.
|
17
15
|
"""
|
18
16
|
|
19
|
-
def __init__(self,
|
20
|
-
super().__init__(
|
17
|
+
def __init__(self, provider_name: str, model_name: str = "text-embedding-3-small", **kwargs):
|
18
|
+
super().__init__(provider_name, model_name, **kwargs)
|
21
19
|
|
22
|
-
# Get
|
23
|
-
provider_config =
|
20
|
+
# Get configuration from centralized config manager
|
21
|
+
provider_config = self.get_provider_config()
|
24
22
|
|
25
23
|
# Initialize AsyncOpenAI client with provider configuration
|
26
24
|
try:
|
@@ -67,8 +65,8 @@ class OpenAIEmbedService(BaseEmbedService):
|
|
67
65
|
usage = getattr(response, 'usage', None)
|
68
66
|
if usage:
|
69
67
|
total_tokens = getattr(usage, 'total_tokens', 0)
|
70
|
-
self._track_usage(
|
71
|
-
service_type=
|
68
|
+
await self._track_usage(
|
69
|
+
service_type="embedding",
|
72
70
|
operation="create_text_embedding",
|
73
71
|
input_tokens=total_tokens,
|
74
72
|
output_tokens=0,
|
@@ -112,8 +110,8 @@ class OpenAIEmbedService(BaseEmbedService):
|
|
112
110
|
usage = getattr(response, 'usage', None)
|
113
111
|
if usage:
|
114
112
|
total_tokens = getattr(usage, 'total_tokens', 0)
|
115
|
-
self._track_usage(
|
116
|
-
service_type=
|
113
|
+
await self._track_usage(
|
114
|
+
service_type="embedding",
|
117
115
|
operation="create_text_embeddings",
|
118
116
|
input_tokens=total_tokens,
|
119
117
|
output_tokens=0,
|