aiecs 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aiecs might be problematic. Click here for more details.
- aiecs/__init__.py +75 -0
- aiecs/__main__.py +41 -0
- aiecs/aiecs_client.py +295 -0
- aiecs/application/__init__.py +10 -0
- aiecs/application/executors/__init__.py +10 -0
- aiecs/application/executors/operation_executor.py +341 -0
- aiecs/config/__init__.py +15 -0
- aiecs/config/config.py +117 -0
- aiecs/config/registry.py +19 -0
- aiecs/core/__init__.py +46 -0
- aiecs/core/interface/__init__.py +34 -0
- aiecs/core/interface/execution_interface.py +150 -0
- aiecs/core/interface/storage_interface.py +214 -0
- aiecs/domain/__init__.py +20 -0
- aiecs/domain/context/__init__.py +28 -0
- aiecs/domain/context/content_engine.py +982 -0
- aiecs/domain/context/conversation_models.py +306 -0
- aiecs/domain/execution/__init__.py +12 -0
- aiecs/domain/execution/model.py +49 -0
- aiecs/domain/task/__init__.py +13 -0
- aiecs/domain/task/dsl_processor.py +460 -0
- aiecs/domain/task/model.py +50 -0
- aiecs/domain/task/task_context.py +257 -0
- aiecs/infrastructure/__init__.py +26 -0
- aiecs/infrastructure/messaging/__init__.py +13 -0
- aiecs/infrastructure/messaging/celery_task_manager.py +341 -0
- aiecs/infrastructure/messaging/websocket_manager.py +289 -0
- aiecs/infrastructure/monitoring/__init__.py +12 -0
- aiecs/infrastructure/monitoring/executor_metrics.py +138 -0
- aiecs/infrastructure/monitoring/structured_logger.py +50 -0
- aiecs/infrastructure/monitoring/tracing_manager.py +376 -0
- aiecs/infrastructure/persistence/__init__.py +12 -0
- aiecs/infrastructure/persistence/database_manager.py +286 -0
- aiecs/infrastructure/persistence/file_storage.py +671 -0
- aiecs/infrastructure/persistence/redis_client.py +162 -0
- aiecs/llm/__init__.py +54 -0
- aiecs/llm/base_client.py +99 -0
- aiecs/llm/client_factory.py +339 -0
- aiecs/llm/custom_callbacks.py +228 -0
- aiecs/llm/openai_client.py +125 -0
- aiecs/llm/vertex_client.py +186 -0
- aiecs/llm/xai_client.py +184 -0
- aiecs/main.py +351 -0
- aiecs/scripts/DEPENDENCY_SYSTEM_SUMMARY.md +241 -0
- aiecs/scripts/README_DEPENDENCY_CHECKER.md +309 -0
- aiecs/scripts/README_WEASEL_PATCH.md +126 -0
- aiecs/scripts/__init__.py +3 -0
- aiecs/scripts/dependency_checker.py +825 -0
- aiecs/scripts/dependency_fixer.py +348 -0
- aiecs/scripts/download_nlp_data.py +348 -0
- aiecs/scripts/fix_weasel_validator.py +121 -0
- aiecs/scripts/fix_weasel_validator.sh +82 -0
- aiecs/scripts/patch_weasel_library.sh +188 -0
- aiecs/scripts/quick_dependency_check.py +269 -0
- aiecs/scripts/run_weasel_patch.sh +41 -0
- aiecs/scripts/setup_nlp_data.sh +217 -0
- aiecs/tasks/__init__.py +2 -0
- aiecs/tasks/worker.py +111 -0
- aiecs/tools/__init__.py +196 -0
- aiecs/tools/base_tool.py +202 -0
- aiecs/tools/langchain_adapter.py +361 -0
- aiecs/tools/task_tools/__init__.py +82 -0
- aiecs/tools/task_tools/chart_tool.py +704 -0
- aiecs/tools/task_tools/classfire_tool.py +901 -0
- aiecs/tools/task_tools/image_tool.py +397 -0
- aiecs/tools/task_tools/office_tool.py +600 -0
- aiecs/tools/task_tools/pandas_tool.py +565 -0
- aiecs/tools/task_tools/report_tool.py +499 -0
- aiecs/tools/task_tools/research_tool.py +363 -0
- aiecs/tools/task_tools/scraper_tool.py +548 -0
- aiecs/tools/task_tools/search_api.py +7 -0
- aiecs/tools/task_tools/stats_tool.py +513 -0
- aiecs/tools/temp_file_manager.py +126 -0
- aiecs/tools/tool_executor/__init__.py +35 -0
- aiecs/tools/tool_executor/tool_executor.py +518 -0
- aiecs/utils/LLM_output_structor.py +409 -0
- aiecs/utils/__init__.py +23 -0
- aiecs/utils/base_callback.py +50 -0
- aiecs/utils/execution_utils.py +158 -0
- aiecs/utils/logging.py +1 -0
- aiecs/utils/prompt_loader.py +13 -0
- aiecs/utils/token_usage_repository.py +279 -0
- aiecs/ws/__init__.py +0 -0
- aiecs/ws/socket_server.py +41 -0
- aiecs-1.0.0.dist-info/METADATA +610 -0
- aiecs-1.0.0.dist-info/RECORD +90 -0
- aiecs-1.0.0.dist-info/WHEEL +5 -0
- aiecs-1.0.0.dist-info/entry_points.txt +7 -0
- aiecs-1.0.0.dist-info/licenses/LICENSE +225 -0
- aiecs-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
from typing import Any, List, Optional
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
# Import the base callback handler from utils
|
|
5
|
+
from ..utils.base_callback import CustomAsyncCallbackHandler
|
|
6
|
+
# Import LLM types for internal use only
|
|
7
|
+
from .base_client import LLMMessage, LLMResponse
|
|
8
|
+
# Import token usage repository
|
|
9
|
+
from ..utils.token_usage_repository import token_usage_repo
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RedisTokenCallbackHandler(CustomAsyncCallbackHandler):
|
|
15
|
+
"""
|
|
16
|
+
Concrete token recording callback handler.
|
|
17
|
+
Responsible for recording token usage after LLM calls by delegating to the repository.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, user_id: str, cycle_start_date: Optional[str] = None):
|
|
21
|
+
if not user_id:
|
|
22
|
+
raise ValueError("user_id must be provided for RedisTokenCallbackHandler")
|
|
23
|
+
self.user_id = user_id
|
|
24
|
+
self.cycle_start_date = cycle_start_date
|
|
25
|
+
self.start_time = None
|
|
26
|
+
self.messages = None
|
|
27
|
+
|
|
28
|
+
async def on_llm_start(self, messages: List[dict], **kwargs: Any) -> None:
|
|
29
|
+
"""Triggered when LLM call starts"""
|
|
30
|
+
import time
|
|
31
|
+
self.start_time = time.time()
|
|
32
|
+
self.messages = messages
|
|
33
|
+
|
|
34
|
+
logger.info(f"[Callback] LLM call started for user '{self.user_id}' with {len(messages)} messages")
|
|
35
|
+
|
|
36
|
+
async def on_llm_end(self, response: dict, **kwargs: Any) -> None:
|
|
37
|
+
"""Triggered when LLM call ends successfully"""
|
|
38
|
+
try:
|
|
39
|
+
# Record call duration
|
|
40
|
+
if self.start_time:
|
|
41
|
+
import time
|
|
42
|
+
call_duration = time.time() - self.start_time
|
|
43
|
+
logger.info(f"[Callback] LLM call completed for user '{self.user_id}' in {call_duration:.2f}s")
|
|
44
|
+
|
|
45
|
+
# Extract token usage from response dictionary
|
|
46
|
+
tokens_used = response.get("tokens_used")
|
|
47
|
+
|
|
48
|
+
if tokens_used and tokens_used > 0:
|
|
49
|
+
# Delegate recording work to repository
|
|
50
|
+
await token_usage_repo.increment_total_usage(
|
|
51
|
+
self.user_id,
|
|
52
|
+
tokens_used,
|
|
53
|
+
self.cycle_start_date
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
logger.info(f"[Callback] Recorded {tokens_used} tokens for user '{self.user_id}'")
|
|
57
|
+
else:
|
|
58
|
+
logger.warning(f"[Callback] No token usage data available for user '{self.user_id}'")
|
|
59
|
+
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.error(f"[Callback] Failed to record token usage for user '{self.user_id}': {e}")
|
|
62
|
+
# Don't re-raise exception to avoid affecting main LLM call flow
|
|
63
|
+
|
|
64
|
+
async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
|
65
|
+
"""Triggered when LLM call encounters an error"""
|
|
66
|
+
if self.start_time:
|
|
67
|
+
import time
|
|
68
|
+
call_duration = time.time() - self.start_time
|
|
69
|
+
logger.error(f"[Callback] LLM call failed for user '{self.user_id}' after {call_duration:.2f}s: {error}")
|
|
70
|
+
else:
|
|
71
|
+
logger.error(f"[Callback] LLM call failed for user '{self.user_id}': {error}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class DetailedRedisTokenCallbackHandler(CustomAsyncCallbackHandler):
|
|
75
|
+
"""
|
|
76
|
+
Detailed token recording callback handler.
|
|
77
|
+
Records separate prompt and completion token usage in addition to total usage.
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(self, user_id: str, cycle_start_date: Optional[str] = None):
|
|
81
|
+
if not user_id:
|
|
82
|
+
raise ValueError("user_id must be provided for DetailedRedisTokenCallbackHandler")
|
|
83
|
+
self.user_id = user_id
|
|
84
|
+
self.cycle_start_date = cycle_start_date
|
|
85
|
+
self.start_time = None
|
|
86
|
+
self.messages = None
|
|
87
|
+
self.prompt_tokens = 0
|
|
88
|
+
|
|
89
|
+
async def on_llm_start(self, messages: List[dict], **kwargs: Any) -> None:
|
|
90
|
+
"""Triggered when LLM call starts"""
|
|
91
|
+
import time
|
|
92
|
+
self.start_time = time.time()
|
|
93
|
+
self.messages = messages
|
|
94
|
+
|
|
95
|
+
# Estimate input token count
|
|
96
|
+
self.prompt_tokens = self._estimate_prompt_tokens(messages)
|
|
97
|
+
|
|
98
|
+
logger.info(f"[DetailedCallback] LLM call started for user '{self.user_id}' with estimated {self.prompt_tokens} prompt tokens")
|
|
99
|
+
|
|
100
|
+
async def on_llm_end(self, response: dict, **kwargs: Any) -> None:
|
|
101
|
+
"""Triggered when LLM call ends successfully"""
|
|
102
|
+
try:
|
|
103
|
+
# Record call duration
|
|
104
|
+
if self.start_time:
|
|
105
|
+
import time
|
|
106
|
+
call_duration = time.time() - self.start_time
|
|
107
|
+
logger.info(f"[DetailedCallback] LLM call completed for user '{self.user_id}' in {call_duration:.2f}s")
|
|
108
|
+
|
|
109
|
+
# Extract detailed token information from response
|
|
110
|
+
prompt_tokens, completion_tokens = self._extract_detailed_tokens(response)
|
|
111
|
+
|
|
112
|
+
# Ensure we have valid integers (not None)
|
|
113
|
+
prompt_tokens = prompt_tokens or 0
|
|
114
|
+
completion_tokens = completion_tokens or 0
|
|
115
|
+
|
|
116
|
+
if prompt_tokens > 0 or completion_tokens > 0:
|
|
117
|
+
# Use detailed token recording method
|
|
118
|
+
await token_usage_repo.increment_detailed_usage(
|
|
119
|
+
self.user_id,
|
|
120
|
+
prompt_tokens,
|
|
121
|
+
completion_tokens,
|
|
122
|
+
self.cycle_start_date
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
logger.info(f"[DetailedCallback] Recorded detailed tokens for user '{self.user_id}': prompt={prompt_tokens}, completion={completion_tokens}")
|
|
126
|
+
else:
|
|
127
|
+
logger.warning(f"[DetailedCallback] No detailed token usage data available for user '{self.user_id}'")
|
|
128
|
+
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.error(f"[DetailedCallback] Failed to record detailed token usage for user '{self.user_id}': {e}")
|
|
131
|
+
# Don't re-raise exception to avoid affecting main LLM call flow
|
|
132
|
+
|
|
133
|
+
async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
|
134
|
+
"""Triggered when LLM call encounters an error"""
|
|
135
|
+
if self.start_time:
|
|
136
|
+
import time
|
|
137
|
+
call_duration = time.time() - self.start_time
|
|
138
|
+
logger.error(f"[DetailedCallback] LLM call failed for user '{self.user_id}' after {call_duration:.2f}s: {error}")
|
|
139
|
+
else:
|
|
140
|
+
logger.error(f"[DetailedCallback] LLM call failed for user '{self.user_id}': {error}")
|
|
141
|
+
|
|
142
|
+
def _estimate_prompt_tokens(self, messages: List[dict]) -> int:
|
|
143
|
+
"""Estimate token count for input messages"""
|
|
144
|
+
total_chars = sum(len(msg.get('content', '')) for msg in messages)
|
|
145
|
+
# Rough estimation: 4 characters ≈ 1 token
|
|
146
|
+
return total_chars // 4
|
|
147
|
+
|
|
148
|
+
def _extract_detailed_tokens(self, response: dict) -> tuple[int, int]:
|
|
149
|
+
"""
|
|
150
|
+
Extract detailed token information from response dictionary
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
tuple: (prompt_tokens, completion_tokens)
|
|
154
|
+
"""
|
|
155
|
+
# If response has detailed token information, use it first
|
|
156
|
+
prompt_tokens = response.get('prompt_tokens') or 0
|
|
157
|
+
completion_tokens = response.get('completion_tokens') or 0
|
|
158
|
+
|
|
159
|
+
if prompt_tokens > 0 and completion_tokens > 0:
|
|
160
|
+
return prompt_tokens, completion_tokens
|
|
161
|
+
|
|
162
|
+
# If only total token count is available, try to allocate
|
|
163
|
+
tokens_used = response.get('tokens_used') or 0
|
|
164
|
+
if tokens_used > 0:
|
|
165
|
+
# Use previously estimated prompt tokens
|
|
166
|
+
prompt_tokens = self.prompt_tokens
|
|
167
|
+
completion_tokens = max(0, tokens_used - prompt_tokens)
|
|
168
|
+
return prompt_tokens, completion_tokens
|
|
169
|
+
|
|
170
|
+
# If no token information, try to estimate from response content
|
|
171
|
+
content = response.get('content', '')
|
|
172
|
+
if content:
|
|
173
|
+
completion_tokens = len(content) // 4
|
|
174
|
+
prompt_tokens = self.prompt_tokens
|
|
175
|
+
return prompt_tokens, completion_tokens
|
|
176
|
+
|
|
177
|
+
return 0, 0
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class CompositeCallbackHandler(CustomAsyncCallbackHandler):
|
|
181
|
+
"""
|
|
182
|
+
Composite callback handler that can execute multiple callback handlers simultaneously
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(self, handlers: List[CustomAsyncCallbackHandler]):
|
|
186
|
+
self.handlers = handlers or []
|
|
187
|
+
|
|
188
|
+
def add_handler(self, handler: CustomAsyncCallbackHandler):
|
|
189
|
+
"""Add a callback handler"""
|
|
190
|
+
self.handlers.append(handler)
|
|
191
|
+
|
|
192
|
+
async def on_llm_start(self, messages: List[dict], **kwargs: Any) -> None:
|
|
193
|
+
"""Execute start callbacks for all handlers"""
|
|
194
|
+
for handler in self.handlers:
|
|
195
|
+
try:
|
|
196
|
+
await handler.on_llm_start(messages, **kwargs)
|
|
197
|
+
except Exception as e:
|
|
198
|
+
logger.error(f"Error in callback handler {type(handler).__name__}.on_llm_start: {e}")
|
|
199
|
+
|
|
200
|
+
async def on_llm_end(self, response: dict, **kwargs: Any) -> None:
|
|
201
|
+
"""Execute end callbacks for all handlers"""
|
|
202
|
+
for handler in self.handlers:
|
|
203
|
+
try:
|
|
204
|
+
await handler.on_llm_end(response, **kwargs)
|
|
205
|
+
except Exception as e:
|
|
206
|
+
logger.error(f"Error in callback handler {type(handler).__name__}.on_llm_end: {e}")
|
|
207
|
+
|
|
208
|
+
async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
|
|
209
|
+
"""Execute error callbacks for all handlers"""
|
|
210
|
+
for handler in self.handlers:
|
|
211
|
+
try:
|
|
212
|
+
await handler.on_llm_error(error, **kwargs)
|
|
213
|
+
except Exception as e:
|
|
214
|
+
logger.error(f"Error in callback handler {type(handler).__name__}.on_llm_error: {e}")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# Convenience functions for creating common callback handlers
|
|
218
|
+
def create_token_callback(user_id: str, cycle_start_date: Optional[str] = None) -> RedisTokenCallbackHandler:
|
|
219
|
+
"""Create a basic token recording callback handler"""
|
|
220
|
+
return RedisTokenCallbackHandler(user_id, cycle_start_date)
|
|
221
|
+
|
|
222
|
+
def create_detailed_token_callback(user_id: str, cycle_start_date: Optional[str] = None) -> DetailedRedisTokenCallbackHandler:
|
|
223
|
+
"""Create a detailed token recording callback handler"""
|
|
224
|
+
return DetailedRedisTokenCallbackHandler(user_id, cycle_start_date)
|
|
225
|
+
|
|
226
|
+
def create_composite_callback(*handlers: CustomAsyncCallbackHandler) -> CompositeCallbackHandler:
|
|
227
|
+
"""Create a composite callback handler"""
|
|
228
|
+
return CompositeCallbackHandler(list(handlers))
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Dict, Any, Optional, List, AsyncGenerator
|
|
4
|
+
from openai import AsyncOpenAI
|
|
5
|
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from .base_client import BaseLLMClient, LLMMessage, LLMResponse, ProviderNotAvailableError, RateLimitError
|
|
9
|
+
from aiecs.config.config import get_settings
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
class OpenAIClient(BaseLLMClient):
|
|
14
|
+
"""OpenAI provider client"""
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
super().__init__("OpenAI")
|
|
18
|
+
self.settings = get_settings()
|
|
19
|
+
self._client: Optional[AsyncOpenAI] = None
|
|
20
|
+
|
|
21
|
+
# Token cost estimates (USD per 1K tokens)
|
|
22
|
+
self.token_costs = {
|
|
23
|
+
"gpt-4": {"input": 0.03, "output": 0.06},
|
|
24
|
+
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
|
|
25
|
+
"gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
|
|
26
|
+
"gpt-4o": {"input": 0.005, "output": 0.015},
|
|
27
|
+
"gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
def _get_client(self) -> AsyncOpenAI:
|
|
31
|
+
"""Lazy initialization of OpenAI client"""
|
|
32
|
+
if not self._client:
|
|
33
|
+
if not self.settings.openai_api_key:
|
|
34
|
+
raise ProviderNotAvailableError("OpenAI API key not configured")
|
|
35
|
+
self._client = AsyncOpenAI(api_key=self.settings.openai_api_key)
|
|
36
|
+
return self._client
|
|
37
|
+
|
|
38
|
+
@retry(
|
|
39
|
+
stop=stop_after_attempt(3),
|
|
40
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
41
|
+
retry=retry_if_exception_type((httpx.RequestError, RateLimitError))
|
|
42
|
+
)
|
|
43
|
+
async def generate_text(
|
|
44
|
+
self,
|
|
45
|
+
messages: List[LLMMessage],
|
|
46
|
+
model: Optional[str] = None,
|
|
47
|
+
temperature: float = 0.7,
|
|
48
|
+
max_tokens: Optional[int] = None,
|
|
49
|
+
**kwargs
|
|
50
|
+
) -> LLMResponse:
|
|
51
|
+
"""Generate text using OpenAI API"""
|
|
52
|
+
client = self._get_client()
|
|
53
|
+
model = model or "gpt-4-turbo"
|
|
54
|
+
|
|
55
|
+
# Convert to OpenAI message format
|
|
56
|
+
openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
response = await client.chat.completions.create(
|
|
60
|
+
model=model,
|
|
61
|
+
messages=openai_messages,
|
|
62
|
+
temperature=temperature,
|
|
63
|
+
max_tokens=max_tokens,
|
|
64
|
+
**kwargs
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
content = response.choices[0].message.content
|
|
68
|
+
tokens_used = response.usage.total_tokens if response.usage else None
|
|
69
|
+
|
|
70
|
+
# Estimate cost
|
|
71
|
+
input_tokens = response.usage.prompt_tokens if response.usage else 0
|
|
72
|
+
output_tokens = response.usage.completion_tokens if response.usage else 0
|
|
73
|
+
cost = self._estimate_cost(model, input_tokens, output_tokens, self.token_costs)
|
|
74
|
+
|
|
75
|
+
return LLMResponse(
|
|
76
|
+
content=content,
|
|
77
|
+
provider=self.provider_name,
|
|
78
|
+
model=model,
|
|
79
|
+
tokens_used=tokens_used,
|
|
80
|
+
cost_estimate=cost
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
except Exception as e:
|
|
84
|
+
if "rate_limit" in str(e).lower():
|
|
85
|
+
raise RateLimitError(f"OpenAI rate limit exceeded: {str(e)}")
|
|
86
|
+
raise
|
|
87
|
+
|
|
88
|
+
async def stream_text(
|
|
89
|
+
self,
|
|
90
|
+
messages: List[LLMMessage],
|
|
91
|
+
model: Optional[str] = None,
|
|
92
|
+
temperature: float = 0.7,
|
|
93
|
+
max_tokens: Optional[int] = None,
|
|
94
|
+
**kwargs
|
|
95
|
+
) -> AsyncGenerator[str, None]:
|
|
96
|
+
"""Stream text using OpenAI API"""
|
|
97
|
+
client = self._get_client()
|
|
98
|
+
model = model or "gpt-4-turbo"
|
|
99
|
+
|
|
100
|
+
openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
stream = await client.chat.completions.create(
|
|
104
|
+
model=model,
|
|
105
|
+
messages=openai_messages,
|
|
106
|
+
temperature=temperature,
|
|
107
|
+
max_tokens=max_tokens,
|
|
108
|
+
stream=True,
|
|
109
|
+
**kwargs
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
async for chunk in stream:
|
|
113
|
+
if chunk.choices[0].delta.content:
|
|
114
|
+
yield chunk.choices[0].delta.content
|
|
115
|
+
|
|
116
|
+
except Exception as e:
|
|
117
|
+
if "rate_limit" in str(e).lower():
|
|
118
|
+
raise RateLimitError(f"OpenAI rate limit exceeded: {str(e)}")
|
|
119
|
+
raise
|
|
120
|
+
|
|
121
|
+
async def close(self):
|
|
122
|
+
"""Clean up resources"""
|
|
123
|
+
if self._client:
|
|
124
|
+
await self._client.close()
|
|
125
|
+
self._client = None
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import Dict, Any, Optional, List, AsyncGenerator
|
|
5
|
+
from vertexai.generative_models import GenerativeModel, HarmCategory, HarmBlockThreshold
|
|
6
|
+
import vertexai
|
|
7
|
+
from google.oauth2 import service_account
|
|
8
|
+
|
|
9
|
+
from .base_client import BaseLLMClient, LLMMessage, LLMResponse, ProviderNotAvailableError, RateLimitError
|
|
10
|
+
from aiecs.config.config import get_settings
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
class VertexAIClient(BaseLLMClient):
|
|
15
|
+
"""Vertex AI provider client"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super().__init__("Vertex")
|
|
19
|
+
self.settings = get_settings()
|
|
20
|
+
self._initialized = False
|
|
21
|
+
|
|
22
|
+
# Token cost estimates (USD per 1K tokens)
|
|
23
|
+
self.token_costs = {
|
|
24
|
+
"gemini-2.5-pro": {"input": 0.00125, "output": 0.00375},
|
|
25
|
+
"gemini-2.5-flash": {"input": 0.000075, "output": 0.0003},
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
def _init_vertex_ai(self):
|
|
29
|
+
"""Lazy initialization of Vertex AI with proper authentication"""
|
|
30
|
+
if not self._initialized:
|
|
31
|
+
if not self.settings.vertex_project_id:
|
|
32
|
+
raise ProviderNotAvailableError("Vertex AI project ID not configured")
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
# Set up Google Cloud authentication
|
|
36
|
+
credentials = None
|
|
37
|
+
|
|
38
|
+
# Check if GOOGLE_APPLICATION_CREDENTIALS is configured
|
|
39
|
+
if self.settings.google_application_credentials:
|
|
40
|
+
credentials_path = self.settings.google_application_credentials
|
|
41
|
+
if os.path.exists(credentials_path):
|
|
42
|
+
# Set the environment variable for Google Cloud SDK
|
|
43
|
+
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = credentials_path
|
|
44
|
+
self.logger.info(f"Using Google Cloud credentials from: {credentials_path}")
|
|
45
|
+
else:
|
|
46
|
+
self.logger.warning(f"Google Cloud credentials file not found: {credentials_path}")
|
|
47
|
+
raise ProviderNotAvailableError(f"Google Cloud credentials file not found: {credentials_path}")
|
|
48
|
+
elif 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
|
|
49
|
+
self.logger.info("Using Google Cloud credentials from environment variable")
|
|
50
|
+
else:
|
|
51
|
+
self.logger.warning("No Google Cloud credentials configured. Using default authentication.")
|
|
52
|
+
|
|
53
|
+
# Initialize Vertex AI
|
|
54
|
+
vertexai.init(
|
|
55
|
+
project=self.settings.vertex_project_id,
|
|
56
|
+
location=getattr(self.settings, 'vertex_location', 'us-central1')
|
|
57
|
+
)
|
|
58
|
+
self._initialized = True
|
|
59
|
+
self.logger.info(f"Vertex AI initialized for project {self.settings.vertex_project_id}")
|
|
60
|
+
|
|
61
|
+
except Exception as e:
|
|
62
|
+
raise ProviderNotAvailableError(f"Failed to initialize Vertex AI: {str(e)}")
|
|
63
|
+
|
|
64
|
+
async def generate_text(
|
|
65
|
+
self,
|
|
66
|
+
messages: List[LLMMessage],
|
|
67
|
+
model: Optional[str] = None,
|
|
68
|
+
temperature: float = 0.7,
|
|
69
|
+
max_tokens: Optional[int] = None,
|
|
70
|
+
**kwargs
|
|
71
|
+
) -> LLMResponse:
|
|
72
|
+
"""Generate text using Vertex AI"""
|
|
73
|
+
self._init_vertex_ai()
|
|
74
|
+
model_name = model or "gemini-2.5-pro"
|
|
75
|
+
|
|
76
|
+
try:
|
|
77
|
+
# Use the stable Vertex AI API
|
|
78
|
+
model_instance = GenerativeModel(model_name)
|
|
79
|
+
|
|
80
|
+
# Convert messages to Vertex AI format
|
|
81
|
+
if len(messages) == 1 and messages[0].role == "user":
|
|
82
|
+
prompt = messages[0].content
|
|
83
|
+
else:
|
|
84
|
+
# For multi-turn conversations, combine messages
|
|
85
|
+
prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
|
|
86
|
+
|
|
87
|
+
response = await asyncio.get_event_loop().run_in_executor(
|
|
88
|
+
None,
|
|
89
|
+
lambda: model_instance.generate_content(
|
|
90
|
+
prompt,
|
|
91
|
+
generation_config={
|
|
92
|
+
"temperature": temperature,
|
|
93
|
+
"max_output_tokens": max_tokens or 8192, # Increased to account for thinking tokens
|
|
94
|
+
"top_p": 0.95,
|
|
95
|
+
"top_k": 40,
|
|
96
|
+
},
|
|
97
|
+
safety_settings={
|
|
98
|
+
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
|
99
|
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
|
100
|
+
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
|
101
|
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
|
102
|
+
}
|
|
103
|
+
)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Handle response content safely
|
|
107
|
+
try:
|
|
108
|
+
content = response.text
|
|
109
|
+
self.logger.debug(f"Vertex AI response received: {content[:100]}...")
|
|
110
|
+
except ValueError as ve:
|
|
111
|
+
# Handle cases where response has no content (safety filters, etc.)
|
|
112
|
+
self.logger.warning(f"Vertex AI response error: {str(ve)}")
|
|
113
|
+
self.logger.debug(f"Full response object: {response}")
|
|
114
|
+
|
|
115
|
+
# Check if response has candidates but no text
|
|
116
|
+
if hasattr(response, 'candidates') and response.candidates:
|
|
117
|
+
candidate = response.candidates[0]
|
|
118
|
+
self.logger.debug(f"Candidate finish_reason: {getattr(candidate, 'finish_reason', 'unknown')}")
|
|
119
|
+
|
|
120
|
+
# If finish_reason is MAX_TOKENS, it might be due to thinking tokens
|
|
121
|
+
if hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'MAX_TOKENS':
|
|
122
|
+
content = "[Response truncated due to token limit - consider increasing max_tokens for Gemini 2.5 models]"
|
|
123
|
+
self.logger.warning("Response truncated due to MAX_TOKENS - Gemini 2.5 uses thinking tokens")
|
|
124
|
+
elif "no parts" in str(ve).lower() or "safety filters" in str(ve).lower():
|
|
125
|
+
content = "[Response blocked by safety filters or has no content]"
|
|
126
|
+
self.logger.warning(f"Vertex AI response blocked or empty: {str(ve)}")
|
|
127
|
+
else:
|
|
128
|
+
content = f"[Response error: {str(ve)}]"
|
|
129
|
+
else:
|
|
130
|
+
content = f"[Response error: {str(ve)}]"
|
|
131
|
+
|
|
132
|
+
# Vertex AI doesn't provide detailed token usage in the response
|
|
133
|
+
tokens_used = self._count_tokens_estimate(prompt + content)
|
|
134
|
+
cost = self._estimate_cost(
|
|
135
|
+
model_name,
|
|
136
|
+
self._count_tokens_estimate(prompt),
|
|
137
|
+
self._count_tokens_estimate(content),
|
|
138
|
+
self.token_costs
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return LLMResponse(
|
|
142
|
+
content=content,
|
|
143
|
+
provider=self.provider_name,
|
|
144
|
+
model=model_name,
|
|
145
|
+
tokens_used=tokens_used,
|
|
146
|
+
cost_estimate=cost
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
except Exception as e:
|
|
150
|
+
if "quota" in str(e).lower() or "limit" in str(e).lower():
|
|
151
|
+
raise RateLimitError(f"Vertex AI quota exceeded: {str(e)}")
|
|
152
|
+
# Handle specific Vertex AI response errors
|
|
153
|
+
if "cannot get the response text" in str(e).lower() or "safety filters" in str(e).lower():
|
|
154
|
+
self.logger.warning(f"Vertex AI response issue: {str(e)}")
|
|
155
|
+
# Return a response indicating the issue
|
|
156
|
+
return LLMResponse(
|
|
157
|
+
content="[Response unavailable due to safety filters or content policy]",
|
|
158
|
+
provider=self.provider_name,
|
|
159
|
+
model=model_name,
|
|
160
|
+
tokens_used=self._count_tokens_estimate(prompt),
|
|
161
|
+
cost_estimate=0.0
|
|
162
|
+
)
|
|
163
|
+
raise
|
|
164
|
+
|
|
165
|
+
async def stream_text(
|
|
166
|
+
self,
|
|
167
|
+
messages: List[LLMMessage],
|
|
168
|
+
model: Optional[str] = None,
|
|
169
|
+
temperature: float = 0.7,
|
|
170
|
+
max_tokens: Optional[int] = None,
|
|
171
|
+
**kwargs
|
|
172
|
+
) -> AsyncGenerator[str, None]:
|
|
173
|
+
"""Stream text using Vertex AI (simulated streaming)"""
|
|
174
|
+
# Vertex AI streaming is more complex, for now fall back to non-streaming
|
|
175
|
+
response = await self.generate_text(messages, model, temperature, max_tokens, **kwargs)
|
|
176
|
+
|
|
177
|
+
# Simulate streaming by yielding words
|
|
178
|
+
words = response.content.split()
|
|
179
|
+
for word in words:
|
|
180
|
+
yield word + " "
|
|
181
|
+
await asyncio.sleep(0.05) # Small delay to simulate streaming
|
|
182
|
+
|
|
183
|
+
async def close(self):
|
|
184
|
+
"""Clean up resources"""
|
|
185
|
+
# Vertex AI doesn't require explicit cleanup
|
|
186
|
+
self._initialized = False
|