aiecs 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aiecs might be problematic. Click here for more details.

Files changed (90) hide show
  1. aiecs/__init__.py +75 -0
  2. aiecs/__main__.py +41 -0
  3. aiecs/aiecs_client.py +295 -0
  4. aiecs/application/__init__.py +10 -0
  5. aiecs/application/executors/__init__.py +10 -0
  6. aiecs/application/executors/operation_executor.py +341 -0
  7. aiecs/config/__init__.py +15 -0
  8. aiecs/config/config.py +117 -0
  9. aiecs/config/registry.py +19 -0
  10. aiecs/core/__init__.py +46 -0
  11. aiecs/core/interface/__init__.py +34 -0
  12. aiecs/core/interface/execution_interface.py +150 -0
  13. aiecs/core/interface/storage_interface.py +214 -0
  14. aiecs/domain/__init__.py +20 -0
  15. aiecs/domain/context/__init__.py +28 -0
  16. aiecs/domain/context/content_engine.py +982 -0
  17. aiecs/domain/context/conversation_models.py +306 -0
  18. aiecs/domain/execution/__init__.py +12 -0
  19. aiecs/domain/execution/model.py +49 -0
  20. aiecs/domain/task/__init__.py +13 -0
  21. aiecs/domain/task/dsl_processor.py +460 -0
  22. aiecs/domain/task/model.py +50 -0
  23. aiecs/domain/task/task_context.py +257 -0
  24. aiecs/infrastructure/__init__.py +26 -0
  25. aiecs/infrastructure/messaging/__init__.py +13 -0
  26. aiecs/infrastructure/messaging/celery_task_manager.py +341 -0
  27. aiecs/infrastructure/messaging/websocket_manager.py +289 -0
  28. aiecs/infrastructure/monitoring/__init__.py +12 -0
  29. aiecs/infrastructure/monitoring/executor_metrics.py +138 -0
  30. aiecs/infrastructure/monitoring/structured_logger.py +50 -0
  31. aiecs/infrastructure/monitoring/tracing_manager.py +376 -0
  32. aiecs/infrastructure/persistence/__init__.py +12 -0
  33. aiecs/infrastructure/persistence/database_manager.py +286 -0
  34. aiecs/infrastructure/persistence/file_storage.py +671 -0
  35. aiecs/infrastructure/persistence/redis_client.py +162 -0
  36. aiecs/llm/__init__.py +54 -0
  37. aiecs/llm/base_client.py +99 -0
  38. aiecs/llm/client_factory.py +339 -0
  39. aiecs/llm/custom_callbacks.py +228 -0
  40. aiecs/llm/openai_client.py +125 -0
  41. aiecs/llm/vertex_client.py +186 -0
  42. aiecs/llm/xai_client.py +184 -0
  43. aiecs/main.py +351 -0
  44. aiecs/scripts/DEPENDENCY_SYSTEM_SUMMARY.md +241 -0
  45. aiecs/scripts/README_DEPENDENCY_CHECKER.md +309 -0
  46. aiecs/scripts/README_WEASEL_PATCH.md +126 -0
  47. aiecs/scripts/__init__.py +3 -0
  48. aiecs/scripts/dependency_checker.py +825 -0
  49. aiecs/scripts/dependency_fixer.py +348 -0
  50. aiecs/scripts/download_nlp_data.py +348 -0
  51. aiecs/scripts/fix_weasel_validator.py +121 -0
  52. aiecs/scripts/fix_weasel_validator.sh +82 -0
  53. aiecs/scripts/patch_weasel_library.sh +188 -0
  54. aiecs/scripts/quick_dependency_check.py +269 -0
  55. aiecs/scripts/run_weasel_patch.sh +41 -0
  56. aiecs/scripts/setup_nlp_data.sh +217 -0
  57. aiecs/tasks/__init__.py +2 -0
  58. aiecs/tasks/worker.py +111 -0
  59. aiecs/tools/__init__.py +196 -0
  60. aiecs/tools/base_tool.py +202 -0
  61. aiecs/tools/langchain_adapter.py +361 -0
  62. aiecs/tools/task_tools/__init__.py +82 -0
  63. aiecs/tools/task_tools/chart_tool.py +704 -0
  64. aiecs/tools/task_tools/classfire_tool.py +901 -0
  65. aiecs/tools/task_tools/image_tool.py +397 -0
  66. aiecs/tools/task_tools/office_tool.py +600 -0
  67. aiecs/tools/task_tools/pandas_tool.py +565 -0
  68. aiecs/tools/task_tools/report_tool.py +499 -0
  69. aiecs/tools/task_tools/research_tool.py +363 -0
  70. aiecs/tools/task_tools/scraper_tool.py +548 -0
  71. aiecs/tools/task_tools/search_api.py +7 -0
  72. aiecs/tools/task_tools/stats_tool.py +513 -0
  73. aiecs/tools/temp_file_manager.py +126 -0
  74. aiecs/tools/tool_executor/__init__.py +35 -0
  75. aiecs/tools/tool_executor/tool_executor.py +518 -0
  76. aiecs/utils/LLM_output_structor.py +409 -0
  77. aiecs/utils/__init__.py +23 -0
  78. aiecs/utils/base_callback.py +50 -0
  79. aiecs/utils/execution_utils.py +158 -0
  80. aiecs/utils/logging.py +1 -0
  81. aiecs/utils/prompt_loader.py +13 -0
  82. aiecs/utils/token_usage_repository.py +279 -0
  83. aiecs/ws/__init__.py +0 -0
  84. aiecs/ws/socket_server.py +41 -0
  85. aiecs-1.0.0.dist-info/METADATA +610 -0
  86. aiecs-1.0.0.dist-info/RECORD +90 -0
  87. aiecs-1.0.0.dist-info/WHEEL +5 -0
  88. aiecs-1.0.0.dist-info/entry_points.txt +7 -0
  89. aiecs-1.0.0.dist-info/licenses/LICENSE +225 -0
  90. aiecs-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,228 @@
1
+ from typing import Any, List, Optional
2
+ import logging
3
+
4
+ # Import the base callback handler from utils
5
+ from ..utils.base_callback import CustomAsyncCallbackHandler
6
+ # Import LLM types for internal use only
7
+ from .base_client import LLMMessage, LLMResponse
8
+ # Import token usage repository
9
+ from ..utils.token_usage_repository import token_usage_repo
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class RedisTokenCallbackHandler(CustomAsyncCallbackHandler):
15
+ """
16
+ Concrete token recording callback handler.
17
+ Responsible for recording token usage after LLM calls by delegating to the repository.
18
+ """
19
+
20
+ def __init__(self, user_id: str, cycle_start_date: Optional[str] = None):
21
+ if not user_id:
22
+ raise ValueError("user_id must be provided for RedisTokenCallbackHandler")
23
+ self.user_id = user_id
24
+ self.cycle_start_date = cycle_start_date
25
+ self.start_time = None
26
+ self.messages = None
27
+
28
+ async def on_llm_start(self, messages: List[dict], **kwargs: Any) -> None:
29
+ """Triggered when LLM call starts"""
30
+ import time
31
+ self.start_time = time.time()
32
+ self.messages = messages
33
+
34
+ logger.info(f"[Callback] LLM call started for user '{self.user_id}' with {len(messages)} messages")
35
+
36
+ async def on_llm_end(self, response: dict, **kwargs: Any) -> None:
37
+ """Triggered when LLM call ends successfully"""
38
+ try:
39
+ # Record call duration
40
+ if self.start_time:
41
+ import time
42
+ call_duration = time.time() - self.start_time
43
+ logger.info(f"[Callback] LLM call completed for user '{self.user_id}' in {call_duration:.2f}s")
44
+
45
+ # Extract token usage from response dictionary
46
+ tokens_used = response.get("tokens_used")
47
+
48
+ if tokens_used and tokens_used > 0:
49
+ # Delegate recording work to repository
50
+ await token_usage_repo.increment_total_usage(
51
+ self.user_id,
52
+ tokens_used,
53
+ self.cycle_start_date
54
+ )
55
+
56
+ logger.info(f"[Callback] Recorded {tokens_used} tokens for user '{self.user_id}'")
57
+ else:
58
+ logger.warning(f"[Callback] No token usage data available for user '{self.user_id}'")
59
+
60
+ except Exception as e:
61
+ logger.error(f"[Callback] Failed to record token usage for user '{self.user_id}': {e}")
62
+ # Don't re-raise exception to avoid affecting main LLM call flow
63
+
64
+ async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
65
+ """Triggered when LLM call encounters an error"""
66
+ if self.start_time:
67
+ import time
68
+ call_duration = time.time() - self.start_time
69
+ logger.error(f"[Callback] LLM call failed for user '{self.user_id}' after {call_duration:.2f}s: {error}")
70
+ else:
71
+ logger.error(f"[Callback] LLM call failed for user '{self.user_id}': {error}")
72
+
73
+
74
+ class DetailedRedisTokenCallbackHandler(CustomAsyncCallbackHandler):
75
+ """
76
+ Detailed token recording callback handler.
77
+ Records separate prompt and completion token usage in addition to total usage.
78
+ """
79
+
80
+ def __init__(self, user_id: str, cycle_start_date: Optional[str] = None):
81
+ if not user_id:
82
+ raise ValueError("user_id must be provided for DetailedRedisTokenCallbackHandler")
83
+ self.user_id = user_id
84
+ self.cycle_start_date = cycle_start_date
85
+ self.start_time = None
86
+ self.messages = None
87
+ self.prompt_tokens = 0
88
+
89
+ async def on_llm_start(self, messages: List[dict], **kwargs: Any) -> None:
90
+ """Triggered when LLM call starts"""
91
+ import time
92
+ self.start_time = time.time()
93
+ self.messages = messages
94
+
95
+ # Estimate input token count
96
+ self.prompt_tokens = self._estimate_prompt_tokens(messages)
97
+
98
+ logger.info(f"[DetailedCallback] LLM call started for user '{self.user_id}' with estimated {self.prompt_tokens} prompt tokens")
99
+
100
+ async def on_llm_end(self, response: dict, **kwargs: Any) -> None:
101
+ """Triggered when LLM call ends successfully"""
102
+ try:
103
+ # Record call duration
104
+ if self.start_time:
105
+ import time
106
+ call_duration = time.time() - self.start_time
107
+ logger.info(f"[DetailedCallback] LLM call completed for user '{self.user_id}' in {call_duration:.2f}s")
108
+
109
+ # Extract detailed token information from response
110
+ prompt_tokens, completion_tokens = self._extract_detailed_tokens(response)
111
+
112
+ # Ensure we have valid integers (not None)
113
+ prompt_tokens = prompt_tokens or 0
114
+ completion_tokens = completion_tokens or 0
115
+
116
+ if prompt_tokens > 0 or completion_tokens > 0:
117
+ # Use detailed token recording method
118
+ await token_usage_repo.increment_detailed_usage(
119
+ self.user_id,
120
+ prompt_tokens,
121
+ completion_tokens,
122
+ self.cycle_start_date
123
+ )
124
+
125
+ logger.info(f"[DetailedCallback] Recorded detailed tokens for user '{self.user_id}': prompt={prompt_tokens}, completion={completion_tokens}")
126
+ else:
127
+ logger.warning(f"[DetailedCallback] No detailed token usage data available for user '{self.user_id}'")
128
+
129
+ except Exception as e:
130
+ logger.error(f"[DetailedCallback] Failed to record detailed token usage for user '{self.user_id}': {e}")
131
+ # Don't re-raise exception to avoid affecting main LLM call flow
132
+
133
+ async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
134
+ """Triggered when LLM call encounters an error"""
135
+ if self.start_time:
136
+ import time
137
+ call_duration = time.time() - self.start_time
138
+ logger.error(f"[DetailedCallback] LLM call failed for user '{self.user_id}' after {call_duration:.2f}s: {error}")
139
+ else:
140
+ logger.error(f"[DetailedCallback] LLM call failed for user '{self.user_id}': {error}")
141
+
142
+ def _estimate_prompt_tokens(self, messages: List[dict]) -> int:
143
+ """Estimate token count for input messages"""
144
+ total_chars = sum(len(msg.get('content', '')) for msg in messages)
145
+ # Rough estimation: 4 characters ≈ 1 token
146
+ return total_chars // 4
147
+
148
+ def _extract_detailed_tokens(self, response: dict) -> tuple[int, int]:
149
+ """
150
+ Extract detailed token information from response dictionary
151
+
152
+ Returns:
153
+ tuple: (prompt_tokens, completion_tokens)
154
+ """
155
+ # If response has detailed token information, use it first
156
+ prompt_tokens = response.get('prompt_tokens') or 0
157
+ completion_tokens = response.get('completion_tokens') or 0
158
+
159
+ if prompt_tokens > 0 and completion_tokens > 0:
160
+ return prompt_tokens, completion_tokens
161
+
162
+ # If only total token count is available, try to allocate
163
+ tokens_used = response.get('tokens_used') or 0
164
+ if tokens_used > 0:
165
+ # Use previously estimated prompt tokens
166
+ prompt_tokens = self.prompt_tokens
167
+ completion_tokens = max(0, tokens_used - prompt_tokens)
168
+ return prompt_tokens, completion_tokens
169
+
170
+ # If no token information, try to estimate from response content
171
+ content = response.get('content', '')
172
+ if content:
173
+ completion_tokens = len(content) // 4
174
+ prompt_tokens = self.prompt_tokens
175
+ return prompt_tokens, completion_tokens
176
+
177
+ return 0, 0
178
+
179
+
180
+ class CompositeCallbackHandler(CustomAsyncCallbackHandler):
181
+ """
182
+ Composite callback handler that can execute multiple callback handlers simultaneously
183
+ """
184
+
185
+ def __init__(self, handlers: List[CustomAsyncCallbackHandler]):
186
+ self.handlers = handlers or []
187
+
188
+ def add_handler(self, handler: CustomAsyncCallbackHandler):
189
+ """Add a callback handler"""
190
+ self.handlers.append(handler)
191
+
192
+ async def on_llm_start(self, messages: List[dict], **kwargs: Any) -> None:
193
+ """Execute start callbacks for all handlers"""
194
+ for handler in self.handlers:
195
+ try:
196
+ await handler.on_llm_start(messages, **kwargs)
197
+ except Exception as e:
198
+ logger.error(f"Error in callback handler {type(handler).__name__}.on_llm_start: {e}")
199
+
200
+ async def on_llm_end(self, response: dict, **kwargs: Any) -> None:
201
+ """Execute end callbacks for all handlers"""
202
+ for handler in self.handlers:
203
+ try:
204
+ await handler.on_llm_end(response, **kwargs)
205
+ except Exception as e:
206
+ logger.error(f"Error in callback handler {type(handler).__name__}.on_llm_end: {e}")
207
+
208
+ async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
209
+ """Execute error callbacks for all handlers"""
210
+ for handler in self.handlers:
211
+ try:
212
+ await handler.on_llm_error(error, **kwargs)
213
+ except Exception as e:
214
+ logger.error(f"Error in callback handler {type(handler).__name__}.on_llm_error: {e}")
215
+
216
+
217
+ # Convenience functions for creating common callback handlers
218
+ def create_token_callback(user_id: str, cycle_start_date: Optional[str] = None) -> RedisTokenCallbackHandler:
219
+ """Create a basic token recording callback handler"""
220
+ return RedisTokenCallbackHandler(user_id, cycle_start_date)
221
+
222
+ def create_detailed_token_callback(user_id: str, cycle_start_date: Optional[str] = None) -> DetailedRedisTokenCallbackHandler:
223
+ """Create a detailed token recording callback handler"""
224
+ return DetailedRedisTokenCallbackHandler(user_id, cycle_start_date)
225
+
226
+ def create_composite_callback(*handlers: CustomAsyncCallbackHandler) -> CompositeCallbackHandler:
227
+ """Create a composite callback handler"""
228
+ return CompositeCallbackHandler(list(handlers))
@@ -0,0 +1,125 @@
1
+ import asyncio
2
+ import logging
3
+ from typing import Dict, Any, Optional, List, AsyncGenerator
4
+ from openai import AsyncOpenAI
5
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
6
+ import httpx
7
+
8
+ from .base_client import BaseLLMClient, LLMMessage, LLMResponse, ProviderNotAvailableError, RateLimitError
9
+ from aiecs.config.config import get_settings
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class OpenAIClient(BaseLLMClient):
14
+ """OpenAI provider client"""
15
+
16
+ def __init__(self):
17
+ super().__init__("OpenAI")
18
+ self.settings = get_settings()
19
+ self._client: Optional[AsyncOpenAI] = None
20
+
21
+ # Token cost estimates (USD per 1K tokens)
22
+ self.token_costs = {
23
+ "gpt-4": {"input": 0.03, "output": 0.06},
24
+ "gpt-4-turbo": {"input": 0.01, "output": 0.03},
25
+ "gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
26
+ "gpt-4o": {"input": 0.005, "output": 0.015},
27
+ "gpt-4o-mini": {"input": 0.00015, "output": 0.0006},
28
+ }
29
+
30
+ def _get_client(self) -> AsyncOpenAI:
31
+ """Lazy initialization of OpenAI client"""
32
+ if not self._client:
33
+ if not self.settings.openai_api_key:
34
+ raise ProviderNotAvailableError("OpenAI API key not configured")
35
+ self._client = AsyncOpenAI(api_key=self.settings.openai_api_key)
36
+ return self._client
37
+
38
+ @retry(
39
+ stop=stop_after_attempt(3),
40
+ wait=wait_exponential(multiplier=1, min=4, max=10),
41
+ retry=retry_if_exception_type((httpx.RequestError, RateLimitError))
42
+ )
43
+ async def generate_text(
44
+ self,
45
+ messages: List[LLMMessage],
46
+ model: Optional[str] = None,
47
+ temperature: float = 0.7,
48
+ max_tokens: Optional[int] = None,
49
+ **kwargs
50
+ ) -> LLMResponse:
51
+ """Generate text using OpenAI API"""
52
+ client = self._get_client()
53
+ model = model or "gpt-4-turbo"
54
+
55
+ # Convert to OpenAI message format
56
+ openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
57
+
58
+ try:
59
+ response = await client.chat.completions.create(
60
+ model=model,
61
+ messages=openai_messages,
62
+ temperature=temperature,
63
+ max_tokens=max_tokens,
64
+ **kwargs
65
+ )
66
+
67
+ content = response.choices[0].message.content
68
+ tokens_used = response.usage.total_tokens if response.usage else None
69
+
70
+ # Estimate cost
71
+ input_tokens = response.usage.prompt_tokens if response.usage else 0
72
+ output_tokens = response.usage.completion_tokens if response.usage else 0
73
+ cost = self._estimate_cost(model, input_tokens, output_tokens, self.token_costs)
74
+
75
+ return LLMResponse(
76
+ content=content,
77
+ provider=self.provider_name,
78
+ model=model,
79
+ tokens_used=tokens_used,
80
+ cost_estimate=cost
81
+ )
82
+
83
+ except Exception as e:
84
+ if "rate_limit" in str(e).lower():
85
+ raise RateLimitError(f"OpenAI rate limit exceeded: {str(e)}")
86
+ raise
87
+
88
+ async def stream_text(
89
+ self,
90
+ messages: List[LLMMessage],
91
+ model: Optional[str] = None,
92
+ temperature: float = 0.7,
93
+ max_tokens: Optional[int] = None,
94
+ **kwargs
95
+ ) -> AsyncGenerator[str, None]:
96
+ """Stream text using OpenAI API"""
97
+ client = self._get_client()
98
+ model = model or "gpt-4-turbo"
99
+
100
+ openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]
101
+
102
+ try:
103
+ stream = await client.chat.completions.create(
104
+ model=model,
105
+ messages=openai_messages,
106
+ temperature=temperature,
107
+ max_tokens=max_tokens,
108
+ stream=True,
109
+ **kwargs
110
+ )
111
+
112
+ async for chunk in stream:
113
+ if chunk.choices[0].delta.content:
114
+ yield chunk.choices[0].delta.content
115
+
116
+ except Exception as e:
117
+ if "rate_limit" in str(e).lower():
118
+ raise RateLimitError(f"OpenAI rate limit exceeded: {str(e)}")
119
+ raise
120
+
121
+ async def close(self):
122
+ """Clean up resources"""
123
+ if self._client:
124
+ await self._client.close()
125
+ self._client = None
@@ -0,0 +1,186 @@
1
+ import asyncio
2
+ import logging
3
+ import os
4
+ from typing import Dict, Any, Optional, List, AsyncGenerator
5
+ from vertexai.generative_models import GenerativeModel, HarmCategory, HarmBlockThreshold
6
+ import vertexai
7
+ from google.oauth2 import service_account
8
+
9
+ from .base_client import BaseLLMClient, LLMMessage, LLMResponse, ProviderNotAvailableError, RateLimitError
10
+ from aiecs.config.config import get_settings
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class VertexAIClient(BaseLLMClient):
15
+ """Vertex AI provider client"""
16
+
17
+ def __init__(self):
18
+ super().__init__("Vertex")
19
+ self.settings = get_settings()
20
+ self._initialized = False
21
+
22
+ # Token cost estimates (USD per 1K tokens)
23
+ self.token_costs = {
24
+ "gemini-2.5-pro": {"input": 0.00125, "output": 0.00375},
25
+ "gemini-2.5-flash": {"input": 0.000075, "output": 0.0003},
26
+ }
27
+
28
+ def _init_vertex_ai(self):
29
+ """Lazy initialization of Vertex AI with proper authentication"""
30
+ if not self._initialized:
31
+ if not self.settings.vertex_project_id:
32
+ raise ProviderNotAvailableError("Vertex AI project ID not configured")
33
+
34
+ try:
35
+ # Set up Google Cloud authentication
36
+ credentials = None
37
+
38
+ # Check if GOOGLE_APPLICATION_CREDENTIALS is configured
39
+ if self.settings.google_application_credentials:
40
+ credentials_path = self.settings.google_application_credentials
41
+ if os.path.exists(credentials_path):
42
+ # Set the environment variable for Google Cloud SDK
43
+ os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = credentials_path
44
+ self.logger.info(f"Using Google Cloud credentials from: {credentials_path}")
45
+ else:
46
+ self.logger.warning(f"Google Cloud credentials file not found: {credentials_path}")
47
+ raise ProviderNotAvailableError(f"Google Cloud credentials file not found: {credentials_path}")
48
+ elif 'GOOGLE_APPLICATION_CREDENTIALS' in os.environ:
49
+ self.logger.info("Using Google Cloud credentials from environment variable")
50
+ else:
51
+ self.logger.warning("No Google Cloud credentials configured. Using default authentication.")
52
+
53
+ # Initialize Vertex AI
54
+ vertexai.init(
55
+ project=self.settings.vertex_project_id,
56
+ location=getattr(self.settings, 'vertex_location', 'us-central1')
57
+ )
58
+ self._initialized = True
59
+ self.logger.info(f"Vertex AI initialized for project {self.settings.vertex_project_id}")
60
+
61
+ except Exception as e:
62
+ raise ProviderNotAvailableError(f"Failed to initialize Vertex AI: {str(e)}")
63
+
64
+ async def generate_text(
65
+ self,
66
+ messages: List[LLMMessage],
67
+ model: Optional[str] = None,
68
+ temperature: float = 0.7,
69
+ max_tokens: Optional[int] = None,
70
+ **kwargs
71
+ ) -> LLMResponse:
72
+ """Generate text using Vertex AI"""
73
+ self._init_vertex_ai()
74
+ model_name = model or "gemini-2.5-pro"
75
+
76
+ try:
77
+ # Use the stable Vertex AI API
78
+ model_instance = GenerativeModel(model_name)
79
+
80
+ # Convert messages to Vertex AI format
81
+ if len(messages) == 1 and messages[0].role == "user":
82
+ prompt = messages[0].content
83
+ else:
84
+ # For multi-turn conversations, combine messages
85
+ prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in messages])
86
+
87
+ response = await asyncio.get_event_loop().run_in_executor(
88
+ None,
89
+ lambda: model_instance.generate_content(
90
+ prompt,
91
+ generation_config={
92
+ "temperature": temperature,
93
+ "max_output_tokens": max_tokens or 8192, # Increased to account for thinking tokens
94
+ "top_p": 0.95,
95
+ "top_k": 40,
96
+ },
97
+ safety_settings={
98
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
99
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
100
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
101
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
102
+ }
103
+ )
104
+ )
105
+
106
+ # Handle response content safely
107
+ try:
108
+ content = response.text
109
+ self.logger.debug(f"Vertex AI response received: {content[:100]}...")
110
+ except ValueError as ve:
111
+ # Handle cases where response has no content (safety filters, etc.)
112
+ self.logger.warning(f"Vertex AI response error: {str(ve)}")
113
+ self.logger.debug(f"Full response object: {response}")
114
+
115
+ # Check if response has candidates but no text
116
+ if hasattr(response, 'candidates') and response.candidates:
117
+ candidate = response.candidates[0]
118
+ self.logger.debug(f"Candidate finish_reason: {getattr(candidate, 'finish_reason', 'unknown')}")
119
+
120
+ # If finish_reason is MAX_TOKENS, it might be due to thinking tokens
121
+ if hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'MAX_TOKENS':
122
+ content = "[Response truncated due to token limit - consider increasing max_tokens for Gemini 2.5 models]"
123
+ self.logger.warning("Response truncated due to MAX_TOKENS - Gemini 2.5 uses thinking tokens")
124
+ elif "no parts" in str(ve).lower() or "safety filters" in str(ve).lower():
125
+ content = "[Response blocked by safety filters or has no content]"
126
+ self.logger.warning(f"Vertex AI response blocked or empty: {str(ve)}")
127
+ else:
128
+ content = f"[Response error: {str(ve)}]"
129
+ else:
130
+ content = f"[Response error: {str(ve)}]"
131
+
132
+ # Vertex AI doesn't provide detailed token usage in the response
133
+ tokens_used = self._count_tokens_estimate(prompt + content)
134
+ cost = self._estimate_cost(
135
+ model_name,
136
+ self._count_tokens_estimate(prompt),
137
+ self._count_tokens_estimate(content),
138
+ self.token_costs
139
+ )
140
+
141
+ return LLMResponse(
142
+ content=content,
143
+ provider=self.provider_name,
144
+ model=model_name,
145
+ tokens_used=tokens_used,
146
+ cost_estimate=cost
147
+ )
148
+
149
+ except Exception as e:
150
+ if "quota" in str(e).lower() or "limit" in str(e).lower():
151
+ raise RateLimitError(f"Vertex AI quota exceeded: {str(e)}")
152
+ # Handle specific Vertex AI response errors
153
+ if "cannot get the response text" in str(e).lower() or "safety filters" in str(e).lower():
154
+ self.logger.warning(f"Vertex AI response issue: {str(e)}")
155
+ # Return a response indicating the issue
156
+ return LLMResponse(
157
+ content="[Response unavailable due to safety filters or content policy]",
158
+ provider=self.provider_name,
159
+ model=model_name,
160
+ tokens_used=self._count_tokens_estimate(prompt),
161
+ cost_estimate=0.0
162
+ )
163
+ raise
164
+
165
+ async def stream_text(
166
+ self,
167
+ messages: List[LLMMessage],
168
+ model: Optional[str] = None,
169
+ temperature: float = 0.7,
170
+ max_tokens: Optional[int] = None,
171
+ **kwargs
172
+ ) -> AsyncGenerator[str, None]:
173
+ """Stream text using Vertex AI (simulated streaming)"""
174
+ # Vertex AI streaming is more complex, for now fall back to non-streaming
175
+ response = await self.generate_text(messages, model, temperature, max_tokens, **kwargs)
176
+
177
+ # Simulate streaming by yielding words
178
+ words = response.content.split()
179
+ for word in words:
180
+ yield word + " "
181
+ await asyncio.sleep(0.05) # Small delay to simulate streaming
182
+
183
+ async def close(self):
184
+ """Clean up resources"""
185
+ # Vertex AI doesn't require explicit cleanup
186
+ self._initialized = False