isa-model 0.3.5__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. isa_model/__init__.py +30 -1
  2. isa_model/client.py +770 -0
  3. isa_model/core/config/__init__.py +16 -0
  4. isa_model/core/config/config_manager.py +514 -0
  5. isa_model/core/config.py +426 -0
  6. isa_model/core/models/model_billing_tracker.py +476 -0
  7. isa_model/core/models/model_manager.py +399 -0
  8. isa_model/core/{storage/supabase_storage.py → models/model_repo.py} +72 -73
  9. isa_model/core/pricing_manager.py +426 -0
  10. isa_model/core/services/__init__.py +19 -0
  11. isa_model/core/services/intelligent_model_selector.py +547 -0
  12. isa_model/core/types.py +291 -0
  13. isa_model/deployment/__init__.py +2 -0
  14. isa_model/deployment/cloud/modal/isa_vision_doc_service.py +157 -3
  15. isa_model/deployment/cloud/modal/isa_vision_table_service.py +532 -0
  16. isa_model/deployment/cloud/modal/isa_vision_ui_service.py +104 -3
  17. isa_model/deployment/cloud/modal/register_models.py +321 -0
  18. isa_model/deployment/runtime/deployed_service.py +338 -0
  19. isa_model/deployment/services/__init__.py +9 -0
  20. isa_model/deployment/services/auto_deploy_vision_service.py +537 -0
  21. isa_model/deployment/services/model_service.py +332 -0
  22. isa_model/deployment/services/service_monitor.py +356 -0
  23. isa_model/deployment/services/service_registry.py +527 -0
  24. isa_model/eval/__init__.py +80 -44
  25. isa_model/eval/config/__init__.py +10 -0
  26. isa_model/eval/config/evaluation_config.py +108 -0
  27. isa_model/eval/evaluators/__init__.py +18 -0
  28. isa_model/eval/evaluators/base_evaluator.py +503 -0
  29. isa_model/eval/evaluators/llm_evaluator.py +472 -0
  30. isa_model/eval/factory.py +417 -709
  31. isa_model/eval/infrastructure/__init__.py +24 -0
  32. isa_model/eval/infrastructure/experiment_tracker.py +466 -0
  33. isa_model/eval/metrics.py +191 -21
  34. isa_model/inference/ai_factory.py +181 -605
  35. isa_model/inference/services/audio/base_stt_service.py +65 -1
  36. isa_model/inference/services/audio/base_tts_service.py +75 -1
  37. isa_model/inference/services/audio/openai_stt_service.py +189 -151
  38. isa_model/inference/services/audio/openai_tts_service.py +12 -10
  39. isa_model/inference/services/audio/replicate_tts_service.py +61 -56
  40. isa_model/inference/services/base_service.py +55 -17
  41. isa_model/inference/services/embedding/base_embed_service.py +65 -1
  42. isa_model/inference/services/embedding/ollama_embed_service.py +103 -43
  43. isa_model/inference/services/embedding/openai_embed_service.py +8 -10
  44. isa_model/inference/services/helpers/stacked_config.py +148 -0
  45. isa_model/inference/services/img/__init__.py +18 -0
  46. isa_model/inference/services/{vision → img}/base_image_gen_service.py +80 -1
  47. isa_model/inference/services/{stacked → img}/flux_professional_service.py +25 -1
  48. isa_model/inference/services/{stacked → img/helpers}/base_stacked_service.py +40 -35
  49. isa_model/inference/services/{vision → img}/replicate_image_gen_service.py +44 -31
  50. isa_model/inference/services/llm/__init__.py +3 -3
  51. isa_model/inference/services/llm/base_llm_service.py +492 -40
  52. isa_model/inference/services/llm/helpers/llm_prompts.py +258 -0
  53. isa_model/inference/services/llm/helpers/llm_utils.py +280 -0
  54. isa_model/inference/services/llm/ollama_llm_service.py +51 -17
  55. isa_model/inference/services/llm/openai_llm_service.py +70 -19
  56. isa_model/inference/services/llm/yyds_llm_service.py +24 -23
  57. isa_model/inference/services/vision/__init__.py +38 -4
  58. isa_model/inference/services/vision/base_vision_service.py +218 -117
  59. isa_model/inference/services/vision/{isA_vision_service.py → disabled/isA_vision_service.py} +98 -0
  60. isa_model/inference/services/{stacked → vision}/doc_analysis_service.py +1 -1
  61. isa_model/inference/services/vision/helpers/base_stacked_service.py +274 -0
  62. isa_model/inference/services/vision/helpers/image_utils.py +272 -3
  63. isa_model/inference/services/vision/helpers/vision_prompts.py +297 -0
  64. isa_model/inference/services/vision/openai_vision_service.py +104 -307
  65. isa_model/inference/services/vision/replicate_vision_service.py +140 -325
  66. isa_model/inference/services/{stacked → vision}/ui_analysis_service.py +2 -498
  67. isa_model/scripts/register_models.py +370 -0
  68. isa_model/scripts/register_models_with_embeddings.py +510 -0
  69. isa_model/serving/api/fastapi_server.py +6 -1
  70. isa_model/serving/api/routes/unified.py +202 -0
  71. {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/METADATA +4 -1
  72. {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/RECORD +77 -53
  73. isa_model/config/__init__.py +0 -9
  74. isa_model/config/config_manager.py +0 -213
  75. isa_model/core/model_manager.py +0 -213
  76. isa_model/core/model_registry.py +0 -375
  77. isa_model/core/vision_models_init.py +0 -116
  78. isa_model/inference/billing_tracker.py +0 -406
  79. isa_model/inference/services/llm/triton_llm_service.py +0 -481
  80. isa_model/inference/services/stacked/__init__.py +0 -26
  81. isa_model/inference/services/stacked/config.py +0 -426
  82. isa_model/inference/services/vision/ollama_vision_service.py +0 -194
  83. /isa_model/core/{model_storage.py → models/model_storage.py} +0 -0
  84. /isa_model/inference/services/{vision → embedding}/helpers/text_splitter.py +0 -0
  85. /isa_model/inference/services/llm/{llm_adapter.py → helpers/llm_adapter.py} +0 -0
  86. {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/WHEEL +0 -0
  87. {isa_model-0.3.5.dist-info → isa_model-0.3.6.dist-info}/top_level.txt +0 -0
@@ -4,42 +4,45 @@ import replicate
4
4
  from tenacity import retry, stop_after_attempt, wait_exponential
5
5
 
6
6
  from isa_model.inference.services.audio.base_tts_service import BaseTTSService
7
- from isa_model.inference.providers.base_provider import BaseProvider
8
- from isa_model.inference.billing_tracker import ServiceType
9
7
 
10
8
  logger = logging.getLogger(__name__)
11
9
 
12
10
  class ReplicateTTSService(BaseTTSService):
13
11
  """
14
- Replicate Text-to-Speech service using Kokoro model.
12
+ Replicate Text-to-Speech service using Kokoro model with unified architecture.
15
13
  High-quality voice synthesis with multiple voice options.
16
14
  """
17
15
 
18
- def __init__(self, provider: 'BaseProvider', model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13"):
19
- super().__init__(provider, model_name)
16
+ def __init__(self, provider_name: str, model_name: str = "jaaari/kokoro-82m:f559560eb822dc509045f3921a1921234918b91739db4bf3daab2169b71c7a13", **kwargs):
17
+ super().__init__(provider_name, model_name, **kwargs)
20
18
 
21
- # Get full configuration from provider (including sensitive data)
22
- provider_config = provider.get_full_config()
19
+ # Get configuration from centralized config manager
20
+ provider_config = self.get_provider_config()
23
21
 
24
22
  # Set up Replicate API token from provider configuration
25
- self.api_token = provider_config.get('api_token') or provider_config.get('replicate_api_token')
26
- if not self.api_token:
27
- raise ValueError("Replicate API token not found in provider configuration")
28
-
29
- # Set environment variable for replicate library
30
- import os
31
- os.environ['REPLICATE_API_TOKEN'] = self.api_token
32
-
33
- # Available voices for Kokoro model
34
- self.available_voices = [
35
- "af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
36
- ]
37
-
38
- # Default settings
39
- self.default_voice = "af_nicole"
40
- self.default_speed = 1.0
41
-
42
- logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
23
+ try:
24
+ self.api_token = provider_config.get('api_key') or provider_config.get('replicate_api_token')
25
+ if not self.api_token:
26
+ raise ValueError("Replicate API token not found in provider configuration")
27
+
28
+ # Set environment variable for replicate library
29
+ import os
30
+ os.environ['REPLICATE_API_TOKEN'] = self.api_token
31
+
32
+ # Available voices for Kokoro model
33
+ self.available_voices = [
34
+ "af_bella", "af_nicole", "af_sarah", "af_sky", "am_adam", "am_michael"
35
+ ]
36
+
37
+ # Default settings
38
+ self.default_voice = "af_nicole"
39
+ self.default_speed = 1.0
40
+
41
+ logger.info(f"Initialized ReplicateTTSService with model '{self.model_name}'")
42
+
43
+ except Exception as e:
44
+ logger.error(f"Failed to initialize Replicate client: {e}")
45
+ raise ValueError(f"Failed to initialize Replicate client: {e}") from e
43
46
 
44
47
  @retry(
45
48
  stop=stop_after_attempt(3),
@@ -51,8 +54,8 @@ class ReplicateTTSService(BaseTTSService):
51
54
  text: str,
52
55
  voice: Optional[str] = None,
53
56
  speed: float = 1.0,
54
- pitch: Optional[float] = None,
55
- volume: Optional[float] = None
57
+ pitch: float = 1.0,
58
+ format: str = "wav"
56
59
  ) -> Dict[str, Any]:
57
60
  """Synthesize speech from text using Kokoro model"""
58
61
  try:
@@ -99,8 +102,8 @@ class ReplicateTTSService(BaseTTSService):
99
102
  estimated_duration_seconds = (words / 150.0) * 60.0 / speed
100
103
 
101
104
  # Track usage for billing
102
- self._track_usage(
103
- service_type=ServiceType.AUDIO_TTS,
105
+ await self._track_usage(
106
+ service_type="audio_tts",
104
107
  operation="synthesize_speech",
105
108
  input_tokens=0,
106
109
  output_tokens=0,
@@ -115,15 +118,24 @@ class ReplicateTTSService(BaseTTSService):
115
118
  }
116
119
  )
117
120
 
121
+ # Download audio data for return format consistency
122
+ import aiohttp
123
+ async with aiohttp.ClientSession() as session:
124
+ async with session.get(audio_url) as response:
125
+ response.raise_for_status()
126
+ audio_data = await response.read()
127
+
118
128
  result = {
119
- "audio_url": audio_url,
120
- "text": text,
121
- "voice": selected_voice,
122
- "speed": speed,
123
- "duration_seconds": estimated_duration_seconds,
129
+ "audio_data": audio_data,
130
+ "format": "wav", # Kokoro typically outputs WAV
131
+ "duration": estimated_duration_seconds,
132
+ "sample_rate": 22050,
133
+ "audio_url": audio_url, # Keep URL for reference
124
134
  "metadata": {
125
135
  "model": self.model_name,
126
136
  "provider": "replicate",
137
+ "voice": selected_voice,
138
+ "speed": speed,
127
139
  "voice_options": self.available_voices
128
140
  }
129
141
  }
@@ -137,36 +149,29 @@ class ReplicateTTSService(BaseTTSService):
137
149
 
138
150
  async def synthesize_speech_to_file(
139
151
  self,
140
- text: str,
152
+ text: str,
141
153
  output_path: str,
142
154
  voice: Optional[str] = None,
143
155
  speed: float = 1.0,
144
- pitch: Optional[float] = None,
145
- volume: Optional[float] = None
156
+ pitch: float = 1.0,
157
+ format: str = "wav"
146
158
  ) -> Dict[str, Any]:
147
159
  """Synthesize speech and save to file"""
148
160
  try:
149
- # Get audio URL
150
- result = await self.synthesize_speech(text, voice, speed, pitch, volume)
151
- audio_url = result["audio_url"]
161
+ # Get synthesis result
162
+ result = await self.synthesize_speech(text, voice, speed, pitch, format)
163
+ audio_data = result["audio_data"]
152
164
 
153
- # Download and save audio
154
- import aiohttp
155
- import aiofiles
165
+ # Save audio data to file
166
+ with open(output_path, 'wb') as f:
167
+ f.write(audio_data)
156
168
 
157
- async with aiohttp.ClientSession() as session:
158
- async with session.get(audio_url) as response:
159
- response.raise_for_status()
160
- audio_data = await response.read()
161
-
162
- async with aiofiles.open(output_path, 'wb') as f:
163
- await f.write(audio_data)
164
-
165
- result["output_path"] = output_path
166
- result["file_size"] = len(audio_data)
167
-
168
- logger.info(f"Audio saved to: {output_path}")
169
- return result
169
+ return {
170
+ "file_path": output_path,
171
+ "duration": result["duration"],
172
+ "sample_rate": result["sample_rate"],
173
+ "file_size": len(audio_data)
174
+ }
170
175
 
171
176
  except Exception as e:
172
177
  logger.error(f"Error saving audio to file: {e}")
@@ -1,19 +1,50 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from typing import Dict, Any, List, Union, AsyncGenerator, TypeVar, Optional
3
- from isa_model.inference.providers.base_provider import BaseProvider
4
- from isa_model.inference.billing_tracker import track_usage, ServiceType, Provider
3
+ from ...core.models.model_manager import ModelManager
4
+ from ...core.config.config_manager import ConfigManager
5
+ from ...core.types import Provider, ServiceType
5
6
 
6
7
  T = TypeVar('T') # Generic type for responses
7
8
 
8
9
  class BaseService(ABC):
9
- """Base class for all AI services"""
10
+ """Base class for all AI services - now uses centralized managers"""
10
11
 
11
- def __init__(self, provider: 'BaseProvider', model_name: str):
12
- self.provider = provider
12
+ def __init__(self,
13
+ provider_name: str,
14
+ model_name: str,
15
+ model_manager: Optional[ModelManager] = None,
16
+ config_manager: Optional[ConfigManager] = None):
17
+ self.provider_name = provider_name
13
18
  self.model_name = model_name
14
- self.config = provider.get_full_config()
19
+ self.model_manager = model_manager or ModelManager()
20
+ self.config_manager = config_manager or ConfigManager()
15
21
 
16
- def _track_usage(
22
+ # Validate provider is configured
23
+ if not self.config_manager.is_provider_enabled(provider_name):
24
+ raise ValueError(f"Provider {provider_name} is not configured or enabled")
25
+
26
+ def get_api_key(self) -> str:
27
+ """Get API key for the provider"""
28
+ api_key = self.config_manager.get_provider_api_key(self.provider_name)
29
+ if not api_key:
30
+ raise ValueError(f"No API key configured for provider {self.provider_name}")
31
+ return api_key
32
+
33
+ def get_provider_config(self) -> Dict[str, Any]:
34
+ """Get provider configuration"""
35
+ config = self.config_manager.get_provider_config(self.provider_name)
36
+ if not config:
37
+ return {}
38
+
39
+ return {
40
+ "api_key": config.api_key,
41
+ "api_base_url": config.api_base_url,
42
+ "organization": config.organization,
43
+ "rate_limit_rpm": config.rate_limit_rpm,
44
+ "rate_limit_tpm": config.rate_limit_tpm,
45
+ }
46
+
47
+ async def _track_usage(
17
48
  self,
18
49
  service_type: Union[str, ServiceType],
19
50
  operation: str,
@@ -23,23 +54,30 @@ class BaseService(ABC):
23
54
  output_units: Optional[float] = None,
24
55
  metadata: Optional[Dict[str, Any]] = None
25
56
  ):
26
- """Track usage for billing purposes"""
57
+ """Track usage for billing purposes using centralized billing tracker"""
27
58
  try:
28
- # Determine provider name - try multiple attributes
29
- provider_name = getattr(self.provider, 'name', None) or \
30
- getattr(self.provider, 'provider_name', None) or \
31
- getattr(self.provider, '__class__', type(None)).__name__.lower().replace('provider', '') or \
32
- 'unknown'
59
+ # Calculate cost using centralized pricing
60
+ cost_usd = None
61
+ if input_tokens is not None and output_tokens is not None:
62
+ cost_usd = self.model_manager.calculate_cost(
63
+ provider=self.provider_name,
64
+ model_name=self.model_name,
65
+ input_tokens=input_tokens,
66
+ output_tokens=output_tokens
67
+ )
33
68
 
34
- track_usage(
35
- provider=provider_name,
36
- service_type=service_type,
37
- model_name=self.model_name,
69
+ # Track usage through model manager
70
+ self.model_manager.billing_tracker.track_model_usage(
71
+ model_id=self.model_name,
72
+ operation_type="inference",
73
+ provider=self.provider_name,
74
+ service_type=service_type if isinstance(service_type, str) else service_type.value,
38
75
  operation=operation,
39
76
  input_tokens=input_tokens,
40
77
  output_tokens=output_tokens,
41
78
  input_units=input_units,
42
79
  output_units=output_units,
80
+ cost_usd=cost_usd,
43
81
  metadata=metadata
44
82
  )
45
83
  except Exception as e:
@@ -3,7 +3,71 @@ from typing import Dict, Any, List, Union, Optional
3
3
  from isa_model.inference.services.base_service import BaseService
4
4
 
5
5
  class BaseEmbedService(BaseService):
6
- """Base class for embedding services"""
6
+ """Base class for embedding services with unified task dispatch"""
7
+
8
+ async def invoke(
9
+ self,
10
+ input_data: Union[str, List[str]],
11
+ task: Optional[str] = None,
12
+ **kwargs
13
+ ) -> Union[List[float], List[List[float]], List[Dict[str, Any]], Dict[str, Any]]:
14
+ """
15
+ 统一的任务分发方法 - Base类提供通用实现
16
+
17
+ Args:
18
+ input_data: 输入数据,可以是:
19
+ - str: 单个文本
20
+ - List[str]: 多个文本(批量处理)
21
+ task: 任务类型,支持多种embedding任务
22
+ **kwargs: 任务特定的附加参数
23
+
24
+ Returns:
25
+ Various types depending on task
26
+ """
27
+ task = task or "embed"
28
+
29
+ # ==================== 嵌入生成类任务 ====================
30
+ if task == "embed":
31
+ if isinstance(input_data, list):
32
+ return await self.create_text_embeddings(input_data)
33
+ else:
34
+ return await self.create_text_embedding(input_data)
35
+ elif task == "embed_batch":
36
+ if not isinstance(input_data, list):
37
+ input_data = [input_data]
38
+ return await self.create_text_embeddings(input_data)
39
+ elif task == "chunk_and_embed":
40
+ if isinstance(input_data, list):
41
+ raise ValueError("chunk_and_embed task requires single text input")
42
+ return await self.create_chunks(input_data, kwargs.get("metadata"))
43
+ elif task == "similarity":
44
+ embedding1 = kwargs.get("embedding1")
45
+ embedding2 = kwargs.get("embedding2")
46
+ if not embedding1 or not embedding2:
47
+ raise ValueError("similarity task requires embedding1 and embedding2 parameters")
48
+ similarity = await self.compute_similarity(embedding1, embedding2)
49
+ return {"similarity": similarity}
50
+ elif task == "find_similar":
51
+ query_embedding = kwargs.get("query_embedding")
52
+ candidate_embeddings = kwargs.get("candidate_embeddings")
53
+ if not query_embedding or not candidate_embeddings:
54
+ raise ValueError("find_similar task requires query_embedding and candidate_embeddings parameters")
55
+ return await self.find_similar_texts(
56
+ query_embedding,
57
+ candidate_embeddings,
58
+ kwargs.get("top_k", 5)
59
+ )
60
+ else:
61
+ raise NotImplementedError(f"{self.__class__.__name__} does not support task: {task}")
62
+
63
+ def get_supported_tasks(self) -> List[str]:
64
+ """
65
+ 获取支持的任务列表
66
+
67
+ Returns:
68
+ List of supported task names
69
+ """
70
+ return ["embed", "embed_batch", "chunk_and_embed", "similarity", "find_similar"]
7
71
 
8
72
  @abstractmethod
9
73
  async def create_text_embedding(self, text: str) -> List[float]:
@@ -3,44 +3,65 @@ import httpx
3
3
  import asyncio
4
4
  from typing import List, Dict, Any, Optional
5
5
 
6
- # 保留您指定的导入和框架结构
7
6
  from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
8
- from isa_model.inference.providers.base_provider import BaseProvider
9
7
 
10
8
  logger = logging.getLogger(__name__)
11
9
 
12
10
  class OllamaEmbedService(BaseEmbedService):
13
11
  """
14
- Ollama embedding service.
15
- 此类遵循基础服务架构,但使用其自己的 HTTP 客户端与 Ollama API 通信,
16
- 而不依赖于注入的 backend 对象。
12
+ Ollama embedding service with unified architecture.
13
+ Uses direct HTTP client communication with Ollama API.
17
14
  """
18
15
 
19
- def __init__(self, provider: 'BaseProvider', model_name: str = "bge-m3"):
20
- # 保持对基类和 provider 的兼容
21
- super().__init__(provider, model_name)
16
+ def __init__(self, provider_name: str, model_name: str = "bge-m3", **kwargs):
17
+ super().__init__(provider_name, model_name, **kwargs)
22
18
 
23
- # 从基类继承的 self.config 中获取配置
24
- host = self.config.get("host", "localhost")
25
- port = self.config.get("port", 11434)
19
+ # Get configuration from centralized config manager
20
+ provider_config = self.get_provider_config()
26
21
 
27
- # 创建并持有自己的 httpx 客户端实例
28
- base_url = f"http://{host}:{port}"
29
- self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
22
+ # Initialize HTTP client with provider configuration
23
+ try:
24
+ host = provider_config.get("host", "localhost")
25
+ port = provider_config.get("port", 11434)
26
+ base_url = f"http://{host}:{port}"
27
+
28
+ self.client = httpx.AsyncClient(base_url=base_url, timeout=30.0)
30
29
 
31
- logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
30
+ logger.info(f"Initialized OllamaEmbedService with model '{self.model_name}' at {base_url}")
31
+
32
+ except Exception as e:
33
+ logger.error(f"Failed to initialize Ollama client: {e}")
34
+ raise ValueError(f"Failed to initialize Ollama client: {e}") from e
32
35
 
33
36
  async def create_text_embedding(self, text: str) -> List[float]:
34
- """为单个文本创建 embedding"""
37
+ """Create embedding for single text"""
35
38
  try:
36
39
  payload = {
37
40
  "model": self.model_name,
38
41
  "prompt": text
39
42
  }
40
- # 使用自己的 client 实例,而不是 self.backend
43
+
41
44
  response = await self.client.post("/api/embeddings", json=payload)
42
- response.raise_for_status() # 检查请求是否成功
43
- return response.json()["embedding"]
45
+ response.raise_for_status()
46
+
47
+ result = response.json()
48
+ embedding = result["embedding"]
49
+
50
+ # Track usage for billing (estimate token usage for Ollama)
51
+ estimated_tokens = len(text.split()) * 1.3 # Rough estimation
52
+ await self._track_usage(
53
+ service_type="embedding",
54
+ operation="create_text_embedding",
55
+ input_tokens=int(estimated_tokens),
56
+ output_tokens=0,
57
+ metadata={
58
+ "model": self.model_name,
59
+ "text_length": len(text),
60
+ "estimated_tokens": int(estimated_tokens)
61
+ }
62
+ )
63
+
64
+ return embedding
44
65
 
45
66
  except httpx.RequestError as e:
46
67
  logger.error(f"An error occurred while requesting {e.request.url!r}: {e}")
@@ -50,41 +71,70 @@ class OllamaEmbedService(BaseEmbedService):
50
71
  raise
51
72
 
52
73
  async def create_text_embeddings(self, texts: List[str]) -> List[List[float]]:
53
- """为多个文本并发地创建 embeddings"""
74
+ """Create embeddings for multiple texts concurrently"""
54
75
  if not texts:
55
76
  return []
56
77
 
57
78
  tasks = [self.create_text_embedding(text) for text in texts]
58
79
  embeddings = await asyncio.gather(*tasks)
80
+
81
+ # Track batch usage for billing
82
+ total_estimated_tokens = sum(len(text.split()) * 1.3 for text in texts)
83
+ await self._track_usage(
84
+ service_type="embedding",
85
+ operation="create_text_embeddings",
86
+ input_tokens=int(total_estimated_tokens),
87
+ output_tokens=0,
88
+ metadata={
89
+ "model": self.model_name,
90
+ "batch_size": len(texts),
91
+ "total_text_length": sum(len(t) for t in texts),
92
+ "estimated_tokens": int(total_estimated_tokens)
93
+ }
94
+ )
95
+
59
96
  return embeddings
60
97
 
61
98
  async def create_chunks(self, text: str, metadata: Optional[Dict] = None) -> List[Dict]:
62
- """将文本分块并为每个块创建 embedding"""
63
- chunk_size = 200 # 单词数量
64
- words = text.split()
65
- chunk_texts = [" ".join(words[i:i+chunk_size]) for i in range(0, len(words), chunk_size)]
99
+ """Create text chunks with embeddings"""
100
+ chunk_size = 200 # words
101
+ overlap = 50 # word overlap between chunks
66
102
 
67
- if not chunk_texts:
103
+ words = text.split()
104
+ if not words:
68
105
  return []
69
-
70
- embeddings = await self.create_text_embeddings(chunk_texts)
71
106
 
72
- chunks = [
73
- {
107
+ chunks = []
108
+ chunk_texts = []
109
+
110
+ for i in range(0, len(words), chunk_size - overlap):
111
+ chunk_words = words[i:i + chunk_size]
112
+ chunk_text = " ".join(chunk_words)
113
+ chunk_texts.append(chunk_text)
114
+
115
+ chunks.append({
74
116
  "text": chunk_text,
75
- "embedding": emb,
117
+ "start_index": i,
118
+ "end_index": min(i + chunk_size, len(words)),
76
119
  "metadata": metadata or {}
77
- }
78
- for chunk_text, emb in zip(chunk_texts, embeddings)
79
- ]
80
-
120
+ })
121
+
122
+ # Get embeddings for all chunks
123
+ embeddings = await self.create_text_embeddings(chunk_texts)
124
+
125
+ # Add embeddings to chunks
126
+ for chunk, embedding in zip(chunks, embeddings):
127
+ chunk["embedding"] = embedding
128
+
81
129
  return chunks
82
130
 
83
131
  async def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
84
- """计算两个嵌入向量之间的余弦相似度"""
132
+ """Compute cosine similarity between two embeddings"""
133
+ import math
134
+
85
135
  dot_product = sum(a * b for a, b in zip(embedding1, embedding2))
86
- norm1 = sum(a * a for a in embedding1) ** 0.5
87
- norm2 = sum(b * b for b in embedding2) ** 0.5
136
+ norm1 = math.sqrt(sum(a * a for a in embedding1))
137
+ norm2 = math.sqrt(sum(b * b for b in embedding2))
88
138
 
89
139
  if norm1 * norm2 == 0:
90
140
  return 0.0
@@ -99,9 +149,13 @@ class OllamaEmbedService(BaseEmbedService):
99
149
  ) -> List[Dict[str, Any]]:
100
150
  """Find most similar texts based on embeddings"""
101
151
  similarities = []
152
+
102
153
  for i, candidate in enumerate(candidate_embeddings):
103
154
  similarity = await self.compute_similarity(query_embedding, candidate)
104
- similarities.append({"index": i, "similarity": similarity})
155
+ similarities.append({
156
+ "index": i,
157
+ "similarity": similarity
158
+ })
105
159
 
106
160
  # Sort by similarity in descending order and return top_k
107
161
  similarities.sort(key=lambda x: x["similarity"], reverse=True)
@@ -109,15 +163,21 @@ class OllamaEmbedService(BaseEmbedService):
109
163
 
110
164
  def get_embedding_dimension(self) -> int:
111
165
  """Get the dimension of embeddings produced by this service"""
112
- # BGE-M3 produces 1024-dimensional embeddings
113
- return 1024
166
+ # Model-specific dimensions
167
+ model_dimensions = {
168
+ "bge-m3": 1024,
169
+ "bge-large": 1024,
170
+ "all-minilm": 384,
171
+ "nomic-embed-text": 768
172
+ }
173
+ return model_dimensions.get(self.model_name, 1024)
114
174
 
115
175
  def get_max_input_length(self) -> int:
116
176
  """Get maximum input text length supported"""
117
- # BGE-M3 supports up to 8192 tokens
177
+ # Most Ollama embedding models support up to 8192 tokens
118
178
  return 8192
119
179
 
120
180
  async def close(self):
121
- """关闭内置的 HTTP 客户端"""
181
+ """Cleanup resources"""
122
182
  await self.client.aclose()
123
- logger.info("OllamaEmbedService's internal client has been closed.")
183
+ logger.info("OllamaEmbedService client has been closed.")
@@ -5,8 +5,6 @@ from openai import AsyncOpenAI
5
5
  from tenacity import retry, stop_after_attempt, wait_exponential
6
6
 
7
7
  from isa_model.inference.services.embedding.base_embed_service import BaseEmbedService
8
- from isa_model.inference.providers.base_provider import BaseProvider
9
- from isa_model.inference.billing_tracker import ServiceType
10
8
 
11
9
  logger = logging.getLogger(__name__)
12
10
 
@@ -16,11 +14,11 @@ class OpenAIEmbedService(BaseEmbedService):
16
14
  Provides high-quality embeddings for production use.
17
15
  """
18
16
 
19
- def __init__(self, provider: 'BaseProvider', model_name: str = "text-embedding-3-small"):
20
- super().__init__(provider, model_name)
17
+ def __init__(self, provider_name: str, model_name: str = "text-embedding-3-small", **kwargs):
18
+ super().__init__(provider_name, model_name, **kwargs)
21
19
 
22
- # Get full configuration from provider (including sensitive data)
23
- provider_config = provider.get_full_config()
20
+ # Get configuration from centralized config manager
21
+ provider_config = self.get_provider_config()
24
22
 
25
23
  # Initialize AsyncOpenAI client with provider configuration
26
24
  try:
@@ -67,8 +65,8 @@ class OpenAIEmbedService(BaseEmbedService):
67
65
  usage = getattr(response, 'usage', None)
68
66
  if usage:
69
67
  total_tokens = getattr(usage, 'total_tokens', 0)
70
- self._track_usage(
71
- service_type=ServiceType.EMBEDDING,
68
+ await self._track_usage(
69
+ service_type="embedding",
72
70
  operation="create_text_embedding",
73
71
  input_tokens=total_tokens,
74
72
  output_tokens=0,
@@ -112,8 +110,8 @@ class OpenAIEmbedService(BaseEmbedService):
112
110
  usage = getattr(response, 'usage', None)
113
111
  if usage:
114
112
  total_tokens = getattr(usage, 'total_tokens', 0)
115
- self._track_usage(
116
- service_type=ServiceType.EMBEDDING,
113
+ await self._track_usage(
114
+ service_type="embedding",
117
115
  operation="create_text_embeddings",
118
116
  input_tokens=total_tokens,
119
117
  output_tokens=0,