ziya 0.1.49__py3-none-any.whl → 0.1.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ziya might be problematic. Click here for more details.
- app/agents/.agent.py.swp +0 -0
- app/agents/agent.py +315 -113
- app/agents/models.py +439 -0
- app/agents/prompts.py +32 -4
- app/main.py +70 -7
- app/server.py +403 -14
- app/utils/code_util.py +641 -215
- pyproject.toml +2 -3
- templates/asset-manifest.json +18 -20
- templates/index.html +1 -1
- templates/static/css/{main.87f30840.css → main.2bddf34e.css} +2 -2
- templates/static/css/main.2bddf34e.css.map +1 -0
- templates/static/js/46907.90c6a4f3.chunk.js +2 -0
- templates/static/js/46907.90c6a4f3.chunk.js.map +1 -0
- templates/static/js/56122.1d6a5c10.chunk.js +3 -0
- templates/static/js/56122.1d6a5c10.chunk.js.LICENSE.txt +9 -0
- templates/static/js/56122.1d6a5c10.chunk.js.map +1 -0
- templates/static/js/83953.61a908f4.chunk.js +3 -0
- templates/static/js/83953.61a908f4.chunk.js.map +1 -0
- templates/static/js/88261.1e90079d.chunk.js +3 -0
- templates/static/js/88261.1e90079d.chunk.js.map +1 -0
- templates/static/js/{96603.863a8f96.chunk.js → 96603.18c5d644.chunk.js} +2 -2
- templates/static/js/{96603.863a8f96.chunk.js.map → 96603.18c5d644.chunk.js.map} +1 -1
- templates/static/js/{97902.75670155.chunk.js → 97902.d1e262d6.chunk.js} +3 -3
- templates/static/js/{97902.75670155.chunk.js.map → 97902.d1e262d6.chunk.js.map} +1 -1
- templates/static/js/main.9b2b2b57.js +3 -0
- templates/static/js/{main.ee8b3c96.js.LICENSE.txt → main.9b2b2b57.js.LICENSE.txt} +8 -2
- templates/static/js/main.9b2b2b57.js.map +1 -0
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/METADATA +4 -5
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/RECORD +36 -35
- templates/static/css/main.87f30840.css.map +0 -1
- templates/static/js/23416.c33f07ab.chunk.js +0 -3
- templates/static/js/23416.c33f07ab.chunk.js.map +0 -1
- templates/static/js/3799.fedb612f.chunk.js +0 -2
- templates/static/js/3799.fedb612f.chunk.js.map +0 -1
- templates/static/js/46907.4a730107.chunk.js +0 -2
- templates/static/js/46907.4a730107.chunk.js.map +0 -1
- templates/static/js/64754.cf383335.chunk.js +0 -2
- templates/static/js/64754.cf383335.chunk.js.map +0 -1
- templates/static/js/88261.33450351.chunk.js +0 -3
- templates/static/js/88261.33450351.chunk.js.map +0 -1
- templates/static/js/main.ee8b3c96.js +0 -3
- templates/static/js/main.ee8b3c96.js.map +0 -1
- /templates/static/js/{23416.c33f07ab.chunk.js.LICENSE.txt → 83953.61a908f4.chunk.js.LICENSE.txt} +0 -0
- /templates/static/js/{88261.33450351.chunk.js.LICENSE.txt → 88261.1e90079d.chunk.js.LICENSE.txt} +0 -0
- /templates/static/js/{97902.75670155.chunk.js.LICENSE.txt → 97902.d1e262d6.chunk.js.LICENSE.txt} +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/LICENSE +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/WHEEL +0 -0
- {ziya-0.1.49.dist-info → ziya-0.1.50.dist-info}/entry_points.txt +0 -0
app/agents/.agent.py.swp
ADDED
|
Binary file
|
app/agents/agent.py
CHANGED
|
@@ -5,17 +5,22 @@ from typing import Dict, List, Tuple, Set, Union, Optional, Any
|
|
|
5
5
|
import json
|
|
6
6
|
import time
|
|
7
7
|
import botocore
|
|
8
|
+
import asyncio
|
|
9
|
+
import tiktoken
|
|
8
10
|
from langchain.agents import AgentExecutor
|
|
9
11
|
from langchain.agents.format_scratchpad import format_xml
|
|
10
12
|
from langchain_aws import ChatBedrock
|
|
13
|
+
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
|
|
14
|
+
from google.api_core.exceptions import ResourceExhausted
|
|
11
15
|
from langchain_community.document_loaders import TextLoader
|
|
12
16
|
from langchain_core.agents import AgentFinish
|
|
13
|
-
from langchain_core.messages import AIMessage, HumanMessage
|
|
17
|
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage
|
|
14
18
|
from langchain_core.outputs import Generation
|
|
15
19
|
from langchain_core.runnables import RunnablePassthrough, Runnable
|
|
16
20
|
from pydantic import BaseModel, Field
|
|
17
21
|
|
|
18
22
|
from app.agents.prompts import conversational_prompt
|
|
23
|
+
from app.agents.models import ModelManager
|
|
19
24
|
|
|
20
25
|
from app.utils.sanitizer_util import clean_backtick_sequences
|
|
21
26
|
|
|
@@ -24,115 +29,141 @@ from app.utils.print_tree_util import print_file_tree
|
|
|
24
29
|
from app.utils.file_utils import is_binary_file
|
|
25
30
|
from app.utils.file_state_manager import FileStateManager
|
|
26
31
|
|
|
27
|
-
import tiktoken
|
|
28
|
-
import anthropic
|
|
29
|
-
from anthropic import Anthropic
|
|
30
|
-
|
|
31
32
|
|
|
32
33
|
def clean_chat_history(chat_history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
|
33
34
|
"""Clean chat history by removing invalid messages and normalizing content."""
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
35
|
+
if not chat_history or not isinstance(chat_history, list):
|
|
36
|
+
return []
|
|
37
|
+
try:
|
|
38
|
+
cleaned = []
|
|
39
|
+
for human, ai in chat_history:
|
|
40
|
+
# Skip pairs with empty messages
|
|
41
|
+
if not isinstance(human, str) or not isinstance(ai, str):
|
|
42
|
+
logger.warning(f"Skipping invalid message pair: human='{human}', ai='{ai}'")
|
|
43
|
+
continue
|
|
44
|
+
human_clean = human.strip() if human else ""
|
|
45
|
+
ai_clean = ai.strip() if ai else ""
|
|
46
|
+
if not human_clean or not ai_clean:
|
|
47
|
+
logger.warning(f"Skipping empty message pair")
|
|
48
|
+
continue
|
|
49
|
+
cleaned.append((human.strip(), ai.strip()))
|
|
50
|
+
return cleaned
|
|
51
|
+
except Exception as e:
|
|
52
|
+
logger.error(f"Error cleaning chat history: {str(e)}")
|
|
53
|
+
logger.error(f"Raw chat history: {chat_history}")
|
|
54
|
+
return cleaned
|
|
42
55
|
|
|
43
56
|
def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List[Union[HumanMessage, AIMessage]]:
|
|
44
|
-
logger.info(f"
|
|
57
|
+
logger.info(f"Chat history type: {type(chat_history)}")
|
|
45
58
|
cleaned_history = clean_chat_history(chat_history)
|
|
46
59
|
buffer = []
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
60
|
+
logger.debug("Message format before conversion:")
|
|
61
|
+
try:
|
|
62
|
+
for human, ai in cleaned_history:
|
|
63
|
+
if human and isinstance(human, str):
|
|
64
|
+
logger.debug(f"Human message type: {type(human)}, content: {human[:100]}")
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
buffer.append(HumanMessage(content=str(human)))
|
|
68
|
+
except Exception as e:
|
|
69
|
+
logger.error(f"Error creating HumanMessage: {str(e)}")
|
|
70
|
+
if ai and isinstance(ai, str):
|
|
71
|
+
logger.debug(f"AI message type: {type(ai)}, content: {ai[:100]}")
|
|
72
|
+
try:
|
|
73
|
+
buffer.append(AIMessage(content=str(ai)))
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.error(f"Error creating AIMessage: {str(e)}")
|
|
76
|
+
|
|
77
|
+
except Exception as e:
|
|
78
|
+
logger.error(f"Error formatting chat history: {str(e)}")
|
|
79
|
+
logger.error(f"Problematic chat history: {chat_history}")
|
|
80
|
+
return []
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
logger.debug(f"Final formatted messages: {[type(m).__name__ for m in buffer]}")
|
|
50
84
|
return buffer
|
|
51
85
|
|
|
52
86
|
def parse_output(message):
|
|
53
87
|
"""Parse and sanitize the output from the language model."""
|
|
54
88
|
try:
|
|
55
|
-
#
|
|
56
|
-
|
|
57
|
-
|
|
89
|
+
# Get the content based on the object type
|
|
90
|
+
content = None
|
|
91
|
+
if hasattr(message, 'text'):
|
|
92
|
+
content = message.text
|
|
93
|
+
elif hasattr(message, 'content'):
|
|
94
|
+
content = message.content
|
|
58
95
|
else:
|
|
59
|
-
content =
|
|
60
|
-
|
|
61
|
-
if
|
|
62
|
-
#
|
|
96
|
+
content = str(message)
|
|
97
|
+
finally:
|
|
98
|
+
if content:
|
|
99
|
+
# If content is a method (from Gemini), get the actual content
|
|
100
|
+
if callable(content):
|
|
101
|
+
try:
|
|
102
|
+
content = content()
|
|
103
|
+
except Exception as e:
|
|
104
|
+
logger.error(f"Error calling content method: {e}")
|
|
63
105
|
try:
|
|
106
|
+
# Check if this is an error message
|
|
64
107
|
error_data = json.loads(content)
|
|
65
108
|
if error_data.get('error') == 'validation_error':
|
|
66
109
|
logger.info(f"Detected validation error in output: {content}")
|
|
67
110
|
return AgentFinish(return_values={"output": content}, log=content)
|
|
68
|
-
except
|
|
111
|
+
except json.JSONDecodeError:
|
|
69
112
|
pass
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
return AgentFinish(return_values={"output": str(message)}, log=str(message))
|
|
77
|
-
|
|
78
|
-
aws_profile = os.environ.get("ZIYA_AWS_PROFILE")
|
|
79
|
-
if aws_profile:
|
|
80
|
-
logger.info(f"Using AWS Profile: {aws_profile}")
|
|
81
|
-
else:
|
|
82
|
-
logger.info("No AWS profile specified via --aws-profile flag, using default credentials")
|
|
83
|
-
model_id = {
|
|
84
|
-
"sonnet3.7": "us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
|
85
|
-
"sonnet3.5": "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
|
86
|
-
"sonnet3.5-v2": "us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
|
87
|
-
"opus": "us.anthropic.claude-3-opus-20240229-v1:0",
|
|
88
|
-
"sonnet": "us.anthropic.claude-3-sonnet-20240229-v1:0",
|
|
89
|
-
"haiku": "us.anthropic.claude-3-haiku-20240307-v1:0",
|
|
90
|
-
}[os.environ.get("ZIYA_AWS_MODEL", "sonnet3.5-v2")]
|
|
91
|
-
logger.info(f"Using Claude Model: {model_id}")
|
|
92
|
-
|
|
93
|
-
model = ChatBedrock(
|
|
94
|
-
model_id=model_id,
|
|
95
|
-
model_kwargs={"max_tokens": 4096, "temperature": 0.3, "top_k": 15},
|
|
96
|
-
credentials_profile_name=aws_profile if aws_profile else None,
|
|
97
|
-
config=botocore.config.Config(
|
|
98
|
-
read_timeout=900,
|
|
99
|
-
retries={
|
|
100
|
-
'max_attempts': 3,
|
|
101
|
-
'total_max_attempts': 5
|
|
102
|
-
}
|
|
103
|
-
# retry_mode is not supported in this version
|
|
104
|
-
)
|
|
105
|
-
)
|
|
113
|
+
# If not an error, clean and return the content
|
|
114
|
+
text = clean_backtick_sequences(content)
|
|
115
|
+
# Log using the same content we extracted
|
|
116
|
+
logger.info(f"parse_output received content size: {len(content)} chars, returning size: {len(text)} chars")
|
|
117
|
+
return AgentFinish(return_values={"output": text}, log=text)
|
|
118
|
+
return AgentFinish(return_values={"output": ""}, log="")
|
|
106
119
|
|
|
107
120
|
# Create a wrapper class that adds retries
|
|
108
121
|
class RetryingChatBedrock(Runnable):
|
|
109
122
|
def __init__(self, model):
|
|
110
123
|
self.model = model
|
|
124
|
+
self.provider = os.environ.get("ZIYA_ENDPOINT", "bedrock")
|
|
125
|
+
|
|
126
|
+
def _debug_input(self, input: Any):
|
|
127
|
+
"""Debug log input structure"""
|
|
128
|
+
logger.info(f"Input type: {type(input)}")
|
|
129
|
+
if hasattr(input, 'to_messages'):
|
|
130
|
+
logger.info("ChatPromptValue detected, messages:")
|
|
131
|
+
messages = input.to_messages()
|
|
132
|
+
for i, msg in enumerate(messages):
|
|
133
|
+
logger.info(f"Message {i}:")
|
|
134
|
+
logger.info(f" Type: {type(msg)}")
|
|
135
|
+
logger.info(f" Content type: {type(msg.content)}")
|
|
136
|
+
logger.info(f" Content: {msg.content}")
|
|
137
|
+
elif isinstance(input, dict):
|
|
138
|
+
logger.info(f"Input keys: {input.keys()}")
|
|
139
|
+
if 'messages' in input:
|
|
140
|
+
logger.info("Messages content:")
|
|
141
|
+
for i, msg in enumerate(input['messages']):
|
|
142
|
+
logger.info(f"Message {i}: type={type(msg)}, content={msg}")
|
|
143
|
+
else:
|
|
144
|
+
logger.info(f"Raw input: {input}")
|
|
111
145
|
|
|
112
146
|
def bind(self, **kwargs):
|
|
113
147
|
return RetryingChatBedrock(self.model.bind(**kwargs))
|
|
114
148
|
|
|
149
|
+
|
|
115
150
|
def get_num_tokens(self, text: str) -> int:
|
|
116
|
-
|
|
117
|
-
Custom token counting function using anthropic's tokenizer.
|
|
118
|
-
Falls back to tiktoken if anthropic is not available.
|
|
119
|
-
"""
|
|
120
|
-
try:
|
|
121
|
-
client = Anthropic()
|
|
122
|
-
count = client.count_tokens(text)
|
|
123
|
-
return count
|
|
124
|
-
except Exception as e:
|
|
125
|
-
logger.warning(f"Failed to use anthropic tokenizer: {str(e)}")
|
|
126
|
-
try:
|
|
127
|
-
return len(tiktoken.get_encoding("cl100k_base").encode(text))
|
|
128
|
-
except Exception as e:
|
|
129
|
-
logger.error(f"Failed to count tokens: {str(e)}")
|
|
130
|
-
return len(text.split()) # Rough approximation
|
|
151
|
+
return self.model.get_num_tokens(text)
|
|
131
152
|
|
|
132
153
|
def __getattr__(self, name: str):
|
|
133
154
|
# Delegate any unknown attributes to the underlying model
|
|
134
155
|
return getattr(self.model, name)
|
|
135
156
|
|
|
157
|
+
def _get_provider_format(self) -> str:
|
|
158
|
+
"""Get the message format requirements for current provider."""
|
|
159
|
+
# Can be extended for other providers
|
|
160
|
+
return self.provider
|
|
161
|
+
|
|
162
|
+
def _convert_to_messages(self, input_value: Any) -> Union[str, List[Dict[str, str]]]:
|
|
163
|
+
"""Convert input to messages format expected by provider."""
|
|
164
|
+
if isinstance(input_value, (str, list)):
|
|
165
|
+
return input_value
|
|
166
|
+
|
|
136
167
|
async def _handle_stream_error(self, e: Exception):
|
|
137
168
|
"""Handle stream errors by yielding an error message."""
|
|
138
169
|
yield Generation(
|
|
@@ -143,6 +174,24 @@ class RetryingChatBedrock(Runnable):
|
|
|
143
174
|
)
|
|
144
175
|
return
|
|
145
176
|
|
|
177
|
+
def _prepare_input(self, input: Any) -> Dict:
|
|
178
|
+
"""Convert input to format expected by Bedrock."""
|
|
179
|
+
logger.info("Preparing input for Bedrock")
|
|
180
|
+
|
|
181
|
+
if hasattr(input, 'to_messages'):
|
|
182
|
+
# Handle ChatPromptValue
|
|
183
|
+
messages = input.to_messages()
|
|
184
|
+
logger.debug(f"Model type: {type(self.model)}")
|
|
185
|
+
logger.debug(f"Original messages: {messages}")
|
|
186
|
+
|
|
187
|
+
# Filter out empty messages but keep the original message types
|
|
188
|
+
filtered_messages = [
|
|
189
|
+
msg for msg in messages
|
|
190
|
+
if self._format_message_content(msg)
|
|
191
|
+
]
|
|
192
|
+
|
|
193
|
+
return filtered_messages
|
|
194
|
+
|
|
146
195
|
async def _handle_validation_error(self, e: Exception):
|
|
147
196
|
"""Handle validation errors by yielding an error message."""
|
|
148
197
|
error_chunk = Generation(
|
|
@@ -153,42 +202,189 @@ class RetryingChatBedrock(Runnable):
|
|
|
153
202
|
"""Check if this is a streaming operation."""
|
|
154
203
|
return hasattr(func, '__name__') and func.__name__ == 'astream'
|
|
155
204
|
|
|
156
|
-
def
|
|
157
|
-
|
|
205
|
+
def _format_message_content(self, message: Any) -> str:
|
|
206
|
+
"""Ensure message content is properly formatted as a string."""
|
|
158
207
|
|
|
159
|
-
|
|
160
|
-
|
|
208
|
+
logger.info(f"Formatting message: type={type(message)}")
|
|
209
|
+
if isinstance(message, dict):
|
|
210
|
+
logger.info(f"Dict message keys: {message.keys()}")
|
|
211
|
+
if 'content' in message:
|
|
212
|
+
logger.info(f"Content type: {type(message['content'])}")
|
|
213
|
+
logger.info(f"Content value: {message['content']}")
|
|
214
|
+
try:
|
|
215
|
+
# Handle different message formats
|
|
216
|
+
if isinstance(message, dict):
|
|
217
|
+
content = message.get('content', '')
|
|
218
|
+
elif hasattr(message, 'content'):
|
|
219
|
+
content = message.content
|
|
220
|
+
else:
|
|
221
|
+
content = str(message)
|
|
222
|
+
# Ensure content is a string
|
|
223
|
+
if not isinstance(content, str):
|
|
224
|
+
if content is None:
|
|
225
|
+
return ""
|
|
226
|
+
content = str(content)
|
|
227
|
+
|
|
228
|
+
return content.strip()
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.error(f"Error formatting message content: {str(e)}")
|
|
231
|
+
return ""
|
|
232
|
+
|
|
233
|
+
def _prepare_messages_for_provider(self, input: Any) -> List[Dict[str, str]]:
|
|
234
|
+
formatted_messages = []
|
|
235
|
+
|
|
236
|
+
# Convert input to messages list
|
|
237
|
+
if hasattr(input, 'to_messages'):
|
|
238
|
+
messages = list(input.to_messages())
|
|
239
|
+
logger.debug(f"Converting ChatPromptValue to messages: {len(messages)} messages")
|
|
240
|
+
elif isinstance(input, (list, tuple)):
|
|
241
|
+
messages = list(input)
|
|
242
|
+
else:
|
|
243
|
+
messages = [input]
|
|
244
|
+
|
|
245
|
+
# Process messages in order
|
|
246
|
+
logger.debug(f"Processing {len(messages)} messages")
|
|
247
|
+
for msg in messages:
|
|
248
|
+
# Extract role and content
|
|
249
|
+
if isinstance(msg, (SystemMessage, HumanMessage, AIMessage)):
|
|
250
|
+
if isinstance(msg, SystemMessage):
|
|
251
|
+
role = 'system'
|
|
252
|
+
elif isinstance(msg, HumanMessage):
|
|
253
|
+
role = 'user'
|
|
254
|
+
else:
|
|
255
|
+
role = 'assistant'
|
|
256
|
+
content = msg.content
|
|
257
|
+
elif isinstance(msg, dict) and 'content' in msg:
|
|
258
|
+
role = msg.get('role', 'user')
|
|
259
|
+
content = msg['content']
|
|
260
|
+
else:
|
|
261
|
+
role = 'user'
|
|
262
|
+
content = str(msg)
|
|
263
|
+
|
|
264
|
+
logger.debug(f"Message type: {type(msg)}, role: {role}, content type: {type(content)}")
|
|
265
|
+
|
|
266
|
+
# Skip empty assistant messages
|
|
267
|
+
if role == 'assistant' and not content:
|
|
268
|
+
continue
|
|
269
|
+
|
|
270
|
+
# Ensure content is a non-empty string
|
|
271
|
+
content = str(content).strip()
|
|
272
|
+
if not content:
|
|
273
|
+
continue
|
|
161
274
|
|
|
162
|
-
|
|
275
|
+
formatted_messages.append({
|
|
276
|
+
'role': role,
|
|
277
|
+
'content': content
|
|
278
|
+
})
|
|
279
|
+
|
|
280
|
+
return formatted_messages
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def _is_chat_model(self):
|
|
285
|
+
return isinstance(self.model, ChatBedrock)
|
|
286
|
+
|
|
287
|
+
async def astream(self, input: Any, config: Optional[Dict] = None, **kwargs):
|
|
288
|
+
"""Stream responses with retries and proper message formatting."""
|
|
163
289
|
max_retries = 3
|
|
164
290
|
retry_delay = 1
|
|
165
291
|
|
|
166
292
|
for attempt in range(max_retries):
|
|
293
|
+
logger.info(f"Attempt {attempt + 1} of {max_retries}")
|
|
167
294
|
try:
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
# Convert input to messages if needed
|
|
298
|
+
if hasattr(input, 'to_messages'):
|
|
299
|
+
messages = input.to_messages()
|
|
300
|
+
logger.debug(f"Using messages from ChatPromptValue: {len(messages)} messages")
|
|
174
301
|
else:
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
})
|
|
302
|
+
messages = input
|
|
303
|
+
logger.debug(f"Using input directly: {type(input)}")
|
|
304
|
+
|
|
305
|
+
# Filter out empty messages
|
|
306
|
+
if isinstance(messages, list):
|
|
307
|
+
messages = [
|
|
308
|
+
msg for msg in messages
|
|
309
|
+
if isinstance(msg, BaseMessage) and msg.content
|
|
310
|
+
]
|
|
311
|
+
if not messages:
|
|
312
|
+
raise ValueError("No valid messages with content")
|
|
313
|
+
logger.debug(f"Filtered to {len(messages)} non-empty messages")
|
|
314
|
+
|
|
315
|
+
async for chunk in self.model.astream(messages, config, **kwargs):
|
|
316
|
+
yield chunk
|
|
317
|
+
|
|
318
|
+
break # Success, exit retry loop
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
except ResourceExhausted as e:
|
|
322
|
+
logger.error(f"Google API quota exceeded: {str(e)}")
|
|
323
|
+
yield Generation(
|
|
324
|
+
text=json.dumps({
|
|
325
|
+
"error": "quota_exceeded",
|
|
326
|
+
"detail": "API quota has been exceeded. Please try again in a few minutes."
|
|
327
|
+
})
|
|
328
|
+
)
|
|
329
|
+
return
|
|
330
|
+
|
|
331
|
+
except Exception as e:
|
|
332
|
+
logger.error(f"Attempt {attempt + 1} failed: {str(e)}")
|
|
333
|
+
if attempt == max_retries - 1:
|
|
334
|
+
yield Generation(text=json.dumps({"error": "stream_error", "detail": str(e)}))
|
|
187
335
|
return
|
|
188
|
-
|
|
336
|
+
await asyncio.sleep(retry_delay * (attempt + 1))
|
|
189
337
|
|
|
338
|
+
def _format_messages(self, input_messages: List[Any]) -> List[Dict[str, str]]:
|
|
339
|
+
"""Format messages according to provider requirements."""
|
|
340
|
+
provider = self._get_provider_format()
|
|
341
|
+
formatted = []
|
|
190
342
|
|
|
191
|
-
|
|
343
|
+
try:
|
|
344
|
+
for msg in input_messages:
|
|
345
|
+
if isinstance(msg, (HumanMessage, AIMessage, SystemMessage)):
|
|
346
|
+
# Convert LangChain messages based on provider
|
|
347
|
+
if provider == "bedrock":
|
|
348
|
+
role = "user" if isinstance(msg, HumanMessage) else \
|
|
349
|
+
"assistant" if isinstance(msg, AIMessage) else \
|
|
350
|
+
"system"
|
|
351
|
+
else:
|
|
352
|
+
# Default/fallback format
|
|
353
|
+
role = msg.__class__.__name__.lower().replace('message', '')
|
|
354
|
+
|
|
355
|
+
content = self._format_message_content(msg)
|
|
356
|
+
elif isinstance(msg, dict) and "role" in msg and "content" in msg:
|
|
357
|
+
# Already in provider format
|
|
358
|
+
role = msg["role"]
|
|
359
|
+
content = self._format_message_content(msg["content"])
|
|
360
|
+
else:
|
|
361
|
+
logger.warning(f"Unknown message format: {type(msg)}")
|
|
362
|
+
role = "user" # Default to user role
|
|
363
|
+
content = self._format_message_content(msg)
|
|
364
|
+
|
|
365
|
+
formatted.append({"role": role, "content": content})
|
|
366
|
+
except Exception as e:
|
|
367
|
+
logger.error(f"Error formatting messages: {str(e)}")
|
|
368
|
+
raise
|
|
369
|
+
def _validate_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
|
370
|
+
"""Remove any messages with empty content."""
|
|
371
|
+
return [msg for msg in messages if msg.get('content')]
|
|
372
|
+
|
|
373
|
+
def invoke(self, input: Any, config: Optional[Dict] = None, **kwargs) -> Any:
|
|
374
|
+
try:
|
|
375
|
+
if isinstance(input, dict) and "messages" in input:
|
|
376
|
+
messages = self._convert_to_messages(input["messages"])
|
|
377
|
+
input = {**input, "messages": messages}
|
|
378
|
+
return self.model.invoke(input, config, **kwargs)
|
|
379
|
+
except Exception as e:
|
|
380
|
+
self.logger.error(f"Error in invoke: {str(e)}")
|
|
381
|
+
raise
|
|
382
|
+
|
|
383
|
+
async def ainvoke(self, input: Any, config: Optional[Dict] = None, **kwargs) -> Any:
|
|
384
|
+
return await self.model.ainvoke(input, config, **kwargs)
|
|
385
|
+
|
|
386
|
+
# Initialize the model using the ModelManager
|
|
387
|
+
model = RetryingChatBedrock(ModelManager.initialize_model())
|
|
192
388
|
|
|
193
389
|
file_state_manager = FileStateManager()
|
|
194
390
|
|
|
@@ -314,7 +510,13 @@ def extract_codebase(x):
|
|
|
314
510
|
return codebase
|
|
315
511
|
|
|
316
512
|
def log_output(x):
|
|
317
|
-
|
|
513
|
+
"""Log output in a consistent format."""
|
|
514
|
+
try:
|
|
515
|
+
output = x.content if hasattr(x, 'content') else str(x)
|
|
516
|
+
logger.info(f"Final output size: {len(output)} chars, first 100 chars: {output[:100]}")
|
|
517
|
+
except Exception as e:
|
|
518
|
+
logger.error(f"Error in log_output: {str(e)}")
|
|
519
|
+
output = str(x)
|
|
318
520
|
return x
|
|
319
521
|
|
|
320
522
|
def log_codebase_wrapper(x):
|
|
@@ -322,7 +524,8 @@ def log_codebase_wrapper(x):
|
|
|
322
524
|
logger.info(f"Codebase before prompt: {len(codebase)} chars")
|
|
323
525
|
file_count = len([l for l in codebase.split('\n') if l.startswith('File: ')])
|
|
324
526
|
logger.info(f"Number of files in codebase before prompt: {file_count}")
|
|
325
|
-
|
|
527
|
+
file_lines = [l for l in codebase.split('\n') if l.startswith('File: ')]
|
|
528
|
+
logger.info("Files in codebase before prompt:\n" + "\n".join(file_lines))
|
|
326
529
|
return codebase
|
|
327
530
|
|
|
328
531
|
# Define the agent chain
|
|
@@ -330,17 +533,17 @@ agent = (
|
|
|
330
533
|
{
|
|
331
534
|
"codebase": log_codebase_wrapper,
|
|
332
535
|
"question": lambda x: x["question"],
|
|
333
|
-
"
|
|
334
|
-
"
|
|
536
|
+
"chat_history": lambda x: _format_chat_history(x.get("chat_history", [])),
|
|
537
|
+
"agent_scratchpad": lambda x: [
|
|
538
|
+
AIMessage(content=format_xml([]))
|
|
539
|
+
],
|
|
335
540
|
}
|
|
336
541
|
| conversational_prompt
|
|
337
|
-
| (lambda x: (
|
|
338
|
-
logger.info(f"Template population check:") or
|
|
339
|
-
logger.info(f"System message contains codebase section: {'---------------------------------------' in str(x)}") or
|
|
340
|
-
logger.info(f"Number of 'File:' markers in system message: {str(x).count('File:')}") or
|
|
341
|
-
x))
|
|
342
542
|
| llm_with_stop
|
|
343
|
-
|
|
|
543
|
+
| (lambda x: AgentFinish(
|
|
544
|
+
return_values={"output": x.content if hasattr(x, 'content') else str(x)},
|
|
545
|
+
log=""
|
|
546
|
+
))
|
|
344
547
|
| log_output
|
|
345
548
|
)
|
|
346
549
|
|
|
@@ -380,9 +583,8 @@ def update_and_return(input_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
380
583
|
agent_executor = AgentExecutor(
|
|
381
584
|
agent=agent,
|
|
382
585
|
tools=[],
|
|
383
|
-
verbose=
|
|
586
|
+
verbose=False,
|
|
384
587
|
handle_parsing_errors=True,
|
|
385
|
-
max_iterations=3
|
|
386
588
|
).with_types(input_type=AgentInput) | RunnablePassthrough(update_and_return)
|
|
387
589
|
|
|
388
590
|
# Chain the executor with the state update
|